From 707e054e502938b72e0b5ee73495f662eac27269 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=B6rg=20Prante?= Date: Sat, 6 Jan 2024 00:06:56 +0100 Subject: [PATCH] initial commit, fork of netty 4.1.104.Final --- .gitignore | 16 + LICENSE.txt | 202 + NOTICE.txt | 30 + build.gradle | 35 + gradle.properties | 3 + gradle/compile/java.gradle | 37 + gradle/documentation/asciidoc.gradle | 19 + gradle/ide/idea.gradle | 8 + gradle/publish/forgejo.gradle | 16 + gradle/publish/ivy.gradle | 27 + gradle/publish/maven.gradle | 52 + gradle/publish/sonatype.gradle | 12 + gradle/quality/checkstyle.gradle | 19 + gradle/quality/checkstyle.xml | 333 + gradle/quality/cyclonedx.gradle | 11 + gradle/quality/pmd.gradle | 17 + .../pmd/category/java/bestpractices.xml | 1650 +++++ .../pmd/category/java/categories.properties | 10 + .../quality/pmd/category/java/codestyle.xml | 2176 ++++++ gradle/quality/pmd/category/java/design.xml | 1657 +++++ .../pmd/category/java/documentation.xml | 144 + .../quality/pmd/category/java/errorprone.xml | 3383 ++++++++++ .../pmd/category/java/multithreading.xml | 393 ++ .../quality/pmd/category/java/performance.xml | 1006 +++ gradle/quality/pmd/category/java/security.xml | 65 + gradle/quality/sonarqube.gradle | 10 + gradle/quality/spotbugs.gradle | 14 + gradle/repositories/maven.gradle | 4 + gradle/test/jmh.gradle | 22 + gradle/test/junit5.gradle | 44 + gradle/wrapper/gradle-wrapper.jar | Bin 0 -> 43462 bytes gradle/wrapper/gradle-wrapper.properties | 7 + gradlew | 249 + gradlew.bat | 92 + netty-buffer/build.gradle | 5 + .../java/io/netty/buffer/AbstractByteBuf.java | 1459 ++++ .../buffer/AbstractByteBufAllocator.java | 280 + .../netty/buffer/AbstractDerivedByteBuf.java | 129 + .../buffer/AbstractPooledDerivedByteBuf.java | 323 + .../AbstractReferenceCountedByteBuf.java | 120 + .../buffer/AbstractUnpooledSlicedByteBuf.java | 477 ++ .../buffer/AbstractUnsafeSwappedByteBuf.java | 171 + .../buffer/AdvancedLeakAwareByteBuf.java | 968 +++ .../AdvancedLeakAwareCompositeByteBuf.java | 1055 +++ .../main/java/io/netty/buffer/ByteBuf.java | 2492 +++++++ .../io/netty/buffer/ByteBufAllocator.java | 134 + .../netty/buffer/ByteBufAllocatorMetric.java | 28 + .../ByteBufAllocatorMetricProvider.java | 24 + .../io/netty/buffer/ByteBufConvertible.java | 32 + .../java/io/netty/buffer/ByteBufHolder.java | 63 + .../io/netty/buffer/ByteBufInputStream.java | 330 + .../io/netty/buffer/ByteBufOutputStream.java | 168 + .../io/netty/buffer/ByteBufProcessor.java | 136 + .../java/io/netty/buffer/ByteBufUtil.java | 1952 ++++++ .../io/netty/buffer/CompositeByteBuf.java | 2363 +++++++ .../io/netty/buffer/DefaultByteBufHolder.java | 158 + .../io/netty/buffer/DuplicatedByteBuf.java | 410 ++ .../java/io/netty/buffer/EmptyByteBuf.java | 1062 +++ .../netty/buffer/FixedCompositeByteBuf.java | 688 ++ .../java/io/netty/buffer/HeapByteBufUtil.java | 146 + .../io/netty/buffer/IntPriorityQueue.java | 107 + .../java/io/netty/buffer/LongLongHashMap.java | 129 + .../main/java/io/netty/buffer/PoolArena.java | 799 +++ .../java/io/netty/buffer/PoolArenaMetric.java | 155 + .../main/java/io/netty/buffer/PoolChunk.java | 709 ++ .../java/io/netty/buffer/PoolChunkList.java | 262 + .../io/netty/buffer/PoolChunkListMetric.java | 32 + .../java/io/netty/buffer/PoolChunkMetric.java | 37 + .../java/io/netty/buffer/PoolSubpage.java | 300 + .../io/netty/buffer/PoolSubpageMetric.java | 43 + .../java/io/netty/buffer/PoolThreadCache.java | 504 ++ .../java/io/netty/buffer/PooledByteBuf.java | 269 + .../netty/buffer/PooledByteBufAllocator.java | 763 +++ .../buffer/PooledByteBufAllocatorMetric.java | 124 + .../io/netty/buffer/PooledDirectByteBuf.java | 313 + .../netty/buffer/PooledDuplicatedByteBuf.java | 378 ++ .../io/netty/buffer/PooledHeapByteBuf.java | 254 + .../io/netty/buffer/PooledSlicedByteBuf.java | 441 ++ .../buffer/PooledUnsafeDirectByteBuf.java | 273 + .../netty/buffer/PooledUnsafeHeapByteBuf.java | 166 + .../java/io/netty/buffer/ReadOnlyByteBuf.java | 430 ++ .../netty/buffer/ReadOnlyByteBufferBuf.java | 485 ++ .../buffer/ReadOnlyUnsafeDirectByteBuf.java | 124 + .../netty/buffer/SimpleLeakAwareByteBuf.java | 175 + .../SimpleLeakAwareCompositeByteBuf.java | 126 + .../java/io/netty/buffer/SizeClasses.java | 413 ++ .../io/netty/buffer/SizeClassesMetric.java | 87 + .../java/io/netty/buffer/SlicedByteBuf.java | 49 + .../java/io/netty/buffer/SwappedByteBuf.java | 1066 +++ .../main/java/io/netty/buffer/Unpooled.java | 923 +++ .../buffer/UnpooledByteBufAllocator.java | 269 + .../netty/buffer/UnpooledDirectByteBuf.java | 654 ++ .../buffer/UnpooledDuplicatedByteBuf.java | 121 + .../io/netty/buffer/UnpooledHeapByteBuf.java | 556 ++ .../netty/buffer/UnpooledSlicedByteBuf.java | 126 + .../buffer/UnpooledUnsafeDirectByteBuf.java | 315 + .../buffer/UnpooledUnsafeHeapByteBuf.java | 282 + .../UnpooledUnsafeNoCleanerDirectByteBuf.java | 55 + .../io/netty/buffer/UnreleasableByteBuf.java | 133 + .../io/netty/buffer/UnsafeByteBufUtil.java | 691 ++ .../buffer/UnsafeDirectSwappedByteBuf.java | 67 + .../buffer/UnsafeHeapSwappedByteBuf.java | 63 + .../java/io/netty/buffer/WrappedByteBuf.java | 1049 +++ .../netty/buffer/WrappedCompositeByteBuf.java | 1284 ++++ .../WrappedUnpooledUnsafeDirectByteBuf.java | 32 + .../java/io/netty/buffer/package-info.java | 128 + .../AbstractMultiSearchProcessorFactory.java | 94 + .../AbstractSearchProcessorFactory.java | 115 + .../AhoCorasicSearchProcessorFactory.java | 191 + .../search/BitapSearchProcessorFactory.java | 77 + .../search/KmpSearchProcessorFactory.java | 91 + .../buffer/search/MultiSearchProcessor.java | 28 + .../search/MultiSearchProcessorFactory.java | 25 + .../netty/buffer/search/SearchProcessor.java | 30 + .../buffer/search/SearchProcessorFactory.java | 24 + .../io/netty/buffer/search/package-info.java | 20 + netty-buffer/src/main/java/module-info.java | 5 + .../buffer/AbstractByteBufAllocatorTest.java | 145 + .../io/netty/buffer/AbstractByteBufTest.java | 5992 +++++++++++++++++ .../buffer/AbstractCompositeByteBufTest.java | 1810 +++++ .../buffer/AbstractPooledByteBufTest.java | 147 + .../AbstractReferenceCountedByteBufTest.java | 358 + .../buffer/AdvancedLeakAwareByteBufTest.java | 53 + ...AdvancedLeakAwareCompositeByteBufTest.java | 31 + .../AlignedPooledByteBufAllocatorTest.java | 130 + .../buffer/BigEndianCompositeByteBufTest.java | 34 + .../buffer/BigEndianDirectByteBufTest.java | 49 + .../buffer/BigEndianHeapByteBufTest.java | 55 + .../BigEndianUnsafeDirectByteBufTest.java | 36 + ...ndianUnsafeNoCleanerDirectByteBufTest.java | 36 + .../io/netty/buffer/ByteBufAllocatorTest.java | 227 + .../netty/buffer/ByteBufDerivationTest.java | 217 + .../io/netty/buffer/ByteBufStreamTest.java | 310 + .../java/io/netty/buffer/ByteBufUtilTest.java | 1086 +++ .../io/netty/buffer/ByteProcessorTest.java | 69 + .../io/netty/buffer/ConsolidationTest.java | 77 + .../buffer/DefaultByteBufHolderTest.java | 108 + .../netty/buffer/DuplicatedByteBufTest.java | 89 + .../io/netty/buffer/EmptyByteBufTest.java | 115 + .../buffer/FixedCompositeByteBufTest.java | 526 ++ .../io/netty/buffer/IntPriorityQueueTest.java | 147 + .../LittleEndianCompositeByteBufTest.java | 26 + .../buffer/LittleEndianDirectByteBufTest.java | 39 + .../buffer/LittleEndianHeapByteBufTest.java | 33 + .../LittleEndianUnsafeDirectByteBufTest.java | 35 + ...ndianUnsafeNoCleanerDirectByteBufTest.java | 36 + .../io/netty/buffer/LongLongHashMapTest.java | 86 + .../netty/buffer/NoopResourceLeakTracker.java | 41 + .../java/io/netty/buffer/PoolArenaTest.java | 172 + ...oledAlignedBigEndianDirectByteBufTest.java | 54 + .../PooledBigEndianDirectByteBufTest.java | 33 + .../PooledBigEndianHeapByteBufTest.java | 27 + .../buffer/PooledByteBufAllocatorTest.java | 963 +++ .../PooledLittleEndianDirectByteBufTest.java | 34 + .../PooledLittleEndianHeapByteBufTest.java | 33 + .../io/netty/buffer/ReadOnlyByteBufTest.java | 300 + .../buffer/ReadOnlyByteBufferBufTest.java | 58 + .../ReadOnlyDirectByteBufferBufTest.java | 426 ++ ...ReadOnlyUnsafeDirectByteBufferBufTest.java | 53 + .../buffer/RetainedDuplicatedByteBufTest.java | 32 + .../buffer/RetainedSlicedByteBufTest.java | 30 + .../buffer/SimpleLeakAwareByteBufTest.java | 144 + .../SimpleLeakAwareCompositeByteBufTest.java | 165 + .../io/netty/buffer/SlicedByteBufTest.java | 355 + .../buffer/UnpooledByteBufAllocatorTest.java | 29 + .../java/io/netty/buffer/UnpooledTest.java | 821 +++ .../buffer/UnreleaseableByteBufTest.java | 54 + .../netty/buffer/UnsafeByteBufUtilTest.java | 252 + .../buffer/WrappedCompositeByteBufTest.java | 33 + .../WrappedUnpooledUnsafeByteBufTest.java | 264 + .../BitapSearchProcessorFactoryTest.java | 40 + .../search/MultiSearchProcessorTest.java | 107 + .../buffer/search/SearchProcessorTest.java | 167 + .../src/test/resources/logging.properties | 7 + netty-bzip2/build.gradle | 5 + .../java/io/netty/bzip2/Bzip2BitReader.java | 157 + .../java/io/netty/bzip2/Bzip2BitWriter.java | 120 + .../io/netty/bzip2/Bzip2BlockCompressor.java | 298 + .../netty/bzip2/Bzip2BlockDecompressor.java | 350 + .../java/io/netty/bzip2/Bzip2Constants.java | 106 + .../java/io/netty/bzip2/Bzip2DivSufSort.java | 2117 ++++++ .../io/netty/bzip2/Bzip2HuffmanAllocator.java | 184 + .../netty/bzip2/Bzip2HuffmanStageDecoder.java | 203 + .../netty/bzip2/Bzip2HuffmanStageEncoder.java | 374 + .../bzip2/Bzip2MTFAndRLE2StageEncoder.java | 185 + .../io/netty/bzip2/Bzip2MoveToFrontTable.java | 84 + .../main/java/io/netty/bzip2/Bzip2Rand.java | 77 + .../src/main/java/io/netty/bzip2/Crc32.java | 123 + .../netty/bzip2/DecompressionException.java | 53 + netty-bzip2/src/main/java/module-info.java | 6 + netty-channel-unix/build.gradle | 5 + .../java/io/netty/channel/unix/Buffer.java | 68 + .../channel/unix/DatagramSocketAddress.java | 57 + .../channel/unix/DomainDatagramChannel.java | 39 + .../unix/DomainDatagramChannelConfig.java | 80 + .../channel/unix/DomainDatagramPacket.java | 86 + .../unix/DomainDatagramSocketAddress.java | 48 + .../channel/unix/DomainSocketAddress.java | 67 + .../channel/unix/DomainSocketChannel.java | 33 + .../unix/DomainSocketChannelConfig.java | 80 + .../channel/unix/DomainSocketReadMode.java | 34 + .../java/io/netty/channel/unix/Errors.java | 223 + .../ErrorsStaticallyReferencedJniMethods.java | 47 + .../io/netty/channel/unix/FileDescriptor.java | 240 + .../unix/GenericUnixChannelOption.java | 51 + .../unix/IntegerUnixChannelOption.java | 32 + .../java/io/netty/channel/unix/IovArray.java | 237 + .../java/io/netty/channel/unix/Limits.java | 31 + .../LimitsStaticallyReferencedJniMethods.java | 37 + .../netty/channel/unix/NativeInetAddress.java | 111 + .../netty/channel/unix/PeerCredentials.java | 74 + .../unix/PreferredDirectByteBufAllocator.java | 130 + .../channel/unix/RawUnixChannelOption.java | 61 + .../channel/unix/SegmentedDatagramPacket.java | 109 + .../unix/ServerDomainSocketChannel.java | 30 + .../java/io/netty/channel/unix/Socket.java | 720 ++ .../unix/SocketWritableByteChannel.java | 86 + .../main/java/io/netty/channel/unix/Unix.java | 94 + .../io/netty/channel/unix/UnixChannel.java | 28 + .../netty/channel/unix/UnixChannelOption.java | 33 + .../netty/channel/unix/UnixChannelUtil.java | 62 + .../io/netty/channel/unix/package-info.java | 20 + .../src/main/java/module-info.java | 6 + .../channel/unix/UnixChannelUtilTest.java | 88 + .../src/test/resources/logging.properties | 7 + netty-channel/build.gradle | 8 + .../io/netty/bootstrap/AbstractBootstrap.java | 530 ++ .../bootstrap/AbstractBootstrapConfig.java | 135 + .../java/io/netty/bootstrap/Bootstrap.java | 353 + .../io/netty/bootstrap/BootstrapConfig.java | 63 + .../io/netty/bootstrap/ChannelFactory.java | 29 + .../ChannelInitializerExtension.java | 122 + .../ChannelInitializerExtensions.java | 128 + .../io/netty/bootstrap/FailedChannel.java | 107 + .../io/netty/bootstrap/ServerBootstrap.java | 309 + .../bootstrap/ServerBootstrapConfig.java | 105 + .../java/io/netty/bootstrap/package-info.java | 21 + .../io/netty/channel/AbstractChannel.java | 1215 ++++ .../AbstractChannelHandlerContext.java | 1324 ++++ .../AbstractCoalescingBufferQueue.java | 399 ++ .../io/netty/channel/AbstractEventLoop.java | 41 + .../netty/channel/AbstractEventLoopGroup.java | 27 + .../netty/channel/AbstractServerChannel.java | 82 + .../channel/AdaptiveRecvByteBufAllocator.java | 205 + .../io/netty/channel/AddressedEnvelope.java | 56 + .../main/java/io/netty/channel/Channel.java | 302 + .../java/io/netty/channel/ChannelConfig.java | 268 + .../netty/channel/ChannelDuplexHandler.java | 129 + .../io/netty/channel/ChannelException.java | 94 + .../java/io/netty/channel/ChannelFactory.java | 28 + .../channel/ChannelFlushPromiseNotifier.java | 273 + .../java/io/netty/channel/ChannelFuture.java | 212 + .../netty/channel/ChannelFutureListener.java | 75 + .../java/io/netty/channel/ChannelHandler.java | 219 + .../netty/channel/ChannelHandlerAdapter.java | 94 + .../netty/channel/ChannelHandlerContext.java | 174 + .../io/netty/channel/ChannelHandlerMask.java | 205 + .../main/java/io/netty/channel/ChannelId.java | 56 + .../netty/channel/ChannelInboundHandler.java | 75 + .../channel/ChannelInboundHandlerAdapter.java | 145 + .../netty/channel/ChannelInboundInvoker.java | 94 + .../io/netty/channel/ChannelInitializer.java | 159 + .../io/netty/channel/ChannelMetadata.java | 72 + .../java/io/netty/channel/ChannelOption.java | 166 + .../netty/channel/ChannelOutboundBuffer.java | 877 +++ .../netty/channel/ChannelOutboundHandler.java | 99 + .../ChannelOutboundHandlerAdapter.java | 127 + .../netty/channel/ChannelOutboundInvoker.java | 271 + .../io/netty/channel/ChannelPipeline.java | 632 ++ .../channel/ChannelPipelineException.java | 52 + .../channel/ChannelProgressiveFuture.java | 49 + .../ChannelProgressiveFutureListener.java | 28 + .../channel/ChannelProgressivePromise.java | 65 + .../java/io/netty/channel/ChannelPromise.java | 68 + .../channel/ChannelPromiseAggregator.java | 38 + .../netty/channel/ChannelPromiseNotifier.java | 48 + .../netty/channel/CoalescingBufferQueue.java | 86 + .../channel/CombinedChannelDuplexHandler.java | 616 ++ .../netty/channel/CompleteChannelFuture.java | 110 + .../channel/ConnectTimeoutException.java | 33 + .../channel/DefaultAddressedEnvelope.java | 129 + .../netty/channel/DefaultChannelConfig.java | 442 ++ .../channel/DefaultChannelHandlerContext.java | 34 + .../io/netty/channel/DefaultChannelId.java | 321 + .../netty/channel/DefaultChannelPipeline.java | 1511 +++++ .../DefaultChannelProgressivePromise.java | 179 + .../netty/channel/DefaultChannelPromise.java | 172 + .../io/netty/channel/DefaultEventLoop.java | 63 + .../netty/channel/DefaultEventLoopGroup.java | 75 + .../io/netty/channel/DefaultFileRegion.java | 192 + .../DefaultMaxBytesRecvByteBufAllocator.java | 195 + ...efaultMaxMessagesRecvByteBufAllocator.java | 171 + .../channel/DefaultMessageSizeEstimator.java | 72 + .../netty/channel/DefaultSelectStrategy.java | 32 + .../channel/DefaultSelectStrategyFactory.java | 30 + .../DelegatingChannelPromiseNotifier.java | 226 + .../main/java/io/netty/channel/EventLoop.java | 30 + .../io/netty/channel/EventLoopException.java | 41 + .../java/io/netty/channel/EventLoopGroup.java | 52 + .../channel/EventLoopTaskQueueFactory.java | 35 + .../ExtendedClosedChannelException.java | 33 + .../io/netty/channel/FailedChannelFuture.java | 63 + .../java/io/netty/channel/FileRegion.java | 101 + .../channel/FixedRecvByteBufAllocator.java | 61 + .../channel/MaxBytesRecvByteBufAllocator.java | 65 + .../MaxMessagesRecvByteBufAllocator.java | 35 + .../netty/channel/MessageSizeEstimator.java | 39 + .../channel/MultithreadEventLoopGroup.java | 100 + .../io/netty/channel/PendingBytesTracker.java | 104 + .../io/netty/channel/PendingWriteQueue.java | 330 + .../channel/PreferHeapByteBufAllocator.java | 135 + .../netty/channel/RecvByteBufAllocator.java | 188 + .../channel/ReflectiveChannelFactory.java | 55 + .../java/io/netty/channel/SelectStrategy.java | 52 + .../netty/channel/SelectStrategyFactory.java | 27 + .../java/io/netty/channel/ServerChannel.java | 27 + .../ServerChannelRecvByteBufAllocator.java | 35 + .../channel/SimpleChannelInboundHandler.java | 120 + .../SimpleUserEventChannelHandler.java | 120 + .../netty/channel/SingleThreadEventLoop.java | 217 + .../StacklessClosedChannelException.java | 43 + .../netty/channel/SucceededChannelFuture.java | 45 + .../channel/ThreadPerChannelEventLoop.java | 103 + .../ThreadPerChannelEventLoopGroup.java | 321 + .../io/netty/channel/VoidChannelPromise.java | 239 + .../netty/channel/WriteBufferWaterMark.java | 96 + .../channel/embedded/EmbeddedChannel.java | 930 +++ .../channel/embedded/EmbeddedChannelId.java | 65 + .../channel/embedded/EmbeddedEventLoop.java | 201 + .../embedded/EmbeddedSocketAddress.java | 27 + .../netty/channel/embedded/package-info.java | 22 + .../io/netty/channel/group/ChannelGroup.java | 278 + .../channel/group/ChannelGroupException.java | 49 + .../channel/group/ChannelGroupFuture.java | 175 + .../group/ChannelGroupFutureListener.java | 28 + .../netty/channel/group/ChannelMatcher.java | 32 + .../netty/channel/group/ChannelMatchers.java | 169 + .../netty/channel/group/CombinedIterator.java | 72 + .../channel/group/DefaultChannelGroup.java | 463 ++ .../group/DefaultChannelGroupFuture.java | 259 + .../channel/group/VoidChannelGroupFuture.java | 175 + .../io/netty/channel/group/package-info.java | 21 + .../netty/channel/internal/ChannelUtils.java | 24 + .../netty/channel/internal/package-info.java | 20 + .../io/netty/channel/local/LocalAddress.java | 97 + .../io/netty/channel/local/LocalChannel.java | 524 ++ .../channel/local/LocalChannelRegistry.java | 62 + .../channel/local/LocalEventLoopGroup.java | 60 + .../channel/local/LocalServerChannel.java | 181 + .../io/netty/channel/local/package-info.java | 21 + .../channel/nio/AbstractNioByteChannel.java | 352 + .../netty/channel/nio/AbstractNioChannel.java | 513 ++ .../nio/AbstractNioMessageChannel.java | 211 + .../io/netty/channel/nio/NioEventLoop.java | 894 +++ .../netty/channel/nio/NioEventLoopGroup.java | 186 + .../java/io/netty/channel/nio/NioTask.java | 41 + .../channel/nio/SelectedSelectionKeySet.java | 109 + .../nio/SelectedSelectionKeySetSelector.java | 80 + .../io/netty/channel/nio/package-info.java | 21 + .../channel/oio/AbstractOioByteChannel.java | 270 + .../netty/channel/oio/AbstractOioChannel.java | 166 + .../oio/AbstractOioMessageChannel.java | 113 + .../channel/oio/OioByteStreamChannel.java | 172 + .../netty/channel/oio/OioEventLoopGroup.java | 87 + .../io/netty/channel/oio/package-info.java | 21 + .../java/io/netty/channel/package-info.java | 22 + .../pool/AbstractChannelPoolHandler.java | 44 + .../channel/pool/AbstractChannelPoolMap.java | 153 + .../channel/pool/ChannelHealthChecker.java | 47 + .../io/netty/channel/pool/ChannelPool.java | 61 + .../channel/pool/ChannelPoolHandler.java | 48 + .../io/netty/channel/pool/ChannelPoolMap.java | 39 + .../netty/channel/pool/FixedChannelPool.java | 533 ++ .../netty/channel/pool/SimpleChannelPool.java | 440 ++ .../io/netty/channel/pool/package-info.java | 20 + .../socket/ChannelInputShutdownEvent.java | 36 + .../ChannelInputShutdownReadComplete.java | 27 + .../socket/ChannelOutputShutdownEvent.java | 33 + .../ChannelOutputShutdownException.java | 38 + .../netty/channel/socket/DatagramChannel.java | 165 + .../channel/socket/DatagramChannelConfig.java | 188 + .../netty/channel/socket/DatagramPacket.java | 88 + .../socket/DefaultDatagramChannelConfig.java | 435 ++ .../DefaultServerSocketChannelConfig.java | 209 + .../socket/DefaultSocketChannelConfig.java | 347 + .../netty/channel/socket/DuplexChannel.java | 81 + .../channel/socket/DuplexChannelConfig.java | 84 + .../socket/InternetProtocolFamily.java | 81 + .../channel/socket/ServerSocketChannel.java | 32 + .../socket/ServerSocketChannelConfig.java | 119 + .../netty/channel/socket/SocketChannel.java | 35 + .../channel/socket/SocketChannelConfig.java | 172 + .../channel/socket/nio/NioChannelOption.java | 123 + .../socket/nio/NioDatagramChannel.java | 625 ++ .../socket/nio/NioDatagramChannelConfig.java | 235 + .../socket/nio/NioServerSocketChannel.java | 250 + .../channel/socket/nio/NioSocketChannel.java | 537 ++ .../socket/nio/ProtocolFamilyConverter.java | 47 + .../socket/nio/SelectorProviderUtil.java | 71 + .../channel/socket/nio/package-info.java | 21 + .../oio/DefaultOioDatagramChannelConfig.java | 207 + .../DefaultOioServerSocketChannelConfig.java | 197 + .../oio/DefaultOioSocketChannelConfig.java | 225 + .../socket/oio/OioDatagramChannel.java | 460 ++ .../socket/oio/OioDatagramChannelConfig.java | 101 + .../socket/oio/OioServerSocketChannel.java | 207 + .../oio/OioServerSocketChannelConfig.java | 103 + .../channel/socket/oio/OioSocketChannel.java | 352 + .../socket/oio/OioSocketChannelConfig.java | 118 + .../channel/socket/oio/package-info.java | 21 + .../io/netty/channel/socket/package-info.java | 20 + netty-channel/src/main/java/module-info.java | 17 + .../generated/handlers/reflect-config.json | 121 + .../netty-transport/reflect-config.json | 33 + .../io/netty/bootstrap/BootstrapTest.java | 584 ++ .../netty/bootstrap/ServerBootstrapTest.java | 244 + .../StubChannelInitializerExtension.java | 47 + .../io/netty/channel/AbstractChannelTest.java | 251 + .../AbstractCoalescingBufferQueueTest.java | 91 + .../netty/channel/AbstractEventLoopTest.java | 95 + .../AdaptiveRecvByteBufAllocatorTest.java | 118 + .../io/netty/channel/BaseChannelTest.java | 89 + .../netty/channel/ChannelInitializerTest.java | 407 ++ .../io/netty/channel/ChannelOptionTest.java | 63 + .../channel/ChannelOutboundBufferTest.java | 539 ++ .../channel/CoalescingBufferQueueTest.java | 318 + .../CombinedChannelDuplexHandlerTest.java | 481 ++ .../channel/CompleteChannelFutureTest.java | 92 + .../netty/channel/DefaultChannelIdTest.java | 87 + .../DefaultChannelPipelineTailTest.java | 408 ++ .../channel/DefaultChannelPipelineTest.java | 2293 +++++++ .../channel/DefaultChannelPromiseTest.java | 56 + .../netty/channel/DefaultFileRegionTest.java | 120 + ...ltMaxMessagesRecvByteBufAllocatorTest.java | 76 + .../DelegatingChannelPromiseNotifierTest.java | 37 + .../channel/FailedChannelFutureTest.java | 46 + .../io/netty/channel/LoggingTestHandler.java | 171 + .../NativeImageHandlerMetadataTest.java | 28 + .../netty/channel/PendingWriteQueueTest.java | 415 ++ .../netty/channel/ReentrantChannelTest.java | 288 + .../SimpleUserEventChannelHandlerTest.java | 103 + .../channel/SingleThreadEventLoopTest.java | 584 ++ .../channel/SucceededChannelFutureTest.java | 33 + .../ThreadPerChannelEventLoopGroupTest.java | 117 + .../channel/embedded/CustomChannelId.java | 65 + .../embedded/EmbeddedChannelIdTest.java | 59 + .../channel/embedded/EmbeddedChannelTest.java | 796 +++ .../group/DefaultChannelGroupTest.java | 60 + .../netty/channel/local/LocalChannelTest.java | 1307 ++++ .../local/LocalTransportThreadModelTest.java | 610 ++ .../local/LocalTransportThreadModelTest2.java | 125 + .../local/LocalTransportThreadModelTest3.java | 336 + .../netty/channel/nio/NioEventLoopTest.java | 348 + .../nio/SelectedSelectionKeySetTest.java | 117 + .../netty/channel/oio/OioEventLoopTest.java | 117 + .../pool/AbstractChannelPoolMapTest.java | 173 + .../channel/pool/ChannelPoolTestUtils.java | 29 + .../pool/CountingChannelPoolHandler.java | 53 + .../pool/FixedChannelPoolMapDeadlockTest.java | 264 + .../channel/pool/FixedChannelPoolTest.java | 459 ++ .../channel/pool/SimpleChannelPoolTest.java | 401 ++ .../socket/InternetProtocolFamilyTest.java | 36 + .../socket/nio/AbstractNioChannelTest.java | 82 + .../socket/nio/NioDatagramChannelTest.java | 82 + .../nio/NioServerSocketChannelTest.java | 83 + .../socket/nio/NioSocketChannelTest.java | 305 + .../ChannelHandlerMetadataUtil.java | 246 + ...etty.bootstrap.ChannelInitializerExtension | 1 + .../src/test/resources/logging.properties | 7 + netty-handler-codec-compression/build.gradle | 17 + .../handler/codec/compression/Brotli.java | 83 + .../codec/compression/BrotliDecoder.java | 173 + .../codec/compression/BrotliEncoder.java | 278 + .../codec/compression/BrotliOptions.java | 47 + .../codec/compression/ByteBufChecksum.java | 142 + .../codec/compression/Bzip2Decoder.java | 348 + .../codec/compression/Bzip2Encoder.java | 243 + .../compression/CompressionException.java | 53 + .../codec/compression/CompressionOptions.java | 27 + .../codec/compression/CompressionUtil.java | 47 + .../handler/codec/compression/Crc32c.java | 125 + .../compression/DecompressionException.java | 53 + .../codec/compression/DeflateOptions.java | 57 + .../codec/compression/EncoderUtil.java | 57 + .../handler/codec/compression/FastLz.java | 560 ++ .../codec/compression/FastLzFrameDecoder.java | 208 + .../codec/compression/FastLzFrameEncoder.java | 172 + .../codec/compression/GzipOptions.java | 38 + .../codec/compression/JZlibDecoder.java | 204 + .../codec/compression/JZlibEncoder.java | 404 ++ .../codec/compression/JdkZlibDecoder.java | 511 ++ .../codec/compression/JdkZlibEncoder.java | 375 ++ .../codec/compression/Lz4Constants.java | 73 + .../codec/compression/Lz4FrameDecoder.java | 277 + .../codec/compression/Lz4FrameEncoder.java | 402 ++ .../codec/compression/Lz4XXHash32.java | 107 + .../handler/codec/compression/LzfDecoder.java | 242 + .../handler/codec/compression/LzfEncoder.java | 232 + .../handler/codec/compression/Snappy.java | 677 ++ .../codec/compression/SnappyFrameDecoder.java | 259 + .../codec/compression/SnappyFrameEncoder.java | 123 + .../compression/SnappyFramedDecoder.java | 25 + .../compression/SnappyFramedEncoder.java | 25 + .../codec/compression/SnappyOptions.java | 24 + .../StandardCompressionOptions.java | 136 + .../codec/compression/ZlibCodecFactory.java | 139 + .../codec/compression/ZlibDecoder.java | 95 + .../codec/compression/ZlibEncoder.java | 53 + .../handler/codec/compression/ZlibUtil.java | 85 + .../codec/compression/ZlibWrapper.java | 40 + .../netty/handler/codec/compression/Zstd.java | 75 + .../codec/compression/ZstdConstants.java | 40 + .../codec/compression/ZstdEncoder.java | 185 + .../codec/compression/ZstdOptions.java | 73 + .../codec/compression/package-info.java | 23 + .../src/main/java/module-info.java | 13 + .../compression/AbstractCompressionTest.java | 38 + .../compression/AbstractDecoderTest.java | 152 + .../compression/AbstractEncoderTest.java | 129 + .../compression/AbstractIntegrationTest.java | 185 + .../codec/compression/BrotliDecoderTest.java | 162 + .../codec/compression/BrotliEncoderTest.java | 83 + .../compression/ByteBufChecksumTest.java | 90 + .../codec/compression/Bzip2DecoderTest.java | 200 + .../codec/compression/Bzip2EncoderTest.java | 63 + .../compression/Bzip2IntegrationTest.java | 56 + .../compression/FastLzIntegrationTest.java | 117 + .../handler/codec/compression/JZlibTest.java | 29 + .../codec/compression/JdkZlibTest.java | 213 + .../LengthAwareLzfIntegrationTest.java | 28 + .../compression/Lz4FrameDecoderTest.java | 160 + .../compression/Lz4FrameEncoderTest.java | 323 + .../compression/Lz4FrameIntegrationTest.java | 31 + .../codec/compression/LzfDecoderTest.java | 73 + .../codec/compression/LzfEncoderTest.java | 39 + .../codec/compression/LzfIntegrationTest.java | 31 + .../compression/SnappyFrameDecoderTest.java | 225 + .../compression/SnappyFrameEncoderTest.java | 156 + .../compression/SnappyIntegrationTest.java | 117 + .../handler/codec/compression/SnappyTest.java | 359 + .../codec/compression/ZlibCrossTest1.java | 29 + .../codec/compression/ZlibCrossTest2.java | 45 + .../handler/codec/compression/ZlibTest.java | 479 ++ .../codec/compression/ZstdEncoderTest.java | 108 + .../src/test/resources/logging.properties | 7 + .../src/test/resources/multiple.gz | Bin 0 -> 46 bytes netty-handler-codec-http/build.gradle | 14 + .../codec/http/ClientCookieEncoder.java | 87 + .../codec/http/CombinedHttpHeaders.java | 329 + .../codec/http/ComposedLastHttpContent.java | 119 + .../codec/http/CompressionEncoderFactory.java | 27 + .../io/netty/handler/codec/http/Cookie.java | 221 + .../handler/codec/http/CookieDecoder.java | 369 + .../netty/handler/codec/http/CookieUtil.java | 104 + .../handler/codec/http/DefaultCookie.java | 195 + .../codec/http/DefaultFullHttpRequest.java | 228 + .../codec/http/DefaultFullHttpResponse.java | 255 + .../codec/http/DefaultHttpContent.java | 105 + .../codec/http/DefaultHttpHeaders.java | 446 ++ .../codec/http/DefaultHttpHeadersFactory.java | 313 + .../codec/http/DefaultHttpMessage.java | 107 + .../handler/codec/http/DefaultHttpObject.java | 63 + .../codec/http/DefaultHttpRequest.java | 149 + .../codec/http/DefaultHttpResponse.java | 145 + .../codec/http/DefaultLastHttpContent.java | 151 + .../handler/codec/http/EmptyHttpHeaders.java | 188 + .../handler/codec/http/FullHttpMessage.java | 48 + .../handler/codec/http/FullHttpRequest.java | 57 + .../handler/codec/http/FullHttpResponse.java | 54 + .../handler/codec/http/HttpChunkedInput.java | 119 + .../handler/codec/http/HttpClientCodec.java | 422 ++ .../codec/http/HttpClientUpgradeHandler.java | 277 + .../handler/codec/http/HttpConstants.java | 82 + .../netty/handler/codec/http/HttpContent.java | 54 + .../codec/http/HttpContentCompressor.java | 475 ++ .../codec/http/HttpContentDecoder.java | 288 + .../codec/http/HttpContentDecompressor.java | 86 + .../codec/http/HttpContentEncoder.java | 383 ++ .../handler/codec/http/HttpDecoderConfig.java | 225 + .../http/HttpExpectationFailedEvent.java | 25 + .../codec/http/HttpHeaderDateFormat.java | 104 + .../handler/codec/http/HttpHeaderNames.java | 386 ++ .../codec/http/HttpHeaderValidationUtil.java | 300 + .../handler/codec/http/HttpHeaderValues.java | 255 + .../netty/handler/codec/http/HttpHeaders.java | 1705 +++++ .../codec/http/HttpHeadersEncoder.java | 57 + .../codec/http/HttpHeadersFactory.java | 34 + .../netty/handler/codec/http/HttpMessage.java | 49 + .../codec/http/HttpMessageDecoderResult.java | 58 + .../handler/codec/http/HttpMessageUtil.java | 113 + .../netty/handler/codec/http/HttpMethod.java | 229 + .../netty/handler/codec/http/HttpObject.java | 27 + .../codec/http/HttpObjectAggregator.java | 574 ++ .../handler/codec/http/HttpObjectDecoder.java | 1233 ++++ .../handler/codec/http/HttpObjectEncoder.java | 597 ++ .../netty/handler/codec/http/HttpRequest.java | 78 + .../codec/http/HttpRequestDecoder.java | 359 + .../codec/http/HttpRequestEncoder.java | 80 + .../handler/codec/http/HttpResponse.java | 57 + .../codec/http/HttpResponseDecoder.java | 208 + .../codec/http/HttpResponseEncoder.java | 98 + .../codec/http/HttpResponseStatus.java | 649 ++ .../netty/handler/codec/http/HttpScheme.java | 70 + .../handler/codec/http/HttpServerCodec.java | 201 + .../http/HttpServerExpectContinueHandler.java | 97 + .../http/HttpServerKeepAliveHandler.java | 128 + .../codec/http/HttpServerUpgradeHandler.java | 453 ++ .../handler/codec/http/HttpStatusClass.java | 126 + .../io/netty/handler/codec/http/HttpUtil.java | 632 ++ .../netty/handler/codec/http/HttpVersion.java | 263 + .../handler/codec/http/LastHttpContent.java | 144 + .../codec/http/QueryStringDecoder.java | 393 ++ .../codec/http/QueryStringEncoder.java | 250 + .../codec/http/ReadOnlyHttpHeaders.java | 459 ++ .../codec/http/ServerCookieEncoder.java | 103 + .../http/TooLongHttpContentException.java | 54 + .../http/TooLongHttpHeaderException.java | 54 + .../codec/http/TooLongHttpLineException.java | 54 + .../http/cookie/ClientCookieDecoder.java | 263 + .../http/cookie/ClientCookieEncoder.java | 225 + .../handler/codec/http/cookie/Cookie.java | 146 + .../codec/http/cookie/CookieDecoder.java | 84 + .../codec/http/cookie/CookieEncoder.java | 52 + .../codec/http/cookie/CookieHeaderNames.java | 63 + .../handler/codec/http/cookie/CookieUtil.java | 183 + .../codec/http/cookie/DefaultCookie.java | 261 + .../http/cookie/ServerCookieDecoder.java | 175 + .../http/cookie/ServerCookieEncoder.java | 232 + .../codec/http/cookie/package-info.java | 20 + .../handler/codec/http/cors/CorsConfig.java | 455 ++ .../codec/http/cors/CorsConfigBuilder.java | 420 ++ .../handler/codec/http/cors/CorsHandler.java | 270 + .../handler/codec/http/cors/package-info.java | 20 + .../http/multipart/AbstractDiskHttpData.java | 484 ++ .../http/multipart/AbstractHttpData.java | 144 + .../multipart/AbstractMemoryHttpData.java | 303 + .../http/multipart/AbstractMixedHttpData.java | 279 + .../codec/http/multipart/Attribute.java | 59 + .../multipart/CaseIgnoringComparator.java | 56 + .../multipart/DefaultHttpDataFactory.java | 348 + .../http/multipart/DeleteFileOnExitHook.java | 82 + .../codec/http/multipart/DiskAttribute.java | 272 + .../codec/http/multipart/DiskFileUpload.java | 240 + .../codec/http/multipart/FileUpload.java | 84 + .../codec/http/multipart/FileUploadUtil.java | 33 + .../codec/http/multipart/HttpData.java | 243 + .../codec/http/multipart/HttpDataFactory.java | 93 + .../http/multipart/HttpPostBodyUtil.java | 269 + .../HttpPostMultipartRequestDecoder.java | 1394 ++++ .../multipart/HttpPostRequestDecoder.java | 341 + .../multipart/HttpPostRequestEncoder.java | 1347 ++++ .../HttpPostStandardRequestDecoder.java | 784 +++ .../http/multipart/InterfaceHttpData.java | 50 + .../InterfaceHttpPostRequestDecoder.java | 148 + .../http/multipart/InternalAttribute.java | 155 + .../codec/http/multipart/MemoryAttribute.java | 197 + .../http/multipart/MemoryFileUpload.java | 188 + .../codec/http/multipart/MixedAttribute.java | 157 + .../codec/http/multipart/MixedFileUpload.java | 131 + .../codec/http/multipart/package-info.java | 20 + .../handler/codec/http/package-info.java | 20 + .../http/websocketx/BinaryWebSocketFrame.java | 100 + .../http/websocketx/CloseWebSocketFrame.java | 206 + .../ContinuationWebSocketFrame.java | 137 + .../CorruptedWebSocketFrameException.java | 64 + .../http/websocketx/PingWebSocketFrame.java | 100 + .../http/websocketx/PongWebSocketFrame.java | 100 + .../http/websocketx/TextWebSocketFrame.java | 140 + .../http/websocketx/Utf8FrameValidator.java | 120 + .../codec/http/websocketx/Utf8Validator.java | 110 + .../websocketx/WebSocket00FrameDecoder.java | 148 + .../websocketx/WebSocket00FrameEncoder.java | 99 + .../websocketx/WebSocket07FrameDecoder.java | 115 + .../websocketx/WebSocket07FrameEncoder.java | 73 + .../websocketx/WebSocket08FrameDecoder.java | 488 ++ .../websocketx/WebSocket08FrameEncoder.java | 232 + .../websocketx/WebSocket13FrameDecoder.java | 111 + .../websocketx/WebSocket13FrameEncoder.java | 73 + .../websocketx/WebSocketChunkedInput.java | 114 + .../WebSocketClientHandshakeException.java | 55 + .../websocketx/WebSocketClientHandshaker.java | 784 +++ .../WebSocketClientHandshaker00.java | 340 + .../WebSocketClientHandshaker07.java | 346 + .../WebSocketClientHandshaker08.java | 348 + .../WebSocketClientHandshaker13.java | 360 + .../WebSocketClientHandshakerFactory.java | 290 + .../WebSocketClientProtocolConfig.java | 438 ++ .../WebSocketClientProtocolHandler.java | 393 ++ ...bSocketClientProtocolHandshakeHandler.java | 144 + .../http/websocketx/WebSocketCloseStatus.java | 330 + .../websocketx/WebSocketDecoderConfig.java | 165 + .../codec/http/websocketx/WebSocketFrame.java | 109 + .../websocketx/WebSocketFrameAggregator.java | 99 + .../websocketx/WebSocketFrameDecoder.java | 27 + .../websocketx/WebSocketFrameEncoder.java | 27 + .../WebSocketHandshakeException.java | 32 + .../websocketx/WebSocketProtocolHandler.java | 194 + .../http/websocketx/WebSocketScheme.java | 69 + .../WebSocketServerHandshakeException.java | 55 + .../websocketx/WebSocketServerHandshaker.java | 514 ++ .../WebSocketServerHandshaker00.java | 245 + .../WebSocketServerHandshaker07.java | 186 + .../WebSocketServerHandshaker08.java | 192 + .../WebSocketServerHandshaker13.java | 202 + .../WebSocketServerHandshakerFactory.java | 180 + .../WebSocketServerProtocolConfig.java | 296 + .../WebSocketServerProtocolHandler.java | 276 + ...bSocketServerProtocolHandshakeHandler.java | 179 + .../codec/http/websocketx/WebSocketUtil.java | 173 + .../http/websocketx/WebSocketVersion.java | 77 + .../extensions/WebSocketClientExtension.java | 23 + .../WebSocketClientExtensionHandler.java | 128 + .../WebSocketClientExtensionHandshaker.java | 41 + .../extensions/WebSocketExtension.java | 42 + .../extensions/WebSocketExtensionData.java | 52 + .../extensions/WebSocketExtensionDecoder.java | 26 + .../extensions/WebSocketExtensionEncoder.java | 26 + .../extensions/WebSocketExtensionFilter.java | 54 + .../WebSocketExtensionFilterProvider.java | 45 + .../extensions/WebSocketExtensionUtil.java | 127 + .../extensions/WebSocketServerExtension.java | 32 + .../WebSocketServerExtensionHandler.java | 263 + .../WebSocketServerExtensionHandshaker.java | 33 + .../compression/DeflateDecoder.java | 146 + .../compression/DeflateEncoder.java | 165 + ...DeflateFrameClientExtensionHandshaker.java | 124 + ...DeflateFrameServerExtensionHandshaker.java | 125 + .../compression/PerFrameDeflateDecoder.java | 75 + .../compression/PerFrameDeflateEncoder.java | 81 + ...ssageDeflateClientExtensionHandshaker.java | 231 + .../compression/PerMessageDeflateDecoder.java | 96 + .../compression/PerMessageDeflateEncoder.java | 101 + ...ssageDeflateServerExtensionHandshaker.java | 232 + .../WebSocketClientCompressionHandler.java | 38 + .../WebSocketServerCompressionHandler.java | 36 + .../extensions/compression/package-info.java | 33 + .../websocketx/extensions/package-info.java | 23 + .../codec/http/websocketx/package-info.java | 39 + .../netty/handler/codec/rtsp/RtspDecoder.java | 181 + .../netty/handler/codec/rtsp/RtspEncoder.java | 68 + .../handler/codec/rtsp/RtspHeaderNames.java | 207 + .../handler/codec/rtsp/RtspHeaderValues.java | 196 + .../netty/handler/codec/rtsp/RtspHeaders.java | 398 ++ .../netty/handler/codec/rtsp/RtspMethods.java | 133 + .../handler/codec/rtsp/RtspObjectDecoder.java | 92 + .../handler/codec/rtsp/RtspObjectEncoder.java | 44 + .../codec/rtsp/RtspRequestDecoder.java | 23 + .../codec/rtsp/RtspRequestEncoder.java | 23 + .../codec/rtsp/RtspResponseDecoder.java | 23 + .../codec/rtsp/RtspResponseEncoder.java | 23 + .../codec/rtsp/RtspResponseStatuses.java | 292 + .../handler/codec/rtsp/RtspVersions.java | 50 + .../handler/codec/rtsp/package-info.java | 21 + .../codec/spdy/DefaultSpdyDataFrame.java | 157 + .../codec/spdy/DefaultSpdyGoAwayFrame.java | 95 + .../codec/spdy/DefaultSpdyHeaders.java | 84 + .../codec/spdy/DefaultSpdyHeadersFrame.java | 120 + .../codec/spdy/DefaultSpdyPingFrame.java | 56 + .../codec/spdy/DefaultSpdyRstStreamFrame.java | 84 + .../codec/spdy/DefaultSpdySettingsFrame.java | 184 + .../codec/spdy/DefaultSpdyStreamFrame.java | 59 + .../codec/spdy/DefaultSpdySynReplyFrame.java | 81 + .../codec/spdy/DefaultSpdySynStreamFrame.java | 142 + .../spdy/DefaultSpdyWindowUpdateFrame.java | 78 + .../handler/codec/spdy/SpdyCodecUtil.java | 328 + .../handler/codec/spdy/SpdyDataFrame.java | 65 + .../netty/handler/codec/spdy/SpdyFrame.java | 23 + .../handler/codec/spdy/SpdyFrameCodec.java | 410 ++ .../handler/codec/spdy/SpdyFrameDecoder.java | 457 ++ .../codec/spdy/SpdyFrameDecoderDelegate.java | 99 + .../handler/codec/spdy/SpdyFrameEncoder.java | 161 + .../handler/codec/spdy/SpdyGoAwayFrame.java | 43 + .../codec/spdy/SpdyHeaderBlockDecoder.java | 50 + .../codec/spdy/SpdyHeaderBlockEncoder.java | 45 + .../spdy/SpdyHeaderBlockJZlibEncoder.java | 141 + .../codec/spdy/SpdyHeaderBlockRawDecoder.java | 306 + .../codec/spdy/SpdyHeaderBlockRawEncoder.java | 89 + .../spdy/SpdyHeaderBlockZlibDecoder.java | 125 + .../spdy/SpdyHeaderBlockZlibEncoder.java | 122 + .../netty/handler/codec/spdy/SpdyHeaders.java | 92 + .../handler/codec/spdy/SpdyHeadersFrame.java | 55 + .../handler/codec/spdy/SpdyHttpCodec.java | 51 + .../handler/codec/spdy/SpdyHttpDecoder.java | 463 ++ .../handler/codec/spdy/SpdyHttpEncoder.java | 331 + .../handler/codec/spdy/SpdyHttpHeaders.java | 51 + .../spdy/SpdyHttpResponseStreamIdHandler.java | 68 + .../handler/codec/spdy/SpdyPingFrame.java | 32 + .../codec/spdy/SpdyProtocolException.java | 87 + .../codec/spdy/SpdyRstStreamFrame.java | 38 + .../netty/handler/codec/spdy/SpdySession.java | 357 + .../codec/spdy/SpdySessionHandler.java | 854 +++ .../handler/codec/spdy/SpdySessionStatus.java | 111 + .../handler/codec/spdy/SpdySettingsFrame.java | 107 + .../handler/codec/spdy/SpdyStreamFrame.java | 43 + .../handler/codec/spdy/SpdyStreamStatus.java | 185 + .../handler/codec/spdy/SpdySynReplyFrame.java | 31 + .../codec/spdy/SpdySynStreamFrame.java | 65 + .../netty/handler/codec/spdy/SpdyVersion.java | 36 + .../codec/spdy/SpdyWindowUpdateFrame.java | 43 + .../handler/codec/spdy/package-info.java | 19 + .../src/main/java/module-info.java | 18 + .../codec/http/CombinedHttpHeadersTest.java | 387 ++ .../codec/http/DefaultHttpHeadersTest.java | 344 + .../codec/http/DefaultHttpRequestTest.java | 50 + .../codec/http/DefaultHttpResponseTest.java | 40 + .../EmptyHttpHeadersInitializationTest.java | 43 + .../codec/http/HttpChunkedInputTest.java | 166 + .../codec/http/HttpClientCodecTest.java | 440 ++ .../http/HttpClientUpgradeHandlerTest.java | 199 + .../HttpContentCompressorOptionsTest.java | 157 + .../codec/http/HttpContentCompressorTest.java | 1105 +++ .../codec/http/HttpContentDecoderTest.java | 874 +++ .../http/HttpContentDecompressorTest.java | 73 + .../codec/http/HttpContentEncoderTest.java | 465 ++ .../codec/http/HttpHeaderDateFormatTest.java | 68 + .../http/HttpHeaderValidationUtilTest.java | 584 ++ .../handler/codec/http/HttpHeadersTest.java | 106 + .../codec/http/HttpHeadersTestUtils.java | 139 + .../codec/http/HttpInvalidMessageTest.java | 122 + .../codec/http/HttpObjectAggregatorTest.java | 747 ++ .../codec/http/HttpRequestDecoderTest.java | 654 ++ .../codec/http/HttpRequestEncoderTest.java | 432 ++ .../codec/http/HttpResponseDecoderTest.java | 1125 ++++ .../codec/http/HttpResponseEncoderTest.java | 404 ++ .../codec/http/HttpResponseStatusTest.java | 147 + .../codec/http/HttpServerCodecTest.java | 185 + .../HttpServerExpectContinueHandlerTest.java | 85 + .../http/HttpServerKeepAliveHandlerTest.java | 235 + .../http/HttpServerUpgradeHandlerTest.java | 232 + .../handler/codec/http/HttpUtilTest.java | 449 ++ .../MultipleContentLengthHeadersTest.java | 130 + .../codec/http/QueryStringDecoderTest.java | 384 ++ .../codec/http/QueryStringEncoderTest.java | 81 + .../codec/http/ReadOnlyHttpHeadersTest.java | 166 + .../http/cookie/ClientCookieDecoderTest.java | 292 + .../http/cookie/ClientCookieEncoderTest.java | 73 + .../http/cookie/ServerCookieDecoderTest.java | 220 + .../http/cookie/ServerCookieEncoderTest.java | 160 + .../codec/http/cors/CorsConfigTest.java | 146 + .../codec/http/cors/CorsHandlerTest.java | 591 ++ .../multipart/AbstractDiskHttpDataTest.java | 128 + .../multipart/AbstractMemoryHttpDataTest.java | 212 + .../multipart/DefaultHttpDataFactoryTest.java | 164 + .../multipart/DeleteFileOnExitHookTest.java | 87 + .../http/multipart/DiskFileUploadTest.java | 297 + .../codec/http/multipart/HttpDataTest.java | 150 + .../HttpPostMultiPartRequestDecoderTest.java | 513 ++ .../multipart/HttpPostRequestDecoderTest.java | 1043 +++ .../multipart/HttpPostRequestEncoderTest.java | 471 ++ .../HttpPostStandardRequestDecoderTest.java | 90 + .../http/multipart/MemoryFileUploadTest.java | 30 + .../codec/http/multipart/MixedTest.java | 76 + .../websocketx/CloseWebSocketFrameTest.java | 104 + .../WebSocket00FrameEncoderTest.java | 47 + .../WebSocket08EncoderDecoderTest.java | 222 + .../WebSocket08FrameDecoderTest.java | 97 + .../WebSocketClientHandshaker00Test.java | 51 + .../WebSocketClientHandshaker07Test.java | 73 + .../WebSocketClientHandshaker08Test.java | 30 + .../WebSocketClientHandshaker13Test.java | 38 + .../WebSocketClientHandshakerTest.java | 517 ++ .../websocketx/WebSocketCloseStatusTest.java | 154 + .../WebSocketFrameAggregatorTest.java | 154 + .../WebSocketHandshakeExceptionTest.java | 76 + .../WebSocketHandshakeHandOverTest.java | 371 + .../WebSocketProtocolHandlerTest.java | 188 + .../websocketx/WebSocketRequestBuilder.java | 165 + .../WebSocketServerHandshaker00Test.java | 131 + .../WebSocketServerHandshaker07Test.java | 30 + .../WebSocketServerHandshaker08Test.java | 97 + .../WebSocketServerHandshaker13Test.java | 226 + .../WebSocketServerHandshakerFactoryTest.java | 55 + .../WebSocketServerHandshakerTest.java | 180 + .../WebSocketServerProtocolHandlerTest.java | 519 ++ .../WebSocketUtf8FrameValidatorTest.java | 79 + .../http/websocketx/WebSocketUtilTest.java | 74 + .../WebSocketClientExtensionHandlerTest.java | 277 + .../WebSocketExtensionFilterProviderTest.java | 33 + .../WebSocketExtensionFilterTest.java | 88 + .../WebSocketExtensionTestUtil.java | 121 + .../WebSocketExtensionUtilTest.java | 85 + .../WebSocketServerExtensionHandlerTest.java | 287 + ...ateFrameClientExtensionHandshakerTest.java | 88 + ...ateFrameServerExtensionHandshakerTest.java | 88 + .../PerFrameDeflateDecoderTest.java | 155 + .../PerFrameDeflateEncoderTest.java | 189 + ...eDeflateClientExtensionHandshakerTest.java | 248 + .../PerMessageDeflateDecoderTest.java | 400 ++ .../PerMessageDeflateEncoderTest.java | 324 + ...eDeflateServerExtensionHandshakerTest.java | 176 + ...WebSocketServerCompressionHandlerTest.java | 201 + .../handler/codec/rtsp/RtspDecoderTest.java | 72 + .../handler/codec/rtsp/RtspEncoderTest.java | 173 + .../codec/spdy/DefaultSpdyHeadersTest.java | 58 + .../codec/spdy/SpdyFrameDecoderTest.java | 1330 ++++ .../spdy/SpdyHeaderBlockRawDecoderTest.java | 516 ++ .../spdy/SpdyHeaderBlockZlibDecoderTest.java | 245 + .../codec/spdy/SpdySessionHandlerTest.java | 392 ++ .../src/test/resources/file-01.txt | 1 + .../src/test/resources/file-02.txt | 1 + .../src/test/resources/file-03.txt | 1 + .../test/resources/junit-platform.properties | 16 + .../src/test/resources/logging.properties | 7 + netty-handler-codec-http2/build.gradle | 15 + ...AbstractHttp2ConnectionHandlerBuilder.java | 660 ++ .../http2/AbstractHttp2StreamChannel.java | 1160 ++++ .../codec/http2/AbstractHttp2StreamFrame.java | 59 + ...tractInboundHttp2ToHttpAdapterBuilder.java | 136 + .../handler/codec/http2/CharSequenceMap.java | 48 + .../CleartextHttp2ServerUpgradeHandler.java | 107 + .../CompressorHttp2ConnectionEncoder.java | 423 ++ .../DecoratingHttp2ConnectionDecoder.java | 80 + .../DecoratingHttp2ConnectionEncoder.java | 73 + .../http2/DecoratingHttp2FrameWriter.java | 116 + .../codec/http2/DefaultHttp2Connection.java | 1080 +++ .../http2/DefaultHttp2ConnectionDecoder.java | 843 +++ .../http2/DefaultHttp2ConnectionEncoder.java | 634 ++ .../codec/http2/DefaultHttp2DataFrame.java | 198 + .../codec/http2/DefaultHttp2FrameReader.java | 775 +++ .../codec/http2/DefaultHttp2FrameWriter.java | 627 ++ .../codec/http2/DefaultHttp2GoAwayFrame.java | 179 + .../codec/http2/DefaultHttp2Headers.java | 303 + .../http2/DefaultHttp2HeadersDecoder.java | 213 + .../http2/DefaultHttp2HeadersEncoder.java | 109 + .../codec/http2/DefaultHttp2HeadersFrame.java | 116 + .../DefaultHttp2LocalFlowController.java | 648 ++ .../codec/http2/DefaultHttp2PingFrame.java | 75 + .../http2/DefaultHttp2PriorityFrame.java | 91 + .../http2/DefaultHttp2PushPromiseFrame.java | 101 + .../DefaultHttp2RemoteFlowController.java | 768 +++ .../codec/http2/DefaultHttp2ResetFrame.java | 85 + .../http2/DefaultHttp2SettingsAckFrame.java | 33 + .../http2/DefaultHttp2SettingsFrame.java | 63 + .../codec/http2/DefaultHttp2UnknownFrame.java | 140 + .../http2/DefaultHttp2WindowUpdateFrame.java | 54 + .../DelegatingDecompressorFrameListener.java | 435 ++ .../codec/http2/EmptyHttp2Headers.java | 83 + .../handler/codec/http2/HpackDecoder.java | 571 ++ .../codec/http2/HpackDynamicTable.java | 201 + .../handler/codec/http2/HpackEncoder.java | 555 ++ .../handler/codec/http2/HpackHeaderField.java | 69 + .../codec/http2/HpackHuffmanDecoder.java | 4736 +++++++++++++ .../codec/http2/HpackHuffmanEncoder.java | 194 + .../handler/codec/http2/HpackStaticTable.java | 257 + .../netty/handler/codec/http2/HpackUtil.java | 372 + .../http2/Http2ChannelDuplexHandler.java | 94 + .../codec/http2/Http2ClientUpgradeCodec.java | 175 + .../handler/codec/http2/Http2CodecUtil.java | 404 ++ .../handler/codec/http2/Http2Connection.java | 356 + .../codec/http2/Http2ConnectionAdapter.java | 52 + .../codec/http2/Http2ConnectionDecoder.java | 77 + .../codec/http2/Http2ConnectionEncoder.java | 68 + .../codec/http2/Http2ConnectionHandler.java | 1009 +++ .../http2/Http2ConnectionHandlerBuilder.java | 121 + ...onPrefaceAndSettingsFrameWrittenEvent.java | 31 + .../http2/Http2ControlFrameLimitEncoder.java | 113 + .../codec/http2/Http2DataChunkedInput.java | 116 + .../handler/codec/http2/Http2DataFrame.java | 73 + .../handler/codec/http2/Http2DataWriter.java | 45 + .../Http2EmptyDataFrameConnectionDecoder.java | 56 + .../http2/Http2EmptyDataFrameListener.java | 65 + .../netty/handler/codec/http2/Http2Error.java | 65 + .../codec/http2/Http2EventAdapter.java | 115 + .../handler/codec/http2/Http2Exception.java | 344 + .../netty/handler/codec/http2/Http2Flags.java | 207 + .../codec/http2/Http2FlowController.java | 81 + .../netty/handler/codec/http2/Http2Frame.java | 28 + .../codec/http2/Http2FrameAdapter.java | 90 + .../handler/codec/http2/Http2FrameCodec.java | 769 +++ .../codec/http2/Http2FrameCodecBuilder.java | 245 + .../codec/http2/Http2FrameListener.java | 220 + .../http2/Http2FrameListenerDecorator.java | 105 + .../handler/codec/http2/Http2FrameLogger.java | 176 + .../handler/codec/http2/Http2FrameReader.java | 62 + .../codec/http2/Http2FrameSizePolicy.java | 39 + .../handler/codec/http2/Http2FrameStream.java | 39 + .../codec/http2/Http2FrameStreamEvent.java | 52 + .../http2/Http2FrameStreamException.java | 47 + .../codec/http2/Http2FrameStreamVisitor.java | 38 + .../handler/codec/http2/Http2FrameTypes.java | 38 + .../handler/codec/http2/Http2FrameWriter.java | 229 + .../handler/codec/http2/Http2GoAwayFrame.java | 89 + .../handler/codec/http2/Http2Headers.java | 205 + .../codec/http2/Http2HeadersDecoder.java | 80 + .../codec/http2/Http2HeadersEncoder.java | 110 + .../codec/http2/Http2HeadersFrame.java | 40 + .../codec/http2/Http2InboundFrameLogger.java | 147 + .../codec/http2/Http2LifecycleManager.java | 98 + .../codec/http2/Http2LocalFlowController.java | 87 + .../codec/http2/Http2MaxRstFrameDecoder.java | 58 + .../codec/http2/Http2MaxRstFrameListener.java | 60 + .../Http2MultiplexActiveStreamsException.java | 33 + .../codec/http2/Http2MultiplexCodec.java | 346 + .../http2/Http2MultiplexCodecBuilder.java | 260 + .../codec/http2/Http2MultiplexHandler.java | 415 ++ .../http2/Http2NoMoreStreamIdsException.java | 36 + .../codec/http2/Http2OutboundFrameLogger.java | 139 + .../handler/codec/http2/Http2PingFrame.java | 36 + .../codec/http2/Http2PriorityFrame.java | 44 + .../http2/Http2PromisedRequestVerifier.java | 74 + .../codec/http2/Http2PushPromiseFrame.java | 55 + .../http2/Http2RemoteFlowController.java | 170 + .../handler/codec/http2/Http2ResetFrame.java | 28 + .../codec/http2/Http2SecurityUtil.java | 80 + .../codec/http2/Http2ServerUpgradeCodec.java | 213 + .../handler/codec/http2/Http2Settings.java | 282 + .../codec/http2/Http2SettingsAckFrame.java | 29 + .../codec/http2/Http2SettingsFrame.java | 28 + .../http2/Http2SettingsReceivedConsumer.java | 25 + .../handler/codec/http2/Http2Stream.java | 177 + .../codec/http2/Http2StreamChannel.java | 33 + .../http2/Http2StreamChannelBootstrap.java | 256 + .../codec/http2/Http2StreamChannelId.java | 76 + .../handler/codec/http2/Http2StreamFrame.java | 38 + .../Http2StreamFrameToHttpObjectCodec.java | 287 + .../codec/http2/Http2StreamVisitor.java | 31 + .../codec/http2/Http2UnknownFrame.java | 58 + .../codec/http2/Http2WindowUpdateFrame.java | 30 + .../codec/http2/HttpConversionUtil.java | 710 ++ .../http2/HttpToHttp2ConnectionHandler.java | 166 + .../HttpToHttp2ConnectionHandlerBuilder.java | 123 + .../http2/InboundHttp2ToHttpAdapter.java | 360 + .../InboundHttp2ToHttpAdapterBuilder.java | 65 + .../http2/InboundHttpToHttp2Adapter.java | 81 + .../handler/codec/http2/MaxCapacityQueue.java | 129 + .../codec/http2/ReadOnlyHttp2Headers.java | 892 +++ .../codec/http2/StreamBufferingEncoder.java | 382 ++ .../codec/http2/StreamByteDistributor.java | 112 + .../http2/UniformStreamByteDistributor.java | 205 + .../WeightedFairQueueByteDistributor.java | 803 +++ .../handler/codec/http2/package-info.java | 22 + .../src/main/java/module-info.java | 12 + ...tDecoratingHttp2ConnectionDecoderTest.java | 63 + ...airQueueByteDistributorDependencyTest.java | 72 + ...leartextHttp2ServerUpgradeHandlerTest.java | 291 + .../codec/http2/DataCompressionHttp2Test.java | 533 ++ .../DecoratingHttp2ConnectionEncoderTest.java | 54 + .../DefaultHttp2ConnectionDecoderTest.java | 1055 +++ .../DefaultHttp2ConnectionEncoderTest.java | 957 +++ .../http2/DefaultHttp2ConnectionTest.java | 731 ++ .../http2/DefaultHttp2FrameReaderTest.java | 453 ++ .../http2/DefaultHttp2FrameWriterTest.java | 390 ++ .../http2/DefaultHttp2HeadersDecoderTest.java | 308 + .../http2/DefaultHttp2HeadersEncoderTest.java | 70 + .../codec/http2/DefaultHttp2HeadersTest.java | 253 + .../DefaultHttp2LocalFlowControllerTest.java | 460 ++ .../DefaultHttp2PushPromiseFrameTest.java | 239 + .../DefaultHttp2RemoteFlowControllerTest.java | 1146 ++++ .../codec/http2/HashCollisionTest.java | 177 + .../handler/codec/http2/HpackDecoderTest.java | 914 +++ .../codec/http2/HpackDynamicTableTest.java | 142 + .../handler/codec/http2/HpackEncoderTest.java | 281 + .../handler/codec/http2/HpackHuffmanTest.java | 247 + .../codec/http2/HpackStaticTableTest.java | 76 + .../netty/handler/codec/http2/HpackTest.java | 61 + .../handler/codec/http2/HpackTestCase.java | 290 + .../http2/Http2ClientUpgradeCodecTest.java | 86 + .../http2/Http2ConnectionHandlerTest.java | 884 +++ .../http2/Http2ConnectionRoundtripTest.java | 1325 ++++ .../Http2ControlFrameLimitEncoderTest.java | 277 + .../http2/Http2DataChunkedInputTest.java | 177 + .../codec/http2/Http2DefaultFramesTest.java | 44 + ...p2EmptyDataFrameConnectionDecoderTest.java | 28 + .../Http2EmptyDataFrameListenerTest.java | 144 + .../codec/http2/Http2ExceptionTest.java | 50 + .../codec/http2/Http2FrameCodecTest.java | 941 +++ .../codec/http2/Http2FrameInboundWriter.java | 340 + .../codec/http2/Http2FrameRoundtripTest.java | 491 ++ .../codec/http2/Http2HeaderBlockIOTest.java | 101 + ...Http2MaxRstFrameConnectionDecoderTest.java | 28 + .../http2/Http2MaxRstFrameListenerTest.java | 68 + .../Http2MultiplexClientUpgradeTest.java | 96 + .../http2/Http2MultiplexCodecBuilderTest.java | 268 + .../Http2MultiplexCodecClientUpgradeTest.java | 34 + .../codec/http2/Http2MultiplexCodecTest.java | 40 + ...ttp2MultiplexHandlerClientUpgradeTest.java | 30 + .../http2/Http2MultiplexHandlerTest.java | 107 + .../codec/http2/Http2MultiplexTest.java | 1441 ++++ .../http2/Http2MultiplexTransportTest.java | 746 ++ .../codec/http2/Http2SecurityUtilTest.java | 49 + .../http2/Http2ServerUpgradeCodecTest.java | 107 + .../codec/http2/Http2SettingsTest.java | 253 + .../Http2StreamChannelBootstrapTest.java | 166 + .../codec/http2/Http2StreamChannelIdTest.java | 58 + ...Http2StreamFrameToHttpObjectCodecTest.java | 995 +++ .../handler/codec/http2/Http2TestUtil.java | 538 ++ .../codec/http2/HttpConversionUtilTest.java | 279 + .../HttpToHttp2ConnectionHandlerTest.java | 640 ++ .../codec/http2/InOrderHttp2Headers.java | 104 + .../http2/InboundHttp2ToHttpAdapterTest.java | 853 +++ .../codec/http2/LastInboundHandler.java | 222 + .../codec/http2/ReadOnlyHttp2HeadersTest.java | 298 + .../http2/StreamBufferingEncoderTest.java | 581 ++ .../codec/http2/TestChannelInitializer.java | 122 + .../codec/http2/TestHeaderListener.java | 49 + ...reamByteDistributorFlowControllerTest.java | 22 + .../UniformStreamByteDistributorTest.java | 283 + ...ueueByteDistributorDependencyTreeTest.java | 980 +++ .../WeightedFairQueueByteDistributorTest.java | 964 +++ ...htedFairQueueRemoteFlowControllerTest.java | 22 + .../http2/testdata/testDuplicateHeaders.json | 66 + .../codec/http2/testdata/testEmpty.json | 14 + .../codec/http2/testdata/testEviction.json | 57 + .../testdata/testMaxHeaderTableSize.json | 55 + .../http2/testdata/testSpecExampleC2_1.json | 18 + .../http2/testdata/testSpecExampleC2_2.json | 17 + .../http2/testdata/testSpecExampleC2_3.json | 17 + .../http2/testdata/testSpecExampleC2_4.json | 17 + .../http2/testdata/testSpecExampleC3.json | 57 + .../http2/testdata/testSpecExampleC4.json | 58 + .../http2/testdata/testSpecExampleC5.json | 68 + .../http2/testdata/testSpecExampleC6.json | 68 + .../testdata/testStaticTableEntries.json | 72 + .../testStaticTableResponseEntries.json | 23 + .../test/resources/junit-platform.properties | 16 + .../src/test/resources/logging.properties | 7 + netty-handler-codec-protobuf/build.gradle | 4 + .../codec/protobuf/ProtobufDecoder.java | 133 + .../codec/protobuf/ProtobufEncoder.java | 74 + .../ProtobufVarint32FrameDecoder.java | 119 + .../ProtobufVarint32LengthFieldPrepender.java | 88 + .../handler/codec/protobuf/package-info.java | 23 + .../src/main/java/module-info.java | 8 + netty-handler-codec/build.gradle | 4 + .../handler/codec/AsciiHeadersEncoder.java | 121 + .../handler/codec/ByteToMessageCodec.java | 175 + .../handler/codec/ByteToMessageDecoder.java | 586 ++ .../codec/CharSequenceValueConverter.java | 150 + .../netty/handler/codec/CodecException.java | 51 + .../netty/handler/codec/CodecOutputList.java | 232 + .../codec/CorruptedFrameException.java | 52 + .../handler/codec/DatagramPacketDecoder.java | 115 + .../handler/codec/DatagramPacketEncoder.java | 147 + .../io/netty/handler/codec/DateFormatter.java | 448 ++ .../netty/handler/codec/DecoderException.java | 51 + .../io/netty/handler/codec/DecoderResult.java | 76 + .../handler/codec/DecoderResultProvider.java | 33 + .../netty/handler/codec/DefaultHeaders.java | 1446 ++++ .../handler/codec/DefaultHeadersImpl.java | 34 + .../codec/DelimiterBasedFrameDecoder.java | 332 + .../io/netty/handler/codec/Delimiters.java | 49 + .../io/netty/handler/codec/EmptyHeaders.java | 526 ++ .../netty/handler/codec/EncoderException.java | 51 + .../codec/FixedLengthFrameDecoder.java | 79 + .../java/io/netty/handler/codec/Headers.java | 998 +++ .../io/netty/handler/codec/HeadersUtils.java | 221 + .../codec/LengthFieldBasedFrameDecoder.java | 516 ++ .../handler/codec/LengthFieldPrepender.java | 201 + .../handler/codec/LineBasedFrameDecoder.java | 180 + .../codec/MessageAggregationException.java | 39 + .../handler/codec/MessageAggregator.java | 471 ++ .../handler/codec/MessageToByteEncoder.java | 160 + .../handler/codec/MessageToMessageCodec.java | 148 + .../codec/MessageToMessageDecoder.java | 121 + .../codec/MessageToMessageEncoder.java | 156 + .../PrematureChannelClosureException.java | 54 + .../codec/ProtocolDetectionResult.java | 80 + .../handler/codec/ProtocolDetectionState.java | 36 + .../netty/handler/codec/ReplayingDecoder.java | 424 ++ .../codec/ReplayingDecoderByteBuf.java | 1147 ++++ .../handler/codec/TooLongFrameException.java | 52 + .../UnsupportedMessageTypeException.java | 63 + .../codec/UnsupportedValueConverter.java | 125 + .../netty/handler/codec/ValueConverter.java | 57 + .../io/netty/handler/codec/base64/Base64.java | 429 ++ .../handler/codec/base64/Base64Decoder.java | 64 + .../handler/codec/base64/Base64Dialect.java | 207 + .../handler/codec/base64/Base64Encoder.java | 66 + .../handler/codec/base64/package-info.java | 23 + .../handler/codec/bytes/ByteArrayDecoder.java | 58 + .../handler/codec/bytes/ByteArrayEncoder.java | 59 + .../handler/codec/bytes/package-info.java | 21 + .../handler/codec/json/JsonObjectDecoder.java | 237 + .../handler/codec/json/package-info.java | 20 + .../io/netty/handler/codec/package-info.java | 22 + .../serialization/CachingClassResolver.java | 46 + .../ClassLoaderClassResolver.java | 35 + .../codec/serialization/ClassResolver.java | 36 + .../codec/serialization/ClassResolvers.java | 118 + .../CompactObjectInputStream.java | 75 + .../CompactObjectOutputStream.java | 49 + .../CompatibleObjectEncoder.java | 101 + .../codec/serialization/ObjectDecoder.java | 92 + .../ObjectDecoderInputStream.java | 255 + .../codec/serialization/ObjectEncoder.java | 74 + .../ObjectEncoderOutputStream.java | 194 + .../codec/serialization/ReferenceMap.java | 102 + .../codec/serialization/SoftReferenceMap.java | 33 + .../codec/serialization/WeakReferenceMap.java | 33 + .../codec/serialization/package-info.java | 32 + .../handler/codec/string/LineEncoder.java | 94 + .../handler/codec/string/LineSeparator.java | 83 + .../handler/codec/string/StringDecoder.java | 79 + .../handler/codec/string/StringEncoder.java | 79 + .../handler/codec/string/package-info.java | 21 + .../handler/codec/xml/XmlFrameDecoder.java | 245 + .../netty/handler/codec/xml/package-info.java | 20 + .../src/main/java/module-info.java | 11 + .../handler/codec/ByteToMessageCodecTest.java | 113 + .../codec/ByteToMessageDecoderTest.java | 656 ++ .../codec/CharSequenceValueConverterTest.java | 92 + .../handler/codec/CodecOutputListTest.java | 54 + .../codec/DatagramPacketDecoderTest.java | 96 + .../codec/DatagramPacketEncoderTest.java | 139 + .../handler/codec/DateFormatterTest.java | 145 + .../handler/codec/DefaultHeadersTest.java | 833 +++ .../codec/DelimiterBasedFrameDecoderTest.java | 128 + .../netty/handler/codec/EmptyHeadersTest.java | 585 ++ .../LengthFieldBasedFrameDecoderTest.java | 89 + .../codec/LineBasedFrameDecoderTest.java | 216 + .../handler/codec/MessageAggregatorTest.java | 137 + .../codec/MessageToMessageEncoderTest.java | 86 + .../codec/ReplayingDecoderByteBufTest.java | 129 + .../handler/codec/ReplayingDecoderTest.java | 319 + .../handler/codec/base64/Base64Test.java | 191 + .../codec/bytes/ByteArrayDecoderTest.java | 58 + .../codec/bytes/ByteArrayEncoderTest.java | 67 + .../frame/DelimiterBasedFrameDecoderTest.java | 76 + .../LengthFieldBasedFrameDecoderTest.java | 73 + .../codec/frame/LengthFieldPrependerTest.java | 116 + .../handler/codec/frame/package-info.java | 20 + .../codec/json/JsonObjectDecoderTest.java | 418 ++ .../CompactObjectSerializationTest.java | 36 + .../CompatibleObjectEncoderTest.java | 79 + .../handler/codec/string/LineEncoderTest.java | 52 + .../codec/string/StringDecoderTest.java | 42 + .../codec/string/StringEncoderTest.java | 42 + .../codec/xml/XmlFrameDecoderTest.java | 230 + .../io/netty/handler/codec/xml/sample-01.xml | 1 + .../io/netty/handler/codec/xml/sample-02.xml | 3 + .../io/netty/handler/codec/xml/sample-03.xml | 65 + .../io/netty/handler/codec/xml/sample-04.xml | 752 +++ .../io/netty/handler/codec/xml/sample-05.xml | 81 + .../io/netty/handler/codec/xml/sample-06.xml | 62 + netty-handler-ssl/build.gradle | 18 + .../netty/handler/ssl/AbstractSniHandler.java | 222 + .../ssl/ApplicationProtocolAccessor.java | 30 + .../ssl/ApplicationProtocolConfig.java | 184 + .../handler/ssl/ApplicationProtocolNames.java | 59 + ...ApplicationProtocolNegotiationHandler.java | 210 + .../ssl/ApplicationProtocolNegotiator.java | 37 + .../handler/ssl/ApplicationProtocolUtil.java | 65 + .../io/netty/handler/ssl/AsyncRunnable.java | 20 + .../io/netty/handler/ssl/BouncyCastle.java | 54 + .../ssl/BouncyCastleAlpnSslEngine.java | 62 + .../handler/ssl/BouncyCastleAlpnSslUtils.java | 259 + .../handler/ssl/BouncyCastlePemReader.java | 223 + .../handler/ssl/CipherSuiteConverter.java | 516 ++ .../netty/handler/ssl/CipherSuiteFilter.java | 34 + .../java/io/netty/handler/ssl/Ciphers.java | 754 +++ .../java/io/netty/handler/ssl/ClientAuth.java | 38 + .../java/io/netty/handler/ssl/Conscrypt.java | 75 + .../handler/ssl/ConscryptAlpnSslEngine.java | 212 + .../ssl/DefaultOpenSslKeyMaterial.java | 126 + .../handler/ssl/DelegatingSslContext.java | 122 + .../EnhancingX509ExtendedTrustManager.java | 124 + .../handler/ssl/ExtendedOpenSslSession.java | 241 + .../io/netty/handler/ssl/GroupsConverter.java | 50 + .../ssl/IdentityCipherSuiteFilter.java | 64 + .../handler/ssl/Java7SslParametersUtils.java | 38 + .../io/netty/handler/ssl/Java8SslUtils.java | 114 + .../JdkAlpnApplicationProtocolNegotiator.java | 154 + .../netty/handler/ssl/JdkAlpnSslEngine.java | 207 + .../io/netty/handler/ssl/JdkAlpnSslUtils.java | 181 + .../ssl/JdkApplicationProtocolNegotiator.java | 162 + .../JdkBaseApplicationProtocolNegotiator.java | 209 + ...kDefaultApplicationProtocolNegotiator.java | 60 + .../handler/ssl/JdkSslClientContext.java | 313 + .../io/netty/handler/ssl/JdkSslContext.java | 514 ++ .../io/netty/handler/ssl/JdkSslEngine.java | 215 + .../handler/ssl/JdkSslServerContext.java | 317 + .../handler/ssl/NotSslRecordException.java | 48 + .../java/io/netty/handler/ssl/OpenSsl.java | 790 +++ .../OpenSslApplicationProtocolNegotiator.java | 40 + .../ssl/OpenSslAsyncPrivateKeyMethod.java | 58 + .../OpenSslCachingKeyMaterialProvider.java | 79 + .../OpenSslCachingX509KeyManagerFactory.java | 80 + ...penSslCertificateCompressionAlgorithm.java | 64 + .../OpenSslCertificateCompressionConfig.java | 137 + .../ssl/OpenSslCertificateException.java | 81 + .../handler/ssl/OpenSslClientContext.java | 211 + .../ssl/OpenSslClientSessionCache.java | 138 + .../io/netty/handler/ssl/OpenSslContext.java | 60 + .../handler/ssl/OpenSslContextOption.java | 77 + ...lDefaultApplicationProtocolNegotiator.java | 53 + .../io/netty/handler/ssl/OpenSslEngine.java | 41 + .../netty/handler/ssl/OpenSslEngineMap.java | 35 + .../netty/handler/ssl/OpenSslKeyMaterial.java | 59 + .../ssl/OpenSslKeyMaterialManager.java | 138 + .../ssl/OpenSslKeyMaterialProvider.java | 154 + ...enSslNpnApplicationProtocolNegotiator.java | 59 + .../netty/handler/ssl/OpenSslPrivateKey.java | 191 + .../handler/ssl/OpenSslPrivateKeyMethod.java | 62 + .../handler/ssl/OpenSslServerContext.java | 371 + .../ssl/OpenSslServerSessionContext.java | 50 + .../io/netty/handler/ssl/OpenSslSession.java | 62 + .../handler/ssl/OpenSslSessionCache.java | 492 ++ .../handler/ssl/OpenSslSessionContext.java | 229 + .../netty/handler/ssl/OpenSslSessionId.java | 66 + .../handler/ssl/OpenSslSessionStats.java | 253 + .../handler/ssl/OpenSslSessionTicketKey.java | 78 + .../ssl/OpenSslX509KeyManagerFactory.java | 416 ++ .../ssl/OpenSslX509TrustManagerWrapper.java | 202 + .../netty/handler/ssl/OptionalSslHandler.java | 117 + .../java/io/netty/handler/ssl/PemEncoded.java | 55 + .../io/netty/handler/ssl/PemPrivateKey.java | 230 + .../java/io/netty/handler/ssl/PemReader.java | 203 + .../java/io/netty/handler/ssl/PemValue.java | 105 + .../netty/handler/ssl/PemX509Certificate.java | 403 ++ .../handler/ssl/PseudoRandomFunction.java | 94 + .../ReferenceCountedOpenSslClientContext.java | 320 + .../ssl/ReferenceCountedOpenSslContext.java | 1146 ++++ .../ssl/ReferenceCountedOpenSslEngine.java | 2761 ++++++++ .../ReferenceCountedOpenSslServerContext.java | 298 + .../ssl/SignatureAlgorithmConverter.java | 74 + .../netty/handler/ssl/SniCompletionEvent.java | 54 + .../java/io/netty/handler/ssl/SniHandler.java | 234 + .../handler/ssl/SslClientHelloHandler.java | 349 + .../handler/ssl/SslCloseCompletionEvent.java | 37 + .../handler/ssl/SslClosedEngineException.java | 31 + .../netty/handler/ssl/SslCompletionEvent.java | 53 + .../java/io/netty/handler/ssl/SslContext.java | 1363 ++++ .../netty/handler/ssl/SslContextBuilder.java | 632 ++ .../netty/handler/ssl/SslContextOption.java | 86 + .../java/io/netty/handler/ssl/SslHandler.java | 2478 +++++++ .../ssl/SslHandshakeCompletionEvent.java | 39 + .../ssl/SslHandshakeTimeoutException.java | 28 + .../handler/ssl/SslMasterKeyHandler.java | 199 + .../io/netty/handler/ssl/SslProtocols.java | 76 + .../io/netty/handler/ssl/SslProvider.java | 115 + .../java/io/netty/handler/ssl/SslUtils.java | 509 ++ .../ssl/StacklessSSLHandshakeException.java | 46 + .../ssl/SupportedCipherSuiteFilter.java | 58 + .../handler/ssl/ocsp/OcspClientHandler.java | 57 + .../netty/handler/ssl/ocsp/package-info.java | 23 + .../io/netty/handler/ssl/package-info.java | 21 + .../BouncyCastleSelfSignedCertGenerator.java | 64 + .../util/FingerprintTrustManagerFactory.java | 266 + ...FingerprintTrustManagerFactoryBuilder.java | 87 + .../ssl/util/InsecureTrustManagerFactory.java | 77 + .../ssl/util/KeyManagerFactoryWrapper.java | 43 + .../ssl/util/LazyJavaxX509Certificate.java | 202 + .../handler/ssl/util/LazyX509Certificate.java | 242 + .../ssl/util/SelfSignedCertificate.java | 406 ++ .../ssl/util/SimpleKeyManagerFactory.java | 154 + .../ssl/util/SimpleTrustManagerFactory.java | 156 + .../ssl/util/ThreadLocalInsecureRandom.java | 101 + .../ssl/util/TrustManagerFactoryWrapper.java | 43 + .../ssl/util/X509KeyManagerWrapper.java | 78 + .../ssl/util/X509TrustManagerWrapper.java | 76 + .../netty/handler/ssl/util/package-info.java | 20 + .../src/main/java/module-info.java | 14 + .../ssl/AmazonCorrettoSslEngineTest.java | 103 + ...icationProtocolNegotiationHandlerTest.java | 232 + .../handler/ssl/CipherSuiteCanaryTest.java | 285 + .../handler/ssl/CipherSuiteConverterTest.java | 417 ++ .../io/netty/handler/ssl/CloseNotifyTest.java | 230 + .../ssl/ConscryptJdkSslEngineInteropTest.java | 90 + .../ConscryptOpenSslEngineInteropTest.java | 219 + .../handler/ssl/ConscryptSslEngineTest.java | 99 + .../handler/ssl/DelegatingSslContextTest.java | 60 + .../EnhancedX509ExtendedTrustManagerTest.java | 326 + .../ssl/IdentityCipherSuiteFilterTest.java | 46 + .../netty/handler/ssl/Java8SslTestUtils.java | 84 + .../ssl/JdkConscryptSslEngineInteropTest.java | 105 + .../ssl/JdkOpenSslEngineInteroptTest.java | 256 + .../handler/ssl/JdkSslClientContextTest.java | 29 + .../netty/handler/ssl/JdkSslEngineTest.java | 358 + .../handler/ssl/JdkSslRenegotiateTest.java | 24 + .../handler/ssl/JdkSslServerContextTest.java | 27 + ...OpenSslCachingKeyMaterialProviderTest.java | 92 + .../OpenSslCertificateCompressionTest.java | 441 ++ .../ssl/OpenSslCertificateExceptionTest.java | 61 + .../handler/ssl/OpenSslClientContextTest.java | 36 + .../OpenSslConscryptSslEngineInteropTest.java | 191 + .../netty/handler/ssl/OpenSslEngineTest.java | 1649 +++++ .../handler/ssl/OpenSslEngineTestParam.java | 34 + .../ssl/OpenSslErrorStackAssertSSLEngine.java | 442 ++ .../ssl/OpenSslJdkSslEngineInteroptTest.java | 203 + .../ssl/OpenSslKeyMaterialManagerTest.java | 83 + .../ssl/OpenSslKeyMaterialProviderTest.java | 180 + .../ssl/OpenSslPrivateKeyMethodTest.java | 481 ++ .../handler/ssl/OpenSslRenegotiateTest.java | 45 + .../handler/ssl/OpenSslServerContextTest.java | 34 + .../io/netty/handler/ssl/OpenSslTest.java | 31 + .../netty/handler/ssl/OpenSslTestUtils.java | 27 + ...nSslX509KeyManagerFactoryProviderTest.java | 38 + .../handler/ssl/OptionalSslHandlerTest.java | 123 + .../ssl/ParameterizedSslHandlerTest.java | 709 ++ .../io/netty/handler/ssl/PemEncodedTest.java | 123 + .../io/netty/handler/ssl/PemReaderTest.java | 91 + .../handler/ssl/PseudoRandomFunctionTest.java | 52 + .../ReferenceCountedOpenSslEngineTest.java | 112 + .../io/netty/handler/ssl/RenegotiateTest.java | 154 + .../io/netty/handler/ssl/SSLEngineTest.java | 4489 ++++++++++++ .../ssl/SignatureAlgorithmConverterTest.java | 59 + .../handler/ssl/SniClientJava8TestUtil.java | 349 + .../io/netty/handler/ssl/SniClientTest.java | 179 + .../io/netty/handler/ssl/SniHandlerTest.java | 858 +++ .../handler/ssl/SslContextBuilderTest.java | 428 ++ .../io/netty/handler/ssl/SslContextTest.java | 371 + .../ssl/SslContextTrustManagerTest.java | 122 + .../io/netty/handler/ssl/SslErrorTest.java | 312 + .../io/netty/handler/ssl/SslHandlerTest.java | 1885 ++++++ .../io/netty/handler/ssl/SslUtilsTest.java | 175 + .../io/netty/handler/ssl/ocsp/OcspTest.java | 535 ++ .../FingerprintTrustManagerFactoryTest.java | 141 + .../ssl/util/SelfSignedCertificateTest.java | 50 + .../handler/ssl/ec_params_unsupported.pem | 18 + .../netty/handler/ssl/generate-certificate.sh | 22 + .../io/netty/handler/ssl/generate-certs.sh | 85 + .../io/netty/handler/ssl/localhost_server.pem | 17 + .../io/netty/handler/ssl/mutual_auth_ca.pem | 19 + .../netty/handler/ssl/mutual_auth_client.p12 | Bin 0 -> 3997 bytes .../ssl/mutual_auth_invalid_client.p12 | Bin 0 -> 3949 bytes .../netty/handler/ssl/mutual_auth_server.p12 | Bin 0 -> 3149 bytes .../netty/handler/ssl/notlocalhost_server.pem | 17 + .../io/netty/handler/ssl/openssl.cnf | 123 + .../handler/ssl/rsaValidation-user-certs.p12 | Bin 0 -> 2605 bytes .../ssl/rsaValidations-server-keystore.p12 | Bin 0 -> 3456 bytes .../io/netty/handler/ssl/rsapss-ca-cert.cert | 21 + .../netty/handler/ssl/rsapss-signing-ext.txt | 21 + .../io/netty/handler/ssl/test2_encrypted.pem | 29 + .../netty/handler/ssl/test2_unencrypted.pem | 28 + .../io/netty/handler/ssl/test_encrypted.pem | 29 + .../handler/ssl/test_encrypted_empty_pass.pem | 29 + .../io/netty/handler/ssl/test_unencrypted.pem | 24 + .../io/netty/handler/ssl/tm_test_ca_1a.pem | 19 + .../io/netty/handler/ssl/tm_test_ca_1b.pem | 19 + .../io/netty/handler/ssl/tm_test_ca_2.pem | 19 + .../io/netty/handler/ssl/tm_test_eec_1.pem | 19 + .../io/netty/handler/ssl/tm_test_eec_2.pem | 19 + .../io/netty/handler/ssl/tm_test_eec_3.pem | 19 + .../src/test/resources/logging.properties | 9 + netty-handler/build.gradle | 6 + .../address/DynamicAddressConnectHandler.java | 82 + .../address/ResolveAddressHandler.java | 66 + .../netty/handler/address/package-info.java | 20 + .../handler/flow/FlowControlHandler.java | 256 + .../io/netty/handler/flow/package-info.java | 20 + .../flush/FlushConsolidationHandler.java | 220 + .../io/netty/handler/flush/package-info.java | 20 + .../ipfilter/AbstractRemoteAddressFilter.java | 109 + .../netty/handler/ipfilter/IpFilterRule.java | 36 + .../handler/ipfilter/IpFilterRuleType.java | 24 + .../handler/ipfilter/IpSubnetFilter.java | 226 + .../handler/ipfilter/IpSubnetFilterRule.java | 219 + .../IpSubnetFilterRuleComparator.java | 36 + .../handler/ipfilter/RuleBasedIpFilter.java | 92 + .../handler/ipfilter/UniqueIpFilter.java | 53 + .../netty/handler/ipfilter/package-info.java | 20 + .../netty/handler/logging/ByteBufFormat.java | 36 + .../io/netty/handler/logging/LogLevel.java | 46 + .../netty/handler/logging/LoggingHandler.java | 427 ++ .../netty/handler/logging/package-info.java | 20 + .../io/netty/handler/pcap/EthernetPacket.java | 81 + .../java/io/netty/handler/pcap/IPPacket.java | 111 + .../io/netty/handler/pcap/PcapHeaders.java | 69 + .../netty/handler/pcap/PcapWriteHandler.java | 861 +++ .../io/netty/handler/pcap/PcapWriter.java | 112 + .../java/io/netty/handler/pcap/State.java | 42 + .../java/io/netty/handler/pcap/TCPPacket.java | 82 + .../java/io/netty/handler/pcap/UDPPacket.java | 43 + .../io/netty/handler/pcap/package-info.java | 20 + .../io/netty/handler/stream/ChunkedFile.java | 170 + .../io/netty/handler/stream/ChunkedInput.java | 81 + .../netty/handler/stream/ChunkedNioFile.java | 181 + .../handler/stream/ChunkedNioStream.java | 143 + .../netty/handler/stream/ChunkedStream.java | 148 + .../handler/stream/ChunkedWriteHandler.java | 384 ++ .../io/netty/handler/stream/package-info.java | 22 + .../io/netty/handler/timeout/IdleState.java | 37 + .../netty/handler/timeout/IdleStateEvent.java | 85 + .../handler/timeout/IdleStateHandler.java | 587 ++ .../handler/timeout/ReadTimeoutException.java | 40 + .../handler/timeout/ReadTimeoutHandler.java | 103 + .../handler/timeout/TimeoutException.java | 40 + .../timeout/WriteTimeoutException.java | 40 + .../handler/timeout/WriteTimeoutHandler.java | 236 + .../netty/handler/timeout/package-info.java | 21 + .../AbstractTrafficShapingHandler.java | 658 ++ .../traffic/ChannelTrafficShapingHandler.java | 231 + .../traffic/GlobalChannelTrafficCounter.java | 127 + .../GlobalChannelTrafficShapingHandler.java | 773 +++ .../traffic/GlobalTrafficShapingHandler.java | 401 ++ .../netty/handler/traffic/TrafficCounter.java | 619 ++ .../netty/handler/traffic/package-info.java | 60 + netty-handler/src/main/java/module-info.java | 16 + .../DynamicAddressConnectHandlerTest.java | 107 + .../address/ResolveAddressHandlerTest.java | 142 + .../handler/flow/FlowControlHandlerTest.java | 680 ++ .../flush/FlushConsolidationHandlerTest.java | 203 + .../handler/ipfilter/IpSubnetFilterTest.java | 220 + .../handler/ipfilter/UniqueIpFilterTest.java | 76 + .../CloseDetectingByteBufOutputStream.java | 46 + .../pcap/DiscardingStatsOutputStream.java | 38 + .../handler/pcap/PcapWriteHandlerTest.java | 680 ++ .../handler/stream/ChunkedStreamTest.java | 51 + .../stream/ChunkedWriteHandlerTest.java | 855 +++ .../handler/timeout/IdleStateEventTest.java | 35 + .../handler/timeout/IdleStateHandlerTest.java | 438 ++ .../timeout/WriteTimeoutHandlerTest.java | 61 + .../traffic/FileRegionThrottleTest.java | 168 + .../traffic/TrafficShapingHandlerTest.java | 125 + .../src/test/resources/logging.properties | 7 + .../tcnative/AsyncSSLPrivateKeyMethod.java | 54 + .../AsyncSSLPrivateKeyMethodAdapter.java | 51 + .../io/netty/internal/tcnative/AsyncTask.java | 27 + .../io/netty/internal/tcnative/Buffer.java | 54 + .../tcnative/CertificateCallback.java | 51 + .../tcnative/CertificateCallbackTask.java | 49 + .../tcnative/CertificateCompressionAlgo.java | 70 + .../CertificateRequestedCallback.java | 57 + .../tcnative/CertificateVerifier.java | 192 + .../tcnative/CertificateVerifierTask.java | 39 + .../io/netty/internal/tcnative/Library.java | 202 + .../NativeStaticallyReferencedJniMethods.java | 184 + .../internal/tcnative/ResultCallback.java | 39 + .../java/io/netty/internal/tcnative/SSL.java | 923 +++ .../netty/internal/tcnative/SSLContext.java | 763 +++ .../tcnative/SSLPrivateKeyMethod.java | 56 + .../SSLPrivateKeyMethodDecryptTask.java | 33 + .../tcnative/SSLPrivateKeyMethodSignTask.java | 34 + .../tcnative/SSLPrivateKeyMethodTask.java | 56 + .../netty/internal/tcnative/SSLSession.java | 80 + .../internal/tcnative/SSLSessionCache.java | 49 + .../io/netty/internal/tcnative/SSLTask.java | 69 + .../internal/tcnative/SessionTicketKey.java | 90 + .../internal/tcnative/SniHostNameMatcher.java | 27 + .../src/main/java/module-info.java | 3 + netty-jctools/build.gradle | 7 + netty-jctools/src/main/java/module-info.java | 9 + .../java/org/jctools/counters/Counter.java | 17 + .../org/jctools/counters/CountersFactory.java | 28 + .../counters/FixedSizeStripedLongCounter.java | 163 + .../FixedSizeStripedLongCounterV6.java | 34 + .../FixedSizeStripedLongCounterV8.java | 26 + .../org/jctools/counters/package-info.java | 14 + .../java/org/jctools/maps/AbstractEntry.java | 59 + .../org/jctools/maps/ConcurrentAutoTable.java | 219 + .../org/jctools/maps/NonBlockingHashMap.java | 1464 ++++ .../jctools/maps/NonBlockingHashMapLong.java | 1313 ++++ .../org/jctools/maps/NonBlockingHashSet.java | 60 + .../maps/NonBlockingIdentityHashMap.java | 1307 ++++ .../org/jctools/maps/NonBlockingSetInt.java | 476 ++ .../java/org/jctools/maps/package-info.java | 14 + .../org/jctools/queues/BaseLinkedQueue.java | 397 ++ .../queues/BaseMpscLinkedArrayQueue.java | 781 +++ .../queues/BaseSpscLinkedArrayQueue.java | 420 ++ .../queues/ConcurrentCircularArrayQueue.java | 170 + ...ConcurrentSequencedCircularArrayQueue.java | 33 + .../jctools/queues/IndexedQueueSizeUtil.java | 101 + .../jctools/queues/LinkedArrayQueueUtil.java | 33 + .../org/jctools/queues/LinkedQueueNode.java | 75 + .../jctools/queues/MessagePassingQueue.java | 316 + .../queues/MessagePassingQueueUtil.java | 125 + .../queues/MpUnboundedXaddArrayQueue.java | 469 ++ .../jctools/queues/MpUnboundedXaddChunk.java | 95 + .../org/jctools/queues/MpmcArrayQueue.java | 624 ++ .../queues/MpmcUnboundedXaddArrayQueue.java | 475 ++ .../queues/MpmcUnboundedXaddChunk.java | 66 + .../org/jctools/queues/MpscArrayQueue.java | 588 ++ .../MpscBlockingConsumerArrayQueue.java | 828 +++ .../jctools/queues/MpscChunkedArrayQueue.java | 102 + .../org/jctools/queues/MpscCompoundQueue.java | 374 + .../queues/MpscGrowableArrayQueue.java | 63 + .../org/jctools/queues/MpscLinkedQueue.java | 196 + .../queues/MpscUnboundedArrayQueue.java | 83 + .../queues/MpscUnboundedXaddArrayQueue.java | 346 + .../queues/MpscUnboundedXaddChunk.java | 26 + .../queues/QueueProgressIndicators.java | 36 + .../org/jctools/queues/SpmcArrayQueue.java | 456 ++ .../org/jctools/queues/SpscArrayQueue.java | 458 ++ .../jctools/queues/SpscChunkedArrayQueue.java | 114 + .../queues/SpscGrowableArrayQueue.java | 170 + .../org/jctools/queues/SpscLinkedQueue.java | 111 + .../queues/SpscUnboundedArrayQueue.java | 83 + .../org/jctools/queues/SupportsIterator.java | 24 + .../queues/atomic/AtomicQueueUtil.java | 101 + .../atomic/AtomicReferenceArrayQueue.java | 157 + .../queues/atomic/BaseLinkedAtomicQueue.java | 471 ++ .../BaseMpscLinkedAtomicArrayQueue.java | 794 +++ .../BaseSpscLinkedAtomicArrayQueue.java | 444 ++ .../queues/atomic/LinkedQueueAtomicNode.java | 69 + .../queues/atomic/MpmcAtomicArrayQueue.java | 652 ++ .../queues/atomic/MpscAtomicArrayQueue.java | 664 ++ .../atomic/MpscChunkedAtomicArrayQueue.java | 133 + .../atomic/MpscGrowableAtomicArrayQueue.java | 60 + .../queues/atomic/MpscLinkedAtomicQueue.java | 158 + .../atomic/MpscUnboundedAtomicArrayQueue.java | 111 + .../SequencedAtomicReferenceArrayQueue.java | 55 + .../queues/atomic/SpmcAtomicArrayQueue.java | 543 ++ .../queues/atomic/SpscAtomicArrayQueue.java | 516 ++ .../atomic/SpscChunkedAtomicArrayQueue.java | 104 + .../atomic/SpscGrowableAtomicArrayQueue.java | 147 + .../queues/atomic/SpscLinkedAtomicQueue.java | 110 + .../atomic/SpscUnboundedAtomicArrayQueue.java | 78 + .../jctools/queues/atomic/package-info.java | 14 + .../java/org/jctools/queues/package-info.java | 98 + .../unpadded/BaseLinkedUnpaddedQueue.java | 324 + .../BaseMpscLinkedUnpaddedArrayQueue.java | 649 ++ .../BaseSpscLinkedUnpaddedArrayQueue.java | 354 + .../ConcurrentCircularUnpaddedArrayQueue.java | 154 + ...ntSequencedCircularUnpaddedArrayQueue.java | 36 + .../unpadded/MpmcUnpaddedArrayQueue.java | 512 ++ .../MpscChunkedUnpaddedArrayQueue.java | 84 + .../MpscGrowableUnpaddedArrayQueue.java | 59 + .../unpadded/MpscLinkedUnpaddedQueue.java | 171 + .../MpscUnboundedUnpaddedArrayQueue.java | 62 + .../unpadded/MpscUnpaddedArrayQueue.java | 469 ++ .../unpadded/SpmcUnpaddedArrayQueue.java | 352 + .../SpscChunkedUnpaddedArrayQueue.java | 103 + .../SpscGrowableUnpaddedArrayQueue.java | 147 + .../unpadded/SpscLinkedUnpaddedQueue.java | 108 + .../SpscUnboundedUnpaddedArrayQueue.java | 77 + .../unpadded/SpscUnpaddedArrayQueue.java | 373 + .../jctools/queues/unpadded/package-info.java | 14 + .../java/org/jctools/util/InternalAPI.java | 30 + .../org/jctools/util/PaddedAtomicLong.java | 392 ++ .../org/jctools/util/PortableJvmInfo.java | 25 + .../src/main/java/org/jctools/util/Pow2.java | 61 + .../main/java/org/jctools/util/RangeUtil.java | 68 + .../org/jctools/util/SpscLookAheadUtil.java | 12 + .../java/org/jctools/util/UnsafeAccess.java | 114 + .../java/org/jctools/util/UnsafeJvmInfo.java | 19 + .../jctools/util/UnsafeLongArrayAccess.java | 114 + .../jctools/util/UnsafeRefArrayAccess.java | 121 + .../java/org/jctools/util/package-info.java | 14 + .../FixedSizeStripedLongCounterTest.java | 88 + .../org/jctools/maps/KeyAtomicityTest.java | 103 + .../maps/NBHMIdentityKeyAtomicityTest.java | 103 + .../maps/NBHMLongKeyAtomicityTest.java | 103 + .../java/org/jctools/maps/NBHMRemoveTest.java | 250 + .../org/jctools/maps/NBHMReplaceTest.java | 20 + .../NonBlockingHashMapGuavaTestSuite.java | 139 + .../linearizability_test/LincheckMapTest.java | 109 + .../linearizability_test/LincheckSetTest.java | 90 + ...NonBlockingHashMapLinearizabilityTest.java | 11 + ...lockingHashMapLongLinearizabilityTest.java | 11 + ...NonBlockingHashSetLinearizabilityTest.java | 11 + ...ingIdentityHashMapLinearizabilityTest.java | 15 + .../NonBlockingSetIntLinearizabilityTest.java | 11 + .../maps/nbhm_test/NBHMID_Tester2.java | 721 ++ .../jctools/maps/nbhm_test/NBHML_Tester2.java | 663 ++ .../jctools/maps/nbhm_test/NBHM_Tester2.java | 728 ++ .../jctools/maps/nbhs_test/nbhs_tester.java | 175 + .../jctools/maps/nbhs_test/nbsi_tester.java | 250 + .../org/jctools/queues/MpqSanityTest.java | 1018 +++ .../queues/MpqSanityTestMpmcArray.java | 33 + .../MpqSanityTestMpmcUnboundedXadd.java | 214 + .../queues/MpqSanityTestMpscArray.java | 33 + .../MpqSanityTestMpscBlockingConsumer.java | 88 + ...anityTestMpscBlockingConsumerExtended.java | 74 + .../queues/MpqSanityTestMpscChunked.java | 46 + .../queues/MpqSanityTestMpscCompound.java | 31 + .../queues/MpqSanityTestMpscGrowable.java | 35 + .../queues/MpqSanityTestMpscLinked.java | 30 + .../queues/MpqSanityTestMpscUnbounded.java | 35 + .../MpqSanityTestMpscUnboundedXadd.java | 36 + .../queues/MpqSanityTestSpmcArray.java | 33 + .../queues/MpqSanityTestSpscArray.java | 33 + .../queues/MpqSanityTestSpscChunked.java | 46 + .../queues/MpqSanityTestSpscGrowable.java | 36 + .../queues/MpqSanityTestSpscLinked.java | 30 + .../queues/MpqSanityTestSpscUnbounded.java | 36 + .../queues/MpscArrayQueueSnapshotTest.java | 79 + .../MpscUnboundedArrayQueueSnapshotTest.java | 108 + .../org/jctools/queues/QueueSanityTest.java | 789 +++ .../queues/QueueSanityTestMpmcArray.java | 100 + .../QueueSanityTestMpmcUnboundedXadd.java | 38 + .../queues/QueueSanityTestMpscArray.java | 98 + .../QueueSanityTestMpscArrayExtended.java | 87 + .../QueueSanityTestMpscBlockingConsumer.java | 30 + ...TestMpscBlockingConsumerArrayExtended.java | 513 ++ ...scBlockingConsumerOfferBelowThreshold.java | 68 + .../queues/QueueSanityTestMpscChunked.java | 39 + .../QueueSanityTestMpscChunkedExtended.java | 16 + .../queues/QueueSanityTestMpscCompound.java | 32 + .../queues/QueueSanityTestMpscGrowable.java | 30 + .../queues/QueueSanityTestMpscLinked.java | 31 + ...ueueSanityTestMpscOfferBelowThreshold.java | 68 + .../QueueSanityTestMpscUnboundedArray.java | 36 + .../QueueSanityTestMpscUnboundedXadd.java | 36 + .../queues/QueueSanityTestSpmcArray.java | 36 + .../queues/QueueSanityTestSpscArray.java | 34 + .../QueueSanityTestSpscArrayExtended.java | 45 + .../queues/QueueSanityTestSpscChunked.java | 36 + .../QueueSanityTestSpscChunkedExtended.java | 17 + .../queues/QueueSanityTestSpscGrowable.java | 66 + .../QueueSanityTestSpscGrowableExtended.java | 35 + .../queues/QueueSanityTestSpscLinked.java | 31 + .../queues/QueueSanityTestSpscUnbounded.java | 37 + .../org/jctools/queues/ScQueueRemoveTest.java | 188 + .../queues/ScQueueRemoveTestMpscLinked.java | 10 + ...tomicArrayQueueOfferWithThresholdTest.java | 38 + .../MpscLinkedAtomicQueueRemoveTest.java | 12 + .../atomic/SpscAtomicArrayQueueTest.java | 58 + .../org/jctools/queues/matchers/Matchers.java | 22 + .../queues/spec/ConcurrentQueueSpec.java | 93 + .../org/jctools/queues/spec/Ordering.java | 20 + .../org/jctools/queues/spec/Preference.java | 20 + .../org/jctools/util/AtomicQueueFactory.java | 72 + .../jctools/util/PaddedAtomicLongTest.java | 214 + .../test/java/org/jctools/util/Pow2Test.java | 41 + .../java/org/jctools/util/QueueFactory.java | 81 + .../java/org/jctools/util/RangeUtilTest.java | 128 + .../test/java/org/jctools/util/TestUtil.java | 93 + .../jctools/util/UnpaddedQueueFactory.java | 73 + .../src/test/resources/logging.properties | 7 + netty-resolver/build.gradle | 4 + .../resolver/AbstractAddressResolver.java | 206 + .../io/netty/resolver/AddressResolver.java | 90 + .../netty/resolver/AddressResolverGroup.java | 131 + .../netty/resolver/CompositeNameResolver.java | 107 + .../resolver/DefaultAddressResolverGroup.java | 36 + .../DefaultHostsFileEntriesResolver.java | 148 + .../netty/resolver/DefaultNameResolver.java | 55 + .../io/netty/resolver/HostsFileEntries.java | 62 + .../resolver/HostsFileEntriesProvider.java | 317 + .../resolver/HostsFileEntriesResolver.java | 37 + .../io/netty/resolver/HostsFileParser.java | 123 + .../io/netty/resolver/InetNameResolver.java | 54 + .../resolver/InetSocketAddressResolver.java | 96 + .../java/io/netty/resolver/NameResolver.java | 73 + .../netty/resolver/NoopAddressResolver.java | 51 + .../resolver/NoopAddressResolverGroup.java | 36 + .../netty/resolver/ResolvedAddressTypes.java | 38 + .../RoundRobinInetAddressResolver.java | 106 + .../io/netty/resolver/SimpleNameResolver.java | 98 + .../java/io/netty/resolver/package-info.java | 20 + netty-resolver/src/main/java/module-info.java | 4 + .../DefaultHostsFileEntriesResolverTest.java | 195 + .../HostsFileEntriesProviderTest.java | 146 + .../netty/resolver/HostsFileParserTest.java | 103 + .../InetSocketAddressResolverTest.java | 38 + .../resources/io/netty/resolver/hosts-unicode | Bin 0 -> 426 bytes .../src/test/resources/logging.properties | 7 + netty-util/build.gradle | 5 + .../java/io/netty/util/AbstractConstant.java | 89 + .../netty/util/AbstractReferenceCounted.java | 95 + .../main/java/io/netty/util/AsciiString.java | 1883 ++++++ .../main/java/io/netty/util/AsyncMapping.java | 28 + .../main/java/io/netty/util/Attribute.java | 93 + .../main/java/io/netty/util/AttributeKey.java | 66 + .../main/java/io/netty/util/AttributeMap.java | 34 + .../java/io/netty/util/BooleanSupplier.java | 49 + .../java/io/netty/util/ByteProcessor.java | 148 + .../io/netty/util/ByteProcessorUtils.java | 25 + .../main/java/io/netty/util/CharsetUtil.java | 186 + .../src/main/java/io/netty/util/Constant.java | 32 + .../main/java/io/netty/util/ConstantPool.java | 115 + .../io/netty/util/DefaultAttributeMap.java | 212 + .../io/netty/util/DomainMappingBuilder.java | 77 + .../java/io/netty/util/DomainNameMapping.java | 151 + .../netty/util/DomainNameMappingBuilder.java | 205 + .../util/DomainWildcardMappingBuilder.java | 160 + .../java/io/netty/util/HashedWheelTimer.java | 870 +++ .../java/io/netty/util/HashingStrategy.java | 75 + .../util/IllegalReferenceCountException.java | 49 + .../main/java/io/netty/util/IntSupplier.java | 29 + .../src/main/java/io/netty/util/Mapping.java | 27 + .../src/main/java/io/netty/util/NetUtil.java | 1102 +++ .../io/netty/util/NetUtilInitializations.java | 188 + .../main/java/io/netty/util/NettyRuntime.java | 105 + .../src/main/java/io/netty/util/Recycler.java | 485 ++ .../io/netty/util/ReferenceCountUtil.java | 210 + .../java/io/netty/util/ReferenceCounted.java | 77 + .../main/java/io/netty/util/ResourceLeak.java | 42 + .../io/netty/util/ResourceLeakDetector.java | 712 ++ .../util/ResourceLeakDetectorFactory.java | 198 + .../io/netty/util/ResourceLeakException.java | 70 + .../java/io/netty/util/ResourceLeakHint.java | 27 + .../io/netty/util/ResourceLeakTracker.java | 39 + .../src/main/java/io/netty/util/Signal.java | 118 + .../java/io/netty/util/SuppressForbidden.java | 32 + .../io/netty/util/ThreadDeathWatcher.java | 258 + .../src/main/java/io/netty/util/Timeout.java | 54 + .../src/main/java/io/netty/util/Timer.java | 47 + .../main/java/io/netty/util/TimerTask.java | 33 + .../netty/util/UncheckedBooleanSupplier.java | 49 + .../src/main/java/io/netty/util/Version.java | 202 + .../util/collection/ByteCollections.java | 313 + .../util/collection/ByteObjectHashMap.java | 723 ++ .../netty/util/collection/ByteObjectMap.java | 84 + .../util/collection/CharCollections.java | 313 + .../util/collection/CharObjectHashMap.java | 723 ++ .../netty/util/collection/CharObjectMap.java | 84 + .../netty/util/collection/IntCollections.java | 313 + .../util/collection/IntObjectHashMap.java | 723 ++ .../netty/util/collection/IntObjectMap.java | 84 + .../util/collection/LongCollections.java | 313 + .../util/collection/LongObjectHashMap.java | 723 ++ .../netty/util/collection/LongObjectMap.java | 84 + .../util/collection/ShortCollections.java | 313 + .../util/collection/ShortObjectHashMap.java | 723 ++ .../netty/util/collection/ShortObjectMap.java | 84 + .../concurrent/AbstractEventExecutor.java | 191 + .../AbstractEventExecutorGroup.java | 117 + .../netty/util/concurrent/AbstractFuture.java | 58 + .../AbstractScheduledEventExecutor.java | 341 + .../BlockingOperationException.java | 42 + .../netty/util/concurrent/CompleteFuture.java | 149 + .../util/concurrent/DefaultEventExecutor.java | 75 + .../DefaultEventExecutorChooserFactory.java | 76 + .../concurrent/DefaultEventExecutorGroup.java | 61 + .../concurrent/DefaultFutureListeners.java | 86 + .../concurrent/DefaultProgressivePromise.java | 129 + .../netty/util/concurrent/DefaultPromise.java | 887 +++ .../util/concurrent/DefaultThreadFactory.java | 122 + .../netty/util/concurrent/EventExecutor.java | 71 + .../EventExecutorChooserFactory.java | 42 + .../util/concurrent/EventExecutorGroup.java | 107 + .../netty/util/concurrent/FailedFuture.java | 67 + .../util/concurrent/FastThreadLocal.java | 279 + .../concurrent/FastThreadLocalRunnable.java | 39 + .../concurrent/FastThreadLocalThread.java | 128 + .../java/io/netty/util/concurrent/Future.java | 164 + .../netty/util/concurrent/FutureListener.java | 29 + .../concurrent/GenericFutureListener.java | 32 + .../GenericProgressiveFutureListener.java | 28 + .../util/concurrent/GlobalEventExecutor.java | 303 + .../concurrent/ImmediateEventExecutor.java | 162 + .../util/concurrent/ImmediateExecutor.java | 35 + .../MultithreadEventExecutorGroup.java | 227 + .../NonStickyEventExecutorGroup.java | 345 + .../util/concurrent/OrderedEventExecutor.java | 22 + .../util/concurrent/ProgressiveFuture.java | 47 + .../util/concurrent/ProgressivePromise.java | 65 + .../io/netty/util/concurrent/Promise.java | 90 + .../util/concurrent/PromiseAggregator.java | 110 + .../util/concurrent/PromiseCombiner.java | 176 + .../util/concurrent/PromiseNotifier.java | 133 + .../io/netty/util/concurrent/PromiseTask.java | 189 + .../concurrent/RejectedExecutionHandler.java | 28 + .../concurrent/RejectedExecutionHandlers.java | 72 + .../util/concurrent/ScheduledFuture.java | 23 + .../util/concurrent/ScheduledFutureTask.java | 219 + .../concurrent/SingleThreadEventExecutor.java | 1130 ++++ .../util/concurrent/SucceededFuture.java | 50 + .../concurrent/ThreadPerTaskExecutor.java | 33 + .../util/concurrent/ThreadProperties.java | 61 + .../util/concurrent/UnaryPromiseNotifier.java | 55 + .../UnorderedThreadPoolEventExecutor.java | 293 + .../netty/util/concurrent/package-info.java | 20 + .../util/internal/AppendableCharSequence.java | 169 + .../util/internal/ClassInitializerUtil.java | 49 + .../java/io/netty/util/internal/Cleaner.java | 29 + .../io/netty/util/internal/CleanerJava9.java | 93 + .../io/netty/util/internal/ConcurrentSet.java | 69 + .../util/internal/ConstantTimeUtils.java | 136 + .../util/internal/DefaultPriorityQueue.java | 295 + .../io/netty/util/internal/EmptyArrays.java | 43 + .../util/internal/EmptyPriorityQueue.java | 161 + .../io/netty/util/internal/IntegerHolder.java | 25 + .../util/internal/InternalThreadLocalMap.java | 398 ++ .../netty/util/internal/LongAdderCounter.java | 27 + .../io/netty/util/internal/LongCounter.java | 29 + .../netty/util/internal/MacAddressUtil.java | 269 + .../java/io/netty/util/internal/MathUtil.java | 97 + .../util/internal/NativeLibraryLoader.java | 567 ++ .../util/internal/NativeLibraryUtil.java | 46 + .../internal/NoOpTypeParameterMatcher.java | 24 + .../io/netty/util/internal/ObjectCleaner.java | 147 + .../io/netty/util/internal/ObjectPool.java | 91 + .../io/netty/util/internal/ObjectUtil.java | 329 + .../util/internal/OutOfDirectMemoryError.java | 30 + .../io/netty/util/internal/PendingWrite.java | 99 + .../util/internal/PlatformDependent.java | 1635 +++++ .../util/internal/PlatformDependent0.java | 931 +++ .../io/netty/util/internal/PriorityQueue.java | 47 + .../util/internal/PriorityQueueNode.java | 45 + .../internal/PromiseNotificationUtil.java | 77 + .../netty/util/internal/ReadOnlyIterator.java | 42 + .../util/internal/RecyclableArrayList.java | 148 + .../util/internal/ReferenceCountUpdater.java | 188 + .../netty/util/internal/ReflectionUtil.java | 53 + .../io/netty/util/internal/ResourcesUtil.java | 41 + .../io/netty/util/internal/SocketUtils.java | 118 + .../io/netty/util/internal/StringUtil.java | 719 ++ .../internal/SuppressJava6Requirement.java | 32 + .../util/internal/SystemPropertyUtil.java | 164 + .../util/internal/ThreadExecutorMap.java | 96 + .../util/internal/ThreadLocalRandom.java | 384 ++ .../io/netty/util/internal/ThrowableUtil.java | 88 + .../util/internal/TypeParameterMatcher.java | 165 + .../UnpaddedInternalThreadLocalMap.java | 24 + .../io/netty/util/internal/UnstableApi.java | 47 + .../logging/AbstractInternalLogger.java | 237 + .../internal/logging/FormattingTuple.java | 61 + .../internal/logging/InternalLogLevel.java | 42 + .../util/internal/logging/InternalLogger.java | 485 ++ .../logging/InternalLoggerFactory.java | 82 + .../util/internal/logging/JdkLogger.java | 646 ++ .../internal/logging/JdkLoggerFactory.java | 41 + .../internal/logging/MessageFormatter.java | 396 ++ .../util/internal/logging/package-info.java | 20 + .../io/netty/util/internal/package-info.java | 21 + .../main/java/io/netty/util/package-info.java | 20 + netty-util/src/main/java/module-info.java | 32 + .../util/AbstractReferenceCountedTest.java | 231 + .../netty/util/AsciiStringCharacterTest.java | 436 ++ .../io/netty/util/AsciiStringMemoryTest.java | 173 + .../java/io/netty/util/AttributeKeyTest.java | 63 + .../java/io/netty/util/ConstantPoolTest.java | 108 + .../netty/util/DefaultAttributeMapTest.java | 131 + .../io/netty/util/DomainNameMappingTest.java | 246 + .../DomainWildcardMappingBuilderTest.java | 121 + .../io/netty/util/HashedWheelTimerTest.java | 299 + .../test/java/io/netty/util/NetUtilTest.java | 820 +++ .../java/io/netty/util/NettyRuntimeTests.java | 206 + .../util/RecyclerFastThreadLocalTest.java | 74 + .../test/java/io/netty/util/RecyclerTest.java | 476 ++ .../netty/util/ResourceLeakDetectorTest.java | 263 + .../RunInFastThreadLocalThreadExtension.java | 58 + .../io/netty/util/ThreadDeathWatcherTest.java | 144 + .../AbstractScheduledEventExecutorTest.java | 173 + .../util/concurrent/DefaultPromiseTest.java | 643 ++ .../concurrent/DefaultThreadFactoryTest.java | 297 + .../util/concurrent/FastThreadLocalTest.java | 363 + .../concurrent/GlobalEventExecutorTest.java | 179 + .../concurrent/ImmediateExecutorTest.java | 54 + .../NonStickyEventExecutorGroupTest.java | 178 + .../concurrent/PromiseAggregatorTest.java | 150 + .../util/concurrent/PromiseCombinerTest.java | 267 + .../util/concurrent/PromiseNotifierTest.java | 110 + .../SingleThreadEventExecutorTest.java | 430 ++ .../UnorderedThreadPoolEventExecutorTest.java | 135 + .../internal/AppendableCharSequenceTest.java | 110 + .../internal/DefaultPriorityQueueTest.java | 322 + .../util/internal/MacAddressUtilTest.java | 198 + .../io/netty/util/internal/MathUtilTest.java | 91 + .../internal/NativeLibraryLoaderTest.java | 127 + .../util/internal/ObjectCleanerTest.java | 142 + .../netty/util/internal/ObjectUtilTest.java | 595 ++ .../util/internal/OsClassifiersTest.java | 121 + .../util/internal/PlatformDependent0Test.java | 94 + .../util/internal/PlatformDependentTest.java | 161 + .../netty/util/internal/StringUtilTest.java | 648 ++ .../util/internal/SystemPropertyUtilTest.java | 140 + .../util/internal/ThreadExecutorMapTest.java | 65 + .../util/internal/ThreadLocalRandomTest.java | 37 + .../internal/TypeParameterMatcherTest.java | 157 + .../logging/AbstractInternalLoggerTest.java | 150 + .../logging/InternalLoggerFactoryTest.java | 188 + .../logging/JdkLoggerFactoryTest.java | 31 + .../logging/MessageFormatterTest.java | 325 + .../src/test/resources/logging.properties | 7 + netty-zlib/NOTICE.txt | 3 + .../src/main/java/io/netty/zlib/Adler32.java | 108 + .../src/main/java/io/netty/zlib/CRC32.java | 147 + .../src/main/java/io/netty/zlib/Checksum.java | 13 + .../src/main/java/io/netty/zlib/Deflate.java | 1736 +++++ .../src/main/java/io/netty/zlib/Deflater.java | 142 + .../io/netty/zlib/DeflaterOutputStream.java | 151 + .../java/io/netty/zlib/GZIPException.java | 11 + .../main/java/io/netty/zlib/GZIPHeader.java | 160 + .../java/io/netty/zlib/GZIPInputStream.java | 117 + .../java/io/netty/zlib/GZIPOutputStream.java | 63 + .../main/java/io/netty/zlib/InfBlocks.java | 666 ++ .../src/main/java/io/netty/zlib/InfCodes.java | 690 ++ .../src/main/java/io/netty/zlib/InfTree.java | 490 ++ .../src/main/java/io/netty/zlib/Inflate.java | 764 +++ .../src/main/java/io/netty/zlib/Inflater.java | 131 + .../io/netty/zlib/InflaterInputStream.java | 225 + .../src/main/java/io/netty/zlib/JZlib.java | 61 + .../main/java/io/netty/zlib/StaticTree.java | 114 + .../src/main/java/io/netty/zlib/Tree.java | 336 + .../main/java/io/netty/zlib/ZInputStream.java | 101 + .../java/io/netty/zlib/ZOutputStream.java | 137 + .../src/main/java/io/netty/zlib/ZStream.java | 360 + .../java/io/netty/zlib/ZStreamException.java | 11 + netty-zlib/src/main/java/module-info.java | 3 + settings.gradle | 75 + 1973 files changed, 421464 insertions(+) create mode 100644 .gitignore create mode 100644 LICENSE.txt create mode 100644 NOTICE.txt create mode 100644 build.gradle create mode 100644 gradle.properties create mode 100644 gradle/compile/java.gradle create mode 100644 gradle/documentation/asciidoc.gradle create mode 100644 gradle/ide/idea.gradle create mode 100644 gradle/publish/forgejo.gradle create mode 100644 gradle/publish/ivy.gradle create mode 100644 gradle/publish/maven.gradle create mode 100644 gradle/publish/sonatype.gradle create mode 100644 gradle/quality/checkstyle.gradle create mode 100644 gradle/quality/checkstyle.xml create mode 100644 gradle/quality/cyclonedx.gradle create mode 100644 gradle/quality/pmd.gradle create mode 100644 gradle/quality/pmd/category/java/bestpractices.xml create mode 100644 gradle/quality/pmd/category/java/categories.properties create mode 100644 gradle/quality/pmd/category/java/codestyle.xml create mode 100644 gradle/quality/pmd/category/java/design.xml create mode 100644 gradle/quality/pmd/category/java/documentation.xml create mode 100644 gradle/quality/pmd/category/java/errorprone.xml create mode 100644 gradle/quality/pmd/category/java/multithreading.xml create mode 100644 gradle/quality/pmd/category/java/performance.xml create mode 100644 gradle/quality/pmd/category/java/security.xml create mode 100644 gradle/quality/sonarqube.gradle create mode 100644 gradle/quality/spotbugs.gradle create mode 100644 gradle/repositories/maven.gradle create mode 100644 gradle/test/jmh.gradle create mode 100644 gradle/test/junit5.gradle create mode 100644 gradle/wrapper/gradle-wrapper.jar create mode 100644 gradle/wrapper/gradle-wrapper.properties create mode 100755 gradlew create mode 100644 gradlew.bat create mode 100644 netty-buffer/build.gradle create mode 100644 netty-buffer/src/main/java/io/netty/buffer/AbstractByteBuf.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/AbstractByteBufAllocator.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/AbstractDerivedByteBuf.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/AbstractPooledDerivedByteBuf.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/AbstractReferenceCountedByteBuf.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/AbstractUnpooledSlicedByteBuf.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/AbstractUnsafeSwappedByteBuf.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/AdvancedLeakAwareByteBuf.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/AdvancedLeakAwareCompositeByteBuf.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/ByteBuf.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/ByteBufAllocator.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/ByteBufAllocatorMetric.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/ByteBufAllocatorMetricProvider.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/ByteBufConvertible.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/ByteBufHolder.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/ByteBufInputStream.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/ByteBufOutputStream.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/ByteBufProcessor.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/ByteBufUtil.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/CompositeByteBuf.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/DefaultByteBufHolder.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/DuplicatedByteBuf.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/EmptyByteBuf.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/FixedCompositeByteBuf.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/HeapByteBufUtil.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/IntPriorityQueue.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/LongLongHashMap.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/PoolArena.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/PoolArenaMetric.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/PoolChunk.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/PoolChunkList.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/PoolChunkListMetric.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/PoolChunkMetric.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/PoolSubpage.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/PoolSubpageMetric.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/PoolThreadCache.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/PooledByteBuf.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/PooledByteBufAllocator.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/PooledByteBufAllocatorMetric.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/PooledDirectByteBuf.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/PooledDuplicatedByteBuf.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/PooledHeapByteBuf.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/PooledSlicedByteBuf.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/PooledUnsafeDirectByteBuf.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/PooledUnsafeHeapByteBuf.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/ReadOnlyByteBuf.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/ReadOnlyByteBufferBuf.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/ReadOnlyUnsafeDirectByteBuf.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/SimpleLeakAwareByteBuf.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/SimpleLeakAwareCompositeByteBuf.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/SizeClasses.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/SizeClassesMetric.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/SlicedByteBuf.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/SwappedByteBuf.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/Unpooled.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/UnpooledByteBufAllocator.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/UnpooledDirectByteBuf.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/UnpooledDuplicatedByteBuf.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/UnpooledHeapByteBuf.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/UnpooledSlicedByteBuf.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/UnpooledUnsafeDirectByteBuf.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/UnpooledUnsafeHeapByteBuf.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/UnpooledUnsafeNoCleanerDirectByteBuf.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/UnreleasableByteBuf.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/UnsafeByteBufUtil.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/UnsafeDirectSwappedByteBuf.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/UnsafeHeapSwappedByteBuf.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/WrappedByteBuf.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/WrappedCompositeByteBuf.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/WrappedUnpooledUnsafeDirectByteBuf.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/package-info.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/search/AbstractMultiSearchProcessorFactory.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/search/AbstractSearchProcessorFactory.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/search/AhoCorasicSearchProcessorFactory.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/search/BitapSearchProcessorFactory.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/search/KmpSearchProcessorFactory.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/search/MultiSearchProcessor.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/search/MultiSearchProcessorFactory.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/search/SearchProcessor.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/search/SearchProcessorFactory.java create mode 100644 netty-buffer/src/main/java/io/netty/buffer/search/package-info.java create mode 100644 netty-buffer/src/main/java/module-info.java create mode 100644 netty-buffer/src/test/java/io/netty/buffer/AbstractByteBufAllocatorTest.java create mode 100644 netty-buffer/src/test/java/io/netty/buffer/AbstractByteBufTest.java create mode 100644 netty-buffer/src/test/java/io/netty/buffer/AbstractCompositeByteBufTest.java create mode 100644 netty-buffer/src/test/java/io/netty/buffer/AbstractPooledByteBufTest.java create mode 100644 netty-buffer/src/test/java/io/netty/buffer/AbstractReferenceCountedByteBufTest.java create mode 100644 netty-buffer/src/test/java/io/netty/buffer/AdvancedLeakAwareByteBufTest.java create mode 100644 netty-buffer/src/test/java/io/netty/buffer/AdvancedLeakAwareCompositeByteBufTest.java create mode 100644 netty-buffer/src/test/java/io/netty/buffer/AlignedPooledByteBufAllocatorTest.java create mode 100644 netty-buffer/src/test/java/io/netty/buffer/BigEndianCompositeByteBufTest.java create mode 100644 netty-buffer/src/test/java/io/netty/buffer/BigEndianDirectByteBufTest.java create mode 100644 netty-buffer/src/test/java/io/netty/buffer/BigEndianHeapByteBufTest.java create mode 100644 netty-buffer/src/test/java/io/netty/buffer/BigEndianUnsafeDirectByteBufTest.java create mode 100644 netty-buffer/src/test/java/io/netty/buffer/BigEndianUnsafeNoCleanerDirectByteBufTest.java create mode 100644 netty-buffer/src/test/java/io/netty/buffer/ByteBufAllocatorTest.java create mode 100644 netty-buffer/src/test/java/io/netty/buffer/ByteBufDerivationTest.java create mode 100644 netty-buffer/src/test/java/io/netty/buffer/ByteBufStreamTest.java create mode 100644 netty-buffer/src/test/java/io/netty/buffer/ByteBufUtilTest.java create mode 100644 netty-buffer/src/test/java/io/netty/buffer/ByteProcessorTest.java create mode 100644 netty-buffer/src/test/java/io/netty/buffer/ConsolidationTest.java create mode 100644 netty-buffer/src/test/java/io/netty/buffer/DefaultByteBufHolderTest.java create mode 100644 netty-buffer/src/test/java/io/netty/buffer/DuplicatedByteBufTest.java create mode 100644 netty-buffer/src/test/java/io/netty/buffer/EmptyByteBufTest.java create mode 100644 netty-buffer/src/test/java/io/netty/buffer/FixedCompositeByteBufTest.java create mode 100644 netty-buffer/src/test/java/io/netty/buffer/IntPriorityQueueTest.java create mode 100644 netty-buffer/src/test/java/io/netty/buffer/LittleEndianCompositeByteBufTest.java create mode 100644 netty-buffer/src/test/java/io/netty/buffer/LittleEndianDirectByteBufTest.java create mode 100644 netty-buffer/src/test/java/io/netty/buffer/LittleEndianHeapByteBufTest.java create mode 100644 netty-buffer/src/test/java/io/netty/buffer/LittleEndianUnsafeDirectByteBufTest.java create mode 100644 netty-buffer/src/test/java/io/netty/buffer/LittleEndianUnsafeNoCleanerDirectByteBufTest.java create mode 100644 netty-buffer/src/test/java/io/netty/buffer/LongLongHashMapTest.java create mode 100644 netty-buffer/src/test/java/io/netty/buffer/NoopResourceLeakTracker.java create mode 100644 netty-buffer/src/test/java/io/netty/buffer/PoolArenaTest.java create mode 100644 netty-buffer/src/test/java/io/netty/buffer/PooledAlignedBigEndianDirectByteBufTest.java create mode 100644 netty-buffer/src/test/java/io/netty/buffer/PooledBigEndianDirectByteBufTest.java create mode 100644 netty-buffer/src/test/java/io/netty/buffer/PooledBigEndianHeapByteBufTest.java create mode 100644 netty-buffer/src/test/java/io/netty/buffer/PooledByteBufAllocatorTest.java create mode 100644 netty-buffer/src/test/java/io/netty/buffer/PooledLittleEndianDirectByteBufTest.java create mode 100644 netty-buffer/src/test/java/io/netty/buffer/PooledLittleEndianHeapByteBufTest.java create mode 100644 netty-buffer/src/test/java/io/netty/buffer/ReadOnlyByteBufTest.java create mode 100644 netty-buffer/src/test/java/io/netty/buffer/ReadOnlyByteBufferBufTest.java create mode 100644 netty-buffer/src/test/java/io/netty/buffer/ReadOnlyDirectByteBufferBufTest.java create mode 100644 netty-buffer/src/test/java/io/netty/buffer/ReadOnlyUnsafeDirectByteBufferBufTest.java create mode 100644 netty-buffer/src/test/java/io/netty/buffer/RetainedDuplicatedByteBufTest.java create mode 100644 netty-buffer/src/test/java/io/netty/buffer/RetainedSlicedByteBufTest.java create mode 100644 netty-buffer/src/test/java/io/netty/buffer/SimpleLeakAwareByteBufTest.java create mode 100644 netty-buffer/src/test/java/io/netty/buffer/SimpleLeakAwareCompositeByteBufTest.java create mode 100644 netty-buffer/src/test/java/io/netty/buffer/SlicedByteBufTest.java create mode 100644 netty-buffer/src/test/java/io/netty/buffer/UnpooledByteBufAllocatorTest.java create mode 100644 netty-buffer/src/test/java/io/netty/buffer/UnpooledTest.java create mode 100644 netty-buffer/src/test/java/io/netty/buffer/UnreleaseableByteBufTest.java create mode 100644 netty-buffer/src/test/java/io/netty/buffer/UnsafeByteBufUtilTest.java create mode 100644 netty-buffer/src/test/java/io/netty/buffer/WrappedCompositeByteBufTest.java create mode 100644 netty-buffer/src/test/java/io/netty/buffer/WrappedUnpooledUnsafeByteBufTest.java create mode 100644 netty-buffer/src/test/java/io/netty/buffer/search/BitapSearchProcessorFactoryTest.java create mode 100644 netty-buffer/src/test/java/io/netty/buffer/search/MultiSearchProcessorTest.java create mode 100644 netty-buffer/src/test/java/io/netty/buffer/search/SearchProcessorTest.java create mode 100644 netty-buffer/src/test/resources/logging.properties create mode 100644 netty-bzip2/build.gradle create mode 100644 netty-bzip2/src/main/java/io/netty/bzip2/Bzip2BitReader.java create mode 100644 netty-bzip2/src/main/java/io/netty/bzip2/Bzip2BitWriter.java create mode 100644 netty-bzip2/src/main/java/io/netty/bzip2/Bzip2BlockCompressor.java create mode 100644 netty-bzip2/src/main/java/io/netty/bzip2/Bzip2BlockDecompressor.java create mode 100644 netty-bzip2/src/main/java/io/netty/bzip2/Bzip2Constants.java create mode 100644 netty-bzip2/src/main/java/io/netty/bzip2/Bzip2DivSufSort.java create mode 100644 netty-bzip2/src/main/java/io/netty/bzip2/Bzip2HuffmanAllocator.java create mode 100644 netty-bzip2/src/main/java/io/netty/bzip2/Bzip2HuffmanStageDecoder.java create mode 100644 netty-bzip2/src/main/java/io/netty/bzip2/Bzip2HuffmanStageEncoder.java create mode 100644 netty-bzip2/src/main/java/io/netty/bzip2/Bzip2MTFAndRLE2StageEncoder.java create mode 100644 netty-bzip2/src/main/java/io/netty/bzip2/Bzip2MoveToFrontTable.java create mode 100644 netty-bzip2/src/main/java/io/netty/bzip2/Bzip2Rand.java create mode 100644 netty-bzip2/src/main/java/io/netty/bzip2/Crc32.java create mode 100644 netty-bzip2/src/main/java/io/netty/bzip2/DecompressionException.java create mode 100644 netty-bzip2/src/main/java/module-info.java create mode 100644 netty-channel-unix/build.gradle create mode 100644 netty-channel-unix/src/main/java/io/netty/channel/unix/Buffer.java create mode 100644 netty-channel-unix/src/main/java/io/netty/channel/unix/DatagramSocketAddress.java create mode 100644 netty-channel-unix/src/main/java/io/netty/channel/unix/DomainDatagramChannel.java create mode 100644 netty-channel-unix/src/main/java/io/netty/channel/unix/DomainDatagramChannelConfig.java create mode 100644 netty-channel-unix/src/main/java/io/netty/channel/unix/DomainDatagramPacket.java create mode 100644 netty-channel-unix/src/main/java/io/netty/channel/unix/DomainDatagramSocketAddress.java create mode 100644 netty-channel-unix/src/main/java/io/netty/channel/unix/DomainSocketAddress.java create mode 100644 netty-channel-unix/src/main/java/io/netty/channel/unix/DomainSocketChannel.java create mode 100644 netty-channel-unix/src/main/java/io/netty/channel/unix/DomainSocketChannelConfig.java create mode 100644 netty-channel-unix/src/main/java/io/netty/channel/unix/DomainSocketReadMode.java create mode 100644 netty-channel-unix/src/main/java/io/netty/channel/unix/Errors.java create mode 100644 netty-channel-unix/src/main/java/io/netty/channel/unix/ErrorsStaticallyReferencedJniMethods.java create mode 100644 netty-channel-unix/src/main/java/io/netty/channel/unix/FileDescriptor.java create mode 100644 netty-channel-unix/src/main/java/io/netty/channel/unix/GenericUnixChannelOption.java create mode 100644 netty-channel-unix/src/main/java/io/netty/channel/unix/IntegerUnixChannelOption.java create mode 100644 netty-channel-unix/src/main/java/io/netty/channel/unix/IovArray.java create mode 100644 netty-channel-unix/src/main/java/io/netty/channel/unix/Limits.java create mode 100644 netty-channel-unix/src/main/java/io/netty/channel/unix/LimitsStaticallyReferencedJniMethods.java create mode 100644 netty-channel-unix/src/main/java/io/netty/channel/unix/NativeInetAddress.java create mode 100644 netty-channel-unix/src/main/java/io/netty/channel/unix/PeerCredentials.java create mode 100644 netty-channel-unix/src/main/java/io/netty/channel/unix/PreferredDirectByteBufAllocator.java create mode 100644 netty-channel-unix/src/main/java/io/netty/channel/unix/RawUnixChannelOption.java create mode 100644 netty-channel-unix/src/main/java/io/netty/channel/unix/SegmentedDatagramPacket.java create mode 100644 netty-channel-unix/src/main/java/io/netty/channel/unix/ServerDomainSocketChannel.java create mode 100644 netty-channel-unix/src/main/java/io/netty/channel/unix/Socket.java create mode 100644 netty-channel-unix/src/main/java/io/netty/channel/unix/SocketWritableByteChannel.java create mode 100644 netty-channel-unix/src/main/java/io/netty/channel/unix/Unix.java create mode 100644 netty-channel-unix/src/main/java/io/netty/channel/unix/UnixChannel.java create mode 100644 netty-channel-unix/src/main/java/io/netty/channel/unix/UnixChannelOption.java create mode 100644 netty-channel-unix/src/main/java/io/netty/channel/unix/UnixChannelUtil.java create mode 100644 netty-channel-unix/src/main/java/io/netty/channel/unix/package-info.java create mode 100644 netty-channel-unix/src/main/java/module-info.java create mode 100644 netty-channel-unix/src/test/java/io/netty/channel/unix/UnixChannelUtilTest.java create mode 100644 netty-channel-unix/src/test/resources/logging.properties create mode 100644 netty-channel/build.gradle create mode 100644 netty-channel/src/main/java/io/netty/bootstrap/AbstractBootstrap.java create mode 100644 netty-channel/src/main/java/io/netty/bootstrap/AbstractBootstrapConfig.java create mode 100644 netty-channel/src/main/java/io/netty/bootstrap/Bootstrap.java create mode 100644 netty-channel/src/main/java/io/netty/bootstrap/BootstrapConfig.java create mode 100644 netty-channel/src/main/java/io/netty/bootstrap/ChannelFactory.java create mode 100644 netty-channel/src/main/java/io/netty/bootstrap/ChannelInitializerExtension.java create mode 100644 netty-channel/src/main/java/io/netty/bootstrap/ChannelInitializerExtensions.java create mode 100644 netty-channel/src/main/java/io/netty/bootstrap/FailedChannel.java create mode 100644 netty-channel/src/main/java/io/netty/bootstrap/ServerBootstrap.java create mode 100644 netty-channel/src/main/java/io/netty/bootstrap/ServerBootstrapConfig.java create mode 100644 netty-channel/src/main/java/io/netty/bootstrap/package-info.java create mode 100644 netty-channel/src/main/java/io/netty/channel/AbstractChannel.java create mode 100644 netty-channel/src/main/java/io/netty/channel/AbstractChannelHandlerContext.java create mode 100644 netty-channel/src/main/java/io/netty/channel/AbstractCoalescingBufferQueue.java create mode 100644 netty-channel/src/main/java/io/netty/channel/AbstractEventLoop.java create mode 100644 netty-channel/src/main/java/io/netty/channel/AbstractEventLoopGroup.java create mode 100644 netty-channel/src/main/java/io/netty/channel/AbstractServerChannel.java create mode 100644 netty-channel/src/main/java/io/netty/channel/AdaptiveRecvByteBufAllocator.java create mode 100644 netty-channel/src/main/java/io/netty/channel/AddressedEnvelope.java create mode 100644 netty-channel/src/main/java/io/netty/channel/Channel.java create mode 100644 netty-channel/src/main/java/io/netty/channel/ChannelConfig.java create mode 100644 netty-channel/src/main/java/io/netty/channel/ChannelDuplexHandler.java create mode 100644 netty-channel/src/main/java/io/netty/channel/ChannelException.java create mode 100644 netty-channel/src/main/java/io/netty/channel/ChannelFactory.java create mode 100644 netty-channel/src/main/java/io/netty/channel/ChannelFlushPromiseNotifier.java create mode 100644 netty-channel/src/main/java/io/netty/channel/ChannelFuture.java create mode 100644 netty-channel/src/main/java/io/netty/channel/ChannelFutureListener.java create mode 100644 netty-channel/src/main/java/io/netty/channel/ChannelHandler.java create mode 100644 netty-channel/src/main/java/io/netty/channel/ChannelHandlerAdapter.java create mode 100644 netty-channel/src/main/java/io/netty/channel/ChannelHandlerContext.java create mode 100644 netty-channel/src/main/java/io/netty/channel/ChannelHandlerMask.java create mode 100644 netty-channel/src/main/java/io/netty/channel/ChannelId.java create mode 100644 netty-channel/src/main/java/io/netty/channel/ChannelInboundHandler.java create mode 100644 netty-channel/src/main/java/io/netty/channel/ChannelInboundHandlerAdapter.java create mode 100644 netty-channel/src/main/java/io/netty/channel/ChannelInboundInvoker.java create mode 100644 netty-channel/src/main/java/io/netty/channel/ChannelInitializer.java create mode 100644 netty-channel/src/main/java/io/netty/channel/ChannelMetadata.java create mode 100644 netty-channel/src/main/java/io/netty/channel/ChannelOption.java create mode 100644 netty-channel/src/main/java/io/netty/channel/ChannelOutboundBuffer.java create mode 100644 netty-channel/src/main/java/io/netty/channel/ChannelOutboundHandler.java create mode 100644 netty-channel/src/main/java/io/netty/channel/ChannelOutboundHandlerAdapter.java create mode 100644 netty-channel/src/main/java/io/netty/channel/ChannelOutboundInvoker.java create mode 100644 netty-channel/src/main/java/io/netty/channel/ChannelPipeline.java create mode 100644 netty-channel/src/main/java/io/netty/channel/ChannelPipelineException.java create mode 100644 netty-channel/src/main/java/io/netty/channel/ChannelProgressiveFuture.java create mode 100644 netty-channel/src/main/java/io/netty/channel/ChannelProgressiveFutureListener.java create mode 100644 netty-channel/src/main/java/io/netty/channel/ChannelProgressivePromise.java create mode 100644 netty-channel/src/main/java/io/netty/channel/ChannelPromise.java create mode 100644 netty-channel/src/main/java/io/netty/channel/ChannelPromiseAggregator.java create mode 100644 netty-channel/src/main/java/io/netty/channel/ChannelPromiseNotifier.java create mode 100644 netty-channel/src/main/java/io/netty/channel/CoalescingBufferQueue.java create mode 100644 netty-channel/src/main/java/io/netty/channel/CombinedChannelDuplexHandler.java create mode 100644 netty-channel/src/main/java/io/netty/channel/CompleteChannelFuture.java create mode 100644 netty-channel/src/main/java/io/netty/channel/ConnectTimeoutException.java create mode 100644 netty-channel/src/main/java/io/netty/channel/DefaultAddressedEnvelope.java create mode 100644 netty-channel/src/main/java/io/netty/channel/DefaultChannelConfig.java create mode 100644 netty-channel/src/main/java/io/netty/channel/DefaultChannelHandlerContext.java create mode 100644 netty-channel/src/main/java/io/netty/channel/DefaultChannelId.java create mode 100644 netty-channel/src/main/java/io/netty/channel/DefaultChannelPipeline.java create mode 100644 netty-channel/src/main/java/io/netty/channel/DefaultChannelProgressivePromise.java create mode 100644 netty-channel/src/main/java/io/netty/channel/DefaultChannelPromise.java create mode 100644 netty-channel/src/main/java/io/netty/channel/DefaultEventLoop.java create mode 100644 netty-channel/src/main/java/io/netty/channel/DefaultEventLoopGroup.java create mode 100644 netty-channel/src/main/java/io/netty/channel/DefaultFileRegion.java create mode 100644 netty-channel/src/main/java/io/netty/channel/DefaultMaxBytesRecvByteBufAllocator.java create mode 100644 netty-channel/src/main/java/io/netty/channel/DefaultMaxMessagesRecvByteBufAllocator.java create mode 100644 netty-channel/src/main/java/io/netty/channel/DefaultMessageSizeEstimator.java create mode 100644 netty-channel/src/main/java/io/netty/channel/DefaultSelectStrategy.java create mode 100644 netty-channel/src/main/java/io/netty/channel/DefaultSelectStrategyFactory.java create mode 100644 netty-channel/src/main/java/io/netty/channel/DelegatingChannelPromiseNotifier.java create mode 100644 netty-channel/src/main/java/io/netty/channel/EventLoop.java create mode 100644 netty-channel/src/main/java/io/netty/channel/EventLoopException.java create mode 100644 netty-channel/src/main/java/io/netty/channel/EventLoopGroup.java create mode 100644 netty-channel/src/main/java/io/netty/channel/EventLoopTaskQueueFactory.java create mode 100644 netty-channel/src/main/java/io/netty/channel/ExtendedClosedChannelException.java create mode 100644 netty-channel/src/main/java/io/netty/channel/FailedChannelFuture.java create mode 100644 netty-channel/src/main/java/io/netty/channel/FileRegion.java create mode 100644 netty-channel/src/main/java/io/netty/channel/FixedRecvByteBufAllocator.java create mode 100644 netty-channel/src/main/java/io/netty/channel/MaxBytesRecvByteBufAllocator.java create mode 100644 netty-channel/src/main/java/io/netty/channel/MaxMessagesRecvByteBufAllocator.java create mode 100644 netty-channel/src/main/java/io/netty/channel/MessageSizeEstimator.java create mode 100644 netty-channel/src/main/java/io/netty/channel/MultithreadEventLoopGroup.java create mode 100644 netty-channel/src/main/java/io/netty/channel/PendingBytesTracker.java create mode 100644 netty-channel/src/main/java/io/netty/channel/PendingWriteQueue.java create mode 100644 netty-channel/src/main/java/io/netty/channel/PreferHeapByteBufAllocator.java create mode 100644 netty-channel/src/main/java/io/netty/channel/RecvByteBufAllocator.java create mode 100644 netty-channel/src/main/java/io/netty/channel/ReflectiveChannelFactory.java create mode 100644 netty-channel/src/main/java/io/netty/channel/SelectStrategy.java create mode 100644 netty-channel/src/main/java/io/netty/channel/SelectStrategyFactory.java create mode 100644 netty-channel/src/main/java/io/netty/channel/ServerChannel.java create mode 100644 netty-channel/src/main/java/io/netty/channel/ServerChannelRecvByteBufAllocator.java create mode 100644 netty-channel/src/main/java/io/netty/channel/SimpleChannelInboundHandler.java create mode 100644 netty-channel/src/main/java/io/netty/channel/SimpleUserEventChannelHandler.java create mode 100644 netty-channel/src/main/java/io/netty/channel/SingleThreadEventLoop.java create mode 100644 netty-channel/src/main/java/io/netty/channel/StacklessClosedChannelException.java create mode 100644 netty-channel/src/main/java/io/netty/channel/SucceededChannelFuture.java create mode 100644 netty-channel/src/main/java/io/netty/channel/ThreadPerChannelEventLoop.java create mode 100644 netty-channel/src/main/java/io/netty/channel/ThreadPerChannelEventLoopGroup.java create mode 100644 netty-channel/src/main/java/io/netty/channel/VoidChannelPromise.java create mode 100644 netty-channel/src/main/java/io/netty/channel/WriteBufferWaterMark.java create mode 100644 netty-channel/src/main/java/io/netty/channel/embedded/EmbeddedChannel.java create mode 100644 netty-channel/src/main/java/io/netty/channel/embedded/EmbeddedChannelId.java create mode 100644 netty-channel/src/main/java/io/netty/channel/embedded/EmbeddedEventLoop.java create mode 100644 netty-channel/src/main/java/io/netty/channel/embedded/EmbeddedSocketAddress.java create mode 100644 netty-channel/src/main/java/io/netty/channel/embedded/package-info.java create mode 100644 netty-channel/src/main/java/io/netty/channel/group/ChannelGroup.java create mode 100644 netty-channel/src/main/java/io/netty/channel/group/ChannelGroupException.java create mode 100644 netty-channel/src/main/java/io/netty/channel/group/ChannelGroupFuture.java create mode 100644 netty-channel/src/main/java/io/netty/channel/group/ChannelGroupFutureListener.java create mode 100644 netty-channel/src/main/java/io/netty/channel/group/ChannelMatcher.java create mode 100644 netty-channel/src/main/java/io/netty/channel/group/ChannelMatchers.java create mode 100644 netty-channel/src/main/java/io/netty/channel/group/CombinedIterator.java create mode 100644 netty-channel/src/main/java/io/netty/channel/group/DefaultChannelGroup.java create mode 100644 netty-channel/src/main/java/io/netty/channel/group/DefaultChannelGroupFuture.java create mode 100644 netty-channel/src/main/java/io/netty/channel/group/VoidChannelGroupFuture.java create mode 100644 netty-channel/src/main/java/io/netty/channel/group/package-info.java create mode 100644 netty-channel/src/main/java/io/netty/channel/internal/ChannelUtils.java create mode 100644 netty-channel/src/main/java/io/netty/channel/internal/package-info.java create mode 100644 netty-channel/src/main/java/io/netty/channel/local/LocalAddress.java create mode 100644 netty-channel/src/main/java/io/netty/channel/local/LocalChannel.java create mode 100644 netty-channel/src/main/java/io/netty/channel/local/LocalChannelRegistry.java create mode 100644 netty-channel/src/main/java/io/netty/channel/local/LocalEventLoopGroup.java create mode 100644 netty-channel/src/main/java/io/netty/channel/local/LocalServerChannel.java create mode 100644 netty-channel/src/main/java/io/netty/channel/local/package-info.java create mode 100644 netty-channel/src/main/java/io/netty/channel/nio/AbstractNioByteChannel.java create mode 100644 netty-channel/src/main/java/io/netty/channel/nio/AbstractNioChannel.java create mode 100644 netty-channel/src/main/java/io/netty/channel/nio/AbstractNioMessageChannel.java create mode 100644 netty-channel/src/main/java/io/netty/channel/nio/NioEventLoop.java create mode 100644 netty-channel/src/main/java/io/netty/channel/nio/NioEventLoopGroup.java create mode 100644 netty-channel/src/main/java/io/netty/channel/nio/NioTask.java create mode 100644 netty-channel/src/main/java/io/netty/channel/nio/SelectedSelectionKeySet.java create mode 100644 netty-channel/src/main/java/io/netty/channel/nio/SelectedSelectionKeySetSelector.java create mode 100644 netty-channel/src/main/java/io/netty/channel/nio/package-info.java create mode 100644 netty-channel/src/main/java/io/netty/channel/oio/AbstractOioByteChannel.java create mode 100644 netty-channel/src/main/java/io/netty/channel/oio/AbstractOioChannel.java create mode 100644 netty-channel/src/main/java/io/netty/channel/oio/AbstractOioMessageChannel.java create mode 100644 netty-channel/src/main/java/io/netty/channel/oio/OioByteStreamChannel.java create mode 100644 netty-channel/src/main/java/io/netty/channel/oio/OioEventLoopGroup.java create mode 100644 netty-channel/src/main/java/io/netty/channel/oio/package-info.java create mode 100644 netty-channel/src/main/java/io/netty/channel/package-info.java create mode 100644 netty-channel/src/main/java/io/netty/channel/pool/AbstractChannelPoolHandler.java create mode 100644 netty-channel/src/main/java/io/netty/channel/pool/AbstractChannelPoolMap.java create mode 100644 netty-channel/src/main/java/io/netty/channel/pool/ChannelHealthChecker.java create mode 100644 netty-channel/src/main/java/io/netty/channel/pool/ChannelPool.java create mode 100644 netty-channel/src/main/java/io/netty/channel/pool/ChannelPoolHandler.java create mode 100644 netty-channel/src/main/java/io/netty/channel/pool/ChannelPoolMap.java create mode 100644 netty-channel/src/main/java/io/netty/channel/pool/FixedChannelPool.java create mode 100644 netty-channel/src/main/java/io/netty/channel/pool/SimpleChannelPool.java create mode 100644 netty-channel/src/main/java/io/netty/channel/pool/package-info.java create mode 100644 netty-channel/src/main/java/io/netty/channel/socket/ChannelInputShutdownEvent.java create mode 100644 netty-channel/src/main/java/io/netty/channel/socket/ChannelInputShutdownReadComplete.java create mode 100644 netty-channel/src/main/java/io/netty/channel/socket/ChannelOutputShutdownEvent.java create mode 100644 netty-channel/src/main/java/io/netty/channel/socket/ChannelOutputShutdownException.java create mode 100644 netty-channel/src/main/java/io/netty/channel/socket/DatagramChannel.java create mode 100644 netty-channel/src/main/java/io/netty/channel/socket/DatagramChannelConfig.java create mode 100644 netty-channel/src/main/java/io/netty/channel/socket/DatagramPacket.java create mode 100644 netty-channel/src/main/java/io/netty/channel/socket/DefaultDatagramChannelConfig.java create mode 100644 netty-channel/src/main/java/io/netty/channel/socket/DefaultServerSocketChannelConfig.java create mode 100644 netty-channel/src/main/java/io/netty/channel/socket/DefaultSocketChannelConfig.java create mode 100644 netty-channel/src/main/java/io/netty/channel/socket/DuplexChannel.java create mode 100644 netty-channel/src/main/java/io/netty/channel/socket/DuplexChannelConfig.java create mode 100644 netty-channel/src/main/java/io/netty/channel/socket/InternetProtocolFamily.java create mode 100644 netty-channel/src/main/java/io/netty/channel/socket/ServerSocketChannel.java create mode 100644 netty-channel/src/main/java/io/netty/channel/socket/ServerSocketChannelConfig.java create mode 100644 netty-channel/src/main/java/io/netty/channel/socket/SocketChannel.java create mode 100644 netty-channel/src/main/java/io/netty/channel/socket/SocketChannelConfig.java create mode 100644 netty-channel/src/main/java/io/netty/channel/socket/nio/NioChannelOption.java create mode 100644 netty-channel/src/main/java/io/netty/channel/socket/nio/NioDatagramChannel.java create mode 100644 netty-channel/src/main/java/io/netty/channel/socket/nio/NioDatagramChannelConfig.java create mode 100644 netty-channel/src/main/java/io/netty/channel/socket/nio/NioServerSocketChannel.java create mode 100644 netty-channel/src/main/java/io/netty/channel/socket/nio/NioSocketChannel.java create mode 100644 netty-channel/src/main/java/io/netty/channel/socket/nio/ProtocolFamilyConverter.java create mode 100644 netty-channel/src/main/java/io/netty/channel/socket/nio/SelectorProviderUtil.java create mode 100644 netty-channel/src/main/java/io/netty/channel/socket/nio/package-info.java create mode 100644 netty-channel/src/main/java/io/netty/channel/socket/oio/DefaultOioDatagramChannelConfig.java create mode 100644 netty-channel/src/main/java/io/netty/channel/socket/oio/DefaultOioServerSocketChannelConfig.java create mode 100644 netty-channel/src/main/java/io/netty/channel/socket/oio/DefaultOioSocketChannelConfig.java create mode 100644 netty-channel/src/main/java/io/netty/channel/socket/oio/OioDatagramChannel.java create mode 100644 netty-channel/src/main/java/io/netty/channel/socket/oio/OioDatagramChannelConfig.java create mode 100644 netty-channel/src/main/java/io/netty/channel/socket/oio/OioServerSocketChannel.java create mode 100644 netty-channel/src/main/java/io/netty/channel/socket/oio/OioServerSocketChannelConfig.java create mode 100644 netty-channel/src/main/java/io/netty/channel/socket/oio/OioSocketChannel.java create mode 100644 netty-channel/src/main/java/io/netty/channel/socket/oio/OioSocketChannelConfig.java create mode 100644 netty-channel/src/main/java/io/netty/channel/socket/oio/package-info.java create mode 100644 netty-channel/src/main/java/io/netty/channel/socket/package-info.java create mode 100644 netty-channel/src/main/java/module-info.java create mode 100644 netty-channel/src/main/resources/META-INF/native-image/io.netty/netty-transport/generated/handlers/reflect-config.json create mode 100644 netty-channel/src/main/resources/META-INF/native-image/io.netty/netty-transport/reflect-config.json create mode 100644 netty-channel/src/test/java/io/netty/bootstrap/BootstrapTest.java create mode 100644 netty-channel/src/test/java/io/netty/bootstrap/ServerBootstrapTest.java create mode 100644 netty-channel/src/test/java/io/netty/bootstrap/StubChannelInitializerExtension.java create mode 100644 netty-channel/src/test/java/io/netty/channel/AbstractChannelTest.java create mode 100644 netty-channel/src/test/java/io/netty/channel/AbstractCoalescingBufferQueueTest.java create mode 100644 netty-channel/src/test/java/io/netty/channel/AbstractEventLoopTest.java create mode 100644 netty-channel/src/test/java/io/netty/channel/AdaptiveRecvByteBufAllocatorTest.java create mode 100644 netty-channel/src/test/java/io/netty/channel/BaseChannelTest.java create mode 100644 netty-channel/src/test/java/io/netty/channel/ChannelInitializerTest.java create mode 100644 netty-channel/src/test/java/io/netty/channel/ChannelOptionTest.java create mode 100644 netty-channel/src/test/java/io/netty/channel/ChannelOutboundBufferTest.java create mode 100644 netty-channel/src/test/java/io/netty/channel/CoalescingBufferQueueTest.java create mode 100644 netty-channel/src/test/java/io/netty/channel/CombinedChannelDuplexHandlerTest.java create mode 100644 netty-channel/src/test/java/io/netty/channel/CompleteChannelFutureTest.java create mode 100644 netty-channel/src/test/java/io/netty/channel/DefaultChannelIdTest.java create mode 100644 netty-channel/src/test/java/io/netty/channel/DefaultChannelPipelineTailTest.java create mode 100644 netty-channel/src/test/java/io/netty/channel/DefaultChannelPipelineTest.java create mode 100644 netty-channel/src/test/java/io/netty/channel/DefaultChannelPromiseTest.java create mode 100644 netty-channel/src/test/java/io/netty/channel/DefaultFileRegionTest.java create mode 100644 netty-channel/src/test/java/io/netty/channel/DefaultMaxMessagesRecvByteBufAllocatorTest.java create mode 100644 netty-channel/src/test/java/io/netty/channel/DelegatingChannelPromiseNotifierTest.java create mode 100644 netty-channel/src/test/java/io/netty/channel/FailedChannelFutureTest.java create mode 100644 netty-channel/src/test/java/io/netty/channel/LoggingTestHandler.java create mode 100644 netty-channel/src/test/java/io/netty/channel/NativeImageHandlerMetadataTest.java create mode 100644 netty-channel/src/test/java/io/netty/channel/PendingWriteQueueTest.java create mode 100644 netty-channel/src/test/java/io/netty/channel/ReentrantChannelTest.java create mode 100644 netty-channel/src/test/java/io/netty/channel/SimpleUserEventChannelHandlerTest.java create mode 100644 netty-channel/src/test/java/io/netty/channel/SingleThreadEventLoopTest.java create mode 100644 netty-channel/src/test/java/io/netty/channel/SucceededChannelFutureTest.java create mode 100644 netty-channel/src/test/java/io/netty/channel/ThreadPerChannelEventLoopGroupTest.java create mode 100644 netty-channel/src/test/java/io/netty/channel/embedded/CustomChannelId.java create mode 100644 netty-channel/src/test/java/io/netty/channel/embedded/EmbeddedChannelIdTest.java create mode 100644 netty-channel/src/test/java/io/netty/channel/embedded/EmbeddedChannelTest.java create mode 100644 netty-channel/src/test/java/io/netty/channel/group/DefaultChannelGroupTest.java create mode 100644 netty-channel/src/test/java/io/netty/channel/local/LocalChannelTest.java create mode 100644 netty-channel/src/test/java/io/netty/channel/local/LocalTransportThreadModelTest.java create mode 100644 netty-channel/src/test/java/io/netty/channel/local/LocalTransportThreadModelTest2.java create mode 100644 netty-channel/src/test/java/io/netty/channel/local/LocalTransportThreadModelTest3.java create mode 100644 netty-channel/src/test/java/io/netty/channel/nio/NioEventLoopTest.java create mode 100644 netty-channel/src/test/java/io/netty/channel/nio/SelectedSelectionKeySetTest.java create mode 100644 netty-channel/src/test/java/io/netty/channel/oio/OioEventLoopTest.java create mode 100644 netty-channel/src/test/java/io/netty/channel/pool/AbstractChannelPoolMapTest.java create mode 100644 netty-channel/src/test/java/io/netty/channel/pool/ChannelPoolTestUtils.java create mode 100644 netty-channel/src/test/java/io/netty/channel/pool/CountingChannelPoolHandler.java create mode 100644 netty-channel/src/test/java/io/netty/channel/pool/FixedChannelPoolMapDeadlockTest.java create mode 100644 netty-channel/src/test/java/io/netty/channel/pool/FixedChannelPoolTest.java create mode 100644 netty-channel/src/test/java/io/netty/channel/pool/SimpleChannelPoolTest.java create mode 100644 netty-channel/src/test/java/io/netty/channel/socket/InternetProtocolFamilyTest.java create mode 100644 netty-channel/src/test/java/io/netty/channel/socket/nio/AbstractNioChannelTest.java create mode 100644 netty-channel/src/test/java/io/netty/channel/socket/nio/NioDatagramChannelTest.java create mode 100644 netty-channel/src/test/java/io/netty/channel/socket/nio/NioServerSocketChannelTest.java create mode 100644 netty-channel/src/test/java/io/netty/channel/socket/nio/NioSocketChannelTest.java create mode 100644 netty-channel/src/test/java/io/netty/nativeimage/ChannelHandlerMetadataUtil.java create mode 100644 netty-channel/src/test/resources/META-INF/services/io.netty.bootstrap.ChannelInitializerExtension create mode 100644 netty-channel/src/test/resources/logging.properties create mode 100644 netty-handler-codec-compression/build.gradle create mode 100644 netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/Brotli.java create mode 100644 netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/BrotliDecoder.java create mode 100644 netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/BrotliEncoder.java create mode 100644 netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/BrotliOptions.java create mode 100644 netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/ByteBufChecksum.java create mode 100644 netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/Bzip2Decoder.java create mode 100644 netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/Bzip2Encoder.java create mode 100644 netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/CompressionException.java create mode 100644 netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/CompressionOptions.java create mode 100644 netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/CompressionUtil.java create mode 100644 netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/Crc32c.java create mode 100644 netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/DecompressionException.java create mode 100644 netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/DeflateOptions.java create mode 100644 netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/EncoderUtil.java create mode 100644 netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/FastLz.java create mode 100644 netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/FastLzFrameDecoder.java create mode 100644 netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/FastLzFrameEncoder.java create mode 100644 netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/GzipOptions.java create mode 100644 netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/JZlibDecoder.java create mode 100644 netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/JZlibEncoder.java create mode 100644 netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/JdkZlibDecoder.java create mode 100644 netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/JdkZlibEncoder.java create mode 100644 netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/Lz4Constants.java create mode 100644 netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/Lz4FrameDecoder.java create mode 100644 netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/Lz4FrameEncoder.java create mode 100644 netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/Lz4XXHash32.java create mode 100644 netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/LzfDecoder.java create mode 100644 netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/LzfEncoder.java create mode 100644 netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/Snappy.java create mode 100644 netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/SnappyFrameDecoder.java create mode 100644 netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/SnappyFrameEncoder.java create mode 100644 netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/SnappyFramedDecoder.java create mode 100644 netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/SnappyFramedEncoder.java create mode 100644 netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/SnappyOptions.java create mode 100644 netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/StandardCompressionOptions.java create mode 100644 netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/ZlibCodecFactory.java create mode 100644 netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/ZlibDecoder.java create mode 100644 netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/ZlibEncoder.java create mode 100644 netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/ZlibUtil.java create mode 100644 netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/ZlibWrapper.java create mode 100644 netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/Zstd.java create mode 100644 netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/ZstdConstants.java create mode 100644 netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/ZstdEncoder.java create mode 100644 netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/ZstdOptions.java create mode 100644 netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/package-info.java create mode 100644 netty-handler-codec-compression/src/main/java/module-info.java create mode 100644 netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/AbstractCompressionTest.java create mode 100644 netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/AbstractDecoderTest.java create mode 100644 netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/AbstractEncoderTest.java create mode 100644 netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/AbstractIntegrationTest.java create mode 100644 netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/BrotliDecoderTest.java create mode 100644 netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/BrotliEncoderTest.java create mode 100644 netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/ByteBufChecksumTest.java create mode 100644 netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/Bzip2DecoderTest.java create mode 100644 netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/Bzip2EncoderTest.java create mode 100644 netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/Bzip2IntegrationTest.java create mode 100644 netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/FastLzIntegrationTest.java create mode 100644 netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/JZlibTest.java create mode 100644 netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/JdkZlibTest.java create mode 100644 netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/LengthAwareLzfIntegrationTest.java create mode 100644 netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/Lz4FrameDecoderTest.java create mode 100644 netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/Lz4FrameEncoderTest.java create mode 100644 netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/Lz4FrameIntegrationTest.java create mode 100644 netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/LzfDecoderTest.java create mode 100644 netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/LzfEncoderTest.java create mode 100644 netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/LzfIntegrationTest.java create mode 100644 netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/SnappyFrameDecoderTest.java create mode 100644 netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/SnappyFrameEncoderTest.java create mode 100644 netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/SnappyIntegrationTest.java create mode 100644 netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/SnappyTest.java create mode 100644 netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/ZlibCrossTest1.java create mode 100644 netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/ZlibCrossTest2.java create mode 100644 netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/ZlibTest.java create mode 100644 netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/ZstdEncoderTest.java create mode 100644 netty-handler-codec-compression/src/test/resources/logging.properties create mode 100644 netty-handler-codec-compression/src/test/resources/multiple.gz create mode 100644 netty-handler-codec-http/build.gradle create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/ClientCookieEncoder.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/CombinedHttpHeaders.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/ComposedLastHttpContent.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/CompressionEncoderFactory.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/Cookie.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/CookieDecoder.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/CookieUtil.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/DefaultCookie.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/DefaultFullHttpRequest.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/DefaultFullHttpResponse.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/DefaultHttpContent.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/DefaultHttpHeaders.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/DefaultHttpHeadersFactory.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/DefaultHttpMessage.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/DefaultHttpObject.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/DefaultHttpRequest.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/DefaultHttpResponse.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/DefaultLastHttpContent.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/EmptyHttpHeaders.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/FullHttpMessage.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/FullHttpRequest.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/FullHttpResponse.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpChunkedInput.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpClientCodec.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpClientUpgradeHandler.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpConstants.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpContent.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpContentCompressor.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpContentDecoder.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpContentDecompressor.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpContentEncoder.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpDecoderConfig.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpExpectationFailedEvent.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpHeaderDateFormat.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpHeaderNames.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpHeaderValidationUtil.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpHeaderValues.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpHeaders.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpHeadersEncoder.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpHeadersFactory.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpMessage.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpMessageDecoderResult.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpMessageUtil.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpMethod.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpObject.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpObjectAggregator.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpObjectDecoder.java create mode 100755 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpObjectEncoder.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpRequest.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpRequestDecoder.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpRequestEncoder.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpResponse.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpResponseDecoder.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpResponseEncoder.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpResponseStatus.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpScheme.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpServerCodec.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpServerExpectContinueHandler.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpServerKeepAliveHandler.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpServerUpgradeHandler.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpStatusClass.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpUtil.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpVersion.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/LastHttpContent.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/QueryStringDecoder.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/QueryStringEncoder.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/ReadOnlyHttpHeaders.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/ServerCookieEncoder.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/TooLongHttpContentException.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/TooLongHttpHeaderException.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/TooLongHttpLineException.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/cookie/ClientCookieDecoder.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/cookie/ClientCookieEncoder.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/cookie/Cookie.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/cookie/CookieDecoder.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/cookie/CookieEncoder.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/cookie/CookieHeaderNames.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/cookie/CookieUtil.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/cookie/DefaultCookie.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/cookie/ServerCookieDecoder.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/cookie/ServerCookieEncoder.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/cookie/package-info.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/cors/CorsConfig.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/cors/CorsConfigBuilder.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/cors/CorsHandler.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/cors/package-info.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/AbstractDiskHttpData.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/AbstractHttpData.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/AbstractMemoryHttpData.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/AbstractMixedHttpData.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/Attribute.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/CaseIgnoringComparator.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/DefaultHttpDataFactory.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/DeleteFileOnExitHook.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/DiskAttribute.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/DiskFileUpload.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/FileUpload.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/FileUploadUtil.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/HttpData.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/HttpDataFactory.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/HttpPostBodyUtil.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/HttpPostMultipartRequestDecoder.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/HttpPostRequestDecoder.java create mode 100755 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/HttpPostRequestEncoder.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/HttpPostStandardRequestDecoder.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/InterfaceHttpData.java create mode 100755 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/InterfaceHttpPostRequestDecoder.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/InternalAttribute.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/MemoryAttribute.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/MemoryFileUpload.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/MixedAttribute.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/MixedFileUpload.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/package-info.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/package-info.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/BinaryWebSocketFrame.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/CloseWebSocketFrame.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/ContinuationWebSocketFrame.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/CorruptedWebSocketFrameException.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/PingWebSocketFrame.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/PongWebSocketFrame.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/TextWebSocketFrame.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/Utf8FrameValidator.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/Utf8Validator.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocket00FrameDecoder.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocket00FrameEncoder.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocket07FrameDecoder.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocket07FrameEncoder.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocket08FrameDecoder.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocket08FrameEncoder.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocket13FrameDecoder.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocket13FrameEncoder.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketChunkedInput.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshakeException.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker00.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker07.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker08.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker13.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshakerFactory.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientProtocolConfig.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientProtocolHandler.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientProtocolHandshakeHandler.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketCloseStatus.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketDecoderConfig.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketFrame.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketFrameAggregator.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketFrameDecoder.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketFrameEncoder.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketHandshakeException.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketProtocolHandler.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketScheme.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshakeException.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshaker.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshaker00.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshaker07.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshaker08.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshaker13.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshakerFactory.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolConfig.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandler.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandshakeHandler.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketUtil.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketVersion.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketClientExtension.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketClientExtensionHandler.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketClientExtensionHandshaker.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketExtension.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketExtensionData.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketExtensionDecoder.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketExtensionEncoder.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketExtensionFilter.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketExtensionFilterProvider.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketExtensionUtil.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketServerExtension.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketServerExtensionHandler.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketServerExtensionHandshaker.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/DeflateDecoder.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/DeflateEncoder.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/DeflateFrameClientExtensionHandshaker.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/DeflateFrameServerExtensionHandshaker.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerFrameDeflateDecoder.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerFrameDeflateEncoder.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerMessageDeflateClientExtensionHandshaker.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerMessageDeflateDecoder.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerMessageDeflateEncoder.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerMessageDeflateServerExtensionHandshaker.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/WebSocketClientCompressionHandler.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/WebSocketServerCompressionHandler.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/package-info.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/package-info.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/package-info.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/rtsp/RtspDecoder.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/rtsp/RtspEncoder.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/rtsp/RtspHeaderNames.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/rtsp/RtspHeaderValues.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/rtsp/RtspHeaders.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/rtsp/RtspMethods.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/rtsp/RtspObjectDecoder.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/rtsp/RtspObjectEncoder.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/rtsp/RtspRequestDecoder.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/rtsp/RtspRequestEncoder.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/rtsp/RtspResponseDecoder.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/rtsp/RtspResponseEncoder.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/rtsp/RtspResponseStatuses.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/rtsp/RtspVersions.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/rtsp/package-info.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/DefaultSpdyDataFrame.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/DefaultSpdyGoAwayFrame.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/DefaultSpdyHeaders.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/DefaultSpdyHeadersFrame.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/DefaultSpdyPingFrame.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/DefaultSpdyRstStreamFrame.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/DefaultSpdySettingsFrame.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/DefaultSpdyStreamFrame.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/DefaultSpdySynReplyFrame.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/DefaultSpdySynStreamFrame.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/DefaultSpdyWindowUpdateFrame.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyCodecUtil.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyDataFrame.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyFrame.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyFrameCodec.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyFrameDecoder.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyFrameDecoderDelegate.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyFrameEncoder.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyGoAwayFrame.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyHeaderBlockDecoder.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyHeaderBlockEncoder.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyHeaderBlockJZlibEncoder.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyHeaderBlockRawDecoder.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyHeaderBlockRawEncoder.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyHeaderBlockZlibDecoder.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyHeaderBlockZlibEncoder.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyHeaders.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyHeadersFrame.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyHttpCodec.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyHttpDecoder.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyHttpEncoder.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyHttpHeaders.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyHttpResponseStreamIdHandler.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyPingFrame.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyProtocolException.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyRstStreamFrame.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdySession.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdySessionHandler.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdySessionStatus.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdySettingsFrame.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyStreamFrame.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyStreamStatus.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdySynReplyFrame.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdySynStreamFrame.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyVersion.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyWindowUpdateFrame.java create mode 100644 netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/package-info.java create mode 100644 netty-handler-codec-http/src/main/java/module-info.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/CombinedHttpHeadersTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/DefaultHttpHeadersTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/DefaultHttpRequestTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/DefaultHttpResponseTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/EmptyHttpHeadersInitializationTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpChunkedInputTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpClientCodecTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpClientUpgradeHandlerTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpContentCompressorOptionsTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpContentCompressorTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpContentDecoderTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpContentDecompressorTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpContentEncoderTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpHeaderDateFormatTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpHeaderValidationUtilTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpHeadersTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpHeadersTestUtils.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpInvalidMessageTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpObjectAggregatorTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpRequestDecoderTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpRequestEncoderTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpResponseDecoderTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpResponseEncoderTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpResponseStatusTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpServerCodecTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpServerExpectContinueHandlerTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpServerKeepAliveHandlerTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpServerUpgradeHandlerTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpUtilTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/MultipleContentLengthHeadersTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/QueryStringDecoderTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/QueryStringEncoderTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/ReadOnlyHttpHeadersTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/cookie/ClientCookieDecoderTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/cookie/ClientCookieEncoderTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/cookie/ServerCookieDecoderTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/cookie/ServerCookieEncoderTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/cors/CorsConfigTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/cors/CorsHandlerTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/multipart/AbstractDiskHttpDataTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/multipart/AbstractMemoryHttpDataTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/multipart/DefaultHttpDataFactoryTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/multipart/DeleteFileOnExitHookTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/multipart/DiskFileUploadTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/multipart/HttpDataTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/multipart/HttpPostMultiPartRequestDecoderTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/multipart/HttpPostRequestDecoderTest.java create mode 100755 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/multipart/HttpPostRequestEncoderTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/multipart/HttpPostStandardRequestDecoderTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/multipart/MemoryFileUploadTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/multipart/MixedTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/CloseWebSocketFrameTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocket00FrameEncoderTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocket08EncoderDecoderTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocket08FrameDecoderTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker00Test.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker07Test.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker08Test.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker13Test.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshakerTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketCloseStatusTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketFrameAggregatorTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketHandshakeExceptionTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketHandshakeHandOverTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketProtocolHandlerTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketRequestBuilder.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshaker00Test.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshaker07Test.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshaker08Test.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshaker13Test.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshakerFactoryTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshakerTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandlerTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketUtf8FrameValidatorTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketUtilTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketClientExtensionHandlerTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketExtensionFilterProviderTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketExtensionFilterTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketExtensionTestUtil.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketExtensionUtilTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketServerExtensionHandlerTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/compression/DeflateFrameClientExtensionHandshakerTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/compression/DeflateFrameServerExtensionHandshakerTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerFrameDeflateDecoderTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerFrameDeflateEncoderTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerMessageDeflateClientExtensionHandshakerTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerMessageDeflateDecoderTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerMessageDeflateEncoderTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerMessageDeflateServerExtensionHandshakerTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/compression/WebSocketServerCompressionHandlerTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/rtsp/RtspDecoderTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/rtsp/RtspEncoderTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/spdy/DefaultSpdyHeadersTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/spdy/SpdyFrameDecoderTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/spdy/SpdyHeaderBlockRawDecoderTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/spdy/SpdyHeaderBlockZlibDecoderTest.java create mode 100644 netty-handler-codec-http/src/test/java/io/netty/handler/codec/spdy/SpdySessionHandlerTest.java create mode 100644 netty-handler-codec-http/src/test/resources/file-01.txt create mode 100644 netty-handler-codec-http/src/test/resources/file-02.txt create mode 100644 netty-handler-codec-http/src/test/resources/file-03.txt create mode 100644 netty-handler-codec-http/src/test/resources/junit-platform.properties create mode 100644 netty-handler-codec-http/src/test/resources/logging.properties create mode 100644 netty-handler-codec-http2/build.gradle create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/AbstractHttp2ConnectionHandlerBuilder.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/AbstractHttp2StreamChannel.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/AbstractHttp2StreamFrame.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/AbstractInboundHttp2ToHttpAdapterBuilder.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/CharSequenceMap.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/CleartextHttp2ServerUpgradeHandler.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/CompressorHttp2ConnectionEncoder.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DecoratingHttp2ConnectionDecoder.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DecoratingHttp2ConnectionEncoder.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DecoratingHttp2FrameWriter.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2Connection.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionDecoder.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionEncoder.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2DataFrame.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2FrameReader.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2FrameWriter.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2GoAwayFrame.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2Headers.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2HeadersDecoder.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2HeadersEncoder.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2HeadersFrame.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2LocalFlowController.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2PingFrame.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2PriorityFrame.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2PushPromiseFrame.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2RemoteFlowController.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2ResetFrame.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2SettingsAckFrame.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2SettingsFrame.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2UnknownFrame.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2WindowUpdateFrame.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DelegatingDecompressorFrameListener.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/EmptyHttp2Headers.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/HpackDecoder.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/HpackDynamicTable.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/HpackEncoder.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/HpackHeaderField.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/HpackHuffmanDecoder.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/HpackHuffmanEncoder.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/HpackStaticTable.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/HpackUtil.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ChannelDuplexHandler.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ClientUpgradeCodec.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2CodecUtil.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2Connection.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ConnectionAdapter.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ConnectionDecoder.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ConnectionEncoder.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ConnectionHandler.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ConnectionHandlerBuilder.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ConnectionPrefaceAndSettingsFrameWrittenEvent.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ControlFrameLimitEncoder.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2DataChunkedInput.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2DataFrame.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2DataWriter.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2EmptyDataFrameConnectionDecoder.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2EmptyDataFrameListener.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2Error.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2EventAdapter.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2Exception.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2Flags.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FlowController.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2Frame.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameAdapter.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameCodec.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameCodecBuilder.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameListener.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameListenerDecorator.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameLogger.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameReader.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameSizePolicy.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameStream.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameStreamEvent.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameStreamException.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameStreamVisitor.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameTypes.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameWriter.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2GoAwayFrame.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2Headers.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2HeadersDecoder.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2HeadersEncoder.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2HeadersFrame.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2InboundFrameLogger.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2LifecycleManager.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2LocalFlowController.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2MaxRstFrameDecoder.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2MaxRstFrameListener.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2MultiplexActiveStreamsException.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2MultiplexCodec.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2MultiplexCodecBuilder.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2MultiplexHandler.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2NoMoreStreamIdsException.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2OutboundFrameLogger.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2PingFrame.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2PriorityFrame.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2PromisedRequestVerifier.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2PushPromiseFrame.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2RemoteFlowController.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ResetFrame.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2SecurityUtil.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ServerUpgradeCodec.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2Settings.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2SettingsAckFrame.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2SettingsFrame.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2SettingsReceivedConsumer.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2Stream.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2StreamChannel.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2StreamChannelBootstrap.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2StreamChannelId.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2StreamFrame.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2StreamFrameToHttpObjectCodec.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2StreamVisitor.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2UnknownFrame.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2WindowUpdateFrame.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/HttpConversionUtil.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/HttpToHttp2ConnectionHandler.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/HttpToHttp2ConnectionHandlerBuilder.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/InboundHttp2ToHttpAdapter.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/InboundHttp2ToHttpAdapterBuilder.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/InboundHttpToHttp2Adapter.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/MaxCapacityQueue.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/ReadOnlyHttp2Headers.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/StreamBufferingEncoder.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/StreamByteDistributor.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/UniformStreamByteDistributor.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/WeightedFairQueueByteDistributor.java create mode 100644 netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/package-info.java create mode 100644 netty-handler-codec-http2/src/main/java/module-info.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/AbstractDecoratingHttp2ConnectionDecoderTest.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/AbstractWeightedFairQueueByteDistributorDependencyTest.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/CleartextHttp2ServerUpgradeHandlerTest.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/DataCompressionHttp2Test.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/DecoratingHttp2ConnectionEncoderTest.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionDecoderTest.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionEncoderTest.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionTest.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2FrameReaderTest.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2FrameWriterTest.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2HeadersDecoderTest.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2HeadersEncoderTest.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2HeadersTest.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2LocalFlowControllerTest.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2PushPromiseFrameTest.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2RemoteFlowControllerTest.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/HashCollisionTest.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/HpackDecoderTest.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/HpackDynamicTableTest.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/HpackEncoderTest.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/HpackHuffmanTest.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/HpackStaticTableTest.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/HpackTest.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/HpackTestCase.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2ClientUpgradeCodecTest.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2ConnectionHandlerTest.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2ConnectionRoundtripTest.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2ControlFrameLimitEncoderTest.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2DataChunkedInputTest.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2DefaultFramesTest.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2EmptyDataFrameConnectionDecoderTest.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2EmptyDataFrameListenerTest.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2ExceptionTest.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2FrameCodecTest.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2FrameInboundWriter.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2FrameRoundtripTest.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2HeaderBlockIOTest.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2MaxRstFrameConnectionDecoderTest.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2MaxRstFrameListenerTest.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2MultiplexClientUpgradeTest.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2MultiplexCodecBuilderTest.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2MultiplexCodecClientUpgradeTest.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2MultiplexCodecTest.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2MultiplexHandlerClientUpgradeTest.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2MultiplexHandlerTest.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2MultiplexTest.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2MultiplexTransportTest.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2SecurityUtilTest.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2ServerUpgradeCodecTest.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2SettingsTest.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2StreamChannelBootstrapTest.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2StreamChannelIdTest.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2StreamFrameToHttpObjectCodecTest.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2TestUtil.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/HttpConversionUtilTest.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/HttpToHttp2ConnectionHandlerTest.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/InOrderHttp2Headers.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/InboundHttp2ToHttpAdapterTest.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/LastInboundHandler.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/ReadOnlyHttp2HeadersTest.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/StreamBufferingEncoderTest.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/TestChannelInitializer.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/TestHeaderListener.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/UniformStreamByteDistributorFlowControllerTest.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/UniformStreamByteDistributorTest.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/WeightedFairQueueByteDistributorDependencyTreeTest.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/WeightedFairQueueByteDistributorTest.java create mode 100644 netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/WeightedFairQueueRemoteFlowControllerTest.java create mode 100644 netty-handler-codec-http2/src/test/resources/io/netty/handler/codec/http2/testdata/testDuplicateHeaders.json create mode 100644 netty-handler-codec-http2/src/test/resources/io/netty/handler/codec/http2/testdata/testEmpty.json create mode 100644 netty-handler-codec-http2/src/test/resources/io/netty/handler/codec/http2/testdata/testEviction.json create mode 100644 netty-handler-codec-http2/src/test/resources/io/netty/handler/codec/http2/testdata/testMaxHeaderTableSize.json create mode 100644 netty-handler-codec-http2/src/test/resources/io/netty/handler/codec/http2/testdata/testSpecExampleC2_1.json create mode 100644 netty-handler-codec-http2/src/test/resources/io/netty/handler/codec/http2/testdata/testSpecExampleC2_2.json create mode 100644 netty-handler-codec-http2/src/test/resources/io/netty/handler/codec/http2/testdata/testSpecExampleC2_3.json create mode 100644 netty-handler-codec-http2/src/test/resources/io/netty/handler/codec/http2/testdata/testSpecExampleC2_4.json create mode 100644 netty-handler-codec-http2/src/test/resources/io/netty/handler/codec/http2/testdata/testSpecExampleC3.json create mode 100644 netty-handler-codec-http2/src/test/resources/io/netty/handler/codec/http2/testdata/testSpecExampleC4.json create mode 100644 netty-handler-codec-http2/src/test/resources/io/netty/handler/codec/http2/testdata/testSpecExampleC5.json create mode 100644 netty-handler-codec-http2/src/test/resources/io/netty/handler/codec/http2/testdata/testSpecExampleC6.json create mode 100644 netty-handler-codec-http2/src/test/resources/io/netty/handler/codec/http2/testdata/testStaticTableEntries.json create mode 100644 netty-handler-codec-http2/src/test/resources/io/netty/handler/codec/http2/testdata/testStaticTableResponseEntries.json create mode 100644 netty-handler-codec-http2/src/test/resources/junit-platform.properties create mode 100644 netty-handler-codec-http2/src/test/resources/logging.properties create mode 100644 netty-handler-codec-protobuf/build.gradle create mode 100644 netty-handler-codec-protobuf/src/main/java/io/netty/handler/codec/protobuf/ProtobufDecoder.java create mode 100644 netty-handler-codec-protobuf/src/main/java/io/netty/handler/codec/protobuf/ProtobufEncoder.java create mode 100644 netty-handler-codec-protobuf/src/main/java/io/netty/handler/codec/protobuf/ProtobufVarint32FrameDecoder.java create mode 100644 netty-handler-codec-protobuf/src/main/java/io/netty/handler/codec/protobuf/ProtobufVarint32LengthFieldPrepender.java create mode 100644 netty-handler-codec-protobuf/src/main/java/io/netty/handler/codec/protobuf/package-info.java create mode 100644 netty-handler-codec-protobuf/src/main/java/module-info.java create mode 100644 netty-handler-codec/build.gradle create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/AsciiHeadersEncoder.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/ByteToMessageCodec.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/ByteToMessageDecoder.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/CharSequenceValueConverter.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/CodecException.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/CodecOutputList.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/CorruptedFrameException.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/DatagramPacketDecoder.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/DatagramPacketEncoder.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/DateFormatter.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/DecoderException.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/DecoderResult.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/DecoderResultProvider.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/DefaultHeaders.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/DefaultHeadersImpl.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/DelimiterBasedFrameDecoder.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/Delimiters.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/EmptyHeaders.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/EncoderException.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/FixedLengthFrameDecoder.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/Headers.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/HeadersUtils.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/LengthFieldBasedFrameDecoder.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/LengthFieldPrepender.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/LineBasedFrameDecoder.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/MessageAggregationException.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/MessageAggregator.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/MessageToByteEncoder.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/MessageToMessageCodec.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/MessageToMessageDecoder.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/MessageToMessageEncoder.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/PrematureChannelClosureException.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/ProtocolDetectionResult.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/ProtocolDetectionState.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/ReplayingDecoder.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/ReplayingDecoderByteBuf.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/TooLongFrameException.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/UnsupportedMessageTypeException.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/UnsupportedValueConverter.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/ValueConverter.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/base64/Base64.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/base64/Base64Decoder.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/base64/Base64Dialect.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/base64/Base64Encoder.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/base64/package-info.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/bytes/ByteArrayDecoder.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/bytes/ByteArrayEncoder.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/bytes/package-info.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/json/JsonObjectDecoder.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/json/package-info.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/package-info.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/serialization/CachingClassResolver.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/serialization/ClassLoaderClassResolver.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/serialization/ClassResolver.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/serialization/ClassResolvers.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/serialization/CompactObjectInputStream.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/serialization/CompactObjectOutputStream.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/serialization/CompatibleObjectEncoder.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/serialization/ObjectDecoder.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/serialization/ObjectDecoderInputStream.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/serialization/ObjectEncoder.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/serialization/ObjectEncoderOutputStream.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/serialization/ReferenceMap.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/serialization/SoftReferenceMap.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/serialization/WeakReferenceMap.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/serialization/package-info.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/string/LineEncoder.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/string/LineSeparator.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/string/StringDecoder.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/string/StringEncoder.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/string/package-info.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/xml/XmlFrameDecoder.java create mode 100644 netty-handler-codec/src/main/java/io/netty/handler/codec/xml/package-info.java create mode 100644 netty-handler-codec/src/main/java/module-info.java create mode 100644 netty-handler-codec/src/test/java/io/netty/handler/codec/ByteToMessageCodecTest.java create mode 100644 netty-handler-codec/src/test/java/io/netty/handler/codec/ByteToMessageDecoderTest.java create mode 100644 netty-handler-codec/src/test/java/io/netty/handler/codec/CharSequenceValueConverterTest.java create mode 100644 netty-handler-codec/src/test/java/io/netty/handler/codec/CodecOutputListTest.java create mode 100644 netty-handler-codec/src/test/java/io/netty/handler/codec/DatagramPacketDecoderTest.java create mode 100644 netty-handler-codec/src/test/java/io/netty/handler/codec/DatagramPacketEncoderTest.java create mode 100644 netty-handler-codec/src/test/java/io/netty/handler/codec/DateFormatterTest.java create mode 100644 netty-handler-codec/src/test/java/io/netty/handler/codec/DefaultHeadersTest.java create mode 100644 netty-handler-codec/src/test/java/io/netty/handler/codec/DelimiterBasedFrameDecoderTest.java create mode 100644 netty-handler-codec/src/test/java/io/netty/handler/codec/EmptyHeadersTest.java create mode 100644 netty-handler-codec/src/test/java/io/netty/handler/codec/LengthFieldBasedFrameDecoderTest.java create mode 100644 netty-handler-codec/src/test/java/io/netty/handler/codec/LineBasedFrameDecoderTest.java create mode 100644 netty-handler-codec/src/test/java/io/netty/handler/codec/MessageAggregatorTest.java create mode 100644 netty-handler-codec/src/test/java/io/netty/handler/codec/MessageToMessageEncoderTest.java create mode 100644 netty-handler-codec/src/test/java/io/netty/handler/codec/ReplayingDecoderByteBufTest.java create mode 100644 netty-handler-codec/src/test/java/io/netty/handler/codec/ReplayingDecoderTest.java create mode 100644 netty-handler-codec/src/test/java/io/netty/handler/codec/base64/Base64Test.java create mode 100644 netty-handler-codec/src/test/java/io/netty/handler/codec/bytes/ByteArrayDecoderTest.java create mode 100644 netty-handler-codec/src/test/java/io/netty/handler/codec/bytes/ByteArrayEncoderTest.java create mode 100644 netty-handler-codec/src/test/java/io/netty/handler/codec/frame/DelimiterBasedFrameDecoderTest.java create mode 100644 netty-handler-codec/src/test/java/io/netty/handler/codec/frame/LengthFieldBasedFrameDecoderTest.java create mode 100644 netty-handler-codec/src/test/java/io/netty/handler/codec/frame/LengthFieldPrependerTest.java create mode 100644 netty-handler-codec/src/test/java/io/netty/handler/codec/frame/package-info.java create mode 100644 netty-handler-codec/src/test/java/io/netty/handler/codec/json/JsonObjectDecoderTest.java create mode 100644 netty-handler-codec/src/test/java/io/netty/handler/codec/serialization/CompactObjectSerializationTest.java create mode 100644 netty-handler-codec/src/test/java/io/netty/handler/codec/serialization/CompatibleObjectEncoderTest.java create mode 100644 netty-handler-codec/src/test/java/io/netty/handler/codec/string/LineEncoderTest.java create mode 100644 netty-handler-codec/src/test/java/io/netty/handler/codec/string/StringDecoderTest.java create mode 100644 netty-handler-codec/src/test/java/io/netty/handler/codec/string/StringEncoderTest.java create mode 100644 netty-handler-codec/src/test/java/io/netty/handler/codec/xml/XmlFrameDecoderTest.java create mode 100644 netty-handler-codec/src/test/resources/io/netty/handler/codec/xml/sample-01.xml create mode 100644 netty-handler-codec/src/test/resources/io/netty/handler/codec/xml/sample-02.xml create mode 100644 netty-handler-codec/src/test/resources/io/netty/handler/codec/xml/sample-03.xml create mode 100644 netty-handler-codec/src/test/resources/io/netty/handler/codec/xml/sample-04.xml create mode 100644 netty-handler-codec/src/test/resources/io/netty/handler/codec/xml/sample-05.xml create mode 100644 netty-handler-codec/src/test/resources/io/netty/handler/codec/xml/sample-06.xml create mode 100644 netty-handler-ssl/build.gradle create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/AbstractSniHandler.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/ApplicationProtocolAccessor.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/ApplicationProtocolConfig.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/ApplicationProtocolNames.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/ApplicationProtocolNegotiationHandler.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/ApplicationProtocolNegotiator.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/ApplicationProtocolUtil.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/AsyncRunnable.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/BouncyCastle.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/BouncyCastleAlpnSslEngine.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/BouncyCastleAlpnSslUtils.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/BouncyCastlePemReader.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/CipherSuiteConverter.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/CipherSuiteFilter.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/Ciphers.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/ClientAuth.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/Conscrypt.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/ConscryptAlpnSslEngine.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/DefaultOpenSslKeyMaterial.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/DelegatingSslContext.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/EnhancingX509ExtendedTrustManager.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/ExtendedOpenSslSession.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/GroupsConverter.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/IdentityCipherSuiteFilter.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/Java7SslParametersUtils.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/Java8SslUtils.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/JdkAlpnApplicationProtocolNegotiator.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/JdkAlpnSslEngine.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/JdkAlpnSslUtils.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/JdkApplicationProtocolNegotiator.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/JdkBaseApplicationProtocolNegotiator.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/JdkDefaultApplicationProtocolNegotiator.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/JdkSslClientContext.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/JdkSslContext.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/JdkSslEngine.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/JdkSslServerContext.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/NotSslRecordException.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSsl.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslApplicationProtocolNegotiator.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslAsyncPrivateKeyMethod.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslCachingKeyMaterialProvider.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslCachingX509KeyManagerFactory.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslCertificateCompressionAlgorithm.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslCertificateCompressionConfig.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslCertificateException.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslClientContext.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslClientSessionCache.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslContext.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslContextOption.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslDefaultApplicationProtocolNegotiator.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslEngine.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslEngineMap.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslKeyMaterial.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslKeyMaterialManager.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslKeyMaterialProvider.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslNpnApplicationProtocolNegotiator.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslPrivateKey.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslPrivateKeyMethod.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslServerContext.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslServerSessionContext.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslSession.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslSessionCache.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslSessionContext.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslSessionId.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslSessionStats.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslSessionTicketKey.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslX509KeyManagerFactory.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslX509TrustManagerWrapper.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/OptionalSslHandler.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/PemEncoded.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/PemPrivateKey.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/PemReader.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/PemValue.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/PemX509Certificate.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/PseudoRandomFunction.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslClientContext.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslContext.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngine.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslServerContext.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/SignatureAlgorithmConverter.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/SniCompletionEvent.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/SniHandler.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/SslClientHelloHandler.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/SslCloseCompletionEvent.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/SslClosedEngineException.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/SslCompletionEvent.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/SslContext.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/SslContextBuilder.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/SslContextOption.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/SslHandler.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/SslHandshakeCompletionEvent.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/SslHandshakeTimeoutException.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/SslMasterKeyHandler.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/SslProtocols.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/SslProvider.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/SslUtils.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/StacklessSSLHandshakeException.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/SupportedCipherSuiteFilter.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/ocsp/OcspClientHandler.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/ocsp/package-info.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/package-info.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/util/BouncyCastleSelfSignedCertGenerator.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/util/FingerprintTrustManagerFactory.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/util/FingerprintTrustManagerFactoryBuilder.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/util/InsecureTrustManagerFactory.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/util/KeyManagerFactoryWrapper.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/util/LazyJavaxX509Certificate.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/util/LazyX509Certificate.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/util/SelfSignedCertificate.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/util/SimpleKeyManagerFactory.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/util/SimpleTrustManagerFactory.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/util/ThreadLocalInsecureRandom.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/util/TrustManagerFactoryWrapper.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/util/X509KeyManagerWrapper.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/util/X509TrustManagerWrapper.java create mode 100644 netty-handler-ssl/src/main/java/io/netty/handler/ssl/util/package-info.java create mode 100644 netty-handler-ssl/src/main/java/module-info.java create mode 100644 netty-handler-ssl/src/test/java/io/netty/handler/ssl/AmazonCorrettoSslEngineTest.java create mode 100644 netty-handler-ssl/src/test/java/io/netty/handler/ssl/ApplicationProtocolNegotiationHandlerTest.java create mode 100644 netty-handler-ssl/src/test/java/io/netty/handler/ssl/CipherSuiteCanaryTest.java create mode 100644 netty-handler-ssl/src/test/java/io/netty/handler/ssl/CipherSuiteConverterTest.java create mode 100644 netty-handler-ssl/src/test/java/io/netty/handler/ssl/CloseNotifyTest.java create mode 100644 netty-handler-ssl/src/test/java/io/netty/handler/ssl/ConscryptJdkSslEngineInteropTest.java create mode 100644 netty-handler-ssl/src/test/java/io/netty/handler/ssl/ConscryptOpenSslEngineInteropTest.java create mode 100644 netty-handler-ssl/src/test/java/io/netty/handler/ssl/ConscryptSslEngineTest.java create mode 100644 netty-handler-ssl/src/test/java/io/netty/handler/ssl/DelegatingSslContextTest.java create mode 100644 netty-handler-ssl/src/test/java/io/netty/handler/ssl/EnhancedX509ExtendedTrustManagerTest.java create mode 100644 netty-handler-ssl/src/test/java/io/netty/handler/ssl/IdentityCipherSuiteFilterTest.java create mode 100644 netty-handler-ssl/src/test/java/io/netty/handler/ssl/Java8SslTestUtils.java create mode 100644 netty-handler-ssl/src/test/java/io/netty/handler/ssl/JdkConscryptSslEngineInteropTest.java create mode 100644 netty-handler-ssl/src/test/java/io/netty/handler/ssl/JdkOpenSslEngineInteroptTest.java create mode 100644 netty-handler-ssl/src/test/java/io/netty/handler/ssl/JdkSslClientContextTest.java create mode 100644 netty-handler-ssl/src/test/java/io/netty/handler/ssl/JdkSslEngineTest.java create mode 100644 netty-handler-ssl/src/test/java/io/netty/handler/ssl/JdkSslRenegotiateTest.java create mode 100644 netty-handler-ssl/src/test/java/io/netty/handler/ssl/JdkSslServerContextTest.java create mode 100644 netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslCachingKeyMaterialProviderTest.java create mode 100644 netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslCertificateCompressionTest.java create mode 100644 netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslCertificateExceptionTest.java create mode 100644 netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslClientContextTest.java create mode 100644 netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslConscryptSslEngineInteropTest.java create mode 100644 netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslEngineTest.java create mode 100644 netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslEngineTestParam.java create mode 100644 netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslErrorStackAssertSSLEngine.java create mode 100644 netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslJdkSslEngineInteroptTest.java create mode 100644 netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslKeyMaterialManagerTest.java create mode 100644 netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslKeyMaterialProviderTest.java create mode 100644 netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslPrivateKeyMethodTest.java create mode 100644 netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslRenegotiateTest.java create mode 100644 netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslServerContextTest.java create mode 100644 netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslTest.java create mode 100644 netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslTestUtils.java create mode 100644 netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslX509KeyManagerFactoryProviderTest.java create mode 100644 netty-handler-ssl/src/test/java/io/netty/handler/ssl/OptionalSslHandlerTest.java create mode 100644 netty-handler-ssl/src/test/java/io/netty/handler/ssl/ParameterizedSslHandlerTest.java create mode 100644 netty-handler-ssl/src/test/java/io/netty/handler/ssl/PemEncodedTest.java create mode 100644 netty-handler-ssl/src/test/java/io/netty/handler/ssl/PemReaderTest.java create mode 100644 netty-handler-ssl/src/test/java/io/netty/handler/ssl/PseudoRandomFunctionTest.java create mode 100644 netty-handler-ssl/src/test/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngineTest.java create mode 100644 netty-handler-ssl/src/test/java/io/netty/handler/ssl/RenegotiateTest.java create mode 100644 netty-handler-ssl/src/test/java/io/netty/handler/ssl/SSLEngineTest.java create mode 100644 netty-handler-ssl/src/test/java/io/netty/handler/ssl/SignatureAlgorithmConverterTest.java create mode 100644 netty-handler-ssl/src/test/java/io/netty/handler/ssl/SniClientJava8TestUtil.java create mode 100644 netty-handler-ssl/src/test/java/io/netty/handler/ssl/SniClientTest.java create mode 100644 netty-handler-ssl/src/test/java/io/netty/handler/ssl/SniHandlerTest.java create mode 100644 netty-handler-ssl/src/test/java/io/netty/handler/ssl/SslContextBuilderTest.java create mode 100644 netty-handler-ssl/src/test/java/io/netty/handler/ssl/SslContextTest.java create mode 100644 netty-handler-ssl/src/test/java/io/netty/handler/ssl/SslContextTrustManagerTest.java create mode 100644 netty-handler-ssl/src/test/java/io/netty/handler/ssl/SslErrorTest.java create mode 100644 netty-handler-ssl/src/test/java/io/netty/handler/ssl/SslHandlerTest.java create mode 100644 netty-handler-ssl/src/test/java/io/netty/handler/ssl/SslUtilsTest.java create mode 100644 netty-handler-ssl/src/test/java/io/netty/handler/ssl/ocsp/OcspTest.java create mode 100644 netty-handler-ssl/src/test/java/io/netty/handler/ssl/util/FingerprintTrustManagerFactoryTest.java create mode 100644 netty-handler-ssl/src/test/java/io/netty/handler/ssl/util/SelfSignedCertificateTest.java create mode 100644 netty-handler-ssl/src/test/resources/io/netty/handler/ssl/ec_params_unsupported.pem create mode 100755 netty-handler-ssl/src/test/resources/io/netty/handler/ssl/generate-certificate.sh create mode 100755 netty-handler-ssl/src/test/resources/io/netty/handler/ssl/generate-certs.sh create mode 100644 netty-handler-ssl/src/test/resources/io/netty/handler/ssl/localhost_server.pem create mode 100644 netty-handler-ssl/src/test/resources/io/netty/handler/ssl/mutual_auth_ca.pem create mode 100644 netty-handler-ssl/src/test/resources/io/netty/handler/ssl/mutual_auth_client.p12 create mode 100644 netty-handler-ssl/src/test/resources/io/netty/handler/ssl/mutual_auth_invalid_client.p12 create mode 100644 netty-handler-ssl/src/test/resources/io/netty/handler/ssl/mutual_auth_server.p12 create mode 100644 netty-handler-ssl/src/test/resources/io/netty/handler/ssl/notlocalhost_server.pem create mode 100644 netty-handler-ssl/src/test/resources/io/netty/handler/ssl/openssl.cnf create mode 100644 netty-handler-ssl/src/test/resources/io/netty/handler/ssl/rsaValidation-user-certs.p12 create mode 100644 netty-handler-ssl/src/test/resources/io/netty/handler/ssl/rsaValidations-server-keystore.p12 create mode 100644 netty-handler-ssl/src/test/resources/io/netty/handler/ssl/rsapss-ca-cert.cert create mode 100644 netty-handler-ssl/src/test/resources/io/netty/handler/ssl/rsapss-signing-ext.txt create mode 100644 netty-handler-ssl/src/test/resources/io/netty/handler/ssl/test2_encrypted.pem create mode 100644 netty-handler-ssl/src/test/resources/io/netty/handler/ssl/test2_unencrypted.pem create mode 100644 netty-handler-ssl/src/test/resources/io/netty/handler/ssl/test_encrypted.pem create mode 100644 netty-handler-ssl/src/test/resources/io/netty/handler/ssl/test_encrypted_empty_pass.pem create mode 100644 netty-handler-ssl/src/test/resources/io/netty/handler/ssl/test_unencrypted.pem create mode 100644 netty-handler-ssl/src/test/resources/io/netty/handler/ssl/tm_test_ca_1a.pem create mode 100644 netty-handler-ssl/src/test/resources/io/netty/handler/ssl/tm_test_ca_1b.pem create mode 100644 netty-handler-ssl/src/test/resources/io/netty/handler/ssl/tm_test_ca_2.pem create mode 100644 netty-handler-ssl/src/test/resources/io/netty/handler/ssl/tm_test_eec_1.pem create mode 100644 netty-handler-ssl/src/test/resources/io/netty/handler/ssl/tm_test_eec_2.pem create mode 100644 netty-handler-ssl/src/test/resources/io/netty/handler/ssl/tm_test_eec_3.pem create mode 100644 netty-handler-ssl/src/test/resources/logging.properties create mode 100644 netty-handler/build.gradle create mode 100644 netty-handler/src/main/java/io/netty/handler/address/DynamicAddressConnectHandler.java create mode 100644 netty-handler/src/main/java/io/netty/handler/address/ResolveAddressHandler.java create mode 100644 netty-handler/src/main/java/io/netty/handler/address/package-info.java create mode 100644 netty-handler/src/main/java/io/netty/handler/flow/FlowControlHandler.java create mode 100644 netty-handler/src/main/java/io/netty/handler/flow/package-info.java create mode 100644 netty-handler/src/main/java/io/netty/handler/flush/FlushConsolidationHandler.java create mode 100644 netty-handler/src/main/java/io/netty/handler/flush/package-info.java create mode 100644 netty-handler/src/main/java/io/netty/handler/ipfilter/AbstractRemoteAddressFilter.java create mode 100644 netty-handler/src/main/java/io/netty/handler/ipfilter/IpFilterRule.java create mode 100644 netty-handler/src/main/java/io/netty/handler/ipfilter/IpFilterRuleType.java create mode 100644 netty-handler/src/main/java/io/netty/handler/ipfilter/IpSubnetFilter.java create mode 100644 netty-handler/src/main/java/io/netty/handler/ipfilter/IpSubnetFilterRule.java create mode 100644 netty-handler/src/main/java/io/netty/handler/ipfilter/IpSubnetFilterRuleComparator.java create mode 100644 netty-handler/src/main/java/io/netty/handler/ipfilter/RuleBasedIpFilter.java create mode 100644 netty-handler/src/main/java/io/netty/handler/ipfilter/UniqueIpFilter.java create mode 100644 netty-handler/src/main/java/io/netty/handler/ipfilter/package-info.java create mode 100644 netty-handler/src/main/java/io/netty/handler/logging/ByteBufFormat.java create mode 100644 netty-handler/src/main/java/io/netty/handler/logging/LogLevel.java create mode 100644 netty-handler/src/main/java/io/netty/handler/logging/LoggingHandler.java create mode 100644 netty-handler/src/main/java/io/netty/handler/logging/package-info.java create mode 100644 netty-handler/src/main/java/io/netty/handler/pcap/EthernetPacket.java create mode 100644 netty-handler/src/main/java/io/netty/handler/pcap/IPPacket.java create mode 100644 netty-handler/src/main/java/io/netty/handler/pcap/PcapHeaders.java create mode 100644 netty-handler/src/main/java/io/netty/handler/pcap/PcapWriteHandler.java create mode 100644 netty-handler/src/main/java/io/netty/handler/pcap/PcapWriter.java create mode 100644 netty-handler/src/main/java/io/netty/handler/pcap/State.java create mode 100644 netty-handler/src/main/java/io/netty/handler/pcap/TCPPacket.java create mode 100644 netty-handler/src/main/java/io/netty/handler/pcap/UDPPacket.java create mode 100644 netty-handler/src/main/java/io/netty/handler/pcap/package-info.java create mode 100644 netty-handler/src/main/java/io/netty/handler/stream/ChunkedFile.java create mode 100644 netty-handler/src/main/java/io/netty/handler/stream/ChunkedInput.java create mode 100644 netty-handler/src/main/java/io/netty/handler/stream/ChunkedNioFile.java create mode 100644 netty-handler/src/main/java/io/netty/handler/stream/ChunkedNioStream.java create mode 100644 netty-handler/src/main/java/io/netty/handler/stream/ChunkedStream.java create mode 100644 netty-handler/src/main/java/io/netty/handler/stream/ChunkedWriteHandler.java create mode 100644 netty-handler/src/main/java/io/netty/handler/stream/package-info.java create mode 100644 netty-handler/src/main/java/io/netty/handler/timeout/IdleState.java create mode 100644 netty-handler/src/main/java/io/netty/handler/timeout/IdleStateEvent.java create mode 100644 netty-handler/src/main/java/io/netty/handler/timeout/IdleStateHandler.java create mode 100644 netty-handler/src/main/java/io/netty/handler/timeout/ReadTimeoutException.java create mode 100644 netty-handler/src/main/java/io/netty/handler/timeout/ReadTimeoutHandler.java create mode 100644 netty-handler/src/main/java/io/netty/handler/timeout/TimeoutException.java create mode 100644 netty-handler/src/main/java/io/netty/handler/timeout/WriteTimeoutException.java create mode 100644 netty-handler/src/main/java/io/netty/handler/timeout/WriteTimeoutHandler.java create mode 100644 netty-handler/src/main/java/io/netty/handler/timeout/package-info.java create mode 100644 netty-handler/src/main/java/io/netty/handler/traffic/AbstractTrafficShapingHandler.java create mode 100644 netty-handler/src/main/java/io/netty/handler/traffic/ChannelTrafficShapingHandler.java create mode 100644 netty-handler/src/main/java/io/netty/handler/traffic/GlobalChannelTrafficCounter.java create mode 100644 netty-handler/src/main/java/io/netty/handler/traffic/GlobalChannelTrafficShapingHandler.java create mode 100644 netty-handler/src/main/java/io/netty/handler/traffic/GlobalTrafficShapingHandler.java create mode 100644 netty-handler/src/main/java/io/netty/handler/traffic/TrafficCounter.java create mode 100644 netty-handler/src/main/java/io/netty/handler/traffic/package-info.java create mode 100644 netty-handler/src/main/java/module-info.java create mode 100644 netty-handler/src/test/java/io/netty/handler/address/DynamicAddressConnectHandlerTest.java create mode 100644 netty-handler/src/test/java/io/netty/handler/address/ResolveAddressHandlerTest.java create mode 100644 netty-handler/src/test/java/io/netty/handler/flow/FlowControlHandlerTest.java create mode 100644 netty-handler/src/test/java/io/netty/handler/flush/FlushConsolidationHandlerTest.java create mode 100644 netty-handler/src/test/java/io/netty/handler/ipfilter/IpSubnetFilterTest.java create mode 100644 netty-handler/src/test/java/io/netty/handler/ipfilter/UniqueIpFilterTest.java create mode 100644 netty-handler/src/test/java/io/netty/handler/pcap/CloseDetectingByteBufOutputStream.java create mode 100644 netty-handler/src/test/java/io/netty/handler/pcap/DiscardingStatsOutputStream.java create mode 100644 netty-handler/src/test/java/io/netty/handler/pcap/PcapWriteHandlerTest.java create mode 100644 netty-handler/src/test/java/io/netty/handler/stream/ChunkedStreamTest.java create mode 100644 netty-handler/src/test/java/io/netty/handler/stream/ChunkedWriteHandlerTest.java create mode 100644 netty-handler/src/test/java/io/netty/handler/timeout/IdleStateEventTest.java create mode 100644 netty-handler/src/test/java/io/netty/handler/timeout/IdleStateHandlerTest.java create mode 100644 netty-handler/src/test/java/io/netty/handler/timeout/WriteTimeoutHandlerTest.java create mode 100644 netty-handler/src/test/java/io/netty/handler/traffic/FileRegionThrottleTest.java create mode 100644 netty-handler/src/test/java/io/netty/handler/traffic/TrafficShapingHandlerTest.java create mode 100644 netty-handler/src/test/resources/logging.properties create mode 100644 netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/AsyncSSLPrivateKeyMethod.java create mode 100644 netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/AsyncSSLPrivateKeyMethodAdapter.java create mode 100644 netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/AsyncTask.java create mode 100644 netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/Buffer.java create mode 100644 netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/CertificateCallback.java create mode 100644 netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/CertificateCallbackTask.java create mode 100644 netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/CertificateCompressionAlgo.java create mode 100644 netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/CertificateRequestedCallback.java create mode 100644 netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/CertificateVerifier.java create mode 100644 netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/CertificateVerifierTask.java create mode 100644 netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/Library.java create mode 100644 netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/NativeStaticallyReferencedJniMethods.java create mode 100644 netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/ResultCallback.java create mode 100644 netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/SSL.java create mode 100644 netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/SSLContext.java create mode 100644 netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/SSLPrivateKeyMethod.java create mode 100644 netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/SSLPrivateKeyMethodDecryptTask.java create mode 100644 netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/SSLPrivateKeyMethodSignTask.java create mode 100644 netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/SSLPrivateKeyMethodTask.java create mode 100644 netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/SSLSession.java create mode 100644 netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/SSLSessionCache.java create mode 100644 netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/SSLTask.java create mode 100644 netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/SessionTicketKey.java create mode 100644 netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/SniHostNameMatcher.java create mode 100644 netty-internal-tcnative/src/main/java/module-info.java create mode 100644 netty-jctools/build.gradle create mode 100644 netty-jctools/src/main/java/module-info.java create mode 100644 netty-jctools/src/main/java/org/jctools/counters/Counter.java create mode 100644 netty-jctools/src/main/java/org/jctools/counters/CountersFactory.java create mode 100644 netty-jctools/src/main/java/org/jctools/counters/FixedSizeStripedLongCounter.java create mode 100644 netty-jctools/src/main/java/org/jctools/counters/FixedSizeStripedLongCounterV6.java create mode 100644 netty-jctools/src/main/java/org/jctools/counters/FixedSizeStripedLongCounterV8.java create mode 100644 netty-jctools/src/main/java/org/jctools/counters/package-info.java create mode 100644 netty-jctools/src/main/java/org/jctools/maps/AbstractEntry.java create mode 100644 netty-jctools/src/main/java/org/jctools/maps/ConcurrentAutoTable.java create mode 100644 netty-jctools/src/main/java/org/jctools/maps/NonBlockingHashMap.java create mode 100644 netty-jctools/src/main/java/org/jctools/maps/NonBlockingHashMapLong.java create mode 100644 netty-jctools/src/main/java/org/jctools/maps/NonBlockingHashSet.java create mode 100644 netty-jctools/src/main/java/org/jctools/maps/NonBlockingIdentityHashMap.java create mode 100644 netty-jctools/src/main/java/org/jctools/maps/NonBlockingSetInt.java create mode 100644 netty-jctools/src/main/java/org/jctools/maps/package-info.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/BaseLinkedQueue.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/BaseMpscLinkedArrayQueue.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/BaseSpscLinkedArrayQueue.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/ConcurrentCircularArrayQueue.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/ConcurrentSequencedCircularArrayQueue.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/IndexedQueueSizeUtil.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/LinkedArrayQueueUtil.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/LinkedQueueNode.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/MessagePassingQueue.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/MessagePassingQueueUtil.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/MpUnboundedXaddArrayQueue.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/MpUnboundedXaddChunk.java create mode 100755 netty-jctools/src/main/java/org/jctools/queues/MpmcArrayQueue.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/MpmcUnboundedXaddArrayQueue.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/MpmcUnboundedXaddChunk.java create mode 100755 netty-jctools/src/main/java/org/jctools/queues/MpscArrayQueue.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/MpscBlockingConsumerArrayQueue.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/MpscChunkedArrayQueue.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/MpscCompoundQueue.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/MpscGrowableArrayQueue.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/MpscLinkedQueue.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/MpscUnboundedArrayQueue.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/MpscUnboundedXaddArrayQueue.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/MpscUnboundedXaddChunk.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/QueueProgressIndicators.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/SpmcArrayQueue.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/SpscArrayQueue.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/SpscChunkedArrayQueue.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/SpscGrowableArrayQueue.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/SpscLinkedQueue.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/SpscUnboundedArrayQueue.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/SupportsIterator.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/atomic/AtomicQueueUtil.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/atomic/AtomicReferenceArrayQueue.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/atomic/BaseLinkedAtomicQueue.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/atomic/BaseMpscLinkedAtomicArrayQueue.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/atomic/BaseSpscLinkedAtomicArrayQueue.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/atomic/LinkedQueueAtomicNode.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/atomic/MpmcAtomicArrayQueue.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/atomic/MpscAtomicArrayQueue.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/atomic/MpscChunkedAtomicArrayQueue.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/atomic/MpscGrowableAtomicArrayQueue.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/atomic/MpscLinkedAtomicQueue.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/atomic/MpscUnboundedAtomicArrayQueue.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/atomic/SequencedAtomicReferenceArrayQueue.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/atomic/SpmcAtomicArrayQueue.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/atomic/SpscAtomicArrayQueue.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/atomic/SpscChunkedAtomicArrayQueue.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/atomic/SpscGrowableAtomicArrayQueue.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/atomic/SpscLinkedAtomicQueue.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/atomic/SpscUnboundedAtomicArrayQueue.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/atomic/package-info.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/package-info.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/unpadded/BaseLinkedUnpaddedQueue.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/unpadded/BaseMpscLinkedUnpaddedArrayQueue.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/unpadded/BaseSpscLinkedUnpaddedArrayQueue.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/unpadded/ConcurrentCircularUnpaddedArrayQueue.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/unpadded/ConcurrentSequencedCircularUnpaddedArrayQueue.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/unpadded/MpmcUnpaddedArrayQueue.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/unpadded/MpscChunkedUnpaddedArrayQueue.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/unpadded/MpscGrowableUnpaddedArrayQueue.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/unpadded/MpscLinkedUnpaddedQueue.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/unpadded/MpscUnboundedUnpaddedArrayQueue.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/unpadded/MpscUnpaddedArrayQueue.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/unpadded/SpmcUnpaddedArrayQueue.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/unpadded/SpscChunkedUnpaddedArrayQueue.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/unpadded/SpscGrowableUnpaddedArrayQueue.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/unpadded/SpscLinkedUnpaddedQueue.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/unpadded/SpscUnboundedUnpaddedArrayQueue.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/unpadded/SpscUnpaddedArrayQueue.java create mode 100644 netty-jctools/src/main/java/org/jctools/queues/unpadded/package-info.java create mode 100644 netty-jctools/src/main/java/org/jctools/util/InternalAPI.java create mode 100644 netty-jctools/src/main/java/org/jctools/util/PaddedAtomicLong.java create mode 100644 netty-jctools/src/main/java/org/jctools/util/PortableJvmInfo.java create mode 100644 netty-jctools/src/main/java/org/jctools/util/Pow2.java create mode 100644 netty-jctools/src/main/java/org/jctools/util/RangeUtil.java create mode 100644 netty-jctools/src/main/java/org/jctools/util/SpscLookAheadUtil.java create mode 100755 netty-jctools/src/main/java/org/jctools/util/UnsafeAccess.java create mode 100644 netty-jctools/src/main/java/org/jctools/util/UnsafeJvmInfo.java create mode 100644 netty-jctools/src/main/java/org/jctools/util/UnsafeLongArrayAccess.java create mode 100644 netty-jctools/src/main/java/org/jctools/util/UnsafeRefArrayAccess.java create mode 100644 netty-jctools/src/main/java/org/jctools/util/package-info.java create mode 100644 netty-jctools/src/test/java/org/jctools/counters/FixedSizeStripedLongCounterTest.java create mode 100644 netty-jctools/src/test/java/org/jctools/maps/KeyAtomicityTest.java create mode 100644 netty-jctools/src/test/java/org/jctools/maps/NBHMIdentityKeyAtomicityTest.java create mode 100644 netty-jctools/src/test/java/org/jctools/maps/NBHMLongKeyAtomicityTest.java create mode 100644 netty-jctools/src/test/java/org/jctools/maps/NBHMRemoveTest.java create mode 100644 netty-jctools/src/test/java/org/jctools/maps/NBHMReplaceTest.java create mode 100644 netty-jctools/src/test/java/org/jctools/maps/NonBlockingHashMapGuavaTestSuite.java create mode 100644 netty-jctools/src/test/java/org/jctools/maps/linearizability_test/LincheckMapTest.java create mode 100644 netty-jctools/src/test/java/org/jctools/maps/linearizability_test/LincheckSetTest.java create mode 100644 netty-jctools/src/test/java/org/jctools/maps/linearizability_test/NonBlockingHashMapLinearizabilityTest.java create mode 100644 netty-jctools/src/test/java/org/jctools/maps/linearizability_test/NonBlockingHashMapLongLinearizabilityTest.java create mode 100644 netty-jctools/src/test/java/org/jctools/maps/linearizability_test/NonBlockingHashSetLinearizabilityTest.java create mode 100644 netty-jctools/src/test/java/org/jctools/maps/linearizability_test/NonBlockingIdentityHashMapLinearizabilityTest.java create mode 100644 netty-jctools/src/test/java/org/jctools/maps/linearizability_test/NonBlockingSetIntLinearizabilityTest.java create mode 100644 netty-jctools/src/test/java/org/jctools/maps/nbhm_test/NBHMID_Tester2.java create mode 100644 netty-jctools/src/test/java/org/jctools/maps/nbhm_test/NBHML_Tester2.java create mode 100644 netty-jctools/src/test/java/org/jctools/maps/nbhm_test/NBHM_Tester2.java create mode 100644 netty-jctools/src/test/java/org/jctools/maps/nbhs_test/nbhs_tester.java create mode 100644 netty-jctools/src/test/java/org/jctools/maps/nbhs_test/nbsi_tester.java create mode 100644 netty-jctools/src/test/java/org/jctools/queues/MpqSanityTest.java create mode 100644 netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestMpmcArray.java create mode 100644 netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestMpmcUnboundedXadd.java create mode 100644 netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestMpscArray.java create mode 100644 netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestMpscBlockingConsumer.java create mode 100644 netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestMpscBlockingConsumerExtended.java create mode 100644 netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestMpscChunked.java create mode 100644 netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestMpscCompound.java create mode 100644 netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestMpscGrowable.java create mode 100644 netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestMpscLinked.java create mode 100644 netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestMpscUnbounded.java create mode 100644 netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestMpscUnboundedXadd.java create mode 100644 netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestSpmcArray.java create mode 100644 netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestSpscArray.java create mode 100644 netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestSpscChunked.java create mode 100644 netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestSpscGrowable.java create mode 100644 netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestSpscLinked.java create mode 100644 netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestSpscUnbounded.java create mode 100644 netty-jctools/src/test/java/org/jctools/queues/MpscArrayQueueSnapshotTest.java create mode 100644 netty-jctools/src/test/java/org/jctools/queues/MpscUnboundedArrayQueueSnapshotTest.java create mode 100644 netty-jctools/src/test/java/org/jctools/queues/QueueSanityTest.java create mode 100644 netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestMpmcArray.java create mode 100644 netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestMpmcUnboundedXadd.java create mode 100644 netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestMpscArray.java create mode 100644 netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestMpscArrayExtended.java create mode 100644 netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestMpscBlockingConsumer.java create mode 100644 netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestMpscBlockingConsumerArrayExtended.java create mode 100644 netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestMpscBlockingConsumerOfferBelowThreshold.java create mode 100644 netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestMpscChunked.java create mode 100644 netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestMpscChunkedExtended.java create mode 100644 netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestMpscCompound.java create mode 100644 netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestMpscGrowable.java create mode 100644 netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestMpscLinked.java create mode 100644 netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestMpscOfferBelowThreshold.java create mode 100644 netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestMpscUnboundedArray.java create mode 100644 netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestMpscUnboundedXadd.java create mode 100644 netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestSpmcArray.java create mode 100644 netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestSpscArray.java create mode 100644 netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestSpscArrayExtended.java create mode 100644 netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestSpscChunked.java create mode 100644 netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestSpscChunkedExtended.java create mode 100644 netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestSpscGrowable.java create mode 100644 netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestSpscGrowableExtended.java create mode 100644 netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestSpscLinked.java create mode 100644 netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestSpscUnbounded.java create mode 100644 netty-jctools/src/test/java/org/jctools/queues/ScQueueRemoveTest.java create mode 100644 netty-jctools/src/test/java/org/jctools/queues/ScQueueRemoveTestMpscLinked.java create mode 100644 netty-jctools/src/test/java/org/jctools/queues/atomic/MpscAtomicArrayQueueOfferWithThresholdTest.java create mode 100644 netty-jctools/src/test/java/org/jctools/queues/atomic/MpscLinkedAtomicQueueRemoveTest.java create mode 100644 netty-jctools/src/test/java/org/jctools/queues/atomic/SpscAtomicArrayQueueTest.java create mode 100644 netty-jctools/src/test/java/org/jctools/queues/matchers/Matchers.java create mode 100644 netty-jctools/src/test/java/org/jctools/queues/spec/ConcurrentQueueSpec.java create mode 100644 netty-jctools/src/test/java/org/jctools/queues/spec/Ordering.java create mode 100644 netty-jctools/src/test/java/org/jctools/queues/spec/Preference.java create mode 100644 netty-jctools/src/test/java/org/jctools/util/AtomicQueueFactory.java create mode 100644 netty-jctools/src/test/java/org/jctools/util/PaddedAtomicLongTest.java create mode 100755 netty-jctools/src/test/java/org/jctools/util/Pow2Test.java create mode 100644 netty-jctools/src/test/java/org/jctools/util/QueueFactory.java create mode 100644 netty-jctools/src/test/java/org/jctools/util/RangeUtilTest.java create mode 100644 netty-jctools/src/test/java/org/jctools/util/TestUtil.java create mode 100644 netty-jctools/src/test/java/org/jctools/util/UnpaddedQueueFactory.java create mode 100644 netty-jctools/src/test/resources/logging.properties create mode 100644 netty-resolver/build.gradle create mode 100644 netty-resolver/src/main/java/io/netty/resolver/AbstractAddressResolver.java create mode 100644 netty-resolver/src/main/java/io/netty/resolver/AddressResolver.java create mode 100644 netty-resolver/src/main/java/io/netty/resolver/AddressResolverGroup.java create mode 100644 netty-resolver/src/main/java/io/netty/resolver/CompositeNameResolver.java create mode 100644 netty-resolver/src/main/java/io/netty/resolver/DefaultAddressResolverGroup.java create mode 100644 netty-resolver/src/main/java/io/netty/resolver/DefaultHostsFileEntriesResolver.java create mode 100644 netty-resolver/src/main/java/io/netty/resolver/DefaultNameResolver.java create mode 100644 netty-resolver/src/main/java/io/netty/resolver/HostsFileEntries.java create mode 100644 netty-resolver/src/main/java/io/netty/resolver/HostsFileEntriesProvider.java create mode 100644 netty-resolver/src/main/java/io/netty/resolver/HostsFileEntriesResolver.java create mode 100644 netty-resolver/src/main/java/io/netty/resolver/HostsFileParser.java create mode 100644 netty-resolver/src/main/java/io/netty/resolver/InetNameResolver.java create mode 100644 netty-resolver/src/main/java/io/netty/resolver/InetSocketAddressResolver.java create mode 100644 netty-resolver/src/main/java/io/netty/resolver/NameResolver.java create mode 100644 netty-resolver/src/main/java/io/netty/resolver/NoopAddressResolver.java create mode 100644 netty-resolver/src/main/java/io/netty/resolver/NoopAddressResolverGroup.java create mode 100644 netty-resolver/src/main/java/io/netty/resolver/ResolvedAddressTypes.java create mode 100644 netty-resolver/src/main/java/io/netty/resolver/RoundRobinInetAddressResolver.java create mode 100644 netty-resolver/src/main/java/io/netty/resolver/SimpleNameResolver.java create mode 100644 netty-resolver/src/main/java/io/netty/resolver/package-info.java create mode 100644 netty-resolver/src/main/java/module-info.java create mode 100644 netty-resolver/src/test/java/io/netty/resolver/DefaultHostsFileEntriesResolverTest.java create mode 100644 netty-resolver/src/test/java/io/netty/resolver/HostsFileEntriesProviderTest.java create mode 100644 netty-resolver/src/test/java/io/netty/resolver/HostsFileParserTest.java create mode 100644 netty-resolver/src/test/java/io/netty/resolver/InetSocketAddressResolverTest.java create mode 100644 netty-resolver/src/test/resources/io/netty/resolver/hosts-unicode create mode 100644 netty-resolver/src/test/resources/logging.properties create mode 100644 netty-util/build.gradle create mode 100644 netty-util/src/main/java/io/netty/util/AbstractConstant.java create mode 100644 netty-util/src/main/java/io/netty/util/AbstractReferenceCounted.java create mode 100644 netty-util/src/main/java/io/netty/util/AsciiString.java create mode 100644 netty-util/src/main/java/io/netty/util/AsyncMapping.java create mode 100644 netty-util/src/main/java/io/netty/util/Attribute.java create mode 100644 netty-util/src/main/java/io/netty/util/AttributeKey.java create mode 100644 netty-util/src/main/java/io/netty/util/AttributeMap.java create mode 100644 netty-util/src/main/java/io/netty/util/BooleanSupplier.java create mode 100644 netty-util/src/main/java/io/netty/util/ByteProcessor.java create mode 100644 netty-util/src/main/java/io/netty/util/ByteProcessorUtils.java create mode 100644 netty-util/src/main/java/io/netty/util/CharsetUtil.java create mode 100644 netty-util/src/main/java/io/netty/util/Constant.java create mode 100644 netty-util/src/main/java/io/netty/util/ConstantPool.java create mode 100644 netty-util/src/main/java/io/netty/util/DefaultAttributeMap.java create mode 100644 netty-util/src/main/java/io/netty/util/DomainMappingBuilder.java create mode 100644 netty-util/src/main/java/io/netty/util/DomainNameMapping.java create mode 100644 netty-util/src/main/java/io/netty/util/DomainNameMappingBuilder.java create mode 100644 netty-util/src/main/java/io/netty/util/DomainWildcardMappingBuilder.java create mode 100644 netty-util/src/main/java/io/netty/util/HashedWheelTimer.java create mode 100644 netty-util/src/main/java/io/netty/util/HashingStrategy.java create mode 100644 netty-util/src/main/java/io/netty/util/IllegalReferenceCountException.java create mode 100644 netty-util/src/main/java/io/netty/util/IntSupplier.java create mode 100644 netty-util/src/main/java/io/netty/util/Mapping.java create mode 100644 netty-util/src/main/java/io/netty/util/NetUtil.java create mode 100644 netty-util/src/main/java/io/netty/util/NetUtilInitializations.java create mode 100644 netty-util/src/main/java/io/netty/util/NettyRuntime.java create mode 100644 netty-util/src/main/java/io/netty/util/Recycler.java create mode 100644 netty-util/src/main/java/io/netty/util/ReferenceCountUtil.java create mode 100644 netty-util/src/main/java/io/netty/util/ReferenceCounted.java create mode 100644 netty-util/src/main/java/io/netty/util/ResourceLeak.java create mode 100644 netty-util/src/main/java/io/netty/util/ResourceLeakDetector.java create mode 100644 netty-util/src/main/java/io/netty/util/ResourceLeakDetectorFactory.java create mode 100644 netty-util/src/main/java/io/netty/util/ResourceLeakException.java create mode 100644 netty-util/src/main/java/io/netty/util/ResourceLeakHint.java create mode 100644 netty-util/src/main/java/io/netty/util/ResourceLeakTracker.java create mode 100644 netty-util/src/main/java/io/netty/util/Signal.java create mode 100644 netty-util/src/main/java/io/netty/util/SuppressForbidden.java create mode 100644 netty-util/src/main/java/io/netty/util/ThreadDeathWatcher.java create mode 100644 netty-util/src/main/java/io/netty/util/Timeout.java create mode 100644 netty-util/src/main/java/io/netty/util/Timer.java create mode 100644 netty-util/src/main/java/io/netty/util/TimerTask.java create mode 100644 netty-util/src/main/java/io/netty/util/UncheckedBooleanSupplier.java create mode 100644 netty-util/src/main/java/io/netty/util/Version.java create mode 100644 netty-util/src/main/java/io/netty/util/collection/ByteCollections.java create mode 100644 netty-util/src/main/java/io/netty/util/collection/ByteObjectHashMap.java create mode 100644 netty-util/src/main/java/io/netty/util/collection/ByteObjectMap.java create mode 100644 netty-util/src/main/java/io/netty/util/collection/CharCollections.java create mode 100644 netty-util/src/main/java/io/netty/util/collection/CharObjectHashMap.java create mode 100644 netty-util/src/main/java/io/netty/util/collection/CharObjectMap.java create mode 100644 netty-util/src/main/java/io/netty/util/collection/IntCollections.java create mode 100644 netty-util/src/main/java/io/netty/util/collection/IntObjectHashMap.java create mode 100644 netty-util/src/main/java/io/netty/util/collection/IntObjectMap.java create mode 100644 netty-util/src/main/java/io/netty/util/collection/LongCollections.java create mode 100644 netty-util/src/main/java/io/netty/util/collection/LongObjectHashMap.java create mode 100644 netty-util/src/main/java/io/netty/util/collection/LongObjectMap.java create mode 100644 netty-util/src/main/java/io/netty/util/collection/ShortCollections.java create mode 100644 netty-util/src/main/java/io/netty/util/collection/ShortObjectHashMap.java create mode 100644 netty-util/src/main/java/io/netty/util/collection/ShortObjectMap.java create mode 100644 netty-util/src/main/java/io/netty/util/concurrent/AbstractEventExecutor.java create mode 100644 netty-util/src/main/java/io/netty/util/concurrent/AbstractEventExecutorGroup.java create mode 100644 netty-util/src/main/java/io/netty/util/concurrent/AbstractFuture.java create mode 100644 netty-util/src/main/java/io/netty/util/concurrent/AbstractScheduledEventExecutor.java create mode 100644 netty-util/src/main/java/io/netty/util/concurrent/BlockingOperationException.java create mode 100644 netty-util/src/main/java/io/netty/util/concurrent/CompleteFuture.java create mode 100644 netty-util/src/main/java/io/netty/util/concurrent/DefaultEventExecutor.java create mode 100644 netty-util/src/main/java/io/netty/util/concurrent/DefaultEventExecutorChooserFactory.java create mode 100644 netty-util/src/main/java/io/netty/util/concurrent/DefaultEventExecutorGroup.java create mode 100644 netty-util/src/main/java/io/netty/util/concurrent/DefaultFutureListeners.java create mode 100644 netty-util/src/main/java/io/netty/util/concurrent/DefaultProgressivePromise.java create mode 100644 netty-util/src/main/java/io/netty/util/concurrent/DefaultPromise.java create mode 100644 netty-util/src/main/java/io/netty/util/concurrent/DefaultThreadFactory.java create mode 100644 netty-util/src/main/java/io/netty/util/concurrent/EventExecutor.java create mode 100644 netty-util/src/main/java/io/netty/util/concurrent/EventExecutorChooserFactory.java create mode 100644 netty-util/src/main/java/io/netty/util/concurrent/EventExecutorGroup.java create mode 100644 netty-util/src/main/java/io/netty/util/concurrent/FailedFuture.java create mode 100644 netty-util/src/main/java/io/netty/util/concurrent/FastThreadLocal.java create mode 100644 netty-util/src/main/java/io/netty/util/concurrent/FastThreadLocalRunnable.java create mode 100644 netty-util/src/main/java/io/netty/util/concurrent/FastThreadLocalThread.java create mode 100644 netty-util/src/main/java/io/netty/util/concurrent/Future.java create mode 100644 netty-util/src/main/java/io/netty/util/concurrent/FutureListener.java create mode 100644 netty-util/src/main/java/io/netty/util/concurrent/GenericFutureListener.java create mode 100644 netty-util/src/main/java/io/netty/util/concurrent/GenericProgressiveFutureListener.java create mode 100644 netty-util/src/main/java/io/netty/util/concurrent/GlobalEventExecutor.java create mode 100644 netty-util/src/main/java/io/netty/util/concurrent/ImmediateEventExecutor.java create mode 100644 netty-util/src/main/java/io/netty/util/concurrent/ImmediateExecutor.java create mode 100644 netty-util/src/main/java/io/netty/util/concurrent/MultithreadEventExecutorGroup.java create mode 100644 netty-util/src/main/java/io/netty/util/concurrent/NonStickyEventExecutorGroup.java create mode 100644 netty-util/src/main/java/io/netty/util/concurrent/OrderedEventExecutor.java create mode 100644 netty-util/src/main/java/io/netty/util/concurrent/ProgressiveFuture.java create mode 100644 netty-util/src/main/java/io/netty/util/concurrent/ProgressivePromise.java create mode 100644 netty-util/src/main/java/io/netty/util/concurrent/Promise.java create mode 100644 netty-util/src/main/java/io/netty/util/concurrent/PromiseAggregator.java create mode 100644 netty-util/src/main/java/io/netty/util/concurrent/PromiseCombiner.java create mode 100644 netty-util/src/main/java/io/netty/util/concurrent/PromiseNotifier.java create mode 100644 netty-util/src/main/java/io/netty/util/concurrent/PromiseTask.java create mode 100644 netty-util/src/main/java/io/netty/util/concurrent/RejectedExecutionHandler.java create mode 100644 netty-util/src/main/java/io/netty/util/concurrent/RejectedExecutionHandlers.java create mode 100644 netty-util/src/main/java/io/netty/util/concurrent/ScheduledFuture.java create mode 100644 netty-util/src/main/java/io/netty/util/concurrent/ScheduledFutureTask.java create mode 100644 netty-util/src/main/java/io/netty/util/concurrent/SingleThreadEventExecutor.java create mode 100644 netty-util/src/main/java/io/netty/util/concurrent/SucceededFuture.java create mode 100644 netty-util/src/main/java/io/netty/util/concurrent/ThreadPerTaskExecutor.java create mode 100644 netty-util/src/main/java/io/netty/util/concurrent/ThreadProperties.java create mode 100644 netty-util/src/main/java/io/netty/util/concurrent/UnaryPromiseNotifier.java create mode 100644 netty-util/src/main/java/io/netty/util/concurrent/UnorderedThreadPoolEventExecutor.java create mode 100644 netty-util/src/main/java/io/netty/util/concurrent/package-info.java create mode 100644 netty-util/src/main/java/io/netty/util/internal/AppendableCharSequence.java create mode 100644 netty-util/src/main/java/io/netty/util/internal/ClassInitializerUtil.java create mode 100644 netty-util/src/main/java/io/netty/util/internal/Cleaner.java create mode 100644 netty-util/src/main/java/io/netty/util/internal/CleanerJava9.java create mode 100644 netty-util/src/main/java/io/netty/util/internal/ConcurrentSet.java create mode 100644 netty-util/src/main/java/io/netty/util/internal/ConstantTimeUtils.java create mode 100644 netty-util/src/main/java/io/netty/util/internal/DefaultPriorityQueue.java create mode 100644 netty-util/src/main/java/io/netty/util/internal/EmptyArrays.java create mode 100644 netty-util/src/main/java/io/netty/util/internal/EmptyPriorityQueue.java create mode 100644 netty-util/src/main/java/io/netty/util/internal/IntegerHolder.java create mode 100644 netty-util/src/main/java/io/netty/util/internal/InternalThreadLocalMap.java create mode 100644 netty-util/src/main/java/io/netty/util/internal/LongAdderCounter.java create mode 100644 netty-util/src/main/java/io/netty/util/internal/LongCounter.java create mode 100644 netty-util/src/main/java/io/netty/util/internal/MacAddressUtil.java create mode 100644 netty-util/src/main/java/io/netty/util/internal/MathUtil.java create mode 100644 netty-util/src/main/java/io/netty/util/internal/NativeLibraryLoader.java create mode 100644 netty-util/src/main/java/io/netty/util/internal/NativeLibraryUtil.java create mode 100644 netty-util/src/main/java/io/netty/util/internal/NoOpTypeParameterMatcher.java create mode 100644 netty-util/src/main/java/io/netty/util/internal/ObjectCleaner.java create mode 100644 netty-util/src/main/java/io/netty/util/internal/ObjectPool.java create mode 100644 netty-util/src/main/java/io/netty/util/internal/ObjectUtil.java create mode 100644 netty-util/src/main/java/io/netty/util/internal/OutOfDirectMemoryError.java create mode 100644 netty-util/src/main/java/io/netty/util/internal/PendingWrite.java create mode 100644 netty-util/src/main/java/io/netty/util/internal/PlatformDependent.java create mode 100644 netty-util/src/main/java/io/netty/util/internal/PlatformDependent0.java create mode 100644 netty-util/src/main/java/io/netty/util/internal/PriorityQueue.java create mode 100644 netty-util/src/main/java/io/netty/util/internal/PriorityQueueNode.java create mode 100644 netty-util/src/main/java/io/netty/util/internal/PromiseNotificationUtil.java create mode 100644 netty-util/src/main/java/io/netty/util/internal/ReadOnlyIterator.java create mode 100644 netty-util/src/main/java/io/netty/util/internal/RecyclableArrayList.java create mode 100644 netty-util/src/main/java/io/netty/util/internal/ReferenceCountUpdater.java create mode 100644 netty-util/src/main/java/io/netty/util/internal/ReflectionUtil.java create mode 100644 netty-util/src/main/java/io/netty/util/internal/ResourcesUtil.java create mode 100644 netty-util/src/main/java/io/netty/util/internal/SocketUtils.java create mode 100644 netty-util/src/main/java/io/netty/util/internal/StringUtil.java create mode 100644 netty-util/src/main/java/io/netty/util/internal/SuppressJava6Requirement.java create mode 100644 netty-util/src/main/java/io/netty/util/internal/SystemPropertyUtil.java create mode 100644 netty-util/src/main/java/io/netty/util/internal/ThreadExecutorMap.java create mode 100644 netty-util/src/main/java/io/netty/util/internal/ThreadLocalRandom.java create mode 100644 netty-util/src/main/java/io/netty/util/internal/ThrowableUtil.java create mode 100644 netty-util/src/main/java/io/netty/util/internal/TypeParameterMatcher.java create mode 100644 netty-util/src/main/java/io/netty/util/internal/UnpaddedInternalThreadLocalMap.java create mode 100644 netty-util/src/main/java/io/netty/util/internal/UnstableApi.java create mode 100644 netty-util/src/main/java/io/netty/util/internal/logging/AbstractInternalLogger.java create mode 100644 netty-util/src/main/java/io/netty/util/internal/logging/FormattingTuple.java create mode 100644 netty-util/src/main/java/io/netty/util/internal/logging/InternalLogLevel.java create mode 100644 netty-util/src/main/java/io/netty/util/internal/logging/InternalLogger.java create mode 100644 netty-util/src/main/java/io/netty/util/internal/logging/InternalLoggerFactory.java create mode 100644 netty-util/src/main/java/io/netty/util/internal/logging/JdkLogger.java create mode 100644 netty-util/src/main/java/io/netty/util/internal/logging/JdkLoggerFactory.java create mode 100644 netty-util/src/main/java/io/netty/util/internal/logging/MessageFormatter.java create mode 100644 netty-util/src/main/java/io/netty/util/internal/logging/package-info.java create mode 100644 netty-util/src/main/java/io/netty/util/internal/package-info.java create mode 100644 netty-util/src/main/java/io/netty/util/package-info.java create mode 100644 netty-util/src/main/java/module-info.java create mode 100644 netty-util/src/test/java/io/netty/util/AbstractReferenceCountedTest.java create mode 100644 netty-util/src/test/java/io/netty/util/AsciiStringCharacterTest.java create mode 100644 netty-util/src/test/java/io/netty/util/AsciiStringMemoryTest.java create mode 100644 netty-util/src/test/java/io/netty/util/AttributeKeyTest.java create mode 100644 netty-util/src/test/java/io/netty/util/ConstantPoolTest.java create mode 100644 netty-util/src/test/java/io/netty/util/DefaultAttributeMapTest.java create mode 100644 netty-util/src/test/java/io/netty/util/DomainNameMappingTest.java create mode 100644 netty-util/src/test/java/io/netty/util/DomainWildcardMappingBuilderTest.java create mode 100644 netty-util/src/test/java/io/netty/util/HashedWheelTimerTest.java create mode 100644 netty-util/src/test/java/io/netty/util/NetUtilTest.java create mode 100644 netty-util/src/test/java/io/netty/util/NettyRuntimeTests.java create mode 100644 netty-util/src/test/java/io/netty/util/RecyclerFastThreadLocalTest.java create mode 100644 netty-util/src/test/java/io/netty/util/RecyclerTest.java create mode 100644 netty-util/src/test/java/io/netty/util/ResourceLeakDetectorTest.java create mode 100644 netty-util/src/test/java/io/netty/util/RunInFastThreadLocalThreadExtension.java create mode 100644 netty-util/src/test/java/io/netty/util/ThreadDeathWatcherTest.java create mode 100644 netty-util/src/test/java/io/netty/util/concurrent/AbstractScheduledEventExecutorTest.java create mode 100644 netty-util/src/test/java/io/netty/util/concurrent/DefaultPromiseTest.java create mode 100644 netty-util/src/test/java/io/netty/util/concurrent/DefaultThreadFactoryTest.java create mode 100644 netty-util/src/test/java/io/netty/util/concurrent/FastThreadLocalTest.java create mode 100644 netty-util/src/test/java/io/netty/util/concurrent/GlobalEventExecutorTest.java create mode 100644 netty-util/src/test/java/io/netty/util/concurrent/ImmediateExecutorTest.java create mode 100644 netty-util/src/test/java/io/netty/util/concurrent/NonStickyEventExecutorGroupTest.java create mode 100644 netty-util/src/test/java/io/netty/util/concurrent/PromiseAggregatorTest.java create mode 100644 netty-util/src/test/java/io/netty/util/concurrent/PromiseCombinerTest.java create mode 100644 netty-util/src/test/java/io/netty/util/concurrent/PromiseNotifierTest.java create mode 100644 netty-util/src/test/java/io/netty/util/concurrent/SingleThreadEventExecutorTest.java create mode 100644 netty-util/src/test/java/io/netty/util/concurrent/UnorderedThreadPoolEventExecutorTest.java create mode 100644 netty-util/src/test/java/io/netty/util/internal/AppendableCharSequenceTest.java create mode 100644 netty-util/src/test/java/io/netty/util/internal/DefaultPriorityQueueTest.java create mode 100644 netty-util/src/test/java/io/netty/util/internal/MacAddressUtilTest.java create mode 100644 netty-util/src/test/java/io/netty/util/internal/MathUtilTest.java create mode 100644 netty-util/src/test/java/io/netty/util/internal/NativeLibraryLoaderTest.java create mode 100644 netty-util/src/test/java/io/netty/util/internal/ObjectCleanerTest.java create mode 100644 netty-util/src/test/java/io/netty/util/internal/ObjectUtilTest.java create mode 100644 netty-util/src/test/java/io/netty/util/internal/OsClassifiersTest.java create mode 100644 netty-util/src/test/java/io/netty/util/internal/PlatformDependent0Test.java create mode 100644 netty-util/src/test/java/io/netty/util/internal/PlatformDependentTest.java create mode 100644 netty-util/src/test/java/io/netty/util/internal/StringUtilTest.java create mode 100644 netty-util/src/test/java/io/netty/util/internal/SystemPropertyUtilTest.java create mode 100644 netty-util/src/test/java/io/netty/util/internal/ThreadExecutorMapTest.java create mode 100644 netty-util/src/test/java/io/netty/util/internal/ThreadLocalRandomTest.java create mode 100644 netty-util/src/test/java/io/netty/util/internal/TypeParameterMatcherTest.java create mode 100644 netty-util/src/test/java/io/netty/util/internal/logging/AbstractInternalLoggerTest.java create mode 100644 netty-util/src/test/java/io/netty/util/internal/logging/InternalLoggerFactoryTest.java create mode 100644 netty-util/src/test/java/io/netty/util/internal/logging/JdkLoggerFactoryTest.java create mode 100644 netty-util/src/test/java/io/netty/util/internal/logging/MessageFormatterTest.java create mode 100644 netty-util/src/test/resources/logging.properties create mode 100644 netty-zlib/NOTICE.txt create mode 100644 netty-zlib/src/main/java/io/netty/zlib/Adler32.java create mode 100644 netty-zlib/src/main/java/io/netty/zlib/CRC32.java create mode 100644 netty-zlib/src/main/java/io/netty/zlib/Checksum.java create mode 100644 netty-zlib/src/main/java/io/netty/zlib/Deflate.java create mode 100644 netty-zlib/src/main/java/io/netty/zlib/Deflater.java create mode 100644 netty-zlib/src/main/java/io/netty/zlib/DeflaterOutputStream.java create mode 100644 netty-zlib/src/main/java/io/netty/zlib/GZIPException.java create mode 100644 netty-zlib/src/main/java/io/netty/zlib/GZIPHeader.java create mode 100644 netty-zlib/src/main/java/io/netty/zlib/GZIPInputStream.java create mode 100644 netty-zlib/src/main/java/io/netty/zlib/GZIPOutputStream.java create mode 100644 netty-zlib/src/main/java/io/netty/zlib/InfBlocks.java create mode 100644 netty-zlib/src/main/java/io/netty/zlib/InfCodes.java create mode 100644 netty-zlib/src/main/java/io/netty/zlib/InfTree.java create mode 100644 netty-zlib/src/main/java/io/netty/zlib/Inflate.java create mode 100644 netty-zlib/src/main/java/io/netty/zlib/Inflater.java create mode 100644 netty-zlib/src/main/java/io/netty/zlib/InflaterInputStream.java create mode 100644 netty-zlib/src/main/java/io/netty/zlib/JZlib.java create mode 100644 netty-zlib/src/main/java/io/netty/zlib/StaticTree.java create mode 100644 netty-zlib/src/main/java/io/netty/zlib/Tree.java create mode 100644 netty-zlib/src/main/java/io/netty/zlib/ZInputStream.java create mode 100644 netty-zlib/src/main/java/io/netty/zlib/ZOutputStream.java create mode 100644 netty-zlib/src/main/java/io/netty/zlib/ZStream.java create mode 100644 netty-zlib/src/main/java/io/netty/zlib/ZStreamException.java create mode 100644 netty-zlib/src/main/java/module-info.java create mode 100644 settings.gradle diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..021874f --- /dev/null +++ b/.gitignore @@ -0,0 +1,16 @@ +/.settings +/.classpath +/.project +/.gradle +**/data +**/work +**/logs +**/.idea +**/target +**/out +**/build +.DS_Store +*.iml +*~ +*.key +*.crt diff --git a/LICENSE.txt b/LICENSE.txt new file mode 100644 index 0000000..d645695 --- /dev/null +++ b/LICENSE.txt @@ -0,0 +1,202 @@ + + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. diff --git a/NOTICE.txt b/NOTICE.txt new file mode 100644 index 0000000..d546c83 --- /dev/null +++ b/NOTICE.txt @@ -0,0 +1,30 @@ + +The following changes were performed on the original source code: + +- removed slf4j, log4j, log4j2 logging +- removed internal classes for GraalVM (SCM) +- removed internal classes for Blockhound +- removed jetbrains annotations +- private copy of jctools in io.netty.jctools +- removed SecurityManager code +- add module info +- removed lzma dependency (too old for module) +- use JdkZLibDecoder/JdkZlibEncoder in websocketx +- removed JettyAlpnSslEngine +- removed JettyNpnSslEngine +- removed NPN +- use of javax.security.cert.X509Certificate replaced by java.security.cert.Certificate +- private copy of com.jcraft.zlib in io.netty.zlib +- precompiled io.netty.util.collection classes added +- refactored SSL handler to separate subproject netty-handler-ssl +- refactored compression codecs to separate subproject netty-handler-codec-compression +- moved netty-tcnative/openssl-classes to netty-internal-tcnative +- removed logging handler test +- removed native image handler test + +Challenges for Netty build on JDK 21 + +- unmaintained com.jcraft.jzlib +- JCTools uses sun.misc.Unsafe, not VarHandles +- PlatformDependent uses sun.misc.Unsafe +- finalize() in PoolThreadCache, PoolArena diff --git a/build.gradle b/build.gradle new file mode 100644 index 0000000..2d9bd74 --- /dev/null +++ b/build.gradle @@ -0,0 +1,35 @@ + +plugins { + id 'maven-publish' + id 'signing' + id "io.github.gradle-nexus.publish-plugin" version "2.0.0-rc-1" +} + +wrapper { + gradleVersion = libs.versions.gradle.get() + distributionType = Wrapper.DistributionType.ALL +} + +ext { + user = 'joerg' + name = 'undertow' + description = 'Undertow port forked from quarkus-http' + inceptionYear = '2023' + url = 'https://xbib.org/' + user + '/' + name + scmUrl = 'https://xbib.org/' + user + '/' + name + scmConnection = 'scm:git:git://xbib.org/' + user + '/' + name + '.git' + scmDeveloperConnection = 'scm:git:ssh://forgejo@xbib.org:' + user + '/' + name + '.git' + issueManagementSystem = 'Forgejo' + issueManagementUrl = ext.scmUrl + '/issues' + licenseName = 'The Apache License, Version 2.0' + licenseUrl = 'http://www.apache.org/licenses/LICENSE-2.0.txt' +} + +subprojects { + apply from: rootProject.file('gradle/repositories/maven.gradle') + apply from: rootProject.file('gradle/compile/java.gradle') + apply from: rootProject.file('gradle/test/junit5.gradle') + apply from: rootProject.file('gradle/publish/maven.gradle') +} +apply from: rootProject.file('gradle/publish/sonatype.gradle') +apply from: rootProject.file('gradle/publish/forgejo.gradle') diff --git a/gradle.properties b/gradle.properties new file mode 100644 index 0000000..3e9aad9 --- /dev/null +++ b/gradle.properties @@ -0,0 +1,3 @@ +group = org.xbib +name = netty +version = 4.1.104 diff --git a/gradle/compile/java.gradle b/gradle/compile/java.gradle new file mode 100644 index 0000000..5cb1fda --- /dev/null +++ b/gradle/compile/java.gradle @@ -0,0 +1,37 @@ + +apply plugin: 'java-library' + +java { + toolchain { + languageVersion = JavaLanguageVersion.of(21) + } + modularity.inferModulePath.set(true) + withSourcesJar() + withJavadocJar() +} + +jar { + manifest { + attributes('Implementation-Version': project.version) + attributes('X-Java-Compiler-Version': JavaLanguageVersion.of(21).toString()) + } +} + +tasks.withType(JavaCompile) { + options.fork = true + options.forkOptions.jvmArgs += [ + '-Duser.language=en', + '-Duser.country=US', + ] + options.compilerArgs += [ + '-Xlint:all', + '--add-exports=jdk.unsupported/sun.misc=org.xbib.io.netty.jctools', + '--add-exports=java.base/jdk.internal.misc=org.xbib.io.netty.util' + ] + options.encoding = 'UTF-8' +} + +tasks.withType(Javadoc) { + options.addStringOption('Xdoclint:none', '-quiet') + options.encoding = 'UTF-8' +} diff --git a/gradle/documentation/asciidoc.gradle b/gradle/documentation/asciidoc.gradle new file mode 100644 index 0000000..da6dd7e --- /dev/null +++ b/gradle/documentation/asciidoc.gradle @@ -0,0 +1,19 @@ +apply plugin: 'org.xbib.gradle.plugin.asciidoctor' + +asciidoctor { + backends 'html5' + outputDir = file("${rootProject.projectDir}/docs") + separateOutputDirs = false + attributes 'source-highlighter': 'coderay', + idprefix: '', + idseparator: '-', + toc: 'left', + doctype: 'book', + icons: 'font', + encoding: 'utf-8', + sectlink: true, + sectanchors: true, + linkattrs: true, + imagesdir: 'img', + stylesheet: "${projectDir}/src/docs/asciidoc/css/foundation.css" +} diff --git a/gradle/ide/idea.gradle b/gradle/ide/idea.gradle new file mode 100644 index 0000000..5bd2095 --- /dev/null +++ b/gradle/ide/idea.gradle @@ -0,0 +1,8 @@ +apply plugin: 'idea' + +idea { + module { + outputDir file('build/classes/java/main') + testOutputDir file('build/classes/java/test') + } +} diff --git a/gradle/publish/forgejo.gradle b/gradle/publish/forgejo.gradle new file mode 100644 index 0000000..b99b2fb --- /dev/null +++ b/gradle/publish/forgejo.gradle @@ -0,0 +1,16 @@ +if (project.hasProperty('forgeJoToken')) { + publishing { + repositories { + maven { + url 'https://xbib.org/api/packages/joerg/maven' + credentials(HttpHeaderCredentials) { + name = "Authorization" + value = "token ${project.property('forgeJoToken')}" + } + authentication { + header(HttpHeaderAuthentication) + } + } + } + } +} diff --git a/gradle/publish/ivy.gradle b/gradle/publish/ivy.gradle new file mode 100644 index 0000000..fe0a848 --- /dev/null +++ b/gradle/publish/ivy.gradle @@ -0,0 +1,27 @@ +apply plugin: 'ivy-publish' + +publishing { + repositories { + ivy { + url = "https://xbib.org/repo" + } + } + publications { + ivy(IvyPublication) { + from components.java + descriptor { + license { + name = 'The Apache License, Version 2.0' + url = 'http://www.apache.org/licenses/LICENSE-2.0.txt' + } + author { + name = 'Jörg Prante' + url = 'http://example.com/users/jane' + } + descriptor.description { + text = rootProject.ext.description + } + } + } + } +} \ No newline at end of file diff --git a/gradle/publish/maven.gradle b/gradle/publish/maven.gradle new file mode 100644 index 0000000..fbdb729 --- /dev/null +++ b/gradle/publish/maven.gradle @@ -0,0 +1,52 @@ + +publishing { + publications { + "${project.name}"(MavenPublication) { + from components.java + pom { + artifactId = project.name + name = project.name + version = project.version + description = rootProject.ext.description + url = rootProject.ext.url + inceptionYear = rootProject.ext.inceptionYear + packaging = 'jar' + organization { + name = 'xbib' + url = 'https://xbib.org' + } + developers { + developer { + id = 'jprante' + name = 'Jörg Prante' + email = 'joergprante@gmail.com' + url = 'https://xbib.org/joerg' + } + } + scm { + url = rootProject.ext.scmUrl + connection = rootProject.ext.scmConnection + developerConnection = rootProject.ext.scmDeveloperConnection + } + issueManagement { + system = rootProject.ext.issueManagementSystem + url = rootProject.ext.issueManagementUrl + } + licenses { + license { + name = rootProject.ext.licenseName + url = rootProject.ext.licenseUrl + distribution = 'repo' + } + } + } + } + } +} + +if (project.hasProperty("signing.keyId")) { + apply plugin: 'signing' + signing { + sign publishing.publications."${project.name}" + } +} diff --git a/gradle/publish/sonatype.gradle b/gradle/publish/sonatype.gradle new file mode 100644 index 0000000..02744cd --- /dev/null +++ b/gradle/publish/sonatype.gradle @@ -0,0 +1,12 @@ + +if (project.hasProperty('ossrhUsername') && project.hasProperty('ossrhPassword')) { + nexusPublishing { + repositories { + sonatype { + username = project.property('ossrhUsername') + password = project.property('ossrhPassword') + packageGroup = "org.xbib" + } + } + } +} diff --git a/gradle/quality/checkstyle.gradle b/gradle/quality/checkstyle.gradle new file mode 100644 index 0000000..707900d --- /dev/null +++ b/gradle/quality/checkstyle.gradle @@ -0,0 +1,19 @@ +apply plugin: 'checkstyle' + +tasks.withType(Checkstyle) { + ignoreFailures = true + reports { + xml.getRequired().set(true) + html.getRequired().set(true) + } +} + +checkstyle { + toolVersion = '10.4' + configFile = rootProject.file('gradle/quality/checkstyle.xml') + ignoreFailures = true + showViolations = false + checkstyleMain { + source = sourceSets.main.allSource + } +} diff --git a/gradle/quality/checkstyle.xml b/gradle/quality/checkstyle.xml new file mode 100644 index 0000000..66a9aae --- /dev/null +++ b/gradle/quality/checkstyle.xml @@ -0,0 +1,333 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/gradle/quality/cyclonedx.gradle b/gradle/quality/cyclonedx.gradle new file mode 100644 index 0000000..a6bf41b --- /dev/null +++ b/gradle/quality/cyclonedx.gradle @@ -0,0 +1,11 @@ +cyclonedxBom { + includeConfigs = [ 'runtimeClasspath' ] + skipConfigs = [ 'compileClasspath', 'testCompileClasspath' ] + projectType = "library" + schemaVersion = "1.4" + destination = file("build/reports") + outputName = "bom" + outputFormat = "json" + includeBomSerialNumber = true + componentVersion = "2.0.0" +} diff --git a/gradle/quality/pmd.gradle b/gradle/quality/pmd.gradle new file mode 100644 index 0000000..55fcfda --- /dev/null +++ b/gradle/quality/pmd.gradle @@ -0,0 +1,17 @@ + +apply plugin: 'pmd' + +tasks.withType(Pmd) { + ignoreFailures = true + reports { + xml.getRequired().set(true) + html.getRequired().set(true) + } +} + +pmd { + ignoreFailures = true + consoleOutput = false + toolVersion = "6.51.0" + ruleSetFiles = rootProject.files('gradle/quality/pmd/category/java/bestpractices.xml') +} diff --git a/gradle/quality/pmd/category/java/bestpractices.xml b/gradle/quality/pmd/category/java/bestpractices.xml new file mode 100644 index 0000000..6bf15a0 --- /dev/null +++ b/gradle/quality/pmd/category/java/bestpractices.xml @@ -0,0 +1,1650 @@ + + + + + + Rules which enforce generally accepted best practices. + + + + + The abstract class does not contain any abstract methods. An abstract class suggests + an incomplete implementation, which is to be completed by subclasses implementing the + abstract methods. If the class is intended to be used as a base class only (not to be instantiated + directly) a protected constructor can be provided prevent direct instantiation. + + 3 + + + + + + + + + + + + + + + Instantiation by way of private constructors from outside of the constructor's class often causes the + generation of an accessor. A factory method, or non-privatization of the constructor can eliminate this + situation. The generated class file is actually an interface. It gives the accessing class the ability + to invoke a new hidden package scope constructor that takes the interface as a supplementary parameter. + This turns a private constructor effectively into one with package scope, and is challenging to discern. + + 3 + + + + + + + + When accessing a private field / method from another class, the Java compiler will generate a accessor methods + with package-private visibility. This adds overhead, and to the dex method count on Android. This situation can + be avoided by changing the visibility of the field / method from private to package-private. + + 3 + + + + + + + + Constructors and methods receiving arrays should clone objects and store the copy. + This prevents future changes from the user from affecting the original array. + + 3 + + + + + + + + Avoid printStackTrace(); use a logger call instead. + + 3 + + + + + + + + + + + + + + + Reassigning loop variables can lead to hard-to-find bugs. Prevent or limit how these variables can be changed. + + In foreach-loops, configured by the `foreachReassign` property: + - `deny`: Report any reassignment of the loop variable in the loop body. _This is the default._ + - `allow`: Don't check the loop variable. + - `firstOnly`: Report any reassignments of the loop variable, except as the first statement in the loop body. + _This is useful if some kind of normalization or clean-up of the value before using is permitted, but any other change of the variable is not._ + + In for-loops, configured by the `forReassign` property: + - `deny`: Report any reassignment of the control variable in the loop body. _This is the default._ + - `allow`: Don't check the control variable. + - `skip`: Report any reassignments of the control variable, except conditional increments/decrements (`++`, `--`, `+=`, `-=`). + _This prevents accidental reassignments or unconditional increments of the control variable._ + + 3 + + + + + + + + Reassigning values to incoming parameters is not recommended. Use temporary local variables instead. + + 2 + + + + + + + + StringBuffers/StringBuilders can grow considerably, and so may become a source of memory leaks + if held within objects with long lifetimes. + + 3 + + + + + + + + + + + + + + + Application with hard-coded IP addresses can become impossible to deploy in some cases. + Externalizing IP adresses is preferable. + + 3 + + + + + + + + Always check the return values of navigation methods (next, previous, first, last) of a ResultSet. + If the value return is 'false', it should be handled properly. + + 3 + + + + + + + + Avoid constants in interfaces. Interfaces should define types, constants are implementation details + better placed in classes or enums. See Effective Java, item 19. + + 3 + + + + + + + + + + + + + + + + By convention, the default label should be the last label in a switch statement. + + 3 + + + + + + + + + + + + + + + Reports loops that can be safely replaced with the foreach syntax. The rule considers loops over + lists, arrays and iterators. A loop is safe to replace if it only uses the index variable to + access an element of the list or array, only has one update statement, and loops through *every* + element of the list or array left to right. + + 3 + + l) { + for (int i = 0; i < l.size(); i++) { // pre Java 1.5 + System.out.println(l.get(i)); + } + + for (String s : l) { // post Java 1.5 + System.out.println(s); + } + } +} +]]> + + + + + + Having a lot of control variables in a 'for' loop makes it harder to see what range of values + the loop iterates over. By default this rule allows a regular 'for' loop with only one variable. + + 3 + + + + //ForInit/LocalVariableDeclaration[count(VariableDeclarator) > $maximumVariables] + + + + + + + + + + Whenever using a log level, one should check if the loglevel is actually enabled, or + otherwise skip the associate String creation and manipulation. + + 2 + + + + + + + + In JUnit 3, test suites are indicated by the suite() method. In JUnit 4, suites are indicated + through the @RunWith(Suite.class) annotation. + + 3 + + + + + + + + + + + + + + + In JUnit 3, the tearDown method was used to clean up all data entities required in running tests. + JUnit 4 skips the tearDown method and executes all methods annotated with @After after running each test. + JUnit 5 introduced @AfterEach and @AfterAll annotations to execute methods after each test or after all tests in the class, respectively. + + 3 + + + + + + + + + + + + + + + In JUnit 3, the setUp method was used to set up all data entities required in running tests. + JUnit 4 skips the setUp method and executes all methods annotated with @Before before all tests. + JUnit 5 introduced @BeforeEach and @BeforeAll annotations to execute methods before each test or before all tests in the class, respectively. + + 3 + + + + + + + + + + + + + + + In JUnit 3, the framework executed all methods which started with the word test as a unit test. + In JUnit 4, only methods annotated with the @Test annotation are executed. + In JUnit 5, one of the following annotations should be used for tests: @Test, @RepeatedTest, @TestFactory, @TestTemplate or @ParameterizedTest. + + 3 + + + + + + + + + + + + + + + + + JUnit assertions should include an informative message - i.e., use the three-argument version of + assertEquals(), not the two-argument version. + + 3 + + + + + + + + Unit tests should not contain too many asserts. Many asserts are indicative of a complex test, for which + it is harder to verify correctness. Consider breaking the test scenario into multiple, shorter test scenarios. + Customize the maximum number of assertions used by this Rule to suit your needs. + + This rule checks for JUnit4, JUnit5 and TestNG Tests, as well as methods starting with "test". + + 3 + + + + + $maximumAsserts] +]]> + + + + + + + + + + + JUnit tests should include at least one assertion. This makes the tests more robust, and using assert + with messages provide the developer a clearer idea of what the test does. + + 3 + + + + + + + + In JUnit4, use the @Test(expected) annotation to denote tests that should throw exceptions. + + 3 + + + + + + + + The use of implementation types (i.e., HashSet) as object references limits your ability to use alternate + implementations in the future as requirements change. Whenever available, referencing objects + by their interface types (i.e, Set) provides much more flexibility. + + 3 + + list = new ArrayList<>(); + + public HashSet getFoo() { + return new HashSet(); + } + + // preferred approach + private List list = new ArrayList<>(); + + public Set getFoo() { + return new HashSet(); + } +} +]]> + + + + + + Exposing internal arrays to the caller violates object encapsulation since elements can be + removed or replaced outside of the object that owns it. It is safer to return a copy of the array. + + 3 + + + + + + + + + Annotating overridden methods with @Override ensures at compile time that + the method really overrides one, which helps refactoring and clarifies intent. + + 3 + + + + + + + + Java allows the use of several variables declaration of the same type on one line. However, it + can lead to quite messy code. This rule looks for several declarations on the same line. + + 4 + + + + 1] + [$strictMode or count(distinct-values(VariableDeclarator/@BeginLine)) != count(VariableDeclarator)] +| +//FieldDeclaration + [count(VariableDeclarator) > 1] + [$strictMode or count(distinct-values(VariableDeclarator/@BeginLine)) != count(VariableDeclarator)] +]]> + + + + + + + + + + + + + Position literals first in comparisons, if the second argument is null then NullPointerExceptions + can be avoided, they will just return false. + + 3 + + + + + + + + + + + + + + + Position literals first in comparisons, if the second argument is null then NullPointerExceptions + can be avoided, they will just return false. + + 3 + + + + + + + + + + + + + + + Throwing a new exception from a catch block without passing the original exception into the + new exception will cause the original stack trace to be lost making it difficult to debug + effectively. + + 3 + + + + + + + + Consider replacing Enumeration usages with the newer java.util.Iterator + + 3 + + + + + + + + + + + + + + + Consider replacing Hashtable usage with the newer java.util.Map if thread safety is not required. + + 3 + + + //Type/ReferenceType/ClassOrInterfaceType[@Image='Hashtable'] + + + + + + + + + + Consider replacing Vector usages with the newer java.util.ArrayList if expensive thread-safe operations are not required. + + 3 + + + //Type/ReferenceType/ClassOrInterfaceType[@Image='Vector'] + + + + + + + + + + All switch statements should include a default option to catch any unspecified values. + + 3 + + + + + + + + + + + + + + References to System.(out|err).print are usually intended for debugging purposes and can remain in + the codebase even in production code. By using a logger one can enable/disable this behaviour at + will (and by priority) and avoid clogging the Standard out log. + + 2 + + + + + + + + + + + + + + + Avoid passing parameters to methods or constructors without actually referencing them in the method body. + + 3 + + + + + + + + Avoid unused import statements to prevent unwanted dependencies. + This rule will also find unused on demand imports, i.e. import com.foo.*. + + 4 + + + + + + + + Detects when a local variable is declared and/or assigned, but not used. + + 3 + + + + + + + + Detects when a private field is declared and/or assigned a value, but not used. + + 3 + + + + + + + + Unused Private Method detects when a private method is declared but is unused. + + 3 + + + + + + + + This rule detects JUnit assertions in object equality. These assertions should be made by more specific methods, like assertEquals. + + 3 + + + + + + + + + + + + + + + This rule detects JUnit assertions in object references equality. These assertions should be made by + more specific methods, like assertNull, assertNotNull. + + 3 + + + + + + + + + + + + + + + This rule detects JUnit assertions in object references equality. These assertions should be made + by more specific methods, like assertSame, assertNotSame. + + 3 + + + + + + + + + + + + + + + When asserting a value is the same as a literal or Boxed boolean, use assertTrue/assertFalse, instead of assertEquals. + + 3 + + + + + + + + + + + + + + + The isEmpty() method on java.util.Collection is provided to determine if a collection has any elements. + Comparing the value of size() to 0 does not convey intent as well as the isEmpty() method. + + 3 + + + + + + + + Java 7 introduced the try-with-resources statement. This statement ensures that each resource is closed at the end + of the statement. It avoids the need of explicitly closing the resources in a finally block. Additionally exceptions + are better handled: If an exception occurred both in the `try` block and `finally` block, then the exception from + the try block was suppressed. With the `try`-with-resources statement, the exception thrown from the try-block is + preserved. + + 3 + + + + + + + + + + + + + + + + + Java 5 introduced the varargs parameter declaration for methods and constructors. This syntactic + sugar provides flexibility for users of these methods and constructors, allowing them to avoid + having to deal with the creation of an array. + + 4 + + + + + + + + + + + + + + diff --git a/gradle/quality/pmd/category/java/categories.properties b/gradle/quality/pmd/category/java/categories.properties new file mode 100644 index 0000000..8ef5eac --- /dev/null +++ b/gradle/quality/pmd/category/java/categories.properties @@ -0,0 +1,10 @@ + +rulesets.filenames=\ + category/java/bestpractices.xml,\ + category/java/codestyle.xml,\ + category/java/design.xml,\ + category/java/documentation.xml,\ + category/java/errorprone.xml,\ + category/java/multithreading.xml,\ + category/java/performance.xml,\ + category/java/security.xml diff --git a/gradle/quality/pmd/category/java/codestyle.xml b/gradle/quality/pmd/category/java/codestyle.xml new file mode 100644 index 0000000..ac2f0a0 --- /dev/null +++ b/gradle/quality/pmd/category/java/codestyle.xml @@ -0,0 +1,2176 @@ + + + + + + Rules which enforce a specific coding style. + + + + + Abstract classes should be named 'AbstractXXX'. + + This rule is deprecated and will be removed with PMD 7.0.0. The rule is replaced + by {% rule java/codestyle/ClassNamingConventions %}. + + 3 + + + + + + + + + + + + + + + + + + 3 + + + + + + + + Avoid using dollar signs in variable/method/class/interface names. + + 3 + + + + + + + Avoid using final local variables, turn them into fields. + 3 + + + + + + + + + + + + + + + Prefixing parameters by 'in' or 'out' pollutes the name of the parameters and reduces code readability. + To indicate whether or not a parameter will be modify in a method, its better to document method + behavior with Javadoc. + + This rule is deprecated and will be removed with PMD 7.0.0. The rule is replaced + by the more general rule {% rule java/codestyle/FormalParameterNamingConventions %}. + + 4 + + + + + + + + + + + + + + + + + + Do not use protected fields in final classes since they cannot be subclassed. + Clarify your intent by using private or package access modifiers instead. + + 3 + + + + + + + + + + + + + + + Do not use protected methods in most final classes since they cannot be subclassed. This should + only be allowed in final classes that extend other classes with protected methods (whose + visibility cannot be reduced). Clarify your intent by using private or package access modifiers instead. + + 3 + + + + + + + + + + + + + + + Unnecessary reliance on Java Native Interface (JNI) calls directly reduces application portability + and increases the maintenance burden. + + 2 + + + //Name[starts-with(@Image,'System.loadLibrary')] + + + + + + + + + + Methods that return boolean results should be named as predicate statements to denote this. + I.e, 'isReady()', 'hasValues()', 'canCommit()', 'willFail()', etc. Avoid the use of the 'get' + prefix for these methods. + + 4 + + + + + + + + + + + + + + + + It is a good practice to call super() in a constructor. If super() is not called but + another constructor (such as an overloaded constructor) is called, this rule will not report it. + + 3 + + + + 0 ] +/ClassOrInterfaceBody + /ClassOrInterfaceBodyDeclaration + /ConstructorDeclaration[ count (.//ExplicitConstructorInvocation)=0 ] +]]> + + + + + + + + + + + Configurable naming conventions for type declarations. This rule reports + type declarations which do not match the regex that applies to their + specific kind (e.g. enum or interface). Each regex can be configured through + properties. + + By default this rule uses the standard Java naming convention (Pascal case), + and reports utility class names not ending with 'Util'. + + 1 + + + + + + + + To avoid mistakes if we want that a Method, Constructor, Field or Nested class have a default access modifier + we must add a comment at the beginning of it's declaration. + By default the comment must be /* default */ or /* package */, if you want another, you have to provide a regular expression. + This rule ignores by default all cases that have a @VisibleForTesting annotation. Use the + property "ignoredAnnotations" to customize the recognized annotations. + + 3 + + + + + + + + Avoid negation within an "if" expression with an "else" clause. For example, rephrase: + `if (x != y) diff(); else same();` as: `if (x == y) same(); else diff();`. + + Most "if (x != y)" cases without an "else" are often return cases, so consistent use of this + rule makes the code easier to read. Also, this resolves trivial ordering problems, such + as "does the error case go first?" or "does the common case go first?". + + 3 + + + + + + + + Enforce a policy for braces on control statements. It is recommended to use braces on 'if ... else' + statements and loop statements, even if they are optional. This usually makes the code clearer, and + helps prepare the future when you need to add another statement. That said, this rule lets you control + which statements are required to have braces via properties. + + From 6.2.0 on, this rule supersedes WhileLoopMustUseBraces, ForLoopMustUseBraces, IfStmtMustUseBraces, + and IfElseStmtMustUseBraces. + + 3 + + + + + + + + + + + + + 1 + or (some $stmt (: in only the block statements until the next label :) + in following-sibling::BlockStatement except following-sibling::SwitchLabel[1]/following-sibling::BlockStatement + satisfies not($stmt/Statement/Block))] + ]]> + + + + + + + + + + Use explicit scoping instead of accidental usage of default package private level. + The rule allows methods and fields annotated with Guava's @VisibleForTesting. + + 3 + + + + + + + + + + + + Avoid importing anything from the package 'java.lang'. These classes are automatically imported (JLS 7.5.3). + + 4 + + + + + + + + Duplicate or overlapping import statements should be avoided. + + 4 + + + + + + + + Empty or auto-generated methods in an abstract class should be tagged as abstract. This helps to remove their inapproprate + usage by developers who should be implementing their own versions in the concrete subclasses. + + 1 + + + + + + + + + + + + + + No need to explicitly extend Object. + 4 + + + + + + + + + + + + + + + Fields should be declared at the top of the class, before any method declarations, constructors, initializers or inner classes. + + 3 + + + + + + + + + Configurable naming conventions for field declarations. This rule reports variable declarations + which do not match the regex that applies to their specific kind ---e.g. constants (static final), + enum constant, final field. Each regex can be configured through properties. + + By default this rule uses the standard Java naming convention (Camel case), and uses the ALL_UPPER + convention for constants and enum constants. + + 1 + + + + + + + + Some for loops can be simplified to while loops, this makes them more concise. + + 3 + + + + + + + + + + + + + + + Avoid using 'for' statements without using curly braces. If the code formatting or + indentation is lost then it becomes difficult to separate the code being controlled + from the rest. + + This rule is deprecated and will be removed with PMD 7.0.0. The rule is replaced + by the rule {% rule java/codestyle/ControlStatementBraces %}. + + 3 + + + //ForStatement[not(Statement/Block)] + + + + + + + + + + Configurable naming conventions for formal parameters of methods and lambdas. + This rule reports formal parameters which do not match the regex that applies to their + specific kind (e.g. lambda parameter, or final formal parameter). Each regex can be + configured through properties. + + By default this rule uses the standard Java naming convention (Camel case). + + 1 + + lambda1 = s_str -> { }; + + // lambda parameters with an explicit type can be configured separately + Consumer lambda1 = (String str) -> { }; + + } + + } + ]]> + + + + + + Names for references to generic values should be limited to a single uppercase letter. + + 4 + + + + 1 + or + string:upper-case(@Image) != @Image +] +]]> + + + + + extends BaseDao { + // This is ok... +} + +public interface GenericDao { + // Also this +} + +public interface GenericDao { + // 'e' should be an 'E' +} + +public interface GenericDao { + // 'EF' is not ok. +} +]]> + + + + + + + Identical `catch` branches use up vertical space and increase the complexity of code without + adding functionality. It's better style to collapse identical branches into a single multi-catch + branch. + + 3 + + + + + + + + Avoid using if..else statements without using surrounding braces. If the code formatting + or indentation is lost then it becomes difficult to separate the code being controlled + from the rest. + + This rule is deprecated and will be removed with PMD 7.0.0. The rule is replaced + by the rule {% rule java/codestyle/ControlStatementBraces %}. + + 3 + + + + + + + + + + + + + + + Avoid using if statements without using braces to surround the code block. If the code + formatting or indentation is lost then it becomes difficult to separate the code being + controlled from the rest. + + This rule is deprecated and will be removed with PMD 7.0.0. The rule is replaced + by the rule {% rule java/codestyle/ControlStatementBraces %}. + + 3 + + + + + + + + + + + + + + + This rule finds Linguistic Naming Antipatterns. It checks for fields, that are named, as if they should + be boolean but have a different type. It also checks for methods, that according to their name, should + return a boolean, but don't. Further, it checks, that getters return something and setters won't. + Finally, it checks that methods, that start with "to" - so called transform methods - actually return + something, since according to their name, they should convert or transform one object into another. + There is additionally an option, to check for methods that contain "To" in their name - which are + also transform methods. However, this is disabled by default, since this detection is prone to + false positives. + + For more information, see [Linguistic Antipatterns - What They Are and How + Developers Perceive Them](https://doi.org/10.1007/s10664-014-9350-8). + + 3 + + + + + + + + The Local Home interface of a Session EJB should be suffixed by 'LocalHome'. + + 4 + + + + + + + + + + + + + + + The Local Interface of a Session EJB should be suffixed by 'Local'. + + 4 + + + + + + + + + + + + + + + A local variable assigned only once can be declared final. + + 3 + + + + + + + + Configurable naming conventions for local variable declarations and other locally-scoped + variables. This rule reports variable declarations which do not match the regex that applies to their + specific kind (e.g. final variable, or catch-clause parameter). Each regex can be configured through + properties. + + By default this rule uses the standard Java naming convention (Camel case). + + 1 + + + + + + + + Fields, formal arguments, or local variable names that are too long can make the code difficult to follow. + + 3 + + + + + $minimum] +]]> + + + + + + + + + + + The EJB Specification states that any MessageDrivenBean or SessionBean should be suffixed by 'Bean'. + + 4 + + + + + + + + + + + + + + + A method argument that is never re-assigned within the method can be declared final. + + 3 + + + + + + + + Configurable naming conventions for method declarations. This rule reports + method declarations which do not match the regex that applies to their + specific kind (e.g. JUnit test or native method). Each regex can be + configured through properties. + + By default this rule uses the standard Java naming convention (Camel case). + + 1 + + + + + + + + Detects when a non-field has a name starting with 'm_'. This usually denotes a field and could be confusing. + + This rule is deprecated and will be removed with PMD 7.0.0. The rule is replaced + by the more general rule + {% rule java/codestyle/LocalVariableNamingConventions %}. + + 3 + + + + + + + + + + + + + + + Detects when a class or interface does not have a package definition. + + 3 + + + //ClassOrInterfaceDeclaration[count(preceding::PackageDeclaration) = 0] + + + + + + + + + + Since Java 1.7, numeric literals can use underscores to separate digits. This rule enforces that + numeric literals above a certain length use these underscores to increase readability. + + The rule only supports decimal (base 10) literals for now. The acceptable length under which literals + are not required to have underscores is configurable via a property. Even under that length, underscores + that are misplaced (not making groups of 3 digits) are reported. + + 3 + + + + + + + + + + + + + + + + + A method should have only one exit point, and that should be the last statement in the method. + + 3 + + 0) { + return "hey"; // first exit + } + return "hi"; // second exit + } +} +]]> + + + + + + Detects when a package definition contains uppercase characters. + + 3 + + + //PackageDeclaration/Name[lower-case(@Image)!=@Image] + + + + + + + + + + Checks for variables that are defined before they might be used. A reference is deemed to be premature if it is created right before a block of code that doesn't use it that also has the ability to return or throw an exception. + + 3 + + + + + + + + Remote Interface of a Session EJB should not have a suffix. + + 4 + + + + + + + + + + + + + + + A Remote Home interface type of a Session EJB should be suffixed by 'Home'. + + 4 + + + + + + + + + + + + + + + Short Classnames with fewer than e.g. five characters are not recommended. + + 4 + + + + + + + + + + + + + + + + Method names that are very short are not helpful to the reader. + + 3 + + + + + + + + + + + + + + + + Fields, local variables, or parameter names that are very short are not helpful to the reader. + + 3 + + + + + + + + + + + + + + + + + Field names using all uppercase characters - Sun's Java naming conventions indicating constants - should + be declared as final. + + This rule is deprecated and will be removed with PMD 7.0.0. The rule is replaced + by the more general rule {% rule java/codestyle/FieldNamingConventions %}. + + 3 + + + + + + + + + + + + + + + If you overuse the static import feature, it can make your program unreadable and + unmaintainable, polluting its namespace with all the static members you import. + Readers of your code (including you, a few months after you wrote it) will not know + which class a static member comes from (Sun 1.5 Language Guide). + + 3 + + + + + $maximumStaticImports] +]]> + + + + + + + + + + + + Avoid the use of value in annotations when it's the only element. + + 3 + + + + + + + + + This rule detects when a constructor is not necessary; i.e., when there is only one constructor and the + constructor is identical to the default constructor. The default constructor should has same access + modifier as the declaring class. In an enum type, the default constructor is implicitly private. + + 3 + + + + + + + + Import statements allow the use of non-fully qualified names. The use of a fully qualified name + which is covered by an import statement is redundant. Consider using the non-fully qualified name. + + 4 + + + + + + + + Avoid the creation of unnecessary local variables + + 3 + + + + + + + + Fields in interfaces and annotations are automatically `public static final`, and methods are `public abstract`. + Classes, interfaces or annotations nested in an interface or annotation are automatically `public static` + (all nested interfaces and annotations are automatically static). + Nested enums are automatically `static`. + For historical reasons, modifiers which are implied by the context are accepted by the compiler, but are superfluous. + + 3 + + + + + + + + Avoid the use of unnecessary return statements. + + 3 + + + + + + + + Use the diamond operator to let the type be inferred automatically. With the Diamond operator it is possible + to avoid duplication of the type parameters. + Instead, the compiler is now able to infer the parameter types for constructor calls, + which makes the code also more readable. + + 3 + + + + + + + + + strings = new ArrayList(); // unnecessary duplication of type parameters +List stringsWithDiamond = new ArrayList<>(); // using the diamond operator is more concise +]]> + + + + + Useless parentheses should be removed. + 4 + + + + 1] + /PrimaryPrefix/Expression + [not(./CastExpression)] + [not(./ConditionalExpression)] + [not(./AdditiveExpression)] + [not(./AssignmentOperator)] +| +//Expression[not(parent::PrimaryPrefix)]/PrimaryExpression[count(*)=1] + /PrimaryPrefix/Expression +| +//Expression/ConditionalAndExpression/PrimaryExpression/PrimaryPrefix/Expression[ + count(*)=1 and + count(./CastExpression)=0 and + count(./EqualityExpression/MultiplicativeExpression)=0 and + count(./ConditionalExpression)=0 and + count(./ConditionalOrExpression)=0] +| +//Expression/ConditionalOrExpression/PrimaryExpression/PrimaryPrefix/Expression[ + count(*)=1 and + not(./CastExpression) and + not(./ConditionalExpression) and + not(./EqualityExpression/MultiplicativeExpression)] +| +//Expression/ConditionalExpression/PrimaryExpression/PrimaryPrefix/Expression[ + count(*)=1 and + not(./CastExpression) and + not(./EqualityExpression)] +| +//Expression/AdditiveExpression[not(./PrimaryExpression/PrimaryPrefix/Literal[@StringLiteral='true'])] + /PrimaryExpression[1]/PrimaryPrefix/Expression[ + count(*)=1 and + not(./CastExpression) and + not(./AdditiveExpression[@Image = '-']) and + not(./ShiftExpression) and + not(./RelationalExpression) and + not(./InstanceOfExpression) and + not(./EqualityExpression) and + not(./AndExpression) and + not(./ExclusiveOrExpression) and + not(./InclusiveOrExpression) and + not(./ConditionalAndExpression) and + not(./ConditionalOrExpression) and + not(./ConditionalExpression)] +| +//Expression/EqualityExpression/PrimaryExpression/PrimaryPrefix/Expression[ + count(*)=1 and + not(./CastExpression) and + not(./AndExpression) and + not(./InclusiveOrExpression) and + not(./ExclusiveOrExpression) and + not(./ConditionalExpression) and + not(./ConditionalAndExpression) and + not(./ConditionalOrExpression) and + not(./EqualityExpression)] +]]> + + + + + + + + + + + Reports qualified this usages in the same class. + + 3 + + + + + + + + + + + + + + + A variable naming conventions rule - customize this to your liking. Currently, it + checks for final variables that should be fully capitalized and non-final variables + that should not include underscores. + + This rule is deprecated and will be removed with PMD 7.0.0. The rule is replaced + by the more general rules {% rule java/codestyle/FieldNamingConventions %}, + {% rule java/codestyle/FormalParameterNamingConventions %}, and + {% rule java/codestyle/LocalVariableNamingConventions %}. + + 1 + + + + + + + + Avoid using 'while' statements without using braces to surround the code block. If the code + formatting or indentation is lost then it becomes difficult to separate the code being + controlled from the rest. + + This rule is deprecated and will be removed with PMD 7.0.0. The rule is replaced + by the rule {% rule java/codestyle/ControlStatementBraces %}. + + 3 + + + //WhileStatement[not(Statement/Block)] + + + + + + + + diff --git a/gradle/quality/pmd/category/java/design.xml b/gradle/quality/pmd/category/java/design.xml new file mode 100644 index 0000000..ded3d80 --- /dev/null +++ b/gradle/quality/pmd/category/java/design.xml @@ -0,0 +1,1657 @@ + + + + + + Rules that help you discover design issues. + + + + + If an abstract class does not provides any methods, it may be acting as a simple data container + that is not meant to be instantiated. In this case, it is probably better to use a private or + protected constructor in order to prevent instantiation than make the class misleadingly abstract. + + 1 + + + + + + + + + + + + + + + Avoid catching generic exceptions such as NullPointerException, RuntimeException, Exception in try-catch block + + 3 + + + + + + + + + + + + + + + Avoid creating deeply nested if-then statements since they are harder to read and error-prone to maintain. + + 3 + + y) { + if (y>z) { + if (z==x) { + // !! too deep + } + } + } + } +} +]]> + + + + + + Catch blocks that merely rethrow a caught exception only add to code size and runtime complexity. + + 3 + + + + + + + + + + + + + + + Catch blocks that merely rethrow a caught exception wrapped inside a new instance of the same type only add to + code size and runtime complexity. + + 3 + + + + + + + + + + + + + + + *Effective Java, 3rd Edition, Item 72: Favor the use of standard exceptions* +> +>Arguably, every erroneous method invocation boils down to an illegal argument or state, +but other exceptions are standardly used for certain kinds of illegal arguments and states. +If a caller passes null in some parameter for which null values are prohibited, convention dictates that +NullPointerException be thrown rather than IllegalArgumentException. + +To implement that, you are encouraged to use `java.util.Objects.requireNonNull()` +(introduced in Java 1.7). This method is designed primarily for doing parameter +validation in methods and constructors with multiple parameters. + +Your parameter validation could thus look like the following: +``` +public class Foo { + private String exampleValue; + + void setExampleValue(String exampleValue) { + // check, throw and assignment in a single standard call + this.exampleValue = Objects.requireNonNull(exampleValue, "exampleValue must not be null!"); + } + } +``` +]]> + + 1 + + + + + + + + + + + + + + + Avoid throwing certain exception types. Rather than throw a raw RuntimeException, Throwable, + Exception, or Error, use a subclassed exception or error instead. + + 1 + + + + + + + + + + + + + + + A class with only private constructors should be final, unless the private constructor + is invoked by a inner class. + + 1 + + + + = 1 ] +[count(./ClassOrInterfaceBody/ClassOrInterfaceBodyDeclaration/ConstructorDeclaration[(@Public = 'true') or (@Protected = 'true') or (@PackagePrivate = 'true')]) = 0 ] +[not(.//ClassOrInterfaceDeclaration)] +]]> + + + + + + + + + + + Sometimes two consecutive 'if' statements can be consolidated by separating their conditions with a boolean short-circuit operator. + + 3 + + + + + + + + + + + + + + + This rule counts the number of unique attributes, local variables, and return types within an object. + A number higher than the specified threshold can indicate a high degree of coupling. + + 3 + + + + + + + = 10. +Additionnally, classes with many methods of moderate complexity get reported as well once the total of their +methods' complexities reaches 80, even if none of the methods was directly reported. + +Reported methods should be broken down into several smaller methods. Reported classes should probably be broken down +into subcomponents.]]> + + 3 + + + + + + + + Data Classes are simple data holders, which reveal most of their state, and + without complex functionality. The lack of functionality may indicate that + their behaviour is defined elsewhere, which is a sign of poor data-behaviour + proximity. By directly exposing their internals, Data Classes break encapsulation, + and therefore reduce the system's maintainability and understandability. Moreover, + classes tend to strongly rely on their data representation, which makes for a brittle + design. + + Refactoring a Data Class should focus on restoring a good data-behaviour proximity. In + most cases, that means moving the operations defined on the data back into the class. + In some other cases it may make sense to remove entirely the class and move the data + into the former client classes. + + 3 + + + + + + + + Errors are system exceptions. Do not extend them. + + 3 + + + + + + + + + + + + + + + Using Exceptions as form of flow control is not recommended as they obscure true exceptions when debugging. + Either add the necessary validation or use an alternate control structure. + + 3 + + + + + + + + Excessive class file lengths are usually indications that the class may be burdened with excessive + responsibilities that could be provided by external classes or functions. In breaking these methods + apart the code becomes more manageable and ripe for reuse. + + 3 + + + + + + + + A high number of imports can indicate a high degree of coupling within an object. This rule + counts the number of unique imports and reports a violation if the count is above the + user-specified threshold. + + 3 + + + + + + + + When methods are excessively long this usually indicates that the method is doing more than its + name/signature might suggest. They also become challenging for others to digest since excessive + scrolling causes readers to lose focus. + Try to reduce the method length by creating helper methods and removing any copy/pasted code. + + 3 + + + + + + + + Methods with numerous parameters are a challenge to maintain, especially if most of them share the + same datatype. These situations usually denote the need for new objects to wrap the numerous parameters. + + 3 + + + + + + + + Classes with large numbers of public methods and attributes require disproportionate testing efforts + since combinational side effects grow rapidly and increase risk. Refactoring these classes into + smaller ones not only increases testability and reliability but also allows new variations to be + developed easily. + + 3 + + + + + + + + If a final field is assigned to a compile-time constant, it could be made static, thus saving overhead + in each object at runtime. + + 3 + + + + + + + + + + + + + + + The God Class rule detects the God Class design flaw using metrics. God classes do too many things, + are very big and overly complex. They should be split apart to be more object-oriented. + The rule uses the detection strategy described in "Object-Oriented Metrics in Practice". + The violations are reported against the entire class. + + See also the references: + + Michele Lanza and Radu Marinescu. Object-Oriented Metrics in Practice: + Using Software Metrics to Characterize, Evaluate, and Improve the Design + of Object-Oriented Systems. Springer, Berlin, 1 edition, October 2006. Page 80. + + 3 + + + + + Identifies private fields whose values never change once object initialization ends either in the declaration + of the field or by a constructor. This helps in converting existing classes to becoming immutable ones. + + 3 + + + + + + + + The Law of Demeter is a simple rule, that says "only talk to friends". It helps to reduce coupling between classes + or objects. + + See also the references: + + * Andrew Hunt, David Thomas, and Ward Cunningham. The Pragmatic Programmer. From Journeyman to Master. Addison-Wesley Longman, Amsterdam, October 1999.; + * K.J. Lieberherr and I.M. Holland. Assuring good style for object-oriented programs. Software, IEEE, 6(5):38–48, 1989.; + * <http://www.ccs.neu.edu/home/lieber/LoD.html> + * <http://en.wikipedia.org/wiki/Law_of_Demeter> + + 3 + + + + + + + + Use opposite operator instead of negating the whole expression with a logic complement operator. + + 3 + + + + + + + + + = + return false; + } + + return true; +} +]]> + + + + + + Avoid using classes from the configured package hierarchy outside of the package hierarchy, + except when using one of the configured allowed classes. + + 3 + + + + + + + + Complexity directly affects maintenance costs is determined by the number of decision points in a method + plus one for the method entry. The decision points include 'if', 'while', 'for', and 'case labels' calls. + Generally, numbers ranging from 1-4 denote low complexity, 5-7 denote moderate complexity, 8-10 denote + high complexity, and 11+ is very high complexity. Modified complexity treats switch statements as a single + decision point. + + This rule is deprecated and will be removed with PMD 7.0.0. The rule is replaced + by the rule {% rule java/design/CyclomaticComplexity %}. + + 3 + + + + + + + + This rule uses the NCSS (Non-Commenting Source Statements) algorithm to determine the number of lines + of code for a given constructor. NCSS ignores comments, and counts actual statements. Using this algorithm, + lines of code that are split are counted as one. + + This rule is deprecated and will be removed with PMD 7.0.0. The rule is replaced + by the rule {% rule java/design/NcssCount %}. + + 3 + + + + + + + + This rule uses the NCSS (Non-Commenting Source Statements) metric to determine the number of lines + of code in a class, method or constructor. NCSS ignores comments, blank lines, and only counts actual + statements. For more details on the calculation, see the documentation of + the [NCSS metric](/pmd_java_metrics_index.html#non-commenting-source-statements-ncss). + + 3 + + + + + + + + This rule uses the NCSS (Non-Commenting Source Statements) algorithm to determine the number of lines + of code for a given method. NCSS ignores comments, and counts actual statements. Using this algorithm, + lines of code that are split are counted as one. + + This rule is deprecated and will be removed with PMD 7.0.0. The rule is replaced + by the rule {% rule java/design/NcssCount %}. + + 3 + + + + + + + + This rule uses the NCSS (Non-Commenting Source Statements) algorithm to determine the number of lines + of code for a given type. NCSS ignores comments, and counts actual statements. Using this algorithm, + lines of code that are split are counted as one. + + This rule is deprecated and will be removed with PMD 7.0.0. The rule is replaced + by the rule {% rule java/design/NcssCount %}. + + 3 + + + + + + + + The NPath complexity of a method is the number of acyclic execution paths through that method. + While cyclomatic complexity counts the number of decision points in a method, NPath counts the number of + full paths from the beginning to the end of the block of the method. That metric grows exponentially, as + it multiplies the complexity of statements in the same block. For more details on the calculation, see the + documentation of the [NPath metric](/pmd_java_metrics_index.html#npath-complexity-npath). + + A threshold of 200 is generally considered the point where measures should be taken to reduce + complexity and increase readability. + + 3 + + + + + + + + A method/constructor shouldn't explicitly throw the generic java.lang.Exception, since it + is unclear which exceptions that can be thrown from the methods. It might be + difficult to document and understand such vague interfaces. Use either a class + derived from RuntimeException or a checked exception. + + 3 + + + + + + + + + + 3 + + + + + + + + + + + + + + + Avoid negation in an assertTrue or assertFalse test. + + For example, rephrase: + + assertTrue(!expr); + + as: + + assertFalse(expr); + + + 3 + + + + + + + + + + + + + + + Avoid unnecessary comparisons in boolean expressions, they serve no purpose and impacts readability. + + 3 + + + + + + + + + + + + + + + Avoid unnecessary if-then-else statements when returning a boolean. The result of + the conditional test can be returned instead. + + 3 + + + + + + + + No need to check for null before an instanceof; the instanceof keyword returns false when given a null argument. + + 3 + + + + + + + + + + + + + + + Fields whose scopes are limited to just single methods do not rely on the containing + object to provide them to other methods. They may be better implemented as local variables + within those methods. + + 3 + + + + + + + + Complexity directly affects maintenance costs is determined by the number of decision points in a method + plus one for the method entry. The decision points include 'if', 'while', 'for', and 'case labels' calls. + Generally, numbers ranging from 1-4 denote low complexity, 5-7 denote moderate complexity, 8-10 denote + high complexity, and 11+ is very high complexity. + + This rule is deprecated and will be removed with PMD 7.0.0. The rule is replaced + by the rule {% rule java/design/CyclomaticComplexity %}. + + 3 + + + + + + + + A high ratio of statements to labels in a switch statement implies that the switch statement + is overloaded. Consider moving the statements into new methods or creating subclasses based + on the switch variable. + + 3 + + + + + + + + Classes that have too many fields can become unwieldy and could be redesigned to have fewer fields, + possibly through grouping related fields in new objects. For example, a class with individual + city/state/zip fields could park them within a single Address field. + + 3 + + + + + + + + A class with too many methods is probably a good suspect for refactoring, in order to reduce its + complexity and find a way to have more fine grained objects. + + 3 + + + + + + $maxmethods + ] +]]> + + + + + + + + The overriding method merely calls the same method defined in a superclass. + + 3 + + + + + + + + When you write a public method, you should be thinking in terms of an API. If your method is public, it means other class + will use it, therefore, you want (or need) to offer a comprehensive and evolutive API. If you pass a lot of information + as a simple series of Strings, you may think of using an Object to represent all those information. You'll get a simpler + API (such as doWork(Workload workload), rather than a tedious series of Strings) and more importantly, if you need at some + point to pass extra data, you'll be able to do so by simply modifying or extending Workload without any modification to + your API. + + 3 + + + + 3 +] +]]> + + + + + + + + + + + For classes that only have static methods, consider making them utility classes. + Note that this doesn't apply to abstract classes, since their subclasses may + well include non-static methods. Also, if you want this class to be a utility class, + remember to add a private constructor to prevent instantiation. + (Note, that this use was known before PMD 5.1.0 as UseSingleton). + + 3 + + + + + + diff --git a/gradle/quality/pmd/category/java/documentation.xml b/gradle/quality/pmd/category/java/documentation.xml new file mode 100644 index 0000000..34b351a --- /dev/null +++ b/gradle/quality/pmd/category/java/documentation.xml @@ -0,0 +1,144 @@ + + + + + + Rules that are related to code documentation. + + + + + A rule for the politically correct... we don't want to offend anyone. + + 3 + + + + + + + + Denotes whether comments are required (or unwanted) for specific language elements. + + 3 + + + + + + + + Determines whether the dimensions of non-header comments found are within the specified limits. + + 3 + + + + + + + + Uncommented Empty Constructor finds instances where a constructor does not + contain statements, but there is no comment. By explicitly commenting empty + constructors it is easier to distinguish between intentional (commented) + and unintentional empty constructors. + + 3 + + + + + + + + + + + + + + + + Uncommented Empty Method Body finds instances where a method body does not contain + statements, but there is no comment. By explicitly commenting empty method bodies + it is easier to distinguish between intentional (commented) and unintentional + empty methods. + + 3 + + + + + + + + + + + + + \ No newline at end of file diff --git a/gradle/quality/pmd/category/java/errorprone.xml b/gradle/quality/pmd/category/java/errorprone.xml new file mode 100644 index 0000000..cf289c3 --- /dev/null +++ b/gradle/quality/pmd/category/java/errorprone.xml @@ -0,0 +1,3383 @@ + + + + + + Rules to detect constructs that are either broken, extremely confusing or prone to runtime errors. + + + + + Avoid assignments in operands; this can make code more complicated and harder to read. + + 3 + + + + + + + + Identifies a possible unsafe usage of a static field. + + 3 + + + + + + + + Methods such as getDeclaredConstructors(), getDeclaredConstructor(Class[]) and setAccessible(), + as the interface PrivilegedAction, allow for the runtime alteration of variable, class, or + method visibility, even if they are private. This violates the principle of encapsulation. + + 3 + + + + + + + + + + + + + + + Use of the term 'assert' will conflict with newer versions of Java since it is a reserved word. + + 2 + + + //VariableDeclaratorId[@Image='assert'] + + + + + + + + + + Using a branching statement as the last part of a loop may be a bug, and/or is confusing. + Ensure that the usage is not a bug, or consider using another approach. + + 2 + + 25) { + break; + } +} +]]> + + + + + + The method Object.finalize() is called by the garbage collector on an object when garbage collection determines + that there are no more references to the object. It should not be invoked by application logic. + + Note that Oracle has declared Object.finalize() as deprecated since JDK 9. + + 3 + + + + + + + + Code should never throw NullPointerExceptions under normal circumstances. A catch block may hide the + original error, causing other, more subtle problems later on. + + 3 + + + + + + + + + + + + + + + Catching Throwable errors is not recommended since its scope is very broad. It includes runtime issues such as + OutOfMemoryError that should be exposed and managed separately. + + 3 + + + + + + + + One might assume that the result of "new BigDecimal(0.1)" is exactly equal to 0.1, but it is actually + equal to .1000000000000000055511151231257827021181583404541015625. + This is because 0.1 cannot be represented exactly as a double (or as a binary fraction of any finite + length). Thus, the long value that is being passed in to the constructor is not exactly equal to 0.1, + appearances notwithstanding. + + The (String) constructor, on the other hand, is perfectly predictable: 'new BigDecimal("0.1")' is + exactly equal to 0.1, as one would expect. Therefore, it is generally recommended that the + (String) constructor be used in preference to this one. + + 3 + + + + + + + + + + + + + + + Code containing duplicate String literals can usually be improved by declaring the String as a constant field. + + 3 + + + + + + + + Use of the term 'enum' will conflict with newer versions of Java since it is a reserved word. + + 2 + + + //VariableDeclaratorId[@Image='enum'] + + + + + + + + + + It can be confusing to have a field name with the same name as a method. While this is permitted, + having information (field) and actions (method) is not clear naming. Developers versed in + Smalltalk often prefer this approach as the methods denote accessor methods. + + 3 + + + + + + + + It is somewhat confusing to have a field name matching the declaring class name. + This probably means that type and/or field names should be chosen more carefully. + + 3 + + + + + + + + Each caught exception type should be handled in its own catch clause. + + 3 + + + + + + + + + + + + + + + Avoid using hard-coded literals in conditional statements. By declaring them as static variables + or private members with descriptive names maintainability is enhanced. By default, the literals "-1" and "0" are ignored. + More exceptions can be defined with the property "ignoreMagicNumbers". + + 3 + + + + + + + + + + + = 0) { } // alternative approach + + if (aDouble > 0.0) {} // magic number 0.0 + if (aDouble >= Double.MIN_VALUE) {} // preferred approach +} +]]> + + + + + + Statements in a catch block that invoke accessors on the exception without using the information + only add to code size. Either remove the invocation, or use the return result. + + 2 + + + + + + + + + + + + + + + The use of multiple unary operators may be problematic, and/or confusing. + Ensure that the intended usage is not a bug, or consider simplifying the expression. + + 2 + + + + + + + + Integer literals should not start with zero since this denotes that the rest of literal will be + interpreted as an octal value. + + 3 + + + + + + + + Avoid equality comparisons with Double.NaN. Due to the implicit lack of representation + precision when comparing floating point numbers these are likely to cause logic errors. + + 3 + + + + + + + + + + + + + + + If a class is a bean, or is referenced by a bean directly or indirectly it needs to be serializable. + Member variables need to be marked as transient, static, or have accessor methods in the class. Marking + variables as transient is the safest and easiest modification. Accessor methods should follow the Java + naming conventions, i.e. for a variable named foo, getFoo() and setFoo() accessor methods should be provided. + + 3 + + + + + + + + The null check is broken since it will throw a NullPointerException itself. + It is likely that you used || instead of && or vice versa. + + 2 + + + + + + + Super should be called at the start of the method + 3 + + + + + + + + + + + + + + + Super should be called at the end of the method + + 3 + + + + + + + + + + + + + + + The skip() method may skip a smaller number of bytes than requested. Check the returned value to find out if it was the case or not. + + 3 + + + + + + + + When deriving an array of a specific class from your Collection, one should provide an array of + the same class as the parameter of the toArray() method. Doing otherwise you will will result + in a ClassCastException. + + 3 + + + + + + + + + + + + + + + The java Manual says "By convention, classes that implement this interface should override + Object.clone (which is protected) with a public method." + + 3 + + + + + + + + + + + + + + + The method clone() should only be implemented if the class implements the Cloneable interface with the exception of + a final method that only throws CloneNotSupportedException. + + The rule can also detect, if the class implements or extends a Cloneable class. + + 3 + + + + + + + + If a class implements cloneable the return type of the method clone() must be the class name. That way, the caller + of the clone method doesn't need to cast the returned clone to the correct type. + + Note: This is only possible with Java 1.5 or higher. + + 3 + + + + + + + + + + + + + + + The method clone() should throw a CloneNotSupportedException. + + 3 + + + + + + + + + + + + + + + Ensure that resources (like Connection, Statement, and ResultSet objects) are always closed after use. + + 3 + + + + + + + + Use equals() to compare object references; avoid comparing them with ==. + + 3 + + + + + + + + Calling overridable methods during construction poses a risk of invoking methods on an incompletely + constructed object and can be difficult to debug. + It may leave the sub-class unable to construct its superclass or forced to replicate the construction + process completely within itself, losing the ability to call super(). If the default constructor + contains a call to an overridable method, the subclass may be completely uninstantiable. Note that + this includes method calls throughout the control flow graph - i.e., if a constructor Foo() calls a + private method bar() that calls a public method buz(), this denotes a problem. + + 1 + + + + + + + The dataflow analysis tracks local definitions, undefinitions and references to variables on different paths on the data flow. + From those informations there can be found various problems. + + 1. UR - Anomaly: There is a reference to a variable that was not defined before. This is a bug and leads to an error. + 2. DU - Anomaly: A recently defined variable is undefined. These anomalies may appear in normal source text. + 3. DD - Anomaly: A recently defined variable is redefined. This is ominous but don't have to be a bug. + + 5 + + dd-anomaly + foo(buz); + buz = 2; +} // buz is undefined when leaving scope -> du-anomaly +]]> + + + + + + Calls to System.gc(), Runtime.getRuntime().gc(), and System.runFinalization() are not advised. Code should have the + same behavior whether the garbage collection is disabled using the option -Xdisableexplicitgc or not. + Moreover, "modern" jvms do a very good job handling garbage collections. If memory usage issues unrelated to memory + leaks develop within an application, it should be dealt with JVM options rather than within the code itself. + + 2 + + + + + + + + + + + + + + + Web applications should not call System.exit(), since only the web container or the + application server should stop the JVM. This rule also checks for the equivalent call Runtime.getRuntime().exit(). + + 3 + + + + + + + + + + + + + + + Extend Exception or RuntimeException instead of Throwable. + + 3 + + + + + + + + + + + + + + + Use Environment.getExternalStorageDirectory() instead of "/sdcard" + + 3 + + + //Literal[starts-with(@Image,'"/sdcard')] + + + + + + + + + + Throwing exceptions within a 'finally' block is confusing since they may mask other exceptions + or code defects. + Note: This is a PMD implementation of the Lint4j rule "A throw in a finally block" + + 4 + + + //FinallyStatement[descendant::ThrowStatement] + + + + + + + + + + Avoid importing anything from the 'sun.*' packages. These packages are not portable and are likely to change. + + 4 + + + + + + + + Don't use floating point for loop indices. If you must use floating point, use double + unless you're certain that float provides enough precision and you have a compelling + performance need (space or time). + + 3 + + + + + + + + + + + + + + + Empty Catch Block finds instances where an exception is caught, but nothing is done. + In most circumstances, this swallows an exception which should either be acted on + or reported. + + 3 + + + + + + + + + + + + + + + + + Empty finalize methods serve no purpose and should be removed. Note that Oracle has declared Object.finalize() as deprecated since JDK 9. + + 3 + + + + + + + + + + + + + + + Empty finally blocks serve no purpose and should be removed. + + 3 + + + + + + + + + + + + + + + Empty If Statement finds instances where a condition is checked but nothing is done about it. + + 3 + + + + + + + + + + + + + + + Empty initializers serve no purpose and should be removed. + + 3 + + + //Initializer/Block[count(*)=0] + + + + + + + + + + Empty block statements serve no purpose and should be removed. + + 3 + + + //BlockStatement/Statement/Block[count(*) = 0] + + + + + + + + + + An empty statement (or a semicolon by itself) that is not used as the sole body of a 'for' + or 'while' loop is probably a bug. It could also be a double semicolon, which has no purpose + and should be removed. + + 3 + + + + + + + + + + + + + + + Empty switch statements serve no purpose and should be removed. + + 3 + + + //SwitchStatement[count(*) = 1] + + + + + + + + + + Empty synchronized blocks serve no purpose and should be removed. + + 3 + + + //SynchronizedStatement/Block[1][count(*) = 0] + + + + + + + + + + Avoid empty try blocks - what's the point? + + 3 + + + + + + + + + + + + + + + Empty While Statement finds all instances where a while statement does nothing. + If it is a timing loop, then you should use Thread.sleep() for it; if it is + a while loop that does a lot in the exit expression, rewrite it to make it clearer. + + 3 + + + + + + + + + + + + + + + Tests for null should not use the equals() method. The '==' operator should be used instead. + + 1 + + + + + + + + + + + + + + + If the finalize() is implemented, its last action should be to call super.finalize. Note that Oracle has declared Object.finalize() as deprecated since JDK 9. + + 3 + + + + + + + + + + + + + + + + If the finalize() is implemented, it should do something besides just calling super.finalize(). Note that Oracle has declared Object.finalize() as deprecated since JDK 9. + + 3 + + + + + + + + + + + + + + + Methods named finalize() should not have parameters. It is confusing and most likely an attempt to + overload Object.finalize(). It will not be called by the VM. + + Note that Oracle has declared Object.finalize() as deprecated since JDK 9. + + 3 + + + + 0]] +]]> + + + + + + + + + + + When overriding the finalize(), the new method should be set as protected. If made public, + other classes may invoke it at inappropriate times. + + Note that Oracle has declared Object.finalize() as deprecated since JDK 9. + + 3 + + + + + + + + + + + + + + + Avoid idempotent operations - they have no effect. + + 3 + + + + + + + + There is no need to import a type that lives in the same package. + + 3 + + + + + + + + Avoid instantiating an object just to call getClass() on it; use the .class public member instead. + + 4 + + + + + + + + + + + + + + + Check for messages in slf4j loggers with non matching number of arguments and placeholders. + + 5 + + + + + + + + Avoid jumbled loop incrementers - its usually a mistake, and is confusing even if intentional. + + 3 + + + + + + + + + + + + + + Some JUnit framework methods are easy to misspell. + + 3 + + + + + + + + + + + + + + + The suite() method in a JUnit test needs to be both public and static. + + 3 + + + + + + + + + + + + + + + In most cases, the Logger reference can be declared as static and final. + + 2 + + + + + + + + + + + + + + + Non-constructor methods should not have the same name as the enclosing class. + + 3 + + + + + + + + The null check here is misplaced. If the variable is null a NullPointerException will be thrown. + Either the check is useless (the variable will never be "null") or it is incorrect. + + 3 + + + + + + + + + + + + + + + + + + Switch statements without break or return statements for each case option + may indicate problematic behaviour. Empty cases are ignored as these indicate an intentional fall-through. + + 3 + + + + + + + + + + + + + + + Serializable classes should provide a serialVersionUID field. + The serialVersionUID field is also needed for abstract base classes. Each individual class in the inheritance + chain needs an own serialVersionUID field. See also [Should an abstract class have a serialVersionUID](https://stackoverflow.com/questions/893259/should-an-abstract-class-have-a-serialversionuid). + + 3 + + + + + + + + + + + + + + + A class that has private constructors and does not have any static methods or fields cannot be used. + + 3 + + + + + + + + + + + + + + + Normally only one logger is used in each class. + + 2 + + + + + + + + A non-case label (e.g. a named break/continue label) was present in a switch statement. + This legal, but confusing. It is easy to mix up the case labels and the non-case labels. + + 3 + + + //SwitchStatement//BlockStatement/Statement/LabeledStatement + + + + + + + + + + A non-static initializer block will be called any time a constructor is invoked (just prior to + invoking the constructor). While this is a valid language construct, it is rarely used and is + confusing. + + 3 + + + + + + + + + + + + + + + Assigning a "null" to a variable (outside of its declaration) is usually bad form. Sometimes, this type + of assignment is an indication that the programmer doesn't completely understand what is going on in the code. + + NOTE: This sort of assignment may used in some cases to dereference objects and encourage garbage collection. + + 3 + + + + + + + + Override both public boolean Object.equals(Object other), and public int Object.hashCode(), or override neither. Even if you are inheriting a hashCode() from a parent class, consider implementing hashCode and explicitly delegating to your superclass. + + 3 + + + + + + + + Object clone() should be implemented with super.clone(). + + 2 + + + + 0 +] +]]> + + + + + + + + + + + A logger should normally be defined private static final and be associated with the correct class. + Private final Log log; is also allowed for rare cases where loggers need to be passed around, + with the restriction that the logger needs to be passed into the constructor. + + 3 + + + + + + + + + + + + + + + + For any method that returns an array, it is a better to return an empty array rather than a + null reference. This removes the need for null checking all results and avoids inadvertent + NullPointerExceptions. + + 1 + + + + + + + + + + + + + + + Avoid returning from a finally block, this can discard exceptions. + + 3 + + + + //FinallyStatement//ReturnStatement except //FinallyStatement//(MethodDeclaration|LambdaExpression)//ReturnStatement + + + + + + + + + + Be sure to specify a Locale when creating SimpleDateFormat instances to ensure that locale-appropriate + formatting is used. + + 3 + + + + + + + + + + + + + + + Some classes contain overloaded getInstance. The problem with overloaded getInstance methods + is that the instance created using the overloaded method is not cached and so, + for each call and new objects will be created for every invocation. + + 2 + + + + + + + + Some classes contain overloaded getInstance. The problem with overloaded getInstance methods + is that the instance created using the overloaded method is not cached and so, + for each call and new objects will be created for every invocation. + + 2 + + + + + + + + According to the J2EE specification, an EJB should not have any static fields + with write access. However, static read-only fields are allowed. This ensures proper + behavior especially when instances are distributed by the container on several JREs. + + 3 + + + + + + + + + + + + + + + Individual character values provided as initialization arguments will be converted into integers. + This can lead to internal buffer sizes that are larger than expected. Some examples: + + ``` + new StringBuffer() // 16 + new StringBuffer(6) // 6 + new StringBuffer("hello world") // 11 + 16 = 27 + new StringBuffer('A') // chr(A) = 65 + new StringBuffer("A") // 1 + 16 = 17 + + new StringBuilder() // 16 + new StringBuilder(6) // 6 + new StringBuilder("hello world") // 11 + 16 = 27 + new StringBuilder('C') // chr(C) = 67 + new StringBuilder("A") // 1 + 16 = 17 + ``` + + 4 + + + + + + + + + + + + + + + The method name and parameter number are suspiciously close to equals(Object), which can denote an + intention to override the equals(Object) method. + + 2 + + + + + + + + + + + + + + + The method name and return type are suspiciously close to hashCode(), which may denote an intention + to override the hashCode() method. + + 3 + + + + + + + + A suspicious octal escape sequence was found inside a String literal. + The Java language specification (section 3.10.6) says an octal + escape sequence inside a literal String shall consist of a backslash + followed by: + + OctalDigit | OctalDigit OctalDigit | ZeroToThree OctalDigit OctalDigit + + Any octal escape sequence followed by non-octal digits can be confusing, + e.g. "\038" is interpreted as the octal escape sequence "\03" followed by + the literal character "8". + + 3 + + + + + + + + Test classes end with the suffix Test. Having a non-test class with that name is not a good practice, + since most people will assume it is a test case. Test classes have test methods named testXXX. + + 3 + + + + + + + + Do not use "if" statements whose conditionals are always true or always false. + + 3 + + + + + + + + + + + + + + + A JUnit test assertion with a boolean literal is unnecessary since it always will evaluate to the same thing. + Consider using flow control (in case of assertTrue(false) or similar) or simply removing + statements like assertTrue(true) and assertFalse(false). If you just want a test to halt after finding + an error, use the fail() method and provide an indication message of why it did. + + 3 + + + + + + + + + + + + + + + Using equalsIgnoreCase() is faster than using toUpperCase/toLowerCase().equals() + + 3 + + + + + + + + Avoid the use temporary objects when converting primitives to Strings. Use the static conversion methods + on the wrapper classes instead. + + 3 + + + + + + + + After checking an object reference for null, you should invoke equals() on that object rather than passing it to another object's equals() method. + + 3 + + + + + + + + + + + + + + + To make sure the full stacktrace is printed out, use the logging statement with two arguments: a String and a Throwable. + + 3 + + + + + + + + + + + + + + + Using '==' or '!=' to compare strings only works if intern version is used on both sides. + Use the equals() method instead. + + 3 + + + + + + + + + + + + + + + An operation on an Immutable object (String, BigDecimal or BigInteger) won't change the object itself + since the result of the operation is a new object. Therefore, ignoring the operation result is an error. + + 3 + + + + + + + + When doing String.toLowerCase()/toUpperCase() conversions, use Locales to avoids problems with languages that + have unusual conventions, i.e. Turkish. + + 3 + + + + + + + + + + + + + + + In J2EE, the getClassLoader() method might not work as expected. Use + Thread.currentThread().getContextClassLoader() instead. + + 3 + + + //PrimarySuffix[@Image='getClassLoader'] + + + + + + + + diff --git a/gradle/quality/pmd/category/java/multithreading.xml b/gradle/quality/pmd/category/java/multithreading.xml new file mode 100644 index 0000000..d3e8327 --- /dev/null +++ b/gradle/quality/pmd/category/java/multithreading.xml @@ -0,0 +1,393 @@ + + + + + + Rules that flag issues when dealing with multiple threads of execution. + + + + + Method-level synchronization can cause problems when new code is added to the method. + Block-level synchronization helps to ensure that only the code that needs synchronization + gets it. + + 3 + + + //MethodDeclaration[@Synchronized='true'] + + + + + + + + + + Avoid using java.lang.ThreadGroup; although it is intended to be used in a threaded environment + it contains methods that are not thread-safe. + + 3 + + + + + + + + + + + + + + + Use of the keyword 'volatile' is generally used to fine tune a Java application, and therefore, requires + a good expertise of the Java Memory Model. Moreover, its range of action is somewhat misknown. Therefore, + the volatile keyword should not be used for maintenance purpose and portability. + + 2 + + + //FieldDeclaration[contains(@Volatile,'true')] + + + + + + + + + + The J2EE specification explicitly forbids the use of threads. + + 3 + + + //ClassOrInterfaceType[@Image = 'Thread' or @Image = 'Runnable'] + + + + + + + + + + Explicitly calling Thread.run() method will execute in the caller's thread of control. Instead, call Thread.start() for the intended behavior. + + 4 + + + + + + + + + + + + + + + Partially created objects can be returned by the Double Checked Locking pattern when used in Java. + An optimizing JRE may assign a reference to the baz variable before it calls the constructor of the object the + reference points to. + + Note: With Java 5, you can make Double checked locking work, if you declare the variable to be `volatile`. + + For more details refer to: <http://www.javaworld.com/javaworld/jw-02-2001/jw-0209-double.html> + or <http://www.cs.umd.edu/~pugh/java/memoryModel/DoubleCheckedLocking.html> + + 1 + + + + + + + + Non-thread safe singletons can result in bad state changes. Eliminate + static singletons if possible by instantiating the object directly. Static + singletons are usually not needed as only a single instance exists anyway. + Other possible fixes are to synchronize the entire method or to use an + [initialize-on-demand holder class](https://en.wikipedia.org/wiki/Initialization-on-demand_holder_idiom). + + Refrain from using the double-checked locking pattern. The Java Memory Model doesn't + guarantee it to work unless the variable is declared as `volatile`, adding an uneeded + performance penalty. [Reference](http://www.cs.umd.edu/~pugh/java/memoryModel/DoubleCheckedLocking.html) + + See Effective Java, item 48. + + 3 + + + + + + + + SimpleDateFormat instances are not synchronized. Sun recommends using separate format instances + for each thread. If multiple threads must access a static formatter, the formatter must be + synchronized either on method or block level. + + This rule has been deprecated in favor of the rule {% rule UnsynchronizedStaticFormatter %}. + + 3 + + + + + + + + Instances of `java.text.Format` are generally not synchronized. + Sun recommends using separate format instances for each thread. + If multiple threads must access a static formatter, the formatter must be + synchronized either on method or block level. + + 3 + + + + + + + + Since Java5 brought a new implementation of the Map designed for multi-threaded access, you can + perform efficient map reads without blocking other threads. + + 3 + + + + + + + + + + + + + + + Thread.notify() awakens a thread monitoring the object. If more than one thread is monitoring, then only + one is chosen. The thread chosen is arbitrary; thus its usually safer to call notifyAll() instead. + + 3 + + + + + + + + + + + + + \ No newline at end of file diff --git a/gradle/quality/pmd/category/java/performance.xml b/gradle/quality/pmd/category/java/performance.xml new file mode 100644 index 0000000..1ce2d8d --- /dev/null +++ b/gradle/quality/pmd/category/java/performance.xml @@ -0,0 +1,1006 @@ + + + + + + Rules that flag suboptimal code. + + + + + The conversion of literals to strings by concatenating them with empty strings is inefficient. + It is much better to use one of the type-specific toString() methods instead. + + 3 + + + + + + + + + + + + + + + Avoid concatenating characters as strings in StringBuffer/StringBuilder.append methods. + + 3 + + + + + + + + Instead of manually copying data between two arrays, use the efficient Arrays.copyOf or System.arraycopy method instead. + + 3 + + + + + + + + + + + + + + + The FileInputStream and FileOutputStream classes contains a finalizer method which will cause garbage + collection pauses. + See [JDK-8080225](https://bugs.openjdk.java.net/browse/JDK-8080225) for details. + + The FileReader and FileWriter constructors instantiate FileInputStream and FileOutputStream, + again causing garbage collection issues while finalizer methods are called. + + * Use `Files.newInputStream(Paths.get(fileName))` instead of `new FileInputStream(fileName)`. + * Use `Files.newOutputStream(Paths.get(fileName))` instead of `new FileOutputStream(fileName)`. + * Use `Files.newBufferedReader(Paths.get(fileName))` instead of `new FileReader(fileName)`. + * Use `Files.newBufferedWriter(Paths.get(fileName))` instead of `new FileWriter(fileName)`. + + Please note, that the `java.nio` API does not throw a `FileNotFoundException` anymore, instead + it throws a `NoSuchFileException`. If your code dealt explicitly with a `FileNotFoundException`, + then this needs to be adjusted. Both exceptions are subclasses of `IOException`, so catching + that one covers both. + + 1 + + + + + + + + + + + + + + + New objects created within loops should be checked to see if they can created outside them and reused. + + 3 + + + + + + + + Java uses the 'short' type to reduce memory usage, not to optimize calculation. In fact, the JVM does not have any + arithmetic capabilities for the short type: the JVM must convert the short into an int, do the proper calculation + and convert the int back to a short. Thus any storage gains found through use of the 'short' type may be offset by + adverse impacts on performance. + + 1 + + + + + + + + + + + + + + + Don't create instances of already existing BigInteger (BigInteger.ZERO, BigInteger.ONE) and + for Java 1.5 onwards, BigInteger.TEN and BigDecimal (BigDecimal.ZERO, BigDecimal.ONE, BigDecimal.TEN) + + 3 + + + + + + + + Avoid instantiating Boolean objects; you can reference Boolean.TRUE, Boolean.FALSE, or call Boolean.valueOf() instead. + Note that new Boolean() is deprecated since JDK 9 for that reason. + + 2 + + + + + + + + Calling new Byte() causes memory allocation that can be avoided by the static Byte.valueOf(). + It makes use of an internal cache that recycles earlier instances making it more memory efficient. + Note that new Byte() is deprecated since JDK 9 for that reason. + + 2 + + + + + + + + + + + + + + + Consecutive calls to StringBuffer/StringBuilder .append should be chained, reusing the target object. This can improve the performance + by producing a smaller bytecode, reducing overhead and improving inlining. A complete analysis can be found [here](https://github.com/pmd/pmd/issues/202#issuecomment-274349067) + + 3 + + + + + + + + Consecutively calling StringBuffer/StringBuilder.append(...) with literals should be avoided. + Since the literals are constants, they can already be combined into a single String literal and this String + can be appended in a single method call. + + 3 + + + + + + + + + + 3 + + 0) { + doSomething(); + } +} +]]> + + + + + + Avoid concatenating non-literals in a StringBuffer constructor or append() since intermediate buffers will + need to be be created and destroyed by the JVM. + + 3 + + + + + + + + Failing to pre-size a StringBuffer or StringBuilder properly could cause it to re-size many times + during runtime. This rule attempts to determine the total number the characters that are actually + passed into StringBuffer.append(), but represents a best guess "worst case" scenario. An empty + StringBuffer/StringBuilder constructor initializes the object to 16 characters. This default + is assumed if the length of the constructor can not be determined. + + 3 + + + + + + + + Calling new Integer() causes memory allocation that can be avoided by the static Integer.valueOf(). + It makes use of an internal cache that recycles earlier instances making it more memory efficient. + Note that new Integer() is deprecated since JDK 9 for that reason. + + 2 + + + + + + + + + + + + + + + Calling new Long() causes memory allocation that can be avoided by the static Long.valueOf(). + It makes use of an internal cache that recycles earlier instances making it more memory efficient. + Note that new Long() is deprecated since JDK 9 for that reason. + + 2 + + + + + + + + + + + + + + + Calls to a collection's `toArray(E[])` method should specify a target array of zero size. This allows the JVM + to optimize the memory allocation and copying as much as possible. + + Previous versions of this rule (pre PMD 6.0.0) suggested the opposite, but current JVM implementations + perform always better, when they have full control over the target array. And allocation an array via + reflection is nowadays as fast as the direct allocation. + + See also [Arrays of Wisdom of the Ancients](https://shipilev.net/blog/2016/arrays-wisdom-ancients/) + + Note: If you don't need an array of the correct type, then the simple `toArray()` method without an array + is faster, but returns only an array of type `Object[]`. + + 3 + + + + + + + + + foos = getFoos(); + +// much better; this one allows the jvm to allocate an array of the correct size and effectively skip +// the zeroing, since each array element will be overridden anyways +Foo[] fooArray = foos.toArray(new Foo[0]); + +// inefficient, the array needs to be zeroed out by the jvm before it is handed over to the toArray method +Foo[] fooArray = foos.toArray(new Foo[foos.size()]); +]]> + + + + + + Java will initialize fields with known default values so any explicit initialization of those same defaults + is redundant and results in a larger class file (approximately three additional bytecode instructions per field). + + 3 + + + + + + + + Since it passes in a literal of length 1, calls to (string).startsWith can be rewritten using (string).charAt(0) + at the expense of some readability. + + 3 + + + + + + + + + + + + + + + Calling new Short() causes memory allocation that can be avoided by the static Short.valueOf(). + It makes use of an internal cache that recycles earlier instances making it more memory efficient. + Note that new Short() is deprecated since JDK 9 for that reason. + + 2 + + + + + + + + + + + + + + + Avoid instantiating String objects; this is usually unnecessary since they are immutable and can be safely shared. + + 2 + + + + + + + + Avoid calling toString() on objects already known to be string instances; this is unnecessary. + + 3 + + + + + + + + Switch statements are intended to be used to support complex branching behaviour. Using a switch for only a few + cases is ill-advised, since switches are not as easy to understand as if-then statements. In these cases use the + if-then statement to increase code readability. + + 3 + + + + + + + + + + + + + + + + Most wrapper classes provide static conversion methods that avoid the need to create intermediate objects + just to create the primitive forms. Using these avoids the cost of creating objects that also need to be + garbage-collected later. + + 3 + + + + + + + + ArrayList is a much better Collection implementation than Vector if thread-safe operation is not required. + + 3 + + + + 0] + //AllocationExpression/ClassOrInterfaceType + [@Image='Vector' or @Image='java.util.Vector'] +]]> + + + + + + + + + + + (Arrays.asList(...)) if that is inconvenient for you (e.g. because of concurrent access). +]]> + + 3 + + + + + + + + + l= new ArrayList<>(100); + for (int i=0; i< 100; i++) { + l.add(ints[i]); + } + for (int i=0; i< 100; i++) { + l.add(a[i].toString()); // won't trigger the rule + } + } +} +]]> + + + + + + Use String.indexOf(char) when checking for the index of a single character; it executes faster. + + 3 + + + + + + + + No need to call String.valueOf to append to a string; just use the valueOf() argument directly. + + 3 + + + + + + + + The use of the '+=' operator for appending strings causes the JVM to create and use an internal StringBuffer. + If a non-trivial number of these concatenations are being used then the explicit use of a StringBuilder or + threadsafe StringBuffer is recommended to avoid this. + + 3 + + + + + + + + Use StringBuffer.length() to determine StringBuffer length rather than using StringBuffer.toString().equals("") + or StringBuffer.toString().length() == ... + + 3 + + + + + + + + + diff --git a/gradle/quality/pmd/category/java/security.xml b/gradle/quality/pmd/category/java/security.xml new file mode 100644 index 0000000..dbad352 --- /dev/null +++ b/gradle/quality/pmd/category/java/security.xml @@ -0,0 +1,65 @@ + + + + + + Rules that flag potential security flaws. + + + + + Do not use hard coded values for cryptographic operations. Please store keys outside of source code. + + 3 + + + + + + + + Do not use hard coded initialization vector in cryptographic operations. Please use a randomly generated IV. + + 3 + + + + + + diff --git a/gradle/quality/sonarqube.gradle b/gradle/quality/sonarqube.gradle new file mode 100644 index 0000000..8243dd3 --- /dev/null +++ b/gradle/quality/sonarqube.gradle @@ -0,0 +1,10 @@ +/* +sonarqube { + properties { + property "sonar.projectName", "${project.group} ${project.name}" + property "sonar.sourceEncoding", "UTF-8" + property "sonar.tests", "src/test/java" + property "sonar.scm.provider", "git" + } +} +*/ \ No newline at end of file diff --git a/gradle/quality/spotbugs.gradle b/gradle/quality/spotbugs.gradle new file mode 100644 index 0000000..83a40f9 --- /dev/null +++ b/gradle/quality/spotbugs.gradle @@ -0,0 +1,14 @@ +apply plugin: "com.github.spotbugs" + +spotbugs { + effort = "min" + reportLevel = "low" + ignoreFailures = true +} + +spotbugsMain { + reports { + xml.getRequired().set(false) + html.getRequired().set(true) + } +} diff --git a/gradle/repositories/maven.gradle b/gradle/repositories/maven.gradle new file mode 100644 index 0000000..ec58acb --- /dev/null +++ b/gradle/repositories/maven.gradle @@ -0,0 +1,4 @@ +repositories { + mavenLocal() + mavenCentral() +} diff --git a/gradle/test/jmh.gradle b/gradle/test/jmh.gradle new file mode 100644 index 0000000..26de618 --- /dev/null +++ b/gradle/test/jmh.gradle @@ -0,0 +1,22 @@ +sourceSets { + jmh { + java.srcDirs = ['src/jmh/java'] + resources.srcDirs = ['src/jmh/resources'] + compileClasspath += sourceSets.main.runtimeClasspath + } +} + +dependencies { + jmhImplementation 'org.openjdk.jmh:jmh-core:1.37' + jmhAnnotationProcessor 'org.openjdk.jmh:jmh-generator-annprocess:1.37' +} + +task jmh(type: JavaExec, group: 'jmh', dependsOn: jmhClasses) { + mainClass.set('org.openjdk.jmh.Main') + classpath = sourceSets.jmh.compileClasspath + sourceSets.jmh.runtimeClasspath + project.file('build/reports/jmh').mkdirs() + args '-rf', 'json' + args '-rff', project.file('build/reports/jmh/result.json') +} + +classes.finalizedBy(jmhClasses) diff --git a/gradle/test/junit5.gradle b/gradle/test/junit5.gradle new file mode 100644 index 0000000..043a91c --- /dev/null +++ b/gradle/test/junit5.gradle @@ -0,0 +1,44 @@ + +dependencies { + testImplementation testLibs.junit.jupiter.api + testImplementation testLibs.junit.jupiter.params + testImplementation testLibs.hamcrest + testRuntimeOnly testLibs.junit.jupiter.engine + testRuntimeOnly testLibs.junit.vintage.engine + testRuntimeOnly testLibs.junit.jupiter.platform.launcher +} + +test { + useJUnitPlatform() + failFast = false + testLogging { + events 'STARTED', 'PASSED', 'FAILED', 'SKIPPED' + showStandardStreams = true + } + minHeapSize = "1g" // initial heap size + maxHeapSize = "2g" // maximum heap size + jvmArgs '--add-exports=java.base/jdk.internal=ALL-UNNAMED', + '--add-exports=java.base/jdk.internal.misc=ALL-UNNAMED', + '--add-exports=java.base/sun.nio.ch=ALL-UNNAMED', + '--add-exports=jdk.unsupported/sun.misc=ALL-UNNAMED', + '--add-opens=java.base/java.lang=ALL-UNNAMED', + '--add-opens=java.base/java.lang.reflect=ALL-UNNAMED', + '--add-opens=java.base/java.io=ALL-UNNAMED', + '--add-opens=java.base/java.util=ALL-UNNAMED', + '--add-opens=java.base/jdk.internal=ALL-UNNAMED', + '--add-opens=java.base/jdk.internal.misc=ALL-UNNAMED', + '--add-opens=jdk.unsupported/sun.misc=ALL-UNNAMED', + '-Dio.netty.bootstrap.extensions=serviceload' + systemProperty 'java.util.logging.config.file', 'src/test/resources/logging.properties' + systemProperty "nativeImage.handlerMetadataGroupId", "io.netty" + systemProperty "nativeimage.handlerMetadataArtifactId", "netty-transport" + afterSuite { desc, result -> + if (!desc.parent) { + println "\nTest result: ${result.resultType}" + println "Test summary: ${result.testCount} tests, " + + "${result.successfulTestCount} succeeded, " + + "${result.failedTestCount} failed, " + + "${result.skippedTestCount} skipped" + } + } +} diff --git a/gradle/wrapper/gradle-wrapper.jar b/gradle/wrapper/gradle-wrapper.jar new file mode 100644 index 0000000000000000000000000000000000000000..d64cd4917707c1f8861d8cb53dd15194d4248596 GIT binary patch literal 43462 zcma&NWl&^owk(X(xVyW%ySuwf;qI=D6|RlDJ2cR^yEKh!@I- zp9QeisK*rlxC>+~7Dk4IxIRsKBHqdR9b3+fyL=ynHmIDe&|>O*VlvO+%z5;9Z$|DJ zb4dO}-R=MKr^6EKJiOrJdLnCJn>np?~vU-1sSFgPu;pthGwf}bG z(1db%xwr#x)r+`4AGu$j7~u2MpVs3VpLp|mx&;>`0p0vH6kF+D2CY0fVdQOZ@h;A` z{infNyvmFUiu*XG}RNMNwXrbec_*a3N=2zJ|Wh5z* z5rAX$JJR{#zP>KY**>xHTuw?|-Rg|o24V)74HcfVT;WtQHXlE+_4iPE8QE#DUm%x0 zEKr75ur~W%w#-My3Tj`hH6EuEW+8K-^5P62$7Sc5OK+22qj&Pd1;)1#4tKihi=~8C zHiQSst0cpri6%OeaR`PY>HH_;CPaRNty%WTm4{wDK8V6gCZlG@U3$~JQZ;HPvDJcT1V{ z?>H@13MJcCNe#5z+MecYNi@VT5|&UiN1D4ATT+%M+h4c$t;C#UAs3O_q=GxK0}8%8 z8J(_M9bayxN}69ex4dzM_P3oh@ZGREjVvn%%r7=xjkqxJP4kj}5tlf;QosR=%4L5y zWhgejO=vao5oX%mOHbhJ8V+SG&K5dABn6!WiKl{|oPkq(9z8l&Mm%(=qGcFzI=eLu zWc_oCLyf;hVlB@dnwY98?75B20=n$>u3b|NB28H0u-6Rpl((%KWEBOfElVWJx+5yg z#SGqwza7f}$z;n~g%4HDU{;V{gXIhft*q2=4zSezGK~nBgu9-Q*rZ#2f=Q}i2|qOp z!!y4p)4o=LVUNhlkp#JL{tfkhXNbB=Ox>M=n6soptJw-IDI|_$is2w}(XY>a=H52d z3zE$tjPUhWWS+5h=KVH&uqQS=$v3nRs&p$%11b%5qtF}S2#Pc`IiyBIF4%A!;AVoI zXU8-Rpv!DQNcF~(qQnyyMy=-AN~U>#&X1j5BLDP{?K!%h!;hfJI>$mdLSvktEr*89 zdJHvby^$xEX0^l9g$xW-d?J;L0#(`UT~zpL&*cEh$L|HPAu=P8`OQZV!-}l`noSp_ zQ-1$q$R-gDL)?6YaM!=8H=QGW$NT2SeZlb8PKJdc=F-cT@j7Xags+Pr*jPtlHFnf- zh?q<6;)27IdPc^Wdy-mX%2s84C1xZq9Xms+==F4);O`VUASmu3(RlgE#0+#giLh-& zcxm3_e}n4{%|X zJp{G_j+%`j_q5}k{eW&TlP}J2wtZ2^<^E(O)4OQX8FDp6RJq!F{(6eHWSD3=f~(h} zJXCf7=r<16X{pHkm%yzYI_=VDP&9bmI1*)YXZeB}F? z(%QsB5fo*FUZxK$oX~X^69;x~j7ms8xlzpt-T15e9}$4T-pC z6PFg@;B-j|Ywajpe4~bk#S6(fO^|mm1hKOPfA%8-_iGCfICE|=P_~e;Wz6my&)h_~ zkv&_xSAw7AZ%ThYF(4jADW4vg=oEdJGVOs>FqamoL3Np8>?!W#!R-0%2Bg4h?kz5I zKV-rKN2n(vUL%D<4oj@|`eJ>0i#TmYBtYmfla;c!ATW%;xGQ0*TW@PTlGG><@dxUI zg>+3SiGdZ%?5N=8uoLA|$4isK$aJ%i{hECP$bK{J#0W2gQ3YEa zZQ50Stn6hqdfxJ*9#NuSLwKFCUGk@c=(igyVL;;2^wi4o30YXSIb2g_ud$ zgpCr@H0qWtk2hK8Q|&wx)}4+hTYlf;$a4#oUM=V@Cw#!$(nOFFpZ;0lc!qd=c$S}Z zGGI-0jg~S~cgVT=4Vo)b)|4phjStD49*EqC)IPwyeKBLcN;Wu@Aeph;emROAwJ-0< z_#>wVm$)ygH|qyxZaet&(Vf%pVdnvKWJn9`%DAxj3ot;v>S$I}jJ$FLBF*~iZ!ZXE zkvui&p}fI0Y=IDX)mm0@tAd|fEHl~J&K}ZX(Mm3cm1UAuwJ42+AO5@HwYfDH7ipIc zmI;1J;J@+aCNG1M`Btf>YT>~c&3j~Qi@Py5JT6;zjx$cvOQW@3oQ>|}GH?TW-E z1R;q^QFjm5W~7f}c3Ww|awg1BAJ^slEV~Pk`Kd`PS$7;SqJZNj->it4DW2l15}xP6 zoCl$kyEF%yJni0(L!Z&14m!1urXh6Btj_5JYt1{#+H8w?5QI%% zo-$KYWNMJVH?Hh@1n7OSu~QhSswL8x0=$<8QG_zepi_`y_79=nK=_ZP_`Em2UI*tyQoB+r{1QYZCpb?2OrgUw#oRH$?^Tj!Req>XiE#~B|~ z+%HB;=ic+R@px4Ld8mwpY;W^A%8%l8$@B@1m5n`TlKI6bz2mp*^^^1mK$COW$HOfp zUGTz-cN9?BGEp}5A!mDFjaiWa2_J2Iq8qj0mXzk; z66JBKRP{p%wN7XobR0YjhAuW9T1Gw3FDvR5dWJ8ElNYF94eF3ebu+QwKjtvVu4L zI9ip#mQ@4uqVdkl-TUQMb^XBJVLW(-$s;Nq;@5gr4`UfLgF$adIhd?rHOa%D);whv z=;krPp~@I+-Z|r#s3yCH+c1US?dnm+C*)r{m+86sTJusLdNu^sqLrfWed^ndHXH`m zd3#cOe3>w-ga(Dus_^ppG9AC>Iq{y%%CK+Cro_sqLCs{VLuK=dev>OL1dis4(PQ5R zcz)>DjEkfV+MO;~>VUlYF00SgfUo~@(&9$Iy2|G0T9BSP?&T22>K46D zL*~j#yJ?)^*%J3!16f)@Y2Z^kS*BzwfAQ7K96rFRIh>#$*$_Io;z>ux@}G98!fWR@ zGTFxv4r~v)Gsd|pF91*-eaZ3Qw1MH$K^7JhWIdX%o$2kCbvGDXy)a?@8T&1dY4`;L z4Kn+f%SSFWE_rpEpL9bnlmYq`D!6F%di<&Hh=+!VI~j)2mfil03T#jJ_s?}VV0_hp z7T9bWxc>Jm2Z0WMU?`Z$xE74Gu~%s{mW!d4uvKCx@WD+gPUQ zV0vQS(Ig++z=EHN)BR44*EDSWIyT~R4$FcF*VEY*8@l=218Q05D2$|fXKFhRgBIEE zdDFB}1dKkoO^7}{5crKX!p?dZWNz$m>1icsXG2N+((x0OIST9Zo^DW_tytvlwXGpn zs8?pJXjEG;T@qrZi%#h93?FP$!&P4JA(&H61tqQi=opRzNpm zkrG}$^t9&XduK*Qa1?355wd8G2CI6QEh@Ua>AsD;7oRUNLPb76m4HG3K?)wF~IyS3`fXuNM>${?wmB zpVz;?6_(Fiadfd{vUCBM*_kt$+F3J+IojI;9L(gc9n3{sEZyzR9o!_mOwFC#tQ{Q~ zP3-`#uK#tP3Q7~Q;4H|wjZHO8h7e4IuBxl&vz2w~D8)w=Wtg31zpZhz%+kzSzL*dV zwp@{WU4i;hJ7c2f1O;7Mz6qRKeASoIv0_bV=i@NMG*l<#+;INk-^`5w@}Dj~;k=|}qM1vq_P z|GpBGe_IKq|LNy9SJhKOQ$c=5L{Dv|Q_lZl=-ky*BFBJLW9&y_C|!vyM~rQx=!vun z?rZJQB5t}Dctmui5i31C_;_}CEn}_W%>oSXtt>@kE1=JW*4*v4tPp;O6 zmAk{)m!)}34pTWg8{i>($%NQ(Tl;QC@J@FfBoc%Gr&m560^kgSfodAFrIjF}aIw)X zoXZ`@IsMkc8_=w%-7`D6Y4e*CG8k%Ud=GXhsTR50jUnm+R*0A(O3UKFg0`K;qp1bl z7``HN=?39ic_kR|^R^~w-*pa?Vj#7|e9F1iRx{GN2?wK!xR1GW!qa=~pjJb-#u1K8 zeR?Y2i-pt}yJq;SCiVHODIvQJX|ZJaT8nO+(?HXbLefulKKgM^B(UIO1r+S=7;kLJ zcH}1J=Px2jsh3Tec&v8Jcbng8;V-`#*UHt?hB(pmOipKwf3Lz8rG$heEB30Sg*2rx zV<|KN86$soN(I!BwO`1n^^uF2*x&vJ$2d$>+`(romzHP|)K_KkO6Hc>_dwMW-M(#S zK(~SiXT1@fvc#U+?|?PniDRm01)f^#55;nhM|wi?oG>yBsa?~?^xTU|fX-R(sTA+5 zaq}-8Tx7zrOy#3*JLIIVsBmHYLdD}!0NP!+ITW+Thn0)8SS!$@)HXwB3tY!fMxc#1 zMp3H?q3eD?u&Njx4;KQ5G>32+GRp1Ee5qMO0lZjaRRu&{W<&~DoJNGkcYF<5(Ab+J zgO>VhBl{okDPn78<%&e2mR{jwVCz5Og;*Z;;3%VvoGo_;HaGLWYF7q#jDX=Z#Ml`H z858YVV$%J|e<1n`%6Vsvq7GmnAV0wW4$5qQ3uR@1i>tW{xrl|ExywIc?fNgYlA?C5 zh$ezAFb5{rQu6i7BSS5*J-|9DQ{6^BVQ{b*lq`xS@RyrsJN?-t=MTMPY;WYeKBCNg z^2|pN!Q^WPJuuO4!|P@jzt&tY1Y8d%FNK5xK(!@`jO2aEA*4 zkO6b|UVBipci?){-Ke=+1;mGlND8)6+P;8sq}UXw2hn;fc7nM>g}GSMWu&v&fqh

iViYT=fZ(|3Ox^$aWPp4a8h24tD<|8-!aK0lHgL$N7Efw}J zVIB!7=T$U`ao1?upi5V4Et*-lTG0XvExbf!ya{cua==$WJyVG(CmA6Of*8E@DSE%L z`V^$qz&RU$7G5mg;8;=#`@rRG`-uS18$0WPN@!v2d{H2sOqP|!(cQ@ zUHo!d>>yFArLPf1q`uBvY32miqShLT1B@gDL4XoVTK&@owOoD)OIHXrYK-a1d$B{v zF^}8D3Y^g%^cnvScOSJR5QNH+BI%d|;J;wWM3~l>${fb8DNPg)wrf|GBP8p%LNGN# z3EaIiItgwtGgT&iYCFy9-LG}bMI|4LdmmJt@V@% zb6B)1kc=T)(|L@0;wr<>=?r04N;E&ef+7C^`wPWtyQe(*pD1pI_&XHy|0gIGHMekd zF_*M4yi6J&Z4LQj65)S zXwdM{SwUo%3SbPwFsHgqF@V|6afT|R6?&S;lw=8% z3}@9B=#JI3@B*#4s!O))~z zc>2_4Q_#&+5V`GFd?88^;c1i7;Vv_I*qt!_Yx*n=;rj!82rrR2rQ8u5(Ejlo{15P% zs~!{%XJ>FmJ})H^I9bn^Re&38H{xA!0l3^89k(oU;bZWXM@kn$#aoS&Y4l^-WEn-fH39Jb9lA%s*WsKJQl?n9B7_~P z-XM&WL7Z!PcoF6_D>V@$CvUIEy=+Z&0kt{szMk=f1|M+r*a43^$$B^MidrT0J;RI` z(?f!O<8UZkm$_Ny$Hth1J#^4ni+im8M9mr&k|3cIgwvjAgjH z8`N&h25xV#v*d$qBX5jkI|xOhQn!>IYZK7l5#^P4M&twe9&Ey@@GxYMxBZq2e7?`q z$~Szs0!g{2fGcp9PZEt|rdQ6bhAgpcLHPz?f-vB?$dc*!9OL?Q8mn7->bFD2Si60* z!O%y)fCdMSV|lkF9w%x~J*A&srMyYY3{=&$}H zGQ4VG_?$2X(0|vT0{=;W$~icCI{b6W{B!Q8xdGhF|D{25G_5_+%s(46lhvNLkik~R z>nr(&C#5wwOzJZQo9m|U<;&Wk!_#q|V>fsmj1g<6%hB{jGoNUPjgJslld>xmODzGjYc?7JSuA?A_QzjDw5AsRgi@Y|Z0{F{!1=!NES-#*f^s4l0Hu zz468))2IY5dmD9pa*(yT5{EyP^G>@ZWumealS-*WeRcZ}B%gxq{MiJ|RyX-^C1V=0 z@iKdrGi1jTe8Ya^x7yyH$kBNvM4R~`fbPq$BzHum-3Zo8C6=KW@||>zsA8-Y9uV5V z#oq-f5L5}V<&wF4@X@<3^C%ptp6+Ce)~hGl`kwj)bsAjmo_GU^r940Z-|`<)oGnh7 zFF0Tde3>ui?8Yj{sF-Z@)yQd~CGZ*w-6p2U<8}JO-sRsVI5dBji`01W8A&3$?}lxBaC&vn0E$c5tW* zX>5(zzZ=qn&!J~KdsPl;P@bmA-Pr8T*)eh_+Dv5=Ma|XSle6t(k8qcgNyar{*ReQ8 zTXwi=8vr>!3Ywr+BhggHDw8ke==NTQVMCK`$69fhzEFB*4+H9LIvdt-#IbhZvpS}} zO3lz;P?zr0*0$%-Rq_y^k(?I{Mk}h@w}cZpMUp|ucs55bcloL2)($u%mXQw({Wzc~ z;6nu5MkjP)0C(@%6Q_I_vsWrfhl7Zpoxw#WoE~r&GOSCz;_ro6i(^hM>I$8y>`!wW z*U^@?B!MMmb89I}2(hcE4zN2G^kwyWCZp5JG>$Ez7zP~D=J^LMjSM)27_0B_X^C(M z`fFT+%DcKlu?^)FCK>QzSnV%IsXVcUFhFdBP!6~se&xxrIxsvySAWu++IrH;FbcY$ z2DWTvSBRfLwdhr0nMx+URA$j3i7_*6BWv#DXfym?ZRDcX9C?cY9sD3q)uBDR3uWg= z(lUIzB)G$Hr!){>E{s4Dew+tb9kvToZp-1&c?y2wn@Z~(VBhqz`cB;{E4(P3N2*nJ z_>~g@;UF2iG{Kt(<1PyePTKahF8<)pozZ*xH~U-kfoAayCwJViIrnqwqO}7{0pHw$ zs2Kx?s#vQr7XZ264>5RNKSL8|Ty^=PsIx^}QqOOcfpGUU4tRkUc|kc7-!Ae6!+B{o~7nFpm3|G5^=0#Bnm6`V}oSQlrX(u%OWnC zoLPy&Q;1Jui&7ST0~#+}I^&?vcE*t47~Xq#YwvA^6^} z`WkC)$AkNub|t@S!$8CBlwbV~?yp&@9h{D|3z-vJXgzRC5^nYm+PyPcgRzAnEi6Q^gslXYRv4nycsy-SJu?lMps-? zV`U*#WnFsdPLL)Q$AmD|0`UaC4ND07+&UmOu!eHruzV|OUox<+Jl|Mr@6~C`T@P%s zW7sgXLF2SSe9Fl^O(I*{9wsFSYb2l%-;&Pi^dpv!{)C3d0AlNY6!4fgmSgj_wQ*7Am7&$z;Jg&wgR-Ih;lUvWS|KTSg!&s_E9_bXBkZvGiC6bFKDWZxsD$*NZ#_8bl zG1P-#@?OQzED7@jlMJTH@V!6k;W>auvft)}g zhoV{7$q=*;=l{O>Q4a@ ziMjf_u*o^PsO)#BjC%0^h>Xp@;5$p{JSYDt)zbb}s{Kbt!T*I@Pk@X0zds6wsefuU zW$XY%yyRGC94=6mf?x+bbA5CDQ2AgW1T-jVAJbm7K(gp+;v6E0WI#kuACgV$r}6L? zd|Tj?^%^*N&b>Dd{Wr$FS2qI#Ucs1yd4N+RBUQiSZGujH`#I)mG&VKoDh=KKFl4=G z&MagXl6*<)$6P}*Tiebpz5L=oMaPrN+caUXRJ`D?=K9!e0f{@D&cZLKN?iNP@X0aF zE(^pl+;*T5qt?1jRC=5PMgV!XNITRLS_=9{CJExaQj;lt!&pdzpK?8p>%Mb+D z?yO*uSung=-`QQ@yX@Hyd4@CI^r{2oiu`%^bNkz+Nkk!IunjwNC|WcqvX~k=><-I3 zDQdbdb|!v+Iz01$w@aMl!R)koD77Xp;eZwzSl-AT zr@Vu{=xvgfq9akRrrM)}=!=xcs+U1JO}{t(avgz`6RqiiX<|hGG1pmop8k6Q+G_mv zJv|RfDheUp2L3=^C=4aCBMBn0aRCU(DQwX-W(RkRwmLeuJYF<0urcaf(=7)JPg<3P zQs!~G)9CT18o!J4{zX{_e}4eS)U-E)0FAt}wEI(c0%HkxgggW;(1E=>J17_hsH^sP z%lT0LGgbUXHx-K*CI-MCrP66UP0PvGqM$MkeLyqHdbgP|_Cm!7te~b8p+e6sQ_3k| zVcwTh6d83ltdnR>D^)BYQpDKlLk3g0Hdcgz2}%qUs9~~Rie)A-BV1mS&naYai#xcZ z(d{8=-LVpTp}2*y)|gR~;qc7fp26}lPcLZ#=JpYcn3AT9(UIdOyg+d(P5T7D&*P}# zQCYplZO5|7+r19%9e`v^vfSS1sbX1c%=w1;oyruXB%Kl$ACgKQ6=qNWLsc=28xJjg zwvsI5-%SGU|3p>&zXVl^vVtQT3o-#$UT9LI@Npz~6=4!>mc431VRNN8od&Ul^+G_kHC`G=6WVWM z%9eWNyy(FTO|A+@x}Ou3CH)oi;t#7rAxdIXfNFwOj_@Y&TGz6P_sqiB`Q6Lxy|Q{`|fgmRG(k+!#b*M+Z9zFce)f-7;?Km5O=LHV9f9_87; zF7%R2B+$?@sH&&-$@tzaPYkw0;=i|;vWdI|Wl3q_Zu>l;XdIw2FjV=;Mq5t1Q0|f< zs08j54Bp`3RzqE=2enlkZxmX6OF+@|2<)A^RNQpBd6o@OXl+i)zO%D4iGiQNuXd+zIR{_lb96{lc~bxsBveIw6umhShTX+3@ZJ=YHh@ zWY3(d0azg;7oHn>H<>?4@*RQbi>SmM=JrHvIG(~BrvI)#W(EAeO6fS+}mxxcc+X~W6&YVl86W9WFSS}Vz-f9vS?XUDBk)3TcF z8V?$4Q)`uKFq>xT=)Y9mMFVTUk*NIA!0$?RP6Ig0TBmUFrq*Q-Agq~DzxjStQyJ({ zBeZ;o5qUUKg=4Hypm|}>>L=XKsZ!F$yNTDO)jt4H0gdQ5$f|d&bnVCMMXhNh)~mN z@_UV6D7MVlsWz+zM+inZZp&P4fj=tm6fX)SG5H>OsQf_I8c~uGCig$GzuwViK54bcgL;VN|FnyQl>Ed7(@>=8$a_UKIz|V6CeVSd2(P z0Uu>A8A+muM%HLFJQ9UZ5c)BSAv_zH#1f02x?h9C}@pN@6{>UiAp>({Fn(T9Q8B z^`zB;kJ5b`>%dLm+Ol}ty!3;8f1XDSVX0AUe5P#@I+FQ-`$(a;zNgz)4x5hz$Hfbg z!Q(z26wHLXko(1`;(BAOg_wShpX0ixfWq3ponndY+u%1gyX)_h=v1zR#V}#q{au6; z!3K=7fQwnRfg6FXtNQmP>`<;!N137paFS%y?;lb1@BEdbvQHYC{976l`cLqn;b8lp zIDY>~m{gDj(wfnK!lpW6pli)HyLEiUrNc%eXTil|F2s(AY+LW5hkKb>TQ3|Q4S9rr zpDs4uK_co6XPsn_z$LeS{K4jFF`2>U`tbgKdyDne`xmR<@6AA+_hPNKCOR-Zqv;xk zu5!HsBUb^!4uJ7v0RuH-7?l?}b=w5lzzXJ~gZcxRKOovSk@|#V+MuX%Y+=;14i*%{)_gSW9(#4%)AV#3__kac1|qUy!uyP{>?U#5wYNq}y$S9pCc zFc~4mgSC*G~j0u#qqp9 z${>3HV~@->GqEhr_Xwoxq?Hjn#=s2;i~g^&Hn|aDKpA>Oc%HlW(KA1?BXqpxB;Ydx)w;2z^MpjJ(Qi(X!$5RC z*P{~%JGDQqojV>2JbEeCE*OEu!$XJ>bWA9Oa_Hd;y)F%MhBRi*LPcdqR8X`NQ&1L# z5#9L*@qxrx8n}LfeB^J{%-?SU{FCwiWyHp682F+|pa+CQa3ZLzBqN1{)h4d6+vBbV zC#NEbQLC;}me3eeYnOG*nXOJZEU$xLZ1<1Y=7r0(-U0P6-AqwMAM`a(Ed#7vJkn6plb4eI4?2y3yOTGmmDQ!z9`wzbf z_OY#0@5=bnep;MV0X_;;SJJWEf^E6Bd^tVJ9znWx&Ks8t*B>AM@?;D4oWUGc z!H*`6d7Cxo6VuyS4Eye&L1ZRhrRmN6Lr`{NL(wDbif|y&z)JN>Fl5#Wi&mMIr5i;x zBx}3YfF>>8EC(fYnmpu~)CYHuHCyr5*`ECap%t@y=jD>!_%3iiE|LN$mK9>- zHdtpy8fGZtkZF?%TW~29JIAfi2jZT8>OA7=h;8T{{k?c2`nCEx9$r zS+*&vt~2o^^J+}RDG@+9&M^K*z4p{5#IEVbz`1%`m5c2};aGt=V?~vIM}ZdPECDI)47|CWBCfDWUbxBCnmYivQ*0Nu_xb*C>~C9(VjHM zxe<*D<#dQ8TlpMX2c@M<9$w!RP$hpG4cs%AI){jp*Sj|*`m)5(Bw*A0$*i-(CA5#%>a)$+jI2C9r6|(>J8InryENI z$NohnxDUB;wAYDwrb*!N3noBTKPpPN}~09SEL18tkG zxgz(RYU_;DPT{l?Q$+eaZaxnsWCA^ds^0PVRkIM%bOd|G2IEBBiz{&^JtNsODs;5z zICt_Zj8wo^KT$7Bg4H+y!Df#3mbl%%?|EXe!&(Vmac1DJ*y~3+kRKAD=Ovde4^^%~ zw<9av18HLyrf*_>Slp;^i`Uy~`mvBjZ|?Ad63yQa#YK`4+c6;pW4?XIY9G1(Xh9WO8{F-Aju+nS9Vmv=$Ac0ienZ+p9*O%NG zMZKy5?%Z6TAJTE?o5vEr0r>f>hb#2w2U3DL64*au_@P!J!TL`oH2r*{>ffu6|A7tv zL4juf$DZ1MW5ZPsG!5)`k8d8c$J$o;%EIL0va9&GzWvkS%ZsGb#S(?{!UFOZ9<$a| zY|a+5kmD5N&{vRqkgY>aHsBT&`rg|&kezoD)gP0fsNYHsO#TRc_$n6Lf1Z{?+DLziXlHrq4sf(!>O{?Tj;Eh@%)+nRE_2VxbN&&%%caU#JDU%vL3}Cb zsb4AazPI{>8H&d=jUaZDS$-0^AxE@utGs;-Ez_F(qC9T=UZX=>ok2k2 ziTn{K?y~a5reD2A)P${NoI^>JXn>`IeArow(41c-Wm~)wiryEP(OS{YXWi7;%dG9v zI?mwu1MxD{yp_rrk!j^cKM)dc4@p4Ezyo%lRN|XyD}}>v=Xoib0gOcdXrQ^*61HNj z=NP|pd>@yfvr-=m{8$3A8TQGMTE7g=z!%yt`8`Bk-0MMwW~h^++;qyUP!J~ykh1GO z(FZ59xuFR$(WE;F@UUyE@Sp>`aVNjyj=Ty>_Vo}xf`e7`F;j-IgL5`1~-#70$9_=uBMq!2&1l zomRgpD58@)YYfvLtPW}{C5B35R;ZVvB<<#)x%srmc_S=A7F@DW8>QOEGwD6suhwCg z>Pa+YyULhmw%BA*4yjDp|2{!T98~<6Yfd(wo1mQ!KWwq0eg+6)o1>W~f~kL<-S+P@$wx*zeI|1t7z#Sxr5 zt6w+;YblPQNplq4Z#T$GLX#j6yldXAqj>4gAnnWtBICUnA&-dtnlh=t0Ho_vEKwV` z)DlJi#!@nkYV#$!)@>udAU*hF?V`2$Hf=V&6PP_|r#Iv*J$9)pF@X3`k;5})9^o4y z&)~?EjX5yX12O(BsFy-l6}nYeuKkiq`u9145&3Ssg^y{5G3Pse z9w(YVa0)N-fLaBq1`P!_#>SS(8fh_5!f{UrgZ~uEdeMJIz7DzI5!NHHqQtm~#CPij z?=N|J>nPR6_sL7!f4hD_|KH`vf8(Wpnj-(gPWH+ZvID}%?~68SwhPTC3u1_cB`otq z)U?6qo!ZLi5b>*KnYHWW=3F!p%h1;h{L&(Q&{qY6)_qxNfbP6E3yYpW!EO+IW3?@J z);4>g4gnl^8klu7uA>eGF6rIGSynacogr)KUwE_R4E5Xzi*Qir@b-jy55-JPC8c~( zo!W8y9OGZ&`xmc8;=4-U9=h{vCqfCNzYirONmGbRQlR`WWlgnY+1wCXbMz&NT~9*| z6@FrzP!LX&{no2!Ln_3|I==_4`@}V?4a;YZKTdw;vT<+K+z=uWbW(&bXEaWJ^W8Td z-3&1bY^Z*oM<=M}LVt>_j+p=2Iu7pZmbXrhQ_k)ysE9yXKygFNw$5hwDn(M>H+e1&9BM5!|81vd%r%vEm zqxY3?F@fb6O#5UunwgAHR9jp_W2zZ}NGp2%mTW@(hz7$^+a`A?mb8|_G*GNMJ) zjqegXQio=i@AINre&%ofexAr95aop5C+0MZ0m-l=MeO8m3epm7U%vZB8+I+C*iNFM z#T3l`gknX;D$-`2XT^Cg*vrv=RH+P;_dfF++cP?B_msQI4j+lt&rX2)3GaJx%W*Nn zkML%D{z5tpHH=dksQ*gzc|}gzW;lwAbxoR07VNgS*-c3d&8J|;@3t^ zVUz*J*&r7DFRuFVDCJDK8V9NN5hvpgGjwx+5n)qa;YCKe8TKtdnh{I7NU9BCN!0dq zczrBk8pE{{@vJa9ywR@mq*J=v+PG;?fwqlJVhijG!3VmIKs>9T6r7MJpC)m!Tc#>g zMtVsU>wbwFJEfwZ{vB|ZlttNe83)$iz`~#8UJ^r)lJ@HA&G#}W&ZH*;k{=TavpjWE z7hdyLZPf*X%Gm}i`Y{OGeeu^~nB8=`{r#TUrM-`;1cBvEd#d!kPqIgYySYhN-*1;L z^byj%Yi}Gx)Wnkosi337BKs}+5H5dth1JA{Ir-JKN$7zC)*}hqeoD(WfaUDPT>0`- z(6sa0AoIqASwF`>hP}^|)a_j2s^PQn*qVC{Q}htR z5-)duBFXT_V56-+UohKXlq~^6uf!6sA#ttk1o~*QEy_Y-S$gAvq47J9Vtk$5oA$Ct zYhYJ@8{hsC^98${!#Ho?4y5MCa7iGnfz}b9jE~h%EAAv~Qxu)_rAV;^cygV~5r_~?l=B`zObj7S=H=~$W zPtI_m%g$`kL_fVUk9J@>EiBH zOO&jtn~&`hIFMS5S`g8w94R4H40mdNUH4W@@XQk1sr17b{@y|JB*G9z1|CrQjd+GX z6+KyURG3;!*BQrentw{B2R&@2&`2}n(z-2&X7#r!{yg@Soy}cRD~j zj9@UBW+N|4HW4AWapy4wfUI- zZ`gSL6DUlgj*f1hSOGXG0IVH8HxK?o2|3HZ;KW{K+yPAlxtb)NV_2AwJm|E)FRs&& z=c^e7bvUsztY|+f^k7NXs$o1EUq>cR7C0$UKi6IooHWlK_#?IWDkvywnzg&ThWo^? z2O_N{5X39#?eV9l)xI(>@!vSB{DLt*oY!K1R8}_?%+0^C{d9a%N4 zoxHVT1&Lm|uDX%$QrBun5e-F`HJ^T$ zmzv)p@4ZHd_w9!%Hf9UYNvGCw2TTTbrj9pl+T9%-_-}L(tES>Or-}Z4F*{##n3~L~TuxjirGuIY#H7{%$E${?p{Q01 zi6T`n;rbK1yIB9jmQNycD~yZq&mbIsFWHo|ZAChSFPQa<(%d8mGw*V3fh|yFoxOOiWJd(qvVb!Z$b88cg->N=qO*4k~6;R==|9ihg&riu#P~s4Oap9O7f%crSr^rljeIfXDEg>wi)&v*a%7zpz<9w z*r!3q9J|390x`Zk;g$&OeN&ctp)VKRpDSV@kU2Q>jtok($Y-*x8_$2piTxun81@vt z!Vj?COa0fg2RPXMSIo26T=~0d`{oGP*eV+$!0I<(4azk&Vj3SiG=Q!6mX0p$z7I}; z9BJUFgT-K9MQQ-0@Z=^7R<{bn2Fm48endsSs`V7_@%8?Bxkqv>BDoVcj?K#dV#uUP zL1ND~?D-|VGKe3Rw_7-Idpht>H6XRLh*U7epS6byiGvJpr%d}XwfusjH9g;Z98H`x zyde%%5mhGOiL4wljCaWCk-&uE4_OOccb9c!ZaWt4B(wYl!?vyzl%7n~QepN&eFUrw zFIOl9c({``6~QD+43*_tzP{f2x41h(?b43^y6=iwyB)2os5hBE!@YUS5?N_tXd=h( z)WE286Fbd>R4M^P{!G)f;h<3Q>Fipuy+d2q-)!RyTgt;wr$(?9ox3;q+{E*ZQHhOn;lM`cjnu9 zXa48ks-v(~b*;MAI<>YZH(^NV8vjb34beE<_cwKlJoR;k6lJNSP6v}uiyRD?|0w+X@o1ONrH8a$fCxXpf? z?$DL0)7|X}Oc%h^zrMKWc-NS9I0Utu@>*j}b@tJ=ixQSJ={4@854wzW@E>VSL+Y{i z#0b=WpbCZS>kUCO_iQz)LoE>P5LIG-hv9E+oG}DtlIDF>$tJ1aw9^LuhLEHt?BCj& z(O4I8v1s#HUi5A>nIS-JK{v!7dJx)^Yg%XjNmlkWAq2*cv#tHgz`Y(bETc6CuO1VkN^L-L3j_x<4NqYb5rzrLC-7uOv z!5e`GZt%B782C5-fGnn*GhDF$%(qP<74Z}3xx+{$4cYKy2ikxI7B2N+2r07DN;|-T->nU&!=Cm#rZt%O_5c&1Z%nlWq3TKAW0w zQqemZw_ue--2uKQsx+niCUou?HjD`xhEjjQd3%rrBi82crq*~#uA4+>vR<_S{~5ce z-2EIl?~s z1=GVL{NxP1N3%=AOaC}j_Fv=ur&THz zyO!d9kHq|c73kpq`$+t+8Bw7MgeR5~`d7ChYyGCBWSteTB>8WAU(NPYt2Dk`@#+}= zI4SvLlyk#pBgVigEe`?NG*vl7V6m+<}%FwPV=~PvvA)=#ths==DRTDEYh4V5}Cf$z@#;< zyWfLY_5sP$gc3LLl2x+Ii)#b2nhNXJ{R~vk`s5U7Nyu^3yFg&D%Txwj6QezMX`V(x z=C`{76*mNb!qHHs)#GgGZ_7|vkt9izl_&PBrsu@}L`X{95-2jf99K)0=*N)VxBX2q z((vkpP2RneSIiIUEnGb?VqbMb=Zia+rF~+iqslydE34cSLJ&BJW^3knX@M;t*b=EA zNvGzv41Ld_T+WT#XjDB840vovUU^FtN_)G}7v)1lPetgpEK9YS^OWFkPoE{ovj^=@ zO9N$S=G$1ecndT_=5ehth2Lmd1II-PuT~C9`XVePw$y8J#dpZ?Tss<6wtVglm(Ok7 z3?^oi@pPio6l&!z8JY(pJvG=*pI?GIOu}e^EB6QYk$#FJQ%^AIK$I4epJ+9t?KjqA+bkj&PQ*|vLttme+`9G=L% ziadyMw_7-M)hS(3E$QGNCu|o23|%O+VN7;Qggp?PB3K-iSeBa2b}V4_wY`G1Jsfz4 z9|SdB^;|I8E8gWqHKx!vj_@SMY^hLEIbSMCuE?WKq=c2mJK z8LoG-pnY!uhqFv&L?yEuxo{dpMTsmCn)95xanqBrNPTgXP((H$9N${Ow~Is-FBg%h z53;|Y5$MUN)9W2HBe2TD`ct^LHI<(xWrw}$qSoei?}s)&w$;&!14w6B6>Yr6Y8b)S z0r71`WmAvJJ`1h&poLftLUS6Ir zC$bG9!Im_4Zjse)#K=oJM9mHW1{%l8sz$1o?ltdKlLTxWWPB>Vk22czVt|1%^wnN@*!l)}?EgtvhC>vlHm^t+ogpgHI1_$1ox9e;>0!+b(tBrmXRB`PY1vp-R**8N7 zGP|QqI$m(Rdu#=(?!(N}G9QhQ%o!aXE=aN{&wtGP8|_qh+7a_j_sU5|J^)vxq;# zjvzLn%_QPHZZIWu1&mRAj;Sa_97p_lLq_{~j!M9N^1yp3U_SxRqK&JnR%6VI#^E12 z>CdOVI^_9aPK2eZ4h&^{pQs}xsijXgFYRIxJ~N7&BB9jUR1fm!(xl)mvy|3e6-B3j zJn#ajL;bFTYJ2+Q)tDjx=3IklO@Q+FFM}6UJr6km7hj7th9n_&JR7fnqC!hTZoM~T zBeaVFp%)0cbPhejX<8pf5HyRUj2>aXnXBqDJe73~J%P(2C?-RT{c3NjE`)om! zl$uewSgWkE66$Kb34+QZZvRn`fob~Cl9=cRk@Es}KQm=?E~CE%spXaMO6YmrMl%9Q zlA3Q$3|L1QJ4?->UjT&CBd!~ru{Ih^in&JXO=|<6J!&qp zRe*OZ*cj5bHYlz!!~iEKcuE|;U4vN1rk$xq6>bUWD*u(V@8sG^7>kVuo(QL@Ki;yL zWC!FT(q{E8#on>%1iAS0HMZDJg{Z{^!De(vSIq&;1$+b)oRMwA3nc3mdTSG#3uYO_ z>+x;7p4I;uHz?ZB>dA-BKl+t-3IB!jBRgdvAbW!aJ(Q{aT>+iz?91`C-xbe)IBoND z9_Xth{6?(y3rddwY$GD65IT#f3<(0o#`di{sh2gm{dw*#-Vnc3r=4==&PU^hCv$qd zjw;>i&?L*Wq#TxG$mFIUf>eK+170KG;~+o&1;Tom9}}mKo23KwdEM6UonXgc z!6N(@k8q@HPw{O8O!lAyi{rZv|DpgfU{py+j(X_cwpKqcalcqKIr0kM^%Br3SdeD> zHSKV94Yxw;pjzDHo!Q?8^0bb%L|wC;4U^9I#pd5O&eexX+Im{ z?jKnCcsE|H?{uGMqVie_C~w7GX)kYGWAg%-?8|N_1#W-|4F)3YTDC+QSq1s!DnOML3@d`mG%o2YbYd#jww|jD$gotpa)kntakp#K;+yo-_ZF9qrNZw<%#C zuPE@#3RocLgPyiBZ+R_-FJ_$xP!RzWm|aN)S+{$LY9vvN+IW~Kf3TsEIvP+B9Mtm! zpfNNxObWQpLoaO&cJh5>%slZnHl_Q~(-Tfh!DMz(dTWld@LG1VRF`9`DYKhyNv z2pU|UZ$#_yUx_B_|MxUq^glT}O5Xt(Vm4Mr02><%C)@v;vPb@pT$*yzJ4aPc_FZ3z z3}PLoMBIM>q_9U2rl^sGhk1VUJ89=*?7|v`{!Z{6bqFMq(mYiA?%KbsI~JwuqVA9$H5vDE+VocjX+G^%bieqx->s;XWlKcuv(s%y%D5Xbc9+ zc(_2nYS1&^yL*ey664&4`IoOeDIig}y-E~_GS?m;D!xv5-xwz+G`5l6V+}CpeJDi^ z%4ed$qowm88=iYG+(`ld5Uh&>Dgs4uPHSJ^TngXP_V6fPyl~>2bhi20QB%lSd#yYn zO05?KT1z@?^-bqO8Cg`;ft>ilejsw@2%RR7;`$Vs;FmO(Yr3Fp`pHGr@P2hC%QcA|X&N2Dn zYf`MqXdHi%cGR@%y7Rg7?d3?an){s$zA{!H;Ie5exE#c~@NhQUFG8V=SQh%UxUeiV zd7#UcYqD=lk-}sEwlpu&H^T_V0{#G?lZMxL7ih_&{(g)MWBnCZxtXg znr#}>U^6!jA%e}@Gj49LWG@*&t0V>Cxc3?oO7LSG%~)Y5}f7vqUUnQ;STjdDU}P9IF9d9<$;=QaXc zL1^X7>fa^jHBu_}9}J~#-oz3Oq^JmGR#?GO7b9a(=R@fw@}Q{{@`Wy1vIQ#Bw?>@X z-_RGG@wt|%u`XUc%W{J z>iSeiz8C3H7@St3mOr_mU+&bL#Uif;+Xw-aZdNYUpdf>Rvu0i0t6k*}vwU`XNO2he z%miH|1tQ8~ZK!zmL&wa3E;l?!!XzgV#%PMVU!0xrDsNNZUWKlbiOjzH-1Uoxm8E#r`#2Sz;-o&qcqB zC-O_R{QGuynW14@)7&@yw1U}uP(1cov)twxeLus0s|7ayrtT8c#`&2~Fiu2=R;1_4bCaD=*E@cYI>7YSnt)nQc zohw5CsK%m?8Ack)qNx`W0_v$5S}nO|(V|RZKBD+btO?JXe|~^Qqur%@eO~<8-L^9d z=GA3-V14ng9L29~XJ>a5k~xT2152zLhM*@zlp2P5Eu}bywkcqR;ISbas&#T#;HZSf z2m69qTV(V@EkY(1Dk3`}j)JMo%ZVJ*5eB zYOjIisi+igK0#yW*gBGj?@I{~mUOvRFQR^pJbEbzFxTubnrw(Muk%}jI+vXmJ;{Q6 zrSobKD>T%}jV4Ub?L1+MGOD~0Ir%-`iTnWZN^~YPrcP5y3VMAzQ+&en^VzKEb$K!Q z<7Dbg&DNXuow*eD5yMr+#08nF!;%4vGrJI++5HdCFcGLfMW!KS*Oi@=7hFwDG!h2< zPunUEAF+HncQkbfFj&pbzp|MU*~60Z(|Ik%Tn{BXMN!hZOosNIseT?R;A`W?=d?5X zK(FB=9mZusYahp|K-wyb={rOpdn=@;4YI2W0EcbMKyo~-#^?h`BA9~o285%oY zfifCh5Lk$SY@|2A@a!T2V+{^!psQkx4?x0HSV`(w9{l75QxMk!)U52Lbhn{8ol?S) zCKo*7R(z!uk<6*qO=wh!Pul{(qq6g6xW;X68GI_CXp`XwO zxuSgPRAtM8K7}5E#-GM!*ydOOG_{A{)hkCII<|2=ma*71ci_-}VPARm3crFQjLYV! z9zbz82$|l01mv`$WahE2$=fAGWkd^X2kY(J7iz}WGS z@%MyBEO=A?HB9=^?nX`@nh;7;laAjs+fbo!|K^mE!tOB>$2a_O0y-*uaIn8k^6Y zSbuv;5~##*4Y~+y7Z5O*3w4qgI5V^17u*ZeupVGH^nM&$qmAk|anf*>r zWc5CV;-JY-Z@Uq1Irpb^O`L_7AGiqd*YpGUShb==os$uN3yYvb`wm6d=?T*it&pDk zo`vhw)RZX|91^^Wa_ti2zBFyWy4cJu#g)_S6~jT}CC{DJ_kKpT`$oAL%b^!2M;JgT zM3ZNbUB?}kP(*YYvXDIH8^7LUxz5oE%kMhF!rnPqv!GiY0o}NR$OD=ITDo9r%4E>E0Y^R(rS^~XjWyVI6 zMOR5rPXhTp*G*M&X#NTL`Hu*R+u*QNoiOKg4CtNPrjgH>c?Hi4MUG#I917fx**+pJfOo!zFM&*da&G_x)L(`k&TPI*t3e^{crd zX<4I$5nBQ8Ax_lmNRa~E*zS-R0sxkz`|>7q_?*e%7bxqNm3_eRG#1ae3gtV9!fQpY z+!^a38o4ZGy9!J5sylDxZTx$JmG!wg7;>&5H1)>f4dXj;B+@6tMlL=)cLl={jLMxY zbbf1ax3S4>bwB9-$;SN2?+GULu;UA-35;VY*^9Blx)Jwyb$=U!D>HhB&=jSsd^6yw zL)?a|>GxU!W}ocTC(?-%z3!IUhw^uzc`Vz_g>-tv)(XA#JK^)ZnC|l1`@CdX1@|!| z_9gQ)7uOf?cR@KDp97*>6X|;t@Y`k_N@)aH7gY27)COv^P3ya9I{4z~vUjLR9~z1Z z5=G{mVtKH*&$*t0@}-i_v|3B$AHHYale7>E+jP`ClqG%L{u;*ff_h@)al?RuL7tOO z->;I}>%WI{;vbLP3VIQ^iA$4wl6@0sDj|~112Y4OFjMs`13!$JGkp%b&E8QzJw_L5 zOnw9joc0^;O%OpF$Qp)W1HI!$4BaXX84`%@#^dk^hFp^pQ@rx4g(8Xjy#!X%+X5Jd@fs3amGT`}mhq#L97R>OwT5-m|h#yT_-v@(k$q7P*9X~T*3)LTdzP!*B} z+SldbVWrrwQo9wX*%FyK+sRXTa@O?WM^FGWOE?S`R(0P{<6p#f?0NJvnBia?k^fX2 zNQs7K-?EijgHJY}&zsr;qJ<*PCZUd*x|dD=IQPUK_nn)@X4KWtqoJNHkT?ZWL_hF? zS8lp2(q>;RXR|F;1O}EE#}gCrY~#n^O`_I&?&z5~7N;zL0)3Tup`%)oHMK-^r$NT% zbFg|o?b9w(q@)6w5V%si<$!U<#}s#x@0aX-hP>zwS#9*75VXA4K*%gUc>+yzupTDBOKH8WR4V0pM(HrfbQ&eJ79>HdCvE=F z|J>s;;iDLB^3(9}?biKbxf1$lI!*Z%*0&8UUq}wMyPs_hclyQQi4;NUY+x2qy|0J; zhn8;5)4ED1oHwg+VZF|80<4MrL97tGGXc5Sw$wAI#|2*cvQ=jB5+{AjMiDHmhUC*a zlmiZ`LAuAn_}hftXh;`Kq0zblDk8?O-`tnilIh|;3lZp@F_osJUV9`*R29M?7H{Fy z`nfVEIDIWXmU&YW;NjU8)EJpXhxe5t+scf|VXM!^bBlwNh)~7|3?fWwo_~ZFk(22% zTMesYw+LNx3J-_|DM~`v93yXe=jPD{q;li;5PD?Dyk+b? zo21|XpT@)$BM$%F=P9J19Vi&1#{jM3!^Y&fr&_`toi`XB1!n>sbL%U9I5<7!@?t)~ z;&H%z>bAaQ4f$wIzkjH70;<8tpUoxzKrPhn#IQfS%9l5=Iu))^XC<58D!-O z{B+o5R^Z21H0T9JQ5gNJnqh#qH^na|z92=hONIM~@_iuOi|F>jBh-?aA20}Qx~EpDGElELNn~|7WRXRFnw+Wdo`|# zBpU=Cz3z%cUJ0mx_1($X<40XEIYz(`noWeO+x#yb_pwj6)R(__%@_Cf>txOQ74wSJ z0#F3(zWWaR-jMEY$7C*3HJrohc79>MCUu26mfYN)f4M~4gD`}EX4e}A!U}QV8!S47 z6y-U-%+h`1n`*pQuKE%Av0@)+wBZr9mH}@vH@i{v(m-6QK7Ncf17x_D=)32`FOjjo zg|^VPf5c6-!FxN{25dvVh#fog=NNpXz zfB$o+0jbRkHH{!TKhE709f+jI^$3#v1Nmf80w`@7-5$1Iv_`)W^px8P-({xwb;D0y z7LKDAHgX<84?l!I*Dvi2#D@oAE^J|g$3!)x1Ua;_;<@#l1fD}lqU2_tS^6Ht$1Wl} zBESo7o^)9-Tjuz$8YQSGhfs{BQV6zW7dA?0b(Dbt=UnQs&4zHfe_sj{RJ4uS-vQpC zX;Bbsuju4%!o8?&m4UZU@~ZZjeFF6ex2ss5_60_JS_|iNc+R0GIjH1@Z z=rLT9%B|WWgOrR7IiIwr2=T;Ne?30M!@{%Qf8o`!>=s<2CBpCK_TWc(DX51>e^xh8 z&@$^b6CgOd7KXQV&Y4%}_#uN*mbanXq(2=Nj`L7H7*k(6F8s6{FOw@(DzU`4-*77{ zF+dxpv}%mFpYK?>N_2*#Y?oB*qEKB}VoQ@bzm>ptmVS_EC(#}Lxxx730trt0G)#$b zE=wVvtqOct1%*9}U{q<)2?{+0TzZzP0jgf9*)arV)*e!f`|jgT{7_9iS@e)recI#z zbzolURQ+TOzE!ymqvBY7+5NnAbWxvMLsLTwEbFqW=CPyCsmJ}P1^V30|D5E|p3BC5 z)3|qgw@ra7aXb-wsa|l^in~1_fm{7bS9jhVRkYVO#U{qMp z)Wce+|DJ}4<2gp8r0_xfZpMo#{Hl2MfjLcZdRB9(B(A(f;+4s*FxV{1F|4d`*sRNd zp4#@sEY|?^FIJ;tmH{@keZ$P(sLh5IdOk@k^0uB^BWr@pk6mHy$qf&~rI>P*a;h0C{%oA*i!VjWn&D~O#MxN&f@1Po# zKN+ zrGrkSjcr?^R#nGl<#Q722^wbYcgW@{+6CBS<1@%dPA8HC!~a`jTz<`g_l5N1M@9wn9GOAZ>nqNgq!yOCbZ@1z`U_N`Z>}+1HIZxk*5RDc&rd5{3qjRh8QmT$VyS;jK z;AF+r6XnnCp=wQYoG|rT2@8&IvKq*IB_WvS%nt%e{MCFm`&W*#LXc|HrD?nVBo=(8*=Aq?u$sDA_sC_RPDUiQ+wnIJET8vx$&fxkW~kP9qXKt zozR)@xGC!P)CTkjeWvXW5&@2?)qt)jiYWWBU?AUtzAN}{JE1I)dfz~7$;}~BmQF`k zpn11qmObXwRB8&rnEG*#4Xax3XBkKlw(;tb?Np^i+H8m(Wyz9k{~ogba@laiEk;2! zV*QV^6g6(QG%vX5Um#^sT&_e`B1pBW5yVth~xUs#0}nv?~C#l?W+9Lsb_5)!71rirGvY zTIJ$OPOY516Y|_014sNv+Z8cc5t_V=i>lWV=vNu#!58y9Zl&GsMEW#pPYPYGHQ|;vFvd*9eM==$_=vc7xnyz0~ zY}r??$<`wAO?JQk@?RGvkWVJlq2dk9vB(yV^vm{=NVI8dhsX<)O(#nr9YD?I?(VmQ z^r7VfUBn<~p3()8yOBjm$#KWx!5hRW)5Jl7wY@ky9lNM^jaT##8QGVsYeaVywmpv>X|Xj7gWE1Ezai&wVLt3p)k4w~yrskT-!PR!kiyQlaxl(( zXhF%Q9x}1TMt3~u@|#wWm-Vq?ZerK={8@~&@9r5JW}r#45#rWii};t`{5#&3$W)|@ zbAf2yDNe0q}NEUvq_Quq3cTjcw z@H_;$hu&xllCI9CFDLuScEMg|x{S7GdV8<&Mq=ezDnRZAyX-8gv97YTm0bg=d)(>N z+B2FcqvI9>jGtnK%eO%y zoBPkJTk%y`8TLf4)IXPBn`U|9>O~WL2C~C$z~9|0m*YH<-vg2CD^SX#&)B4ngOSG$ zV^wmy_iQk>dfN@Pv(ckfy&#ak@MLC7&Q6Ro#!ezM*VEh`+b3Jt%m(^T&p&WJ2Oqvj zs-4nq0TW6cv~(YI$n0UkfwN}kg3_fp?(ijSV#tR9L0}l2qjc7W?i*q01=St0eZ=4h zyGQbEw`9OEH>NMuIe)hVwYHsGERWOD;JxEiO7cQv%pFCeR+IyhwQ|y@&^24k+|8fD zLiOWFNJ2&vu2&`Jv96_z-Cd5RLgmeY3*4rDOQo?Jm`;I_(+ejsPM03!ly!*Cu}Cco zrQSrEDHNyzT(D5s1rZq!8#?f6@v6dB7a-aWs(Qk>N?UGAo{gytlh$%_IhyL7h?DLXDGx zgxGEBQoCAWo-$LRvM=F5MTle`M})t3vVv;2j0HZY&G z22^iGhV@uaJh(XyyY%} zd4iH_UfdV#T=3n}(Lj^|n;O4|$;xhu*8T3hR1mc_A}fK}jfZ7LX~*n5+`8N2q#rI$ z@<_2VANlYF$vIH$ zl<)+*tIWW78IIINA7Rr7i{<;#^yzxoLNkXL)eSs=%|P>$YQIh+ea_3k z_s7r4%j7%&*NHSl?R4k%1>Z=M9o#zxY!n8sL5>BO-ZP;T3Gut>iLS@U%IBrX6BA3k z)&@q}V8a{X<5B}K5s(c(LQ=%v1ocr`t$EqqY0EqVjr65usa=0bkf|O#ky{j3)WBR(((L^wmyHRzoWuL2~WTC=`yZ zn%VX`L=|Ok0v7?s>IHg?yArBcync5rG#^+u)>a%qjES%dRZoIyA8gQ;StH z1Ao7{<&}6U=5}4v<)1T7t!J_CL%U}CKNs-0xWoTTeqj{5{?Be$L0_tk>M9o8 zo371}S#30rKZFM{`H_(L`EM9DGp+Mifk&IP|C2Zu_)Ghr4Qtpmkm1osCf@%Z$%t+7 zYH$Cr)Ro@3-QDeQJ8m+x6%;?YYT;k6Z0E-?kr>x33`H%*ueBD7Zx~3&HtWn0?2Wt} zTG}*|v?{$ajzt}xPzV%lL1t-URi8*Zn)YljXNGDb>;!905Td|mpa@mHjIH%VIiGx- zd@MqhpYFu4_?y5N4xiHn3vX&|e6r~Xt> zZG`aGq|yTNjv;9E+Txuoa@A(9V7g?1_T5FzRI;!=NP1Kqou1z5?%X~Wwb{trRfd>i z8&y^H)8YnKyA_Fyx>}RNmQIczT?w2J4SNvI{5J&}Wto|8FR(W;Qw#b1G<1%#tmYzQ zQ2mZA-PAdi%RQOhkHy9Ea#TPSw?WxwL@H@cbkZwIq0B!@ns}niALidmn&W?!Vd4Gj zO7FiuV4*6Mr^2xlFSvM;Cp_#r8UaqIzHJQg_z^rEJw&OMm_8NGAY2)rKvki|o1bH~ z$2IbfVeY2L(^*rMRU1lM5Y_sgrDS`Z??nR2lX;zyR=c%UyGb*%TC-Dil?SihkjrQy~TMv6;BMs7P8il`H7DmpVm@rJ;b)hW)BL)GjS154b*xq-NXq2cwE z^;VP7ua2pxvCmxrnqUYQMH%a%nHmwmI33nJM(>4LznvY*k&C0{8f*%?zggpDgkuz&JBx{9mfb@wegEl2v!=}Sq2Gaty0<)UrOT0{MZtZ~j5y&w zXlYa_jY)I_+VA-^#mEox#+G>UgvM!Ac8zI<%JRXM_73Q!#i3O|)lOP*qBeJG#BST0 zqohi)O!|$|2SeJQo(w6w7%*92S})XfnhrH_Z8qe!G5>CglP=nI7JAOW?(Z29;pXJ9 zR9`KzQ=WEhy*)WH>$;7Cdz|>*i>=##0bB)oU0OR>>N<21e4rMCHDemNi2LD>Nc$;& zQRFthpWniC1J6@Zh~iJCoLOxN`oCKD5Q4r%ynwgUKPlIEd#?QViIqovY|czyK8>6B zSP%{2-<;%;1`#0mG^B(8KbtXF;Nf>K#Di72UWE4gQ%(_26Koiad)q$xRL~?pN71ZZ zujaaCx~jXjygw;rI!WB=xrOJO6HJ!!w}7eiivtCg5K|F6$EXa)=xUC za^JXSX98W`7g-tm@uo|BKj39Dl;sg5ta;4qjo^pCh~{-HdLl6qI9Ix6f$+qiZ$}s= zNguKrU;u+T@ko(Vr1>)Q%h$?UKXCY>3se%&;h2osl2D zE4A9bd7_|^njDd)6cI*FupHpE3){4NQ*$k*cOWZ_?CZ>Z4_fl@n(mMnYK62Q1d@+I zr&O))G4hMihgBqRIAJkLdk(p(D~X{-oBUA+If@B}j& zsHbeJ3RzTq96lB7d($h$xTeZ^gP0c{t!Y0c)aQE;$FY2!mACg!GDEMKXFOPI^)nHZ z`aSPJpvV0|bbrzhWWkuPURlDeN%VT8tndV8?d)eN*i4I@u zVKl^6{?}A?P)Fsy?3oi#clf}L18t;TjNI2>eI&(ezDK7RyqFxcv%>?oxUlonv(px) z$vnPzRH`y5A(x!yOIfL0bmgeMQB$H5wenx~!ujQK*nUBW;@Em&6Xv2%s(~H5WcU2R z;%Nw<$tI)a`Ve!>x+qegJnQsN2N7HaKzrFqM>`6R*gvh%O*-%THt zrB$Nk;lE;z{s{r^PPm5qz(&lM{sO*g+W{sK+m3M_z=4=&CC>T`{X}1Vg2PEfSj2x_ zmT*(x;ov%3F?qoEeeM>dUn$a*?SIGyO8m806J1W1o+4HRhc2`9$s6hM#qAm zChQ87b~GEw{ADfs+5}FJ8+|bIlIv(jT$Ap#hSHoXdd9#w<#cA<1Rkq^*EEkknUd4& zoIWIY)sAswy6fSERVm&!SO~#iN$OgOX*{9@_BWFyJTvC%S++ilSfCrO(?u=Dc?CXZ zzCG&0yVR{Z`|ZF0eEApWEo#s9osV>F{uK{QA@BES#&;#KsScf>y zvs?vIbI>VrT<*!;XmQS=bhq%46-aambZ(8KU-wOO2=en~D}MCToB_u;Yz{)1ySrPZ z@=$}EvjTdzTWU7c0ZI6L8=yP+YRD_eMMos}b5vY^S*~VZysrkq<`cK3>>v%uy7jgq z0ilW9KjVDHLv0b<1K_`1IkbTOINs0=m-22c%M~l=^S}%hbli-3?BnNq?b`hx^HX2J zIe6ECljRL0uBWb`%{EA=%!i^4sMcj+U_TaTZRb+~GOk z^ZW!nky0n*Wb*r+Q|9H@ml@Z5gU&W`(z4-j!OzC1wOke`TRAYGZVl$PmQ16{3196( zO*?`--I}Qf(2HIwb2&1FB^!faPA2=sLg(@6P4mN)>Dc3i(B0;@O-y2;lM4akD>@^v z=u>*|!s&9zem70g7zfw9FXl1bpJW(C#5w#uy5!V?Q(U35A~$dR%LDVnq@}kQm13{} zd53q3N(s$Eu{R}k2esbftfjfOITCL;jWa$}(mmm}d(&7JZ6d3%IABCapFFYjdEjdK z&4Edqf$G^MNAtL=uCDRs&Fu@FXRgX{*0<(@c3|PNHa>L%zvxWS={L8%qw`STm+=Rd zA}FLspESSIpE_^41~#5yI2bJ=9`oc;GIL!JuW&7YetZ?0H}$$%8rW@*J37L-~Rsx!)8($nI4 zZhcZ2^=Y+p4YPl%j!nFJA|*M^gc(0o$i3nlphe+~-_m}jVkRN{spFs(o0ajW@f3K{ zDV!#BwL322CET$}Y}^0ixYj2w>&Xh12|R8&yEw|wLDvF!lZ#dOTHM9pK6@Nm-@9Lnng4ZHBgBSrr7KI8YCC9DX5Kg|`HsiwJHg2(7#nS;A{b3tVO?Z% za{m5b3rFV6EpX;=;n#wltDv1LE*|g5pQ+OY&*6qCJZc5oDS6Z6JD#6F)bWxZSF@q% z+1WV;m!lRB!n^PC>RgQCI#D1br_o^#iPk>;K2hB~0^<~)?p}LG%kigm@moD#q3PE+ zA^Qca)(xnqw6x>XFhV6ku9r$E>bWNrVH9fum0?4s?Rn2LG{Vm_+QJHse6xa%nzQ?k zKug4PW~#Gtb;#5+9!QBgyB@q=sk9=$S{4T>wjFICStOM?__fr+Kei1 z3j~xPqW;W@YkiUM;HngG!;>@AITg}vAE`M2Pj9Irl4w1fo4w<|Bu!%rh%a(Ai^Zhi zs92>v5;@Y(Zi#RI*ua*h`d_7;byQSa*v9E{2x$<-_=5Z<7{%)}4XExANcz@rK69T0x3%H<@frW>RA8^swA+^a(FxK| zFl3LD*ImHN=XDUkrRhp6RY5$rQ{bRgSO*(vEHYV)3Mo6Jy3puiLmU&g82p{qr0F?ohmbz)f2r{X2|T2 z$4fdQ=>0BeKbiVM!e-lIIs8wVTuC_m7}y4A_%ikI;Wm5$9j(^Y z(cD%U%k)X>_>9~t8;pGzL6L-fmQO@K; zo&vQzMlgY95;1BSkngY)e{`n0!NfVgf}2mB3t}D9@*N;FQ{HZ3Pb%BK6;5#-O|WI( zb6h@qTLU~AbVW#_6?c!?Dj65Now7*pU{h!1+eCV^KCuPAGs28~3k@ueL5+u|Z-7}t z9|lskE`4B7W8wMs@xJa{#bsCGDFoRSNSnmNYB&U7 zVGKWe%+kFB6kb)e;TyHfqtU6~fRg)f|>=5(N36)0+C z`hv65J<$B}WUc!wFAb^QtY31yNleq4dzmG`1wHTj=c*=hay9iD071Hc?oYoUk|M*_ zU1GihAMBsM@5rUJ(qS?9ZYJ6@{bNqJ`2Mr+5#hKf?doa?F|+^IR!8lq9)wS3tF_9n zW_?hm)G(M+MYb?V9YoX^_mu5h-LP^TL^!Q9Z7|@sO(rg_4+@=PdI)WL(B7`!K^ND- z-uIuVDCVEdH_C@c71YGYT^_Scf_dhB8Z2Xy6vGtBSlYud9vggOqv^L~F{BraSE_t} zIkP+Hp2&nH^-MNEs}^`oMLy11`PQW$T|K(`Bu*(f@)mv1-qY(_YG&J2M2<7k;;RK~ zL{Fqj9yCz8(S{}@c)S!65aF<=&eLI{hAMErCx&>i7OeDN>okvegO87OaG{Jmi<|}D zaT@b|0X{d@OIJ7zvT>r+eTzgLq~|Dpu)Z&db-P4z*`M$UL51lf>FLlq6rfG)%doyp z)3kk_YIM!03eQ8Vu_2fg{+osaEJPtJ-s36R+5_AEG12`NG)IQ#TF9c@$99%0iye+ zUzZ57=m2)$D(5Nx!n)=5Au&O0BBgwxIBaeI(mro$#&UGCr<;C{UjJVAbVi%|+WP(a zL$U@TYCxJ=1{Z~}rnW;7UVb7+ZnzgmrogDxhjLGo>c~MiJAWs&&;AGg@%U?Y^0JhL ze(x6Z74JG6FlOFK(T}SXQfhr}RIFl@QXKnIcXYF)5|V~e-}suHILKT-k|<*~Ij|VF zC;t@=uj=hot~*!C68G8hTA%8SzOfETOXQ|3FSaIEjvBJp(A)7SWUi5!Eu#yWgY+;n zlm<$+UDou*V+246_o#V4kMdto8hF%%Lki#zPh}KYXmMf?hrN0;>Mv%`@{0Qn`Ujp) z=lZe+13>^Q!9zT);H<(#bIeRWz%#*}sgUX9P|9($kexOyKIOc`dLux}c$7It4u|Rl z6SSkY*V~g_B-hMPo_ak>>z@AVQ(_N)VY2kB3IZ0G(iDUYw+2d7W^~(Jq}KY=JnWS( z#rzEa&0uNhJ>QE8iiyz;n2H|SV#Og+wEZv=f2%1ELX!SX-(d3tEj$5$1}70Mp<&eI zCkfbByL7af=qQE@5vDVxx1}FSGt_a1DoE3SDI+G)mBAna)KBG4p8Epxl9QZ4BfdAN zFnF|Y(umr;gRgG6NLQ$?ZWgllEeeq~z^ZS7L?<(~O&$5|y)Al^iMKy}&W+eMm1W z7EMU)u^ke(A1#XCV>CZ71}P}0x)4wtHO8#JRG3MA-6g=`ZM!FcICCZ{IEw8Dm2&LQ z1|r)BUG^0GzI6f946RrBlfB1Vs)~8toZf~7)+G;pv&XiUO(%5bm)pl=p>nV^o*;&T z;}@oZSibzto$arQgfkp|z4Z($P>dTXE{4O=vY0!)kDO* zGF8a4wq#VaFpLfK!iELy@?-SeRrdz%F*}hjKcA*y@mj~VD3!it9lhRhX}5YOaR9$} z3mS%$2Be7{l(+MVx3 z(4?h;P!jnRmX9J9sYN#7i=iyj_5q7n#X(!cdqI2lnr8T$IfOW<_v`eB!d9xY1P=2q&WtOXY=D9QYteP)De?S4}FK6#6Ma z=E*V+#s8>L;8aVroK^6iKo=MH{4yEZ_>N-N z`(|;aOATba1^asjxlILk<4}f~`39dBFlxj>Dw(hMYKPO3EEt1@S`1lxFNM+J@uB7T zZ8WKjz7HF1-5&2=l=fqF-*@>n5J}jIxdDwpT?oKM3s8Nr`x8JnN-kCE?~aM1H!hAE z%%w(3kHfGwMnMmNj(SU(w42OrC-euI>Dsjk&jz3ts}WHqmMpzQ3vZrsXrZ|}+MHA7 z068obeXZTsO*6RS@o3x80E4ok``rV^Y3hr&C1;|ZZ0|*EKO`$lECUYG2gVFtUTw)R z4Um<0ZzlON`zTdvVdL#KFoMFQX*a5wM0Czp%wTtfK4Sjs)P**RW&?lP$(<}q%r68Z zS53Y!d@&~ne9O)A^tNrXHhXBkj~$8j%pT1%%mypa9AW5E&s9)rjF4@O3ytH{0z6riz|@< zB~UPh*wRFg2^7EbQrHf0y?E~dHlkOxof_a?M{LqQ^C!i2dawHTPYUE=X@2(3<=OOxs8qn_(y>pU>u^}3y&df{JarR0@VJn0f+U%UiF=$Wyq zQvnVHESil@d|8&R<%}uidGh7@u^(%?$#|&J$pvFC-n8&A>utA=n3#)yMkz+qnG3wd zP7xCnF|$9Dif@N~L)Vde3hW8W!UY0BgT2v(wzp;tlLmyk2%N|0jfG$%<;A&IVrOI< z!L)o>j>;dFaqA3pL}b-Je(bB@VJ4%!JeX@3x!i{yIeIso^=n?fDX`3bU=eG7sTc%g%ye8$v8P@yKE^XD=NYxTb zbf!Mk=h|otpqjFaA-vs5YOF-*GwWPc7VbaOW&stlANnCN8iftFMMrUdYNJ_Bnn5Vt zxfz@Ah|+4&P;reZxp;MmEI7C|FOv8NKUm8njF7Wb6Gi7DeODLl&G~}G4be&*Hi0Qw z5}77vL0P+7-B%UL@3n1&JPxW^d@vVwp?u#gVcJqY9#@-3X{ok#UfW3<1fb%FT`|)V~ggq z(3AUoUS-;7)^hCjdT0Kf{i}h)mBg4qhtHHBti=~h^n^OTH5U*XMgDLIR@sre`AaB$ zg)IGBET_4??m@cx&c~bA80O7B8CHR7(LX7%HThkeC*@vi{-pL%e)yXp!B2InafbDF zjPXf1mko3h59{lT6EEbxKO1Z5GF71)WwowO6kY|6tjSVSWdQ}NsK2x{>i|MKZK8%Q zfu&_0D;CO-Jg0#YmyfctyJ!mRJp)e#@O0mYdp|8x;G1%OZQ3Q847YWTyy|%^cpA;m zze0(5p{tMu^lDkpe?HynyO?a1$_LJl2L&mpeKu%8YvgRNr=%2z${%WThHG=vrWY@4 zsA`OP#O&)TetZ>s%h!=+CE15lOOls&nvC~$Qz0Ph7tHiP;O$i|eDwpT{cp>+)0-|; zY$|bB+Gbel>5aRN3>c0x)4U=|X+z+{ zn*_p*EQoquRL+=+p;=lm`d71&1NqBz&_ph)MXu(Nv6&XE7(RsS)^MGj5Q?Fwude-(sq zjJ>aOq!7!EN>@(fK7EE#;i_BGvli`5U;r!YA{JRodLBc6-`n8K+Fjgwb%sX;j=qHQ z7&Tr!)!{HXoO<2BQrV9Sw?JRaLXV8HrsNevvnf>Y-6|{T!pYLl7jp$-nEE z#X!4G4L#K0qG_4Z;Cj6=;b|Be$hi4JvMH!-voxqx^@8cXp`B??eFBz2lLD8RRaRGh zn7kUfy!YV~p(R|p7iC1Rdgt$_24i0cd-S8HpG|`@my70g^y`gu%#Tf_L21-k?sRRZHK&at(*ED0P8iw{7?R$9~OF$Ko;Iu5)ur5<->x!m93Eb zFYpIx60s=Wxxw=`$aS-O&dCO_9?b1yKiPCQmSQb>T)963`*U+Ydj5kI(B(B?HNP8r z*bfSBpSu)w(Z3j7HQoRjUG(+d=IaE~tv}y14zHHs|0UcN52fT8V_<@2ep_ee{QgZG zmgp8iv4V{k;~8@I%M3<#B;2R>Ef(Gg_cQM7%}0s*^)SK6!Ym+~P^58*wnwV1BW@eG z4sZLqsUvBbFsr#8u7S1r4teQ;t)Y@jnn_m5jS$CsW1um!p&PqAcc8!zyiXHVta9QC zY~wCwCF0U%xiQPD_INKtTb;A|Zf29(mu9NI;E zc-e>*1%(LSXB`g}kd`#}O;veb<(sk~RWL|f3ljxCnEZDdNSTDV6#Td({6l&y4IjKF z^}lIUq*ZUqgTPumD)RrCN{M^jhY>E~1pn|KOZ5((%F)G|*ZQ|r4zIbrEiV%42hJV8 z3xS)=!X1+=olbdGJ=yZil?oXLct8FM{(6ikLL3E%=q#O6(H$p~gQu6T8N!plf!96| z&Q3=`L~>U0zZh;z(pGR2^S^{#PrPxTRHD1RQOON&f)Siaf`GLj#UOk&(|@0?zm;Sx ztsGt8=29-MZs5CSf1l1jNFtNt5rFNZxJPvkNu~2}7*9468TWm>nN9TP&^!;J{-h)_ z7WsHH9|F%I`Pb!>KAS3jQWKfGivTVkMJLO-HUGM_a4UQ_%RgL6WZvrW+Z4ujZn;y@ zz9$=oO!7qVTaQAA^BhX&ZxS*|5dj803M=k&2%QrXda`-Q#IoZL6E(g+tN!6CA!CP* zCpWtCujIea)ENl0liwVfj)Nc<9mV%+e@=d`haoZ*`B7+PNjEbXBkv=B+Pi^~L#EO$D$ZqTiD8f<5$eyb54-(=3 zh)6i8i|jp(@OnRrY5B8t|LFXFQVQ895n*P16cEKTrT*~yLH6Z4e*bZ5otpRDri&+A zfNbK1D5@O=sm`fN=WzWyse!za5n%^+6dHPGX#8DyIK>?9qyX}2XvBWVqbP%%D)7$= z=#$WulZlZR<{m#gU7lwqK4WS1Ne$#_P{b17qe$~UOXCl>5b|6WVh;5vVnR<%d+Lnp z$uEmML38}U4vaW8>shm6CzB(Wei3s#NAWE3)a2)z@i{4jTn;;aQS)O@l{rUM`J@K& l00vQ5JBs~;vo!vr%%-k{2_Fq1Mn4QF81S)AQ99zk{{c4yR+0b! literal 0 HcmV?d00001 diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties new file mode 100644 index 0000000..e6aba25 --- /dev/null +++ b/gradle/wrapper/gradle-wrapper.properties @@ -0,0 +1,7 @@ +distributionBase=GRADLE_USER_HOME +distributionPath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-8.5-all.zip +networkTimeout=10000 +validateDistributionUrl=true +zipStoreBase=GRADLE_USER_HOME +zipStorePath=wrapper/dists diff --git a/gradlew b/gradlew new file mode 100755 index 0000000..1aa94a4 --- /dev/null +++ b/gradlew @@ -0,0 +1,249 @@ +#!/bin/sh + +# +# Copyright © 2015-2021 the original authors. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +############################################################################## +# +# Gradle start up script for POSIX generated by Gradle. +# +# Important for running: +# +# (1) You need a POSIX-compliant shell to run this script. If your /bin/sh is +# noncompliant, but you have some other compliant shell such as ksh or +# bash, then to run this script, type that shell name before the whole +# command line, like: +# +# ksh Gradle +# +# Busybox and similar reduced shells will NOT work, because this script +# requires all of these POSIX shell features: +# * functions; +# * expansions «$var», «${var}», «${var:-default}», «${var+SET}», +# «${var#prefix}», «${var%suffix}», and «$( cmd )»; +# * compound commands having a testable exit status, especially «case»; +# * various built-in commands including «command», «set», and «ulimit». +# +# Important for patching: +# +# (2) This script targets any POSIX shell, so it avoids extensions provided +# by Bash, Ksh, etc; in particular arrays are avoided. +# +# The "traditional" practice of packing multiple parameters into a +# space-separated string is a well documented source of bugs and security +# problems, so this is (mostly) avoided, by progressively accumulating +# options in "$@", and eventually passing that to Java. +# +# Where the inherited environment variables (DEFAULT_JVM_OPTS, JAVA_OPTS, +# and GRADLE_OPTS) rely on word-splitting, this is performed explicitly; +# see the in-line comments for details. +# +# There are tweaks for specific operating systems such as AIX, CygWin, +# Darwin, MinGW, and NonStop. +# +# (3) This script is generated from the Groovy template +# https://github.com/gradle/gradle/blob/HEAD/subprojects/plugins/src/main/resources/org/gradle/api/internal/plugins/unixStartScript.txt +# within the Gradle project. +# +# You can find Gradle at https://github.com/gradle/gradle/. +# +############################################################################## + +# Attempt to set APP_HOME + +# Resolve links: $0 may be a link +app_path=$0 + +# Need this for daisy-chained symlinks. +while + APP_HOME=${app_path%"${app_path##*/}"} # leaves a trailing /; empty if no leading path + [ -h "$app_path" ] +do + ls=$( ls -ld "$app_path" ) + link=${ls#*' -> '} + case $link in #( + /*) app_path=$link ;; #( + *) app_path=$APP_HOME$link ;; + esac +done + +# This is normally unused +# shellcheck disable=SC2034 +APP_BASE_NAME=${0##*/} +# Discard cd standard output in case $CDPATH is set (https://github.com/gradle/gradle/issues/25036) +APP_HOME=$( cd "${APP_HOME:-./}" > /dev/null && pwd -P ) || exit + +# Use the maximum available, or set MAX_FD != -1 to use that value. +MAX_FD=maximum + +warn () { + echo "$*" +} >&2 + +die () { + echo + echo "$*" + echo + exit 1 +} >&2 + +# OS specific support (must be 'true' or 'false'). +cygwin=false +msys=false +darwin=false +nonstop=false +case "$( uname )" in #( + CYGWIN* ) cygwin=true ;; #( + Darwin* ) darwin=true ;; #( + MSYS* | MINGW* ) msys=true ;; #( + NONSTOP* ) nonstop=true ;; +esac + +CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar + + +# Determine the Java command to use to start the JVM. +if [ -n "$JAVA_HOME" ] ; then + if [ -x "$JAVA_HOME/jre/sh/java" ] ; then + # IBM's JDK on AIX uses strange locations for the executables + JAVACMD=$JAVA_HOME/jre/sh/java + else + JAVACMD=$JAVA_HOME/bin/java + fi + if [ ! -x "$JAVACMD" ] ; then + die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." + fi +else + JAVACMD=java + if ! command -v java >/dev/null 2>&1 + then + die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." + fi +fi + +# Increase the maximum file descriptors if we can. +if ! "$cygwin" && ! "$darwin" && ! "$nonstop" ; then + case $MAX_FD in #( + max*) + # In POSIX sh, ulimit -H is undefined. That's why the result is checked to see if it worked. + # shellcheck disable=SC2039,SC3045 + MAX_FD=$( ulimit -H -n ) || + warn "Could not query maximum file descriptor limit" + esac + case $MAX_FD in #( + '' | soft) :;; #( + *) + # In POSIX sh, ulimit -n is undefined. That's why the result is checked to see if it worked. + # shellcheck disable=SC2039,SC3045 + ulimit -n "$MAX_FD" || + warn "Could not set maximum file descriptor limit to $MAX_FD" + esac +fi + +# Collect all arguments for the java command, stacking in reverse order: +# * args from the command line +# * the main class name +# * -classpath +# * -D...appname settings +# * --module-path (only if needed) +# * DEFAULT_JVM_OPTS, JAVA_OPTS, and GRADLE_OPTS environment variables. + +# For Cygwin or MSYS, switch paths to Windows format before running java +if "$cygwin" || "$msys" ; then + APP_HOME=$( cygpath --path --mixed "$APP_HOME" ) + CLASSPATH=$( cygpath --path --mixed "$CLASSPATH" ) + + JAVACMD=$( cygpath --unix "$JAVACMD" ) + + # Now convert the arguments - kludge to limit ourselves to /bin/sh + for arg do + if + case $arg in #( + -*) false ;; # don't mess with options #( + /?*) t=${arg#/} t=/${t%%/*} # looks like a POSIX filepath + [ -e "$t" ] ;; #( + *) false ;; + esac + then + arg=$( cygpath --path --ignore --mixed "$arg" ) + fi + # Roll the args list around exactly as many times as the number of + # args, so each arg winds up back in the position where it started, but + # possibly modified. + # + # NB: a `for` loop captures its iteration list before it begins, so + # changing the positional parameters here affects neither the number of + # iterations, nor the values presented in `arg`. + shift # remove old arg + set -- "$@" "$arg" # push replacement arg + done +fi + + +# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +DEFAULT_JVM_OPTS='"-Xmx64m" "-Xms64m"' + +# Collect all arguments for the java command: +# * DEFAULT_JVM_OPTS, JAVA_OPTS, JAVA_OPTS, and optsEnvironmentVar are not allowed to contain shell fragments, +# and any embedded shellness will be escaped. +# * For example: A user cannot expect ${Hostname} to be expanded, as it is an environment variable and will be +# treated as '${Hostname}' itself on the command line. + +set -- \ + "-Dorg.gradle.appname=$APP_BASE_NAME" \ + -classpath "$CLASSPATH" \ + org.gradle.wrapper.GradleWrapperMain \ + "$@" + +# Stop when "xargs" is not available. +if ! command -v xargs >/dev/null 2>&1 +then + die "xargs is not available" +fi + +# Use "xargs" to parse quoted args. +# +# With -n1 it outputs one arg per line, with the quotes and backslashes removed. +# +# In Bash we could simply go: +# +# readarray ARGS < <( xargs -n1 <<<"$var" ) && +# set -- "${ARGS[@]}" "$@" +# +# but POSIX shell has neither arrays nor command substitution, so instead we +# post-process each arg (as a line of input to sed) to backslash-escape any +# character that might be a shell metacharacter, then use eval to reverse +# that process (while maintaining the separation between arguments), and wrap +# the whole thing up as a single "set" statement. +# +# This will of course break if any of these variables contains a newline or +# an unmatched quote. +# + +eval "set -- $( + printf '%s\n' "$DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS" | + xargs -n1 | + sed ' s~[^-[:alnum:]+,./:=@_]~\\&~g; ' | + tr '\n' ' ' + )" '"$@"' + +exec "$JAVACMD" "$@" diff --git a/gradlew.bat b/gradlew.bat new file mode 100644 index 0000000..6689b85 --- /dev/null +++ b/gradlew.bat @@ -0,0 +1,92 @@ +@rem +@rem Copyright 2015 the original author or authors. +@rem +@rem Licensed under the Apache License, Version 2.0 (the "License"); +@rem you may not use this file except in compliance with the License. +@rem You may obtain a copy of the License at +@rem +@rem https://www.apache.org/licenses/LICENSE-2.0 +@rem +@rem Unless required by applicable law or agreed to in writing, software +@rem distributed under the License is distributed on an "AS IS" BASIS, +@rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +@rem See the License for the specific language governing permissions and +@rem limitations under the License. +@rem + +@if "%DEBUG%"=="" @echo off +@rem ########################################################################## +@rem +@rem Gradle startup script for Windows +@rem +@rem ########################################################################## + +@rem Set local scope for the variables with windows NT shell +if "%OS%"=="Windows_NT" setlocal + +set DIRNAME=%~dp0 +if "%DIRNAME%"=="" set DIRNAME=. +@rem This is normally unused +set APP_BASE_NAME=%~n0 +set APP_HOME=%DIRNAME% + +@rem Resolve any "." and ".." in APP_HOME to make it shorter. +for %%i in ("%APP_HOME%") do set APP_HOME=%%~fi + +@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +set DEFAULT_JVM_OPTS="-Xmx64m" "-Xms64m" + +@rem Find java.exe +if defined JAVA_HOME goto findJavaFromJavaHome + +set JAVA_EXE=java.exe +%JAVA_EXE% -version >NUL 2>&1 +if %ERRORLEVEL% equ 0 goto execute + +echo. +echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:findJavaFromJavaHome +set JAVA_HOME=%JAVA_HOME:"=% +set JAVA_EXE=%JAVA_HOME%/bin/java.exe + +if exist "%JAVA_EXE%" goto execute + +echo. +echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:execute +@rem Setup the command line + +set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar + + +@rem Execute Gradle +"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %* + +:end +@rem End local scope for the variables with windows NT shell +if %ERRORLEVEL% equ 0 goto mainEnd + +:fail +rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of +rem the _cmd.exe /c_ return code! +set EXIT_CODE=%ERRORLEVEL% +if %EXIT_CODE% equ 0 set EXIT_CODE=1 +if not ""=="%GRADLE_EXIT_CONSOLE%" exit %EXIT_CODE% +exit /b %EXIT_CODE% + +:mainEnd +if "%OS%"=="Windows_NT" endlocal + +:omega diff --git a/netty-buffer/build.gradle b/netty-buffer/build.gradle new file mode 100644 index 0000000..e3a09c8 --- /dev/null +++ b/netty-buffer/build.gradle @@ -0,0 +1,5 @@ +dependencies { + api project(':netty-util') + testImplementation testLibs.mockito.core + testImplementation testLibs.assertj +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/AbstractByteBuf.java b/netty-buffer/src/main/java/io/netty/buffer/AbstractByteBuf.java new file mode 100644 index 0000000..7bff10a --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/AbstractByteBuf.java @@ -0,0 +1,1459 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import io.netty.util.AsciiString; +import io.netty.util.ByteProcessor; +import io.netty.util.CharsetUtil; +import io.netty.util.IllegalReferenceCountException; +import io.netty.util.ResourceLeakDetector; +import io.netty.util.ResourceLeakDetectorFactory; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.StringUtil; +import io.netty.util.internal.SystemPropertyUtil; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.channels.FileChannel; +import java.nio.channels.GatheringByteChannel; +import java.nio.channels.ScatteringByteChannel; +import java.nio.charset.Charset; + +import static io.netty.util.internal.MathUtil.isOutOfBounds; +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; + +/** + * A skeletal implementation of a buffer. + */ +public abstract class AbstractByteBuf extends ByteBuf { + private static final String LEGACY_PROP_CHECK_ACCESSIBLE = "io.netty.buffer.bytebuf.checkAccessible"; + private static final String PROP_CHECK_ACCESSIBLE = "io.netty.buffer.checkAccessible"; + static final boolean checkAccessible; // accessed from CompositeByteBuf + private static final String PROP_CHECK_BOUNDS = "io.netty.buffer.checkBounds"; + private static final boolean checkBounds; + + static { + if (SystemPropertyUtil.contains(PROP_CHECK_ACCESSIBLE)) { + checkAccessible = SystemPropertyUtil.getBoolean(PROP_CHECK_ACCESSIBLE, true); + } else { + checkAccessible = SystemPropertyUtil.getBoolean(LEGACY_PROP_CHECK_ACCESSIBLE, true); + } + checkBounds = SystemPropertyUtil.getBoolean(PROP_CHECK_BOUNDS, true); + } + + static final ResourceLeakDetector leakDetector = + ResourceLeakDetectorFactory.instance().newResourceLeakDetector(ByteBuf.class); + + int readerIndex; + int writerIndex; + private int markedReaderIndex; + private int markedWriterIndex; + private int maxCapacity; + + protected AbstractByteBuf(int maxCapacity) { + checkPositiveOrZero(maxCapacity, "maxCapacity"); + this.maxCapacity = maxCapacity; + } + + @Override + public boolean isReadOnly() { + return false; + } + + @SuppressWarnings("deprecation") + @Override + public ByteBuf asReadOnly() { + if (isReadOnly()) { + return this; + } + return Unpooled.unmodifiableBuffer(this); + } + + @Override + public int maxCapacity() { + return maxCapacity; + } + + protected final void maxCapacity(int maxCapacity) { + this.maxCapacity = maxCapacity; + } + + @Override + public int readerIndex() { + return readerIndex; + } + + private static void checkIndexBounds(final int readerIndex, final int writerIndex, final int capacity) { + if (readerIndex < 0 || readerIndex > writerIndex || writerIndex > capacity) { + throw new IndexOutOfBoundsException(String.format( + "readerIndex: %d, writerIndex: %d (expected: 0 <= readerIndex <= writerIndex <= capacity(%d))", + readerIndex, writerIndex, capacity)); + } + } + + @Override + public ByteBuf readerIndex(int readerIndex) { + if (checkBounds) { + checkIndexBounds(readerIndex, writerIndex, capacity()); + } + this.readerIndex = readerIndex; + return this; + } + + @Override + public int writerIndex() { + return writerIndex; + } + + @Override + public ByteBuf writerIndex(int writerIndex) { + if (checkBounds) { + checkIndexBounds(readerIndex, writerIndex, capacity()); + } + this.writerIndex = writerIndex; + return this; + } + + @Override + public ByteBuf setIndex(int readerIndex, int writerIndex) { + if (checkBounds) { + checkIndexBounds(readerIndex, writerIndex, capacity()); + } + setIndex0(readerIndex, writerIndex); + return this; + } + + @Override + public ByteBuf clear() { + readerIndex = writerIndex = 0; + return this; + } + + @Override + public boolean isReadable() { + return writerIndex > readerIndex; + } + + @Override + public boolean isReadable(int numBytes) { + return writerIndex - readerIndex >= numBytes; + } + + @Override + public boolean isWritable() { + return capacity() > writerIndex; + } + + @Override + public boolean isWritable(int numBytes) { + return capacity() - writerIndex >= numBytes; + } + + @Override + public int readableBytes() { + return writerIndex - readerIndex; + } + + @Override + public int writableBytes() { + return capacity() - writerIndex; + } + + @Override + public int maxWritableBytes() { + return maxCapacity() - writerIndex; + } + + @Override + public ByteBuf markReaderIndex() { + markedReaderIndex = readerIndex; + return this; + } + + @Override + public ByteBuf resetReaderIndex() { + readerIndex(markedReaderIndex); + return this; + } + + @Override + public ByteBuf markWriterIndex() { + markedWriterIndex = writerIndex; + return this; + } + + @Override + public ByteBuf resetWriterIndex() { + writerIndex(markedWriterIndex); + return this; + } + + @Override + public ByteBuf discardReadBytes() { + if (readerIndex == 0) { + ensureAccessible(); + return this; + } + + if (readerIndex != writerIndex) { + setBytes(0, this, readerIndex, writerIndex - readerIndex); + writerIndex -= readerIndex; + adjustMarkers(readerIndex); + readerIndex = 0; + } else { + ensureAccessible(); + adjustMarkers(readerIndex); + writerIndex = readerIndex = 0; + } + return this; + } + + @Override + public ByteBuf discardSomeReadBytes() { + if (readerIndex > 0) { + if (readerIndex == writerIndex) { + ensureAccessible(); + adjustMarkers(readerIndex); + writerIndex = readerIndex = 0; + return this; + } + + if (readerIndex >= capacity() >>> 1) { + setBytes(0, this, readerIndex, writerIndex - readerIndex); + writerIndex -= readerIndex; + adjustMarkers(readerIndex); + readerIndex = 0; + return this; + } + } + ensureAccessible(); + return this; + } + + protected final void adjustMarkers(int decrement) { + if (markedReaderIndex <= decrement) { + markedReaderIndex = 0; + if (markedWriterIndex <= decrement) { + markedWriterIndex = 0; + } else { + markedWriterIndex -= decrement; + } + } else { + markedReaderIndex -= decrement; + markedWriterIndex -= decrement; + } + } + + // Called after a capacity reduction + protected final void trimIndicesToCapacity(int newCapacity) { + if (writerIndex() > newCapacity) { + setIndex0(Math.min(readerIndex(), newCapacity), newCapacity); + } + } + + @Override + public ByteBuf ensureWritable(int minWritableBytes) { + ensureWritable0(checkPositiveOrZero(minWritableBytes, "minWritableBytes")); + return this; + } + + final void ensureWritable0(int minWritableBytes) { + final int writerIndex = writerIndex(); + final int targetCapacity = writerIndex + minWritableBytes; + // using non-short-circuit & to reduce branching - this is a hot path and targetCapacity should rarely overflow + if (targetCapacity >= 0 & targetCapacity <= capacity()) { + ensureAccessible(); + return; + } + if (checkBounds && (targetCapacity < 0 || targetCapacity > maxCapacity)) { + ensureAccessible(); + throw new IndexOutOfBoundsException(String.format( + "writerIndex(%d) + minWritableBytes(%d) exceeds maxCapacity(%d): %s", + writerIndex, minWritableBytes, maxCapacity, this)); + } + + // Normalize the target capacity to the power of 2. + final int fastWritable = maxFastWritableBytes(); + int newCapacity = fastWritable >= minWritableBytes ? writerIndex + fastWritable + : alloc().calculateNewCapacity(targetCapacity, maxCapacity); + + // Adjust to the new capacity. + capacity(newCapacity); + } + + @Override + public int ensureWritable(int minWritableBytes, boolean force) { + ensureAccessible(); + checkPositiveOrZero(minWritableBytes, "minWritableBytes"); + + if (minWritableBytes <= writableBytes()) { + return 0; + } + + final int maxCapacity = maxCapacity(); + final int writerIndex = writerIndex(); + if (minWritableBytes > maxCapacity - writerIndex) { + if (!force || capacity() == maxCapacity) { + return 1; + } + + capacity(maxCapacity); + return 3; + } + + int fastWritable = maxFastWritableBytes(); + int newCapacity = fastWritable >= minWritableBytes ? writerIndex + fastWritable + : alloc().calculateNewCapacity(writerIndex + minWritableBytes, maxCapacity); + + // Adjust to the new capacity. + capacity(newCapacity); + return 2; + } + + @Override + public ByteBuf order(ByteOrder endianness) { + if (endianness == order()) { + return this; + } + ObjectUtil.checkNotNull(endianness, "endianness"); + return newSwappedByteBuf(); + } + + /** + * Creates a new {@link SwappedByteBuf} for this {@link ByteBuf} instance. + */ + protected SwappedByteBuf newSwappedByteBuf() { + return new SwappedByteBuf(this); + } + + @Override + public byte getByte(int index) { + checkIndex(index); + return _getByte(index); + } + + protected abstract byte _getByte(int index); + + @Override + public boolean getBoolean(int index) { + return getByte(index) != 0; + } + + @Override + public short getUnsignedByte(int index) { + return (short) (getByte(index) & 0xFF); + } + + @Override + public short getShort(int index) { + checkIndex(index, 2); + return _getShort(index); + } + + protected abstract short _getShort(int index); + + @Override + public short getShortLE(int index) { + checkIndex(index, 2); + return _getShortLE(index); + } + + protected abstract short _getShortLE(int index); + + @Override + public int getUnsignedShort(int index) { + return getShort(index) & 0xFFFF; + } + + @Override + public int getUnsignedShortLE(int index) { + return getShortLE(index) & 0xFFFF; + } + + @Override + public int getUnsignedMedium(int index) { + checkIndex(index, 3); + return _getUnsignedMedium(index); + } + + protected abstract int _getUnsignedMedium(int index); + + @Override + public int getUnsignedMediumLE(int index) { + checkIndex(index, 3); + return _getUnsignedMediumLE(index); + } + + protected abstract int _getUnsignedMediumLE(int index); + + @Override + public int getMedium(int index) { + int value = getUnsignedMedium(index); + if ((value & 0x800000) != 0) { + value |= 0xff000000; + } + return value; + } + + @Override + public int getMediumLE(int index) { + int value = getUnsignedMediumLE(index); + if ((value & 0x800000) != 0) { + value |= 0xff000000; + } + return value; + } + + @Override + public int getInt(int index) { + checkIndex(index, 4); + return _getInt(index); + } + + protected abstract int _getInt(int index); + + @Override + public int getIntLE(int index) { + checkIndex(index, 4); + return _getIntLE(index); + } + + protected abstract int _getIntLE(int index); + + @Override + public long getUnsignedInt(int index) { + return getInt(index) & 0xFFFFFFFFL; + } + + @Override + public long getUnsignedIntLE(int index) { + return getIntLE(index) & 0xFFFFFFFFL; + } + + @Override + public long getLong(int index) { + checkIndex(index, 8); + return _getLong(index); + } + + protected abstract long _getLong(int index); + + @Override + public long getLongLE(int index) { + checkIndex(index, 8); + return _getLongLE(index); + } + + protected abstract long _getLongLE(int index); + + @Override + public char getChar(int index) { + return (char) getShort(index); + } + + @Override + public float getFloat(int index) { + return Float.intBitsToFloat(getInt(index)); + } + + @Override + public double getDouble(int index) { + return Double.longBitsToDouble(getLong(index)); + } + + @Override + public ByteBuf getBytes(int index, byte[] dst) { + getBytes(index, dst, 0, dst.length); + return this; + } + + @Override + public ByteBuf getBytes(int index, ByteBuf dst) { + getBytes(index, dst, dst.writableBytes()); + return this; + } + + @Override + public ByteBuf getBytes(int index, ByteBuf dst, int length) { + getBytes(index, dst, dst.writerIndex(), length); + dst.writerIndex(dst.writerIndex() + length); + return this; + } + + @Override + public CharSequence getCharSequence(int index, int length, Charset charset) { + if (CharsetUtil.US_ASCII.equals(charset) || CharsetUtil.ISO_8859_1.equals(charset)) { + // ByteBufUtil.getBytes(...) will return a new copy which the AsciiString uses directly + return new AsciiString(ByteBufUtil.getBytes(this, index, length, true), false); + } + return toString(index, length, charset); + } + + @Override + public CharSequence readCharSequence(int length, Charset charset) { + CharSequence sequence = getCharSequence(readerIndex, length, charset); + readerIndex += length; + return sequence; + } + + @Override + public ByteBuf setByte(int index, int value) { + checkIndex(index); + _setByte(index, value); + return this; + } + + protected abstract void _setByte(int index, int value); + + @Override + public ByteBuf setBoolean(int index, boolean value) { + setByte(index, value? 1 : 0); + return this; + } + + @Override + public ByteBuf setShort(int index, int value) { + checkIndex(index, 2); + _setShort(index, value); + return this; + } + + protected abstract void _setShort(int index, int value); + + @Override + public ByteBuf setShortLE(int index, int value) { + checkIndex(index, 2); + _setShortLE(index, value); + return this; + } + + protected abstract void _setShortLE(int index, int value); + + @Override + public ByteBuf setChar(int index, int value) { + setShort(index, value); + return this; + } + + @Override + public ByteBuf setMedium(int index, int value) { + checkIndex(index, 3); + _setMedium(index, value); + return this; + } + + protected abstract void _setMedium(int index, int value); + + @Override + public ByteBuf setMediumLE(int index, int value) { + checkIndex(index, 3); + _setMediumLE(index, value); + return this; + } + + protected abstract void _setMediumLE(int index, int value); + + @Override + public ByteBuf setInt(int index, int value) { + checkIndex(index, 4); + _setInt(index, value); + return this; + } + + protected abstract void _setInt(int index, int value); + + @Override + public ByteBuf setIntLE(int index, int value) { + checkIndex(index, 4); + _setIntLE(index, value); + return this; + } + + protected abstract void _setIntLE(int index, int value); + + @Override + public ByteBuf setFloat(int index, float value) { + setInt(index, Float.floatToRawIntBits(value)); + return this; + } + + @Override + public ByteBuf setLong(int index, long value) { + checkIndex(index, 8); + _setLong(index, value); + return this; + } + + protected abstract void _setLong(int index, long value); + + @Override + public ByteBuf setLongLE(int index, long value) { + checkIndex(index, 8); + _setLongLE(index, value); + return this; + } + + protected abstract void _setLongLE(int index, long value); + + @Override + public ByteBuf setDouble(int index, double value) { + setLong(index, Double.doubleToRawLongBits(value)); + return this; + } + + @Override + public ByteBuf setBytes(int index, byte[] src) { + setBytes(index, src, 0, src.length); + return this; + } + + @Override + public ByteBuf setBytes(int index, ByteBuf src) { + setBytes(index, src, src.readableBytes()); + return this; + } + + private static void checkReadableBounds(final ByteBuf src, final int length) { + if (length > src.readableBytes()) { + throw new IndexOutOfBoundsException(String.format( + "length(%d) exceeds src.readableBytes(%d) where src is: %s", length, src.readableBytes(), src)); + } + } + + @Override + public ByteBuf setBytes(int index, ByteBuf src, int length) { + checkIndex(index, length); + ObjectUtil.checkNotNull(src, "src"); + if (checkBounds) { + checkReadableBounds(src, length); + } + + setBytes(index, src, src.readerIndex(), length); + src.readerIndex(src.readerIndex() + length); + return this; + } + + @Override + public ByteBuf setZero(int index, int length) { + if (length == 0) { + return this; + } + + checkIndex(index, length); + + int nLong = length >>> 3; + int nBytes = length & 7; + for (int i = nLong; i > 0; i --) { + _setLong(index, 0); + index += 8; + } + if (nBytes == 4) { + _setInt(index, 0); + // Not need to update the index as we not will use it after this. + } else if (nBytes < 4) { + for (int i = nBytes; i > 0; i --) { + _setByte(index, 0); + index ++; + } + } else { + _setInt(index, 0); + index += 4; + for (int i = nBytes - 4; i > 0; i --) { + _setByte(index, 0); + index ++; + } + } + return this; + } + + @Override + public int setCharSequence(int index, CharSequence sequence, Charset charset) { + return setCharSequence0(index, sequence, charset, false); + } + + private int setCharSequence0(int index, CharSequence sequence, Charset charset, boolean expand) { + if (charset.equals(CharsetUtil.UTF_8)) { + int length = ByteBufUtil.utf8MaxBytes(sequence); + if (expand) { + ensureWritable0(length); + checkIndex0(index, length); + } else { + checkIndex(index, length); + } + return ByteBufUtil.writeUtf8(this, index, length, sequence, sequence.length()); + } + if (charset.equals(CharsetUtil.US_ASCII) || charset.equals(CharsetUtil.ISO_8859_1)) { + int length = sequence.length(); + if (expand) { + ensureWritable0(length); + checkIndex0(index, length); + } else { + checkIndex(index, length); + } + return ByteBufUtil.writeAscii(this, index, sequence, length); + } + byte[] bytes = sequence.toString().getBytes(charset); + if (expand) { + ensureWritable0(bytes.length); + // setBytes(...) will take care of checking the indices. + } + setBytes(index, bytes); + return bytes.length; + } + + @Override + public byte readByte() { + checkReadableBytes0(1); + int i = readerIndex; + byte b = _getByte(i); + readerIndex = i + 1; + return b; + } + + @Override + public boolean readBoolean() { + return readByte() != 0; + } + + @Override + public short readUnsignedByte() { + return (short) (readByte() & 0xFF); + } + + @Override + public short readShort() { + checkReadableBytes0(2); + short v = _getShort(readerIndex); + readerIndex += 2; + return v; + } + + @Override + public short readShortLE() { + checkReadableBytes0(2); + short v = _getShortLE(readerIndex); + readerIndex += 2; + return v; + } + + @Override + public int readUnsignedShort() { + return readShort() & 0xFFFF; + } + + @Override + public int readUnsignedShortLE() { + return readShortLE() & 0xFFFF; + } + + @Override + public int readMedium() { + int value = readUnsignedMedium(); + if ((value & 0x800000) != 0) { + value |= 0xff000000; + } + return value; + } + + @Override + public int readMediumLE() { + int value = readUnsignedMediumLE(); + if ((value & 0x800000) != 0) { + value |= 0xff000000; + } + return value; + } + + @Override + public int readUnsignedMedium() { + checkReadableBytes0(3); + int v = _getUnsignedMedium(readerIndex); + readerIndex += 3; + return v; + } + + @Override + public int readUnsignedMediumLE() { + checkReadableBytes0(3); + int v = _getUnsignedMediumLE(readerIndex); + readerIndex += 3; + return v; + } + + @Override + public int readInt() { + checkReadableBytes0(4); + int v = _getInt(readerIndex); + readerIndex += 4; + return v; + } + + @Override + public int readIntLE() { + checkReadableBytes0(4); + int v = _getIntLE(readerIndex); + readerIndex += 4; + return v; + } + + @Override + public long readUnsignedInt() { + return readInt() & 0xFFFFFFFFL; + } + + @Override + public long readUnsignedIntLE() { + return readIntLE() & 0xFFFFFFFFL; + } + + @Override + public long readLong() { + checkReadableBytes0(8); + long v = _getLong(readerIndex); + readerIndex += 8; + return v; + } + + @Override + public long readLongLE() { + checkReadableBytes0(8); + long v = _getLongLE(readerIndex); + readerIndex += 8; + return v; + } + + @Override + public char readChar() { + return (char) readShort(); + } + + @Override + public float readFloat() { + return Float.intBitsToFloat(readInt()); + } + + @Override + public double readDouble() { + return Double.longBitsToDouble(readLong()); + } + + @Override + public ByteBuf readBytes(int length) { + checkReadableBytes(length); + if (length == 0) { + return Unpooled.EMPTY_BUFFER; + } + + ByteBuf buf = alloc().buffer(length, maxCapacity); + buf.writeBytes(this, readerIndex, length); + readerIndex += length; + return buf; + } + + @Override + public ByteBuf readSlice(int length) { + checkReadableBytes(length); + ByteBuf slice = slice(readerIndex, length); + readerIndex += length; + return slice; + } + + @Override + public ByteBuf readRetainedSlice(int length) { + checkReadableBytes(length); + ByteBuf slice = retainedSlice(readerIndex, length); + readerIndex += length; + return slice; + } + + @Override + public ByteBuf readBytes(byte[] dst, int dstIndex, int length) { + checkReadableBytes(length); + getBytes(readerIndex, dst, dstIndex, length); + readerIndex += length; + return this; + } + + @Override + public ByteBuf readBytes(byte[] dst) { + readBytes(dst, 0, dst.length); + return this; + } + + @Override + public ByteBuf readBytes(ByteBuf dst) { + readBytes(dst, dst.writableBytes()); + return this; + } + + @Override + public ByteBuf readBytes(ByteBuf dst, int length) { + if (checkBounds) { + if (length > dst.writableBytes()) { + throw new IndexOutOfBoundsException(String.format( + "length(%d) exceeds dst.writableBytes(%d) where dst is: %s", length, dst.writableBytes(), dst)); + } + } + readBytes(dst, dst.writerIndex(), length); + dst.writerIndex(dst.writerIndex() + length); + return this; + } + + @Override + public ByteBuf readBytes(ByteBuf dst, int dstIndex, int length) { + checkReadableBytes(length); + getBytes(readerIndex, dst, dstIndex, length); + readerIndex += length; + return this; + } + + @Override + public ByteBuf readBytes(ByteBuffer dst) { + int length = dst.remaining(); + checkReadableBytes(length); + getBytes(readerIndex, dst); + readerIndex += length; + return this; + } + + @Override + public int readBytes(GatheringByteChannel out, int length) + throws IOException { + checkReadableBytes(length); + int readBytes = getBytes(readerIndex, out, length); + readerIndex += readBytes; + return readBytes; + } + + @Override + public int readBytes(FileChannel out, long position, int length) + throws IOException { + checkReadableBytes(length); + int readBytes = getBytes(readerIndex, out, position, length); + readerIndex += readBytes; + return readBytes; + } + + @Override + public ByteBuf readBytes(OutputStream out, int length) throws IOException { + checkReadableBytes(length); + getBytes(readerIndex, out, length); + readerIndex += length; + return this; + } + + @Override + public ByteBuf skipBytes(int length) { + checkReadableBytes(length); + readerIndex += length; + return this; + } + + @Override + public ByteBuf writeBoolean(boolean value) { + writeByte(value ? 1 : 0); + return this; + } + + @Override + public ByteBuf writeByte(int value) { + ensureWritable0(1); + _setByte(writerIndex++, value); + return this; + } + + @Override + public ByteBuf writeShort(int value) { + ensureWritable0(2); + _setShort(writerIndex, value); + writerIndex += 2; + return this; + } + + @Override + public ByteBuf writeShortLE(int value) { + ensureWritable0(2); + _setShortLE(writerIndex, value); + writerIndex += 2; + return this; + } + + @Override + public ByteBuf writeMedium(int value) { + ensureWritable0(3); + _setMedium(writerIndex, value); + writerIndex += 3; + return this; + } + + @Override + public ByteBuf writeMediumLE(int value) { + ensureWritable0(3); + _setMediumLE(writerIndex, value); + writerIndex += 3; + return this; + } + + @Override + public ByteBuf writeInt(int value) { + ensureWritable0(4); + _setInt(writerIndex, value); + writerIndex += 4; + return this; + } + + @Override + public ByteBuf writeIntLE(int value) { + ensureWritable0(4); + _setIntLE(writerIndex, value); + writerIndex += 4; + return this; + } + + @Override + public ByteBuf writeLong(long value) { + ensureWritable0(8); + _setLong(writerIndex, value); + writerIndex += 8; + return this; + } + + @Override + public ByteBuf writeLongLE(long value) { + ensureWritable0(8); + _setLongLE(writerIndex, value); + writerIndex += 8; + return this; + } + + @Override + public ByteBuf writeChar(int value) { + writeShort(value); + return this; + } + + @Override + public ByteBuf writeFloat(float value) { + writeInt(Float.floatToRawIntBits(value)); + return this; + } + + @Override + public ByteBuf writeDouble(double value) { + writeLong(Double.doubleToRawLongBits(value)); + return this; + } + + @Override + public ByteBuf writeBytes(byte[] src, int srcIndex, int length) { + ensureWritable(length); + setBytes(writerIndex, src, srcIndex, length); + writerIndex += length; + return this; + } + + @Override + public ByteBuf writeBytes(byte[] src) { + writeBytes(src, 0, src.length); + return this; + } + + @Override + public ByteBuf writeBytes(ByteBuf src) { + writeBytes(src, src.readableBytes()); + return this; + } + + @Override + public ByteBuf writeBytes(ByteBuf src, int length) { + if (checkBounds) { + checkReadableBounds(src, length); + } + writeBytes(src, src.readerIndex(), length); + src.readerIndex(src.readerIndex() + length); + return this; + } + + @Override + public ByteBuf writeBytes(ByteBuf src, int srcIndex, int length) { + ensureWritable(length); + setBytes(writerIndex, src, srcIndex, length); + writerIndex += length; + return this; + } + + @Override + public ByteBuf writeBytes(ByteBuffer src) { + int length = src.remaining(); + ensureWritable0(length); + setBytes(writerIndex, src); + writerIndex += length; + return this; + } + + @Override + public int writeBytes(InputStream in, int length) + throws IOException { + ensureWritable(length); + int writtenBytes = setBytes(writerIndex, in, length); + if (writtenBytes > 0) { + writerIndex += writtenBytes; + } + return writtenBytes; + } + + @Override + public int writeBytes(ScatteringByteChannel in, int length) throws IOException { + ensureWritable(length); + int writtenBytes = setBytes(writerIndex, in, length); + if (writtenBytes > 0) { + writerIndex += writtenBytes; + } + return writtenBytes; + } + + @Override + public int writeBytes(FileChannel in, long position, int length) throws IOException { + ensureWritable(length); + int writtenBytes = setBytes(writerIndex, in, position, length); + if (writtenBytes > 0) { + writerIndex += writtenBytes; + } + return writtenBytes; + } + + @Override + public ByteBuf writeZero(int length) { + if (length == 0) { + return this; + } + + ensureWritable(length); + int wIndex = writerIndex; + checkIndex0(wIndex, length); + + int nLong = length >>> 3; + int nBytes = length & 7; + for (int i = nLong; i > 0; i --) { + _setLong(wIndex, 0); + wIndex += 8; + } + if (nBytes == 4) { + _setInt(wIndex, 0); + wIndex += 4; + } else if (nBytes < 4) { + for (int i = nBytes; i > 0; i --) { + _setByte(wIndex, 0); + wIndex++; + } + } else { + _setInt(wIndex, 0); + wIndex += 4; + for (int i = nBytes - 4; i > 0; i --) { + _setByte(wIndex, 0); + wIndex++; + } + } + writerIndex = wIndex; + return this; + } + + @Override + public int writeCharSequence(CharSequence sequence, Charset charset) { + int written = setCharSequence0(writerIndex, sequence, charset, true); + writerIndex += written; + return written; + } + + @Override + public ByteBuf copy() { + return copy(readerIndex, readableBytes()); + } + + @Override + public ByteBuf duplicate() { + ensureAccessible(); + return new UnpooledDuplicatedByteBuf(this); + } + + @Override + public ByteBuf retainedDuplicate() { + return duplicate().retain(); + } + + @Override + public ByteBuf slice() { + return slice(readerIndex, readableBytes()); + } + + @Override + public ByteBuf retainedSlice() { + return slice().retain(); + } + + @Override + public ByteBuf slice(int index, int length) { + ensureAccessible(); + return new UnpooledSlicedByteBuf(this, index, length); + } + + @Override + public ByteBuf retainedSlice(int index, int length) { + return slice(index, length).retain(); + } + + @Override + public ByteBuffer nioBuffer() { + return nioBuffer(readerIndex, readableBytes()); + } + + @Override + public ByteBuffer[] nioBuffers() { + return nioBuffers(readerIndex, readableBytes()); + } + + @Override + public String toString(Charset charset) { + return toString(readerIndex, readableBytes(), charset); + } + + @Override + public String toString(int index, int length, Charset charset) { + return ByteBufUtil.decodeString(this, index, length, charset); + } + + @Override + public int indexOf(int fromIndex, int toIndex, byte value) { + if (fromIndex <= toIndex) { + return ByteBufUtil.firstIndexOf(this, fromIndex, toIndex, value); + } + return ByteBufUtil.lastIndexOf(this, fromIndex, toIndex, value); + } + + @Override + public int bytesBefore(byte value) { + return bytesBefore(readerIndex(), readableBytes(), value); + } + + @Override + public int bytesBefore(int length, byte value) { + checkReadableBytes(length); + return bytesBefore(readerIndex(), length, value); + } + + @Override + public int bytesBefore(int index, int length, byte value) { + int endIndex = indexOf(index, index + length, value); + if (endIndex < 0) { + return -1; + } + return endIndex - index; + } + + @Override + public int forEachByte(ByteProcessor processor) { + ensureAccessible(); + try { + return forEachByteAsc0(readerIndex, writerIndex, processor); + } catch (Exception e) { + PlatformDependent.throwException(e); + return -1; + } + } + + @Override + public int forEachByte(int index, int length, ByteProcessor processor) { + checkIndex(index, length); + try { + return forEachByteAsc0(index, index + length, processor); + } catch (Exception e) { + PlatformDependent.throwException(e); + return -1; + } + } + + int forEachByteAsc0(int start, int end, ByteProcessor processor) throws Exception { + for (; start < end; ++start) { + if (!processor.process(_getByte(start))) { + return start; + } + } + + return -1; + } + + @Override + public int forEachByteDesc(ByteProcessor processor) { + ensureAccessible(); + try { + return forEachByteDesc0(writerIndex - 1, readerIndex, processor); + } catch (Exception e) { + PlatformDependent.throwException(e); + return -1; + } + } + + @Override + public int forEachByteDesc(int index, int length, ByteProcessor processor) { + checkIndex(index, length); + try { + return forEachByteDesc0(index + length - 1, index, processor); + } catch (Exception e) { + PlatformDependent.throwException(e); + return -1; + } + } + + int forEachByteDesc0(int rStart, final int rEnd, ByteProcessor processor) throws Exception { + for (; rStart >= rEnd; --rStart) { + if (!processor.process(_getByte(rStart))) { + return rStart; + } + } + return -1; + } + + @Override + public int hashCode() { + return ByteBufUtil.hashCode(this); + } + + @Override + public boolean equals(Object o) { + return o instanceof ByteBuf && ByteBufUtil.equals(this, (ByteBuf) o); + } + + @Override + public int compareTo(ByteBuf that) { + return ByteBufUtil.compare(this, that); + } + + @Override + public String toString() { + if (refCnt() == 0) { + return StringUtil.simpleClassName(this) + "(freed)"; + } + + StringBuilder buf = new StringBuilder() + .append(StringUtil.simpleClassName(this)) + .append("(ridx: ").append(readerIndex) + .append(", widx: ").append(writerIndex) + .append(", cap: ").append(capacity()); + if (maxCapacity != Integer.MAX_VALUE) { + buf.append('/').append(maxCapacity); + } + + ByteBuf unwrapped = unwrap(); + if (unwrapped != null) { + buf.append(", unwrapped: ").append(unwrapped); + } + buf.append(')'); + return buf.toString(); + } + + protected final void checkIndex(int index) { + checkIndex(index, 1); + } + + protected final void checkIndex(int index, int fieldLength) { + ensureAccessible(); + checkIndex0(index, fieldLength); + } + + private static void checkRangeBounds(final String indexName, final int index, + final int fieldLength, final int capacity) { + if (isOutOfBounds(index, fieldLength, capacity)) { + throw new IndexOutOfBoundsException(String.format( + "%s: %d, length: %d (expected: range(0, %d))", indexName, index, fieldLength, capacity)); + } + } + + final void checkIndex0(int index, int fieldLength) { + if (checkBounds) { + checkRangeBounds("index", index, fieldLength, capacity()); + } + } + + protected final void checkSrcIndex(int index, int length, int srcIndex, int srcCapacity) { + checkIndex(index, length); + if (checkBounds) { + checkRangeBounds("srcIndex", srcIndex, length, srcCapacity); + } + } + + protected final void checkDstIndex(int index, int length, int dstIndex, int dstCapacity) { + checkIndex(index, length); + if (checkBounds) { + checkRangeBounds("dstIndex", dstIndex, length, dstCapacity); + } + } + + protected final void checkDstIndex(int length, int dstIndex, int dstCapacity) { + checkReadableBytes(length); + if (checkBounds) { + checkRangeBounds("dstIndex", dstIndex, length, dstCapacity); + } + } + + /** + * Throws an {@link IndexOutOfBoundsException} if the current + * {@linkplain #readableBytes() readable bytes} of this buffer is less + * than the specified value. + */ + protected final void checkReadableBytes(int minimumReadableBytes) { + checkReadableBytes0(checkPositiveOrZero(minimumReadableBytes, "minimumReadableBytes")); + } + + protected final void checkNewCapacity(int newCapacity) { + ensureAccessible(); + if (checkBounds && (newCapacity < 0 || newCapacity > maxCapacity())) { + throw new IllegalArgumentException("newCapacity: " + newCapacity + + " (expected: 0-" + maxCapacity() + ')'); + } + } + + private void checkReadableBytes0(int minimumReadableBytes) { + ensureAccessible(); + if (checkBounds && readerIndex > writerIndex - minimumReadableBytes) { + throw new IndexOutOfBoundsException(String.format( + "readerIndex(%d) + length(%d) exceeds writerIndex(%d): %s", + readerIndex, minimumReadableBytes, writerIndex, this)); + } + } + + /** + * Should be called by every method that tries to access the buffers content to check + * if the buffer was released before. + */ + protected final void ensureAccessible() { + if (checkAccessible && !isAccessible()) { + throw new IllegalReferenceCountException(0); + } + } + + final void setIndex0(int readerIndex, int writerIndex) { + this.readerIndex = readerIndex; + this.writerIndex = writerIndex; + } + + final void discardMarks() { + markedReaderIndex = markedWriterIndex = 0; + } +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/AbstractByteBufAllocator.java b/netty-buffer/src/main/java/io/netty/buffer/AbstractByteBufAllocator.java new file mode 100644 index 0000000..3aa05ae --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/AbstractByteBufAllocator.java @@ -0,0 +1,280 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.buffer; + +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; + +import io.netty.util.ResourceLeakDetector; +import io.netty.util.ResourceLeakTracker; +import io.netty.util.internal.MathUtil; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.StringUtil; + +/** + * Skeletal {@link ByteBufAllocator} implementation to extend. + */ +public abstract class AbstractByteBufAllocator implements ByteBufAllocator { + static final int DEFAULT_INITIAL_CAPACITY = 256; + static final int DEFAULT_MAX_CAPACITY = Integer.MAX_VALUE; + static final int DEFAULT_MAX_COMPONENTS = 16; + static final int CALCULATE_THRESHOLD = 1048576 * 4; // 4 MiB page + + static { + ResourceLeakDetector.addExclusions(AbstractByteBufAllocator.class, "toLeakAwareBuffer"); + } + + protected static ByteBuf toLeakAwareBuffer(ByteBuf buf) { + ResourceLeakTracker leak; + switch (ResourceLeakDetector.getLevel()) { + case SIMPLE: + leak = AbstractByteBuf.leakDetector.track(buf); + if (leak != null) { + buf = new SimpleLeakAwareByteBuf(buf, leak); + } + break; + case ADVANCED: + case PARANOID: + leak = AbstractByteBuf.leakDetector.track(buf); + if (leak != null) { + buf = new AdvancedLeakAwareByteBuf(buf, leak); + } + break; + default: + break; + } + return buf; + } + + protected static CompositeByteBuf toLeakAwareBuffer(CompositeByteBuf buf) { + ResourceLeakTracker leak; + switch (ResourceLeakDetector.getLevel()) { + case SIMPLE: + leak = AbstractByteBuf.leakDetector.track(buf); + if (leak != null) { + buf = new SimpleLeakAwareCompositeByteBuf(buf, leak); + } + break; + case ADVANCED: + case PARANOID: + leak = AbstractByteBuf.leakDetector.track(buf); + if (leak != null) { + buf = new AdvancedLeakAwareCompositeByteBuf(buf, leak); + } + break; + default: + break; + } + return buf; + } + + private final boolean directByDefault; + private final ByteBuf emptyBuf; + + /** + * Instance use heap buffers by default + */ + protected AbstractByteBufAllocator() { + this(false); + } + + /** + * Create new instance + * + * @param preferDirect {@code true} if {@link #buffer(int)} should try to allocate a direct buffer rather than + * a heap buffer + */ + protected AbstractByteBufAllocator(boolean preferDirect) { + directByDefault = preferDirect && PlatformDependent.hasUnsafe(); + emptyBuf = new EmptyByteBuf(this); + } + + @Override + public ByteBuf buffer() { + if (directByDefault) { + return directBuffer(); + } + return heapBuffer(); + } + + @Override + public ByteBuf buffer(int initialCapacity) { + if (directByDefault) { + return directBuffer(initialCapacity); + } + return heapBuffer(initialCapacity); + } + + @Override + public ByteBuf buffer(int initialCapacity, int maxCapacity) { + if (directByDefault) { + return directBuffer(initialCapacity, maxCapacity); + } + return heapBuffer(initialCapacity, maxCapacity); + } + + @Override + public ByteBuf ioBuffer() { + if (PlatformDependent.hasUnsafe() || isDirectBufferPooled()) { + return directBuffer(DEFAULT_INITIAL_CAPACITY); + } + return heapBuffer(DEFAULT_INITIAL_CAPACITY); + } + + @Override + public ByteBuf ioBuffer(int initialCapacity) { + if (PlatformDependent.hasUnsafe() || isDirectBufferPooled()) { + return directBuffer(initialCapacity); + } + return heapBuffer(initialCapacity); + } + + @Override + public ByteBuf ioBuffer(int initialCapacity, int maxCapacity) { + if (PlatformDependent.hasUnsafe() || isDirectBufferPooled()) { + return directBuffer(initialCapacity, maxCapacity); + } + return heapBuffer(initialCapacity, maxCapacity); + } + + @Override + public ByteBuf heapBuffer() { + return heapBuffer(DEFAULT_INITIAL_CAPACITY, DEFAULT_MAX_CAPACITY); + } + + @Override + public ByteBuf heapBuffer(int initialCapacity) { + return heapBuffer(initialCapacity, DEFAULT_MAX_CAPACITY); + } + + @Override + public ByteBuf heapBuffer(int initialCapacity, int maxCapacity) { + if (initialCapacity == 0 && maxCapacity == 0) { + return emptyBuf; + } + validate(initialCapacity, maxCapacity); + return newHeapBuffer(initialCapacity, maxCapacity); + } + + @Override + public ByteBuf directBuffer() { + return directBuffer(DEFAULT_INITIAL_CAPACITY, DEFAULT_MAX_CAPACITY); + } + + @Override + public ByteBuf directBuffer(int initialCapacity) { + return directBuffer(initialCapacity, DEFAULT_MAX_CAPACITY); + } + + @Override + public ByteBuf directBuffer(int initialCapacity, int maxCapacity) { + if (initialCapacity == 0 && maxCapacity == 0) { + return emptyBuf; + } + validate(initialCapacity, maxCapacity); + return newDirectBuffer(initialCapacity, maxCapacity); + } + + @Override + public CompositeByteBuf compositeBuffer() { + if (directByDefault) { + return compositeDirectBuffer(); + } + return compositeHeapBuffer(); + } + + @Override + public CompositeByteBuf compositeBuffer(int maxNumComponents) { + if (directByDefault) { + return compositeDirectBuffer(maxNumComponents); + } + return compositeHeapBuffer(maxNumComponents); + } + + @Override + public CompositeByteBuf compositeHeapBuffer() { + return compositeHeapBuffer(DEFAULT_MAX_COMPONENTS); + } + + @Override + public CompositeByteBuf compositeHeapBuffer(int maxNumComponents) { + return toLeakAwareBuffer(new CompositeByteBuf(this, false, maxNumComponents)); + } + + @Override + public CompositeByteBuf compositeDirectBuffer() { + return compositeDirectBuffer(DEFAULT_MAX_COMPONENTS); + } + + @Override + public CompositeByteBuf compositeDirectBuffer(int maxNumComponents) { + return toLeakAwareBuffer(new CompositeByteBuf(this, true, maxNumComponents)); + } + + private static void validate(int initialCapacity, int maxCapacity) { + checkPositiveOrZero(initialCapacity, "initialCapacity"); + if (initialCapacity > maxCapacity) { + throw new IllegalArgumentException(String.format( + "initialCapacity: %d (expected: not greater than maxCapacity(%d)", + initialCapacity, maxCapacity)); + } + } + + /** + * Create a heap {@link ByteBuf} with the given initialCapacity and maxCapacity. + */ + protected abstract ByteBuf newHeapBuffer(int initialCapacity, int maxCapacity); + + /** + * Create a direct {@link ByteBuf} with the given initialCapacity and maxCapacity. + */ + protected abstract ByteBuf newDirectBuffer(int initialCapacity, int maxCapacity); + + @Override + public String toString() { + return StringUtil.simpleClassName(this) + "(directByDefault: " + directByDefault + ')'; + } + + @Override + public int calculateNewCapacity(int minNewCapacity, int maxCapacity) { + checkPositiveOrZero(minNewCapacity, "minNewCapacity"); + if (minNewCapacity > maxCapacity) { + throw new IllegalArgumentException(String.format( + "minNewCapacity: %d (expected: not greater than maxCapacity(%d)", + minNewCapacity, maxCapacity)); + } + final int threshold = CALCULATE_THRESHOLD; // 4 MiB page + + if (minNewCapacity == threshold) { + return threshold; + } + + // If over threshold, do not double but just increase by threshold. + if (minNewCapacity > threshold) { + int newCapacity = minNewCapacity / threshold * threshold; + if (newCapacity > maxCapacity - threshold) { + newCapacity = maxCapacity; + } else { + newCapacity += threshold; + } + return newCapacity; + } + + // 64 <= newCapacity is a power of 2 <= threshold + final int newCapacity = MathUtil.findNextPositivePowerOfTwo(Math.max(minNewCapacity, 64)); + return Math.min(newCapacity, maxCapacity); + } +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/AbstractDerivedByteBuf.java b/netty-buffer/src/main/java/io/netty/buffer/AbstractDerivedByteBuf.java new file mode 100644 index 0000000..c3765c8 --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/AbstractDerivedByteBuf.java @@ -0,0 +1,129 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.buffer; + +import java.nio.ByteBuffer; + +/** + * Abstract base class for {@link ByteBuf} implementations that wrap another + * {@link ByteBuf}. + * + * @deprecated Do not use. + */ +@Deprecated +public abstract class AbstractDerivedByteBuf extends AbstractByteBuf { + + protected AbstractDerivedByteBuf(int maxCapacity) { + super(maxCapacity); + } + + @Override + final boolean isAccessible() { + return isAccessible0(); + } + + boolean isAccessible0() { + return unwrap().isAccessible(); + } + + @Override + public final int refCnt() { + return refCnt0(); + } + + int refCnt0() { + return unwrap().refCnt(); + } + + @Override + public final ByteBuf retain() { + return retain0(); + } + + ByteBuf retain0() { + unwrap().retain(); + return this; + } + + @Override + public final ByteBuf retain(int increment) { + return retain0(increment); + } + + ByteBuf retain0(int increment) { + unwrap().retain(increment); + return this; + } + + @Override + public final ByteBuf touch() { + return touch0(); + } + + ByteBuf touch0() { + unwrap().touch(); + return this; + } + + @Override + public final ByteBuf touch(Object hint) { + return touch0(hint); + } + + ByteBuf touch0(Object hint) { + unwrap().touch(hint); + return this; + } + + @Override + public final boolean release() { + return release0(); + } + + boolean release0() { + return unwrap().release(); + } + + @Override + public final boolean release(int decrement) { + return release0(decrement); + } + + boolean release0(int decrement) { + return unwrap().release(decrement); + } + + @Override + public boolean isReadOnly() { + return unwrap().isReadOnly(); + } + + @Override + public ByteBuffer internalNioBuffer(int index, int length) { + return nioBuffer(index, length); + } + + @Override + public ByteBuffer nioBuffer(int index, int length) { + return unwrap().nioBuffer(index, length); + } + + @Override + public boolean isContiguous() { + return unwrap().isContiguous(); + } +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/AbstractPooledDerivedByteBuf.java b/netty-buffer/src/main/java/io/netty/buffer/AbstractPooledDerivedByteBuf.java new file mode 100644 index 0000000..3086fd7 --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/AbstractPooledDerivedByteBuf.java @@ -0,0 +1,323 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.buffer; + +import io.netty.util.Recycler.EnhancedHandle; +import io.netty.util.internal.ObjectPool.Handle; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + +/** + * Abstract base class for derived {@link ByteBuf} implementations. + */ +abstract class AbstractPooledDerivedByteBuf extends AbstractReferenceCountedByteBuf { + + private final EnhancedHandle recyclerHandle; + private AbstractByteBuf rootParent; + /** + * Deallocations of a pooled derived buffer should always propagate through the entire chain of derived buffers. + * This is because each pooled derived buffer maintains its own reference count and we should respect each one. + * If deallocations cause a release of the "root parent" then then we may prematurely release the underlying + * content before all the derived buffers have been released. + */ + private ByteBuf parent; + + @SuppressWarnings("unchecked") + AbstractPooledDerivedByteBuf(Handle recyclerHandle) { + super(0); + this.recyclerHandle = (EnhancedHandle) recyclerHandle; + } + + // Called from within SimpleLeakAwareByteBuf and AdvancedLeakAwareByteBuf. + final void parent(ByteBuf newParent) { + assert newParent instanceof SimpleLeakAwareByteBuf; + parent = newParent; + } + + @Override + public final AbstractByteBuf unwrap() { + return rootParent; + } + + final U init( + AbstractByteBuf unwrapped, ByteBuf wrapped, int readerIndex, int writerIndex, int maxCapacity) { + wrapped.retain(); // Retain up front to ensure the parent is accessible before doing more work. + parent = wrapped; + rootParent = unwrapped; + + try { + maxCapacity(maxCapacity); + setIndex0(readerIndex, writerIndex); // It is assumed the bounds checking is done by the caller. + resetRefCnt(); + + @SuppressWarnings("unchecked") + final U castThis = (U) this; + wrapped = null; + return castThis; + } finally { + if (wrapped != null) { + parent = rootParent = null; + wrapped.release(); + } + } + } + + @Override + protected final void deallocate() { + // We need to first store a reference to the parent before recycle this instance. This is needed as + // otherwise it is possible that the same AbstractPooledDerivedByteBuf is again obtained and init(...) is + // called before we actually have a chance to call release(). This leads to call release() on the wrong parent. + ByteBuf parent = this.parent; + recyclerHandle.unguardedRecycle(this); + parent.release(); + } + + @Override + public final ByteBufAllocator alloc() { + return unwrap().alloc(); + } + + @Override + @Deprecated + public final ByteOrder order() { + return unwrap().order(); + } + + @Override + public boolean isReadOnly() { + return unwrap().isReadOnly(); + } + + @Override + public final boolean isDirect() { + return unwrap().isDirect(); + } + + @Override + public boolean hasArray() { + return unwrap().hasArray(); + } + + @Override + public byte[] array() { + return unwrap().array(); + } + + @Override + public boolean hasMemoryAddress() { + return unwrap().hasMemoryAddress(); + } + + @Override + public boolean isContiguous() { + return unwrap().isContiguous(); + } + + @Override + public final int nioBufferCount() { + return unwrap().nioBufferCount(); + } + + @Override + public final ByteBuffer internalNioBuffer(int index, int length) { + return nioBuffer(index, length); + } + + @Override + public final ByteBuf retainedSlice() { + final int index = readerIndex(); + return retainedSlice(index, writerIndex() - index); + } + + @Override + public ByteBuf slice(int index, int length) { + ensureAccessible(); + // All reference count methods should be inherited from this object (this is the "parent"). + return new PooledNonRetainedSlicedByteBuf(this, unwrap(), index, length); + } + + final ByteBuf duplicate0() { + ensureAccessible(); + // All reference count methods should be inherited from this object (this is the "parent"). + return new PooledNonRetainedDuplicateByteBuf(this, unwrap()); + } + + private static final class PooledNonRetainedDuplicateByteBuf extends UnpooledDuplicatedByteBuf { + private final ByteBuf referenceCountDelegate; + + PooledNonRetainedDuplicateByteBuf(ByteBuf referenceCountDelegate, AbstractByteBuf buffer) { + super(buffer); + this.referenceCountDelegate = referenceCountDelegate; + } + + @Override + boolean isAccessible0() { + return referenceCountDelegate.isAccessible(); + } + + @Override + int refCnt0() { + return referenceCountDelegate.refCnt(); + } + + @Override + ByteBuf retain0() { + referenceCountDelegate.retain(); + return this; + } + + @Override + ByteBuf retain0(int increment) { + referenceCountDelegate.retain(increment); + return this; + } + + @Override + ByteBuf touch0() { + referenceCountDelegate.touch(); + return this; + } + + @Override + ByteBuf touch0(Object hint) { + referenceCountDelegate.touch(hint); + return this; + } + + @Override + boolean release0() { + return referenceCountDelegate.release(); + } + + @Override + boolean release0(int decrement) { + return referenceCountDelegate.release(decrement); + } + + @Override + public ByteBuf duplicate() { + ensureAccessible(); + return new PooledNonRetainedDuplicateByteBuf(referenceCountDelegate, this); + } + + @Override + public ByteBuf retainedDuplicate() { + return PooledDuplicatedByteBuf.newInstance(unwrap(), this, readerIndex(), writerIndex()); + } + + @Override + public ByteBuf slice(int index, int length) { + checkIndex(index, length); + return new PooledNonRetainedSlicedByteBuf(referenceCountDelegate, unwrap(), index, length); + } + + @Override + public ByteBuf retainedSlice() { + // Capacity is not allowed to change for a sliced ByteBuf, so length == capacity() + return retainedSlice(readerIndex(), capacity()); + } + + @Override + public ByteBuf retainedSlice(int index, int length) { + return PooledSlicedByteBuf.newInstance(unwrap(), this, index, length); + } + } + + private static final class PooledNonRetainedSlicedByteBuf extends UnpooledSlicedByteBuf { + private final ByteBuf referenceCountDelegate; + + PooledNonRetainedSlicedByteBuf(ByteBuf referenceCountDelegate, + AbstractByteBuf buffer, int index, int length) { + super(buffer, index, length); + this.referenceCountDelegate = referenceCountDelegate; + } + + @Override + boolean isAccessible0() { + return referenceCountDelegate.isAccessible(); + } + + @Override + int refCnt0() { + return referenceCountDelegate.refCnt(); + } + + @Override + ByteBuf retain0() { + referenceCountDelegate.retain(); + return this; + } + + @Override + ByteBuf retain0(int increment) { + referenceCountDelegate.retain(increment); + return this; + } + + @Override + ByteBuf touch0() { + referenceCountDelegate.touch(); + return this; + } + + @Override + ByteBuf touch0(Object hint) { + referenceCountDelegate.touch(hint); + return this; + } + + @Override + boolean release0() { + return referenceCountDelegate.release(); + } + + @Override + boolean release0(int decrement) { + return referenceCountDelegate.release(decrement); + } + + @Override + public ByteBuf duplicate() { + ensureAccessible(); + return new PooledNonRetainedDuplicateByteBuf(referenceCountDelegate, unwrap()) + .setIndex(idx(readerIndex()), idx(writerIndex())); + } + + @Override + public ByteBuf retainedDuplicate() { + return PooledDuplicatedByteBuf.newInstance(unwrap(), this, idx(readerIndex()), idx(writerIndex())); + } + + @Override + public ByteBuf slice(int index, int length) { + checkIndex(index, length); + return new PooledNonRetainedSlicedByteBuf(referenceCountDelegate, unwrap(), idx(index), length); + } + + @Override + public ByteBuf retainedSlice() { + // Capacity is not allowed to change for a sliced ByteBuf, so length == capacity() + return retainedSlice(0, capacity()); + } + + @Override + public ByteBuf retainedSlice(int index, int length) { + return PooledSlicedByteBuf.newInstance(unwrap(), this, idx(index), length); + } + } +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/AbstractReferenceCountedByteBuf.java b/netty-buffer/src/main/java/io/netty/buffer/AbstractReferenceCountedByteBuf.java new file mode 100644 index 0000000..bb15579 --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/AbstractReferenceCountedByteBuf.java @@ -0,0 +1,120 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.buffer; + +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; + +import io.netty.util.internal.ReferenceCountUpdater; + +/** + * Abstract base class for {@link ByteBuf} implementations that count references. + */ +public abstract class AbstractReferenceCountedByteBuf extends AbstractByteBuf { + private static final long REFCNT_FIELD_OFFSET = + ReferenceCountUpdater.getUnsafeOffset(AbstractReferenceCountedByteBuf.class, "refCnt"); + private static final AtomicIntegerFieldUpdater AIF_UPDATER = + AtomicIntegerFieldUpdater.newUpdater(AbstractReferenceCountedByteBuf.class, "refCnt"); + + private static final ReferenceCountUpdater updater = + new ReferenceCountUpdater() { + @Override + protected AtomicIntegerFieldUpdater updater() { + return AIF_UPDATER; + } + @Override + protected long unsafeOffset() { + return REFCNT_FIELD_OFFSET; + } + }; + + // Value might not equal "real" reference count, all access should be via the updater + @SuppressWarnings({"unused", "FieldMayBeFinal"}) + private volatile int refCnt; + + protected AbstractReferenceCountedByteBuf(int maxCapacity) { + super(maxCapacity); + updater.setInitialValue(this); + } + + @Override + boolean isAccessible() { + // Try to do non-volatile read for performance as the ensureAccessible() is racy anyway and only provide + // a best-effort guard. + return updater.isLiveNonVolatile(this); + } + + @Override + public int refCnt() { + return updater.refCnt(this); + } + + /** + * An unsafe operation intended for use by a subclass that sets the reference count of the buffer directly + */ + protected final void setRefCnt(int refCnt) { + updater.setRefCnt(this, refCnt); + } + + /** + * An unsafe operation intended for use by a subclass that resets the reference count of the buffer to 1 + */ + protected final void resetRefCnt() { + updater.resetRefCnt(this); + } + + @Override + public ByteBuf retain() { + return updater.retain(this); + } + + @Override + public ByteBuf retain(int increment) { + return updater.retain(this, increment); + } + + @Override + public ByteBuf touch() { + return this; + } + + @Override + public ByteBuf touch(Object hint) { + return this; + } + + @Override + public boolean release() { + return handleRelease(updater.release(this)); + } + + @Override + public boolean release(int decrement) { + return handleRelease(updater.release(this, decrement)); + } + + private boolean handleRelease(boolean result) { + if (result) { + deallocate(); + } + return result; + } + + /** + * Called once {@link #refCnt()} is equals 0. + */ + protected abstract void deallocate(); +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/AbstractUnpooledSlicedByteBuf.java b/netty-buffer/src/main/java/io/netty/buffer/AbstractUnpooledSlicedByteBuf.java new file mode 100644 index 0000000..f1863ff --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/AbstractUnpooledSlicedByteBuf.java @@ -0,0 +1,477 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import io.netty.util.ByteProcessor; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.channels.FileChannel; +import java.nio.channels.GatheringByteChannel; +import java.nio.channels.ScatteringByteChannel; +import java.nio.charset.Charset; + +import static io.netty.util.internal.MathUtil.isOutOfBounds; + +abstract class AbstractUnpooledSlicedByteBuf extends AbstractDerivedByteBuf { + private final ByteBuf buffer; + private final int adjustment; + + AbstractUnpooledSlicedByteBuf(ByteBuf buffer, int index, int length) { + super(length); + checkSliceOutOfBounds(index, length, buffer); + + if (buffer instanceof AbstractUnpooledSlicedByteBuf) { + this.buffer = ((AbstractUnpooledSlicedByteBuf) buffer).buffer; + adjustment = ((AbstractUnpooledSlicedByteBuf) buffer).adjustment + index; + } else if (buffer instanceof DuplicatedByteBuf) { + this.buffer = buffer.unwrap(); + adjustment = index; + } else { + this.buffer = buffer; + adjustment = index; + } + + initLength(length); + writerIndex(length); + } + + /** + * Called by the constructor before {@link #writerIndex(int)}. + * @param length the {@code length} argument from the constructor. + */ + void initLength(int length) { + } + + int length() { + return capacity(); + } + + @Override + public ByteBuf unwrap() { + return buffer; + } + + @Override + public ByteBufAllocator alloc() { + return unwrap().alloc(); + } + + @Override + @Deprecated + public ByteOrder order() { + return unwrap().order(); + } + + @Override + public boolean isDirect() { + return unwrap().isDirect(); + } + + @Override + public ByteBuf capacity(int newCapacity) { + throw new UnsupportedOperationException("sliced buffer"); + } + + @Override + public boolean hasArray() { + return unwrap().hasArray(); + } + + @Override + public byte[] array() { + return unwrap().array(); + } + + @Override + public int arrayOffset() { + return idx(unwrap().arrayOffset()); + } + + @Override + public boolean hasMemoryAddress() { + return unwrap().hasMemoryAddress(); + } + + @Override + public long memoryAddress() { + return unwrap().memoryAddress() + adjustment; + } + + @Override + public byte getByte(int index) { + checkIndex0(index, 1); + return unwrap().getByte(idx(index)); + } + + @Override + protected byte _getByte(int index) { + return unwrap().getByte(idx(index)); + } + + @Override + public short getShort(int index) { + checkIndex0(index, 2); + return unwrap().getShort(idx(index)); + } + + @Override + protected short _getShort(int index) { + return unwrap().getShort(idx(index)); + } + + @Override + public short getShortLE(int index) { + checkIndex0(index, 2); + return unwrap().getShortLE(idx(index)); + } + + @Override + protected short _getShortLE(int index) { + return unwrap().getShortLE(idx(index)); + } + + @Override + public int getUnsignedMedium(int index) { + checkIndex0(index, 3); + return unwrap().getUnsignedMedium(idx(index)); + } + + @Override + protected int _getUnsignedMedium(int index) { + return unwrap().getUnsignedMedium(idx(index)); + } + + @Override + public int getUnsignedMediumLE(int index) { + checkIndex0(index, 3); + return unwrap().getUnsignedMediumLE(idx(index)); + } + + @Override + protected int _getUnsignedMediumLE(int index) { + return unwrap().getUnsignedMediumLE(idx(index)); + } + + @Override + public int getInt(int index) { + checkIndex0(index, 4); + return unwrap().getInt(idx(index)); + } + + @Override + protected int _getInt(int index) { + return unwrap().getInt(idx(index)); + } + + @Override + public int getIntLE(int index) { + checkIndex0(index, 4); + return unwrap().getIntLE(idx(index)); + } + + @Override + protected int _getIntLE(int index) { + return unwrap().getIntLE(idx(index)); + } + + @Override + public long getLong(int index) { + checkIndex0(index, 8); + return unwrap().getLong(idx(index)); + } + + @Override + protected long _getLong(int index) { + return unwrap().getLong(idx(index)); + } + + @Override + public long getLongLE(int index) { + checkIndex0(index, 8); + return unwrap().getLongLE(idx(index)); + } + + @Override + protected long _getLongLE(int index) { + return unwrap().getLongLE(idx(index)); + } + + @Override + public ByteBuf duplicate() { + return unwrap().duplicate().setIndex(idx(readerIndex()), idx(writerIndex())); + } + + @Override + public ByteBuf copy(int index, int length) { + checkIndex0(index, length); + return unwrap().copy(idx(index), length); + } + + @Override + public ByteBuf slice(int index, int length) { + checkIndex0(index, length); + return unwrap().slice(idx(index), length); + } + + @Override + public ByteBuf getBytes(int index, ByteBuf dst, int dstIndex, int length) { + checkIndex0(index, length); + unwrap().getBytes(idx(index), dst, dstIndex, length); + return this; + } + + @Override + public ByteBuf getBytes(int index, byte[] dst, int dstIndex, int length) { + checkIndex0(index, length); + unwrap().getBytes(idx(index), dst, dstIndex, length); + return this; + } + + @Override + public ByteBuf getBytes(int index, ByteBuffer dst) { + checkIndex0(index, dst.remaining()); + unwrap().getBytes(idx(index), dst); + return this; + } + + @Override + public ByteBuf setByte(int index, int value) { + checkIndex0(index, 1); + unwrap().setByte(idx(index), value); + return this; + } + + @Override + public CharSequence getCharSequence(int index, int length, Charset charset) { + checkIndex0(index, length); + return unwrap().getCharSequence(idx(index), length, charset); + } + + @Override + protected void _setByte(int index, int value) { + unwrap().setByte(idx(index), value); + } + + @Override + public ByteBuf setShort(int index, int value) { + checkIndex0(index, 2); + unwrap().setShort(idx(index), value); + return this; + } + + @Override + protected void _setShort(int index, int value) { + unwrap().setShort(idx(index), value); + } + + @Override + public ByteBuf setShortLE(int index, int value) { + checkIndex0(index, 2); + unwrap().setShortLE(idx(index), value); + return this; + } + + @Override + protected void _setShortLE(int index, int value) { + unwrap().setShortLE(idx(index), value); + } + + @Override + public ByteBuf setMedium(int index, int value) { + checkIndex0(index, 3); + unwrap().setMedium(idx(index), value); + return this; + } + + @Override + protected void _setMedium(int index, int value) { + unwrap().setMedium(idx(index), value); + } + + @Override + public ByteBuf setMediumLE(int index, int value) { + checkIndex0(index, 3); + unwrap().setMediumLE(idx(index), value); + return this; + } + + @Override + protected void _setMediumLE(int index, int value) { + unwrap().setMediumLE(idx(index), value); + } + + @Override + public ByteBuf setInt(int index, int value) { + checkIndex0(index, 4); + unwrap().setInt(idx(index), value); + return this; + } + + @Override + protected void _setInt(int index, int value) { + unwrap().setInt(idx(index), value); + } + + @Override + public ByteBuf setIntLE(int index, int value) { + checkIndex0(index, 4); + unwrap().setIntLE(idx(index), value); + return this; + } + + @Override + protected void _setIntLE(int index, int value) { + unwrap().setIntLE(idx(index), value); + } + + @Override + public ByteBuf setLong(int index, long value) { + checkIndex0(index, 8); + unwrap().setLong(idx(index), value); + return this; + } + + @Override + protected void _setLong(int index, long value) { + unwrap().setLong(idx(index), value); + } + + @Override + public ByteBuf setLongLE(int index, long value) { + checkIndex0(index, 8); + unwrap().setLongLE(idx(index), value); + return this; + } + + @Override + protected void _setLongLE(int index, long value) { + unwrap().setLongLE(idx(index), value); + } + + @Override + public ByteBuf setBytes(int index, byte[] src, int srcIndex, int length) { + checkIndex0(index, length); + unwrap().setBytes(idx(index), src, srcIndex, length); + return this; + } + + @Override + public ByteBuf setBytes(int index, ByteBuf src, int srcIndex, int length) { + checkIndex0(index, length); + unwrap().setBytes(idx(index), src, srcIndex, length); + return this; + } + + @Override + public ByteBuf setBytes(int index, ByteBuffer src) { + checkIndex0(index, src.remaining()); + unwrap().setBytes(idx(index), src); + return this; + } + + @Override + public ByteBuf getBytes(int index, OutputStream out, int length) throws IOException { + checkIndex0(index, length); + unwrap().getBytes(idx(index), out, length); + return this; + } + + @Override + public int getBytes(int index, GatheringByteChannel out, int length) throws IOException { + checkIndex0(index, length); + return unwrap().getBytes(idx(index), out, length); + } + + @Override + public int getBytes(int index, FileChannel out, long position, int length) throws IOException { + checkIndex0(index, length); + return unwrap().getBytes(idx(index), out, position, length); + } + + @Override + public int setBytes(int index, InputStream in, int length) throws IOException { + checkIndex0(index, length); + return unwrap().setBytes(idx(index), in, length); + } + + @Override + public int setBytes(int index, ScatteringByteChannel in, int length) throws IOException { + checkIndex0(index, length); + return unwrap().setBytes(idx(index), in, length); + } + + @Override + public int setBytes(int index, FileChannel in, long position, int length) throws IOException { + checkIndex0(index, length); + return unwrap().setBytes(idx(index), in, position, length); + } + + @Override + public int nioBufferCount() { + return unwrap().nioBufferCount(); + } + + @Override + public ByteBuffer nioBuffer(int index, int length) { + checkIndex0(index, length); + return unwrap().nioBuffer(idx(index), length); + } + + @Override + public ByteBuffer[] nioBuffers(int index, int length) { + checkIndex0(index, length); + return unwrap().nioBuffers(idx(index), length); + } + + @Override + public int forEachByte(int index, int length, ByteProcessor processor) { + checkIndex0(index, length); + int ret = unwrap().forEachByte(idx(index), length, processor); + if (ret >= adjustment) { + return ret - adjustment; + } else { + return -1; + } + } + + @Override + public int forEachByteDesc(int index, int length, ByteProcessor processor) { + checkIndex0(index, length); + int ret = unwrap().forEachByteDesc(idx(index), length, processor); + if (ret >= adjustment) { + return ret - adjustment; + } else { + return -1; + } + } + + /** + * Returns the index with the needed adjustment. + */ + final int idx(int index) { + return index + adjustment; + } + + static void checkSliceOutOfBounds(int index, int length, ByteBuf buffer) { + if (isOutOfBounds(index, length, buffer.capacity())) { + throw new IndexOutOfBoundsException(buffer + ".slice(" + index + ", " + length + ')'); + } + } +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/AbstractUnsafeSwappedByteBuf.java b/netty-buffer/src/main/java/io/netty/buffer/AbstractUnsafeSwappedByteBuf.java new file mode 100644 index 0000000..3ebcef3 --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/AbstractUnsafeSwappedByteBuf.java @@ -0,0 +1,171 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import io.netty.util.internal.PlatformDependent; + +import java.nio.ByteOrder; + +import static io.netty.util.internal.PlatformDependent.BIG_ENDIAN_NATIVE_ORDER; + +/** + * Special {@link SwappedByteBuf} for {@link ByteBuf}s that is using unsafe. + */ +abstract class AbstractUnsafeSwappedByteBuf extends SwappedByteBuf { + private final boolean nativeByteOrder; + private final AbstractByteBuf wrapped; + + AbstractUnsafeSwappedByteBuf(AbstractByteBuf buf) { + super(buf); + assert PlatformDependent.isUnaligned(); + wrapped = buf; + nativeByteOrder = BIG_ENDIAN_NATIVE_ORDER == (order() == ByteOrder.BIG_ENDIAN); + } + + @Override + public final long getLong(int index) { + wrapped.checkIndex(index, 8); + long v = _getLong(wrapped, index); + return nativeByteOrder ? v : Long.reverseBytes(v); + } + + @Override + public final float getFloat(int index) { + return Float.intBitsToFloat(getInt(index)); + } + + @Override + public final double getDouble(int index) { + return Double.longBitsToDouble(getLong(index)); + } + + @Override + public final char getChar(int index) { + return (char) getShort(index); + } + + @Override + public final long getUnsignedInt(int index) { + return getInt(index) & 0xFFFFFFFFL; + } + + @Override + public final int getInt(int index) { + wrapped.checkIndex(index, 4); + int v = _getInt(wrapped, index); + return nativeByteOrder ? v : Integer.reverseBytes(v); + } + + @Override + public final int getUnsignedShort(int index) { + return getShort(index) & 0xFFFF; + } + + @Override + public final short getShort(int index) { + wrapped.checkIndex(index, 2); + short v = _getShort(wrapped, index); + return nativeByteOrder ? v : Short.reverseBytes(v); + } + + @Override + public final ByteBuf setShort(int index, int value) { + wrapped.checkIndex(index, 2); + _setShort(wrapped, index, nativeByteOrder ? (short) value : Short.reverseBytes((short) value)); + return this; + } + + @Override + public final ByteBuf setInt(int index, int value) { + wrapped.checkIndex(index, 4); + _setInt(wrapped, index, nativeByteOrder ? value : Integer.reverseBytes(value)); + return this; + } + + @Override + public final ByteBuf setLong(int index, long value) { + wrapped.checkIndex(index, 8); + _setLong(wrapped, index, nativeByteOrder ? value : Long.reverseBytes(value)); + return this; + } + + @Override + public final ByteBuf setChar(int index, int value) { + setShort(index, value); + return this; + } + + @Override + public final ByteBuf setFloat(int index, float value) { + setInt(index, Float.floatToRawIntBits(value)); + return this; + } + + @Override + public final ByteBuf setDouble(int index, double value) { + setLong(index, Double.doubleToRawLongBits(value)); + return this; + } + + @Override + public final ByteBuf writeShort(int value) { + wrapped.ensureWritable0(2); + _setShort(wrapped, wrapped.writerIndex, nativeByteOrder ? (short) value : Short.reverseBytes((short) value)); + wrapped.writerIndex += 2; + return this; + } + + @Override + public final ByteBuf writeInt(int value) { + wrapped.ensureWritable0(4); + _setInt(wrapped, wrapped.writerIndex, nativeByteOrder ? value : Integer.reverseBytes(value)); + wrapped.writerIndex += 4; + return this; + } + + @Override + public final ByteBuf writeLong(long value) { + wrapped.ensureWritable0(8); + _setLong(wrapped, wrapped.writerIndex, nativeByteOrder ? value : Long.reverseBytes(value)); + wrapped.writerIndex += 8; + return this; + } + + @Override + public final ByteBuf writeChar(int value) { + writeShort(value); + return this; + } + + @Override + public final ByteBuf writeFloat(float value) { + writeInt(Float.floatToRawIntBits(value)); + return this; + } + + @Override + public final ByteBuf writeDouble(double value) { + writeLong(Double.doubleToRawLongBits(value)); + return this; + } + + protected abstract short _getShort(AbstractByteBuf wrapped, int index); + protected abstract int _getInt(AbstractByteBuf wrapped, int index); + protected abstract long _getLong(AbstractByteBuf wrapped, int index); + protected abstract void _setShort(AbstractByteBuf wrapped, int index, short value); + protected abstract void _setInt(AbstractByteBuf wrapped, int index, int value); + protected abstract void _setLong(AbstractByteBuf wrapped, int index, long value); +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/AdvancedLeakAwareByteBuf.java b/netty-buffer/src/main/java/io/netty/buffer/AdvancedLeakAwareByteBuf.java new file mode 100644 index 0000000..e8d96d0 --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/AdvancedLeakAwareByteBuf.java @@ -0,0 +1,968 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.buffer; + +import io.netty.util.ByteProcessor; +import io.netty.util.ResourceLeakDetector; +import io.netty.util.ResourceLeakTracker; +import io.netty.util.internal.SystemPropertyUtil; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.channels.FileChannel; +import java.nio.channels.GatheringByteChannel; +import java.nio.channels.ScatteringByteChannel; +import java.nio.charset.Charset; + +final class AdvancedLeakAwareByteBuf extends SimpleLeakAwareByteBuf { + + // If set to true we will only record stacktraces for touch(...), release(...) and retain(...) calls. + private static final String PROP_ACQUIRE_AND_RELEASE_ONLY = "io.netty.leakDetection.acquireAndReleaseOnly"; + private static final boolean ACQUIRE_AND_RELEASE_ONLY; + + private static final InternalLogger logger = InternalLoggerFactory.getInstance(AdvancedLeakAwareByteBuf.class); + + static { + ACQUIRE_AND_RELEASE_ONLY = SystemPropertyUtil.getBoolean(PROP_ACQUIRE_AND_RELEASE_ONLY, false); + + if (logger.isDebugEnabled()) { + logger.debug("-D{}: {}", PROP_ACQUIRE_AND_RELEASE_ONLY, ACQUIRE_AND_RELEASE_ONLY); + } + + ResourceLeakDetector.addExclusions( + AdvancedLeakAwareByteBuf.class, "touch", "recordLeakNonRefCountingOperation"); + } + + AdvancedLeakAwareByteBuf(ByteBuf buf, ResourceLeakTracker leak) { + super(buf, leak); + } + + AdvancedLeakAwareByteBuf(ByteBuf wrapped, ByteBuf trackedByteBuf, ResourceLeakTracker leak) { + super(wrapped, trackedByteBuf, leak); + } + + static void recordLeakNonRefCountingOperation(ResourceLeakTracker leak) { + if (!ACQUIRE_AND_RELEASE_ONLY) { + leak.record(); + } + } + + @Override + public ByteBuf order(ByteOrder endianness) { + recordLeakNonRefCountingOperation(leak); + return super.order(endianness); + } + + @Override + public ByteBuf slice() { + recordLeakNonRefCountingOperation(leak); + return super.slice(); + } + + @Override + public ByteBuf slice(int index, int length) { + recordLeakNonRefCountingOperation(leak); + return super.slice(index, length); + } + + @Override + public ByteBuf retainedSlice() { + recordLeakNonRefCountingOperation(leak); + return super.retainedSlice(); + } + + @Override + public ByteBuf retainedSlice(int index, int length) { + recordLeakNonRefCountingOperation(leak); + return super.retainedSlice(index, length); + } + + @Override + public ByteBuf retainedDuplicate() { + recordLeakNonRefCountingOperation(leak); + return super.retainedDuplicate(); + } + + @Override + public ByteBuf readRetainedSlice(int length) { + recordLeakNonRefCountingOperation(leak); + return super.readRetainedSlice(length); + } + + @Override + public ByteBuf duplicate() { + recordLeakNonRefCountingOperation(leak); + return super.duplicate(); + } + + @Override + public ByteBuf readSlice(int length) { + recordLeakNonRefCountingOperation(leak); + return super.readSlice(length); + } + + @Override + public ByteBuf discardReadBytes() { + recordLeakNonRefCountingOperation(leak); + return super.discardReadBytes(); + } + + @Override + public ByteBuf discardSomeReadBytes() { + recordLeakNonRefCountingOperation(leak); + return super.discardSomeReadBytes(); + } + + @Override + public ByteBuf ensureWritable(int minWritableBytes) { + recordLeakNonRefCountingOperation(leak); + return super.ensureWritable(minWritableBytes); + } + + @Override + public int ensureWritable(int minWritableBytes, boolean force) { + recordLeakNonRefCountingOperation(leak); + return super.ensureWritable(minWritableBytes, force); + } + + @Override + public boolean getBoolean(int index) { + recordLeakNonRefCountingOperation(leak); + return super.getBoolean(index); + } + + @Override + public byte getByte(int index) { + recordLeakNonRefCountingOperation(leak); + return super.getByte(index); + } + + @Override + public short getUnsignedByte(int index) { + recordLeakNonRefCountingOperation(leak); + return super.getUnsignedByte(index); + } + + @Override + public short getShort(int index) { + recordLeakNonRefCountingOperation(leak); + return super.getShort(index); + } + + @Override + public int getUnsignedShort(int index) { + recordLeakNonRefCountingOperation(leak); + return super.getUnsignedShort(index); + } + + @Override + public int getMedium(int index) { + recordLeakNonRefCountingOperation(leak); + return super.getMedium(index); + } + + @Override + public int getUnsignedMedium(int index) { + recordLeakNonRefCountingOperation(leak); + return super.getUnsignedMedium(index); + } + + @Override + public int getInt(int index) { + recordLeakNonRefCountingOperation(leak); + return super.getInt(index); + } + + @Override + public long getUnsignedInt(int index) { + recordLeakNonRefCountingOperation(leak); + return super.getUnsignedInt(index); + } + + @Override + public long getLong(int index) { + recordLeakNonRefCountingOperation(leak); + return super.getLong(index); + } + + @Override + public char getChar(int index) { + recordLeakNonRefCountingOperation(leak); + return super.getChar(index); + } + + @Override + public float getFloat(int index) { + recordLeakNonRefCountingOperation(leak); + return super.getFloat(index); + } + + @Override + public double getDouble(int index) { + recordLeakNonRefCountingOperation(leak); + return super.getDouble(index); + } + + @Override + public ByteBuf getBytes(int index, ByteBuf dst) { + recordLeakNonRefCountingOperation(leak); + return super.getBytes(index, dst); + } + + @Override + public ByteBuf getBytes(int index, ByteBuf dst, int length) { + recordLeakNonRefCountingOperation(leak); + return super.getBytes(index, dst, length); + } + + @Override + public ByteBuf getBytes(int index, ByteBuf dst, int dstIndex, int length) { + recordLeakNonRefCountingOperation(leak); + return super.getBytes(index, dst, dstIndex, length); + } + + @Override + public ByteBuf getBytes(int index, byte[] dst) { + recordLeakNonRefCountingOperation(leak); + return super.getBytes(index, dst); + } + + @Override + public ByteBuf getBytes(int index, byte[] dst, int dstIndex, int length) { + recordLeakNonRefCountingOperation(leak); + return super.getBytes(index, dst, dstIndex, length); + } + + @Override + public ByteBuf getBytes(int index, ByteBuffer dst) { + recordLeakNonRefCountingOperation(leak); + return super.getBytes(index, dst); + } + + @Override + public ByteBuf getBytes(int index, OutputStream out, int length) throws IOException { + recordLeakNonRefCountingOperation(leak); + return super.getBytes(index, out, length); + } + + @Override + public int getBytes(int index, GatheringByteChannel out, int length) throws IOException { + recordLeakNonRefCountingOperation(leak); + return super.getBytes(index, out, length); + } + + @Override + public CharSequence getCharSequence(int index, int length, Charset charset) { + recordLeakNonRefCountingOperation(leak); + return super.getCharSequence(index, length, charset); + } + + @Override + public ByteBuf setBoolean(int index, boolean value) { + recordLeakNonRefCountingOperation(leak); + return super.setBoolean(index, value); + } + + @Override + public ByteBuf setByte(int index, int value) { + recordLeakNonRefCountingOperation(leak); + return super.setByte(index, value); + } + + @Override + public ByteBuf setShort(int index, int value) { + recordLeakNonRefCountingOperation(leak); + return super.setShort(index, value); + } + + @Override + public ByteBuf setMedium(int index, int value) { + recordLeakNonRefCountingOperation(leak); + return super.setMedium(index, value); + } + + @Override + public ByteBuf setInt(int index, int value) { + recordLeakNonRefCountingOperation(leak); + return super.setInt(index, value); + } + + @Override + public ByteBuf setLong(int index, long value) { + recordLeakNonRefCountingOperation(leak); + return super.setLong(index, value); + } + + @Override + public ByteBuf setChar(int index, int value) { + recordLeakNonRefCountingOperation(leak); + return super.setChar(index, value); + } + + @Override + public ByteBuf setFloat(int index, float value) { + recordLeakNonRefCountingOperation(leak); + return super.setFloat(index, value); + } + + @Override + public ByteBuf setDouble(int index, double value) { + recordLeakNonRefCountingOperation(leak); + return super.setDouble(index, value); + } + + @Override + public ByteBuf setBytes(int index, ByteBuf src) { + recordLeakNonRefCountingOperation(leak); + return super.setBytes(index, src); + } + + @Override + public ByteBuf setBytes(int index, ByteBuf src, int length) { + recordLeakNonRefCountingOperation(leak); + return super.setBytes(index, src, length); + } + + @Override + public ByteBuf setBytes(int index, ByteBuf src, int srcIndex, int length) { + recordLeakNonRefCountingOperation(leak); + return super.setBytes(index, src, srcIndex, length); + } + + @Override + public ByteBuf setBytes(int index, byte[] src) { + recordLeakNonRefCountingOperation(leak); + return super.setBytes(index, src); + } + + @Override + public ByteBuf setBytes(int index, byte[] src, int srcIndex, int length) { + recordLeakNonRefCountingOperation(leak); + return super.setBytes(index, src, srcIndex, length); + } + + @Override + public ByteBuf setBytes(int index, ByteBuffer src) { + recordLeakNonRefCountingOperation(leak); + return super.setBytes(index, src); + } + + @Override + public int setBytes(int index, InputStream in, int length) throws IOException { + recordLeakNonRefCountingOperation(leak); + return super.setBytes(index, in, length); + } + + @Override + public int setBytes(int index, ScatteringByteChannel in, int length) throws IOException { + recordLeakNonRefCountingOperation(leak); + return super.setBytes(index, in, length); + } + + @Override + public ByteBuf setZero(int index, int length) { + recordLeakNonRefCountingOperation(leak); + return super.setZero(index, length); + } + + @Override + public int setCharSequence(int index, CharSequence sequence, Charset charset) { + recordLeakNonRefCountingOperation(leak); + return super.setCharSequence(index, sequence, charset); + } + + @Override + public boolean readBoolean() { + recordLeakNonRefCountingOperation(leak); + return super.readBoolean(); + } + + @Override + public byte readByte() { + recordLeakNonRefCountingOperation(leak); + return super.readByte(); + } + + @Override + public short readUnsignedByte() { + recordLeakNonRefCountingOperation(leak); + return super.readUnsignedByte(); + } + + @Override + public short readShort() { + recordLeakNonRefCountingOperation(leak); + return super.readShort(); + } + + @Override + public int readUnsignedShort() { + recordLeakNonRefCountingOperation(leak); + return super.readUnsignedShort(); + } + + @Override + public int readMedium() { + recordLeakNonRefCountingOperation(leak); + return super.readMedium(); + } + + @Override + public int readUnsignedMedium() { + recordLeakNonRefCountingOperation(leak); + return super.readUnsignedMedium(); + } + + @Override + public int readInt() { + recordLeakNonRefCountingOperation(leak); + return super.readInt(); + } + + @Override + public long readUnsignedInt() { + recordLeakNonRefCountingOperation(leak); + return super.readUnsignedInt(); + } + + @Override + public long readLong() { + recordLeakNonRefCountingOperation(leak); + return super.readLong(); + } + + @Override + public char readChar() { + recordLeakNonRefCountingOperation(leak); + return super.readChar(); + } + + @Override + public float readFloat() { + recordLeakNonRefCountingOperation(leak); + return super.readFloat(); + } + + @Override + public double readDouble() { + recordLeakNonRefCountingOperation(leak); + return super.readDouble(); + } + + @Override + public ByteBuf readBytes(int length) { + recordLeakNonRefCountingOperation(leak); + return super.readBytes(length); + } + + @Override + public ByteBuf readBytes(ByteBuf dst) { + recordLeakNonRefCountingOperation(leak); + return super.readBytes(dst); + } + + @Override + public ByteBuf readBytes(ByteBuf dst, int length) { + recordLeakNonRefCountingOperation(leak); + return super.readBytes(dst, length); + } + + @Override + public ByteBuf readBytes(ByteBuf dst, int dstIndex, int length) { + recordLeakNonRefCountingOperation(leak); + return super.readBytes(dst, dstIndex, length); + } + + @Override + public ByteBuf readBytes(byte[] dst) { + recordLeakNonRefCountingOperation(leak); + return super.readBytes(dst); + } + + @Override + public ByteBuf readBytes(byte[] dst, int dstIndex, int length) { + recordLeakNonRefCountingOperation(leak); + return super.readBytes(dst, dstIndex, length); + } + + @Override + public ByteBuf readBytes(ByteBuffer dst) { + recordLeakNonRefCountingOperation(leak); + return super.readBytes(dst); + } + + @Override + public ByteBuf readBytes(OutputStream out, int length) throws IOException { + recordLeakNonRefCountingOperation(leak); + return super.readBytes(out, length); + } + + @Override + public int readBytes(GatheringByteChannel out, int length) throws IOException { + recordLeakNonRefCountingOperation(leak); + return super.readBytes(out, length); + } + + @Override + public CharSequence readCharSequence(int length, Charset charset) { + recordLeakNonRefCountingOperation(leak); + return super.readCharSequence(length, charset); + } + + @Override + public ByteBuf skipBytes(int length) { + recordLeakNonRefCountingOperation(leak); + return super.skipBytes(length); + } + + @Override + public ByteBuf writeBoolean(boolean value) { + recordLeakNonRefCountingOperation(leak); + return super.writeBoolean(value); + } + + @Override + public ByteBuf writeByte(int value) { + recordLeakNonRefCountingOperation(leak); + return super.writeByte(value); + } + + @Override + public ByteBuf writeShort(int value) { + recordLeakNonRefCountingOperation(leak); + return super.writeShort(value); + } + + @Override + public ByteBuf writeMedium(int value) { + recordLeakNonRefCountingOperation(leak); + return super.writeMedium(value); + } + + @Override + public ByteBuf writeInt(int value) { + recordLeakNonRefCountingOperation(leak); + return super.writeInt(value); + } + + @Override + public ByteBuf writeLong(long value) { + recordLeakNonRefCountingOperation(leak); + return super.writeLong(value); + } + + @Override + public ByteBuf writeChar(int value) { + recordLeakNonRefCountingOperation(leak); + return super.writeChar(value); + } + + @Override + public ByteBuf writeFloat(float value) { + recordLeakNonRefCountingOperation(leak); + return super.writeFloat(value); + } + + @Override + public ByteBuf writeDouble(double value) { + recordLeakNonRefCountingOperation(leak); + return super.writeDouble(value); + } + + @Override + public ByteBuf writeBytes(ByteBuf src) { + recordLeakNonRefCountingOperation(leak); + return super.writeBytes(src); + } + + @Override + public ByteBuf writeBytes(ByteBuf src, int length) { + recordLeakNonRefCountingOperation(leak); + return super.writeBytes(src, length); + } + + @Override + public ByteBuf writeBytes(ByteBuf src, int srcIndex, int length) { + recordLeakNonRefCountingOperation(leak); + return super.writeBytes(src, srcIndex, length); + } + + @Override + public ByteBuf writeBytes(byte[] src) { + recordLeakNonRefCountingOperation(leak); + return super.writeBytes(src); + } + + @Override + public ByteBuf writeBytes(byte[] src, int srcIndex, int length) { + recordLeakNonRefCountingOperation(leak); + return super.writeBytes(src, srcIndex, length); + } + + @Override + public ByteBuf writeBytes(ByteBuffer src) { + recordLeakNonRefCountingOperation(leak); + return super.writeBytes(src); + } + + @Override + public int writeBytes(InputStream in, int length) throws IOException { + recordLeakNonRefCountingOperation(leak); + return super.writeBytes(in, length); + } + + @Override + public int writeBytes(ScatteringByteChannel in, int length) throws IOException { + recordLeakNonRefCountingOperation(leak); + return super.writeBytes(in, length); + } + + @Override + public ByteBuf writeZero(int length) { + recordLeakNonRefCountingOperation(leak); + return super.writeZero(length); + } + + @Override + public int indexOf(int fromIndex, int toIndex, byte value) { + recordLeakNonRefCountingOperation(leak); + return super.indexOf(fromIndex, toIndex, value); + } + + @Override + public int bytesBefore(byte value) { + recordLeakNonRefCountingOperation(leak); + return super.bytesBefore(value); + } + + @Override + public int bytesBefore(int length, byte value) { + recordLeakNonRefCountingOperation(leak); + return super.bytesBefore(length, value); + } + + @Override + public int bytesBefore(int index, int length, byte value) { + recordLeakNonRefCountingOperation(leak); + return super.bytesBefore(index, length, value); + } + + @Override + public int forEachByte(ByteProcessor processor) { + recordLeakNonRefCountingOperation(leak); + return super.forEachByte(processor); + } + + @Override + public int forEachByte(int index, int length, ByteProcessor processor) { + recordLeakNonRefCountingOperation(leak); + return super.forEachByte(index, length, processor); + } + + @Override + public int forEachByteDesc(ByteProcessor processor) { + recordLeakNonRefCountingOperation(leak); + return super.forEachByteDesc(processor); + } + + @Override + public int forEachByteDesc(int index, int length, ByteProcessor processor) { + recordLeakNonRefCountingOperation(leak); + return super.forEachByteDesc(index, length, processor); + } + + @Override + public ByteBuf copy() { + recordLeakNonRefCountingOperation(leak); + return super.copy(); + } + + @Override + public ByteBuf copy(int index, int length) { + recordLeakNonRefCountingOperation(leak); + return super.copy(index, length); + } + + @Override + public int nioBufferCount() { + recordLeakNonRefCountingOperation(leak); + return super.nioBufferCount(); + } + + @Override + public ByteBuffer nioBuffer() { + recordLeakNonRefCountingOperation(leak); + return super.nioBuffer(); + } + + @Override + public ByteBuffer nioBuffer(int index, int length) { + recordLeakNonRefCountingOperation(leak); + return super.nioBuffer(index, length); + } + + @Override + public ByteBuffer[] nioBuffers() { + recordLeakNonRefCountingOperation(leak); + return super.nioBuffers(); + } + + @Override + public ByteBuffer[] nioBuffers(int index, int length) { + recordLeakNonRefCountingOperation(leak); + return super.nioBuffers(index, length); + } + + @Override + public ByteBuffer internalNioBuffer(int index, int length) { + recordLeakNonRefCountingOperation(leak); + return super.internalNioBuffer(index, length); + } + + @Override + public String toString(Charset charset) { + recordLeakNonRefCountingOperation(leak); + return super.toString(charset); + } + + @Override + public String toString(int index, int length, Charset charset) { + recordLeakNonRefCountingOperation(leak); + return super.toString(index, length, charset); + } + + @Override + public ByteBuf capacity(int newCapacity) { + recordLeakNonRefCountingOperation(leak); + return super.capacity(newCapacity); + } + + @Override + public short getShortLE(int index) { + recordLeakNonRefCountingOperation(leak); + return super.getShortLE(index); + } + + @Override + public int getUnsignedShortLE(int index) { + recordLeakNonRefCountingOperation(leak); + return super.getUnsignedShortLE(index); + } + + @Override + public int getMediumLE(int index) { + recordLeakNonRefCountingOperation(leak); + return super.getMediumLE(index); + } + + @Override + public int getUnsignedMediumLE(int index) { + recordLeakNonRefCountingOperation(leak); + return super.getUnsignedMediumLE(index); + } + + @Override + public int getIntLE(int index) { + recordLeakNonRefCountingOperation(leak); + return super.getIntLE(index); + } + + @Override + public long getUnsignedIntLE(int index) { + recordLeakNonRefCountingOperation(leak); + return super.getUnsignedIntLE(index); + } + + @Override + public long getLongLE(int index) { + recordLeakNonRefCountingOperation(leak); + return super.getLongLE(index); + } + + @Override + public ByteBuf setShortLE(int index, int value) { + recordLeakNonRefCountingOperation(leak); + return super.setShortLE(index, value); + } + + @Override + public ByteBuf setIntLE(int index, int value) { + recordLeakNonRefCountingOperation(leak); + return super.setIntLE(index, value); + } + + @Override + public ByteBuf setMediumLE(int index, int value) { + recordLeakNonRefCountingOperation(leak); + return super.setMediumLE(index, value); + } + + @Override + public ByteBuf setLongLE(int index, long value) { + recordLeakNonRefCountingOperation(leak); + return super.setLongLE(index, value); + } + + @Override + public short readShortLE() { + recordLeakNonRefCountingOperation(leak); + return super.readShortLE(); + } + + @Override + public int readUnsignedShortLE() { + recordLeakNonRefCountingOperation(leak); + return super.readUnsignedShortLE(); + } + + @Override + public int readMediumLE() { + recordLeakNonRefCountingOperation(leak); + return super.readMediumLE(); + } + + @Override + public int readUnsignedMediumLE() { + recordLeakNonRefCountingOperation(leak); + return super.readUnsignedMediumLE(); + } + + @Override + public int readIntLE() { + recordLeakNonRefCountingOperation(leak); + return super.readIntLE(); + } + + @Override + public long readUnsignedIntLE() { + recordLeakNonRefCountingOperation(leak); + return super.readUnsignedIntLE(); + } + + @Override + public long readLongLE() { + recordLeakNonRefCountingOperation(leak); + return super.readLongLE(); + } + + @Override + public ByteBuf writeShortLE(int value) { + recordLeakNonRefCountingOperation(leak); + return super.writeShortLE(value); + } + + @Override + public ByteBuf writeMediumLE(int value) { + recordLeakNonRefCountingOperation(leak); + return super.writeMediumLE(value); + } + + @Override + public ByteBuf writeIntLE(int value) { + recordLeakNonRefCountingOperation(leak); + return super.writeIntLE(value); + } + + @Override + public ByteBuf writeLongLE(long value) { + recordLeakNonRefCountingOperation(leak); + return super.writeLongLE(value); + } + + @Override + public int writeCharSequence(CharSequence sequence, Charset charset) { + recordLeakNonRefCountingOperation(leak); + return super.writeCharSequence(sequence, charset); + } + + @Override + public int getBytes(int index, FileChannel out, long position, int length) throws IOException { + recordLeakNonRefCountingOperation(leak); + return super.getBytes(index, out, position, length); + } + + @Override + public int setBytes(int index, FileChannel in, long position, int length) throws IOException { + recordLeakNonRefCountingOperation(leak); + return super.setBytes(index, in, position, length); + } + + @Override + public int readBytes(FileChannel out, long position, int length) throws IOException { + recordLeakNonRefCountingOperation(leak); + return super.readBytes(out, position, length); + } + + @Override + public int writeBytes(FileChannel in, long position, int length) throws IOException { + recordLeakNonRefCountingOperation(leak); + return super.writeBytes(in, position, length); + } + + @Override + public ByteBuf asReadOnly() { + recordLeakNonRefCountingOperation(leak); + return super.asReadOnly(); + } + + @Override + public ByteBuf retain() { + leak.record(); + return super.retain(); + } + + @Override + public ByteBuf retain(int increment) { + leak.record(); + return super.retain(increment); + } + + @Override + public boolean release() { + leak.record(); + return super.release(); + } + + @Override + public boolean release(int decrement) { + leak.record(); + return super.release(decrement); + } + + @Override + public ByteBuf touch() { + leak.record(); + return this; + } + + @Override + public ByteBuf touch(Object hint) { + leak.record(hint); + return this; + } + + @Override + protected AdvancedLeakAwareByteBuf newLeakAwareByteBuf( + ByteBuf buf, ByteBuf trackedByteBuf, ResourceLeakTracker leakTracker) { + return new AdvancedLeakAwareByteBuf(buf, trackedByteBuf, leakTracker); + } +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/AdvancedLeakAwareCompositeByteBuf.java b/netty-buffer/src/main/java/io/netty/buffer/AdvancedLeakAwareCompositeByteBuf.java new file mode 100644 index 0000000..3eb404e --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/AdvancedLeakAwareCompositeByteBuf.java @@ -0,0 +1,1055 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + + +import io.netty.util.ByteProcessor; +import io.netty.util.ResourceLeakTracker; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.channels.FileChannel; +import java.nio.channels.GatheringByteChannel; +import java.nio.channels.ScatteringByteChannel; +import java.nio.charset.Charset; +import java.util.Iterator; +import java.util.List; + +import static io.netty.buffer.AdvancedLeakAwareByteBuf.recordLeakNonRefCountingOperation; + +final class AdvancedLeakAwareCompositeByteBuf extends SimpleLeakAwareCompositeByteBuf { + + AdvancedLeakAwareCompositeByteBuf(CompositeByteBuf wrapped, ResourceLeakTracker leak) { + super(wrapped, leak); + } + + @Override + public ByteBuf order(ByteOrder endianness) { + recordLeakNonRefCountingOperation(leak); + return super.order(endianness); + } + + @Override + public ByteBuf slice() { + recordLeakNonRefCountingOperation(leak); + return super.slice(); + } + + @Override + public ByteBuf retainedSlice() { + recordLeakNonRefCountingOperation(leak); + return super.retainedSlice(); + } + + @Override + public ByteBuf slice(int index, int length) { + recordLeakNonRefCountingOperation(leak); + return super.slice(index, length); + } + + @Override + public ByteBuf retainedSlice(int index, int length) { + recordLeakNonRefCountingOperation(leak); + return super.retainedSlice(index, length); + } + + @Override + public ByteBuf duplicate() { + recordLeakNonRefCountingOperation(leak); + return super.duplicate(); + } + + @Override + public ByteBuf retainedDuplicate() { + recordLeakNonRefCountingOperation(leak); + return super.retainedDuplicate(); + } + + @Override + public ByteBuf readSlice(int length) { + recordLeakNonRefCountingOperation(leak); + return super.readSlice(length); + } + + @Override + public ByteBuf readRetainedSlice(int length) { + recordLeakNonRefCountingOperation(leak); + return super.readRetainedSlice(length); + } + + @Override + public ByteBuf asReadOnly() { + recordLeakNonRefCountingOperation(leak); + return super.asReadOnly(); + } + + @Override + public boolean isReadOnly() { + recordLeakNonRefCountingOperation(leak); + return super.isReadOnly(); + } + + @Override + public CompositeByteBuf discardReadBytes() { + recordLeakNonRefCountingOperation(leak); + return super.discardReadBytes(); + } + + @Override + public CompositeByteBuf discardSomeReadBytes() { + recordLeakNonRefCountingOperation(leak); + return super.discardSomeReadBytes(); + } + + @Override + public CompositeByteBuf ensureWritable(int minWritableBytes) { + recordLeakNonRefCountingOperation(leak); + return super.ensureWritable(minWritableBytes); + } + + @Override + public int ensureWritable(int minWritableBytes, boolean force) { + recordLeakNonRefCountingOperation(leak); + return super.ensureWritable(minWritableBytes, force); + } + + @Override + public boolean getBoolean(int index) { + recordLeakNonRefCountingOperation(leak); + return super.getBoolean(index); + } + + @Override + public byte getByte(int index) { + recordLeakNonRefCountingOperation(leak); + return super.getByte(index); + } + + @Override + public short getUnsignedByte(int index) { + recordLeakNonRefCountingOperation(leak); + return super.getUnsignedByte(index); + } + + @Override + public short getShort(int index) { + recordLeakNonRefCountingOperation(leak); + return super.getShort(index); + } + + @Override + public int getUnsignedShort(int index) { + recordLeakNonRefCountingOperation(leak); + return super.getUnsignedShort(index); + } + + @Override + public int getMedium(int index) { + recordLeakNonRefCountingOperation(leak); + return super.getMedium(index); + } + + @Override + public int getUnsignedMedium(int index) { + recordLeakNonRefCountingOperation(leak); + return super.getUnsignedMedium(index); + } + + @Override + public int getInt(int index) { + recordLeakNonRefCountingOperation(leak); + return super.getInt(index); + } + + @Override + public long getUnsignedInt(int index) { + recordLeakNonRefCountingOperation(leak); + return super.getUnsignedInt(index); + } + + @Override + public long getLong(int index) { + recordLeakNonRefCountingOperation(leak); + return super.getLong(index); + } + + @Override + public char getChar(int index) { + recordLeakNonRefCountingOperation(leak); + return super.getChar(index); + } + + @Override + public float getFloat(int index) { + recordLeakNonRefCountingOperation(leak); + return super.getFloat(index); + } + + @Override + public double getDouble(int index) { + recordLeakNonRefCountingOperation(leak); + return super.getDouble(index); + } + + @Override + public CompositeByteBuf getBytes(int index, ByteBuf dst) { + recordLeakNonRefCountingOperation(leak); + return super.getBytes(index, dst); + } + + @Override + public CompositeByteBuf getBytes(int index, ByteBuf dst, int length) { + recordLeakNonRefCountingOperation(leak); + return super.getBytes(index, dst, length); + } + + @Override + public CompositeByteBuf getBytes(int index, ByteBuf dst, int dstIndex, int length) { + recordLeakNonRefCountingOperation(leak); + return super.getBytes(index, dst, dstIndex, length); + } + + @Override + public CompositeByteBuf getBytes(int index, byte[] dst) { + recordLeakNonRefCountingOperation(leak); + return super.getBytes(index, dst); + } + + @Override + public CompositeByteBuf getBytes(int index, byte[] dst, int dstIndex, int length) { + recordLeakNonRefCountingOperation(leak); + return super.getBytes(index, dst, dstIndex, length); + } + + @Override + public CompositeByteBuf getBytes(int index, ByteBuffer dst) { + recordLeakNonRefCountingOperation(leak); + return super.getBytes(index, dst); + } + + @Override + public CompositeByteBuf getBytes(int index, OutputStream out, int length) throws IOException { + recordLeakNonRefCountingOperation(leak); + return super.getBytes(index, out, length); + } + + @Override + public int getBytes(int index, GatheringByteChannel out, int length) throws IOException { + recordLeakNonRefCountingOperation(leak); + return super.getBytes(index, out, length); + } + + @Override + public CharSequence getCharSequence(int index, int length, Charset charset) { + recordLeakNonRefCountingOperation(leak); + return super.getCharSequence(index, length, charset); + } + + @Override + public CompositeByteBuf setBoolean(int index, boolean value) { + recordLeakNonRefCountingOperation(leak); + return super.setBoolean(index, value); + } + + @Override + public CompositeByteBuf setByte(int index, int value) { + recordLeakNonRefCountingOperation(leak); + return super.setByte(index, value); + } + + @Override + public CompositeByteBuf setShort(int index, int value) { + recordLeakNonRefCountingOperation(leak); + return super.setShort(index, value); + } + + @Override + public CompositeByteBuf setMedium(int index, int value) { + recordLeakNonRefCountingOperation(leak); + return super.setMedium(index, value); + } + + @Override + public CompositeByteBuf setInt(int index, int value) { + recordLeakNonRefCountingOperation(leak); + return super.setInt(index, value); + } + + @Override + public CompositeByteBuf setLong(int index, long value) { + recordLeakNonRefCountingOperation(leak); + return super.setLong(index, value); + } + + @Override + public CompositeByteBuf setChar(int index, int value) { + recordLeakNonRefCountingOperation(leak); + return super.setChar(index, value); + } + + @Override + public CompositeByteBuf setFloat(int index, float value) { + recordLeakNonRefCountingOperation(leak); + return super.setFloat(index, value); + } + + @Override + public CompositeByteBuf setDouble(int index, double value) { + recordLeakNonRefCountingOperation(leak); + return super.setDouble(index, value); + } + + @Override + public CompositeByteBuf setBytes(int index, ByteBuf src) { + recordLeakNonRefCountingOperation(leak); + return super.setBytes(index, src); + } + + @Override + public CompositeByteBuf setBytes(int index, ByteBuf src, int length) { + recordLeakNonRefCountingOperation(leak); + return super.setBytes(index, src, length); + } + + @Override + public CompositeByteBuf setBytes(int index, ByteBuf src, int srcIndex, int length) { + recordLeakNonRefCountingOperation(leak); + return super.setBytes(index, src, srcIndex, length); + } + + @Override + public CompositeByteBuf setBytes(int index, byte[] src) { + recordLeakNonRefCountingOperation(leak); + return super.setBytes(index, src); + } + + @Override + public CompositeByteBuf setBytes(int index, byte[] src, int srcIndex, int length) { + recordLeakNonRefCountingOperation(leak); + return super.setBytes(index, src, srcIndex, length); + } + + @Override + public CompositeByteBuf setBytes(int index, ByteBuffer src) { + recordLeakNonRefCountingOperation(leak); + return super.setBytes(index, src); + } + + @Override + public int setBytes(int index, InputStream in, int length) throws IOException { + recordLeakNonRefCountingOperation(leak); + return super.setBytes(index, in, length); + } + + @Override + public int setBytes(int index, ScatteringByteChannel in, int length) throws IOException { + recordLeakNonRefCountingOperation(leak); + return super.setBytes(index, in, length); + } + + @Override + public CompositeByteBuf setZero(int index, int length) { + recordLeakNonRefCountingOperation(leak); + return super.setZero(index, length); + } + + @Override + public boolean readBoolean() { + recordLeakNonRefCountingOperation(leak); + return super.readBoolean(); + } + + @Override + public byte readByte() { + recordLeakNonRefCountingOperation(leak); + return super.readByte(); + } + + @Override + public short readUnsignedByte() { + recordLeakNonRefCountingOperation(leak); + return super.readUnsignedByte(); + } + + @Override + public short readShort() { + recordLeakNonRefCountingOperation(leak); + return super.readShort(); + } + + @Override + public int readUnsignedShort() { + recordLeakNonRefCountingOperation(leak); + return super.readUnsignedShort(); + } + + @Override + public int readMedium() { + recordLeakNonRefCountingOperation(leak); + return super.readMedium(); + } + + @Override + public int readUnsignedMedium() { + recordLeakNonRefCountingOperation(leak); + return super.readUnsignedMedium(); + } + + @Override + public int readInt() { + recordLeakNonRefCountingOperation(leak); + return super.readInt(); + } + + @Override + public long readUnsignedInt() { + recordLeakNonRefCountingOperation(leak); + return super.readUnsignedInt(); + } + + @Override + public long readLong() { + recordLeakNonRefCountingOperation(leak); + return super.readLong(); + } + + @Override + public char readChar() { + recordLeakNonRefCountingOperation(leak); + return super.readChar(); + } + + @Override + public float readFloat() { + recordLeakNonRefCountingOperation(leak); + return super.readFloat(); + } + + @Override + public double readDouble() { + recordLeakNonRefCountingOperation(leak); + return super.readDouble(); + } + + @Override + public ByteBuf readBytes(int length) { + recordLeakNonRefCountingOperation(leak); + return super.readBytes(length); + } + + @Override + public CompositeByteBuf readBytes(ByteBuf dst) { + recordLeakNonRefCountingOperation(leak); + return super.readBytes(dst); + } + + @Override + public CompositeByteBuf readBytes(ByteBuf dst, int length) { + recordLeakNonRefCountingOperation(leak); + return super.readBytes(dst, length); + } + + @Override + public CompositeByteBuf readBytes(ByteBuf dst, int dstIndex, int length) { + recordLeakNonRefCountingOperation(leak); + return super.readBytes(dst, dstIndex, length); + } + + @Override + public CompositeByteBuf readBytes(byte[] dst) { + recordLeakNonRefCountingOperation(leak); + return super.readBytes(dst); + } + + @Override + public CompositeByteBuf readBytes(byte[] dst, int dstIndex, int length) { + recordLeakNonRefCountingOperation(leak); + return super.readBytes(dst, dstIndex, length); + } + + @Override + public CompositeByteBuf readBytes(ByteBuffer dst) { + recordLeakNonRefCountingOperation(leak); + return super.readBytes(dst); + } + + @Override + public CompositeByteBuf readBytes(OutputStream out, int length) throws IOException { + recordLeakNonRefCountingOperation(leak); + return super.readBytes(out, length); + } + + @Override + public int readBytes(GatheringByteChannel out, int length) throws IOException { + recordLeakNonRefCountingOperation(leak); + return super.readBytes(out, length); + } + + @Override + public CharSequence readCharSequence(int length, Charset charset) { + recordLeakNonRefCountingOperation(leak); + return super.readCharSequence(length, charset); + } + + @Override + public CompositeByteBuf skipBytes(int length) { + recordLeakNonRefCountingOperation(leak); + return super.skipBytes(length); + } + + @Override + public CompositeByteBuf writeBoolean(boolean value) { + recordLeakNonRefCountingOperation(leak); + return super.writeBoolean(value); + } + + @Override + public CompositeByteBuf writeByte(int value) { + recordLeakNonRefCountingOperation(leak); + return super.writeByte(value); + } + + @Override + public CompositeByteBuf writeShort(int value) { + recordLeakNonRefCountingOperation(leak); + return super.writeShort(value); + } + + @Override + public CompositeByteBuf writeMedium(int value) { + recordLeakNonRefCountingOperation(leak); + return super.writeMedium(value); + } + + @Override + public CompositeByteBuf writeInt(int value) { + recordLeakNonRefCountingOperation(leak); + return super.writeInt(value); + } + + @Override + public CompositeByteBuf writeLong(long value) { + recordLeakNonRefCountingOperation(leak); + return super.writeLong(value); + } + + @Override + public CompositeByteBuf writeChar(int value) { + recordLeakNonRefCountingOperation(leak); + return super.writeChar(value); + } + + @Override + public CompositeByteBuf writeFloat(float value) { + recordLeakNonRefCountingOperation(leak); + return super.writeFloat(value); + } + + @Override + public CompositeByteBuf writeDouble(double value) { + recordLeakNonRefCountingOperation(leak); + return super.writeDouble(value); + } + + @Override + public CompositeByteBuf writeBytes(ByteBuf src) { + recordLeakNonRefCountingOperation(leak); + return super.writeBytes(src); + } + + @Override + public CompositeByteBuf writeBytes(ByteBuf src, int length) { + recordLeakNonRefCountingOperation(leak); + return super.writeBytes(src, length); + } + + @Override + public CompositeByteBuf writeBytes(ByteBuf src, int srcIndex, int length) { + recordLeakNonRefCountingOperation(leak); + return super.writeBytes(src, srcIndex, length); + } + + @Override + public CompositeByteBuf writeBytes(byte[] src) { + recordLeakNonRefCountingOperation(leak); + return super.writeBytes(src); + } + + @Override + public CompositeByteBuf writeBytes(byte[] src, int srcIndex, int length) { + recordLeakNonRefCountingOperation(leak); + return super.writeBytes(src, srcIndex, length); + } + + @Override + public CompositeByteBuf writeBytes(ByteBuffer src) { + recordLeakNonRefCountingOperation(leak); + return super.writeBytes(src); + } + + @Override + public int writeBytes(InputStream in, int length) throws IOException { + recordLeakNonRefCountingOperation(leak); + return super.writeBytes(in, length); + } + + @Override + public int writeBytes(ScatteringByteChannel in, int length) throws IOException { + recordLeakNonRefCountingOperation(leak); + return super.writeBytes(in, length); + } + + @Override + public CompositeByteBuf writeZero(int length) { + recordLeakNonRefCountingOperation(leak); + return super.writeZero(length); + } + + @Override + public int writeCharSequence(CharSequence sequence, Charset charset) { + recordLeakNonRefCountingOperation(leak); + return super.writeCharSequence(sequence, charset); + } + + @Override + public int indexOf(int fromIndex, int toIndex, byte value) { + recordLeakNonRefCountingOperation(leak); + return super.indexOf(fromIndex, toIndex, value); + } + + @Override + public int bytesBefore(byte value) { + recordLeakNonRefCountingOperation(leak); + return super.bytesBefore(value); + } + + @Override + public int bytesBefore(int length, byte value) { + recordLeakNonRefCountingOperation(leak); + return super.bytesBefore(length, value); + } + + @Override + public int bytesBefore(int index, int length, byte value) { + recordLeakNonRefCountingOperation(leak); + return super.bytesBefore(index, length, value); + } + + @Override + public int forEachByte(ByteProcessor processor) { + recordLeakNonRefCountingOperation(leak); + return super.forEachByte(processor); + } + + @Override + public int forEachByte(int index, int length, ByteProcessor processor) { + recordLeakNonRefCountingOperation(leak); + return super.forEachByte(index, length, processor); + } + + @Override + public int forEachByteDesc(ByteProcessor processor) { + recordLeakNonRefCountingOperation(leak); + return super.forEachByteDesc(processor); + } + + @Override + public int forEachByteDesc(int index, int length, ByteProcessor processor) { + recordLeakNonRefCountingOperation(leak); + return super.forEachByteDesc(index, length, processor); + } + + @Override + public ByteBuf copy() { + recordLeakNonRefCountingOperation(leak); + return super.copy(); + } + + @Override + public ByteBuf copy(int index, int length) { + recordLeakNonRefCountingOperation(leak); + return super.copy(index, length); + } + + @Override + public int nioBufferCount() { + recordLeakNonRefCountingOperation(leak); + return super.nioBufferCount(); + } + + @Override + public ByteBuffer nioBuffer() { + recordLeakNonRefCountingOperation(leak); + return super.nioBuffer(); + } + + @Override + public ByteBuffer nioBuffer(int index, int length) { + recordLeakNonRefCountingOperation(leak); + return super.nioBuffer(index, length); + } + + @Override + public ByteBuffer[] nioBuffers() { + recordLeakNonRefCountingOperation(leak); + return super.nioBuffers(); + } + + @Override + public ByteBuffer[] nioBuffers(int index, int length) { + recordLeakNonRefCountingOperation(leak); + return super.nioBuffers(index, length); + } + + @Override + public ByteBuffer internalNioBuffer(int index, int length) { + recordLeakNonRefCountingOperation(leak); + return super.internalNioBuffer(index, length); + } + + @Override + public String toString(Charset charset) { + recordLeakNonRefCountingOperation(leak); + return super.toString(charset); + } + + @Override + public String toString(int index, int length, Charset charset) { + recordLeakNonRefCountingOperation(leak); + return super.toString(index, length, charset); + } + + @Override + public CompositeByteBuf capacity(int newCapacity) { + recordLeakNonRefCountingOperation(leak); + return super.capacity(newCapacity); + } + + @Override + public short getShortLE(int index) { + recordLeakNonRefCountingOperation(leak); + return super.getShortLE(index); + } + + @Override + public int getUnsignedShortLE(int index) { + recordLeakNonRefCountingOperation(leak); + return super.getUnsignedShortLE(index); + } + + @Override + public int getUnsignedMediumLE(int index) { + recordLeakNonRefCountingOperation(leak); + return super.getUnsignedMediumLE(index); + } + + @Override + public int getMediumLE(int index) { + recordLeakNonRefCountingOperation(leak); + return super.getMediumLE(index); + } + + @Override + public int getIntLE(int index) { + recordLeakNonRefCountingOperation(leak); + return super.getIntLE(index); + } + + @Override + public long getUnsignedIntLE(int index) { + recordLeakNonRefCountingOperation(leak); + return super.getUnsignedIntLE(index); + } + + @Override + public long getLongLE(int index) { + recordLeakNonRefCountingOperation(leak); + return super.getLongLE(index); + } + + @Override + public ByteBuf setShortLE(int index, int value) { + recordLeakNonRefCountingOperation(leak); + return super.setShortLE(index, value); + } + + @Override + public ByteBuf setMediumLE(int index, int value) { + recordLeakNonRefCountingOperation(leak); + return super.setMediumLE(index, value); + } + + @Override + public ByteBuf setIntLE(int index, int value) { + recordLeakNonRefCountingOperation(leak); + return super.setIntLE(index, value); + } + + @Override + public ByteBuf setLongLE(int index, long value) { + recordLeakNonRefCountingOperation(leak); + return super.setLongLE(index, value); + } + + @Override + public int setCharSequence(int index, CharSequence sequence, Charset charset) { + recordLeakNonRefCountingOperation(leak); + return super.setCharSequence(index, sequence, charset); + } + + @Override + public short readShortLE() { + recordLeakNonRefCountingOperation(leak); + return super.readShortLE(); + } + + @Override + public int readUnsignedShortLE() { + recordLeakNonRefCountingOperation(leak); + return super.readUnsignedShortLE(); + } + + @Override + public int readMediumLE() { + recordLeakNonRefCountingOperation(leak); + return super.readMediumLE(); + } + + @Override + public int readUnsignedMediumLE() { + recordLeakNonRefCountingOperation(leak); + return super.readUnsignedMediumLE(); + } + + @Override + public int readIntLE() { + recordLeakNonRefCountingOperation(leak); + return super.readIntLE(); + } + + @Override + public long readUnsignedIntLE() { + recordLeakNonRefCountingOperation(leak); + return super.readUnsignedIntLE(); + } + + @Override + public long readLongLE() { + recordLeakNonRefCountingOperation(leak); + return super.readLongLE(); + } + + @Override + public ByteBuf writeShortLE(int value) { + recordLeakNonRefCountingOperation(leak); + return super.writeShortLE(value); + } + + @Override + public ByteBuf writeMediumLE(int value) { + recordLeakNonRefCountingOperation(leak); + return super.writeMediumLE(value); + } + + @Override + public ByteBuf writeIntLE(int value) { + recordLeakNonRefCountingOperation(leak); + return super.writeIntLE(value); + } + + @Override + public ByteBuf writeLongLE(long value) { + recordLeakNonRefCountingOperation(leak); + return super.writeLongLE(value); + } + + @Override + public CompositeByteBuf addComponent(ByteBuf buffer) { + recordLeakNonRefCountingOperation(leak); + return super.addComponent(buffer); + } + + @Override + public CompositeByteBuf addComponents(ByteBuf... buffers) { + recordLeakNonRefCountingOperation(leak); + return super.addComponents(buffers); + } + + @Override + public CompositeByteBuf addComponents(Iterable buffers) { + recordLeakNonRefCountingOperation(leak); + return super.addComponents(buffers); + } + + @Override + public CompositeByteBuf addComponent(int cIndex, ByteBuf buffer) { + recordLeakNonRefCountingOperation(leak); + return super.addComponent(cIndex, buffer); + } + + @Override + public CompositeByteBuf addComponents(int cIndex, ByteBuf... buffers) { + recordLeakNonRefCountingOperation(leak); + return super.addComponents(cIndex, buffers); + } + + @Override + public CompositeByteBuf addComponents(int cIndex, Iterable buffers) { + recordLeakNonRefCountingOperation(leak); + return super.addComponents(cIndex, buffers); + } + + @Override + public CompositeByteBuf addComponent(boolean increaseWriterIndex, ByteBuf buffer) { + recordLeakNonRefCountingOperation(leak); + return super.addComponent(increaseWriterIndex, buffer); + } + + @Override + public CompositeByteBuf addComponents(boolean increaseWriterIndex, ByteBuf... buffers) { + recordLeakNonRefCountingOperation(leak); + return super.addComponents(increaseWriterIndex, buffers); + } + + @Override + public CompositeByteBuf addComponents(boolean increaseWriterIndex, Iterable buffers) { + recordLeakNonRefCountingOperation(leak); + return super.addComponents(increaseWriterIndex, buffers); + } + + @Override + public CompositeByteBuf addComponent(boolean increaseWriterIndex, int cIndex, ByteBuf buffer) { + recordLeakNonRefCountingOperation(leak); + return super.addComponent(increaseWriterIndex, cIndex, buffer); + } + + @Override + public CompositeByteBuf addFlattenedComponents(boolean increaseWriterIndex, ByteBuf buffer) { + recordLeakNonRefCountingOperation(leak); + return super.addFlattenedComponents(increaseWriterIndex, buffer); + } + + @Override + public CompositeByteBuf removeComponent(int cIndex) { + recordLeakNonRefCountingOperation(leak); + return super.removeComponent(cIndex); + } + + @Override + public CompositeByteBuf removeComponents(int cIndex, int numComponents) { + recordLeakNonRefCountingOperation(leak); + return super.removeComponents(cIndex, numComponents); + } + + @Override + public Iterator iterator() { + recordLeakNonRefCountingOperation(leak); + return super.iterator(); + } + + @Override + public List decompose(int offset, int length) { + recordLeakNonRefCountingOperation(leak); + return super.decompose(offset, length); + } + + @Override + public CompositeByteBuf consolidate() { + recordLeakNonRefCountingOperation(leak); + return super.consolidate(); + } + + @Override + public CompositeByteBuf discardReadComponents() { + recordLeakNonRefCountingOperation(leak); + return super.discardReadComponents(); + } + + @Override + public CompositeByteBuf consolidate(int cIndex, int numComponents) { + recordLeakNonRefCountingOperation(leak); + return super.consolidate(cIndex, numComponents); + } + + @Override + public int getBytes(int index, FileChannel out, long position, int length) throws IOException { + recordLeakNonRefCountingOperation(leak); + return super.getBytes(index, out, position, length); + } + + @Override + public int setBytes(int index, FileChannel in, long position, int length) throws IOException { + recordLeakNonRefCountingOperation(leak); + return super.setBytes(index, in, position, length); + } + + @Override + public int readBytes(FileChannel out, long position, int length) throws IOException { + recordLeakNonRefCountingOperation(leak); + return super.readBytes(out, position, length); + } + + @Override + public int writeBytes(FileChannel in, long position, int length) throws IOException { + recordLeakNonRefCountingOperation(leak); + return super.writeBytes(in, position, length); + } + + @Override + public CompositeByteBuf retain() { + leak.record(); + return super.retain(); + } + + @Override + public CompositeByteBuf retain(int increment) { + leak.record(); + return super.retain(increment); + } + + @Override + public boolean release() { + leak.record(); + return super.release(); + } + + @Override + public boolean release(int decrement) { + leak.record(); + return super.release(decrement); + } + + @Override + public CompositeByteBuf touch() { + leak.record(); + return this; + } + + @Override + public CompositeByteBuf touch(Object hint) { + leak.record(hint); + return this; + } + + @Override + protected AdvancedLeakAwareByteBuf newLeakAwareByteBuf( + ByteBuf wrapped, ByteBuf trackedByteBuf, ResourceLeakTracker leakTracker) { + return new AdvancedLeakAwareByteBuf(wrapped, trackedByteBuf, leakTracker); + } +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/ByteBuf.java b/netty-buffer/src/main/java/io/netty/buffer/ByteBuf.java new file mode 100644 index 0000000..b0fc1f2 --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/ByteBuf.java @@ -0,0 +1,2492 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import io.netty.util.ByteProcessor; +import io.netty.util.ReferenceCounted; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.channels.FileChannel; +import java.nio.channels.GatheringByteChannel; +import java.nio.channels.ScatteringByteChannel; +import java.nio.charset.Charset; +import java.nio.charset.UnsupportedCharsetException; + +/** + * A random and sequential accessible sequence of zero or more bytes (octets). + * This interface provides an abstract view for one or more primitive byte + * arrays ({@code byte[]}) and {@linkplain ByteBuffer NIO buffers}. + * + *

Creation of a buffer

+ * + * It is recommended to create a new buffer using the helper methods in + * {@link Unpooled} rather than calling an individual implementation's + * constructor. + * + *

Random Access Indexing

+ * + * Just like an ordinary primitive byte array, {@link ByteBuf} uses + * zero-based indexing. + * It means the index of the first byte is always {@code 0} and the index of the last byte is + * always {@link #capacity() capacity - 1}. For example, to iterate all bytes of a buffer, you + * can do the following, regardless of its internal implementation: + * + *
+ * {@link ByteBuf} buffer = ...;
+ * for (int i = 0; i < buffer.capacity(); i ++) {
+ *     byte b = buffer.getByte(i);
+ *     System.out.println((char) b);
+ * }
+ * 
+ * + *

Sequential Access Indexing

+ * + * {@link ByteBuf} provides two pointer variables to support sequential + * read and write operations - {@link #readerIndex() readerIndex} for a read + * operation and {@link #writerIndex() writerIndex} for a write operation + * respectively. The following diagram shows how a buffer is segmented into + * three areas by the two pointers: + * + *
+ *      +-------------------+------------------+------------------+
+ *      | discardable bytes |  readable bytes  |  writable bytes  |
+ *      |                   |     (CONTENT)    |                  |
+ *      +-------------------+------------------+------------------+
+ *      |                   |                  |                  |
+ *      0      <=      readerIndex   <=   writerIndex    <=    capacity
+ * 
+ * + *

Readable bytes (the actual content)

+ * + * This segment is where the actual data is stored. Any operation whose name + * starts with {@code read} or {@code skip} will get or skip the data at the + * current {@link #readerIndex() readerIndex} and increase it by the number of + * read bytes. If the argument of the read operation is also a + * {@link ByteBuf} and no destination index is specified, the specified + * buffer's {@link #writerIndex() writerIndex} is increased together. + *

+ * If there's not enough content left, {@link IndexOutOfBoundsException} is + * raised. The default value of newly allocated, wrapped or copied buffer's + * {@link #readerIndex() readerIndex} is {@code 0}. + * + *

+ * // Iterates the readable bytes of a buffer.
+ * {@link ByteBuf} buffer = ...;
+ * while (buffer.isReadable()) {
+ *     System.out.println(buffer.readByte());
+ * }
+ * 
+ * + *

Writable bytes

+ * + * This segment is a undefined space which needs to be filled. Any operation + * whose name starts with {@code write} will write the data at the current + * {@link #writerIndex() writerIndex} and increase it by the number of written + * bytes. If the argument of the write operation is also a {@link ByteBuf}, + * and no source index is specified, the specified buffer's + * {@link #readerIndex() readerIndex} is increased together. + *

+ * If there's not enough writable bytes left, {@link IndexOutOfBoundsException} + * is raised. The default value of newly allocated buffer's + * {@link #writerIndex() writerIndex} is {@code 0}. The default value of + * wrapped or copied buffer's {@link #writerIndex() writerIndex} is the + * {@link #capacity() capacity} of the buffer. + * + *

+ * // Fills the writable bytes of a buffer with random integers.
+ * {@link ByteBuf} buffer = ...;
+ * while (buffer.maxWritableBytes() >= 4) {
+ *     buffer.writeInt(random.nextInt());
+ * }
+ * 
+ * + *

Discardable bytes

+ * + * This segment contains the bytes which were read already by a read operation. + * Initially, the size of this segment is {@code 0}, but its size increases up + * to the {@link #writerIndex() writerIndex} as read operations are executed. + * The read bytes can be discarded by calling {@link #discardReadBytes()} to + * reclaim unused area as depicted by the following diagram: + * + *
+ *  BEFORE discardReadBytes()
+ *
+ *      +-------------------+------------------+------------------+
+ *      | discardable bytes |  readable bytes  |  writable bytes  |
+ *      +-------------------+------------------+------------------+
+ *      |                   |                  |                  |
+ *      0      <=      readerIndex   <=   writerIndex    <=    capacity
+ *
+ *
+ *  AFTER discardReadBytes()
+ *
+ *      +------------------+--------------------------------------+
+ *      |  readable bytes  |    writable bytes (got more space)   |
+ *      +------------------+--------------------------------------+
+ *      |                  |                                      |
+ * readerIndex (0) <= writerIndex (decreased)        <=        capacity
+ * 
+ * + * Please note that there is no guarantee about the content of writable bytes + * after calling {@link #discardReadBytes()}. The writable bytes will not be + * moved in most cases and could even be filled with completely different data + * depending on the underlying buffer implementation. + * + *

Clearing the buffer indexes

+ * + * You can set both {@link #readerIndex() readerIndex} and + * {@link #writerIndex() writerIndex} to {@code 0} by calling {@link #clear()}. + * It does not clear the buffer content (e.g. filling with {@code 0}) but just + * clears the two pointers. Please also note that the semantic of this + * operation is different from {@link ByteBuffer#clear()}. + * + *
+ *  BEFORE clear()
+ *
+ *      +-------------------+------------------+------------------+
+ *      | discardable bytes |  readable bytes  |  writable bytes  |
+ *      +-------------------+------------------+------------------+
+ *      |                   |                  |                  |
+ *      0      <=      readerIndex   <=   writerIndex    <=    capacity
+ *
+ *
+ *  AFTER clear()
+ *
+ *      +---------------------------------------------------------+
+ *      |             writable bytes (got more space)             |
+ *      +---------------------------------------------------------+
+ *      |                                                         |
+ *      0 = readerIndex = writerIndex            <=            capacity
+ * 
+ * + *

Search operations

+ * + * For simple single-byte searches, use {@link #indexOf(int, int, byte)} and {@link #bytesBefore(int, int, byte)}. + * {@link #bytesBefore(byte)} is especially useful when you deal with a {@code NUL}-terminated string. + * For complicated searches, use {@link #forEachByte(int, int, ByteProcessor)} with a {@link ByteProcessor} + * implementation. + * + *

Mark and reset

+ * + * There are two marker indexes in every buffer. One is for storing + * {@link #readerIndex() readerIndex} and the other is for storing + * {@link #writerIndex() writerIndex}. You can always reposition one of the + * two indexes by calling a reset method. It works in a similar fashion to + * the mark and reset methods in {@link InputStream} except that there's no + * {@code readlimit}. + * + *

Derived buffers

+ * + * You can create a view of an existing buffer by calling one of the following methods: + *
    + *
  • {@link #duplicate()}
  • + *
  • {@link #slice()}
  • + *
  • {@link #slice(int, int)}
  • + *
  • {@link #readSlice(int)}
  • + *
  • {@link #retainedDuplicate()}
  • + *
  • {@link #retainedSlice()}
  • + *
  • {@link #retainedSlice(int, int)}
  • + *
  • {@link #readRetainedSlice(int)}
  • + *
+ * A derived buffer will have an independent {@link #readerIndex() readerIndex}, + * {@link #writerIndex() writerIndex} and marker indexes, while it shares + * other internal data representation, just like a NIO buffer does. + *

+ * In case a completely fresh copy of an existing buffer is required, please + * call {@link #copy()} method instead. + * + *

Non-retained and retained derived buffers

+ * + * Note that the {@link #duplicate()}, {@link #slice()}, {@link #slice(int, int)} and {@link #readSlice(int)} does NOT + * call {@link #retain()} on the returned derived buffer, and thus its reference count will NOT be increased. If you + * need to create a derived buffer with increased reference count, consider using {@link #retainedDuplicate()}, + * {@link #retainedSlice()}, {@link #retainedSlice(int, int)} and {@link #readRetainedSlice(int)} which may return + * a buffer implementation that produces less garbage. + * + *

Conversion to existing JDK types

+ * + *

Byte array

+ * + * If a {@link ByteBuf} is backed by a byte array (i.e. {@code byte[]}), + * you can access it directly via the {@link #array()} method. To determine + * if a buffer is backed by a byte array, {@link #hasArray()} should be used. + * + *

NIO Buffers

+ * + * If a {@link ByteBuf} can be converted into an NIO {@link ByteBuffer} which shares its + * content (i.e. view buffer), you can get it via the {@link #nioBuffer()} method. To determine + * if a buffer can be converted into an NIO buffer, use {@link #nioBufferCount()}. + * + *

Strings

+ * + * Various {@link #toString(Charset)} methods convert a {@link ByteBuf} + * into a {@link String}. Please note that {@link #toString()} is not a + * conversion method. + * + *

I/O Streams

+ * + * Please refer to {@link ByteBufInputStream} and + * {@link ByteBufOutputStream}. + */ +public abstract class ByteBuf implements ReferenceCounted, Comparable, ByteBufConvertible { + + /** + * Returns the number of bytes (octets) this buffer can contain. + */ + public abstract int capacity(); + + /** + * Adjusts the capacity of this buffer. If the {@code newCapacity} is less than the current + * capacity, the content of this buffer is truncated. If the {@code newCapacity} is greater + * than the current capacity, the buffer is appended with unspecified data whose length is + * {@code (newCapacity - currentCapacity)}. + * + * @throws IllegalArgumentException if the {@code newCapacity} is greater than {@link #maxCapacity()} + */ + public abstract ByteBuf capacity(int newCapacity); + + /** + * Returns the maximum allowed capacity of this buffer. This value provides an upper + * bound on {@link #capacity()}. + */ + public abstract int maxCapacity(); + + /** + * Returns the {@link ByteBufAllocator} which created this buffer. + */ + public abstract ByteBufAllocator alloc(); + + /** + * Returns the endianness + * of this buffer. + * + * @deprecated use the Little Endian accessors, e.g. {@code getShortLE}, {@code getIntLE} + * instead of creating a buffer with swapped {@code endianness}. + */ + @Deprecated + public abstract ByteOrder order(); + + /** + * Returns a buffer with the specified {@code endianness} which shares the whole region, + * indexes, and marks of this buffer. Modifying the content, the indexes, or the marks of the + * returned buffer or this buffer affects each other's content, indexes, and marks. If the + * specified {@code endianness} is identical to this buffer's byte order, this method can + * return {@code this}. This method does not modify {@code readerIndex} or {@code writerIndex} + * of this buffer. + * + * @deprecated use the Little Endian accessors, e.g. {@code getShortLE}, {@code getIntLE} + * instead of creating a buffer with swapped {@code endianness}. + */ + @Deprecated + public abstract ByteBuf order(ByteOrder endianness); + + /** + * Return the underlying buffer instance if this buffer is a wrapper of another buffer. + * + * @return {@code null} if this buffer is not a wrapper + */ + public abstract ByteBuf unwrap(); + + /** + * Returns {@code true} if and only if this buffer is backed by an + * NIO direct buffer. + */ + public abstract boolean isDirect(); + + /** + * Returns {@code true} if and only if this buffer is read-only. + */ + public abstract boolean isReadOnly(); + + /** + * Returns a read-only version of this buffer. + */ + public abstract ByteBuf asReadOnly(); + + /** + * Returns the {@code readerIndex} of this buffer. + */ + public abstract int readerIndex(); + + /** + * Sets the {@code readerIndex} of this buffer. + * + * @throws IndexOutOfBoundsException + * if the specified {@code readerIndex} is + * less than {@code 0} or + * greater than {@code this.writerIndex} + */ + public abstract ByteBuf readerIndex(int readerIndex); + + /** + * Returns the {@code writerIndex} of this buffer. + */ + public abstract int writerIndex(); + + /** + * Sets the {@code writerIndex} of this buffer. + * + * @throws IndexOutOfBoundsException + * if the specified {@code writerIndex} is + * less than {@code this.readerIndex} or + * greater than {@code this.capacity} + */ + public abstract ByteBuf writerIndex(int writerIndex); + + /** + * Sets the {@code readerIndex} and {@code writerIndex} of this buffer + * in one shot. This method is useful when you have to worry about the + * invocation order of {@link #readerIndex(int)} and {@link #writerIndex(int)} + * methods. For example, the following code will fail: + * + *
+     * // Create a buffer whose readerIndex, writerIndex and capacity are
+     * // 0, 0 and 8 respectively.
+     * {@link ByteBuf} buf = {@link Unpooled}.buffer(8);
+     *
+     * // IndexOutOfBoundsException is thrown because the specified
+     * // readerIndex (2) cannot be greater than the current writerIndex (0).
+     * buf.readerIndex(2);
+     * buf.writerIndex(4);
+     * 
+ * + * The following code will also fail: + * + *
+     * // Create a buffer whose readerIndex, writerIndex and capacity are
+     * // 0, 8 and 8 respectively.
+     * {@link ByteBuf} buf = {@link Unpooled}.wrappedBuffer(new byte[8]);
+     *
+     * // readerIndex becomes 8.
+     * buf.readLong();
+     *
+     * // IndexOutOfBoundsException is thrown because the specified
+     * // writerIndex (4) cannot be less than the current readerIndex (8).
+     * buf.writerIndex(4);
+     * buf.readerIndex(2);
+     * 
+ * + * By contrast, this method guarantees that it never + * throws an {@link IndexOutOfBoundsException} as long as the specified + * indexes meet basic constraints, regardless what the current index + * values of the buffer are: + * + *
+     * // No matter what the current state of the buffer is, the following
+     * // call always succeeds as long as the capacity of the buffer is not
+     * // less than 4.
+     * buf.setIndex(2, 4);
+     * 
+ * + * @throws IndexOutOfBoundsException + * if the specified {@code readerIndex} is less than 0, + * if the specified {@code writerIndex} is less than the specified + * {@code readerIndex} or if the specified {@code writerIndex} is + * greater than {@code this.capacity} + */ + public abstract ByteBuf setIndex(int readerIndex, int writerIndex); + + /** + * Returns the number of readable bytes which is equal to + * {@code (this.writerIndex - this.readerIndex)}. + */ + public abstract int readableBytes(); + + /** + * Returns the number of writable bytes which is equal to + * {@code (this.capacity - this.writerIndex)}. + */ + public abstract int writableBytes(); + + /** + * Returns the maximum possible number of writable bytes, which is equal to + * {@code (this.maxCapacity - this.writerIndex)}. + */ + public abstract int maxWritableBytes(); + + /** + * Returns the maximum number of bytes which can be written for certain without involving + * an internal reallocation or data-copy. The returned value will be ≥ {@link #writableBytes()} + * and ≤ {@link #maxWritableBytes()}. + */ + public int maxFastWritableBytes() { + return writableBytes(); + } + + /** + * Returns {@code true} + * if and only if {@code (this.writerIndex - this.readerIndex)} is greater + * than {@code 0}. + */ + public abstract boolean isReadable(); + + /** + * Returns {@code true} if and only if this buffer contains equal to or more than the specified number of elements. + */ + public abstract boolean isReadable(int size); + + /** + * Returns {@code true} + * if and only if {@code (this.capacity - this.writerIndex)} is greater + * than {@code 0}. + */ + public abstract boolean isWritable(); + + /** + * Returns {@code true} if and only if this buffer has enough room to allow writing the specified number of + * elements. + */ + public abstract boolean isWritable(int size); + + /** + * Sets the {@code readerIndex} and {@code writerIndex} of this buffer to + * {@code 0}. + * This method is identical to {@link #setIndex(int, int) setIndex(0, 0)}. + *

+ * Please note that the behavior of this method is different + * from that of NIO buffer, which sets the {@code limit} to + * the {@code capacity} of the buffer. + */ + public abstract ByteBuf clear(); + + /** + * Marks the current {@code readerIndex} in this buffer. You can + * reposition the current {@code readerIndex} to the marked + * {@code readerIndex} by calling {@link #resetReaderIndex()}. + * The initial value of the marked {@code readerIndex} is {@code 0}. + */ + public abstract ByteBuf markReaderIndex(); + + /** + * Repositions the current {@code readerIndex} to the marked + * {@code readerIndex} in this buffer. + * + * @throws IndexOutOfBoundsException + * if the current {@code writerIndex} is less than the marked + * {@code readerIndex} + */ + public abstract ByteBuf resetReaderIndex(); + + /** + * Marks the current {@code writerIndex} in this buffer. You can + * reposition the current {@code writerIndex} to the marked + * {@code writerIndex} by calling {@link #resetWriterIndex()}. + * The initial value of the marked {@code writerIndex} is {@code 0}. + */ + public abstract ByteBuf markWriterIndex(); + + /** + * Repositions the current {@code writerIndex} to the marked + * {@code writerIndex} in this buffer. + * + * @throws IndexOutOfBoundsException + * if the current {@code readerIndex} is greater than the marked + * {@code writerIndex} + */ + public abstract ByteBuf resetWriterIndex(); + + /** + * Discards the bytes between the 0th index and {@code readerIndex}. + * It moves the bytes between {@code readerIndex} and {@code writerIndex} + * to the 0th index, and sets {@code readerIndex} and {@code writerIndex} + * to {@code 0} and {@code oldWriterIndex - oldReaderIndex} respectively. + *

+ * Please refer to the class documentation for more detailed explanation. + */ + public abstract ByteBuf discardReadBytes(); + + /** + * Similar to {@link ByteBuf#discardReadBytes()} except that this method might discard + * some, all, or none of read bytes depending on its internal implementation to reduce + * overall memory bandwidth consumption at the cost of potentially additional memory + * consumption. + */ + public abstract ByteBuf discardSomeReadBytes(); + + /** + * Expands the buffer {@link #capacity()} to make sure the number of + * {@linkplain #writableBytes() writable bytes} is equal to or greater than the + * specified value. If there are enough writable bytes in this buffer, this method + * returns with no side effect. + * + * @param minWritableBytes + * the expected minimum number of writable bytes + * @throws IndexOutOfBoundsException + * if {@link #writerIndex()} + {@code minWritableBytes} > {@link #maxCapacity()}. + * @see #capacity(int) + */ + public abstract ByteBuf ensureWritable(int minWritableBytes); + + /** + * Expands the buffer {@link #capacity()} to make sure the number of + * {@linkplain #writableBytes() writable bytes} is equal to or greater than the + * specified value. Unlike {@link #ensureWritable(int)}, this method returns a status code. + * + * @param minWritableBytes + * the expected minimum number of writable bytes + * @param force + * When {@link #writerIndex()} + {@code minWritableBytes} > {@link #maxCapacity()}: + *

    + *
  • {@code true} - the capacity of the buffer is expanded to {@link #maxCapacity()}
  • + *
  • {@code false} - the capacity of the buffer is unchanged
  • + *
+ * @return {@code 0} if the buffer has enough writable bytes, and its capacity is unchanged. + * {@code 1} if the buffer does not have enough bytes, and its capacity is unchanged. + * {@code 2} if the buffer has enough writable bytes, and its capacity has been increased. + * {@code 3} if the buffer does not have enough bytes, but its capacity has been + * increased to its maximum. + */ + public abstract int ensureWritable(int minWritableBytes, boolean force); + + /** + * Gets a boolean at the specified absolute (@code index) in this buffer. + * This method does not modify the {@code readerIndex} or {@code writerIndex} + * of this buffer. + * + * @throws IndexOutOfBoundsException + * if the specified {@code index} is less than {@code 0} or + * {@code index + 1} is greater than {@code this.capacity} + */ + public abstract boolean getBoolean(int index); + + /** + * Gets a byte at the specified absolute {@code index} in this buffer. + * This method does not modify {@code readerIndex} or {@code writerIndex} of + * this buffer. + * + * @throws IndexOutOfBoundsException + * if the specified {@code index} is less than {@code 0} or + * {@code index + 1} is greater than {@code this.capacity} + */ + public abstract byte getByte(int index); + + /** + * Gets an unsigned byte at the specified absolute {@code index} in this + * buffer. This method does not modify {@code readerIndex} or + * {@code writerIndex} of this buffer. + * + * @throws IndexOutOfBoundsException + * if the specified {@code index} is less than {@code 0} or + * {@code index + 1} is greater than {@code this.capacity} + */ + public abstract short getUnsignedByte(int index); + + /** + * Gets a 16-bit short integer at the specified absolute {@code index} in + * this buffer. This method does not modify {@code readerIndex} or + * {@code writerIndex} of this buffer. + * + * @throws IndexOutOfBoundsException + * if the specified {@code index} is less than {@code 0} or + * {@code index + 2} is greater than {@code this.capacity} + */ + public abstract short getShort(int index); + + /** + * Gets a 16-bit short integer at the specified absolute {@code index} in + * this buffer in Little Endian Byte Order. This method does not modify + * {@code readerIndex} or {@code writerIndex} of this buffer. + * + * @throws IndexOutOfBoundsException + * if the specified {@code index} is less than {@code 0} or + * {@code index + 2} is greater than {@code this.capacity} + */ + public abstract short getShortLE(int index); + + /** + * Gets an unsigned 16-bit short integer at the specified absolute + * {@code index} in this buffer. This method does not modify + * {@code readerIndex} or {@code writerIndex} of this buffer. + * + * @throws IndexOutOfBoundsException + * if the specified {@code index} is less than {@code 0} or + * {@code index + 2} is greater than {@code this.capacity} + */ + public abstract int getUnsignedShort(int index); + + /** + * Gets an unsigned 16-bit short integer at the specified absolute + * {@code index} in this buffer in Little Endian Byte Order. + * This method does not modify {@code readerIndex} or + * {@code writerIndex} of this buffer. + * + * @throws IndexOutOfBoundsException + * if the specified {@code index} is less than {@code 0} or + * {@code index + 2} is greater than {@code this.capacity} + */ + public abstract int getUnsignedShortLE(int index); + + /** + * Gets a 24-bit medium integer at the specified absolute {@code index} in + * this buffer. This method does not modify {@code readerIndex} or + * {@code writerIndex} of this buffer. + * + * @throws IndexOutOfBoundsException + * if the specified {@code index} is less than {@code 0} or + * {@code index + 3} is greater than {@code this.capacity} + */ + public abstract int getMedium(int index); + + /** + * Gets a 24-bit medium integer at the specified absolute {@code index} in + * this buffer in the Little Endian Byte Order. This method does not + * modify {@code readerIndex} or {@code writerIndex} of this buffer. + * + * @throws IndexOutOfBoundsException + * if the specified {@code index} is less than {@code 0} or + * {@code index + 3} is greater than {@code this.capacity} + */ + public abstract int getMediumLE(int index); + + /** + * Gets an unsigned 24-bit medium integer at the specified absolute + * {@code index} in this buffer. This method does not modify + * {@code readerIndex} or {@code writerIndex} of this buffer. + * + * @throws IndexOutOfBoundsException + * if the specified {@code index} is less than {@code 0} or + * {@code index + 3} is greater than {@code this.capacity} + */ + public abstract int getUnsignedMedium(int index); + + /** + * Gets an unsigned 24-bit medium integer at the specified absolute + * {@code index} in this buffer in Little Endian Byte Order. + * This method does not modify {@code readerIndex} or + * {@code writerIndex} of this buffer. + * + * @throws IndexOutOfBoundsException + * if the specified {@code index} is less than {@code 0} or + * {@code index + 3} is greater than {@code this.capacity} + */ + public abstract int getUnsignedMediumLE(int index); + + /** + * Gets a 32-bit integer at the specified absolute {@code index} in + * this buffer. This method does not modify {@code readerIndex} or + * {@code writerIndex} of this buffer. + * + * @throws IndexOutOfBoundsException + * if the specified {@code index} is less than {@code 0} or + * {@code index + 4} is greater than {@code this.capacity} + */ + public abstract int getInt(int index); + + /** + * Gets a 32-bit integer at the specified absolute {@code index} in + * this buffer with Little Endian Byte Order. This method does not + * modify {@code readerIndex} or {@code writerIndex} of this buffer. + * + * @throws IndexOutOfBoundsException + * if the specified {@code index} is less than {@code 0} or + * {@code index + 4} is greater than {@code this.capacity} + */ + public abstract int getIntLE(int index); + + /** + * Gets an unsigned 32-bit integer at the specified absolute {@code index} + * in this buffer. This method does not modify {@code readerIndex} or + * {@code writerIndex} of this buffer. + * + * @throws IndexOutOfBoundsException + * if the specified {@code index} is less than {@code 0} or + * {@code index + 4} is greater than {@code this.capacity} + */ + public abstract long getUnsignedInt(int index); + + /** + * Gets an unsigned 32-bit integer at the specified absolute {@code index} + * in this buffer in Little Endian Byte Order. This method does not + * modify {@code readerIndex} or {@code writerIndex} of this buffer. + * + * @throws IndexOutOfBoundsException + * if the specified {@code index} is less than {@code 0} or + * {@code index + 4} is greater than {@code this.capacity} + */ + public abstract long getUnsignedIntLE(int index); + + /** + * Gets a 64-bit long integer at the specified absolute {@code index} in + * this buffer. This method does not modify {@code readerIndex} or + * {@code writerIndex} of this buffer. + * + * @throws IndexOutOfBoundsException + * if the specified {@code index} is less than {@code 0} or + * {@code index + 8} is greater than {@code this.capacity} + */ + public abstract long getLong(int index); + + /** + * Gets a 64-bit long integer at the specified absolute {@code index} in + * this buffer in Little Endian Byte Order. This method does not + * modify {@code readerIndex} or {@code writerIndex} of this buffer. + * + * @throws IndexOutOfBoundsException + * if the specified {@code index} is less than {@code 0} or + * {@code index + 8} is greater than {@code this.capacity} + */ + public abstract long getLongLE(int index); + + /** + * Gets a 2-byte UTF-16 character at the specified absolute + * {@code index} in this buffer. This method does not modify + * {@code readerIndex} or {@code writerIndex} of this buffer. + * + * @throws IndexOutOfBoundsException + * if the specified {@code index} is less than {@code 0} or + * {@code index + 2} is greater than {@code this.capacity} + */ + public abstract char getChar(int index); + + /** + * Gets a 32-bit floating point number at the specified absolute + * {@code index} in this buffer. This method does not modify + * {@code readerIndex} or {@code writerIndex} of this buffer. + * + * @throws IndexOutOfBoundsException + * if the specified {@code index} is less than {@code 0} or + * {@code index + 4} is greater than {@code this.capacity} + */ + public abstract float getFloat(int index); + + /** + * Gets a 32-bit floating point number at the specified absolute + * {@code index} in this buffer in Little Endian Byte Order. + * This method does not modify {@code readerIndex} or + * {@code writerIndex} of this buffer. + * + * @throws IndexOutOfBoundsException + * if the specified {@code index} is less than {@code 0} or + * {@code index + 4} is greater than {@code this.capacity} + */ + public float getFloatLE(int index) { + return Float.intBitsToFloat(getIntLE(index)); + } + + /** + * Gets a 64-bit floating point number at the specified absolute + * {@code index} in this buffer. This method does not modify + * {@code readerIndex} or {@code writerIndex} of this buffer. + * + * @throws IndexOutOfBoundsException + * if the specified {@code index} is less than {@code 0} or + * {@code index + 8} is greater than {@code this.capacity} + */ + public abstract double getDouble(int index); + + /** + * Gets a 64-bit floating point number at the specified absolute + * {@code index} in this buffer in Little Endian Byte Order. + * This method does not modify {@code readerIndex} or + * {@code writerIndex} of this buffer. + * + * @throws IndexOutOfBoundsException + * if the specified {@code index} is less than {@code 0} or + * {@code index + 8} is greater than {@code this.capacity} + */ + public double getDoubleLE(int index) { + return Double.longBitsToDouble(getLongLE(index)); + } + + /** + * Transfers this buffer's data to the specified destination starting at + * the specified absolute {@code index} until the destination becomes + * non-writable. This method is basically same with + * {@link #getBytes(int, ByteBuf, int, int)}, except that this + * method increases the {@code writerIndex} of the destination by the + * number of the transferred bytes while + * {@link #getBytes(int, ByteBuf, int, int)} does not. + * This method does not modify {@code readerIndex} or {@code writerIndex} of + * the source buffer (i.e. {@code this}). + * + * @throws IndexOutOfBoundsException + * if the specified {@code index} is less than {@code 0} or + * if {@code index + dst.writableBytes} is greater than + * {@code this.capacity} + */ + public abstract ByteBuf getBytes(int index, ByteBuf dst); + + /** + * Transfers this buffer's data to the specified destination starting at + * the specified absolute {@code index}. This method is basically same + * with {@link #getBytes(int, ByteBuf, int, int)}, except that this + * method increases the {@code writerIndex} of the destination by the + * number of the transferred bytes while + * {@link #getBytes(int, ByteBuf, int, int)} does not. + * This method does not modify {@code readerIndex} or {@code writerIndex} of + * the source buffer (i.e. {@code this}). + * + * @param length the number of bytes to transfer + * + * @throws IndexOutOfBoundsException + * if the specified {@code index} is less than {@code 0}, + * if {@code index + length} is greater than + * {@code this.capacity}, or + * if {@code length} is greater than {@code dst.writableBytes} + */ + public abstract ByteBuf getBytes(int index, ByteBuf dst, int length); + + /** + * Transfers this buffer's data to the specified destination starting at + * the specified absolute {@code index}. + * This method does not modify {@code readerIndex} or {@code writerIndex} + * of both the source (i.e. {@code this}) and the destination. + * + * @param dstIndex the first index of the destination + * @param length the number of bytes to transfer + * + * @throws IndexOutOfBoundsException + * if the specified {@code index} is less than {@code 0}, + * if the specified {@code dstIndex} is less than {@code 0}, + * if {@code index + length} is greater than + * {@code this.capacity}, or + * if {@code dstIndex + length} is greater than + * {@code dst.capacity} + */ + public abstract ByteBuf getBytes(int index, ByteBuf dst, int dstIndex, int length); + + /** + * Transfers this buffer's data to the specified destination starting at + * the specified absolute {@code index}. + * This method does not modify {@code readerIndex} or {@code writerIndex} of + * this buffer + * + * @throws IndexOutOfBoundsException + * if the specified {@code index} is less than {@code 0} or + * if {@code index + dst.length} is greater than + * {@code this.capacity} + */ + public abstract ByteBuf getBytes(int index, byte[] dst); + + /** + * Transfers this buffer's data to the specified destination starting at + * the specified absolute {@code index}. + * This method does not modify {@code readerIndex} or {@code writerIndex} + * of this buffer. + * + * @param dstIndex the first index of the destination + * @param length the number of bytes to transfer + * + * @throws IndexOutOfBoundsException + * if the specified {@code index} is less than {@code 0}, + * if the specified {@code dstIndex} is less than {@code 0}, + * if {@code index + length} is greater than + * {@code this.capacity}, or + * if {@code dstIndex + length} is greater than + * {@code dst.length} + */ + public abstract ByteBuf getBytes(int index, byte[] dst, int dstIndex, int length); + + /** + * Transfers this buffer's data to the specified destination starting at + * the specified absolute {@code index} until the destination's position + * reaches its limit. + * This method does not modify {@code readerIndex} or {@code writerIndex} of + * this buffer while the destination's {@code position} will be increased. + * + * @throws IndexOutOfBoundsException + * if the specified {@code index} is less than {@code 0} or + * if {@code index + dst.remaining()} is greater than + * {@code this.capacity} + */ + public abstract ByteBuf getBytes(int index, ByteBuffer dst); + + /** + * Transfers this buffer's data to the specified stream starting at the + * specified absolute {@code index}. + * This method does not modify {@code readerIndex} or {@code writerIndex} of + * this buffer. + * + * @param length the number of bytes to transfer + * + * @throws IndexOutOfBoundsException + * if the specified {@code index} is less than {@code 0} or + * if {@code index + length} is greater than + * {@code this.capacity} + * @throws IOException + * if the specified stream threw an exception during I/O + */ + public abstract ByteBuf getBytes(int index, OutputStream out, int length) throws IOException; + + /** + * Transfers this buffer's data to the specified channel starting at the + * specified absolute {@code index}. + * This method does not modify {@code readerIndex} or {@code writerIndex} of + * this buffer. + * + * @param length the maximum number of bytes to transfer + * + * @return the actual number of bytes written out to the specified channel + * + * @throws IndexOutOfBoundsException + * if the specified {@code index} is less than {@code 0} or + * if {@code index + length} is greater than + * {@code this.capacity} + * @throws IOException + * if the specified channel threw an exception during I/O + */ + public abstract int getBytes(int index, GatheringByteChannel out, int length) throws IOException; + + /** + * Transfers this buffer's data starting at the specified absolute {@code index} + * to the specified channel starting at the given file position. + * This method does not modify {@code readerIndex} or {@code writerIndex} of + * this buffer. This method does not modify the channel's position. + * + * @param position the file position at which the transfer is to begin + * @param length the maximum number of bytes to transfer + * + * @return the actual number of bytes written out to the specified channel + * + * @throws IndexOutOfBoundsException + * if the specified {@code index} is less than {@code 0} or + * if {@code index + length} is greater than + * {@code this.capacity} + * @throws IOException + * if the specified channel threw an exception during I/O + */ + public abstract int getBytes(int index, FileChannel out, long position, int length) throws IOException; + + /** + * Gets a {@link CharSequence} with the given length at the given index. + * + * @param length the length to read + * @param charset that should be used + * @return the sequence + * @throws IndexOutOfBoundsException + * if {@code length} is greater than {@code this.readableBytes} + */ + public abstract CharSequence getCharSequence(int index, int length, Charset charset); + + /** + * Sets the specified boolean at the specified absolute {@code index} in this + * buffer. + * This method does not modify {@code readerIndex} or {@code writerIndex} of + * this buffer. + * + * @throws IndexOutOfBoundsException + * if the specified {@code index} is less than {@code 0} or + * {@code index + 1} is greater than {@code this.capacity} + */ + public abstract ByteBuf setBoolean(int index, boolean value); + + /** + * Sets the specified byte at the specified absolute {@code index} in this + * buffer. The 24 high-order bits of the specified value are ignored. + * This method does not modify {@code readerIndex} or {@code writerIndex} of + * this buffer. + * + * @throws IndexOutOfBoundsException + * if the specified {@code index} is less than {@code 0} or + * {@code index + 1} is greater than {@code this.capacity} + */ + public abstract ByteBuf setByte(int index, int value); + + /** + * Sets the specified 16-bit short integer at the specified absolute + * {@code index} in this buffer. The 16 high-order bits of the specified + * value are ignored. + * This method does not modify {@code readerIndex} or {@code writerIndex} of + * this buffer. + * + * @throws IndexOutOfBoundsException + * if the specified {@code index} is less than {@code 0} or + * {@code index + 2} is greater than {@code this.capacity} + */ + public abstract ByteBuf setShort(int index, int value); + + /** + * Sets the specified 16-bit short integer at the specified absolute + * {@code index} in this buffer with the Little Endian Byte Order. + * The 16 high-order bits of the specified value are ignored. + * This method does not modify {@code readerIndex} or {@code writerIndex} of + * this buffer. + * + * @throws IndexOutOfBoundsException + * if the specified {@code index} is less than {@code 0} or + * {@code index + 2} is greater than {@code this.capacity} + */ + public abstract ByteBuf setShortLE(int index, int value); + + /** + * Sets the specified 24-bit medium integer at the specified absolute + * {@code index} in this buffer. Please note that the most significant + * byte is ignored in the specified value. + * This method does not modify {@code readerIndex} or {@code writerIndex} of + * this buffer. + * + * @throws IndexOutOfBoundsException + * if the specified {@code index} is less than {@code 0} or + * {@code index + 3} is greater than {@code this.capacity} + */ + public abstract ByteBuf setMedium(int index, int value); + + /** + * Sets the specified 24-bit medium integer at the specified absolute + * {@code index} in this buffer in the Little Endian Byte Order. + * Please note that the most significant byte is ignored in the + * specified value. + * This method does not modify {@code readerIndex} or {@code writerIndex} of + * this buffer. + * + * @throws IndexOutOfBoundsException + * if the specified {@code index} is less than {@code 0} or + * {@code index + 3} is greater than {@code this.capacity} + */ + public abstract ByteBuf setMediumLE(int index, int value); + + /** + * Sets the specified 32-bit integer at the specified absolute + * {@code index} in this buffer. + * This method does not modify {@code readerIndex} or {@code writerIndex} of + * this buffer. + * + * @throws IndexOutOfBoundsException + * if the specified {@code index} is less than {@code 0} or + * {@code index + 4} is greater than {@code this.capacity} + */ + public abstract ByteBuf setInt(int index, int value); + + /** + * Sets the specified 32-bit integer at the specified absolute + * {@code index} in this buffer with Little Endian byte order + * . + * This method does not modify {@code readerIndex} or {@code writerIndex} of + * this buffer. + * + * @throws IndexOutOfBoundsException + * if the specified {@code index} is less than {@code 0} or + * {@code index + 4} is greater than {@code this.capacity} + */ + public abstract ByteBuf setIntLE(int index, int value); + + /** + * Sets the specified 64-bit long integer at the specified absolute + * {@code index} in this buffer. + * This method does not modify {@code readerIndex} or {@code writerIndex} of + * this buffer. + * + * @throws IndexOutOfBoundsException + * if the specified {@code index} is less than {@code 0} or + * {@code index + 8} is greater than {@code this.capacity} + */ + public abstract ByteBuf setLong(int index, long value); + + /** + * Sets the specified 64-bit long integer at the specified absolute + * {@code index} in this buffer in Little Endian Byte Order. + * This method does not modify {@code readerIndex} or {@code writerIndex} of + * this buffer. + * + * @throws IndexOutOfBoundsException + * if the specified {@code index} is less than {@code 0} or + * {@code index + 8} is greater than {@code this.capacity} + */ + public abstract ByteBuf setLongLE(int index, long value); + + /** + * Sets the specified 2-byte UTF-16 character at the specified absolute + * {@code index} in this buffer. + * The 16 high-order bits of the specified value are ignored. + * This method does not modify {@code readerIndex} or {@code writerIndex} of + * this buffer. + * + * @throws IndexOutOfBoundsException + * if the specified {@code index} is less than {@code 0} or + * {@code index + 2} is greater than {@code this.capacity} + */ + public abstract ByteBuf setChar(int index, int value); + + /** + * Sets the specified 32-bit floating-point number at the specified + * absolute {@code index} in this buffer. + * This method does not modify {@code readerIndex} or {@code writerIndex} of + * this buffer. + * + * @throws IndexOutOfBoundsException + * if the specified {@code index} is less than {@code 0} or + * {@code index + 4} is greater than {@code this.capacity} + */ + public abstract ByteBuf setFloat(int index, float value); + + /** + * Sets the specified 32-bit floating-point number at the specified + * absolute {@code index} in this buffer in Little Endian Byte Order. + * This method does not modify {@code readerIndex} or {@code writerIndex} of + * this buffer. + * + * @throws IndexOutOfBoundsException + * if the specified {@code index} is less than {@code 0} or + * {@code index + 4} is greater than {@code this.capacity} + */ + public ByteBuf setFloatLE(int index, float value) { + return setIntLE(index, Float.floatToRawIntBits(value)); + } + + /** + * Sets the specified 64-bit floating-point number at the specified + * absolute {@code index} in this buffer. + * This method does not modify {@code readerIndex} or {@code writerIndex} of + * this buffer. + * + * @throws IndexOutOfBoundsException + * if the specified {@code index} is less than {@code 0} or + * {@code index + 8} is greater than {@code this.capacity} + */ + public abstract ByteBuf setDouble(int index, double value); + + /** + * Sets the specified 64-bit floating-point number at the specified + * absolute {@code index} in this buffer in Little Endian Byte Order. + * This method does not modify {@code readerIndex} or {@code writerIndex} of + * this buffer. + * + * @throws IndexOutOfBoundsException + * if the specified {@code index} is less than {@code 0} or + * {@code index + 8} is greater than {@code this.capacity} + */ + public ByteBuf setDoubleLE(int index, double value) { + return setLongLE(index, Double.doubleToRawLongBits(value)); + } + + /** + * Transfers the specified source buffer's data to this buffer starting at + * the specified absolute {@code index} until the source buffer becomes + * unreadable. This method is basically same with + * {@link #setBytes(int, ByteBuf, int, int)}, except that this + * method increases the {@code readerIndex} of the source buffer by + * the number of the transferred bytes while + * {@link #setBytes(int, ByteBuf, int, int)} does not. + * This method does not modify {@code readerIndex} or {@code writerIndex} of + * this buffer (i.e. {@code this}). + * + * @throws IndexOutOfBoundsException + * if the specified {@code index} is less than {@code 0} or + * if {@code index + src.readableBytes} is greater than + * {@code this.capacity} + */ + public abstract ByteBuf setBytes(int index, ByteBuf src); + + /** + * Transfers the specified source buffer's data to this buffer starting at + * the specified absolute {@code index}. This method is basically same + * with {@link #setBytes(int, ByteBuf, int, int)}, except that this + * method increases the {@code readerIndex} of the source buffer by + * the number of the transferred bytes while + * {@link #setBytes(int, ByteBuf, int, int)} does not. + * This method does not modify {@code readerIndex} or {@code writerIndex} of + * this buffer (i.e. {@code this}). + * + * @param length the number of bytes to transfer + * + * @throws IndexOutOfBoundsException + * if the specified {@code index} is less than {@code 0}, + * if {@code index + length} is greater than + * {@code this.capacity}, or + * if {@code length} is greater than {@code src.readableBytes} + */ + public abstract ByteBuf setBytes(int index, ByteBuf src, int length); + + /** + * Transfers the specified source buffer's data to this buffer starting at + * the specified absolute {@code index}. + * This method does not modify {@code readerIndex} or {@code writerIndex} + * of both the source (i.e. {@code this}) and the destination. + * + * @param srcIndex the first index of the source + * @param length the number of bytes to transfer + * + * @throws IndexOutOfBoundsException + * if the specified {@code index} is less than {@code 0}, + * if the specified {@code srcIndex} is less than {@code 0}, + * if {@code index + length} is greater than + * {@code this.capacity}, or + * if {@code srcIndex + length} is greater than + * {@code src.capacity} + */ + public abstract ByteBuf setBytes(int index, ByteBuf src, int srcIndex, int length); + + /** + * Transfers the specified source array's data to this buffer starting at + * the specified absolute {@code index}. + * This method does not modify {@code readerIndex} or {@code writerIndex} of + * this buffer. + * + * @throws IndexOutOfBoundsException + * if the specified {@code index} is less than {@code 0} or + * if {@code index + src.length} is greater than + * {@code this.capacity} + */ + public abstract ByteBuf setBytes(int index, byte[] src); + + /** + * Transfers the specified source array's data to this buffer starting at + * the specified absolute {@code index}. + * This method does not modify {@code readerIndex} or {@code writerIndex} of + * this buffer. + * + * @throws IndexOutOfBoundsException + * if the specified {@code index} is less than {@code 0}, + * if the specified {@code srcIndex} is less than {@code 0}, + * if {@code index + length} is greater than + * {@code this.capacity}, or + * if {@code srcIndex + length} is greater than {@code src.length} + */ + public abstract ByteBuf setBytes(int index, byte[] src, int srcIndex, int length); + + /** + * Transfers the specified source buffer's data to this buffer starting at + * the specified absolute {@code index} until the source buffer's position + * reaches its limit. + * This method does not modify {@code readerIndex} or {@code writerIndex} of + * this buffer. + * + * @throws IndexOutOfBoundsException + * if the specified {@code index} is less than {@code 0} or + * if {@code index + src.remaining()} is greater than + * {@code this.capacity} + */ + public abstract ByteBuf setBytes(int index, ByteBuffer src); + + /** + * Transfers the content of the specified source stream to this buffer + * starting at the specified absolute {@code index}. + * This method does not modify {@code readerIndex} or {@code writerIndex} of + * this buffer. + * + * @param length the number of bytes to transfer + * + * @return the actual number of bytes read in from the specified channel. + * {@code -1} if the specified {@link InputStream} reached EOF. + * + * @throws IndexOutOfBoundsException + * if the specified {@code index} is less than {@code 0} or + * if {@code index + length} is greater than {@code this.capacity} + * @throws IOException + * if the specified stream threw an exception during I/O + */ + public abstract int setBytes(int index, InputStream in, int length) throws IOException; + + /** + * Transfers the content of the specified source channel to this buffer + * starting at the specified absolute {@code index}. + * This method does not modify {@code readerIndex} or {@code writerIndex} of + * this buffer. + * + * @param length the maximum number of bytes to transfer + * + * @return the actual number of bytes read in from the specified channel. + * {@code -1} if the specified channel is closed or it reached EOF. + * + * @throws IndexOutOfBoundsException + * if the specified {@code index} is less than {@code 0} or + * if {@code index + length} is greater than {@code this.capacity} + * @throws IOException + * if the specified channel threw an exception during I/O + */ + public abstract int setBytes(int index, ScatteringByteChannel in, int length) throws IOException; + + /** + * Transfers the content of the specified source channel starting at the given file position + * to this buffer starting at the specified absolute {@code index}. + * This method does not modify {@code readerIndex} or {@code writerIndex} of + * this buffer. This method does not modify the channel's position. + * + * @param position the file position at which the transfer is to begin + * @param length the maximum number of bytes to transfer + * + * @return the actual number of bytes read in from the specified channel. + * {@code -1} if the specified channel is closed or it reached EOF. + * + * @throws IndexOutOfBoundsException + * if the specified {@code index} is less than {@code 0} or + * if {@code index + length} is greater than {@code this.capacity} + * @throws IOException + * if the specified channel threw an exception during I/O + */ + public abstract int setBytes(int index, FileChannel in, long position, int length) throws IOException; + + /** + * Fills this buffer with NUL (0x00) starting at the specified + * absolute {@code index}. + * This method does not modify {@code readerIndex} or {@code writerIndex} of + * this buffer. + * + * @param length the number of NULs to write to the buffer + * + * @throws IndexOutOfBoundsException + * if the specified {@code index} is less than {@code 0} or + * if {@code index + length} is greater than {@code this.capacity} + */ + public abstract ByteBuf setZero(int index, int length); + + /** + * Writes the specified {@link CharSequence} at the given {@code index}. + * The {@code writerIndex} is not modified by this method. + * + * @param index on which the sequence should be written + * @param sequence to write + * @param charset that should be used. + * @return the written number of bytes. + * @throws IndexOutOfBoundsException + * if the sequence at the given index would be out of bounds of the buffer capacity + */ + public abstract int setCharSequence(int index, CharSequence sequence, Charset charset); + + /** + * Gets a boolean at the current {@code readerIndex} and increases + * the {@code readerIndex} by {@code 1} in this buffer. + * + * @throws IndexOutOfBoundsException + * if {@code this.readableBytes} is less than {@code 1} + */ + public abstract boolean readBoolean(); + + /** + * Gets a byte at the current {@code readerIndex} and increases + * the {@code readerIndex} by {@code 1} in this buffer. + * + * @throws IndexOutOfBoundsException + * if {@code this.readableBytes} is less than {@code 1} + */ + public abstract byte readByte(); + + /** + * Gets an unsigned byte at the current {@code readerIndex} and increases + * the {@code readerIndex} by {@code 1} in this buffer. + * + * @throws IndexOutOfBoundsException + * if {@code this.readableBytes} is less than {@code 1} + */ + public abstract short readUnsignedByte(); + + /** + * Gets a 16-bit short integer at the current {@code readerIndex} + * and increases the {@code readerIndex} by {@code 2} in this buffer. + * + * @throws IndexOutOfBoundsException + * if {@code this.readableBytes} is less than {@code 2} + */ + public abstract short readShort(); + + /** + * Gets a 16-bit short integer at the current {@code readerIndex} + * in the Little Endian Byte Order and increases the {@code readerIndex} + * by {@code 2} in this buffer. + * + * @throws IndexOutOfBoundsException + * if {@code this.readableBytes} is less than {@code 2} + */ + public abstract short readShortLE(); + + /** + * Gets an unsigned 16-bit short integer at the current {@code readerIndex} + * and increases the {@code readerIndex} by {@code 2} in this buffer. + * + * @throws IndexOutOfBoundsException + * if {@code this.readableBytes} is less than {@code 2} + */ + public abstract int readUnsignedShort(); + + /** + * Gets an unsigned 16-bit short integer at the current {@code readerIndex} + * in the Little Endian Byte Order and increases the {@code readerIndex} + * by {@code 2} in this buffer. + * + * @throws IndexOutOfBoundsException + * if {@code this.readableBytes} is less than {@code 2} + */ + public abstract int readUnsignedShortLE(); + + /** + * Gets a 24-bit medium integer at the current {@code readerIndex} + * and increases the {@code readerIndex} by {@code 3} in this buffer. + * + * @throws IndexOutOfBoundsException + * if {@code this.readableBytes} is less than {@code 3} + */ + public abstract int readMedium(); + + /** + * Gets a 24-bit medium integer at the current {@code readerIndex} + * in the Little Endian Byte Order and increases the + * {@code readerIndex} by {@code 3} in this buffer. + * + * @throws IndexOutOfBoundsException + * if {@code this.readableBytes} is less than {@code 3} + */ + public abstract int readMediumLE(); + + /** + * Gets an unsigned 24-bit medium integer at the current {@code readerIndex} + * and increases the {@code readerIndex} by {@code 3} in this buffer. + * + * @throws IndexOutOfBoundsException + * if {@code this.readableBytes} is less than {@code 3} + */ + public abstract int readUnsignedMedium(); + + /** + * Gets an unsigned 24-bit medium integer at the current {@code readerIndex} + * in the Little Endian Byte Order and increases the {@code readerIndex} + * by {@code 3} in this buffer. + * + * @throws IndexOutOfBoundsException + * if {@code this.readableBytes} is less than {@code 3} + */ + public abstract int readUnsignedMediumLE(); + + /** + * Gets a 32-bit integer at the current {@code readerIndex} + * and increases the {@code readerIndex} by {@code 4} in this buffer. + * + * @throws IndexOutOfBoundsException + * if {@code this.readableBytes} is less than {@code 4} + */ + public abstract int readInt(); + + /** + * Gets a 32-bit integer at the current {@code readerIndex} + * in the Little Endian Byte Order and increases the {@code readerIndex} + * by {@code 4} in this buffer. + * + * @throws IndexOutOfBoundsException + * if {@code this.readableBytes} is less than {@code 4} + */ + public abstract int readIntLE(); + + /** + * Gets an unsigned 32-bit integer at the current {@code readerIndex} + * and increases the {@code readerIndex} by {@code 4} in this buffer. + * + * @throws IndexOutOfBoundsException + * if {@code this.readableBytes} is less than {@code 4} + */ + public abstract long readUnsignedInt(); + + /** + * Gets an unsigned 32-bit integer at the current {@code readerIndex} + * in the Little Endian Byte Order and increases the {@code readerIndex} + * by {@code 4} in this buffer. + * + * @throws IndexOutOfBoundsException + * if {@code this.readableBytes} is less than {@code 4} + */ + public abstract long readUnsignedIntLE(); + + /** + * Gets a 64-bit integer at the current {@code readerIndex} + * and increases the {@code readerIndex} by {@code 8} in this buffer. + * + * @throws IndexOutOfBoundsException + * if {@code this.readableBytes} is less than {@code 8} + */ + public abstract long readLong(); + + /** + * Gets a 64-bit integer at the current {@code readerIndex} + * in the Little Endian Byte Order and increases the {@code readerIndex} + * by {@code 8} in this buffer. + * + * @throws IndexOutOfBoundsException + * if {@code this.readableBytes} is less than {@code 8} + */ + public abstract long readLongLE(); + + /** + * Gets a 2-byte UTF-16 character at the current {@code readerIndex} + * and increases the {@code readerIndex} by {@code 2} in this buffer. + * + * @throws IndexOutOfBoundsException + * if {@code this.readableBytes} is less than {@code 2} + */ + public abstract char readChar(); + + /** + * Gets a 32-bit floating point number at the current {@code readerIndex} + * and increases the {@code readerIndex} by {@code 4} in this buffer. + * + * @throws IndexOutOfBoundsException + * if {@code this.readableBytes} is less than {@code 4} + */ + public abstract float readFloat(); + + /** + * Gets a 32-bit floating point number at the current {@code readerIndex} + * in Little Endian Byte Order and increases the {@code readerIndex} + * by {@code 4} in this buffer. + * + * @throws IndexOutOfBoundsException + * if {@code this.readableBytes} is less than {@code 4} + */ + public float readFloatLE() { + return Float.intBitsToFloat(readIntLE()); + } + + /** + * Gets a 64-bit floating point number at the current {@code readerIndex} + * and increases the {@code readerIndex} by {@code 8} in this buffer. + * + * @throws IndexOutOfBoundsException + * if {@code this.readableBytes} is less than {@code 8} + */ + public abstract double readDouble(); + + /** + * Gets a 64-bit floating point number at the current {@code readerIndex} + * in Little Endian Byte Order and increases the {@code readerIndex} + * by {@code 8} in this buffer. + * + * @throws IndexOutOfBoundsException + * if {@code this.readableBytes} is less than {@code 8} + */ + public double readDoubleLE() { + return Double.longBitsToDouble(readLongLE()); + } + + /** + * Transfers this buffer's data to a newly created buffer starting at + * the current {@code readerIndex} and increases the {@code readerIndex} + * by the number of the transferred bytes (= {@code length}). + * The returned buffer's {@code readerIndex} and {@code writerIndex} are + * {@code 0} and {@code length} respectively. + * + * @param length the number of bytes to transfer + * + * @return the newly created buffer which contains the transferred bytes + * + * @throws IndexOutOfBoundsException + * if {@code length} is greater than {@code this.readableBytes} + */ + public abstract ByteBuf readBytes(int length); + + /** + * Returns a new slice of this buffer's sub-region starting at the current + * {@code readerIndex} and increases the {@code readerIndex} by the size + * of the new slice (= {@code length}). + *

+ * Also be aware that this method will NOT call {@link #retain()} and so the + * reference count will NOT be increased. + * + * @param length the size of the new slice + * + * @return the newly created slice + * + * @throws IndexOutOfBoundsException + * if {@code length} is greater than {@code this.readableBytes} + */ + public abstract ByteBuf readSlice(int length); + + /** + * Returns a new retained slice of this buffer's sub-region starting at the current + * {@code readerIndex} and increases the {@code readerIndex} by the size + * of the new slice (= {@code length}). + *

+ * Note that this method returns a {@linkplain #retain() retained} buffer unlike {@link #readSlice(int)}. + * This method behaves similarly to {@code readSlice(...).retain()} except that this method may return + * a buffer implementation that produces less garbage. + * + * @param length the size of the new slice + * + * @return the newly created slice + * + * @throws IndexOutOfBoundsException + * if {@code length} is greater than {@code this.readableBytes} + */ + public abstract ByteBuf readRetainedSlice(int length); + + /** + * Transfers this buffer's data to the specified destination starting at + * the current {@code readerIndex} until the destination becomes + * non-writable, and increases the {@code readerIndex} by the number of the + * transferred bytes. This method is basically same with + * {@link #readBytes(ByteBuf, int, int)}, except that this method + * increases the {@code writerIndex} of the destination by the number of + * the transferred bytes while {@link #readBytes(ByteBuf, int, int)} + * does not. + * + * @throws IndexOutOfBoundsException + * if {@code dst.writableBytes} is greater than + * {@code this.readableBytes} + */ + public abstract ByteBuf readBytes(ByteBuf dst); + + /** + * Transfers this buffer's data to the specified destination starting at + * the current {@code readerIndex} and increases the {@code readerIndex} + * by the number of the transferred bytes (= {@code length}). This method + * is basically same with {@link #readBytes(ByteBuf, int, int)}, + * except that this method increases the {@code writerIndex} of the + * destination by the number of the transferred bytes (= {@code length}) + * while {@link #readBytes(ByteBuf, int, int)} does not. + * + * @throws IndexOutOfBoundsException + * if {@code length} is greater than {@code this.readableBytes} or + * if {@code length} is greater than {@code dst.writableBytes} + */ + public abstract ByteBuf readBytes(ByteBuf dst, int length); + + /** + * Transfers this buffer's data to the specified destination starting at + * the current {@code readerIndex} and increases the {@code readerIndex} + * by the number of the transferred bytes (= {@code length}). + * + * @param dstIndex the first index of the destination + * @param length the number of bytes to transfer + * + * @throws IndexOutOfBoundsException + * if the specified {@code dstIndex} is less than {@code 0}, + * if {@code length} is greater than {@code this.readableBytes}, or + * if {@code dstIndex + length} is greater than + * {@code dst.capacity} + */ + public abstract ByteBuf readBytes(ByteBuf dst, int dstIndex, int length); + + /** + * Transfers this buffer's data to the specified destination starting at + * the current {@code readerIndex} and increases the {@code readerIndex} + * by the number of the transferred bytes (= {@code dst.length}). + * + * @throws IndexOutOfBoundsException + * if {@code dst.length} is greater than {@code this.readableBytes} + */ + public abstract ByteBuf readBytes(byte[] dst); + + /** + * Transfers this buffer's data to the specified destination starting at + * the current {@code readerIndex} and increases the {@code readerIndex} + * by the number of the transferred bytes (= {@code length}). + * + * @param dstIndex the first index of the destination + * @param length the number of bytes to transfer + * + * @throws IndexOutOfBoundsException + * if the specified {@code dstIndex} is less than {@code 0}, + * if {@code length} is greater than {@code this.readableBytes}, or + * if {@code dstIndex + length} is greater than {@code dst.length} + */ + public abstract ByteBuf readBytes(byte[] dst, int dstIndex, int length); + + /** + * Transfers this buffer's data to the specified destination starting at + * the current {@code readerIndex} until the destination's position + * reaches its limit, and increases the {@code readerIndex} by the + * number of the transferred bytes. + * + * @throws IndexOutOfBoundsException + * if {@code dst.remaining()} is greater than + * {@code this.readableBytes} + */ + public abstract ByteBuf readBytes(ByteBuffer dst); + + /** + * Transfers this buffer's data to the specified stream starting at the + * current {@code readerIndex}. + * + * @param length the number of bytes to transfer + * + * @throws IndexOutOfBoundsException + * if {@code length} is greater than {@code this.readableBytes} + * @throws IOException + * if the specified stream threw an exception during I/O + */ + public abstract ByteBuf readBytes(OutputStream out, int length) throws IOException; + + /** + * Transfers this buffer's data to the specified stream starting at the + * current {@code readerIndex}. + * + * @param length the maximum number of bytes to transfer + * + * @return the actual number of bytes written out to the specified channel + * + * @throws IndexOutOfBoundsException + * if {@code length} is greater than {@code this.readableBytes} + * @throws IOException + * if the specified channel threw an exception during I/O + */ + public abstract int readBytes(GatheringByteChannel out, int length) throws IOException; + + /** + * Gets a {@link CharSequence} with the given length at the current {@code readerIndex} + * and increases the {@code readerIndex} by the given length. + * + * @param length the length to read + * @param charset that should be used + * @return the sequence + * @throws IndexOutOfBoundsException + * if {@code length} is greater than {@code this.readableBytes} + */ + public abstract CharSequence readCharSequence(int length, Charset charset); + + /** + * Transfers this buffer's data starting at the current {@code readerIndex} + * to the specified channel starting at the given file position. + * This method does not modify the channel's position. + * + * @param position the file position at which the transfer is to begin + * @param length the maximum number of bytes to transfer + * + * @return the actual number of bytes written out to the specified channel + * + * @throws IndexOutOfBoundsException + * if {@code length} is greater than {@code this.readableBytes} + * @throws IOException + * if the specified channel threw an exception during I/O + */ + public abstract int readBytes(FileChannel out, long position, int length) throws IOException; + + /** + * Increases the current {@code readerIndex} by the specified + * {@code length} in this buffer. + * + * @throws IndexOutOfBoundsException + * if {@code length} is greater than {@code this.readableBytes} + */ + public abstract ByteBuf skipBytes(int length); + + /** + * Sets the specified boolean at the current {@code writerIndex} + * and increases the {@code writerIndex} by {@code 1} in this buffer. + * If {@code this.writableBytes} is less than {@code 1}, {@link #ensureWritable(int)} + * will be called in an attempt to expand capacity to accommodate. + */ + public abstract ByteBuf writeBoolean(boolean value); + + /** + * Sets the specified byte at the current {@code writerIndex} + * and increases the {@code writerIndex} by {@code 1} in this buffer. + * The 24 high-order bits of the specified value are ignored. + * If {@code this.writableBytes} is less than {@code 1}, {@link #ensureWritable(int)} + * will be called in an attempt to expand capacity to accommodate. + */ + public abstract ByteBuf writeByte(int value); + + /** + * Sets the specified 16-bit short integer at the current + * {@code writerIndex} and increases the {@code writerIndex} by {@code 2} + * in this buffer. The 16 high-order bits of the specified value are ignored. + * If {@code this.writableBytes} is less than {@code 2}, {@link #ensureWritable(int)} + * will be called in an attempt to expand capacity to accommodate. + */ + public abstract ByteBuf writeShort(int value); + + /** + * Sets the specified 16-bit short integer in the Little Endian Byte + * Order at the current {@code writerIndex} and increases the + * {@code writerIndex} by {@code 2} in this buffer. + * The 16 high-order bits of the specified value are ignored. + * If {@code this.writableBytes} is less than {@code 2}, {@link #ensureWritable(int)} + * will be called in an attempt to expand capacity to accommodate. + */ + public abstract ByteBuf writeShortLE(int value); + + /** + * Sets the specified 24-bit medium integer at the current + * {@code writerIndex} and increases the {@code writerIndex} by {@code 3} + * in this buffer. + * If {@code this.writableBytes} is less than {@code 3}, {@link #ensureWritable(int)} + * will be called in an attempt to expand capacity to accommodate. + */ + public abstract ByteBuf writeMedium(int value); + + /** + * Sets the specified 24-bit medium integer at the current + * {@code writerIndex} in the Little Endian Byte Order and + * increases the {@code writerIndex} by {@code 3} in this + * buffer. + * If {@code this.writableBytes} is less than {@code 3}, {@link #ensureWritable(int)} + * will be called in an attempt to expand capacity to accommodate. + */ + public abstract ByteBuf writeMediumLE(int value); + + /** + * Sets the specified 32-bit integer at the current {@code writerIndex} + * and increases the {@code writerIndex} by {@code 4} in this buffer. + * If {@code this.writableBytes} is less than {@code 4}, {@link #ensureWritable(int)} + * will be called in an attempt to expand capacity to accommodate. + */ + public abstract ByteBuf writeInt(int value); + + /** + * Sets the specified 32-bit integer at the current {@code writerIndex} + * in the Little Endian Byte Order and increases the {@code writerIndex} + * by {@code 4} in this buffer. + * If {@code this.writableBytes} is less than {@code 4}, {@link #ensureWritable(int)} + * will be called in an attempt to expand capacity to accommodate. + */ + public abstract ByteBuf writeIntLE(int value); + + /** + * Sets the specified 64-bit long integer at the current + * {@code writerIndex} and increases the {@code writerIndex} by {@code 8} + * in this buffer. + * If {@code this.writableBytes} is less than {@code 8}, {@link #ensureWritable(int)} + * will be called in an attempt to expand capacity to accommodate. + */ + public abstract ByteBuf writeLong(long value); + + /** + * Sets the specified 64-bit long integer at the current + * {@code writerIndex} in the Little Endian Byte Order and + * increases the {@code writerIndex} by {@code 8} + * in this buffer. + * If {@code this.writableBytes} is less than {@code 8}, {@link #ensureWritable(int)} + * will be called in an attempt to expand capacity to accommodate. + */ + public abstract ByteBuf writeLongLE(long value); + + /** + * Sets the specified 2-byte UTF-16 character at the current + * {@code writerIndex} and increases the {@code writerIndex} by {@code 2} + * in this buffer. The 16 high-order bits of the specified value are ignored. + * If {@code this.writableBytes} is less than {@code 2}, {@link #ensureWritable(int)} + * will be called in an attempt to expand capacity to accommodate. + */ + public abstract ByteBuf writeChar(int value); + + /** + * Sets the specified 32-bit floating point number at the current + * {@code writerIndex} and increases the {@code writerIndex} by {@code 4} + * in this buffer. + * If {@code this.writableBytes} is less than {@code 4}, {@link #ensureWritable(int)} + * will be called in an attempt to expand capacity to accommodate. + */ + public abstract ByteBuf writeFloat(float value); + + /** + * Sets the specified 32-bit floating point number at the current + * {@code writerIndex} in Little Endian Byte Order and increases + * the {@code writerIndex} by {@code 4} in this buffer. + * If {@code this.writableBytes} is less than {@code 4}, {@link #ensureWritable(int)} + * will be called in an attempt to expand capacity to accommodate. + */ + public ByteBuf writeFloatLE(float value) { + return writeIntLE(Float.floatToRawIntBits(value)); + } + + /** + * Sets the specified 64-bit floating point number at the current + * {@code writerIndex} and increases the {@code writerIndex} by {@code 8} + * in this buffer. + * If {@code this.writableBytes} is less than {@code 8}, {@link #ensureWritable(int)} + * will be called in an attempt to expand capacity to accommodate. + */ + public abstract ByteBuf writeDouble(double value); + + /** + * Sets the specified 64-bit floating point number at the current + * {@code writerIndex} in Little Endian Byte Order and increases + * the {@code writerIndex} by {@code 8} in this buffer. + * If {@code this.writableBytes} is less than {@code 8}, {@link #ensureWritable(int)} + * will be called in an attempt to expand capacity to accommodate. + */ + public ByteBuf writeDoubleLE(double value) { + return writeLongLE(Double.doubleToRawLongBits(value)); + } + + /** + * Transfers the specified source buffer's data to this buffer starting at + * the current {@code writerIndex} until the source buffer becomes + * unreadable, and increases the {@code writerIndex} by the number of + * the transferred bytes. This method is basically same with + * {@link #writeBytes(ByteBuf, int, int)}, except that this method + * increases the {@code readerIndex} of the source buffer by the number of + * the transferred bytes while {@link #writeBytes(ByteBuf, int, int)} + * does not. + * If {@code this.writableBytes} is less than {@code src.readableBytes}, + * {@link #ensureWritable(int)} will be called in an attempt to expand + * capacity to accommodate. + */ + public abstract ByteBuf writeBytes(ByteBuf src); + + /** + * Transfers the specified source buffer's data to this buffer starting at + * the current {@code writerIndex} and increases the {@code writerIndex} + * by the number of the transferred bytes (= {@code length}). This method + * is basically same with {@link #writeBytes(ByteBuf, int, int)}, + * except that this method increases the {@code readerIndex} of the source + * buffer by the number of the transferred bytes (= {@code length}) while + * {@link #writeBytes(ByteBuf, int, int)} does not. + * If {@code this.writableBytes} is less than {@code length}, {@link #ensureWritable(int)} + * will be called in an attempt to expand capacity to accommodate. + * + * @param length the number of bytes to transfer + * @throws IndexOutOfBoundsException if {@code length} is greater then {@code src.readableBytes} + */ + public abstract ByteBuf writeBytes(ByteBuf src, int length); + + /** + * Transfers the specified source buffer's data to this buffer starting at + * the current {@code writerIndex} and increases the {@code writerIndex} + * by the number of the transferred bytes (= {@code length}). + * If {@code this.writableBytes} is less than {@code length}, {@link #ensureWritable(int)} + * will be called in an attempt to expand capacity to accommodate. + * + * @param srcIndex the first index of the source + * @param length the number of bytes to transfer + * + * @throws IndexOutOfBoundsException + * if the specified {@code srcIndex} is less than {@code 0}, or + * if {@code srcIndex + length} is greater than {@code src.capacity} + */ + public abstract ByteBuf writeBytes(ByteBuf src, int srcIndex, int length); + + /** + * Transfers the specified source array's data to this buffer starting at + * the current {@code writerIndex} and increases the {@code writerIndex} + * by the number of the transferred bytes (= {@code src.length}). + * If {@code this.writableBytes} is less than {@code src.length}, {@link #ensureWritable(int)} + * will be called in an attempt to expand capacity to accommodate. + */ + public abstract ByteBuf writeBytes(byte[] src); + + /** + * Transfers the specified source array's data to this buffer starting at + * the current {@code writerIndex} and increases the {@code writerIndex} + * by the number of the transferred bytes (= {@code length}). + * If {@code this.writableBytes} is less than {@code length}, {@link #ensureWritable(int)} + * will be called in an attempt to expand capacity to accommodate. + * + * @param srcIndex the first index of the source + * @param length the number of bytes to transfer + * + * @throws IndexOutOfBoundsException + * if the specified {@code srcIndex} is less than {@code 0}, or + * if {@code srcIndex + length} is greater than {@code src.length} + */ + public abstract ByteBuf writeBytes(byte[] src, int srcIndex, int length); + + /** + * Transfers the specified source buffer's data to this buffer starting at + * the current {@code writerIndex} until the source buffer's position + * reaches its limit, and increases the {@code writerIndex} by the + * number of the transferred bytes. + * If {@code this.writableBytes} is less than {@code src.remaining()}, + * {@link #ensureWritable(int)} will be called in an attempt to expand + * capacity to accommodate. + */ + public abstract ByteBuf writeBytes(ByteBuffer src); + + /** + * Transfers the content of the specified stream to this buffer + * starting at the current {@code writerIndex} and increases the + * {@code writerIndex} by the number of the transferred bytes. + * If {@code this.writableBytes} is less than {@code length}, {@link #ensureWritable(int)} + * will be called in an attempt to expand capacity to accommodate. + * + * @param length the number of bytes to transfer + * + * @return the actual number of bytes read in from the specified channel. + * {@code -1} if the specified {@link InputStream} reached EOF. + * + * @throws IOException if the specified stream threw an exception during I/O + */ + public abstract int writeBytes(InputStream in, int length) throws IOException; + + /** + * Transfers the content of the specified channel to this buffer + * starting at the current {@code writerIndex} and increases the + * {@code writerIndex} by the number of the transferred bytes. + * If {@code this.writableBytes} is less than {@code length}, {@link #ensureWritable(int)} + * will be called in an attempt to expand capacity to accommodate. + * + * @param length the maximum number of bytes to transfer + * + * @return the actual number of bytes read in from the specified channel. + * {@code -1} if the specified channel is closed or it reached EOF. + * + * @throws IOException + * if the specified channel threw an exception during I/O + */ + public abstract int writeBytes(ScatteringByteChannel in, int length) throws IOException; + + /** + * Transfers the content of the specified channel starting at the given file position + * to this buffer starting at the current {@code writerIndex} and increases the + * {@code writerIndex} by the number of the transferred bytes. + * This method does not modify the channel's position. + * If {@code this.writableBytes} is less than {@code length}, {@link #ensureWritable(int)} + * will be called in an attempt to expand capacity to accommodate. + * + * @param position the file position at which the transfer is to begin + * @param length the maximum number of bytes to transfer + * + * @return the actual number of bytes read in from the specified channel. + * {@code -1} if the specified channel is closed or it reached EOF. + * + * @throws IOException + * if the specified channel threw an exception during I/O + */ + public abstract int writeBytes(FileChannel in, long position, int length) throws IOException; + + /** + * Fills this buffer with NUL (0x00) starting at the current + * {@code writerIndex} and increases the {@code writerIndex} by the + * specified {@code length}. + * If {@code this.writableBytes} is less than {@code length}, {@link #ensureWritable(int)} + * will be called in an attempt to expand capacity to accommodate. + * + * @param length the number of NULs to write to the buffer + */ + public abstract ByteBuf writeZero(int length); + + /** + * Writes the specified {@link CharSequence} at the current {@code writerIndex} and increases + * the {@code writerIndex} by the written bytes. + * in this buffer. + * If {@code this.writableBytes} is not large enough to write the whole sequence, + * {@link #ensureWritable(int)} will be called in an attempt to expand capacity to accommodate. + * + * @param sequence to write + * @param charset that should be used + * @return the written number of bytes + */ + public abstract int writeCharSequence(CharSequence sequence, Charset charset); + + /** + * Locates the first occurrence of the specified {@code value} in this + * buffer. The search takes place from the specified {@code fromIndex} + * (inclusive) to the specified {@code toIndex} (exclusive). + *

+ * If {@code fromIndex} is greater than {@code toIndex}, the search is + * performed in a reversed order from {@code fromIndex} (exclusive) + * down to {@code toIndex} (inclusive). + *

+ * Note that the lower index is always included and higher always excluded. + *

+ * This method does not modify {@code readerIndex} or {@code writerIndex} of + * this buffer. + * + * @return the absolute index of the first occurrence if found. + * {@code -1} otherwise. + */ + public abstract int indexOf(int fromIndex, int toIndex, byte value); + + /** + * Locates the first occurrence of the specified {@code value} in this + * buffer. The search takes place from the current {@code readerIndex} + * (inclusive) to the current {@code writerIndex} (exclusive). + *

+ * This method does not modify {@code readerIndex} or {@code writerIndex} of + * this buffer. + * + * @return the number of bytes between the current {@code readerIndex} + * and the first occurrence if found. {@code -1} otherwise. + */ + public abstract int bytesBefore(byte value); + + /** + * Locates the first occurrence of the specified {@code value} in this + * buffer. The search starts from the current {@code readerIndex} + * (inclusive) and lasts for the specified {@code length}. + *

+ * This method does not modify {@code readerIndex} or {@code writerIndex} of + * this buffer. + * + * @return the number of bytes between the current {@code readerIndex} + * and the first occurrence if found. {@code -1} otherwise. + * + * @throws IndexOutOfBoundsException + * if {@code length} is greater than {@code this.readableBytes} + */ + public abstract int bytesBefore(int length, byte value); + + /** + * Locates the first occurrence of the specified {@code value} in this + * buffer. The search starts from the specified {@code index} (inclusive) + * and lasts for the specified {@code length}. + *

+ * This method does not modify {@code readerIndex} or {@code writerIndex} of + * this buffer. + * + * @return the number of bytes between the specified {@code index} + * and the first occurrence if found. {@code -1} otherwise. + * + * @throws IndexOutOfBoundsException + * if {@code index + length} is greater than {@code this.capacity} + */ + public abstract int bytesBefore(int index, int length, byte value); + + /** + * Iterates over the readable bytes of this buffer with the specified {@code processor} in ascending order. + * + * @return {@code -1} if the processor iterated to or beyond the end of the readable bytes. + * The last-visited index If the {@link ByteProcessor#process(byte)} returned {@code false}. + */ + public abstract int forEachByte(ByteProcessor processor); + + /** + * Iterates over the specified area of this buffer with the specified {@code processor} in ascending order. + * (i.e. {@code index}, {@code (index + 1)}, .. {@code (index + length - 1)}) + * + * @return {@code -1} if the processor iterated to or beyond the end of the specified area. + * The last-visited index If the {@link ByteProcessor#process(byte)} returned {@code false}. + */ + public abstract int forEachByte(int index, int length, ByteProcessor processor); + + /** + * Iterates over the readable bytes of this buffer with the specified {@code processor} in descending order. + * + * @return {@code -1} if the processor iterated to or beyond the beginning of the readable bytes. + * The last-visited index If the {@link ByteProcessor#process(byte)} returned {@code false}. + */ + public abstract int forEachByteDesc(ByteProcessor processor); + + /** + * Iterates over the specified area of this buffer with the specified {@code processor} in descending order. + * (i.e. {@code (index + length - 1)}, {@code (index + length - 2)}, ... {@code index}) + * + * + * @return {@code -1} if the processor iterated to or beyond the beginning of the specified area. + * The last-visited index If the {@link ByteProcessor#process(byte)} returned {@code false}. + */ + public abstract int forEachByteDesc(int index, int length, ByteProcessor processor); + + /** + * Returns a copy of this buffer's readable bytes. Modifying the content + * of the returned buffer or this buffer does not affect each other at all. + * This method is identical to {@code buf.copy(buf.readerIndex(), buf.readableBytes())}. + * This method does not modify {@code readerIndex} or {@code writerIndex} of + * this buffer. + */ + public abstract ByteBuf copy(); + + /** + * Returns a copy of this buffer's sub-region. Modifying the content of + * the returned buffer or this buffer does not affect each other at all. + * This method does not modify {@code readerIndex} or {@code writerIndex} of + * this buffer. + */ + public abstract ByteBuf copy(int index, int length); + + /** + * Returns a slice of this buffer's readable bytes. Modifying the content + * of the returned buffer or this buffer affects each other's content + * while they maintain separate indexes and marks. This method is + * identical to {@code buf.slice(buf.readerIndex(), buf.readableBytes())}. + * This method does not modify {@code readerIndex} or {@code writerIndex} of + * this buffer. + *

+ * Also be aware that this method will NOT call {@link #retain()} and so the + * reference count will NOT be increased. + */ + public abstract ByteBuf slice(); + + /** + * Returns a retained slice of this buffer's readable bytes. Modifying the content + * of the returned buffer or this buffer affects each other's content + * while they maintain separate indexes and marks. This method is + * identical to {@code buf.slice(buf.readerIndex(), buf.readableBytes())}. + * This method does not modify {@code readerIndex} or {@code writerIndex} of + * this buffer. + *

+ * Note that this method returns a {@linkplain #retain() retained} buffer unlike {@link #slice()}. + * This method behaves similarly to {@code slice().retain()} except that this method may return + * a buffer implementation that produces less garbage. + */ + public abstract ByteBuf retainedSlice(); + + /** + * Returns a slice of this buffer's sub-region. Modifying the content of + * the returned buffer or this buffer affects each other's content while + * they maintain separate indexes and marks. + * This method does not modify {@code readerIndex} or {@code writerIndex} of + * this buffer. + *

+ * Also be aware that this method will NOT call {@link #retain()} and so the + * reference count will NOT be increased. + */ + public abstract ByteBuf slice(int index, int length); + + /** + * Returns a retained slice of this buffer's sub-region. Modifying the content of + * the returned buffer or this buffer affects each other's content while + * they maintain separate indexes and marks. + * This method does not modify {@code readerIndex} or {@code writerIndex} of + * this buffer. + *

+ * Note that this method returns a {@linkplain #retain() retained} buffer unlike {@link #slice(int, int)}. + * This method behaves similarly to {@code slice(...).retain()} except that this method may return + * a buffer implementation that produces less garbage. + */ + public abstract ByteBuf retainedSlice(int index, int length); + + /** + * Returns a buffer which shares the whole region of this buffer. + * Modifying the content of the returned buffer or this buffer affects + * each other's content while they maintain separate indexes and marks. + * This method does not modify {@code readerIndex} or {@code writerIndex} of + * this buffer. + *

+ * The reader and writer marks will not be duplicated. Also be aware that this method will + * NOT call {@link #retain()} and so the reference count will NOT be increased. + * @return A buffer whose readable content is equivalent to the buffer returned by {@link #slice()}. + * However this buffer will share the capacity of the underlying buffer, and therefore allows access to all of the + * underlying content if necessary. + */ + public abstract ByteBuf duplicate(); + + /** + * Returns a retained buffer which shares the whole region of this buffer. + * Modifying the content of the returned buffer or this buffer affects + * each other's content while they maintain separate indexes and marks. + * This method is identical to {@code buf.slice(0, buf.capacity())}. + * This method does not modify {@code readerIndex} or {@code writerIndex} of + * this buffer. + *

+ * Note that this method returns a {@linkplain #retain() retained} buffer unlike {@link #slice(int, int)}. + * This method behaves similarly to {@code duplicate().retain()} except that this method may return + * a buffer implementation that produces less garbage. + */ + public abstract ByteBuf retainedDuplicate(); + + /** + * Returns the maximum number of NIO {@link ByteBuffer}s that consist this buffer. Note that {@link #nioBuffers()} + * or {@link #nioBuffers(int, int)} might return a less number of {@link ByteBuffer}s. + * + * @return {@code -1} if this buffer has no underlying {@link ByteBuffer}. + * the number of the underlying {@link ByteBuffer}s if this buffer has at least one underlying + * {@link ByteBuffer}. Note that this method does not return {@code 0} to avoid confusion. + * + * @see #nioBuffer() + * @see #nioBuffer(int, int) + * @see #nioBuffers() + * @see #nioBuffers(int, int) + */ + public abstract int nioBufferCount(); + + /** + * Exposes this buffer's readable bytes as an NIO {@link ByteBuffer}. The returned buffer + * either share or contains the copied content of this buffer, while changing the position + * and limit of the returned NIO buffer does not affect the indexes and marks of this buffer. + * This method is identical to {@code buf.nioBuffer(buf.readerIndex(), buf.readableBytes())}. + * This method does not modify {@code readerIndex} or {@code writerIndex} of this buffer. + * Please note that the returned NIO buffer will not see the changes of this buffer if this buffer + * is a dynamic buffer and it adjusted its capacity. + * + * @throws UnsupportedOperationException + * if this buffer cannot create a {@link ByteBuffer} that shares the content with itself + * + * @see #nioBufferCount() + * @see #nioBuffers() + * @see #nioBuffers(int, int) + */ + public abstract ByteBuffer nioBuffer(); + + /** + * Exposes this buffer's sub-region as an NIO {@link ByteBuffer}. The returned buffer + * either share or contains the copied content of this buffer, while changing the position + * and limit of the returned NIO buffer does not affect the indexes and marks of this buffer. + * This method does not modify {@code readerIndex} or {@code writerIndex} of this buffer. + * Please note that the returned NIO buffer will not see the changes of this buffer if this buffer + * is a dynamic buffer and it adjusted its capacity. + * + * @throws UnsupportedOperationException + * if this buffer cannot create a {@link ByteBuffer} that shares the content with itself + * + * @see #nioBufferCount() + * @see #nioBuffers() + * @see #nioBuffers(int, int) + */ + public abstract ByteBuffer nioBuffer(int index, int length); + + /** + * Internal use only: Exposes the internal NIO buffer. + */ + public abstract ByteBuffer internalNioBuffer(int index, int length); + + /** + * Exposes this buffer's readable bytes as an NIO {@link ByteBuffer}'s. The returned buffer + * either share or contains the copied content of this buffer, while changing the position + * and limit of the returned NIO buffer does not affect the indexes and marks of this buffer. + * This method does not modify {@code readerIndex} or {@code writerIndex} of this buffer. + * Please note that the returned NIO buffer will not see the changes of this buffer if this buffer + * is a dynamic buffer and it adjusted its capacity. + * + * + * @throws UnsupportedOperationException + * if this buffer cannot create a {@link ByteBuffer} that shares the content with itself + * + * @see #nioBufferCount() + * @see #nioBuffer() + * @see #nioBuffer(int, int) + */ + public abstract ByteBuffer[] nioBuffers(); + + /** + * Exposes this buffer's bytes as an NIO {@link ByteBuffer}'s for the specified index and length + * The returned buffer either share or contains the copied content of this buffer, while changing + * the position and limit of the returned NIO buffer does not affect the indexes and marks of this buffer. + * This method does not modify {@code readerIndex} or {@code writerIndex} of this buffer. Please note that the + * returned NIO buffer will not see the changes of this buffer if this buffer is a dynamic + * buffer and it adjusted its capacity. + * + * @throws UnsupportedOperationException + * if this buffer cannot create a {@link ByteBuffer} that shares the content with itself + * + * @see #nioBufferCount() + * @see #nioBuffer() + * @see #nioBuffer(int, int) + */ + public abstract ByteBuffer[] nioBuffers(int index, int length); + + /** + * Returns {@code true} if and only if this buffer has a backing byte array. + * If this method returns true, you can safely call {@link #array()} and + * {@link #arrayOffset()}. + */ + public abstract boolean hasArray(); + + /** + * Returns the backing byte array of this buffer. + * + * @throws UnsupportedOperationException + * if there no accessible backing byte array + */ + public abstract byte[] array(); + + /** + * Returns the offset of the first byte within the backing byte array of + * this buffer. + * + * @throws UnsupportedOperationException + * if there no accessible backing byte array + */ + public abstract int arrayOffset(); + + /** + * Returns {@code true} if and only if this buffer has a reference to the low-level memory address that points + * to the backing data. + */ + public abstract boolean hasMemoryAddress(); + + /** + * Returns the low-level memory address that point to the first byte of ths backing data. + * + * @throws UnsupportedOperationException + * if this buffer does not support accessing the low-level memory address + */ + public abstract long memoryAddress(); + + /** + * Returns {@code true} if this {@link ByteBuf} implementation is backed by a single memory region. + * Composite buffer implementations must return false even if they currently hold ≤ 1 components. + * For buffers that return {@code true}, it's guaranteed that a successful call to {@link #discardReadBytes()} + * will increase the value of {@link #maxFastWritableBytes()} by the current {@code readerIndex}. + *

+ * This method will return {@code false} by default, and a {@code false} return value does not necessarily + * mean that the implementation is composite or that it is not backed by a single memory region. + */ + public boolean isContiguous() { + return false; + } + + /** + * A {@code ByteBuf} can turn into itself. + * @return This {@code ByteBuf} instance. + */ + @Override + public ByteBuf asByteBuf() { + return this; + } + + /** + * Decodes this buffer's readable bytes into a string with the specified + * character set name. This method is identical to + * {@code buf.toString(buf.readerIndex(), buf.readableBytes(), charsetName)}. + * This method does not modify {@code readerIndex} or {@code writerIndex} of + * this buffer. + * + * @throws UnsupportedCharsetException + * if the specified character set name is not supported by the + * current VM + */ + public abstract String toString(Charset charset); + + /** + * Decodes this buffer's sub-region into a string with the specified + * character set. This method does not modify {@code readerIndex} or + * {@code writerIndex} of this buffer. + */ + public abstract String toString(int index, int length, Charset charset); + + /** + * Returns a hash code which was calculated from the content of this + * buffer. If there's a byte array which is + * {@linkplain #equals(Object) equal to} this array, both arrays should + * return the same value. + */ + @Override + public abstract int hashCode(); + + /** + * Determines if the content of the specified buffer is identical to the + * content of this array. 'Identical' here means: + *

    + *
  • the size of the contents of the two buffers are same and
  • + *
  • every single byte of the content of the two buffers are same.
  • + *
+ * Please note that it does not compare {@link #readerIndex()} nor + * {@link #writerIndex()}. This method also returns {@code false} for + * {@code null} and an object which is not an instance of + * {@link ByteBuf} type. + */ + @Override + public abstract boolean equals(Object obj); + + /** + * Compares the content of the specified buffer to the content of this + * buffer. Comparison is performed in the same manner with the string + * comparison functions of various languages such as {@code strcmp}, + * {@code memcmp} and {@link String#compareTo(String)}. + */ + @Override + public abstract int compareTo(ByteBuf buffer); + + /** + * Returns the string representation of this buffer. This method does not + * necessarily return the whole content of the buffer but returns + * the values of the key properties such as {@link #readerIndex()}, + * {@link #writerIndex()} and {@link #capacity()}. + */ + @Override + public abstract String toString(); + + @Override + public abstract ByteBuf retain(int increment); + + @Override + public abstract ByteBuf retain(); + + @Override + public abstract ByteBuf touch(); + + @Override + public abstract ByteBuf touch(Object hint); + + /** + * Used internally by {@link AbstractByteBuf#ensureAccessible()} to try to guard + * against using the buffer after it was released (best-effort). + */ + boolean isAccessible() { + return refCnt() != 0; + } +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/ByteBufAllocator.java b/netty-buffer/src/main/java/io/netty/buffer/ByteBufAllocator.java new file mode 100644 index 0000000..802f5c8 --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/ByteBufAllocator.java @@ -0,0 +1,134 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +/** + * Implementations are responsible to allocate buffers. Implementations of this interface are expected to be + * thread-safe. + */ +public interface ByteBufAllocator { + + ByteBufAllocator DEFAULT = ByteBufUtil.DEFAULT_ALLOCATOR; + + /** + * Allocate a {@link ByteBuf}. If it is a direct or heap buffer + * depends on the actual implementation. + */ + ByteBuf buffer(); + + /** + * Allocate a {@link ByteBuf} with the given initial capacity. + * If it is a direct or heap buffer depends on the actual implementation. + */ + ByteBuf buffer(int initialCapacity); + + /** + * Allocate a {@link ByteBuf} with the given initial capacity and the given + * maximal capacity. If it is a direct or heap buffer depends on the actual + * implementation. + */ + ByteBuf buffer(int initialCapacity, int maxCapacity); + + /** + * Allocate a {@link ByteBuf}, preferably a direct buffer which is suitable for I/O. + */ + ByteBuf ioBuffer(); + + /** + * Allocate a {@link ByteBuf}, preferably a direct buffer which is suitable for I/O. + */ + ByteBuf ioBuffer(int initialCapacity); + + /** + * Allocate a {@link ByteBuf}, preferably a direct buffer which is suitable for I/O. + */ + ByteBuf ioBuffer(int initialCapacity, int maxCapacity); + + /** + * Allocate a heap {@link ByteBuf}. + */ + ByteBuf heapBuffer(); + + /** + * Allocate a heap {@link ByteBuf} with the given initial capacity. + */ + ByteBuf heapBuffer(int initialCapacity); + + /** + * Allocate a heap {@link ByteBuf} with the given initial capacity and the given + * maximal capacity. + */ + ByteBuf heapBuffer(int initialCapacity, int maxCapacity); + + /** + * Allocate a direct {@link ByteBuf}. + */ + ByteBuf directBuffer(); + + /** + * Allocate a direct {@link ByteBuf} with the given initial capacity. + */ + ByteBuf directBuffer(int initialCapacity); + + /** + * Allocate a direct {@link ByteBuf} with the given initial capacity and the given + * maximal capacity. + */ + ByteBuf directBuffer(int initialCapacity, int maxCapacity); + + /** + * Allocate a {@link CompositeByteBuf}. + * If it is a direct or heap buffer depends on the actual implementation. + */ + CompositeByteBuf compositeBuffer(); + + /** + * Allocate a {@link CompositeByteBuf} with the given maximum number of components that can be stored in it. + * If it is a direct or heap buffer depends on the actual implementation. + */ + CompositeByteBuf compositeBuffer(int maxNumComponents); + + /** + * Allocate a heap {@link CompositeByteBuf}. + */ + CompositeByteBuf compositeHeapBuffer(); + + /** + * Allocate a heap {@link CompositeByteBuf} with the given maximum number of components that can be stored in it. + */ + CompositeByteBuf compositeHeapBuffer(int maxNumComponents); + + /** + * Allocate a direct {@link CompositeByteBuf}. + */ + CompositeByteBuf compositeDirectBuffer(); + + /** + * Allocate a direct {@link CompositeByteBuf} with the given maximum number of components that can be stored in it. + */ + CompositeByteBuf compositeDirectBuffer(int maxNumComponents); + + /** + * Returns {@code true} if direct {@link ByteBuf}'s are pooled + */ + boolean isDirectBufferPooled(); + + /** + * Calculate the new capacity of a {@link ByteBuf} that is used when a {@link ByteBuf} needs to expand by the + * {@code minNewCapacity} with {@code maxCapacity} as upper-bound. + */ + int calculateNewCapacity(int minNewCapacity, int maxCapacity); + } diff --git a/netty-buffer/src/main/java/io/netty/buffer/ByteBufAllocatorMetric.java b/netty-buffer/src/main/java/io/netty/buffer/ByteBufAllocatorMetric.java new file mode 100644 index 0000000..7f3ffbd --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/ByteBufAllocatorMetric.java @@ -0,0 +1,28 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +public interface ByteBufAllocatorMetric { + /** + * Returns the number of bytes of heap memory used by a {@link ByteBufAllocator} or {@code -1} if unknown. + */ + long usedHeapMemory(); + + /** + * Returns the number of bytes of direct memory used by a {@link ByteBufAllocator} or {@code -1} if unknown. + */ + long usedDirectMemory(); +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/ByteBufAllocatorMetricProvider.java b/netty-buffer/src/main/java/io/netty/buffer/ByteBufAllocatorMetricProvider.java new file mode 100644 index 0000000..84b0184 --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/ByteBufAllocatorMetricProvider.java @@ -0,0 +1,24 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +public interface ByteBufAllocatorMetricProvider { + + /** + * Returns a {@link ByteBufAllocatorMetric} for a {@link ByteBufAllocator}. + */ + ByteBufAllocatorMetric metric(); +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/ByteBufConvertible.java b/netty-buffer/src/main/java/io/netty/buffer/ByteBufConvertible.java new file mode 100644 index 0000000..853fb37 --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/ByteBufConvertible.java @@ -0,0 +1,32 @@ +/* + * Copyright 2022 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +/** + * An interface that can be implemented by any object that know how to turn itself into a {@link ByteBuf}. + * All {@link ByteBuf} classes implement this interface, and return themselves. + */ +public interface ByteBufConvertible { + /** + * Turn this object into a {@link ByteBuf}. + * This does not increment the reference count of the {@link ByteBuf} instance. + * The conversion or exposure of the {@link ByteBuf} must be idempotent, so that this method can be called + * either once, or multiple times, without causing any change in program behaviour. + * + * @return A {@link ByteBuf} instance from this object. + */ + ByteBuf asByteBuf(); +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/ByteBufHolder.java b/netty-buffer/src/main/java/io/netty/buffer/ByteBufHolder.java new file mode 100644 index 0000000..c506dd9 --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/ByteBufHolder.java @@ -0,0 +1,63 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import io.netty.util.ReferenceCounted; + +/** + * A packet which is send or receive. + */ +public interface ByteBufHolder extends ReferenceCounted { + + /** + * Return the data which is held by this {@link ByteBufHolder}. + */ + ByteBuf content(); + + /** + * Creates a deep copy of this {@link ByteBufHolder}. + */ + ByteBufHolder copy(); + + /** + * Duplicates this {@link ByteBufHolder}. Be aware that this will not automatically call {@link #retain()}. + */ + ByteBufHolder duplicate(); + + /** + * Duplicates this {@link ByteBufHolder}. This method returns a retained duplicate unlike {@link #duplicate()}. + * + * @see ByteBuf#retainedDuplicate() + */ + ByteBufHolder retainedDuplicate(); + + /** + * Returns a new {@link ByteBufHolder} which contains the specified {@code content}. + */ + ByteBufHolder replace(ByteBuf content); + + @Override + ByteBufHolder retain(); + + @Override + ByteBufHolder retain(int increment); + + @Override + ByteBufHolder touch(); + + @Override + ByteBufHolder touch(Object hint); +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/ByteBufInputStream.java b/netty-buffer/src/main/java/io/netty/buffer/ByteBufInputStream.java new file mode 100644 index 0000000..a0b6127 --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/ByteBufInputStream.java @@ -0,0 +1,330 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; + +import io.netty.util.ReferenceCounted; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.StringUtil; + +import java.io.DataInput; +import java.io.DataInputStream; +import java.io.EOFException; +import java.io.IOException; +import java.io.InputStream; + +/** + * An {@link InputStream} which reads data from a {@link ByteBuf}. + *

+ * A read operation against this stream will occur at the {@code readerIndex} + * of its underlying buffer and the {@code readerIndex} will increase during + * the read operation. Please note that it only reads up to the number of + * readable bytes determined at the moment of construction. Therefore, + * updating {@link ByteBuf#writerIndex()} will not affect the return + * value of {@link #available()}. + *

+ * This stream implements {@link DataInput} for your convenience. + * The endianness of the stream is not always big endian but depends on + * the endianness of the underlying buffer. + * + * @see ByteBufOutputStream + */ +public class ByteBufInputStream extends InputStream implements DataInput { + private final ByteBuf buffer; + private final int startIndex; + private final int endIndex; + private boolean closed; + /** + * To preserve backwards compatibility (which didn't transfer ownership) we support a conditional flag which + * indicates if {@link #buffer} should be released when this {@link InputStream} is closed. + * However in future releases ownership should always be transferred and callers of this class should call + * {@link ReferenceCounted#retain()} if necessary. + */ + private final boolean releaseOnClose; + + /** + * Creates a new stream which reads data from the specified {@code buffer} + * starting at the current {@code readerIndex} and ending at the current + * {@code writerIndex}. + * @param buffer The buffer which provides the content for this {@link InputStream}. + */ + public ByteBufInputStream(ByteBuf buffer) { + this(buffer, buffer.readableBytes()); + } + + /** + * Creates a new stream which reads data from the specified {@code buffer} + * starting at the current {@code readerIndex} and ending at + * {@code readerIndex + length}. + * @param buffer The buffer which provides the content for this {@link InputStream}. + * @param length The length of the buffer to use for this {@link InputStream}. + * @throws IndexOutOfBoundsException + * if {@code readerIndex + length} is greater than + * {@code writerIndex} + */ + public ByteBufInputStream(ByteBuf buffer, int length) { + this(buffer, length, false); + } + + /** + * Creates a new stream which reads data from the specified {@code buffer} + * starting at the current {@code readerIndex} and ending at the current + * {@code writerIndex}. + * @param buffer The buffer which provides the content for this {@link InputStream}. + * @param releaseOnClose {@code true} means that when {@link #close()} is called then {@link ByteBuf#release()} will + * be called on {@code buffer}. + */ + public ByteBufInputStream(ByteBuf buffer, boolean releaseOnClose) { + this(buffer, buffer.readableBytes(), releaseOnClose); + } + + /** + * Creates a new stream which reads data from the specified {@code buffer} + * starting at the current {@code readerIndex} and ending at + * {@code readerIndex + length}. + * @param buffer The buffer which provides the content for this {@link InputStream}. + * @param length The length of the buffer to use for this {@link InputStream}. + * @param releaseOnClose {@code true} means that when {@link #close()} is called then {@link ByteBuf#release()} will + * be called on {@code buffer}. + * @throws IndexOutOfBoundsException + * if {@code readerIndex + length} is greater than + * {@code writerIndex} + */ + public ByteBufInputStream(ByteBuf buffer, int length, boolean releaseOnClose) { + ObjectUtil.checkNotNull(buffer, "buffer"); + if (length < 0) { + if (releaseOnClose) { + buffer.release(); + } + checkPositiveOrZero(length, "length"); + } + if (length > buffer.readableBytes()) { + if (releaseOnClose) { + buffer.release(); + } + throw new IndexOutOfBoundsException("Too many bytes to be read - Needs " + + length + ", maximum is " + buffer.readableBytes()); + } + + this.releaseOnClose = releaseOnClose; + this.buffer = buffer; + startIndex = buffer.readerIndex(); + endIndex = startIndex + length; + buffer.markReaderIndex(); + } + + /** + * Returns the number of read bytes by this stream so far. + */ + public int readBytes() { + return buffer.readerIndex() - startIndex; + } + + @Override + public void close() throws IOException { + try { + super.close(); + } finally { + // The Closable interface says "If the stream is already closed then invoking this method has no effect." + if (releaseOnClose && !closed) { + closed = true; + buffer.release(); + } + } + } + + @Override + public int available() throws IOException { + return endIndex - buffer.readerIndex(); + } + + // Suppress a warning since the class is not thread-safe + @Override + public void mark(int readlimit) { + buffer.markReaderIndex(); + } + + @Override + public boolean markSupported() { + return true; + } + + @Override + public int read() throws IOException { + int available = available(); + if (available == 0) { + return -1; + } + return buffer.readByte() & 0xff; + } + + @Override + public int read(byte[] b, int off, int len) throws IOException { + int available = available(); + if (available == 0) { + return -1; + } + + len = Math.min(available, len); + buffer.readBytes(b, off, len); + return len; + } + + // Suppress a warning since the class is not thread-safe + @Override + public void reset() throws IOException { + buffer.resetReaderIndex(); + } + + @Override + public long skip(long n) throws IOException { + if (n > Integer.MAX_VALUE) { + return skipBytes(Integer.MAX_VALUE); + } else { + return skipBytes((int) n); + } + } + + @Override + public boolean readBoolean() throws IOException { + checkAvailable(1); + return read() != 0; + } + + @Override + public byte readByte() throws IOException { + int available = available(); + if (available == 0) { + throw new EOFException(); + } + return buffer.readByte(); + } + + @Override + public char readChar() throws IOException { + return (char) readShort(); + } + + @Override + public double readDouble() throws IOException { + return Double.longBitsToDouble(readLong()); + } + + @Override + public float readFloat() throws IOException { + return Float.intBitsToFloat(readInt()); + } + + @Override + public void readFully(byte[] b) throws IOException { + readFully(b, 0, b.length); + } + + @Override + public void readFully(byte[] b, int off, int len) throws IOException { + checkAvailable(len); + buffer.readBytes(b, off, len); + } + + @Override + public int readInt() throws IOException { + checkAvailable(4); + return buffer.readInt(); + } + + private StringBuilder lineBuf; + + @Override + public String readLine() throws IOException { + int available = available(); + if (available == 0) { + return null; + } + + if (lineBuf != null) { + lineBuf.setLength(0); + } + + loop: do { + int c = buffer.readUnsignedByte(); + --available; + switch (c) { + case '\n': + break loop; + + case '\r': + if (available > 0 && (char) buffer.getUnsignedByte(buffer.readerIndex()) == '\n') { + buffer.skipBytes(1); + --available; + } + break loop; + + default: + if (lineBuf == null) { + lineBuf = new StringBuilder(); + } + lineBuf.append((char) c); + } + } while (available > 0); + + return lineBuf != null && lineBuf.length() > 0 ? lineBuf.toString() : StringUtil.EMPTY_STRING; + } + + @Override + public long readLong() throws IOException { + checkAvailable(8); + return buffer.readLong(); + } + + @Override + public short readShort() throws IOException { + checkAvailable(2); + return buffer.readShort(); + } + + @Override + public String readUTF() throws IOException { + return DataInputStream.readUTF(this); + } + + @Override + public int readUnsignedByte() throws IOException { + return readByte() & 0xff; + } + + @Override + public int readUnsignedShort() throws IOException { + return readShort() & 0xffff; + } + + @Override + public int skipBytes(int n) throws IOException { + int nBytes = Math.min(available(), n); + buffer.skipBytes(nBytes); + return nBytes; + } + + private void checkAvailable(int fieldSize) throws IOException { + if (fieldSize < 0) { + throw new IndexOutOfBoundsException("fieldSize cannot be a negative number"); + } + if (fieldSize > available()) { + throw new EOFException("fieldSize is too long! Length is " + fieldSize + + ", but maximum is " + available()); + } + } +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/ByteBufOutputStream.java b/netty-buffer/src/main/java/io/netty/buffer/ByteBufOutputStream.java new file mode 100644 index 0000000..85ba4c5 --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/ByteBufOutputStream.java @@ -0,0 +1,168 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import io.netty.util.CharsetUtil; +import io.netty.util.internal.ObjectUtil; + +import java.io.DataOutput; +import java.io.DataOutputStream; +import java.io.IOException; +import java.io.OutputStream; + +/** + * An {@link OutputStream} which writes data to a {@link ByteBuf}. + *

+ * A write operation against this stream will occur at the {@code writerIndex} + * of its underlying buffer and the {@code writerIndex} will increase during + * the write operation. + *

+ * This stream implements {@link DataOutput} for your convenience. + * The endianness of the stream is not always big endian but depends on + * the endianness of the underlying buffer. + * + * @see ByteBufInputStream + */ +public class ByteBufOutputStream extends OutputStream implements DataOutput { + + private final ByteBuf buffer; + private final int startIndex; + private DataOutputStream utf8out; // lazily-instantiated + private boolean closed; + + /** + * Creates a new stream which writes data to the specified {@code buffer}. + */ + public ByteBufOutputStream(ByteBuf buffer) { + this.buffer = ObjectUtil.checkNotNull(buffer, "buffer"); + startIndex = buffer.writerIndex(); + } + + /** + * Returns the number of written bytes by this stream so far. + */ + public int writtenBytes() { + return buffer.writerIndex() - startIndex; + } + + @Override + public void write(byte[] b, int off, int len) throws IOException { + if (len == 0) { + return; + } + + buffer.writeBytes(b, off, len); + } + + @Override + public void write(byte[] b) throws IOException { + buffer.writeBytes(b); + } + + @Override + public void write(int b) throws IOException { + buffer.writeByte(b); + } + + @Override + public void writeBoolean(boolean v) throws IOException { + buffer.writeBoolean(v); + } + + @Override + public void writeByte(int v) throws IOException { + buffer.writeByte(v); + } + + @Override + public void writeBytes(String s) throws IOException { + buffer.writeCharSequence(s, CharsetUtil.US_ASCII); + } + + @Override + public void writeChar(int v) throws IOException { + buffer.writeChar(v); + } + + @Override + public void writeChars(String s) throws IOException { + int len = s.length(); + for (int i = 0 ; i < len ; i ++) { + buffer.writeChar(s.charAt(i)); + } + } + + @Override + public void writeDouble(double v) throws IOException { + buffer.writeDouble(v); + } + + @Override + public void writeFloat(float v) throws IOException { + buffer.writeFloat(v); + } + + @Override + public void writeInt(int v) throws IOException { + buffer.writeInt(v); + } + + @Override + public void writeLong(long v) throws IOException { + buffer.writeLong(v); + } + + @Override + public void writeShort(int v) throws IOException { + buffer.writeShort((short) v); + } + + @Override + public void writeUTF(String s) throws IOException { + DataOutputStream out = utf8out; + if (out == null) { + if (closed) { + throw new IOException("The stream is closed"); + } + // Suppress a warning since the stream is closed in the close() method + utf8out = out = new DataOutputStream(this); + } + out.writeUTF(s); + } + + /** + * Returns the buffer where this stream is writing data. + */ + public ByteBuf buffer() { + return buffer; + } + + @Override + public void close() throws IOException { + if (closed) { + return; + } + closed = true; + + try { + super.close(); + } finally { + if (utf8out != null) { + utf8out.close(); + } + } + } +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/ByteBufProcessor.java b/netty-buffer/src/main/java/io/netty/buffer/ByteBufProcessor.java new file mode 100644 index 0000000..9e27987 --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/ByteBufProcessor.java @@ -0,0 +1,136 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.buffer; + +import io.netty.util.ByteProcessor; + +/** + * @deprecated Use {@link ByteProcessor}. + */ +@Deprecated +public interface ByteBufProcessor extends ByteProcessor { + + /** + * @deprecated Use {@link ByteProcessor#FIND_NUL}. + */ + @Deprecated + ByteBufProcessor FIND_NUL = new ByteBufProcessor() { + @Override + public boolean process(byte value) throws Exception { + return value != 0; + } + }; + + /** + * @deprecated Use {@link ByteProcessor#FIND_NON_NUL}. + */ + @Deprecated + ByteBufProcessor FIND_NON_NUL = new ByteBufProcessor() { + @Override + public boolean process(byte value) throws Exception { + return value == 0; + } + }; + + /** + * @deprecated Use {@link ByteProcessor#FIND_CR}. + */ + @Deprecated + ByteBufProcessor FIND_CR = new ByteBufProcessor() { + @Override + public boolean process(byte value) throws Exception { + return value != '\r'; + } + }; + + /** + * @deprecated Use {@link ByteProcessor#FIND_NON_CR}. + */ + @Deprecated + ByteBufProcessor FIND_NON_CR = new ByteBufProcessor() { + @Override + public boolean process(byte value) throws Exception { + return value == '\r'; + } + }; + + /** + * @deprecated Use {@link ByteProcessor#FIND_LF}. + */ + @Deprecated + ByteBufProcessor FIND_LF = new ByteBufProcessor() { + @Override + public boolean process(byte value) throws Exception { + return value != '\n'; + } + }; + + /** + * @deprecated Use {@link ByteProcessor#FIND_NON_LF}. + */ + @Deprecated + ByteBufProcessor FIND_NON_LF = new ByteBufProcessor() { + @Override + public boolean process(byte value) throws Exception { + return value == '\n'; + } + }; + + /** + * @deprecated Use {@link ByteProcessor#FIND_CRLF}. + */ + @Deprecated + ByteBufProcessor FIND_CRLF = new ByteBufProcessor() { + @Override + public boolean process(byte value) throws Exception { + return value != '\r' && value != '\n'; + } + }; + + /** + * @deprecated Use {@link ByteProcessor#FIND_NON_CRLF}. + */ + @Deprecated + ByteBufProcessor FIND_NON_CRLF = new ByteBufProcessor() { + @Override + public boolean process(byte value) throws Exception { + return value == '\r' || value == '\n'; + } + }; + + /** + * @deprecated Use {@link ByteProcessor#FIND_LINEAR_WHITESPACE}. + */ + @Deprecated + ByteBufProcessor FIND_LINEAR_WHITESPACE = new ByteBufProcessor() { + @Override + public boolean process(byte value) throws Exception { + return value != ' ' && value != '\t'; + } + }; + + /** + * @deprecated Use {@link ByteProcessor#FIND_NON_LINEAR_WHITESPACE}. + */ + @Deprecated + ByteBufProcessor FIND_NON_LINEAR_WHITESPACE = new ByteBufProcessor() { + @Override + public boolean process(byte value) throws Exception { + return value == ' ' || value == '\t'; + } + }; +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/ByteBufUtil.java b/netty-buffer/src/main/java/io/netty/buffer/ByteBufUtil.java new file mode 100644 index 0000000..3f0c588 --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/ByteBufUtil.java @@ -0,0 +1,1952 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import io.netty.util.AsciiString; +import io.netty.util.ByteProcessor; +import io.netty.util.CharsetUtil; +import io.netty.util.IllegalReferenceCountException; +import io.netty.util.Recycler.EnhancedHandle; +import io.netty.util.ResourceLeakDetector; +import io.netty.util.concurrent.FastThreadLocal; +import io.netty.util.internal.MathUtil; +import io.netty.util.internal.ObjectPool; +import io.netty.util.internal.ObjectPool.Handle; +import io.netty.util.internal.ObjectPool.ObjectCreator; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.StringUtil; +import io.netty.util.internal.SystemPropertyUtil; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.io.IOException; +import java.io.OutputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.CharBuffer; +import java.nio.charset.CharacterCodingException; +import java.nio.charset.Charset; +import java.nio.charset.CharsetDecoder; +import java.nio.charset.CharsetEncoder; +import java.nio.charset.CoderResult; +import java.nio.charset.CodingErrorAction; +import java.util.Arrays; +import java.util.Locale; + +import static io.netty.util.internal.MathUtil.isOutOfBounds; +import static io.netty.util.internal.ObjectUtil.checkNotNull; +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; +import static io.netty.util.internal.StringUtil.NEWLINE; +import static io.netty.util.internal.StringUtil.isSurrogate; + +/** + * A collection of utility methods that is related with handling {@link ByteBuf}, + * such as the generation of hex dump and swapping an integer's byte order. + */ +public final class ByteBufUtil { + + private static final InternalLogger logger = InternalLoggerFactory.getInstance(ByteBufUtil.class); + private static final FastThreadLocal BYTE_ARRAYS = new FastThreadLocal() { + @Override + protected byte[] initialValue() throws Exception { + return PlatformDependent.allocateUninitializedArray(MAX_TL_ARRAY_LEN); + } + }; + + private static final byte WRITE_UTF_UNKNOWN = (byte) '?'; + private static final int MAX_CHAR_BUFFER_SIZE; + private static final int THREAD_LOCAL_BUFFER_SIZE; + private static final int MAX_BYTES_PER_CHAR_UTF8 = + (int) CharsetUtil.encoder(CharsetUtil.UTF_8).maxBytesPerChar(); + + static final int WRITE_CHUNK_SIZE = 8192; + static final ByteBufAllocator DEFAULT_ALLOCATOR; + + static { + String allocType = SystemPropertyUtil.get( + "io.netty.allocator.type", PlatformDependent.isAndroid() ? "unpooled" : "pooled"); + allocType = allocType.toLowerCase(Locale.US).trim(); + + ByteBufAllocator alloc; + if ("unpooled".equals(allocType)) { + alloc = UnpooledByteBufAllocator.DEFAULT; + logger.debug("-Dio.netty.allocator.type: {}", allocType); + } else if ("pooled".equals(allocType)) { + alloc = PooledByteBufAllocator.DEFAULT; + logger.debug("-Dio.netty.allocator.type: {}", allocType); + } else { + alloc = PooledByteBufAllocator.DEFAULT; + logger.debug("-Dio.netty.allocator.type: pooled (unknown: {})", allocType); + } + + DEFAULT_ALLOCATOR = alloc; + + THREAD_LOCAL_BUFFER_SIZE = SystemPropertyUtil.getInt("io.netty.threadLocalDirectBufferSize", 0); + logger.debug("-Dio.netty.threadLocalDirectBufferSize: {}", THREAD_LOCAL_BUFFER_SIZE); + + MAX_CHAR_BUFFER_SIZE = SystemPropertyUtil.getInt("io.netty.maxThreadLocalCharBufferSize", 16 * 1024); + logger.debug("-Dio.netty.maxThreadLocalCharBufferSize: {}", MAX_CHAR_BUFFER_SIZE); + } + + static final int MAX_TL_ARRAY_LEN = 1024; + + /** + * Allocates a new array if minLength > {@link ByteBufUtil#MAX_TL_ARRAY_LEN} + */ + static byte[] threadLocalTempArray(int minLength) { + return minLength <= MAX_TL_ARRAY_LEN ? BYTE_ARRAYS.get() + : PlatformDependent.allocateUninitializedArray(minLength); + } + + /** + * @return whether the specified buffer has a nonzero ref count + */ + public static boolean isAccessible(ByteBuf buffer) { + return buffer.isAccessible(); + } + + /** + * @throws IllegalReferenceCountException if the buffer has a zero ref count + * @return the passed in buffer + */ + public static ByteBuf ensureAccessible(ByteBuf buffer) { + if (!buffer.isAccessible()) { + throw new IllegalReferenceCountException(buffer.refCnt()); + } + return buffer; + } + + /** + * Returns a hex dump + * of the specified buffer's readable bytes. + */ + public static String hexDump(ByteBuf buffer) { + return hexDump(buffer, buffer.readerIndex(), buffer.readableBytes()); + } + + /** + * Returns a hex dump + * of the specified buffer's sub-region. + */ + public static String hexDump(ByteBuf buffer, int fromIndex, int length) { + return HexUtil.hexDump(buffer, fromIndex, length); + } + + /** + * Returns a hex dump + * of the specified byte array. + */ + public static String hexDump(byte[] array) { + return hexDump(array, 0, array.length); + } + + /** + * Returns a hex dump + * of the specified byte array's sub-region. + */ + public static String hexDump(byte[] array, int fromIndex, int length) { + return HexUtil.hexDump(array, fromIndex, length); + } + + /** + * Decode a 2-digit hex byte from within a string. + */ + public static byte decodeHexByte(CharSequence s, int pos) { + return StringUtil.decodeHexByte(s, pos); + } + + /** + * Decodes a string generated by {@link #hexDump(byte[])} + */ + public static byte[] decodeHexDump(CharSequence hexDump) { + return StringUtil.decodeHexDump(hexDump, 0, hexDump.length()); + } + + /** + * Decodes part of a string generated by {@link #hexDump(byte[])} + */ + public static byte[] decodeHexDump(CharSequence hexDump, int fromIndex, int length) { + return StringUtil.decodeHexDump(hexDump, fromIndex, length); + } + + /** + * Used to determine if the return value of {@link ByteBuf#ensureWritable(int, boolean)} means that there is + * adequate space and a write operation will succeed. + * @param ensureWritableResult The return value from {@link ByteBuf#ensureWritable(int, boolean)}. + * @return {@code true} if {@code ensureWritableResult} means that there is adequate space and a write operation + * will succeed. + */ + public static boolean ensureWritableSuccess(int ensureWritableResult) { + return ensureWritableResult == 0 || ensureWritableResult == 2; + } + + /** + * Calculates the hash code of the specified buffer. This method is + * useful when implementing a new buffer type. + */ + public static int hashCode(ByteBuf buffer) { + final int aLen = buffer.readableBytes(); + final int intCount = aLen >>> 2; + final int byteCount = aLen & 3; + + int hashCode = EmptyByteBuf.EMPTY_BYTE_BUF_HASH_CODE; + int arrayIndex = buffer.readerIndex(); + if (buffer.order() == ByteOrder.BIG_ENDIAN) { + for (int i = intCount; i > 0; i --) { + hashCode = 31 * hashCode + buffer.getInt(arrayIndex); + arrayIndex += 4; + } + } else { + for (int i = intCount; i > 0; i --) { + hashCode = 31 * hashCode + swapInt(buffer.getInt(arrayIndex)); + arrayIndex += 4; + } + } + + for (int i = byteCount; i > 0; i --) { + hashCode = 31 * hashCode + buffer.getByte(arrayIndex ++); + } + + if (hashCode == 0) { + hashCode = 1; + } + + return hashCode; + } + + /** + * Returns the reader index of needle in haystack, or -1 if needle is not in haystack. + * This method uses the Two-Way + * string matching algorithm, which yields O(1) space complexity and excellent performance. + */ + public static int indexOf(ByteBuf needle, ByteBuf haystack) { + if (haystack == null || needle == null) { + return -1; + } + + if (needle.readableBytes() > haystack.readableBytes()) { + return -1; + } + + int n = haystack.readableBytes(); + int m = needle.readableBytes(); + if (m == 0) { + return 0; + } + + // When the needle has only one byte that can be read, + // the ByteBuf.indexOf() can be used + if (m == 1) { + return haystack.indexOf(haystack.readerIndex(), haystack.writerIndex(), + needle.getByte(needle.readerIndex())); + } + + int i; + int j = 0; + int aStartIndex = needle.readerIndex(); + int bStartIndex = haystack.readerIndex(); + long suffixes = maxSuf(needle, m, aStartIndex, true); + long prefixes = maxSuf(needle, m, aStartIndex, false); + int ell = Math.max((int) (suffixes >> 32), (int) (prefixes >> 32)); + int per = Math.max((int) suffixes, (int) prefixes); + int memory; + int length = Math.min(m - per, ell + 1); + + if (equals(needle, aStartIndex, needle, aStartIndex + per, length)) { + memory = -1; + while (j <= n - m) { + i = Math.max(ell, memory) + 1; + while (i < m && needle.getByte(i + aStartIndex) == haystack.getByte(i + j + bStartIndex)) { + ++i; + } + if (i > n) { + return -1; + } + if (i >= m) { + i = ell; + while (i > memory && needle.getByte(i + aStartIndex) == haystack.getByte(i + j + bStartIndex)) { + --i; + } + if (i <= memory) { + return j + bStartIndex; + } + j += per; + memory = m - per - 1; + } else { + j += i - ell; + memory = -1; + } + } + } else { + per = Math.max(ell + 1, m - ell - 1) + 1; + while (j <= n - m) { + i = ell + 1; + while (i < m && needle.getByte(i + aStartIndex) == haystack.getByte(i + j + bStartIndex)) { + ++i; + } + if (i > n) { + return -1; + } + if (i >= m) { + i = ell; + while (i >= 0 && needle.getByte(i + aStartIndex) == haystack.getByte(i + j + bStartIndex)) { + --i; + } + if (i < 0) { + return j + bStartIndex; + } + j += per; + } else { + j += i - ell; + } + } + } + return -1; + } + + private static long maxSuf(ByteBuf x, int m, int start, boolean isSuffix) { + int p = 1; + int ms = -1; + int j = start; + int k = 1; + byte a; + byte b; + while (j + k < m) { + a = x.getByte(j + k); + b = x.getByte(ms + k); + boolean suffix = isSuffix ? a < b : a > b; + if (suffix) { + j += k; + k = 1; + p = j - ms; + } else if (a == b) { + if (k != p) { + ++k; + } else { + j += p; + k = 1; + } + } else { + ms = j; + j = ms + 1; + k = p = 1; + } + } + return ((long) ms << 32) + p; + } + + /** + * Returns {@code true} if and only if the two specified buffers are + * identical to each other for {@code length} bytes starting at {@code aStartIndex} + * index for the {@code a} buffer and {@code bStartIndex} index for the {@code b} buffer. + * A more compact way to express this is: + *

+ * {@code a[aStartIndex : aStartIndex + length] == b[bStartIndex : bStartIndex + length]} + */ + public static boolean equals(ByteBuf a, int aStartIndex, ByteBuf b, int bStartIndex, int length) { + checkNotNull(a, "a"); + checkNotNull(b, "b"); + // All indexes and lengths must be non-negative + checkPositiveOrZero(aStartIndex, "aStartIndex"); + checkPositiveOrZero(bStartIndex, "bStartIndex"); + checkPositiveOrZero(length, "length"); + + if (a.writerIndex() - length < aStartIndex || b.writerIndex() - length < bStartIndex) { + return false; + } + + final int longCount = length >>> 3; + final int byteCount = length & 7; + + if (a.order() == b.order()) { + for (int i = longCount; i > 0; i --) { + if (a.getLong(aStartIndex) != b.getLong(bStartIndex)) { + return false; + } + aStartIndex += 8; + bStartIndex += 8; + } + } else { + for (int i = longCount; i > 0; i --) { + if (a.getLong(aStartIndex) != swapLong(b.getLong(bStartIndex))) { + return false; + } + aStartIndex += 8; + bStartIndex += 8; + } + } + + for (int i = byteCount; i > 0; i --) { + if (a.getByte(aStartIndex) != b.getByte(bStartIndex)) { + return false; + } + aStartIndex ++; + bStartIndex ++; + } + + return true; + } + + /** + * Returns {@code true} if and only if the two specified buffers are + * identical to each other as described in {@link ByteBuf#equals(Object)}. + * This method is useful when implementing a new buffer type. + */ + public static boolean equals(ByteBuf bufferA, ByteBuf bufferB) { + if (bufferA == bufferB) { + return true; + } + final int aLen = bufferA.readableBytes(); + if (aLen != bufferB.readableBytes()) { + return false; + } + return equals(bufferA, bufferA.readerIndex(), bufferB, bufferB.readerIndex(), aLen); + } + + /** + * Compares the two specified buffers as described in {@link ByteBuf#compareTo(ByteBuf)}. + * This method is useful when implementing a new buffer type. + */ + public static int compare(ByteBuf bufferA, ByteBuf bufferB) { + if (bufferA == bufferB) { + return 0; + } + final int aLen = bufferA.readableBytes(); + final int bLen = bufferB.readableBytes(); + final int minLength = Math.min(aLen, bLen); + final int uintCount = minLength >>> 2; + final int byteCount = minLength & 3; + int aIndex = bufferA.readerIndex(); + int bIndex = bufferB.readerIndex(); + + if (uintCount > 0) { + boolean bufferAIsBigEndian = bufferA.order() == ByteOrder.BIG_ENDIAN; + final long res; + int uintCountIncrement = uintCount << 2; + + if (bufferA.order() == bufferB.order()) { + res = bufferAIsBigEndian ? compareUintBigEndian(bufferA, bufferB, aIndex, bIndex, uintCountIncrement) : + compareUintLittleEndian(bufferA, bufferB, aIndex, bIndex, uintCountIncrement); + } else { + res = bufferAIsBigEndian ? compareUintBigEndianA(bufferA, bufferB, aIndex, bIndex, uintCountIncrement) : + compareUintBigEndianB(bufferA, bufferB, aIndex, bIndex, uintCountIncrement); + } + if (res != 0) { + // Ensure we not overflow when cast + return (int) Math.min(Integer.MAX_VALUE, Math.max(Integer.MIN_VALUE, res)); + } + aIndex += uintCountIncrement; + bIndex += uintCountIncrement; + } + + for (int aEnd = aIndex + byteCount; aIndex < aEnd; ++aIndex, ++bIndex) { + int comp = bufferA.getUnsignedByte(aIndex) - bufferB.getUnsignedByte(bIndex); + if (comp != 0) { + return comp; + } + } + + return aLen - bLen; + } + + private static long compareUintBigEndian( + ByteBuf bufferA, ByteBuf bufferB, int aIndex, int bIndex, int uintCountIncrement) { + for (int aEnd = aIndex + uintCountIncrement; aIndex < aEnd; aIndex += 4, bIndex += 4) { + long comp = bufferA.getUnsignedInt(aIndex) - bufferB.getUnsignedInt(bIndex); + if (comp != 0) { + return comp; + } + } + return 0; + } + + private static long compareUintLittleEndian( + ByteBuf bufferA, ByteBuf bufferB, int aIndex, int bIndex, int uintCountIncrement) { + for (int aEnd = aIndex + uintCountIncrement; aIndex < aEnd; aIndex += 4, bIndex += 4) { + long comp = uintFromLE(bufferA.getUnsignedIntLE(aIndex)) - uintFromLE(bufferB.getUnsignedIntLE(bIndex)); + if (comp != 0) { + return comp; + } + } + return 0; + } + + private static long compareUintBigEndianA( + ByteBuf bufferA, ByteBuf bufferB, int aIndex, int bIndex, int uintCountIncrement) { + for (int aEnd = aIndex + uintCountIncrement; aIndex < aEnd; aIndex += 4, bIndex += 4) { + long a = bufferA.getUnsignedInt(aIndex); + long b = uintFromLE(bufferB.getUnsignedIntLE(bIndex)); + long comp = a - b; + if (comp != 0) { + return comp; + } + } + return 0; + } + + private static long compareUintBigEndianB( + ByteBuf bufferA, ByteBuf bufferB, int aIndex, int bIndex, int uintCountIncrement) { + for (int aEnd = aIndex + uintCountIncrement; aIndex < aEnd; aIndex += 4, bIndex += 4) { + long a = uintFromLE(bufferA.getUnsignedIntLE(aIndex)); + long b = bufferB.getUnsignedInt(bIndex); + long comp = a - b; + if (comp != 0) { + return comp; + } + } + return 0; + } + + private static long uintFromLE(long value) { + return Long.reverseBytes(value) >>> Integer.SIZE; + } + + private static final class SWARByteSearch { + + private static long compilePattern(byte byteToFind) { + return (byteToFind & 0xFFL) * 0x101010101010101L; + } + + private static int firstAnyPattern(long word, long pattern, boolean leading) { + long input = word ^ pattern; + long tmp = (input & 0x7F7F7F7F7F7F7F7FL) + 0x7F7F7F7F7F7F7F7FL; + tmp = ~(tmp | input | 0x7F7F7F7F7F7F7F7FL); + final int binaryPosition = leading? Long.numberOfLeadingZeros(tmp) : Long.numberOfTrailingZeros(tmp); + return binaryPosition >>> 3; + } + } + + private static int unrolledFirstIndexOf(AbstractByteBuf buffer, int fromIndex, int byteCount, byte value) { + assert byteCount > 0 && byteCount < 8; + if (buffer._getByte(fromIndex) == value) { + return fromIndex; + } + if (byteCount == 1) { + return -1; + } + if (buffer._getByte(fromIndex + 1) == value) { + return fromIndex + 1; + } + if (byteCount == 2) { + return -1; + } + if (buffer._getByte(fromIndex + 2) == value) { + return fromIndex + 2; + } + if (byteCount == 3) { + return -1; + } + if (buffer._getByte(fromIndex + 3) == value) { + return fromIndex + 3; + } + if (byteCount == 4) { + return -1; + } + if (buffer._getByte(fromIndex + 4) == value) { + return fromIndex + 4; + } + if (byteCount == 5) { + return -1; + } + if (buffer._getByte(fromIndex + 5) == value) { + return fromIndex + 5; + } + if (byteCount == 6) { + return -1; + } + if (buffer._getByte(fromIndex + 6) == value) { + return fromIndex + 6; + } + return -1; + } + + /** + * This is using a SWAR (SIMD Within A Register) batch read technique to minimize bound-checks and improve memory + * usage while searching for {@code value}. + */ + static int firstIndexOf(AbstractByteBuf buffer, int fromIndex, int toIndex, byte value) { + fromIndex = Math.max(fromIndex, 0); + if (fromIndex >= toIndex || buffer.capacity() == 0) { + return -1; + } + final int length = toIndex - fromIndex; + buffer.checkIndex(fromIndex, length); + if (!PlatformDependent.isUnaligned()) { + return linearFirstIndexOf(buffer, fromIndex, toIndex, value); + } + assert PlatformDependent.isUnaligned(); + int offset = fromIndex; + final int byteCount = length & 7; + if (byteCount > 0) { + final int index = unrolledFirstIndexOf(buffer, fromIndex, byteCount, value); + if (index != -1) { + return index; + } + offset += byteCount; + if (offset == toIndex) { + return -1; + } + } + final int longCount = length >>> 3; + final ByteOrder nativeOrder = ByteOrder.nativeOrder(); + final boolean isNative = nativeOrder == buffer.order(); + final boolean useLE = nativeOrder == ByteOrder.LITTLE_ENDIAN; + final long pattern = SWARByteSearch.compilePattern(value); + for (int i = 0; i < longCount; i++) { + // use the faster available getLong + final long word = useLE? buffer._getLongLE(offset) : buffer._getLong(offset); + int index = SWARByteSearch.firstAnyPattern(word, pattern, isNative); + if (index < Long.BYTES) { + return offset + index; + } + offset += Long.BYTES; + } + return -1; + } + + private static int linearFirstIndexOf(AbstractByteBuf buffer, int fromIndex, int toIndex, byte value) { + for (int i = fromIndex; i < toIndex; i++) { + if (buffer._getByte(i) == value) { + return i; + } + } + return -1; + } + + /** + * The default implementation of {@link ByteBuf#indexOf(int, int, byte)}. + * This method is useful when implementing a new buffer type. + */ + public static int indexOf(ByteBuf buffer, int fromIndex, int toIndex, byte value) { + return buffer.indexOf(fromIndex, toIndex, value); + } + + /** + * Toggles the endianness of the specified 16-bit short integer. + */ + public static short swapShort(short value) { + return Short.reverseBytes(value); + } + + /** + * Toggles the endianness of the specified 24-bit medium integer. + */ + public static int swapMedium(int value) { + int swapped = value << 16 & 0xff0000 | value & 0xff00 | value >>> 16 & 0xff; + if ((swapped & 0x800000) != 0) { + swapped |= 0xff000000; + } + return swapped; + } + + /** + * Toggles the endianness of the specified 32-bit integer. + */ + public static int swapInt(int value) { + return Integer.reverseBytes(value); + } + + /** + * Toggles the endianness of the specified 64-bit long integer. + */ + public static long swapLong(long value) { + return Long.reverseBytes(value); + } + + /** + * Writes a big-endian 16-bit short integer to the buffer. + */ + @SuppressWarnings("deprecation") + public static ByteBuf writeShortBE(ByteBuf buf, int shortValue) { + return buf.order() == ByteOrder.BIG_ENDIAN? buf.writeShort(shortValue) : + buf.writeShort(swapShort((short) shortValue)); + } + + /** + * Sets a big-endian 16-bit short integer to the buffer. + */ + @SuppressWarnings("deprecation") + public static ByteBuf setShortBE(ByteBuf buf, int index, int shortValue) { + return buf.order() == ByteOrder.BIG_ENDIAN? buf.setShort(index, shortValue) : + buf.setShort(index, swapShort((short) shortValue)); + } + + /** + * Writes a big-endian 24-bit medium integer to the buffer. + */ + @SuppressWarnings("deprecation") + public static ByteBuf writeMediumBE(ByteBuf buf, int mediumValue) { + return buf.order() == ByteOrder.BIG_ENDIAN? buf.writeMedium(mediumValue) : + buf.writeMedium(swapMedium(mediumValue)); + } + + /** + * Reads a big-endian unsigned 16-bit short integer from the buffer. + */ + @SuppressWarnings("deprecation") + public static int readUnsignedShortBE(ByteBuf buf) { + return buf.order() == ByteOrder.BIG_ENDIAN? buf.readUnsignedShort() : + swapShort((short) buf.readUnsignedShort()) & 0xFFFF; + } + + /** + * Reads a big-endian 32-bit integer from the buffer. + */ + @SuppressWarnings("deprecation") + public static int readIntBE(ByteBuf buf) { + return buf.order() == ByteOrder.BIG_ENDIAN? buf.readInt() : + swapInt(buf.readInt()); + } + + /** + * Read the given amount of bytes into a new {@link ByteBuf} that is allocated from the {@link ByteBufAllocator}. + */ + public static ByteBuf readBytes(ByteBufAllocator alloc, ByteBuf buffer, int length) { + boolean release = true; + ByteBuf dst = alloc.buffer(length); + try { + buffer.readBytes(dst); + release = false; + return dst; + } finally { + if (release) { + dst.release(); + } + } + } + + static int lastIndexOf(AbstractByteBuf buffer, int fromIndex, int toIndex, byte value) { + assert fromIndex > toIndex; + final int capacity = buffer.capacity(); + fromIndex = Math.min(fromIndex, capacity); + if (fromIndex < 0 || capacity == 0) { + return -1; + } + buffer.checkIndex(toIndex, fromIndex - toIndex); + for (int i = fromIndex - 1; i >= toIndex; i--) { + if (buffer._getByte(i) == value) { + return i; + } + } + + return -1; + } + + private static CharSequence checkCharSequenceBounds(CharSequence seq, int start, int end) { + if (MathUtil.isOutOfBounds(start, end - start, seq.length())) { + throw new IndexOutOfBoundsException("expected: 0 <= start(" + start + ") <= end (" + end + + ") <= seq.length(" + seq.length() + ')'); + } + return seq; + } + + /** + * Encode a {@link CharSequence} in UTF-8 and write + * it to a {@link ByteBuf} allocated with {@code alloc}. + * @param alloc The allocator used to allocate a new {@link ByteBuf}. + * @param seq The characters to write into a buffer. + * @return The {@link ByteBuf} which contains the UTF-8 encoded + * result. + */ + public static ByteBuf writeUtf8(ByteBufAllocator alloc, CharSequence seq) { + // UTF-8 uses max. 3 bytes per char, so calculate the worst case. + ByteBuf buf = alloc.buffer(utf8MaxBytes(seq)); + writeUtf8(buf, seq); + return buf; + } + + /** + * Encode a {@link CharSequence} in UTF-8 and write + * it to a {@link ByteBuf}. + *

+ * It behaves like {@link #reserveAndWriteUtf8(ByteBuf, CharSequence, int)} with {@code reserveBytes} + * computed by {@link #utf8MaxBytes(CharSequence)}.
+ * This method returns the actual number of bytes written. + */ + public static int writeUtf8(ByteBuf buf, CharSequence seq) { + int seqLength = seq.length(); + return reserveAndWriteUtf8Seq(buf, seq, 0, seqLength, utf8MaxBytes(seqLength)); + } + + /** + * Equivalent to {@link #writeUtf8(ByteBuf, CharSequence) writeUtf8(buf, seq.subSequence(start, end))} + * but avoids subsequence object allocation. + */ + public static int writeUtf8(ByteBuf buf, CharSequence seq, int start, int end) { + checkCharSequenceBounds(seq, start, end); + return reserveAndWriteUtf8Seq(buf, seq, start, end, utf8MaxBytes(end - start)); + } + + /** + * Encode a {@link CharSequence} in UTF-8 and write + * it into {@code reserveBytes} of a {@link ByteBuf}. + *

+ * The {@code reserveBytes} must be computed (ie eagerly using {@link #utf8MaxBytes(CharSequence)} + * or exactly with {@link #utf8Bytes(CharSequence)}) to ensure this method to not fail: for performance reasons + * the index checks will be performed using just {@code reserveBytes}.
+ * This method returns the actual number of bytes written. + */ + public static int reserveAndWriteUtf8(ByteBuf buf, CharSequence seq, int reserveBytes) { + return reserveAndWriteUtf8Seq(buf, seq, 0, seq.length(), reserveBytes); + } + + /** + * Equivalent to {@link #reserveAndWriteUtf8(ByteBuf, CharSequence, int) + * reserveAndWriteUtf8(buf, seq.subSequence(start, end), reserveBytes)} but avoids + * subsequence object allocation if possible. + * + * @return actual number of bytes written + */ + public static int reserveAndWriteUtf8(ByteBuf buf, CharSequence seq, int start, int end, int reserveBytes) { + return reserveAndWriteUtf8Seq(buf, checkCharSequenceBounds(seq, start, end), start, end, reserveBytes); + } + + private static int reserveAndWriteUtf8Seq(ByteBuf buf, CharSequence seq, int start, int end, int reserveBytes) { + for (;;) { + if (buf instanceof WrappedCompositeByteBuf) { + // WrappedCompositeByteBuf is a sub-class of AbstractByteBuf so it needs special handling. + buf = buf.unwrap(); + } else if (buf instanceof AbstractByteBuf) { + AbstractByteBuf byteBuf = (AbstractByteBuf) buf; + byteBuf.ensureWritable0(reserveBytes); + int written = writeUtf8(byteBuf, byteBuf.writerIndex, reserveBytes, seq, start, end); + byteBuf.writerIndex += written; + return written; + } else if (buf instanceof WrappedByteBuf) { + // Unwrap as the wrapped buffer may be an AbstractByteBuf and so we can use fast-path. + buf = buf.unwrap(); + } else { + byte[] bytes = seq.subSequence(start, end).toString().getBytes(CharsetUtil.UTF_8); + buf.writeBytes(bytes); + return bytes.length; + } + } + } + + static int writeUtf8(AbstractByteBuf buffer, int writerIndex, int reservedBytes, CharSequence seq, int len) { + return writeUtf8(buffer, writerIndex, reservedBytes, seq, 0, len); + } + + // Fast-Path implementation + static int writeUtf8(AbstractByteBuf buffer, int writerIndex, int reservedBytes, + CharSequence seq, int start, int end) { + if (seq instanceof AsciiString) { + writeAsciiString(buffer, writerIndex, (AsciiString) seq, start, end); + return end - start; + } + if (PlatformDependent.hasUnsafe()) { + if (buffer.hasArray()) { + return unsafeWriteUtf8(buffer.array(), PlatformDependent.byteArrayBaseOffset(), + buffer.arrayOffset() + writerIndex, seq, start, end); + } + if (buffer.hasMemoryAddress()) { + return unsafeWriteUtf8(null, buffer.memoryAddress(), writerIndex, seq, start, end); + } + } else { + if (buffer.hasArray()) { + return safeArrayWriteUtf8(buffer.array(), buffer.arrayOffset() + writerIndex, seq, start, end); + } + if (buffer.isDirect()) { + assert buffer.nioBufferCount() == 1; + final ByteBuffer internalDirectBuffer = buffer.internalNioBuffer(writerIndex, reservedBytes); + final int bufferPosition = internalDirectBuffer.position(); + return safeDirectWriteUtf8(internalDirectBuffer, bufferPosition, seq, start, end); + } + } + return safeWriteUtf8(buffer, writerIndex, seq, start, end); + } + + // AsciiString Fast-Path implementation - no explicit bound-checks + static void writeAsciiString(AbstractByteBuf buffer, int writerIndex, AsciiString seq, int start, int end) { + final int begin = seq.arrayOffset() + start; + final int length = end - start; + if (PlatformDependent.hasUnsafe()) { + if (buffer.hasArray()) { + PlatformDependent.copyMemory(seq.array(), begin, + buffer.array(), buffer.arrayOffset() + writerIndex, length); + return; + } + if (buffer.hasMemoryAddress()) { + PlatformDependent.copyMemory(seq.array(), begin, buffer.memoryAddress() + writerIndex, length); + return; + } + } + if (buffer.hasArray()) { + System.arraycopy(seq.array(), begin, buffer.array(), buffer.arrayOffset() + writerIndex, length); + return; + } + buffer.setBytes(writerIndex, seq.array(), begin, length); + } + + // Safe off-heap Fast-Path implementation + private static int safeDirectWriteUtf8(ByteBuffer buffer, int writerIndex, CharSequence seq, int start, int end) { + assert !(seq instanceof AsciiString); + int oldWriterIndex = writerIndex; + + // We can use the _set methods as these not need to do any index checks and reference checks. + // This is possible as we called ensureWritable(...) before. + for (int i = start; i < end; i++) { + char c = seq.charAt(i); + if (c < 0x80) { + buffer.put(writerIndex++, (byte) c); + } else if (c < 0x800) { + buffer.put(writerIndex++, (byte) (0xc0 | (c >> 6))); + buffer.put(writerIndex++, (byte) (0x80 | (c & 0x3f))); + } else if (isSurrogate(c)) { + if (!Character.isHighSurrogate(c)) { + buffer.put(writerIndex++, WRITE_UTF_UNKNOWN); + continue; + } + // Surrogate Pair consumes 2 characters. + if (++i == end) { + buffer.put(writerIndex++, WRITE_UTF_UNKNOWN); + break; + } + // Extra method is copied here to NOT allow inlining of writeUtf8 + // and increase the chance to inline CharSequence::charAt instead + char c2 = seq.charAt(i); + if (!Character.isLowSurrogate(c2)) { + buffer.put(writerIndex++, WRITE_UTF_UNKNOWN); + buffer.put(writerIndex++, Character.isHighSurrogate(c2)? WRITE_UTF_UNKNOWN : (byte) c2); + } else { + int codePoint = Character.toCodePoint(c, c2); + // See https://www.unicode.org/versions/Unicode7.0.0/ch03.pdf#G2630. + buffer.put(writerIndex++, (byte) (0xf0 | (codePoint >> 18))); + buffer.put(writerIndex++, (byte) (0x80 | ((codePoint >> 12) & 0x3f))); + buffer.put(writerIndex++, (byte) (0x80 | ((codePoint >> 6) & 0x3f))); + buffer.put(writerIndex++, (byte) (0x80 | (codePoint & 0x3f))); + } + } else { + buffer.put(writerIndex++, (byte) (0xe0 | (c >> 12))); + buffer.put(writerIndex++, (byte) (0x80 | ((c >> 6) & 0x3f))); + buffer.put(writerIndex++, (byte) (0x80 | (c & 0x3f))); + } + } + return writerIndex - oldWriterIndex; + } + + // Safe off-heap Fast-Path implementation + private static int safeWriteUtf8(AbstractByteBuf buffer, int writerIndex, CharSequence seq, int start, int end) { + assert !(seq instanceof AsciiString); + int oldWriterIndex = writerIndex; + + // We can use the _set methods as these not need to do any index checks and reference checks. + // This is possible as we called ensureWritable(...) before. + for (int i = start; i < end; i++) { + char c = seq.charAt(i); + if (c < 0x80) { + buffer._setByte(writerIndex++, (byte) c); + } else if (c < 0x800) { + buffer._setByte(writerIndex++, (byte) (0xc0 | (c >> 6))); + buffer._setByte(writerIndex++, (byte) (0x80 | (c & 0x3f))); + } else if (isSurrogate(c)) { + if (!Character.isHighSurrogate(c)) { + buffer._setByte(writerIndex++, WRITE_UTF_UNKNOWN); + continue; + } + // Surrogate Pair consumes 2 characters. + if (++i == end) { + buffer._setByte(writerIndex++, WRITE_UTF_UNKNOWN); + break; + } + // Extra method is copied here to NOT allow inlining of writeUtf8 + // and increase the chance to inline CharSequence::charAt instead + char c2 = seq.charAt(i); + if (!Character.isLowSurrogate(c2)) { + buffer._setByte(writerIndex++, WRITE_UTF_UNKNOWN); + buffer._setByte(writerIndex++, Character.isHighSurrogate(c2)? WRITE_UTF_UNKNOWN : c2); + } else { + int codePoint = Character.toCodePoint(c, c2); + // See https://www.unicode.org/versions/Unicode7.0.0/ch03.pdf#G2630. + buffer._setByte(writerIndex++, (byte) (0xf0 | (codePoint >> 18))); + buffer._setByte(writerIndex++, (byte) (0x80 | ((codePoint >> 12) & 0x3f))); + buffer._setByte(writerIndex++, (byte) (0x80 | ((codePoint >> 6) & 0x3f))); + buffer._setByte(writerIndex++, (byte) (0x80 | (codePoint & 0x3f))); + } + } else { + buffer._setByte(writerIndex++, (byte) (0xe0 | (c >> 12))); + buffer._setByte(writerIndex++, (byte) (0x80 | ((c >> 6) & 0x3f))); + buffer._setByte(writerIndex++, (byte) (0x80 | (c & 0x3f))); + } + } + return writerIndex - oldWriterIndex; + } + + // safe byte[] Fast-Path implementation + private static int safeArrayWriteUtf8(byte[] buffer, int writerIndex, CharSequence seq, int start, int end) { + int oldWriterIndex = writerIndex; + for (int i = start; i < end; i++) { + char c = seq.charAt(i); + if (c < 0x80) { + buffer[writerIndex++] = (byte) c; + } else if (c < 0x800) { + buffer[writerIndex++] = (byte) (0xc0 | (c >> 6)); + buffer[writerIndex++] = (byte) (0x80 | (c & 0x3f)); + } else if (isSurrogate(c)) { + if (!Character.isHighSurrogate(c)) { + buffer[writerIndex++] = WRITE_UTF_UNKNOWN; + continue; + } + // Surrogate Pair consumes 2 characters. + if (++i == end) { + buffer[writerIndex++] = WRITE_UTF_UNKNOWN; + break; + } + char c2 = seq.charAt(i); + // Extra method is copied here to NOT allow inlining of writeUtf8 + // and increase the chance to inline CharSequence::charAt instead + if (!Character.isLowSurrogate(c2)) { + buffer[writerIndex++] = WRITE_UTF_UNKNOWN; + buffer[writerIndex++] = (byte) (Character.isHighSurrogate(c2)? WRITE_UTF_UNKNOWN : c2); + } else { + int codePoint = Character.toCodePoint(c, c2); + // See https://www.unicode.org/versions/Unicode7.0.0/ch03.pdf#G2630. + buffer[writerIndex++] = (byte) (0xf0 | (codePoint >> 18)); + buffer[writerIndex++] = (byte) (0x80 | ((codePoint >> 12) & 0x3f)); + buffer[writerIndex++] = (byte) (0x80 | ((codePoint >> 6) & 0x3f)); + buffer[writerIndex++] = (byte) (0x80 | (codePoint & 0x3f)); + } + } else { + buffer[writerIndex++] = (byte) (0xe0 | (c >> 12)); + buffer[writerIndex++] = (byte) (0x80 | ((c >> 6) & 0x3f)); + buffer[writerIndex++] = (byte) (0x80 | (c & 0x3f)); + } + } + return writerIndex - oldWriterIndex; + } + + // unsafe Fast-Path implementation + private static int unsafeWriteUtf8(byte[] buffer, long memoryOffset, int writerIndex, + CharSequence seq, int start, int end) { + assert !(seq instanceof AsciiString); + long writerOffset = memoryOffset + writerIndex; + final long oldWriterOffset = writerOffset; + for (int i = start; i < end; i++) { + char c = seq.charAt(i); + if (c < 0x80) { + PlatformDependent.putByte(buffer, writerOffset++, (byte) c); + } else if (c < 0x800) { + PlatformDependent.putByte(buffer, writerOffset++, (byte) (0xc0 | (c >> 6))); + PlatformDependent.putByte(buffer, writerOffset++, (byte) (0x80 | (c & 0x3f))); + } else if (isSurrogate(c)) { + if (!Character.isHighSurrogate(c)) { + PlatformDependent.putByte(buffer, writerOffset++, WRITE_UTF_UNKNOWN); + continue; + } + // Surrogate Pair consumes 2 characters. + if (++i == end) { + PlatformDependent.putByte(buffer, writerOffset++, WRITE_UTF_UNKNOWN); + break; + } + char c2 = seq.charAt(i); + // Extra method is copied here to NOT allow inlining of writeUtf8 + // and increase the chance to inline CharSequence::charAt instead + if (!Character.isLowSurrogate(c2)) { + PlatformDependent.putByte(buffer, writerOffset++, WRITE_UTF_UNKNOWN); + PlatformDependent.putByte(buffer, writerOffset++, + (byte) (Character.isHighSurrogate(c2)? WRITE_UTF_UNKNOWN : c2)); + } else { + int codePoint = Character.toCodePoint(c, c2); + // See https://www.unicode.org/versions/Unicode7.0.0/ch03.pdf#G2630. + PlatformDependent.putByte(buffer, writerOffset++, (byte) (0xf0 | (codePoint >> 18))); + PlatformDependent.putByte(buffer, writerOffset++, (byte) (0x80 | ((codePoint >> 12) & 0x3f))); + PlatformDependent.putByte(buffer, writerOffset++, (byte) (0x80 | ((codePoint >> 6) & 0x3f))); + PlatformDependent.putByte(buffer, writerOffset++, (byte) (0x80 | (codePoint & 0x3f))); + } + } else { + PlatformDependent.putByte(buffer, writerOffset++, (byte) (0xe0 | (c >> 12))); + PlatformDependent.putByte(buffer, writerOffset++, (byte) (0x80 | ((c >> 6) & 0x3f))); + PlatformDependent.putByte(buffer, writerOffset++, (byte) (0x80 | (c & 0x3f))); + } + } + return (int) (writerOffset - oldWriterOffset); + } + + /** + * Returns max bytes length of UTF8 character sequence of the given length. + */ + public static int utf8MaxBytes(final int seqLength) { + return seqLength * MAX_BYTES_PER_CHAR_UTF8; + } + + /** + * Returns max bytes length of UTF8 character sequence. + *

+ * It behaves like {@link #utf8MaxBytes(int)} applied to {@code seq} {@link CharSequence#length()}. + */ + public static int utf8MaxBytes(CharSequence seq) { + if (seq instanceof AsciiString) { + return seq.length(); + } + return utf8MaxBytes(seq.length()); + } + + /** + * Returns the exact bytes length of UTF8 character sequence. + *

+ * This method is producing the exact length according to {@link #writeUtf8(ByteBuf, CharSequence)}. + */ + public static int utf8Bytes(final CharSequence seq) { + return utf8ByteCount(seq, 0, seq.length()); + } + + /** + * Equivalent to {@link #utf8Bytes(CharSequence) utf8Bytes(seq.subSequence(start, end))} + * but avoids subsequence object allocation. + *

+ * This method is producing the exact length according to {@link #writeUtf8(ByteBuf, CharSequence, int, int)}. + */ + public static int utf8Bytes(final CharSequence seq, int start, int end) { + return utf8ByteCount(checkCharSequenceBounds(seq, start, end), start, end); + } + + private static int utf8ByteCount(final CharSequence seq, int start, int end) { + if (seq instanceof AsciiString) { + return end - start; + } + int i = start; + // ASCII fast path + while (i < end && seq.charAt(i) < 0x80) { + ++i; + } + // !ASCII is packed in a separate method to let the ASCII case be smaller + return i < end ? (i - start) + utf8BytesNonAscii(seq, i, end) : i - start; + } + + private static int utf8BytesNonAscii(final CharSequence seq, final int start, final int end) { + int encodedLength = 0; + for (int i = start; i < end; i++) { + final char c = seq.charAt(i); + // making it 100% branchless isn't rewarding due to the many bit operations necessary! + if (c < 0x800) { + // branchless version of: (c <= 127 ? 0:1) + 1 + encodedLength += ((0x7f - c) >>> 31) + 1; + } else if (isSurrogate(c)) { + if (!Character.isHighSurrogate(c)) { + encodedLength++; + // WRITE_UTF_UNKNOWN + continue; + } + // Surrogate Pair consumes 2 characters. + if (++i == end) { + encodedLength++; + // WRITE_UTF_UNKNOWN + break; + } + if (!Character.isLowSurrogate(seq.charAt(i))) { + // WRITE_UTF_UNKNOWN + (Character.isHighSurrogate(c2) ? WRITE_UTF_UNKNOWN : c2) + encodedLength += 2; + continue; + } + // See https://www.unicode.org/versions/Unicode7.0.0/ch03.pdf#G2630. + encodedLength += 4; + } else { + encodedLength += 3; + } + } + return encodedLength; + } + + /** + * Encode a {@link CharSequence} in ASCII and write + * it to a {@link ByteBuf} allocated with {@code alloc}. + * @param alloc The allocator used to allocate a new {@link ByteBuf}. + * @param seq The characters to write into a buffer. + * @return The {@link ByteBuf} which contains the ASCII encoded + * result. + */ + public static ByteBuf writeAscii(ByteBufAllocator alloc, CharSequence seq) { + // ASCII uses 1 byte per char + ByteBuf buf = alloc.buffer(seq.length()); + writeAscii(buf, seq); + return buf; + } + + /** + * Encode a {@link CharSequence} in ASCII and write it + * to a {@link ByteBuf}. + * + * This method returns the actual number of bytes written. + */ + public static int writeAscii(ByteBuf buf, CharSequence seq) { + // ASCII uses 1 byte per char + for (;;) { + if (buf instanceof WrappedCompositeByteBuf) { + // WrappedCompositeByteBuf is a sub-class of AbstractByteBuf so it needs special handling. + buf = buf.unwrap(); + } else if (buf instanceof AbstractByteBuf) { + final int len = seq.length(); + AbstractByteBuf byteBuf = (AbstractByteBuf) buf; + byteBuf.ensureWritable0(len); + if (seq instanceof AsciiString) { + writeAsciiString(byteBuf, byteBuf.writerIndex, (AsciiString) seq, 0, len); + } else { + final int written = writeAscii(byteBuf, byteBuf.writerIndex, seq, len); + assert written == len; + } + byteBuf.writerIndex += len; + return len; + } else if (buf instanceof WrappedByteBuf) { + // Unwrap as the wrapped buffer may be an AbstractByteBuf and so we can use fast-path. + buf = buf.unwrap(); + } else { + byte[] bytes = seq.toString().getBytes(CharsetUtil.US_ASCII); + buf.writeBytes(bytes); + return bytes.length; + } + } + } + + static int writeAscii(AbstractByteBuf buffer, int writerIndex, CharSequence seq, int len) { + if (seq instanceof AsciiString) { + writeAsciiString(buffer, writerIndex, (AsciiString) seq, 0, len); + } else { + writeAsciiCharSequence(buffer, writerIndex, seq, len); + } + return len; + } + + private static int writeAsciiCharSequence(AbstractByteBuf buffer, int writerIndex, CharSequence seq, int len) { + // We can use the _set methods as these not need to do any index checks and reference checks. + // This is possible as we called ensureWritable(...) before. + for (int i = 0; i < len; i++) { + buffer._setByte(writerIndex++, AsciiString.c2b(seq.charAt(i))); + } + return len; + } + + /** + * Encode the given {@link CharBuffer} using the given {@link Charset} into a new {@link ByteBuf} which + * is allocated via the {@link ByteBufAllocator}. + */ + public static ByteBuf encodeString(ByteBufAllocator alloc, CharBuffer src, Charset charset) { + return encodeString0(alloc, false, src, charset, 0); + } + + /** + * Encode the given {@link CharBuffer} using the given {@link Charset} into a new {@link ByteBuf} which + * is allocated via the {@link ByteBufAllocator}. + * + * @param alloc The {@link ByteBufAllocator} to allocate {@link ByteBuf}. + * @param src The {@link CharBuffer} to encode. + * @param charset The specified {@link Charset}. + * @param extraCapacity the extra capacity to alloc except the space for decoding. + */ + public static ByteBuf encodeString(ByteBufAllocator alloc, CharBuffer src, Charset charset, int extraCapacity) { + return encodeString0(alloc, false, src, charset, extraCapacity); + } + + static ByteBuf encodeString0(ByteBufAllocator alloc, boolean enforceHeap, CharBuffer src, Charset charset, + int extraCapacity) { + final CharsetEncoder encoder = CharsetUtil.encoder(charset); + int length = (int) ((double) src.remaining() * encoder.maxBytesPerChar()) + extraCapacity; + boolean release = true; + final ByteBuf dst; + if (enforceHeap) { + dst = alloc.heapBuffer(length); + } else { + dst = alloc.buffer(length); + } + try { + final ByteBuffer dstBuf = dst.internalNioBuffer(dst.readerIndex(), length); + final int pos = dstBuf.position(); + CoderResult cr = encoder.encode(src, dstBuf, true); + if (!cr.isUnderflow()) { + cr.throwException(); + } + cr = encoder.flush(dstBuf); + if (!cr.isUnderflow()) { + cr.throwException(); + } + dst.writerIndex(dst.writerIndex() + dstBuf.position() - pos); + release = false; + return dst; + } catch (CharacterCodingException x) { + throw new IllegalStateException(x); + } finally { + if (release) { + dst.release(); + } + } + } + + @SuppressWarnings("deprecation") + static String decodeString(ByteBuf src, int readerIndex, int len, Charset charset) { + if (len == 0) { + return StringUtil.EMPTY_STRING; + } + final byte[] array; + final int offset; + + if (src.hasArray()) { + array = src.array(); + offset = src.arrayOffset() + readerIndex; + } else { + array = threadLocalTempArray(len); + offset = 0; + src.getBytes(readerIndex, array, 0, len); + } + if (CharsetUtil.US_ASCII.equals(charset)) { + // Fast-path for US-ASCII which is used frequently. + return new String(array, 0, offset, len); + } + return new String(array, offset, len, charset); + } + + /** + * Returns a cached thread-local direct buffer, if available. + * + * @return a cached thread-local direct buffer, if available. {@code null} otherwise. + */ + public static ByteBuf threadLocalDirectBuffer() { + if (THREAD_LOCAL_BUFFER_SIZE <= 0) { + return null; + } + + if (PlatformDependent.hasUnsafe()) { + return ThreadLocalUnsafeDirectByteBuf.newInstance(); + } else { + return ThreadLocalDirectByteBuf.newInstance(); + } + } + + /** + * Create a copy of the underlying storage from {@code buf} into a byte array. + * The copy will start at {@link ByteBuf#readerIndex()} and copy {@link ByteBuf#readableBytes()} bytes. + */ + public static byte[] getBytes(ByteBuf buf) { + return getBytes(buf, buf.readerIndex(), buf.readableBytes()); + } + + /** + * Create a copy of the underlying storage from {@code buf} into a byte array. + * The copy will start at {@code start} and copy {@code length} bytes. + */ + public static byte[] getBytes(ByteBuf buf, int start, int length) { + return getBytes(buf, start, length, true); + } + + /** + * Return an array of the underlying storage from {@code buf} into a byte array. + * The copy will start at {@code start} and copy {@code length} bytes. + * If {@code copy} is true a copy will be made of the memory. + * If {@code copy} is false the underlying storage will be shared, if possible. + */ + public static byte[] getBytes(ByteBuf buf, int start, int length, boolean copy) { + int capacity = buf.capacity(); + if (isOutOfBounds(start, length, capacity)) { + throw new IndexOutOfBoundsException("expected: " + "0 <= start(" + start + ") <= start + length(" + length + + ") <= " + "buf.capacity(" + capacity + ')'); + } + + if (buf.hasArray()) { + int baseOffset = buf.arrayOffset() + start; + byte[] bytes = buf.array(); + if (copy || baseOffset != 0 || length != bytes.length) { + return Arrays.copyOfRange(bytes, baseOffset, baseOffset + length); + } else { + return bytes; + } + } + + byte[] bytes = PlatformDependent.allocateUninitializedArray(length); + buf.getBytes(start, bytes); + return bytes; + } + + /** + * Copies the all content of {@code src} to a {@link ByteBuf} using {@link ByteBuf#writeBytes(byte[], int, int)}. + * + * @param src the source string to copy + * @param dst the destination buffer + */ + public static void copy(AsciiString src, ByteBuf dst) { + copy(src, 0, dst, src.length()); + } + + /** + * Copies the content of {@code src} to a {@link ByteBuf} using {@link ByteBuf#setBytes(int, byte[], int, int)}. + * Unlike the {@link #copy(AsciiString, ByteBuf)} and {@link #copy(AsciiString, int, ByteBuf, int)} methods, + * this method do not increase a {@code writerIndex} of {@code dst} buffer. + * + * @param src the source string to copy + * @param srcIdx the starting offset of characters to copy + * @param dst the destination buffer + * @param dstIdx the starting offset in the destination buffer + * @param length the number of characters to copy + */ + public static void copy(AsciiString src, int srcIdx, ByteBuf dst, int dstIdx, int length) { + if (isOutOfBounds(srcIdx, length, src.length())) { + throw new IndexOutOfBoundsException("expected: " + "0 <= srcIdx(" + srcIdx + ") <= srcIdx + length(" + + length + ") <= srcLen(" + src.length() + ')'); + } + + checkNotNull(dst, "dst").setBytes(dstIdx, src.array(), srcIdx + src.arrayOffset(), length); + } + + /** + * Copies the content of {@code src} to a {@link ByteBuf} using {@link ByteBuf#writeBytes(byte[], int, int)}. + * + * @param src the source string to copy + * @param srcIdx the starting offset of characters to copy + * @param dst the destination buffer + * @param length the number of characters to copy + */ + public static void copy(AsciiString src, int srcIdx, ByteBuf dst, int length) { + if (isOutOfBounds(srcIdx, length, src.length())) { + throw new IndexOutOfBoundsException("expected: " + "0 <= srcIdx(" + srcIdx + ") <= srcIdx + length(" + + length + ") <= srcLen(" + src.length() + ')'); + } + + checkNotNull(dst, "dst").writeBytes(src.array(), srcIdx + src.arrayOffset(), length); + } + + /** + * Returns a multi-line hexadecimal dump of the specified {@link ByteBuf} that is easy to read by humans. + */ + public static String prettyHexDump(ByteBuf buffer) { + return prettyHexDump(buffer, buffer.readerIndex(), buffer.readableBytes()); + } + + /** + * Returns a multi-line hexadecimal dump of the specified {@link ByteBuf} that is easy to read by humans, + * starting at the given {@code offset} using the given {@code length}. + */ + public static String prettyHexDump(ByteBuf buffer, int offset, int length) { + return HexUtil.prettyHexDump(buffer, offset, length); + } + + /** + * Appends the prettified multi-line hexadecimal dump of the specified {@link ByteBuf} to the specified + * {@link StringBuilder} that is easy to read by humans. + */ + public static void appendPrettyHexDump(StringBuilder dump, ByteBuf buf) { + appendPrettyHexDump(dump, buf, buf.readerIndex(), buf.readableBytes()); + } + + /** + * Appends the prettified multi-line hexadecimal dump of the specified {@link ByteBuf} to the specified + * {@link StringBuilder} that is easy to read by humans, starting at the given {@code offset} using + * the given {@code length}. + */ + public static void appendPrettyHexDump(StringBuilder dump, ByteBuf buf, int offset, int length) { + HexUtil.appendPrettyHexDump(dump, buf, offset, length); + } + + /* Separate class so that the expensive static initialization is only done when needed */ + private static final class HexUtil { + + private static final char[] BYTE2CHAR = new char[256]; + private static final char[] HEXDUMP_TABLE = new char[256 * 4]; + private static final String[] HEXPADDING = new String[16]; + private static final String[] HEXDUMP_ROWPREFIXES = new String[65536 >>> 4]; + private static final String[] BYTE2HEX = new String[256]; + private static final String[] BYTEPADDING = new String[16]; + + static { + final char[] DIGITS = "0123456789abcdef".toCharArray(); + for (int i = 0; i < 256; i ++) { + HEXDUMP_TABLE[ i << 1 ] = DIGITS[i >>> 4 & 0x0F]; + HEXDUMP_TABLE[(i << 1) + 1] = DIGITS[i & 0x0F]; + } + + int i; + + // Generate the lookup table for hex dump paddings + for (i = 0; i < HEXPADDING.length; i ++) { + int padding = HEXPADDING.length - i; + StringBuilder buf = new StringBuilder(padding * 3); + for (int j = 0; j < padding; j ++) { + buf.append(" "); + } + HEXPADDING[i] = buf.toString(); + } + + // Generate the lookup table for the start-offset header in each row (up to 64KiB). + for (i = 0; i < HEXDUMP_ROWPREFIXES.length; i ++) { + StringBuilder buf = new StringBuilder(12); + buf.append(NEWLINE); + buf.append(Long.toHexString(i << 4 & 0xFFFFFFFFL | 0x100000000L)); + buf.setCharAt(buf.length() - 9, '|'); + buf.append('|'); + HEXDUMP_ROWPREFIXES[i] = buf.toString(); + } + + // Generate the lookup table for byte-to-hex-dump conversion + for (i = 0; i < BYTE2HEX.length; i ++) { + BYTE2HEX[i] = ' ' + StringUtil.byteToHexStringPadded(i); + } + + // Generate the lookup table for byte dump paddings + for (i = 0; i < BYTEPADDING.length; i ++) { + int padding = BYTEPADDING.length - i; + StringBuilder buf = new StringBuilder(padding); + for (int j = 0; j < padding; j ++) { + buf.append(' '); + } + BYTEPADDING[i] = buf.toString(); + } + + // Generate the lookup table for byte-to-char conversion + for (i = 0; i < BYTE2CHAR.length; i ++) { + if (i <= 0x1f || i >= 0x7f) { + BYTE2CHAR[i] = '.'; + } else { + BYTE2CHAR[i] = (char) i; + } + } + } + + private static String hexDump(ByteBuf buffer, int fromIndex, int length) { + checkPositiveOrZero(length, "length"); + if (length == 0) { + return ""; + } + + int endIndex = fromIndex + length; + char[] buf = new char[length << 1]; + + int srcIdx = fromIndex; + int dstIdx = 0; + for (; srcIdx < endIndex; srcIdx ++, dstIdx += 2) { + System.arraycopy( + HEXDUMP_TABLE, buffer.getUnsignedByte(srcIdx) << 1, + buf, dstIdx, 2); + } + + return new String(buf); + } + + private static String hexDump(byte[] array, int fromIndex, int length) { + checkPositiveOrZero(length, "length"); + if (length == 0) { + return ""; + } + + int endIndex = fromIndex + length; + char[] buf = new char[length << 1]; + + int srcIdx = fromIndex; + int dstIdx = 0; + for (; srcIdx < endIndex; srcIdx ++, dstIdx += 2) { + System.arraycopy( + HEXDUMP_TABLE, (array[srcIdx] & 0xFF) << 1, + buf, dstIdx, 2); + } + + return new String(buf); + } + + private static String prettyHexDump(ByteBuf buffer, int offset, int length) { + if (length == 0) { + return StringUtil.EMPTY_STRING; + } else { + int rows = length / 16 + ((length & 15) == 0? 0 : 1) + 4; + StringBuilder buf = new StringBuilder(rows * 80); + appendPrettyHexDump(buf, buffer, offset, length); + return buf.toString(); + } + } + + private static void appendPrettyHexDump(StringBuilder dump, ByteBuf buf, int offset, int length) { + if (isOutOfBounds(offset, length, buf.capacity())) { + throw new IndexOutOfBoundsException( + "expected: " + "0 <= offset(" + offset + ") <= offset + length(" + length + + ") <= " + "buf.capacity(" + buf.capacity() + ')'); + } + if (length == 0) { + return; + } + dump.append( + " +-------------------------------------------------+" + + NEWLINE + " | 0 1 2 3 4 5 6 7 8 9 a b c d e f |" + + NEWLINE + "+--------+-------------------------------------------------+----------------+"); + + final int fullRows = length >>> 4; + final int remainder = length & 0xF; + + // Dump the rows which have 16 bytes. + for (int row = 0; row < fullRows; row ++) { + int rowStartIndex = (row << 4) + offset; + + // Per-row prefix. + appendHexDumpRowPrefix(dump, row, rowStartIndex); + + // Hex dump + int rowEndIndex = rowStartIndex + 16; + for (int j = rowStartIndex; j < rowEndIndex; j ++) { + dump.append(BYTE2HEX[buf.getUnsignedByte(j)]); + } + dump.append(" |"); + + // ASCII dump + for (int j = rowStartIndex; j < rowEndIndex; j ++) { + dump.append(BYTE2CHAR[buf.getUnsignedByte(j)]); + } + dump.append('|'); + } + + // Dump the last row which has less than 16 bytes. + if (remainder != 0) { + int rowStartIndex = (fullRows << 4) + offset; + appendHexDumpRowPrefix(dump, fullRows, rowStartIndex); + + // Hex dump + int rowEndIndex = rowStartIndex + remainder; + for (int j = rowStartIndex; j < rowEndIndex; j ++) { + dump.append(BYTE2HEX[buf.getUnsignedByte(j)]); + } + dump.append(HEXPADDING[remainder]); + dump.append(" |"); + + // Ascii dump + for (int j = rowStartIndex; j < rowEndIndex; j ++) { + dump.append(BYTE2CHAR[buf.getUnsignedByte(j)]); + } + dump.append(BYTEPADDING[remainder]); + dump.append('|'); + } + + dump.append(NEWLINE + + "+--------+-------------------------------------------------+----------------+"); + } + + private static void appendHexDumpRowPrefix(StringBuilder dump, int row, int rowStartIndex) { + if (row < HEXDUMP_ROWPREFIXES.length) { + dump.append(HEXDUMP_ROWPREFIXES[row]); + } else { + dump.append(NEWLINE); + dump.append(Long.toHexString(rowStartIndex & 0xFFFFFFFFL | 0x100000000L)); + dump.setCharAt(dump.length() - 9, '|'); + dump.append('|'); + } + } + } + + static final class ThreadLocalUnsafeDirectByteBuf extends UnpooledUnsafeDirectByteBuf { + + private static final ObjectPool RECYCLER = + ObjectPool.newPool(new ObjectCreator() { + @Override + public ThreadLocalUnsafeDirectByteBuf newObject(Handle handle) { + return new ThreadLocalUnsafeDirectByteBuf(handle); + } + }); + + static ThreadLocalUnsafeDirectByteBuf newInstance() { + ThreadLocalUnsafeDirectByteBuf buf = RECYCLER.get(); + buf.resetRefCnt(); + return buf; + } + + private final EnhancedHandle handle; + + private ThreadLocalUnsafeDirectByteBuf(Handle handle) { + super(UnpooledByteBufAllocator.DEFAULT, 256, Integer.MAX_VALUE); + this.handle = (EnhancedHandle) handle; + } + + @Override + protected void deallocate() { + if (capacity() > THREAD_LOCAL_BUFFER_SIZE) { + super.deallocate(); + } else { + clear(); + handle.unguardedRecycle(this); + } + } + } + + static final class ThreadLocalDirectByteBuf extends UnpooledDirectByteBuf { + + private static final ObjectPool RECYCLER = ObjectPool.newPool( + new ObjectCreator() { + @Override + public ThreadLocalDirectByteBuf newObject(Handle handle) { + return new ThreadLocalDirectByteBuf(handle); + } + }); + + static ThreadLocalDirectByteBuf newInstance() { + ThreadLocalDirectByteBuf buf = RECYCLER.get(); + buf.resetRefCnt(); + return buf; + } + + private final EnhancedHandle handle; + + private ThreadLocalDirectByteBuf(Handle handle) { + super(UnpooledByteBufAllocator.DEFAULT, 256, Integer.MAX_VALUE); + this.handle = (EnhancedHandle) handle; + } + + @Override + protected void deallocate() { + if (capacity() > THREAD_LOCAL_BUFFER_SIZE) { + super.deallocate(); + } else { + clear(); + handle.unguardedRecycle(this); + } + } + } + + /** + * Returns {@code true} if the given {@link ByteBuf} is valid text using the given {@link Charset}, + * otherwise return {@code false}. + * + * @param buf The given {@link ByteBuf}. + * @param charset The specified {@link Charset}. + */ + public static boolean isText(ByteBuf buf, Charset charset) { + return isText(buf, buf.readerIndex(), buf.readableBytes(), charset); + } + + /** + * Returns {@code true} if the specified {@link ByteBuf} starting at {@code index} with {@code length} is valid + * text using the given {@link Charset}, otherwise return {@code false}. + * + * @param buf The given {@link ByteBuf}. + * @param index The start index of the specified buffer. + * @param length The length of the specified buffer. + * @param charset The specified {@link Charset}. + * + * @throws IndexOutOfBoundsException if {@code index} + {@code length} is greater than {@code buf.readableBytes} + */ + public static boolean isText(ByteBuf buf, int index, int length, Charset charset) { + checkNotNull(buf, "buf"); + checkNotNull(charset, "charset"); + final int maxIndex = buf.readerIndex() + buf.readableBytes(); + if (index < 0 || length < 0 || index > maxIndex - length) { + throw new IndexOutOfBoundsException("index: " + index + " length: " + length); + } + if (charset.equals(CharsetUtil.UTF_8)) { + return isUtf8(buf, index, length); + } else if (charset.equals(CharsetUtil.US_ASCII)) { + return isAscii(buf, index, length); + } else { + CharsetDecoder decoder = CharsetUtil.decoder(charset, CodingErrorAction.REPORT, CodingErrorAction.REPORT); + try { + if (buf.nioBufferCount() == 1) { + decoder.decode(buf.nioBuffer(index, length)); + } else { + ByteBuf heapBuffer = buf.alloc().heapBuffer(length); + try { + heapBuffer.writeBytes(buf, index, length); + decoder.decode(heapBuffer.internalNioBuffer(heapBuffer.readerIndex(), length)); + } finally { + heapBuffer.release(); + } + } + return true; + } catch (CharacterCodingException ignore) { + return false; + } + } + } + + /** + * Aborts on a byte which is not a valid ASCII character. + */ + private static final ByteProcessor FIND_NON_ASCII = new ByteProcessor() { + @Override + public boolean process(byte value) { + return value >= 0; + } + }; + + /** + * Returns {@code true} if the specified {@link ByteBuf} starting at {@code index} with {@code length} is valid + * ASCII text, otherwise return {@code false}. + * + * @param buf The given {@link ByteBuf}. + * @param index The start index of the specified buffer. + * @param length The length of the specified buffer. + */ + private static boolean isAscii(ByteBuf buf, int index, int length) { + return buf.forEachByte(index, length, FIND_NON_ASCII) == -1; + } + + /** + * Returns {@code true} if the specified {@link ByteBuf} starting at {@code index} with {@code length} is valid + * UTF8 text, otherwise return {@code false}. + * + * @param buf The given {@link ByteBuf}. + * @param index The start index of the specified buffer. + * @param length The length of the specified buffer. + * + * @see + * UTF-8 Definition + * + *

+     * 1. Bytes format of UTF-8
+     *
+     * The table below summarizes the format of these different octet types.
+     * The letter x indicates bits available for encoding bits of the character number.
+     *
+     * Char. number range  |        UTF-8 octet sequence
+     *    (hexadecimal)    |              (binary)
+     * --------------------+---------------------------------------------
+     * 0000 0000-0000 007F | 0xxxxxxx
+     * 0000 0080-0000 07FF | 110xxxxx 10xxxxxx
+     * 0000 0800-0000 FFFF | 1110xxxx 10xxxxxx 10xxxxxx
+     * 0001 0000-0010 FFFF | 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx
+     * 
+ * + *
+     * 2. Syntax of UTF-8 Byte Sequences
+     *
+     * UTF8-octets = *( UTF8-char )
+     * UTF8-char   = UTF8-1 / UTF8-2 / UTF8-3 / UTF8-4
+     * UTF8-1      = %x00-7F
+     * UTF8-2      = %xC2-DF UTF8-tail
+     * UTF8-3      = %xE0 %xA0-BF UTF8-tail /
+     *               %xE1-EC 2( UTF8-tail ) /
+     *               %xED %x80-9F UTF8-tail /
+     *               %xEE-EF 2( UTF8-tail )
+     * UTF8-4      = %xF0 %x90-BF 2( UTF8-tail ) /
+     *               %xF1-F3 3( UTF8-tail ) /
+     *               %xF4 %x80-8F 2( UTF8-tail )
+     * UTF8-tail   = %x80-BF
+     * 
+ */ + private static boolean isUtf8(ByteBuf buf, int index, int length) { + final int endIndex = index + length; + while (index < endIndex) { + byte b1 = buf.getByte(index++); + byte b2, b3, b4; + if ((b1 & 0x80) == 0) { + // 1 byte + continue; + } + if ((b1 & 0xE0) == 0xC0) { + // 2 bytes + // + // Bit/Byte pattern + // 110xxxxx 10xxxxxx + // C2..DF 80..BF + if (index >= endIndex) { // no enough bytes + return false; + } + b2 = buf.getByte(index++); + if ((b2 & 0xC0) != 0x80) { // 2nd byte not starts with 10 + return false; + } + if ((b1 & 0xFF) < 0xC2) { // out of lower bound + return false; + } + } else if ((b1 & 0xF0) == 0xE0) { + // 3 bytes + // + // Bit/Byte pattern + // 1110xxxx 10xxxxxx 10xxxxxx + // E0 A0..BF 80..BF + // E1..EC 80..BF 80..BF + // ED 80..9F 80..BF + // E1..EF 80..BF 80..BF + if (index > endIndex - 2) { // no enough bytes + return false; + } + b2 = buf.getByte(index++); + b3 = buf.getByte(index++); + if ((b2 & 0xC0) != 0x80 || (b3 & 0xC0) != 0x80) { // 2nd or 3rd bytes not start with 10 + return false; + } + if ((b1 & 0x0F) == 0x00 && (b2 & 0xFF) < 0xA0) { // out of lower bound + return false; + } + if ((b1 & 0x0F) == 0x0D && (b2 & 0xFF) > 0x9F) { // out of upper bound + return false; + } + } else if ((b1 & 0xF8) == 0xF0) { + // 4 bytes + // + // Bit/Byte pattern + // 11110xxx 10xxxxxx 10xxxxxx 10xxxxxx + // F0 90..BF 80..BF 80..BF + // F1..F3 80..BF 80..BF 80..BF + // F4 80..8F 80..BF 80..BF + if (index > endIndex - 3) { // no enough bytes + return false; + } + b2 = buf.getByte(index++); + b3 = buf.getByte(index++); + b4 = buf.getByte(index++); + if ((b2 & 0xC0) != 0x80 || (b3 & 0xC0) != 0x80 || (b4 & 0xC0) != 0x80) { + // 2nd, 3rd or 4th bytes not start with 10 + return false; + } + if ((b1 & 0xFF) > 0xF4 // b1 invalid + || (b1 & 0xFF) == 0xF0 && (b2 & 0xFF) < 0x90 // b2 out of lower bound + || (b1 & 0xFF) == 0xF4 && (b2 & 0xFF) > 0x8F) { // b2 out of upper bound + return false; + } + } else { + return false; + } + } + return true; + } + + /** + * Read bytes from the given {@link ByteBuffer} into the given {@link OutputStream} using the {@code position} and + * {@code length}. The position and limit of the given {@link ByteBuffer} may be adjusted. + */ + static void readBytes(ByteBufAllocator allocator, ByteBuffer buffer, int position, int length, OutputStream out) + throws IOException { + if (buffer.hasArray()) { + out.write(buffer.array(), position + buffer.arrayOffset(), length); + } else { + int chunkLen = Math.min(length, WRITE_CHUNK_SIZE); + buffer.clear().position(position); + + if (length <= MAX_TL_ARRAY_LEN || !allocator.isDirectBufferPooled()) { + getBytes(buffer, threadLocalTempArray(chunkLen), 0, chunkLen, out, length); + } else { + // if direct buffers are pooled chances are good that heap buffers are pooled as well. + ByteBuf tmpBuf = allocator.heapBuffer(chunkLen); + try { + byte[] tmp = tmpBuf.array(); + int offset = tmpBuf.arrayOffset(); + getBytes(buffer, tmp, offset, chunkLen, out, length); + } finally { + tmpBuf.release(); + } + } + } + } + + private static void getBytes(ByteBuffer inBuffer, byte[] in, int inOffset, int inLen, OutputStream out, int outLen) + throws IOException { + do { + int len = Math.min(inLen, outLen); + inBuffer.get(in, inOffset, len); + out.write(in, inOffset, len); + outLen -= len; + } while (outLen > 0); + } + + /** + * Set {@link AbstractByteBuf#leakDetector}'s {@link ResourceLeakDetector.LeakListener}. + * + * @param leakListener If leakListener is not null, it will be notified once a ByteBuf leak is detected. + */ + public static void setLeakListener(ResourceLeakDetector.LeakListener leakListener) { + AbstractByteBuf.leakDetector.setLeakListener(leakListener); + } + + private ByteBufUtil() { } +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/CompositeByteBuf.java b/netty-buffer/src/main/java/io/netty/buffer/CompositeByteBuf.java new file mode 100644 index 0000000..4ad8613 --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/CompositeByteBuf.java @@ -0,0 +1,2363 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import io.netty.util.ByteProcessor; +import io.netty.util.IllegalReferenceCountException; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.internal.EmptyArrays; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.RecyclableArrayList; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.channels.FileChannel; +import java.nio.channels.GatheringByteChannel; +import java.nio.channels.ScatteringByteChannel; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.ConcurrentModificationException; +import java.util.Iterator; +import java.util.List; +import java.util.NoSuchElementException; + +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +/** + * A virtual buffer which shows multiple buffers as a single merged buffer. It is recommended to use + * {@link ByteBufAllocator#compositeBuffer()} or {@link Unpooled#wrappedBuffer(ByteBuf...)} instead of calling the + * constructor explicitly. + */ +public class CompositeByteBuf extends AbstractReferenceCountedByteBuf implements Iterable { + + private static final ByteBuffer EMPTY_NIO_BUFFER = Unpooled.EMPTY_BUFFER.nioBuffer(); + private static final Iterator EMPTY_ITERATOR = Collections.emptyList().iterator(); + + private final ByteBufAllocator alloc; + private final boolean direct; + private final int maxNumComponents; + + private int componentCount; + private Component[] components; // resized when needed + + private boolean freed; + + private CompositeByteBuf(ByteBufAllocator alloc, boolean direct, int maxNumComponents, int initSize) { + super(AbstractByteBufAllocator.DEFAULT_MAX_CAPACITY); + + this.alloc = ObjectUtil.checkNotNull(alloc, "alloc"); + if (maxNumComponents < 1) { + throw new IllegalArgumentException( + "maxNumComponents: " + maxNumComponents + " (expected: >= 1)"); + } + + this.direct = direct; + this.maxNumComponents = maxNumComponents; + components = newCompArray(initSize, maxNumComponents); + } + + public CompositeByteBuf(ByteBufAllocator alloc, boolean direct, int maxNumComponents) { + this(alloc, direct, maxNumComponents, 0); + } + + public CompositeByteBuf(ByteBufAllocator alloc, boolean direct, int maxNumComponents, ByteBuf... buffers) { + this(alloc, direct, maxNumComponents, buffers, 0); + } + + CompositeByteBuf(ByteBufAllocator alloc, boolean direct, int maxNumComponents, + ByteBuf[] buffers, int offset) { + this(alloc, direct, maxNumComponents, buffers.length - offset); + + addComponents0(false, 0, buffers, offset); + consolidateIfNeeded(); + setIndex0(0, capacity()); + } + + public CompositeByteBuf( + ByteBufAllocator alloc, boolean direct, int maxNumComponents, Iterable buffers) { + this(alloc, direct, maxNumComponents, + buffers instanceof Collection ? ((Collection) buffers).size() : 0); + + addComponents(false, 0, buffers); + setIndex(0, capacity()); + } + + // support passing arrays of other types instead of having to copy to a ByteBuf[] first + interface ByteWrapper { + ByteBuf wrap(T bytes); + boolean isEmpty(T bytes); + } + + static final ByteWrapper BYTE_ARRAY_WRAPPER = new ByteWrapper() { + @Override + public ByteBuf wrap(byte[] bytes) { + return Unpooled.wrappedBuffer(bytes); + } + @Override + public boolean isEmpty(byte[] bytes) { + return bytes.length == 0; + } + }; + + static final ByteWrapper BYTE_BUFFER_WRAPPER = new ByteWrapper() { + @Override + public ByteBuf wrap(ByteBuffer bytes) { + return Unpooled.wrappedBuffer(bytes); + } + @Override + public boolean isEmpty(ByteBuffer bytes) { + return !bytes.hasRemaining(); + } + }; + + CompositeByteBuf(ByteBufAllocator alloc, boolean direct, int maxNumComponents, + ByteWrapper wrapper, T[] buffers, int offset) { + this(alloc, direct, maxNumComponents, buffers.length - offset); + + addComponents0(false, 0, wrapper, buffers, offset); + consolidateIfNeeded(); + setIndex(0, capacity()); + } + + private static Component[] newCompArray(int initComponents, int maxNumComponents) { + int capacityGuess = Math.min(AbstractByteBufAllocator.DEFAULT_MAX_COMPONENTS, maxNumComponents); + return new Component[Math.max(initComponents, capacityGuess)]; + } + + // Special constructor used by WrappedCompositeByteBuf + CompositeByteBuf(ByteBufAllocator alloc) { + super(Integer.MAX_VALUE); + this.alloc = alloc; + direct = false; + maxNumComponents = 0; + components = null; + } + + /** + * Add the given {@link ByteBuf}. + *

+ * Be aware that this method does not increase the {@code writerIndex} of the {@link CompositeByteBuf}. + * If you need to have it increased use {@link #addComponent(boolean, ByteBuf)}. + *

+ * {@link ByteBuf#release()} ownership of {@code buffer} is transferred to this {@link CompositeByteBuf}. + * @param buffer the {@link ByteBuf} to add. {@link ByteBuf#release()} ownership is transferred to this + * {@link CompositeByteBuf}. + */ + public CompositeByteBuf addComponent(ByteBuf buffer) { + return addComponent(false, buffer); + } + + /** + * Add the given {@link ByteBuf}s. + *

+ * Be aware that this method does not increase the {@code writerIndex} of the {@link CompositeByteBuf}. + * If you need to have it increased use {@link #addComponents(boolean, ByteBuf[])}. + *

+ * {@link ByteBuf#release()} ownership of all {@link ByteBuf} objects in {@code buffers} is transferred to this + * {@link CompositeByteBuf}. + * @param buffers the {@link ByteBuf}s to add. {@link ByteBuf#release()} ownership of all {@link ByteBuf#release()} + * ownership of all {@link ByteBuf} objects is transferred to this {@link CompositeByteBuf}. + */ + public CompositeByteBuf addComponents(ByteBuf... buffers) { + return addComponents(false, buffers); + } + + /** + * Add the given {@link ByteBuf}s. + *

+ * Be aware that this method does not increase the {@code writerIndex} of the {@link CompositeByteBuf}. + * If you need to have it increased use {@link #addComponents(boolean, Iterable)}. + *

+ * {@link ByteBuf#release()} ownership of all {@link ByteBuf} objects in {@code buffers} is transferred to this + * {@link CompositeByteBuf}. + * @param buffers the {@link ByteBuf}s to add. {@link ByteBuf#release()} ownership of all {@link ByteBuf#release()} + * ownership of all {@link ByteBuf} objects is transferred to this {@link CompositeByteBuf}. + */ + public CompositeByteBuf addComponents(Iterable buffers) { + return addComponents(false, buffers); + } + + /** + * Add the given {@link ByteBuf} on the specific index. + *

+ * Be aware that this method does not increase the {@code writerIndex} of the {@link CompositeByteBuf}. + * If you need to have it increased use {@link #addComponent(boolean, int, ByteBuf)}. + *

+ * {@link ByteBuf#release()} ownership of {@code buffer} is transferred to this {@link CompositeByteBuf}. + * @param cIndex the index on which the {@link ByteBuf} will be added. + * @param buffer the {@link ByteBuf} to add. {@link ByteBuf#release()} ownership is transferred to this + * {@link CompositeByteBuf}. + */ + public CompositeByteBuf addComponent(int cIndex, ByteBuf buffer) { + return addComponent(false, cIndex, buffer); + } + + /** + * Add the given {@link ByteBuf} and increase the {@code writerIndex} if {@code increaseWriterIndex} is + * {@code true}. + * + * {@link ByteBuf#release()} ownership of {@code buffer} is transferred to this {@link CompositeByteBuf}. + * @param buffer the {@link ByteBuf} to add. {@link ByteBuf#release()} ownership is transferred to this + * {@link CompositeByteBuf}. + */ + public CompositeByteBuf addComponent(boolean increaseWriterIndex, ByteBuf buffer) { + return addComponent(increaseWriterIndex, componentCount, buffer); + } + + /** + * Add the given {@link ByteBuf}s and increase the {@code writerIndex} if {@code increaseWriterIndex} is + * {@code true}. + * + * {@link ByteBuf#release()} ownership of all {@link ByteBuf} objects in {@code buffers} is transferred to this + * {@link CompositeByteBuf}. + * @param buffers the {@link ByteBuf}s to add. {@link ByteBuf#release()} ownership of all {@link ByteBuf#release()} + * ownership of all {@link ByteBuf} objects is transferred to this {@link CompositeByteBuf}. + */ + public CompositeByteBuf addComponents(boolean increaseWriterIndex, ByteBuf... buffers) { + checkNotNull(buffers, "buffers"); + addComponents0(increaseWriterIndex, componentCount, buffers, 0); + consolidateIfNeeded(); + return this; + } + + /** + * Add the given {@link ByteBuf}s and increase the {@code writerIndex} if {@code increaseWriterIndex} is + * {@code true}. + * + * {@link ByteBuf#release()} ownership of all {@link ByteBuf} objects in {@code buffers} is transferred to this + * {@link CompositeByteBuf}. + * @param buffers the {@link ByteBuf}s to add. {@link ByteBuf#release()} ownership of all {@link ByteBuf#release()} + * ownership of all {@link ByteBuf} objects is transferred to this {@link CompositeByteBuf}. + */ + public CompositeByteBuf addComponents(boolean increaseWriterIndex, Iterable buffers) { + return addComponents(increaseWriterIndex, componentCount, buffers); + } + + /** + * Add the given {@link ByteBuf} on the specific index and increase the {@code writerIndex} + * if {@code increaseWriterIndex} is {@code true}. + * + * {@link ByteBuf#release()} ownership of {@code buffer} is transferred to this {@link CompositeByteBuf}. + * @param cIndex the index on which the {@link ByteBuf} will be added. + * @param buffer the {@link ByteBuf} to add. {@link ByteBuf#release()} ownership is transferred to this + * {@link CompositeByteBuf}. + */ + public CompositeByteBuf addComponent(boolean increaseWriterIndex, int cIndex, ByteBuf buffer) { + checkNotNull(buffer, "buffer"); + addComponent0(increaseWriterIndex, cIndex, buffer); + consolidateIfNeeded(); + return this; + } + + private static void checkForOverflow(int capacity, int readableBytes) { + if (capacity + readableBytes < 0) { + throw new IllegalArgumentException("Can't increase by " + readableBytes + " as capacity(" + capacity + ")" + + " would overflow " + Integer.MAX_VALUE); + } + } + + /** + * Precondition is that {@code buffer != null}. + */ + private int addComponent0(boolean increaseWriterIndex, int cIndex, ByteBuf buffer) { + assert buffer != null; + boolean wasAdded = false; + try { + checkComponentIndex(cIndex); + + // No need to consolidate - just add a component to the list. + Component c = newComponent(ensureAccessible(buffer), 0); + int readableBytes = c.length(); + + // Check if we would overflow. + // See https://github.com/netty/netty/issues/10194 + checkForOverflow(capacity(), readableBytes); + + addComp(cIndex, c); + wasAdded = true; + if (readableBytes > 0 && cIndex < componentCount - 1) { + updateComponentOffsets(cIndex); + } else if (cIndex > 0) { + c.reposition(components[cIndex - 1].endOffset); + } + if (increaseWriterIndex) { + writerIndex += readableBytes; + } + return cIndex; + } finally { + if (!wasAdded) { + buffer.release(); + } + } + } + + private static ByteBuf ensureAccessible(final ByteBuf buf) { + if (checkAccessible && !buf.isAccessible()) { + throw new IllegalReferenceCountException(0); + } + return buf; + } + + @SuppressWarnings("deprecation") + private Component newComponent(final ByteBuf buf, final int offset) { + final int srcIndex = buf.readerIndex(); + final int len = buf.readableBytes(); + + // unpeel any intermediate outer layers (UnreleasableByteBuf, LeakAwareByteBufs, SwappedByteBuf) + ByteBuf unwrapped = buf; + int unwrappedIndex = srcIndex; + while (unwrapped instanceof WrappedByteBuf || unwrapped instanceof SwappedByteBuf) { + unwrapped = unwrapped.unwrap(); + } + + // unwrap if already sliced + if (unwrapped instanceof AbstractUnpooledSlicedByteBuf) { + unwrappedIndex += ((AbstractUnpooledSlicedByteBuf) unwrapped).idx(0); + unwrapped = unwrapped.unwrap(); + } else if (unwrapped instanceof PooledSlicedByteBuf) { + unwrappedIndex += ((PooledSlicedByteBuf) unwrapped).adjustment; + unwrapped = unwrapped.unwrap(); + } else if (unwrapped instanceof DuplicatedByteBuf || unwrapped instanceof PooledDuplicatedByteBuf) { + unwrapped = unwrapped.unwrap(); + } + + // We don't need to slice later to expose the internal component if the readable range + // is already the entire buffer + final ByteBuf slice = buf.capacity() == len ? buf : null; + + return new Component(buf.order(ByteOrder.BIG_ENDIAN), srcIndex, + unwrapped.order(ByteOrder.BIG_ENDIAN), unwrappedIndex, offset, len, slice); + } + + /** + * Add the given {@link ByteBuf}s on the specific index + *

+ * Be aware that this method does not increase the {@code writerIndex} of the {@link CompositeByteBuf}. + * If you need to have it increased you need to handle it by your own. + *

+ * {@link ByteBuf#release()} ownership of all {@link ByteBuf} objects in {@code buffers} is transferred to this + * {@link CompositeByteBuf}. + * @param cIndex the index on which the {@link ByteBuf} will be added. {@link ByteBuf#release()} ownership of all + * {@link ByteBuf#release()} ownership of all {@link ByteBuf} objects is transferred to this + * {@link CompositeByteBuf}. + * @param buffers the {@link ByteBuf}s to add. {@link ByteBuf#release()} ownership of all {@link ByteBuf#release()} + * ownership of all {@link ByteBuf} objects is transferred to this {@link CompositeByteBuf}. + */ + public CompositeByteBuf addComponents(int cIndex, ByteBuf... buffers) { + checkNotNull(buffers, "buffers"); + addComponents0(false, cIndex, buffers, 0); + consolidateIfNeeded(); + return this; + } + + private CompositeByteBuf addComponents0(boolean increaseWriterIndex, + final int cIndex, ByteBuf[] buffers, int arrOffset) { + final int len = buffers.length, count = len - arrOffset; + + int readableBytes = 0; + int capacity = capacity(); + for (int i = arrOffset; i < buffers.length; i++) { + ByteBuf b = buffers[i]; + if (b == null) { + break; + } + readableBytes += b.readableBytes(); + + // Check if we would overflow. + // See https://github.com/netty/netty/issues/10194 + checkForOverflow(capacity, readableBytes); + } + // only set ci after we've shifted so that finally block logic is always correct + int ci = Integer.MAX_VALUE; + try { + checkComponentIndex(cIndex); + shiftComps(cIndex, count); // will increase componentCount + int nextOffset = cIndex > 0 ? components[cIndex - 1].endOffset : 0; + for (ci = cIndex; arrOffset < len; arrOffset++, ci++) { + ByteBuf b = buffers[arrOffset]; + if (b == null) { + break; + } + Component c = newComponent(ensureAccessible(b), nextOffset); + components[ci] = c; + nextOffset = c.endOffset; + } + return this; + } finally { + // ci is now the index following the last successfully added component + if (ci < componentCount) { + if (ci < cIndex + count) { + // we bailed early + removeCompRange(ci, cIndex + count); + for (; arrOffset < len; ++arrOffset) { + ReferenceCountUtil.safeRelease(buffers[arrOffset]); + } + } + updateComponentOffsets(ci); // only need to do this here for components after the added ones + } + if (increaseWriterIndex && ci > cIndex && ci <= componentCount) { + writerIndex += components[ci - 1].endOffset - components[cIndex].offset; + } + } + } + + private int addComponents0(boolean increaseWriterIndex, int cIndex, + ByteWrapper wrapper, T[] buffers, int offset) { + checkComponentIndex(cIndex); + + // No need for consolidation + for (int i = offset, len = buffers.length; i < len; i++) { + T b = buffers[i]; + if (b == null) { + break; + } + if (!wrapper.isEmpty(b)) { + cIndex = addComponent0(increaseWriterIndex, cIndex, wrapper.wrap(b)) + 1; + int size = componentCount; + if (cIndex > size) { + cIndex = size; + } + } + } + return cIndex; + } + + /** + * Add the given {@link ByteBuf}s on the specific index + * + * Be aware that this method does not increase the {@code writerIndex} of the {@link CompositeByteBuf}. + * If you need to have it increased you need to handle it by your own. + *

+ * {@link ByteBuf#release()} ownership of all {@link ByteBuf} objects in {@code buffers} is transferred to this + * {@link CompositeByteBuf}. + * @param cIndex the index on which the {@link ByteBuf} will be added. + * @param buffers the {@link ByteBuf}s to add. {@link ByteBuf#release()} ownership of all + * {@link ByteBuf#release()} ownership of all {@link ByteBuf} objects is transferred to this + * {@link CompositeByteBuf}. + */ + public CompositeByteBuf addComponents(int cIndex, Iterable buffers) { + return addComponents(false, cIndex, buffers); + } + + /** + * Add the given {@link ByteBuf} and increase the {@code writerIndex} if {@code increaseWriterIndex} is + * {@code true}. If the provided buffer is a {@link CompositeByteBuf} itself, a "shallow copy" of its + * readable components will be performed. Thus the actual number of new components added may vary + * and in particular will be zero if the provided buffer is not readable. + *

+ * {@link ByteBuf#release()} ownership of {@code buffer} is transferred to this {@link CompositeByteBuf}. + * @param buffer the {@link ByteBuf} to add. {@link ByteBuf#release()} ownership is transferred to this + * {@link CompositeByteBuf}. + */ + public CompositeByteBuf addFlattenedComponents(boolean increaseWriterIndex, ByteBuf buffer) { + checkNotNull(buffer, "buffer"); + final int ridx = buffer.readerIndex(); + final int widx = buffer.writerIndex(); + if (ridx == widx) { + buffer.release(); + return this; + } + if (!(buffer instanceof CompositeByteBuf)) { + addComponent0(increaseWriterIndex, componentCount, buffer); + consolidateIfNeeded(); + return this; + } + final CompositeByteBuf from; + if (buffer instanceof WrappedCompositeByteBuf) { + from = (CompositeByteBuf) buffer.unwrap(); + } else { + from = (CompositeByteBuf) buffer; + } + from.checkIndex(ridx, widx - ridx); + final Component[] fromComponents = from.components; + final int compCountBefore = componentCount; + final int writerIndexBefore = writerIndex; + try { + for (int cidx = from.toComponentIndex0(ridx), newOffset = capacity();; cidx++) { + final Component component = fromComponents[cidx]; + final int compOffset = component.offset; + final int fromIdx = Math.max(ridx, compOffset); + final int toIdx = Math.min(widx, component.endOffset); + final int len = toIdx - fromIdx; + if (len > 0) { // skip empty components + addComp(componentCount, new Component( + component.srcBuf.retain(), component.srcIdx(fromIdx), + component.buf, component.idx(fromIdx), newOffset, len, null)); + } + if (widx == toIdx) { + break; + } + newOffset += len; + } + if (increaseWriterIndex) { + writerIndex = writerIndexBefore + (widx - ridx); + } + consolidateIfNeeded(); + buffer.release(); + buffer = null; + return this; + } finally { + if (buffer != null) { + // if we did not succeed, attempt to rollback any components that were added + if (increaseWriterIndex) { + writerIndex = writerIndexBefore; + } + for (int cidx = componentCount - 1; cidx >= compCountBefore; cidx--) { + components[cidx].free(); + removeComp(cidx); + } + } + } + } + + // TODO optimize further, similar to ByteBuf[] version + // (difference here is that we don't know *always* know precise size increase in advance, + // but we do in the most common case that the Iterable is a Collection) + private CompositeByteBuf addComponents(boolean increaseIndex, int cIndex, Iterable buffers) { + if (buffers instanceof ByteBuf) { + // If buffers also implements ByteBuf (e.g. CompositeByteBuf), it has to go to addComponent(ByteBuf). + return addComponent(increaseIndex, cIndex, (ByteBuf) buffers); + } + checkNotNull(buffers, "buffers"); + Iterator it = buffers.iterator(); + try { + checkComponentIndex(cIndex); + + // No need for consolidation + while (it.hasNext()) { + ByteBuf b = it.next(); + if (b == null) { + break; + } + cIndex = addComponent0(increaseIndex, cIndex, b) + 1; + cIndex = Math.min(cIndex, componentCount); + } + } finally { + while (it.hasNext()) { + ReferenceCountUtil.safeRelease(it.next()); + } + } + consolidateIfNeeded(); + return this; + } + + /** + * This should only be called as last operation from a method as this may adjust the underlying + * array of components and so affect the index etc. + */ + private void consolidateIfNeeded() { + // Consolidate if the number of components will exceed the allowed maximum by the current + // operation. + int size = componentCount; + if (size > maxNumComponents) { + consolidate0(0, size); + } + } + + private void checkComponentIndex(int cIndex) { + ensureAccessible(); + if (cIndex < 0 || cIndex > componentCount) { + throw new IndexOutOfBoundsException(String.format( + "cIndex: %d (expected: >= 0 && <= numComponents(%d))", + cIndex, componentCount)); + } + } + + private void checkComponentIndex(int cIndex, int numComponents) { + ensureAccessible(); + if (cIndex < 0 || cIndex + numComponents > componentCount) { + throw new IndexOutOfBoundsException(String.format( + "cIndex: %d, numComponents: %d " + + "(expected: cIndex >= 0 && cIndex + numComponents <= totalNumComponents(%d))", + cIndex, numComponents, componentCount)); + } + } + + private void updateComponentOffsets(int cIndex) { + int size = componentCount; + if (size <= cIndex) { + return; + } + + int nextIndex = cIndex > 0 ? components[cIndex - 1].endOffset : 0; + for (; cIndex < size; cIndex++) { + Component c = components[cIndex]; + c.reposition(nextIndex); + nextIndex = c.endOffset; + } + } + + /** + * Remove the {@link ByteBuf} from the given index. + * + * @param cIndex the index on from which the {@link ByteBuf} will be remove + */ + public CompositeByteBuf removeComponent(int cIndex) { + checkComponentIndex(cIndex); + Component comp = components[cIndex]; + if (lastAccessed == comp) { + lastAccessed = null; + } + comp.free(); + removeComp(cIndex); + if (comp.length() > 0) { + // Only need to call updateComponentOffsets if the length was > 0 + updateComponentOffsets(cIndex); + } + return this; + } + + /** + * Remove the number of {@link ByteBuf}s starting from the given index. + * + * @param cIndex the index on which the {@link ByteBuf}s will be started to removed + * @param numComponents the number of components to remove + */ + public CompositeByteBuf removeComponents(int cIndex, int numComponents) { + checkComponentIndex(cIndex, numComponents); + + if (numComponents == 0) { + return this; + } + int endIndex = cIndex + numComponents; + boolean needsUpdate = false; + for (int i = cIndex; i < endIndex; ++i) { + Component c = components[i]; + if (c.length() > 0) { + needsUpdate = true; + } + if (lastAccessed == c) { + lastAccessed = null; + } + c.free(); + } + removeCompRange(cIndex, endIndex); + + if (needsUpdate) { + // Only need to call updateComponentOffsets if the length was > 0 + updateComponentOffsets(cIndex); + } + return this; + } + + @Override + public Iterator iterator() { + ensureAccessible(); + return componentCount == 0 ? EMPTY_ITERATOR : new CompositeByteBufIterator(); + } + + @Override + protected int forEachByteAsc0(int start, int end, ByteProcessor processor) throws Exception { + if (end <= start) { + return -1; + } + for (int i = toComponentIndex0(start), length = end - start; length > 0; i++) { + Component c = components[i]; + if (c.offset == c.endOffset) { + continue; // empty + } + ByteBuf s = c.buf; + int localStart = c.idx(start); + int localLength = Math.min(length, c.endOffset - start); + // avoid additional checks in AbstractByteBuf case + int result = s instanceof AbstractByteBuf + ? ((AbstractByteBuf) s).forEachByteAsc0(localStart, localStart + localLength, processor) + : s.forEachByte(localStart, localLength, processor); + if (result != -1) { + return result - c.adjustment; + } + start += localLength; + length -= localLength; + } + return -1; + } + + @Override + protected int forEachByteDesc0(int rStart, int rEnd, ByteProcessor processor) throws Exception { + if (rEnd > rStart) { // rStart *and* rEnd are inclusive + return -1; + } + for (int i = toComponentIndex0(rStart), length = 1 + rStart - rEnd; length > 0; i--) { + Component c = components[i]; + if (c.offset == c.endOffset) { + continue; // empty + } + ByteBuf s = c.buf; + int localRStart = c.idx(length + rEnd); + int localLength = Math.min(length, localRStart), localIndex = localRStart - localLength; + // avoid additional checks in AbstractByteBuf case + int result = s instanceof AbstractByteBuf + ? ((AbstractByteBuf) s).forEachByteDesc0(localRStart - 1, localIndex, processor) + : s.forEachByteDesc(localIndex, localLength, processor); + + if (result != -1) { + return result - c.adjustment; + } + length -= localLength; + } + return -1; + } + + /** + * Same with {@link #slice(int, int)} except that this method returns a list. + */ + public List decompose(int offset, int length) { + checkIndex(offset, length); + if (length == 0) { + return Collections.emptyList(); + } + + int componentId = toComponentIndex0(offset); + int bytesToSlice = length; + // The first component + Component firstC = components[componentId]; + + // It's important to use srcBuf and NOT buf as we need to return the "original" source buffer and not the + // unwrapped one as otherwise we could loose the ability to correctly update the reference count on the + // returned buffer. + ByteBuf slice = firstC.srcBuf.slice(firstC.srcIdx(offset), Math.min(firstC.endOffset - offset, bytesToSlice)); + bytesToSlice -= slice.readableBytes(); + + if (bytesToSlice == 0) { + return Collections.singletonList(slice); + } + + List sliceList = new ArrayList(componentCount - componentId); + sliceList.add(slice); + + // Add all the slices until there is nothing more left and then return the List. + do { + Component component = components[++componentId]; + // It's important to use srcBuf and NOT buf as we need to return the "original" source buffer and not the + // unwrapped one as otherwise we could loose the ability to correctly update the reference count on the + // returned buffer. + slice = component.srcBuf.slice(component.srcIdx(component.offset), + Math.min(component.length(), bytesToSlice)); + bytesToSlice -= slice.readableBytes(); + sliceList.add(slice); + } while (bytesToSlice > 0); + + return sliceList; + } + + @Override + public boolean isDirect() { + int size = componentCount; + if (size == 0) { + return false; + } + for (int i = 0; i < size; i++) { + if (!components[i].buf.isDirect()) { + return false; + } + } + return true; + } + + @Override + public boolean hasArray() { + switch (componentCount) { + case 0: + return true; + case 1: + return components[0].buf.hasArray(); + default: + return false; + } + } + + @Override + public byte[] array() { + switch (componentCount) { + case 0: + return EmptyArrays.EMPTY_BYTES; + case 1: + return components[0].buf.array(); + default: + throw new UnsupportedOperationException(); + } + } + + @Override + public int arrayOffset() { + switch (componentCount) { + case 0: + return 0; + case 1: + Component c = components[0]; + return c.idx(c.buf.arrayOffset()); + default: + throw new UnsupportedOperationException(); + } + } + + @Override + public boolean hasMemoryAddress() { + switch (componentCount) { + case 0: + return Unpooled.EMPTY_BUFFER.hasMemoryAddress(); + case 1: + return components[0].buf.hasMemoryAddress(); + default: + return false; + } + } + + @Override + public long memoryAddress() { + switch (componentCount) { + case 0: + return Unpooled.EMPTY_BUFFER.memoryAddress(); + case 1: + Component c = components[0]; + return c.buf.memoryAddress() + c.adjustment; + default: + throw new UnsupportedOperationException(); + } + } + + @Override + public int capacity() { + int size = componentCount; + return size > 0 ? components[size - 1].endOffset : 0; + } + + @Override + public CompositeByteBuf capacity(int newCapacity) { + checkNewCapacity(newCapacity); + + final int size = componentCount, oldCapacity = capacity(); + if (newCapacity > oldCapacity) { + final int paddingLength = newCapacity - oldCapacity; + ByteBuf padding = allocBuffer(paddingLength).setIndex(0, paddingLength); + addComponent0(false, size, padding); + if (componentCount >= maxNumComponents) { + // FIXME: No need to create a padding buffer and consolidate. + // Just create a big single buffer and put the current content there. + consolidateIfNeeded(); + } + } else if (newCapacity < oldCapacity) { + lastAccessed = null; + int i = size - 1; + for (int bytesToTrim = oldCapacity - newCapacity; i >= 0; i--) { + Component c = components[i]; + final int cLength = c.length(); + if (bytesToTrim < cLength) { + // Trim the last component + c.endOffset -= bytesToTrim; + ByteBuf slice = c.slice; + if (slice != null) { + // We must replace the cached slice with a derived one to ensure that + // it can later be released properly in the case of PooledSlicedByteBuf. + c.slice = slice.slice(0, c.length()); + } + break; + } + c.free(); + bytesToTrim -= cLength; + } + removeCompRange(i + 1, size); + + if (readerIndex() > newCapacity) { + setIndex0(newCapacity, newCapacity); + } else if (writerIndex > newCapacity) { + writerIndex = newCapacity; + } + } + return this; + } + + @Override + public ByteBufAllocator alloc() { + return alloc; + } + + @Override + public ByteOrder order() { + return ByteOrder.BIG_ENDIAN; + } + + /** + * Return the current number of {@link ByteBuf}'s that are composed in this instance + */ + public int numComponents() { + return componentCount; + } + + /** + * Return the max number of {@link ByteBuf}'s that are composed in this instance + */ + public int maxNumComponents() { + return maxNumComponents; + } + + /** + * Return the index for the given offset + */ + public int toComponentIndex(int offset) { + checkIndex(offset); + return toComponentIndex0(offset); + } + + private int toComponentIndex0(int offset) { + int size = componentCount; + if (offset == 0) { // fast-path zero offset + for (int i = 0; i < size; i++) { + if (components[i].endOffset > 0) { + return i; + } + } + } + if (size <= 2) { // fast-path for 1 and 2 component count + return size == 1 || offset < components[0].endOffset ? 0 : 1; + } + for (int low = 0, high = size; low <= high;) { + int mid = low + high >>> 1; + Component c = components[mid]; + if (offset >= c.endOffset) { + low = mid + 1; + } else if (offset < c.offset) { + high = mid - 1; + } else { + return mid; + } + } + + throw new Error("should not reach here"); + } + + public int toByteIndex(int cIndex) { + checkComponentIndex(cIndex); + return components[cIndex].offset; + } + + @Override + public byte getByte(int index) { + Component c = findComponent(index); + return c.buf.getByte(c.idx(index)); + } + + @Override + protected byte _getByte(int index) { + Component c = findComponent0(index); + return c.buf.getByte(c.idx(index)); + } + + @Override + protected short _getShort(int index) { + Component c = findComponent0(index); + if (index + 2 <= c.endOffset) { + return c.buf.getShort(c.idx(index)); + } else if (order() == ByteOrder.BIG_ENDIAN) { + return (short) ((_getByte(index) & 0xff) << 8 | _getByte(index + 1) & 0xff); + } else { + return (short) (_getByte(index) & 0xff | (_getByte(index + 1) & 0xff) << 8); + } + } + + @Override + protected short _getShortLE(int index) { + Component c = findComponent0(index); + if (index + 2 <= c.endOffset) { + return c.buf.getShortLE(c.idx(index)); + } else if (order() == ByteOrder.BIG_ENDIAN) { + return (short) (_getByte(index) & 0xff | (_getByte(index + 1) & 0xff) << 8); + } else { + return (short) ((_getByte(index) & 0xff) << 8 | _getByte(index + 1) & 0xff); + } + } + + @Override + protected int _getUnsignedMedium(int index) { + Component c = findComponent0(index); + if (index + 3 <= c.endOffset) { + return c.buf.getUnsignedMedium(c.idx(index)); + } else if (order() == ByteOrder.BIG_ENDIAN) { + return (_getShort(index) & 0xffff) << 8 | _getByte(index + 2) & 0xff; + } else { + return _getShort(index) & 0xFFFF | (_getByte(index + 2) & 0xFF) << 16; + } + } + + @Override + protected int _getUnsignedMediumLE(int index) { + Component c = findComponent0(index); + if (index + 3 <= c.endOffset) { + return c.buf.getUnsignedMediumLE(c.idx(index)); + } else if (order() == ByteOrder.BIG_ENDIAN) { + return _getShortLE(index) & 0xffff | (_getByte(index + 2) & 0xff) << 16; + } else { + return (_getShortLE(index) & 0xffff) << 8 | _getByte(index + 2) & 0xff; + } + } + + @Override + protected int _getInt(int index) { + Component c = findComponent0(index); + if (index + 4 <= c.endOffset) { + return c.buf.getInt(c.idx(index)); + } else if (order() == ByteOrder.BIG_ENDIAN) { + return (_getShort(index) & 0xffff) << 16 | _getShort(index + 2) & 0xffff; + } else { + return _getShort(index) & 0xFFFF | (_getShort(index + 2) & 0xFFFF) << 16; + } + } + + @Override + protected int _getIntLE(int index) { + Component c = findComponent0(index); + if (index + 4 <= c.endOffset) { + return c.buf.getIntLE(c.idx(index)); + } else if (order() == ByteOrder.BIG_ENDIAN) { + return _getShortLE(index) & 0xffff | (_getShortLE(index + 2) & 0xffff) << 16; + } else { + return (_getShortLE(index) & 0xffff) << 16 | _getShortLE(index + 2) & 0xffff; + } + } + + @Override + protected long _getLong(int index) { + Component c = findComponent0(index); + if (index + 8 <= c.endOffset) { + return c.buf.getLong(c.idx(index)); + } else if (order() == ByteOrder.BIG_ENDIAN) { + return (_getInt(index) & 0xffffffffL) << 32 | _getInt(index + 4) & 0xffffffffL; + } else { + return _getInt(index) & 0xFFFFFFFFL | (_getInt(index + 4) & 0xFFFFFFFFL) << 32; + } + } + + @Override + protected long _getLongLE(int index) { + Component c = findComponent0(index); + if (index + 8 <= c.endOffset) { + return c.buf.getLongLE(c.idx(index)); + } else if (order() == ByteOrder.BIG_ENDIAN) { + return _getIntLE(index) & 0xffffffffL | (_getIntLE(index + 4) & 0xffffffffL) << 32; + } else { + return (_getIntLE(index) & 0xffffffffL) << 32 | _getIntLE(index + 4) & 0xffffffffL; + } + } + + @Override + public CompositeByteBuf getBytes(int index, byte[] dst, int dstIndex, int length) { + checkDstIndex(index, length, dstIndex, dst.length); + if (length == 0) { + return this; + } + + int i = toComponentIndex0(index); + while (length > 0) { + Component c = components[i]; + int localLength = Math.min(length, c.endOffset - index); + c.buf.getBytes(c.idx(index), dst, dstIndex, localLength); + index += localLength; + dstIndex += localLength; + length -= localLength; + i ++; + } + return this; + } + + @Override + public CompositeByteBuf getBytes(int index, ByteBuffer dst) { + int limit = dst.limit(); + int length = dst.remaining(); + + checkIndex(index, length); + if (length == 0) { + return this; + } + + int i = toComponentIndex0(index); + try { + while (length > 0) { + Component c = components[i]; + int localLength = Math.min(length, c.endOffset - index); + dst.limit(dst.position() + localLength); + c.buf.getBytes(c.idx(index), dst); + index += localLength; + length -= localLength; + i ++; + } + } finally { + dst.limit(limit); + } + return this; + } + + @Override + public CompositeByteBuf getBytes(int index, ByteBuf dst, int dstIndex, int length) { + checkDstIndex(index, length, dstIndex, dst.capacity()); + if (length == 0) { + return this; + } + + int i = toComponentIndex0(index); + while (length > 0) { + Component c = components[i]; + int localLength = Math.min(length, c.endOffset - index); + c.buf.getBytes(c.idx(index), dst, dstIndex, localLength); + index += localLength; + dstIndex += localLength; + length -= localLength; + i ++; + } + return this; + } + + @Override + public int getBytes(int index, GatheringByteChannel out, int length) + throws IOException { + int count = nioBufferCount(); + if (count == 1) { + return out.write(internalNioBuffer(index, length)); + } else { + long writtenBytes = out.write(nioBuffers(index, length)); + if (writtenBytes > Integer.MAX_VALUE) { + return Integer.MAX_VALUE; + } else { + return (int) writtenBytes; + } + } + } + + @Override + public int getBytes(int index, FileChannel out, long position, int length) + throws IOException { + int count = nioBufferCount(); + if (count == 1) { + return out.write(internalNioBuffer(index, length), position); + } else { + long writtenBytes = 0; + for (ByteBuffer buf : nioBuffers(index, length)) { + writtenBytes += out.write(buf, position + writtenBytes); + } + if (writtenBytes > Integer.MAX_VALUE) { + return Integer.MAX_VALUE; + } + return (int) writtenBytes; + } + } + + @Override + public CompositeByteBuf getBytes(int index, OutputStream out, int length) throws IOException { + checkIndex(index, length); + if (length == 0) { + return this; + } + + int i = toComponentIndex0(index); + while (length > 0) { + Component c = components[i]; + int localLength = Math.min(length, c.endOffset - index); + c.buf.getBytes(c.idx(index), out, localLength); + index += localLength; + length -= localLength; + i ++; + } + return this; + } + + @Override + public CompositeByteBuf setByte(int index, int value) { + Component c = findComponent(index); + c.buf.setByte(c.idx(index), value); + return this; + } + + @Override + protected void _setByte(int index, int value) { + Component c = findComponent0(index); + c.buf.setByte(c.idx(index), value); + } + + @Override + public CompositeByteBuf setShort(int index, int value) { + checkIndex(index, 2); + _setShort(index, value); + return this; + } + + @Override + protected void _setShort(int index, int value) { + Component c = findComponent0(index); + if (index + 2 <= c.endOffset) { + c.buf.setShort(c.idx(index), value); + } else if (order() == ByteOrder.BIG_ENDIAN) { + _setByte(index, (byte) (value >>> 8)); + _setByte(index + 1, (byte) value); + } else { + _setByte(index, (byte) value); + _setByte(index + 1, (byte) (value >>> 8)); + } + } + + @Override + protected void _setShortLE(int index, int value) { + Component c = findComponent0(index); + if (index + 2 <= c.endOffset) { + c.buf.setShortLE(c.idx(index), value); + } else if (order() == ByteOrder.BIG_ENDIAN) { + _setByte(index, (byte) value); + _setByte(index + 1, (byte) (value >>> 8)); + } else { + _setByte(index, (byte) (value >>> 8)); + _setByte(index + 1, (byte) value); + } + } + + @Override + public CompositeByteBuf setMedium(int index, int value) { + checkIndex(index, 3); + _setMedium(index, value); + return this; + } + + @Override + protected void _setMedium(int index, int value) { + Component c = findComponent0(index); + if (index + 3 <= c.endOffset) { + c.buf.setMedium(c.idx(index), value); + } else if (order() == ByteOrder.BIG_ENDIAN) { + _setShort(index, (short) (value >> 8)); + _setByte(index + 2, (byte) value); + } else { + _setShort(index, (short) value); + _setByte(index + 2, (byte) (value >>> 16)); + } + } + + @Override + protected void _setMediumLE(int index, int value) { + Component c = findComponent0(index); + if (index + 3 <= c.endOffset) { + c.buf.setMediumLE(c.idx(index), value); + } else if (order() == ByteOrder.BIG_ENDIAN) { + _setShortLE(index, (short) value); + _setByte(index + 2, (byte) (value >>> 16)); + } else { + _setShortLE(index, (short) (value >> 8)); + _setByte(index + 2, (byte) value); + } + } + + @Override + public CompositeByteBuf setInt(int index, int value) { + checkIndex(index, 4); + _setInt(index, value); + return this; + } + + @Override + protected void _setInt(int index, int value) { + Component c = findComponent0(index); + if (index + 4 <= c.endOffset) { + c.buf.setInt(c.idx(index), value); + } else if (order() == ByteOrder.BIG_ENDIAN) { + _setShort(index, (short) (value >>> 16)); + _setShort(index + 2, (short) value); + } else { + _setShort(index, (short) value); + _setShort(index + 2, (short) (value >>> 16)); + } + } + + @Override + protected void _setIntLE(int index, int value) { + Component c = findComponent0(index); + if (index + 4 <= c.endOffset) { + c.buf.setIntLE(c.idx(index), value); + } else if (order() == ByteOrder.BIG_ENDIAN) { + _setShortLE(index, (short) value); + _setShortLE(index + 2, (short) (value >>> 16)); + } else { + _setShortLE(index, (short) (value >>> 16)); + _setShortLE(index + 2, (short) value); + } + } + + @Override + public CompositeByteBuf setLong(int index, long value) { + checkIndex(index, 8); + _setLong(index, value); + return this; + } + + @Override + protected void _setLong(int index, long value) { + Component c = findComponent0(index); + if (index + 8 <= c.endOffset) { + c.buf.setLong(c.idx(index), value); + } else if (order() == ByteOrder.BIG_ENDIAN) { + _setInt(index, (int) (value >>> 32)); + _setInt(index + 4, (int) value); + } else { + _setInt(index, (int) value); + _setInt(index + 4, (int) (value >>> 32)); + } + } + + @Override + protected void _setLongLE(int index, long value) { + Component c = findComponent0(index); + if (index + 8 <= c.endOffset) { + c.buf.setLongLE(c.idx(index), value); + } else if (order() == ByteOrder.BIG_ENDIAN) { + _setIntLE(index, (int) value); + _setIntLE(index + 4, (int) (value >>> 32)); + } else { + _setIntLE(index, (int) (value >>> 32)); + _setIntLE(index + 4, (int) value); + } + } + + @Override + public CompositeByteBuf setBytes(int index, byte[] src, int srcIndex, int length) { + checkSrcIndex(index, length, srcIndex, src.length); + if (length == 0) { + return this; + } + + int i = toComponentIndex0(index); + while (length > 0) { + Component c = components[i]; + int localLength = Math.min(length, c.endOffset - index); + c.buf.setBytes(c.idx(index), src, srcIndex, localLength); + index += localLength; + srcIndex += localLength; + length -= localLength; + i ++; + } + return this; + } + + @Override + public CompositeByteBuf setBytes(int index, ByteBuffer src) { + int limit = src.limit(); + int length = src.remaining(); + + checkIndex(index, length); + if (length == 0) { + return this; + } + + int i = toComponentIndex0(index); + try { + while (length > 0) { + Component c = components[i]; + int localLength = Math.min(length, c.endOffset - index); + src.limit(src.position() + localLength); + c.buf.setBytes(c.idx(index), src); + index += localLength; + length -= localLength; + i ++; + } + } finally { + src.limit(limit); + } + return this; + } + + @Override + public CompositeByteBuf setBytes(int index, ByteBuf src, int srcIndex, int length) { + checkSrcIndex(index, length, srcIndex, src.capacity()); + if (length == 0) { + return this; + } + + int i = toComponentIndex0(index); + while (length > 0) { + Component c = components[i]; + int localLength = Math.min(length, c.endOffset - index); + c.buf.setBytes(c.idx(index), src, srcIndex, localLength); + index += localLength; + srcIndex += localLength; + length -= localLength; + i ++; + } + return this; + } + + @Override + public int setBytes(int index, InputStream in, int length) throws IOException { + checkIndex(index, length); + if (length == 0) { + return in.read(EmptyArrays.EMPTY_BYTES); + } + + int i = toComponentIndex0(index); + int readBytes = 0; + do { + Component c = components[i]; + int localLength = Math.min(length, c.endOffset - index); + if (localLength == 0) { + // Skip empty buffer + i++; + continue; + } + int localReadBytes = c.buf.setBytes(c.idx(index), in, localLength); + if (localReadBytes < 0) { + if (readBytes == 0) { + return -1; + } else { + break; + } + } + + index += localReadBytes; + length -= localReadBytes; + readBytes += localReadBytes; + if (localReadBytes == localLength) { + i ++; + } + } while (length > 0); + + return readBytes; + } + + @Override + public int setBytes(int index, ScatteringByteChannel in, int length) throws IOException { + checkIndex(index, length); + if (length == 0) { + return in.read(EMPTY_NIO_BUFFER); + } + + int i = toComponentIndex0(index); + int readBytes = 0; + do { + Component c = components[i]; + int localLength = Math.min(length, c.endOffset - index); + if (localLength == 0) { + // Skip empty buffer + i++; + continue; + } + int localReadBytes = c.buf.setBytes(c.idx(index), in, localLength); + + if (localReadBytes == 0) { + break; + } + + if (localReadBytes < 0) { + if (readBytes == 0) { + return -1; + } else { + break; + } + } + + index += localReadBytes; + length -= localReadBytes; + readBytes += localReadBytes; + if (localReadBytes == localLength) { + i ++; + } + } while (length > 0); + + return readBytes; + } + + @Override + public int setBytes(int index, FileChannel in, long position, int length) throws IOException { + checkIndex(index, length); + if (length == 0) { + return in.read(EMPTY_NIO_BUFFER, position); + } + + int i = toComponentIndex0(index); + int readBytes = 0; + do { + Component c = components[i]; + int localLength = Math.min(length, c.endOffset - index); + if (localLength == 0) { + // Skip empty buffer + i++; + continue; + } + int localReadBytes = c.buf.setBytes(c.idx(index), in, position + readBytes, localLength); + + if (localReadBytes == 0) { + break; + } + + if (localReadBytes < 0) { + if (readBytes == 0) { + return -1; + } else { + break; + } + } + + index += localReadBytes; + length -= localReadBytes; + readBytes += localReadBytes; + if (localReadBytes == localLength) { + i ++; + } + } while (length > 0); + + return readBytes; + } + + @Override + public ByteBuf copy(int index, int length) { + checkIndex(index, length); + ByteBuf dst = allocBuffer(length); + if (length != 0) { + copyTo(index, length, toComponentIndex0(index), dst); + } + return dst; + } + + private void copyTo(int index, int length, int componentId, ByteBuf dst) { + int dstIndex = 0; + int i = componentId; + + while (length > 0) { + Component c = components[i]; + int localLength = Math.min(length, c.endOffset - index); + c.buf.getBytes(c.idx(index), dst, dstIndex, localLength); + index += localLength; + dstIndex += localLength; + length -= localLength; + i ++; + } + + dst.writerIndex(dst.capacity()); + } + + /** + * Return the {@link ByteBuf} on the specified index + * + * @param cIndex the index for which the {@link ByteBuf} should be returned + * @return buf the {@link ByteBuf} on the specified index + */ + public ByteBuf component(int cIndex) { + checkComponentIndex(cIndex); + return components[cIndex].duplicate(); + } + + /** + * Return the {@link ByteBuf} on the specified index + * + * @param offset the offset for which the {@link ByteBuf} should be returned + * @return the {@link ByteBuf} on the specified index + */ + public ByteBuf componentAtOffset(int offset) { + return findComponent(offset).duplicate(); + } + + /** + * Return the internal {@link ByteBuf} on the specified index. Note that updating the indexes of the returned + * buffer will lead to an undefined behavior of this buffer. + * + * @param cIndex the index for which the {@link ByteBuf} should be returned + */ + public ByteBuf internalComponent(int cIndex) { + checkComponentIndex(cIndex); + return components[cIndex].slice(); + } + + /** + * Return the internal {@link ByteBuf} on the specified offset. Note that updating the indexes of the returned + * buffer will lead to an undefined behavior of this buffer. + * + * @param offset the offset for which the {@link ByteBuf} should be returned + */ + public ByteBuf internalComponentAtOffset(int offset) { + return findComponent(offset).slice(); + } + + // weak cache - check it first when looking for component + private Component lastAccessed; + + private Component findComponent(int offset) { + Component la = lastAccessed; + if (la != null && offset >= la.offset && offset < la.endOffset) { + ensureAccessible(); + return la; + } + checkIndex(offset); + return findIt(offset); + } + + private Component findComponent0(int offset) { + Component la = lastAccessed; + if (la != null && offset >= la.offset && offset < la.endOffset) { + return la; + } + return findIt(offset); + } + + private Component findIt(int offset) { + for (int low = 0, high = componentCount; low <= high;) { + int mid = low + high >>> 1; + Component c = components[mid]; + if (c == null) { + throw new IllegalStateException("No component found for offset. " + + "Composite buffer layout might be outdated, e.g. from a discardReadBytes call."); + } + if (offset >= c.endOffset) { + low = mid + 1; + } else if (offset < c.offset) { + high = mid - 1; + } else { + lastAccessed = c; + return c; + } + } + + throw new Error("should not reach here"); + } + + @Override + public int nioBufferCount() { + int size = componentCount; + switch (size) { + case 0: + return 1; + case 1: + return components[0].buf.nioBufferCount(); + default: + int count = 0; + for (int i = 0; i < size; i++) { + count += components[i].buf.nioBufferCount(); + } + return count; + } + } + + @Override + public ByteBuffer internalNioBuffer(int index, int length) { + switch (componentCount) { + case 0: + return EMPTY_NIO_BUFFER; + case 1: + return components[0].internalNioBuffer(index, length); + default: + throw new UnsupportedOperationException(); + } + } + + @Override + public ByteBuffer nioBuffer(int index, int length) { + checkIndex(index, length); + + switch (componentCount) { + case 0: + return EMPTY_NIO_BUFFER; + case 1: + Component c = components[0]; + ByteBuf buf = c.buf; + if (buf.nioBufferCount() == 1) { + return buf.nioBuffer(c.idx(index), length); + } + break; + default: + break; + } + + ByteBuffer[] buffers = nioBuffers(index, length); + + if (buffers.length == 1) { + return buffers[0]; + } + + ByteBuffer merged = ByteBuffer.allocate(length).order(order()); + for (ByteBuffer buf: buffers) { + merged.put(buf); + } + + merged.flip(); + return merged; + } + + @Override + public ByteBuffer[] nioBuffers(int index, int length) { + checkIndex(index, length); + if (length == 0) { + return new ByteBuffer[] { EMPTY_NIO_BUFFER }; + } + + RecyclableArrayList buffers = RecyclableArrayList.newInstance(componentCount); + try { + int i = toComponentIndex0(index); + while (length > 0) { + Component c = components[i]; + ByteBuf s = c.buf; + int localLength = Math.min(length, c.endOffset - index); + switch (s.nioBufferCount()) { + case 0: + throw new UnsupportedOperationException(); + case 1: + buffers.add(s.nioBuffer(c.idx(index), localLength)); + break; + default: + Collections.addAll(buffers, s.nioBuffers(c.idx(index), localLength)); + } + + index += localLength; + length -= localLength; + i ++; + } + + return buffers.toArray(EmptyArrays.EMPTY_BYTE_BUFFERS); + } finally { + buffers.recycle(); + } + } + + /** + * Consolidate the composed {@link ByteBuf}s + */ + public CompositeByteBuf consolidate() { + ensureAccessible(); + consolidate0(0, componentCount); + return this; + } + + /** + * Consolidate the composed {@link ByteBuf}s + * + * @param cIndex the index on which to start to compose + * @param numComponents the number of components to compose + */ + public CompositeByteBuf consolidate(int cIndex, int numComponents) { + checkComponentIndex(cIndex, numComponents); + consolidate0(cIndex, numComponents); + return this; + } + + private void consolidate0(int cIndex, int numComponents) { + if (numComponents <= 1) { + return; + } + + final int endCIndex = cIndex + numComponents; + final int startOffset = cIndex != 0 ? components[cIndex].offset : 0; + final int capacity = components[endCIndex - 1].endOffset - startOffset; + final ByteBuf consolidated = allocBuffer(capacity); + + for (int i = cIndex; i < endCIndex; i ++) { + components[i].transferTo(consolidated); + } + lastAccessed = null; + removeCompRange(cIndex + 1, endCIndex); + components[cIndex] = newComponent(consolidated, 0); + if (cIndex != 0 || numComponents != componentCount) { + updateComponentOffsets(cIndex); + } + } + + /** + * Discard all {@link ByteBuf}s which are read. + */ + public CompositeByteBuf discardReadComponents() { + ensureAccessible(); + final int readerIndex = readerIndex(); + if (readerIndex == 0) { + return this; + } + + // Discard everything if (readerIndex = writerIndex = capacity). + int writerIndex = writerIndex(); + if (readerIndex == writerIndex && writerIndex == capacity()) { + for (int i = 0, size = componentCount; i < size; i++) { + components[i].free(); + } + lastAccessed = null; + clearComps(); + setIndex(0, 0); + adjustMarkers(readerIndex); + return this; + } + + // Remove read components. + int firstComponentId = 0; + Component c = null; + for (int size = componentCount; firstComponentId < size; firstComponentId++) { + c = components[firstComponentId]; + if (c.endOffset > readerIndex) { + break; + } + c.free(); + } + if (firstComponentId == 0) { + return this; // Nothing to discard + } + Component la = lastAccessed; + if (la != null && la.endOffset <= readerIndex) { + lastAccessed = null; + } + removeCompRange(0, firstComponentId); + + // Update indexes and markers. + int offset = c.offset; + updateComponentOffsets(0); + setIndex(readerIndex - offset, writerIndex - offset); + adjustMarkers(offset); + return this; + } + + @Override + public CompositeByteBuf discardReadBytes() { + ensureAccessible(); + final int readerIndex = readerIndex(); + if (readerIndex == 0) { + return this; + } + + // Discard everything if (readerIndex = writerIndex = capacity). + int writerIndex = writerIndex(); + if (readerIndex == writerIndex && writerIndex == capacity()) { + for (int i = 0, size = componentCount; i < size; i++) { + components[i].free(); + } + lastAccessed = null; + clearComps(); + setIndex(0, 0); + adjustMarkers(readerIndex); + return this; + } + + int firstComponentId = 0; + Component c = null; + for (int size = componentCount; firstComponentId < size; firstComponentId++) { + c = components[firstComponentId]; + if (c.endOffset > readerIndex) { + break; + } + c.free(); + } + + // Replace the first readable component with a new slice. + int trimmedBytes = readerIndex - c.offset; + c.offset = 0; + c.endOffset -= readerIndex; + c.srcAdjustment += readerIndex; + c.adjustment += readerIndex; + ByteBuf slice = c.slice; + if (slice != null) { + // We must replace the cached slice with a derived one to ensure that + // it can later be released properly in the case of PooledSlicedByteBuf. + c.slice = slice.slice(trimmedBytes, c.length()); + } + Component la = lastAccessed; + if (la != null && la.endOffset <= readerIndex) { + lastAccessed = null; + } + + removeCompRange(0, firstComponentId); + + // Update indexes and markers. + updateComponentOffsets(0); + setIndex(0, writerIndex - readerIndex); + adjustMarkers(readerIndex); + return this; + } + + private ByteBuf allocBuffer(int capacity) { + return direct ? alloc().directBuffer(capacity) : alloc().heapBuffer(capacity); + } + + @Override + public String toString() { + String result = super.toString(); + result = result.substring(0, result.length() - 1); + return result + ", components=" + componentCount + ')'; + } + + private static final class Component { + final ByteBuf srcBuf; // the originally added buffer + final ByteBuf buf; // srcBuf unwrapped zero or more times + + int srcAdjustment; // index of the start of this CompositeByteBuf relative to srcBuf + int adjustment; // index of the start of this CompositeByteBuf relative to buf + + int offset; // offset of this component within this CompositeByteBuf + int endOffset; // end offset of this component within this CompositeByteBuf + + private ByteBuf slice; // cached slice, may be null + + Component(ByteBuf srcBuf, int srcOffset, ByteBuf buf, int bufOffset, + int offset, int len, ByteBuf slice) { + this.srcBuf = srcBuf; + this.srcAdjustment = srcOffset - offset; + this.buf = buf; + this.adjustment = bufOffset - offset; + this.offset = offset; + this.endOffset = offset + len; + this.slice = slice; + } + + int srcIdx(int index) { + return index + srcAdjustment; + } + + int idx(int index) { + return index + adjustment; + } + + int length() { + return endOffset - offset; + } + + void reposition(int newOffset) { + int move = newOffset - offset; + endOffset += move; + srcAdjustment -= move; + adjustment -= move; + offset = newOffset; + } + + // copy then release + void transferTo(ByteBuf dst) { + dst.writeBytes(buf, idx(offset), length()); + free(); + } + + ByteBuf slice() { + ByteBuf s = slice; + if (s == null) { + slice = s = srcBuf.slice(srcIdx(offset), length()); + } + return s; + } + + ByteBuf duplicate() { + return srcBuf.duplicate(); + } + + ByteBuffer internalNioBuffer(int index, int length) { + // Some buffers override this so we must use srcBuf + return srcBuf.internalNioBuffer(srcIdx(index), length); + } + + void free() { + slice = null; + // Release the original buffer since it may have a different + // refcount to the unwrapped buf (e.g. if PooledSlicedByteBuf) + srcBuf.release(); + } + } + + @Override + public CompositeByteBuf readerIndex(int readerIndex) { + super.readerIndex(readerIndex); + return this; + } + + @Override + public CompositeByteBuf writerIndex(int writerIndex) { + super.writerIndex(writerIndex); + return this; + } + + @Override + public CompositeByteBuf setIndex(int readerIndex, int writerIndex) { + super.setIndex(readerIndex, writerIndex); + return this; + } + + @Override + public CompositeByteBuf clear() { + super.clear(); + return this; + } + + @Override + public CompositeByteBuf markReaderIndex() { + super.markReaderIndex(); + return this; + } + + @Override + public CompositeByteBuf resetReaderIndex() { + super.resetReaderIndex(); + return this; + } + + @Override + public CompositeByteBuf markWriterIndex() { + super.markWriterIndex(); + return this; + } + + @Override + public CompositeByteBuf resetWriterIndex() { + super.resetWriterIndex(); + return this; + } + + @Override + public CompositeByteBuf ensureWritable(int minWritableBytes) { + super.ensureWritable(minWritableBytes); + return this; + } + + @Override + public CompositeByteBuf getBytes(int index, ByteBuf dst) { + return getBytes(index, dst, dst.writableBytes()); + } + + @Override + public CompositeByteBuf getBytes(int index, ByteBuf dst, int length) { + getBytes(index, dst, dst.writerIndex(), length); + dst.writerIndex(dst.writerIndex() + length); + return this; + } + + @Override + public CompositeByteBuf getBytes(int index, byte[] dst) { + return getBytes(index, dst, 0, dst.length); + } + + @Override + public CompositeByteBuf setBoolean(int index, boolean value) { + return setByte(index, value? 1 : 0); + } + + @Override + public CompositeByteBuf setChar(int index, int value) { + return setShort(index, value); + } + + @Override + public CompositeByteBuf setFloat(int index, float value) { + return setInt(index, Float.floatToRawIntBits(value)); + } + + @Override + public CompositeByteBuf setDouble(int index, double value) { + return setLong(index, Double.doubleToRawLongBits(value)); + } + + @Override + public CompositeByteBuf setBytes(int index, ByteBuf src) { + super.setBytes(index, src, src.readableBytes()); + return this; + } + + @Override + public CompositeByteBuf setBytes(int index, ByteBuf src, int length) { + super.setBytes(index, src, length); + return this; + } + + @Override + public CompositeByteBuf setBytes(int index, byte[] src) { + return setBytes(index, src, 0, src.length); + } + + @Override + public CompositeByteBuf setZero(int index, int length) { + super.setZero(index, length); + return this; + } + + @Override + public CompositeByteBuf readBytes(ByteBuf dst) { + super.readBytes(dst, dst.writableBytes()); + return this; + } + + @Override + public CompositeByteBuf readBytes(ByteBuf dst, int length) { + super.readBytes(dst, length); + return this; + } + + @Override + public CompositeByteBuf readBytes(ByteBuf dst, int dstIndex, int length) { + super.readBytes(dst, dstIndex, length); + return this; + } + + @Override + public CompositeByteBuf readBytes(byte[] dst) { + super.readBytes(dst, 0, dst.length); + return this; + } + + @Override + public CompositeByteBuf readBytes(byte[] dst, int dstIndex, int length) { + super.readBytes(dst, dstIndex, length); + return this; + } + + @Override + public CompositeByteBuf readBytes(ByteBuffer dst) { + super.readBytes(dst); + return this; + } + + @Override + public CompositeByteBuf readBytes(OutputStream out, int length) throws IOException { + super.readBytes(out, length); + return this; + } + + @Override + public CompositeByteBuf skipBytes(int length) { + super.skipBytes(length); + return this; + } + + @Override + public CompositeByteBuf writeBoolean(boolean value) { + writeByte(value ? 1 : 0); + return this; + } + + @Override + public CompositeByteBuf writeByte(int value) { + ensureWritable0(1); + _setByte(writerIndex++, value); + return this; + } + + @Override + public CompositeByteBuf writeShort(int value) { + super.writeShort(value); + return this; + } + + @Override + public CompositeByteBuf writeMedium(int value) { + super.writeMedium(value); + return this; + } + + @Override + public CompositeByteBuf writeInt(int value) { + super.writeInt(value); + return this; + } + + @Override + public CompositeByteBuf writeLong(long value) { + super.writeLong(value); + return this; + } + + @Override + public CompositeByteBuf writeChar(int value) { + super.writeShort(value); + return this; + } + + @Override + public CompositeByteBuf writeFloat(float value) { + super.writeInt(Float.floatToRawIntBits(value)); + return this; + } + + @Override + public CompositeByteBuf writeDouble(double value) { + super.writeLong(Double.doubleToRawLongBits(value)); + return this; + } + + @Override + public CompositeByteBuf writeBytes(ByteBuf src) { + super.writeBytes(src, src.readableBytes()); + return this; + } + + @Override + public CompositeByteBuf writeBytes(ByteBuf src, int length) { + super.writeBytes(src, length); + return this; + } + + @Override + public CompositeByteBuf writeBytes(ByteBuf src, int srcIndex, int length) { + super.writeBytes(src, srcIndex, length); + return this; + } + + @Override + public CompositeByteBuf writeBytes(byte[] src) { + super.writeBytes(src, 0, src.length); + return this; + } + + @Override + public CompositeByteBuf writeBytes(byte[] src, int srcIndex, int length) { + super.writeBytes(src, srcIndex, length); + return this; + } + + @Override + public CompositeByteBuf writeBytes(ByteBuffer src) { + super.writeBytes(src); + return this; + } + + @Override + public CompositeByteBuf writeZero(int length) { + super.writeZero(length); + return this; + } + + @Override + public CompositeByteBuf retain(int increment) { + super.retain(increment); + return this; + } + + @Override + public CompositeByteBuf retain() { + super.retain(); + return this; + } + + @Override + public CompositeByteBuf touch() { + return this; + } + + @Override + public CompositeByteBuf touch(Object hint) { + return this; + } + + @Override + public ByteBuffer[] nioBuffers() { + return nioBuffers(readerIndex(), readableBytes()); + } + + @Override + public CompositeByteBuf discardSomeReadBytes() { + return discardReadComponents(); + } + + @Override + protected void deallocate() { + if (freed) { + return; + } + + freed = true; + // We're not using foreach to avoid creating an iterator. + // see https://github.com/netty/netty/issues/2642 + for (int i = 0, size = componentCount; i < size; i++) { + components[i].free(); + } + } + + @Override + boolean isAccessible() { + return !freed; + } + + @Override + public ByteBuf unwrap() { + return null; + } + + private final class CompositeByteBufIterator implements Iterator { + private final int size = numComponents(); + private int index; + + @Override + public boolean hasNext() { + return size > index; + } + + @Override + public ByteBuf next() { + if (size != numComponents()) { + throw new ConcurrentModificationException(); + } + if (!hasNext()) { + throw new NoSuchElementException(); + } + try { + return components[index++].slice(); + } catch (IndexOutOfBoundsException e) { + throw new ConcurrentModificationException(); + } + } + + @Override + public void remove() { + throw new UnsupportedOperationException("Read-Only"); + } + } + + // Component array manipulation - range checking omitted + + private void clearComps() { + removeCompRange(0, componentCount); + } + + private void removeComp(int i) { + removeCompRange(i, i + 1); + } + + private void removeCompRange(int from, int to) { + if (from >= to) { + return; + } + final int size = componentCount; + assert from >= 0 && to <= size; + if (to < size) { + System.arraycopy(components, to, components, from, size - to); + } + int newSize = size - to + from; + for (int i = newSize; i < size; i++) { + components[i] = null; + } + componentCount = newSize; + } + + private void addComp(int i, Component c) { + shiftComps(i, 1); + components[i] = c; + } + + private void shiftComps(int i, int count) { + final int size = componentCount, newSize = size + count; + assert i >= 0 && i <= size && count > 0; + if (newSize > components.length) { + // grow the array + int newArrSize = Math.max(size + (size >> 1), newSize); + Component[] newArr; + if (i == size) { + newArr = Arrays.copyOf(components, newArrSize, Component[].class); + } else { + newArr = new Component[newArrSize]; + if (i > 0) { + System.arraycopy(components, 0, newArr, 0, i); + } + if (i < size) { + System.arraycopy(components, i, newArr, i + count, size - i); + } + } + components = newArr; + } else if (i < size) { + System.arraycopy(components, i, components, i + count, size - i); + } + componentCount = newSize; + } +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/DefaultByteBufHolder.java b/netty-buffer/src/main/java/io/netty/buffer/DefaultByteBufHolder.java new file mode 100644 index 0000000..0dc1d4c --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/DefaultByteBufHolder.java @@ -0,0 +1,158 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.StringUtil; + +/** + * Default implementation of a {@link ByteBufHolder} that holds it's data in a {@link ByteBuf}. + * + */ +public class DefaultByteBufHolder implements ByteBufHolder { + + private final ByteBuf data; + + public DefaultByteBufHolder(ByteBuf data) { + this.data = ObjectUtil.checkNotNull(data, "data"); + } + + @Override + public ByteBuf content() { + return ByteBufUtil.ensureAccessible(data); + } + + /** + * {@inheritDoc} + *

+ * This method calls {@code replace(content().copy())} by default. + */ + @Override + public ByteBufHolder copy() { + return replace(data.copy()); + } + + /** + * {@inheritDoc} + *

+ * This method calls {@code replace(content().duplicate())} by default. + */ + @Override + public ByteBufHolder duplicate() { + return replace(data.duplicate()); + } + + /** + * {@inheritDoc} + *

+ * This method calls {@code replace(content().retainedDuplicate())} by default. + */ + @Override + public ByteBufHolder retainedDuplicate() { + return replace(data.retainedDuplicate()); + } + + /** + * {@inheritDoc} + *

+ * Override this method to return a new instance of this object whose content is set to the specified + * {@code content}. The default implementation of {@link #copy()}, {@link #duplicate()} and + * {@link #retainedDuplicate()} invokes this method to create a copy. + */ + @Override + public ByteBufHolder replace(ByteBuf content) { + return new DefaultByteBufHolder(content); + } + + @Override + public int refCnt() { + return data.refCnt(); + } + + @Override + public ByteBufHolder retain() { + data.retain(); + return this; + } + + @Override + public ByteBufHolder retain(int increment) { + data.retain(increment); + return this; + } + + @Override + public ByteBufHolder touch() { + data.touch(); + return this; + } + + @Override + public ByteBufHolder touch(Object hint) { + data.touch(hint); + return this; + } + + @Override + public boolean release() { + return data.release(); + } + + @Override + public boolean release(int decrement) { + return data.release(decrement); + } + + /** + * Return {@link ByteBuf#toString()} without checking the reference count first. This is useful to implement + * {@link #toString()}. + */ + protected final String contentToString() { + return data.toString(); + } + + @Override + public String toString() { + return StringUtil.simpleClassName(this) + '(' + contentToString() + ')'; + } + + /** + * This implementation of the {@code equals} operation is restricted to + * work only with instances of the same class. The reason for that is that + * Netty library already has a number of classes that extend {@link DefaultByteBufHolder} and + * override {@code equals} method with an additional comparison logic and we + * need the symmetric property of the {@code equals} operation to be preserved. + * + * @param o the reference object with which to compare. + * @return {@code true} if this object is the same as the obj + * argument; {@code false} otherwise. + */ + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o != null && getClass() == o.getClass()) { + return data.equals(((DefaultByteBufHolder) o).data); + } + return false; + } + + @Override + public int hashCode() { + return data.hashCode(); + } +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/DuplicatedByteBuf.java b/netty-buffer/src/main/java/io/netty/buffer/DuplicatedByteBuf.java new file mode 100644 index 0000000..59862a4 --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/DuplicatedByteBuf.java @@ -0,0 +1,410 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import io.netty.util.ByteProcessor; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.channels.FileChannel; +import java.nio.channels.GatheringByteChannel; +import java.nio.channels.ScatteringByteChannel; + +/** + * A derived buffer which simply forwards all data access requests to its + * parent. It is recommended to use {@link ByteBuf#duplicate()} instead + * of calling the constructor explicitly. + * + * @deprecated Do not use. + */ +@Deprecated +public class DuplicatedByteBuf extends AbstractDerivedByteBuf { + + private final ByteBuf buffer; + + public DuplicatedByteBuf(ByteBuf buffer) { + this(buffer, buffer.readerIndex(), buffer.writerIndex()); + } + + DuplicatedByteBuf(ByteBuf buffer, int readerIndex, int writerIndex) { + super(buffer.maxCapacity()); + + if (buffer instanceof DuplicatedByteBuf) { + this.buffer = ((DuplicatedByteBuf) buffer).buffer; + } else if (buffer instanceof AbstractPooledDerivedByteBuf) { + this.buffer = buffer.unwrap(); + } else { + this.buffer = buffer; + } + + setIndex(readerIndex, writerIndex); + markReaderIndex(); + markWriterIndex(); + } + + @Override + public ByteBuf unwrap() { + return buffer; + } + + @Override + public ByteBufAllocator alloc() { + return unwrap().alloc(); + } + + @Override + @Deprecated + public ByteOrder order() { + return unwrap().order(); + } + + @Override + public boolean isDirect() { + return unwrap().isDirect(); + } + + @Override + public int capacity() { + return unwrap().capacity(); + } + + @Override + public ByteBuf capacity(int newCapacity) { + unwrap().capacity(newCapacity); + return this; + } + + @Override + public boolean hasArray() { + return unwrap().hasArray(); + } + + @Override + public byte[] array() { + return unwrap().array(); + } + + @Override + public int arrayOffset() { + return unwrap().arrayOffset(); + } + + @Override + public boolean hasMemoryAddress() { + return unwrap().hasMemoryAddress(); + } + + @Override + public long memoryAddress() { + return unwrap().memoryAddress(); + } + + @Override + public byte getByte(int index) { + return unwrap().getByte(index); + } + + @Override + protected byte _getByte(int index) { + return unwrap().getByte(index); + } + + @Override + public short getShort(int index) { + return unwrap().getShort(index); + } + + @Override + protected short _getShort(int index) { + return unwrap().getShort(index); + } + + @Override + public short getShortLE(int index) { + return unwrap().getShortLE(index); + } + + @Override + protected short _getShortLE(int index) { + return unwrap().getShortLE(index); + } + + @Override + public int getUnsignedMedium(int index) { + return unwrap().getUnsignedMedium(index); + } + + @Override + protected int _getUnsignedMedium(int index) { + return unwrap().getUnsignedMedium(index); + } + + @Override + public int getUnsignedMediumLE(int index) { + return unwrap().getUnsignedMediumLE(index); + } + + @Override + protected int _getUnsignedMediumLE(int index) { + return unwrap().getUnsignedMediumLE(index); + } + + @Override + public int getInt(int index) { + return unwrap().getInt(index); + } + + @Override + protected int _getInt(int index) { + return unwrap().getInt(index); + } + + @Override + public int getIntLE(int index) { + return unwrap().getIntLE(index); + } + + @Override + protected int _getIntLE(int index) { + return unwrap().getIntLE(index); + } + + @Override + public long getLong(int index) { + return unwrap().getLong(index); + } + + @Override + protected long _getLong(int index) { + return unwrap().getLong(index); + } + + @Override + public long getLongLE(int index) { + return unwrap().getLongLE(index); + } + + @Override + protected long _getLongLE(int index) { + return unwrap().getLongLE(index); + } + + @Override + public ByteBuf copy(int index, int length) { + return unwrap().copy(index, length); + } + + @Override + public ByteBuf slice(int index, int length) { + return unwrap().slice(index, length); + } + + @Override + public ByteBuf getBytes(int index, ByteBuf dst, int dstIndex, int length) { + unwrap().getBytes(index, dst, dstIndex, length); + return this; + } + + @Override + public ByteBuf getBytes(int index, byte[] dst, int dstIndex, int length) { + unwrap().getBytes(index, dst, dstIndex, length); + return this; + } + + @Override + public ByteBuf getBytes(int index, ByteBuffer dst) { + unwrap().getBytes(index, dst); + return this; + } + + @Override + public ByteBuf setByte(int index, int value) { + unwrap().setByte(index, value); + return this; + } + + @Override + protected void _setByte(int index, int value) { + unwrap().setByte(index, value); + } + + @Override + public ByteBuf setShort(int index, int value) { + unwrap().setShort(index, value); + return this; + } + + @Override + protected void _setShort(int index, int value) { + unwrap().setShort(index, value); + } + + @Override + public ByteBuf setShortLE(int index, int value) { + unwrap().setShortLE(index, value); + return this; + } + + @Override + protected void _setShortLE(int index, int value) { + unwrap().setShortLE(index, value); + } + + @Override + public ByteBuf setMedium(int index, int value) { + unwrap().setMedium(index, value); + return this; + } + + @Override + protected void _setMedium(int index, int value) { + unwrap().setMedium(index, value); + } + + @Override + public ByteBuf setMediumLE(int index, int value) { + unwrap().setMediumLE(index, value); + return this; + } + + @Override + protected void _setMediumLE(int index, int value) { + unwrap().setMediumLE(index, value); + } + + @Override + public ByteBuf setInt(int index, int value) { + unwrap().setInt(index, value); + return this; + } + + @Override + protected void _setInt(int index, int value) { + unwrap().setInt(index, value); + } + + @Override + public ByteBuf setIntLE(int index, int value) { + unwrap().setIntLE(index, value); + return this; + } + + @Override + protected void _setIntLE(int index, int value) { + unwrap().setIntLE(index, value); + } + + @Override + public ByteBuf setLong(int index, long value) { + unwrap().setLong(index, value); + return this; + } + + @Override + protected void _setLong(int index, long value) { + unwrap().setLong(index, value); + } + + @Override + public ByteBuf setLongLE(int index, long value) { + unwrap().setLongLE(index, value); + return this; + } + + @Override + protected void _setLongLE(int index, long value) { + unwrap().setLongLE(index, value); + } + + @Override + public ByteBuf setBytes(int index, byte[] src, int srcIndex, int length) { + unwrap().setBytes(index, src, srcIndex, length); + return this; + } + + @Override + public ByteBuf setBytes(int index, ByteBuf src, int srcIndex, int length) { + unwrap().setBytes(index, src, srcIndex, length); + return this; + } + + @Override + public ByteBuf setBytes(int index, ByteBuffer src) { + unwrap().setBytes(index, src); + return this; + } + + @Override + public ByteBuf getBytes(int index, OutputStream out, int length) + throws IOException { + unwrap().getBytes(index, out, length); + return this; + } + + @Override + public int getBytes(int index, GatheringByteChannel out, int length) + throws IOException { + return unwrap().getBytes(index, out, length); + } + + @Override + public int getBytes(int index, FileChannel out, long position, int length) + throws IOException { + return unwrap().getBytes(index, out, position, length); + } + + @Override + public int setBytes(int index, InputStream in, int length) + throws IOException { + return unwrap().setBytes(index, in, length); + } + + @Override + public int setBytes(int index, ScatteringByteChannel in, int length) + throws IOException { + return unwrap().setBytes(index, in, length); + } + + @Override + public int setBytes(int index, FileChannel in, long position, int length) + throws IOException { + return unwrap().setBytes(index, in, position, length); + } + + @Override + public int nioBufferCount() { + return unwrap().nioBufferCount(); + } + + @Override + public ByteBuffer[] nioBuffers(int index, int length) { + return unwrap().nioBuffers(index, length); + } + + @Override + public int forEachByte(int index, int length, ByteProcessor processor) { + return unwrap().forEachByte(index, length, processor); + } + + @Override + public int forEachByteDesc(int index, int length, ByteProcessor processor) { + return unwrap().forEachByteDesc(index, length, processor); + } +} + diff --git a/netty-buffer/src/main/java/io/netty/buffer/EmptyByteBuf.java b/netty-buffer/src/main/java/io/netty/buffer/EmptyByteBuf.java new file mode 100644 index 0000000..6a0ddb5 --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/EmptyByteBuf.java @@ -0,0 +1,1062 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.buffer; + +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; + +import io.netty.util.ByteProcessor; +import io.netty.util.internal.EmptyArrays; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.StringUtil; + +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.ReadOnlyBufferException; +import java.nio.channels.FileChannel; +import java.nio.channels.GatheringByteChannel; +import java.nio.channels.ScatteringByteChannel; +import java.nio.charset.Charset; + +/** + * An empty {@link ByteBuf} whose capacity and maximum capacity are all {@code 0}. + */ +public final class EmptyByteBuf extends ByteBuf { + + static final int EMPTY_BYTE_BUF_HASH_CODE = 1; + private static final ByteBuffer EMPTY_BYTE_BUFFER = ByteBuffer.allocateDirect(0); + private static final long EMPTY_BYTE_BUFFER_ADDRESS; + + static { + long emptyByteBufferAddress = 0; + try { + if (PlatformDependent.hasUnsafe()) { + emptyByteBufferAddress = PlatformDependent.directBufferAddress(EMPTY_BYTE_BUFFER); + } + } catch (Throwable t) { + // Ignore + } + EMPTY_BYTE_BUFFER_ADDRESS = emptyByteBufferAddress; + } + + private final ByteBufAllocator alloc; + private final ByteOrder order; + private final String str; + private EmptyByteBuf swapped; + + public EmptyByteBuf(ByteBufAllocator alloc) { + this(alloc, ByteOrder.BIG_ENDIAN); + } + + private EmptyByteBuf(ByteBufAllocator alloc, ByteOrder order) { + this.alloc = ObjectUtil.checkNotNull(alloc, "alloc"); + this.order = order; + str = StringUtil.simpleClassName(this) + (order == ByteOrder.BIG_ENDIAN? "BE" : "LE"); + } + + @Override + public int capacity() { + return 0; + } + + @Override + public ByteBuf capacity(int newCapacity) { + throw new ReadOnlyBufferException(); + } + + @Override + public ByteBufAllocator alloc() { + return alloc; + } + + @Override + public ByteOrder order() { + return order; + } + + @Override + public ByteBuf unwrap() { + return null; + } + + @Override + public ByteBuf asReadOnly() { + return Unpooled.unmodifiableBuffer(this); + } + + @Override + public boolean isReadOnly() { + return false; + } + + @Override + public boolean isDirect() { + return true; + } + + @Override + public int maxCapacity() { + return 0; + } + + @Override + public ByteBuf order(ByteOrder endianness) { + if (ObjectUtil.checkNotNull(endianness, "endianness") == order()) { + return this; + } + + EmptyByteBuf swapped = this.swapped; + if (swapped != null) { + return swapped; + } + + this.swapped = swapped = new EmptyByteBuf(alloc(), endianness); + return swapped; + } + + @Override + public int readerIndex() { + return 0; + } + + @Override + public ByteBuf readerIndex(int readerIndex) { + return checkIndex(readerIndex); + } + + @Override + public int writerIndex() { + return 0; + } + + @Override + public ByteBuf writerIndex(int writerIndex) { + return checkIndex(writerIndex); + } + + @Override + public ByteBuf setIndex(int readerIndex, int writerIndex) { + checkIndex(readerIndex); + checkIndex(writerIndex); + return this; + } + + @Override + public int readableBytes() { + return 0; + } + + @Override + public int writableBytes() { + return 0; + } + + @Override + public int maxWritableBytes() { + return 0; + } + + @Override + public boolean isReadable() { + return false; + } + + @Override + public boolean isWritable() { + return false; + } + + @Override + public ByteBuf clear() { + return this; + } + + @Override + public ByteBuf markReaderIndex() { + return this; + } + + @Override + public ByteBuf resetReaderIndex() { + return this; + } + + @Override + public ByteBuf markWriterIndex() { + return this; + } + + @Override + public ByteBuf resetWriterIndex() { + return this; + } + + @Override + public ByteBuf discardReadBytes() { + return this; + } + + @Override + public ByteBuf discardSomeReadBytes() { + return this; + } + + @Override + public ByteBuf ensureWritable(int minWritableBytes) { + checkPositiveOrZero(minWritableBytes, "minWritableBytes"); + if (minWritableBytes != 0) { + throw new IndexOutOfBoundsException(); + } + return this; + } + + @Override + public int ensureWritable(int minWritableBytes, boolean force) { + checkPositiveOrZero(minWritableBytes, "minWritableBytes"); + + if (minWritableBytes == 0) { + return 0; + } + + return 1; + } + + @Override + public boolean getBoolean(int index) { + throw new IndexOutOfBoundsException(); + } + + @Override + public byte getByte(int index) { + throw new IndexOutOfBoundsException(); + } + + @Override + public short getUnsignedByte(int index) { + throw new IndexOutOfBoundsException(); + } + + @Override + public short getShort(int index) { + throw new IndexOutOfBoundsException(); + } + + @Override + public short getShortLE(int index) { + throw new IndexOutOfBoundsException(); + } + + @Override + public int getUnsignedShort(int index) { + throw new IndexOutOfBoundsException(); + } + + @Override + public int getUnsignedShortLE(int index) { + throw new IndexOutOfBoundsException(); + } + + @Override + public int getMedium(int index) { + throw new IndexOutOfBoundsException(); + } + + @Override + public int getMediumLE(int index) { + throw new IndexOutOfBoundsException(); + } + + @Override + public int getUnsignedMedium(int index) { + throw new IndexOutOfBoundsException(); + } + + @Override + public int getUnsignedMediumLE(int index) { + throw new IndexOutOfBoundsException(); + } + + @Override + public int getInt(int index) { + throw new IndexOutOfBoundsException(); + } + + @Override + public int getIntLE(int index) { + throw new IndexOutOfBoundsException(); + } + + @Override + public long getUnsignedInt(int index) { + throw new IndexOutOfBoundsException(); + } + + @Override + public long getUnsignedIntLE(int index) { + throw new IndexOutOfBoundsException(); + } + + @Override + public long getLong(int index) { + throw new IndexOutOfBoundsException(); + } + + @Override + public long getLongLE(int index) { + throw new IndexOutOfBoundsException(); + } + + @Override + public char getChar(int index) { + throw new IndexOutOfBoundsException(); + } + + @Override + public float getFloat(int index) { + throw new IndexOutOfBoundsException(); + } + + @Override + public double getDouble(int index) { + throw new IndexOutOfBoundsException(); + } + + @Override + public ByteBuf getBytes(int index, ByteBuf dst) { + return checkIndex(index, dst.writableBytes()); + } + + @Override + public ByteBuf getBytes(int index, ByteBuf dst, int length) { + return checkIndex(index, length); + } + + @Override + public ByteBuf getBytes(int index, ByteBuf dst, int dstIndex, int length) { + return checkIndex(index, length); + } + + @Override + public ByteBuf getBytes(int index, byte[] dst) { + return checkIndex(index, dst.length); + } + + @Override + public ByteBuf getBytes(int index, byte[] dst, int dstIndex, int length) { + return checkIndex(index, length); + } + + @Override + public ByteBuf getBytes(int index, ByteBuffer dst) { + return checkIndex(index, dst.remaining()); + } + + @Override + public ByteBuf getBytes(int index, OutputStream out, int length) { + return checkIndex(index, length); + } + + @Override + public int getBytes(int index, GatheringByteChannel out, int length) { + checkIndex(index, length); + return 0; + } + + @Override + public int getBytes(int index, FileChannel out, long position, int length) { + checkIndex(index, length); + return 0; + } + + @Override + public CharSequence getCharSequence(int index, int length, Charset charset) { + checkIndex(index, length); + return null; + } + + @Override + public ByteBuf setBoolean(int index, boolean value) { + throw new IndexOutOfBoundsException(); + } + + @Override + public ByteBuf setByte(int index, int value) { + throw new IndexOutOfBoundsException(); + } + + @Override + public ByteBuf setShort(int index, int value) { + throw new IndexOutOfBoundsException(); + } + + @Override + public ByteBuf setShortLE(int index, int value) { + throw new IndexOutOfBoundsException(); + } + + @Override + public ByteBuf setMedium(int index, int value) { + throw new IndexOutOfBoundsException(); + } + + @Override + public ByteBuf setMediumLE(int index, int value) { + throw new IndexOutOfBoundsException(); + } + + @Override + public ByteBuf setInt(int index, int value) { + throw new IndexOutOfBoundsException(); + } + + @Override + public ByteBuf setIntLE(int index, int value) { + throw new IndexOutOfBoundsException(); + } + + @Override + public ByteBuf setLong(int index, long value) { + throw new IndexOutOfBoundsException(); + } + + @Override + public ByteBuf setLongLE(int index, long value) { + throw new IndexOutOfBoundsException(); + } + + @Override + public ByteBuf setChar(int index, int value) { + throw new IndexOutOfBoundsException(); + } + + @Override + public ByteBuf setFloat(int index, float value) { + throw new IndexOutOfBoundsException(); + } + + @Override + public ByteBuf setDouble(int index, double value) { + throw new IndexOutOfBoundsException(); + } + + @Override + public ByteBuf setBytes(int index, ByteBuf src) { + throw new IndexOutOfBoundsException(); + } + + @Override + public ByteBuf setBytes(int index, ByteBuf src, int length) { + return checkIndex(index, length); + } + + @Override + public ByteBuf setBytes(int index, ByteBuf src, int srcIndex, int length) { + return checkIndex(index, length); + } + + @Override + public ByteBuf setBytes(int index, byte[] src) { + return checkIndex(index, src.length); + } + + @Override + public ByteBuf setBytes(int index, byte[] src, int srcIndex, int length) { + return checkIndex(index, length); + } + + @Override + public ByteBuf setBytes(int index, ByteBuffer src) { + return checkIndex(index, src.remaining()); + } + + @Override + public int setBytes(int index, InputStream in, int length) { + checkIndex(index, length); + return 0; + } + + @Override + public int setBytes(int index, ScatteringByteChannel in, int length) { + checkIndex(index, length); + return 0; + } + + @Override + public int setBytes(int index, FileChannel in, long position, int length) { + checkIndex(index, length); + return 0; + } + + @Override + public ByteBuf setZero(int index, int length) { + return checkIndex(index, length); + } + + @Override + public int setCharSequence(int index, CharSequence sequence, Charset charset) { + throw new IndexOutOfBoundsException(); + } + + @Override + public boolean readBoolean() { + throw new IndexOutOfBoundsException(); + } + + @Override + public byte readByte() { + throw new IndexOutOfBoundsException(); + } + + @Override + public short readUnsignedByte() { + throw new IndexOutOfBoundsException(); + } + + @Override + public short readShort() { + throw new IndexOutOfBoundsException(); + } + + @Override + public short readShortLE() { + throw new IndexOutOfBoundsException(); + } + + @Override + public int readUnsignedShort() { + throw new IndexOutOfBoundsException(); + } + + @Override + public int readUnsignedShortLE() { + throw new IndexOutOfBoundsException(); + } + + @Override + public int readMedium() { + throw new IndexOutOfBoundsException(); + } + + @Override + public int readMediumLE() { + throw new IndexOutOfBoundsException(); + } + + @Override + public int readUnsignedMedium() { + throw new IndexOutOfBoundsException(); + } + + @Override + public int readUnsignedMediumLE() { + throw new IndexOutOfBoundsException(); + } + + @Override + public int readInt() { + throw new IndexOutOfBoundsException(); + } + + @Override + public int readIntLE() { + throw new IndexOutOfBoundsException(); + } + + @Override + public long readUnsignedInt() { + throw new IndexOutOfBoundsException(); + } + + @Override + public long readUnsignedIntLE() { + throw new IndexOutOfBoundsException(); + } + + @Override + public long readLong() { + throw new IndexOutOfBoundsException(); + } + + @Override + public long readLongLE() { + throw new IndexOutOfBoundsException(); + } + + @Override + public char readChar() { + throw new IndexOutOfBoundsException(); + } + + @Override + public float readFloat() { + throw new IndexOutOfBoundsException(); + } + + @Override + public double readDouble() { + throw new IndexOutOfBoundsException(); + } + + @Override + public ByteBuf readBytes(int length) { + return checkLength(length); + } + + @Override + public ByteBuf readSlice(int length) { + return checkLength(length); + } + + @Override + public ByteBuf readRetainedSlice(int length) { + return checkLength(length); + } + + @Override + public ByteBuf readBytes(ByteBuf dst) { + return checkLength(dst.writableBytes()); + } + + @Override + public ByteBuf readBytes(ByteBuf dst, int length) { + return checkLength(length); + } + + @Override + public ByteBuf readBytes(ByteBuf dst, int dstIndex, int length) { + return checkLength(length); + } + + @Override + public ByteBuf readBytes(byte[] dst) { + return checkLength(dst.length); + } + + @Override + public ByteBuf readBytes(byte[] dst, int dstIndex, int length) { + return checkLength(length); + } + + @Override + public ByteBuf readBytes(ByteBuffer dst) { + return checkLength(dst.remaining()); + } + + @Override + public ByteBuf readBytes(OutputStream out, int length) { + return checkLength(length); + } + + @Override + public int readBytes(GatheringByteChannel out, int length) { + checkLength(length); + return 0; + } + + @Override + public int readBytes(FileChannel out, long position, int length) { + checkLength(length); + return 0; + } + + @Override + public CharSequence readCharSequence(int length, Charset charset) { + checkLength(length); + return StringUtil.EMPTY_STRING; + } + + @Override + public ByteBuf skipBytes(int length) { + return checkLength(length); + } + + @Override + public ByteBuf writeBoolean(boolean value) { + throw new IndexOutOfBoundsException(); + } + + @Override + public ByteBuf writeByte(int value) { + throw new IndexOutOfBoundsException(); + } + + @Override + public ByteBuf writeShort(int value) { + throw new IndexOutOfBoundsException(); + } + + @Override + public ByteBuf writeShortLE(int value) { + throw new IndexOutOfBoundsException(); + } + + @Override + public ByteBuf writeMedium(int value) { + throw new IndexOutOfBoundsException(); + } + + @Override + public ByteBuf writeMediumLE(int value) { + throw new IndexOutOfBoundsException(); + } + + @Override + public ByteBuf writeInt(int value) { + throw new IndexOutOfBoundsException(); + } + + @Override + public ByteBuf writeIntLE(int value) { + throw new IndexOutOfBoundsException(); + } + + @Override + public ByteBuf writeLong(long value) { + throw new IndexOutOfBoundsException(); + } + + @Override + public ByteBuf writeLongLE(long value) { + throw new IndexOutOfBoundsException(); + } + + @Override + public ByteBuf writeChar(int value) { + throw new IndexOutOfBoundsException(); + } + + @Override + public ByteBuf writeFloat(float value) { + throw new IndexOutOfBoundsException(); + } + + @Override + public ByteBuf writeDouble(double value) { + throw new IndexOutOfBoundsException(); + } + + @Override + public ByteBuf writeBytes(ByteBuf src) { + return checkLength(src.readableBytes()); + } + + @Override + public ByteBuf writeBytes(ByteBuf src, int length) { + return checkLength(length); + } + + @Override + public ByteBuf writeBytes(ByteBuf src, int srcIndex, int length) { + return checkLength(length); + } + + @Override + public ByteBuf writeBytes(byte[] src) { + return checkLength(src.length); + } + + @Override + public ByteBuf writeBytes(byte[] src, int srcIndex, int length) { + return checkLength(length); + } + + @Override + public ByteBuf writeBytes(ByteBuffer src) { + return checkLength(src.remaining()); + } + + @Override + public int writeBytes(InputStream in, int length) { + checkLength(length); + return 0; + } + + @Override + public int writeBytes(ScatteringByteChannel in, int length) { + checkLength(length); + return 0; + } + + @Override + public int writeBytes(FileChannel in, long position, int length) { + checkLength(length); + return 0; + } + + @Override + public ByteBuf writeZero(int length) { + return checkLength(length); + } + + @Override + public int writeCharSequence(CharSequence sequence, Charset charset) { + throw new IndexOutOfBoundsException(); + } + + @Override + public int indexOf(int fromIndex, int toIndex, byte value) { + checkIndex(fromIndex); + checkIndex(toIndex); + return -1; + } + + @Override + public int bytesBefore(byte value) { + return -1; + } + + @Override + public int bytesBefore(int length, byte value) { + checkLength(length); + return -1; + } + + @Override + public int bytesBefore(int index, int length, byte value) { + checkIndex(index, length); + return -1; + } + + @Override + public int forEachByte(ByteProcessor processor) { + return -1; + } + + @Override + public int forEachByte(int index, int length, ByteProcessor processor) { + checkIndex(index, length); + return -1; + } + + @Override + public int forEachByteDesc(ByteProcessor processor) { + return -1; + } + + @Override + public int forEachByteDesc(int index, int length, ByteProcessor processor) { + checkIndex(index, length); + return -1; + } + + @Override + public ByteBuf copy() { + return this; + } + + @Override + public ByteBuf copy(int index, int length) { + return checkIndex(index, length); + } + + @Override + public ByteBuf slice() { + return this; + } + + @Override + public ByteBuf retainedSlice() { + return this; + } + + @Override + public ByteBuf slice(int index, int length) { + return checkIndex(index, length); + } + + @Override + public ByteBuf retainedSlice(int index, int length) { + return checkIndex(index, length); + } + + @Override + public ByteBuf duplicate() { + return this; + } + + @Override + public ByteBuf retainedDuplicate() { + return this; + } + + @Override + public int nioBufferCount() { + return 1; + } + + @Override + public ByteBuffer nioBuffer() { + return EMPTY_BYTE_BUFFER; + } + + @Override + public ByteBuffer nioBuffer(int index, int length) { + checkIndex(index, length); + return nioBuffer(); + } + + @Override + public ByteBuffer[] nioBuffers() { + return new ByteBuffer[] { EMPTY_BYTE_BUFFER }; + } + + @Override + public ByteBuffer[] nioBuffers(int index, int length) { + checkIndex(index, length); + return nioBuffers(); + } + + @Override + public ByteBuffer internalNioBuffer(int index, int length) { + return EMPTY_BYTE_BUFFER; + } + + @Override + public boolean hasArray() { + return true; + } + + @Override + public byte[] array() { + return EmptyArrays.EMPTY_BYTES; + } + + @Override + public int arrayOffset() { + return 0; + } + + @Override + public boolean hasMemoryAddress() { + return EMPTY_BYTE_BUFFER_ADDRESS != 0; + } + + @Override + public long memoryAddress() { + if (hasMemoryAddress()) { + return EMPTY_BYTE_BUFFER_ADDRESS; + } else { + throw new UnsupportedOperationException(); + } + } + + @Override + public boolean isContiguous() { + return true; + } + + @Override + public String toString(Charset charset) { + return ""; + } + + @Override + public String toString(int index, int length, Charset charset) { + checkIndex(index, length); + return toString(charset); + } + + @Override + public int hashCode() { + return EMPTY_BYTE_BUF_HASH_CODE; + } + + @Override + public boolean equals(Object obj) { + return obj instanceof ByteBuf && !((ByteBuf) obj).isReadable(); + } + + @Override + public int compareTo(ByteBuf buffer) { + return buffer.isReadable()? -1 : 0; + } + + @Override + public String toString() { + return str; + } + + @Override + public boolean isReadable(int size) { + return false; + } + + @Override + public boolean isWritable(int size) { + return false; + } + + @Override + public int refCnt() { + return 1; + } + + @Override + public ByteBuf retain() { + return this; + } + + @Override + public ByteBuf retain(int increment) { + return this; + } + + @Override + public ByteBuf touch() { + return this; + } + + @Override + public ByteBuf touch(Object hint) { + return this; + } + + @Override + public boolean release() { + return false; + } + + @Override + public boolean release(int decrement) { + return false; + } + + private ByteBuf checkIndex(int index) { + if (index != 0) { + throw new IndexOutOfBoundsException(); + } + return this; + } + + private ByteBuf checkIndex(int index, int length) { + checkPositiveOrZero(length, "length"); + if (index != 0 || length != 0) { + throw new IndexOutOfBoundsException(); + } + return this; + } + + private ByteBuf checkLength(int length) { + checkPositiveOrZero(length, "length"); + if (length != 0) { + throw new IndexOutOfBoundsException(); + } + return this; + } +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/FixedCompositeByteBuf.java b/netty-buffer/src/main/java/io/netty/buffer/FixedCompositeByteBuf.java new file mode 100644 index 0000000..c9f397f --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/FixedCompositeByteBuf.java @@ -0,0 +1,688 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import io.netty.util.internal.EmptyArrays; +import io.netty.util.internal.RecyclableArrayList; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.ReadOnlyBufferException; +import java.nio.channels.FileChannel; +import java.nio.channels.GatheringByteChannel; +import java.nio.channels.ScatteringByteChannel; +import java.util.Collections; + +/** + * {@link ByteBuf} implementation which allows to wrap an array of {@link ByteBuf} in a read-only mode. + * This is useful to write an array of {@link ByteBuf}s. + */ +final class FixedCompositeByteBuf extends AbstractReferenceCountedByteBuf { + private static final ByteBuf[] EMPTY = { Unpooled.EMPTY_BUFFER }; + private final int nioBufferCount; + private final int capacity; + private final ByteBufAllocator allocator; + private final ByteOrder order; + private final ByteBuf[] buffers; + private final boolean direct; + + FixedCompositeByteBuf(ByteBufAllocator allocator, ByteBuf... buffers) { + super(AbstractByteBufAllocator.DEFAULT_MAX_CAPACITY); + if (buffers.length == 0) { + this.buffers = EMPTY; + order = ByteOrder.BIG_ENDIAN; + nioBufferCount = 1; + capacity = 0; + direct = Unpooled.EMPTY_BUFFER.isDirect(); + } else { + ByteBuf b = buffers[0]; + this.buffers = buffers; + boolean direct = true; + int nioBufferCount = b.nioBufferCount(); + int capacity = b.readableBytes(); + order = b.order(); + for (int i = 1; i < buffers.length; i++) { + b = buffers[i]; + if (buffers[i].order() != order) { + throw new IllegalArgumentException("All ByteBufs need to have same ByteOrder"); + } + nioBufferCount += b.nioBufferCount(); + capacity += b.readableBytes(); + if (!b.isDirect()) { + direct = false; + } + } + this.nioBufferCount = nioBufferCount; + this.capacity = capacity; + this.direct = direct; + } + setIndex(0, capacity()); + this.allocator = allocator; + } + + @Override + public boolean isWritable() { + return false; + } + + @Override + public boolean isWritable(int size) { + return false; + } + + @Override + public ByteBuf discardReadBytes() { + throw new ReadOnlyBufferException(); + } + + @Override + public ByteBuf setBytes(int index, ByteBuf src, int srcIndex, int length) { + throw new ReadOnlyBufferException(); + } + + @Override + public ByteBuf setBytes(int index, byte[] src, int srcIndex, int length) { + throw new ReadOnlyBufferException(); + } + + @Override + public ByteBuf setBytes(int index, ByteBuffer src) { + throw new ReadOnlyBufferException(); + } + + @Override + public ByteBuf setByte(int index, int value) { + throw new ReadOnlyBufferException(); + } + + @Override + protected void _setByte(int index, int value) { + throw new ReadOnlyBufferException(); + } + + @Override + public ByteBuf setShort(int index, int value) { + throw new ReadOnlyBufferException(); + } + + @Override + protected void _setShort(int index, int value) { + throw new ReadOnlyBufferException(); + } + + @Override + protected void _setShortLE(int index, int value) { + throw new ReadOnlyBufferException(); + } + + @Override + public ByteBuf setMedium(int index, int value) { + throw new ReadOnlyBufferException(); + } + + @Override + protected void _setMedium(int index, int value) { + throw new ReadOnlyBufferException(); + } + + @Override + protected void _setMediumLE(int index, int value) { + throw new ReadOnlyBufferException(); + } + + @Override + public ByteBuf setInt(int index, int value) { + throw new ReadOnlyBufferException(); + } + + @Override + protected void _setInt(int index, int value) { + throw new ReadOnlyBufferException(); + } + + @Override + protected void _setIntLE(int index, int value) { + throw new ReadOnlyBufferException(); + } + + @Override + public ByteBuf setLong(int index, long value) { + throw new ReadOnlyBufferException(); + } + + @Override + protected void _setLong(int index, long value) { + throw new ReadOnlyBufferException(); + } + + @Override + protected void _setLongLE(int index, long value) { + throw new ReadOnlyBufferException(); + } + + @Override + public int setBytes(int index, InputStream in, int length) { + throw new ReadOnlyBufferException(); + } + + @Override + public int setBytes(int index, ScatteringByteChannel in, int length) { + throw new ReadOnlyBufferException(); + } + + @Override + public int setBytes(int index, FileChannel in, long position, int length) { + throw new ReadOnlyBufferException(); + } + + @Override + public int capacity() { + return capacity; + } + + @Override + public int maxCapacity() { + return capacity; + } + + @Override + public ByteBuf capacity(int newCapacity) { + throw new ReadOnlyBufferException(); + } + + @Override + public ByteBufAllocator alloc() { + return allocator; + } + + @Override + public ByteOrder order() { + return order; + } + + @Override + public ByteBuf unwrap() { + return null; + } + + @Override + public boolean isDirect() { + return direct; + } + + private Component findComponent(int index) { + int readable = 0; + for (int i = 0 ; i < buffers.length; i++) { + Component comp = null; + ByteBuf b = buffers[i]; + if (b instanceof Component) { + comp = (Component) b; + b = comp.buf; + } + readable += b.readableBytes(); + if (index < readable) { + if (comp == null) { + // Create a new component and store it in the array so it not create a new object + // on the next access. + comp = new Component(i, readable - b.readableBytes(), b); + buffers[i] = comp; + } + return comp; + } + } + throw new IllegalStateException(); + } + + /** + * Return the {@link ByteBuf} stored at the given index of the array. + */ + private ByteBuf buffer(int i) { + ByteBuf b = buffers[i]; + return b instanceof Component ? ((Component) b).buf : b; + } + + @Override + public byte getByte(int index) { + return _getByte(index); + } + + @Override + protected byte _getByte(int index) { + Component c = findComponent(index); + return c.buf.getByte(index - c.offset); + } + + @Override + protected short _getShort(int index) { + Component c = findComponent(index); + if (index + 2 <= c.endOffset) { + return c.buf.getShort(index - c.offset); + } else if (order() == ByteOrder.BIG_ENDIAN) { + return (short) ((_getByte(index) & 0xff) << 8 | _getByte(index + 1) & 0xff); + } else { + return (short) (_getByte(index) & 0xff | (_getByte(index + 1) & 0xff) << 8); + } + } + + @Override + protected short _getShortLE(int index) { + Component c = findComponent(index); + if (index + 2 <= c.endOffset) { + return c.buf.getShortLE(index - c.offset); + } else if (order() == ByteOrder.BIG_ENDIAN) { + return (short) (_getByte(index) & 0xff | (_getByte(index + 1) & 0xff) << 8); + } else { + return (short) ((_getByte(index) & 0xff) << 8 | _getByte(index + 1) & 0xff); + } + } + + @Override + protected int _getUnsignedMedium(int index) { + Component c = findComponent(index); + if (index + 3 <= c.endOffset) { + return c.buf.getUnsignedMedium(index - c.offset); + } else if (order() == ByteOrder.BIG_ENDIAN) { + return (_getShort(index) & 0xffff) << 8 | _getByte(index + 2) & 0xff; + } else { + return _getShort(index) & 0xFFFF | (_getByte(index + 2) & 0xFF) << 16; + } + } + + @Override + protected int _getUnsignedMediumLE(int index) { + Component c = findComponent(index); + if (index + 3 <= c.endOffset) { + return c.buf.getUnsignedMediumLE(index - c.offset); + } else if (order() == ByteOrder.BIG_ENDIAN) { + return _getShortLE(index) & 0xffff | (_getByte(index + 2) & 0xff) << 16; + } else { + return (_getShortLE(index) & 0xffff) << 8 | _getByte(index + 2) & 0xff; + } + } + + @Override + protected int _getInt(int index) { + Component c = findComponent(index); + if (index + 4 <= c.endOffset) { + return c.buf.getInt(index - c.offset); + } else if (order() == ByteOrder.BIG_ENDIAN) { + return (_getShort(index) & 0xffff) << 16 | _getShort(index + 2) & 0xffff; + } else { + return _getShort(index) & 0xFFFF | (_getShort(index + 2) & 0xFFFF) << 16; + } + } + + @Override + protected int _getIntLE(int index) { + Component c = findComponent(index); + if (index + 4 <= c.endOffset) { + return c.buf.getIntLE(index - c.offset); + } else if (order() == ByteOrder.BIG_ENDIAN) { + return _getShortLE(index) & 0xFFFF | (_getShortLE(index + 2) & 0xFFFF) << 16; + } else { + return (_getShortLE(index) & 0xffff) << 16 | _getShortLE(index + 2) & 0xffff; + } + } + + @Override + protected long _getLong(int index) { + Component c = findComponent(index); + if (index + 8 <= c.endOffset) { + return c.buf.getLong(index - c.offset); + } else if (order() == ByteOrder.BIG_ENDIAN) { + return (_getInt(index) & 0xffffffffL) << 32 | _getInt(index + 4) & 0xffffffffL; + } else { + return _getInt(index) & 0xFFFFFFFFL | (_getInt(index + 4) & 0xFFFFFFFFL) << 32; + } + } + + @Override + protected long _getLongLE(int index) { + Component c = findComponent(index); + if (index + 8 <= c.endOffset) { + return c.buf.getLongLE(index - c.offset); + } else if (order() == ByteOrder.BIG_ENDIAN) { + return _getIntLE(index) & 0xffffffffL | (_getIntLE(index + 4) & 0xffffffffL) << 32; + } else { + return (_getIntLE(index) & 0xffffffffL) << 32 | _getIntLE(index + 4) & 0xffffffffL; + } + } + + @Override + public ByteBuf getBytes(int index, byte[] dst, int dstIndex, int length) { + checkDstIndex(index, length, dstIndex, dst.length); + if (length == 0) { + return this; + } + + Component c = findComponent(index); + int i = c.index; + int adjustment = c.offset; + ByteBuf s = c.buf; + for (;;) { + int localLength = Math.min(length, s.readableBytes() - (index - adjustment)); + s.getBytes(index - adjustment, dst, dstIndex, localLength); + index += localLength; + dstIndex += localLength; + length -= localLength; + adjustment += s.readableBytes(); + if (length <= 0) { + break; + } + s = buffer(++i); + } + return this; + } + + @Override + public ByteBuf getBytes(int index, ByteBuffer dst) { + int limit = dst.limit(); + int length = dst.remaining(); + + checkIndex(index, length); + if (length == 0) { + return this; + } + + try { + Component c = findComponent(index); + int i = c.index; + int adjustment = c.offset; + ByteBuf s = c.buf; + for (;;) { + int localLength = Math.min(length, s.readableBytes() - (index - adjustment)); + dst.limit(dst.position() + localLength); + s.getBytes(index - adjustment, dst); + index += localLength; + length -= localLength; + adjustment += s.readableBytes(); + if (length <= 0) { + break; + } + s = buffer(++i); + } + } finally { + dst.limit(limit); + } + return this; + } + + @Override + public ByteBuf getBytes(int index, ByteBuf dst, int dstIndex, int length) { + checkDstIndex(index, length, dstIndex, dst.capacity()); + if (length == 0) { + return this; + } + + Component c = findComponent(index); + int i = c.index; + int adjustment = c.offset; + ByteBuf s = c.buf; + for (;;) { + int localLength = Math.min(length, s.readableBytes() - (index - adjustment)); + s.getBytes(index - adjustment, dst, dstIndex, localLength); + index += localLength; + dstIndex += localLength; + length -= localLength; + adjustment += s.readableBytes(); + if (length <= 0) { + break; + } + s = buffer(++i); + } + return this; + } + + @Override + public int getBytes(int index, GatheringByteChannel out, int length) + throws IOException { + int count = nioBufferCount(); + if (count == 1) { + return out.write(internalNioBuffer(index, length)); + } else { + long writtenBytes = out.write(nioBuffers(index, length)); + if (writtenBytes > Integer.MAX_VALUE) { + return Integer.MAX_VALUE; + } else { + return (int) writtenBytes; + } + } + } + + @Override + public int getBytes(int index, FileChannel out, long position, int length) + throws IOException { + int count = nioBufferCount(); + if (count == 1) { + return out.write(internalNioBuffer(index, length), position); + } else { + long writtenBytes = 0; + for (ByteBuffer buf : nioBuffers(index, length)) { + writtenBytes += out.write(buf, position + writtenBytes); + } + if (writtenBytes > Integer.MAX_VALUE) { + return Integer.MAX_VALUE; + } else { + return (int) writtenBytes; + } + } + } + + @Override + public ByteBuf getBytes(int index, OutputStream out, int length) throws IOException { + checkIndex(index, length); + if (length == 0) { + return this; + } + + Component c = findComponent(index); + int i = c.index; + int adjustment = c.offset; + ByteBuf s = c.buf; + for (;;) { + int localLength = Math.min(length, s.readableBytes() - (index - adjustment)); + s.getBytes(index - adjustment, out, localLength); + index += localLength; + length -= localLength; + adjustment += s.readableBytes(); + if (length <= 0) { + break; + } + s = buffer(++i); + } + return this; + } + + @Override + public ByteBuf copy(int index, int length) { + checkIndex(index, length); + boolean release = true; + ByteBuf buf = alloc().buffer(length); + try { + buf.writeBytes(this, index, length); + release = false; + return buf; + } finally { + if (release) { + buf.release(); + } + } + } + + @Override + public int nioBufferCount() { + return nioBufferCount; + } + + @Override + public ByteBuffer nioBuffer(int index, int length) { + checkIndex(index, length); + if (buffers.length == 1) { + ByteBuf buf = buffer(0); + if (buf.nioBufferCount() == 1) { + return buf.nioBuffer(index, length); + } + } + ByteBuffer merged = ByteBuffer.allocate(length).order(order()); + ByteBuffer[] buffers = nioBuffers(index, length); + + //noinspection ForLoopReplaceableByForEach + for (int i = 0; i < buffers.length; i++) { + merged.put(buffers[i]); + } + + merged.flip(); + return merged; + } + + @Override + public ByteBuffer internalNioBuffer(int index, int length) { + if (buffers.length == 1) { + return buffer(0).internalNioBuffer(index, length); + } + throw new UnsupportedOperationException(); + } + + @Override + public ByteBuffer[] nioBuffers(int index, int length) { + checkIndex(index, length); + if (length == 0) { + return EmptyArrays.EMPTY_BYTE_BUFFERS; + } + + RecyclableArrayList array = RecyclableArrayList.newInstance(buffers.length); + try { + Component c = findComponent(index); + int i = c.index; + int adjustment = c.offset; + ByteBuf s = c.buf; + for (;;) { + int localLength = Math.min(length, s.readableBytes() - (index - adjustment)); + switch (s.nioBufferCount()) { + case 0: + throw new UnsupportedOperationException(); + case 1: + array.add(s.nioBuffer(index - adjustment, localLength)); + break; + default: + Collections.addAll(array, s.nioBuffers(index - adjustment, localLength)); + } + + index += localLength; + length -= localLength; + adjustment += s.readableBytes(); + if (length <= 0) { + break; + } + s = buffer(++i); + } + + return array.toArray(EmptyArrays.EMPTY_BYTE_BUFFERS); + } finally { + array.recycle(); + } + } + + @Override + public boolean hasArray() { + switch (buffers.length) { + case 0: + return true; + case 1: + return buffer(0).hasArray(); + default: + return false; + } + } + + @Override + public byte[] array() { + switch (buffers.length) { + case 0: + return EmptyArrays.EMPTY_BYTES; + case 1: + return buffer(0).array(); + default: + throw new UnsupportedOperationException(); + } + } + + @Override + public int arrayOffset() { + switch (buffers.length) { + case 0: + return 0; + case 1: + return buffer(0).arrayOffset(); + default: + throw new UnsupportedOperationException(); + } + } + + @Override + public boolean hasMemoryAddress() { + switch (buffers.length) { + case 0: + return Unpooled.EMPTY_BUFFER.hasMemoryAddress(); + case 1: + return buffer(0).hasMemoryAddress(); + default: + return false; + } + } + + @Override + public long memoryAddress() { + switch (buffers.length) { + case 0: + return Unpooled.EMPTY_BUFFER.memoryAddress(); + case 1: + return buffer(0).memoryAddress(); + default: + throw new UnsupportedOperationException(); + } + } + + @Override + protected void deallocate() { + for (int i = 0; i < buffers.length; i++) { + buffer(i).release(); + } + } + + @Override + public String toString() { + String result = super.toString(); + result = result.substring(0, result.length() - 1); + return result + ", components=" + buffers.length + ')'; + } + + private static final class Component extends WrappedByteBuf { + private final int index; + private final int offset; + private final int endOffset; + + Component(int index, int offset, ByteBuf buf) { + super(buf); + this.index = index; + this.offset = offset; + endOffset = offset + buf.readableBytes(); + } + } +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/HeapByteBufUtil.java b/netty-buffer/src/main/java/io/netty/buffer/HeapByteBufUtil.java new file mode 100644 index 0000000..9f7972a --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/HeapByteBufUtil.java @@ -0,0 +1,146 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +/** + * Utility class for heap buffers. + */ +final class HeapByteBufUtil { + + static byte getByte(byte[] memory, int index) { + return memory[index]; + } + + static short getShort(byte[] memory, int index) { + return (short) (memory[index] << 8 | memory[index + 1] & 0xFF); + } + + static short getShortLE(byte[] memory, int index) { + return (short) (memory[index] & 0xff | memory[index + 1] << 8); + } + + static int getUnsignedMedium(byte[] memory, int index) { + return (memory[index] & 0xff) << 16 | + (memory[index + 1] & 0xff) << 8 | + memory[index + 2] & 0xff; + } + + static int getUnsignedMediumLE(byte[] memory, int index) { + return memory[index] & 0xff | + (memory[index + 1] & 0xff) << 8 | + (memory[index + 2] & 0xff) << 16; + } + + static int getInt(byte[] memory, int index) { + return (memory[index] & 0xff) << 24 | + (memory[index + 1] & 0xff) << 16 | + (memory[index + 2] & 0xff) << 8 | + memory[index + 3] & 0xff; + } + + static int getIntLE(byte[] memory, int index) { + return memory[index] & 0xff | + (memory[index + 1] & 0xff) << 8 | + (memory[index + 2] & 0xff) << 16 | + (memory[index + 3] & 0xff) << 24; + } + + static long getLong(byte[] memory, int index) { + return ((long) memory[index] & 0xff) << 56 | + ((long) memory[index + 1] & 0xff) << 48 | + ((long) memory[index + 2] & 0xff) << 40 | + ((long) memory[index + 3] & 0xff) << 32 | + ((long) memory[index + 4] & 0xff) << 24 | + ((long) memory[index + 5] & 0xff) << 16 | + ((long) memory[index + 6] & 0xff) << 8 | + (long) memory[index + 7] & 0xff; + } + + static long getLongLE(byte[] memory, int index) { + return (long) memory[index] & 0xff | + ((long) memory[index + 1] & 0xff) << 8 | + ((long) memory[index + 2] & 0xff) << 16 | + ((long) memory[index + 3] & 0xff) << 24 | + ((long) memory[index + 4] & 0xff) << 32 | + ((long) memory[index + 5] & 0xff) << 40 | + ((long) memory[index + 6] & 0xff) << 48 | + ((long) memory[index + 7] & 0xff) << 56; + } + + static void setByte(byte[] memory, int index, int value) { + memory[index] = (byte) value; + } + + static void setShort(byte[] memory, int index, int value) { + memory[index] = (byte) (value >>> 8); + memory[index + 1] = (byte) value; + } + + static void setShortLE(byte[] memory, int index, int value) { + memory[index] = (byte) value; + memory[index + 1] = (byte) (value >>> 8); + } + + static void setMedium(byte[] memory, int index, int value) { + memory[index] = (byte) (value >>> 16); + memory[index + 1] = (byte) (value >>> 8); + memory[index + 2] = (byte) value; + } + + static void setMediumLE(byte[] memory, int index, int value) { + memory[index] = (byte) value; + memory[index + 1] = (byte) (value >>> 8); + memory[index + 2] = (byte) (value >>> 16); + } + + static void setInt(byte[] memory, int index, int value) { + memory[index] = (byte) (value >>> 24); + memory[index + 1] = (byte) (value >>> 16); + memory[index + 2] = (byte) (value >>> 8); + memory[index + 3] = (byte) value; + } + + static void setIntLE(byte[] memory, int index, int value) { + memory[index] = (byte) value; + memory[index + 1] = (byte) (value >>> 8); + memory[index + 2] = (byte) (value >>> 16); + memory[index + 3] = (byte) (value >>> 24); + } + + static void setLong(byte[] memory, int index, long value) { + memory[index] = (byte) (value >>> 56); + memory[index + 1] = (byte) (value >>> 48); + memory[index + 2] = (byte) (value >>> 40); + memory[index + 3] = (byte) (value >>> 32); + memory[index + 4] = (byte) (value >>> 24); + memory[index + 5] = (byte) (value >>> 16); + memory[index + 6] = (byte) (value >>> 8); + memory[index + 7] = (byte) value; + } + + static void setLongLE(byte[] memory, int index, long value) { + memory[index] = (byte) value; + memory[index + 1] = (byte) (value >>> 8); + memory[index + 2] = (byte) (value >>> 16); + memory[index + 3] = (byte) (value >>> 24); + memory[index + 4] = (byte) (value >>> 32); + memory[index + 5] = (byte) (value >>> 40); + memory[index + 6] = (byte) (value >>> 48); + memory[index + 7] = (byte) (value >>> 56); + } + + private HeapByteBufUtil() { } +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/IntPriorityQueue.java b/netty-buffer/src/main/java/io/netty/buffer/IntPriorityQueue.java new file mode 100644 index 0000000..b3b07ca --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/IntPriorityQueue.java @@ -0,0 +1,107 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import java.util.Arrays; + +/** + * Internal primitive priority queue, used by {@link PoolChunk}. + * The implementation is based on the binary heap, as described in Algorithms by Sedgewick and Wayne. + */ +final class IntPriorityQueue { + public static final int NO_VALUE = -1; + private int[] array = new int[9]; + private int size; + + public void offer(int handle) { + if (handle == NO_VALUE) { + throw new IllegalArgumentException("The NO_VALUE (" + NO_VALUE + ") cannot be added to the queue."); + } + size++; + if (size == array.length) { + // Grow queue capacity. + array = Arrays.copyOf(array, 1 + (array.length - 1) * 2); + } + array[size] = handle; + lift(size); + } + + public void remove(int value) { + for (int i = 1; i <= size; i++) { + if (array[i] == value) { + array[i] = array[size--]; + lift(i); + sink(i); + return; + } + } + } + + public int peek() { + if (size == 0) { + return NO_VALUE; + } + return array[1]; + } + + public int poll() { + if (size == 0) { + return NO_VALUE; + } + int val = array[1]; + array[1] = array[size]; + array[size] = 0; + size--; + sink(1); + return val; + } + + public boolean isEmpty() { + return size == 0; + } + + private void lift(int index) { + int parentIndex; + while (index > 1 && subord(parentIndex = index >> 1, index)) { + swap(index, parentIndex); + index = parentIndex; + } + } + + private void sink(int index) { + int child; + while ((child = index << 1) <= size) { + if (child < size && subord(child, child + 1)) { + child++; + } + if (!subord(index, child)) { + break; + } + swap(index, child); + index = child; + } + } + + private boolean subord(int a, int b) { + return array[a] > array[b]; + } + + private void swap(int a, int b) { + int value = array[a]; + array[a] = array[b]; + array[b] = value; + } +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/LongLongHashMap.java b/netty-buffer/src/main/java/io/netty/buffer/LongLongHashMap.java new file mode 100644 index 0000000..e962840 --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/LongLongHashMap.java @@ -0,0 +1,129 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +/** + * Internal primitive map implementation that is specifically optimised for the runs availability map use case in {@link + * PoolChunk}. + */ +final class LongLongHashMap { + private static final int MASK_TEMPLATE = ~1; + private int mask; + private long[] array; + private int maxProbe; + private long zeroVal; + private final long emptyVal; + + LongLongHashMap(long emptyVal) { + this.emptyVal = emptyVal; + zeroVal = emptyVal; + int initialSize = 32; + array = new long[initialSize]; + mask = initialSize - 1; + computeMaskAndProbe(); + } + + public long put(long key, long value) { + if (key == 0) { + long prev = zeroVal; + zeroVal = value; + return prev; + } + + for (;;) { + int index = index(key); + for (int i = 0; i < maxProbe; i++) { + long existing = array[index]; + if (existing == key || existing == 0) { + long prev = existing == 0? emptyVal : array[index + 1]; + array[index] = key; + array[index + 1] = value; + for (; i < maxProbe; i++) { // Nerf any existing misplaced entries. + index = index + 2 & mask; + if (array[index] == key) { + array[index] = 0; + prev = array[index + 1]; + break; + } + } + return prev; + } + index = index + 2 & mask; + } + expand(); // Grow array and re-hash. + } + } + + public void remove(long key) { + if (key == 0) { + zeroVal = emptyVal; + return; + } + int index = index(key); + for (int i = 0; i < maxProbe; i++) { + long existing = array[index]; + if (existing == key) { + array[index] = 0; + break; + } + index = index + 2 & mask; + } + } + + public long get(long key) { + if (key == 0) { + return zeroVal; + } + int index = index(key); + for (int i = 0; i < maxProbe; i++) { + long existing = array[index]; + if (existing == key) { + return array[index + 1]; + } + index = index + 2 & mask; + } + return emptyVal; + } + + private int index(long key) { + // Hash with murmur64, and mask. + key ^= key >>> 33; + key *= 0xff51afd7ed558ccdL; + key ^= key >>> 33; + key *= 0xc4ceb9fe1a85ec53L; + key ^= key >>> 33; + return (int) key & mask; + } + + private void expand() { + long[] prev = array; + array = new long[prev.length * 2]; + computeMaskAndProbe(); + for (int i = 0; i < prev.length; i += 2) { + long key = prev[i]; + if (key != 0) { + long val = prev[i + 1]; + put(key, val); + } + } + } + + private void computeMaskAndProbe() { + int length = array.length; + mask = length - 1 & MASK_TEMPLATE; + maxProbe = (int) Math.log(length); + } +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/PoolArena.java b/netty-buffer/src/main/java/io/netty/buffer/PoolArena.java new file mode 100644 index 0000000..ca02d66 --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/PoolArena.java @@ -0,0 +1,799 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.buffer; + +import io.netty.util.internal.LongCounter; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.StringUtil; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.locks.ReentrantLock; + +import static io.netty.buffer.PoolChunk.isSubpage; +import static java.lang.Math.max; + +abstract class PoolArena implements PoolArenaMetric { + private static final boolean HAS_UNSAFE = PlatformDependent.hasUnsafe(); + + enum SizeClass { + Small, + Normal + } + + final PooledByteBufAllocator parent; + + final PoolSubpage[] smallSubpagePools; + + private final PoolChunkList q050; + private final PoolChunkList q025; + private final PoolChunkList q000; + private final PoolChunkList qInit; + private final PoolChunkList q075; + private final PoolChunkList q100; + + private final List chunkListMetrics; + + // Metrics for allocations and deallocations + private long allocationsNormal; + // We need to use the LongCounter here as this is not guarded via synchronized block. + private final LongCounter allocationsSmall = PlatformDependent.newLongCounter(); + private final LongCounter allocationsHuge = PlatformDependent.newLongCounter(); + private final LongCounter activeBytesHuge = PlatformDependent.newLongCounter(); + + private long deallocationsSmall; + private long deallocationsNormal; + + // We need to use the LongCounter here as this is not guarded via synchronized block. + private final LongCounter deallocationsHuge = PlatformDependent.newLongCounter(); + + // Number of thread caches backed by this arena. + final AtomicInteger numThreadCaches = new AtomicInteger(); + + // TODO: Test if adding padding helps under contention + //private long pad0, pad1, pad2, pad3, pad4, pad5, pad6, pad7; + + private final ReentrantLock lock = new ReentrantLock(); + + final SizeClasses sizeClass; + + protected PoolArena(PooledByteBufAllocator parent, SizeClasses sizeClass) { + assert null != sizeClass; + this.parent = parent; + this.sizeClass = sizeClass; + smallSubpagePools = newSubpagePoolArray(sizeClass.nSubpages); + for (int i = 0; i < smallSubpagePools.length; i ++) { + smallSubpagePools[i] = newSubpagePoolHead(i); + } + + q100 = new PoolChunkList(this, null, 100, Integer.MAX_VALUE, sizeClass.chunkSize); + q075 = new PoolChunkList(this, q100, 75, 100, sizeClass.chunkSize); + q050 = new PoolChunkList(this, q075, 50, 100, sizeClass.chunkSize); + q025 = new PoolChunkList(this, q050, 25, 75, sizeClass.chunkSize); + q000 = new PoolChunkList(this, q025, 1, 50, sizeClass.chunkSize); + qInit = new PoolChunkList(this, q000, Integer.MIN_VALUE, 25, sizeClass.chunkSize); + + q100.prevList(q075); + q075.prevList(q050); + q050.prevList(q025); + q025.prevList(q000); + q000.prevList(null); + qInit.prevList(qInit); + + List metrics = new ArrayList<>(6); + metrics.add(qInit); + metrics.add(q000); + metrics.add(q025); + metrics.add(q050); + metrics.add(q075); + metrics.add(q100); + chunkListMetrics = Collections.unmodifiableList(metrics); + } + + private PoolSubpage newSubpagePoolHead(int index) { + PoolSubpage head = new PoolSubpage(index); + head.prev = head; + head.next = head; + return head; + } + + @SuppressWarnings("unchecked") + private PoolSubpage[] newSubpagePoolArray(int size) { + return new PoolSubpage[size]; + } + + abstract boolean isDirect(); + + PooledByteBuf allocate(PoolThreadCache cache, int reqCapacity, int maxCapacity) { + PooledByteBuf buf = newByteBuf(maxCapacity); + allocate(cache, buf, reqCapacity); + return buf; + } + + private void allocate(PoolThreadCache cache, PooledByteBuf buf, final int reqCapacity) { + final int sizeIdx = sizeClass.size2SizeIdx(reqCapacity); + + if (sizeIdx <= sizeClass.smallMaxSizeIdx) { + tcacheAllocateSmall(cache, buf, reqCapacity, sizeIdx); + } else if (sizeIdx < sizeClass.nSizes) { + tcacheAllocateNormal(cache, buf, reqCapacity, sizeIdx); + } else { + int normCapacity = sizeClass.directMemoryCacheAlignment > 0 + ? sizeClass.normalizeSize(reqCapacity) : reqCapacity; + // Huge allocations are never served via the cache so just call allocateHuge + allocateHuge(buf, normCapacity); + } + } + + private void tcacheAllocateSmall(PoolThreadCache cache, PooledByteBuf buf, final int reqCapacity, + final int sizeIdx) { + + if (cache.allocateSmall(this, buf, reqCapacity, sizeIdx)) { + // was able to allocate out of the cache so move on + return; + } + + /* + * Synchronize on the head. This is needed as {@link PoolChunk#allocateSubpage(int)} and + * {@link PoolChunk#free(long)} may modify the doubly linked list as well. + */ + final PoolSubpage head = smallSubpagePools[sizeIdx]; + final boolean needsNormalAllocation; + head.lock(); + try { + final PoolSubpage s = head.next; + needsNormalAllocation = s == head; + if (!needsNormalAllocation) { + assert s.doNotDestroy && s.elemSize == sizeClass.sizeIdx2size(sizeIdx) : "doNotDestroy=" + + s.doNotDestroy + ", elemSize=" + s.elemSize + ", sizeIdx=" + sizeIdx; + long handle = s.allocate(); + assert handle >= 0; + s.chunk.initBufWithSubpage(buf, null, handle, reqCapacity, cache); + } + } finally { + head.unlock(); + } + + if (needsNormalAllocation) { + lock(); + try { + allocateNormal(buf, reqCapacity, sizeIdx, cache); + } finally { + unlock(); + } + } + + incSmallAllocation(); + } + + private void tcacheAllocateNormal(PoolThreadCache cache, PooledByteBuf buf, final int reqCapacity, + final int sizeIdx) { + if (cache.allocateNormal(this, buf, reqCapacity, sizeIdx)) { + // was able to allocate out of the cache so move on + return; + } + lock(); + try { + allocateNormal(buf, reqCapacity, sizeIdx, cache); + ++allocationsNormal; + } finally { + unlock(); + } + } + + private void allocateNormal(PooledByteBuf buf, int reqCapacity, int sizeIdx, PoolThreadCache threadCache) { + assert lock.isHeldByCurrentThread(); + if (q050.allocate(buf, reqCapacity, sizeIdx, threadCache) || + q025.allocate(buf, reqCapacity, sizeIdx, threadCache) || + q000.allocate(buf, reqCapacity, sizeIdx, threadCache) || + qInit.allocate(buf, reqCapacity, sizeIdx, threadCache) || + q075.allocate(buf, reqCapacity, sizeIdx, threadCache)) { + return; + } + + // Add a new chunk. + PoolChunk c = newChunk(sizeClass.pageSize, sizeClass.nPSizes, sizeClass.pageShifts, sizeClass.chunkSize); + boolean success = c.allocate(buf, reqCapacity, sizeIdx, threadCache); + assert success; + qInit.add(c); + } + + private void incSmallAllocation() { + allocationsSmall.increment(); + } + + private void allocateHuge(PooledByteBuf buf, int reqCapacity) { + PoolChunk chunk = newUnpooledChunk(reqCapacity); + activeBytesHuge.add(chunk.chunkSize()); + buf.initUnpooled(chunk, reqCapacity); + allocationsHuge.increment(); + } + + void free(PoolChunk chunk, ByteBuffer nioBuffer, long handle, int normCapacity, PoolThreadCache cache) { + chunk.decrementPinnedMemory(normCapacity); + if (chunk.unpooled) { + int size = chunk.chunkSize(); + destroyChunk(chunk); + activeBytesHuge.add(-size); + deallocationsHuge.increment(); + } else { + SizeClass sizeClass = sizeClass(handle); + if (cache != null && cache.add(this, chunk, nioBuffer, handle, normCapacity, sizeClass)) { + // cached so not free it. + return; + } + + freeChunk(chunk, handle, normCapacity, sizeClass, nioBuffer, false); + } + } + + private static SizeClass sizeClass(long handle) { + return isSubpage(handle) ? SizeClass.Small : SizeClass.Normal; + } + + void freeChunk(PoolChunk chunk, long handle, int normCapacity, SizeClass sizeClass, ByteBuffer nioBuffer, + boolean finalizer) { + final boolean destroyChunk; + lock(); + try { + // We only call this if freeChunk is not called because of the PoolThreadCache finalizer as otherwise this + // may fail due lazy class-loading in for example tomcat. + if (!finalizer) { + switch (sizeClass) { + case Normal: + ++deallocationsNormal; + break; + case Small: + ++deallocationsSmall; + break; + default: + throw new Error(); + } + } + destroyChunk = !chunk.parent.free(chunk, handle, normCapacity, nioBuffer); + } finally { + unlock(); + } + if (destroyChunk) { + // destroyChunk not need to be called while holding the synchronized lock. + destroyChunk(chunk); + } + } + + void reallocate(final PooledByteBuf buf, int newCapacity) { + assert newCapacity >= 0 && newCapacity <= buf.maxCapacity(); + + final int oldCapacity; + final PoolChunk oldChunk; + final ByteBuffer oldNioBuffer; + final long oldHandle; + final T oldMemory; + final int oldOffset; + final int oldMaxLength; + final PoolThreadCache oldCache; + + // We synchronize on the ByteBuf itself to ensure there is no "concurrent" reallocations for the same buffer. + // We do this to ensure the ByteBuf internal fields that are used to allocate / free are not accessed + // concurrently. This is important as otherwise we might end up corrupting our internal state of our data + // structures. + // + // Also note we don't use a Lock here but just synchronized even tho this might seem like a bad choice for Loom. + // This is done to minimize the overhead per ByteBuf. The time this would block another thread should be + // relative small and so not be a problem for Loom. + // See https://github.com/netty/netty/issues/13467 + synchronized (buf) { + oldCapacity = buf.length; + if (oldCapacity == newCapacity) { + return; + } + + oldChunk = buf.chunk; + oldNioBuffer = buf.tmpNioBuf; + oldHandle = buf.handle; + oldMemory = buf.memory; + oldOffset = buf.offset; + oldMaxLength = buf.maxLength; + oldCache = buf.cache; + + // This does not touch buf's reader/writer indices + allocate(parent.threadCache(), buf, newCapacity); + } + int bytesToCopy; + if (newCapacity > oldCapacity) { + bytesToCopy = oldCapacity; + } else { + buf.trimIndicesToCapacity(newCapacity); + bytesToCopy = newCapacity; + } + memoryCopy(oldMemory, oldOffset, buf, bytesToCopy); + free(oldChunk, oldNioBuffer, oldHandle, oldMaxLength, oldCache); + } + + @Override + public int numThreadCaches() { + return numThreadCaches.get(); + } + + @Override + public int numTinySubpages() { + return 0; + } + + @Override + public int numSmallSubpages() { + return smallSubpagePools.length; + } + + @Override + public int numChunkLists() { + return chunkListMetrics.size(); + } + + @Override + public List tinySubpages() { + return Collections.emptyList(); + } + + @Override + public List smallSubpages() { + return subPageMetricList(smallSubpagePools); + } + + @Override + public List chunkLists() { + return chunkListMetrics; + } + + private static List subPageMetricList(PoolSubpage[] pages) { + List metrics = new ArrayList(); + for (PoolSubpage head : pages) { + if (head.next == head) { + continue; + } + PoolSubpage s = head.next; + while (true) { + metrics.add(s); + s = s.next; + if (s == head) { + break; + } + } + } + return metrics; + } + + @Override + public long numAllocations() { + final long allocsNormal; + lock(); + try { + allocsNormal = allocationsNormal; + } finally { + unlock(); + } + return allocationsSmall.value() + allocsNormal + allocationsHuge.value(); + } + + @Override + public long numTinyAllocations() { + return 0; + } + + @Override + public long numSmallAllocations() { + return allocationsSmall.value(); + } + + @Override + public long numNormalAllocations() { + lock(); + try { + return allocationsNormal; + } finally { + unlock(); + } + } + + @Override + public long numDeallocations() { + final long deallocs; + lock(); + try { + deallocs = deallocationsSmall + deallocationsNormal; + } finally { + unlock(); + } + return deallocs + deallocationsHuge.value(); + } + + @Override + public long numTinyDeallocations() { + return 0; + } + + @Override + public long numSmallDeallocations() { + lock(); + try { + return deallocationsSmall; + } finally { + unlock(); + } + } + + @Override + public long numNormalDeallocations() { + lock(); + try { + return deallocationsNormal; + } finally { + unlock(); + } + } + + @Override + public long numHugeAllocations() { + return allocationsHuge.value(); + } + + @Override + public long numHugeDeallocations() { + return deallocationsHuge.value(); + } + + @Override + public long numActiveAllocations() { + long val = allocationsSmall.value() + allocationsHuge.value() + - deallocationsHuge.value(); + lock(); + try { + val += allocationsNormal - (deallocationsSmall + deallocationsNormal); + } finally { + unlock(); + } + return max(val, 0); + } + + @Override + public long numActiveTinyAllocations() { + return 0; + } + + @Override + public long numActiveSmallAllocations() { + return max(numSmallAllocations() - numSmallDeallocations(), 0); + } + + @Override + public long numActiveNormalAllocations() { + final long val; + lock(); + try { + val = allocationsNormal - deallocationsNormal; + } finally { + unlock(); + } + return max(val, 0); + } + + @Override + public long numActiveHugeAllocations() { + return max(numHugeAllocations() - numHugeDeallocations(), 0); + } + + @Override + public long numActiveBytes() { + long val = activeBytesHuge.value(); + lock(); + try { + for (PoolChunkListMetric chunkListMetric : chunkListMetrics) { + for (PoolChunkMetric m : chunkListMetric) { + val += m.chunkSize(); + } + } + } finally { + unlock(); + } + return max(0, val); + } + + /** + * Return the number of bytes that are currently pinned to buffer instances, by the arena. The pinned memory is not + * accessible for use by any other allocation, until the buffers using have all been released. + */ + public long numPinnedBytes() { + long val = activeBytesHuge.value(); // Huge chunks are exact-sized for the buffers they were allocated to. + lock(); + try { + for (PoolChunkListMetric chunkListMetric : chunkListMetrics) { + for (PoolChunkMetric m : chunkListMetric) { + val += ((PoolChunk) m).pinnedBytes(); + } + } + } finally { + unlock(); + } + return max(0, val); + } + + protected abstract PoolChunk newChunk(int pageSize, int maxPageIdx, int pageShifts, int chunkSize); + protected abstract PoolChunk newUnpooledChunk(int capacity); + protected abstract PooledByteBuf newByteBuf(int maxCapacity); + protected abstract void memoryCopy(T src, int srcOffset, PooledByteBuf dst, int length); + protected abstract void destroyChunk(PoolChunk chunk); + + @Override + public String toString() { + lock(); + try { + StringBuilder buf = new StringBuilder() + .append("Chunk(s) at 0~25%:") + .append(StringUtil.NEWLINE) + .append(qInit) + .append(StringUtil.NEWLINE) + .append("Chunk(s) at 0~50%:") + .append(StringUtil.NEWLINE) + .append(q000) + .append(StringUtil.NEWLINE) + .append("Chunk(s) at 25~75%:") + .append(StringUtil.NEWLINE) + .append(q025) + .append(StringUtil.NEWLINE) + .append("Chunk(s) at 50~100%:") + .append(StringUtil.NEWLINE) + .append(q050) + .append(StringUtil.NEWLINE) + .append("Chunk(s) at 75~100%:") + .append(StringUtil.NEWLINE) + .append(q075) + .append(StringUtil.NEWLINE) + .append("Chunk(s) at 100%:") + .append(StringUtil.NEWLINE) + .append(q100) + .append(StringUtil.NEWLINE) + .append("small subpages:"); + appendPoolSubPages(buf, smallSubpagePools); + buf.append(StringUtil.NEWLINE); + return buf.toString(); + } finally { + unlock(); + } + } + + private static void appendPoolSubPages(StringBuilder buf, PoolSubpage[] subpages) { + for (int i = 0; i < subpages.length; i ++) { + PoolSubpage head = subpages[i]; + if (head.next == head || head.next == null) { + continue; + } + + buf.append(StringUtil.NEWLINE) + .append(i) + .append(": "); + PoolSubpage s = head.next; + while (s != null) { + buf.append(s); + s = s.next; + if (s == head) { + break; + } + } + } + } + + @Override + protected final void finalize() throws Throwable { + try { + super.finalize(); + } finally { + destroyPoolSubPages(smallSubpagePools); + destroyPoolChunkLists(qInit, q000, q025, q050, q075, q100); + } + } + + private static void destroyPoolSubPages(PoolSubpage[] pages) { + for (PoolSubpage page : pages) { + page.destroy(); + } + } + + private void destroyPoolChunkLists(PoolChunkList... chunkLists) { + for (PoolChunkList chunkList: chunkLists) { + chunkList.destroy(this); + } + } + + static final class HeapArena extends PoolArena { + + HeapArena(PooledByteBufAllocator parent, SizeClasses sizeClass) { + super(parent, sizeClass); + } + + private static byte[] newByteArray(int size) { + return PlatformDependent.allocateUninitializedArray(size); + } + + @Override + boolean isDirect() { + return false; + } + + @Override + protected PoolChunk newChunk(int pageSize, int maxPageIdx, int pageShifts, int chunkSize) { + return new PoolChunk( + this, null, newByteArray(chunkSize), pageSize, pageShifts, chunkSize, maxPageIdx); + } + + @Override + protected PoolChunk newUnpooledChunk(int capacity) { + return new PoolChunk(this, null, newByteArray(capacity), capacity); + } + + @Override + protected void destroyChunk(PoolChunk chunk) { + // Rely on GC. + } + + @Override + protected PooledByteBuf newByteBuf(int maxCapacity) { + return HAS_UNSAFE ? PooledUnsafeHeapByteBuf.newUnsafeInstance(maxCapacity) + : PooledHeapByteBuf.newInstance(maxCapacity); + } + + @Override + protected void memoryCopy(byte[] src, int srcOffset, PooledByteBuf dst, int length) { + if (length == 0) { + return; + } + + System.arraycopy(src, srcOffset, dst.memory, dst.offset, length); + } + } + + static final class DirectArena extends PoolArena { + + DirectArena(PooledByteBufAllocator parent, SizeClasses sizeClass) { + super(parent, sizeClass); + } + + @Override + boolean isDirect() { + return true; + } + + @Override + protected PoolChunk newChunk(int pageSize, int maxPageIdx, + int pageShifts, int chunkSize) { + if (sizeClass.directMemoryCacheAlignment == 0) { + ByteBuffer memory = allocateDirect(chunkSize); + return new PoolChunk(this, memory, memory, pageSize, pageShifts, + chunkSize, maxPageIdx); + } + + final ByteBuffer base = allocateDirect(chunkSize + sizeClass.directMemoryCacheAlignment); + final ByteBuffer memory = PlatformDependent.alignDirectBuffer(base, sizeClass.directMemoryCacheAlignment); + return new PoolChunk(this, base, memory, pageSize, + pageShifts, chunkSize, maxPageIdx); + } + + @Override + protected PoolChunk newUnpooledChunk(int capacity) { + if (sizeClass.directMemoryCacheAlignment == 0) { + ByteBuffer memory = allocateDirect(capacity); + return new PoolChunk(this, memory, memory, capacity); + } + + final ByteBuffer base = allocateDirect(capacity + sizeClass.directMemoryCacheAlignment); + final ByteBuffer memory = PlatformDependent.alignDirectBuffer(base, sizeClass.directMemoryCacheAlignment); + return new PoolChunk(this, base, memory, capacity); + } + + private static ByteBuffer allocateDirect(int capacity) { + return PlatformDependent.useDirectBufferNoCleaner() ? + PlatformDependent.allocateDirectNoCleaner(capacity) : ByteBuffer.allocateDirect(capacity); + } + + @Override + protected void destroyChunk(PoolChunk chunk) { + if (PlatformDependent.useDirectBufferNoCleaner()) { + PlatformDependent.freeDirectNoCleaner((ByteBuffer) chunk.base); + } else { + PlatformDependent.freeDirectBuffer((ByteBuffer) chunk.base); + } + } + + @Override + protected PooledByteBuf newByteBuf(int maxCapacity) { + if (HAS_UNSAFE) { + return PooledUnsafeDirectByteBuf.newInstance(maxCapacity); + } else { + return PooledDirectByteBuf.newInstance(maxCapacity); + } + } + + @Override + protected void memoryCopy(ByteBuffer src, int srcOffset, PooledByteBuf dstBuf, int length) { + if (length == 0) { + return; + } + + if (HAS_UNSAFE) { + PlatformDependent.copyMemory( + PlatformDependent.directBufferAddress(src) + srcOffset, + PlatformDependent.directBufferAddress(dstBuf.memory) + dstBuf.offset, length); + } else { + // We must duplicate the NIO buffers because they may be accessed by other Netty buffers. + src = src.duplicate(); + ByteBuffer dst = dstBuf.internalNioBuffer(); + src.position(srcOffset).limit(srcOffset + length); + dst.position(dstBuf.offset); + dst.put(src); + } + } + } + + void lock() { + lock.lock(); + } + + void unlock() { + lock.unlock(); + } + + @Override + public int sizeIdx2size(int sizeIdx) { + return sizeClass.sizeIdx2size(sizeIdx); + } + + @Override + public int sizeIdx2sizeCompute(int sizeIdx) { + return sizeClass.sizeIdx2sizeCompute(sizeIdx); + } + + @Override + public long pageIdx2size(int pageIdx) { + return sizeClass.pageIdx2size(pageIdx); + } + + @Override + public long pageIdx2sizeCompute(int pageIdx) { + return sizeClass.pageIdx2sizeCompute(pageIdx); + } + + @Override + public int size2SizeIdx(int size) { + return sizeClass.size2SizeIdx(size); + } + + @Override + public int pages2pageIdx(int pages) { + return sizeClass.pages2pageIdx(pages); + } + + @Override + public int pages2pageIdxFloor(int pages) { + return sizeClass.pages2pageIdxFloor(pages); + } + + @Override + public int normalizeSize(int size) { + return sizeClass.normalizeSize(size); + } +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/PoolArenaMetric.java b/netty-buffer/src/main/java/io/netty/buffer/PoolArenaMetric.java new file mode 100644 index 0000000..b11a3c4 --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/PoolArenaMetric.java @@ -0,0 +1,155 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.buffer; + +import java.util.List; + +/** + * Expose metrics for an arena. + */ +public interface PoolArenaMetric extends SizeClassesMetric { + + /** + * Returns the number of thread caches backed by this arena. + */ + int numThreadCaches(); + + /** + * Returns the number of tiny sub-pages for the arena. + * + * @deprecated Tiny sub-pages have been merged into small sub-pages. + */ + @Deprecated + int numTinySubpages(); + + /** + * Returns the number of small sub-pages for the arena. + */ + int numSmallSubpages(); + + /** + * Returns the number of chunk lists for the arena. + */ + int numChunkLists(); + + /** + * Returns an unmodifiable {@link List} which holds {@link PoolSubpageMetric}s for tiny sub-pages. + * + * @deprecated Tiny sub-pages have been merged into small sub-pages. + */ + @Deprecated + List tinySubpages(); + + /** + * Returns an unmodifiable {@link List} which holds {@link PoolSubpageMetric}s for small sub-pages. + */ + List smallSubpages(); + + /** + * Returns an unmodifiable {@link List} which holds {@link PoolChunkListMetric}s. + */ + List chunkLists(); + + /** + * Return the number of allocations done via the arena. This includes all sizes. + */ + long numAllocations(); + + /** + * Return the number of tiny allocations done via the arena. + * + * @deprecated Tiny allocations have been merged into small allocations. + */ + @Deprecated + long numTinyAllocations(); + + /** + * Return the number of small allocations done via the arena. + */ + long numSmallAllocations(); + + /** + * Return the number of normal allocations done via the arena. + */ + long numNormalAllocations(); + + /** + * Return the number of huge allocations done via the arena. + */ + long numHugeAllocations(); + + /** + * Return the number of deallocations done via the arena. This includes all sizes. + */ + long numDeallocations(); + + /** + * Return the number of tiny deallocations done via the arena. + * + * @deprecated Tiny deallocations have been merged into small deallocations. + */ + @Deprecated + long numTinyDeallocations(); + + /** + * Return the number of small deallocations done via the arena. + */ + long numSmallDeallocations(); + + /** + * Return the number of normal deallocations done via the arena. + */ + long numNormalDeallocations(); + + /** + * Return the number of huge deallocations done via the arena. + */ + long numHugeDeallocations(); + + /** + * Return the number of currently active allocations. + */ + long numActiveAllocations(); + + /** + * Return the number of currently active tiny allocations. + * + * @deprecated Tiny allocations have been merged into small allocations. + */ + @Deprecated + long numActiveTinyAllocations(); + + /** + * Return the number of currently active small allocations. + */ + long numActiveSmallAllocations(); + + /** + * Return the number of currently active normal allocations. + */ + long numActiveNormalAllocations(); + + /** + * Return the number of currently active huge allocations. + */ + long numActiveHugeAllocations(); + + /** + * Return the number of active bytes that are currently allocated by the arena. + */ + long numActiveBytes(); +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/PoolChunk.java b/netty-buffer/src/main/java/io/netty/buffer/PoolChunk.java new file mode 100644 index 0000000..2d53374 --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/PoolChunk.java @@ -0,0 +1,709 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import io.netty.util.internal.LongCounter; +import io.netty.util.internal.PlatformDependent; + +import java.nio.ByteBuffer; +import java.util.ArrayDeque; +import java.util.Deque; +import java.util.PriorityQueue; +import java.util.concurrent.locks.ReentrantLock; + +/** + * Description of algorithm for PageRun/PoolSubpage allocation from PoolChunk + * + * Notation: The following terms are important to understand the code + * > page - a page is the smallest unit of memory chunk that can be allocated + * > run - a run is a collection of pages + * > chunk - a chunk is a collection of runs + * > in this code chunkSize = maxPages * pageSize + * + * To begin we allocate a byte array of size = chunkSize + * Whenever a ByteBuf of given size needs to be created we search for the first position + * in the byte array that has enough empty space to accommodate the requested size and + * return a (long) handle that encodes this offset information, (this memory segment is then + * marked as reserved so it is always used by exactly one ByteBuf and no more) + * + * For simplicity all sizes are normalized according to {@link PoolArena#sizeClass#size2SizeIdx(int)} method. + * This ensures that when we request for memory segments of size > pageSize the normalizedCapacity + * equals the next nearest size in {@link SizeClasses}. + * + * + * A chunk has the following layout: + * + * /-----------------\ + * | run | + * | | + * | | + * |-----------------| + * | run | + * | | + * |-----------------| + * | unalloctated | + * | (freed) | + * | | + * |-----------------| + * | subpage | + * |-----------------| + * | unallocated | + * | (freed) | + * | ... | + * | ... | + * | ... | + * | | + * | | + * | | + * \-----------------/ + * + * + * handle: + * ------- + * a handle is a long number, the bit layout of a run looks like: + * + * oooooooo ooooooos ssssssss ssssssue bbbbbbbb bbbbbbbb bbbbbbbb bbbbbbbb + * + * o: runOffset (page offset in the chunk), 15bit + * s: size (number of pages) of this run, 15bit + * u: isUsed?, 1bit + * e: isSubpage?, 1bit + * b: bitmapIdx of subpage, zero if it's not subpage, 32bit + * + * runsAvailMap: + * ------ + * a map which manages all runs (used and not in used). + * For each run, the first runOffset and last runOffset are stored in runsAvailMap. + * key: runOffset + * value: handle + * + * runsAvail: + * ---------- + * an array of {@link PriorityQueue}. + * Each queue manages same size of runs. + * Runs are sorted by offset, so that we always allocate runs with smaller offset. + * + * + * Algorithm: + * ---------- + * + * As we allocate runs, we update values stored in runsAvailMap and runsAvail so that the property is maintained. + * + * Initialization - + * In the beginning we store the initial run which is the whole chunk. + * The initial run: + * runOffset = 0 + * size = chunkSize + * isUsed = no + * isSubpage = no + * bitmapIdx = 0 + * + * + * Algorithm: [allocateRun(size)] + * ---------- + * 1) find the first avail run using in runsAvails according to size + * 2) if pages of run is larger than request pages then split it, and save the tailing run + * for later using + * + * Algorithm: [allocateSubpage(size)] + * ---------- + * 1) find a not full subpage according to size. + * if it already exists just return, otherwise allocate a new PoolSubpage and call init() + * note that this subpage object is added to subpagesPool in the PoolArena when we init() it + * 2) call subpage.allocate() + * + * Algorithm: [free(handle, length, nioBuffer)] + * ---------- + * 1) if it is a subpage, return the slab back into this subpage + * 2) if the subpage is not used or it is a run, then start free this run + * 3) merge continuous avail runs + * 4) save the merged run + * + */ +final class PoolChunk implements PoolChunkMetric { + private static final int SIZE_BIT_LENGTH = 15; + private static final int INUSED_BIT_LENGTH = 1; + private static final int SUBPAGE_BIT_LENGTH = 1; + private static final int BITMAP_IDX_BIT_LENGTH = 32; + + static final int IS_SUBPAGE_SHIFT = BITMAP_IDX_BIT_LENGTH; + static final int IS_USED_SHIFT = SUBPAGE_BIT_LENGTH + IS_SUBPAGE_SHIFT; + static final int SIZE_SHIFT = INUSED_BIT_LENGTH + IS_USED_SHIFT; + static final int RUN_OFFSET_SHIFT = SIZE_BIT_LENGTH + SIZE_SHIFT; + + final PoolArena arena; + final Object base; + final T memory; + final boolean unpooled; + + /** + * store the first page and last page of each avail run + */ + private final LongLongHashMap runsAvailMap; + + /** + * manage all avail runs + */ + private final IntPriorityQueue[] runsAvail; + + private final ReentrantLock runsAvailLock; + + /** + * manage all subpages in this chunk + */ + private final PoolSubpage[] subpages; + + /** + * Accounting of pinned memory – memory that is currently in use by ByteBuf instances. + */ + private final LongCounter pinnedBytes = PlatformDependent.newLongCounter(); + + private final int pageSize; + private final int pageShifts; + private final int chunkSize; + + // Use as cache for ByteBuffer created from the memory. These are just duplicates and so are only a container + // around the memory itself. These are often needed for operations within the Pooled*ByteBuf and so + // may produce extra GC, which can be greatly reduced by caching the duplicates. + // + // This may be null if the PoolChunk is unpooled as pooling the ByteBuffer instances does not make any sense here. + private final Deque cachedNioBuffers; + + int freeBytes; + + PoolChunkList parent; + PoolChunk prev; + PoolChunk next; + + // TODO: Test if adding padding helps under contention + //private long pad0, pad1, pad2, pad3, pad4, pad5, pad6, pad7; + + @SuppressWarnings("unchecked") + PoolChunk(PoolArena arena, Object base, T memory, int pageSize, int pageShifts, int chunkSize, int maxPageIdx) { + unpooled = false; + this.arena = arena; + this.base = base; + this.memory = memory; + this.pageSize = pageSize; + this.pageShifts = pageShifts; + this.chunkSize = chunkSize; + freeBytes = chunkSize; + + runsAvail = newRunsAvailqueueArray(maxPageIdx); + runsAvailLock = new ReentrantLock(); + runsAvailMap = new LongLongHashMap(-1); + subpages = new PoolSubpage[chunkSize >> pageShifts]; + + //insert initial run, offset = 0, pages = chunkSize / pageSize + int pages = chunkSize >> pageShifts; + long initHandle = (long) pages << SIZE_SHIFT; + insertAvailRun(0, pages, initHandle); + + cachedNioBuffers = new ArrayDeque(8); + } + + /** Creates a special chunk that is not pooled. */ + PoolChunk(PoolArena arena, Object base, T memory, int size) { + unpooled = true; + this.arena = arena; + this.base = base; + this.memory = memory; + pageSize = 0; + pageShifts = 0; + runsAvailMap = null; + runsAvail = null; + runsAvailLock = null; + subpages = null; + chunkSize = size; + cachedNioBuffers = null; + } + + private static IntPriorityQueue[] newRunsAvailqueueArray(int size) { + IntPriorityQueue[] queueArray = new IntPriorityQueue[size]; + for (int i = 0; i < queueArray.length; i++) { + queueArray[i] = new IntPriorityQueue(); + } + return queueArray; + } + + private void insertAvailRun(int runOffset, int pages, long handle) { + int pageIdxFloor = arena.sizeClass.pages2pageIdxFloor(pages); + IntPriorityQueue queue = runsAvail[pageIdxFloor]; + assert isRun(handle); + queue.offer((int) (handle >> BITMAP_IDX_BIT_LENGTH)); + + //insert first page of run + insertAvailRun0(runOffset, handle); + if (pages > 1) { + //insert last page of run + insertAvailRun0(lastPage(runOffset, pages), handle); + } + } + + private void insertAvailRun0(int runOffset, long handle) { + long pre = runsAvailMap.put(runOffset, handle); + assert pre == -1; + } + + private void removeAvailRun(long handle) { + int pageIdxFloor = arena.sizeClass.pages2pageIdxFloor(runPages(handle)); + runsAvail[pageIdxFloor].remove((int) (handle >> BITMAP_IDX_BIT_LENGTH)); + removeAvailRun0(handle); + } + + private void removeAvailRun0(long handle) { + int runOffset = runOffset(handle); + int pages = runPages(handle); + //remove first page of run + runsAvailMap.remove(runOffset); + if (pages > 1) { + //remove last page of run + runsAvailMap.remove(lastPage(runOffset, pages)); + } + } + + private static int lastPage(int runOffset, int pages) { + return runOffset + pages - 1; + } + + private long getAvailRunByOffset(int runOffset) { + return runsAvailMap.get(runOffset); + } + + @Override + public int usage() { + final int freeBytes; + if (this.unpooled) { + freeBytes = this.freeBytes; + } else { + runsAvailLock.lock(); + try { + freeBytes = this.freeBytes; + } finally { + runsAvailLock.unlock(); + } + } + return usage(freeBytes); + } + + private int usage(int freeBytes) { + if (freeBytes == 0) { + return 100; + } + + int freePercentage = (int) (freeBytes * 100L / chunkSize); + if (freePercentage == 0) { + return 99; + } + return 100 - freePercentage; + } + + boolean allocate(PooledByteBuf buf, int reqCapacity, int sizeIdx, PoolThreadCache cache) { + final long handle; + if (sizeIdx <= arena.sizeClass.smallMaxSizeIdx) { + final PoolSubpage nextSub; + // small + // Obtain the head of the PoolSubPage pool that is owned by the PoolArena and synchronize on it. + // This is need as we may add it back and so alter the linked-list structure. + PoolSubpage head = arena.smallSubpagePools[sizeIdx]; + head.lock(); + try { + nextSub = head.next; + if (nextSub != head) { + assert nextSub.doNotDestroy && nextSub.elemSize == arena.sizeClass.sizeIdx2size(sizeIdx) : + "doNotDestroy=" + nextSub.doNotDestroy + ", elemSize=" + nextSub.elemSize + ", sizeIdx=" + + sizeIdx; + handle = nextSub.allocate(); + assert handle >= 0; + assert isSubpage(handle); + nextSub.chunk.initBufWithSubpage(buf, null, handle, reqCapacity, cache); + return true; + } + handle = allocateSubpage(sizeIdx, head); + if (handle < 0) { + return false; + } + assert isSubpage(handle); + } finally { + head.unlock(); + } + } else { + // normal + // runSize must be multiple of pageSize + int runSize = arena.sizeClass.sizeIdx2size(sizeIdx); + handle = allocateRun(runSize); + if (handle < 0) { + return false; + } + assert !isSubpage(handle); + } + + ByteBuffer nioBuffer = cachedNioBuffers != null? cachedNioBuffers.pollLast() : null; + initBuf(buf, nioBuffer, handle, reqCapacity, cache); + return true; + } + + private long allocateRun(int runSize) { + int pages = runSize >> pageShifts; + int pageIdx = arena.sizeClass.pages2pageIdx(pages); + + runsAvailLock.lock(); + try { + //find first queue which has at least one big enough run + int queueIdx = runFirstBestFit(pageIdx); + if (queueIdx == -1) { + return -1; + } + + //get run with min offset in this queue + IntPriorityQueue queue = runsAvail[queueIdx]; + long handle = queue.poll(); + assert handle != IntPriorityQueue.NO_VALUE; + handle <<= BITMAP_IDX_BIT_LENGTH; + assert !isUsed(handle) : "invalid handle: " + handle; + + removeAvailRun0(handle); + + handle = splitLargeRun(handle, pages); + + int pinnedSize = runSize(pageShifts, handle); + freeBytes -= pinnedSize; + return handle; + } finally { + runsAvailLock.unlock(); + } + } + + private int calculateRunSize(int sizeIdx) { + int maxElements = 1 << pageShifts - SizeClasses.LOG2_QUANTUM; + int runSize = 0; + int nElements; + + final int elemSize = arena.sizeClass.sizeIdx2size(sizeIdx); + + //find lowest common multiple of pageSize and elemSize + do { + runSize += pageSize; + nElements = runSize / elemSize; + } while (nElements < maxElements && runSize != nElements * elemSize); + + while (nElements > maxElements) { + runSize -= pageSize; + nElements = runSize / elemSize; + } + + assert nElements > 0; + assert runSize <= chunkSize; + assert runSize >= elemSize; + + return runSize; + } + + private int runFirstBestFit(int pageIdx) { + if (freeBytes == chunkSize) { + return arena.sizeClass.nPSizes - 1; + } + for (int i = pageIdx; i < arena.sizeClass.nPSizes; i++) { + IntPriorityQueue queue = runsAvail[i]; + if (queue != null && !queue.isEmpty()) { + return i; + } + } + return -1; + } + + private long splitLargeRun(long handle, int needPages) { + assert needPages > 0; + + int totalPages = runPages(handle); + assert needPages <= totalPages; + + int remPages = totalPages - needPages; + + if (remPages > 0) { + int runOffset = runOffset(handle); + + // keep track of trailing unused pages for later use + int availOffset = runOffset + needPages; + long availRun = toRunHandle(availOffset, remPages, 0); + insertAvailRun(availOffset, remPages, availRun); + + // not avail + return toRunHandle(runOffset, needPages, 1); + } + + //mark it as used + handle |= 1L << IS_USED_SHIFT; + return handle; + } + + /** + * Create / initialize a new PoolSubpage of normCapacity. Any PoolSubpage created / initialized here is added to + * subpage pool in the PoolArena that owns this PoolChunk. + * + * @param sizeIdx sizeIdx of normalized size + * @param head head of subpages + * + * @return index in memoryMap + */ + private long allocateSubpage(int sizeIdx, PoolSubpage head) { + //allocate a new run + int runSize = calculateRunSize(sizeIdx); + //runSize must be multiples of pageSize + long runHandle = allocateRun(runSize); + if (runHandle < 0) { + return -1; + } + + int runOffset = runOffset(runHandle); + assert subpages[runOffset] == null; + int elemSize = arena.sizeClass.sizeIdx2size(sizeIdx); + + PoolSubpage subpage = new PoolSubpage(head, this, pageShifts, runOffset, + runSize(pageShifts, runHandle), elemSize); + + subpages[runOffset] = subpage; + return subpage.allocate(); + } + + /** + * Free a subpage or a run of pages When a subpage is freed from PoolSubpage, it might be added back to subpage pool + * of the owning PoolArena. If the subpage pool in PoolArena has at least one other PoolSubpage of given elemSize, + * we can completely free the owning Page so it is available for subsequent allocations + * + * @param handle handle to free + */ + void free(long handle, int normCapacity, ByteBuffer nioBuffer) { + if (isSubpage(handle)) { + int sIdx = runOffset(handle); + PoolSubpage subpage = subpages[sIdx]; + assert subpage != null; + PoolSubpage head = subpage.chunk.arena.smallSubpagePools[subpage.headIndex]; + // Obtain the head of the PoolSubPage pool that is owned by the PoolArena and synchronize on it. + // This is need as we may add it back and so alter the linked-list structure. + head.lock(); + try { + assert subpage.doNotDestroy; + if (subpage.free(head, bitmapIdx(handle))) { + //the subpage is still used, do not free it + return; + } + assert !subpage.doNotDestroy; + // Null out slot in the array as it was freed and we should not use it anymore. + subpages[sIdx] = null; + } finally { + head.unlock(); + } + } + + int runSize = runSize(pageShifts, handle); + //start free run + runsAvailLock.lock(); + try { + // collapse continuous runs, successfully collapsed runs + // will be removed from runsAvail and runsAvailMap + long finalRun = collapseRuns(handle); + + //set run as not used + finalRun &= ~(1L << IS_USED_SHIFT); + //if it is a subpage, set it to run + finalRun &= ~(1L << IS_SUBPAGE_SHIFT); + + insertAvailRun(runOffset(finalRun), runPages(finalRun), finalRun); + freeBytes += runSize; + } finally { + runsAvailLock.unlock(); + } + + if (nioBuffer != null && cachedNioBuffers != null && + cachedNioBuffers.size() < PooledByteBufAllocator.DEFAULT_MAX_CACHED_BYTEBUFFERS_PER_CHUNK) { + cachedNioBuffers.offer(nioBuffer); + } + } + + private long collapseRuns(long handle) { + return collapseNext(collapsePast(handle)); + } + + private long collapsePast(long handle) { + for (;;) { + int runOffset = runOffset(handle); + int runPages = runPages(handle); + + long pastRun = getAvailRunByOffset(runOffset - 1); + if (pastRun == -1) { + return handle; + } + + int pastOffset = runOffset(pastRun); + int pastPages = runPages(pastRun); + + //is continuous + if (pastRun != handle && pastOffset + pastPages == runOffset) { + //remove past run + removeAvailRun(pastRun); + handle = toRunHandle(pastOffset, pastPages + runPages, 0); + } else { + return handle; + } + } + } + + private long collapseNext(long handle) { + for (;;) { + int runOffset = runOffset(handle); + int runPages = runPages(handle); + + long nextRun = getAvailRunByOffset(runOffset + runPages); + if (nextRun == -1) { + return handle; + } + + int nextOffset = runOffset(nextRun); + int nextPages = runPages(nextRun); + + //is continuous + if (nextRun != handle && runOffset + runPages == nextOffset) { + //remove next run + removeAvailRun(nextRun); + handle = toRunHandle(runOffset, runPages + nextPages, 0); + } else { + return handle; + } + } + } + + private static long toRunHandle(int runOffset, int runPages, int inUsed) { + return (long) runOffset << RUN_OFFSET_SHIFT + | (long) runPages << SIZE_SHIFT + | (long) inUsed << IS_USED_SHIFT; + } + + void initBuf(PooledByteBuf buf, ByteBuffer nioBuffer, long handle, int reqCapacity, + PoolThreadCache threadCache) { + if (isSubpage(handle)) { + initBufWithSubpage(buf, nioBuffer, handle, reqCapacity, threadCache); + } else { + int maxLength = runSize(pageShifts, handle); + buf.init(this, nioBuffer, handle, runOffset(handle) << pageShifts, + reqCapacity, maxLength, arena.parent.threadCache()); + } + } + + void initBufWithSubpage(PooledByteBuf buf, ByteBuffer nioBuffer, long handle, int reqCapacity, + PoolThreadCache threadCache) { + int runOffset = runOffset(handle); + int bitmapIdx = bitmapIdx(handle); + + PoolSubpage s = subpages[runOffset]; + assert s.isDoNotDestroy(); + assert reqCapacity <= s.elemSize : reqCapacity + "<=" + s.elemSize; + + int offset = (runOffset << pageShifts) + bitmapIdx * s.elemSize; + buf.init(this, nioBuffer, handle, offset, reqCapacity, s.elemSize, threadCache); + } + + void incrementPinnedMemory(int delta) { + assert delta > 0; + pinnedBytes.add(delta); + } + + void decrementPinnedMemory(int delta) { + assert delta > 0; + pinnedBytes.add(-delta); + } + + @Override + public int chunkSize() { + return chunkSize; + } + + @Override + public int freeBytes() { + if (this.unpooled) { + return freeBytes; + } + runsAvailLock.lock(); + try { + return freeBytes; + } finally { + runsAvailLock.unlock(); + } + } + + public int pinnedBytes() { + return (int) pinnedBytes.value(); + } + + @Override + public String toString() { + final int freeBytes; + if (this.unpooled) { + freeBytes = this.freeBytes; + } else { + runsAvailLock.lock(); + try { + freeBytes = this.freeBytes; + } finally { + runsAvailLock.unlock(); + } + } + + return new StringBuilder() + .append("Chunk(") + .append(Integer.toHexString(System.identityHashCode(this))) + .append(": ") + .append(usage(freeBytes)) + .append("%, ") + .append(chunkSize - freeBytes) + .append('/') + .append(chunkSize) + .append(')') + .toString(); + } + + void destroy() { + arena.destroyChunk(this); + } + + static int runOffset(long handle) { + return (int) (handle >> RUN_OFFSET_SHIFT); + } + + static int runSize(int pageShifts, long handle) { + return runPages(handle) << pageShifts; + } + + static int runPages(long handle) { + return (int) (handle >> SIZE_SHIFT & 0x7fff); + } + + static boolean isUsed(long handle) { + return (handle >> IS_USED_SHIFT & 1) == 1L; + } + + static boolean isRun(long handle) { + return !isSubpage(handle); + } + + static boolean isSubpage(long handle) { + return (handle >> IS_SUBPAGE_SHIFT & 1) == 1L; + } + + static int bitmapIdx(long handle) { + return (int) handle; + } +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/PoolChunkList.java b/netty-buffer/src/main/java/io/netty/buffer/PoolChunkList.java new file mode 100644 index 0000000..19daec7 --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/PoolChunkList.java @@ -0,0 +1,262 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.buffer; + +import io.netty.util.internal.StringUtil; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; + +import static java.lang.Math.*; + +import java.nio.ByteBuffer; + +final class PoolChunkList implements PoolChunkListMetric { + private static final Iterator EMPTY_METRICS = Collections.emptyList().iterator(); + private final PoolArena arena; + private final PoolChunkList nextList; + private final int minUsage; + private final int maxUsage; + private final int maxCapacity; + private PoolChunk head; + private final int freeMinThreshold; + private final int freeMaxThreshold; + + // This is only update once when create the linked like list of PoolChunkList in PoolArena constructor. + private PoolChunkList prevList; + + // TODO: Test if adding padding helps under contention + //private long pad0, pad1, pad2, pad3, pad4, pad5, pad6, pad7; + + PoolChunkList(PoolArena arena, PoolChunkList nextList, int minUsage, int maxUsage, int chunkSize) { + assert minUsage <= maxUsage; + this.arena = arena; + this.nextList = nextList; + this.minUsage = minUsage; + this.maxUsage = maxUsage; + maxCapacity = calculateMaxCapacity(minUsage, chunkSize); + + // the thresholds are aligned with PoolChunk.usage() logic: + // 1) basic logic: usage() = 100 - freeBytes * 100L / chunkSize + // so, for example: (usage() >= maxUsage) condition can be transformed in the following way: + // 100 - freeBytes * 100L / chunkSize >= maxUsage + // freeBytes <= chunkSize * (100 - maxUsage) / 100 + // let freeMinThreshold = chunkSize * (100 - maxUsage) / 100, then freeBytes <= freeMinThreshold + // + // 2) usage() returns an int value and has a floor rounding during a calculation, + // to be aligned absolute thresholds should be shifted for "the rounding step": + // freeBytes * 100 / chunkSize < 1 + // the condition can be converted to: freeBytes < 1 * chunkSize / 100 + // this is why we have + 0.99999999 shifts. A example why just +1 shift cannot be used: + // freeBytes = 16777216 == freeMaxThreshold: 16777216, usage = 0 < minUsage: 1, chunkSize: 16777216 + // At the same time we want to have zero thresholds in case of (maxUsage == 100) and (minUsage == 100). + // + freeMinThreshold = (maxUsage == 100) ? 0 : (int) (chunkSize * (100.0 - maxUsage + 0.99999999) / 100L); + freeMaxThreshold = (minUsage == 100) ? 0 : (int) (chunkSize * (100.0 - minUsage + 0.99999999) / 100L); + } + + /** + * Calculates the maximum capacity of a buffer that will ever be possible to allocate out of the {@link PoolChunk}s + * that belong to the {@link PoolChunkList} with the given {@code minUsage} and {@code maxUsage} settings. + */ + private static int calculateMaxCapacity(int minUsage, int chunkSize) { + minUsage = minUsage0(minUsage); + + if (minUsage == 100) { + // If the minUsage is 100 we can not allocate anything out of this list. + return 0; + } + + // Calculate the maximum amount of bytes that can be allocated from a PoolChunk in this PoolChunkList. + // + // As an example: + // - If a PoolChunkList has minUsage == 25 we are allowed to allocate at most 75% of the chunkSize because + // this is the maximum amount available in any PoolChunk in this PoolChunkList. + return (int) (chunkSize * (100L - minUsage) / 100L); + } + + void prevList(PoolChunkList prevList) { + assert this.prevList == null; + this.prevList = prevList; + } + + boolean allocate(PooledByteBuf buf, int reqCapacity, int sizeIdx, PoolThreadCache threadCache) { + int normCapacity = arena.sizeClass.sizeIdx2size(sizeIdx); + if (normCapacity > maxCapacity) { + // Either this PoolChunkList is empty or the requested capacity is larger then the capacity which can + // be handled by the PoolChunks that are contained in this PoolChunkList. + return false; + } + + for (PoolChunk cur = head; cur != null; cur = cur.next) { + if (cur.allocate(buf, reqCapacity, sizeIdx, threadCache)) { + if (cur.freeBytes <= freeMinThreshold) { + remove(cur); + nextList.add(cur); + } + return true; + } + } + return false; + } + + boolean free(PoolChunk chunk, long handle, int normCapacity, ByteBuffer nioBuffer) { + chunk.free(handle, normCapacity, nioBuffer); + if (chunk.freeBytes > freeMaxThreshold) { + remove(chunk); + // Move the PoolChunk down the PoolChunkList linked-list. + return move0(chunk); + } + return true; + } + + private boolean move(PoolChunk chunk) { + assert chunk.usage() < maxUsage; + + if (chunk.freeBytes > freeMaxThreshold) { + // Move the PoolChunk down the PoolChunkList linked-list. + return move0(chunk); + } + + // PoolChunk fits into this PoolChunkList, adding it here. + add0(chunk); + return true; + } + + /** + * Moves the {@link PoolChunk} down the {@link PoolChunkList} linked-list so it will end up in the right + * {@link PoolChunkList} that has the correct minUsage / maxUsage in respect to {@link PoolChunk#usage()}. + */ + private boolean move0(PoolChunk chunk) { + if (prevList == null) { + // There is no previous PoolChunkList so return false which result in having the PoolChunk destroyed and + // all memory associated with the PoolChunk will be released. + assert chunk.usage() == 0; + return false; + } + return prevList.move(chunk); + } + + void add(PoolChunk chunk) { + if (chunk.freeBytes <= freeMinThreshold) { + nextList.add(chunk); + return; + } + add0(chunk); + } + + /** + * Adds the {@link PoolChunk} to this {@link PoolChunkList}. + */ + void add0(PoolChunk chunk) { + chunk.parent = this; + if (head == null) { + head = chunk; + chunk.prev = null; + chunk.next = null; + } else { + chunk.prev = null; + chunk.next = head; + head.prev = chunk; + head = chunk; + } + } + + private void remove(PoolChunk cur) { + if (cur == head) { + head = cur.next; + if (head != null) { + head.prev = null; + } + } else { + PoolChunk next = cur.next; + cur.prev.next = next; + if (next != null) { + next.prev = cur.prev; + } + } + } + + @Override + public int minUsage() { + return minUsage0(minUsage); + } + + @Override + public int maxUsage() { + return min(maxUsage, 100); + } + + private static int minUsage0(int value) { + return max(1, value); + } + + @Override + public Iterator iterator() { + arena.lock(); + try { + if (head == null) { + return EMPTY_METRICS; + } + List metrics = new ArrayList(); + for (PoolChunk cur = head;;) { + metrics.add(cur); + cur = cur.next; + if (cur == null) { + break; + } + } + return metrics.iterator(); + } finally { + arena.unlock(); + } + } + + @Override + public String toString() { + StringBuilder buf = new StringBuilder(); + arena.lock(); + try { + if (head == null) { + return "none"; + } + + for (PoolChunk cur = head;;) { + buf.append(cur); + cur = cur.next; + if (cur == null) { + break; + } + buf.append(StringUtil.NEWLINE); + } + } finally { + arena.unlock(); + } + return buf.toString(); + } + + void destroy(PoolArena arena) { + PoolChunk chunk = head; + while (chunk != null) { + arena.destroyChunk(chunk); + chunk = chunk.next; + } + head = null; + } +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/PoolChunkListMetric.java b/netty-buffer/src/main/java/io/netty/buffer/PoolChunkListMetric.java new file mode 100644 index 0000000..ec45561 --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/PoolChunkListMetric.java @@ -0,0 +1,32 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +/** + * Metrics for a list of chunks. + */ +public interface PoolChunkListMetric extends Iterable { + + /** + * Return the minimum usage of the chunk list before which chunks are promoted to the previous list. + */ + int minUsage(); + + /** + * Return the maximum usage of the chunk list after which chunks are promoted to the next list. + */ + int maxUsage(); +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/PoolChunkMetric.java b/netty-buffer/src/main/java/io/netty/buffer/PoolChunkMetric.java new file mode 100644 index 0000000..a006785 --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/PoolChunkMetric.java @@ -0,0 +1,37 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +/** + * Metrics for a chunk. + */ +public interface PoolChunkMetric { + + /** + * Return the percentage of the current usage of the chunk. + */ + int usage(); + + /** + * Return the size of the chunk in bytes, this is the maximum of bytes that can be served out of the chunk. + */ + int chunkSize(); + + /** + * Return the number of free bytes in the chunk. + */ + int freeBytes(); +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/PoolSubpage.java b/netty-buffer/src/main/java/io/netty/buffer/PoolSubpage.java new file mode 100644 index 0000000..17e0442 --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/PoolSubpage.java @@ -0,0 +1,300 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.buffer; + +import java.util.concurrent.locks.ReentrantLock; + +import static io.netty.buffer.PoolChunk.RUN_OFFSET_SHIFT; +import static io.netty.buffer.PoolChunk.SIZE_SHIFT; +import static io.netty.buffer.PoolChunk.IS_USED_SHIFT; +import static io.netty.buffer.PoolChunk.IS_SUBPAGE_SHIFT; + +final class PoolSubpage implements PoolSubpageMetric { + + final PoolChunk chunk; + final int elemSize; + private final int pageShifts; + private final int runOffset; + private final int runSize; + private final long[] bitmap; + private final int bitmapLength; + private final int maxNumElems; + final int headIndex; + + PoolSubpage prev; + PoolSubpage next; + + boolean doNotDestroy; + private int nextAvail; + private int numAvail; + + final ReentrantLock lock; + + // TODO: Test if adding padding helps under contention + //private long pad0, pad1, pad2, pad3, pad4, pad5, pad6, pad7; + + /** Special constructor that creates a linked list head */ + PoolSubpage(int headIndex) { + chunk = null; + lock = new ReentrantLock(); + pageShifts = -1; + runOffset = -1; + elemSize = -1; + runSize = -1; + bitmap = null; + bitmapLength = -1; + maxNumElems = 0; + this.headIndex = headIndex; + } + + PoolSubpage(PoolSubpage head, PoolChunk chunk, int pageShifts, int runOffset, int runSize, int elemSize) { + this.headIndex = head.headIndex; + this.chunk = chunk; + this.pageShifts = pageShifts; + this.runOffset = runOffset; + this.runSize = runSize; + this.elemSize = elemSize; + + doNotDestroy = true; + + maxNumElems = numAvail = runSize / elemSize; + int bitmapLength = maxNumElems >>> 6; + if ((maxNumElems & 63) != 0) { + bitmapLength ++; + } + this.bitmapLength = bitmapLength; + bitmap = new long[bitmapLength]; + nextAvail = 0; + + lock = null; + addToPool(head); + } + + /** + * Returns the bitmap index of the subpage allocation. + */ + long allocate() { + if (numAvail == 0 || !doNotDestroy) { + return -1; + } + + final int bitmapIdx = getNextAvail(); + if (bitmapIdx < 0) { + removeFromPool(); // Subpage appear to be in an invalid state. Remove to prevent repeated errors. + throw new AssertionError("No next available bitmap index found (bitmapIdx = " + bitmapIdx + "), " + + "even though there are supposed to be (numAvail = " + numAvail + ") " + + "out of (maxNumElems = " + maxNumElems + ") available indexes."); + } + int q = bitmapIdx >>> 6; + int r = bitmapIdx & 63; + assert (bitmap[q] >>> r & 1) == 0; + bitmap[q] |= 1L << r; + + if (-- numAvail == 0) { + removeFromPool(); + } + + return toHandle(bitmapIdx); + } + + /** + * @return {@code true} if this subpage is in use. + * {@code false} if this subpage is not used by its chunk and thus it's OK to be released. + */ + boolean free(PoolSubpage head, int bitmapIdx) { + int q = bitmapIdx >>> 6; + int r = bitmapIdx & 63; + assert (bitmap[q] >>> r & 1) != 0; + bitmap[q] ^= 1L << r; + + setNextAvail(bitmapIdx); + + if (numAvail ++ == 0) { + addToPool(head); + /* When maxNumElems == 1, the maximum numAvail is also 1. + * Each of these PoolSubpages will go in here when they do free operation. + * If they return true directly from here, then the rest of the code will be unreachable + * and they will not actually be recycled. So return true only on maxNumElems > 1. */ + if (maxNumElems > 1) { + return true; + } + } + + if (numAvail != maxNumElems) { + return true; + } else { + // Subpage not in use (numAvail == maxNumElems) + if (prev == next) { + // Do not remove if this subpage is the only one left in the pool. + return true; + } + + // Remove this subpage from the pool if there are other subpages left in the pool. + doNotDestroy = false; + removeFromPool(); + return false; + } + } + + private void addToPool(PoolSubpage head) { + assert prev == null && next == null; + prev = head; + next = head.next; + next.prev = this; + head.next = this; + } + + private void removeFromPool() { + assert prev != null && next != null; + prev.next = next; + next.prev = prev; + next = null; + prev = null; + } + + private void setNextAvail(int bitmapIdx) { + nextAvail = bitmapIdx; + } + + private int getNextAvail() { + int nextAvail = this.nextAvail; + if (nextAvail >= 0) { + this.nextAvail = -1; + return nextAvail; + } + return findNextAvail(); + } + + private int findNextAvail() { + for (int i = 0; i < bitmapLength; i ++) { + long bits = bitmap[i]; + if (~bits != 0) { + return findNextAvail0(i, bits); + } + } + return -1; + } + + private int findNextAvail0(int i, long bits) { + final int baseVal = i << 6; + for (int j = 0; j < 64; j ++) { + if ((bits & 1) == 0) { + int val = baseVal | j; + if (val < maxNumElems) { + return val; + } else { + break; + } + } + bits >>>= 1; + } + return -1; + } + + private long toHandle(int bitmapIdx) { + int pages = runSize >> pageShifts; + return (long) runOffset << RUN_OFFSET_SHIFT + | (long) pages << SIZE_SHIFT + | 1L << IS_USED_SHIFT + | 1L << IS_SUBPAGE_SHIFT + | bitmapIdx; + } + + @Override + public String toString() { + final int numAvail; + if (chunk == null) { + // This is the head so there is no need to synchronize at all as these never change. + numAvail = 0; + } else { + final boolean doNotDestroy; + PoolSubpage head = chunk.arena.smallSubpagePools[headIndex]; + head.lock(); + try { + doNotDestroy = this.doNotDestroy; + numAvail = this.numAvail; + } finally { + head.unlock(); + } + if (!doNotDestroy) { + // Not used for creating the String. + return "(" + runOffset + ": not in use)"; + } + } + + return "(" + this.runOffset + ": " + (this.maxNumElems - numAvail) + '/' + this.maxNumElems + + ", offset: " + this.runOffset + ", length: " + this.runSize + ", elemSize: " + this.elemSize + ')'; + } + + @Override + public int maxNumElements() { + return maxNumElems; + } + + @Override + public int numAvailable() { + if (chunk == null) { + // It's the head. + return 0; + } + PoolSubpage head = chunk.arena.smallSubpagePools[headIndex]; + head.lock(); + try { + return numAvail; + } finally { + head.unlock(); + } + } + + @Override + public int elementSize() { + return elemSize; + } + + @Override + public int pageSize() { + return 1 << pageShifts; + } + + boolean isDoNotDestroy() { + if (chunk == null) { + // It's the head. + return true; + } + PoolSubpage head = chunk.arena.smallSubpagePools[headIndex]; + head.lock(); + try { + return doNotDestroy; + } finally { + head.unlock(); + } + } + + void destroy() { + if (chunk != null) { + chunk.destroy(); + } + } + + void lock() { + lock.lock(); + } + + void unlock() { + lock.unlock(); + } +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/PoolSubpageMetric.java b/netty-buffer/src/main/java/io/netty/buffer/PoolSubpageMetric.java new file mode 100644 index 0000000..c010273 --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/PoolSubpageMetric.java @@ -0,0 +1,43 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +/** + * Metrics for a sub-page. + */ +public interface PoolSubpageMetric { + + /** + * Return the number of maximal elements that can be allocated out of the sub-page. + */ + int maxNumElements(); + + /** + * Return the number of available elements to be allocated. + */ + int numAvailable(); + + /** + * Return the size (in bytes) of the elements that will be allocated. + */ + int elementSize(); + + /** + * Return the page size (in bytes) of this page. + */ + int pageSize(); +} + diff --git a/netty-buffer/src/main/java/io/netty/buffer/PoolThreadCache.java b/netty-buffer/src/main/java/io/netty/buffer/PoolThreadCache.java new file mode 100644 index 0000000..3c10b5a --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/PoolThreadCache.java @@ -0,0 +1,504 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.buffer; + + +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; + +import io.netty.buffer.PoolArena.SizeClass; +import io.netty.util.Recycler.EnhancedHandle; +import io.netty.util.internal.MathUtil; +import io.netty.util.internal.ObjectPool; +import io.netty.util.internal.ObjectPool.Handle; +import io.netty.util.internal.ObjectPool.ObjectCreator; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; +import java.util.Queue; +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * Acts a Thread cache for allocations. This implementation is moduled after + * jemalloc and the descripted + * technics of + * + * Scalable memory allocation using jemalloc. + */ +final class PoolThreadCache { + + private static final InternalLogger logger = InternalLoggerFactory.getInstance(PoolThreadCache.class); + private static final int INTEGER_SIZE_MINUS_ONE = Integer.SIZE - 1; + + final PoolArena heapArena; + final PoolArena directArena; + + // Hold the caches for the different size classes, which are small and normal. + private final MemoryRegionCache[] smallSubPageHeapCaches; + private final MemoryRegionCache[] smallSubPageDirectCaches; + private final MemoryRegionCache[] normalHeapCaches; + private final MemoryRegionCache[] normalDirectCaches; + + private final int freeSweepAllocationThreshold; + private final AtomicBoolean freed = new AtomicBoolean(); + @SuppressWarnings("unused") // Field is only here for the finalizer. + private final FreeOnFinalize freeOnFinalize; + + private int allocations; + + // TODO: Test if adding padding helps under contention + //private long pad0, pad1, pad2, pad3, pad4, pad5, pad6, pad7; + + PoolThreadCache(PoolArena heapArena, PoolArena directArena, + int smallCacheSize, int normalCacheSize, int maxCachedBufferCapacity, + int freeSweepAllocationThreshold, boolean useFinalizer) { + checkPositiveOrZero(maxCachedBufferCapacity, "maxCachedBufferCapacity"); + this.freeSweepAllocationThreshold = freeSweepAllocationThreshold; + this.heapArena = heapArena; + this.directArena = directArena; + if (directArena != null) { + smallSubPageDirectCaches = createSubPageCaches(smallCacheSize, directArena.sizeClass.nSubpages); + normalDirectCaches = createNormalCaches(normalCacheSize, maxCachedBufferCapacity, directArena); + directArena.numThreadCaches.getAndIncrement(); + } else { + // No directArea is configured so just null out all caches + smallSubPageDirectCaches = null; + normalDirectCaches = null; + } + if (heapArena != null) { + // Create the caches for the heap allocations + smallSubPageHeapCaches = createSubPageCaches(smallCacheSize, heapArena.sizeClass.nSubpages); + normalHeapCaches = createNormalCaches(normalCacheSize, maxCachedBufferCapacity, heapArena); + heapArena.numThreadCaches.getAndIncrement(); + } else { + // No heapArea is configured so just null out all caches + smallSubPageHeapCaches = null; + normalHeapCaches = null; + } + + // Only check if there are caches in use. + if ((smallSubPageDirectCaches != null || normalDirectCaches != null + || smallSubPageHeapCaches != null || normalHeapCaches != null) + && freeSweepAllocationThreshold < 1) { + throw new IllegalArgumentException("freeSweepAllocationThreshold: " + + freeSweepAllocationThreshold + " (expected: > 0)"); + } + freeOnFinalize = useFinalizer ? new FreeOnFinalize(this) : null; + } + + private static MemoryRegionCache[] createSubPageCaches( + int cacheSize, int numCaches) { + if (cacheSize > 0 && numCaches > 0) { + @SuppressWarnings("unchecked") + MemoryRegionCache[] cache = new MemoryRegionCache[numCaches]; + for (int i = 0; i < cache.length; i++) { + // TODO: maybe use cacheSize / cache.length + cache[i] = new SubPageMemoryRegionCache(cacheSize); + } + return cache; + } else { + return null; + } + } + + @SuppressWarnings("unchecked") + private static MemoryRegionCache[] createNormalCaches( + int cacheSize, int maxCachedBufferCapacity, PoolArena area) { + if (cacheSize > 0 && maxCachedBufferCapacity > 0) { + int max = Math.min(area.sizeClass.chunkSize, maxCachedBufferCapacity); + // Create as many normal caches as we support based on how many sizeIdx we have and what the upper + // bound is that we want to cache in general. + List> cache = new ArrayList>() ; + for (int idx = area.sizeClass.nSubpages; idx < area.sizeClass.nSizes && + area.sizeClass.sizeIdx2size(idx) <= max; idx++) { + cache.add(new NormalMemoryRegionCache(cacheSize)); + } + return cache.toArray(new MemoryRegionCache[0]); + } else { + return null; + } + } + + // val > 0 + static int log2(int val) { + return INTEGER_SIZE_MINUS_ONE - Integer.numberOfLeadingZeros(val); + } + + /** + * Try to allocate a small buffer out of the cache. Returns {@code true} if successful {@code false} otherwise + */ + boolean allocateSmall(PoolArena area, PooledByteBuf buf, int reqCapacity, int sizeIdx) { + return allocate(cacheForSmall(area, sizeIdx), buf, reqCapacity); + } + + /** + * Try to allocate a normal buffer out of the cache. Returns {@code true} if successful {@code false} otherwise + */ + boolean allocateNormal(PoolArena area, PooledByteBuf buf, int reqCapacity, int sizeIdx) { + return allocate(cacheForNormal(area, sizeIdx), buf, reqCapacity); + } + + @SuppressWarnings({ "unchecked", "rawtypes" }) + private boolean allocate(MemoryRegionCache cache, PooledByteBuf buf, int reqCapacity) { + if (cache == null) { + // no cache found so just return false here + return false; + } + boolean allocated = cache.allocate(buf, reqCapacity, this); + if (++ allocations >= freeSweepAllocationThreshold) { + allocations = 0; + trim(); + } + return allocated; + } + + /** + * Add {@link PoolChunk} and {@code handle} to the cache if there is enough room. + * Returns {@code true} if it fit into the cache {@code false} otherwise. + */ + @SuppressWarnings({ "unchecked", "rawtypes" }) + boolean add(PoolArena area, PoolChunk chunk, ByteBuffer nioBuffer, + long handle, int normCapacity, SizeClass sizeClass) { + int sizeIdx = area.sizeClass.size2SizeIdx(normCapacity); + MemoryRegionCache cache = cache(area, sizeIdx, sizeClass); + if (cache == null) { + return false; + } + if (freed.get()) { + return false; + } + return cache.add(chunk, nioBuffer, handle, normCapacity); + } + + private MemoryRegionCache cache(PoolArena area, int sizeIdx, SizeClass sizeClass) { + switch (sizeClass) { + case Normal: + return cacheForNormal(area, sizeIdx); + case Small: + return cacheForSmall(area, sizeIdx); + default: + throw new Error(); + } + } + + /** + * Should be called if the Thread that uses this cache is about to exist to release resources out of the cache + */ + void free(boolean finalizer) { + // As free() may be called either by the finalizer or by FastThreadLocal.onRemoval(...) we need to ensure + // we only call this one time. + if (freed.compareAndSet(false, true)) { + int numFreed = free(smallSubPageDirectCaches, finalizer) + + free(normalDirectCaches, finalizer) + + free(smallSubPageHeapCaches, finalizer) + + free(normalHeapCaches, finalizer); + + if (numFreed > 0 && logger.isDebugEnabled()) { + logger.debug("Freed {} thread-local buffer(s) from thread: {}", numFreed, + Thread.currentThread().getName()); + } + + if (directArena != null) { + directArena.numThreadCaches.getAndDecrement(); + } + + if (heapArena != null) { + heapArena.numThreadCaches.getAndDecrement(); + } + } else { + // See https://github.com/netty/netty/issues/12749 + checkCacheMayLeak(smallSubPageDirectCaches, "SmallSubPageDirectCaches"); + checkCacheMayLeak(normalDirectCaches, "NormalDirectCaches"); + checkCacheMayLeak(smallSubPageHeapCaches, "SmallSubPageHeapCaches"); + checkCacheMayLeak(normalHeapCaches, "NormalHeapCaches"); + } + } + + private static void checkCacheMayLeak(MemoryRegionCache[] caches, String type) { + for (MemoryRegionCache cache : caches) { + if (!cache.queue.isEmpty()) { + logger.debug("{} memory may leak.", type); + return; + } + } + } + + private static int free(MemoryRegionCache[] caches, boolean finalizer) { + if (caches == null) { + return 0; + } + + int numFreed = 0; + for (MemoryRegionCache c: caches) { + numFreed += free(c, finalizer); + } + return numFreed; + } + + private static int free(MemoryRegionCache cache, boolean finalizer) { + if (cache == null) { + return 0; + } + return cache.free(finalizer); + } + + void trim() { + trim(smallSubPageDirectCaches); + trim(normalDirectCaches); + trim(smallSubPageHeapCaches); + trim(normalHeapCaches); + } + + private static void trim(MemoryRegionCache[] caches) { + if (caches == null) { + return; + } + for (MemoryRegionCache c: caches) { + trim(c); + } + } + + private static void trim(MemoryRegionCache cache) { + if (cache == null) { + return; + } + cache.trim(); + } + + private MemoryRegionCache cacheForSmall(PoolArena area, int sizeIdx) { + if (area.isDirect()) { + return cache(smallSubPageDirectCaches, sizeIdx); + } + return cache(smallSubPageHeapCaches, sizeIdx); + } + + private MemoryRegionCache cacheForNormal(PoolArena area, int sizeIdx) { + // We need to subtract area.sizeClass.nSubpages as sizeIdx is the overall index for all sizes. + int idx = sizeIdx - area.sizeClass.nSubpages; + if (area.isDirect()) { + return cache(normalDirectCaches, idx); + } + return cache(normalHeapCaches, idx); + } + + private static MemoryRegionCache cache(MemoryRegionCache[] cache, int sizeIdx) { + if (cache == null || sizeIdx > cache.length - 1) { + return null; + } + return cache[sizeIdx]; + } + + /** + * Cache used for buffers which are backed by TINY or SMALL size. + */ + private static final class SubPageMemoryRegionCache extends MemoryRegionCache { + SubPageMemoryRegionCache(int size) { + super(size, SizeClass.Small); + } + + @Override + protected void initBuf( + PoolChunk chunk, ByteBuffer nioBuffer, long handle, PooledByteBuf buf, int reqCapacity, + PoolThreadCache threadCache) { + chunk.initBufWithSubpage(buf, nioBuffer, handle, reqCapacity, threadCache); + } + } + + /** + * Cache used for buffers which are backed by NORMAL size. + */ + private static final class NormalMemoryRegionCache extends MemoryRegionCache { + NormalMemoryRegionCache(int size) { + super(size, SizeClass.Normal); + } + + @Override + protected void initBuf( + PoolChunk chunk, ByteBuffer nioBuffer, long handle, PooledByteBuf buf, int reqCapacity, + PoolThreadCache threadCache) { + chunk.initBuf(buf, nioBuffer, handle, reqCapacity, threadCache); + } + } + + private abstract static class MemoryRegionCache { + private final int size; + private final Queue> queue; + private final SizeClass sizeClass; + private int allocations; + + MemoryRegionCache(int size, SizeClass sizeClass) { + this.size = MathUtil.safeFindNextPositivePowerOfTwo(size); + queue = PlatformDependent.newFixedMpscQueue(this.size); + this.sizeClass = sizeClass; + } + + /** + * Init the {@link PooledByteBuf} using the provided chunk and handle with the capacity restrictions. + */ + protected abstract void initBuf(PoolChunk chunk, ByteBuffer nioBuffer, long handle, + PooledByteBuf buf, int reqCapacity, PoolThreadCache threadCache); + + /** + * Add to cache if not already full. + */ + @SuppressWarnings("unchecked") + public final boolean add(PoolChunk chunk, ByteBuffer nioBuffer, long handle, int normCapacity) { + Entry entry = newEntry(chunk, nioBuffer, handle, normCapacity); + boolean queued = queue.offer(entry); + if (!queued) { + // If it was not possible to cache the chunk, immediately recycle the entry + entry.unguardedRecycle(); + } + + return queued; + } + + /** + * Allocate something out of the cache if possible and remove the entry from the cache. + */ + public final boolean allocate(PooledByteBuf buf, int reqCapacity, PoolThreadCache threadCache) { + Entry entry = queue.poll(); + if (entry == null) { + return false; + } + initBuf(entry.chunk, entry.nioBuffer, entry.handle, buf, reqCapacity, threadCache); + entry.unguardedRecycle(); + + // allocations is not thread-safe which is fine as this is only called from the same thread all time. + ++ allocations; + return true; + } + + /** + * Clear out this cache and free up all previous cached {@link PoolChunk}s and {@code handle}s. + */ + public final int free(boolean finalizer) { + return free(Integer.MAX_VALUE, finalizer); + } + + private int free(int max, boolean finalizer) { + int numFreed = 0; + for (; numFreed < max; numFreed++) { + Entry entry = queue.poll(); + if (entry != null) { + freeEntry(entry, finalizer); + } else { + // all cleared + return numFreed; + } + } + return numFreed; + } + + /** + * Free up cached {@link PoolChunk}s if not allocated frequently enough. + */ + public final void trim() { + int free = size - allocations; + allocations = 0; + + // We not even allocated all the number that are + if (free > 0) { + free(free, false); + } + } + + @SuppressWarnings({ "unchecked", "rawtypes" }) + private void freeEntry(Entry entry, boolean finalizer) { + // Capture entry state before we recycle the entry object. + PoolChunk chunk = entry.chunk; + long handle = entry.handle; + ByteBuffer nioBuffer = entry.nioBuffer; + int normCapacity = entry.normCapacity; + + if (!finalizer) { + // recycle now so PoolChunk can be GC'ed. This will only be done if this is not freed because of + // a finalizer. + entry.recycle(); + } + + chunk.arena.freeChunk(chunk, handle, normCapacity, sizeClass, nioBuffer, finalizer); + } + + static final class Entry { + final EnhancedHandle> recyclerHandle; + PoolChunk chunk; + ByteBuffer nioBuffer; + long handle = -1; + int normCapacity; + + Entry(Handle> recyclerHandle) { + this.recyclerHandle = (EnhancedHandle>) recyclerHandle; + } + + void recycle() { + chunk = null; + nioBuffer = null; + handle = -1; + recyclerHandle.recycle(this); + } + + void unguardedRecycle() { + chunk = null; + nioBuffer = null; + handle = -1; + recyclerHandle.unguardedRecycle(this); + } + } + + @SuppressWarnings("rawtypes") + private static Entry newEntry(PoolChunk chunk, ByteBuffer nioBuffer, long handle, int normCapacity) { + Entry entry = RECYCLER.get(); + entry.chunk = chunk; + entry.nioBuffer = nioBuffer; + entry.handle = handle; + entry.normCapacity = normCapacity; + return entry; + } + + @SuppressWarnings("rawtypes") + private static final ObjectPool RECYCLER = ObjectPool.newPool(new ObjectCreator() { + @SuppressWarnings("unchecked") + @Override + public Entry newObject(Handle handle) { + return new Entry(handle); + } + }); + } + + private static final class FreeOnFinalize { + private final PoolThreadCache cache; + + private FreeOnFinalize(PoolThreadCache cache) { + this.cache = cache; + } + + /// TODO: In the future when we move to Java9+ we should use java.lang.ref.Cleaner. + @SuppressWarnings({"FinalizeDeclaration", "deprecation"}) + @Override + protected void finalize() throws Throwable { + try { + super.finalize(); + } finally { + cache.free(true); + } + } + } +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/PooledByteBuf.java b/netty-buffer/src/main/java/io/netty/buffer/PooledByteBuf.java new file mode 100644 index 0000000..8e404c9 --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/PooledByteBuf.java @@ -0,0 +1,269 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.buffer; + +import io.netty.util.Recycler.EnhancedHandle; +import io.netty.util.internal.ObjectPool.Handle; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.channels.ClosedChannelException; +import java.nio.channels.FileChannel; +import java.nio.channels.GatheringByteChannel; +import java.nio.channels.ScatteringByteChannel; + +abstract class PooledByteBuf extends AbstractReferenceCountedByteBuf { + + private final EnhancedHandle> recyclerHandle; + + protected PoolChunk chunk; + protected long handle; + protected T memory; + protected int offset; + protected int length; + int maxLength; + PoolThreadCache cache; + ByteBuffer tmpNioBuf; + private ByteBufAllocator allocator; + + @SuppressWarnings("unchecked") + protected PooledByteBuf(Handle> recyclerHandle, int maxCapacity) { + super(maxCapacity); + this.recyclerHandle = (EnhancedHandle>) recyclerHandle; + } + + void init(PoolChunk chunk, ByteBuffer nioBuffer, + long handle, int offset, int length, int maxLength, PoolThreadCache cache) { + init0(chunk, nioBuffer, handle, offset, length, maxLength, cache); + } + + void initUnpooled(PoolChunk chunk, int length) { + init0(chunk, null, 0, 0, length, length, null); + } + + private void init0(PoolChunk chunk, ByteBuffer nioBuffer, + long handle, int offset, int length, int maxLength, PoolThreadCache cache) { + assert handle >= 0; + assert chunk != null; + assert !PoolChunk.isSubpage(handle) || + chunk.arena.sizeClass.size2SizeIdx(maxLength) <= chunk.arena.sizeClass.smallMaxSizeIdx: + "Allocated small sub-page handle for a buffer size that isn't \"small.\""; + + chunk.incrementPinnedMemory(maxLength); + this.chunk = chunk; + memory = chunk.memory; + tmpNioBuf = nioBuffer; + allocator = chunk.arena.parent; + this.cache = cache; + this.handle = handle; + this.offset = offset; + this.length = length; + this.maxLength = maxLength; + } + + /** + * Method must be called before reuse this {@link PooledByteBufAllocator} + */ + final void reuse(int maxCapacity) { + maxCapacity(maxCapacity); + resetRefCnt(); + setIndex0(0, 0); + discardMarks(); + } + + @Override + public final int capacity() { + return length; + } + + @Override + public int maxFastWritableBytes() { + return Math.min(maxLength, maxCapacity()) - writerIndex; + } + + @Override + public final ByteBuf capacity(int newCapacity) { + if (newCapacity == length) { + ensureAccessible(); + return this; + } + checkNewCapacity(newCapacity); + if (!chunk.unpooled) { + // If the request capacity does not require reallocation, just update the length of the memory. + if (newCapacity > length) { + if (newCapacity <= maxLength) { + length = newCapacity; + return this; + } + } else if (newCapacity > maxLength >>> 1 && + (maxLength > 512 || newCapacity > maxLength - 16)) { + // here newCapacity < length + length = newCapacity; + trimIndicesToCapacity(newCapacity); + return this; + } + } + + // Reallocation required. + chunk.arena.reallocate(this, newCapacity); + return this; + } + + @Override + public final ByteBufAllocator alloc() { + return allocator; + } + + @Override + public final ByteOrder order() { + return ByteOrder.BIG_ENDIAN; + } + + @Override + public final ByteBuf unwrap() { + return null; + } + + @Override + public final ByteBuf retainedDuplicate() { + return PooledDuplicatedByteBuf.newInstance(this, this, readerIndex(), writerIndex()); + } + + @Override + public final ByteBuf retainedSlice() { + final int index = readerIndex(); + return retainedSlice(index, writerIndex() - index); + } + + @Override + public final ByteBuf retainedSlice(int index, int length) { + return PooledSlicedByteBuf.newInstance(this, this, index, length); + } + + protected final ByteBuffer internalNioBuffer() { + ByteBuffer tmpNioBuf = this.tmpNioBuf; + if (tmpNioBuf == null) { + this.tmpNioBuf = tmpNioBuf = newInternalNioBuffer(memory); + } else { + tmpNioBuf.clear(); + } + return tmpNioBuf; + } + + protected abstract ByteBuffer newInternalNioBuffer(T memory); + + @Override + protected final void deallocate() { + if (handle >= 0) { + final long handle = this.handle; + this.handle = -1; + memory = null; + chunk.arena.free(chunk, tmpNioBuf, handle, maxLength, cache); + tmpNioBuf = null; + chunk = null; + cache = null; + this.recyclerHandle.unguardedRecycle(this); + } + } + + protected final int idx(int index) { + return offset + index; + } + + final ByteBuffer _internalNioBuffer(int index, int length, boolean duplicate) { + index = idx(index); + ByteBuffer buffer = duplicate ? newInternalNioBuffer(memory) : internalNioBuffer(); + buffer.limit(index + length).position(index); + return buffer; + } + + ByteBuffer duplicateInternalNioBuffer(int index, int length) { + checkIndex(index, length); + return _internalNioBuffer(index, length, true); + } + + @Override + public final ByteBuffer internalNioBuffer(int index, int length) { + checkIndex(index, length); + return _internalNioBuffer(index, length, false); + } + + @Override + public final int nioBufferCount() { + return 1; + } + + @Override + public final ByteBuffer nioBuffer(int index, int length) { + return duplicateInternalNioBuffer(index, length).slice(); + } + + @Override + public final ByteBuffer[] nioBuffers(int index, int length) { + return new ByteBuffer[] { nioBuffer(index, length) }; + } + + @Override + public final boolean isContiguous() { + return true; + } + + @Override + public final int getBytes(int index, GatheringByteChannel out, int length) throws IOException { + return out.write(duplicateInternalNioBuffer(index, length)); + } + + @Override + public final int readBytes(GatheringByteChannel out, int length) throws IOException { + checkReadableBytes(length); + int readBytes = out.write(_internalNioBuffer(readerIndex, length, false)); + readerIndex += readBytes; + return readBytes; + } + + @Override + public final int getBytes(int index, FileChannel out, long position, int length) throws IOException { + return out.write(duplicateInternalNioBuffer(index, length), position); + } + + @Override + public final int readBytes(FileChannel out, long position, int length) throws IOException { + checkReadableBytes(length); + int readBytes = out.write(_internalNioBuffer(readerIndex, length, false), position); + readerIndex += readBytes; + return readBytes; + } + + @Override + public final int setBytes(int index, ScatteringByteChannel in, int length) throws IOException { + try { + return in.read(internalNioBuffer(index, length)); + } catch (ClosedChannelException ignored) { + return -1; + } + } + + @Override + public final int setBytes(int index, FileChannel in, long position, int length) throws IOException { + try { + return in.read(internalNioBuffer(index, length), position); + } catch (ClosedChannelException ignored) { + return -1; + } + } +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/PooledByteBufAllocator.java b/netty-buffer/src/main/java/io/netty/buffer/PooledByteBufAllocator.java new file mode 100644 index 0000000..cfd1c75 --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/PooledByteBufAllocator.java @@ -0,0 +1,763 @@ +package io.netty.buffer; + +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; + +import io.netty.util.NettyRuntime; +import io.netty.util.concurrent.EventExecutor; +import io.netty.util.concurrent.FastThreadLocal; +import io.netty.util.concurrent.FastThreadLocalThread; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.StringUtil; +import io.netty.util.internal.SystemPropertyUtil; +import io.netty.util.internal.ThreadExecutorMap; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.TimeUnit; + +public class PooledByteBufAllocator extends AbstractByteBufAllocator implements ByteBufAllocatorMetricProvider { + + private static final InternalLogger logger = InternalLoggerFactory.getInstance(PooledByteBufAllocator.class); + private static final int DEFAULT_NUM_HEAP_ARENA; + private static final int DEFAULT_NUM_DIRECT_ARENA; + + private static final int DEFAULT_PAGE_SIZE; + private static final int DEFAULT_MAX_ORDER; // 8192 << 9 = 4 MiB per chunk + private static final int DEFAULT_SMALL_CACHE_SIZE; + private static final int DEFAULT_NORMAL_CACHE_SIZE; + static final int DEFAULT_MAX_CACHED_BUFFER_CAPACITY; + private static final int DEFAULT_CACHE_TRIM_INTERVAL; + private static final long DEFAULT_CACHE_TRIM_INTERVAL_MILLIS; + private static final boolean DEFAULT_USE_CACHE_FOR_ALL_THREADS; + private static final int DEFAULT_DIRECT_MEMORY_CACHE_ALIGNMENT; + static final int DEFAULT_MAX_CACHED_BYTEBUFFERS_PER_CHUNK; + + private static final int MIN_PAGE_SIZE = 4096; + private static final int MAX_CHUNK_SIZE = (int) (((long) Integer.MAX_VALUE + 1) / 2); + + private static final int CACHE_NOT_USED = 0; + + private final Runnable trimTask = new Runnable() { + @Override + public void run() { + PooledByteBufAllocator.this.trimCurrentThreadCache(); + } + }; + + static { + int defaultAlignment = SystemPropertyUtil.getInt( + "io.netty.allocator.directMemoryCacheAlignment", 0); + int defaultPageSize = SystemPropertyUtil.getInt("io.netty.allocator.pageSize", 8192); + Throwable pageSizeFallbackCause = null; + try { + validateAndCalculatePageShifts(defaultPageSize, defaultAlignment); + } catch (Throwable t) { + pageSizeFallbackCause = t; + defaultPageSize = 8192; + defaultAlignment = 0; + } + DEFAULT_PAGE_SIZE = defaultPageSize; + DEFAULT_DIRECT_MEMORY_CACHE_ALIGNMENT = defaultAlignment; + + int defaultMaxOrder = SystemPropertyUtil.getInt("io.netty.allocator.maxOrder", 9); + Throwable maxOrderFallbackCause = null; + try { + validateAndCalculateChunkSize(DEFAULT_PAGE_SIZE, defaultMaxOrder); + } catch (Throwable t) { + maxOrderFallbackCause = t; + defaultMaxOrder = 9; + } + DEFAULT_MAX_ORDER = defaultMaxOrder; + + // Determine reasonable default for nHeapArena and nDirectArena. + // Assuming each arena has 3 chunks, the pool should not consume more than 50% of max memory. + final Runtime runtime = Runtime.getRuntime(); + + /* + * We use 2 * available processors by default to reduce contention as we use 2 * available processors for the + * number of EventLoops in NIO and EPOLL as well. If we choose a smaller number we will run into hot spots as + * allocation and de-allocation needs to be synchronized on the PoolArena. + * + * See https://github.com/netty/netty/issues/3888. + */ + final int defaultMinNumArena = NettyRuntime.availableProcessors() * 2; + final int defaultChunkSize = DEFAULT_PAGE_SIZE << DEFAULT_MAX_ORDER; + DEFAULT_NUM_HEAP_ARENA = Math.max(0, + SystemPropertyUtil.getInt( + "io.netty.allocator.numHeapArenas", + (int) Math.min( + defaultMinNumArena, + runtime.maxMemory() / defaultChunkSize / 2 / 3))); + DEFAULT_NUM_DIRECT_ARENA = Math.max(0, + SystemPropertyUtil.getInt( + "io.netty.allocator.numDirectArenas", + (int) Math.min( + defaultMinNumArena, + PlatformDependent.maxDirectMemory() / defaultChunkSize / 2 / 3))); + + // cache sizes + DEFAULT_SMALL_CACHE_SIZE = SystemPropertyUtil.getInt("io.netty.allocator.smallCacheSize", 256); + DEFAULT_NORMAL_CACHE_SIZE = SystemPropertyUtil.getInt("io.netty.allocator.normalCacheSize", 64); + + // 32 kb is the default maximum capacity of the cached buffer. Similar to what is explained in + // 'Scalable memory allocation using jemalloc' + DEFAULT_MAX_CACHED_BUFFER_CAPACITY = SystemPropertyUtil.getInt( + "io.netty.allocator.maxCachedBufferCapacity", 32 * 1024); + + // the number of threshold of allocations when cached entries will be freed up if not frequently used + DEFAULT_CACHE_TRIM_INTERVAL = SystemPropertyUtil.getInt( + "io.netty.allocator.cacheTrimInterval", 8192); + + if (SystemPropertyUtil.contains("io.netty.allocation.cacheTrimIntervalMillis")) { + logger.warn("-Dio.netty.allocation.cacheTrimIntervalMillis is deprecated," + + " use -Dio.netty.allocator.cacheTrimIntervalMillis"); + + if (SystemPropertyUtil.contains("io.netty.allocator.cacheTrimIntervalMillis")) { + // Both system properties are specified. Use the non-deprecated one. + DEFAULT_CACHE_TRIM_INTERVAL_MILLIS = SystemPropertyUtil.getLong( + "io.netty.allocator.cacheTrimIntervalMillis", 0); + } else { + DEFAULT_CACHE_TRIM_INTERVAL_MILLIS = SystemPropertyUtil.getLong( + "io.netty.allocation.cacheTrimIntervalMillis", 0); + } + } else { + DEFAULT_CACHE_TRIM_INTERVAL_MILLIS = SystemPropertyUtil.getLong( + "io.netty.allocator.cacheTrimIntervalMillis", 0); + } + + DEFAULT_USE_CACHE_FOR_ALL_THREADS = SystemPropertyUtil.getBoolean( + "io.netty.allocator.useCacheForAllThreads", false); + + // Use 1023 by default as we use an ArrayDeque as backing storage which will then allocate an internal array + // of 1024 elements. Otherwise we would allocate 2048 and only use 1024 which is wasteful. + DEFAULT_MAX_CACHED_BYTEBUFFERS_PER_CHUNK = SystemPropertyUtil.getInt( + "io.netty.allocator.maxCachedByteBuffersPerChunk", 1023); + + if (logger.isDebugEnabled()) { + logger.debug("-Dio.netty.allocator.numHeapArenas: {}", DEFAULT_NUM_HEAP_ARENA); + logger.debug("-Dio.netty.allocator.numDirectArenas: {}", DEFAULT_NUM_DIRECT_ARENA); + if (pageSizeFallbackCause == null) { + logger.debug("-Dio.netty.allocator.pageSize: {}", DEFAULT_PAGE_SIZE); + } else { + logger.debug("-Dio.netty.allocator.pageSize: {}", DEFAULT_PAGE_SIZE, pageSizeFallbackCause); + } + if (maxOrderFallbackCause == null) { + logger.debug("-Dio.netty.allocator.maxOrder: {}", DEFAULT_MAX_ORDER); + } else { + logger.debug("-Dio.netty.allocator.maxOrder: {}", DEFAULT_MAX_ORDER, maxOrderFallbackCause); + } + logger.debug("-Dio.netty.allocator.chunkSize: {}", DEFAULT_PAGE_SIZE << DEFAULT_MAX_ORDER); + logger.debug("-Dio.netty.allocator.smallCacheSize: {}", DEFAULT_SMALL_CACHE_SIZE); + logger.debug("-Dio.netty.allocator.normalCacheSize: {}", DEFAULT_NORMAL_CACHE_SIZE); + logger.debug("-Dio.netty.allocator.maxCachedBufferCapacity: {}", DEFAULT_MAX_CACHED_BUFFER_CAPACITY); + logger.debug("-Dio.netty.allocator.cacheTrimInterval: {}", DEFAULT_CACHE_TRIM_INTERVAL); + logger.debug("-Dio.netty.allocator.cacheTrimIntervalMillis: {}", DEFAULT_CACHE_TRIM_INTERVAL_MILLIS); + logger.debug("-Dio.netty.allocator.useCacheForAllThreads: {}", DEFAULT_USE_CACHE_FOR_ALL_THREADS); + logger.debug("-Dio.netty.allocator.maxCachedByteBuffersPerChunk: {}", + DEFAULT_MAX_CACHED_BYTEBUFFERS_PER_CHUNK); + } + } + + public static final PooledByteBufAllocator DEFAULT = + new PooledByteBufAllocator(PlatformDependent.directBufferPreferred()); + + private final PoolArena[] heapArenas; + private final PoolArena[] directArenas; + private final int smallCacheSize; + private final int normalCacheSize; + private final List heapArenaMetrics; + private final List directArenaMetrics; + private final PoolThreadLocalCache threadCache; + private final int chunkSize; + private final PooledByteBufAllocatorMetric metric; + + public PooledByteBufAllocator() { + this(false); + } + + @SuppressWarnings("deprecation") + public PooledByteBufAllocator(boolean preferDirect) { + this(preferDirect, DEFAULT_NUM_HEAP_ARENA, DEFAULT_NUM_DIRECT_ARENA, DEFAULT_PAGE_SIZE, DEFAULT_MAX_ORDER); + } + + @SuppressWarnings("deprecation") + public PooledByteBufAllocator(int nHeapArena, int nDirectArena, int pageSize, int maxOrder) { + this(false, nHeapArena, nDirectArena, pageSize, maxOrder); + } + + /** + * @deprecated use + * {@link PooledByteBufAllocator#PooledByteBufAllocator(boolean, int, int, int, int, int, int, boolean)} + */ + @Deprecated + public PooledByteBufAllocator(boolean preferDirect, int nHeapArena, int nDirectArena, int pageSize, int maxOrder) { + this(preferDirect, nHeapArena, nDirectArena, pageSize, maxOrder, + 0, DEFAULT_SMALL_CACHE_SIZE, DEFAULT_NORMAL_CACHE_SIZE); + } + + /** + * @deprecated use + * {@link PooledByteBufAllocator#PooledByteBufAllocator(boolean, int, int, int, int, int, int, boolean)} + */ + @Deprecated + public PooledByteBufAllocator(boolean preferDirect, int nHeapArena, int nDirectArena, int pageSize, int maxOrder, + int tinyCacheSize, int smallCacheSize, int normalCacheSize) { + this(preferDirect, nHeapArena, nDirectArena, pageSize, maxOrder, smallCacheSize, + normalCacheSize, DEFAULT_USE_CACHE_FOR_ALL_THREADS, DEFAULT_DIRECT_MEMORY_CACHE_ALIGNMENT); + } + + /** + * @deprecated use + * {@link PooledByteBufAllocator#PooledByteBufAllocator(boolean, int, int, int, int, int, int, boolean)} + */ + @Deprecated + public PooledByteBufAllocator(boolean preferDirect, int nHeapArena, + int nDirectArena, int pageSize, int maxOrder, int tinyCacheSize, + int smallCacheSize, int normalCacheSize, + boolean useCacheForAllThreads) { + this(preferDirect, nHeapArena, nDirectArena, pageSize, maxOrder, + smallCacheSize, normalCacheSize, + useCacheForAllThreads); + } + + public PooledByteBufAllocator(boolean preferDirect, int nHeapArena, + int nDirectArena, int pageSize, int maxOrder, + int smallCacheSize, int normalCacheSize, + boolean useCacheForAllThreads) { + this(preferDirect, nHeapArena, nDirectArena, pageSize, maxOrder, + smallCacheSize, normalCacheSize, + useCacheForAllThreads, DEFAULT_DIRECT_MEMORY_CACHE_ALIGNMENT); + } + + /** + * @deprecated use + * {@link PooledByteBufAllocator#PooledByteBufAllocator(boolean, int, int, int, int, int, int, boolean, int)} + */ + @Deprecated + public PooledByteBufAllocator(boolean preferDirect, int nHeapArena, int nDirectArena, int pageSize, int maxOrder, + int tinyCacheSize, int smallCacheSize, int normalCacheSize, + boolean useCacheForAllThreads, int directMemoryCacheAlignment) { + this(preferDirect, nHeapArena, nDirectArena, pageSize, maxOrder, + smallCacheSize, normalCacheSize, + useCacheForAllThreads, directMemoryCacheAlignment); + } + + public PooledByteBufAllocator(boolean preferDirect, int nHeapArena, int nDirectArena, int pageSize, int maxOrder, + int smallCacheSize, int normalCacheSize, + boolean useCacheForAllThreads, int directMemoryCacheAlignment) { + super(preferDirect); + threadCache = new PoolThreadLocalCache(useCacheForAllThreads); + this.smallCacheSize = smallCacheSize; + this.normalCacheSize = normalCacheSize; + + if (directMemoryCacheAlignment != 0) { + if (!PlatformDependent.hasAlignDirectByteBuffer()) { + throw new UnsupportedOperationException("Buffer alignment is not supported. " + + "Either Unsafe or ByteBuffer.alignSlice() must be available."); + } + + // Ensure page size is a whole multiple of the alignment, or bump it to the next whole multiple. + pageSize = (int) PlatformDependent.align(pageSize, directMemoryCacheAlignment); + } + + chunkSize = validateAndCalculateChunkSize(pageSize, maxOrder); + + checkPositiveOrZero(nHeapArena, "nHeapArena"); + checkPositiveOrZero(nDirectArena, "nDirectArena"); + + checkPositiveOrZero(directMemoryCacheAlignment, "directMemoryCacheAlignment"); + if (directMemoryCacheAlignment > 0 && !isDirectMemoryCacheAlignmentSupported()) { + throw new IllegalArgumentException("directMemoryCacheAlignment is not supported"); + } + + if ((directMemoryCacheAlignment & -directMemoryCacheAlignment) != directMemoryCacheAlignment) { + throw new IllegalArgumentException("directMemoryCacheAlignment: " + + directMemoryCacheAlignment + " (expected: power of two)"); + } + + int pageShifts = validateAndCalculatePageShifts(pageSize, directMemoryCacheAlignment); + + if (nHeapArena > 0) { + heapArenas = newArenaArray(nHeapArena); + List metrics = new ArrayList(heapArenas.length); + final SizeClasses sizeClasses = new SizeClasses(pageSize, pageShifts, chunkSize, 0); + for (int i = 0; i < heapArenas.length; i ++) { + PoolArena.HeapArena arena = new PoolArena.HeapArena(this, sizeClasses); + heapArenas[i] = arena; + metrics.add(arena); + } + heapArenaMetrics = Collections.unmodifiableList(metrics); + } else { + heapArenas = null; + heapArenaMetrics = Collections.emptyList(); + } + + if (nDirectArena > 0) { + directArenas = newArenaArray(nDirectArena); + List metrics = new ArrayList(directArenas.length); + final SizeClasses sizeClasses = new SizeClasses(pageSize, pageShifts, chunkSize, + directMemoryCacheAlignment); + for (int i = 0; i < directArenas.length; i ++) { + PoolArena.DirectArena arena = new PoolArena.DirectArena(this, sizeClasses); + directArenas[i] = arena; + metrics.add(arena); + } + directArenaMetrics = Collections.unmodifiableList(metrics); + } else { + directArenas = null; + directArenaMetrics = Collections.emptyList(); + } + metric = new PooledByteBufAllocatorMetric(this); + } + + @SuppressWarnings("unchecked") + private static PoolArena[] newArenaArray(int size) { + return new PoolArena[size]; + } + + private static int validateAndCalculatePageShifts(int pageSize, int alignment) { + if (pageSize < MIN_PAGE_SIZE) { + throw new IllegalArgumentException("pageSize: " + pageSize + " (expected: " + MIN_PAGE_SIZE + ')'); + } + + if ((pageSize & pageSize - 1) != 0) { + throw new IllegalArgumentException("pageSize: " + pageSize + " (expected: power of 2)"); + } + + if (pageSize < alignment) { + throw new IllegalArgumentException("Alignment cannot be greater than page size. " + + "Alignment: " + alignment + ", page size: " + pageSize + '.'); + } + + // Logarithm base 2. At this point we know that pageSize is a power of two. + return Integer.SIZE - 1 - Integer.numberOfLeadingZeros(pageSize); + } + + private static int validateAndCalculateChunkSize(int pageSize, int maxOrder) { + if (maxOrder > 14) { + throw new IllegalArgumentException("maxOrder: " + maxOrder + " (expected: 0-14)"); + } + + // Ensure the resulting chunkSize does not overflow. + int chunkSize = pageSize; + for (int i = maxOrder; i > 0; i --) { + if (chunkSize > MAX_CHUNK_SIZE / 2) { + throw new IllegalArgumentException(String.format( + "pageSize (%d) << maxOrder (%d) must not exceed %d", pageSize, maxOrder, MAX_CHUNK_SIZE)); + } + chunkSize <<= 1; + } + return chunkSize; + } + + @Override + protected ByteBuf newHeapBuffer(int initialCapacity, int maxCapacity) { + PoolThreadCache cache = threadCache.get(); + PoolArena heapArena = cache.heapArena; + + final ByteBuf buf; + if (heapArena != null) { + buf = heapArena.allocate(cache, initialCapacity, maxCapacity); + } else { + buf = PlatformDependent.hasUnsafe() ? + new UnpooledUnsafeHeapByteBuf(this, initialCapacity, maxCapacity) : + new UnpooledHeapByteBuf(this, initialCapacity, maxCapacity); + } + + return toLeakAwareBuffer(buf); + } + + @Override + protected ByteBuf newDirectBuffer(int initialCapacity, int maxCapacity) { + PoolThreadCache cache = threadCache.get(); + PoolArena directArena = cache.directArena; + + final ByteBuf buf; + if (directArena != null) { + buf = directArena.allocate(cache, initialCapacity, maxCapacity); + } else { + buf = PlatformDependent.hasUnsafe() ? + UnsafeByteBufUtil.newUnsafeDirectByteBuf(this, initialCapacity, maxCapacity) : + new UnpooledDirectByteBuf(this, initialCapacity, maxCapacity); + } + + return toLeakAwareBuffer(buf); + } + + /** + * Default number of heap arenas - System Property: io.netty.allocator.numHeapArenas - default 2 * cores + */ + public static int defaultNumHeapArena() { + return DEFAULT_NUM_HEAP_ARENA; + } + + /** + * Default number of direct arenas - System Property: io.netty.allocator.numDirectArenas - default 2 * cores + */ + public static int defaultNumDirectArena() { + return DEFAULT_NUM_DIRECT_ARENA; + } + + /** + * Default buffer page size - System Property: io.netty.allocator.pageSize - default 8192 + */ + public static int defaultPageSize() { + return DEFAULT_PAGE_SIZE; + } + + /** + * Default maximum order - System Property: io.netty.allocator.maxOrder - default 9 + */ + public static int defaultMaxOrder() { + return DEFAULT_MAX_ORDER; + } + + /** + * Default thread caching behavior - System Property: io.netty.allocator.useCacheForAllThreads - default false + */ + public static boolean defaultUseCacheForAllThreads() { + return DEFAULT_USE_CACHE_FOR_ALL_THREADS; + } + + /** + * Default prefer direct - System Property: io.netty.noPreferDirect - default false + */ + public static boolean defaultPreferDirect() { + return PlatformDependent.directBufferPreferred(); + } + + /** + * Default tiny cache size - default 0 + * + * @deprecated Tiny caches have been merged into small caches. + */ + @Deprecated + public static int defaultTinyCacheSize() { + return 0; + } + + /** + * Default small cache size - System Property: io.netty.allocator.smallCacheSize - default 256 + */ + public static int defaultSmallCacheSize() { + return DEFAULT_SMALL_CACHE_SIZE; + } + + /** + * Default normal cache size - System Property: io.netty.allocator.normalCacheSize - default 64 + */ + public static int defaultNormalCacheSize() { + return DEFAULT_NORMAL_CACHE_SIZE; + } + + /** + * Return {@code true} if direct memory cache alignment is supported, {@code false} otherwise. + */ + public static boolean isDirectMemoryCacheAlignmentSupported() { + return PlatformDependent.hasUnsafe(); + } + + @Override + public boolean isDirectBufferPooled() { + return directArenas != null; + } + + /** + * @deprecated will be removed + * Returns {@code true} if the calling {@link Thread} has a {@link ThreadLocal} cache for the allocated + * buffers. + */ + @Deprecated + public boolean hasThreadLocalCache() { + return threadCache.isSet(); + } + + /** + * @deprecated will be removed + * Free all cached buffers for the calling {@link Thread}. + */ + @Deprecated + public void freeThreadLocalCache() { + threadCache.remove(); + } + + private final class PoolThreadLocalCache extends FastThreadLocal { + private final boolean useCacheForAllThreads; + + PoolThreadLocalCache(boolean useCacheForAllThreads) { + this.useCacheForAllThreads = useCacheForAllThreads; + } + + @Override + protected synchronized PoolThreadCache initialValue() { + final PoolArena heapArena = leastUsedArena(heapArenas); + final PoolArena directArena = leastUsedArena(directArenas); + + final Thread current = Thread.currentThread(); + final EventExecutor executor = ThreadExecutorMap.currentExecutor(); + + if (useCacheForAllThreads || + // If the current thread is a FastThreadLocalThread we will always use the cache + current instanceof FastThreadLocalThread || + // The Thread is used by an EventExecutor, let's use the cache as the chances are good that we + // will allocate a lot! + executor != null) { + final PoolThreadCache cache = new PoolThreadCache( + heapArena, directArena, smallCacheSize, normalCacheSize, + DEFAULT_MAX_CACHED_BUFFER_CAPACITY, DEFAULT_CACHE_TRIM_INTERVAL, true); + + if (DEFAULT_CACHE_TRIM_INTERVAL_MILLIS > 0) { + if (executor != null) { + executor.scheduleAtFixedRate(trimTask, DEFAULT_CACHE_TRIM_INTERVAL_MILLIS, + DEFAULT_CACHE_TRIM_INTERVAL_MILLIS, TimeUnit.MILLISECONDS); + } + } + return cache; + } + // No caching so just use 0 as sizes. + return new PoolThreadCache(heapArena, directArena, 0, 0, 0, 0, false); + } + + @Override + protected void onRemoval(PoolThreadCache threadCache) { + threadCache.free(false); + } + + private PoolArena leastUsedArena(PoolArena[] arenas) { + if (arenas == null || arenas.length == 0) { + return null; + } + + PoolArena minArena = arenas[0]; + //optimized + //If it is the first execution, directly return minarena and reduce the number of for loop comparisons below + if (minArena.numThreadCaches.get() == CACHE_NOT_USED) { + return minArena; + } + for (int i = 1; i < arenas.length; i++) { + PoolArena arena = arenas[i]; + if (arena.numThreadCaches.get() < minArena.numThreadCaches.get()) { + minArena = arena; + } + } + + return minArena; + } + } + + @Override + public PooledByteBufAllocatorMetric metric() { + return metric; + } + + /** + * Return the number of heap arenas. + * + * @deprecated use {@link PooledByteBufAllocatorMetric#numHeapArenas()}. + */ + @Deprecated + public int numHeapArenas() { + return heapArenaMetrics.size(); + } + + /** + * Return the number of direct arenas. + * + * @deprecated use {@link PooledByteBufAllocatorMetric#numDirectArenas()}. + */ + @Deprecated + public int numDirectArenas() { + return directArenaMetrics.size(); + } + + /** + * Return a {@link List} of all heap {@link PoolArenaMetric}s that are provided by this pool. + * + * @deprecated use {@link PooledByteBufAllocatorMetric#heapArenas()}. + */ + @Deprecated + public List heapArenas() { + return heapArenaMetrics; + } + + /** + * Return a {@link List} of all direct {@link PoolArenaMetric}s that are provided by this pool. + * + * @deprecated use {@link PooledByteBufAllocatorMetric#directArenas()}. + */ + @Deprecated + public List directArenas() { + return directArenaMetrics; + } + + /** + * Return the number of thread local caches used by this {@link PooledByteBufAllocator}. + * + * @deprecated use {@link PooledByteBufAllocatorMetric#numThreadLocalCaches()}. + */ + @Deprecated + public int numThreadLocalCaches() { + PoolArena[] arenas = heapArenas != null ? heapArenas : directArenas; + if (arenas == null) { + return 0; + } + + int total = 0; + for (PoolArena arena : arenas) { + total += arena.numThreadCaches.get(); + } + + return total; + } + + /** + * Return the size of the tiny cache. + * + * @deprecated use {@link PooledByteBufAllocatorMetric#tinyCacheSize()}. + */ + @Deprecated + public int tinyCacheSize() { + return 0; + } + + /** + * Return the size of the small cache. + * + * @deprecated use {@link PooledByteBufAllocatorMetric#smallCacheSize()}. + */ + @Deprecated + public int smallCacheSize() { + return smallCacheSize; + } + + /** + * Return the size of the normal cache. + * + * @deprecated use {@link PooledByteBufAllocatorMetric#normalCacheSize()}. + */ + @Deprecated + public int normalCacheSize() { + return normalCacheSize; + } + + /** + * Return the chunk size for an arena. + * + * @deprecated use {@link PooledByteBufAllocatorMetric#chunkSize()}. + */ + @Deprecated + public final int chunkSize() { + return chunkSize; + } + + final long usedHeapMemory() { + return usedMemory(heapArenas); + } + + final long usedDirectMemory() { + return usedMemory(directArenas); + } + + private static long usedMemory(PoolArena[] arenas) { + if (arenas == null) { + return -1; + } + long used = 0; + for (PoolArena arena : arenas) { + used += arena.numActiveBytes(); + if (used < 0) { + return Long.MAX_VALUE; + } + } + return used; + } + + /** + * Returns the number of bytes of heap memory that is currently pinned to heap buffers allocated by a + * {@link ByteBufAllocator}, or {@code -1} if unknown. + * A buffer can pin more memory than its {@linkplain ByteBuf#capacity() capacity} might indicate, + * due to implementation details of the allocator. + */ + public final long pinnedHeapMemory() { + return pinnedMemory(heapArenas); + } + + /** + * Returns the number of bytes of direct memory that is currently pinned to direct buffers allocated by a + * {@link ByteBufAllocator}, or {@code -1} if unknown. + * A buffer can pin more memory than its {@linkplain ByteBuf#capacity() capacity} might indicate, + * due to implementation details of the allocator. + */ + public final long pinnedDirectMemory() { + return pinnedMemory(directArenas); + } + + private static long pinnedMemory(PoolArena[] arenas) { + if (arenas == null) { + return -1; + } + long used = 0; + for (PoolArena arena : arenas) { + used += arena.numPinnedBytes(); + if (used < 0) { + return Long.MAX_VALUE; + } + } + return used; + } + + final PoolThreadCache threadCache() { + PoolThreadCache cache = threadCache.get(); + assert cache != null; + return cache; + } + + /** + * Trim thread local cache for the current {@link Thread}, which will give back any cached memory that was not + * allocated frequently since the last trim operation. + * + * Returns {@code true} if a cache for the current {@link Thread} exists and so was trimmed, false otherwise. + */ + public boolean trimCurrentThreadCache() { + PoolThreadCache cache = threadCache.getIfExists(); + if (cache != null) { + cache.trim(); + return true; + } + return false; + } + + /** + * Returns the status of the allocator (which contains all metrics) as string. Be aware this may be expensive + * and so should not called too frequently. + */ + public String dumpStats() { + int heapArenasLen = heapArenas == null ? 0 : heapArenas.length; + StringBuilder buf = new StringBuilder(512) + .append(heapArenasLen) + .append(" heap arena(s):") + .append(StringUtil.NEWLINE); + if (heapArenasLen > 0) { + for (PoolArena a: heapArenas) { + buf.append(a); + } + } + + int directArenasLen = directArenas == null ? 0 : directArenas.length; + + buf.append(directArenasLen) + .append(" direct arena(s):") + .append(StringUtil.NEWLINE); + if (directArenasLen > 0) { + for (PoolArena a: directArenas) { + buf.append(a); + } + } + + return buf.toString(); + } +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/PooledByteBufAllocatorMetric.java b/netty-buffer/src/main/java/io/netty/buffer/PooledByteBufAllocatorMetric.java new file mode 100644 index 0000000..f20dfff --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/PooledByteBufAllocatorMetric.java @@ -0,0 +1,124 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import io.netty.util.internal.StringUtil; + +import java.util.List; + +/** + * Exposed metric for {@link PooledByteBufAllocator}. + */ +@SuppressWarnings("deprecation") +public final class PooledByteBufAllocatorMetric implements ByteBufAllocatorMetric { + + private final PooledByteBufAllocator allocator; + + PooledByteBufAllocatorMetric(PooledByteBufAllocator allocator) { + this.allocator = allocator; + } + + /** + * Return the number of heap arenas. + */ + public int numHeapArenas() { + return allocator.numHeapArenas(); + } + + /** + * Return the number of direct arenas. + */ + public int numDirectArenas() { + return allocator.numDirectArenas(); + } + + /** + * Return a {@link List} of all heap {@link PoolArenaMetric}s that are provided by this pool. + */ + public List heapArenas() { + return allocator.heapArenas(); + } + + /** + * Return a {@link List} of all direct {@link PoolArenaMetric}s that are provided by this pool. + */ + public List directArenas() { + return allocator.directArenas(); + } + + /** + * Return the number of thread local caches used by this {@link PooledByteBufAllocator}. + */ + public int numThreadLocalCaches() { + return allocator.numThreadLocalCaches(); + } + + /** + * Return the size of the tiny cache. + * + * @deprecated Tiny caches have been merged into small caches. + */ + @Deprecated + public int tinyCacheSize() { + return allocator.tinyCacheSize(); + } + + /** + * Return the size of the small cache. + */ + public int smallCacheSize() { + return allocator.smallCacheSize(); + } + + /** + * Return the size of the normal cache. + */ + public int normalCacheSize() { + return allocator.normalCacheSize(); + } + + /** + * Return the chunk size for an arena. + */ + public int chunkSize() { + return allocator.chunkSize(); + } + + @Override + public long usedHeapMemory() { + return allocator.usedHeapMemory(); + } + + @Override + public long usedDirectMemory() { + return allocator.usedDirectMemory(); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(256); + sb.append(StringUtil.simpleClassName(this)) + .append("(usedHeapMemory: ").append(usedHeapMemory()) + .append("; usedDirectMemory: ").append(usedDirectMemory()) + .append("; numHeapArenas: ").append(numHeapArenas()) + .append("; numDirectArenas: ").append(numDirectArenas()) + .append("; smallCacheSize: ").append(smallCacheSize()) + .append("; normalCacheSize: ").append(normalCacheSize()) + .append("; numThreadLocalCaches: ").append(numThreadLocalCaches()) + .append("; chunkSize: ").append(chunkSize()).append(')'); + return sb.toString(); + } +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/PooledDirectByteBuf.java b/netty-buffer/src/main/java/io/netty/buffer/PooledDirectByteBuf.java new file mode 100644 index 0000000..338b677 --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/PooledDirectByteBuf.java @@ -0,0 +1,313 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.buffer; + +import io.netty.util.internal.ObjectPool; +import io.netty.util.internal.ObjectPool.Handle; +import io.netty.util.internal.ObjectPool.ObjectCreator; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.ByteBuffer; + +final class PooledDirectByteBuf extends PooledByteBuf { + + private static final ObjectPool RECYCLER = ObjectPool.newPool( + new ObjectCreator() { + @Override + public PooledDirectByteBuf newObject(Handle handle) { + return new PooledDirectByteBuf(handle, 0); + } + }); + + static PooledDirectByteBuf newInstance(int maxCapacity) { + PooledDirectByteBuf buf = RECYCLER.get(); + buf.reuse(maxCapacity); + return buf; + } + + private PooledDirectByteBuf(Handle recyclerHandle, int maxCapacity) { + super(recyclerHandle, maxCapacity); + } + + @Override + protected ByteBuffer newInternalNioBuffer(ByteBuffer memory) { + return memory.duplicate(); + } + + @Override + public boolean isDirect() { + return true; + } + + @Override + protected byte _getByte(int index) { + return memory.get(idx(index)); + } + + @Override + protected short _getShort(int index) { + return memory.getShort(idx(index)); + } + + @Override + protected short _getShortLE(int index) { + return ByteBufUtil.swapShort(_getShort(index)); + } + + @Override + protected int _getUnsignedMedium(int index) { + index = idx(index); + return (memory.get(index) & 0xff) << 16 | + (memory.get(index + 1) & 0xff) << 8 | + memory.get(index + 2) & 0xff; + } + + @Override + protected int _getUnsignedMediumLE(int index) { + index = idx(index); + return memory.get(index) & 0xff | + (memory.get(index + 1) & 0xff) << 8 | + (memory.get(index + 2) & 0xff) << 16; + } + + @Override + protected int _getInt(int index) { + return memory.getInt(idx(index)); + } + + @Override + protected int _getIntLE(int index) { + return ByteBufUtil.swapInt(_getInt(index)); + } + + @Override + protected long _getLong(int index) { + return memory.getLong(idx(index)); + } + + @Override + protected long _getLongLE(int index) { + return ByteBufUtil.swapLong(_getLong(index)); + } + + @Override + public ByteBuf getBytes(int index, ByteBuf dst, int dstIndex, int length) { + checkDstIndex(index, length, dstIndex, dst.capacity()); + if (dst.hasArray()) { + getBytes(index, dst.array(), dst.arrayOffset() + dstIndex, length); + } else if (dst.nioBufferCount() > 0) { + for (ByteBuffer bb: dst.nioBuffers(dstIndex, length)) { + int bbLen = bb.remaining(); + getBytes(index, bb); + index += bbLen; + } + } else { + dst.setBytes(dstIndex, this, index, length); + } + return this; + } + + @Override + public ByteBuf getBytes(int index, byte[] dst, int dstIndex, int length) { + checkDstIndex(index, length, dstIndex, dst.length); + _internalNioBuffer(index, length, true).get(dst, dstIndex, length); + return this; + } + + @Override + public ByteBuf readBytes(byte[] dst, int dstIndex, int length) { + checkDstIndex(length, dstIndex, dst.length); + _internalNioBuffer(readerIndex, length, false).get(dst, dstIndex, length); + readerIndex += length; + return this; + } + + @Override + public ByteBuf getBytes(int index, ByteBuffer dst) { + dst.put(duplicateInternalNioBuffer(index, dst.remaining())); + return this; + } + + @Override + public ByteBuf readBytes(ByteBuffer dst) { + int length = dst.remaining(); + checkReadableBytes(length); + dst.put(_internalNioBuffer(readerIndex, length, false)); + readerIndex += length; + return this; + } + + @Override + public ByteBuf getBytes(int index, OutputStream out, int length) throws IOException { + getBytes(index, out, length, false); + return this; + } + + private void getBytes(int index, OutputStream out, int length, boolean internal) throws IOException { + checkIndex(index, length); + if (length == 0) { + return; + } + ByteBufUtil.readBytes(alloc(), internal ? internalNioBuffer() : memory.duplicate(), idx(index), length, out); + } + + @Override + public ByteBuf readBytes(OutputStream out, int length) throws IOException { + checkReadableBytes(length); + getBytes(readerIndex, out, length, true); + readerIndex += length; + return this; + } + + @Override + protected void _setByte(int index, int value) { + memory.put(idx(index), (byte) value); + } + + @Override + protected void _setShort(int index, int value) { + memory.putShort(idx(index), (short) value); + } + + @Override + protected void _setShortLE(int index, int value) { + _setShort(index, ByteBufUtil.swapShort((short) value)); + } + + @Override + protected void _setMedium(int index, int value) { + index = idx(index); + memory.put(index, (byte) (value >>> 16)); + memory.put(index + 1, (byte) (value >>> 8)); + memory.put(index + 2, (byte) value); + } + + @Override + protected void _setMediumLE(int index, int value) { + index = idx(index); + memory.put(index, (byte) value); + memory.put(index + 1, (byte) (value >>> 8)); + memory.put(index + 2, (byte) (value >>> 16)); + } + + @Override + protected void _setInt(int index, int value) { + memory.putInt(idx(index), value); + } + + @Override + protected void _setIntLE(int index, int value) { + _setInt(index, ByteBufUtil.swapInt(value)); + } + + @Override + protected void _setLong(int index, long value) { + memory.putLong(idx(index), value); + } + + @Override + protected void _setLongLE(int index, long value) { + _setLong(index, ByteBufUtil.swapLong(value)); + } + + @Override + public ByteBuf setBytes(int index, ByteBuf src, int srcIndex, int length) { + checkSrcIndex(index, length, srcIndex, src.capacity()); + if (src.hasArray()) { + setBytes(index, src.array(), src.arrayOffset() + srcIndex, length); + } else if (src.nioBufferCount() > 0) { + for (ByteBuffer bb: src.nioBuffers(srcIndex, length)) { + int bbLen = bb.remaining(); + setBytes(index, bb); + index += bbLen; + } + } else { + src.getBytes(srcIndex, this, index, length); + } + return this; + } + + @Override + public ByteBuf setBytes(int index, byte[] src, int srcIndex, int length) { + checkSrcIndex(index, length, srcIndex, src.length); + _internalNioBuffer(index, length, false).put(src, srcIndex, length); + return this; + } + + @Override + public ByteBuf setBytes(int index, ByteBuffer src) { + int length = src.remaining(); + checkIndex(index, length); + ByteBuffer tmpBuf = internalNioBuffer(); + if (src == tmpBuf) { + src = src.duplicate(); + } + + index = idx(index); + tmpBuf.limit(index + length).position(index); + tmpBuf.put(src); + return this; + } + + @Override + public int setBytes(int index, InputStream in, int length) throws IOException { + checkIndex(index, length); + byte[] tmp = ByteBufUtil.threadLocalTempArray(length); + int readBytes = in.read(tmp, 0, length); + if (readBytes <= 0) { + return readBytes; + } + ByteBuffer tmpBuf = internalNioBuffer(); + tmpBuf.position(idx(index)); + tmpBuf.put(tmp, 0, readBytes); + return readBytes; + } + + @Override + public ByteBuf copy(int index, int length) { + checkIndex(index, length); + ByteBuf copy = alloc().directBuffer(length, maxCapacity()); + return copy.writeBytes(this, index, length); + } + + @Override + public boolean hasArray() { + return false; + } + + @Override + public byte[] array() { + throw new UnsupportedOperationException("direct buffer"); + } + + @Override + public int arrayOffset() { + throw new UnsupportedOperationException("direct buffer"); + } + + @Override + public boolean hasMemoryAddress() { + return false; + } + + @Override + public long memoryAddress() { + throw new UnsupportedOperationException(); + } +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/PooledDuplicatedByteBuf.java b/netty-buffer/src/main/java/io/netty/buffer/PooledDuplicatedByteBuf.java new file mode 100644 index 0000000..717a249 --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/PooledDuplicatedByteBuf.java @@ -0,0 +1,378 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.buffer; + +import io.netty.util.ByteProcessor; +import io.netty.util.internal.ObjectPool; +import io.netty.util.internal.ObjectPool.Handle; +import io.netty.util.internal.ObjectPool.ObjectCreator; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.ByteBuffer; +import java.nio.channels.FileChannel; +import java.nio.channels.GatheringByteChannel; +import java.nio.channels.ScatteringByteChannel; + +final class PooledDuplicatedByteBuf extends AbstractPooledDerivedByteBuf { + + private static final ObjectPool RECYCLER = ObjectPool.newPool( + new ObjectCreator() { + @Override + public PooledDuplicatedByteBuf newObject(Handle handle) { + return new PooledDuplicatedByteBuf(handle); + } + }); + + static PooledDuplicatedByteBuf newInstance(AbstractByteBuf unwrapped, ByteBuf wrapped, + int readerIndex, int writerIndex) { + final PooledDuplicatedByteBuf duplicate = RECYCLER.get(); + duplicate.init(unwrapped, wrapped, readerIndex, writerIndex, unwrapped.maxCapacity()); + duplicate.markReaderIndex(); + duplicate.markWriterIndex(); + + return duplicate; + } + + private PooledDuplicatedByteBuf(Handle handle) { + super(handle); + } + + @Override + public int capacity() { + return unwrap().capacity(); + } + + @Override + public ByteBuf capacity(int newCapacity) { + unwrap().capacity(newCapacity); + return this; + } + + @Override + public int arrayOffset() { + return unwrap().arrayOffset(); + } + + @Override + public long memoryAddress() { + return unwrap().memoryAddress(); + } + + @Override + public ByteBuffer nioBuffer(int index, int length) { + return unwrap().nioBuffer(index, length); + } + + @Override + public ByteBuffer[] nioBuffers(int index, int length) { + return unwrap().nioBuffers(index, length); + } + + @Override + public ByteBuf copy(int index, int length) { + return unwrap().copy(index, length); + } + + @Override + public ByteBuf retainedSlice(int index, int length) { + return PooledSlicedByteBuf.newInstance(unwrap(), this, index, length); + } + + @Override + public ByteBuf duplicate() { + return duplicate0().setIndex(readerIndex(), writerIndex()); + } + + @Override + public ByteBuf retainedDuplicate() { + return PooledDuplicatedByteBuf.newInstance(unwrap(), this, readerIndex(), writerIndex()); + } + + @Override + public byte getByte(int index) { + return unwrap().getByte(index); + } + + @Override + protected byte _getByte(int index) { + return unwrap()._getByte(index); + } + + @Override + public short getShort(int index) { + return unwrap().getShort(index); + } + + @Override + protected short _getShort(int index) { + return unwrap()._getShort(index); + } + + @Override + public short getShortLE(int index) { + return unwrap().getShortLE(index); + } + + @Override + protected short _getShortLE(int index) { + return unwrap()._getShortLE(index); + } + + @Override + public int getUnsignedMedium(int index) { + return unwrap().getUnsignedMedium(index); + } + + @Override + protected int _getUnsignedMedium(int index) { + return unwrap()._getUnsignedMedium(index); + } + + @Override + public int getUnsignedMediumLE(int index) { + return unwrap().getUnsignedMediumLE(index); + } + + @Override + protected int _getUnsignedMediumLE(int index) { + return unwrap()._getUnsignedMediumLE(index); + } + + @Override + public int getInt(int index) { + return unwrap().getInt(index); + } + + @Override + protected int _getInt(int index) { + return unwrap()._getInt(index); + } + + @Override + public int getIntLE(int index) { + return unwrap().getIntLE(index); + } + + @Override + protected int _getIntLE(int index) { + return unwrap()._getIntLE(index); + } + + @Override + public long getLong(int index) { + return unwrap().getLong(index); + } + + @Override + protected long _getLong(int index) { + return unwrap()._getLong(index); + } + + @Override + public long getLongLE(int index) { + return unwrap().getLongLE(index); + } + + @Override + protected long _getLongLE(int index) { + return unwrap()._getLongLE(index); + } + + @Override + public ByteBuf getBytes(int index, ByteBuf dst, int dstIndex, int length) { + unwrap().getBytes(index, dst, dstIndex, length); + return this; + } + + @Override + public ByteBuf getBytes(int index, byte[] dst, int dstIndex, int length) { + unwrap().getBytes(index, dst, dstIndex, length); + return this; + } + + @Override + public ByteBuf getBytes(int index, ByteBuffer dst) { + unwrap().getBytes(index, dst); + return this; + } + + @Override + public ByteBuf setByte(int index, int value) { + unwrap().setByte(index, value); + return this; + } + + @Override + protected void _setByte(int index, int value) { + unwrap()._setByte(index, value); + } + + @Override + public ByteBuf setShort(int index, int value) { + unwrap().setShort(index, value); + return this; + } + + @Override + protected void _setShort(int index, int value) { + unwrap()._setShort(index, value); + } + + @Override + public ByteBuf setShortLE(int index, int value) { + unwrap().setShortLE(index, value); + return this; + } + + @Override + protected void _setShortLE(int index, int value) { + unwrap()._setShortLE(index, value); + } + + @Override + public ByteBuf setMedium(int index, int value) { + unwrap().setMedium(index, value); + return this; + } + + @Override + protected void _setMedium(int index, int value) { + unwrap()._setMedium(index, value); + } + + @Override + public ByteBuf setMediumLE(int index, int value) { + unwrap().setMediumLE(index, value); + return this; + } + + @Override + protected void _setMediumLE(int index, int value) { + unwrap()._setMediumLE(index, value); + } + + @Override + public ByteBuf setInt(int index, int value) { + unwrap().setInt(index, value); + return this; + } + + @Override + protected void _setInt(int index, int value) { + unwrap()._setInt(index, value); + } + + @Override + public ByteBuf setIntLE(int index, int value) { + unwrap().setIntLE(index, value); + return this; + } + + @Override + protected void _setIntLE(int index, int value) { + unwrap()._setIntLE(index, value); + } + + @Override + public ByteBuf setLong(int index, long value) { + unwrap().setLong(index, value); + return this; + } + + @Override + protected void _setLong(int index, long value) { + unwrap()._setLong(index, value); + } + + @Override + public ByteBuf setLongLE(int index, long value) { + unwrap().setLongLE(index, value); + return this; + } + + @Override + protected void _setLongLE(int index, long value) { + unwrap().setLongLE(index, value); + } + + @Override + public ByteBuf setBytes(int index, byte[] src, int srcIndex, int length) { + unwrap().setBytes(index, src, srcIndex, length); + return this; + } + + @Override + public ByteBuf setBytes(int index, ByteBuf src, int srcIndex, int length) { + unwrap().setBytes(index, src, srcIndex, length); + return this; + } + + @Override + public ByteBuf setBytes(int index, ByteBuffer src) { + unwrap().setBytes(index, src); + return this; + } + + @Override + public ByteBuf getBytes(int index, OutputStream out, int length) + throws IOException { + unwrap().getBytes(index, out, length); + return this; + } + + @Override + public int getBytes(int index, GatheringByteChannel out, int length) + throws IOException { + return unwrap().getBytes(index, out, length); + } + + @Override + public int getBytes(int index, FileChannel out, long position, int length) + throws IOException { + return unwrap().getBytes(index, out, position, length); + } + + @Override + public int setBytes(int index, InputStream in, int length) + throws IOException { + return unwrap().setBytes(index, in, length); + } + + @Override + public int setBytes(int index, ScatteringByteChannel in, int length) + throws IOException { + return unwrap().setBytes(index, in, length); + } + + @Override + public int setBytes(int index, FileChannel in, long position, int length) + throws IOException { + return unwrap().setBytes(index, in, position, length); + } + + @Override + public int forEachByte(int index, int length, ByteProcessor processor) { + return unwrap().forEachByte(index, length, processor); + } + + @Override + public int forEachByteDesc(int index, int length, ByteProcessor processor) { + return unwrap().forEachByteDesc(index, length, processor); + } +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/PooledHeapByteBuf.java b/netty-buffer/src/main/java/io/netty/buffer/PooledHeapByteBuf.java new file mode 100644 index 0000000..1825cc7 --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/PooledHeapByteBuf.java @@ -0,0 +1,254 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.buffer; + +import io.netty.util.internal.ObjectPool; +import io.netty.util.internal.ObjectPool.Handle; +import io.netty.util.internal.ObjectPool.ObjectCreator; +import io.netty.util.internal.PlatformDependent; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.ByteBuffer; + +class PooledHeapByteBuf extends PooledByteBuf { + + private static final ObjectPool RECYCLER = ObjectPool.newPool( + new ObjectCreator() { + @Override + public PooledHeapByteBuf newObject(Handle handle) { + return new PooledHeapByteBuf(handle, 0); + } + }); + + static PooledHeapByteBuf newInstance(int maxCapacity) { + PooledHeapByteBuf buf = RECYCLER.get(); + buf.reuse(maxCapacity); + return buf; + } + + PooledHeapByteBuf(Handle recyclerHandle, int maxCapacity) { + super(recyclerHandle, maxCapacity); + } + + @Override + public final boolean isDirect() { + return false; + } + + @Override + protected byte _getByte(int index) { + return HeapByteBufUtil.getByte(memory, idx(index)); + } + + @Override + protected short _getShort(int index) { + return HeapByteBufUtil.getShort(memory, idx(index)); + } + + @Override + protected short _getShortLE(int index) { + return HeapByteBufUtil.getShortLE(memory, idx(index)); + } + + @Override + protected int _getUnsignedMedium(int index) { + return HeapByteBufUtil.getUnsignedMedium(memory, idx(index)); + } + + @Override + protected int _getUnsignedMediumLE(int index) { + return HeapByteBufUtil.getUnsignedMediumLE(memory, idx(index)); + } + + @Override + protected int _getInt(int index) { + return HeapByteBufUtil.getInt(memory, idx(index)); + } + + @Override + protected int _getIntLE(int index) { + return HeapByteBufUtil.getIntLE(memory, idx(index)); + } + + @Override + protected long _getLong(int index) { + return HeapByteBufUtil.getLong(memory, idx(index)); + } + + @Override + protected long _getLongLE(int index) { + return HeapByteBufUtil.getLongLE(memory, idx(index)); + } + + @Override + public final ByteBuf getBytes(int index, ByteBuf dst, int dstIndex, int length) { + checkDstIndex(index, length, dstIndex, dst.capacity()); + if (dst.hasMemoryAddress()) { + PlatformDependent.copyMemory(memory, idx(index), dst.memoryAddress() + dstIndex, length); + } else if (dst.hasArray()) { + getBytes(index, dst.array(), dst.arrayOffset() + dstIndex, length); + } else { + dst.setBytes(dstIndex, memory, idx(index), length); + } + return this; + } + + @Override + public final ByteBuf getBytes(int index, byte[] dst, int dstIndex, int length) { + checkDstIndex(index, length, dstIndex, dst.length); + System.arraycopy(memory, idx(index), dst, dstIndex, length); + return this; + } + + @Override + public final ByteBuf getBytes(int index, ByteBuffer dst) { + int length = dst.remaining(); + checkIndex(index, length); + dst.put(memory, idx(index), length); + return this; + } + + @Override + public final ByteBuf getBytes(int index, OutputStream out, int length) throws IOException { + checkIndex(index, length); + out.write(memory, idx(index), length); + return this; + } + + @Override + protected void _setByte(int index, int value) { + HeapByteBufUtil.setByte(memory, idx(index), value); + } + + @Override + protected void _setShort(int index, int value) { + HeapByteBufUtil.setShort(memory, idx(index), value); + } + + @Override + protected void _setShortLE(int index, int value) { + HeapByteBufUtil.setShortLE(memory, idx(index), value); + } + + @Override + protected void _setMedium(int index, int value) { + HeapByteBufUtil.setMedium(memory, idx(index), value); + } + + @Override + protected void _setMediumLE(int index, int value) { + HeapByteBufUtil.setMediumLE(memory, idx(index), value); + } + + @Override + protected void _setInt(int index, int value) { + HeapByteBufUtil.setInt(memory, idx(index), value); + } + + @Override + protected void _setIntLE(int index, int value) { + HeapByteBufUtil.setIntLE(memory, idx(index), value); + } + + @Override + protected void _setLong(int index, long value) { + HeapByteBufUtil.setLong(memory, idx(index), value); + } + + @Override + protected void _setLongLE(int index, long value) { + HeapByteBufUtil.setLongLE(memory, idx(index), value); + } + + @Override + public final ByteBuf setBytes(int index, ByteBuf src, int srcIndex, int length) { + checkSrcIndex(index, length, srcIndex, src.capacity()); + if (src.hasMemoryAddress()) { + PlatformDependent.copyMemory(src.memoryAddress() + srcIndex, memory, idx(index), length); + } else if (src.hasArray()) { + setBytes(index, src.array(), src.arrayOffset() + srcIndex, length); + } else { + src.getBytes(srcIndex, memory, idx(index), length); + } + return this; + } + + @Override + public final ByteBuf setBytes(int index, byte[] src, int srcIndex, int length) { + checkSrcIndex(index, length, srcIndex, src.length); + System.arraycopy(src, srcIndex, memory, idx(index), length); + return this; + } + + @Override + public final ByteBuf setBytes(int index, ByteBuffer src) { + int length = src.remaining(); + checkIndex(index, length); + src.get(memory, idx(index), length); + return this; + } + + @Override + public final int setBytes(int index, InputStream in, int length) throws IOException { + checkIndex(index, length); + return in.read(memory, idx(index), length); + } + + @Override + public final ByteBuf copy(int index, int length) { + checkIndex(index, length); + ByteBuf copy = alloc().heapBuffer(length, maxCapacity()); + return copy.writeBytes(memory, idx(index), length); + } + + @Override + final ByteBuffer duplicateInternalNioBuffer(int index, int length) { + checkIndex(index, length); + return ByteBuffer.wrap(memory, idx(index), length).slice(); + } + + @Override + public final boolean hasArray() { + return true; + } + + @Override + public final byte[] array() { + ensureAccessible(); + return memory; + } + + @Override + public final int arrayOffset() { + return offset; + } + + @Override + public final boolean hasMemoryAddress() { + return false; + } + + @Override + public final long memoryAddress() { + throw new UnsupportedOperationException(); + } + + @Override + protected final ByteBuffer newInternalNioBuffer(byte[] memory) { + return ByteBuffer.wrap(memory); + } +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/PooledSlicedByteBuf.java b/netty-buffer/src/main/java/io/netty/buffer/PooledSlicedByteBuf.java new file mode 100644 index 0000000..054b637 --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/PooledSlicedByteBuf.java @@ -0,0 +1,441 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.buffer; + +import io.netty.util.ByteProcessor; +import io.netty.util.internal.ObjectPool; +import io.netty.util.internal.ObjectPool.Handle; +import io.netty.util.internal.ObjectPool.ObjectCreator; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.ByteBuffer; +import java.nio.channels.FileChannel; +import java.nio.channels.GatheringByteChannel; +import java.nio.channels.ScatteringByteChannel; + +import static io.netty.buffer.AbstractUnpooledSlicedByteBuf.checkSliceOutOfBounds; + +final class PooledSlicedByteBuf extends AbstractPooledDerivedByteBuf { + + private static final ObjectPool RECYCLER = ObjectPool.newPool( + new ObjectCreator() { + @Override + public PooledSlicedByteBuf newObject(Handle handle) { + return new PooledSlicedByteBuf(handle); + } + }); + + static PooledSlicedByteBuf newInstance(AbstractByteBuf unwrapped, ByteBuf wrapped, + int index, int length) { + checkSliceOutOfBounds(index, length, unwrapped); + return newInstance0(unwrapped, wrapped, index, length); + } + + private static PooledSlicedByteBuf newInstance0(AbstractByteBuf unwrapped, ByteBuf wrapped, + int adjustment, int length) { + final PooledSlicedByteBuf slice = RECYCLER.get(); + slice.init(unwrapped, wrapped, 0, length, length); + slice.discardMarks(); + slice.adjustment = adjustment; + + return slice; + } + + int adjustment; + + private PooledSlicedByteBuf(Handle handle) { + super(handle); + } + + @Override + public int capacity() { + return maxCapacity(); + } + + @Override + public ByteBuf capacity(int newCapacity) { + throw new UnsupportedOperationException("sliced buffer"); + } + + @Override + public int arrayOffset() { + return idx(unwrap().arrayOffset()); + } + + @Override + public long memoryAddress() { + return unwrap().memoryAddress() + adjustment; + } + + @Override + public ByteBuffer nioBuffer(int index, int length) { + checkIndex0(index, length); + return unwrap().nioBuffer(idx(index), length); + } + + @Override + public ByteBuffer[] nioBuffers(int index, int length) { + checkIndex0(index, length); + return unwrap().nioBuffers(idx(index), length); + } + + @Override + public ByteBuf copy(int index, int length) { + checkIndex0(index, length); + return unwrap().copy(idx(index), length); + } + + @Override + public ByteBuf slice(int index, int length) { + checkIndex0(index, length); + return super.slice(idx(index), length); + } + + @Override + public ByteBuf retainedSlice(int index, int length) { + checkIndex0(index, length); + return PooledSlicedByteBuf.newInstance0(unwrap(), this, idx(index), length); + } + + @Override + public ByteBuf duplicate() { + return duplicate0().setIndex(idx(readerIndex()), idx(writerIndex())); + } + + @Override + public ByteBuf retainedDuplicate() { + return PooledDuplicatedByteBuf.newInstance(unwrap(), this, idx(readerIndex()), idx(writerIndex())); + } + + @Override + public byte getByte(int index) { + checkIndex0(index, 1); + return unwrap().getByte(idx(index)); + } + + @Override + protected byte _getByte(int index) { + return unwrap()._getByte(idx(index)); + } + + @Override + public short getShort(int index) { + checkIndex0(index, 2); + return unwrap().getShort(idx(index)); + } + + @Override + protected short _getShort(int index) { + return unwrap()._getShort(idx(index)); + } + + @Override + public short getShortLE(int index) { + checkIndex0(index, 2); + return unwrap().getShortLE(idx(index)); + } + + @Override + protected short _getShortLE(int index) { + return unwrap()._getShortLE(idx(index)); + } + + @Override + public int getUnsignedMedium(int index) { + checkIndex0(index, 3); + return unwrap().getUnsignedMedium(idx(index)); + } + + @Override + protected int _getUnsignedMedium(int index) { + return unwrap()._getUnsignedMedium(idx(index)); + } + + @Override + public int getUnsignedMediumLE(int index) { + checkIndex0(index, 3); + return unwrap().getUnsignedMediumLE(idx(index)); + } + + @Override + protected int _getUnsignedMediumLE(int index) { + return unwrap()._getUnsignedMediumLE(idx(index)); + } + + @Override + public int getInt(int index) { + checkIndex0(index, 4); + return unwrap().getInt(idx(index)); + } + + @Override + protected int _getInt(int index) { + return unwrap()._getInt(idx(index)); + } + + @Override + public int getIntLE(int index) { + checkIndex0(index, 4); + return unwrap().getIntLE(idx(index)); + } + + @Override + protected int _getIntLE(int index) { + return unwrap()._getIntLE(idx(index)); + } + + @Override + public long getLong(int index) { + checkIndex0(index, 8); + return unwrap().getLong(idx(index)); + } + + @Override + protected long _getLong(int index) { + return unwrap()._getLong(idx(index)); + } + + @Override + public long getLongLE(int index) { + checkIndex0(index, 8); + return unwrap().getLongLE(idx(index)); + } + + @Override + protected long _getLongLE(int index) { + return unwrap()._getLongLE(idx(index)); + } + + @Override + public ByteBuf getBytes(int index, ByteBuf dst, int dstIndex, int length) { + checkIndex0(index, length); + unwrap().getBytes(idx(index), dst, dstIndex, length); + return this; + } + + @Override + public ByteBuf getBytes(int index, byte[] dst, int dstIndex, int length) { + checkIndex0(index, length); + unwrap().getBytes(idx(index), dst, dstIndex, length); + return this; + } + + @Override + public ByteBuf getBytes(int index, ByteBuffer dst) { + checkIndex0(index, dst.remaining()); + unwrap().getBytes(idx(index), dst); + return this; + } + + @Override + public ByteBuf setByte(int index, int value) { + checkIndex0(index, 1); + unwrap().setByte(idx(index), value); + return this; + } + + @Override + protected void _setByte(int index, int value) { + unwrap()._setByte(idx(index), value); + } + + @Override + public ByteBuf setShort(int index, int value) { + checkIndex0(index, 2); + unwrap().setShort(idx(index), value); + return this; + } + + @Override + protected void _setShort(int index, int value) { + unwrap()._setShort(idx(index), value); + } + + @Override + public ByteBuf setShortLE(int index, int value) { + checkIndex0(index, 2); + unwrap().setShortLE(idx(index), value); + return this; + } + + @Override + protected void _setShortLE(int index, int value) { + unwrap()._setShortLE(idx(index), value); + } + + @Override + public ByteBuf setMedium(int index, int value) { + checkIndex0(index, 3); + unwrap().setMedium(idx(index), value); + return this; + } + + @Override + protected void _setMedium(int index, int value) { + unwrap()._setMedium(idx(index), value); + } + + @Override + public ByteBuf setMediumLE(int index, int value) { + checkIndex0(index, 3); + unwrap().setMediumLE(idx(index), value); + return this; + } + + @Override + protected void _setMediumLE(int index, int value) { + unwrap()._setMediumLE(idx(index), value); + } + + @Override + public ByteBuf setInt(int index, int value) { + checkIndex0(index, 4); + unwrap().setInt(idx(index), value); + return this; + } + + @Override + protected void _setInt(int index, int value) { + unwrap()._setInt(idx(index), value); + } + + @Override + public ByteBuf setIntLE(int index, int value) { + checkIndex0(index, 4); + unwrap().setIntLE(idx(index), value); + return this; + } + + @Override + protected void _setIntLE(int index, int value) { + unwrap()._setIntLE(idx(index), value); + } + + @Override + public ByteBuf setLong(int index, long value) { + checkIndex0(index, 8); + unwrap().setLong(idx(index), value); + return this; + } + + @Override + protected void _setLong(int index, long value) { + unwrap()._setLong(idx(index), value); + } + + @Override + public ByteBuf setLongLE(int index, long value) { + checkIndex0(index, 8); + unwrap().setLongLE(idx(index), value); + return this; + } + + @Override + protected void _setLongLE(int index, long value) { + unwrap().setLongLE(idx(index), value); + } + + @Override + public ByteBuf setBytes(int index, byte[] src, int srcIndex, int length) { + checkIndex0(index, length); + unwrap().setBytes(idx(index), src, srcIndex, length); + return this; + } + + @Override + public ByteBuf setBytes(int index, ByteBuf src, int srcIndex, int length) { + checkIndex0(index, length); + unwrap().setBytes(idx(index), src, srcIndex, length); + return this; + } + + @Override + public ByteBuf setBytes(int index, ByteBuffer src) { + checkIndex0(index, src.remaining()); + unwrap().setBytes(idx(index), src); + return this; + } + + @Override + public ByteBuf getBytes(int index, OutputStream out, int length) + throws IOException { + checkIndex0(index, length); + unwrap().getBytes(idx(index), out, length); + return this; + } + + @Override + public int getBytes(int index, GatheringByteChannel out, int length) + throws IOException { + checkIndex0(index, length); + return unwrap().getBytes(idx(index), out, length); + } + + @Override + public int getBytes(int index, FileChannel out, long position, int length) + throws IOException { + checkIndex0(index, length); + return unwrap().getBytes(idx(index), out, position, length); + } + + @Override + public int setBytes(int index, InputStream in, int length) + throws IOException { + checkIndex0(index, length); + return unwrap().setBytes(idx(index), in, length); + } + + @Override + public int setBytes(int index, ScatteringByteChannel in, int length) + throws IOException { + checkIndex0(index, length); + return unwrap().setBytes(idx(index), in, length); + } + + @Override + public int setBytes(int index, FileChannel in, long position, int length) + throws IOException { + checkIndex0(index, length); + return unwrap().setBytes(idx(index), in, position, length); + } + + @Override + public int forEachByte(int index, int length, ByteProcessor processor) { + checkIndex0(index, length); + int ret = unwrap().forEachByte(idx(index), length, processor); + if (ret < adjustment) { + return -1; + } + return ret - adjustment; + } + + @Override + public int forEachByteDesc(int index, int length, ByteProcessor processor) { + checkIndex0(index, length); + int ret = unwrap().forEachByteDesc(idx(index), length, processor); + if (ret < adjustment) { + return -1; + } + return ret - adjustment; + } + + private int idx(int index) { + return index + adjustment; + } +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/PooledUnsafeDirectByteBuf.java b/netty-buffer/src/main/java/io/netty/buffer/PooledUnsafeDirectByteBuf.java new file mode 100644 index 0000000..c67486a --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/PooledUnsafeDirectByteBuf.java @@ -0,0 +1,273 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.buffer; + +import io.netty.util.internal.ObjectPool; +import io.netty.util.internal.ObjectPool.Handle; +import io.netty.util.internal.ObjectPool.ObjectCreator; +import io.netty.util.internal.PlatformDependent; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.ByteBuffer; + +final class PooledUnsafeDirectByteBuf extends PooledByteBuf { + private static final ObjectPool RECYCLER = ObjectPool.newPool( + new ObjectCreator() { + @Override + public PooledUnsafeDirectByteBuf newObject(Handle handle) { + return new PooledUnsafeDirectByteBuf(handle, 0); + } + }); + + static PooledUnsafeDirectByteBuf newInstance(int maxCapacity) { + PooledUnsafeDirectByteBuf buf = RECYCLER.get(); + buf.reuse(maxCapacity); + return buf; + } + + private long memoryAddress; + + private PooledUnsafeDirectByteBuf(Handle recyclerHandle, int maxCapacity) { + super(recyclerHandle, maxCapacity); + } + + @Override + void init(PoolChunk chunk, ByteBuffer nioBuffer, + long handle, int offset, int length, int maxLength, PoolThreadCache cache) { + super.init(chunk, nioBuffer, handle, offset, length, maxLength, cache); + initMemoryAddress(); + } + + @Override + void initUnpooled(PoolChunk chunk, int length) { + super.initUnpooled(chunk, length); + initMemoryAddress(); + } + + private void initMemoryAddress() { + memoryAddress = PlatformDependent.directBufferAddress(memory) + offset; + } + + @Override + protected ByteBuffer newInternalNioBuffer(ByteBuffer memory) { + return memory.duplicate(); + } + + @Override + public boolean isDirect() { + return true; + } + + @Override + protected byte _getByte(int index) { + return UnsafeByteBufUtil.getByte(addr(index)); + } + + @Override + protected short _getShort(int index) { + return UnsafeByteBufUtil.getShort(addr(index)); + } + + @Override + protected short _getShortLE(int index) { + return UnsafeByteBufUtil.getShortLE(addr(index)); + } + + @Override + protected int _getUnsignedMedium(int index) { + return UnsafeByteBufUtil.getUnsignedMedium(addr(index)); + } + + @Override + protected int _getUnsignedMediumLE(int index) { + return UnsafeByteBufUtil.getUnsignedMediumLE(addr(index)); + } + + @Override + protected int _getInt(int index) { + return UnsafeByteBufUtil.getInt(addr(index)); + } + + @Override + protected int _getIntLE(int index) { + return UnsafeByteBufUtil.getIntLE(addr(index)); + } + + @Override + protected long _getLong(int index) { + return UnsafeByteBufUtil.getLong(addr(index)); + } + + @Override + protected long _getLongLE(int index) { + return UnsafeByteBufUtil.getLongLE(addr(index)); + } + + @Override + public ByteBuf getBytes(int index, ByteBuf dst, int dstIndex, int length) { + UnsafeByteBufUtil.getBytes(this, addr(index), index, dst, dstIndex, length); + return this; + } + + @Override + public ByteBuf getBytes(int index, byte[] dst, int dstIndex, int length) { + UnsafeByteBufUtil.getBytes(this, addr(index), index, dst, dstIndex, length); + return this; + } + + @Override + public ByteBuf getBytes(int index, ByteBuffer dst) { + UnsafeByteBufUtil.getBytes(this, addr(index), index, dst); + return this; + } + + @Override + public ByteBuf getBytes(int index, OutputStream out, int length) throws IOException { + UnsafeByteBufUtil.getBytes(this, addr(index), index, out, length); + return this; + } + + @Override + protected void _setByte(int index, int value) { + UnsafeByteBufUtil.setByte(addr(index), (byte) value); + } + + @Override + protected void _setShort(int index, int value) { + UnsafeByteBufUtil.setShort(addr(index), value); + } + + @Override + protected void _setShortLE(int index, int value) { + UnsafeByteBufUtil.setShortLE(addr(index), value); + } + + @Override + protected void _setMedium(int index, int value) { + UnsafeByteBufUtil.setMedium(addr(index), value); + } + + @Override + protected void _setMediumLE(int index, int value) { + UnsafeByteBufUtil.setMediumLE(addr(index), value); + } + + @Override + protected void _setInt(int index, int value) { + UnsafeByteBufUtil.setInt(addr(index), value); + } + + @Override + protected void _setIntLE(int index, int value) { + UnsafeByteBufUtil.setIntLE(addr(index), value); + } + + @Override + protected void _setLong(int index, long value) { + UnsafeByteBufUtil.setLong(addr(index), value); + } + + @Override + protected void _setLongLE(int index, long value) { + UnsafeByteBufUtil.setLongLE(addr(index), value); + } + + @Override + public ByteBuf setBytes(int index, ByteBuf src, int srcIndex, int length) { + UnsafeByteBufUtil.setBytes(this, addr(index), index, src, srcIndex, length); + return this; + } + + @Override + public ByteBuf setBytes(int index, byte[] src, int srcIndex, int length) { + UnsafeByteBufUtil.setBytes(this, addr(index), index, src, srcIndex, length); + return this; + } + + @Override + public ByteBuf setBytes(int index, ByteBuffer src) { + UnsafeByteBufUtil.setBytes(this, addr(index), index, src); + return this; + } + + @Override + public int setBytes(int index, InputStream in, int length) throws IOException { + return UnsafeByteBufUtil.setBytes(this, addr(index), index, in, length); + } + + @Override + public ByteBuf copy(int index, int length) { + return UnsafeByteBufUtil.copy(this, addr(index), index, length); + } + + @Override + public boolean hasArray() { + return false; + } + + @Override + public byte[] array() { + throw new UnsupportedOperationException("direct buffer"); + } + + @Override + public int arrayOffset() { + throw new UnsupportedOperationException("direct buffer"); + } + + @Override + public boolean hasMemoryAddress() { + return true; + } + + @Override + public long memoryAddress() { + ensureAccessible(); + return memoryAddress; + } + + private long addr(int index) { + return memoryAddress + index; + } + + @Override + protected SwappedByteBuf newSwappedByteBuf() { + if (PlatformDependent.isUnaligned()) { + // Only use if unaligned access is supported otherwise there is no gain. + return new UnsafeDirectSwappedByteBuf(this); + } + return super.newSwappedByteBuf(); + } + + @Override + public ByteBuf setZero(int index, int length) { + checkIndex(index, length); + UnsafeByteBufUtil.setZero(addr(index), length); + return this; + } + + @Override + public ByteBuf writeZero(int length) { + ensureWritable(length); + int wIndex = writerIndex; + UnsafeByteBufUtil.setZero(addr(wIndex), length); + writerIndex = wIndex + length; + return this; + } +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/PooledUnsafeHeapByteBuf.java b/netty-buffer/src/main/java/io/netty/buffer/PooledUnsafeHeapByteBuf.java new file mode 100644 index 0000000..9a8674e --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/PooledUnsafeHeapByteBuf.java @@ -0,0 +1,166 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import io.netty.util.internal.ObjectPool; +import io.netty.util.internal.ObjectPool.Handle; +import io.netty.util.internal.ObjectPool.ObjectCreator; +import io.netty.util.internal.PlatformDependent; + +final class PooledUnsafeHeapByteBuf extends PooledHeapByteBuf { + + private static final ObjectPool RECYCLER = ObjectPool.newPool( + new ObjectCreator() { + @Override + public PooledUnsafeHeapByteBuf newObject(Handle handle) { + return new PooledUnsafeHeapByteBuf(handle, 0); + } + }); + + static PooledUnsafeHeapByteBuf newUnsafeInstance(int maxCapacity) { + PooledUnsafeHeapByteBuf buf = RECYCLER.get(); + buf.reuse(maxCapacity); + return buf; + } + + private PooledUnsafeHeapByteBuf(Handle recyclerHandle, int maxCapacity) { + super(recyclerHandle, maxCapacity); + } + + @Override + protected byte _getByte(int index) { + return UnsafeByteBufUtil.getByte(memory, idx(index)); + } + + @Override + protected short _getShort(int index) { + return UnsafeByteBufUtil.getShort(memory, idx(index)); + } + + @Override + protected short _getShortLE(int index) { + return UnsafeByteBufUtil.getShortLE(memory, idx(index)); + } + + @Override + protected int _getUnsignedMedium(int index) { + return UnsafeByteBufUtil.getUnsignedMedium(memory, idx(index)); + } + + @Override + protected int _getUnsignedMediumLE(int index) { + return UnsafeByteBufUtil.getUnsignedMediumLE(memory, idx(index)); + } + + @Override + protected int _getInt(int index) { + return UnsafeByteBufUtil.getInt(memory, idx(index)); + } + + @Override + protected int _getIntLE(int index) { + return UnsafeByteBufUtil.getIntLE(memory, idx(index)); + } + + @Override + protected long _getLong(int index) { + return UnsafeByteBufUtil.getLong(memory, idx(index)); + } + + @Override + protected long _getLongLE(int index) { + return UnsafeByteBufUtil.getLongLE(memory, idx(index)); + } + + @Override + protected void _setByte(int index, int value) { + UnsafeByteBufUtil.setByte(memory, idx(index), value); + } + + @Override + protected void _setShort(int index, int value) { + UnsafeByteBufUtil.setShort(memory, idx(index), value); + } + + @Override + protected void _setShortLE(int index, int value) { + UnsafeByteBufUtil.setShortLE(memory, idx(index), value); + } + + @Override + protected void _setMedium(int index, int value) { + UnsafeByteBufUtil.setMedium(memory, idx(index), value); + } + + @Override + protected void _setMediumLE(int index, int value) { + UnsafeByteBufUtil.setMediumLE(memory, idx(index), value); + } + + @Override + protected void _setInt(int index, int value) { + UnsafeByteBufUtil.setInt(memory, idx(index), value); + } + + @Override + protected void _setIntLE(int index, int value) { + UnsafeByteBufUtil.setIntLE(memory, idx(index), value); + } + + @Override + protected void _setLong(int index, long value) { + UnsafeByteBufUtil.setLong(memory, idx(index), value); + } + + @Override + protected void _setLongLE(int index, long value) { + UnsafeByteBufUtil.setLongLE(memory, idx(index), value); + } + + @Override + public ByteBuf setZero(int index, int length) { + if (PlatformDependent.javaVersion() >= 7) { + checkIndex(index, length); + // Only do on java7+ as the needed Unsafe call was only added there. + UnsafeByteBufUtil.setZero(memory, idx(index), length); + return this; + } + return super.setZero(index, length); + } + + @Override + public ByteBuf writeZero(int length) { + if (PlatformDependent.javaVersion() >= 7) { + // Only do on java7+ as the needed Unsafe call was only added there. + ensureWritable(length); + int wIndex = writerIndex; + UnsafeByteBufUtil.setZero(memory, idx(wIndex), length); + writerIndex = wIndex + length; + return this; + } + return super.writeZero(length); + } + + @Override + @Deprecated + protected SwappedByteBuf newSwappedByteBuf() { + if (PlatformDependent.isUnaligned()) { + // Only use if unaligned access is supported otherwise there is no gain. + return new UnsafeHeapSwappedByteBuf(this); + } + return super.newSwappedByteBuf(); + } +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/ReadOnlyByteBuf.java b/netty-buffer/src/main/java/io/netty/buffer/ReadOnlyByteBuf.java new file mode 100644 index 0000000..7d5b651 --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/ReadOnlyByteBuf.java @@ -0,0 +1,430 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import io.netty.util.ByteProcessor; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.ReadOnlyBufferException; +import java.nio.channels.FileChannel; +import java.nio.channels.GatheringByteChannel; +import java.nio.channels.ScatteringByteChannel; + +/** + * A derived buffer which forbids any write requests to its parent. It is + * recommended to use {@link Unpooled#unmodifiableBuffer(ByteBuf)} + * instead of calling the constructor explicitly. + * + * @deprecated Do not use. + */ +@Deprecated +public class ReadOnlyByteBuf extends AbstractDerivedByteBuf { + + private final ByteBuf buffer; + + public ReadOnlyByteBuf(ByteBuf buffer) { + super(buffer.maxCapacity()); + + if (buffer instanceof ReadOnlyByteBuf || buffer instanceof DuplicatedByteBuf) { + this.buffer = buffer.unwrap(); + } else { + this.buffer = buffer; + } + setIndex(buffer.readerIndex(), buffer.writerIndex()); + } + + @Override + public boolean isReadOnly() { + return true; + } + + @Override + public boolean isWritable() { + return false; + } + + @Override + public boolean isWritable(int numBytes) { + return false; + } + + @Override + public int ensureWritable(int minWritableBytes, boolean force) { + return 1; + } + + @Override + public ByteBuf ensureWritable(int minWritableBytes) { + throw new ReadOnlyBufferException(); + } + + @Override + public ByteBuf unwrap() { + return buffer; + } + + @Override + public ByteBufAllocator alloc() { + return unwrap().alloc(); + } + + @Override + @Deprecated + public ByteOrder order() { + return unwrap().order(); + } + + @Override + public boolean isDirect() { + return unwrap().isDirect(); + } + + @Override + public boolean hasArray() { + return false; + } + + @Override + public byte[] array() { + throw new ReadOnlyBufferException(); + } + + @Override + public int arrayOffset() { + throw new ReadOnlyBufferException(); + } + + @Override + public boolean hasMemoryAddress() { + return unwrap().hasMemoryAddress(); + } + + @Override + public long memoryAddress() { + return unwrap().memoryAddress(); + } + + @Override + public ByteBuf discardReadBytes() { + throw new ReadOnlyBufferException(); + } + + @Override + public ByteBuf setBytes(int index, ByteBuf src, int srcIndex, int length) { + throw new ReadOnlyBufferException(); + } + + @Override + public ByteBuf setBytes(int index, byte[] src, int srcIndex, int length) { + throw new ReadOnlyBufferException(); + } + + @Override + public ByteBuf setBytes(int index, ByteBuffer src) { + throw new ReadOnlyBufferException(); + } + + @Override + public ByteBuf setByte(int index, int value) { + throw new ReadOnlyBufferException(); + } + + @Override + protected void _setByte(int index, int value) { + throw new ReadOnlyBufferException(); + } + + @Override + public ByteBuf setShort(int index, int value) { + throw new ReadOnlyBufferException(); + } + + @Override + protected void _setShort(int index, int value) { + throw new ReadOnlyBufferException(); + } + + @Override + public ByteBuf setShortLE(int index, int value) { + throw new ReadOnlyBufferException(); + } + + @Override + protected void _setShortLE(int index, int value) { + throw new ReadOnlyBufferException(); + } + + @Override + public ByteBuf setMedium(int index, int value) { + throw new ReadOnlyBufferException(); + } + + @Override + protected void _setMedium(int index, int value) { + throw new ReadOnlyBufferException(); + } + + @Override + public ByteBuf setMediumLE(int index, int value) { + throw new ReadOnlyBufferException(); + } + + @Override + protected void _setMediumLE(int index, int value) { + throw new ReadOnlyBufferException(); + } + + @Override + public ByteBuf setInt(int index, int value) { + throw new ReadOnlyBufferException(); + } + + @Override + protected void _setInt(int index, int value) { + throw new ReadOnlyBufferException(); + } + + @Override + public ByteBuf setIntLE(int index, int value) { + throw new ReadOnlyBufferException(); + } + + @Override + protected void _setIntLE(int index, int value) { + throw new ReadOnlyBufferException(); + } + + @Override + public ByteBuf setLong(int index, long value) { + throw new ReadOnlyBufferException(); + } + + @Override + protected void _setLong(int index, long value) { + throw new ReadOnlyBufferException(); + } + + @Override + public ByteBuf setLongLE(int index, long value) { + throw new ReadOnlyBufferException(); + } + + @Override + protected void _setLongLE(int index, long value) { + throw new ReadOnlyBufferException(); + } + + @Override + public int setBytes(int index, InputStream in, int length) { + throw new ReadOnlyBufferException(); + } + + @Override + public int setBytes(int index, ScatteringByteChannel in, int length) { + throw new ReadOnlyBufferException(); + } + + @Override + public int setBytes(int index, FileChannel in, long position, int length) { + throw new ReadOnlyBufferException(); + } + + @Override + public int getBytes(int index, GatheringByteChannel out, int length) + throws IOException { + return unwrap().getBytes(index, out, length); + } + + @Override + public int getBytes(int index, FileChannel out, long position, int length) + throws IOException { + return unwrap().getBytes(index, out, position, length); + } + + @Override + public ByteBuf getBytes(int index, OutputStream out, int length) + throws IOException { + unwrap().getBytes(index, out, length); + return this; + } + + @Override + public ByteBuf getBytes(int index, byte[] dst, int dstIndex, int length) { + unwrap().getBytes(index, dst, dstIndex, length); + return this; + } + + @Override + public ByteBuf getBytes(int index, ByteBuf dst, int dstIndex, int length) { + unwrap().getBytes(index, dst, dstIndex, length); + return this; + } + + @Override + public ByteBuf getBytes(int index, ByteBuffer dst) { + unwrap().getBytes(index, dst); + return this; + } + + @Override + public ByteBuf duplicate() { + return new ReadOnlyByteBuf(this); + } + + @Override + public ByteBuf copy(int index, int length) { + return unwrap().copy(index, length); + } + + @Override + public ByteBuf slice(int index, int length) { + return Unpooled.unmodifiableBuffer(unwrap().slice(index, length)); + } + + @Override + public byte getByte(int index) { + return unwrap().getByte(index); + } + + @Override + protected byte _getByte(int index) { + return unwrap().getByte(index); + } + + @Override + public short getShort(int index) { + return unwrap().getShort(index); + } + + @Override + protected short _getShort(int index) { + return unwrap().getShort(index); + } + + @Override + public short getShortLE(int index) { + return unwrap().getShortLE(index); + } + + @Override + protected short _getShortLE(int index) { + return unwrap().getShortLE(index); + } + + @Override + public int getUnsignedMedium(int index) { + return unwrap().getUnsignedMedium(index); + } + + @Override + protected int _getUnsignedMedium(int index) { + return unwrap().getUnsignedMedium(index); + } + + @Override + public int getUnsignedMediumLE(int index) { + return unwrap().getUnsignedMediumLE(index); + } + + @Override + protected int _getUnsignedMediumLE(int index) { + return unwrap().getUnsignedMediumLE(index); + } + + @Override + public int getInt(int index) { + return unwrap().getInt(index); + } + + @Override + protected int _getInt(int index) { + return unwrap().getInt(index); + } + + @Override + public int getIntLE(int index) { + return unwrap().getIntLE(index); + } + + @Override + protected int _getIntLE(int index) { + return unwrap().getIntLE(index); + } + + @Override + public long getLong(int index) { + return unwrap().getLong(index); + } + + @Override + protected long _getLong(int index) { + return unwrap().getLong(index); + } + + @Override + public long getLongLE(int index) { + return unwrap().getLongLE(index); + } + + @Override + protected long _getLongLE(int index) { + return unwrap().getLongLE(index); + } + + @Override + public int nioBufferCount() { + return unwrap().nioBufferCount(); + } + + @Override + public ByteBuffer nioBuffer(int index, int length) { + return unwrap().nioBuffer(index, length).asReadOnlyBuffer(); + } + + @Override + public ByteBuffer[] nioBuffers(int index, int length) { + return unwrap().nioBuffers(index, length); + } + + @Override + public int forEachByte(int index, int length, ByteProcessor processor) { + return unwrap().forEachByte(index, length, processor); + } + + @Override + public int forEachByteDesc(int index, int length, ByteProcessor processor) { + return unwrap().forEachByteDesc(index, length, processor); + } + + @Override + public int capacity() { + return unwrap().capacity(); + } + + @Override + public ByteBuf capacity(int newCapacity) { + throw new ReadOnlyBufferException(); + } + + @Override + public ByteBuf asReadOnly() { + return this; + } +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/ReadOnlyByteBufferBuf.java b/netty-buffer/src/main/java/io/netty/buffer/ReadOnlyByteBufferBuf.java new file mode 100644 index 0000000..ec0e5d6 --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/ReadOnlyByteBufferBuf.java @@ -0,0 +1,485 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import io.netty.util.internal.StringUtil; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.ReadOnlyBufferException; +import java.nio.channels.FileChannel; +import java.nio.channels.GatheringByteChannel; +import java.nio.channels.ScatteringByteChannel; + + +/** + * Read-only ByteBuf which wraps a read-only ByteBuffer. + */ +class ReadOnlyByteBufferBuf extends AbstractReferenceCountedByteBuf { + + protected final ByteBuffer buffer; + private final ByteBufAllocator allocator; + private ByteBuffer tmpNioBuf; + + ReadOnlyByteBufferBuf(ByteBufAllocator allocator, ByteBuffer buffer) { + super(buffer.remaining()); + if (!buffer.isReadOnly()) { + throw new IllegalArgumentException("must be a readonly buffer: " + StringUtil.simpleClassName(buffer)); + } + + this.allocator = allocator; + this.buffer = buffer.slice().order(ByteOrder.BIG_ENDIAN); + writerIndex(this.buffer.limit()); + } + + @Override + protected void deallocate() { } + + @Override + public boolean isWritable() { + return false; + } + + @Override + public boolean isWritable(int numBytes) { + return false; + } + + @Override + public ByteBuf ensureWritable(int minWritableBytes) { + throw new ReadOnlyBufferException(); + } + + @Override + public int ensureWritable(int minWritableBytes, boolean force) { + return 1; + } + + @Override + public byte getByte(int index) { + ensureAccessible(); + return _getByte(index); + } + + @Override + protected byte _getByte(int index) { + return buffer.get(index); + } + + @Override + public short getShort(int index) { + ensureAccessible(); + return _getShort(index); + } + + @Override + protected short _getShort(int index) { + return buffer.getShort(index); + } + + @Override + public short getShortLE(int index) { + ensureAccessible(); + return _getShortLE(index); + } + + @Override + protected short _getShortLE(int index) { + return ByteBufUtil.swapShort(buffer.getShort(index)); + } + + @Override + public int getUnsignedMedium(int index) { + ensureAccessible(); + return _getUnsignedMedium(index); + } + + @Override + protected int _getUnsignedMedium(int index) { + return (getByte(index) & 0xff) << 16 | + (getByte(index + 1) & 0xff) << 8 | + getByte(index + 2) & 0xff; + } + + @Override + public int getUnsignedMediumLE(int index) { + ensureAccessible(); + return _getUnsignedMediumLE(index); + } + + @Override + protected int _getUnsignedMediumLE(int index) { + return getByte(index) & 0xff | + (getByte(index + 1) & 0xff) << 8 | + (getByte(index + 2) & 0xff) << 16; + } + + @Override + public int getInt(int index) { + ensureAccessible(); + return _getInt(index); + } + + @Override + protected int _getInt(int index) { + return buffer.getInt(index); + } + + @Override + public int getIntLE(int index) { + ensureAccessible(); + return _getIntLE(index); + } + + @Override + protected int _getIntLE(int index) { + return ByteBufUtil.swapInt(buffer.getInt(index)); + } + + @Override + public long getLong(int index) { + ensureAccessible(); + return _getLong(index); + } + + @Override + protected long _getLong(int index) { + return buffer.getLong(index); + } + + @Override + public long getLongLE(int index) { + ensureAccessible(); + return _getLongLE(index); + } + + @Override + protected long _getLongLE(int index) { + return ByteBufUtil.swapLong(buffer.getLong(index)); + } + + @Override + public ByteBuf getBytes(int index, ByteBuf dst, int dstIndex, int length) { + checkDstIndex(index, length, dstIndex, dst.capacity()); + if (dst.hasArray()) { + getBytes(index, dst.array(), dst.arrayOffset() + dstIndex, length); + } else if (dst.nioBufferCount() > 0) { + for (ByteBuffer bb: dst.nioBuffers(dstIndex, length)) { + int bbLen = bb.remaining(); + getBytes(index, bb); + index += bbLen; + } + } else { + dst.setBytes(dstIndex, this, index, length); + } + return this; + } + + @Override + public ByteBuf getBytes(int index, byte[] dst, int dstIndex, int length) { + checkDstIndex(index, length, dstIndex, dst.length); + + ByteBuffer tmpBuf = internalNioBuffer(); + tmpBuf.clear().position(index).limit(index + length); + tmpBuf.get(dst, dstIndex, length); + return this; + } + + @Override + public ByteBuf getBytes(int index, ByteBuffer dst) { + checkIndex(index, dst.remaining()); + + ByteBuffer tmpBuf = internalNioBuffer(); + tmpBuf.clear().position(index).limit(index + dst.remaining()); + dst.put(tmpBuf); + return this; + } + + @Override + public ByteBuf setByte(int index, int value) { + throw new ReadOnlyBufferException(); + } + + @Override + protected void _setByte(int index, int value) { + throw new ReadOnlyBufferException(); + } + + @Override + public ByteBuf setShort(int index, int value) { + throw new ReadOnlyBufferException(); + } + + @Override + protected void _setShort(int index, int value) { + throw new ReadOnlyBufferException(); + } + + @Override + public ByteBuf setShortLE(int index, int value) { + throw new ReadOnlyBufferException(); + } + + @Override + protected void _setShortLE(int index, int value) { + throw new ReadOnlyBufferException(); + } + + @Override + public ByteBuf setMedium(int index, int value) { + throw new ReadOnlyBufferException(); + } + + @Override + protected void _setMedium(int index, int value) { + throw new ReadOnlyBufferException(); + } + + @Override + public ByteBuf setMediumLE(int index, int value) { + throw new ReadOnlyBufferException(); + } + + @Override + protected void _setMediumLE(int index, int value) { + throw new ReadOnlyBufferException(); + } + + @Override + public ByteBuf setInt(int index, int value) { + throw new ReadOnlyBufferException(); + } + + @Override + protected void _setInt(int index, int value) { + throw new ReadOnlyBufferException(); + } + + @Override + public ByteBuf setIntLE(int index, int value) { + throw new ReadOnlyBufferException(); + } + + @Override + protected void _setIntLE(int index, int value) { + throw new ReadOnlyBufferException(); + } + + @Override + public ByteBuf setLong(int index, long value) { + throw new ReadOnlyBufferException(); + } + + @Override + protected void _setLong(int index, long value) { + throw new ReadOnlyBufferException(); + } + + @Override + public ByteBuf setLongLE(int index, long value) { + throw new ReadOnlyBufferException(); + } + + @Override + protected void _setLongLE(int index, long value) { + throw new ReadOnlyBufferException(); + } + + @Override + public int capacity() { + return maxCapacity(); + } + + @Override + public ByteBuf capacity(int newCapacity) { + throw new ReadOnlyBufferException(); + } + + @Override + public ByteBufAllocator alloc() { + return allocator; + } + + @Override + public ByteOrder order() { + return ByteOrder.BIG_ENDIAN; + } + + @Override + public ByteBuf unwrap() { + return null; + } + + @Override + public boolean isReadOnly() { + return buffer.isReadOnly(); + } + + @Override + public boolean isDirect() { + return buffer.isDirect(); + } + + @Override + public ByteBuf getBytes(int index, OutputStream out, int length) throws IOException { + ensureAccessible(); + if (length == 0) { + return this; + } + + if (buffer.hasArray()) { + out.write(buffer.array(), index + buffer.arrayOffset(), length); + } else { + byte[] tmp = ByteBufUtil.threadLocalTempArray(length); + ByteBuffer tmpBuf = internalNioBuffer(); + tmpBuf.clear().position(index); + tmpBuf.get(tmp, 0, length); + out.write(tmp, 0, length); + } + return this; + } + + @Override + public int getBytes(int index, GatheringByteChannel out, int length) throws IOException { + ensureAccessible(); + if (length == 0) { + return 0; + } + + ByteBuffer tmpBuf = internalNioBuffer(); + tmpBuf.clear().position(index).limit(index + length); + return out.write(tmpBuf); + } + + @Override + public int getBytes(int index, FileChannel out, long position, int length) throws IOException { + ensureAccessible(); + if (length == 0) { + return 0; + } + + ByteBuffer tmpBuf = internalNioBuffer(); + tmpBuf.clear().position(index).limit(index + length); + return out.write(tmpBuf, position); + } + + @Override + public ByteBuf setBytes(int index, ByteBuf src, int srcIndex, int length) { + throw new ReadOnlyBufferException(); + } + + @Override + public ByteBuf setBytes(int index, byte[] src, int srcIndex, int length) { + throw new ReadOnlyBufferException(); + } + + @Override + public ByteBuf setBytes(int index, ByteBuffer src) { + throw new ReadOnlyBufferException(); + } + + @Override + public int setBytes(int index, InputStream in, int length) throws IOException { + throw new ReadOnlyBufferException(); + } + + @Override + public int setBytes(int index, ScatteringByteChannel in, int length) throws IOException { + throw new ReadOnlyBufferException(); + } + + @Override + public int setBytes(int index, FileChannel in, long position, int length) throws IOException { + throw new ReadOnlyBufferException(); + } + + protected final ByteBuffer internalNioBuffer() { + ByteBuffer tmpNioBuf = this.tmpNioBuf; + if (tmpNioBuf == null) { + this.tmpNioBuf = tmpNioBuf = buffer.duplicate(); + } + return tmpNioBuf; + } + + @Override + public ByteBuf copy(int index, int length) { + ensureAccessible(); + ByteBuffer src; + try { + src = internalNioBuffer().clear().position(index).limit(index + length); + } catch (IllegalArgumentException ignored) { + throw new IndexOutOfBoundsException("Too many bytes to read - Need " + (index + length)); + } + ByteBuf dst = src.isDirect() ? alloc().directBuffer(length) : alloc().heapBuffer(length); + dst.writeBytes(src); + return dst; + } + + @Override + public int nioBufferCount() { + return 1; + } + + @Override + public ByteBuffer[] nioBuffers(int index, int length) { + return new ByteBuffer[] { nioBuffer(index, length) }; + } + + @Override + public ByteBuffer nioBuffer(int index, int length) { + checkIndex(index, length); + return buffer.duplicate().position(index).limit(index + length); + } + + @Override + public ByteBuffer internalNioBuffer(int index, int length) { + ensureAccessible(); + return internalNioBuffer().clear().position(index).limit(index + length); + } + + @Override + public final boolean isContiguous() { + return true; + } + + @Override + public boolean hasArray() { + return buffer.hasArray(); + } + + @Override + public byte[] array() { + return buffer.array(); + } + + @Override + public int arrayOffset() { + return buffer.arrayOffset(); + } + + @Override + public boolean hasMemoryAddress() { + return false; + } + + @Override + public long memoryAddress() { + throw new UnsupportedOperationException(); + } +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/ReadOnlyUnsafeDirectByteBuf.java b/netty-buffer/src/main/java/io/netty/buffer/ReadOnlyUnsafeDirectByteBuf.java new file mode 100644 index 0000000..c0e07d3 --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/ReadOnlyUnsafeDirectByteBuf.java @@ -0,0 +1,124 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + + +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.PlatformDependent; + +import java.nio.ByteBuffer; + + +/** + * Read-only ByteBuf which wraps a read-only direct ByteBuffer and use unsafe for best performance. + */ +final class ReadOnlyUnsafeDirectByteBuf extends ReadOnlyByteBufferBuf { + private final long memoryAddress; + + ReadOnlyUnsafeDirectByteBuf(ByteBufAllocator allocator, ByteBuffer byteBuffer) { + super(allocator, byteBuffer); + // Use buffer as the super class will slice the passed in ByteBuffer which means the memoryAddress + // may be different if the position != 0. + memoryAddress = PlatformDependent.directBufferAddress(buffer); + } + + @Override + protected byte _getByte(int index) { + return UnsafeByteBufUtil.getByte(addr(index)); + } + + @Override + protected short _getShort(int index) { + return UnsafeByteBufUtil.getShort(addr(index)); + } + + @Override + protected int _getUnsignedMedium(int index) { + return UnsafeByteBufUtil.getUnsignedMedium(addr(index)); + } + + @Override + protected int _getInt(int index) { + return UnsafeByteBufUtil.getInt(addr(index)); + } + + @Override + protected long _getLong(int index) { + return UnsafeByteBufUtil.getLong(addr(index)); + } + + @Override + public ByteBuf getBytes(int index, ByteBuf dst, int dstIndex, int length) { + checkIndex(index, length); + ObjectUtil.checkNotNull(dst, "dst"); + if (dstIndex < 0 || dstIndex > dst.capacity() - length) { + throw new IndexOutOfBoundsException("dstIndex: " + dstIndex); + } + + if (dst.hasMemoryAddress()) { + PlatformDependent.copyMemory(addr(index), dst.memoryAddress() + dstIndex, length); + } else if (dst.hasArray()) { + PlatformDependent.copyMemory(addr(index), dst.array(), dst.arrayOffset() + dstIndex, length); + } else { + dst.setBytes(dstIndex, this, index, length); + } + return this; + } + + @Override + public ByteBuf getBytes(int index, byte[] dst, int dstIndex, int length) { + checkIndex(index, length); + ObjectUtil.checkNotNull(dst, "dst"); + if (dstIndex < 0 || dstIndex > dst.length - length) { + throw new IndexOutOfBoundsException(String.format( + "dstIndex: %d, length: %d (expected: range(0, %d))", dstIndex, length, dst.length)); + } + + if (length != 0) { + PlatformDependent.copyMemory(addr(index), dst, dstIndex, length); + } + return this; + } + + @Override + public ByteBuf copy(int index, int length) { + checkIndex(index, length); + ByteBuf copy = alloc().directBuffer(length, maxCapacity()); + if (length != 0) { + if (copy.hasMemoryAddress()) { + PlatformDependent.copyMemory(addr(index), copy.memoryAddress(), length); + copy.setIndex(0, length); + } else { + copy.writeBytes(this, index, length); + } + } + return copy; + } + + @Override + public boolean hasMemoryAddress() { + return true; + } + + @Override + public long memoryAddress() { + return memoryAddress; + } + + private long addr(int index) { + return memoryAddress + index; + } +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/SimpleLeakAwareByteBuf.java b/netty-buffer/src/main/java/io/netty/buffer/SimpleLeakAwareByteBuf.java new file mode 100644 index 0000000..2297849 --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/SimpleLeakAwareByteBuf.java @@ -0,0 +1,175 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.buffer; + +import io.netty.util.ResourceLeakDetector; +import io.netty.util.ResourceLeakTracker; +import io.netty.util.internal.ObjectUtil; + +import java.nio.ByteOrder; + +class SimpleLeakAwareByteBuf extends WrappedByteBuf { + + /** + * This object's is associated with the {@link ResourceLeakTracker}. When {@link ResourceLeakTracker#close(Object)} + * is called this object will be used as the argument. It is also assumed that this object is used when + * {@link ResourceLeakDetector#track(Object)} is called to create {@link #leak}. + */ + private final ByteBuf trackedByteBuf; + final ResourceLeakTracker leak; + + SimpleLeakAwareByteBuf(ByteBuf wrapped, ByteBuf trackedByteBuf, ResourceLeakTracker leak) { + super(wrapped); + this.trackedByteBuf = ObjectUtil.checkNotNull(trackedByteBuf, "trackedByteBuf"); + this.leak = ObjectUtil.checkNotNull(leak, "leak"); + } + + SimpleLeakAwareByteBuf(ByteBuf wrapped, ResourceLeakTracker leak) { + this(wrapped, wrapped, leak); + } + + @Override + public ByteBuf slice() { + return newSharedLeakAwareByteBuf(super.slice()); + } + + @Override + public ByteBuf retainedSlice() { + return unwrappedDerived(super.retainedSlice()); + } + + @Override + public ByteBuf retainedSlice(int index, int length) { + return unwrappedDerived(super.retainedSlice(index, length)); + } + + @Override + public ByteBuf retainedDuplicate() { + return unwrappedDerived(super.retainedDuplicate()); + } + + @Override + public ByteBuf readRetainedSlice(int length) { + return unwrappedDerived(super.readRetainedSlice(length)); + } + + @Override + public ByteBuf slice(int index, int length) { + return newSharedLeakAwareByteBuf(super.slice(index, length)); + } + + @Override + public ByteBuf duplicate() { + return newSharedLeakAwareByteBuf(super.duplicate()); + } + + @Override + public ByteBuf readSlice(int length) { + return newSharedLeakAwareByteBuf(super.readSlice(length)); + } + + @Override + public ByteBuf asReadOnly() { + return newSharedLeakAwareByteBuf(super.asReadOnly()); + } + + @Override + public ByteBuf touch() { + return this; + } + + @Override + public ByteBuf touch(Object hint) { + return this; + } + + @Override + public boolean release() { + if (super.release()) { + closeLeak(); + return true; + } + return false; + } + + @Override + public boolean release(int decrement) { + if (super.release(decrement)) { + closeLeak(); + return true; + } + return false; + } + + private void closeLeak() { + // Close the ResourceLeakTracker with the tracked ByteBuf as argument. This must be the same that was used when + // calling DefaultResourceLeak.track(...). + boolean closed = leak.close(trackedByteBuf); + assert closed; + } + + @Override + public ByteBuf order(ByteOrder endianness) { + if (order() == endianness) { + return this; + } else { + return newSharedLeakAwareByteBuf(super.order(endianness)); + } + } + + private ByteBuf unwrappedDerived(ByteBuf derived) { + // We only need to unwrap SwappedByteBuf implementations as these will be the only ones that may end up in + // the AbstractLeakAwareByteBuf implementations beside slices / duplicates and "real" buffers. + ByteBuf unwrappedDerived = unwrapSwapped(derived); + + if (unwrappedDerived instanceof AbstractPooledDerivedByteBuf) { + // Update the parent to point to this buffer so we correctly close the ResourceLeakTracker. + ((AbstractPooledDerivedByteBuf) unwrappedDerived).parent(this); + + // force tracking of derived buffers (see issue #13414) + return newLeakAwareByteBuf(derived, AbstractByteBuf.leakDetector.trackForcibly(derived)); + } + return newSharedLeakAwareByteBuf(derived); + } + + @SuppressWarnings("deprecation") + private static ByteBuf unwrapSwapped(ByteBuf buf) { + if (buf instanceof SwappedByteBuf) { + do { + buf = buf.unwrap(); + } while (buf instanceof SwappedByteBuf); + + return buf; + } + return buf; + } + + private SimpleLeakAwareByteBuf newSharedLeakAwareByteBuf( + ByteBuf wrapped) { + return newLeakAwareByteBuf(wrapped, trackedByteBuf, leak); + } + + private SimpleLeakAwareByteBuf newLeakAwareByteBuf( + ByteBuf wrapped, ResourceLeakTracker leakTracker) { + return newLeakAwareByteBuf(wrapped, wrapped, leakTracker); + } + + protected SimpleLeakAwareByteBuf newLeakAwareByteBuf( + ByteBuf buf, ByteBuf trackedByteBuf, ResourceLeakTracker leakTracker) { + return new SimpleLeakAwareByteBuf(buf, trackedByteBuf, leakTracker); + } +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/SimpleLeakAwareCompositeByteBuf.java b/netty-buffer/src/main/java/io/netty/buffer/SimpleLeakAwareCompositeByteBuf.java new file mode 100644 index 0000000..aacbc36 --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/SimpleLeakAwareCompositeByteBuf.java @@ -0,0 +1,126 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + + +import io.netty.util.ResourceLeakTracker; +import io.netty.util.internal.ObjectUtil; + +import java.nio.ByteOrder; + +class SimpleLeakAwareCompositeByteBuf extends WrappedCompositeByteBuf { + + final ResourceLeakTracker leak; + + SimpleLeakAwareCompositeByteBuf(CompositeByteBuf wrapped, ResourceLeakTracker leak) { + super(wrapped); + this.leak = ObjectUtil.checkNotNull(leak, "leak"); + } + + @Override + public boolean release() { + // Call unwrap() before just in case that super.release() will change the ByteBuf instance that is returned + // by unwrap(). + ByteBuf unwrapped = unwrap(); + if (super.release()) { + closeLeak(unwrapped); + return true; + } + return false; + } + + @Override + public boolean release(int decrement) { + // Call unwrap() before just in case that super.release() will change the ByteBuf instance that is returned + // by unwrap(). + ByteBuf unwrapped = unwrap(); + if (super.release(decrement)) { + closeLeak(unwrapped); + return true; + } + return false; + } + + private void closeLeak(ByteBuf trackedByteBuf) { + // Close the ResourceLeakTracker with the tracked ByteBuf as argument. This must be the same that was used when + // calling DefaultResourceLeak.track(...). + boolean closed = leak.close(trackedByteBuf); + assert closed; + } + + @Override + public ByteBuf order(ByteOrder endianness) { + if (order() == endianness) { + return this; + } else { + return newLeakAwareByteBuf(super.order(endianness)); + } + } + + @Override + public ByteBuf slice() { + return newLeakAwareByteBuf(super.slice()); + } + + @Override + public ByteBuf retainedSlice() { + return newLeakAwareByteBuf(super.retainedSlice()); + } + + @Override + public ByteBuf slice(int index, int length) { + return newLeakAwareByteBuf(super.slice(index, length)); + } + + @Override + public ByteBuf retainedSlice(int index, int length) { + return newLeakAwareByteBuf(super.retainedSlice(index, length)); + } + + @Override + public ByteBuf duplicate() { + return newLeakAwareByteBuf(super.duplicate()); + } + + @Override + public ByteBuf retainedDuplicate() { + return newLeakAwareByteBuf(super.retainedDuplicate()); + } + + @Override + public ByteBuf readSlice(int length) { + return newLeakAwareByteBuf(super.readSlice(length)); + } + + @Override + public ByteBuf readRetainedSlice(int length) { + return newLeakAwareByteBuf(super.readRetainedSlice(length)); + } + + @Override + public ByteBuf asReadOnly() { + return newLeakAwareByteBuf(super.asReadOnly()); + } + + private SimpleLeakAwareByteBuf newLeakAwareByteBuf(ByteBuf wrapped) { + return newLeakAwareByteBuf(wrapped, unwrap(), leak); + } + + protected SimpleLeakAwareByteBuf newLeakAwareByteBuf( + ByteBuf wrapped, ByteBuf trackedByteBuf, ResourceLeakTracker leakTracker) { + return new SimpleLeakAwareByteBuf(wrapped, trackedByteBuf, leakTracker); + } +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/SizeClasses.java b/netty-buffer/src/main/java/io/netty/buffer/SizeClasses.java new file mode 100644 index 0000000..b42d455 --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/SizeClasses.java @@ -0,0 +1,413 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import static io.netty.buffer.PoolThreadCache.*; + +/** + * SizeClasses requires {@code pageShifts} to be defined prior to inclusion, + * and it in turn defines: + *

+ * LOG2_SIZE_CLASS_GROUP: Log of size class count for each size doubling. + * LOG2_MAX_LOOKUP_SIZE: Log of max size class in the lookup table. + * sizeClasses: Complete table of [index, log2Group, log2Delta, nDelta, isMultiPageSize, + * isSubPage, log2DeltaLookup] tuples. + * index: Size class index. + * log2Group: Log of group base size (no deltas added). + * log2Delta: Log of delta to previous size class. + * nDelta: Delta multiplier. + * isMultiPageSize: 'yes' if a multiple of the page size, 'no' otherwise. + * isSubPage: 'yes' if a subpage size class, 'no' otherwise. + * log2DeltaLookup: Same as log2Delta if a lookup table size class, 'no' + * otherwise. + *

+ * nSubpages: Number of subpages size classes. + * nSizes: Number of size classes. + * nPSizes: Number of size classes that are multiples of pageSize. + * + * smallMaxSizeIdx: Maximum small size class index. + * + * lookupMaxClass: Maximum size class included in lookup table. + * log2NormalMinClass: Log of minimum normal size class. + *

+ * The first size class and spacing are 1 << LOG2_QUANTUM. + * Each group has 1 << LOG2_SIZE_CLASS_GROUP of size classes. + * + * size = 1 << log2Group + nDelta * (1 << log2Delta) + * + * The first size class has an unusual encoding, because the size has to be + * split between group and delta*nDelta. + * + * If pageShift = 13, sizeClasses looks like this: + * + * (index, log2Group, log2Delta, nDelta, isMultiPageSize, isSubPage, log2DeltaLookup) + *

+ * ( 0, 4, 4, 0, no, yes, 4) + * ( 1, 4, 4, 1, no, yes, 4) + * ( 2, 4, 4, 2, no, yes, 4) + * ( 3, 4, 4, 3, no, yes, 4) + *

+ * ( 4, 6, 4, 1, no, yes, 4) + * ( 5, 6, 4, 2, no, yes, 4) + * ( 6, 6, 4, 3, no, yes, 4) + * ( 7, 6, 4, 4, no, yes, 4) + *

+ * ( 8, 7, 5, 1, no, yes, 5) + * ( 9, 7, 5, 2, no, yes, 5) + * ( 10, 7, 5, 3, no, yes, 5) + * ( 11, 7, 5, 4, no, yes, 5) + * ... + * ... + * ( 72, 23, 21, 1, yes, no, no) + * ( 73, 23, 21, 2, yes, no, no) + * ( 74, 23, 21, 3, yes, no, no) + * ( 75, 23, 21, 4, yes, no, no) + *

+ * ( 76, 24, 22, 1, yes, no, no) + */ +final class SizeClasses implements SizeClassesMetric { + + static final int LOG2_QUANTUM = 4; + + private static final int LOG2_SIZE_CLASS_GROUP = 2; + private static final int LOG2_MAX_LOOKUP_SIZE = 12; + + private static final int LOG2GROUP_IDX = 1; + private static final int LOG2DELTA_IDX = 2; + private static final int NDELTA_IDX = 3; + private static final int PAGESIZE_IDX = 4; + private static final int SUBPAGE_IDX = 5; + private static final int LOG2_DELTA_LOOKUP_IDX = 6; + + private static final byte no = 0, yes = 1; + + final int pageSize; + final int pageShifts; + final int chunkSize; + final int directMemoryCacheAlignment; + + final int nSizes; + final int nSubpages; + final int nPSizes; + final int lookupMaxSize; + final int smallMaxSizeIdx; + + private final int[] pageIdx2sizeTab; + + // lookup table for sizeIdx <= smallMaxSizeIdx + private final int[] sizeIdx2sizeTab; + + // lookup table used for size <= lookupMaxClass + // spacing is 1 << LOG2_QUANTUM, so the size of array is lookupMaxClass >> LOG2_QUANTUM + private final int[] size2idxTab; + + SizeClasses(int pageSize, int pageShifts, int chunkSize, int directMemoryCacheAlignment) { + int group = log2(chunkSize) - LOG2_QUANTUM - LOG2_SIZE_CLASS_GROUP + 1; + + //generate size classes + //[index, log2Group, log2Delta, nDelta, isMultiPageSize, isSubPage, log2DeltaLookup] + short[][] sizeClasses = new short[group << LOG2_SIZE_CLASS_GROUP][7]; + + int normalMaxSize = -1; + int nSizes = 0; + int size = 0; + + int log2Group = LOG2_QUANTUM; + int log2Delta = LOG2_QUANTUM; + int ndeltaLimit = 1 << LOG2_SIZE_CLASS_GROUP; + + //First small group, nDelta start at 0. + //first size class is 1 << LOG2_QUANTUM + for (int nDelta = 0; nDelta < ndeltaLimit; nDelta++, nSizes++) { + short[] sizeClass = newSizeClass(nSizes, log2Group, log2Delta, nDelta, pageShifts); + sizeClasses[nSizes] = sizeClass; + size = sizeOf(sizeClass, directMemoryCacheAlignment); + } + + log2Group += LOG2_SIZE_CLASS_GROUP; + + //All remaining groups, nDelta start at 1. + for (; size < chunkSize; log2Group++, log2Delta++) { + for (int nDelta = 1; nDelta <= ndeltaLimit && size < chunkSize; nDelta++, nSizes++) { + short[] sizeClass = newSizeClass(nSizes, log2Group, log2Delta, nDelta, pageShifts); + sizeClasses[nSizes] = sizeClass; + size = normalMaxSize = sizeOf(sizeClass, directMemoryCacheAlignment); + } + } + + //chunkSize must be normalMaxSize + assert chunkSize == normalMaxSize; + + int smallMaxSizeIdx = 0; + int lookupMaxSize = 0; + int nPSizes = 0; + int nSubpages = 0; + for (int idx = 0; idx < nSizes; idx++) { + short[] sz = sizeClasses[idx]; + if (sz[PAGESIZE_IDX] == yes) { + nPSizes++; + } + if (sz[SUBPAGE_IDX] == yes) { + nSubpages++; + smallMaxSizeIdx = idx; + } + if (sz[LOG2_DELTA_LOOKUP_IDX] != no) { + lookupMaxSize = sizeOf(sz, directMemoryCacheAlignment); + } + } + this.smallMaxSizeIdx = smallMaxSizeIdx; + this.lookupMaxSize = lookupMaxSize; + this.nPSizes = nPSizes; + this.nSubpages = nSubpages; + this.nSizes = nSizes; + + this.pageSize = pageSize; + this.pageShifts = pageShifts; + this.chunkSize = chunkSize; + this.directMemoryCacheAlignment = directMemoryCacheAlignment; + + //generate lookup tables + this.sizeIdx2sizeTab = newIdx2SizeTab(sizeClasses, nSizes, directMemoryCacheAlignment); + this.pageIdx2sizeTab = newPageIdx2sizeTab(sizeClasses, nSizes, nPSizes, directMemoryCacheAlignment); + this.size2idxTab = newSize2idxTab(lookupMaxSize, sizeClasses); + } + + //calculate size class + private static short[] newSizeClass(int index, int log2Group, int log2Delta, int nDelta, int pageShifts) { + short isMultiPageSize; + if (log2Delta >= pageShifts) { + isMultiPageSize = yes; + } else { + int pageSize = 1 << pageShifts; + int size = calculateSize(log2Group, nDelta, log2Delta); + + isMultiPageSize = size == size / pageSize * pageSize? yes : no; + } + + int log2Ndelta = nDelta == 0? 0 : log2(nDelta); + + byte remove = 1 << log2Ndelta < nDelta? yes : no; + + int log2Size = log2Delta + log2Ndelta == log2Group? log2Group + 1 : log2Group; + if (log2Size == log2Group) { + remove = yes; + } + + short isSubpage = log2Size < pageShifts + LOG2_SIZE_CLASS_GROUP? yes : no; + + int log2DeltaLookup = log2Size < LOG2_MAX_LOOKUP_SIZE || + log2Size == LOG2_MAX_LOOKUP_SIZE && remove == no + ? log2Delta : no; + + return new short[] { + (short) index, (short) log2Group, (short) log2Delta, + (short) nDelta, isMultiPageSize, isSubpage, (short) log2DeltaLookup + }; + } + + private static int[] newIdx2SizeTab(short[][] sizeClasses, int nSizes, int directMemoryCacheAlignment) { + int[] sizeIdx2sizeTab = new int[nSizes]; + + for (int i = 0; i < nSizes; i++) { + short[] sizeClass = sizeClasses[i]; + sizeIdx2sizeTab[i] = sizeOf(sizeClass, directMemoryCacheAlignment); + } + return sizeIdx2sizeTab; + } + + private static int calculateSize(int log2Group, int nDelta, int log2Delta) { + return (1 << log2Group) + (nDelta << log2Delta); + } + + private static int sizeOf(short[] sizeClass, int directMemoryCacheAlignment) { + int log2Group = sizeClass[LOG2GROUP_IDX]; + int log2Delta = sizeClass[LOG2DELTA_IDX]; + int nDelta = sizeClass[NDELTA_IDX]; + + int size = calculateSize(log2Group, nDelta, log2Delta); + + return alignSizeIfNeeded(size, directMemoryCacheAlignment); + } + + private static int[] newPageIdx2sizeTab(short[][] sizeClasses, int nSizes, int nPSizes, + int directMemoryCacheAlignment) { + int[] pageIdx2sizeTab = new int[nPSizes]; + int pageIdx = 0; + for (int i = 0; i < nSizes; i++) { + short[] sizeClass = sizeClasses[i]; + if (sizeClass[PAGESIZE_IDX] == yes) { + pageIdx2sizeTab[pageIdx++] = sizeOf(sizeClass, directMemoryCacheAlignment); + } + } + return pageIdx2sizeTab; + } + + private static int[] newSize2idxTab(int lookupMaxSize, short[][] sizeClasses) { + int[] size2idxTab = new int[lookupMaxSize >> LOG2_QUANTUM]; + int idx = 0; + int size = 0; + + for (int i = 0; size <= lookupMaxSize; i++) { + int log2Delta = sizeClasses[i][LOG2DELTA_IDX]; + int times = 1 << log2Delta - LOG2_QUANTUM; + + while (size <= lookupMaxSize && times-- > 0) { + size2idxTab[idx++] = i; + size = idx + 1 << LOG2_QUANTUM; + } + } + return size2idxTab; + } + + @Override + public int sizeIdx2size(int sizeIdx) { + return sizeIdx2sizeTab[sizeIdx]; + } + + @Override + public int sizeIdx2sizeCompute(int sizeIdx) { + int group = sizeIdx >> LOG2_SIZE_CLASS_GROUP; + int mod = sizeIdx & (1 << LOG2_SIZE_CLASS_GROUP) - 1; + + int groupSize = group == 0? 0 : + 1 << LOG2_QUANTUM + LOG2_SIZE_CLASS_GROUP - 1 << group; + + int shift = group == 0? 1 : group; + int lgDelta = shift + LOG2_QUANTUM - 1; + int modSize = mod + 1 << lgDelta; + + return groupSize + modSize; + } + + @Override + public long pageIdx2size(int pageIdx) { + return pageIdx2sizeTab[pageIdx]; + } + + @Override + public long pageIdx2sizeCompute(int pageIdx) { + int group = pageIdx >> LOG2_SIZE_CLASS_GROUP; + int mod = pageIdx & (1 << LOG2_SIZE_CLASS_GROUP) - 1; + + long groupSize = group == 0? 0 : + 1L << pageShifts + LOG2_SIZE_CLASS_GROUP - 1 << group; + + int shift = group == 0? 1 : group; + int log2Delta = shift + pageShifts - 1; + int modSize = mod + 1 << log2Delta; + + return groupSize + modSize; + } + + @Override + public int size2SizeIdx(int size) { + if (size == 0) { + return 0; + } + if (size > chunkSize) { + return nSizes; + } + + size = alignSizeIfNeeded(size, directMemoryCacheAlignment); + + if (size <= lookupMaxSize) { + //size-1 / MIN_TINY + return size2idxTab[size - 1 >> LOG2_QUANTUM]; + } + + int x = log2((size << 1) - 1); + int shift = x < LOG2_SIZE_CLASS_GROUP + LOG2_QUANTUM + 1 + ? 0 : x - (LOG2_SIZE_CLASS_GROUP + LOG2_QUANTUM); + + int group = shift << LOG2_SIZE_CLASS_GROUP; + + int log2Delta = x < LOG2_SIZE_CLASS_GROUP + LOG2_QUANTUM + 1 + ? LOG2_QUANTUM : x - LOG2_SIZE_CLASS_GROUP - 1; + + int mod = size - 1 >> log2Delta & (1 << LOG2_SIZE_CLASS_GROUP) - 1; + + return group + mod; + } + + @Override + public int pages2pageIdx(int pages) { + return pages2pageIdxCompute(pages, false); + } + + @Override + public int pages2pageIdxFloor(int pages) { + return pages2pageIdxCompute(pages, true); + } + + private int pages2pageIdxCompute(int pages, boolean floor) { + int pageSize = pages << pageShifts; + if (pageSize > chunkSize) { + return nPSizes; + } + + int x = log2((pageSize << 1) - 1); + + int shift = x < LOG2_SIZE_CLASS_GROUP + pageShifts + ? 0 : x - (LOG2_SIZE_CLASS_GROUP + pageShifts); + + int group = shift << LOG2_SIZE_CLASS_GROUP; + + int log2Delta = x < LOG2_SIZE_CLASS_GROUP + pageShifts + 1? + pageShifts : x - LOG2_SIZE_CLASS_GROUP - 1; + + int mod = pageSize - 1 >> log2Delta & (1 << LOG2_SIZE_CLASS_GROUP) - 1; + + int pageIdx = group + mod; + + if (floor && pageIdx2sizeTab[pageIdx] > pages << pageShifts) { + pageIdx--; + } + + return pageIdx; + } + + // Round size up to the nearest multiple of alignment. + private static int alignSizeIfNeeded(int size, int directMemoryCacheAlignment) { + if (directMemoryCacheAlignment <= 0) { + return size; + } + int delta = size & directMemoryCacheAlignment - 1; + return delta == 0? size : size + directMemoryCacheAlignment - delta; + } + + @Override + public int normalizeSize(int size) { + if (size == 0) { + return sizeIdx2sizeTab[0]; + } + size = alignSizeIfNeeded(size, directMemoryCacheAlignment); + if (size <= lookupMaxSize) { + int ret = sizeIdx2sizeTab[size2idxTab[size - 1 >> LOG2_QUANTUM]]; + assert ret == normalizeSizeCompute(size); + return ret; + } + return normalizeSizeCompute(size); + } + + private static int normalizeSizeCompute(int size) { + int x = log2((size << 1) - 1); + int log2Delta = x < LOG2_SIZE_CLASS_GROUP + LOG2_QUANTUM + 1 + ? LOG2_QUANTUM : x - LOG2_SIZE_CLASS_GROUP - 1; + int delta = 1 << log2Delta; + int delta_mask = delta - 1; + return size + delta_mask & ~delta_mask; + } +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/SizeClassesMetric.java b/netty-buffer/src/main/java/io/netty/buffer/SizeClassesMetric.java new file mode 100644 index 0000000..17ade94 --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/SizeClassesMetric.java @@ -0,0 +1,87 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +/** + * Expose metrics for an SizeClasses. + */ +public interface SizeClassesMetric { + + /** + * Computes size from lookup table according to sizeIdx. + * + * @return size + */ + int sizeIdx2size(int sizeIdx); + + /** + * Computes size according to sizeIdx. + * + * @return size + */ + int sizeIdx2sizeCompute(int sizeIdx); + + /** + * Computes size from lookup table according to pageIdx. + * + * @return size which is multiples of pageSize. + */ + long pageIdx2size(int pageIdx); + + /** + * Computes size according to pageIdx. + * + * @return size which is multiples of pageSize + */ + long pageIdx2sizeCompute(int pageIdx); + + /** + * Normalizes request size up to the nearest size class. + * + * @param size request size + * + * @return sizeIdx of the size class + */ + int size2SizeIdx(int size); + + /** + * Normalizes request size up to the nearest pageSize class. + * + * @param pages multiples of pageSizes + * + * @return pageIdx of the pageSize class + */ + int pages2pageIdx(int pages); + + /** + * Normalizes request size down to the nearest pageSize class. + * + * @param pages multiples of pageSizes + * + * @return pageIdx of the pageSize class + */ + int pages2pageIdxFloor(int pages); + + /** + * Normalizes usable size that would result from allocating an object with the + * specified size and alignment. + * + * @param size request size + * + * @return normalized size + */ + int normalizeSize(int size); +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/SlicedByteBuf.java b/netty-buffer/src/main/java/io/netty/buffer/SlicedByteBuf.java new file mode 100644 index 0000000..be0b711 --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/SlicedByteBuf.java @@ -0,0 +1,49 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +/** + * A derived buffer which exposes its parent's sub-region only. It is + * recommended to use {@link ByteBuf#slice()} and + * {@link ByteBuf#slice(int, int)} instead of calling the constructor + * explicitly. + * + * @deprecated Do not use. + */ +@Deprecated +public class SlicedByteBuf extends AbstractUnpooledSlicedByteBuf { + + private int length; + + public SlicedByteBuf(ByteBuf buffer, int index, int length) { + super(buffer, index, length); + } + + @Override + final void initLength(int length) { + this.length = length; + } + + @Override + final int length() { + return length; + } + + @Override + public int capacity() { + return length; + } +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/SwappedByteBuf.java b/netty-buffer/src/main/java/io/netty/buffer/SwappedByteBuf.java new file mode 100644 index 0000000..3b1b4d5 --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/SwappedByteBuf.java @@ -0,0 +1,1066 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import io.netty.util.ByteProcessor; +import io.netty.util.internal.ObjectUtil; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.channels.FileChannel; +import java.nio.channels.GatheringByteChannel; +import java.nio.channels.ScatteringByteChannel; +import java.nio.charset.Charset; + +/** + * Wrapper which swap the {@link ByteOrder} of a {@link ByteBuf}. + * + * @deprecated use the Little Endian accessors, e.g. {@code getShortLE}, {@code getIntLE} + * instead. + */ +@Deprecated +public class SwappedByteBuf extends ByteBuf { + + private final ByteBuf buf; + private final ByteOrder order; + + public SwappedByteBuf(ByteBuf buf) { + this.buf = ObjectUtil.checkNotNull(buf, "buf"); + if (buf.order() == ByteOrder.BIG_ENDIAN) { + order = ByteOrder.LITTLE_ENDIAN; + } else { + order = ByteOrder.BIG_ENDIAN; + } + } + + @Override + public ByteOrder order() { + return order; + } + + @Override + public ByteBuf order(ByteOrder endianness) { + if (ObjectUtil.checkNotNull(endianness, "endianness") == order) { + return this; + } + return buf; + } + + @Override + public ByteBuf unwrap() { + return buf; + } + + @Override + public ByteBufAllocator alloc() { + return buf.alloc(); + } + + @Override + public int capacity() { + return buf.capacity(); + } + + @Override + public ByteBuf capacity(int newCapacity) { + buf.capacity(newCapacity); + return this; + } + + @Override + public int maxCapacity() { + return buf.maxCapacity(); + } + + @Override + public boolean isReadOnly() { + return buf.isReadOnly(); + } + + @Override + public ByteBuf asReadOnly() { + return Unpooled.unmodifiableBuffer(this); + } + + @Override + public boolean isDirect() { + return buf.isDirect(); + } + + @Override + public int readerIndex() { + return buf.readerIndex(); + } + + @Override + public ByteBuf readerIndex(int readerIndex) { + buf.readerIndex(readerIndex); + return this; + } + + @Override + public int writerIndex() { + return buf.writerIndex(); + } + + @Override + public ByteBuf writerIndex(int writerIndex) { + buf.writerIndex(writerIndex); + return this; + } + + @Override + public ByteBuf setIndex(int readerIndex, int writerIndex) { + buf.setIndex(readerIndex, writerIndex); + return this; + } + + @Override + public int readableBytes() { + return buf.readableBytes(); + } + + @Override + public int writableBytes() { + return buf.writableBytes(); + } + + @Override + public int maxWritableBytes() { + return buf.maxWritableBytes(); + } + + @Override + public int maxFastWritableBytes() { + return buf.maxFastWritableBytes(); + } + + @Override + public boolean isReadable() { + return buf.isReadable(); + } + + @Override + public boolean isReadable(int size) { + return buf.isReadable(size); + } + + @Override + public boolean isWritable() { + return buf.isWritable(); + } + + @Override + public boolean isWritable(int size) { + return buf.isWritable(size); + } + + @Override + public ByteBuf clear() { + buf.clear(); + return this; + } + + @Override + public ByteBuf markReaderIndex() { + buf.markReaderIndex(); + return this; + } + + @Override + public ByteBuf resetReaderIndex() { + buf.resetReaderIndex(); + return this; + } + + @Override + public ByteBuf markWriterIndex() { + buf.markWriterIndex(); + return this; + } + + @Override + public ByteBuf resetWriterIndex() { + buf.resetWriterIndex(); + return this; + } + + @Override + public ByteBuf discardReadBytes() { + buf.discardReadBytes(); + return this; + } + + @Override + public ByteBuf discardSomeReadBytes() { + buf.discardSomeReadBytes(); + return this; + } + + @Override + public ByteBuf ensureWritable(int writableBytes) { + buf.ensureWritable(writableBytes); + return this; + } + + @Override + public int ensureWritable(int minWritableBytes, boolean force) { + return buf.ensureWritable(minWritableBytes, force); + } + + @Override + public boolean getBoolean(int index) { + return buf.getBoolean(index); + } + + @Override + public byte getByte(int index) { + return buf.getByte(index); + } + + @Override + public short getUnsignedByte(int index) { + return buf.getUnsignedByte(index); + } + + @Override + public short getShort(int index) { + return ByteBufUtil.swapShort(buf.getShort(index)); + } + + @Override + public short getShortLE(int index) { + return buf.getShortLE(index); + } + + @Override + public int getUnsignedShort(int index) { + return getShort(index) & 0xFFFF; + } + + @Override + public int getUnsignedShortLE(int index) { + return getShortLE(index) & 0xFFFF; + } + + @Override + public int getMedium(int index) { + return ByteBufUtil.swapMedium(buf.getMedium(index)); + } + + @Override + public int getMediumLE(int index) { + return buf.getMediumLE(index); + } + + @Override + public int getUnsignedMedium(int index) { + return getMedium(index) & 0xFFFFFF; + } + + @Override + public int getUnsignedMediumLE(int index) { + return getMediumLE(index) & 0xFFFFFF; + } + + @Override + public int getInt(int index) { + return ByteBufUtil.swapInt(buf.getInt(index)); + } + + @Override + public int getIntLE(int index) { + return buf.getIntLE(index); + } + + @Override + public long getUnsignedInt(int index) { + return getInt(index) & 0xFFFFFFFFL; + } + + @Override + public long getUnsignedIntLE(int index) { + return getIntLE(index) & 0xFFFFFFFFL; + } + + @Override + public long getLong(int index) { + return ByteBufUtil.swapLong(buf.getLong(index)); + } + + @Override + public long getLongLE(int index) { + return buf.getLongLE(index); + } + + @Override + public char getChar(int index) { + return (char) getShort(index); + } + + @Override + public float getFloat(int index) { + return Float.intBitsToFloat(getInt(index)); + } + + @Override + public double getDouble(int index) { + return Double.longBitsToDouble(getLong(index)); + } + + @Override + public ByteBuf getBytes(int index, ByteBuf dst) { + buf.getBytes(index, dst); + return this; + } + + @Override + public ByteBuf getBytes(int index, ByteBuf dst, int length) { + buf.getBytes(index, dst, length); + return this; + } + + @Override + public ByteBuf getBytes(int index, ByteBuf dst, int dstIndex, int length) { + buf.getBytes(index, dst, dstIndex, length); + return this; + } + + @Override + public ByteBuf getBytes(int index, byte[] dst) { + buf.getBytes(index, dst); + return this; + } + + @Override + public ByteBuf getBytes(int index, byte[] dst, int dstIndex, int length) { + buf.getBytes(index, dst, dstIndex, length); + return this; + } + + @Override + public ByteBuf getBytes(int index, ByteBuffer dst) { + buf.getBytes(index, dst); + return this; + } + + @Override + public ByteBuf getBytes(int index, OutputStream out, int length) throws IOException { + buf.getBytes(index, out, length); + return this; + } + + @Override + public int getBytes(int index, GatheringByteChannel out, int length) throws IOException { + return buf.getBytes(index, out, length); + } + + @Override + public int getBytes(int index, FileChannel out, long position, int length) throws IOException { + return buf.getBytes(index, out, position, length); + } + + @Override + public CharSequence getCharSequence(int index, int length, Charset charset) { + return buf.getCharSequence(index, length, charset); + } + + @Override + public ByteBuf setBoolean(int index, boolean value) { + buf.setBoolean(index, value); + return this; + } + + @Override + public ByteBuf setByte(int index, int value) { + buf.setByte(index, value); + return this; + } + + @Override + public ByteBuf setShort(int index, int value) { + buf.setShort(index, ByteBufUtil.swapShort((short) value)); + return this; + } + + @Override + public ByteBuf setShortLE(int index, int value) { + buf.setShortLE(index, (short) value); + return this; + } + + @Override + public ByteBuf setMedium(int index, int value) { + buf.setMedium(index, ByteBufUtil.swapMedium(value)); + return this; + } + + @Override + public ByteBuf setMediumLE(int index, int value) { + buf.setMediumLE(index, value); + return this; + } + + @Override + public ByteBuf setInt(int index, int value) { + buf.setInt(index, ByteBufUtil.swapInt(value)); + return this; + } + + @Override + public ByteBuf setIntLE(int index, int value) { + buf.setIntLE(index, value); + return this; + } + + @Override + public ByteBuf setLong(int index, long value) { + buf.setLong(index, ByteBufUtil.swapLong(value)); + return this; + } + + @Override + public ByteBuf setLongLE(int index, long value) { + buf.setLongLE(index, value); + return this; + } + + @Override + public ByteBuf setChar(int index, int value) { + setShort(index, value); + return this; + } + + @Override + public ByteBuf setFloat(int index, float value) { + setInt(index, Float.floatToRawIntBits(value)); + return this; + } + + @Override + public ByteBuf setDouble(int index, double value) { + setLong(index, Double.doubleToRawLongBits(value)); + return this; + } + + @Override + public ByteBuf setBytes(int index, ByteBuf src) { + buf.setBytes(index, src); + return this; + } + + @Override + public ByteBuf setBytes(int index, ByteBuf src, int length) { + buf.setBytes(index, src, length); + return this; + } + + @Override + public ByteBuf setBytes(int index, ByteBuf src, int srcIndex, int length) { + buf.setBytes(index, src, srcIndex, length); + return this; + } + + @Override + public ByteBuf setBytes(int index, byte[] src) { + buf.setBytes(index, src); + return this; + } + + @Override + public ByteBuf setBytes(int index, byte[] src, int srcIndex, int length) { + buf.setBytes(index, src, srcIndex, length); + return this; + } + + @Override + public ByteBuf setBytes(int index, ByteBuffer src) { + buf.setBytes(index, src); + return this; + } + + @Override + public int setBytes(int index, InputStream in, int length) throws IOException { + return buf.setBytes(index, in, length); + } + + @Override + public int setBytes(int index, ScatteringByteChannel in, int length) throws IOException { + return buf.setBytes(index, in, length); + } + + @Override + public int setBytes(int index, FileChannel in, long position, int length) throws IOException { + return buf.setBytes(index, in, position, length); + } + + @Override + public ByteBuf setZero(int index, int length) { + buf.setZero(index, length); + return this; + } + + @Override + public int setCharSequence(int index, CharSequence sequence, Charset charset) { + return buf.setCharSequence(index, sequence, charset); + } + + @Override + public boolean readBoolean() { + return buf.readBoolean(); + } + + @Override + public byte readByte() { + return buf.readByte(); + } + + @Override + public short readUnsignedByte() { + return buf.readUnsignedByte(); + } + + @Override + public short readShort() { + return ByteBufUtil.swapShort(buf.readShort()); + } + + @Override + public short readShortLE() { + return buf.readShortLE(); + } + + @Override + public int readUnsignedShort() { + return readShort() & 0xFFFF; + } + + @Override + public int readUnsignedShortLE() { + return readShortLE() & 0xFFFF; + } + + @Override + public int readMedium() { + return ByteBufUtil.swapMedium(buf.readMedium()); + } + + @Override + public int readMediumLE() { + return buf.readMediumLE(); + } + + @Override + public int readUnsignedMedium() { + return readMedium() & 0xFFFFFF; + } + + @Override + public int readUnsignedMediumLE() { + return readMediumLE() & 0xFFFFFF; + } + + @Override + public int readInt() { + return ByteBufUtil.swapInt(buf.readInt()); + } + + @Override + public int readIntLE() { + return buf.readIntLE(); + } + + @Override + public long readUnsignedInt() { + return readInt() & 0xFFFFFFFFL; + } + + @Override + public long readUnsignedIntLE() { + return readIntLE() & 0xFFFFFFFFL; + } + + @Override + public long readLong() { + return ByteBufUtil.swapLong(buf.readLong()); + } + + @Override + public long readLongLE() { + return buf.readLongLE(); + } + + @Override + public char readChar() { + return (char) readShort(); + } + + @Override + public float readFloat() { + return Float.intBitsToFloat(readInt()); + } + + @Override + public double readDouble() { + return Double.longBitsToDouble(readLong()); + } + + @Override + public ByteBuf readBytes(int length) { + return buf.readBytes(length).order(order()); + } + + @Override + public ByteBuf readSlice(int length) { + return buf.readSlice(length).order(order); + } + + @Override + public ByteBuf readRetainedSlice(int length) { + return buf.readRetainedSlice(length).order(order); + } + + @Override + public ByteBuf readBytes(ByteBuf dst) { + buf.readBytes(dst); + return this; + } + + @Override + public ByteBuf readBytes(ByteBuf dst, int length) { + buf.readBytes(dst, length); + return this; + } + + @Override + public ByteBuf readBytes(ByteBuf dst, int dstIndex, int length) { + buf.readBytes(dst, dstIndex, length); + return this; + } + + @Override + public ByteBuf readBytes(byte[] dst) { + buf.readBytes(dst); + return this; + } + + @Override + public ByteBuf readBytes(byte[] dst, int dstIndex, int length) { + buf.readBytes(dst, dstIndex, length); + return this; + } + + @Override + public ByteBuf readBytes(ByteBuffer dst) { + buf.readBytes(dst); + return this; + } + + @Override + public ByteBuf readBytes(OutputStream out, int length) throws IOException { + buf.readBytes(out, length); + return this; + } + + @Override + public int readBytes(GatheringByteChannel out, int length) throws IOException { + return buf.readBytes(out, length); + } + + @Override + public int readBytes(FileChannel out, long position, int length) throws IOException { + return buf.readBytes(out, position, length); + } + + @Override + public CharSequence readCharSequence(int length, Charset charset) { + return buf.readCharSequence(length, charset); + } + + @Override + public ByteBuf skipBytes(int length) { + buf.skipBytes(length); + return this; + } + + @Override + public ByteBuf writeBoolean(boolean value) { + buf.writeBoolean(value); + return this; + } + + @Override + public ByteBuf writeByte(int value) { + buf.writeByte(value); + return this; + } + + @Override + public ByteBuf writeShort(int value) { + buf.writeShort(ByteBufUtil.swapShort((short) value)); + return this; + } + + @Override + public ByteBuf writeShortLE(int value) { + buf.writeShortLE((short) value); + return this; + } + + @Override + public ByteBuf writeMedium(int value) { + buf.writeMedium(ByteBufUtil.swapMedium(value)); + return this; + } + + @Override + public ByteBuf writeMediumLE(int value) { + buf.writeMediumLE(value); + return this; + } + + @Override + public ByteBuf writeInt(int value) { + buf.writeInt(ByteBufUtil.swapInt(value)); + return this; + } + + @Override + public ByteBuf writeIntLE(int value) { + buf.writeIntLE(value); + return this; + } + + @Override + public ByteBuf writeLong(long value) { + buf.writeLong(ByteBufUtil.swapLong(value)); + return this; + } + + @Override + public ByteBuf writeLongLE(long value) { + buf.writeLongLE(value); + return this; + } + + @Override + public ByteBuf writeChar(int value) { + writeShort(value); + return this; + } + + @Override + public ByteBuf writeFloat(float value) { + writeInt(Float.floatToRawIntBits(value)); + return this; + } + + @Override + public ByteBuf writeDouble(double value) { + writeLong(Double.doubleToRawLongBits(value)); + return this; + } + + @Override + public ByteBuf writeBytes(ByteBuf src) { + buf.writeBytes(src); + return this; + } + + @Override + public ByteBuf writeBytes(ByteBuf src, int length) { + buf.writeBytes(src, length); + return this; + } + + @Override + public ByteBuf writeBytes(ByteBuf src, int srcIndex, int length) { + buf.writeBytes(src, srcIndex, length); + return this; + } + + @Override + public ByteBuf writeBytes(byte[] src) { + buf.writeBytes(src); + return this; + } + + @Override + public ByteBuf writeBytes(byte[] src, int srcIndex, int length) { + buf.writeBytes(src, srcIndex, length); + return this; + } + + @Override + public ByteBuf writeBytes(ByteBuffer src) { + buf.writeBytes(src); + return this; + } + + @Override + public int writeBytes(InputStream in, int length) throws IOException { + return buf.writeBytes(in, length); + } + + @Override + public int writeBytes(ScatteringByteChannel in, int length) throws IOException { + return buf.writeBytes(in, length); + } + + @Override + public int writeBytes(FileChannel in, long position, int length) throws IOException { + return buf.writeBytes(in, position, length); + } + + @Override + public ByteBuf writeZero(int length) { + buf.writeZero(length); + return this; + } + + @Override + public int writeCharSequence(CharSequence sequence, Charset charset) { + return buf.writeCharSequence(sequence, charset); + } + + @Override + public int indexOf(int fromIndex, int toIndex, byte value) { + return buf.indexOf(fromIndex, toIndex, value); + } + + @Override + public int bytesBefore(byte value) { + return buf.bytesBefore(value); + } + + @Override + public int bytesBefore(int length, byte value) { + return buf.bytesBefore(length, value); + } + + @Override + public int bytesBefore(int index, int length, byte value) { + return buf.bytesBefore(index, length, value); + } + + @Override + public int forEachByte(ByteProcessor processor) { + return buf.forEachByte(processor); + } + + @Override + public int forEachByte(int index, int length, ByteProcessor processor) { + return buf.forEachByte(index, length, processor); + } + + @Override + public int forEachByteDesc(ByteProcessor processor) { + return buf.forEachByteDesc(processor); + } + + @Override + public int forEachByteDesc(int index, int length, ByteProcessor processor) { + return buf.forEachByteDesc(index, length, processor); + } + + @Override + public ByteBuf copy() { + return buf.copy().order(order); + } + + @Override + public ByteBuf copy(int index, int length) { + return buf.copy(index, length).order(order); + } + + @Override + public ByteBuf slice() { + return buf.slice().order(order); + } + + @Override + public ByteBuf retainedSlice() { + return buf.retainedSlice().order(order); + } + + @Override + public ByteBuf slice(int index, int length) { + return buf.slice(index, length).order(order); + } + + @Override + public ByteBuf retainedSlice(int index, int length) { + return buf.retainedSlice(index, length).order(order); + } + + @Override + public ByteBuf duplicate() { + return buf.duplicate().order(order); + } + + @Override + public ByteBuf retainedDuplicate() { + return buf.retainedDuplicate().order(order); + } + + @Override + public int nioBufferCount() { + return buf.nioBufferCount(); + } + + @Override + public ByteBuffer nioBuffer() { + return buf.nioBuffer().order(order); + } + + @Override + public ByteBuffer nioBuffer(int index, int length) { + return buf.nioBuffer(index, length).order(order); + } + + @Override + public ByteBuffer internalNioBuffer(int index, int length) { + return nioBuffer(index, length); + } + + @Override + public ByteBuffer[] nioBuffers() { + ByteBuffer[] nioBuffers = buf.nioBuffers(); + for (int i = 0; i < nioBuffers.length; i++) { + nioBuffers[i] = nioBuffers[i].order(order); + } + return nioBuffers; + } + + @Override + public ByteBuffer[] nioBuffers(int index, int length) { + ByteBuffer[] nioBuffers = buf.nioBuffers(index, length); + for (int i = 0; i < nioBuffers.length; i++) { + nioBuffers[i] = nioBuffers[i].order(order); + } + return nioBuffers; + } + + @Override + public boolean hasArray() { + return buf.hasArray(); + } + + @Override + public byte[] array() { + return buf.array(); + } + + @Override + public int arrayOffset() { + return buf.arrayOffset(); + } + + @Override + public boolean hasMemoryAddress() { + return buf.hasMemoryAddress(); + } + + @Override + public boolean isContiguous() { + return buf.isContiguous(); + } + + @Override + public long memoryAddress() { + return buf.memoryAddress(); + } + + @Override + public String toString(Charset charset) { + return buf.toString(charset); + } + + @Override + public String toString(int index, int length, Charset charset) { + return buf.toString(index, length, charset); + } + + @Override + public int refCnt() { + return buf.refCnt(); + } + + @Override + final boolean isAccessible() { + return buf.isAccessible(); + } + + @Override + public ByteBuf retain() { + buf.retain(); + return this; + } + + @Override + public ByteBuf retain(int increment) { + buf.retain(increment); + return this; + } + + @Override + public ByteBuf touch() { + buf.touch(); + return this; + } + + @Override + public ByteBuf touch(Object hint) { + buf.touch(hint); + return this; + } + + @Override + public boolean release() { + return buf.release(); + } + + @Override + public boolean release(int decrement) { + return buf.release(decrement); + } + + @Override + public int hashCode() { + return buf.hashCode(); + } + + @Override + public boolean equals(Object obj) { + if (obj instanceof ByteBuf) { + return ByteBufUtil.equals(this, (ByteBuf) obj); + } + return false; + } + + @Override + public int compareTo(ByteBuf buffer) { + return ByteBufUtil.compare(this, buffer); + } + + @Override + public String toString() { + return "Swapped(" + buf + ')'; + } +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/Unpooled.java b/netty-buffer/src/main/java/io/netty/buffer/Unpooled.java new file mode 100644 index 0000000..6d9a90b --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/Unpooled.java @@ -0,0 +1,923 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import io.netty.buffer.CompositeByteBuf.ByteWrapper; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.CharsetUtil; +import io.netty.util.internal.PlatformDependent; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.CharBuffer; +import java.nio.charset.Charset; +import java.util.Arrays; + + +/** + * Creates a new {@link ByteBuf} by allocating new space or by wrapping + * or copying existing byte arrays, byte buffers and a string. + * + *

Use static import

+ * This classes is intended to be used with Java 5 static import statement: + * + *
+ * import static io.netty.buffer.{@link Unpooled}.*;
+ *
+ * {@link ByteBuf} heapBuffer    = buffer(128);
+ * {@link ByteBuf} directBuffer  = directBuffer(256);
+ * {@link ByteBuf} wrappedBuffer = wrappedBuffer(new byte[128], new byte[256]);
+ * {@link ByteBuf} copiedBuffer  = copiedBuffer({@link ByteBuffer}.allocate(128));
+ * 
+ * + *

Allocating a new buffer

+ * + * Three buffer types are provided out of the box. + * + *
    + *
  • {@link #buffer(int)} allocates a new fixed-capacity heap buffer.
  • + *
  • {@link #directBuffer(int)} allocates a new fixed-capacity direct buffer.
  • + *
+ * + *

Creating a wrapped buffer

+ * + * Wrapped buffer is a buffer which is a view of one or more existing + * byte arrays and byte buffers. Any changes in the content of the original + * array or buffer will be visible in the wrapped buffer. Various wrapper + * methods are provided and their name is all {@code wrappedBuffer()}. + * You might want to take a look at the methods that accept varargs closely if + * you want to create a buffer which is composed of more than one array to + * reduce the number of memory copy. + * + *

Creating a copied buffer

+ * + * Copied buffer is a deep copy of one or more existing byte arrays, byte + * buffers or a string. Unlike a wrapped buffer, there's no shared data + * between the original data and the copied buffer. Various copy methods are + * provided and their name is all {@code copiedBuffer()}. It is also convenient + * to use this operation to merge multiple buffers into one buffer. + */ +public final class Unpooled { + + private static final ByteBufAllocator ALLOC = UnpooledByteBufAllocator.DEFAULT; + + /** + * Big endian byte order. + */ + public static final ByteOrder BIG_ENDIAN = ByteOrder.BIG_ENDIAN; + + /** + * Little endian byte order. + */ + public static final ByteOrder LITTLE_ENDIAN = ByteOrder.LITTLE_ENDIAN; + + /** + * A buffer whose capacity is {@code 0}. + */ + @SuppressWarnings("checkstyle:StaticFinalBuffer") // EmptyByteBuf is not writeable or readable. + public static final ByteBuf EMPTY_BUFFER = ALLOC.buffer(0, 0); + + static { + assert EMPTY_BUFFER instanceof EmptyByteBuf: "EMPTY_BUFFER must be an EmptyByteBuf."; + } + + /** + * Creates a new big-endian Java heap buffer with reasonably small initial capacity, which + * expands its capacity boundlessly on demand. + */ + public static ByteBuf buffer() { + return ALLOC.heapBuffer(); + } + + /** + * Creates a new big-endian direct buffer with reasonably small initial capacity, which + * expands its capacity boundlessly on demand. + */ + public static ByteBuf directBuffer() { + return ALLOC.directBuffer(); + } + + /** + * Creates a new big-endian Java heap buffer with the specified {@code capacity}, which + * expands its capacity boundlessly on demand. The new buffer's {@code readerIndex} and + * {@code writerIndex} are {@code 0}. + */ + public static ByteBuf buffer(int initialCapacity) { + return ALLOC.heapBuffer(initialCapacity); + } + + /** + * Creates a new big-endian direct buffer with the specified {@code capacity}, which + * expands its capacity boundlessly on demand. The new buffer's {@code readerIndex} and + * {@code writerIndex} are {@code 0}. + */ + public static ByteBuf directBuffer(int initialCapacity) { + return ALLOC.directBuffer(initialCapacity); + } + + /** + * Creates a new big-endian Java heap buffer with the specified + * {@code initialCapacity}, that may grow up to {@code maxCapacity} + * The new buffer's {@code readerIndex} and {@code writerIndex} are + * {@code 0}. + */ + public static ByteBuf buffer(int initialCapacity, int maxCapacity) { + return ALLOC.heapBuffer(initialCapacity, maxCapacity); + } + + /** + * Creates a new big-endian direct buffer with the specified + * {@code initialCapacity}, that may grow up to {@code maxCapacity}. + * The new buffer's {@code readerIndex} and {@code writerIndex} are + * {@code 0}. + */ + public static ByteBuf directBuffer(int initialCapacity, int maxCapacity) { + return ALLOC.directBuffer(initialCapacity, maxCapacity); + } + + /** + * Creates a new big-endian buffer which wraps the specified {@code array}. + * A modification on the specified array's content will be visible to the + * returned buffer. + */ + public static ByteBuf wrappedBuffer(byte[] array) { + if (array.length == 0) { + return EMPTY_BUFFER; + } + return new UnpooledHeapByteBuf(ALLOC, array, array.length); + } + + /** + * Creates a new big-endian buffer which wraps the sub-region of the + * specified {@code array}. A modification on the specified array's + * content will be visible to the returned buffer. + */ + public static ByteBuf wrappedBuffer(byte[] array, int offset, int length) { + if (length == 0) { + return EMPTY_BUFFER; + } + + if (offset == 0 && length == array.length) { + return wrappedBuffer(array); + } + + return wrappedBuffer(array).slice(offset, length); + } + + /** + * Creates a new buffer which wraps the specified NIO buffer's current + * slice. A modification on the specified buffer's content will be + * visible to the returned buffer. + */ + public static ByteBuf wrappedBuffer(ByteBuffer buffer) { + if (!buffer.hasRemaining()) { + return EMPTY_BUFFER; + } + if (!buffer.isDirect() && buffer.hasArray()) { + return wrappedBuffer( + buffer.array(), + buffer.arrayOffset() + buffer.position(), + buffer.remaining()).order(buffer.order()); + } else if (PlatformDependent.hasUnsafe()) { + if (buffer.isReadOnly()) { + if (buffer.isDirect()) { + return new ReadOnlyUnsafeDirectByteBuf(ALLOC, buffer); + } else { + return new ReadOnlyByteBufferBuf(ALLOC, buffer); + } + } else { + return new UnpooledUnsafeDirectByteBuf(ALLOC, buffer, buffer.remaining()); + } + } else { + if (buffer.isReadOnly()) { + return new ReadOnlyByteBufferBuf(ALLOC, buffer); + } else { + return new UnpooledDirectByteBuf(ALLOC, buffer, buffer.remaining()); + } + } + } + + /** + * Creates a new buffer which wraps the specified memory address. If {@code doFree} is true the + * memoryAddress will automatically be freed once the reference count of the {@link ByteBuf} reaches {@code 0}. + */ + public static ByteBuf wrappedBuffer(long memoryAddress, int size, boolean doFree) { + return new WrappedUnpooledUnsafeDirectByteBuf(ALLOC, memoryAddress, size, doFree); + } + + /** + * Creates a new buffer which wraps the specified buffer's readable bytes. + * A modification on the specified buffer's content will be visible to the + * returned buffer. + * @param buffer The buffer to wrap. Reference count ownership of this variable is transferred to this method. + * @return The readable portion of the {@code buffer}, or an empty buffer if there is no readable portion. + * The caller is responsible for releasing this buffer. + */ + public static ByteBuf wrappedBuffer(ByteBuf buffer) { + if (buffer.isReadable()) { + return buffer.slice(); + } else { + buffer.release(); + return EMPTY_BUFFER; + } + } + + /** + * Creates a new big-endian composite buffer which wraps the specified + * arrays without copying them. A modification on the specified arrays' + * content will be visible to the returned buffer. + */ + public static ByteBuf wrappedBuffer(byte[]... arrays) { + return wrappedBuffer(arrays.length, arrays); + } + + /** + * Creates a new big-endian composite buffer which wraps the readable bytes of the + * specified buffers without copying them. A modification on the content + * of the specified buffers will be visible to the returned buffer. + * @param buffers The buffers to wrap. Reference count ownership of all variables is transferred to this method. + * @return The readable portion of the {@code buffers}. The caller is responsible for releasing this buffer. + */ + public static ByteBuf wrappedBuffer(ByteBuf... buffers) { + return wrappedBuffer(buffers.length, buffers); + } + + /** + * Creates a new big-endian composite buffer which wraps the slices of the specified + * NIO buffers without copying them. A modification on the content of the + * specified buffers will be visible to the returned buffer. + */ + public static ByteBuf wrappedBuffer(ByteBuffer... buffers) { + return wrappedBuffer(buffers.length, buffers); + } + + static ByteBuf wrappedBuffer(int maxNumComponents, ByteWrapper wrapper, T[] array) { + switch (array.length) { + case 0: + break; + case 1: + if (!wrapper.isEmpty(array[0])) { + return wrapper.wrap(array[0]); + } + break; + default: + for (int i = 0, len = array.length; i < len; i++) { + T bytes = array[i]; + if (bytes == null) { + return EMPTY_BUFFER; + } + if (!wrapper.isEmpty(bytes)) { + return new CompositeByteBuf(ALLOC, false, maxNumComponents, wrapper, array, i); + } + } + } + + return EMPTY_BUFFER; + } + + /** + * Creates a new big-endian composite buffer which wraps the specified + * arrays without copying them. A modification on the specified arrays' + * content will be visible to the returned buffer. + */ + public static ByteBuf wrappedBuffer(int maxNumComponents, byte[]... arrays) { + return wrappedBuffer(maxNumComponents, CompositeByteBuf.BYTE_ARRAY_WRAPPER, arrays); + } + + /** + * Creates a new big-endian composite buffer which wraps the readable bytes of the + * specified buffers without copying them. A modification on the content + * of the specified buffers will be visible to the returned buffer. + * @param maxNumComponents Advisement as to how many independent buffers are allowed to exist before + * consolidation occurs. + * @param buffers The buffers to wrap. Reference count ownership of all variables is transferred to this method. + * @return The readable portion of the {@code buffers}. The caller is responsible for releasing this buffer. + */ + public static ByteBuf wrappedBuffer(int maxNumComponents, ByteBuf... buffers) { + switch (buffers.length) { + case 0: + break; + case 1: + ByteBuf buffer = buffers[0]; + if (buffer.isReadable()) { + return wrappedBuffer(buffer.order(BIG_ENDIAN)); + } else { + buffer.release(); + } + break; + default: + for (int i = 0; i < buffers.length; i++) { + ByteBuf buf = buffers[i]; + if (buf.isReadable()) { + return new CompositeByteBuf(ALLOC, false, maxNumComponents, buffers, i); + } + buf.release(); + } + break; + } + return EMPTY_BUFFER; + } + + /** + * Creates a new big-endian composite buffer which wraps the slices of the specified + * NIO buffers without copying them. A modification on the content of the + * specified buffers will be visible to the returned buffer. + */ + public static ByteBuf wrappedBuffer(int maxNumComponents, ByteBuffer... buffers) { + return wrappedBuffer(maxNumComponents, CompositeByteBuf.BYTE_BUFFER_WRAPPER, buffers); + } + + /** + * Returns a new big-endian composite buffer with no components. + */ + public static CompositeByteBuf compositeBuffer() { + return compositeBuffer(AbstractByteBufAllocator.DEFAULT_MAX_COMPONENTS); + } + + /** + * Returns a new big-endian composite buffer with no components. + */ + public static CompositeByteBuf compositeBuffer(int maxNumComponents) { + return new CompositeByteBuf(ALLOC, false, maxNumComponents); + } + + /** + * Creates a new big-endian buffer whose content is a copy of the + * specified {@code array}. The new buffer's {@code readerIndex} and + * {@code writerIndex} are {@code 0} and {@code array.length} respectively. + */ + public static ByteBuf copiedBuffer(byte[] array) { + if (array.length == 0) { + return EMPTY_BUFFER; + } + return wrappedBuffer(array.clone()); + } + + /** + * Creates a new big-endian buffer whose content is a copy of the + * specified {@code array}'s sub-region. The new buffer's + * {@code readerIndex} and {@code writerIndex} are {@code 0} and + * the specified {@code length} respectively. + */ + public static ByteBuf copiedBuffer(byte[] array, int offset, int length) { + if (length == 0) { + return EMPTY_BUFFER; + } + byte[] copy = PlatformDependent.allocateUninitializedArray(length); + System.arraycopy(array, offset, copy, 0, length); + return wrappedBuffer(copy); + } + + /** + * Creates a new buffer whose content is a copy of the specified + * {@code buffer}'s current slice. The new buffer's {@code readerIndex} + * and {@code writerIndex} are {@code 0} and {@code buffer.remaining} + * respectively. + */ + public static ByteBuf copiedBuffer(ByteBuffer buffer) { + int length = buffer.remaining(); + if (length == 0) { + return EMPTY_BUFFER; + } + byte[] copy = PlatformDependent.allocateUninitializedArray(length); + // Duplicate the buffer so we not adjust the position during our get operation. + // See https://github.com/netty/netty/issues/3896 + ByteBuffer duplicate = buffer.duplicate(); + duplicate.get(copy); + return wrappedBuffer(copy).order(duplicate.order()); + } + + /** + * Creates a new buffer whose content is a copy of the specified + * {@code buffer}'s readable bytes. The new buffer's {@code readerIndex} + * and {@code writerIndex} are {@code 0} and {@code buffer.readableBytes} + * respectively. + */ + public static ByteBuf copiedBuffer(ByteBuf buffer) { + int readable = buffer.readableBytes(); + if (readable > 0) { + ByteBuf copy = buffer(readable); + copy.writeBytes(buffer, buffer.readerIndex(), readable); + return copy; + } else { + return EMPTY_BUFFER; + } + } + + /** + * Creates a new big-endian buffer whose content is a merged copy of + * the specified {@code arrays}. The new buffer's {@code readerIndex} + * and {@code writerIndex} are {@code 0} and the sum of all arrays' + * {@code length} respectively. + */ + public static ByteBuf copiedBuffer(byte[]... arrays) { + switch (arrays.length) { + case 0: + return EMPTY_BUFFER; + case 1: + if (arrays[0].length == 0) { + return EMPTY_BUFFER; + } else { + return copiedBuffer(arrays[0]); + } + } + + // Merge the specified arrays into one array. + int length = 0; + for (byte[] a: arrays) { + if (Integer.MAX_VALUE - length < a.length) { + throw new IllegalArgumentException( + "The total length of the specified arrays is too big."); + } + length += a.length; + } + + if (length == 0) { + return EMPTY_BUFFER; + } + + byte[] mergedArray = PlatformDependent.allocateUninitializedArray(length); + for (int i = 0, j = 0; i < arrays.length; i ++) { + byte[] a = arrays[i]; + System.arraycopy(a, 0, mergedArray, j, a.length); + j += a.length; + } + + return wrappedBuffer(mergedArray); + } + + /** + * Creates a new buffer whose content is a merged copy of the specified + * {@code buffers}' readable bytes. The new buffer's {@code readerIndex} + * and {@code writerIndex} are {@code 0} and the sum of all buffers' + * {@code readableBytes} respectively. + * + * @throws IllegalArgumentException + * if the specified buffers' endianness are different from each + * other + */ + public static ByteBuf copiedBuffer(ByteBuf... buffers) { + switch (buffers.length) { + case 0: + return EMPTY_BUFFER; + case 1: + return copiedBuffer(buffers[0]); + } + + // Merge the specified buffers into one buffer. + ByteOrder order = null; + int length = 0; + for (ByteBuf b: buffers) { + int bLen = b.readableBytes(); + if (bLen <= 0) { + continue; + } + if (Integer.MAX_VALUE - length < bLen) { + throw new IllegalArgumentException( + "The total length of the specified buffers is too big."); + } + length += bLen; + if (order != null) { + if (!order.equals(b.order())) { + throw new IllegalArgumentException("inconsistent byte order"); + } + } else { + order = b.order(); + } + } + + if (length == 0) { + return EMPTY_BUFFER; + } + + byte[] mergedArray = PlatformDependent.allocateUninitializedArray(length); + for (int i = 0, j = 0; i < buffers.length; i ++) { + ByteBuf b = buffers[i]; + int bLen = b.readableBytes(); + b.getBytes(b.readerIndex(), mergedArray, j, bLen); + j += bLen; + } + + return wrappedBuffer(mergedArray).order(order); + } + + /** + * Creates a new buffer whose content is a merged copy of the specified + * {@code buffers}' slices. The new buffer's {@code readerIndex} and + * {@code writerIndex} are {@code 0} and the sum of all buffers' + * {@code remaining} respectively. + * + * @throws IllegalArgumentException + * if the specified buffers' endianness are different from each + * other + */ + public static ByteBuf copiedBuffer(ByteBuffer... buffers) { + switch (buffers.length) { + case 0: + return EMPTY_BUFFER; + case 1: + return copiedBuffer(buffers[0]); + } + + // Merge the specified buffers into one buffer. + ByteOrder order = null; + int length = 0; + for (ByteBuffer b: buffers) { + int bLen = b.remaining(); + if (bLen <= 0) { + continue; + } + if (Integer.MAX_VALUE - length < bLen) { + throw new IllegalArgumentException( + "The total length of the specified buffers is too big."); + } + length += bLen; + if (order != null) { + if (!order.equals(b.order())) { + throw new IllegalArgumentException("inconsistent byte order"); + } + } else { + order = b.order(); + } + } + + if (length == 0) { + return EMPTY_BUFFER; + } + + byte[] mergedArray = PlatformDependent.allocateUninitializedArray(length); + for (int i = 0, j = 0; i < buffers.length; i ++) { + // Duplicate the buffer so we not adjust the position during our get operation. + // See https://github.com/netty/netty/issues/3896 + ByteBuffer b = buffers[i].duplicate(); + int bLen = b.remaining(); + b.get(mergedArray, j, bLen); + j += bLen; + } + + return wrappedBuffer(mergedArray).order(order); + } + + /** + * Creates a new big-endian buffer whose content is the specified + * {@code string} encoded in the specified {@code charset}. + * The new buffer's {@code readerIndex} and {@code writerIndex} are + * {@code 0} and the length of the encoded string respectively. + */ + public static ByteBuf copiedBuffer(CharSequence string, Charset charset) { + ObjectUtil.checkNotNull(string, "string"); + if (CharsetUtil.UTF_8.equals(charset)) { + return copiedBufferUtf8(string); + } + if (CharsetUtil.US_ASCII.equals(charset)) { + return copiedBufferAscii(string); + } + if (string instanceof CharBuffer) { + return copiedBuffer((CharBuffer) string, charset); + } + + return copiedBuffer(CharBuffer.wrap(string), charset); + } + + private static ByteBuf copiedBufferUtf8(CharSequence string) { + boolean release = true; + // Mimic the same behavior as other copiedBuffer implementations. + ByteBuf buffer = ALLOC.heapBuffer(ByteBufUtil.utf8Bytes(string)); + try { + ByteBufUtil.writeUtf8(buffer, string); + release = false; + return buffer; + } finally { + if (release) { + buffer.release(); + } + } + } + + private static ByteBuf copiedBufferAscii(CharSequence string) { + boolean release = true; + // Mimic the same behavior as other copiedBuffer implementations. + ByteBuf buffer = ALLOC.heapBuffer(string.length()); + try { + ByteBufUtil.writeAscii(buffer, string); + release = false; + return buffer; + } finally { + if (release) { + buffer.release(); + } + } + } + + /** + * Creates a new big-endian buffer whose content is a subregion of + * the specified {@code string} encoded in the specified {@code charset}. + * The new buffer's {@code readerIndex} and {@code writerIndex} are + * {@code 0} and the length of the encoded string respectively. + */ + public static ByteBuf copiedBuffer( + CharSequence string, int offset, int length, Charset charset) { + ObjectUtil.checkNotNull(string, "string"); + if (length == 0) { + return EMPTY_BUFFER; + } + + if (string instanceof CharBuffer) { + CharBuffer buf = (CharBuffer) string; + if (buf.hasArray()) { + return copiedBuffer( + buf.array(), + buf.arrayOffset() + buf.position() + offset, + length, charset); + } + + buf = buf.slice(); + buf.limit(length); + buf.position(offset); + return copiedBuffer(buf, charset); + } + + return copiedBuffer(CharBuffer.wrap(string, offset, offset + length), charset); + } + + /** + * Creates a new big-endian buffer whose content is the specified + * {@code array} encoded in the specified {@code charset}. + * The new buffer's {@code readerIndex} and {@code writerIndex} are + * {@code 0} and the length of the encoded string respectively. + */ + public static ByteBuf copiedBuffer(char[] array, Charset charset) { + ObjectUtil.checkNotNull(array, "array"); + return copiedBuffer(array, 0, array.length, charset); + } + + /** + * Creates a new big-endian buffer whose content is a subregion of + * the specified {@code array} encoded in the specified {@code charset}. + * The new buffer's {@code readerIndex} and {@code writerIndex} are + * {@code 0} and the length of the encoded string respectively. + */ + public static ByteBuf copiedBuffer(char[] array, int offset, int length, Charset charset) { + ObjectUtil.checkNotNull(array, "array"); + if (length == 0) { + return EMPTY_BUFFER; + } + return copiedBuffer(CharBuffer.wrap(array, offset, length), charset); + } + + private static ByteBuf copiedBuffer(CharBuffer buffer, Charset charset) { + return ByteBufUtil.encodeString0(ALLOC, true, buffer, charset, 0); + } + + /** + * Creates a read-only buffer which disallows any modification operations + * on the specified {@code buffer}. The new buffer has the same + * {@code readerIndex} and {@code writerIndex} with the specified + * {@code buffer}. + * + * @deprecated Use {@link ByteBuf#asReadOnly()}. + */ + @Deprecated + public static ByteBuf unmodifiableBuffer(ByteBuf buffer) { + ByteOrder endianness = buffer.order(); + if (endianness == BIG_ENDIAN) { + return new ReadOnlyByteBuf(buffer); + } + + return new ReadOnlyByteBuf(buffer.order(BIG_ENDIAN)).order(LITTLE_ENDIAN); + } + + /** + * Creates a new 4-byte big-endian buffer that holds the specified 32-bit integer. + */ + public static ByteBuf copyInt(int value) { + ByteBuf buf = buffer(4); + buf.writeInt(value); + return buf; + } + + /** + * Create a big-endian buffer that holds a sequence of the specified 32-bit integers. + */ + public static ByteBuf copyInt(int... values) { + if (values == null || values.length == 0) { + return EMPTY_BUFFER; + } + ByteBuf buffer = buffer(values.length * 4); + for (int v: values) { + buffer.writeInt(v); + } + return buffer; + } + + /** + * Creates a new 2-byte big-endian buffer that holds the specified 16-bit integer. + */ + public static ByteBuf copyShort(int value) { + ByteBuf buf = buffer(2); + buf.writeShort(value); + return buf; + } + + /** + * Create a new big-endian buffer that holds a sequence of the specified 16-bit integers. + */ + public static ByteBuf copyShort(short... values) { + if (values == null || values.length == 0) { + return EMPTY_BUFFER; + } + ByteBuf buffer = buffer(values.length * 2); + for (int v: values) { + buffer.writeShort(v); + } + return buffer; + } + + /** + * Create a new big-endian buffer that holds a sequence of the specified 16-bit integers. + */ + public static ByteBuf copyShort(int... values) { + if (values == null || values.length == 0) { + return EMPTY_BUFFER; + } + ByteBuf buffer = buffer(values.length * 2); + for (int v: values) { + buffer.writeShort(v); + } + return buffer; + } + + /** + * Creates a new 3-byte big-endian buffer that holds the specified 24-bit integer. + */ + public static ByteBuf copyMedium(int value) { + ByteBuf buf = buffer(3); + buf.writeMedium(value); + return buf; + } + + /** + * Create a new big-endian buffer that holds a sequence of the specified 24-bit integers. + */ + public static ByteBuf copyMedium(int... values) { + if (values == null || values.length == 0) { + return EMPTY_BUFFER; + } + ByteBuf buffer = buffer(values.length * 3); + for (int v: values) { + buffer.writeMedium(v); + } + return buffer; + } + + /** + * Creates a new 8-byte big-endian buffer that holds the specified 64-bit integer. + */ + public static ByteBuf copyLong(long value) { + ByteBuf buf = buffer(8); + buf.writeLong(value); + return buf; + } + + /** + * Create a new big-endian buffer that holds a sequence of the specified 64-bit integers. + */ + public static ByteBuf copyLong(long... values) { + if (values == null || values.length == 0) { + return EMPTY_BUFFER; + } + ByteBuf buffer = buffer(values.length * 8); + for (long v: values) { + buffer.writeLong(v); + } + return buffer; + } + + /** + * Creates a new single-byte big-endian buffer that holds the specified boolean value. + */ + public static ByteBuf copyBoolean(boolean value) { + ByteBuf buf = buffer(1); + buf.writeBoolean(value); + return buf; + } + + /** + * Create a new big-endian buffer that holds a sequence of the specified boolean values. + */ + public static ByteBuf copyBoolean(boolean... values) { + if (values == null || values.length == 0) { + return EMPTY_BUFFER; + } + ByteBuf buffer = buffer(values.length); + for (boolean v: values) { + buffer.writeBoolean(v); + } + return buffer; + } + + /** + * Creates a new 4-byte big-endian buffer that holds the specified 32-bit floating point number. + */ + public static ByteBuf copyFloat(float value) { + ByteBuf buf = buffer(4); + buf.writeFloat(value); + return buf; + } + + /** + * Create a new big-endian buffer that holds a sequence of the specified 32-bit floating point numbers. + */ + public static ByteBuf copyFloat(float... values) { + if (values == null || values.length == 0) { + return EMPTY_BUFFER; + } + ByteBuf buffer = buffer(values.length * 4); + for (float v: values) { + buffer.writeFloat(v); + } + return buffer; + } + + /** + * Creates a new 8-byte big-endian buffer that holds the specified 64-bit floating point number. + */ + public static ByteBuf copyDouble(double value) { + ByteBuf buf = buffer(8); + buf.writeDouble(value); + return buf; + } + + /** + * Create a new big-endian buffer that holds a sequence of the specified 64-bit floating point numbers. + */ + public static ByteBuf copyDouble(double... values) { + if (values == null || values.length == 0) { + return EMPTY_BUFFER; + } + ByteBuf buffer = buffer(values.length * 8); + for (double v: values) { + buffer.writeDouble(v); + } + return buffer; + } + + /** + * Return a unreleasable view on the given {@link ByteBuf} which will just ignore release and retain calls. + */ + public static ByteBuf unreleasableBuffer(ByteBuf buf) { + return new UnreleasableByteBuf(buf); + } + + /** + * Wrap the given {@link ByteBuf}s in an unmodifiable {@link ByteBuf}. Be aware the returned {@link ByteBuf} will + * not try to slice the given {@link ByteBuf}s to reduce GC-Pressure. + * + * @deprecated Use {@link #wrappedUnmodifiableBuffer(ByteBuf...)}. + */ + @Deprecated + public static ByteBuf unmodifiableBuffer(ByteBuf... buffers) { + return wrappedUnmodifiableBuffer(true, buffers); + } + + /** + * Wrap the given {@link ByteBuf}s in an unmodifiable {@link ByteBuf}. Be aware the returned {@link ByteBuf} will + * not try to slice the given {@link ByteBuf}s to reduce GC-Pressure. + * + * The returned {@link ByteBuf} may wrap the provided array directly, and so should not be subsequently modified. + */ + public static ByteBuf wrappedUnmodifiableBuffer(ByteBuf... buffers) { + return wrappedUnmodifiableBuffer(false, buffers); + } + + private static ByteBuf wrappedUnmodifiableBuffer(boolean copy, ByteBuf... buffers) { + switch (buffers.length) { + case 0: + return EMPTY_BUFFER; + case 1: + return buffers[0].asReadOnly(); + default: + if (copy) { + buffers = Arrays.copyOf(buffers, buffers.length, ByteBuf[].class); + } + return new FixedCompositeByteBuf(ALLOC, buffers); + } + } + + private Unpooled() { + // Unused + } +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/UnpooledByteBufAllocator.java b/netty-buffer/src/main/java/io/netty/buffer/UnpooledByteBufAllocator.java new file mode 100644 index 0000000..0d70e99 --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/UnpooledByteBufAllocator.java @@ -0,0 +1,269 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import io.netty.util.internal.LongCounter; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.StringUtil; + +import java.nio.ByteBuffer; + +/** + * Simplistic {@link ByteBufAllocator} implementation that does not pool anything. + */ +public final class UnpooledByteBufAllocator extends AbstractByteBufAllocator implements ByteBufAllocatorMetricProvider { + + private final UnpooledByteBufAllocatorMetric metric = new UnpooledByteBufAllocatorMetric(); + private final boolean disableLeakDetector; + private final boolean noCleaner; + + /** + * Default instance which uses leak-detection for direct buffers. + */ + public static final UnpooledByteBufAllocator DEFAULT = + new UnpooledByteBufAllocator(PlatformDependent.directBufferPreferred()); + + /** + * Create a new instance which uses leak-detection for direct buffers. + * + * @param preferDirect {@code true} if {@link #buffer(int)} should try to allocate a direct buffer rather than + * a heap buffer + */ + public UnpooledByteBufAllocator(boolean preferDirect) { + this(preferDirect, false); + } + + /** + * Create a new instance + * + * @param preferDirect {@code true} if {@link #buffer(int)} should try to allocate a direct buffer rather than + * a heap buffer + * @param disableLeakDetector {@code true} if the leak-detection should be disabled completely for this + * allocator. This can be useful if the user just want to depend on the GC to handle + * direct buffers when not explicit released. + */ + public UnpooledByteBufAllocator(boolean preferDirect, boolean disableLeakDetector) { + this(preferDirect, disableLeakDetector, PlatformDependent.useDirectBufferNoCleaner()); + } + + /** + * Create a new instance + * + * @param preferDirect {@code true} if {@link #buffer(int)} should try to allocate a direct buffer rather than + * a heap buffer + * @param disableLeakDetector {@code true} if the leak-detection should be disabled completely for this + * allocator. This can be useful if the user just want to depend on the GC to handle + * direct buffers when not explicit released. + * @param tryNoCleaner {@code true} if we should try to use {@link PlatformDependent#allocateDirectNoCleaner(int)} + * to allocate direct memory. + */ + public UnpooledByteBufAllocator(boolean preferDirect, boolean disableLeakDetector, boolean tryNoCleaner) { + super(preferDirect); + this.disableLeakDetector = disableLeakDetector; + noCleaner = tryNoCleaner && PlatformDependent.hasUnsafe() + && PlatformDependent.hasDirectBufferNoCleanerConstructor(); + } + + @Override + protected ByteBuf newHeapBuffer(int initialCapacity, int maxCapacity) { + return PlatformDependent.hasUnsafe() ? + new InstrumentedUnpooledUnsafeHeapByteBuf(this, initialCapacity, maxCapacity) : + new InstrumentedUnpooledHeapByteBuf(this, initialCapacity, maxCapacity); + } + + @Override + protected ByteBuf newDirectBuffer(int initialCapacity, int maxCapacity) { + final ByteBuf buf; + if (PlatformDependent.hasUnsafe()) { + buf = noCleaner ? new InstrumentedUnpooledUnsafeNoCleanerDirectByteBuf(this, initialCapacity, maxCapacity) : + new InstrumentedUnpooledUnsafeDirectByteBuf(this, initialCapacity, maxCapacity); + } else { + buf = new InstrumentedUnpooledDirectByteBuf(this, initialCapacity, maxCapacity); + } + return disableLeakDetector ? buf : toLeakAwareBuffer(buf); + } + + @Override + public CompositeByteBuf compositeHeapBuffer(int maxNumComponents) { + CompositeByteBuf buf = new CompositeByteBuf(this, false, maxNumComponents); + return disableLeakDetector ? buf : toLeakAwareBuffer(buf); + } + + @Override + public CompositeByteBuf compositeDirectBuffer(int maxNumComponents) { + CompositeByteBuf buf = new CompositeByteBuf(this, true, maxNumComponents); + return disableLeakDetector ? buf : toLeakAwareBuffer(buf); + } + + @Override + public boolean isDirectBufferPooled() { + return false; + } + + @Override + public ByteBufAllocatorMetric metric() { + return metric; + } + + void incrementDirect(int amount) { + metric.directCounter.add(amount); + } + + void decrementDirect(int amount) { + metric.directCounter.add(-amount); + } + + void incrementHeap(int amount) { + metric.heapCounter.add(amount); + } + + void decrementHeap(int amount) { + metric.heapCounter.add(-amount); + } + + private static final class InstrumentedUnpooledUnsafeHeapByteBuf extends UnpooledUnsafeHeapByteBuf { + InstrumentedUnpooledUnsafeHeapByteBuf(UnpooledByteBufAllocator alloc, int initialCapacity, int maxCapacity) { + super(alloc, initialCapacity, maxCapacity); + } + + @Override + protected byte[] allocateArray(int initialCapacity) { + byte[] bytes = super.allocateArray(initialCapacity); + ((UnpooledByteBufAllocator) alloc()).incrementHeap(bytes.length); + return bytes; + } + + @Override + protected void freeArray(byte[] array) { + int length = array.length; + super.freeArray(array); + ((UnpooledByteBufAllocator) alloc()).decrementHeap(length); + } + } + + private static final class InstrumentedUnpooledHeapByteBuf extends UnpooledHeapByteBuf { + InstrumentedUnpooledHeapByteBuf(UnpooledByteBufAllocator alloc, int initialCapacity, int maxCapacity) { + super(alloc, initialCapacity, maxCapacity); + } + + @Override + protected byte[] allocateArray(int initialCapacity) { + byte[] bytes = super.allocateArray(initialCapacity); + ((UnpooledByteBufAllocator) alloc()).incrementHeap(bytes.length); + return bytes; + } + + @Override + protected void freeArray(byte[] array) { + int length = array.length; + super.freeArray(array); + ((UnpooledByteBufAllocator) alloc()).decrementHeap(length); + } + } + + private static final class InstrumentedUnpooledUnsafeNoCleanerDirectByteBuf + extends UnpooledUnsafeNoCleanerDirectByteBuf { + InstrumentedUnpooledUnsafeNoCleanerDirectByteBuf( + UnpooledByteBufAllocator alloc, int initialCapacity, int maxCapacity) { + super(alloc, initialCapacity, maxCapacity); + } + + @Override + protected ByteBuffer allocateDirect(int initialCapacity) { + ByteBuffer buffer = super.allocateDirect(initialCapacity); + ((UnpooledByteBufAllocator) alloc()).incrementDirect(buffer.capacity()); + return buffer; + } + + @Override + ByteBuffer reallocateDirect(ByteBuffer oldBuffer, int initialCapacity) { + int capacity = oldBuffer.capacity(); + ByteBuffer buffer = super.reallocateDirect(oldBuffer, initialCapacity); + ((UnpooledByteBufAllocator) alloc()).incrementDirect(buffer.capacity() - capacity); + return buffer; + } + + @Override + protected void freeDirect(ByteBuffer buffer) { + int capacity = buffer.capacity(); + super.freeDirect(buffer); + ((UnpooledByteBufAllocator) alloc()).decrementDirect(capacity); + } + } + + private static final class InstrumentedUnpooledUnsafeDirectByteBuf extends UnpooledUnsafeDirectByteBuf { + InstrumentedUnpooledUnsafeDirectByteBuf( + UnpooledByteBufAllocator alloc, int initialCapacity, int maxCapacity) { + super(alloc, initialCapacity, maxCapacity); + } + + @Override + protected ByteBuffer allocateDirect(int initialCapacity) { + ByteBuffer buffer = super.allocateDirect(initialCapacity); + ((UnpooledByteBufAllocator) alloc()).incrementDirect(buffer.capacity()); + return buffer; + } + + @Override + protected void freeDirect(ByteBuffer buffer) { + int capacity = buffer.capacity(); + super.freeDirect(buffer); + ((UnpooledByteBufAllocator) alloc()).decrementDirect(capacity); + } + } + + private static final class InstrumentedUnpooledDirectByteBuf extends UnpooledDirectByteBuf { + InstrumentedUnpooledDirectByteBuf( + UnpooledByteBufAllocator alloc, int initialCapacity, int maxCapacity) { + super(alloc, initialCapacity, maxCapacity); + } + + @Override + protected ByteBuffer allocateDirect(int initialCapacity) { + ByteBuffer buffer = super.allocateDirect(initialCapacity); + ((UnpooledByteBufAllocator) alloc()).incrementDirect(buffer.capacity()); + return buffer; + } + + @Override + protected void freeDirect(ByteBuffer buffer) { + int capacity = buffer.capacity(); + super.freeDirect(buffer); + ((UnpooledByteBufAllocator) alloc()).decrementDirect(capacity); + } + } + + private static final class UnpooledByteBufAllocatorMetric implements ByteBufAllocatorMetric { + final LongCounter directCounter = PlatformDependent.newLongCounter(); + final LongCounter heapCounter = PlatformDependent.newLongCounter(); + + @Override + public long usedHeapMemory() { + return heapCounter.value(); + } + + @Override + public long usedDirectMemory() { + return directCounter.value(); + } + + @Override + public String toString() { + return StringUtil.simpleClassName(this) + + "(usedHeapMemory: " + usedHeapMemory() + "; usedDirectMemory: " + usedDirectMemory() + ')'; + } + } +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/UnpooledDirectByteBuf.java b/netty-buffer/src/main/java/io/netty/buffer/UnpooledDirectByteBuf.java new file mode 100644 index 0000000..7888f7b --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/UnpooledDirectByteBuf.java @@ -0,0 +1,654 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.PlatformDependent; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.channels.ClosedChannelException; +import java.nio.channels.FileChannel; +import java.nio.channels.GatheringByteChannel; +import java.nio.channels.ScatteringByteChannel; + +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; + +/** + * A NIO {@link ByteBuffer} based buffer. It is recommended to use + * {@link UnpooledByteBufAllocator#directBuffer(int, int)}, {@link Unpooled#directBuffer(int)} and + * {@link Unpooled#wrappedBuffer(ByteBuffer)} instead of calling the constructor explicitly. + */ +public class UnpooledDirectByteBuf extends AbstractReferenceCountedByteBuf { + + private final ByteBufAllocator alloc; + + ByteBuffer buffer; // accessed by UnpooledUnsafeNoCleanerDirectByteBuf.reallocateDirect() + private ByteBuffer tmpNioBuf; + private int capacity; + private boolean doNotFree; + + /** + * Creates a new direct buffer. + * + * @param initialCapacity the initial capacity of the underlying direct buffer + * @param maxCapacity the maximum capacity of the underlying direct buffer + */ + public UnpooledDirectByteBuf(ByteBufAllocator alloc, int initialCapacity, int maxCapacity) { + super(maxCapacity); + ObjectUtil.checkNotNull(alloc, "alloc"); + checkPositiveOrZero(initialCapacity, "initialCapacity"); + checkPositiveOrZero(maxCapacity, "maxCapacity"); + if (initialCapacity > maxCapacity) { + throw new IllegalArgumentException(String.format( + "initialCapacity(%d) > maxCapacity(%d)", initialCapacity, maxCapacity)); + } + + this.alloc = alloc; + setByteBuffer(allocateDirect(initialCapacity), false); + } + + /** + * Creates a new direct buffer by wrapping the specified initial buffer. + * + * @param maxCapacity the maximum capacity of the underlying direct buffer + */ + protected UnpooledDirectByteBuf(ByteBufAllocator alloc, ByteBuffer initialBuffer, int maxCapacity) { + this(alloc, initialBuffer, maxCapacity, false, true); + } + + UnpooledDirectByteBuf(ByteBufAllocator alloc, ByteBuffer initialBuffer, + int maxCapacity, boolean doFree, boolean slice) { + super(maxCapacity); + ObjectUtil.checkNotNull(alloc, "alloc"); + ObjectUtil.checkNotNull(initialBuffer, "initialBuffer"); + if (!initialBuffer.isDirect()) { + throw new IllegalArgumentException("initialBuffer is not a direct buffer."); + } + if (initialBuffer.isReadOnly()) { + throw new IllegalArgumentException("initialBuffer is a read-only buffer."); + } + + int initialCapacity = initialBuffer.remaining(); + if (initialCapacity > maxCapacity) { + throw new IllegalArgumentException(String.format( + "initialCapacity(%d) > maxCapacity(%d)", initialCapacity, maxCapacity)); + } + + this.alloc = alloc; + doNotFree = !doFree; + setByteBuffer((slice ? initialBuffer.slice() : initialBuffer).order(ByteOrder.BIG_ENDIAN), false); + writerIndex(initialCapacity); + } + + /** + * Allocate a new direct {@link ByteBuffer} with the given initialCapacity. + */ + protected ByteBuffer allocateDirect(int initialCapacity) { + return ByteBuffer.allocateDirect(initialCapacity); + } + + /** + * Free a direct {@link ByteBuffer} + */ + protected void freeDirect(ByteBuffer buffer) { + PlatformDependent.freeDirectBuffer(buffer); + } + + void setByteBuffer(ByteBuffer buffer, boolean tryFree) { + if (tryFree) { + ByteBuffer oldBuffer = this.buffer; + if (oldBuffer != null) { + if (doNotFree) { + doNotFree = false; + } else { + freeDirect(oldBuffer); + } + } + } + + this.buffer = buffer; + tmpNioBuf = null; + capacity = buffer.remaining(); + } + + @Override + public boolean isDirect() { + return true; + } + + @Override + public int capacity() { + return capacity; + } + + @Override + public ByteBuf capacity(int newCapacity) { + checkNewCapacity(newCapacity); + int oldCapacity = capacity; + if (newCapacity == oldCapacity) { + return this; + } + int bytesToCopy; + if (newCapacity > oldCapacity) { + bytesToCopy = oldCapacity; + } else { + trimIndicesToCapacity(newCapacity); + bytesToCopy = newCapacity; + } + ByteBuffer oldBuffer = buffer; + ByteBuffer newBuffer = allocateDirect(newCapacity); + oldBuffer.position(0).limit(bytesToCopy); + newBuffer.position(0).limit(bytesToCopy); + newBuffer.put(oldBuffer).clear(); + setByteBuffer(newBuffer, true); + return this; + } + + @Override + public ByteBufAllocator alloc() { + return alloc; + } + + @Override + public ByteOrder order() { + return ByteOrder.BIG_ENDIAN; + } + + @Override + public boolean hasArray() { + return false; + } + + @Override + public byte[] array() { + throw new UnsupportedOperationException("direct buffer"); + } + + @Override + public int arrayOffset() { + throw new UnsupportedOperationException("direct buffer"); + } + + @Override + public boolean hasMemoryAddress() { + return false; + } + + @Override + public long memoryAddress() { + throw new UnsupportedOperationException(); + } + + @Override + public byte getByte(int index) { + ensureAccessible(); + return _getByte(index); + } + + @Override + protected byte _getByte(int index) { + return buffer.get(index); + } + + @Override + public short getShort(int index) { + ensureAccessible(); + return _getShort(index); + } + + @Override + protected short _getShort(int index) { + return buffer.getShort(index); + } + + @Override + protected short _getShortLE(int index) { + return ByteBufUtil.swapShort(buffer.getShort(index)); + } + + @Override + public int getUnsignedMedium(int index) { + ensureAccessible(); + return _getUnsignedMedium(index); + } + + @Override + protected int _getUnsignedMedium(int index) { + return (getByte(index) & 0xff) << 16 | + (getByte(index + 1) & 0xff) << 8 | + getByte(index + 2) & 0xff; + } + + @Override + protected int _getUnsignedMediumLE(int index) { + return getByte(index) & 0xff | + (getByte(index + 1) & 0xff) << 8 | + (getByte(index + 2) & 0xff) << 16; + } + + @Override + public int getInt(int index) { + ensureAccessible(); + return _getInt(index); + } + + @Override + protected int _getInt(int index) { + return buffer.getInt(index); + } + + @Override + protected int _getIntLE(int index) { + return ByteBufUtil.swapInt(buffer.getInt(index)); + } + + @Override + public long getLong(int index) { + ensureAccessible(); + return _getLong(index); + } + + @Override + protected long _getLong(int index) { + return buffer.getLong(index); + } + + @Override + protected long _getLongLE(int index) { + return ByteBufUtil.swapLong(buffer.getLong(index)); + } + + @Override + public ByteBuf getBytes(int index, ByteBuf dst, int dstIndex, int length) { + checkDstIndex(index, length, dstIndex, dst.capacity()); + if (dst.hasArray()) { + getBytes(index, dst.array(), dst.arrayOffset() + dstIndex, length); + } else if (dst.nioBufferCount() > 0) { + for (ByteBuffer bb: dst.nioBuffers(dstIndex, length)) { + int bbLen = bb.remaining(); + getBytes(index, bb); + index += bbLen; + } + } else { + dst.setBytes(dstIndex, this, index, length); + } + return this; + } + + @Override + public ByteBuf getBytes(int index, byte[] dst, int dstIndex, int length) { + getBytes(index, dst, dstIndex, length, false); + return this; + } + + void getBytes(int index, byte[] dst, int dstIndex, int length, boolean internal) { + checkDstIndex(index, length, dstIndex, dst.length); + + ByteBuffer tmpBuf; + if (internal) { + tmpBuf = internalNioBuffer(); + } else { + tmpBuf = buffer.duplicate(); + } + tmpBuf.clear().position(index).limit(index + length); + tmpBuf.get(dst, dstIndex, length); + } + + @Override + public ByteBuf readBytes(byte[] dst, int dstIndex, int length) { + checkReadableBytes(length); + getBytes(readerIndex, dst, dstIndex, length, true); + readerIndex += length; + return this; + } + + @Override + public ByteBuf getBytes(int index, ByteBuffer dst) { + getBytes(index, dst, false); + return this; + } + + void getBytes(int index, ByteBuffer dst, boolean internal) { + checkIndex(index, dst.remaining()); + + ByteBuffer tmpBuf; + if (internal) { + tmpBuf = internalNioBuffer(); + } else { + tmpBuf = buffer.duplicate(); + } + tmpBuf.clear().position(index).limit(index + dst.remaining()); + dst.put(tmpBuf); + } + + @Override + public ByteBuf readBytes(ByteBuffer dst) { + int length = dst.remaining(); + checkReadableBytes(length); + getBytes(readerIndex, dst, true); + readerIndex += length; + return this; + } + + @Override + public ByteBuf setByte(int index, int value) { + ensureAccessible(); + _setByte(index, value); + return this; + } + + @Override + protected void _setByte(int index, int value) { + buffer.put(index, (byte) value); + } + + @Override + public ByteBuf setShort(int index, int value) { + ensureAccessible(); + _setShort(index, value); + return this; + } + + @Override + protected void _setShort(int index, int value) { + buffer.putShort(index, (short) value); + } + + @Override + protected void _setShortLE(int index, int value) { + buffer.putShort(index, ByteBufUtil.swapShort((short) value)); + } + + @Override + public ByteBuf setMedium(int index, int value) { + ensureAccessible(); + _setMedium(index, value); + return this; + } + + @Override + protected void _setMedium(int index, int value) { + setByte(index, (byte) (value >>> 16)); + setByte(index + 1, (byte) (value >>> 8)); + setByte(index + 2, (byte) value); + } + + @Override + protected void _setMediumLE(int index, int value) { + setByte(index, (byte) value); + setByte(index + 1, (byte) (value >>> 8)); + setByte(index + 2, (byte) (value >>> 16)); + } + + @Override + public ByteBuf setInt(int index, int value) { + ensureAccessible(); + _setInt(index, value); + return this; + } + + @Override + protected void _setInt(int index, int value) { + buffer.putInt(index, value); + } + + @Override + protected void _setIntLE(int index, int value) { + buffer.putInt(index, ByteBufUtil.swapInt(value)); + } + + @Override + public ByteBuf setLong(int index, long value) { + ensureAccessible(); + _setLong(index, value); + return this; + } + + @Override + protected void _setLong(int index, long value) { + buffer.putLong(index, value); + } + + @Override + protected void _setLongLE(int index, long value) { + buffer.putLong(index, ByteBufUtil.swapLong(value)); + } + + @Override + public ByteBuf setBytes(int index, ByteBuf src, int srcIndex, int length) { + checkSrcIndex(index, length, srcIndex, src.capacity()); + if (src.nioBufferCount() > 0) { + for (ByteBuffer bb: src.nioBuffers(srcIndex, length)) { + int bbLen = bb.remaining(); + setBytes(index, bb); + index += bbLen; + } + } else { + src.getBytes(srcIndex, this, index, length); + } + return this; + } + + @Override + public ByteBuf setBytes(int index, byte[] src, int srcIndex, int length) { + checkSrcIndex(index, length, srcIndex, src.length); + ByteBuffer tmpBuf = internalNioBuffer(); + tmpBuf.clear().position(index).limit(index + length); + tmpBuf.put(src, srcIndex, length); + return this; + } + + @Override + public ByteBuf setBytes(int index, ByteBuffer src) { + ensureAccessible(); + ByteBuffer tmpBuf = internalNioBuffer(); + if (src == tmpBuf) { + src = src.duplicate(); + } + + tmpBuf.clear().position(index).limit(index + src.remaining()); + tmpBuf.put(src); + return this; + } + + @Override + public ByteBuf getBytes(int index, OutputStream out, int length) throws IOException { + getBytes(index, out, length, false); + return this; + } + + void getBytes(int index, OutputStream out, int length, boolean internal) throws IOException { + ensureAccessible(); + if (length == 0) { + return; + } + ByteBufUtil.readBytes(alloc(), internal ? internalNioBuffer() : buffer.duplicate(), index, length, out); + } + + @Override + public ByteBuf readBytes(OutputStream out, int length) throws IOException { + checkReadableBytes(length); + getBytes(readerIndex, out, length, true); + readerIndex += length; + return this; + } + + @Override + public int getBytes(int index, GatheringByteChannel out, int length) throws IOException { + return getBytes(index, out, length, false); + } + + private int getBytes(int index, GatheringByteChannel out, int length, boolean internal) throws IOException { + ensureAccessible(); + if (length == 0) { + return 0; + } + + ByteBuffer tmpBuf; + if (internal) { + tmpBuf = internalNioBuffer(); + } else { + tmpBuf = buffer.duplicate(); + } + tmpBuf.clear().position(index).limit(index + length); + return out.write(tmpBuf); + } + + @Override + public int getBytes(int index, FileChannel out, long position, int length) throws IOException { + return getBytes(index, out, position, length, false); + } + + private int getBytes(int index, FileChannel out, long position, int length, boolean internal) throws IOException { + ensureAccessible(); + if (length == 0) { + return 0; + } + + ByteBuffer tmpBuf = internal ? internalNioBuffer() : buffer.duplicate(); + tmpBuf.clear().position(index).limit(index + length); + return out.write(tmpBuf, position); + } + + @Override + public int readBytes(GatheringByteChannel out, int length) throws IOException { + checkReadableBytes(length); + int readBytes = getBytes(readerIndex, out, length, true); + readerIndex += readBytes; + return readBytes; + } + + @Override + public int readBytes(FileChannel out, long position, int length) throws IOException { + checkReadableBytes(length); + int readBytes = getBytes(readerIndex, out, position, length, true); + readerIndex += readBytes; + return readBytes; + } + + @Override + public int setBytes(int index, InputStream in, int length) throws IOException { + ensureAccessible(); + if (buffer.hasArray()) { + return in.read(buffer.array(), buffer.arrayOffset() + index, length); + } else { + byte[] tmp = ByteBufUtil.threadLocalTempArray(length); + int readBytes = in.read(tmp, 0, length); + if (readBytes <= 0) { + return readBytes; + } + ByteBuffer tmpBuf = internalNioBuffer(); + tmpBuf.clear().position(index); + tmpBuf.put(tmp, 0, readBytes); + return readBytes; + } + } + + @Override + public int setBytes(int index, ScatteringByteChannel in, int length) throws IOException { + ensureAccessible(); + ByteBuffer tmpBuf = internalNioBuffer(); + tmpBuf.clear().position(index).limit(index + length); + try { + return in.read(tmpBuf); + } catch (ClosedChannelException ignored) { + return -1; + } + } + + @Override + public int setBytes(int index, FileChannel in, long position, int length) throws IOException { + ensureAccessible(); + ByteBuffer tmpBuf = internalNioBuffer(); + tmpBuf.clear().position(index).limit(index + length); + try { + return in.read(tmpBuf, position); + } catch (ClosedChannelException ignored) { + return -1; + } + } + + @Override + public int nioBufferCount() { + return 1; + } + + @Override + public ByteBuffer[] nioBuffers(int index, int length) { + return new ByteBuffer[] { nioBuffer(index, length) }; + } + + @Override + public final boolean isContiguous() { + return true; + } + + @Override + public ByteBuf copy(int index, int length) { + ensureAccessible(); + ByteBuffer src; + try { + src = (ByteBuffer) buffer.duplicate().clear().position(index).limit(index + length); + } catch (IllegalArgumentException ignored) { + throw new IndexOutOfBoundsException("Too many bytes to read - Need " + (index + length)); + } + + return alloc().directBuffer(length, maxCapacity()).writeBytes(src); + } + + @Override + public ByteBuffer internalNioBuffer(int index, int length) { + checkIndex(index, length); + return (ByteBuffer) internalNioBuffer().clear().position(index).limit(index + length); + } + + private ByteBuffer internalNioBuffer() { + ByteBuffer tmpNioBuf = this.tmpNioBuf; + if (tmpNioBuf == null) { + this.tmpNioBuf = tmpNioBuf = buffer.duplicate(); + } + return tmpNioBuf; + } + + @Override + public ByteBuffer nioBuffer(int index, int length) { + checkIndex(index, length); + return ((ByteBuffer) buffer.duplicate().position(index).limit(index + length)).slice(); + } + + @Override + protected void deallocate() { + ByteBuffer buffer = this.buffer; + if (buffer == null) { + return; + } + + this.buffer = null; + + if (!doNotFree) { + freeDirect(buffer); + } + } + + @Override + public ByteBuf unwrap() { + return null; + } +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/UnpooledDuplicatedByteBuf.java b/netty-buffer/src/main/java/io/netty/buffer/UnpooledDuplicatedByteBuf.java new file mode 100644 index 0000000..5b9ea1d --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/UnpooledDuplicatedByteBuf.java @@ -0,0 +1,121 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +/** + * {@link DuplicatedByteBuf} implementation that can do optimizations because it knows the duplicated buffer + * is of type {@link AbstractByteBuf}. + */ +class UnpooledDuplicatedByteBuf extends DuplicatedByteBuf { + UnpooledDuplicatedByteBuf(AbstractByteBuf buffer) { + super(buffer); + } + + @Override + public AbstractByteBuf unwrap() { + return (AbstractByteBuf) super.unwrap(); + } + + @Override + protected byte _getByte(int index) { + return unwrap()._getByte(index); + } + + @Override + protected short _getShort(int index) { + return unwrap()._getShort(index); + } + + @Override + protected short _getShortLE(int index) { + return unwrap()._getShortLE(index); + } + + @Override + protected int _getUnsignedMedium(int index) { + return unwrap()._getUnsignedMedium(index); + } + + @Override + protected int _getUnsignedMediumLE(int index) { + return unwrap()._getUnsignedMediumLE(index); + } + + @Override + protected int _getInt(int index) { + return unwrap()._getInt(index); + } + + @Override + protected int _getIntLE(int index) { + return unwrap()._getIntLE(index); + } + + @Override + protected long _getLong(int index) { + return unwrap()._getLong(index); + } + + @Override + protected long _getLongLE(int index) { + return unwrap()._getLongLE(index); + } + + @Override + protected void _setByte(int index, int value) { + unwrap()._setByte(index, value); + } + + @Override + protected void _setShort(int index, int value) { + unwrap()._setShort(index, value); + } + + @Override + protected void _setShortLE(int index, int value) { + unwrap()._setShortLE(index, value); + } + + @Override + protected void _setMedium(int index, int value) { + unwrap()._setMedium(index, value); + } + + @Override + protected void _setMediumLE(int index, int value) { + unwrap()._setMediumLE(index, value); + } + + @Override + protected void _setInt(int index, int value) { + unwrap()._setInt(index, value); + } + + @Override + protected void _setIntLE(int index, int value) { + unwrap()._setIntLE(index, value); + } + + @Override + protected void _setLong(int index, long value) { + unwrap()._setLong(index, value); + } + + @Override + protected void _setLongLE(int index, long value) { + unwrap()._setLongLE(index, value); + } +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/UnpooledHeapByteBuf.java b/netty-buffer/src/main/java/io/netty/buffer/UnpooledHeapByteBuf.java new file mode 100644 index 0000000..2f1e9a4 --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/UnpooledHeapByteBuf.java @@ -0,0 +1,556 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import io.netty.util.internal.EmptyArrays; +import io.netty.util.internal.PlatformDependent; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.channels.ClosedChannelException; +import java.nio.channels.FileChannel; +import java.nio.channels.GatheringByteChannel; +import java.nio.channels.ScatteringByteChannel; + +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +/** + * Big endian Java heap buffer implementation. It is recommended to use + * {@link UnpooledByteBufAllocator#heapBuffer(int, int)}, {@link Unpooled#buffer(int)} and + * {@link Unpooled#wrappedBuffer(byte[])} instead of calling the constructor explicitly. + */ +public class UnpooledHeapByteBuf extends AbstractReferenceCountedByteBuf { + + private final ByteBufAllocator alloc; + byte[] array; + private ByteBuffer tmpNioBuf; + + /** + * Creates a new heap buffer with a newly allocated byte array. + * + * @param initialCapacity the initial capacity of the underlying byte array + * @param maxCapacity the max capacity of the underlying byte array + */ + public UnpooledHeapByteBuf(ByteBufAllocator alloc, int initialCapacity, int maxCapacity) { + super(maxCapacity); + + if (initialCapacity > maxCapacity) { + throw new IllegalArgumentException(String.format( + "initialCapacity(%d) > maxCapacity(%d)", initialCapacity, maxCapacity)); + } + + this.alloc = checkNotNull(alloc, "alloc"); + setArray(allocateArray(initialCapacity)); + setIndex(0, 0); + } + + /** + * Creates a new heap buffer with an existing byte array. + * + * @param initialArray the initial underlying byte array + * @param maxCapacity the max capacity of the underlying byte array + */ + protected UnpooledHeapByteBuf(ByteBufAllocator alloc, byte[] initialArray, int maxCapacity) { + super(maxCapacity); + + checkNotNull(alloc, "alloc"); + checkNotNull(initialArray, "initialArray"); + if (initialArray.length > maxCapacity) { + throw new IllegalArgumentException(String.format( + "initialCapacity(%d) > maxCapacity(%d)", initialArray.length, maxCapacity)); + } + + this.alloc = alloc; + setArray(initialArray); + setIndex(0, initialArray.length); + } + + protected byte[] allocateArray(int initialCapacity) { + return new byte[initialCapacity]; + } + + protected void freeArray(byte[] array) { + // NOOP + } + + private void setArray(byte[] initialArray) { + array = initialArray; + tmpNioBuf = null; + } + + @Override + public ByteBufAllocator alloc() { + return alloc; + } + + @Override + public ByteOrder order() { + return ByteOrder.BIG_ENDIAN; + } + + @Override + public boolean isDirect() { + return false; + } + + @Override + public int capacity() { + return array.length; + } + + @Override + public ByteBuf capacity(int newCapacity) { + checkNewCapacity(newCapacity); + byte[] oldArray = array; + int oldCapacity = oldArray.length; + if (newCapacity == oldCapacity) { + return this; + } + + int bytesToCopy; + if (newCapacity > oldCapacity) { + bytesToCopy = oldCapacity; + } else { + trimIndicesToCapacity(newCapacity); + bytesToCopy = newCapacity; + } + byte[] newArray = allocateArray(newCapacity); + System.arraycopy(oldArray, 0, newArray, 0, bytesToCopy); + setArray(newArray); + freeArray(oldArray); + return this; + } + + @Override + public boolean hasArray() { + return true; + } + + @Override + public byte[] array() { + ensureAccessible(); + return array; + } + + @Override + public int arrayOffset() { + return 0; + } + + @Override + public boolean hasMemoryAddress() { + return false; + } + + @Override + public long memoryAddress() { + throw new UnsupportedOperationException(); + } + + @Override + public ByteBuf getBytes(int index, ByteBuf dst, int dstIndex, int length) { + checkDstIndex(index, length, dstIndex, dst.capacity()); + if (dst.hasMemoryAddress()) { + PlatformDependent.copyMemory(array, index, dst.memoryAddress() + dstIndex, length); + } else if (dst.hasArray()) { + getBytes(index, dst.array(), dst.arrayOffset() + dstIndex, length); + } else { + dst.setBytes(dstIndex, array, index, length); + } + return this; + } + + @Override + public ByteBuf getBytes(int index, byte[] dst, int dstIndex, int length) { + checkDstIndex(index, length, dstIndex, dst.length); + System.arraycopy(array, index, dst, dstIndex, length); + return this; + } + + @Override + public ByteBuf getBytes(int index, ByteBuffer dst) { + ensureAccessible(); + dst.put(array, index, dst.remaining()); + return this; + } + + @Override + public ByteBuf getBytes(int index, OutputStream out, int length) throws IOException { + ensureAccessible(); + out.write(array, index, length); + return this; + } + + @Override + public int getBytes(int index, GatheringByteChannel out, int length) throws IOException { + ensureAccessible(); + return getBytes(index, out, length, false); + } + + @Override + public int getBytes(int index, FileChannel out, long position, int length) throws IOException { + ensureAccessible(); + return getBytes(index, out, position, length, false); + } + + private int getBytes(int index, GatheringByteChannel out, int length, boolean internal) throws IOException { + ensureAccessible(); + ByteBuffer tmpBuf; + if (internal) { + tmpBuf = internalNioBuffer(); + } else { + tmpBuf = ByteBuffer.wrap(array); + } + return out.write(tmpBuf.clear().position(index).limit(index + length)); + } + + private int getBytes(int index, FileChannel out, long position, int length, boolean internal) throws IOException { + ensureAccessible(); + ByteBuffer tmpBuf = internal ? internalNioBuffer() : ByteBuffer.wrap(array); + return out.write(tmpBuf.clear().position(index).limit(index + length), position); + } + + @Override + public int readBytes(GatheringByteChannel out, int length) throws IOException { + checkReadableBytes(length); + int readBytes = getBytes(readerIndex, out, length, true); + readerIndex += readBytes; + return readBytes; + } + + @Override + public int readBytes(FileChannel out, long position, int length) throws IOException { + checkReadableBytes(length); + int readBytes = getBytes(readerIndex, out, position, length, true); + readerIndex += readBytes; + return readBytes; + } + + @Override + public ByteBuf setBytes(int index, ByteBuf src, int srcIndex, int length) { + checkSrcIndex(index, length, srcIndex, src.capacity()); + if (src.hasMemoryAddress()) { + PlatformDependent.copyMemory(src.memoryAddress() + srcIndex, array, index, length); + } else if (src.hasArray()) { + setBytes(index, src.array(), src.arrayOffset() + srcIndex, length); + } else { + src.getBytes(srcIndex, array, index, length); + } + return this; + } + + @Override + public ByteBuf setBytes(int index, byte[] src, int srcIndex, int length) { + checkSrcIndex(index, length, srcIndex, src.length); + System.arraycopy(src, srcIndex, array, index, length); + return this; + } + + @Override + public ByteBuf setBytes(int index, ByteBuffer src) { + ensureAccessible(); + src.get(array, index, src.remaining()); + return this; + } + + @Override + public int setBytes(int index, InputStream in, int length) throws IOException { + ensureAccessible(); + return in.read(array, index, length); + } + + @Override + public int setBytes(int index, ScatteringByteChannel in, int length) throws IOException { + ensureAccessible(); + try { + return in.read(internalNioBuffer().clear().position(index).limit(index + length)); + } catch (ClosedChannelException ignored) { + return -1; + } + } + + @Override + public int setBytes(int index, FileChannel in, long position, int length) throws IOException { + ensureAccessible(); + try { + return in.read(internalNioBuffer().clear().position(index).limit(index + length), position); + } catch (ClosedChannelException ignored) { + return -1; + } + } + + @Override + public int nioBufferCount() { + return 1; + } + + @Override + public ByteBuffer nioBuffer(int index, int length) { + ensureAccessible(); + return ByteBuffer.wrap(array, index, length).slice(); + } + + @Override + public ByteBuffer[] nioBuffers(int index, int length) { + return new ByteBuffer[] { nioBuffer(index, length) }; + } + + @Override + public ByteBuffer internalNioBuffer(int index, int length) { + checkIndex(index, length); + return internalNioBuffer().clear().position(index).limit(index + length); + } + + @Override + public final boolean isContiguous() { + return true; + } + + @Override + public byte getByte(int index) { + ensureAccessible(); + return _getByte(index); + } + + @Override + protected byte _getByte(int index) { + return HeapByteBufUtil.getByte(array, index); + } + + @Override + public short getShort(int index) { + ensureAccessible(); + return _getShort(index); + } + + @Override + protected short _getShort(int index) { + return HeapByteBufUtil.getShort(array, index); + } + + @Override + public short getShortLE(int index) { + ensureAccessible(); + return _getShortLE(index); + } + + @Override + protected short _getShortLE(int index) { + return HeapByteBufUtil.getShortLE(array, index); + } + + @Override + public int getUnsignedMedium(int index) { + ensureAccessible(); + return _getUnsignedMedium(index); + } + + @Override + protected int _getUnsignedMedium(int index) { + return HeapByteBufUtil.getUnsignedMedium(array, index); + } + + @Override + public int getUnsignedMediumLE(int index) { + ensureAccessible(); + return _getUnsignedMediumLE(index); + } + + @Override + protected int _getUnsignedMediumLE(int index) { + return HeapByteBufUtil.getUnsignedMediumLE(array, index); + } + + @Override + public int getInt(int index) { + ensureAccessible(); + return _getInt(index); + } + + @Override + protected int _getInt(int index) { + return HeapByteBufUtil.getInt(array, index); + } + + @Override + public int getIntLE(int index) { + ensureAccessible(); + return _getIntLE(index); + } + + @Override + protected int _getIntLE(int index) { + return HeapByteBufUtil.getIntLE(array, index); + } + + @Override + public long getLong(int index) { + ensureAccessible(); + return _getLong(index); + } + + @Override + protected long _getLong(int index) { + return HeapByteBufUtil.getLong(array, index); + } + + @Override + public long getLongLE(int index) { + ensureAccessible(); + return _getLongLE(index); + } + + @Override + protected long _getLongLE(int index) { + return HeapByteBufUtil.getLongLE(array, index); + } + + @Override + public ByteBuf setByte(int index, int value) { + ensureAccessible(); + _setByte(index, value); + return this; + } + + @Override + protected void _setByte(int index, int value) { + HeapByteBufUtil.setByte(array, index, value); + } + + @Override + public ByteBuf setShort(int index, int value) { + ensureAccessible(); + _setShort(index, value); + return this; + } + + @Override + protected void _setShort(int index, int value) { + HeapByteBufUtil.setShort(array, index, value); + } + + @Override + public ByteBuf setShortLE(int index, int value) { + ensureAccessible(); + _setShortLE(index, value); + return this; + } + + @Override + protected void _setShortLE(int index, int value) { + HeapByteBufUtil.setShortLE(array, index, value); + } + + @Override + public ByteBuf setMedium(int index, int value) { + ensureAccessible(); + _setMedium(index, value); + return this; + } + + @Override + protected void _setMedium(int index, int value) { + HeapByteBufUtil.setMedium(array, index, value); + } + + @Override + public ByteBuf setMediumLE(int index, int value) { + ensureAccessible(); + _setMediumLE(index, value); + return this; + } + + @Override + protected void _setMediumLE(int index, int value) { + HeapByteBufUtil.setMediumLE(array, index, value); + } + + @Override + public ByteBuf setInt(int index, int value) { + ensureAccessible(); + _setInt(index, value); + return this; + } + + @Override + protected void _setInt(int index, int value) { + HeapByteBufUtil.setInt(array, index, value); + } + + @Override + public ByteBuf setIntLE(int index, int value) { + ensureAccessible(); + _setIntLE(index, value); + return this; + } + + @Override + protected void _setIntLE(int index, int value) { + HeapByteBufUtil.setIntLE(array, index, value); + } + + @Override + public ByteBuf setLong(int index, long value) { + ensureAccessible(); + _setLong(index, value); + return this; + } + + @Override + protected void _setLong(int index, long value) { + HeapByteBufUtil.setLong(array, index, value); + } + + @Override + public ByteBuf setLongLE(int index, long value) { + ensureAccessible(); + _setLongLE(index, value); + return this; + } + + @Override + protected void _setLongLE(int index, long value) { + HeapByteBufUtil.setLongLE(array, index, value); + } + + @Override + public ByteBuf copy(int index, int length) { + checkIndex(index, length); + return alloc().heapBuffer(length, maxCapacity()).writeBytes(array, index, length); + } + + private ByteBuffer internalNioBuffer() { + ByteBuffer tmpNioBuf = this.tmpNioBuf; + if (tmpNioBuf == null) { + this.tmpNioBuf = tmpNioBuf = ByteBuffer.wrap(array); + } + return tmpNioBuf; + } + + @Override + protected void deallocate() { + freeArray(array); + array = EmptyArrays.EMPTY_BYTES; + } + + @Override + public ByteBuf unwrap() { + return null; + } +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/UnpooledSlicedByteBuf.java b/netty-buffer/src/main/java/io/netty/buffer/UnpooledSlicedByteBuf.java new file mode 100644 index 0000000..3c5a765 --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/UnpooledSlicedByteBuf.java @@ -0,0 +1,126 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +/** + * A special {@link AbstractUnpooledSlicedByteBuf} that can make optimizations because it knows the sliced buffer is of + * type {@link AbstractByteBuf}. + */ +class UnpooledSlicedByteBuf extends AbstractUnpooledSlicedByteBuf { + UnpooledSlicedByteBuf(AbstractByteBuf buffer, int index, int length) { + super(buffer, index, length); + } + + @Override + public int capacity() { + return maxCapacity(); + } + + @Override + public AbstractByteBuf unwrap() { + return (AbstractByteBuf) super.unwrap(); + } + + @Override + protected byte _getByte(int index) { + return unwrap()._getByte(idx(index)); + } + + @Override + protected short _getShort(int index) { + return unwrap()._getShort(idx(index)); + } + + @Override + protected short _getShortLE(int index) { + return unwrap()._getShortLE(idx(index)); + } + + @Override + protected int _getUnsignedMedium(int index) { + return unwrap()._getUnsignedMedium(idx(index)); + } + + @Override + protected int _getUnsignedMediumLE(int index) { + return unwrap()._getUnsignedMediumLE(idx(index)); + } + + @Override + protected int _getInt(int index) { + return unwrap()._getInt(idx(index)); + } + + @Override + protected int _getIntLE(int index) { + return unwrap()._getIntLE(idx(index)); + } + + @Override + protected long _getLong(int index) { + return unwrap()._getLong(idx(index)); + } + + @Override + protected long _getLongLE(int index) { + return unwrap()._getLongLE(idx(index)); + } + + @Override + protected void _setByte(int index, int value) { + unwrap()._setByte(idx(index), value); + } + + @Override + protected void _setShort(int index, int value) { + unwrap()._setShort(idx(index), value); + } + + @Override + protected void _setShortLE(int index, int value) { + unwrap()._setShortLE(idx(index), value); + } + + @Override + protected void _setMedium(int index, int value) { + unwrap()._setMedium(idx(index), value); + } + + @Override + protected void _setMediumLE(int index, int value) { + unwrap()._setMediumLE(idx(index), value); + } + + @Override + protected void _setInt(int index, int value) { + unwrap()._setInt(idx(index), value); + } + + @Override + protected void _setIntLE(int index, int value) { + unwrap()._setIntLE(idx(index), value); + } + + @Override + protected void _setLong(int index, long value) { + unwrap()._setLong(idx(index), value); + } + + @Override + protected void _setLongLE(int index, long value) { + unwrap()._setLongLE(idx(index), value); + } +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/UnpooledUnsafeDirectByteBuf.java b/netty-buffer/src/main/java/io/netty/buffer/UnpooledUnsafeDirectByteBuf.java new file mode 100644 index 0000000..6d1e326 --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/UnpooledUnsafeDirectByteBuf.java @@ -0,0 +1,315 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import io.netty.util.internal.PlatformDependent; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.ByteBuffer; + +/** + * A NIO {@link ByteBuffer} based buffer. It is recommended to use + * {@link UnpooledByteBufAllocator#directBuffer(int, int)}, {@link Unpooled#directBuffer(int)} and + * {@link Unpooled#wrappedBuffer(ByteBuffer)} instead of calling the constructor explicitly.} + */ +public class UnpooledUnsafeDirectByteBuf extends UnpooledDirectByteBuf { + + long memoryAddress; + + /** + * Creates a new direct buffer. + * + * @param initialCapacity the initial capacity of the underlying direct buffer + * @param maxCapacity the maximum capacity of the underlying direct buffer + */ + public UnpooledUnsafeDirectByteBuf(ByteBufAllocator alloc, int initialCapacity, int maxCapacity) { + super(alloc, initialCapacity, maxCapacity); + } + + /** + * Creates a new direct buffer by wrapping the specified initial buffer. + * + * @param maxCapacity the maximum capacity of the underlying direct buffer + */ + protected UnpooledUnsafeDirectByteBuf(ByteBufAllocator alloc, ByteBuffer initialBuffer, int maxCapacity) { + // We never try to free the buffer if it was provided by the end-user as we don't know if this is a duplicate or + // a slice. This is done to prevent an IllegalArgumentException when using Java9 as Unsafe.invokeCleaner(...) + // will check if the given buffer is either a duplicate or slice and in this case throw an + // IllegalArgumentException. + // + // See https://hg.openjdk.java.net/jdk9/hs-demo/jdk/file/0d2ab72ba600/src/jdk.unsupported/share/classes/ + // sun/misc/Unsafe.java#l1250 + // + // We also call slice() explicitly here to preserve behaviour with previous netty releases. + super(alloc, initialBuffer, maxCapacity, /* doFree = */ false, /* slice = */ true); + } + + UnpooledUnsafeDirectByteBuf(ByteBufAllocator alloc, ByteBuffer initialBuffer, int maxCapacity, boolean doFree) { + super(alloc, initialBuffer, maxCapacity, doFree, false); + } + + @Override + final void setByteBuffer(ByteBuffer buffer, boolean tryFree) { + super.setByteBuffer(buffer, tryFree); + memoryAddress = PlatformDependent.directBufferAddress(buffer); + } + + @Override + public boolean hasMemoryAddress() { + return true; + } + + @Override + public long memoryAddress() { + ensureAccessible(); + return memoryAddress; + } + + @Override + public byte getByte(int index) { + checkIndex(index); + return _getByte(index); + } + + @Override + protected byte _getByte(int index) { + return UnsafeByteBufUtil.getByte(addr(index)); + } + + @Override + public short getShort(int index) { + checkIndex(index, 2); + return _getShort(index); + } + + @Override + protected short _getShort(int index) { + return UnsafeByteBufUtil.getShort(addr(index)); + } + + @Override + protected short _getShortLE(int index) { + return UnsafeByteBufUtil.getShortLE(addr(index)); + } + + @Override + public int getUnsignedMedium(int index) { + checkIndex(index, 3); + return _getUnsignedMedium(index); + } + + @Override + protected int _getUnsignedMedium(int index) { + return UnsafeByteBufUtil.getUnsignedMedium(addr(index)); + } + + @Override + protected int _getUnsignedMediumLE(int index) { + return UnsafeByteBufUtil.getUnsignedMediumLE(addr(index)); + } + + @Override + public int getInt(int index) { + checkIndex(index, 4); + return _getInt(index); + } + + @Override + protected int _getInt(int index) { + return UnsafeByteBufUtil.getInt(addr(index)); + } + + @Override + protected int _getIntLE(int index) { + return UnsafeByteBufUtil.getIntLE(addr(index)); + } + + @Override + public long getLong(int index) { + checkIndex(index, 8); + return _getLong(index); + } + + @Override + protected long _getLong(int index) { + return UnsafeByteBufUtil.getLong(addr(index)); + } + + @Override + protected long _getLongLE(int index) { + return UnsafeByteBufUtil.getLongLE(addr(index)); + } + + @Override + public ByteBuf getBytes(int index, ByteBuf dst, int dstIndex, int length) { + UnsafeByteBufUtil.getBytes(this, addr(index), index, dst, dstIndex, length); + return this; + } + + @Override + void getBytes(int index, byte[] dst, int dstIndex, int length, boolean internal) { + UnsafeByteBufUtil.getBytes(this, addr(index), index, dst, dstIndex, length); + } + + @Override + void getBytes(int index, ByteBuffer dst, boolean internal) { + UnsafeByteBufUtil.getBytes(this, addr(index), index, dst); + } + + @Override + public ByteBuf setByte(int index, int value) { + checkIndex(index); + _setByte(index, value); + return this; + } + + @Override + protected void _setByte(int index, int value) { + UnsafeByteBufUtil.setByte(addr(index), value); + } + + @Override + public ByteBuf setShort(int index, int value) { + checkIndex(index, 2); + _setShort(index, value); + return this; + } + + @Override + protected void _setShort(int index, int value) { + UnsafeByteBufUtil.setShort(addr(index), value); + } + + @Override + protected void _setShortLE(int index, int value) { + UnsafeByteBufUtil.setShortLE(addr(index), value); + } + + @Override + public ByteBuf setMedium(int index, int value) { + checkIndex(index, 3); + _setMedium(index, value); + return this; + } + + @Override + protected void _setMedium(int index, int value) { + UnsafeByteBufUtil.setMedium(addr(index), value); + } + + @Override + protected void _setMediumLE(int index, int value) { + UnsafeByteBufUtil.setMediumLE(addr(index), value); + } + + @Override + public ByteBuf setInt(int index, int value) { + checkIndex(index, 4); + _setInt(index, value); + return this; + } + + @Override + protected void _setInt(int index, int value) { + UnsafeByteBufUtil.setInt(addr(index), value); + } + + @Override + protected void _setIntLE(int index, int value) { + UnsafeByteBufUtil.setIntLE(addr(index), value); + } + + @Override + public ByteBuf setLong(int index, long value) { + checkIndex(index, 8); + _setLong(index, value); + return this; + } + + @Override + protected void _setLong(int index, long value) { + UnsafeByteBufUtil.setLong(addr(index), value); + } + + @Override + protected void _setLongLE(int index, long value) { + UnsafeByteBufUtil.setLongLE(addr(index), value); + } + + @Override + public ByteBuf setBytes(int index, ByteBuf src, int srcIndex, int length) { + UnsafeByteBufUtil.setBytes(this, addr(index), index, src, srcIndex, length); + return this; + } + + @Override + public ByteBuf setBytes(int index, byte[] src, int srcIndex, int length) { + UnsafeByteBufUtil.setBytes(this, addr(index), index, src, srcIndex, length); + return this; + } + + @Override + public ByteBuf setBytes(int index, ByteBuffer src) { + UnsafeByteBufUtil.setBytes(this, addr(index), index, src); + return this; + } + + @Override + void getBytes(int index, OutputStream out, int length, boolean internal) throws IOException { + UnsafeByteBufUtil.getBytes(this, addr(index), index, out, length); + } + + @Override + public int setBytes(int index, InputStream in, int length) throws IOException { + return UnsafeByteBufUtil.setBytes(this, addr(index), index, in, length); + } + + @Override + public ByteBuf copy(int index, int length) { + return UnsafeByteBufUtil.copy(this, addr(index), index, length); + } + + final long addr(int index) { + return memoryAddress + index; + } + + @Override + protected SwappedByteBuf newSwappedByteBuf() { + if (PlatformDependent.isUnaligned()) { + // Only use if unaligned access is supported otherwise there is no gain. + return new UnsafeDirectSwappedByteBuf(this); + } + return super.newSwappedByteBuf(); + } + + @Override + public ByteBuf setZero(int index, int length) { + checkIndex(index, length); + UnsafeByteBufUtil.setZero(addr(index), length); + return this; + } + + @Override + public ByteBuf writeZero(int length) { + ensureWritable(length); + int wIndex = writerIndex; + UnsafeByteBufUtil.setZero(addr(wIndex), length); + writerIndex = wIndex + length; + return this; + } +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/UnpooledUnsafeHeapByteBuf.java b/netty-buffer/src/main/java/io/netty/buffer/UnpooledUnsafeHeapByteBuf.java new file mode 100644 index 0000000..2644df4 --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/UnpooledUnsafeHeapByteBuf.java @@ -0,0 +1,282 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import io.netty.util.internal.PlatformDependent; + +/** + * Big endian Java heap buffer implementation. It is recommended to use + * {@link UnpooledByteBufAllocator#heapBuffer(int, int)}, {@link Unpooled#buffer(int)} and + * {@link Unpooled#wrappedBuffer(byte[])} instead of calling the constructor explicitly. + */ +public class UnpooledUnsafeHeapByteBuf extends UnpooledHeapByteBuf { + + /** + * Creates a new heap buffer with a newly allocated byte array. + * + * @param initialCapacity the initial capacity of the underlying byte array + * @param maxCapacity the max capacity of the underlying byte array + */ + public UnpooledUnsafeHeapByteBuf(ByteBufAllocator alloc, int initialCapacity, int maxCapacity) { + super(alloc, initialCapacity, maxCapacity); + } + + @Override + protected byte[] allocateArray(int initialCapacity) { + return PlatformDependent.allocateUninitializedArray(initialCapacity); + } + + @Override + public byte getByte(int index) { + checkIndex(index); + return _getByte(index); + } + + @Override + protected byte _getByte(int index) { + return UnsafeByteBufUtil.getByte(array, index); + } + + @Override + public short getShort(int index) { + checkIndex(index, 2); + return _getShort(index); + } + + @Override + protected short _getShort(int index) { + return UnsafeByteBufUtil.getShort(array, index); + } + + @Override + public short getShortLE(int index) { + checkIndex(index, 2); + return _getShortLE(index); + } + + @Override + protected short _getShortLE(int index) { + return UnsafeByteBufUtil.getShortLE(array, index); + } + + @Override + public int getUnsignedMedium(int index) { + checkIndex(index, 3); + return _getUnsignedMedium(index); + } + + @Override + protected int _getUnsignedMedium(int index) { + return UnsafeByteBufUtil.getUnsignedMedium(array, index); + } + + @Override + public int getUnsignedMediumLE(int index) { + checkIndex(index, 3); + return _getUnsignedMediumLE(index); + } + + @Override + protected int _getUnsignedMediumLE(int index) { + return UnsafeByteBufUtil.getUnsignedMediumLE(array, index); + } + + @Override + public int getInt(int index) { + checkIndex(index, 4); + return _getInt(index); + } + + @Override + protected int _getInt(int index) { + return UnsafeByteBufUtil.getInt(array, index); + } + + @Override + public int getIntLE(int index) { + checkIndex(index, 4); + return _getIntLE(index); + } + + @Override + protected int _getIntLE(int index) { + return UnsafeByteBufUtil.getIntLE(array, index); + } + + @Override + public long getLong(int index) { + checkIndex(index, 8); + return _getLong(index); + } + + @Override + protected long _getLong(int index) { + return UnsafeByteBufUtil.getLong(array, index); + } + + @Override + public long getLongLE(int index) { + checkIndex(index, 8); + return _getLongLE(index); + } + + @Override + protected long _getLongLE(int index) { + return UnsafeByteBufUtil.getLongLE(array, index); + } + + @Override + public ByteBuf setByte(int index, int value) { + checkIndex(index); + _setByte(index, value); + return this; + } + + @Override + protected void _setByte(int index, int value) { + UnsafeByteBufUtil.setByte(array, index, value); + } + + @Override + public ByteBuf setShort(int index, int value) { + checkIndex(index, 2); + _setShort(index, value); + return this; + } + + @Override + protected void _setShort(int index, int value) { + UnsafeByteBufUtil.setShort(array, index, value); + } + + @Override + public ByteBuf setShortLE(int index, int value) { + checkIndex(index, 2); + _setShortLE(index, value); + return this; + } + + @Override + protected void _setShortLE(int index, int value) { + UnsafeByteBufUtil.setShortLE(array, index, value); + } + + @Override + public ByteBuf setMedium(int index, int value) { + checkIndex(index, 3); + _setMedium(index, value); + return this; + } + + @Override + protected void _setMedium(int index, int value) { + UnsafeByteBufUtil.setMedium(array, index, value); + } + + @Override + public ByteBuf setMediumLE(int index, int value) { + checkIndex(index, 3); + _setMediumLE(index, value); + return this; + } + + @Override + protected void _setMediumLE(int index, int value) { + UnsafeByteBufUtil.setMediumLE(array, index, value); + } + + @Override + public ByteBuf setInt(int index, int value) { + checkIndex(index, 4); + _setInt(index, value); + return this; + } + + @Override + protected void _setInt(int index, int value) { + UnsafeByteBufUtil.setInt(array, index, value); + } + + @Override + public ByteBuf setIntLE(int index, int value) { + checkIndex(index, 4); + _setIntLE(index, value); + return this; + } + + @Override + protected void _setIntLE(int index, int value) { + UnsafeByteBufUtil.setIntLE(array, index, value); + } + + @Override + public ByteBuf setLong(int index, long value) { + checkIndex(index, 8); + _setLong(index, value); + return this; + } + + @Override + protected void _setLong(int index, long value) { + UnsafeByteBufUtil.setLong(array, index, value); + } + + @Override + public ByteBuf setLongLE(int index, long value) { + checkIndex(index, 8); + _setLongLE(index, value); + return this; + } + + @Override + protected void _setLongLE(int index, long value) { + UnsafeByteBufUtil.setLongLE(array, index, value); + } + + @Override + public ByteBuf setZero(int index, int length) { + if (PlatformDependent.javaVersion() >= 7) { + // Only do on java7+ as the needed Unsafe call was only added there. + checkIndex(index, length); + UnsafeByteBufUtil.setZero(array, index, length); + return this; + } + return super.setZero(index, length); + } + + @Override + public ByteBuf writeZero(int length) { + if (PlatformDependent.javaVersion() >= 7) { + // Only do on java7+ as the needed Unsafe call was only added there. + ensureWritable(length); + int wIndex = writerIndex; + UnsafeByteBufUtil.setZero(array, wIndex, length); + writerIndex = wIndex + length; + return this; + } + return super.writeZero(length); + } + + @Override + @Deprecated + protected SwappedByteBuf newSwappedByteBuf() { + if (PlatformDependent.isUnaligned()) { + // Only use if unaligned access is supported otherwise there is no gain. + return new UnsafeHeapSwappedByteBuf(this); + } + return super.newSwappedByteBuf(); + } +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/UnpooledUnsafeNoCleanerDirectByteBuf.java b/netty-buffer/src/main/java/io/netty/buffer/UnpooledUnsafeNoCleanerDirectByteBuf.java new file mode 100644 index 0000000..e30f929 --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/UnpooledUnsafeNoCleanerDirectByteBuf.java @@ -0,0 +1,55 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import io.netty.util.internal.PlatformDependent; + +import java.nio.ByteBuffer; + +class UnpooledUnsafeNoCleanerDirectByteBuf extends UnpooledUnsafeDirectByteBuf { + + UnpooledUnsafeNoCleanerDirectByteBuf(ByteBufAllocator alloc, int initialCapacity, int maxCapacity) { + super(alloc, initialCapacity, maxCapacity); + } + + @Override + protected ByteBuffer allocateDirect(int initialCapacity) { + return PlatformDependent.allocateDirectNoCleaner(initialCapacity); + } + + ByteBuffer reallocateDirect(ByteBuffer oldBuffer, int initialCapacity) { + return PlatformDependent.reallocateDirectNoCleaner(oldBuffer, initialCapacity); + } + + @Override + protected void freeDirect(ByteBuffer buffer) { + PlatformDependent.freeDirectNoCleaner(buffer); + } + + @Override + public ByteBuf capacity(int newCapacity) { + checkNewCapacity(newCapacity); + + int oldCapacity = capacity(); + if (newCapacity == oldCapacity) { + return this; + } + + trimIndicesToCapacity(newCapacity); + setByteBuffer(reallocateDirect(buffer, newCapacity), false); + return this; + } +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/UnreleasableByteBuf.java b/netty-buffer/src/main/java/io/netty/buffer/UnreleasableByteBuf.java new file mode 100644 index 0000000..a4c6ef7 --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/UnreleasableByteBuf.java @@ -0,0 +1,133 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import io.netty.util.internal.ObjectUtil; + +import java.nio.ByteOrder; + +/** + * A {@link ByteBuf} implementation that wraps another buffer to prevent a user from increasing or decreasing the + * wrapped buffer's reference count. + */ +final class UnreleasableByteBuf extends WrappedByteBuf { + + private SwappedByteBuf swappedBuf; + + UnreleasableByteBuf(ByteBuf buf) { + super(buf instanceof UnreleasableByteBuf ? buf.unwrap() : buf); + } + + @Override + public ByteBuf order(ByteOrder endianness) { + if (ObjectUtil.checkNotNull(endianness, "endianness") == order()) { + return this; + } + + SwappedByteBuf swappedBuf = this.swappedBuf; + if (swappedBuf == null) { + this.swappedBuf = swappedBuf = new SwappedByteBuf(this); + } + return swappedBuf; + } + + @Override + public ByteBuf asReadOnly() { + return buf.isReadOnly() ? this : new UnreleasableByteBuf(buf.asReadOnly()); + } + + @Override + public ByteBuf readSlice(int length) { + return new UnreleasableByteBuf(buf.readSlice(length)); + } + + @Override + public ByteBuf readRetainedSlice(int length) { + // We could call buf.readSlice(..), and then call buf.release(). However this creates a leak in unit tests + // because the release method on UnreleasableByteBuf will never allow the leak record to be cleaned up. + // So we just use readSlice(..) because the end result should be logically equivalent. + return readSlice(length); + } + + @Override + public ByteBuf slice() { + return new UnreleasableByteBuf(buf.slice()); + } + + @Override + public ByteBuf retainedSlice() { + // We could call buf.retainedSlice(), and then call buf.release(). However this creates a leak in unit tests + // because the release method on UnreleasableByteBuf will never allow the leak record to be cleaned up. + // So we just use slice() because the end result should be logically equivalent. + return slice(); + } + + @Override + public ByteBuf slice(int index, int length) { + return new UnreleasableByteBuf(buf.slice(index, length)); + } + + @Override + public ByteBuf retainedSlice(int index, int length) { + // We could call buf.retainedSlice(..), and then call buf.release(). However this creates a leak in unit tests + // because the release method on UnreleasableByteBuf will never allow the leak record to be cleaned up. + // So we just use slice(..) because the end result should be logically equivalent. + return slice(index, length); + } + + @Override + public ByteBuf duplicate() { + return new UnreleasableByteBuf(buf.duplicate()); + } + + @Override + public ByteBuf retainedDuplicate() { + // We could call buf.retainedDuplicate(), and then call buf.release(). However this creates a leak in unit tests + // because the release method on UnreleasableByteBuf will never allow the leak record to be cleaned up. + // So we just use duplicate() because the end result should be logically equivalent. + return duplicate(); + } + + @Override + public ByteBuf retain(int increment) { + return this; + } + + @Override + public ByteBuf retain() { + return this; + } + + @Override + public ByteBuf touch() { + return this; + } + + @Override + public ByteBuf touch(Object hint) { + return this; + } + + @Override + public boolean release() { + return false; + } + + @Override + public boolean release(int decrement) { + return false; + } +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/UnsafeByteBufUtil.java b/netty-buffer/src/main/java/io/netty/buffer/UnsafeByteBufUtil.java new file mode 100644 index 0000000..e21b0fb --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/UnsafeByteBufUtil.java @@ -0,0 +1,691 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import io.netty.util.internal.PlatformDependent; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.ReadOnlyBufferException; + +import static io.netty.util.internal.MathUtil.isOutOfBounds; +import static io.netty.util.internal.ObjectUtil.checkNotNull; +import static io.netty.util.internal.PlatformDependent.BIG_ENDIAN_NATIVE_ORDER; + +/** + * All operations get and set as {@link ByteOrder#BIG_ENDIAN}. + */ +final class UnsafeByteBufUtil { + private static final boolean UNALIGNED = PlatformDependent.isUnaligned(); + private static final byte ZERO = 0; + private static final int MAX_HAND_ROLLED_SET_ZERO_BYTES = 64; + + static byte getByte(long address) { + return PlatformDependent.getByte(address); + } + + static short getShort(long address) { + if (UNALIGNED) { + short v = PlatformDependent.getShort(address); + return BIG_ENDIAN_NATIVE_ORDER ? v : Short.reverseBytes(v); + } + return (short) (PlatformDependent.getByte(address) << 8 | PlatformDependent.getByte(address + 1) & 0xff); + } + + static short getShortLE(long address) { + if (UNALIGNED) { + short v = PlatformDependent.getShort(address); + return BIG_ENDIAN_NATIVE_ORDER ? Short.reverseBytes(v) : v; + } + return (short) (PlatformDependent.getByte(address) & 0xff | PlatformDependent.getByte(address + 1) << 8); + } + + static int getUnsignedMedium(long address) { + if (UNALIGNED) { + return (PlatformDependent.getByte(address) & 0xff) << 16 | + (BIG_ENDIAN_NATIVE_ORDER ? PlatformDependent.getShort(address + 1) + : Short.reverseBytes(PlatformDependent.getShort(address + 1))) & 0xffff; + } + return (PlatformDependent.getByte(address) & 0xff) << 16 | + (PlatformDependent.getByte(address + 1) & 0xff) << 8 | + PlatformDependent.getByte(address + 2) & 0xff; + } + + static int getUnsignedMediumLE(long address) { + if (UNALIGNED) { + return (PlatformDependent.getByte(address) & 0xff) | + ((BIG_ENDIAN_NATIVE_ORDER ? Short.reverseBytes(PlatformDependent.getShort(address + 1)) + : PlatformDependent.getShort(address + 1)) & 0xffff) << 8; + } + return PlatformDependent.getByte(address) & 0xff | + (PlatformDependent.getByte(address + 1) & 0xff) << 8 | + (PlatformDependent.getByte(address + 2) & 0xff) << 16; + } + + static int getInt(long address) { + if (UNALIGNED) { + int v = PlatformDependent.getInt(address); + return BIG_ENDIAN_NATIVE_ORDER ? v : Integer.reverseBytes(v); + } + return PlatformDependent.getByte(address) << 24 | + (PlatformDependent.getByte(address + 1) & 0xff) << 16 | + (PlatformDependent.getByte(address + 2) & 0xff) << 8 | + PlatformDependent.getByte(address + 3) & 0xff; + } + + static int getIntLE(long address) { + if (UNALIGNED) { + int v = PlatformDependent.getInt(address); + return BIG_ENDIAN_NATIVE_ORDER ? Integer.reverseBytes(v) : v; + } + return PlatformDependent.getByte(address) & 0xff | + (PlatformDependent.getByte(address + 1) & 0xff) << 8 | + (PlatformDependent.getByte(address + 2) & 0xff) << 16 | + PlatformDependent.getByte(address + 3) << 24; + } + + static long getLong(long address) { + if (UNALIGNED) { + long v = PlatformDependent.getLong(address); + return BIG_ENDIAN_NATIVE_ORDER ? v : Long.reverseBytes(v); + } + return ((long) PlatformDependent.getByte(address)) << 56 | + (PlatformDependent.getByte(address + 1) & 0xffL) << 48 | + (PlatformDependent.getByte(address + 2) & 0xffL) << 40 | + (PlatformDependent.getByte(address + 3) & 0xffL) << 32 | + (PlatformDependent.getByte(address + 4) & 0xffL) << 24 | + (PlatformDependent.getByte(address + 5) & 0xffL) << 16 | + (PlatformDependent.getByte(address + 6) & 0xffL) << 8 | + (PlatformDependent.getByte(address + 7)) & 0xffL; + } + + static long getLongLE(long address) { + if (UNALIGNED) { + long v = PlatformDependent.getLong(address); + return BIG_ENDIAN_NATIVE_ORDER ? Long.reverseBytes(v) : v; + } + return (PlatformDependent.getByte(address)) & 0xffL | + (PlatformDependent.getByte(address + 1) & 0xffL) << 8 | + (PlatformDependent.getByte(address + 2) & 0xffL) << 16 | + (PlatformDependent.getByte(address + 3) & 0xffL) << 24 | + (PlatformDependent.getByte(address + 4) & 0xffL) << 32 | + (PlatformDependent.getByte(address + 5) & 0xffL) << 40 | + (PlatformDependent.getByte(address + 6) & 0xffL) << 48 | + ((long) PlatformDependent.getByte(address + 7)) << 56; + } + + static void setByte(long address, int value) { + PlatformDependent.putByte(address, (byte) value); + } + + static void setShort(long address, int value) { + if (UNALIGNED) { + PlatformDependent.putShort( + address, BIG_ENDIAN_NATIVE_ORDER ? (short) value : Short.reverseBytes((short) value)); + } else { + PlatformDependent.putByte(address, (byte) (value >>> 8)); + PlatformDependent.putByte(address + 1, (byte) value); + } + } + + static void setShortLE(long address, int value) { + if (UNALIGNED) { + PlatformDependent.putShort( + address, BIG_ENDIAN_NATIVE_ORDER ? Short.reverseBytes((short) value) : (short) value); + } else { + PlatformDependent.putByte(address, (byte) value); + PlatformDependent.putByte(address + 1, (byte) (value >>> 8)); + } + } + + static void setMedium(long address, int value) { + PlatformDependent.putByte(address, (byte) (value >>> 16)); + if (UNALIGNED) { + PlatformDependent.putShort(address + 1, BIG_ENDIAN_NATIVE_ORDER ? (short) value + : Short.reverseBytes((short) value)); + } else { + PlatformDependent.putByte(address + 1, (byte) (value >>> 8)); + PlatformDependent.putByte(address + 2, (byte) value); + } + } + + static void setMediumLE(long address, int value) { + PlatformDependent.putByte(address, (byte) value); + if (UNALIGNED) { + PlatformDependent.putShort(address + 1, BIG_ENDIAN_NATIVE_ORDER ? Short.reverseBytes((short) (value >>> 8)) + : (short) (value >>> 8)); + } else { + PlatformDependent.putByte(address + 1, (byte) (value >>> 8)); + PlatformDependent.putByte(address + 2, (byte) (value >>> 16)); + } + } + + static void setInt(long address, int value) { + if (UNALIGNED) { + PlatformDependent.putInt(address, BIG_ENDIAN_NATIVE_ORDER ? value : Integer.reverseBytes(value)); + } else { + PlatformDependent.putByte(address, (byte) (value >>> 24)); + PlatformDependent.putByte(address + 1, (byte) (value >>> 16)); + PlatformDependent.putByte(address + 2, (byte) (value >>> 8)); + PlatformDependent.putByte(address + 3, (byte) value); + } + } + + static void setIntLE(long address, int value) { + if (UNALIGNED) { + PlatformDependent.putInt(address, BIG_ENDIAN_NATIVE_ORDER ? Integer.reverseBytes(value) : value); + } else { + PlatformDependent.putByte(address, (byte) value); + PlatformDependent.putByte(address + 1, (byte) (value >>> 8)); + PlatformDependent.putByte(address + 2, (byte) (value >>> 16)); + PlatformDependent.putByte(address + 3, (byte) (value >>> 24)); + } + } + + static void setLong(long address, long value) { + if (UNALIGNED) { + PlatformDependent.putLong(address, BIG_ENDIAN_NATIVE_ORDER ? value : Long.reverseBytes(value)); + } else { + PlatformDependent.putByte(address, (byte) (value >>> 56)); + PlatformDependent.putByte(address + 1, (byte) (value >>> 48)); + PlatformDependent.putByte(address + 2, (byte) (value >>> 40)); + PlatformDependent.putByte(address + 3, (byte) (value >>> 32)); + PlatformDependent.putByte(address + 4, (byte) (value >>> 24)); + PlatformDependent.putByte(address + 5, (byte) (value >>> 16)); + PlatformDependent.putByte(address + 6, (byte) (value >>> 8)); + PlatformDependent.putByte(address + 7, (byte) value); + } + } + + static void setLongLE(long address, long value) { + if (UNALIGNED) { + PlatformDependent.putLong(address, BIG_ENDIAN_NATIVE_ORDER ? Long.reverseBytes(value) : value); + } else { + PlatformDependent.putByte(address, (byte) value); + PlatformDependent.putByte(address + 1, (byte) (value >>> 8)); + PlatformDependent.putByte(address + 2, (byte) (value >>> 16)); + PlatformDependent.putByte(address + 3, (byte) (value >>> 24)); + PlatformDependent.putByte(address + 4, (byte) (value >>> 32)); + PlatformDependent.putByte(address + 5, (byte) (value >>> 40)); + PlatformDependent.putByte(address + 6, (byte) (value >>> 48)); + PlatformDependent.putByte(address + 7, (byte) (value >>> 56)); + } + } + + static byte getByte(byte[] array, int index) { + return PlatformDependent.getByte(array, index); + } + + static short getShort(byte[] array, int index) { + if (UNALIGNED) { + short v = PlatformDependent.getShort(array, index); + return BIG_ENDIAN_NATIVE_ORDER ? v : Short.reverseBytes(v); + } + return (short) (PlatformDependent.getByte(array, index) << 8 | + PlatformDependent.getByte(array, index + 1) & 0xff); + } + + static short getShortLE(byte[] array, int index) { + if (UNALIGNED) { + short v = PlatformDependent.getShort(array, index); + return BIG_ENDIAN_NATIVE_ORDER ? Short.reverseBytes(v) : v; + } + return (short) (PlatformDependent.getByte(array, index) & 0xff | + PlatformDependent.getByte(array, index + 1) << 8); + } + + static int getUnsignedMedium(byte[] array, int index) { + if (UNALIGNED) { + return (PlatformDependent.getByte(array, index) & 0xff) << 16 | + (BIG_ENDIAN_NATIVE_ORDER ? PlatformDependent.getShort(array, index + 1) + : Short.reverseBytes(PlatformDependent.getShort(array, index + 1))) + & 0xffff; + } + return (PlatformDependent.getByte(array, index) & 0xff) << 16 | + (PlatformDependent.getByte(array, index + 1) & 0xff) << 8 | + PlatformDependent.getByte(array, index + 2) & 0xff; + } + + static int getUnsignedMediumLE(byte[] array, int index) { + if (UNALIGNED) { + return (PlatformDependent.getByte(array, index) & 0xff) | + ((BIG_ENDIAN_NATIVE_ORDER ? Short.reverseBytes(PlatformDependent.getShort(array, index + 1)) + : PlatformDependent.getShort(array, index + 1)) & 0xffff) << 8; + } + return PlatformDependent.getByte(array, index) & 0xff | + (PlatformDependent.getByte(array, index + 1) & 0xff) << 8 | + (PlatformDependent.getByte(array, index + 2) & 0xff) << 16; + } + + static int getInt(byte[] array, int index) { + if (UNALIGNED) { + int v = PlatformDependent.getInt(array, index); + return BIG_ENDIAN_NATIVE_ORDER ? v : Integer.reverseBytes(v); + } + return PlatformDependent.getByte(array, index) << 24 | + (PlatformDependent.getByte(array, index + 1) & 0xff) << 16 | + (PlatformDependent.getByte(array, index + 2) & 0xff) << 8 | + PlatformDependent.getByte(array, index + 3) & 0xff; + } + + static int getIntLE(byte[] array, int index) { + if (UNALIGNED) { + int v = PlatformDependent.getInt(array, index); + return BIG_ENDIAN_NATIVE_ORDER ? Integer.reverseBytes(v) : v; + } + return PlatformDependent.getByte(array, index) & 0xff | + (PlatformDependent.getByte(array, index + 1) & 0xff) << 8 | + (PlatformDependent.getByte(array, index + 2) & 0xff) << 16 | + PlatformDependent.getByte(array, index + 3) << 24; + } + + static long getLong(byte[] array, int index) { + if (UNALIGNED) { + long v = PlatformDependent.getLong(array, index); + return BIG_ENDIAN_NATIVE_ORDER ? v : Long.reverseBytes(v); + } + return ((long) PlatformDependent.getByte(array, index)) << 56 | + (PlatformDependent.getByte(array, index + 1) & 0xffL) << 48 | + (PlatformDependent.getByte(array, index + 2) & 0xffL) << 40 | + (PlatformDependent.getByte(array, index + 3) & 0xffL) << 32 | + (PlatformDependent.getByte(array, index + 4) & 0xffL) << 24 | + (PlatformDependent.getByte(array, index + 5) & 0xffL) << 16 | + (PlatformDependent.getByte(array, index + 6) & 0xffL) << 8 | + (PlatformDependent.getByte(array, index + 7)) & 0xffL; + } + + static long getLongLE(byte[] array, int index) { + if (UNALIGNED) { + long v = PlatformDependent.getLong(array, index); + return BIG_ENDIAN_NATIVE_ORDER ? Long.reverseBytes(v) : v; + } + return PlatformDependent.getByte(array, index) & 0xffL | + (PlatformDependent.getByte(array, index + 1) & 0xffL) << 8 | + (PlatformDependent.getByte(array, index + 2) & 0xffL) << 16 | + (PlatformDependent.getByte(array, index + 3) & 0xffL) << 24 | + (PlatformDependent.getByte(array, index + 4) & 0xffL) << 32 | + (PlatformDependent.getByte(array, index + 5) & 0xffL) << 40 | + (PlatformDependent.getByte(array, index + 6) & 0xffL) << 48 | + ((long) PlatformDependent.getByte(array, index + 7)) << 56; + } + + static void setByte(byte[] array, int index, int value) { + PlatformDependent.putByte(array, index, (byte) value); + } + + static void setShort(byte[] array, int index, int value) { + if (UNALIGNED) { + PlatformDependent.putShort(array, index, + BIG_ENDIAN_NATIVE_ORDER ? (short) value : Short.reverseBytes((short) value)); + } else { + PlatformDependent.putByte(array, index, (byte) (value >>> 8)); + PlatformDependent.putByte(array, index + 1, (byte) value); + } + } + + static void setShortLE(byte[] array, int index, int value) { + if (UNALIGNED) { + PlatformDependent.putShort(array, index, + BIG_ENDIAN_NATIVE_ORDER ? Short.reverseBytes((short) value) : (short) value); + } else { + PlatformDependent.putByte(array, index, (byte) value); + PlatformDependent.putByte(array, index + 1, (byte) (value >>> 8)); + } + } + + static void setMedium(byte[] array, int index, int value) { + PlatformDependent.putByte(array, index, (byte) (value >>> 16)); + if (UNALIGNED) { + PlatformDependent.putShort(array, index + 1, + BIG_ENDIAN_NATIVE_ORDER ? (short) value + : Short.reverseBytes((short) value)); + } else { + PlatformDependent.putByte(array, index + 1, (byte) (value >>> 8)); + PlatformDependent.putByte(array, index + 2, (byte) value); + } + } + + static void setMediumLE(byte[] array, int index, int value) { + PlatformDependent.putByte(array, index, (byte) value); + if (UNALIGNED) { + PlatformDependent.putShort(array, index + 1, + BIG_ENDIAN_NATIVE_ORDER ? Short.reverseBytes((short) (value >>> 8)) + : (short) (value >>> 8)); + } else { + PlatformDependent.putByte(array, index + 1, (byte) (value >>> 8)); + PlatformDependent.putByte(array, index + 2, (byte) (value >>> 16)); + } + } + + static void setInt(byte[] array, int index, int value) { + if (UNALIGNED) { + PlatformDependent.putInt(array, index, BIG_ENDIAN_NATIVE_ORDER ? value : Integer.reverseBytes(value)); + } else { + PlatformDependent.putByte(array, index, (byte) (value >>> 24)); + PlatformDependent.putByte(array, index + 1, (byte) (value >>> 16)); + PlatformDependent.putByte(array, index + 2, (byte) (value >>> 8)); + PlatformDependent.putByte(array, index + 3, (byte) value); + } + } + + static void setIntLE(byte[] array, int index, int value) { + if (UNALIGNED) { + PlatformDependent.putInt(array, index, BIG_ENDIAN_NATIVE_ORDER ? Integer.reverseBytes(value) : value); + } else { + PlatformDependent.putByte(array, index, (byte) value); + PlatformDependent.putByte(array, index + 1, (byte) (value >>> 8)); + PlatformDependent.putByte(array, index + 2, (byte) (value >>> 16)); + PlatformDependent.putByte(array, index + 3, (byte) (value >>> 24)); + } + } + + static void setLong(byte[] array, int index, long value) { + if (UNALIGNED) { + PlatformDependent.putLong(array, index, BIG_ENDIAN_NATIVE_ORDER ? value : Long.reverseBytes(value)); + } else { + PlatformDependent.putByte(array, index, (byte) (value >>> 56)); + PlatformDependent.putByte(array, index + 1, (byte) (value >>> 48)); + PlatformDependent.putByte(array, index + 2, (byte) (value >>> 40)); + PlatformDependent.putByte(array, index + 3, (byte) (value >>> 32)); + PlatformDependent.putByte(array, index + 4, (byte) (value >>> 24)); + PlatformDependent.putByte(array, index + 5, (byte) (value >>> 16)); + PlatformDependent.putByte(array, index + 6, (byte) (value >>> 8)); + PlatformDependent.putByte(array, index + 7, (byte) value); + } + } + + static void setLongLE(byte[] array, int index, long value) { + if (UNALIGNED) { + PlatformDependent.putLong(array, index, BIG_ENDIAN_NATIVE_ORDER ? Long.reverseBytes(value) : value); + } else { + PlatformDependent.putByte(array, index, (byte) value); + PlatformDependent.putByte(array, index + 1, (byte) (value >>> 8)); + PlatformDependent.putByte(array, index + 2, (byte) (value >>> 16)); + PlatformDependent.putByte(array, index + 3, (byte) (value >>> 24)); + PlatformDependent.putByte(array, index + 4, (byte) (value >>> 32)); + PlatformDependent.putByte(array, index + 5, (byte) (value >>> 40)); + PlatformDependent.putByte(array, index + 6, (byte) (value >>> 48)); + PlatformDependent.putByte(array, index + 7, (byte) (value >>> 56)); + } + } + + private static void batchSetZero(byte[] data, int index, int length) { + int longBatches = length / 8; + for (int i = 0; i < longBatches; i++) { + PlatformDependent.putLong(data, index, ZERO); + index += 8; + } + final int remaining = length % 8; + for (int i = 0; i < remaining; i++) { + PlatformDependent.putByte(data, index + i, ZERO); + } + } + + static void setZero(byte[] array, int index, int length) { + if (length == 0) { + return; + } + // fast-path for small writes to avoid thread-state change JDK's handling + if (UNALIGNED && length <= MAX_HAND_ROLLED_SET_ZERO_BYTES) { + batchSetZero(array, index, length); + } else { + PlatformDependent.setMemory(array, index, length, ZERO); + } + } + + static ByteBuf copy(AbstractByteBuf buf, long addr, int index, int length) { + buf.checkIndex(index, length); + ByteBuf copy = buf.alloc().directBuffer(length, buf.maxCapacity()); + if (length != 0) { + if (copy.hasMemoryAddress()) { + PlatformDependent.copyMemory(addr, copy.memoryAddress(), length); + copy.setIndex(0, length); + } else { + copy.writeBytes(buf, index, length); + } + } + return copy; + } + + static int setBytes(AbstractByteBuf buf, long addr, int index, InputStream in, int length) throws IOException { + buf.checkIndex(index, length); + ByteBuf tmpBuf = buf.alloc().heapBuffer(length); + try { + byte[] tmp = tmpBuf.array(); + int offset = tmpBuf.arrayOffset(); + int readBytes = in.read(tmp, offset, length); + if (readBytes > 0) { + PlatformDependent.copyMemory(tmp, offset, addr, readBytes); + } + return readBytes; + } finally { + tmpBuf.release(); + } + } + + static void getBytes(AbstractByteBuf buf, long addr, int index, ByteBuf dst, int dstIndex, int length) { + buf.checkIndex(index, length); + checkNotNull(dst, "dst"); + if (isOutOfBounds(dstIndex, length, dst.capacity())) { + throw new IndexOutOfBoundsException("dstIndex: " + dstIndex); + } + + if (dst.hasMemoryAddress()) { + PlatformDependent.copyMemory(addr, dst.memoryAddress() + dstIndex, length); + } else if (dst.hasArray()) { + PlatformDependent.copyMemory(addr, dst.array(), dst.arrayOffset() + dstIndex, length); + } else { + dst.setBytes(dstIndex, buf, index, length); + } + } + + static void getBytes(AbstractByteBuf buf, long addr, int index, byte[] dst, int dstIndex, int length) { + buf.checkIndex(index, length); + checkNotNull(dst, "dst"); + if (isOutOfBounds(dstIndex, length, dst.length)) { + throw new IndexOutOfBoundsException("dstIndex: " + dstIndex); + } + if (length != 0) { + PlatformDependent.copyMemory(addr, dst, dstIndex, length); + } + } + + static void getBytes(AbstractByteBuf buf, long addr, int index, ByteBuffer dst) { + buf.checkIndex(index, dst.remaining()); + if (dst.remaining() == 0) { + return; + } + + if (dst.isDirect()) { + if (dst.isReadOnly()) { + // We need to check if dst is ready-only so we not write something in it by using Unsafe. + throw new ReadOnlyBufferException(); + } + // Copy to direct memory + long dstAddress = PlatformDependent.directBufferAddress(dst); + PlatformDependent.copyMemory(addr, dstAddress + dst.position(), dst.remaining()); + dst.position(dst.position() + dst.remaining()); + } else if (dst.hasArray()) { + // Copy to array + PlatformDependent.copyMemory(addr, dst.array(), dst.arrayOffset() + dst.position(), dst.remaining()); + dst.position(dst.position() + dst.remaining()); + } else { + dst.put(buf.nioBuffer()); + } + } + + static void setBytes(AbstractByteBuf buf, long addr, int index, ByteBuf src, int srcIndex, int length) { + buf.checkIndex(index, length); + checkNotNull(src, "src"); + if (isOutOfBounds(srcIndex, length, src.capacity())) { + throw new IndexOutOfBoundsException("srcIndex: " + srcIndex); + } + + if (length != 0) { + if (src.hasMemoryAddress()) { + PlatformDependent.copyMemory(src.memoryAddress() + srcIndex, addr, length); + } else if (src.hasArray()) { + PlatformDependent.copyMemory(src.array(), src.arrayOffset() + srcIndex, addr, length); + } else { + src.getBytes(srcIndex, buf, index, length); + } + } + } + + static void setBytes(AbstractByteBuf buf, long addr, int index, byte[] src, int srcIndex, int length) { + buf.checkIndex(index, length); + // we need to check not null for src as it may cause the JVM crash + // See https://github.com/netty/netty/issues/10791 + checkNotNull(src, "src"); + if (isOutOfBounds(srcIndex, length, src.length)) { + throw new IndexOutOfBoundsException("srcIndex: " + srcIndex); + } + + if (length != 0) { + PlatformDependent.copyMemory(src, srcIndex, addr, length); + } + } + + static void setBytes(AbstractByteBuf buf, long addr, int index, ByteBuffer src) { + final int length = src.remaining(); + if (length == 0) { + return; + } + + if (src.isDirect()) { + buf.checkIndex(index, length); + // Copy from direct memory + long srcAddress = PlatformDependent.directBufferAddress(src); + PlatformDependent.copyMemory(srcAddress + src.position(), addr, length); + src.position(src.position() + length); + } else if (src.hasArray()) { + buf.checkIndex(index, length); + // Copy from array + PlatformDependent.copyMemory(src.array(), src.arrayOffset() + src.position(), addr, length); + src.position(src.position() + length); + } else { + if (length < 8) { + setSingleBytes(buf, addr, index, src, length); + } else { + //no need to checkIndex: internalNioBuffer is already taking care of it + assert buf.nioBufferCount() == 1; + final ByteBuffer internalBuffer = buf.internalNioBuffer(index, length); + internalBuffer.put(src); + } + } + } + + private static void setSingleBytes(final AbstractByteBuf buf, final long addr, final int index, + final ByteBuffer src, final int length) { + buf.checkIndex(index, length); + final int srcPosition = src.position(); + final int srcLimit = src.limit(); + long dstAddr = addr; + for (int srcIndex = srcPosition; srcIndex < srcLimit; srcIndex++) { + final byte value = src.get(srcIndex); + PlatformDependent.putByte(dstAddr, value); + dstAddr++; + } + src.position(srcLimit); + } + + static void getBytes(AbstractByteBuf buf, long addr, int index, OutputStream out, int length) throws IOException { + buf.checkIndex(index, length); + if (length != 0) { + int len = Math.min(length, ByteBufUtil.WRITE_CHUNK_SIZE); + if (len <= ByteBufUtil.MAX_TL_ARRAY_LEN || !buf.alloc().isDirectBufferPooled()) { + getBytes(addr, ByteBufUtil.threadLocalTempArray(len), 0, len, out, length); + } else { + // if direct buffers are pooled chances are good that heap buffers are pooled as well. + ByteBuf tmpBuf = buf.alloc().heapBuffer(len); + try { + byte[] tmp = tmpBuf.array(); + int offset = tmpBuf.arrayOffset(); + getBytes(addr, tmp, offset, len, out, length); + } finally { + tmpBuf.release(); + } + } + } + } + + private static void getBytes(long inAddr, byte[] in, int inOffset, int inLen, OutputStream out, int outLen) + throws IOException { + do { + int len = Math.min(inLen, outLen); + PlatformDependent.copyMemory(inAddr, in, inOffset, len); + out.write(in, inOffset, len); + outLen -= len; + inAddr += len; + } while (outLen > 0); + } + + private static void batchSetZero(long addr, int length) { + int longBatches = length / 8; + for (int i = 0; i < longBatches; i++) { + PlatformDependent.putLong(addr, ZERO); + addr += 8; + } + final int remaining = length % 8; + for (int i = 0; i < remaining; i++) { + PlatformDependent.putByte(addr + i, ZERO); + } + } + + static void setZero(long addr, int length) { + if (length == 0) { + return; + } + // fast-path for small writes to avoid thread-state change JDK's handling + if (length <= MAX_HAND_ROLLED_SET_ZERO_BYTES) { + if (!UNALIGNED) { + // write bytes until the address is aligned + int bytesToGetAligned = zeroTillAligned(addr, length); + addr += bytesToGetAligned; + length -= bytesToGetAligned; + if (length == 0) { + return; + } + assert addr % 8 == 0; + } + batchSetZero(addr, length); + } else { + PlatformDependent.setMemory(addr, length, ZERO); + } + } + + private static int zeroTillAligned(long addr, int length) { + // write bytes until the address is aligned + int bytesToGetAligned = Math.min((int) (addr % 8), length); + for (int i = 0; i < bytesToGetAligned; i++) { + PlatformDependent.putByte(addr + i, ZERO); + } + return bytesToGetAligned; + } + + static UnpooledUnsafeDirectByteBuf newUnsafeDirectByteBuf( + ByteBufAllocator alloc, int initialCapacity, int maxCapacity) { + if (PlatformDependent.useDirectBufferNoCleaner()) { + return new UnpooledUnsafeNoCleanerDirectByteBuf(alloc, initialCapacity, maxCapacity); + } + return new UnpooledUnsafeDirectByteBuf(alloc, initialCapacity, maxCapacity); + } + + private UnsafeByteBufUtil() { } +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/UnsafeDirectSwappedByteBuf.java b/netty-buffer/src/main/java/io/netty/buffer/UnsafeDirectSwappedByteBuf.java new file mode 100644 index 0000000..dca920e --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/UnsafeDirectSwappedByteBuf.java @@ -0,0 +1,67 @@ +/* +* Copyright 2014 The Netty Project +* +* The Netty Project licenses this file to you under the Apache License, +* version 2.0 (the "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at: +* +* https://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ + +package io.netty.buffer; + +import io.netty.util.internal.PlatformDependent; + +/** + * Special {@link SwappedByteBuf} for {@link ByteBuf}s that are backed by a {@code memoryAddress}. + */ +final class UnsafeDirectSwappedByteBuf extends AbstractUnsafeSwappedByteBuf { + + UnsafeDirectSwappedByteBuf(AbstractByteBuf buf) { + super(buf); + } + + private static long addr(AbstractByteBuf wrapped, int index) { + // We need to call wrapped.memoryAddress() everytime and NOT cache it as it may change if the buffer expand. + // See: + // - https://github.com/netty/netty/issues/2587 + // - https://github.com/netty/netty/issues/2580 + return wrapped.memoryAddress() + index; + } + + @Override + protected long _getLong(AbstractByteBuf wrapped, int index) { + return PlatformDependent.getLong(addr(wrapped, index)); + } + + @Override + protected int _getInt(AbstractByteBuf wrapped, int index) { + return PlatformDependent.getInt(addr(wrapped, index)); + } + + @Override + protected short _getShort(AbstractByteBuf wrapped, int index) { + return PlatformDependent.getShort(addr(wrapped, index)); + } + + @Override + protected void _setShort(AbstractByteBuf wrapped, int index, short value) { + PlatformDependent.putShort(addr(wrapped, index), value); + } + + @Override + protected void _setInt(AbstractByteBuf wrapped, int index, int value) { + PlatformDependent.putInt(addr(wrapped, index), value); + } + + @Override + protected void _setLong(AbstractByteBuf wrapped, int index, long value) { + PlatformDependent.putLong(addr(wrapped, index), value); + } +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/UnsafeHeapSwappedByteBuf.java b/netty-buffer/src/main/java/io/netty/buffer/UnsafeHeapSwappedByteBuf.java new file mode 100644 index 0000000..8de2870 --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/UnsafeHeapSwappedByteBuf.java @@ -0,0 +1,63 @@ +/* +* Copyright 2014 The Netty Project +* +* The Netty Project licenses this file to you under the Apache License, +* version 2.0 (the "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at: +* +* https://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ + +package io.netty.buffer; + +import io.netty.util.internal.PlatformDependent; + +/** + * Special {@link SwappedByteBuf} for {@link ByteBuf}s that use unsafe to access the byte array. + */ +final class UnsafeHeapSwappedByteBuf extends AbstractUnsafeSwappedByteBuf { + + UnsafeHeapSwappedByteBuf(AbstractByteBuf buf) { + super(buf); + } + + private static int idx(ByteBuf wrapped, int index) { + return wrapped.arrayOffset() + index; + } + + @Override + protected long _getLong(AbstractByteBuf wrapped, int index) { + return PlatformDependent.getLong(wrapped.array(), idx(wrapped, index)); + } + + @Override + protected int _getInt(AbstractByteBuf wrapped, int index) { + return PlatformDependent.getInt(wrapped.array(), idx(wrapped, index)); + } + + @Override + protected short _getShort(AbstractByteBuf wrapped, int index) { + return PlatformDependent.getShort(wrapped.array(), idx(wrapped, index)); + } + + @Override + protected void _setShort(AbstractByteBuf wrapped, int index, short value) { + PlatformDependent.putShort(wrapped.array(), idx(wrapped, index), value); + } + + @Override + protected void _setInt(AbstractByteBuf wrapped, int index, int value) { + PlatformDependent.putInt(wrapped.array(), idx(wrapped, index), value); + } + + @Override + protected void _setLong(AbstractByteBuf wrapped, int index, long value) { + PlatformDependent.putLong(wrapped.array(), idx(wrapped, index), value); + } +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/WrappedByteBuf.java b/netty-buffer/src/main/java/io/netty/buffer/WrappedByteBuf.java new file mode 100644 index 0000000..f5f0b7b --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/WrappedByteBuf.java @@ -0,0 +1,1049 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.buffer; + +import io.netty.util.ByteProcessor; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.StringUtil; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.channels.FileChannel; +import java.nio.channels.GatheringByteChannel; +import java.nio.channels.ScatteringByteChannel; +import java.nio.charset.Charset; + +/** + * Wraps another {@link ByteBuf}. + * + * It's important that the {@link #readerIndex()} and {@link #writerIndex()} will not do any adjustments on the + * indices on the fly because of internal optimizations made by {@link ByteBufUtil#writeAscii(ByteBuf, CharSequence)} + * and {@link ByteBufUtil#writeUtf8(ByteBuf, CharSequence)}. + */ +class WrappedByteBuf extends ByteBuf { + + protected final ByteBuf buf; + + protected WrappedByteBuf(ByteBuf buf) { + this.buf = ObjectUtil.checkNotNull(buf, "buf"); + } + + @Override + public final boolean hasMemoryAddress() { + return buf.hasMemoryAddress(); + } + + @Override + public boolean isContiguous() { + return buf.isContiguous(); + } + + @Override + public final long memoryAddress() { + return buf.memoryAddress(); + } + + @Override + public final int capacity() { + return buf.capacity(); + } + + @Override + public ByteBuf capacity(int newCapacity) { + buf.capacity(newCapacity); + return this; + } + + @Override + public final int maxCapacity() { + return buf.maxCapacity(); + } + + @Override + public final ByteBufAllocator alloc() { + return buf.alloc(); + } + + @Override + public final ByteOrder order() { + return buf.order(); + } + + @Override + public ByteBuf order(ByteOrder endianness) { + return buf.order(endianness); + } + + @Override + public final ByteBuf unwrap() { + return buf; + } + + @Override + public ByteBuf asReadOnly() { + return buf.asReadOnly(); + } + + @Override + public boolean isReadOnly() { + return buf.isReadOnly(); + } + + @Override + public final boolean isDirect() { + return buf.isDirect(); + } + + @Override + public final int readerIndex() { + return buf.readerIndex(); + } + + @Override + public final ByteBuf readerIndex(int readerIndex) { + buf.readerIndex(readerIndex); + return this; + } + + @Override + public final int writerIndex() { + return buf.writerIndex(); + } + + @Override + public final ByteBuf writerIndex(int writerIndex) { + buf.writerIndex(writerIndex); + return this; + } + + @Override + public ByteBuf setIndex(int readerIndex, int writerIndex) { + buf.setIndex(readerIndex, writerIndex); + return this; + } + + @Override + public final int readableBytes() { + return buf.readableBytes(); + } + + @Override + public final int writableBytes() { + return buf.writableBytes(); + } + + @Override + public final int maxWritableBytes() { + return buf.maxWritableBytes(); + } + + @Override + public int maxFastWritableBytes() { + return buf.maxFastWritableBytes(); + } + + @Override + public final boolean isReadable() { + return buf.isReadable(); + } + + @Override + public final boolean isWritable() { + return buf.isWritable(); + } + + @Override + public final ByteBuf clear() { + buf.clear(); + return this; + } + + @Override + public final ByteBuf markReaderIndex() { + buf.markReaderIndex(); + return this; + } + + @Override + public final ByteBuf resetReaderIndex() { + buf.resetReaderIndex(); + return this; + } + + @Override + public final ByteBuf markWriterIndex() { + buf.markWriterIndex(); + return this; + } + + @Override + public final ByteBuf resetWriterIndex() { + buf.resetWriterIndex(); + return this; + } + + @Override + public ByteBuf discardReadBytes() { + buf.discardReadBytes(); + return this; + } + + @Override + public ByteBuf discardSomeReadBytes() { + buf.discardSomeReadBytes(); + return this; + } + + @Override + public ByteBuf ensureWritable(int minWritableBytes) { + buf.ensureWritable(minWritableBytes); + return this; + } + + @Override + public int ensureWritable(int minWritableBytes, boolean force) { + return buf.ensureWritable(minWritableBytes, force); + } + + @Override + public boolean getBoolean(int index) { + return buf.getBoolean(index); + } + + @Override + public byte getByte(int index) { + return buf.getByte(index); + } + + @Override + public short getUnsignedByte(int index) { + return buf.getUnsignedByte(index); + } + + @Override + public short getShort(int index) { + return buf.getShort(index); + } + + @Override + public short getShortLE(int index) { + return buf.getShortLE(index); + } + + @Override + public int getUnsignedShort(int index) { + return buf.getUnsignedShort(index); + } + + @Override + public int getUnsignedShortLE(int index) { + return buf.getUnsignedShortLE(index); + } + + @Override + public int getMedium(int index) { + return buf.getMedium(index); + } + + @Override + public int getMediumLE(int index) { + return buf.getMediumLE(index); + } + + @Override + public int getUnsignedMedium(int index) { + return buf.getUnsignedMedium(index); + } + + @Override + public int getUnsignedMediumLE(int index) { + return buf.getUnsignedMediumLE(index); + } + + @Override + public int getInt(int index) { + return buf.getInt(index); + } + + @Override + public int getIntLE(int index) { + return buf.getIntLE(index); + } + + @Override + public long getUnsignedInt(int index) { + return buf.getUnsignedInt(index); + } + + @Override + public long getUnsignedIntLE(int index) { + return buf.getUnsignedIntLE(index); + } + + @Override + public long getLong(int index) { + return buf.getLong(index); + } + + @Override + public long getLongLE(int index) { + return buf.getLongLE(index); + } + + @Override + public char getChar(int index) { + return buf.getChar(index); + } + + @Override + public float getFloat(int index) { + return buf.getFloat(index); + } + + @Override + public double getDouble(int index) { + return buf.getDouble(index); + } + + @Override + public ByteBuf getBytes(int index, ByteBuf dst) { + buf.getBytes(index, dst); + return this; + } + + @Override + public ByteBuf getBytes(int index, ByteBuf dst, int length) { + buf.getBytes(index, dst, length); + return this; + } + + @Override + public ByteBuf getBytes(int index, ByteBuf dst, int dstIndex, int length) { + buf.getBytes(index, dst, dstIndex, length); + return this; + } + + @Override + public ByteBuf getBytes(int index, byte[] dst) { + buf.getBytes(index, dst); + return this; + } + + @Override + public ByteBuf getBytes(int index, byte[] dst, int dstIndex, int length) { + buf.getBytes(index, dst, dstIndex, length); + return this; + } + + @Override + public ByteBuf getBytes(int index, ByteBuffer dst) { + buf.getBytes(index, dst); + return this; + } + + @Override + public ByteBuf getBytes(int index, OutputStream out, int length) throws IOException { + buf.getBytes(index, out, length); + return this; + } + + @Override + public int getBytes(int index, GatheringByteChannel out, int length) throws IOException { + return buf.getBytes(index, out, length); + } + + @Override + public int getBytes(int index, FileChannel out, long position, int length) throws IOException { + return buf.getBytes(index, out, position, length); + } + + @Override + public CharSequence getCharSequence(int index, int length, Charset charset) { + return buf.getCharSequence(index, length, charset); + } + + @Override + public ByteBuf setBoolean(int index, boolean value) { + buf.setBoolean(index, value); + return this; + } + + @Override + public ByteBuf setByte(int index, int value) { + buf.setByte(index, value); + return this; + } + + @Override + public ByteBuf setShort(int index, int value) { + buf.setShort(index, value); + return this; + } + + @Override + public ByteBuf setShortLE(int index, int value) { + buf.setShortLE(index, value); + return this; + } + + @Override + public ByteBuf setMedium(int index, int value) { + buf.setMedium(index, value); + return this; + } + + @Override + public ByteBuf setMediumLE(int index, int value) { + buf.setMediumLE(index, value); + return this; + } + + @Override + public ByteBuf setInt(int index, int value) { + buf.setInt(index, value); + return this; + } + + @Override + public ByteBuf setIntLE(int index, int value) { + buf.setIntLE(index, value); + return this; + } + + @Override + public ByteBuf setLong(int index, long value) { + buf.setLong(index, value); + return this; + } + + @Override + public ByteBuf setLongLE(int index, long value) { + buf.setLongLE(index, value); + return this; + } + + @Override + public ByteBuf setChar(int index, int value) { + buf.setChar(index, value); + return this; + } + + @Override + public ByteBuf setFloat(int index, float value) { + buf.setFloat(index, value); + return this; + } + + @Override + public ByteBuf setDouble(int index, double value) { + buf.setDouble(index, value); + return this; + } + + @Override + public ByteBuf setBytes(int index, ByteBuf src) { + buf.setBytes(index, src); + return this; + } + + @Override + public ByteBuf setBytes(int index, ByteBuf src, int length) { + buf.setBytes(index, src, length); + return this; + } + + @Override + public ByteBuf setBytes(int index, ByteBuf src, int srcIndex, int length) { + buf.setBytes(index, src, srcIndex, length); + return this; + } + + @Override + public ByteBuf setBytes(int index, byte[] src) { + buf.setBytes(index, src); + return this; + } + + @Override + public ByteBuf setBytes(int index, byte[] src, int srcIndex, int length) { + buf.setBytes(index, src, srcIndex, length); + return this; + } + + @Override + public ByteBuf setBytes(int index, ByteBuffer src) { + buf.setBytes(index, src); + return this; + } + + @Override + public int setBytes(int index, InputStream in, int length) throws IOException { + return buf.setBytes(index, in, length); + } + + @Override + public int setBytes(int index, ScatteringByteChannel in, int length) throws IOException { + return buf.setBytes(index, in, length); + } + + @Override + public int setBytes(int index, FileChannel in, long position, int length) throws IOException { + return buf.setBytes(index, in, position, length); + } + + @Override + public ByteBuf setZero(int index, int length) { + buf.setZero(index, length); + return this; + } + + @Override + public int setCharSequence(int index, CharSequence sequence, Charset charset) { + return buf.setCharSequence(index, sequence, charset); + } + + @Override + public boolean readBoolean() { + return buf.readBoolean(); + } + + @Override + public byte readByte() { + return buf.readByte(); + } + + @Override + public short readUnsignedByte() { + return buf.readUnsignedByte(); + } + + @Override + public short readShort() { + return buf.readShort(); + } + + @Override + public short readShortLE() { + return buf.readShortLE(); + } + + @Override + public int readUnsignedShort() { + return buf.readUnsignedShort(); + } + + @Override + public int readUnsignedShortLE() { + return buf.readUnsignedShortLE(); + } + + @Override + public int readMedium() { + return buf.readMedium(); + } + + @Override + public int readMediumLE() { + return buf.readMediumLE(); + } + + @Override + public int readUnsignedMedium() { + return buf.readUnsignedMedium(); + } + + @Override + public int readUnsignedMediumLE() { + return buf.readUnsignedMediumLE(); + } + + @Override + public int readInt() { + return buf.readInt(); + } + + @Override + public int readIntLE() { + return buf.readIntLE(); + } + + @Override + public long readUnsignedInt() { + return buf.readUnsignedInt(); + } + + @Override + public long readUnsignedIntLE() { + return buf.readUnsignedIntLE(); + } + + @Override + public long readLong() { + return buf.readLong(); + } + + @Override + public long readLongLE() { + return buf.readLongLE(); + } + + @Override + public char readChar() { + return buf.readChar(); + } + + @Override + public float readFloat() { + return buf.readFloat(); + } + + @Override + public double readDouble() { + return buf.readDouble(); + } + + @Override + public ByteBuf readBytes(int length) { + return buf.readBytes(length); + } + + @Override + public ByteBuf readSlice(int length) { + return buf.readSlice(length); + } + + @Override + public ByteBuf readRetainedSlice(int length) { + return buf.readRetainedSlice(length); + } + + @Override + public ByteBuf readBytes(ByteBuf dst) { + buf.readBytes(dst); + return this; + } + + @Override + public ByteBuf readBytes(ByteBuf dst, int length) { + buf.readBytes(dst, length); + return this; + } + + @Override + public ByteBuf readBytes(ByteBuf dst, int dstIndex, int length) { + buf.readBytes(dst, dstIndex, length); + return this; + } + + @Override + public ByteBuf readBytes(byte[] dst) { + buf.readBytes(dst); + return this; + } + + @Override + public ByteBuf readBytes(byte[] dst, int dstIndex, int length) { + buf.readBytes(dst, dstIndex, length); + return this; + } + + @Override + public ByteBuf readBytes(ByteBuffer dst) { + buf.readBytes(dst); + return this; + } + + @Override + public ByteBuf readBytes(OutputStream out, int length) throws IOException { + buf.readBytes(out, length); + return this; + } + + @Override + public int readBytes(GatheringByteChannel out, int length) throws IOException { + return buf.readBytes(out, length); + } + + @Override + public int readBytes(FileChannel out, long position, int length) throws IOException { + return buf.readBytes(out, position, length); + } + + @Override + public CharSequence readCharSequence(int length, Charset charset) { + return buf.readCharSequence(length, charset); + } + + @Override + public ByteBuf skipBytes(int length) { + buf.skipBytes(length); + return this; + } + + @Override + public ByteBuf writeBoolean(boolean value) { + buf.writeBoolean(value); + return this; + } + + @Override + public ByteBuf writeByte(int value) { + buf.writeByte(value); + return this; + } + + @Override + public ByteBuf writeShort(int value) { + buf.writeShort(value); + return this; + } + + @Override + public ByteBuf writeShortLE(int value) { + buf.writeShortLE(value); + return this; + } + + @Override + public ByteBuf writeMedium(int value) { + buf.writeMedium(value); + return this; + } + + @Override + public ByteBuf writeMediumLE(int value) { + buf.writeMediumLE(value); + return this; + } + + @Override + public ByteBuf writeInt(int value) { + buf.writeInt(value); + return this; + } + + @Override + public ByteBuf writeIntLE(int value) { + buf.writeIntLE(value); + return this; + } + + @Override + public ByteBuf writeLong(long value) { + buf.writeLong(value); + return this; + } + + @Override + public ByteBuf writeLongLE(long value) { + buf.writeLongLE(value); + return this; + } + + @Override + public ByteBuf writeChar(int value) { + buf.writeChar(value); + return this; + } + + @Override + public ByteBuf writeFloat(float value) { + buf.writeFloat(value); + return this; + } + + @Override + public ByteBuf writeDouble(double value) { + buf.writeDouble(value); + return this; + } + + @Override + public ByteBuf writeBytes(ByteBuf src) { + buf.writeBytes(src); + return this; + } + + @Override + public ByteBuf writeBytes(ByteBuf src, int length) { + buf.writeBytes(src, length); + return this; + } + + @Override + public ByteBuf writeBytes(ByteBuf src, int srcIndex, int length) { + buf.writeBytes(src, srcIndex, length); + return this; + } + + @Override + public ByteBuf writeBytes(byte[] src) { + buf.writeBytes(src); + return this; + } + + @Override + public ByteBuf writeBytes(byte[] src, int srcIndex, int length) { + buf.writeBytes(src, srcIndex, length); + return this; + } + + @Override + public ByteBuf writeBytes(ByteBuffer src) { + buf.writeBytes(src); + return this; + } + + @Override + public int writeBytes(InputStream in, int length) throws IOException { + return buf.writeBytes(in, length); + } + + @Override + public int writeBytes(ScatteringByteChannel in, int length) throws IOException { + return buf.writeBytes(in, length); + } + + @Override + public int writeBytes(FileChannel in, long position, int length) throws IOException { + return buf.writeBytes(in, position, length); + } + + @Override + public ByteBuf writeZero(int length) { + buf.writeZero(length); + return this; + } + + @Override + public int writeCharSequence(CharSequence sequence, Charset charset) { + return buf.writeCharSequence(sequence, charset); + } + + @Override + public int indexOf(int fromIndex, int toIndex, byte value) { + return buf.indexOf(fromIndex, toIndex, value); + } + + @Override + public int bytesBefore(byte value) { + return buf.bytesBefore(value); + } + + @Override + public int bytesBefore(int length, byte value) { + return buf.bytesBefore(length, value); + } + + @Override + public int bytesBefore(int index, int length, byte value) { + return buf.bytesBefore(index, length, value); + } + + @Override + public int forEachByte(ByteProcessor processor) { + return buf.forEachByte(processor); + } + + @Override + public int forEachByte(int index, int length, ByteProcessor processor) { + return buf.forEachByte(index, length, processor); + } + + @Override + public int forEachByteDesc(ByteProcessor processor) { + return buf.forEachByteDesc(processor); + } + + @Override + public int forEachByteDesc(int index, int length, ByteProcessor processor) { + return buf.forEachByteDesc(index, length, processor); + } + + @Override + public ByteBuf copy() { + return buf.copy(); + } + + @Override + public ByteBuf copy(int index, int length) { + return buf.copy(index, length); + } + + @Override + public ByteBuf slice() { + return buf.slice(); + } + + @Override + public ByteBuf retainedSlice() { + return buf.retainedSlice(); + } + + @Override + public ByteBuf slice(int index, int length) { + return buf.slice(index, length); + } + + @Override + public ByteBuf retainedSlice(int index, int length) { + return buf.retainedSlice(index, length); + } + + @Override + public ByteBuf duplicate() { + return buf.duplicate(); + } + + @Override + public ByteBuf retainedDuplicate() { + return buf.retainedDuplicate(); + } + + @Override + public int nioBufferCount() { + return buf.nioBufferCount(); + } + + @Override + public ByteBuffer nioBuffer() { + return buf.nioBuffer(); + } + + @Override + public ByteBuffer nioBuffer(int index, int length) { + return buf.nioBuffer(index, length); + } + + @Override + public ByteBuffer[] nioBuffers() { + return buf.nioBuffers(); + } + + @Override + public ByteBuffer[] nioBuffers(int index, int length) { + return buf.nioBuffers(index, length); + } + + @Override + public ByteBuffer internalNioBuffer(int index, int length) { + return buf.internalNioBuffer(index, length); + } + + @Override + public boolean hasArray() { + return buf.hasArray(); + } + + @Override + public byte[] array() { + return buf.array(); + } + + @Override + public int arrayOffset() { + return buf.arrayOffset(); + } + + @Override + public String toString(Charset charset) { + return buf.toString(charset); + } + + @Override + public String toString(int index, int length, Charset charset) { + return buf.toString(index, length, charset); + } + + @Override + public int hashCode() { + return buf.hashCode(); + } + + @Override + @SuppressWarnings("EqualsWhichDoesntCheckParameterClass") + public boolean equals(Object obj) { + return buf.equals(obj); + } + + @Override + public int compareTo(ByteBuf buffer) { + return buf.compareTo(buffer); + } + + @Override + public String toString() { + return StringUtil.simpleClassName(this) + '(' + buf.toString() + ')'; + } + + @Override + public ByteBuf retain(int increment) { + buf.retain(increment); + return this; + } + + @Override + public ByteBuf retain() { + buf.retain(); + return this; + } + + @Override + public ByteBuf touch() { + buf.touch(); + return this; + } + + @Override + public ByteBuf touch(Object hint) { + buf.touch(hint); + return this; + } + + @Override + public final boolean isReadable(int size) { + return buf.isReadable(size); + } + + @Override + public final boolean isWritable(int size) { + return buf.isWritable(size); + } + + @Override + public final int refCnt() { + return buf.refCnt(); + } + + @Override + public boolean release() { + return buf.release(); + } + + @Override + public boolean release(int decrement) { + return buf.release(decrement); + } + + @Override + final boolean isAccessible() { + return buf.isAccessible(); + } +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/WrappedCompositeByteBuf.java b/netty-buffer/src/main/java/io/netty/buffer/WrappedCompositeByteBuf.java new file mode 100644 index 0000000..5a31a5e --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/WrappedCompositeByteBuf.java @@ -0,0 +1,1284 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import io.netty.util.ByteProcessor; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.channels.FileChannel; +import java.nio.channels.GatheringByteChannel; +import java.nio.channels.ScatteringByteChannel; +import java.nio.charset.Charset; +import java.util.Iterator; +import java.util.List; + +class WrappedCompositeByteBuf extends CompositeByteBuf { + + private final CompositeByteBuf wrapped; + + WrappedCompositeByteBuf(CompositeByteBuf wrapped) { + super(wrapped.alloc()); + this.wrapped = wrapped; + } + + @Override + public boolean release() { + return wrapped.release(); + } + + @Override + public boolean release(int decrement) { + return wrapped.release(decrement); + } + + @Override + public final int maxCapacity() { + return wrapped.maxCapacity(); + } + + @Override + public final int readerIndex() { + return wrapped.readerIndex(); + } + + @Override + public final int writerIndex() { + return wrapped.writerIndex(); + } + + @Override + public final boolean isReadable() { + return wrapped.isReadable(); + } + + @Override + public final boolean isReadable(int numBytes) { + return wrapped.isReadable(numBytes); + } + + @Override + public final boolean isWritable() { + return wrapped.isWritable(); + } + + @Override + public final boolean isWritable(int numBytes) { + return wrapped.isWritable(numBytes); + } + + @Override + public final int readableBytes() { + return wrapped.readableBytes(); + } + + @Override + public final int writableBytes() { + return wrapped.writableBytes(); + } + + @Override + public final int maxWritableBytes() { + return wrapped.maxWritableBytes(); + } + + @Override + public int maxFastWritableBytes() { + return wrapped.maxFastWritableBytes(); + } + + @Override + public int ensureWritable(int minWritableBytes, boolean force) { + return wrapped.ensureWritable(minWritableBytes, force); + } + + @Override + public ByteBuf order(ByteOrder endianness) { + return wrapped.order(endianness); + } + + @Override + public boolean getBoolean(int index) { + return wrapped.getBoolean(index); + } + + @Override + public short getUnsignedByte(int index) { + return wrapped.getUnsignedByte(index); + } + + @Override + public short getShort(int index) { + return wrapped.getShort(index); + } + + @Override + public short getShortLE(int index) { + return wrapped.getShortLE(index); + } + + @Override + public int getUnsignedShort(int index) { + return wrapped.getUnsignedShort(index); + } + + @Override + public int getUnsignedShortLE(int index) { + return wrapped.getUnsignedShortLE(index); + } + + @Override + public int getUnsignedMedium(int index) { + return wrapped.getUnsignedMedium(index); + } + + @Override + public int getUnsignedMediumLE(int index) { + return wrapped.getUnsignedMediumLE(index); + } + + @Override + public int getMedium(int index) { + return wrapped.getMedium(index); + } + + @Override + public int getMediumLE(int index) { + return wrapped.getMediumLE(index); + } + + @Override + public int getInt(int index) { + return wrapped.getInt(index); + } + + @Override + public int getIntLE(int index) { + return wrapped.getIntLE(index); + } + + @Override + public long getUnsignedInt(int index) { + return wrapped.getUnsignedInt(index); + } + + @Override + public long getUnsignedIntLE(int index) { + return wrapped.getUnsignedIntLE(index); + } + + @Override + public long getLong(int index) { + return wrapped.getLong(index); + } + + @Override + public long getLongLE(int index) { + return wrapped.getLongLE(index); + } + + @Override + public char getChar(int index) { + return wrapped.getChar(index); + } + + @Override + public float getFloat(int index) { + return wrapped.getFloat(index); + } + + @Override + public double getDouble(int index) { + return wrapped.getDouble(index); + } + + @Override + public ByteBuf setShortLE(int index, int value) { + return wrapped.setShortLE(index, value); + } + + @Override + public ByteBuf setMediumLE(int index, int value) { + return wrapped.setMediumLE(index, value); + } + + @Override + public ByteBuf setIntLE(int index, int value) { + return wrapped.setIntLE(index, value); + } + + @Override + public ByteBuf setLongLE(int index, long value) { + return wrapped.setLongLE(index, value); + } + + @Override + public byte readByte() { + return wrapped.readByte(); + } + + @Override + public boolean readBoolean() { + return wrapped.readBoolean(); + } + + @Override + public short readUnsignedByte() { + return wrapped.readUnsignedByte(); + } + + @Override + public short readShort() { + return wrapped.readShort(); + } + + @Override + public short readShortLE() { + return wrapped.readShortLE(); + } + + @Override + public int readUnsignedShort() { + return wrapped.readUnsignedShort(); + } + + @Override + public int readUnsignedShortLE() { + return wrapped.readUnsignedShortLE(); + } + + @Override + public int readMedium() { + return wrapped.readMedium(); + } + + @Override + public int readMediumLE() { + return wrapped.readMediumLE(); + } + + @Override + public int readUnsignedMedium() { + return wrapped.readUnsignedMedium(); + } + + @Override + public int readUnsignedMediumLE() { + return wrapped.readUnsignedMediumLE(); + } + + @Override + public int readInt() { + return wrapped.readInt(); + } + + @Override + public int readIntLE() { + return wrapped.readIntLE(); + } + + @Override + public long readUnsignedInt() { + return wrapped.readUnsignedInt(); + } + + @Override + public long readUnsignedIntLE() { + return wrapped.readUnsignedIntLE(); + } + + @Override + public long readLong() { + return wrapped.readLong(); + } + + @Override + public long readLongLE() { + return wrapped.readLongLE(); + } + + @Override + public char readChar() { + return wrapped.readChar(); + } + + @Override + public float readFloat() { + return wrapped.readFloat(); + } + + @Override + public double readDouble() { + return wrapped.readDouble(); + } + + @Override + public ByteBuf readBytes(int length) { + return wrapped.readBytes(length); + } + + @Override + public ByteBuf slice() { + return wrapped.slice(); + } + + @Override + public ByteBuf retainedSlice() { + return wrapped.retainedSlice(); + } + + @Override + public ByteBuf slice(int index, int length) { + return wrapped.slice(index, length); + } + + @Override + public ByteBuf retainedSlice(int index, int length) { + return wrapped.retainedSlice(index, length); + } + + @Override + public ByteBuffer nioBuffer() { + return wrapped.nioBuffer(); + } + + @Override + public String toString(Charset charset) { + return wrapped.toString(charset); + } + + @Override + public String toString(int index, int length, Charset charset) { + return wrapped.toString(index, length, charset); + } + + @Override + public int indexOf(int fromIndex, int toIndex, byte value) { + return wrapped.indexOf(fromIndex, toIndex, value); + } + + @Override + public int bytesBefore(byte value) { + return wrapped.bytesBefore(value); + } + + @Override + public int bytesBefore(int length, byte value) { + return wrapped.bytesBefore(length, value); + } + + @Override + public int bytesBefore(int index, int length, byte value) { + return wrapped.bytesBefore(index, length, value); + } + + @Override + public int forEachByte(ByteProcessor processor) { + return wrapped.forEachByte(processor); + } + + @Override + public int forEachByte(int index, int length, ByteProcessor processor) { + return wrapped.forEachByte(index, length, processor); + } + + @Override + public int forEachByteDesc(ByteProcessor processor) { + return wrapped.forEachByteDesc(processor); + } + + @Override + public int forEachByteDesc(int index, int length, ByteProcessor processor) { + return wrapped.forEachByteDesc(index, length, processor); + } + + @Override + protected int forEachByteAsc0(int start, int end, ByteProcessor processor) throws Exception { + return wrapped.forEachByteAsc0(start, end, processor); + } + + @Override + protected int forEachByteDesc0(int rStart, int rEnd, ByteProcessor processor) throws Exception { + return wrapped.forEachByteDesc0(rStart, rEnd, processor); + } + + @Override + public final int hashCode() { + return wrapped.hashCode(); + } + + @Override + public final boolean equals(Object o) { + return wrapped.equals(o); + } + + @Override + public final int compareTo(ByteBuf that) { + return wrapped.compareTo(that); + } + + @Override + public final int refCnt() { + return wrapped.refCnt(); + } + + @Override + final boolean isAccessible() { + return wrapped.isAccessible(); + } + + @Override + public ByteBuf duplicate() { + return wrapped.duplicate(); + } + + @Override + public ByteBuf retainedDuplicate() { + return wrapped.retainedDuplicate(); + } + + @Override + public ByteBuf readSlice(int length) { + return wrapped.readSlice(length); + } + + @Override + public ByteBuf readRetainedSlice(int length) { + return wrapped.readRetainedSlice(length); + } + + @Override + public int readBytes(GatheringByteChannel out, int length) throws IOException { + return wrapped.readBytes(out, length); + } + + @Override + public ByteBuf writeShortLE(int value) { + return wrapped.writeShortLE(value); + } + + @Override + public ByteBuf writeMediumLE(int value) { + return wrapped.writeMediumLE(value); + } + + @Override + public ByteBuf writeIntLE(int value) { + return wrapped.writeIntLE(value); + } + + @Override + public ByteBuf writeLongLE(long value) { + return wrapped.writeLongLE(value); + } + + @Override + public int writeBytes(InputStream in, int length) throws IOException { + return wrapped.writeBytes(in, length); + } + + @Override + public int writeBytes(ScatteringByteChannel in, int length) throws IOException { + return wrapped.writeBytes(in, length); + } + + @Override + public ByteBuf copy() { + return wrapped.copy(); + } + + @Override + public CompositeByteBuf addComponent(ByteBuf buffer) { + wrapped.addComponent(buffer); + return this; + } + + @Override + public CompositeByteBuf addComponents(ByteBuf... buffers) { + wrapped.addComponents(buffers); + return this; + } + + @Override + public CompositeByteBuf addComponents(Iterable buffers) { + wrapped.addComponents(buffers); + return this; + } + + @Override + public CompositeByteBuf addComponent(int cIndex, ByteBuf buffer) { + wrapped.addComponent(cIndex, buffer); + return this; + } + + @Override + public CompositeByteBuf addComponents(int cIndex, ByteBuf... buffers) { + wrapped.addComponents(cIndex, buffers); + return this; + } + + @Override + public CompositeByteBuf addComponents(int cIndex, Iterable buffers) { + wrapped.addComponents(cIndex, buffers); + return this; + } + + @Override + public CompositeByteBuf addComponent(boolean increaseWriterIndex, ByteBuf buffer) { + wrapped.addComponent(increaseWriterIndex, buffer); + return this; + } + + @Override + public CompositeByteBuf addComponents(boolean increaseWriterIndex, ByteBuf... buffers) { + wrapped.addComponents(increaseWriterIndex, buffers); + return this; + } + + @Override + public CompositeByteBuf addComponents(boolean increaseWriterIndex, Iterable buffers) { + wrapped.addComponents(increaseWriterIndex, buffers); + return this; + } + + @Override + public CompositeByteBuf addComponent(boolean increaseWriterIndex, int cIndex, ByteBuf buffer) { + wrapped.addComponent(increaseWriterIndex, cIndex, buffer); + return this; + } + + @Override + public CompositeByteBuf addFlattenedComponents(boolean increaseWriterIndex, ByteBuf buffer) { + wrapped.addFlattenedComponents(increaseWriterIndex, buffer); + return this; + } + + @Override + public CompositeByteBuf removeComponent(int cIndex) { + wrapped.removeComponent(cIndex); + return this; + } + + @Override + public CompositeByteBuf removeComponents(int cIndex, int numComponents) { + wrapped.removeComponents(cIndex, numComponents); + return this; + } + + @Override + public Iterator iterator() { + return wrapped.iterator(); + } + + @Override + public List decompose(int offset, int length) { + return wrapped.decompose(offset, length); + } + + @Override + public final boolean isDirect() { + return wrapped.isDirect(); + } + + @Override + public final boolean hasArray() { + return wrapped.hasArray(); + } + + @Override + public final byte[] array() { + return wrapped.array(); + } + + @Override + public final int arrayOffset() { + return wrapped.arrayOffset(); + } + + @Override + public final boolean hasMemoryAddress() { + return wrapped.hasMemoryAddress(); + } + + @Override + public final long memoryAddress() { + return wrapped.memoryAddress(); + } + + @Override + public final int capacity() { + return wrapped.capacity(); + } + + @Override + public CompositeByteBuf capacity(int newCapacity) { + wrapped.capacity(newCapacity); + return this; + } + + @Override + public final ByteBufAllocator alloc() { + return wrapped.alloc(); + } + + @Override + public final ByteOrder order() { + return wrapped.order(); + } + + @Override + public final int numComponents() { + return wrapped.numComponents(); + } + + @Override + public final int maxNumComponents() { + return wrapped.maxNumComponents(); + } + + @Override + public final int toComponentIndex(int offset) { + return wrapped.toComponentIndex(offset); + } + + @Override + public final int toByteIndex(int cIndex) { + return wrapped.toByteIndex(cIndex); + } + + @Override + public byte getByte(int index) { + return wrapped.getByte(index); + } + + @Override + protected final byte _getByte(int index) { + return wrapped._getByte(index); + } + + @Override + protected final short _getShort(int index) { + return wrapped._getShort(index); + } + + @Override + protected final short _getShortLE(int index) { + return wrapped._getShortLE(index); + } + + @Override + protected final int _getUnsignedMedium(int index) { + return wrapped._getUnsignedMedium(index); + } + + @Override + protected final int _getUnsignedMediumLE(int index) { + return wrapped._getUnsignedMediumLE(index); + } + + @Override + protected final int _getInt(int index) { + return wrapped._getInt(index); + } + + @Override + protected final int _getIntLE(int index) { + return wrapped._getIntLE(index); + } + + @Override + protected final long _getLong(int index) { + return wrapped._getLong(index); + } + + @Override + protected final long _getLongLE(int index) { + return wrapped._getLongLE(index); + } + + @Override + public CompositeByteBuf getBytes(int index, byte[] dst, int dstIndex, int length) { + wrapped.getBytes(index, dst, dstIndex, length); + return this; + } + + @Override + public CompositeByteBuf getBytes(int index, ByteBuffer dst) { + wrapped.getBytes(index, dst); + return this; + } + + @Override + public CompositeByteBuf getBytes(int index, ByteBuf dst, int dstIndex, int length) { + wrapped.getBytes(index, dst, dstIndex, length); + return this; + } + + @Override + public int getBytes(int index, GatheringByteChannel out, int length) throws IOException { + return wrapped.getBytes(index, out, length); + } + + @Override + public CompositeByteBuf getBytes(int index, OutputStream out, int length) throws IOException { + wrapped.getBytes(index, out, length); + return this; + } + + @Override + public CompositeByteBuf setByte(int index, int value) { + wrapped.setByte(index, value); + return this; + } + + @Override + protected final void _setByte(int index, int value) { + wrapped._setByte(index, value); + } + + @Override + public CompositeByteBuf setShort(int index, int value) { + wrapped.setShort(index, value); + return this; + } + + @Override + protected final void _setShort(int index, int value) { + wrapped._setShort(index, value); + } + + @Override + protected final void _setShortLE(int index, int value) { + wrapped._setShortLE(index, value); + } + + @Override + public CompositeByteBuf setMedium(int index, int value) { + wrapped.setMedium(index, value); + return this; + } + + @Override + protected final void _setMedium(int index, int value) { + wrapped._setMedium(index, value); + } + + @Override + protected final void _setMediumLE(int index, int value) { + wrapped._setMediumLE(index, value); + } + + @Override + public CompositeByteBuf setInt(int index, int value) { + wrapped.setInt(index, value); + return this; + } + + @Override + protected final void _setInt(int index, int value) { + wrapped._setInt(index, value); + } + + @Override + protected final void _setIntLE(int index, int value) { + wrapped._setIntLE(index, value); + } + + @Override + public CompositeByteBuf setLong(int index, long value) { + wrapped.setLong(index, value); + return this; + } + + @Override + protected final void _setLong(int index, long value) { + wrapped._setLong(index, value); + } + + @Override + protected final void _setLongLE(int index, long value) { + wrapped._setLongLE(index, value); + } + + @Override + public CompositeByteBuf setBytes(int index, byte[] src, int srcIndex, int length) { + wrapped.setBytes(index, src, srcIndex, length); + return this; + } + + @Override + public CompositeByteBuf setBytes(int index, ByteBuffer src) { + wrapped.setBytes(index, src); + return this; + } + + @Override + public CompositeByteBuf setBytes(int index, ByteBuf src, int srcIndex, int length) { + wrapped.setBytes(index, src, srcIndex, length); + return this; + } + + @Override + public int setBytes(int index, InputStream in, int length) throws IOException { + return wrapped.setBytes(index, in, length); + } + + @Override + public int setBytes(int index, ScatteringByteChannel in, int length) throws IOException { + return wrapped.setBytes(index, in, length); + } + + @Override + public ByteBuf copy(int index, int length) { + return wrapped.copy(index, length); + } + + @Override + public final ByteBuf component(int cIndex) { + return wrapped.component(cIndex); + } + + @Override + public final ByteBuf componentAtOffset(int offset) { + return wrapped.componentAtOffset(offset); + } + + @Override + public final ByteBuf internalComponent(int cIndex) { + return wrapped.internalComponent(cIndex); + } + + @Override + public final ByteBuf internalComponentAtOffset(int offset) { + return wrapped.internalComponentAtOffset(offset); + } + + @Override + public int nioBufferCount() { + return wrapped.nioBufferCount(); + } + + @Override + public ByteBuffer internalNioBuffer(int index, int length) { + return wrapped.internalNioBuffer(index, length); + } + + @Override + public ByteBuffer nioBuffer(int index, int length) { + return wrapped.nioBuffer(index, length); + } + + @Override + public ByteBuffer[] nioBuffers(int index, int length) { + return wrapped.nioBuffers(index, length); + } + + @Override + public CompositeByteBuf consolidate() { + wrapped.consolidate(); + return this; + } + + @Override + public CompositeByteBuf consolidate(int cIndex, int numComponents) { + wrapped.consolidate(cIndex, numComponents); + return this; + } + + @Override + public CompositeByteBuf discardReadComponents() { + wrapped.discardReadComponents(); + return this; + } + + @Override + public CompositeByteBuf discardReadBytes() { + wrapped.discardReadBytes(); + return this; + } + + @Override + public final String toString() { + return wrapped.toString(); + } + + @Override + public final CompositeByteBuf readerIndex(int readerIndex) { + wrapped.readerIndex(readerIndex); + return this; + } + + @Override + public final CompositeByteBuf writerIndex(int writerIndex) { + wrapped.writerIndex(writerIndex); + return this; + } + + @Override + public final CompositeByteBuf setIndex(int readerIndex, int writerIndex) { + wrapped.setIndex(readerIndex, writerIndex); + return this; + } + + @Override + public final CompositeByteBuf clear() { + wrapped.clear(); + return this; + } + + @Override + public final CompositeByteBuf markReaderIndex() { + wrapped.markReaderIndex(); + return this; + } + + @Override + public final CompositeByteBuf resetReaderIndex() { + wrapped.resetReaderIndex(); + return this; + } + + @Override + public final CompositeByteBuf markWriterIndex() { + wrapped.markWriterIndex(); + return this; + } + + @Override + public final CompositeByteBuf resetWriterIndex() { + wrapped.resetWriterIndex(); + return this; + } + + @Override + public CompositeByteBuf ensureWritable(int minWritableBytes) { + wrapped.ensureWritable(minWritableBytes); + return this; + } + + @Override + public CompositeByteBuf getBytes(int index, ByteBuf dst) { + wrapped.getBytes(index, dst); + return this; + } + + @Override + public CompositeByteBuf getBytes(int index, ByteBuf dst, int length) { + wrapped.getBytes(index, dst, length); + return this; + } + + @Override + public CompositeByteBuf getBytes(int index, byte[] dst) { + wrapped.getBytes(index, dst); + return this; + } + + @Override + public CompositeByteBuf setBoolean(int index, boolean value) { + wrapped.setBoolean(index, value); + return this; + } + + @Override + public CompositeByteBuf setChar(int index, int value) { + wrapped.setChar(index, value); + return this; + } + + @Override + public CompositeByteBuf setFloat(int index, float value) { + wrapped.setFloat(index, value); + return this; + } + + @Override + public CompositeByteBuf setDouble(int index, double value) { + wrapped.setDouble(index, value); + return this; + } + + @Override + public CompositeByteBuf setBytes(int index, ByteBuf src) { + wrapped.setBytes(index, src); + return this; + } + + @Override + public CompositeByteBuf setBytes(int index, ByteBuf src, int length) { + wrapped.setBytes(index, src, length); + return this; + } + + @Override + public CompositeByteBuf setBytes(int index, byte[] src) { + wrapped.setBytes(index, src); + return this; + } + + @Override + public CompositeByteBuf setZero(int index, int length) { + wrapped.setZero(index, length); + return this; + } + + @Override + public CompositeByteBuf readBytes(ByteBuf dst) { + wrapped.readBytes(dst); + return this; + } + + @Override + public CompositeByteBuf readBytes(ByteBuf dst, int length) { + wrapped.readBytes(dst, length); + return this; + } + + @Override + public CompositeByteBuf readBytes(ByteBuf dst, int dstIndex, int length) { + wrapped.readBytes(dst, dstIndex, length); + return this; + } + + @Override + public CompositeByteBuf readBytes(byte[] dst) { + wrapped.readBytes(dst); + return this; + } + + @Override + public CompositeByteBuf readBytes(byte[] dst, int dstIndex, int length) { + wrapped.readBytes(dst, dstIndex, length); + return this; + } + + @Override + public CompositeByteBuf readBytes(ByteBuffer dst) { + wrapped.readBytes(dst); + return this; + } + + @Override + public CompositeByteBuf readBytes(OutputStream out, int length) throws IOException { + wrapped.readBytes(out, length); + return this; + } + + @Override + public int getBytes(int index, FileChannel out, long position, int length) throws IOException { + return wrapped.getBytes(index, out, position, length); + } + + @Override + public int setBytes(int index, FileChannel in, long position, int length) throws IOException { + return wrapped.setBytes(index, in, position, length); + } + + @Override + public boolean isReadOnly() { + return wrapped.isReadOnly(); + } + + @Override + public ByteBuf asReadOnly() { + return wrapped.asReadOnly(); + } + + @Override + protected SwappedByteBuf newSwappedByteBuf() { + return wrapped.newSwappedByteBuf(); + } + + @Override + public CharSequence getCharSequence(int index, int length, Charset charset) { + return wrapped.getCharSequence(index, length, charset); + } + + @Override + public CharSequence readCharSequence(int length, Charset charset) { + return wrapped.readCharSequence(length, charset); + } + + @Override + public int setCharSequence(int index, CharSequence sequence, Charset charset) { + return wrapped.setCharSequence(index, sequence, charset); + } + + @Override + public int readBytes(FileChannel out, long position, int length) throws IOException { + return wrapped.readBytes(out, position, length); + } + + @Override + public int writeBytes(FileChannel in, long position, int length) throws IOException { + return wrapped.writeBytes(in, position, length); + } + + @Override + public int writeCharSequence(CharSequence sequence, Charset charset) { + return wrapped.writeCharSequence(sequence, charset); + } + + @Override + public CompositeByteBuf skipBytes(int length) { + wrapped.skipBytes(length); + return this; + } + + @Override + public CompositeByteBuf writeBoolean(boolean value) { + wrapped.writeBoolean(value); + return this; + } + + @Override + public CompositeByteBuf writeByte(int value) { + wrapped.writeByte(value); + return this; + } + + @Override + public CompositeByteBuf writeShort(int value) { + wrapped.writeShort(value); + return this; + } + + @Override + public CompositeByteBuf writeMedium(int value) { + wrapped.writeMedium(value); + return this; + } + + @Override + public CompositeByteBuf writeInt(int value) { + wrapped.writeInt(value); + return this; + } + + @Override + public CompositeByteBuf writeLong(long value) { + wrapped.writeLong(value); + return this; + } + + @Override + public CompositeByteBuf writeChar(int value) { + wrapped.writeChar(value); + return this; + } + + @Override + public CompositeByteBuf writeFloat(float value) { + wrapped.writeFloat(value); + return this; + } + + @Override + public CompositeByteBuf writeDouble(double value) { + wrapped.writeDouble(value); + return this; + } + + @Override + public CompositeByteBuf writeBytes(ByteBuf src) { + wrapped.writeBytes(src); + return this; + } + + @Override + public CompositeByteBuf writeBytes(ByteBuf src, int length) { + wrapped.writeBytes(src, length); + return this; + } + + @Override + public CompositeByteBuf writeBytes(ByteBuf src, int srcIndex, int length) { + wrapped.writeBytes(src, srcIndex, length); + return this; + } + + @Override + public CompositeByteBuf writeBytes(byte[] src) { + wrapped.writeBytes(src); + return this; + } + + @Override + public CompositeByteBuf writeBytes(byte[] src, int srcIndex, int length) { + wrapped.writeBytes(src, srcIndex, length); + return this; + } + + @Override + public CompositeByteBuf writeBytes(ByteBuffer src) { + wrapped.writeBytes(src); + return this; + } + + @Override + public CompositeByteBuf writeZero(int length) { + wrapped.writeZero(length); + return this; + } + + @Override + public CompositeByteBuf retain(int increment) { + wrapped.retain(increment); + return this; + } + + @Override + public CompositeByteBuf retain() { + wrapped.retain(); + return this; + } + + @Override + public CompositeByteBuf touch() { + wrapped.touch(); + return this; + } + + @Override + public CompositeByteBuf touch(Object hint) { + wrapped.touch(hint); + return this; + } + + @Override + public ByteBuffer[] nioBuffers() { + return wrapped.nioBuffers(); + } + + @Override + public CompositeByteBuf discardSomeReadBytes() { + wrapped.discardSomeReadBytes(); + return this; + } + + @Override + public final void deallocate() { + wrapped.deallocate(); + } + + @Override + public final ByteBuf unwrap() { + return wrapped; + } +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/WrappedUnpooledUnsafeDirectByteBuf.java b/netty-buffer/src/main/java/io/netty/buffer/WrappedUnpooledUnsafeDirectByteBuf.java new file mode 100644 index 0000000..dd8493c --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/WrappedUnpooledUnsafeDirectByteBuf.java @@ -0,0 +1,32 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import io.netty.util.internal.PlatformDependent; + +import java.nio.ByteBuffer; + +final class WrappedUnpooledUnsafeDirectByteBuf extends UnpooledUnsafeDirectByteBuf { + + WrappedUnpooledUnsafeDirectByteBuf(ByteBufAllocator alloc, long memoryAddress, int size, boolean doFree) { + super(alloc, PlatformDependent.directBuffer(memoryAddress, size), size, doFree); + } + + @Override + protected void freeDirect(ByteBuffer buffer) { + PlatformDependent.freeMemory(memoryAddress); + } +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/package-info.java b/netty-buffer/src/main/java/io/netty/buffer/package-info.java new file mode 100644 index 0000000..4ed7939 --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/package-info.java @@ -0,0 +1,128 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/** + * Abstraction of a byte buffer - the fundamental data structure + * to represent a low-level binary and text message. + * + * Netty uses its own buffer API instead of NIO {@link java.nio.ByteBuffer} to + * represent a sequence of bytes. This approach has significant advantage over + * using {@link java.nio.ByteBuffer}. Netty's new buffer type, + * {@link io.netty.buffer.ByteBuf}, has been designed from ground + * up to address the problems of {@link java.nio.ByteBuffer} and to meet the + * daily needs of network application developers. To list a few cool features: + *
    + *
  • You can define your buffer type if necessary.
  • + *
  • Transparent zero copy is achieved by built-in composite buffer type.
  • + *
  • A dynamic buffer type is provided out-of-the-box, whose capacity is + * expanded on demand, just like {@link java.lang.StringBuffer}.
  • + *
  • There's no need to call the {@code flip()} method anymore.
  • + *
  • It is often faster than {@link java.nio.ByteBuffer}.
  • + *
+ * + *

Extensibility

+ * + * {@link io.netty.buffer.ByteBuf} has rich set of operations + * optimized for rapid protocol implementation. For example, + * {@link io.netty.buffer.ByteBuf} provides various operations + * for accessing unsigned values and strings and searching for certain byte + * sequence in a buffer. You can also extend or wrap existing buffer type + * to add convenient accessors. The custom buffer type still implements + * {@link io.netty.buffer.ByteBuf} interface rather than + * introducing an incompatible type. + * + *

Transparent Zero Copy

+ * + * To lift up the performance of a network application to the extreme, you need + * to reduce the number of memory copy operation. You might have a set of + * buffers that could be sliced and combined to compose a whole message. Netty + * provides a composite buffer which allows you to create a new buffer from the + * arbitrary number of existing buffers with no memory copy. For example, a + * message could be composed of two parts; header and body. In a modularized + * application, the two parts could be produced by different modules and + * assembled later when the message is sent out. + *
+ * +--------+----------+
+ * | header |   body   |
+ * +--------+----------+
+ * 
+ * If {@link java.nio.ByteBuffer} were used, you would have to create a new big + * buffer and copy the two parts into the new buffer. Alternatively, you can + * perform a gathering write operation in NIO, but it restricts you to represent + * the composite of buffers as an array of {@link java.nio.ByteBuffer}s rather + * than a single buffer, breaking the abstraction and introducing complicated + * state management. Moreover, it's of no use if you are not going to read or + * write from an NIO channel. + *
+ * // The composite type is incompatible with the component type.
+ * ByteBuffer[] message = new ByteBuffer[] { header, body };
+ * 
+ * By contrast, {@link io.netty.buffer.ByteBuf} does not have such + * caveats because it is fully extensible and has a built-in composite buffer + * type. + *
+ * // The composite type is compatible with the component type.
+ * {@link io.netty.buffer.ByteBuf} message = {@link io.netty.buffer.Unpooled}.wrappedBuffer(header, body);
+ *
+ * // Therefore, you can even create a composite by mixing a composite and an
+ * // ordinary buffer.
+ * {@link io.netty.buffer.ByteBuf} messageWithFooter = {@link io.netty.buffer.Unpooled}.wrappedBuffer(message, footer);
+ *
+ * // Because the composite is still a {@link io.netty.buffer.ByteBuf}, you can access its content
+ * // easily, and the accessor method will behave just like it's a single buffer
+ * // even if the region you want to access spans over multiple components.  The
+ * // unsigned integer being read here is located across body and footer.
+ * messageWithFooter.getUnsignedInt(
+ *     messageWithFooter.readableBytes() - footer.readableBytes() - 1);
+ * 
+ * + *

Automatic Capacity Extension

+ * + * Many protocols define variable length messages, which means there's no way to + * determine the length of a message until you construct the message or it is + * difficult and inconvenient to calculate the length precisely. It is just + * like when you build a {@link java.lang.String}. You often estimate the length + * of the resulting string and let {@link java.lang.StringBuffer} expand itself + * on demand. + *
+ * // A new dynamic buffer is created.  Internally, the actual buffer is created
+ * // lazily to avoid potentially wasted memory space.
+ * {@link io.netty.buffer.ByteBuf} b = {@link io.netty.buffer.Unpooled}.buffer(4);
+ *
+ * // When the first write attempt is made, the internal buffer is created with
+ * // the specified initial capacity (4).
+ * b.writeByte('1');
+ *
+ * b.writeByte('2');
+ * b.writeByte('3');
+ * b.writeByte('4');
+ *
+ * // When the number of written bytes exceeds the initial capacity (4), the
+ * // internal buffer is reallocated automatically with a larger capacity.
+ * b.writeByte('5');
+ * 
+ * + *

Better Performance

+ * + * Most frequently used buffer implementation of + * {@link io.netty.buffer.ByteBuf} is a very thin wrapper of a + * byte array (i.e. {@code byte[]}). Unlike {@link java.nio.ByteBuffer}, it has + * no complicated boundary check and index compensation, and therefore it is + * easier for a JVM to optimize the buffer access. More complicated buffer + * implementation is used only for sliced or composite buffers, and it performs + * as well as {@link java.nio.ByteBuffer}. + */ +package io.netty.buffer; diff --git a/netty-buffer/src/main/java/io/netty/buffer/search/AbstractMultiSearchProcessorFactory.java b/netty-buffer/src/main/java/io/netty/buffer/search/AbstractMultiSearchProcessorFactory.java new file mode 100644 index 0000000..e8c6067 --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/search/AbstractMultiSearchProcessorFactory.java @@ -0,0 +1,94 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.buffer.search; + +/** + * Base class for precomputed factories that create {@link MultiSearchProcessor}s. + *
+ * The purpose of {@link MultiSearchProcessor} is to perform efficient simultaneous search for multiple {@code needles} + * in the {@code haystack}, while scanning every byte of the input sequentially, only once. While it can also be used + * to search for just a single {@code needle}, using a {@link SearchProcessorFactory} would be more efficient for + * doing that. + *
+ * See the documentation of {@link AbstractSearchProcessorFactory} for a comprehensive description of common usage. + * In addition to the functionality provided by {@link SearchProcessor}, {@link MultiSearchProcessor} adds + * a method to get the index of the {@code needle} found at the current position of the {@link MultiSearchProcessor} - + * {@link MultiSearchProcessor#getFoundNeedleId()}. + *
+ * Note: in some cases one {@code needle} can be a suffix of another {@code needle}, eg. {@code {"BC", "ABC"}}, + * and there can potentially be multiple {@code needles} found ending at the same position of the {@code haystack}. + * In such case {@link MultiSearchProcessor#getFoundNeedleId()} returns the index of the longest matching {@code needle} + * in the array of {@code needles}. + *
+ * Usage example (given that the {@code haystack} is a {@link io.netty.buffer.ByteBuf} containing "ABCD" and the + * {@code needles} are "AB", "BC" and "CD"): + *
+ *      MultiSearchProcessorFactory factory = MultiSearchProcessorFactory.newAhoCorasicSearchProcessorFactory(
+ *          "AB".getBytes(CharsetUtil.UTF_8), "BC".getBytes(CharsetUtil.UTF_8), "CD".getBytes(CharsetUtil.UTF_8));
+ *      MultiSearchProcessor processor = factory.newSearchProcessor();
+ *
+ *      int idx1 = haystack.forEachByte(processor);
+ *      // idx1 is 1 (index of the last character of the occurrence of "AB" in the haystack)
+ *      // processor.getFoundNeedleId() is 0 (index of "AB" in needles[])
+ *
+ *      int continueFrom1 = idx1 + 1;
+ *      // continue the search starting from the next character
+ *
+ *      int idx2 = haystack.forEachByte(continueFrom1, haystack.readableBytes() - continueFrom1, processor);
+ *      // idx2 is 2 (index of the last character of the occurrence of "BC" in the haystack)
+ *      // processor.getFoundNeedleId() is 1 (index of "BC" in needles[])
+ *
+ *      int continueFrom2 = idx2 + 1;
+ *
+ *      int idx3 = haystack.forEachByte(continueFrom2, haystack.readableBytes() - continueFrom2, processor);
+ *      // idx3 is 3 (index of the last character of the occurrence of "CD" in the haystack)
+ *      // processor.getFoundNeedleId() is 2 (index of "CD" in needles[])
+ *
+ *      int continueFrom3 = idx3 + 1;
+ *
+ *      int idx4 = haystack.forEachByte(continueFrom3, haystack.readableBytes() - continueFrom3, processor);
+ *      // idx4 is -1 (no more occurrences of any of the needles)
+ *
+ *      // This search session is complete, processor should be discarded.
+ *      // To search for the same needles again, reuse the same {@link AbstractMultiSearchProcessorFactory}
+ *      // to get a new MultiSearchProcessor.
+ * 
+ */ +public abstract class AbstractMultiSearchProcessorFactory implements MultiSearchProcessorFactory { + + /** + * Creates a {@link MultiSearchProcessorFactory} based on + * Aho–Corasick + * string search algorithm. + *
+ * Precomputation (this method) time is linear in the size of input ({@code O(Σ|needles|)}). + *
+ * The factory allocates and retains an array of 256 * X ints plus another array of X ints, where X + * is the sum of lengths of each entry of {@code needles} minus the sum of lengths of repeated + * prefixes of the {@code needles}. + *
+ * Search (the actual application of {@link MultiSearchProcessor}) time is linear in the size of + * {@link io.netty.buffer.ByteBuf} on which the search is performed ({@code O(|haystack|)}). + * Every byte of {@link io.netty.buffer.ByteBuf} is processed only once, sequentually, regardles of + * the number of {@code needles} being searched for. + * + * @param needles a varargs array of arrays of bytes to search for + * @return a new instance of {@link AhoCorasicSearchProcessorFactory} precomputed for the given {@code needles} + */ + public static AhoCorasicSearchProcessorFactory newAhoCorasicSearchProcessorFactory(byte[] ...needles) { + return new AhoCorasicSearchProcessorFactory(needles); + } + +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/search/AbstractSearchProcessorFactory.java b/netty-buffer/src/main/java/io/netty/buffer/search/AbstractSearchProcessorFactory.java new file mode 100644 index 0000000..a044e24 --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/search/AbstractSearchProcessorFactory.java @@ -0,0 +1,115 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.buffer.search; + +/** + * Base class for precomputed factories that create {@link SearchProcessor}s. + *
+ * Different factories implement different search algorithms with performance characteristics that + * depend on a use case, so it is advisable to benchmark a concrete use case with different algorithms + * before choosing one of them. + *
+ * A concrete instance of {@link AbstractSearchProcessorFactory} is built for searching for a concrete sequence of bytes + * (the {@code needle}), it contains precomputed data needed to perform the search, and is meant to be reused + * whenever searching for the same {@code needle}. + *
+ * Note: implementations of {@link SearchProcessor} scan the {@link io.netty.buffer.ByteBuf} sequentially, + * one byte after another, without doing any random access. As a result, when using {@link SearchProcessor} + * with such methods as {@link io.netty.buffer.ByteBuf#forEachByte}, these methods return the index of the last byte + * of the found byte sequence within the {@link io.netty.buffer.ByteBuf} (which might feel counterintuitive, + * and different from {@link io.netty.buffer.ByteBufUtil#indexOf} which returns the index of the first byte + * of found sequence). + *
+ * A {@link SearchProcessor} is implemented as a + * Finite State Automaton that contains a + * small internal state which is updated with every byte processed. As a result, an instance of {@link SearchProcessor} + * should not be reused across independent search sessions (eg. for searching in different + * {@link io.netty.buffer.ByteBuf}s). A new instance should be created with {@link AbstractSearchProcessorFactory} for + * every search session. However, a {@link SearchProcessor} can (and should) be reused within the search session, + * eg. when searching for all occurrences of the {@code needle} within the same {@code haystack}. That way, it can + * also detect overlapping occurrences of the {@code needle} (eg. a string "ABABAB" contains two occurrences of "BAB" + * that overlap by one character "B"). For this to work correctly, after an occurrence of the {@code needle} is + * found ending at index {@code idx}, the search should continue starting from the index {@code idx + 1}. + *
+ * Example (given that the {@code haystack} is a {@link io.netty.buffer.ByteBuf} containing "ABABAB" and + * the {@code needle} is "BAB"): + *
+ *     SearchProcessorFactory factory =
+ *         SearchProcessorFactory.newKmpSearchProcessorFactory(needle.getBytes(CharsetUtil.UTF_8));
+ *     SearchProcessor processor = factory.newSearchProcessor();
+ *
+ *     int idx1 = haystack.forEachByte(processor);
+ *     // idx1 is 3 (index of the last character of the first occurrence of the needle in the haystack)
+ *
+ *     int continueFrom1 = idx1 + 1;
+ *     // continue the search starting from the next character
+ *
+ *     int idx2 = haystack.forEachByte(continueFrom1, haystack.readableBytes() - continueFrom1, processor);
+ *     // idx2 is 5 (index of the last character of the second occurrence of the needle in the haystack)
+ *
+ *     int continueFrom2 = idx2 + 1;
+ *     // continue the search starting from the next character
+ *
+ *     int idx3 = haystack.forEachByte(continueFrom2, haystack.readableBytes() - continueFrom2, processor);
+ *     // idx3 is -1 (no more occurrences of the needle)
+ *
+ *     // After this search session is complete, processor should be discarded.
+ *     // To search for the same needle again, reuse the same factory to get a new SearchProcessor.
+ * 
+ */ +public abstract class AbstractSearchProcessorFactory implements SearchProcessorFactory { + + /** + * Creates a {@link SearchProcessorFactory} based on + * Knuth-Morris-Pratt + * string search algorithm. It is a reasonable default choice among the provided algorithms. + *
+ * Precomputation (this method) time is linear in the size of input ({@code O(|needle|)}). + *
+ * The factory allocates and retains an int array of size {@code needle.length + 1}, and retains a reference + * to the {@code needle} itself. + *
+ * Search (the actual application of {@link SearchProcessor}) time is linear in the size of + * {@link io.netty.buffer.ByteBuf} on which the search is performed ({@code O(|haystack|)}). + * Every byte of {@link io.netty.buffer.ByteBuf} is processed only once, sequentially. + * + * @param needle an array of bytes to search for + * @return a new instance of {@link KmpSearchProcessorFactory} precomputed for the given {@code needle} + */ + public static KmpSearchProcessorFactory newKmpSearchProcessorFactory(byte[] needle) { + return new KmpSearchProcessorFactory(needle); + } + + /** + * Creates a {@link SearchProcessorFactory} based on Bitap string search algorithm. + * It is a jump free algorithm that has very stable performance (the contents of the inputs have a minimal + * effect on it). The limitation is that the {@code needle} can be no more than 64 bytes long. + *
+ * Precomputation (this method) time is linear in the size of the input ({@code O(|needle|)}). + *
+ * The factory allocates and retains a long[256] array. + *
+ * Search (the actual application of {@link SearchProcessor}) time is linear in the size of + * {@link io.netty.buffer.ByteBuf} on which the search is performed ({@code O(|haystack|)}). + * Every byte of {@link io.netty.buffer.ByteBuf} is processed only once, sequentially. + * + * @param needle an array of no more than 64 bytes to search for + * @return a new instance of {@link BitapSearchProcessorFactory} precomputed for the given {@code needle} + */ + public static BitapSearchProcessorFactory newBitapSearchProcessorFactory(byte[] needle) { + return new BitapSearchProcessorFactory(needle); + } + +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/search/AhoCorasicSearchProcessorFactory.java b/netty-buffer/src/main/java/io/netty/buffer/search/AhoCorasicSearchProcessorFactory.java new file mode 100644 index 0000000..5ee27c1 --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/search/AhoCorasicSearchProcessorFactory.java @@ -0,0 +1,191 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.buffer.search; + +import io.netty.util.internal.PlatformDependent; + +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Queue; + +/** + * Implements Aho–Corasick + * string search algorithm. + * Use static {@link AbstractMultiSearchProcessorFactory#newAhoCorasicSearchProcessorFactory} + * to create an instance of this factory. + * Use {@link AhoCorasicSearchProcessorFactory#newSearchProcessor} to get an instance of + * {@link io.netty.util.ByteProcessor} implementation for performing the actual search. + * @see AbstractMultiSearchProcessorFactory + */ +public class AhoCorasicSearchProcessorFactory extends AbstractMultiSearchProcessorFactory { + + private final int[] jumpTable; + private final int[] matchForNeedleId; + + static final int BITS_PER_SYMBOL = 8; + static final int ALPHABET_SIZE = 1 << BITS_PER_SYMBOL; + + private static class Context { + int[] jumpTable; + int[] matchForNeedleId; + } + + public static class Processor implements MultiSearchProcessor { + + private final int[] jumpTable; + private final int[] matchForNeedleId; + private long currentPosition; + + Processor(int[] jumpTable, int[] matchForNeedleId) { + this.jumpTable = jumpTable; + this.matchForNeedleId = matchForNeedleId; + } + + @Override + public boolean process(byte value) { + currentPosition = PlatformDependent.getInt(jumpTable, currentPosition | (value & 0xffL)); + if (currentPosition < 0) { + currentPosition = -currentPosition; + return false; + } + return true; + } + + @Override + public int getFoundNeedleId() { + return matchForNeedleId[(int) currentPosition >> AhoCorasicSearchProcessorFactory.BITS_PER_SYMBOL]; + } + + @Override + public void reset() { + currentPosition = 0; + } + } + + AhoCorasicSearchProcessorFactory(byte[] ...needles) { + + for (byte[] needle: needles) { + if (needle.length == 0) { + throw new IllegalArgumentException("Needle must be non empty"); + } + } + + Context context = buildTrie(needles); + jumpTable = context.jumpTable; + matchForNeedleId = context.matchForNeedleId; + + linkSuffixes(); + + for (int i = 0; i < jumpTable.length; i++) { + if (matchForNeedleId[jumpTable[i] >> BITS_PER_SYMBOL] >= 0) { + jumpTable[i] = -jumpTable[i]; + } + } + } + + private static Context buildTrie(byte[][] needles) { + + ArrayList jumpTableBuilder = new ArrayList(ALPHABET_SIZE); + for (int i = 0; i < ALPHABET_SIZE; i++) { + jumpTableBuilder.add(-1); + } + + ArrayList matchForBuilder = new ArrayList(); + matchForBuilder.add(-1); + + for (int needleId = 0; needleId < needles.length; needleId++) { + byte[] needle = needles[needleId]; + int currentPosition = 0; + + for (byte ch0: needle) { + + final int ch = ch0 & 0xff; + final int next = currentPosition + ch; + + if (jumpTableBuilder.get(next) == -1) { + jumpTableBuilder.set(next, jumpTableBuilder.size()); + for (int i = 0; i < ALPHABET_SIZE; i++) { + jumpTableBuilder.add(-1); + } + matchForBuilder.add(-1); + } + + currentPosition = jumpTableBuilder.get(next); + } + + matchForBuilder.set(currentPosition >> BITS_PER_SYMBOL, needleId); + } + + Context context = new Context(); + + context.jumpTable = new int[jumpTableBuilder.size()]; + for (int i = 0; i < jumpTableBuilder.size(); i++) { + context.jumpTable[i] = jumpTableBuilder.get(i); + } + + context.matchForNeedleId = new int[matchForBuilder.size()]; + for (int i = 0; i < matchForBuilder.size(); i++) { + context.matchForNeedleId[i] = matchForBuilder.get(i); + } + + return context; + } + + private void linkSuffixes() { + + Queue queue = new ArrayDeque(); + queue.add(0); + + int[] suffixLinks = new int[matchForNeedleId.length]; + Arrays.fill(suffixLinks, -1); + + while (!queue.isEmpty()) { + + final int v = queue.remove(); + int vPosition = v >> BITS_PER_SYMBOL; + final int u = suffixLinks[vPosition] == -1 ? 0 : suffixLinks[vPosition]; + + if (matchForNeedleId[vPosition] == -1) { + matchForNeedleId[vPosition] = matchForNeedleId[u >> BITS_PER_SYMBOL]; + } + + for (int ch = 0; ch < ALPHABET_SIZE; ch++) { + + final int vIndex = v | ch; + final int uIndex = u | ch; + + final int jumpV = jumpTable[vIndex]; + final int jumpU = jumpTable[uIndex]; + + if (jumpV != -1) { + suffixLinks[jumpV >> BITS_PER_SYMBOL] = v > 0 && jumpU != -1 ? jumpU : 0; + queue.add(jumpV); + } else { + jumpTable[vIndex] = jumpU != -1 ? jumpU : 0; + } + } + } + } + + /** + * Returns a new {@link Processor}. + */ + @Override + public Processor newSearchProcessor() { + return new Processor(jumpTable, matchForNeedleId); + } + +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/search/BitapSearchProcessorFactory.java b/netty-buffer/src/main/java/io/netty/buffer/search/BitapSearchProcessorFactory.java new file mode 100644 index 0000000..bb4a7c5 --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/search/BitapSearchProcessorFactory.java @@ -0,0 +1,77 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.buffer.search; + +import io.netty.util.internal.PlatformDependent; + +/** + * Implements Bitap string search algorithm. + * Use static {@link AbstractSearchProcessorFactory#newBitapSearchProcessorFactory} + * to create an instance of this factory. + * Use {@link BitapSearchProcessorFactory#newSearchProcessor} to get an instance of {@link io.netty.util.ByteProcessor} + * implementation for performing the actual search. + * @see AbstractSearchProcessorFactory + */ +public class BitapSearchProcessorFactory extends AbstractSearchProcessorFactory { + + private final long[] bitMasks = new long[256]; + private final long successBit; + + public static class Processor implements SearchProcessor { + + private final long[] bitMasks; + private final long successBit; + private long currentMask; + + Processor(long[] bitMasks, long successBit) { + this.bitMasks = bitMasks; + this.successBit = successBit; + } + + @Override + public boolean process(byte value) { + currentMask = ((currentMask << 1) | 1) & PlatformDependent.getLong(bitMasks, value & 0xffL); + return (currentMask & successBit) == 0; + } + + @Override + public void reset() { + currentMask = 0; + } + } + + BitapSearchProcessorFactory(byte[] needle) { + if (needle.length > 64) { + throw new IllegalArgumentException("Maximum supported search pattern length is 64, got " + needle.length); + } + + long bit = 1L; + for (byte c: needle) { + bitMasks[c & 0xff] |= bit; + bit <<= 1; + } + + successBit = 1L << (needle.length - 1); + } + + /** + * Returns a new {@link Processor}. + */ + @Override + public Processor newSearchProcessor() { + return new Processor(bitMasks, successBit); + } + +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/search/KmpSearchProcessorFactory.java b/netty-buffer/src/main/java/io/netty/buffer/search/KmpSearchProcessorFactory.java new file mode 100644 index 0000000..5b16b7f --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/search/KmpSearchProcessorFactory.java @@ -0,0 +1,91 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.buffer.search; + +import io.netty.util.internal.PlatformDependent; + +/** + * Implements + * Knuth-Morris-Pratt + * string search algorithm. + * Use static {@link AbstractSearchProcessorFactory#newKmpSearchProcessorFactory} + * to create an instance of this factory. + * Use {@link KmpSearchProcessorFactory#newSearchProcessor} to get an instance of {@link io.netty.util.ByteProcessor} + * implementation for performing the actual search. + * @see AbstractSearchProcessorFactory + */ +public class KmpSearchProcessorFactory extends AbstractSearchProcessorFactory { + + private final int[] jumpTable; + private final byte[] needle; + + public static class Processor implements SearchProcessor { + + private final byte[] needle; + private final int[] jumpTable; + private long currentPosition; + + Processor(byte[] needle, int[] jumpTable) { + this.needle = needle; + this.jumpTable = jumpTable; + } + + @Override + public boolean process(byte value) { + while (currentPosition > 0 && PlatformDependent.getByte(needle, currentPosition) != value) { + currentPosition = PlatformDependent.getInt(jumpTable, currentPosition); + } + if (PlatformDependent.getByte(needle, currentPosition) == value) { + currentPosition++; + } + if (currentPosition == needle.length) { + currentPosition = PlatformDependent.getInt(jumpTable, currentPosition); + return false; + } + + return true; + } + + @Override + public void reset() { + currentPosition = 0; + } + } + + KmpSearchProcessorFactory(byte[] needle) { + this.needle = needle.clone(); + this.jumpTable = new int[needle.length + 1]; + + int j = 0; + for (int i = 1; i < needle.length; i++) { + while (j > 0 && needle[j] != needle[i]) { + j = jumpTable[j]; + } + if (needle[j] == needle[i]) { + j++; + } + jumpTable[i + 1] = j; + } + } + + /** + * Returns a new {@link Processor}. + */ + @Override + public Processor newSearchProcessor() { + return new Processor(needle, jumpTable); + } + +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/search/MultiSearchProcessor.java b/netty-buffer/src/main/java/io/netty/buffer/search/MultiSearchProcessor.java new file mode 100644 index 0000000..f7e9987 --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/search/MultiSearchProcessor.java @@ -0,0 +1,28 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.buffer.search; + +/** + * Interface for {@link SearchProcessor} that implements simultaneous search for multiple strings. + * @see MultiSearchProcessorFactory + */ +public interface MultiSearchProcessor extends SearchProcessor { + + /** + * @return the index of found search string (if any, or -1 if none) at current position of this MultiSearchProcessor + */ + int getFoundNeedleId(); + +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/search/MultiSearchProcessorFactory.java b/netty-buffer/src/main/java/io/netty/buffer/search/MultiSearchProcessorFactory.java new file mode 100644 index 0000000..176ea8a --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/search/MultiSearchProcessorFactory.java @@ -0,0 +1,25 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.buffer.search; + +public interface MultiSearchProcessorFactory extends SearchProcessorFactory { + + /** + * Returns a new {@link MultiSearchProcessor}. + */ + @Override + MultiSearchProcessor newSearchProcessor(); + +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/search/SearchProcessor.java b/netty-buffer/src/main/java/io/netty/buffer/search/SearchProcessor.java new file mode 100644 index 0000000..baefd25 --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/search/SearchProcessor.java @@ -0,0 +1,30 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.buffer.search; + +import io.netty.util.ByteProcessor; + +/** + * Interface for {@link ByteProcessor} that implements string search. + * @see SearchProcessorFactory + */ +public interface SearchProcessor extends ByteProcessor { + + /** + * Resets the state of SearchProcessor. + */ + void reset(); + +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/search/SearchProcessorFactory.java b/netty-buffer/src/main/java/io/netty/buffer/search/SearchProcessorFactory.java new file mode 100644 index 0000000..17679d7 --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/search/SearchProcessorFactory.java @@ -0,0 +1,24 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.buffer.search; + +public interface SearchProcessorFactory { + + /** + * Returns a new {@link SearchProcessor}. + */ + SearchProcessor newSearchProcessor(); + +} diff --git a/netty-buffer/src/main/java/io/netty/buffer/search/package-info.java b/netty-buffer/src/main/java/io/netty/buffer/search/package-info.java new file mode 100644 index 0000000..630e341 --- /dev/null +++ b/netty-buffer/src/main/java/io/netty/buffer/search/package-info.java @@ -0,0 +1,20 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/** + * Utility classes for performing efficient substring search within {@link io.netty.buffer.ByteBuf}. + */ +package io.netty.buffer.search; diff --git a/netty-buffer/src/main/java/module-info.java b/netty-buffer/src/main/java/module-info.java new file mode 100644 index 0000000..29639d5 --- /dev/null +++ b/netty-buffer/src/main/java/module-info.java @@ -0,0 +1,5 @@ +module org.xbib.io.netty.buffer { + exports io.netty.buffer; + exports io.netty.buffer.search; + requires org.xbib.io.netty.util; +} diff --git a/netty-buffer/src/test/java/io/netty/buffer/AbstractByteBufAllocatorTest.java b/netty-buffer/src/test/java/io/netty/buffer/AbstractByteBufAllocatorTest.java new file mode 100644 index 0000000..e6e064b --- /dev/null +++ b/netty-buffer/src/test/java/io/netty/buffer/AbstractByteBufAllocatorTest.java @@ -0,0 +1,145 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import io.netty.util.internal.PlatformDependent; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public abstract class AbstractByteBufAllocatorTest extends ByteBufAllocatorTest { + + @Override + protected abstract T newAllocator(boolean preferDirect); + + protected abstract T newUnpooledAllocator(); + + @Override + protected boolean isDirectExpected(boolean preferDirect) { + return preferDirect && PlatformDependent.hasUnsafe(); + } + + @Override + protected final int defaultMaxCapacity() { + return AbstractByteBufAllocator.DEFAULT_MAX_CAPACITY; + } + + @Override + protected final int defaultMaxComponents() { + return AbstractByteBufAllocator.DEFAULT_MAX_COMPONENTS; + } + + @Test + public void testCalculateNewCapacity() { + testCalculateNewCapacity(true); + testCalculateNewCapacity(false); + } + + private void testCalculateNewCapacity(boolean preferDirect) { + T allocator = newAllocator(preferDirect); + assertEquals(8, allocator.calculateNewCapacity(1, 8)); + assertEquals(7, allocator.calculateNewCapacity(1, 7)); + assertEquals(64, allocator.calculateNewCapacity(1, 129)); + assertEquals(AbstractByteBufAllocator.CALCULATE_THRESHOLD, + allocator.calculateNewCapacity(AbstractByteBufAllocator.CALCULATE_THRESHOLD, + AbstractByteBufAllocator.CALCULATE_THRESHOLD + 1)); + assertEquals(AbstractByteBufAllocator.CALCULATE_THRESHOLD * 2, + allocator.calculateNewCapacity(AbstractByteBufAllocator.CALCULATE_THRESHOLD + 1, + AbstractByteBufAllocator.CALCULATE_THRESHOLD * 4)); + try { + allocator.calculateNewCapacity(8, 7); + fail(); + } catch (IllegalArgumentException e) { + // expected + } + + try { + allocator.calculateNewCapacity(-1, 8); + fail(); + } catch (IllegalArgumentException e) { + // expected + } + } + + @Test + public void testUnsafeHeapBufferAndUnsafeDirectBuffer() { + T allocator = newUnpooledAllocator(); + ByteBuf directBuffer = allocator.directBuffer(); + assertInstanceOf(directBuffer, + PlatformDependent.hasUnsafe() ? UnpooledUnsafeDirectByteBuf.class : UnpooledDirectByteBuf.class); + directBuffer.release(); + + ByteBuf heapBuffer = allocator.heapBuffer(); + assertInstanceOf(heapBuffer, + PlatformDependent.hasUnsafe() ? UnpooledUnsafeHeapByteBuf.class : UnpooledHeapByteBuf.class); + heapBuffer.release(); + } + + protected static void assertInstanceOf(ByteBuf buffer, Class clazz) { + // Unwrap if needed + assertTrue(clazz.isInstance(buffer instanceof SimpleLeakAwareByteBuf ? buffer.unwrap() : buffer)); + } + + @Test + public void testUsedDirectMemory() { + T allocator = newAllocator(true); + ByteBufAllocatorMetric metric = ((ByteBufAllocatorMetricProvider) allocator).metric(); + assertEquals(0, metric.usedDirectMemory()); + ByteBuf buffer = allocator.directBuffer(1024, 4096); + int capacity = buffer.capacity(); + assertEquals(expectedUsedMemory(allocator, capacity), metric.usedDirectMemory()); + + // Double the size of the buffer + buffer.capacity(capacity << 1); + capacity = buffer.capacity(); + assertEquals(expectedUsedMemory(allocator, capacity), metric.usedDirectMemory(), buffer.toString()); + + buffer.release(); + assertEquals(expectedUsedMemoryAfterRelease(allocator, capacity), metric.usedDirectMemory()); + } + + @Test + public void testUsedHeapMemory() { + T allocator = newAllocator(true); + ByteBufAllocatorMetric metric = ((ByteBufAllocatorMetricProvider) allocator).metric(); + + assertEquals(0, metric.usedHeapMemory()); + ByteBuf buffer = allocator.heapBuffer(1024, 4096); + int capacity = buffer.capacity(); + assertEquals(expectedUsedMemory(allocator, capacity), metric.usedHeapMemory()); + + // Double the size of the buffer + buffer.capacity(capacity << 1); + capacity = buffer.capacity(); + assertEquals(expectedUsedMemory(allocator, capacity), metric.usedHeapMemory()); + + buffer.release(); + assertEquals(expectedUsedMemoryAfterRelease(allocator, capacity), metric.usedHeapMemory()); + } + + protected long expectedUsedMemory(T allocator, int capacity) { + return capacity; + } + + protected long expectedUsedMemoryAfterRelease(T allocator, int capacity) { + return 0; + } + + protected void trimCaches(T allocator) { + } +} diff --git a/netty-buffer/src/test/java/io/netty/buffer/AbstractByteBufTest.java b/netty-buffer/src/test/java/io/netty/buffer/AbstractByteBufTest.java new file mode 100644 index 0000000..6a2c98f --- /dev/null +++ b/netty-buffer/src/test/java/io/netty/buffer/AbstractByteBufTest.java @@ -0,0 +1,5992 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import io.netty.util.ByteProcessor; +import io.netty.util.CharsetUtil; +import io.netty.util.IllegalReferenceCountException; +import io.netty.util.internal.PlatformDependent; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.api.function.Executable; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.IOException; +import java.io.RandomAccessFile; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.CharBuffer; +import java.nio.ReadOnlyBufferException; +import java.nio.channels.Channels; +import java.nio.channels.FileChannel; +import java.nio.channels.GatheringByteChannel; +import java.nio.channels.ScatteringByteChannel; +import java.nio.channels.WritableByteChannel; +import java.nio.charset.Charset; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.CyclicBarrier; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; + +import static io.netty.buffer.Unpooled.LITTLE_ENDIAN; +import static io.netty.buffer.Unpooled.buffer; +import static io.netty.buffer.Unpooled.copiedBuffer; +import static io.netty.buffer.Unpooled.directBuffer; +import static io.netty.buffer.Unpooled.unreleasableBuffer; +import static io.netty.buffer.Unpooled.wrappedBuffer; +import static io.netty.util.internal.EmptyArrays.EMPTY_BYTES; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; +import static org.junit.jupiter.api.Assumptions.assumeFalse; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +/** + * An abstract test class for channel buffers + */ +public abstract class AbstractByteBufTest { + + private static final int CAPACITY = 4096; // Must be even + private static final int BLOCK_SIZE = 128; + private static final int JAVA_BYTEBUFFER_CONSISTENCY_ITERATIONS = 100; + + private long seed; + private Random random; + private ByteBuf buffer; + + protected final ByteBuf newBuffer(int capacity) { + return newBuffer(capacity, Integer.MAX_VALUE); + } + + protected abstract ByteBuf newBuffer(int capacity, int maxCapacity); + + protected boolean discardReadBytesDoesNotMoveWritableBytes() { + return true; + } + + @BeforeEach + public void init() { + buffer = newBuffer(CAPACITY); + seed = System.currentTimeMillis(); + random = new Random(seed); + } + + @AfterEach + public void dispose() { + if (buffer != null) { + assertThat(buffer.release(), is(true)); + assertThat(buffer.refCnt(), is(0)); + + try { + buffer.release(); + } catch (Exception e) { + // Ignore. + } + buffer = null; + } + } + + @Test + public void comparableInterfaceNotViolated() { + assumeFalse(buffer.isReadOnly()); + buffer.writerIndex(buffer.readerIndex()); + assumeTrue(buffer.writableBytes() >= 4); + + buffer.writeLong(0); + ByteBuf buffer2 = newBuffer(CAPACITY); + assumeFalse(buffer2.isReadOnly()); + buffer2.writerIndex(buffer2.readerIndex()); + // Write an unsigned integer that will cause buffer.getUnsignedInt() - buffer2.getUnsignedInt() to underflow the + // int type and wrap around on the negative side. + buffer2.writeLong(0xF0000000L); + assertTrue(buffer.compareTo(buffer2) < 0); + assertTrue(buffer2.compareTo(buffer) > 0); + buffer2.release(); + } + + @Test + public void initialState() { + assertEquals(CAPACITY, buffer.capacity()); + assertEquals(0, buffer.readerIndex()); + } + + @Test + public void readerIndexBoundaryCheck1() { + try { + buffer.writerIndex(0); + } catch (IndexOutOfBoundsException e) { + fail(); + } + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + buffer.readerIndex(-1); + } + }); + } + + @Test + public void readerIndexBoundaryCheck2() { + try { + buffer.writerIndex(buffer.capacity()); + } catch (IndexOutOfBoundsException e) { + fail(); + } + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + buffer.readerIndex(buffer.capacity() + 1); + } + }); + } + + @Test + public void readerIndexBoundaryCheck3() { + try { + buffer.writerIndex(CAPACITY / 2); + } catch (IndexOutOfBoundsException e) { + fail(); + } + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + buffer.readerIndex(CAPACITY * 3 / 2); + } + }); + } + + @Test + public void readerIndexBoundaryCheck4() { + buffer.writerIndex(0); + buffer.readerIndex(0); + buffer.writerIndex(buffer.capacity()); + buffer.readerIndex(buffer.capacity()); + } + + @Test + public void writerIndexBoundaryCheck1() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + buffer.writerIndex(-1); + } + }); + } + + @Test + public void writerIndexBoundaryCheck2() { + try { + buffer.writerIndex(CAPACITY); + buffer.readerIndex(CAPACITY); + } catch (IndexOutOfBoundsException e) { + fail(); + } + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + buffer.writerIndex(buffer.capacity() + 1); + } + }); + } + + @Test + public void writerIndexBoundaryCheck3() { + try { + buffer.writerIndex(CAPACITY); + buffer.readerIndex(CAPACITY / 2); + } catch (IndexOutOfBoundsException e) { + fail(); + } + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + buffer.writerIndex(CAPACITY / 4); + } + }); + } + + @Test + public void writerIndexBoundaryCheck4() { + buffer.writerIndex(0); + buffer.readerIndex(0); + buffer.writerIndex(CAPACITY); + + buffer.writeBytes(ByteBuffer.wrap(EMPTY_BYTES)); + } + + @Test + public void getBooleanBoundaryCheck1() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + buffer.getBoolean(-1); + } + }); + } + + @Test + public void getBooleanBoundaryCheck2() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + buffer.getBoolean(buffer.capacity()); + } + }); + } + + @Test + public void getByteBoundaryCheck1() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + buffer.getByte(-1); + } + }); + } + + @Test + public void getByteBoundaryCheck2() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + buffer.getByte(buffer.capacity()); + } + }); + } + + @Test + public void getShortBoundaryCheck1() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + buffer.getShort(-1); + } + }); + } + + @Test + public void getShortBoundaryCheck2() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + buffer.getShort(buffer.capacity() - 1); + } + }); + } + + @Test + public void getMediumBoundaryCheck1() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + buffer.getMedium(-1); + } + }); + } + + @Test + public void getMediumBoundaryCheck2() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + buffer.getMedium(buffer.capacity() - 2); + } + }); + } + + @Test + public void getIntBoundaryCheck1() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + buffer.getInt(-1); + } + }); + } + + @Test + public void getIntBoundaryCheck2() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + buffer.getInt(buffer.capacity() - 3); + } + }); + } + + @Test + public void getLongBoundaryCheck1() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + buffer.getLong(-1); + } + }); + } + + @Test + public void getLongBoundaryCheck2() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + buffer.getLong(buffer.capacity() - 7); + } + }); + } + + @Test + public void getByteArrayBoundaryCheck1() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + buffer.getBytes(-1, EMPTY_BYTES); + } + }); + } + + @Test + public void getByteArrayBoundaryCheck2() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + buffer.getBytes(-1, EMPTY_BYTES, 0, 0); + } + }); + } + + @Test + public void getByteArrayBoundaryCheck3() { + byte[] dst = new byte[4]; + buffer.setInt(0, 0x01020304); + try { + buffer.getBytes(0, dst, -1, 4); + fail(); + } catch (IndexOutOfBoundsException e) { + // Success + } + + // No partial copy is expected. + assertEquals(0, dst[0]); + assertEquals(0, dst[1]); + assertEquals(0, dst[2]); + assertEquals(0, dst[3]); + } + + @Test + public void getByteArrayBoundaryCheck4() { + byte[] dst = new byte[4]; + buffer.setInt(0, 0x01020304); + try { + buffer.getBytes(0, dst, 1, 4); + fail(); + } catch (IndexOutOfBoundsException e) { + // Success + } + + // No partial copy is expected. + assertEquals(0, dst[0]); + assertEquals(0, dst[1]); + assertEquals(0, dst[2]); + assertEquals(0, dst[3]); + } + + @Test + public void getByteBufferBoundaryCheck() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + buffer.getBytes(-1, ByteBuffer.allocate(0)); + } + }); + } + + @Test + public void copyBoundaryCheck1() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + buffer.copy(-1, 0); + } + }); + } + + @Test + public void copyBoundaryCheck2() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + buffer.copy(0, buffer.capacity() + 1); + } + }); + } + + @Test + public void copyBoundaryCheck3() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + buffer.copy(buffer.capacity() + 1, 0); + } + }); + } + + @Test + public void copyBoundaryCheck4() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + buffer.copy(buffer.capacity(), 1); + } + }); + } + + @Test + public void setIndexBoundaryCheck1() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + buffer.setIndex(-1, CAPACITY); + } + }); + } + + @Test + public void setIndexBoundaryCheck2() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + buffer.setIndex(CAPACITY / 2, CAPACITY / 4); + } + }); + } + + @Test + public void setIndexBoundaryCheck3() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + buffer.setIndex(0, CAPACITY + 1); + } + }); + } + + @Test + public void getByteBufferState() { + ByteBuffer dst = ByteBuffer.allocate(4); + dst.position(1); + dst.limit(3); + + buffer.setByte(0, (byte) 1); + buffer.setByte(1, (byte) 2); + buffer.setByte(2, (byte) 3); + buffer.setByte(3, (byte) 4); + buffer.getBytes(1, dst); + + assertEquals(3, dst.position()); + assertEquals(3, dst.limit()); + + dst.clear(); + assertEquals(0, dst.get(0)); + assertEquals(2, dst.get(1)); + assertEquals(3, dst.get(2)); + assertEquals(0, dst.get(3)); + } + + @Test + public void getDirectByteBufferBoundaryCheck() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + buffer.getBytes(-1, ByteBuffer.allocateDirect(0)); + } + }); + } + + @Test + public void getDirectByteBufferState() { + ByteBuffer dst = ByteBuffer.allocateDirect(4); + dst.position(1); + dst.limit(3); + + buffer.setByte(0, (byte) 1); + buffer.setByte(1, (byte) 2); + buffer.setByte(2, (byte) 3); + buffer.setByte(3, (byte) 4); + buffer.getBytes(1, dst); + + assertEquals(3, dst.position()); + assertEquals(3, dst.limit()); + + dst.clear(); + assertEquals(0, dst.get(0)); + assertEquals(2, dst.get(1)); + assertEquals(3, dst.get(2)); + assertEquals(0, dst.get(3)); + } + + @Test + public void testRandomByteAccess() { + for (int i = 0; i < buffer.capacity(); i ++) { + byte value = (byte) random.nextInt(); + buffer.setByte(i, value); + } + + random.setSeed(seed); + for (int i = 0; i < buffer.capacity(); i ++) { + byte value = (byte) random.nextInt(); + assertEquals(value, buffer.getByte(i)); + } + } + + @Test + public void testRandomUnsignedByteAccess() { + for (int i = 0; i < buffer.capacity(); i ++) { + byte value = (byte) random.nextInt(); + buffer.setByte(i, value); + } + + random.setSeed(seed); + for (int i = 0; i < buffer.capacity(); i ++) { + int value = random.nextInt() & 0xFF; + assertEquals(value, buffer.getUnsignedByte(i)); + } + } + + @Test + public void testRandomShortAccess() { + testRandomShortAccess(true); + } + @Test + public void testRandomShortLEAccess() { + testRandomShortAccess(false); + } + + private void testRandomShortAccess(boolean testBigEndian) { + for (int i = 0; i < buffer.capacity() - 1; i += 2) { + short value = (short) random.nextInt(); + if (testBigEndian) { + buffer.setShort(i, value); + } else { + buffer.setShortLE(i, value); + } + } + + random.setSeed(seed); + for (int i = 0; i < buffer.capacity() - 1; i += 2) { + short value = (short) random.nextInt(); + if (testBigEndian) { + assertEquals(value, buffer.getShort(i)); + } else { + assertEquals(value, buffer.getShortLE(i)); + } + } + } + + @Test + public void testShortConsistentWithByteBuffer() { + testShortConsistentWithByteBuffer(true, true); + testShortConsistentWithByteBuffer(true, false); + testShortConsistentWithByteBuffer(false, true); + testShortConsistentWithByteBuffer(false, false); + } + + private void testShortConsistentWithByteBuffer(boolean direct, boolean testBigEndian) { + for (int i = 0; i < JAVA_BYTEBUFFER_CONSISTENCY_ITERATIONS; ++i) { + ByteBuffer javaBuffer = direct ? ByteBuffer.allocateDirect(buffer.capacity()) + : ByteBuffer.allocate(buffer.capacity()); + if (!testBigEndian) { + javaBuffer = javaBuffer.order(ByteOrder.LITTLE_ENDIAN); + } + + short expected = (short) (random.nextInt() & 0xFFFF); + javaBuffer.putShort(expected); + + final int bufferIndex = buffer.capacity() - 2; + if (testBigEndian) { + buffer.setShort(bufferIndex, expected); + } else { + buffer.setShortLE(bufferIndex, expected); + } + javaBuffer.flip(); + + short javaActual = javaBuffer.getShort(); + assertEquals(expected, javaActual); + assertEquals(javaActual, testBigEndian ? buffer.getShort(bufferIndex) + : buffer.getShortLE(bufferIndex)); + } + } + + @Test + public void testRandomUnsignedShortAccess() { + testRandomUnsignedShortAccess(true); + } + + @Test + public void testRandomUnsignedShortLEAccess() { + testRandomUnsignedShortAccess(false); + } + + private void testRandomUnsignedShortAccess(boolean testBigEndian) { + for (int i = 0; i < buffer.capacity() - 1; i += 2) { + short value = (short) random.nextInt(); + if (testBigEndian) { + buffer.setShort(i, value); + } else { + buffer.setShortLE(i, value); + } + } + + random.setSeed(seed); + for (int i = 0; i < buffer.capacity() - 1; i += 2) { + int value = random.nextInt() & 0xFFFF; + if (testBigEndian) { + assertEquals(value, buffer.getUnsignedShort(i)); + } else { + assertEquals(value, buffer.getUnsignedShortLE(i)); + } + } + } + + @Test + public void testRandomMediumAccess() { + testRandomMediumAccess(true); + } + + @Test + public void testRandomMediumLEAccess() { + testRandomMediumAccess(false); + } + + private void testRandomMediumAccess(boolean testBigEndian) { + for (int i = 0; i < buffer.capacity() - 2; i += 3) { + int value = random.nextInt(); + if (testBigEndian) { + buffer.setMedium(i, value); + } else { + buffer.setMediumLE(i, value); + } + } + + random.setSeed(seed); + for (int i = 0; i < buffer.capacity() - 2; i += 3) { + int value = random.nextInt() << 8 >> 8; + if (testBigEndian) { + assertEquals(value, buffer.getMedium(i)); + } else { + assertEquals(value, buffer.getMediumLE(i)); + } + } + } + + @Test + public void testRandomUnsignedMediumAccess() { + testRandomUnsignedMediumAccess(true); + } + + @Test + public void testRandomUnsignedMediumLEAccess() { + testRandomUnsignedMediumAccess(false); + } + + private void testRandomUnsignedMediumAccess(boolean testBigEndian) { + for (int i = 0; i < buffer.capacity() - 2; i += 3) { + int value = random.nextInt(); + if (testBigEndian) { + buffer.setMedium(i, value); + } else { + buffer.setMediumLE(i, value); + } + } + + random.setSeed(seed); + for (int i = 0; i < buffer.capacity() - 2; i += 3) { + int value = random.nextInt() & 0x00FFFFFF; + if (testBigEndian) { + assertEquals(value, buffer.getUnsignedMedium(i)); + } else { + assertEquals(value, buffer.getUnsignedMediumLE(i)); + } + } + } + + @Test + public void testMediumConsistentWithByteBuffer() { + testMediumConsistentWithByteBuffer(true, true); + testMediumConsistentWithByteBuffer(true, false); + testMediumConsistentWithByteBuffer(false, true); + testMediumConsistentWithByteBuffer(false, false); + } + + private void testMediumConsistentWithByteBuffer(boolean direct, boolean testBigEndian) { + for (int i = 0; i < JAVA_BYTEBUFFER_CONSISTENCY_ITERATIONS; ++i) { + ByteBuffer javaBuffer = direct ? ByteBuffer.allocateDirect(buffer.capacity()) + : ByteBuffer.allocate(buffer.capacity()); + if (!testBigEndian) { + javaBuffer = javaBuffer.order(ByteOrder.LITTLE_ENDIAN); + } + + int expected = random.nextInt() & 0x00FFFFFF; + javaBuffer.putInt(expected); + + final int bufferIndex = buffer.capacity() - 3; + if (testBigEndian) { + buffer.setMedium(bufferIndex, expected); + } else { + buffer.setMediumLE(bufferIndex, expected); + } + javaBuffer.flip(); + + int javaActual = javaBuffer.getInt(); + assertEquals(expected, javaActual); + assertEquals(javaActual, testBigEndian ? buffer.getUnsignedMedium(bufferIndex) + : buffer.getUnsignedMediumLE(bufferIndex)); + } + } + + @Test + public void testRandomIntAccess() { + testRandomIntAccess(true); + } + + @Test + public void testRandomIntLEAccess() { + testRandomIntAccess(false); + } + + private void testRandomIntAccess(boolean testBigEndian) { + for (int i = 0; i < buffer.capacity() - 3; i += 4) { + int value = random.nextInt(); + if (testBigEndian) { + buffer.setInt(i, value); + } else { + buffer.setIntLE(i, value); + } + } + + random.setSeed(seed); + for (int i = 0; i < buffer.capacity() - 3; i += 4) { + int value = random.nextInt(); + if (testBigEndian) { + assertEquals(value, buffer.getInt(i)); + } else { + assertEquals(value, buffer.getIntLE(i)); + } + } + } + + @Test + public void testIntConsistentWithByteBuffer() { + testIntConsistentWithByteBuffer(true, true); + testIntConsistentWithByteBuffer(true, false); + testIntConsistentWithByteBuffer(false, true); + testIntConsistentWithByteBuffer(false, false); + } + + private void testIntConsistentWithByteBuffer(boolean direct, boolean testBigEndian) { + for (int i = 0; i < JAVA_BYTEBUFFER_CONSISTENCY_ITERATIONS; ++i) { + ByteBuffer javaBuffer = direct ? ByteBuffer.allocateDirect(buffer.capacity()) + : ByteBuffer.allocate(buffer.capacity()); + if (!testBigEndian) { + javaBuffer = javaBuffer.order(ByteOrder.LITTLE_ENDIAN); + } + + int expected = random.nextInt(); + javaBuffer.putInt(expected); + + final int bufferIndex = buffer.capacity() - 4; + if (testBigEndian) { + buffer.setInt(bufferIndex, expected); + } else { + buffer.setIntLE(bufferIndex, expected); + } + javaBuffer.flip(); + + int javaActual = javaBuffer.getInt(); + assertEquals(expected, javaActual); + assertEquals(javaActual, testBigEndian ? buffer.getInt(bufferIndex) + : buffer.getIntLE(bufferIndex)); + } + } + + @Test + public void testRandomUnsignedIntAccess() { + testRandomUnsignedIntAccess(true); + } + + @Test + public void testRandomUnsignedIntLEAccess() { + testRandomUnsignedIntAccess(false); + } + + private void testRandomUnsignedIntAccess(boolean testBigEndian) { + for (int i = 0; i < buffer.capacity() - 3; i += 4) { + int value = random.nextInt(); + if (testBigEndian) { + buffer.setInt(i, value); + } else { + buffer.setIntLE(i, value); + } + } + + random.setSeed(seed); + for (int i = 0; i < buffer.capacity() - 3; i += 4) { + long value = random.nextInt() & 0xFFFFFFFFL; + if (testBigEndian) { + assertEquals(value, buffer.getUnsignedInt(i)); + } else { + assertEquals(value, buffer.getUnsignedIntLE(i)); + } + } + } + + @Test + public void testRandomLongAccess() { + testRandomLongAccess(true); + } + + @Test + public void testRandomLongLEAccess() { + testRandomLongAccess(false); + } + + private void testRandomLongAccess(boolean testBigEndian) { + for (int i = 0; i < buffer.capacity() - 7; i += 8) { + long value = random.nextLong(); + if (testBigEndian) { + buffer.setLong(i, value); + } else { + buffer.setLongLE(i, value); + } + } + + random.setSeed(seed); + for (int i = 0; i < buffer.capacity() - 7; i += 8) { + long value = random.nextLong(); + if (testBigEndian) { + assertEquals(value, buffer.getLong(i)); + } else { + assertEquals(value, buffer.getLongLE(i)); + } + } + } + + @Test + public void testLongConsistentWithByteBuffer() { + testLongConsistentWithByteBuffer(true, true); + testLongConsistentWithByteBuffer(true, false); + testLongConsistentWithByteBuffer(false, true); + testLongConsistentWithByteBuffer(false, false); + } + + private void testLongConsistentWithByteBuffer(boolean direct, boolean testBigEndian) { + for (int i = 0; i < JAVA_BYTEBUFFER_CONSISTENCY_ITERATIONS; ++i) { + ByteBuffer javaBuffer = direct ? ByteBuffer.allocateDirect(buffer.capacity()) + : ByteBuffer.allocate(buffer.capacity()); + if (!testBigEndian) { + javaBuffer = javaBuffer.order(ByteOrder.LITTLE_ENDIAN); + } + + long expected = random.nextLong(); + javaBuffer.putLong(expected); + + final int bufferIndex = buffer.capacity() - 8; + if (testBigEndian) { + buffer.setLong(bufferIndex, expected); + } else { + buffer.setLongLE(bufferIndex, expected); + } + javaBuffer.flip(); + + long javaActual = javaBuffer.getLong(); + assertEquals(expected, javaActual); + assertEquals(javaActual, testBigEndian ? buffer.getLong(bufferIndex) + : buffer.getLongLE(bufferIndex)); + } + } + + @Test + public void testRandomFloatAccess() { + testRandomFloatAccess(true); + } + + @Test + public void testRandomFloatLEAccess() { + testRandomFloatAccess(false); + } + + private void testRandomFloatAccess(boolean testBigEndian) { + for (int i = 0; i < buffer.capacity() - 7; i += 8) { + float value = random.nextFloat(); + if (testBigEndian) { + buffer.setFloat(i, value); + } else { + buffer.setFloatLE(i, value); + } + } + + random.setSeed(seed); + for (int i = 0; i < buffer.capacity() - 7; i += 8) { + float expected = random.nextFloat(); + float actual = testBigEndian? buffer.getFloat(i) : buffer.getFloatLE(i); + assertEquals(expected, actual, 0.01); + } + } + + @Test + public void testRandomDoubleAccess() { + testRandomDoubleAccess(true); + } + + @Test + public void testRandomDoubleLEAccess() { + testRandomDoubleAccess(false); + } + + private void testRandomDoubleAccess(boolean testBigEndian) { + for (int i = 0; i < buffer.capacity() - 7; i += 8) { + double value = random.nextDouble(); + if (testBigEndian) { + buffer.setDouble(i, value); + } else { + buffer.setDoubleLE(i, value); + } + } + + random.setSeed(seed); + for (int i = 0; i < buffer.capacity() - 7; i += 8) { + double expected = random.nextDouble(); + double actual = testBigEndian? buffer.getDouble(i) : buffer.getDoubleLE(i); + assertEquals(expected, actual, 0.01); + } + } + + @Test + public void testSetZero() { + buffer.clear(); + while (buffer.isWritable()) { + buffer.writeByte((byte) 0xFF); + } + + for (int i = 0; i < buffer.capacity();) { + int length = Math.min(buffer.capacity() - i, random.nextInt(32)); + buffer.setZero(i, length); + i += length; + } + + for (int i = 0; i < buffer.capacity(); i ++) { + assertEquals(0, buffer.getByte(i)); + } + } + + @Test + public void testSequentialByteAccess() { + buffer.writerIndex(0); + for (int i = 0; i < buffer.capacity(); i ++) { + byte value = (byte) random.nextInt(); + assertEquals(i, buffer.writerIndex()); + assertTrue(buffer.isWritable()); + buffer.writeByte(value); + } + + assertEquals(0, buffer.readerIndex()); + assertEquals(buffer.capacity(), buffer.writerIndex()); + assertFalse(buffer.isWritable()); + + random.setSeed(seed); + for (int i = 0; i < buffer.capacity(); i ++) { + byte value = (byte) random.nextInt(); + assertEquals(i, buffer.readerIndex()); + assertTrue(buffer.isReadable()); + assertEquals(value, buffer.readByte()); + } + + assertEquals(buffer.capacity(), buffer.readerIndex()); + assertEquals(buffer.capacity(), buffer.writerIndex()); + assertFalse(buffer.isReadable()); + assertFalse(buffer.isWritable()); + } + + @Test + public void testSequentialUnsignedByteAccess() { + buffer.writerIndex(0); + for (int i = 0; i < buffer.capacity(); i ++) { + byte value = (byte) random.nextInt(); + assertEquals(i, buffer.writerIndex()); + assertTrue(buffer.isWritable()); + buffer.writeByte(value); + } + + assertEquals(0, buffer.readerIndex()); + assertEquals(buffer.capacity(), buffer.writerIndex()); + assertFalse(buffer.isWritable()); + + random.setSeed(seed); + for (int i = 0; i < buffer.capacity(); i ++) { + int value = random.nextInt() & 0xFF; + assertEquals(i, buffer.readerIndex()); + assertTrue(buffer.isReadable()); + assertEquals(value, buffer.readUnsignedByte()); + } + + assertEquals(buffer.capacity(), buffer.readerIndex()); + assertEquals(buffer.capacity(), buffer.writerIndex()); + assertFalse(buffer.isReadable()); + assertFalse(buffer.isWritable()); + } + + @Test + public void testSequentialShortAccess() { + testSequentialShortAccess(true); + } + + @Test + public void testSequentialShortLEAccess() { + testSequentialShortAccess(false); + } + + private void testSequentialShortAccess(boolean testBigEndian) { + buffer.writerIndex(0); + for (int i = 0; i < buffer.capacity(); i += 2) { + short value = (short) random.nextInt(); + assertEquals(i, buffer.writerIndex()); + assertTrue(buffer.isWritable()); + if (testBigEndian) { + buffer.writeShort(value); + } else { + buffer.writeShortLE(value); + } + } + + assertEquals(0, buffer.readerIndex()); + assertEquals(buffer.capacity(), buffer.writerIndex()); + assertFalse(buffer.isWritable()); + + random.setSeed(seed); + for (int i = 0; i < buffer.capacity(); i += 2) { + short value = (short) random.nextInt(); + assertEquals(i, buffer.readerIndex()); + assertTrue(buffer.isReadable()); + if (testBigEndian) { + assertEquals(value, buffer.readShort()); + } else { + assertEquals(value, buffer.readShortLE()); + } + } + + assertEquals(buffer.capacity(), buffer.readerIndex()); + assertEquals(buffer.capacity(), buffer.writerIndex()); + assertFalse(buffer.isReadable()); + assertFalse(buffer.isWritable()); + } + + @Test + public void testSequentialUnsignedShortAccess() { + testSequentialUnsignedShortAccess(true); + } + + @Test + public void testSequentialUnsignedShortLEAccess() { + testSequentialUnsignedShortAccess(true); + } + + private void testSequentialUnsignedShortAccess(boolean testBigEndian) { + buffer.writerIndex(0); + for (int i = 0; i < buffer.capacity(); i += 2) { + short value = (short) random.nextInt(); + assertEquals(i, buffer.writerIndex()); + assertTrue(buffer.isWritable()); + if (testBigEndian) { + buffer.writeShort(value); + } else { + buffer.writeShortLE(value); + } + } + + assertEquals(0, buffer.readerIndex()); + assertEquals(buffer.capacity(), buffer.writerIndex()); + assertFalse(buffer.isWritable()); + + random.setSeed(seed); + for (int i = 0; i < buffer.capacity(); i += 2) { + int value = random.nextInt() & 0xFFFF; + assertEquals(i, buffer.readerIndex()); + assertTrue(buffer.isReadable()); + if (testBigEndian) { + assertEquals(value, buffer.readUnsignedShort()); + } else { + assertEquals(value, buffer.readUnsignedShortLE()); + } + } + + assertEquals(buffer.capacity(), buffer.readerIndex()); + assertEquals(buffer.capacity(), buffer.writerIndex()); + assertFalse(buffer.isReadable()); + assertFalse(buffer.isWritable()); + } + + @Test + public void testSequentialMediumAccess() { + testSequentialMediumAccess(true); + } + @Test + public void testSequentialMediumLEAccess() { + testSequentialMediumAccess(false); + } + + private void testSequentialMediumAccess(boolean testBigEndian) { + buffer.writerIndex(0); + for (int i = 0; i < buffer.capacity() / 3 * 3; i += 3) { + int value = random.nextInt(); + assertEquals(i, buffer.writerIndex()); + assertTrue(buffer.isWritable()); + if (testBigEndian) { + buffer.writeMedium(value); + } else { + buffer.writeMediumLE(value); + } + } + + assertEquals(0, buffer.readerIndex()); + assertEquals(buffer.capacity() / 3 * 3, buffer.writerIndex()); + assertEquals(buffer.capacity() % 3, buffer.writableBytes()); + + random.setSeed(seed); + for (int i = 0; i < buffer.capacity() / 3 * 3; i += 3) { + int value = random.nextInt() << 8 >> 8; + assertEquals(i, buffer.readerIndex()); + assertTrue(buffer.isReadable()); + if (testBigEndian) { + assertEquals(value, buffer.readMedium()); + } else { + assertEquals(value, buffer.readMediumLE()); + } + } + + assertEquals(buffer.capacity() / 3 * 3, buffer.readerIndex()); + assertEquals(buffer.capacity() / 3 * 3, buffer.writerIndex()); + assertEquals(0, buffer.readableBytes()); + assertEquals(buffer.capacity() % 3, buffer.writableBytes()); + } + + @Test + public void testSequentialUnsignedMediumAccess() { + testSequentialUnsignedMediumAccess(true); + } + + @Test + public void testSequentialUnsignedMediumLEAccess() { + testSequentialUnsignedMediumAccess(false); + } + + private void testSequentialUnsignedMediumAccess(boolean testBigEndian) { + buffer.writerIndex(0); + for (int i = 0; i < buffer.capacity() / 3 * 3; i += 3) { + int value = random.nextInt() & 0x00FFFFFF; + assertEquals(i, buffer.writerIndex()); + assertTrue(buffer.isWritable()); + if (testBigEndian) { + buffer.writeMedium(value); + } else { + buffer.writeMediumLE(value); + } + } + + assertEquals(0, buffer.readerIndex()); + assertEquals(buffer.capacity() / 3 * 3, buffer.writerIndex()); + assertEquals(buffer.capacity() % 3, buffer.writableBytes()); + + random.setSeed(seed); + for (int i = 0; i < buffer.capacity() / 3 * 3; i += 3) { + int value = random.nextInt() & 0x00FFFFFF; + assertEquals(i, buffer.readerIndex()); + assertTrue(buffer.isReadable()); + if (testBigEndian) { + assertEquals(value, buffer.readUnsignedMedium()); + } else { + assertEquals(value, buffer.readUnsignedMediumLE()); + } + } + + assertEquals(buffer.capacity() / 3 * 3, buffer.readerIndex()); + assertEquals(buffer.capacity() / 3 * 3, buffer.writerIndex()); + assertEquals(0, buffer.readableBytes()); + assertEquals(buffer.capacity() % 3, buffer.writableBytes()); + } + + @Test + public void testSequentialIntAccess() { + testSequentialIntAccess(true); + } + + @Test + public void testSequentialIntLEAccess() { + testSequentialIntAccess(false); + } + + private void testSequentialIntAccess(boolean testBigEndian) { + buffer.writerIndex(0); + for (int i = 0; i < buffer.capacity(); i += 4) { + int value = random.nextInt(); + assertEquals(i, buffer.writerIndex()); + assertTrue(buffer.isWritable()); + if (testBigEndian) { + buffer.writeInt(value); + } else { + buffer.writeIntLE(value); + } + } + + assertEquals(0, buffer.readerIndex()); + assertEquals(buffer.capacity(), buffer.writerIndex()); + assertFalse(buffer.isWritable()); + + random.setSeed(seed); + for (int i = 0; i < buffer.capacity(); i += 4) { + int value = random.nextInt(); + assertEquals(i, buffer.readerIndex()); + assertTrue(buffer.isReadable()); + if (testBigEndian) { + assertEquals(value, buffer.readInt()); + } else { + assertEquals(value, buffer.readIntLE()); + } + } + + assertEquals(buffer.capacity(), buffer.readerIndex()); + assertEquals(buffer.capacity(), buffer.writerIndex()); + assertFalse(buffer.isReadable()); + assertFalse(buffer.isWritable()); + } + + @Test + public void testSequentialUnsignedIntAccess() { + testSequentialUnsignedIntAccess(true); + } + + @Test + public void testSequentialUnsignedIntLEAccess() { + testSequentialUnsignedIntAccess(false); + } + + private void testSequentialUnsignedIntAccess(boolean testBigEndian) { + buffer.writerIndex(0); + for (int i = 0; i < buffer.capacity(); i += 4) { + int value = random.nextInt(); + assertEquals(i, buffer.writerIndex()); + assertTrue(buffer.isWritable()); + if (testBigEndian) { + buffer.writeInt(value); + } else { + buffer.writeIntLE(value); + } + } + + assertEquals(0, buffer.readerIndex()); + assertEquals(buffer.capacity(), buffer.writerIndex()); + assertFalse(buffer.isWritable()); + + random.setSeed(seed); + for (int i = 0; i < buffer.capacity(); i += 4) { + long value = random.nextInt() & 0xFFFFFFFFL; + assertEquals(i, buffer.readerIndex()); + assertTrue(buffer.isReadable()); + if (testBigEndian) { + assertEquals(value, buffer.readUnsignedInt()); + } else { + assertEquals(value, buffer.readUnsignedIntLE()); + } + } + + assertEquals(buffer.capacity(), buffer.readerIndex()); + assertEquals(buffer.capacity(), buffer.writerIndex()); + assertFalse(buffer.isReadable()); + assertFalse(buffer.isWritable()); + } + + @Test + public void testSequentialLongAccess() { + testSequentialLongAccess(true); + } + + @Test + public void testSequentialLongLEAccess() { + testSequentialLongAccess(false); + } + + private void testSequentialLongAccess(boolean testBigEndian) { + buffer.writerIndex(0); + for (int i = 0; i < buffer.capacity(); i += 8) { + long value = random.nextLong(); + assertEquals(i, buffer.writerIndex()); + assertTrue(buffer.isWritable()); + if (testBigEndian) { + buffer.writeLong(value); + } else { + buffer.writeLongLE(value); + } + } + + assertEquals(0, buffer.readerIndex()); + assertEquals(buffer.capacity(), buffer.writerIndex()); + assertFalse(buffer.isWritable()); + + random.setSeed(seed); + for (int i = 0; i < buffer.capacity(); i += 8) { + long value = random.nextLong(); + assertEquals(i, buffer.readerIndex()); + assertTrue(buffer.isReadable()); + if (testBigEndian) { + assertEquals(value, buffer.readLong()); + } else { + assertEquals(value, buffer.readLongLE()); + } + } + + assertEquals(buffer.capacity(), buffer.readerIndex()); + assertEquals(buffer.capacity(), buffer.writerIndex()); + assertFalse(buffer.isReadable()); + assertFalse(buffer.isWritable()); + } + + @Test + public void testByteArrayTransfer() { + byte[] value = new byte[BLOCK_SIZE * 2]; + for (int i = 0; i < buffer.capacity() - BLOCK_SIZE + 1; i += BLOCK_SIZE) { + random.nextBytes(value); + buffer.setBytes(i, value, random.nextInt(BLOCK_SIZE), BLOCK_SIZE); + } + + random.setSeed(seed); + byte[] expectedValue = new byte[BLOCK_SIZE * 2]; + for (int i = 0; i < buffer.capacity() - BLOCK_SIZE + 1; i += BLOCK_SIZE) { + random.nextBytes(expectedValue); + int valueOffset = random.nextInt(BLOCK_SIZE); + buffer.getBytes(i, value, valueOffset, BLOCK_SIZE); + for (int j = valueOffset; j < valueOffset + BLOCK_SIZE; j ++) { + assertEquals(expectedValue[j], value[j]); + } + } + } + + @Test + public void testRandomByteArrayTransfer1() { + byte[] value = new byte[BLOCK_SIZE]; + for (int i = 0; i < buffer.capacity() - BLOCK_SIZE + 1; i += BLOCK_SIZE) { + random.nextBytes(value); + buffer.setBytes(i, value); + } + + random.setSeed(seed); + byte[] expectedValueContent = new byte[BLOCK_SIZE]; + ByteBuf expectedValue = wrappedBuffer(expectedValueContent); + for (int i = 0; i < buffer.capacity() - BLOCK_SIZE + 1; i += BLOCK_SIZE) { + random.nextBytes(expectedValueContent); + buffer.getBytes(i, value); + for (int j = 0; j < BLOCK_SIZE; j ++) { + assertEquals(expectedValue.getByte(j), value[j]); + } + } + } + + @Test + public void testRandomByteArrayTransfer2() { + byte[] value = new byte[BLOCK_SIZE * 2]; + for (int i = 0; i < buffer.capacity() - BLOCK_SIZE + 1; i += BLOCK_SIZE) { + random.nextBytes(value); + buffer.setBytes(i, value, random.nextInt(BLOCK_SIZE), BLOCK_SIZE); + } + + random.setSeed(seed); + byte[] expectedValueContent = new byte[BLOCK_SIZE * 2]; + ByteBuf expectedValue = wrappedBuffer(expectedValueContent); + for (int i = 0; i < buffer.capacity() - BLOCK_SIZE + 1; i += BLOCK_SIZE) { + random.nextBytes(expectedValueContent); + int valueOffset = random.nextInt(BLOCK_SIZE); + buffer.getBytes(i, value, valueOffset, BLOCK_SIZE); + for (int j = valueOffset; j < valueOffset + BLOCK_SIZE; j ++) { + assertEquals(expectedValue.getByte(j), value[j]); + } + } + } + + @Test + public void testRandomHeapBufferTransfer1() { + byte[] valueContent = new byte[BLOCK_SIZE]; + ByteBuf value = wrappedBuffer(valueContent); + for (int i = 0; i < buffer.capacity() - BLOCK_SIZE + 1; i += BLOCK_SIZE) { + random.nextBytes(valueContent); + value.setIndex(0, BLOCK_SIZE); + buffer.setBytes(i, value); + assertEquals(BLOCK_SIZE, value.readerIndex()); + assertEquals(BLOCK_SIZE, value.writerIndex()); + } + + random.setSeed(seed); + byte[] expectedValueContent = new byte[BLOCK_SIZE]; + ByteBuf expectedValue = wrappedBuffer(expectedValueContent); + for (int i = 0; i < buffer.capacity() - BLOCK_SIZE + 1; i += BLOCK_SIZE) { + random.nextBytes(expectedValueContent); + value.clear(); + buffer.getBytes(i, value); + assertEquals(0, value.readerIndex()); + assertEquals(BLOCK_SIZE, value.writerIndex()); + for (int j = 0; j < BLOCK_SIZE; j ++) { + assertEquals(expectedValue.getByte(j), value.getByte(j)); + } + } + } + + @Test + public void testRandomHeapBufferTransfer2() { + byte[] valueContent = new byte[BLOCK_SIZE * 2]; + ByteBuf value = wrappedBuffer(valueContent); + for (int i = 0; i < buffer.capacity() - BLOCK_SIZE + 1; i += BLOCK_SIZE) { + random.nextBytes(valueContent); + buffer.setBytes(i, value, random.nextInt(BLOCK_SIZE), BLOCK_SIZE); + } + + random.setSeed(seed); + byte[] expectedValueContent = new byte[BLOCK_SIZE * 2]; + ByteBuf expectedValue = wrappedBuffer(expectedValueContent); + for (int i = 0; i < buffer.capacity() - BLOCK_SIZE + 1; i += BLOCK_SIZE) { + random.nextBytes(expectedValueContent); + int valueOffset = random.nextInt(BLOCK_SIZE); + buffer.getBytes(i, value, valueOffset, BLOCK_SIZE); + for (int j = valueOffset; j < valueOffset + BLOCK_SIZE; j ++) { + assertEquals(expectedValue.getByte(j), value.getByte(j)); + } + } + } + + @Test + public void testRandomDirectBufferTransfer() { + byte[] tmp = new byte[BLOCK_SIZE * 2]; + ByteBuf value = directBuffer(BLOCK_SIZE * 2); + for (int i = 0; i < buffer.capacity() - BLOCK_SIZE + 1; i += BLOCK_SIZE) { + random.nextBytes(tmp); + value.setBytes(0, tmp, 0, value.capacity()); + buffer.setBytes(i, value, random.nextInt(BLOCK_SIZE), BLOCK_SIZE); + } + + random.setSeed(seed); + ByteBuf expectedValue = directBuffer(BLOCK_SIZE * 2); + for (int i = 0; i < buffer.capacity() - BLOCK_SIZE + 1; i += BLOCK_SIZE) { + random.nextBytes(tmp); + expectedValue.setBytes(0, tmp, 0, expectedValue.capacity()); + int valueOffset = random.nextInt(BLOCK_SIZE); + buffer.getBytes(i, value, valueOffset, BLOCK_SIZE); + for (int j = valueOffset; j < valueOffset + BLOCK_SIZE; j ++) { + assertEquals(expectedValue.getByte(j), value.getByte(j)); + } + } + value.release(); + expectedValue.release(); + } + + @Test + public void testRandomByteBufferTransfer() { + ByteBuffer value = ByteBuffer.allocate(BLOCK_SIZE * 2); + for (int i = 0; i < buffer.capacity() - BLOCK_SIZE + 1; i += BLOCK_SIZE) { + random.nextBytes(value.array()); + value.clear().position(random.nextInt(BLOCK_SIZE)); + value.limit(value.position() + BLOCK_SIZE); + buffer.setBytes(i, value); + } + + random.setSeed(seed); + ByteBuffer expectedValue = ByteBuffer.allocate(BLOCK_SIZE * 2); + for (int i = 0; i < buffer.capacity() - BLOCK_SIZE + 1; i += BLOCK_SIZE) { + random.nextBytes(expectedValue.array()); + int valueOffset = random.nextInt(BLOCK_SIZE); + value.clear().position(valueOffset).limit(valueOffset + BLOCK_SIZE); + buffer.getBytes(i, value); + assertEquals(valueOffset + BLOCK_SIZE, value.position()); + for (int j = valueOffset; j < valueOffset + BLOCK_SIZE; j ++) { + assertEquals(expectedValue.get(j), value.get(j)); + } + } + } + + @Test + public void testSequentialByteArrayTransfer1() { + byte[] value = new byte[BLOCK_SIZE]; + buffer.writerIndex(0); + for (int i = 0; i < buffer.capacity() - BLOCK_SIZE + 1; i += BLOCK_SIZE) { + random.nextBytes(value); + assertEquals(0, buffer.readerIndex()); + assertEquals(i, buffer.writerIndex()); + buffer.writeBytes(value); + } + + random.setSeed(seed); + byte[] expectedValue = new byte[BLOCK_SIZE]; + for (int i = 0; i < buffer.capacity() - BLOCK_SIZE + 1; i += BLOCK_SIZE) { + random.nextBytes(expectedValue); + assertEquals(i, buffer.readerIndex()); + assertEquals(CAPACITY, buffer.writerIndex()); + buffer.readBytes(value); + for (int j = 0; j < BLOCK_SIZE; j ++) { + assertEquals(expectedValue[j], value[j]); + } + } + } + + @Test + public void testSequentialByteArrayTransfer2() { + byte[] value = new byte[BLOCK_SIZE * 2]; + buffer.writerIndex(0); + for (int i = 0; i < buffer.capacity() - BLOCK_SIZE + 1; i += BLOCK_SIZE) { + random.nextBytes(value); + assertEquals(0, buffer.readerIndex()); + assertEquals(i, buffer.writerIndex()); + int readerIndex = random.nextInt(BLOCK_SIZE); + buffer.writeBytes(value, readerIndex, BLOCK_SIZE); + } + + random.setSeed(seed); + byte[] expectedValue = new byte[BLOCK_SIZE * 2]; + for (int i = 0; i < buffer.capacity() - BLOCK_SIZE + 1; i += BLOCK_SIZE) { + random.nextBytes(expectedValue); + int valueOffset = random.nextInt(BLOCK_SIZE); + assertEquals(i, buffer.readerIndex()); + assertEquals(CAPACITY, buffer.writerIndex()); + buffer.readBytes(value, valueOffset, BLOCK_SIZE); + for (int j = valueOffset; j < valueOffset + BLOCK_SIZE; j ++) { + assertEquals(expectedValue[j], value[j]); + } + } + } + + @Test + public void testSequentialHeapBufferTransfer1() { + byte[] valueContent = new byte[BLOCK_SIZE * 2]; + ByteBuf value = wrappedBuffer(valueContent); + buffer.writerIndex(0); + for (int i = 0; i < buffer.capacity() - BLOCK_SIZE + 1; i += BLOCK_SIZE) { + random.nextBytes(valueContent); + assertEquals(0, buffer.readerIndex()); + assertEquals(i, buffer.writerIndex()); + buffer.writeBytes(value, random.nextInt(BLOCK_SIZE), BLOCK_SIZE); + assertEquals(0, value.readerIndex()); + assertEquals(valueContent.length, value.writerIndex()); + } + + random.setSeed(seed); + byte[] expectedValueContent = new byte[BLOCK_SIZE * 2]; + ByteBuf expectedValue = wrappedBuffer(expectedValueContent); + for (int i = 0; i < buffer.capacity() - BLOCK_SIZE + 1; i += BLOCK_SIZE) { + random.nextBytes(expectedValueContent); + int valueOffset = random.nextInt(BLOCK_SIZE); + assertEquals(i, buffer.readerIndex()); + assertEquals(CAPACITY, buffer.writerIndex()); + buffer.readBytes(value, valueOffset, BLOCK_SIZE); + for (int j = valueOffset; j < valueOffset + BLOCK_SIZE; j ++) { + assertEquals(expectedValue.getByte(j), value.getByte(j)); + } + assertEquals(0, value.readerIndex()); + assertEquals(valueContent.length, value.writerIndex()); + } + } + + @Test + public void testSequentialHeapBufferTransfer2() { + byte[] valueContent = new byte[BLOCK_SIZE * 2]; + ByteBuf value = wrappedBuffer(valueContent); + buffer.writerIndex(0); + for (int i = 0; i < buffer.capacity() - BLOCK_SIZE + 1; i += BLOCK_SIZE) { + random.nextBytes(valueContent); + assertEquals(0, buffer.readerIndex()); + assertEquals(i, buffer.writerIndex()); + int readerIndex = random.nextInt(BLOCK_SIZE); + value.readerIndex(readerIndex); + value.writerIndex(readerIndex + BLOCK_SIZE); + buffer.writeBytes(value); + assertEquals(readerIndex + BLOCK_SIZE, value.writerIndex()); + assertEquals(value.writerIndex(), value.readerIndex()); + } + + random.setSeed(seed); + byte[] expectedValueContent = new byte[BLOCK_SIZE * 2]; + ByteBuf expectedValue = wrappedBuffer(expectedValueContent); + for (int i = 0; i < buffer.capacity() - BLOCK_SIZE + 1; i += BLOCK_SIZE) { + random.nextBytes(expectedValueContent); + int valueOffset = random.nextInt(BLOCK_SIZE); + assertEquals(i, buffer.readerIndex()); + assertEquals(CAPACITY, buffer.writerIndex()); + value.readerIndex(valueOffset); + value.writerIndex(valueOffset); + buffer.readBytes(value, BLOCK_SIZE); + for (int j = valueOffset; j < valueOffset + BLOCK_SIZE; j ++) { + assertEquals(expectedValue.getByte(j), value.getByte(j)); + } + assertEquals(valueOffset, value.readerIndex()); + assertEquals(valueOffset + BLOCK_SIZE, value.writerIndex()); + } + } + + @Test + public void testSequentialDirectBufferTransfer1() { + byte[] valueContent = new byte[BLOCK_SIZE * 2]; + ByteBuf value = directBuffer(BLOCK_SIZE * 2); + buffer.writerIndex(0); + for (int i = 0; i < buffer.capacity() - BLOCK_SIZE + 1; i += BLOCK_SIZE) { + random.nextBytes(valueContent); + value.setBytes(0, valueContent); + assertEquals(0, buffer.readerIndex()); + assertEquals(i, buffer.writerIndex()); + buffer.writeBytes(value, random.nextInt(BLOCK_SIZE), BLOCK_SIZE); + assertEquals(0, value.readerIndex()); + assertEquals(0, value.writerIndex()); + } + + random.setSeed(seed); + byte[] expectedValueContent = new byte[BLOCK_SIZE * 2]; + ByteBuf expectedValue = wrappedBuffer(expectedValueContent); + for (int i = 0; i < buffer.capacity() - BLOCK_SIZE + 1; i += BLOCK_SIZE) { + random.nextBytes(expectedValueContent); + int valueOffset = random.nextInt(BLOCK_SIZE); + value.setBytes(0, valueContent); + assertEquals(i, buffer.readerIndex()); + assertEquals(CAPACITY, buffer.writerIndex()); + buffer.readBytes(value, valueOffset, BLOCK_SIZE); + for (int j = valueOffset; j < valueOffset + BLOCK_SIZE; j ++) { + assertEquals(expectedValue.getByte(j), value.getByte(j)); + } + assertEquals(0, value.readerIndex()); + assertEquals(0, value.writerIndex()); + } + value.release(); + expectedValue.release(); + } + + @Test + public void testSequentialDirectBufferTransfer2() { + byte[] valueContent = new byte[BLOCK_SIZE * 2]; + ByteBuf value = directBuffer(BLOCK_SIZE * 2); + buffer.writerIndex(0); + for (int i = 0; i < buffer.capacity() - BLOCK_SIZE + 1; i += BLOCK_SIZE) { + random.nextBytes(valueContent); + value.setBytes(0, valueContent); + assertEquals(0, buffer.readerIndex()); + assertEquals(i, buffer.writerIndex()); + int readerIndex = random.nextInt(BLOCK_SIZE); + value.readerIndex(0); + value.writerIndex(readerIndex + BLOCK_SIZE); + value.readerIndex(readerIndex); + buffer.writeBytes(value); + assertEquals(readerIndex + BLOCK_SIZE, value.writerIndex()); + assertEquals(value.writerIndex(), value.readerIndex()); + } + + random.setSeed(seed); + byte[] expectedValueContent = new byte[BLOCK_SIZE * 2]; + ByteBuf expectedValue = wrappedBuffer(expectedValueContent); + for (int i = 0; i < buffer.capacity() - BLOCK_SIZE + 1; i += BLOCK_SIZE) { + random.nextBytes(expectedValueContent); + value.setBytes(0, valueContent); + int valueOffset = random.nextInt(BLOCK_SIZE); + assertEquals(i, buffer.readerIndex()); + assertEquals(CAPACITY, buffer.writerIndex()); + value.readerIndex(valueOffset); + value.writerIndex(valueOffset); + buffer.readBytes(value, BLOCK_SIZE); + for (int j = valueOffset; j < valueOffset + BLOCK_SIZE; j ++) { + assertEquals(expectedValue.getByte(j), value.getByte(j)); + } + assertEquals(valueOffset, value.readerIndex()); + assertEquals(valueOffset + BLOCK_SIZE, value.writerIndex()); + } + value.release(); + expectedValue.release(); + } + + @Test + public void testSequentialByteBufferBackedHeapBufferTransfer1() { + byte[] valueContent = new byte[BLOCK_SIZE * 2]; + ByteBuf value = wrappedBuffer(ByteBuffer.allocate(BLOCK_SIZE * 2)); + value.writerIndex(0); + buffer.writerIndex(0); + for (int i = 0; i < buffer.capacity() - BLOCK_SIZE + 1; i += BLOCK_SIZE) { + random.nextBytes(valueContent); + value.setBytes(0, valueContent); + assertEquals(0, buffer.readerIndex()); + assertEquals(i, buffer.writerIndex()); + buffer.writeBytes(value, random.nextInt(BLOCK_SIZE), BLOCK_SIZE); + assertEquals(0, value.readerIndex()); + assertEquals(0, value.writerIndex()); + } + + random.setSeed(seed); + byte[] expectedValueContent = new byte[BLOCK_SIZE * 2]; + ByteBuf expectedValue = wrappedBuffer(expectedValueContent); + for (int i = 0; i < buffer.capacity() - BLOCK_SIZE + 1; i += BLOCK_SIZE) { + random.nextBytes(expectedValueContent); + int valueOffset = random.nextInt(BLOCK_SIZE); + value.setBytes(0, valueContent); + assertEquals(i, buffer.readerIndex()); + assertEquals(CAPACITY, buffer.writerIndex()); + buffer.readBytes(value, valueOffset, BLOCK_SIZE); + for (int j = valueOffset; j < valueOffset + BLOCK_SIZE; j ++) { + assertEquals(expectedValue.getByte(j), value.getByte(j)); + } + assertEquals(0, value.readerIndex()); + assertEquals(0, value.writerIndex()); + } + } + + @Test + public void testSequentialByteBufferBackedHeapBufferTransfer2() { + byte[] valueContent = new byte[BLOCK_SIZE * 2]; + ByteBuf value = wrappedBuffer(ByteBuffer.allocate(BLOCK_SIZE * 2)); + value.writerIndex(0); + buffer.writerIndex(0); + for (int i = 0; i < buffer.capacity() - BLOCK_SIZE + 1; i += BLOCK_SIZE) { + random.nextBytes(valueContent); + value.setBytes(0, valueContent); + assertEquals(0, buffer.readerIndex()); + assertEquals(i, buffer.writerIndex()); + int readerIndex = random.nextInt(BLOCK_SIZE); + value.readerIndex(0); + value.writerIndex(readerIndex + BLOCK_SIZE); + value.readerIndex(readerIndex); + buffer.writeBytes(value); + assertEquals(readerIndex + BLOCK_SIZE, value.writerIndex()); + assertEquals(value.writerIndex(), value.readerIndex()); + } + + random.setSeed(seed); + byte[] expectedValueContent = new byte[BLOCK_SIZE * 2]; + ByteBuf expectedValue = wrappedBuffer(expectedValueContent); + for (int i = 0; i < buffer.capacity() - BLOCK_SIZE + 1; i += BLOCK_SIZE) { + random.nextBytes(expectedValueContent); + value.setBytes(0, valueContent); + int valueOffset = random.nextInt(BLOCK_SIZE); + assertEquals(i, buffer.readerIndex()); + assertEquals(CAPACITY, buffer.writerIndex()); + value.readerIndex(valueOffset); + value.writerIndex(valueOffset); + buffer.readBytes(value, BLOCK_SIZE); + for (int j = valueOffset; j < valueOffset + BLOCK_SIZE; j ++) { + assertEquals(expectedValue.getByte(j), value.getByte(j)); + } + assertEquals(valueOffset, value.readerIndex()); + assertEquals(valueOffset + BLOCK_SIZE, value.writerIndex()); + } + } + + @Test + public void testSequentialByteBufferTransfer() { + buffer.writerIndex(0); + ByteBuffer value = ByteBuffer.allocate(BLOCK_SIZE * 2); + for (int i = 0; i < buffer.capacity() - BLOCK_SIZE + 1; i += BLOCK_SIZE) { + random.nextBytes(value.array()); + value.clear().position(random.nextInt(BLOCK_SIZE)); + value.limit(value.position() + BLOCK_SIZE); + buffer.writeBytes(value); + } + + random.setSeed(seed); + ByteBuffer expectedValue = ByteBuffer.allocate(BLOCK_SIZE * 2); + for (int i = 0; i < buffer.capacity() - BLOCK_SIZE + 1; i += BLOCK_SIZE) { + random.nextBytes(expectedValue.array()); + int valueOffset = random.nextInt(BLOCK_SIZE); + value.clear().position(valueOffset).limit(valueOffset + BLOCK_SIZE); + buffer.readBytes(value); + assertEquals(valueOffset + BLOCK_SIZE, value.position()); + for (int j = valueOffset; j < valueOffset + BLOCK_SIZE; j ++) { + assertEquals(expectedValue.get(j), value.get(j)); + } + } + } + + @Test + public void testSequentialCopiedBufferTransfer1() { + buffer.writerIndex(0); + for (int i = 0; i < buffer.capacity() - BLOCK_SIZE + 1; i += BLOCK_SIZE) { + byte[] value = new byte[BLOCK_SIZE]; + random.nextBytes(value); + assertEquals(0, buffer.readerIndex()); + assertEquals(i, buffer.writerIndex()); + buffer.writeBytes(value); + } + + random.setSeed(seed); + byte[] expectedValue = new byte[BLOCK_SIZE]; + for (int i = 0; i < buffer.capacity() - BLOCK_SIZE + 1; i += BLOCK_SIZE) { + random.nextBytes(expectedValue); + assertEquals(i, buffer.readerIndex()); + assertEquals(CAPACITY, buffer.writerIndex()); + ByteBuf actualValue = buffer.readBytes(BLOCK_SIZE); + assertEquals(wrappedBuffer(expectedValue), actualValue); + + // Make sure if it is a copied buffer. + actualValue.setByte(0, (byte) (actualValue.getByte(0) + 1)); + assertFalse(buffer.getByte(i) == actualValue.getByte(0)); + actualValue.release(); + } + } + + @Test + public void testSequentialSlice1() { + buffer.writerIndex(0); + for (int i = 0; i < buffer.capacity() - BLOCK_SIZE + 1; i += BLOCK_SIZE) { + byte[] value = new byte[BLOCK_SIZE]; + random.nextBytes(value); + assertEquals(0, buffer.readerIndex()); + assertEquals(i, buffer.writerIndex()); + buffer.writeBytes(value); + } + + random.setSeed(seed); + byte[] expectedValue = new byte[BLOCK_SIZE]; + for (int i = 0; i < buffer.capacity() - BLOCK_SIZE + 1; i += BLOCK_SIZE) { + random.nextBytes(expectedValue); + assertEquals(i, buffer.readerIndex()); + assertEquals(CAPACITY, buffer.writerIndex()); + ByteBuf actualValue = buffer.readSlice(BLOCK_SIZE); + assertEquals(buffer.order(), actualValue.order()); + assertEquals(wrappedBuffer(expectedValue), actualValue); + + // Make sure if it is a sliced buffer. + actualValue.setByte(0, (byte) (actualValue.getByte(0) + 1)); + assertEquals(buffer.getByte(i), actualValue.getByte(0)); + } + } + + @Test + public void testWriteZero() { + try { + buffer.writeZero(-1); + fail(); + } catch (IllegalArgumentException e) { + // Expected + } + + buffer.clear(); + while (buffer.isWritable()) { + buffer.writeByte((byte) 0xFF); + } + + buffer.clear(); + for (int i = 0; i < buffer.capacity();) { + int length = Math.min(buffer.capacity() - i, random.nextInt(32)); + buffer.writeZero(length); + i += length; + } + + assertEquals(0, buffer.readerIndex()); + assertEquals(buffer.capacity(), buffer.writerIndex()); + + for (int i = 0; i < buffer.capacity(); i ++) { + assertEquals(0, buffer.getByte(i)); + } + } + + @Test + public void testDiscardReadBytes() { + buffer.writerIndex(0); + for (int i = 0; i < buffer.capacity(); i += 4) { + buffer.writeInt(i); + } + ByteBuf copy = copiedBuffer(buffer); + + // Make sure there's no effect if called when readerIndex is 0. + buffer.readerIndex(CAPACITY / 4); + buffer.markReaderIndex(); + buffer.writerIndex(CAPACITY / 3); + buffer.markWriterIndex(); + buffer.readerIndex(0); + buffer.writerIndex(CAPACITY / 2); + buffer.discardReadBytes(); + + assertEquals(0, buffer.readerIndex()); + assertEquals(CAPACITY / 2, buffer.writerIndex()); + assertEquals(copy.slice(0, CAPACITY / 2), buffer.slice(0, CAPACITY / 2)); + buffer.resetReaderIndex(); + assertEquals(CAPACITY / 4, buffer.readerIndex()); + buffer.resetWriterIndex(); + assertEquals(CAPACITY / 3, buffer.writerIndex()); + + // Make sure bytes after writerIndex is not copied. + buffer.readerIndex(1); + buffer.writerIndex(CAPACITY / 2); + buffer.discardReadBytes(); + + assertEquals(0, buffer.readerIndex()); + assertEquals(CAPACITY / 2 - 1, buffer.writerIndex()); + assertEquals(copy.slice(1, CAPACITY / 2 - 1), buffer.slice(0, CAPACITY / 2 - 1)); + + if (discardReadBytesDoesNotMoveWritableBytes()) { + // If writable bytes were copied, the test should fail to avoid unnecessary memory bandwidth consumption. + assertFalse(copy.slice(CAPACITY / 2, CAPACITY / 2).equals(buffer.slice(CAPACITY / 2 - 1, CAPACITY / 2))); + } else { + assertEquals(copy.slice(CAPACITY / 2, CAPACITY / 2), buffer.slice(CAPACITY / 2 - 1, CAPACITY / 2)); + } + + // Marks also should be relocated. + buffer.resetReaderIndex(); + assertEquals(CAPACITY / 4 - 1, buffer.readerIndex()); + buffer.resetWriterIndex(); + assertEquals(CAPACITY / 3 - 1, buffer.writerIndex()); + copy.release(); + } + + /** + * The similar test case with {@link #testDiscardReadBytes()} but this one + * discards a large chunk at once. + */ + @Test + public void testDiscardReadBytes2() { + buffer.writerIndex(0); + for (int i = 0; i < buffer.capacity(); i ++) { + buffer.writeByte((byte) i); + } + ByteBuf copy = copiedBuffer(buffer); + + // Discard the first (CAPACITY / 2 - 1) bytes. + buffer.setIndex(CAPACITY / 2 - 1, CAPACITY - 1); + buffer.discardReadBytes(); + assertEquals(0, buffer.readerIndex()); + assertEquals(CAPACITY / 2, buffer.writerIndex()); + for (int i = 0; i < CAPACITY / 2; i ++) { + assertEquals(copy.slice(CAPACITY / 2 - 1 + i, CAPACITY / 2 - i), buffer.slice(i, CAPACITY / 2 - i)); + } + copy.release(); + } + + @Test + public void testStreamTransfer1() throws Exception { + byte[] expected = new byte[buffer.capacity()]; + random.nextBytes(expected); + + for (int i = 0; i < buffer.capacity() - BLOCK_SIZE + 1; i += BLOCK_SIZE) { + ByteArrayInputStream in = new ByteArrayInputStream(expected, i, BLOCK_SIZE); + assertEquals(BLOCK_SIZE, buffer.setBytes(i, in, BLOCK_SIZE)); + assertEquals(-1, buffer.setBytes(i, in, 0)); + } + + ByteArrayOutputStream out = new ByteArrayOutputStream(); + for (int i = 0; i < buffer.capacity() - BLOCK_SIZE + 1; i += BLOCK_SIZE) { + buffer.getBytes(i, out, BLOCK_SIZE); + } + + assertTrue(Arrays.equals(expected, out.toByteArray())); + } + + @Test + public void testStreamTransfer2() throws Exception { + byte[] expected = new byte[buffer.capacity()]; + random.nextBytes(expected); + buffer.clear(); + + for (int i = 0; i < buffer.capacity() - BLOCK_SIZE + 1; i += BLOCK_SIZE) { + ByteArrayInputStream in = new ByteArrayInputStream(expected, i, BLOCK_SIZE); + assertEquals(i, buffer.writerIndex()); + buffer.writeBytes(in, BLOCK_SIZE); + assertEquals(i + BLOCK_SIZE, buffer.writerIndex()); + } + + ByteArrayOutputStream out = new ByteArrayOutputStream(); + for (int i = 0; i < buffer.capacity() - BLOCK_SIZE + 1; i += BLOCK_SIZE) { + assertEquals(i, buffer.readerIndex()); + buffer.readBytes(out, BLOCK_SIZE); + assertEquals(i + BLOCK_SIZE, buffer.readerIndex()); + } + + assertTrue(Arrays.equals(expected, out.toByteArray())); + } + + @Test + public void testCopy() { + for (int i = 0; i < buffer.capacity(); i ++) { + byte value = (byte) random.nextInt(); + buffer.setByte(i, value); + } + + final int readerIndex = CAPACITY / 3; + final int writerIndex = CAPACITY * 2 / 3; + buffer.setIndex(readerIndex, writerIndex); + + // Make sure all properties are copied. + ByteBuf copy = buffer.copy(); + assertEquals(0, copy.readerIndex()); + assertEquals(buffer.readableBytes(), copy.writerIndex()); + assertEquals(buffer.readableBytes(), copy.capacity()); + assertSame(buffer.order(), copy.order()); + for (int i = 0; i < copy.capacity(); i ++) { + assertEquals(buffer.getByte(i + readerIndex), copy.getByte(i)); + } + + // Make sure the buffer content is independent from each other. + buffer.setByte(readerIndex, (byte) (buffer.getByte(readerIndex) + 1)); + assertTrue(buffer.getByte(readerIndex) != copy.getByte(0)); + copy.setByte(1, (byte) (copy.getByte(1) + 1)); + assertTrue(buffer.getByte(readerIndex + 1) != copy.getByte(1)); + copy.release(); + } + + @Test + public void testDuplicate() { + for (int i = 0; i < buffer.capacity(); i ++) { + byte value = (byte) random.nextInt(); + buffer.setByte(i, value); + } + + final int readerIndex = CAPACITY / 3; + final int writerIndex = CAPACITY * 2 / 3; + buffer.setIndex(readerIndex, writerIndex); + + // Make sure all properties are copied. + ByteBuf duplicate = buffer.duplicate(); + assertSame(buffer.order(), duplicate.order()); + assertEquals(buffer.readableBytes(), duplicate.readableBytes()); + assertEquals(0, buffer.compareTo(duplicate)); + + // Make sure the buffer content is shared. + buffer.setByte(readerIndex, (byte) (buffer.getByte(readerIndex) + 1)); + assertEquals(buffer.getByte(readerIndex), duplicate.getByte(duplicate.readerIndex())); + duplicate.setByte(duplicate.readerIndex(), (byte) (duplicate.getByte(duplicate.readerIndex()) + 1)); + assertEquals(buffer.getByte(readerIndex), duplicate.getByte(duplicate.readerIndex())); + } + + @Test + public void testSliceEndianness() throws Exception { + assertEquals(buffer.order(), buffer.slice(0, buffer.capacity()).order()); + assertEquals(buffer.order(), buffer.slice(0, buffer.capacity() - 1).order()); + assertEquals(buffer.order(), buffer.slice(1, buffer.capacity() - 1).order()); + assertEquals(buffer.order(), buffer.slice(1, buffer.capacity() - 2).order()); + } + + @Test + public void testSliceIndex() throws Exception { + assertEquals(0, buffer.slice(0, buffer.capacity()).readerIndex()); + assertEquals(0, buffer.slice(0, buffer.capacity() - 1).readerIndex()); + assertEquals(0, buffer.slice(1, buffer.capacity() - 1).readerIndex()); + assertEquals(0, buffer.slice(1, buffer.capacity() - 2).readerIndex()); + + assertEquals(buffer.capacity(), buffer.slice(0, buffer.capacity()).writerIndex()); + assertEquals(buffer.capacity() - 1, buffer.slice(0, buffer.capacity() - 1).writerIndex()); + assertEquals(buffer.capacity() - 1, buffer.slice(1, buffer.capacity() - 1).writerIndex()); + assertEquals(buffer.capacity() - 2, buffer.slice(1, buffer.capacity() - 2).writerIndex()); + } + + @Test + public void testRetainedSliceIndex() throws Exception { + ByteBuf retainedSlice = buffer.retainedSlice(0, buffer.capacity()); + assertEquals(0, retainedSlice.readerIndex()); + retainedSlice.release(); + + retainedSlice = buffer.retainedSlice(0, buffer.capacity() - 1); + assertEquals(0, retainedSlice.readerIndex()); + retainedSlice.release(); + + retainedSlice = buffer.retainedSlice(1, buffer.capacity() - 1); + assertEquals(0, retainedSlice.readerIndex()); + retainedSlice.release(); + + retainedSlice = buffer.retainedSlice(1, buffer.capacity() - 2); + assertEquals(0, retainedSlice.readerIndex()); + retainedSlice.release(); + + retainedSlice = buffer.retainedSlice(0, buffer.capacity()); + assertEquals(buffer.capacity(), retainedSlice.writerIndex()); + retainedSlice.release(); + + retainedSlice = buffer.retainedSlice(0, buffer.capacity() - 1); + assertEquals(buffer.capacity() - 1, retainedSlice.writerIndex()); + retainedSlice.release(); + + retainedSlice = buffer.retainedSlice(1, buffer.capacity() - 1); + assertEquals(buffer.capacity() - 1, retainedSlice.writerIndex()); + retainedSlice.release(); + + retainedSlice = buffer.retainedSlice(1, buffer.capacity() - 2); + assertEquals(buffer.capacity() - 2, retainedSlice.writerIndex()); + retainedSlice.release(); + } + + @Test + @SuppressWarnings("ObjectEqualsNull") + public void testEquals() { + assertFalse(buffer.equals(null)); + assertFalse(buffer.equals(new Object())); + + byte[] value = new byte[32]; + buffer.setIndex(0, value.length); + random.nextBytes(value); + buffer.setBytes(0, value); + + assertEquals(buffer, wrappedBuffer(value)); + assertEquals(buffer, wrappedBuffer(value).order(LITTLE_ENDIAN)); + + value[0] ++; + assertFalse(buffer.equals(wrappedBuffer(value))); + assertFalse(buffer.equals(wrappedBuffer(value).order(LITTLE_ENDIAN))); + } + + @Test + public void testCompareTo() { + try { + buffer.compareTo(null); + fail(); + } catch (NullPointerException e) { + // Expected + } + + // Fill the random stuff + byte[] value = new byte[32]; + random.nextBytes(value); + // Prevent overflow / underflow + if (value[0] == 0) { + value[0] ++; + } else if (value[0] == -1) { + value[0] --; + } + + buffer.setIndex(0, value.length); + buffer.setBytes(0, value); + + assertEquals(0, buffer.compareTo(wrappedBuffer(value))); + assertEquals(0, buffer.compareTo(wrappedBuffer(value).order(LITTLE_ENDIAN))); + + value[0] ++; + assertTrue(buffer.compareTo(wrappedBuffer(value)) < 0); + assertTrue(buffer.compareTo(wrappedBuffer(value).order(LITTLE_ENDIAN)) < 0); + value[0] -= 2; + assertTrue(buffer.compareTo(wrappedBuffer(value)) > 0); + assertTrue(buffer.compareTo(wrappedBuffer(value).order(LITTLE_ENDIAN)) > 0); + value[0] ++; + + assertTrue(buffer.compareTo(wrappedBuffer(value, 0, 31)) > 0); + assertTrue(buffer.compareTo(wrappedBuffer(value, 0, 31).order(LITTLE_ENDIAN)) > 0); + assertTrue(buffer.slice(0, 31).compareTo(wrappedBuffer(value)) < 0); + assertTrue(buffer.slice(0, 31).compareTo(wrappedBuffer(value).order(LITTLE_ENDIAN)) < 0); + + ByteBuf retainedSlice = buffer.retainedSlice(0, 31); + assertTrue(retainedSlice.compareTo(wrappedBuffer(value)) < 0); + retainedSlice.release(); + + retainedSlice = buffer.retainedSlice(0, 31); + assertTrue(retainedSlice.compareTo(wrappedBuffer(value).order(LITTLE_ENDIAN)) < 0); + retainedSlice.release(); + } + + @Test + public void testCompareTo2() { + byte[] bytes = {1, 2, 3, 4}; + byte[] bytesReversed = {4, 3, 2, 1}; + + ByteBuf buf1 = newBuffer(4).clear().writeBytes(bytes).order(ByteOrder.LITTLE_ENDIAN); + ByteBuf buf2 = newBuffer(4).clear().writeBytes(bytesReversed).order(ByteOrder.LITTLE_ENDIAN); + ByteBuf buf3 = newBuffer(4).clear().writeBytes(bytes).order(ByteOrder.BIG_ENDIAN); + ByteBuf buf4 = newBuffer(4).clear().writeBytes(bytesReversed).order(ByteOrder.BIG_ENDIAN); + try { + assertEquals(buf1.compareTo(buf2), buf3.compareTo(buf4)); + assertEquals(buf2.compareTo(buf1), buf4.compareTo(buf3)); + assertEquals(buf1.compareTo(buf3), buf2.compareTo(buf4)); + assertEquals(buf3.compareTo(buf1), buf4.compareTo(buf2)); + } finally { + buf1.release(); + buf2.release(); + buf3.release(); + buf4.release(); + } + } + + @Test + public void testToString() { + ByteBuf copied = copiedBuffer("Hello, World!", CharsetUtil.ISO_8859_1); + buffer.clear(); + buffer.writeBytes(copied); + assertEquals("Hello, World!", buffer.toString(CharsetUtil.ISO_8859_1)); + copied.release(); + } + + @Test + @Timeout(value = 10000, unit = TimeUnit.MILLISECONDS) + public void testToStringMultipleThreads() throws Throwable { + buffer.clear(); + buffer.writeBytes("Hello, World!".getBytes(CharsetUtil.ISO_8859_1)); + + final AtomicInteger counter = new AtomicInteger(30000); + final AtomicReference errorRef = new AtomicReference(); + List threads = new ArrayList(); + for (int i = 0; i < 10; i++) { + Thread thread = new Thread(new Runnable() { + @Override + public void run() { + try { + while (errorRef.get() == null && counter.decrementAndGet() > 0) { + assertEquals("Hello, World!", buffer.toString(CharsetUtil.ISO_8859_1)); + } + } catch (Throwable cause) { + errorRef.compareAndSet(null, cause); + } + } + }); + threads.add(thread); + } + for (Thread thread : threads) { + thread.start(); + } + + for (Thread thread : threads) { + thread.join(); + } + + Throwable error = errorRef.get(); + if (error != null) { + throw error; + } + } + + @Test + public void testSWARIndexOf() { + ByteBuf buffer = newBuffer(16); + buffer.clear(); + // Ensure the buffer is completely zero'ed. + buffer.setZero(0, buffer.capacity()); + buffer.writeByte((byte) 0); // 0 + buffer.writeByte((byte) 0); + buffer.writeByte((byte) 0); + buffer.writeByte((byte) 0); + buffer.writeByte((byte) 0); + buffer.writeByte((byte) 0); + buffer.writeByte((byte) 0); + buffer.writeByte((byte) 0); // 7 + + buffer.writeByte((byte) 0); + buffer.writeByte((byte) 0); + buffer.writeByte((byte) 0); + buffer.writeByte((byte) 1); // 11 + buffer.writeByte((byte) 2); + buffer.writeByte((byte) 3); + buffer.writeByte((byte) 4); + buffer.writeByte((byte) 1); + assertEquals(11, buffer.indexOf(0, 12, (byte) 1)); + assertEquals(12, buffer.indexOf(0, 16, (byte) 2)); + assertEquals(-1, buffer.indexOf(0, 11, (byte) 1)); + assertEquals(11, buffer.indexOf(0, 16, (byte) 1)); + buffer.release(); + } + + @Test + public void testIndexOf() { + buffer.clear(); + // Ensure the buffer is completely zero'ed. + buffer.setZero(0, buffer.capacity()); + + buffer.writeByte((byte) 1); + buffer.writeByte((byte) 2); + buffer.writeByte((byte) 3); + buffer.writeByte((byte) 2); + buffer.writeByte((byte) 1); + + assertEquals(-1, buffer.indexOf(1, 4, (byte) 1)); + assertEquals(-1, buffer.indexOf(4, 1, (byte) 1)); + assertEquals(1, buffer.indexOf(1, 4, (byte) 2)); + assertEquals(3, buffer.indexOf(4, 1, (byte) 2)); + + try { + buffer.indexOf(0, buffer.capacity() + 1, (byte) 0); + fail(); + } catch (IndexOutOfBoundsException expected) { + // expected + } + + try { + buffer.indexOf(buffer.capacity(), -1, (byte) 0); + fail(); + } catch (IndexOutOfBoundsException expected) { + // expected + } + + assertEquals(4, buffer.indexOf(buffer.capacity() + 1, 0, (byte) 1)); + assertEquals(0, buffer.indexOf(-1, buffer.capacity(), (byte) 1)); + } + + @Test + public void testIndexOfReleaseBuffer() { + ByteBuf buffer = releasedBuffer(); + if (buffer.capacity() != 0) { + try { + buffer.indexOf(0, 1, (byte) 1); + fail(); + } catch (IllegalReferenceCountException expected) { + // expected + } + } else { + assertEquals(-1, buffer.indexOf(0, 1, (byte) 1)); + } + } + + @Test + public void testNioBuffer1() { + assumeTrue(buffer.nioBufferCount() == 1); + + byte[] value = new byte[buffer.capacity()]; + random.nextBytes(value); + buffer.clear(); + buffer.writeBytes(value); + + assertRemainingEquals(ByteBuffer.wrap(value), buffer.nioBuffer()); + } + + @Test + public void testToByteBuffer2() { + assumeTrue(buffer.nioBufferCount() == 1); + + byte[] value = new byte[buffer.capacity()]; + random.nextBytes(value); + buffer.clear(); + buffer.writeBytes(value); + + for (int i = 0; i < buffer.capacity() - BLOCK_SIZE + 1; i += BLOCK_SIZE) { + assertRemainingEquals(ByteBuffer.wrap(value, i, BLOCK_SIZE), buffer.nioBuffer(i, BLOCK_SIZE)); + } + } + + private static void assertRemainingEquals(ByteBuffer expected, ByteBuffer actual) { + int remaining = expected.remaining(); + int remaining2 = actual.remaining(); + + assertEquals(remaining, remaining2); + byte[] array1 = new byte[remaining]; + byte[] array2 = new byte[remaining2]; + expected.get(array1); + actual.get(array2); + assertArrayEquals(array1, array2); + } + + @Test + public void testToByteBuffer3() { + assumeTrue(buffer.nioBufferCount() == 1); + + assertEquals(buffer.order(), buffer.nioBuffer().order()); + } + + @Test + public void testSkipBytes1() { + buffer.setIndex(CAPACITY / 4, CAPACITY / 2); + + buffer.skipBytes(CAPACITY / 4); + assertEquals(CAPACITY / 4 * 2, buffer.readerIndex()); + + try { + buffer.skipBytes(CAPACITY / 4 + 1); + fail(); + } catch (IndexOutOfBoundsException e) { + // Expected + } + + // Should remain unchanged. + assertEquals(CAPACITY / 4 * 2, buffer.readerIndex()); + } + + @Test + public void testHashCode() { + ByteBuf elemA = buffer(15); + ByteBuf elemB = directBuffer(15); + elemA.writeBytes(new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5 }); + elemB.writeBytes(new byte[] { 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 9 }); + + Set set = new HashSet(); + set.add(elemA); + set.add(elemB); + + assertEquals(2, set.size()); + ByteBuf elemACopy = elemA.copy(); + assertTrue(set.contains(elemACopy)); + + ByteBuf elemBCopy = elemB.copy(); + assertTrue(set.contains(elemBCopy)); + + buffer.clear(); + buffer.writeBytes(elemA.duplicate()); + + assertTrue(set.remove(buffer)); + assertFalse(set.contains(elemA)); + assertEquals(1, set.size()); + + buffer.clear(); + buffer.writeBytes(elemB.duplicate()); + assertTrue(set.remove(buffer)); + assertFalse(set.contains(elemB)); + assertEquals(0, set.size()); + elemA.release(); + elemB.release(); + elemACopy.release(); + elemBCopy.release(); + } + + // Test case for https://github.com/netty/netty/issues/325 + @Test + public void testDiscardAllReadBytes() { + buffer.writerIndex(buffer.capacity()); + buffer.readerIndex(buffer.writerIndex()); + buffer.discardReadBytes(); + } + + @Test + public void testForEachByte() { + buffer.clear(); + for (int i = 0; i < CAPACITY; i ++) { + buffer.writeByte(i + 1); + } + + final AtomicInteger lastIndex = new AtomicInteger(); + buffer.setIndex(CAPACITY / 4, CAPACITY * 3 / 4); + assertThat(buffer.forEachByte(new ByteProcessor() { + int i = CAPACITY / 4; + + @Override + public boolean process(byte value) throws Exception { + assertThat(value, is((byte) (i + 1))); + lastIndex.set(i); + i ++; + return true; + } + }), is(-1)); + + assertThat(lastIndex.get(), is(CAPACITY * 3 / 4 - 1)); + } + + @Test + public void testForEachByteAbort() { + buffer.clear(); + for (int i = 0; i < CAPACITY; i ++) { + buffer.writeByte(i + 1); + } + + final int stop = CAPACITY / 2; + assertThat(buffer.forEachByte(CAPACITY / 3, CAPACITY / 3, new ByteProcessor() { + int i = CAPACITY / 3; + + @Override + public boolean process(byte value) throws Exception { + assertThat(value, is((byte) (i + 1))); + if (i == stop) { + return false; + } + + i++; + return true; + } + }), is(stop)); + } + + @Test + public void testForEachByteDesc() { + buffer.clear(); + for (int i = 0; i < CAPACITY; i ++) { + buffer.writeByte(i + 1); + } + + final AtomicInteger lastIndex = new AtomicInteger(); + assertThat(buffer.forEachByteDesc(CAPACITY / 4, CAPACITY * 2 / 4, new ByteProcessor() { + int i = CAPACITY * 3 / 4 - 1; + + @Override + public boolean process(byte value) throws Exception { + assertThat(value, is((byte) (i + 1))); + lastIndex.set(i); + i --; + return true; + } + }), is(-1)); + + assertThat(lastIndex.get(), is(CAPACITY / 4)); + } + + @Test + public void testInternalNioBuffer() { + testInternalNioBuffer(128); + testInternalNioBuffer(1024); + testInternalNioBuffer(4 * 1024); + testInternalNioBuffer(64 * 1024); + testInternalNioBuffer(32 * 1024 * 1024); + testInternalNioBuffer(64 * 1024 * 1024); + } + + private void testInternalNioBuffer(int a) { + ByteBuf buffer = newBuffer(2); + ByteBuffer buf = buffer.internalNioBuffer(buffer.readerIndex(), 1); + assertEquals(1, buf.remaining()); + + byte[] data = new byte[a]; + PlatformDependent.threadLocalRandom().nextBytes(data); + buffer.writeBytes(data); + + buf = buffer.internalNioBuffer(buffer.readerIndex(), a); + assertEquals(a, buf.remaining()); + + for (int i = 0; i < a; i++) { + assertEquals(data[i], buf.get()); + } + assertFalse(buf.hasRemaining()); + buffer.release(); + } + + @Test + public void testDuplicateReadGatheringByteChannelMultipleThreads() throws Exception { + testReadGatheringByteChannelMultipleThreads(false); + } + + @Test + public void testSliceReadGatheringByteChannelMultipleThreads() throws Exception { + testReadGatheringByteChannelMultipleThreads(true); + } + + private void testReadGatheringByteChannelMultipleThreads(final boolean slice) throws Exception { + final byte[] bytes = new byte[8]; + random.nextBytes(bytes); + + final ByteBuf buffer = newBuffer(8); + buffer.writeBytes(bytes); + final CountDownLatch latch = new CountDownLatch(60000); + final CyclicBarrier barrier = new CyclicBarrier(11); + for (int i = 0; i < 10; i++) { + new Thread(new Runnable() { + @Override + public void run() { + while (latch.getCount() > 0) { + ByteBuf buf; + if (slice) { + buf = buffer.slice(); + } else { + buf = buffer.duplicate(); + } + TestGatheringByteChannel channel = new TestGatheringByteChannel(); + + while (buf.isReadable()) { + try { + buf.readBytes(channel, buf.readableBytes()); + } catch (IOException e) { + // Never happens + return; + } + } + assertArrayEquals(bytes, channel.writtenBytes()); + latch.countDown(); + } + try { + barrier.await(); + } catch (Exception e) { + // ignore + } + } + }).start(); + } + latch.await(10, TimeUnit.SECONDS); + barrier.await(5, TimeUnit.SECONDS); + buffer.release(); + } + + @Test + public void testDuplicateReadOutputStreamMultipleThreads() throws Exception { + testReadOutputStreamMultipleThreads(false); + } + + @Test + public void testSliceReadOutputStreamMultipleThreads() throws Exception { + testReadOutputStreamMultipleThreads(true); + } + + private void testReadOutputStreamMultipleThreads(final boolean slice) throws Exception { + final byte[] bytes = new byte[8]; + random.nextBytes(bytes); + + final ByteBuf buffer = newBuffer(8); + buffer.writeBytes(bytes); + final CountDownLatch latch = new CountDownLatch(60000); + final CyclicBarrier barrier = new CyclicBarrier(11); + for (int i = 0; i < 10; i++) { + new Thread(new Runnable() { + @Override + public void run() { + while (latch.getCount() > 0) { + ByteBuf buf; + if (slice) { + buf = buffer.slice(); + } else { + buf = buffer.duplicate(); + } + ByteArrayOutputStream out = new ByteArrayOutputStream(); + + while (buf.isReadable()) { + try { + buf.readBytes(out, buf.readableBytes()); + } catch (IOException e) { + // Never happens + return; + } + } + assertArrayEquals(bytes, out.toByteArray()); + latch.countDown(); + } + try { + barrier.await(); + } catch (Exception e) { + // ignore + } + } + }).start(); + } + latch.await(10, TimeUnit.SECONDS); + barrier.await(5, TimeUnit.SECONDS); + buffer.release(); + } + + @Test + public void testDuplicateBytesInArrayMultipleThreads() throws Exception { + testBytesInArrayMultipleThreads(false); + } + + @Test + public void testSliceBytesInArrayMultipleThreads() throws Exception { + testBytesInArrayMultipleThreads(true); + } + + private void testBytesInArrayMultipleThreads(final boolean slice) throws Exception { + final byte[] bytes = new byte[8]; + random.nextBytes(bytes); + + final ByteBuf buffer = newBuffer(8); + buffer.writeBytes(bytes); + final AtomicReference cause = new AtomicReference(); + final CountDownLatch latch = new CountDownLatch(60000); + final CyclicBarrier barrier = new CyclicBarrier(11); + for (int i = 0; i < 10; i++) { + new Thread(new Runnable() { + @Override + public void run() { + while (cause.get() == null && latch.getCount() > 0) { + ByteBuf buf; + if (slice) { + buf = buffer.slice(); + } else { + buf = buffer.duplicate(); + } + + byte[] array = new byte[8]; + buf.readBytes(array); + + assertArrayEquals(bytes, array); + + Arrays.fill(array, (byte) 0); + buf.getBytes(0, array); + assertArrayEquals(bytes, array); + + latch.countDown(); + } + try { + barrier.await(); + } catch (Exception e) { + // ignore + } + } + }).start(); + } + latch.await(10, TimeUnit.SECONDS); + barrier.await(5, TimeUnit.SECONDS); + assertNull(cause.get()); + buffer.release(); + } + + @Test + public void readByteThrowsIndexOutOfBoundsException() { + final ByteBuf buffer = newBuffer(8); + try { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + buffer.writeByte(0); + assertEquals((byte) 0, buffer.readByte()); + buffer.readByte(); + } + }); + } finally { + buffer.release(); + } + } + + @Test + @SuppressWarnings("ForLoopThatDoesntUseLoopVariable") + public void testNioBufferExposeOnlyRegion() { + final ByteBuf buffer = newBuffer(8); + byte[] data = new byte[8]; + random.nextBytes(data); + buffer.writeBytes(data); + + ByteBuffer nioBuf = buffer.nioBuffer(1, data.length - 2); + assertEquals(0, nioBuf.position()); + assertEquals(6, nioBuf.remaining()); + + for (int i = 1; nioBuf.hasRemaining(); i++) { + assertEquals(data[i], nioBuf.get()); + } + buffer.release(); + } + + @Test + public void ensureWritableWithForceDoesNotThrow() { + ensureWritableDoesNotThrow(true); + } + + @Test + public void ensureWritableWithOutForceDoesNotThrow() { + ensureWritableDoesNotThrow(false); + } + + private void ensureWritableDoesNotThrow(boolean force) { + final ByteBuf buffer = newBuffer(8); + buffer.writerIndex(buffer.capacity()); + buffer.ensureWritable(8, force); + buffer.release(); + } + + // See: + // - https://github.com/netty/netty/issues/2587 + // - https://github.com/netty/netty/issues/2580 + @Test + public void testLittleEndianWithExpand() { + ByteBuf buffer = newBuffer(0).order(LITTLE_ENDIAN); + buffer.writeInt(0x12345678); + assertEquals("78563412", ByteBufUtil.hexDump(buffer)); + buffer.release(); + } + + private ByteBuf releasedBuffer() { + ByteBuf buffer = newBuffer(8); + // Clear the buffer so we are sure the reader and writer indices are 0. + // This is important as we may return a slice from newBuffer(...). + buffer.clear(); + assertTrue(buffer.release()); + return buffer; + } + + @Test + public void testDiscardReadBytesAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().discardReadBytes(); + } + }); + } + + @Test + public void testDiscardSomeReadBytesAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().discardSomeReadBytes(); + } + }); + } + + @Test + public void testEnsureWritableAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().ensureWritable(16); + } + }); + } + + @Test + public void testGetBooleanAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().getBoolean(0); + } + }); + } + + @Test + public void testGetByteAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().getByte(0); + } + }); + } + + @Test + public void testGetUnsignedByteAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().getUnsignedByte(0); + } + }); + } + + @Test + public void testGetShortAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().getShort(0); + } + }); + } + + @Test + public void testGetShortLEAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().getShortLE(0); + } + }); + } + + @Test + public void testGetUnsignedShortAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().getUnsignedShort(0); + } + }); + } + + @Test + public void testGetUnsignedShortLEAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().getUnsignedShortLE(0); + } + }); + } + + @Test + public void testGetMediumAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().getMedium(0); + } + }); + } + + @Test + public void testGetMediumLEAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().getMediumLE(0); + } + }); + } + + @Test + public void testGetUnsignedMediumAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().getUnsignedMedium(0); + } + }); + } + + @Test + public void testGetIntAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().getInt(0); + } + }); + } + + @Test + public void testGetIntLEAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().getIntLE(0); + } + }); + } + + @Test + public void testGetUnsignedIntAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().getUnsignedInt(0); + } + }); + } + + @Test + public void testGetUnsignedIntLEAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().getUnsignedIntLE(0); + } + }); + } + + @Test + public void testGetLongAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().getLong(0); + } + }); + } + + @Test + public void testGetLongLEAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().getLongLE(0); + } + }); + } + + @Test + public void testGetCharAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().getChar(0); + } + }); + } + + @Test + public void testGetFloatAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().getFloat(0); + } + }); + } + + @Test + public void testGetFloatLEAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().getFloatLE(0); + } + }); + } + + @Test + public void testGetDoubleAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().getDouble(0); + } + }); + } + + @Test + public void testGetDoubleLEAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().getDoubleLE(0); + } + }); + } + + @Test + public void testGetBytesAfterRelease() { + final ByteBuf buffer = buffer(8); + try { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().getBytes(0, buffer); + } + }); + } finally { + buffer.release(); + } + } + + @Test + public void testGetBytesAfterRelease2() { + final ByteBuf buffer = buffer(); + try { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().getBytes(0, buffer, 1); + } + }); + } finally { + buffer.release(); + } + } + + @Test + public void testGetBytesAfterRelease3() { + final ByteBuf buffer = buffer(); + try { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().getBytes(0, buffer, 0, 1); + } + }); + } finally { + buffer.release(); + } + } + + @Test + public void testGetBytesAfterRelease4() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().getBytes(0, new byte[8]); + } + }); + } + + @Test + public void testGetBytesAfterRelease5() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().getBytes(0, new byte[8], 0, 1); + } + }); + } + + @Test + public void testGetBytesAfterRelease6() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().getBytes(0, ByteBuffer.allocate(8)); + } + }); + } + + @Test + public void testGetBytesAfterRelease7() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() throws IOException { + releasedBuffer().getBytes(0, new ByteArrayOutputStream(), 1); + } + }); + } + + @Test + public void testGetBytesAfterRelease8() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() throws IOException { + releasedBuffer().getBytes(0, new DevNullGatheringByteChannel(), 1); + } + }); + } + + @Test + public void testSetBooleanAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().setBoolean(0, true); + } + }); + } + + @Test + public void testSetByteAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().setByte(0, 1); + } + }); + } + + @Test + public void testSetShortAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().setShort(0, 1); + } + }); + } + + @Test + public void testSetShortLEAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().setShortLE(0, 1); + } + }); + } + + @Test + public void testSetMediumAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().setMedium(0, 1); + } + }); + } + + @Test + public void testSetMediumLEAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().setMediumLE(0, 1); + } + }); + } + + @Test + public void testSetIntAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().setInt(0, 1); + } + }); + } + + @Test + public void testSetIntLEAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().setIntLE(0, 1); + } + }); + } + + @Test + public void testSetLongAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().setLong(0, 1); + } + }); + } + + @Test + public void testSetLongLEAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().setLongLE(0, 1); + } + }); + } + + @Test + public void testSetCharAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().setChar(0, 1); + } + }); + } + + @Test + public void testSetFloatAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().setFloat(0, 1); + } + }); + } + + @Test + public void testSetDoubleAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().setDouble(0, 1); + } + }); + } + + @Test + public void testSetBytesAfterRelease() { + final ByteBuf buffer = buffer(); + try { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().setBytes(0, buffer); + } + }); + } finally { + buffer.release(); + } + } + + @Test + public void testSetBytesAfterRelease2() { + final ByteBuf buffer = buffer(); + try { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().setBytes(0, buffer, 1); + } + }); + } finally { + buffer.release(); + } + } + + @Test + public void testSetBytesAfterRelease3() { + final ByteBuf buffer = buffer(); + try { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().setBytes(0, buffer, 0, 1); + } + }); + } finally { + buffer.release(); + } + } + + @Test + public void testSetUsAsciiCharSequenceAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + testSetCharSequenceAfterRelease0(CharsetUtil.US_ASCII); + } + }); + } + + @Test + public void testSetIso88591CharSequenceAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + testSetCharSequenceAfterRelease0(CharsetUtil.ISO_8859_1); + } + }); + } + + @Test + public void testSetUtf8CharSequenceAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + testSetCharSequenceAfterRelease0(CharsetUtil.UTF_8); + } + }); + } + + @Test + public void testSetUtf16CharSequenceAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + testSetCharSequenceAfterRelease0(CharsetUtil.UTF_16); + } + }); + } + + private void testSetCharSequenceAfterRelease0(Charset charset) { + releasedBuffer().setCharSequence(0, "x", charset); + } + + @Test + public void testSetBytesAfterRelease4() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().setBytes(0, new byte[8]); + } + }); + } + + @Test + public void testSetBytesAfterRelease5() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().setBytes(0, new byte[8], 0, 1); + } + }); + } + + @Test + public void testSetBytesAfterRelease6() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().setBytes(0, ByteBuffer.allocate(8)); + } + }); + } + + @Test + public void testSetBytesAfterRelease7() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() throws IOException { + releasedBuffer().setBytes(0, new ByteArrayInputStream(new byte[8]), 1); + } + }); + } + + @Test + public void testSetBytesAfterRelease8() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() throws IOException { + releasedBuffer().setBytes(0, new TestScatteringByteChannel(), 1); + } + }); + } + + @Test + public void testSetZeroAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().setZero(0, 1); + } + }); + } + + @Test + public void testReadBooleanAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().readBoolean(); + } + }); + } + + @Test + public void testReadByteAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().readByte(); + } + }); + } + + @Test + public void testReadUnsignedByteAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().readUnsignedByte(); + } + }); + } + + @Test + public void testReadShortAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().readShort(); + } + }); + } + + @Test + public void testReadShortLEAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().readShortLE(); + } + }); + } + + @Test + public void testReadUnsignedShortAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().readUnsignedShort(); + } + }); + } + + @Test + public void testReadUnsignedShortLEAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().readUnsignedShortLE(); + } + }); + } + + @Test + public void testReadMediumAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().readMedium(); + } + }); + } + + @Test + public void testReadMediumLEAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().readMediumLE(); + } + }); + } + + @Test + public void testReadUnsignedMediumAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().readUnsignedMedium(); + } + }); + } + + @Test + public void testReadUnsignedMediumLEAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().readUnsignedMediumLE(); + } + }); + } + + @Test + public void testReadIntAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().readInt(); + } + }); + } + + @Test + public void testReadIntLEAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().readIntLE(); + } + }); + } + + @Test + public void testReadUnsignedIntAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().readUnsignedInt(); + } + }); + } + + @Test + public void testReadUnsignedIntLEAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().readUnsignedIntLE(); + } + }); + } + + @Test + public void testReadLongAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().readLong(); + } + }); + } + + @Test + public void testReadLongLEAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().readLongLE(); + } + }); + } + + @Test + public void testReadCharAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().readChar(); + } + }); + } + + @Test + public void testReadFloatAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().readFloat(); + } + }); + } + + @Test + public void testReadFloatLEAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().readFloatLE(); + } + }); + } + + @Test + public void testReadDoubleAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().readDouble(); + } + }); + } + + @Test + public void testReadDoubleLEAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().readDoubleLE(); + } + }); + } + + @Test + public void testReadBytesAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().readBytes(1); + } + }); + } + + @Test + public void testReadBytesAfterRelease2() { + final ByteBuf buffer = buffer(8); + try { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().readBytes(buffer); + } + }); + } finally { + buffer.release(); + } + } + + @Test + public void testReadBytesAfterRelease3() { + final ByteBuf buffer = buffer(8); + try { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().readBytes(buffer); + } + }); + } finally { + buffer.release(); + } + } + + @Test + public void testReadBytesAfterRelease4() { + final ByteBuf buffer = buffer(8); + try { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().readBytes(buffer, 0, 1); + } + }); + } finally { + buffer.release(); + } + } + + @Test + public void testReadBytesAfterRelease5() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().readBytes(new byte[8]); + } + }); + } + + @Test + public void testReadBytesAfterRelease6() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().readBytes(new byte[8], 0, 1); + } + }); + } + + @Test + public void testReadBytesAfterRelease7() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().readBytes(ByteBuffer.allocate(8)); + } + }); + } + + @Test + public void testReadBytesAfterRelease8() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() throws IOException { + releasedBuffer().readBytes(new ByteArrayOutputStream(), 1); + } + }); + } + + @Test + public void testReadBytesAfterRelease9() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() throws IOException { + releasedBuffer().readBytes(new ByteArrayOutputStream(), 1); + } + }); + } + + @Test + public void testReadBytesAfterRelease10() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() throws IOException { + releasedBuffer().readBytes(new DevNullGatheringByteChannel(), 1); + } + }); + } + + @Test + public void testWriteBooleanAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().writeBoolean(true); + } + }); + } + + @Test + public void testWriteByteAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().writeByte(1); + } + }); + } + + @Test + public void testWriteShortAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().writeShort(1); + } + }); + } + + @Test + public void testWriteShortLEAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().writeShortLE(1); + } + }); + } + + @Test + public void testWriteMediumAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().writeMedium(1); + } + }); + } + + @Test + public void testWriteMediumLEAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().writeMediumLE(1); + } + }); + } + + @Test + public void testWriteIntAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().writeInt(1); + } + }); + } + + @Test + public void testWriteIntLEAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().writeIntLE(1); + } + }); + } + + @Test + public void testWriteLongAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().writeLong(1); + } + }); + } + + @Test + public void testWriteLongLEAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().writeLongLE(1); + } + }); + } + + @Test + public void testWriteCharAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().writeChar(1); + } + }); + } + + @Test + public void testWriteFloatAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().writeFloat(1); + } + }); + } + + @Test + public void testWriteFloatLEAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().writeFloatLE(1); + } + }); + } + + @Test + public void testWriteDoubleAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().writeDouble(1); + } + }); + } + + @Test + public void testWriteDoubleLEAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().writeDoubleLE(1); + } + }); + } + + @Test + public void testWriteBytesAfterRelease() { + final ByteBuf buffer = buffer(8); + try { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().writeBytes(buffer); + } + }); + } finally { + buffer.release(); + } + } + + @Test + public void testWriteBytesAfterRelease2() { + final ByteBuf buffer = copiedBuffer(new byte[8]); + try { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().writeBytes(buffer, 1); + } + }); + } finally { + buffer.release(); + } + } + + @Test + public void testWriteBytesAfterRelease3() { + final ByteBuf buffer = buffer(8); + try { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().writeBytes(buffer, 0, 1); + } + }); + } finally { + buffer.release(); + } + } + + @Test + public void testWriteBytesAfterRelease4() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().writeBytes(new byte[8]); + } + }); + } + + @Test + public void testWriteBytesAfterRelease5() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().writeBytes(new byte[8], 0, 1); + } + }); + } + + @Test + public void testWriteBytesAfterRelease6() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().writeBytes(ByteBuffer.allocate(8)); + } + }); + } + + @Test + public void testWriteBytesAfterRelease7() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() throws IOException { + releasedBuffer().writeBytes(new ByteArrayInputStream(new byte[8]), 1); + } + }); + } + + @Test + public void testWriteBytesAfterRelease8() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() throws IOException { + releasedBuffer().writeBytes(new TestScatteringByteChannel(), 1); + } + }); + } + + @Test + public void testWriteZeroAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().writeZero(1); + } + }); + } + + @Test + public void testWriteUsAsciiCharSequenceAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + testWriteCharSequenceAfterRelease0(CharsetUtil.US_ASCII); + } + }); + } + + @Test + public void testWriteIso88591CharSequenceAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + testWriteCharSequenceAfterRelease0(CharsetUtil.ISO_8859_1); + } + }); + } + + @Test + public void testWriteUtf8CharSequenceAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + testWriteCharSequenceAfterRelease0(CharsetUtil.UTF_8); + } + }); + } + + @Test + public void testWriteUtf16CharSequenceAfterRelease() { + assertThrows(IllegalReferenceCountException.class, + new Executable() { + @Override + public void execute() { + testWriteCharSequenceAfterRelease0(CharsetUtil.UTF_16); + } + }); + } + + private void testWriteCharSequenceAfterRelease0(Charset charset) { + releasedBuffer().writeCharSequence("x", charset); + } + + @Test + public void testForEachByteAfterRelease() { + assertThrows(IllegalReferenceCountException.class, + new Executable() { + @Override + public void execute() { + releasedBuffer().forEachByte(new TestByteProcessor()); + } + }); + } + + @Test + public void testForEachByteAfterRelease1() { + assertThrows(IllegalReferenceCountException.class, + new Executable() { + @Override + public void execute() { + releasedBuffer().forEachByte(0, 1, new TestByteProcessor()); + } + }); + } + + @Test + public void testForEachByteDescAfterRelease() { + assertThrows(IllegalReferenceCountException.class, + new Executable() { + @Override + public void execute() { + releasedBuffer().forEachByteDesc(new TestByteProcessor()); + } + }); + } + + @Test + public void testForEachByteDescAfterRelease1() { + assertThrows(IllegalReferenceCountException.class, + new Executable() { + @Override + public void execute() { + releasedBuffer().forEachByteDesc(0, 1, new TestByteProcessor()); + } + }); + } + + @Test + public void testCopyAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().copy(); + } + }); + } + + @Test + public void testCopyAfterRelease1() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().copy(); + } + }); + } + + @Test + public void testNioBufferAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().nioBuffer(); + } + }); + } + + @Test + public void testNioBufferAfterRelease1() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().nioBuffer(0, 1); + } + }); + } + + @Test + public void testInternalNioBufferAfterRelease() { + testInternalNioBufferAfterRelease0(IllegalReferenceCountException.class); + } + + protected void testInternalNioBufferAfterRelease0(final Class expectedException) { + final ByteBuf releasedBuffer = releasedBuffer(); + assertThrows(expectedException, new Executable() { + @Override + public void execute() { + releasedBuffer.internalNioBuffer(releasedBuffer.readerIndex(), 1); + } + }); + } + + @Test + public void testNioBuffersAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().nioBuffers(); + } + }); + } + + @Test + public void testNioBuffersAfterRelease2() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().nioBuffers(0, 1); + } + }); + } + + @Test + public void testArrayAfterRelease() { + ByteBuf buf = releasedBuffer(); + if (buf.hasArray()) { + try { + buf.array(); + fail(); + } catch (IllegalReferenceCountException e) { + // expected + } + } + } + + @Test + public void testMemoryAddressAfterRelease() { + ByteBuf buf = releasedBuffer(); + if (buf.hasMemoryAddress()) { + try { + buf.memoryAddress(); + fail(); + } catch (IllegalReferenceCountException e) { + // expected + } + } + } + + @Test + public void testSliceAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().slice(); + } + }); + } + + @Test + public void testSliceAfterRelease2() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().slice(0, 1); + } + }); + } + + private static void assertSliceFailAfterRelease(ByteBuf... bufs) { + for (ByteBuf buf : bufs) { + if (buf.refCnt() > 0) { + buf.release(); + } + } + for (ByteBuf buf : bufs) { + try { + assertEquals(0, buf.refCnt()); + buf.slice(); + fail(); + } catch (IllegalReferenceCountException ignored) { + // as expected + } + } + } + + @Test + public void testSliceAfterReleaseRetainedSlice() { + ByteBuf buf = newBuffer(1); + ByteBuf buf2 = buf.retainedSlice(0, 1); + assertSliceFailAfterRelease(buf, buf2); + } + + @Test + public void testSliceAfterReleaseRetainedSliceDuplicate() { + ByteBuf buf = newBuffer(1); + ByteBuf buf2 = buf.retainedSlice(0, 1); + ByteBuf buf3 = buf2.duplicate(); + assertSliceFailAfterRelease(buf, buf2, buf3); + } + + @Test + public void testSliceAfterReleaseRetainedSliceRetainedDuplicate() { + ByteBuf buf = newBuffer(1); + ByteBuf buf2 = buf.retainedSlice(0, 1); + ByteBuf buf3 = buf2.retainedDuplicate(); + assertSliceFailAfterRelease(buf, buf2, buf3); + } + + @Test + public void testSliceAfterReleaseRetainedDuplicate() { + ByteBuf buf = newBuffer(1); + ByteBuf buf2 = buf.retainedDuplicate(); + assertSliceFailAfterRelease(buf, buf2); + } + + @Test + public void testSliceAfterReleaseRetainedDuplicateSlice() { + ByteBuf buf = newBuffer(1); + ByteBuf buf2 = buf.retainedDuplicate(); + ByteBuf buf3 = buf2.slice(0, 1); + assertSliceFailAfterRelease(buf, buf2, buf3); + } + + @Test + public void testRetainedSliceAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().retainedSlice(); + } + }); + } + + @Test + public void testRetainedSliceAfterRelease2() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().retainedSlice(0, 1); + } + }); + } + + private static void assertRetainedSliceFailAfterRelease(ByteBuf... bufs) { + for (ByteBuf buf : bufs) { + if (buf.refCnt() > 0) { + buf.release(); + } + } + for (ByteBuf buf : bufs) { + try { + assertEquals(0, buf.refCnt()); + buf.retainedSlice(); + fail(); + } catch (IllegalReferenceCountException ignored) { + // as expected + } + } + } + + @Test + public void testRetainedSliceAfterReleaseRetainedSlice() { + ByteBuf buf = newBuffer(1); + ByteBuf buf2 = buf.retainedSlice(0, 1); + assertRetainedSliceFailAfterRelease(buf, buf2); + } + + @Test + public void testRetainedSliceAfterReleaseRetainedSliceDuplicate() { + ByteBuf buf = newBuffer(1); + ByteBuf buf2 = buf.retainedSlice(0, 1); + ByteBuf buf3 = buf2.duplicate(); + assertRetainedSliceFailAfterRelease(buf, buf2, buf3); + } + + @Test + public void testRetainedSliceAfterReleaseRetainedSliceRetainedDuplicate() { + ByteBuf buf = newBuffer(1); + ByteBuf buf2 = buf.retainedSlice(0, 1); + ByteBuf buf3 = buf2.retainedDuplicate(); + assertRetainedSliceFailAfterRelease(buf, buf2, buf3); + } + + @Test + public void testRetainedSliceAfterReleaseRetainedDuplicate() { + ByteBuf buf = newBuffer(1); + ByteBuf buf2 = buf.retainedDuplicate(); + assertRetainedSliceFailAfterRelease(buf, buf2); + } + + @Test + public void testRetainedSliceAfterReleaseRetainedDuplicateSlice() { + ByteBuf buf = newBuffer(1); + ByteBuf buf2 = buf.retainedDuplicate(); + ByteBuf buf3 = buf2.slice(0, 1); + assertRetainedSliceFailAfterRelease(buf, buf2, buf3); + } + + @Test + public void testDuplicateAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().duplicate(); + } + }); + } + + @Test + public void testRetainedDuplicateAfterRelease() { + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + releasedBuffer().retainedDuplicate(); + } + }); + } + + private static void assertDuplicateFailAfterRelease(ByteBuf... bufs) { + for (ByteBuf buf : bufs) { + if (buf.refCnt() > 0) { + buf.release(); + } + } + for (ByteBuf buf : bufs) { + try { + assertEquals(0, buf.refCnt()); + buf.duplicate(); + fail(); + } catch (IllegalReferenceCountException ignored) { + // as expected + } + } + } + + @Test + public void testDuplicateAfterReleaseRetainedSliceDuplicate() { + ByteBuf buf = newBuffer(1); + ByteBuf buf2 = buf.retainedSlice(0, 1); + ByteBuf buf3 = buf2.duplicate(); + assertDuplicateFailAfterRelease(buf, buf2, buf3); + } + + @Test + public void testDuplicateAfterReleaseRetainedDuplicate() { + ByteBuf buf = newBuffer(1); + ByteBuf buf2 = buf.retainedDuplicate(); + assertDuplicateFailAfterRelease(buf, buf2); + } + + @Test + public void testDuplicateAfterReleaseRetainedDuplicateSlice() { + ByteBuf buf = newBuffer(1); + ByteBuf buf2 = buf.retainedDuplicate(); + ByteBuf buf3 = buf2.slice(0, 1); + assertDuplicateFailAfterRelease(buf, buf2, buf3); + } + + private static void assertRetainedDuplicateFailAfterRelease(ByteBuf... bufs) { + for (ByteBuf buf : bufs) { + if (buf.refCnt() > 0) { + buf.release(); + } + } + for (ByteBuf buf : bufs) { + try { + assertEquals(0, buf.refCnt()); + buf.retainedDuplicate(); + fail(); + } catch (IllegalReferenceCountException ignored) { + // as expected + } + } + } + + @Test + public void testRetainedDuplicateAfterReleaseRetainedDuplicate() { + ByteBuf buf = newBuffer(1); + ByteBuf buf2 = buf.retainedDuplicate(); + assertRetainedDuplicateFailAfterRelease(buf, buf2); + } + + @Test + public void testRetainedDuplicateAfterReleaseDuplicate() { + ByteBuf buf = newBuffer(1); + ByteBuf buf2 = buf.duplicate(); + assertRetainedDuplicateFailAfterRelease(buf, buf2); + } + + @Test + public void testRetainedDuplicateAfterReleaseRetainedSlice() { + ByteBuf buf = newBuffer(1); + ByteBuf buf2 = buf.retainedSlice(0, 1); + assertRetainedDuplicateFailAfterRelease(buf, buf2); + } + + @Test + public void testSliceRelease() { + ByteBuf buf = newBuffer(8); + assertEquals(1, buf.refCnt()); + assertTrue(buf.slice().release()); + assertEquals(0, buf.refCnt()); + } + + @Test + public void testReadSliceOutOfBounds() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + testReadSliceOutOfBounds(false); + } + }); + } + + @Test + public void testReadRetainedSliceOutOfBounds() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + testReadSliceOutOfBounds(true); + } + }); + } + + private void testReadSliceOutOfBounds(boolean retainedSlice) { + ByteBuf buf = newBuffer(100); + try { + buf.writeZero(50); + if (retainedSlice) { + buf.readRetainedSlice(51); + } else { + buf.readSlice(51); + } + fail(); + } finally { + buf.release(); + } + } + + @Test + public void testWriteUsAsciiCharSequenceExpand() { + testWriteCharSequenceExpand(CharsetUtil.US_ASCII); + } + + @Test + public void testWriteUtf8CharSequenceExpand() { + testWriteCharSequenceExpand(CharsetUtil.UTF_8); + } + + @Test + public void testWriteIso88591CharSequenceExpand() { + testWriteCharSequenceExpand(CharsetUtil.ISO_8859_1); + } + @Test + public void testWriteUtf16CharSequenceExpand() { + testWriteCharSequenceExpand(CharsetUtil.UTF_16); + } + + private void testWriteCharSequenceExpand(Charset charset) { + ByteBuf buf = newBuffer(1); + try { + int writerIndex = buf.capacity() - 1; + buf.writerIndex(writerIndex); + int written = buf.writeCharSequence("AB", charset); + assertEquals(writerIndex, buf.writerIndex() - written); + } finally { + buf.release(); + } + } + + @Test + public void testSetUsAsciiCharSequenceNoExpand() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + testSetCharSequenceNoExpand(CharsetUtil.US_ASCII); + } + }); + } + + @Test + public void testSetUtf8CharSequenceNoExpand() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + testSetCharSequenceNoExpand(CharsetUtil.UTF_8); + } + }); + } + + @Test + public void testSetIso88591CharSequenceNoExpand() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + testSetCharSequenceNoExpand(CharsetUtil.ISO_8859_1); + } + }); + } + + @Test + public void testSetUtf16CharSequenceNoExpand() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + testSetCharSequenceNoExpand(CharsetUtil.UTF_16); + } + }); + } + + private void testSetCharSequenceNoExpand(Charset charset) { + ByteBuf buf = newBuffer(1); + try { + buf.setCharSequence(0, "AB", charset); + } finally { + buf.release(); + } + } + + @Test + public void testSetUsAsciiCharSequence() { + testSetGetCharSequence(CharsetUtil.US_ASCII); + } + + @Test + public void testSetUtf8CharSequence() { + testSetGetCharSequence(CharsetUtil.UTF_8); + } + + @Test + public void testSetIso88591CharSequence() { + testSetGetCharSequence(CharsetUtil.ISO_8859_1); + } + + @Test + public void testSetUtf16CharSequence() { + testSetGetCharSequence(CharsetUtil.UTF_16); + } + + private static final CharBuffer EXTENDED_ASCII_CHARS, ASCII_CHARS; + + static { + char[] chars = new char[256]; + for (char c = 0; c < chars.length; c++) { + chars[c] = c; + } + EXTENDED_ASCII_CHARS = CharBuffer.wrap(chars); + ASCII_CHARS = CharBuffer.wrap(chars, 0, 128); + } + + private void testSetGetCharSequence(Charset charset) { + ByteBuf buf = newBuffer(1024); + CharBuffer sequence = CharsetUtil.US_ASCII.equals(charset) + ? ASCII_CHARS : EXTENDED_ASCII_CHARS; + int bytes = buf.setCharSequence(1, sequence, charset); + assertEquals(sequence, CharBuffer.wrap(buf.getCharSequence(1, bytes, charset))); + buf.release(); + } + + @Test + public void testWriteReadUsAsciiCharSequence() { + testWriteReadCharSequence(CharsetUtil.US_ASCII); + } + + @Test + public void testWriteReadUtf8CharSequence() { + testWriteReadCharSequence(CharsetUtil.UTF_8); + } + + @Test + public void testWriteReadIso88591CharSequence() { + testWriteReadCharSequence(CharsetUtil.ISO_8859_1); + } + + @Test + public void testWriteReadUtf16CharSequence() { + testWriteReadCharSequence(CharsetUtil.UTF_16); + } + + private void testWriteReadCharSequence(Charset charset) { + ByteBuf buf = newBuffer(1024); + CharBuffer sequence = CharsetUtil.US_ASCII.equals(charset) + ? ASCII_CHARS : EXTENDED_ASCII_CHARS; + buf.writerIndex(1); + int bytes = buf.writeCharSequence(sequence, charset); + buf.readerIndex(1); + assertEquals(sequence, CharBuffer.wrap(buf.readCharSequence(bytes, charset))); + buf.release(); + } + + @Test + public void testRetainedSliceIndexOutOfBounds() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + testSliceOutOfBounds(true, true, true); + } + }); + } + + @Test + public void testRetainedSliceLengthOutOfBounds() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + testSliceOutOfBounds(true, true, false); + } + }); + } + + @Test + public void testMixedSliceAIndexOutOfBounds() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + testSliceOutOfBounds(true, false, true); + } + }); + } + + @Test + public void testMixedSliceALengthOutOfBounds() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + testSliceOutOfBounds(true, false, false); + } + }); + } + + @Test + public void testMixedSliceBIndexOutOfBounds() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + testSliceOutOfBounds(false, true, true); + } + }); + } + + @Test + public void testMixedSliceBLengthOutOfBounds() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + testSliceOutOfBounds(false, true, false); + } + }); + } + + @Test + public void testSliceIndexOutOfBounds() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + testSliceOutOfBounds(false, false, true); + } + }); + } + + @Test + public void testSliceLengthOutOfBounds() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + testSliceOutOfBounds(false, false, false); + } + }); + } + + @Test + public void testRetainedSliceAndRetainedDuplicateContentIsExpected() { + ByteBuf buf = newBuffer(8).resetWriterIndex(); + ByteBuf expected1 = newBuffer(6).resetWriterIndex(); + ByteBuf expected2 = newBuffer(5).resetWriterIndex(); + ByteBuf expected3 = newBuffer(4).resetWriterIndex(); + ByteBuf expected4 = newBuffer(3).resetWriterIndex(); + buf.writeBytes(new byte[] {1, 2, 3, 4, 5, 6, 7, 8}); + expected1.writeBytes(new byte[] {2, 3, 4, 5, 6, 7}); + expected2.writeBytes(new byte[] {3, 4, 5, 6, 7}); + expected3.writeBytes(new byte[] {4, 5, 6, 7}); + expected4.writeBytes(new byte[] {5, 6, 7}); + + ByteBuf slice1 = buf.retainedSlice(buf.readerIndex() + 1, 6); + assertEquals(0, slice1.compareTo(expected1)); + assertEquals(0, slice1.compareTo(buf.slice(buf.readerIndex() + 1, 6))); + // Simulate a handler that releases the original buffer, and propagates a slice. + buf.release(); + + // Advance the reader index on the slice. + slice1.readByte(); + + ByteBuf dup1 = slice1.retainedDuplicate(); + assertEquals(0, dup1.compareTo(expected2)); + assertEquals(0, dup1.compareTo(slice1.duplicate())); + + // Advance the reader index on dup1. + dup1.readByte(); + + ByteBuf dup2 = dup1.duplicate(); + assertEquals(0, dup2.compareTo(expected3)); + + // Advance the reader index on dup2. + dup2.readByte(); + + ByteBuf slice2 = dup2.retainedSlice(dup2.readerIndex(), 3); + assertEquals(0, slice2.compareTo(expected4)); + assertEquals(0, slice2.compareTo(dup2.slice(dup2.readerIndex(), 3))); + + // Cleanup the expected buffers used for testing. + assertTrue(expected1.release()); + assertTrue(expected2.release()); + assertTrue(expected3.release()); + assertTrue(expected4.release()); + + slice2.release(); + dup2.release(); + + assertEquals(slice2.refCnt(), dup2.refCnt()); + assertEquals(dup2.refCnt(), dup1.refCnt()); + + // The handler is now done with the original slice + assertTrue(slice1.release()); + + // Reference counting may be shared, or may be independently tracked, but at this point all buffers should + // be deallocated and have a reference count of 0. + assertEquals(0, buf.refCnt()); + assertEquals(0, slice1.refCnt()); + assertEquals(0, slice2.refCnt()); + assertEquals(0, dup1.refCnt()); + assertEquals(0, dup2.refCnt()); + } + + @Test + public void testRetainedDuplicateAndRetainedSliceContentIsExpected() { + ByteBuf buf = newBuffer(8).resetWriterIndex(); + ByteBuf expected1 = newBuffer(6).resetWriterIndex(); + ByteBuf expected2 = newBuffer(5).resetWriterIndex(); + ByteBuf expected3 = newBuffer(4).resetWriterIndex(); + buf.writeBytes(new byte[] {1, 2, 3, 4, 5, 6, 7, 8}); + expected1.writeBytes(new byte[] {2, 3, 4, 5, 6, 7}); + expected2.writeBytes(new byte[] {3, 4, 5, 6, 7}); + expected3.writeBytes(new byte[] {5, 6, 7}); + + ByteBuf dup1 = buf.retainedDuplicate(); + assertEquals(0, dup1.compareTo(buf)); + assertEquals(0, dup1.compareTo(buf.slice())); + // Simulate a handler that releases the original buffer, and propagates a slice. + buf.release(); + + // Advance the reader index on the dup. + dup1.readByte(); + + ByteBuf slice1 = dup1.retainedSlice(dup1.readerIndex(), 6); + assertEquals(0, slice1.compareTo(expected1)); + assertEquals(0, slice1.compareTo(slice1.duplicate())); + + // Advance the reader index on slice1. + slice1.readByte(); + + ByteBuf dup2 = slice1.duplicate(); + assertEquals(0, dup2.compareTo(slice1)); + + // Advance the reader index on dup2. + dup2.readByte(); + + ByteBuf slice2 = dup2.retainedSlice(dup2.readerIndex() + 1, 3); + assertEquals(0, slice2.compareTo(expected3)); + assertEquals(0, slice2.compareTo(dup2.slice(dup2.readerIndex() + 1, 3))); + + // Cleanup the expected buffers used for testing. + assertTrue(expected1.release()); + assertTrue(expected2.release()); + assertTrue(expected3.release()); + + slice2.release(); + slice1.release(); + + assertEquals(slice2.refCnt(), dup2.refCnt()); + assertEquals(dup2.refCnt(), slice1.refCnt()); + + // The handler is now done with the original slice + assertTrue(dup1.release()); + + // Reference counting may be shared, or may be independently tracked, but at this point all buffers should + // be deallocated and have a reference count of 0. + assertEquals(0, buf.refCnt()); + assertEquals(0, slice1.refCnt()); + assertEquals(0, slice2.refCnt()); + assertEquals(0, dup1.refCnt()); + assertEquals(0, dup2.refCnt()); + } + + @Test + public void testRetainedSliceContents() { + testSliceContents(true); + } + + @Test + public void testMultipleLevelRetainedSlice1() { + testMultipleLevelRetainedSliceWithNonRetained(true, true); + } + + @Test + public void testMultipleLevelRetainedSlice2() { + testMultipleLevelRetainedSliceWithNonRetained(true, false); + } + + @Test + public void testMultipleLevelRetainedSlice3() { + testMultipleLevelRetainedSliceWithNonRetained(false, true); + } + + @Test + public void testMultipleLevelRetainedSlice4() { + testMultipleLevelRetainedSliceWithNonRetained(false, false); + } + + @Test + public void testRetainedSliceReleaseOriginal1() { + testSliceReleaseOriginal(true, true); + } + + @Test + public void testRetainedSliceReleaseOriginal2() { + testSliceReleaseOriginal(true, false); + } + + @Test + public void testRetainedSliceReleaseOriginal3() { + testSliceReleaseOriginal(false, true); + } + + @Test + public void testRetainedSliceReleaseOriginal4() { + testSliceReleaseOriginal(false, false); + } + + @Test + public void testRetainedDuplicateReleaseOriginal1() { + testDuplicateReleaseOriginal(true, true); + } + + @Test + public void testRetainedDuplicateReleaseOriginal2() { + testDuplicateReleaseOriginal(true, false); + } + + @Test + public void testRetainedDuplicateReleaseOriginal3() { + testDuplicateReleaseOriginal(false, true); + } + + @Test + public void testRetainedDuplicateReleaseOriginal4() { + testDuplicateReleaseOriginal(false, false); + } + + @Test + public void testMultipleRetainedSliceReleaseOriginal1() { + testMultipleRetainedSliceReleaseOriginal(true, true); + } + + @Test + public void testMultipleRetainedSliceReleaseOriginal2() { + testMultipleRetainedSliceReleaseOriginal(true, false); + } + + @Test + public void testMultipleRetainedSliceReleaseOriginal3() { + testMultipleRetainedSliceReleaseOriginal(false, true); + } + + @Test + public void testMultipleRetainedSliceReleaseOriginal4() { + testMultipleRetainedSliceReleaseOriginal(false, false); + } + + @Test + public void testMultipleRetainedDuplicateReleaseOriginal1() { + testMultipleRetainedDuplicateReleaseOriginal(true, true); + } + + @Test + public void testMultipleRetainedDuplicateReleaseOriginal2() { + testMultipleRetainedDuplicateReleaseOriginal(true, false); + } + + @Test + public void testMultipleRetainedDuplicateReleaseOriginal3() { + testMultipleRetainedDuplicateReleaseOriginal(false, true); + } + + @Test + public void testMultipleRetainedDuplicateReleaseOriginal4() { + testMultipleRetainedDuplicateReleaseOriginal(false, false); + } + + @Test + public void testSliceContents() { + testSliceContents(false); + } + + @Test + public void testRetainedDuplicateContents() { + testDuplicateContents(true); + } + + @Test + public void testDuplicateContents() { + testDuplicateContents(false); + } + + @Test + public void testDuplicateCapacityChange() { + testDuplicateCapacityChange(false); + } + + @Test + public void testRetainedDuplicateCapacityChange() { + testDuplicateCapacityChange(true); + } + + @Test + public void testSliceCapacityChange() { + assertThrows(UnsupportedOperationException.class, new Executable() { + @Override + public void execute() { + testSliceCapacityChange(false); + } + }); + } + + @Test + public void testRetainedSliceCapacityChange() { + assertThrows(UnsupportedOperationException.class, new Executable() { + @Override + public void execute() { + testSliceCapacityChange(true); + } + }); + } + + @Test + public void testRetainedSliceUnreleasable1() { + testRetainedSliceUnreleasable(true, true); + } + + @Test + public void testRetainedSliceUnreleasable2() { + testRetainedSliceUnreleasable(true, false); + } + + @Test + public void testRetainedSliceUnreleasable3() { + testRetainedSliceUnreleasable(false, true); + } + + @Test + public void testRetainedSliceUnreleasable4() { + testRetainedSliceUnreleasable(false, false); + } + + @Test + public void testReadRetainedSliceUnreleasable1() { + testReadRetainedSliceUnreleasable(true, true); + } + + @Test + public void testReadRetainedSliceUnreleasable2() { + testReadRetainedSliceUnreleasable(true, false); + } + + @Test + public void testReadRetainedSliceUnreleasable3() { + testReadRetainedSliceUnreleasable(false, true); + } + + @Test + public void testReadRetainedSliceUnreleasable4() { + testReadRetainedSliceUnreleasable(false, false); + } + + @Test + public void testRetainedDuplicateUnreleasable1() { + testRetainedDuplicateUnreleasable(true, true); + } + + @Test + public void testRetainedDuplicateUnreleasable2() { + testRetainedDuplicateUnreleasable(true, false); + } + + @Test + public void testRetainedDuplicateUnreleasable3() { + testRetainedDuplicateUnreleasable(false, true); + } + + @Test + public void testRetainedDuplicateUnreleasable4() { + testRetainedDuplicateUnreleasable(false, false); + } + + private void testRetainedSliceUnreleasable(boolean initRetainedSlice, boolean finalRetainedSlice) { + ByteBuf buf = newBuffer(8); + ByteBuf buf1 = initRetainedSlice ? buf.retainedSlice() : buf.slice().retain(); + ByteBuf buf2 = unreleasableBuffer(buf1); + ByteBuf buf3 = finalRetainedSlice ? buf2.retainedSlice() : buf2.slice().retain(); + assertFalse(buf3.release()); + assertFalse(buf2.release()); + buf1.release(); + assertTrue(buf.release()); + assertEquals(0, buf1.refCnt()); + assertEquals(0, buf.refCnt()); + } + + private void testReadRetainedSliceUnreleasable(boolean initRetainedSlice, boolean finalRetainedSlice) { + ByteBuf buf = newBuffer(8); + ByteBuf buf1 = initRetainedSlice ? buf.retainedSlice() : buf.slice().retain(); + ByteBuf buf2 = unreleasableBuffer(buf1); + ByteBuf buf3 = finalRetainedSlice ? buf2.readRetainedSlice(buf2.readableBytes()) + : buf2.readSlice(buf2.readableBytes()).retain(); + assertFalse(buf3.release()); + assertFalse(buf2.release()); + buf1.release(); + assertTrue(buf.release()); + assertEquals(0, buf1.refCnt()); + assertEquals(0, buf.refCnt()); + } + + private void testRetainedDuplicateUnreleasable(boolean initRetainedDuplicate, boolean finalRetainedDuplicate) { + ByteBuf buf = newBuffer(8); + ByteBuf buf1 = initRetainedDuplicate ? buf.retainedDuplicate() : buf.duplicate().retain(); + ByteBuf buf2 = unreleasableBuffer(buf1); + ByteBuf buf3 = finalRetainedDuplicate ? buf2.retainedDuplicate() : buf2.duplicate().retain(); + assertFalse(buf3.release()); + assertFalse(buf2.release()); + buf1.release(); + assertTrue(buf.release()); + assertEquals(0, buf1.refCnt()); + assertEquals(0, buf.refCnt()); + } + + private void testDuplicateCapacityChange(boolean retainedDuplicate) { + ByteBuf buf = newBuffer(8); + ByteBuf dup = retainedDuplicate ? buf.retainedDuplicate() : buf.duplicate(); + try { + dup.capacity(10); + assertEquals(buf.capacity(), dup.capacity()); + dup.capacity(5); + assertEquals(buf.capacity(), dup.capacity()); + } finally { + if (retainedDuplicate) { + dup.release(); + } + buf.release(); + } + } + + private void testSliceCapacityChange(boolean retainedSlice) { + ByteBuf buf = newBuffer(8); + ByteBuf slice = retainedSlice ? buf.retainedSlice(buf.readerIndex() + 1, 3) + : buf.slice(buf.readerIndex() + 1, 3); + try { + slice.capacity(10); + } finally { + if (retainedSlice) { + slice.release(); + } + buf.release(); + } + } + + private void testSliceOutOfBounds(boolean initRetainedSlice, boolean finalRetainedSlice, boolean indexOutOfBounds) { + ByteBuf buf = newBuffer(8); + ByteBuf slice = initRetainedSlice ? buf.retainedSlice(buf.readerIndex() + 1, 2) + : buf.slice(buf.readerIndex() + 1, 2); + try { + assertEquals(2, slice.capacity()); + assertEquals(2, slice.maxCapacity()); + final int index = indexOutOfBounds ? 3 : 0; + final int length = indexOutOfBounds ? 0 : 3; + if (finalRetainedSlice) { + // This is expected to fail ... so no need to release. + slice.retainedSlice(index, length); + } else { + slice.slice(index, length); + } + } finally { + if (initRetainedSlice) { + slice.release(); + } + buf.release(); + } + } + + private void testSliceContents(boolean retainedSlice) { + ByteBuf buf = newBuffer(8).resetWriterIndex(); + ByteBuf expected = newBuffer(3).resetWriterIndex(); + buf.writeBytes(new byte[] {1, 2, 3, 4, 5, 6, 7, 8}); + expected.writeBytes(new byte[] {4, 5, 6}); + ByteBuf slice = retainedSlice ? buf.retainedSlice(buf.readerIndex() + 3, 3) + : buf.slice(buf.readerIndex() + 3, 3); + try { + assertEquals(0, slice.compareTo(expected)); + assertEquals(0, slice.compareTo(slice.duplicate())); + ByteBuf b = slice.retainedDuplicate(); + assertEquals(0, slice.compareTo(b)); + b.release(); + assertEquals(0, slice.compareTo(slice.slice(0, slice.capacity()))); + } finally { + if (retainedSlice) { + slice.release(); + } + buf.release(); + expected.release(); + } + } + + private void testSliceReleaseOriginal(boolean retainedSlice1, boolean retainedSlice2) { + ByteBuf buf = newBuffer(8).resetWriterIndex(); + ByteBuf expected1 = newBuffer(3).resetWriterIndex(); + ByteBuf expected2 = newBuffer(2).resetWriterIndex(); + buf.writeBytes(new byte[] {1, 2, 3, 4, 5, 6, 7, 8}); + expected1.writeBytes(new byte[] {6, 7, 8}); + expected2.writeBytes(new byte[] {7, 8}); + ByteBuf slice1 = retainedSlice1 ? buf.retainedSlice(buf.readerIndex() + 5, 3) + : buf.slice(buf.readerIndex() + 5, 3).retain(); + assertEquals(0, slice1.compareTo(expected1)); + // Simulate a handler that releases the original buffer, and propagates a slice. + buf.release(); + + ByteBuf slice2 = retainedSlice2 ? slice1.retainedSlice(slice1.readerIndex() + 1, 2) + : slice1.slice(slice1.readerIndex() + 1, 2).retain(); + assertEquals(0, slice2.compareTo(expected2)); + + // Cleanup the expected buffers used for testing. + assertTrue(expected1.release()); + assertTrue(expected2.release()); + + // The handler created a slice of the slice and is now done with it. + slice2.release(); + + // The handler is now done with the original slice + assertTrue(slice1.release()); + + // Reference counting may be shared, or may be independently tracked, but at this point all buffers should + // be deallocated and have a reference count of 0. + assertEquals(0, buf.refCnt()); + assertEquals(0, slice1.refCnt()); + assertEquals(0, slice2.refCnt()); + } + + private void testMultipleLevelRetainedSliceWithNonRetained(boolean doSlice1, boolean doSlice2) { + ByteBuf buf = newBuffer(8).resetWriterIndex(); + ByteBuf expected1 = newBuffer(6).resetWriterIndex(); + ByteBuf expected2 = newBuffer(4).resetWriterIndex(); + ByteBuf expected3 = newBuffer(2).resetWriterIndex(); + ByteBuf expected4SliceSlice = newBuffer(1).resetWriterIndex(); + ByteBuf expected4DupSlice = newBuffer(1).resetWriterIndex(); + buf.writeBytes(new byte[] {1, 2, 3, 4, 5, 6, 7, 8}); + expected1.writeBytes(new byte[] {2, 3, 4, 5, 6, 7}); + expected2.writeBytes(new byte[] {3, 4, 5, 6}); + expected3.writeBytes(new byte[] {4, 5}); + expected4SliceSlice.writeBytes(new byte[] {5}); + expected4DupSlice.writeBytes(new byte[] {4}); + + ByteBuf slice1 = buf.retainedSlice(buf.readerIndex() + 1, 6); + assertEquals(0, slice1.compareTo(expected1)); + // Simulate a handler that releases the original buffer, and propagates a slice. + buf.release(); + + ByteBuf slice2 = slice1.retainedSlice(slice1.readerIndex() + 1, 4); + assertEquals(0, slice2.compareTo(expected2)); + assertEquals(0, slice2.compareTo(slice2.duplicate())); + assertEquals(0, slice2.compareTo(slice2.slice())); + + ByteBuf tmpBuf = slice2.retainedDuplicate(); + assertEquals(0, slice2.compareTo(tmpBuf)); + tmpBuf.release(); + tmpBuf = slice2.retainedSlice(); + assertEquals(0, slice2.compareTo(tmpBuf)); + tmpBuf.release(); + + ByteBuf slice3 = doSlice1 ? slice2.slice(slice2.readerIndex() + 1, 2) : slice2.duplicate(); + if (doSlice1) { + assertEquals(0, slice3.compareTo(expected3)); + } else { + assertEquals(0, slice3.compareTo(expected2)); + } + + ByteBuf slice4 = doSlice2 ? slice3.slice(slice3.readerIndex() + 1, 1) : slice3.duplicate(); + if (doSlice1 && doSlice2) { + assertEquals(0, slice4.compareTo(expected4SliceSlice)); + } else if (doSlice2) { + assertEquals(0, slice4.compareTo(expected4DupSlice)); + } else { + assertEquals(0, slice3.compareTo(slice4)); + } + + // Cleanup the expected buffers used for testing. + assertTrue(expected1.release()); + assertTrue(expected2.release()); + assertTrue(expected3.release()); + assertTrue(expected4SliceSlice.release()); + assertTrue(expected4DupSlice.release()); + + // Slice 4, 3, and 2 should effectively "share" a reference count. + slice4.release(); + assertEquals(slice3.refCnt(), slice2.refCnt()); + assertEquals(slice3.refCnt(), slice4.refCnt()); + + // Slice 1 should also release the original underlying buffer without throwing exceptions + assertTrue(slice1.release()); + + // Reference counting may be shared, or may be independently tracked, but at this point all buffers should + // be deallocated and have a reference count of 0. + assertEquals(0, buf.refCnt()); + assertEquals(0, slice1.refCnt()); + assertEquals(0, slice2.refCnt()); + assertEquals(0, slice3.refCnt()); + } + + private void testDuplicateReleaseOriginal(boolean retainedDuplicate1, boolean retainedDuplicate2) { + ByteBuf buf = newBuffer(8).resetWriterIndex(); + ByteBuf expected = newBuffer(8).resetWriterIndex(); + buf.writeBytes(new byte[] {1, 2, 3, 4, 5, 6, 7, 8}); + expected.writeBytes(buf, buf.readerIndex(), buf.readableBytes()); + ByteBuf dup1 = retainedDuplicate1 ? buf.retainedDuplicate() + : buf.duplicate().retain(); + assertEquals(0, dup1.compareTo(expected)); + // Simulate a handler that releases the original buffer, and propagates a slice. + buf.release(); + + ByteBuf dup2 = retainedDuplicate2 ? dup1.retainedDuplicate() + : dup1.duplicate().retain(); + assertEquals(0, dup2.compareTo(expected)); + + // Cleanup the expected buffers used for testing. + assertTrue(expected.release()); + + // The handler created a slice of the slice and is now done with it. + dup2.release(); + + // The handler is now done with the original slice + assertTrue(dup1.release()); + + // Reference counting may be shared, or may be independently tracked, but at this point all buffers should + // be deallocated and have a reference count of 0. + assertEquals(0, buf.refCnt()); + assertEquals(0, dup1.refCnt()); + assertEquals(0, dup2.refCnt()); + } + + private void testMultipleRetainedSliceReleaseOriginal(boolean retainedSlice1, boolean retainedSlice2) { + ByteBuf buf = newBuffer(8).resetWriterIndex(); + ByteBuf expected1 = newBuffer(3).resetWriterIndex(); + ByteBuf expected2 = newBuffer(2).resetWriterIndex(); + ByteBuf expected3 = newBuffer(2).resetWriterIndex(); + buf.writeBytes(new byte[] {1, 2, 3, 4, 5, 6, 7, 8}); + expected1.writeBytes(new byte[] {6, 7, 8}); + expected2.writeBytes(new byte[] {7, 8}); + expected3.writeBytes(new byte[] {6, 7}); + ByteBuf slice1 = retainedSlice1 ? buf.retainedSlice(buf.readerIndex() + 5, 3) + : buf.slice(buf.readerIndex() + 5, 3).retain(); + assertEquals(0, slice1.compareTo(expected1)); + // Simulate a handler that releases the original buffer, and propagates a slice. + buf.release(); + + ByteBuf slice2 = retainedSlice2 ? slice1.retainedSlice(slice1.readerIndex() + 1, 2) + : slice1.slice(slice1.readerIndex() + 1, 2).retain(); + assertEquals(0, slice2.compareTo(expected2)); + + // The handler created a slice of the slice and is now done with it. + slice2.release(); + + ByteBuf slice3 = slice1.retainedSlice(slice1.readerIndex(), 2); + assertEquals(0, slice3.compareTo(expected3)); + + // The handler created another slice of the slice and is now done with it. + slice3.release(); + + // The handler is now done with the original slice + assertTrue(slice1.release()); + + // Cleanup the expected buffers used for testing. + assertTrue(expected1.release()); + assertTrue(expected2.release()); + assertTrue(expected3.release()); + + // Reference counting may be shared, or may be independently tracked, but at this point all buffers should + // be deallocated and have a reference count of 0. + assertEquals(0, buf.refCnt()); + assertEquals(0, slice1.refCnt()); + assertEquals(0, slice2.refCnt()); + assertEquals(0, slice3.refCnt()); + } + + private void testMultipleRetainedDuplicateReleaseOriginal(boolean retainedDuplicate1, boolean retainedDuplicate2) { + ByteBuf buf = newBuffer(8).resetWriterIndex(); + ByteBuf expected = newBuffer(8).resetWriterIndex(); + buf.writeBytes(new byte[] {1, 2, 3, 4, 5, 6, 7, 8}); + expected.writeBytes(buf, buf.readerIndex(), buf.readableBytes()); + ByteBuf dup1 = retainedDuplicate1 ? buf.retainedDuplicate() + : buf.duplicate().retain(); + assertEquals(0, dup1.compareTo(expected)); + // Simulate a handler that releases the original buffer, and propagates a slice. + buf.release(); + + ByteBuf dup2 = retainedDuplicate2 ? dup1.retainedDuplicate() + : dup1.duplicate().retain(); + assertEquals(0, dup2.compareTo(expected)); + assertEquals(0, dup2.compareTo(dup2.duplicate())); + assertEquals(0, dup2.compareTo(dup2.slice())); + + ByteBuf tmpBuf = dup2.retainedDuplicate(); + assertEquals(0, dup2.compareTo(tmpBuf)); + tmpBuf.release(); + tmpBuf = dup2.retainedSlice(); + assertEquals(0, dup2.compareTo(tmpBuf)); + tmpBuf.release(); + + // The handler created a slice of the slice and is now done with it. + dup2.release(); + + ByteBuf dup3 = dup1.retainedDuplicate(); + assertEquals(0, dup3.compareTo(expected)); + + // The handler created another slice of the slice and is now done with it. + dup3.release(); + + // The handler is now done with the original slice + assertTrue(dup1.release()); + + // Cleanup the expected buffers used for testing. + assertTrue(expected.release()); + + // Reference counting may be shared, or may be independently tracked, but at this point all buffers should + // be deallocated and have a reference count of 0. + assertEquals(0, buf.refCnt()); + assertEquals(0, dup1.refCnt()); + assertEquals(0, dup2.refCnt()); + assertEquals(0, dup3.refCnt()); + } + + private void testDuplicateContents(boolean retainedDuplicate) { + ByteBuf buf = newBuffer(8).resetWriterIndex(); + buf.writeBytes(new byte[] {1, 2, 3, 4, 5, 6, 7, 8}); + ByteBuf dup = retainedDuplicate ? buf.retainedDuplicate() : buf.duplicate(); + try { + assertEquals(0, dup.compareTo(buf)); + assertEquals(0, dup.compareTo(dup.duplicate())); + ByteBuf b = dup.retainedDuplicate(); + assertEquals(0, dup.compareTo(b)); + b.release(); + assertEquals(0, dup.compareTo(dup.slice(dup.readerIndex(), dup.readableBytes()))); + } finally { + if (retainedDuplicate) { + dup.release(); + } + buf.release(); + } + } + + @Test + public void testDuplicateRelease() { + ByteBuf buf = newBuffer(8); + assertEquals(1, buf.refCnt()); + assertTrue(buf.duplicate().release()); + assertEquals(0, buf.refCnt()); + } + + // Test-case trying to reproduce: + // https://github.com/netty/netty/issues/2843 + @Test + public void testRefCnt() throws Exception { + testRefCnt0(false); + } + + // Test-case trying to reproduce: + // https://github.com/netty/netty/issues/2843 + @Test + public void testRefCnt2() throws Exception { + testRefCnt0(true); + } + + @Test + public void testEmptyNioBuffers() throws Exception { + ByteBuf buffer = newBuffer(8); + buffer.clear(); + assertFalse(buffer.isReadable()); + ByteBuffer[] nioBuffers = buffer.nioBuffers(); + assertEquals(1, nioBuffers.length); + assertFalse(nioBuffers[0].hasRemaining()); + buffer.release(); + } + + @Test + public void testGetReadOnlyDirectDst() { + testGetReadOnlyDst(true); + } + + @Test + public void testGetReadOnlyHeapDst() { + testGetReadOnlyDst(false); + } + + private void testGetReadOnlyDst(boolean direct) { + byte[] bytes = { 'a', 'b', 'c', 'd' }; + + ByteBuf buffer = newBuffer(bytes.length); + buffer.writeBytes(bytes); + + ByteBuffer dst = direct ? ByteBuffer.allocateDirect(bytes.length) : ByteBuffer.allocate(bytes.length); + ByteBuffer readOnlyDst = dst.asReadOnlyBuffer(); + try { + buffer.getBytes(0, readOnlyDst); + fail(); + } catch (ReadOnlyBufferException e) { + // expected + } + assertEquals(0, readOnlyDst.position()); + buffer.release(); + } + + @Test + public void testReadBytesAndWriteBytesWithFileChannel() throws IOException { + File file = PlatformDependent.createTempFile("file-channel", ".tmp", null); + RandomAccessFile randomAccessFile = null; + try { + randomAccessFile = new RandomAccessFile(file, "rw"); + FileChannel channel = randomAccessFile.getChannel(); + // channelPosition should never be changed + long channelPosition = channel.position(); + + byte[] bytes = {'a', 'b', 'c', 'd'}; + int len = bytes.length; + ByteBuf buffer = newBuffer(len); + buffer.resetReaderIndex(); + buffer.resetWriterIndex(); + buffer.writeBytes(bytes); + + int oldReaderIndex = buffer.readerIndex(); + assertEquals(len, buffer.readBytes(channel, 10, len)); + assertEquals(oldReaderIndex + len, buffer.readerIndex()); + assertEquals(channelPosition, channel.position()); + + ByteBuf buffer2 = newBuffer(len); + buffer2.resetReaderIndex(); + buffer2.resetWriterIndex(); + int oldWriterIndex = buffer2.writerIndex(); + assertEquals(len, buffer2.writeBytes(channel, 10, len)); + assertEquals(channelPosition, channel.position()); + assertEquals(oldWriterIndex + len, buffer2.writerIndex()); + assertEquals('a', buffer2.getByte(0)); + assertEquals('b', buffer2.getByte(1)); + assertEquals('c', buffer2.getByte(2)); + assertEquals('d', buffer2.getByte(3)); + buffer.release(); + buffer2.release(); + } finally { + if (randomAccessFile != null) { + randomAccessFile.close(); + } + file.delete(); + } + } + + @Test + public void testGetBytesAndSetBytesWithFileChannel() throws IOException { + File file = PlatformDependent.createTempFile("file-channel", ".tmp", null); + RandomAccessFile randomAccessFile = null; + try { + randomAccessFile = new RandomAccessFile(file, "rw"); + FileChannel channel = randomAccessFile.getChannel(); + // channelPosition should never be changed + long channelPosition = channel.position(); + + byte[] bytes = {'a', 'b', 'c', 'd'}; + int len = bytes.length; + ByteBuf buffer = newBuffer(len); + buffer.resetReaderIndex(); + buffer.resetWriterIndex(); + buffer.writeBytes(bytes); + + int oldReaderIndex = buffer.readerIndex(); + assertEquals(len, buffer.getBytes(oldReaderIndex, channel, 10, len)); + assertEquals(oldReaderIndex, buffer.readerIndex()); + assertEquals(channelPosition, channel.position()); + + ByteBuf buffer2 = newBuffer(len); + buffer2.resetReaderIndex(); + buffer2.resetWriterIndex(); + int oldWriterIndex = buffer2.writerIndex(); + assertEquals(buffer2.setBytes(oldWriterIndex, channel, 10, len), len); + assertEquals(channelPosition, channel.position()); + + assertEquals(oldWriterIndex, buffer2.writerIndex()); + assertEquals('a', buffer2.getByte(oldWriterIndex)); + assertEquals('b', buffer2.getByte(oldWriterIndex + 1)); + assertEquals('c', buffer2.getByte(oldWriterIndex + 2)); + assertEquals('d', buffer2.getByte(oldWriterIndex + 3)); + + buffer.release(); + buffer2.release(); + } finally { + if (randomAccessFile != null) { + randomAccessFile.close(); + } + file.delete(); + } + } + + @Test + public void testReadBytes() { + ByteBuf buffer = newBuffer(8); + byte[] bytes = new byte[8]; + buffer.writeBytes(bytes); + + ByteBuf buffer2 = buffer.readBytes(4); + assertSame(buffer.alloc(), buffer2.alloc()); + assertEquals(4, buffer.readerIndex()); + assertTrue(buffer.release()); + assertEquals(0, buffer.refCnt()); + assertTrue(buffer2.release()); + assertEquals(0, buffer2.refCnt()); + } + + @Test + public void testForEachByteDesc2() { + byte[] expected = {1, 2, 3, 4}; + ByteBuf buf = newBuffer(expected.length); + try { + buf.writeBytes(expected); + final byte[] bytes = new byte[expected.length]; + int i = buf.forEachByteDesc(new ByteProcessor() { + private int index = bytes.length - 1; + + @Override + public boolean process(byte value) throws Exception { + bytes[index--] = value; + return true; + } + }); + assertEquals(-1, i); + assertArrayEquals(expected, bytes); + } finally { + buf.release(); + } + } + + @Test + public void testForEachByte2() { + byte[] expected = {1, 2, 3, 4}; + ByteBuf buf = newBuffer(expected.length); + try { + buf.writeBytes(expected); + final byte[] bytes = new byte[expected.length]; + int i = buf.forEachByte(new ByteProcessor() { + private int index; + + @Override + public boolean process(byte value) throws Exception { + bytes[index++] = value; + return true; + } + }); + assertEquals(-1, i); + assertArrayEquals(expected, bytes); + } finally { + buf.release(); + } + } + + @Test + public void testGetBytesByteBuffer() { + byte[] bytes = {'a', 'b', 'c', 'd', 'e', 'f', 'g'}; + // Ensure destination buffer is bigger then what is in the ByteBuf. + final ByteBuffer nioBuffer = ByteBuffer.allocate(bytes.length + 1); + final ByteBuf buffer = newBuffer(bytes.length); + try { + buffer.writeBytes(bytes); + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + buffer.getBytes(buffer.readerIndex(), nioBuffer); + } + }); + } finally { + buffer.release(); + } + } + + private void testRefCnt0(final boolean parameter) throws Exception { + for (int i = 0; i < 10; i++) { + final CountDownLatch latch = new CountDownLatch(1); + final CountDownLatch innerLatch = new CountDownLatch(1); + + final ByteBuf buffer = newBuffer(4); + assertEquals(1, buffer.refCnt()); + final AtomicInteger cnt = new AtomicInteger(Integer.MAX_VALUE); + Thread t1 = new Thread(new Runnable() { + @Override + public void run() { + boolean released; + if (parameter) { + released = buffer.release(buffer.refCnt()); + } else { + released = buffer.release(); + } + assertTrue(released); + Thread t2 = new Thread(new Runnable() { + @Override + public void run() { + cnt.set(buffer.refCnt()); + latch.countDown(); + } + }); + t2.start(); + try { + // Keep Thread alive a bit so the ThreadLocal caches are not freed + innerLatch.await(); + } catch (InterruptedException ignore) { + // ignore + } + } + }); + t1.start(); + + latch.await(); + assertEquals(0, cnt.get()); + innerLatch.countDown(); + } + } + + static final class TestGatheringByteChannel implements GatheringByteChannel { + private final ByteArrayOutputStream out = new ByteArrayOutputStream(); + private final WritableByteChannel channel = Channels.newChannel(out); + private final int limit; + TestGatheringByteChannel(int limit) { + this.limit = limit; + } + + TestGatheringByteChannel() { + this(Integer.MAX_VALUE); + } + + @Override + public long write(ByteBuffer[] srcs, int offset, int length) throws IOException { + long written = 0; + for (; offset < length; offset++) { + written += write(srcs[offset]); + if (written >= limit) { + break; + } + } + return written; + } + + @Override + public long write(ByteBuffer[] srcs) throws IOException { + return write(srcs, 0, srcs.length); + } + + @Override + public int write(ByteBuffer src) throws IOException { + int oldLimit = src.limit(); + if (limit < src.remaining()) { + src.limit(src.position() + limit); + } + int w = channel.write(src); + src.limit(oldLimit); + return w; + } + + @Override + public boolean isOpen() { + return channel.isOpen(); + } + + @Override + public void close() throws IOException { + channel.close(); + } + + public byte[] writtenBytes() { + return out.toByteArray(); + } + } + + private static final class DevNullGatheringByteChannel implements GatheringByteChannel { + @Override + public long write(ByteBuffer[] srcs, int offset, int length) { + throw new UnsupportedOperationException(); + } + + @Override + public long write(ByteBuffer[] srcs) { + throw new UnsupportedOperationException(); + } + + @Override + public int write(ByteBuffer src) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isOpen() { + return false; + } + + @Override + public void close() { + throw new UnsupportedOperationException(); + } + } + + private static final class TestScatteringByteChannel implements ScatteringByteChannel { + @Override + public long read(ByteBuffer[] dsts, int offset, int length) { + throw new UnsupportedOperationException(); + } + + @Override + public long read(ByteBuffer[] dsts) { + throw new UnsupportedOperationException(); + } + + @Override + public int read(ByteBuffer dst) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isOpen() { + return false; + } + + @Override + public void close() { + throw new UnsupportedOperationException(); + } + } + + private static final class TestByteProcessor implements ByteProcessor { + @Override + public boolean process(byte value) throws Exception { + return true; + } + } + + @Test + public void testCapacityEnforceMaxCapacity() { + final ByteBuf buffer = newBuffer(3, 13); + assertEquals(13, buffer.maxCapacity()); + assertEquals(3, buffer.capacity()); + try { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + buffer.capacity(14); + } + }); + } finally { + buffer.release(); + } + } + + @Test + public void testCapacityNegative() { + final ByteBuf buffer = newBuffer(3, 13); + assertEquals(13, buffer.maxCapacity()); + assertEquals(3, buffer.capacity()); + try { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + buffer.capacity(-1); + } + }); + } finally { + buffer.release(); + } + } + + @Test + public void testCapacityDecrease() { + ByteBuf buffer = newBuffer(3, 13); + assertEquals(13, buffer.maxCapacity()); + assertEquals(3, buffer.capacity()); + try { + buffer.capacity(2); + assertEquals(2, buffer.capacity()); + assertEquals(13, buffer.maxCapacity()); + } finally { + buffer.release(); + } + } + + @Test + public void testCapacityIncrease() { + ByteBuf buffer = newBuffer(3, 13); + assertEquals(13, buffer.maxCapacity()); + assertEquals(3, buffer.capacity()); + try { + buffer.capacity(4); + assertEquals(4, buffer.capacity()); + assertEquals(13, buffer.maxCapacity()); + } finally { + buffer.release(); + } + } + + @Test + public void testReaderIndexLargerThanWriterIndex() { + String content1 = "hello"; + String content2 = "world"; + int length = content1.length() + content2.length(); + final ByteBuf buffer = newBuffer(length); + buffer.setIndex(0, 0); + buffer.writeCharSequence(content1, CharsetUtil.US_ASCII); + buffer.markWriterIndex(); + buffer.skipBytes(content1.length()); + buffer.writeCharSequence(content2, CharsetUtil.US_ASCII); + buffer.skipBytes(content2.length()); + assertTrue(buffer.readerIndex() <= buffer.writerIndex()); + + try { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + buffer.resetWriterIndex(); + } + }); + } finally { + buffer.release(); + } + } + + @Test + public void testMaxFastWritableBytes() { + ByteBuf buffer = newBuffer(150, 500).writerIndex(100); + assertEquals(50, buffer.writableBytes()); + assertEquals(150, buffer.capacity()); + assertEquals(500, buffer.maxCapacity()); + assertEquals(400, buffer.maxWritableBytes()); + // Default implementation has fast writable == writable + assertEquals(50, buffer.maxFastWritableBytes()); + buffer.release(); + } + + @Test + public void testEnsureWritableIntegerOverflow() { + ByteBuf buffer = newBuffer(CAPACITY); + buffer.writerIndex(buffer.readerIndex()); + buffer.writeByte(1); + try { + buffer.ensureWritable(Integer.MAX_VALUE); + fail(); + } catch (IndexOutOfBoundsException e) { + // expected + } finally { + buffer.release(); + } + } + + @Test + public void testEndiannessIndexOf() { + buffer.clear(); + final int v = 0x02030201; + buffer.writeIntLE(v); + buffer.writeByte(0x01); + + assertEquals(-1, buffer.indexOf(1, 4, (byte) 1)); + assertEquals(-1, buffer.indexOf(4, 1, (byte) 1)); + assertEquals(1, buffer.indexOf(1, 4, (byte) 2)); + assertEquals(3, buffer.indexOf(4, 1, (byte) 2)); + } + + @Test + public void explicitLittleEndianReadMethodsMustAlwaysUseLittleEndianByteOrder() { + buffer.clear(); + buffer.writeBytes(new byte[] {0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08}); + assertEquals(0x0201, buffer.readShortLE()); + buffer.readerIndex(0); + assertEquals(0x0201, buffer.readUnsignedShortLE()); + buffer.readerIndex(0); + assertEquals(0x030201, buffer.readMediumLE()); + buffer.readerIndex(0); + assertEquals(0x030201, buffer.readUnsignedMediumLE()); + buffer.readerIndex(0); + assertEquals(0x04030201, buffer.readIntLE()); + buffer.readerIndex(0); + assertEquals(0x04030201, buffer.readUnsignedIntLE()); + buffer.readerIndex(0); + assertEquals(0x04030201, Float.floatToRawIntBits(buffer.readFloatLE())); + buffer.readerIndex(0); + assertEquals(0x0807060504030201L, buffer.readLongLE()); + buffer.readerIndex(0); + assertEquals(0x0807060504030201L, Double.doubleToRawLongBits(buffer.readDoubleLE())); + buffer.readerIndex(0); + + assertEquals(0x0201, buffer.getShortLE(0)); + assertEquals(0x0201, buffer.getUnsignedShortLE(0)); + assertEquals(0x030201, buffer.getMediumLE(0)); + assertEquals(0x030201, buffer.getUnsignedMediumLE(0)); + assertEquals(0x04030201, buffer.getIntLE(0)); + assertEquals(0x04030201, buffer.getUnsignedIntLE(0)); + assertEquals(0x04030201, Float.floatToRawIntBits(buffer.getFloatLE(0))); + assertEquals(0x0807060504030201L, buffer.getLongLE(0)); + assertEquals(0x0807060504030201L, Double.doubleToRawLongBits(buffer.getDoubleLE(0))); + } + + @Test + public void explicitLittleEndianWriteMethodsMustAlwaysUseLittleEndianByteOrder() { + buffer.clear(); + buffer.writeShortLE(0x0102); + assertEquals(0x0102, buffer.readShortLE()); + buffer.clear(); + buffer.writeMediumLE(0x010203); + assertEquals(0x010203, buffer.readMediumLE()); + buffer.clear(); + buffer.writeIntLE(0x01020304); + assertEquals(0x01020304, buffer.readIntLE()); + buffer.clear(); + buffer.writeFloatLE(Float.intBitsToFloat(0x01020304)); + assertEquals(0x01020304, Float.floatToRawIntBits(buffer.readFloatLE())); + buffer.clear(); + buffer.writeLongLE(0x0102030405060708L); + assertEquals(0x0102030405060708L, buffer.readLongLE()); + buffer.clear(); + buffer.writeDoubleLE(Double.longBitsToDouble(0x0102030405060708L)); + assertEquals(0x0102030405060708L, Double.doubleToRawLongBits(buffer.readDoubleLE())); + + buffer.setShortLE(0, 0x0102); + assertEquals(0x0102, buffer.getShortLE(0)); + buffer.setMediumLE(0, 0x010203); + assertEquals(0x010203, buffer.getMediumLE(0)); + buffer.setIntLE(0, 0x01020304); + assertEquals(0x01020304, buffer.getIntLE(0)); + buffer.setFloatLE(0, Float.intBitsToFloat(0x01020304)); + assertEquals(0x01020304, Float.floatToRawIntBits(buffer.getFloatLE(0))); + buffer.setLongLE(0, 0x0102030405060708L); + assertEquals(0x0102030405060708L, buffer.getLongLE(0)); + buffer.setDoubleLE(0, Double.longBitsToDouble(0x0102030405060708L)); + assertEquals(0x0102030405060708L, Double.doubleToRawLongBits(buffer.getDoubleLE(0))); + } +} diff --git a/netty-buffer/src/test/java/io/netty/buffer/AbstractCompositeByteBufTest.java b/netty-buffer/src/test/java/io/netty/buffer/AbstractCompositeByteBufTest.java new file mode 100644 index 0000000..4bb21b1 --- /dev/null +++ b/netty-buffer/src/test/java/io/netty/buffer/AbstractCompositeByteBufTest.java @@ -0,0 +1,1810 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import io.netty.util.ByteProcessor; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.PlatformDependent; +import org.junit.jupiter.api.Assumptions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.util.ArrayList; +import java.util.Collections; +import java.util.ConcurrentModificationException; +import java.util.Iterator; +import java.util.List; +import java.util.NoSuchElementException; + +import static io.netty.buffer.Unpooled.EMPTY_BUFFER; +import static io.netty.buffer.Unpooled.buffer; +import static io.netty.buffer.Unpooled.compositeBuffer; +import static io.netty.buffer.Unpooled.directBuffer; +import static io.netty.buffer.Unpooled.wrappedBuffer; +import static io.netty.util.internal.EmptyArrays.EMPTY_BYTES; +import static org.hamcrest.CoreMatchers.instanceOf; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNotSame; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +/** + * An abstract test class for composite channel buffers + */ +public abstract class AbstractCompositeByteBufTest extends AbstractByteBufTest { + + private static final ByteBufAllocator ALLOC = UnpooledByteBufAllocator.DEFAULT; + + private final ByteOrder order; + + protected AbstractCompositeByteBufTest(ByteOrder order) { + this.order = ObjectUtil.checkNotNull(order, "order"); + } + + @Override + protected ByteBuf newBuffer(int length, int maxCapacity) { + Assumptions.assumeTrue(maxCapacity == Integer.MAX_VALUE); + + List buffers = new ArrayList(); + for (int i = 0; i < length + 45; i += 45) { + buffers.add(EMPTY_BUFFER); + buffers.add(wrappedBuffer(new byte[1])); + buffers.add(EMPTY_BUFFER); + buffers.add(wrappedBuffer(new byte[2])); + buffers.add(EMPTY_BUFFER); + buffers.add(wrappedBuffer(new byte[3])); + buffers.add(EMPTY_BUFFER); + buffers.add(wrappedBuffer(new byte[4])); + buffers.add(EMPTY_BUFFER); + buffers.add(wrappedBuffer(new byte[5])); + buffers.add(EMPTY_BUFFER); + buffers.add(wrappedBuffer(new byte[6])); + buffers.add(EMPTY_BUFFER); + buffers.add(wrappedBuffer(new byte[7])); + buffers.add(EMPTY_BUFFER); + buffers.add(wrappedBuffer(new byte[8])); + buffers.add(EMPTY_BUFFER); + buffers.add(wrappedBuffer(new byte[9])); + buffers.add(EMPTY_BUFFER); + } + + ByteBuf buffer; + // Ensure that we are really testing a CompositeByteBuf + switch (buffers.size()) { + case 0: + buffer = compositeBuffer(Integer.MAX_VALUE); + break; + case 1: + buffer = compositeBuffer(Integer.MAX_VALUE).addComponent(buffers.get(0)); + break; + default: + buffer = wrappedBuffer(Integer.MAX_VALUE, buffers.toArray(new ByteBuf[0])); + break; + } + buffer = buffer.order(order); + + // Truncate to the requested capacity. + buffer.capacity(length); + + assertEquals(length, buffer.capacity()); + assertEquals(length, buffer.readableBytes()); + assertFalse(buffer.isWritable()); + buffer.writerIndex(0); + return buffer; + } + + protected CompositeByteBuf newCompositeBuffer() { + return compositeBuffer(); + } + + // Composite buffer does not waste bandwidth on discardReadBytes, but + // the test will fail in strict mode. + @Override + protected boolean discardReadBytesDoesNotMoveWritableBytes() { + return false; + } + + @Test + public void testIsContiguous() { + ByteBuf buf = newBuffer(4); + assertFalse(buf.isContiguous()); + buf.release(); + } + + /** + * Tests the "getBufferFor" method + */ + @Test + public void testComponentAtOffset() { + CompositeByteBuf buf = (CompositeByteBuf) wrappedBuffer(new byte[]{1, 2, 3, 4, 5}, + new byte[]{4, 5, 6, 7, 8, 9, 26}); + + //Ensure that a random place will be fine + assertEquals(5, buf.componentAtOffset(2).capacity()); + + //Loop through each byte + + byte index = 0; + + while (index < buf.capacity()) { + ByteBuf _buf = buf.componentAtOffset(index++); + assertNotNull(_buf); + assertTrue(_buf.capacity() > 0); + assertTrue(_buf.getByte(0) > 0); + assertTrue(_buf.getByte(_buf.readableBytes() - 1) > 0); + } + + buf.release(); + } + + @Test + public void testToComponentIndex() { + CompositeByteBuf buf = (CompositeByteBuf) wrappedBuffer(new byte[]{1, 2, 3, 4, 5}, + new byte[]{4, 5, 6, 7, 8, 9, 26}, new byte[]{10, 9, 8, 7, 6, 5, 33}); + + // spot checks + assertEquals(0, buf.toComponentIndex(4)); + assertEquals(1, buf.toComponentIndex(5)); + assertEquals(2, buf.toComponentIndex(15)); + + //Loop through each byte + + byte index = 0; + + while (index < buf.capacity()) { + int cindex = buf.toComponentIndex(index++); + assertTrue(cindex >= 0 && cindex < buf.numComponents()); + } + + buf.release(); + } + + @Test + public void testToByteIndex() { + CompositeByteBuf buf = (CompositeByteBuf) wrappedBuffer(new byte[]{1, 2, 3, 4, 5}, + new byte[]{4, 5, 6, 7, 8, 9, 26}, new byte[]{10, 9, 8, 7, 6, 5, 33}); + + // spot checks + assertEquals(0, buf.toByteIndex(0)); + assertEquals(5, buf.toByteIndex(1)); + assertEquals(12, buf.toByteIndex(2)); + + buf.release(); + } + + @Test + public void testDiscardReadBytes3() { + ByteBuf a, b; + a = wrappedBuffer(new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }).order(order); + b = wrappedBuffer( + wrappedBuffer(new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }, 0, 5).order(order), + wrappedBuffer(new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }, 5, 5).order(order)); + a.skipBytes(6); + a.markReaderIndex(); + b.skipBytes(6); + b.markReaderIndex(); + assertEquals(a.readerIndex(), b.readerIndex()); + a.readerIndex(a.readerIndex() - 1); + b.readerIndex(b.readerIndex() - 1); + assertEquals(a.readerIndex(), b.readerIndex()); + a.writerIndex(a.writerIndex() - 1); + a.markWriterIndex(); + b.writerIndex(b.writerIndex() - 1); + b.markWriterIndex(); + assertEquals(a.writerIndex(), b.writerIndex()); + a.writerIndex(a.writerIndex() + 1); + b.writerIndex(b.writerIndex() + 1); + assertEquals(a.writerIndex(), b.writerIndex()); + assertTrue(ByteBufUtil.equals(a, b)); + // now discard + a.discardReadBytes(); + b.discardReadBytes(); + assertEquals(a.readerIndex(), b.readerIndex()); + assertEquals(a.writerIndex(), b.writerIndex()); + assertTrue(ByteBufUtil.equals(a, b)); + a.resetReaderIndex(); + b.resetReaderIndex(); + assertEquals(a.readerIndex(), b.readerIndex()); + a.resetWriterIndex(); + b.resetWriterIndex(); + assertEquals(a.writerIndex(), b.writerIndex()); + assertTrue(ByteBufUtil.equals(a, b)); + + a.release(); + b.release(); + } + + @Test + public void testAutoConsolidation() { + CompositeByteBuf buf = compositeBuffer(2); + + buf.addComponent(wrappedBuffer(new byte[] { 1 })); + assertEquals(1, buf.numComponents()); + + buf.addComponent(wrappedBuffer(new byte[] { 2, 3 })); + assertEquals(2, buf.numComponents()); + + buf.addComponent(wrappedBuffer(new byte[] { 4, 5, 6 })); + + assertEquals(1, buf.numComponents()); + assertTrue(buf.hasArray()); + assertNotNull(buf.array()); + assertEquals(0, buf.arrayOffset()); + + buf.release(); + } + + @Test + public void testCompositeToSingleBuffer() { + CompositeByteBuf buf = compositeBuffer(3); + + buf.addComponent(wrappedBuffer(new byte[] {1, 2, 3})); + assertEquals(1, buf.numComponents()); + + buf.addComponent(wrappedBuffer(new byte[] {4})); + assertEquals(2, buf.numComponents()); + + buf.addComponent(wrappedBuffer(new byte[] {5, 6})); + assertEquals(3, buf.numComponents()); + + // NOTE: hard-coding 6 here, since it seems like addComponent doesn't bump the writer index. + // I'm unsure as to whether or not this is correct behavior + ByteBuffer nioBuffer = buf.nioBuffer(0, 6); + byte[] bytes = nioBuffer.array(); + assertEquals(6, bytes.length); + assertArrayEquals(new byte[] {1, 2, 3, 4, 5, 6}, bytes); + + buf.release(); + } + + @Test + public void testFullConsolidation() { + CompositeByteBuf buf = compositeBuffer(Integer.MAX_VALUE); + buf.addComponent(wrappedBuffer(new byte[] { 1 })); + buf.addComponent(wrappedBuffer(new byte[] { 2, 3 })); + buf.addComponent(wrappedBuffer(new byte[] { 4, 5, 6 })); + buf.consolidate(); + + assertEquals(1, buf.numComponents()); + assertTrue(buf.hasArray()); + assertNotNull(buf.array()); + assertEquals(0, buf.arrayOffset()); + + buf.release(); + } + + @Test + public void testRangedConsolidation() { + CompositeByteBuf buf = compositeBuffer(Integer.MAX_VALUE); + buf.addComponent(wrappedBuffer(new byte[] { 1 })); + buf.addComponent(wrappedBuffer(new byte[] { 2, 3 })); + buf.addComponent(wrappedBuffer(new byte[] { 4, 5, 6 })); + buf.addComponent(wrappedBuffer(new byte[] { 7, 8, 9, 10 })); + buf.consolidate(1, 2); + + assertEquals(3, buf.numComponents()); + assertEquals(wrappedBuffer(new byte[] { 1 }), buf.component(0)); + assertEquals(wrappedBuffer(new byte[] { 2, 3, 4, 5, 6 }), buf.component(1)); + assertEquals(wrappedBuffer(new byte[] { 7, 8, 9, 10 }), buf.component(2)); + + buf.release(); + } + + @Test + public void testCompositeWrappedBuffer() { + ByteBuf header = buffer(12).order(order); + ByteBuf payload = buffer(512).order(order); + + header.writeBytes(new byte[12]); + payload.writeBytes(new byte[512]); + + ByteBuf buffer = wrappedBuffer(header, payload); + + assertEquals(12, header.readableBytes()); + assertEquals(512, payload.readableBytes()); + + assertEquals(12 + 512, buffer.readableBytes()); + assertEquals(2, buffer.nioBufferCount()); + + buffer.release(); + } + + @Test + public void testSeveralBuffersEquals() { + ByteBuf a, b; + // XXX Same tests with several buffers in wrappedCheckedBuffer + // Different length. + a = wrappedBuffer(new byte[] { 1 }).order(order); + b = wrappedBuffer( + wrappedBuffer(new byte[] { 1 }).order(order), + wrappedBuffer(new byte[] { 2 }).order(order)); + assertFalse(ByteBufUtil.equals(a, b)); + + a.release(); + b.release(); + + // Same content, same firstIndex, short length. + a = wrappedBuffer(new byte[] { 1, 2, 3 }).order(order); + b = wrappedBuffer( + wrappedBuffer(new byte[]{1}).order(order), + wrappedBuffer(new byte[]{2}).order(order), + wrappedBuffer(new byte[]{3}).order(order)); + assertTrue(ByteBufUtil.equals(a, b)); + + a.release(); + b.release(); + + // Same content, different firstIndex, short length. + a = wrappedBuffer(new byte[] { 1, 2, 3 }).order(order); + b = wrappedBuffer( + wrappedBuffer(new byte[] { 0, 1, 2, 3, 4 }, 1, 2).order(order), + wrappedBuffer(new byte[] { 0, 1, 2, 3, 4 }, 3, 1).order(order)); + assertTrue(ByteBufUtil.equals(a, b)); + + a.release(); + b.release(); + + // Different content, same firstIndex, short length. + a = wrappedBuffer(new byte[] { 1, 2, 3 }).order(order); + b = wrappedBuffer( + wrappedBuffer(new byte[] { 1, 2 }).order(order), + wrappedBuffer(new byte[] { 4 }).order(order)); + assertFalse(ByteBufUtil.equals(a, b)); + + a.release(); + b.release(); + + // Different content, different firstIndex, short length. + a = wrappedBuffer(new byte[] { 1, 2, 3 }).order(order); + b = wrappedBuffer( + wrappedBuffer(new byte[] { 0, 1, 2, 4, 5 }, 1, 2).order(order), + wrappedBuffer(new byte[] { 0, 1, 2, 4, 5 }, 3, 1).order(order)); + assertFalse(ByteBufUtil.equals(a, b)); + + a.release(); + b.release(); + + // Same content, same firstIndex, long length. + a = wrappedBuffer(new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }).order(order); + b = wrappedBuffer( + wrappedBuffer(new byte[] { 1, 2, 3 }).order(order), + wrappedBuffer(new byte[] { 4, 5, 6 }).order(order), + wrappedBuffer(new byte[] { 7, 8, 9, 10 }).order(order)); + assertTrue(ByteBufUtil.equals(a, b)); + + a.release(); + b.release(); + + // Same content, different firstIndex, long length. + a = wrappedBuffer(new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }).order(order); + b = wrappedBuffer( + wrappedBuffer(new byte[] { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, 1, 5).order(order), + wrappedBuffer(new byte[] { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, 6, 5).order(order)); + assertTrue(ByteBufUtil.equals(a, b)); + + a.release(); + b.release(); + + // Different content, same firstIndex, long length. + a = wrappedBuffer(new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }).order(order); + b = wrappedBuffer( + wrappedBuffer(new byte[] { 1, 2, 3, 4, 6 }).order(order), + wrappedBuffer(new byte[] { 7, 8, 5, 9, 10 }).order(order)); + assertFalse(ByteBufUtil.equals(a, b)); + + a.release(); + b.release(); + + // Different content, different firstIndex, long length. + a = wrappedBuffer(new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }).order(order); + b = wrappedBuffer( + wrappedBuffer(new byte[] { 0, 1, 2, 3, 4, 6, 7, 8, 5, 9, 10, 11 }, 1, 5).order(order), + wrappedBuffer(new byte[] { 0, 1, 2, 3, 4, 6, 7, 8, 5, 9, 10, 11 }, 6, 5).order(order)); + assertFalse(ByteBufUtil.equals(a, b)); + + a.release(); + b.release(); + } + + @Test + public void testWrappedBuffer() { + + ByteBuf a = wrappedBuffer(wrappedBuffer(ByteBuffer.allocateDirect(16))); + assertEquals(16, a.capacity()); + a.release(); + + a = wrappedBuffer(wrappedBuffer(new byte[] { 1, 2, 3 }).order(order)); + ByteBuf b = wrappedBuffer(wrappedBuffer(new byte[][] { new byte[] { 1, 2, 3 } }).order(order)); + assertEquals(a, b); + + a.release(); + b.release(); + + a = wrappedBuffer(wrappedBuffer(new byte[] { 1, 2, 3 }).order(order)); + b = wrappedBuffer(wrappedBuffer( + new byte[] { 1 }, + new byte[] { 2 }, + new byte[] { 3 }).order(order)); + assertEquals(a, b); + + a.release(); + b.release(); + + a = wrappedBuffer(wrappedBuffer(new byte[] { 1, 2, 3 }).order(order)); + b = wrappedBuffer(new ByteBuf[] { + wrappedBuffer(new byte[] { 1, 2, 3 }).order(order) + }); + assertEquals(a, b); + + a.release(); + b.release(); + + a = wrappedBuffer(wrappedBuffer(new byte[] { 1, 2, 3 }).order(order)); + b = wrappedBuffer( + wrappedBuffer(new byte[] { 1 }).order(order), + wrappedBuffer(new byte[] { 2 }).order(order), + wrappedBuffer(new byte[] { 3 }).order(order)); + assertEquals(a, b); + + a.release(); + b.release(); + + a = wrappedBuffer(wrappedBuffer(new byte[] { 1, 2, 3 })).order(order); + b = wrappedBuffer(wrappedBuffer(new ByteBuffer[] { + ByteBuffer.wrap(new byte[] { 1, 2, 3 }) + })); + assertEquals(a, b); + + a.release(); + b.release(); + + a = wrappedBuffer(wrappedBuffer(new byte[] { 1, 2, 3 }).order(order)); + b = wrappedBuffer(wrappedBuffer( + ByteBuffer.wrap(new byte[] { 1 }), + ByteBuffer.wrap(new byte[] { 2 }), + ByteBuffer.wrap(new byte[] { 3 }))); + assertEquals(a, b); + + a.release(); + b.release(); + } + + @Test + public void testWrittenBuffersEquals() { + //XXX Same tests than testEquals with written AggregateChannelBuffers + ByteBuf a, b, c; + // Different length. + a = wrappedBuffer(new byte[] { 1 }).order(order); + b = wrappedBuffer(wrappedBuffer(new byte[] { 1 }, new byte[1])).order(order); + c = wrappedBuffer(new byte[] { 2 }).order(order); + + // to enable writeBytes + b.writerIndex(b.writerIndex() - 1); + b.writeBytes(c); + assertFalse(ByteBufUtil.equals(a, b)); + + a.release(); + b.release(); + c.release(); + + // Same content, same firstIndex, short length. + a = wrappedBuffer(new byte[] { 1, 2, 3 }).order(order); + b = wrappedBuffer(wrappedBuffer(new byte[] { 1 }, new byte[2])).order(order); + c = wrappedBuffer(new byte[] { 2 }).order(order); + + // to enable writeBytes + b.writerIndex(b.writerIndex() - 2); + b.writeBytes(c); + c.release(); + c = wrappedBuffer(new byte[] { 3 }).order(order); + + b.writeBytes(c); + assertTrue(ByteBufUtil.equals(a, b)); + + a.release(); + b.release(); + c.release(); + + // Same content, different firstIndex, short length. + a = wrappedBuffer(new byte[] { 1, 2, 3 }).order(order); + b = wrappedBuffer(wrappedBuffer(new byte[] { 0, 1, 2, 3, 4 }, 1, 3)).order(order); + c = wrappedBuffer(new byte[] { 0, 1, 2, 3, 4 }, 3, 1).order(order); + // to enable writeBytes + b.writerIndex(b.writerIndex() - 1); + b.writeBytes(c); + assertTrue(ByteBufUtil.equals(a, b)); + + a.release(); + b.release(); + c.release(); + + // Different content, same firstIndex, short length. + a = wrappedBuffer(new byte[] { 1, 2, 3 }).order(order); + b = wrappedBuffer(wrappedBuffer(new byte[] { 1, 2 }, new byte[1])).order(order); + c = wrappedBuffer(new byte[] { 4 }).order(order); + // to enable writeBytes + b.writerIndex(b.writerIndex() - 1); + b.writeBytes(c); + assertFalse(ByteBufUtil.equals(a, b)); + + a.release(); + b.release(); + c.release(); + + // Different content, different firstIndex, short length. + a = wrappedBuffer(new byte[] { 1, 2, 3 }).order(order); + b = wrappedBuffer(wrappedBuffer(new byte[] { 0, 1, 2, 4, 5 }, 1, 3)).order(order); + c = wrappedBuffer(new byte[] { 0, 1, 2, 4, 5 }, 3, 1).order(order); + // to enable writeBytes + b.writerIndex(b.writerIndex() - 1); + b.writeBytes(c); + assertFalse(ByteBufUtil.equals(a, b)); + + a.release(); + b.release(); + c.release(); + + // Same content, same firstIndex, long length. + a = wrappedBuffer(new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }).order(order); + b = wrappedBuffer(wrappedBuffer(new byte[] { 1, 2, 3 }, new byte[7])).order(order); + c = wrappedBuffer(new byte[] { 4, 5, 6 }).order(order); + + // to enable writeBytes + b.writerIndex(b.writerIndex() - 7); + b.writeBytes(c); + c.release(); + c = wrappedBuffer(new byte[] { 7, 8, 9, 10 }).order(order); + b.writeBytes(c); + assertTrue(ByteBufUtil.equals(a, b)); + + a.release(); + b.release(); + c.release(); + + // Same content, different firstIndex, long length. + a = wrappedBuffer(new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }).order(order); + b = wrappedBuffer( + wrappedBuffer(new byte[] { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, 1, 10)).order(order); + c = wrappedBuffer(new byte[] { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, 6, 5).order(order); + // to enable writeBytes + b.writerIndex(b.writerIndex() - 5); + b.writeBytes(c); + assertTrue(ByteBufUtil.equals(a, b)); + + a.release(); + b.release(); + c.release(); + + // Different content, same firstIndex, long length. + a = wrappedBuffer(new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }).order(order); + b = wrappedBuffer(wrappedBuffer(new byte[] { 1, 2, 3, 4, 6 }, new byte[5])).order(order); + c = wrappedBuffer(new byte[] { 7, 8, 5, 9, 10 }).order(order); + // to enable writeBytes + b.writerIndex(b.writerIndex() - 5); + b.writeBytes(c); + assertFalse(ByteBufUtil.equals(a, b)); + + a.release(); + b.release(); + c.release(); + + // Different content, different firstIndex, long length. + a = wrappedBuffer(new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }).order(order); + b = wrappedBuffer( + wrappedBuffer(new byte[] { 0, 1, 2, 3, 4, 6, 7, 8, 5, 9, 10, 11 }, 1, 10)).order(order); + c = wrappedBuffer(new byte[] { 0, 1, 2, 3, 4, 6, 7, 8, 5, 9, 10, 11 }, 6, 5).order(order); + // to enable writeBytes + b.writerIndex(b.writerIndex() - 5); + b.writeBytes(c); + assertFalse(ByteBufUtil.equals(a, b)); + + a.release(); + b.release(); + c.release(); + } + + @Test + public void testEmptyBuffer() { + ByteBuf b = wrappedBuffer(new byte[]{1, 2}, new byte[]{3, 4}); + b.readBytes(new byte[4]); + b.readBytes(EMPTY_BYTES); + b.release(); + } + + // Test for https://github.com/netty/netty/issues/1060 + @Test + public void testReadWithEmptyCompositeBuffer() { + ByteBuf buf = compositeBuffer(); + int n = 65; + for (int i = 0; i < n; i ++) { + buf.writeByte(1); + assertEquals(1, buf.readByte()); + } + buf.release(); + } + + @SuppressWarnings("deprecation") + @Test + public void testComponentMustBeDuplicate() { + CompositeByteBuf buf = compositeBuffer(); + buf.addComponent(buffer(4, 6).setIndex(1, 3)); + assertThat(buf.component(0), is(instanceOf(AbstractDerivedByteBuf.class))); + assertThat(buf.component(0).capacity(), is(4)); + assertThat(buf.component(0).maxCapacity(), is(6)); + assertThat(buf.component(0).readableBytes(), is(2)); + buf.release(); + } + + @Test + public void testReferenceCounts1() { + ByteBuf c1 = buffer().writeByte(1); + ByteBuf c2 = buffer().writeByte(2).retain(); + ByteBuf c3 = buffer().writeByte(3).retain(2); + + CompositeByteBuf buf = compositeBuffer(); + assertThat(buf.refCnt(), is(1)); + buf.addComponents(c1, c2, c3); + + assertThat(buf.refCnt(), is(1)); + + // Ensure that c[123]'s refCount did not change. + assertThat(c1.refCnt(), is(1)); + assertThat(c2.refCnt(), is(2)); + assertThat(c3.refCnt(), is(3)); + + assertThat(buf.component(0).refCnt(), is(1)); + assertThat(buf.component(1).refCnt(), is(2)); + assertThat(buf.component(2).refCnt(), is(3)); + + c3.release(2); + c2.release(); + buf.release(); + } + + @Test + public void testReferenceCounts2() { + ByteBuf c1 = buffer().writeByte(1); + ByteBuf c2 = buffer().writeByte(2).retain(); + ByteBuf c3 = buffer().writeByte(3).retain(2); + + CompositeByteBuf bufA = compositeBuffer(); + bufA.addComponents(c1, c2, c3).writerIndex(3); + + CompositeByteBuf bufB = compositeBuffer(); + bufB.addComponents(bufA); + + // Ensure that bufA.refCnt() did not change. + assertThat(bufA.refCnt(), is(1)); + + // Ensure that c[123]'s refCnt did not change. + assertThat(c1.refCnt(), is(1)); + assertThat(c2.refCnt(), is(2)); + assertThat(c3.refCnt(), is(3)); + + // This should decrease bufA.refCnt(). + bufB.release(); + assertThat(bufB.refCnt(), is(0)); + + // Ensure bufA.refCnt() changed. + assertThat(bufA.refCnt(), is(0)); + + // Ensure that c[123]'s refCnt also changed due to the deallocation of bufA. + assertThat(c1.refCnt(), is(0)); + assertThat(c2.refCnt(), is(1)); + assertThat(c3.refCnt(), is(2)); + + c3.release(2); + c2.release(); + } + + @Test + public void testReferenceCounts3() { + ByteBuf c1 = buffer().writeByte(1); + ByteBuf c2 = buffer().writeByte(2).retain(); + ByteBuf c3 = buffer().writeByte(3).retain(2); + + CompositeByteBuf buf = compositeBuffer(); + assertThat(buf.refCnt(), is(1)); + + List components = new ArrayList(); + Collections.addAll(components, c1, c2, c3); + buf.addComponents(components); + + // Ensure that c[123]'s refCount did not change. + assertThat(c1.refCnt(), is(1)); + assertThat(c2.refCnt(), is(2)); + assertThat(c3.refCnt(), is(3)); + + assertThat(buf.component(0).refCnt(), is(1)); + assertThat(buf.component(1).refCnt(), is(2)); + assertThat(buf.component(2).refCnt(), is(3)); + + c3.release(2); + c2.release(); + buf.release(); + } + + @Test + public void testNestedLayout() { + CompositeByteBuf buf = compositeBuffer(); + buf.addComponent( + compositeBuffer() + .addComponent(wrappedBuffer(new byte[]{1, 2})) + .addComponent(wrappedBuffer(new byte[]{3, 4})).slice(1, 2)); + + ByteBuffer[] nioBuffers = buf.nioBuffers(0, 2); + assertThat(nioBuffers.length, is(2)); + assertThat(nioBuffers[0].remaining(), is(1)); + assertThat(nioBuffers[0].get(), is((byte) 2)); + assertThat(nioBuffers[1].remaining(), is(1)); + assertThat(nioBuffers[1].get(), is((byte) 3)); + + buf.release(); + } + + @Test + public void testRemoveLastComponent() { + CompositeByteBuf buf = compositeBuffer(); + buf.addComponent(wrappedBuffer(new byte[]{1, 2})); + assertEquals(1, buf.numComponents()); + buf.removeComponent(0); + assertEquals(0, buf.numComponents()); + buf.release(); + } + + @Test + public void testCopyEmpty() { + CompositeByteBuf buf = compositeBuffer(); + assertEquals(0, buf.numComponents()); + + ByteBuf copy = buf.copy(); + assertEquals(0, copy.readableBytes()); + + buf.release(); + copy.release(); + } + + @Test + public void testDuplicateEmpty() { + CompositeByteBuf buf = compositeBuffer(); + assertEquals(0, buf.numComponents()); + assertEquals(0, buf.duplicate().readableBytes()); + + buf.release(); + } + + @Test + public void testRemoveLastComponentWithOthersLeft() { + CompositeByteBuf buf = compositeBuffer(); + buf.addComponent(wrappedBuffer(new byte[]{1, 2})); + buf.addComponent(wrappedBuffer(new byte[]{1, 2})); + assertEquals(2, buf.numComponents()); + buf.removeComponent(1); + assertEquals(1, buf.numComponents()); + buf.release(); + } + + @Test + public void testRemoveComponents() { + CompositeByteBuf buf = compositeBuffer(); + for (int i = 0; i < 10; i++) { + buf.addComponent(wrappedBuffer(new byte[]{1, 2})); + } + assertEquals(10, buf.numComponents()); + assertEquals(20, buf.capacity()); + buf.removeComponents(4, 3); + assertEquals(7, buf.numComponents()); + assertEquals(14, buf.capacity()); + buf.release(); + } + + @Test + public void testGatheringWritesHeap() throws Exception { + testGatheringWrites(buffer().order(order), buffer().order(order)); + } + + @Test + public void testGatheringWritesDirect() throws Exception { + testGatheringWrites(directBuffer().order(order), directBuffer().order(order)); + } + + @Test + public void testGatheringWritesMixes() throws Exception { + testGatheringWrites(buffer().order(order), directBuffer().order(order)); + } + + @Test + public void testGatheringWritesHeapPooled() throws Exception { + testGatheringWrites(PooledByteBufAllocator.DEFAULT.heapBuffer().order(order), + PooledByteBufAllocator.DEFAULT.heapBuffer().order(order)); + } + + @Test + public void testGatheringWritesDirectPooled() throws Exception { + testGatheringWrites(PooledByteBufAllocator.DEFAULT.directBuffer().order(order), + PooledByteBufAllocator.DEFAULT.directBuffer().order(order)); + } + + @Test + public void testGatheringWritesMixesPooled() throws Exception { + testGatheringWrites(PooledByteBufAllocator.DEFAULT.heapBuffer().order(order), + PooledByteBufAllocator.DEFAULT.directBuffer().order(order)); + } + + private static void testGatheringWrites(ByteBuf buf1, ByteBuf buf2) throws Exception { + CompositeByteBuf buf = compositeBuffer(); + buf.addComponent(buf1.writeBytes(new byte[]{1, 2})); + buf.addComponent(buf2.writeBytes(new byte[]{1, 2})); + buf.writerIndex(3); + buf.readerIndex(1); + + TestGatheringByteChannel channel = new TestGatheringByteChannel(); + + buf.readBytes(channel, 2); + + byte[] data = new byte[2]; + buf.getBytes(1, data); + assertArrayEquals(data, channel.writtenBytes()); + + buf.release(); + } + + @Test + public void testGatheringWritesPartialHeap() throws Exception { + testGatheringWritesPartial(buffer().order(order), buffer().order(order), false); + } + + @Test + public void testGatheringWritesPartialDirect() throws Exception { + testGatheringWritesPartial(directBuffer().order(order), directBuffer().order(order), false); + } + + @Test + public void testGatheringWritesPartialMixes() throws Exception { + testGatheringWritesPartial(buffer().order(order), directBuffer().order(order), false); + } + + @Test + public void testGatheringWritesPartialHeapSlice() throws Exception { + testGatheringWritesPartial(buffer().order(order), buffer().order(order), true); + } + + @Test + public void testGatheringWritesPartialDirectSlice() throws Exception { + testGatheringWritesPartial(directBuffer().order(order), directBuffer().order(order), true); + } + + @Test + public void testGatheringWritesPartialMixesSlice() throws Exception { + testGatheringWritesPartial(buffer().order(order), directBuffer().order(order), true); + } + + @Test + public void testGatheringWritesPartialHeapPooled() throws Exception { + testGatheringWritesPartial(PooledByteBufAllocator.DEFAULT.heapBuffer().order(order), + PooledByteBufAllocator.DEFAULT.heapBuffer().order(order), false); + } + + @Test + public void testGatheringWritesPartialDirectPooled() throws Exception { + testGatheringWritesPartial(PooledByteBufAllocator.DEFAULT.directBuffer().order(order), + PooledByteBufAllocator.DEFAULT.directBuffer().order(order), false); + } + + @Test + public void testGatheringWritesPartialMixesPooled() throws Exception { + testGatheringWritesPartial(PooledByteBufAllocator.DEFAULT.heapBuffer().order(order), + PooledByteBufAllocator.DEFAULT.directBuffer().order(order), false); + } + + @Test + public void testGatheringWritesPartialHeapPooledSliced() throws Exception { + testGatheringWritesPartial(PooledByteBufAllocator.DEFAULT.heapBuffer().order(order), + PooledByteBufAllocator.DEFAULT.heapBuffer().order(order), true); + } + + @Test + public void testGatheringWritesPartialDirectPooledSliced() throws Exception { + testGatheringWritesPartial(PooledByteBufAllocator.DEFAULT.directBuffer().order(order), + PooledByteBufAllocator.DEFAULT.directBuffer().order(order), true); + } + + @Test + public void testGatheringWritesPartialMixesPooledSliced() throws Exception { + testGatheringWritesPartial(PooledByteBufAllocator.DEFAULT.heapBuffer().order(order), + PooledByteBufAllocator.DEFAULT.directBuffer().order(order), true); + } + + private static void testGatheringWritesPartial(ByteBuf buf1, ByteBuf buf2, boolean slice) throws Exception { + CompositeByteBuf buf = compositeBuffer(); + buf1.writeBytes(new byte[]{1, 2, 3, 4}); + buf2.writeBytes(new byte[]{1, 2, 3, 4}); + if (slice) { + buf1 = buf1.readerIndex(1).slice(); + buf2 = buf2.writerIndex(3).slice(); + buf.addComponent(buf1); + buf.addComponent(buf2); + buf.writerIndex(6); + } else { + buf.addComponent(buf1); + buf.addComponent(buf2); + buf.writerIndex(7); + buf.readerIndex(1); + } + + TestGatheringByteChannel channel = new TestGatheringByteChannel(1); + + while (buf.isReadable()) { + buf.readBytes(channel, buf.readableBytes()); + } + + byte[] data = new byte[6]; + + if (slice) { + buf.getBytes(0, data); + } else { + buf.getBytes(1, data); + } + assertArrayEquals(data, channel.writtenBytes()); + + buf.release(); + } + + @Test + public void testGatheringWritesSingleHeap() throws Exception { + testGatheringWritesSingleBuf(buffer().order(order)); + } + + @Test + public void testGatheringWritesSingleDirect() throws Exception { + testGatheringWritesSingleBuf(directBuffer().order(order)); + } + + private static void testGatheringWritesSingleBuf(ByteBuf buf1) throws Exception { + CompositeByteBuf buf = compositeBuffer(); + buf.addComponent(buf1.writeBytes(new byte[]{1, 2, 3, 4})); + buf.writerIndex(3); + buf.readerIndex(1); + + TestGatheringByteChannel channel = new TestGatheringByteChannel(); + buf.readBytes(channel, 2); + + byte[] data = new byte[2]; + buf.getBytes(1, data); + assertArrayEquals(data, channel.writtenBytes()); + + buf.release(); + } + + @Override + @Test + public void testInternalNioBuffer() { + CompositeByteBuf buf = compositeBuffer(); + assertEquals(0, buf.internalNioBuffer(0, 0).remaining()); + + // If non-derived buffer is added, its internal buffer should be returned + ByteBuf concreteBuffer = directBuffer().writeByte(1); + buf.addComponent(concreteBuffer); + assertSame(concreteBuffer.internalNioBuffer(0, 1), buf.internalNioBuffer(0, 1)); + buf.release(); + + // In derived cases, the original internal buffer must not be used + buf = compositeBuffer(); + concreteBuffer = directBuffer().writeByte(1); + buf.addComponent(concreteBuffer.slice()); + assertNotSame(concreteBuffer.internalNioBuffer(0, 1), buf.internalNioBuffer(0, 1)); + buf.release(); + + buf = compositeBuffer(); + concreteBuffer = directBuffer().writeByte(1); + buf.addComponent(concreteBuffer.duplicate()); + assertNotSame(concreteBuffer.internalNioBuffer(0, 1), buf.internalNioBuffer(0, 1)); + buf.release(); + } + + @Test + public void testisDirectMultipleBufs() { + CompositeByteBuf buf = compositeBuffer(); + assertFalse(buf.isDirect()); + + buf.addComponent(directBuffer().writeByte(1)); + + assertTrue(buf.isDirect()); + buf.addComponent(directBuffer().writeByte(1)); + assertTrue(buf.isDirect()); + + buf.addComponent(buffer().writeByte(1)); + assertFalse(buf.isDirect()); + + buf.release(); + } + + // See https://github.com/netty/netty/issues/1976 + @Test + public void testDiscardSomeReadBytes() { + CompositeByteBuf cbuf = compositeBuffer(); + int len = 8 * 4; + for (int i = 0; i < len; i += 4) { + ByteBuf buf = buffer().writeInt(i); + cbuf.capacity(cbuf.writerIndex()).addComponent(buf).writerIndex(i + 4); + } + cbuf.writeByte(1); + + byte[] me = new byte[len]; + cbuf.readBytes(me); + cbuf.readByte(); + + cbuf.discardSomeReadBytes(); + + cbuf.release(); + } + + @Test + public void testAddEmptyBufferRelease() { + CompositeByteBuf cbuf = compositeBuffer(); + ByteBuf buf = buffer(); + assertEquals(1, buf.refCnt()); + cbuf.addComponent(buf); + assertEquals(1, buf.refCnt()); + + cbuf.release(); + assertEquals(0, buf.refCnt()); + } + + @Test + public void testAddEmptyBuffersRelease() { + CompositeByteBuf cbuf = compositeBuffer(); + ByteBuf buf = buffer(); + ByteBuf buf2 = buffer().writeInt(1); + ByteBuf buf3 = buffer(); + + assertEquals(1, buf.refCnt()); + assertEquals(1, buf2.refCnt()); + assertEquals(1, buf3.refCnt()); + + cbuf.addComponents(buf, buf2, buf3); + assertEquals(1, buf.refCnt()); + assertEquals(1, buf2.refCnt()); + assertEquals(1, buf3.refCnt()); + + cbuf.release(); + assertEquals(0, buf.refCnt()); + assertEquals(0, buf2.refCnt()); + assertEquals(0, buf3.refCnt()); + } + + @Test + public void testAddEmptyBufferInMiddle() { + CompositeByteBuf cbuf = compositeBuffer(); + ByteBuf buf1 = buffer().writeByte((byte) 1); + cbuf.addComponent(true, buf1); + cbuf.addComponent(true, EMPTY_BUFFER); + ByteBuf buf3 = buffer().writeByte((byte) 2); + cbuf.addComponent(true, buf3); + + assertEquals(2, cbuf.readableBytes()); + assertEquals((byte) 1, cbuf.readByte()); + assertEquals((byte) 2, cbuf.readByte()); + + assertSame(EMPTY_BUFFER, cbuf.internalComponent(1)); + assertNotSame(EMPTY_BUFFER, cbuf.internalComponentAtOffset(1)); + cbuf.release(); + } + + @Test + public void testInsertEmptyBufferInMiddle() { + CompositeByteBuf cbuf = compositeBuffer(); + ByteBuf buf1 = buffer().writeByte((byte) 1); + cbuf.addComponent(true, buf1); + ByteBuf buf2 = buffer().writeByte((byte) 2); + cbuf.addComponent(true, buf2); + + // insert empty one between the first two + cbuf.addComponent(true, 1, EMPTY_BUFFER); + + assertEquals(2, cbuf.readableBytes()); + assertEquals((byte) 1, cbuf.readByte()); + assertEquals((byte) 2, cbuf.readByte()); + + assertEquals(2, cbuf.capacity()); + assertEquals(3, cbuf.numComponents()); + + byte[] dest = new byte[2]; + // should skip over the empty one, not throw a java.lang.Error :) + cbuf.getBytes(0, dest); + + assertArrayEquals(new byte[] {1, 2}, dest); + + cbuf.release(); + } + + @Test + public void testAddFlattenedComponents() { + testAddFlattenedComponents(false); + } + + @Test + public void testAddFlattenedComponentsWithWrappedComposite() { + testAddFlattenedComponents(true); + } + + private void testAddFlattenedComponents(boolean addWrapped) { + ByteBuf b1 = Unpooled.wrappedBuffer(new byte[] { 1, 2, 3 }); + CompositeByteBuf newComposite = newCompositeBuffer() + .addComponent(true, b1) + .addFlattenedComponents(true, b1.retain()) + .addFlattenedComponents(true, Unpooled.EMPTY_BUFFER); + + assertEquals(2, newComposite.numComponents()); + assertEquals(6, newComposite.capacity()); + assertEquals(6, newComposite.writerIndex()); + + // It is important to use a pooled allocator here to ensure + // the slices returned by readRetainedSlice are of type + // PooledSlicedByteBuf, which maintains an independent refcount + // (so that we can be sure to cover this case) + ByteBuf buffer = PooledByteBufAllocator.DEFAULT.buffer() + .writeBytes(new byte[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); + + // use mixture of slice and retained slice + ByteBuf s1 = buffer.readRetainedSlice(2); + ByteBuf s2 = s1.retainedSlice(0, 2); + ByteBuf s3 = buffer.slice(0, 2).retain(); + ByteBuf s4 = s2.retainedSlice(0, 2); + buffer.release(); + + CompositeByteBuf compositeToAdd = compositeBuffer() + .addComponent(s1) + .addComponent(Unpooled.EMPTY_BUFFER) + .addComponents(s2, s3, s4); + // set readable range to be from middle of first component + // to middle of penultimate component + compositeToAdd.setIndex(1, 5); + + assertEquals(1, compositeToAdd.refCnt()); + assertEquals(1, s4.refCnt()); + + ByteBuf compositeCopy = compositeToAdd.copy(); + + if (addWrapped) { + compositeToAdd = new WrappedCompositeByteBuf(compositeToAdd); + } + newComposite.addFlattenedComponents(true, compositeToAdd); + + // verify that added range matches + ByteBufUtil.equals(compositeCopy, 0, + newComposite, 6, compositeCopy.readableBytes()); + + // should not include empty component or last component + // (latter outside of the readable range) + assertEquals(5, newComposite.numComponents()); + assertEquals(10, newComposite.capacity()); + assertEquals(10, newComposite.writerIndex()); + + assertEquals(0, compositeToAdd.refCnt()); + // s4 wasn't in added range so should have been jettisoned + assertEquals(0, s4.refCnt()); + assertEquals(1, newComposite.refCnt()); + + // releasing composite should release the remaining components + newComposite.release(); + assertEquals(0, newComposite.refCnt()); + assertEquals(0, s1.refCnt()); + assertEquals(0, s2.refCnt()); + assertEquals(0, s3.refCnt()); + assertEquals(0, b1.refCnt()); + } + + @Test + public void testIterator() { + CompositeByteBuf cbuf = newCompositeBuffer(); + cbuf.addComponent(EMPTY_BUFFER); + cbuf.addComponent(EMPTY_BUFFER); + + Iterator it = cbuf.iterator(); + assertTrue(it.hasNext()); + assertSame(EMPTY_BUFFER, it.next()); + assertTrue(it.hasNext()); + assertSame(EMPTY_BUFFER, it.next()); + assertFalse(it.hasNext()); + + try { + it.next(); + fail(); + } catch (NoSuchElementException e) { + //Expected + } + cbuf.release(); + } + + @Test + public void testEmptyIterator() { + CompositeByteBuf cbuf = newCompositeBuffer(); + + Iterator it = cbuf.iterator(); + assertFalse(it.hasNext()); + + try { + it.next(); + fail(); + } catch (NoSuchElementException e) { + //Expected + } + cbuf.release(); + } + + @Test + public void testIteratorConcurrentModificationAdd() { + CompositeByteBuf cbuf = newCompositeBuffer(); + cbuf.addComponent(EMPTY_BUFFER); + + final Iterator it = cbuf.iterator(); + cbuf.addComponent(EMPTY_BUFFER); + + assertTrue(it.hasNext()); + try { + assertThrows(ConcurrentModificationException.class, new Executable() { + @Override + public void execute() { + it.next(); + } + }); + } finally { + cbuf.release(); + } + } + + @Test + public void testIteratorConcurrentModificationRemove() { + CompositeByteBuf cbuf = newCompositeBuffer(); + cbuf.addComponent(EMPTY_BUFFER); + + final Iterator it = cbuf.iterator(); + cbuf.removeComponent(0); + + assertTrue(it.hasNext()); + try { + assertThrows(ConcurrentModificationException.class, new Executable() { + @Override + public void execute() { + it.next(); + } + }); + } finally { + cbuf.release(); + } + } + + @Test + public void testReleasesItsComponents() { + ByteBuf buffer = PooledByteBufAllocator.DEFAULT.buffer(); // 1 + + buffer.writeBytes(new byte[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); + + ByteBuf s1 = buffer.readSlice(2).retain(); // 2 + ByteBuf s2 = s1.readSlice(2).retain(); // 3 + ByteBuf s3 = s2.readSlice(2).retain(); // 4 + ByteBuf s4 = s3.readSlice(2).retain(); // 5 + + ByteBuf composite = PooledByteBufAllocator.DEFAULT.compositeBuffer() + .addComponent(s1) + .addComponents(s2, s3, s4) + .order(ByteOrder.LITTLE_ENDIAN); + + assertEquals(1, composite.refCnt()); + assertEquals(5, buffer.refCnt()); + + // releasing composite should release the 4 components + ReferenceCountUtil.release(composite); + assertEquals(0, composite.refCnt()); + assertEquals(1, buffer.refCnt()); + + // last remaining ref to buffer + ReferenceCountUtil.release(buffer); + assertEquals(0, buffer.refCnt()); + } + + @Test + public void testReleasesItsComponents2() { + // It is important to use a pooled allocator here to ensure + // the slices returned by readRetainedSlice are of type + // PooledSlicedByteBuf, which maintains an independent refcount + // (so that we can be sure to cover this case) + ByteBuf buffer = PooledByteBufAllocator.DEFAULT.buffer(); // 1 + + buffer.writeBytes(new byte[] {1, 2, 3, 4, 5, 6, 7, 8, 9, 10}); + + // use readRetainedSlice this time - produces different kind of slices + ByteBuf s1 = buffer.readRetainedSlice(2); // 2 + ByteBuf s2 = s1.readRetainedSlice(2); // 3 + ByteBuf s3 = s2.readRetainedSlice(2); // 4 + ByteBuf s4 = s3.readRetainedSlice(2); // 5 + + ByteBuf composite = newCompositeBuffer() + .addComponent(s1) + .addComponents(s2, s3, s4) + .order(ByteOrder.LITTLE_ENDIAN); + + assertEquals(1, composite.refCnt()); + assertEquals(2, buffer.refCnt()); + + // releasing composite should release the 4 components + composite.release(); + assertEquals(0, composite.refCnt()); + assertEquals(1, buffer.refCnt()); + + // last remaining ref to buffer + buffer.release(); + assertEquals(0, buffer.refCnt()); + } + + @Test + public void testReleasesOnShrink() { + + ByteBuf b1 = Unpooled.buffer(2).writeShort(1); + ByteBuf b2 = Unpooled.buffer(2).writeShort(2); + + // composite takes ownership of s1 and s2 + ByteBuf composite = newCompositeBuffer() + .addComponents(b1, b2); + + assertEquals(4, composite.capacity()); + + // reduce capacity down to two, will drop the second component + composite.capacity(2); + assertEquals(2, composite.capacity()); + + // releasing composite should release the components + composite.release(); + assertEquals(0, composite.refCnt()); + assertEquals(0, b1.refCnt()); + assertEquals(0, b2.refCnt()); + } + + @Test + public void testReleasesOnShrink2() { + // It is important to use a pooled allocator here to ensure + // the slices returned by readRetainedSlice are of type + // PooledSlicedByteBuf, which maintains an independent refcount + // (so that we can be sure to cover this case) + ByteBuf buffer = PooledByteBufAllocator.DEFAULT.buffer(); + + buffer.writeShort(1).writeShort(2); + + ByteBuf b1 = buffer.readRetainedSlice(2); + ByteBuf b2 = b1.retainedSlice(b1.readerIndex(), 2); + + // composite takes ownership of b1 and b2 + ByteBuf composite = newCompositeBuffer() + .addComponents(b1, b2); + + assertEquals(4, composite.capacity()); + + // reduce capacity down to two, will drop the second component + composite.capacity(2); + assertEquals(2, composite.capacity()); + + // releasing composite should release the components + composite.release(); + assertEquals(0, composite.refCnt()); + assertEquals(0, b1.refCnt()); + assertEquals(0, b2.refCnt()); + + // release last remaining ref to buffer + buffer.release(); + assertEquals(0, buffer.refCnt()); + } + + @Test + public void testAllocatorIsSameWhenCopy() { + testAllocatorIsSameWhenCopy(false); + } + + @Test + public void testAllocatorIsSameWhenCopyUsingIndexAndLength() { + testAllocatorIsSameWhenCopy(true); + } + + private void testAllocatorIsSameWhenCopy(boolean withIndexAndLength) { + ByteBuf buffer = newBuffer(8); + buffer.writeZero(4); + ByteBuf copy = withIndexAndLength ? buffer.copy(0, 4) : buffer.copy(); + assertEquals(buffer, copy); + assertEquals(buffer.isDirect(), copy.isDirect()); + assertSame(buffer.alloc(), copy.alloc()); + buffer.release(); + copy.release(); + } + + @Test + public void testDecomposeMultiple() { + testDecompose(150, 500, 3); + } + + @Test + public void testDecomposeOne() { + testDecompose(310, 50, 1); + } + + @Test + public void testDecomposeNone() { + testDecompose(310, 0, 0); + } + + private void testDecompose(int offset, int length, int expectedListSize) { + byte[] bytes = new byte[1024]; + PlatformDependent.threadLocalRandom().nextBytes(bytes); + ByteBuf buf = wrappedBuffer(bytes); + + CompositeByteBuf composite = newCompositeBuffer(); + composite.addComponents(true, + buf.retainedSlice(100, 200), + buf.retainedSlice(300, 400), + buf.retainedSlice(700, 100)); + + ByteBuf slice = composite.slice(offset, length); + List bufferList = composite.decompose(offset, length); + assertEquals(expectedListSize, bufferList.size()); + ByteBuf wrapped = wrappedBuffer(bufferList.toArray(new ByteBuf[0])); + + assertEquals(slice, wrapped); + composite.release(); + buf.release(); + + for (ByteBuf buffer: bufferList) { + assertEquals(0, buffer.refCnt()); + } + } + + @Test + public void testDecomposeReturnNonUnwrappedBuffer() { + ByteBuf buf = PooledByteBufAllocator.DEFAULT.buffer(1024); + buf.writeZero(1024); + ByteBuf sliced = buf.retainedSlice(100, 200); + sliced.retain(); + assertEquals(2, sliced.refCnt()); + CompositeByteBuf composite = newCompositeBuffer(); + composite.addComponents(true, sliced); + + List bufferList = composite.decompose(0, 100); + assertEquals(1, bufferList.size()); + ByteBuf decomposed = bufferList.get(0); + assertSame(sliced.refCnt(), decomposed.refCnt()); + decomposed.release(); + + assertSame(sliced.refCnt(), decomposed.refCnt()); + + composite.release(); + buf.release(); + + for (ByteBuf buffer: bufferList) { + assertEquals(0, buffer.refCnt()); + } + } + + @Test + public void testDecomposeReturnNonUnwrappedBuffers() { + ByteBuf buf = PooledByteBufAllocator.DEFAULT.buffer(1024); + buf.writeZero(1024); + ByteBuf sliced = buf.retainedSlice(100, 200); + ByteBuf sliced2 = buf.retainedSlice(400, 100); + sliced.retain(); + sliced2.retain(); + assertEquals(2, sliced.refCnt()); + CompositeByteBuf composite = compositeBuffer(); + composite.addComponents(true, sliced); + composite.addComponents(true, sliced2); + + List bufferList = composite.decompose(100, 150); + assertEquals(2, bufferList.size()); + ByteBuf decomposed = bufferList.get(0); + ByteBuf decomposed2 = bufferList.get(1); + assertSame(sliced.refCnt(), decomposed.refCnt()); + decomposed.release(); + decomposed2.release(); + + assertSame(sliced.refCnt(), decomposed.refCnt()); + assertSame(sliced2.refCnt(), decomposed2.refCnt()); + + composite.release(); + buf.release(); + + for (ByteBuf buffer: bufferList) { + assertEquals(0, buffer.refCnt()); + } + } + + @Test + public void testComponentsLessThanLowerBound() { + try { + new CompositeByteBuf(ALLOC, true, 0); + fail(); + } catch (IllegalArgumentException e) { + assertEquals("maxNumComponents: 0 (expected: >= 1)", e.getMessage()); + } + } + + @Test + public void testComponentsEqualToLowerBound() { + assertCompositeBufCreated(1); + } + + @Test + public void testComponentsGreaterThanLowerBound() { + assertCompositeBufCreated(5); + } + + /** + * Assert that a new {@linkplain CompositeByteBuf} was created successfully with the desired number of max + * components. + */ + private static void assertCompositeBufCreated(int expectedMaxComponents) { + CompositeByteBuf buf = new CompositeByteBuf(ALLOC, true, expectedMaxComponents); + + assertEquals(expectedMaxComponents, buf.maxNumComponents()); + assertTrue(buf.release()); + } + + @Test + public void testDiscardSomeReadBytesCorrectlyUpdatesLastAccessed() { + testDiscardCorrectlyUpdatesLastAccessed(true); + } + + @Test + public void testDiscardReadBytesCorrectlyUpdatesLastAccessed() { + testDiscardCorrectlyUpdatesLastAccessed(false); + } + + private void testDiscardCorrectlyUpdatesLastAccessed(boolean discardSome) { + CompositeByteBuf cbuf = newCompositeBuffer(); + List buffers = new ArrayList(4); + for (int i = 0; i < 4; i++) { + ByteBuf buf = buffer().writeInt(i); + cbuf.addComponent(true, buf); + buffers.add(buf); + } + + // Skip the first 2 bytes which means even if we call discard*ReadBytes() later we can no drop the first + // component as it is still used. + cbuf.skipBytes(2); + if (discardSome) { + cbuf.discardSomeReadBytes(); + } else { + cbuf.discardReadBytes(); + } + assertEquals(4, cbuf.numComponents()); + + // Now skip 3 bytes which means we should be able to drop the first component on the next discard*ReadBytes() + // call. + cbuf.skipBytes(3); + + if (discardSome) { + cbuf.discardSomeReadBytes(); + } else { + cbuf.discardReadBytes(); + } + assertEquals(3, cbuf.numComponents()); + // Now skip again 3 bytes which should bring our readerIndex == start of the 3 component. + cbuf.skipBytes(3); + + // Read one int (4 bytes) which should bring our readerIndex == start of the 4 component. + assertEquals(2, cbuf.readInt()); + if (discardSome) { + cbuf.discardSomeReadBytes(); + } else { + cbuf.discardReadBytes(); + } + + // Now all except the last component should have been dropped / released. + assertEquals(1, cbuf.numComponents()); + assertEquals(3, cbuf.readInt()); + if (discardSome) { + cbuf.discardSomeReadBytes(); + } else { + cbuf.discardReadBytes(); + } + assertEquals(0, cbuf.numComponents()); + + // These should have been released already. + for (ByteBuf buffer: buffers) { + assertEquals(0, buffer.refCnt()); + } + assertTrue(cbuf.release()); + } + + // See https://github.com/netty/netty/issues/11612 + @Test + public void testAddComponentWithNullEntry() { + final ByteBuf buffer = Unpooled.buffer(8).writeZero(8); + final CompositeByteBuf compositeByteBuf = compositeBuffer(Integer.MAX_VALUE); + try { + compositeByteBuf.addComponents(true, new ByteBuf[] { buffer, null }); + assertEquals(8, compositeByteBuf.readableBytes()); + assertEquals(1, compositeByteBuf.numComponents()); + } finally { + compositeByteBuf.release(); + } + } + + @Test + public void testOverflowWhileAddingComponent() { + int capacity = 1024 * 1024; // 1MB + final ByteBuf buffer = Unpooled.buffer(capacity).writeZero(capacity); + final CompositeByteBuf compositeByteBuf = compositeBuffer(Integer.MAX_VALUE); + + try { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + for (int i = 0; i >= 0; i += buffer.readableBytes()) { + ByteBuf duplicate = buffer.duplicate(); + compositeByteBuf.addComponent(duplicate); + duplicate.retain(); + } + } + }); + } finally { + compositeByteBuf.release(); + } + } + + @Test + public void testOverflowWhileAddingComponentsViaVarargs() { + int capacity = 1024 * 1024; // 1MB + final ByteBuf buffer = Unpooled.buffer(capacity).writeZero(capacity); + final CompositeByteBuf compositeByteBuf = compositeBuffer(Integer.MAX_VALUE); + + try { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + for (int i = 0; i >= 0; i += buffer.readableBytes()) { + ByteBuf duplicate = buffer.duplicate(); + compositeByteBuf.addComponents(duplicate); + duplicate.retain(); + } + } + }); + } finally { + compositeByteBuf.release(); + } + } + + @Test + public void testOverflowWhileAddingComponentsViaIterable() { + int capacity = 1024 * 1024; // 1MB + final ByteBuf buffer = Unpooled.buffer(capacity).writeZero(capacity); + final CompositeByteBuf compositeByteBuf = compositeBuffer(Integer.MAX_VALUE); + + try { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + for (int i = 0; i >= 0; i += buffer.readableBytes()) { + ByteBuf duplicate = buffer.duplicate(); + compositeByteBuf.addComponents(Collections.singletonList(duplicate)); + duplicate.retain(); + } + } + }); + } finally { + compositeByteBuf.release(); + } + } + + @Test + public void testOverflowWhileUseConstructorWithOffset() { + int capacity = 1024 * 1024; // 1MB + final ByteBuf buffer = Unpooled.buffer(capacity).writeZero(capacity); + final List buffers = new ArrayList(); + for (long i = 0; i <= Integer.MAX_VALUE; i += capacity) { + buffers.add(buffer.duplicate()); + } + // Add one more + buffers.add(buffer.duplicate()); + + try { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + ByteBuf[] bufferArray = buffers.toArray(new ByteBuf[0]); + new CompositeByteBuf(ALLOC, false, Integer.MAX_VALUE, bufferArray, 0); + } + }); + } finally { + buffer.release(); + } + } + + @Test + public void testNotOverflowWhileUseConstructorWithOffset() { + int capacity = 1024 * 1024; // 1MB + final ByteBuf buffer = Unpooled.buffer(capacity).writeZero(capacity); + final List buffers = new ArrayList(); + for (long i = 0; i <= Integer.MAX_VALUE; i += capacity) { + buffers.add(buffer.duplicate()); + } + // Add one more + buffers.add(buffer.duplicate()); + + ByteBuf[] bufferArray = buffers.toArray(new ByteBuf[0]); + CompositeByteBuf compositeByteBuf = + new CompositeByteBuf(ALLOC, false, Integer.MAX_VALUE, bufferArray, bufferArray.length - 1); + compositeByteBuf.release(); + } + + @Test + public void sliceOfCompositeBufferMustThrowISEAfterDiscardBytes() { + CompositeByteBuf composite = compositeBuffer(); + composite.addComponent(true, buffer(8).writeZero(8)); + + ByteBuf slice = composite.retainedSlice(); + composite.skipBytes(slice.readableBytes()); + composite.discardSomeReadBytes(); + + try { + slice.readByte(); + fail("Expected readByte of discarded slice to throw."); + } catch (IllegalStateException ignore) { + // Good. + } finally { + slice.release(); + composite.release(); + } + } + + @Test + public void forEachByteOnNestedCompositeByteBufMustSeeEntireFlattenedContents() { + CompositeByteBuf buf = newCompositeBuffer(); + buf.addComponent(true, newCompositeBuffer().addComponents( + true, + wrappedBuffer(new byte[] {1, 2, 3}), + wrappedBuffer(new byte[] {4, 5, 6}))); + final byte[] arrayAsc = new byte[6]; + final byte[] arrayDesc = new byte[6]; + buf.forEachByte(new ByteProcessor() { + int index; + + @Override + public boolean process(byte value) throws Exception { + arrayAsc[index++] = value; + return true; + } + }); + buf.forEachByteDesc(new ByteProcessor() { + int index; + + @Override + public boolean process(byte value) throws Exception { + arrayDesc[index++] = value; + return true; + } + }); + assertArrayEquals(new byte[] {1, 2, 3, 4, 5, 6}, arrayAsc); + assertArrayEquals(new byte[] {6, 5, 4, 3, 2, 1}, arrayDesc); + buf.release(); + } +} diff --git a/netty-buffer/src/test/java/io/netty/buffer/AbstractPooledByteBufTest.java b/netty-buffer/src/test/java/io/netty/buffer/AbstractPooledByteBufTest.java new file mode 100644 index 0000000..a4a0e61 --- /dev/null +++ b/netty-buffer/src/test/java/io/netty/buffer/AbstractPooledByteBufTest.java @@ -0,0 +1,147 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.hamcrest.Matchers.is; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public abstract class AbstractPooledByteBufTest extends AbstractByteBufTest { + + protected abstract ByteBuf alloc(int length, int maxCapacity); + + @Override + protected ByteBuf newBuffer(int length, int maxCapacity) { + ByteBuf buffer = alloc(length, maxCapacity); + + // Testing if the writerIndex and readerIndex are correct when allocate and also after we reset the mark. + assertEquals(0, buffer.writerIndex()); + assertEquals(0, buffer.readerIndex()); + buffer.resetReaderIndex(); + buffer.resetWriterIndex(); + assertEquals(0, buffer.writerIndex()); + assertEquals(0, buffer.readerIndex()); + return buffer; + } + + @Test + public void ensureWritableWithEnoughSpaceShouldNotThrow() { + ByteBuf buf = newBuffer(1, 10); + buf.ensureWritable(3); + assertThat(buf.writableBytes(), is(greaterThanOrEqualTo(3))); + buf.release(); + } + + @Test + public void ensureWritableWithNotEnoughSpaceShouldThrow() { + final ByteBuf buf = newBuffer(1, 10); + try { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + buf.ensureWritable(11); + } + }); + } finally { + buf.release(); + } + } + + @Override + @Test + public void testMaxFastWritableBytes() { + ByteBuf buffer = newBuffer(150, 500).writerIndex(100); + assertEquals(50, buffer.writableBytes()); + assertEquals(150, buffer.capacity()); + assertEquals(500, buffer.maxCapacity()); + assertEquals(400, buffer.maxWritableBytes()); + + int chunkSize = pooledByteBuf(buffer).maxLength; + assertTrue(chunkSize >= 150); + int remainingInAlloc = Math.min(chunkSize - 100, 400); + assertEquals(remainingInAlloc, buffer.maxFastWritableBytes()); + + // write up to max, chunk alloc should not change (same handle) + long handleBefore = pooledByteBuf(buffer).handle; + buffer.writeBytes(new byte[remainingInAlloc]); + assertEquals(handleBefore, pooledByteBuf(buffer).handle); + + assertEquals(0, buffer.maxFastWritableBytes()); + // writing one more should trigger a reallocation (new handle) + buffer.writeByte(7); + assertNotEquals(handleBefore, pooledByteBuf(buffer).handle); + + // should not exceed maxCapacity even if chunk alloc does + buffer.capacity(500); + assertEquals(500 - buffer.writerIndex(), buffer.maxFastWritableBytes()); + buffer.release(); + } + + private static PooledByteBuf pooledByteBuf(ByteBuf buffer) { + // might need to unwrap if swapped (LE) and/or leak-aware-wrapped + while (!(buffer instanceof PooledByteBuf)) { + buffer = buffer.unwrap(); + } + return (PooledByteBuf) buffer; + } + + @Test + public void testEnsureWritableDoesntGrowTooMuch() { + ByteBuf buffer = newBuffer(150, 500).writerIndex(100); + + assertEquals(50, buffer.writableBytes()); + int fastWritable = buffer.maxFastWritableBytes(); + assertTrue(fastWritable > 50); + + long handleBefore = pooledByteBuf(buffer).handle; + + // capacity expansion should not cause reallocation + // (should grow precisely the specified amount) + buffer.ensureWritable(fastWritable); + assertEquals(handleBefore, pooledByteBuf(buffer).handle); + assertEquals(100 + fastWritable, buffer.capacity()); + assertEquals(buffer.writableBytes(), buffer.maxFastWritableBytes()); + buffer.release(); + } + + @Test + public void testIsContiguous() { + ByteBuf buf = newBuffer(4); + assertTrue(buf.isContiguous()); + buf.release(); + } + + @Test + public void distinctBuffersMustNotOverlap() { + ByteBuf a = newBuffer(16384); + ByteBuf b = newBuffer(65536); + a.setByte(a.capacity() - 1, 1); + b.setByte(0, 2); + try { + assertEquals(1, a.getByte(a.capacity() - 1)); + } finally { + a.release(); + b.release(); + } + } +} diff --git a/netty-buffer/src/test/java/io/netty/buffer/AbstractReferenceCountedByteBufTest.java b/netty-buffer/src/test/java/io/netty/buffer/AbstractReferenceCountedByteBufTest.java new file mode 100644 index 0000000..8525f67 --- /dev/null +++ b/netty-buffer/src/test/java/io/netty/buffer/AbstractReferenceCountedByteBufTest.java @@ -0,0 +1,358 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import io.netty.util.IllegalReferenceCountException; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.channels.FileChannel; +import java.nio.channels.GatheringByteChannel; +import java.nio.channels.ScatteringByteChannel; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class AbstractReferenceCountedByteBufTest { + + @Test + public void testRetainOverflow() { + final AbstractReferenceCountedByteBuf referenceCounted = newReferenceCounted(); + referenceCounted.setRefCnt(Integer.MAX_VALUE); + assertEquals(Integer.MAX_VALUE, referenceCounted.refCnt()); + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + referenceCounted.retain(); + } + }); + } + + @Test + public void testRetainOverflow2() { + final AbstractReferenceCountedByteBuf referenceCounted = newReferenceCounted(); + assertEquals(1, referenceCounted.refCnt()); + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + referenceCounted.retain(Integer.MAX_VALUE); + } + }); + } + + @Test + public void testReleaseOverflow() { + final AbstractReferenceCountedByteBuf referenceCounted = newReferenceCounted(); + referenceCounted.setRefCnt(0); + assertEquals(0, referenceCounted.refCnt()); + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + referenceCounted.release(Integer.MAX_VALUE); + } + }); + } + + @Test + public void testReleaseErrorMessage() { + AbstractReferenceCountedByteBuf referenceCounted = newReferenceCounted(); + assertTrue(referenceCounted.release()); + try { + referenceCounted.release(1); + fail("IllegalReferenceCountException didn't occur"); + } catch (IllegalReferenceCountException e) { + assertEquals("refCnt: 0, decrement: 1", e.getMessage()); + } + } + + @Test + public void testRetainResurrect() { + final AbstractReferenceCountedByteBuf referenceCounted = newReferenceCounted(); + assertTrue(referenceCounted.release()); + assertEquals(0, referenceCounted.refCnt()); + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + referenceCounted.retain(); + } + }); + } + + @Test + public void testRetainResurrect2() { + final AbstractReferenceCountedByteBuf referenceCounted = newReferenceCounted(); + assertTrue(referenceCounted.release()); + assertEquals(0, referenceCounted.refCnt()); + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + referenceCounted.retain(2); + } + }); + } + + private static AbstractReferenceCountedByteBuf newReferenceCounted() { + return new AbstractReferenceCountedByteBuf(Integer.MAX_VALUE) { + + @Override + protected byte _getByte(int index) { + throw new UnsupportedOperationException(); + } + + @Override + protected short _getShort(int index) { + throw new UnsupportedOperationException(); + } + + @Override + protected short _getShortLE(int index) { + throw new UnsupportedOperationException(); + } + + @Override + protected int _getUnsignedMedium(int index) { + throw new UnsupportedOperationException(); + } + + @Override + protected int _getUnsignedMediumLE(int index) { + throw new UnsupportedOperationException(); + } + + @Override + protected int _getInt(int index) { + throw new UnsupportedOperationException(); + } + + @Override + protected int _getIntLE(int index) { + throw new UnsupportedOperationException(); + } + + @Override + protected long _getLong(int index) { + throw new UnsupportedOperationException(); + } + + @Override + protected long _getLongLE(int index) { + throw new UnsupportedOperationException(); + } + + @Override + protected void _setByte(int index, int value) { + throw new UnsupportedOperationException(); + } + + @Override + protected void _setShort(int index, int value) { + throw new UnsupportedOperationException(); + } + + @Override + protected void _setShortLE(int index, int value) { + throw new UnsupportedOperationException(); + } + + @Override + protected void _setMedium(int index, int value) { + throw new UnsupportedOperationException(); + } + + @Override + protected void _setMediumLE(int index, int value) { + throw new UnsupportedOperationException(); + } + + @Override + protected void _setInt(int index, int value) { + throw new UnsupportedOperationException(); + } + + @Override + protected void _setIntLE(int index, int value) { + throw new UnsupportedOperationException(); + } + + @Override + protected void _setLong(int index, long value) { + throw new UnsupportedOperationException(); + } + + @Override + protected void _setLongLE(int index, long value) { + throw new UnsupportedOperationException(); + } + + @Override + public int capacity() { + throw new UnsupportedOperationException(); + } + + @Override + public ByteBuf capacity(int newCapacity) { + throw new UnsupportedOperationException(); + } + + @Override + public ByteBufAllocator alloc() { + throw new UnsupportedOperationException(); + } + + @Override + public ByteOrder order() { + throw new UnsupportedOperationException(); + } + + @Override + public ByteBuf unwrap() { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isDirect() { + throw new UnsupportedOperationException(); + } + + @Override + public ByteBuf getBytes(int index, ByteBuf dst, int dstIndex, int length) { + throw new UnsupportedOperationException(); + } + + @Override + public ByteBuf getBytes(int index, byte[] dst, int dstIndex, int length) { + throw new UnsupportedOperationException(); + } + + @Override + public ByteBuf getBytes(int index, ByteBuffer dst) { + throw new UnsupportedOperationException(); + } + + @Override + public ByteBuf getBytes(int index, OutputStream out, int length) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public int getBytes(int index, GatheringByteChannel out, int length) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public int getBytes(int index, FileChannel out, long position, int length) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public ByteBuf setBytes(int index, ByteBuf src, int srcIndex, int length) { + throw new UnsupportedOperationException(); + } + + @Override + public ByteBuf setBytes(int index, byte[] src, int srcIndex, int length) { + throw new UnsupportedOperationException(); + } + + @Override + public ByteBuf setBytes(int index, ByteBuffer src) { + throw new UnsupportedOperationException(); + } + + @Override + public int setBytes(int index, InputStream in, int length) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public int setBytes(int index, ScatteringByteChannel in, int length) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public int setBytes(int index, FileChannel in, long position, int length) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public ByteBuf copy(int index, int length) { + throw new UnsupportedOperationException(); + } + + @Override + public int nioBufferCount() { + throw new UnsupportedOperationException(); + } + + @Override + public ByteBuffer nioBuffer(int index, int length) { + throw new UnsupportedOperationException(); + } + + @Override + public ByteBuffer internalNioBuffer(int index, int length) { + throw new UnsupportedOperationException(); + } + + @Override + public ByteBuffer[] nioBuffers(int index, int length) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean hasArray() { + throw new UnsupportedOperationException(); + } + + @Override + public byte[] array() { + throw new UnsupportedOperationException(); + } + + @Override + public int arrayOffset() { + throw new UnsupportedOperationException(); + } + + @Override + public boolean hasMemoryAddress() { + throw new UnsupportedOperationException(); + } + + @Override + public long memoryAddress() { + throw new UnsupportedOperationException(); + } + + @Override + protected void deallocate() { + // NOOP + } + + @Override + public AbstractReferenceCountedByteBuf touch(Object hint) { + throw new UnsupportedOperationException(); + } + }; + } +} diff --git a/netty-buffer/src/test/java/io/netty/buffer/AdvancedLeakAwareByteBufTest.java b/netty-buffer/src/test/java/io/netty/buffer/AdvancedLeakAwareByteBufTest.java new file mode 100644 index 0000000..e79b35b --- /dev/null +++ b/netty-buffer/src/test/java/io/netty/buffer/AdvancedLeakAwareByteBufTest.java @@ -0,0 +1,53 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import static io.netty.buffer.Unpooled.*; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; + +import org.junit.jupiter.api.Test; + +import io.netty.util.CharsetUtil; +import io.netty.util.ResourceLeakTracker; + +public class AdvancedLeakAwareByteBufTest extends SimpleLeakAwareByteBufTest { + + @Override + protected Class leakClass() { + return AdvancedLeakAwareByteBuf.class; + } + + @Override + protected SimpleLeakAwareByteBuf wrap(ByteBuf buffer, ResourceLeakTracker tracker) { + return new AdvancedLeakAwareByteBuf(buffer, tracker); + } + + @Test + public void testAddComponentWithLeakAwareByteBuf() { + NoopResourceLeakTracker tracker = new NoopResourceLeakTracker(); + + ByteBuf buffer = wrappedBuffer("hello world".getBytes(CharsetUtil.US_ASCII)).slice(6, 5); + ByteBuf leakAwareBuf = wrap(buffer, tracker); + + CompositeByteBuf composite = compositeBuffer(); + composite.addComponent(true, leakAwareBuf); + byte[] result = new byte[5]; + ByteBuf bb = composite.component(0); + bb.readBytes(result); + assertArrayEquals("world".getBytes(CharsetUtil.US_ASCII), result); + composite.release(); + } +} diff --git a/netty-buffer/src/test/java/io/netty/buffer/AdvancedLeakAwareCompositeByteBufTest.java b/netty-buffer/src/test/java/io/netty/buffer/AdvancedLeakAwareCompositeByteBufTest.java new file mode 100644 index 0000000..e5742f4 --- /dev/null +++ b/netty-buffer/src/test/java/io/netty/buffer/AdvancedLeakAwareCompositeByteBufTest.java @@ -0,0 +1,31 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import io.netty.util.ResourceLeakTracker; + +public class AdvancedLeakAwareCompositeByteBufTest extends SimpleLeakAwareCompositeByteBufTest { + + @Override + protected SimpleLeakAwareCompositeByteBuf wrap(CompositeByteBuf buffer, ResourceLeakTracker tracker) { + return new AdvancedLeakAwareCompositeByteBuf(buffer, tracker); + } + + @Override + protected Class leakClass() { + return AdvancedLeakAwareByteBuf.class; + } +} diff --git a/netty-buffer/src/test/java/io/netty/buffer/AlignedPooledByteBufAllocatorTest.java b/netty-buffer/src/test/java/io/netty/buffer/AlignedPooledByteBufAllocatorTest.java new file mode 100644 index 0000000..c09efff --- /dev/null +++ b/netty-buffer/src/test/java/io/netty/buffer/AlignedPooledByteBufAllocatorTest.java @@ -0,0 +1,130 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import org.junit.jupiter.api.Test; + +import java.util.concurrent.TimeUnit; + +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class AlignedPooledByteBufAllocatorTest extends PooledByteBufAllocatorTest { + @Override + protected PooledByteBufAllocator newAllocator(boolean preferDirect) { + int directMemoryCacheAlignment = 1; + return new PooledByteBufAllocator( + preferDirect, + PooledByteBufAllocator.defaultNumHeapArena(), + PooledByteBufAllocator.defaultNumDirectArena(), + PooledByteBufAllocator.defaultPageSize(), + 11, + PooledByteBufAllocator.defaultSmallCacheSize(), + 64, + PooledByteBufAllocator.defaultUseCacheForAllThreads(), + directMemoryCacheAlignment); + } + + // https://github.com/netty/netty/issues/11955 + @Test + public void testCorrectElementSize() { + ByteBufAllocator allocator = new PooledByteBufAllocator( + true, + PooledByteBufAllocator.defaultNumHeapArena(), + PooledByteBufAllocator.defaultNumDirectArena(), + PooledByteBufAllocator.defaultPageSize(), + 11, + PooledByteBufAllocator.defaultSmallCacheSize(), + 64, + PooledByteBufAllocator.defaultUseCacheForAllThreads(), + 64); + + ByteBuf a = allocator.directBuffer(0, 16384); + ByteBuf b = allocator.directBuffer(0, 16384); + a.capacity(16); + assertEquals(16, a.capacity()); + b.capacity(16); + assertEquals(16, b.capacity()); + a.capacity(17); + assertEquals(17, a.capacity()); + b.capacity(18); + assertEquals(18, b.capacity()); + assertTrue(a.release()); + assertTrue(b.release()); + } + + @Test + public void testDirectSubpageReleaseLock() { + int initialCapacity = 0; + int directMemoryCacheAlignment = 32; + PooledByteBufAllocator allocator = new PooledByteBufAllocator( + true, + 0, + 1, + PooledByteBufAllocator.defaultPageSize(), + PooledByteBufAllocator.defaultMaxOrder(), + 0, + 0, + false, + directMemoryCacheAlignment); + + final PooledByteBuf byteBuf = pooledByteBuf(allocator.directBuffer(initialCapacity, 16)); + // Get the smallSubpagePools[] array in arena. + @SuppressWarnings("unchecked") + PoolSubpage[] smallSubpagePools = (PoolSubpage[]) byteBuf.chunk.arena.smallSubpagePools; + PoolSubpage head = null; + for (PoolSubpage subpage : smallSubpagePools) { + if (subpage.next != subpage) { + // Find the head subpage which the byteBuf belongs to. + head = subpage; + break; + } + } + assertNotNull(head); + Thread t1 = new Thread(new Runnable() { + @Override + public void run() { + // Because the head subpage was already locked in the main thread, so this should hang and wait. + byteBuf.release(); + } + }); + t1.setDaemon(true); + // Intentionally lock the head subpage in main thread. + head.lock(); + try { + t1.start(); + long start = System.nanoTime(); + while (!head.lock.hasQueuedThread(t1)) { + if ((System.nanoTime() - start) > TimeUnit.SECONDS.toNanos(3)) { + break; + } + } + assertTrue(head.lock.hasQueuedThread(t1), + "The t1 thread should still be waiting for the head lock."); + } finally { + head.unlock(); + } + } + + private static PooledByteBuf pooledByteBuf(ByteBuf buffer) { + // might need to unwrap if swapped (LE) and/or leak-aware-wrapped + while (!(buffer instanceof PooledByteBuf)) { + buffer = buffer.unwrap(); + } + return (PooledByteBuf) buffer; + } +} diff --git a/netty-buffer/src/test/java/io/netty/buffer/BigEndianCompositeByteBufTest.java b/netty-buffer/src/test/java/io/netty/buffer/BigEndianCompositeByteBufTest.java new file mode 100644 index 0000000..50d3e12 --- /dev/null +++ b/netty-buffer/src/test/java/io/netty/buffer/BigEndianCompositeByteBufTest.java @@ -0,0 +1,34 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import org.junit.jupiter.api.Test; + +/** + * Tests big-endian composite channel buffers + */ +public class BigEndianCompositeByteBufTest extends AbstractCompositeByteBufTest { + public BigEndianCompositeByteBufTest() { + super(Unpooled.BIG_ENDIAN); + } + + @Override + @Test + public void testInternalNioBufferAfterRelease() { + testInternalNioBufferAfterRelease0(UnsupportedOperationException.class); + } + +} diff --git a/netty-buffer/src/test/java/io/netty/buffer/BigEndianDirectByteBufTest.java b/netty-buffer/src/test/java/io/netty/buffer/BigEndianDirectByteBufTest.java new file mode 100644 index 0000000..e7a1463 --- /dev/null +++ b/netty-buffer/src/test/java/io/netty/buffer/BigEndianDirectByteBufTest.java @@ -0,0 +1,49 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.nio.ByteOrder; + +import org.junit.jupiter.api.Test; + +/** + * Tests big-endian direct channel buffers + */ +public class BigEndianDirectByteBufTest extends AbstractByteBufTest { + + @Override + protected ByteBuf newBuffer(int length, int maxCapacity) { + ByteBuf buffer = newDirectBuffer(length, maxCapacity); + assertSame(ByteOrder.BIG_ENDIAN, buffer.order()); + assertEquals(0, buffer.writerIndex()); + return buffer; + } + + protected ByteBuf newDirectBuffer(int length, int maxCapacity) { + return new UnpooledDirectByteBuf(UnpooledByteBufAllocator.DEFAULT, length, maxCapacity); + } + + @Test + public void testIsContiguous() { + ByteBuf buf = newBuffer(4); + assertTrue(buf.isContiguous()); + buf.release(); + } +} diff --git a/netty-buffer/src/test/java/io/netty/buffer/BigEndianHeapByteBufTest.java b/netty-buffer/src/test/java/io/netty/buffer/BigEndianHeapByteBufTest.java new file mode 100644 index 0000000..e9accda --- /dev/null +++ b/netty-buffer/src/test/java/io/netty/buffer/BigEndianHeapByteBufTest.java @@ -0,0 +1,55 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +/** + * Tests big-endian heap channel buffers + */ +public class BigEndianHeapByteBufTest extends AbstractByteBufTest { + + @Override + protected ByteBuf newBuffer(int length, int maxCapacity) { + ByteBuf buffer = Unpooled.buffer(length, maxCapacity); + assertEquals(0, buffer.writerIndex()); + return buffer; + } + + @Test + public void shouldNotAllowNullInConstructor1() { + assertThrows(NullPointerException.class, new Executable() { + @Override + public void execute() { + new UnpooledHeapByteBuf(null, new byte[1], 0); + } + }); + } + + @Test + public void shouldNotAllowNullInConstructor2() { + assertThrows(NullPointerException.class, new Executable() { + @Override + public void execute() { + new UnpooledHeapByteBuf(UnpooledByteBufAllocator.DEFAULT, null, 0); + } + }); + } +} diff --git a/netty-buffer/src/test/java/io/netty/buffer/BigEndianUnsafeDirectByteBufTest.java b/netty-buffer/src/test/java/io/netty/buffer/BigEndianUnsafeDirectByteBufTest.java new file mode 100644 index 0000000..a2de7f6 --- /dev/null +++ b/netty-buffer/src/test/java/io/netty/buffer/BigEndianUnsafeDirectByteBufTest.java @@ -0,0 +1,36 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + + +import io.netty.util.internal.PlatformDependent; +import org.junit.jupiter.api.Assumptions; +import org.junit.jupiter.api.BeforeEach; + +public class BigEndianUnsafeDirectByteBufTest extends BigEndianDirectByteBufTest { + + @BeforeEach + @Override + public void init() { + Assumptions.assumeTrue(PlatformDependent.hasUnsafe(), "sun.misc.Unsafe not found, skip tests"); + super.init(); + } + + @Override + protected ByteBuf newBuffer(int length, int maxCapacity) { + return new UnpooledUnsafeDirectByteBuf(UnpooledByteBufAllocator.DEFAULT, length, maxCapacity); + } +} diff --git a/netty-buffer/src/test/java/io/netty/buffer/BigEndianUnsafeNoCleanerDirectByteBufTest.java b/netty-buffer/src/test/java/io/netty/buffer/BigEndianUnsafeNoCleanerDirectByteBufTest.java new file mode 100644 index 0000000..19c053e --- /dev/null +++ b/netty-buffer/src/test/java/io/netty/buffer/BigEndianUnsafeNoCleanerDirectByteBufTest.java @@ -0,0 +1,36 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import io.netty.util.internal.PlatformDependent; +import org.junit.jupiter.api.Assumptions; +import org.junit.jupiter.api.BeforeEach; + +public class BigEndianUnsafeNoCleanerDirectByteBufTest extends BigEndianDirectByteBufTest { + + @BeforeEach + @Override + public void init() { + Assumptions.assumeTrue(PlatformDependent.useDirectBufferNoCleaner(), + "java.nio.DirectByteBuffer.(long, int) not found, skip tests"); + super.init(); + } + + @Override + protected ByteBuf newBuffer(int length, int maxCapacity) { + return new UnpooledUnsafeNoCleanerDirectByteBuf(UnpooledByteBufAllocator.DEFAULT, length, maxCapacity); + } +} diff --git a/netty-buffer/src/test/java/io/netty/buffer/ByteBufAllocatorTest.java b/netty-buffer/src/test/java/io/netty/buffer/ByteBufAllocatorTest.java new file mode 100644 index 0000000..b538953 --- /dev/null +++ b/netty-buffer/src/test/java/io/netty/buffer/ByteBufAllocatorTest.java @@ -0,0 +1,227 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public abstract class ByteBufAllocatorTest { + + protected abstract int defaultMaxCapacity(); + + protected abstract int defaultMaxComponents(); + + protected abstract ByteBufAllocator newAllocator(boolean preferDirect); + + @Test + public void testBuffer() { + testBuffer(true); + testBuffer(false); + } + + private void testBuffer(boolean preferDirect) { + ByteBufAllocator allocator = newAllocator(preferDirect); + ByteBuf buffer = allocator.buffer(1); + try { + assertBuffer(buffer, isDirectExpected(preferDirect), 1, defaultMaxCapacity()); + } finally { + buffer.release(); + } + } + + @Test + public void testBufferWithCapacity() { + testBufferWithCapacity(true, 8); + testBufferWithCapacity(false, 8); + } + + private void testBufferWithCapacity(boolean preferDirect, int maxCapacity) { + ByteBufAllocator allocator = newAllocator(preferDirect); + ByteBuf buffer = allocator.buffer(1, maxCapacity); + try { + assertBuffer(buffer, isDirectExpected(preferDirect), 1, maxCapacity); + } finally { + buffer.release(); + } + } + + protected abstract boolean isDirectExpected(boolean preferDirect); + + @Test + public void testHeapBuffer() { + testHeapBuffer(true); + testHeapBuffer(false); + } + + private void testHeapBuffer(boolean preferDirect) { + ByteBufAllocator allocator = newAllocator(preferDirect); + ByteBuf buffer = allocator.heapBuffer(1); + try { + assertBuffer(buffer, false, 1, defaultMaxCapacity()); + } finally { + buffer.release(); + } + } + + @Test + public void testHeapBufferMaxCapacity() { + testHeapBuffer(true, 8); + testHeapBuffer(false, 8); + } + + private void testHeapBuffer(boolean preferDirect, int maxCapacity) { + ByteBufAllocator allocator = newAllocator(preferDirect); + ByteBuf buffer = allocator.heapBuffer(1, maxCapacity); + try { + assertBuffer(buffer, false, 1, maxCapacity); + } finally { + buffer.release(); + } + } + + @Test + public void testDirectBuffer() { + testDirectBuffer(true); + testDirectBuffer(false); + } + + private void testDirectBuffer(boolean preferDirect) { + ByteBufAllocator allocator = newAllocator(preferDirect); + ByteBuf buffer = allocator.directBuffer(1); + try { + assertBuffer(buffer, true, 1, defaultMaxCapacity()); + } finally { + buffer.release(); + } + } + + @Test + public void testDirectBufferMaxCapacity() { + testDirectBuffer(true, 8); + testDirectBuffer(false, 8); + } + + private void testDirectBuffer(boolean preferDirect, int maxCapacity) { + ByteBufAllocator allocator = newAllocator(preferDirect); + ByteBuf buffer = allocator.directBuffer(1, maxCapacity); + try { + assertBuffer(buffer, true, 1, maxCapacity); + } finally { + buffer.release(); + } + } + + @Test + public void testCompositeBuffer() { + testCompositeBuffer(true); + testCompositeBuffer(false); + } + + private void testCompositeBuffer(boolean preferDirect) { + ByteBufAllocator allocator = newAllocator(preferDirect); + CompositeByteBuf buffer = allocator.compositeBuffer(); + try { + assertCompositeByteBuf(buffer, defaultMaxComponents()); + } finally { + buffer.release(); + } + } + + @Test + public void testCompositeBufferWithCapacity() { + testCompositeHeapBufferWithCapacity(true, 8); + testCompositeHeapBufferWithCapacity(false, 8); + } + + @Test + public void testCompositeHeapBuffer() { + testCompositeHeapBuffer(true); + testCompositeHeapBuffer(false); + } + + private void testCompositeHeapBuffer(boolean preferDirect) { + ByteBufAllocator allocator = newAllocator(preferDirect); + CompositeByteBuf buffer = allocator.compositeHeapBuffer(); + try { + assertCompositeByteBuf(buffer, defaultMaxComponents()); + } finally { + buffer.release(); + } + } + + @Test + public void testCompositeHeapBufferWithCapacity() { + testCompositeHeapBufferWithCapacity(true, 8); + testCompositeHeapBufferWithCapacity(false, 8); + } + + private void testCompositeHeapBufferWithCapacity(boolean preferDirect, int maxNumComponents) { + ByteBufAllocator allocator = newAllocator(preferDirect); + CompositeByteBuf buffer = allocator.compositeHeapBuffer(maxNumComponents); + try { + assertCompositeByteBuf(buffer, maxNumComponents); + } finally { + buffer.release(); + } + } + + @Test + public void testCompositeDirectBuffer() { + testCompositeDirectBuffer(true); + testCompositeDirectBuffer(false); + } + + private void testCompositeDirectBuffer(boolean preferDirect) { + ByteBufAllocator allocator = newAllocator(preferDirect); + CompositeByteBuf buffer = allocator.compositeDirectBuffer(); + try { + assertCompositeByteBuf(buffer, defaultMaxComponents()); + } finally { + buffer.release(); + } + } + + @Test + public void testCompositeDirectBufferWithCapacity() { + testCompositeDirectBufferWithCapacity(true, 8); + testCompositeDirectBufferWithCapacity(false, 8); + } + + private void testCompositeDirectBufferWithCapacity(boolean preferDirect, int maxNumComponents) { + ByteBufAllocator allocator = newAllocator(preferDirect); + CompositeByteBuf buffer = allocator.compositeDirectBuffer(maxNumComponents); + try { + assertCompositeByteBuf(buffer, maxNumComponents); + } finally { + buffer.release(); + } + } + + private static void assertBuffer( + ByteBuf buffer, boolean expectedDirect, int expectedCapacity, int expectedMaxCapacity) { + assertEquals(expectedDirect, buffer.isDirect()); + assertEquals(expectedCapacity, buffer.capacity()); + assertEquals(expectedMaxCapacity, buffer.maxCapacity()); + } + + private void assertCompositeByteBuf( + CompositeByteBuf buffer, int expectedMaxNumComponents) { + assertEquals(0, buffer.numComponents()); + assertEquals(expectedMaxNumComponents, buffer.maxNumComponents()); + assertBuffer(buffer, false, 0, defaultMaxCapacity()); + } +} diff --git a/netty-buffer/src/test/java/io/netty/buffer/ByteBufDerivationTest.java b/netty-buffer/src/test/java/io/netty/buffer/ByteBufDerivationTest.java new file mode 100644 index 0000000..ab4a59c --- /dev/null +++ b/netty-buffer/src/test/java/io/netty/buffer/ByteBufDerivationTest.java @@ -0,0 +1,217 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.buffer; + +import org.junit.jupiter.api.Test; + +import java.nio.ByteOrder; +import java.util.Random; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.*; + +/** + * Tests wrapping a wrapped buffer does not go way too deep chaining. + */ +public class ByteBufDerivationTest { + + @Test + public void testSlice() throws Exception { + ByteBuf buf = Unpooled.buffer(8).setIndex(1, 7); + ByteBuf slice = buf.slice(1, 7); + + assertThat(slice, instanceOf(AbstractUnpooledSlicedByteBuf.class)); + assertThat(slice.unwrap(), sameInstance(buf)); + assertThat(slice.readerIndex(), is(0)); + assertThat(slice.writerIndex(), is(7)); + assertThat(slice.capacity(), is(7)); + assertThat(slice.maxCapacity(), is(7)); + + slice.setIndex(1, 6); + assertThat(buf.readerIndex(), is(1)); + assertThat(buf.writerIndex(), is(7)); + } + + @Test + public void testSliceOfSlice() throws Exception { + ByteBuf buf = Unpooled.buffer(8); + ByteBuf slice = buf.slice(1, 7); + ByteBuf slice2 = slice.slice(0, 6); + + assertThat(slice2, not(sameInstance(slice))); + assertThat(slice2, instanceOf(AbstractUnpooledSlicedByteBuf.class)); + assertThat(slice2.unwrap(), sameInstance(buf)); + assertThat(slice2.writerIndex(), is(6)); + assertThat(slice2.capacity(), is(6)); + } + + @Test + public void testDuplicate() throws Exception { + ByteBuf buf = Unpooled.buffer(8).setIndex(1, 7); + ByteBuf dup = buf.duplicate(); + + assertThat(dup, instanceOf(DuplicatedByteBuf.class)); + assertThat(dup.unwrap(), sameInstance(buf)); + assertThat(dup.readerIndex(), is(buf.readerIndex())); + assertThat(dup.writerIndex(), is(buf.writerIndex())); + assertThat(dup.capacity(), is(buf.capacity())); + assertThat(dup.maxCapacity(), is(buf.maxCapacity())); + + dup.setIndex(2, 6); + assertThat(buf.readerIndex(), is(1)); + assertThat(buf.writerIndex(), is(7)); + } + + @Test + public void testDuplicateOfDuplicate() throws Exception { + ByteBuf buf = Unpooled.buffer(8).setIndex(1, 7); + ByteBuf dup = buf.duplicate().setIndex(2, 6); + ByteBuf dup2 = dup.duplicate(); + + assertThat(dup2, not(sameInstance(dup))); + assertThat(dup2, instanceOf(DuplicatedByteBuf.class)); + assertThat(dup2.unwrap(), sameInstance(buf)); + assertThat(dup2.readerIndex(), is(dup.readerIndex())); + assertThat(dup2.writerIndex(), is(dup.writerIndex())); + assertThat(dup2.capacity(), is(dup.capacity())); + assertThat(dup2.maxCapacity(), is(dup.maxCapacity())); + } + + @Test + public void testReadOnly() throws Exception { + ByteBuf buf = Unpooled.buffer(8).setIndex(1, 7); + ByteBuf ro = Unpooled.unmodifiableBuffer(buf); + + assertThat(ro, instanceOf(ReadOnlyByteBuf.class)); + assertThat(ro.unwrap(), sameInstance(buf)); + assertThat(ro.readerIndex(), is(buf.readerIndex())); + assertThat(ro.writerIndex(), is(buf.writerIndex())); + assertThat(ro.capacity(), is(buf.capacity())); + assertThat(ro.maxCapacity(), is(buf.maxCapacity())); + + ro.setIndex(2, 6); + assertThat(buf.readerIndex(), is(1)); + } + + @Test + public void testReadOnlyOfReadOnly() throws Exception { + ByteBuf buf = Unpooled.buffer(8).setIndex(1, 7); + ByteBuf ro = Unpooled.unmodifiableBuffer(buf).setIndex(2, 6); + ByteBuf ro2 = Unpooled.unmodifiableBuffer(ro); + + assertThat(ro2, not(sameInstance(ro))); + assertThat(ro2, instanceOf(ReadOnlyByteBuf.class)); + assertThat(ro2.unwrap(), sameInstance(buf)); + assertThat(ro2.readerIndex(), is(ro.readerIndex())); + assertThat(ro2.writerIndex(), is(ro.writerIndex())); + assertThat(ro2.capacity(), is(ro.capacity())); + assertThat(ro2.maxCapacity(), is(ro.maxCapacity())); + } + + @Test + public void testReadOnlyOfDuplicate() throws Exception { + ByteBuf buf = Unpooled.buffer(8).setIndex(1, 7); + ByteBuf dup = buf.duplicate().setIndex(2, 6); + ByteBuf ro = Unpooled.unmodifiableBuffer(dup); + + assertThat(ro, instanceOf(ReadOnlyByteBuf.class)); + assertThat(ro.unwrap(), sameInstance(buf)); + assertThat(ro.readerIndex(), is(dup.readerIndex())); + assertThat(ro.writerIndex(), is(dup.writerIndex())); + assertThat(ro.capacity(), is(dup.capacity())); + assertThat(ro.maxCapacity(), is(dup.maxCapacity())); + } + + @Test + public void testDuplicateOfReadOnly() throws Exception { + ByteBuf buf = Unpooled.buffer(8).setIndex(1, 7); + ByteBuf ro = Unpooled.unmodifiableBuffer(buf).setIndex(2, 6); + ByteBuf dup = ro.duplicate(); + + assertThat(dup, instanceOf(ReadOnlyByteBuf.class)); + assertThat(dup.unwrap(), sameInstance(buf)); + assertThat(dup.readerIndex(), is(ro.readerIndex())); + assertThat(dup.writerIndex(), is(ro.writerIndex())); + assertThat(dup.capacity(), is(ro.capacity())); + assertThat(dup.maxCapacity(), is(ro.maxCapacity())); + } + + @Test + public void testSwap() throws Exception { + ByteBuf buf = Unpooled.buffer(8).setIndex(1, 7); + ByteBuf swapped = buf.order(ByteOrder.LITTLE_ENDIAN); + + assertThat(swapped, instanceOf(SwappedByteBuf.class)); + assertThat(swapped.unwrap(), sameInstance(buf)); + assertThat(swapped.order(ByteOrder.LITTLE_ENDIAN), sameInstance(swapped)); + assertThat(swapped.order(ByteOrder.BIG_ENDIAN), sameInstance(buf)); + + buf.setIndex(2, 6); + assertThat(swapped.readerIndex(), is(2)); + assertThat(swapped.writerIndex(), is(6)); + } + + @Test + public void testMixture() throws Exception { + ByteBuf buf = Unpooled.buffer(10000); + ByteBuf derived = buf; + Random rnd = new Random(); + for (int i = 0; i < buf.capacity(); i ++) { + ByteBuf newDerived; + switch (rnd.nextInt(4)) { + case 0: + newDerived = derived.slice(1, derived.capacity() - 1); + break; + case 1: + newDerived = derived.duplicate(); + break; + case 2: + newDerived = derived.order( + derived.order() == ByteOrder.BIG_ENDIAN ? ByteOrder.LITTLE_ENDIAN : ByteOrder.BIG_ENDIAN); + break; + case 3: + newDerived = Unpooled.unmodifiableBuffer(derived); + break; + default: + throw new Error(); + } + + assertThat("nest level of " + newDerived, nestLevel(newDerived), is(lessThanOrEqualTo(3))); + assertThat( + "nest level of " + newDerived.order(ByteOrder.BIG_ENDIAN), + nestLevel(newDerived.order(ByteOrder.BIG_ENDIAN)), is(lessThanOrEqualTo(2))); + + derived = newDerived; + } + } + + private static int nestLevel(ByteBuf buf) { + int depth = 0; + for (ByteBuf b = buf.order(ByteOrder.BIG_ENDIAN);;) { + if (b.unwrap() == null && !(b instanceof SwappedByteBuf)) { + break; + } + depth ++; + if (b instanceof SwappedByteBuf) { + b = b.order(ByteOrder.BIG_ENDIAN); + } else { + b = b.unwrap(); + } + } + return depth; + } +} diff --git a/netty-buffer/src/test/java/io/netty/buffer/ByteBufStreamTest.java b/netty-buffer/src/test/java/io/netty/buffer/ByteBufStreamTest.java new file mode 100644 index 0000000..8c3375b --- /dev/null +++ b/netty-buffer/src/test/java/io/netty/buffer/ByteBufStreamTest.java @@ -0,0 +1,310 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import java.io.EOFException; +import java.io.IOException; +import java.nio.charset.Charset; + +import static io.netty.util.internal.EmptyArrays.*; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +/** + * Tests channel buffer streams + */ +public class ByteBufStreamTest { + + @Test + public void testAll() throws Exception { + ByteBuf buf = Unpooled.buffer(0, 65536); + + try { + new ByteBufOutputStream(null); + fail(); + } catch (NullPointerException e) { + // Expected + } + + ByteBufOutputStream out = new ByteBufOutputStream(buf); + try { + assertSame(buf, out.buffer()); + out.writeBoolean(true); + out.writeBoolean(false); + out.writeByte(42); + out.writeByte(224); + out.writeBytes("Hello, World!"); + out.writeChars("Hello, World"); + out.writeChar('!'); + out.writeDouble(42.0); + out.writeFloat(42.0f); + out.writeInt(42); + out.writeLong(42); + out.writeShort(42); + out.writeShort(49152); + out.writeUTF("Hello, World!"); + out.writeBytes("The first line\r\r\n"); + out.write(EMPTY_BYTES); + out.write(new byte[]{1, 2, 3, 4}); + out.write(new byte[]{1, 3, 3, 4}, 0, 0); + } finally { + out.close(); + } + + try { + new ByteBufInputStream(null, true); + fail(); + } catch (NullPointerException e) { + // Expected + } + + try { + new ByteBufInputStream(null, 0, true); + fail(); + } catch (NullPointerException e) { + // Expected + } + + try { + new ByteBufInputStream(buf.retainedSlice(), -1, true); + } catch (IllegalArgumentException e) { + // Expected + } + + try { + new ByteBufInputStream(buf.retainedSlice(), buf.capacity() + 1, true); + } catch (IndexOutOfBoundsException e) { + // Expected + } + + ByteBufInputStream in = new ByteBufInputStream(buf, true); + try { + assertTrue(in.markSupported()); + in.mark(Integer.MAX_VALUE); + + assertEquals(buf.writerIndex(), in.skip(Long.MAX_VALUE)); + assertFalse(buf.isReadable()); + + in.reset(); + assertEquals(0, buf.readerIndex()); + + assertEquals(4, in.skip(4)); + assertEquals(4, buf.readerIndex()); + in.reset(); + + assertTrue(in.readBoolean()); + assertFalse(in.readBoolean()); + assertEquals(42, in.readByte()); + assertEquals(224, in.readUnsignedByte()); + + byte[] tmp = new byte[13]; + in.readFully(tmp); + assertEquals("Hello, World!", new String(tmp, "ISO-8859-1")); + + assertEquals('H', in.readChar()); + assertEquals('e', in.readChar()); + assertEquals('l', in.readChar()); + assertEquals('l', in.readChar()); + assertEquals('o', in.readChar()); + assertEquals(',', in.readChar()); + assertEquals(' ', in.readChar()); + assertEquals('W', in.readChar()); + assertEquals('o', in.readChar()); + assertEquals('r', in.readChar()); + assertEquals('l', in.readChar()); + assertEquals('d', in.readChar()); + assertEquals('!', in.readChar()); + + assertEquals(42.0, in.readDouble(), 0.0); + assertEquals(42.0f, in.readFloat(), 0.0); + assertEquals(42, in.readInt()); + assertEquals(42, in.readLong()); + assertEquals(42, in.readShort()); + assertEquals(49152, in.readUnsignedShort()); + + assertEquals("Hello, World!", in.readUTF()); + assertEquals("The first line", in.readLine()); + assertEquals("", in.readLine()); + + assertEquals(4, in.read(tmp)); + assertEquals(1, tmp[0]); + assertEquals(2, tmp[1]); + assertEquals(3, tmp[2]); + assertEquals(4, tmp[3]); + + assertEquals(-1, in.read()); + assertEquals(-1, in.read(tmp)); + + try { + in.readByte(); + fail(); + } catch (EOFException e) { + // Expected + } + + try { + in.readFully(tmp, 0, -1); + fail(); + } catch (IndexOutOfBoundsException e) { + // Expected + } + + try { + in.readFully(tmp); + fail(); + } catch (EOFException e) { + // Expected + } + } finally { + // Ownership was transferred to the ByteBufOutputStream, before we close we must retain the underlying + // buffer. + buf.retain(); + in.close(); + } + + assertEquals(buf.readerIndex(), in.readBytes()); + buf.release(); + } + + @Test + public void testReadLine() throws Exception { + Charset utf8 = Charset.forName("UTF-8"); + ByteBuf buf = Unpooled.buffer(); + ByteBufInputStream in = new ByteBufInputStream(buf, true); + + String s = in.readLine(); + assertNull(s); + in.close(); + + ByteBuf buf2 = Unpooled.buffer(); + int charCount = 7; //total chars in the string below without new line characters + byte[] abc = "\na\n\nb\r\nc\nd\ne".getBytes(utf8); + buf2.writeBytes(abc); + + ByteBufInputStream in2 = new ByteBufInputStream(buf2, true); + in2.mark(charCount); + assertEquals("", in2.readLine()); + assertEquals("a", in2.readLine()); + assertEquals("", in2.readLine()); + assertEquals("b", in2.readLine()); + assertEquals("c", in2.readLine()); + assertEquals("d", in2.readLine()); + assertEquals("e", in2.readLine()); + assertNull(in.readLine()); + + in2.reset(); + int count = 0; + while (in2.readLine() != null) { + ++count; + if (count > charCount) { + fail("readLine() should have returned null"); + } + } + assertEquals(charCount, count); + in2.close(); + } + + @Test + public void testRead() throws Exception { + // case1 + ByteBuf buf = Unpooled.buffer(16); + buf.writeBytes(new byte[]{1, 2, 3, 4, 5, 6}); + + ByteBufInputStream in = new ByteBufInputStream(buf, 3); + + assertEquals(1, in.read()); + assertEquals(2, in.read()); + assertEquals(3, in.read()); + assertEquals(-1, in.read()); + assertEquals(-1, in.read()); + assertEquals(-1, in.read()); + + buf.release(); + in.close(); + + // case2 + ByteBuf buf2 = Unpooled.buffer(16); + buf2.writeBytes(new byte[]{1, 2, 3, 4, 5, 6}); + + ByteBufInputStream in2 = new ByteBufInputStream(buf2, 4); + + assertEquals(1, in2.read()); + assertEquals(2, in2.read()); + assertEquals(3, in2.read()); + assertEquals(4, in2.read()); + assertNotEquals(5, in2.read()); + assertEquals(-1, in2.read()); + + buf2.release(); + in2.close(); + } + + @Test + public void testReadLineLengthRespected1() throws Exception { + // case1 + ByteBuf buf = Unpooled.buffer(16); + buf.writeBytes(new byte[] { 1, 2, 3, 4, 5, 6 }); + + ByteBufInputStream in = new ByteBufInputStream(buf, 0); + + assertNull(in.readLine()); + buf.release(); + in.close(); + } + + @Test + public void testReadLineLengthRespected2() throws Exception { + ByteBuf buf2 = Unpooled.buffer(16); + buf2.writeBytes(new byte[] { 'A', 'B', '\n', 'C', 'E', 'F'}); + + ByteBufInputStream in2 = new ByteBufInputStream(buf2, 4); + + assertEquals("AB", in2.readLine()); + assertEquals("C", in2.readLine()); + assertNull(in2.readLine()); + buf2.release(); + in2.close(); + } + + @Test + public void testReadByteLengthRespected() throws Exception { + // case1 + ByteBuf buf = Unpooled.buffer(16); + buf.writeBytes(new byte[] { 1, 2, 3, 4, 5, 6 }); + + final ByteBufInputStream in = new ByteBufInputStream(buf, 0); + try { + assertThrows(EOFException.class, new Executable() { + @Override + public void execute() throws IOException { + in.readBoolean(); + } + }); + } finally { + buf.release(); + in.close(); + } + } +} diff --git a/netty-buffer/src/test/java/io/netty/buffer/ByteBufUtilTest.java b/netty-buffer/src/test/java/io/netty/buffer/ByteBufUtilTest.java new file mode 100644 index 0000000..a726f8f --- /dev/null +++ b/netty-buffer/src/test/java/io/netty/buffer/ByteBufUtilTest.java @@ -0,0 +1,1086 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import io.netty.util.AsciiString; +import io.netty.util.CharsetUtil; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import java.nio.ByteOrder; +import java.nio.charset.Charset; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.Random; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; + +import static io.netty.buffer.Unpooled.unreleasableBuffer; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +public class ByteBufUtilTest { + private static final String PARAMETERIZED_NAME = "bufferType = {0}"; + + private enum BufferType { + DIRECT_UNPOOLED, DIRECT_POOLED, HEAP_POOLED, HEAP_UNPOOLED + } + + private ByteBuf buffer(BufferType bufferType, int capacity) { + switch (bufferType) { + + case DIRECT_UNPOOLED: + return Unpooled.directBuffer(capacity); + case HEAP_UNPOOLED: + return Unpooled.buffer(capacity); + case DIRECT_POOLED: + return PooledByteBufAllocator.DEFAULT.directBuffer(capacity); + case HEAP_POOLED: + return PooledByteBufAllocator.DEFAULT.buffer(capacity); + default: + throw new AssertionError("unexpected buffer type: " + bufferType); + } + } + + public static Collection noUnsafe() { + return Arrays.asList(new Object[][] { + { BufferType.DIRECT_POOLED }, + { BufferType.DIRECT_UNPOOLED }, + { BufferType.HEAP_POOLED }, + { BufferType.HEAP_UNPOOLED } + }); + } + + @Test + public void decodeRandomHexBytesWithEvenLength() { + decodeRandomHexBytes(256); + } + + @Test + public void decodeRandomHexBytesWithOddLength() { + decodeRandomHexBytes(257); + } + + private static void decodeRandomHexBytes(int len) { + byte[] b = new byte[len]; + Random rand = new Random(); + rand.nextBytes(b); + String hexDump = ByteBufUtil.hexDump(b); + for (int i = 0; i <= len; i++) { // going over sub-strings of various lengths including empty byte[]. + byte[] b2 = Arrays.copyOfRange(b, i, b.length); + byte[] decodedBytes = ByteBufUtil.decodeHexDump(hexDump, i * 2, (len - i) * 2); + assertArrayEquals(b2, decodedBytes); + } + } + + @Test + public void decodeHexDumpWithOddLength() { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() throws Throwable { + ByteBufUtil.decodeHexDump("abc"); + } + }); + } + + @Test + public void decodeHexDumpWithInvalidChar() { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() throws Throwable { + ByteBufUtil.decodeHexDump("fg"); + } + }); + } + + @Test + public void testIndexOf() { + ByteBuf haystack = Unpooled.copiedBuffer("abc123", CharsetUtil.UTF_8); + assertEquals(0, ByteBufUtil.indexOf(Unpooled.copiedBuffer("a", CharsetUtil.UTF_8), haystack)); + assertEquals(1, ByteBufUtil.indexOf(Unpooled.copiedBuffer("bc".getBytes(CharsetUtil.UTF_8)), haystack)); + assertEquals(2, ByteBufUtil.indexOf(Unpooled.copiedBuffer("c".getBytes(CharsetUtil.UTF_8)), haystack)); + assertEquals(0, ByteBufUtil.indexOf(Unpooled.copiedBuffer("abc12".getBytes(CharsetUtil.UTF_8)), haystack)); + assertEquals(-1, ByteBufUtil.indexOf(Unpooled.copiedBuffer("abcdef".getBytes(CharsetUtil.UTF_8)), haystack)); + assertEquals(-1, ByteBufUtil.indexOf(Unpooled.copiedBuffer("abc12x".getBytes(CharsetUtil.UTF_8)), haystack)); + assertEquals(-1, ByteBufUtil.indexOf(Unpooled.copiedBuffer("abc123def".getBytes(CharsetUtil.UTF_8)), haystack)); + + final ByteBuf needle = Unpooled.copiedBuffer("abc12", CharsetUtil.UTF_8); + haystack.readerIndex(1); + needle.readerIndex(1); + assertEquals(1, ByteBufUtil.indexOf(needle, haystack)); + haystack.readerIndex(2); + needle.readerIndex(3); + assertEquals(3, ByteBufUtil.indexOf(needle, haystack)); + haystack.readerIndex(1); + needle.readerIndex(2); + assertEquals(2, ByteBufUtil.indexOf(needle, haystack)); + haystack.release(); + + haystack = new WrappedByteBuf(Unpooled.copiedBuffer("abc123", CharsetUtil.UTF_8)); + assertEquals(0, ByteBufUtil.indexOf(Unpooled.copiedBuffer("a", CharsetUtil.UTF_8), haystack)); + assertEquals(1, ByteBufUtil.indexOf(Unpooled.copiedBuffer("bc".getBytes(CharsetUtil.UTF_8)), haystack)); + assertEquals(-1, ByteBufUtil.indexOf(Unpooled.copiedBuffer("abcdef".getBytes(CharsetUtil.UTF_8)), haystack)); + haystack.release(); + + haystack = Unpooled.copiedBuffer("123aab123", CharsetUtil.UTF_8); + assertEquals(3, ByteBufUtil.indexOf(Unpooled.copiedBuffer("aab", CharsetUtil.UTF_8), haystack)); + haystack.release(); + needle.release(); + } + + @Test + public void equalsBufferSubsections() { + byte[] b1 = new byte[128]; + byte[] b2 = new byte[256]; + Random rand = new Random(); + rand.nextBytes(b1); + rand.nextBytes(b2); + final int iB1 = b1.length / 2; + final int iB2 = iB1 + b1.length; + final int length = b1.length - iB1; + System.arraycopy(b1, iB1, b2, iB2, length); + assertTrue(ByteBufUtil.equals(Unpooled.wrappedBuffer(b1), iB1, Unpooled.wrappedBuffer(b2), iB2, length)); + } + + private static int random(Random r, int min, int max) { + return r.nextInt((max - min) + 1) + min; + } + + @Test + public void notEqualsBufferSubsections() { + byte[] b1 = new byte[50]; + byte[] b2 = new byte[256]; + Random rand = new Random(); + rand.nextBytes(b1); + rand.nextBytes(b2); + final int iB1 = b1.length / 2; + final int iB2 = iB1 + b1.length; + final int length = b1.length - iB1; + System.arraycopy(b1, iB1, b2, iB2, length); + // Randomly pick an index in the range that will be compared and make the value at that index differ between + // the 2 arrays. + int diffIndex = random(rand, iB1, iB1 + length - 1); + ++b1[diffIndex]; + assertFalse(ByteBufUtil.equals(Unpooled.wrappedBuffer(b1), iB1, Unpooled.wrappedBuffer(b2), iB2, length)); + } + + @Test + public void notEqualsBufferOverflow() { + byte[] b1 = new byte[8]; + byte[] b2 = new byte[16]; + Random rand = new Random(); + rand.nextBytes(b1); + rand.nextBytes(b2); + final int iB1 = b1.length / 2; + final int iB2 = iB1 + b1.length; + final int length = b1.length - iB1; + System.arraycopy(b1, iB1, b2, iB2, length - 1); + assertFalse(ByteBufUtil.equals(Unpooled.wrappedBuffer(b1), iB1, Unpooled.wrappedBuffer(b2), iB2, + Math.max(b1.length, b2.length) * 2)); + } + + @Test + public void notEqualsBufferUnderflow() { + final byte[] b1 = new byte[8]; + final byte[] b2 = new byte[16]; + Random rand = new Random(); + rand.nextBytes(b1); + rand.nextBytes(b2); + final int iB1 = b1.length / 2; + final int iB2 = iB1 + b1.length; + final int length = b1.length - iB1; + System.arraycopy(b1, iB1, b2, iB2, length - 1); + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + ByteBufUtil.equals(Unpooled.wrappedBuffer(b1), iB1, Unpooled.wrappedBuffer(b2), iB2, -1); + } + }); + } + + @SuppressWarnings("deprecation") + @ParameterizedTest(name = PARAMETERIZED_NAME) + @MethodSource("noUnsafe") + public void writeShortBE(BufferType bufferType) { + int expected = 0x1234; + + ByteBuf buf = buffer(bufferType, 2).order(ByteOrder.BIG_ENDIAN); + ByteBufUtil.writeShortBE(buf, expected); + assertEquals(expected, buf.readShort()); + buf.resetReaderIndex(); + assertEquals(ByteBufUtil.swapShort((short) expected), buf.readShortLE()); + buf.release(); + + buf = buffer(bufferType, 2).order(ByteOrder.LITTLE_ENDIAN); + ByteBufUtil.writeShortBE(buf, expected); + assertEquals(ByteBufUtil.swapShort((short) expected), buf.readShortLE()); + buf.resetReaderIndex(); + assertEquals(ByteBufUtil.swapShort((short) expected), buf.readShort()); + buf.release(); + } + + @SuppressWarnings("deprecation") + @Test + public void setShortBE() { + int shortValue = 0x1234; + + ByteBuf buf = Unpooled.wrappedBuffer(new byte[2]).order(ByteOrder.BIG_ENDIAN); + ByteBufUtil.setShortBE(buf, 0, shortValue); + assertEquals(shortValue, buf.readShort()); + buf.resetReaderIndex(); + assertEquals(ByteBufUtil.swapShort((short) shortValue), buf.readShortLE()); + buf.release(); + + buf = Unpooled.wrappedBuffer(new byte[2]).order(ByteOrder.LITTLE_ENDIAN); + ByteBufUtil.setShortBE(buf, 0, shortValue); + assertEquals(ByteBufUtil.swapShort((short) shortValue), buf.readShortLE()); + buf.resetReaderIndex(); + assertEquals(ByteBufUtil.swapShort((short) shortValue), buf.readShort()); + buf.release(); + } + + @SuppressWarnings("deprecation") + @ParameterizedTest(name = PARAMETERIZED_NAME) + @MethodSource("noUnsafe") + public void writeMediumBE(BufferType bufferType) { + int mediumValue = 0x123456; + + ByteBuf buf = buffer(bufferType, 4).order(ByteOrder.BIG_ENDIAN); + ByteBufUtil.writeMediumBE(buf, mediumValue); + assertEquals(mediumValue, buf.readMedium()); + buf.resetReaderIndex(); + assertEquals(ByteBufUtil.swapMedium(mediumValue), buf.readMediumLE()); + buf.release(); + + buf = buffer(bufferType, 4).order(ByteOrder.LITTLE_ENDIAN); + ByteBufUtil.writeMediumBE(buf, mediumValue); + assertEquals(ByteBufUtil.swapMedium(mediumValue), buf.readMediumLE()); + buf.resetReaderIndex(); + assertEquals(ByteBufUtil.swapMedium(mediumValue), buf.readMedium()); + buf.release(); + } + + @SuppressWarnings("deprecation") + @ParameterizedTest(name = PARAMETERIZED_NAME) + @MethodSource("noUnsafe") + public void readUnsignedShortBE(BufferType bufferType) { + int shortValue = 0x1234; // unsigned short + int swappedShortValue = 0x3412; // swapped version of the value above + + ByteBuf buf = buffer(bufferType, 2).order(ByteOrder.BIG_ENDIAN); + buf.writeShort(shortValue); + assertEquals(shortValue, ByteBufUtil.readUnsignedShortBE(buf)); + buf.clear(); + buf.writeShortLE(shortValue); + assertEquals(swappedShortValue, ByteBufUtil.readUnsignedShortBE(buf)); + buf.release(); + + buf = buffer(bufferType, 2).order(ByteOrder.LITTLE_ENDIAN); + buf.writeShort(shortValue); + assertEquals(swappedShortValue, ByteBufUtil.readUnsignedShortBE(buf)); + buf.clear(); + buf.writeShortLE(shortValue); + assertEquals(swappedShortValue, ByteBufUtil.readUnsignedShortBE(buf)); + buf.release(); + + shortValue = 0xfedc; // unsigned short + swappedShortValue = 0xdcfe; // swapped version of the value above + + buf = buffer(bufferType, 2).order(ByteOrder.BIG_ENDIAN); + buf.writeShort(shortValue); + assertEquals(shortValue, ByteBufUtil.readUnsignedShortBE(buf)); + buf.clear(); + buf.writeShortLE(shortValue); + assertEquals(swappedShortValue, ByteBufUtil.readUnsignedShortBE(buf)); + buf.release(); + + buf = buffer(bufferType, 2).order(ByteOrder.LITTLE_ENDIAN); + buf.writeShort(shortValue); + assertEquals(swappedShortValue, ByteBufUtil.readUnsignedShortBE(buf)); + buf.clear(); + buf.writeShortLE(shortValue); + assertEquals(swappedShortValue, ByteBufUtil.readUnsignedShortBE(buf)); + buf.release(); + } + + @SuppressWarnings("deprecation") + @ParameterizedTest(name = PARAMETERIZED_NAME) + @MethodSource("noUnsafe") + public void readIntBE(BufferType bufferType) { + int intValue = 0x12345678; + + ByteBuf buf = buffer(bufferType, 4).order(ByteOrder.BIG_ENDIAN); + buf.writeInt(intValue); + assertEquals(intValue, ByteBufUtil.readIntBE(buf)); + buf.clear(); + buf.writeIntLE(intValue); + assertEquals(ByteBufUtil.swapInt(intValue), ByteBufUtil.readIntBE(buf)); + buf.release(); + + buf = buffer(bufferType, 4).order(ByteOrder.LITTLE_ENDIAN); + buf.writeInt(intValue); + assertEquals(ByteBufUtil.swapInt(intValue), ByteBufUtil.readIntBE(buf)); + buf.clear(); + buf.writeIntLE(intValue); + assertEquals(ByteBufUtil.swapInt(intValue), ByteBufUtil.readIntBE(buf)); + buf.release(); + } + + @ParameterizedTest(name = PARAMETERIZED_NAME) + @MethodSource("noUnsafe") + public void testWriteUsAscii(BufferType bufferType) { + String usAscii = "NettyRocks"; + ByteBuf buf = buffer(bufferType, 16); + buf.writeBytes(usAscii.getBytes(CharsetUtil.US_ASCII)); + ByteBuf buf2 = buffer(bufferType, 16); + ByteBufUtil.writeAscii(buf2, usAscii); + + assertEquals(buf, buf2); + + buf.release(); + buf2.release(); + } + + @ParameterizedTest(name = PARAMETERIZED_NAME) + @MethodSource("noUnsafe") + public void testWriteUsAsciiSwapped(BufferType bufferType) { + String usAscii = "NettyRocks"; + ByteBuf buf = buffer(bufferType, 16); + buf.writeBytes(usAscii.getBytes(CharsetUtil.US_ASCII)); + SwappedByteBuf buf2 = new SwappedByteBuf(buffer(bufferType, 16)); + ByteBufUtil.writeAscii(buf2, usAscii); + + assertEquals(buf, buf2); + + buf.release(); + buf2.release(); + } + + @ParameterizedTest(name = PARAMETERIZED_NAME) + @MethodSource("noUnsafe") + public void testWriteUsAsciiWrapped(BufferType bufferType) { + String usAscii = "NettyRocks"; + ByteBuf buf = unreleasableBuffer(buffer(bufferType, 16)); + assertWrapped(buf); + buf.writeBytes(usAscii.getBytes(CharsetUtil.US_ASCII)); + ByteBuf buf2 = unreleasableBuffer(buffer(bufferType, 16)); + assertWrapped(buf2); + ByteBufUtil.writeAscii(buf2, usAscii); + + assertEquals(buf, buf2); + + buf.unwrap().release(); + buf2.unwrap().release(); + } + + @ParameterizedTest(name = PARAMETERIZED_NAME) + @MethodSource("noUnsafe") + public void testWriteUsAsciiComposite(BufferType bufferType) { + String usAscii = "NettyRocks"; + ByteBuf buf = buffer(bufferType, 16); + buf.writeBytes(usAscii.getBytes(CharsetUtil.US_ASCII)); + ByteBuf buf2 = Unpooled.compositeBuffer().addComponent( + buffer(bufferType, 8)).addComponent(buffer(bufferType, 24)); + // write some byte so we start writing with an offset. + buf2.writeByte(1); + ByteBufUtil.writeAscii(buf2, usAscii); + + // Skip the previously written byte. + assertEquals(buf, buf2.skipBytes(1)); + + buf.release(); + buf2.release(); + } + + @ParameterizedTest(name = PARAMETERIZED_NAME) + @MethodSource("noUnsafe") + public void testWriteUsAsciiCompositeWrapped(BufferType bufferType) { + String usAscii = "NettyRocks"; + ByteBuf buf = buffer(bufferType, 16); + buf.writeBytes(usAscii.getBytes(CharsetUtil.US_ASCII)); + ByteBuf buf2 = new WrappedCompositeByteBuf(Unpooled.compositeBuffer().addComponent( + buffer(bufferType, 8)).addComponent(buffer(bufferType, 24))); + // write some byte so we start writing with an offset. + buf2.writeByte(1); + ByteBufUtil.writeAscii(buf2, usAscii); + + // Skip the previously written byte. + assertEquals(buf, buf2.skipBytes(1)); + + buf.release(); + buf2.release(); + } + + @ParameterizedTest(name = PARAMETERIZED_NAME) + @MethodSource("noUnsafe") + public void testWriteUtf8(BufferType bufferType) { + String usAscii = "Some UTF-8 like äÄ∏ŒŒ"; + ByteBuf buf = buffer(bufferType, 16); + buf.writeBytes(usAscii.getBytes(CharsetUtil.UTF_8)); + ByteBuf buf2 = buffer(bufferType, 16); + ByteBufUtil.writeUtf8(buf2, usAscii); + + assertEquals(buf, buf2); + + buf.release(); + buf2.release(); + } + + @ParameterizedTest(name = PARAMETERIZED_NAME) + @MethodSource("noUnsafe") + public void testWriteUtf8Composite(BufferType bufferType) { + String utf8 = "Some UTF-8 like äÄ∏ŒŒ"; + ByteBuf buf = buffer(bufferType, 16); + buf.writeBytes(utf8.getBytes(CharsetUtil.UTF_8)); + ByteBuf buf2 = Unpooled.compositeBuffer().addComponent( + buffer(bufferType, 8)).addComponent(buffer(bufferType, 24)); + // write some byte so we start writing with an offset. + buf2.writeByte(1); + ByteBufUtil.writeUtf8(buf2, utf8); + + // Skip the previously written byte. + assertEquals(buf, buf2.skipBytes(1)); + + buf.release(); + buf2.release(); + } + + @ParameterizedTest(name = PARAMETERIZED_NAME) + @MethodSource("noUnsafe") + public void testWriteUtf8CompositeWrapped(BufferType bufferType) { + String utf8 = "Some UTF-8 like äÄ∏ŒŒ"; + ByteBuf buf = buffer(bufferType, 16); + buf.writeBytes(utf8.getBytes(CharsetUtil.UTF_8)); + ByteBuf buf2 = new WrappedCompositeByteBuf(Unpooled.compositeBuffer().addComponent( + buffer(bufferType, 8)).addComponent(buffer(bufferType, 24))); + // write some byte so we start writing with an offset. + buf2.writeByte(1); + ByteBufUtil.writeUtf8(buf2, utf8); + + // Skip the previously written byte. + assertEquals(buf, buf2.skipBytes(1)); + + buf.release(); + buf2.release(); + } + + @ParameterizedTest(name = PARAMETERIZED_NAME) + @MethodSource("noUnsafe") + public void testWriteUtf8Surrogates(BufferType bufferType) { + // leading surrogate + trailing surrogate + String surrogateString = new StringBuilder(2) + .append('a') + .append('\uD800') + .append('\uDC00') + .append('b') + .toString(); + ByteBuf buf = buffer(bufferType, 16); + buf.writeBytes(surrogateString.getBytes(CharsetUtil.UTF_8)); + ByteBuf buf2 = buffer(bufferType, 16); + ByteBufUtil.writeUtf8(buf2, surrogateString); + + assertEquals(buf, buf2); + assertEquals(buf.readableBytes(), ByteBufUtil.utf8Bytes(surrogateString)); + + buf.release(); + buf2.release(); + } + + @ParameterizedTest(name = PARAMETERIZED_NAME) + @MethodSource("noUnsafe") + public void testWriteUtf8InvalidOnlyTrailingSurrogate(BufferType bufferType) { + String surrogateString = new StringBuilder(2) + .append('a') + .append('\uDC00') + .append('b') + .toString(); + ByteBuf buf = buffer(bufferType, 16); + buf.writeBytes(surrogateString.getBytes(CharsetUtil.UTF_8)); + ByteBuf buf2 = buffer(bufferType, 16); + ByteBufUtil.writeUtf8(buf2, surrogateString); + + assertEquals(buf, buf2); + assertEquals(buf.readableBytes(), ByteBufUtil.utf8Bytes(surrogateString)); + + buf.release(); + buf2.release(); + } + + @ParameterizedTest(name = PARAMETERIZED_NAME) + @MethodSource("noUnsafe") + public void testWriteUtf8InvalidOnlyLeadingSurrogate(BufferType bufferType) { + String surrogateString = new StringBuilder(2) + .append('a') + .append('\uD800') + .append('b') + .toString(); + ByteBuf buf = buffer(bufferType, 16); + buf.writeBytes(surrogateString.getBytes(CharsetUtil.UTF_8)); + ByteBuf buf2 = buffer(bufferType, 16); + ByteBufUtil.writeUtf8(buf2, surrogateString); + + assertEquals(buf, buf2); + assertEquals(buf.readableBytes(), ByteBufUtil.utf8Bytes(surrogateString)); + + buf.release(); + buf2.release(); + } + + @ParameterizedTest(name = PARAMETERIZED_NAME) + @MethodSource("noUnsafe") + public void testWriteUtf8InvalidSurrogatesSwitched(BufferType bufferType) { + String surrogateString = new StringBuilder(2) + .append('a') + .append('\uDC00') + .append('\uD800') + .append('b') + .toString(); + ByteBuf buf = buffer(bufferType, 16); + buf.writeBytes(surrogateString.getBytes(CharsetUtil.UTF_8)); + ByteBuf buf2 = buffer(bufferType, 16); + ByteBufUtil.writeUtf8(buf2, surrogateString); + + assertEquals(buf, buf2); + assertEquals(buf.readableBytes(), ByteBufUtil.utf8Bytes(surrogateString)); + + buf.release(); + buf2.release(); + } + + @ParameterizedTest(name = PARAMETERIZED_NAME) + @MethodSource("noUnsafe") + public void testWriteUtf8InvalidTwoLeadingSurrogates(BufferType bufferType) { + String surrogateString = new StringBuilder(2) + .append('a') + .append('\uD800') + .append('\uD800') + .append('b') + .toString(); + ByteBuf buf = buffer(bufferType, 16); + buf.writeBytes(surrogateString.getBytes(CharsetUtil.UTF_8)); + ByteBuf buf2 = buffer(bufferType, 16); + ByteBufUtil.writeUtf8(buf2, surrogateString); + + assertEquals(buf, buf2); + assertEquals(buf.readableBytes(), ByteBufUtil.utf8Bytes(surrogateString)); + buf.release(); + buf2.release(); + } + + @ParameterizedTest(name = PARAMETERIZED_NAME) + @MethodSource("noUnsafe") + public void testWriteUtf8InvalidTwoTrailingSurrogates(BufferType bufferType) { + String surrogateString = new StringBuilder(2) + .append('a') + .append('\uDC00') + .append('\uDC00') + .append('b') + .toString(); + ByteBuf buf = buffer(bufferType, 16); + buf.writeBytes(surrogateString.getBytes(CharsetUtil.UTF_8)); + ByteBuf buf2 = buffer(bufferType, 16); + ByteBufUtil.writeUtf8(buf2, surrogateString); + + assertEquals(buf, buf2); + assertEquals(buf.readableBytes(), ByteBufUtil.utf8Bytes(surrogateString)); + + buf.release(); + buf2.release(); + } + + @ParameterizedTest(name = PARAMETERIZED_NAME) + @MethodSource("noUnsafe") + public void testWriteUtf8InvalidEndOnLeadingSurrogate(BufferType bufferType) { + String surrogateString = new StringBuilder(2) + .append('\uD800') + .toString(); + ByteBuf buf = buffer(bufferType, 16); + buf.writeBytes(surrogateString.getBytes(CharsetUtil.UTF_8)); + ByteBuf buf2 = buffer(bufferType, 16); + ByteBufUtil.writeUtf8(buf2, surrogateString); + + assertEquals(buf, buf2); + assertEquals(buf.readableBytes(), ByteBufUtil.utf8Bytes(surrogateString)); + + buf.release(); + buf2.release(); + } + + @ParameterizedTest(name = PARAMETERIZED_NAME) + @MethodSource("noUnsafe") + public void testWriteUtf8InvalidEndOnTrailingSurrogate(BufferType bufferType) { + String surrogateString = new StringBuilder(2) + .append('\uDC00') + .toString(); + ByteBuf buf = buffer(bufferType, 16); + buf.writeBytes(surrogateString.getBytes(CharsetUtil.UTF_8)); + ByteBuf buf2 = buffer(bufferType, 16); + ByteBufUtil.writeUtf8(buf2, surrogateString); + + assertEquals(buf, buf2); + assertEquals(buf.readableBytes(), ByteBufUtil.utf8Bytes(surrogateString)); + + buf.release(); + buf2.release(); + } + + @ParameterizedTest(name = PARAMETERIZED_NAME) + @MethodSource("noUnsafe") + public void testWriteUsAsciiString(BufferType bufferType) { + AsciiString usAscii = new AsciiString("NettyRocks"); + int expectedCapacity = usAscii.length(); + ByteBuf buf = buffer(bufferType, expectedCapacity); + buf.writeBytes(usAscii.toString().getBytes(CharsetUtil.US_ASCII)); + ByteBuf buf2 = buffer(bufferType, expectedCapacity); + ByteBufUtil.writeAscii(buf2, usAscii); + + assertEquals(buf, buf2); + + buf.release(); + buf2.release(); + } + + @ParameterizedTest(name = PARAMETERIZED_NAME) + @MethodSource("noUnsafe") + public void testWriteUtf8Wrapped(BufferType bufferType) { + String usAscii = "Some UTF-8 like äÄ∏ŒŒ"; + ByteBuf buf = unreleasableBuffer(buffer(bufferType, 16)); + assertWrapped(buf); + buf.writeBytes(usAscii.getBytes(CharsetUtil.UTF_8)); + ByteBuf buf2 = unreleasableBuffer(buffer(bufferType, 16)); + assertWrapped(buf2); + ByteBufUtil.writeUtf8(buf2, usAscii); + + assertEquals(buf, buf2); + + buf.unwrap().release(); + buf2.unwrap().release(); + } + + private static void assertWrapped(ByteBuf buf) { + assertThat(buf, instanceOf(WrappedByteBuf.class)); + } + + @ParameterizedTest(name = PARAMETERIZED_NAME) + @MethodSource("noUnsafe") + public void testWriteUtf8Subsequence(BufferType bufferType) { + String usAscii = "Some UTF-8 like äÄ∏ŒŒ"; + ByteBuf buf = buffer(bufferType, 16); + buf.writeBytes(usAscii.substring(5, 18).getBytes(CharsetUtil.UTF_8)); + ByteBuf buf2 = buffer(bufferType, 16); + ByteBufUtil.writeUtf8(buf2, usAscii, 5, 18); + + assertEquals(buf, buf2); + + buf.release(); + buf2.release(); + } + + @ParameterizedTest(name = PARAMETERIZED_NAME) + @MethodSource("noUnsafe") + public void testWriteUtf8SubsequenceSplitSurrogate(BufferType bufferType) { + String usAscii = "\uD800\uDC00"; // surrogate pair: one code point, two chars + ByteBuf buf = buffer(bufferType, 16); + buf.writeBytes(usAscii.substring(0, 1).getBytes(CharsetUtil.UTF_8)); + ByteBuf buf2 = buffer(bufferType, 16); + ByteBufUtil.writeUtf8(buf2, usAscii, 0, 1); + + assertEquals(buf, buf2); + + buf.release(); + buf2.release(); + } + + @ParameterizedTest(name = PARAMETERIZED_NAME) + @MethodSource("noUnsafe") + public void testReserveAndWriteUtf8Subsequence(BufferType bufferType) { + String usAscii = "Some UTF-8 like äÄ∏ŒŒ"; + ByteBuf buf = buffer(bufferType, 16); + buf.writeBytes(usAscii.substring(5, 18).getBytes(CharsetUtil.UTF_8)); + ByteBuf buf2 = buffer(bufferType, 16); + int count = ByteBufUtil.reserveAndWriteUtf8(buf2, usAscii, 5, 18, 16); + + assertEquals(buf, buf2); + assertEquals(buf.readableBytes(), count); + + buf.release(); + buf2.release(); + } + + @Test + public void testUtf8BytesSubsequence() { + String usAscii = "Some UTF-8 like äÄ∏ŒŒ"; + assertEquals(usAscii.substring(5, 18).getBytes(CharsetUtil.UTF_8).length, + ByteBufUtil.utf8Bytes(usAscii, 5, 18)); + } + + private static final int[][] INVALID_RANGES = new int[][] { + { -1, 5 }, { 5, 30 }, { 10, 5 } + }; + + interface TestMethod { + int invoke(Object... args); + } + + private void testInvalidSubsequences(BufferType bufferType, TestMethod method) { + for (int [] range : INVALID_RANGES) { + ByteBuf buf = buffer(bufferType, 16); + try { + method.invoke(buf, "Some UTF-8 like äÄ∏ŒŒ", range[0], range[1]); + fail("Did not throw IndexOutOfBoundsException for range (" + range[0] + ", " + range[1] + ")"); + } catch (IndexOutOfBoundsException iiobe) { + // expected + } finally { + assertFalse(buf.isReadable()); + buf.release(); + } + } + } + + @ParameterizedTest(name = PARAMETERIZED_NAME) + @MethodSource("noUnsafe") + public void testWriteUtf8InvalidSubsequences(BufferType bufferType) { + testInvalidSubsequences(bufferType, new TestMethod() { + @Override + public int invoke(Object... args) { + return ByteBufUtil.writeUtf8((ByteBuf) args[0], (String) args[1], + (Integer) args[2], (Integer) args[3]); + } + }); + } + + @ParameterizedTest(name = PARAMETERIZED_NAME) + @MethodSource("noUnsafe") + public void testReserveAndWriteUtf8InvalidSubsequences(BufferType bufferType) { + testInvalidSubsequences(bufferType, new TestMethod() { + @Override + public int invoke(Object... args) { + return ByteBufUtil.reserveAndWriteUtf8((ByteBuf) args[0], (String) args[1], + (Integer) args[2], (Integer) args[3], 32); + } + }); + } + + @ParameterizedTest(name = PARAMETERIZED_NAME) + @MethodSource("noUnsafe") + public void testUtf8BytesInvalidSubsequences(BufferType bufferType) { + testInvalidSubsequences(bufferType, + new TestMethod() { + @Override + public int invoke(Object... args) { + return ByteBufUtil.utf8Bytes((String) args[1], (Integer) args[2], (Integer) args[3]); + } + }); + } + + @Test + public void testDecodeUsAscii() { + testDecodeString("This is a test", CharsetUtil.US_ASCII); + } + + @Test + public void testDecodeUtf8() { + testDecodeString("Some UTF-8 like äÄ∏ŒŒ", CharsetUtil.UTF_8); + } + + private static void testDecodeString(String text, Charset charset) { + ByteBuf buffer = Unpooled.copiedBuffer(text, charset); + assertEquals(text, ByteBufUtil.decodeString(buffer, 0, buffer.readableBytes(), charset)); + buffer.release(); + } + + @ParameterizedTest(name = PARAMETERIZED_NAME) + @MethodSource("noUnsafe") + public void testToStringDoesNotThrowIndexOutOfBounds(BufferType bufferType) { + CompositeByteBuf buffer = Unpooled.compositeBuffer(); + try { + byte[] bytes = "1234".getBytes(CharsetUtil.UTF_8); + buffer.addComponent(buffer(bufferType, bytes.length).writeBytes(bytes)); + buffer.addComponent(buffer(bufferType, bytes.length).writeBytes(bytes)); + assertEquals("1234", buffer.toString(bytes.length, bytes.length, CharsetUtil.UTF_8)); + } finally { + buffer.release(); + } + } + + @ParameterizedTest(name = PARAMETERIZED_NAME) + @MethodSource("noUnsafe") + public void testIsTextWithUtf8(BufferType bufferType) { + byte[][] validUtf8Bytes = { + "netty".getBytes(CharsetUtil.UTF_8), + {(byte) 0x24}, + {(byte) 0xC2, (byte) 0xA2}, + {(byte) 0xE2, (byte) 0x82, (byte) 0xAC}, + {(byte) 0xF0, (byte) 0x90, (byte) 0x8D, (byte) 0x88}, + {(byte) 0x24, + (byte) 0xC2, (byte) 0xA2, + (byte) 0xE2, (byte) 0x82, (byte) 0xAC, + (byte) 0xF0, (byte) 0x90, (byte) 0x8D, (byte) 0x88} // multiple characters + }; + for (byte[] bytes : validUtf8Bytes) { + assertIsText(bufferType, bytes, true, CharsetUtil.UTF_8); + } + byte[][] invalidUtf8Bytes = { + {(byte) 0x80}, + {(byte) 0xF0, (byte) 0x82, (byte) 0x82, (byte) 0xAC}, // Overlong encodings + {(byte) 0xC2}, // not enough bytes + {(byte) 0xE2, (byte) 0x82}, // not enough bytes + {(byte) 0xF0, (byte) 0x90, (byte) 0x8D}, // not enough bytes + {(byte) 0xC2, (byte) 0xC0}, // not correct bytes + {(byte) 0xE2, (byte) 0x82, (byte) 0xC0}, // not correct bytes + {(byte) 0xF0, (byte) 0x90, (byte) 0x8D, (byte) 0xC0}, // not correct bytes + {(byte) 0xC1, (byte) 0x80}, // out of lower bound + {(byte) 0xE0, (byte) 0x80, (byte) 0x80}, // out of lower bound + {(byte) 0xED, (byte) 0xAF, (byte) 0x80} // out of upper bound + }; + for (byte[] bytes : invalidUtf8Bytes) { + assertIsText(bufferType, bytes, false, CharsetUtil.UTF_8); + } + } + + @ParameterizedTest(name = PARAMETERIZED_NAME) + @MethodSource("noUnsafe") + public void testIsTextWithoutOptimization(BufferType bufferType) { + byte[] validBytes = {(byte) 0x01, (byte) 0xD8, (byte) 0x37, (byte) 0xDC}; + byte[] invalidBytes = {(byte) 0x01, (byte) 0xD8}; + + assertIsText(bufferType, validBytes, true, CharsetUtil.UTF_16LE); + assertIsText(bufferType, invalidBytes, false, CharsetUtil.UTF_16LE); + } + + @ParameterizedTest(name = PARAMETERIZED_NAME) + @MethodSource("noUnsafe") + public void testIsTextWithAscii(BufferType bufferType) { + byte[] validBytes = {(byte) 0x00, (byte) 0x01, (byte) 0x37, (byte) 0x7F}; + byte[] invalidBytes = {(byte) 0x80, (byte) 0xFF}; + + assertIsText(bufferType, validBytes, true, CharsetUtil.US_ASCII); + assertIsText(bufferType, invalidBytes, false, CharsetUtil.US_ASCII); + } + + @ParameterizedTest(name = PARAMETERIZED_NAME) + @MethodSource("noUnsafe") + public void testIsTextWithInvalidIndexAndLength(BufferType bufferType) { + ByteBuf buffer = buffer(bufferType, 4); + try { + buffer.writeBytes(new byte[4]); + int[][] validIndexLengthPairs = { + {4, 0}, + {0, 4}, + {1, 3}, + }; + for (int[] pair : validIndexLengthPairs) { + assertTrue(ByteBufUtil.isText(buffer, pair[0], pair[1], CharsetUtil.US_ASCII)); + } + int[][] invalidIndexLengthPairs = { + {4, 1}, + {-1, 2}, + {3, -1}, + {3, -2}, + {5, 0}, + {1, 5}, + }; + for (int[] pair : invalidIndexLengthPairs) { + try { + ByteBufUtil.isText(buffer, pair[0], pair[1], CharsetUtil.US_ASCII); + fail("Expected IndexOutOfBoundsException"); + } catch (IndexOutOfBoundsException e) { + // expected + } + } + } finally { + buffer.release(); + } + } + + @ParameterizedTest(name = PARAMETERIZED_NAME) + @MethodSource("noUnsafe") + public void testUtf8Bytes(BufferType bufferType) { + final String s = "Some UTF-8 like äÄ∏ŒŒ"; + checkUtf8Bytes(bufferType, s); + } + + @ParameterizedTest(name = PARAMETERIZED_NAME) + @MethodSource("noUnsafe") + public void testUtf8BytesWithSurrogates(BufferType bufferType) { + final String s = "a\uD800\uDC00b"; + checkUtf8Bytes(bufferType, s); + } + + @ParameterizedTest(name = PARAMETERIZED_NAME) + @MethodSource("noUnsafe") + public void testUtf8BytesWithNonSurrogates3Bytes(BufferType bufferType) { + final String s = "a\uE000b"; + checkUtf8Bytes(bufferType, s); + } + + @ParameterizedTest(name = PARAMETERIZED_NAME) + @MethodSource("noUnsafe") + public void testUtf8BytesWithNonSurrogatesNonAscii(BufferType bufferType) { + final char nonAscii = (char) 0x81; + final String s = "a" + nonAscii + "b"; + checkUtf8Bytes(bufferType, s); + } + + private void checkUtf8Bytes(BufferType bufferType, final CharSequence charSequence) { + final ByteBuf buf = buffer(bufferType, ByteBufUtil.utf8MaxBytes(charSequence)); + try { + final int writtenBytes = ByteBufUtil.writeUtf8(buf, charSequence); + final int utf8Bytes = ByteBufUtil.utf8Bytes(charSequence); + assertEquals(writtenBytes, utf8Bytes); + } finally { + buf.release(); + } + } + + private void assertIsText(BufferType bufferType, byte[] bytes, boolean expected, Charset charset) { + ByteBuf buffer = buffer(bufferType, bytes.length); + try { + buffer.writeBytes(bytes); + assertEquals(expected, ByteBufUtil.isText(buffer, charset)); + } finally { + buffer.release(); + } + } + + @ParameterizedTest(name = PARAMETERIZED_NAME) + @MethodSource("noUnsafe") + public void testIsTextMultiThreaded(BufferType bufferType) throws Throwable { + assumeTrue(bufferType == BufferType.HEAP_UNPOOLED); + final ByteBuf buffer = Unpooled.copiedBuffer("Hello, World!", CharsetUtil.ISO_8859_1); + + try { + final AtomicInteger counter = new AtomicInteger(60000); + final AtomicReference errorRef = new AtomicReference(); + List threads = new ArrayList(); + for (int i = 0; i < 10; i++) { + Thread thread = new Thread(new Runnable() { + @Override + public void run() { + try { + while (errorRef.get() == null && counter.decrementAndGet() > 0) { + assertTrue(ByteBufUtil.isText(buffer, CharsetUtil.ISO_8859_1)); + } + } catch (Throwable cause) { + errorRef.compareAndSet(null, cause); + } + } + }); + threads.add(thread); + } + for (Thread thread : threads) { + thread.start(); + } + + for (Thread thread : threads) { + thread.join(); + } + + Throwable error = errorRef.get(); + if (error != null) { + throw error; + } + } finally { + buffer.release(); + } + } + + @ParameterizedTest(name = PARAMETERIZED_NAME) + @MethodSource("noUnsafe") + public void testGetBytes(BufferType bufferType) { + final ByteBuf buf = buffer(bufferType, 4); + try { + checkGetBytes(buf); + } finally { + buf.release(); + } + } + + @ParameterizedTest(name = PARAMETERIZED_NAME) + @MethodSource("noUnsafe") + public void testGetBytesHeapWithNonZeroArrayOffset(BufferType bufferType) { + assumeTrue(bufferType == BufferType.HEAP_UNPOOLED); + final ByteBuf buf = buffer(bufferType, 5); + try { + buf.setByte(0, 0x05); + + final ByteBuf slice = buf.slice(1, 4); + slice.writerIndex(0); + + assertTrue(slice.hasArray()); + assertThat(slice.arrayOffset(), is(1)); + assertThat(slice.array().length, is(buf.capacity())); + + checkGetBytes(slice); + } finally { + buf.release(); + } + } + + @ParameterizedTest(name = PARAMETERIZED_NAME) + @MethodSource("noUnsafe") + public void testGetBytesHeapWithArrayLengthGreaterThanCapacity(BufferType bufferType) { + assumeTrue(bufferType == BufferType.HEAP_UNPOOLED); + final ByteBuf buf = buffer(bufferType, 5); + try { + buf.setByte(4, 0x05); + + final ByteBuf slice = buf.slice(0, 4); + slice.writerIndex(0); + + assertTrue(slice.hasArray()); + assertThat(slice.arrayOffset(), is(0)); + assertThat(slice.array().length, greaterThan(slice.capacity())); + + checkGetBytes(slice); + } finally { + buf.release(); + } + } + + private static void checkGetBytes(final ByteBuf buf) { + buf.writeInt(0x01020304); + + byte[] expected = { 0x01, 0x02, 0x03, 0x04 }; + assertArrayEquals(expected, ByteBufUtil.getBytes(buf)); + assertArrayEquals(expected, ByteBufUtil.getBytes(buf, 0, buf.readableBytes(), false)); + + expected = new byte[] { 0x01, 0x02, 0x03 }; + assertArrayEquals(expected, ByteBufUtil.getBytes(buf, 0, 3)); + assertArrayEquals(expected, ByteBufUtil.getBytes(buf, 0, 3, false)); + + expected = new byte[] { 0x02, 0x03, 0x04 }; + assertArrayEquals(expected, ByteBufUtil.getBytes(buf, 1, 3)); + assertArrayEquals(expected, ByteBufUtil.getBytes(buf, 1, 3, false)); + + expected = new byte[] { 0x02, 0x03 }; + assertArrayEquals(expected, ByteBufUtil.getBytes(buf, 1, 2)); + assertArrayEquals(expected, ByteBufUtil.getBytes(buf, 1, 2, false)); + } +} diff --git a/netty-buffer/src/test/java/io/netty/buffer/ByteProcessorTest.java b/netty-buffer/src/test/java/io/netty/buffer/ByteProcessorTest.java new file mode 100644 index 0000000..01c7b7f --- /dev/null +++ b/netty-buffer/src/test/java/io/netty/buffer/ByteProcessorTest.java @@ -0,0 +1,69 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.buffer; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import io.netty.util.ByteProcessor; +import io.netty.util.CharsetUtil; + +import org.junit.jupiter.api.Test; + +public class ByteProcessorTest { + @Test + public void testForward() { + final ByteBuf buf = + Unpooled.copiedBuffer("abc\r\n\ndef\r\rghi\n\njkl\0\0mno \t\tx", CharsetUtil.ISO_8859_1); + final int length = buf.readableBytes(); + + assertEquals(3, buf.forEachByte(0, length, ByteProcessor.FIND_CRLF)); + assertEquals(6, buf.forEachByte(3, length - 3, ByteProcessor.FIND_NON_CRLF)); + assertEquals(9, buf.forEachByte(6, length - 6, ByteProcessor.FIND_CR)); + assertEquals(11, buf.forEachByte(9, length - 9, ByteProcessor.FIND_NON_CR)); + assertEquals(14, buf.forEachByte(11, length - 11, ByteProcessor.FIND_LF)); + assertEquals(16, buf.forEachByte(14, length - 14, ByteProcessor.FIND_NON_LF)); + assertEquals(19, buf.forEachByte(16, length - 16, ByteProcessor.FIND_NUL)); + assertEquals(21, buf.forEachByte(19, length - 19, ByteProcessor.FIND_NON_NUL)); + assertEquals(24, buf.forEachByte(19, length - 19, ByteProcessor.FIND_ASCII_SPACE)); + assertEquals(24, buf.forEachByte(21, length - 21, ByteProcessor.FIND_LINEAR_WHITESPACE)); + assertEquals(28, buf.forEachByte(24, length - 24, ByteProcessor.FIND_NON_LINEAR_WHITESPACE)); + assertEquals(-1, buf.forEachByte(28, length - 28, ByteProcessor.FIND_LINEAR_WHITESPACE)); + + buf.release(); + } + + @Test + public void testBackward() { + final ByteBuf buf = + Unpooled.copiedBuffer("abc\r\n\ndef\r\rghi\n\njkl\0\0mno \t\tx", CharsetUtil.ISO_8859_1); + final int length = buf.readableBytes(); + + assertEquals(27, buf.forEachByteDesc(0, length, ByteProcessor.FIND_LINEAR_WHITESPACE)); + assertEquals(25, buf.forEachByteDesc(0, length, ByteProcessor.FIND_ASCII_SPACE)); + assertEquals(23, buf.forEachByteDesc(0, 28, ByteProcessor.FIND_NON_LINEAR_WHITESPACE)); + assertEquals(20, buf.forEachByteDesc(0, 24, ByteProcessor.FIND_NUL)); + assertEquals(18, buf.forEachByteDesc(0, 21, ByteProcessor.FIND_NON_NUL)); + assertEquals(15, buf.forEachByteDesc(0, 19, ByteProcessor.FIND_LF)); + assertEquals(13, buf.forEachByteDesc(0, 16, ByteProcessor.FIND_NON_LF)); + assertEquals(10, buf.forEachByteDesc(0, 14, ByteProcessor.FIND_CR)); + assertEquals(8, buf.forEachByteDesc(0, 11, ByteProcessor.FIND_NON_CR)); + assertEquals(5, buf.forEachByteDesc(0, 9, ByteProcessor.FIND_CRLF)); + assertEquals(2, buf.forEachByteDesc(0, 6, ByteProcessor.FIND_NON_CRLF)); + assertEquals(-1, buf.forEachByteDesc(0, 3, ByteProcessor.FIND_CRLF)); + + buf.release(); + } +} diff --git a/netty-buffer/src/test/java/io/netty/buffer/ConsolidationTest.java b/netty-buffer/src/test/java/io/netty/buffer/ConsolidationTest.java new file mode 100644 index 0000000..6d304cc --- /dev/null +++ b/netty-buffer/src/test/java/io/netty/buffer/ConsolidationTest.java @@ -0,0 +1,77 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import io.netty.util.CharsetUtil; +import org.junit.jupiter.api.Test; + +import static io.netty.buffer.Unpooled.*; +import static org.junit.jupiter.api.Assertions.assertEquals; + +/** + * Tests buffer consolidation + */ +public class ConsolidationTest { + @Test + public void shouldWrapInSequence() { + ByteBuf currentBuffer = wrappedBuffer(wrappedBuffer("a".getBytes(CharsetUtil.US_ASCII)), + wrappedBuffer("=".getBytes(CharsetUtil.US_ASCII))); + currentBuffer = wrappedBuffer(currentBuffer, wrappedBuffer("1".getBytes(CharsetUtil.US_ASCII)), + wrappedBuffer("&".getBytes(CharsetUtil.US_ASCII))); + + ByteBuf copy = currentBuffer.copy(); + String s = copy.toString(CharsetUtil.US_ASCII); + assertEquals("a=1&", s); + + currentBuffer.release(); + copy.release(); + } + + @Test + public void shouldConsolidationInSequence() { + ByteBuf currentBuffer = wrappedBuffer(wrappedBuffer("a".getBytes(CharsetUtil.US_ASCII)), + wrappedBuffer("=".getBytes(CharsetUtil.US_ASCII))); + currentBuffer = wrappedBuffer(currentBuffer, wrappedBuffer("1".getBytes(CharsetUtil.US_ASCII)), + wrappedBuffer("&".getBytes(CharsetUtil.US_ASCII))); + + currentBuffer = wrappedBuffer(currentBuffer, wrappedBuffer("b".getBytes(CharsetUtil.US_ASCII)), + wrappedBuffer("=".getBytes(CharsetUtil.US_ASCII))); + currentBuffer = wrappedBuffer(currentBuffer, wrappedBuffer("2".getBytes(CharsetUtil.US_ASCII)), + wrappedBuffer("&".getBytes(CharsetUtil.US_ASCII))); + + currentBuffer = wrappedBuffer(currentBuffer, wrappedBuffer("c".getBytes(CharsetUtil.US_ASCII)), + wrappedBuffer("=".getBytes(CharsetUtil.US_ASCII))); + currentBuffer = wrappedBuffer(currentBuffer, wrappedBuffer("3".getBytes(CharsetUtil.US_ASCII)), + wrappedBuffer("&".getBytes(CharsetUtil.US_ASCII))); + + currentBuffer = wrappedBuffer(currentBuffer, wrappedBuffer("d".getBytes(CharsetUtil.US_ASCII)), + wrappedBuffer("=".getBytes(CharsetUtil.US_ASCII))); + currentBuffer = wrappedBuffer(currentBuffer, wrappedBuffer("4".getBytes(CharsetUtil.US_ASCII)), + wrappedBuffer("&".getBytes(CharsetUtil.US_ASCII))); + + currentBuffer = wrappedBuffer(currentBuffer, wrappedBuffer("e".getBytes(CharsetUtil.US_ASCII)), + wrappedBuffer("=".getBytes(CharsetUtil.US_ASCII))); + currentBuffer = wrappedBuffer(currentBuffer, wrappedBuffer("5".getBytes(CharsetUtil.US_ASCII)), + wrappedBuffer("&".getBytes(CharsetUtil.US_ASCII))); + + ByteBuf copy = currentBuffer.copy(); + String s = copy.toString(CharsetUtil.US_ASCII); + assertEquals("a=1&b=2&c=3&d=4&e=5&", s); + + currentBuffer.release(); + copy.release(); + } +} diff --git a/netty-buffer/src/test/java/io/netty/buffer/DefaultByteBufHolderTest.java b/netty-buffer/src/test/java/io/netty/buffer/DefaultByteBufHolderTest.java new file mode 100644 index 0000000..2a623aa --- /dev/null +++ b/netty-buffer/src/test/java/io/netty/buffer/DefaultByteBufHolderTest.java @@ -0,0 +1,108 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class DefaultByteBufHolderTest { + + @Test + public void testToString() { + ByteBufHolder holder = new DefaultByteBufHolder(Unpooled.buffer()); + assertEquals(1, holder.refCnt()); + assertNotNull(holder.toString()); + assertTrue(holder.release()); + assertNotNull(holder.toString()); + } + + @Test + public void testEqualsAndHashCode() { + ByteBufHolder holder = new DefaultByteBufHolder(Unpooled.EMPTY_BUFFER); + ByteBufHolder copy = holder.copy(); + try { + assertEquals(holder, copy); + assertEquals(holder.hashCode(), copy.hashCode()); + } finally { + holder.release(); + copy.release(); + } + } + + @SuppressWarnings("SimplifiableJUnitAssertion") + @Test + public void testDifferentClassesAreNotEqual() { + // all objects here have EMPTY_BUFFER data but are instances of different classes + // so we want to check that none of them are equal to another. + ByteBufHolder dflt = new DefaultByteBufHolder(Unpooled.EMPTY_BUFFER); + ByteBufHolder other = new OtherByteBufHolder(Unpooled.EMPTY_BUFFER, 123); + ByteBufHolder constant1 = new DefaultByteBufHolder(Unpooled.EMPTY_BUFFER) { + // intentionally empty + }; + ByteBufHolder constant2 = new DefaultByteBufHolder(Unpooled.EMPTY_BUFFER) { + // intentionally empty + }; + try { + // not using 'assertNotEquals' to be explicit about which object we are calling .equals() on + assertFalse(dflt.equals(other)); + assertFalse(dflt.equals(constant1)); + assertFalse(constant1.equals(dflt)); + assertFalse(constant1.equals(other)); + assertFalse(constant1.equals(constant2)); + } finally { + dflt.release(); + other.release(); + constant1.release(); + constant2.release(); + } + } + + private static class OtherByteBufHolder extends DefaultByteBufHolder { + + private final int extraField; + + OtherByteBufHolder(final ByteBuf data, final int extraField) { + super(data); + this.extraField = extraField; + } + + @Override + public boolean equals(final Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + if (!super.equals(o)) { + return false; + } + final OtherByteBufHolder that = (OtherByteBufHolder) o; + return extraField == that.extraField; + } + + @Override + public int hashCode() { + int result = super.hashCode(); + result = 31 * result + extraField; + return result; + } + } +} diff --git a/netty-buffer/src/test/java/io/netty/buffer/DuplicatedByteBufTest.java b/netty-buffer/src/test/java/io/netty/buffer/DuplicatedByteBufTest.java new file mode 100644 index 0000000..201d1ed --- /dev/null +++ b/netty-buffer/src/test/java/io/netty/buffer/DuplicatedByteBufTest.java @@ -0,0 +1,89 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +/** + * Tests duplicated channel buffers + */ +public class DuplicatedByteBufTest extends AbstractByteBufTest { + + @Override + protected ByteBuf newBuffer(int length, int maxCapacity) { + ByteBuf wrapped = Unpooled.buffer(length, maxCapacity); + ByteBuf buffer = new DuplicatedByteBuf(wrapped); + assertEquals(wrapped.writerIndex(), buffer.writerIndex()); + assertEquals(wrapped.readerIndex(), buffer.readerIndex()); + return buffer; + } + + @Test + public void testIsContiguous() { + ByteBuf buf = newBuffer(4); + assertEquals(buf.unwrap().isContiguous(), buf.isContiguous()); + buf.release(); + } + + @Test + public void shouldNotAllowNullInConstructor() { + assertThrows(NullPointerException.class, new Executable() { + @Override + public void execute() { + new DuplicatedByteBuf(null); + } + }); + } + + // See https://github.com/netty/netty/issues/1800 + @Test + public void testIncreaseCapacityWrapped() { + ByteBuf buffer = newBuffer(8); + ByteBuf wrapped = buffer.unwrap(); + wrapped.writeByte(0); + wrapped.readerIndex(wrapped.readerIndex() + 1); + buffer.writerIndex(buffer.writerIndex() + 1); + wrapped.capacity(wrapped.capacity() * 2); + + assertEquals((byte) 0, buffer.readByte()); + } + + @Test + public void testMarksInitialized() { + ByteBuf wrapped = Unpooled.buffer(8); + try { + wrapped.writerIndex(6); + wrapped.readerIndex(1); + ByteBuf duplicate = new DuplicatedByteBuf(wrapped); + + // Test writer mark + duplicate.writerIndex(duplicate.writerIndex() + 1); + duplicate.resetWriterIndex(); + assertEquals(wrapped.writerIndex(), duplicate.writerIndex()); + + // Test reader mark + duplicate.readerIndex(duplicate.readerIndex() + 1); + duplicate.resetReaderIndex(); + assertEquals(wrapped.readerIndex(), duplicate.readerIndex()); + } finally { + wrapped.release(); + } + } +} diff --git a/netty-buffer/src/test/java/io/netty/buffer/EmptyByteBufTest.java b/netty-buffer/src/test/java/io/netty/buffer/EmptyByteBufTest.java new file mode 100644 index 0000000..3ea8012 --- /dev/null +++ b/netty-buffer/src/test/java/io/netty/buffer/EmptyByteBufTest.java @@ -0,0 +1,115 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import io.netty.util.CharsetUtil; +import org.junit.jupiter.api.Test; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.*; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class EmptyByteBufTest { + + @Test + public void testIsContiguous() { + EmptyByteBuf empty = new EmptyByteBuf(UnpooledByteBufAllocator.DEFAULT); + assertTrue(empty.isContiguous()); + } + + @Test + public void testIsWritable() { + EmptyByteBuf empty = new EmptyByteBuf(UnpooledByteBufAllocator.DEFAULT); + assertFalse(empty.isWritable()); + assertFalse(empty.isWritable(1)); + } + + @Test + public void testWriteEmptyByteBuf() { + EmptyByteBuf empty = new EmptyByteBuf(UnpooledByteBufAllocator.DEFAULT); + empty.writeBytes(Unpooled.EMPTY_BUFFER); // Ok + ByteBuf nonEmpty = UnpooledByteBufAllocator.DEFAULT.buffer().writeBoolean(false); + try { + empty.writeBytes(nonEmpty); + fail(); + } catch (IndexOutOfBoundsException ignored) { + // Ignore. + } finally { + nonEmpty.release(); + } + } + + @Test + public void testIsReadable() { + EmptyByteBuf empty = new EmptyByteBuf(UnpooledByteBufAllocator.DEFAULT); + assertFalse(empty.isReadable()); + assertFalse(empty.isReadable(1)); + } + + @Test + public void testArray() { + EmptyByteBuf empty = new EmptyByteBuf(UnpooledByteBufAllocator.DEFAULT); + assertThat(empty.hasArray(), is(true)); + assertThat(empty.array().length, is(0)); + assertThat(empty.arrayOffset(), is(0)); + } + + @Test + public void testNioBuffer() { + EmptyByteBuf empty = new EmptyByteBuf(UnpooledByteBufAllocator.DEFAULT); + assertThat(empty.nioBufferCount(), is(1)); + assertThat(empty.nioBuffer().position(), is(0)); + assertThat(empty.nioBuffer().limit(), is(0)); + assertThat(empty.nioBuffer(), is(sameInstance(empty.nioBuffer()))); + assertThat(empty.nioBuffer(), is(sameInstance(empty.internalNioBuffer(empty.readerIndex(), 0)))); + } + + @Test + public void testMemoryAddress() { + EmptyByteBuf empty = new EmptyByteBuf(UnpooledByteBufAllocator.DEFAULT); + if (empty.hasMemoryAddress()) { + assertThat(empty.memoryAddress(), is(not(0L))); + } else { + try { + empty.memoryAddress(); + fail(); + } catch (UnsupportedOperationException ignored) { + // Ignore. + } + } + } + + @Test + public void consistentEqualsAndHashCodeWithAbstractBytebuf() { + ByteBuf empty = new EmptyByteBuf(UnpooledByteBufAllocator.DEFAULT); + ByteBuf emptyAbstract = new UnpooledHeapByteBuf(UnpooledByteBufAllocator.DEFAULT, 0, 0); + assertEquals(emptyAbstract, empty); + assertEquals(emptyAbstract.hashCode(), empty.hashCode()); + assertEquals(EmptyByteBuf.EMPTY_BYTE_BUF_HASH_CODE, empty.hashCode()); + assertTrue(emptyAbstract.release()); + assertFalse(empty.release()); + } + + @Test + public void testGetCharSequence() { + EmptyByteBuf empty = new EmptyByteBuf(UnpooledByteBufAllocator.DEFAULT); + assertEquals("", empty.readCharSequence(0, CharsetUtil.US_ASCII)); + } + +} diff --git a/netty-buffer/src/test/java/io/netty/buffer/FixedCompositeByteBufTest.java b/netty-buffer/src/test/java/io/netty/buffer/FixedCompositeByteBufTest.java new file mode 100644 index 0000000..df53c10 --- /dev/null +++ b/netty-buffer/src/test/java/io/netty/buffer/FixedCompositeByteBufTest.java @@ -0,0 +1,526 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import org.junit.jupiter.api.Assumptions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import java.io.ByteArrayInputStream; +import java.io.IOException; +import java.io.OutputStream; +import java.nio.ByteBuffer; +import java.nio.ReadOnlyBufferException; +import java.nio.channels.ScatteringByteChannel; +import java.nio.charset.Charset; + +import static io.netty.buffer.Unpooled.*; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class FixedCompositeByteBufTest { + + private static ByteBuf newBuffer(ByteBuf... buffers) { + return new FixedCompositeByteBuf(UnpooledByteBufAllocator.DEFAULT, buffers); + } + + @Test + public void testSetBoolean() { + final ByteBuf buf = newBuffer(wrappedBuffer(new byte[8])); + try { + assertThrows(ReadOnlyBufferException.class, new Executable() { + @Override + public void execute() { + buf.setBoolean(0, true); + } + }); + } finally { + buf.release(); + } + } + + @Test + public void testSetByte() { + final ByteBuf buf = newBuffer(wrappedBuffer(new byte[8])); + try { + assertThrows(ReadOnlyBufferException.class, new Executable() { + @Override + public void execute() { + buf.setByte(0, 1); + } + }); + } finally { + buf.release(); + } + } + + @Test + public void testSetBytesWithByteBuf() { + final ByteBuf buf = newBuffer(wrappedBuffer(new byte[8])); + final ByteBuf src = wrappedBuffer(new byte[4]); + try { + assertThrows(ReadOnlyBufferException.class, new Executable() { + @Override + public void execute() { + buf.setBytes(0, src); + } + }); + } finally { + buf.release(); + src.release(); + } + } + + @Test + public void testSetBytesWithByteBuffer() { + final ByteBuf buf = newBuffer(wrappedBuffer(new byte[8])); + try { + assertThrows(ReadOnlyBufferException.class, new Executable() { + @Override + public void execute() { + buf.setBytes(0, ByteBuffer.wrap(new byte[4])); + } + }); + } finally { + buf.release(); + } + } + + @Test + public void testSetBytesWithInputStream() { + final ByteBuf buf = newBuffer(wrappedBuffer(new byte[8])); + try { + assertThrows(ReadOnlyBufferException.class, new Executable() { + @Override + public void execute() throws IOException { + buf.setBytes(0, new ByteArrayInputStream(new byte[4]), 4); + } + }); + } finally { + buf.release(); + } + } + + @Test + public void testSetBytesWithChannel() { + final ByteBuf buf = newBuffer(wrappedBuffer(new byte[8])); + try { + assertThrows(ReadOnlyBufferException.class, new Executable() { + @Override + public void execute() throws IOException { + buf.setBytes(0, new ScatteringByteChannel() { + @Override + public long read(ByteBuffer[] dsts, int offset, int length) { + return 0; + } + + @Override + public long read(ByteBuffer[] dsts) { + return 0; + } + + @Override + public int read(ByteBuffer dst) { + return 0; + } + + @Override + public boolean isOpen() { + return true; + } + + @Override + public void close() { + } + }, 4); + } + }); + } finally { + buf.release(); + } + } + + @Test + public void testSetChar() { + final ByteBuf buf = newBuffer(wrappedBuffer(new byte[8])); + try { + assertThrows(ReadOnlyBufferException.class, new Executable() { + @Override + public void execute() { + buf.setChar(0, 'b'); + } + }); + } finally { + buf.release(); + } + } + + @Test + public void testSetDouble() { + final ByteBuf buf = newBuffer(wrappedBuffer(new byte[8])); + try { + assertThrows(ReadOnlyBufferException.class, new Executable() { + @Override + public void execute() { + buf.setDouble(0, 1); + } + }); + } finally { + buf.release(); + } + } + + @Test + public void testSetFloat() { + final ByteBuf buf = newBuffer(wrappedBuffer(new byte[8])); + try { + assertThrows(ReadOnlyBufferException.class, new Executable() { + @Override + public void execute() { + buf.setFloat(0, 1); + } + }); + } finally { + buf.release(); + } + } + + @Test + public void testSetInt() throws IOException { + final ByteBuf buf = newBuffer(wrappedBuffer(new byte[8])); + try { + assertThrows(ReadOnlyBufferException.class, new Executable() { + @Override + public void execute() { + buf.setInt(0, 1); + } + }); + } finally { + buf.release(); + } + } + + @Test + public void testSetLong() { + final ByteBuf buf = newBuffer(wrappedBuffer(new byte[8])); + try { + assertThrows(ReadOnlyBufferException.class, new Executable() { + @Override + public void execute() { + buf.setLong(0, 1); + } + }); + } finally { + buf.release(); + } + } + + @Test + public void testSetMedium() { + final ByteBuf buf = newBuffer(wrappedBuffer(new byte[8])); + try { + assertThrows(ReadOnlyBufferException.class, new Executable() { + @Override + public void execute() { + buf.setMedium(0, 1); + } + }); + } finally { + buf.release(); + } + } + + @Test + public void testGatheringWritesHeap() throws Exception { + testGatheringWrites(buffer(), buffer()); + } + + @Test + public void testGatheringWritesDirect() throws Exception { + testGatheringWrites(directBuffer(), directBuffer()); + } + + @Test + public void testGatheringWritesMixes() throws Exception { + testGatheringWrites(buffer(), directBuffer()); + } + + @Test + public void testGatheringWritesHeapPooled() throws Exception { + testGatheringWrites(PooledByteBufAllocator.DEFAULT.heapBuffer(), + PooledByteBufAllocator.DEFAULT.heapBuffer()); + } + + @Test + public void testGatheringWritesDirectPooled() throws Exception { + testGatheringWrites(PooledByteBufAllocator.DEFAULT.directBuffer(), + PooledByteBufAllocator.DEFAULT.directBuffer()); + } + + @Test + public void testGatheringWritesMixesPooled() throws Exception { + testGatheringWrites(PooledByteBufAllocator.DEFAULT.heapBuffer(), + PooledByteBufAllocator.DEFAULT.directBuffer()); + } + + private static void testGatheringWrites(ByteBuf buf1, ByteBuf buf2) throws Exception { + CompositeByteBuf buf = compositeBuffer(); + buf.addComponent(buf1.writeBytes(new byte[]{1, 2})); + buf.addComponent(buf2.writeBytes(new byte[]{1, 2})); + buf.writerIndex(3); + buf.readerIndex(1); + + AbstractByteBufTest.TestGatheringByteChannel channel = new AbstractByteBufTest.TestGatheringByteChannel(); + buf.readBytes(channel, 2); + + byte[] data = new byte[2]; + buf.getBytes(1, data); + buf.release(); + + assertArrayEquals(data, channel.writtenBytes()); + } + + @Test + public void testGatheringWritesPartialHeap() throws Exception { + testGatheringWritesPartial(buffer(), buffer()); + } + + @Test + public void testGatheringWritesPartialDirect() throws Exception { + testGatheringWritesPartial(directBuffer(), directBuffer()); + } + + @Test + public void testGatheringWritesPartialMixes() throws Exception { + testGatheringWritesPartial(buffer(), directBuffer()); + } + + @Test + public void testGatheringWritesPartialHeapPooled() throws Exception { + testGatheringWritesPartial(PooledByteBufAllocator.DEFAULT.heapBuffer(), + PooledByteBufAllocator.DEFAULT.heapBuffer()); + } + + @Test + public void testGatheringWritesPartialDirectPooled() throws Exception { + testGatheringWritesPartial(PooledByteBufAllocator.DEFAULT.directBuffer(), + PooledByteBufAllocator.DEFAULT.directBuffer()); + } + + @Test + public void testGatheringWritesPartialMixesPooled() throws Exception { + testGatheringWritesPartial(PooledByteBufAllocator.DEFAULT.heapBuffer(), + PooledByteBufAllocator.DEFAULT.directBuffer()); + } + + private static void testGatheringWritesPartial(ByteBuf buf1, ByteBuf buf2) throws Exception { + buf1.writeBytes(new byte[]{1, 2, 3, 4}); + buf2.writeBytes(new byte[]{1, 2, 3, 4}); + ByteBuf buf = newBuffer(buf1, buf2); + AbstractByteBufTest.TestGatheringByteChannel channel = new AbstractByteBufTest.TestGatheringByteChannel(1); + + while (buf.isReadable()) { + buf.readBytes(channel, buf.readableBytes()); + } + + byte[] data = new byte[8]; + buf.getBytes(0, data); + assertArrayEquals(data, channel.writtenBytes()); + buf.release(); + } + + @Test + public void testGatheringWritesSingleHeap() throws Exception { + testGatheringWritesSingleBuf(buffer()); + } + + @Test + public void testGatheringWritesSingleDirect() throws Exception { + testGatheringWritesSingleBuf(directBuffer()); + } + + private static void testGatheringWritesSingleBuf(ByteBuf buf1) throws Exception { + ByteBuf buf = newBuffer(buf1.writeBytes(new byte[]{1, 2, 3, 4})); + buf.readerIndex(1); + + AbstractByteBufTest.TestGatheringByteChannel channel = new AbstractByteBufTest.TestGatheringByteChannel(); + buf.readBytes(channel, 2); + + byte[] data = new byte[2]; + buf.getBytes(1, data); + assertArrayEquals(data, channel.writtenBytes()); + + buf.release(); + } + + @Test + public void testCopyingToOtherBuffer() { + ByteBuf buf1 = directBuffer(10); + ByteBuf buf2 = buffer(10); + ByteBuf buf3 = directBuffer(10); + buf1.writeBytes("a".getBytes(Charset.defaultCharset())); + buf2.writeBytes("b".getBytes(Charset.defaultCharset())); + buf3.writeBytes("c".getBytes(Charset.defaultCharset())); + ByteBuf composite = unmodifiableBuffer(buf1, buf2, buf3); + ByteBuf copy = directBuffer(3); + ByteBuf copy2 = buffer(3); + copy.setBytes(0, composite, 0, 3); + copy2.setBytes(0, composite, 0, 3); + copy.writerIndex(3); + copy2.writerIndex(3); + assertEquals(0, ByteBufUtil.compare(copy, composite)); + assertEquals(0, ByteBufUtil.compare(copy2, composite)); + assertEquals(0, ByteBufUtil.compare(copy, copy2)); + copy.release(); + copy2.release(); + composite.release(); + } + + @Test + public void testCopyingToOutputStream() throws IOException { + ByteBuf buf1 = directBuffer(10); + ByteBuf buf2 = buffer(10); + ByteBuf buf3 = directBuffer(10); + buf1.writeBytes("a".getBytes(Charset.defaultCharset())); + buf2.writeBytes("b".getBytes(Charset.defaultCharset())); + buf3.writeBytes("c".getBytes(Charset.defaultCharset())); + ByteBuf composite = unmodifiableBuffer(buf1, buf2, buf3); + ByteBuf copy = directBuffer(3); + ByteBuf copy2 = buffer(3); + OutputStream copyStream = new ByteBufOutputStream(copy); + OutputStream copy2Stream = new ByteBufOutputStream(copy2); + try { + composite.getBytes(0, copyStream, 3); + composite.getBytes(0, copy2Stream, 3); + assertEquals(0, ByteBufUtil.compare(copy, composite)); + assertEquals(0, ByteBufUtil.compare(copy2, composite)); + assertEquals(0, ByteBufUtil.compare(copy, copy2)); + } finally { + copy.release(); + copy2.release(); + copyStream.close(); + copy2Stream.close(); + composite.release(); + } + } + + @Test + public void testExtractNioBuffers() { + ByteBuf buf1 = directBuffer(10); + ByteBuf buf2 = buffer(10); + ByteBuf buf3 = directBuffer(10); + buf1.writeBytes("a".getBytes(Charset.defaultCharset())); + buf2.writeBytes("b".getBytes(Charset.defaultCharset())); + buf3.writeBytes("c".getBytes(Charset.defaultCharset())); + ByteBuf composite = unmodifiableBuffer(buf1, buf2, buf3); + ByteBuffer[] byteBuffers = composite.nioBuffers(0, 3); + assertEquals(3, byteBuffers.length); + assertEquals(1, byteBuffers[0].limit()); + assertEquals(1, byteBuffers[1].limit()); + assertEquals(1, byteBuffers[2].limit()); + composite.release(); + } + + @Test + public void testEmptyArray() { + ByteBuf buf = newBuffer(new ByteBuf[0]); + buf.release(); + } + + @Test + public void testHasMemoryAddressWithSingleBuffer() { + ByteBuf buf1 = directBuffer(10); + if (!buf1.hasMemoryAddress()) { + buf1.release(); + return; + } + ByteBuf buf = newBuffer(buf1); + assertTrue(buf.hasMemoryAddress()); + assertEquals(buf1.memoryAddress(), buf.memoryAddress()); + buf.release(); + } + + @Test + public void testHasMemoryAddressWhenEmpty() { + Assumptions.assumeTrue(EMPTY_BUFFER.hasMemoryAddress()); + ByteBuf buf = newBuffer(new ByteBuf[0]); + assertTrue(buf.hasMemoryAddress()); + assertEquals(EMPTY_BUFFER.memoryAddress(), buf.memoryAddress()); + buf.release(); + } + + @Test + public void testHasNoMemoryAddressWhenMultipleBuffers() { + ByteBuf buf1 = directBuffer(10); + if (!buf1.hasMemoryAddress()) { + buf1.release(); + return; + } + + ByteBuf buf2 = directBuffer(10); + ByteBuf buf = newBuffer(buf1, buf2); + assertFalse(buf.hasMemoryAddress()); + try { + buf.memoryAddress(); + fail(); + } catch (UnsupportedOperationException expected) { + // expected + } finally { + buf.release(); + } + } + + @Test + public void testHasArrayWithSingleBuffer() { + ByteBuf buf1 = buffer(10); + ByteBuf buf = newBuffer(buf1); + assertTrue(buf.hasArray()); + assertArrayEquals(buf1.array(), buf.array()); + buf.release(); + } + + @Test + public void testHasArrayWhenEmptyAndIsDirect() { + ByteBuf buf = newBuffer(new ByteBuf[0]); + assertTrue(buf.hasArray()); + assertArrayEquals(EMPTY_BUFFER.array(), buf.array()); + assertEquals(EMPTY_BUFFER.isDirect(), buf.isDirect()); + assertEquals(EMPTY_BUFFER.memoryAddress(), buf.memoryAddress()); + buf.release(); + } + + @Test + public void testHasNoArrayWhenMultipleBuffers() { + ByteBuf buf1 = buffer(10); + ByteBuf buf2 = buffer(10); + final ByteBuf buf = newBuffer(buf1, buf2); + assertFalse(buf.hasArray()); + try { + assertThrows(UnsupportedOperationException.class, new Executable() { + @Override + public void execute() { + buf.array(); + } + }); + } finally { + buf.release(); + } + } +} diff --git a/netty-buffer/src/test/java/io/netty/buffer/IntPriorityQueueTest.java b/netty-buffer/src/test/java/io/netty/buffer/IntPriorityQueueTest.java new file mode 100644 index 0000000..1377c32 --- /dev/null +++ b/netty-buffer/src/test/java/io/netty/buffer/IntPriorityQueueTest.java @@ -0,0 +1,147 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import io.netty.util.internal.ThreadLocalRandom; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.ListIterator; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.*; + +class IntPriorityQueueTest { + @Test + public void mustThrowWhenAddingNoValue() { + final IntPriorityQueue pq = new IntPriorityQueue(); + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + pq.offer(IntPriorityQueue.NO_VALUE); + } + }); + } + + @Test + public void mustReturnValuesInOrder() { + ThreadLocalRandom tlr = ThreadLocalRandom.current(); + int initialValues = tlr.nextInt(5, 30); + ArrayList values = new ArrayList(); + for (int i = 0; i < initialValues; i++) { + values.add(tlr.nextInt(0, Integer.MAX_VALUE)); + } + IntPriorityQueue pq = new IntPriorityQueue(); + assertTrue(pq.isEmpty()); + for (Integer value : values) { + pq.offer(value); + } + Collections.sort(values); + int valuesToRemove = initialValues / 2; + ListIterator itr = values.listIterator(); + for (int i = 0; i < valuesToRemove; i++) { + assertTrue(itr.hasNext()); + assertThat(pq.poll()).isEqualTo(itr.next()); + itr.remove(); + } + int moreValues = tlr.nextInt(5, 30); + for (int i = 0; i < moreValues; i++) { + int value = tlr.nextInt(0, Integer.MAX_VALUE); + pq.offer(value); + values.add(value); + } + Collections.sort(values); + itr = values.listIterator(); + while (itr.hasNext()) { + assertThat(pq.poll()).isEqualTo(itr.next()); + } + assertTrue(pq.isEmpty()); + assertThat(pq.poll()).isEqualTo(IntPriorityQueue.NO_VALUE); + } + + @Test + public void internalRemoveOfAllElements() { + ThreadLocalRandom tlr = ThreadLocalRandom.current(); + int initialValues = tlr.nextInt(5, 30); + ArrayList values = new ArrayList(); + IntPriorityQueue pq = new IntPriorityQueue(); + for (int i = 0; i < initialValues; i++) { + int value = tlr.nextInt(0, Integer.MAX_VALUE); + pq.offer(value); + values.add(value); + } + for (Integer value : values) { + pq.remove(value); + } + assertTrue(pq.isEmpty()); + assertThat(pq.poll()).isEqualTo(IntPriorityQueue.NO_VALUE); + } + + @Test + public void internalRemoveMustPreserveOrder() { + ThreadLocalRandom tlr = ThreadLocalRandom.current(); + int initialValues = tlr.nextInt(1, 30); + ArrayList values = new ArrayList(); + IntPriorityQueue pq = new IntPriorityQueue(); + for (int i = 0; i < initialValues; i++) { + int value = tlr.nextInt(0, Integer.MAX_VALUE); + pq.offer(value); + values.add(value); + } + + Integer toRemove = values.get(values.size() / 2); + values.remove(toRemove); + pq.remove(toRemove); + + Collections.sort(values); + for (Integer value : values) { + assertThat(pq.poll()).isEqualTo(value); + } + assertTrue(pq.isEmpty()); + assertThat(pq.poll()).isEqualTo(IntPriorityQueue.NO_VALUE); + } + + @Test + public void mustSupportDuplicateValues() { + IntPriorityQueue pq = new IntPriorityQueue(); + pq.offer(10); + pq.offer(5); + pq.offer(6); + pq.offer(5); + pq.offer(10); + pq.offer(10); + pq.offer(6); + pq.remove(10); + assertThat(pq.peek()).isEqualTo(5); + assertThat(pq.peek()).isEqualTo(5); + assertThat(pq.poll()).isEqualTo(5); + assertThat(pq.peek()).isEqualTo(5); + assertThat(pq.poll()).isEqualTo(5); + assertThat(pq.peek()).isEqualTo(6); + assertThat(pq.poll()).isEqualTo(6); + assertThat(pq.peek()).isEqualTo(6); + assertThat(pq.peek()).isEqualTo(6); + assertThat(pq.poll()).isEqualTo(6); + assertThat(pq.peek()).isEqualTo(10); + assertThat(pq.poll()).isEqualTo(10); + assertThat(pq.poll()).isEqualTo(10); + assertTrue(pq.isEmpty()); + assertThat(pq.poll()).isEqualTo(IntPriorityQueue.NO_VALUE); + assertThat(pq.peek()).isEqualTo(IntPriorityQueue.NO_VALUE); + } +} diff --git a/netty-buffer/src/test/java/io/netty/buffer/LittleEndianCompositeByteBufTest.java b/netty-buffer/src/test/java/io/netty/buffer/LittleEndianCompositeByteBufTest.java new file mode 100644 index 0000000..b6b27de --- /dev/null +++ b/netty-buffer/src/test/java/io/netty/buffer/LittleEndianCompositeByteBufTest.java @@ -0,0 +1,26 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + + +/** + * Tests little-endian composite channel buffers + */ +public class LittleEndianCompositeByteBufTest extends AbstractCompositeByteBufTest { + public LittleEndianCompositeByteBufTest() { + super(Unpooled.LITTLE_ENDIAN); + } +} diff --git a/netty-buffer/src/test/java/io/netty/buffer/LittleEndianDirectByteBufTest.java b/netty-buffer/src/test/java/io/netty/buffer/LittleEndianDirectByteBufTest.java new file mode 100644 index 0000000..6fa0ed0 --- /dev/null +++ b/netty-buffer/src/test/java/io/netty/buffer/LittleEndianDirectByteBufTest.java @@ -0,0 +1,39 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertSame; + +import java.nio.ByteOrder; + +/** + * Tests little-endian direct channel buffers + */ +public class LittleEndianDirectByteBufTest extends AbstractByteBufTest { + + @Override + protected ByteBuf newBuffer(int length, int maxCapacity) { + ByteBuf buffer = newDirectBuffer(length, maxCapacity).order(ByteOrder.LITTLE_ENDIAN); + assertSame(ByteOrder.LITTLE_ENDIAN, buffer.order()); + assertEquals(0, buffer.writerIndex()); + return buffer; + } + + protected ByteBuf newDirectBuffer(int length, int maxCapacity) { + return new UnpooledDirectByteBuf(UnpooledByteBufAllocator.DEFAULT, length, maxCapacity); + } +} diff --git a/netty-buffer/src/test/java/io/netty/buffer/LittleEndianHeapByteBufTest.java b/netty-buffer/src/test/java/io/netty/buffer/LittleEndianHeapByteBufTest.java new file mode 100644 index 0000000..0778734 --- /dev/null +++ b/netty-buffer/src/test/java/io/netty/buffer/LittleEndianHeapByteBufTest.java @@ -0,0 +1,33 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +import java.nio.ByteOrder; + +/** + * Tests little-endian heap channel buffers + */ +public class LittleEndianHeapByteBufTest extends AbstractByteBufTest { + + @Override + protected ByteBuf newBuffer(int length, int maxCapacity) { + ByteBuf buffer = Unpooled.buffer(length, maxCapacity).order(ByteOrder.LITTLE_ENDIAN); + assertEquals(0, buffer.writerIndex()); + return buffer; + } +} diff --git a/netty-buffer/src/test/java/io/netty/buffer/LittleEndianUnsafeDirectByteBufTest.java b/netty-buffer/src/test/java/io/netty/buffer/LittleEndianUnsafeDirectByteBufTest.java new file mode 100644 index 0000000..57eff37 --- /dev/null +++ b/netty-buffer/src/test/java/io/netty/buffer/LittleEndianUnsafeDirectByteBufTest.java @@ -0,0 +1,35 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import io.netty.util.internal.PlatformDependent; +import org.junit.jupiter.api.Assumptions; +import org.junit.jupiter.api.BeforeEach; + +public class LittleEndianUnsafeDirectByteBufTest extends LittleEndianDirectByteBufTest { + + @BeforeEach + @Override + public void init() { + Assumptions.assumeTrue(PlatformDependent.hasUnsafe(), "sun.misc.Unsafe not found, skip tests"); + super.init(); + } + + @Override + protected ByteBuf newBuffer(int length, int maxCapacity) { + return new UnpooledUnsafeDirectByteBuf(UnpooledByteBufAllocator.DEFAULT, length, maxCapacity); + } +} diff --git a/netty-buffer/src/test/java/io/netty/buffer/LittleEndianUnsafeNoCleanerDirectByteBufTest.java b/netty-buffer/src/test/java/io/netty/buffer/LittleEndianUnsafeNoCleanerDirectByteBufTest.java new file mode 100644 index 0000000..6396148 --- /dev/null +++ b/netty-buffer/src/test/java/io/netty/buffer/LittleEndianUnsafeNoCleanerDirectByteBufTest.java @@ -0,0 +1,36 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import io.netty.util.internal.PlatformDependent; +import org.junit.jupiter.api.Assumptions; +import org.junit.jupiter.api.BeforeEach; + +public class LittleEndianUnsafeNoCleanerDirectByteBufTest extends LittleEndianDirectByteBufTest { + + @BeforeEach + @Override + public void init() { + Assumptions.assumeTrue(PlatformDependent.useDirectBufferNoCleaner(), + "java.nio.DirectByteBuffer.(long, int) not found, skip tests"); + super.init(); + } + + @Override + protected ByteBuf newBuffer(int length, int maxCapacity) { + return new UnpooledUnsafeNoCleanerDirectByteBuf(UnpooledByteBufAllocator.DEFAULT, length, maxCapacity); + } +} diff --git a/netty-buffer/src/test/java/io/netty/buffer/LongLongHashMapTest.java b/netty-buffer/src/test/java/io/netty/buffer/LongLongHashMapTest.java new file mode 100644 index 0000000..915b430 --- /dev/null +++ b/netty-buffer/src/test/java/io/netty/buffer/LongLongHashMapTest.java @@ -0,0 +1,86 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import io.netty.util.internal.ThreadLocalRandom; +import org.junit.jupiter.api.Test; + +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.PrimitiveIterator.OfLong; +import java.util.Set; + +import static org.assertj.core.api.Assertions.assertThat; + +class LongLongHashMapTest { + @Test + public void zeroPutGetAndRemove() { + LongLongHashMap map = new LongLongHashMap(-1); + assertThat(map.put(0, 42)).isEqualTo(-1); + assertThat(map.get(0)).isEqualTo(42); + assertThat(map.put(0, 24)).isEqualTo(42); + assertThat(map.get(0)).isEqualTo(24); + map.remove(0); + assertThat(map.get(0)).isEqualTo(-1); + } + + @Test + public void mustHandleCollisions() { + LongLongHashMap map = new LongLongHashMap(-1); + Set set = new HashSet(); + long v = 1; + for (int i = 0; i < 63; i++) { + assertThat(map.put(v, v)).isEqualTo(-1); + set.add(v); + v <<= 1; + } + for (Long value : set) { + assertThat(map.get(value)).isEqualTo(value); + assertThat(map.put(value, -value)).isEqualTo(value); + assertThat(map.get(value)).isEqualTo(-value); + map.remove(value); + assertThat(map.get(value)).isEqualTo(-1); + } + } + + @Test + public void randomOperations() { + int operations = 6000; + ThreadLocalRandom tlr = ThreadLocalRandom.current(); + Map expected = new HashMap(); + LongLongHashMap actual = new LongLongHashMap(-1); + OfLong itr = tlr.longs(0, operations).limit(operations * 50).iterator(); + while (itr.hasNext()) { + long value = itr.nextLong(); + if (expected.containsKey(value)) { + assertThat(actual.get(value)).isEqualTo(expected.get(value)); + if (tlr.nextBoolean()) { + actual.remove(value); + expected.remove(value); + assertThat(actual.get(value)).isEqualTo(-1); + } else { + long v = expected.get(value); + assertThat(actual.put(value, -v)).isEqualTo(expected.put(value, -v)); + } + } else { + assertThat(actual.get(value)).isEqualTo(-1); + assertThat(actual.put(value, value)).isEqualTo(-1); + expected.put(value, value); + } + } + } +} diff --git a/netty-buffer/src/test/java/io/netty/buffer/NoopResourceLeakTracker.java b/netty-buffer/src/test/java/io/netty/buffer/NoopResourceLeakTracker.java new file mode 100644 index 0000000..8f31f80 --- /dev/null +++ b/netty-buffer/src/test/java/io/netty/buffer/NoopResourceLeakTracker.java @@ -0,0 +1,41 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import io.netty.util.ResourceLeakTracker; + +import java.util.concurrent.atomic.AtomicBoolean; + + +final class NoopResourceLeakTracker extends AtomicBoolean implements ResourceLeakTracker { + + private static final long serialVersionUID = 7874092436796083851L; + + @Override + public void record() { + // NOOP + } + + @Override + public void record(Object hint) { + // NOOP + } + + @Override + public boolean close(T trackedObject) { + return compareAndSet(false, true); + } +} diff --git a/netty-buffer/src/test/java/io/netty/buffer/PoolArenaTest.java b/netty-buffer/src/test/java/io/netty/buffer/PoolArenaTest.java new file mode 100644 index 0000000..7da96da --- /dev/null +++ b/netty-buffer/src/test/java/io/netty/buffer/PoolArenaTest.java @@ -0,0 +1,172 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.buffer; + +import org.junit.jupiter.api.Test; + +import java.nio.ByteBuffer; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class PoolArenaTest { + + private static final int PAGE_SIZE = 8192; + private static final int PAGE_SHIFTS = 11; + //chunkSize = pageSize * (2 ^ pageShifts) + private static final int CHUNK_SIZE = 16777216; + + @Test + public void testNormalizeCapacity() { + SizeClasses sc = new SizeClasses(PAGE_SIZE, PAGE_SHIFTS, CHUNK_SIZE, 0); + PoolArena arena = new PoolArena.DirectArena(null, sc); + int[] reqCapacities = {0, 15, 510, 1024, 1023, 1025}; + int[] expectedResult = {16, 16, 512, 1024, 1024, 1280}; + for (int i = 0; i < reqCapacities.length; i ++) { + assertEquals(expectedResult[i], + arena.sizeClass.sizeIdx2size(arena.sizeClass.size2SizeIdx(reqCapacities[i]))); + } + } + + @Test + public void testNormalizeAlignedCapacity() { + SizeClasses sc = new SizeClasses(PAGE_SIZE, PAGE_SHIFTS, CHUNK_SIZE, 64); + PoolArena arena = new PoolArena.DirectArena(null, sc); + int[] reqCapacities = {0, 15, 510, 1024, 1023, 1025}; + int[] expectedResult = {64, 64, 512, 1024, 1024, 1280}; + for (int i = 0; i < reqCapacities.length; i ++) { + assertEquals(expectedResult[i], + arena.sizeClass.sizeIdx2size(arena.sizeClass.size2SizeIdx(reqCapacities[i]))); + } + } + + @Test + public void testSize2SizeIdx() { + SizeClasses sc = new SizeClasses(PAGE_SIZE, PAGE_SHIFTS, CHUNK_SIZE, 0); + PoolArena arena = new PoolArena.DirectArena(null, sc); + + for (int sz = 0; sz <= CHUNK_SIZE; sz++) { + int sizeIdx = arena.sizeClass.size2SizeIdx(sz); + assertTrue(sz <= arena.sizeClass.sizeIdx2size(sizeIdx)); + if (sizeIdx > 0) { + assertTrue(sz > arena.sizeClass.sizeIdx2size(sizeIdx - 1)); + } + } + } + + @Test + public void testPages2PageIdx() { + int pageShifts = PAGE_SHIFTS; + SizeClasses sc = new SizeClasses(PAGE_SIZE, PAGE_SHIFTS, CHUNK_SIZE, 0); + PoolArena arena = new PoolArena.DirectArena(null, sc); + + int maxPages = CHUNK_SIZE >> pageShifts; + for (int pages = 1; pages <= maxPages; pages++) { + int pageIdxFloor = arena.sizeClass.pages2pageIdxFloor(pages); + assertTrue(pages << pageShifts >= arena.sizeClass.pageIdx2size(pageIdxFloor)); + if (pageIdxFloor > 0 && pages < maxPages) { + assertTrue(pages << pageShifts < arena.sizeClass.pageIdx2size(pageIdxFloor + 1)); + } + + int pageIdxCeiling = arena.sizeClass.pages2pageIdx(pages); + assertTrue(pages << pageShifts <= arena.sizeClass.pageIdx2size(pageIdxCeiling)); + if (pageIdxCeiling > 0) { + assertTrue(pages << pageShifts > arena.sizeClass.pageIdx2size(pageIdxCeiling - 1)); + } + } + } + + @Test + public void testSizeIdx2size() { + SizeClasses sc = new SizeClasses(PAGE_SIZE, PAGE_SHIFTS, CHUNK_SIZE, 0); + PoolArena arena = new PoolArena.DirectArena(null, sc); + for (int i = 0; i < arena.sizeClass.nSizes; i++) { + assertEquals(arena.sizeClass.sizeIdx2sizeCompute(i), arena.sizeClass.sizeIdx2size(i)); + } + } + + @Test + public void testPageIdx2size() { + SizeClasses sc = new SizeClasses(PAGE_SIZE, PAGE_SHIFTS, CHUNK_SIZE, 0); + PoolArena arena = new PoolArena.DirectArena(null, sc); + for (int i = 0; i < arena.sizeClass.nPSizes; i++) { + assertEquals(arena.sizeClass.pageIdx2sizeCompute(i), arena.sizeClass.pageIdx2size(i)); + } + } + + @Test + public void testAllocationCounter() { + final PooledByteBufAllocator allocator = new PooledByteBufAllocator( + true, // preferDirect + 0, // nHeapArena + 1, // nDirectArena + 8192, // pageSize + 11, // maxOrder + 0, // tinyCacheSize + 0, // smallCacheSize + 0, // normalCacheSize + true // useCacheForAllThreads + ); + + // create small buffer + final ByteBuf b1 = allocator.directBuffer(800); + // create normal buffer + final ByteBuf b2 = allocator.directBuffer(8192 * 5); + + assertNotNull(b1); + assertNotNull(b2); + + // then release buffer to deallocated memory while threadlocal cache has been disabled + // allocations counter value must equals deallocations counter value + assertTrue(b1.release()); + assertTrue(b2.release()); + + assertTrue(allocator.directArenas().size() >= 1); + final PoolArenaMetric metric = allocator.directArenas().get(0); + + assertEquals(2, metric.numDeallocations()); + assertEquals(2, metric.numAllocations()); + + assertEquals(1, metric.numSmallDeallocations()); + assertEquals(1, metric.numSmallAllocations()); + assertEquals(1, metric.numNormalDeallocations()); + assertEquals(1, metric.numNormalAllocations()); + } + + @Test + public void testDirectArenaMemoryCopy() { + ByteBuf src = PooledByteBufAllocator.DEFAULT.directBuffer(512); + ByteBuf dst = PooledByteBufAllocator.DEFAULT.directBuffer(512); + + PooledByteBuf pooledSrc = unwrapIfNeeded(src); + PooledByteBuf pooledDst = unwrapIfNeeded(dst); + + // This causes the internal reused ByteBuffer duplicate limit to be set to 128 + pooledDst.writeBytes(ByteBuffer.allocate(128)); + // Ensure internal ByteBuffer duplicate limit is properly reset (used in memoryCopy non-Unsafe case) + pooledDst.chunk.arena.memoryCopy(pooledSrc.memory, 0, pooledDst, 512); + + src.release(); + dst.release(); + } + + @SuppressWarnings("unchecked") + private PooledByteBuf unwrapIfNeeded(ByteBuf buf) { + return (PooledByteBuf) (buf instanceof PooledByteBuf ? buf : buf.unwrap()); + } +} diff --git a/netty-buffer/src/test/java/io/netty/buffer/PooledAlignedBigEndianDirectByteBufTest.java b/netty-buffer/src/test/java/io/netty/buffer/PooledAlignedBigEndianDirectByteBufTest.java new file mode 100644 index 0000000..5ee3de6 --- /dev/null +++ b/netty-buffer/src/test/java/io/netty/buffer/PooledAlignedBigEndianDirectByteBufTest.java @@ -0,0 +1,54 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; + +import java.nio.ByteOrder; + +import static org.junit.jupiter.api.Assertions.assertSame; + +public class PooledAlignedBigEndianDirectByteBufTest extends PooledBigEndianDirectByteBufTest { + private static final int directMemoryCacheAlignment = 1; + private static PooledByteBufAllocator allocator; + + @BeforeAll + public static void setUpAllocator() { + allocator = new PooledByteBufAllocator( + true, + PooledByteBufAllocator.defaultNumHeapArena(), + PooledByteBufAllocator.defaultNumDirectArena(), + PooledByteBufAllocator.defaultPageSize(), + 11, + PooledByteBufAllocator.defaultSmallCacheSize(), + 64, + PooledByteBufAllocator.defaultUseCacheForAllThreads(), + directMemoryCacheAlignment); + } + + @AfterAll + public static void releaseAllocator() { + allocator = null; + } + + @Override + protected ByteBuf alloc(int length, int maxCapacity) { + ByteBuf buffer = allocator.directBuffer(length, maxCapacity); + assertSame(ByteOrder.BIG_ENDIAN, buffer.order()); + return buffer; + } +} diff --git a/netty-buffer/src/test/java/io/netty/buffer/PooledBigEndianDirectByteBufTest.java b/netty-buffer/src/test/java/io/netty/buffer/PooledBigEndianDirectByteBufTest.java new file mode 100644 index 0000000..0494f77 --- /dev/null +++ b/netty-buffer/src/test/java/io/netty/buffer/PooledBigEndianDirectByteBufTest.java @@ -0,0 +1,33 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import java.nio.ByteOrder; + +import static org.junit.jupiter.api.Assertions.assertSame; + +/** + * Tests big-endian direct channel buffers + */ +public class PooledBigEndianDirectByteBufTest extends AbstractPooledByteBufTest { + + @Override + protected ByteBuf alloc(int length, int maxCapacity) { + ByteBuf buffer = PooledByteBufAllocator.DEFAULT.directBuffer(length, maxCapacity); + assertSame(ByteOrder.BIG_ENDIAN, buffer.order()); + return buffer; + } +} diff --git a/netty-buffer/src/test/java/io/netty/buffer/PooledBigEndianHeapByteBufTest.java b/netty-buffer/src/test/java/io/netty/buffer/PooledBigEndianHeapByteBufTest.java new file mode 100644 index 0000000..12b57a8 --- /dev/null +++ b/netty-buffer/src/test/java/io/netty/buffer/PooledBigEndianHeapByteBufTest.java @@ -0,0 +1,27 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +/** + * Tests big-endian heap channel buffers + */ +public class PooledBigEndianHeapByteBufTest extends AbstractPooledByteBufTest { + + @Override + protected ByteBuf alloc(int length, int maxCapacity) { + return PooledByteBufAllocator.DEFAULT.heapBuffer(length, maxCapacity); + } +} diff --git a/netty-buffer/src/test/java/io/netty/buffer/PooledByteBufAllocatorTest.java b/netty-buffer/src/test/java/io/netty/buffer/PooledByteBufAllocatorTest.java new file mode 100644 index 0000000..ac44e0b --- /dev/null +++ b/netty-buffer/src/test/java/io/netty/buffer/PooledByteBufAllocatorTest.java @@ -0,0 +1,963 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.buffer; + +import io.netty.util.concurrent.FastThreadLocal; +import io.netty.util.concurrent.FastThreadLocalThread; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.SystemPropertyUtil; +import org.junit.jupiter.api.Assumptions; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; +import java.util.Queue; +import java.util.Random; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; +import java.util.concurrent.atomic.AtomicReference; +import java.util.concurrent.locks.LockSupport; +import org.junit.jupiter.api.Timeout; + +import static io.netty.buffer.PoolChunk.runOffset; +import static io.netty.buffer.PoolChunk.runPages; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class PooledByteBufAllocatorTest extends AbstractByteBufAllocatorTest { + + @Override + protected PooledByteBufAllocator newAllocator(boolean preferDirect) { + return new PooledByteBufAllocator(preferDirect); + } + + @Override + protected PooledByteBufAllocator newUnpooledAllocator() { + return new PooledByteBufAllocator(0, 0, 8192, 1); + } + + @Override + protected long expectedUsedMemory(PooledByteBufAllocator allocator, int capacity) { + return allocator.metric().chunkSize(); + } + + @Override + protected long expectedUsedMemoryAfterRelease(PooledByteBufAllocator allocator, int capacity) { + // This is the case as allocations will start in qInit and chunks in qInit will never be released until + // these are moved to q000. + // See https://www.bsdcan.org/2006/papers/jemalloc.pdf + return allocator.metric().chunkSize(); + } + + @Override + protected void trimCaches(PooledByteBufAllocator allocator) { + allocator.trimCurrentThreadCache(); + } + + @Test + public void testTrim() { + PooledByteBufAllocator allocator = newAllocator(true); + + // Should return false as we never allocated from this thread yet. + assertFalse(allocator.trimCurrentThreadCache()); + + ByteBuf directBuffer = allocator.directBuffer(); + + assertTrue(directBuffer.release()); + + // Should return true now a cache exists for the calling thread. + assertTrue(allocator.trimCurrentThreadCache()); + } + + @Test + public void testPooledUnsafeHeapBufferAndUnsafeDirectBuffer() { + PooledByteBufAllocator allocator = newAllocator(true); + ByteBuf directBuffer = allocator.directBuffer(); + assertInstanceOf(directBuffer, + PlatformDependent.hasUnsafe() ? PooledUnsafeDirectByteBuf.class : PooledDirectByteBuf.class); + directBuffer.release(); + + ByteBuf heapBuffer = allocator.heapBuffer(); + assertInstanceOf(heapBuffer, + PlatformDependent.hasUnsafe() ? PooledUnsafeHeapByteBuf.class : PooledHeapByteBuf.class); + heapBuffer.release(); + } + + @Test + public void testIOBuffersAreDirectWhenUnsafeAvailableOrDirectBuffersPooled() { + PooledByteBufAllocator allocator = newAllocator(true); + ByteBuf ioBuffer = allocator.ioBuffer(); + + assertTrue(ioBuffer.isDirect()); + ioBuffer.release(); + + PooledByteBufAllocator unpooledAllocator = newUnpooledAllocator(); + ioBuffer = unpooledAllocator.ioBuffer(); + + if (PlatformDependent.hasUnsafe()) { + assertTrue(ioBuffer.isDirect()); + } else { + assertFalse(ioBuffer.isDirect()); + } + ioBuffer.release(); + } + + @Test + public void testWithoutUseCacheForAllThreads() { + assertThat(Thread.currentThread()).isNotInstanceOf(FastThreadLocalThread.class); + + PooledByteBufAllocator pool = new PooledByteBufAllocator( + /*preferDirect=*/ false, + /*nHeapArena=*/ 1, + /*nDirectArena=*/ 1, + /*pageSize=*/8192, + /*maxOrder=*/ 9, + /*tinyCacheSize=*/ 0, + /*smallCacheSize=*/ 0, + /*normalCacheSize=*/ 0, + /*useCacheForAllThreads=*/ false); + ByteBuf buf = pool.buffer(1); + buf.release(); + } + + @Test + public void testArenaMetricsNoCache() { + testArenaMetrics0(new PooledByteBufAllocator(true, 2, 2, 8192, 9, 0, 0, 0), 100, 0, 100, 100); + } + + @Test + public void testArenaMetricsCache() { + testArenaMetrics0(new PooledByteBufAllocator(true, 2, 2, 8192, 9, 1000, 1000, 1000, true, 0), 100, 1, 1, 0); + } + + @Test + public void testArenaMetricsNoCacheAlign() { + Assumptions.assumeTrue(PooledByteBufAllocator.isDirectMemoryCacheAlignmentSupported()); + testArenaMetrics0(new PooledByteBufAllocator(true, 2, 2, 8192, 9, 0, 0, 0, true, 64), 100, 0, 100, 100); + } + + @Test + public void testArenaMetricsCacheAlign() { + Assumptions.assumeTrue(PooledByteBufAllocator.isDirectMemoryCacheAlignmentSupported()); + testArenaMetrics0(new PooledByteBufAllocator(true, 2, 2, 8192, 9, 1000, 1000, 1000, true, 64), 100, 1, 1, 0); + } + + private static void testArenaMetrics0( + PooledByteBufAllocator allocator, int num, int expectedActive, int expectedAlloc, int expectedDealloc) { + for (int i = 0; i < num; i++) { + assertTrue(allocator.directBuffer().release()); + assertTrue(allocator.heapBuffer().release()); + } + + assertArenaMetrics(allocator.metric().directArenas(), expectedActive, expectedAlloc, expectedDealloc); + assertArenaMetrics(allocator.metric().heapArenas(), expectedActive, expectedAlloc, expectedDealloc); + } + + private static void assertArenaMetrics( + List arenaMetrics, int expectedActive, int expectedAlloc, int expectedDealloc) { + long active = 0; + long alloc = 0; + long dealloc = 0; + for (PoolArenaMetric arena : arenaMetrics) { + active += arena.numActiveAllocations(); + alloc += arena.numAllocations(); + dealloc += arena.numDeallocations(); + } + assertEquals(expectedActive, active); + assertEquals(expectedAlloc, alloc); + assertEquals(expectedDealloc, dealloc); + } + + @Test + public void testPoolChunkListMetric() { + for (PoolArenaMetric arenaMetric: PooledByteBufAllocator.DEFAULT.metric().heapArenas()) { + assertPoolChunkListMetric(arenaMetric); + } + } + + private static void assertPoolChunkListMetric(PoolArenaMetric arenaMetric) { + List lists = arenaMetric.chunkLists(); + assertEquals(6, lists.size()); + assertPoolChunkListMetric(lists.get(0), 1, 25); + assertPoolChunkListMetric(lists.get(1), 1, 50); + assertPoolChunkListMetric(lists.get(2), 25, 75); + assertPoolChunkListMetric(lists.get(4), 75, 100); + assertPoolChunkListMetric(lists.get(5), 100, 100); + } + + private static void assertPoolChunkListMetric(PoolChunkListMetric m, int min, int max) { + assertEquals(min, m.minUsage()); + assertEquals(max, m.maxUsage()); + } + + @Test + public void testSmallSubpageMetric() { + PooledByteBufAllocator allocator = new PooledByteBufAllocator(true, 1, 1, 8192, 9, 0, 0, 0); + ByteBuf buffer = allocator.heapBuffer(500); + try { + PoolArenaMetric metric = allocator.metric().heapArenas().get(0); + PoolSubpageMetric subpageMetric = metric.smallSubpages().get(0); + assertEquals(1, subpageMetric.maxNumElements() - subpageMetric.numAvailable()); + } finally { + buffer.release(); + } + } + + @Test + public void testAllocNotNull() { + PooledByteBufAllocator allocator = new PooledByteBufAllocator(true, 1, 1, 8192, 9, 0, 0, 0); + // Huge allocation + testAllocNotNull(allocator, allocator.metric().chunkSize() + 1); + // Normal allocation + testAllocNotNull(allocator, 1024); + // Small allocation + testAllocNotNull(allocator, 512); + testAllocNotNull(allocator, 1); + } + + private static void testAllocNotNull(PooledByteBufAllocator allocator, int capacity) { + ByteBuf buffer = allocator.heapBuffer(capacity); + assertNotNull(buffer.alloc()); + assertTrue(buffer.release()); + assertNotNull(buffer.alloc()); + } + + @Test + public void testFreePoolChunk() { + int chunkSize = 16 * 1024 * 1024; + PooledByteBufAllocator allocator = new PooledByteBufAllocator(true, 1, 0, 8192, 11, 0, 0, 0); + ByteBuf buffer = allocator.heapBuffer(chunkSize); + List arenas = allocator.metric().heapArenas(); + assertEquals(1, arenas.size()); + List lists = arenas.get(0).chunkLists(); + assertEquals(6, lists.size()); + + assertFalse(lists.get(0).iterator().hasNext()); + assertFalse(lists.get(1).iterator().hasNext()); + assertFalse(lists.get(2).iterator().hasNext()); + assertFalse(lists.get(3).iterator().hasNext()); + assertFalse(lists.get(4).iterator().hasNext()); + + // Must end up in the 6th PoolChunkList + assertTrue(lists.get(5).iterator().hasNext()); + assertTrue(buffer.release()); + + // Should be completely removed and so all PoolChunkLists must be empty + assertFalse(lists.get(0).iterator().hasNext()); + assertFalse(lists.get(1).iterator().hasNext()); + assertFalse(lists.get(2).iterator().hasNext()); + assertFalse(lists.get(3).iterator().hasNext()); + assertFalse(lists.get(4).iterator().hasNext()); + assertFalse(lists.get(5).iterator().hasNext()); + } + + @Test + public void testCollapse() { + int pageSize = 8192; + //no cache + ByteBufAllocator allocator = new PooledByteBufAllocator(true, 0, 1, 8192, 9, 0, 0, 0); + + ByteBuf b1 = allocator.buffer(pageSize * 4); + ByteBuf b2 = allocator.buffer(pageSize * 5); + ByteBuf b3 = allocator.buffer(pageSize * 6); + + b2.release(); + b3.release(); + + ByteBuf b4 = allocator.buffer(pageSize * 10); + + PooledByteBuf b = unwrapIfNeeded(b4); + + //b2 and b3 are collapsed, b4 should start at offset 4 + assertEquals(4, runOffset(b.handle)); + assertEquals(10, runPages(b.handle)); + + b1.release(); + b4.release(); + + //all ByteBuf are collapsed, b5 should start at offset 0 + ByteBuf b5 = allocator.buffer(pageSize * 20); + b = unwrapIfNeeded(b5); + + assertEquals(0, runOffset(b.handle)); + assertEquals(20, runPages(b.handle)); + + b5.release(); + } + + @Test + public void testAllocateSmallOffset() { + int pageSize = 8192; + ByteBufAllocator allocator = new PooledByteBufAllocator(true, 0, 1, 8192, 9, 0, 0, 0); + + int size = pageSize * 5; + + ByteBuf[] bufs = new ByteBuf[10]; + for (int i = 0; i < 10; i++) { + bufs[i] = allocator.buffer(size); + } + + for (int i = 0; i < 5; i++) { + bufs[i].release(); + } + + //make sure we always allocate runs with small offset + for (int i = 0; i < 5; i++) { + ByteBuf buf = allocator.buffer(size); + PooledByteBuf unwrapedBuf = unwrapIfNeeded(buf); + assertEquals(runOffset(unwrapedBuf.handle), i * 5); + bufs[i] = buf; + } + + //release at reverse order + for (int i = 10 - 1; i >= 5; i--) { + bufs[i].release(); + } + + for (int i = 5; i < 10; i++) { + ByteBuf buf = allocator.buffer(size); + PooledByteBuf unwrapedBuf = unwrapIfNeeded(buf); + assertEquals(runOffset(unwrapedBuf.handle), i * 5); + bufs[i] = buf; + } + + for (int i = 0; i < 10; i++) { + bufs[i].release(); + } + } + + @Disabled + @Test + @Timeout(value = 4000, unit = MILLISECONDS) + public void testThreadCacheDestroyedByThreadCleaner() throws InterruptedException { + testThreadCacheDestroyed(false); + } + + @Disabled + @Test + @Timeout(value = 4000, unit = MILLISECONDS) + public void testThreadCacheDestroyedAfterExitRun() throws InterruptedException { + testThreadCacheDestroyed(true); + } + + private static void testThreadCacheDestroyed(boolean useRunnable) throws InterruptedException { + int numArenas = 11; + final PooledByteBufAllocator allocator = + new PooledByteBufAllocator(numArenas, numArenas, 8192, 1); + + final AtomicBoolean threadCachesCreated = new AtomicBoolean(true); + + final Runnable task = new Runnable() { + @Override + public void run() { + ByteBuf buf = allocator.newHeapBuffer(1024, 1024); + for (int i = 0; i < buf.capacity(); i++) { + buf.writeByte(0); + } + + // Make sure that thread caches are actually created, + // so that down below we are not testing for zero + // thread caches without any of them ever having been initialized. + if (allocator.metric().numThreadLocalCaches() == 0) { + threadCachesCreated.set(false); + } + + buf.release(); + } + }; + + for (int i = 0; i < numArenas; i++) { + final FastThreadLocalThread thread; + if (useRunnable) { + thread = new FastThreadLocalThread(task); + assertTrue(thread.willCleanupFastThreadLocals()); + } else { + thread = new FastThreadLocalThread() { + @Override + public void run() { + task.run(); + } + }; + assertFalse(thread.willCleanupFastThreadLocals()); + } + thread.start(); + thread.join(); + } + + // Wait for the ThreadDeathWatcher to have destroyed all thread caches + while (allocator.metric().numThreadLocalCaches() > 0) { + // Signal we want to have a GC run to ensure we can process our ThreadCleanerReference + System.gc(); + System.runFinalization(); + LockSupport.parkNanos(MILLISECONDS.toNanos(100)); + } + + assertTrue(threadCachesCreated.get()); + } + + @Test + @Timeout(value = 3000, unit = MILLISECONDS) + public void testNumThreadCachesWithNoDirectArenas() throws InterruptedException { + int numHeapArenas = 1; + final PooledByteBufAllocator allocator = + new PooledByteBufAllocator(numHeapArenas, 0, 8192, 1); + + ThreadCache tcache0 = createNewThreadCache(allocator); + assertEquals(1, allocator.metric().numThreadLocalCaches()); + + ThreadCache tcache1 = createNewThreadCache(allocator); + assertEquals(2, allocator.metric().numThreadLocalCaches()); + + tcache0.destroy(); + assertEquals(1, allocator.metric().numThreadLocalCaches()); + + tcache1.destroy(); + assertEquals(0, allocator.metric().numThreadLocalCaches()); + } + + @Test + @Timeout(value = 3000, unit = MILLISECONDS) + public void testThreadCacheToArenaMappings() throws InterruptedException { + int numArenas = 2; + final PooledByteBufAllocator allocator = + new PooledByteBufAllocator(numArenas, numArenas, 8192, 1); + + ThreadCache tcache0 = createNewThreadCache(allocator); + ThreadCache tcache1 = createNewThreadCache(allocator); + assertEquals(2, allocator.metric().numThreadLocalCaches()); + assertEquals(1, allocator.metric().heapArenas().get(0).numThreadCaches()); + assertEquals(1, allocator.metric().heapArenas().get(1).numThreadCaches()); + assertEquals(1, allocator.metric().directArenas().get(0).numThreadCaches()); + assertEquals(1, allocator.metric().directArenas().get(0).numThreadCaches()); + + tcache1.destroy(); + + assertEquals(1, allocator.metric().numThreadLocalCaches()); + assertEquals(1, allocator.metric().heapArenas().get(0).numThreadCaches()); + assertEquals(0, allocator.metric().heapArenas().get(1).numThreadCaches()); + assertEquals(1, allocator.metric().directArenas().get(0).numThreadCaches()); + assertEquals(0, allocator.metric().directArenas().get(1).numThreadCaches()); + + ThreadCache tcache2 = createNewThreadCache(allocator); + assertEquals(2, allocator.metric().numThreadLocalCaches()); + assertEquals(1, allocator.metric().heapArenas().get(0).numThreadCaches()); + assertEquals(1, allocator.metric().heapArenas().get(1).numThreadCaches()); + assertEquals(1, allocator.metric().directArenas().get(0).numThreadCaches()); + assertEquals(1, allocator.metric().directArenas().get(1).numThreadCaches()); + + tcache0.destroy(); + assertEquals(1, allocator.metric().numThreadLocalCaches()); + + tcache2.destroy(); + assertEquals(0, allocator.metric().numThreadLocalCaches()); + assertEquals(0, allocator.metric().heapArenas().get(0).numThreadCaches()); + assertEquals(0, allocator.metric().heapArenas().get(1).numThreadCaches()); + assertEquals(0, allocator.metric().directArenas().get(0).numThreadCaches()); + assertEquals(0, allocator.metric().directArenas().get(1).numThreadCaches()); + } + + private static ThreadCache createNewThreadCache(final PooledByteBufAllocator allocator) + throws InterruptedException { + final CountDownLatch latch = new CountDownLatch(1); + final CountDownLatch cacheLatch = new CountDownLatch(1); + final Thread t = new FastThreadLocalThread(new Runnable() { + + @Override + public void run() { + ByteBuf buf = allocator.newHeapBuffer(1024, 1024); + + // Countdown the latch after we allocated a buffer. At this point the cache must exists. + cacheLatch.countDown(); + + buf.writeZero(buf.capacity()); + + try { + latch.await(); + } catch (InterruptedException e) { + throw new IllegalStateException(e); + } + + buf.release(); + + FastThreadLocal.removeAll(); + } + }); + t.start(); + + // Wait until we allocated a buffer and so be sure the thread was started and the cache exists. + cacheLatch.await(); + + return new ThreadCache() { + @Override + public void destroy() throws InterruptedException { + latch.countDown(); + t.join(); + } + }; + } + + private interface ThreadCache { + void destroy() throws InterruptedException; + } + + @Test + public void testConcurrentUsage() throws Throwable { + long runningTime = MILLISECONDS.toNanos(SystemPropertyUtil.getLong( + "io.netty.buffer.PooledByteBufAllocatorTest.testConcurrentUsageTime", 15000)); + + // We use no caches and only one arena to maximize the chance of hitting the race-condition we + // had before. + ByteBufAllocator allocator = new PooledByteBufAllocator(true, 0, 1, 8192, 9, 0, 0, 0); + List threads = new ArrayList(); + try { + for (int i = 0; i < 64; i++) { + AllocationThread thread = new AllocationThread(allocator); + thread.start(); + threads.add(thread); + } + + long start = System.nanoTime(); + while (!isExpired(start, runningTime)) { + checkForErrors(threads); + Thread.sleep(100); + } + } finally { + // First mark all AllocationThreads to complete their work and then wait until these are complete + // and rethrow if there was any error. + for (AllocationThread t : threads) { + t.markAsFinished(); + } + + for (AllocationThread t: threads) { + t.joinAndCheckForError(); + } + } + } + + private static boolean isExpired(long start, long expireTime) { + return System.nanoTime() - start > expireTime; + } + + private static void checkForErrors(List threads) throws Throwable { + for (AllocationThread t : threads) { + if (t.isFinished()) { + t.checkForError(); + } + } + } + + private static final class AllocationThread extends Thread { + + private static final int[] ALLOCATION_SIZES = new int[16 * 1024]; + static { + for (int i = 0; i < ALLOCATION_SIZES.length; i++) { + ALLOCATION_SIZES[i] = i; + } + } + + private final Queue buffers = new ConcurrentLinkedQueue(); + private final ByteBufAllocator allocator; + private final AtomicReference finish = new AtomicReference(); + + AllocationThread(ByteBufAllocator allocator) { + this.allocator = allocator; + } + + @Override + public void run() { + try { + int idx = 0; + while (finish.get() == null) { + for (int i = 0; i < 10; i++) { + int len = ALLOCATION_SIZES[Math.abs(idx++ % ALLOCATION_SIZES.length)]; + ByteBuf buf = allocator.directBuffer(len, Integer.MAX_VALUE); + assertEquals(len, buf.writableBytes()); + while (buf.isWritable()) { + buf.writeByte(i); + } + + buffers.offer(buf); + } + releaseBuffersAndCheckContent(); + } + } catch (Throwable cause) { + finish.set(cause); + } finally { + releaseBuffersAndCheckContent(); + } + } + + private void releaseBuffersAndCheckContent() { + int i = 0; + while (!buffers.isEmpty()) { + ByteBuf buf = buffers.poll(); + while (buf.isReadable()) { + assertEquals(i, buf.readByte()); + } + buf.release(); + i++; + } + } + + boolean isFinished() { + return finish.get() != null; + } + + void markAsFinished() { + finish.compareAndSet(null, Boolean.TRUE); + } + + void joinAndCheckForError() throws Throwable { + try { + // Mark as finish if not already done but ensure we not override the previous set error. + join(); + } finally { + releaseBuffersAndCheckContent(); + } + checkForError(); + } + + void checkForError() throws Throwable { + Object obj = finish.get(); + if (obj instanceof Throwable) { + throw (Throwable) obj; + } + } + } + + @SuppressWarnings("unchecked") + private static PooledByteBuf unwrapIfNeeded(ByteBuf buf) { + return (PooledByteBuf) (buf instanceof PooledByteBuf ? buf : buf.unwrap()); + } + + @Test + public void testCacheWorksForNormalAllocations() { + int maxCachedBufferCapacity = PooledByteBufAllocator.DEFAULT_MAX_CACHED_BUFFER_CAPACITY; + final PooledByteBufAllocator allocator = + new PooledByteBufAllocator(true, 0, 1, + PooledByteBufAllocator.defaultPageSize(), PooledByteBufAllocator.defaultMaxOrder(), + 128, 128, true); + ByteBuf buffer = allocator.directBuffer(maxCachedBufferCapacity); + assertEquals(1, allocator.metric().directArenas().get(0).numNormalAllocations()); + buffer.release(); + + buffer = allocator.directBuffer(maxCachedBufferCapacity); + // Should come out of the cache so the count should not be incremented + assertEquals(1, allocator.metric().directArenas().get(0).numNormalAllocations()); + buffer.release(); + + // Should be allocated without cache and also not put back in a cache. + buffer = allocator.directBuffer(maxCachedBufferCapacity + 1); + assertEquals(2, allocator.metric().directArenas().get(0).numNormalAllocations()); + buffer.release(); + + buffer = allocator.directBuffer(maxCachedBufferCapacity + 1); + assertEquals(3, allocator.metric().directArenas().get(0).numNormalAllocations()); + buffer.release(); + } + + @Test + public void testNormalPoolSubpageRelease() { + // 16 < elemSize <= 7168 or 8192 < elemSize <= 28672, 1 < subpage.maxNumElems <= 256 + // 7168 <= elemSize <= 8192, subpage.maxNumElems == 1 + int elemSize = 8192; + int length = 1024; + ByteBuf[] byteBufs = new ByteBuf[length]; + final PooledByteBufAllocator allocator = new PooledByteBufAllocator(false, 32, 32, 8192, 11, 256, 64, false, 0); + + for (int i = 0; i < length; i++) { + byteBufs[i] = allocator.heapBuffer(elemSize, elemSize); + } + PoolChunk chunk = unwrapIfNeeded(byteBufs[0]).chunk; + + int beforeFreeBytes = chunk.freeBytes(); + for (int i = 0; i < length; i++) { + byteBufs[i].release(); + } + int afterFreeBytes = chunk.freeBytes(); + + assertTrue(beforeFreeBytes < afterFreeBytes); + } + + @Override + @Test + public void testUsedDirectMemory() { + for (int power = 0; power < 8; power++) { + int initialCapacity = 1024 << power; + testUsedDirectMemory(initialCapacity); + } + } + + private void testUsedDirectMemory(int initialCapacity) { + PooledByteBufAllocator allocator = newAllocator(true); + ByteBufAllocatorMetric metric = allocator.metric(); + assertEquals(0, metric.usedDirectMemory()); + assertEquals(0, allocator.pinnedDirectMemory()); + ByteBuf buffer = allocator.directBuffer(initialCapacity, 4 * initialCapacity); + int capacity = buffer.capacity(); + assertEquals(expectedUsedMemory(allocator, capacity), metric.usedDirectMemory()); + assertThat(allocator.pinnedDirectMemory()) + .isGreaterThanOrEqualTo(capacity) + .isLessThanOrEqualTo(metric.usedDirectMemory()); + + // Double the size of the buffer + buffer.capacity(capacity << 1); + capacity = buffer.capacity(); + assertEquals(expectedUsedMemory(allocator, capacity), metric.usedDirectMemory(), buffer.toString()); + assertThat(allocator.pinnedDirectMemory()) + .isGreaterThanOrEqualTo(capacity) + .isLessThanOrEqualTo(metric.usedDirectMemory()); + + buffer.release(); + assertEquals(expectedUsedMemoryAfterRelease(allocator, capacity), metric.usedDirectMemory()); + assertThat(allocator.pinnedDirectMemory()) + .isGreaterThanOrEqualTo(0) + .isLessThanOrEqualTo(metric.usedDirectMemory()); + trimCaches(allocator); + assertEquals(0, allocator.pinnedDirectMemory()); + + int[] capacities = new int[30]; + Random rng = new Random(); + for (int i = 0; i < capacities.length; i++) { + capacities[i] = initialCapacity / 4 + rng.nextInt(8 * initialCapacity); + } + ByteBuf[] bufs = new ByteBuf[capacities.length]; + for (int i = 0; i < 20; i++) { + bufs[i] = allocator.directBuffer(capacities[i], 2 * capacities[i]); + } + for (int i = 0; i < 10; i++) { + bufs[i].release(); + } + for (int i = 20; i < 30; i++) { + bufs[i] = allocator.directBuffer(capacities[i], 2 * capacities[i]); + } + for (int i = 0; i < 10; i++) { + bufs[i] = allocator.directBuffer(capacities[i], 2 * capacities[i]); + } + for (int i = 0; i < 30; i++) { + bufs[i].release(); + } + trimCaches(allocator); + assertEquals(0, allocator.pinnedDirectMemory()); + } + + @Override + @Test + public void testUsedHeapMemory() { + for (int power = 0; power < 8; power++) { + int initialCapacity = 1024 << power; + testUsedHeapMemory(initialCapacity); + } + } + + private void testUsedHeapMemory(int initialCapacity) { + PooledByteBufAllocator allocator = newAllocator(true); + ByteBufAllocatorMetric metric = allocator.metric(); + + assertEquals(0, metric.usedHeapMemory()); + assertEquals(0, allocator.pinnedDirectMemory()); + ByteBuf buffer = allocator.heapBuffer(initialCapacity, 4 * initialCapacity); + int capacity = buffer.capacity(); + assertEquals(expectedUsedMemory(allocator, capacity), metric.usedHeapMemory()); + assertThat(allocator.pinnedHeapMemory()) + .isGreaterThanOrEqualTo(capacity) + .isLessThanOrEqualTo(metric.usedHeapMemory()); + + // Double the size of the buffer + buffer.capacity(capacity << 1); + capacity = buffer.capacity(); + assertEquals(expectedUsedMemory(allocator, capacity), metric.usedHeapMemory()); + assertThat(allocator.pinnedHeapMemory()) + .isGreaterThanOrEqualTo(capacity) + .isLessThanOrEqualTo(metric.usedHeapMemory()); + + buffer.release(); + assertEquals(expectedUsedMemoryAfterRelease(allocator, capacity), metric.usedHeapMemory()); + assertThat(allocator.pinnedHeapMemory()) + .isGreaterThanOrEqualTo(0) + .isLessThanOrEqualTo(metric.usedHeapMemory()); + trimCaches(allocator); + assertEquals(0, allocator.pinnedHeapMemory()); + + int[] capacities = new int[30]; + Random rng = new Random(); + for (int i = 0; i < capacities.length; i++) { + capacities[i] = initialCapacity / 4 + rng.nextInt(8 * initialCapacity); + } + ByteBuf[] bufs = new ByteBuf[capacities.length]; + for (int i = 0; i < 20; i++) { + bufs[i] = allocator.heapBuffer(capacities[i], 2 * capacities[i]); + } + for (int i = 0; i < 10; i++) { + bufs[i].release(); + } + for (int i = 20; i < 30; i++) { + bufs[i] = allocator.heapBuffer(capacities[i], 2 * capacities[i]); + } + for (int i = 0; i < 10; i++) { + bufs[i] = allocator.heapBuffer(capacities[i], 2 * capacities[i]); + } + for (int i = 0; i < 30; i++) { + bufs[i].release(); + } + trimCaches(allocator); + assertEquals(0, allocator.pinnedDirectMemory()); + } + + @Test + public void pinnedMemoryMustReflectBuffersInUseWithThreadLocalCaching() { + pinnedMemoryMustReflectBuffersInUse(true); + } + + @Test + public void pinnedMemoryMustReflectBuffersInUseWithoutThreadLocalCaching() { + pinnedMemoryMustReflectBuffersInUse(false); + } + + private static void pinnedMemoryMustReflectBuffersInUse(boolean useThreadLocalCaching) { + int smallCacheSize; + int normalCacheSize; + if (useThreadLocalCaching) { + smallCacheSize = PooledByteBufAllocator.defaultSmallCacheSize(); + normalCacheSize = PooledByteBufAllocator.defaultNormalCacheSize(); + } else { + smallCacheSize = 0; + normalCacheSize = 0; + } + int directMemoryCacheAlignment = 0; + PooledByteBufAllocator alloc = new PooledByteBufAllocator( + PooledByteBufAllocator.defaultPreferDirect(), + 1, + 1, + PooledByteBufAllocator.defaultPageSize(), + PooledByteBufAllocator.defaultMaxOrder(), + smallCacheSize, + normalCacheSize, + useThreadLocalCaching, + directMemoryCacheAlignment); + PooledByteBufAllocatorMetric metric = alloc.metric(); + AtomicLong capSum = new AtomicLong(); + + for (long index = 0; index < 10000; index++) { + ThreadLocalRandom rnd = ThreadLocalRandom.current(); + int bufCount = rnd.nextInt(1, 100); + List buffers = new ArrayList(bufCount); + + if (index % 2 == 0) { + // ensure that we allocate a small buffer + for (int i = 0; i < bufCount; i++) { + ByteBuf buf = alloc.directBuffer(rnd.nextInt(8, 128)); + buffers.add(buf); + capSum.addAndGet(buf.capacity()); + } + } else { + // allocate a larger buffer + for (int i = 0; i < bufCount; i++) { + ByteBuf buf = alloc.directBuffer(rnd.nextInt(1024, 1024 * 100)); + buffers.add(buf); + capSum.addAndGet(buf.capacity()); + } + } + + if (index % 100 == 0) { + long used = usedMemory(metric.directArenas()); + long pinned = alloc.pinnedDirectMemory(); + assertThat(capSum.get()).isLessThanOrEqualTo(pinned); + assertThat(pinned).isLessThanOrEqualTo(used); + } + + for (ByteBuf buffer : buffers) { + buffer.release(); + } + capSum.set(0); + // After releasing all buffers, pinned memory must be zero + assertThat(alloc.pinnedDirectMemory()).isZero(); + } + } + + /** + * Returns an estimate of bytes used by currently in-use buffers + */ + private static long usedMemory(List arenas) { + long totalUsed = 0; + for (PoolArenaMetric arenaMetrics : arenas) { + for (PoolChunkListMetric arenaMetric : arenaMetrics.chunkLists()) { + for (PoolChunkMetric chunkMetric : arenaMetric) { + // chunkMetric.chunkSize() returns maximum of bytes that can be served out of the chunk + // and chunkMetric.freeBytes() returns the bytes that are not yet allocated by in-use buffers + totalUsed += chunkMetric.chunkSize() - chunkMetric.freeBytes(); + } + } + } + return totalUsed; + } + + @Test + public void testCapacityChangeDoesntThrowAssertionError() throws Exception { + ByteBufAllocator allocator = newAllocator(true); + List buffers = new ArrayList(); + try { + for (int i = 0; i < 31; i++) { + buffers.add(allocator.heapBuffer()); + } + + final ByteBuf buf = allocator.heapBuffer(); + buffers.add(buf); + final AtomicReference assertionRef = new AtomicReference(); + Runnable capacityChangeTask = new Runnable() { + @Override + public void run() { + try { + buf.capacity(512); + } catch (AssertionError e) { + assertionRef.compareAndSet(null, e); + throw e; + } + } + }; + Thread thread1 = new Thread(capacityChangeTask); + Thread thread2 = new Thread(capacityChangeTask); + + thread1.start(); + thread2.start(); + + thread1.join(); + thread2.join(); + + buffers.add(allocator.heapBuffer()); + buffers.add(allocator.heapBuffer()); + + AssertionError error = assertionRef.get(); + if (error != null) { + throw error; + } + } finally { + for (ByteBuf buffer: buffers) { + buffer.release(); + } + } + } +} diff --git a/netty-buffer/src/test/java/io/netty/buffer/PooledLittleEndianDirectByteBufTest.java b/netty-buffer/src/test/java/io/netty/buffer/PooledLittleEndianDirectByteBufTest.java new file mode 100644 index 0000000..2e0300a --- /dev/null +++ b/netty-buffer/src/test/java/io/netty/buffer/PooledLittleEndianDirectByteBufTest.java @@ -0,0 +1,34 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import java.nio.ByteOrder; + +import static org.junit.jupiter.api.Assertions.assertSame; + +/** + * Tests little-endian direct channel buffers + */ +public class PooledLittleEndianDirectByteBufTest extends AbstractPooledByteBufTest { + + @Override + protected ByteBuf alloc(int length, int maxCapacity) { + ByteBuf buffer = PooledByteBufAllocator.DEFAULT.directBuffer(length, maxCapacity) + .order(ByteOrder.LITTLE_ENDIAN); + assertSame(ByteOrder.LITTLE_ENDIAN, buffer.order()); + return buffer; + } +} diff --git a/netty-buffer/src/test/java/io/netty/buffer/PooledLittleEndianHeapByteBufTest.java b/netty-buffer/src/test/java/io/netty/buffer/PooledLittleEndianHeapByteBufTest.java new file mode 100644 index 0000000..bbaa007 --- /dev/null +++ b/netty-buffer/src/test/java/io/netty/buffer/PooledLittleEndianHeapByteBufTest.java @@ -0,0 +1,33 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import java.nio.ByteOrder; + +import static org.junit.jupiter.api.Assertions.assertSame; + +/** + * Tests little-endian heap channel buffers + */ +public class PooledLittleEndianHeapByteBufTest extends AbstractPooledByteBufTest { + + @Override + protected ByteBuf alloc(int length, int maxCapacity) { + ByteBuf buffer = PooledByteBufAllocator.DEFAULT.heapBuffer(length, maxCapacity).order(ByteOrder.LITTLE_ENDIAN); + assertSame(ByteOrder.LITTLE_ENDIAN, buffer.order()); + return buffer; + } +} diff --git a/netty-buffer/src/test/java/io/netty/buffer/ReadOnlyByteBufTest.java b/netty-buffer/src/test/java/io/netty/buffer/ReadOnlyByteBufTest.java new file mode 100644 index 0000000..07ce509 --- /dev/null +++ b/netty-buffer/src/test/java/io/netty/buffer/ReadOnlyByteBufTest.java @@ -0,0 +1,300 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.ByteBuffer; +import java.nio.ReadOnlyBufferException; +import java.nio.channels.GatheringByteChannel; +import java.nio.channels.ScatteringByteChannel; + +import static io.netty.buffer.ByteBufUtil.ensureWritableSuccess; +import static io.netty.buffer.Unpooled.BIG_ENDIAN; +import static io.netty.buffer.Unpooled.EMPTY_BUFFER; +import static io.netty.buffer.Unpooled.LITTLE_ENDIAN; +import static io.netty.buffer.Unpooled.buffer; +import static io.netty.buffer.Unpooled.unmodifiableBuffer; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +/** + * Tests read-only channel buffers + */ +public class ReadOnlyByteBufTest { + + @Test + public void shouldNotAllowNullInConstructor() { + assertThrows(NullPointerException.class, new Executable() { + @Override + public void execute() { + new ReadOnlyByteBuf(null); + } + }); + } + + @Test + public void testUnmodifiableBuffer() { + assertThat(unmodifiableBuffer(buffer(1))).isInstanceOf(ReadOnlyByteBuf.class); + } + + @Test + public void testUnwrap() { + ByteBuf buf = buffer(1); + assertSame(buf, unmodifiableBuffer(buf).unwrap()); + } + + @Test + public void shouldHaveSameByteOrder() { + ByteBuf buf = buffer(1); + assertSame(BIG_ENDIAN, unmodifiableBuffer(buf).order()); + buf = buf.order(LITTLE_ENDIAN); + assertSame(LITTLE_ENDIAN, unmodifiableBuffer(buf).order()); + } + + @Test + public void shouldReturnReadOnlyDerivedBuffer() { + ByteBuf buf = unmodifiableBuffer(buffer(1)); + assertThat(buf.duplicate()).isInstanceOf(ReadOnlyByteBuf.class); + assertThat(buf.slice()).isInstanceOf(ReadOnlyByteBuf.class); + assertThat(buf.slice(0, 1)).isInstanceOf(ReadOnlyByteBuf.class); + assertThat(buf.duplicate()).isInstanceOf(ReadOnlyByteBuf.class); + } + + @Test + public void shouldReturnWritableCopy() { + ByteBuf buf = unmodifiableBuffer(buffer(1)); + assertThat(buf.copy()).isNotInstanceOf(ReadOnlyByteBuf.class); + } + + @Test + public void shouldForwardReadCallsBlindly() throws Exception { + ByteBuf buf = mock(ByteBuf.class); + when(buf.order()).thenReturn(BIG_ENDIAN); + when(buf.maxCapacity()).thenReturn(65536); + when(buf.readerIndex()).thenReturn(0); + when(buf.writerIndex()).thenReturn(0); + when(buf.capacity()).thenReturn(0); + + when(buf.getBytes(1, (GatheringByteChannel) null, 2)).thenReturn(3); + when(buf.getBytes(4, (OutputStream) null, 5)).thenReturn(buf); + when(buf.getBytes(6, (byte[]) null, 7, 8)).thenReturn(buf); + when(buf.getBytes(9, (ByteBuf) null, 10, 11)).thenReturn(buf); + when(buf.getBytes(12, (ByteBuffer) null)).thenReturn(buf); + when(buf.getByte(13)).thenReturn(Byte.valueOf((byte) 14)); + when(buf.getShort(15)).thenReturn(Short.valueOf((short) 16)); + when(buf.getUnsignedMedium(17)).thenReturn(18); + when(buf.getInt(19)).thenReturn(20); + when(buf.getLong(21)).thenReturn(22L); + + ByteBuffer bb = ByteBuffer.allocate(100); + + when(buf.nioBuffer(23, 24)).thenReturn(bb); + when(buf.capacity()).thenReturn(27); + + ByteBuf roBuf = unmodifiableBuffer(buf); + assertEquals(3, roBuf.getBytes(1, (GatheringByteChannel) null, 2)); + roBuf.getBytes(4, (OutputStream) null, 5); + roBuf.getBytes(6, (byte[]) null, 7, 8); + roBuf.getBytes(9, (ByteBuf) null, 10, 11); + roBuf.getBytes(12, (ByteBuffer) null); + assertEquals((byte) 14, roBuf.getByte(13)); + assertEquals((short) 16, roBuf.getShort(15)); + assertEquals(18, roBuf.getUnsignedMedium(17)); + assertEquals(20, roBuf.getInt(19)); + assertEquals(22L, roBuf.getLong(21)); + + ByteBuffer roBB = roBuf.nioBuffer(23, 24); + assertEquals(100, roBB.capacity()); + assertTrue(roBB.isReadOnly()); + + assertEquals(27, roBuf.capacity()); + } + + @Test + public void shouldRejectDiscardReadBytes() { + assertThrows(UnsupportedOperationException.class, new Executable() { + @Override + public void execute() { + unmodifiableBuffer(EMPTY_BUFFER).discardReadBytes(); + } + }); + } + + @Test + public void shouldRejectSetByte() { + assertThrows(UnsupportedOperationException.class, new Executable() { + @Override + public void execute() { + unmodifiableBuffer(EMPTY_BUFFER).setByte(0, (byte) 0); + } + }); + } + + @Test + public void shouldRejectSetShort() { + assertThrows(UnsupportedOperationException.class, new Executable() { + @Override + public void execute() { + unmodifiableBuffer(EMPTY_BUFFER).setShort(0, (short) 0); + } + }); + } + + @Test + public void shouldRejectSetMedium() { + assertThrows(UnsupportedOperationException.class, new Executable() { + @Override + public void execute() { + unmodifiableBuffer(EMPTY_BUFFER).setMedium(0, 0); + } + }); + } + + @Test + public void shouldRejectSetInt() { + assertThrows(UnsupportedOperationException.class, new Executable() { + @Override + public void execute() { + unmodifiableBuffer(EMPTY_BUFFER).setInt(0, 0); + } + }); + } + + @Test + public void shouldRejectSetLong() { + assertThrows(UnsupportedOperationException.class, new Executable() { + @Override + public void execute() { + unmodifiableBuffer(EMPTY_BUFFER).setLong(0, 0); + } + }); + } + + @Test + public void shouldRejectSetBytes1() { + assertThrows(UnsupportedOperationException.class, new Executable() { + @Override + public void execute() throws IOException { + unmodifiableBuffer(EMPTY_BUFFER).setBytes(0, (InputStream) null, 0); + } + }); + } + + @Test + public void shouldRejectSetBytes2() { + assertThrows(UnsupportedOperationException.class, new Executable() { + @Override + public void execute() throws IOException { + unmodifiableBuffer(EMPTY_BUFFER).setBytes(0, (ScatteringByteChannel) null, 0); + } + }); + } + + @Test + public void shouldRejectSetBytes3() { + assertThrows(UnsupportedOperationException.class, new Executable() { + @Override + public void execute() throws IOException { + unmodifiableBuffer(EMPTY_BUFFER).setBytes(0, (byte[]) null, 0, 0); + } + }); + } + + @Test + public void shouldRejectSetBytes4() { + assertThrows(UnsupportedOperationException.class, new Executable() { + @Override + public void execute() { + unmodifiableBuffer(EMPTY_BUFFER).setBytes(0, (ByteBuf) null, 0, 0); + } + }); + } + + @Test + public void shouldRejectSetBytes5() { + assertThrows(UnsupportedOperationException.class, new Executable() { + @Override + public void execute() { + unmodifiableBuffer(EMPTY_BUFFER).setBytes(0, (ByteBuffer) null); + } + }); + } + + @Test + public void shouldIndicateNotWritable() { + assertFalse(unmodifiableBuffer(buffer(1)).isWritable()); + } + + @Test + public void shouldIndicateNotWritableAnyNumber() { + assertFalse(unmodifiableBuffer(buffer(1)).isWritable(1)); + } + + @Test + public void ensureWritableIntStatusShouldFailButNotThrow() { + ensureWritableIntStatusShouldFailButNotThrow(false); + } + + @Test + public void ensureWritableForceIntStatusShouldFailButNotThrow() { + ensureWritableIntStatusShouldFailButNotThrow(true); + } + + private static void ensureWritableIntStatusShouldFailButNotThrow(boolean force) { + ByteBuf buf = buffer(1); + ByteBuf readOnly = buf.asReadOnly(); + int result = readOnly.ensureWritable(1, force); + assertEquals(1, result); + assertFalse(ensureWritableSuccess(result)); + readOnly.release(); + } + + @Test + public void ensureWritableShouldThrow() { + ByteBuf buf = buffer(1); + final ByteBuf readOnly = buf.asReadOnly(); + try { + assertThrows(ReadOnlyBufferException.class, new Executable() { + @Override + public void execute() { + readOnly.ensureWritable(1); + } + }); + } finally { + buf.release(); + } + } + + @Test + public void asReadOnly() { + ByteBuf buf = buffer(1); + ByteBuf readOnly = buf.asReadOnly(); + assertTrue(readOnly.isReadOnly()); + assertSame(readOnly, readOnly.asReadOnly()); + readOnly.release(); + } +} diff --git a/netty-buffer/src/test/java/io/netty/buffer/ReadOnlyByteBufferBufTest.java b/netty-buffer/src/test/java/io/netty/buffer/ReadOnlyByteBufferBufTest.java new file mode 100644 index 0000000..743d92e --- /dev/null +++ b/netty-buffer/src/test/java/io/netty/buffer/ReadOnlyByteBufferBufTest.java @@ -0,0 +1,58 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import io.netty.util.internal.PlatformDependent; +import org.junit.jupiter.api.Test; + +import java.nio.ByteBuffer; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class ReadOnlyByteBufferBufTest extends ReadOnlyDirectByteBufferBufTest { + @Override + protected ByteBuffer allocate(int size) { + return ByteBuffer.allocate(size); + } + + @Test + public void testCopyDirect() { + testCopy(true); + } + + @Test + public void testCopyHeap() { + testCopy(false); + } + + private static void testCopy(boolean direct) { + byte[] bytes = new byte[1024]; + PlatformDependent.threadLocalRandom().nextBytes(bytes); + + ByteBuffer nioBuffer = direct ? ByteBuffer.allocateDirect(bytes.length) : ByteBuffer.allocate(bytes.length); + nioBuffer.put(bytes).flip(); + + ByteBuf buf = new ReadOnlyByteBufferBuf(UnpooledByteBufAllocator.DEFAULT, nioBuffer.asReadOnlyBuffer()); + ByteBuf copy = buf.copy(); + + assertEquals(buf, copy); + assertEquals(buf.alloc(), copy.alloc()); + assertEquals(buf.isDirect(), copy.isDirect()); + + copy.release(); + buf.release(); + } +} diff --git a/netty-buffer/src/test/java/io/netty/buffer/ReadOnlyDirectByteBufferBufTest.java b/netty-buffer/src/test/java/io/netty/buffer/ReadOnlyDirectByteBufferBufTest.java new file mode 100644 index 0000000..053c704 --- /dev/null +++ b/netty-buffer/src/test/java/io/netty/buffer/ReadOnlyDirectByteBufferBufTest.java @@ -0,0 +1,426 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import io.netty.util.internal.PlatformDependent; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import java.io.ByteArrayInputStream; +import java.io.File; +import java.io.IOException; +import java.io.RandomAccessFile; +import java.nio.ByteBuffer; +import java.nio.ReadOnlyBufferException; +import java.nio.channels.FileChannel; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class ReadOnlyDirectByteBufferBufTest { + + protected ByteBuf buffer(ByteBuffer buffer) { + return new ReadOnlyByteBufferBuf(UnpooledByteBufAllocator.DEFAULT, buffer); + } + + protected ByteBuffer allocate(int size) { + return ByteBuffer.allocateDirect(size); + } + + @Test + public void testIsContiguous() { + ByteBuf buf = buffer(allocate(4).asReadOnlyBuffer()); + assertTrue(buf.isContiguous()); + buf.release(); + } + + @Test + public void testConstructWithWritable() { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + buffer(ReadOnlyDirectByteBufferBufTest.this.allocate(1)); + } + }); + } + + @Test + public void shouldIndicateNotWritable() { + ByteBuf buf = buffer(allocate(8).asReadOnlyBuffer()).clear(); + try { + assertFalse(buf.isWritable()); + } finally { + buf.release(); + } + } + + @Test + public void shouldIndicateNotWritableAnyNumber() { + ByteBuf buf = buffer(allocate(8).asReadOnlyBuffer()).clear(); + try { + assertFalse(buf.isWritable(1)); + } finally { + buf.release(); + } + } + + @Test + public void ensureWritableIntStatusShouldFailButNotThrow() { + ByteBuf buf = buffer(allocate(8).asReadOnlyBuffer()).clear(); + try { + int result = buf.ensureWritable(1, false); + assertEquals(1, result); + } finally { + buf.release(); + } + } + + @Test + public void ensureWritableForceIntStatusShouldFailButNotThrow() { + ByteBuf buf = buffer(allocate(8).asReadOnlyBuffer()).clear(); + try { + int result = buf.ensureWritable(1, true); + assertEquals(1, result); + } finally { + buf.release(); + } + } + + @Test + public void ensureWritableShouldThrow() { + final ByteBuf buf = buffer(allocate(8).asReadOnlyBuffer()).clear(); + try { + assertThrows(ReadOnlyBufferException.class, new Executable() { + @Override + public void execute() { + buf.ensureWritable(1); + } + }); + } finally { + buf.release(); + } + } + + @Test + public void testSetByte() { + final ByteBuf buf = buffer(allocate(8).asReadOnlyBuffer()); + try { + assertThrows(ReadOnlyBufferException.class, new Executable() { + @Override + public void execute() { + buf.setByte(0, 1); + } + }); + } finally { + buf.release(); + } + } + + @Test + public void testSetInt() { + final ByteBuf buf = buffer(allocate(8).asReadOnlyBuffer()); + try { + assertThrows(ReadOnlyBufferException.class, new Executable() { + @Override + public void execute() { + buf.setInt(0, 1); + } + }); + } finally { + buf.release(); + } + } + + @Test + public void testSetShort() { + final ByteBuf buf = buffer(allocate(8).asReadOnlyBuffer()); + try { + assertThrows(ReadOnlyBufferException.class, new Executable() { + @Override + public void execute() { + buf.setShort(0, 1); + } + }); + } finally { + buf.release(); + } + } + + @Test + public void testSetMedium() { + final ByteBuf buf = buffer(allocate(8).asReadOnlyBuffer()); + try { + assertThrows(ReadOnlyBufferException.class, new Executable() { + @Override + public void execute() { + buf.setMedium(0, 1); + } + }); + } finally { + buf.release(); + } + } + + @Test + public void testSetLong() { + final ByteBuf buf = buffer(allocate(8).asReadOnlyBuffer()); + try { + assertThrows(ReadOnlyBufferException.class, new Executable() { + @Override + public void execute() { + buf.setLong(0, 1); + } + }); + } finally { + buf.release(); + } + } + + @Test + public void testSetBytesViaArray() { + final ByteBuf buf = buffer(allocate(8).asReadOnlyBuffer()); + try { + assertThrows(ReadOnlyBufferException.class, new Executable() { + @Override + public void execute() { + buf.setBytes(0, "test".getBytes()); + } + }); + } finally { + buf.release(); + } + } + + @Test + public void testSetBytesViaBuffer() { + final ByteBuf buf = buffer(allocate(8).asReadOnlyBuffer()); + final ByteBuf copy = Unpooled.copyInt(1); + try { + assertThrows(ReadOnlyBufferException.class, new Executable() { + @Override + public void execute() { + buf.setBytes(0, copy); + } + }); + } finally { + buf.release(); + copy.release(); + } + } + + @Test + public void testSetBytesViaStream() throws IOException { + final ByteBuf buf = buffer(ByteBuffer.allocateDirect(8).asReadOnlyBuffer()); + try { + assertThrows(ReadOnlyBufferException.class, new Executable() { + @Override + public void execute() throws Throwable { + buf.setBytes(0, new ByteArrayInputStream("test".getBytes()), 2); + } + }); + } finally { + buf.release(); + } + } + + @Test + public void testGetReadByte() { + ByteBuf buf = buffer( + ((ByteBuffer) allocate(2).put(new byte[] { (byte) 1, (byte) 2 }).flip()).asReadOnlyBuffer()); + + assertEquals(1, buf.getByte(0)); + assertEquals(2, buf.getByte(1)); + + assertEquals(1, buf.readByte()); + assertEquals(2, buf.readByte()); + assertFalse(buf.isReadable()); + + buf.release(); + } + + @Test + public void testGetReadInt() { + ByteBuf buf = buffer(((ByteBuffer) allocate(8).putInt(1).putInt(2).flip()).asReadOnlyBuffer()); + + assertEquals(1, buf.getInt(0)); + assertEquals(2, buf.getInt(4)); + + assertEquals(1, buf.readInt()); + assertEquals(2, buf.readInt()); + assertFalse(buf.isReadable()); + + buf.release(); + } + + @Test + public void testGetReadShort() { + ByteBuf buf = buffer(((ByteBuffer) allocate(8) + .putShort((short) 1).putShort((short) 2).flip()).asReadOnlyBuffer()); + + assertEquals(1, buf.getShort(0)); + assertEquals(2, buf.getShort(2)); + + assertEquals(1, buf.readShort()); + assertEquals(2, buf.readShort()); + assertFalse(buf.isReadable()); + + buf.release(); + } + + @Test + public void testGetReadLong() { + ByteBuf buf = buffer(((ByteBuffer) allocate(16) + .putLong(1).putLong(2).flip()).asReadOnlyBuffer()); + + assertEquals(1, buf.getLong(0)); + assertEquals(2, buf.getLong(8)); + + assertEquals(1, buf.readLong()); + assertEquals(2, buf.readLong()); + assertFalse(buf.isReadable()); + + buf.release(); + } + + @Test + public void testGetBytesByteBuffer() { + byte[] bytes = {'a', 'b', 'c', 'd', 'e', 'f', 'g'}; + // Ensure destination buffer is bigger then what is in the ByteBuf. + final ByteBuffer nioBuffer = ByteBuffer.allocate(bytes.length + 1); + final ByteBuf buffer = buffer(((ByteBuffer) allocate(bytes.length) + .put(bytes).flip()).asReadOnlyBuffer()); + try { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + buffer.getBytes(buffer.readerIndex(), nioBuffer); + } + }); + } finally { + buffer.release(); + } + } + + @Test + public void testCopy() { + ByteBuf buf = buffer(((ByteBuffer) allocate(16).putLong(1).putLong(2).flip()).asReadOnlyBuffer()); + ByteBuf copy = buf.copy(); + + assertEquals(buf, copy); + + buf.release(); + copy.release(); + } + + @Test + public void testCopyWithOffset() { + ByteBuf buf = buffer(((ByteBuffer) allocate(16).putLong(1).putLong(2).flip()).asReadOnlyBuffer()); + ByteBuf copy = buf.copy(1, 9); + + assertEquals(buf.slice(1, 9), copy); + + buf.release(); + copy.release(); + } + + // Test for https://github.com/netty/netty/issues/1708 + @Test + public void testWrapBufferWithNonZeroPosition() { + ByteBuf buf = buffer(((ByteBuffer) allocate(16) + .putLong(1).flip().position(1)).asReadOnlyBuffer()); + + ByteBuf slice = buf.slice(); + assertEquals(buf, slice); + + buf.release(); + } + + @Test + public void testWrapBufferRoundTrip() { + ByteBuf buf = buffer(((ByteBuffer) allocate(16).putInt(1).putInt(2).flip()).asReadOnlyBuffer()); + + assertEquals(1, buf.readInt()); + + ByteBuffer nioBuffer = buf.nioBuffer(); + + // Ensure this can be accessed without throwing a BufferUnderflowException + assertEquals(2, nioBuffer.getInt()); + + buf.release(); + } + + @Test + public void testWrapMemoryMapped() throws Exception { + File file = PlatformDependent.createTempFile("netty-test", "tmp", null); + FileChannel output = null; + FileChannel input = null; + ByteBuf b1 = null; + ByteBuf b2 = null; + + try { + output = new RandomAccessFile(file, "rw").getChannel(); + byte[] bytes = new byte[1024]; + PlatformDependent.threadLocalRandom().nextBytes(bytes); + output.write(ByteBuffer.wrap(bytes)); + + input = new RandomAccessFile(file, "r").getChannel(); + ByteBuffer m = input.map(FileChannel.MapMode.READ_ONLY, 0, input.size()); + + b1 = buffer(m); + + ByteBuffer dup = m.duplicate(); + dup.position(2); + dup.limit(4); + + b2 = buffer(dup); + + assertEquals(b2, b1.slice(2, 2)); + } finally { + if (b1 != null) { + b1.release(); + } + if (b2 != null) { + b2.release(); + } + if (output != null) { + output.close(); + } + if (input != null) { + input.close(); + } + file.delete(); + } + } + + @Test + public void testMemoryAddress() { + ByteBuf buf = buffer(allocate(8).asReadOnlyBuffer()); + try { + assertFalse(buf.hasMemoryAddress()); + try { + buf.memoryAddress(); + fail(); + } catch (UnsupportedOperationException expected) { + // expected + } + } finally { + buf.release(); + } + } +} diff --git a/netty-buffer/src/test/java/io/netty/buffer/ReadOnlyUnsafeDirectByteBufferBufTest.java b/netty-buffer/src/test/java/io/netty/buffer/ReadOnlyUnsafeDirectByteBufferBufTest.java new file mode 100644 index 0000000..e091174 --- /dev/null +++ b/netty-buffer/src/test/java/io/netty/buffer/ReadOnlyUnsafeDirectByteBufferBufTest.java @@ -0,0 +1,53 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import io.netty.util.internal.PlatformDependent; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +import java.nio.ByteBuffer; + +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +public class ReadOnlyUnsafeDirectByteBufferBufTest extends ReadOnlyDirectByteBufferBufTest { + + /** + * Needs unsafe to run + */ + @BeforeAll + public static void assumeConditions() { + assumeTrue(PlatformDependent.hasUnsafe(), "sun.misc.Unsafe not found, skip tests"); + } + + @Override + protected ByteBuf buffer(ByteBuffer buffer) { + return new ReadOnlyUnsafeDirectByteBuf(UnpooledByteBufAllocator.DEFAULT, buffer); + } + + @Test + @Override + public void testMemoryAddress() { + ByteBuf buf = buffer(allocate(8).asReadOnlyBuffer()); + try { + assertTrue(buf.hasMemoryAddress()); + buf.memoryAddress(); + } finally { + buf.release(); + } + } +} diff --git a/netty-buffer/src/test/java/io/netty/buffer/RetainedDuplicatedByteBufTest.java b/netty-buffer/src/test/java/io/netty/buffer/RetainedDuplicatedByteBufTest.java new file mode 100644 index 0000000..5bc5702 --- /dev/null +++ b/netty-buffer/src/test/java/io/netty/buffer/RetainedDuplicatedByteBufTest.java @@ -0,0 +1,32 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.buffer; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class RetainedDuplicatedByteBufTest extends DuplicatedByteBufTest { + @Override + protected ByteBuf newBuffer(int length, int maxCapacity) { + ByteBuf wrapped = Unpooled.buffer(length, maxCapacity); + ByteBuf buffer = wrapped.retainedDuplicate(); + wrapped.release(); + + assertEquals(wrapped.writerIndex(), buffer.writerIndex()); + assertEquals(wrapped.readerIndex(), buffer.readerIndex()); + return buffer; + } +} diff --git a/netty-buffer/src/test/java/io/netty/buffer/RetainedSlicedByteBufTest.java b/netty-buffer/src/test/java/io/netty/buffer/RetainedSlicedByteBufTest.java new file mode 100644 index 0000000..898ad2f --- /dev/null +++ b/netty-buffer/src/test/java/io/netty/buffer/RetainedSlicedByteBufTest.java @@ -0,0 +1,30 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.buffer; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class RetainedSlicedByteBufTest extends SlicedByteBufTest { + + @Override + protected ByteBuf newSlice(ByteBuf buffer, int offset, int length) { + ByteBuf slice = buffer.retainedSlice(offset, length); + buffer.release(); + assertEquals(buffer.refCnt(), slice.refCnt()); + return slice; + } +} diff --git a/netty-buffer/src/test/java/io/netty/buffer/SimpleLeakAwareByteBufTest.java b/netty-buffer/src/test/java/io/netty/buffer/SimpleLeakAwareByteBufTest.java new file mode 100644 index 0000000..7844cce --- /dev/null +++ b/netty-buffer/src/test/java/io/netty/buffer/SimpleLeakAwareByteBufTest.java @@ -0,0 +1,144 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import io.netty.util.ResourceLeakTracker; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.ArrayDeque; +import java.util.Queue; + +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class SimpleLeakAwareByteBufTest extends BigEndianHeapByteBufTest { + private final Class clazz = leakClass(); + private final Queue> trackers = new ArrayDeque>(); + + @Override + protected final ByteBuf newBuffer(int capacity, int maxCapacity) { + return wrap(super.newBuffer(capacity, maxCapacity)); + } + + private ByteBuf wrap(ByteBuf buffer) { + NoopResourceLeakTracker tracker = new NoopResourceLeakTracker(); + ByteBuf leakAwareBuf = wrap(buffer, tracker); + trackers.add(tracker); + return leakAwareBuf; + } + + protected SimpleLeakAwareByteBuf wrap(ByteBuf buffer, ResourceLeakTracker tracker) { + return new SimpleLeakAwareByteBuf(buffer, tracker); + } + + @BeforeEach + @Override + public void init() { + super.init(); + trackers.clear(); + } + + @AfterEach + @Override + public void dispose() { + super.dispose(); + + for (;;) { + NoopResourceLeakTracker tracker = trackers.poll(); + + if (tracker == null) { + break; + } + assertTrue(tracker.get()); + } + } + + protected Class leakClass() { + return SimpleLeakAwareByteBuf.class; + } + + @Test + public void testWrapSlice() { + assertWrapped(newBuffer(8).slice()); + } + + @Test + public void testWrapSlice2() { + assertWrapped(newBuffer(8).slice(0, 1)); + } + + @Test + public void testWrapReadSlice() { + ByteBuf buffer = newBuffer(8); + if (buffer.isReadable()) { + assertWrapped(buffer.readSlice(1)); + } else { + assertTrue(buffer.release()); + } + } + + @Test + public void testWrapRetainedSlice() { + ByteBuf buffer = newBuffer(8); + assertWrapped(buffer.retainedSlice()); + assertTrue(buffer.release()); + } + + @Test + public void testWrapRetainedSlice2() { + ByteBuf buffer = newBuffer(8); + if (buffer.isReadable()) { + assertWrapped(buffer.retainedSlice(0, 1)); + } + assertTrue(buffer.release()); + } + + @Test + public void testWrapReadRetainedSlice() { + ByteBuf buffer = newBuffer(8); + if (buffer.isReadable()) { + assertWrapped(buffer.readRetainedSlice(1)); + } + assertTrue(buffer.release()); + } + + @Test + public void testWrapDuplicate() { + assertWrapped(newBuffer(8).duplicate()); + } + + @Test + public void testWrapRetainedDuplicate() { + ByteBuf buffer = newBuffer(8); + assertWrapped(buffer.retainedDuplicate()); + assertTrue(buffer.release()); + } + + @Test + public void testWrapReadOnly() { + assertWrapped(newBuffer(8).asReadOnly()); + } + + protected final void assertWrapped(ByteBuf buf) { + try { + assertSame(clazz, buf.getClass()); + } finally { + buf.release(); + } + } +} diff --git a/netty-buffer/src/test/java/io/netty/buffer/SimpleLeakAwareCompositeByteBufTest.java b/netty-buffer/src/test/java/io/netty/buffer/SimpleLeakAwareCompositeByteBufTest.java new file mode 100644 index 0000000..43369c5 --- /dev/null +++ b/netty-buffer/src/test/java/io/netty/buffer/SimpleLeakAwareCompositeByteBufTest.java @@ -0,0 +1,165 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import io.netty.util.ByteProcessor; +import io.netty.util.ResourceLeakTracker; +import org.hamcrest.CoreMatchers; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.ArrayDeque; +import java.util.Queue; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class SimpleLeakAwareCompositeByteBufTest extends WrappedCompositeByteBufTest { + + private final Class clazz = leakClass(); + private final Queue> trackers = new ArrayDeque>(); + + @Override + protected final WrappedCompositeByteBuf wrap(CompositeByteBuf buffer) { + NoopResourceLeakTracker tracker = new NoopResourceLeakTracker(); + WrappedCompositeByteBuf leakAwareBuf = wrap(buffer, tracker); + trackers.add(tracker); + return leakAwareBuf; + } + + protected SimpleLeakAwareCompositeByteBuf wrap(CompositeByteBuf buffer, ResourceLeakTracker tracker) { + return new SimpleLeakAwareCompositeByteBuf(buffer, tracker); + } + + @BeforeEach + @Override + public void init() { + super.init(); + trackers.clear(); + } + + @AfterEach + @Override + public void dispose() { + super.dispose(); + + for (;;) { + NoopResourceLeakTracker tracker = trackers.poll(); + + if (tracker == null) { + break; + } + assertTrue(tracker.get()); + } + } + + protected Class leakClass() { + return SimpleLeakAwareByteBuf.class; + } + + @Test + public void testWrapSlice() { + assertWrapped(newBuffer(8).slice()); + } + + @Test + public void testWrapSlice2() { + assertWrapped(newBuffer(8).slice(0, 1)); + } + + @Test + public void testWrapReadSlice() { + ByteBuf buffer = newBuffer(8); + if (buffer.isReadable()) { + assertWrapped(buffer.readSlice(1)); + } else { + assertTrue(buffer.release()); + } + } + + @Test + public void testWrapRetainedSlice() { + ByteBuf buffer = newBuffer(8); + assertWrapped(buffer.retainedSlice()); + assertTrue(buffer.release()); + } + + @Test + public void testWrapRetainedSlice2() { + ByteBuf buffer = newBuffer(8); + if (buffer.isReadable()) { + assertWrapped(buffer.retainedSlice(0, 1)); + } + assertTrue(buffer.release()); + } + + @Test + public void testWrapReadRetainedSlice() { + ByteBuf buffer = newBuffer(8); + if (buffer.isReadable()) { + assertWrapped(buffer.readRetainedSlice(1)); + } + assertTrue(buffer.release()); + } + + @Test + public void testWrapDuplicate() { + assertWrapped(newBuffer(8).duplicate()); + } + + @Test + public void testWrapRetainedDuplicate() { + ByteBuf buffer = newBuffer(8); + assertWrapped(buffer.retainedDuplicate()); + assertTrue(buffer.release()); + } + + @Test + public void testWrapReadOnly() { + assertWrapped(newBuffer(8).asReadOnly()); + } + + @Test + public void forEachByteUnderLeakDetectionShouldNotThrowException() { + CompositeByteBuf buf = (CompositeByteBuf) newBuffer(8); + assertThat(buf, CoreMatchers.instanceOf(SimpleLeakAwareCompositeByteBuf.class)); + CompositeByteBuf comp = (CompositeByteBuf) newBuffer(8); + assertThat(comp, CoreMatchers.instanceOf(SimpleLeakAwareCompositeByteBuf.class)); + + ByteBuf inner = comp.alloc().directBuffer(1).writeByte(0); + comp.addComponent(true, inner); + buf.addComponent(true, comp); + + assertEquals(-1, buf.forEachByte(new ByteProcessor() { + @Override + public boolean process(byte value) { + return true; + } + })); + assertTrue(buf.release()); + } + + protected final void assertWrapped(ByteBuf buf) { + try { + assertSame(clazz, buf.getClass()); + } finally { + buf.release(); + } + } +} diff --git a/netty-buffer/src/test/java/io/netty/buffer/SlicedByteBufTest.java b/netty-buffer/src/test/java/io/netty/buffer/SlicedByteBufTest.java new file mode 100644 index 0000000..e6e20ff --- /dev/null +++ b/netty-buffer/src/test/java/io/netty/buffer/SlicedByteBufTest.java @@ -0,0 +1,355 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import io.netty.util.internal.PlatformDependent; + +import org.junit.jupiter.api.Assumptions; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import java.nio.ByteBuffer; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * Tests sliced channel buffers + */ +public class SlicedByteBufTest extends AbstractByteBufTest { + + @Override + protected final ByteBuf newBuffer(int length, int maxCapacity) { + Assumptions.assumeTrue(maxCapacity == Integer.MAX_VALUE); + int offset = length == 0 ? 0 : PlatformDependent.threadLocalRandom().nextInt(length); + ByteBuf buffer = Unpooled.buffer(length * 2); + ByteBuf slice = newSlice(buffer, offset, length); + assertEquals(0, slice.readerIndex()); + assertEquals(length, slice.writerIndex()); + return slice; + } + + protected ByteBuf newSlice(ByteBuf buffer, int offset, int length) { + return buffer.slice(offset, length); + } + + @Test + public void testIsContiguous() { + ByteBuf buf = newBuffer(4); + assertEquals(buf.unwrap().isContiguous(), buf.isContiguous()); + buf.release(); + } + + @Test + public void shouldNotAllowNullInConstructor() { + assertThrows(NullPointerException.class, new Executable() { + @Override + public void execute() { + new SlicedByteBuf(null, 0, 0); + } + }); + } + + @Test + @Override + public void testInternalNioBuffer() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + SlicedByteBufTest.super.testInternalNioBuffer(); + } + }); + } + + @Test + @Override + public void testDuplicateReadGatheringByteChannelMultipleThreads() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() throws Exception { + SlicedByteBufTest.super.testDuplicateReadGatheringByteChannelMultipleThreads(); + } + }); + } + + @Test + @Override + public void testSliceReadGatheringByteChannelMultipleThreads() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() throws Exception { + SlicedByteBufTest.super.testSliceReadGatheringByteChannelMultipleThreads(); + } + }); + } + + @Test + @Override + public void testDuplicateReadOutputStreamMultipleThreads() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() throws Exception { + SlicedByteBufTest.super.testDuplicateReadOutputStreamMultipleThreads(); + } + }); + } + + @Test + @Override + public void testSliceReadOutputStreamMultipleThreads() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() throws Exception { + SlicedByteBufTest.super.testSliceReadOutputStreamMultipleThreads(); + } + }); + } + + @Test + @Override + public void testDuplicateBytesInArrayMultipleThreads() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() throws Exception { + SlicedByteBufTest.super.testDuplicateBytesInArrayMultipleThreads(); + } + }); + } + + @Test + @Override + public void testSliceBytesInArrayMultipleThreads() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() throws Exception { + SlicedByteBufTest.super.testSliceBytesInArrayMultipleThreads(); + } + }); + } + + @Test + @Override + public void testNioBufferExposeOnlyRegion() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + SlicedByteBufTest.super.testNioBufferExposeOnlyRegion(); + } + }); + } + + @Test + @Override + public void testGetReadOnlyDirectDst() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + SlicedByteBufTest.super.testGetReadOnlyDirectDst(); + } + }); + } + + @Test + @Override + public void testGetReadOnlyHeapDst() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + SlicedByteBufTest.super.testGetReadOnlyHeapDst(); + } + }); + } + + @Test + @Override + public void testLittleEndianWithExpand() { + // ignore for SlicedByteBuf + } + + @Test + @Override + public void testReadBytes() { + // ignore for SlicedByteBuf + } + + @Test + @Override + public void testForEachByteDesc2() { + // Ignore for SlicedByteBuf + } + + @Test + @Override + public void testForEachByte2() { + // Ignore for SlicedByteBuf + } + + @Disabled("Sliced ByteBuf objects don't allow the capacity to change. So this test would fail and shouldn't be run") + @Override + public void testDuplicateCapacityChange() { + } + + @Disabled("Sliced ByteBuf objects don't allow the capacity to change. So this test would fail and shouldn't be run") + @Override + public void testRetainedDuplicateCapacityChange() { + } + + @Test + public void testReaderIndexAndMarks() { + ByteBuf wrapped = Unpooled.buffer(16); + try { + wrapped.writerIndex(14); + wrapped.readerIndex(2); + wrapped.markWriterIndex(); + wrapped.markReaderIndex(); + ByteBuf slice = wrapped.slice(4, 4); + assertEquals(0, slice.readerIndex()); + assertEquals(4, slice.writerIndex()); + + slice.readerIndex(slice.readerIndex() + 1); + slice.resetReaderIndex(); + assertEquals(0, slice.readerIndex()); + + slice.writerIndex(slice.writerIndex() - 1); + slice.resetWriterIndex(); + assertEquals(0, slice.writerIndex()); + } finally { + wrapped.release(); + } + } + + @Test + public void sliceEmptyNotLeak() { + ByteBuf buffer = Unpooled.buffer(8).retain(); + assertEquals(2, buffer.refCnt()); + + ByteBuf slice1 = buffer.slice(); + assertEquals(2, slice1.refCnt()); + + ByteBuf slice2 = slice1.slice(); + assertEquals(2, slice2.refCnt()); + + assertFalse(slice2.release()); + assertEquals(1, buffer.refCnt()); + assertEquals(1, slice1.refCnt()); + assertEquals(1, slice2.refCnt()); + + assertTrue(slice2.release()); + + assertEquals(0, buffer.refCnt()); + assertEquals(0, slice1.refCnt()); + assertEquals(0, slice2.refCnt()); + } + + @Override + @Test + public void testGetBytesByteBuffer() { + byte[] bytes = {'a', 'b', 'c', 'd', 'e', 'f', 'g'}; + // Ensure destination buffer is bigger then what is wrapped in the ByteBuf. + final ByteBuffer nioBuffer = ByteBuffer.allocate(bytes.length + 1); + final ByteBuf wrappedBuffer = Unpooled.wrappedBuffer(bytes).slice(0, bytes.length - 1); + try { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + wrappedBuffer.getBytes(wrappedBuffer.readerIndex(), nioBuffer); + } + }); + } finally { + wrappedBuffer.release(); + } + } + + @Test + @Override + public void testWriteUsAsciiCharSequenceExpand() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + SlicedByteBufTest.super.testWriteUsAsciiCharSequenceExpand(); + } + }); + } + + @Test + @Override + public void testWriteUtf8CharSequenceExpand() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + SlicedByteBufTest.super.testWriteUtf8CharSequenceExpand(); + } + }); + } + + @Test + @Override + public void testWriteIso88591CharSequenceExpand() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + SlicedByteBufTest.super.testWriteIso88591CharSequenceExpand(); + } + }); + } + + @Test + @Override + public void testWriteUtf16CharSequenceExpand() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + SlicedByteBufTest.super.testWriteUtf16CharSequenceExpand(); + } + }); + } + + @Test + public void ensureWritableWithEnoughSpaceShouldNotThrow() { + ByteBuf slice = newBuffer(10); + ByteBuf unwrapped = slice.unwrap(); + unwrapped.writerIndex(unwrapped.writerIndex() + 5); + slice.writerIndex(slice.readerIndex()); + + // Run ensureWritable and verify this doesn't change any indexes. + int originalWriterIndex = slice.writerIndex(); + int originalReadableBytes = slice.readableBytes(); + slice.ensureWritable(originalWriterIndex - slice.writerIndex()); + assertEquals(originalWriterIndex, slice.writerIndex()); + assertEquals(originalReadableBytes, slice.readableBytes()); + slice.release(); + } + + @Test + public void ensureWritableWithNotEnoughSpaceShouldThrow() { + final ByteBuf slice = newBuffer(10); + ByteBuf unwrapped = slice.unwrap(); + unwrapped.writerIndex(unwrapped.writerIndex() + 5); + try { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + slice.ensureWritable(1); + } + }); + } finally { + slice.release(); + } + } +} diff --git a/netty-buffer/src/test/java/io/netty/buffer/UnpooledByteBufAllocatorTest.java b/netty-buffer/src/test/java/io/netty/buffer/UnpooledByteBufAllocatorTest.java new file mode 100644 index 0000000..bd51ef3 --- /dev/null +++ b/netty-buffer/src/test/java/io/netty/buffer/UnpooledByteBufAllocatorTest.java @@ -0,0 +1,29 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +public class UnpooledByteBufAllocatorTest extends AbstractByteBufAllocatorTest { + + @Override + protected UnpooledByteBufAllocator newAllocator(boolean preferDirect) { + return new UnpooledByteBufAllocator(preferDirect); + } + + @Override + protected UnpooledByteBufAllocator newUnpooledAllocator() { + return new UnpooledByteBufAllocator(false); + } +} diff --git a/netty-buffer/src/test/java/io/netty/buffer/UnpooledTest.java b/netty-buffer/src/test/java/io/netty/buffer/UnpooledTest.java new file mode 100644 index 0000000..efc1daf --- /dev/null +++ b/netty-buffer/src/test/java/io/netty/buffer/UnpooledTest.java @@ -0,0 +1,821 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import io.netty.util.CharsetUtil; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; +import org.mockito.Mockito; + +import java.io.InputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.channels.ScatteringByteChannel; +import java.nio.charset.Charset; +import java.util.ArrayList; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; + +import static io.netty.buffer.Unpooled.*; +import static io.netty.util.internal.EmptyArrays.*; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +/** + * Tests channel buffers + */ +public class UnpooledTest { + + private static final ByteBuf[] EMPTY_BYTE_BUFS = new ByteBuf[0]; + private static final byte[][] EMPTY_BYTES_2D = new byte[0][]; + + @Test + public void testCompositeWrappedBuffer() { + ByteBuf header = buffer(12); + ByteBuf payload = buffer(512); + + header.writeBytes(new byte[12]); + payload.writeBytes(new byte[512]); + + ByteBuf buffer = wrappedBuffer(header, payload); + + assertEquals(12, header.readableBytes()); + assertEquals(512, payload.readableBytes()); + + assertEquals(12 + 512, buffer.readableBytes()); + assertEquals(2, buffer.nioBufferCount()); + + buffer.release(); + } + + @Test + public void testHashCode() { + Map map = new LinkedHashMap(); + map.put(EMPTY_BYTES, 1); + map.put(new byte[] { 1 }, 32); + map.put(new byte[] { 2 }, 33); + map.put(new byte[] { 0, 1 }, 962); + map.put(new byte[] { 1, 2 }, 994); + map.put(new byte[] { 0, 1, 2, 3, 4, 5 }, 63504931); + map.put(new byte[] { 6, 7, 8, 9, 0, 1 }, (int) 97180294697L); + map.put(new byte[] { -1, -1, -1, (byte) 0xE1 }, 1); + + for (Entry e: map.entrySet()) { + ByteBuf buffer = wrappedBuffer(e.getKey()); + assertEquals( + e.getValue().intValue(), + ByteBufUtil.hashCode(buffer)); + buffer.release(); + } + } + + @Test + public void testEquals() { + ByteBuf a, b; + + // Different length. + a = wrappedBuffer(new byte[] { 1 }); + b = wrappedBuffer(new byte[] { 1, 2 }); + assertFalse(ByteBufUtil.equals(a, b)); + a.release(); + b.release(); + + // Same content, same firstIndex, short length. + a = wrappedBuffer(new byte[] { 1, 2, 3 }); + b = wrappedBuffer(new byte[] { 1, 2, 3 }); + assertTrue(ByteBufUtil.equals(a, b)); + a.release(); + b.release(); + + // Same content, different firstIndex, short length. + a = wrappedBuffer(new byte[] { 1, 2, 3 }); + b = wrappedBuffer(new byte[] { 0, 1, 2, 3, 4 }, 1, 3); + assertTrue(ByteBufUtil.equals(a, b)); + a.release(); + b.release(); + + // Different content, same firstIndex, short length. + a = wrappedBuffer(new byte[] { 1, 2, 3 }); + b = wrappedBuffer(new byte[] { 1, 2, 4 }); + assertFalse(ByteBufUtil.equals(a, b)); + a.release(); + b.release(); + + // Different content, different firstIndex, short length. + a = wrappedBuffer(new byte[] { 1, 2, 3 }); + b = wrappedBuffer(new byte[] { 0, 1, 2, 4, 5 }, 1, 3); + assertFalse(ByteBufUtil.equals(a, b)); + a.release(); + b.release(); + + // Same content, same firstIndex, long length. + a = wrappedBuffer(new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }); + b = wrappedBuffer(new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }); + assertTrue(ByteBufUtil.equals(a, b)); + a.release(); + b.release(); + + // Same content, different firstIndex, long length. + a = wrappedBuffer(new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }); + b = wrappedBuffer(new byte[] { 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11}, 1, 10); + assertTrue(ByteBufUtil.equals(a, b)); + a.release(); + b.release(); + + // Different content, same firstIndex, long length. + a = wrappedBuffer(new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }); + b = wrappedBuffer(new byte[] { 1, 2, 3, 4, 6, 7, 8, 5, 9, 10 }); + assertFalse(ByteBufUtil.equals(a, b)); + a.release(); + b.release(); + + // Different content, different firstIndex, long length. + a = wrappedBuffer(new byte[] { 1, 2, 3, 4, 5, 6, 7, 8, 9, 10 }); + b = wrappedBuffer(new byte[] { 0, 1, 2, 3, 4, 6, 7, 8, 5, 9, 10, 11 }, 1, 10); + assertFalse(ByteBufUtil.equals(a, b)); + a.release(); + b.release(); + } + + @Test + public void testCompare() { + List expected = new ArrayList(); + expected.add(wrappedBuffer(new byte[]{1})); + expected.add(wrappedBuffer(new byte[]{1, 2})); + expected.add(wrappedBuffer(new byte[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10})); + expected.add(wrappedBuffer(new byte[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12})); + expected.add(wrappedBuffer(new byte[]{2})); + expected.add(wrappedBuffer(new byte[]{2, 3})); + expected.add(wrappedBuffer(new byte[]{2, 3, 4, 5, 6, 7, 8, 9, 10, 11})); + expected.add(wrappedBuffer(new byte[]{2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13})); + expected.add(wrappedBuffer(new byte[]{2, 3, 4}, 1, 1)); + expected.add(wrappedBuffer(new byte[]{1, 2, 3, 4}, 2, 2)); + expected.add(wrappedBuffer(new byte[]{2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14}, 1, 10)); + expected.add(wrappedBuffer(new byte[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14}, 2, 12)); + expected.add(wrappedBuffer(new byte[]{2, 3, 4, 5}, 2, 1)); + expected.add(wrappedBuffer(new byte[]{1, 2, 3, 4, 5}, 3, 2)); + expected.add(wrappedBuffer(new byte[]{2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14}, 2, 10)); + expected.add(wrappedBuffer(new byte[]{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15}, 3, 12)); + + for (int i = 0; i < expected.size(); i ++) { + for (int j = 0; j < expected.size(); j ++) { + if (i == j) { + assertEquals(0, ByteBufUtil.compare(expected.get(i), expected.get(j))); + } else if (i < j) { + assertTrue(ByteBufUtil.compare(expected.get(i), expected.get(j)) < 0); + } else { + assertTrue(ByteBufUtil.compare(expected.get(i), expected.get(j)) > 0); + } + } + } + for (ByteBuf buffer: expected) { + buffer.release(); + } + } + + @Test + public void shouldReturnEmptyBufferWhenLengthIsZero() { + assertSameAndRelease(EMPTY_BUFFER, wrappedBuffer(EMPTY_BYTES)); + assertSameAndRelease(EMPTY_BUFFER, wrappedBuffer(new byte[8], 0, 0)); + assertSameAndRelease(EMPTY_BUFFER, wrappedBuffer(new byte[8], 8, 0)); + assertSameAndRelease(EMPTY_BUFFER, wrappedBuffer(ByteBuffer.allocateDirect(0))); + assertSameAndRelease(EMPTY_BUFFER, wrappedBuffer(EMPTY_BUFFER)); + assertSameAndRelease(EMPTY_BUFFER, wrappedBuffer(EMPTY_BYTES_2D)); + assertSameAndRelease(EMPTY_BUFFER, wrappedBuffer(new byte[][] { EMPTY_BYTES })); + assertSameAndRelease(EMPTY_BUFFER, wrappedBuffer(EMPTY_BYTE_BUFFERS)); + assertSameAndRelease(EMPTY_BUFFER, wrappedBuffer(new ByteBuffer[] { ByteBuffer.allocate(0) })); + assertSameAndRelease(EMPTY_BUFFER, wrappedBuffer(ByteBuffer.allocate(0), ByteBuffer.allocate(0))); + assertSameAndRelease(EMPTY_BUFFER, wrappedBuffer(EMPTY_BYTE_BUFS)); + assertSameAndRelease(EMPTY_BUFFER, wrappedBuffer(new ByteBuf[] { buffer(0) })); + assertSameAndRelease(EMPTY_BUFFER, wrappedBuffer(buffer(0), buffer(0))); + + assertSameAndRelease(EMPTY_BUFFER, copiedBuffer(EMPTY_BYTES)); + assertSameAndRelease(EMPTY_BUFFER, copiedBuffer(new byte[8], 0, 0)); + assertSameAndRelease(EMPTY_BUFFER, copiedBuffer(new byte[8], 8, 0)); + assertSameAndRelease(EMPTY_BUFFER, copiedBuffer(ByteBuffer.allocateDirect(0))); + assertSameAndRelease(EMPTY_BUFFER, copiedBuffer(EMPTY_BUFFER)); + assertSame(EMPTY_BUFFER, copiedBuffer(EMPTY_BYTES_2D)); + assertSameAndRelease(EMPTY_BUFFER, copiedBuffer(new byte[][] { EMPTY_BYTES })); + assertSameAndRelease(EMPTY_BUFFER, copiedBuffer(EMPTY_BYTE_BUFFERS)); + assertSameAndRelease(EMPTY_BUFFER, copiedBuffer(new ByteBuffer[] { ByteBuffer.allocate(0) })); + assertSameAndRelease(EMPTY_BUFFER, copiedBuffer(ByteBuffer.allocate(0), ByteBuffer.allocate(0))); + assertSameAndRelease(EMPTY_BUFFER, copiedBuffer(EMPTY_BYTE_BUFS)); + assertSameAndRelease(EMPTY_BUFFER, copiedBuffer(new ByteBuf[] { buffer(0) })); + assertSameAndRelease(EMPTY_BUFFER, copiedBuffer(buffer(0), buffer(0))); + } + + @Test + public void testCompare2() { + ByteBuf expected = wrappedBuffer(new byte[]{(byte) 0xFF, (byte) 0xFF, (byte) 0xFF, (byte) 0xFF}); + ByteBuf actual = wrappedBuffer(new byte[]{(byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00}); + assertTrue(ByteBufUtil.compare(expected, actual) > 0); + expected.release(); + actual.release(); + + expected = wrappedBuffer(new byte[]{(byte) 0xFF}); + actual = wrappedBuffer(new byte[]{(byte) 0x00}); + assertTrue(ByteBufUtil.compare(expected, actual) > 0); + expected.release(); + actual.release(); + } + + @Test + public void shouldAllowEmptyBufferToCreateCompositeBuffer() { + ByteBuf buf = wrappedBuffer( + EMPTY_BUFFER, + wrappedBuffer(new byte[16]).order(LITTLE_ENDIAN), + EMPTY_BUFFER); + try { + assertEquals(16, buf.capacity()); + } finally { + buf.release(); + } + } + + @Test + public void testWrappedBuffer() { + ByteBuf buffer = wrappedBuffer(ByteBuffer.allocateDirect(16)); + assertEquals(16, buffer.capacity()); + buffer.release(); + + assertEqualsAndRelease( + wrappedBuffer(new byte[] { 1, 2, 3 }), + wrappedBuffer(new byte[][] { new byte[] { 1, 2, 3 } })); + + assertEqualsAndRelease( + wrappedBuffer(new byte[] { 1, 2, 3 }), + wrappedBuffer(new byte[] { 1 }, new byte[] { 2 }, new byte[] { 3 })); + + assertEqualsAndRelease(wrappedBuffer(new byte[] { 1, 2, 3 }), + wrappedBuffer(new ByteBuf[] { wrappedBuffer(new byte[] { 1, 2, 3 }) })); + + assertEqualsAndRelease( + wrappedBuffer(new byte[] { 1, 2, 3 }), + wrappedBuffer(wrappedBuffer(new byte[] { 1 }), + wrappedBuffer(new byte[] { 2 }), wrappedBuffer(new byte[] { 3 }))); + + assertEqualsAndRelease(wrappedBuffer(new byte[] { 1, 2, 3 }), + wrappedBuffer(new ByteBuffer[] { ByteBuffer.wrap(new byte[] { 1, 2, 3 }) })); + + assertEqualsAndRelease(wrappedBuffer(new byte[] { 1, 2, 3 }), + wrappedBuffer(ByteBuffer.wrap(new byte[] { 1 }), + ByteBuffer.wrap(new byte[] { 2 }), ByteBuffer.wrap(new byte[] { 3 }))); + } + + @Test + public void testSingleWrappedByteBufReleased() { + ByteBuf buf = buffer(12).writeByte(0); + ByteBuf wrapped = wrappedBuffer(buf); + assertTrue(wrapped.release()); + assertEquals(0, buf.refCnt()); + } + + @Test + public void testSingleUnReadableWrappedByteBufReleased() { + ByteBuf buf = buffer(12); + ByteBuf wrapped = wrappedBuffer(buf); + assertFalse(wrapped.release()); // EMPTY_BUFFER cannot be released + assertEquals(0, buf.refCnt()); + } + + @Test + public void testMultiByteBufReleased() { + ByteBuf buf1 = buffer(12).writeByte(0); + ByteBuf buf2 = buffer(12).writeByte(0); + ByteBuf wrapped = wrappedBuffer(16, buf1, buf2); + assertTrue(wrapped.release()); + assertEquals(0, buf1.refCnt()); + assertEquals(0, buf2.refCnt()); + } + + @Test + public void testMultiUnReadableByteBufReleased() { + ByteBuf buf1 = buffer(12); + ByteBuf buf2 = buffer(12); + ByteBuf wrapped = wrappedBuffer(16, buf1, buf2); + assertFalse(wrapped.release()); // EMPTY_BUFFER cannot be released + assertEquals(0, buf1.refCnt()); + assertEquals(0, buf2.refCnt()); + } + + @Test + public void testCopiedBufferUtf8() { + testCopiedBufferCharSequence("Some UTF_8 like äÄ∏ŒŒ", CharsetUtil.UTF_8); + } + + @Test + public void testCopiedBufferAscii() { + testCopiedBufferCharSequence("Some US_ASCII", CharsetUtil.US_ASCII); + } + + @Test + public void testCopiedBufferSomeOtherCharset() { + testCopiedBufferCharSequence("Some ISO_8859_1", CharsetUtil.ISO_8859_1); + } + + private static void testCopiedBufferCharSequence(CharSequence sequence, Charset charset) { + ByteBuf copied = copiedBuffer(sequence, charset); + try { + assertEquals(sequence, copied.toString(charset)); + } finally { + copied.release(); + } + } + + @Test + public void testCopiedBuffer() { + ByteBuf copied = copiedBuffer(ByteBuffer.allocateDirect(16)); + assertEquals(16, copied.capacity()); + copied.release(); + + assertEqualsAndRelease(wrappedBuffer(new byte[] { 1, 2, 3 }), + copiedBuffer(new byte[][] { new byte[] { 1, 2, 3 } })); + + assertEqualsAndRelease(wrappedBuffer(new byte[] { 1, 2, 3 }), + copiedBuffer(new byte[] { 1 }, new byte[] { 2 }, new byte[] { 3 })); + + assertEqualsAndRelease(wrappedBuffer(new byte[] { 1, 2, 3 }), + copiedBuffer(new ByteBuf[] { wrappedBuffer(new byte[] { 1, 2, 3 })})); + + assertEqualsAndRelease(wrappedBuffer(new byte[] { 1, 2, 3 }), + copiedBuffer(wrappedBuffer(new byte[] { 1 }), + wrappedBuffer(new byte[] { 2 }), wrappedBuffer(new byte[] { 3 }))); + + assertEqualsAndRelease(wrappedBuffer(new byte[] { 1, 2, 3 }), + copiedBuffer(new ByteBuffer[] { ByteBuffer.wrap(new byte[] { 1, 2, 3 }) })); + + assertEqualsAndRelease(wrappedBuffer(new byte[] { 1, 2, 3 }), + copiedBuffer(ByteBuffer.wrap(new byte[] { 1 }), + ByteBuffer.wrap(new byte[] { 2 }), ByteBuffer.wrap(new byte[] { 3 }))); + } + + private static void assertEqualsAndRelease(ByteBuf expected, ByteBuf actual) { + assertEquals(expected, actual); + expected.release(); + actual.release(); + } + + private static void assertSameAndRelease(ByteBuf expected, ByteBuf actual) { + assertEquals(expected, actual); + expected.release(); + actual.release(); + } + + @Test + public void testHexDump() { + assertEquals("", ByteBufUtil.hexDump(EMPTY_BUFFER)); + + ByteBuf buffer = wrappedBuffer(new byte[]{ 0x12, 0x34, 0x56 }); + assertEquals("123456", ByteBufUtil.hexDump(buffer)); + buffer.release(); + + buffer = wrappedBuffer(new byte[]{ + 0x12, 0x34, 0x56, 0x78, + (byte) 0x90, (byte) 0xAB, (byte) 0xCD, (byte) 0xEF + }); + assertEquals("1234567890abcdef", ByteBufUtil.hexDump(buffer)); + buffer.release(); + } + + @Test + public void testSwapMedium() { + assertEquals(0x563412, ByteBufUtil.swapMedium(0x123456)); + assertEquals(0x80, ByteBufUtil.swapMedium(0x800000)); + } + + @Test + public void testUnmodifiableBuffer() throws Exception { + ByteBuf buf = unmodifiableBuffer(buffer(16)); + + try { + buf.discardReadBytes(); + fail(); + } catch (UnsupportedOperationException e) { + // Expected + } + + try { + buf.setByte(0, (byte) 0); + fail(); + } catch (UnsupportedOperationException e) { + // Expected + } + + try { + buf.setBytes(0, EMPTY_BUFFER, 0, 0); + fail(); + } catch (UnsupportedOperationException e) { + // Expected + } + + try { + buf.setBytes(0, EMPTY_BYTES, 0, 0); + fail(); + } catch (UnsupportedOperationException e) { + // Expected + } + + try { + buf.setBytes(0, ByteBuffer.allocate(0)); + fail(); + } catch (UnsupportedOperationException e) { + // Expected + } + + try { + buf.setShort(0, (short) 0); + fail(); + } catch (UnsupportedOperationException e) { + // Expected + } + + try { + buf.setMedium(0, 0); + fail(); + } catch (UnsupportedOperationException e) { + // Expected + } + + try { + buf.setInt(0, 0); + fail(); + } catch (UnsupportedOperationException e) { + // Expected + } + + try { + buf.setLong(0, 0); + fail(); + } catch (UnsupportedOperationException e) { + // Expected + } + + InputStream inputStream = Mockito.mock(InputStream.class); + try { + buf.setBytes(0, inputStream, 0); + fail(); + } catch (UnsupportedOperationException e) { + // Expected + } + Mockito.verifyZeroInteractions(inputStream); + + ScatteringByteChannel scatteringByteChannel = Mockito.mock(ScatteringByteChannel.class); + try { + buf.setBytes(0, scatteringByteChannel, 0); + fail(); + } catch (UnsupportedOperationException e) { + // Expected + } + Mockito.verifyZeroInteractions(scatteringByteChannel); + buf.release(); + } + + @Test + public void testWrapSingleInt() { + ByteBuf buffer = copyInt(42); + assertEquals(4, buffer.capacity()); + assertEquals(42, buffer.readInt()); + assertFalse(buffer.isReadable()); + buffer.release(); + } + + @Test + public void testWrapInt() { + ByteBuf buffer = copyInt(1, 4); + assertEquals(8, buffer.capacity()); + assertEquals(1, buffer.readInt()); + assertEquals(4, buffer.readInt()); + assertFalse(buffer.isReadable()); + buffer.release(); + + buffer = copyInt(null); + assertEquals(0, buffer.capacity()); + buffer.release(); + + buffer = copyInt(new int[] {}); + assertEquals(0, buffer.capacity()); + buffer.release(); + } + + @Test + public void testWrapSingleShort() { + ByteBuf buffer = copyShort(42); + assertEquals(2, buffer.capacity()); + assertEquals(42, buffer.readShort()); + assertFalse(buffer.isReadable()); + buffer.release(); + } + + @Test + public void testWrapShortFromShortArray() { + ByteBuf buffer = copyShort(new short[]{1, 4}); + assertEquals(4, buffer.capacity()); + assertEquals(1, buffer.readShort()); + assertEquals(4, buffer.readShort()); + assertFalse(buffer.isReadable()); + buffer.release(); + + buffer = copyShort((short[]) null); + assertEquals(0, buffer.capacity()); + buffer.release(); + + buffer = copyShort(new short[] {}); + assertEquals(0, buffer.capacity()); + buffer.release(); + } + + @Test + public void testWrapShortFromIntArray() { + ByteBuf buffer = copyShort(1, 4); + assertEquals(4, buffer.capacity()); + assertEquals(1, buffer.readShort()); + assertEquals(4, buffer.readShort()); + assertFalse(buffer.isReadable()); + buffer.release(); + + buffer = copyShort((int[]) null); + assertEquals(0, buffer.capacity()); + buffer.release(); + + buffer = copyShort(new int[] {}); + assertEquals(0, buffer.capacity()); + buffer.release(); + } + + @Test + public void testWrapSingleMedium() { + ByteBuf buffer = copyMedium(42); + assertEquals(3, buffer.capacity()); + assertEquals(42, buffer.readMedium()); + assertFalse(buffer.isReadable()); + buffer.release(); + } + + @Test + public void testWrapMedium() { + ByteBuf buffer = copyMedium(1, 4); + assertEquals(6, buffer.capacity()); + assertEquals(1, buffer.readMedium()); + assertEquals(4, buffer.readMedium()); + assertFalse(buffer.isReadable()); + buffer.release(); + + buffer = copyMedium(null); + assertEquals(0, copyMedium(null).capacity()); + buffer.release(); + + buffer = copyMedium(new int[] {}); + assertEquals(0, buffer.capacity()); + buffer.release(); + } + + @Test + public void testWrapSingleLong() { + ByteBuf buffer = copyLong(42); + assertEquals(8, buffer.capacity()); + assertEquals(42, buffer.readLong()); + assertFalse(buffer.isReadable()); + buffer.release(); + } + + @Test + public void testWrapLong() { + ByteBuf buffer = copyLong(1, 4); + assertEquals(16, buffer.capacity()); + assertEquals(1, buffer.readLong()); + assertEquals(4, buffer.readLong()); + assertFalse(buffer.isReadable()); + buffer.release(); + + buffer = copyLong(null); + assertEquals(0, buffer.capacity()); + buffer.release(); + + buffer = copyLong(new long[] {}); + assertEquals(0, buffer.capacity()); + buffer.release(); + } + + @Test + public void testWrapSingleFloat() { + ByteBuf buffer = copyFloat(42); + assertEquals(4, buffer.capacity()); + assertEquals(42, buffer.readFloat(), 0.01); + assertFalse(buffer.isReadable()); + buffer.release(); + } + + @Test + public void testWrapFloat() { + ByteBuf buffer = copyFloat(1, 4); + assertEquals(8, buffer.capacity()); + assertEquals(1, buffer.readFloat(), 0.01); + assertEquals(4, buffer.readFloat(), 0.01); + assertFalse(buffer.isReadable()); + buffer.release(); + + buffer = copyFloat(null); + assertEquals(0, buffer.capacity()); + buffer.release(); + + buffer = copyFloat(new float[] {}); + assertEquals(0, buffer.capacity()); + buffer.release(); + } + + @Test + public void testWrapSingleDouble() { + ByteBuf buffer = copyDouble(42); + assertEquals(8, buffer.capacity()); + assertEquals(42, buffer.readDouble(), 0.01); + assertFalse(buffer.isReadable()); + buffer.release(); + } + + @Test + public void testWrapDouble() { + ByteBuf buffer = copyDouble(1, 4); + assertEquals(16, buffer.capacity()); + assertEquals(1, buffer.readDouble(), 0.01); + assertEquals(4, buffer.readDouble(), 0.01); + assertFalse(buffer.isReadable()); + buffer.release(); + + buffer = copyDouble(null); + assertEquals(0, buffer.capacity()); + buffer.release(); + + buffer = copyDouble(new double[] {}); + assertEquals(0, buffer.capacity()); + buffer.release(); + } + + @Test + public void testWrapBoolean() { + ByteBuf buffer = copyBoolean(true, false); + assertEquals(2, buffer.capacity()); + assertTrue(buffer.readBoolean()); + assertFalse(buffer.readBoolean()); + assertFalse(buffer.isReadable()); + buffer.release(); + + buffer = copyBoolean(null); + assertEquals(0, buffer.capacity()); + buffer.release(); + + buffer = copyBoolean(new boolean[] {}); + assertEquals(0, buffer.capacity()); + buffer.release(); + } + + @Test + public void wrappedReadOnlyDirectBuffer() { + ByteBuffer buffer = ByteBuffer.allocateDirect(12); + for (int i = 0; i < 12; i++) { + buffer.put((byte) i); + } + buffer.flip(); + ByteBuf wrapped = wrappedBuffer(buffer.asReadOnlyBuffer()); + for (int i = 0; i < 12; i++) { + assertEquals((byte) i, wrapped.readByte()); + } + wrapped.release(); + } + + @Test + public void skipBytesNegativeLength() { + final ByteBuf buf = buffer(8); + try { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() throws Throwable { + buf.skipBytes(-1); + } + }); + } finally { + buf.release(); + } + } + + // See https://github.com/netty/netty/issues/5597 + @Test + public void testWrapByteBufArrayStartsWithNonReadable() { + ByteBuf buffer1 = buffer(8); + ByteBuf buffer2 = buffer(8).writeZero(8); // Ensure the ByteBuf is readable. + ByteBuf buffer3 = buffer(8); + ByteBuf buffer4 = buffer(8).writeZero(8); // Ensure the ByteBuf is readable. + + ByteBuf wrapped = wrappedBuffer(buffer1, buffer2, buffer3, buffer4); + assertEquals(16, wrapped.readableBytes()); + assertTrue(wrapped.release()); + assertEquals(0, buffer1.refCnt()); + assertEquals(0, buffer2.refCnt()); + assertEquals(0, buffer3.refCnt()); + assertEquals(0, buffer4.refCnt()); + assertEquals(0, wrapped.refCnt()); + } + + @Test + public void testGetBytesByteBuffer() { + byte[] bytes = {'a', 'b', 'c', 'd', 'e', 'f', 'g'}; + // Ensure destination buffer is bigger then what is wrapped in the ByteBuf. + final ByteBuffer nioBuffer = ByteBuffer.allocate(bytes.length + 1); + final ByteBuf wrappedBuffer = wrappedBuffer(bytes); + try { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() throws Throwable { + wrappedBuffer.getBytes(wrappedBuffer.readerIndex(), nioBuffer); + } + }); + } finally { + wrappedBuffer.release(); + } + } + + @Test + public void testGetBytesByteBuffer2() { + byte[] bytes = {'a', 'b', 'c', 'd', 'e', 'f', 'g'}; + // Ensure destination buffer is bigger then what is wrapped in the ByteBuf. + final ByteBuffer nioBuffer = ByteBuffer.allocate(bytes.length + 1); + final ByteBuf wrappedBuffer = wrappedBuffer(bytes, 0, bytes.length); + try { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() throws Throwable { + wrappedBuffer.getBytes(wrappedBuffer.readerIndex(), nioBuffer); + } + }); + } finally { + wrappedBuffer.release(); + } + } + + @SuppressWarnings("deprecation") + @Test + public void littleEndianWriteOnLittleEndianBufferMustStoreLittleEndianValue() { + ByteBuf b = buffer(1024).order(ByteOrder.LITTLE_ENDIAN); + + b.writeShortLE(0x0102); + assertEquals((short) 0x0102, b.getShortLE(0)); + assertEquals((short) 0x0102, b.getShort(0)); + b.clear(); + + b.writeMediumLE(0x010203); + assertEquals(0x010203, b.getMediumLE(0)); + assertEquals(0x010203, b.getMedium(0)); + b.clear(); + + b.writeIntLE(0x01020304); + assertEquals(0x01020304, b.getIntLE(0)); + assertEquals(0x01020304, b.getInt(0)); + b.clear(); + + b.writeLongLE(0x0102030405060708L); + assertEquals(0x0102030405060708L, b.getLongLE(0)); + assertEquals(0x0102030405060708L, b.getLong(0)); + } + + @Test + public void littleEndianWriteOnDefaultBufferMustStoreLittleEndianValue() { + ByteBuf b = buffer(1024); + + b.writeShortLE(0x0102); + assertEquals((short) 0x0102, b.getShortLE(0)); + assertEquals((short) 0x0201, b.getShort(0)); + b.clear(); + + b.writeMediumLE(0x010203); + assertEquals(0x010203, b.getMediumLE(0)); + assertEquals(0x030201, b.getMedium(0)); + b.clear(); + + b.writeIntLE(0x01020304); + assertEquals(0x01020304, b.getIntLE(0)); + assertEquals(0x04030201, b.getInt(0)); + b.clear(); + + b.writeLongLE(0x0102030405060708L); + assertEquals(0x0102030405060708L, b.getLongLE(0)); + assertEquals(0x0807060504030201L, b.getLong(0)); + } +} diff --git a/netty-buffer/src/test/java/io/netty/buffer/UnreleaseableByteBufTest.java b/netty-buffer/src/test/java/io/netty/buffer/UnreleaseableByteBufTest.java new file mode 100644 index 0000000..9ee22db --- /dev/null +++ b/netty-buffer/src/test/java/io/netty/buffer/UnreleaseableByteBufTest.java @@ -0,0 +1,54 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import org.junit.jupiter.api.Test; + +import static io.netty.buffer.Unpooled.buffer; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class UnreleaseableByteBufTest { + + @Test + public void testCantRelease() { + ByteBuf buf = Unpooled.unreleasableBuffer(Unpooled.copyInt(1)); + assertEquals(1, buf.refCnt()); + assertFalse(buf.release()); + assertEquals(1, buf.refCnt()); + assertFalse(buf.release()); + assertEquals(1, buf.refCnt()); + + buf.retain(5); + assertEquals(1, buf.refCnt()); + + buf.retain(); + assertEquals(1, buf.refCnt()); + + assertTrue(buf.unwrap().release()); + assertEquals(0, buf.refCnt()); + } + + @Test + public void testWrappedReadOnly() { + ByteBuf buf = Unpooled.unreleasableBuffer(buffer(1).asReadOnly()); + assertSame(buf, buf.asReadOnly()); + + assertTrue(buf.unwrap().release()); + } +} diff --git a/netty-buffer/src/test/java/io/netty/buffer/UnsafeByteBufUtilTest.java b/netty-buffer/src/test/java/io/netty/buffer/UnsafeByteBufUtilTest.java new file mode 100644 index 0000000..307e7ee --- /dev/null +++ b/netty-buffer/src/test/java/io/netty/buffer/UnsafeByteBufUtilTest.java @@ -0,0 +1,252 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import io.netty.util.internal.PlatformDependent; +import org.junit.jupiter.api.Assumptions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import java.nio.ByteBuffer; + +import static io.netty.util.internal.PlatformDependent.directBufferAddress; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class UnsafeByteBufUtilTest { + @BeforeEach + public void checkHasUnsafe() { + Assumptions.assumeTrue(PlatformDependent.hasUnsafe(), "sun.misc.Unsafe not found, skip tests"); + } + + @Test + public void testSetBytesOnReadOnlyByteBuffer() throws Exception { + byte[] testData = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + int length = testData.length; + + ByteBuffer readOnlyBuffer = ByteBuffer.wrap(testData).asReadOnlyBuffer(); + + UnpooledByteBufAllocator alloc = new UnpooledByteBufAllocator(true); + UnpooledDirectByteBuf targetBuffer = new UnpooledDirectByteBuf(alloc, length, length); + + try { + UnsafeByteBufUtil.setBytes(targetBuffer, directBufferAddress(targetBuffer.nioBuffer()), 0, readOnlyBuffer); + + byte[] check = new byte[length]; + targetBuffer.getBytes(0, check, 0, length); + + assertArrayEquals(testData, check, "The byte array's copy does not equal the original"); + } finally { + targetBuffer.release(); + } + } + + @Test + public void testSetBytesOnReadOnlyByteBufferWithPooledAlloc() throws Exception { + byte[] testData = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + int length = testData.length; + + ByteBuffer readOnlyBuffer = ByteBuffer.wrap(testData).asReadOnlyBuffer(); + + int pageSize = 4096; + + // create memory pool with one page + ByteBufAllocator alloc = new PooledByteBufAllocator(true, 1, 1, pageSize, 0); + UnpooledDirectByteBuf targetBuffer = new UnpooledDirectByteBuf(alloc, length, length); + + ByteBuf b1 = alloc.heapBuffer(16); + ByteBuf b2 = alloc.heapBuffer(16); + + try { + // just check that two following buffers share same array but different offset + assertEquals(pageSize, b1.array().length); + assertArrayEquals(b1.array(), b2.array()); + assertNotEquals(b1.arrayOffset(), b2.arrayOffset()); + + UnsafeByteBufUtil.setBytes(targetBuffer, directBufferAddress(targetBuffer.nioBuffer()), 0, readOnlyBuffer); + + byte[] check = new byte[length]; + targetBuffer.getBytes(0, check, 0, length); + + assertArrayEquals(testData, check, "The byte array's copy does not equal the original"); + } finally { + targetBuffer.release(); + b1.release(); + b2.release(); + } + } + + @Test + public void testSetBytesWithByteArray() { + final byte[] testData = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + final int length = testData.length; + + final UnpooledByteBufAllocator alloc = new UnpooledByteBufAllocator(true); + final UnpooledDirectByteBuf targetBuffer = new UnpooledDirectByteBuf(alloc, length, length); + + try { + UnsafeByteBufUtil.setBytes(targetBuffer, + directBufferAddress(targetBuffer.nioBuffer()), 0, testData, 0, length); + + final byte[] check = new byte[length]; + targetBuffer.getBytes(0, check, 0, length); + + assertArrayEquals(testData, check, "The byte array's copy does not equal the original"); + } finally { + targetBuffer.release(); + } + } + + @Test + public void testSetBytesWithZeroLength() { + final byte[] testData = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}; + final int length = testData.length; + + final UnpooledByteBufAllocator alloc = new UnpooledByteBufAllocator(true); + final UnpooledDirectByteBuf targetBuffer = new UnpooledDirectByteBuf(alloc, length, length); + + try { + final byte[] beforeSet = new byte[length]; + targetBuffer.getBytes(0, beforeSet, 0, length); + + UnsafeByteBufUtil.setBytes(targetBuffer, + directBufferAddress(targetBuffer.nioBuffer()), 0, testData, 0, 0); + + final byte[] check = new byte[length]; + targetBuffer.getBytes(0, check, 0, length); + + assertArrayEquals(beforeSet, check); + } finally { + targetBuffer.release(); + } + } + + @Test + public void testSetBytesWithNullByteArray() { + + final UnpooledByteBufAllocator alloc = new UnpooledByteBufAllocator(true); + final UnpooledDirectByteBuf targetBuffer = new UnpooledDirectByteBuf(alloc, 8, 8); + + try { + assertThrows(NullPointerException.class, new Executable() { + @Override + public void execute() { + UnsafeByteBufUtil.setBytes(targetBuffer, + directBufferAddress(targetBuffer.nioBuffer()), 0, (byte[]) null, 0, 8); + } + }); + } finally { + targetBuffer.release(); + } + } + + @Test + public void testSetBytesOutOfBounds() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + // negative index + testSetBytesOutOfBounds0(4, 4, -1, 0, 4); + } + }); + } + + @Test + public void testSetBytesOutOfBounds2() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + // negative length + testSetBytesOutOfBounds0(4, 4, 0, 0, -1); + } + }); + } + + @Test + public void testSetBytesOutOfBounds3() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + // buffer length oversize + testSetBytesOutOfBounds0(4, 8, 0, 0, 5); + } + }); + } + + @Test + public void testSetBytesOutOfBounds4() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + // buffer length oversize + testSetBytesOutOfBounds0(4, 4, 3, 0, 3); + } + }); + } + + @Test + public void testSetBytesOutOfBounds5() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + // negative srcIndex + testSetBytesOutOfBounds0(4, 4, 0, -1, 4); + } + }); + } + + @Test + public void testSetBytesOutOfBounds6() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + // src length oversize + testSetBytesOutOfBounds0(8, 4, 0, 0, 5); + } + }); + } + + @Test + public void testSetBytesOutOfBounds7() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + // src length oversize + testSetBytesOutOfBounds0(4, 4, 0, 1, 4); + } + }); + } + + private static void testSetBytesOutOfBounds0(int lengthOfBuffer, + int lengthOfBytes, + int index, + int srcIndex, + int length) { + final UnpooledByteBufAllocator alloc = new UnpooledByteBufAllocator(true); + final UnpooledDirectByteBuf targetBuffer = new UnpooledDirectByteBuf(alloc, lengthOfBuffer, lengthOfBuffer); + + try { + UnsafeByteBufUtil.setBytes(targetBuffer, + directBufferAddress(targetBuffer.nioBuffer()), index, new byte[lengthOfBytes], srcIndex, length); + } finally { + targetBuffer.release(); + } + } + +} diff --git a/netty-buffer/src/test/java/io/netty/buffer/WrappedCompositeByteBufTest.java b/netty-buffer/src/test/java/io/netty/buffer/WrappedCompositeByteBufTest.java new file mode 100644 index 0000000..c4993c8 --- /dev/null +++ b/netty-buffer/src/test/java/io/netty/buffer/WrappedCompositeByteBufTest.java @@ -0,0 +1,33 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +public class WrappedCompositeByteBufTest extends BigEndianCompositeByteBufTest { + + @Override + protected final ByteBuf newBuffer(int length, int maxCapacity) { + return wrap((CompositeByteBuf) super.newBuffer(length, maxCapacity)); + } + + protected WrappedCompositeByteBuf wrap(CompositeByteBuf buffer) { + return new WrappedCompositeByteBuf(buffer); + } + + @Override + protected CompositeByteBuf newCompositeBuffer() { + return wrap(super.newCompositeBuffer()); + } +} diff --git a/netty-buffer/src/test/java/io/netty/buffer/WrappedUnpooledUnsafeByteBufTest.java b/netty-buffer/src/test/java/io/netty/buffer/WrappedUnpooledUnsafeByteBufTest.java new file mode 100644 index 0000000..7a1563d --- /dev/null +++ b/netty-buffer/src/test/java/io/netty/buffer/WrappedUnpooledUnsafeByteBufTest.java @@ -0,0 +1,264 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer; + +import io.netty.util.internal.PlatformDependent; +import org.junit.jupiter.api.Assumptions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class WrappedUnpooledUnsafeByteBufTest extends BigEndianUnsafeDirectByteBufTest { + + @BeforeEach + @Override + public void init() { + Assumptions.assumeTrue(PlatformDependent.useDirectBufferNoCleaner(), + "PlatformDependent.useDirectBufferNoCleaner() returned false, skip tests"); + super.init(); + } + + @Override + protected ByteBuf newBuffer(int length, int maxCapacity) { + Assumptions.assumeTrue(maxCapacity == Integer.MAX_VALUE); + + return new WrappedUnpooledUnsafeDirectByteBuf(UnpooledByteBufAllocator.DEFAULT, + PlatformDependent.allocateMemory(length), length, true); + } + + @Test + @Override + public void testInternalNioBuffer() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + WrappedUnpooledUnsafeByteBufTest.super.testInternalNioBuffer(); + } + }); + } + + @Test + @Override + public void testDuplicateReadGatheringByteChannelMultipleThreads() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() throws Exception { + WrappedUnpooledUnsafeByteBufTest.super.testDuplicateReadGatheringByteChannelMultipleThreads(); + } + }); + } + + @Test + @Override + public void testSliceReadGatheringByteChannelMultipleThreads() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() throws Exception { + WrappedUnpooledUnsafeByteBufTest.super.testSliceReadGatheringByteChannelMultipleThreads(); + } + }); + } + + @Test + @Override + public void testDuplicateReadOutputStreamMultipleThreads() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() throws Exception { + WrappedUnpooledUnsafeByteBufTest.super.testDuplicateReadOutputStreamMultipleThreads(); + } + }); + } + + @Test + @Override + public void testSliceReadOutputStreamMultipleThreads() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() throws Exception { + WrappedUnpooledUnsafeByteBufTest.super.testSliceReadOutputStreamMultipleThreads(); + } + }); + } + + @Test + @Override + public void testDuplicateBytesInArrayMultipleThreads() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() throws Exception { + WrappedUnpooledUnsafeByteBufTest.super.testDuplicateBytesInArrayMultipleThreads(); + } + }); + } + + @Test + @Override + public void testSliceBytesInArrayMultipleThreads() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() throws Exception { + WrappedUnpooledUnsafeByteBufTest.super.testSliceBytesInArrayMultipleThreads(); + } + }); + } + + @Test + @Override + public void testNioBufferExposeOnlyRegion() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + WrappedUnpooledUnsafeByteBufTest.super.testNioBufferExposeOnlyRegion(); + } + }); + } + + @Test + @Override + public void testGetReadOnlyDirectDst() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + WrappedUnpooledUnsafeByteBufTest.super.testGetReadOnlyDirectDst(); + } + }); + } + + @Test + @Override + public void testGetReadOnlyHeapDst() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + WrappedUnpooledUnsafeByteBufTest.super.testGetReadOnlyHeapDst(); + } + }); + } + + @Test + @Override + public void testReadBytes() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + WrappedUnpooledUnsafeByteBufTest.super.testReadBytes(); + } + }); + } + + @Test + @Override + public void testDuplicateCapacityChange() { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + WrappedUnpooledUnsafeByteBufTest.super.testDuplicateCapacityChange(); + } + }); + } + + @Test + @Override + public void testRetainedDuplicateCapacityChange() { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + WrappedUnpooledUnsafeByteBufTest.super.testRetainedDuplicateCapacityChange(); + } + }); + } + + @Test + @Override + public void testLittleEndianWithExpand() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + WrappedUnpooledUnsafeByteBufTest.super.testLittleEndianWithExpand(); + } + }); + } + + @Test + @Override + public void testWriteUsAsciiCharSequenceExpand() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + WrappedUnpooledUnsafeByteBufTest.super.testWriteUsAsciiCharSequenceExpand(); + } + }); + } + + @Test + @Override + public void testWriteUtf8CharSequenceExpand() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + WrappedUnpooledUnsafeByteBufTest.super.testWriteUtf8CharSequenceExpand(); + } + }); + } + + @Test + @Override + public void testWriteIso88591CharSequenceExpand() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + WrappedUnpooledUnsafeByteBufTest.super.testWriteIso88591CharSequenceExpand(); + } + }); + } + + @Test + @Override + public void testWriteUtf16CharSequenceExpand() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + WrappedUnpooledUnsafeByteBufTest.super.testWriteUtf16CharSequenceExpand(); + } + }); + } + + @Test + @Override + public void testGetBytesByteBuffer() { + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() { + WrappedUnpooledUnsafeByteBufTest.super.testGetBytesByteBuffer(); + } + }); + } + + @Test + @Override + public void testForEachByteDesc2() { + // Ignore + } + + @Test + @Override + public void testForEachByte2() { + // Ignore + } +} diff --git a/netty-buffer/src/test/java/io/netty/buffer/search/BitapSearchProcessorFactoryTest.java b/netty-buffer/src/test/java/io/netty/buffer/search/BitapSearchProcessorFactoryTest.java new file mode 100644 index 0000000..42b6849 --- /dev/null +++ b/netty-buffer/src/test/java/io/netty/buffer/search/BitapSearchProcessorFactoryTest.java @@ -0,0 +1,40 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer.search; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class BitapSearchProcessorFactoryTest { + + @Test + public void testAcceptMaximumLengthNeedle() { + new BitapSearchProcessorFactory(new byte[64]); + } + + @Test + public void testRejectTooLongNeedle() { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + new BitapSearchProcessorFactory(new byte[65]); + } + }); + } + +} diff --git a/netty-buffer/src/test/java/io/netty/buffer/search/MultiSearchProcessorTest.java b/netty-buffer/src/test/java/io/netty/buffer/search/MultiSearchProcessorTest.java new file mode 100644 index 0000000..1a48982 --- /dev/null +++ b/netty-buffer/src/test/java/io/netty/buffer/search/MultiSearchProcessorTest.java @@ -0,0 +1,107 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer.search; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class MultiSearchProcessorTest { + + @Test + public void testSearchForMultiple() { + final ByteBuf haystack = Unpooled.copiedBuffer("one two three one", CharsetUtil.UTF_8); + final int length = haystack.readableBytes(); + + final MultiSearchProcessor processor = AbstractMultiSearchProcessorFactory.newAhoCorasicSearchProcessorFactory( + bytes("one"), + bytes("two"), + bytes("three") + ).newSearchProcessor(); + + assertEquals(-1, processor.getFoundNeedleId()); + + assertEquals(2, haystack.forEachByte(processor)); + assertEquals(0, processor.getFoundNeedleId()); // index of "one" in needles[] + + assertEquals(6, haystack.forEachByte(3, length - 3, processor)); + assertEquals(1, processor.getFoundNeedleId()); // index of "two" in needles[] + + assertEquals(12, haystack.forEachByte(7, length - 7, processor)); + assertEquals(2, processor.getFoundNeedleId()); // index of "three" in needles[] + + assertEquals(16, haystack.forEachByte(13, length - 13, processor)); + assertEquals(0, processor.getFoundNeedleId()); // index of "one" in needles[] + + assertEquals(-1, haystack.forEachByte(17, length - 17, processor)); + + haystack.release(); + } + + @Test + public void testSearchForMultipleOverlapping() { + final ByteBuf haystack = Unpooled.copiedBuffer("abcd", CharsetUtil.UTF_8); + final int length = haystack.readableBytes(); + + final MultiSearchProcessor processor = AbstractMultiSearchProcessorFactory.newAhoCorasicSearchProcessorFactory( + bytes("ab"), + bytes("bc"), + bytes("cd") + ).newSearchProcessor(); + + assertEquals(1, haystack.forEachByte(processor)); + assertEquals(0, processor.getFoundNeedleId()); // index of "ab" in needles[] + + assertEquals(2, haystack.forEachByte(2, length - 2, processor)); + assertEquals(1, processor.getFoundNeedleId()); // index of "bc" in needles[] + + assertEquals(3, haystack.forEachByte(3, length - 3, processor)); + assertEquals(2, processor.getFoundNeedleId()); // index of "cd" in needles[] + + haystack.release(); + } + + @Test + public void findLongerNeedleInCaseOfSuffixMatch() { + final ByteBuf haystack = Unpooled.copiedBuffer("xabcx", CharsetUtil.UTF_8); + + final MultiSearchProcessor processor1 = AbstractMultiSearchProcessorFactory.newAhoCorasicSearchProcessorFactory( + bytes("abc"), + bytes("bc") + ).newSearchProcessor(); + + assertEquals(3, haystack.forEachByte(processor1)); // end of "abc" in haystack + assertEquals(0, processor1.getFoundNeedleId()); // index of "abc" in needles[] + + final MultiSearchProcessor processor2 = AbstractMultiSearchProcessorFactory.newAhoCorasicSearchProcessorFactory( + bytes("bc"), + bytes("abc") + ).newSearchProcessor(); + + assertEquals(3, haystack.forEachByte(processor2)); // end of "abc" in haystack + assertEquals(1, processor2.getFoundNeedleId()); // index of "abc" in needles[] + + haystack.release(); + } + + private static byte[] bytes(String s) { + return s.getBytes(CharsetUtil.UTF_8); + } + +} diff --git a/netty-buffer/src/test/java/io/netty/buffer/search/SearchProcessorTest.java b/netty-buffer/src/test/java/io/netty/buffer/search/SearchProcessorTest.java new file mode 100644 index 0000000..b24caf0 --- /dev/null +++ b/netty-buffer/src/test/java/io/netty/buffer/search/SearchProcessorTest.java @@ -0,0 +1,167 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.buffer.search; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.EnumSource; + +import java.util.Arrays; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class SearchProcessorTest { + + private enum Algorithm { + KNUTH_MORRIS_PRATT { + @Override + SearchProcessorFactory newFactory(byte[] needle) { + return AbstractSearchProcessorFactory.newKmpSearchProcessorFactory(needle); + } + }, + BITAP { + @Override + SearchProcessorFactory newFactory(byte[] needle) { + return AbstractSearchProcessorFactory.newBitapSearchProcessorFactory(needle); + } + }, + AHO_CORASIC { + @Override + SearchProcessorFactory newFactory(byte[] needle) { + return AbstractMultiSearchProcessorFactory.newAhoCorasicSearchProcessorFactory(needle); + } + }; + abstract SearchProcessorFactory newFactory(byte[] needle); + } + + @ParameterizedTest + @EnumSource(Algorithm.class) + public void testSearch(Algorithm algorithm) { + final ByteBuf haystack = Unpooled.copiedBuffer("abc☺", CharsetUtil.UTF_8); + + assertEquals(0, haystack.forEachByte(factory(algorithm, "a").newSearchProcessor())); + assertEquals(1, haystack.forEachByte(factory(algorithm, "ab").newSearchProcessor())); + assertEquals(2, haystack.forEachByte(factory(algorithm, "abc").newSearchProcessor())); + assertEquals(5, haystack.forEachByte(factory(algorithm, "abc☺").newSearchProcessor())); + assertEquals(-1, haystack.forEachByte(factory(algorithm, "abc☺☺").newSearchProcessor())); + assertEquals(-1, haystack.forEachByte(factory(algorithm, "abc☺x").newSearchProcessor())); + + assertEquals(1, haystack.forEachByte(factory(algorithm, "b").newSearchProcessor())); + assertEquals(2, haystack.forEachByte(factory(algorithm, "bc").newSearchProcessor())); + assertEquals(5, haystack.forEachByte(factory(algorithm, "bc☺").newSearchProcessor())); + assertEquals(-1, haystack.forEachByte(factory(algorithm, "bc☺☺").newSearchProcessor())); + assertEquals(-1, haystack.forEachByte(factory(algorithm, "bc☺x").newSearchProcessor())); + + assertEquals(2, haystack.forEachByte(factory(algorithm, "c").newSearchProcessor())); + assertEquals(5, haystack.forEachByte(factory(algorithm, "c☺").newSearchProcessor())); + assertEquals(-1, haystack.forEachByte(factory(algorithm, "c☺☺").newSearchProcessor())); + assertEquals(-1, haystack.forEachByte(factory(algorithm, "c☺x").newSearchProcessor())); + + assertEquals(5, haystack.forEachByte(factory(algorithm, "☺").newSearchProcessor())); + assertEquals(-1, haystack.forEachByte(factory(algorithm, "☺☺").newSearchProcessor())); + assertEquals(-1, haystack.forEachByte(factory(algorithm, "☺x").newSearchProcessor())); + + assertEquals(-1, haystack.forEachByte(factory(algorithm, "z").newSearchProcessor())); + assertEquals(-1, haystack.forEachByte(factory(algorithm, "aa").newSearchProcessor())); + assertEquals(-1, haystack.forEachByte(factory(algorithm, "ba").newSearchProcessor())); + assertEquals(-1, haystack.forEachByte(factory(algorithm, "abcd").newSearchProcessor())); + assertEquals(-1, haystack.forEachByte(factory(algorithm, "abcde").newSearchProcessor())); + + haystack.release(); + } + + @ParameterizedTest + @EnumSource(Algorithm.class) + public void testRepeating(Algorithm algorithm) { + final ByteBuf haystack = Unpooled.copiedBuffer("abcababc", CharsetUtil.UTF_8); + final int length = haystack.readableBytes(); + SearchProcessor processor = factory(algorithm, "ab").newSearchProcessor(); + + assertEquals(1, haystack.forEachByte(processor)); + assertEquals(4, haystack.forEachByte(2, length - 2, processor)); + assertEquals(6, haystack.forEachByte(5, length - 5, processor)); + assertEquals(-1, haystack.forEachByte(7, length - 7, processor)); + + haystack.release(); + } + + @ParameterizedTest + @EnumSource(Algorithm.class) + public void testOverlapping(Algorithm algorithm) { + final ByteBuf haystack = Unpooled.copiedBuffer("ababab", CharsetUtil.UTF_8); + final int length = haystack.readableBytes(); + SearchProcessor processor = factory(algorithm, "bab").newSearchProcessor(); + + assertEquals(3, haystack.forEachByte(processor)); + assertEquals(5, haystack.forEachByte(4, length - 4, processor)); + assertEquals(-1, haystack.forEachByte(6, length - 6, processor)); + + haystack.release(); + } + + @ParameterizedTest + @EnumSource(Algorithm.class) + public void testLongInputs(Algorithm algorithm) { + final int haystackLen = 1024; + final int needleLen = 64; + + final byte[] haystackBytes = new byte[haystackLen]; + haystackBytes[haystackLen - 1] = 1; + final ByteBuf haystack = Unpooled.copiedBuffer(haystackBytes); // 00000...00001 + + final byte[] needleBytes = new byte[needleLen]; // 000...000 + assertEquals(needleLen - 1, haystack.forEachByte(factory(algorithm, needleBytes).newSearchProcessor())); + + needleBytes[needleLen - 1] = 1; // 000...001 + assertEquals(haystackLen - 1, haystack.forEachByte(factory(algorithm, needleBytes).newSearchProcessor())); + + needleBytes[needleLen - 1] = 2; // 000...002 + assertEquals(-1, haystack.forEachByte(factory(algorithm, needleBytes).newSearchProcessor())); + + needleBytes[needleLen - 1] = 0; + needleBytes[0] = 1; // 100...000 + assertEquals(-1, haystack.forEachByte(factory(algorithm, needleBytes).newSearchProcessor())); + } + + @ParameterizedTest + @EnumSource(Algorithm.class) + public void testUniqueLen64Substrings(Algorithm algorithm) { + final byte[] haystackBytes = new byte[32 * 65]; // 1, 2, 2, 3, 3, 3, 4, 4, 4, 4, ... + int pos = 0; + for (int i = 1; i <= 64; i++) { + for (int j = 0; j < i; j++) { + haystackBytes[pos++] = (byte) i; + } + } + final ByteBuf haystack = Unpooled.copiedBuffer(haystackBytes); + + for (int start = 0; start < haystackBytes.length - 64; start++) { + final byte[] needle = Arrays.copyOfRange(haystackBytes, start, start + 64); + assertEquals(start + 63, haystack.forEachByte(factory(algorithm, needle).newSearchProcessor())); + } + } + + private SearchProcessorFactory factory(Algorithm algorithm, byte[] needle) { + return algorithm.newFactory(needle); + } + + private SearchProcessorFactory factory(Algorithm algorithm, String needle) { + return factory(algorithm, needle.getBytes(CharsetUtil.UTF_8)); + } + +} diff --git a/netty-buffer/src/test/resources/logging.properties b/netty-buffer/src/test/resources/logging.properties new file mode 100644 index 0000000..3cd7309 --- /dev/null +++ b/netty-buffer/src/test/resources/logging.properties @@ -0,0 +1,7 @@ +handlers=java.util.logging.ConsoleHandler +.level=ALL +java.util.logging.SimpleFormatter.format=%1$tY-%1$tm-%1$td %1$tH:%1$tM:%1$tS.%1$tL %4$-7s [%3$s] %5$s %6$s%n +java.util.logging.ConsoleHandler.level=ALL +java.util.logging.ConsoleHandler.formatter=java.util.logging.SimpleFormatter +jdk.event.security.level=INFO +org.junit.jupiter.engine.execution.ConditionEvaluator.level=OFF diff --git a/netty-bzip2/build.gradle b/netty-bzip2/build.gradle new file mode 100644 index 0000000..653b374 --- /dev/null +++ b/netty-bzip2/build.gradle @@ -0,0 +1,5 @@ +dependencies { + api project(':netty-buffer') + api project(':netty-handler-codec') + api project(':netty-util') +} \ No newline at end of file diff --git a/netty-bzip2/src/main/java/io/netty/bzip2/Bzip2BitReader.java b/netty-bzip2/src/main/java/io/netty/bzip2/Bzip2BitReader.java new file mode 100644 index 0000000..911c9ed --- /dev/null +++ b/netty-bzip2/src/main/java/io/netty/bzip2/Bzip2BitReader.java @@ -0,0 +1,157 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.bzip2; + +import io.netty.buffer.ByteBuf; + +/** + * An bit reader that allows the reading of single bit booleans, bit strings of + * arbitrary length (up to 32 bits), and bit aligned 32-bit integers. A single byte + * at a time is read from the {@link ByteBuf} when more bits are required. + */ +public class Bzip2BitReader { + /** + * Maximum count of possible readable bytes to check. + */ + private static final int MAX_COUNT_OF_READABLE_BYTES = Integer.MAX_VALUE >>> 3; + + /** + * The {@link ByteBuf} from which to read data. + */ + private ByteBuf in; + + /** + * A buffer of bits read from the input stream that have not yet been returned. + */ + private long bitBuffer; + + /** + * The number of bits currently buffered in {@link #bitBuffer}. + */ + private int bitCount; + + /** + * Set the {@link ByteBuf} from which to read data. + */ + public void setByteBuf(ByteBuf in) { + this.in = in; + } + + /** + * Reads up to 32 bits from the {@link ByteBuf}. + * @param count The number of bits to read (maximum {@code 32} as a size of {@code int}) + * @return The bits requested, right-aligned within the integer + */ + public int readBits(final int count) { + if (count < 0 || count > 32) { + throw new IllegalArgumentException("count: " + count + " (expected: 0-32 )"); + } + int bitCount = this.bitCount; + long bitBuffer = this.bitBuffer; + + if (bitCount < count) { + long readData; + int offset; + switch (in.readableBytes()) { + case 1: { + readData = in.readUnsignedByte(); + offset = 8; + break; + } + case 2: { + readData = in.readUnsignedShort(); + offset = 16; + break; + } + case 3: { + readData = in.readUnsignedMedium(); + offset = 24; + break; + } + default: { + readData = in.readUnsignedInt(); + offset = 32; + break; + } + } + + bitBuffer = bitBuffer << offset | readData; + bitCount += offset; + this.bitBuffer = bitBuffer; + } + + this.bitCount = bitCount -= count; + return (int) (bitBuffer >>> bitCount & (count != 32 ? (1 << count) - 1 : 0xFFFFFFFFL)); + } + + /** + * Reads a single bit from the {@link ByteBuf}. + * @return {@code true} if the bit read was {@code 1}, otherwise {@code false} + */ + public boolean readBoolean() { + return readBits(1) != 0; + } + + /** + * Reads 32 bits of input as an integer. + * @return The integer read + */ + public int readInt() { + return readBits(32); + } + + /** + * Refill the {@link ByteBuf} by one byte. + */ + public void refill() { + int readData = in.readUnsignedByte(); + bitBuffer = bitBuffer << 8 | readData; + bitCount += 8; + } + + /** + * Checks that at least one bit is available for reading. + * @return {@code true} if one bit is available for reading, otherwise {@code false} + */ + public boolean isReadable() { + return bitCount > 0 || in.isReadable(); + } + + /** + * Checks that the specified number of bits available for reading. + * @param count The number of bits to check + * @return {@code true} if {@code count} bits are available for reading, otherwise {@code false} + */ + public boolean hasReadableBits(int count) { + if (count < 0) { + throw new IllegalArgumentException("count: " + count + " (expected value greater than 0)"); + } + return bitCount >= count || (in.readableBytes() << 3 & Integer.MAX_VALUE) >= count - bitCount; + } + + /** + * Checks that the specified number of bytes available for reading. + * @param count The number of bytes to check + * @return {@code true} if {@code count} bytes are available for reading, otherwise {@code false} + */ + public boolean hasReadableBytes(int count) { + if (count < 0 || count > MAX_COUNT_OF_READABLE_BYTES) { + throw new IllegalArgumentException("count: " + count + + " (expected: 0-" + MAX_COUNT_OF_READABLE_BYTES + ')'); + } + return hasReadableBits(count << 3); + } +} diff --git a/netty-bzip2/src/main/java/io/netty/bzip2/Bzip2BitWriter.java b/netty-bzip2/src/main/java/io/netty/bzip2/Bzip2BitWriter.java new file mode 100644 index 0000000..d455b6e --- /dev/null +++ b/netty-bzip2/src/main/java/io/netty/bzip2/Bzip2BitWriter.java @@ -0,0 +1,120 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.bzip2; + +import io.netty.buffer.ByteBuf; + +/** + * A bit writer that allows the writing of single bit booleans, unary numbers, bit strings + * of arbitrary length (up to 32 bits), and bit aligned 32-bit integers. A single byte at a + * time is written to the {@link ByteBuf} when sufficient bits have been accumulated. + */ +public final class Bzip2BitWriter { + /** + * A buffer of bits waiting to be written to the output stream. + */ + private long bitBuffer; + + /** + * The number of bits currently buffered in {@link #bitBuffer}. + */ + private int bitCount; + + /** + * Writes up to 32 bits to the output {@link ByteBuf}. + * @param count The number of bits to write (maximum {@code 32} as a size of {@code int}) + * @param value The bits to write + */ + public void writeBits(ByteBuf out, final int count, final long value) { + if (count < 0 || count > 32) { + throw new IllegalArgumentException("count: " + count + " (expected: 0-32)"); + } + int bitCount = this.bitCount; + long bitBuffer = this.bitBuffer | value << 64 - count >>> bitCount; + bitCount += count; + + if (bitCount >= 32) { + out.writeInt((int) (bitBuffer >>> 32)); + bitBuffer <<= 32; + bitCount -= 32; + } + this.bitBuffer = bitBuffer; + this.bitCount = bitCount; + } + + /** + * Writes a single bit to the output {@link ByteBuf}. + * @param value The bit to write + */ + public void writeBoolean(ByteBuf out, final boolean value) { + int bitCount = this.bitCount + 1; + long bitBuffer = this.bitBuffer | (value ? 1L << 64 - bitCount : 0L); + + if (bitCount == 32) { + out.writeInt((int) (bitBuffer >>> 32)); + bitBuffer = 0; + bitCount = 0; + } + this.bitBuffer = bitBuffer; + this.bitCount = bitCount; + } + + /** + * Writes a zero-terminated unary number to the output {@link ByteBuf}. + * Example of the output for value = 6: {@code 1111110} + * @param value The number of {@code 1} to write + */ + public void writeUnary(ByteBuf out, int value) { + if (value < 0) { + throw new IllegalArgumentException("value: " + value + " (expected 0 or more)"); + } + while (value-- > 0) { + writeBoolean(out, true); + } + writeBoolean(out, false); + } + + /** + * Writes an integer as 32 bits to the output {@link ByteBuf}. + * @param value The integer to write + */ + public void writeInt(ByteBuf out, final int value) { + writeBits(out, 32, value); + } + + /** + * Writes any remaining bits to the output {@link ByteBuf}, + * zero padding to a whole byte as required. + */ + public void flush(ByteBuf out) { + final int bitCount = this.bitCount; + + if (bitCount > 0) { + final long bitBuffer = this.bitBuffer; + final int shiftToRight = 64 - bitCount; + + if (bitCount <= 8) { + out.writeByte((int) (bitBuffer >>> shiftToRight << 8 - bitCount)); + } else if (bitCount <= 16) { + out.writeShort((int) (bitBuffer >>> shiftToRight << 16 - bitCount)); + } else if (bitCount <= 24) { + out.writeMedium((int) (bitBuffer >>> shiftToRight << 24 - bitCount)); + } else { + out.writeInt((int) (bitBuffer >>> shiftToRight << 32 - bitCount)); + } + } + } +} diff --git a/netty-bzip2/src/main/java/io/netty/bzip2/Bzip2BlockCompressor.java b/netty-bzip2/src/main/java/io/netty/bzip2/Bzip2BlockCompressor.java new file mode 100644 index 0000000..7435ca0 --- /dev/null +++ b/netty-bzip2/src/main/java/io/netty/bzip2/Bzip2BlockCompressor.java @@ -0,0 +1,298 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.bzip2; + +import io.netty.buffer.ByteBuf; +import io.netty.util.ByteProcessor; + +import static io.netty.bzip2.Bzip2Constants.BLOCK_HEADER_MAGIC_1; +import static io.netty.bzip2.Bzip2Constants.BLOCK_HEADER_MAGIC_2; +import static io.netty.bzip2.Bzip2Constants.HUFFMAN_SYMBOL_RANGE_SIZE; + +/** + * Compresses and writes a single Bzip2 block.

+ * + * Block encoding consists of the following stages:
+ * 1. Run-Length Encoding[1] - {@link #write(int)}
+ * 2. Burrows Wheeler Transform - {@link #close(ByteBuf)} (through {@link Bzip2DivSufSort})
+ * 3. Write block header - {@link #close(ByteBuf)}
+ * 4. Move To Front Transform - {@link #close(ByteBuf)} (through {@link Bzip2HuffmanStageEncoder})
+ * 5. Run-Length Encoding[2] - {@link #close(ByteBuf)} (through {@link Bzip2HuffmanStageEncoder})
+ * 6. Create and write Huffman tables - {@link #close(ByteBuf)} (through {@link Bzip2HuffmanStageEncoder})
+ * 7. Huffman encode and write data - {@link #close(ByteBuf)} (through {@link Bzip2HuffmanStageEncoder}) + */ +public final class Bzip2BlockCompressor { + private final ByteProcessor writeProcessor = new ByteProcessor() { + @Override + public boolean process(byte value) throws Exception { + return write(value); + } + }; + + /** + * A writer that provides bit-level writes. + */ + private final Bzip2BitWriter writer; + + /** + * CRC builder for the block. + */ + private final Crc32 crc = new Crc32(); + + /** + * The RLE'd block data. + */ + private final byte[] block; + + /** + * Current length of the data within the {@link #block} array. + */ + private int blockLength; + + /** + * A limit beyond which new data will not be accepted into the block. + */ + private final int blockLengthLimit; + + /** + * The values that are present within the RLE'd block data. For each index, {@code true} if that + * value is present within the data, otherwise {@code false}. + */ + private final boolean[] blockValuesPresent = new boolean[256]; + + /** + * The Burrows Wheeler Transformed block data. + */ + private final int[] bwtBlock; + + /** + * The current RLE value being accumulated (undefined when {@link #rleLength} is 0). + */ + private int rleCurrentValue = -1; + + /** + * The repeat count of the current RLE value. + */ + private int rleLength; + + /** + * @param writer The {@link Bzip2BitWriter} which provides bit-level writes + * @param blockSize The declared block size in bytes. Up to this many bytes will be accepted + * into the block after Run-Length Encoding is applied + */ + public Bzip2BlockCompressor(final Bzip2BitWriter writer, final int blockSize) { + this.writer = writer; + + // One extra byte is added to allow for the block wrap applied in close() + block = new byte[blockSize + 1]; + bwtBlock = new int[blockSize + 1]; + blockLengthLimit = blockSize - 6; // 5 bytes for one RLE run plus one byte - see {@link #write(int)} + } + + /** + * Write the Huffman symbol to output byte map. + */ + private void writeSymbolMap(ByteBuf out) { + Bzip2BitWriter writer = this.writer; + + final boolean[] blockValuesPresent = this.blockValuesPresent; + final boolean[] condensedInUse = new boolean[16]; + + for (int i = 0; i < condensedInUse.length; i++) { + for (int j = 0, k = i << 4; j < HUFFMAN_SYMBOL_RANGE_SIZE; j++, k++) { + if (blockValuesPresent[k]) { + condensedInUse[i] = true; + break; + } + } + } + + for (boolean isCondensedInUse : condensedInUse) { + writer.writeBoolean(out, isCondensedInUse); + } + + for (int i = 0; i < condensedInUse.length; i++) { + if (condensedInUse[i]) { + for (int j = 0, k = i << 4; j < HUFFMAN_SYMBOL_RANGE_SIZE; j++, k++) { + writer.writeBoolean(out, blockValuesPresent[k]); + } + } + } + } + + /** + * Writes an RLE run to the block array, updating the block CRC and present values array as required. + * @param value The value to write + * @param runLength The run length of the value to write + */ + private void writeRun(final int value, int runLength) { + final int blockLength = this.blockLength; + final byte[] block = this.block; + + blockValuesPresent[value] = true; + crc.updateCRC(value, runLength); + + final byte byteValue = (byte) value; + switch (runLength) { + case 1: + block[blockLength] = byteValue; + this.blockLength = blockLength + 1; + break; + case 2: + block[blockLength] = byteValue; + block[blockLength + 1] = byteValue; + this.blockLength = blockLength + 2; + break; + case 3: + block[blockLength] = byteValue; + block[blockLength + 1] = byteValue; + block[blockLength + 2] = byteValue; + this.blockLength = blockLength + 3; + break; + default: + runLength -= 4; + blockValuesPresent[runLength] = true; + block[blockLength] = byteValue; + block[blockLength + 1] = byteValue; + block[blockLength + 2] = byteValue; + block[blockLength + 3] = byteValue; + block[blockLength + 4] = (byte) runLength; + this.blockLength = blockLength + 5; + break; + } + } + + /** + * Writes a byte to the block, accumulating to an RLE run where possible. + * @param value The byte to write + * @return {@code true} if the byte was written, or {@code false} if the block is already full + */ + public boolean write(final int value) { + if (blockLength > blockLengthLimit) { + return false; + } + final int rleCurrentValue = this.rleCurrentValue; + final int rleLength = this.rleLength; + + if (rleLength == 0) { + this.rleCurrentValue = value; + this.rleLength = 1; + } else if (rleCurrentValue != value) { + // This path commits us to write 6 bytes - one RLE run (5 bytes) plus one extra + writeRun(rleCurrentValue & 0xff, rleLength); + this.rleCurrentValue = value; + this.rleLength = 1; + } else { + if (rleLength == 254) { + writeRun(rleCurrentValue & 0xff, 255); + this.rleLength = 0; + } else { + this.rleLength = rleLength + 1; + } + } + return true; + } + + /** + * Writes an array to the block. + * @param buffer The buffer to write + * @param offset The offset within the input data to write from + * @param length The number of bytes of input data to write + * @return The actual number of input bytes written. May be less than the number requested, or + * zero if the block is already full + */ + public int write(final ByteBuf buffer, int offset, int length) { + int index = buffer.forEachByte(offset, length, writeProcessor); + return index == -1 ? length : index - offset; + } + + /** + * Compresses and writes out the block. + */ + public void close(ByteBuf out) { + // If an RLE run is in progress, write it out + if (rleLength > 0) { + writeRun(rleCurrentValue & 0xff, rleLength); + } + + // Apply a one byte block wrap required by the BWT implementation + block[blockLength] = block[0]; + + // Perform the Burrows Wheeler Transform + Bzip2DivSufSort divSufSort = new Bzip2DivSufSort(block, bwtBlock, blockLength); + int bwtStartPointer = divSufSort.bwt(); + + Bzip2BitWriter writer = this.writer; + + // Write out the block header + writer.writeBits(out, 24, BLOCK_HEADER_MAGIC_1); + writer.writeBits(out, 24, BLOCK_HEADER_MAGIC_2); + writer.writeInt(out, crc.getCRC()); + writer.writeBoolean(out, false); // Randomised block flag. We never create randomised blocks + writer.writeBits(out, 24, bwtStartPointer); + + // Write out the symbol map + writeSymbolMap(out); + + // Perform the Move To Front Transform and Run-Length Encoding[2] stages + Bzip2MTFAndRLE2StageEncoder mtfEncoder = new Bzip2MTFAndRLE2StageEncoder(bwtBlock, blockLength, + blockValuesPresent); + mtfEncoder.encode(); + + // Perform the Huffman Encoding stage and write out the encoded data + Bzip2HuffmanStageEncoder huffmanEncoder = new Bzip2HuffmanStageEncoder(writer, + mtfEncoder.mtfBlock(), + mtfEncoder.mtfLength(), + mtfEncoder.mtfAlphabetSize(), + mtfEncoder.mtfSymbolFrequencies()); + huffmanEncoder.encode(out); + } + + /** + * Gets available size of the current block. + * @return Number of available bytes which can be written + */ + public int availableSize() { + if (blockLength == 0) { + return blockLengthLimit + 2; + } + return blockLengthLimit - blockLength + 1; + } + + /** + * Determines if the block is full and ready for compression. + * @return {@code true} if the block is full, otherwise {@code false} + */ + public boolean isFull() { + return blockLength > blockLengthLimit; + } + + /** + * Determines if any bytes have been written to the block. + * @return {@code true} if one or more bytes has been written to the block, otherwise {@code false} + */ + public boolean isEmpty() { + return blockLength == 0 && rleLength == 0; + } + + /** + * Gets the CRC of the completed block. Only valid after calling {@link #close(ByteBuf)}. + * @return The block's CRC + */ + public int crc() { + return crc.getCRC(); + } +} diff --git a/netty-bzip2/src/main/java/io/netty/bzip2/Bzip2BlockDecompressor.java b/netty-bzip2/src/main/java/io/netty/bzip2/Bzip2BlockDecompressor.java new file mode 100644 index 0000000..09476b0 --- /dev/null +++ b/netty-bzip2/src/main/java/io/netty/bzip2/Bzip2BlockDecompressor.java @@ -0,0 +1,350 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.bzip2; + +import static io.netty.bzip2.Bzip2Constants.HUFFMAN_DECODE_MAX_CODE_LENGTH; +import static io.netty.bzip2.Bzip2Constants.HUFFMAN_SYMBOL_RUNA; +import static io.netty.bzip2.Bzip2Constants.HUFFMAN_SYMBOL_RUNB; +import static io.netty.bzip2.Bzip2Constants.MAX_BLOCK_LENGTH; + +/** + * Reads and decompresses a single Bzip2 block.

+ * + * Block decoding consists of the following stages:
+ * 1. Read block header
+ * 2. Read Huffman tables
+ * 3. Read and decode Huffman encoded data - {@link #decodeHuffmanData(Bzip2HuffmanStageDecoder)}
+ * 4. Run-Length Decoding[2] - {@link #decodeHuffmanData(Bzip2HuffmanStageDecoder)}
+ * 5. Inverse Move To Front Transform - {@link #decodeHuffmanData(Bzip2HuffmanStageDecoder)}
+ * 6. Inverse Burrows Wheeler Transform - {@link #initialiseInverseBWT()}
+ * 7. Run-Length Decoding[1] - {@link #read()}
+ * 8. Optional Block De-Randomisation - {@link #read()} (through {@link #decodeNextBWTByte()}) + */ +public final class Bzip2BlockDecompressor { + /** + * A reader that provides bit-level reads. + */ + private final Bzip2BitReader reader; + + /** + * Calculates the block CRC from the fully decoded bytes of the block. + */ + private final Crc32 crc = new Crc32(); + + /** + * The CRC of the current block as read from the block header. + */ + private final int blockCRC; + + /** + * {@code true} if the current block is randomised, otherwise {@code false}. + */ + private final boolean blockRandomised; + + /* Huffman Decoding stage */ + /** + * The end-of-block Huffman symbol. Decoding of the block ends when this is encountered. + */ + public int huffmanEndOfBlockSymbol; + + /** + * Bitmap, of ranges of 16 bytes, present/not present. + */ + public int huffmanInUse16; + + /** + * A map from Huffman symbol index to output character. Some types of data (e.g. ASCII text) + * may contain only a limited number of byte values; Huffman symbols are only allocated to + * those values that actually occur in the uncompressed data. + */ + public final byte[] huffmanSymbolMap = new byte[256]; + + /* Move To Front stage */ + /** + * Counts of each byte value within the {@link Bzip2BlockDecompressor#huffmanSymbolMap} data. + * Collected at the Move To Front stage, consumed by the Inverse Burrows Wheeler Transform stage. + */ + private final int[] bwtByteCounts = new int[256]; + + /** + * The Burrows-Wheeler Transform processed data. Read at the Move To Front stage, consumed by the + * Inverse Burrows Wheeler Transform stage. + */ + private final byte[] bwtBlock; + + /** + * Starting pointer into BWT for after untransform. + */ + private final int bwtStartPointer; + + /* Inverse Burrows-Wheeler Transform stage */ + /** + * At each position contains the union of :- + * An output character (8 bits) + * A pointer from each position to its successor (24 bits, left shifted 8 bits) + * As the pointer cannot exceed the maximum block size of 900k, 24 bits is more than enough to + * hold it; Folding the character data into the spare bits while performing the inverse BWT, + * when both pieces of information are available, saves a large number of memory accesses in + * the final decoding stages. + */ + private int[] bwtMergedPointers; + + /** + * The current merged pointer into the Burrow-Wheeler Transform array. + */ + private int bwtCurrentMergedPointer; + + /** + * The actual length in bytes of the current block at the Inverse Burrows Wheeler Transform + * stage (before final Run-Length Decoding). + */ + private int bwtBlockLength; + + /** + * The number of output bytes that have been decoded up to the Inverse Burrows Wheeler Transform stage. + */ + private int bwtBytesDecoded; + + /* Run-Length Encoding and Random Perturbation stage */ + /** + * The most recently RLE decoded byte. + */ + private int rleLastDecodedByte = -1; + + /** + * The number of previous identical output bytes decoded. After 4 identical bytes, the next byte + * decoded is an RLE repeat count. + */ + private int rleAccumulator; + + /** + * The RLE repeat count of the current decoded byte. When this reaches zero, a new byte is decoded. + */ + private int rleRepeat; + + /** + * If the current block is randomised, the position within the RNUMS randomisation array. + */ + private int randomIndex; + + /** + * If the current block is randomised, the remaining count at the current RNUMS position. + */ + private int randomCount = Bzip2Rand.rNums(0) - 1; + + /** + * Table for Move To Front transformations. + */ + private final Bzip2MoveToFrontTable symbolMTF = new Bzip2MoveToFrontTable(); + + // This variables is used to save current state if we haven't got enough readable bits + private int repeatCount; + private int repeatIncrement = 1; + private int mtfValue; + + public Bzip2BlockDecompressor(final int blockSize, final int blockCRC, final boolean blockRandomised, + final int bwtStartPointer, final Bzip2BitReader reader) { + + bwtBlock = new byte[blockSize]; + + this.blockCRC = blockCRC; + this.blockRandomised = blockRandomised; + this.bwtStartPointer = bwtStartPointer; + + this.reader = reader; + } + + /** + * Reads the Huffman encoded data from the input stream, performs Run-Length Decoding and + * applies the Move To Front transform to reconstruct the Burrows-Wheeler Transform array. + */ + public boolean decodeHuffmanData(final Bzip2HuffmanStageDecoder huffmanDecoder) { + final Bzip2BitReader reader = this.reader; + final byte[] bwtBlock = this.bwtBlock; + final byte[] huffmanSymbolMap = this.huffmanSymbolMap; + final int streamBlockSize = this.bwtBlock.length; + final int huffmanEndOfBlockSymbol = this.huffmanEndOfBlockSymbol; + final int[] bwtByteCounts = this.bwtByteCounts; + final Bzip2MoveToFrontTable symbolMTF = this.symbolMTF; + + int bwtBlockLength = this.bwtBlockLength; + int repeatCount = this.repeatCount; + int repeatIncrement = this.repeatIncrement; + int mtfValue = this.mtfValue; + + for (;;) { + if (!reader.hasReadableBits(HUFFMAN_DECODE_MAX_CODE_LENGTH)) { + this.bwtBlockLength = bwtBlockLength; + this.repeatCount = repeatCount; + this.repeatIncrement = repeatIncrement; + this.mtfValue = mtfValue; + return false; + } + final int nextSymbol = huffmanDecoder.nextSymbol(); + + if (nextSymbol == HUFFMAN_SYMBOL_RUNA) { + repeatCount += repeatIncrement; + repeatIncrement <<= 1; + } else if (nextSymbol == HUFFMAN_SYMBOL_RUNB) { + repeatCount += repeatIncrement << 1; + repeatIncrement <<= 1; + } else { + if (repeatCount > 0) { + if (bwtBlockLength + repeatCount > streamBlockSize) { + throw new DecompressionException("block exceeds declared block size"); + } + final byte nextByte = huffmanSymbolMap[mtfValue]; + bwtByteCounts[nextByte & 0xff] += repeatCount; + while (--repeatCount >= 0) { + bwtBlock[bwtBlockLength++] = nextByte; + } + + repeatCount = 0; + repeatIncrement = 1; + } + + if (nextSymbol == huffmanEndOfBlockSymbol) { + break; + } + + if (bwtBlockLength >= streamBlockSize) { + throw new DecompressionException("block exceeds declared block size"); + } + + mtfValue = symbolMTF.indexToFront(nextSymbol - 1) & 0xff; + + final byte nextByte = huffmanSymbolMap[mtfValue]; + bwtByteCounts[nextByte & 0xff]++; + bwtBlock[bwtBlockLength++] = nextByte; + } + } + if (bwtBlockLength > MAX_BLOCK_LENGTH) { + throw new DecompressionException("block length exceeds max block length: " + + bwtBlockLength + " > " + MAX_BLOCK_LENGTH); + } + + this.bwtBlockLength = bwtBlockLength; + initialiseInverseBWT(); + return true; + } + + /** + * Set up the Inverse Burrows-Wheeler Transform merged pointer array. + */ + private void initialiseInverseBWT() { + final int bwtStartPointer = this.bwtStartPointer; + final byte[] bwtBlock = this.bwtBlock; + final int[] bwtMergedPointers = new int[bwtBlockLength]; + final int[] characterBase = new int[256]; + + if (bwtStartPointer < 0 || bwtStartPointer >= bwtBlockLength) { + throw new DecompressionException("start pointer invalid"); + } + + // Cumulative character counts + System.arraycopy(bwtByteCounts, 0, characterBase, 1, 255); + for (int i = 2; i <= 255; i++) { + characterBase[i] += characterBase[i - 1]; + } + + // Merged-Array Inverse Burrows-Wheeler Transform + // Combining the output characters and forward pointers into a single array here, where we + // have already read both of the corresponding values, cuts down on memory accesses in the + // final walk through the array + for (int i = 0; i < bwtBlockLength; i++) { + int value = bwtBlock[i] & 0xff; + bwtMergedPointers[characterBase[value]++] = (i << 8) + value; + } + + this.bwtMergedPointers = bwtMergedPointers; + bwtCurrentMergedPointer = bwtMergedPointers[bwtStartPointer]; + } + + /** + * Decodes a byte from the final Run-Length Encoding stage, pulling a new byte from the + * Burrows-Wheeler Transform stage when required. + * @return The decoded byte, or -1 if there are no more bytes + */ + public int read() { + while (rleRepeat < 1) { + if (bwtBytesDecoded == bwtBlockLength) { + return -1; + } + + int nextByte = decodeNextBWTByte(); + if (nextByte != rleLastDecodedByte) { + // New byte, restart accumulation + rleLastDecodedByte = nextByte; + rleRepeat = 1; + rleAccumulator = 1; + crc.updateCRC(nextByte); + } else { + if (++rleAccumulator == 4) { + // Accumulation complete, start repetition + int rleRepeat = decodeNextBWTByte() + 1; + this.rleRepeat = rleRepeat; + rleAccumulator = 0; + crc.updateCRC(nextByte, rleRepeat); + } else { + rleRepeat = 1; + crc.updateCRC(nextByte); + } + } + } + rleRepeat--; + + return rleLastDecodedByte; + } + + /** + * Decodes a byte from the Burrows-Wheeler Transform stage. If the block has randomisation + * applied, reverses the randomisation. + * @return The decoded byte + */ + private int decodeNextBWTByte() { + int mergedPointer = bwtCurrentMergedPointer; + int nextDecodedByte = mergedPointer & 0xff; + bwtCurrentMergedPointer = bwtMergedPointers[mergedPointer >>> 8]; + + if (blockRandomised) { + if (--randomCount == 0) { + nextDecodedByte ^= 1; + randomIndex = (randomIndex + 1) % 512; + randomCount = Bzip2Rand.rNums(randomIndex); + } + } + bwtBytesDecoded++; + + return nextDecodedByte; + } + + public int blockLength() { + return bwtBlockLength; + } + + /** + * Verify and return the block CRC. This method may only be called + * after all of the block's bytes have been read. + * @return The block CRC + */ + public int checkCRC() { + final int computedBlockCRC = crc.getCRC(); + if (blockCRC != computedBlockCRC) { + throw new DecompressionException("block CRC error"); + } + return computedBlockCRC; + } +} diff --git a/netty-bzip2/src/main/java/io/netty/bzip2/Bzip2Constants.java b/netty-bzip2/src/main/java/io/netty/bzip2/Bzip2Constants.java new file mode 100644 index 0000000..6e4e0f9 --- /dev/null +++ b/netty-bzip2/src/main/java/io/netty/bzip2/Bzip2Constants.java @@ -0,0 +1,106 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.bzip2; + +/** + * Constants for both the Bzip2Encoder and the Bzip2Decoder. + */ +public final class Bzip2Constants { + + /** + * Magic number of Bzip2 stream. + */ + public static final int MAGIC_NUMBER = 'B' << 16 | 'Z' << 8 | 'h'; + + /** + * Block header magic number. Equals to BCD (pi). + */ + public static final int BLOCK_HEADER_MAGIC_1 = 0x314159; + public static final int BLOCK_HEADER_MAGIC_2 = 0x265359; + + /** + * End of stream magic number. Equals to BCD sqrt(pi). + */ + public static final int END_OF_STREAM_MAGIC_1 = 0x177245; + public static final int END_OF_STREAM_MAGIC_2 = 0x385090; + + /** + * Base block size. + */ + public static final int BASE_BLOCK_SIZE = 100000; + + /** + * Minimum and maximum size of one block. + * Must be multiplied by {@link Bzip2Constants#BASE_BLOCK_SIZE}. + */ + public static final int MIN_BLOCK_SIZE = 1; + public static final int MAX_BLOCK_SIZE = 9; + + public static final int MAX_BLOCK_LENGTH = MAX_BLOCK_SIZE * BASE_BLOCK_SIZE; + + /** + * Maximum possible Huffman alphabet size. + */ + public static final int HUFFMAN_MAX_ALPHABET_SIZE = 258; + + /** + * The longest Huffman code length created by the encoder. + */ + public static final int HUFFMAN_ENCODE_MAX_CODE_LENGTH = 20; + + /** + * The longest Huffman code length accepted by the decoder. + */ + public static final int HUFFMAN_DECODE_MAX_CODE_LENGTH = 23; + + /** + * Huffman symbols used for run-length encoding. + */ + public static final int HUFFMAN_SYMBOL_RUNA = 0; + public static final int HUFFMAN_SYMBOL_RUNB = 1; + + /** + * Huffman symbols range size for Huffman used map. + */ + public static final int HUFFMAN_SYMBOL_RANGE_SIZE = 16; + + /** + * Maximum length of zero-terminated bit runs of MTF'ed Huffman table. + */ + public static final int HUFFMAN_SELECTOR_LIST_MAX_LENGTH = 6; + + /** + * Number of symbols decoded after which a new Huffman table is selected. + */ + public static final int HUFFMAN_GROUP_RUN_LENGTH = 50; + + /** + * Maximum possible number of Huffman table selectors. + */ + public static final int MAX_SELECTORS = 2 + 900000 / HUFFMAN_GROUP_RUN_LENGTH; // 18002 + + /** + * Minimum number of alternative Huffman tables. + */ + public static final int HUFFMAN_MINIMUM_TABLES = 2; + + /** + * Maximum number of alternative Huffman tables. + */ + public static final int HUFFMAN_MAXIMUM_TABLES = 6; + + private Bzip2Constants() { } +} diff --git a/netty-bzip2/src/main/java/io/netty/bzip2/Bzip2DivSufSort.java b/netty-bzip2/src/main/java/io/netty/bzip2/Bzip2DivSufSort.java new file mode 100644 index 0000000..c9b4bad --- /dev/null +++ b/netty-bzip2/src/main/java/io/netty/bzip2/Bzip2DivSufSort.java @@ -0,0 +1,2117 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.bzip2; + +/** + * DivSufSort suffix array generator.
+ * + * Based on libdivsufsort 1.2.3 patched to support Bzip2.
+ * This is a simple conversion of the original C with two minor bugfixes applied (see "BUGFIX" + * comments within the class). Documentation within the class is largely absent. + */ +final class Bzip2DivSufSort { + + private static final int STACK_SIZE = 64; + private static final int BUCKET_A_SIZE = 256; + private static final int BUCKET_B_SIZE = 65536; + private static final int SS_BLOCKSIZE = 1024; + private static final int INSERTIONSORT_THRESHOLD = 8; + + private static final int[] LOG_2_TABLE = { + -1, 0, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 3, 3, 3, 3, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, 4, + 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, 5, + 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, + 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7 + }; + + private final int[] SA; + private final byte[] T; + private final int n; + + /** + * @param block The input array + * @param bwtBlock The output array + * @param blockLength The length of the input data + */ + Bzip2DivSufSort(final byte[] block, final int[] bwtBlock, final int blockLength) { + T = block; + SA = bwtBlock; + n = blockLength; + } + + private static void swapElements(final int[] array1, final int idx1, final int[] array2, final int idx2) { + final int temp = array1[idx1]; + array1[idx1] = array2[idx2]; + array2[idx2] = temp; + } + + private int ssCompare(final int p1, final int p2, final int depth) { + final int[] SA = this.SA; + final byte[] T = this.T; + + // pointers within T + final int U1n = SA[p1 + 1] + 2; + final int U2n = SA[p2 + 1] + 2; + + int U1 = depth + SA[p1]; + int U2 = depth + SA[p2]; + + while (U1 < U1n && U2 < U2n && T[U1] == T[U2]) { + ++U1; + ++U2; + } + + return U1 < U1n ? + U2 < U2n ? (T[U1] & 0xff) - (T[U2] & 0xff) : 1 + : U2 < U2n ? -1 : 0; + } + + private int ssCompareLast(int pa, int p1, int p2, int depth, int size) { + final int[] SA = this.SA; + final byte[] T = this.T; + + int U1 = depth + SA[p1]; + int U2 = depth + SA[p2]; + int U1n = size; + int U2n = SA[p2 + 1] + 2; + + while (U1 < U1n && U2 < U2n && T[U1] == T[U2]) { + ++U1; + ++U2; + } + + if (U1 < U1n) { + return U2 < U2n ? (T[U1] & 0xff) - (T[U2] & 0xff) : 1; + } + if (U2 == U2n) { + return 1; + } + + U1 %= size; + U1n = SA[pa] + 2; + while (U1 < U1n && U2 < U2n && T[U1] == T[U2]) { + ++U1; + ++U2; + } + + return U1 < U1n ? + U2 < U2n ? (T[U1] & 0xff) - (T[U2] & 0xff) : 1 + : U2 < U2n ? -1 : 0; + } + + private void ssInsertionSort(int pa, int first, int last, int depth) { + final int[] SA = this.SA; + + int i, j; // pointer within SA + int t; + int r; + + for (i = last - 2; first <= i; --i) { + for (t = SA[i], j = i + 1; 0 < (r = ssCompare(pa + t, pa + SA[j], depth));) { + do { + SA[j - 1] = SA[j]; + } while (++j < last && SA[j] < 0); + if (last <= j) { + break; + } + } + if (r == 0) { + SA[j] = ~SA[j]; + } + SA[j - 1] = t; + } + } + + private void ssFixdown(int td, int pa, int sa, int i, int size) { + final int[] SA = this.SA; + final byte[] T = this.T; + + int j, k; + int v; + int c, d, e; + + for (v = SA[sa + i], c = T[td + SA[pa + v]] & 0xff; (j = 2 * i + 1) < size; SA[sa + i] = SA[sa + k], i = k) { + d = T[td + SA[pa + SA[sa + (k = j++)]]] & 0xff; + if (d < (e = T[td + SA[pa + SA[sa + j]]] & 0xff)) { + k = j; + d = e; + } + if (d <= c) { + break; + } + } + SA[sa + i] = v; + } + + private void ssHeapSort(int td, int pa, int sa, int size) { + final int[] SA = this.SA; + final byte[] T = this.T; + + int i, m; + int t; + + m = size; + if (size % 2 == 0) { + m--; + if ((T[td + SA[pa + SA[sa + m / 2]]] & 0xff) < (T[td + SA[pa + SA[sa + m]]] & 0xff)) { + swapElements(SA, sa + m, SA, sa + m / 2); + } + } + + for (i = m / 2 - 1; 0 <= i; --i) { + ssFixdown(td, pa, sa, i, m); + } + + if (size % 2 == 0) { + swapElements(SA, sa, SA, sa + m); + ssFixdown(td, pa, sa, 0, m); + } + + for (i = m - 1; 0 < i; --i) { + t = SA[sa]; + SA[sa] = SA[sa + i]; + ssFixdown(td, pa, sa, 0, i); + SA[sa + i] = t; + } + } + + private int ssMedian3(final int td, final int pa, int v1, int v2, int v3) { + final int[] SA = this.SA; + final byte[] T = this.T; + + int T_v1 = T[td + SA[pa + SA[v1]]] & 0xff; + int T_v2 = T[td + SA[pa + SA[v2]]] & 0xff; + int T_v3 = T[td + SA[pa + SA[v3]]] & 0xff; + + if (T_v1 > T_v2) { + final int temp = v1; + v1 = v2; + v2 = temp; + final int T_vtemp = T_v1; + T_v1 = T_v2; + T_v2 = T_vtemp; + } + if (T_v2 > T_v3) { + if (T_v1 > T_v3) { + return v1; + } + return v3; + } + return v2; + } + + private int ssMedian5(final int td, final int pa, int v1, int v2, int v3, int v4, int v5) { + final int[] SA = this.SA; + final byte[] T = this.T; + + int T_v1 = T[td + SA[pa + SA[v1]]] & 0xff; + int T_v2 = T[td + SA[pa + SA[v2]]] & 0xff; + int T_v3 = T[td + SA[pa + SA[v3]]] & 0xff; + int T_v4 = T[td + SA[pa + SA[v4]]] & 0xff; + int T_v5 = T[td + SA[pa + SA[v5]]] & 0xff; + int temp; + int T_vtemp; + + if (T_v2 > T_v3) { + temp = v2; + v2 = v3; + v3 = temp; + T_vtemp = T_v2; + T_v2 = T_v3; + T_v3 = T_vtemp; + } + if (T_v4 > T_v5) { + temp = v4; + v4 = v5; + v5 = temp; + T_vtemp = T_v4; + T_v4 = T_v5; + T_v5 = T_vtemp; + } + if (T_v2 > T_v4) { + temp = v2; + v4 = temp; + T_vtemp = T_v2; + T_v4 = T_vtemp; + temp = v3; + v3 = v5; + v5 = temp; + T_vtemp = T_v3; + T_v3 = T_v5; + T_v5 = T_vtemp; + } + if (T_v1 > T_v3) { + temp = v1; + v1 = v3; + v3 = temp; + T_vtemp = T_v1; + T_v1 = T_v3; + T_v3 = T_vtemp; + } + if (T_v1 > T_v4) { + temp = v1; + v4 = temp; + T_vtemp = T_v1; + T_v4 = T_vtemp; + v3 = v5; + T_v3 = T_v5; + } + if (T_v3 > T_v4) { + return v4; + } + return v3; + } + + private int ssPivot(final int td, final int pa, final int first, final int last) { + int middle; + int t; + + t = last - first; + middle = first + t / 2; + + if (t <= 512) { + if (t <= 32) { + return ssMedian3(td, pa, first, middle, last - 1); + } + t >>= 2; + return ssMedian5(td, pa, first, first + t, middle, last - 1 - t, last - 1); + } + t >>= 3; + return ssMedian3( + td, pa, + ssMedian3(td, pa, first, first + t, first + (t << 1)), + ssMedian3(td, pa, middle - t, middle, middle + t), + ssMedian3(td, pa, last - 1 - (t << 1), last - 1 - t, last - 1) + ); + } + + private static int ssLog(final int n) { + return (n & 0xff00) != 0 ? + 8 + LOG_2_TABLE[n >> 8 & 0xff] + : LOG_2_TABLE[n & 0xff]; + } + + private int ssSubstringPartition(final int pa, final int first, final int last, final int depth) { + final int[] SA = this.SA; + + int a, b; + int t; + + for (a = first - 1, b = last;;) { + while (++a < b && (SA[pa + SA[a]] + depth >= SA[pa + SA[a] + 1] + 1)) { + SA[a] = ~SA[a]; + } + --b; + while (a < b && (SA[pa + SA[b]] + depth < SA[pa + SA[b] + 1] + 1)) { + --b; + } + + if (b <= a) { + break; + } + t = ~SA[b]; + SA[b] = SA[a]; + SA[a] = t; + } + if (first < a) { + SA[first] = ~SA[first]; + } + return a; + } + + private static class StackEntry { + final int a; + final int b; + final int c; + final int d; + + StackEntry(final int a, final int b, final int c, final int d) { + this.a = a; + this.b = b; + this.c = c; + this.d = d; + } + } + + private void ssMultiKeyIntroSort(final int pa, int first, int last, int depth) { + final int[] SA = this.SA; + final byte[] T = this.T; + + final StackEntry[] stack = new StackEntry[STACK_SIZE]; + + int Td; + int a, b, c, d, e, f; + int s, t; + int ssize; + int limit; + int v, x = 0; + + for (ssize = 0, limit = ssLog(last - first);;) { + if (last - first <= INSERTIONSORT_THRESHOLD) { + if (1 < last - first) { + ssInsertionSort(pa, first, last, depth); + } + if (ssize == 0) { + return; + } + StackEntry entry = stack[--ssize]; + first = entry.a; + last = entry.b; + depth = entry.c; + limit = entry.d; + continue; + } + + Td = depth; + if (limit-- == 0) { + ssHeapSort(Td, pa, first, last - first); + } + if (limit < 0) { + for (a = first + 1, v = T[Td + SA[pa + SA[first]]] & 0xff; a < last; ++a) { + if ((x = T[Td + SA[pa + SA[a]]] & 0xff) != v) { + if (1 < a - first) { + break; + } + v = x; + first = a; + } + } + if ((T[Td + SA[pa + SA[first]] - 1] & 0xff) < v) { + first = ssSubstringPartition(pa, first, a, depth); + } + if (a - first <= last - a) { + if (1 < a - first) { + stack[ssize++] = new StackEntry(a, last, depth, -1); + last = a; + depth += 1; + limit = ssLog(a - first); + } else { + first = a; + limit = -1; + } + } else { + if (1 < last - a) { + stack[ssize++] = new StackEntry(first, a, depth + 1, ssLog(a - first)); + first = a; + limit = -1; + } else { + last = a; + depth += 1; + limit = ssLog(a - first); + } + } + continue; + } + + a = ssPivot(Td, pa, first, last); + v = T[Td + SA[pa + SA[a]]] & 0xff; + swapElements(SA, first, SA, a); + + b = first + 1; + while (b < last && (x = T[Td + SA[pa + SA[b]]] & 0xff) == v) { + ++b; + } + if ((a = b) < last && x < v) { + while (++b < last && (x = T[Td + SA[pa + SA[b]]] & 0xff) <= v) { + if (x == v) { + swapElements(SA, b, SA, a); + ++a; + } + } + } + + c = last - 1; + while (b < c && (x = T[Td + SA[pa + SA[c]]] & 0xff) == v) { + --c; + } + if (b < (d = c) && x > v) { + while (b < --c && (x = T[Td + SA[pa + SA[c]]] & 0xff) >= v) { + if (x == v) { + swapElements(SA, c, SA, d); + --d; + } + } + } + while (b < c) { + swapElements(SA, b, SA, c); + while (++b < c && (x = T[Td + SA[pa + SA[b]]] & 0xff) <= v) { + if (x == v) { + swapElements(SA, b, SA, a); + ++a; + } + } + while (b < --c && (x = T[Td + SA[pa + SA[c]]] & 0xff) >= v) { + if (x == v) { + swapElements(SA, c, SA, d); + --d; + } + } + } + + if (a <= d) { + c = b - 1; + + if ((s = a - first) > (t = b - a)) { + s = t; + } + for (e = first, f = b - s; 0 < s; --s, ++e, ++f) { + swapElements(SA, e, SA, f); + } + if ((s = d - c) > (t = last - d - 1)) { + s = t; + } + for (e = b, f = last - s; 0 < s; --s, ++e, ++f) { + swapElements(SA, e, SA, f); + } + + a = first + (b - a); + c = last - (d - c); + b = v <= (T[Td + SA[pa + SA[a]] - 1] & 0xff) ? a : ssSubstringPartition(pa, a, c, depth); + + if (a - first <= last - c) { + if (last - c <= c - b) { + stack[ssize++] = new StackEntry(b, c, depth + 1, ssLog(c - b)); + stack[ssize++] = new StackEntry(c, last, depth, limit); + last = a; + } else if (a - first <= c - b) { + stack[ssize++] = new StackEntry(c, last, depth, limit); + stack[ssize++] = new StackEntry(b, c, depth + 1, ssLog(c - b)); + last = a; + } else { + stack[ssize++] = new StackEntry(c, last, depth, limit); + stack[ssize++] = new StackEntry(first, a, depth, limit); + first = b; + last = c; + depth += 1; + limit = ssLog(c - b); + } + } else { + if (a - first <= c - b) { + stack[ssize++] = new StackEntry(b, c, depth + 1, ssLog(c - b)); + stack[ssize++] = new StackEntry(first, a, depth, limit); + first = c; + } else if (last - c <= c - b) { + stack[ssize++] = new StackEntry(first, a, depth, limit); + stack[ssize++] = new StackEntry(b, c, depth + 1, ssLog(c - b)); + first = c; + } else { + stack[ssize++] = new StackEntry(first, a, depth, limit); + stack[ssize++] = new StackEntry(c, last, depth, limit); + first = b; + last = c; + depth += 1; + limit = ssLog(c - b); + } + } + } else { + limit += 1; + if ((T[Td + SA[pa + SA[first]] - 1] & 0xff) < v) { + first = ssSubstringPartition(pa, first, last, depth); + limit = ssLog(last - first); + } + depth += 1; + } + } + } + + private static void ssBlockSwap(final int[] array1, final int first1, + final int[] array2, final int first2, final int size) { + int a, b; + int i; + for (i = size, a = first1, b = first2; 0 < i; --i, ++a, ++b) { + swapElements(array1, a, array2, b); + } + } + + private void ssMergeForward(final int pa, int[] buf, final int bufoffset, + final int first, final int middle, final int last, final int depth) { + final int[] SA = this.SA; + + int bufend; + int i, j, k; + int t; + int r; + + bufend = bufoffset + (middle - first) - 1; + ssBlockSwap(buf, bufoffset, SA, first, middle - first); + + for (t = SA[first], i = first, j = bufoffset, k = middle;;) { + r = ssCompare(pa + buf[j], pa + SA[k], depth); + if (r < 0) { + do { + SA[i++] = buf[j]; + if (bufend <= j) { + buf[j] = t; + return; + } + buf[j++] = SA[i]; + } while (buf[j] < 0); + } else if (r > 0) { + do { + SA[i++] = SA[k]; + SA[k++] = SA[i]; + if (last <= k) { + while (j < bufend) { + SA[i++] = buf[j]; buf[j++] = SA[i]; + } + SA[i] = buf[j]; buf[j] = t; + return; + } + } while (SA[k] < 0); + } else { + SA[k] = ~SA[k]; + do { + SA[i++] = buf[j]; + if (bufend <= j) { + buf[j] = t; + return; + } + buf[j++] = SA[i]; + } while (buf[j] < 0); + + do { + SA[i++] = SA[k]; + SA[k++] = SA[i]; + if (last <= k) { + while (j < bufend) { + SA[i++] = buf[j]; + buf[j++] = SA[i]; + } + SA[i] = buf[j]; buf[j] = t; + return; + } + } while (SA[k] < 0); + } + } + } + + private void ssMergeBackward(final int pa, int[] buf, final int bufoffset, + final int first, final int middle, final int last, final int depth) { + final int[] SA = this.SA; + + int p1, p2; + int bufend; + int i, j, k; + int t; + int r; + int x; + + bufend = bufoffset + (last - middle); + ssBlockSwap(buf, bufoffset, SA, middle, last - middle); + + x = 0; + if (buf[bufend - 1] < 0) { + x |= 1; + p1 = pa + ~buf[bufend - 1]; + } else { + p1 = pa + buf[bufend - 1]; + } + if (SA[middle - 1] < 0) { + x |= 2; + p2 = pa + ~SA[middle - 1]; + } else { + p2 = pa + SA[middle - 1]; + } + for (t = SA[last - 1], i = last - 1, j = bufend - 1, k = middle - 1;;) { + + r = ssCompare(p1, p2, depth); + if (r > 0) { + if ((x & 1) != 0) { + do { + SA[i--] = buf[j]; + buf[j--] = SA[i]; + } while (buf[j] < 0); + x ^= 1; + } + SA[i--] = buf[j]; + if (j <= bufoffset) { + buf[j] = t; + return; + } + buf[j--] = SA[i]; + + if (buf[j] < 0) { + x |= 1; + p1 = pa + ~buf[j]; + } else { + p1 = pa + buf[j]; + } + } else if (r < 0) { + if ((x & 2) != 0) { + do { + SA[i--] = SA[k]; + SA[k--] = SA[i]; + } while (SA[k] < 0); + x ^= 2; + } + SA[i--] = SA[k]; + SA[k--] = SA[i]; + if (k < first) { + while (bufoffset < j) { + SA[i--] = buf[j]; + buf[j--] = SA[i]; + } + SA[i] = buf[j]; + buf[j] = t; + return; + } + + if (SA[k] < 0) { + x |= 2; + p2 = pa + ~SA[k]; + } else { + p2 = pa + SA[k]; + } + } else { + if ((x & 1) != 0) { + do { + SA[i--] = buf[j]; + buf[j--] = SA[i]; + } while (buf[j] < 0); + x ^= 1; + } + SA[i--] = ~buf[j]; + if (j <= bufoffset) { + buf[j] = t; + return; + } + buf[j--] = SA[i]; + + if ((x & 2) != 0) { + do { + SA[i--] = SA[k]; + SA[k--] = SA[i]; + } while (SA[k] < 0); + x ^= 2; + } + SA[i--] = SA[k]; + SA[k--] = SA[i]; + if (k < first) { + while (bufoffset < j) { + SA[i--] = buf[j]; + buf[j--] = SA[i]; + } + SA[i] = buf[j]; + buf[j] = t; + return; + } + + if (buf[j] < 0) { + x |= 1; + p1 = pa + ~buf[j]; + } else { + p1 = pa + buf[j]; + } + if (SA[k] < 0) { + x |= 2; + p2 = pa + ~SA[k]; + } else { + p2 = pa + SA[k]; + } + } + } + } + + private static int getIDX(final int a) { + return 0 <= a ? a : ~a; + } + + private void ssMergeCheckEqual(final int pa, final int depth, final int a) { + final int[] SA = this.SA; + + if (0 <= SA[a] && ssCompare(pa + getIDX(SA[a - 1]), pa + SA[a], depth) == 0) { + SA[a] = ~SA[a]; + } + } + + private void ssMerge(final int pa, int first, int middle, int last, int[] buf, + final int bufoffset, final int bufsize, final int depth) { + final int[] SA = this.SA; + + final StackEntry[] stack = new StackEntry[STACK_SIZE]; + + int i, j; + int m, len, half; + int ssize; + int check, next; + + for (check = 0, ssize = 0;;) { + + if (last - middle <= bufsize) { + if (first < middle && middle < last) { + ssMergeBackward(pa, buf, bufoffset, first, middle, last, depth); + } + + if ((check & 1) != 0) { + ssMergeCheckEqual(pa, depth, first); + } + if ((check & 2) != 0) { + ssMergeCheckEqual(pa, depth, last); + } + if (ssize == 0) { + return; + } + StackEntry entry = stack[--ssize]; + first = entry.a; + middle = entry.b; + last = entry.c; + check = entry.d; + continue; + } + + if (middle - first <= bufsize) { + if (first < middle) { + ssMergeForward(pa, buf, bufoffset, first, middle, last, depth); + } + if ((check & 1) != 0) { + ssMergeCheckEqual(pa, depth, first); + } + if ((check & 2) != 0) { + ssMergeCheckEqual(pa, depth, last); + } + if (ssize == 0) { + return; + } + StackEntry entry = stack[--ssize]; + first = entry.a; + middle = entry.b; + last = entry.c; + check = entry.d; + continue; + } + + for (m = 0, len = Math.min(middle - first, last - middle), half = len >> 1; + 0 < len; + len = half, half >>= 1) { + + if (ssCompare(pa + getIDX(SA[middle + m + half]), + pa + getIDX(SA[middle - m - half - 1]), depth) < 0) { + m += half + 1; + half -= (len & 1) ^ 1; + } + } + + if (0 < m) { + ssBlockSwap(SA, middle - m, SA, middle, m); + i = j = middle; + next = 0; + if (middle + m < last) { + if (SA[middle + m] < 0) { + while (SA[i - 1] < 0) { + --i; + } + SA[middle + m] = ~SA[middle + m]; + } + for (j = middle; SA[j] < 0;) { + ++j; + } + next = 1; + } + if (i - first <= last - j) { + stack[ssize++] = new StackEntry(j, middle + m, last, (check & 2) | (next & 1)); + middle -= m; + last = i; + check &= 1; + } else { + if (i == middle && middle == j) { + next <<= 1; + } + stack[ssize++] = new StackEntry(first, middle - m, i, (check & 1) | (next & 2)); + first = j; + middle += m; + check = (check & 2) | (next & 1); + } + } else { + if ((check & 1) != 0) { + ssMergeCheckEqual(pa, depth, first); + } + ssMergeCheckEqual(pa, depth, middle); + if ((check & 2) != 0) { + ssMergeCheckEqual(pa, depth, last); + } + if (ssize == 0) { + return; + } + StackEntry entry = stack[--ssize]; + first = entry.a; + middle = entry.b; + last = entry.c; + check = entry.d; + } + } + } + + private void subStringSort(final int pa, int first, final int last, + final int[] buf, final int bufoffset, final int bufsize, + final int depth, final boolean lastsuffix, final int size) { + final int[] SA = this.SA; + + int a, b; + int[] curbuf; + int curbufoffset; + int i, j, k; + int curbufsize; + + if (lastsuffix) { + ++first; + } + for (a = first, i = 0; a + SS_BLOCKSIZE < last; a += SS_BLOCKSIZE, ++i) { + ssMultiKeyIntroSort(pa, a, a + SS_BLOCKSIZE, depth); + curbuf = SA; + curbufoffset = a + SS_BLOCKSIZE; + curbufsize = last - (a + SS_BLOCKSIZE); + if (curbufsize <= bufsize) { + curbufsize = bufsize; + curbuf = buf; + curbufoffset = bufoffset; + } + for (b = a, k = SS_BLOCKSIZE, j = i; (j & 1) != 0; b -= k, k <<= 1, j >>>= 1) { + ssMerge(pa, b - k, b, b + k, curbuf, curbufoffset, curbufsize, depth); + } + } + + ssMultiKeyIntroSort(pa, a, last, depth); + + for (k = SS_BLOCKSIZE; i != 0; k <<= 1, i >>= 1) { + if ((i & 1) != 0) { + ssMerge(pa, a - k, a, last, buf, bufoffset, bufsize, depth); + a -= k; + } + } + + if (lastsuffix) { + int r; + for (a = first, i = SA[first - 1], r = 1; + a < last && (SA[a] < 0 || 0 < (r = ssCompareLast(pa, pa + i, pa + SA[a], depth, size))); + ++a) { + SA[a - 1] = SA[a]; + } + if (r == 0) { + SA[a] = ~SA[a]; + } + SA[a - 1] = i; + } + } + + /*----------------------------------------------------------------------------*/ + + private int trGetC(final int isa, final int isaD, final int isaN, final int p) { + return isaD + p < isaN ? + SA[isaD + p] + : SA[isa + ((isaD - isa + p) % (isaN - isa))]; + } + + private void trFixdown(final int isa, final int isaD, final int isaN, final int sa, int i, final int size) { + final int[] SA = this.SA; + + int j, k; + int v; + int c, d, e; + + for (v = SA[sa + i], c = trGetC(isa, isaD, isaN, v); (j = 2 * i + 1) < size; SA[sa + i] = SA[sa + k], i = k) { + k = j++; + d = trGetC(isa, isaD, isaN, SA[sa + k]); + if (d < (e = trGetC(isa, isaD, isaN, SA[sa + j]))) { + k = j; + d = e; + } + if (d <= c) { + break; + } + } + SA[sa + i] = v; + } + + private void trHeapSort(final int isa, final int isaD, final int isaN, final int sa, final int size) { + final int[] SA = this.SA; + + int i, m; + int t; + + m = size; + if (size % 2 == 0) { + m--; + if (trGetC(isa, isaD, isaN, SA[sa + m / 2]) < trGetC(isa, isaD, isaN, SA[sa + m])) { + swapElements(SA, sa + m, SA, sa + m / 2); + } + } + + for (i = m / 2 - 1; 0 <= i; --i) { + trFixdown(isa, isaD, isaN, sa, i, m); + } + + if (size % 2 == 0) { + swapElements(SA, sa, SA, sa + m); + trFixdown(isa, isaD, isaN, sa, 0, m); + } + + for (i = m - 1; 0 < i; --i) { + t = SA[sa]; + SA[sa] = SA[sa + i]; + trFixdown(isa, isaD, isaN, sa, 0, i); + SA[sa + i] = t; + } + } + + private void trInsertionSort(final int isa, final int isaD, final int isaN, int first, int last) { + final int[] SA = this.SA; + + int a, b; + int t, r; + + for (a = first + 1; a < last; ++a) { + for (t = SA[a], b = a - 1; 0 > (r = trGetC(isa, isaD, isaN, t) - trGetC(isa, isaD, isaN, SA[b]));) { + do { + SA[b + 1] = SA[b]; + } while (first <= --b && SA[b] < 0); + if (b < first) { + break; + } + } + if (r == 0) { + SA[b] = ~SA[b]; + } + SA[b + 1] = t; + } + } + + private static int trLog(int n) { + return (n & 0xffff0000) != 0 ? + (n & 0xff000000) != 0 ? 24 + LOG_2_TABLE[n >> 24 & 0xff] : LOG_2_TABLE[n >> 16 & 0xff + 16] + : (n & 0x0000ff00) != 0 ? 8 + LOG_2_TABLE[n >> 8 & 0xff] : LOG_2_TABLE[n & 0xff]; + } + + private int trMedian3(final int isa, final int isaD, final int isaN, int v1, int v2, int v3) { + final int[] SA = this.SA; + + int SA_v1 = trGetC(isa, isaD, isaN, SA[v1]); + int SA_v2 = trGetC(isa, isaD, isaN, SA[v2]); + int SA_v3 = trGetC(isa, isaD, isaN, SA[v3]); + + if (SA_v1 > SA_v2) { + final int temp = v1; + v1 = v2; + v2 = temp; + final int SA_vtemp = SA_v1; + SA_v1 = SA_v2; + SA_v2 = SA_vtemp; + } + if (SA_v2 > SA_v3) { + if (SA_v1 > SA_v3) { + return v1; + } + return v3; + } + + return v2; + } + + private int trMedian5(final int isa, final int isaD, final int isaN, int v1, int v2, int v3, int v4, int v5) { + final int[] SA = this.SA; + + int SA_v1 = trGetC(isa, isaD, isaN, SA[v1]); + int SA_v2 = trGetC(isa, isaD, isaN, SA[v2]); + int SA_v3 = trGetC(isa, isaD, isaN, SA[v3]); + int SA_v4 = trGetC(isa, isaD, isaN, SA[v4]); + int SA_v5 = trGetC(isa, isaD, isaN, SA[v5]); + int temp; + int SA_vtemp; + + if (SA_v2 > SA_v3) { + temp = v2; + v2 = v3; + v3 = temp; + SA_vtemp = SA_v2; + SA_v2 = SA_v3; + SA_v3 = SA_vtemp; + } + if (SA_v4 > SA_v5) { + temp = v4; + v4 = v5; + v5 = temp; + SA_vtemp = SA_v4; + SA_v4 = SA_v5; + SA_v5 = SA_vtemp; + } + if (SA_v2 > SA_v4) { + temp = v2; + v4 = temp; + SA_vtemp = SA_v2; + SA_v4 = SA_vtemp; + temp = v3; + v3 = v5; + v5 = temp; + SA_vtemp = SA_v3; + SA_v3 = SA_v5; + SA_v5 = SA_vtemp; + } + if (SA_v1 > SA_v3) { + temp = v1; + v1 = v3; + v3 = temp; + SA_vtemp = SA_v1; + SA_v1 = SA_v3; + SA_v3 = SA_vtemp; + } + if (SA_v1 > SA_v4) { + temp = v1; + v4 = temp; + SA_vtemp = SA_v1; + SA_v4 = SA_vtemp; + v3 = v5; + SA_v3 = SA_v5; + } + if (SA_v3 > SA_v4) { + return v4; + } + return v3; + } + + private int trPivot(final int isa, final int isaD, final int isaN, final int first, final int last) { + final int middle; + int t; + + t = last - first; + middle = first + t / 2; + + if (t <= 512) { + if (t <= 32) { + return trMedian3(isa, isaD, isaN, first, middle, last - 1); + } + t >>= 2; + return trMedian5( + isa, isaD, isaN, + first, first + t, + middle, + last - 1 - t, last - 1 + ); + } + t >>= 3; + return trMedian3( + isa, isaD, isaN, + trMedian3(isa, isaD, isaN, first, first + t, first + (t << 1)), + trMedian3(isa, isaD, isaN, middle - t, middle, middle + t), + trMedian3(isa, isaD, isaN, last - 1 - (t << 1), last - 1 - t, last - 1) + ); + } + + /*---------------------------------------------------------------------------*/ + + private void lsUpdateGroup(final int isa, final int first, final int last) { + final int[] SA = this.SA; + + int a, b; + int t; + + for (a = first; a < last; ++a) { + if (0 <= SA[a]) { + b = a; + do { + SA[isa + SA[a]] = a; + } while (++a < last && 0 <= SA[a]); + SA[b] = b - a; + if (last <= a) { + break; + } + } + b = a; + do { + SA[a] = ~SA[a]; + } while (SA[++a] < 0); + t = a; + do { + SA[isa + SA[b]] = t; + } while (++b <= a); + } + } + + private void lsIntroSort(final int isa, final int isaD, final int isaN, int first, int last) { + final int[] SA = this.SA; + + final StackEntry[] stack = new StackEntry[STACK_SIZE]; + + int a, b, c, d, e, f; + int s, t; + int limit; + int v, x = 0; + int ssize; + + for (ssize = 0, limit = trLog(last - first);;) { + if (last - first <= INSERTIONSORT_THRESHOLD) { + if (1 < last - first) { + trInsertionSort(isa, isaD, isaN, first, last); + lsUpdateGroup(isa, first, last); + } else if (last - first == 1) { + SA[first] = -1; + } + if (ssize == 0) { + return; + } + StackEntry entry = stack[--ssize]; + first = entry.a; + last = entry.b; + limit = entry.c; + continue; + } + + if (limit-- == 0) { + trHeapSort(isa, isaD, isaN, first, last - first); + for (a = last - 1; first < a; a = b) { + for (x = trGetC(isa, isaD, isaN, SA[a]), b = a - 1; + first <= b && trGetC(isa, isaD, isaN, SA[b]) == x; + --b) { + SA[b] = ~SA[b]; + } + } + lsUpdateGroup(isa, first, last); + if (ssize == 0) { + return; + } + StackEntry entry = stack[--ssize]; + first = entry.a; + last = entry.b; + limit = entry.c; + continue; + } + + a = trPivot(isa, isaD, isaN, first, last); + swapElements(SA, first, SA, a); + v = trGetC(isa, isaD, isaN, SA[first]); + + b = first + 1; + while (b < last && (x = trGetC(isa, isaD, isaN, SA[b])) == v) { + ++b; + } + if ((a = b) < last && x < v) { + while (++b < last && (x = trGetC(isa, isaD, isaN, SA[b])) <= v) { + if (x == v) { + swapElements(SA, b, SA, a); + ++a; + } + } + } + + c = last - 1; + while (b < c && (x = trGetC(isa, isaD, isaN, SA[c])) == v) { + --c; + } + if (b < (d = c) && x > v) { + while (b < --c && (x = trGetC(isa, isaD, isaN, SA[c])) >= v) { + if (x == v) { + swapElements(SA, c, SA, d); + --d; + } + } + } + while (b < c) { + swapElements(SA, b, SA, c); + while (++b < c && (x = trGetC(isa, isaD, isaN, SA[b])) <= v) { + if (x == v) { + swapElements(SA, b, SA, a); + ++a; + } + } + while (b < --c && (x = trGetC(isa, isaD, isaN, SA[c])) >= v) { + if (x == v) { + swapElements(SA, c, SA, d); + --d; + } + } + } + + if (a <= d) { + c = b - 1; + + if ((s = a - first) > (t = b - a)) { + s = t; + } + for (e = first, f = b - s; 0 < s; --s, ++e, ++f) { + swapElements(SA, e, SA, f); + } + if ((s = d - c) > (t = last - d - 1)) { + s = t; + } + for (e = b, f = last - s; 0 < s; --s, ++e, ++f) { + swapElements(SA, e, SA, f); + } + + a = first + (b - a); + b = last - (d - c); + + for (c = first, v = a - 1; c < a; ++c) { + SA[isa + SA[c]] = v; + } + if (b < last) { + for (c = a, v = b - 1; c < b; ++c) { + SA[isa + SA[c]] = v; + } + } + if ((b - a) == 1) { + SA[a] = - 1; + } + + if (a - first <= last - b) { + if (first < a) { + stack[ssize++] = new StackEntry(b, last, limit, 0); + last = a; + } else { + first = b; + } + } else { + if (b < last) { + stack[ssize++] = new StackEntry(first, a, limit, 0); + first = b; + } else { + last = a; + } + } + } else { + if (ssize == 0) { + return; + } + StackEntry entry = stack[--ssize]; + first = entry.a; + last = entry.b; + limit = entry.c; + } + } + } + + private void lsSort(final int isa, final int n, final int depth) { + final int[] SA = this.SA; + + int isaD; + int first, last, i; + int t, skip; + + for (isaD = isa + depth; -n < SA[0]; isaD += isaD - isa) { + first = 0; + skip = 0; + do { + if ((t = SA[first]) < 0) { + first -= t; + skip += t; + } else { + if (skip != 0) { + SA[first + skip] = skip; + skip = 0; + } + last = SA[isa + t] + 1; + lsIntroSort(isa, isaD, isa + n, first, last); + first = last; + } + } while (first < n); + if (skip != 0) { + SA[first + skip] = skip; + } + if (n < isaD - isa) { + first = 0; + do { + if ((t = SA[first]) < 0) { + first -= t; + } else { + last = SA[isa + t] + 1; + for (i = first; i < last; ++i) { + SA[isa + SA[i]] = i; + } + first = last; + } + } while (first < n); + break; + } + } + } + + /*---------------------------------------------------------------------------*/ + + private static class PartitionResult { + final int first; + final int last; + + PartitionResult(final int first, final int last) { + this.first = first; + this.last = last; + } + } + + private PartitionResult trPartition(final int isa, final int isaD, final int isaN, + int first, int last, final int v) { + final int[] SA = this.SA; + + int a, b, c, d, e, f; + int t, s; + int x = 0; + + b = first; + while (b < last && (x = trGetC(isa, isaD, isaN, SA[b])) == v) { + ++b; + } + if ((a = b) < last && x < v) { + while (++b < last && (x = trGetC(isa, isaD, isaN, SA[b])) <= v) { + if (x == v) { + swapElements(SA, b, SA, a); + ++a; + } + } + } + + c = last - 1; + while (b < c && (x = trGetC(isa, isaD, isaN, SA[c])) == v) { + --c; + } + if (b < (d = c) && x > v) { + while (b < --c && (x = trGetC(isa, isaD, isaN, SA[c])) >= v) { + if (x == v) { + swapElements(SA, c, SA, d); + --d; + } + } + } + while (b < c) { + swapElements(SA, b, SA, c); + while (++b < c && (x = trGetC(isa, isaD, isaN, SA[b])) <= v) { + if (x == v) { + swapElements(SA, b, SA, a); + ++a; + } + } + while (b < --c && (x = trGetC(isa, isaD, isaN, SA[c])) >= v) { + if (x == v) { + swapElements(SA, c, SA, d); + --d; + } + } + } + + if (a <= d) { + c = b - 1; + if ((s = a - first) > (t = b - a)) { + s = t; + } + for (e = first, f = b - s; 0 < s; --s, ++e, ++f) { + swapElements(SA, e, SA, f); + } + if ((s = d - c) > (t = last - d - 1)) { + s = t; + } + for (e = b, f = last - s; 0 < s; --s, ++e, ++f) { + swapElements(SA, e, SA, f); + } + first += b - a; + last -= d - c; + } + return new PartitionResult(first, last); + } + + private void trCopy(final int isa, final int isaN, final int first, + final int a, final int b, final int last, final int depth) { + final int[] SA = this.SA; + + int c, d, e; + int s, v; + + v = b - 1; + + for (c = first, d = a - 1; c <= d; ++c) { + if ((s = SA[c] - depth) < 0) { + s += isaN - isa; + } + if (SA[isa + s] == v) { + SA[++d] = s; + SA[isa + s] = d; + } + } + for (c = last - 1, e = d + 1, d = b; e < d; --c) { + if ((s = SA[c] - depth) < 0) { + s += isaN - isa; + } + if (SA[isa + s] == v) { + SA[--d] = s; + SA[isa + s] = d; + } + } + } + + private void trIntroSort(final int isa, int isaD, int isaN, int first, + int last, final TRBudget budget, final int size) { + final int[] SA = this.SA; + + final StackEntry[] stack = new StackEntry[STACK_SIZE]; + + int a, b, c, d, e, f; + int s, t; + int v, x = 0; + int limit, next; + int ssize; + + for (ssize = 0, limit = trLog(last - first);;) { + if (limit < 0) { + if (limit == -1) { + if (!budget.update(size, last - first)) { + break; + } + PartitionResult result = trPartition(isa, isaD - 1, isaN, first, last, last - 1); + a = result.first; + b = result.last; + if (first < a || b < last) { + if (a < last) { + for (c = first, v = a - 1; c < a; ++c) { + SA[isa + SA[c]] = v; + } + } + if (b < last) { + for (c = a, v = b - 1; c < b; ++c) { + SA[isa + SA[c]] = v; + } + } + + stack[ssize++] = new StackEntry(0, a, b, 0); + stack[ssize++] = new StackEntry(isaD - 1, first, last, -2); + if (a - first <= last - b) { + if (1 < a - first) { + stack[ssize++] = new StackEntry(isaD, b, last, trLog(last - b)); + last = a; limit = trLog(a - first); + } else if (1 < last - b) { + first = b; limit = trLog(last - b); + } else { + if (ssize == 0) { + return; + } + StackEntry entry = stack[--ssize]; + isaD = entry.a; + first = entry.b; + last = entry.c; + limit = entry.d; + } + } else { + if (1 < last - b) { + stack[ssize++] = new StackEntry(isaD, first, a, trLog(a - first)); + first = b; + limit = trLog(last - b); + } else if (1 < a - first) { + last = a; + limit = trLog(a - first); + } else { + if (ssize == 0) { + return; + } + StackEntry entry = stack[--ssize]; + isaD = entry.a; + first = entry.b; + last = entry.c; + limit = entry.d; + } + } + } else { + for (c = first; c < last; ++c) { + SA[isa + SA[c]] = c; + } + if (ssize == 0) { + return; + } + StackEntry entry = stack[--ssize]; + isaD = entry.a; + first = entry.b; + last = entry.c; + limit = entry.d; + } + } else if (limit == -2) { + a = stack[--ssize].b; + b = stack[ssize].c; + trCopy(isa, isaN, first, a, b, last, isaD - isa); + if (ssize == 0) { + return; + } + StackEntry entry = stack[--ssize]; + isaD = entry.a; + first = entry.b; + last = entry.c; + limit = entry.d; + } else { + if (0 <= SA[first]) { + a = first; + do { + SA[isa + SA[a]] = a; + } while (++a < last && 0 <= SA[a]); + first = a; + } + if (first < last) { + a = first; + do { + SA[a] = ~SA[a]; + } while (SA[++a] < 0); + next = SA[isa + SA[a]] != SA[isaD + SA[a]] ? trLog(a - first + 1) : -1; + if (++a < last) { + for (b = first, v = a - 1; b < a; ++b) { + SA[isa + SA[b]] = v; + } + } + + if (a - first <= last - a) { + stack[ssize++] = new StackEntry(isaD, a, last, -3); + isaD += 1; last = a; limit = next; + } else { + if (1 < last - a) { + stack[ssize++] = new StackEntry(isaD + 1, first, a, next); + first = a; limit = -3; + } else { + isaD += 1; last = a; limit = next; + } + } + } else { + if (ssize == 0) { + return; + } + StackEntry entry = stack[--ssize]; + isaD = entry.a; + first = entry.b; + last = entry.c; + limit = entry.d; + } + } + continue; + } + + if (last - first <= INSERTIONSORT_THRESHOLD) { + if (!budget.update(size, last - first)) { + break; + } + trInsertionSort(isa, isaD, isaN, first, last); + limit = -3; + continue; + } + + if (limit-- == 0) { + if (!budget.update(size, last - first)) { + break; + } + trHeapSort(isa, isaD, isaN, first, last - first); + for (a = last - 1; first < a; a = b) { + for (x = trGetC(isa, isaD, isaN, SA[a]), b = a - 1; + first <= b && trGetC(isa, isaD, isaN, SA[b]) == x; + --b) { + SA[b] = ~SA[b]; + } + } + limit = -3; + continue; + } + + a = trPivot(isa, isaD, isaN, first, last); + + swapElements(SA, first, SA, a); + v = trGetC(isa, isaD, isaN, SA[first]); + + b = first + 1; + while (b < last && (x = trGetC(isa, isaD, isaN, SA[b])) == v) { + ++b; + } + if ((a = b) < last && x < v) { + while (++b < last && (x = trGetC(isa, isaD, isaN, SA[b])) <= v) { + if (x == v) { + swapElements(SA, b, SA, a); + ++a; + } + } + } + + c = last - 1; + while (b < c && (x = trGetC(isa, isaD, isaN, SA[c])) == v) { + --c; + } + if (b < (d = c) && x > v) { + while (b < --c && (x = trGetC(isa, isaD, isaN, SA[c])) >= v) { + if (x == v) { + swapElements(SA, c, SA, d); + --d; + } + } + } + while (b < c) { + swapElements(SA, b, SA, c); + while (++b < c && (x = trGetC(isa, isaD, isaN, SA[b])) <= v) { + if (x == v) { + swapElements(SA, b, SA, a); + ++a; + } + } + while (b < --c && (x = trGetC(isa, isaD, isaN, SA[c])) >= v) { + if (x == v) { + swapElements(SA, c, SA, d); + --d; + } + } + } + + if (a <= d) { + c = b - 1; + + if ((s = a - first) > (t = b - a)) { + s = t; + } + for (e = first, f = b - s; 0 < s; --s, ++e, ++f) { + swapElements(SA, e, SA, f); + } + if ((s = d - c) > (t = last - d - 1)) { + s = t; + } + for (e = b, f = last - s; 0 < s; --s, ++e, ++f) { + swapElements(SA, e, SA, f); + } + + a = first + (b - a); + b = last - (d - c); + next = SA[isa + SA[a]] != v ? trLog(b - a) : -1; + + for (c = first, v = a - 1; c < a; ++c) { + SA[isa + SA[c]] = v; + } + if (b < last) { + for (c = a, v = b - 1; c < b; ++c) { + SA[isa + SA[c]] = v; } + } + + if (a - first <= last - b) { + if (last - b <= b - a) { + if (1 < a - first) { + stack[ssize++] = new StackEntry(isaD + 1, a, b, next); + stack[ssize++] = new StackEntry(isaD, b, last, limit); + last = a; + } else if (1 < last - b) { + stack[ssize++] = new StackEntry(isaD + 1, a, b, next); + first = b; + } else if (1 < b - a) { + isaD += 1; + first = a; + last = b; + limit = next; + } else { + if (ssize == 0) { + return; + } + StackEntry entry = stack[--ssize]; + isaD = entry.a; + first = entry.b; + last = entry.c; + limit = entry.d; + } + } else if (a - first <= b - a) { + if (1 < a - first) { + stack[ssize++] = new StackEntry(isaD, b, last, limit); + stack[ssize++] = new StackEntry(isaD + 1, a, b, next); + last = a; + } else if (1 < b - a) { + stack[ssize++] = new StackEntry(isaD, b, last, limit); + isaD += 1; + first = a; + last = b; + limit = next; + } else { + first = b; + } + } else { + if (1 < b - a) { + stack[ssize++] = new StackEntry(isaD, b, last, limit); + stack[ssize++] = new StackEntry(isaD, first, a, limit); + isaD += 1; + first = a; + last = b; + limit = next; + } else { + stack[ssize++] = new StackEntry(isaD, b, last, limit); + last = a; + } + } + } else { + if (a - first <= b - a) { + if (1 < last - b) { + stack[ssize++] = new StackEntry(isaD + 1, a, b, next); + stack[ssize++] = new StackEntry(isaD, first, a, limit); + first = b; + } else if (1 < a - first) { + stack[ssize++] = new StackEntry(isaD + 1, a, b, next); + last = a; + } else if (1 < b - a) { + isaD += 1; + first = a; + last = b; + limit = next; + } else { + stack[ssize++] = new StackEntry(isaD, first, last, limit); + } + } else if (last - b <= b - a) { + if (1 < last - b) { + stack[ssize++] = new StackEntry(isaD, first, a, limit); + stack[ssize++] = new StackEntry(isaD + 1, a, b, next); + first = b; + } else if (1 < b - a) { + stack[ssize++] = new StackEntry(isaD, first, a, limit); + isaD += 1; + first = a; + last = b; + limit = next; + } else { + last = a; + } + } else { + if (1 < b - a) { + stack[ssize++] = new StackEntry(isaD, first, a, limit); + stack[ssize++] = new StackEntry(isaD, b, last, limit); + isaD += 1; + first = a; + last = b; + limit = next; + } else { + stack[ssize++] = new StackEntry(isaD, first, a, limit); + first = b; + } + } + } + } else { + if (!budget.update(size, last - first)) { + break; // BUGFIX : Added to prevent an infinite loop in the original code + } + limit += 1; isaD += 1; + } + } + + for (s = 0; s < ssize; ++s) { + if (stack[s].d == -3) { + lsUpdateGroup(isa, stack[s].b, stack[s].c); + } + } + } + + private static class TRBudget { + int budget; + int chance; + + TRBudget(final int budget, final int chance) { + this.budget = budget; + this.chance = chance; + } + + boolean update(final int size, final int n) { + budget -= n; + if (budget <= 0) { + if (--chance == 0) { + return false; + } + budget += size; + } + return true; + } + } + + private void trSort(final int isa, final int n, final int depth) { + final int[] SA = this.SA; + + int first = 0, last; + int t; + + if (-n < SA[0]) { + TRBudget budget = new TRBudget(n, trLog(n) * 2 / 3 + 1); + do { + if ((t = SA[first]) < 0) { + first -= t; + } else { + last = SA[isa + t] + 1; + if (1 < last - first) { + trIntroSort(isa, isa + depth, isa + n, first, last, budget, n); + if (budget.chance == 0) { + /* Switch to Larsson-Sadakane sorting algorithm */ + if (0 < first) { + SA[0] = -first; + } + lsSort(isa, n, depth); + break; + } + } + first = last; + } + } while (first < n); + } + } + + /*---------------------------------------------------------------------------*/ + + private static int BUCKET_B(final int c0, final int c1) { + return (c1 << 8) | c0; + } + + private static int BUCKET_BSTAR(final int c0, final int c1) { + return (c0 << 8) | c1; + } + + private int sortTypeBstar(final int[] bucketA, final int[] bucketB) { + final byte[] T = this.T; + final int[] SA = this.SA; + final int n = this.n; + final int[] tempbuf = new int[256]; + + int[] buf; + int PAb, ISAb, bufoffset; + int i, j, k, t, m, bufsize; + int c0, c1; + int flag; + + for (i = 1, flag = 1; i < n; ++i) { + if (T[i - 1] != T[i]) { + if ((T[i - 1] & 0xff) > (T[i] & 0xff)) { + flag = 0; + } + break; + } + } + i = n - 1; + m = n; + + int ti, ti1, t0; + if ((ti = T[i] & 0xff) < (t0 = T[0] & 0xff) || (T[i] == T[0] && flag != 0)) { + if (flag == 0) { + ++bucketB[BUCKET_BSTAR(ti, t0)]; + SA[--m] = i; + } else { + ++bucketB[BUCKET_B(ti, t0)]; + } + for (--i; 0 <= i && (ti = T[i] & 0xff) <= (ti1 = T[i + 1] & 0xff); --i) { + ++bucketB[BUCKET_B(ti, ti1)]; + } + } + + while (0 <= i) { + do { + ++bucketA[T[i] & 0xff]; + } while (0 <= --i && (T[i] & 0xff) >= (T[i + 1] & 0xff)); + if (0 <= i) { + ++bucketB[BUCKET_BSTAR(T[i] & 0xff, T[i + 1] & 0xff)]; + SA[--m] = i; + for (--i; 0 <= i && (ti = T[i] & 0xff) <= (ti1 = T[i + 1] & 0xff); --i) { + ++bucketB[BUCKET_B(ti, ti1)]; + } + } + } + m = n - m; + if (m == 0) { + for (i = 0; i < n; ++i) { + SA[i] = i; + } + return 0; + } + + for (c0 = 0, i = -1, j = 0; c0 < 256; ++c0) { + t = i + bucketA[c0]; + bucketA[c0] = i + j; + i = t + bucketB[BUCKET_B(c0, c0)]; + for (c1 = c0 + 1; c1 < 256; ++c1) { + j += bucketB[BUCKET_BSTAR(c0, c1)]; + bucketB[(c0 << 8) | c1] = j; + i += bucketB[BUCKET_B(c0, c1)]; + } + } + + PAb = n - m; + ISAb = m; + for (i = m - 2; 0 <= i; --i) { + t = SA[PAb + i]; + c0 = T[t] & 0xff; + c1 = T[t + 1] & 0xff; + SA[--bucketB[BUCKET_BSTAR(c0, c1)]] = i; + } + t = SA[PAb + m - 1]; + c0 = T[t] & 0xff; + c1 = T[t + 1] & 0xff; + SA[--bucketB[BUCKET_BSTAR(c0, c1)]] = m - 1; + + buf = SA; + bufoffset = m; + bufsize = n - 2 * m; + if (bufsize <= 256) { + buf = tempbuf; + bufoffset = 0; + bufsize = 256; + } + + for (c0 = 255, j = m; 0 < j; --c0) { + for (c1 = 255; c0 < c1; j = i, --c1) { + i = bucketB[BUCKET_BSTAR(c0, c1)]; + if (1 < j - i) { + subStringSort(PAb, i, j, buf, bufoffset, bufsize, 2, SA[i] == m - 1, n); + } + } + } + + for (i = m - 1; 0 <= i; --i) { + if (0 <= SA[i]) { + j = i; + do { + SA[ISAb + SA[i]] = i; + } while (0 <= --i && 0 <= SA[i]); + SA[i + 1] = i - j; + if (i <= 0) { + break; + } + } + j = i; + do { + SA[ISAb + (SA[i] = ~SA[i])] = j; + } while (SA[--i] < 0); + SA[ISAb + SA[i]] = j; + } + + trSort(ISAb, m, 1); + + i = n - 1; j = m; + if ((T[i] & 0xff) < (T[0] & 0xff) || (T[i] == T[0] && flag != 0)) { + if (flag == 0) { + SA[SA[ISAb + --j]] = i; + } + for (--i; 0 <= i && (T[i] & 0xff) <= (T[i + 1] & 0xff);) { + --i; + } + } + while (0 <= i) { + for (--i; 0 <= i && (T[i] & 0xff) >= (T[i + 1] & 0xff);) { + --i; + } + if (0 <= i) { + SA[SA[ISAb + --j]] = i; + for (--i; 0 <= i && (T[i] & 0xff) <= (T[i + 1] & 0xff);) { + --i; + } + } + } + + for (c0 = 255, i = n - 1, k = m - 1; 0 <= c0; --c0) { + for (c1 = 255; c0 < c1; --c1) { + t = i - bucketB[BUCKET_B(c0, c1)]; + bucketB[BUCKET_B(c0, c1)] = i + 1; + + for (i = t, j = bucketB[BUCKET_BSTAR(c0, c1)]; j <= k; --i, --k) { + SA[i] = SA[k]; + } + } + t = i - bucketB[BUCKET_B(c0, c0)]; + bucketB[BUCKET_B(c0, c0)] = i + 1; + if (c0 < 255) { + bucketB[BUCKET_BSTAR(c0, c0 + 1)] = t + 1; + } + i = bucketA[c0]; + } + return m; + } + + private int constructBWT(final int[] bucketA, final int[] bucketB) { + final byte[] T = this.T; + final int[] SA = this.SA; + final int n = this.n; + + int i, j, t = 0; + int s, s1; + int c0, c1, c2 = 0; + int orig = -1; + + for (c1 = 254; 0 <= c1; --c1) { + for (i = bucketB[BUCKET_BSTAR(c1, c1 + 1)], j = bucketA[c1 + 1], t = 0, c2 = -1; + i <= j; + --j) { + if (0 <= (s1 = s = SA[j])) { + if (--s < 0) { + s = n - 1; + } + if ((c0 = T[s] & 0xff) <= c1) { + SA[j] = ~s1; + if (0 < s && (T[s - 1] & 0xff) > c0) { + s = ~s; + } + if (c2 == c0) { + SA[--t] = s; + } else { + if (0 <= c2) { + bucketB[BUCKET_B(c2, c1)] = t; + } + SA[t = bucketB[BUCKET_B(c2 = c0, c1)] - 1] = s; + } + } + } else { + SA[j] = ~s; + } + } + } + + for (i = 0; i < n; ++i) { + if (0 <= (s1 = s = SA[i])) { + if (--s < 0) { + s = n - 1; + } + if ((c0 = T[s] & 0xff) >= (T[s + 1] & 0xff)) { + if (0 < s && (T[s - 1] & 0xff) < c0) { + s = ~s; + } + if (c0 == c2) { + SA[++t] = s; + } else { + if (c2 != -1) { + bucketA[c2] = t; // BUGFIX: Original code can write to bucketA[-1] + } + SA[t = bucketA[c2 = c0] + 1] = s; + } + } + } else { + s1 = ~s1; + } + + if (s1 == 0) { + SA[i] = T[n - 1]; + orig = i; + } else { + SA[i] = T[s1 - 1]; + } + } + return orig; + } + + /** + * Performs a Burrows Wheeler Transform on the input array. + * @return the index of the first character of the input array within the output array + */ + public int bwt() { + final int[] SA = this.SA; + final byte[] T = this.T; + final int n = this.n; + + final int[] bucketA = new int[BUCKET_A_SIZE]; + final int[] bucketB = new int[BUCKET_B_SIZE]; + + if (n == 0) { + return 0; + } + if (n == 1) { + SA[0] = T[0]; + return 0; + } + + int m = sortTypeBstar(bucketA, bucketB); + if (0 < m) { + return constructBWT(bucketA, bucketB); + } + return 0; + } +} diff --git a/netty-bzip2/src/main/java/io/netty/bzip2/Bzip2HuffmanAllocator.java b/netty-bzip2/src/main/java/io/netty/bzip2/Bzip2HuffmanAllocator.java new file mode 100644 index 0000000..8cbfef5 --- /dev/null +++ b/netty-bzip2/src/main/java/io/netty/bzip2/Bzip2HuffmanAllocator.java @@ -0,0 +1,184 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.bzip2; + +/** + * An in-place, length restricted Canonical Huffman code length allocator.
+ * Based on the algorithm proposed by R. L. Milidi'u, A. A. Pessoa and E. S. Laber in + * In-place Length-Restricted Prefix Coding + * and incorporating additional ideas from the implementation of + * shcodec by Simakov Alexander. + */ +final class Bzip2HuffmanAllocator { + /** + * @param array The code length array + * @param i The input position + * @param nodesToMove The number of internal nodes to be relocated + * @return The smallest {@code k} such that {@code nodesToMove <= k <= i} and + * {@code i <= (array[k] % array.length)} + */ + private static int first(final int[] array, int i, final int nodesToMove) { + final int length = array.length; + final int limit = i; + int k = array.length - 2; + + while (i >= nodesToMove && array[i] % length > limit) { + k = i; + i -= limit - i + 1; + } + i = Math.max(nodesToMove - 1, i); + + while (k > i + 1) { + int temp = i + k >>> 1; + if (array[temp] % length > limit) { + k = temp; + } else { + i = temp; + } + } + return k; + } + + /** + * Fills the code array with extended parent pointers. + * @param array The code length array + */ + private static void setExtendedParentPointers(final int[] array) { + final int length = array.length; + array[0] += array[1]; + + for (int headNode = 0, tailNode = 1, topNode = 2; tailNode < length - 1; tailNode++) { + int temp; + if (topNode >= length || array[headNode] < array[topNode]) { + temp = array[headNode]; + array[headNode++] = tailNode; + } else { + temp = array[topNode++]; + } + + if (topNode >= length || (headNode < tailNode && array[headNode] < array[topNode])) { + temp += array[headNode]; + array[headNode++] = tailNode + length; + } else { + temp += array[topNode++]; + } + array[tailNode] = temp; + } + } + + /** + * Finds the number of nodes to relocate in order to achieve a given code length limit. + * @param array The code length array + * @param maximumLength The maximum bit length for the generated codes + * @return The number of nodes to relocate + */ + private static int findNodesToRelocate(final int[] array, final int maximumLength) { + int currentNode = array.length - 2; + for (int currentDepth = 1; currentDepth < maximumLength - 1 && currentNode > 1; currentDepth++) { + currentNode = first(array, currentNode - 1, 0); + } + return currentNode; + } + + /** + * A final allocation pass with no code length limit. + * @param array The code length array + */ + private static void allocateNodeLengths(final int[] array) { + int firstNode = array.length - 2; + int nextNode = array.length - 1; + + for (int currentDepth = 1, availableNodes = 2; availableNodes > 0; currentDepth++) { + final int lastNode = firstNode; + firstNode = first(array, lastNode - 1, 0); + + for (int i = availableNodes - (lastNode - firstNode); i > 0; i--) { + array[nextNode--] = currentDepth; + } + + availableNodes = (lastNode - firstNode) << 1; + } + } + + /** + * A final allocation pass that relocates nodes in order to achieve a maximum code length limit. + * @param array The code length array + * @param nodesToMove The number of internal nodes to be relocated + * @param insertDepth The depth at which to insert relocated nodes + */ + private static void allocateNodeLengthsWithRelocation(final int[] array, + final int nodesToMove, final int insertDepth) { + int firstNode = array.length - 2; + int nextNode = array.length - 1; + int currentDepth = insertDepth == 1 ? 2 : 1; + int nodesLeftToMove = insertDepth == 1 ? nodesToMove - 2 : nodesToMove; + + for (int availableNodes = currentDepth << 1; availableNodes > 0; currentDepth++) { + final int lastNode = firstNode; + firstNode = firstNode <= nodesToMove ? firstNode : first(array, lastNode - 1, nodesToMove); + + int offset = 0; + if (currentDepth >= insertDepth) { + offset = Math.min(nodesLeftToMove, 1 << (currentDepth - insertDepth)); + } else if (currentDepth == insertDepth - 1) { + offset = 1; + if (array[firstNode] == lastNode) { + firstNode++; + } + } + + for (int i = availableNodes - (lastNode - firstNode + offset); i > 0; i--) { + array[nextNode--] = currentDepth; + } + + nodesLeftToMove -= offset; + availableNodes = (lastNode - firstNode + offset) << 1; + } + } + + /** + * Allocates Canonical Huffman code lengths in place based on a sorted frequency array. + * @param array On input, a sorted array of symbol frequencies; On output, an array of Canonical + * Huffman code lengths + * @param maximumLength The maximum code length. Must be at least {@code ceil(log2(array.length))} + */ + static void allocateHuffmanCodeLengths(final int[] array, final int maximumLength) { + switch (array.length) { + case 2: + array[1] = 1; + // fall through + case 1: + array[0] = 1; + return; + } + + /* Pass 1 : Set extended parent pointers */ + setExtendedParentPointers(array); + + /* Pass 2 : Find number of nodes to relocate in order to achieve maximum code length */ + int nodesToRelocate = findNodesToRelocate(array, maximumLength); + + /* Pass 3 : Generate code lengths */ + if (array[0] % array.length >= nodesToRelocate) { + allocateNodeLengths(array); + } else { + int insertDepth = maximumLength - (32 - Integer.numberOfLeadingZeros(nodesToRelocate - 1)); + allocateNodeLengthsWithRelocation(array, nodesToRelocate, insertDepth); + } + } + + private Bzip2HuffmanAllocator() { } +} diff --git a/netty-bzip2/src/main/java/io/netty/bzip2/Bzip2HuffmanStageDecoder.java b/netty-bzip2/src/main/java/io/netty/bzip2/Bzip2HuffmanStageDecoder.java new file mode 100644 index 0000000..8af3edf --- /dev/null +++ b/netty-bzip2/src/main/java/io/netty/bzip2/Bzip2HuffmanStageDecoder.java @@ -0,0 +1,203 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.bzip2; + +import static io.netty.bzip2.Bzip2Constants.HUFFMAN_DECODE_MAX_CODE_LENGTH; +import static io.netty.bzip2.Bzip2Constants.HUFFMAN_GROUP_RUN_LENGTH; +import static io.netty.bzip2.Bzip2Constants.HUFFMAN_MAX_ALPHABET_SIZE; + +/** + * A decoder for the Bzip2 Huffman coding stage. + */ +public final class Bzip2HuffmanStageDecoder { + /** + * A reader that provides bit-level reads. + */ + private final Bzip2BitReader reader; + + /** + * The Huffman table number to use for each group of 50 symbols. + */ + public byte[] selectors; + + /** + * The minimum code length for each Huffman table. + */ + private final int[] minimumLengths; + + /** + * An array of values for each Huffman table that must be subtracted from the numerical value of + * a Huffman code of a given bit length to give its canonical code index. + */ + private final int[][] codeBases; + + /** + * An array of values for each Huffman table that gives the highest numerical value of a Huffman + * code of a given bit length. + */ + private final int[][] codeLimits; + + /** + * A mapping for each Huffman table from canonical code index to output symbol. + */ + private final int[][] codeSymbols; + + /** + * The Huffman table for the current group. + */ + private int currentTable; + + /** + * The index of the current group within the selectors array. + */ + private int groupIndex = -1; + + /** + * The byte position within the current group. A new group is selected every 50 decoded bytes. + */ + private int groupPosition = -1; + + /** + * Total number of used Huffman tables in range 2..6. + */ + public final int totalTables; + + /** + * The total number of codes (uniform for each table). + */ + public final int alphabetSize; + + /** + * Table for Move To Front transformations. + */ + public final Bzip2MoveToFrontTable tableMTF = new Bzip2MoveToFrontTable(); + + // For saving state if end of current ByteBuf was reached + public int currentSelector; + + /** + * The Canonical Huffman code lengths for each table. + */ + public final byte[][] tableCodeLengths; + + // For saving state if end of current ByteBuf was reached + public int currentGroup; + public int currentLength = -1; + public int currentAlpha; + public boolean modifyLength; + + public Bzip2HuffmanStageDecoder(final Bzip2BitReader reader, final int totalTables, final int alphabetSize) { + this.reader = reader; + this.totalTables = totalTables; + this.alphabetSize = alphabetSize; + + minimumLengths = new int[totalTables]; + codeBases = new int[totalTables][HUFFMAN_DECODE_MAX_CODE_LENGTH + 2]; + codeLimits = new int[totalTables][HUFFMAN_DECODE_MAX_CODE_LENGTH + 1]; + codeSymbols = new int[totalTables][HUFFMAN_MAX_ALPHABET_SIZE]; + tableCodeLengths = new byte[totalTables][HUFFMAN_MAX_ALPHABET_SIZE]; + } + + /** + * Constructs Huffman decoding tables from lists of Canonical Huffman code lengths. + */ + public void createHuffmanDecodingTables() { + final int alphabetSize = this.alphabetSize; + + for (int table = 0; table < tableCodeLengths.length; table++) { + final int[] tableBases = codeBases[table]; + final int[] tableLimits = codeLimits[table]; + final int[] tableSymbols = codeSymbols[table]; + final byte[] codeLengths = tableCodeLengths[table]; + + int minimumLength = HUFFMAN_DECODE_MAX_CODE_LENGTH; + int maximumLength = 0; + + // Find the minimum and maximum code length for the table + for (int i = 0; i < alphabetSize; i++) { + final byte currLength = codeLengths[i]; + maximumLength = Math.max(currLength, maximumLength); + minimumLength = Math.min(currLength, minimumLength); + } + minimumLengths[table] = minimumLength; + + // Calculate the first output symbol for each code length + for (int i = 0; i < alphabetSize; i++) { + tableBases[codeLengths[i] + 1]++; + } + for (int i = 1, b = tableBases[0]; i < HUFFMAN_DECODE_MAX_CODE_LENGTH + 2; i++) { + b += tableBases[i]; + tableBases[i] = b; + } + + // Calculate the first and last Huffman code for each code length (codes at a given + // length are sequential in value) + for (int i = minimumLength, code = 0; i <= maximumLength; i++) { + int base = code; + code += tableBases[i + 1] - tableBases[i]; + tableBases[i] = base - tableBases[i]; + tableLimits[i] = code - 1; + code <<= 1; + } + + // Populate the mapping from canonical code index to output symbol + for (int bitLength = minimumLength, codeIndex = 0; bitLength <= maximumLength; bitLength++) { + for (int symbol = 0; symbol < alphabetSize; symbol++) { + if (codeLengths[symbol] == bitLength) { + tableSymbols[codeIndex++] = symbol; + } + } + } + } + + currentTable = selectors[0]; + } + + /** + * Decodes and returns the next symbol. + * @return The decoded symbol + */ + public int nextSymbol() { + // Move to next group selector if required + if (++groupPosition % HUFFMAN_GROUP_RUN_LENGTH == 0) { + groupIndex++; + if (groupIndex == selectors.length) { + throw new DecompressionException("error decoding block"); + } + currentTable = selectors[groupIndex] & 0xff; + } + + final Bzip2BitReader reader = this.reader; + final int currentTable = this.currentTable; + final int[] tableLimits = codeLimits[currentTable]; + final int[] tableBases = codeBases[currentTable]; + final int[] tableSymbols = codeSymbols[currentTable]; + int codeLength = minimumLengths[currentTable]; + + // Starting with the minimum bit length for the table, read additional bits one at a time + // until a complete code is recognised + int codeBits = reader.readBits(codeLength); + for (; codeLength <= HUFFMAN_DECODE_MAX_CODE_LENGTH; codeLength++) { + if (codeBits <= tableLimits[codeLength]) { + // Convert the code to a symbol index and return + return tableSymbols[codeBits - tableBases[codeLength]]; + } + codeBits = codeBits << 1 | reader.readBits(1); + } + + throw new DecompressionException("a valid code was not recognised"); + } +} diff --git a/netty-bzip2/src/main/java/io/netty/bzip2/Bzip2HuffmanStageEncoder.java b/netty-bzip2/src/main/java/io/netty/bzip2/Bzip2HuffmanStageEncoder.java new file mode 100644 index 0000000..ac63e34 --- /dev/null +++ b/netty-bzip2/src/main/java/io/netty/bzip2/Bzip2HuffmanStageEncoder.java @@ -0,0 +1,374 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.bzip2; + +import io.netty.buffer.ByteBuf; + +import java.util.Arrays; + +import static io.netty.bzip2.Bzip2Constants.HUFFMAN_ENCODE_MAX_CODE_LENGTH; +import static io.netty.bzip2.Bzip2Constants.HUFFMAN_GROUP_RUN_LENGTH; + +/** + * An encoder for the Bzip2 Huffman encoding stage. + */ +public final class Bzip2HuffmanStageEncoder { + /** + * Used in initial Huffman table generation. + */ + private static final int HUFFMAN_HIGH_SYMBOL_COST = 15; + + /** + * The {@link Bzip2BitWriter} to which the Huffman tables and data is written. + */ + private final Bzip2BitWriter writer; + + /** + * The output of the Move To Front Transform and Run Length Encoding[2] stages. + */ + private final char[] mtfBlock; + + /** + * The actual number of values contained in the {@link #mtfBlock} array. + */ + private final int mtfLength; + + /** + * The number of unique values in the {@link #mtfBlock} array. + */ + private final int mtfAlphabetSize; + + /** + * The global frequencies of values within the {@link #mtfBlock} array. + */ + private final int[] mtfSymbolFrequencies; + + /** + * The Canonical Huffman code lengths for each table. + */ + private final int[][] huffmanCodeLengths; + + /** + * Merged code symbols for each table. The value at each position is ((code length << 24) | code). + */ + private final int[][] huffmanMergedCodeSymbols; + + /** + * The selectors for each segment. + */ + private final byte[] selectors; + + /** + * @param writer The {@link Bzip2BitWriter} which provides bit-level writes + * @param mtfBlock The MTF block data + * @param mtfLength The actual length of the MTF block + * @param mtfAlphabetSize The size of the MTF block's alphabet + * @param mtfSymbolFrequencies The frequencies the MTF block's symbols + */ + Bzip2HuffmanStageEncoder(final Bzip2BitWriter writer, final char[] mtfBlock, + final int mtfLength, final int mtfAlphabetSize, final int[] mtfSymbolFrequencies) { + this.writer = writer; + this.mtfBlock = mtfBlock; + this.mtfLength = mtfLength; + this.mtfAlphabetSize = mtfAlphabetSize; + this.mtfSymbolFrequencies = mtfSymbolFrequencies; + + final int totalTables = selectTableCount(mtfLength); + + huffmanCodeLengths = new int[totalTables][mtfAlphabetSize]; + huffmanMergedCodeSymbols = new int[totalTables][mtfAlphabetSize]; + selectors = new byte[(mtfLength + HUFFMAN_GROUP_RUN_LENGTH - 1) / HUFFMAN_GROUP_RUN_LENGTH]; + } + + /** + * Selects an appropriate table count for a given MTF length. + * @param mtfLength The length to select a table count for + * @return The selected table count + */ + private static int selectTableCount(final int mtfLength) { + if (mtfLength >= 2400) { + return 6; + } + if (mtfLength >= 1200) { + return 5; + } + if (mtfLength >= 600) { + return 4; + } + if (mtfLength >= 200) { + return 3; + } + return 2; + } + + /** + * Generate a Huffman code length table for a given list of symbol frequencies. + * @param alphabetSize The total number of symbols + * @param symbolFrequencies The frequencies of the symbols + * @param codeLengths The array to which the generated code lengths should be written + */ + private static void generateHuffmanCodeLengths(final int alphabetSize, + final int[] symbolFrequencies, final int[] codeLengths) { + + final int[] mergedFrequenciesAndIndices = new int[alphabetSize]; + final int[] sortedFrequencies = new int[alphabetSize]; + + // The Huffman allocator needs its input symbol frequencies to be sorted, but we need to + // return code lengths in the same order as the corresponding frequencies are passed in. + + // The symbol frequency and index are merged into a single array of + // integers - frequency in the high 23 bits, index in the low 9 bits. + // 2^23 = 8,388,608 which is higher than the maximum possible frequency for one symbol in a block + // 2^9 = 512 which is higher than the maximum possible alphabet size (== 258) + // Sorting this array simultaneously sorts the frequencies and + // leaves a lookup that can be used to cheaply invert the sort. + for (int i = 0; i < alphabetSize; i++) { + mergedFrequenciesAndIndices[i] = (symbolFrequencies[i] << 9) | i; + } + Arrays.sort(mergedFrequenciesAndIndices); + for (int i = 0; i < alphabetSize; i++) { + sortedFrequencies[i] = mergedFrequenciesAndIndices[i] >>> 9; + } + + // Allocate code lengths - the allocation is in place, + // so the code lengths will be in the sortedFrequencies array afterwards + Bzip2HuffmanAllocator.allocateHuffmanCodeLengths(sortedFrequencies, HUFFMAN_ENCODE_MAX_CODE_LENGTH); + + // Reverse the sort to place the code lengths in the same order as the symbols whose frequencies were passed in + for (int i = 0; i < alphabetSize; i++) { + codeLengths[mergedFrequenciesAndIndices[i] & 0x1ff] = sortedFrequencies[i]; + } + } + + /** + * Generate initial Huffman code length tables, giving each table a different low cost section + * of the alphabet that is roughly equal in overall cumulative frequency. Note that the initial + * tables are invalid for actual Huffman code generation, and only serve as the seed for later + * iterative optimisation in {@link #optimiseSelectorsAndHuffmanTables(boolean)}. + */ + private void generateHuffmanOptimisationSeeds() { + final int[][] huffmanCodeLengths = this.huffmanCodeLengths; + final int[] mtfSymbolFrequencies = this.mtfSymbolFrequencies; + final int mtfAlphabetSize = this.mtfAlphabetSize; + + final int totalTables = huffmanCodeLengths.length; + + int remainingLength = mtfLength; + int lowCostEnd = -1; + + for (int i = 0; i < totalTables; i++) { + + final int targetCumulativeFrequency = remainingLength / (totalTables - i); + final int lowCostStart = lowCostEnd + 1; + int actualCumulativeFrequency = 0; + + while (actualCumulativeFrequency < targetCumulativeFrequency && lowCostEnd < mtfAlphabetSize - 1) { + actualCumulativeFrequency += mtfSymbolFrequencies[++lowCostEnd]; + } + + if (lowCostEnd > lowCostStart && i != 0 && i != totalTables - 1 && (totalTables - i & 1) == 0) { + actualCumulativeFrequency -= mtfSymbolFrequencies[lowCostEnd--]; + } + + final int[] tableCodeLengths = huffmanCodeLengths[i]; + for (int j = 0; j < mtfAlphabetSize; j++) { + if (j < lowCostStart || j > lowCostEnd) { + tableCodeLengths[j] = HUFFMAN_HIGH_SYMBOL_COST; + } + } + + remainingLength -= actualCumulativeFrequency; + } + } + + /** + * Co-optimise the selector list and the alternative Huffman table code lengths. This method is + * called repeatedly in the hope that the total encoded size of the selectors, the Huffman code + * lengths and the block data encoded with them will converge towards a minimum.
+ * If the data is highly incompressible, it is possible that the total encoded size will + * instead diverge (increase) slightly.
+ * @param storeSelectors If {@code true}, write out the (final) chosen selectors + */ + private void optimiseSelectorsAndHuffmanTables(final boolean storeSelectors) { + final char[] mtfBlock = this.mtfBlock; + final byte[] selectors = this.selectors; + final int[][] huffmanCodeLengths = this.huffmanCodeLengths; + final int mtfLength = this.mtfLength; + final int mtfAlphabetSize = this.mtfAlphabetSize; + + final int totalTables = huffmanCodeLengths.length; + final int[][] tableFrequencies = new int[totalTables][mtfAlphabetSize]; + + int selectorIndex = 0; + + // Find the best table for each group of 50 block bytes based on the current Huffman code lengths + for (int groupStart = 0; groupStart < mtfLength;) { + + final int groupEnd = Math.min(groupStart + HUFFMAN_GROUP_RUN_LENGTH, mtfLength) - 1; + + // Calculate the cost of this group when encoded by each table + int[] cost = new int[totalTables]; + for (int i = groupStart; i <= groupEnd; i++) { + final int value = mtfBlock[i]; + for (int j = 0; j < totalTables; j++) { + cost[j] += huffmanCodeLengths[j][value]; + } + } + + // Find the table with the least cost for this group + byte bestTable = 0; + int bestCost = cost[0]; + for (byte i = 1 ; i < totalTables; i++) { + final int tableCost = cost[i]; + if (tableCost < bestCost) { + bestCost = tableCost; + bestTable = i; + } + } + + // Accumulate symbol frequencies for the table chosen for this block + final int[] bestGroupFrequencies = tableFrequencies[bestTable]; + for (int i = groupStart; i <= groupEnd; i++) { + bestGroupFrequencies[mtfBlock[i]]++; + } + + // Store a selector indicating the table chosen for this block + if (storeSelectors) { + selectors[selectorIndex++] = bestTable; + } + groupStart = groupEnd + 1; + } + + // Generate new Huffman code lengths based on the frequencies for each table accumulated in this iteration + for (int i = 0; i < totalTables; i++) { + generateHuffmanCodeLengths(mtfAlphabetSize, tableFrequencies[i], huffmanCodeLengths[i]); + } + } + + /** + * Assigns Canonical Huffman codes based on the calculated lengths. + */ + private void assignHuffmanCodeSymbols() { + final int[][] huffmanMergedCodeSymbols = this.huffmanMergedCodeSymbols; + final int[][] huffmanCodeLengths = this.huffmanCodeLengths; + final int mtfAlphabetSize = this.mtfAlphabetSize; + + final int totalTables = huffmanCodeLengths.length; + + for (int i = 0; i < totalTables; i++) { + final int[] tableLengths = huffmanCodeLengths[i]; + + int minimumLength = 32; + int maximumLength = 0; + for (int j = 0; j < mtfAlphabetSize; j++) { + final int length = tableLengths[j]; + if (length > maximumLength) { + maximumLength = length; + } + if (length < minimumLength) { + minimumLength = length; + } + } + + int code = 0; + for (int j = minimumLength; j <= maximumLength; j++) { + for (int k = 0; k < mtfAlphabetSize; k++) { + if ((huffmanCodeLengths[i][k] & 0xff) == j) { + huffmanMergedCodeSymbols[i][k] = (j << 24) | code; + code++; + } + } + code <<= 1; + } + } + } + + /** + * Write out the selector list and Huffman tables. + */ + private void writeSelectorsAndHuffmanTables(ByteBuf out) { + final Bzip2BitWriter writer = this.writer; + final byte[] selectors = this.selectors; + final int totalSelectors = selectors.length; + final int[][] huffmanCodeLengths = this.huffmanCodeLengths; + final int totalTables = huffmanCodeLengths.length; + final int mtfAlphabetSize = this.mtfAlphabetSize; + + writer.writeBits(out, 3, totalTables); + writer.writeBits(out, 15, totalSelectors); + + // Write the selectors + Bzip2MoveToFrontTable selectorMTF = new Bzip2MoveToFrontTable(); + for (byte selector : selectors) { + writer.writeUnary(out, selectorMTF.valueToFront(selector)); + } + + // Write the Huffman tables + for (final int[] tableLengths : huffmanCodeLengths) { + int currentLength = tableLengths[0]; + + writer.writeBits(out, 5, currentLength); + + for (int j = 0; j < mtfAlphabetSize; j++) { + final int codeLength = tableLengths[j]; + final int value = currentLength < codeLength ? 2 : 3; + int delta = Math.abs(codeLength - currentLength); + while (delta-- > 0) { + writer.writeBits(out, 2, value); + } + writer.writeBoolean(out, false); + currentLength = codeLength; + } + } + } + + /** + * Writes out the encoded block data. + */ + private void writeBlockData(ByteBuf out) { + final Bzip2BitWriter writer = this.writer; + final int[][] huffmanMergedCodeSymbols = this.huffmanMergedCodeSymbols; + final byte[] selectors = this.selectors; + final int mtfLength = this.mtfLength; + + int selectorIndex = 0; + for (int mtfIndex = 0; mtfIndex < mtfLength;) { + final int groupEnd = Math.min(mtfIndex + HUFFMAN_GROUP_RUN_LENGTH, mtfLength) - 1; + final int[] tableMergedCodeSymbols = huffmanMergedCodeSymbols[selectors[selectorIndex++]]; + + while (mtfIndex <= groupEnd) { + final int mergedCodeSymbol = tableMergedCodeSymbols[mtfBlock[mtfIndex++]]; + writer.writeBits(out, mergedCodeSymbol >>> 24, mergedCodeSymbol); + } + } + } + + /** + * Encodes and writes the block data. + */ + void encode(ByteBuf out) { + // Create optimised selector list and Huffman tables + generateHuffmanOptimisationSeeds(); + for (int i = 3; i >= 0; i--) { + optimiseSelectorsAndHuffmanTables(i == 0); + } + assignHuffmanCodeSymbols(); + + // Write out the tables and the block data encoded with them + writeSelectorsAndHuffmanTables(out); + writeBlockData(out); + } +} diff --git a/netty-bzip2/src/main/java/io/netty/bzip2/Bzip2MTFAndRLE2StageEncoder.java b/netty-bzip2/src/main/java/io/netty/bzip2/Bzip2MTFAndRLE2StageEncoder.java new file mode 100644 index 0000000..9ddc984 --- /dev/null +++ b/netty-bzip2/src/main/java/io/netty/bzip2/Bzip2MTFAndRLE2StageEncoder.java @@ -0,0 +1,185 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.bzip2; + +import static io.netty.bzip2.Bzip2Constants.HUFFMAN_MAX_ALPHABET_SIZE; +import static io.netty.bzip2.Bzip2Constants.HUFFMAN_SYMBOL_RUNA; +import static io.netty.bzip2.Bzip2Constants.HUFFMAN_SYMBOL_RUNB; + +/** + * An encoder for the Bzip2 Move To Front Transform and Run-Length Encoding[2] stages.
+ * Although conceptually these two stages are separate, it is computationally efficient to perform + * them in one pass. + */ +public final class Bzip2MTFAndRLE2StageEncoder { + /** + * The Burrows-Wheeler transformed block. + */ + private final int[] bwtBlock; + + /** + * Actual length of the data in the {@link #bwtBlock} array. + */ + private final int bwtLength; + + /** + * At each position, {@code true} if the byte value with that index is present within the block, + * otherwise {@code false}. + */ + private final boolean[] bwtValuesPresent; + + /** + * The output of the Move To Front Transform and Run-Length Encoding[2] stages. + */ + private final char[] mtfBlock; + + /** + * The actual number of values contained in the {@link #mtfBlock} array. + */ + private int mtfLength; + + /** + * The global frequencies of values within the {@link #mtfBlock} array. + */ + private final int[] mtfSymbolFrequencies = new int[HUFFMAN_MAX_ALPHABET_SIZE]; + + /** + * The encoded alphabet size. + */ + private int alphabetSize; + + /** + * @param bwtBlock The Burrows Wheeler Transformed block data + * @param bwtLength The actual length of the BWT data + * @param bwtValuesPresent The values that are present within the BWT data. For each index, + * {@code true} if that value is present within the data, otherwise {@code false} + */ + Bzip2MTFAndRLE2StageEncoder(final int[] bwtBlock, final int bwtLength, final boolean[] bwtValuesPresent) { + this.bwtBlock = bwtBlock; + this.bwtLength = bwtLength; + this.bwtValuesPresent = bwtValuesPresent; + mtfBlock = new char[bwtLength + 1]; + } + + /** + * Performs the Move To Front transform and Run Length Encoding[1] stages. + */ + void encode() { + final int bwtLength = this.bwtLength; + final boolean[] bwtValuesPresent = this.bwtValuesPresent; + final int[] bwtBlock = this.bwtBlock; + final char[] mtfBlock = this.mtfBlock; + final int[] mtfSymbolFrequencies = this.mtfSymbolFrequencies; + final byte[] huffmanSymbolMap = new byte[256]; + final Bzip2MoveToFrontTable symbolMTF = new Bzip2MoveToFrontTable(); + + int totalUniqueValues = 0; + for (int i = 0; i < huffmanSymbolMap.length; i++) { + if (bwtValuesPresent[i]) { + huffmanSymbolMap[i] = (byte) totalUniqueValues++; + } + } + final int endOfBlockSymbol = totalUniqueValues + 1; + + int mtfIndex = 0; + int repeatCount = 0; + int totalRunAs = 0; + int totalRunBs = 0; + for (int i = 0; i < bwtLength; i++) { + // Move To Front + final int mtfPosition = symbolMTF.valueToFront(huffmanSymbolMap[bwtBlock[i] & 0xff]); + // Run Length Encode + if (mtfPosition == 0) { + repeatCount++; + } else { + if (repeatCount > 0) { + repeatCount--; + while (true) { + if ((repeatCount & 1) == 0) { + mtfBlock[mtfIndex++] = HUFFMAN_SYMBOL_RUNA; + totalRunAs++; + } else { + mtfBlock[mtfIndex++] = HUFFMAN_SYMBOL_RUNB; + totalRunBs++; + } + + if (repeatCount <= 1) { + break; + } + repeatCount = (repeatCount - 2) >>> 1; + } + repeatCount = 0; + } + mtfBlock[mtfIndex++] = (char) (mtfPosition + 1); + mtfSymbolFrequencies[mtfPosition + 1]++; + } + } + + if (repeatCount > 0) { + repeatCount--; + while (true) { + if ((repeatCount & 1) == 0) { + mtfBlock[mtfIndex++] = HUFFMAN_SYMBOL_RUNA; + totalRunAs++; + } else { + mtfBlock[mtfIndex++] = HUFFMAN_SYMBOL_RUNB; + totalRunBs++; + } + + if (repeatCount <= 1) { + break; + } + repeatCount = (repeatCount - 2) >>> 1; + } + } + + mtfBlock[mtfIndex] = (char) endOfBlockSymbol; + mtfSymbolFrequencies[endOfBlockSymbol]++; + mtfSymbolFrequencies[HUFFMAN_SYMBOL_RUNA] += totalRunAs; + mtfSymbolFrequencies[HUFFMAN_SYMBOL_RUNB] += totalRunBs; + + mtfLength = mtfIndex + 1; + alphabetSize = endOfBlockSymbol + 1; + } + + /** + * @return The encoded MTF block + */ + char[] mtfBlock() { + return mtfBlock; + } + + /** + * @return The actual length of the MTF block + */ + int mtfLength() { + return mtfLength; + } + + /** + * @return The size of the MTF block's alphabet + */ + int mtfAlphabetSize() { + return alphabetSize; + } + + /** + * @return The frequencies of the MTF block's symbols + */ + int[] mtfSymbolFrequencies() { + return mtfSymbolFrequencies; + } +} diff --git a/netty-bzip2/src/main/java/io/netty/bzip2/Bzip2MoveToFrontTable.java b/netty-bzip2/src/main/java/io/netty/bzip2/Bzip2MoveToFrontTable.java new file mode 100644 index 0000000..f7abd57 --- /dev/null +++ b/netty-bzip2/src/main/java/io/netty/bzip2/Bzip2MoveToFrontTable.java @@ -0,0 +1,84 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.bzip2; + +/** + * A 256 entry Move To Front transform. + */ +public final class Bzip2MoveToFrontTable { + /** + * The Move To Front list. + */ + private final byte[] mtf = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, + 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, + 32, 33, 34, 35, 36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, + 48, 49, 50, 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, + 64, 65, 66, 67, 68, 69, 70, 71, 72, 73, 74, 75, 76, 77, 78, 79, + 80, 81, 82, 83, 84, 85, 86, 87, 88, 89, 90, 91, 92, 93, 94, 95, + 96, 97, 98, 99, 100, 101, 102, 103, 104, 105, 106, 107, 108, 109, 110, 111, + 112, 113, 114, 115, 116, 117, 118, 119, 120, 121, 122, 123, 124, 125, 126, 127, + (byte) 128, (byte) 129, (byte) 130, (byte) 131, (byte) 132, (byte) 133, (byte) 134, (byte) 135, + (byte) 136, (byte) 137, (byte) 138, (byte) 139, (byte) 140, (byte) 141, (byte) 142, (byte) 143, + (byte) 144, (byte) 145, (byte) 146, (byte) 147, (byte) 148, (byte) 149, (byte) 150, (byte) 151, + (byte) 152, (byte) 153, (byte) 154, (byte) 155, (byte) 156, (byte) 157, (byte) 158, (byte) 159, + (byte) 160, (byte) 161, (byte) 162, (byte) 163, (byte) 164, (byte) 165, (byte) 166, (byte) 167, + (byte) 168, (byte) 169, (byte) 170, (byte) 171, (byte) 172, (byte) 173, (byte) 174, (byte) 175, + (byte) 176, (byte) 177, (byte) 178, (byte) 179, (byte) 180, (byte) 181, (byte) 182, (byte) 183, + (byte) 184, (byte) 185, (byte) 186, (byte) 187, (byte) 188, (byte) 189, (byte) 190, (byte) 191, + (byte) 192, (byte) 193, (byte) 194, (byte) 195, (byte) 196, (byte) 197, (byte) 198, (byte) 199, + (byte) 200, (byte) 201, (byte) 202, (byte) 203, (byte) 204, (byte) 205, (byte) 206, (byte) 207, + (byte) 208, (byte) 209, (byte) 210, (byte) 211, (byte) 212, (byte) 213, (byte) 214, (byte) 215, + (byte) 216, (byte) 217, (byte) 218, (byte) 219, (byte) 220, (byte) 221, (byte) 222, (byte) 223, + (byte) 224, (byte) 225, (byte) 226, (byte) 227, (byte) 228, (byte) 229, (byte) 230, (byte) 231, + (byte) 232, (byte) 233, (byte) 234, (byte) 235, (byte) 236, (byte) 237, (byte) 238, (byte) 239, + (byte) 240, (byte) 241, (byte) 242, (byte) 243, (byte) 244, (byte) 245, (byte) 246, (byte) 247, + (byte) 248, (byte) 249, (byte) 250, (byte) 251, (byte) 252, (byte) 253, (byte) 254, (byte) 255 + }; + + /** + * Moves a value to the head of the MTF list (forward Move To Front transform). + * @param value The value to move + * @return The position the value moved from + */ + public int valueToFront(final byte value) { + int index = 0; + byte temp = mtf[0]; + if (value != temp) { + mtf[0] = value; + while (value != temp) { + index++; + final byte temp2 = temp; + temp = mtf[index]; + mtf[index] = temp2; + } + } + return index; + } + + /** + * Gets the value from a given index and moves it to the front of the MTF list (inverse Move To Front transform). + * @param index The index to move + * @return The value at the given index + */ + public byte indexToFront(final int index) { + final byte value = mtf[index]; + System.arraycopy(mtf, 0, mtf, 1, index); + mtf[0] = value; + + return value; + } +} diff --git a/netty-bzip2/src/main/java/io/netty/bzip2/Bzip2Rand.java b/netty-bzip2/src/main/java/io/netty/bzip2/Bzip2Rand.java new file mode 100644 index 0000000..16c4d2f --- /dev/null +++ b/netty-bzip2/src/main/java/io/netty/bzip2/Bzip2Rand.java @@ -0,0 +1,77 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.bzip2; + +/** + * Random numbers for decompress Bzip2 blocks. + */ +final class Bzip2Rand { + /** + * The Bzip2 specification originally included the optional addition of a slight pseudo-random + * perturbation to the input data, in order to work around the block sorting algorithm's non- + * optimal performance on some types of input. The current mainline bzip2 does not require this + * and will not create randomised blocks, but compatibility is still required for old data (and + * third party compressors that haven't caught up). When decompressing a randomised block, for + * each value N in this array, a 1 will be XOR'd onto the output of the Burrows-Wheeler + * transform stage after N bytes, then the next N taken from the following entry. + */ + private static final int[] RNUMS = { + 619, 720, 127, 481, 931, 816, 813, 233, 566, 247, 985, 724, 205, 454, 863, 491, + 741, 242, 949, 214, 733, 859, 335, 708, 621, 574, 73, 654, 730, 472, 419, 436, + 278, 496, 867, 210, 399, 680, 480, 51, 878, 465, 811, 169, 869, 675, 611, 697, + 867, 561, 862, 687, 507, 283, 482, 129, 807, 591, 733, 623, 150, 238, 59, 379, + 684, 877, 625, 169, 643, 105, 170, 607, 520, 932, 727, 476, 693, 425, 174, 647, + 73, 122, 335, 530, 442, 853, 695, 249, 445, 515, 909, 545, 703, 919, 874, 474, + 882, 500, 594, 612, 641, 801, 220, 162, 819, 984, 589, 513, 495, 799, 161, 604, + 958, 533, 221, 400, 386, 867, 600, 782, 382, 596, 414, 171, 516, 375, 682, 485, + 911, 276, 98, 553, 163, 354, 666, 933, 424, 341, 533, 870, 227, 730, 475, 186, + 263, 647, 537, 686, 600, 224, 469, 68, 770, 919, 190, 373, 294, 822, 808, 206, + 184, 943, 795, 384, 383, 461, 404, 758, 839, 887, 715, 67, 618, 276, 204, 918, + 873, 777, 604, 560, 951, 160, 578, 722, 79, 804, 96, 409, 713, 940, 652, 934, + 970, 447, 318, 353, 859, 672, 112, 785, 645, 863, 803, 350, 139, 93, 354, 99, + 820, 908, 609, 772, 154, 274, 580, 184, 79, 626, 630, 742, 653, 282, 762, 623, + 680, 81, 927, 626, 789, 125, 411, 521, 938, 300, 821, 78, 343, 175, 128, 250, + 170, 774, 972, 275, 999, 639, 495, 78, 352, 126, 857, 956, 358, 619, 580, 124, + 737, 594, 701, 612, 669, 112, 134, 694, 363, 992, 809, 743, 168, 974, 944, 375, + 748, 52, 600, 747, 642, 182, 862, 81, 344, 805, 988, 739, 511, 655, 814, 334, + 249, 515, 897, 955, 664, 981, 649, 113, 974, 459, 893, 228, 433, 837, 553, 268, + 926, 240, 102, 654, 459, 51, 686, 754, 806, 760, 493, 403, 415, 394, 687, 700, + 946, 670, 656, 610, 738, 392, 760, 799, 887, 653, 978, 321, 576, 617, 626, 502, + 894, 679, 243, 440, 680, 879, 194, 572, 640, 724, 926, 56, 204, 700, 707, 151, + 457, 449, 797, 195, 791, 558, 945, 679, 297, 59, 87, 824, 713, 663, 412, 693, + 342, 606, 134, 108, 571, 364, 631, 212, 174, 643, 304, 329, 343, 97, 430, 751, + 497, 314, 983, 374, 822, 928, 140, 206, 73, 263, 980, 736, 876, 478, 430, 305, + 170, 514, 364, 692, 829, 82, 855, 953, 676, 246, 369, 970, 294, 750, 807, 827, + 150, 790, 288, 923, 804, 378, 215, 828, 592, 281, 565, 555, 710, 82, 896, 831, + 547, 261, 524, 462, 293, 465, 502, 56, 661, 821, 976, 991, 658, 869, 905, 758, + 745, 193, 768, 550, 608, 933, 378, 286, 215, 979, 792, 961, 61, 688, 793, 644, + 986, 403, 106, 366, 905, 644, 372, 567, 466, 434, 645, 210, 389, 550, 919, 135, + 780, 773, 635, 389, 707, 100, 626, 958, 165, 504, 920, 176, 193, 713, 857, 265, + 203, 50, 668, 108, 645, 990, 626, 197, 510, 357, 358, 850, 858, 364, 936, 638 + }; + + /** + * Return the random number at a specific index. + * + * @param i the index + * @return the random number + */ + static int rNums(int i) { + return RNUMS[i]; + } + + private Bzip2Rand() { } +} diff --git a/netty-bzip2/src/main/java/io/netty/bzip2/Crc32.java b/netty-bzip2/src/main/java/io/netty/bzip2/Crc32.java new file mode 100644 index 0000000..920a56b --- /dev/null +++ b/netty-bzip2/src/main/java/io/netty/bzip2/Crc32.java @@ -0,0 +1,123 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.bzip2; + +/** + * A CRC32 calculator. + */ +final class Crc32 { + /** + * A static CRC lookup table. + */ + private static final int[] crc32Table = { + 0x00000000, 0x04c11db7, 0x09823b6e, 0x0d4326d9, + 0x130476dc, 0x17c56b6b, 0x1a864db2, 0x1e475005, + 0x2608edb8, 0x22c9f00f, 0x2f8ad6d6, 0x2b4bcb61, + 0x350c9b64, 0x31cd86d3, 0x3c8ea00a, 0x384fbdbd, + 0x4c11db70, 0x48d0c6c7, 0x4593e01e, 0x4152fda9, + 0x5f15adac, 0x5bd4b01b, 0x569796c2, 0x52568b75, + 0x6a1936c8, 0x6ed82b7f, 0x639b0da6, 0x675a1011, + 0x791d4014, 0x7ddc5da3, 0x709f7b7a, 0x745e66cd, + 0x9823b6e0, 0x9ce2ab57, 0x91a18d8e, 0x95609039, + 0x8b27c03c, 0x8fe6dd8b, 0x82a5fb52, 0x8664e6e5, + 0xbe2b5b58, 0xbaea46ef, 0xb7a96036, 0xb3687d81, + 0xad2f2d84, 0xa9ee3033, 0xa4ad16ea, 0xa06c0b5d, + 0xd4326d90, 0xd0f37027, 0xddb056fe, 0xd9714b49, + 0xc7361b4c, 0xc3f706fb, 0xceb42022, 0xca753d95, + 0xf23a8028, 0xf6fb9d9f, 0xfbb8bb46, 0xff79a6f1, + 0xe13ef6f4, 0xe5ffeb43, 0xe8bccd9a, 0xec7dd02d, + 0x34867077, 0x30476dc0, 0x3d044b19, 0x39c556ae, + 0x278206ab, 0x23431b1c, 0x2e003dc5, 0x2ac12072, + 0x128e9dcf, 0x164f8078, 0x1b0ca6a1, 0x1fcdbb16, + 0x018aeb13, 0x054bf6a4, 0x0808d07d, 0x0cc9cdca, + 0x7897ab07, 0x7c56b6b0, 0x71159069, 0x75d48dde, + 0x6b93dddb, 0x6f52c06c, 0x6211e6b5, 0x66d0fb02, + 0x5e9f46bf, 0x5a5e5b08, 0x571d7dd1, 0x53dc6066, + 0x4d9b3063, 0x495a2dd4, 0x44190b0d, 0x40d816ba, + 0xaca5c697, 0xa864db20, 0xa527fdf9, 0xa1e6e04e, + 0xbfa1b04b, 0xbb60adfc, 0xb6238b25, 0xb2e29692, + 0x8aad2b2f, 0x8e6c3698, 0x832f1041, 0x87ee0df6, + 0x99a95df3, 0x9d684044, 0x902b669d, 0x94ea7b2a, + 0xe0b41de7, 0xe4750050, 0xe9362689, 0xedf73b3e, + 0xf3b06b3b, 0xf771768c, 0xfa325055, 0xfef34de2, + 0xc6bcf05f, 0xc27dede8, 0xcf3ecb31, 0xcbffd686, + 0xd5b88683, 0xd1799b34, 0xdc3abded, 0xd8fba05a, + 0x690ce0ee, 0x6dcdfd59, 0x608edb80, 0x644fc637, + 0x7a089632, 0x7ec98b85, 0x738aad5c, 0x774bb0eb, + 0x4f040d56, 0x4bc510e1, 0x46863638, 0x42472b8f, + 0x5c007b8a, 0x58c1663d, 0x558240e4, 0x51435d53, + 0x251d3b9e, 0x21dc2629, 0x2c9f00f0, 0x285e1d47, + 0x36194d42, 0x32d850f5, 0x3f9b762c, 0x3b5a6b9b, + 0x0315d626, 0x07d4cb91, 0x0a97ed48, 0x0e56f0ff, + 0x1011a0fa, 0x14d0bd4d, 0x19939b94, 0x1d528623, + 0xf12f560e, 0xf5ee4bb9, 0xf8ad6d60, 0xfc6c70d7, + 0xe22b20d2, 0xe6ea3d65, 0xeba91bbc, 0xef68060b, + 0xd727bbb6, 0xd3e6a601, 0xdea580d8, 0xda649d6f, + 0xc423cd6a, 0xc0e2d0dd, 0xcda1f604, 0xc960ebb3, + 0xbd3e8d7e, 0xb9ff90c9, 0xb4bcb610, 0xb07daba7, + 0xae3afba2, 0xaafbe615, 0xa7b8c0cc, 0xa379dd7b, + 0x9b3660c6, 0x9ff77d71, 0x92b45ba8, 0x9675461f, + 0x8832161a, 0x8cf30bad, 0x81b02d74, 0x857130c3, + 0x5d8a9099, 0x594b8d2e, 0x5408abf7, 0x50c9b640, + 0x4e8ee645, 0x4a4ffbf2, 0x470cdd2b, 0x43cdc09c, + 0x7b827d21, 0x7f436096, 0x7200464f, 0x76c15bf8, + 0x68860bfd, 0x6c47164a, 0x61043093, 0x65c52d24, + 0x119b4be9, 0x155a565e, 0x18197087, 0x1cd86d30, + 0x029f3d35, 0x065e2082, 0x0b1d065b, 0x0fdc1bec, + 0x3793a651, 0x3352bbe6, 0x3e119d3f, 0x3ad08088, + 0x2497d08d, 0x2056cd3a, 0x2d15ebe3, 0x29d4f654, + 0xc5a92679, 0xc1683bce, 0xcc2b1d17, 0xc8ea00a0, + 0xd6ad50a5, 0xd26c4d12, 0xdf2f6bcb, 0xdbee767c, + 0xe3a1cbc1, 0xe760d676, 0xea23f0af, 0xeee2ed18, + 0xf0a5bd1d, 0xf464a0aa, 0xf9278673, 0xfde69bc4, + 0x89b8fd09, 0x8d79e0be, 0x803ac667, 0x84fbdbd0, + 0x9abc8bd5, 0x9e7d9662, 0x933eb0bb, 0x97ffad0c, + 0xafb010b1, 0xab710d06, 0xa6322bdf, 0xa2f33668, + 0xbcb4666d, 0xb8757bda, 0xb5365d03, 0xb1f740b4 + }; + + /** + * The current CRC. + */ + private int crc = 0xffffffff; + + /** + * @return The current CRC. + */ + public int getCRC() { + return ~crc; + } + + /** + * Update the CRC with a single byte. + * @param value The value to update the CRC with + */ + public void updateCRC(final int value) { + final int crc = this.crc; + this.crc = crc << 8 ^ crc32Table[(crc >> 24 ^ value) & 0xff]; + } + + /** + * Update the CRC with a sequence of identical bytes. + * @param value The value to update the CRC with + * @param count The number of bytes + */ + public void updateCRC(final int value, int count) { + while (count-- > 0) { + updateCRC(value); + } + } +} diff --git a/netty-bzip2/src/main/java/io/netty/bzip2/DecompressionException.java b/netty-bzip2/src/main/java/io/netty/bzip2/DecompressionException.java new file mode 100644 index 0000000..c9e1d13 --- /dev/null +++ b/netty-bzip2/src/main/java/io/netty/bzip2/DecompressionException.java @@ -0,0 +1,53 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.bzip2; + +import io.netty.handler.codec.DecoderException; + +/** + * A {@link DecoderException} that is raised when decompression failed. + */ +public class DecompressionException extends DecoderException { + + private static final long serialVersionUID = 3546272712208105199L; + + /** + * Creates a new instance. + */ + public DecompressionException() { + } + + /** + * Creates a new instance. + */ + public DecompressionException(String message, Throwable cause) { + super(message, cause); + } + + /** + * Creates a new instance. + */ + public DecompressionException(String message) { + super(message); + } + + /** + * Creates a new instance. + */ + public DecompressionException(Throwable cause) { + super(cause); + } +} diff --git a/netty-bzip2/src/main/java/module-info.java b/netty-bzip2/src/main/java/module-info.java new file mode 100644 index 0000000..ff06a93 --- /dev/null +++ b/netty-bzip2/src/main/java/module-info.java @@ -0,0 +1,6 @@ +module org.xbib.io.netty.bziptwo { + exports io.netty.bzip2; + requires org.xbib.io.netty.buffer; + requires org.xbib.io.netty.handler.codec; + requires org.xbib.io.netty.util; +} diff --git a/netty-channel-unix/build.gradle b/netty-channel-unix/build.gradle new file mode 100644 index 0000000..d4c655e --- /dev/null +++ b/netty-channel-unix/build.gradle @@ -0,0 +1,5 @@ +dependencies { + api project(':netty-buffer') + api project(':netty-channel') + api project(':netty-util') +} diff --git a/netty-channel-unix/src/main/java/io/netty/channel/unix/Buffer.java b/netty-channel-unix/src/main/java/io/netty/channel/unix/Buffer.java new file mode 100644 index 0000000..6a177a9 --- /dev/null +++ b/netty-channel-unix/src/main/java/io/netty/channel/unix/Buffer.java @@ -0,0 +1,68 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.unix; + +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.UnstableApi; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + +@UnstableApi +public final class Buffer { + + private Buffer() { } + + /** + * Free the direct {@link ByteBuffer}. + */ + public static void free(ByteBuffer buffer) { + PlatformDependent.freeDirectBuffer(buffer); + } + + /** + * Returns a new {@link ByteBuffer} which has the same {@link ByteOrder} as the native order of the machine. + */ + public static ByteBuffer allocateDirectWithNativeOrder(int capacity) { + return ByteBuffer.allocateDirect(capacity).order( + PlatformDependent.BIG_ENDIAN_NATIVE_ORDER ? ByteOrder.BIG_ENDIAN : ByteOrder.LITTLE_ENDIAN); + } + + /** + * Returns the memory address of the given direct {@link ByteBuffer}. + */ + public static long memoryAddress(ByteBuffer buffer) { + assert buffer.isDirect(); + if (PlatformDependent.hasUnsafe()) { + return PlatformDependent.directBufferAddress(buffer); + } + return memoryAddress0(buffer); + } + + /** + * Returns the size of a pointer. + */ + public static int addressSize() { + if (PlatformDependent.hasUnsafe()) { + return PlatformDependent.addressSize(); + } + return addressSize0(); + } + + // If Unsafe can not be used we will need to do JNI calls. + private static native int addressSize0(); + private static native long memoryAddress0(ByteBuffer buffer); +} diff --git a/netty-channel-unix/src/main/java/io/netty/channel/unix/DatagramSocketAddress.java b/netty-channel-unix/src/main/java/io/netty/channel/unix/DatagramSocketAddress.java new file mode 100644 index 0000000..523bbce --- /dev/null +++ b/netty-channel-unix/src/main/java/io/netty/channel/unix/DatagramSocketAddress.java @@ -0,0 +1,57 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.unix; + +import java.net.Inet6Address; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.UnknownHostException; + +/** + * Act as special {@link InetSocketAddress} to be able to easily pass all needed data from JNI without the need + * to create more objects then needed. + *

+ * Internal usage only! + */ +public final class DatagramSocketAddress extends InetSocketAddress { + private static final long serialVersionUID = 3094819287843178401L; + + // holds the amount of received bytes + private final int receivedAmount; + private final DatagramSocketAddress localAddress; + + DatagramSocketAddress(byte[] addr, int scopeId, int port, int receivedAmount, DatagramSocketAddress local) + throws UnknownHostException { + super(newAddress(addr, scopeId), port); + this.receivedAmount = receivedAmount; + localAddress = local; + } + + public DatagramSocketAddress localAddress() { + return localAddress; + } + + public int receivedAmount() { + return receivedAmount; + } + + private static InetAddress newAddress(byte[] bytes, int scopeId) throws UnknownHostException { + if (bytes.length == 4) { + return InetAddress.getByAddress(bytes); + } + return Inet6Address.getByAddress(null, bytes, scopeId); + } +} diff --git a/netty-channel-unix/src/main/java/io/netty/channel/unix/DomainDatagramChannel.java b/netty-channel-unix/src/main/java/io/netty/channel/unix/DomainDatagramChannel.java new file mode 100644 index 0000000..a26ef95 --- /dev/null +++ b/netty-channel-unix/src/main/java/io/netty/channel/unix/DomainDatagramChannel.java @@ -0,0 +1,39 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.unix; + +import io.netty.channel.Channel; + +/** + * A {@link UnixChannel} that supports communication via + * UNIX domain datagram sockets. + */ +public interface DomainDatagramChannel extends UnixChannel, Channel { + + @Override + DomainDatagramChannelConfig config(); + + /** + * Return {@code true} if the {@link DomainDatagramChannel} is connected to the remote peer. + */ + boolean isConnected(); + + @Override + DomainSocketAddress localAddress(); + + @Override + DomainSocketAddress remoteAddress(); +} diff --git a/netty-channel-unix/src/main/java/io/netty/channel/unix/DomainDatagramChannelConfig.java b/netty-channel-unix/src/main/java/io/netty/channel/unix/DomainDatagramChannelConfig.java new file mode 100644 index 0000000..68b1a97 --- /dev/null +++ b/netty-channel-unix/src/main/java/io/netty/channel/unix/DomainDatagramChannelConfig.java @@ -0,0 +1,80 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.unix; + +import io.netty.buffer.ByteBufAllocator; +import io.netty.channel.ChannelConfig; +import io.netty.channel.ChannelOption; +import io.netty.channel.MessageSizeEstimator; +import io.netty.channel.RecvByteBufAllocator; +import io.netty.channel.WriteBufferWaterMark; + +/** + * A {@link ChannelConfig} for a {@link DomainDatagramChannel}. + * + *

Available options

+ * + * In addition to the options provided by {@link ChannelConfig}, + * {@link DomainDatagramChannelConfig} allows the following options in the option map: + * + * + * + * + * + * + * + *
NameAssociated setter method
{@link ChannelOption#SO_SNDBUF}{@link #setSendBufferSize(int)}
+ */ +public interface DomainDatagramChannelConfig extends ChannelConfig { + + @Override + DomainDatagramChannelConfig setAllocator(ByteBufAllocator allocator); + + @Override + DomainDatagramChannelConfig setAutoClose(boolean autoClose); + + @Override + DomainDatagramChannelConfig setAutoRead(boolean autoRead); + + @Override + DomainDatagramChannelConfig setConnectTimeoutMillis(int connectTimeoutMillis); + + @Override + @Deprecated + DomainDatagramChannelConfig setMaxMessagesPerRead(int maxMessagesPerRead); + + @Override + DomainDatagramChannelConfig setMessageSizeEstimator(MessageSizeEstimator estimator); + + @Override + DomainDatagramChannelConfig setRecvByteBufAllocator(RecvByteBufAllocator allocator); + + /** + * Sets the {@link java.net.StandardSocketOptions#SO_SNDBUF} option. + */ + DomainDatagramChannelConfig setSendBufferSize(int sendBufferSize); + + /** + * Gets the {@link java.net.StandardSocketOptions#SO_SNDBUF} option. + */ + int getSendBufferSize(); + + @Override + DomainDatagramChannelConfig setWriteBufferWaterMark(WriteBufferWaterMark writeBufferWaterMark); + + @Override + DomainDatagramChannelConfig setWriteSpinCount(int writeSpinCount); +} diff --git a/netty-channel-unix/src/main/java/io/netty/channel/unix/DomainDatagramPacket.java b/netty-channel-unix/src/main/java/io/netty/channel/unix/DomainDatagramPacket.java new file mode 100644 index 0000000..39a1cd3 --- /dev/null +++ b/netty-channel-unix/src/main/java/io/netty/channel/unix/DomainDatagramPacket.java @@ -0,0 +1,86 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.unix; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufHolder; +import io.netty.channel.DefaultAddressedEnvelope; + +/** + * The message container that is used for {@link DomainDatagramChannel} to communicate with the remote peer. + */ +public final class DomainDatagramPacket + extends DefaultAddressedEnvelope implements ByteBufHolder { + + /** + * Create a new instance with the specified packet {@code data} and {@code recipient} address. + */ + public DomainDatagramPacket(ByteBuf data, DomainSocketAddress recipient) { + super(data, recipient); + } + + /** + * Create a new instance with the specified packet {@code data}, {@code recipient} address, and {@code sender} + * address. + */ + public DomainDatagramPacket(ByteBuf data, DomainSocketAddress recipient, DomainSocketAddress sender) { + super(data, recipient, sender); + } + + @Override + public DomainDatagramPacket copy() { + return replace(content().copy()); + } + + @Override + public DomainDatagramPacket duplicate() { + return replace(content().duplicate()); + } + + @Override + public DomainDatagramPacket replace(ByteBuf content) { + return new DomainDatagramPacket(content, recipient(), sender()); + } + + @Override + public DomainDatagramPacket retain() { + super.retain(); + return this; + } + + @Override + public DomainDatagramPacket retain(int increment) { + super.retain(increment); + return this; + } + + @Override + public DomainDatagramPacket retainedDuplicate() { + return replace(content().retainedDuplicate()); + } + + @Override + public DomainDatagramPacket touch() { + super.touch(); + return this; + } + + @Override + public DomainDatagramPacket touch(Object hint) { + super.touch(hint); + return this; + } +} diff --git a/netty-channel-unix/src/main/java/io/netty/channel/unix/DomainDatagramSocketAddress.java b/netty-channel-unix/src/main/java/io/netty/channel/unix/DomainDatagramSocketAddress.java new file mode 100644 index 0000000..b67c670 --- /dev/null +++ b/netty-channel-unix/src/main/java/io/netty/channel/unix/DomainDatagramSocketAddress.java @@ -0,0 +1,48 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.unix; + +import io.netty.util.CharsetUtil; + +/** + * Act as special {@link DomainSocketAddress} to be able to easily pass all needed data from JNI without the need + * to create more objects then needed. + *

+ * Internal usage only! + */ +public final class DomainDatagramSocketAddress extends DomainSocketAddress { + + private static final long serialVersionUID = -5925732678737768223L; + + private final DomainDatagramSocketAddress localAddress; + // holds the amount of received bytes + private final int receivedAmount; + + public DomainDatagramSocketAddress(byte[] socketPath, int receivedAmount, + DomainDatagramSocketAddress localAddress) { + super(new String(socketPath, CharsetUtil.UTF_8)); + this.localAddress = localAddress; + this.receivedAmount = receivedAmount; + } + + public DomainDatagramSocketAddress localAddress() { + return localAddress; + } + + public int receivedAmount() { + return receivedAmount; + } +} diff --git a/netty-channel-unix/src/main/java/io/netty/channel/unix/DomainSocketAddress.java b/netty-channel-unix/src/main/java/io/netty/channel/unix/DomainSocketAddress.java new file mode 100644 index 0000000..c73dba4 --- /dev/null +++ b/netty-channel-unix/src/main/java/io/netty/channel/unix/DomainSocketAddress.java @@ -0,0 +1,67 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.unix; + +import io.netty.util.internal.ObjectUtil; + +import java.io.File; +import java.net.SocketAddress; + +/** + * A address for a + * Unix Domain Socket. + */ +public class DomainSocketAddress extends SocketAddress { + private static final long serialVersionUID = -6934618000832236893L; + private final String socketPath; + + public DomainSocketAddress(String socketPath) { + this.socketPath = ObjectUtil.checkNotNull(socketPath, "socketPath"); + } + + public DomainSocketAddress(File file) { + this(file.getPath()); + } + + /** + * The path to the domain socket. + */ + public String path() { + return socketPath; + } + + @Override + public String toString() { + return path(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof DomainSocketAddress)) { + return false; + } + + return ((DomainSocketAddress) o).socketPath.equals(socketPath); + } + + @Override + public int hashCode() { + return socketPath.hashCode(); + } +} diff --git a/netty-channel-unix/src/main/java/io/netty/channel/unix/DomainSocketChannel.java b/netty-channel-unix/src/main/java/io/netty/channel/unix/DomainSocketChannel.java new file mode 100644 index 0000000..97bebaa --- /dev/null +++ b/netty-channel-unix/src/main/java/io/netty/channel/unix/DomainSocketChannel.java @@ -0,0 +1,33 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.unix; + +import io.netty.channel.socket.DuplexChannel; + +/** + * A {@link UnixChannel} that supports communication via + * Unix Domain Socket. + */ +public interface DomainSocketChannel extends UnixChannel, DuplexChannel { + @Override + DomainSocketAddress remoteAddress(); + + @Override + DomainSocketAddress localAddress(); + + @Override + DomainSocketChannelConfig config(); +} diff --git a/netty-channel-unix/src/main/java/io/netty/channel/unix/DomainSocketChannelConfig.java b/netty-channel-unix/src/main/java/io/netty/channel/unix/DomainSocketChannelConfig.java new file mode 100644 index 0000000..56696de --- /dev/null +++ b/netty-channel-unix/src/main/java/io/netty/channel/unix/DomainSocketChannelConfig.java @@ -0,0 +1,80 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.unix; + +import io.netty.buffer.ByteBufAllocator; +import io.netty.channel.ChannelConfig; +import io.netty.channel.MessageSizeEstimator; +import io.netty.channel.RecvByteBufAllocator; +import io.netty.channel.WriteBufferWaterMark; + +/** + * Special {@link ChannelConfig} for {@link DomainSocketChannel}s. + */ +public interface DomainSocketChannelConfig extends ChannelConfig { + + @Override + @Deprecated + DomainSocketChannelConfig setMaxMessagesPerRead(int maxMessagesPerRead); + + @Override + DomainSocketChannelConfig setConnectTimeoutMillis(int connectTimeoutMillis); + + @Override + DomainSocketChannelConfig setWriteSpinCount(int writeSpinCount); + + @Override + DomainSocketChannelConfig setAllocator(ByteBufAllocator allocator); + + @Override + DomainSocketChannelConfig setRecvByteBufAllocator(RecvByteBufAllocator allocator); + + @Override + DomainSocketChannelConfig setAutoRead(boolean autoRead); + + @Override + DomainSocketChannelConfig setAutoClose(boolean autoClose); + + @Override + @Deprecated + DomainSocketChannelConfig setWriteBufferHighWaterMark(int writeBufferHighWaterMark); + + @Override + @Deprecated + DomainSocketChannelConfig setWriteBufferLowWaterMark(int writeBufferLowWaterMark); + + @Override + DomainSocketChannelConfig setWriteBufferWaterMark(WriteBufferWaterMark writeBufferWaterMark); + + @Override + DomainSocketChannelConfig setMessageSizeEstimator(MessageSizeEstimator estimator); + + /** + * Change the {@link DomainSocketReadMode} for the channel. The default is + * {@link DomainSocketReadMode#BYTES} which means bytes will be read from the + * {@link io.netty.channel.Channel} and passed through the pipeline. If + * {@link DomainSocketReadMode#FILE_DESCRIPTORS} is used + * {@link FileDescriptor}s will be passed through the {@link io.netty.channel.ChannelPipeline}. + * + * This setting can be modified on the fly if needed. + */ + DomainSocketChannelConfig setReadMode(DomainSocketReadMode mode); + + /** + * Return the {@link DomainSocketReadMode} for the channel. + */ + DomainSocketReadMode getReadMode(); +} diff --git a/netty-channel-unix/src/main/java/io/netty/channel/unix/DomainSocketReadMode.java b/netty-channel-unix/src/main/java/io/netty/channel/unix/DomainSocketReadMode.java new file mode 100644 index 0000000..60d2680 --- /dev/null +++ b/netty-channel-unix/src/main/java/io/netty/channel/unix/DomainSocketReadMode.java @@ -0,0 +1,34 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.unix; + +import io.netty.buffer.ByteBuf; + +/** + * Different modes of reading from a {@link DomainSocketChannel}. + */ +public enum DomainSocketReadMode { + + /** + * Read {@link ByteBuf}s from the {@link DomainSocketChannel}. + */ + BYTES, + + /** + * Read {@link FileDescriptor}s from the {@link DomainSocketChannel}. + */ + FILE_DESCRIPTORS +} diff --git a/netty-channel-unix/src/main/java/io/netty/channel/unix/Errors.java b/netty-channel-unix/src/main/java/io/netty/channel/unix/Errors.java new file mode 100644 index 0000000..cbb0577 --- /dev/null +++ b/netty-channel-unix/src/main/java/io/netty/channel/unix/Errors.java @@ -0,0 +1,223 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.unix; + +import io.netty.util.internal.EmptyArrays; + +import java.io.FileNotFoundException; +import java.io.IOException; +import java.net.ConnectException; +import java.net.NoRouteToHostException; +import java.nio.channels.AlreadyConnectedException; +import java.nio.channels.ClosedChannelException; +import java.nio.channels.ConnectionPendingException; +import java.nio.channels.NotYetConnectedException; + +import static io.netty.channel.unix.ErrorsStaticallyReferencedJniMethods.errnoEAGAIN; +import static io.netty.channel.unix.ErrorsStaticallyReferencedJniMethods.errnoEBADF; +import static io.netty.channel.unix.ErrorsStaticallyReferencedJniMethods.errnoECONNRESET; +import static io.netty.channel.unix.ErrorsStaticallyReferencedJniMethods.errnoEINPROGRESS; +import static io.netty.channel.unix.ErrorsStaticallyReferencedJniMethods.errnoENOENT; +import static io.netty.channel.unix.ErrorsStaticallyReferencedJniMethods.errnoENOTCONN; +import static io.netty.channel.unix.ErrorsStaticallyReferencedJniMethods.errnoEPIPE; +import static io.netty.channel.unix.ErrorsStaticallyReferencedJniMethods.errnoEWOULDBLOCK; +import static io.netty.channel.unix.ErrorsStaticallyReferencedJniMethods.errorEALREADY; +import static io.netty.channel.unix.ErrorsStaticallyReferencedJniMethods.errorECONNREFUSED; +import static io.netty.channel.unix.ErrorsStaticallyReferencedJniMethods.errorEHOSTUNREACH; +import static io.netty.channel.unix.ErrorsStaticallyReferencedJniMethods.errorEISCONN; +import static io.netty.channel.unix.ErrorsStaticallyReferencedJniMethods.errorENETUNREACH; +import static io.netty.channel.unix.ErrorsStaticallyReferencedJniMethods.strError; + +/** + * Internal usage only! + *

Static members which call JNI methods must be defined in {@link ErrorsStaticallyReferencedJniMethods}. + */ +public final class Errors { + // As all our JNI methods return -errno on error we need to compare with the negative errno codes. + public static final int ERRNO_ENOENT_NEGATIVE = -errnoENOENT(); + public static final int ERRNO_ENOTCONN_NEGATIVE = -errnoENOTCONN(); + public static final int ERRNO_EBADF_NEGATIVE = -errnoEBADF(); + public static final int ERRNO_EPIPE_NEGATIVE = -errnoEPIPE(); + public static final int ERRNO_ECONNRESET_NEGATIVE = -errnoECONNRESET(); + public static final int ERRNO_EAGAIN_NEGATIVE = -errnoEAGAIN(); + public static final int ERRNO_EWOULDBLOCK_NEGATIVE = -errnoEWOULDBLOCK(); + public static final int ERRNO_EINPROGRESS_NEGATIVE = -errnoEINPROGRESS(); + public static final int ERROR_ECONNREFUSED_NEGATIVE = -errorECONNREFUSED(); + public static final int ERROR_EISCONN_NEGATIVE = -errorEISCONN(); + public static final int ERROR_EALREADY_NEGATIVE = -errorEALREADY(); + public static final int ERROR_ENETUNREACH_NEGATIVE = -errorENETUNREACH(); + public static final int ERROR_EHOSTUNREACH_NEGATIVE = -errorEHOSTUNREACH(); + + /** + * Holds the mappings for errno codes to String messages. + * This eliminates the need to call back into JNI to get the right String message on an exception + * and thus is faster. + * + * Choose an array length which should give us enough space in the future even when more errno codes + * will be added. + */ + private static final String[] ERRORS = new String[2048]; + + /** + * Internal usage only! + */ + public static final class NativeIoException extends IOException { + private static final long serialVersionUID = 8222160204268655526L; + private final int expectedErr; + private final boolean fillInStackTrace; + + public NativeIoException(String method, int expectedErr) { + this(method, expectedErr, true); + } + + public NativeIoException(String method, int expectedErr, boolean fillInStackTrace) { + super(method + "(..) failed: " + errnoString(-expectedErr)); + this.expectedErr = expectedErr; + this.fillInStackTrace = fillInStackTrace; + } + + public int expectedErr() { + return expectedErr; + } + + @Override + public synchronized Throwable fillInStackTrace() { + if (fillInStackTrace) { + return super.fillInStackTrace(); + } + return this; + } + } + + static final class NativeConnectException extends ConnectException { + private static final long serialVersionUID = -5532328671712318161L; + private final int expectedErr; + NativeConnectException(String method, int expectedErr) { + super(method + "(..) failed: " + errnoString(-expectedErr)); + this.expectedErr = expectedErr; + } + + int expectedErr() { + return expectedErr; + } + } + + static { + for (int i = 0; i < ERRORS.length; i++) { + // This is ok as strerror returns 'Unknown error i' when the message is not known. + ERRORS[i] = strError(i); + } + } + + public static boolean handleConnectErrno(String method, int err) throws IOException { + if (err == ERRNO_EINPROGRESS_NEGATIVE || err == ERROR_EALREADY_NEGATIVE) { + // connect not complete yet need to wait for EPOLLOUT event. + // EALREADY has been observed when using tcp fast open on centos8. + return false; + } + throw newConnectException0(method, err); + } + + /** + * @deprecated Use {@link #handleConnectErrno(String, int)}. + * @param method The native method name which caused the errno. + * @param err the negative value of the errno. + * @throws IOException The errno translated into an exception. + */ + @Deprecated + public static void throwConnectException(String method, int err) throws IOException { + if (err == ERROR_EALREADY_NEGATIVE) { + throw new ConnectionPendingException(); + } + throw newConnectException0(method, err); + } + + private static String errnoString(int err) { + // Check first if we had it cached, if not we need to do a JNI call. + if (err < ERRORS.length - 1) { + return ERRORS[err]; + } + return strError(err); + } + + private static IOException newConnectException0(String method, int err) { + if (err == ERROR_ENETUNREACH_NEGATIVE || err == ERROR_EHOSTUNREACH_NEGATIVE) { + return new NoRouteToHostException(); + } + if (err == ERROR_EISCONN_NEGATIVE) { + throw new AlreadyConnectedException(); + } + if (err == ERRNO_ENOENT_NEGATIVE) { + return new FileNotFoundException(); + } + return new ConnectException(method + "(..) failed: " + errnoString(-err)); + } + + public static NativeIoException newConnectionResetException(String method, int errnoNegative) { + NativeIoException exception = new NativeIoException(method, errnoNegative, false); + exception.setStackTrace(EmptyArrays.EMPTY_STACK_TRACE); + return exception; + } + + public static NativeIoException newIOException(String method, int err) { + return new NativeIoException(method, err); + } + + @Deprecated + public static int ioResult(String method, int err, NativeIoException resetCause, + ClosedChannelException closedCause) throws IOException { + // network stack saturated... try again later + if (err == ERRNO_EAGAIN_NEGATIVE || err == ERRNO_EWOULDBLOCK_NEGATIVE) { + return 0; + } + if (err == resetCause.expectedErr()) { + throw resetCause; + } + if (err == ERRNO_EBADF_NEGATIVE) { + throw closedCause; + } + if (err == ERRNO_ENOTCONN_NEGATIVE) { + throw new NotYetConnectedException(); + } + if (err == ERRNO_ENOENT_NEGATIVE) { + throw new FileNotFoundException(); + } + + // TODO: We could even go further and use a pre-instantiated IOException for the other error codes, but for + // all other errors it may be better to just include a stack trace. + throw newIOException(method, err); + } + + public static int ioResult(String method, int err) throws IOException { + // network stack saturated... try again later + if (err == ERRNO_EAGAIN_NEGATIVE || err == ERRNO_EWOULDBLOCK_NEGATIVE) { + return 0; + } + if (err == ERRNO_EBADF_NEGATIVE) { + throw new ClosedChannelException(); + } + if (err == ERRNO_ENOTCONN_NEGATIVE) { + throw new NotYetConnectedException(); + } + if (err == ERRNO_ENOENT_NEGATIVE) { + throw new FileNotFoundException(); + } + + throw new NativeIoException(method, err, false); + } + + private Errors() { } +} diff --git a/netty-channel-unix/src/main/java/io/netty/channel/unix/ErrorsStaticallyReferencedJniMethods.java b/netty-channel-unix/src/main/java/io/netty/channel/unix/ErrorsStaticallyReferencedJniMethods.java new file mode 100644 index 0000000..3e04107 --- /dev/null +++ b/netty-channel-unix/src/main/java/io/netty/channel/unix/ErrorsStaticallyReferencedJniMethods.java @@ -0,0 +1,47 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.unix; + +/** + * This class is necessary to break the following cyclic dependency: + *

    + *
  1. JNI_OnLoad
  2. + *
  3. JNI Calls FindClass because RegisterNatives (used to register JNI methods) requires a class
  4. + *
  5. FindClass loads the class, but static members variables of that class attempt to call a JNI method which has not + * yet been registered.
  6. + *
  7. java.lang.UnsatisfiedLinkError is thrown because native method has not yet been registered.
  8. + *
+ * Static members which call JNI methods must not be declared in this class! + */ +final class ErrorsStaticallyReferencedJniMethods { + + private ErrorsStaticallyReferencedJniMethods() { } + + static native int errnoENOENT(); + static native int errnoEBADF(); + static native int errnoEPIPE(); + static native int errnoECONNRESET(); + static native int errnoENOTCONN(); + static native int errnoEAGAIN(); + static native int errnoEWOULDBLOCK(); + static native int errnoEINPROGRESS(); + static native int errorECONNREFUSED(); + static native int errorEISCONN(); + static native int errorEALREADY(); + static native int errorENETUNREACH(); + static native int errorEHOSTUNREACH(); + static native String strError(int err); +} diff --git a/netty-channel-unix/src/main/java/io/netty/channel/unix/FileDescriptor.java b/netty-channel-unix/src/main/java/io/netty/channel/unix/FileDescriptor.java new file mode 100644 index 0000000..e397aa3 --- /dev/null +++ b/netty-channel-unix/src/main/java/io/netty/channel/unix/FileDescriptor.java @@ -0,0 +1,240 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.unix; + + +import java.io.File; +import java.io.IOException; +import java.nio.ByteBuffer; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; + +import static io.netty.channel.unix.Errors.ioResult; +import static io.netty.channel.unix.Errors.newIOException; +import static io.netty.channel.unix.Limits.IOV_MAX; +import static io.netty.util.internal.ObjectUtil.checkNotNull; +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; +import static java.lang.Math.min; + +/** + * Native {@link FileDescriptor} implementation which allows to wrap an {@code int} and provide a + * {@link FileDescriptor} for it. + */ +public class FileDescriptor { + + private static final AtomicIntegerFieldUpdater stateUpdater = + AtomicIntegerFieldUpdater.newUpdater(FileDescriptor.class, "state"); + + private static final int STATE_CLOSED_MASK = 1; + private static final int STATE_INPUT_SHUTDOWN_MASK = 1 << 1; + private static final int STATE_OUTPUT_SHUTDOWN_MASK = 1 << 2; + private static final int STATE_ALL_MASK = STATE_CLOSED_MASK | + STATE_INPUT_SHUTDOWN_MASK | + STATE_OUTPUT_SHUTDOWN_MASK; + + /** + * Bit map = [Output Shutdown | Input Shutdown | Closed] + */ + volatile int state; + final int fd; + + public FileDescriptor(int fd) { + checkPositiveOrZero(fd, "fd"); + this.fd = fd; + } + + /** + * Return the int value of the filedescriptor. + */ + public final int intValue() { + return fd; + } + + protected boolean markClosed() { + for (;;) { + int state = this.state; + if (isClosed(state)) { + return false; + } + // Once a close operation happens, the channel is considered shutdown. + if (casState(state, state | STATE_ALL_MASK)) { + return true; + } + } + } + + /** + * Close the file descriptor. + */ + public void close() throws IOException { + if (markClosed()) { + int res = close(fd); + if (res < 0) { + throw newIOException("close", res); + } + } + } + + /** + * Returns {@code true} if the file descriptor is open. + */ + public boolean isOpen() { + return !isClosed(state); + } + + public final int write(ByteBuffer buf, int pos, int limit) throws IOException { + int res = write(fd, buf, pos, limit); + if (res >= 0) { + return res; + } + return ioResult("write", res); + } + + public final int writeAddress(long address, int pos, int limit) throws IOException { + int res = writeAddress(fd, address, pos, limit); + if (res >= 0) { + return res; + } + return ioResult("writeAddress", res); + } + + public final long writev(ByteBuffer[] buffers, int offset, int length, long maxBytesToWrite) throws IOException { + long res = writev(fd, buffers, offset, min(IOV_MAX, length), maxBytesToWrite); + if (res >= 0) { + return res; + } + return ioResult("writev", (int) res); + } + + public final long writevAddresses(long memoryAddress, int length) throws IOException { + long res = writevAddresses(fd, memoryAddress, length); + if (res >= 0) { + return res; + } + return ioResult("writevAddresses", (int) res); + } + + public final int read(ByteBuffer buf, int pos, int limit) throws IOException { + int res = read(fd, buf, pos, limit); + if (res > 0) { + return res; + } + if (res == 0) { + return -1; + } + return ioResult("read", res); + } + + public final int readAddress(long address, int pos, int limit) throws IOException { + int res = readAddress(fd, address, pos, limit); + if (res > 0) { + return res; + } + if (res == 0) { + return -1; + } + return ioResult("readAddress", res); + } + + @Override + public String toString() { + return "FileDescriptor{" + + "fd=" + fd + + '}'; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof FileDescriptor)) { + return false; + } + + return fd == ((FileDescriptor) o).fd; + } + + @Override + public int hashCode() { + return fd; + } + + /** + * Open a new {@link FileDescriptor} for the given path. + */ + public static FileDescriptor from(String path) throws IOException { + int res = open(checkNotNull(path, "path")); + if (res < 0) { + throw newIOException("open", res); + } + return new FileDescriptor(res); + } + + /** + * Open a new {@link FileDescriptor} for the given {@link File}. + */ + public static FileDescriptor from(File file) throws IOException { + return from(checkNotNull(file, "file").getPath()); + } + + /** + * @return [0] = read end, [1] = write end + */ + public static FileDescriptor[] pipe() throws IOException { + long res = newPipe(); + if (res < 0) { + throw newIOException("newPipe", (int) res); + } + return new FileDescriptor[]{new FileDescriptor((int) (res >>> 32)), new FileDescriptor((int) res)}; + } + + final boolean casState(int expected, int update) { + return stateUpdater.compareAndSet(this, expected, update); + } + + static boolean isClosed(int state) { + return (state & STATE_CLOSED_MASK) != 0; + } + + static boolean isInputShutdown(int state) { + return (state & STATE_INPUT_SHUTDOWN_MASK) != 0; + } + + static boolean isOutputShutdown(int state) { + return (state & STATE_OUTPUT_SHUTDOWN_MASK) != 0; + } + + static int inputShutdown(int state) { + return state | STATE_INPUT_SHUTDOWN_MASK; + } + + static int outputShutdown(int state) { + return state | STATE_OUTPUT_SHUTDOWN_MASK; + } + + private static native int open(String path); + private static native int close(int fd); + + private static native int write(int fd, ByteBuffer buf, int pos, int limit); + private static native int writeAddress(int fd, long address, int pos, int limit); + private static native long writev(int fd, ByteBuffer[] buffers, int offset, int length, long maxBytesToWrite); + private static native long writevAddresses(int fd, long memoryAddress, int length); + + private static native int read(int fd, ByteBuffer buf, int pos, int limit); + private static native int readAddress(int fd, long address, int pos, int limit); + + private static native long newPipe(); +} diff --git a/netty-channel-unix/src/main/java/io/netty/channel/unix/GenericUnixChannelOption.java b/netty-channel-unix/src/main/java/io/netty/channel/unix/GenericUnixChannelOption.java new file mode 100644 index 0000000..6c3f6f9 --- /dev/null +++ b/netty-channel-unix/src/main/java/io/netty/channel/unix/GenericUnixChannelOption.java @@ -0,0 +1,51 @@ +/* + * Copyright 2022 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.unix; + +/** + * A generic socket option. See man setsockopt. + * + * @param the value type + */ +public abstract class GenericUnixChannelOption extends UnixChannelOption { + + private final int level; + private final int optname; + + GenericUnixChannelOption(String name, int level, int optname) { + super(name); + this.level = level; + this.optname = optname; + } + + /** + * Returns the level. See man setsockopt + * + * @return the level. + */ + public int level() { + return level; + } + + /** + * Returns the optname. See man setsockopt + * + * @return the level. + */ + public int optname() { + return optname; + } +} diff --git a/netty-channel-unix/src/main/java/io/netty/channel/unix/IntegerUnixChannelOption.java b/netty-channel-unix/src/main/java/io/netty/channel/unix/IntegerUnixChannelOption.java new file mode 100644 index 0000000..bfc6183 --- /dev/null +++ b/netty-channel-unix/src/main/java/io/netty/channel/unix/IntegerUnixChannelOption.java @@ -0,0 +1,32 @@ +/* + * Copyright 2022 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.unix; + +/** + * A {@link GenericUnixChannelOption} which uses an {@link Integer} as {@code optval}. + */ +public final class IntegerUnixChannelOption extends GenericUnixChannelOption { + /** + * Creates a new instance. + * + * @param name the name that is used. + * @param level the level. + * @param optname the optname. + */ + public IntegerUnixChannelOption(String name, int level, int optname) { + super(name, level, optname); + } +} diff --git a/netty-channel-unix/src/main/java/io/netty/channel/unix/IovArray.java b/netty-channel-unix/src/main/java/io/netty/channel/unix/IovArray.java new file mode 100644 index 0000000..6071ccc --- /dev/null +++ b/netty-channel-unix/src/main/java/io/netty/channel/unix/IovArray.java @@ -0,0 +1,237 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.unix; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelOutboundBuffer.MessageProcessor; +import io.netty.util.internal.PlatformDependent; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; + +import static io.netty.channel.unix.Limits.IOV_MAX; +import static io.netty.channel.unix.Limits.SSIZE_MAX; +import static io.netty.util.internal.ObjectUtil.checkPositive; +import static java.lang.Math.min; + +/** + * Represent an array of struct array and so can be passed directly over via JNI without the need to do any more + * array copies. + * + * The buffers are written out directly into direct memory to match the struct iov. See also {@code man writev}. + * + *
+ * struct iovec {
+ *   void  *iov_base;
+ *   size_t iov_len;
+ * };
+ * 
+ * + * See also + * Efficient JNI programming IV: Wrapping native data objects. + */ +public final class IovArray implements MessageProcessor { + + /** The size of an address which should be 8 for 64 bits and 4 for 32 bits. */ + private static final int ADDRESS_SIZE = Buffer.addressSize(); + + /** + * The size of an {@code iovec} struct in bytes. This is calculated as we have 2 entries each of the size of the + * address. + */ + public static final int IOV_SIZE = 2 * ADDRESS_SIZE; + + /** + * The needed memory to hold up to {@code IOV_MAX} iov entries, where {@code IOV_MAX} signified + * the maximum number of {@code iovec} structs that can be passed to {@code writev(...)}. + */ + private static final int MAX_CAPACITY = IOV_MAX * IOV_SIZE; + + private final long memoryAddress; + private final ByteBuf memory; + private int count; + private long size; + private long maxBytes = SSIZE_MAX; + + public IovArray() { + this(Unpooled.wrappedBuffer(Buffer.allocateDirectWithNativeOrder(MAX_CAPACITY)).setIndex(0, 0)); + } + + @SuppressWarnings("deprecation") + public IovArray(ByteBuf memory) { + assert memory.writerIndex() == 0; + assert memory.readerIndex() == 0; + this.memory = PlatformDependent.hasUnsafe() ? memory : memory.order( + PlatformDependent.BIG_ENDIAN_NATIVE_ORDER ? ByteOrder.BIG_ENDIAN : ByteOrder.LITTLE_ENDIAN); + if (memory.hasMemoryAddress()) { + memoryAddress = memory.memoryAddress(); + } else { + // Fallback to using JNI as we were not be able to access the address otherwise. + memoryAddress = Buffer.memoryAddress(memory.internalNioBuffer(0, memory.capacity())); + } + } + + public void clear() { + count = 0; + size = 0; + } + + /** + * @deprecated Use {@link #add(ByteBuf, int, int)} + */ + @Deprecated + public boolean add(ByteBuf buf) { + return add(buf, buf.readerIndex(), buf.readableBytes()); + } + + public boolean add(ByteBuf buf, int offset, int len) { + if (count == IOV_MAX) { + // No more room! + return false; + } + if (buf.nioBufferCount() == 1) { + if (len == 0) { + return true; + } + if (buf.hasMemoryAddress()) { + return add(memoryAddress, buf.memoryAddress() + offset, len); + } else { + ByteBuffer nioBuffer = buf.internalNioBuffer(offset, len); + return add(memoryAddress, Buffer.memoryAddress(nioBuffer) + nioBuffer.position(), len); + } + } else { + ByteBuffer[] buffers = buf.nioBuffers(offset, len); + for (ByteBuffer nioBuffer : buffers) { + final int remaining = nioBuffer.remaining(); + if (remaining != 0 && + (!add(memoryAddress, Buffer.memoryAddress(nioBuffer) + nioBuffer.position(), remaining) + || count == IOV_MAX)) { + return false; + } + } + return true; + } + } + + private boolean add(long memoryAddress, long addr, int len) { + assert addr != 0; + + // If there is at least 1 entry then we enforce the maximum bytes. We want to accept at least one entry so we + // will attempt to write some data and make progress. + if ((maxBytes - len < size && count > 0) || + // Check if we have enough space left + memory.capacity() < (count + 1) * IOV_SIZE) { + // If the size + len will overflow SSIZE_MAX we stop populate the IovArray. This is done as linux + // not allow to write more bytes then SSIZE_MAX with one writev(...) call and so will + // return 'EINVAL', which will raise an IOException. + // + // See also: + // - https://linux.die.net//man/2/writev + return false; + } + final int baseOffset = idx(count); + final int lengthOffset = baseOffset + ADDRESS_SIZE; + + size += len; + ++count; + + if (ADDRESS_SIZE == 8) { + // 64bit + if (PlatformDependent.hasUnsafe()) { + PlatformDependent.putLong(baseOffset + memoryAddress, addr); + PlatformDependent.putLong(lengthOffset + memoryAddress, len); + } else { + memory.setLong(baseOffset, addr); + memory.setLong(lengthOffset, len); + } + } else { + assert ADDRESS_SIZE == 4; + if (PlatformDependent.hasUnsafe()) { + PlatformDependent.putInt(baseOffset + memoryAddress, (int) addr); + PlatformDependent.putInt(lengthOffset + memoryAddress, len); + } else { + memory.setInt(baseOffset, (int) addr); + memory.setInt(lengthOffset, len); + } + } + return true; + } + + /** + * Returns the number if iov entries. + */ + public int count() { + return count; + } + + /** + * Returns the size in bytes + */ + public long size() { + return size; + } + + /** + * Set the maximum amount of bytes that can be added to this {@link IovArray} via {@link #add(ByteBuf, int, int)} + *

+ * This will not impact the existing state of the {@link IovArray}, and only applies to subsequent calls to + * {@link #add(ByteBuf)}. + *

+ * In order to ensure some progress is made at least one {@link ByteBuf} will be accepted even if it's size exceeds + * this value. + * @param maxBytes the maximum amount of bytes that can be added to this {@link IovArray}. + */ + public void maxBytes(long maxBytes) { + this.maxBytes = min(SSIZE_MAX, checkPositive(maxBytes, "maxBytes")); + } + + /** + * Get the maximum amount of bytes that can be added to this {@link IovArray}. + * @return the maximum amount of bytes that can be added to this {@link IovArray}. + */ + public long maxBytes() { + return maxBytes; + } + + /** + * Returns the {@code memoryAddress} for the given {@code offset}. + */ + public long memoryAddress(int offset) { + return memoryAddress + idx(offset); + } + + /** + * Release the {@link IovArray}. Once release further using of it may crash the JVM! + */ + public void release() { + memory.release(); + } + + @Override + public boolean processMessage(Object msg) throws Exception { + if (msg instanceof ByteBuf) { + ByteBuf buffer = (ByteBuf) msg; + return add(buffer, buffer.readerIndex(), buffer.readableBytes()); + } + return false; + } + + private static int idx(int index) { + return IOV_SIZE * index; + } +} diff --git a/netty-channel-unix/src/main/java/io/netty/channel/unix/Limits.java b/netty-channel-unix/src/main/java/io/netty/channel/unix/Limits.java new file mode 100644 index 0000000..606c829 --- /dev/null +++ b/netty-channel-unix/src/main/java/io/netty/channel/unix/Limits.java @@ -0,0 +1,31 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.unix; + +import static io.netty.channel.unix.LimitsStaticallyReferencedJniMethods.iovMax; +import static io.netty.channel.unix.LimitsStaticallyReferencedJniMethods.sizeOfjlong; +import static io.netty.channel.unix.LimitsStaticallyReferencedJniMethods.ssizeMax; +import static io.netty.channel.unix.LimitsStaticallyReferencedJniMethods.uioMaxIov; + +public final class Limits { + public static final int IOV_MAX = iovMax(); + public static final int UIO_MAX_IOV = uioMaxIov(); + public static final long SSIZE_MAX = ssizeMax(); + + public static final int SIZEOF_JLONG = sizeOfjlong(); + + private Limits() { } +} diff --git a/netty-channel-unix/src/main/java/io/netty/channel/unix/LimitsStaticallyReferencedJniMethods.java b/netty-channel-unix/src/main/java/io/netty/channel/unix/LimitsStaticallyReferencedJniMethods.java new file mode 100644 index 0000000..5480eea --- /dev/null +++ b/netty-channel-unix/src/main/java/io/netty/channel/unix/LimitsStaticallyReferencedJniMethods.java @@ -0,0 +1,37 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.unix; + +/** + * This class is necessary to break the following cyclic dependency: + *

    + *
  1. JNI_OnLoad
  2. + *
  3. JNI Calls FindClass because RegisterNatives (used to register JNI methods) requires a class
  4. + *
  5. FindClass loads the class, but static members variables of that class attempt to call a JNI method which has not + * yet been registered.
  6. + *
  7. java.lang.UnsatisfiedLinkError is thrown because native method has not yet been registered.
  8. + *
+ * Static members which call JNI methods must not be declared in this class! + */ +final class LimitsStaticallyReferencedJniMethods { + private LimitsStaticallyReferencedJniMethods() { } + + static native long ssizeMax(); + static native int iovMax(); + static native int uioMaxIov(); + static native int sizeOfjlong(); + static native int udsSunPathSize(); +} diff --git a/netty-channel-unix/src/main/java/io/netty/channel/unix/NativeInetAddress.java b/netty-channel-unix/src/main/java/io/netty/channel/unix/NativeInetAddress.java new file mode 100644 index 0000000..d4c6904 --- /dev/null +++ b/netty-channel-unix/src/main/java/io/netty/channel/unix/NativeInetAddress.java @@ -0,0 +1,111 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.unix; + +import java.net.Inet6Address; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.UnknownHostException; + +/** + * Internal usage only! + */ +public final class NativeInetAddress { + private static final byte[] IPV4_MAPPED_IPV6_PREFIX = { + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, (byte) 0xff, (byte) 0xff }; + final byte[] address; + final int scopeId; + + public static NativeInetAddress newInstance(InetAddress addr) { + byte[] bytes = addr.getAddress(); + if (addr instanceof Inet6Address) { + return new NativeInetAddress(bytes, ((Inet6Address) addr).getScopeId()); + } else { + // convert to ipv4 mapped ipv6 address; + return new NativeInetAddress(ipv4MappedIpv6Address(bytes)); + } + } + + public NativeInetAddress(byte[] address, int scopeId) { + this.address = address; + this.scopeId = scopeId; + } + + public NativeInetAddress(byte[] address) { + this(address, 0); + } + + public byte[] address() { + return address; + } + + public int scopeId() { + return scopeId; + } + + public static byte[] ipv4MappedIpv6Address(byte[] ipv4) { + byte[] address = new byte[16]; + copyIpv4MappedIpv6Address(ipv4, address); + return address; + } + + public static void copyIpv4MappedIpv6Address(byte[] ipv4, byte[] ipv6) { + System.arraycopy(IPV4_MAPPED_IPV6_PREFIX, 0, ipv6, 0, IPV4_MAPPED_IPV6_PREFIX.length); + System.arraycopy(ipv4, 0, ipv6, 12, ipv4.length); + } + + public static InetSocketAddress address(byte[] addr, int offset, int len) { + // The last 4 bytes are always the port + final int port = decodeInt(addr, offset + len - 4); + final InetAddress address; + + try { + switch (len) { + // 8 bytes: + // - 4 == ipaddress + // - 4 == port + case 8: + byte[] ipv4 = new byte[4]; + System.arraycopy(addr, offset, ipv4, 0, 4); + address = InetAddress.getByAddress(ipv4); + break; + + // 24 bytes: + // - 16 == ipaddress + // - 4 == scopeId + // - 4 == port + case 24: + byte[] ipv6 = new byte[16]; + System.arraycopy(addr, offset, ipv6, 0, 16); + int scopeId = decodeInt(addr, offset + len - 8); + address = Inet6Address.getByAddress(null, ipv6, scopeId); + break; + default: + throw new Error(); + } + return new InetSocketAddress(address, port); + } catch (UnknownHostException e) { + throw new Error("Should never happen", e); + } + } + + static int decodeInt(byte[] addr, int index) { + return (addr[index] & 0xff) << 24 | + (addr[index + 1] & 0xff) << 16 | + (addr[index + 2] & 0xff) << 8 | + addr[index + 3] & 0xff; + } +} diff --git a/netty-channel-unix/src/main/java/io/netty/channel/unix/PeerCredentials.java b/netty-channel-unix/src/main/java/io/netty/channel/unix/PeerCredentials.java new file mode 100644 index 0000000..f05e92a --- /dev/null +++ b/netty-channel-unix/src/main/java/io/netty/channel/unix/PeerCredentials.java @@ -0,0 +1,74 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.channel.unix; + +import io.netty.util.internal.UnstableApi; + +import static io.netty.util.internal.EmptyArrays.EMPTY_INTS; + +/** + * User credentials discovered for the peer unix domain socket. + * + * The PID, UID and GID of the user connected on the other side of the unix domain socket + * For details see: + * SO_PEERCRED + */ +@UnstableApi +public final class PeerCredentials { + private final int pid; + private final int uid; + private final int[] gids; + + // These values are set by JNI via Socket.peerCredentials() + PeerCredentials(int p, int u, int... gids) { + pid = p; + uid = u; + this.gids = gids == null ? EMPTY_INTS : gids; + } + + /** + * Get the PID of the peer process. + *

+ * This is currently not populated on MacOS and BSD based systems. + * @return The PID of the peer process. + */ + public int pid() { + return pid; + } + + public int uid() { + return uid; + } + + public int[] gids() { + return gids.clone(); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(128); + sb.append("UserCredentials[pid=").append(pid).append("; uid=").append(uid).append("; gids=["); + if (gids.length > 0) { + sb.append(gids[0]); + for (int i = 1; i < gids.length; ++i) { + sb.append(", ").append(gids[i]); + } + } + sb.append(']'); + return sb.toString(); + } +} diff --git a/netty-channel-unix/src/main/java/io/netty/channel/unix/PreferredDirectByteBufAllocator.java b/netty-channel-unix/src/main/java/io/netty/channel/unix/PreferredDirectByteBufAllocator.java new file mode 100644 index 0000000..41b2997 --- /dev/null +++ b/netty-channel-unix/src/main/java/io/netty/channel/unix/PreferredDirectByteBufAllocator.java @@ -0,0 +1,130 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.unix; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.util.internal.UnstableApi; + +@UnstableApi +public final class PreferredDirectByteBufAllocator implements ByteBufAllocator { + private ByteBufAllocator allocator; + + public void updateAllocator(ByteBufAllocator allocator) { + this.allocator = allocator; + } + + @Override + public ByteBuf buffer() { + return allocator.directBuffer(); + } + + @Override + public ByteBuf buffer(int initialCapacity) { + return allocator.directBuffer(initialCapacity); + } + + @Override + public ByteBuf buffer(int initialCapacity, int maxCapacity) { + return allocator.directBuffer(initialCapacity, maxCapacity); + } + + @Override + public ByteBuf ioBuffer() { + return allocator.directBuffer(); + } + + @Override + public ByteBuf ioBuffer(int initialCapacity) { + return allocator.directBuffer(initialCapacity); + } + + @Override + public ByteBuf ioBuffer(int initialCapacity, int maxCapacity) { + return allocator.directBuffer(initialCapacity, maxCapacity); + } + + @Override + public ByteBuf heapBuffer() { + return allocator.heapBuffer(); + } + + @Override + public ByteBuf heapBuffer(int initialCapacity) { + return allocator.heapBuffer(initialCapacity); + } + + @Override + public ByteBuf heapBuffer(int initialCapacity, int maxCapacity) { + return allocator.heapBuffer(initialCapacity, maxCapacity); + } + + @Override + public ByteBuf directBuffer() { + return allocator.directBuffer(); + } + + @Override + public ByteBuf directBuffer(int initialCapacity) { + return allocator.directBuffer(initialCapacity); + } + + @Override + public ByteBuf directBuffer(int initialCapacity, int maxCapacity) { + return allocator.directBuffer(initialCapacity, maxCapacity); + } + + @Override + public CompositeByteBuf compositeBuffer() { + return allocator.compositeDirectBuffer(); + } + + @Override + public CompositeByteBuf compositeBuffer(int maxNumComponents) { + return allocator.compositeDirectBuffer(maxNumComponents); + } + + @Override + public CompositeByteBuf compositeHeapBuffer() { + return allocator.compositeHeapBuffer(); + } + + @Override + public CompositeByteBuf compositeHeapBuffer(int maxNumComponents) { + return allocator.compositeHeapBuffer(maxNumComponents); + } + + @Override + public CompositeByteBuf compositeDirectBuffer() { + return allocator.compositeDirectBuffer(); + } + + @Override + public CompositeByteBuf compositeDirectBuffer(int maxNumComponents) { + return allocator.compositeDirectBuffer(maxNumComponents); + } + + @Override + public boolean isDirectBufferPooled() { + return allocator.isDirectBufferPooled(); + } + + @Override + public int calculateNewCapacity(int minNewCapacity, int maxCapacity) { + return allocator.calculateNewCapacity(minNewCapacity, maxCapacity); + } +} diff --git a/netty-channel-unix/src/main/java/io/netty/channel/unix/RawUnixChannelOption.java b/netty-channel-unix/src/main/java/io/netty/channel/unix/RawUnixChannelOption.java new file mode 100644 index 0000000..1ba237d --- /dev/null +++ b/netty-channel-unix/src/main/java/io/netty/channel/unix/RawUnixChannelOption.java @@ -0,0 +1,61 @@ +/* + * Copyright 2022 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.unix; + + +import io.netty.util.internal.ObjectUtil; + +import java.nio.ByteBuffer; + +/** + * A {@link GenericUnixChannelOption} which uses an {@link ByteBuffer} as {@code optval}. The user is responsible + * to fill the {@link ByteBuffer} in a correct manner, so it works with the {@param level} and {@param optname}. + */ +public final class RawUnixChannelOption extends GenericUnixChannelOption { + + private final int length; + + /** + * Creates a new instance. + * + * @param name the name that is used. + * @param level the level. + * @param length the expected length of the optvalue. + * @param optname the optname. + */ + public RawUnixChannelOption(String name, int level, int optname, int length) { + super(name, level, optname); + this.length = ObjectUtil.checkPositive(length, "length"); + } + + /** + * The length of the optval. + * + * @return the length. + */ + public int length() { + return length; + } + + @Override + public void validate(ByteBuffer value) { + super.validate(value); + if (value.remaining() != length) { + throw new IllegalArgumentException("Length of value does not match. Expected " + + length + ", but got " + value.remaining()); + } + } +} diff --git a/netty-channel-unix/src/main/java/io/netty/channel/unix/SegmentedDatagramPacket.java b/netty-channel-unix/src/main/java/io/netty/channel/unix/SegmentedDatagramPacket.java new file mode 100644 index 0000000..01f21ce --- /dev/null +++ b/netty-channel-unix/src/main/java/io/netty/channel/unix/SegmentedDatagramPacket.java @@ -0,0 +1,109 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.unix; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.socket.DatagramPacket; +import io.netty.util.internal.ObjectUtil; + +import java.net.InetSocketAddress; + +/** + * Allows to use GSO + * if the underlying OS supports it. Before using this you should ensure your system support it. + */ +public class SegmentedDatagramPacket extends DatagramPacket { + + private final int segmentSize; + + /** + * Create a new instance. + * + * @param data the {@link ByteBuf} which must be continguous. + * @param segmentSize the segment size. + * @param recipient the recipient. + */ + public SegmentedDatagramPacket(ByteBuf data, int segmentSize, InetSocketAddress recipient) { + super(data, recipient); + this.segmentSize = ObjectUtil.checkPositive(segmentSize, "segmentSize"); + } + + /** + * Create a new instance. + * + * @param data the {@link ByteBuf} which must be continguous. + * @param segmentSize the segment size. + * @param recipient the recipient. + */ + public SegmentedDatagramPacket(ByteBuf data, int segmentSize, + InetSocketAddress recipient, InetSocketAddress sender) { + super(data, recipient, sender); + this.segmentSize = ObjectUtil.checkPositive(segmentSize, "segmentSize"); + } + + /** + * Return the size of each segment (the last segment can be smaller). + * + * @return size of segments. + */ + public int segmentSize() { + return segmentSize; + } + + @Override + public SegmentedDatagramPacket copy() { + return new SegmentedDatagramPacket(content().copy(), segmentSize, recipient(), sender()); + } + + @Override + public SegmentedDatagramPacket duplicate() { + return new SegmentedDatagramPacket(content().duplicate(), segmentSize, recipient(), sender()); + } + + @Override + public SegmentedDatagramPacket retainedDuplicate() { + return new SegmentedDatagramPacket(content().retainedDuplicate(), segmentSize, recipient(), sender()); + } + + @Override + public SegmentedDatagramPacket replace(ByteBuf content) { + return new SegmentedDatagramPacket(content, segmentSize, recipient(), sender()); + } + + @Override + public SegmentedDatagramPacket retain() { + super.retain(); + return this; + } + + @Override + public SegmentedDatagramPacket retain(int increment) { + super.retain(increment); + return this; + } + + @Override + public SegmentedDatagramPacket touch() { + super.touch(); + return this; + } + + @Override + public SegmentedDatagramPacket touch(Object hint) { + super.touch(hint); + return this; + } +} diff --git a/netty-channel-unix/src/main/java/io/netty/channel/unix/ServerDomainSocketChannel.java b/netty-channel-unix/src/main/java/io/netty/channel/unix/ServerDomainSocketChannel.java new file mode 100644 index 0000000..78b5c48 --- /dev/null +++ b/netty-channel-unix/src/main/java/io/netty/channel/unix/ServerDomainSocketChannel.java @@ -0,0 +1,30 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.unix; + +import io.netty.channel.ServerChannel; + +/** + * {@link ServerChannel} that accepts {@link DomainSocketChannel}'s via + * Unix Domain Socket. + */ +public interface ServerDomainSocketChannel extends ServerChannel, UnixChannel { + @Override + DomainSocketAddress remoteAddress(); + + @Override + DomainSocketAddress localAddress(); +} diff --git a/netty-channel-unix/src/main/java/io/netty/channel/unix/Socket.java b/netty-channel-unix/src/main/java/io/netty/channel/unix/Socket.java new file mode 100644 index 0000000..7e20967 --- /dev/null +++ b/netty-channel-unix/src/main/java/io/netty/channel/unix/Socket.java @@ -0,0 +1,720 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.unix; + +import io.netty.channel.ChannelException; +import io.netty.channel.socket.InternetProtocolFamily; +import io.netty.util.CharsetUtil; +import io.netty.util.NetUtil; + +import java.io.IOException; +import java.net.Inet6Address; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.PortUnreachableException; +import java.net.SocketAddress; +import java.nio.ByteBuffer; +import java.nio.channels.ClosedChannelException; + +import static io.netty.channel.unix.Errors.ERRNO_EAGAIN_NEGATIVE; +import static io.netty.channel.unix.Errors.ERRNO_EINPROGRESS_NEGATIVE; +import static io.netty.channel.unix.Errors.ERRNO_EWOULDBLOCK_NEGATIVE; +import static io.netty.channel.unix.Errors.ERROR_ECONNREFUSED_NEGATIVE; +import static io.netty.channel.unix.Errors.handleConnectErrno; +import static io.netty.channel.unix.Errors.ioResult; +import static io.netty.channel.unix.Errors.newIOException; +import static io.netty.channel.unix.NativeInetAddress.address; +import static io.netty.channel.unix.NativeInetAddress.ipv4MappedIpv6Address; + +/** + * Provides a JNI bridge to native socket operations. + * Internal usage only! + */ +public class Socket extends FileDescriptor { + + private static volatile boolean isIpv6Preferred; + + @Deprecated + public static final int UDS_SUN_PATH_SIZE = 100; + + protected final boolean ipv6; + + public Socket(int fd) { + super(fd); + ipv6 = isIPv6(fd); + } + /** + * Returns {@code true} if we should use IPv6 internally, {@code false} otherwise. + */ + private boolean useIpv6(InetAddress address) { + return useIpv6(this, address); + } + + /** + * Returns {@code true} if the given socket and address combination should use IPv6 internally, + * {@code false} otherwise. + */ + protected static boolean useIpv6(Socket socket, InetAddress address) { + return socket.ipv6 || address instanceof Inet6Address; + } + + public final void shutdown() throws IOException { + shutdown(true, true); + } + + public final void shutdown(boolean read, boolean write) throws IOException { + for (;;) { + // We need to only shutdown what has not been shutdown yet, and if there is no change we should not + // shutdown anything. This is because if the underlying FD is reused and we still have an object which + // represents the previous incarnation of the FD we need to be sure we don't inadvertently shutdown the + // "new" FD without explicitly having a change. + final int oldState = state; + if (isClosed(oldState)) { + throw new ClosedChannelException(); + } + int newState = oldState; + if (read && !isInputShutdown(newState)) { + newState = inputShutdown(newState); + } + if (write && !isOutputShutdown(newState)) { + newState = outputShutdown(newState); + } + + // If there is no change in state, then we should not take any action. + if (newState == oldState) { + return; + } + if (casState(oldState, newState)) { + break; + } + } + int res = shutdown(fd, read, write); + if (res < 0) { + ioResult("shutdown", res); + } + } + + public final boolean isShutdown() { + int state = this.state; + return isInputShutdown(state) && isOutputShutdown(state); + } + + public final boolean isInputShutdown() { + return isInputShutdown(state); + } + + public final boolean isOutputShutdown() { + return isOutputShutdown(state); + } + + public final int sendTo(ByteBuffer buf, int pos, int limit, InetAddress addr, int port) throws IOException { + return sendTo(buf, pos, limit, addr, port, false); + } + + public final int sendTo(ByteBuffer buf, int pos, int limit, InetAddress addr, int port, boolean fastOpen) + throws IOException { + // just duplicate the toNativeInetAddress code here to minimize object creation as this method is expected + // to be called frequently + byte[] address; + int scopeId; + if (addr instanceof Inet6Address) { + address = addr.getAddress(); + scopeId = ((Inet6Address) addr).getScopeId(); + } else { + // convert to ipv4 mapped ipv6 address; + scopeId = 0; + address = ipv4MappedIpv6Address(addr.getAddress()); + } + int flags = fastOpen ? msgFastopen() : 0; + int res = sendTo(fd, useIpv6(addr), buf, pos, limit, address, scopeId, port, flags); + if (res >= 0) { + return res; + } + if (res == ERRNO_EINPROGRESS_NEGATIVE && fastOpen) { + // This happens when we (as a client) have no pre-existing cookie for doing a fast-open connection. + // In this case, our TCP connection will be established normally, but no data was transmitted at this time. + // We'll just transmit the data with normal writes later. + return 0; + } + if (res == ERROR_ECONNREFUSED_NEGATIVE) { + throw new PortUnreachableException("sendTo failed"); + } + return ioResult("sendTo", res); + } + + public final int sendToDomainSocket(ByteBuffer buf, int pos, int limit, byte[] path) throws IOException { + int res = sendToDomainSocket(fd, buf, pos, limit, path); + if (res >= 0) { + return res; + } + return ioResult("sendToDomainSocket", res); + } + + public final int sendToAddress(long memoryAddress, int pos, int limit, InetAddress addr, int port) + throws IOException { + return sendToAddress(memoryAddress, pos, limit, addr, port, false); + } + + public final int sendToAddress(long memoryAddress, int pos, int limit, InetAddress addr, int port, + boolean fastOpen) throws IOException { + // just duplicate the toNativeInetAddress code here to minimize object creation as this method is expected + // to be called frequently + byte[] address; + int scopeId; + if (addr instanceof Inet6Address) { + address = addr.getAddress(); + scopeId = ((Inet6Address) addr).getScopeId(); + } else { + // convert to ipv4 mapped ipv6 address; + scopeId = 0; + address = ipv4MappedIpv6Address(addr.getAddress()); + } + int flags = fastOpen ? msgFastopen() : 0; + int res = sendToAddress(fd, useIpv6(addr), memoryAddress, pos, limit, address, scopeId, port, flags); + if (res >= 0) { + return res; + } + if (res == ERRNO_EINPROGRESS_NEGATIVE && fastOpen) { + // This happens when we (as a client) have no pre-existing cookie for doing a fast-open connection. + // In this case, our TCP connection will be established normally, but no data was transmitted at this time. + // We'll just transmit the data with normal writes later. + return 0; + } + if (res == ERROR_ECONNREFUSED_NEGATIVE) { + throw new PortUnreachableException("sendToAddress failed"); + } + return ioResult("sendToAddress", res); + } + + public final int sendToAddressDomainSocket(long memoryAddress, int pos, int limit, byte[] path) throws IOException { + int res = sendToAddressDomainSocket(fd, memoryAddress, pos, limit, path); + if (res >= 0) { + return res; + } + return ioResult("sendToAddressDomainSocket", res); + } + + public final int sendToAddresses(long memoryAddress, int length, InetAddress addr, int port) throws IOException { + return sendToAddresses(memoryAddress, length, addr, port, false); + } + + public final int sendToAddresses(long memoryAddress, int length, InetAddress addr, int port, boolean fastOpen) + throws IOException { + // just duplicate the toNativeInetAddress code here to minimize object creation as this method is expected + // to be called frequently + byte[] address; + int scopeId; + if (addr instanceof Inet6Address) { + address = addr.getAddress(); + scopeId = ((Inet6Address) addr).getScopeId(); + } else { + // convert to ipv4 mapped ipv6 address; + scopeId = 0; + address = ipv4MappedIpv6Address(addr.getAddress()); + } + int flags = fastOpen ? msgFastopen() : 0; + int res = sendToAddresses(fd, useIpv6(addr), memoryAddress, length, address, scopeId, port, flags); + if (res >= 0) { + return res; + } + if (res == ERRNO_EINPROGRESS_NEGATIVE && fastOpen) { + // This happens when we (as a client) have no pre-existing cookie for doing a fast-open connection. + // In this case, our TCP connection will be established normally, but no data was transmitted at this time. + // We'll just transmit the data with normal writes later. + return 0; + } + if (res == ERROR_ECONNREFUSED_NEGATIVE) { + throw new PortUnreachableException("sendToAddresses failed"); + } + return ioResult("sendToAddresses", res); + } + + public final int sendToAddressesDomainSocket(long memoryAddress, int length, byte[] path) throws IOException { + int res = sendToAddressesDomainSocket(fd, memoryAddress, length, path); + if (res >= 0) { + return res; + } + return ioResult("sendToAddressesDomainSocket", res); + } + + public final DatagramSocketAddress recvFrom(ByteBuffer buf, int pos, int limit) throws IOException { + return recvFrom(fd, buf, pos, limit); + } + + public final DatagramSocketAddress recvFromAddress(long memoryAddress, int pos, int limit) throws IOException { + return recvFromAddress(fd, memoryAddress, pos, limit); + } + + public final DomainDatagramSocketAddress recvFromDomainSocket(ByteBuffer buf, int pos, int limit) + throws IOException { + return recvFromDomainSocket(fd, buf, pos, limit); + } + + public final DomainDatagramSocketAddress recvFromAddressDomainSocket(long memoryAddress, int pos, int limit) + throws IOException { + return recvFromAddressDomainSocket(fd, memoryAddress, pos, limit); + } + + public int recv(ByteBuffer buf, int pos, int limit) throws IOException { + int res = recv(intValue(), buf, pos, limit); + if (res > 0) { + return res; + } + if (res == 0) { + return -1; + } + return ioResult("recv", res); + } + + public int recvAddress(long address, int pos, int limit) throws IOException { + int res = recvAddress(intValue(), address, pos, limit); + if (res > 0) { + return res; + } + if (res == 0) { + return -1; + } + return ioResult("recvAddress", res); + } + + public int send(ByteBuffer buf, int pos, int limit) throws IOException { + int res = send(intValue(), buf, pos, limit); + if (res >= 0) { + return res; + } + return ioResult("send", res); + } + + public int sendAddress(long address, int pos, int limit) throws IOException { + int res = sendAddress(intValue(), address, pos, limit); + if (res >= 0) { + return res; + } + return ioResult("sendAddress", res); + } + + public final int recvFd() throws IOException { + int res = recvFd(fd); + if (res > 0) { + return res; + } + if (res == 0) { + return -1; + } + + if (res == ERRNO_EAGAIN_NEGATIVE || res == ERRNO_EWOULDBLOCK_NEGATIVE) { + // Everything consumed so just return -1 here. + return 0; + } + throw newIOException("recvFd", res); + } + + public final int sendFd(int fdToSend) throws IOException { + int res = sendFd(fd, fdToSend); + if (res >= 0) { + return res; + } + if (res == ERRNO_EAGAIN_NEGATIVE || res == ERRNO_EWOULDBLOCK_NEGATIVE) { + // Everything consumed so just return -1 here. + return -1; + } + throw newIOException("sendFd", res); + } + + public final boolean connect(SocketAddress socketAddress) throws IOException { + int res; + if (socketAddress instanceof InetSocketAddress) { + InetSocketAddress inetSocketAddress = (InetSocketAddress) socketAddress; + InetAddress inetAddress = inetSocketAddress.getAddress(); + NativeInetAddress address = NativeInetAddress.newInstance(inetAddress); + res = connect(fd, useIpv6(inetAddress), address.address, address.scopeId, inetSocketAddress.getPort()); + } else if (socketAddress instanceof DomainSocketAddress) { + DomainSocketAddress unixDomainSocketAddress = (DomainSocketAddress) socketAddress; + res = connectDomainSocket(fd, unixDomainSocketAddress.path().getBytes(CharsetUtil.UTF_8)); + } else { + throw new Error("Unexpected SocketAddress implementation " + socketAddress); + } + if (res < 0) { + return handleConnectErrno("connect", res); + } + return true; + } + + public final boolean finishConnect() throws IOException { + int res = finishConnect(fd); + if (res < 0) { + return handleConnectErrno("finishConnect", res); + } + return true; + } + + public final void disconnect() throws IOException { + int res = disconnect(fd, ipv6); + if (res < 0) { + handleConnectErrno("disconnect", res); + } + } + + public final void bind(SocketAddress socketAddress) throws IOException { + if (socketAddress instanceof InetSocketAddress) { + InetSocketAddress addr = (InetSocketAddress) socketAddress; + InetAddress inetAddress = addr.getAddress(); + NativeInetAddress address = NativeInetAddress.newInstance(inetAddress); + int res = bind(fd, useIpv6(inetAddress), address.address, address.scopeId, addr.getPort()); + if (res < 0) { + throw newIOException("bind", res); + } + } else if (socketAddress instanceof DomainSocketAddress) { + DomainSocketAddress addr = (DomainSocketAddress) socketAddress; + int res = bindDomainSocket(fd, addr.path().getBytes(CharsetUtil.UTF_8)); + if (res < 0) { + throw newIOException("bind", res); + } + } else { + throw new Error("Unexpected SocketAddress implementation " + socketAddress); + } + } + + public final void listen(int backlog) throws IOException { + int res = listen(fd, backlog); + if (res < 0) { + throw newIOException("listen", res); + } + } + + public final int accept(byte[] addr) throws IOException { + int res = accept(fd, addr); + if (res >= 0) { + return res; + } + if (res == ERRNO_EAGAIN_NEGATIVE || res == ERRNO_EWOULDBLOCK_NEGATIVE) { + // Everything consumed so just return -1 here. + return -1; + } + throw newIOException("accept", res); + } + + public final InetSocketAddress remoteAddress() { + byte[] addr = remoteAddress(fd); + // addr may be null if getpeername failed. + // See https://github.com/netty/netty/issues/3328 + return addr == null ? null : address(addr, 0, addr.length); + } + + public final DomainSocketAddress remoteDomainSocketAddress() { + byte[] addr = remoteDomainSocketAddress(fd); + return addr == null ? null : new DomainSocketAddress(new String(addr)); + } + + public final InetSocketAddress localAddress() { + byte[] addr = localAddress(fd); + // addr may be null if getpeername failed. + // See https://github.com/netty/netty/issues/3328 + return addr == null ? null : address(addr, 0, addr.length); + } + + public final DomainSocketAddress localDomainSocketAddress() { + byte[] addr = localDomainSocketAddress(fd); + return addr == null ? null : new DomainSocketAddress(new String(addr)); + } + + public final int getReceiveBufferSize() throws IOException { + return getReceiveBufferSize(fd); + } + + public final int getSendBufferSize() throws IOException { + return getSendBufferSize(fd); + } + + public final boolean isKeepAlive() throws IOException { + return isKeepAlive(fd) != 0; + } + + public final boolean isTcpNoDelay() throws IOException { + return isTcpNoDelay(fd) != 0; + } + + public final boolean isReuseAddress() throws IOException { + return isReuseAddress(fd) != 0; + } + + public final boolean isReusePort() throws IOException { + return isReusePort(fd) != 0; + } + + public final boolean isBroadcast() throws IOException { + return isBroadcast(fd) != 0; + } + + public final int getSoLinger() throws IOException { + return getSoLinger(fd); + } + + public final int getSoError() throws IOException { + return getSoError(fd); + } + + public final int getTrafficClass() throws IOException { + return getTrafficClass(fd, ipv6); + } + + public final void setKeepAlive(boolean keepAlive) throws IOException { + setKeepAlive(fd, keepAlive ? 1 : 0); + } + + public final void setReceiveBufferSize(int receiveBufferSize) throws IOException { + setReceiveBufferSize(fd, receiveBufferSize); + } + + public final void setSendBufferSize(int sendBufferSize) throws IOException { + setSendBufferSize(fd, sendBufferSize); + } + + public final void setTcpNoDelay(boolean tcpNoDelay) throws IOException { + setTcpNoDelay(fd, tcpNoDelay ? 1 : 0); + } + + public final void setSoLinger(int soLinger) throws IOException { + setSoLinger(fd, soLinger); + } + + public final void setReuseAddress(boolean reuseAddress) throws IOException { + setReuseAddress(fd, reuseAddress ? 1 : 0); + } + + public final void setReusePort(boolean reusePort) throws IOException { + setReusePort(fd, reusePort ? 1 : 0); + } + + public final void setBroadcast(boolean broadcast) throws IOException { + setBroadcast(fd, broadcast ? 1 : 0); + } + + public final void setTrafficClass(int trafficClass) throws IOException { + setTrafficClass(fd, ipv6, trafficClass); + } + + public void setIntOpt(int level, int optname, int optvalue) throws IOException { + setIntOpt(fd, level, optname, optvalue); + } + + public void setRawOpt(int level, int optname, ByteBuffer optvalue) throws IOException { + int limit = optvalue.limit(); + if (optvalue.isDirect()) { + setRawOptAddress(fd, level, optname, + Buffer.memoryAddress(optvalue) + optvalue.position(), optvalue.remaining()); + } else if (optvalue.hasArray()) { + setRawOptArray(fd, level, optname, + optvalue.array(), optvalue.arrayOffset() + optvalue.position(), optvalue.remaining()); + } else { + byte[] bytes = new byte[optvalue.remaining()]; + optvalue.duplicate().get(bytes); + setRawOptArray(fd, level, optname, bytes, 0, bytes.length); + } + optvalue.position(limit); + } + + public int getIntOpt(int level, int optname) throws IOException { + return getIntOpt(fd, level, optname); + } + + public void getRawOpt(int level, int optname, ByteBuffer out) throws IOException { + if (out.isDirect()) { + getRawOptAddress(fd, level, optname, Buffer.memoryAddress(out) + out.position() , out.remaining()); + } else if (out.hasArray()) { + getRawOptArray(fd, level, optname, out.array(), out.position() + out.arrayOffset(), out.remaining()); + } else { + byte[] outArray = new byte[out.remaining()]; + getRawOptArray(fd, level, optname, outArray, 0, outArray.length); + out.put(outArray); + } + out.position(out.limit()); + } + + public static boolean isIPv6Preferred() { + return isIpv6Preferred; + } + + public static boolean shouldUseIpv6(InternetProtocolFamily family) { + return family == null ? isIPv6Preferred() : + family == InternetProtocolFamily.IPv6; + } + + private static native boolean isIPv6Preferred0(boolean ipv4Preferred); + + private static native boolean isIPv6(int fd); + + @Override + public String toString() { + return "Socket{" + + "fd=" + fd + + '}'; + } + + public static Socket newSocketStream() { + return new Socket(newSocketStream0()); + } + + public static Socket newSocketDgram() { + return new Socket(newSocketDgram0()); + } + + public static Socket newSocketDomain() { + return new Socket(newSocketDomain0()); + } + + public static Socket newSocketDomainDgram() { + return new Socket(newSocketDomainDgram0()); + } + + public static void initialize() { + isIpv6Preferred = isIPv6Preferred0(NetUtil.isIpV4StackPreferred()); + } + + protected static int newSocketStream0() { + return newSocketStream0(isIPv6Preferred()); + } + + protected static int newSocketStream0(InternetProtocolFamily protocol) { + return newSocketStream0(shouldUseIpv6(protocol)); + } + + protected static int newSocketStream0(boolean ipv6) { + int res = newSocketStreamFd(ipv6); + if (res < 0) { + throw new ChannelException(newIOException("newSocketStream", res)); + } + return res; + } + + protected static int newSocketDgram0() { + return newSocketDgram0(isIPv6Preferred()); + } + + protected static int newSocketDgram0(InternetProtocolFamily family) { + return newSocketDgram0(shouldUseIpv6(family)); + } + + protected static int newSocketDgram0(boolean ipv6) { + int res = newSocketDgramFd(ipv6); + if (res < 0) { + throw new ChannelException(newIOException("newSocketDgram", res)); + } + return res; + } + + protected static int newSocketDomain0() { + int res = newSocketDomainFd(); + if (res < 0) { + throw new ChannelException(newIOException("newSocketDomain", res)); + } + return res; + } + + protected static int newSocketDomainDgram0() { + int res = newSocketDomainDgramFd(); + if (res < 0) { + throw new ChannelException(newIOException("newSocketDomainDgram", res)); + } + return res; + } + + private static native int shutdown(int fd, boolean read, boolean write); + private static native int connect(int fd, boolean ipv6, byte[] address, int scopeId, int port); + private static native int connectDomainSocket(int fd, byte[] path); + private static native int finishConnect(int fd); + private static native int disconnect(int fd, boolean ipv6); + private static native int bind(int fd, boolean ipv6, byte[] address, int scopeId, int port); + private static native int bindDomainSocket(int fd, byte[] path); + private static native int listen(int fd, int backlog); + private static native int accept(int fd, byte[] addr); + + private static native byte[] remoteAddress(int fd); + private static native byte[] remoteDomainSocketAddress(int fd); + private static native byte[] localAddress(int fd); + private static native byte[] localDomainSocketAddress(int fd); + + private static native int send(int fd, ByteBuffer buf, int pos, int limit); + private static native int sendAddress(int fd, long address, int pos, int limit); + private static native int recv(int fd, ByteBuffer buf, int pos, int limit); + + private static native int recvAddress(int fd, long address, int pos, int limit); + + private static native int sendTo( + int fd, boolean ipv6, ByteBuffer buf, int pos, int limit, byte[] address, int scopeId, int port, + int flags); + + private static native int sendToAddress( + int fd, boolean ipv6, long memoryAddress, int pos, int limit, byte[] address, int scopeId, int port, + int flags); + + private static native int sendToAddresses( + int fd, boolean ipv6, long memoryAddress, int length, byte[] address, int scopeId, int port, + int flags); + + private static native int sendToDomainSocket(int fd, ByteBuffer buf, int pos, int limit, byte[] path); + private static native int sendToAddressDomainSocket(int fd, long memoryAddress, int pos, int limit, byte[] path); + private static native int sendToAddressesDomainSocket(int fd, long memoryAddress, int length, byte[] path); + + private static native DatagramSocketAddress recvFrom( + int fd, ByteBuffer buf, int pos, int limit) throws IOException; + private static native DatagramSocketAddress recvFromAddress( + int fd, long memoryAddress, int pos, int limit) throws IOException; + private static native DomainDatagramSocketAddress recvFromDomainSocket( + int fd, ByteBuffer buf, int pos, int limit) throws IOException; + private static native DomainDatagramSocketAddress recvFromAddressDomainSocket( + int fd, long memoryAddress, int pos, int limit) throws IOException; + private static native int recvFd(int fd); + private static native int sendFd(int socketFd, int fd); + private static native int msgFastopen(); + + private static native int newSocketStreamFd(boolean ipv6); + private static native int newSocketDgramFd(boolean ipv6); + private static native int newSocketDomainFd(); + private static native int newSocketDomainDgramFd(); + + private static native int isReuseAddress(int fd) throws IOException; + private static native int isReusePort(int fd) throws IOException; + private static native int getReceiveBufferSize(int fd) throws IOException; + private static native int getSendBufferSize(int fd) throws IOException; + private static native int isKeepAlive(int fd) throws IOException; + private static native int isTcpNoDelay(int fd) throws IOException; + private static native int isBroadcast(int fd) throws IOException; + private static native int getSoLinger(int fd) throws IOException; + private static native int getSoError(int fd) throws IOException; + private static native int getTrafficClass(int fd, boolean ipv6) throws IOException; + + private static native void setReuseAddress(int fd, int reuseAddress) throws IOException; + private static native void setReusePort(int fd, int reuseAddress) throws IOException; + private static native void setKeepAlive(int fd, int keepAlive) throws IOException; + private static native void setReceiveBufferSize(int fd, int receiveBufferSize) throws IOException; + private static native void setSendBufferSize(int fd, int sendBufferSize) throws IOException; + private static native void setTcpNoDelay(int fd, int tcpNoDelay) throws IOException; + private static native void setSoLinger(int fd, int soLinger) throws IOException; + private static native void setBroadcast(int fd, int broadcast) throws IOException; + private static native void setTrafficClass(int fd, boolean ipv6, int trafficClass) throws IOException; + + private static native void setIntOpt(int fd, int level, int optname, int optvalue) throws IOException; + private static native void setRawOptArray(int fd, int level, int optname, byte[] optvalue, int offset, int length) + throws IOException; + private static native void setRawOptAddress(int fd, int level, int optname, long optvalueMemoryAddress, int length) + throws IOException; + private static native int getIntOpt(int fd, int level, int optname) throws IOException; + private static native void getRawOptArray(int fd, int level, int optname, byte[] out, int offset, int length) + throws IOException; + private static native void getRawOptAddress(int fd, int level, int optname, long outMemoryAddress, int length) + throws IOException; +} diff --git a/netty-channel-unix/src/main/java/io/netty/channel/unix/SocketWritableByteChannel.java b/netty-channel-unix/src/main/java/io/netty/channel/unix/SocketWritableByteChannel.java new file mode 100644 index 0000000..4b06033 --- /dev/null +++ b/netty-channel-unix/src/main/java/io/netty/channel/unix/SocketWritableByteChannel.java @@ -0,0 +1,86 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.unix; + +import io.netty.buffer.ByteBufAllocator; +import io.netty.util.internal.ObjectUtil; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.WritableByteChannel; + +public abstract class SocketWritableByteChannel implements WritableByteChannel { + protected final FileDescriptor fd; + + protected SocketWritableByteChannel(FileDescriptor fd) { + this.fd = ObjectUtil.checkNotNull(fd, "fd"); + } + + protected int write(ByteBuffer buf, int pos, int limit) throws IOException { + return fd.write(buf, pos, limit); + } + + @Override + public final int write(java.nio.ByteBuffer src) throws java.io.IOException { + final int written; + int position = src.position(); + int limit = src.limit(); + if (src.isDirect()) { + written = write(src, position, src.limit()); + } else { + final int readableBytes = limit - position; + io.netty.buffer.ByteBuf buffer = null; + try { + if (readableBytes == 0) { + buffer = io.netty.buffer.Unpooled.EMPTY_BUFFER; + } else { + final ByteBufAllocator alloc = alloc(); + if (alloc.isDirectBufferPooled()) { + buffer = alloc.directBuffer(readableBytes); + } else { + buffer = io.netty.buffer.ByteBufUtil.threadLocalDirectBuffer(); + if (buffer == null) { + buffer = io.netty.buffer.Unpooled.directBuffer(readableBytes); + } + } + } + buffer.writeBytes(src.duplicate()); + java.nio.ByteBuffer nioBuffer = buffer.internalNioBuffer(buffer.readerIndex(), readableBytes); + written = write(nioBuffer, nioBuffer.position(), nioBuffer.limit()); + } finally { + if (buffer != null) { + buffer.release(); + } + } + } + if (written > 0) { + src.position(position + written); + } + return written; + } + + @Override + public final boolean isOpen() { + return fd.isOpen(); + } + + @Override + public final void close() throws java.io.IOException { + fd.close(); + } + + protected abstract ByteBufAllocator alloc(); +} diff --git a/netty-channel-unix/src/main/java/io/netty/channel/unix/Unix.java b/netty-channel-unix/src/main/java/io/netty/channel/unix/Unix.java new file mode 100644 index 0000000..e1fbb13 --- /dev/null +++ b/netty-channel-unix/src/main/java/io/netty/channel/unix/Unix.java @@ -0,0 +1,94 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.unix; + +import io.netty.util.internal.ClassInitializerUtil; +import io.netty.util.internal.UnstableApi; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.PortUnreachableException; +import java.nio.channels.ClosedChannelException; +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * Tells if {@code netty-transport-native-unix} is + * supported. + */ +public final class Unix { + private static final AtomicBoolean registered = new AtomicBoolean(); + + static { + // Preload all classes that will be used in the OnLoad(...) function of JNI to eliminate the possiblity of a + // class-loader deadlock. This is a workaround for https://github.com/netty/netty/issues/11209. + + // This needs to match all the classes that are loaded via NETTY_JNI_UTIL_LOAD_CLASS or looked up via + // NETTY_JNI_UTIL_FIND_CLASS. + ClassInitializerUtil.tryLoadClasses(Unix.class, + // netty_unix_errors + OutOfMemoryError.class, RuntimeException.class, ClosedChannelException.class, + IOException.class, PortUnreachableException.class, + + // netty_unix_socket + DatagramSocketAddress.class, DomainDatagramSocketAddress.class, InetSocketAddress.class + ); + } + + /** + * Internal method... Should never be called from the user. + * + * @param registerTask + */ + @UnstableApi + public static synchronized void registerInternal(Runnable registerTask) { + registerTask.run(); + Socket.initialize(); + } + + /** + * Returns {@code true} if and only if the {@code + * netty_transport_native_unix} is available. + */ + @Deprecated + public static boolean isAvailable() { + return false; + } + + /** + * Ensure that {@code netty_transport_native_unix} is + * available. + * + * @throws UnsatisfiedLinkError if unavailable + */ + @Deprecated + public static void ensureAvailability() { + throw new UnsupportedOperationException(); + } + + /** + * Returns the cause of unavailability of + * {@code netty_transport_native_unix}. + * + * @return the cause if unavailable. {@code null} if available. + */ + @Deprecated + public static Throwable unavailabilityCause() { + return new UnsupportedOperationException(); + } + + private Unix() { + } +} diff --git a/netty-channel-unix/src/main/java/io/netty/channel/unix/UnixChannel.java b/netty-channel-unix/src/main/java/io/netty/channel/unix/UnixChannel.java new file mode 100644 index 0000000..6119918 --- /dev/null +++ b/netty-channel-unix/src/main/java/io/netty/channel/unix/UnixChannel.java @@ -0,0 +1,28 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.unix; + +import io.netty.channel.Channel; + +/** + * {@link Channel} that expose operations that are only present on {@code UNIX} like systems. + */ +public interface UnixChannel extends Channel { + /** + * Returns the {@link FileDescriptor} that is used by this {@link Channel}. + */ + FileDescriptor fd(); +} diff --git a/netty-channel-unix/src/main/java/io/netty/channel/unix/UnixChannelOption.java b/netty-channel-unix/src/main/java/io/netty/channel/unix/UnixChannelOption.java new file mode 100644 index 0000000..2249ae1 --- /dev/null +++ b/netty-channel-unix/src/main/java/io/netty/channel/unix/UnixChannelOption.java @@ -0,0 +1,33 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.unix; + +import io.netty.channel.ChannelOption; + +public class UnixChannelOption extends ChannelOption { + public static final ChannelOption SO_REUSEPORT = valueOf(UnixChannelOption.class, "SO_REUSEPORT"); + public static final ChannelOption DOMAIN_SOCKET_READ_MODE = + ChannelOption.valueOf(UnixChannelOption.class, "DOMAIN_SOCKET_READ_MODE"); + + @SuppressWarnings({ "unused", "deprecation" }) + protected UnixChannelOption() { + super(null); + } + + UnixChannelOption(String name) { + super(name); + } +} diff --git a/netty-channel-unix/src/main/java/io/netty/channel/unix/UnixChannelUtil.java b/netty-channel-unix/src/main/java/io/netty/channel/unix/UnixChannelUtil.java new file mode 100644 index 0000000..ada5111 --- /dev/null +++ b/netty-channel-unix/src/main/java/io/netty/channel/unix/UnixChannelUtil.java @@ -0,0 +1,62 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.unix; + +import io.netty.buffer.ByteBuf; +import io.netty.util.internal.PlatformDependent; + +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.UnknownHostException; + +import static io.netty.channel.unix.Limits.IOV_MAX; + +public final class UnixChannelUtil { + + private UnixChannelUtil() { + } + + /** + * Checks if the specified buffer has memory address or is composed of n(n <= IOV_MAX) NIO direct buffers. + * (We check this because otherwise we need to make it a new direct buffer.) + */ + public static boolean isBufferCopyNeededForWrite(ByteBuf byteBuf) { + return isBufferCopyNeededForWrite(byteBuf, IOV_MAX); + } + + static boolean isBufferCopyNeededForWrite(ByteBuf byteBuf, int iovMax) { + return !byteBuf.hasMemoryAddress() && (!byteBuf.isDirect() || byteBuf.nioBufferCount() > iovMax); + } + + public static InetSocketAddress computeRemoteAddr(InetSocketAddress remoteAddr, InetSocketAddress osRemoteAddr) { + if (osRemoteAddr != null) { + if (PlatformDependent.javaVersion() >= 7) { + try { + // Only try to construct a new InetSocketAddress if we using java >= 7 as getHostString() does not + // exists in earlier releases and so the retrieval of the hostname could block the EventLoop if a + // reverse lookup would be needed. + return new InetSocketAddress(InetAddress.getByAddress(remoteAddr.getHostString(), + osRemoteAddr.getAddress().getAddress()), + osRemoteAddr.getPort()); + } catch (UnknownHostException ignore) { + // Should never happen but fallback to osRemoteAddr anyway. + } + } + return osRemoteAddr; + } + return remoteAddr; + } +} diff --git a/netty-channel-unix/src/main/java/io/netty/channel/unix/package-info.java b/netty-channel-unix/src/main/java/io/netty/channel/unix/package-info.java new file mode 100644 index 0000000..e2cc3c4 --- /dev/null +++ b/netty-channel-unix/src/main/java/io/netty/channel/unix/package-info.java @@ -0,0 +1,20 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/** + * Unix specific transport. + */ +package io.netty.channel.unix; diff --git a/netty-channel-unix/src/main/java/module-info.java b/netty-channel-unix/src/main/java/module-info.java new file mode 100644 index 0000000..5c4b772 --- /dev/null +++ b/netty-channel-unix/src/main/java/module-info.java @@ -0,0 +1,6 @@ +module org.xbib.io.netty.channel.unix { + exports io.netty.channel.unix; + requires org.xbib.io.netty.buffer; + requires org.xbib.io.netty.channel; + requires org.xbib.io.netty.util; +} diff --git a/netty-channel-unix/src/test/java/io/netty/channel/unix/UnixChannelUtilTest.java b/netty-channel-unix/src/test/java/io/netty/channel/unix/UnixChannelUtilTest.java new file mode 100644 index 0000000..83895ad --- /dev/null +++ b/netty-channel-unix/src/test/java/io/netty/channel/unix/UnixChannelUtilTest.java @@ -0,0 +1,88 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.channel.unix; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.buffer.PooledByteBufAllocator; +import io.netty.buffer.UnpooledByteBufAllocator; +import org.junit.jupiter.api.Test; + +import java.util.Collections; +import java.util.LinkedList; +import java.util.List; + +import static io.netty.channel.unix.UnixChannelUtil.isBufferCopyNeededForWrite; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class UnixChannelUtilTest { + + private static final int IOV_MAX = 1024; + + @Test + public void testPooledAllocatorIsBufferCopyNeededForWrite() { + testIsBufferCopyNeededForWrite(PooledByteBufAllocator.DEFAULT); + } + + @Test + public void testUnPooledAllocatorIsBufferCopyNeededForWrite() { + testIsBufferCopyNeededForWrite(UnpooledByteBufAllocator.DEFAULT); + } + + private static void testIsBufferCopyNeededForWrite(ByteBufAllocator alloc) { + ByteBuf byteBuf = alloc.directBuffer(); + assertFalse(isBufferCopyNeededForWrite(byteBuf, IOV_MAX)); + assertFalse(isBufferCopyNeededForWrite(byteBuf.asReadOnly(), IOV_MAX)); + assertTrue(byteBuf.release()); + + byteBuf = alloc.heapBuffer(); + assertTrue(isBufferCopyNeededForWrite(byteBuf, IOV_MAX)); + assertTrue(isBufferCopyNeededForWrite(byteBuf.asReadOnly(), IOV_MAX)); + assertTrue(byteBuf.release()); + + assertCompositeByteBufIsBufferCopyNeededForWrite(alloc, 2, 0, false); + assertCompositeByteBufIsBufferCopyNeededForWrite(alloc, IOV_MAX + 1, 0, true); + assertCompositeByteBufIsBufferCopyNeededForWrite(alloc, 0, 2, true); + assertCompositeByteBufIsBufferCopyNeededForWrite(alloc, 1, 1, true); + } + + private static void assertCompositeByteBufIsBufferCopyNeededForWrite(ByteBufAllocator alloc, int numDirect, + int numHeap, boolean expected) { + CompositeByteBuf comp = alloc.compositeBuffer(numDirect + numHeap); + List byteBufs = new LinkedList(); + + while (numDirect > 0) { + byteBufs.add(alloc.directBuffer(1)); + numDirect--; + } + while (numHeap > 0) { + byteBufs.add(alloc.heapBuffer(1)); + numHeap--; + } + + Collections.shuffle(byteBufs); + for (ByteBuf byteBuf : byteBufs) { + comp.addComponent(byteBuf); + } + + assertEquals(expected, isBufferCopyNeededForWrite(comp, IOV_MAX), byteBufs.toString()); + assertTrue(comp.release()); + } +} diff --git a/netty-channel-unix/src/test/resources/logging.properties b/netty-channel-unix/src/test/resources/logging.properties new file mode 100644 index 0000000..3cd7309 --- /dev/null +++ b/netty-channel-unix/src/test/resources/logging.properties @@ -0,0 +1,7 @@ +handlers=java.util.logging.ConsoleHandler +.level=ALL +java.util.logging.SimpleFormatter.format=%1$tY-%1$tm-%1$td %1$tH:%1$tM:%1$tS.%1$tL %4$-7s [%3$s] %5$s %6$s%n +java.util.logging.ConsoleHandler.level=ALL +java.util.logging.ConsoleHandler.formatter=java.util.logging.SimpleFormatter +jdk.event.security.level=INFO +org.junit.jupiter.engine.execution.ConditionEvaluator.level=OFF diff --git a/netty-channel/build.gradle b/netty-channel/build.gradle new file mode 100644 index 0000000..9ca7960 --- /dev/null +++ b/netty-channel/build.gradle @@ -0,0 +1,8 @@ +dependencies { + api project(':netty-buffer') + api project(':netty-resolver') + testImplementation testLibs.mockito.core + testImplementation testLibs.testlibs + testImplementation testLibs.gson + testImplementation testLibs.reflections +} diff --git a/netty-channel/src/main/java/io/netty/bootstrap/AbstractBootstrap.java b/netty-channel/src/main/java/io/netty/bootstrap/AbstractBootstrap.java new file mode 100644 index 0000000..e6d9ef0 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/bootstrap/AbstractBootstrap.java @@ -0,0 +1,530 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.bootstrap; + +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelOption; +import io.netty.channel.ChannelPromise; +import io.netty.channel.DefaultChannelPromise; +import io.netty.channel.EventLoop; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.ReflectiveChannelFactory; +import io.netty.util.AttributeKey; +import io.netty.util.concurrent.EventExecutor; +import io.netty.util.concurrent.GlobalEventExecutor; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.SocketUtils; +import io.netty.util.internal.StringUtil; +import io.netty.util.internal.logging.InternalLogger; + +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +/** + * {@link AbstractBootstrap} is a helper class that makes it easy to bootstrap a {@link Channel}. It support + * method-chaining to provide an easy way to configure the {@link AbstractBootstrap}. + * + *

When not used in a {@link ServerBootstrap} context, the {@link #bind()} methods are useful for connectionless + * transports such as datagram (UDP).

+ */ +public abstract class AbstractBootstrap, C extends Channel> implements Cloneable { + @SuppressWarnings("unchecked") + private static final Map.Entry, Object>[] EMPTY_OPTION_ARRAY = new Map.Entry[0]; + @SuppressWarnings("unchecked") + private static final Map.Entry, Object>[] EMPTY_ATTRIBUTE_ARRAY = new Map.Entry[0]; + + volatile EventLoopGroup group; + @SuppressWarnings("deprecation") + private volatile ChannelFactory channelFactory; + private volatile SocketAddress localAddress; + + // The order in which ChannelOptions are applied is important they may depend on each other for validation + // purposes. + private final Map, Object> options = new LinkedHashMap, Object>(); + private final Map, Object> attrs = new ConcurrentHashMap, Object>(); + private volatile ChannelHandler handler; + private volatile ClassLoader extensionsClassLoader; + + AbstractBootstrap() { + // Disallow extending from a different package. + } + + AbstractBootstrap(AbstractBootstrap bootstrap) { + group = bootstrap.group; + channelFactory = bootstrap.channelFactory; + handler = bootstrap.handler; + localAddress = bootstrap.localAddress; + synchronized (bootstrap.options) { + options.putAll(bootstrap.options); + } + attrs.putAll(bootstrap.attrs); + extensionsClassLoader = bootstrap.extensionsClassLoader; + } + + /** + * The {@link EventLoopGroup} which is used to handle all the events for the to-be-created + * {@link Channel} + */ + public B group(EventLoopGroup group) { + ObjectUtil.checkNotNull(group, "group"); + if (this.group != null) { + throw new IllegalStateException("group set already"); + } + this.group = group; + return self(); + } + + @SuppressWarnings("unchecked") + private B self() { + return (B) this; + } + + /** + * The {@link Class} which is used to create {@link Channel} instances from. + * You either use this or {@link #channelFactory(io.netty.channel.ChannelFactory)} if your + * {@link Channel} implementation has no no-args constructor. + */ + public B channel(Class channelClass) { + return channelFactory(new ReflectiveChannelFactory( + ObjectUtil.checkNotNull(channelClass, "channelClass") + )); + } + + /** + * @deprecated Use {@link #channelFactory(io.netty.channel.ChannelFactory)} instead. + */ + @Deprecated + public B channelFactory(ChannelFactory channelFactory) { + ObjectUtil.checkNotNull(channelFactory, "channelFactory"); + if (this.channelFactory != null) { + throw new IllegalStateException("channelFactory set already"); + } + + this.channelFactory = channelFactory; + return self(); + } + + /** + * {@link io.netty.channel.ChannelFactory} which is used to create {@link Channel} instances from + * when calling {@link #bind()}. This method is usually only used if {@link #channel(Class)} + * is not working for you because of some more complex needs. If your {@link Channel} implementation + * has a no-args constructor, its highly recommend to just use {@link #channel(Class)} to + * simplify your code. + */ + @SuppressWarnings({ "unchecked", "deprecation" }) + public B channelFactory(io.netty.channel.ChannelFactory channelFactory) { + return channelFactory((ChannelFactory) channelFactory); + } + + /** + * The {@link SocketAddress} which is used to bind the local "end" to. + */ + public B localAddress(SocketAddress localAddress) { + this.localAddress = localAddress; + return self(); + } + + /** + * @see #localAddress(SocketAddress) + */ + public B localAddress(int inetPort) { + return localAddress(new InetSocketAddress(inetPort)); + } + + /** + * @see #localAddress(SocketAddress) + */ + public B localAddress(String inetHost, int inetPort) { + return localAddress(SocketUtils.socketAddress(inetHost, inetPort)); + } + + /** + * @see #localAddress(SocketAddress) + */ + public B localAddress(InetAddress inetHost, int inetPort) { + return localAddress(new InetSocketAddress(inetHost, inetPort)); + } + + /** + * Allow to specify a {@link ChannelOption} which is used for the {@link Channel} instances once they got + * created. Use a value of {@code null} to remove a previous set {@link ChannelOption}. + */ + public B option(ChannelOption option, T value) { + ObjectUtil.checkNotNull(option, "option"); + synchronized (options) { + if (value == null) { + options.remove(option); + } else { + options.put(option, value); + } + } + return self(); + } + + /** + * Allow to specify an initial attribute of the newly created {@link Channel}. If the {@code value} is + * {@code null}, the attribute of the specified {@code key} is removed. + */ + public B attr(AttributeKey key, T value) { + ObjectUtil.checkNotNull(key, "key"); + if (value == null) { + attrs.remove(key); + } else { + attrs.put(key, value); + } + return self(); + } + + /** + * Load {@link ChannelInitializerExtension}s using the given class loader. + *

+ * By default, the extensions will be loaded by the same class loader that loaded this bootstrap class. + * + * @param classLoader The class loader to use for loading {@link ChannelInitializerExtension}s. + * @return This bootstrap. + */ + public B extensionsClassLoader(ClassLoader classLoader) { + extensionsClassLoader = classLoader; + return self(); + } + + /** + * Validate all the parameters. Sub-classes may override this, but should + * call the super method in that case. + */ + public B validate() { + if (group == null) { + throw new IllegalStateException("group not set"); + } + if (channelFactory == null) { + throw new IllegalStateException("channel or channelFactory not set"); + } + return self(); + } + + /** + * Returns a deep clone of this bootstrap which has the identical configuration. This method is useful when making + * multiple {@link Channel}s with similar settings. Please note that this method does not clone the + * {@link EventLoopGroup} deeply but shallowly, making the group a shared resource. + */ + @Override + @SuppressWarnings("CloneDoesntDeclareCloneNotSupportedException") + public abstract B clone(); + + /** + * Create a new {@link Channel} and register it with an {@link EventLoop}. + */ + public ChannelFuture register() { + validate(); + return initAndRegister(); + } + + /** + * Create a new {@link Channel} and bind it. + */ + public ChannelFuture bind() { + validate(); + SocketAddress localAddress = this.localAddress; + if (localAddress == null) { + throw new IllegalStateException("localAddress not set"); + } + return doBind(localAddress); + } + + /** + * Create a new {@link Channel} and bind it. + */ + public ChannelFuture bind(int inetPort) { + return bind(new InetSocketAddress(inetPort)); + } + + /** + * Create a new {@link Channel} and bind it. + */ + public ChannelFuture bind(String inetHost, int inetPort) { + return bind(SocketUtils.socketAddress(inetHost, inetPort)); + } + + /** + * Create a new {@link Channel} and bind it. + */ + public ChannelFuture bind(InetAddress inetHost, int inetPort) { + return bind(new InetSocketAddress(inetHost, inetPort)); + } + + /** + * Create a new {@link Channel} and bind it. + */ + public ChannelFuture bind(SocketAddress localAddress) { + validate(); + return doBind(ObjectUtil.checkNotNull(localAddress, "localAddress")); + } + + private ChannelFuture doBind(final SocketAddress localAddress) { + final ChannelFuture regFuture = initAndRegister(); + final Channel channel = regFuture.channel(); + if (regFuture.cause() != null) { + return regFuture; + } + + if (regFuture.isDone()) { + // At this point we know that the registration was complete and successful. + ChannelPromise promise = channel.newPromise(); + doBind0(regFuture, channel, localAddress, promise); + return promise; + } else { + // Registration future is almost always fulfilled already, but just in case it's not. + final PendingRegistrationPromise promise = new PendingRegistrationPromise(channel); + regFuture.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + Throwable cause = future.cause(); + if (cause != null) { + // Registration on the EventLoop failed so fail the ChannelPromise directly to not cause an + // IllegalStateException once we try to access the EventLoop of the Channel. + promise.setFailure(cause); + } else { + // Registration was successful, so set the correct executor to use. + // See https://github.com/netty/netty/issues/2586 + promise.registered(); + + doBind0(regFuture, channel, localAddress, promise); + } + } + }); + return promise; + } + } + + final ChannelFuture initAndRegister() { + Channel channel = null; + try { + channel = channelFactory.newChannel(); + init(channel); + } catch (Throwable t) { + if (channel != null) { + // channel can be null if newChannel crashed (eg SocketException("too many open files")) + channel.unsafe().closeForcibly(); + // as the Channel is not registered yet we need to force the usage of the GlobalEventExecutor + return new DefaultChannelPromise(channel, GlobalEventExecutor.INSTANCE).setFailure(t); + } + // as the Channel is not registered yet we need to force the usage of the GlobalEventExecutor + return new DefaultChannelPromise(new FailedChannel(), GlobalEventExecutor.INSTANCE).setFailure(t); + } + + ChannelFuture regFuture = config().group().register(channel); + if (regFuture.cause() != null) { + if (channel.isRegistered()) { + channel.close(); + } else { + channel.unsafe().closeForcibly(); + } + } + + // If we are here and the promise is not failed, it's one of the following cases: + // 1) If we attempted registration from the event loop, the registration has been completed at this point. + // i.e. It's safe to attempt bind() or connect() now because the channel has been registered. + // 2) If we attempted registration from the other thread, the registration request has been successfully + // added to the event loop's task queue for later execution. + // i.e. It's safe to attempt bind() or connect() now: + // because bind() or connect() will be executed *after* the scheduled registration task is executed + // because register(), bind(), and connect() are all bound to the same thread. + + return regFuture; + } + + abstract void init(Channel channel) throws Exception; + + Collection getInitializerExtensions() { + ClassLoader loader = extensionsClassLoader; + if (loader == null) { + loader = getClass().getClassLoader(); + } + return ChannelInitializerExtensions.getExtensions().extensions(loader); + } + + private static void doBind0( + final ChannelFuture regFuture, final Channel channel, + final SocketAddress localAddress, final ChannelPromise promise) { + + // This method is invoked before channelRegistered() is triggered. Give user handlers a chance to set up + // the pipeline in its channelRegistered() implementation. + channel.eventLoop().execute(new Runnable() { + @Override + public void run() { + if (regFuture.isSuccess()) { + channel.bind(localAddress, promise).addListener(ChannelFutureListener.CLOSE_ON_FAILURE); + } else { + promise.setFailure(regFuture.cause()); + } + } + }); + } + + /** + * the {@link ChannelHandler} to use for serving the requests. + */ + public B handler(ChannelHandler handler) { + this.handler = ObjectUtil.checkNotNull(handler, "handler"); + return self(); + } + + /** + * Returns the configured {@link EventLoopGroup} or {@code null} if non is configured yet. + * + * @deprecated Use {@link #config()} instead. + */ + @Deprecated + public final EventLoopGroup group() { + return group; + } + + /** + * Returns the {@link AbstractBootstrapConfig} object that can be used to obtain the current config + * of the bootstrap. + */ + public abstract AbstractBootstrapConfig config(); + + final Map.Entry, Object>[] newOptionsArray() { + return newOptionsArray(options); + } + + static Map.Entry, Object>[] newOptionsArray(Map, Object> options) { + synchronized (options) { + return new LinkedHashMap, Object>(options).entrySet().toArray(EMPTY_OPTION_ARRAY); + } + } + + final Map.Entry, Object>[] newAttributesArray() { + return newAttributesArray(attrs0()); + } + + static Map.Entry, Object>[] newAttributesArray(Map, Object> attributes) { + return attributes.entrySet().toArray(EMPTY_ATTRIBUTE_ARRAY); + } + + final Map, Object> options0() { + return options; + } + + final Map, Object> attrs0() { + return attrs; + } + + final SocketAddress localAddress() { + return localAddress; + } + + @SuppressWarnings("deprecation") + final ChannelFactory channelFactory() { + return channelFactory; + } + + final ChannelHandler handler() { + return handler; + } + + final Map, Object> options() { + synchronized (options) { + return copiedMap(options); + } + } + + final Map, Object> attrs() { + return copiedMap(attrs); + } + + static Map copiedMap(Map map) { + if (map.isEmpty()) { + return Collections.emptyMap(); + } + return Collections.unmodifiableMap(new HashMap(map)); + } + + static void setAttributes(Channel channel, Map.Entry, Object>[] attrs) { + for (Map.Entry, Object> e: attrs) { + @SuppressWarnings("unchecked") + AttributeKey key = (AttributeKey) e.getKey(); + channel.attr(key).set(e.getValue()); + } + } + + static void setChannelOptions( + Channel channel, Map.Entry, Object>[] options, InternalLogger logger) { + for (Map.Entry, Object> e: options) { + setChannelOption(channel, e.getKey(), e.getValue(), logger); + } + } + + @SuppressWarnings("unchecked") + private static void setChannelOption( + Channel channel, ChannelOption option, Object value, InternalLogger logger) { + try { + if (!channel.config().setOption((ChannelOption) option, value)) { + logger.warn("Unknown channel option '{}' for channel '{}'", option, channel); + } + } catch (Throwable t) { + logger.warn( + "Failed to set channel option '{}' with value '{}' for channel '{}'", option, value, channel, t); + } + } + + @Override + public String toString() { + StringBuilder buf = new StringBuilder() + .append(StringUtil.simpleClassName(this)) + .append('(').append(config()).append(')'); + return buf.toString(); + } + + static final class PendingRegistrationPromise extends DefaultChannelPromise { + + // Is set to the correct EventExecutor once the registration was successful. Otherwise it will + // stay null and so the GlobalEventExecutor.INSTANCE will be used for notifications. + private volatile boolean registered; + + PendingRegistrationPromise(Channel channel) { + super(channel); + } + + void registered() { + registered = true; + } + + @Override + protected EventExecutor executor() { + if (registered) { + // If the registration was a success executor is set. + // + // See https://github.com/netty/netty/issues/2586 + return super.executor(); + } + // The registration failed so we can only use the GlobalEventExecutor as last resort to notify. + return GlobalEventExecutor.INSTANCE; + } + } +} diff --git a/netty-channel/src/main/java/io/netty/bootstrap/AbstractBootstrapConfig.java b/netty-channel/src/main/java/io/netty/bootstrap/AbstractBootstrapConfig.java new file mode 100644 index 0000000..a682b9c --- /dev/null +++ b/netty-channel/src/main/java/io/netty/bootstrap/AbstractBootstrapConfig.java @@ -0,0 +1,135 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.bootstrap; + +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelOption; +import io.netty.channel.EventLoopGroup; +import io.netty.util.AttributeKey; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.StringUtil; + +import java.net.SocketAddress; +import java.util.Map; + +/** + * Exposes the configuration of an {@link AbstractBootstrap}. + */ +public abstract class AbstractBootstrapConfig, C extends Channel> { + + protected final B bootstrap; + + protected AbstractBootstrapConfig(B bootstrap) { + this.bootstrap = ObjectUtil.checkNotNull(bootstrap, "bootstrap"); + } + + /** + * Returns the configured local address or {@code null} if non is configured yet. + */ + public final SocketAddress localAddress() { + return bootstrap.localAddress(); + } + + /** + * Returns the configured {@link ChannelFactory} or {@code null} if non is configured yet. + */ + @SuppressWarnings("deprecation") + public final ChannelFactory channelFactory() { + return bootstrap.channelFactory(); + } + + /** + * Returns the configured {@link ChannelHandler} or {@code null} if non is configured yet. + */ + public final ChannelHandler handler() { + return bootstrap.handler(); + } + + /** + * Returns a copy of the configured options. + */ + public final Map, Object> options() { + return bootstrap.options(); + } + + /** + * Returns a copy of the configured attributes. + */ + public final Map, Object> attrs() { + return bootstrap.attrs(); + } + + /** + * Returns the configured {@link EventLoopGroup} or {@code null} if non is configured yet. + */ + @SuppressWarnings("deprecation") + public final EventLoopGroup group() { + return bootstrap.group(); + } + + @Override + public String toString() { + StringBuilder buf = new StringBuilder() + .append(StringUtil.simpleClassName(this)) + .append('('); + EventLoopGroup group = group(); + if (group != null) { + buf.append("group: ") + .append(StringUtil.simpleClassName(group)) + .append(", "); + } + @SuppressWarnings("deprecation") + ChannelFactory factory = channelFactory(); + if (factory != null) { + buf.append("channelFactory: ") + .append(factory) + .append(", "); + } + SocketAddress localAddress = localAddress(); + if (localAddress != null) { + buf.append("localAddress: ") + .append(localAddress) + .append(", "); + } + + Map, Object> options = options(); + if (!options.isEmpty()) { + buf.append("options: ") + .append(options) + .append(", "); + } + Map, Object> attrs = attrs(); + if (!attrs.isEmpty()) { + buf.append("attrs: ") + .append(attrs) + .append(", "); + } + ChannelHandler handler = handler(); + if (handler != null) { + buf.append("handler: ") + .append(handler) + .append(", "); + } + if (buf.charAt(buf.length() - 1) == '(') { + buf.append(')'); + } else { + buf.setCharAt(buf.length() - 2, ')'); + buf.setLength(buf.length() - 1); + } + return buf.toString(); + } +} diff --git a/netty-channel/src/main/java/io/netty/bootstrap/Bootstrap.java b/netty-channel/src/main/java/io/netty/bootstrap/Bootstrap.java new file mode 100644 index 0000000..cfba85f --- /dev/null +++ b/netty-channel/src/main/java/io/netty/bootstrap/Bootstrap.java @@ -0,0 +1,353 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.bootstrap; + +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.ChannelPromise; +import io.netty.channel.EventLoop; +import io.netty.channel.EventLoopGroup; +import io.netty.resolver.AddressResolver; +import io.netty.resolver.AddressResolverGroup; +import io.netty.resolver.DefaultAddressResolverGroup; +import io.netty.resolver.NameResolver; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.FutureListener; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.util.Collection; + +/** + * A {@link Bootstrap} that makes it easy to bootstrap a {@link Channel} to use + * for clients. + * + *

The {@link #bind()} methods are useful in combination with connectionless transports such as datagram (UDP). + * For regular TCP connections, please use the provided {@link #connect()} methods.

+ */ +public class Bootstrap extends AbstractBootstrap { + + private static final InternalLogger logger = InternalLoggerFactory.getInstance(Bootstrap.class); + + private final BootstrapConfig config = new BootstrapConfig(this); + + private ExternalAddressResolver externalResolver; + private volatile boolean disableResolver; + private volatile SocketAddress remoteAddress; + + public Bootstrap() { } + + private Bootstrap(Bootstrap bootstrap) { + super(bootstrap); + externalResolver = bootstrap.externalResolver; + disableResolver = bootstrap.disableResolver; + remoteAddress = bootstrap.remoteAddress; + } + + /** + * Sets the {@link NameResolver} which will resolve the address of the unresolved named address. + * + * @param resolver the {@link NameResolver} for this {@code Bootstrap}; may be {@code null}, in which case a default + * resolver will be used + * + * @see io.netty.resolver.DefaultAddressResolverGroup + */ + public Bootstrap resolver(AddressResolverGroup resolver) { + externalResolver = resolver == null ? null : new ExternalAddressResolver(resolver); + disableResolver = false; + return this; + } + + /** + * Disables address name resolution. Name resolution may be re-enabled with + * {@link Bootstrap#resolver(AddressResolverGroup)} + */ + public Bootstrap disableResolver() { + externalResolver = null; + disableResolver = true; + return this; + } + + /** + * The {@link SocketAddress} to connect to once the {@link #connect()} method + * is called. + */ + public Bootstrap remoteAddress(SocketAddress remoteAddress) { + this.remoteAddress = remoteAddress; + return this; + } + + /** + * @see #remoteAddress(SocketAddress) + */ + public Bootstrap remoteAddress(String inetHost, int inetPort) { + remoteAddress = InetSocketAddress.createUnresolved(inetHost, inetPort); + return this; + } + + /** + * @see #remoteAddress(SocketAddress) + */ + public Bootstrap remoteAddress(InetAddress inetHost, int inetPort) { + remoteAddress = new InetSocketAddress(inetHost, inetPort); + return this; + } + + /** + * Connect a {@link Channel} to the remote peer. + */ + public ChannelFuture connect() { + validate(); + SocketAddress remoteAddress = this.remoteAddress; + if (remoteAddress == null) { + throw new IllegalStateException("remoteAddress not set"); + } + + return doResolveAndConnect(remoteAddress, config.localAddress()); + } + + /** + * Connect a {@link Channel} to the remote peer. + */ + public ChannelFuture connect(String inetHost, int inetPort) { + return connect(InetSocketAddress.createUnresolved(inetHost, inetPort)); + } + + /** + * Connect a {@link Channel} to the remote peer. + */ + public ChannelFuture connect(InetAddress inetHost, int inetPort) { + return connect(new InetSocketAddress(inetHost, inetPort)); + } + + /** + * Connect a {@link Channel} to the remote peer. + */ + public ChannelFuture connect(SocketAddress remoteAddress) { + ObjectUtil.checkNotNull(remoteAddress, "remoteAddress"); + validate(); + return doResolveAndConnect(remoteAddress, config.localAddress()); + } + + /** + * Connect a {@link Channel} to the remote peer. + */ + public ChannelFuture connect(SocketAddress remoteAddress, SocketAddress localAddress) { + ObjectUtil.checkNotNull(remoteAddress, "remoteAddress"); + validate(); + return doResolveAndConnect(remoteAddress, localAddress); + } + + /** + * @see #connect() + */ + private ChannelFuture doResolveAndConnect(final SocketAddress remoteAddress, final SocketAddress localAddress) { + final ChannelFuture regFuture = initAndRegister(); + final Channel channel = regFuture.channel(); + + if (regFuture.isDone()) { + if (!regFuture.isSuccess()) { + return regFuture; + } + return doResolveAndConnect0(channel, remoteAddress, localAddress, channel.newPromise()); + } else { + // Registration future is almost always fulfilled already, but just in case it's not. + final PendingRegistrationPromise promise = new PendingRegistrationPromise(channel); + regFuture.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + // Directly obtain the cause and do a null check so we only need one volatile read in case of a + // failure. + Throwable cause = future.cause(); + if (cause != null) { + // Registration on the EventLoop failed so fail the ChannelPromise directly to not cause an + // IllegalStateException once we try to access the EventLoop of the Channel. + promise.setFailure(cause); + } else { + // Registration was successful, so set the correct executor to use. + // See https://github.com/netty/netty/issues/2586 + promise.registered(); + doResolveAndConnect0(channel, remoteAddress, localAddress, promise); + } + } + }); + return promise; + } + } + + private ChannelFuture doResolveAndConnect0(final Channel channel, SocketAddress remoteAddress, + final SocketAddress localAddress, final ChannelPromise promise) { + try { + if (disableResolver) { + doConnect(remoteAddress, localAddress, promise); + return promise; + } + + final EventLoop eventLoop = channel.eventLoop(); + AddressResolver resolver; + try { + resolver = ExternalAddressResolver.getOrDefault(externalResolver).getResolver(eventLoop); + } catch (Throwable cause) { + channel.close(); + return promise.setFailure(cause); + } + + if (!resolver.isSupported(remoteAddress) || resolver.isResolved(remoteAddress)) { + // Resolver has no idea about what to do with the specified remote address or it's resolved already. + doConnect(remoteAddress, localAddress, promise); + return promise; + } + + final Future resolveFuture = resolver.resolve(remoteAddress); + + if (resolveFuture.isDone()) { + final Throwable resolveFailureCause = resolveFuture.cause(); + + if (resolveFailureCause != null) { + // Failed to resolve immediately + channel.close(); + promise.setFailure(resolveFailureCause); + } else { + // Succeeded to resolve immediately; cached? (or did a blocking lookup) + doConnect(resolveFuture.getNow(), localAddress, promise); + } + return promise; + } + + // Wait until the name resolution is finished. + resolveFuture.addListener(new FutureListener() { + @Override + public void operationComplete(Future future) throws Exception { + if (future.cause() != null) { + channel.close(); + promise.setFailure(future.cause()); + } else { + doConnect(future.getNow(), localAddress, promise); + } + } + }); + } catch (Throwable cause) { + promise.tryFailure(cause); + } + return promise; + } + + private static void doConnect( + final SocketAddress remoteAddress, final SocketAddress localAddress, final ChannelPromise connectPromise) { + + // This method is invoked before channelRegistered() is triggered. Give user handlers a chance to set up + // the pipeline in its channelRegistered() implementation. + final Channel channel = connectPromise.channel(); + channel.eventLoop().execute(new Runnable() { + @Override + public void run() { + if (localAddress == null) { + channel.connect(remoteAddress, connectPromise); + } else { + channel.connect(remoteAddress, localAddress, connectPromise); + } + connectPromise.addListener(ChannelFutureListener.CLOSE_ON_FAILURE); + } + }); + } + + @Override + void init(Channel channel) { + ChannelPipeline p = channel.pipeline(); + p.addLast(config.handler()); + + setChannelOptions(channel, newOptionsArray(), logger); + setAttributes(channel, newAttributesArray()); + Collection extensions = getInitializerExtensions(); + if (!extensions.isEmpty()) { + for (ChannelInitializerExtension extension : extensions) { + try { + extension.postInitializeClientChannel(channel); + } catch (Exception e) { + logger.warn("Exception thrown from postInitializeClientChannel", e); + } + } + } + } + + @Override + public Bootstrap validate() { + super.validate(); + if (config.handler() == null) { + throw new IllegalStateException("handler not set"); + } + return this; + } + + @Override + @SuppressWarnings("CloneDoesntCallSuperClone") + public Bootstrap clone() { + return new Bootstrap(this); + } + + /** + * Returns a deep clone of this bootstrap which has the identical configuration except that it uses + * the given {@link EventLoopGroup}. This method is useful when making multiple {@link Channel}s with similar + * settings. + */ + public Bootstrap clone(EventLoopGroup group) { + Bootstrap bs = new Bootstrap(this); + bs.group = group; + return bs; + } + + @Override + public final BootstrapConfig config() { + return config; + } + + final SocketAddress remoteAddress() { + return remoteAddress; + } + + final AddressResolverGroup resolver() { + if (disableResolver) { + return null; + } + return ExternalAddressResolver.getOrDefault(externalResolver); + } + + /* Holder to avoid NoClassDefFoundError in case netty-resolver dependency is excluded + (e.g. some address families do not need name resolution) */ + static final class ExternalAddressResolver { + final AddressResolverGroup resolverGroup; + + @SuppressWarnings("unchecked") + ExternalAddressResolver(AddressResolverGroup resolverGroup) { + this.resolverGroup = (AddressResolverGroup) resolverGroup; + } + + @SuppressWarnings("unchecked") + static AddressResolverGroup getOrDefault(ExternalAddressResolver externalResolver) { + if (externalResolver == null) { + AddressResolverGroup defaultResolverGroup = DefaultAddressResolverGroup.INSTANCE; + return (AddressResolverGroup) defaultResolverGroup; + } + return externalResolver.resolverGroup; + } + } +} diff --git a/netty-channel/src/main/java/io/netty/bootstrap/BootstrapConfig.java b/netty-channel/src/main/java/io/netty/bootstrap/BootstrapConfig.java new file mode 100644 index 0000000..e80ddb2 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/bootstrap/BootstrapConfig.java @@ -0,0 +1,63 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.bootstrap; + +import io.netty.channel.Channel; +import io.netty.resolver.AddressResolverGroup; + +import java.net.SocketAddress; + +/** + * Exposes the configuration of a {@link Bootstrap}. + */ +public final class BootstrapConfig extends AbstractBootstrapConfig { + + BootstrapConfig(Bootstrap bootstrap) { + super(bootstrap); + } + + /** + * Returns the configured remote address or {@code null} if non is configured yet. + */ + public SocketAddress remoteAddress() { + return bootstrap.remoteAddress(); + } + + /** + * Returns the configured {@link AddressResolverGroup}, {@code null} if resolver was disabled + * with {@link Bootstrap#disableResolver()}, or the default if not configured yet. + */ + public AddressResolverGroup resolver() { + return bootstrap.resolver(); + } + + @Override + public String toString() { + StringBuilder buf = new StringBuilder(super.toString()); + buf.setLength(buf.length() - 1); + AddressResolverGroup resolver = resolver(); + if (resolver != null) { + buf.append(", resolver: ") + .append(resolver); + } + SocketAddress remoteAddress = remoteAddress(); + if (remoteAddress != null) { + buf.append(", remoteAddress: ") + .append(remoteAddress); + } + return buf.append(')').toString(); + } +} diff --git a/netty-channel/src/main/java/io/netty/bootstrap/ChannelFactory.java b/netty-channel/src/main/java/io/netty/bootstrap/ChannelFactory.java new file mode 100644 index 0000000..9c72e60 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/bootstrap/ChannelFactory.java @@ -0,0 +1,29 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.bootstrap; + +import io.netty.channel.Channel; + +/** + * @deprecated Use {@link io.netty.channel.ChannelFactory} instead. + */ +@Deprecated +public interface ChannelFactory { + /** + * Creates a new channel. + */ + T newChannel(); +} diff --git a/netty-channel/src/main/java/io/netty/bootstrap/ChannelInitializerExtension.java b/netty-channel/src/main/java/io/netty/bootstrap/ChannelInitializerExtension.java new file mode 100644 index 0000000..09fe0f8 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/bootstrap/ChannelInitializerExtension.java @@ -0,0 +1,122 @@ +/* + * Copyright 2023 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.bootstrap; + +import io.netty.channel.Channel; +import io.netty.channel.ServerChannel; + +/** + * A channel initializer extension make it possible to enforce rules and apply modifications across multiple, + * disconnected uses of Netty within the same JVM process. + *

+ * For instance, application-level firewall rules can be injected into all uses of Netty within an application, + * without making changes to such uses that are otherwise outside the purview of the application code, + * such as 3rd-party libraries. + *

+ * Channel initializer extensions are not enabled by default, because of their power to influence Netty + * pipelines across libraries, frameworks, and use-cases. + * Extensions must be explicitly enabled by setting the {@value #EXTENSIONS_SYSTEM_PROPERTY} to {@code serviceload}. + *

+ * All channel initializer extensions that are available on the classpath will be + * {@linkplain java.util.ServiceLoader#load(Class) service-loaded} and used by all {@link AbstractBootstrap} subclasses. + *

+ * Note that this feature will not work for Netty uses that are shaded and relocated into other libraries. + * The classes in a relocated Netty library are technically distinct and incompatible types. This means the + * service-loader in non-relocated Netty will not see types from a relocated Netty, and vice versa. + */ +public abstract class ChannelInitializerExtension { + /** + * The name of the system property that control initializer extensions. + *

+ * These extensions can potentially be a security liability, so they are disabled by default. + *

+ * To enable the extensions, application operators can explicitly opt in by setting this system property to the + * value {@code serviceload}. This will enable all the extensions that are available through the service loader + * mechanism. + *

+ * To load and log (at INFO level) all available extensions without actually running them, set this system property + * to the value {@code log}. + */ + public static final String EXTENSIONS_SYSTEM_PROPERTY = "io.netty.bootstrap.extensions"; + + /** + * Get the "priority" of this extension. If multiple extensions are avilable, then they will be called in their + * priority order, from lowest to highest. + *

+ * Implementers are encouraged to pick a number between {@code -100.0} and {@code 100.0}, where extensions that have + * no particular opinion on their priority are encouraged to return {@code 0.0}. + *

+ * Extensions with lower priority will get called first, while extensions with greater priority may be able to + * observe the effects of extensions with lesser priority. + *

+ * Note that if multiple extensions have the same priority, then their relative order will be unpredictable. + * As such, implementations should always take into consideration that other extensions might be called before + * or after them. + *

+ * Override this method to specify your own priority. + * The default implementation just returns {@code 0}. + * + * @return The priority. + */ + public double priority() { + return 0; + } + + /** + * Called by {@link Bootstrap} after the initialization of the given client channel. + *

+ * The method is allowed to modify the handlers in the pipeline, the channel attributes, or the channel options. + * The method must refrain from doing any I/O, or from closing the channel. + *

+ * Override this method to add your own callback logic. + * The default implementation does nothing. + * + * @param channel The channel that was initialized. + */ + public void postInitializeClientChannel(Channel channel) { + } + + /** + * Called by {@link ServerBootstrap} after the initialization of the given server listener channel. + * The listener channel is responsible for invoking the {@code accept(2)} system call, + * and for producing child channels. + *

+ * The method is allowed to modify the handlers in the pipeline, the channel attributes, or the channel options. + * The method must refrain from doing any I/O, or from closing the channel. + *

+ * Override this method to add your own callback logic. + * The default implementation does nothing. + * + * @param channel The channel that was initialized. + */ + public void postInitializeServerListenerChannel(ServerChannel channel) { + } + + /** + * Called by {@link ServerBootstrap} after the initialization of the given child channel. + * A child channel is a newly established connection from a client to the server. + *

+ * The method is allowed to modify the handlers in the pipeline, the channel attributes, or the channel options. + * The method must refrain from doing any I/O, or from closing the channel. + *

+ * Override this method to add your own callback logic. + * The default implementation does nothing. + * + * @param channel The channel that was initialized. + */ + public void postInitializeServerChildChannel(Channel channel) { + } +} diff --git a/netty-channel/src/main/java/io/netty/bootstrap/ChannelInitializerExtensions.java b/netty-channel/src/main/java/io/netty/bootstrap/ChannelInitializerExtensions.java new file mode 100644 index 0000000..f4cbb4b --- /dev/null +++ b/netty-channel/src/main/java/io/netty/bootstrap/ChannelInitializerExtensions.java @@ -0,0 +1,128 @@ +/* + * Copyright 2023 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.bootstrap; + +import io.netty.util.internal.SystemPropertyUtil; +import io.netty.util.internal.logging.InternalLogLevel; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.lang.ref.WeakReference; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.Comparator; +import java.util.ServiceLoader; + +/** + * The configurable facade that decides what {@link ChannelInitializerExtension}s to load and where to find them. + */ +abstract class ChannelInitializerExtensions { + private static final InternalLogger logger = InternalLoggerFactory.getInstance(ChannelInitializerExtensions.class); + private static volatile ChannelInitializerExtensions implementation; + + private ChannelInitializerExtensions() { + } + + /** + * Get the configuration extensions, which is a no-op implementation by default, + * or a service-loading implementation if the {@code io.netty.bootstrap.extensions} system property is + * {@code serviceload}. + */ + static ChannelInitializerExtensions getExtensions() { + ChannelInitializerExtensions impl = implementation; + if (impl == null) { + synchronized (ChannelInitializerExtensions.class) { + impl = implementation; + if (impl != null) { + return impl; + } + String extensionProp = SystemPropertyUtil.get(ChannelInitializerExtension.EXTENSIONS_SYSTEM_PROPERTY); + logger.debug("-Dio.netty.bootstrap.extensions: {}", extensionProp); + if ("serviceload".equalsIgnoreCase(extensionProp)) { + impl = new ServiceLoadingExtensions(InternalLogLevel.DEBUG, true); + } else if ("log".equalsIgnoreCase(extensionProp)) { + impl = new ServiceLoadingExtensions(InternalLogLevel.INFO, false); + } else { + impl = new EmptyExtensions(); + } + implementation = impl; + } + } + return impl; + } + + /** + * Get the list of available extensions. The list is unmodifiable. + */ + abstract Collection extensions(ClassLoader cl); + + private static final class EmptyExtensions extends ChannelInitializerExtensions { + @Override + Collection extensions(ClassLoader cl) { + return Collections.emptyList(); + } + } + + private static final class ServiceLoadingExtensions extends ChannelInitializerExtensions { + private final InternalLogLevel logLevel; + private final boolean loadAndCache; + + private WeakReference classLoader; + private Collection extensions; + + ServiceLoadingExtensions(InternalLogLevel logLevel, boolean loadAndCache) { + this.logLevel = logLevel; + this.loadAndCache = loadAndCache; + } + + @SuppressWarnings("AssignmentOrReturnOfFieldWithMutableType") + @Override + synchronized Collection extensions(ClassLoader cl) { + ClassLoader configured = classLoader == null ? null : classLoader.get(); + if (configured == null || configured != cl) { + Collection loaded = serviceLoadExtensions(logLevel, cl); + classLoader = new WeakReference(cl); + extensions = loadAndCache ? loaded : Collections.emptyList(); + } + return extensions; + } + + private static Collection serviceLoadExtensions( + InternalLogLevel logLevel, ClassLoader cl) { + ArrayList extensions = new ArrayList(); + + ServiceLoader loader = ServiceLoader.load( + ChannelInitializerExtension.class, cl); + logger.log(logLevel, "Loader: {}", loader); + for (ChannelInitializerExtension extension : loader) { + logger.log(logLevel, "Loaded extension: {}", extension.getClass()); + extensions.add(extension); + } + + if (!extensions.isEmpty()) { + Collections.sort(extensions, new Comparator() { + @Override + public int compare(ChannelInitializerExtension a, ChannelInitializerExtension b) { + return Double.compare(a.priority(), b.priority()); + } + }); + return Collections.unmodifiableList(extensions); + } + return Collections.emptyList(); + } + } +} diff --git a/netty-channel/src/main/java/io/netty/bootstrap/FailedChannel.java b/netty-channel/src/main/java/io/netty/bootstrap/FailedChannel.java new file mode 100644 index 0000000..96ef7c0 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/bootstrap/FailedChannel.java @@ -0,0 +1,107 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.bootstrap; + +import io.netty.channel.AbstractChannel; +import io.netty.channel.ChannelConfig; +import io.netty.channel.ChannelMetadata; +import io.netty.channel.ChannelOutboundBuffer; +import io.netty.channel.ChannelPromise; +import io.netty.channel.DefaultChannelConfig; +import io.netty.channel.EventLoop; + +import java.net.SocketAddress; + +final class FailedChannel extends AbstractChannel { + private static final ChannelMetadata METADATA = new ChannelMetadata(false); + private final ChannelConfig config = new DefaultChannelConfig(this); + + FailedChannel() { + super(null); + } + + @Override + protected AbstractUnsafe newUnsafe() { + return new FailedChannelUnsafe(); + } + + @Override + protected boolean isCompatible(EventLoop loop) { + return false; + } + + @Override + protected SocketAddress localAddress0() { + return null; + } + + @Override + protected SocketAddress remoteAddress0() { + return null; + } + + @Override + protected void doBind(SocketAddress localAddress) { + throw new UnsupportedOperationException(); + } + + @Override + protected void doDisconnect() { + throw new UnsupportedOperationException(); + } + + @Override + protected void doClose() { + throw new UnsupportedOperationException(); + } + + @Override + protected void doBeginRead() { + throw new UnsupportedOperationException(); + } + + @Override + protected void doWrite(ChannelOutboundBuffer in) { + throw new UnsupportedOperationException(); + } + + @Override + public ChannelConfig config() { + return config; + } + + @Override + public boolean isOpen() { + return false; + } + + @Override + public boolean isActive() { + return false; + } + + @Override + public ChannelMetadata metadata() { + return METADATA; + } + + private final class FailedChannelUnsafe extends AbstractUnsafe { + @Override + public void connect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) { + promise.setFailure(new UnsupportedOperationException()); + } + } +} diff --git a/netty-channel/src/main/java/io/netty/bootstrap/ServerBootstrap.java b/netty-channel/src/main/java/io/netty/bootstrap/ServerBootstrap.java new file mode 100644 index 0000000..1781f4e --- /dev/null +++ b/netty-channel/src/main/java/io/netty/bootstrap/ServerBootstrap.java @@ -0,0 +1,309 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.bootstrap; + +import io.netty.channel.Channel; +import io.netty.channel.ChannelConfig; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelOption; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.ServerChannel; +import io.netty.util.AttributeKey; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.util.Collection; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.Map.Entry; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.TimeUnit; + +/** + * {@link Bootstrap} sub-class which allows easy bootstrap of {@link ServerChannel} + * + */ +public class ServerBootstrap extends AbstractBootstrap { + + private static final InternalLogger logger = InternalLoggerFactory.getInstance(ServerBootstrap.class); + + // The order in which child ChannelOptions are applied is important they may depend on each other for validation + // purposes. + private final Map, Object> childOptions = new LinkedHashMap, Object>(); + private final Map, Object> childAttrs = new ConcurrentHashMap, Object>(); + private final ServerBootstrapConfig config = new ServerBootstrapConfig(this); + private volatile EventLoopGroup childGroup; + private volatile ChannelHandler childHandler; + + public ServerBootstrap() { } + + private ServerBootstrap(ServerBootstrap bootstrap) { + super(bootstrap); + childGroup = bootstrap.childGroup; + childHandler = bootstrap.childHandler; + synchronized (bootstrap.childOptions) { + childOptions.putAll(bootstrap.childOptions); + } + childAttrs.putAll(bootstrap.childAttrs); + } + + /** + * Specify the {@link EventLoopGroup} which is used for the parent (acceptor) and the child (client). + */ + @Override + public ServerBootstrap group(EventLoopGroup group) { + return group(group, group); + } + + /** + * Set the {@link EventLoopGroup} for the parent (acceptor) and the child (client). These + * {@link EventLoopGroup}'s are used to handle all the events and IO for {@link ServerChannel} and + * {@link Channel}'s. + */ + public ServerBootstrap group(EventLoopGroup parentGroup, EventLoopGroup childGroup) { + super.group(parentGroup); + if (this.childGroup != null) { + throw new IllegalStateException("childGroup set already"); + } + this.childGroup = ObjectUtil.checkNotNull(childGroup, "childGroup"); + return this; + } + + /** + * Allow to specify a {@link ChannelOption} which is used for the {@link Channel} instances once they get created + * (after the acceptor accepted the {@link Channel}). Use a value of {@code null} to remove a previous set + * {@link ChannelOption}. + */ + public ServerBootstrap childOption(ChannelOption childOption, T value) { + ObjectUtil.checkNotNull(childOption, "childOption"); + synchronized (childOptions) { + if (value == null) { + childOptions.remove(childOption); + } else { + childOptions.put(childOption, value); + } + } + return this; + } + + /** + * Set the specific {@link AttributeKey} with the given value on every child {@link Channel}. If the value is + * {@code null} the {@link AttributeKey} is removed + */ + public ServerBootstrap childAttr(AttributeKey childKey, T value) { + ObjectUtil.checkNotNull(childKey, "childKey"); + if (value == null) { + childAttrs.remove(childKey); + } else { + childAttrs.put(childKey, value); + } + return this; + } + + /** + * Set the {@link ChannelHandler} which is used to serve the request for the {@link Channel}'s. + */ + public ServerBootstrap childHandler(ChannelHandler childHandler) { + this.childHandler = ObjectUtil.checkNotNull(childHandler, "childHandler"); + return this; + } + + @Override + void init(Channel channel) { + setChannelOptions(channel, newOptionsArray(), logger); + setAttributes(channel, newAttributesArray()); + + ChannelPipeline p = channel.pipeline(); + + final EventLoopGroup currentChildGroup = childGroup; + final ChannelHandler currentChildHandler = childHandler; + final Entry, Object>[] currentChildOptions = newOptionsArray(childOptions); + final Entry, Object>[] currentChildAttrs = newAttributesArray(childAttrs); + final Collection extensions = getInitializerExtensions(); + + p.addLast(new ChannelInitializer() { + @Override + public void initChannel(final Channel ch) { + final ChannelPipeline pipeline = ch.pipeline(); + ChannelHandler handler = config.handler(); + if (handler != null) { + pipeline.addLast(handler); + } + + ch.eventLoop().execute(new Runnable() { + @Override + public void run() { + pipeline.addLast(new ServerBootstrapAcceptor( + ch, currentChildGroup, currentChildHandler, currentChildOptions, currentChildAttrs, + extensions)); + } + }); + } + }); + if (!extensions.isEmpty() && channel instanceof ServerChannel) { + ServerChannel serverChannel = (ServerChannel) channel; + for (ChannelInitializerExtension extension : extensions) { + try { + extension.postInitializeServerListenerChannel(serverChannel); + } catch (Exception e) { + logger.warn("Exception thrown from postInitializeServerListenerChannel", e); + } + } + } + } + + @Override + public ServerBootstrap validate() { + super.validate(); + if (childHandler == null) { + throw new IllegalStateException("childHandler not set"); + } + if (childGroup == null) { + logger.warn("childGroup is not set. Using parentGroup instead."); + childGroup = config.group(); + } + return this; + } + + private static class ServerBootstrapAcceptor extends ChannelInboundHandlerAdapter { + + private final EventLoopGroup childGroup; + private final ChannelHandler childHandler; + private final Entry, Object>[] childOptions; + private final Entry, Object>[] childAttrs; + private final Runnable enableAutoReadTask; + private final Collection extensions; + + ServerBootstrapAcceptor( + final Channel channel, EventLoopGroup childGroup, ChannelHandler childHandler, + Entry, Object>[] childOptions, Entry, Object>[] childAttrs, + Collection extensions) { + this.childGroup = childGroup; + this.childHandler = childHandler; + this.childOptions = childOptions; + this.childAttrs = childAttrs; + this.extensions = extensions; + + // Task which is scheduled to re-enable auto-read. + // It's important to create this Runnable before we try to submit it as otherwise the URLClassLoader may + // not be able to load the class because of the file limit it already reached. + // + // See https://github.com/netty/netty/issues/1328 + enableAutoReadTask = new Runnable() { + @Override + public void run() { + channel.config().setAutoRead(true); + } + }; + } + + @Override + @SuppressWarnings("unchecked") + public void channelRead(ChannelHandlerContext ctx, Object msg) { + final Channel child = (Channel) msg; + + child.pipeline().addLast(childHandler); + + setChannelOptions(child, childOptions, logger); + setAttributes(child, childAttrs); + + if (!extensions.isEmpty()) { + for (ChannelInitializerExtension extension : extensions) { + try { + extension.postInitializeServerChildChannel(child); + } catch (Exception e) { + logger.warn("Exception thrown from postInitializeServerChildChannel", e); + } + } + } + + try { + childGroup.register(child).addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + if (!future.isSuccess()) { + forceClose(child, future.cause()); + } + } + }); + } catch (Throwable t) { + forceClose(child, t); + } + } + + private static void forceClose(Channel child, Throwable t) { + child.unsafe().closeForcibly(); + logger.warn("Failed to register an accepted channel: {}", child, t); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + final ChannelConfig config = ctx.channel().config(); + if (config.isAutoRead()) { + // stop accept new connections for 1 second to allow the channel to recover + // See https://github.com/netty/netty/issues/1328 + config.setAutoRead(false); + ctx.channel().eventLoop().schedule(enableAutoReadTask, 1, TimeUnit.SECONDS); + } + // still let the exceptionCaught event flow through the pipeline to give the user + // a chance to do something with it + ctx.fireExceptionCaught(cause); + } + } + + @Override + @SuppressWarnings("CloneDoesntCallSuperClone") + public ServerBootstrap clone() { + return new ServerBootstrap(this); + } + + /** + * Return the configured {@link EventLoopGroup} which will be used for the child channels or {@code null} + * if non is configured yet. + * + * @deprecated Use {@link #config()} instead. + */ + @Deprecated + public EventLoopGroup childGroup() { + return childGroup; + } + + final ChannelHandler childHandler() { + return childHandler; + } + + final Map, Object> childOptions() { + synchronized (childOptions) { + return copiedMap(childOptions); + } + } + + final Map, Object> childAttrs() { + return copiedMap(childAttrs); + } + + @Override + public final ServerBootstrapConfig config() { + return config; + } +} diff --git a/netty-channel/src/main/java/io/netty/bootstrap/ServerBootstrapConfig.java b/netty-channel/src/main/java/io/netty/bootstrap/ServerBootstrapConfig.java new file mode 100644 index 0000000..676a803 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/bootstrap/ServerBootstrapConfig.java @@ -0,0 +1,105 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.bootstrap; + +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelOption; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.ServerChannel; +import io.netty.util.AttributeKey; +import io.netty.util.internal.StringUtil; + +import java.util.Map; + +/** + * Exposes the configuration of a {@link ServerBootstrapConfig}. + */ +public final class ServerBootstrapConfig extends AbstractBootstrapConfig { + + ServerBootstrapConfig(ServerBootstrap bootstrap) { + super(bootstrap); + } + + /** + * Returns the configured {@link EventLoopGroup} which will be used for the child channels or {@code null} + * if non is configured yet. + */ + @SuppressWarnings("deprecation") + public EventLoopGroup childGroup() { + return bootstrap.childGroup(); + } + + /** + * Returns the configured {@link ChannelHandler} be used for the child channels or {@code null} + * if non is configured yet. + */ + public ChannelHandler childHandler() { + return bootstrap.childHandler(); + } + + /** + * Returns a copy of the configured options which will be used for the child channels. + */ + public Map, Object> childOptions() { + return bootstrap.childOptions(); + } + + /** + * Returns a copy of the configured attributes which will be used for the child channels. + */ + public Map, Object> childAttrs() { + return bootstrap.childAttrs(); + } + + @Override + public String toString() { + StringBuilder buf = new StringBuilder(super.toString()); + buf.setLength(buf.length() - 1); + buf.append(", "); + EventLoopGroup childGroup = childGroup(); + if (childGroup != null) { + buf.append("childGroup: "); + buf.append(StringUtil.simpleClassName(childGroup)); + buf.append(", "); + } + Map, Object> childOptions = childOptions(); + if (!childOptions.isEmpty()) { + buf.append("childOptions: "); + buf.append(childOptions); + buf.append(", "); + } + Map, Object> childAttrs = childAttrs(); + if (!childAttrs.isEmpty()) { + buf.append("childAttrs: "); + buf.append(childAttrs); + buf.append(", "); + } + ChannelHandler childHandler = childHandler(); + if (childHandler != null) { + buf.append("childHandler: "); + buf.append(childHandler); + buf.append(", "); + } + if (buf.charAt(buf.length() - 1) == '(') { + buf.append(')'); + } else { + buf.setCharAt(buf.length() - 2, ')'); + buf.setLength(buf.length() - 1); + } + + return buf.toString(); + } +} diff --git a/netty-channel/src/main/java/io/netty/bootstrap/package-info.java b/netty-channel/src/main/java/io/netty/bootstrap/package-info.java new file mode 100644 index 0000000..563688d --- /dev/null +++ b/netty-channel/src/main/java/io/netty/bootstrap/package-info.java @@ -0,0 +1,21 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/** + * The helper classes with fluent API which enable an easy implementation of + * typical client side and server side channel initialization. + */ +package io.netty.bootstrap; diff --git a/netty-channel/src/main/java/io/netty/channel/AbstractChannel.java b/netty-channel/src/main/java/io/netty/channel/AbstractChannel.java new file mode 100644 index 0000000..79a0949 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/AbstractChannel.java @@ -0,0 +1,1215 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.buffer.ByteBufAllocator; +import io.netty.channel.socket.ChannelOutputShutdownEvent; +import io.netty.channel.socket.ChannelOutputShutdownException; +import io.netty.util.DefaultAttributeMap; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.UnstableApi; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.io.IOException; +import java.net.ConnectException; +import java.net.InetSocketAddress; +import java.net.NoRouteToHostException; +import java.net.SocketAddress; +import java.net.SocketException; +import java.nio.channels.ClosedChannelException; +import java.nio.channels.NotYetConnectedException; +import java.util.concurrent.Executor; +import java.util.concurrent.RejectedExecutionException; + +/** + * A skeletal {@link Channel} implementation. + */ +public abstract class AbstractChannel extends DefaultAttributeMap implements Channel { + + private static final InternalLogger logger = InternalLoggerFactory.getInstance(AbstractChannel.class); + + private final Channel parent; + private final ChannelId id; + private final Unsafe unsafe; + private final DefaultChannelPipeline pipeline; + private final VoidChannelPromise unsafeVoidPromise = new VoidChannelPromise(this, false); + private final CloseFuture closeFuture = new CloseFuture(this); + + private volatile SocketAddress localAddress; + private volatile SocketAddress remoteAddress; + private volatile EventLoop eventLoop; + private volatile boolean registered; + private boolean closeInitiated; + private Throwable initialCloseCause; + + /** Cache for the string representation of this channel */ + private boolean strValActive; + private String strVal; + + /** + * Creates a new instance. + * + * @param parent + * the parent of this channel. {@code null} if there's no parent. + */ + protected AbstractChannel(Channel parent) { + this.parent = parent; + id = newId(); + unsafe = newUnsafe(); + pipeline = newChannelPipeline(); + } + + /** + * Creates a new instance. + * + * @param parent + * the parent of this channel. {@code null} if there's no parent. + */ + protected AbstractChannel(Channel parent, ChannelId id) { + this.parent = parent; + this.id = id; + unsafe = newUnsafe(); + pipeline = newChannelPipeline(); + } + + protected final int maxMessagesPerWrite() { + ChannelConfig config = config(); + if (config instanceof DefaultChannelConfig) { + return ((DefaultChannelConfig) config).getMaxMessagesPerWrite(); + } + Integer value = config.getOption(ChannelOption.MAX_MESSAGES_PER_WRITE); + if (value == null) { + return Integer.MAX_VALUE; + } + return value; + } + + @Override + public final ChannelId id() { + return id; + } + + /** + * Returns a new {@link DefaultChannelId} instance. Subclasses may override this method to assign custom + * {@link ChannelId}s to {@link Channel}s that use the {@link AbstractChannel#AbstractChannel(Channel)} constructor. + */ + protected ChannelId newId() { + return DefaultChannelId.newInstance(); + } + + /** + * Returns a new {@link DefaultChannelPipeline} instance. + */ + protected DefaultChannelPipeline newChannelPipeline() { + return new DefaultChannelPipeline(this); + } + + @Override + public boolean isWritable() { + ChannelOutboundBuffer buf = unsafe.outboundBuffer(); + return buf != null && buf.isWritable(); + } + + @Override + public long bytesBeforeUnwritable() { + ChannelOutboundBuffer buf = unsafe.outboundBuffer(); + // isWritable() is currently assuming if there is no outboundBuffer then the channel is not writable. + // We should be consistent with that here. + return buf != null ? buf.bytesBeforeUnwritable() : 0; + } + + @Override + public long bytesBeforeWritable() { + ChannelOutboundBuffer buf = unsafe.outboundBuffer(); + // isWritable() is currently assuming if there is no outboundBuffer then the channel is not writable. + // We should be consistent with that here. + return buf != null ? buf.bytesBeforeWritable() : Long.MAX_VALUE; + } + + @Override + public Channel parent() { + return parent; + } + + @Override + public ChannelPipeline pipeline() { + return pipeline; + } + + @Override + public ByteBufAllocator alloc() { + return config().getAllocator(); + } + + @Override + public EventLoop eventLoop() { + EventLoop eventLoop = this.eventLoop; + if (eventLoop == null) { + throw new IllegalStateException("channel not registered to an event loop"); + } + return eventLoop; + } + + @Override + public SocketAddress localAddress() { + SocketAddress localAddress = this.localAddress; + if (localAddress == null) { + try { + this.localAddress = localAddress = unsafe().localAddress(); + } catch (Error e) { + throw e; + } catch (Throwable t) { + // Sometimes fails on a closed socket in Windows. + return null; + } + } + return localAddress; + } + + /** + * @deprecated no use-case for this. + */ + @Deprecated + protected void invalidateLocalAddress() { + localAddress = null; + } + + @Override + public SocketAddress remoteAddress() { + SocketAddress remoteAddress = this.remoteAddress; + if (remoteAddress == null) { + try { + this.remoteAddress = remoteAddress = unsafe().remoteAddress(); + } catch (Error e) { + throw e; + } catch (Throwable t) { + // Sometimes fails on a closed socket in Windows. + return null; + } + } + return remoteAddress; + } + + /** + * @deprecated no use-case for this. + */ + @Deprecated + protected void invalidateRemoteAddress() { + remoteAddress = null; + } + + @Override + public boolean isRegistered() { + return registered; + } + + @Override + public ChannelFuture bind(SocketAddress localAddress) { + return pipeline.bind(localAddress); + } + + @Override + public ChannelFuture connect(SocketAddress remoteAddress) { + return pipeline.connect(remoteAddress); + } + + @Override + public ChannelFuture connect(SocketAddress remoteAddress, SocketAddress localAddress) { + return pipeline.connect(remoteAddress, localAddress); + } + + @Override + public ChannelFuture disconnect() { + return pipeline.disconnect(); + } + + @Override + public ChannelFuture close() { + return pipeline.close(); + } + + @Override + public ChannelFuture deregister() { + return pipeline.deregister(); + } + + @Override + public Channel flush() { + pipeline.flush(); + return this; + } + + @Override + public ChannelFuture bind(SocketAddress localAddress, ChannelPromise promise) { + return pipeline.bind(localAddress, promise); + } + + @Override + public ChannelFuture connect(SocketAddress remoteAddress, ChannelPromise promise) { + return pipeline.connect(remoteAddress, promise); + } + + @Override + public ChannelFuture connect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) { + return pipeline.connect(remoteAddress, localAddress, promise); + } + + @Override + public ChannelFuture disconnect(ChannelPromise promise) { + return pipeline.disconnect(promise); + } + + @Override + public ChannelFuture close(ChannelPromise promise) { + return pipeline.close(promise); + } + + @Override + public ChannelFuture deregister(ChannelPromise promise) { + return pipeline.deregister(promise); + } + + @Override + public Channel read() { + pipeline.read(); + return this; + } + + @Override + public ChannelFuture write(Object msg) { + return pipeline.write(msg); + } + + @Override + public ChannelFuture write(Object msg, ChannelPromise promise) { + return pipeline.write(msg, promise); + } + + @Override + public ChannelFuture writeAndFlush(Object msg) { + return pipeline.writeAndFlush(msg); + } + + @Override + public ChannelFuture writeAndFlush(Object msg, ChannelPromise promise) { + return pipeline.writeAndFlush(msg, promise); + } + + @Override + public ChannelPromise newPromise() { + return pipeline.newPromise(); + } + + @Override + public ChannelProgressivePromise newProgressivePromise() { + return pipeline.newProgressivePromise(); + } + + @Override + public ChannelFuture newSucceededFuture() { + return pipeline.newSucceededFuture(); + } + + @Override + public ChannelFuture newFailedFuture(Throwable cause) { + return pipeline.newFailedFuture(cause); + } + + @Override + public ChannelFuture closeFuture() { + return closeFuture; + } + + @Override + public Unsafe unsafe() { + return unsafe; + } + + /** + * Create a new {@link AbstractUnsafe} instance which will be used for the life-time of the {@link Channel} + */ + protected abstract AbstractUnsafe newUnsafe(); + + /** + * Returns the ID of this channel. + */ + @Override + public final int hashCode() { + return id.hashCode(); + } + + /** + * Returns {@code true} if and only if the specified object is identical + * with this channel (i.e: {@code this == o}). + */ + @Override + public final boolean equals(Object o) { + return this == o; + } + + @Override + public final int compareTo(Channel o) { + if (this == o) { + return 0; + } + + return id().compareTo(o.id()); + } + + /** + * Returns the {@link String} representation of this channel. The returned + * string contains the {@linkplain #hashCode() ID}, {@linkplain #localAddress() local address}, + * and {@linkplain #remoteAddress() remote address} of this channel for + * easier identification. + */ + @Override + public String toString() { + boolean active = isActive(); + if (strValActive == active && strVal != null) { + return strVal; + } + + SocketAddress remoteAddr = remoteAddress(); + SocketAddress localAddr = localAddress(); + if (remoteAddr != null) { + StringBuilder buf = new StringBuilder(96) + .append("[id: 0x") + .append(id.asShortText()) + .append(", L:") + .append(localAddr) + .append(active? " - " : " ! ") + .append("R:") + .append(remoteAddr) + .append(']'); + strVal = buf.toString(); + } else if (localAddr != null) { + StringBuilder buf = new StringBuilder(64) + .append("[id: 0x") + .append(id.asShortText()) + .append(", L:") + .append(localAddr) + .append(']'); + strVal = buf.toString(); + } else { + StringBuilder buf = new StringBuilder(16) + .append("[id: 0x") + .append(id.asShortText()) + .append(']'); + strVal = buf.toString(); + } + + strValActive = active; + return strVal; + } + + @Override + public final ChannelPromise voidPromise() { + return pipeline.voidPromise(); + } + + /** + * {@link Unsafe} implementation which sub-classes must extend and use. + */ + protected abstract class AbstractUnsafe implements Unsafe { + + private volatile ChannelOutboundBuffer outboundBuffer = new ChannelOutboundBuffer(AbstractChannel.this); + private RecvByteBufAllocator.Handle recvHandle; + private boolean inFlush0; + /** true if the channel has never been registered, false otherwise */ + private boolean neverRegistered = true; + + private void assertEventLoop() { + assert !registered || eventLoop.inEventLoop(); + } + + @Override + public RecvByteBufAllocator.Handle recvBufAllocHandle() { + if (recvHandle == null) { + recvHandle = config().getRecvByteBufAllocator().newHandle(); + } + return recvHandle; + } + + @Override + public final ChannelOutboundBuffer outboundBuffer() { + return outboundBuffer; + } + + @Override + public final SocketAddress localAddress() { + return localAddress0(); + } + + @Override + public final SocketAddress remoteAddress() { + return remoteAddress0(); + } + + @Override + public final void register(EventLoop eventLoop, final ChannelPromise promise) { + ObjectUtil.checkNotNull(eventLoop, "eventLoop"); + if (isRegistered()) { + promise.setFailure(new IllegalStateException("registered to an event loop already")); + return; + } + if (!isCompatible(eventLoop)) { + promise.setFailure( + new IllegalStateException("incompatible event loop type: " + eventLoop.getClass().getName())); + return; + } + + AbstractChannel.this.eventLoop = eventLoop; + + if (eventLoop.inEventLoop()) { + register0(promise); + } else { + try { + eventLoop.execute(new Runnable() { + @Override + public void run() { + register0(promise); + } + }); + } catch (Throwable t) { + logger.warn( + "Force-closing a channel whose registration task was not accepted by an event loop: {}", + AbstractChannel.this, t); + closeForcibly(); + closeFuture.setClosed(); + safeSetFailure(promise, t); + } + } + } + + private void register0(ChannelPromise promise) { + try { + // check if the channel is still open as it could be closed in the mean time when the register + // call was outside of the eventLoop + if (!promise.setUncancellable() || !ensureOpen(promise)) { + return; + } + boolean firstRegistration = neverRegistered; + doRegister(); + neverRegistered = false; + registered = true; + + // Ensure we call handlerAdded(...) before we actually notify the promise. This is needed as the + // user may already fire events through the pipeline in the ChannelFutureListener. + pipeline.invokeHandlerAddedIfNeeded(); + + safeSetSuccess(promise); + pipeline.fireChannelRegistered(); + // Only fire a channelActive if the channel has never been registered. This prevents firing + // multiple channel actives if the channel is deregistered and re-registered. + if (isActive()) { + if (firstRegistration) { + pipeline.fireChannelActive(); + } else if (config().isAutoRead()) { + // This channel was registered before and autoRead() is set. This means we need to begin read + // again so that we process inbound data. + // + // See https://github.com/netty/netty/issues/4805 + beginRead(); + } + } + } catch (Throwable t) { + // Close the channel directly to avoid FD leak. + closeForcibly(); + closeFuture.setClosed(); + safeSetFailure(promise, t); + } + } + + @Override + public final void bind(final SocketAddress localAddress, final ChannelPromise promise) { + assertEventLoop(); + + if (!promise.setUncancellable() || !ensureOpen(promise)) { + return; + } + + // See: https://github.com/netty/netty/issues/576 + if (Boolean.TRUE.equals(config().getOption(ChannelOption.SO_BROADCAST)) && + localAddress instanceof InetSocketAddress && + !((InetSocketAddress) localAddress).getAddress().isAnyLocalAddress() && + !PlatformDependent.isWindows() && !PlatformDependent.maybeSuperUser()) { + // Warn a user about the fact that a non-root user can't receive a + // broadcast packet on *nix if the socket is bound on non-wildcard address. + logger.warn( + "A non-root user can't receive a broadcast packet if the socket " + + "is not bound to a wildcard address; binding to a non-wildcard " + + "address (" + localAddress + ") anyway as requested."); + } + + boolean wasActive = isActive(); + try { + doBind(localAddress); + } catch (Throwable t) { + safeSetFailure(promise, t); + closeIfClosed(); + return; + } + + if (!wasActive && isActive()) { + invokeLater(new Runnable() { + @Override + public void run() { + pipeline.fireChannelActive(); + } + }); + } + + safeSetSuccess(promise); + } + + @Override + public final void disconnect(final ChannelPromise promise) { + assertEventLoop(); + + if (!promise.setUncancellable()) { + return; + } + + boolean wasActive = isActive(); + try { + doDisconnect(); + // Reset remoteAddress and localAddress + remoteAddress = null; + localAddress = null; + } catch (Throwable t) { + safeSetFailure(promise, t); + closeIfClosed(); + return; + } + + if (wasActive && !isActive()) { + invokeLater(new Runnable() { + @Override + public void run() { + pipeline.fireChannelInactive(); + } + }); + } + + safeSetSuccess(promise); + closeIfClosed(); // doDisconnect() might have closed the channel + } + + @Override + public void close(final ChannelPromise promise) { + assertEventLoop(); + + ClosedChannelException closedChannelException = + StacklessClosedChannelException.newInstance(AbstractChannel.class, "close(ChannelPromise)"); + close(promise, closedChannelException, closedChannelException, false); + } + + /** + * Shutdown the output portion of the corresponding {@link Channel}. + * For example this will clean up the {@link ChannelOutboundBuffer} and not allow any more writes. + */ + @UnstableApi + public final void shutdownOutput(final ChannelPromise promise) { + assertEventLoop(); + shutdownOutput(promise, null); + } + + /** + * Shutdown the output portion of the corresponding {@link Channel}. + * For example this will clean up the {@link ChannelOutboundBuffer} and not allow any more writes. + * @param cause The cause which may provide rational for the shutdown. + */ + private void shutdownOutput(final ChannelPromise promise, Throwable cause) { + if (!promise.setUncancellable()) { + return; + } + + final ChannelOutboundBuffer outboundBuffer = this.outboundBuffer; + if (outboundBuffer == null) { + promise.setFailure(new ClosedChannelException()); + return; + } + this.outboundBuffer = null; // Disallow adding any messages and flushes to outboundBuffer. + + final Throwable shutdownCause = cause == null ? + new ChannelOutputShutdownException("Channel output shutdown") : + new ChannelOutputShutdownException("Channel output shutdown", cause); + + // When a side enables SO_LINGER and calls showdownOutput(...) to start TCP half-closure + // we can not call doDeregister here because we should ensure this side in fin_wait2 state + // can still receive and process the data which is send by another side in the close_wait state。 + // See https://github.com/netty/netty/issues/11981 + try { + // The shutdown function does not block regardless of the SO_LINGER setting on the socket + // so we don't need to use GlobalEventExecutor to execute the shutdown + doShutdownOutput(); + promise.setSuccess(); + } catch (Throwable err) { + promise.setFailure(err); + } finally { + closeOutboundBufferForShutdown(pipeline, outboundBuffer, shutdownCause); + } + } + + private void closeOutboundBufferForShutdown( + ChannelPipeline pipeline, ChannelOutboundBuffer buffer, Throwable cause) { + buffer.failFlushed(cause, false); + buffer.close(cause, true); + pipeline.fireUserEventTriggered(ChannelOutputShutdownEvent.INSTANCE); + } + + private void close(final ChannelPromise promise, final Throwable cause, + final ClosedChannelException closeCause, final boolean notify) { + if (!promise.setUncancellable()) { + return; + } + + if (closeInitiated) { + if (closeFuture.isDone()) { + // Closed already. + safeSetSuccess(promise); + } else if (!(promise instanceof VoidChannelPromise)) { // Only needed if no VoidChannelPromise. + // This means close() was called before so we just register a listener and return + closeFuture.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + promise.setSuccess(); + } + }); + } + return; + } + + closeInitiated = true; + + final boolean wasActive = isActive(); + final ChannelOutboundBuffer outboundBuffer = this.outboundBuffer; + this.outboundBuffer = null; // Disallow adding any messages and flushes to outboundBuffer. + Executor closeExecutor = prepareToClose(); + if (closeExecutor != null) { + closeExecutor.execute(new Runnable() { + @Override + public void run() { + try { + // Execute the close. + doClose0(promise); + } finally { + // Call invokeLater so closeAndDeregister is executed in the EventLoop again! + invokeLater(new Runnable() { + @Override + public void run() { + if (outboundBuffer != null) { + // Fail all the queued messages + outboundBuffer.failFlushed(cause, notify); + outboundBuffer.close(closeCause); + } + fireChannelInactiveAndDeregister(wasActive); + } + }); + } + } + }); + } else { + try { + // Close the channel and fail the queued messages in all cases. + doClose0(promise); + } finally { + if (outboundBuffer != null) { + // Fail all the queued messages. + outboundBuffer.failFlushed(cause, notify); + outboundBuffer.close(closeCause); + } + } + if (inFlush0) { + invokeLater(new Runnable() { + @Override + public void run() { + fireChannelInactiveAndDeregister(wasActive); + } + }); + } else { + fireChannelInactiveAndDeregister(wasActive); + } + } + } + + private void doClose0(ChannelPromise promise) { + try { + doClose(); + closeFuture.setClosed(); + safeSetSuccess(promise); + } catch (Throwable t) { + closeFuture.setClosed(); + safeSetFailure(promise, t); + } + } + + private void fireChannelInactiveAndDeregister(final boolean wasActive) { + deregister(voidPromise(), wasActive && !isActive()); + } + + @Override + public final void closeForcibly() { + assertEventLoop(); + + try { + doClose(); + } catch (Exception e) { + logger.warn("Failed to close a channel.", e); + } + } + + @Override + public final void deregister(final ChannelPromise promise) { + assertEventLoop(); + + deregister(promise, false); + } + + private void deregister(final ChannelPromise promise, final boolean fireChannelInactive) { + if (!promise.setUncancellable()) { + return; + } + + if (!registered) { + safeSetSuccess(promise); + return; + } + + // As a user may call deregister() from within any method while doing processing in the ChannelPipeline, + // we need to ensure we do the actual deregister operation later. This is needed as for example, + // we may be in the ByteToMessageDecoder.callDecode(...) method and so still try to do processing in + // the old EventLoop while the user already registered the Channel to a new EventLoop. Without delay, + // the deregister operation this could lead to have a handler invoked by different EventLoop and so + // threads. + // + // See: + // https://github.com/netty/netty/issues/4435 + invokeLater(new Runnable() { + @Override + public void run() { + try { + doDeregister(); + } catch (Throwable t) { + logger.warn("Unexpected exception occurred while deregistering a channel.", t); + } finally { + if (fireChannelInactive) { + pipeline.fireChannelInactive(); + } + // Some transports like local and AIO does not allow the deregistration of + // an open channel. Their doDeregister() calls close(). Consequently, + // close() calls deregister() again - no need to fire channelUnregistered, so check + // if it was registered. + if (registered) { + registered = false; + pipeline.fireChannelUnregistered(); + } + safeSetSuccess(promise); + } + } + }); + } + + @Override + public final void beginRead() { + assertEventLoop(); + + try { + doBeginRead(); + } catch (final Exception e) { + invokeLater(new Runnable() { + @Override + public void run() { + pipeline.fireExceptionCaught(e); + } + }); + close(voidPromise()); + } + } + + @Override + public final void write(Object msg, ChannelPromise promise) { + assertEventLoop(); + + ChannelOutboundBuffer outboundBuffer = this.outboundBuffer; + if (outboundBuffer == null) { + try { + // release message now to prevent resource-leak + ReferenceCountUtil.release(msg); + } finally { + // If the outboundBuffer is null we know the channel was closed and so + // need to fail the future right away. If it is not null the handling of the rest + // will be done in flush0() + // See https://github.com/netty/netty/issues/2362 + safeSetFailure(promise, + newClosedChannelException(initialCloseCause, "write(Object, ChannelPromise)")); + } + return; + } + + int size; + try { + msg = filterOutboundMessage(msg); + size = pipeline.estimatorHandle().size(msg); + if (size < 0) { + size = 0; + } + } catch (Throwable t) { + try { + ReferenceCountUtil.release(msg); + } finally { + safeSetFailure(promise, t); + } + return; + } + + outboundBuffer.addMessage(msg, size, promise); + } + + @Override + public final void flush() { + assertEventLoop(); + + ChannelOutboundBuffer outboundBuffer = this.outboundBuffer; + if (outboundBuffer == null) { + return; + } + + outboundBuffer.addFlush(); + flush0(); + } + + @SuppressWarnings("deprecation") + protected void flush0() { + if (inFlush0) { + // Avoid re-entrance + return; + } + + final ChannelOutboundBuffer outboundBuffer = this.outboundBuffer; + if (outboundBuffer == null || outboundBuffer.isEmpty()) { + return; + } + + inFlush0 = true; + + // Mark all pending write requests as failure if the channel is inactive. + if (!isActive()) { + try { + // Check if we need to generate the exception at all. + if (!outboundBuffer.isEmpty()) { + if (isOpen()) { + outboundBuffer.failFlushed(new NotYetConnectedException(), true); + } else { + // Do not trigger channelWritabilityChanged because the channel is closed already. + outboundBuffer.failFlushed(newClosedChannelException(initialCloseCause, "flush0()"), false); + } + } + } finally { + inFlush0 = false; + } + return; + } + + try { + doWrite(outboundBuffer); + } catch (Throwable t) { + handleWriteError(t); + } finally { + inFlush0 = false; + } + } + + protected final void handleWriteError(Throwable t) { + if (t instanceof IOException && config().isAutoClose()) { + /** + * Just call {@link #close(ChannelPromise, Throwable, boolean)} here which will take care of + * failing all flushed messages and also ensure the actual close of the underlying transport + * will happen before the promises are notified. + * + * This is needed as otherwise {@link #isActive()} , {@link #isOpen()} and {@link #isWritable()} + * may still return {@code true} even if the channel should be closed as result of the exception. + */ + initialCloseCause = t; + close(voidPromise(), t, newClosedChannelException(t, "flush0()"), false); + } else { + try { + shutdownOutput(voidPromise(), t); + } catch (Throwable t2) { + initialCloseCause = t; + close(voidPromise(), t2, newClosedChannelException(t, "flush0()"), false); + } + } + } + + private ClosedChannelException newClosedChannelException(Throwable cause, String method) { + ClosedChannelException exception = + StacklessClosedChannelException.newInstance(AbstractChannel.AbstractUnsafe.class, method); + if (cause != null) { + exception.initCause(cause); + } + return exception; + } + + @Override + public final ChannelPromise voidPromise() { + assertEventLoop(); + + return unsafeVoidPromise; + } + + protected final boolean ensureOpen(ChannelPromise promise) { + if (isOpen()) { + return true; + } + + safeSetFailure(promise, newClosedChannelException(initialCloseCause, "ensureOpen(ChannelPromise)")); + return false; + } + + /** + * Marks the specified {@code promise} as success. If the {@code promise} is done already, log a message. + */ + protected final void safeSetSuccess(ChannelPromise promise) { + if (!(promise instanceof VoidChannelPromise) && !promise.trySuccess()) { + logger.warn("Failed to mark a promise as success because it is done already: {}", promise); + } + } + + /** + * Marks the specified {@code promise} as failure. If the {@code promise} is done already, log a message. + */ + protected final void safeSetFailure(ChannelPromise promise, Throwable cause) { + if (!(promise instanceof VoidChannelPromise) && !promise.tryFailure(cause)) { + logger.warn("Failed to mark a promise as failure because it's done already: {}", promise, cause); + } + } + + protected final void closeIfClosed() { + if (isOpen()) { + return; + } + close(voidPromise()); + } + + private void invokeLater(Runnable task) { + try { + // This method is used by outbound operation implementations to trigger an inbound event later. + // They do not trigger an inbound event immediately because an outbound operation might have been + // triggered by another inbound event handler method. If fired immediately, the call stack + // will look like this for example: + // + // handlerA.inboundBufferUpdated() - (1) an inbound handler method closes a connection. + // -> handlerA.ctx.close() + // -> channel.unsafe.close() + // -> handlerA.channelInactive() - (2) another inbound handler method called while in (1) yet + // + // which means the execution of two inbound handler methods of the same handler overlap undesirably. + eventLoop().execute(task); + } catch (RejectedExecutionException e) { + logger.warn("Can't invoke task later as EventLoop rejected it", e); + } + } + + /** + * Appends the remote address to the message of the exceptions caused by connection attempt failure. + */ + protected final Throwable annotateConnectException(Throwable cause, SocketAddress remoteAddress) { + if (cause instanceof ConnectException) { + return new AnnotatedConnectException((ConnectException) cause, remoteAddress); + } + if (cause instanceof NoRouteToHostException) { + return new AnnotatedNoRouteToHostException((NoRouteToHostException) cause, remoteAddress); + } + if (cause instanceof SocketException) { + return new AnnotatedSocketException((SocketException) cause, remoteAddress); + } + + return cause; + } + + /** + * Prepares to close the {@link Channel}. If this method returns an {@link Executor}, the + * caller must call the {@link Executor#execute(Runnable)} method with a task that calls + * {@link #doClose()} on the returned {@link Executor}. If this method returns {@code null}, + * {@link #doClose()} must be called from the caller thread. (i.e. {@link EventLoop}) + */ + protected Executor prepareToClose() { + return null; + } + } + + /** + * Return {@code true} if the given {@link EventLoop} is compatible with this instance. + */ + protected abstract boolean isCompatible(EventLoop loop); + + /** + * Returns the {@link SocketAddress} which is bound locally. + */ + protected abstract SocketAddress localAddress0(); + + /** + * Return the {@link SocketAddress} which the {@link Channel} is connected to. + */ + protected abstract SocketAddress remoteAddress0(); + + /** + * Is called after the {@link Channel} is registered with its {@link EventLoop} as part of the register process. + * + * Sub-classes may override this method + */ + protected void doRegister() throws Exception { + // NOOP + } + + /** + * Bind the {@link Channel} to the {@link SocketAddress} + */ + protected abstract void doBind(SocketAddress localAddress) throws Exception; + + /** + * Disconnect this {@link Channel} from its remote peer + */ + protected abstract void doDisconnect() throws Exception; + + /** + * Close the {@link Channel} + */ + protected abstract void doClose() throws Exception; + + /** + * Called when conditions justify shutting down the output portion of the channel. This may happen if a write + * operation throws an exception. + */ + @UnstableApi + protected void doShutdownOutput() throws Exception { + doClose(); + } + + /** + * Deregister the {@link Channel} from its {@link EventLoop}. + * + * Sub-classes may override this method + */ + protected void doDeregister() throws Exception { + // NOOP + } + + /** + * Schedule a read operation. + */ + protected abstract void doBeginRead() throws Exception; + + /** + * Flush the content of the given buffer to the remote peer. + */ + protected abstract void doWrite(ChannelOutboundBuffer in) throws Exception; + + /** + * Invoked when a new message is added to a {@link ChannelOutboundBuffer} of this {@link AbstractChannel}, so that + * the {@link Channel} implementation converts the message to another. (e.g. heap buffer -> direct buffer) + */ + protected Object filterOutboundMessage(Object msg) throws Exception { + return msg; + } + + protected void validateFileRegion(DefaultFileRegion region, long position) throws IOException { + DefaultFileRegion.validate(region, position); + } + + static final class CloseFuture extends DefaultChannelPromise { + + CloseFuture(AbstractChannel ch) { + super(ch); + } + + @Override + public ChannelPromise setSuccess() { + throw new IllegalStateException(); + } + + @Override + public ChannelPromise setFailure(Throwable cause) { + throw new IllegalStateException(); + } + + @Override + public boolean trySuccess() { + throw new IllegalStateException(); + } + + @Override + public boolean tryFailure(Throwable cause) { + throw new IllegalStateException(); + } + + boolean setClosed() { + return super.trySuccess(); + } + } + + private static final class AnnotatedConnectException extends ConnectException { + + private static final long serialVersionUID = 3901958112696433556L; + + AnnotatedConnectException(ConnectException exception, SocketAddress remoteAddress) { + super(exception.getMessage() + ": " + remoteAddress); + initCause(exception); + } + + // Suppress a warning since this method doesn't need synchronization + @Override + public Throwable fillInStackTrace() { + return this; + } + } + + private static final class AnnotatedNoRouteToHostException extends NoRouteToHostException { + + private static final long serialVersionUID = -6801433937592080623L; + + AnnotatedNoRouteToHostException(NoRouteToHostException exception, SocketAddress remoteAddress) { + super(exception.getMessage() + ": " + remoteAddress); + initCause(exception); + } + + // Suppress a warning since this method doesn't need synchronization + @Override + public Throwable fillInStackTrace() { + return this; + } + } + + private static final class AnnotatedSocketException extends SocketException { + + private static final long serialVersionUID = 3896743275010454039L; + + AnnotatedSocketException(SocketException exception, SocketAddress remoteAddress) { + super(exception.getMessage() + ": " + remoteAddress); + initCause(exception); + } + + // Suppress a warning since this method doesn't need synchronization + @Override + public Throwable fillInStackTrace() { + return this; + } + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/AbstractChannelHandlerContext.java b/netty-channel/src/main/java/io/netty/channel/AbstractChannelHandlerContext.java new file mode 100644 index 0000000..d3953f4 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/AbstractChannelHandlerContext.java @@ -0,0 +1,1324 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.buffer.ByteBufAllocator; +import io.netty.util.Attribute; +import io.netty.util.AttributeKey; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.ResourceLeakHint; +import io.netty.util.concurrent.AbstractEventExecutor; +import io.netty.util.concurrent.EventExecutor; +import io.netty.util.concurrent.OrderedEventExecutor; +import io.netty.util.internal.ObjectPool; +import io.netty.util.internal.ObjectPool.Handle; +import io.netty.util.internal.ObjectPool.ObjectCreator; +import io.netty.util.internal.PromiseNotificationUtil; +import io.netty.util.internal.ThrowableUtil; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.StringUtil; +import io.netty.util.internal.SystemPropertyUtil; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.net.SocketAddress; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; + +import static io.netty.channel.ChannelHandlerMask.MASK_BIND; +import static io.netty.channel.ChannelHandlerMask.MASK_CHANNEL_ACTIVE; +import static io.netty.channel.ChannelHandlerMask.MASK_CHANNEL_INACTIVE; +import static io.netty.channel.ChannelHandlerMask.MASK_CHANNEL_READ; +import static io.netty.channel.ChannelHandlerMask.MASK_CHANNEL_READ_COMPLETE; +import static io.netty.channel.ChannelHandlerMask.MASK_CHANNEL_REGISTERED; +import static io.netty.channel.ChannelHandlerMask.MASK_CHANNEL_UNREGISTERED; +import static io.netty.channel.ChannelHandlerMask.MASK_CHANNEL_WRITABILITY_CHANGED; +import static io.netty.channel.ChannelHandlerMask.MASK_CLOSE; +import static io.netty.channel.ChannelHandlerMask.MASK_CONNECT; +import static io.netty.channel.ChannelHandlerMask.MASK_DEREGISTER; +import static io.netty.channel.ChannelHandlerMask.MASK_DISCONNECT; +import static io.netty.channel.ChannelHandlerMask.MASK_EXCEPTION_CAUGHT; +import static io.netty.channel.ChannelHandlerMask.MASK_FLUSH; +import static io.netty.channel.ChannelHandlerMask.MASK_ONLY_INBOUND; +import static io.netty.channel.ChannelHandlerMask.MASK_ONLY_OUTBOUND; +import static io.netty.channel.ChannelHandlerMask.MASK_READ; +import static io.netty.channel.ChannelHandlerMask.MASK_USER_EVENT_TRIGGERED; +import static io.netty.channel.ChannelHandlerMask.MASK_WRITE; +import static io.netty.channel.ChannelHandlerMask.mask; + +abstract class AbstractChannelHandlerContext implements ChannelHandlerContext, ResourceLeakHint { + + private static final InternalLogger logger = InternalLoggerFactory.getInstance(AbstractChannelHandlerContext.class); + volatile AbstractChannelHandlerContext next; + volatile AbstractChannelHandlerContext prev; + + private static final AtomicIntegerFieldUpdater HANDLER_STATE_UPDATER = + AtomicIntegerFieldUpdater.newUpdater(AbstractChannelHandlerContext.class, "handlerState"); + + /** + * {@link ChannelHandler#handlerAdded(ChannelHandlerContext)} is about to be called. + */ + private static final int ADD_PENDING = 1; + /** + * {@link ChannelHandler#handlerAdded(ChannelHandlerContext)} was called. + */ + private static final int ADD_COMPLETE = 2; + /** + * {@link ChannelHandler#handlerRemoved(ChannelHandlerContext)} was called. + */ + private static final int REMOVE_COMPLETE = 3; + /** + * Neither {@link ChannelHandler#handlerAdded(ChannelHandlerContext)} + * nor {@link ChannelHandler#handlerRemoved(ChannelHandlerContext)} was called. + */ + private static final int INIT = 0; + + private final DefaultChannelPipeline pipeline; + private final String name; + private final boolean ordered; + private final int executionMask; + + // Will be set to null if no child executor should be used, otherwise it will be set to the + // child executor. + final EventExecutor executor; + private ChannelFuture succeededFuture; + + // Lazily instantiated tasks used to trigger events to a handler with different executor. + // There is no need to make this volatile as at worse it will just create a few more instances then needed. + private Tasks invokeTasks; + + private volatile int handlerState = INIT; + + AbstractChannelHandlerContext(DefaultChannelPipeline pipeline, EventExecutor executor, + String name, Class handlerClass) { + this.name = ObjectUtil.checkNotNull(name, "name"); + this.pipeline = pipeline; + this.executor = executor; + this.executionMask = mask(handlerClass); + // Its ordered if its driven by the EventLoop or the given Executor is an instanceof OrderedEventExecutor. + ordered = executor == null || executor instanceof OrderedEventExecutor; + } + + @Override + public Channel channel() { + return pipeline.channel(); + } + + @Override + public ChannelPipeline pipeline() { + return pipeline; + } + + @Override + public ByteBufAllocator alloc() { + return channel().config().getAllocator(); + } + + @Override + public EventExecutor executor() { + if (executor == null) { + return channel().eventLoop(); + } else { + return executor; + } + } + + @Override + public String name() { + return name; + } + + @Override + public ChannelHandlerContext fireChannelRegistered() { + invokeChannelRegistered(findContextInbound(MASK_CHANNEL_REGISTERED)); + return this; + } + + static void invokeChannelRegistered(final AbstractChannelHandlerContext next) { + EventExecutor executor = next.executor(); + if (executor.inEventLoop()) { + next.invokeChannelRegistered(); + } else { + executor.execute(new Runnable() { + @Override + public void run() { + next.invokeChannelRegistered(); + } + }); + } + } + + private void invokeChannelRegistered() { + if (invokeHandler()) { + try { + // DON'T CHANGE + // Duplex handlers implements both out/in interfaces causing a scalability issue + // see https://bugs.openjdk.org/browse/JDK-8180450 + final ChannelHandler handler = handler(); + final DefaultChannelPipeline.HeadContext headContext = pipeline.head; + if (handler == headContext) { + headContext.channelRegistered(this); + } else if (handler instanceof ChannelInboundHandlerAdapter) { + ((ChannelInboundHandlerAdapter) handler).channelRegistered(this); + } else { + ((ChannelInboundHandler) handler).channelRegistered(this); + } + } catch (Throwable t) { + invokeExceptionCaught(t); + } + } else { + fireChannelRegistered(); + } + } + + @Override + public ChannelHandlerContext fireChannelUnregistered() { + invokeChannelUnregistered(findContextInbound(MASK_CHANNEL_UNREGISTERED)); + return this; + } + + static void invokeChannelUnregistered(final AbstractChannelHandlerContext next) { + EventExecutor executor = next.executor(); + if (executor.inEventLoop()) { + next.invokeChannelUnregistered(); + } else { + executor.execute(new Runnable() { + @Override + public void run() { + next.invokeChannelUnregistered(); + } + }); + } + } + + private void invokeChannelUnregistered() { + if (invokeHandler()) { + try { + // DON'T CHANGE + // Duplex handlers implements both out/in interfaces causing a scalability issue + // see https://bugs.openjdk.org/browse/JDK-8180450 + final ChannelHandler handler = handler(); + final DefaultChannelPipeline.HeadContext headContext = pipeline.head; + if (handler == headContext) { + headContext.channelUnregistered(this); + } else if (handler instanceof ChannelInboundHandlerAdapter) { + ((ChannelInboundHandlerAdapter) handler).channelUnregistered(this); + } else { + ((ChannelInboundHandler) handler).channelUnregistered(this); + } + } catch (Throwable t) { + invokeExceptionCaught(t); + } + } else { + fireChannelUnregistered(); + } + } + + @Override + public ChannelHandlerContext fireChannelActive() { + invokeChannelActive(findContextInbound(MASK_CHANNEL_ACTIVE)); + return this; + } + + static void invokeChannelActive(final AbstractChannelHandlerContext next) { + EventExecutor executor = next.executor(); + if (executor.inEventLoop()) { + next.invokeChannelActive(); + } else { + executor.execute(new Runnable() { + @Override + public void run() { + next.invokeChannelActive(); + } + }); + } + } + + private void invokeChannelActive() { + if (invokeHandler()) { + try { + // DON'T CHANGE + // Duplex handlers implements both out/in interfaces causing a scalability issue + // see https://bugs.openjdk.org/browse/JDK-8180450 + final ChannelHandler handler = handler(); + final DefaultChannelPipeline.HeadContext headContext = pipeline.head; + if (handler == headContext) { + headContext.channelActive(this); + } else if (handler instanceof ChannelInboundHandlerAdapter) { + ((ChannelInboundHandlerAdapter) handler).channelActive(this); + } else { + ((ChannelInboundHandler) handler).channelActive(this); + } + } catch (Throwable t) { + invokeExceptionCaught(t); + } + } else { + fireChannelActive(); + } + } + + @Override + public ChannelHandlerContext fireChannelInactive() { + invokeChannelInactive(findContextInbound(MASK_CHANNEL_INACTIVE)); + return this; + } + + static void invokeChannelInactive(final AbstractChannelHandlerContext next) { + EventExecutor executor = next.executor(); + if (executor.inEventLoop()) { + next.invokeChannelInactive(); + } else { + executor.execute(new Runnable() { + @Override + public void run() { + next.invokeChannelInactive(); + } + }); + } + } + + private void invokeChannelInactive() { + if (invokeHandler()) { + try { + // DON'T CHANGE + // Duplex handlers implements both out/in interfaces causing a scalability issue + // see https://bugs.openjdk.org/browse/JDK-8180450 + final ChannelHandler handler = handler(); + final DefaultChannelPipeline.HeadContext headContext = pipeline.head; + if (handler == headContext) { + headContext.channelInactive(this); + } else if (handler instanceof ChannelInboundHandlerAdapter) { + ((ChannelInboundHandlerAdapter) handler).channelInactive(this); + } else { + ((ChannelInboundHandler) handler).channelInactive(this); + } + } catch (Throwable t) { + invokeExceptionCaught(t); + } + } else { + fireChannelInactive(); + } + } + + @Override + public ChannelHandlerContext fireExceptionCaught(final Throwable cause) { + invokeExceptionCaught(findContextInbound(MASK_EXCEPTION_CAUGHT), cause); + return this; + } + + static void invokeExceptionCaught(final AbstractChannelHandlerContext next, final Throwable cause) { + ObjectUtil.checkNotNull(cause, "cause"); + EventExecutor executor = next.executor(); + if (executor.inEventLoop()) { + next.invokeExceptionCaught(cause); + } else { + try { + executor.execute(new Runnable() { + @Override + public void run() { + next.invokeExceptionCaught(cause); + } + }); + } catch (Throwable t) { + if (logger.isWarnEnabled()) { + logger.warn("Failed to submit an exceptionCaught() event.", t); + logger.warn("The exceptionCaught() event that was failed to submit was:", cause); + } + } + } + } + + private void invokeExceptionCaught(final Throwable cause) { + if (invokeHandler()) { + try { + handler().exceptionCaught(this, cause); + } catch (Throwable error) { + if (logger.isDebugEnabled()) { + logger.debug( + "An exception {}" + + "was thrown by a user handler's exceptionCaught() " + + "method while handling the following exception:", + ThrowableUtil.stackTraceToString(error), cause); + } else if (logger.isWarnEnabled()) { + logger.warn( + "An exception '{}' [enable DEBUG level for full stacktrace] " + + "was thrown by a user handler's exceptionCaught() " + + "method while handling the following exception:", error, cause); + } + } + } else { + fireExceptionCaught(cause); + } + } + + @Override + public ChannelHandlerContext fireUserEventTriggered(final Object event) { + invokeUserEventTriggered(findContextInbound(MASK_USER_EVENT_TRIGGERED), event); + return this; + } + + static void invokeUserEventTriggered(final AbstractChannelHandlerContext next, final Object event) { + ObjectUtil.checkNotNull(event, "event"); + EventExecutor executor = next.executor(); + if (executor.inEventLoop()) { + next.invokeUserEventTriggered(event); + } else { + executor.execute(new Runnable() { + @Override + public void run() { + next.invokeUserEventTriggered(event); + } + }); + } + } + + private void invokeUserEventTriggered(Object event) { + if (invokeHandler()) { + try { + // DON'T CHANGE + // Duplex handlers implements both out/in interfaces causing a scalability issue + // see https://bugs.openjdk.org/browse/JDK-8180450 + final ChannelHandler handler = handler(); + final DefaultChannelPipeline.HeadContext headContext = pipeline.head; + if (handler == headContext) { + headContext.userEventTriggered(this, event); + } else if (handler instanceof ChannelInboundHandlerAdapter) { + ((ChannelInboundHandlerAdapter) handler).userEventTriggered(this, event); + } else { + ((ChannelInboundHandler) handler).userEventTriggered(this, event); + } + } catch (Throwable t) { + invokeExceptionCaught(t); + } + } else { + fireUserEventTriggered(event); + } + } + + @Override + public ChannelHandlerContext fireChannelRead(final Object msg) { + invokeChannelRead(findContextInbound(MASK_CHANNEL_READ), msg); + return this; + } + + static void invokeChannelRead(final AbstractChannelHandlerContext next, Object msg) { + final Object m = next.pipeline.touch(ObjectUtil.checkNotNull(msg, "msg"), next); + EventExecutor executor = next.executor(); + if (executor.inEventLoop()) { + next.invokeChannelRead(m); + } else { + executor.execute(new Runnable() { + @Override + public void run() { + next.invokeChannelRead(m); + } + }); + } + } + + private void invokeChannelRead(Object msg) { + if (invokeHandler()) { + try { + // DON'T CHANGE + // Duplex handlers implements both out/in interfaces causing a scalability issue + // see https://bugs.openjdk.org/browse/JDK-8180450 + final ChannelHandler handler = handler(); + final DefaultChannelPipeline.HeadContext headContext = pipeline.head; + if (handler == headContext) { + headContext.channelRead(this, msg); + } else if (handler instanceof ChannelDuplexHandler) { + ((ChannelDuplexHandler) handler).channelRead(this, msg); + } else { + ((ChannelInboundHandler) handler).channelRead(this, msg); + } + } catch (Throwable t) { + invokeExceptionCaught(t); + } + } else { + fireChannelRead(msg); + } + } + + @Override + public ChannelHandlerContext fireChannelReadComplete() { + invokeChannelReadComplete(findContextInbound(MASK_CHANNEL_READ_COMPLETE)); + return this; + } + + static void invokeChannelReadComplete(final AbstractChannelHandlerContext next) { + EventExecutor executor = next.executor(); + if (executor.inEventLoop()) { + next.invokeChannelReadComplete(); + } else { + Tasks tasks = next.invokeTasks; + if (tasks == null) { + next.invokeTasks = tasks = new Tasks(next); + } + executor.execute(tasks.invokeChannelReadCompleteTask); + } + } + + private void invokeChannelReadComplete() { + if (invokeHandler()) { + try { + // DON'T CHANGE + // Duplex handlers implements both out/in interfaces causing a scalability issue + // see https://bugs.openjdk.org/browse/JDK-8180450 + final ChannelHandler handler = handler(); + final DefaultChannelPipeline.HeadContext headContext = pipeline.head; + if (handler == headContext) { + headContext.channelReadComplete(this); + } else if (handler instanceof ChannelDuplexHandler) { + ((ChannelDuplexHandler) handler).channelReadComplete(this); + } else { + ((ChannelInboundHandler) handler).channelReadComplete(this); + } + } catch (Throwable t) { + invokeExceptionCaught(t); + } + } else { + fireChannelReadComplete(); + } + } + + @Override + public ChannelHandlerContext fireChannelWritabilityChanged() { + invokeChannelWritabilityChanged(findContextInbound(MASK_CHANNEL_WRITABILITY_CHANGED)); + return this; + } + + static void invokeChannelWritabilityChanged(final AbstractChannelHandlerContext next) { + EventExecutor executor = next.executor(); + if (executor.inEventLoop()) { + next.invokeChannelWritabilityChanged(); + } else { + Tasks tasks = next.invokeTasks; + if (tasks == null) { + next.invokeTasks = tasks = new Tasks(next); + } + executor.execute(tasks.invokeChannelWritableStateChangedTask); + } + } + + private void invokeChannelWritabilityChanged() { + if (invokeHandler()) { + try { + // DON'T CHANGE + // Duplex handlers implements both out/in interfaces causing a scalability issue + // see https://bugs.openjdk.org/browse/JDK-8180450 + final ChannelHandler handler = handler(); + final DefaultChannelPipeline.HeadContext headContext = pipeline.head; + if (handler == headContext) { + headContext.channelWritabilityChanged(this); + } else if (handler instanceof ChannelInboundHandlerAdapter) { + ((ChannelInboundHandlerAdapter) handler).channelWritabilityChanged(this); + } else { + ((ChannelInboundHandler) handler).channelWritabilityChanged(this); + } + } catch (Throwable t) { + invokeExceptionCaught(t); + } + } else { + fireChannelWritabilityChanged(); + } + } + + @Override + public ChannelFuture bind(SocketAddress localAddress) { + return bind(localAddress, newPromise()); + } + + @Override + public ChannelFuture connect(SocketAddress remoteAddress) { + return connect(remoteAddress, newPromise()); + } + + @Override + public ChannelFuture connect(SocketAddress remoteAddress, SocketAddress localAddress) { + return connect(remoteAddress, localAddress, newPromise()); + } + + @Override + public ChannelFuture disconnect() { + return disconnect(newPromise()); + } + + @Override + public ChannelFuture close() { + return close(newPromise()); + } + + @Override + public ChannelFuture deregister() { + return deregister(newPromise()); + } + + @Override + public ChannelFuture bind(final SocketAddress localAddress, final ChannelPromise promise) { + ObjectUtil.checkNotNull(localAddress, "localAddress"); + if (isNotValidPromise(promise, false)) { + // cancelled + return promise; + } + + final AbstractChannelHandlerContext next = findContextOutbound(MASK_BIND); + EventExecutor executor = next.executor(); + if (executor.inEventLoop()) { + next.invokeBind(localAddress, promise); + } else { + safeExecute(executor, new Runnable() { + @Override + public void run() { + next.invokeBind(localAddress, promise); + } + }, promise, null, false); + } + return promise; + } + + private void invokeBind(SocketAddress localAddress, ChannelPromise promise) { + if (invokeHandler()) { + try { + // DON'T CHANGE + // Duplex handlers implements both out/in interfaces causing a scalability issue + // see https://bugs.openjdk.org/browse/JDK-8180450 + final ChannelHandler handler = handler(); + final DefaultChannelPipeline.HeadContext headContext = pipeline.head; + if (handler == headContext) { + headContext.bind(this, localAddress, promise); + } else if (handler instanceof ChannelDuplexHandler) { + ((ChannelDuplexHandler) handler).bind(this, localAddress, promise); + } else if (handler instanceof ChannelOutboundHandlerAdapter) { + ((ChannelOutboundHandlerAdapter) handler).bind(this, localAddress, promise); + } else { + ((ChannelOutboundHandler) handler).bind(this, localAddress, promise); + } + } catch (Throwable t) { + notifyOutboundHandlerException(t, promise); + } + } else { + bind(localAddress, promise); + } + } + + @Override + public ChannelFuture connect(SocketAddress remoteAddress, ChannelPromise promise) { + return connect(remoteAddress, null, promise); + } + + @Override + public ChannelFuture connect( + final SocketAddress remoteAddress, final SocketAddress localAddress, final ChannelPromise promise) { + ObjectUtil.checkNotNull(remoteAddress, "remoteAddress"); + + if (isNotValidPromise(promise, false)) { + // cancelled + return promise; + } + + final AbstractChannelHandlerContext next = findContextOutbound(MASK_CONNECT); + EventExecutor executor = next.executor(); + if (executor.inEventLoop()) { + next.invokeConnect(remoteAddress, localAddress, promise); + } else { + safeExecute(executor, new Runnable() { + @Override + public void run() { + next.invokeConnect(remoteAddress, localAddress, promise); + } + }, promise, null, false); + } + return promise; + } + + private void invokeConnect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) { + if (invokeHandler()) { + try { + // DON'T CHANGE + // Duplex handlers implements both out/in interfaces causing a scalability issue + // see https://bugs.openjdk.org/browse/JDK-8180450 + final ChannelHandler handler = handler(); + final DefaultChannelPipeline.HeadContext headContext = pipeline.head; + if (handler == headContext) { + headContext.connect(this, remoteAddress, localAddress, promise); + } else if (handler instanceof ChannelDuplexHandler) { + ((ChannelDuplexHandler) handler).connect(this, remoteAddress, localAddress, promise); + } else if (handler instanceof ChannelOutboundHandlerAdapter) { + ((ChannelOutboundHandlerAdapter) handler).connect(this, remoteAddress, localAddress, promise); + } else { + ((ChannelOutboundHandler) handler).connect(this, remoteAddress, localAddress, promise); + } + } catch (Throwable t) { + notifyOutboundHandlerException(t, promise); + } + } else { + connect(remoteAddress, localAddress, promise); + } + } + + @Override + public ChannelFuture disconnect(final ChannelPromise promise) { + if (!channel().metadata().hasDisconnect()) { + // Translate disconnect to close if the channel has no notion of disconnect-reconnect. + // So far, UDP/IP is the only transport that has such behavior. + return close(promise); + } + if (isNotValidPromise(promise, false)) { + // cancelled + return promise; + } + + final AbstractChannelHandlerContext next = findContextOutbound(MASK_DISCONNECT); + EventExecutor executor = next.executor(); + if (executor.inEventLoop()) { + next.invokeDisconnect(promise); + } else { + safeExecute(executor, new Runnable() { + @Override + public void run() { + next.invokeDisconnect(promise); + } + }, promise, null, false); + } + return promise; + } + + private void invokeDisconnect(ChannelPromise promise) { + if (invokeHandler()) { + try { + // DON'T CHANGE + // Duplex handlers implements both out/in interfaces causing a scalability issue + // see https://bugs.openjdk.org/browse/JDK-8180450 + final ChannelHandler handler = handler(); + final DefaultChannelPipeline.HeadContext headContext = pipeline.head; + if (handler == headContext) { + headContext.disconnect(this, promise); + } else if (handler instanceof ChannelDuplexHandler) { + ((ChannelDuplexHandler) handler).disconnect(this, promise); + } else if (handler instanceof ChannelOutboundHandlerAdapter) { + ((ChannelOutboundHandlerAdapter) handler).disconnect(this, promise); + } else { + ((ChannelOutboundHandler) handler).disconnect(this, promise); + } + } catch (Throwable t) { + notifyOutboundHandlerException(t, promise); + } + } else { + disconnect(promise); + } + } + + @Override + public ChannelFuture close(final ChannelPromise promise) { + if (isNotValidPromise(promise, false)) { + // cancelled + return promise; + } + + final AbstractChannelHandlerContext next = findContextOutbound(MASK_CLOSE); + EventExecutor executor = next.executor(); + if (executor.inEventLoop()) { + next.invokeClose(promise); + } else { + safeExecute(executor, new Runnable() { + @Override + public void run() { + next.invokeClose(promise); + } + }, promise, null, false); + } + + return promise; + } + + private void invokeClose(ChannelPromise promise) { + if (invokeHandler()) { + try { + // DON'T CHANGE + // Duplex handlers implements both out/in interfaces causing a scalability issue + // see https://bugs.openjdk.org/browse/JDK-8180450 + final ChannelHandler handler = handler(); + final DefaultChannelPipeline.HeadContext headContext = pipeline.head; + if (handler == headContext) { + headContext.close(this, promise); + } else if (handler instanceof ChannelDuplexHandler) { + ((ChannelDuplexHandler) handler).close(this, promise); + } else if (handler instanceof ChannelOutboundHandlerAdapter) { + ((ChannelOutboundHandlerAdapter) handler).close(this, promise); + } else { + ((ChannelOutboundHandler) handler).close(this, promise); + } + } catch (Throwable t) { + notifyOutboundHandlerException(t, promise); + } + } else { + close(promise); + } + } + + @Override + public ChannelFuture deregister(final ChannelPromise promise) { + if (isNotValidPromise(promise, false)) { + // cancelled + return promise; + } + + final AbstractChannelHandlerContext next = findContextOutbound(MASK_DEREGISTER); + EventExecutor executor = next.executor(); + if (executor.inEventLoop()) { + next.invokeDeregister(promise); + } else { + safeExecute(executor, new Runnable() { + @Override + public void run() { + next.invokeDeregister(promise); + } + }, promise, null, false); + } + + return promise; + } + + private void invokeDeregister(ChannelPromise promise) { + if (invokeHandler()) { + try { + // DON'T CHANGE + // Duplex handlers implements both out/in interfaces causing a scalability issue + // see https://bugs.openjdk.org/browse/JDK-8180450 + final ChannelHandler handler = handler(); + final DefaultChannelPipeline.HeadContext headContext = pipeline.head; + if (handler == headContext) { + headContext.deregister(this, promise); + } else if (handler instanceof ChannelDuplexHandler) { + ((ChannelDuplexHandler) handler).deregister(this, promise); + } else if (handler instanceof ChannelOutboundHandlerAdapter) { + ((ChannelOutboundHandlerAdapter) handler).deregister(this, promise); + } else { + ((ChannelOutboundHandler) handler).deregister(this, promise); + } + } catch (Throwable t) { + notifyOutboundHandlerException(t, promise); + } + } else { + deregister(promise); + } + } + + @Override + public ChannelHandlerContext read() { + final AbstractChannelHandlerContext next = findContextOutbound(MASK_READ); + EventExecutor executor = next.executor(); + if (executor.inEventLoop()) { + next.invokeRead(); + } else { + Tasks tasks = next.invokeTasks; + if (tasks == null) { + next.invokeTasks = tasks = new Tasks(next); + } + executor.execute(tasks.invokeReadTask); + } + + return this; + } + + private void invokeRead() { + if (invokeHandler()) { + try { + // DON'T CHANGE + // Duplex handlers implements both out/in interfaces causing a scalability issue + // see https://bugs.openjdk.org/browse/JDK-8180450 + final ChannelHandler handler = handler(); + final DefaultChannelPipeline.HeadContext headContext = pipeline.head; + if (handler == headContext) { + headContext.read(this); + } else if (handler instanceof ChannelDuplexHandler) { + ((ChannelDuplexHandler) handler).read(this); + } else if (handler instanceof ChannelOutboundHandlerAdapter) { + ((ChannelOutboundHandlerAdapter) handler).read(this); + } else { + ((ChannelOutboundHandler) handler).read(this); + } + } catch (Throwable t) { + invokeExceptionCaught(t); + } + } else { + read(); + } + } + + @Override + public ChannelFuture write(Object msg) { + return write(msg, newPromise()); + } + + @Override + public ChannelFuture write(final Object msg, final ChannelPromise promise) { + write(msg, false, promise); + + return promise; + } + + void invokeWrite(Object msg, ChannelPromise promise) { + if (invokeHandler()) { + invokeWrite0(msg, promise); + } else { + write(msg, promise); + } + } + + private void invokeWrite0(Object msg, ChannelPromise promise) { + try { + // DON'T CHANGE + // Duplex handlers implements both out/in interfaces causing a scalability issue + // see https://bugs.openjdk.org/browse/JDK-8180450 + final ChannelHandler handler = handler(); + final DefaultChannelPipeline.HeadContext headContext = pipeline.head; + if (handler == headContext) { + headContext.write(this, msg, promise); + } else if (handler instanceof ChannelDuplexHandler) { + ((ChannelDuplexHandler) handler).write(this, msg, promise); + } else if (handler instanceof ChannelOutboundHandlerAdapter) { + ((ChannelOutboundHandlerAdapter) handler).write(this, msg, promise); + } else { + ((ChannelOutboundHandler) handler).write(this, msg, promise); + } + } catch (Throwable t) { + notifyOutboundHandlerException(t, promise); + } + } + + @Override + public ChannelHandlerContext flush() { + final AbstractChannelHandlerContext next = findContextOutbound(MASK_FLUSH); + EventExecutor executor = next.executor(); + if (executor.inEventLoop()) { + next.invokeFlush(); + } else { + Tasks tasks = next.invokeTasks; + if (tasks == null) { + next.invokeTasks = tasks = new Tasks(next); + } + safeExecute(executor, tasks.invokeFlushTask, channel().voidPromise(), null, false); + } + + return this; + } + + private void invokeFlush() { + if (invokeHandler()) { + invokeFlush0(); + } else { + flush(); + } + } + + private void invokeFlush0() { + try { + // DON'T CHANGE + // Duplex handlers implements both out/in interfaces causing a scalability issue + // see https://bugs.openjdk.org/browse/JDK-8180450 + final ChannelHandler handler = handler(); + final DefaultChannelPipeline.HeadContext headContext = pipeline.head; + if (handler == headContext) { + headContext.flush(this); + } else if (handler instanceof ChannelDuplexHandler) { + ((ChannelDuplexHandler) handler).flush(this); + } else if (handler instanceof ChannelOutboundHandlerAdapter) { + ((ChannelOutboundHandlerAdapter) handler).flush(this); + } else { + ((ChannelOutboundHandler) handler).flush(this); + } + } catch (Throwable t) { + invokeExceptionCaught(t); + } + } + + @Override + public ChannelFuture writeAndFlush(Object msg, ChannelPromise promise) { + write(msg, true, promise); + return promise; + } + + void invokeWriteAndFlush(Object msg, ChannelPromise promise) { + if (invokeHandler()) { + invokeWrite0(msg, promise); + invokeFlush0(); + } else { + writeAndFlush(msg, promise); + } + } + + private void write(Object msg, boolean flush, ChannelPromise promise) { + ObjectUtil.checkNotNull(msg, "msg"); + try { + if (isNotValidPromise(promise, true)) { + ReferenceCountUtil.release(msg); + // cancelled + return; + } + } catch (RuntimeException e) { + ReferenceCountUtil.release(msg); + throw e; + } + + final AbstractChannelHandlerContext next = findContextOutbound(flush ? + (MASK_WRITE | MASK_FLUSH) : MASK_WRITE); + final Object m = pipeline.touch(msg, next); + EventExecutor executor = next.executor(); + if (executor.inEventLoop()) { + if (flush) { + next.invokeWriteAndFlush(m, promise); + } else { + next.invokeWrite(m, promise); + } + } else { + final WriteTask task = WriteTask.newInstance(next, m, promise, flush); + if (!safeExecute(executor, task, promise, m, !flush)) { + // We failed to submit the WriteTask. We need to cancel it so we decrement the pending bytes + // and put it back in the Recycler for re-use later. + // + // See https://github.com/netty/netty/issues/8343. + task.cancel(); + } + } + } + + @Override + public ChannelFuture writeAndFlush(Object msg) { + return writeAndFlush(msg, newPromise()); + } + + private static void notifyOutboundHandlerException(Throwable cause, ChannelPromise promise) { + // Only log if the given promise is not of type VoidChannelPromise as tryFailure(...) is expected to return + // false. + PromiseNotificationUtil.tryFailure(promise, cause, promise instanceof VoidChannelPromise ? null : logger); + } + + @Override + public ChannelPromise newPromise() { + return new DefaultChannelPromise(channel(), executor()); + } + + @Override + public ChannelProgressivePromise newProgressivePromise() { + return new DefaultChannelProgressivePromise(channel(), executor()); + } + + @Override + public ChannelFuture newSucceededFuture() { + ChannelFuture succeededFuture = this.succeededFuture; + if (succeededFuture == null) { + this.succeededFuture = succeededFuture = new SucceededChannelFuture(channel(), executor()); + } + return succeededFuture; + } + + @Override + public ChannelFuture newFailedFuture(Throwable cause) { + return new FailedChannelFuture(channel(), executor(), cause); + } + + private boolean isNotValidPromise(ChannelPromise promise, boolean allowVoidPromise) { + ObjectUtil.checkNotNull(promise, "promise"); + + if (promise.isDone()) { + // Check if the promise was cancelled and if so signal that the processing of the operation + // should not be performed. + // + // See https://github.com/netty/netty/issues/2349 + if (promise.isCancelled()) { + return true; + } + throw new IllegalArgumentException("promise already done: " + promise); + } + + if (promise.channel() != channel()) { + throw new IllegalArgumentException(String.format( + "promise.channel does not match: %s (expected: %s)", promise.channel(), channel())); + } + + if (promise.getClass() == DefaultChannelPromise.class) { + return false; + } + + if (!allowVoidPromise && promise instanceof VoidChannelPromise) { + throw new IllegalArgumentException( + StringUtil.simpleClassName(VoidChannelPromise.class) + " not allowed for this operation"); + } + + if (promise instanceof AbstractChannel.CloseFuture) { + throw new IllegalArgumentException( + StringUtil.simpleClassName(AbstractChannel.CloseFuture.class) + " not allowed in a pipeline"); + } + return false; + } + + private AbstractChannelHandlerContext findContextInbound(int mask) { + AbstractChannelHandlerContext ctx = this; + EventExecutor currentExecutor = executor(); + do { + ctx = ctx.next; + } while (skipContext(ctx, currentExecutor, mask, MASK_ONLY_INBOUND)); + return ctx; + } + + private AbstractChannelHandlerContext findContextOutbound(int mask) { + AbstractChannelHandlerContext ctx = this; + EventExecutor currentExecutor = executor(); + do { + ctx = ctx.prev; + } while (skipContext(ctx, currentExecutor, mask, MASK_ONLY_OUTBOUND)); + return ctx; + } + + private static boolean skipContext( + AbstractChannelHandlerContext ctx, EventExecutor currentExecutor, int mask, int onlyMask) { + // Ensure we correctly handle MASK_EXCEPTION_CAUGHT which is not included in the MASK_EXCEPTION_CAUGHT + return (ctx.executionMask & (onlyMask | mask)) == 0 || + // We can only skip if the EventExecutor is the same as otherwise we need to ensure we offload + // everything to preserve ordering. + // + // See https://github.com/netty/netty/issues/10067 + (ctx.executor() == currentExecutor && (ctx.executionMask & mask) == 0); + } + + @Override + public ChannelPromise voidPromise() { + return channel().voidPromise(); + } + + final void setRemoved() { + handlerState = REMOVE_COMPLETE; + } + + final boolean setAddComplete() { + for (;;) { + int oldState = handlerState; + if (oldState == REMOVE_COMPLETE) { + return false; + } + // Ensure we never update when the handlerState is REMOVE_COMPLETE already. + // oldState is usually ADD_PENDING but can also be REMOVE_COMPLETE when an EventExecutor is used that is not + // exposing ordering guarantees. + if (HANDLER_STATE_UPDATER.compareAndSet(this, oldState, ADD_COMPLETE)) { + return true; + } + } + } + + final void setAddPending() { + boolean updated = HANDLER_STATE_UPDATER.compareAndSet(this, INIT, ADD_PENDING); + assert updated; // This should always be true as it MUST be called before setAddComplete() or setRemoved(). + } + + final void callHandlerAdded() throws Exception { + // We must call setAddComplete before calling handlerAdded. Otherwise if the handlerAdded method generates + // any pipeline events ctx.handler() will miss them because the state will not allow it. + if (setAddComplete()) { + handler().handlerAdded(this); + } + } + + final void callHandlerRemoved() throws Exception { + try { + // Only call handlerRemoved(...) if we called handlerAdded(...) before. + if (handlerState == ADD_COMPLETE) { + handler().handlerRemoved(this); + } + } finally { + // Mark the handler as removed in any case. + setRemoved(); + } + } + + /** + * Makes best possible effort to detect if {@link ChannelHandler#handlerAdded(ChannelHandlerContext)} was called + * yet. If not return {@code false} and if called or could not detect return {@code true}. + * + * If this method returns {@code false} we will not invoke the {@link ChannelHandler} but just forward the event. + * This is needed as {@link DefaultChannelPipeline} may already put the {@link ChannelHandler} in the linked-list + * but not called {@link ChannelHandler#handlerAdded(ChannelHandlerContext)}. + */ + private boolean invokeHandler() { + // Store in local variable to reduce volatile reads. + int handlerState = this.handlerState; + return handlerState == ADD_COMPLETE || (!ordered && handlerState == ADD_PENDING); + } + + @Override + public boolean isRemoved() { + return handlerState == REMOVE_COMPLETE; + } + + @Override + public Attribute attr(AttributeKey key) { + return channel().attr(key); + } + + @Override + public boolean hasAttr(AttributeKey key) { + return channel().hasAttr(key); + } + + private static boolean safeExecute(EventExecutor executor, Runnable runnable, + ChannelPromise promise, Object msg, boolean lazy) { + try { + if (lazy && executor instanceof AbstractEventExecutor) { + ((AbstractEventExecutor) executor).lazyExecute(runnable); + } else { + executor.execute(runnable); + } + return true; + } catch (Throwable cause) { + try { + if (msg != null) { + ReferenceCountUtil.release(msg); + } + } finally { + promise.setFailure(cause); + } + return false; + } + } + + @Override + public String toHintString() { + return '\'' + name + "' will handle the message from this point."; + } + + @Override + public String toString() { + return StringUtil.simpleClassName(ChannelHandlerContext.class) + '(' + name + ", " + channel() + ')'; + } + + static final class WriteTask implements Runnable { + private static final ObjectPool RECYCLER = ObjectPool.newPool(new ObjectCreator() { + @Override + public WriteTask newObject(Handle handle) { + return new WriteTask(handle); + } + }); + + static WriteTask newInstance(AbstractChannelHandlerContext ctx, + Object msg, ChannelPromise promise, boolean flush) { + WriteTask task = RECYCLER.get(); + init(task, ctx, msg, promise, flush); + return task; + } + + private static final boolean ESTIMATE_TASK_SIZE_ON_SUBMIT = + SystemPropertyUtil.getBoolean("io.netty.transport.estimateSizeOnSubmit", true); + + // Assuming compressed oops, 12 bytes obj header, 4 ref fields and one int field + private static final int WRITE_TASK_OVERHEAD = + SystemPropertyUtil.getInt("io.netty.transport.writeTaskSizeOverhead", 32); + + private final Handle handle; + private AbstractChannelHandlerContext ctx; + private Object msg; + private ChannelPromise promise; + private int size; // sign bit controls flush + + @SuppressWarnings("unchecked") + private WriteTask(Handle handle) { + this.handle = (Handle) handle; + } + + protected static void init(WriteTask task, AbstractChannelHandlerContext ctx, + Object msg, ChannelPromise promise, boolean flush) { + task.ctx = ctx; + task.msg = msg; + task.promise = promise; + + if (ESTIMATE_TASK_SIZE_ON_SUBMIT) { + task.size = ctx.pipeline.estimatorHandle().size(msg) + WRITE_TASK_OVERHEAD; + ctx.pipeline.incrementPendingOutboundBytes(task.size); + } else { + task.size = 0; + } + if (flush) { + task.size |= Integer.MIN_VALUE; + } + } + + @Override + public void run() { + try { + decrementPendingOutboundBytes(); + if (size >= 0) { + ctx.invokeWrite(msg, promise); + } else { + ctx.invokeWriteAndFlush(msg, promise); + } + } finally { + recycle(); + } + } + + void cancel() { + try { + decrementPendingOutboundBytes(); + } finally { + recycle(); + } + } + + private void decrementPendingOutboundBytes() { + if (ESTIMATE_TASK_SIZE_ON_SUBMIT) { + ctx.pipeline.decrementPendingOutboundBytes(size & Integer.MAX_VALUE); + } + } + + private void recycle() { + // Set to null so the GC can collect them directly + ctx = null; + msg = null; + promise = null; + handle.recycle(this); + } + } + + private static final class Tasks { + private final AbstractChannelHandlerContext next; + private final Runnable invokeChannelReadCompleteTask = new Runnable() { + @Override + public void run() { + next.invokeChannelReadComplete(); + } + }; + private final Runnable invokeReadTask = new Runnable() { + @Override + public void run() { + next.invokeRead(); + } + }; + private final Runnable invokeChannelWritableStateChangedTask = new Runnable() { + @Override + public void run() { + next.invokeChannelWritabilityChanged(); + } + }; + private final Runnable invokeFlushTask = new Runnable() { + @Override + public void run() { + next.invokeFlush(); + } + }; + + Tasks(AbstractChannelHandlerContext next) { + this.next = next; + } + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/AbstractCoalescingBufferQueue.java b/netty-channel/src/main/java/io/netty/channel/AbstractCoalescingBufferQueue.java new file mode 100644 index 0000000..33f3d81 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/AbstractCoalescingBufferQueue.java @@ -0,0 +1,399 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.channel; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.util.internal.UnstableApi; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.util.ArrayDeque; + +import static io.netty.util.ReferenceCountUtil.safeRelease; +import static io.netty.util.internal.ObjectUtil.checkNotNull; +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; +import static io.netty.util.internal.PlatformDependent.throwException; + +@UnstableApi +public abstract class AbstractCoalescingBufferQueue { + private static final InternalLogger logger = InternalLoggerFactory.getInstance(AbstractCoalescingBufferQueue.class); + private final ArrayDeque bufAndListenerPairs; + private final PendingBytesTracker tracker; + private int readableBytes; + + /** + * Create a new instance. + * + * @param channel the {@link Channel} which will have the {@link Channel#isWritable()} reflect the amount of queued + * buffers or {@code null} if there is no writability state updated. + * @param initSize the initial size of the underlying queue. + */ + protected AbstractCoalescingBufferQueue(Channel channel, int initSize) { + bufAndListenerPairs = new ArrayDeque(initSize); + tracker = channel == null ? null : PendingBytesTracker.newTracker(channel); + } + + /** + * Add a buffer to the front of the queue and associate a promise with it that should be completed when + * all the buffer's bytes have been consumed from the queue and written. + * @param buf to add to the head of the queue + * @param promise to complete when all the bytes have been consumed and written, can be void. + */ + public final void addFirst(ByteBuf buf, ChannelPromise promise) { + addFirst(buf, toChannelFutureListener(promise)); + } + + private void addFirst(ByteBuf buf, ChannelFutureListener listener) { + if (listener != null) { + bufAndListenerPairs.addFirst(listener); + } + bufAndListenerPairs.addFirst(buf); + incrementReadableBytes(buf.readableBytes()); + } + + /** + * Add a buffer to the end of the queue. + */ + public final void add(ByteBuf buf) { + add(buf, (ChannelFutureListener) null); + } + + /** + * Add a buffer to the end of the queue and associate a promise with it that should be completed when + * all the buffer's bytes have been consumed from the queue and written. + * @param buf to add to the tail of the queue + * @param promise to complete when all the bytes have been consumed and written, can be void. + */ + public final void add(ByteBuf buf, ChannelPromise promise) { + // buffers are added before promises so that we naturally 'consume' the entire buffer during removal + // before we complete it's promise. + add(buf, toChannelFutureListener(promise)); + } + + /** + * Add a buffer to the end of the queue and associate a listener with it that should be completed when + * all the buffers bytes have been consumed from the queue and written. + * @param buf to add to the tail of the queue + * @param listener to notify when all the bytes have been consumed and written, can be {@code null}. + */ + public final void add(ByteBuf buf, ChannelFutureListener listener) { + // buffers are added before promises so that we naturally 'consume' the entire buffer during removal + // before we complete it's promise. + bufAndListenerPairs.add(buf); + if (listener != null) { + bufAndListenerPairs.add(listener); + } + incrementReadableBytes(buf.readableBytes()); + } + + /** + * Remove the first {@link ByteBuf} from the queue. + * @param aggregatePromise used to aggregate the promises and listeners for the returned buffer. + * @return the first {@link ByteBuf} from the queue. + */ + public final ByteBuf removeFirst(ChannelPromise aggregatePromise) { + Object entry = bufAndListenerPairs.poll(); + if (entry == null) { + return null; + } + assert entry instanceof ByteBuf; + ByteBuf result = (ByteBuf) entry; + + decrementReadableBytes(result.readableBytes()); + + entry = bufAndListenerPairs.peek(); + if (entry instanceof ChannelFutureListener) { + aggregatePromise.addListener((ChannelFutureListener) entry); + bufAndListenerPairs.poll(); + } + return result; + } + + /** + * Remove a {@link ByteBuf} from the queue with the specified number of bytes. Any added buffer who's bytes are + * fully consumed during removal will have it's promise completed when the passed aggregate {@link ChannelPromise} + * completes. + * + * @param alloc The allocator used if a new {@link ByteBuf} is generated during the aggregation process. + * @param bytes the maximum number of readable bytes in the returned {@link ByteBuf}, if {@code bytes} is greater + * than {@link #readableBytes} then a buffer of length {@link #readableBytes} is returned. + * @param aggregatePromise used to aggregate the promises and listeners for the constituent buffers. + * @return a {@link ByteBuf} composed of the enqueued buffers. + */ + public final ByteBuf remove(ByteBufAllocator alloc, int bytes, ChannelPromise aggregatePromise) { + checkPositiveOrZero(bytes, "bytes"); + checkNotNull(aggregatePromise, "aggregatePromise"); + + // Use isEmpty rather than readableBytes==0 as we may have a promise associated with an empty buffer. + if (bufAndListenerPairs.isEmpty()) { + assert readableBytes == 0; + return removeEmptyValue(); + } + bytes = Math.min(bytes, readableBytes); + + ByteBuf toReturn = null; + ByteBuf entryBuffer = null; + int originalBytes = bytes; + try { + for (;;) { + Object entry = bufAndListenerPairs.poll(); + if (entry == null) { + break; + } + // fast-path vs abstract type + if (entry instanceof ByteBuf) { + entryBuffer = (ByteBuf) entry; + int bufferBytes = entryBuffer.readableBytes(); + + if (bufferBytes > bytes) { + // Add the buffer back to the queue as we can't consume all of it. + bufAndListenerPairs.addFirst(entryBuffer); + if (bytes > 0) { + // Take a slice of what we can consume and retain it. + entryBuffer = entryBuffer.readRetainedSlice(bytes); + // we end here, so if this is the only buffer to return, skip composing + toReturn = toReturn == null ? entryBuffer + : compose(alloc, toReturn, entryBuffer); + bytes = 0; + } + break; + } + + bytes -= bufferBytes; + if (toReturn == null) { + // if there are no more bytes in the queue after this, there's no reason to compose + toReturn = bufferBytes == readableBytes + ? entryBuffer + : composeFirst(alloc, entryBuffer); + } else { + toReturn = compose(alloc, toReturn, entryBuffer); + } + entryBuffer = null; + } else if (entry instanceof DelegatingChannelPromiseNotifier) { + aggregatePromise.addListener((DelegatingChannelPromiseNotifier) entry); + } else if (entry instanceof ChannelFutureListener) { + aggregatePromise.addListener((ChannelFutureListener) entry); + } + } + } catch (Throwable cause) { + safeRelease(entryBuffer); + safeRelease(toReturn); + aggregatePromise.setFailure(cause); + throwException(cause); + } + decrementReadableBytes(originalBytes - bytes); + return toReturn; + } + + /** + * The number of readable bytes. + */ + public final int readableBytes() { + return readableBytes; + } + + /** + * Are there pending buffers in the queue. + */ + public final boolean isEmpty() { + return bufAndListenerPairs.isEmpty(); + } + + /** + * Release all buffers in the queue and complete all listeners and promises. + */ + public final void releaseAndFailAll(ChannelOutboundInvoker invoker, Throwable cause) { + releaseAndCompleteAll(invoker.newFailedFuture(cause)); + } + + /** + * Copy all pending entries in this queue into the destination queue. + * @param dest to copy pending buffers to. + */ + public final void copyTo(AbstractCoalescingBufferQueue dest) { + dest.bufAndListenerPairs.addAll(bufAndListenerPairs); + dest.incrementReadableBytes(readableBytes); + } + + /** + * Writes all remaining elements in this queue. + * @param ctx The context to write all elements to. + */ + public final void writeAndRemoveAll(ChannelHandlerContext ctx) { + Throwable pending = null; + ByteBuf previousBuf = null; + for (;;) { + Object entry = bufAndListenerPairs.poll(); + try { + if (entry == null) { + if (previousBuf != null) { + decrementReadableBytes(previousBuf.readableBytes()); + ctx.write(previousBuf, ctx.voidPromise()); + } + break; + } + + if (entry instanceof ByteBuf) { + if (previousBuf != null) { + decrementReadableBytes(previousBuf.readableBytes()); + ctx.write(previousBuf, ctx.voidPromise()); + } + previousBuf = (ByteBuf) entry; + } else if (entry instanceof ChannelPromise) { + decrementReadableBytes(previousBuf.readableBytes()); + ctx.write(previousBuf, (ChannelPromise) entry); + previousBuf = null; + } else { + decrementReadableBytes(previousBuf.readableBytes()); + ctx.write(previousBuf).addListener((ChannelFutureListener) entry); + previousBuf = null; + } + } catch (Throwable t) { + if (pending == null) { + pending = t; + } else { + logger.info("Throwable being suppressed because Throwable {} is already pending", pending, t); + } + } + } + if (pending != null) { + throw new IllegalStateException(pending); + } + } + + @Override + public String toString() { + return "bytes: " + readableBytes + " buffers: " + (size() >> 1); + } + + /** + * Calculate the result of {@code current + next}. + */ + protected abstract ByteBuf compose(ByteBufAllocator alloc, ByteBuf cumulation, ByteBuf next); + + /** + * Compose {@code cumulation} and {@code next} into a new {@link CompositeByteBuf}. + */ + protected final ByteBuf composeIntoComposite(ByteBufAllocator alloc, ByteBuf cumulation, ByteBuf next) { + // Create a composite buffer to accumulate this pair and potentially all the buffers + // in the queue. Using +2 as we have already dequeued current and next. + CompositeByteBuf composite = alloc.compositeBuffer(size() + 2); + try { + composite.addComponent(true, cumulation); + composite.addComponent(true, next); + } catch (Throwable cause) { + composite.release(); + safeRelease(next); + throwException(cause); + } + return composite; + } + + /** + * Compose {@code cumulation} and {@code next} into a new {@link ByteBufAllocator#ioBuffer()}. + * @param alloc The allocator to use to allocate the new buffer. + * @param cumulation The current cumulation. + * @param next The next buffer. + * @return The result of {@code cumulation + next}. + */ + protected final ByteBuf copyAndCompose(ByteBufAllocator alloc, ByteBuf cumulation, ByteBuf next) { + ByteBuf newCumulation = alloc.ioBuffer(cumulation.readableBytes() + next.readableBytes()); + try { + newCumulation.writeBytes(cumulation).writeBytes(next); + } catch (Throwable cause) { + newCumulation.release(); + safeRelease(next); + throwException(cause); + } + cumulation.release(); + next.release(); + return newCumulation; + } + + /** + * Calculate the first {@link ByteBuf} which will be used in subsequent calls to + * {@link #compose(ByteBufAllocator, ByteBuf, ByteBuf)}. + */ + protected ByteBuf composeFirst(ByteBufAllocator allocator, ByteBuf first) { + return first; + } + + /** + * The value to return when {@link #remove(ByteBufAllocator, int, ChannelPromise)} is called but the queue is empty. + * @return the {@link ByteBuf} which represents an empty queue. + */ + protected abstract ByteBuf removeEmptyValue(); + + /** + * Get the number of elements in this queue added via one of the {@link #add(ByteBuf)} methods. + * @return the number of elements in this queue. + */ + protected final int size() { + return bufAndListenerPairs.size(); + } + + private void releaseAndCompleteAll(ChannelFuture future) { + Throwable pending = null; + for (;;) { + Object entry = bufAndListenerPairs.poll(); + if (entry == null) { + break; + } + try { + if (entry instanceof ByteBuf) { + ByteBuf buffer = (ByteBuf) entry; + decrementReadableBytes(buffer.readableBytes()); + safeRelease(buffer); + } else { + ((ChannelFutureListener) entry).operationComplete(future); + } + } catch (Throwable t) { + if (pending == null) { + pending = t; + } else { + logger.info("Throwable being suppressed because Throwable {} is already pending", pending, t); + } + } + } + if (pending != null) { + throw new IllegalStateException(pending); + } + } + + private void incrementReadableBytes(int increment) { + int nextReadableBytes = readableBytes + increment; + if (nextReadableBytes < readableBytes) { + throw new IllegalStateException("buffer queue length overflow: " + readableBytes + " + " + increment); + } + readableBytes = nextReadableBytes; + if (tracker != null) { + tracker.incrementPendingOutboundBytes(increment); + } + } + + private void decrementReadableBytes(int decrement) { + readableBytes -= decrement; + assert readableBytes >= 0; + if (tracker != null) { + tracker.decrementPendingOutboundBytes(decrement); + } + } + + private static ChannelFutureListener toChannelFutureListener(ChannelPromise promise) { + return promise.isVoid() ? null : new DelegatingChannelPromiseNotifier(promise); + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/AbstractEventLoop.java b/netty-channel/src/main/java/io/netty/channel/AbstractEventLoop.java new file mode 100644 index 0000000..02b65da --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/AbstractEventLoop.java @@ -0,0 +1,41 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.channel; + +import io.netty.util.concurrent.AbstractEventExecutor; + +/** + * Skeletal implementation of {@link EventLoop}. + */ +public abstract class AbstractEventLoop extends AbstractEventExecutor implements EventLoop { + + protected AbstractEventLoop() { } + + protected AbstractEventLoop(EventLoopGroup parent) { + super(parent); + } + + @Override + public EventLoopGroup parent() { + return (EventLoopGroup) super.parent(); + } + + @Override + public EventLoop next() { + return (EventLoop) super.next(); + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/AbstractEventLoopGroup.java b/netty-channel/src/main/java/io/netty/channel/AbstractEventLoopGroup.java new file mode 100644 index 0000000..3311087 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/AbstractEventLoopGroup.java @@ -0,0 +1,27 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.channel; + +import io.netty.util.concurrent.AbstractEventExecutorGroup; + +/** + * Skeletal implementation of {@link EventLoopGroup}. + */ +public abstract class AbstractEventLoopGroup extends AbstractEventExecutorGroup implements EventLoopGroup { + @Override + public abstract EventLoop next(); +} diff --git a/netty-channel/src/main/java/io/netty/channel/AbstractServerChannel.java b/netty-channel/src/main/java/io/netty/channel/AbstractServerChannel.java new file mode 100644 index 0000000..222108a --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/AbstractServerChannel.java @@ -0,0 +1,82 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import java.net.SocketAddress; + +/** + * A skeletal server-side {@link Channel} implementation. A server-side + * {@link Channel} does not allow the following operations: + *
    + *
  • {@link #connect(SocketAddress, ChannelPromise)}
  • + *
  • {@link #disconnect(ChannelPromise)}
  • + *
  • {@link #write(Object, ChannelPromise)}
  • + *
  • {@link #flush()}
  • + *
  • and the shortcut methods which calls the methods mentioned above + *
+ */ +public abstract class AbstractServerChannel extends AbstractChannel implements ServerChannel { + private static final ChannelMetadata METADATA = new ChannelMetadata(false, 16); + + /** + * Creates a new instance. + */ + protected AbstractServerChannel() { + super(null); + } + + @Override + public ChannelMetadata metadata() { + return METADATA; + } + + @Override + public SocketAddress remoteAddress() { + return null; + } + + @Override + protected SocketAddress remoteAddress0() { + return null; + } + + @Override + protected void doDisconnect() throws Exception { + throw new UnsupportedOperationException(); + } + + @Override + protected AbstractUnsafe newUnsafe() { + return new DefaultServerUnsafe(); + } + + @Override + protected void doWrite(ChannelOutboundBuffer in) throws Exception { + throw new UnsupportedOperationException(); + } + + @Override + protected final Object filterOutboundMessage(Object msg) { + throw new UnsupportedOperationException(); + } + + private final class DefaultServerUnsafe extends AbstractUnsafe { + @Override + public void connect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) { + safeSetFailure(promise, new UnsupportedOperationException()); + } + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/AdaptiveRecvByteBufAllocator.java b/netty-channel/src/main/java/io/netty/channel/AdaptiveRecvByteBufAllocator.java new file mode 100644 index 0000000..46ef634 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/AdaptiveRecvByteBufAllocator.java @@ -0,0 +1,205 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import java.util.ArrayList; +import java.util.List; + +import static io.netty.util.internal.ObjectUtil.checkPositive; +import static java.lang.Math.max; +import static java.lang.Math.min; + +/** + * The {@link RecvByteBufAllocator} that automatically increases and + * decreases the predicted buffer size on feed back. + *

+ * It gradually increases the expected number of readable bytes if the previous + * read fully filled the allocated buffer. It gradually decreases the expected + * number of readable bytes if the read operation was not able to fill a certain + * amount of the allocated buffer two times consecutively. Otherwise, it keeps + * returning the same prediction. + */ +public class AdaptiveRecvByteBufAllocator extends DefaultMaxMessagesRecvByteBufAllocator { + + static final int DEFAULT_MINIMUM = 64; + // Use an initial value that is bigger than the common MTU of 1500 + static final int DEFAULT_INITIAL = 2048; + static final int DEFAULT_MAXIMUM = 65536; + + private static final int INDEX_INCREMENT = 4; + private static final int INDEX_DECREMENT = 1; + + private static final int[] SIZE_TABLE; + + static { + List sizeTable = new ArrayList(); + for (int i = 16; i < 512; i += 16) { + sizeTable.add(i); + } + + // Suppress a warning since i becomes negative when an integer overflow happens + for (int i = 512; i > 0; i <<= 1) { + sizeTable.add(i); + } + + SIZE_TABLE = new int[sizeTable.size()]; + for (int i = 0; i < SIZE_TABLE.length; i ++) { + SIZE_TABLE[i] = sizeTable.get(i); + } + } + + /** + * @deprecated There is state for {@link #maxMessagesPerRead()} which is typically based upon channel type. + */ + @Deprecated + public static final AdaptiveRecvByteBufAllocator DEFAULT = new AdaptiveRecvByteBufAllocator(); + + private static int getSizeTableIndex(final int size) { + for (int low = 0, high = SIZE_TABLE.length - 1;;) { + if (high < low) { + return low; + } + if (high == low) { + return high; + } + + int mid = low + high >>> 1; + int a = SIZE_TABLE[mid]; + int b = SIZE_TABLE[mid + 1]; + if (size > b) { + low = mid + 1; + } else if (size < a) { + high = mid - 1; + } else if (size == a) { + return mid; + } else { + return mid + 1; + } + } + } + + private final class HandleImpl extends MaxMessageHandle { + private final int minIndex; + private final int maxIndex; + private int index; + private int nextReceiveBufferSize; + private boolean decreaseNow; + + HandleImpl(int minIndex, int maxIndex, int initial) { + this.minIndex = minIndex; + this.maxIndex = maxIndex; + + index = getSizeTableIndex(initial); + nextReceiveBufferSize = SIZE_TABLE[index]; + } + + @Override + public void lastBytesRead(int bytes) { + // If we read as much as we asked for we should check if we need to ramp up the size of our next guess. + // This helps adjust more quickly when large amounts of data is pending and can avoid going back to + // the selector to check for more data. Going back to the selector can add significant latency for large + // data transfers. + if (bytes == attemptedBytesRead()) { + record(bytes); + } + super.lastBytesRead(bytes); + } + + @Override + public int guess() { + return nextReceiveBufferSize; + } + + private void record(int actualReadBytes) { + if (actualReadBytes <= SIZE_TABLE[max(0, index - INDEX_DECREMENT)]) { + if (decreaseNow) { + index = max(index - INDEX_DECREMENT, minIndex); + nextReceiveBufferSize = SIZE_TABLE[index]; + decreaseNow = false; + } else { + decreaseNow = true; + } + } else if (actualReadBytes >= nextReceiveBufferSize) { + index = min(index + INDEX_INCREMENT, maxIndex); + nextReceiveBufferSize = SIZE_TABLE[index]; + decreaseNow = false; + } + } + + @Override + public void readComplete() { + record(totalBytesRead()); + } + } + + private final int minIndex; + private final int maxIndex; + private final int initial; + + /** + * Creates a new predictor with the default parameters. With the default + * parameters, the expected buffer size starts from {@code 1024}, does not + * go down below {@code 64}, and does not go up above {@code 65536}. + */ + public AdaptiveRecvByteBufAllocator() { + this(DEFAULT_MINIMUM, DEFAULT_INITIAL, DEFAULT_MAXIMUM); + } + + /** + * Creates a new predictor with the specified parameters. + * + * @param minimum the inclusive lower bound of the expected buffer size + * @param initial the initial buffer size when no feed back was received + * @param maximum the inclusive upper bound of the expected buffer size + */ + public AdaptiveRecvByteBufAllocator(int minimum, int initial, int maximum) { + checkPositive(minimum, "minimum"); + if (initial < minimum) { + throw new IllegalArgumentException("initial: " + initial); + } + if (maximum < initial) { + throw new IllegalArgumentException("maximum: " + maximum); + } + + int minIndex = getSizeTableIndex(minimum); + if (SIZE_TABLE[minIndex] < minimum) { + this.minIndex = minIndex + 1; + } else { + this.minIndex = minIndex; + } + + int maxIndex = getSizeTableIndex(maximum); + if (SIZE_TABLE[maxIndex] > maximum) { + this.maxIndex = maxIndex - 1; + } else { + this.maxIndex = maxIndex; + } + + this.initial = initial; + } + + @SuppressWarnings("deprecation") + @Override + public Handle newHandle() { + return new HandleImpl(minIndex, maxIndex, initial); + } + + @Override + public AdaptiveRecvByteBufAllocator respectMaybeMoreData(boolean respectMaybeMoreData) { + super.respectMaybeMoreData(respectMaybeMoreData); + return this; + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/AddressedEnvelope.java b/netty-channel/src/main/java/io/netty/channel/AddressedEnvelope.java new file mode 100644 index 0000000..fe4006c --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/AddressedEnvelope.java @@ -0,0 +1,56 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.channel; + +import io.netty.util.ReferenceCounted; + +import java.net.SocketAddress; + +/** + * A message that wraps another message with a sender address and a recipient address. + * + * @param the type of the wrapped message + * @param the type of the address + */ +public interface AddressedEnvelope extends ReferenceCounted { + /** + * Returns the message wrapped by this envelope message. + */ + M content(); + + /** + * Returns the address of the sender of this message. + */ + A sender(); + + /** + * Returns the address of the recipient of this message. + */ + A recipient(); + + @Override + AddressedEnvelope retain(); + + @Override + AddressedEnvelope retain(int increment); + + @Override + AddressedEnvelope touch(); + + @Override + AddressedEnvelope touch(Object hint); +} diff --git a/netty-channel/src/main/java/io/netty/channel/Channel.java b/netty-channel/src/main/java/io/netty/channel/Channel.java new file mode 100644 index 0000000..bf5f1d1 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/Channel.java @@ -0,0 +1,302 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.channel.socket.DatagramChannel; +import io.netty.channel.socket.DatagramPacket; +import io.netty.channel.socket.ServerSocketChannel; +import io.netty.channel.socket.SocketChannel; +import io.netty.util.AttributeMap; + +import java.net.InetSocketAddress; +import java.net.SocketAddress; + + +/** + * A nexus to a network socket or a component which is capable of I/O + * operations such as read, write, connect, and bind. + *

+ * A channel provides a user: + *

+ * + *

All I/O operations are asynchronous.

+ *

+ * All I/O operations in Netty are asynchronous. It means any I/O calls will + * return immediately with no guarantee that the requested I/O operation has + * been completed at the end of the call. Instead, you will be returned with + * a {@link ChannelFuture} instance which will notify you when the requested I/O + * operation has succeeded, failed, or canceled. + * + *

Channels are hierarchical

+ *

+ * A {@link Channel} can have a {@linkplain #parent() parent} depending on + * how it was created. For instance, a {@link SocketChannel}, that was accepted + * by {@link ServerSocketChannel}, will return the {@link ServerSocketChannel} + * as its parent on {@link #parent()}. + *

+ * The semantics of the hierarchical structure depends on the transport + * implementation where the {@link Channel} belongs to. For example, you could + * write a new {@link Channel} implementation that creates the sub-channels that + * share one socket connection, as BEEP and + * SSH do. + * + *

Downcast to access transport-specific operations

+ *

+ * Some transports exposes additional operations that is specific to the + * transport. Down-cast the {@link Channel} to sub-type to invoke such + * operations. For example, with the old I/O datagram transport, multicast + * join / leave operations are provided by {@link DatagramChannel}. + * + *

Release resources

+ *

+ * It is important to call {@link #close()} or {@link #close(ChannelPromise)} to release all + * resources once you are done with the {@link Channel}. This ensures all resources are + * released in a proper way, i.e. filehandles. + */ +public interface Channel extends AttributeMap, ChannelOutboundInvoker, Comparable { + + /** + * Returns the globally unique identifier of this {@link Channel}. + */ + ChannelId id(); + + /** + * Return the {@link EventLoop} this {@link Channel} was registered to. + */ + EventLoop eventLoop(); + + /** + * Returns the parent of this channel. + * + * @return the parent channel. + * {@code null} if this channel does not have a parent channel. + */ + Channel parent(); + + /** + * Returns the configuration of this channel. + */ + ChannelConfig config(); + + /** + * Returns {@code true} if the {@link Channel} is open and may get active later + */ + boolean isOpen(); + + /** + * Returns {@code true} if the {@link Channel} is registered with an {@link EventLoop}. + */ + boolean isRegistered(); + + /** + * Return {@code true} if the {@link Channel} is active and so connected. + */ + boolean isActive(); + + /** + * Return the {@link ChannelMetadata} of the {@link Channel} which describe the nature of the {@link Channel}. + */ + ChannelMetadata metadata(); + + /** + * Returns the local address where this channel is bound to. The returned + * {@link SocketAddress} is supposed to be down-cast into more concrete + * type such as {@link InetSocketAddress} to retrieve the detailed + * information. + * + * @return the local address of this channel. + * {@code null} if this channel is not bound. + */ + SocketAddress localAddress(); + + /** + * Returns the remote address where this channel is connected to. The + * returned {@link SocketAddress} is supposed to be down-cast into more + * concrete type such as {@link InetSocketAddress} to retrieve the detailed + * information. + * + * @return the remote address of this channel. + * {@code null} if this channel is not connected. + * If this channel is not connected but it can receive messages + * from arbitrary remote addresses (e.g. {@link DatagramChannel}, + * use {@link DatagramPacket#recipient()} to determine + * the origination of the received message as this method will + * return {@code null}. + */ + SocketAddress remoteAddress(); + + /** + * Returns the {@link ChannelFuture} which will be notified when this + * channel is closed. This method always returns the same future instance. + */ + ChannelFuture closeFuture(); + + /** + * Returns {@code true} if and only if the I/O thread will perform the + * requested write operation immediately. Any write requests made when + * this method returns {@code false} are queued until the I/O thread is + * ready to process the queued write requests. + */ + boolean isWritable(); + + /** + * Get how many bytes can be written until {@link #isWritable()} returns {@code false}. + * This quantity will always be non-negative. If {@link #isWritable()} is {@code false} then 0. + */ + long bytesBeforeUnwritable(); + + /** + * Get how many bytes must be drained from underlying buffers until {@link #isWritable()} returns {@code true}. + * This quantity will always be non-negative. If {@link #isWritable()} is {@code true} then 0. + */ + long bytesBeforeWritable(); + + /** + * Returns an internal-use-only object that provides unsafe operations. + */ + Unsafe unsafe(); + + /** + * Return the assigned {@link ChannelPipeline}. + */ + ChannelPipeline pipeline(); + + /** + * Return the assigned {@link ByteBufAllocator} which will be used to allocate {@link ByteBuf}s. + */ + ByteBufAllocator alloc(); + + @Override + Channel read(); + + @Override + Channel flush(); + + /** + * Unsafe operations that should never be called from user-code. These methods + * are only provided to implement the actual transport, and must be invoked from an I/O thread except for the + * following methods: + *

    + *
  • {@link #localAddress()}
  • + *
  • {@link #remoteAddress()}
  • + *
  • {@link #closeForcibly()}
  • + *
  • {@link #register(EventLoop, ChannelPromise)}
  • + *
  • {@link #deregister(ChannelPromise)}
  • + *
  • {@link #voidPromise()}
  • + *
+ */ + interface Unsafe { + + /** + * Return the assigned {@link RecvByteBufAllocator.Handle} which will be used to allocate {@link ByteBuf}'s when + * receiving data. + */ + RecvByteBufAllocator.Handle recvBufAllocHandle(); + + /** + * Return the {@link SocketAddress} to which is bound local or + * {@code null} if none. + */ + SocketAddress localAddress(); + + /** + * Return the {@link SocketAddress} to which is bound remote or + * {@code null} if none is bound yet. + */ + SocketAddress remoteAddress(); + + /** + * Register the {@link Channel} of the {@link ChannelPromise} and notify + * the {@link ChannelFuture} once the registration was complete. + */ + void register(EventLoop eventLoop, ChannelPromise promise); + + /** + * Bind the {@link SocketAddress} to the {@link Channel} of the {@link ChannelPromise} and notify + * it once its done. + */ + void bind(SocketAddress localAddress, ChannelPromise promise); + + /** + * Connect the {@link Channel} of the given {@link ChannelFuture} with the given remote {@link SocketAddress}. + * If a specific local {@link SocketAddress} should be used it need to be given as argument. Otherwise just + * pass {@code null} to it. + * + * The {@link ChannelPromise} will get notified once the connect operation was complete. + */ + void connect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise); + + /** + * Disconnect the {@link Channel} of the {@link ChannelFuture} and notify the {@link ChannelPromise} once the + * operation was complete. + */ + void disconnect(ChannelPromise promise); + + /** + * Close the {@link Channel} of the {@link ChannelPromise} and notify the {@link ChannelPromise} once the + * operation was complete. + */ + void close(ChannelPromise promise); + + /** + * Closes the {@link Channel} immediately without firing any events. Probably only useful + * when registration attempt failed. + */ + void closeForcibly(); + + /** + * Deregister the {@link Channel} of the {@link ChannelPromise} from {@link EventLoop} and notify the + * {@link ChannelPromise} once the operation was complete. + */ + void deregister(ChannelPromise promise); + + /** + * Schedules a read operation that fills the inbound buffer of the first {@link ChannelInboundHandler} in the + * {@link ChannelPipeline}. If there's already a pending read operation, this method does nothing. + */ + void beginRead(); + + /** + * Schedules a write operation. + */ + void write(Object msg, ChannelPromise promise); + + /** + * Flush out all write operations scheduled via {@link #write(Object, ChannelPromise)}. + */ + void flush(); + + /** + * Return a special ChannelPromise which can be reused and passed to the operations in {@link Unsafe}. + * It will never be notified of a success or error and so is only a placeholder for operations + * that take a {@link ChannelPromise} as argument but for which you not want to get notified. + */ + ChannelPromise voidPromise(); + + /** + * Returns the {@link ChannelOutboundBuffer} of the {@link Channel} where the pending write requests are stored. + */ + ChannelOutboundBuffer outboundBuffer(); + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/ChannelConfig.java b/netty-channel/src/main/java/io/netty/channel/ChannelConfig.java new file mode 100644 index 0000000..a907160 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/ChannelConfig.java @@ -0,0 +1,268 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.buffer.ByteBufAllocator; +import io.netty.channel.socket.SocketChannelConfig; + +import java.nio.ByteBuffer; +import java.nio.channels.WritableByteChannel; +import java.util.Map; + +/** + * A set of configuration properties of a {@link Channel}. + *

+ * Please down-cast to more specific configuration type such as + * {@link SocketChannelConfig} or use {@link #setOptions(Map)} to set the + * transport-specific properties: + *

+ * {@link Channel} ch = ...;
+ * {@link SocketChannelConfig} cfg = ({@link SocketChannelConfig}) ch.getConfig();
+ * cfg.setTcpNoDelay(false);
+ * 
+ * + *

Option map

+ * + * An option map property is a dynamic write-only property which allows + * the configuration of a {@link Channel} without down-casting its associated + * {@link ChannelConfig}. To update an option map, please call {@link #setOptions(Map)}. + *

+ * All {@link ChannelConfig} has the following options: + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
NameAssociated setter method
{@link ChannelOption#CONNECT_TIMEOUT_MILLIS}{@link #setConnectTimeoutMillis(int)}
{@link ChannelOption#WRITE_SPIN_COUNT}{@link #setWriteSpinCount(int)}
{@link ChannelOption#WRITE_BUFFER_WATER_MARK}{@link #setWriteBufferWaterMark(WriteBufferWaterMark)}
{@link ChannelOption#ALLOCATOR}{@link #setAllocator(ByteBufAllocator)}
{@link ChannelOption#AUTO_READ}{@link #setAutoRead(boolean)}
+ *

+ * More options are available in the sub-types of {@link ChannelConfig}. For + * example, you can configure the parameters which are specific to a TCP/IP + * socket as explained in {@link SocketChannelConfig}. + */ +public interface ChannelConfig { + + /** + * Return all set {@link ChannelOption}'s. + */ + Map, Object> getOptions(); + + /** + * Sets the configuration properties from the specified {@link Map}. + */ + boolean setOptions(Map, ?> options); + + /** + * Return the value of the given {@link ChannelOption} + */ + T getOption(ChannelOption option); + + /** + * Sets a configuration property with the specified name and value. + * To override this method properly, you must call the super class: + *

+     * public boolean setOption(ChannelOption<T> option, T value) {
+     *     if (super.setOption(option, value)) {
+     *         return true;
+     *     }
+     *
+     *     if (option.equals(additionalOption)) {
+     *         ....
+     *         return true;
+     *     }
+     *
+     *     return false;
+     * }
+     * 
+ * + * @return {@code true} if and only if the property has been set + */ + boolean setOption(ChannelOption option, T value); + + /** + * Returns the connect timeout of the channel in milliseconds. If the + * {@link Channel} does not support connect operation, this property is not + * used at all, and therefore will be ignored. + * + * @return the connect timeout in milliseconds. {@code 0} if disabled. + */ + int getConnectTimeoutMillis(); + + /** + * Sets the connect timeout of the channel in milliseconds. If the + * {@link Channel} does not support connect operation, this property is not + * used at all, and therefore will be ignored. + * + * @param connectTimeoutMillis the connect timeout in milliseconds. + * {@code 0} to disable. + */ + ChannelConfig setConnectTimeoutMillis(int connectTimeoutMillis); + + /** + * @deprecated Use {@link MaxMessagesRecvByteBufAllocator} and + * {@link MaxMessagesRecvByteBufAllocator#maxMessagesPerRead()}. + *

+ * Returns the maximum number of messages to read per read loop. + * a {@link ChannelInboundHandler#channelRead(ChannelHandlerContext, Object) channelRead()} event. + * If this value is greater than 1, an event loop might attempt to read multiple times to procure multiple messages. + */ + @Deprecated + int getMaxMessagesPerRead(); + + /** + * @deprecated Use {@link MaxMessagesRecvByteBufAllocator} and + * {@link MaxMessagesRecvByteBufAllocator#maxMessagesPerRead(int)}. + *

+ * Sets the maximum number of messages to read per read loop. + * If this value is greater than 1, an event loop might attempt to read multiple times to procure multiple messages. + */ + @Deprecated + ChannelConfig setMaxMessagesPerRead(int maxMessagesPerRead); + + /** + * Returns the maximum loop count for a write operation until + * {@link WritableByteChannel#write(ByteBuffer)} returns a non-zero value. + * It is similar to what a spin lock is used for in concurrency programming. + * It improves memory utilization and write throughput depending on + * the platform that JVM runs on. The default value is {@code 16}. + */ + int getWriteSpinCount(); + + /** + * Sets the maximum loop count for a write operation until + * {@link WritableByteChannel#write(ByteBuffer)} returns a non-zero value. + * It is similar to what a spin lock is used for in concurrency programming. + * It improves memory utilization and write throughput depending on + * the platform that JVM runs on. The default value is {@code 16}. + * + * @throws IllegalArgumentException + * if the specified value is {@code 0} or less than {@code 0} + */ + ChannelConfig setWriteSpinCount(int writeSpinCount); + + /** + * Returns {@link ByteBufAllocator} which is used for the channel + * to allocate buffers. + */ + ByteBufAllocator getAllocator(); + + /** + * Set the {@link ByteBufAllocator} which is used for the channel + * to allocate buffers. + */ + ChannelConfig setAllocator(ByteBufAllocator allocator); + + /** + * Returns {@link RecvByteBufAllocator} which is used for the channel to allocate receive buffers. + */ + T getRecvByteBufAllocator(); + + /** + * Set the {@link RecvByteBufAllocator} which is used for the channel to allocate receive buffers. + */ + ChannelConfig setRecvByteBufAllocator(RecvByteBufAllocator allocator); + + /** + * Returns {@code true} if and only if {@link ChannelHandlerContext#read()} will be invoked automatically so that + * a user application doesn't need to call it at all. The default value is {@code true}. + */ + boolean isAutoRead(); + + /** + * Sets if {@link ChannelHandlerContext#read()} will be invoked automatically so that a user application doesn't + * need to call it at all. The default value is {@code true}. + */ + ChannelConfig setAutoRead(boolean autoRead); + + /** + * Returns {@code true} if and only if the {@link Channel} will be closed automatically and immediately on + * write failure. The default is {@code true}. + */ + boolean isAutoClose(); + + /** + * Sets whether the {@link Channel} should be closed automatically and immediately on write failure. + * The default is {@code true}. + */ + ChannelConfig setAutoClose(boolean autoClose); + + /** + * Returns the high water mark of the write buffer. If the number of bytes + * queued in the write buffer exceeds this value, {@link Channel#isWritable()} + * will start to return {@code false}. + */ + int getWriteBufferHighWaterMark(); + + /** + *

+ * Sets the high water mark of the write buffer. If the number of bytes + * queued in the write buffer exceeds this value, {@link Channel#isWritable()} + * will start to return {@code false}. + */ + ChannelConfig setWriteBufferHighWaterMark(int writeBufferHighWaterMark); + + /** + * Returns the low water mark of the write buffer. Once the number of bytes + * queued in the write buffer exceeded the + * {@linkplain #setWriteBufferHighWaterMark(int) high water mark} and then + * dropped down below this value, {@link Channel#isWritable()} will start to return + * {@code true} again. + */ + int getWriteBufferLowWaterMark(); + + /** + *

+ * Sets the low water mark of the write buffer. Once the number of bytes + * queued in the write buffer exceeded the + * {@linkplain #setWriteBufferHighWaterMark(int) high water mark} and then + * dropped down below this value, {@link Channel#isWritable()} will start to return + * {@code true} again. + */ + ChannelConfig setWriteBufferLowWaterMark(int writeBufferLowWaterMark); + + /** + * Returns {@link MessageSizeEstimator} which is used for the channel + * to detect the size of a message. + */ + MessageSizeEstimator getMessageSizeEstimator(); + + /** + * Set the {@link MessageSizeEstimator} which is used for the channel + * to detect the size of a message. + */ + ChannelConfig setMessageSizeEstimator(MessageSizeEstimator estimator); + + /** + * Returns the {@link WriteBufferWaterMark} which is used for setting the high and low + * water mark of the write buffer. + */ + WriteBufferWaterMark getWriteBufferWaterMark(); + + /** + * Set the {@link WriteBufferWaterMark} which is used for setting the high and low + * water mark of the write buffer. + */ + ChannelConfig setWriteBufferWaterMark(WriteBufferWaterMark writeBufferWaterMark); +} diff --git a/netty-channel/src/main/java/io/netty/channel/ChannelDuplexHandler.java b/netty-channel/src/main/java/io/netty/channel/ChannelDuplexHandler.java new file mode 100644 index 0000000..620aafb --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/ChannelDuplexHandler.java @@ -0,0 +1,129 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.channel.ChannelHandlerMask.Skip; + +import java.net.SocketAddress; + +/** + * {@link ChannelHandler} implementation which represents a combination out of a {@link ChannelInboundHandler} and + * the {@link ChannelOutboundHandler}. + * + * It is a good starting point if your {@link ChannelHandler} implementation needs to intercept operations and also + * state updates. + */ +public class ChannelDuplexHandler extends ChannelInboundHandlerAdapter implements ChannelOutboundHandler { + + /** + * Calls {@link ChannelHandlerContext#bind(SocketAddress, ChannelPromise)} to forward + * to the next {@link ChannelOutboundHandler} in the {@link ChannelPipeline}. + * + * Sub-classes may override this method to change behavior. + */ + @Skip + @Override + public void bind(ChannelHandlerContext ctx, SocketAddress localAddress, + ChannelPromise promise) throws Exception { + ctx.bind(localAddress, promise); + } + + /** + * Calls {@link ChannelHandlerContext#connect(SocketAddress, SocketAddress, ChannelPromise)} to forward + * to the next {@link ChannelOutboundHandler} in the {@link ChannelPipeline}. + * + * Sub-classes may override this method to change behavior. + */ + @Skip + @Override + public void connect(ChannelHandlerContext ctx, SocketAddress remoteAddress, + SocketAddress localAddress, ChannelPromise promise) throws Exception { + ctx.connect(remoteAddress, localAddress, promise); + } + + /** + * Calls {@link ChannelHandlerContext#disconnect(ChannelPromise)} to forward + * to the next {@link ChannelOutboundHandler} in the {@link ChannelPipeline}. + * + * Sub-classes may override this method to change behavior. + */ + @Skip + @Override + public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) + throws Exception { + ctx.disconnect(promise); + } + + /** + * Calls {@link ChannelHandlerContext#close(ChannelPromise)} to forward + * to the next {@link ChannelOutboundHandler} in the {@link ChannelPipeline}. + * + * Sub-classes may override this method to change behavior. + */ + @Skip + @Override + public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + ctx.close(promise); + } + + /** + * Calls {@link ChannelHandlerContext#deregister(ChannelPromise)} to forward + * to the next {@link ChannelOutboundHandler} in the {@link ChannelPipeline}. + * + * Sub-classes may override this method to change behavior. + */ + @Skip + @Override + public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + ctx.deregister(promise); + } + + /** + * Calls {@link ChannelHandlerContext#read()} to forward + * to the next {@link ChannelOutboundHandler} in the {@link ChannelPipeline}. + * + * Sub-classes may override this method to change behavior. + */ + @Skip + @Override + public void read(ChannelHandlerContext ctx) throws Exception { + ctx.read(); + } + + /** + * Calls {@link ChannelHandlerContext#write(Object, ChannelPromise)} to forward + * to the next {@link ChannelOutboundHandler} in the {@link ChannelPipeline}. + * + * Sub-classes may override this method to change behavior. + */ + @Skip + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + ctx.write(msg, promise); + } + + /** + * Calls {@link ChannelHandlerContext#flush()} to forward + * to the next {@link ChannelOutboundHandler} in the {@link ChannelPipeline}. + * + * Sub-classes may override this method to change behavior. + */ + @Skip + @Override + public void flush(ChannelHandlerContext ctx) throws Exception { + ctx.flush(); + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/ChannelException.java b/netty-channel/src/main/java/io/netty/channel/ChannelException.java new file mode 100644 index 0000000..147f81c --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/ChannelException.java @@ -0,0 +1,94 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.SuppressJava6Requirement; +import io.netty.util.internal.ThrowableUtil; +import io.netty.util.internal.UnstableApi; + +/** + * A {@link RuntimeException} which is thrown when an I/O operation fails. + */ +public class ChannelException extends RuntimeException { + + private static final long serialVersionUID = 2908618315971075004L; + + /** + * Creates a new exception. + */ + public ChannelException() { + } + + /** + * Creates a new exception. + */ + public ChannelException(String message, Throwable cause) { + super(message, cause); + } + + /** + * Creates a new exception. + */ + public ChannelException(String message) { + super(message); + } + + /** + * Creates a new exception. + */ + public ChannelException(Throwable cause) { + super(cause); + } + + @UnstableApi + @SuppressJava6Requirement(reason = "uses Java 7+ RuntimeException.(String, Throwable, boolean, boolean)" + + " but is guarded by version checks") + protected ChannelException(String message, Throwable cause, boolean shared) { + super(message, cause, false, true); + assert shared; + } + + static ChannelException newStatic(String message, Class clazz, String method) { + ChannelException exception; + if (PlatformDependent.javaVersion() >= 7) { + exception = new StacklessChannelException(message, null, true); + } else { + exception = new StacklessChannelException(message, null); + } + return ThrowableUtil.unknownStackTrace(exception, clazz, method); + } + + private static final class StacklessChannelException extends ChannelException { + private static final long serialVersionUID = -6384642137753538579L; + + StacklessChannelException(String message, Throwable cause) { + super(message, cause); + } + + StacklessChannelException(String message, Throwable cause, boolean shared) { + super(message, cause, shared); + } + + // Override fillInStackTrace() so we not populate the backtrace via a native call and so leak the + // Classloader. + + @Override + public Throwable fillInStackTrace() { + return this; + } + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/ChannelFactory.java b/netty-channel/src/main/java/io/netty/channel/ChannelFactory.java new file mode 100644 index 0000000..c039e51 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/ChannelFactory.java @@ -0,0 +1,28 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +/** + * Creates a new {@link Channel}. + */ +@SuppressWarnings({ "ClassNameSameAsAncestorName", "deprecation" }) +public interface ChannelFactory extends io.netty.bootstrap.ChannelFactory { + /** + * Creates a new channel. + */ + @Override + T newChannel(); +} diff --git a/netty-channel/src/main/java/io/netty/channel/ChannelFlushPromiseNotifier.java b/netty-channel/src/main/java/io/netty/channel/ChannelFlushPromiseNotifier.java new file mode 100644 index 0000000..86a4ce0 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/ChannelFlushPromiseNotifier.java @@ -0,0 +1,273 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.util.internal.ObjectUtil; + +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; + +import java.util.ArrayDeque; +import java.util.Queue; + +/** + * This implementation allows to register {@link ChannelFuture} instances which will get notified once some amount of + * data was written and so a checkpoint was reached. + */ +public final class ChannelFlushPromiseNotifier { + + private long writeCounter; + private final Queue flushCheckpoints = new ArrayDeque(); + private final boolean tryNotify; + + /** + * Create a new instance + * + * @param tryNotify if {@code true} the {@link ChannelPromise}s will get notified with + * {@link ChannelPromise#trySuccess()} and {@link ChannelPromise#tryFailure(Throwable)}. + * Otherwise {@link ChannelPromise#setSuccess()} and {@link ChannelPromise#setFailure(Throwable)} + * is used + */ + public ChannelFlushPromiseNotifier(boolean tryNotify) { + this.tryNotify = tryNotify; + } + + /** + * Create a new instance which will use {@link ChannelPromise#setSuccess()} and + * {@link ChannelPromise#setFailure(Throwable)} to notify the {@link ChannelPromise}s. + */ + public ChannelFlushPromiseNotifier() { + this(false); + } + + /** + * @deprecated use {@link #add(ChannelPromise, long)} + */ + @Deprecated + public ChannelFlushPromiseNotifier add(ChannelPromise promise, int pendingDataSize) { + return add(promise, (long) pendingDataSize); + } + + /** + * Add a {@link ChannelPromise} to this {@link ChannelFlushPromiseNotifier} which will be notified after the given + * {@code pendingDataSize} was reached. + */ + public ChannelFlushPromiseNotifier add(ChannelPromise promise, long pendingDataSize) { + ObjectUtil.checkNotNull(promise, "promise"); + checkPositiveOrZero(pendingDataSize, "pendingDataSize"); + long checkpoint = writeCounter + pendingDataSize; + if (promise instanceof FlushCheckpoint) { + FlushCheckpoint cp = (FlushCheckpoint) promise; + cp.flushCheckpoint(checkpoint); + flushCheckpoints.add(cp); + } else { + flushCheckpoints.add(new DefaultFlushCheckpoint(checkpoint, promise)); + } + return this; + } + /** + * Increase the current write counter by the given delta + */ + public ChannelFlushPromiseNotifier increaseWriteCounter(long delta) { + checkPositiveOrZero(delta, "delta"); + writeCounter += delta; + return this; + } + + /** + * Return the current write counter of this {@link ChannelFlushPromiseNotifier} + */ + public long writeCounter() { + return writeCounter; + } + + /** + * Notify all {@link ChannelFuture}s that were registered with {@link #add(ChannelPromise, int)} and + * their pendingDatasize is smaller after the current writeCounter returned by {@link #writeCounter()}. + * + * After a {@link ChannelFuture} was notified it will be removed from this {@link ChannelFlushPromiseNotifier} and + * so not receive anymore notification. + */ + public ChannelFlushPromiseNotifier notifyPromises() { + notifyPromises0(null); + return this; + } + + /** + * @deprecated use {@link #notifyPromises()} + */ + @Deprecated + public ChannelFlushPromiseNotifier notifyFlushFutures() { + return notifyPromises(); + } + + /** + * Notify all {@link ChannelFuture}s that were registered with {@link #add(ChannelPromise, int)} and + * their pendingDatasize isis smaller then the current writeCounter returned by {@link #writeCounter()}. + * + * After a {@link ChannelFuture} was notified it will be removed from this {@link ChannelFlushPromiseNotifier} and + * so not receive anymore notification. + * + * The rest of the remaining {@link ChannelFuture}s will be failed with the given {@link Throwable}. + * + * So after this operation this {@link ChannelFutureListener} is empty. + */ + public ChannelFlushPromiseNotifier notifyPromises(Throwable cause) { + notifyPromises(); + for (;;) { + FlushCheckpoint cp = flushCheckpoints.poll(); + if (cp == null) { + break; + } + if (tryNotify) { + cp.promise().tryFailure(cause); + } else { + cp.promise().setFailure(cause); + } + } + return this; + } + + /** + * @deprecated use {@link #notifyPromises(Throwable)} + */ + @Deprecated + public ChannelFlushPromiseNotifier notifyFlushFutures(Throwable cause) { + return notifyPromises(cause); + } + + /** + * Notify all {@link ChannelFuture}s that were registered with {@link #add(ChannelPromise, int)} and + * their pendingDatasize is smaller then the current writeCounter returned by {@link #writeCounter()} using + * the given cause1. + * + * After a {@link ChannelFuture} was notified it will be removed from this {@link ChannelFlushPromiseNotifier} and + * so not receive anymore notification. + * + * The rest of the remaining {@link ChannelFuture}s will be failed with the given {@link Throwable}. + * + * So after this operation this {@link ChannelFutureListener} is empty. + * + * @param cause1 the {@link Throwable} which will be used to fail all of the {@link ChannelFuture}s which + * pendingDataSize is smaller then the current writeCounter returned by {@link #writeCounter()} + * @param cause2 the {@link Throwable} which will be used to fail the remaining {@link ChannelFuture}s + */ + public ChannelFlushPromiseNotifier notifyPromises(Throwable cause1, Throwable cause2) { + notifyPromises0(cause1); + for (;;) { + FlushCheckpoint cp = flushCheckpoints.poll(); + if (cp == null) { + break; + } + if (tryNotify) { + cp.promise().tryFailure(cause2); + } else { + cp.promise().setFailure(cause2); + } + } + return this; + } + + /** + * @deprecated use {@link #notifyPromises(Throwable, Throwable)} + */ + @Deprecated + public ChannelFlushPromiseNotifier notifyFlushFutures(Throwable cause1, Throwable cause2) { + return notifyPromises(cause1, cause2); + } + + private void notifyPromises0(Throwable cause) { + if (flushCheckpoints.isEmpty()) { + writeCounter = 0; + return; + } + + final long writeCounter = this.writeCounter; + for (;;) { + FlushCheckpoint cp = flushCheckpoints.peek(); + if (cp == null) { + // Reset the counter if there's nothing in the notification list. + this.writeCounter = 0; + break; + } + + if (cp.flushCheckpoint() > writeCounter) { + if (writeCounter > 0 && flushCheckpoints.size() == 1) { + this.writeCounter = 0; + cp.flushCheckpoint(cp.flushCheckpoint() - writeCounter); + } + break; + } + + flushCheckpoints.remove(); + ChannelPromise promise = cp.promise(); + if (cause == null) { + if (tryNotify) { + promise.trySuccess(); + } else { + promise.setSuccess(); + } + } else { + if (tryNotify) { + promise.tryFailure(cause); + } else { + promise.setFailure(cause); + } + } + } + + // Avoid overflow + final long newWriteCounter = this.writeCounter; + if (newWriteCounter >= 0x8000000000L) { + // Reset the counter only when the counter grew pretty large + // so that we can reduce the cost of updating all entries in the notification list. + this.writeCounter = 0; + for (FlushCheckpoint cp: flushCheckpoints) { + cp.flushCheckpoint(cp.flushCheckpoint() - newWriteCounter); + } + } + } + + interface FlushCheckpoint { + long flushCheckpoint(); + void flushCheckpoint(long checkpoint); + ChannelPromise promise(); + } + + private static class DefaultFlushCheckpoint implements FlushCheckpoint { + private long checkpoint; + private final ChannelPromise future; + + DefaultFlushCheckpoint(long checkpoint, ChannelPromise future) { + this.checkpoint = checkpoint; + this.future = future; + } + + @Override + public long flushCheckpoint() { + return checkpoint; + } + + @Override + public void flushCheckpoint(long checkpoint) { + this.checkpoint = checkpoint; + } + + @Override + public ChannelPromise promise() { + return future; + } + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/ChannelFuture.java b/netty-channel/src/main/java/io/netty/channel/ChannelFuture.java new file mode 100644 index 0000000..4222860 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/ChannelFuture.java @@ -0,0 +1,212 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.bootstrap.Bootstrap; +import io.netty.util.concurrent.BlockingOperationException; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.GenericFutureListener; + +import java.util.concurrent.TimeUnit; + + +/** + * The result of an asynchronous {@link Channel} I/O operation. + *

+ * All I/O operations in Netty are asynchronous. It means any I/O calls will + * return immediately with no guarantee that the requested I/O operation has + * been completed at the end of the call. Instead, you will be returned with + * a {@link ChannelFuture} instance which gives you the information about the + * result or status of the I/O operation. + *

+ * A {@link ChannelFuture} is either uncompleted or completed. + * When an I/O operation begins, a new future object is created. The new future + * is uncompleted initially - it is neither succeeded, failed, nor cancelled + * because the I/O operation is not finished yet. If the I/O operation is + * finished either successfully, with failure, or by cancellation, the future is + * marked as completed with more specific information, such as the cause of the + * failure. Please note that even failure and cancellation belong to the + * completed state. + *

+ *                                      +---------------------------+
+ *                                      | Completed successfully    |
+ *                                      +---------------------------+
+ *                                 +---->      isDone() = true      |
+ * +--------------------------+    |    |   isSuccess() = true      |
+ * |        Uncompleted       |    |    +===========================+
+ * +--------------------------+    |    | Completed with failure    |
+ * |      isDone() = false    |    |    +---------------------------+
+ * |   isSuccess() = false    |----+---->      isDone() = true      |
+ * | isCancelled() = false    |    |    |       cause() = non-null  |
+ * |       cause() = null     |    |    +===========================+
+ * +--------------------------+    |    | Completed by cancellation |
+ *                                 |    +---------------------------+
+ *                                 +---->      isDone() = true      |
+ *                                      | isCancelled() = true      |
+ *                                      +---------------------------+
+ * 
+ * + * Various methods are provided to let you check if the I/O operation has been + * completed, wait for the completion, and retrieve the result of the I/O + * operation. It also allows you to add {@link ChannelFutureListener}s so you + * can get notified when the I/O operation is completed. + * + *

Prefer {@link #addListener(GenericFutureListener)} to {@link #await()}

+ * + * It is recommended to prefer {@link #addListener(GenericFutureListener)} to + * {@link #await()} wherever possible to get notified when an I/O operation is + * done and to do any follow-up tasks. + *

+ * {@link #addListener(GenericFutureListener)} is non-blocking. It simply adds + * the specified {@link ChannelFutureListener} to the {@link ChannelFuture}, and + * I/O thread will notify the listeners when the I/O operation associated with + * the future is done. {@link ChannelFutureListener} yields the best + * performance and resource utilization because it does not block at all, but + * it could be tricky to implement a sequential logic if you are not used to + * event-driven programming. + *

+ * By contrast, {@link #await()} is a blocking operation. Once called, the + * caller thread blocks until the operation is done. It is easier to implement + * a sequential logic with {@link #await()}, but the caller thread blocks + * unnecessarily until the I/O operation is done and there's relatively + * expensive cost of inter-thread notification. Moreover, there's a chance of + * dead lock in a particular circumstance, which is described below. + * + *

Do not call {@link #await()} inside {@link ChannelHandler}

+ *

+ * The event handler methods in {@link ChannelHandler} are usually called by + * an I/O thread. If {@link #await()} is called by an event handler + * method, which is called by the I/O thread, the I/O operation it is waiting + * for might never complete because {@link #await()} can block the I/O + * operation it is waiting for, which is a dead lock. + *

+ * // BAD - NEVER DO THIS
+ * {@code @Override}
+ * public void channelRead({@link ChannelHandlerContext} ctx, Object msg) {
+ *     {@link ChannelFuture} future = ctx.channel().close();
+ *     future.awaitUninterruptibly();
+ *     // Perform post-closure operation
+ *     // ...
+ * }
+ *
+ * // GOOD
+ * {@code @Override}
+ * public void channelRead({@link ChannelHandlerContext} ctx, Object msg) {
+ *     {@link ChannelFuture} future = ctx.channel().close();
+ *     future.addListener(new {@link ChannelFutureListener}() {
+ *         public void operationComplete({@link ChannelFuture} future) {
+ *             // Perform post-closure operation
+ *             // ...
+ *         }
+ *     });
+ * }
+ * 
+ *

+ * In spite of the disadvantages mentioned above, there are certainly the cases + * where it is more convenient to call {@link #await()}. In such a case, please + * make sure you do not call {@link #await()} in an I/O thread. Otherwise, + * {@link BlockingOperationException} will be raised to prevent a dead lock. + * + *

Do not confuse I/O timeout and await timeout

+ * + * The timeout value you specify with {@link #await(long)}, + * {@link #await(long, TimeUnit)}, {@link #awaitUninterruptibly(long)}, or + * {@link #awaitUninterruptibly(long, TimeUnit)} are not related with I/O + * timeout at all. If an I/O operation times out, the future will be marked as + * 'completed with failure,' as depicted in the diagram above. For example, + * connect timeout should be configured via a transport-specific option: + *
+ * // BAD - NEVER DO THIS
+ * {@link Bootstrap} b = ...;
+ * {@link ChannelFuture} f = b.connect(...);
+ * f.awaitUninterruptibly(10, TimeUnit.SECONDS);
+ * if (f.isCancelled()) {
+ *     // Connection attempt cancelled by user
+ * } else if (!f.isSuccess()) {
+ *     // You might get a NullPointerException here because the future
+ *     // might not be completed yet.
+ *     f.cause().printStackTrace();
+ * } else {
+ *     // Connection established successfully
+ * }
+ *
+ * // GOOD
+ * {@link Bootstrap} b = ...;
+ * // Configure the connect timeout option.
+ * b.option({@link ChannelOption}.CONNECT_TIMEOUT_MILLIS, 10000);
+ * {@link ChannelFuture} f = b.connect(...);
+ * f.awaitUninterruptibly();
+ *
+ * // Now we are sure the future is completed.
+ * assert f.isDone();
+ *
+ * if (f.isCancelled()) {
+ *     // Connection attempt cancelled by user
+ * } else if (!f.isSuccess()) {
+ *     f.cause().printStackTrace();
+ * } else {
+ *     // Connection established successfully
+ * }
+ * 
+ */ +public interface ChannelFuture extends Future { + + /** + * Returns a channel where the I/O operation associated with this + * future takes place. + */ + Channel channel(); + + @Override + ChannelFuture addListener(GenericFutureListener> listener); + + @Override + ChannelFuture addListeners(GenericFutureListener>... listeners); + + @Override + ChannelFuture removeListener(GenericFutureListener> listener); + + @Override + ChannelFuture removeListeners(GenericFutureListener>... listeners); + + @Override + ChannelFuture sync() throws InterruptedException; + + @Override + ChannelFuture syncUninterruptibly(); + + @Override + ChannelFuture await() throws InterruptedException; + + @Override + ChannelFuture awaitUninterruptibly(); + + /** + * Returns {@code true} if this {@link ChannelFuture} is a void future and so not allow to call any of the + * following methods: + *
    + *
  • {@link #addListener(GenericFutureListener)}
  • + *
  • {@link #addListeners(GenericFutureListener[])}
  • + *
  • {@link #await()}
  • + *
  • {@link #await(long, TimeUnit)} ()}
  • + *
  • {@link #await(long)} ()}
  • + *
  • {@link #awaitUninterruptibly()}
  • + *
  • {@link #sync()}
  • + *
  • {@link #syncUninterruptibly()}
  • + *
+ */ + boolean isVoid(); +} diff --git a/netty-channel/src/main/java/io/netty/channel/ChannelFutureListener.java b/netty-channel/src/main/java/io/netty/channel/ChannelFutureListener.java new file mode 100644 index 0000000..d79e098 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/ChannelFutureListener.java @@ -0,0 +1,75 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.GenericFutureListener; + + +/** + * Listens to the result of a {@link ChannelFuture}. The result of the + * asynchronous {@link Channel} I/O operation is notified once this listener + * is added by calling {@link ChannelFuture#addListener(GenericFutureListener)}. + * + *

Return the control to the caller quickly

+ * + * {@link #operationComplete(Future)} is directly called by an I/O + * thread. Therefore, performing a time consuming task or a blocking operation + * in the handler method can cause an unexpected pause during I/O. If you need + * to perform a blocking operation on I/O completion, try to execute the + * operation in a different thread using a thread pool. + */ +public interface ChannelFutureListener extends GenericFutureListener { + + /** + * A {@link ChannelFutureListener} that closes the {@link Channel} which is + * associated with the specified {@link ChannelFuture}. + */ + ChannelFutureListener CLOSE = new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + future.channel().close(); + } + }; + + /** + * A {@link ChannelFutureListener} that closes the {@link Channel} when the + * operation ended up with a failure or cancellation rather than a success. + */ + ChannelFutureListener CLOSE_ON_FAILURE = new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + if (!future.isSuccess()) { + future.channel().close(); + } + } + }; + + /** + * A {@link ChannelFutureListener} that forwards the {@link Throwable} of the {@link ChannelFuture} into the + * {@link ChannelPipeline}. This mimics the old behavior of Netty 3. + */ + ChannelFutureListener FIRE_EXCEPTION_ON_FAILURE = new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + if (!future.isSuccess()) { + future.channel().pipeline().fireExceptionCaught(future.cause()); + } + } + }; + + // Just a type alias +} diff --git a/netty-channel/src/main/java/io/netty/channel/ChannelHandler.java b/netty-channel/src/main/java/io/netty/channel/ChannelHandler.java new file mode 100644 index 0000000..36d668d --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/ChannelHandler.java @@ -0,0 +1,219 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.util.Attribute; +import io.netty.util.AttributeKey; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Inherited; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * Handles an I/O event or intercepts an I/O operation, and forwards it to its next handler in + * its {@link ChannelPipeline}. + * + *

Sub-types

+ *

+ * {@link ChannelHandler} itself does not provide many methods, but you usually have to implement one of its subtypes: + *

    + *
  • {@link ChannelInboundHandler} to handle inbound I/O events, and
  • + *
  • {@link ChannelOutboundHandler} to handle outbound I/O operations.
  • + *
+ *

+ *

+ * Alternatively, the following adapter classes are provided for your convenience: + *

    + *
  • {@link ChannelInboundHandlerAdapter} to handle inbound I/O events,
  • + *
  • {@link ChannelOutboundHandlerAdapter} to handle outbound I/O operations, and
  • + *
  • {@link ChannelDuplexHandler} to handle both inbound and outbound events
  • + *
+ *

+ *

+ * For more information, please refer to the documentation of each subtype. + *

+ * + *

The context object

+ *

+ * A {@link ChannelHandler} is provided with a {@link ChannelHandlerContext} + * object. A {@link ChannelHandler} is supposed to interact with the + * {@link ChannelPipeline} it belongs to via a context object. Using the + * context object, the {@link ChannelHandler} can pass events upstream or + * downstream, modify the pipeline dynamically, or store the information + * (using {@link AttributeKey}s) which is specific to the handler. + * + *

State management

+ * + * A {@link ChannelHandler} often needs to store some stateful information. + * The simplest and recommended approach is to use member variables: + *
+ * public interface Message {
+ *     // your methods here
+ * }
+ *
+ * public class DataServerHandler extends {@link SimpleChannelInboundHandler}<Message> {
+ *
+ *     private boolean loggedIn;
+ *
+ *     {@code @Override}
+ *     public void channelRead0({@link ChannelHandlerContext} ctx, Message message) {
+ *         if (message instanceof LoginMessage) {
+ *             authenticate((LoginMessage) message);
+ *             loggedIn = true;
+ *         } else (message instanceof GetDataMessage) {
+ *             if (loggedIn) {
+ *                 ctx.writeAndFlush(fetchSecret((GetDataMessage) message));
+ *             } else {
+ *                 fail();
+ *             }
+ *         }
+ *     }
+ *     ...
+ * }
+ * 
+ * Because the handler instance has a state variable which is dedicated to + * one connection, you have to create a new handler instance for each new + * channel to avoid a race condition where an unauthenticated client can get + * the confidential information: + *
+ * // Create a new handler instance per channel.
+ * // See {@link ChannelInitializer#initChannel(Channel)}.
+ * public class DataServerInitializer extends {@link ChannelInitializer}<{@link Channel}> {
+ *     {@code @Override}
+ *     public void initChannel({@link Channel} channel) {
+ *         channel.pipeline().addLast("handler", new DataServerHandler());
+ *     }
+ * }
+ *
+ * 
+ * + *

Using {@link AttributeKey}s

+ * + * Although it's recommended to use member variables to store the state of a + * handler, for some reason you might not want to create many handler instances. + * In such a case, you can use {@link AttributeKey}s which is provided by + * {@link ChannelHandlerContext}: + *
+ * public interface Message {
+ *     // your methods here
+ * }
+ *
+ * {@code @Sharable}
+ * public class DataServerHandler extends {@link SimpleChannelInboundHandler}<Message> {
+ *     private final {@link AttributeKey}<{@link Boolean}> auth =
+ *           {@link AttributeKey#valueOf(String) AttributeKey.valueOf("auth")};
+ *
+ *     {@code @Override}
+ *     public void channelRead({@link ChannelHandlerContext} ctx, Message message) {
+ *         {@link Attribute}<{@link Boolean}> attr = ctx.attr(auth);
+ *         if (message instanceof LoginMessage) {
+ *             authenticate((LoginMessage) o);
+ *             attr.set(true);
+ *         } else (message instanceof GetDataMessage) {
+ *             if (Boolean.TRUE.equals(attr.get())) {
+ *                 ctx.writeAndFlush(fetchSecret((GetDataMessage) o));
+ *             } else {
+ *                 fail();
+ *             }
+ *         }
+ *     }
+ *     ...
+ * }
+ * 
+ * Now that the state of the handler is attached to the {@link ChannelHandlerContext}, you can add the + * same handler instance to different pipelines: + *
+ * public class DataServerInitializer extends {@link ChannelInitializer}<{@link Channel}> {
+ *
+ *     private static final DataServerHandler SHARED = new DataServerHandler();
+ *
+ *     {@code @Override}
+ *     public void initChannel({@link Channel} channel) {
+ *         channel.pipeline().addLast("handler", SHARED);
+ *     }
+ * }
+ * 
+ * + * + *

The {@code @Sharable} annotation

+ *

+ * In the example above which used an {@link AttributeKey}, + * you might have noticed the {@code @Sharable} annotation. + *

+ * If a {@link ChannelHandler} is annotated with the {@code @Sharable} + * annotation, it means you can create an instance of the handler just once and + * add it to one or more {@link ChannelPipeline}s multiple times without + * a race condition. + *

+ * If this annotation is not specified, you have to create a new handler + * instance every time you add it to a pipeline because it has unshared state + * such as member variables. + *

+ * This annotation is provided for documentation purpose, just like + * the JCIP annotations. + * + *

Additional resources worth reading

+ *

+ * Please refer to the {@link ChannelHandler}, and + * {@link ChannelPipeline} to find out more about inbound and outbound operations, + * what fundamental differences they have, how they flow in a pipeline, and how to handle + * the operation in your application. + */ +public interface ChannelHandler { + + /** + * Gets called after the {@link ChannelHandler} was added to the actual context and it's ready to handle events. + */ + void handlerAdded(ChannelHandlerContext ctx) throws Exception; + + /** + * Gets called after the {@link ChannelHandler} was removed from the actual context and it doesn't handle events + * anymore. + */ + void handlerRemoved(ChannelHandlerContext ctx) throws Exception; + + /** + * Gets called if a {@link Throwable} was thrown. + * + * @deprecated if you want to handle this event you should implement {@link ChannelInboundHandler} and + * implement the method there. + */ + @Deprecated + void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception; + + /** + * Indicates that the same instance of the annotated {@link ChannelHandler} + * can be added to one or more {@link ChannelPipeline}s multiple times + * without a race condition. + *

+ * If this annotation is not specified, you have to create a new handler + * instance every time you add it to a pipeline because it has unshared + * state such as member variables. + *

+ * This annotation is provided for documentation purpose, just like + * the JCIP annotations. + */ + @Inherited + @Documented + @Target(ElementType.TYPE) + @Retention(RetentionPolicy.RUNTIME) + @interface Sharable { + // no value + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/ChannelHandlerAdapter.java b/netty-channel/src/main/java/io/netty/channel/ChannelHandlerAdapter.java new file mode 100644 index 0000000..c82cefc --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/ChannelHandlerAdapter.java @@ -0,0 +1,94 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.channel; + +import io.netty.channel.ChannelHandlerMask.Skip; +import io.netty.util.internal.InternalThreadLocalMap; + +import java.util.Map; + +/** + * Skeleton implementation of a {@link ChannelHandler}. + */ +public abstract class ChannelHandlerAdapter implements ChannelHandler { + + // Not using volatile because it's used only for a sanity check. + boolean added; + + /** + * Throws {@link IllegalStateException} if {@link ChannelHandlerAdapter#isSharable()} returns {@code true} + */ + protected void ensureNotSharable() { + if (isSharable()) { + throw new IllegalStateException("ChannelHandler " + getClass().getName() + " is not allowed to be shared"); + } + } + + /** + * Return {@code true} if the implementation is {@link Sharable} and so can be added + * to different {@link ChannelPipeline}s. + */ + public boolean isSharable() { + /** + * Cache the result of {@link Sharable} annotation detection to workaround a condition. We use a + * {@link ThreadLocal} and {@link WeakHashMap} to eliminate the volatile write/reads. Using different + * {@link WeakHashMap} instances per {@link Thread} is good enough for us and the number of + * {@link Thread}s are quite limited anyway. + * + * See #2289. + */ + Class clazz = getClass(); + Map, Boolean> cache = InternalThreadLocalMap.get().handlerSharableCache(); + Boolean sharable = cache.get(clazz); + if (sharable == null) { + sharable = clazz.isAnnotationPresent(Sharable.class); + cache.put(clazz, sharable); + } + return sharable; + } + + /** + * Do nothing by default, sub-classes may override this method. + */ + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + // NOOP + } + + /** + * Do nothing by default, sub-classes may override this method. + */ + @Override + public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { + // NOOP + } + + /** + * Calls {@link ChannelHandlerContext#fireExceptionCaught(Throwable)} to forward + * to the next {@link ChannelHandler} in the {@link ChannelPipeline}. + * + * Sub-classes may override this method to change behavior. + * + * @deprecated is part of {@link ChannelInboundHandler} + */ + @Skip + @Override + @Deprecated + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + ctx.fireExceptionCaught(cause); + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/ChannelHandlerContext.java b/netty-channel/src/main/java/io/netty/channel/ChannelHandlerContext.java new file mode 100644 index 0000000..c81b0dd --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/ChannelHandlerContext.java @@ -0,0 +1,174 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.util.Attribute; +import io.netty.util.AttributeKey; +import io.netty.util.AttributeMap; +import io.netty.util.concurrent.EventExecutor; + +/** + * Enables a {@link ChannelHandler} to interact with its {@link ChannelPipeline} + * and other handlers. Among other things a handler can notify the next {@link ChannelHandler} in the + * {@link ChannelPipeline} as well as modify the {@link ChannelPipeline} it belongs to dynamically. + * + *

Notify

+ * + * You can notify the closest handler in the same {@link ChannelPipeline} by calling one of the various methods + * provided here. + * + * Please refer to {@link ChannelPipeline} to understand how an event flows. + * + *

Modifying a pipeline

+ * + * You can get the {@link ChannelPipeline} your handler belongs to by calling + * {@link #pipeline()}. A non-trivial application could insert, remove, or + * replace handlers in the pipeline dynamically at runtime. + * + *

Retrieving for later use

+ * + * You can keep the {@link ChannelHandlerContext} for later use, such as + * triggering an event outside the handler methods, even from a different thread. + *
+ * public class MyHandler extends {@link ChannelDuplexHandler} {
+ *
+ *     private {@link ChannelHandlerContext} ctx;
+ *
+ *     public void beforeAdd({@link ChannelHandlerContext} ctx) {
+ *         this.ctx = ctx;
+ *     }
+ *
+ *     public void login(String username, password) {
+ *         ctx.write(new LoginMessage(username, password));
+ *     }
+ *     ...
+ * }
+ * 
+ * + *

Storing stateful information

+ * + * {@link #attr(AttributeKey)} allow you to + * store and access stateful information that is related with a {@link ChannelHandler} / {@link Channel} and its + * context. Please refer to {@link ChannelHandler} to learn various recommended + * ways to manage stateful information. + * + *

A handler can have more than one {@link ChannelHandlerContext}

+ * + * Please note that a {@link ChannelHandler} instance can be added to more than + * one {@link ChannelPipeline}. It means a single {@link ChannelHandler} + * instance can have more than one {@link ChannelHandlerContext} and therefore + * the single instance can be invoked with different + * {@link ChannelHandlerContext}s if it is added to one or more {@link ChannelPipeline}s more than once. + * Also note that a {@link ChannelHandler} that is supposed to be added to multiple {@link ChannelPipeline}s should + * be marked as {@link io.netty.channel.ChannelHandler.Sharable}. + * + *

Additional resources worth reading

+ *

+ * Please refer to the {@link ChannelHandler}, and + * {@link ChannelPipeline} to find out more about inbound and outbound operations, + * what fundamental differences they have, how they flow in a pipeline, and how to handle + * the operation in your application. + */ +public interface ChannelHandlerContext extends AttributeMap, ChannelInboundInvoker, ChannelOutboundInvoker { + + /** + * Return the {@link Channel} which is bound to the {@link ChannelHandlerContext}. + */ + Channel channel(); + + /** + * Returns the {@link EventExecutor} which is used to execute an arbitrary task. + */ + EventExecutor executor(); + + /** + * The unique name of the {@link ChannelHandlerContext}.The name was used when then {@link ChannelHandler} + * was added to the {@link ChannelPipeline}. This name can also be used to access the registered + * {@link ChannelHandler} from the {@link ChannelPipeline}. + */ + String name(); + + /** + * The {@link ChannelHandler} that is bound this {@link ChannelHandlerContext}. + */ + ChannelHandler handler(); + + /** + * Return {@code true} if the {@link ChannelHandler} which belongs to this context was removed + * from the {@link ChannelPipeline}. Note that this method is only meant to be called from with in the + * {@link EventLoop}. + */ + boolean isRemoved(); + + @Override + ChannelHandlerContext fireChannelRegistered(); + + @Override + ChannelHandlerContext fireChannelUnregistered(); + + @Override + ChannelHandlerContext fireChannelActive(); + + @Override + ChannelHandlerContext fireChannelInactive(); + + @Override + ChannelHandlerContext fireExceptionCaught(Throwable cause); + + @Override + ChannelHandlerContext fireUserEventTriggered(Object evt); + + @Override + ChannelHandlerContext fireChannelRead(Object msg); + + @Override + ChannelHandlerContext fireChannelReadComplete(); + + @Override + ChannelHandlerContext fireChannelWritabilityChanged(); + + @Override + ChannelHandlerContext read(); + + @Override + ChannelHandlerContext flush(); + + /** + * Return the assigned {@link ChannelPipeline} + */ + ChannelPipeline pipeline(); + + /** + * Return the assigned {@link ByteBufAllocator} which will be used to allocate {@link ByteBuf}s. + */ + ByteBufAllocator alloc(); + + /** + * @deprecated Use {@link Channel#attr(AttributeKey)} + */ + @Deprecated + @Override + Attribute attr(AttributeKey key); + + /** + * @deprecated Use {@link Channel#hasAttr(AttributeKey)} + */ + @Deprecated + @Override + boolean hasAttr(AttributeKey key); +} diff --git a/netty-channel/src/main/java/io/netty/channel/ChannelHandlerMask.java b/netty-channel/src/main/java/io/netty/channel/ChannelHandlerMask.java new file mode 100644 index 0000000..837adec --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/ChannelHandlerMask.java @@ -0,0 +1,205 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.util.concurrent.FastThreadLocal; +import io.netty.util.internal.PlatformDependent; + +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; +import java.lang.annotation.ElementType; +import java.lang.annotation.Inherited; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; +import java.lang.reflect.Method; +import java.net.SocketAddress; +import java.security.AccessController; +import java.security.PrivilegedExceptionAction; +import java.util.Map; +import java.util.WeakHashMap; + +final class ChannelHandlerMask { + private static final InternalLogger logger = InternalLoggerFactory.getInstance(ChannelHandlerMask.class); + + // Using to mask which methods must be called for a ChannelHandler. + static final int MASK_EXCEPTION_CAUGHT = 1; + static final int MASK_CHANNEL_REGISTERED = 1 << 1; + static final int MASK_CHANNEL_UNREGISTERED = 1 << 2; + static final int MASK_CHANNEL_ACTIVE = 1 << 3; + static final int MASK_CHANNEL_INACTIVE = 1 << 4; + static final int MASK_CHANNEL_READ = 1 << 5; + static final int MASK_CHANNEL_READ_COMPLETE = 1 << 6; + static final int MASK_USER_EVENT_TRIGGERED = 1 << 7; + static final int MASK_CHANNEL_WRITABILITY_CHANGED = 1 << 8; + static final int MASK_BIND = 1 << 9; + static final int MASK_CONNECT = 1 << 10; + static final int MASK_DISCONNECT = 1 << 11; + static final int MASK_CLOSE = 1 << 12; + static final int MASK_DEREGISTER = 1 << 13; + static final int MASK_READ = 1 << 14; + static final int MASK_WRITE = 1 << 15; + static final int MASK_FLUSH = 1 << 16; + + static final int MASK_ONLY_INBOUND = MASK_CHANNEL_REGISTERED | + MASK_CHANNEL_UNREGISTERED | MASK_CHANNEL_ACTIVE | MASK_CHANNEL_INACTIVE | MASK_CHANNEL_READ | + MASK_CHANNEL_READ_COMPLETE | MASK_USER_EVENT_TRIGGERED | MASK_CHANNEL_WRITABILITY_CHANGED; + private static final int MASK_ALL_INBOUND = MASK_EXCEPTION_CAUGHT | MASK_ONLY_INBOUND; + static final int MASK_ONLY_OUTBOUND = MASK_BIND | MASK_CONNECT | MASK_DISCONNECT | + MASK_CLOSE | MASK_DEREGISTER | MASK_READ | MASK_WRITE | MASK_FLUSH; + private static final int MASK_ALL_OUTBOUND = MASK_EXCEPTION_CAUGHT | MASK_ONLY_OUTBOUND; + + private static final FastThreadLocal, Integer>> MASKS = + new FastThreadLocal, Integer>>() { + @Override + protected Map, Integer> initialValue() { + return new WeakHashMap, Integer>(32); + } + }; + + /** + * Return the {@code executionMask}. + */ + static int mask(Class clazz) { + // Try to obtain the mask from the cache first. If this fails calculate it and put it in the cache for fast + // lookup in the future. + Map, Integer> cache = MASKS.get(); + Integer mask = cache.get(clazz); + if (mask == null) { + mask = mask0(clazz); + cache.put(clazz, mask); + } + return mask; + } + + /** + * Calculate the {@code executionMask}. + */ + private static int mask0(Class handlerType) { + int mask = MASK_EXCEPTION_CAUGHT; + try { + if (ChannelInboundHandler.class.isAssignableFrom(handlerType)) { + mask |= MASK_ALL_INBOUND; + + if (isSkippable(handlerType, "channelRegistered", ChannelHandlerContext.class)) { + mask &= ~MASK_CHANNEL_REGISTERED; + } + if (isSkippable(handlerType, "channelUnregistered", ChannelHandlerContext.class)) { + mask &= ~MASK_CHANNEL_UNREGISTERED; + } + if (isSkippable(handlerType, "channelActive", ChannelHandlerContext.class)) { + mask &= ~MASK_CHANNEL_ACTIVE; + } + if (isSkippable(handlerType, "channelInactive", ChannelHandlerContext.class)) { + mask &= ~MASK_CHANNEL_INACTIVE; + } + if (isSkippable(handlerType, "channelRead", ChannelHandlerContext.class, Object.class)) { + mask &= ~MASK_CHANNEL_READ; + } + if (isSkippable(handlerType, "channelReadComplete", ChannelHandlerContext.class)) { + mask &= ~MASK_CHANNEL_READ_COMPLETE; + } + if (isSkippable(handlerType, "channelWritabilityChanged", ChannelHandlerContext.class)) { + mask &= ~MASK_CHANNEL_WRITABILITY_CHANGED; + } + if (isSkippable(handlerType, "userEventTriggered", ChannelHandlerContext.class, Object.class)) { + mask &= ~MASK_USER_EVENT_TRIGGERED; + } + } + + if (ChannelOutboundHandler.class.isAssignableFrom(handlerType)) { + mask |= MASK_ALL_OUTBOUND; + + if (isSkippable(handlerType, "bind", ChannelHandlerContext.class, + SocketAddress.class, ChannelPromise.class)) { + mask &= ~MASK_BIND; + } + if (isSkippable(handlerType, "connect", ChannelHandlerContext.class, SocketAddress.class, + SocketAddress.class, ChannelPromise.class)) { + mask &= ~MASK_CONNECT; + } + if (isSkippable(handlerType, "disconnect", ChannelHandlerContext.class, ChannelPromise.class)) { + mask &= ~MASK_DISCONNECT; + } + if (isSkippable(handlerType, "close", ChannelHandlerContext.class, ChannelPromise.class)) { + mask &= ~MASK_CLOSE; + } + if (isSkippable(handlerType, "deregister", ChannelHandlerContext.class, ChannelPromise.class)) { + mask &= ~MASK_DEREGISTER; + } + if (isSkippable(handlerType, "read", ChannelHandlerContext.class)) { + mask &= ~MASK_READ; + } + if (isSkippable(handlerType, "write", ChannelHandlerContext.class, + Object.class, ChannelPromise.class)) { + mask &= ~MASK_WRITE; + } + if (isSkippable(handlerType, "flush", ChannelHandlerContext.class)) { + mask &= ~MASK_FLUSH; + } + } + + if (isSkippable(handlerType, "exceptionCaught", ChannelHandlerContext.class, Throwable.class)) { + mask &= ~MASK_EXCEPTION_CAUGHT; + } + } catch (Exception e) { + // Should never reach here. + PlatformDependent.throwException(e); + } + + return mask; + } + + @SuppressWarnings("rawtypes") + private static boolean isSkippable( + final Class handlerType, final String methodName, final Class... paramTypes) throws Exception { + return AccessController.doPrivileged(new PrivilegedExceptionAction() { + @Override + public Boolean run() throws Exception { + Method m; + try { + m = handlerType.getMethod(methodName, paramTypes); + } catch (NoSuchMethodException e) { + if (logger.isDebugEnabled()) { + logger.debug( + "Class {} missing method {}, assume we can not skip execution", handlerType, methodName, e); + } + return false; + } + return m.isAnnotationPresent(Skip.class); + } + }); + } + + private ChannelHandlerMask() { } + + /** + * Indicates that the annotated event handler method in {@link ChannelHandler} will not be invoked by + * {@link ChannelPipeline} and so MUST only be used when the {@link ChannelHandler} + * method does nothing except forward to the next {@link ChannelHandler} in the pipeline. + *

+ * Note that this annotation is not {@linkplain Inherited inherited}. If a user overrides a method annotated with + * {@link Skip}, it will not be skipped anymore. Similarly, the user can override a method not annotated with + * {@link Skip} and simply pass the event through to the next handler, which reverses the behavior of the + * supertype. + *

+ */ + @Target(ElementType.METHOD) + @Retention(RetentionPolicy.RUNTIME) + @interface Skip { + // no value + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/ChannelId.java b/netty-channel/src/main/java/io/netty/channel/ChannelId.java new file mode 100644 index 0000000..d8d8dd8 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/ChannelId.java @@ -0,0 +1,56 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.channel; + +import java.io.Serializable; + +/** + * Represents the globally unique identifier of a {@link Channel}. + *

+ * The identifier is generated from various sources listed in the following: + *

    + *
  • MAC address (EUI-48 or EUI-64) or the network adapter, preferably a globally unique one,
  • + *
  • the current process ID,
  • + *
  • {@link System#currentTimeMillis()},
  • + *
  • {@link System#nanoTime()},
  • + *
  • a random 32-bit integer, and
  • + *
  • a sequentially incremented 32-bit integer.
  • + *
+ *

+ *

+ * The global uniqueness of the generated identifier mostly depends on the MAC address and the current process ID, + * which are auto-detected at the class-loading time in best-effort manner. If all attempts to acquire them fail, + * a warning message is logged, and random values will be used instead. Alternatively, you can specify them manually + * via system properties: + *

    + *
  • {@code io.netty.machineId} - hexadecimal representation of 48 (or 64) bit integer, + * optionally separated by colon or hyphen.
  • + *
  • {@code io.netty.processId} - an integer between 0 and 65535
  • + *
+ *

+ */ +public interface ChannelId extends Serializable, Comparable { + /** + * Returns the short but globally non-unique string representation of the {@link ChannelId}. + */ + String asShortText(); + + /** + * Returns the long yet globally unique string representation of the {@link ChannelId}. + */ + String asLongText(); +} diff --git a/netty-channel/src/main/java/io/netty/channel/ChannelInboundHandler.java b/netty-channel/src/main/java/io/netty/channel/ChannelInboundHandler.java new file mode 100644 index 0000000..0bf136d --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/ChannelInboundHandler.java @@ -0,0 +1,75 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +/** + * {@link ChannelHandler} which adds callbacks for state changes. This allows the user + * to hook in to state changes easily. + */ +public interface ChannelInboundHandler extends ChannelHandler { + + /** + * The {@link Channel} of the {@link ChannelHandlerContext} was registered with its {@link EventLoop} + */ + void channelRegistered(ChannelHandlerContext ctx) throws Exception; + + /** + * The {@link Channel} of the {@link ChannelHandlerContext} was unregistered from its {@link EventLoop} + */ + void channelUnregistered(ChannelHandlerContext ctx) throws Exception; + + /** + * The {@link Channel} of the {@link ChannelHandlerContext} is now active + */ + void channelActive(ChannelHandlerContext ctx) throws Exception; + + /** + * The {@link Channel} of the {@link ChannelHandlerContext} was registered is now inactive and reached its + * end of lifetime. + */ + void channelInactive(ChannelHandlerContext ctx) throws Exception; + + /** + * Invoked when the current {@link Channel} has read a message from the peer. + */ + void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception; + + /** + * Invoked when the last message read by the current read operation has been consumed by + * {@link #channelRead(ChannelHandlerContext, Object)}. If {@link ChannelOption#AUTO_READ} is off, no further + * attempt to read an inbound data from the current {@link Channel} will be made until + * {@link ChannelHandlerContext#read()} is called. + */ + void channelReadComplete(ChannelHandlerContext ctx) throws Exception; + + /** + * Gets called if an user event was triggered. + */ + void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception; + + /** + * Gets called once the writable state of a {@link Channel} changed. You can check the state with + * {@link Channel#isWritable()}. + */ + void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception; + + /** + * Gets called if a {@link Throwable} was thrown. + */ + @Override + @SuppressWarnings("deprecation") + void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception; +} diff --git a/netty-channel/src/main/java/io/netty/channel/ChannelInboundHandlerAdapter.java b/netty-channel/src/main/java/io/netty/channel/ChannelInboundHandlerAdapter.java new file mode 100644 index 0000000..3fe68d3 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/ChannelInboundHandlerAdapter.java @@ -0,0 +1,145 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.channel.ChannelHandlerMask.Skip; + +/** + * Abstract base class for {@link ChannelInboundHandler} implementations which provide + * implementations of all of their methods. + * + *

+ * This implementation just forward the operation to the next {@link ChannelHandler} in the + * {@link ChannelPipeline}. Sub-classes may override a method implementation to change this. + *

+ *

+ * Be aware that messages are not released after the {@link #channelRead(ChannelHandlerContext, Object)} + * method returns automatically. If you are looking for a {@link ChannelInboundHandler} implementation that + * releases the received messages automatically, please see {@link SimpleChannelInboundHandler}. + *

+ */ +public class ChannelInboundHandlerAdapter extends ChannelHandlerAdapter implements ChannelInboundHandler { + + /** + * Calls {@link ChannelHandlerContext#fireChannelRegistered()} to forward + * to the next {@link ChannelInboundHandler} in the {@link ChannelPipeline}. + * + * Sub-classes may override this method to change behavior. + */ + @Skip + @Override + public void channelRegistered(ChannelHandlerContext ctx) throws Exception { + ctx.fireChannelRegistered(); + } + + /** + * Calls {@link ChannelHandlerContext#fireChannelUnregistered()} to forward + * to the next {@link ChannelInboundHandler} in the {@link ChannelPipeline}. + * + * Sub-classes may override this method to change behavior. + */ + @Skip + @Override + public void channelUnregistered(ChannelHandlerContext ctx) throws Exception { + ctx.fireChannelUnregistered(); + } + + /** + * Calls {@link ChannelHandlerContext#fireChannelActive()} to forward + * to the next {@link ChannelInboundHandler} in the {@link ChannelPipeline}. + * + * Sub-classes may override this method to change behavior. + */ + @Skip + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + ctx.fireChannelActive(); + } + + /** + * Calls {@link ChannelHandlerContext#fireChannelInactive()} to forward + * to the next {@link ChannelInboundHandler} in the {@link ChannelPipeline}. + * + * Sub-classes may override this method to change behavior. + */ + @Skip + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + ctx.fireChannelInactive(); + } + + /** + * Calls {@link ChannelHandlerContext#fireChannelRead(Object)} to forward + * to the next {@link ChannelInboundHandler} in the {@link ChannelPipeline}. + * + * Sub-classes may override this method to change behavior. + */ + @Skip + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + ctx.fireChannelRead(msg); + } + + /** + * Calls {@link ChannelHandlerContext#fireChannelReadComplete()} to forward + * to the next {@link ChannelInboundHandler} in the {@link ChannelPipeline}. + * + * Sub-classes may override this method to change behavior. + */ + @Skip + @Override + public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + ctx.fireChannelReadComplete(); + } + + /** + * Calls {@link ChannelHandlerContext#fireUserEventTriggered(Object)} to forward + * to the next {@link ChannelInboundHandler} in the {@link ChannelPipeline}. + * + * Sub-classes may override this method to change behavior. + */ + @Skip + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + ctx.fireUserEventTriggered(evt); + } + + /** + * Calls {@link ChannelHandlerContext#fireChannelWritabilityChanged()} to forward + * to the next {@link ChannelInboundHandler} in the {@link ChannelPipeline}. + * + * Sub-classes may override this method to change behavior. + */ + @Skip + @Override + public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception { + ctx.fireChannelWritabilityChanged(); + } + + /** + * Calls {@link ChannelHandlerContext#fireExceptionCaught(Throwable)} to forward + * to the next {@link ChannelHandler} in the {@link ChannelPipeline}. + * + * Sub-classes may override this method to change behavior. + */ + @Skip + @Override + @SuppressWarnings("deprecation") + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) + throws Exception { + ctx.fireExceptionCaught(cause); + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/ChannelInboundInvoker.java b/netty-channel/src/main/java/io/netty/channel/ChannelInboundInvoker.java new file mode 100644 index 0000000..4949089 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/ChannelInboundInvoker.java @@ -0,0 +1,94 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +public interface ChannelInboundInvoker { + + /** + * A {@link Channel} was registered to its {@link EventLoop}. + * + * This will result in having the {@link ChannelInboundHandler#channelRegistered(ChannelHandlerContext)} method + * called of the next {@link ChannelInboundHandler} contained in the {@link ChannelPipeline} of the + * {@link Channel}. + */ + ChannelInboundInvoker fireChannelRegistered(); + + /** + * A {@link Channel} was unregistered from its {@link EventLoop}. + * + * This will result in having the {@link ChannelInboundHandler#channelUnregistered(ChannelHandlerContext)} method + * called of the next {@link ChannelInboundHandler} contained in the {@link ChannelPipeline} of the + * {@link Channel}. + */ + ChannelInboundInvoker fireChannelUnregistered(); + + /** + * A {@link Channel} is active now, which means it is connected. + * + * This will result in having the {@link ChannelInboundHandler#channelActive(ChannelHandlerContext)} method + * called of the next {@link ChannelInboundHandler} contained in the {@link ChannelPipeline} of the + * {@link Channel}. + */ + ChannelInboundInvoker fireChannelActive(); + + /** + * A {@link Channel} is inactive now, which means it is closed. + * + * This will result in having the {@link ChannelInboundHandler#channelInactive(ChannelHandlerContext)} method + * called of the next {@link ChannelInboundHandler} contained in the {@link ChannelPipeline} of the + * {@link Channel}. + */ + ChannelInboundInvoker fireChannelInactive(); + + /** + * A {@link Channel} received an {@link Throwable} in one of its inbound operations. + * + * This will result in having the {@link ChannelInboundHandler#exceptionCaught(ChannelHandlerContext, Throwable)} + * method called of the next {@link ChannelInboundHandler} contained in the {@link ChannelPipeline} of the + * {@link Channel}. + */ + ChannelInboundInvoker fireExceptionCaught(Throwable cause); + + /** + * A {@link Channel} received an user defined event. + * + * This will result in having the {@link ChannelInboundHandler#userEventTriggered(ChannelHandlerContext, Object)} + * method called of the next {@link ChannelInboundHandler} contained in the {@link ChannelPipeline} of the + * {@link Channel}. + */ + ChannelInboundInvoker fireUserEventTriggered(Object event); + + /** + * A {@link Channel} received a message. + * + * This will result in having the {@link ChannelInboundHandler#channelRead(ChannelHandlerContext, Object)} + * method called of the next {@link ChannelInboundHandler} contained in the {@link ChannelPipeline} of the + * {@link Channel}. + */ + ChannelInboundInvoker fireChannelRead(Object msg); + + /** + * Triggers an {@link ChannelInboundHandler#channelReadComplete(ChannelHandlerContext)} + * event to the next {@link ChannelInboundHandler} in the {@link ChannelPipeline}. + */ + ChannelInboundInvoker fireChannelReadComplete(); + + /** + * Triggers an {@link ChannelInboundHandler#channelWritabilityChanged(ChannelHandlerContext)} + * event to the next {@link ChannelInboundHandler} in the {@link ChannelPipeline}. + */ + ChannelInboundInvoker fireChannelWritabilityChanged(); +} diff --git a/netty-channel/src/main/java/io/netty/channel/ChannelInitializer.java b/netty-channel/src/main/java/io/netty/channel/ChannelInitializer.java new file mode 100644 index 0000000..61d9112 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/ChannelInitializer.java @@ -0,0 +1,159 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.channel.ChannelHandler.Sharable; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.util.Collections; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; + +/** + * A special {@link ChannelInboundHandler} which offers an easy way to initialize a {@link Channel} once it was + * registered to its {@link EventLoop}. + * + * Implementations are most often used in the context of {@link Bootstrap#handler(ChannelHandler)} , + * {@link ServerBootstrap#handler(ChannelHandler)} and {@link ServerBootstrap#childHandler(ChannelHandler)} to + * setup the {@link ChannelPipeline} of a {@link Channel}. + * + *
+ *
+ * public class MyChannelInitializer extends {@link ChannelInitializer} {
+ *     public void initChannel({@link Channel} channel) {
+ *         channel.pipeline().addLast("myHandler", new MyHandler());
+ *     }
+ * }
+ *
+ * {@link ServerBootstrap} bootstrap = ...;
+ * ...
+ * bootstrap.childHandler(new MyChannelInitializer());
+ * ...
+ * 
+ * Be aware that this class is marked as {@link Sharable} and so the implementation must be safe to be re-used. + * + * @param A sub-type of {@link Channel} + */ +@Sharable +public abstract class ChannelInitializer extends ChannelInboundHandlerAdapter { + + private static final InternalLogger logger = InternalLoggerFactory.getInstance(ChannelInitializer.class); + // We use a Set as a ChannelInitializer is usually shared between all Channels in a Bootstrap / + // ServerBootstrap. This way we can reduce the memory usage compared to use Attributes. + private final Set initMap = Collections.newSetFromMap( + new ConcurrentHashMap()); + + /** + * This method will be called once the {@link Channel} was registered. After the method returns this instance + * will be removed from the {@link ChannelPipeline} of the {@link Channel}. + * + * @param ch the {@link Channel} which was registered. + * @throws Exception is thrown if an error occurs. In that case it will be handled by + * {@link #exceptionCaught(ChannelHandlerContext, Throwable)} which will by default close + * the {@link Channel}. + */ + protected abstract void initChannel(C ch) throws Exception; + + @Override + @SuppressWarnings("unchecked") + public final void channelRegistered(ChannelHandlerContext ctx) throws Exception { + // Normally this method will never be called as handlerAdded(...) should call initChannel(...) and remove + // the handler. + if (initChannel(ctx)) { + // we called initChannel(...) so we need to call now pipeline.fireChannelRegistered() to ensure we not + // miss an event. + ctx.pipeline().fireChannelRegistered(); + + // We are done with init the Channel, removing all the state for the Channel now. + removeState(ctx); + } else { + // Called initChannel(...) before which is the expected behavior, so just forward the event. + ctx.fireChannelRegistered(); + } + } + + /** + * Handle the {@link Throwable} by logging and closing the {@link Channel}. Sub-classes may override this. + */ + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + if (logger.isWarnEnabled()) { + logger.warn("Failed to initialize a channel. Closing: " + ctx.channel(), cause); + } + ctx.close(); + } + + /** + * {@inheritDoc} If override this method ensure you call super! + */ + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + if (ctx.channel().isRegistered()) { + // This should always be true with our current DefaultChannelPipeline implementation. + // The good thing about calling initChannel(...) in handlerAdded(...) is that there will be no ordering + // surprises if a ChannelInitializer will add another ChannelInitializer. This is as all handlers + // will be added in the expected order. + if (initChannel(ctx)) { + + // We are done with init the Channel, removing the initializer now. + removeState(ctx); + } + } + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { + initMap.remove(ctx); + } + + @SuppressWarnings("unchecked") + private boolean initChannel(ChannelHandlerContext ctx) throws Exception { + if (initMap.add(ctx)) { // Guard against re-entrance. + try { + initChannel((C) ctx.channel()); + } catch (Throwable cause) { + // Explicitly call exceptionCaught(...) as we removed the handler before calling initChannel(...). + // We do so to prevent multiple calls to initChannel(...). + exceptionCaught(ctx, cause); + } finally { + if (!ctx.isRemoved()) { + ctx.pipeline().remove(this); + } + } + return true; + } + return false; + } + + private void removeState(final ChannelHandlerContext ctx) { + // The removal may happen in an async fashion if the EventExecutor we use does something funky. + if (ctx.isRemoved()) { + initMap.remove(ctx); + } else { + // The context is not removed yet which is most likely the case because a custom EventExecutor is used. + // Let's schedule it on the EventExecutor to give it some more time to be completed in case it is offloaded. + ctx.executor().execute(new Runnable() { + @Override + public void run() { + initMap.remove(ctx); + } + }); + } + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/ChannelMetadata.java b/netty-channel/src/main/java/io/netty/channel/ChannelMetadata.java new file mode 100644 index 0000000..f03b97b --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/ChannelMetadata.java @@ -0,0 +1,72 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import static io.netty.util.internal.ObjectUtil.checkPositive; + +import java.net.SocketAddress; + +/** + * Represents the properties of a {@link Channel} implementation. + */ +public final class ChannelMetadata { + + private final boolean hasDisconnect; + private final int defaultMaxMessagesPerRead; + + /** + * Create a new instance + * + * @param hasDisconnect {@code true} if and only if the channel has the {@code disconnect()} operation + * that allows a user to disconnect and then call {@link Channel#connect(SocketAddress)} + * again, such as UDP/IP. + */ + public ChannelMetadata(boolean hasDisconnect) { + this(hasDisconnect, 16); + } + + /** + * Create a new instance + * + * @param hasDisconnect {@code true} if and only if the channel has the {@code disconnect()} operation + * that allows a user to disconnect and then call {@link Channel#connect(SocketAddress)} + * again, such as UDP/IP. + * @param defaultMaxMessagesPerRead If a {@link MaxMessagesRecvByteBufAllocator} is in use, then this value will be + * set for {@link MaxMessagesRecvByteBufAllocator#maxMessagesPerRead()}. Must be {@code > 0}. + */ + public ChannelMetadata(boolean hasDisconnect, int defaultMaxMessagesPerRead) { + checkPositive(defaultMaxMessagesPerRead, "defaultMaxMessagesPerRead"); + this.hasDisconnect = hasDisconnect; + this.defaultMaxMessagesPerRead = defaultMaxMessagesPerRead; + } + + /** + * Returns {@code true} if and only if the channel has the {@code disconnect()} operation + * that allows a user to disconnect and then call {@link Channel#connect(SocketAddress)} again, + * such as UDP/IP. + */ + public boolean hasDisconnect() { + return hasDisconnect; + } + + /** + * If a {@link MaxMessagesRecvByteBufAllocator} is in use, then this is the default value for + * {@link MaxMessagesRecvByteBufAllocator#maxMessagesPerRead()}. + */ + public int defaultMaxMessagesPerRead() { + return defaultMaxMessagesPerRead; + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/ChannelOption.java b/netty-channel/src/main/java/io/netty/channel/ChannelOption.java new file mode 100644 index 0000000..0511285 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/ChannelOption.java @@ -0,0 +1,166 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.buffer.ByteBufAllocator; +import io.netty.util.AbstractConstant; +import io.netty.util.ConstantPool; +import io.netty.util.internal.ObjectUtil; + +import java.net.InetAddress; +import java.net.NetworkInterface; + +/** + * A {@link ChannelOption} allows to configure a {@link ChannelConfig} in a type-safe + * way. Which {@link ChannelOption} is supported depends on the actual implementation + * of {@link ChannelConfig} and may depend on the nature of the transport it belongs + * to. + * + * @param the type of the value which is valid for the {@link ChannelOption} + */ +public class ChannelOption extends AbstractConstant> { + + private static final ConstantPool> pool = new ConstantPool>() { + @Override + protected ChannelOption newConstant(int id, String name) { + return new ChannelOption(id, name); + } + }; + + /** + * Returns the {@link ChannelOption} of the specified name. + */ + @SuppressWarnings("unchecked") + public static ChannelOption valueOf(String name) { + return (ChannelOption) pool.valueOf(name); + } + + /** + * Shortcut of {@link #valueOf(String) valueOf(firstNameComponent.getName() + "#" + secondNameComponent)}. + */ + @SuppressWarnings("unchecked") + public static ChannelOption valueOf(Class firstNameComponent, String secondNameComponent) { + return (ChannelOption) pool.valueOf(firstNameComponent, secondNameComponent); + } + + /** + * Returns {@code true} if a {@link ChannelOption} exists for the given {@code name}. + */ + public static boolean exists(String name) { + return pool.exists(name); + } + + /** + * Creates a new {@link ChannelOption} for the given {@code name} or fail with an + * {@link IllegalArgumentException} if a {@link ChannelOption} for the given {@code name} exists. + * + * @deprecated use {@link #valueOf(String)}. + */ + @Deprecated + @SuppressWarnings("unchecked") + public static ChannelOption newInstance(String name) { + return (ChannelOption) pool.newInstance(name); + } + + public static final ChannelOption ALLOCATOR = valueOf("ALLOCATOR"); + public static final ChannelOption RCVBUF_ALLOCATOR = valueOf("RCVBUF_ALLOCATOR"); + public static final ChannelOption MESSAGE_SIZE_ESTIMATOR = valueOf("MESSAGE_SIZE_ESTIMATOR"); + + public static final ChannelOption CONNECT_TIMEOUT_MILLIS = valueOf("CONNECT_TIMEOUT_MILLIS"); + /** + * @deprecated Use {@link MaxMessagesRecvByteBufAllocator} + * and {@link MaxMessagesRecvByteBufAllocator#maxMessagesPerRead(int)}. + */ + @Deprecated + public static final ChannelOption MAX_MESSAGES_PER_READ = valueOf("MAX_MESSAGES_PER_READ"); + public static final ChannelOption MAX_MESSAGES_PER_WRITE = valueOf("MAX_MESSAGES_PER_WRITE"); + + public static final ChannelOption WRITE_SPIN_COUNT = valueOf("WRITE_SPIN_COUNT"); + /** + * @deprecated Use {@link #WRITE_BUFFER_WATER_MARK} + */ + @Deprecated + public static final ChannelOption WRITE_BUFFER_HIGH_WATER_MARK = valueOf("WRITE_BUFFER_HIGH_WATER_MARK"); + /** + * @deprecated Use {@link #WRITE_BUFFER_WATER_MARK} + */ + @Deprecated + public static final ChannelOption WRITE_BUFFER_LOW_WATER_MARK = valueOf("WRITE_BUFFER_LOW_WATER_MARK"); + public static final ChannelOption WRITE_BUFFER_WATER_MARK = + valueOf("WRITE_BUFFER_WATER_MARK"); + + public static final ChannelOption ALLOW_HALF_CLOSURE = valueOf("ALLOW_HALF_CLOSURE"); + public static final ChannelOption AUTO_READ = valueOf("AUTO_READ"); + + /** + * If {@code true} then the {@link Channel} is closed automatically and immediately on write failure. + * The default value is {@code true}. + */ + public static final ChannelOption AUTO_CLOSE = valueOf("AUTO_CLOSE"); + + public static final ChannelOption SO_BROADCAST = valueOf("SO_BROADCAST"); + public static final ChannelOption SO_KEEPALIVE = valueOf("SO_KEEPALIVE"); + public static final ChannelOption SO_SNDBUF = valueOf("SO_SNDBUF"); + public static final ChannelOption SO_RCVBUF = valueOf("SO_RCVBUF"); + public static final ChannelOption SO_REUSEADDR = valueOf("SO_REUSEADDR"); + public static final ChannelOption SO_LINGER = valueOf("SO_LINGER"); + public static final ChannelOption SO_BACKLOG = valueOf("SO_BACKLOG"); + public static final ChannelOption SO_TIMEOUT = valueOf("SO_TIMEOUT"); + + public static final ChannelOption IP_TOS = valueOf("IP_TOS"); + public static final ChannelOption IP_MULTICAST_ADDR = valueOf("IP_MULTICAST_ADDR"); + public static final ChannelOption IP_MULTICAST_IF = valueOf("IP_MULTICAST_IF"); + public static final ChannelOption IP_MULTICAST_TTL = valueOf("IP_MULTICAST_TTL"); + public static final ChannelOption IP_MULTICAST_LOOP_DISABLED = valueOf("IP_MULTICAST_LOOP_DISABLED"); + + public static final ChannelOption TCP_NODELAY = valueOf("TCP_NODELAY"); + /** + * Client-side TCP FastOpen. Sending data with the initial TCP handshake. + */ + public static final ChannelOption TCP_FASTOPEN_CONNECT = valueOf("TCP_FASTOPEN_CONNECT"); + + /** + * Server-side TCP FastOpen. Configures the maximum number of outstanding (waiting to be accepted) TFO connections. + */ + public static final ChannelOption TCP_FASTOPEN = valueOf(ChannelOption.class, "TCP_FASTOPEN"); + + @Deprecated + public static final ChannelOption DATAGRAM_CHANNEL_ACTIVE_ON_REGISTRATION = + valueOf("DATAGRAM_CHANNEL_ACTIVE_ON_REGISTRATION"); + + public static final ChannelOption SINGLE_EVENTEXECUTOR_PER_GROUP = + valueOf("SINGLE_EVENTEXECUTOR_PER_GROUP"); + + /** + * Creates a new {@link ChannelOption} with the specified unique {@code name}. + */ + private ChannelOption(int id, String name) { + super(id, name); + } + + @Deprecated + protected ChannelOption(String name) { + this(pool.nextId(), name); + } + + /** + * Validate the value which is set for the {@link ChannelOption}. Sub-classes + * may override this for special checks. + */ + public void validate(T value) { + ObjectUtil.checkNotNull(value, "value"); + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/ChannelOutboundBuffer.java b/netty-channel/src/main/java/io/netty/channel/ChannelOutboundBuffer.java new file mode 100644 index 0000000..7008ef5 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/ChannelOutboundBuffer.java @@ -0,0 +1,877 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufHolder; +import io.netty.buffer.Unpooled; +import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.util.Recycler.EnhancedHandle; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.concurrent.FastThreadLocal; +import io.netty.util.internal.InternalThreadLocalMap; +import io.netty.util.internal.ObjectPool; +import io.netty.util.internal.ObjectPool.Handle; +import io.netty.util.internal.ObjectPool.ObjectCreator; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.PromiseNotificationUtil; +import io.netty.util.internal.SystemPropertyUtil; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.nio.ByteBuffer; +import java.nio.channels.ClosedChannelException; +import java.util.Arrays; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; + +import static java.lang.Math.min; + +/** + * (Transport implementors only) an internal data structure used by {@link AbstractChannel} to store its pending + * outbound write requests. + *

+ * All methods must be called by a transport implementation from an I/O thread, except the following ones: + *

    + *
  • {@link #isWritable()}
  • + *
  • {@link #getUserDefinedWritability(int)} and {@link #setUserDefinedWritability(int, boolean)}
  • + *
+ *

+ */ +public final class ChannelOutboundBuffer { + // Assuming a 64-bit JVM: + // - 16 bytes object header + // - 6 reference fields + // - 2 long fields + // - 2 int fields + // - 1 boolean field + // - padding + static final int CHANNEL_OUTBOUND_BUFFER_ENTRY_OVERHEAD = + SystemPropertyUtil.getInt("io.netty.transport.outboundBufferEntrySizeOverhead", 96); + + private static final InternalLogger logger = InternalLoggerFactory.getInstance(ChannelOutboundBuffer.class); + + private static final FastThreadLocal NIO_BUFFERS = new FastThreadLocal() { + @Override + protected ByteBuffer[] initialValue() throws Exception { + return new ByteBuffer[1024]; + } + }; + + private final Channel channel; + + // Entry(flushedEntry) --> ... Entry(unflushedEntry) --> ... Entry(tailEntry) + // + // The Entry that is the first in the linked-list structure that was flushed + private Entry flushedEntry; + // The Entry which is the first unflushed in the linked-list structure + private Entry unflushedEntry; + // The Entry which represents the tail of the buffer + private Entry tailEntry; + // The number of flushed entries that are not written yet + private int flushed; + + private int nioBufferCount; + private long nioBufferSize; + + private boolean inFail; + + private static final AtomicLongFieldUpdater TOTAL_PENDING_SIZE_UPDATER = + AtomicLongFieldUpdater.newUpdater(ChannelOutboundBuffer.class, "totalPendingSize"); + + @SuppressWarnings("UnusedDeclaration") + private volatile long totalPendingSize; + + private static final AtomicIntegerFieldUpdater UNWRITABLE_UPDATER = + AtomicIntegerFieldUpdater.newUpdater(ChannelOutboundBuffer.class, "unwritable"); + + @SuppressWarnings("UnusedDeclaration") + private volatile int unwritable; + + private volatile Runnable fireChannelWritabilityChangedTask; + + ChannelOutboundBuffer(AbstractChannel channel) { + this.channel = channel; + } + + /** + * Add given message to this {@link ChannelOutboundBuffer}. The given {@link ChannelPromise} will be notified once + * the message was written. + */ + public void addMessage(Object msg, int size, ChannelPromise promise) { + Entry entry = Entry.newInstance(msg, size, total(msg), promise); + if (tailEntry == null) { + flushedEntry = null; + } else { + Entry tail = tailEntry; + tail.next = entry; + } + tailEntry = entry; + if (unflushedEntry == null) { + unflushedEntry = entry; + } + + // increment pending bytes after adding message to the unflushed arrays. + // See https://github.com/netty/netty/issues/1619 + incrementPendingOutboundBytes(entry.pendingSize, false); + } + + /** + * Add a flush to this {@link ChannelOutboundBuffer}. This means all previous added messages are marked as flushed + * and so you will be able to handle them. + */ + public void addFlush() { + // There is no need to process all entries if there was already a flush before and no new messages + // where added in the meantime. + // + // See https://github.com/netty/netty/issues/2577 + Entry entry = unflushedEntry; + if (entry != null) { + if (flushedEntry == null) { + // there is no flushedEntry yet, so start with the entry + flushedEntry = entry; + } + do { + flushed ++; + if (!entry.promise.setUncancellable()) { + // Was cancelled so make sure we free up memory and notify about the freed bytes + int pending = entry.cancel(); + decrementPendingOutboundBytes(pending, false, true); + } + entry = entry.next; + } while (entry != null); + + // All flushed so reset unflushedEntry + unflushedEntry = null; + } + } + + /** + * Increment the pending bytes which will be written at some point. + * This method is thread-safe! + */ + void incrementPendingOutboundBytes(long size) { + incrementPendingOutboundBytes(size, true); + } + + private void incrementPendingOutboundBytes(long size, boolean invokeLater) { + if (size == 0) { + return; + } + + long newWriteBufferSize = TOTAL_PENDING_SIZE_UPDATER.addAndGet(this, size); + if (newWriteBufferSize > channel.config().getWriteBufferHighWaterMark()) { + setUnwritable(invokeLater); + } + } + + /** + * Decrement the pending bytes which will be written at some point. + * This method is thread-safe! + */ + void decrementPendingOutboundBytes(long size) { + decrementPendingOutboundBytes(size, true, true); + } + + private void decrementPendingOutboundBytes(long size, boolean invokeLater, boolean notifyWritability) { + if (size == 0) { + return; + } + + long newWriteBufferSize = TOTAL_PENDING_SIZE_UPDATER.addAndGet(this, -size); + if (notifyWritability && newWriteBufferSize < channel.config().getWriteBufferLowWaterMark()) { + setWritable(invokeLater); + } + } + + private static long total(Object msg) { + if (msg instanceof ByteBuf) { + return ((ByteBuf) msg).readableBytes(); + } + if (msg instanceof FileRegion) { + return ((FileRegion) msg).count(); + } + if (msg instanceof ByteBufHolder) { + return ((ByteBufHolder) msg).content().readableBytes(); + } + return -1; + } + + /** + * Return the current message to write or {@code null} if nothing was flushed before and so is ready to be written. + */ + public Object current() { + Entry entry = flushedEntry; + if (entry == null) { + return null; + } + + return entry.msg; + } + + /** + * Return the current message flush progress. + * @return {@code 0} if nothing was flushed before for the current message or there is no current message + */ + public long currentProgress() { + Entry entry = flushedEntry; + if (entry == null) { + return 0; + } + return entry.progress; + } + + /** + * Notify the {@link ChannelPromise} of the current message about writing progress. + */ + public void progress(long amount) { + Entry e = flushedEntry; + assert e != null; + ChannelPromise p = e.promise; + long progress = e.progress + amount; + e.progress = progress; + assert p != null; + final Class promiseClass = p.getClass(); + // fast-path to save O(n) ChannelProgressivePromise's type check on OpenJDK + if (promiseClass == VoidChannelPromise.class || promiseClass == DefaultChannelPromise.class) { + return; + } + // this is going to save from type pollution due to https://bugs.openjdk.org/browse/JDK-8180450 + if (p instanceof DefaultChannelProgressivePromise) { + ((DefaultChannelProgressivePromise) p).tryProgress(progress, e.total); + } else if (p instanceof ChannelProgressivePromise) { + ((ChannelProgressivePromise) p).tryProgress(progress, e.total); + } + } + + /** + * Will remove the current message, mark its {@link ChannelPromise} as success and return {@code true}. If no + * flushed message exists at the time this method is called it will return {@code false} to signal that no more + * messages are ready to be handled. + */ + public boolean remove() { + Entry e = flushedEntry; + if (e == null) { + clearNioBuffers(); + return false; + } + Object msg = e.msg; + + ChannelPromise promise = e.promise; + int size = e.pendingSize; + + removeEntry(e); + + if (!e.cancelled) { + // only release message, notify and decrement if it was not canceled before. + ReferenceCountUtil.safeRelease(msg); + safeSuccess(promise); + decrementPendingOutboundBytes(size, false, true); + } + + // recycle the entry + e.unguardedRecycle(); + + return true; + } + + /** + * Will remove the current message, mark its {@link ChannelPromise} as failure using the given {@link Throwable} + * and return {@code true}. If no flushed message exists at the time this method is called it will return + * {@code false} to signal that no more messages are ready to be handled. + */ + public boolean remove(Throwable cause) { + return remove0(cause, true); + } + + private boolean remove0(Throwable cause, boolean notifyWritability) { + Entry e = flushedEntry; + if (e == null) { + clearNioBuffers(); + return false; + } + Object msg = e.msg; + + ChannelPromise promise = e.promise; + int size = e.pendingSize; + + removeEntry(e); + + if (!e.cancelled) { + // only release message, fail and decrement if it was not canceled before. + ReferenceCountUtil.safeRelease(msg); + + safeFail(promise, cause); + decrementPendingOutboundBytes(size, false, notifyWritability); + } + + // recycle the entry + e.unguardedRecycle(); + + return true; + } + + private void removeEntry(Entry e) { + if (-- flushed == 0) { + // processed everything + flushedEntry = null; + if (e == tailEntry) { + tailEntry = null; + unflushedEntry = null; + } + } else { + flushedEntry = e.next; + } + } + + /** + * Removes the fully written entries and update the reader index of the partially written entry. + * This operation assumes all messages in this buffer is {@link ByteBuf}. + */ + public void removeBytes(long writtenBytes) { + for (;;) { + Object msg = current(); + if (!(msg instanceof ByteBuf)) { + assert writtenBytes == 0; + break; + } + + final ByteBuf buf = (ByteBuf) msg; + final int readerIndex = buf.readerIndex(); + final int readableBytes = buf.writerIndex() - readerIndex; + + if (readableBytes <= writtenBytes) { + if (writtenBytes != 0) { + progress(readableBytes); + writtenBytes -= readableBytes; + } + remove(); + } else { // readableBytes > writtenBytes + if (writtenBytes != 0) { + buf.readerIndex(readerIndex + (int) writtenBytes); + progress(writtenBytes); + } + break; + } + } + clearNioBuffers(); + } + + // Clear all ByteBuffer from the array so these can be GC'ed. + // See https://github.com/netty/netty/issues/3837 + private void clearNioBuffers() { + int count = nioBufferCount; + if (count > 0) { + nioBufferCount = 0; + Arrays.fill(NIO_BUFFERS.get(), 0, count, null); + } + } + + /** + * Returns an array of direct NIO buffers if the currently pending messages are made of {@link ByteBuf} only. + * {@link #nioBufferCount()} and {@link #nioBufferSize()} will return the number of NIO buffers in the returned + * array and the total number of readable bytes of the NIO buffers respectively. + *

+ * Note that the returned array is reused and thus should not escape + * {@link AbstractChannel#doWrite(ChannelOutboundBuffer)}. + * Refer to {@link NioSocketChannel#doWrite(ChannelOutboundBuffer)} for an example. + *

+ */ + public ByteBuffer[] nioBuffers() { + return nioBuffers(Integer.MAX_VALUE, Integer.MAX_VALUE); + } + + /** + * Returns an array of direct NIO buffers if the currently pending messages are made of {@link ByteBuf} only. + * {@link #nioBufferCount()} and {@link #nioBufferSize()} will return the number of NIO buffers in the returned + * array and the total number of readable bytes of the NIO buffers respectively. + *

+ * Note that the returned array is reused and thus should not escape + * {@link AbstractChannel#doWrite(ChannelOutboundBuffer)}. + * Refer to {@link NioSocketChannel#doWrite(ChannelOutboundBuffer)} for an example. + *

+ * @param maxCount The maximum amount of buffers that will be added to the return value. + * @param maxBytes A hint toward the maximum number of bytes to include as part of the return value. Note that this + * value maybe exceeded because we make a best effort to include at least 1 {@link ByteBuffer} + * in the return value to ensure write progress is made. + */ + public ByteBuffer[] nioBuffers(int maxCount, long maxBytes) { + assert maxCount > 0; + assert maxBytes > 0; + long nioBufferSize = 0; + int nioBufferCount = 0; + final InternalThreadLocalMap threadLocalMap = InternalThreadLocalMap.get(); + ByteBuffer[] nioBuffers = NIO_BUFFERS.get(threadLocalMap); + Entry entry = flushedEntry; + while (isFlushedEntry(entry) && entry.msg instanceof ByteBuf) { + if (!entry.cancelled) { + ByteBuf buf = (ByteBuf) entry.msg; + final int readerIndex = buf.readerIndex(); + final int readableBytes = buf.writerIndex() - readerIndex; + + if (readableBytes > 0) { + if (maxBytes - readableBytes < nioBufferSize && nioBufferCount != 0) { + // If the nioBufferSize + readableBytes will overflow maxBytes, and there is at least one entry + // we stop populate the ByteBuffer array. This is done for 2 reasons: + // 1. bsd/osx don't allow to write more bytes then Integer.MAX_VALUE with one writev(...) call + // and so will return 'EINVAL', which will raise an IOException. On Linux it may work depending + // on the architecture and kernel but to be safe we also enforce the limit here. + // 2. There is no sense in putting more data in the array than is likely to be accepted by the + // OS. + // + // See also: + // - https://www.freebsd.org/cgi/man.cgi?query=write&sektion=2 + // - https://linux.die.net//man/2/writev + break; + } + nioBufferSize += readableBytes; + int count = entry.count; + if (count == -1) { + //noinspection ConstantValueVariableUse + entry.count = count = buf.nioBufferCount(); + } + int neededSpace = min(maxCount, nioBufferCount + count); + if (neededSpace > nioBuffers.length) { + nioBuffers = expandNioBufferArray(nioBuffers, neededSpace, nioBufferCount); + NIO_BUFFERS.set(threadLocalMap, nioBuffers); + } + if (count == 1) { + ByteBuffer nioBuf = entry.buf; + if (nioBuf == null) { + // cache ByteBuffer as it may need to create a new ByteBuffer instance if its a + // derived buffer + entry.buf = nioBuf = buf.internalNioBuffer(readerIndex, readableBytes); + } + nioBuffers[nioBufferCount++] = nioBuf; + } else { + // The code exists in an extra method to ensure the method is not too big to inline as this + // branch is not very likely to get hit very frequently. + nioBufferCount = nioBuffers(entry, buf, nioBuffers, nioBufferCount, maxCount); + } + if (nioBufferCount >= maxCount) { + break; + } + } + } + entry = entry.next; + } + this.nioBufferCount = nioBufferCount; + this.nioBufferSize = nioBufferSize; + + return nioBuffers; + } + + private static int nioBuffers(Entry entry, ByteBuf buf, ByteBuffer[] nioBuffers, int nioBufferCount, int maxCount) { + ByteBuffer[] nioBufs = entry.bufs; + if (nioBufs == null) { + // cached ByteBuffers as they may be expensive to create in terms + // of Object allocation + entry.bufs = nioBufs = buf.nioBuffers(); + } + for (int i = 0; i < nioBufs.length && nioBufferCount < maxCount; ++i) { + ByteBuffer nioBuf = nioBufs[i]; + if (nioBuf == null) { + break; + } else if (!nioBuf.hasRemaining()) { + continue; + } + nioBuffers[nioBufferCount++] = nioBuf; + } + return nioBufferCount; + } + + private static ByteBuffer[] expandNioBufferArray(ByteBuffer[] array, int neededSpace, int size) { + int newCapacity = array.length; + do { + // double capacity until it is big enough + // See https://github.com/netty/netty/issues/1890 + newCapacity <<= 1; + + if (newCapacity < 0) { + throw new IllegalStateException(); + } + + } while (neededSpace > newCapacity); + + ByteBuffer[] newArray = new ByteBuffer[newCapacity]; + System.arraycopy(array, 0, newArray, 0, size); + + return newArray; + } + + /** + * Returns the number of {@link ByteBuffer} that can be written out of the {@link ByteBuffer} array that was + * obtained via {@link #nioBuffers()}. This method MUST be called after {@link #nioBuffers()} + * was called. + */ + public int nioBufferCount() { + return nioBufferCount; + } + + /** + * Returns the number of bytes that can be written out of the {@link ByteBuffer} array that was + * obtained via {@link #nioBuffers()}. This method MUST be called after {@link #nioBuffers()} + * was called. + */ + public long nioBufferSize() { + return nioBufferSize; + } + + /** + * Returns {@code true} if and only if {@linkplain #totalPendingWriteBytes() the total number of pending bytes} did + * not exceed the write watermark of the {@link Channel} and + * no {@linkplain #setUserDefinedWritability(int, boolean) user-defined writability flag} has been set to + * {@code false}. + */ + public boolean isWritable() { + return unwritable == 0; + } + + /** + * Returns {@code true} if and only if the user-defined writability flag at the specified index is set to + * {@code true}. + */ + public boolean getUserDefinedWritability(int index) { + return (unwritable & writabilityMask(index)) == 0; + } + + /** + * Sets a user-defined writability flag at the specified index. + */ + public void setUserDefinedWritability(int index, boolean writable) { + if (writable) { + setUserDefinedWritability(index); + } else { + clearUserDefinedWritability(index); + } + } + + private void setUserDefinedWritability(int index) { + final int mask = ~writabilityMask(index); + for (;;) { + final int oldValue = unwritable; + final int newValue = oldValue & mask; + if (UNWRITABLE_UPDATER.compareAndSet(this, oldValue, newValue)) { + if (oldValue != 0 && newValue == 0) { + fireChannelWritabilityChanged(true); + } + break; + } + } + } + + private void clearUserDefinedWritability(int index) { + final int mask = writabilityMask(index); + for (;;) { + final int oldValue = unwritable; + final int newValue = oldValue | mask; + if (UNWRITABLE_UPDATER.compareAndSet(this, oldValue, newValue)) { + if (oldValue == 0 && newValue != 0) { + fireChannelWritabilityChanged(true); + } + break; + } + } + } + + private static int writabilityMask(int index) { + if (index < 1 || index > 31) { + throw new IllegalArgumentException("index: " + index + " (expected: 1~31)"); + } + return 1 << index; + } + + private void setWritable(boolean invokeLater) { + for (;;) { + final int oldValue = unwritable; + final int newValue = oldValue & ~1; + if (UNWRITABLE_UPDATER.compareAndSet(this, oldValue, newValue)) { + if (oldValue != 0 && newValue == 0) { + fireChannelWritabilityChanged(invokeLater); + } + break; + } + } + } + + private void setUnwritable(boolean invokeLater) { + for (;;) { + final int oldValue = unwritable; + final int newValue = oldValue | 1; + if (UNWRITABLE_UPDATER.compareAndSet(this, oldValue, newValue)) { + if (oldValue == 0) { + fireChannelWritabilityChanged(invokeLater); + } + break; + } + } + } + + private void fireChannelWritabilityChanged(boolean invokeLater) { + final ChannelPipeline pipeline = channel.pipeline(); + if (invokeLater) { + Runnable task = fireChannelWritabilityChangedTask; + if (task == null) { + fireChannelWritabilityChangedTask = task = new Runnable() { + @Override + public void run() { + pipeline.fireChannelWritabilityChanged(); + } + }; + } + channel.eventLoop().execute(task); + } else { + pipeline.fireChannelWritabilityChanged(); + } + } + + /** + * Returns the number of flushed messages in this {@link ChannelOutboundBuffer}. + */ + public int size() { + return flushed; + } + + /** + * Returns {@code true} if there are flushed messages in this {@link ChannelOutboundBuffer} or {@code false} + * otherwise. + */ + public boolean isEmpty() { + return flushed == 0; + } + + void failFlushed(Throwable cause, boolean notify) { + // Make sure that this method does not reenter. A listener added to the current promise can be notified by the + // current thread in the tryFailure() call of the loop below, and the listener can trigger another fail() call + // indirectly (usually by closing the channel.) + // + // See https://github.com/netty/netty/issues/1501 + if (inFail) { + return; + } + + try { + inFail = true; + for (;;) { + if (!remove0(cause, notify)) { + break; + } + } + } finally { + inFail = false; + } + } + + void close(final Throwable cause, final boolean allowChannelOpen) { + if (inFail) { + channel.eventLoop().execute(new Runnable() { + @Override + public void run() { + close(cause, allowChannelOpen); + } + }); + return; + } + + inFail = true; + + if (!allowChannelOpen && channel.isOpen()) { + throw new IllegalStateException("close() must be invoked after the channel is closed."); + } + + if (!isEmpty()) { + throw new IllegalStateException("close() must be invoked after all flushed writes are handled."); + } + + // Release all unflushed messages. + try { + Entry e = unflushedEntry; + while (e != null) { + // Just decrease; do not trigger any events via decrementPendingOutboundBytes() + int size = e.pendingSize; + TOTAL_PENDING_SIZE_UPDATER.addAndGet(this, -size); + + if (!e.cancelled) { + ReferenceCountUtil.safeRelease(e.msg); + safeFail(e.promise, cause); + } + e = e.unguardedRecycleAndGetNext(); + } + } finally { + inFail = false; + } + clearNioBuffers(); + } + + void close(ClosedChannelException cause) { + close(cause, false); + } + + private static void safeSuccess(ChannelPromise promise) { + // Only log if the given promise is not of type VoidChannelPromise as trySuccess(...) is expected to return + // false. + PromiseNotificationUtil.trySuccess(promise, null, promise instanceof VoidChannelPromise ? null : logger); + } + + private static void safeFail(ChannelPromise promise, Throwable cause) { + // Only log if the given promise is not of type VoidChannelPromise as tryFailure(...) is expected to return + // false. + PromiseNotificationUtil.tryFailure(promise, cause, promise instanceof VoidChannelPromise ? null : logger); + } + + @Deprecated + public void recycle() { + // NOOP + } + + public long totalPendingWriteBytes() { + return totalPendingSize; + } + + /** + * Get how many bytes can be written until {@link #isWritable()} returns {@code false}. + * This quantity will always be non-negative. If {@link #isWritable()} is {@code false} then 0. + */ + public long bytesBeforeUnwritable() { + // +1 because writability doesn't change until the threshold is crossed (not equal to). + long bytes = channel.config().getWriteBufferHighWaterMark() - totalPendingSize + 1; + // If bytes is negative we know we are not writable, but if bytes is non-negative we have to check writability. + // Note that totalPendingSize and isWritable() use different volatile variables that are not synchronized + // together. totalPendingSize will be updated before isWritable(). + return bytes > 0 && isWritable() ? bytes : 0; + } + + /** + * Get how many bytes must be drained from the underlying buffer until {@link #isWritable()} returns {@code true}. + * This quantity will always be non-negative. If {@link #isWritable()} is {@code true} then 0. + */ + public long bytesBeforeWritable() { + // +1 because writability doesn't change until the threshold is crossed (not equal to). + long bytes = totalPendingSize - channel.config().getWriteBufferLowWaterMark() + 1; + // If bytes is negative we know we are writable, but if bytes is non-negative we have to check writability. + // Note that totalPendingSize and isWritable() use different volatile variables that are not synchronized + // together. totalPendingSize will be updated before isWritable(). + return bytes <= 0 || isWritable() ? 0 : bytes; + } + + /** + * Call {@link MessageProcessor#processMessage(Object)} for each flushed message + * in this {@link ChannelOutboundBuffer} until {@link MessageProcessor#processMessage(Object)} + * returns {@code false} or there are no more flushed messages to process. + */ + public void forEachFlushedMessage(MessageProcessor processor) throws Exception { + ObjectUtil.checkNotNull(processor, "processor"); + + Entry entry = flushedEntry; + if (entry == null) { + return; + } + + do { + if (!entry.cancelled) { + if (!processor.processMessage(entry.msg)) { + return; + } + } + entry = entry.next; + } while (isFlushedEntry(entry)); + } + + private boolean isFlushedEntry(Entry e) { + return e != null && e != unflushedEntry; + } + + public interface MessageProcessor { + /** + * Will be called for each flushed message until it either there are no more flushed messages or this + * method returns {@code false}. + */ + boolean processMessage(Object msg) throws Exception; + } + + static final class Entry { + private static final ObjectPool RECYCLER = ObjectPool.newPool(new ObjectCreator() { + @Override + public Entry newObject(Handle handle) { + return new Entry(handle); + } + }); + + private final EnhancedHandle handle; + Entry next; + Object msg; + ByteBuffer[] bufs; + ByteBuffer buf; + ChannelPromise promise; + long progress; + long total; + int pendingSize; + int count = -1; + boolean cancelled; + + private Entry(Handle handle) { + this.handle = (EnhancedHandle) handle; + } + + static Entry newInstance(Object msg, int size, long total, ChannelPromise promise) { + Entry entry = RECYCLER.get(); + entry.msg = msg; + entry.pendingSize = size + CHANNEL_OUTBOUND_BUFFER_ENTRY_OVERHEAD; + entry.total = total; + entry.promise = promise; + return entry; + } + + int cancel() { + if (!cancelled) { + cancelled = true; + int pSize = pendingSize; + + // release message and replace with an empty buffer + ReferenceCountUtil.safeRelease(msg); + msg = Unpooled.EMPTY_BUFFER; + + pendingSize = 0; + total = 0; + progress = 0; + bufs = null; + buf = null; + return pSize; + } + return 0; + } + + void unguardedRecycle() { + next = null; + bufs = null; + buf = null; + msg = null; + promise = null; + progress = 0; + total = 0; + pendingSize = 0; + count = -1; + cancelled = false; + handle.unguardedRecycle(this); + } + + Entry unguardedRecycleAndGetNext() { + Entry next = this.next; + unguardedRecycle(); + return next; + } + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/ChannelOutboundHandler.java b/netty-channel/src/main/java/io/netty/channel/ChannelOutboundHandler.java new file mode 100644 index 0000000..5d74e18 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/ChannelOutboundHandler.java @@ -0,0 +1,99 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import java.net.SocketAddress; + +/** + * {@link ChannelHandler} which will get notified for IO-outbound-operations. + */ +public interface ChannelOutboundHandler extends ChannelHandler { + /** + * Called once a bind operation is made. + * + * @param ctx the {@link ChannelHandlerContext} for which the bind operation is made + * @param localAddress the {@link SocketAddress} to which it should bound + * @param promise the {@link ChannelPromise} to notify once the operation completes + * @throws Exception thrown if an error occurs + */ + void bind(ChannelHandlerContext ctx, SocketAddress localAddress, ChannelPromise promise) throws Exception; + + /** + * Called once a connect operation is made. + * + * @param ctx the {@link ChannelHandlerContext} for which the connect operation is made + * @param remoteAddress the {@link SocketAddress} to which it should connect + * @param localAddress the {@link SocketAddress} which is used as source on connect + * @param promise the {@link ChannelPromise} to notify once the operation completes + * @throws Exception thrown if an error occurs + */ + void connect( + ChannelHandlerContext ctx, SocketAddress remoteAddress, + SocketAddress localAddress, ChannelPromise promise) throws Exception; + + /** + * Called once a disconnect operation is made. + * + * @param ctx the {@link ChannelHandlerContext} for which the disconnect operation is made + * @param promise the {@link ChannelPromise} to notify once the operation completes + * @throws Exception thrown if an error occurs + */ + void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception; + + /** + * Called once a close operation is made. + * + * @param ctx the {@link ChannelHandlerContext} for which the close operation is made + * @param promise the {@link ChannelPromise} to notify once the operation completes + * @throws Exception thrown if an error occurs + */ + void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception; + + /** + * Called once a deregister operation is made from the current registered {@link EventLoop}. + * + * @param ctx the {@link ChannelHandlerContext} for which the close operation is made + * @param promise the {@link ChannelPromise} to notify once the operation completes + * @throws Exception thrown if an error occurs + */ + void deregister(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception; + + /** + * Intercepts {@link ChannelHandlerContext#read()}. + */ + void read(ChannelHandlerContext ctx) throws Exception; + + /** + * Called once a write operation is made. The write operation will write the messages through the + * {@link ChannelPipeline}. Those are then ready to be flushed to the actual {@link Channel} once + * {@link Channel#flush()} is called + * + * @param ctx the {@link ChannelHandlerContext} for which the write operation is made + * @param msg the message to write + * @param promise the {@link ChannelPromise} to notify once the operation completes + * @throws Exception thrown if an error occurs + */ + void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception; + + /** + * Called once a flush operation is made. The flush operation will try to flush out all previous written messages + * that are pending. + * + * @param ctx the {@link ChannelHandlerContext} for which the flush operation is made + * @throws Exception thrown if an error occurs + */ + void flush(ChannelHandlerContext ctx) throws Exception; +} diff --git a/netty-channel/src/main/java/io/netty/channel/ChannelOutboundHandlerAdapter.java b/netty-channel/src/main/java/io/netty/channel/ChannelOutboundHandlerAdapter.java new file mode 100644 index 0000000..0ab7e4f --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/ChannelOutboundHandlerAdapter.java @@ -0,0 +1,127 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.channel.ChannelHandlerMask.Skip; + +import java.net.SocketAddress; + +/** + * Skeleton implementation of a {@link ChannelOutboundHandler}. This implementation just forwards each method call via + * the {@link ChannelHandlerContext}. + */ +public class ChannelOutboundHandlerAdapter extends ChannelHandlerAdapter implements ChannelOutboundHandler { + + /** + * Calls {@link ChannelHandlerContext#bind(SocketAddress, ChannelPromise)} to forward + * to the next {@link ChannelOutboundHandler} in the {@link ChannelPipeline}. + * + * Sub-classes may override this method to change behavior. + */ + @Skip + @Override + public void bind(ChannelHandlerContext ctx, SocketAddress localAddress, + ChannelPromise promise) throws Exception { + ctx.bind(localAddress, promise); + } + + /** + * Calls {@link ChannelHandlerContext#connect(SocketAddress, SocketAddress, ChannelPromise)} to forward + * to the next {@link ChannelOutboundHandler} in the {@link ChannelPipeline}. + * + * Sub-classes may override this method to change behavior. + */ + @Skip + @Override + public void connect(ChannelHandlerContext ctx, SocketAddress remoteAddress, + SocketAddress localAddress, ChannelPromise promise) throws Exception { + ctx.connect(remoteAddress, localAddress, promise); + } + + /** + * Calls {@link ChannelHandlerContext#disconnect(ChannelPromise)} to forward + * to the next {@link ChannelOutboundHandler} in the {@link ChannelPipeline}. + * + * Sub-classes may override this method to change behavior. + */ + @Skip + @Override + public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) + throws Exception { + ctx.disconnect(promise); + } + + /** + * Calls {@link ChannelHandlerContext#close(ChannelPromise)} to forward + * to the next {@link ChannelOutboundHandler} in the {@link ChannelPipeline}. + * + * Sub-classes may override this method to change behavior. + */ + @Skip + @Override + public void close(ChannelHandlerContext ctx, ChannelPromise promise) + throws Exception { + ctx.close(promise); + } + + /** + * Calls {@link ChannelHandlerContext#deregister(ChannelPromise)} to forward + * to the next {@link ChannelOutboundHandler} in the {@link ChannelPipeline}. + * + * Sub-classes may override this method to change behavior. + */ + @Skip + @Override + public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + ctx.deregister(promise); + } + + /** + * Calls {@link ChannelHandlerContext#read()} to forward + * to the next {@link ChannelOutboundHandler} in the {@link ChannelPipeline}. + * + * Sub-classes may override this method to change behavior. + */ + @Skip + @Override + public void read(ChannelHandlerContext ctx) throws Exception { + ctx.read(); + } + + /** + * Calls {@link ChannelHandlerContext#write(Object, ChannelPromise)} to forward + * to the next {@link ChannelOutboundHandler} in the {@link ChannelPipeline}. + * + * Sub-classes may override this method to change behavior. + */ + @Skip + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + ctx.write(msg, promise); + } + + /** + * Calls {@link ChannelHandlerContext#flush()} to forward + * to the next {@link ChannelOutboundHandler} in the {@link ChannelPipeline}. + * + * Sub-classes may override this method to change behavior. + */ + @Skip + @Override + public void flush(ChannelHandlerContext ctx) throws Exception { + ctx.flush(); + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/ChannelOutboundInvoker.java b/netty-channel/src/main/java/io/netty/channel/ChannelOutboundInvoker.java new file mode 100644 index 0000000..7a3ef29 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/ChannelOutboundInvoker.java @@ -0,0 +1,271 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.util.concurrent.EventExecutor; +import io.netty.util.concurrent.FutureListener; + +import java.net.ConnectException; +import java.net.SocketAddress; + +public interface ChannelOutboundInvoker { + + /** + * Request to bind to the given {@link SocketAddress} and notify the {@link ChannelFuture} once the operation + * completes, either because the operation was successful or because of an error. + *

+ * This will result in having the + * {@link ChannelOutboundHandler#bind(ChannelHandlerContext, SocketAddress, ChannelPromise)} method + * called of the next {@link ChannelOutboundHandler} contained in the {@link ChannelPipeline} of the + * {@link Channel}. + */ + ChannelFuture bind(SocketAddress localAddress); + + /** + * Request to connect to the given {@link SocketAddress} and notify the {@link ChannelFuture} once the operation + * completes, either because the operation was successful or because of an error. + *

+ * If the connection fails because of a connection timeout, the {@link ChannelFuture} will get failed with + * a {@link ConnectTimeoutException}. If it fails because of connection refused a {@link ConnectException} + * will be used. + *

+ * This will result in having the + * {@link ChannelOutboundHandler#connect(ChannelHandlerContext, SocketAddress, SocketAddress, ChannelPromise)} + * method called of the next {@link ChannelOutboundHandler} contained in the {@link ChannelPipeline} of the + * {@link Channel}. + */ + ChannelFuture connect(SocketAddress remoteAddress); + + /** + * Request to connect to the given {@link SocketAddress} while bind to the localAddress and notify the + * {@link ChannelFuture} once the operation completes, either because the operation was successful or because of + * an error. + *

+ * This will result in having the + * {@link ChannelOutboundHandler#connect(ChannelHandlerContext, SocketAddress, SocketAddress, ChannelPromise)} + * method called of the next {@link ChannelOutboundHandler} contained in the {@link ChannelPipeline} of the + * {@link Channel}. + */ + ChannelFuture connect(SocketAddress remoteAddress, SocketAddress localAddress); + + /** + * Request to disconnect from the remote peer and notify the {@link ChannelFuture} once the operation completes, + * either because the operation was successful or because of an error. + *

+ * This will result in having the + * {@link ChannelOutboundHandler#disconnect(ChannelHandlerContext, ChannelPromise)} + * method called of the next {@link ChannelOutboundHandler} contained in the {@link ChannelPipeline} of the + * {@link Channel}. + */ + ChannelFuture disconnect(); + + /** + * Request to close the {@link Channel} and notify the {@link ChannelFuture} once the operation completes, + * either because the operation was successful or because of + * an error. + * + * After it is closed it is not possible to reuse it again. + *

+ * This will result in having the + * {@link ChannelOutboundHandler#close(ChannelHandlerContext, ChannelPromise)} + * method called of the next {@link ChannelOutboundHandler} contained in the {@link ChannelPipeline} of the + * {@link Channel}. + */ + ChannelFuture close(); + + /** + * Request to deregister from the previous assigned {@link EventExecutor} and notify the + * {@link ChannelFuture} once the operation completes, either because the operation was successful or because of + * an error. + *

+ * This will result in having the + * {@link ChannelOutboundHandler#deregister(ChannelHandlerContext, ChannelPromise)} + * method called of the next {@link ChannelOutboundHandler} contained in the {@link ChannelPipeline} of the + * {@link Channel}. + * + */ + ChannelFuture deregister(); + + /** + * Request to bind to the given {@link SocketAddress} and notify the {@link ChannelFuture} once the operation + * completes, either because the operation was successful or because of an error. + * + * The given {@link ChannelPromise} will be notified. + *

+ * This will result in having the + * {@link ChannelOutboundHandler#bind(ChannelHandlerContext, SocketAddress, ChannelPromise)} method + * called of the next {@link ChannelOutboundHandler} contained in the {@link ChannelPipeline} of the + * {@link Channel}. + */ + ChannelFuture bind(SocketAddress localAddress, ChannelPromise promise); + + /** + * Request to connect to the given {@link SocketAddress} and notify the {@link ChannelFuture} once the operation + * completes, either because the operation was successful or because of an error. + * + * The given {@link ChannelFuture} will be notified. + * + *

+ * If the connection fails because of a connection timeout, the {@link ChannelFuture} will get failed with + * a {@link ConnectTimeoutException}. If it fails because of connection refused a {@link ConnectException} + * will be used. + *

+ * This will result in having the + * {@link ChannelOutboundHandler#connect(ChannelHandlerContext, SocketAddress, SocketAddress, ChannelPromise)} + * method called of the next {@link ChannelOutboundHandler} contained in the {@link ChannelPipeline} of the + * {@link Channel}. + */ + ChannelFuture connect(SocketAddress remoteAddress, ChannelPromise promise); + + /** + * Request to connect to the given {@link SocketAddress} while bind to the localAddress and notify the + * {@link ChannelFuture} once the operation completes, either because the operation was successful or because of + * an error. + * + * The given {@link ChannelPromise} will be notified and also returned. + *

+ * This will result in having the + * {@link ChannelOutboundHandler#connect(ChannelHandlerContext, SocketAddress, SocketAddress, ChannelPromise)} + * method called of the next {@link ChannelOutboundHandler} contained in the {@link ChannelPipeline} of the + * {@link Channel}. + */ + ChannelFuture connect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise); + + /** + * Request to disconnect from the remote peer and notify the {@link ChannelFuture} once the operation completes, + * either because the operation was successful or because of an error. + * + * The given {@link ChannelPromise} will be notified. + *

+ * This will result in having the + * {@link ChannelOutboundHandler#disconnect(ChannelHandlerContext, ChannelPromise)} + * method called of the next {@link ChannelOutboundHandler} contained in the {@link ChannelPipeline} of the + * {@link Channel}. + */ + ChannelFuture disconnect(ChannelPromise promise); + + /** + * Request to close the {@link Channel} and notify the {@link ChannelFuture} once the operation completes, + * either because the operation was successful or because of + * an error. + * + * After it is closed it is not possible to reuse it again. + * The given {@link ChannelPromise} will be notified. + *

+ * This will result in having the + * {@link ChannelOutboundHandler#close(ChannelHandlerContext, ChannelPromise)} + * method called of the next {@link ChannelOutboundHandler} contained in the {@link ChannelPipeline} of the + * {@link Channel}. + */ + ChannelFuture close(ChannelPromise promise); + + /** + * Request to deregister from the previous assigned {@link EventExecutor} and notify the + * {@link ChannelFuture} once the operation completes, either because the operation was successful or because of + * an error. + * + * The given {@link ChannelPromise} will be notified. + *

+ * This will result in having the + * {@link ChannelOutboundHandler#deregister(ChannelHandlerContext, ChannelPromise)} + * method called of the next {@link ChannelOutboundHandler} contained in the {@link ChannelPipeline} of the + * {@link Channel}. + */ + ChannelFuture deregister(ChannelPromise promise); + + /** + * Request to Read data from the {@link Channel} into the first inbound buffer, triggers an + * {@link ChannelInboundHandler#channelRead(ChannelHandlerContext, Object)} event if data was + * read, and triggers a + * {@link ChannelInboundHandler#channelReadComplete(ChannelHandlerContext) channelReadComplete} event so the + * handler can decide to continue reading. If there's a pending read operation already, this method does nothing. + *

+ * This will result in having the + * {@link ChannelOutboundHandler#read(ChannelHandlerContext)} + * method called of the next {@link ChannelOutboundHandler} contained in the {@link ChannelPipeline} of the + * {@link Channel}. + */ + ChannelOutboundInvoker read(); + + /** + * Request to write a message via this {@link ChannelHandlerContext} through the {@link ChannelPipeline}. + * This method will not request to actual flush, so be sure to call {@link #flush()} + * once you want to request to flush all pending data to the actual transport. + */ + ChannelFuture write(Object msg); + + /** + * Request to write a message via this {@link ChannelHandlerContext} through the {@link ChannelPipeline}. + * This method will not request to actual flush, so be sure to call {@link #flush()} + * once you want to request to flush all pending data to the actual transport. + */ + ChannelFuture write(Object msg, ChannelPromise promise); + + /** + * Request to flush all pending messages via this ChannelOutboundInvoker. + */ + ChannelOutboundInvoker flush(); + + /** + * Shortcut for call {@link #write(Object, ChannelPromise)} and {@link #flush()}. + */ + ChannelFuture writeAndFlush(Object msg, ChannelPromise promise); + + /** + * Shortcut for call {@link #write(Object)} and {@link #flush()}. + */ + ChannelFuture writeAndFlush(Object msg); + + /** + * Return a new {@link ChannelPromise}. + */ + ChannelPromise newPromise(); + + /** + * Return an new {@link ChannelProgressivePromise} + */ + ChannelProgressivePromise newProgressivePromise(); + + /** + * Create a new {@link ChannelFuture} which is marked as succeeded already. So {@link ChannelFuture#isSuccess()} + * will return {@code true}. All {@link FutureListener} added to it will be notified directly. Also + * every call of blocking methods will just return without blocking. + */ + ChannelFuture newSucceededFuture(); + + /** + * Create a new {@link ChannelFuture} which is marked as failed already. So {@link ChannelFuture#isSuccess()} + * will return {@code false}. All {@link FutureListener} added to it will be notified directly. Also + * every call of blocking methods will just return without blocking. + */ + ChannelFuture newFailedFuture(Throwable cause); + + /** + * Return a special ChannelPromise which can be reused for different operations. + *

+ * It's only supported to use + * it for {@link ChannelOutboundInvoker#write(Object, ChannelPromise)}. + *

+ *

+ * Be aware that the returned {@link ChannelPromise} will not support most operations and should only be used + * if you want to save an object allocation for every write operation. You will not be able to detect if the + * operation was complete, only if it failed as the implementation will call + * {@link ChannelPipeline#fireExceptionCaught(Throwable)} in this case. + *

+ * Be aware this is an expert feature and should be used with care! + */ + ChannelPromise voidPromise(); +} diff --git a/netty-channel/src/main/java/io/netty/channel/ChannelPipeline.java b/netty-channel/src/main/java/io/netty/channel/ChannelPipeline.java new file mode 100644 index 0000000..c59a7f5 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/ChannelPipeline.java @@ -0,0 +1,632 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.buffer.ByteBuf; +import io.netty.util.concurrent.DefaultEventExecutorGroup; +import io.netty.util.concurrent.EventExecutorGroup; +import io.netty.util.concurrent.UnorderedThreadPoolEventExecutor; + +import java.net.SocketAddress; +import java.nio.ByteBuffer; +import java.nio.channels.SocketChannel; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.NoSuchElementException; + + +/** + * A list of {@link ChannelHandler}s which handles or intercepts inbound events and outbound operations of a + * {@link Channel}. {@link ChannelPipeline} implements an advanced form of the + * Intercepting Filter pattern + * to give a user full control over how an event is handled and how the {@link ChannelHandler}s in a pipeline + * interact with each other. + * + *

Creation of a pipeline

+ * + * Each channel has its own pipeline and it is created automatically when a new channel is created. + * + *

How an event flows in a pipeline

+ * + * The following diagram describes how I/O events are processed by {@link ChannelHandler}s in a {@link ChannelPipeline} + * typically. An I/O event is handled by either a {@link ChannelInboundHandler} or a {@link ChannelOutboundHandler} + * and be forwarded to its closest handler by calling the event propagation methods defined in + * {@link ChannelHandlerContext}, such as {@link ChannelHandlerContext#fireChannelRead(Object)} and + * {@link ChannelHandlerContext#write(Object)}. + * + *
+ *                                                 I/O Request
+ *                                            via {@link Channel} or
+ *                                        {@link ChannelHandlerContext}
+ *                                                      |
+ *  +---------------------------------------------------+---------------+
+ *  |                           ChannelPipeline         |               |
+ *  |                                                  \|/              |
+ *  |    +---------------------+            +-----------+----------+    |
+ *  |    | Inbound Handler  N  |            | Outbound Handler  1  |    |
+ *  |    +----------+----------+            +-----------+----------+    |
+ *  |              /|\                                  |               |
+ *  |               |                                  \|/              |
+ *  |    +----------+----------+            +-----------+----------+    |
+ *  |    | Inbound Handler N-1 |            | Outbound Handler  2  |    |
+ *  |    +----------+----------+            +-----------+----------+    |
+ *  |              /|\                                  .               |
+ *  |               .                                   .               |
+ *  | ChannelHandlerContext.fireIN_EVT() ChannelHandlerContext.OUT_EVT()|
+ *  |        [ method call]                       [method call]         |
+ *  |               .                                   .               |
+ *  |               .                                  \|/              |
+ *  |    +----------+----------+            +-----------+----------+    |
+ *  |    | Inbound Handler  2  |            | Outbound Handler M-1 |    |
+ *  |    +----------+----------+            +-----------+----------+    |
+ *  |              /|\                                  |               |
+ *  |               |                                  \|/              |
+ *  |    +----------+----------+            +-----------+----------+    |
+ *  |    | Inbound Handler  1  |            | Outbound Handler  M  |    |
+ *  |    +----------+----------+            +-----------+----------+    |
+ *  |              /|\                                  |               |
+ *  +---------------+-----------------------------------+---------------+
+ *                  |                                  \|/
+ *  +---------------+-----------------------------------+---------------+
+ *  |               |                                   |               |
+ *  |       [ Socket.read() ]                    [ Socket.write() ]     |
+ *  |                                                                   |
+ *  |  Netty Internal I/O Threads (Transport Implementation)            |
+ *  +-------------------------------------------------------------------+
+ * 
+ * An inbound event is handled by the inbound handlers in the bottom-up direction as shown on the left side of the + * diagram. An inbound handler usually handles the inbound data generated by the I/O thread on the bottom of the + * diagram. The inbound data is often read from a remote peer via the actual input operation such as + * {@link SocketChannel#read(ByteBuffer)}. If an inbound event goes beyond the top inbound handler, it is discarded + * silently, or logged if it needs your attention. + *

+ * An outbound event is handled by the outbound handler in the top-down direction as shown on the right side of the + * diagram. An outbound handler usually generates or transforms the outbound traffic such as write requests. + * If an outbound event goes beyond the bottom outbound handler, it is handled by an I/O thread associated with the + * {@link Channel}. The I/O thread often performs the actual output operation such as + * {@link SocketChannel#write(ByteBuffer)}. + *

+ * For example, let us assume that we created the following pipeline: + *

+ * {@link ChannelPipeline} p = ...;
+ * p.addLast("1", new InboundHandlerA());
+ * p.addLast("2", new InboundHandlerB());
+ * p.addLast("3", new OutboundHandlerA());
+ * p.addLast("4", new OutboundHandlerB());
+ * p.addLast("5", new InboundOutboundHandlerX());
+ * 
+ * In the example above, the class whose name starts with {@code Inbound} means it is an inbound handler. + * The class whose name starts with {@code Outbound} means it is a outbound handler. + *

+ * In the given example configuration, the handler evaluation order is 1, 2, 3, 4, 5 when an event goes inbound. + * When an event goes outbound, the order is 5, 4, 3, 2, 1. On top of this principle, {@link ChannelPipeline} skips + * the evaluation of certain handlers to shorten the stack depth: + *

    + *
  • 3 and 4 don't implement {@link ChannelInboundHandler}, and therefore the actual evaluation order of an inbound + * event will be: 1, 2, and 5.
  • + *
  • 1 and 2 don't implement {@link ChannelOutboundHandler}, and therefore the actual evaluation order of a + * outbound event will be: 5, 4, and 3.
  • + *
  • If 5 implements both {@link ChannelInboundHandler} and {@link ChannelOutboundHandler}, the evaluation order of + * an inbound and a outbound event could be 125 and 543 respectively.
  • + *
+ * + *

Forwarding an event to the next handler

+ * + * As you might noticed in the diagram shows, a handler has to invoke the event propagation methods in + * {@link ChannelHandlerContext} to forward an event to its next handler. Those methods include: + *
    + *
  • Inbound event propagation methods: + *
      + *
    • {@link ChannelHandlerContext#fireChannelRegistered()}
    • + *
    • {@link ChannelHandlerContext#fireChannelActive()}
    • + *
    • {@link ChannelHandlerContext#fireChannelRead(Object)}
    • + *
    • {@link ChannelHandlerContext#fireChannelReadComplete()}
    • + *
    • {@link ChannelHandlerContext#fireExceptionCaught(Throwable)}
    • + *
    • {@link ChannelHandlerContext#fireUserEventTriggered(Object)}
    • + *
    • {@link ChannelHandlerContext#fireChannelWritabilityChanged()}
    • + *
    • {@link ChannelHandlerContext#fireChannelInactive()}
    • + *
    • {@link ChannelHandlerContext#fireChannelUnregistered()}
    • + *
    + *
  • + *
  • Outbound event propagation methods: + *
      + *
    • {@link ChannelHandlerContext#bind(SocketAddress, ChannelPromise)}
    • + *
    • {@link ChannelHandlerContext#connect(SocketAddress, SocketAddress, ChannelPromise)}
    • + *
    • {@link ChannelHandlerContext#write(Object, ChannelPromise)}
    • + *
    • {@link ChannelHandlerContext#flush()}
    • + *
    • {@link ChannelHandlerContext#read()}
    • + *
    • {@link ChannelHandlerContext#disconnect(ChannelPromise)}
    • + *
    • {@link ChannelHandlerContext#close(ChannelPromise)}
    • + *
    • {@link ChannelHandlerContext#deregister(ChannelPromise)}
    • + *
    + *
  • + *
+ * + * and the following example shows how the event propagation is usually done: + * + *
+ * public class MyInboundHandler extends {@link ChannelInboundHandlerAdapter} {
+ *     {@code @Override}
+ *     public void channelActive({@link ChannelHandlerContext} ctx) {
+ *         System.out.println("Connected!");
+ *         ctx.fireChannelActive();
+ *     }
+ * }
+ *
+ * public class MyOutboundHandler extends {@link ChannelOutboundHandlerAdapter} {
+ *     {@code @Override}
+ *     public void close({@link ChannelHandlerContext} ctx, {@link ChannelPromise} promise) {
+ *         System.out.println("Closing ..");
+ *         ctx.close(promise);
+ *     }
+ * }
+ * 
+ * + *

Building a pipeline

+ *

+ * A user is supposed to have one or more {@link ChannelHandler}s in a pipeline to receive I/O events (e.g. read) and + * to request I/O operations (e.g. write and close). For example, a typical server will have the following handlers + * in each channel's pipeline, but your mileage may vary depending on the complexity and characteristics of the + * protocol and business logic: + * + *

    + *
  1. Protocol Decoder - translates binary data (e.g. {@link ByteBuf}) into a Java object.
  2. + *
  3. Protocol Encoder - translates a Java object into binary data.
  4. + *
  5. Business Logic Handler - performs the actual business logic (e.g. database access).
  6. + *
+ * + * and it could be represented as shown in the following example: + * + *
+ * static final {@link EventExecutorGroup} group = new {@link DefaultEventExecutorGroup}(16);
+ * ...
+ *
+ * {@link ChannelPipeline} pipeline = ch.pipeline();
+ *
+ * pipeline.addLast("decoder", new MyProtocolDecoder());
+ * pipeline.addLast("encoder", new MyProtocolEncoder());
+ *
+ * // Tell the pipeline to run MyBusinessLogicHandler's event handler methods
+ * // in a different thread than an I/O thread so that the I/O thread is not blocked by
+ * // a time-consuming task.
+ * // If your business logic is fully asynchronous or finished very quickly, you don't
+ * // need to specify a group.
+ * pipeline.addLast(group, "handler", new MyBusinessLogicHandler());
+ * 
+ * + * Be aware that while using {@link DefaultEventLoopGroup} will offload the operation from the {@link EventLoop} it will + * still process tasks in a serial fashion per {@link ChannelHandlerContext} and so guarantee ordering. Due the ordering + * it may still become a bottle-neck. If ordering is not a requirement for your use-case you may want to consider using + * {@link UnorderedThreadPoolEventExecutor} to maximize the parallelism of the task execution. + * + *

Thread safety

+ *

+ * A {@link ChannelHandler} can be added or removed at any time because a {@link ChannelPipeline} is thread safe. + * For example, you can insert an encryption handler when sensitive information is about to be exchanged, and remove it + * after the exchange. + */ +public interface ChannelPipeline + extends ChannelInboundInvoker, ChannelOutboundInvoker, Iterable> { + + /** + * Inserts a {@link ChannelHandler} at the first position of this pipeline. + * + * @param name the name of the handler to insert first + * @param handler the handler to insert first + * + * @throws IllegalArgumentException + * if there's an entry with the same name already in the pipeline + * @throws NullPointerException + * if the specified handler is {@code null} + */ + ChannelPipeline addFirst(String name, ChannelHandler handler); + + /** + * Inserts a {@link ChannelHandler} at the first position of this pipeline. + * + * @param group the {@link EventExecutorGroup} which will be used to execute the {@link ChannelHandler} + * methods + * @param name the name of the handler to insert first + * @param handler the handler to insert first + * + * @throws IllegalArgumentException + * if there's an entry with the same name already in the pipeline + * @throws NullPointerException + * if the specified handler is {@code null} + */ + ChannelPipeline addFirst(EventExecutorGroup group, String name, ChannelHandler handler); + + /** + * Appends a {@link ChannelHandler} at the last position of this pipeline. + * + * @param name the name of the handler to append + * @param handler the handler to append + * + * @throws IllegalArgumentException + * if there's an entry with the same name already in the pipeline + * @throws NullPointerException + * if the specified handler is {@code null} + */ + ChannelPipeline addLast(String name, ChannelHandler handler); + + /** + * Appends a {@link ChannelHandler} at the last position of this pipeline. + * + * @param group the {@link EventExecutorGroup} which will be used to execute the {@link ChannelHandler} + * methods + * @param name the name of the handler to append + * @param handler the handler to append + * + * @throws IllegalArgumentException + * if there's an entry with the same name already in the pipeline + * @throws NullPointerException + * if the specified handler is {@code null} + */ + ChannelPipeline addLast(EventExecutorGroup group, String name, ChannelHandler handler); + + /** + * Inserts a {@link ChannelHandler} before an existing handler of this + * pipeline. + * + * @param baseName the name of the existing handler + * @param name the name of the handler to insert before + * @param handler the handler to insert before + * + * @throws NoSuchElementException + * if there's no such entry with the specified {@code baseName} + * @throws IllegalArgumentException + * if there's an entry with the same name already in the pipeline + * @throws NullPointerException + * if the specified baseName or handler is {@code null} + */ + ChannelPipeline addBefore(String baseName, String name, ChannelHandler handler); + + /** + * Inserts a {@link ChannelHandler} before an existing handler of this + * pipeline. + * + * @param group the {@link EventExecutorGroup} which will be used to execute the {@link ChannelHandler} + * methods + * @param baseName the name of the existing handler + * @param name the name of the handler to insert before + * @param handler the handler to insert before + * + * @throws NoSuchElementException + * if there's no such entry with the specified {@code baseName} + * @throws IllegalArgumentException + * if there's an entry with the same name already in the pipeline + * @throws NullPointerException + * if the specified baseName or handler is {@code null} + */ + ChannelPipeline addBefore(EventExecutorGroup group, String baseName, String name, ChannelHandler handler); + + /** + * Inserts a {@link ChannelHandler} after an existing handler of this + * pipeline. + * + * @param baseName the name of the existing handler + * @param name the name of the handler to insert after + * @param handler the handler to insert after + * + * @throws NoSuchElementException + * if there's no such entry with the specified {@code baseName} + * @throws IllegalArgumentException + * if there's an entry with the same name already in the pipeline + * @throws NullPointerException + * if the specified baseName or handler is {@code null} + */ + ChannelPipeline addAfter(String baseName, String name, ChannelHandler handler); + + /** + * Inserts a {@link ChannelHandler} after an existing handler of this + * pipeline. + * + * @param group the {@link EventExecutorGroup} which will be used to execute the {@link ChannelHandler} + * methods + * @param baseName the name of the existing handler + * @param name the name of the handler to insert after + * @param handler the handler to insert after + * + * @throws NoSuchElementException + * if there's no such entry with the specified {@code baseName} + * @throws IllegalArgumentException + * if there's an entry with the same name already in the pipeline + * @throws NullPointerException + * if the specified baseName or handler is {@code null} + */ + ChannelPipeline addAfter(EventExecutorGroup group, String baseName, String name, ChannelHandler handler); + + /** + * Inserts {@link ChannelHandler}s at the first position of this pipeline. + * + * @param handlers the handlers to insert first + * + */ + ChannelPipeline addFirst(ChannelHandler... handlers); + + /** + * Inserts {@link ChannelHandler}s at the first position of this pipeline. + * + * @param group the {@link EventExecutorGroup} which will be used to execute the {@link ChannelHandler}s + * methods. + * @param handlers the handlers to insert first + * + */ + ChannelPipeline addFirst(EventExecutorGroup group, ChannelHandler... handlers); + + /** + * Inserts {@link ChannelHandler}s at the last position of this pipeline. + * + * @param handlers the handlers to insert last + * + */ + ChannelPipeline addLast(ChannelHandler... handlers); + + /** + * Inserts {@link ChannelHandler}s at the last position of this pipeline. + * + * @param group the {@link EventExecutorGroup} which will be used to execute the {@link ChannelHandler}s + * methods. + * @param handlers the handlers to insert last + * + */ + ChannelPipeline addLast(EventExecutorGroup group, ChannelHandler... handlers); + + /** + * Removes the specified {@link ChannelHandler} from this pipeline. + * + * @param handler the {@link ChannelHandler} to remove + * + * @throws NoSuchElementException + * if there's no such handler in this pipeline + * @throws NullPointerException + * if the specified handler is {@code null} + */ + ChannelPipeline remove(ChannelHandler handler); + + /** + * Removes the {@link ChannelHandler} with the specified name from this pipeline. + * + * @param name the name under which the {@link ChannelHandler} was stored. + * + * @return the removed handler + * + * @throws NoSuchElementException + * if there's no such handler with the specified name in this pipeline + * @throws NullPointerException + * if the specified name is {@code null} + */ + ChannelHandler remove(String name); + + /** + * Removes the {@link ChannelHandler} of the specified type from this pipeline. + * + * @param the type of the handler + * @param handlerType the type of the handler + * + * @return the removed handler + * + * @throws NoSuchElementException + * if there's no such handler of the specified type in this pipeline + * @throws NullPointerException + * if the specified handler type is {@code null} + */ + T remove(Class handlerType); + + /** + * Removes the first {@link ChannelHandler} in this pipeline. + * + * @return the removed handler + * + * @throws NoSuchElementException + * if this pipeline is empty + */ + ChannelHandler removeFirst(); + + /** + * Removes the last {@link ChannelHandler} in this pipeline. + * + * @return the removed handler + * + * @throws NoSuchElementException + * if this pipeline is empty + */ + ChannelHandler removeLast(); + + /** + * Replaces the specified {@link ChannelHandler} with a new handler in this pipeline. + * + * @param oldHandler the {@link ChannelHandler} to be replaced + * @param newName the name under which the replacement should be added + * @param newHandler the {@link ChannelHandler} which is used as replacement + * + * @return itself + + * @throws NoSuchElementException + * if the specified old handler does not exist in this pipeline + * @throws IllegalArgumentException + * if a handler with the specified new name already exists in this + * pipeline, except for the handler to be replaced + * @throws NullPointerException + * if the specified old handler or new handler is + * {@code null} + */ + ChannelPipeline replace(ChannelHandler oldHandler, String newName, ChannelHandler newHandler); + + /** + * Replaces the {@link ChannelHandler} of the specified name with a new handler in this pipeline. + * + * @param oldName the name of the {@link ChannelHandler} to be replaced + * @param newName the name under which the replacement should be added + * @param newHandler the {@link ChannelHandler} which is used as replacement + * + * @return the removed handler + * + * @throws NoSuchElementException + * if the handler with the specified old name does not exist in this pipeline + * @throws IllegalArgumentException + * if a handler with the specified new name already exists in this + * pipeline, except for the handler to be replaced + * @throws NullPointerException + * if the specified old handler or new handler is + * {@code null} + */ + ChannelHandler replace(String oldName, String newName, ChannelHandler newHandler); + + /** + * Replaces the {@link ChannelHandler} of the specified type with a new handler in this pipeline. + * + * @param oldHandlerType the type of the handler to be removed + * @param newName the name under which the replacement should be added + * @param newHandler the {@link ChannelHandler} which is used as replacement + * + * @return the removed handler + * + * @throws NoSuchElementException + * if the handler of the specified old handler type does not exist + * in this pipeline + * @throws IllegalArgumentException + * if a handler with the specified new name already exists in this + * pipeline, except for the handler to be replaced + * @throws NullPointerException + * if the specified old handler or new handler is + * {@code null} + */ + T replace(Class oldHandlerType, String newName, + ChannelHandler newHandler); + + /** + * Returns the first {@link ChannelHandler} in this pipeline. + * + * @return the first handler. {@code null} if this pipeline is empty. + */ + ChannelHandler first(); + + /** + * Returns the context of the first {@link ChannelHandler} in this pipeline. + * + * @return the context of the first handler. {@code null} if this pipeline is empty. + */ + ChannelHandlerContext firstContext(); + + /** + * Returns the last {@link ChannelHandler} in this pipeline. + * + * @return the last handler. {@code null} if this pipeline is empty. + */ + ChannelHandler last(); + + /** + * Returns the context of the last {@link ChannelHandler} in this pipeline. + * + * @return the context of the last handler. {@code null} if this pipeline is empty. + */ + ChannelHandlerContext lastContext(); + + /** + * Returns the {@link ChannelHandler} with the specified name in this + * pipeline. + * + * @return the handler with the specified name. + * {@code null} if there's no such handler in this pipeline. + */ + ChannelHandler get(String name); + + /** + * Returns the {@link ChannelHandler} of the specified type in this + * pipeline. + * + * @return the handler of the specified handler type. + * {@code null} if there's no such handler in this pipeline. + */ + T get(Class handlerType); + + /** + * Returns the context object of the specified {@link ChannelHandler} in + * this pipeline. + * + * @return the context object of the specified handler. + * {@code null} if there's no such handler in this pipeline. + */ + ChannelHandlerContext context(ChannelHandler handler); + + /** + * Returns the context object of the {@link ChannelHandler} with the + * specified name in this pipeline. + * + * @return the context object of the handler with the specified name. + * {@code null} if there's no such handler in this pipeline. + */ + ChannelHandlerContext context(String name); + + /** + * Returns the context object of the {@link ChannelHandler} of the + * specified type in this pipeline. + * + * @return the context object of the handler of the specified type. + * {@code null} if there's no such handler in this pipeline. + */ + ChannelHandlerContext context(Class handlerType); + + /** + * Returns the {@link Channel} that this pipeline is attached to. + * + * @return the channel. {@code null} if this pipeline is not attached yet. + */ + Channel channel(); + + /** + * Returns the {@link List} of the handler names. + */ + List names(); + + /** + * Converts this pipeline into an ordered {@link Map} whose keys are + * handler names and whose values are handlers. + */ + Map toMap(); + + @Override + ChannelPipeline fireChannelRegistered(); + + @Override + ChannelPipeline fireChannelUnregistered(); + + @Override + ChannelPipeline fireChannelActive(); + + @Override + ChannelPipeline fireChannelInactive(); + + @Override + ChannelPipeline fireExceptionCaught(Throwable cause); + + @Override + ChannelPipeline fireUserEventTriggered(Object event); + + @Override + ChannelPipeline fireChannelRead(Object msg); + + @Override + ChannelPipeline fireChannelReadComplete(); + + @Override + ChannelPipeline fireChannelWritabilityChanged(); + + @Override + ChannelPipeline flush(); +} diff --git a/netty-channel/src/main/java/io/netty/channel/ChannelPipelineException.java b/netty-channel/src/main/java/io/netty/channel/ChannelPipelineException.java new file mode 100644 index 0000000..7c954a1 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/ChannelPipelineException.java @@ -0,0 +1,52 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +/** + * A {@link ChannelException} which is thrown when a {@link ChannelPipeline} + * failed to execute an operation. + */ +public class ChannelPipelineException extends ChannelException { + + private static final long serialVersionUID = 3379174210419885980L; + + /** + * Creates a new instance. + */ + public ChannelPipelineException() { + } + + /** + * Creates a new instance. + */ + public ChannelPipelineException(String message, Throwable cause) { + super(message, cause); + } + + /** + * Creates a new instance. + */ + public ChannelPipelineException(String message) { + super(message); + } + + /** + * Creates a new instance. + */ + public ChannelPipelineException(Throwable cause) { + super(cause); + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/ChannelProgressiveFuture.java b/netty-channel/src/main/java/io/netty/channel/ChannelProgressiveFuture.java new file mode 100644 index 0000000..c40786f --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/ChannelProgressiveFuture.java @@ -0,0 +1,49 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.GenericFutureListener; +import io.netty.util.concurrent.ProgressiveFuture; + +/** + * An special {@link ChannelFuture} which is used to indicate the {@link FileRegion} transfer progress + */ +public interface ChannelProgressiveFuture extends ChannelFuture, ProgressiveFuture { + @Override + ChannelProgressiveFuture addListener(GenericFutureListener> listener); + + @Override + ChannelProgressiveFuture addListeners(GenericFutureListener>... listeners); + + @Override + ChannelProgressiveFuture removeListener(GenericFutureListener> listener); + + @Override + ChannelProgressiveFuture removeListeners(GenericFutureListener>... listeners); + + @Override + ChannelProgressiveFuture sync() throws InterruptedException; + + @Override + ChannelProgressiveFuture syncUninterruptibly(); + + @Override + ChannelProgressiveFuture await() throws InterruptedException; + + @Override + ChannelProgressiveFuture awaitUninterruptibly(); +} diff --git a/netty-channel/src/main/java/io/netty/channel/ChannelProgressiveFutureListener.java b/netty-channel/src/main/java/io/netty/channel/ChannelProgressiveFutureListener.java new file mode 100644 index 0000000..dee0268 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/ChannelProgressiveFutureListener.java @@ -0,0 +1,28 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.util.concurrent.GenericProgressiveFutureListener; + +import java.util.EventListener; + +/** + * An {@link EventListener} listener which will be called once the sending task associated with future is + * being transferred. + */ +public interface ChannelProgressiveFutureListener extends GenericProgressiveFutureListener { + // Just a type alias +} diff --git a/netty-channel/src/main/java/io/netty/channel/ChannelProgressivePromise.java b/netty-channel/src/main/java/io/netty/channel/ChannelProgressivePromise.java new file mode 100644 index 0000000..3eb17cc --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/ChannelProgressivePromise.java @@ -0,0 +1,65 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.GenericFutureListener; +import io.netty.util.concurrent.ProgressivePromise; + +/** + * Special {@link ChannelPromise} which will be notified once the associated bytes is transferring. + */ +public interface ChannelProgressivePromise extends ProgressivePromise, ChannelProgressiveFuture, ChannelPromise { + + @Override + ChannelProgressivePromise addListener(GenericFutureListener> listener); + + @Override + ChannelProgressivePromise addListeners(GenericFutureListener>... listeners); + + @Override + ChannelProgressivePromise removeListener(GenericFutureListener> listener); + + @Override + ChannelProgressivePromise removeListeners(GenericFutureListener>... listeners); + + @Override + ChannelProgressivePromise sync() throws InterruptedException; + + @Override + ChannelProgressivePromise syncUninterruptibly(); + + @Override + ChannelProgressivePromise await() throws InterruptedException; + + @Override + ChannelProgressivePromise awaitUninterruptibly(); + + @Override + ChannelProgressivePromise setSuccess(Void result); + + @Override + ChannelProgressivePromise setSuccess(); + + @Override + ChannelProgressivePromise setFailure(Throwable cause); + + @Override + ChannelProgressivePromise setProgress(long progress, long total); + + @Override + ChannelProgressivePromise unvoid(); +} diff --git a/netty-channel/src/main/java/io/netty/channel/ChannelPromise.java b/netty-channel/src/main/java/io/netty/channel/ChannelPromise.java new file mode 100644 index 0000000..e14fade --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/ChannelPromise.java @@ -0,0 +1,68 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.GenericFutureListener; +import io.netty.util.concurrent.Promise; + +/** + * Special {@link ChannelFuture} which is writable. + */ +public interface ChannelPromise extends ChannelFuture, Promise { + + @Override + Channel channel(); + + @Override + ChannelPromise setSuccess(Void result); + + ChannelPromise setSuccess(); + + boolean trySuccess(); + + @Override + ChannelPromise setFailure(Throwable cause); + + @Override + ChannelPromise addListener(GenericFutureListener> listener); + + @Override + ChannelPromise addListeners(GenericFutureListener>... listeners); + + @Override + ChannelPromise removeListener(GenericFutureListener> listener); + + @Override + ChannelPromise removeListeners(GenericFutureListener>... listeners); + + @Override + ChannelPromise sync() throws InterruptedException; + + @Override + ChannelPromise syncUninterruptibly(); + + @Override + ChannelPromise await() throws InterruptedException; + + @Override + ChannelPromise awaitUninterruptibly(); + + /** + * Returns a new {@link ChannelPromise} if {@link #isVoid()} returns {@code true} otherwise itself. + */ + ChannelPromise unvoid(); +} diff --git a/netty-channel/src/main/java/io/netty/channel/ChannelPromiseAggregator.java b/netty-channel/src/main/java/io/netty/channel/ChannelPromiseAggregator.java new file mode 100644 index 0000000..be56fbb --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/ChannelPromiseAggregator.java @@ -0,0 +1,38 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.channel; + +import io.netty.util.concurrent.PromiseAggregator; +import io.netty.util.concurrent.PromiseCombiner; + +/** + * @deprecated Use {@link PromiseCombiner} + * + * Class which is used to consolidate multiple channel futures into one, by + * listening to the individual futures and producing an aggregated result + * (success/failure) when all futures have completed. + */ +@Deprecated +public final class ChannelPromiseAggregator + extends PromiseAggregator + implements ChannelFutureListener { + + public ChannelPromiseAggregator(ChannelPromise aggregatePromise) { + super(aggregatePromise); + } + +} diff --git a/netty-channel/src/main/java/io/netty/channel/ChannelPromiseNotifier.java b/netty-channel/src/main/java/io/netty/channel/ChannelPromiseNotifier.java new file mode 100644 index 0000000..bcd8ba6 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/ChannelPromiseNotifier.java @@ -0,0 +1,48 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.util.concurrent.PromiseNotifier; + +/** + * ChannelFutureListener implementation which takes other {@link ChannelPromise}(s) and notifies them on completion. + * + * @deprecated use {@link PromiseNotifier}. + */ +@Deprecated +public final class ChannelPromiseNotifier + extends PromiseNotifier + implements ChannelFutureListener { + + /** + * Create a new instance + * + * @param promises the {@link ChannelPromise}s to notify once this {@link ChannelFutureListener} is notified. + */ + public ChannelPromiseNotifier(ChannelPromise... promises) { + super(promises); + } + + /** + * Create a new instance + * + * @param logNotifyFailure {@code true} if logging should be done in case notification fails. + * @param promises the {@link ChannelPromise}s to notify once this {@link ChannelFutureListener} is notified. + */ + public ChannelPromiseNotifier(boolean logNotifyFailure, ChannelPromise... promises) { + super(logNotifyFailure, promises); + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/CoalescingBufferQueue.java b/netty-channel/src/main/java/io/netty/channel/CoalescingBufferQueue.java new file mode 100644 index 0000000..0f194bd --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/CoalescingBufferQueue.java @@ -0,0 +1,86 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.channel; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.util.internal.ObjectUtil; + +/** + * A FIFO queue of bytes where producers add bytes by repeatedly adding {@link ByteBuf} and consumers take bytes in + * arbitrary lengths. This allows producers to add lots of small buffers and the consumer to take all the bytes + * out in a single buffer. Conversely the producer may add larger buffers and the consumer could take the bytes in + * many small buffers. + * + *

Bytes are added and removed with promises. If the last byte of a buffer added with a promise is removed then + * that promise will complete when the promise passed to {@link #remove} completes. + * + *

This functionality is useful for aggregating or partitioning writes into fixed size buffers for framing protocols + * such as HTTP2. + */ +public final class CoalescingBufferQueue extends AbstractCoalescingBufferQueue { + private final Channel channel; + + public CoalescingBufferQueue(Channel channel) { + this(channel, 4); + } + + public CoalescingBufferQueue(Channel channel, int initSize) { + this(channel, initSize, false); + } + + public CoalescingBufferQueue(Channel channel, int initSize, boolean updateWritability) { + super(updateWritability ? channel : null, initSize); + this.channel = ObjectUtil.checkNotNull(channel, "channel"); + } + + /** + * Remove a {@link ByteBuf} from the queue with the specified number of bytes. Any added buffer who's bytes are + * fully consumed during removal will have it's promise completed when the passed aggregate {@link ChannelPromise} + * completes. + * + * @param bytes the maximum number of readable bytes in the returned {@link ByteBuf}, if {@code bytes} is greater + * than {@link #readableBytes} then a buffer of length {@link #readableBytes} is returned. + * @param aggregatePromise used to aggregate the promises and listeners for the constituent buffers. + * @return a {@link ByteBuf} composed of the enqueued buffers. + */ + public ByteBuf remove(int bytes, ChannelPromise aggregatePromise) { + return remove(channel.alloc(), bytes, aggregatePromise); + } + + /** + * Release all buffers in the queue and complete all listeners and promises. + */ + public void releaseAndFailAll(Throwable cause) { + releaseAndFailAll(channel, cause); + } + + @Override + protected ByteBuf compose(ByteBufAllocator alloc, ByteBuf cumulation, ByteBuf next) { + if (cumulation instanceof CompositeByteBuf) { + CompositeByteBuf composite = (CompositeByteBuf) cumulation; + composite.addComponent(true, next); + return composite; + } + return composeIntoComposite(alloc, cumulation, next); + } + + @Override + protected ByteBuf removeEmptyValue() { + return Unpooled.EMPTY_BUFFER; + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/CombinedChannelDuplexHandler.java b/netty-channel/src/main/java/io/netty/channel/CombinedChannelDuplexHandler.java new file mode 100644 index 0000000..e3f9f1f --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/CombinedChannelDuplexHandler.java @@ -0,0 +1,616 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.buffer.ByteBufAllocator; +import io.netty.util.Attribute; +import io.netty.util.AttributeKey; +import io.netty.util.concurrent.EventExecutor; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.ThrowableUtil; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.net.SocketAddress; + +/** + * Combines a {@link ChannelInboundHandler} and a {@link ChannelOutboundHandler} into one {@link ChannelHandler}. + */ +public class CombinedChannelDuplexHandler + extends ChannelDuplexHandler { + + private static final InternalLogger logger = InternalLoggerFactory.getInstance(CombinedChannelDuplexHandler.class); + + private DelegatingChannelHandlerContext inboundCtx; + private DelegatingChannelHandlerContext outboundCtx; + private volatile boolean handlerAdded; + + private I inboundHandler; + private O outboundHandler; + + /** + * Creates a new uninitialized instance. A class that extends this handler must invoke + * {@link #init(ChannelInboundHandler, ChannelOutboundHandler)} before adding this handler into a + * {@link ChannelPipeline}. + */ + protected CombinedChannelDuplexHandler() { + ensureNotSharable(); + } + + /** + * Creates a new instance that combines the specified two handlers into one. + */ + public CombinedChannelDuplexHandler(I inboundHandler, O outboundHandler) { + ensureNotSharable(); + init(inboundHandler, outboundHandler); + } + + /** + * Initialized this handler with the specified handlers. + * + * @throws IllegalStateException if this handler was not constructed via the default constructor or + * if this handler does not implement all required handler interfaces + * @throws IllegalArgumentException if the specified handlers cannot be combined into one due to a conflict + * in the type hierarchy + */ + protected final void init(I inboundHandler, O outboundHandler) { + validate(inboundHandler, outboundHandler); + this.inboundHandler = inboundHandler; + this.outboundHandler = outboundHandler; + } + + private void validate(I inboundHandler, O outboundHandler) { + if (this.inboundHandler != null) { + throw new IllegalStateException( + "init() can not be invoked if " + CombinedChannelDuplexHandler.class.getSimpleName() + + " was constructed with non-default constructor."); + } + + ObjectUtil.checkNotNull(inboundHandler, "inboundHandler"); + ObjectUtil.checkNotNull(outboundHandler, "outboundHandler"); + + if (inboundHandler instanceof ChannelOutboundHandler) { + throw new IllegalArgumentException( + "inboundHandler must not implement " + + ChannelOutboundHandler.class.getSimpleName() + " to get combined."); + } + if (outboundHandler instanceof ChannelInboundHandler) { + throw new IllegalArgumentException( + "outboundHandler must not implement " + + ChannelInboundHandler.class.getSimpleName() + " to get combined."); + } + } + + protected final I inboundHandler() { + return inboundHandler; + } + + protected final O outboundHandler() { + return outboundHandler; + } + + private void checkAdded() { + if (!handlerAdded) { + throw new IllegalStateException("handler not added to pipeline yet"); + } + } + + /** + * Removes the {@link ChannelInboundHandler} that was combined in this {@link CombinedChannelDuplexHandler}. + */ + public final void removeInboundHandler() { + checkAdded(); + inboundCtx.remove(); + } + + /** + * Removes the {@link ChannelOutboundHandler} that was combined in this {@link CombinedChannelDuplexHandler}. + */ + public final void removeOutboundHandler() { + checkAdded(); + outboundCtx.remove(); + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + if (inboundHandler == null) { + throw new IllegalStateException( + "init() must be invoked before being added to a " + ChannelPipeline.class.getSimpleName() + + " if " + CombinedChannelDuplexHandler.class.getSimpleName() + + " was constructed with the default constructor."); + } + + outboundCtx = new DelegatingChannelHandlerContext(ctx, outboundHandler); + inboundCtx = new DelegatingChannelHandlerContext(ctx, inboundHandler) { + @SuppressWarnings("deprecation") + @Override + public ChannelHandlerContext fireExceptionCaught(Throwable cause) { + if (!outboundCtx.removed) { + try { + // We directly delegate to the ChannelOutboundHandler as this may override exceptionCaught(...) + // as well + outboundHandler.exceptionCaught(outboundCtx, cause); + } catch (Throwable error) { + if (logger.isDebugEnabled()) { + logger.debug( + "An exception {}" + + "was thrown by a user handler's exceptionCaught() " + + "method while handling the following exception:", + ThrowableUtil.stackTraceToString(error), cause); + } else if (logger.isWarnEnabled()) { + logger.warn( + "An exception '{}' [enable DEBUG level for full stacktrace] " + + "was thrown by a user handler's exceptionCaught() " + + "method while handling the following exception:", error, cause); + } + } + } else { + super.fireExceptionCaught(cause); + } + return this; + } + }; + + // The inboundCtx and outboundCtx were created and set now it's safe to call removeInboundHandler() and + // removeOutboundHandler(). + handlerAdded = true; + + try { + inboundHandler.handlerAdded(inboundCtx); + } finally { + outboundHandler.handlerAdded(outboundCtx); + } + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { + try { + inboundCtx.remove(); + } finally { + outboundCtx.remove(); + } + } + + @Override + public void channelRegistered(ChannelHandlerContext ctx) throws Exception { + assert ctx == inboundCtx.ctx; + if (!inboundCtx.removed) { + inboundHandler.channelRegistered(inboundCtx); + } else { + inboundCtx.fireChannelRegistered(); + } + } + + @Override + public void channelUnregistered(ChannelHandlerContext ctx) throws Exception { + assert ctx == inboundCtx.ctx; + if (!inboundCtx.removed) { + inboundHandler.channelUnregistered(inboundCtx); + } else { + inboundCtx.fireChannelUnregistered(); + } + } + + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + assert ctx == inboundCtx.ctx; + if (!inboundCtx.removed) { + inboundHandler.channelActive(inboundCtx); + } else { + inboundCtx.fireChannelActive(); + } + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + assert ctx == inboundCtx.ctx; + if (!inboundCtx.removed) { + inboundHandler.channelInactive(inboundCtx); + } else { + inboundCtx.fireChannelInactive(); + } + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + assert ctx == inboundCtx.ctx; + if (!inboundCtx.removed) { + inboundHandler.exceptionCaught(inboundCtx, cause); + } else { + inboundCtx.fireExceptionCaught(cause); + } + } + + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + assert ctx == inboundCtx.ctx; + if (!inboundCtx.removed) { + inboundHandler.userEventTriggered(inboundCtx, evt); + } else { + inboundCtx.fireUserEventTriggered(evt); + } + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + assert ctx == inboundCtx.ctx; + if (!inboundCtx.removed) { + inboundHandler.channelRead(inboundCtx, msg); + } else { + inboundCtx.fireChannelRead(msg); + } + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + assert ctx == inboundCtx.ctx; + if (!inboundCtx.removed) { + inboundHandler.channelReadComplete(inboundCtx); + } else { + inboundCtx.fireChannelReadComplete(); + } + } + + @Override + public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception { + assert ctx == inboundCtx.ctx; + if (!inboundCtx.removed) { + inboundHandler.channelWritabilityChanged(inboundCtx); + } else { + inboundCtx.fireChannelWritabilityChanged(); + } + } + + @Override + public void bind( + ChannelHandlerContext ctx, + SocketAddress localAddress, ChannelPromise promise) throws Exception { + assert ctx == outboundCtx.ctx; + if (!outboundCtx.removed) { + outboundHandler.bind(outboundCtx, localAddress, promise); + } else { + outboundCtx.bind(localAddress, promise); + } + } + + @Override + public void connect( + ChannelHandlerContext ctx, + SocketAddress remoteAddress, SocketAddress localAddress, + ChannelPromise promise) throws Exception { + assert ctx == outboundCtx.ctx; + if (!outboundCtx.removed) { + outboundHandler.connect(outboundCtx, remoteAddress, localAddress, promise); + } else { + outboundCtx.connect(remoteAddress, localAddress, promise); + } + } + + @Override + public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + assert ctx == outboundCtx.ctx; + if (!outboundCtx.removed) { + outboundHandler.disconnect(outboundCtx, promise); + } else { + outboundCtx.disconnect(promise); + } + } + + @Override + public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + assert ctx == outboundCtx.ctx; + if (!outboundCtx.removed) { + outboundHandler.close(outboundCtx, promise); + } else { + outboundCtx.close(promise); + } + } + + @Override + public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + assert ctx == outboundCtx.ctx; + if (!outboundCtx.removed) { + outboundHandler.deregister(outboundCtx, promise); + } else { + outboundCtx.deregister(promise); + } + } + + @Override + public void read(ChannelHandlerContext ctx) throws Exception { + assert ctx == outboundCtx.ctx; + if (!outboundCtx.removed) { + outboundHandler.read(outboundCtx); + } else { + outboundCtx.read(); + } + } + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + assert ctx == outboundCtx.ctx; + if (!outboundCtx.removed) { + outboundHandler.write(outboundCtx, msg, promise); + } else { + outboundCtx.write(msg, promise); + } + } + + @Override + public void flush(ChannelHandlerContext ctx) throws Exception { + assert ctx == outboundCtx.ctx; + if (!outboundCtx.removed) { + outboundHandler.flush(outboundCtx); + } else { + outboundCtx.flush(); + } + } + + private static class DelegatingChannelHandlerContext implements ChannelHandlerContext { + + private final ChannelHandlerContext ctx; + private final ChannelHandler handler; + boolean removed; + + DelegatingChannelHandlerContext(ChannelHandlerContext ctx, ChannelHandler handler) { + this.ctx = ctx; + this.handler = handler; + } + + @Override + public Channel channel() { + return ctx.channel(); + } + + @Override + public EventExecutor executor() { + return ctx.executor(); + } + + @Override + public String name() { + return ctx.name(); + } + + @Override + public ChannelHandler handler() { + return ctx.handler(); + } + + @Override + public boolean isRemoved() { + return removed || ctx.isRemoved(); + } + + @Override + public ChannelHandlerContext fireChannelRegistered() { + ctx.fireChannelRegistered(); + return this; + } + + @Override + public ChannelHandlerContext fireChannelUnregistered() { + ctx.fireChannelUnregistered(); + return this; + } + + @Override + public ChannelHandlerContext fireChannelActive() { + ctx.fireChannelActive(); + return this; + } + + @Override + public ChannelHandlerContext fireChannelInactive() { + ctx.fireChannelInactive(); + return this; + } + + @Override + public ChannelHandlerContext fireExceptionCaught(Throwable cause) { + ctx.fireExceptionCaught(cause); + return this; + } + + @Override + public ChannelHandlerContext fireUserEventTriggered(Object event) { + ctx.fireUserEventTriggered(event); + return this; + } + + @Override + public ChannelHandlerContext fireChannelRead(Object msg) { + ctx.fireChannelRead(msg); + return this; + } + + @Override + public ChannelHandlerContext fireChannelReadComplete() { + ctx.fireChannelReadComplete(); + return this; + } + + @Override + public ChannelHandlerContext fireChannelWritabilityChanged() { + ctx.fireChannelWritabilityChanged(); + return this; + } + + @Override + public ChannelFuture bind(SocketAddress localAddress) { + return ctx.bind(localAddress); + } + + @Override + public ChannelFuture connect(SocketAddress remoteAddress) { + return ctx.connect(remoteAddress); + } + + @Override + public ChannelFuture connect(SocketAddress remoteAddress, SocketAddress localAddress) { + return ctx.connect(remoteAddress, localAddress); + } + + @Override + public ChannelFuture disconnect() { + return ctx.disconnect(); + } + + @Override + public ChannelFuture close() { + return ctx.close(); + } + + @Override + public ChannelFuture deregister() { + return ctx.deregister(); + } + + @Override + public ChannelFuture bind(SocketAddress localAddress, ChannelPromise promise) { + return ctx.bind(localAddress, promise); + } + + @Override + public ChannelFuture connect(SocketAddress remoteAddress, ChannelPromise promise) { + return ctx.connect(remoteAddress, promise); + } + + @Override + public ChannelFuture connect( + SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) { + return ctx.connect(remoteAddress, localAddress, promise); + } + + @Override + public ChannelFuture disconnect(ChannelPromise promise) { + return ctx.disconnect(promise); + } + + @Override + public ChannelFuture close(ChannelPromise promise) { + return ctx.close(promise); + } + + @Override + public ChannelFuture deregister(ChannelPromise promise) { + return ctx.deregister(promise); + } + + @Override + public ChannelHandlerContext read() { + ctx.read(); + return this; + } + + @Override + public ChannelFuture write(Object msg) { + return ctx.write(msg); + } + + @Override + public ChannelFuture write(Object msg, ChannelPromise promise) { + return ctx.write(msg, promise); + } + + @Override + public ChannelHandlerContext flush() { + ctx.flush(); + return this; + } + + @Override + public ChannelFuture writeAndFlush(Object msg, ChannelPromise promise) { + return ctx.writeAndFlush(msg, promise); + } + + @Override + public ChannelFuture writeAndFlush(Object msg) { + return ctx.writeAndFlush(msg); + } + + @Override + public ChannelPipeline pipeline() { + return ctx.pipeline(); + } + + @Override + public ByteBufAllocator alloc() { + return ctx.alloc(); + } + + @Override + public ChannelPromise newPromise() { + return ctx.newPromise(); + } + + @Override + public ChannelProgressivePromise newProgressivePromise() { + return ctx.newProgressivePromise(); + } + + @Override + public ChannelFuture newSucceededFuture() { + return ctx.newSucceededFuture(); + } + + @Override + public ChannelFuture newFailedFuture(Throwable cause) { + return ctx.newFailedFuture(cause); + } + + @Override + public ChannelPromise voidPromise() { + return ctx.voidPromise(); + } + + @Override + public Attribute attr(AttributeKey key) { + return ctx.channel().attr(key); + } + + @Override + public boolean hasAttr(AttributeKey key) { + return ctx.channel().hasAttr(key); + } + + final void remove() { + EventExecutor executor = executor(); + if (executor.inEventLoop()) { + remove0(); + } else { + executor.execute(new Runnable() { + @Override + public void run() { + remove0(); + } + }); + } + } + + private void remove0() { + if (!removed) { + removed = true; + try { + handler.handlerRemoved(this); + } catch (Throwable cause) { + fireExceptionCaught(new ChannelPipelineException( + handler.getClass().getName() + ".handlerRemoved() has thrown an exception.", cause)); + } + } + } + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/CompleteChannelFuture.java b/netty-channel/src/main/java/io/netty/channel/CompleteChannelFuture.java new file mode 100644 index 0000000..e25b058 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/CompleteChannelFuture.java @@ -0,0 +1,110 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.util.concurrent.CompleteFuture; +import io.netty.util.concurrent.EventExecutor; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.GenericFutureListener; +import io.netty.util.internal.ObjectUtil; + +/** + * A skeletal {@link ChannelFuture} implementation which represents a + * {@link ChannelFuture} which has been completed already. + */ +abstract class CompleteChannelFuture extends CompleteFuture implements ChannelFuture { + + private final Channel channel; + + /** + * Creates a new instance. + * + * @param channel the {@link Channel} associated with this future + */ + protected CompleteChannelFuture(Channel channel, EventExecutor executor) { + super(executor); + this.channel = ObjectUtil.checkNotNull(channel, "channel"); + } + + @Override + protected EventExecutor executor() { + EventExecutor e = super.executor(); + if (e == null) { + return channel().eventLoop(); + } else { + return e; + } + } + + @Override + public ChannelFuture addListener(GenericFutureListener> listener) { + super.addListener(listener); + return this; + } + + @Override + public ChannelFuture addListeners(GenericFutureListener>... listeners) { + super.addListeners(listeners); + return this; + } + + @Override + public ChannelFuture removeListener(GenericFutureListener> listener) { + super.removeListener(listener); + return this; + } + + @Override + public ChannelFuture removeListeners(GenericFutureListener>... listeners) { + super.removeListeners(listeners); + return this; + } + + @Override + public ChannelFuture syncUninterruptibly() { + return this; + } + + @Override + public ChannelFuture sync() throws InterruptedException { + return this; + } + + @Override + public ChannelFuture await() throws InterruptedException { + return this; + } + + @Override + public ChannelFuture awaitUninterruptibly() { + return this; + } + + @Override + public Channel channel() { + return channel; + } + + @Override + public Void getNow() { + return null; + } + + @Override + public boolean isVoid() { + return false; + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/ConnectTimeoutException.java b/netty-channel/src/main/java/io/netty/channel/ConnectTimeoutException.java new file mode 100644 index 0000000..1d49877 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/ConnectTimeoutException.java @@ -0,0 +1,33 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import java.net.ConnectException; + +/** + * {@link ConnectException} which will be thrown if a connection could + * not be established because of a connection timeout. + */ +public class ConnectTimeoutException extends ConnectException { + private static final long serialVersionUID = 2317065249988317463L; + + public ConnectTimeoutException(String msg) { + super(msg); + } + + public ConnectTimeoutException() { + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/DefaultAddressedEnvelope.java b/netty-channel/src/main/java/io/netty/channel/DefaultAddressedEnvelope.java new file mode 100644 index 0000000..cf9267d --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/DefaultAddressedEnvelope.java @@ -0,0 +1,129 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.channel; + +import io.netty.util.ReferenceCountUtil; +import io.netty.util.ReferenceCounted; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.StringUtil; + +import java.net.SocketAddress; + +/** + * The default {@link AddressedEnvelope} implementation. + * + * @param the type of the wrapped message + * @param the type of the recipient address + */ +public class DefaultAddressedEnvelope implements AddressedEnvelope { + + private final M message; + private final A sender; + private final A recipient; + + /** + * Creates a new instance with the specified {@code message}, {@code recipient} address, and + * {@code sender} address. + */ + public DefaultAddressedEnvelope(M message, A recipient, A sender) { + ObjectUtil.checkNotNull(message, "message"); + if (recipient == null && sender == null) { + throw new NullPointerException("recipient and sender"); + } + + this.message = message; + this.sender = sender; + this.recipient = recipient; + } + + /** + * Creates a new instance with the specified {@code message} and {@code recipient} address. + * The sender address becomes {@code null}. + */ + public DefaultAddressedEnvelope(M message, A recipient) { + this(message, recipient, null); + } + + @Override + public M content() { + return message; + } + + @Override + public A sender() { + return sender; + } + + @Override + public A recipient() { + return recipient; + } + + @Override + public int refCnt() { + if (message instanceof ReferenceCounted) { + return ((ReferenceCounted) message).refCnt(); + } else { + return 1; + } + } + + @Override + public AddressedEnvelope retain() { + ReferenceCountUtil.retain(message); + return this; + } + + @Override + public AddressedEnvelope retain(int increment) { + ReferenceCountUtil.retain(message, increment); + return this; + } + + @Override + public boolean release() { + return ReferenceCountUtil.release(message); + } + + @Override + public boolean release(int decrement) { + return ReferenceCountUtil.release(message, decrement); + } + + @Override + public AddressedEnvelope touch() { + ReferenceCountUtil.touch(message); + return this; + } + + @Override + public AddressedEnvelope touch(Object hint) { + ReferenceCountUtil.touch(message, hint); + return this; + } + + @Override + public String toString() { + if (sender != null) { + return StringUtil.simpleClassName(this) + + '(' + sender + " => " + recipient + ", " + message + ')'; + } else { + return StringUtil.simpleClassName(this) + + "(=> " + recipient + ", " + message + ')'; + } + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/DefaultChannelConfig.java b/netty-channel/src/main/java/io/netty/channel/DefaultChannelConfig.java new file mode 100644 index 0000000..c650601 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/DefaultChannelConfig.java @@ -0,0 +1,442 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.buffer.ByteBufAllocator; +import io.netty.util.internal.ObjectUtil; + +import java.util.IdentityHashMap; +import java.util.Map; +import java.util.Map.Entry; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; + +import static io.netty.channel.ChannelOption.ALLOCATOR; +import static io.netty.channel.ChannelOption.AUTO_CLOSE; +import static io.netty.channel.ChannelOption.AUTO_READ; +import static io.netty.channel.ChannelOption.CONNECT_TIMEOUT_MILLIS; +import static io.netty.channel.ChannelOption.MAX_MESSAGES_PER_READ; +import static io.netty.channel.ChannelOption.MAX_MESSAGES_PER_WRITE; +import static io.netty.channel.ChannelOption.MESSAGE_SIZE_ESTIMATOR; +import static io.netty.channel.ChannelOption.RCVBUF_ALLOCATOR; +import static io.netty.channel.ChannelOption.SINGLE_EVENTEXECUTOR_PER_GROUP; +import static io.netty.channel.ChannelOption.WRITE_BUFFER_HIGH_WATER_MARK; +import static io.netty.channel.ChannelOption.WRITE_BUFFER_LOW_WATER_MARK; +import static io.netty.channel.ChannelOption.WRITE_BUFFER_WATER_MARK; +import static io.netty.channel.ChannelOption.WRITE_SPIN_COUNT; +import static io.netty.util.internal.ObjectUtil.checkNotNull; +import static io.netty.util.internal.ObjectUtil.checkPositive; +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; + +/** + * The default {@link ChannelConfig} implementation. + */ +public class DefaultChannelConfig implements ChannelConfig { + private static final MessageSizeEstimator DEFAULT_MSG_SIZE_ESTIMATOR = DefaultMessageSizeEstimator.DEFAULT; + + private static final int DEFAULT_CONNECT_TIMEOUT = 30000; + + private static final AtomicIntegerFieldUpdater AUTOREAD_UPDATER = + AtomicIntegerFieldUpdater.newUpdater(DefaultChannelConfig.class, "autoRead"); + private static final AtomicReferenceFieldUpdater WATERMARK_UPDATER = + AtomicReferenceFieldUpdater.newUpdater( + DefaultChannelConfig.class, WriteBufferWaterMark.class, "writeBufferWaterMark"); + + protected final Channel channel; + + private volatile ByteBufAllocator allocator = ByteBufAllocator.DEFAULT; + private volatile RecvByteBufAllocator rcvBufAllocator; + private volatile MessageSizeEstimator msgSizeEstimator = DEFAULT_MSG_SIZE_ESTIMATOR; + + private volatile int connectTimeoutMillis = DEFAULT_CONNECT_TIMEOUT; + private volatile int writeSpinCount = 16; + private volatile int maxMessagesPerWrite = Integer.MAX_VALUE; + + @SuppressWarnings("FieldMayBeFinal") + private volatile int autoRead = 1; + private volatile boolean autoClose = true; + private volatile WriteBufferWaterMark writeBufferWaterMark = WriteBufferWaterMark.DEFAULT; + private volatile boolean pinEventExecutor = true; + + public DefaultChannelConfig(Channel channel) { + this(channel, new AdaptiveRecvByteBufAllocator()); + } + + protected DefaultChannelConfig(Channel channel, RecvByteBufAllocator allocator) { + setRecvByteBufAllocator(allocator, channel.metadata()); + this.channel = channel; + } + + @Override + @SuppressWarnings("deprecation") + public Map, Object> getOptions() { + return getOptions( + null, + CONNECT_TIMEOUT_MILLIS, MAX_MESSAGES_PER_READ, WRITE_SPIN_COUNT, + ALLOCATOR, AUTO_READ, AUTO_CLOSE, RCVBUF_ALLOCATOR, WRITE_BUFFER_HIGH_WATER_MARK, + WRITE_BUFFER_LOW_WATER_MARK, WRITE_BUFFER_WATER_MARK, MESSAGE_SIZE_ESTIMATOR, + SINGLE_EVENTEXECUTOR_PER_GROUP, MAX_MESSAGES_PER_WRITE); + } + + protected Map, Object> getOptions( + Map, Object> result, ChannelOption... options) { + if (result == null) { + result = new IdentityHashMap, Object>(); + } + for (ChannelOption o: options) { + result.put(o, getOption(o)); + } + return result; + } + + @SuppressWarnings("unchecked") + @Override + public boolean setOptions(Map, ?> options) { + ObjectUtil.checkNotNull(options, "options"); + + boolean setAllOptions = true; + for (Entry, ?> e: options.entrySet()) { + if (!setOption((ChannelOption) e.getKey(), e.getValue())) { + setAllOptions = false; + } + } + + return setAllOptions; + } + + @Override + @SuppressWarnings({ "unchecked", "deprecation" }) + public T getOption(ChannelOption option) { + ObjectUtil.checkNotNull(option, "option"); + + if (option == CONNECT_TIMEOUT_MILLIS) { + return (T) Integer.valueOf(getConnectTimeoutMillis()); + } + if (option == MAX_MESSAGES_PER_READ) { + return (T) Integer.valueOf(getMaxMessagesPerRead()); + } + if (option == WRITE_SPIN_COUNT) { + return (T) Integer.valueOf(getWriteSpinCount()); + } + if (option == ALLOCATOR) { + return (T) getAllocator(); + } + if (option == RCVBUF_ALLOCATOR) { + return (T) getRecvByteBufAllocator(); + } + if (option == AUTO_READ) { + return (T) Boolean.valueOf(isAutoRead()); + } + if (option == AUTO_CLOSE) { + return (T) Boolean.valueOf(isAutoClose()); + } + if (option == WRITE_BUFFER_HIGH_WATER_MARK) { + return (T) Integer.valueOf(getWriteBufferHighWaterMark()); + } + if (option == WRITE_BUFFER_LOW_WATER_MARK) { + return (T) Integer.valueOf(getWriteBufferLowWaterMark()); + } + if (option == WRITE_BUFFER_WATER_MARK) { + return (T) getWriteBufferWaterMark(); + } + if (option == MESSAGE_SIZE_ESTIMATOR) { + return (T) getMessageSizeEstimator(); + } + if (option == SINGLE_EVENTEXECUTOR_PER_GROUP) { + return (T) Boolean.valueOf(getPinEventExecutorPerGroup()); + } + if (option == MAX_MESSAGES_PER_WRITE) { + return (T) Integer.valueOf(getMaxMessagesPerWrite()); + } + return null; + } + + @Override + @SuppressWarnings("deprecation") + public boolean setOption(ChannelOption option, T value) { + validate(option, value); + + if (option == CONNECT_TIMEOUT_MILLIS) { + setConnectTimeoutMillis((Integer) value); + } else if (option == MAX_MESSAGES_PER_READ) { + setMaxMessagesPerRead((Integer) value); + } else if (option == WRITE_SPIN_COUNT) { + setWriteSpinCount((Integer) value); + } else if (option == ALLOCATOR) { + setAllocator((ByteBufAllocator) value); + } else if (option == RCVBUF_ALLOCATOR) { + setRecvByteBufAllocator((RecvByteBufAllocator) value); + } else if (option == AUTO_READ) { + setAutoRead((Boolean) value); + } else if (option == AUTO_CLOSE) { + setAutoClose((Boolean) value); + } else if (option == WRITE_BUFFER_HIGH_WATER_MARK) { + setWriteBufferHighWaterMark((Integer) value); + } else if (option == WRITE_BUFFER_LOW_WATER_MARK) { + setWriteBufferLowWaterMark((Integer) value); + } else if (option == WRITE_BUFFER_WATER_MARK) { + setWriteBufferWaterMark((WriteBufferWaterMark) value); + } else if (option == MESSAGE_SIZE_ESTIMATOR) { + setMessageSizeEstimator((MessageSizeEstimator) value); + } else if (option == SINGLE_EVENTEXECUTOR_PER_GROUP) { + setPinEventExecutorPerGroup((Boolean) value); + } else if (option == MAX_MESSAGES_PER_WRITE) { + setMaxMessagesPerWrite((Integer) value); + } else { + return false; + } + + return true; + } + + protected void validate(ChannelOption option, T value) { + ObjectUtil.checkNotNull(option, "option").validate(value); + } + + @Override + public int getConnectTimeoutMillis() { + return connectTimeoutMillis; + } + + @Override + public ChannelConfig setConnectTimeoutMillis(int connectTimeoutMillis) { + checkPositiveOrZero(connectTimeoutMillis, "connectTimeoutMillis"); + this.connectTimeoutMillis = connectTimeoutMillis; + return this; + } + + /** + * {@inheritDoc} + *

+ * @throws IllegalStateException if {@link #getRecvByteBufAllocator()} does not return an object of type + * {@link MaxMessagesRecvByteBufAllocator}. + */ + @Override + @Deprecated + public int getMaxMessagesPerRead() { + try { + MaxMessagesRecvByteBufAllocator allocator = getRecvByteBufAllocator(); + return allocator.maxMessagesPerRead(); + } catch (ClassCastException e) { + throw new IllegalStateException("getRecvByteBufAllocator() must return an object of type " + + "MaxMessagesRecvByteBufAllocator", e); + } + } + + /** + * {@inheritDoc} + *

+ * @throws IllegalStateException if {@link #getRecvByteBufAllocator()} does not return an object of type + * {@link MaxMessagesRecvByteBufAllocator}. + */ + @Override + @Deprecated + public ChannelConfig setMaxMessagesPerRead(int maxMessagesPerRead) { + try { + MaxMessagesRecvByteBufAllocator allocator = getRecvByteBufAllocator(); + allocator.maxMessagesPerRead(maxMessagesPerRead); + return this; + } catch (ClassCastException e) { + throw new IllegalStateException("getRecvByteBufAllocator() must return an object of type " + + "MaxMessagesRecvByteBufAllocator", e); + } + } + + /** + * Get the maximum number of message to write per eventloop run. Once this limit is + * reached we will continue to process other events before trying to write the remaining messages. + */ + public int getMaxMessagesPerWrite() { + return maxMessagesPerWrite; + } + + /** + * Set the maximum number of message to write per eventloop run. Once this limit is + * reached we will continue to process other events before trying to write the remaining messages. + */ + public ChannelConfig setMaxMessagesPerWrite(int maxMessagesPerWrite) { + this.maxMessagesPerWrite = ObjectUtil.checkPositive(maxMessagesPerWrite, "maxMessagesPerWrite"); + return this; + } + + @Override + public int getWriteSpinCount() { + return writeSpinCount; + } + + @Override + public ChannelConfig setWriteSpinCount(int writeSpinCount) { + checkPositive(writeSpinCount, "writeSpinCount"); + // Integer.MAX_VALUE is used as a special value in the channel implementations to indicate the channel cannot + // accept any more data, and results in the writeOp being set on the selector (or execute a runnable which tries + // to flush later because the writeSpinCount quantum has been exhausted). This strategy prevents additional + // conditional logic in the channel implementations, and shouldn't be noticeable in practice. + if (writeSpinCount == Integer.MAX_VALUE) { + --writeSpinCount; + } + this.writeSpinCount = writeSpinCount; + return this; + } + + @Override + public ByteBufAllocator getAllocator() { + return allocator; + } + + @Override + public ChannelConfig setAllocator(ByteBufAllocator allocator) { + this.allocator = ObjectUtil.checkNotNull(allocator, "allocator"); + return this; + } + + @SuppressWarnings("unchecked") + @Override + public T getRecvByteBufAllocator() { + return (T) rcvBufAllocator; + } + + @Override + public ChannelConfig setRecvByteBufAllocator(RecvByteBufAllocator allocator) { + rcvBufAllocator = checkNotNull(allocator, "allocator"); + return this; + } + + /** + * Set the {@link RecvByteBufAllocator} which is used for the channel to allocate receive buffers. + * @param allocator the allocator to set. + * @param metadata Used to set the {@link ChannelMetadata#defaultMaxMessagesPerRead()} if {@code allocator} + * is of type {@link MaxMessagesRecvByteBufAllocator}. + */ + private void setRecvByteBufAllocator(RecvByteBufAllocator allocator, ChannelMetadata metadata) { + checkNotNull(allocator, "allocator"); + checkNotNull(metadata, "metadata"); + if (allocator instanceof MaxMessagesRecvByteBufAllocator) { + ((MaxMessagesRecvByteBufAllocator) allocator).maxMessagesPerRead(metadata.defaultMaxMessagesPerRead()); + } + setRecvByteBufAllocator(allocator); + } + + @Override + public boolean isAutoRead() { + return autoRead == 1; + } + + @Override + public ChannelConfig setAutoRead(boolean autoRead) { + boolean oldAutoRead = AUTOREAD_UPDATER.getAndSet(this, autoRead ? 1 : 0) == 1; + if (autoRead && !oldAutoRead) { + channel.read(); + } else if (!autoRead && oldAutoRead) { + autoReadCleared(); + } + return this; + } + + /** + * Is called once {@link #setAutoRead(boolean)} is called with {@code false} and {@link #isAutoRead()} was + * {@code true} before. + */ + protected void autoReadCleared() { } + + @Override + public boolean isAutoClose() { + return autoClose; + } + + @Override + public ChannelConfig setAutoClose(boolean autoClose) { + this.autoClose = autoClose; + return this; + } + + @Override + public int getWriteBufferHighWaterMark() { + return writeBufferWaterMark.high(); + } + + @Override + public ChannelConfig setWriteBufferHighWaterMark(int writeBufferHighWaterMark) { + checkPositiveOrZero(writeBufferHighWaterMark, "writeBufferHighWaterMark"); + for (;;) { + WriteBufferWaterMark waterMark = writeBufferWaterMark; + if (writeBufferHighWaterMark < waterMark.low()) { + throw new IllegalArgumentException( + "writeBufferHighWaterMark cannot be less than " + + "writeBufferLowWaterMark (" + waterMark.low() + "): " + + writeBufferHighWaterMark); + } + if (WATERMARK_UPDATER.compareAndSet(this, waterMark, + new WriteBufferWaterMark(waterMark.low(), writeBufferHighWaterMark, false))) { + return this; + } + } + } + + @Override + public int getWriteBufferLowWaterMark() { + return writeBufferWaterMark.low(); + } + + @Override + public ChannelConfig setWriteBufferLowWaterMark(int writeBufferLowWaterMark) { + checkPositiveOrZero(writeBufferLowWaterMark, "writeBufferLowWaterMark"); + for (;;) { + WriteBufferWaterMark waterMark = writeBufferWaterMark; + if (writeBufferLowWaterMark > waterMark.high()) { + throw new IllegalArgumentException( + "writeBufferLowWaterMark cannot be greater than " + + "writeBufferHighWaterMark (" + waterMark.high() + "): " + + writeBufferLowWaterMark); + } + if (WATERMARK_UPDATER.compareAndSet(this, waterMark, + new WriteBufferWaterMark(writeBufferLowWaterMark, waterMark.high(), false))) { + return this; + } + } + } + + @Override + public ChannelConfig setWriteBufferWaterMark(WriteBufferWaterMark writeBufferWaterMark) { + this.writeBufferWaterMark = checkNotNull(writeBufferWaterMark, "writeBufferWaterMark"); + return this; + } + + @Override + public WriteBufferWaterMark getWriteBufferWaterMark() { + return writeBufferWaterMark; + } + + @Override + public MessageSizeEstimator getMessageSizeEstimator() { + return msgSizeEstimator; + } + + @Override + public ChannelConfig setMessageSizeEstimator(MessageSizeEstimator estimator) { + this.msgSizeEstimator = ObjectUtil.checkNotNull(estimator, "estimator"); + return this; + } + + private ChannelConfig setPinEventExecutorPerGroup(boolean pinEventExecutor) { + this.pinEventExecutor = pinEventExecutor; + return this; + } + + private boolean getPinEventExecutorPerGroup() { + return pinEventExecutor; + } + +} diff --git a/netty-channel/src/main/java/io/netty/channel/DefaultChannelHandlerContext.java b/netty-channel/src/main/java/io/netty/channel/DefaultChannelHandlerContext.java new file mode 100644 index 0000000..0916793 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/DefaultChannelHandlerContext.java @@ -0,0 +1,34 @@ +/* +* Copyright 2014 The Netty Project +* +* The Netty Project licenses this file to you under the Apache License, +* version 2.0 (the "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at: +* +* https://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ +package io.netty.channel; + +import io.netty.util.concurrent.EventExecutor; + +final class DefaultChannelHandlerContext extends AbstractChannelHandlerContext { + + private final ChannelHandler handler; + + DefaultChannelHandlerContext( + DefaultChannelPipeline pipeline, EventExecutor executor, String name, ChannelHandler handler) { + super(pipeline, executor, name, handler.getClass()); + this.handler = handler; + } + + @Override + public ChannelHandler handler() { + return handler; + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/DefaultChannelId.java b/netty-channel/src/main/java/io/netty/channel/DefaultChannelId.java new file mode 100644 index 0000000..cfd00d4 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/DefaultChannelId.java @@ -0,0 +1,321 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.channel; + +import io.netty.buffer.ByteBufUtil; +import io.netty.util.internal.EmptyArrays; +import io.netty.util.internal.MacAddressUtil; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.SystemPropertyUtil; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.lang.reflect.Method; +import java.util.Arrays; +import java.util.concurrent.atomic.AtomicInteger; + +import static io.netty.util.internal.MacAddressUtil.defaultMachineId; +import static io.netty.util.internal.MacAddressUtil.parseMAC; + +/** + * The default {@link ChannelId} implementation. + */ +public final class DefaultChannelId implements ChannelId { + + private static final long serialVersionUID = 3884076183504074063L; + + private static final InternalLogger logger = InternalLoggerFactory.getInstance(DefaultChannelId.class); + private static final byte[] MACHINE_ID; + private static final int PROCESS_ID_LEN = 4; + private static final int PROCESS_ID; + private static final int SEQUENCE_LEN = 4; + private static final int TIMESTAMP_LEN = 8; + private static final int RANDOM_LEN = 4; + + private static final AtomicInteger nextSequence = new AtomicInteger(); + + /** + * Returns a new {@link DefaultChannelId} instance. + */ + public static DefaultChannelId newInstance() { + return new DefaultChannelId(); + } + + static { + int processId = -1; + String customProcessId = SystemPropertyUtil.get("io.netty.processId"); + if (customProcessId != null) { + try { + processId = Integer.parseInt(customProcessId); + } catch (NumberFormatException e) { + // Malformed input. + } + + if (processId < 0) { + processId = -1; + logger.warn("-Dio.netty.processId: {} (malformed)", customProcessId); + } else if (logger.isDebugEnabled()) { + logger.debug("-Dio.netty.processId: {} (user-set)", processId); + } + } + + if (processId < 0) { + processId = defaultProcessId(); + if (logger.isDebugEnabled()) { + logger.debug("-Dio.netty.processId: {} (auto-detected)", processId); + } + } + + PROCESS_ID = processId; + + byte[] machineId = null; + String customMachineId = SystemPropertyUtil.get("io.netty.machineId"); + if (customMachineId != null) { + try { + machineId = parseMAC(customMachineId); + } catch (Exception e) { + logger.warn("-Dio.netty.machineId: {} (malformed)", customMachineId, e); + } + if (machineId != null) { + logger.debug("-Dio.netty.machineId: {} (user-set)", customMachineId); + } + } + + if (machineId == null) { + machineId = defaultMachineId(); + if (logger.isDebugEnabled()) { + logger.debug("-Dio.netty.machineId: {} (auto-detected)", MacAddressUtil.formatAddress(machineId)); + } + } + + MACHINE_ID = machineId; + } + + static int processHandlePid(ClassLoader loader) { + // pid is positive on unix, non{-1,0} on windows + int nilValue = -1; + if (PlatformDependent.javaVersion() >= 9) { + Long pid; + try { + Class processHandleImplType = Class.forName("java.lang.ProcessHandle", true, loader); + Method processHandleCurrent = processHandleImplType.getMethod("current"); + Object processHandleInstance = processHandleCurrent.invoke(null); + Method processHandlePid = processHandleImplType.getMethod("pid"); + pid = (Long) processHandlePid.invoke(processHandleInstance); + } catch (Exception e) { + logger.debug("Could not invoke ProcessHandle.current().pid();", e); + return nilValue; + } + if (pid > Integer.MAX_VALUE || pid < Integer.MIN_VALUE) { + throw new IllegalStateException("Current process ID exceeds int range: " + pid); + } + return pid.intValue(); + } + return nilValue; + } + + static int jmxPid(ClassLoader loader) { + String value; + try { + // Invoke java.lang.management.ManagementFactory.getRuntimeMXBean().getName() + Class mgmtFactoryType = Class.forName("java.lang.management.ManagementFactory", true, loader); + Class runtimeMxBeanType = Class.forName("java.lang.management.RuntimeMXBean", true, loader); + + Method getRuntimeMXBean = mgmtFactoryType.getMethod("getRuntimeMXBean", EmptyArrays.EMPTY_CLASSES); + Object bean = getRuntimeMXBean.invoke(null, EmptyArrays.EMPTY_OBJECTS); + Method getName = runtimeMxBeanType.getMethod("getName", EmptyArrays.EMPTY_CLASSES); + value = (String) getName.invoke(bean, EmptyArrays.EMPTY_OBJECTS); + } catch (Throwable t) { + logger.debug("Could not invoke ManagementFactory.getRuntimeMXBean().getName(); Android?", t); + try { + // Invoke android.os.Process.myPid() + Class processType = Class.forName("android.os.Process", true, loader); + Method myPid = processType.getMethod("myPid", EmptyArrays.EMPTY_CLASSES); + value = myPid.invoke(null, EmptyArrays.EMPTY_OBJECTS).toString(); + } catch (Throwable t2) { + logger.debug("Could not invoke Process.myPid(); not Android?", t2); + value = ""; + } + } + + int atIndex = value.indexOf('@'); + if (atIndex >= 0) { + value = value.substring(0, atIndex); + } + + int pid; + try { + pid = Integer.parseInt(value); + } catch (NumberFormatException e) { + // value did not contain an integer. + pid = -1; + } + + if (pid < 0) { + pid = PlatformDependent.threadLocalRandom().nextInt(); + logger.warn("Failed to find the current process ID from '{}'; using a random value: {}", value, pid); + } + + return pid; + } + + static int defaultProcessId() { + ClassLoader loader = PlatformDependent.getClassLoader(DefaultChannelId.class); + int processId = processHandlePid(loader); + if (processId != -1) { + return processId; + } + return jmxPid(loader); + } + + private final byte[] data; + private final int hashCode; + + private transient String shortValue; + private transient String longValue; + + private DefaultChannelId() { + data = new byte[MACHINE_ID.length + PROCESS_ID_LEN + SEQUENCE_LEN + TIMESTAMP_LEN + RANDOM_LEN]; + int i = 0; + + // machineId + System.arraycopy(MACHINE_ID, 0, data, i, MACHINE_ID.length); + i += MACHINE_ID.length; + + // processId + i = writeInt(i, PROCESS_ID); + + // sequence + i = writeInt(i, nextSequence.getAndIncrement()); + + // timestamp (kind of) + i = writeLong(i, Long.reverse(System.nanoTime()) ^ System.currentTimeMillis()); + + // random + int random = PlatformDependent.threadLocalRandom().nextInt(); + i = writeInt(i, random); + assert i == data.length; + + hashCode = Arrays.hashCode(data); + } + + private int writeInt(int i, int value) { + data[i ++] = (byte) (value >>> 24); + data[i ++] = (byte) (value >>> 16); + data[i ++] = (byte) (value >>> 8); + data[i ++] = (byte) value; + return i; + } + + private int writeLong(int i, long value) { + data[i ++] = (byte) (value >>> 56); + data[i ++] = (byte) (value >>> 48); + data[i ++] = (byte) (value >>> 40); + data[i ++] = (byte) (value >>> 32); + data[i ++] = (byte) (value >>> 24); + data[i ++] = (byte) (value >>> 16); + data[i ++] = (byte) (value >>> 8); + data[i ++] = (byte) value; + return i; + } + + @Override + public String asShortText() { + String shortValue = this.shortValue; + if (shortValue == null) { + this.shortValue = shortValue = ByteBufUtil.hexDump(data, data.length - RANDOM_LEN, RANDOM_LEN); + } + return shortValue; + } + + @Override + public String asLongText() { + String longValue = this.longValue; + if (longValue == null) { + this.longValue = longValue = newLongValue(); + } + return longValue; + } + + private String newLongValue() { + StringBuilder buf = new StringBuilder(2 * data.length + 5); + int i = 0; + i = appendHexDumpField(buf, i, MACHINE_ID.length); + i = appendHexDumpField(buf, i, PROCESS_ID_LEN); + i = appendHexDumpField(buf, i, SEQUENCE_LEN); + i = appendHexDumpField(buf, i, TIMESTAMP_LEN); + i = appendHexDumpField(buf, i, RANDOM_LEN); + assert i == data.length; + return buf.substring(0, buf.length() - 1); + } + + private int appendHexDumpField(StringBuilder buf, int i, int length) { + buf.append(ByteBufUtil.hexDump(data, i, length)); + buf.append('-'); + i += length; + return i; + } + + @Override + public int hashCode() { + return hashCode; + } + + @Override + public int compareTo(final ChannelId o) { + if (this == o) { + // short circuit + return 0; + } + if (o instanceof DefaultChannelId) { + // lexicographic comparison + final byte[] otherData = ((DefaultChannelId) o).data; + int len1 = data.length; + int len2 = otherData.length; + int len = Math.min(len1, len2); + + for (int k = 0; k < len; k++) { + byte x = data[k]; + byte y = otherData[k]; + if (x != y) { + // treat these as unsigned bytes for comparison + return (x & 0xff) - (y & 0xff); + } + } + return len1 - len2; + } + + return asLongText().compareTo(o.asLongText()); + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof DefaultChannelId)) { + return false; + } + DefaultChannelId other = (DefaultChannelId) obj; + return hashCode == other.hashCode && Arrays.equals(data, other.data); + } + + @Override + public String toString() { + return asShortText(); + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/DefaultChannelPipeline.java b/netty-channel/src/main/java/io/netty/channel/DefaultChannelPipeline.java new file mode 100644 index 0000000..3bd860c --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/DefaultChannelPipeline.java @@ -0,0 +1,1511 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.channel.Channel.Unsafe; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.ResourceLeakDetector; +import io.netty.util.concurrent.EventExecutor; +import io.netty.util.concurrent.EventExecutorGroup; +import io.netty.util.concurrent.FastThreadLocal; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.StringUtil; +import io.netty.util.internal.UnstableApi; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.net.SocketAddress; +import java.util.ArrayList; +import java.util.IdentityHashMap; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.WeakHashMap; +import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; + +/** + * The default {@link ChannelPipeline} implementation. It is usually created + * by a {@link Channel} implementation when the {@link Channel} is created. + */ +public class DefaultChannelPipeline implements ChannelPipeline { + + static final InternalLogger logger = InternalLoggerFactory.getInstance(DefaultChannelPipeline.class); + + private static final String HEAD_NAME = generateName0(HeadContext.class); + private static final String TAIL_NAME = generateName0(TailContext.class); + + private static final FastThreadLocal, String>> nameCaches = + new FastThreadLocal, String>>() { + @Override + protected Map, String> initialValue() { + return new WeakHashMap, String>(); + } + }; + + private static final AtomicReferenceFieldUpdater ESTIMATOR = + AtomicReferenceFieldUpdater.newUpdater( + DefaultChannelPipeline.class, MessageSizeEstimator.Handle.class, "estimatorHandle"); + final HeadContext head; + final TailContext tail; + + private final Channel channel; + private final ChannelFuture succeededFuture; + private final VoidChannelPromise voidPromise; + private final boolean touch = ResourceLeakDetector.isEnabled(); + + private Map childExecutors; + private volatile MessageSizeEstimator.Handle estimatorHandle; + private boolean firstRegistration = true; + + /** + * This is the head of a linked list that is processed by {@link #callHandlerAddedForAllHandlers()} and so process + * all the pending {@link #callHandlerAdded0(AbstractChannelHandlerContext)}. + * + * We only keep the head because it is expected that the list is used infrequently and its size is small. + * Thus full iterations to do insertions is assumed to be a good compromised to saving memory and tail management + * complexity. + */ + private PendingHandlerCallback pendingHandlerCallbackHead; + + /** + * Set to {@code true} once the {@link AbstractChannel} is registered.Once set to {@code true} the value will never + * change. + */ + private boolean registered; + + protected DefaultChannelPipeline(Channel channel) { + this.channel = ObjectUtil.checkNotNull(channel, "channel"); + succeededFuture = new SucceededChannelFuture(channel, null); + voidPromise = new VoidChannelPromise(channel, true); + + tail = new TailContext(this); + head = new HeadContext(this); + + head.next = tail; + tail.prev = head; + } + + final MessageSizeEstimator.Handle estimatorHandle() { + MessageSizeEstimator.Handle handle = estimatorHandle; + if (handle == null) { + handle = channel.config().getMessageSizeEstimator().newHandle(); + if (!ESTIMATOR.compareAndSet(this, null, handle)) { + handle = estimatorHandle; + } + } + return handle; + } + + final Object touch(Object msg, AbstractChannelHandlerContext next) { + return touch ? ReferenceCountUtil.touch(msg, next) : msg; + } + + private AbstractChannelHandlerContext newContext(EventExecutorGroup group, String name, ChannelHandler handler) { + return new DefaultChannelHandlerContext(this, childExecutor(group), name, handler); + } + + private EventExecutor childExecutor(EventExecutorGroup group) { + if (group == null) { + return null; + } + Boolean pinEventExecutor = channel.config().getOption(ChannelOption.SINGLE_EVENTEXECUTOR_PER_GROUP); + if (pinEventExecutor != null && !pinEventExecutor) { + return group.next(); + } + Map childExecutors = this.childExecutors; + if (childExecutors == null) { + // Use size of 4 as most people only use one extra EventExecutor. + childExecutors = this.childExecutors = new IdentityHashMap(4); + } + // Pin one of the child executors once and remember it so that the same child executor + // is used to fire events for the same channel. + EventExecutor childExecutor = childExecutors.get(group); + if (childExecutor == null) { + childExecutor = group.next(); + childExecutors.put(group, childExecutor); + } + return childExecutor; + } + @Override + public final Channel channel() { + return channel; + } + + @Override + public final ChannelPipeline addFirst(String name, ChannelHandler handler) { + return addFirst(null, name, handler); + } + + @Override + public final ChannelPipeline addFirst(EventExecutorGroup group, String name, ChannelHandler handler) { + final AbstractChannelHandlerContext newCtx; + synchronized (this) { + checkMultiplicity(handler); + name = filterName(name, handler); + + newCtx = newContext(group, name, handler); + + addFirst0(newCtx); + + // If the registered is false it means that the channel was not registered on an eventLoop yet. + // In this case we add the context to the pipeline and add a task that will call + // ChannelHandler.handlerAdded(...) once the channel is registered. + if (!registered) { + newCtx.setAddPending(); + callHandlerCallbackLater(newCtx, true); + return this; + } + + EventExecutor executor = newCtx.executor(); + if (!executor.inEventLoop()) { + callHandlerAddedInEventLoop(newCtx, executor); + return this; + } + } + callHandlerAdded0(newCtx); + return this; + } + + private void addFirst0(AbstractChannelHandlerContext newCtx) { + AbstractChannelHandlerContext nextCtx = head.next; + newCtx.prev = head; + newCtx.next = nextCtx; + head.next = newCtx; + nextCtx.prev = newCtx; + } + + @Override + public final ChannelPipeline addLast(String name, ChannelHandler handler) { + return addLast(null, name, handler); + } + + @Override + public final ChannelPipeline addLast(EventExecutorGroup group, String name, ChannelHandler handler) { + final AbstractChannelHandlerContext newCtx; + synchronized (this) { + checkMultiplicity(handler); + + newCtx = newContext(group, filterName(name, handler), handler); + + addLast0(newCtx); + + // If the registered is false it means that the channel was not registered on an eventLoop yet. + // In this case we add the context to the pipeline and add a task that will call + // ChannelHandler.handlerAdded(...) once the channel is registered. + if (!registered) { + newCtx.setAddPending(); + callHandlerCallbackLater(newCtx, true); + return this; + } + + EventExecutor executor = newCtx.executor(); + if (!executor.inEventLoop()) { + callHandlerAddedInEventLoop(newCtx, executor); + return this; + } + } + callHandlerAdded0(newCtx); + return this; + } + + private void addLast0(AbstractChannelHandlerContext newCtx) { + AbstractChannelHandlerContext prev = tail.prev; + newCtx.prev = prev; + newCtx.next = tail; + prev.next = newCtx; + tail.prev = newCtx; + } + + @Override + public final ChannelPipeline addBefore(String baseName, String name, ChannelHandler handler) { + return addBefore(null, baseName, name, handler); + } + + @Override + public final ChannelPipeline addBefore( + EventExecutorGroup group, String baseName, String name, ChannelHandler handler) { + final AbstractChannelHandlerContext newCtx; + final AbstractChannelHandlerContext ctx; + synchronized (this) { + checkMultiplicity(handler); + name = filterName(name, handler); + ctx = getContextOrDie(baseName); + + newCtx = newContext(group, name, handler); + + addBefore0(ctx, newCtx); + + // If the registered is false it means that the channel was not registered on an eventLoop yet. + // In this case we add the context to the pipeline and add a task that will call + // ChannelHandler.handlerAdded(...) once the channel is registered. + if (!registered) { + newCtx.setAddPending(); + callHandlerCallbackLater(newCtx, true); + return this; + } + + EventExecutor executor = newCtx.executor(); + if (!executor.inEventLoop()) { + callHandlerAddedInEventLoop(newCtx, executor); + return this; + } + } + callHandlerAdded0(newCtx); + return this; + } + + private static void addBefore0(AbstractChannelHandlerContext ctx, AbstractChannelHandlerContext newCtx) { + newCtx.prev = ctx.prev; + newCtx.next = ctx; + ctx.prev.next = newCtx; + ctx.prev = newCtx; + } + + private String filterName(String name, ChannelHandler handler) { + if (name == null) { + return generateName(handler); + } + checkDuplicateName(name); + return name; + } + + @Override + public final ChannelPipeline addAfter(String baseName, String name, ChannelHandler handler) { + return addAfter(null, baseName, name, handler); + } + + @Override + public final ChannelPipeline addAfter( + EventExecutorGroup group, String baseName, String name, ChannelHandler handler) { + final AbstractChannelHandlerContext newCtx; + final AbstractChannelHandlerContext ctx; + + synchronized (this) { + checkMultiplicity(handler); + name = filterName(name, handler); + ctx = getContextOrDie(baseName); + + newCtx = newContext(group, name, handler); + + addAfter0(ctx, newCtx); + + // If the registered is false it means that the channel was not registered on an eventLoop yet. + // In this case we remove the context from the pipeline and add a task that will call + // ChannelHandler.handlerRemoved(...) once the channel is registered. + if (!registered) { + newCtx.setAddPending(); + callHandlerCallbackLater(newCtx, true); + return this; + } + EventExecutor executor = newCtx.executor(); + if (!executor.inEventLoop()) { + callHandlerAddedInEventLoop(newCtx, executor); + return this; + } + } + callHandlerAdded0(newCtx); + return this; + } + + private static void addAfter0(AbstractChannelHandlerContext ctx, AbstractChannelHandlerContext newCtx) { + newCtx.prev = ctx; + newCtx.next = ctx.next; + ctx.next.prev = newCtx; + ctx.next = newCtx; + } + + public final ChannelPipeline addFirst(ChannelHandler handler) { + return addFirst(null, handler); + } + + @Override + public final ChannelPipeline addFirst(ChannelHandler... handlers) { + return addFirst(null, handlers); + } + + @Override + public final ChannelPipeline addFirst(EventExecutorGroup executor, ChannelHandler... handlers) { + ObjectUtil.checkNotNull(handlers, "handlers"); + if (handlers.length == 0 || handlers[0] == null) { + return this; + } + + int size; + for (size = 1; size < handlers.length; size ++) { + if (handlers[size] == null) { + break; + } + } + + for (int i = size - 1; i >= 0; i --) { + ChannelHandler h = handlers[i]; + addFirst(executor, null, h); + } + + return this; + } + + public final ChannelPipeline addLast(ChannelHandler handler) { + return addLast(null, handler); + } + + @Override + public final ChannelPipeline addLast(ChannelHandler... handlers) { + return addLast(null, handlers); + } + + @Override + public final ChannelPipeline addLast(EventExecutorGroup executor, ChannelHandler... handlers) { + ObjectUtil.checkNotNull(handlers, "handlers"); + + for (ChannelHandler h: handlers) { + if (h == null) { + break; + } + addLast(executor, null, h); + } + + return this; + } + + private String generateName(ChannelHandler handler) { + Map, String> cache = nameCaches.get(); + Class handlerType = handler.getClass(); + String name = cache.get(handlerType); + if (name == null) { + name = generateName0(handlerType); + cache.put(handlerType, name); + } + + // It's not very likely for a user to put more than one handler of the same type, but make sure to avoid + // any name conflicts. Note that we don't cache the names generated here. + if (context0(name) != null) { + String baseName = name.substring(0, name.length() - 1); // Strip the trailing '0'. + for (int i = 1;; i ++) { + String newName = baseName + i; + if (context0(newName) == null) { + name = newName; + break; + } + } + } + return name; + } + + private static String generateName0(Class handlerType) { + return StringUtil.simpleClassName(handlerType) + "#0"; + } + + @Override + public final ChannelPipeline remove(ChannelHandler handler) { + remove(getContextOrDie(handler)); + return this; + } + + @Override + public final ChannelHandler remove(String name) { + return remove(getContextOrDie(name)).handler(); + } + + @SuppressWarnings("unchecked") + @Override + public final T remove(Class handlerType) { + return (T) remove(getContextOrDie(handlerType)).handler(); + } + + public final T removeIfExists(String name) { + return removeIfExists(context(name)); + } + + public final T removeIfExists(Class handlerType) { + return removeIfExists(context(handlerType)); + } + + public final T removeIfExists(ChannelHandler handler) { + return removeIfExists(context(handler)); + } + + @SuppressWarnings("unchecked") + private T removeIfExists(ChannelHandlerContext ctx) { + if (ctx == null) { + return null; + } + return (T) remove((AbstractChannelHandlerContext) ctx).handler(); + } + + private AbstractChannelHandlerContext remove(final AbstractChannelHandlerContext ctx) { + assert ctx != head && ctx != tail; + + synchronized (this) { + atomicRemoveFromHandlerList(ctx); + + // If the registered is false it means that the channel was not registered on an eventloop yet. + // In this case we remove the context from the pipeline and add a task that will call + // ChannelHandler.handlerRemoved(...) once the channel is registered. + if (!registered) { + callHandlerCallbackLater(ctx, false); + return ctx; + } + + EventExecutor executor = ctx.executor(); + if (!executor.inEventLoop()) { + executor.execute(new Runnable() { + @Override + public void run() { + callHandlerRemoved0(ctx); + } + }); + return ctx; + } + } + callHandlerRemoved0(ctx); + return ctx; + } + + /** + * Method is synchronized to make the handler removal from the double linked list atomic. + */ + private synchronized void atomicRemoveFromHandlerList(AbstractChannelHandlerContext ctx) { + AbstractChannelHandlerContext prev = ctx.prev; + AbstractChannelHandlerContext next = ctx.next; + prev.next = next; + next.prev = prev; + } + + @Override + public final ChannelHandler removeFirst() { + if (head.next == tail) { + throw new NoSuchElementException(); + } + return remove(head.next).handler(); + } + + @Override + public final ChannelHandler removeLast() { + if (head.next == tail) { + throw new NoSuchElementException(); + } + return remove(tail.prev).handler(); + } + + @Override + public final ChannelPipeline replace(ChannelHandler oldHandler, String newName, ChannelHandler newHandler) { + replace(getContextOrDie(oldHandler), newName, newHandler); + return this; + } + + @Override + public final ChannelHandler replace(String oldName, String newName, ChannelHandler newHandler) { + return replace(getContextOrDie(oldName), newName, newHandler); + } + + @Override + @SuppressWarnings("unchecked") + public final T replace( + Class oldHandlerType, String newName, ChannelHandler newHandler) { + return (T) replace(getContextOrDie(oldHandlerType), newName, newHandler); + } + + private ChannelHandler replace( + final AbstractChannelHandlerContext ctx, String newName, ChannelHandler newHandler) { + assert ctx != head && ctx != tail; + + final AbstractChannelHandlerContext newCtx; + synchronized (this) { + checkMultiplicity(newHandler); + if (newName == null) { + newName = generateName(newHandler); + } else { + boolean sameName = ctx.name().equals(newName); + if (!sameName) { + checkDuplicateName(newName); + } + } + + newCtx = newContext(ctx.executor, newName, newHandler); + + replace0(ctx, newCtx); + + // If the registered is false it means that the channel was not registered on an eventloop yet. + // In this case we replace the context in the pipeline + // and add a task that will call ChannelHandler.handlerAdded(...) and + // ChannelHandler.handlerRemoved(...) once the channel is registered. + if (!registered) { + callHandlerCallbackLater(newCtx, true); + callHandlerCallbackLater(ctx, false); + return ctx.handler(); + } + EventExecutor executor = ctx.executor(); + if (!executor.inEventLoop()) { + executor.execute(new Runnable() { + @Override + public void run() { + // Invoke newHandler.handlerAdded() first (i.e. before oldHandler.handlerRemoved() is invoked) + // because callHandlerRemoved() will trigger channelRead() or flush() on newHandler and + // those event handlers must be called after handlerAdded(). + callHandlerAdded0(newCtx); + callHandlerRemoved0(ctx); + } + }); + return ctx.handler(); + } + } + // Invoke newHandler.handlerAdded() first (i.e. before oldHandler.handlerRemoved() is invoked) + // because callHandlerRemoved() will trigger channelRead() or flush() on newHandler and those + // event handlers must be called after handlerAdded(). + callHandlerAdded0(newCtx); + callHandlerRemoved0(ctx); + return ctx.handler(); + } + + private static void replace0(AbstractChannelHandlerContext oldCtx, AbstractChannelHandlerContext newCtx) { + AbstractChannelHandlerContext prev = oldCtx.prev; + AbstractChannelHandlerContext next = oldCtx.next; + newCtx.prev = prev; + newCtx.next = next; + + // Finish the replacement of oldCtx with newCtx in the linked list. + // Note that this doesn't mean events will be sent to the new handler immediately + // because we are currently at the event handler thread and no more than one handler methods can be invoked + // at the same time (we ensured that in replace().) + prev.next = newCtx; + next.prev = newCtx; + + // update the reference to the replacement so forward of buffered content will work correctly + oldCtx.prev = newCtx; + oldCtx.next = newCtx; + } + + private static void checkMultiplicity(ChannelHandler handler) { + if (handler instanceof ChannelHandlerAdapter) { + ChannelHandlerAdapter h = (ChannelHandlerAdapter) handler; + if (!h.isSharable() && h.added) { + throw new ChannelPipelineException( + h.getClass().getName() + + " is not a @Sharable handler, so can't be added or removed multiple times."); + } + h.added = true; + } + } + + private void callHandlerAdded0(final AbstractChannelHandlerContext ctx) { + try { + ctx.callHandlerAdded(); + } catch (Throwable t) { + boolean removed = false; + try { + atomicRemoveFromHandlerList(ctx); + ctx.callHandlerRemoved(); + removed = true; + } catch (Throwable t2) { + if (logger.isWarnEnabled()) { + logger.warn("Failed to remove a handler: " + ctx.name(), t2); + } + } + + if (removed) { + fireExceptionCaught(new ChannelPipelineException( + ctx.handler().getClass().getName() + + ".handlerAdded() has thrown an exception; removed.", t)); + } else { + fireExceptionCaught(new ChannelPipelineException( + ctx.handler().getClass().getName() + + ".handlerAdded() has thrown an exception; also failed to remove.", t)); + } + } + } + + private void callHandlerRemoved0(final AbstractChannelHandlerContext ctx) { + // Notify the complete removal. + try { + ctx.callHandlerRemoved(); + } catch (Throwable t) { + fireExceptionCaught(new ChannelPipelineException( + ctx.handler().getClass().getName() + ".handlerRemoved() has thrown an exception.", t)); + } + } + + final void invokeHandlerAddedIfNeeded() { + assert channel.eventLoop().inEventLoop(); + if (firstRegistration) { + firstRegistration = false; + // We are now registered to the EventLoop. It's time to call the callbacks for the ChannelHandlers, + // that were added before the registration was done. + callHandlerAddedForAllHandlers(); + } + } + + @Override + public final ChannelHandler first() { + ChannelHandlerContext first = firstContext(); + if (first == null) { + return null; + } + return first.handler(); + } + + @Override + public final ChannelHandlerContext firstContext() { + AbstractChannelHandlerContext first = head.next; + if (first == tail) { + return null; + } + return head.next; + } + + @Override + public final ChannelHandler last() { + AbstractChannelHandlerContext last = tail.prev; + if (last == head) { + return null; + } + return last.handler(); + } + + @Override + public final ChannelHandlerContext lastContext() { + AbstractChannelHandlerContext last = tail.prev; + if (last == head) { + return null; + } + return last; + } + + @Override + public final ChannelHandler get(String name) { + ChannelHandlerContext ctx = context(name); + if (ctx == null) { + return null; + } else { + return ctx.handler(); + } + } + + @SuppressWarnings("unchecked") + @Override + public final T get(Class handlerType) { + ChannelHandlerContext ctx = context(handlerType); + if (ctx == null) { + return null; + } else { + return (T) ctx.handler(); + } + } + + @Override + public final ChannelHandlerContext context(String name) { + return context0(ObjectUtil.checkNotNull(name, "name")); + } + + @Override + public final ChannelHandlerContext context(ChannelHandler handler) { + ObjectUtil.checkNotNull(handler, "handler"); + + AbstractChannelHandlerContext ctx = head.next; + for (;;) { + + if (ctx == null) { + return null; + } + + if (ctx.handler() == handler) { + return ctx; + } + + ctx = ctx.next; + } + } + + @Override + public final ChannelHandlerContext context(Class handlerType) { + ObjectUtil.checkNotNull(handlerType, "handlerType"); + + AbstractChannelHandlerContext ctx = head.next; + for (;;) { + if (ctx == null) { + return null; + } + if (handlerType.isAssignableFrom(ctx.handler().getClass())) { + return ctx; + } + ctx = ctx.next; + } + } + + @Override + public final List names() { + List list = new ArrayList(); + AbstractChannelHandlerContext ctx = head.next; + for (;;) { + if (ctx == null) { + return list; + } + list.add(ctx.name()); + ctx = ctx.next; + } + } + + @Override + public final Map toMap() { + Map map = new LinkedHashMap(); + AbstractChannelHandlerContext ctx = head.next; + for (;;) { + if (ctx == tail) { + return map; + } + map.put(ctx.name(), ctx.handler()); + ctx = ctx.next; + } + } + + @Override + public final Iterator> iterator() { + return toMap().entrySet().iterator(); + } + + /** + * Returns the {@link String} representation of this pipeline. + */ + @Override + public final String toString() { + StringBuilder buf = new StringBuilder() + .append(StringUtil.simpleClassName(this)) + .append('{'); + AbstractChannelHandlerContext ctx = head.next; + for (;;) { + if (ctx == tail) { + break; + } + + buf.append('(') + .append(ctx.name()) + .append(" = ") + .append(ctx.handler().getClass().getName()) + .append(')'); + + ctx = ctx.next; + if (ctx == tail) { + break; + } + + buf.append(", "); + } + buf.append('}'); + return buf.toString(); + } + + @Override + public final ChannelPipeline fireChannelRegistered() { + AbstractChannelHandlerContext.invokeChannelRegistered(head); + return this; + } + + @Override + public final ChannelPipeline fireChannelUnregistered() { + AbstractChannelHandlerContext.invokeChannelUnregistered(head); + return this; + } + + /** + * Removes all handlers from the pipeline one by one from tail (exclusive) to head (exclusive) to trigger + * handlerRemoved(). + * + * Note that we traverse up the pipeline ({@link #destroyUp(AbstractChannelHandlerContext, boolean)}) + * before traversing down ({@link #destroyDown(Thread, AbstractChannelHandlerContext, boolean)}) so that + * the handlers are removed after all events are handled. + * + * See: https://github.com/netty/netty/issues/3156 + */ + private synchronized void destroy() { + destroyUp(head.next, false); + } + + private void destroyUp(AbstractChannelHandlerContext ctx, boolean inEventLoop) { + final Thread currentThread = Thread.currentThread(); + final AbstractChannelHandlerContext tail = this.tail; + for (;;) { + if (ctx == tail) { + destroyDown(currentThread, tail.prev, inEventLoop); + break; + } + + final EventExecutor executor = ctx.executor(); + if (!inEventLoop && !executor.inEventLoop(currentThread)) { + final AbstractChannelHandlerContext finalCtx = ctx; + executor.execute(new Runnable() { + @Override + public void run() { + destroyUp(finalCtx, true); + } + }); + break; + } + + ctx = ctx.next; + inEventLoop = false; + } + } + + private void destroyDown(Thread currentThread, AbstractChannelHandlerContext ctx, boolean inEventLoop) { + // We have reached at tail; now traverse backwards. + final AbstractChannelHandlerContext head = this.head; + for (;;) { + if (ctx == head) { + break; + } + + final EventExecutor executor = ctx.executor(); + if (inEventLoop || executor.inEventLoop(currentThread)) { + atomicRemoveFromHandlerList(ctx); + callHandlerRemoved0(ctx); + } else { + final AbstractChannelHandlerContext finalCtx = ctx; + executor.execute(new Runnable() { + @Override + public void run() { + destroyDown(Thread.currentThread(), finalCtx, true); + } + }); + break; + } + + ctx = ctx.prev; + inEventLoop = false; + } + } + + @Override + public final ChannelPipeline fireChannelActive() { + AbstractChannelHandlerContext.invokeChannelActive(head); + return this; + } + + @Override + public final ChannelPipeline fireChannelInactive() { + AbstractChannelHandlerContext.invokeChannelInactive(head); + return this; + } + + @Override + public final ChannelPipeline fireExceptionCaught(Throwable cause) { + AbstractChannelHandlerContext.invokeExceptionCaught(head, cause); + return this; + } + + @Override + public final ChannelPipeline fireUserEventTriggered(Object event) { + AbstractChannelHandlerContext.invokeUserEventTriggered(head, event); + return this; + } + + @Override + public final ChannelPipeline fireChannelRead(Object msg) { + AbstractChannelHandlerContext.invokeChannelRead(head, msg); + return this; + } + + @Override + public final ChannelPipeline fireChannelReadComplete() { + AbstractChannelHandlerContext.invokeChannelReadComplete(head); + return this; + } + + @Override + public final ChannelPipeline fireChannelWritabilityChanged() { + AbstractChannelHandlerContext.invokeChannelWritabilityChanged(head); + return this; + } + + @Override + public final ChannelFuture bind(SocketAddress localAddress) { + return tail.bind(localAddress); + } + + @Override + public final ChannelFuture connect(SocketAddress remoteAddress) { + return tail.connect(remoteAddress); + } + + @Override + public final ChannelFuture connect(SocketAddress remoteAddress, SocketAddress localAddress) { + return tail.connect(remoteAddress, localAddress); + } + + @Override + public final ChannelFuture disconnect() { + return tail.disconnect(); + } + + @Override + public final ChannelFuture close() { + return tail.close(); + } + + @Override + public final ChannelFuture deregister() { + return tail.deregister(); + } + + @Override + public final ChannelPipeline flush() { + tail.flush(); + return this; + } + + @Override + public final ChannelFuture bind(SocketAddress localAddress, ChannelPromise promise) { + return tail.bind(localAddress, promise); + } + + @Override + public final ChannelFuture connect(SocketAddress remoteAddress, ChannelPromise promise) { + return tail.connect(remoteAddress, promise); + } + + @Override + public final ChannelFuture connect( + SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) { + return tail.connect(remoteAddress, localAddress, promise); + } + + @Override + public final ChannelFuture disconnect(ChannelPromise promise) { + return tail.disconnect(promise); + } + + @Override + public final ChannelFuture close(ChannelPromise promise) { + return tail.close(promise); + } + + @Override + public final ChannelFuture deregister(final ChannelPromise promise) { + return tail.deregister(promise); + } + + @Override + public final ChannelPipeline read() { + tail.read(); + return this; + } + + @Override + public final ChannelFuture write(Object msg) { + return tail.write(msg); + } + + @Override + public final ChannelFuture write(Object msg, ChannelPromise promise) { + return tail.write(msg, promise); + } + + @Override + public final ChannelFuture writeAndFlush(Object msg, ChannelPromise promise) { + return tail.writeAndFlush(msg, promise); + } + + @Override + public final ChannelFuture writeAndFlush(Object msg) { + return tail.writeAndFlush(msg); + } + + @Override + public final ChannelPromise newPromise() { + return new DefaultChannelPromise(channel); + } + + @Override + public final ChannelProgressivePromise newProgressivePromise() { + return new DefaultChannelProgressivePromise(channel); + } + + @Override + public final ChannelFuture newSucceededFuture() { + return succeededFuture; + } + + @Override + public final ChannelFuture newFailedFuture(Throwable cause) { + return new FailedChannelFuture(channel, null, cause); + } + + @Override + public final ChannelPromise voidPromise() { + return voidPromise; + } + + private void checkDuplicateName(String name) { + if (context0(name) != null) { + throw new IllegalArgumentException("Duplicate handler name: " + name); + } + } + + private AbstractChannelHandlerContext context0(String name) { + AbstractChannelHandlerContext context = head.next; + while (context != tail) { + if (context.name().equals(name)) { + return context; + } + context = context.next; + } + return null; + } + + private AbstractChannelHandlerContext getContextOrDie(String name) { + AbstractChannelHandlerContext ctx = (AbstractChannelHandlerContext) context(name); + if (ctx == null) { + throw new NoSuchElementException(name); + } else { + return ctx; + } + } + + private AbstractChannelHandlerContext getContextOrDie(ChannelHandler handler) { + AbstractChannelHandlerContext ctx = (AbstractChannelHandlerContext) context(handler); + if (ctx == null) { + throw new NoSuchElementException(handler.getClass().getName()); + } else { + return ctx; + } + } + + private AbstractChannelHandlerContext getContextOrDie(Class handlerType) { + AbstractChannelHandlerContext ctx = (AbstractChannelHandlerContext) context(handlerType); + if (ctx == null) { + throw new NoSuchElementException(handlerType.getName()); + } else { + return ctx; + } + } + + private void callHandlerAddedForAllHandlers() { + final PendingHandlerCallback pendingHandlerCallbackHead; + synchronized (this) { + assert !registered; + + // This Channel itself was registered. + registered = true; + + pendingHandlerCallbackHead = this.pendingHandlerCallbackHead; + // Null out so it can be GC'ed. + this.pendingHandlerCallbackHead = null; + } + + // This must happen outside of the synchronized(...) block as otherwise handlerAdded(...) may be called while + // holding the lock and so produce a deadlock if handlerAdded(...) will try to add another handler from outside + // the EventLoop. + PendingHandlerCallback task = pendingHandlerCallbackHead; + while (task != null) { + task.execute(); + task = task.next; + } + } + + private void callHandlerCallbackLater(AbstractChannelHandlerContext ctx, boolean added) { + assert !registered; + + PendingHandlerCallback task = added ? new PendingHandlerAddedTask(ctx) : new PendingHandlerRemovedTask(ctx); + PendingHandlerCallback pending = pendingHandlerCallbackHead; + if (pending == null) { + pendingHandlerCallbackHead = task; + } else { + // Find the tail of the linked-list. + while (pending.next != null) { + pending = pending.next; + } + pending.next = task; + } + } + + private void callHandlerAddedInEventLoop(final AbstractChannelHandlerContext newCtx, EventExecutor executor) { + newCtx.setAddPending(); + executor.execute(new Runnable() { + @Override + public void run() { + callHandlerAdded0(newCtx); + } + }); + } + + /** + * Called once a {@link Throwable} hit the end of the {@link ChannelPipeline} without been handled by the user + * in {@link ChannelHandler#exceptionCaught(ChannelHandlerContext, Throwable)}. + */ + protected void onUnhandledInboundException(Throwable cause) { + try { + logger.warn( + "An exceptionCaught() event was fired, and it reached at the tail of the pipeline. " + + "It usually means the last handler in the pipeline did not handle the exception.", + cause); + } finally { + ReferenceCountUtil.release(cause); + } + } + + /** + * Called once the {@link ChannelInboundHandler#channelActive(ChannelHandlerContext)}event hit + * the end of the {@link ChannelPipeline}. + */ + protected void onUnhandledInboundChannelActive() { + } + + /** + * Called once the {@link ChannelInboundHandler#channelInactive(ChannelHandlerContext)} event hit + * the end of the {@link ChannelPipeline}. + */ + protected void onUnhandledInboundChannelInactive() { + } + + /** + * Called once a message hit the end of the {@link ChannelPipeline} without been handled by the user + * in {@link ChannelInboundHandler#channelRead(ChannelHandlerContext, Object)}. This method is responsible + * to call {@link ReferenceCountUtil#release(Object)} on the given msg at some point. + */ + protected void onUnhandledInboundMessage(Object msg) { + try { + logger.debug( + "Discarded inbound message {} that reached at the tail of the pipeline. " + + "Please check your pipeline configuration.", msg); + } finally { + ReferenceCountUtil.release(msg); + } + } + + /** + * Called once a message hit the end of the {@link ChannelPipeline} without been handled by the user + * in {@link ChannelInboundHandler#channelRead(ChannelHandlerContext, Object)}. This method is responsible + * to call {@link ReferenceCountUtil#release(Object)} on the given msg at some point. + */ + protected void onUnhandledInboundMessage(ChannelHandlerContext ctx, Object msg) { + onUnhandledInboundMessage(msg); + if (logger.isDebugEnabled()) { + logger.debug("Discarded message pipeline : {}. Channel : {}.", + ctx.pipeline().names(), ctx.channel()); + } + } + + /** + * Called once the {@link ChannelInboundHandler#channelReadComplete(ChannelHandlerContext)} event hit + * the end of the {@link ChannelPipeline}. + */ + protected void onUnhandledInboundChannelReadComplete() { + } + + /** + * Called once an user event hit the end of the {@link ChannelPipeline} without been handled by the user + * in {@link ChannelInboundHandler#userEventTriggered(ChannelHandlerContext, Object)}. This method is responsible + * to call {@link ReferenceCountUtil#release(Object)} on the given event at some point. + */ + protected void onUnhandledInboundUserEventTriggered(Object evt) { + // This may not be a configuration error and so don't log anything. + // The event may be superfluous for the current pipeline configuration. + ReferenceCountUtil.release(evt); + } + + /** + * Called once the {@link ChannelInboundHandler#channelWritabilityChanged(ChannelHandlerContext)} event hit + * the end of the {@link ChannelPipeline}. + */ + protected void onUnhandledChannelWritabilityChanged() { + } + + @UnstableApi + protected void incrementPendingOutboundBytes(long size) { + ChannelOutboundBuffer buffer = channel.unsafe().outboundBuffer(); + if (buffer != null) { + buffer.incrementPendingOutboundBytes(size); + } + } + + @UnstableApi + protected void decrementPendingOutboundBytes(long size) { + ChannelOutboundBuffer buffer = channel.unsafe().outboundBuffer(); + if (buffer != null) { + buffer.decrementPendingOutboundBytes(size); + } + } + + // A special catch-all handler that handles both bytes and messages. + final class TailContext extends AbstractChannelHandlerContext implements ChannelInboundHandler { + + TailContext(DefaultChannelPipeline pipeline) { + super(pipeline, null, TAIL_NAME, TailContext.class); + setAddComplete(); + } + + @Override + public ChannelHandler handler() { + return this; + } + + @Override + public void channelRegistered(ChannelHandlerContext ctx) { } + + @Override + public void channelUnregistered(ChannelHandlerContext ctx) { } + + @Override + public void channelActive(ChannelHandlerContext ctx) { + onUnhandledInboundChannelActive(); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) { + onUnhandledInboundChannelInactive(); + } + + @Override + public void channelWritabilityChanged(ChannelHandlerContext ctx) { + onUnhandledChannelWritabilityChanged(); + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) { } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) { } + + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + onUnhandledInboundUserEventTriggered(evt); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + onUnhandledInboundException(cause); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + onUnhandledInboundMessage(ctx, msg); + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) { + onUnhandledInboundChannelReadComplete(); + } + } + + final class HeadContext extends AbstractChannelHandlerContext + implements ChannelOutboundHandler, ChannelInboundHandler { + + private final Unsafe unsafe; + + HeadContext(DefaultChannelPipeline pipeline) { + super(pipeline, null, HEAD_NAME, HeadContext.class); + unsafe = pipeline.channel().unsafe(); + setAddComplete(); + } + + @Override + public ChannelHandler handler() { + return this; + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) { + // NOOP + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) { + // NOOP + } + + @Override + public void bind( + ChannelHandlerContext ctx, SocketAddress localAddress, ChannelPromise promise) { + unsafe.bind(localAddress, promise); + } + + @Override + public void connect( + ChannelHandlerContext ctx, + SocketAddress remoteAddress, SocketAddress localAddress, + ChannelPromise promise) { + unsafe.connect(remoteAddress, localAddress, promise); + } + + @Override + public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) { + unsafe.disconnect(promise); + } + + @Override + public void close(ChannelHandlerContext ctx, ChannelPromise promise) { + unsafe.close(promise); + } + + @Override + public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) { + unsafe.deregister(promise); + } + + @Override + public void read(ChannelHandlerContext ctx) { + unsafe.beginRead(); + } + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) { + unsafe.write(msg, promise); + } + + @Override + public void flush(ChannelHandlerContext ctx) { + unsafe.flush(); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + ctx.fireExceptionCaught(cause); + } + + @Override + public void channelRegistered(ChannelHandlerContext ctx) { + invokeHandlerAddedIfNeeded(); + ctx.fireChannelRegistered(); + } + + @Override + public void channelUnregistered(ChannelHandlerContext ctx) { + ctx.fireChannelUnregistered(); + + // Remove all handlers sequentially if channel is closed and unregistered. + if (!channel.isOpen()) { + destroy(); + } + } + + @Override + public void channelActive(ChannelHandlerContext ctx) { + ctx.fireChannelActive(); + + readIfIsAutoRead(); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) { + ctx.fireChannelInactive(); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + ctx.fireChannelRead(msg); + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) { + ctx.fireChannelReadComplete(); + + readIfIsAutoRead(); + } + + private void readIfIsAutoRead() { + if (channel.config().isAutoRead()) { + channel.read(); + } + } + + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + ctx.fireUserEventTriggered(evt); + } + + @Override + public void channelWritabilityChanged(ChannelHandlerContext ctx) { + ctx.fireChannelWritabilityChanged(); + } + } + + private abstract static class PendingHandlerCallback implements Runnable { + final AbstractChannelHandlerContext ctx; + PendingHandlerCallback next; + + PendingHandlerCallback(AbstractChannelHandlerContext ctx) { + this.ctx = ctx; + } + + abstract void execute(); + } + + private final class PendingHandlerAddedTask extends PendingHandlerCallback { + + PendingHandlerAddedTask(AbstractChannelHandlerContext ctx) { + super(ctx); + } + + @Override + public void run() { + callHandlerAdded0(ctx); + } + + @Override + void execute() { + EventExecutor executor = ctx.executor(); + if (executor.inEventLoop()) { + callHandlerAdded0(ctx); + } else { + try { + executor.execute(this); + } catch (RejectedExecutionException e) { + if (logger.isWarnEnabled()) { + logger.warn( + "Can't invoke handlerAdded() as the EventExecutor {} rejected it, removing handler {}.", + executor, ctx.name(), e); + } + atomicRemoveFromHandlerList(ctx); + ctx.setRemoved(); + } + } + } + } + + private final class PendingHandlerRemovedTask extends PendingHandlerCallback { + + PendingHandlerRemovedTask(AbstractChannelHandlerContext ctx) { + super(ctx); + } + + @Override + public void run() { + callHandlerRemoved0(ctx); + } + + @Override + void execute() { + EventExecutor executor = ctx.executor(); + if (executor.inEventLoop()) { + callHandlerRemoved0(ctx); + } else { + try { + executor.execute(this); + } catch (RejectedExecutionException e) { + if (logger.isWarnEnabled()) { + logger.warn( + "Can't invoke handlerRemoved() as the EventExecutor {} rejected it," + + " removing handler {}.", executor, ctx.name(), e); + } + // remove0(...) was call before so just call AbstractChannelHandlerContext.setRemoved(). + ctx.setRemoved(); + } + } + } + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/DefaultChannelProgressivePromise.java b/netty-channel/src/main/java/io/netty/channel/DefaultChannelProgressivePromise.java new file mode 100644 index 0000000..db13879 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/DefaultChannelProgressivePromise.java @@ -0,0 +1,179 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.channel.ChannelFlushPromiseNotifier.FlushCheckpoint; +import io.netty.util.concurrent.DefaultProgressivePromise; +import io.netty.util.concurrent.EventExecutor; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.GenericFutureListener; + +/** + * The default {@link ChannelProgressivePromise} implementation. It is recommended to use + * {@link Channel#newProgressivePromise()} to create a new {@link ChannelProgressivePromise} rather than calling the + * constructor explicitly. + */ +public class DefaultChannelProgressivePromise + extends DefaultProgressivePromise implements ChannelProgressivePromise, FlushCheckpoint { + + private final Channel channel; + private long checkpoint; + + /** + * Creates a new instance. + * + * @param channel + * the {@link Channel} associated with this future + */ + public DefaultChannelProgressivePromise(Channel channel) { + this.channel = channel; + } + + /** + * Creates a new instance. + * + * @param channel + * the {@link Channel} associated with this future + */ + public DefaultChannelProgressivePromise(Channel channel, EventExecutor executor) { + super(executor); + this.channel = channel; + } + + @Override + protected EventExecutor executor() { + EventExecutor e = super.executor(); + if (e == null) { + return channel().eventLoop(); + } else { + return e; + } + } + + @Override + public Channel channel() { + return channel; + } + + @Override + public ChannelProgressivePromise setSuccess() { + return setSuccess(null); + } + + @Override + public ChannelProgressivePromise setSuccess(Void result) { + super.setSuccess(result); + return this; + } + + @Override + public boolean trySuccess() { + return trySuccess(null); + } + + @Override + public ChannelProgressivePromise setFailure(Throwable cause) { + super.setFailure(cause); + return this; + } + + @Override + public ChannelProgressivePromise setProgress(long progress, long total) { + super.setProgress(progress, total); + return this; + } + + @Override + public ChannelProgressivePromise addListener(GenericFutureListener> listener) { + super.addListener(listener); + return this; + } + + @Override + public ChannelProgressivePromise addListeners(GenericFutureListener>... listeners) { + super.addListeners(listeners); + return this; + } + + @Override + public ChannelProgressivePromise removeListener(GenericFutureListener> listener) { + super.removeListener(listener); + return this; + } + + @Override + public ChannelProgressivePromise removeListeners( + GenericFutureListener>... listeners) { + super.removeListeners(listeners); + return this; + } + + @Override + public ChannelProgressivePromise sync() throws InterruptedException { + super.sync(); + return this; + } + + @Override + public ChannelProgressivePromise syncUninterruptibly() { + super.syncUninterruptibly(); + return this; + } + + @Override + public ChannelProgressivePromise await() throws InterruptedException { + super.await(); + return this; + } + + @Override + public ChannelProgressivePromise awaitUninterruptibly() { + super.awaitUninterruptibly(); + return this; + } + + @Override + public long flushCheckpoint() { + return checkpoint; + } + + @Override + public void flushCheckpoint(long checkpoint) { + this.checkpoint = checkpoint; + } + + @Override + public ChannelProgressivePromise promise() { + return this; + } + + @Override + protected void checkDeadLock() { + if (channel().isRegistered()) { + super.checkDeadLock(); + } + } + + @Override + public ChannelProgressivePromise unvoid() { + return this; + } + + @Override + public boolean isVoid() { + return false; + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/DefaultChannelPromise.java b/netty-channel/src/main/java/io/netty/channel/DefaultChannelPromise.java new file mode 100644 index 0000000..92ed53c --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/DefaultChannelPromise.java @@ -0,0 +1,172 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.channel.ChannelFlushPromiseNotifier.FlushCheckpoint; +import io.netty.util.concurrent.DefaultPromise; +import io.netty.util.concurrent.EventExecutor; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.GenericFutureListener; + +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +/** + * The default {@link ChannelPromise} implementation. It is recommended to use {@link Channel#newPromise()} to create + * a new {@link ChannelPromise} rather than calling the constructor explicitly. + */ +public class DefaultChannelPromise extends DefaultPromise implements ChannelPromise, FlushCheckpoint { + + private final Channel channel; + private long checkpoint; + + /** + * Creates a new instance. + * + * @param channel + * the {@link Channel} associated with this future + */ + public DefaultChannelPromise(Channel channel) { + this.channel = checkNotNull(channel, "channel"); + } + + /** + * Creates a new instance. + * + * @param channel + * the {@link Channel} associated with this future + */ + public DefaultChannelPromise(Channel channel, EventExecutor executor) { + super(executor); + this.channel = checkNotNull(channel, "channel"); + } + + @Override + protected EventExecutor executor() { + EventExecutor e = super.executor(); + if (e == null) { + return channel().eventLoop(); + } else { + return e; + } + } + + @Override + public Channel channel() { + return channel; + } + + @Override + public ChannelPromise setSuccess() { + return setSuccess(null); + } + + @Override + public ChannelPromise setSuccess(Void result) { + super.setSuccess(result); + return this; + } + + @Override + public boolean trySuccess() { + return trySuccess(null); + } + + @Override + public ChannelPromise setFailure(Throwable cause) { + super.setFailure(cause); + return this; + } + + @Override + public ChannelPromise addListener(GenericFutureListener> listener) { + super.addListener(listener); + return this; + } + + @Override + public ChannelPromise addListeners(GenericFutureListener>... listeners) { + super.addListeners(listeners); + return this; + } + + @Override + public ChannelPromise removeListener(GenericFutureListener> listener) { + super.removeListener(listener); + return this; + } + + @Override + public ChannelPromise removeListeners(GenericFutureListener>... listeners) { + super.removeListeners(listeners); + return this; + } + + @Override + public ChannelPromise sync() throws InterruptedException { + super.sync(); + return this; + } + + @Override + public ChannelPromise syncUninterruptibly() { + super.syncUninterruptibly(); + return this; + } + + @Override + public ChannelPromise await() throws InterruptedException { + super.await(); + return this; + } + + @Override + public ChannelPromise awaitUninterruptibly() { + super.awaitUninterruptibly(); + return this; + } + + @Override + public long flushCheckpoint() { + return checkpoint; + } + + @Override + public void flushCheckpoint(long checkpoint) { + this.checkpoint = checkpoint; + } + + @Override + public ChannelPromise promise() { + return this; + } + + @Override + protected void checkDeadLock() { + if (channel().isRegistered()) { + super.checkDeadLock(); + } + } + + @Override + public ChannelPromise unvoid() { + return this; + } + + @Override + public boolean isVoid() { + return false; + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/DefaultEventLoop.java b/netty-channel/src/main/java/io/netty/channel/DefaultEventLoop.java new file mode 100644 index 0000000..4f0d734 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/DefaultEventLoop.java @@ -0,0 +1,63 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.util.concurrent.DefaultThreadFactory; + +import java.util.concurrent.Executor; +import java.util.concurrent.ThreadFactory; + +public class DefaultEventLoop extends SingleThreadEventLoop { + + public DefaultEventLoop() { + this((EventLoopGroup) null); + } + + public DefaultEventLoop(ThreadFactory threadFactory) { + this(null, threadFactory); + } + + public DefaultEventLoop(Executor executor) { + this(null, executor); + } + + public DefaultEventLoop(EventLoopGroup parent) { + this(parent, new DefaultThreadFactory(DefaultEventLoop.class)); + } + + public DefaultEventLoop(EventLoopGroup parent, ThreadFactory threadFactory) { + super(parent, threadFactory, true); + } + + public DefaultEventLoop(EventLoopGroup parent, Executor executor) { + super(parent, executor, true); + } + + @Override + protected void run() { + for (;;) { + Runnable task = takeTask(); + if (task != null) { + runTask(task); + updateLastExecutionTime(); + } + + if (confirmShutdown()) { + break; + } + } + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/DefaultEventLoopGroup.java b/netty-channel/src/main/java/io/netty/channel/DefaultEventLoopGroup.java new file mode 100644 index 0000000..d12c288 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/DefaultEventLoopGroup.java @@ -0,0 +1,75 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import java.util.concurrent.Executor; +import java.util.concurrent.ThreadFactory; + +/** + * {@link MultithreadEventLoopGroup} which must be used for the local transport. + */ +public class DefaultEventLoopGroup extends MultithreadEventLoopGroup { + + /** + * Create a new instance with the default number of threads. + */ + public DefaultEventLoopGroup() { + this(0); + } + + /** + * Create a new instance + * + * @param nThreads the number of threads to use + */ + public DefaultEventLoopGroup(int nThreads) { + this(nThreads, (ThreadFactory) null); + } + + /** + * Create a new instance with the default number of threads and the given {@link ThreadFactory}. + * + * @param threadFactory the {@link ThreadFactory} or {@code null} to use the default + */ + public DefaultEventLoopGroup(ThreadFactory threadFactory) { + this(0, threadFactory); + } + + /** + * Create a new instance + * + * @param nThreads the number of threads to use + * @param threadFactory the {@link ThreadFactory} or {@code null} to use the default + */ + public DefaultEventLoopGroup(int nThreads, ThreadFactory threadFactory) { + super(nThreads, threadFactory); + } + + /** + * Create a new instance + * + * @param nThreads the number of threads to use + * @param executor the Executor to use, or {@code null} if the default should be used. + */ + public DefaultEventLoopGroup(int nThreads, Executor executor) { + super(nThreads, executor); + } + + @Override + protected EventLoop newChild(Executor executor, Object... args) throws Exception { + return new DefaultEventLoop(this, executor); + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/DefaultFileRegion.java b/netty-channel/src/main/java/io/netty/channel/DefaultFileRegion.java new file mode 100644 index 0000000..1011c21 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/DefaultFileRegion.java @@ -0,0 +1,192 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.util.AbstractReferenceCounted; +import io.netty.util.IllegalReferenceCountException; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.io.File; +import java.io.IOException; +import java.io.RandomAccessFile; +import java.nio.channels.FileChannel; +import java.nio.channels.WritableByteChannel; + +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; + +/** + * Default {@link FileRegion} implementation which transfer data from a {@link FileChannel} or {@link File}. + * + * Be aware that the {@link FileChannel} will be automatically closed once {@link #refCnt()} returns + * {@code 0}. + */ +public class DefaultFileRegion extends AbstractReferenceCounted implements FileRegion { + + private static final InternalLogger logger = InternalLoggerFactory.getInstance(DefaultFileRegion.class); + private final File f; + private final long position; + private final long count; + private long transferred; + private FileChannel file; + + /** + * Create a new instance + * + * @param fileChannel the {@link FileChannel} which should be transferred + * @param position the position from which the transfer should start + * @param count the number of bytes to transfer + */ + public DefaultFileRegion(FileChannel fileChannel, long position, long count) { + this.file = ObjectUtil.checkNotNull(fileChannel, "fileChannel"); + this.position = checkPositiveOrZero(position, "position"); + this.count = checkPositiveOrZero(count, "count"); + this.f = null; + } + + /** + * Create a new instance using the given {@link File}. The {@link File} will be opened lazily or + * explicitly via {@link #open()}. + * + * @param file the {@link File} which should be transferred + * @param position the position from which the transfer should start + * @param count the number of bytes to transfer + */ + public DefaultFileRegion(File file, long position, long count) { + this.f = ObjectUtil.checkNotNull(file, "file"); + this.position = checkPositiveOrZero(position, "position"); + this.count = checkPositiveOrZero(count, "count"); + } + + /** + * Returns {@code true} if the {@link FileRegion} has a open file-descriptor + */ + public boolean isOpen() { + return file != null; + } + + /** + * Explicitly open the underlying file-descriptor if not done yet. + */ + public void open() throws IOException { + if (!isOpen() && refCnt() > 0) { + // Only open if this DefaultFileRegion was not released yet. + file = new RandomAccessFile(f, "r").getChannel(); + } + } + + @Override + public long position() { + return position; + } + + @Override + public long count() { + return count; + } + + @Deprecated + @Override + public long transfered() { + return transferred; + } + + @Override + public long transferred() { + return transferred; + } + + @Override + public long transferTo(WritableByteChannel target, long position) throws IOException { + long count = this.count - position; + if (count < 0 || position < 0) { + throw new IllegalArgumentException( + "position out of range: " + position + + " (expected: 0 - " + (this.count - 1) + ')'); + } + if (count == 0) { + return 0L; + } + if (refCnt() == 0) { + throw new IllegalReferenceCountException(0); + } + // Call open to make sure fc is initialized. This is a no-oop if we called it before. + open(); + + long written = file.transferTo(this.position + position, count, target); + if (written > 0) { + transferred += written; + } else if (written == 0) { + // If the amount of written data is 0 we need to check if the requested count is bigger then the + // actual file itself as it may have been truncated on disk. + // + // See https://github.com/netty/netty/issues/8868 + validate(this, position); + } + return written; + } + + @Override + protected void deallocate() { + FileChannel file = this.file; + + if (file == null) { + return; + } + this.file = null; + + try { + file.close(); + } catch (IOException e) { + logger.warn("Failed to close a file.", e); + } + } + + @Override + public FileRegion retain() { + super.retain(); + return this; + } + + @Override + public FileRegion retain(int increment) { + super.retain(increment); + return this; + } + + @Override + public FileRegion touch() { + return this; + } + + @Override + public FileRegion touch(Object hint) { + return this; + } + + static void validate(DefaultFileRegion region, long position) throws IOException { + // If the amount of written data is 0 we need to check if the requested count is bigger then the + // actual file itself as it may have been truncated on disk. + // + // See https://github.com/netty/netty/issues/8868 + long size = region.file.size(); + long count = region.count - position; + if (region.position + count + position > size) { + throw new IOException("Underlying file size " + size + " smaller then requested count " + region.count); + } + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/DefaultMaxBytesRecvByteBufAllocator.java b/netty-channel/src/main/java/io/netty/channel/DefaultMaxBytesRecvByteBufAllocator.java new file mode 100644 index 0000000..090fa81 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/DefaultMaxBytesRecvByteBufAllocator.java @@ -0,0 +1,195 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import static io.netty.util.internal.ObjectUtil.checkPositive; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.util.UncheckedBooleanSupplier; + +import java.util.AbstractMap; +import java.util.Map.Entry; + +/** + * The {@link RecvByteBufAllocator} that yields a buffer size prediction based upon decrementing the value from + * the max bytes per read. + */ +public class DefaultMaxBytesRecvByteBufAllocator implements MaxBytesRecvByteBufAllocator { + private volatile int maxBytesPerRead; + private volatile int maxBytesPerIndividualRead; + + private final class HandleImpl implements ExtendedHandle { + private int individualReadMax; + private int bytesToRead; + private int lastBytesRead; + private int attemptBytesRead; + private final UncheckedBooleanSupplier defaultMaybeMoreSupplier = new UncheckedBooleanSupplier() { + @Override + public boolean get() { + return attemptBytesRead == lastBytesRead; + } + }; + + @Override + public ByteBuf allocate(ByteBufAllocator alloc) { + return alloc.ioBuffer(guess()); + } + + @Override + public int guess() { + return Math.min(individualReadMax, bytesToRead); + } + + @Override + public void reset(ChannelConfig config) { + bytesToRead = maxBytesPerRead(); + individualReadMax = maxBytesPerIndividualRead(); + } + + @Override + public void incMessagesRead(int amt) { + } + + @Override + public void lastBytesRead(int bytes) { + lastBytesRead = bytes; + // Ignore if bytes is negative, the interface contract states it will be detected externally after call. + // The value may be "invalid" after this point, but it doesn't matter because reading will be stopped. + bytesToRead -= bytes; + } + + @Override + public int lastBytesRead() { + return lastBytesRead; + } + + @Override + public boolean continueReading() { + return continueReading(defaultMaybeMoreSupplier); + } + + @Override + public boolean continueReading(UncheckedBooleanSupplier maybeMoreDataSupplier) { + // Keep reading if we are allowed to read more bytes, and our last read filled up the buffer we provided. + return bytesToRead > 0 && maybeMoreDataSupplier.get(); + } + + @Override + public void readComplete() { + } + + @Override + public void attemptedBytesRead(int bytes) { + attemptBytesRead = bytes; + } + + @Override + public int attemptedBytesRead() { + return attemptBytesRead; + } + } + + public DefaultMaxBytesRecvByteBufAllocator() { + this(64 * 1024, 64 * 1024); + } + + public DefaultMaxBytesRecvByteBufAllocator(int maxBytesPerRead, int maxBytesPerIndividualRead) { + checkMaxBytesPerReadPair(maxBytesPerRead, maxBytesPerIndividualRead); + this.maxBytesPerRead = maxBytesPerRead; + this.maxBytesPerIndividualRead = maxBytesPerIndividualRead; + } + + @SuppressWarnings("deprecation") + @Override + public Handle newHandle() { + return new HandleImpl(); + } + + @Override + public int maxBytesPerRead() { + return maxBytesPerRead; + } + + @Override + public DefaultMaxBytesRecvByteBufAllocator maxBytesPerRead(int maxBytesPerRead) { + checkPositive(maxBytesPerRead, "maxBytesPerRead"); + // There is a dependency between this.maxBytesPerRead and this.maxBytesPerIndividualRead (a < b). + // Write operations must be synchronized, but independent read operations can just be volatile. + synchronized (this) { + final int maxBytesPerIndividualRead = maxBytesPerIndividualRead(); + if (maxBytesPerRead < maxBytesPerIndividualRead) { + throw new IllegalArgumentException( + "maxBytesPerRead cannot be less than " + + "maxBytesPerIndividualRead (" + maxBytesPerIndividualRead + "): " + maxBytesPerRead); + } + + this.maxBytesPerRead = maxBytesPerRead; + } + return this; + } + + @Override + public int maxBytesPerIndividualRead() { + return maxBytesPerIndividualRead; + } + + @Override + public DefaultMaxBytesRecvByteBufAllocator maxBytesPerIndividualRead(int maxBytesPerIndividualRead) { + checkPositive(maxBytesPerIndividualRead, "maxBytesPerIndividualRead"); + // There is a dependency between this.maxBytesPerRead and this.maxBytesPerIndividualRead (a < b). + // Write operations must be synchronized, but independent read operations can just be volatile. + synchronized (this) { + final int maxBytesPerRead = maxBytesPerRead(); + if (maxBytesPerIndividualRead > maxBytesPerRead) { + throw new IllegalArgumentException( + "maxBytesPerIndividualRead cannot be greater than " + + "maxBytesPerRead (" + maxBytesPerRead + "): " + maxBytesPerIndividualRead); + } + + this.maxBytesPerIndividualRead = maxBytesPerIndividualRead; + } + return this; + } + + @Override + public synchronized Entry maxBytesPerReadPair() { + return new AbstractMap.SimpleEntry(maxBytesPerRead, maxBytesPerIndividualRead); + } + + private static void checkMaxBytesPerReadPair(int maxBytesPerRead, int maxBytesPerIndividualRead) { + checkPositive(maxBytesPerRead, "maxBytesPerRead"); + checkPositive(maxBytesPerIndividualRead, "maxBytesPerIndividualRead"); + if (maxBytesPerRead < maxBytesPerIndividualRead) { + throw new IllegalArgumentException( + "maxBytesPerRead cannot be less than " + + "maxBytesPerIndividualRead (" + maxBytesPerIndividualRead + "): " + maxBytesPerRead); + } + } + + @Override + public DefaultMaxBytesRecvByteBufAllocator maxBytesPerReadPair(int maxBytesPerRead, + int maxBytesPerIndividualRead) { + checkMaxBytesPerReadPair(maxBytesPerRead, maxBytesPerIndividualRead); + // There is a dependency between this.maxBytesPerRead and this.maxBytesPerIndividualRead (a < b). + // Write operations must be synchronized, but independent read operations can just be volatile. + synchronized (this) { + this.maxBytesPerRead = maxBytesPerRead; + this.maxBytesPerIndividualRead = maxBytesPerIndividualRead; + } + return this; + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/DefaultMaxMessagesRecvByteBufAllocator.java b/netty-channel/src/main/java/io/netty/channel/DefaultMaxMessagesRecvByteBufAllocator.java new file mode 100644 index 0000000..eb4bb84 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/DefaultMaxMessagesRecvByteBufAllocator.java @@ -0,0 +1,171 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import static io.netty.util.internal.ObjectUtil.checkPositive; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.util.UncheckedBooleanSupplier; + +/** + * Default implementation of {@link MaxMessagesRecvByteBufAllocator} which respects {@link ChannelConfig#isAutoRead()} + * and also prevents overflow. + */ +public abstract class DefaultMaxMessagesRecvByteBufAllocator implements MaxMessagesRecvByteBufAllocator { + private final boolean ignoreBytesRead; + private volatile int maxMessagesPerRead; + private volatile boolean respectMaybeMoreData = true; + + public DefaultMaxMessagesRecvByteBufAllocator() { + this(1); + } + + public DefaultMaxMessagesRecvByteBufAllocator(int maxMessagesPerRead) { + this(maxMessagesPerRead, false); + } + + DefaultMaxMessagesRecvByteBufAllocator(int maxMessagesPerRead, boolean ignoreBytesRead) { + this.ignoreBytesRead = ignoreBytesRead; + maxMessagesPerRead(maxMessagesPerRead); + } + + @Override + public int maxMessagesPerRead() { + return maxMessagesPerRead; + } + + @Override + public MaxMessagesRecvByteBufAllocator maxMessagesPerRead(int maxMessagesPerRead) { + checkPositive(maxMessagesPerRead, "maxMessagesPerRead"); + this.maxMessagesPerRead = maxMessagesPerRead; + return this; + } + + /** + * Determine if future instances of {@link #newHandle()} will stop reading if we think there is no more data. + * @param respectMaybeMoreData + *

    + *
  • {@code true} to stop reading if we think there is no more data. This may save a system call to read from + * the socket, but if data has arrived in a racy fashion we may give up our {@link #maxMessagesPerRead()} + * quantum and have to wait for the selector to notify us of more data.
  • + *
  • {@code false} to keep reading (up to {@link #maxMessagesPerRead()}) or until there is no data when we + * attempt to read.
  • + *
+ * @return {@code this}. + */ + public DefaultMaxMessagesRecvByteBufAllocator respectMaybeMoreData(boolean respectMaybeMoreData) { + this.respectMaybeMoreData = respectMaybeMoreData; + return this; + } + + /** + * Get if future instances of {@link #newHandle()} will stop reading if we think there is no more data. + * @return + *
    + *
  • {@code true} to stop reading if we think there is no more data. This may save a system call to read from + * the socket, but if data has arrived in a racy fashion we may give up our {@link #maxMessagesPerRead()} + * quantum and have to wait for the selector to notify us of more data.
  • + *
  • {@code false} to keep reading (up to {@link #maxMessagesPerRead()}) or until there is no data when we + * attempt to read.
  • + *
+ */ + public final boolean respectMaybeMoreData() { + return respectMaybeMoreData; + } + + /** + * Focuses on enforcing the maximum messages per read condition for {@link #continueReading()}. + */ + public abstract class MaxMessageHandle implements ExtendedHandle { + private ChannelConfig config; + private int maxMessagePerRead; + private int totalMessages; + private int totalBytesRead; + private int attemptedBytesRead; + private int lastBytesRead; + private final boolean respectMaybeMoreData = DefaultMaxMessagesRecvByteBufAllocator.this.respectMaybeMoreData; + private final UncheckedBooleanSupplier defaultMaybeMoreSupplier = new UncheckedBooleanSupplier() { + @Override + public boolean get() { + return attemptedBytesRead == lastBytesRead; + } + }; + + /** + * Only {@link ChannelConfig#getMaxMessagesPerRead()} is used. + */ + @Override + public void reset(ChannelConfig config) { + this.config = config; + maxMessagePerRead = maxMessagesPerRead(); + totalMessages = totalBytesRead = 0; + } + + @Override + public ByteBuf allocate(ByteBufAllocator alloc) { + return alloc.ioBuffer(guess()); + } + + @Override + public final void incMessagesRead(int amt) { + totalMessages += amt; + } + + @Override + public void lastBytesRead(int bytes) { + lastBytesRead = bytes; + if (bytes > 0) { + totalBytesRead += bytes; + } + } + + @Override + public final int lastBytesRead() { + return lastBytesRead; + } + + @Override + public boolean continueReading() { + return continueReading(defaultMaybeMoreSupplier); + } + + @Override + public boolean continueReading(UncheckedBooleanSupplier maybeMoreDataSupplier) { + return config.isAutoRead() && + (!respectMaybeMoreData || maybeMoreDataSupplier.get()) && + totalMessages < maxMessagePerRead && (ignoreBytesRead || totalBytesRead > 0); + } + + @Override + public void readComplete() { + } + + @Override + public int attemptedBytesRead() { + return attemptedBytesRead; + } + + @Override + public void attemptedBytesRead(int bytes) { + attemptedBytesRead = bytes; + } + + protected final int totalBytesRead() { + return totalBytesRead < 0 ? Integer.MAX_VALUE : totalBytesRead; + } + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/DefaultMessageSizeEstimator.java b/netty-channel/src/main/java/io/netty/channel/DefaultMessageSizeEstimator.java new file mode 100644 index 0000000..ef4d542 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/DefaultMessageSizeEstimator.java @@ -0,0 +1,72 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufHolder; + +/** + * Default {@link MessageSizeEstimator} implementation which supports the estimation of the size of + * {@link ByteBuf}, {@link ByteBufHolder} and {@link FileRegion}. + */ +public final class DefaultMessageSizeEstimator implements MessageSizeEstimator { + + private static final class HandleImpl implements Handle { + private final int unknownSize; + + private HandleImpl(int unknownSize) { + this.unknownSize = unknownSize; + } + + @Override + public int size(Object msg) { + if (msg instanceof ByteBuf) { + return ((ByteBuf) msg).readableBytes(); + } + if (msg instanceof ByteBufHolder) { + return ((ByteBufHolder) msg).content().readableBytes(); + } + if (msg instanceof FileRegion) { + return 0; + } + return unknownSize; + } + } + + /** + * Return the default implementation which returns {@code 8} for unknown messages. + */ + public static final MessageSizeEstimator DEFAULT = new DefaultMessageSizeEstimator(8); + + private final Handle handle; + + /** + * Create a new instance + * + * @param unknownSize The size which is returned for unknown messages. + */ + public DefaultMessageSizeEstimator(int unknownSize) { + checkPositiveOrZero(unknownSize, "unknownSize"); + handle = new HandleImpl(unknownSize); + } + + @Override + public Handle newHandle() { + return handle; + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/DefaultSelectStrategy.java b/netty-channel/src/main/java/io/netty/channel/DefaultSelectStrategy.java new file mode 100644 index 0000000..eafe2a7 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/DefaultSelectStrategy.java @@ -0,0 +1,32 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.util.IntSupplier; + +/** + * Default select strategy. + */ +final class DefaultSelectStrategy implements SelectStrategy { + static final SelectStrategy INSTANCE = new DefaultSelectStrategy(); + + private DefaultSelectStrategy() { } + + @Override + public int calculateStrategy(IntSupplier selectSupplier, boolean hasTasks) throws Exception { + return hasTasks ? selectSupplier.get() : SelectStrategy.SELECT; + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/DefaultSelectStrategyFactory.java b/netty-channel/src/main/java/io/netty/channel/DefaultSelectStrategyFactory.java new file mode 100644 index 0000000..a2fb175 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/DefaultSelectStrategyFactory.java @@ -0,0 +1,30 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +/** + * Factory which uses the default select strategy. + */ +public final class DefaultSelectStrategyFactory implements SelectStrategyFactory { + public static final SelectStrategyFactory INSTANCE = new DefaultSelectStrategyFactory(); + + private DefaultSelectStrategyFactory() { } + + @Override + public SelectStrategy newSelectStrategy() { + return DefaultSelectStrategy.INSTANCE; + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/DelegatingChannelPromiseNotifier.java b/netty-channel/src/main/java/io/netty/channel/DelegatingChannelPromiseNotifier.java new file mode 100644 index 0000000..210d6d5 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/DelegatingChannelPromiseNotifier.java @@ -0,0 +1,226 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.GenericFutureListener; +import io.netty.util.internal.PromiseNotificationUtil; +import io.netty.util.internal.UnstableApi; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +@UnstableApi +public final class DelegatingChannelPromiseNotifier implements ChannelPromise, ChannelFutureListener { + private static final InternalLogger logger = + InternalLoggerFactory.getInstance(DelegatingChannelPromiseNotifier.class); + private final ChannelPromise delegate; + private final boolean logNotifyFailure; + + public DelegatingChannelPromiseNotifier(ChannelPromise delegate) { + this(delegate, !(delegate instanceof VoidChannelPromise)); + } + + public DelegatingChannelPromiseNotifier(ChannelPromise delegate, boolean logNotifyFailure) { + this.delegate = checkNotNull(delegate, "delegate"); + this.logNotifyFailure = logNotifyFailure; + } + + @Override + public void operationComplete(ChannelFuture future) throws Exception { + InternalLogger internalLogger = logNotifyFailure ? logger : null; + if (future.isSuccess()) { + Void result = future.get(); + PromiseNotificationUtil.trySuccess(delegate, result, internalLogger); + } else if (future.isCancelled()) { + PromiseNotificationUtil.tryCancel(delegate, internalLogger); + } else { + Throwable cause = future.cause(); + PromiseNotificationUtil.tryFailure(delegate, cause, internalLogger); + } + } + + @Override + public Channel channel() { + return delegate.channel(); + } + + @Override + public ChannelPromise setSuccess(Void result) { + delegate.setSuccess(result); + return this; + } + + @Override + public ChannelPromise setSuccess() { + delegate.setSuccess(); + return this; + } + + @Override + public boolean trySuccess() { + return delegate.trySuccess(); + } + + @Override + public boolean trySuccess(Void result) { + return delegate.trySuccess(result); + } + + @Override + public ChannelPromise setFailure(Throwable cause) { + delegate.setFailure(cause); + return this; + } + + @Override + public ChannelPromise addListener(GenericFutureListener> listener) { + delegate.addListener(listener); + return this; + } + + @Override + public ChannelPromise addListeners(GenericFutureListener>... listeners) { + delegate.addListeners(listeners); + return this; + } + + @Override + public ChannelPromise removeListener(GenericFutureListener> listener) { + delegate.removeListener(listener); + return this; + } + + @Override + public ChannelPromise removeListeners(GenericFutureListener>... listeners) { + delegate.removeListeners(listeners); + return this; + } + + @Override + public boolean tryFailure(Throwable cause) { + return delegate.tryFailure(cause); + } + + @Override + public boolean setUncancellable() { + return delegate.setUncancellable(); + } + + @Override + public ChannelPromise await() throws InterruptedException { + delegate.await(); + return this; + } + + @Override + public ChannelPromise awaitUninterruptibly() { + delegate.awaitUninterruptibly(); + return this; + } + + @Override + public boolean isVoid() { + return delegate.isVoid(); + } + + @Override + public ChannelPromise unvoid() { + return isVoid() ? new DelegatingChannelPromiseNotifier(delegate.unvoid()) : this; + } + + @Override + public boolean await(long timeout, TimeUnit unit) throws InterruptedException { + return delegate.await(timeout, unit); + } + + @Override + public boolean await(long timeoutMillis) throws InterruptedException { + return delegate.await(timeoutMillis); + } + + @Override + public boolean awaitUninterruptibly(long timeout, TimeUnit unit) { + return delegate.awaitUninterruptibly(timeout, unit); + } + + @Override + public boolean awaitUninterruptibly(long timeoutMillis) { + return delegate.awaitUninterruptibly(timeoutMillis); + } + + @Override + public Void getNow() { + return delegate.getNow(); + } + + @Override + public boolean cancel(boolean mayInterruptIfRunning) { + return delegate.cancel(mayInterruptIfRunning); + } + + @Override + public boolean isCancelled() { + return delegate.isCancelled(); + } + + @Override + public boolean isDone() { + return delegate.isDone(); + } + + @Override + public Void get() throws InterruptedException, ExecutionException { + return delegate.get(); + } + + @Override + public Void get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException { + return delegate.get(timeout, unit); + } + + @Override + public ChannelPromise sync() throws InterruptedException { + delegate.sync(); + return this; + } + + @Override + public ChannelPromise syncUninterruptibly() { + delegate.syncUninterruptibly(); + return this; + } + + @Override + public boolean isSuccess() { + return delegate.isSuccess(); + } + + @Override + public boolean isCancellable() { + return delegate.isCancellable(); + } + + @Override + public Throwable cause() { + return delegate.cause(); + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/EventLoop.java b/netty-channel/src/main/java/io/netty/channel/EventLoop.java new file mode 100644 index 0000000..d9720da --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/EventLoop.java @@ -0,0 +1,30 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.util.concurrent.OrderedEventExecutor; + +/** + * Will handle all the I/O operations for a {@link Channel} once registered. + * + * One {@link EventLoop} instance will usually handle more than one {@link Channel} but this may depend on + * implementation details and internals. + * + */ +public interface EventLoop extends OrderedEventExecutor, EventLoopGroup { + @Override + EventLoopGroup parent(); +} diff --git a/netty-channel/src/main/java/io/netty/channel/EventLoopException.java b/netty-channel/src/main/java/io/netty/channel/EventLoopException.java new file mode 100644 index 0000000..330b1e0 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/EventLoopException.java @@ -0,0 +1,41 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +/** + * Special {@link ChannelException} which will be thrown by {@link EventLoop} and {@link EventLoopGroup} + * implementations when an error occurs. + */ +public class EventLoopException extends ChannelException { + + private static final long serialVersionUID = -8969100344583703616L; + + public EventLoopException() { + } + + public EventLoopException(String message, Throwable cause) { + super(message, cause); + } + + public EventLoopException(String message) { + super(message); + } + + public EventLoopException(Throwable cause) { + super(cause); + } + +} diff --git a/netty-channel/src/main/java/io/netty/channel/EventLoopGroup.java b/netty-channel/src/main/java/io/netty/channel/EventLoopGroup.java new file mode 100644 index 0000000..610909c --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/EventLoopGroup.java @@ -0,0 +1,52 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.util.concurrent.EventExecutorGroup; + +/** + * Special {@link EventExecutorGroup} which allows registering {@link Channel}s that get + * processed for later selection during the event loop. + * + */ +public interface EventLoopGroup extends EventExecutorGroup { + /** + * Return the next {@link EventLoop} to use + */ + @Override + EventLoop next(); + + /** + * Register a {@link Channel} with this {@link EventLoop}. The returned {@link ChannelFuture} + * will get notified once the registration was complete. + */ + ChannelFuture register(Channel channel); + + /** + * Register a {@link Channel} with this {@link EventLoop} using a {@link ChannelFuture}. The passed + * {@link ChannelFuture} will get notified once the registration was complete and also will get returned. + */ + ChannelFuture register(ChannelPromise promise); + + /** + * Register a {@link Channel} with this {@link EventLoop}. The passed {@link ChannelFuture} + * will get notified once the registration was complete and also will get returned. + * + * @deprecated Use {@link #register(ChannelPromise)} instead. + */ + @Deprecated + ChannelFuture register(Channel channel, ChannelPromise promise); +} diff --git a/netty-channel/src/main/java/io/netty/channel/EventLoopTaskQueueFactory.java b/netty-channel/src/main/java/io/netty/channel/EventLoopTaskQueueFactory.java new file mode 100644 index 0000000..a413d31 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/EventLoopTaskQueueFactory.java @@ -0,0 +1,35 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import java.util.Queue; + +/** + * Factory used to create {@link Queue} instances that will be used to store tasks for an {@link EventLoop}. + * + * Generally speaking the returned {@link Queue} MUST be thread-safe and depending on the {@link EventLoop} + * implementation must be of type {@link java.util.concurrent.BlockingQueue}. + */ +public interface EventLoopTaskQueueFactory { + + /** + * Returns a new {@link Queue} to use. + * @param maxCapacity the maximum amount of elements that can be stored in the {@link Queue} at a given point + * in time. + * @return the new queue. + */ + Queue newTaskQueue(int maxCapacity); +} diff --git a/netty-channel/src/main/java/io/netty/channel/ExtendedClosedChannelException.java b/netty-channel/src/main/java/io/netty/channel/ExtendedClosedChannelException.java new file mode 100644 index 0000000..022af63 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/ExtendedClosedChannelException.java @@ -0,0 +1,33 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import java.nio.channels.ClosedChannelException; + +final class ExtendedClosedChannelException extends ClosedChannelException { + + ExtendedClosedChannelException(Throwable cause) { + if (cause != null) { + initCause(cause); + } + } + + // Suppress a warning since the method doesn't need synchronization + @Override + public Throwable fillInStackTrace() { + return this; + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/FailedChannelFuture.java b/netty-channel/src/main/java/io/netty/channel/FailedChannelFuture.java new file mode 100644 index 0000000..47d9fd9 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/FailedChannelFuture.java @@ -0,0 +1,63 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.util.concurrent.EventExecutor; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.PlatformDependent; + +/** + * The {@link CompleteChannelFuture} which is failed already. It is + * recommended to use {@link Channel#newFailedFuture(Throwable)} + * instead of calling the constructor of this future. + */ +final class FailedChannelFuture extends CompleteChannelFuture { + + private final Throwable cause; + + /** + * Creates a new instance. + * + * @param channel the {@link Channel} associated with this future + * @param cause the cause of failure + */ + FailedChannelFuture(Channel channel, EventExecutor executor, Throwable cause) { + super(channel, executor); + this.cause = ObjectUtil.checkNotNull(cause, "cause"); + } + + @Override + public Throwable cause() { + return cause; + } + + @Override + public boolean isSuccess() { + return false; + } + + @Override + public ChannelFuture sync() { + PlatformDependent.throwException(cause); + return this; + } + + @Override + public ChannelFuture syncUninterruptibly() { + PlatformDependent.throwException(cause); + return this; + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/FileRegion.java b/netty-channel/src/main/java/io/netty/channel/FileRegion.java new file mode 100644 index 0000000..fec73bc --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/FileRegion.java @@ -0,0 +1,101 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.util.ReferenceCounted; + +import java.io.IOException; +import java.nio.channels.FileChannel; +import java.nio.channels.WritableByteChannel; + +/** + * A region of a file that is sent via a {@link Channel} which supports + * zero-copy file transfer. + * + *

Upgrade your JDK / JRE

+ * + * {@link FileChannel#transferTo(long, long, WritableByteChannel)} has at least + * four known bugs in the old versions of Sun JDK and perhaps its derived ones. + * Please upgrade your JDK to 1.6.0_18 or later version if you are going to use + * zero-copy file transfer. + *
    + *
  • 5103988 + * - FileChannel.transferTo() should return -1 for EAGAIN instead throws IOException
  • + *
  • 6253145 + * - FileChannel.transferTo() on Linux fails when going beyond 2GB boundary
  • + *
  • 6427312 + * - FileChannel.transferTo() throws IOException "system call interrupted"
  • + *
  • 6470086 + * - FileChannel.transferTo(2147483647, 1, channel) causes "Value too large" exception
  • + *
+ * + *

Check your operating system and JDK / JRE

+ * + * If your operating system (or JDK / JRE) does not support zero-copy file + * transfer, sending a file with {@link FileRegion} might fail or yield worse + * performance. For example, sending a large file doesn't work well in Windows. + * + *

Not all transports support it

+ */ +public interface FileRegion extends ReferenceCounted { + + /** + * Returns the offset in the file where the transfer began. + */ + long position(); + + /** + * Returns the bytes which was transferred already. + * + * @deprecated Use {@link #transferred()} instead. + */ + @Deprecated + long transfered(); + + /** + * Returns the bytes which was transferred already. + */ + long transferred(); + + /** + * Returns the number of bytes to transfer. + */ + long count(); + + /** + * Transfers the content of this file region to the specified channel. + * + * @param target the destination of the transfer + * @param position the relative offset of the file where the transfer + * begins from. For example, 0 will make the + * transfer start from {@link #position()}th byte and + * {@link #count()} - 1 will make the last + * byte of the region transferred. + */ + long transferTo(WritableByteChannel target, long position) throws IOException; + + @Override + FileRegion retain(); + + @Override + FileRegion retain(int increment); + + @Override + FileRegion touch(); + + @Override + FileRegion touch(Object hint); +} diff --git a/netty-channel/src/main/java/io/netty/channel/FixedRecvByteBufAllocator.java b/netty-channel/src/main/java/io/netty/channel/FixedRecvByteBufAllocator.java new file mode 100644 index 0000000..28e8b6f --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/FixedRecvByteBufAllocator.java @@ -0,0 +1,61 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import static io.netty.util.internal.ObjectUtil.checkPositive; + +/** + * The {@link RecvByteBufAllocator} that always yields the same buffer + * size prediction. This predictor ignores the feed back from the I/O thread. + */ +public class FixedRecvByteBufAllocator extends DefaultMaxMessagesRecvByteBufAllocator { + + private final int bufferSize; + + private final class HandleImpl extends MaxMessageHandle { + private final int bufferSize; + + HandleImpl(int bufferSize) { + this.bufferSize = bufferSize; + } + + @Override + public int guess() { + return bufferSize; + } + } + + /** + * Creates a new predictor that always returns the same prediction of + * the specified buffer size. + */ + public FixedRecvByteBufAllocator(int bufferSize) { + checkPositive(bufferSize, "bufferSize"); + this.bufferSize = bufferSize; + } + + @SuppressWarnings("deprecation") + @Override + public Handle newHandle() { + return new HandleImpl(bufferSize); + } + + @Override + public FixedRecvByteBufAllocator respectMaybeMoreData(boolean respectMaybeMoreData) { + super.respectMaybeMoreData(respectMaybeMoreData); + return this; + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/MaxBytesRecvByteBufAllocator.java b/netty-channel/src/main/java/io/netty/channel/MaxBytesRecvByteBufAllocator.java new file mode 100644 index 0000000..ce7d79b --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/MaxBytesRecvByteBufAllocator.java @@ -0,0 +1,65 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import java.util.Map.Entry; + +/** + * {@link RecvByteBufAllocator} that limits a read operation based upon a maximum value per individual read + * and a maximum amount when a read operation is attempted by the event loop. + */ +public interface MaxBytesRecvByteBufAllocator extends RecvByteBufAllocator { + /** + * Returns the maximum number of bytes to read per read loop. + * a {@link ChannelInboundHandler#channelRead(ChannelHandlerContext, Object) channelRead()} event. + * If this value is greater than 1, an event loop might attempt to read multiple times to procure bytes. + */ + int maxBytesPerRead(); + + /** + * Sets the maximum number of bytes to read per read loop. + * If this value is greater than 1, an event loop might attempt to read multiple times to procure bytes. + */ + MaxBytesRecvByteBufAllocator maxBytesPerRead(int maxBytesPerRead); + + /** + * Returns the maximum number of bytes to read per individual read operation. + * a {@link ChannelInboundHandler#channelRead(ChannelHandlerContext, Object) channelRead()} event. + * If this value is greater than 1, an event loop might attempt to read multiple times to procure bytes. + */ + int maxBytesPerIndividualRead(); + + /** + * Sets the maximum number of bytes to read per individual read operation. + * If this value is greater than 1, an event loop might attempt to read multiple times to procure bytes. + */ + MaxBytesRecvByteBufAllocator maxBytesPerIndividualRead(int maxBytesPerIndividualRead); + + /** + * Atomic way to get the maximum number of bytes to read for a read loop and per individual read operation. + * If this value is greater than 1, an event loop might attempt to read multiple times to procure bytes. + * @return The Key is from {@link #maxBytesPerRead()}. The Value is from {@link #maxBytesPerIndividualRead()} + */ + Entry maxBytesPerReadPair(); + + /** + * Sets the maximum number of bytes to read for a read loop and per individual read operation. + * If this value is greater than 1, an event loop might attempt to read multiple times to procure bytes. + * @param maxBytesPerRead see {@link #maxBytesPerRead(int)} + * @param maxBytesPerIndividualRead see {@link #maxBytesPerIndividualRead(int)} + */ + MaxBytesRecvByteBufAllocator maxBytesPerReadPair(int maxBytesPerRead, int maxBytesPerIndividualRead); +} diff --git a/netty-channel/src/main/java/io/netty/channel/MaxMessagesRecvByteBufAllocator.java b/netty-channel/src/main/java/io/netty/channel/MaxMessagesRecvByteBufAllocator.java new file mode 100644 index 0000000..3c40bbd --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/MaxMessagesRecvByteBufAllocator.java @@ -0,0 +1,35 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +/** + * {@link RecvByteBufAllocator} that limits the number of read operations that will be attempted when a read operation + * is attempted by the event loop. + */ +public interface MaxMessagesRecvByteBufAllocator extends RecvByteBufAllocator { + /** + * Returns the maximum number of messages to read per read loop. + * a {@link ChannelInboundHandler#channelRead(ChannelHandlerContext, Object) channelRead()} event. + * If this value is greater than 1, an event loop might attempt to read multiple times to procure multiple messages. + */ + int maxMessagesPerRead(); + + /** + * Sets the maximum number of messages to read per read loop. + * If this value is greater than 1, an event loop might attempt to read multiple times to procure multiple messages. + */ + MaxMessagesRecvByteBufAllocator maxMessagesPerRead(int maxMessagesPerRead); +} diff --git a/netty-channel/src/main/java/io/netty/channel/MessageSizeEstimator.java b/netty-channel/src/main/java/io/netty/channel/MessageSizeEstimator.java new file mode 100644 index 0000000..c7a4c03 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/MessageSizeEstimator.java @@ -0,0 +1,39 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +/** + * Responsible to estimate the size of a message. The size represents approximately how much memory the message will + * reserve in memory. + */ +public interface MessageSizeEstimator { + + /** + * Creates a new handle. The handle provides the actual operations. + */ + Handle newHandle(); + + interface Handle { + + /** + * Calculate the size of the given message. + * + * @param msg The message for which the size should be calculated + * @return size The size in bytes. The returned size must be >= 0 + */ + int size(Object msg); + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/MultithreadEventLoopGroup.java b/netty-channel/src/main/java/io/netty/channel/MultithreadEventLoopGroup.java new file mode 100644 index 0000000..e88c282 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/MultithreadEventLoopGroup.java @@ -0,0 +1,100 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.util.NettyRuntime; +import io.netty.util.concurrent.DefaultThreadFactory; +import io.netty.util.concurrent.EventExecutorChooserFactory; +import io.netty.util.concurrent.MultithreadEventExecutorGroup; +import io.netty.util.internal.SystemPropertyUtil; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.util.concurrent.Executor; +import java.util.concurrent.ThreadFactory; + +/** + * Abstract base class for {@link EventLoopGroup} implementations that handles their tasks with multiple threads at + * the same time. + */ +public abstract class MultithreadEventLoopGroup extends MultithreadEventExecutorGroup implements EventLoopGroup { + + private static final InternalLogger logger = InternalLoggerFactory.getInstance(MultithreadEventLoopGroup.class); + + private static final int DEFAULT_EVENT_LOOP_THREADS; + + static { + DEFAULT_EVENT_LOOP_THREADS = Math.max(1, SystemPropertyUtil.getInt( + "io.netty.eventLoopThreads", NettyRuntime.availableProcessors() * 2)); + + if (logger.isDebugEnabled()) { + logger.debug("-Dio.netty.eventLoopThreads: {}", DEFAULT_EVENT_LOOP_THREADS); + } + } + + /** + * @see MultithreadEventExecutorGroup#MultithreadEventExecutorGroup(int, Executor, Object...) + */ + protected MultithreadEventLoopGroup(int nThreads, Executor executor, Object... args) { + super(nThreads == 0 ? DEFAULT_EVENT_LOOP_THREADS : nThreads, executor, args); + } + + /** + * @see MultithreadEventExecutorGroup#MultithreadEventExecutorGroup(int, ThreadFactory, Object...) + */ + protected MultithreadEventLoopGroup(int nThreads, ThreadFactory threadFactory, Object... args) { + super(nThreads == 0 ? DEFAULT_EVENT_LOOP_THREADS : nThreads, threadFactory, args); + } + + /** + * @see MultithreadEventExecutorGroup#MultithreadEventExecutorGroup(int, Executor, + * EventExecutorChooserFactory, Object...) + */ + protected MultithreadEventLoopGroup(int nThreads, Executor executor, EventExecutorChooserFactory chooserFactory, + Object... args) { + super(nThreads == 0 ? DEFAULT_EVENT_LOOP_THREADS : nThreads, executor, chooserFactory, args); + } + + @Override + protected ThreadFactory newDefaultThreadFactory() { + return new DefaultThreadFactory(getClass(), Thread.MAX_PRIORITY); + } + + @Override + public EventLoop next() { + return (EventLoop) super.next(); + } + + @Override + protected abstract EventLoop newChild(Executor executor, Object... args) throws Exception; + + @Override + public ChannelFuture register(Channel channel) { + return next().register(channel); + } + + @Override + public ChannelFuture register(ChannelPromise promise) { + return next().register(promise); + } + + @Deprecated + @Override + public ChannelFuture register(Channel channel, ChannelPromise promise) { + return next().register(channel, promise); + } + +} diff --git a/netty-channel/src/main/java/io/netty/channel/PendingBytesTracker.java b/netty-channel/src/main/java/io/netty/channel/PendingBytesTracker.java new file mode 100644 index 0000000..309921c --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/PendingBytesTracker.java @@ -0,0 +1,104 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.util.internal.ObjectUtil; + +abstract class PendingBytesTracker implements MessageSizeEstimator.Handle { + private final MessageSizeEstimator.Handle estimatorHandle; + + private PendingBytesTracker(MessageSizeEstimator.Handle estimatorHandle) { + this.estimatorHandle = ObjectUtil.checkNotNull(estimatorHandle, "estimatorHandle"); + } + + @Override + public final int size(Object msg) { + return estimatorHandle.size(msg); + } + + public abstract void incrementPendingOutboundBytes(long bytes); + public abstract void decrementPendingOutboundBytes(long bytes); + + static PendingBytesTracker newTracker(Channel channel) { + if (channel.pipeline() instanceof DefaultChannelPipeline) { + return new DefaultChannelPipelinePendingBytesTracker((DefaultChannelPipeline) channel.pipeline()); + } else { + ChannelOutboundBuffer buffer = channel.unsafe().outboundBuffer(); + MessageSizeEstimator.Handle handle = channel.config().getMessageSizeEstimator().newHandle(); + // We need to guard against null as channel.unsafe().outboundBuffer() may returned null + // if the channel was already closed when constructing the PendingBytesTracker. + // See https://github.com/netty/netty/issues/3967 + return buffer == null ? + new NoopPendingBytesTracker(handle) : new ChannelOutboundBufferPendingBytesTracker(buffer, handle); + } + } + + private static final class DefaultChannelPipelinePendingBytesTracker extends PendingBytesTracker { + private final DefaultChannelPipeline pipeline; + + DefaultChannelPipelinePendingBytesTracker(DefaultChannelPipeline pipeline) { + super(pipeline.estimatorHandle()); + this.pipeline = pipeline; + } + + @Override + public void incrementPendingOutboundBytes(long bytes) { + pipeline.incrementPendingOutboundBytes(bytes); + } + + @Override + public void decrementPendingOutboundBytes(long bytes) { + pipeline.decrementPendingOutboundBytes(bytes); + } + } + + private static final class ChannelOutboundBufferPendingBytesTracker extends PendingBytesTracker { + private final ChannelOutboundBuffer buffer; + + ChannelOutboundBufferPendingBytesTracker( + ChannelOutboundBuffer buffer, MessageSizeEstimator.Handle estimatorHandle) { + super(estimatorHandle); + this.buffer = buffer; + } + + @Override + public void incrementPendingOutboundBytes(long bytes) { + buffer.incrementPendingOutboundBytes(bytes); + } + + @Override + public void decrementPendingOutboundBytes(long bytes) { + buffer.decrementPendingOutboundBytes(bytes); + } + } + + private static final class NoopPendingBytesTracker extends PendingBytesTracker { + + NoopPendingBytesTracker(MessageSizeEstimator.Handle estimatorHandle) { + super(estimatorHandle); + } + + @Override + public void incrementPendingOutboundBytes(long bytes) { + // Noop + } + + @Override + public void decrementPendingOutboundBytes(long bytes) { + // Noop + } + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/PendingWriteQueue.java b/netty-channel/src/main/java/io/netty/channel/PendingWriteQueue.java new file mode 100644 index 0000000..20d0e0d --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/PendingWriteQueue.java @@ -0,0 +1,330 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.util.ReferenceCountUtil; +import io.netty.util.concurrent.EventExecutor; +import io.netty.util.concurrent.PromiseCombiner; +import io.netty.util.internal.ObjectPool; +import io.netty.util.internal.ObjectPool.ObjectCreator; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.SystemPropertyUtil; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +/** + * A queue of write operations which are pending for later execution. It also updates the + * {@linkplain Channel#isWritable() writability} of the associated {@link Channel}, so that + * the pending write operations are also considered to determine the writability. + */ +public final class PendingWriteQueue { + private static final InternalLogger logger = InternalLoggerFactory.getInstance(PendingWriteQueue.class); + // Assuming a 64-bit JVM: + // - 16 bytes object header + // - 4 reference fields + // - 1 long fields + private static final int PENDING_WRITE_OVERHEAD = + SystemPropertyUtil.getInt("io.netty.transport.pendingWriteSizeOverhead", 64); + + private final ChannelOutboundInvoker invoker; + private final EventExecutor executor; + private final PendingBytesTracker tracker; + + // head and tail pointers for the linked-list structure. If empty head and tail are null. + private PendingWrite head; + private PendingWrite tail; + private int size; + private long bytes; + + public PendingWriteQueue(ChannelHandlerContext ctx) { + tracker = PendingBytesTracker.newTracker(ctx.channel()); + this.invoker = ctx; + this.executor = ctx.executor(); + } + + public PendingWriteQueue(Channel channel) { + tracker = PendingBytesTracker.newTracker(channel); + this.invoker = channel; + this.executor = channel.eventLoop(); + } + + /** + * Returns {@code true} if there are no pending write operations left in this queue. + */ + public boolean isEmpty() { + assert executor.inEventLoop(); + return head == null; + } + + /** + * Returns the number of pending write operations. + */ + public int size() { + assert executor.inEventLoop(); + return size; + } + + /** + * Returns the total number of bytes that are pending because of pending messages. This is only an estimate so + * it should only be treated as a hint. + */ + public long bytes() { + assert executor.inEventLoop(); + return bytes; + } + + private int size(Object msg) { + // It is possible for writes to be triggered from removeAndFailAll(). To preserve ordering, + // we should add them to the queue and let removeAndFailAll() fail them later. + int messageSize = tracker.size(msg); + if (messageSize < 0) { + // Size may be unknown so just use 0 + messageSize = 0; + } + return messageSize + PENDING_WRITE_OVERHEAD; + } + + /** + * Add the given {@code msg} and {@link ChannelPromise}. + */ + public void add(Object msg, ChannelPromise promise) { + assert executor.inEventLoop(); + ObjectUtil.checkNotNull(msg, "msg"); + ObjectUtil.checkNotNull(promise, "promise"); + // It is possible for writes to be triggered from removeAndFailAll(). To preserve ordering, + // we should add them to the queue and let removeAndFailAll() fail them later. + int messageSize = size(msg); + + PendingWrite write = PendingWrite.newInstance(msg, messageSize, promise); + PendingWrite currentTail = tail; + if (currentTail == null) { + tail = head = write; + } else { + currentTail.next = write; + tail = write; + } + size ++; + bytes += messageSize; + tracker.incrementPendingOutboundBytes(write.size); + } + + /** + * Remove all pending write operation and performs them via + * {@link ChannelHandlerContext#write(Object, ChannelPromise)}. + * + * @return {@link ChannelFuture} if something was written and {@code null} + * if the {@link PendingWriteQueue} is empty. + */ + public ChannelFuture removeAndWriteAll() { + assert executor.inEventLoop(); + + if (isEmpty()) { + return null; + } + + ChannelPromise p = invoker.newPromise(); + PromiseCombiner combiner = new PromiseCombiner(executor); + try { + // It is possible for some of the written promises to trigger more writes. The new writes + // will "revive" the queue, so we need to write them up until the queue is empty. + for (PendingWrite write = head; write != null; write = head) { + head = tail = null; + size = 0; + bytes = 0; + + while (write != null) { + PendingWrite next = write.next; + Object msg = write.msg; + ChannelPromise promise = write.promise; + recycle(write, false); + if (!(promise instanceof VoidChannelPromise)) { + combiner.add(promise); + } + invoker.write(msg, promise); + write = next; + } + } + combiner.finish(p); + } catch (Throwable cause) { + p.setFailure(cause); + } + assertEmpty(); + return p; + } + + /** + * Remove all pending write operation and fail them with the given {@link Throwable}. The message will be released + * via {@link ReferenceCountUtil#safeRelease(Object)}. + */ + public void removeAndFailAll(Throwable cause) { + assert executor.inEventLoop(); + ObjectUtil.checkNotNull(cause, "cause"); + // It is possible for some of the failed promises to trigger more writes. The new writes + // will "revive" the queue, so we need to clean them up until the queue is empty. + for (PendingWrite write = head; write != null; write = head) { + head = tail = null; + size = 0; + bytes = 0; + while (write != null) { + PendingWrite next = write.next; + ReferenceCountUtil.safeRelease(write.msg); + ChannelPromise promise = write.promise; + recycle(write, false); + safeFail(promise, cause); + write = next; + } + } + assertEmpty(); + } + + /** + * Remove a pending write operation and fail it with the given {@link Throwable}. The message will be released via + * {@link ReferenceCountUtil#safeRelease(Object)}. + */ + public void removeAndFail(Throwable cause) { + assert executor.inEventLoop(); + ObjectUtil.checkNotNull(cause, "cause"); + + PendingWrite write = head; + if (write == null) { + return; + } + ReferenceCountUtil.safeRelease(write.msg); + ChannelPromise promise = write.promise; + safeFail(promise, cause); + recycle(write, true); + } + + private void assertEmpty() { + assert tail == null && head == null && size == 0; + } + + /** + * Removes a pending write operation and performs it via + * {@link ChannelHandlerContext#write(Object, ChannelPromise)}. + * + * @return {@link ChannelFuture} if something was written and {@code null} + * if the {@link PendingWriteQueue} is empty. + */ + public ChannelFuture removeAndWrite() { + assert executor.inEventLoop(); + PendingWrite write = head; + if (write == null) { + return null; + } + Object msg = write.msg; + ChannelPromise promise = write.promise; + recycle(write, true); + return invoker.write(msg, promise); + } + + /** + * Removes a pending write operation and release it's message via {@link ReferenceCountUtil#safeRelease(Object)}. + * + * @return {@link ChannelPromise} of the pending write or {@code null} if the queue is empty. + * + */ + public ChannelPromise remove() { + assert executor.inEventLoop(); + PendingWrite write = head; + if (write == null) { + return null; + } + ChannelPromise promise = write.promise; + ReferenceCountUtil.safeRelease(write.msg); + recycle(write, true); + return promise; + } + + /** + * Return the current message or {@code null} if empty. + */ + public Object current() { + assert executor.inEventLoop(); + PendingWrite write = head; + if (write == null) { + return null; + } + return write.msg; + } + + private void recycle(PendingWrite write, boolean update) { + final PendingWrite next = write.next; + final long writeSize = write.size; + + if (update) { + if (next == null) { + // Handled last PendingWrite so rest head and tail + // Guard against re-entrance by directly reset + head = tail = null; + size = 0; + bytes = 0; + } else { + head = next; + size --; + bytes -= writeSize; + assert size > 0 && bytes >= 0; + } + } + + write.recycle(); + tracker.decrementPendingOutboundBytes(writeSize); + } + + private static void safeFail(ChannelPromise promise, Throwable cause) { + if (!(promise instanceof VoidChannelPromise) && !promise.tryFailure(cause)) { + logger.warn("Failed to mark a promise as failure because it's done already: {}", promise, cause); + } + } + + /** + * Holds all meta-data and construct the linked-list structure. + */ + static final class PendingWrite { + private static final ObjectPool RECYCLER = ObjectPool.newPool(new ObjectCreator() { + @Override + public PendingWrite newObject(ObjectPool.Handle handle) { + return new PendingWrite(handle); + } + }); + + private final ObjectPool.Handle handle; + private PendingWrite next; + private long size; + private ChannelPromise promise; + private Object msg; + + private PendingWrite(ObjectPool.Handle handle) { + this.handle = handle; + } + + static PendingWrite newInstance(Object msg, int size, ChannelPromise promise) { + PendingWrite write = RECYCLER.get(); + write.size = size; + write.msg = msg; + write.promise = promise; + return write; + } + + private void recycle() { + size = 0; + next = null; + msg = null; + promise = null; + handle.recycle(this); + } + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/PreferHeapByteBufAllocator.java b/netty-channel/src/main/java/io/netty/channel/PreferHeapByteBufAllocator.java new file mode 100644 index 0000000..f835b1a --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/PreferHeapByteBufAllocator.java @@ -0,0 +1,135 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.UnstableApi; + +/** + * Wraps another {@link ByteBufAllocator} and use heapbuffers everywhere except when a direct buffer is explicit + * requested. + */ +@UnstableApi +public final class PreferHeapByteBufAllocator implements ByteBufAllocator { + private final ByteBufAllocator allocator; + + public PreferHeapByteBufAllocator(ByteBufAllocator allocator) { + this.allocator = ObjectUtil.checkNotNull(allocator, "allocator"); + } + + @Override + public ByteBuf buffer() { + return allocator.heapBuffer(); + } + + @Override + public ByteBuf buffer(int initialCapacity) { + return allocator.heapBuffer(initialCapacity); + } + + @Override + public ByteBuf buffer(int initialCapacity, int maxCapacity) { + return allocator.heapBuffer(initialCapacity, maxCapacity); + } + + @Override + public ByteBuf ioBuffer() { + return allocator.heapBuffer(); + } + + @Override + public ByteBuf ioBuffer(int initialCapacity) { + return allocator.heapBuffer(initialCapacity); + } + + @Override + public ByteBuf ioBuffer(int initialCapacity, int maxCapacity) { + return allocator.heapBuffer(initialCapacity, maxCapacity); + } + + @Override + public ByteBuf heapBuffer() { + return allocator.heapBuffer(); + } + + @Override + public ByteBuf heapBuffer(int initialCapacity) { + return allocator.heapBuffer(initialCapacity); + } + + @Override + public ByteBuf heapBuffer(int initialCapacity, int maxCapacity) { + return allocator.heapBuffer(initialCapacity, maxCapacity); + } + + @Override + public ByteBuf directBuffer() { + return allocator.directBuffer(); + } + + @Override + public ByteBuf directBuffer(int initialCapacity) { + return allocator.directBuffer(initialCapacity); + } + + @Override + public ByteBuf directBuffer(int initialCapacity, int maxCapacity) { + return allocator.directBuffer(initialCapacity, maxCapacity); + } + + @Override + public CompositeByteBuf compositeBuffer() { + return allocator.compositeHeapBuffer(); + } + + @Override + public CompositeByteBuf compositeBuffer(int maxNumComponents) { + return allocator.compositeHeapBuffer(maxNumComponents); + } + + @Override + public CompositeByteBuf compositeHeapBuffer() { + return allocator.compositeHeapBuffer(); + } + + @Override + public CompositeByteBuf compositeHeapBuffer(int maxNumComponents) { + return allocator.compositeHeapBuffer(maxNumComponents); + } + + @Override + public CompositeByteBuf compositeDirectBuffer() { + return allocator.compositeDirectBuffer(); + } + + @Override + public CompositeByteBuf compositeDirectBuffer(int maxNumComponents) { + return allocator.compositeDirectBuffer(maxNumComponents); + } + + @Override + public boolean isDirectBufferPooled() { + return allocator.isDirectBufferPooled(); + } + + @Override + public int calculateNewCapacity(int minNewCapacity, int maxCapacity) { + return allocator.calculateNewCapacity(minNewCapacity, maxCapacity); + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/RecvByteBufAllocator.java b/netty-channel/src/main/java/io/netty/channel/RecvByteBufAllocator.java new file mode 100644 index 0000000..24e1493 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/RecvByteBufAllocator.java @@ -0,0 +1,188 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.util.UncheckedBooleanSupplier; +import io.netty.util.internal.UnstableApi; + +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +/** + * Allocates a new receive buffer whose capacity is probably large enough to read all inbound data and small enough + * not to waste its space. + */ +public interface RecvByteBufAllocator { + /** + * Creates a new handle. The handle provides the actual operations and keeps the internal information which is + * required for predicting an optimal buffer capacity. + */ + Handle newHandle(); + + /** + * @deprecated Use {@link ExtendedHandle}. + */ + @Deprecated + interface Handle { + /** + * Creates a new receive buffer whose capacity is probably large enough to read all inbound data and small + * enough not to waste its space. + */ + ByteBuf allocate(ByteBufAllocator alloc); + + /** + * Similar to {@link #allocate(ByteBufAllocator)} except that it does not allocate anything but just tells the + * capacity. + */ + int guess(); + + /** + * Reset any counters that have accumulated and recommend how many messages/bytes should be read for the next + * read loop. + *

+ * This may be used by {@link #continueReading()} to determine if the read operation should complete. + *

+ * This is only ever a hint and may be ignored by the implementation. + * @param config The channel configuration which may impact this object's behavior. + */ + void reset(ChannelConfig config); + + /** + * Increment the number of messages that have been read for the current read loop. + * @param numMessages The amount to increment by. + */ + void incMessagesRead(int numMessages); + + /** + * Set the bytes that have been read for the last read operation. + * This may be used to increment the number of bytes that have been read. + * @param bytes The number of bytes from the previous read operation. This may be negative if an read error + * occurs. If a negative value is seen it is expected to be return on the next call to + * {@link #lastBytesRead()}. A negative value will signal a termination condition enforced externally + * to this class and is not required to be enforced in {@link #continueReading()}. + */ + void lastBytesRead(int bytes); + + /** + * Get the amount of bytes for the previous read operation. + * @return The amount of bytes for the previous read operation. + */ + int lastBytesRead(); + + /** + * Set how many bytes the read operation will (or did) attempt to read. + * @param bytes How many bytes the read operation will (or did) attempt to read. + */ + void attemptedBytesRead(int bytes); + + /** + * Get how many bytes the read operation will (or did) attempt to read. + * @return How many bytes the read operation will (or did) attempt to read. + */ + int attemptedBytesRead(); + + /** + * Determine if the current read loop should continue. + * @return {@code true} if the read loop should continue reading. {@code false} if the read loop is complete. + */ + boolean continueReading(); + + /** + * The read has completed. + */ + void readComplete(); + } + + @SuppressWarnings("deprecation") + @UnstableApi + interface ExtendedHandle extends Handle { + /** + * Same as {@link Handle#continueReading()} except "more data" is determined by the supplier parameter. + * @param maybeMoreDataSupplier A supplier that determines if there maybe more data to read. + */ + boolean continueReading(UncheckedBooleanSupplier maybeMoreDataSupplier); + } + + /** + * A {@link Handle} which delegates all call to some other {@link Handle}. + */ + class DelegatingHandle implements Handle { + private final Handle delegate; + + public DelegatingHandle(Handle delegate) { + this.delegate = checkNotNull(delegate, "delegate"); + } + + /** + * Get the {@link Handle} which all methods will be delegated to. + * @return the {@link Handle} which all methods will be delegated to. + */ + protected final Handle delegate() { + return delegate; + } + + @Override + public ByteBuf allocate(ByteBufAllocator alloc) { + return delegate.allocate(alloc); + } + + @Override + public int guess() { + return delegate.guess(); + } + + @Override + public void reset(ChannelConfig config) { + delegate.reset(config); + } + + @Override + public void incMessagesRead(int numMessages) { + delegate.incMessagesRead(numMessages); + } + + @Override + public void lastBytesRead(int bytes) { + delegate.lastBytesRead(bytes); + } + + @Override + public int lastBytesRead() { + return delegate.lastBytesRead(); + } + + @Override + public boolean continueReading() { + return delegate.continueReading(); + } + + @Override + public int attemptedBytesRead() { + return delegate.attemptedBytesRead(); + } + + @Override + public void attemptedBytesRead(int bytes) { + delegate.attemptedBytesRead(bytes); + } + + @Override + public void readComplete() { + delegate.readComplete(); + } + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/ReflectiveChannelFactory.java b/netty-channel/src/main/java/io/netty/channel/ReflectiveChannelFactory.java new file mode 100644 index 0000000..16280ff --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/ReflectiveChannelFactory.java @@ -0,0 +1,55 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.channel; + +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.StringUtil; + +import java.lang.reflect.Constructor; + +/** + * A {@link ChannelFactory} that instantiates a new {@link Channel} by invoking its default constructor reflectively. + */ +public class ReflectiveChannelFactory implements ChannelFactory { + + private final Constructor constructor; + + public ReflectiveChannelFactory(Class clazz) { + ObjectUtil.checkNotNull(clazz, "clazz"); + try { + this.constructor = clazz.getConstructor(); + } catch (NoSuchMethodException e) { + throw new IllegalArgumentException("Class " + StringUtil.simpleClassName(clazz) + + " does not have a public non-arg constructor", e); + } + } + + @Override + public T newChannel() { + try { + return constructor.newInstance(); + } catch (Throwable t) { + throw new ChannelException("Unable to create Channel from class " + constructor.getDeclaringClass(), t); + } + } + + @Override + public String toString() { + return StringUtil.simpleClassName(ReflectiveChannelFactory.class) + + '(' + StringUtil.simpleClassName(constructor.getDeclaringClass()) + ".class)"; + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/SelectStrategy.java b/netty-channel/src/main/java/io/netty/channel/SelectStrategy.java new file mode 100644 index 0000000..5e550d4 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/SelectStrategy.java @@ -0,0 +1,52 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.util.IntSupplier; + +/** + * Select strategy interface. + * + * Provides the ability to control the behavior of the select loop. For example a blocking select + * operation can be delayed or skipped entirely if there are events to process immediately. + */ +public interface SelectStrategy { + + /** + * Indicates a blocking select should follow. + */ + int SELECT = -1; + /** + * Indicates the IO loop should be retried, no blocking select to follow directly. + */ + int CONTINUE = -2; + /** + * Indicates the IO loop to poll for new events without blocking. + */ + int BUSY_WAIT = -3; + + /** + * The {@link SelectStrategy} can be used to steer the outcome of a potential select + * call. + * + * @param selectSupplier The supplier with the result of a select result. + * @param hasTasks true if tasks are waiting to be processed. + * @return {@link #SELECT} if the next step should be blocking select {@link #CONTINUE} if + * the next step should be to not select but rather jump back to the IO loop and try + * again. Any value >= 0 is treated as an indicator that work needs to be done. + */ + int calculateStrategy(IntSupplier selectSupplier, boolean hasTasks) throws Exception; +} diff --git a/netty-channel/src/main/java/io/netty/channel/SelectStrategyFactory.java b/netty-channel/src/main/java/io/netty/channel/SelectStrategyFactory.java new file mode 100644 index 0000000..ae123b5 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/SelectStrategyFactory.java @@ -0,0 +1,27 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +/** + * Factory that creates a new {@link SelectStrategy} every time. + */ +public interface SelectStrategyFactory { + + /** + * Creates a new {@link SelectStrategy}. + */ + SelectStrategy newSelectStrategy(); +} diff --git a/netty-channel/src/main/java/io/netty/channel/ServerChannel.java b/netty-channel/src/main/java/io/netty/channel/ServerChannel.java new file mode 100644 index 0000000..520b39a --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/ServerChannel.java @@ -0,0 +1,27 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.channel.socket.ServerSocketChannel; + +/** + * A {@link Channel} that accepts an incoming connection attempt and creates + * its child {@link Channel}s by accepting them. {@link ServerSocketChannel} is + * a good example. + */ +public interface ServerChannel extends Channel { + // This is a tag interface. +} diff --git a/netty-channel/src/main/java/io/netty/channel/ServerChannelRecvByteBufAllocator.java b/netty-channel/src/main/java/io/netty/channel/ServerChannelRecvByteBufAllocator.java new file mode 100644 index 0000000..e02242b --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/ServerChannelRecvByteBufAllocator.java @@ -0,0 +1,35 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +/** + * {@link MaxMessagesRecvByteBufAllocator} implementation which should be used for {@link ServerChannel}s. + */ +public final class ServerChannelRecvByteBufAllocator extends DefaultMaxMessagesRecvByteBufAllocator { + public ServerChannelRecvByteBufAllocator() { + super(1, true); + } + + @Override + public Handle newHandle() { + return new MaxMessageHandle() { + @Override + public int guess() { + return 128; + } + }; + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/SimpleChannelInboundHandler.java b/netty-channel/src/main/java/io/netty/channel/SimpleChannelInboundHandler.java new file mode 100644 index 0000000..a1e8216 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/SimpleChannelInboundHandler.java @@ -0,0 +1,120 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.util.ReferenceCountUtil; +import io.netty.util.internal.TypeParameterMatcher; + +/** + * {@link ChannelInboundHandlerAdapter} which allows to explicit only handle a specific type of messages. + * + * For example here is an implementation which only handle {@link String} messages. + * + *
+ *     public class StringHandler extends
+ *             {@link SimpleChannelInboundHandler}<{@link String}> {
+ *
+ *         {@code @Override}
+ *         protected void channelRead0({@link ChannelHandlerContext} ctx, {@link String} message)
+ *                 throws {@link Exception} {
+ *             System.out.println(message);
+ *         }
+ *     }
+ * 
+ * + * Be aware that depending of the constructor parameters it will release all handled messages by passing them to + * {@link ReferenceCountUtil#release(Object)}. In this case you may need to use + * {@link ReferenceCountUtil#retain(Object)} if you pass the object to the next handler in the {@link ChannelPipeline}. + */ +public abstract class SimpleChannelInboundHandler extends ChannelInboundHandlerAdapter { + + private final TypeParameterMatcher matcher; + private final boolean autoRelease; + + /** + * see {@link #SimpleChannelInboundHandler(boolean)} with {@code true} as boolean parameter. + */ + protected SimpleChannelInboundHandler() { + this(true); + } + + /** + * Create a new instance which will try to detect the types to match out of the type parameter of the class. + * + * @param autoRelease {@code true} if handled messages should be released automatically by passing them to + * {@link ReferenceCountUtil#release(Object)}. + */ + protected SimpleChannelInboundHandler(boolean autoRelease) { + matcher = TypeParameterMatcher.find(this, SimpleChannelInboundHandler.class, "I"); + this.autoRelease = autoRelease; + } + + /** + * see {@link #SimpleChannelInboundHandler(Class, boolean)} with {@code true} as boolean value. + */ + protected SimpleChannelInboundHandler(Class inboundMessageType) { + this(inboundMessageType, true); + } + + /** + * Create a new instance + * + * @param inboundMessageType The type of messages to match + * @param autoRelease {@code true} if handled messages should be released automatically by passing them to + * {@link ReferenceCountUtil#release(Object)}. + */ + protected SimpleChannelInboundHandler(Class inboundMessageType, boolean autoRelease) { + matcher = TypeParameterMatcher.get(inboundMessageType); + this.autoRelease = autoRelease; + } + + /** + * Returns {@code true} if the given message should be handled. If {@code false} it will be passed to the next + * {@link ChannelInboundHandler} in the {@link ChannelPipeline}. + */ + public boolean acceptInboundMessage(Object msg) throws Exception { + return matcher.match(msg); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + boolean release = true; + try { + if (acceptInboundMessage(msg)) { + @SuppressWarnings("unchecked") + I imsg = (I) msg; + channelRead0(ctx, imsg); + } else { + release = false; + ctx.fireChannelRead(msg); + } + } finally { + if (autoRelease && release) { + ReferenceCountUtil.release(msg); + } + } + } + + /** + * Is called for each message of type {@link I}. + * + * @param ctx the {@link ChannelHandlerContext} which this {@link SimpleChannelInboundHandler} + * belongs to + * @param msg the message to handle + * @throws Exception is thrown if an error occurred + */ + protected abstract void channelRead0(ChannelHandlerContext ctx, I msg) throws Exception; +} diff --git a/netty-channel/src/main/java/io/netty/channel/SimpleUserEventChannelHandler.java b/netty-channel/src/main/java/io/netty/channel/SimpleUserEventChannelHandler.java new file mode 100644 index 0000000..98a7edd --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/SimpleUserEventChannelHandler.java @@ -0,0 +1,120 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.util.ReferenceCountUtil; +import io.netty.util.internal.TypeParameterMatcher; + +/** + * {@link ChannelInboundHandlerAdapter} which allows to conveniently only handle a specific type of user events. + * + * For example, here is an implementation which only handle {@link String} user events. + * + *
+ *     public class StringEventHandler extends
+ *             {@link SimpleUserEventChannelHandler}<{@link String}> {
+ *
+ *         {@code @Override}
+ *         protected void eventReceived({@link ChannelHandlerContext} ctx, {@link String} evt)
+ *                 throws {@link Exception} {
+ *             System.out.println(evt);
+ *         }
+ *     }
+ * 
+ * + * Be aware that depending of the constructor parameters it will release all handled events by passing them to + * {@link ReferenceCountUtil#release(Object)}. In this case you may need to use + * {@link ReferenceCountUtil#retain(Object)} if you pass the object to the next handler in the {@link ChannelPipeline}. + */ +public abstract class SimpleUserEventChannelHandler extends ChannelInboundHandlerAdapter { + + private final TypeParameterMatcher matcher; + private final boolean autoRelease; + + /** + * see {@link #SimpleUserEventChannelHandler(boolean)} with {@code true} as boolean parameter. + */ + protected SimpleUserEventChannelHandler() { + this(true); + } + + /** + * Create a new instance which will try to detect the types to match out of the type parameter of the class. + * + * @param autoRelease {@code true} if handled events should be released automatically by passing them to + * {@link ReferenceCountUtil#release(Object)}. + */ + protected SimpleUserEventChannelHandler(boolean autoRelease) { + matcher = TypeParameterMatcher.find(this, SimpleUserEventChannelHandler.class, "I"); + this.autoRelease = autoRelease; + } + + /** + * see {@link #SimpleUserEventChannelHandler(Class, boolean)} with {@code true} as boolean value. + */ + protected SimpleUserEventChannelHandler(Class eventType) { + this(eventType, true); + } + + /** + * Create a new instance + * + * @param eventType The type of events to match + * @param autoRelease {@code true} if handled events should be released automatically by passing them to + * {@link ReferenceCountUtil#release(Object)}. + */ + protected SimpleUserEventChannelHandler(Class eventType, boolean autoRelease) { + matcher = TypeParameterMatcher.get(eventType); + this.autoRelease = autoRelease; + } + + /** + * Returns {@code true} if the given user event should be handled. If {@code false} it will be passed to the next + * {@link ChannelInboundHandler} in the {@link ChannelPipeline}. + */ + protected boolean acceptEvent(Object evt) throws Exception { + return matcher.match(evt); + } + + @Override + public final void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + boolean release = true; + try { + if (acceptEvent(evt)) { + @SuppressWarnings("unchecked") + I ievt = (I) evt; + eventReceived(ctx, ievt); + } else { + release = false; + ctx.fireUserEventTriggered(evt); + } + } finally { + if (autoRelease && release) { + ReferenceCountUtil.release(evt); + } + } + } + + /** + * Is called for each user event triggered of type {@link I}. + * + * @param ctx the {@link ChannelHandlerContext} which this {@link SimpleUserEventChannelHandler} belongs to + * @param evt the user event to handle + * + * @throws Exception is thrown if an error occurred + */ + protected abstract void eventReceived(ChannelHandlerContext ctx, I evt) throws Exception; +} diff --git a/netty-channel/src/main/java/io/netty/channel/SingleThreadEventLoop.java b/netty-channel/src/main/java/io/netty/channel/SingleThreadEventLoop.java new file mode 100644 index 0000000..f7136c6 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/SingleThreadEventLoop.java @@ -0,0 +1,217 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.util.concurrent.RejectedExecutionHandler; +import io.netty.util.concurrent.RejectedExecutionHandlers; +import io.netty.util.concurrent.SingleThreadEventExecutor; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.SystemPropertyUtil; +import io.netty.util.internal.UnstableApi; + +import java.util.Iterator; +import java.util.NoSuchElementException; +import java.util.Queue; +import java.util.concurrent.Executor; +import java.util.concurrent.ThreadFactory; + +/** + * Abstract base class for {@link EventLoop}s that execute all its submitted tasks in a single thread. + * + */ +public abstract class SingleThreadEventLoop extends SingleThreadEventExecutor implements EventLoop { + + protected static final int DEFAULT_MAX_PENDING_TASKS = Math.max(16, + SystemPropertyUtil.getInt("io.netty.eventLoop.maxPendingTasks", Integer.MAX_VALUE)); + + private final Queue tailTasks; + + protected SingleThreadEventLoop(EventLoopGroup parent, ThreadFactory threadFactory, boolean addTaskWakesUp) { + this(parent, threadFactory, addTaskWakesUp, DEFAULT_MAX_PENDING_TASKS, RejectedExecutionHandlers.reject()); + } + + protected SingleThreadEventLoop(EventLoopGroup parent, Executor executor, boolean addTaskWakesUp) { + this(parent, executor, addTaskWakesUp, DEFAULT_MAX_PENDING_TASKS, RejectedExecutionHandlers.reject()); + } + + protected SingleThreadEventLoop(EventLoopGroup parent, ThreadFactory threadFactory, + boolean addTaskWakesUp, int maxPendingTasks, + RejectedExecutionHandler rejectedExecutionHandler) { + super(parent, threadFactory, addTaskWakesUp, maxPendingTasks, rejectedExecutionHandler); + tailTasks = newTaskQueue(maxPendingTasks); + } + + protected SingleThreadEventLoop(EventLoopGroup parent, Executor executor, + boolean addTaskWakesUp, int maxPendingTasks, + RejectedExecutionHandler rejectedExecutionHandler) { + super(parent, executor, addTaskWakesUp, maxPendingTasks, rejectedExecutionHandler); + tailTasks = newTaskQueue(maxPendingTasks); + } + + protected SingleThreadEventLoop(EventLoopGroup parent, Executor executor, + boolean addTaskWakesUp, Queue taskQueue, Queue tailTaskQueue, + RejectedExecutionHandler rejectedExecutionHandler) { + super(parent, executor, addTaskWakesUp, taskQueue, rejectedExecutionHandler); + tailTasks = ObjectUtil.checkNotNull(tailTaskQueue, "tailTaskQueue"); + } + + @Override + public EventLoopGroup parent() { + return (EventLoopGroup) super.parent(); + } + + @Override + public EventLoop next() { + return (EventLoop) super.next(); + } + + @Override + public ChannelFuture register(Channel channel) { + return register(new DefaultChannelPromise(channel, this)); + } + + @Override + public ChannelFuture register(final ChannelPromise promise) { + ObjectUtil.checkNotNull(promise, "promise"); + promise.channel().unsafe().register(this, promise); + return promise; + } + + @Deprecated + @Override + public ChannelFuture register(final Channel channel, final ChannelPromise promise) { + ObjectUtil.checkNotNull(promise, "promise"); + ObjectUtil.checkNotNull(channel, "channel"); + channel.unsafe().register(this, promise); + return promise; + } + + /** + * Adds a task to be run once at the end of next (or current) {@code eventloop} iteration. + * + * @param task to be added. + */ + @UnstableApi + public final void executeAfterEventLoopIteration(Runnable task) { + ObjectUtil.checkNotNull(task, "task"); + if (isShutdown()) { + reject(); + } + + if (!tailTasks.offer(task)) { + reject(task); + } + + if (wakesUpForTask(task)) { + wakeup(inEventLoop()); + } + } + + /** + * Removes a task that was added previously via {@link #executeAfterEventLoopIteration(Runnable)}. + * + * @param task to be removed. + * + * @return {@code true} if the task was removed as a result of this call. + */ + @UnstableApi + final boolean removeAfterEventLoopIterationTask(Runnable task) { + return tailTasks.remove(ObjectUtil.checkNotNull(task, "task")); + } + + @Override + protected void afterRunningAllTasks() { + runAllTasksFrom(tailTasks); + } + + @Override + protected boolean hasTasks() { + return super.hasTasks() || !tailTasks.isEmpty(); + } + + @Override + public int pendingTasks() { + return super.pendingTasks() + tailTasks.size(); + } + + /** + * Returns the number of {@link Channel}s registered with this {@link EventLoop} or {@code -1} + * if operation is not supported. The returned value is not guaranteed to be exact accurate and + * should be viewed as a best effort. + */ + @UnstableApi + public int registeredChannels() { + return -1; + } + + /** + * @return read-only iterator of active {@link Channel}s registered with this {@link EventLoop}. + * The returned value is not guaranteed to be exact accurate and + * should be viewed as a best effort. This method is expected to be called from within + * event loop. + * @throws UnsupportedOperationException if operation is not supported by implementation. + */ + @UnstableApi + public Iterator registeredChannelsIterator() { + throw new UnsupportedOperationException("registeredChannelsIterator"); + } + + protected static final class ChannelsReadOnlyIterator implements Iterator { + private final Iterator channelIterator; + + public ChannelsReadOnlyIterator(Iterable channelIterable) { + this.channelIterator = + ObjectUtil.checkNotNull(channelIterable, "channelIterable").iterator(); + } + + @Override + public boolean hasNext() { + return channelIterator.hasNext(); + } + + @Override + public Channel next() { + return channelIterator.next(); + } + + @Override + public void remove() { + throw new UnsupportedOperationException("remove"); + } + + @SuppressWarnings("unchecked") + public static Iterator empty() { + return (Iterator) EMPTY; + } + + private static final Iterator EMPTY = new Iterator() { + @Override + public boolean hasNext() { + return false; + } + + @Override + public Object next() { + throw new NoSuchElementException(); + } + + @Override + public void remove() { + throw new UnsupportedOperationException("remove"); + } + }; + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/StacklessClosedChannelException.java b/netty-channel/src/main/java/io/netty/channel/StacklessClosedChannelException.java new file mode 100644 index 0000000..1310f63 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/StacklessClosedChannelException.java @@ -0,0 +1,43 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.util.internal.ThrowableUtil; + +import java.nio.channels.ClosedChannelException; + +/** + * Cheap {@link ClosedChannelException} that does not fill in the stacktrace. + */ +final class StacklessClosedChannelException extends ClosedChannelException { + + private static final long serialVersionUID = -2214806025529435136L; + + private StacklessClosedChannelException() { } + + @Override + public Throwable fillInStackTrace() { + // Suppress a warning since this method doesn't need synchronization + return this; + } + + /** + * Creates a new {@link StacklessClosedChannelException} which has the origin of the given {@link Class} and method. + */ + static StacklessClosedChannelException newInstance(Class clazz, String method) { + return ThrowableUtil.unknownStackTrace(new StacklessClosedChannelException(), clazz, method); + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/SucceededChannelFuture.java b/netty-channel/src/main/java/io/netty/channel/SucceededChannelFuture.java new file mode 100644 index 0000000..36f9e6c --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/SucceededChannelFuture.java @@ -0,0 +1,45 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.util.concurrent.EventExecutor; + +/** + * The {@link CompleteChannelFuture} which is succeeded already. It is + * recommended to use {@link Channel#newSucceededFuture()} instead of + * calling the constructor of this future. + */ +final class SucceededChannelFuture extends CompleteChannelFuture { + + /** + * Creates a new instance. + * + * @param channel the {@link Channel} associated with this future + */ + SucceededChannelFuture(Channel channel, EventExecutor executor) { + super(channel, executor); + } + + @Override + public Throwable cause() { + return null; + } + + @Override + public boolean isSuccess() { + return true; + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/ThreadPerChannelEventLoop.java b/netty-channel/src/main/java/io/netty/channel/ThreadPerChannelEventLoop.java new file mode 100644 index 0000000..b92ac6b --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/ThreadPerChannelEventLoop.java @@ -0,0 +1,103 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +/** + * {@link SingleThreadEventLoop} which is used to handle OIO {@link Channel}'s. So in general there will be + * one {@link ThreadPerChannelEventLoop} per {@link Channel}. + * + * @deprecated this will be remove in the next-major release. + */ +@Deprecated +public class ThreadPerChannelEventLoop extends SingleThreadEventLoop { + + private final ThreadPerChannelEventLoopGroup parent; + private Channel ch; + + public ThreadPerChannelEventLoop(ThreadPerChannelEventLoopGroup parent) { + super(parent, parent.executor, true); + this.parent = parent; + } + + @Override + public ChannelFuture register(ChannelPromise promise) { + return super.register(promise).addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + if (future.isSuccess()) { + ch = future.channel(); + } else { + deregister(); + } + } + }); + } + + @Deprecated + @Override + public ChannelFuture register(Channel channel, ChannelPromise promise) { + return super.register(channel, promise).addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + if (future.isSuccess()) { + ch = future.channel(); + } else { + deregister(); + } + } + }); + } + + @Override + protected void run() { + for (;;) { + Runnable task = takeTask(); + if (task != null) { + task.run(); + updateLastExecutionTime(); + } + + Channel ch = this.ch; + if (isShuttingDown()) { + if (ch != null) { + ch.unsafe().close(ch.unsafe().voidPromise()); + } + if (confirmShutdown()) { + break; + } + } else { + if (ch != null) { + // Handle deregistration + if (!ch.isRegistered()) { + runAllTasks(); + deregister(); + } + } + } + } + } + + protected void deregister() { + ch = null; + parent.activeChildren.remove(this); + parent.idleChildren.add(this); + } + + @Override + public int registeredChannels() { + return 1; + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/ThreadPerChannelEventLoopGroup.java b/netty-channel/src/main/java/io/netty/channel/ThreadPerChannelEventLoopGroup.java new file mode 100644 index 0000000..9d9b4cb --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/ThreadPerChannelEventLoopGroup.java @@ -0,0 +1,321 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + + +import io.netty.util.concurrent.AbstractEventExecutorGroup; +import io.netty.util.concurrent.DefaultPromise; +import io.netty.util.concurrent.DefaultThreadFactory; +import io.netty.util.concurrent.EventExecutor; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.FutureListener; +import io.netty.util.concurrent.GlobalEventExecutor; +import io.netty.util.concurrent.Promise; +import io.netty.util.concurrent.ThreadPerTaskExecutor; +import io.netty.util.internal.EmptyArrays; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.ReadOnlyIterator; + +import java.util.Collections; +import java.util.Iterator; +import java.util.Queue; +import java.util.Set; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.Executor; +import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.TimeUnit; + +/** + * An {@link EventLoopGroup} that creates one {@link EventLoop} per {@link Channel}. + * + * @deprecated this will be remove in the next-major release. + */ +@Deprecated +public class ThreadPerChannelEventLoopGroup extends AbstractEventExecutorGroup implements EventLoopGroup { + + private final Object[] childArgs; + private final int maxChannels; + final Executor executor; + final Set activeChildren = + Collections.newSetFromMap(PlatformDependent.newConcurrentHashMap()); + final Queue idleChildren = new ConcurrentLinkedQueue(); + private final ChannelException tooManyChannels; + + private volatile boolean shuttingDown; + private final Promise terminationFuture = new DefaultPromise(GlobalEventExecutor.INSTANCE); + private final FutureListener childTerminationListener = new FutureListener() { + @Override + public void operationComplete(Future future) throws Exception { + // Inefficient, but works. + if (isTerminated()) { + terminationFuture.trySuccess(null); + } + } + }; + + /** + * Create a new {@link ThreadPerChannelEventLoopGroup} with no limit in place. + */ + protected ThreadPerChannelEventLoopGroup() { + this(0); + } + + /** + * Create a new {@link ThreadPerChannelEventLoopGroup}. + * + * @param maxChannels the maximum number of channels to handle with this instance. Once you try to register + * a new {@link Channel} and the maximum is exceed it will throw an + * {@link ChannelException}. on the {@link #register(Channel)} and + * {@link #register(ChannelPromise)} method. + * Use {@code 0} to use no limit + */ + protected ThreadPerChannelEventLoopGroup(int maxChannels) { + this(maxChannels, (ThreadFactory) null); + } + + /** + * Create a new {@link ThreadPerChannelEventLoopGroup}. + * + * @param maxChannels the maximum number of channels to handle with this instance. Once you try to register + * a new {@link Channel} and the maximum is exceed it will throw an + * {@link ChannelException} on the {@link #register(Channel)} and + * {@link #register(ChannelPromise)} method. + * Use {@code 0} to use no limit + * @param threadFactory the {@link ThreadFactory} used to create new {@link Thread} instances that handle the + * registered {@link Channel}s + * @param args arguments which will passed to each {@link #newChild(Object...)} call. + */ + protected ThreadPerChannelEventLoopGroup(int maxChannels, ThreadFactory threadFactory, Object... args) { + this(maxChannels, threadFactory == null ? null : new ThreadPerTaskExecutor(threadFactory), args); + } + + /** + * Create a new {@link ThreadPerChannelEventLoopGroup}. + * + * @param maxChannels the maximum number of channels to handle with this instance. Once you try to register + * a new {@link Channel} and the maximum is exceed it will throw an + * {@link ChannelException} on the {@link #register(Channel)} and + * {@link #register(ChannelPromise)} method. + * Use {@code 0} to use no limit + * @param executor the {@link Executor} used to create new {@link Thread} instances that handle the + * registered {@link Channel}s + * @param args arguments which will passed to each {@link #newChild(Object...)} call. + */ + protected ThreadPerChannelEventLoopGroup(int maxChannels, Executor executor, Object... args) { + ObjectUtil.checkPositiveOrZero(maxChannels, "maxChannels"); + if (executor == null) { + executor = new ThreadPerTaskExecutor(new DefaultThreadFactory(getClass())); + } + + if (args == null) { + childArgs = EmptyArrays.EMPTY_OBJECTS; + } else { + childArgs = args.clone(); + } + + this.maxChannels = maxChannels; + this.executor = executor; + + tooManyChannels = + ChannelException.newStatic("too many channels (max: " + maxChannels + ')', + ThreadPerChannelEventLoopGroup.class, "nextChild()"); + } + + /** + * Creates a new {@link EventLoop}. The default implementation creates a new {@link ThreadPerChannelEventLoop}. + */ + protected EventLoop newChild(@SuppressWarnings("UnusedParameters") Object... args) throws Exception { + return new ThreadPerChannelEventLoop(this); + } + + @Override + public Iterator iterator() { + return new ReadOnlyIterator(activeChildren.iterator()); + } + + @Override + public EventLoop next() { + throw new UnsupportedOperationException(); + } + + @Override + public Future shutdownGracefully(long quietPeriod, long timeout, TimeUnit unit) { + shuttingDown = true; + + for (EventLoop l: activeChildren) { + l.shutdownGracefully(quietPeriod, timeout, unit); + } + for (EventLoop l: idleChildren) { + l.shutdownGracefully(quietPeriod, timeout, unit); + } + + // Notify the future if there was no children. + if (isTerminated()) { + terminationFuture.trySuccess(null); + } + + return terminationFuture(); + } + + @Override + public Future terminationFuture() { + return terminationFuture; + } + + @Override + @Deprecated + public void shutdown() { + shuttingDown = true; + + for (EventLoop l: activeChildren) { + l.shutdown(); + } + for (EventLoop l: idleChildren) { + l.shutdown(); + } + + // Notify the future if there was no children. + if (isTerminated()) { + terminationFuture.trySuccess(null); + } + } + + @Override + public boolean isShuttingDown() { + for (EventLoop l: activeChildren) { + if (!l.isShuttingDown()) { + return false; + } + } + for (EventLoop l: idleChildren) { + if (!l.isShuttingDown()) { + return false; + } + } + return true; + } + + @Override + public boolean isShutdown() { + for (EventLoop l: activeChildren) { + if (!l.isShutdown()) { + return false; + } + } + for (EventLoop l: idleChildren) { + if (!l.isShutdown()) { + return false; + } + } + return true; + } + + @Override + public boolean isTerminated() { + for (EventLoop l: activeChildren) { + if (!l.isTerminated()) { + return false; + } + } + for (EventLoop l: idleChildren) { + if (!l.isTerminated()) { + return false; + } + } + return true; + } + + @Override + public boolean awaitTermination(long timeout, TimeUnit unit) + throws InterruptedException { + long deadline = System.nanoTime() + unit.toNanos(timeout); + for (EventLoop l: activeChildren) { + for (;;) { + long timeLeft = deadline - System.nanoTime(); + if (timeLeft <= 0) { + return isTerminated(); + } + if (l.awaitTermination(timeLeft, TimeUnit.NANOSECONDS)) { + break; + } + } + } + for (EventLoop l: idleChildren) { + for (;;) { + long timeLeft = deadline - System.nanoTime(); + if (timeLeft <= 0) { + return isTerminated(); + } + if (l.awaitTermination(timeLeft, TimeUnit.NANOSECONDS)) { + break; + } + } + } + return isTerminated(); + } + + @Override + public ChannelFuture register(Channel channel) { + ObjectUtil.checkNotNull(channel, "channel"); + try { + EventLoop l = nextChild(); + return l.register(new DefaultChannelPromise(channel, l)); + } catch (Throwable t) { + return new FailedChannelFuture(channel, GlobalEventExecutor.INSTANCE, t); + } + } + + @Override + public ChannelFuture register(ChannelPromise promise) { + try { + return nextChild().register(promise); + } catch (Throwable t) { + promise.setFailure(t); + return promise; + } + } + + @Deprecated + @Override + public ChannelFuture register(Channel channel, ChannelPromise promise) { + ObjectUtil.checkNotNull(channel, "channel"); + try { + return nextChild().register(channel, promise); + } catch (Throwable t) { + promise.setFailure(t); + return promise; + } + } + + private EventLoop nextChild() throws Exception { + if (shuttingDown) { + throw new RejectedExecutionException("shutting down"); + } + + EventLoop loop = idleChildren.poll(); + if (loop == null) { + if (maxChannels > 0 && activeChildren.size() >= maxChannels) { + throw tooManyChannels; + } + loop = newChild(childArgs); + loop.terminationFuture().addListener(childTerminationListener); + } + activeChildren.add(loop); + return loop; + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/VoidChannelPromise.java b/netty-channel/src/main/java/io/netty/channel/VoidChannelPromise.java new file mode 100644 index 0000000..50c143a --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/VoidChannelPromise.java @@ -0,0 +1,239 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.util.concurrent.AbstractFuture; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.GenericFutureListener; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.UnstableApi; + +import java.util.concurrent.TimeUnit; + +@UnstableApi +public final class VoidChannelPromise extends AbstractFuture implements ChannelPromise { + + private final Channel channel; + // Will be null if we should not propagate exceptions through the pipeline on failure case. + private final ChannelFutureListener fireExceptionListener; + + /** + * Creates a new instance. + * + * @param channel the {@link Channel} associated with this future + */ + public VoidChannelPromise(final Channel channel, boolean fireException) { + ObjectUtil.checkNotNull(channel, "channel"); + this.channel = channel; + if (fireException) { + fireExceptionListener = new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + Throwable cause = future.cause(); + if (cause != null) { + fireException0(cause); + } + } + }; + } else { + fireExceptionListener = null; + } + } + + @Override + public VoidChannelPromise addListener(GenericFutureListener> listener) { + fail(); + return this; + } + + @Override + public VoidChannelPromise addListeners(GenericFutureListener>... listeners) { + fail(); + return this; + } + + @Override + public VoidChannelPromise removeListener(GenericFutureListener> listener) { + // NOOP + return this; + } + + @Override + public VoidChannelPromise removeListeners(GenericFutureListener>... listeners) { + // NOOP + return this; + } + + @Override + public VoidChannelPromise await() throws InterruptedException { + if (Thread.interrupted()) { + throw new InterruptedException(); + } + return this; + } + + @Override + public boolean await(long timeout, TimeUnit unit) { + fail(); + return false; + } + + @Override + public boolean await(long timeoutMillis) { + fail(); + return false; + } + + @Override + public VoidChannelPromise awaitUninterruptibly() { + fail(); + return this; + } + + @Override + public boolean awaitUninterruptibly(long timeout, TimeUnit unit) { + fail(); + return false; + } + + @Override + public boolean awaitUninterruptibly(long timeoutMillis) { + fail(); + return false; + } + + @Override + public Channel channel() { + return channel; + } + + @Override + public boolean isDone() { + return false; + } + + @Override + public boolean isSuccess() { + return false; + } + + @Override + public boolean setUncancellable() { + return true; + } + + @Override + public boolean isCancellable() { + return false; + } + + @Override + public boolean isCancelled() { + return false; + } + + @Override + public Throwable cause() { + return null; + } + + @Override + public VoidChannelPromise sync() { + fail(); + return this; + } + + @Override + public VoidChannelPromise syncUninterruptibly() { + fail(); + return this; + } + + @Override + public VoidChannelPromise setFailure(Throwable cause) { + fireException0(cause); + return this; + } + + @Override + public VoidChannelPromise setSuccess() { + return this; + } + + @Override + public boolean tryFailure(Throwable cause) { + fireException0(cause); + return false; + } + + /** + * {@inheritDoc} + * + * @param mayInterruptIfRunning this value has no effect in this implementation. + */ + @Override + public boolean cancel(boolean mayInterruptIfRunning) { + return false; + } + + @Override + public boolean trySuccess() { + return false; + } + + private static void fail() { + throw new IllegalStateException("void future"); + } + + @Override + public VoidChannelPromise setSuccess(Void result) { + return this; + } + + @Override + public boolean trySuccess(Void result) { + return false; + } + + @Override + public Void getNow() { + return null; + } + + @Override + public ChannelPromise unvoid() { + ChannelPromise promise = new DefaultChannelPromise(channel); + if (fireExceptionListener != null) { + promise.addListener(fireExceptionListener); + } + return promise; + } + + @Override + public boolean isVoid() { + return true; + } + + private void fireException0(Throwable cause) { + // Only fire the exception if the channel is open and registered + // if not the pipeline is not setup and so it would hit the tail + // of the pipeline. + // See https://github.com/netty/netty/issues/1517 + if (fireExceptionListener != null && channel.isRegistered()) { + channel.pipeline().fireExceptionCaught(cause); + } + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/WriteBufferWaterMark.java b/netty-channel/src/main/java/io/netty/channel/WriteBufferWaterMark.java new file mode 100644 index 0000000..cfb6140 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/WriteBufferWaterMark.java @@ -0,0 +1,96 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; + +/** + * WriteBufferWaterMark is used to set low water mark and high water mark for the write buffer. + *

+ * If the number of bytes queued in the write buffer exceeds the + * {@linkplain #high high water mark}, {@link Channel#isWritable()} + * will start to return {@code false}. + *

+ * If the number of bytes queued in the write buffer exceeds the + * {@linkplain #high high water mark} and then + * dropped down below the {@linkplain #low low water mark}, + * {@link Channel#isWritable()} will start to return + * {@code true} again. + */ +public final class WriteBufferWaterMark { + + private static final int DEFAULT_LOW_WATER_MARK = 32 * 1024; + private static final int DEFAULT_HIGH_WATER_MARK = 64 * 1024; + + public static final WriteBufferWaterMark DEFAULT = + new WriteBufferWaterMark(DEFAULT_LOW_WATER_MARK, DEFAULT_HIGH_WATER_MARK, false); + + private final int low; + private final int high; + + /** + * Create a new instance. + * + * @param low low water mark for write buffer. + * @param high high water mark for write buffer + */ + public WriteBufferWaterMark(int low, int high) { + this(low, high, true); + } + + /** + * This constructor is needed to keep backward-compatibility. + */ + WriteBufferWaterMark(int low, int high, boolean validate) { + if (validate) { + checkPositiveOrZero(low, "low"); + if (high < low) { + throw new IllegalArgumentException( + "write buffer's high water mark cannot be less than " + + " low water mark (" + low + "): " + + high); + } + } + this.low = low; + this.high = high; + } + + /** + * Returns the low water mark for the write buffer. + */ + public int low() { + return low; + } + + /** + * Returns the high water mark for the write buffer. + */ + public int high() { + return high; + } + + @Override + public String toString() { + StringBuilder builder = new StringBuilder(55) + .append("WriteBufferWaterMark(low: ") + .append(low) + .append(", high: ") + .append(high) + .append(")"); + return builder.toString(); + } + +} diff --git a/netty-channel/src/main/java/io/netty/channel/embedded/EmbeddedChannel.java b/netty-channel/src/main/java/io/netty/channel/embedded/EmbeddedChannel.java new file mode 100644 index 0000000..191d6f6 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/embedded/EmbeddedChannel.java @@ -0,0 +1,930 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.embedded; + +import java.net.SocketAddress; +import java.nio.channels.ClosedChannelException; +import java.util.ArrayDeque; +import java.util.Queue; +import java.util.concurrent.TimeUnit; + +import io.netty.channel.AbstractChannel; +import io.netty.channel.Channel; +import io.netty.channel.ChannelConfig; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelId; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelMetadata; +import io.netty.channel.ChannelOutboundBuffer; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.ChannelPromise; +import io.netty.channel.DefaultChannelConfig; +import io.netty.channel.DefaultChannelPipeline; +import io.netty.channel.EventLoop; +import io.netty.channel.RecvByteBufAllocator; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.RecyclableArrayList; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +/** + * Base class for {@link Channel} implementations that are used in an embedded fashion. + */ +public class EmbeddedChannel extends AbstractChannel { + + private static final SocketAddress LOCAL_ADDRESS = new EmbeddedSocketAddress(); + private static final SocketAddress REMOTE_ADDRESS = new EmbeddedSocketAddress(); + + private static final ChannelHandler[] EMPTY_HANDLERS = new ChannelHandler[0]; + private enum State { OPEN, ACTIVE, CLOSED } + + private static final InternalLogger logger = InternalLoggerFactory.getInstance(EmbeddedChannel.class); + + private static final ChannelMetadata METADATA_NO_DISCONNECT = new ChannelMetadata(false); + private static final ChannelMetadata METADATA_DISCONNECT = new ChannelMetadata(true); + + private final EmbeddedEventLoop loop = new EmbeddedEventLoop(); + private final ChannelFutureListener recordExceptionListener = new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + recordException(future); + } + }; + + private final ChannelMetadata metadata; + private final ChannelConfig config; + + private Queue inboundMessages; + private Queue outboundMessages; + private Throwable lastException; + private State state; + + /** + * Create a new instance with an {@link EmbeddedChannelId} and an empty pipeline. + */ + public EmbeddedChannel() { + this(EMPTY_HANDLERS); + } + + /** + * Create a new instance with the specified ID and an empty pipeline. + * + * @param channelId the {@link ChannelId} that will be used to identify this channel + */ + public EmbeddedChannel(ChannelId channelId) { + this(channelId, EMPTY_HANDLERS); + } + + /** + * Create a new instance with the pipeline initialized with the specified handlers. + * + * @param handlers the {@link ChannelHandler}s which will be add in the {@link ChannelPipeline} + */ + public EmbeddedChannel(ChannelHandler... handlers) { + this(EmbeddedChannelId.INSTANCE, handlers); + } + + /** + * Create a new instance with the pipeline initialized with the specified handlers. + * + * @param hasDisconnect {@code false} if this {@link Channel} will delegate {@link #disconnect()} + * to {@link #close()}, {@code true} otherwise. + * @param handlers the {@link ChannelHandler}s which will be added to the {@link ChannelPipeline} + */ + public EmbeddedChannel(boolean hasDisconnect, ChannelHandler... handlers) { + this(EmbeddedChannelId.INSTANCE, hasDisconnect, handlers); + } + + /** + * Create a new instance with the pipeline initialized with the specified handlers. + * + * @param register {@code true} if this {@link Channel} is registered to the {@link EventLoop} in the + * constructor. If {@code false} the user will need to call {@link #register()}. + * @param hasDisconnect {@code false} if this {@link Channel} will delegate {@link #disconnect()} + * to {@link #close()}, {@code true} otherwise. + * @param handlers the {@link ChannelHandler}s which will be added to the {@link ChannelPipeline} + */ + public EmbeddedChannel(boolean register, boolean hasDisconnect, ChannelHandler... handlers) { + this(EmbeddedChannelId.INSTANCE, register, hasDisconnect, handlers); + } + + /** + * Create a new instance with the channel ID set to the given ID and the pipeline + * initialized with the specified handlers. + * + * @param channelId the {@link ChannelId} that will be used to identify this channel + * @param handlers the {@link ChannelHandler}s which will be added to the {@link ChannelPipeline} + */ + public EmbeddedChannel(ChannelId channelId, ChannelHandler... handlers) { + this(channelId, false, handlers); + } + + /** + * Create a new instance with the channel ID set to the given ID and the pipeline + * initialized with the specified handlers. + * + * @param channelId the {@link ChannelId} that will be used to identify this channel + * @param hasDisconnect {@code false} if this {@link Channel} will delegate {@link #disconnect()} + * to {@link #close()}, {@code true} otherwise. + * @param handlers the {@link ChannelHandler}s which will be added to the {@link ChannelPipeline} + */ + public EmbeddedChannel(ChannelId channelId, boolean hasDisconnect, ChannelHandler... handlers) { + this(channelId, true, hasDisconnect, handlers); + } + + /** + * Create a new instance with the channel ID set to the given ID and the pipeline + * initialized with the specified handlers. + * + * @param channelId the {@link ChannelId} that will be used to identify this channel + * @param register {@code true} if this {@link Channel} is registered to the {@link EventLoop} in the + * constructor. If {@code false} the user will need to call {@link #register()}. + * @param hasDisconnect {@code false} if this {@link Channel} will delegate {@link #disconnect()} + * to {@link #close()}, {@code true} otherwise. + * @param handlers the {@link ChannelHandler}s which will be added to the {@link ChannelPipeline} + */ + public EmbeddedChannel(ChannelId channelId, boolean register, boolean hasDisconnect, + ChannelHandler... handlers) { + this(null, channelId, register, hasDisconnect, handlers); + } + + /** + * Create a new instance with the channel ID set to the given ID and the pipeline + * initialized with the specified handlers. + * + * @param parent the parent {@link Channel} of this {@link EmbeddedChannel}. + * @param channelId the {@link ChannelId} that will be used to identify this channel + * @param register {@code true} if this {@link Channel} is registered to the {@link EventLoop} in the + * constructor. If {@code false} the user will need to call {@link #register()}. + * @param hasDisconnect {@code false} if this {@link Channel} will delegate {@link #disconnect()} + * to {@link #close()}, {@code true} otherwise. + * @param handlers the {@link ChannelHandler}s which will be added to the {@link ChannelPipeline} + */ + public EmbeddedChannel(Channel parent, ChannelId channelId, boolean register, boolean hasDisconnect, + final ChannelHandler... handlers) { + super(parent, channelId); + metadata = metadata(hasDisconnect); + config = new DefaultChannelConfig(this); + setup(register, handlers); + } + + /** + * Create a new instance with the channel ID set to the given ID and the pipeline + * initialized with the specified handlers. + * + * @param channelId the {@link ChannelId} that will be used to identify this channel + * @param hasDisconnect {@code false} if this {@link Channel} will delegate {@link #disconnect()} + * to {@link #close()}, {@code true} otherwise. + * @param config the {@link ChannelConfig} which will be returned by {@link #config()}. + * @param handlers the {@link ChannelHandler}s which will be added to the {@link ChannelPipeline} + */ + public EmbeddedChannel(ChannelId channelId, boolean hasDisconnect, final ChannelConfig config, + final ChannelHandler... handlers) { + super(null, channelId); + metadata = metadata(hasDisconnect); + this.config = ObjectUtil.checkNotNull(config, "config"); + setup(true, handlers); + } + + private static ChannelMetadata metadata(boolean hasDisconnect) { + return hasDisconnect ? METADATA_DISCONNECT : METADATA_NO_DISCONNECT; + } + + private void setup(boolean register, final ChannelHandler... handlers) { + ObjectUtil.checkNotNull(handlers, "handlers"); + ChannelPipeline p = pipeline(); + p.addLast(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ChannelPipeline pipeline = ch.pipeline(); + for (ChannelHandler h: handlers) { + if (h == null) { + break; + } + pipeline.addLast(h); + } + } + }); + if (register) { + ChannelFuture future = loop.register(this); + assert future.isDone(); + } + } + + /** + * Register this {@code Channel} on its {@link EventLoop}. + */ + public void register() throws Exception { + ChannelFuture future = loop.register(this); + assert future.isDone(); + Throwable cause = future.cause(); + if (cause != null) { + PlatformDependent.throwException(cause); + } + } + + @Override + protected final DefaultChannelPipeline newChannelPipeline() { + return new EmbeddedChannelPipeline(this); + } + + @Override + public ChannelMetadata metadata() { + return metadata; + } + + @Override + public ChannelConfig config() { + return config; + } + + @Override + public boolean isOpen() { + return state != State.CLOSED; + } + + @Override + public boolean isActive() { + return state == State.ACTIVE; + } + + /** + * Returns the {@link Queue} which holds all the {@link Object}s that were received by this {@link Channel}. + */ + public Queue inboundMessages() { + if (inboundMessages == null) { + inboundMessages = new ArrayDeque(); + } + return inboundMessages; + } + + /** + * @deprecated use {@link #inboundMessages()} + */ + @Deprecated + public Queue lastInboundBuffer() { + return inboundMessages(); + } + + /** + * Returns the {@link Queue} which holds all the {@link Object}s that were written by this {@link Channel}. + */ + public Queue outboundMessages() { + if (outboundMessages == null) { + outboundMessages = new ArrayDeque(); + } + return outboundMessages; + } + + /** + * @deprecated use {@link #outboundMessages()} + */ + @Deprecated + public Queue lastOutboundBuffer() { + return outboundMessages(); + } + + /** + * Return received data from this {@link Channel} + */ + @SuppressWarnings("unchecked") + public T readInbound() { + T message = (T) poll(inboundMessages); + if (message != null) { + ReferenceCountUtil.touch(message, "Caller of readInbound() will handle the message from this point"); + } + return message; + } + + /** + * Read data from the outbound. This may return {@code null} if nothing is readable. + */ + @SuppressWarnings("unchecked") + public T readOutbound() { + T message = (T) poll(outboundMessages); + if (message != null) { + ReferenceCountUtil.touch(message, "Caller of readOutbound() will handle the message from this point."); + } + return message; + } + + /** + * Write messages to the inbound of this {@link Channel}. + * + * @param msgs the messages to be written + * + * @return {@code true} if the write operation did add something to the inbound buffer + */ + public boolean writeInbound(Object... msgs) { + ensureOpen(); + if (msgs.length == 0) { + return isNotEmpty(inboundMessages); + } + + ChannelPipeline p = pipeline(); + for (Object m: msgs) { + p.fireChannelRead(m); + } + + flushInbound(false, voidPromise()); + return isNotEmpty(inboundMessages); + } + + /** + * Writes one message to the inbound of this {@link Channel} and does not flush it. This + * method is conceptually equivalent to {@link #write(Object)}. + * + * @see #writeOneOutbound(Object) + */ + public ChannelFuture writeOneInbound(Object msg) { + return writeOneInbound(msg, newPromise()); + } + + /** + * Writes one message to the inbound of this {@link Channel} and does not flush it. This + * method is conceptually equivalent to {@link #write(Object, ChannelPromise)}. + * + * @see #writeOneOutbound(Object, ChannelPromise) + */ + public ChannelFuture writeOneInbound(Object msg, ChannelPromise promise) { + if (checkOpen(true)) { + pipeline().fireChannelRead(msg); + } + return checkException(promise); + } + + /** + * Flushes the inbound of this {@link Channel}. This method is conceptually equivalent to {@link #flush()}. + * + * @see #flushOutbound() + */ + public EmbeddedChannel flushInbound() { + flushInbound(true, voidPromise()); + return this; + } + + private ChannelFuture flushInbound(boolean recordException, ChannelPromise promise) { + if (checkOpen(recordException)) { + pipeline().fireChannelReadComplete(); + runPendingTasks(); + } + + return checkException(promise); + } + + /** + * Write messages to the outbound of this {@link Channel}. + * + * @param msgs the messages to be written + * @return bufferReadable returns {@code true} if the write operation did add something to the outbound buffer + */ + public boolean writeOutbound(Object... msgs) { + ensureOpen(); + if (msgs.length == 0) { + return isNotEmpty(outboundMessages); + } + + RecyclableArrayList futures = RecyclableArrayList.newInstance(msgs.length); + try { + for (Object m: msgs) { + if (m == null) { + break; + } + futures.add(write(m)); + } + + flushOutbound0(); + + int size = futures.size(); + for (int i = 0; i < size; i++) { + ChannelFuture future = (ChannelFuture) futures.get(i); + if (future.isDone()) { + recordException(future); + } else { + // The write may be delayed to run later by runPendingTasks() + future.addListener(recordExceptionListener); + } + } + + checkException(); + return isNotEmpty(outboundMessages); + } finally { + futures.recycle(); + } + } + + /** + * Writes one message to the outbound of this {@link Channel} and does not flush it. This + * method is conceptually equivalent to {@link #write(Object)}. + * + * @see #writeOneInbound(Object) + */ + public ChannelFuture writeOneOutbound(Object msg) { + return writeOneOutbound(msg, newPromise()); + } + + /** + * Writes one message to the outbound of this {@link Channel} and does not flush it. This + * method is conceptually equivalent to {@link #write(Object, ChannelPromise)}. + * + * @see #writeOneInbound(Object, ChannelPromise) + */ + public ChannelFuture writeOneOutbound(Object msg, ChannelPromise promise) { + if (checkOpen(true)) { + return write(msg, promise); + } + return checkException(promise); + } + + /** + * Flushes the outbound of this {@link Channel}. This method is conceptually equivalent to {@link #flush()}. + * + * @see #flushInbound() + */ + public EmbeddedChannel flushOutbound() { + if (checkOpen(true)) { + flushOutbound0(); + } + checkException(voidPromise()); + return this; + } + + private void flushOutbound0() { + // We need to call runPendingTasks first as a ChannelOutboundHandler may used eventloop.execute(...) to + // delay the write on the next eventloop run. + runPendingTasks(); + + flush(); + } + + /** + * Mark this {@link Channel} as finished. Any further try to write data to it will fail. + * + * @return bufferReadable returns {@code true} if any of the used buffers has something left to read + */ + public boolean finish() { + return finish(false); + } + + /** + * Mark this {@link Channel} as finished and release all pending message in the inbound and outbound buffer. + * Any further try to write data to it will fail. + * + * @return bufferReadable returns {@code true} if any of the used buffers has something left to read + */ + public boolean finishAndReleaseAll() { + return finish(true); + } + + /** + * Mark this {@link Channel} as finished. Any further try to write data to it will fail. + * + * @param releaseAll if {@code true} all pending message in the inbound and outbound buffer are released. + * @return bufferReadable returns {@code true} if any of the used buffers has something left to read + */ + private boolean finish(boolean releaseAll) { + close(); + try { + checkException(); + return isNotEmpty(inboundMessages) || isNotEmpty(outboundMessages); + } finally { + if (releaseAll) { + releaseAll(inboundMessages); + releaseAll(outboundMessages); + } + } + } + + /** + * Release all buffered inbound messages and return {@code true} if any were in the inbound buffer, {@code false} + * otherwise. + */ + public boolean releaseInbound() { + return releaseAll(inboundMessages); + } + + /** + * Release all buffered outbound messages and return {@code true} if any were in the outbound buffer, {@code false} + * otherwise. + */ + public boolean releaseOutbound() { + return releaseAll(outboundMessages); + } + + private static boolean releaseAll(Queue queue) { + if (isNotEmpty(queue)) { + for (;;) { + Object msg = queue.poll(); + if (msg == null) { + break; + } + ReferenceCountUtil.release(msg); + } + return true; + } + return false; + } + + private void finishPendingTasks(boolean cancel) { + runPendingTasks(); + if (cancel) { + // Cancel all scheduled tasks that are left. + embeddedEventLoop().cancelScheduledTasks(); + } + } + + @Override + public final ChannelFuture close() { + return close(newPromise()); + } + + @Override + public final ChannelFuture disconnect() { + return disconnect(newPromise()); + } + + @Override + public final ChannelFuture close(ChannelPromise promise) { + // We need to call runPendingTasks() before calling super.close() as there may be something in the queue + // that needs to be run before the actual close takes place. + runPendingTasks(); + ChannelFuture future = super.close(promise); + + // Now finish everything else and cancel all scheduled tasks that were not ready set. + finishPendingTasks(true); + return future; + } + + @Override + public final ChannelFuture disconnect(ChannelPromise promise) { + ChannelFuture future = super.disconnect(promise); + finishPendingTasks(!metadata.hasDisconnect()); + return future; + } + + private static boolean isNotEmpty(Queue queue) { + return queue != null && !queue.isEmpty(); + } + + private static Object poll(Queue queue) { + return queue != null ? queue.poll() : null; + } + + /** + * Run all tasks (which also includes scheduled tasks) that are pending in the {@link EventLoop} + * for this {@link Channel} + */ + public void runPendingTasks() { + try { + embeddedEventLoop().runTasks(); + } catch (Exception e) { + recordException(e); + } + + try { + embeddedEventLoop().runScheduledTasks(); + } catch (Exception e) { + recordException(e); + } + } + + /** + * Check whether this channel has any pending tasks that would be executed by a call to {@link #runPendingTasks()}. + * This includes normal tasks, and scheduled tasks where the deadline has expired. If this method returns + * {@code false}, a call to {@link #runPendingTasks()} would do nothing. + * + * @return {@code true} if there are any pending tasks, {@code false} otherwise. + */ + public boolean hasPendingTasks() { + return embeddedEventLoop().hasPendingNormalTasks() || + embeddedEventLoop().nextScheduledTask() == 0; + } + + /** + * Run all pending scheduled tasks in the {@link EventLoop} for this {@link Channel} and return the + * {@code nanoseconds} when the next scheduled task is ready to run. If no other task was scheduled it will return + * {@code -1}. + */ + public long runScheduledPendingTasks() { + try { + return embeddedEventLoop().runScheduledTasks(); + } catch (Exception e) { + recordException(e); + return embeddedEventLoop().nextScheduledTask(); + } + } + + private void recordException(ChannelFuture future) { + if (!future.isSuccess()) { + recordException(future.cause()); + } + } + + private void recordException(Throwable cause) { + if (lastException == null) { + lastException = cause; + } else { + logger.warn( + "More than one exception was raised. " + + "Will report only the first one and log others.", cause); + } + } + + /** + * Advance the clock of the event loop of this channel by the given duration. Any scheduled tasks will execute + * sooner by the given time (but {@link #runScheduledPendingTasks()} still needs to be called). + */ + public void advanceTimeBy(long duration, TimeUnit unit) { + embeddedEventLoop().advanceTimeBy(unit.toNanos(duration)); + } + + /** + * Freeze the clock of this channel's event loop. Any scheduled tasks that are not already due will not run on + * future {@link #runScheduledPendingTasks()} calls. While the event loop is frozen, it is still possible to + * {@link #advanceTimeBy(long, TimeUnit) advance time} manually so that scheduled tasks execute. + */ + public void freezeTime() { + embeddedEventLoop().freezeTime(); + } + + /** + * Unfreeze an event loop that was {@link #freezeTime() frozen}. Time will continue at the point where + * {@link #freezeTime()} stopped it: if a task was scheduled ten minutes in the future and {@link #freezeTime()} + * was called, it will run ten minutes after this method is called again (assuming no + * {@link #advanceTimeBy(long, TimeUnit)} calls, and assuming pending scheduled tasks are run at that time using + * {@link #runScheduledPendingTasks()}). + */ + public void unfreezeTime() { + embeddedEventLoop().unfreezeTime(); + } + + /** + * Checks for the presence of an {@link Exception}. + */ + private ChannelFuture checkException(ChannelPromise promise) { + Throwable t = lastException; + if (t != null) { + lastException = null; + + if (promise.isVoid()) { + PlatformDependent.throwException(t); + } + + return promise.setFailure(t); + } + + return promise.setSuccess(); + } + + /** + * Check if there was any {@link Throwable} received and if so rethrow it. + */ + public void checkException() { + checkException(voidPromise()); + } + + /** + * Returns {@code true} if the {@link Channel} is open and records optionally + * an {@link Exception} if it isn't. + */ + private boolean checkOpen(boolean recordException) { + if (!isOpen()) { + if (recordException) { + recordException(new ClosedChannelException()); + } + return false; + } + + return true; + } + + private EmbeddedEventLoop embeddedEventLoop() { + if (isRegistered()) { + return (EmbeddedEventLoop) super.eventLoop(); + } + + return loop; + } + + /** + * Ensure the {@link Channel} is open and if not throw an exception. + */ + protected final void ensureOpen() { + if (!checkOpen(true)) { + checkException(); + } + } + + @Override + protected boolean isCompatible(EventLoop loop) { + return loop instanceof EmbeddedEventLoop; + } + + @Override + protected SocketAddress localAddress0() { + return isActive()? LOCAL_ADDRESS : null; + } + + @Override + protected SocketAddress remoteAddress0() { + return isActive()? REMOTE_ADDRESS : null; + } + + @Override + protected void doRegister() throws Exception { + state = State.ACTIVE; + } + + @Override + protected void doBind(SocketAddress localAddress) throws Exception { + // NOOP + } + + @Override + protected void doDisconnect() throws Exception { + if (!metadata.hasDisconnect()) { + doClose(); + } + } + + @Override + protected void doClose() throws Exception { + state = State.CLOSED; + } + + @Override + protected void doBeginRead() throws Exception { + // NOOP + } + + @Override + protected AbstractUnsafe newUnsafe() { + return new EmbeddedUnsafe(); + } + + @Override + public Unsafe unsafe() { + return ((EmbeddedUnsafe) super.unsafe()).wrapped; + } + + @Override + protected void doWrite(ChannelOutboundBuffer in) throws Exception { + for (;;) { + Object msg = in.current(); + if (msg == null) { + break; + } + + ReferenceCountUtil.retain(msg); + handleOutboundMessage(msg); + in.remove(); + } + } + + /** + * Called for each outbound message. + * + * @see #doWrite(ChannelOutboundBuffer) + */ + protected void handleOutboundMessage(Object msg) { + outboundMessages().add(msg); + } + + /** + * Called for each inbound message. + */ + protected void handleInboundMessage(Object msg) { + inboundMessages().add(msg); + } + + private final class EmbeddedUnsafe extends AbstractUnsafe { + + // Delegates to the EmbeddedUnsafe instance but ensures runPendingTasks() is called after each operation + // that may change the state of the Channel and may schedule tasks for later execution. + final Unsafe wrapped = new Unsafe() { + @Override + public RecvByteBufAllocator.Handle recvBufAllocHandle() { + return EmbeddedUnsafe.this.recvBufAllocHandle(); + } + + @Override + public SocketAddress localAddress() { + return EmbeddedUnsafe.this.localAddress(); + } + + @Override + public SocketAddress remoteAddress() { + return EmbeddedUnsafe.this.remoteAddress(); + } + + @Override + public void register(EventLoop eventLoop, ChannelPromise promise) { + EmbeddedUnsafe.this.register(eventLoop, promise); + runPendingTasks(); + } + + @Override + public void bind(SocketAddress localAddress, ChannelPromise promise) { + EmbeddedUnsafe.this.bind(localAddress, promise); + runPendingTasks(); + } + + @Override + public void connect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) { + EmbeddedUnsafe.this.connect(remoteAddress, localAddress, promise); + runPendingTasks(); + } + + @Override + public void disconnect(ChannelPromise promise) { + EmbeddedUnsafe.this.disconnect(promise); + runPendingTasks(); + } + + @Override + public void close(ChannelPromise promise) { + EmbeddedUnsafe.this.close(promise); + runPendingTasks(); + } + + @Override + public void closeForcibly() { + EmbeddedUnsafe.this.closeForcibly(); + runPendingTasks(); + } + + @Override + public void deregister(ChannelPromise promise) { + EmbeddedUnsafe.this.deregister(promise); + runPendingTasks(); + } + + @Override + public void beginRead() { + EmbeddedUnsafe.this.beginRead(); + runPendingTasks(); + } + + @Override + public void write(Object msg, ChannelPromise promise) { + EmbeddedUnsafe.this.write(msg, promise); + runPendingTasks(); + } + + @Override + public void flush() { + EmbeddedUnsafe.this.flush(); + runPendingTasks(); + } + + @Override + public ChannelPromise voidPromise() { + return EmbeddedUnsafe.this.voidPromise(); + } + + @Override + public ChannelOutboundBuffer outboundBuffer() { + return EmbeddedUnsafe.this.outboundBuffer(); + } + }; + + @Override + public void connect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) { + safeSetSuccess(promise); + } + } + + private final class EmbeddedChannelPipeline extends DefaultChannelPipeline { + EmbeddedChannelPipeline(EmbeddedChannel channel) { + super(channel); + } + + @Override + protected void onUnhandledInboundException(Throwable cause) { + recordException(cause); + } + + @Override + protected void onUnhandledInboundMessage(ChannelHandlerContext ctx, Object msg) { + handleInboundMessage(msg); + } + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/embedded/EmbeddedChannelId.java b/netty-channel/src/main/java/io/netty/channel/embedded/EmbeddedChannelId.java new file mode 100644 index 0000000..6e40369 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/embedded/EmbeddedChannelId.java @@ -0,0 +1,65 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.channel.embedded; + +import io.netty.channel.ChannelId; + +/** + * A dummy {@link ChannelId} implementation. + */ +final class EmbeddedChannelId implements ChannelId { + + private static final long serialVersionUID = -251711922203466130L; + + static final ChannelId INSTANCE = new EmbeddedChannelId(); + + private EmbeddedChannelId() { } + + @Override + public String asShortText() { + return toString(); + } + + @Override + public String asLongText() { + return toString(); + } + + @Override + public int compareTo(final ChannelId o) { + if (o instanceof EmbeddedChannelId) { + return 0; + } + + return asLongText().compareTo(o.asLongText()); + } + + @Override + public int hashCode() { + return 0; + } + + @Override + public boolean equals(Object obj) { + return obj instanceof EmbeddedChannelId; + } + + @Override + public String toString() { + return "embedded"; + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/embedded/EmbeddedEventLoop.java b/netty-channel/src/main/java/io/netty/channel/embedded/EmbeddedEventLoop.java new file mode 100644 index 0000000..a21eec2 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/embedded/EmbeddedEventLoop.java @@ -0,0 +1,201 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.embedded; + +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelPromise; +import io.netty.channel.DefaultChannelPromise; +import io.netty.channel.EventLoop; +import io.netty.channel.EventLoopGroup; +import io.netty.util.concurrent.AbstractScheduledEventExecutor; +import io.netty.util.concurrent.Future; +import io.netty.util.internal.ObjectUtil; + +import java.util.ArrayDeque; +import java.util.Queue; +import java.util.concurrent.TimeUnit; + +final class EmbeddedEventLoop extends AbstractScheduledEventExecutor implements EventLoop { + /** + * When time is not {@link #timeFrozen frozen}, the base time to subtract from {@link System#nanoTime()}. When time + * is frozen, this variable is unused. + * + * Initialized to {@link #initialNanoTime()} so that until one of the time mutator methods is called, + * {@link #getCurrentTimeNanos()} matches the default behavior. + */ + private long startTime = initialNanoTime(); + /** + * When time is frozen, the timestamp returned by {@link #getCurrentTimeNanos()}. When unfrozen, this is unused. + */ + private long frozenTimestamp; + /** + * Whether time is currently frozen. + */ + private boolean timeFrozen; + + private final Queue tasks = new ArrayDeque(2); + + @Override + public EventLoopGroup parent() { + return (EventLoopGroup) super.parent(); + } + + @Override + public EventLoop next() { + return (EventLoop) super.next(); + } + + @Override + public void execute(Runnable command) { + tasks.add(ObjectUtil.checkNotNull(command, "command")); + } + + void runTasks() { + for (;;) { + Runnable task = tasks.poll(); + if (task == null) { + break; + } + + task.run(); + } + } + + boolean hasPendingNormalTasks() { + return !tasks.isEmpty(); + } + + long runScheduledTasks() { + long time = getCurrentTimeNanos(); + for (;;) { + Runnable task = pollScheduledTask(time); + if (task == null) { + return nextScheduledTaskNano(); + } + + task.run(); + } + } + + long nextScheduledTask() { + return nextScheduledTaskNano(); + } + + @Override + protected long getCurrentTimeNanos() { + if (timeFrozen) { + return frozenTimestamp; + } + return System.nanoTime() - startTime; + } + + void advanceTimeBy(long nanos) { + if (timeFrozen) { + frozenTimestamp += nanos; + } else { + // startTime is subtracted from nanoTime, so increasing the startTime will advance getCurrentTimeNanos + startTime -= nanos; + } + } + + void freezeTime() { + if (!timeFrozen) { + frozenTimestamp = getCurrentTimeNanos(); + timeFrozen = true; + } + } + + void unfreezeTime() { + if (timeFrozen) { + // we want getCurrentTimeNanos to continue right where frozenTimestamp left off: + // getCurrentTimeNanos = nanoTime - startTime = frozenTimestamp + // then solve for startTime + startTime = System.nanoTime() - frozenTimestamp; + timeFrozen = false; + } + } + + @Override + protected void cancelScheduledTasks() { + super.cancelScheduledTasks(); + } + + @Override + public Future shutdownGracefully(long quietPeriod, long timeout, TimeUnit unit) { + throw new UnsupportedOperationException(); + } + + @Override + public Future terminationFuture() { + throw new UnsupportedOperationException(); + } + + @Override + @Deprecated + public void shutdown() { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isShuttingDown() { + return false; + } + + @Override + public boolean isShutdown() { + return false; + } + + @Override + public boolean isTerminated() { + return false; + } + + @Override + public boolean awaitTermination(long timeout, TimeUnit unit) { + return false; + } + + @Override + public ChannelFuture register(Channel channel) { + return register(new DefaultChannelPromise(channel, this)); + } + + @Override + public ChannelFuture register(ChannelPromise promise) { + ObjectUtil.checkNotNull(promise, "promise"); + promise.channel().unsafe().register(this, promise); + return promise; + } + + @Deprecated + @Override + public ChannelFuture register(Channel channel, ChannelPromise promise) { + channel.unsafe().register(this, promise); + return promise; + } + + @Override + public boolean inEventLoop() { + return true; + } + + @Override + public boolean inEventLoop(Thread thread) { + return true; + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/embedded/EmbeddedSocketAddress.java b/netty-channel/src/main/java/io/netty/channel/embedded/EmbeddedSocketAddress.java new file mode 100644 index 0000000..6b14bf4 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/embedded/EmbeddedSocketAddress.java @@ -0,0 +1,27 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.embedded; + +import java.net.SocketAddress; + +final class EmbeddedSocketAddress extends SocketAddress { + private static final long serialVersionUID = 1400788804624980619L; + + @Override + public String toString() { + return "embedded"; + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/embedded/package-info.java b/netty-channel/src/main/java/io/netty/channel/embedded/package-info.java new file mode 100644 index 0000000..0cc337f --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/embedded/package-info.java @@ -0,0 +1,22 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/** + * A virtual {@link io.netty.channel.Channel} that helps wrapping a series of handlers to + * unit test the handlers or use them in non-I/O context. + */ +package io.netty.channel.embedded; + diff --git a/netty-channel/src/main/java/io/netty/channel/group/ChannelGroup.java b/netty-channel/src/main/java/io/netty/channel/group/ChannelGroup.java new file mode 100644 index 0000000..b11f6f5 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/group/ChannelGroup.java @@ -0,0 +1,278 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.group; + +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufHolder; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelId; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.EventLoop; +import io.netty.channel.ServerChannel; +import io.netty.util.CharsetUtil; +import io.netty.util.concurrent.GlobalEventExecutor; + +import java.util.Set; + +/** + * A thread-safe {@link Set} that contains open {@link Channel}s and provides + * various bulk operations on them. Using {@link ChannelGroup}, you can + * categorize {@link Channel}s into a meaningful group (e.g. on a per-service + * or per-state basis.) A closed {@link Channel} is automatically removed from + * the collection, so that you don't need to worry about the life cycle of the + * added {@link Channel}. A {@link Channel} can belong to more than one + * {@link ChannelGroup}. + * + *

Broadcast a message to multiple {@link Channel}s

+ *

+ * If you need to broadcast a message to more than one {@link Channel}, you can + * add the {@link Channel}s associated with the recipients and call {@link ChannelGroup#write(Object)}: + *

+ * {@link ChannelGroup} recipients =
+ *         new {@link DefaultChannelGroup}({@link GlobalEventExecutor}.INSTANCE);
+ * recipients.add(channelA);
+ * recipients.add(channelB);
+ * ..
+ * recipients.write({@link Unpooled}.copiedBuffer(
+ *         "Service will shut down for maintenance in 5 minutes.",
+ *         {@link CharsetUtil}.UTF_8));
+ * 
+ * + *

Simplify shutdown process with {@link ChannelGroup}

+ *

+ * If both {@link ServerChannel}s and non-{@link ServerChannel}s exist in the + * same {@link ChannelGroup}, any requested I/O operations on the group are + * performed for the {@link ServerChannel}s first and then for the others. + *

+ * This rule is very useful when you shut down a server in one shot: + * + *

+ * {@link ChannelGroup} allChannels =
+ *         new {@link DefaultChannelGroup}({@link GlobalEventExecutor}.INSTANCE);
+ *
+ * public static void main(String[] args) throws Exception {
+ *     {@link ServerBootstrap} b = new {@link ServerBootstrap}(..);
+ *     ...
+ *     b.childHandler(new MyHandler());
+ *
+ *     // Start the server
+ *     b.getPipeline().addLast("handler", new MyHandler());
+ *     {@link Channel} serverChannel = b.bind(..).sync();
+ *     allChannels.add(serverChannel);
+ *
+ *     ... Wait until the shutdown signal reception ...
+ *
+ *     // Close the serverChannel and then all accepted connections.
+ *     allChannels.close().awaitUninterruptibly();
+ * }
+ *
+ * public class MyHandler extends {@link ChannelInboundHandlerAdapter} {
+ *     {@code @Override}
+ *     public void channelActive({@link ChannelHandlerContext} ctx) {
+ *         // closed on shutdown.
+ *         allChannels.add(ctx.channel());
+ *         super.channelActive(ctx);
+ *     }
+ * }
+ * 
+ */ +public interface ChannelGroup extends Set, Comparable { + + /** + * Returns the name of this group. A group name is purely for helping + * you to distinguish one group from others. + */ + String name(); + + /** + * Returns the {@link Channel} which has the specified {@link ChannelId}. + * + * @return the matching {@link Channel} if found. {@code null} otherwise. + */ + Channel find(ChannelId id); + + /** + * Writes the specified {@code message} to all {@link Channel}s in this + * group. If the specified {@code message} is an instance of + * {@link ByteBuf}, it is automatically + * {@linkplain ByteBuf#duplicate() duplicated} to avoid a race + * condition. The same is true for {@link ByteBufHolder}. Please note that this operation is asynchronous as + * {@link Channel#write(Object)} is. + * + * @return itself + */ + ChannelGroupFuture write(Object message); + + /** + * Writes the specified {@code message} to all {@link Channel}s in this + * group that are matched by the given {@link ChannelMatcher}. If the specified {@code message} is an instance of + * {@link ByteBuf}, it is automatically + * {@linkplain ByteBuf#duplicate() duplicated} to avoid a race + * condition. The same is true for {@link ByteBufHolder}. Please note that this operation is asynchronous as + * {@link Channel#write(Object)} is. + * + * @return the {@link ChannelGroupFuture} instance that notifies when + * the operation is done for all channels + */ + ChannelGroupFuture write(Object message, ChannelMatcher matcher); + + /** + * Writes the specified {@code message} to all {@link Channel}s in this + * group that are matched by the given {@link ChannelMatcher}. If the specified {@code message} is an instance of + * {@link ByteBuf}, it is automatically + * {@linkplain ByteBuf#duplicate() duplicated} to avoid a race + * condition. The same is true for {@link ByteBufHolder}. Please note that this operation is asynchronous as + * {@link Channel#write(Object)} is. + * + * If {@code voidPromise} is {@code true} {@link Channel#voidPromise()} is used for the writes and so the same + * restrictions to the returned {@link ChannelGroupFuture} apply as to a void promise. + * + * @return the {@link ChannelGroupFuture} instance that notifies when + * the operation is done for all channels + */ + ChannelGroupFuture write(Object message, ChannelMatcher matcher, boolean voidPromise); + + /** + * Flush all {@link Channel}s in this + * group. If the specified {@code messages} are an instance of + * {@link ByteBuf}, it is automatically + * {@linkplain ByteBuf#duplicate() duplicated} to avoid a race + * condition. Please note that this operation is asynchronous as + * {@link Channel#write(Object)} is. + * + * @return the {@link ChannelGroupFuture} instance that notifies when + * the operation is done for all channels + */ + ChannelGroup flush(); + + /** + * Flush all {@link Channel}s in this group that are matched by the given {@link ChannelMatcher}. + * If the specified {@code messages} are an instance of + * {@link ByteBuf}, it is automatically + * {@linkplain ByteBuf#duplicate() duplicated} to avoid a race + * condition. Please note that this operation is asynchronous as + * {@link Channel#write(Object)} is. + * + * @return the {@link ChannelGroupFuture} instance that notifies when + * the operation is done for all channels + */ + ChannelGroup flush(ChannelMatcher matcher); + + /** + * Shortcut for calling {@link #write(Object)} and {@link #flush()}. + */ + ChannelGroupFuture writeAndFlush(Object message); + + /** + * @deprecated Use {@link #writeAndFlush(Object)} instead. + */ + @Deprecated + ChannelGroupFuture flushAndWrite(Object message); + + /** + * Shortcut for calling {@link #write(Object)} and {@link #flush()} and only act on + * {@link Channel}s that are matched by the {@link ChannelMatcher}. + */ + ChannelGroupFuture writeAndFlush(Object message, ChannelMatcher matcher); + + /** + * Shortcut for calling {@link #write(Object, ChannelMatcher, boolean)} and {@link #flush()} and only act on + * {@link Channel}s that are matched by the {@link ChannelMatcher}. + */ + ChannelGroupFuture writeAndFlush(Object message, ChannelMatcher matcher, boolean voidPromise); + + /** + * @deprecated Use {@link #writeAndFlush(Object, ChannelMatcher)} instead. + */ + @Deprecated + ChannelGroupFuture flushAndWrite(Object message, ChannelMatcher matcher); + + /** + * Disconnects all {@link Channel}s in this group from their remote peers. + * + * @return the {@link ChannelGroupFuture} instance that notifies when + * the operation is done for all channels + */ + ChannelGroupFuture disconnect(); + + /** + * Disconnects all {@link Channel}s in this group from their remote peers, + * that are matched by the given {@link ChannelMatcher}. + * + * @return the {@link ChannelGroupFuture} instance that notifies when + * the operation is done for all channels + */ + ChannelGroupFuture disconnect(ChannelMatcher matcher); + + /** + * Closes all {@link Channel}s in this group. If the {@link Channel} is + * connected to a remote peer or bound to a local address, it is + * automatically disconnected and unbound. + * + * @return the {@link ChannelGroupFuture} instance that notifies when + * the operation is done for all channels + */ + ChannelGroupFuture close(); + + /** + * Closes all {@link Channel}s in this group that are matched by the given {@link ChannelMatcher}. + * If the {@link Channel} is connected to a remote peer or bound to a local address, it is + * automatically disconnected and unbound. + * + * @return the {@link ChannelGroupFuture} instance that notifies when + * the operation is done for all channels + */ + ChannelGroupFuture close(ChannelMatcher matcher); + + /** + * @deprecated This method will be removed in the next major feature release. + * + * Deregister all {@link Channel}s in this group from their {@link EventLoop}. + * Please note that this operation is asynchronous as {@link Channel#deregister()} is. + * + * @return the {@link ChannelGroupFuture} instance that notifies when + * the operation is done for all channels + */ + @Deprecated + ChannelGroupFuture deregister(); + + /** + * @deprecated This method will be removed in the next major feature release. + * + * Deregister all {@link Channel}s in this group from their {@link EventLoop} that are matched by the given + * {@link ChannelMatcher}. Please note that this operation is asynchronous as {@link Channel#deregister()} is. + * + * @return the {@link ChannelGroupFuture} instance that notifies when + * the operation is done for all channels + */ + @Deprecated + ChannelGroupFuture deregister(ChannelMatcher matcher); + + /** + * Returns the {@link ChannelGroupFuture} which will be notified when all {@link Channel}s that are part of this + * {@link ChannelGroup}, at the time of calling, are closed. + */ + ChannelGroupFuture newCloseFuture(); + + /** + * Returns the {@link ChannelGroupFuture} which will be notified when all {@link Channel}s that are part of this + * {@link ChannelGroup}, at the time of calling, are closed. + */ + ChannelGroupFuture newCloseFuture(ChannelMatcher matcher); +} diff --git a/netty-channel/src/main/java/io/netty/channel/group/ChannelGroupException.java b/netty-channel/src/main/java/io/netty/channel/group/ChannelGroupException.java new file mode 100644 index 0000000..727ac08 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/group/ChannelGroupException.java @@ -0,0 +1,49 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.group; + +import io.netty.channel.Channel; +import io.netty.channel.ChannelException; +import io.netty.channel.ChannelFuture; +import io.netty.util.internal.ObjectUtil; + +import java.util.Collection; +import java.util.Collections; +import java.util.Iterator; +import java.util.Map; + +/** + * {@link ChannelException} which holds {@link ChannelFuture}s that failed because of an error. + */ +public class ChannelGroupException extends ChannelException implements Iterable> { + private static final long serialVersionUID = -4093064295562629453L; + private final Collection> failed; + + public ChannelGroupException(Collection> causes) { + ObjectUtil.checkNonEmpty(causes, "causes"); + + failed = Collections.unmodifiableCollection(causes); + } + + /** + * Returns a {@link Iterator} which contains all the {@link Throwable} that was a cause of the failure and the + * related id of the {@link Channel}. + */ + @Override + public Iterator> iterator() { + return failed.iterator(); + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/group/ChannelGroupFuture.java b/netty-channel/src/main/java/io/netty/channel/group/ChannelGroupFuture.java new file mode 100644 index 0000000..caa50d5 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/group/ChannelGroupFuture.java @@ -0,0 +1,175 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.group; + +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.GenericFutureListener; + +import java.util.Iterator; + +/** + * The result of an asynchronous {@link ChannelGroup} operation. + * {@link ChannelGroupFuture} is composed of {@link ChannelFuture}s which + * represent the outcome of the individual I/O operations that affect the + * {@link Channel}s in the {@link ChannelGroup}. + * + *

+ * All I/O operations in {@link ChannelGroup} are asynchronous. It means any + * I/O calls will return immediately with no guarantee that the requested I/O + * operations have been completed at the end of the call. Instead, you will be + * returned with a {@link ChannelGroupFuture} instance which tells you when the + * requested I/O operations have succeeded, failed, or cancelled. + *

+ * Various methods are provided to let you check if the I/O operations has been + * completed, wait for the completion, and retrieve the result of the I/O + * operation. It also allows you to add more than one + * {@link ChannelGroupFutureListener} so you can get notified when the I/O + * operation have been completed. + * + *

Prefer {@link #addListener(GenericFutureListener)} to {@link #await()}

+ * + * It is recommended to prefer {@link #addListener(GenericFutureListener)} to + * {@link #await()} wherever possible to get notified when I/O operations are + * done and to do any follow-up tasks. + *

+ * {@link #addListener(GenericFutureListener)} is non-blocking. It simply + * adds the specified {@link ChannelGroupFutureListener} to the + * {@link ChannelGroupFuture}, and I/O thread will notify the listeners when + * the I/O operations associated with the future is done. + * {@link ChannelGroupFutureListener} yields the best performance and resource + * utilization because it does not block at all, but it could be tricky to + * implement a sequential logic if you are not used to event-driven programming. + *

+ * By contrast, {@link #await()} is a blocking operation. Once called, the + * caller thread blocks until all I/O operations are done. It is easier to + * implement a sequential logic with {@link #await()}, but the caller thread + * blocks unnecessarily until all I/O operations are done and there's relatively + * expensive cost of inter-thread notification. Moreover, there's a chance of + * dead lock in a particular circumstance, which is described below. + * + *

Do not call {@link #await()} inside {@link ChannelHandler}

+ *

+ * The event handler methods in {@link ChannelHandler} is often called by + * an I/O thread. If {@link #await()} is called by an event handler + * method, which is called by the I/O thread, the I/O operation it is waiting + * for might never be complete because {@link #await()} can block the I/O + * operation it is waiting for, which is a dead lock. + *

+ * // BAD - NEVER DO THIS
+ * {@code @Override}
+ * public void messageReceived({@link ChannelHandlerContext} ctx, ShutdownMessage msg) {
+ *     {@link ChannelGroup} allChannels = MyServer.getAllChannels();
+ *     {@link ChannelGroupFuture} future = allChannels.close();
+ *     future.awaitUninterruptibly();
+ *     // Perform post-shutdown operation
+ *     // ...
+ *
+ * }
+ *
+ * // GOOD
+ * {@code @Override}
+ * public void messageReceived(ChannelHandlerContext ctx, ShutdownMessage msg) {
+ *     {@link ChannelGroup} allChannels = MyServer.getAllChannels();
+ *     {@link ChannelGroupFuture} future = allChannels.close();
+ *     future.addListener(new {@link ChannelGroupFutureListener}() {
+ *         public void operationComplete({@link ChannelGroupFuture} future) {
+ *             // Perform post-closure operation
+ *             // ...
+ *         }
+ *     });
+ * }
+ * 
+ *

+ * In spite of the disadvantages mentioned above, there are certainly the cases + * where it is more convenient to call {@link #await()}. In such a case, please + * make sure you do not call {@link #await()} in an I/O thread. Otherwise, + * {@link IllegalStateException} will be raised to prevent a dead lock. + */ +public interface ChannelGroupFuture extends Future, Iterable { + + /** + * Returns the {@link ChannelGroup} which is associated with this future. + */ + ChannelGroup group(); + + /** + * Returns the {@link ChannelFuture} of the individual I/O operation which + * is associated with the specified {@link Channel}. + * + * @return the matching {@link ChannelFuture} if found. + * {@code null} otherwise. + */ + ChannelFuture find(Channel channel); + + /** + * Returns {@code true} if and only if all I/O operations associated with + * this future were successful without any failure. + */ + @Override + boolean isSuccess(); + + @Override + ChannelGroupException cause(); + + /** + * Returns {@code true} if and only if the I/O operations associated with + * this future were partially successful with some failure. + */ + boolean isPartialSuccess(); + + /** + * Returns {@code true} if and only if the I/O operations associated with + * this future have failed partially with some success. + */ + boolean isPartialFailure(); + + @Override + ChannelGroupFuture addListener(GenericFutureListener> listener); + + @Override + ChannelGroupFuture addListeners(GenericFutureListener>... listeners); + + @Override + ChannelGroupFuture removeListener(GenericFutureListener> listener); + + @Override + ChannelGroupFuture removeListeners(GenericFutureListener>... listeners); + + @Override + ChannelGroupFuture await() throws InterruptedException; + + @Override + ChannelGroupFuture awaitUninterruptibly(); + + @Override + ChannelGroupFuture syncUninterruptibly(); + + @Override + ChannelGroupFuture sync() throws InterruptedException; + + /** + * Returns the {@link Iterator} that enumerates all {@link ChannelFuture}s + * which are associated with this future. Please note that the returned + * {@link Iterator} is is unmodifiable, which means a {@link ChannelFuture} + * cannot be removed from this future. + */ + @Override + Iterator iterator(); +} diff --git a/netty-channel/src/main/java/io/netty/channel/group/ChannelGroupFutureListener.java b/netty-channel/src/main/java/io/netty/channel/group/ChannelGroupFutureListener.java new file mode 100644 index 0000000..26eee48 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/group/ChannelGroupFutureListener.java @@ -0,0 +1,28 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.group; + +import io.netty.util.concurrent.GenericFutureListener; + +/** + * Listens to the result of a {@link ChannelGroupFuture}. The result of the + * asynchronous {@link ChannelGroup} I/O operations is notified once this + * listener is added by calling {@link ChannelGroupFuture#addListener(GenericFutureListener)} + * and all I/O operations are complete. + */ +public interface ChannelGroupFutureListener extends GenericFutureListener { + +} diff --git a/netty-channel/src/main/java/io/netty/channel/group/ChannelMatcher.java b/netty-channel/src/main/java/io/netty/channel/group/ChannelMatcher.java new file mode 100644 index 0000000..9749d19 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/group/ChannelMatcher.java @@ -0,0 +1,32 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.group; + + +import io.netty.channel.Channel; + +/** + * Allows to only match some {@link Channel}'s for operations in {@link ChannelGroup}. + * + * {@link ChannelMatchers} provide you with helper methods for usual needed implementations. + */ +public interface ChannelMatcher { + + /** + * Returns {@code true} if the operation should be also executed on the given {@link Channel}. + */ + boolean matches(Channel channel); +} diff --git a/netty-channel/src/main/java/io/netty/channel/group/ChannelMatchers.java b/netty-channel/src/main/java/io/netty/channel/group/ChannelMatchers.java new file mode 100644 index 0000000..0179086 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/group/ChannelMatchers.java @@ -0,0 +1,169 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.group; + +import io.netty.channel.Channel; +import io.netty.channel.ServerChannel; + +/** + * Helper class which provides often used {@link ChannelMatcher} implementations. + */ +public final class ChannelMatchers { + + private static final ChannelMatcher ALL_MATCHER = new ChannelMatcher() { + @Override + public boolean matches(Channel channel) { + return true; + } + }; + + private static final ChannelMatcher SERVER_CHANNEL_MATCHER = isInstanceOf(ServerChannel.class); + private static final ChannelMatcher NON_SERVER_CHANNEL_MATCHER = isNotInstanceOf(ServerChannel.class); + + private ChannelMatchers() { + // static methods only + } + + /** + * Returns a {@link ChannelMatcher} that matches all {@link Channel}s. + */ + public static ChannelMatcher all() { + return ALL_MATCHER; + } + + /** + * Returns a {@link ChannelMatcher} that matches all {@link Channel}s except the given. + */ + public static ChannelMatcher isNot(Channel channel) { + return invert(is(channel)); + } + + /** + * Returns a {@link ChannelMatcher} that matches the given {@link Channel}. + */ + public static ChannelMatcher is(Channel channel) { + return new InstanceMatcher(channel); + } + + /** + * Returns a {@link ChannelMatcher} that matches all {@link Channel}s that are an instance of sub-type of + * the given class. + */ + public static ChannelMatcher isInstanceOf(Class clazz) { + return new ClassMatcher(clazz); + } + + /** + * Returns a {@link ChannelMatcher} that matches all {@link Channel}s that are not an + * instance of sub-type of the given class. + */ + public static ChannelMatcher isNotInstanceOf(Class clazz) { + return invert(isInstanceOf(clazz)); + } + + /** + * Returns a {@link ChannelMatcher} that matches all {@link Channel}s that are of type {@link ServerChannel}. + */ + public static ChannelMatcher isServerChannel() { + return SERVER_CHANNEL_MATCHER; + } + + /** + * Returns a {@link ChannelMatcher} that matches all {@link Channel}s that are not of type + * {@link ServerChannel}. + */ + public static ChannelMatcher isNonServerChannel() { + return NON_SERVER_CHANNEL_MATCHER; + } + + /** + * Invert the given {@link ChannelMatcher}. + */ + public static ChannelMatcher invert(ChannelMatcher matcher) { + return new InvertMatcher(matcher); + } + + /** + * Return a composite of the given {@link ChannelMatcher}s. This means all {@link ChannelMatcher} must + * return {@code true} to match. + */ + public static ChannelMatcher compose(ChannelMatcher... matchers) { + if (matchers.length < 1) { + throw new IllegalArgumentException("matchers must at least contain one element"); + } + if (matchers.length == 1) { + return matchers[0]; + } + return new CompositeMatcher(matchers); + } + + private static final class CompositeMatcher implements ChannelMatcher { + private final ChannelMatcher[] matchers; + + CompositeMatcher(ChannelMatcher... matchers) { + this.matchers = matchers; + } + + @Override + public boolean matches(Channel channel) { + for (ChannelMatcher m: matchers) { + if (!m.matches(channel)) { + return false; + } + } + return true; + } + } + + private static final class InvertMatcher implements ChannelMatcher { + private final ChannelMatcher matcher; + + InvertMatcher(ChannelMatcher matcher) { + this.matcher = matcher; + } + + @Override + public boolean matches(Channel channel) { + return !matcher.matches(channel); + } + } + + private static final class InstanceMatcher implements ChannelMatcher { + private final Channel channel; + + InstanceMatcher(Channel channel) { + this.channel = channel; + } + + @Override + public boolean matches(Channel ch) { + return channel == ch; + } + } + + private static final class ClassMatcher implements ChannelMatcher { + private final Class clazz; + + ClassMatcher(Class clazz) { + this.clazz = clazz; + } + + @Override + public boolean matches(Channel ch) { + return clazz.isInstance(ch); + } + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/group/CombinedIterator.java b/netty-channel/src/main/java/io/netty/channel/group/CombinedIterator.java new file mode 100644 index 0000000..e9ae345 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/group/CombinedIterator.java @@ -0,0 +1,72 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.group; + +import io.netty.util.internal.ObjectUtil; + +import java.util.Iterator; +import java.util.NoSuchElementException; + +/** + */ +final class CombinedIterator implements Iterator { + + private final Iterator i1; + private final Iterator i2; + private Iterator currentIterator; + + CombinedIterator(Iterator i1, Iterator i2) { + this.i1 = ObjectUtil.checkNotNull(i1, "i1"); + this.i2 = ObjectUtil.checkNotNull(i2, "i2"); + this.currentIterator = i1; + } + + @Override + public boolean hasNext() { + for (;;) { + if (currentIterator.hasNext()) { + return true; + } + + if (currentIterator == i1) { + currentIterator = i2; + } else { + return false; + } + } + } + + @Override + public E next() { + for (;;) { + try { + return currentIterator.next(); + } catch (NoSuchElementException e) { + if (currentIterator == i1) { + currentIterator = i2; + } else { + throw e; + } + } + } + } + + @Override + public void remove() { + currentIterator.remove(); + } + +} diff --git a/netty-channel/src/main/java/io/netty/channel/group/DefaultChannelGroup.java b/netty-channel/src/main/java/io/netty/channel/group/DefaultChannelGroup.java new file mode 100644 index 0000000..dc282f2 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/group/DefaultChannelGroup.java @@ -0,0 +1,463 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.group; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufHolder; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelId; +import io.netty.channel.ServerChannel; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.concurrent.EventExecutor; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.StringUtil; + +import java.util.AbstractSet; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.atomic.AtomicInteger; + +/** + * The default {@link ChannelGroup} implementation. + */ +public class DefaultChannelGroup extends AbstractSet implements ChannelGroup { + + private static final AtomicInteger nextId = new AtomicInteger(); + private final String name; + private final EventExecutor executor; + private final ConcurrentMap serverChannels = PlatformDependent.newConcurrentHashMap(); + private final ConcurrentMap nonServerChannels = PlatformDependent.newConcurrentHashMap(); + private final ChannelFutureListener remover = new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + remove(future.channel()); + } + }; + private final VoidChannelGroupFuture voidFuture = new VoidChannelGroupFuture(this); + private final boolean stayClosed; + private volatile boolean closed; + + /** + * Creates a new group with a generated name and the provided {@link EventExecutor} to notify the + * {@link ChannelGroupFuture}s. + */ + public DefaultChannelGroup(EventExecutor executor) { + this(executor, false); + } + + /** + * Creates a new group with the specified {@code name} and {@link EventExecutor} to notify the + * {@link ChannelGroupFuture}s. Please note that different groups can have the same name, which means no + * duplicate check is done against group names. + */ + public DefaultChannelGroup(String name, EventExecutor executor) { + this(name, executor, false); + } + + /** + * Creates a new group with a generated name and the provided {@link EventExecutor} to notify the + * {@link ChannelGroupFuture}s. {@code stayClosed} defines whether or not, this group can be closed + * more than once. Adding channels to a closed group will immediately close them, too. This makes it + * easy, to shutdown server and child channels at once. + */ + public DefaultChannelGroup(EventExecutor executor, boolean stayClosed) { + this("group-0x" + Integer.toHexString(nextId.incrementAndGet()), executor, stayClosed); + } + + /** + * Creates a new group with the specified {@code name} and {@link EventExecutor} to notify the + * {@link ChannelGroupFuture}s. {@code stayClosed} defines whether or not, this group can be closed + * more than once. Adding channels to a closed group will immediately close them, too. This makes it + * easy, to shutdown server and child channels at once. Please note that different groups can have + * the same name, which means no duplicate check is done against group names. + */ + public DefaultChannelGroup(String name, EventExecutor executor, boolean stayClosed) { + ObjectUtil.checkNotNull(name, "name"); + this.name = name; + this.executor = executor; + this.stayClosed = stayClosed; + } + + @Override + public String name() { + return name; + } + + @Override + public Channel find(ChannelId id) { + Channel c = nonServerChannels.get(id); + if (c != null) { + return c; + } else { + return serverChannels.get(id); + } + } + + @Override + public boolean isEmpty() { + return nonServerChannels.isEmpty() && serverChannels.isEmpty(); + } + + @Override + public int size() { + return nonServerChannels.size() + serverChannels.size(); + } + + @Override + public boolean contains(Object o) { + if (o instanceof ServerChannel) { + return serverChannels.containsValue(o); + } else if (o instanceof Channel) { + return nonServerChannels.containsValue(o); + } + return false; + } + + @Override + public boolean add(Channel channel) { + ConcurrentMap map = + channel instanceof ServerChannel? serverChannels : nonServerChannels; + + boolean added = map.putIfAbsent(channel.id(), channel) == null; + if (added) { + channel.closeFuture().addListener(remover); + } + + if (stayClosed && closed) { + + // First add channel, than check if closed. + // Seems inefficient at first, but this way a volatile + // gives us enough synchronization to be thread-safe. + // + // If true: Close right away. + // (Might be closed a second time by ChannelGroup.close(), but this is ok) + // + // If false: Channel will definitely be closed by the ChannelGroup. + // (Because closed=true always happens-before ChannelGroup.close()) + // + // See https://github.com/netty/netty/issues/4020 + channel.close(); + } + + return added; + } + + @Override + public boolean remove(Object o) { + Channel c = null; + if (o instanceof ChannelId) { + c = nonServerChannels.remove(o); + if (c == null) { + c = serverChannels.remove(o); + } + } else if (o instanceof Channel) { + c = (Channel) o; + if (c instanceof ServerChannel) { + c = serverChannels.remove(c.id()); + } else { + c = nonServerChannels.remove(c.id()); + } + } + + if (c == null) { + return false; + } + + c.closeFuture().removeListener(remover); + return true; + } + + @Override + public void clear() { + nonServerChannels.clear(); + serverChannels.clear(); + } + + @Override + public Iterator iterator() { + return new CombinedIterator( + serverChannels.values().iterator(), + nonServerChannels.values().iterator()); + } + + @Override + public Object[] toArray() { + Collection channels = new ArrayList(size()); + channels.addAll(serverChannels.values()); + channels.addAll(nonServerChannels.values()); + return channels.toArray(); + } + + @Override + public T[] toArray(T[] a) { + Collection channels = new ArrayList(size()); + channels.addAll(serverChannels.values()); + channels.addAll(nonServerChannels.values()); + return channels.toArray(a); + } + + @Override + public ChannelGroupFuture close() { + return close(ChannelMatchers.all()); + } + + @Override + public ChannelGroupFuture disconnect() { + return disconnect(ChannelMatchers.all()); + } + + @Override + public ChannelGroupFuture deregister() { + return deregister(ChannelMatchers.all()); + } + + @Override + public ChannelGroupFuture write(Object message) { + return write(message, ChannelMatchers.all()); + } + + // Create a safe duplicate of the message to write it to a channel but not affect other writes. + // See https://github.com/netty/netty/issues/1461 + private static Object safeDuplicate(Object message) { + if (message instanceof ByteBuf) { + return ((ByteBuf) message).retainedDuplicate(); + } else if (message instanceof ByteBufHolder) { + return ((ByteBufHolder) message).retainedDuplicate(); + } else { + return ReferenceCountUtil.retain(message); + } + } + + @Override + public ChannelGroupFuture write(Object message, ChannelMatcher matcher) { + return write(message, matcher, false); + } + + @Override + public ChannelGroupFuture write(Object message, ChannelMatcher matcher, boolean voidPromise) { + ObjectUtil.checkNotNull(message, "message"); + ObjectUtil.checkNotNull(matcher, "matcher"); + + final ChannelGroupFuture future; + if (voidPromise) { + for (Channel c: nonServerChannels.values()) { + if (matcher.matches(c)) { + c.write(safeDuplicate(message), c.voidPromise()); + } + } + future = voidFuture; + } else { + Map futures = new LinkedHashMap(nonServerChannels.size()); + for (Channel c: nonServerChannels.values()) { + if (matcher.matches(c)) { + futures.put(c, c.write(safeDuplicate(message))); + } + } + future = new DefaultChannelGroupFuture(this, futures, executor); + } + ReferenceCountUtil.release(message); + return future; + } + + @Override + public ChannelGroup flush() { + return flush(ChannelMatchers.all()); + } + + @Override + public ChannelGroupFuture flushAndWrite(Object message) { + return writeAndFlush(message); + } + + @Override + public ChannelGroupFuture writeAndFlush(Object message) { + return writeAndFlush(message, ChannelMatchers.all()); + } + + @Override + public ChannelGroupFuture disconnect(ChannelMatcher matcher) { + ObjectUtil.checkNotNull(matcher, "matcher"); + + Map futures = + new LinkedHashMap(size()); + + for (Channel c: serverChannels.values()) { + if (matcher.matches(c)) { + futures.put(c, c.disconnect()); + } + } + for (Channel c: nonServerChannels.values()) { + if (matcher.matches(c)) { + futures.put(c, c.disconnect()); + } + } + + return new DefaultChannelGroupFuture(this, futures, executor); + } + + @Override + public ChannelGroupFuture close(ChannelMatcher matcher) { + ObjectUtil.checkNotNull(matcher, "matcher"); + + Map futures = + new LinkedHashMap(size()); + + if (stayClosed) { + // It is important to set the closed to true, before closing channels. + // Our invariants are: + // closed=true happens-before ChannelGroup.close() + // ChannelGroup.add() happens-before checking closed==true + // + // See https://github.com/netty/netty/issues/4020 + closed = true; + } + + for (Channel c: serverChannels.values()) { + if (matcher.matches(c)) { + futures.put(c, c.close()); + } + } + for (Channel c: nonServerChannels.values()) { + if (matcher.matches(c)) { + futures.put(c, c.close()); + } + } + + return new DefaultChannelGroupFuture(this, futures, executor); + } + + @Override + public ChannelGroupFuture deregister(ChannelMatcher matcher) { + ObjectUtil.checkNotNull(matcher, "matcher"); + + Map futures = + new LinkedHashMap(size()); + + for (Channel c: serverChannels.values()) { + if (matcher.matches(c)) { + futures.put(c, c.deregister()); + } + } + for (Channel c: nonServerChannels.values()) { + if (matcher.matches(c)) { + futures.put(c, c.deregister()); + } + } + + return new DefaultChannelGroupFuture(this, futures, executor); + } + + @Override + public ChannelGroup flush(ChannelMatcher matcher) { + for (Channel c: nonServerChannels.values()) { + if (matcher.matches(c)) { + c.flush(); + } + } + return this; + } + + @Override + public ChannelGroupFuture flushAndWrite(Object message, ChannelMatcher matcher) { + return writeAndFlush(message, matcher); + } + + @Override + public ChannelGroupFuture writeAndFlush(Object message, ChannelMatcher matcher) { + return writeAndFlush(message, matcher, false); + } + + @Override + public ChannelGroupFuture writeAndFlush(Object message, ChannelMatcher matcher, boolean voidPromise) { + ObjectUtil.checkNotNull(message, "message"); + + final ChannelGroupFuture future; + if (voidPromise) { + for (Channel c: nonServerChannels.values()) { + if (matcher.matches(c)) { + c.writeAndFlush(safeDuplicate(message), c.voidPromise()); + } + } + future = voidFuture; + } else { + Map futures = new LinkedHashMap(nonServerChannels.size()); + for (Channel c: nonServerChannels.values()) { + if (matcher.matches(c)) { + futures.put(c, c.writeAndFlush(safeDuplicate(message))); + } + } + future = new DefaultChannelGroupFuture(this, futures, executor); + } + ReferenceCountUtil.release(message); + return future; + } + + @Override + public ChannelGroupFuture newCloseFuture() { + return newCloseFuture(ChannelMatchers.all()); + } + + @Override + public ChannelGroupFuture newCloseFuture(ChannelMatcher matcher) { + Map futures = + new LinkedHashMap(size()); + + for (Channel c: serverChannels.values()) { + if (matcher.matches(c)) { + futures.put(c, c.closeFuture()); + } + } + for (Channel c: nonServerChannels.values()) { + if (matcher.matches(c)) { + futures.put(c, c.closeFuture()); + } + } + + return new DefaultChannelGroupFuture(this, futures, executor); + } + + @Override + public int hashCode() { + return System.identityHashCode(this); + } + + @Override + public boolean equals(Object o) { + return this == o; + } + + @Override + public int compareTo(ChannelGroup o) { + int v = name().compareTo(o.name()); + if (v != 0) { + return v; + } + + return System.identityHashCode(this) - System.identityHashCode(o); + } + + @Override + public String toString() { + return StringUtil.simpleClassName(this) + "(name: " + name() + ", size: " + size() + ')'; + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/group/DefaultChannelGroupFuture.java b/netty-channel/src/main/java/io/netty/channel/group/DefaultChannelGroupFuture.java new file mode 100644 index 0000000..4b1a051 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/group/DefaultChannelGroupFuture.java @@ -0,0 +1,259 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.group; + +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.util.concurrent.BlockingOperationException; +import io.netty.util.concurrent.DefaultPromise; +import io.netty.util.concurrent.EventExecutor; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.GenericFutureListener; +import io.netty.util.concurrent.ImmediateEventExecutor; +import io.netty.util.internal.ObjectUtil; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + + +/** + * The default {@link ChannelGroupFuture} implementation. + */ +final class DefaultChannelGroupFuture extends DefaultPromise implements ChannelGroupFuture { + + private final ChannelGroup group; + private final Map futures; + private int successCount; + private int failureCount; + + private final ChannelFutureListener childListener = new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + boolean success = future.isSuccess(); + boolean callSetDone; + synchronized (DefaultChannelGroupFuture.this) { + if (success) { + successCount ++; + } else { + failureCount ++; + } + + callSetDone = successCount + failureCount == futures.size(); + assert successCount + failureCount <= futures.size(); + } + + if (callSetDone) { + if (failureCount > 0) { + List> failed = + new ArrayList>(failureCount); + for (ChannelFuture f: futures.values()) { + if (!f.isSuccess()) { + failed.add(new DefaultEntry(f.channel(), f.cause())); + } + } + setFailure0(new ChannelGroupException(failed)); + } else { + setSuccess0(); + } + } + } + }; + + /** + * Creates a new instance. + */ + DefaultChannelGroupFuture(ChannelGroup group, Collection futures, EventExecutor executor) { + super(executor); + this.group = ObjectUtil.checkNotNull(group, "group"); + ObjectUtil.checkNotNull(futures, "futures"); + + Map futureMap = new LinkedHashMap(); + for (ChannelFuture f: futures) { + futureMap.put(f.channel(), f); + } + + this.futures = Collections.unmodifiableMap(futureMap); + + for (ChannelFuture f: this.futures.values()) { + f.addListener(childListener); + } + + // Done on arrival? + if (this.futures.isEmpty()) { + setSuccess0(); + } + } + + DefaultChannelGroupFuture(ChannelGroup group, Map futures, EventExecutor executor) { + super(executor); + this.group = group; + this.futures = Collections.unmodifiableMap(futures); + for (ChannelFuture f: this.futures.values()) { + f.addListener(childListener); + } + + // Done on arrival? + if (this.futures.isEmpty()) { + setSuccess0(); + } + } + + @Override + public ChannelGroup group() { + return group; + } + + @Override + public ChannelFuture find(Channel channel) { + return futures.get(channel); + } + + @Override + public Iterator iterator() { + return futures.values().iterator(); + } + + @Override + public synchronized boolean isPartialSuccess() { + return successCount != 0 && successCount != futures.size(); + } + + @Override + public synchronized boolean isPartialFailure() { + return failureCount != 0 && failureCount != futures.size(); + } + + @Override + public DefaultChannelGroupFuture addListener(GenericFutureListener> listener) { + super.addListener(listener); + return this; + } + + @Override + public DefaultChannelGroupFuture addListeners(GenericFutureListener>... listeners) { + super.addListeners(listeners); + return this; + } + + @Override + public DefaultChannelGroupFuture removeListener(GenericFutureListener> listener) { + super.removeListener(listener); + return this; + } + + @Override + public DefaultChannelGroupFuture removeListeners( + GenericFutureListener>... listeners) { + super.removeListeners(listeners); + return this; + } + + @Override + public DefaultChannelGroupFuture await() throws InterruptedException { + super.await(); + return this; + } + + @Override + public DefaultChannelGroupFuture awaitUninterruptibly() { + super.awaitUninterruptibly(); + return this; + } + + @Override + public DefaultChannelGroupFuture syncUninterruptibly() { + super.syncUninterruptibly(); + return this; + } + + @Override + public DefaultChannelGroupFuture sync() throws InterruptedException { + super.sync(); + return this; + } + + @Override + public ChannelGroupException cause() { + return (ChannelGroupException) super.cause(); + } + + private void setSuccess0() { + super.setSuccess(null); + } + + private void setFailure0(ChannelGroupException cause) { + super.setFailure(cause); + } + + @Override + public DefaultChannelGroupFuture setSuccess(Void result) { + throw new IllegalStateException(); + } + + @Override + public boolean trySuccess(Void result) { + throw new IllegalStateException(); + } + + @Override + public DefaultChannelGroupFuture setFailure(Throwable cause) { + throw new IllegalStateException(); + } + + @Override + public boolean tryFailure(Throwable cause) { + throw new IllegalStateException(); + } + + @Override + protected void checkDeadLock() { + EventExecutor e = executor(); + if (e != null && e != ImmediateEventExecutor.INSTANCE && e.inEventLoop()) { + throw new BlockingOperationException(); + } + } + + private static final class DefaultEntry implements Map.Entry { + private final K key; + private final V value; + + DefaultEntry(K key, V value) { + this.key = key; + this.value = value; + } + + @Override + public K getKey() { + return key; + } + + @Override + public V getValue() { + return value; + } + + @Override + public V setValue(V value) { + throw new UnsupportedOperationException("read-only"); + } + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/group/VoidChannelGroupFuture.java b/netty-channel/src/main/java/io/netty/channel/group/VoidChannelGroupFuture.java new file mode 100644 index 0000000..1c65d53 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/group/VoidChannelGroupFuture.java @@ -0,0 +1,175 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.group; + +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.GenericFutureListener; + +import java.util.Collections; +import java.util.Iterator; +import java.util.concurrent.TimeUnit; + +// Suppress a warning about returning the same iterator since it always returns an empty iterator +final class VoidChannelGroupFuture implements ChannelGroupFuture { + + private static final Iterator EMPTY = Collections.emptyList().iterator(); + private final ChannelGroup group; + + VoidChannelGroupFuture(ChannelGroup group) { + this.group = group; + } + + @Override + public ChannelGroup group() { + return group; + } + + @Override + public ChannelFuture find(Channel channel) { + return null; + } + + @Override + public boolean isSuccess() { + return false; + } + + @Override + public ChannelGroupException cause() { + return null; + } + + @Override + public boolean isPartialSuccess() { + return false; + } + + @Override + public boolean isPartialFailure() { + return false; + } + + @Override + public ChannelGroupFuture addListener(GenericFutureListener> listener) { + throw reject(); + } + + @Override + public ChannelGroupFuture addListeners(GenericFutureListener>... listeners) { + throw reject(); + } + + @Override + public ChannelGroupFuture removeListener(GenericFutureListener> listener) { + throw reject(); + } + + @Override + public ChannelGroupFuture removeListeners(GenericFutureListener>... listeners) { + throw reject(); + } + + @Override + public ChannelGroupFuture await() { + throw reject(); + } + + @Override + public ChannelGroupFuture awaitUninterruptibly() { + throw reject(); + } + + @Override + public ChannelGroupFuture syncUninterruptibly() { + throw reject(); + } + + @Override + public ChannelGroupFuture sync() { + throw reject(); + } + + @Override + public Iterator iterator() { + return EMPTY; + } + + @Override + public boolean isCancellable() { + return false; + } + + @Override + public boolean await(long timeout, TimeUnit unit) { + throw reject(); + } + + @Override + public boolean await(long timeoutMillis) { + throw reject(); + } + + @Override + public boolean awaitUninterruptibly(long timeout, TimeUnit unit) { + throw reject(); + } + + @Override + public boolean awaitUninterruptibly(long timeoutMillis) { + throw reject(); + } + + @Override + public Void getNow() { + return null; + } + + /** + * {@inheritDoc} + * + * @param mayInterruptIfRunning this value has no effect in this implementation. + */ + @Override + public boolean cancel(boolean mayInterruptIfRunning) { + return false; + } + + @Override + public boolean isCancelled() { + return false; + } + + @Override + public boolean isDone() { + return false; + } + + @Override + public Void get() { + throw reject(); + } + + @Override + public Void get(long timeout, TimeUnit unit) { + throw reject(); + } + + private static RuntimeException reject() { + return new IllegalStateException("void future"); + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/group/package-info.java b/netty-channel/src/main/java/io/netty/channel/group/package-info.java new file mode 100644 index 0000000..6dd41b5 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/group/package-info.java @@ -0,0 +1,21 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/** + * A channel registry which helps a user maintain the list of open + * {@link io.netty.channel.Channel}s and perform bulk operations on them. + */ +package io.netty.channel.group; diff --git a/netty-channel/src/main/java/io/netty/channel/internal/ChannelUtils.java b/netty-channel/src/main/java/io/netty/channel/internal/ChannelUtils.java new file mode 100644 index 0000000..a203317 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/internal/ChannelUtils.java @@ -0,0 +1,24 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.internal; + +public final class ChannelUtils { + public static final int MAX_BYTES_PER_GATHERING_WRITE_ATTEMPTED_LOW_THRESHOLD = 4096; + public static final int WRITE_STATUS_SNDBUF_FULL = Integer.MAX_VALUE; + + private ChannelUtils() { + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/internal/package-info.java b/netty-channel/src/main/java/io/netty/channel/internal/package-info.java new file mode 100644 index 0000000..74972b4 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/internal/package-info.java @@ -0,0 +1,20 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/** + * Internal utilities for channel implementations. + */ +package io.netty.channel.internal; diff --git a/netty-channel/src/main/java/io/netty/channel/local/LocalAddress.java b/netty-channel/src/main/java/io/netty/channel/local/LocalAddress.java new file mode 100644 index 0000000..0cb8629 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/local/LocalAddress.java @@ -0,0 +1,97 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.local; + +import static io.netty.util.internal.ObjectUtil.checkNonEmptyAfterTrim; + +import io.netty.channel.Channel; + +import java.net.SocketAddress; +import java.util.UUID; + +/** + * An endpoint in the local transport. Each endpoint is identified by a unique + * case-insensitive string. + */ +public final class LocalAddress extends SocketAddress implements Comparable { + + private static final long serialVersionUID = 4644331421130916435L; + + public static final LocalAddress ANY = new LocalAddress("ANY"); + + private final String id; + private final String strVal; + + /** + * Creates a new ephemeral port based on the ID of the specified channel. + * Note that we prepend an upper-case character so that it never conflicts with + * the addresses created by a user, which are always lower-cased on construction time. + */ + LocalAddress(Channel channel) { + StringBuilder buf = new StringBuilder(16); + buf.append("local:E"); + buf.append(Long.toHexString(channel.hashCode() & 0xFFFFFFFFL | 0x100000000L)); + buf.setCharAt(7, ':'); + id = buf.substring(6); + strVal = buf.toString(); + } + + /** + * Creates a new instance with the specified ID. + */ + public LocalAddress(String id) { + this.id = checkNonEmptyAfterTrim(id, "id").toLowerCase(); + strVal = "local:" + this.id; + } + + /** + * Creates a new instance with a random ID based on the given class. + */ + public LocalAddress(Class cls) { + this(cls.getSimpleName() + '/' + UUID.randomUUID()); + } + + /** + * Returns the ID of this address. + */ + public String id() { + return id; + } + + @Override + public int hashCode() { + return id.hashCode(); + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof LocalAddress)) { + return false; + } + + return id.equals(((LocalAddress) o).id); + } + + @Override + public int compareTo(LocalAddress o) { + return id.compareTo(o.id); + } + + @Override + public String toString() { + return strVal; + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/local/LocalChannel.java b/netty-channel/src/main/java/io/netty/channel/local/LocalChannel.java new file mode 100644 index 0000000..8e5f9c5 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/local/LocalChannel.java @@ -0,0 +1,524 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.local; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.AbstractChannel; +import io.netty.channel.Channel; +import io.netty.channel.ChannelConfig; +import io.netty.channel.ChannelMetadata; +import io.netty.channel.ChannelOutboundBuffer; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.ChannelPromise; +import io.netty.channel.DefaultChannelConfig; +import io.netty.channel.EventLoop; +import io.netty.channel.PreferHeapByteBufAllocator; +import io.netty.channel.RecvByteBufAllocator; +import io.netty.channel.SingleThreadEventLoop; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.SingleThreadEventExecutor; +import io.netty.util.internal.InternalThreadLocalMap; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.net.ConnectException; +import java.net.SocketAddress; +import java.nio.channels.AlreadyConnectedException; +import java.nio.channels.ClosedChannelException; +import java.nio.channels.ConnectionPendingException; +import java.nio.channels.NotYetConnectedException; +import java.util.Queue; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; + +/** + * A {@link Channel} for the local transport. + */ +public class LocalChannel extends AbstractChannel { + private static final InternalLogger logger = InternalLoggerFactory.getInstance(LocalChannel.class); + @SuppressWarnings({ "rawtypes" }) + private static final AtomicReferenceFieldUpdater FINISH_READ_FUTURE_UPDATER = + AtomicReferenceFieldUpdater.newUpdater(LocalChannel.class, Future.class, "finishReadFuture"); + private static final ChannelMetadata METADATA = new ChannelMetadata(false); + private static final int MAX_READER_STACK_DEPTH = 8; + + private enum State { OPEN, BOUND, CONNECTED, CLOSED } + + private final ChannelConfig config = new DefaultChannelConfig(this); + // To further optimize this we could write our own SPSC queue. + final Queue inboundBuffer = PlatformDependent.newSpscQueue(); + private final Runnable readTask = new Runnable() { + @Override + public void run() { + // Ensure the inboundBuffer is not empty as readInbound() will always call fireChannelReadComplete() + if (!inboundBuffer.isEmpty()) { + readInbound(); + } + } + }; + + private final Runnable shutdownHook = new Runnable() { + @Override + public void run() { + unsafe().close(unsafe().voidPromise()); + } + }; + + private volatile State state; + private volatile LocalChannel peer; + private volatile LocalAddress localAddress; + private volatile LocalAddress remoteAddress; + private volatile ChannelPromise connectPromise; + private volatile boolean readInProgress; + private volatile boolean writeInProgress; + private volatile Future finishReadFuture; + + public LocalChannel() { + super(null); + config().setAllocator(new PreferHeapByteBufAllocator(config.getAllocator())); + } + + protected LocalChannel(LocalServerChannel parent, LocalChannel peer) { + super(parent); + config().setAllocator(new PreferHeapByteBufAllocator(config.getAllocator())); + this.peer = peer; + localAddress = parent.localAddress(); + remoteAddress = peer.localAddress(); + } + + @Override + public ChannelMetadata metadata() { + return METADATA; + } + + @Override + public ChannelConfig config() { + return config; + } + + @Override + public LocalServerChannel parent() { + return (LocalServerChannel) super.parent(); + } + + @Override + public LocalAddress localAddress() { + return (LocalAddress) super.localAddress(); + } + + @Override + public LocalAddress remoteAddress() { + return (LocalAddress) super.remoteAddress(); + } + + @Override + public boolean isOpen() { + return state != State.CLOSED; + } + + @Override + public boolean isActive() { + return state == State.CONNECTED; + } + + @Override + protected AbstractUnsafe newUnsafe() { + return new LocalUnsafe(); + } + + @Override + protected boolean isCompatible(EventLoop loop) { + return loop instanceof SingleThreadEventLoop; + } + + @Override + protected SocketAddress localAddress0() { + return localAddress; + } + + @Override + protected SocketAddress remoteAddress0() { + return remoteAddress; + } + + @Override + protected void doRegister() throws Exception { + // Check if both peer and parent are non-null because this channel was created by a LocalServerChannel. + // This is needed as a peer may not be null also if a LocalChannel was connected before and + // deregistered / registered later again. + // + // See https://github.com/netty/netty/issues/2400 + if (peer != null && parent() != null) { + // Store the peer in a local variable as it may be set to null if doClose() is called. + // See https://github.com/netty/netty/issues/2144 + final LocalChannel peer = this.peer; + state = State.CONNECTED; + + peer.remoteAddress = parent() == null ? null : parent().localAddress(); + peer.state = State.CONNECTED; + + // Always call peer.eventLoop().execute() even if peer.eventLoop().inEventLoop() is true. + // This ensures that if both channels are on the same event loop, the peer's channelActive + // event is triggered *after* this channel's channelRegistered event, so that this channel's + // pipeline is fully initialized by ChannelInitializer before any channelRead events. + peer.eventLoop().execute(new Runnable() { + @Override + public void run() { + ChannelPromise promise = peer.connectPromise; + + // Only trigger fireChannelActive() if the promise was not null and was not completed yet. + // connectPromise may be set to null if doClose() was called in the meantime. + if (promise != null && promise.trySuccess()) { + peer.pipeline().fireChannelActive(); + } + } + }); + } + ((SingleThreadEventExecutor) eventLoop()).addShutdownHook(shutdownHook); + } + + @Override + protected void doBind(SocketAddress localAddress) throws Exception { + this.localAddress = + LocalChannelRegistry.register(this, this.localAddress, + localAddress); + state = State.BOUND; + } + + @Override + protected void doDisconnect() throws Exception { + doClose(); + } + + @Override + protected void doClose() throws Exception { + final LocalChannel peer = this.peer; + State oldState = state; + try { + if (oldState != State.CLOSED) { + // Update all internal state before the closeFuture is notified. + if (localAddress != null) { + if (parent() == null) { + LocalChannelRegistry.unregister(localAddress); + } + localAddress = null; + } + + // State change must happen before finishPeerRead to ensure writes are released either in doWrite or + // channelRead. + state = State.CLOSED; + + // Preserve order of event and force a read operation now before the close operation is processed. + if (writeInProgress && peer != null) { + finishPeerRead(peer); + } + + ChannelPromise promise = connectPromise; + if (promise != null) { + // Use tryFailure() instead of setFailure() to avoid the race against cancel(). + promise.tryFailure(new ClosedChannelException()); + connectPromise = null; + } + } + + if (peer != null) { + this.peer = null; + // Always call peer.eventLoop().execute() even if peer.eventLoop().inEventLoop() is true. + // This ensures that if both channels are on the same event loop, the peer's channelInActive + // event is triggered *after* this peer's channelInActive event + EventLoop peerEventLoop = peer.eventLoop(); + final boolean peerIsActive = peer.isActive(); + try { + peerEventLoop.execute(new Runnable() { + @Override + public void run() { + peer.tryClose(peerIsActive); + } + }); + } catch (Throwable cause) { + logger.warn("Releasing Inbound Queues for channels {}-{} because exception occurred!", + this, peer, cause); + if (peerEventLoop.inEventLoop()) { + peer.releaseInboundBuffers(); + } else { + // inboundBuffers is a SPSC so we may leak if the event loop is shutdown prematurely or + // rejects the close Runnable but give a best effort. + peer.close(); + } + PlatformDependent.throwException(cause); + } + } + } finally { + // Release all buffers if the Channel was already registered in the past and if it was not closed before. + if (oldState != null && oldState != State.CLOSED) { + // We need to release all the buffers that may be put into our inbound queue since we closed the Channel + // to ensure we not leak any memory. This is fine as it basically gives the same guarantees as TCP which + // means even if the promise was notified before its not really guaranteed that the "remote peer" will + // see the buffer at all. + releaseInboundBuffers(); + } + } + } + + private void tryClose(boolean isActive) { + if (isActive) { + unsafe().close(unsafe().voidPromise()); + } else { + releaseInboundBuffers(); + } + } + + @Override + protected void doDeregister() throws Exception { + // Just remove the shutdownHook as this Channel may be closed later or registered to another EventLoop + ((SingleThreadEventExecutor) eventLoop()).removeShutdownHook(shutdownHook); + } + + private void readInbound() { + RecvByteBufAllocator.Handle handle = unsafe().recvBufAllocHandle(); + handle.reset(config()); + ChannelPipeline pipeline = pipeline(); + do { + Object received = inboundBuffer.poll(); + if (received == null) { + break; + } + if (received instanceof ByteBuf && inboundBuffer.peek() instanceof ByteBuf) { + ByteBuf msg = (ByteBuf) received; + ByteBuf output = handle.allocate(alloc()); + if (msg.readableBytes() < output.writableBytes()) { + // We have an opportunity to coalesce buffers. + output.writeBytes(msg, msg.readerIndex(), msg.readableBytes()); + msg.release(); + while ((received = inboundBuffer.peek()) instanceof ByteBuf && + ((ByteBuf) received).readableBytes() < output.writableBytes()) { + inboundBuffer.poll(); + msg = (ByteBuf) received; + output.writeBytes(msg, msg.readerIndex(), msg.readableBytes()); + msg.release(); + } + handle.lastBytesRead(output.readableBytes()); + received = output; // Send the coalesced buffer down the pipeline. + } else { + // It won't be profitable to coalesce buffers this time around. + handle.lastBytesRead(output.capacity()); + output.release(); + } + } + handle.incMessagesRead(1); + pipeline.fireChannelRead(received); + } while (handle.continueReading()); + handle.readComplete(); + pipeline.fireChannelReadComplete(); + } + + @Override + protected void doBeginRead() throws Exception { + if (readInProgress) { + return; + } + + Queue inboundBuffer = this.inboundBuffer; + if (inboundBuffer.isEmpty()) { + readInProgress = true; + return; + } + + final InternalThreadLocalMap threadLocals = InternalThreadLocalMap.get(); + final int stackDepth = threadLocals.localChannelReaderStackDepth(); + if (stackDepth < MAX_READER_STACK_DEPTH) { + threadLocals.setLocalChannelReaderStackDepth(stackDepth + 1); + try { + readInbound(); + } finally { + threadLocals.setLocalChannelReaderStackDepth(stackDepth); + } + } else { + try { + eventLoop().execute(readTask); + } catch (Throwable cause) { + logger.warn("Closing Local channels {}-{} because exception occurred!", this, peer, cause); + close(); + peer.close(); + PlatformDependent.throwException(cause); + } + } + } + + @Override + protected void doWrite(ChannelOutboundBuffer in) throws Exception { + switch (state) { + case OPEN: + case BOUND: + throw new NotYetConnectedException(); + case CLOSED: + throw new ClosedChannelException(); + case CONNECTED: + break; + } + + final LocalChannel peer = this.peer; + + writeInProgress = true; + try { + ClosedChannelException exception = null; + for (;;) { + Object msg = in.current(); + if (msg == null) { + break; + } + try { + // It is possible the peer could have closed while we are writing, and in this case we should + // simulate real socket behavior and ensure the write operation is failed. + if (peer.state == State.CONNECTED) { + peer.inboundBuffer.add(ReferenceCountUtil.retain(msg)); + in.remove(); + } else { + if (exception == null) { + exception = new ClosedChannelException(); + } + in.remove(exception); + } + } catch (Throwable cause) { + in.remove(cause); + } + } + } finally { + // The following situation may cause trouble: + // 1. Write (with promise X) + // 2. promise X is completed when in.remove() is called, and a listener on this promise calls close() + // 3. Then the close event will be executed for the peer before the write events, when the write events + // actually happened before the close event. + writeInProgress = false; + } + + finishPeerRead(peer); + } + + private void finishPeerRead(final LocalChannel peer) { + // If the peer is also writing, then we must schedule the event on the event loop to preserve read order. + if (peer.eventLoop() == eventLoop() && !peer.writeInProgress) { + finishPeerRead0(peer); + } else { + runFinishPeerReadTask(peer); + } + } + + private void runFinishPeerReadTask(final LocalChannel peer) { + // If the peer is writing, we must wait until after reads are completed for that peer before we can read. So + // we keep track of the task, and coordinate later that our read can't happen until the peer is done. + final Runnable finishPeerReadTask = new Runnable() { + @Override + public void run() { + finishPeerRead0(peer); + } + }; + try { + if (peer.writeInProgress) { + peer.finishReadFuture = peer.eventLoop().submit(finishPeerReadTask); + } else { + peer.eventLoop().execute(finishPeerReadTask); + } + } catch (Throwable cause) { + logger.warn("Closing Local channels {}-{} because exception occurred!", this, peer, cause); + close(); + peer.close(); + PlatformDependent.throwException(cause); + } + } + + private void releaseInboundBuffers() { + assert eventLoop() == null || eventLoop().inEventLoop(); + readInProgress = false; + Queue inboundBuffer = this.inboundBuffer; + Object msg; + while ((msg = inboundBuffer.poll()) != null) { + ReferenceCountUtil.release(msg); + } + } + + private void finishPeerRead0(LocalChannel peer) { + Future peerFinishReadFuture = peer.finishReadFuture; + if (peerFinishReadFuture != null) { + if (!peerFinishReadFuture.isDone()) { + runFinishPeerReadTask(peer); + return; + } else { + // Lazy unset to make sure we don't prematurely unset it while scheduling a new task. + FINISH_READ_FUTURE_UPDATER.compareAndSet(peer, peerFinishReadFuture, null); + } + } + // We should only set readInProgress to false if there is any data that was read as otherwise we may miss to + // forward data later on. + if (peer.readInProgress && !peer.inboundBuffer.isEmpty()) { + peer.readInProgress = false; + peer.readInbound(); + } + } + + private class LocalUnsafe extends AbstractUnsafe { + + @Override + public void connect(final SocketAddress remoteAddress, + SocketAddress localAddress, final ChannelPromise promise) { + if (!promise.setUncancellable() || !ensureOpen(promise)) { + return; + } + + if (state == State.CONNECTED) { + Exception cause = new AlreadyConnectedException(); + safeSetFailure(promise, cause); + pipeline().fireExceptionCaught(cause); + return; + } + + if (connectPromise != null) { + throw new ConnectionPendingException(); + } + + connectPromise = promise; + + if (state != State.BOUND) { + // Not bound yet and no localAddress specified - get one. + if (localAddress == null) { + localAddress = new LocalAddress(LocalChannel.this); + } + } + + if (localAddress != null) { + try { + doBind(localAddress); + } catch (Throwable t) { + safeSetFailure(promise, t); + close(voidPromise()); + return; + } + } + + Channel boundChannel = LocalChannelRegistry.get(remoteAddress); + if (!(boundChannel instanceof LocalServerChannel)) { + Exception cause = new ConnectException("connection refused: " + remoteAddress); + safeSetFailure(promise, cause); + close(voidPromise()); + return; + } + + LocalServerChannel serverChannel = (LocalServerChannel) boundChannel; + peer = serverChannel.serve(LocalChannel.this); + } + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/local/LocalChannelRegistry.java b/netty-channel/src/main/java/io/netty/channel/local/LocalChannelRegistry.java new file mode 100644 index 0000000..81ba857 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/local/LocalChannelRegistry.java @@ -0,0 +1,62 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.local; + +import io.netty.channel.Channel; +import io.netty.channel.ChannelException; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.StringUtil; + +import java.net.SocketAddress; +import java.util.concurrent.ConcurrentMap; + +final class LocalChannelRegistry { + + private static final ConcurrentMap boundChannels = PlatformDependent.newConcurrentHashMap(); + + static LocalAddress register( + Channel channel, LocalAddress oldLocalAddress, SocketAddress localAddress) { + if (oldLocalAddress != null) { + throw new ChannelException("already bound"); + } + if (!(localAddress instanceof LocalAddress)) { + throw new ChannelException("unsupported address type: " + StringUtil.simpleClassName(localAddress)); + } + + LocalAddress addr = (LocalAddress) localAddress; + if (LocalAddress.ANY.equals(addr)) { + addr = new LocalAddress(channel); + } + + Channel boundChannel = boundChannels.putIfAbsent(addr, channel); + if (boundChannel != null) { + throw new ChannelException("address already in use by: " + boundChannel); + } + return addr; + } + + static Channel get(SocketAddress localAddress) { + return boundChannels.get(localAddress); + } + + static void unregister(LocalAddress localAddress) { + boundChannels.remove(localAddress); + } + + private LocalChannelRegistry() { + // Unused + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/local/LocalEventLoopGroup.java b/netty-channel/src/main/java/io/netty/channel/local/LocalEventLoopGroup.java new file mode 100644 index 0000000..212d2b0 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/local/LocalEventLoopGroup.java @@ -0,0 +1,60 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.local; + +import io.netty.channel.DefaultEventLoopGroup; + +import java.util.concurrent.ThreadFactory; + +/** + * @deprecated Use {@link DefaultEventLoopGroup} instead. + */ +@Deprecated +public class LocalEventLoopGroup extends DefaultEventLoopGroup { + + /** + * Create a new instance with the default number of threads. + */ + public LocalEventLoopGroup() { } + + /** + * Create a new instance + * + * @param nThreads the number of threads to use + */ + public LocalEventLoopGroup(int nThreads) { + super(nThreads); + } + + /** + * Create a new instance with the default number of threads and the given {@link ThreadFactory}. + * + * @param threadFactory the {@link ThreadFactory} or {@code null} to use the default + */ + public LocalEventLoopGroup(ThreadFactory threadFactory) { + super(0, threadFactory); + } + + /** + * Create a new instance + * + * @param nThreads the number of threads to use + * @param threadFactory the {@link ThreadFactory} or {@code null} to use the default + */ + public LocalEventLoopGroup(int nThreads, ThreadFactory threadFactory) { + super(nThreads, threadFactory); + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/local/LocalServerChannel.java b/netty-channel/src/main/java/io/netty/channel/local/LocalServerChannel.java new file mode 100644 index 0000000..d637cc7 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/local/LocalServerChannel.java @@ -0,0 +1,181 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.local; + +import io.netty.channel.AbstractServerChannel; +import io.netty.channel.ChannelConfig; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.DefaultChannelConfig; +import io.netty.channel.EventLoop; +import io.netty.channel.PreferHeapByteBufAllocator; +import io.netty.channel.RecvByteBufAllocator; +import io.netty.channel.ServerChannel; +import io.netty.channel.ServerChannelRecvByteBufAllocator; +import io.netty.channel.SingleThreadEventLoop; +import io.netty.util.concurrent.SingleThreadEventExecutor; + +import java.net.SocketAddress; +import java.util.ArrayDeque; +import java.util.Queue; + +/** + * A {@link ServerChannel} for the local transport which allows in VM communication. + */ +public class LocalServerChannel extends AbstractServerChannel { + + private final ChannelConfig config = + new DefaultChannelConfig(this, new ServerChannelRecvByteBufAllocator()) { }; + private final Queue inboundBuffer = new ArrayDeque(); + private final Runnable shutdownHook = new Runnable() { + @Override + public void run() { + unsafe().close(unsafe().voidPromise()); + } + }; + + private volatile int state; // 0 - open, 1 - active, 2 - closed + private volatile LocalAddress localAddress; + private volatile boolean acceptInProgress; + + public LocalServerChannel() { + config().setAllocator(new PreferHeapByteBufAllocator(config.getAllocator())); + } + + @Override + public ChannelConfig config() { + return config; + } + + @Override + public LocalAddress localAddress() { + return (LocalAddress) super.localAddress(); + } + + @Override + public LocalAddress remoteAddress() { + return (LocalAddress) super.remoteAddress(); + } + + @Override + public boolean isOpen() { + return state < 2; + } + + @Override + public boolean isActive() { + return state == 1; + } + + @Override + protected boolean isCompatible(EventLoop loop) { + return loop instanceof SingleThreadEventLoop; + } + + @Override + protected SocketAddress localAddress0() { + return localAddress; + } + + @Override + protected void doRegister() throws Exception { + ((SingleThreadEventExecutor) eventLoop()).addShutdownHook(shutdownHook); + } + + @Override + protected void doBind(SocketAddress localAddress) throws Exception { + this.localAddress = LocalChannelRegistry.register(this, this.localAddress, localAddress); + state = 1; + } + + @Override + protected void doClose() throws Exception { + if (state <= 1) { + // Update all internal state before the closeFuture is notified. + if (localAddress != null) { + LocalChannelRegistry.unregister(localAddress); + localAddress = null; + } + state = 2; + } + } + + @Override + protected void doDeregister() throws Exception { + ((SingleThreadEventExecutor) eventLoop()).removeShutdownHook(shutdownHook); + } + + @Override + protected void doBeginRead() throws Exception { + if (acceptInProgress) { + return; + } + + Queue inboundBuffer = this.inboundBuffer; + if (inboundBuffer.isEmpty()) { + acceptInProgress = true; + return; + } + + readInbound(); + } + + LocalChannel serve(final LocalChannel peer) { + final LocalChannel child = newLocalChannel(peer); + if (eventLoop().inEventLoop()) { + serve0(child); + } else { + eventLoop().execute(new Runnable() { + @Override + public void run() { + serve0(child); + } + }); + } + return child; + } + + private void readInbound() { + RecvByteBufAllocator.Handle handle = unsafe().recvBufAllocHandle(); + handle.reset(config()); + ChannelPipeline pipeline = pipeline(); + do { + Object m = inboundBuffer.poll(); + if (m == null) { + break; + } + pipeline.fireChannelRead(m); + } while (handle.continueReading()); + handle.readComplete(); + pipeline.fireChannelReadComplete(); + } + + /** + * A factory method for {@link LocalChannel}s. Users may override it + * to create custom instances of {@link LocalChannel}s. + */ + protected LocalChannel newLocalChannel(LocalChannel peer) { + return new LocalChannel(this, peer); + } + + private void serve0(final LocalChannel child) { + inboundBuffer.add(child); + if (acceptInProgress) { + acceptInProgress = false; + + readInbound(); + } + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/local/package-info.java b/netty-channel/src/main/java/io/netty/channel/local/package-info.java new file mode 100644 index 0000000..2956564 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/local/package-info.java @@ -0,0 +1,21 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/** + * A virtual transport that enables the communication between the two + * parties in the same virtual machine. + */ +package io.netty.channel.local; diff --git a/netty-channel/src/main/java/io/netty/channel/nio/AbstractNioByteChannel.java b/netty-channel/src/main/java/io/netty/channel/nio/AbstractNioByteChannel.java new file mode 100644 index 0000000..bfa2bae --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/nio/AbstractNioByteChannel.java @@ -0,0 +1,352 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.nio; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.channel.Channel; +import io.netty.channel.ChannelConfig; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelMetadata; +import io.netty.channel.ChannelOutboundBuffer; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.FileRegion; +import io.netty.channel.RecvByteBufAllocator; +import io.netty.channel.internal.ChannelUtils; +import io.netty.channel.socket.ChannelInputShutdownEvent; +import io.netty.channel.socket.ChannelInputShutdownReadComplete; +import io.netty.channel.socket.SocketChannelConfig; +import io.netty.util.internal.StringUtil; + +import java.io.IOException; +import java.nio.channels.SelectableChannel; +import java.nio.channels.SelectionKey; + +import static io.netty.channel.internal.ChannelUtils.WRITE_STATUS_SNDBUF_FULL; + +/** + * {@link AbstractNioChannel} base class for {@link Channel}s that operate on bytes. + */ +public abstract class AbstractNioByteChannel extends AbstractNioChannel { + private static final ChannelMetadata METADATA = new ChannelMetadata(false, 16); + private static final String EXPECTED_TYPES = + " (expected: " + StringUtil.simpleClassName(ByteBuf.class) + ", " + + StringUtil.simpleClassName(FileRegion.class) + ')'; + + private final Runnable flushTask = new Runnable() { + @Override + public void run() { + // Calling flush0 directly to ensure we not try to flush messages that were added via write(...) in the + // meantime. + ((AbstractNioUnsafe) unsafe()).flush0(); + } + }; + private boolean inputClosedSeenErrorOnRead; + + /** + * Create a new instance + * + * @param parent the parent {@link Channel} by which this instance was created. May be {@code null} + * @param ch the underlying {@link SelectableChannel} on which it operates + */ + protected AbstractNioByteChannel(Channel parent, SelectableChannel ch) { + super(parent, ch, SelectionKey.OP_READ); + } + + /** + * Shutdown the input side of the channel. + */ + protected abstract ChannelFuture shutdownInput(); + + protected boolean isInputShutdown0() { + return false; + } + + @Override + protected AbstractNioUnsafe newUnsafe() { + return new NioByteUnsafe(); + } + + @Override + public ChannelMetadata metadata() { + return METADATA; + } + + final boolean shouldBreakReadReady(ChannelConfig config) { + return isInputShutdown0() && (inputClosedSeenErrorOnRead || !isAllowHalfClosure(config)); + } + + private static boolean isAllowHalfClosure(ChannelConfig config) { + return config instanceof SocketChannelConfig && + ((SocketChannelConfig) config).isAllowHalfClosure(); + } + + protected class NioByteUnsafe extends AbstractNioUnsafe { + + private void closeOnRead(ChannelPipeline pipeline) { + if (!isInputShutdown0()) { + if (isAllowHalfClosure(config())) { + shutdownInput(); + pipeline.fireUserEventTriggered(ChannelInputShutdownEvent.INSTANCE); + } else { + close(voidPromise()); + } + } else if (!inputClosedSeenErrorOnRead) { + inputClosedSeenErrorOnRead = true; + pipeline.fireUserEventTriggered(ChannelInputShutdownReadComplete.INSTANCE); + } + } + + private void handleReadException(ChannelPipeline pipeline, ByteBuf byteBuf, Throwable cause, boolean close, + RecvByteBufAllocator.Handle allocHandle) { + if (byteBuf != null) { + if (byteBuf.isReadable()) { + readPending = false; + pipeline.fireChannelRead(byteBuf); + } else { + byteBuf.release(); + } + } + allocHandle.readComplete(); + pipeline.fireChannelReadComplete(); + pipeline.fireExceptionCaught(cause); + + // If oom will close the read event, release connection. + // See https://github.com/netty/netty/issues/10434 + if (close || cause instanceof OutOfMemoryError || cause instanceof IOException) { + closeOnRead(pipeline); + } + } + + @Override + public final void read() { + final ChannelConfig config = config(); + if (shouldBreakReadReady(config)) { + clearReadPending(); + return; + } + final ChannelPipeline pipeline = pipeline(); + final ByteBufAllocator allocator = config.getAllocator(); + final RecvByteBufAllocator.Handle allocHandle = recvBufAllocHandle(); + allocHandle.reset(config); + + ByteBuf byteBuf = null; + boolean close = false; + try { + do { + byteBuf = allocHandle.allocate(allocator); + allocHandle.lastBytesRead(doReadBytes(byteBuf)); + if (allocHandle.lastBytesRead() <= 0) { + // nothing was read. release the buffer. + byteBuf.release(); + byteBuf = null; + close = allocHandle.lastBytesRead() < 0; + if (close) { + // There is nothing left to read as we received an EOF. + readPending = false; + } + break; + } + + allocHandle.incMessagesRead(1); + readPending = false; + pipeline.fireChannelRead(byteBuf); + byteBuf = null; + } while (allocHandle.continueReading()); + + allocHandle.readComplete(); + pipeline.fireChannelReadComplete(); + + if (close) { + closeOnRead(pipeline); + } + } catch (Throwable t) { + handleReadException(pipeline, byteBuf, t, close, allocHandle); + } finally { + // Check if there is a readPending which was not processed yet. + // This could be for two reasons: + // * The user called Channel.read() or ChannelHandlerContext.read() in channelRead(...) method + // * The user called Channel.read() or ChannelHandlerContext.read() in channelReadComplete(...) method + // + // See https://github.com/netty/netty/issues/2254 + if (!readPending && !config.isAutoRead()) { + removeReadOp(); + } + } + } + } + + /** + * Write objects to the OS. + * @param in the collection which contains objects to write. + * @return The value that should be decremented from the write quantum which starts at + * {@link ChannelConfig#getWriteSpinCount()}. The typical use cases are as follows: + *
    + *
  • 0 - if no write was attempted. This is appropriate if an empty {@link ByteBuf} (or other empty content) + * is encountered
  • + *
  • 1 - if a single call to write data was made to the OS
  • + *
  • {@link ChannelUtils#WRITE_STATUS_SNDBUF_FULL} - if an attempt to write data was made to the OS, but no + * data was accepted
  • + *
+ * @throws Exception if an I/O exception occurs during write. + */ + protected final int doWrite0(ChannelOutboundBuffer in) throws Exception { + Object msg = in.current(); + if (msg == null) { + // Directly return here so incompleteWrite(...) is not called. + return 0; + } + return doWriteInternal(in, in.current()); + } + + private int doWriteInternal(ChannelOutboundBuffer in, Object msg) throws Exception { + if (msg instanceof ByteBuf) { + ByteBuf buf = (ByteBuf) msg; + if (!buf.isReadable()) { + in.remove(); + return 0; + } + + final int localFlushedAmount = doWriteBytes(buf); + if (localFlushedAmount > 0) { + in.progress(localFlushedAmount); + if (!buf.isReadable()) { + in.remove(); + } + return 1; + } + } else if (msg instanceof FileRegion) { + FileRegion region = (FileRegion) msg; + if (region.transferred() >= region.count()) { + in.remove(); + return 0; + } + + long localFlushedAmount = doWriteFileRegion(region); + if (localFlushedAmount > 0) { + in.progress(localFlushedAmount); + if (region.transferred() >= region.count()) { + in.remove(); + } + return 1; + } + } else { + // Should not reach here. + throw new Error(); + } + return WRITE_STATUS_SNDBUF_FULL; + } + + @Override + protected void doWrite(ChannelOutboundBuffer in) throws Exception { + int writeSpinCount = config().getWriteSpinCount(); + do { + Object msg = in.current(); + if (msg == null) { + // Wrote all messages. + clearOpWrite(); + // Directly return here so incompleteWrite(...) is not called. + return; + } + writeSpinCount -= doWriteInternal(in, msg); + } while (writeSpinCount > 0); + + incompleteWrite(writeSpinCount < 0); + } + + @Override + protected final Object filterOutboundMessage(Object msg) { + if (msg instanceof ByteBuf) { + ByteBuf buf = (ByteBuf) msg; + if (buf.isDirect()) { + return msg; + } + + return newDirectBuffer(buf); + } + + if (msg instanceof FileRegion) { + return msg; + } + + throw new UnsupportedOperationException( + "unsupported message type: " + StringUtil.simpleClassName(msg) + EXPECTED_TYPES); + } + + protected final void incompleteWrite(boolean setOpWrite) { + // Did not write completely. + if (setOpWrite) { + setOpWrite(); + } else { + // It is possible that we have set the write OP, woken up by NIO because the socket is writable, and then + // use our write quantum. In this case we no longer want to set the write OP because the socket is still + // writable (as far as we know). We will find out next time we attempt to write if the socket is writable + // and set the write OP if necessary. + clearOpWrite(); + + // Schedule flush again later so other tasks can be picked up in the meantime + eventLoop().execute(flushTask); + } + } + + /** + * Write a {@link FileRegion} + * + * @param region the {@link FileRegion} from which the bytes should be written + * @return amount the amount of written bytes + */ + protected abstract long doWriteFileRegion(FileRegion region) throws Exception; + + /** + * Read bytes into the given {@link ByteBuf} and return the amount. + */ + protected abstract int doReadBytes(ByteBuf buf) throws Exception; + + /** + * Write bytes form the given {@link ByteBuf} to the underlying {@link java.nio.channels.Channel}. + * @param buf the {@link ByteBuf} from which the bytes should be written + * @return amount the amount of written bytes + */ + protected abstract int doWriteBytes(ByteBuf buf) throws Exception; + + protected final void setOpWrite() { + final SelectionKey key = selectionKey(); + // Check first if the key is still valid as it may be canceled as part of the deregistration + // from the EventLoop + // See https://github.com/netty/netty/issues/2104 + if (!key.isValid()) { + return; + } + final int interestOps = key.interestOps(); + if ((interestOps & SelectionKey.OP_WRITE) == 0) { + key.interestOps(interestOps | SelectionKey.OP_WRITE); + } + } + + protected final void clearOpWrite() { + final SelectionKey key = selectionKey(); + // Check first if the key is still valid as it may be canceled as part of the deregistration + // from the EventLoop + // See https://github.com/netty/netty/issues/2104 + if (!key.isValid()) { + return; + } + final int interestOps = key.interestOps(); + if ((interestOps & SelectionKey.OP_WRITE) != 0) { + key.interestOps(interestOps & ~SelectionKey.OP_WRITE); + } + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/nio/AbstractNioChannel.java b/netty-channel/src/main/java/io/netty/channel/nio/AbstractNioChannel.java new file mode 100644 index 0000000..f6a515e --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/nio/AbstractNioChannel.java @@ -0,0 +1,513 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.nio; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import io.netty.channel.AbstractChannel; +import io.netty.channel.Channel; +import io.netty.channel.ChannelException; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelPromise; +import io.netty.channel.ConnectTimeoutException; +import io.netty.channel.EventLoop; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.ReferenceCounted; +import io.netty.util.concurrent.Future; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.io.IOException; +import java.net.SocketAddress; +import java.nio.channels.CancelledKeyException; +import java.nio.channels.ClosedChannelException; +import java.nio.channels.ConnectionPendingException; +import java.nio.channels.SelectableChannel; +import java.nio.channels.SelectionKey; +import java.util.concurrent.TimeUnit; + +/** + * Abstract base class for {@link Channel} implementations which use a Selector based approach. + */ +public abstract class AbstractNioChannel extends AbstractChannel { + + private static final InternalLogger logger = + InternalLoggerFactory.getInstance(AbstractNioChannel.class); + + private final SelectableChannel ch; + protected final int readInterestOp; + volatile SelectionKey selectionKey; + boolean readPending; + private final Runnable clearReadPendingRunnable = new Runnable() { + @Override + public void run() { + clearReadPending0(); + } + }; + + /** + * The future of the current connection attempt. If not null, subsequent + * connection attempts will fail. + */ + private ChannelPromise connectPromise; + private Future connectTimeoutFuture; + private SocketAddress requestedRemoteAddress; + + /** + * Create a new instance + * + * @param parent the parent {@link Channel} by which this instance was created. May be {@code null} + * @param ch the underlying {@link SelectableChannel} on which it operates + * @param readInterestOp the ops to set to receive data from the {@link SelectableChannel} + */ + protected AbstractNioChannel(Channel parent, SelectableChannel ch, int readInterestOp) { + super(parent); + this.ch = ch; + this.readInterestOp = readInterestOp; + try { + ch.configureBlocking(false); + } catch (IOException e) { + try { + ch.close(); + } catch (IOException e2) { + logger.warn( + "Failed to close a partially initialized socket.", e2); + } + + throw new ChannelException("Failed to enter non-blocking mode.", e); + } + } + + @Override + public boolean isOpen() { + return ch.isOpen(); + } + + @Override + public NioUnsafe unsafe() { + return (NioUnsafe) super.unsafe(); + } + + protected SelectableChannel javaChannel() { + return ch; + } + + @Override + public NioEventLoop eventLoop() { + return (NioEventLoop) super.eventLoop(); + } + + /** + * Return the current {@link SelectionKey} + */ + protected SelectionKey selectionKey() { + assert selectionKey != null; + return selectionKey; + } + + /** + * @deprecated No longer supported. + * No longer supported. + */ + @Deprecated + protected boolean isReadPending() { + return readPending; + } + + /** + * @deprecated Use {@link #clearReadPending()} if appropriate instead. + * No longer supported. + */ + @Deprecated + protected void setReadPending(final boolean readPending) { + if (isRegistered()) { + EventLoop eventLoop = eventLoop(); + if (eventLoop.inEventLoop()) { + setReadPending0(readPending); + } else { + eventLoop.execute(new Runnable() { + @Override + public void run() { + setReadPending0(readPending); + } + }); + } + } else { + // Best effort if we are not registered yet clear readPending. + // NB: We only set the boolean field instead of calling clearReadPending0(), because the SelectionKey is + // not set yet so it would produce an assertion failure. + this.readPending = readPending; + } + } + + /** + * Set read pending to {@code false}. + */ + protected final void clearReadPending() { + if (isRegistered()) { + EventLoop eventLoop = eventLoop(); + if (eventLoop.inEventLoop()) { + clearReadPending0(); + } else { + eventLoop.execute(clearReadPendingRunnable); + } + } else { + // Best effort if we are not registered yet clear readPending. This happens during channel initialization. + // NB: We only set the boolean field instead of calling clearReadPending0(), because the SelectionKey is + // not set yet so it would produce an assertion failure. + readPending = false; + } + } + + private void setReadPending0(boolean readPending) { + this.readPending = readPending; + if (!readPending) { + ((AbstractNioUnsafe) unsafe()).removeReadOp(); + } + } + + private void clearReadPending0() { + readPending = false; + ((AbstractNioUnsafe) unsafe()).removeReadOp(); + } + + /** + * Special {@link Unsafe} sub-type which allows to access the underlying {@link SelectableChannel} + */ + public interface NioUnsafe extends Unsafe { + /** + * Return underlying {@link SelectableChannel} + */ + SelectableChannel ch(); + + /** + * Finish connect + */ + void finishConnect(); + + /** + * Read from underlying {@link SelectableChannel} + */ + void read(); + + void forceFlush(); + } + + protected abstract class AbstractNioUnsafe extends AbstractUnsafe implements NioUnsafe { + + protected final void removeReadOp() { + SelectionKey key = selectionKey(); + // Check first if the key is still valid as it may be canceled as part of the deregistration + // from the EventLoop + // See https://github.com/netty/netty/issues/2104 + if (!key.isValid()) { + return; + } + int interestOps = key.interestOps(); + if ((interestOps & readInterestOp) != 0) { + // only remove readInterestOp if needed + key.interestOps(interestOps & ~readInterestOp); + } + } + + @Override + public final SelectableChannel ch() { + return javaChannel(); + } + + @Override + public final void connect( + final SocketAddress remoteAddress, final SocketAddress localAddress, final ChannelPromise promise) { + if (!promise.setUncancellable() || !ensureOpen(promise)) { + return; + } + + try { + if (connectPromise != null) { + // Already a connect in process. + throw new ConnectionPendingException(); + } + + boolean wasActive = isActive(); + if (doConnect(remoteAddress, localAddress)) { + fulfillConnectPromise(promise, wasActive); + } else { + connectPromise = promise; + requestedRemoteAddress = remoteAddress; + + // Schedule connect timeout. + final int connectTimeoutMillis = config().getConnectTimeoutMillis(); + if (connectTimeoutMillis > 0) { + connectTimeoutFuture = eventLoop().schedule(new Runnable() { + @Override + public void run() { + ChannelPromise connectPromise = AbstractNioChannel.this.connectPromise; + if (connectPromise != null && !connectPromise.isDone() + && connectPromise.tryFailure(new ConnectTimeoutException( + "connection timed out after " + connectTimeoutMillis + " ms: " + + remoteAddress))) { + close(voidPromise()); + } + } + }, connectTimeoutMillis, TimeUnit.MILLISECONDS); + } + + promise.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + if (future.isCancelled()) { + if (connectTimeoutFuture != null) { + connectTimeoutFuture.cancel(false); + } + connectPromise = null; + close(voidPromise()); + } + } + }); + } + } catch (Throwable t) { + promise.tryFailure(annotateConnectException(t, remoteAddress)); + closeIfClosed(); + } + } + + private void fulfillConnectPromise(ChannelPromise promise, boolean wasActive) { + if (promise == null) { + // Closed via cancellation and the promise has been notified already. + return; + } + + // Get the state as trySuccess() may trigger an ChannelFutureListener that will close the Channel. + // We still need to ensure we call fireChannelActive() in this case. + boolean active = isActive(); + + // trySuccess() will return false if a user cancelled the connection attempt. + boolean promiseSet = promise.trySuccess(); + + // Regardless if the connection attempt was cancelled, channelActive() event should be triggered, + // because what happened is what happened. + if (!wasActive && active) { + pipeline().fireChannelActive(); + } + + // If a user cancelled the connection attempt, close the channel, which is followed by channelInactive(). + if (!promiseSet) { + close(voidPromise()); + } + } + + private void fulfillConnectPromise(ChannelPromise promise, Throwable cause) { + if (promise == null) { + // Closed via cancellation and the promise has been notified already. + return; + } + + // Use tryFailure() instead of setFailure() to avoid the race against cancel(). + promise.tryFailure(cause); + closeIfClosed(); + } + + @Override + public final void finishConnect() { + // Note this method is invoked by the event loop only if the connection attempt was + // neither cancelled nor timed out. + + assert eventLoop().inEventLoop(); + + try { + boolean wasActive = isActive(); + doFinishConnect(); + fulfillConnectPromise(connectPromise, wasActive); + } catch (Throwable t) { + fulfillConnectPromise(connectPromise, annotateConnectException(t, requestedRemoteAddress)); + } finally { + // Check for null as the connectTimeoutFuture is only created if a connectTimeoutMillis > 0 is used + // See https://github.com/netty/netty/issues/1770 + if (connectTimeoutFuture != null) { + connectTimeoutFuture.cancel(false); + } + connectPromise = null; + } + } + + @Override + protected final void flush0() { + // Flush immediately only when there's no pending flush. + // If there's a pending flush operation, event loop will call forceFlush() later, + // and thus there's no need to call it now. + if (!isFlushPending()) { + super.flush0(); + } + } + + @Override + public final void forceFlush() { + // directly call super.flush0() to force a flush now + super.flush0(); + } + + private boolean isFlushPending() { + SelectionKey selectionKey = selectionKey(); + return selectionKey.isValid() && (selectionKey.interestOps() & SelectionKey.OP_WRITE) != 0; + } + } + + @Override + protected boolean isCompatible(EventLoop loop) { + return loop instanceof NioEventLoop; + } + + @Override + protected void doRegister() throws Exception { + boolean selected = false; + for (;;) { + try { + selectionKey = javaChannel().register(eventLoop().unwrappedSelector(), 0, this); + return; + } catch (CancelledKeyException e) { + if (!selected) { + // Force the Selector to select now as the "canceled" SelectionKey may still be + // cached and not removed because no Select.select(..) operation was called yet. + eventLoop().selectNow(); + selected = true; + } else { + // We forced a select operation on the selector before but the SelectionKey is still cached + // for whatever reason. JDK bug ? + throw e; + } + } + } + } + + @Override + protected void doDeregister() throws Exception { + eventLoop().cancel(selectionKey()); + } + + @Override + protected void doBeginRead() throws Exception { + // Channel.read() or ChannelHandlerContext.read() was called + final SelectionKey selectionKey = this.selectionKey; + if (!selectionKey.isValid()) { + return; + } + + readPending = true; + + final int interestOps = selectionKey.interestOps(); + if ((interestOps & readInterestOp) == 0) { + selectionKey.interestOps(interestOps | readInterestOp); + } + } + + /** + * Connect to the remote peer + */ + protected abstract boolean doConnect(SocketAddress remoteAddress, SocketAddress localAddress) throws Exception; + + /** + * Finish the connect + */ + protected abstract void doFinishConnect() throws Exception; + + /** + * Returns an off-heap copy of the specified {@link ByteBuf}, and releases the original one. + * Note that this method does not create an off-heap copy if the allocation / deallocation cost is too high, + * but just returns the original {@link ByteBuf}.. + */ + protected final ByteBuf newDirectBuffer(ByteBuf buf) { + final int readableBytes = buf.readableBytes(); + if (readableBytes == 0) { + ReferenceCountUtil.safeRelease(buf); + return Unpooled.EMPTY_BUFFER; + } + + final ByteBufAllocator alloc = alloc(); + if (alloc.isDirectBufferPooled()) { + ByteBuf directBuf = alloc.directBuffer(readableBytes); + directBuf.writeBytes(buf, buf.readerIndex(), readableBytes); + ReferenceCountUtil.safeRelease(buf); + return directBuf; + } + + final ByteBuf directBuf = ByteBufUtil.threadLocalDirectBuffer(); + if (directBuf != null) { + directBuf.writeBytes(buf, buf.readerIndex(), readableBytes); + ReferenceCountUtil.safeRelease(buf); + return directBuf; + } + + // Allocating and deallocating an unpooled direct buffer is very expensive; give up. + return buf; + } + + /** + * Returns an off-heap copy of the specified {@link ByteBuf}, and releases the specified holder. + * The caller must ensure that the holder releases the original {@link ByteBuf} when the holder is released by + * this method. Note that this method does not create an off-heap copy if the allocation / deallocation cost is + * too high, but just returns the original {@link ByteBuf}.. + */ + protected final ByteBuf newDirectBuffer(ReferenceCounted holder, ByteBuf buf) { + final int readableBytes = buf.readableBytes(); + if (readableBytes == 0) { + ReferenceCountUtil.safeRelease(holder); + return Unpooled.EMPTY_BUFFER; + } + + final ByteBufAllocator alloc = alloc(); + if (alloc.isDirectBufferPooled()) { + ByteBuf directBuf = alloc.directBuffer(readableBytes); + directBuf.writeBytes(buf, buf.readerIndex(), readableBytes); + ReferenceCountUtil.safeRelease(holder); + return directBuf; + } + + final ByteBuf directBuf = ByteBufUtil.threadLocalDirectBuffer(); + if (directBuf != null) { + directBuf.writeBytes(buf, buf.readerIndex(), readableBytes); + ReferenceCountUtil.safeRelease(holder); + return directBuf; + } + + // Allocating and deallocating an unpooled direct buffer is very expensive; give up. + if (holder != buf) { + // Ensure to call holder.release() to give the holder a chance to release other resources than its content. + buf.retain(); + ReferenceCountUtil.safeRelease(holder); + } + + return buf; + } + + @Override + protected void doClose() throws Exception { + ChannelPromise promise = connectPromise; + if (promise != null) { + // Use tryFailure() instead of setFailure() to avoid the race against cancel(). + promise.tryFailure(new ClosedChannelException()); + connectPromise = null; + } + + Future future = connectTimeoutFuture; + if (future != null) { + future.cancel(false); + connectTimeoutFuture = null; + } + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/nio/AbstractNioMessageChannel.java b/netty-channel/src/main/java/io/netty/channel/nio/AbstractNioMessageChannel.java new file mode 100644 index 0000000..996d7b1 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/nio/AbstractNioMessageChannel.java @@ -0,0 +1,211 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.nio; + +import io.netty.channel.Channel; +import io.netty.channel.ChannelConfig; +import io.netty.channel.ChannelOutboundBuffer; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.RecvByteBufAllocator; +import io.netty.channel.ServerChannel; + +import java.io.IOException; +import java.net.PortUnreachableException; +import java.nio.channels.SelectableChannel; +import java.nio.channels.SelectionKey; +import java.util.ArrayList; +import java.util.List; + +/** + * {@link AbstractNioChannel} base class for {@link Channel}s that operate on messages. + */ +public abstract class AbstractNioMessageChannel extends AbstractNioChannel { + boolean inputShutdown; + + /** + * @see AbstractNioChannel#AbstractNioChannel(Channel, SelectableChannel, int) + */ + protected AbstractNioMessageChannel(Channel parent, SelectableChannel ch, int readInterestOp) { + super(parent, ch, readInterestOp); + } + + @Override + protected AbstractNioUnsafe newUnsafe() { + return new NioMessageUnsafe(); + } + + @Override + protected void doBeginRead() throws Exception { + if (inputShutdown) { + return; + } + super.doBeginRead(); + } + + protected boolean continueReading(RecvByteBufAllocator.Handle allocHandle) { + return allocHandle.continueReading(); + } + + private final class NioMessageUnsafe extends AbstractNioUnsafe { + + private final List readBuf = new ArrayList(); + + @Override + public void read() { + assert eventLoop().inEventLoop(); + final ChannelConfig config = config(); + final ChannelPipeline pipeline = pipeline(); + final RecvByteBufAllocator.Handle allocHandle = unsafe().recvBufAllocHandle(); + allocHandle.reset(config); + + boolean closed = false; + Throwable exception = null; + try { + try { + do { + int localRead = doReadMessages(readBuf); + if (localRead == 0) { + break; + } + if (localRead < 0) { + closed = true; + break; + } + + allocHandle.incMessagesRead(localRead); + } while (continueReading(allocHandle)); + } catch (Throwable t) { + exception = t; + } + + int size = readBuf.size(); + for (int i = 0; i < size; i ++) { + readPending = false; + pipeline.fireChannelRead(readBuf.get(i)); + } + readBuf.clear(); + allocHandle.readComplete(); + pipeline.fireChannelReadComplete(); + + if (exception != null) { + closed = closeOnReadError(exception); + + pipeline.fireExceptionCaught(exception); + } + + if (closed) { + inputShutdown = true; + if (isOpen()) { + close(voidPromise()); + } + } + } finally { + // Check if there is a readPending which was not processed yet. + // This could be for two reasons: + // * The user called Channel.read() or ChannelHandlerContext.read() in channelRead(...) method + // * The user called Channel.read() or ChannelHandlerContext.read() in channelReadComplete(...) method + // + // See https://github.com/netty/netty/issues/2254 + if (!readPending && !config.isAutoRead()) { + removeReadOp(); + } + } + } + } + + @Override + protected void doWrite(ChannelOutboundBuffer in) throws Exception { + final SelectionKey key = selectionKey(); + final int interestOps = key.interestOps(); + + int maxMessagesPerWrite = maxMessagesPerWrite(); + while (maxMessagesPerWrite > 0) { + Object msg = in.current(); + if (msg == null) { + break; + } + try { + boolean done = false; + for (int i = config().getWriteSpinCount() - 1; i >= 0; i--) { + if (doWriteMessage(msg, in)) { + done = true; + break; + } + } + + if (done) { + maxMessagesPerWrite--; + in.remove(); + } else { + break; + } + } catch (Exception e) { + if (continueOnWriteError()) { + maxMessagesPerWrite--; + in.remove(e); + } else { + throw e; + } + } + } + if (in.isEmpty()) { + // Wrote all messages. + if ((interestOps & SelectionKey.OP_WRITE) != 0) { + key.interestOps(interestOps & ~SelectionKey.OP_WRITE); + } + } else { + // Did not write all messages. + if ((interestOps & SelectionKey.OP_WRITE) == 0) { + key.interestOps(interestOps | SelectionKey.OP_WRITE); + } + } + } + + /** + * Returns {@code true} if we should continue the write loop on a write error. + */ + protected boolean continueOnWriteError() { + return false; + } + + protected boolean closeOnReadError(Throwable cause) { + if (!isActive()) { + // If the channel is not active anymore for whatever reason we should not try to continue reading. + return true; + } + if (cause instanceof PortUnreachableException) { + return false; + } + if (cause instanceof IOException) { + // ServerChannel should not be closed even on IOException because it can often continue + // accepting incoming connections. (e.g. too many open files) + return !(this instanceof ServerChannel); + } + return true; + } + + /** + * Read messages into the given array and return the amount which was read. + */ + protected abstract int doReadMessages(List buf) throws Exception; + + /** + * Write a message to the underlying {@link java.nio.channels.Channel}. + * + * @return {@code true} if and only if the message has been written + */ + protected abstract boolean doWriteMessage(Object msg, ChannelOutboundBuffer in) throws Exception; +} diff --git a/netty-channel/src/main/java/io/netty/channel/nio/NioEventLoop.java b/netty-channel/src/main/java/io/netty/channel/nio/NioEventLoop.java new file mode 100644 index 0000000..9d40733 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/nio/NioEventLoop.java @@ -0,0 +1,894 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.nio; + +import io.netty.channel.Channel; +import io.netty.channel.ChannelException; +import io.netty.channel.EventLoop; +import io.netty.channel.EventLoopException; +import io.netty.channel.EventLoopTaskQueueFactory; +import io.netty.channel.SelectStrategy; +import io.netty.channel.SingleThreadEventLoop; +import io.netty.util.IntSupplier; +import io.netty.util.concurrent.RejectedExecutionHandler; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.ReflectionUtil; +import io.netty.util.internal.SystemPropertyUtil; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.io.IOException; +import java.lang.reflect.Field; +import java.nio.channels.CancelledKeyException; +import java.nio.channels.SelectableChannel; +import java.nio.channels.Selector; +import java.nio.channels.SelectionKey; + +import java.nio.channels.spi.SelectorProvider; +import java.security.AccessController; +import java.security.PrivilegedAction; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Iterator; +import java.util.NoSuchElementException; +import java.util.Queue; +import java.util.Set; +import java.util.concurrent.Executor; +import java.util.concurrent.atomic.AtomicLong; + +/** + * {@link SingleThreadEventLoop} implementation which register the {@link Channel}'s to a + * {@link Selector} and so does the multi-plexing of these in the event loop. + * + */ +public final class NioEventLoop extends SingleThreadEventLoop { + + private static final InternalLogger logger = InternalLoggerFactory.getInstance(NioEventLoop.class); + + private static final int CLEANUP_INTERVAL = 256; // XXX Hard-coded value, but won't need customization. + + private static final boolean DISABLE_KEY_SET_OPTIMIZATION = + SystemPropertyUtil.getBoolean("io.netty.noKeySetOptimization", false); + + private static final int MIN_PREMATURE_SELECTOR_RETURNS = 3; + private static final int SELECTOR_AUTO_REBUILD_THRESHOLD; + + private final IntSupplier selectNowSupplier = new IntSupplier() { + @Override + public int get() throws Exception { + return selectNow(); + } + }; + + // Workaround for JDK NIO bug. + // + // See: + // - https://bugs.openjdk.java.net/browse/JDK-6427854 for first few dev (unreleased) builds of JDK 7 + // - https://bugs.openjdk.java.net/browse/JDK-6527572 for JDK prior to 5.0u15-rev and 6u10 + // - https://github.com/netty/netty/issues/203 + static { + if (PlatformDependent.javaVersion() < 7) { + final String key = "sun.nio.ch.bugLevel"; + final String bugLevel = SystemPropertyUtil.get(key); + if (bugLevel == null) { + try { + AccessController.doPrivileged(new PrivilegedAction() { + @Override + public Void run() { + System.setProperty(key, ""); + return null; + } + }); + } catch (final SecurityException e) { + logger.debug("Unable to get/set System Property: " + key, e); + } + } + } + + int selectorAutoRebuildThreshold = SystemPropertyUtil.getInt("io.netty.selectorAutoRebuildThreshold", 512); + if (selectorAutoRebuildThreshold < MIN_PREMATURE_SELECTOR_RETURNS) { + selectorAutoRebuildThreshold = 0; + } + + SELECTOR_AUTO_REBUILD_THRESHOLD = selectorAutoRebuildThreshold; + + if (logger.isDebugEnabled()) { + logger.debug("-Dio.netty.noKeySetOptimization: {}", DISABLE_KEY_SET_OPTIMIZATION); + logger.debug("-Dio.netty.selectorAutoRebuildThreshold: {}", SELECTOR_AUTO_REBUILD_THRESHOLD); + } + } + + /** + * The NIO {@link Selector}. + */ + private Selector selector; + private Selector unwrappedSelector; + private SelectedSelectionKeySet selectedKeys; + + private final SelectorProvider provider; + + private static final long AWAKE = -1L; + private static final long NONE = Long.MAX_VALUE; + + // nextWakeupNanos is: + // AWAKE when EL is awake + // NONE when EL is waiting with no wakeup scheduled + // other value T when EL is waiting with wakeup scheduled at time T + private final AtomicLong nextWakeupNanos = new AtomicLong(AWAKE); + + private final SelectStrategy selectStrategy; + + private volatile int ioRatio = 50; + private int cancelledKeys; + private boolean needsToSelectAgain; + + NioEventLoop(NioEventLoopGroup parent, Executor executor, SelectorProvider selectorProvider, + SelectStrategy strategy, RejectedExecutionHandler rejectedExecutionHandler, + EventLoopTaskQueueFactory taskQueueFactory, EventLoopTaskQueueFactory tailTaskQueueFactory) { + super(parent, executor, false, newTaskQueue(taskQueueFactory), newTaskQueue(tailTaskQueueFactory), + rejectedExecutionHandler); + this.provider = ObjectUtil.checkNotNull(selectorProvider, "selectorProvider"); + this.selectStrategy = ObjectUtil.checkNotNull(strategy, "selectStrategy"); + final SelectorTuple selectorTuple = openSelector(); + this.selector = selectorTuple.selector; + this.unwrappedSelector = selectorTuple.unwrappedSelector; + } + + private static Queue newTaskQueue( + EventLoopTaskQueueFactory queueFactory) { + if (queueFactory == null) { + return newTaskQueue0(DEFAULT_MAX_PENDING_TASKS); + } + return queueFactory.newTaskQueue(DEFAULT_MAX_PENDING_TASKS); + } + + private static final class SelectorTuple { + final Selector unwrappedSelector; + final Selector selector; + + SelectorTuple(Selector unwrappedSelector) { + this.unwrappedSelector = unwrappedSelector; + this.selector = unwrappedSelector; + } + + SelectorTuple(Selector unwrappedSelector, Selector selector) { + this.unwrappedSelector = unwrappedSelector; + this.selector = selector; + } + } + + private SelectorTuple openSelector() { + final Selector unwrappedSelector; + try { + unwrappedSelector = provider.openSelector(); + } catch (IOException e) { + throw new ChannelException("failed to open a new selector", e); + } + + if (DISABLE_KEY_SET_OPTIMIZATION) { + return new SelectorTuple(unwrappedSelector); + } + + Object maybeSelectorImplClass = AccessController.doPrivileged(new PrivilegedAction() { + @Override + public Object run() { + try { + return Class.forName( + "sun.nio.ch.SelectorImpl", + false, + PlatformDependent.getSystemClassLoader()); + } catch (Throwable cause) { + return cause; + } + } + }); + + if (!(maybeSelectorImplClass instanceof Class) || + // ensure the current selector implementation is what we can instrument. + !((Class) maybeSelectorImplClass).isAssignableFrom(unwrappedSelector.getClass())) { + if (maybeSelectorImplClass instanceof Throwable) { + Throwable t = (Throwable) maybeSelectorImplClass; + logger.trace("failed to instrument a special java.util.Set into: {}", unwrappedSelector, t); + } + return new SelectorTuple(unwrappedSelector); + } + + final Class selectorImplClass = (Class) maybeSelectorImplClass; + final SelectedSelectionKeySet selectedKeySet = new SelectedSelectionKeySet(); + + Object maybeException = AccessController.doPrivileged(new PrivilegedAction() { + @Override + public Object run() { + try { + Field selectedKeysField = selectorImplClass.getDeclaredField("selectedKeys"); + Field publicSelectedKeysField = selectorImplClass.getDeclaredField("publicSelectedKeys"); + + if (PlatformDependent.javaVersion() >= 9 && PlatformDependent.hasUnsafe()) { + // Let us try to use sun.misc.Unsafe to replace the SelectionKeySet. + // This allows us to also do this in Java9+ without any extra flags. + long selectedKeysFieldOffset = PlatformDependent.objectFieldOffset(selectedKeysField); + long publicSelectedKeysFieldOffset = + PlatformDependent.objectFieldOffset(publicSelectedKeysField); + + if (selectedKeysFieldOffset != -1 && publicSelectedKeysFieldOffset != -1) { + PlatformDependent.putObject( + unwrappedSelector, selectedKeysFieldOffset, selectedKeySet); + PlatformDependent.putObject( + unwrappedSelector, publicSelectedKeysFieldOffset, selectedKeySet); + return null; + } + // We could not retrieve the offset, lets try reflection as last-resort. + } + + Throwable cause = ReflectionUtil.trySetAccessible(selectedKeysField, true); + if (cause != null) { + return cause; + } + cause = ReflectionUtil.trySetAccessible(publicSelectedKeysField, true); + if (cause != null) { + return cause; + } + + selectedKeysField.set(unwrappedSelector, selectedKeySet); + publicSelectedKeysField.set(unwrappedSelector, selectedKeySet); + return null; + } catch (NoSuchFieldException e) { + return e; + } catch (IllegalAccessException e) { + return e; + } + } + }); + + if (maybeException instanceof Exception) { + selectedKeys = null; + Exception e = (Exception) maybeException; + logger.trace("failed to instrument a special java.util.Set into: {}", unwrappedSelector, e); + return new SelectorTuple(unwrappedSelector); + } + selectedKeys = selectedKeySet; + logger.trace("instrumented a special java.util.Set into: {}", unwrappedSelector); + return new SelectorTuple(unwrappedSelector, + new SelectedSelectionKeySetSelector(unwrappedSelector, selectedKeySet)); + } + + /** + * Returns the {@link SelectorProvider} used by this {@link NioEventLoop} to obtain the {@link Selector}. + */ + public SelectorProvider selectorProvider() { + return provider; + } + + @Override + protected Queue newTaskQueue(int maxPendingTasks) { + return newTaskQueue0(maxPendingTasks); + } + + private static Queue newTaskQueue0(int maxPendingTasks) { + // This event loop never calls takeTask() + return maxPendingTasks == Integer.MAX_VALUE ? PlatformDependent.newMpscQueue() + : PlatformDependent.newMpscQueue(maxPendingTasks); + } + + /** + * Registers an arbitrary {@link SelectableChannel}, not necessarily created by Netty, to the {@link Selector} + * of this event loop. Once the specified {@link SelectableChannel} is registered, the specified {@code task} will + * be executed by this event loop when the {@link SelectableChannel} is ready. + */ + public void register(final SelectableChannel ch, final int interestOps, final NioTask task) { + ObjectUtil.checkNotNull(ch, "ch"); + if (interestOps == 0) { + throw new IllegalArgumentException("interestOps must be non-zero."); + } + if ((interestOps & ~ch.validOps()) != 0) { + throw new IllegalArgumentException( + "invalid interestOps: " + interestOps + "(validOps: " + ch.validOps() + ')'); + } + ObjectUtil.checkNotNull(task, "task"); + + if (isShutdown()) { + throw new IllegalStateException("event loop shut down"); + } + + if (inEventLoop()) { + register0(ch, interestOps, task); + } else { + try { + // Offload to the EventLoop as otherwise java.nio.channels.spi.AbstractSelectableChannel.register + // may block for a long time while trying to obtain an internal lock that may be hold while selecting. + submit(new Runnable() { + @Override + public void run() { + register0(ch, interestOps, task); + } + }).sync(); + } catch (InterruptedException ignore) { + // Even if interrupted we did schedule it so just mark the Thread as interrupted. + Thread.currentThread().interrupt(); + } + } + } + + private void register0(SelectableChannel ch, int interestOps, NioTask task) { + try { + ch.register(unwrappedSelector, interestOps, task); + } catch (Exception e) { + throw new EventLoopException("failed to register a channel", e); + } + } + + /** + * Returns the percentage of the desired amount of time spent for I/O in the event loop. + */ + public int getIoRatio() { + return ioRatio; + } + + /** + * Sets the percentage of the desired amount of time spent for I/O in the event loop. Value range from 1-100. + * The default value is {@code 50}, which means the event loop will try to spend the same amount of time for I/O + * as for non-I/O tasks. The lower the number the more time can be spent on non-I/O tasks. If value set to + * {@code 100}, this feature will be disabled and event loop will not attempt to balance I/O and non-I/O tasks. + */ + public void setIoRatio(int ioRatio) { + if (ioRatio <= 0 || ioRatio > 100) { + throw new IllegalArgumentException("ioRatio: " + ioRatio + " (expected: 0 < ioRatio <= 100)"); + } + this.ioRatio = ioRatio; + } + + /** + * Replaces the current {@link Selector} of this event loop with newly created {@link Selector}s to work + * around the infamous epoll 100% CPU bug. + */ + public void rebuildSelector() { + if (!inEventLoop()) { + execute(new Runnable() { + @Override + public void run() { + rebuildSelector0(); + } + }); + return; + } + rebuildSelector0(); + } + + @Override + public int registeredChannels() { + return selector.keys().size() - cancelledKeys; + } + + @Override + public Iterator registeredChannelsIterator() { + assert inEventLoop(); + final Set keys = selector.keys(); + if (keys.isEmpty()) { + return ChannelsReadOnlyIterator.empty(); + } + return new Iterator() { + final Iterator selectionKeyIterator = + ObjectUtil.checkNotNull(keys, "selectionKeys") + .iterator(); + Channel next; + boolean isDone; + + @Override + public boolean hasNext() { + if (isDone) { + return false; + } + Channel cur = next; + if (cur == null) { + cur = next = nextOrDone(); + return cur != null; + } + return true; + } + + @Override + public Channel next() { + if (isDone) { + throw new NoSuchElementException(); + } + Channel cur = next; + if (cur == null) { + cur = nextOrDone(); + if (cur == null) { + throw new NoSuchElementException(); + } + } + next = nextOrDone(); + return cur; + } + + @Override + public void remove() { + throw new UnsupportedOperationException("remove"); + } + + private Channel nextOrDone() { + Iterator it = selectionKeyIterator; + while (it.hasNext()) { + SelectionKey key = it.next(); + if (key.isValid()) { + Object attachment = key.attachment(); + if (attachment instanceof AbstractNioChannel) { + return (AbstractNioChannel) attachment; + } + } + } + isDone = true; + return null; + } + }; + } + + private void rebuildSelector0() { + final Selector oldSelector = selector; + final SelectorTuple newSelectorTuple; + + if (oldSelector == null) { + return; + } + + try { + newSelectorTuple = openSelector(); + } catch (Exception e) { + logger.warn("Failed to create a new Selector.", e); + return; + } + + // Register all channels to the new Selector. + int nChannels = 0; + for (SelectionKey key: oldSelector.keys()) { + Object a = key.attachment(); + try { + if (!key.isValid() || key.channel().keyFor(newSelectorTuple.unwrappedSelector) != null) { + continue; + } + + int interestOps = key.interestOps(); + key.cancel(); + SelectionKey newKey = key.channel().register(newSelectorTuple.unwrappedSelector, interestOps, a); + if (a instanceof AbstractNioChannel) { + // Update SelectionKey + ((AbstractNioChannel) a).selectionKey = newKey; + } + nChannels ++; + } catch (Exception e) { + logger.warn("Failed to re-register a Channel to the new Selector.", e); + if (a instanceof AbstractNioChannel) { + AbstractNioChannel ch = (AbstractNioChannel) a; + ch.unsafe().close(ch.unsafe().voidPromise()); + } else { + @SuppressWarnings("unchecked") + NioTask task = (NioTask) a; + invokeChannelUnregistered(task, key, e); + } + } + } + + selector = newSelectorTuple.selector; + unwrappedSelector = newSelectorTuple.unwrappedSelector; + + try { + // time to close the old selector as everything else is registered to the new one + oldSelector.close(); + } catch (Throwable t) { + if (logger.isWarnEnabled()) { + logger.warn("Failed to close the old Selector.", t); + } + } + + if (logger.isInfoEnabled()) { + logger.info("Migrated " + nChannels + " channel(s) to the new Selector."); + } + } + + @Override + protected void run() { + int selectCnt = 0; + for (;;) { + try { + int strategy; + try { + strategy = selectStrategy.calculateStrategy(selectNowSupplier, hasTasks()); + switch (strategy) { + case SelectStrategy.CONTINUE: + continue; + + case SelectStrategy.BUSY_WAIT: + // fall-through to SELECT since the busy-wait is not supported with NIO + + case SelectStrategy.SELECT: + long curDeadlineNanos = nextScheduledTaskDeadlineNanos(); + if (curDeadlineNanos == -1L) { + curDeadlineNanos = NONE; // nothing on the calendar + } + nextWakeupNanos.set(curDeadlineNanos); + try { + if (!hasTasks()) { + strategy = select(curDeadlineNanos); + } + } finally { + // This update is just to help block unnecessary selector wakeups + // so use of lazySet is ok (no race condition) + nextWakeupNanos.lazySet(AWAKE); + } + // fall through + default: + } + } catch (IOException e) { + // If we receive an IOException here its because the Selector is messed up. Let's rebuild + // the selector and retry. https://github.com/netty/netty/issues/8566 + rebuildSelector0(); + selectCnt = 0; + handleLoopException(e); + continue; + } + + selectCnt++; + cancelledKeys = 0; + needsToSelectAgain = false; + final int ioRatio = this.ioRatio; + boolean ranTasks; + if (ioRatio == 100) { + try { + if (strategy > 0) { + processSelectedKeys(); + } + } finally { + // Ensure we always run tasks. + ranTasks = runAllTasks(); + } + } else if (strategy > 0) { + final long ioStartTime = System.nanoTime(); + try { + processSelectedKeys(); + } finally { + // Ensure we always run tasks. + final long ioTime = System.nanoTime() - ioStartTime; + ranTasks = runAllTasks(ioTime * (100 - ioRatio) / ioRatio); + } + } else { + ranTasks = runAllTasks(0); // This will run the minimum number of tasks + } + + if (ranTasks || strategy > 0) { + if (selectCnt > MIN_PREMATURE_SELECTOR_RETURNS && logger.isDebugEnabled()) { + logger.debug("Selector.select() returned prematurely {} times in a row for Selector {}.", + selectCnt - 1, selector); + } + selectCnt = 0; + } else if (unexpectedSelectorWakeup(selectCnt)) { // Unexpected wakeup (unusual case) + selectCnt = 0; + } + } catch (CancelledKeyException e) { + // Harmless exception - log anyway + if (logger.isDebugEnabled()) { + logger.debug(CancelledKeyException.class.getSimpleName() + " raised by a Selector {} - JDK bug?", + selector, e); + } + } catch (Error e) { + throw e; + } catch (Throwable t) { + handleLoopException(t); + } finally { + // Always handle shutdown even if the loop processing threw an exception. + try { + if (isShuttingDown()) { + closeAll(); + if (confirmShutdown()) { + return; + } + } + } catch (Error e) { + throw e; + } catch (Throwable t) { + handleLoopException(t); + } + } + } + } + + // returns true if selectCnt should be reset + private boolean unexpectedSelectorWakeup(int selectCnt) { + if (Thread.interrupted()) { + // Thread was interrupted so reset selected keys and break so we not run into a busy loop. + // As this is most likely a bug in the handler of the user or it's client library we will + // also log it. + // + // See https://github.com/netty/netty/issues/2426 + if (logger.isDebugEnabled()) { + logger.debug("Selector.select() returned prematurely because " + + "Thread.currentThread().interrupt() was called. Use " + + "NioEventLoop.shutdownGracefully() to shutdown the NioEventLoop."); + } + return true; + } + if (SELECTOR_AUTO_REBUILD_THRESHOLD > 0 && + selectCnt >= SELECTOR_AUTO_REBUILD_THRESHOLD) { + // The selector returned prematurely many times in a row. + // Rebuild the selector to work around the problem. + logger.warn("Selector.select() returned prematurely {} times in a row; rebuilding Selector {}.", + selectCnt, selector); + rebuildSelector(); + return true; + } + return false; + } + + private static void handleLoopException(Throwable t) { + logger.warn("Unexpected exception in the selector loop.", t); + + // Prevent possible consecutive immediate failures that lead to + // excessive CPU consumption. + try { + Thread.sleep(1000); + } catch (InterruptedException e) { + // Ignore. + } + } + + private void processSelectedKeys() { + if (selectedKeys != null) { + processSelectedKeysOptimized(); + } else { + processSelectedKeysPlain(selector.selectedKeys()); + } + } + + @Override + protected void cleanup() { + try { + selector.close(); + } catch (IOException e) { + logger.warn("Failed to close a selector.", e); + } + } + + void cancel(SelectionKey key) { + key.cancel(); + cancelledKeys ++; + if (cancelledKeys >= CLEANUP_INTERVAL) { + cancelledKeys = 0; + needsToSelectAgain = true; + } + } + + private void processSelectedKeysPlain(Set selectedKeys) { + // check if the set is empty and if so just return to not create garbage by + // creating a new Iterator every time even if there is nothing to process. + // See https://github.com/netty/netty/issues/597 + if (selectedKeys.isEmpty()) { + return; + } + + Iterator i = selectedKeys.iterator(); + for (;;) { + final SelectionKey k = i.next(); + final Object a = k.attachment(); + i.remove(); + + if (a instanceof AbstractNioChannel) { + processSelectedKey(k, (AbstractNioChannel) a); + } else { + @SuppressWarnings("unchecked") + NioTask task = (NioTask) a; + processSelectedKey(k, task); + } + + if (!i.hasNext()) { + break; + } + + if (needsToSelectAgain) { + selectAgain(); + selectedKeys = selector.selectedKeys(); + + // Create the iterator again to avoid ConcurrentModificationException + if (selectedKeys.isEmpty()) { + break; + } else { + i = selectedKeys.iterator(); + } + } + } + } + + private void processSelectedKeysOptimized() { + for (int i = 0; i < selectedKeys.size; ++i) { + final SelectionKey k = selectedKeys.keys[i]; + // null out entry in the array to allow to have it GC'ed once the Channel close + // See https://github.com/netty/netty/issues/2363 + selectedKeys.keys[i] = null; + + final Object a = k.attachment(); + + if (a instanceof AbstractNioChannel) { + processSelectedKey(k, (AbstractNioChannel) a); + } else { + @SuppressWarnings("unchecked") + NioTask task = (NioTask) a; + processSelectedKey(k, task); + } + + if (needsToSelectAgain) { + // null out entries in the array to allow to have it GC'ed once the Channel close + // See https://github.com/netty/netty/issues/2363 + selectedKeys.reset(i + 1); + + selectAgain(); + i = -1; + } + } + } + + private void processSelectedKey(SelectionKey k, AbstractNioChannel ch) { + final AbstractNioChannel.NioUnsafe unsafe = ch.unsafe(); + if (!k.isValid()) { + final EventLoop eventLoop; + try { + eventLoop = ch.eventLoop(); + } catch (Throwable ignored) { + // If the channel implementation throws an exception because there is no event loop, we ignore this + // because we are only trying to determine if ch is registered to this event loop and thus has authority + // to close ch. + return; + } + // Only close ch if ch is still registered to this EventLoop. ch could have deregistered from the event loop + // and thus the SelectionKey could be cancelled as part of the deregistration process, but the channel is + // still healthy and should not be closed. + // See https://github.com/netty/netty/issues/5125 + if (eventLoop == this) { + // close the channel if the key is not valid anymore + unsafe.close(unsafe.voidPromise()); + } + return; + } + + try { + int readyOps = k.readyOps(); + // We first need to call finishConnect() before try to trigger a read(...) or write(...) as otherwise + // the NIO JDK channel implementation may throw a NotYetConnectedException. + if ((readyOps & SelectionKey.OP_CONNECT) != 0) { + // remove OP_CONNECT as otherwise Selector.select(..) will always return without blocking + // See https://github.com/netty/netty/issues/924 + int ops = k.interestOps(); + ops &= ~SelectionKey.OP_CONNECT; + k.interestOps(ops); + + unsafe.finishConnect(); + } + + // Process OP_WRITE first as we may be able to write some queued buffers and so free memory. + if ((readyOps & SelectionKey.OP_WRITE) != 0) { + // Call forceFlush which will also take care of clear the OP_WRITE once there is nothing left to write + unsafe.forceFlush(); + } + + // Also check for readOps of 0 to workaround possible JDK bug which may otherwise lead + // to a spin loop + if ((readyOps & (SelectionKey.OP_READ | SelectionKey.OP_ACCEPT)) != 0 || readyOps == 0) { + unsafe.read(); + } + } catch (CancelledKeyException ignored) { + unsafe.close(unsafe.voidPromise()); + } + } + + private static void processSelectedKey(SelectionKey k, NioTask task) { + int state = 0; + try { + task.channelReady(k.channel(), k); + state = 1; + } catch (Exception e) { + k.cancel(); + invokeChannelUnregistered(task, k, e); + state = 2; + } finally { + switch (state) { + case 0: + k.cancel(); + invokeChannelUnregistered(task, k, null); + break; + case 1: + if (!k.isValid()) { // Cancelled by channelReady() + invokeChannelUnregistered(task, k, null); + } + break; + default: + break; + } + } + } + + private void closeAll() { + selectAgain(); + Set keys = selector.keys(); + Collection channels = new ArrayList(keys.size()); + for (SelectionKey k: keys) { + Object a = k.attachment(); + if (a instanceof AbstractNioChannel) { + channels.add((AbstractNioChannel) a); + } else { + k.cancel(); + @SuppressWarnings("unchecked") + NioTask task = (NioTask) a; + invokeChannelUnregistered(task, k, null); + } + } + + for (AbstractNioChannel ch: channels) { + ch.unsafe().close(ch.unsafe().voidPromise()); + } + } + + private static void invokeChannelUnregistered(NioTask task, SelectionKey k, Throwable cause) { + try { + task.channelUnregistered(k.channel(), cause); + } catch (Exception e) { + logger.warn("Unexpected exception while running NioTask.channelUnregistered()", e); + } + } + + @Override + protected void wakeup(boolean inEventLoop) { + if (!inEventLoop && nextWakeupNanos.getAndSet(AWAKE) != AWAKE) { + selector.wakeup(); + } + } + + @Override + protected boolean beforeScheduledTaskSubmitted(long deadlineNanos) { + // Note this is also correct for the nextWakeupNanos == -1 (AWAKE) case + return deadlineNanos < nextWakeupNanos.get(); + } + + @Override + protected boolean afterScheduledTaskSubmitted(long deadlineNanos) { + // Note this is also correct for the nextWakeupNanos == -1 (AWAKE) case + return deadlineNanos < nextWakeupNanos.get(); + } + + Selector unwrappedSelector() { + return unwrappedSelector; + } + + int selectNow() throws IOException { + return selector.selectNow(); + } + + private int select(long deadlineNanos) throws IOException { + if (deadlineNanos == NONE) { + return selector.select(); + } + // Timeout will only be 0 if deadline is within 5 microsecs + long timeoutMillis = deadlineToDelayNanos(deadlineNanos + 995000L) / 1000000L; + return timeoutMillis <= 0 ? selector.selectNow() : selector.select(timeoutMillis); + } + + private void selectAgain() { + needsToSelectAgain = false; + try { + selector.selectNow(); + } catch (Throwable t) { + logger.warn("Failed to update SelectionKeys.", t); + } + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/nio/NioEventLoopGroup.java b/netty-channel/src/main/java/io/netty/channel/nio/NioEventLoopGroup.java new file mode 100644 index 0000000..9854e9c --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/nio/NioEventLoopGroup.java @@ -0,0 +1,186 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.nio; + +import io.netty.channel.Channel; +import io.netty.channel.DefaultSelectStrategyFactory; +import io.netty.channel.EventLoop; +import io.netty.channel.EventLoopTaskQueueFactory; +import io.netty.channel.MultithreadEventLoopGroup; +import io.netty.channel.SelectStrategyFactory; +import io.netty.channel.SingleThreadEventLoop; +import io.netty.util.concurrent.EventExecutor; +import io.netty.util.concurrent.EventExecutorChooserFactory; +import io.netty.util.concurrent.RejectedExecutionHandler; +import io.netty.util.concurrent.RejectedExecutionHandlers; + +import java.nio.channels.Selector; +import java.nio.channels.spi.SelectorProvider; +import java.util.concurrent.Executor; +import java.util.concurrent.ThreadFactory; + +/** + * {@link MultithreadEventLoopGroup} implementations which is used for NIO {@link Selector} based {@link Channel}s. + */ +public class NioEventLoopGroup extends MultithreadEventLoopGroup { + + /** + * Create a new instance using the default number of threads, the default {@link ThreadFactory} and + * the {@link SelectorProvider} which is returned by {@link SelectorProvider#provider()}. + */ + public NioEventLoopGroup() { + this(0); + } + + /** + * Create a new instance using the specified number of threads, {@link ThreadFactory} and the + * {@link SelectorProvider} which is returned by {@link SelectorProvider#provider()}. + */ + public NioEventLoopGroup(int nThreads) { + this(nThreads, (Executor) null); + } + + /** + * Create a new instance using the default number of threads, the given {@link ThreadFactory} and the + * {@link SelectorProvider} which is returned by {@link SelectorProvider#provider()}. + */ + public NioEventLoopGroup(ThreadFactory threadFactory) { + this(0, threadFactory, SelectorProvider.provider()); + } + + /** + * Create a new instance using the specified number of threads, the given {@link ThreadFactory} and the + * {@link SelectorProvider} which is returned by {@link SelectorProvider#provider()}. + */ + public NioEventLoopGroup(int nThreads, ThreadFactory threadFactory) { + this(nThreads, threadFactory, SelectorProvider.provider()); + } + + public NioEventLoopGroup(int nThreads, Executor executor) { + this(nThreads, executor, SelectorProvider.provider()); + } + + /** + * Create a new instance using the specified number of threads, the given {@link ThreadFactory} and the given + * {@link SelectorProvider}. + */ + public NioEventLoopGroup( + int nThreads, ThreadFactory threadFactory, final SelectorProvider selectorProvider) { + this(nThreads, threadFactory, selectorProvider, DefaultSelectStrategyFactory.INSTANCE); + } + + public NioEventLoopGroup(int nThreads, ThreadFactory threadFactory, + final SelectorProvider selectorProvider, final SelectStrategyFactory selectStrategyFactory) { + super(nThreads, threadFactory, selectorProvider, selectStrategyFactory, RejectedExecutionHandlers.reject()); + } + + public NioEventLoopGroup( + int nThreads, Executor executor, final SelectorProvider selectorProvider) { + this(nThreads, executor, selectorProvider, DefaultSelectStrategyFactory.INSTANCE); + } + + public NioEventLoopGroup(int nThreads, Executor executor, final SelectorProvider selectorProvider, + final SelectStrategyFactory selectStrategyFactory) { + super(nThreads, executor, selectorProvider, selectStrategyFactory, RejectedExecutionHandlers.reject()); + } + + public NioEventLoopGroup(int nThreads, Executor executor, EventExecutorChooserFactory chooserFactory, + final SelectorProvider selectorProvider, + final SelectStrategyFactory selectStrategyFactory) { + super(nThreads, executor, chooserFactory, selectorProvider, selectStrategyFactory, + RejectedExecutionHandlers.reject()); + } + + public NioEventLoopGroup(int nThreads, Executor executor, EventExecutorChooserFactory chooserFactory, + final SelectorProvider selectorProvider, + final SelectStrategyFactory selectStrategyFactory, + final RejectedExecutionHandler rejectedExecutionHandler) { + super(nThreads, executor, chooserFactory, selectorProvider, selectStrategyFactory, rejectedExecutionHandler); + } + + public NioEventLoopGroup(int nThreads, Executor executor, EventExecutorChooserFactory chooserFactory, + final SelectorProvider selectorProvider, + final SelectStrategyFactory selectStrategyFactory, + final RejectedExecutionHandler rejectedExecutionHandler, + final EventLoopTaskQueueFactory taskQueueFactory) { + super(nThreads, executor, chooserFactory, selectorProvider, selectStrategyFactory, + rejectedExecutionHandler, taskQueueFactory); + } + + /** + * @param nThreads the number of threads that will be used by this instance. + * @param executor the Executor to use, or {@code null} if default one should be used. + * @param chooserFactory the {@link EventExecutorChooserFactory} to use. + * @param selectorProvider the {@link SelectorProvider} to use. + * @param selectStrategyFactory the {@link SelectStrategyFactory} to use. + * @param rejectedExecutionHandler the {@link RejectedExecutionHandler} to use. + * @param taskQueueFactory the {@link EventLoopTaskQueueFactory} to use for + * {@link SingleThreadEventLoop#execute(Runnable)}, + * or {@code null} if default one should be used. + * @param tailTaskQueueFactory the {@link EventLoopTaskQueueFactory} to use for + * {@link SingleThreadEventLoop#executeAfterEventLoopIteration(Runnable)}, + * or {@code null} if default one should be used. + */ + public NioEventLoopGroup(int nThreads, Executor executor, EventExecutorChooserFactory chooserFactory, + SelectorProvider selectorProvider, + SelectStrategyFactory selectStrategyFactory, + RejectedExecutionHandler rejectedExecutionHandler, + EventLoopTaskQueueFactory taskQueueFactory, + EventLoopTaskQueueFactory tailTaskQueueFactory) { + super(nThreads, executor, chooserFactory, selectorProvider, selectStrategyFactory, + rejectedExecutionHandler, taskQueueFactory, tailTaskQueueFactory); + } + + /** + * Sets the percentage of the desired amount of time spent for I/O in the child event loops. The default value is + * {@code 50}, which means the event loop will try to spend the same amount of time for I/O as for non-I/O tasks. + */ + public void setIoRatio(int ioRatio) { + for (EventExecutor e: this) { + ((NioEventLoop) e).setIoRatio(ioRatio); + } + } + + /** + * Replaces the current {@link Selector}s of the child event loops with newly created {@link Selector}s to work + * around the infamous epoll 100% CPU bug. + */ + public void rebuildSelectors() { + for (EventExecutor e: this) { + ((NioEventLoop) e).rebuildSelector(); + } + } + + @Override + protected EventLoop newChild(Executor executor, Object... args) throws Exception { + SelectorProvider selectorProvider = (SelectorProvider) args[0]; + SelectStrategyFactory selectStrategyFactory = (SelectStrategyFactory) args[1]; + RejectedExecutionHandler rejectedExecutionHandler = (RejectedExecutionHandler) args[2]; + EventLoopTaskQueueFactory taskQueueFactory = null; + EventLoopTaskQueueFactory tailTaskQueueFactory = null; + + int argsLength = args.length; + if (argsLength > 3) { + taskQueueFactory = (EventLoopTaskQueueFactory) args[3]; + } + if (argsLength > 4) { + tailTaskQueueFactory = (EventLoopTaskQueueFactory) args[4]; + } + return new NioEventLoop(this, executor, selectorProvider, + selectStrategyFactory.newSelectStrategy(), + rejectedExecutionHandler, taskQueueFactory, tailTaskQueueFactory); + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/nio/NioTask.java b/netty-channel/src/main/java/io/netty/channel/nio/NioTask.java new file mode 100644 index 0000000..8b2d213 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/nio/NioTask.java @@ -0,0 +1,41 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.nio; + +import java.nio.channels.SelectableChannel; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; + +/** + * An arbitrary task that can be executed by {@link NioEventLoop} when a {@link SelectableChannel} becomes ready. + * + * @see NioEventLoop#register(SelectableChannel, int, NioTask) + */ +public interface NioTask { + /** + * Invoked when the {@link SelectableChannel} has been selected by the {@link Selector}. + */ + void channelReady(C ch, SelectionKey key) throws Exception; + + /** + * Invoked when the {@link SelectionKey} of the specified {@link SelectableChannel} has been cancelled and thus + * this {@link NioTask} will not be notified anymore. + * + * @param cause the cause of the unregistration. {@code null} if a user called {@link SelectionKey#cancel()} or + * the event loop has been shut down. + */ + void channelUnregistered(C ch, Throwable cause) throws Exception; +} diff --git a/netty-channel/src/main/java/io/netty/channel/nio/SelectedSelectionKeySet.java b/netty-channel/src/main/java/io/netty/channel/nio/SelectedSelectionKeySet.java new file mode 100644 index 0000000..b9c72b6 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/nio/SelectedSelectionKeySet.java @@ -0,0 +1,109 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.nio; + +import java.nio.channels.SelectionKey; +import java.util.AbstractSet; +import java.util.Arrays; +import java.util.Iterator; +import java.util.NoSuchElementException; +import java.util.Objects; + +final class SelectedSelectionKeySet extends AbstractSet { + + SelectionKey[] keys; + int size; + + SelectedSelectionKeySet() { + keys = new SelectionKey[1024]; + } + + @Override + public boolean add(SelectionKey o) { + if (o == null) { + return false; + } + + if (size == keys.length) { + increaseCapacity(); + } + + keys[size++] = o; + return true; + } + + @Override + public boolean remove(Object o) { + return false; + } + + @Override + public boolean contains(Object o) { + SelectionKey[] array = keys; + for (int i = 0, s = size; i < s; i++) { + SelectionKey k = array[i]; + if (k.equals(o)) { + return true; + } + } + return false; + } + + @Override + public int size() { + return size; + } + + @Override + public Iterator iterator() { + return new Iterator() { + private int idx; + + @Override + public boolean hasNext() { + return idx < size; + } + + @Override + public SelectionKey next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + return keys[idx++]; + } + + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + }; + } + + void reset() { + reset(0); + } + + void reset(int start) { + Arrays.fill(keys, start, size, null); + size = 0; + } + + private void increaseCapacity() { + SelectionKey[] newKeys = new SelectionKey[keys.length << 1]; + System.arraycopy(keys, 0, newKeys, 0, size); + keys = newKeys; + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/nio/SelectedSelectionKeySetSelector.java b/netty-channel/src/main/java/io/netty/channel/nio/SelectedSelectionKeySetSelector.java new file mode 100644 index 0000000..7e45e47 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/nio/SelectedSelectionKeySetSelector.java @@ -0,0 +1,80 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.nio; + +import java.io.IOException; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.nio.channels.spi.SelectorProvider; +import java.util.Set; + +final class SelectedSelectionKeySetSelector extends Selector { + private final SelectedSelectionKeySet selectionKeys; + private final Selector delegate; + + SelectedSelectionKeySetSelector(Selector delegate, SelectedSelectionKeySet selectionKeys) { + this.delegate = delegate; + this.selectionKeys = selectionKeys; + } + + @Override + public boolean isOpen() { + return delegate.isOpen(); + } + + @Override + public SelectorProvider provider() { + return delegate.provider(); + } + + @Override + public Set keys() { + return delegate.keys(); + } + + @Override + public Set selectedKeys() { + return delegate.selectedKeys(); + } + + @Override + public int selectNow() throws IOException { + selectionKeys.reset(); + return delegate.selectNow(); + } + + @Override + public int select(long timeout) throws IOException { + selectionKeys.reset(); + return delegate.select(timeout); + } + + @Override + public int select() throws IOException { + selectionKeys.reset(); + return delegate.select(); + } + + @Override + public Selector wakeup() { + return delegate.wakeup(); + } + + @Override + public void close() throws IOException { + delegate.close(); + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/nio/package-info.java b/netty-channel/src/main/java/io/netty/channel/nio/package-info.java new file mode 100644 index 0000000..c6c7688 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/nio/package-info.java @@ -0,0 +1,21 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/** + * NIO-based channel + * API implementation - recommended for a large number of connections (>= 1000). + */ +package io.netty.channel.nio; diff --git a/netty-channel/src/main/java/io/netty/channel/oio/AbstractOioByteChannel.java b/netty-channel/src/main/java/io/netty/channel/oio/AbstractOioByteChannel.java new file mode 100644 index 0000000..26a5056 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/oio/AbstractOioByteChannel.java @@ -0,0 +1,270 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.oio; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.channel.Channel; +import io.netty.channel.ChannelConfig; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelMetadata; +import io.netty.channel.ChannelOption; +import io.netty.channel.ChannelOutboundBuffer; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.FileRegion; +import io.netty.channel.RecvByteBufAllocator; +import io.netty.channel.socket.ChannelInputShutdownEvent; +import io.netty.channel.socket.ChannelInputShutdownReadComplete; +import io.netty.util.internal.StringUtil; + +import java.io.IOException; + +/** + * Abstract base class for OIO which reads and writes bytes from/to a Socket + * + * @deprecated use NIO / EPOLL / KQUEUE transport. + */ +public abstract class AbstractOioByteChannel extends AbstractOioChannel { + + private static final ChannelMetadata METADATA = new ChannelMetadata(false); + private static final String EXPECTED_TYPES = + " (expected: " + StringUtil.simpleClassName(ByteBuf.class) + ", " + + StringUtil.simpleClassName(FileRegion.class) + ')'; + + /** + * @see AbstractOioByteChannel#AbstractOioByteChannel(Channel) + */ + protected AbstractOioByteChannel(Channel parent) { + super(parent); + } + + @Override + public ChannelMetadata metadata() { + return METADATA; + } + + /** + * Determine if the input side of this channel is shutdown. + * @return {@code true} if the input side of this channel is shutdown. + */ + protected abstract boolean isInputShutdown(); + + /** + * Shutdown the input side of this channel. + * @return A channel future that will complete when the shutdown is complete. + */ + protected abstract ChannelFuture shutdownInput(); + + private void closeOnRead(ChannelPipeline pipeline) { + if (isOpen()) { + if (Boolean.TRUE.equals(config().getOption(ChannelOption.ALLOW_HALF_CLOSURE))) { + shutdownInput(); + pipeline.fireUserEventTriggered(ChannelInputShutdownEvent.INSTANCE); + } else { + unsafe().close(unsafe().voidPromise()); + } + pipeline.fireUserEventTriggered(ChannelInputShutdownReadComplete.INSTANCE); + } + } + + private void handleReadException(ChannelPipeline pipeline, ByteBuf byteBuf, Throwable cause, boolean close, + RecvByteBufAllocator.Handle allocHandle) { + if (byteBuf != null) { + if (byteBuf.isReadable()) { + readPending = false; + pipeline.fireChannelRead(byteBuf); + } else { + byteBuf.release(); + } + } + allocHandle.readComplete(); + pipeline.fireChannelReadComplete(); + pipeline.fireExceptionCaught(cause); + + // If oom will close the read event, release connection. + // See https://github.com/netty/netty/issues/10434 + if (close || cause instanceof OutOfMemoryError || cause instanceof IOException) { + closeOnRead(pipeline); + } + } + + @Override + protected void doRead() { + final ChannelConfig config = config(); + if (isInputShutdown() || !readPending) { + // We have to check readPending here because the Runnable to read could have been scheduled and later + // during the same read loop readPending was set to false. + return; + } + // In OIO we should set readPending to false even if the read was not successful so we can schedule + // another read on the event loop if no reads are done. + readPending = false; + + final ChannelPipeline pipeline = pipeline(); + final ByteBufAllocator allocator = config.getAllocator(); + final RecvByteBufAllocator.Handle allocHandle = unsafe().recvBufAllocHandle(); + allocHandle.reset(config); + + ByteBuf byteBuf = null; + boolean close = false; + boolean readData = false; + try { + byteBuf = allocHandle.allocate(allocator); + do { + allocHandle.lastBytesRead(doReadBytes(byteBuf)); + if (allocHandle.lastBytesRead() <= 0) { + if (!byteBuf.isReadable()) { // nothing was read. release the buffer. + byteBuf.release(); + byteBuf = null; + close = allocHandle.lastBytesRead() < 0; + if (close) { + // There is nothing left to read as we received an EOF. + readPending = false; + } + } + break; + } else { + readData = true; + } + + final int available = available(); + if (available <= 0) { + break; + } + + // Oio collects consecutive read operations into 1 ByteBuf before propagating up the pipeline. + if (!byteBuf.isWritable()) { + final int capacity = byteBuf.capacity(); + final int maxCapacity = byteBuf.maxCapacity(); + if (capacity == maxCapacity) { + allocHandle.incMessagesRead(1); + readPending = false; + pipeline.fireChannelRead(byteBuf); + byteBuf = allocHandle.allocate(allocator); + } else { + final int writerIndex = byteBuf.writerIndex(); + if (writerIndex + available > maxCapacity) { + byteBuf.capacity(maxCapacity); + } else { + byteBuf.ensureWritable(available); + } + } + } + } while (allocHandle.continueReading()); + + if (byteBuf != null) { + // It is possible we allocated a buffer because the previous one was not writable, but then didn't use + // it because allocHandle.continueReading() returned false. + if (byteBuf.isReadable()) { + readPending = false; + pipeline.fireChannelRead(byteBuf); + } else { + byteBuf.release(); + } + byteBuf = null; + } + + if (readData) { + allocHandle.readComplete(); + pipeline.fireChannelReadComplete(); + } + + if (close) { + closeOnRead(pipeline); + } + } catch (Throwable t) { + handleReadException(pipeline, byteBuf, t, close, allocHandle); + } finally { + if (readPending || config.isAutoRead() || !readData && isActive()) { + // Reading 0 bytes could mean there is a SocketTimeout and no data was actually read, so we + // should execute read() again because no data may have been read. + read(); + } + } + } + + @Override + protected void doWrite(ChannelOutboundBuffer in) throws Exception { + for (;;) { + Object msg = in.current(); + if (msg == null) { + // nothing left to write + break; + } + if (msg instanceof ByteBuf) { + ByteBuf buf = (ByteBuf) msg; + int readableBytes = buf.readableBytes(); + while (readableBytes > 0) { + doWriteBytes(buf); + int newReadableBytes = buf.readableBytes(); + in.progress(readableBytes - newReadableBytes); + readableBytes = newReadableBytes; + } + in.remove(); + } else if (msg instanceof FileRegion) { + FileRegion region = (FileRegion) msg; + long transferred = region.transferred(); + doWriteFileRegion(region); + in.progress(region.transferred() - transferred); + in.remove(); + } else { + in.remove(new UnsupportedOperationException( + "unsupported message type: " + StringUtil.simpleClassName(msg))); + } + } + } + + @Override + protected final Object filterOutboundMessage(Object msg) throws Exception { + if (msg instanceof ByteBuf || msg instanceof FileRegion) { + return msg; + } + + throw new UnsupportedOperationException( + "unsupported message type: " + StringUtil.simpleClassName(msg) + EXPECTED_TYPES); + } + + /** + * Return the number of bytes ready to read from the underlying Socket. + */ + protected abstract int available(); + + /** + * Read bytes from the underlying Socket. + * + * @param buf the {@link ByteBuf} into which the read bytes will be written + * @return amount the number of bytes read. This may return a negative amount if the underlying + * Socket was closed + * @throws Exception is thrown if an error occurred + */ + protected abstract int doReadBytes(ByteBuf buf) throws Exception; + + /** + * Write the data which is hold by the {@link ByteBuf} to the underlying Socket. + * + * @param buf the {@link ByteBuf} which holds the data to transfer + * @throws Exception is thrown if an error occurred + */ + protected abstract void doWriteBytes(ByteBuf buf) throws Exception; + + /** + * Write the data which is hold by the {@link FileRegion} to the underlying Socket. + * + * @param region the {@link FileRegion} which holds the data to transfer + * @throws Exception is thrown if an error occurred + */ + protected abstract void doWriteFileRegion(FileRegion region) throws Exception; +} diff --git a/netty-channel/src/main/java/io/netty/channel/oio/AbstractOioChannel.java b/netty-channel/src/main/java/io/netty/channel/oio/AbstractOioChannel.java new file mode 100644 index 0000000..1c7dc07 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/oio/AbstractOioChannel.java @@ -0,0 +1,166 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.oio; + +import io.netty.channel.AbstractChannel; +import io.netty.channel.Channel; +import io.netty.channel.ChannelPromise; +import io.netty.channel.EventLoop; +import io.netty.channel.ThreadPerChannelEventLoop; + +import java.net.SocketAddress; + +/** + * Abstract base class for {@link Channel} implementations that use Old-Blocking-IO + * + * @deprecated use NIO / EPOLL / KQUEUE transport. + */ +@Deprecated +public abstract class AbstractOioChannel extends AbstractChannel { + + protected static final int SO_TIMEOUT = 1000; + + boolean readPending; + boolean readWhenInactive; + final Runnable readTask = new Runnable() { + @Override + public void run() { + doRead(); + } + }; + private final Runnable clearReadPendingRunnable = new Runnable() { + @Override + public void run() { + readPending = false; + } + }; + + /** + * @see AbstractChannel#AbstractChannel(Channel) + */ + protected AbstractOioChannel(Channel parent) { + super(parent); + } + + @Override + protected AbstractUnsafe newUnsafe() { + return new DefaultOioUnsafe(); + } + + private final class DefaultOioUnsafe extends AbstractUnsafe { + @Override + public void connect( + final SocketAddress remoteAddress, + final SocketAddress localAddress, final ChannelPromise promise) { + if (!promise.setUncancellable() || !ensureOpen(promise)) { + return; + } + + try { + boolean wasActive = isActive(); + doConnect(remoteAddress, localAddress); + + // Get the state as trySuccess() may trigger an ChannelFutureListener that will close the Channel. + // We still need to ensure we call fireChannelActive() in this case. + boolean active = isActive(); + + safeSetSuccess(promise); + if (!wasActive && active) { + pipeline().fireChannelActive(); + } + } catch (Throwable t) { + safeSetFailure(promise, annotateConnectException(t, remoteAddress)); + closeIfClosed(); + } + } + } + + @Override + protected boolean isCompatible(EventLoop loop) { + return loop instanceof ThreadPerChannelEventLoop; + } + + /** + * Connect to the remote peer using the given localAddress if one is specified or {@code null} otherwise. + */ + protected abstract void doConnect( + SocketAddress remoteAddress, SocketAddress localAddress) throws Exception; + + @Override + protected void doBeginRead() throws Exception { + if (readPending) { + return; + } + if (!isActive()) { + readWhenInactive = true; + return; + } + + readPending = true; + eventLoop().execute(readTask); + } + + protected abstract void doRead(); + + /** + * @deprecated No longer supported. + * No longer supported. + */ + @Deprecated + protected boolean isReadPending() { + return readPending; + } + + /** + * @deprecated Use {@link #clearReadPending()} if appropriate instead. + * No longer supported. + */ + @Deprecated + protected void setReadPending(final boolean readPending) { + if (isRegistered()) { + EventLoop eventLoop = eventLoop(); + if (eventLoop.inEventLoop()) { + this.readPending = readPending; + } else { + eventLoop.execute(new Runnable() { + @Override + public void run() { + AbstractOioChannel.this.readPending = readPending; + } + }); + } + } else { + this.readPending = readPending; + } + } + + /** + * Set read pending to {@code false}. + */ + protected final void clearReadPending() { + if (isRegistered()) { + EventLoop eventLoop = eventLoop(); + if (eventLoop.inEventLoop()) { + readPending = false; + } else { + eventLoop.execute(clearReadPendingRunnable); + } + } else { + // Best effort if we are not registered yet clear readPending. This happens during channel initialization. + readPending = false; + } + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/oio/AbstractOioMessageChannel.java b/netty-channel/src/main/java/io/netty/channel/oio/AbstractOioMessageChannel.java new file mode 100644 index 0000000..7b37114 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/oio/AbstractOioMessageChannel.java @@ -0,0 +1,113 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.oio; + +import io.netty.channel.Channel; +import io.netty.channel.ChannelConfig; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.RecvByteBufAllocator; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +/** + * Abstract base class for OIO which reads and writes objects from/to a Socket + * + * @deprecated use NIO / EPOLL / KQUEUE transport. + */ +@Deprecated +public abstract class AbstractOioMessageChannel extends AbstractOioChannel { + + private final List readBuf = new ArrayList(); + + protected AbstractOioMessageChannel(Channel parent) { + super(parent); + } + + @Override + protected void doRead() { + if (!readPending) { + // We have to check readPending here because the Runnable to read could have been scheduled and later + // during the same read loop readPending was set to false. + return; + } + // In OIO we should set readPending to false even if the read was not successful so we can schedule + // another read on the event loop if no reads are done. + readPending = false; + + final ChannelConfig config = config(); + final ChannelPipeline pipeline = pipeline(); + final RecvByteBufAllocator.Handle allocHandle = unsafe().recvBufAllocHandle(); + allocHandle.reset(config); + + boolean closed = false; + Throwable exception = null; + try { + do { + // Perform a read. + int localRead = doReadMessages(readBuf); + if (localRead == 0) { + break; + } + if (localRead < 0) { + closed = true; + break; + } + + allocHandle.incMessagesRead(localRead); + } while (allocHandle.continueReading()); + } catch (Throwable t) { + exception = t; + } + + boolean readData = false; + int size = readBuf.size(); + if (size > 0) { + readData = true; + for (int i = 0; i < size; i++) { + readPending = false; + pipeline.fireChannelRead(readBuf.get(i)); + } + readBuf.clear(); + allocHandle.readComplete(); + pipeline.fireChannelReadComplete(); + } + + if (exception != null) { + if (exception instanceof IOException) { + closed = true; + } + + pipeline.fireExceptionCaught(exception); + } + + if (closed) { + if (isOpen()) { + unsafe().close(unsafe().voidPromise()); + } + } else if (readPending || config.isAutoRead() || !readData && isActive()) { + // Reading 0 bytes could mean there is a SocketTimeout and no data was actually read, so we + // should execute read() again because no data may have been read. + read(); + } + } + + /** + * Read messages into the given array and return the amount which was read. + */ + protected abstract int doReadMessages(List msgs) throws Exception; +} diff --git a/netty-channel/src/main/java/io/netty/channel/oio/OioByteStreamChannel.java b/netty-channel/src/main/java/io/netty/channel/oio/OioByteStreamChannel.java new file mode 100644 index 0000000..e5e449f --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/oio/OioByteStreamChannel.java @@ -0,0 +1,172 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.oio; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.Channel; +import io.netty.channel.FileRegion; +import io.netty.channel.RecvByteBufAllocator; +import io.netty.util.internal.ObjectUtil; + +import java.io.EOFException; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.channels.Channels; +import java.nio.channels.ClosedChannelException; +import java.nio.channels.NotYetConnectedException; +import java.nio.channels.WritableByteChannel; + +/** + * Abstract base class for OIO Channels that are based on streams. + * + * @deprecated use NIO / EPOLL / KQUEUE transport. + */ +@Deprecated +public abstract class OioByteStreamChannel extends AbstractOioByteChannel { + + private static final InputStream CLOSED_IN = new InputStream() { + @Override + public int read() { + return -1; + } + }; + + private static final OutputStream CLOSED_OUT = new OutputStream() { + @Override + public void write(int b) throws IOException { + throw new ClosedChannelException(); + } + }; + + private InputStream is; + private OutputStream os; + private WritableByteChannel outChannel; + + /** + * Create a new instance + * + * @param parent the parent {@link Channel} which was used to create this instance. This can be null if the + * channel has no parent as it was created by your self. + */ + protected OioByteStreamChannel(Channel parent) { + super(parent); + } + + /** + * Activate this instance. After this call {@link #isActive()} will return {@code true}. + */ + protected final void activate(InputStream is, OutputStream os) { + if (this.is != null) { + throw new IllegalStateException("input was set already"); + } + if (this.os != null) { + throw new IllegalStateException("output was set already"); + } + this.is = ObjectUtil.checkNotNull(is, "is"); + this.os = ObjectUtil.checkNotNull(os, "os"); + if (readWhenInactive) { + eventLoop().execute(readTask); + readWhenInactive = false; + } + } + + @Override + public boolean isActive() { + InputStream is = this.is; + if (is == null || is == CLOSED_IN) { + return false; + } + + OutputStream os = this.os; + return !(os == null || os == CLOSED_OUT); + } + + @Override + protected int available() { + try { + return is.available(); + } catch (IOException ignored) { + return 0; + } + } + + @Override + protected int doReadBytes(ByteBuf buf) throws Exception { + final RecvByteBufAllocator.Handle allocHandle = unsafe().recvBufAllocHandle(); + allocHandle.attemptedBytesRead(Math.max(1, Math.min(available(), buf.maxWritableBytes()))); + return buf.writeBytes(is, allocHandle.attemptedBytesRead()); + } + + @Override + protected void doWriteBytes(ByteBuf buf) throws Exception { + OutputStream os = this.os; + if (os == null) { + throw new NotYetConnectedException(); + } + buf.readBytes(os, buf.readableBytes()); + } + + @Override + protected void doWriteFileRegion(FileRegion region) throws Exception { + OutputStream os = this.os; + if (os == null) { + throw new NotYetConnectedException(); + } + if (outChannel == null) { + outChannel = Channels.newChannel(os); + } + + long written = 0; + for (;;) { + long localWritten = region.transferTo(outChannel, written); + if (localWritten == -1) { + checkEOF(region); + return; + } + written += localWritten; + + if (written >= region.count()) { + return; + } + } + } + + private static void checkEOF(FileRegion region) throws IOException { + if (region.transferred() < region.count()) { + throw new EOFException("Expected to be able to write " + region.count() + " bytes, " + + "but only wrote " + region.transferred()); + } + } + + @Override + protected void doClose() throws Exception { + InputStream is = this.is; + OutputStream os = this.os; + this.is = CLOSED_IN; + this.os = CLOSED_OUT; + + try { + if (is != null) { + is.close(); + } + } finally { + if (os != null) { + os.close(); + } + } + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/oio/OioEventLoopGroup.java b/netty-channel/src/main/java/io/netty/channel/oio/OioEventLoopGroup.java new file mode 100644 index 0000000..2a25e97 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/oio/OioEventLoopGroup.java @@ -0,0 +1,87 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.oio; + + +import io.netty.channel.Channel; +import io.netty.channel.ChannelException; +import io.netty.channel.ChannelPromise; +import io.netty.channel.EventLoop; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.ThreadPerChannelEventLoopGroup; + +import java.util.concurrent.Executor; +import java.util.concurrent.ThreadFactory; + +/** + * {@link EventLoopGroup} which is used to handle OIO {@link Channel}'s. Each {@link Channel} will be handled by its + * own {@link EventLoop} to not block others. + * + * @deprecated use NIO / EPOLL / KQUEUE transport. + */ +@Deprecated +public class OioEventLoopGroup extends ThreadPerChannelEventLoopGroup { + + /** + * Create a new {@link OioEventLoopGroup} with no limit in place. + */ + public OioEventLoopGroup() { + this(0); + } + + /** + * Create a new {@link OioEventLoopGroup}. + * + * @param maxChannels the maximum number of channels to handle with this instance. Once you try to register + * a new {@link Channel} and the maximum is exceed it will throw an + * {@link ChannelException} on the {@link #register(Channel)} and + * {@link #register(ChannelPromise)} method. + * Use {@code 0} to use no limit + */ + public OioEventLoopGroup(int maxChannels) { + this(maxChannels, (ThreadFactory) null); + } + + /** + * Create a new {@link OioEventLoopGroup}. + * + * @param maxChannels the maximum number of channels to handle with this instance. Once you try to register + * a new {@link Channel} and the maximum is exceed it will throw an + * {@link ChannelException} on the {@link #register(Channel)} and + * {@link #register(ChannelPromise)} method. + * Use {@code 0} to use no limit + * @param executor the {@link Executor} used to create new {@link Thread} instances that handle the + * registered {@link Channel}s + */ + public OioEventLoopGroup(int maxChannels, Executor executor) { + super(maxChannels, executor); + } + + /** + * Create a new {@link OioEventLoopGroup}. + * + * @param maxChannels the maximum number of channels to handle with this instance. Once you try to register + * a new {@link Channel} and the maximum is exceed it will throw an + * {@link ChannelException} on the {@link #register(Channel)} and + * {@link #register(ChannelPromise)} method. + * Use {@code 0} to use no limit + * @param threadFactory the {@link ThreadFactory} used to create new {@link Thread} instances that handle the + * registered {@link Channel}s + */ + public OioEventLoopGroup(int maxChannels, ThreadFactory threadFactory) { + super(maxChannels, threadFactory); + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/oio/package-info.java b/netty-channel/src/main/java/io/netty/channel/oio/package-info.java new file mode 100644 index 0000000..ae1cfac --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/oio/package-info.java @@ -0,0 +1,21 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/** + * Old blocking I/O based channel API implementation - recommended for + * a small number of connections (< 1000). + */ +package io.netty.channel.oio; diff --git a/netty-channel/src/main/java/io/netty/channel/package-info.java b/netty-channel/src/main/java/io/netty/channel/package-info.java new file mode 100644 index 0000000..52ac05b --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/package-info.java @@ -0,0 +1,22 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/** + * The core channel API which is asynchronous and event-driven abstraction of + * various transports such as a + * NIO Channel. + */ +package io.netty.channel; diff --git a/netty-channel/src/main/java/io/netty/channel/pool/AbstractChannelPoolHandler.java b/netty-channel/src/main/java/io/netty/channel/pool/AbstractChannelPoolHandler.java new file mode 100644 index 0000000..ada532c --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/pool/AbstractChannelPoolHandler.java @@ -0,0 +1,44 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.pool; + +import io.netty.channel.Channel; + +/** + * A skeletal {@link ChannelPoolHandler} implementation. + */ +public abstract class AbstractChannelPoolHandler implements ChannelPoolHandler { + + /** + * NOOP implementation, sub-classes may override this. + * + * {@inheritDoc} + */ + @Override + public void channelAcquired(@SuppressWarnings("unused") Channel ch) throws Exception { + // NOOP + } + + /** + * NOOP implementation, sub-classes may override this. + * + * {@inheritDoc} + */ + @Override + public void channelReleased(@SuppressWarnings("unused") Channel ch) throws Exception { + // NOOP + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/pool/AbstractChannelPoolMap.java b/netty-channel/src/main/java/io/netty/channel/pool/AbstractChannelPoolMap.java new file mode 100644 index 0000000..2f104f0 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/pool/AbstractChannelPoolMap.java @@ -0,0 +1,153 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.pool; + +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.GenericFutureListener; +import io.netty.util.concurrent.GlobalEventExecutor; +import io.netty.util.concurrent.Promise; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.ReadOnlyIterator; + +import java.io.Closeable; +import java.util.Iterator; +import java.util.Map.Entry; +import java.util.concurrent.ConcurrentMap; + +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +/** + * A skeletal {@link ChannelPoolMap} implementation. To find the right {@link ChannelPool} + * the {@link Object#hashCode()} and {@link Object#equals(Object)} is used. + */ +public abstract class AbstractChannelPoolMap + implements ChannelPoolMap, Iterable>, Closeable { + private final ConcurrentMap map = PlatformDependent.newConcurrentHashMap(); + + @Override + public final P get(K key) { + P pool = map.get(checkNotNull(key, "key")); + if (pool == null) { + pool = newPool(key); + P old = map.putIfAbsent(key, pool); + if (old != null) { + // We need to destroy the newly created pool as we not use it. + poolCloseAsyncIfSupported(pool); + pool = old; + } + } + return pool; + } + /** + * Remove the {@link ChannelPool} from this {@link AbstractChannelPoolMap}. Returns {@code true} if removed, + * {@code false} otherwise. + * + * If the removed pool extends {@link SimpleChannelPool} it will be closed asynchronously to avoid blocking in + * this method. + * + * Please note that {@code null} keys are not allowed. + */ + public final boolean remove(K key) { + P pool = map.remove(checkNotNull(key, "key")); + if (pool != null) { + poolCloseAsyncIfSupported(pool); + return true; + } + return false; + } + + /** + * Remove the {@link ChannelPool} from this {@link AbstractChannelPoolMap}. Returns a future that comletes with a + * {@code true} result if the pool has been removed by this call, otherwise the result is {@code false}. + * + * If the removed pool extends {@link SimpleChannelPool} it will be closed asynchronously to avoid blocking in + * this method. The returned future will be completed once this asynchronous pool close operation completes. + */ + private Future removeAsyncIfSupported(K key) { + P pool = map.remove(checkNotNull(key, "key")); + if (pool != null) { + final Promise removePromise = GlobalEventExecutor.INSTANCE.newPromise(); + poolCloseAsyncIfSupported(pool).addListener(new GenericFutureListener>() { + @Override + public void operationComplete(Future future) throws Exception { + if (future.isSuccess()) { + removePromise.setSuccess(Boolean.TRUE); + } else { + removePromise.setFailure(future.cause()); + } + } + }); + return removePromise; + } + return GlobalEventExecutor.INSTANCE.newSucceededFuture(Boolean.FALSE); + } + + /** + * If the pool implementation supports asynchronous close, then use it to avoid a blocking close call in case + * the ChannelPoolMap operations are called from an EventLoop. + * + * @param pool the ChannelPool to be closed + */ + private static Future poolCloseAsyncIfSupported(ChannelPool pool) { + if (pool instanceof SimpleChannelPool) { + return ((SimpleChannelPool) pool).closeAsync(); + } else { + try { + pool.close(); + return GlobalEventExecutor.INSTANCE.newSucceededFuture(null); + } catch (Exception e) { + return GlobalEventExecutor.INSTANCE.newFailedFuture(e); + } + } + } + + @Override + public final Iterator> iterator() { + return new ReadOnlyIterator>(map.entrySet().iterator()); + } + + /** + * Returns the number of {@link ChannelPool}s currently in this {@link AbstractChannelPoolMap}. + */ + public final int size() { + return map.size(); + } + + /** + * Returns {@code true} if the {@link AbstractChannelPoolMap} is empty, otherwise {@code false}. + */ + public final boolean isEmpty() { + return map.isEmpty(); + } + + @Override + public final boolean contains(K key) { + return map.containsKey(checkNotNull(key, "key")); + } + + /** + * Called once a new {@link ChannelPool} needs to be created as non exists yet for the {@code key}. + */ + protected abstract P newPool(K key); + + @Override + public final void close() { + for (K key: map.keySet()) { + // Wait for remove to finish to ensure that resources are released before returning from close + removeAsyncIfSupported(key).syncUninterruptibly(); + } + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/pool/ChannelHealthChecker.java b/netty-channel/src/main/java/io/netty/channel/pool/ChannelHealthChecker.java new file mode 100644 index 0000000..b63f6ea --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/pool/ChannelHealthChecker.java @@ -0,0 +1,47 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.pool; + +import io.netty.channel.Channel; +import io.netty.channel.EventLoop; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.Promise; + +/** + * Called before a {@link Channel} will be returned via {@link ChannelPool#acquire()} or + * {@link ChannelPool#acquire(Promise)}. + */ +public interface ChannelHealthChecker { + + /** + * {@link ChannelHealthChecker} implementation that checks if {@link Channel#isActive()} returns {@code true}. + */ + ChannelHealthChecker ACTIVE = new ChannelHealthChecker() { + @Override + public Future isHealthy(Channel channel) { + EventLoop loop = channel.eventLoop(); + return channel.isActive()? loop.newSucceededFuture(Boolean.TRUE) : loop.newSucceededFuture(Boolean.FALSE); + } + }; + + /** + * Check if the given channel is healthy which means it can be used. The returned {@link Future} is notified once + * the check is complete. If notified with {@link Boolean#TRUE} it can be used {@link Boolean#FALSE} otherwise. + * + * This method will be called by the {@link EventLoop} of the {@link Channel}. + */ + Future isHealthy(Channel channel); +} diff --git a/netty-channel/src/main/java/io/netty/channel/pool/ChannelPool.java b/netty-channel/src/main/java/io/netty/channel/pool/ChannelPool.java new file mode 100644 index 0000000..92b457b --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/pool/ChannelPool.java @@ -0,0 +1,61 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.pool; + +import io.netty.channel.Channel; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.Promise; + +import java.io.Closeable; + +/** + * Allows to acquire and release {@link Channel} and so act as a pool of these. + */ +public interface ChannelPool extends Closeable { + + /** + * Acquire a {@link Channel} from this {@link ChannelPool}. The returned {@link Future} is notified once + * the acquire is successful and failed otherwise. + * + * Its important that an acquired is always released to the pool again, even if the {@link Channel} + * is explicitly closed.. + */ + Future acquire(); + + /** + * Acquire a {@link Channel} from this {@link ChannelPool}. The given {@link Promise} is notified once + * the acquire is successful and failed otherwise. + * + * Its important that an acquired is always released to the pool again, even if the {@link Channel} + * is explicitly closed.. + */ + Future acquire(Promise promise); + + /** + * Release a {@link Channel} back to this {@link ChannelPool}. The returned {@link Future} is notified once + * the release is successful and failed otherwise. When failed the {@link Channel} will automatically closed. + */ + Future release(Channel channel); + + /** + * Release a {@link Channel} back to this {@link ChannelPool}. The given {@link Promise} is notified once + * the release is successful and failed otherwise. When failed the {@link Channel} will automatically closed. + */ + Future release(Channel channel, Promise promise); + + @Override + void close(); +} diff --git a/netty-channel/src/main/java/io/netty/channel/pool/ChannelPoolHandler.java b/netty-channel/src/main/java/io/netty/channel/pool/ChannelPoolHandler.java new file mode 100644 index 0000000..fb4d951 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/pool/ChannelPoolHandler.java @@ -0,0 +1,48 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.pool; + +import io.netty.channel.Channel; +import io.netty.channel.EventLoop; +import io.netty.util.concurrent.Promise; + +/** + * Handler which is called for various actions done by the {@link ChannelPool}. + */ +public interface ChannelPoolHandler { + /** + * Called once a {@link Channel} was released by calling {@link ChannelPool#release(Channel)} or + * {@link ChannelPool#release(Channel, Promise)}. + * + * This method will be called by the {@link EventLoop} of the {@link Channel}. + */ + void channelReleased(Channel ch) throws Exception; + + /** + * Called once a {@link Channel} was acquired by calling {@link ChannelPool#acquire()} or + * {@link ChannelPool#acquire(Promise)}. + * + * This method will be called by the {@link EventLoop} of the {@link Channel}. + */ + void channelAcquired(Channel ch) throws Exception; + + /** + * Called once a new {@link Channel} is created in the {@link ChannelPool}. + * + * This method will be called by the {@link EventLoop} of the {@link Channel}. + */ + void channelCreated(Channel ch) throws Exception; +} diff --git a/netty-channel/src/main/java/io/netty/channel/pool/ChannelPoolMap.java b/netty-channel/src/main/java/io/netty/channel/pool/ChannelPoolMap.java new file mode 100644 index 0000000..ad9da4b --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/pool/ChannelPoolMap.java @@ -0,0 +1,39 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.pool; + +/** + * Allows to map {@link ChannelPool} implementations to a specific key. + * + * @param the type of the key + * @param

the type of the {@link ChannelPool} + */ +public interface ChannelPoolMap { + /** + * Return the {@link ChannelPool} for the {@code code}. This will never return {@code null}, + * but create a new {@link ChannelPool} if non exists for they requested {@code key}. + * + * Please note that {@code null} keys are not allowed. + */ + P get(K key); + + /** + * Returns {@code true} if a {@link ChannelPool} exists for the given {@code key}. + * + * Please note that {@code null} keys are not allowed. + */ + boolean contains(K key); +} diff --git a/netty-channel/src/main/java/io/netty/channel/pool/FixedChannelPool.java b/netty-channel/src/main/java/io/netty/channel/pool/FixedChannelPool.java new file mode 100644 index 0000000..c0cc702 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/pool/FixedChannelPool.java @@ -0,0 +1,533 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.pool; + +import static io.netty.util.internal.ObjectUtil.checkPositive; + +import io.netty.bootstrap.Bootstrap; +import io.netty.channel.Channel; +import io.netty.util.concurrent.EventExecutor; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.FutureListener; +import io.netty.util.concurrent.GlobalEventExecutor; +import io.netty.util.concurrent.Promise; +import io.netty.util.internal.ObjectUtil; + +import java.nio.channels.ClosedChannelException; +import java.util.ArrayDeque; +import java.util.Queue; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.Callable; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +/** + * {@link ChannelPool} implementation that takes another {@link ChannelPool} implementation and enforce a maximum + * number of concurrent connections. + */ +public class FixedChannelPool extends SimpleChannelPool { + + public enum AcquireTimeoutAction { + /** + * Create a new connection when the timeout is detected. + */ + NEW, + + /** + * Fail the {@link Future} of the acquire call with a {@link TimeoutException}. + */ + FAIL + } + + private final EventExecutor executor; + private final long acquireTimeoutNanos; + private final Runnable timeoutTask; + + // There is no need to worry about synchronization as everything that modified the queue or counts is done + // by the above EventExecutor. + private final Queue pendingAcquireQueue = new ArrayDeque(); + private final int maxConnections; + private final int maxPendingAcquires; + private final AtomicInteger acquiredChannelCount = new AtomicInteger(); + private int pendingAcquireCount; + private boolean closed; + + /** + * Creates a new instance using the {@link ChannelHealthChecker#ACTIVE}. + * + * @param bootstrap the {@link Bootstrap} that is used for connections + * @param handler the {@link ChannelPoolHandler} that will be notified for the different pool actions + * @param maxConnections the number of maximal active connections, once this is reached new tries to acquire + * a {@link Channel} will be delayed until a connection is returned to the pool again. + */ + public FixedChannelPool(Bootstrap bootstrap, + ChannelPoolHandler handler, int maxConnections) { + this(bootstrap, handler, maxConnections, Integer.MAX_VALUE); + } + + /** + * Creates a new instance using the {@link ChannelHealthChecker#ACTIVE}. + * + * @param bootstrap the {@link Bootstrap} that is used for connections + * @param handler the {@link ChannelPoolHandler} that will be notified for the different pool actions + * @param maxConnections the number of maximal active connections, once this is reached new tries to + * acquire a {@link Channel} will be delayed until a connection is returned to the + * pool again. + * @param maxPendingAcquires the maximum number of pending acquires. Once this is exceed acquire tries will + * be failed. + */ + public FixedChannelPool(Bootstrap bootstrap, + ChannelPoolHandler handler, int maxConnections, int maxPendingAcquires) { + this(bootstrap, handler, ChannelHealthChecker.ACTIVE, null, -1, maxConnections, maxPendingAcquires); + } + + /** + * Creates a new instance. + * + * @param bootstrap the {@link Bootstrap} that is used for connections + * @param handler the {@link ChannelPoolHandler} that will be notified for the different pool actions + * @param healthCheck the {@link ChannelHealthChecker} that will be used to check if a {@link Channel} is + * still healthy when obtain from the {@link ChannelPool} + * @param action the {@link AcquireTimeoutAction} to use or {@code null} if non should be used. + * In this case {@param acquireTimeoutMillis} must be {@code -1}. + * @param acquireTimeoutMillis the time (in milliseconds) after which an pending acquire must complete or + * the {@link AcquireTimeoutAction} takes place. + * @param maxConnections the number of maximal active connections, once this is reached new tries to + * acquire a {@link Channel} will be delayed until a connection is returned to the + * pool again. + * @param maxPendingAcquires the maximum number of pending acquires. Once this is exceed acquire tries will + * be failed. + */ + public FixedChannelPool(Bootstrap bootstrap, + ChannelPoolHandler handler, + ChannelHealthChecker healthCheck, AcquireTimeoutAction action, + final long acquireTimeoutMillis, + int maxConnections, int maxPendingAcquires) { + this(bootstrap, handler, healthCheck, action, acquireTimeoutMillis, maxConnections, maxPendingAcquires, true); + } + + /** + * Creates a new instance. + * + * @param bootstrap the {@link Bootstrap} that is used for connections + * @param handler the {@link ChannelPoolHandler} that will be notified for the different pool actions + * @param healthCheck the {@link ChannelHealthChecker} that will be used to check if a {@link Channel} is + * still healthy when obtain from the {@link ChannelPool} + * @param action the {@link AcquireTimeoutAction} to use or {@code null} if non should be used. + * In this case {@param acquireTimeoutMillis} must be {@code -1}. + * @param acquireTimeoutMillis the time (in milliseconds) after which an pending acquire must complete or + * the {@link AcquireTimeoutAction} takes place. + * @param maxConnections the number of maximal active connections, once this is reached new tries to + * acquire a {@link Channel} will be delayed until a connection is returned to the + * pool again. + * @param maxPendingAcquires the maximum number of pending acquires. Once this is exceed acquire tries will + * be failed. + * @param releaseHealthCheck will check channel health before offering back if this parameter set to + * {@code true}. + */ + public FixedChannelPool(Bootstrap bootstrap, + ChannelPoolHandler handler, + ChannelHealthChecker healthCheck, AcquireTimeoutAction action, + final long acquireTimeoutMillis, + int maxConnections, int maxPendingAcquires, final boolean releaseHealthCheck) { + this(bootstrap, handler, healthCheck, action, acquireTimeoutMillis, maxConnections, maxPendingAcquires, + releaseHealthCheck, true); + } + + /** + * Creates a new instance. + * + * @param bootstrap the {@link Bootstrap} that is used for connections + * @param handler the {@link ChannelPoolHandler} that will be notified for the different pool actions + * @param healthCheck the {@link ChannelHealthChecker} that will be used to check if a {@link Channel} is + * still healthy when obtain from the {@link ChannelPool} + * @param action the {@link AcquireTimeoutAction} to use or {@code null} if non should be used. + * In this case {@param acquireTimeoutMillis} must be {@code -1}. + * @param acquireTimeoutMillis the time (in milliseconds) after which an pending acquire must complete or + * the {@link AcquireTimeoutAction} takes place. + * @param maxConnections the number of maximal active connections, once this is reached new tries to + * acquire a {@link Channel} will be delayed until a connection is returned to the + * pool again. + * @param maxPendingAcquires the maximum number of pending acquires. Once this is exceed acquire tries will + * be failed. + * @param releaseHealthCheck will check channel health before offering back if this parameter set to + * {@code true}. + * @param lastRecentUsed {@code true} {@link Channel} selection will be LIFO, if {@code false} FIFO. + */ + public FixedChannelPool(Bootstrap bootstrap, + ChannelPoolHandler handler, + ChannelHealthChecker healthCheck, AcquireTimeoutAction action, + final long acquireTimeoutMillis, + int maxConnections, int maxPendingAcquires, + boolean releaseHealthCheck, boolean lastRecentUsed) { + super(bootstrap, handler, healthCheck, releaseHealthCheck, lastRecentUsed); + checkPositive(maxConnections, "maxConnections"); + checkPositive(maxPendingAcquires, "maxPendingAcquires"); + if (action == null && acquireTimeoutMillis == -1) { + timeoutTask = null; + acquireTimeoutNanos = -1; + } else if (action == null && acquireTimeoutMillis != -1) { + throw new NullPointerException("action"); + } else if (action != null && acquireTimeoutMillis < 0) { + throw new IllegalArgumentException("acquireTimeoutMillis: " + acquireTimeoutMillis + " (expected: >= 0)"); + } else { + acquireTimeoutNanos = TimeUnit.MILLISECONDS.toNanos(acquireTimeoutMillis); + switch (action) { + case FAIL: + timeoutTask = new TimeoutTask() { + @Override + public void onTimeout(AcquireTask task) { + // Fail the promise as we timed out. + task.promise.setFailure(new AcquireTimeoutException()); + } + }; + break; + case NEW: + timeoutTask = new TimeoutTask() { + @Override + public void onTimeout(AcquireTask task) { + // Increment the acquire count and delegate to super to actually acquire a Channel which will + // create a new connection. + task.acquired(); + + FixedChannelPool.super.acquire(task.promise); + } + }; + break; + default: + throw new Error(); + } + } + executor = bootstrap.config().group().next(); + this.maxConnections = maxConnections; + this.maxPendingAcquires = maxPendingAcquires; + } + + /** Returns the number of acquired channels that this pool thinks it has. */ + public int acquiredChannelCount() { + return acquiredChannelCount.get(); + } + + @Override + public Future acquire(final Promise promise) { + try { + if (executor.inEventLoop()) { + acquire0(promise); + } else { + executor.execute(new Runnable() { + @Override + public void run() { + acquire0(promise); + } + }); + } + } catch (Throwable cause) { + promise.tryFailure(cause); + } + return promise; + } + + private void acquire0(final Promise promise) { + try { + assert executor.inEventLoop(); + + if (closed) { + promise.setFailure(new IllegalStateException("FixedChannelPool was closed")); + return; + } + if (acquiredChannelCount.get() < maxConnections) { + assert acquiredChannelCount.get() >= 0; + + // We need to create a new promise as we need to ensure the AcquireListener runs in the correct + // EventLoop + Promise p = executor.newPromise(); + AcquireListener l = new AcquireListener(promise); + l.acquired(); + p.addListener(l); + super.acquire(p); + } else { + if (pendingAcquireCount >= maxPendingAcquires) { + tooManyOutstanding(promise); + } else { + AcquireTask task = new AcquireTask(promise); + if (pendingAcquireQueue.offer(task)) { + ++pendingAcquireCount; + + if (timeoutTask != null) { + task.timeoutFuture = executor.schedule(timeoutTask, acquireTimeoutNanos, + TimeUnit.NANOSECONDS); + } + } else { + tooManyOutstanding(promise); + } + } + + assert pendingAcquireCount > 0; + } + } catch (Throwable cause) { + promise.tryFailure(cause); + } + } + + private void tooManyOutstanding(Promise promise) { + promise.setFailure(new IllegalStateException("Too many outstanding acquire operations")); + } + + @Override + public Future release(final Channel channel, final Promise promise) { + ObjectUtil.checkNotNull(promise, "promise"); + final Promise p = executor.newPromise(); + super.release(channel, p.addListener(new FutureListener() { + + @Override + public void operationComplete(Future future) { + try { + assert executor.inEventLoop(); + + if (closed) { + // Since the pool is closed, we have no choice but to close the channel + channel.close(); + promise.setFailure(new IllegalStateException("FixedChannelPool was closed")); + return; + } + + if (future.isSuccess()) { + decrementAndRunTaskQueue(); + promise.setSuccess(null); + } else { + Throwable cause = future.cause(); + // Check if the exception was not because of we passed the Channel to the wrong pool. + if (!(cause instanceof IllegalArgumentException)) { + decrementAndRunTaskQueue(); + } + promise.setFailure(future.cause()); + } + } catch (Throwable cause) { + promise.tryFailure(cause); + } + } + })); + return promise; + } + + private void decrementAndRunTaskQueue() { + // We should never have a negative value. + int currentCount = acquiredChannelCount.decrementAndGet(); + assert currentCount >= 0; + + // Run the pending acquire tasks before notify the original promise so if the user would + // try to acquire again from the ChannelFutureListener and the pendingAcquireCount is >= + // maxPendingAcquires we may be able to run some pending tasks first and so allow to add + // more. + runTaskQueue(); + } + + private void runTaskQueue() { + while (acquiredChannelCount.get() < maxConnections) { + AcquireTask task = pendingAcquireQueue.poll(); + if (task == null) { + break; + } + + // Cancel the timeout if one was scheduled + ScheduledFuture timeoutFuture = task.timeoutFuture; + if (timeoutFuture != null) { + timeoutFuture.cancel(false); + } + + --pendingAcquireCount; + task.acquired(); + + super.acquire(task.promise); + } + + // We should never have a negative value. + assert pendingAcquireCount >= 0; + assert acquiredChannelCount.get() >= 0; + } + + // AcquireTask extends AcquireListener to reduce object creations and so GC pressure + private final class AcquireTask extends AcquireListener { + final Promise promise; + final long expireNanoTime = System.nanoTime() + acquireTimeoutNanos; + ScheduledFuture timeoutFuture; + + AcquireTask(Promise promise) { + super(promise); + // We need to create a new promise as we need to ensure the AcquireListener runs in the correct + // EventLoop. + this.promise = executor.newPromise().addListener(this); + } + } + + private abstract class TimeoutTask implements Runnable { + @Override + public final void run() { + assert executor.inEventLoop(); + long nanoTime = System.nanoTime(); + for (;;) { + AcquireTask task = pendingAcquireQueue.peek(); + // Compare nanoTime as descripted in the javadocs of System.nanoTime() + // + // See https://docs.oracle.com/javase/7/docs/api/java/lang/System.html#nanoTime() + // See https://github.com/netty/netty/issues/3705 + if (task == null || nanoTime - task.expireNanoTime < 0) { + break; + } + pendingAcquireQueue.remove(); + + --pendingAcquireCount; + onTimeout(task); + } + } + + public abstract void onTimeout(AcquireTask task); + } + + private class AcquireListener implements FutureListener { + private final Promise originalPromise; + protected boolean acquired; + + AcquireListener(Promise originalPromise) { + this.originalPromise = originalPromise; + } + + @Override + public void operationComplete(Future future) throws Exception { + try { + assert executor.inEventLoop(); + + if (closed) { + if (future.isSuccess()) { + // Since the pool is closed, we have no choice but to close the channel + future.getNow().close(); + } + originalPromise.setFailure(new IllegalStateException("FixedChannelPool was closed")); + return; + } + + if (future.isSuccess()) { + originalPromise.setSuccess(future.getNow()); + } else { + if (acquired) { + decrementAndRunTaskQueue(); + } else { + runTaskQueue(); + } + + originalPromise.setFailure(future.cause()); + } + } catch (Throwable cause) { + originalPromise.tryFailure(cause); + } + } + + public void acquired() { + if (acquired) { + return; + } + acquiredChannelCount.incrementAndGet(); + acquired = true; + } + } + + @Override + public void close() { + try { + closeAsync().await(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + throw new RuntimeException(e); + } + } + + /** + * Closes the pool in an async manner. + * + * @return Future which represents completion of the close task + */ + @Override + public Future closeAsync() { + if (executor.inEventLoop()) { + return close0(); + } else { + final Promise closeComplete = executor.newPromise(); + executor.execute(new Runnable() { + @Override + public void run() { + close0().addListener(new FutureListener() { + @Override + public void operationComplete(Future f) throws Exception { + if (f.isSuccess()) { + closeComplete.setSuccess(null); + } else { + closeComplete.setFailure(f.cause()); + } + } + }); + } + }); + return closeComplete; + } + } + + private Future close0() { + assert executor.inEventLoop(); + + if (!closed) { + closed = true; + for (;;) { + AcquireTask task = pendingAcquireQueue.poll(); + if (task == null) { + break; + } + ScheduledFuture f = task.timeoutFuture; + if (f != null) { + f.cancel(false); + } + task.promise.setFailure(new ClosedChannelException()); + } + acquiredChannelCount.set(0); + pendingAcquireCount = 0; + + // Ensure we dispatch this on another Thread as close0 will be called from the EventExecutor and we need + // to ensure we will not block in a EventExecutor. + return GlobalEventExecutor.INSTANCE.submit(new Callable() { + @Override + public Void call() throws Exception { + FixedChannelPool.super.close(); + return null; + } + }); + } + + return GlobalEventExecutor.INSTANCE.newSucceededFuture(null); + } + + private static final class AcquireTimeoutException extends TimeoutException { + + private AcquireTimeoutException() { + super("Acquire operation took longer then configured maximum time"); + } + + // Suppress a warning since the method doesn't need synchronization + @Override + public Throwable fillInStackTrace() { + return this; + } + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/pool/SimpleChannelPool.java b/netty-channel/src/main/java/io/netty/channel/pool/SimpleChannelPool.java new file mode 100644 index 0000000..1c972e7 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/pool/SimpleChannelPool.java @@ -0,0 +1,440 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.pool; + +import io.netty.bootstrap.Bootstrap; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.EventLoop; +import io.netty.util.AttributeKey; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.FutureListener; +import io.netty.util.concurrent.GlobalEventExecutor; +import io.netty.util.concurrent.Promise; +import io.netty.util.internal.PlatformDependent; + +import java.util.Deque; +import java.util.concurrent.Callable; + +import static io.netty.util.internal.ObjectUtil.*; + +/** + * Simple {@link ChannelPool} implementation which will create new {@link Channel}s if someone tries to acquire + * a {@link Channel} but none is in the pool atm. No limit on the maximal concurrent {@link Channel}s is enforced. + * + * This implementation uses LIFO order for {@link Channel}s in the {@link ChannelPool}. + * + */ +public class SimpleChannelPool implements ChannelPool { + private static final AttributeKey POOL_KEY = + AttributeKey.newInstance("io.netty.channel.pool.SimpleChannelPool"); + private final Deque deque = PlatformDependent.newConcurrentDeque(); + private final ChannelPoolHandler handler; + private final ChannelHealthChecker healthCheck; + private final Bootstrap bootstrap; + private final boolean releaseHealthCheck; + private final boolean lastRecentUsed; + + /** + * Creates a new instance using the {@link ChannelHealthChecker#ACTIVE}. + * + * @param bootstrap the {@link Bootstrap} that is used for connections + * @param handler the {@link ChannelPoolHandler} that will be notified for the different pool actions + */ + public SimpleChannelPool(Bootstrap bootstrap, final ChannelPoolHandler handler) { + this(bootstrap, handler, ChannelHealthChecker.ACTIVE); + } + + /** + * Creates a new instance. + * + * @param bootstrap the {@link Bootstrap} that is used for connections + * @param handler the {@link ChannelPoolHandler} that will be notified for the different pool actions + * @param healthCheck the {@link ChannelHealthChecker} that will be used to check if a {@link Channel} is + * still healthy when obtain from the {@link ChannelPool} + */ + public SimpleChannelPool(Bootstrap bootstrap, final ChannelPoolHandler handler, ChannelHealthChecker healthCheck) { + this(bootstrap, handler, healthCheck, true); + } + + /** + * Creates a new instance. + * + * @param bootstrap the {@link Bootstrap} that is used for connections + * @param handler the {@link ChannelPoolHandler} that will be notified for the different pool actions + * @param healthCheck the {@link ChannelHealthChecker} that will be used to check if a {@link Channel} is + * still healthy when obtain from the {@link ChannelPool} + * @param releaseHealthCheck will check channel health before offering back if this parameter set to {@code true}; + * otherwise, channel health is only checked at acquisition time + */ + public SimpleChannelPool(Bootstrap bootstrap, final ChannelPoolHandler handler, ChannelHealthChecker healthCheck, + boolean releaseHealthCheck) { + this(bootstrap, handler, healthCheck, releaseHealthCheck, true); + } + + /** + * Creates a new instance. + * + * @param bootstrap the {@link Bootstrap} that is used for connections + * @param handler the {@link ChannelPoolHandler} that will be notified for the different pool actions + * @param healthCheck the {@link ChannelHealthChecker} that will be used to check if a {@link Channel} is + * still healthy when obtain from the {@link ChannelPool} + * @param releaseHealthCheck will check channel health before offering back if this parameter set to {@code true}; + * otherwise, channel health is only checked at acquisition time + * @param lastRecentUsed {@code true} {@link Channel} selection will be LIFO, if {@code false} FIFO. + */ + public SimpleChannelPool(Bootstrap bootstrap, final ChannelPoolHandler handler, ChannelHealthChecker healthCheck, + boolean releaseHealthCheck, boolean lastRecentUsed) { + this.handler = checkNotNull(handler, "handler"); + this.healthCheck = checkNotNull(healthCheck, "healthCheck"); + this.releaseHealthCheck = releaseHealthCheck; + // Clone the original Bootstrap as we want to set our own handler + this.bootstrap = checkNotNull(bootstrap, "bootstrap").clone(); + this.bootstrap.handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + assert ch.eventLoop().inEventLoop(); + handler.channelCreated(ch); + } + }); + this.lastRecentUsed = lastRecentUsed; + } + + /** + * Returns the {@link Bootstrap} this pool will use to open new connections. + * + * @return the {@link Bootstrap} this pool will use to open new connections + */ + protected Bootstrap bootstrap() { + return bootstrap; + } + + /** + * Returns the {@link ChannelPoolHandler} that will be notified for the different pool actions. + * + * @return the {@link ChannelPoolHandler} that will be notified for the different pool actions + */ + protected ChannelPoolHandler handler() { + return handler; + } + + /** + * Returns the {@link ChannelHealthChecker} that will be used to check if a {@link Channel} is healthy. + * + * @return the {@link ChannelHealthChecker} that will be used to check if a {@link Channel} is healthy + */ + protected ChannelHealthChecker healthChecker() { + return healthCheck; + } + + /** + * Indicates whether this pool will check the health of channels before offering them back into the pool. + * + * @return {@code true} if this pool will check the health of channels before offering them back into the pool, or + * {@code false} if channel health is only checked at acquisition time + */ + protected boolean releaseHealthCheck() { + return releaseHealthCheck; + } + + @Override + public final Future acquire() { + return acquire(bootstrap.config().group().next().newPromise()); + } + + @Override + public Future acquire(final Promise promise) { + return acquireHealthyFromPoolOrNew(checkNotNull(promise, "promise")); + } + + /** + * Tries to retrieve healthy channel from the pool if any or creates a new channel otherwise. + * @param promise the promise to provide acquire result. + * @return future for acquiring a channel. + */ + private Future acquireHealthyFromPoolOrNew(final Promise promise) { + try { + final Channel ch = pollChannel(); + if (ch == null) { + // No Channel left in the pool bootstrap a new Channel + Bootstrap bs = bootstrap.clone(); + bs.attr(POOL_KEY, this); + ChannelFuture f = connectChannel(bs); + if (f.isDone()) { + notifyConnect(f, promise); + } else { + f.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + notifyConnect(future, promise); + } + }); + } + } else { + EventLoop loop = ch.eventLoop(); + if (loop.inEventLoop()) { + doHealthCheck(ch, promise); + } else { + loop.execute(new Runnable() { + @Override + public void run() { + doHealthCheck(ch, promise); + } + }); + } + } + } catch (Throwable cause) { + promise.tryFailure(cause); + } + return promise; + } + + private void notifyConnect(ChannelFuture future, Promise promise) { + Channel channel = null; + try { + if (future.isSuccess()) { + channel = future.channel(); + handler.channelAcquired(channel); + if (!promise.trySuccess(channel)) { + // Promise was completed in the meantime (like cancelled), just release the channel again + release(channel); + } + } else { + promise.tryFailure(future.cause()); + } + } catch (Throwable cause) { + closeAndFail(channel, cause, promise); + } + } + + private void doHealthCheck(final Channel channel, final Promise promise) { + try { + assert channel.eventLoop().inEventLoop(); + Future f = healthCheck.isHealthy(channel); + if (f.isDone()) { + notifyHealthCheck(f, channel, promise); + } else { + f.addListener(new FutureListener() { + @Override + public void operationComplete(Future future) { + notifyHealthCheck(future, channel, promise); + } + }); + } + } catch (Throwable cause) { + closeAndFail(channel, cause, promise); + } + } + + private void notifyHealthCheck(Future future, Channel channel, Promise promise) { + try { + assert channel.eventLoop().inEventLoop(); + if (future.isSuccess() && future.getNow()) { + channel.attr(POOL_KEY).set(this); + handler.channelAcquired(channel); + promise.setSuccess(channel); + } else { + closeChannel(channel); + acquireHealthyFromPoolOrNew(promise); + } + } catch (Throwable cause) { + closeAndFail(channel, cause, promise); + } + } + + /** + * Bootstrap a new {@link Channel}. The default implementation uses {@link Bootstrap#connect()}, sub-classes may + * override this. + *

+ * The {@link Bootstrap} that is passed in here is cloned via {@link Bootstrap#clone()}, so it is safe to modify. + */ + protected ChannelFuture connectChannel(Bootstrap bs) { + return bs.connect(); + } + + @Override + public final Future release(Channel channel) { + return release(channel, channel.eventLoop().newPromise()); + } + + @Override + public Future release(final Channel channel, final Promise promise) { + try { + checkNotNull(channel, "channel"); + checkNotNull(promise, "promise"); + EventLoop loop = channel.eventLoop(); + if (loop.inEventLoop()) { + doReleaseChannel(channel, promise); + } else { + loop.execute(new Runnable() { + @Override + public void run() { + doReleaseChannel(channel, promise); + } + }); + } + } catch (Throwable cause) { + closeAndFail(channel, cause, promise); + } + return promise; + } + + private void doReleaseChannel(Channel channel, Promise promise) { + try { + assert channel.eventLoop().inEventLoop(); + // Remove the POOL_KEY attribute from the Channel and check if it was acquired from this pool, if not fail. + if (channel.attr(POOL_KEY).getAndSet(null) != this) { + closeAndFail(channel, + // Better include a stacktrace here as this is an user error. + new IllegalArgumentException( + "Channel " + channel + " was not acquired from this ChannelPool"), + promise); + } else { + if (releaseHealthCheck) { + doHealthCheckOnRelease(channel, promise); + } else { + releaseAndOffer(channel, promise); + } + } + } catch (Throwable cause) { + closeAndFail(channel, cause, promise); + } + } + + private void doHealthCheckOnRelease(final Channel channel, final Promise promise) throws Exception { + final Future f = healthCheck.isHealthy(channel); + if (f.isDone()) { + releaseAndOfferIfHealthy(channel, promise, f); + } else { + f.addListener(new FutureListener() { + @Override + public void operationComplete(Future future) throws Exception { + releaseAndOfferIfHealthy(channel, promise, f); + } + }); + } + } + + /** + * Adds the channel back to the pool only if the channel is healthy. + * @param channel the channel to put back to the pool + * @param promise offer operation promise. + * @param future the future that contains information fif channel is healthy or not. + * @throws Exception in case when failed to notify handler about release operation. + */ + private void releaseAndOfferIfHealthy(Channel channel, Promise promise, Future future) { + try { + if (future.getNow()) { //channel turns out to be healthy, offering and releasing it. + releaseAndOffer(channel, promise); + } else { //channel not healthy, just releasing it. + handler.channelReleased(channel); + promise.setSuccess(null); + } + } catch (Throwable cause) { + closeAndFail(channel, cause, promise); + } + } + + private void releaseAndOffer(Channel channel, Promise promise) throws Exception { + if (offerChannel(channel)) { + handler.channelReleased(channel); + promise.setSuccess(null); + } else { + closeAndFail(channel, new ChannelPoolFullException(), promise); + } + } + + private void closeChannel(Channel channel) throws Exception { + channel.attr(POOL_KEY).getAndSet(null); + channel.close(); + } + + private void closeAndFail(Channel channel, Throwable cause, Promise promise) { + if (channel != null) { + try { + closeChannel(channel); + } catch (Throwable t) { + promise.tryFailure(t); + } + } + promise.tryFailure(cause); + } + + /** + * Poll a {@link Channel} out of the internal storage to reuse it. This will return {@code null} if no + * {@link Channel} is ready to be reused. + * + * Sub-classes may override {@link #pollChannel()} and {@link #offerChannel(Channel)}. Be aware that + * implementations of these methods needs to be thread-safe! + */ + protected Channel pollChannel() { + return lastRecentUsed ? deque.pollLast() : deque.pollFirst(); + } + + /** + * Offer a {@link Channel} back to the internal storage. This will return {@code true} if the {@link Channel} + * could be added, {@code false} otherwise. + * + * Sub-classes may override {@link #pollChannel()} and {@link #offerChannel(Channel)}. Be aware that + * implementations of these methods needs to be thread-safe! + */ + protected boolean offerChannel(Channel channel) { + return deque.offer(channel); + } + + @Override + public void close() { + for (;;) { + Channel channel = pollChannel(); + if (channel == null) { + break; + } + // Just ignore any errors that are reported back from close(). + channel.close().awaitUninterruptibly(); + } + } + + /** + * Closes the pool in an async manner. + * + * @return Future which represents completion of the close task + */ + public Future closeAsync() { + // Execute close asynchronously in case this is being invoked on an eventloop to avoid blocking + return GlobalEventExecutor.INSTANCE.submit(new Callable() { + @Override + public Void call() throws Exception { + close(); + return null; + } + }); + } + + private static final class ChannelPoolFullException extends IllegalStateException { + + private ChannelPoolFullException() { + super("ChannelPool full"); + } + + // Suppress a warning since the method doesn't need synchronization + @Override + public Throwable fillInStackTrace() { + return this; + } + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/pool/package-info.java b/netty-channel/src/main/java/io/netty/channel/pool/package-info.java new file mode 100644 index 0000000..3af6c6f --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/pool/package-info.java @@ -0,0 +1,20 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/** + * Implementations and API for {@link io.netty.channel.Channel} pools. + */ +package io.netty.channel.pool; diff --git a/netty-channel/src/main/java/io/netty/channel/socket/ChannelInputShutdownEvent.java b/netty-channel/src/main/java/io/netty/channel/socket/ChannelInputShutdownEvent.java new file mode 100644 index 0000000..9e958b4 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/socket/ChannelInputShutdownEvent.java @@ -0,0 +1,36 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.socket; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandler; + +/** + * Special event which will be fired and passed to the + * {@link ChannelInboundHandler#userEventTriggered(ChannelHandlerContext, Object)} methods once the input of + * a {@link SocketChannel} was shutdown and the {@link SocketChannelConfig#isAllowHalfClosure()} method returns + * {@code true}. + */ +public final class ChannelInputShutdownEvent { + + /** + * Instance to use + */ + @SuppressWarnings("InstantiationOfUtilityClass") + public static final ChannelInputShutdownEvent INSTANCE = new ChannelInputShutdownEvent(); + + private ChannelInputShutdownEvent() { } +} diff --git a/netty-channel/src/main/java/io/netty/channel/socket/ChannelInputShutdownReadComplete.java b/netty-channel/src/main/java/io/netty/channel/socket/ChannelInputShutdownReadComplete.java new file mode 100644 index 0000000..3a0747d --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/socket/ChannelInputShutdownReadComplete.java @@ -0,0 +1,27 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.socket; + +/** + * User event that signifies the channel's input side is shutdown, and we tried to shut it down again. This typically + * indicates that there is no more data to read. + */ +public final class ChannelInputShutdownReadComplete { + public static final ChannelInputShutdownReadComplete INSTANCE = new ChannelInputShutdownReadComplete(); + + private ChannelInputShutdownReadComplete() { + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/socket/ChannelOutputShutdownEvent.java b/netty-channel/src/main/java/io/netty/channel/socket/ChannelOutputShutdownEvent.java new file mode 100644 index 0000000..8db9eea --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/socket/ChannelOutputShutdownEvent.java @@ -0,0 +1,33 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.socket; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandler; +import io.netty.util.internal.UnstableApi; + +/** + * Special event which will be fired and passed to the + * {@link ChannelInboundHandler#userEventTriggered(ChannelHandlerContext, Object)} methods once the output of + * a {@link SocketChannel} was shutdown. + */ +@UnstableApi +public final class ChannelOutputShutdownEvent { + public static final ChannelOutputShutdownEvent INSTANCE = new ChannelOutputShutdownEvent(); + + private ChannelOutputShutdownEvent() { + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/socket/ChannelOutputShutdownException.java b/netty-channel/src/main/java/io/netty/channel/socket/ChannelOutputShutdownException.java new file mode 100644 index 0000000..03edb3b --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/socket/ChannelOutputShutdownException.java @@ -0,0 +1,38 @@ + + +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.socket; + +import io.netty.util.internal.UnstableApi; + +import java.io.IOException; + +/** + * Used to fail pending writes when a channel's output has been shutdown. + */ +@UnstableApi +public final class ChannelOutputShutdownException extends IOException { + private static final long serialVersionUID = 6712549938359321378L; + + public ChannelOutputShutdownException(String msg) { + super(msg); + } + + public ChannelOutputShutdownException(String msg, Throwable cause) { + super(msg, cause); + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/socket/DatagramChannel.java b/netty-channel/src/main/java/io/netty/channel/socket/DatagramChannel.java new file mode 100644 index 0000000..1b9e446 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/socket/DatagramChannel.java @@ -0,0 +1,165 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.socket; + +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelPromise; + +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.NetworkInterface; + +/** + * A UDP/IP {@link Channel}. + */ +public interface DatagramChannel extends Channel { + @Override + DatagramChannelConfig config(); + @Override + InetSocketAddress localAddress(); + @Override + InetSocketAddress remoteAddress(); + + /** + * Return {@code true} if the {@link DatagramChannel} is connected to the remote peer. + */ + boolean isConnected(); + + /** + * Joins a multicast group and notifies the {@link ChannelFuture} once the operation completes. + */ + ChannelFuture joinGroup(InetAddress multicastAddress); + + /** + * Joins a multicast group and notifies the {@link ChannelFuture} once the operation completes. + * + * The given {@link ChannelFuture} will be notified and also returned. + */ + ChannelFuture joinGroup(InetAddress multicastAddress, ChannelPromise future); + + /** + * Joins the specified multicast group at the specified interface and notifies the {@link ChannelFuture} + * once the operation completes. + */ + ChannelFuture joinGroup(InetSocketAddress multicastAddress, NetworkInterface networkInterface); + + /** + * Joins the specified multicast group at the specified interface and notifies the {@link ChannelFuture} + * once the operation completes. + * + * The given {@link ChannelFuture} will be notified and also returned. + */ + ChannelFuture joinGroup( + InetSocketAddress multicastAddress, NetworkInterface networkInterface, ChannelPromise future); + + /** + * Joins the specified multicast group at the specified interface and notifies the {@link ChannelFuture} + * once the operation completes. + */ + ChannelFuture joinGroup(InetAddress multicastAddress, NetworkInterface networkInterface, InetAddress source); + + /** + * Joins the specified multicast group at the specified interface and notifies the {@link ChannelFuture} + * once the operation completes. + * + * The given {@link ChannelFuture} will be notified and also returned. + */ + ChannelFuture joinGroup( + InetAddress multicastAddress, NetworkInterface networkInterface, InetAddress source, ChannelPromise future); + + /** + * Leaves a multicast group and notifies the {@link ChannelFuture} once the operation completes. + */ + ChannelFuture leaveGroup(InetAddress multicastAddress); + + /** + * Leaves a multicast group and notifies the {@link ChannelFuture} once the operation completes. + * + * The given {@link ChannelFuture} will be notified and also returned. + */ + ChannelFuture leaveGroup(InetAddress multicastAddress, ChannelPromise future); + + /** + * Leaves a multicast group on a specified local interface and notifies the {@link ChannelFuture} once the + * operation completes. + */ + ChannelFuture leaveGroup(InetSocketAddress multicastAddress, NetworkInterface networkInterface); + + /** + * Leaves a multicast group on a specified local interface and notifies the {@link ChannelFuture} once the + * operation completes. + * + * The given {@link ChannelFuture} will be notified and also returned. + */ + ChannelFuture leaveGroup( + InetSocketAddress multicastAddress, NetworkInterface networkInterface, ChannelPromise future); + + /** + * Leave the specified multicast group at the specified interface using the specified source and notifies + * the {@link ChannelFuture} once the operation completes. + * + */ + ChannelFuture leaveGroup( + InetAddress multicastAddress, NetworkInterface networkInterface, InetAddress source); + + /** + * Leave the specified multicast group at the specified interface using the specified source and notifies + * the {@link ChannelFuture} once the operation completes. + * + * The given {@link ChannelFuture} will be notified and also returned. + */ + ChannelFuture leaveGroup( + InetAddress multicastAddress, NetworkInterface networkInterface, InetAddress source, + ChannelPromise future); + + /** + * Block the given sourceToBlock address for the given multicastAddress on the given networkInterface and notifies + * the {@link ChannelFuture} once the operation completes. + * + * The given {@link ChannelFuture} will be notified and also returned. + */ + ChannelFuture block( + InetAddress multicastAddress, NetworkInterface networkInterface, + InetAddress sourceToBlock); + + /** + * Block the given sourceToBlock address for the given multicastAddress on the given networkInterface and notifies + * the {@link ChannelFuture} once the operation completes. + * + * The given {@link ChannelFuture} will be notified and also returned. + */ + ChannelFuture block( + InetAddress multicastAddress, NetworkInterface networkInterface, + InetAddress sourceToBlock, ChannelPromise future); + + /** + * Block the given sourceToBlock address for the given multicastAddress and notifies the {@link ChannelFuture} once + * the operation completes. + * + * The given {@link ChannelFuture} will be notified and also returned. + */ + ChannelFuture block(InetAddress multicastAddress, InetAddress sourceToBlock); + + /** + * Block the given sourceToBlock address for the given multicastAddress and notifies the {@link ChannelFuture} once + * the operation completes. + * + * The given {@link ChannelFuture} will be notified and also returned. + */ + ChannelFuture block( + InetAddress multicastAddress, InetAddress sourceToBlock, ChannelPromise future); +} diff --git a/netty-channel/src/main/java/io/netty/channel/socket/DatagramChannelConfig.java b/netty-channel/src/main/java/io/netty/channel/socket/DatagramChannelConfig.java new file mode 100644 index 0000000..9c4dde9 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/socket/DatagramChannelConfig.java @@ -0,0 +1,188 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.socket; + +import io.netty.buffer.ByteBufAllocator; +import io.netty.channel.ChannelConfig; +import io.netty.channel.ChannelOption; +import io.netty.channel.MessageSizeEstimator; +import io.netty.channel.RecvByteBufAllocator; +import io.netty.channel.WriteBufferWaterMark; + +import java.net.InetAddress; +import java.net.NetworkInterface; +import java.net.StandardSocketOptions; + +/** + * A {@link ChannelConfig} for a {@link DatagramChannel}. + * + *

Available options

+ * + * In addition to the options provided by {@link ChannelConfig}, + * {@link DatagramChannelConfig} allows the following options in the option map: + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
NameAssociated setter method
{@link ChannelOption#SO_BROADCAST}{@link #setBroadcast(boolean)}
{@link ChannelOption#IP_MULTICAST_ADDR}{@link #setInterface(InetAddress)}
{@link ChannelOption#IP_MULTICAST_LOOP_DISABLED}{@link #setLoopbackModeDisabled(boolean)}
{@link ChannelOption#IP_MULTICAST_IF}{@link #setNetworkInterface(NetworkInterface)}
{@link ChannelOption#SO_REUSEADDR}{@link #setReuseAddress(boolean)}
{@link ChannelOption#SO_RCVBUF}{@link #setReceiveBufferSize(int)}
{@link ChannelOption#SO_SNDBUF}{@link #setSendBufferSize(int)}
{@link ChannelOption#IP_MULTICAST_TTL}{@link #setTimeToLive(int)}
{@link ChannelOption#IP_TOS}{@link #setTrafficClass(int)}
+ */ +public interface DatagramChannelConfig extends ChannelConfig { + + /** + * Gets the {@link StandardSocketOptions#SO_SNDBUF} option. + */ + int getSendBufferSize(); + + /** + * Sets the {@link StandardSocketOptions#SO_SNDBUF} option. + */ + DatagramChannelConfig setSendBufferSize(int sendBufferSize); + + /** + * Gets the {@link StandardSocketOptions#SO_RCVBUF} option. + */ + int getReceiveBufferSize(); + + /** + * Sets the {@link StandardSocketOptions#SO_RCVBUF} option. + */ + DatagramChannelConfig setReceiveBufferSize(int receiveBufferSize); + + /** + * Gets the {@link StandardSocketOptions#IP_TOS} option. + */ + int getTrafficClass(); + + /** + * Sets the {@link StandardSocketOptions#IP_TOS} option. + */ + DatagramChannelConfig setTrafficClass(int trafficClass); + + /** + * Gets the {@link StandardSocketOptions#SO_REUSEADDR} option. + */ + boolean isReuseAddress(); + + /** + * Gets the {@link StandardSocketOptions#SO_REUSEADDR} option. + */ + DatagramChannelConfig setReuseAddress(boolean reuseAddress); + + /** + * Gets the {@link StandardSocketOptions#SO_BROADCAST} option. + */ + boolean isBroadcast(); + + /** + * Sets the {@link StandardSocketOptions#SO_BROADCAST} option. + */ + DatagramChannelConfig setBroadcast(boolean broadcast); + + /** + * Gets the {@link StandardSocketOptions#IP_MULTICAST_LOOP} option. + * + * @return {@code true} if and only if the loopback mode has been disabled + */ + boolean isLoopbackModeDisabled(); + + /** + * Sets the {@link StandardSocketOptions#IP_MULTICAST_LOOP} option. + * + * @param loopbackModeDisabled + * {@code true} if and only if the loopback mode has been disabled + */ + DatagramChannelConfig setLoopbackModeDisabled(boolean loopbackModeDisabled); + + /** + * Gets the {@link StandardSocketOptions#IP_MULTICAST_TTL} option. + */ + int getTimeToLive(); + + /** + * Sets the {@link StandardSocketOptions#IP_MULTICAST_TTL} option. + */ + DatagramChannelConfig setTimeToLive(int ttl); + + /** + * Gets the address of the network interface used for multicast packets. + */ + InetAddress getInterface(); + + /** + * Sets the address of the network interface used for multicast packets. + */ + DatagramChannelConfig setInterface(InetAddress interfaceAddress); + + /** + * Gets the {@link StandardSocketOptions#IP_MULTICAST_IF} option. + */ + NetworkInterface getNetworkInterface(); + + /** + * Sets the {@link StandardSocketOptions#IP_MULTICAST_IF} option. + */ + DatagramChannelConfig setNetworkInterface(NetworkInterface networkInterface); + + @Override + @Deprecated + DatagramChannelConfig setMaxMessagesPerRead(int maxMessagesPerRead); + + @Override + DatagramChannelConfig setWriteSpinCount(int writeSpinCount); + + @Override + DatagramChannelConfig setConnectTimeoutMillis(int connectTimeoutMillis); + + @Override + DatagramChannelConfig setAllocator(ByteBufAllocator allocator); + + @Override + DatagramChannelConfig setRecvByteBufAllocator(RecvByteBufAllocator allocator); + + @Override + DatagramChannelConfig setAutoRead(boolean autoRead); + + @Override + DatagramChannelConfig setAutoClose(boolean autoClose); + + @Override + DatagramChannelConfig setMessageSizeEstimator(MessageSizeEstimator estimator); + + @Override + DatagramChannelConfig setWriteBufferWaterMark(WriteBufferWaterMark writeBufferWaterMark); + +} diff --git a/netty-channel/src/main/java/io/netty/channel/socket/DatagramPacket.java b/netty-channel/src/main/java/io/netty/channel/socket/DatagramPacket.java new file mode 100644 index 0000000..c5d1cd7 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/socket/DatagramPacket.java @@ -0,0 +1,88 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.socket; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufHolder; +import io.netty.channel.DefaultAddressedEnvelope; + +import java.net.InetSocketAddress; + +/** + * The message container that is used for {@link DatagramChannel} to communicate with the remote peer. + */ +public class DatagramPacket + extends DefaultAddressedEnvelope implements ByteBufHolder { + + /** + * Create a new instance with the specified packet {@code data} and {@code recipient} address. + */ + public DatagramPacket(ByteBuf data, InetSocketAddress recipient) { + super(data, recipient); + } + + /** + * Create a new instance with the specified packet {@code data}, {@code recipient} address, and {@code sender} + * address. + */ + public DatagramPacket(ByteBuf data, InetSocketAddress recipient, InetSocketAddress sender) { + super(data, recipient, sender); + } + + @Override + public DatagramPacket copy() { + return replace(content().copy()); + } + + @Override + public DatagramPacket duplicate() { + return replace(content().duplicate()); + } + + @Override + public DatagramPacket retainedDuplicate() { + return replace(content().retainedDuplicate()); + } + + @Override + public DatagramPacket replace(ByteBuf content) { + return new DatagramPacket(content, recipient(), sender()); + } + + @Override + public DatagramPacket retain() { + super.retain(); + return this; + } + + @Override + public DatagramPacket retain(int increment) { + super.retain(increment); + return this; + } + + @Override + public DatagramPacket touch() { + super.touch(); + return this; + } + + @Override + public DatagramPacket touch(Object hint) { + super.touch(hint); + return this; + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/socket/DefaultDatagramChannelConfig.java b/netty-channel/src/main/java/io/netty/channel/socket/DefaultDatagramChannelConfig.java new file mode 100644 index 0000000..32a8049 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/socket/DefaultDatagramChannelConfig.java @@ -0,0 +1,435 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.socket; + +import io.netty.buffer.ByteBufAllocator; +import io.netty.channel.ChannelException; +import io.netty.channel.ChannelOption; +import io.netty.channel.DefaultChannelConfig; +import io.netty.channel.FixedRecvByteBufAllocator; +import io.netty.channel.MessageSizeEstimator; +import io.netty.channel.RecvByteBufAllocator; +import io.netty.channel.WriteBufferWaterMark; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.io.IOException; +import java.net.DatagramSocket; +import java.net.InetAddress; +import java.net.MulticastSocket; +import java.net.NetworkInterface; +import java.net.SocketException; +import java.util.Map; + +import static io.netty.channel.ChannelOption.*; + +/** + * The default {@link DatagramChannelConfig} implementation. + */ +public class DefaultDatagramChannelConfig extends DefaultChannelConfig implements DatagramChannelConfig { + + private static final InternalLogger logger = InternalLoggerFactory.getInstance(DefaultDatagramChannelConfig.class); + + private final DatagramSocket javaSocket; + private volatile boolean activeOnOpen; + + /** + * Creates a new instance. + */ + public DefaultDatagramChannelConfig(DatagramChannel channel, DatagramSocket javaSocket) { + super(channel, new FixedRecvByteBufAllocator(2048)); + this.javaSocket = ObjectUtil.checkNotNull(javaSocket, "javaSocket"); + } + + protected final DatagramSocket javaSocket() { + return javaSocket; + } + + @Override + @SuppressWarnings("deprecation") + public Map, Object> getOptions() { + return getOptions( + super.getOptions(), + SO_BROADCAST, SO_RCVBUF, SO_SNDBUF, SO_REUSEADDR, IP_MULTICAST_LOOP_DISABLED, + IP_MULTICAST_ADDR, IP_MULTICAST_IF, IP_MULTICAST_TTL, IP_TOS, DATAGRAM_CHANNEL_ACTIVE_ON_REGISTRATION); + } + + @Override + @SuppressWarnings({ "unchecked", "deprecation" }) + public T getOption(ChannelOption option) { + if (option == SO_BROADCAST) { + return (T) Boolean.valueOf(isBroadcast()); + } + if (option == SO_RCVBUF) { + return (T) Integer.valueOf(getReceiveBufferSize()); + } + if (option == SO_SNDBUF) { + return (T) Integer.valueOf(getSendBufferSize()); + } + if (option == SO_REUSEADDR) { + return (T) Boolean.valueOf(isReuseAddress()); + } + if (option == IP_MULTICAST_LOOP_DISABLED) { + return (T) Boolean.valueOf(isLoopbackModeDisabled()); + } + if (option == IP_MULTICAST_ADDR) { + return (T) getInterface(); + } + if (option == IP_MULTICAST_IF) { + return (T) getNetworkInterface(); + } + if (option == IP_MULTICAST_TTL) { + return (T) Integer.valueOf(getTimeToLive()); + } + if (option == IP_TOS) { + return (T) Integer.valueOf(getTrafficClass()); + } + if (option == DATAGRAM_CHANNEL_ACTIVE_ON_REGISTRATION) { + return (T) Boolean.valueOf(activeOnOpen); + } + return super.getOption(option); + } + + @Override + @SuppressWarnings("deprecation") + public boolean setOption(ChannelOption option, T value) { + validate(option, value); + + if (option == SO_BROADCAST) { + setBroadcast((Boolean) value); + } else if (option == SO_RCVBUF) { + setReceiveBufferSize((Integer) value); + } else if (option == SO_SNDBUF) { + setSendBufferSize((Integer) value); + } else if (option == SO_REUSEADDR) { + setReuseAddress((Boolean) value); + } else if (option == IP_MULTICAST_LOOP_DISABLED) { + setLoopbackModeDisabled((Boolean) value); + } else if (option == IP_MULTICAST_ADDR) { + setInterface((InetAddress) value); + } else if (option == IP_MULTICAST_IF) { + setNetworkInterface((NetworkInterface) value); + } else if (option == IP_MULTICAST_TTL) { + setTimeToLive((Integer) value); + } else if (option == IP_TOS) { + setTrafficClass((Integer) value); + } else if (option == DATAGRAM_CHANNEL_ACTIVE_ON_REGISTRATION) { + setActiveOnOpen((Boolean) value); + } else { + return super.setOption(option, value); + } + + return true; + } + + private void setActiveOnOpen(boolean activeOnOpen) { + if (channel.isRegistered()) { + throw new IllegalStateException("Can only changed before channel was registered"); + } + this.activeOnOpen = activeOnOpen; + } + + @Override + public boolean isBroadcast() { + try { + return javaSocket.getBroadcast(); + } catch (SocketException e) { + throw new ChannelException(e); + } + } + + @Override + public DatagramChannelConfig setBroadcast(boolean broadcast) { + try { + // See: https://github.com/netty/netty/issues/576 + if (broadcast && + !javaSocket.getLocalAddress().isAnyLocalAddress() && + !PlatformDependent.isWindows() && !PlatformDependent.maybeSuperUser()) { + // Warn a user about the fact that a non-root user can't receive a + // broadcast packet on *nix if the socket is bound on non-wildcard address. + logger.warn( + "A non-root user can't receive a broadcast packet if the socket " + + "is not bound to a wildcard address; setting the SO_BROADCAST flag " + + "anyway as requested on the socket which is bound to " + + javaSocket.getLocalSocketAddress() + '.'); + } + + javaSocket.setBroadcast(broadcast); + } catch (SocketException e) { + throw new ChannelException(e); + } + return this; + } + + @Override + public InetAddress getInterface() { + if (javaSocket instanceof MulticastSocket) { + try { + return ((MulticastSocket) javaSocket).getInterface(); + } catch (SocketException e) { + throw new ChannelException(e); + } + } else { + throw new UnsupportedOperationException(); + } + } + + @Override + public DatagramChannelConfig setInterface(InetAddress interfaceAddress) { + if (javaSocket instanceof MulticastSocket) { + try { + ((MulticastSocket) javaSocket).setInterface(interfaceAddress); + } catch (SocketException e) { + throw new ChannelException(e); + } + } else { + throw new UnsupportedOperationException(); + } + return this; + } + + @Override + public boolean isLoopbackModeDisabled() { + if (javaSocket instanceof MulticastSocket) { + try { + return ((MulticastSocket) javaSocket).getLoopbackMode(); + } catch (SocketException e) { + throw new ChannelException(e); + } + } else { + throw new UnsupportedOperationException(); + } + } + + @Override + public DatagramChannelConfig setLoopbackModeDisabled(boolean loopbackModeDisabled) { + if (javaSocket instanceof MulticastSocket) { + try { + ((MulticastSocket) javaSocket).setLoopbackMode(loopbackModeDisabled); + } catch (SocketException e) { + throw new ChannelException(e); + } + } else { + throw new UnsupportedOperationException(); + } + return this; + } + + @Override + public NetworkInterface getNetworkInterface() { + if (javaSocket instanceof MulticastSocket) { + try { + return ((MulticastSocket) javaSocket).getNetworkInterface(); + } catch (SocketException e) { + throw new ChannelException(e); + } + } else { + throw new UnsupportedOperationException(); + } + } + + @Override + public DatagramChannelConfig setNetworkInterface(NetworkInterface networkInterface) { + if (javaSocket instanceof MulticastSocket) { + try { + ((MulticastSocket) javaSocket).setNetworkInterface(networkInterface); + } catch (SocketException e) { + throw new ChannelException(e); + } + } else { + throw new UnsupportedOperationException(); + } + return this; + } + + @Override + public boolean isReuseAddress() { + try { + return javaSocket.getReuseAddress(); + } catch (SocketException e) { + throw new ChannelException(e); + } + } + + @Override + public DatagramChannelConfig setReuseAddress(boolean reuseAddress) { + try { + javaSocket.setReuseAddress(reuseAddress); + } catch (SocketException e) { + throw new ChannelException(e); + } + return this; + } + + @Override + public int getReceiveBufferSize() { + try { + return javaSocket.getReceiveBufferSize(); + } catch (SocketException e) { + throw new ChannelException(e); + } + } + + @Override + public DatagramChannelConfig setReceiveBufferSize(int receiveBufferSize) { + try { + javaSocket.setReceiveBufferSize(receiveBufferSize); + } catch (SocketException e) { + throw new ChannelException(e); + } + return this; + } + + @Override + public int getSendBufferSize() { + try { + return javaSocket.getSendBufferSize(); + } catch (SocketException e) { + throw new ChannelException(e); + } + } + + @Override + public DatagramChannelConfig setSendBufferSize(int sendBufferSize) { + try { + javaSocket.setSendBufferSize(sendBufferSize); + } catch (SocketException e) { + throw new ChannelException(e); + } + return this; + } + + @Override + public int getTimeToLive() { + if (javaSocket instanceof MulticastSocket) { + try { + return ((MulticastSocket) javaSocket).getTimeToLive(); + } catch (IOException e) { + throw new ChannelException(e); + } + } else { + throw new UnsupportedOperationException(); + } + } + + @Override + public DatagramChannelConfig setTimeToLive(int ttl) { + if (javaSocket instanceof MulticastSocket) { + try { + ((MulticastSocket) javaSocket).setTimeToLive(ttl); + } catch (IOException e) { + throw new ChannelException(e); + } + } else { + throw new UnsupportedOperationException(); + } + return this; + } + + @Override + public int getTrafficClass() { + try { + return javaSocket.getTrafficClass(); + } catch (SocketException e) { + throw new ChannelException(e); + } + } + + @Override + public DatagramChannelConfig setTrafficClass(int trafficClass) { + try { + javaSocket.setTrafficClass(trafficClass); + } catch (SocketException e) { + throw new ChannelException(e); + } + return this; + } + + @Override + public DatagramChannelConfig setWriteSpinCount(int writeSpinCount) { + super.setWriteSpinCount(writeSpinCount); + return this; + } + + @Override + public DatagramChannelConfig setConnectTimeoutMillis(int connectTimeoutMillis) { + super.setConnectTimeoutMillis(connectTimeoutMillis); + return this; + } + + @Override + @Deprecated + public DatagramChannelConfig setMaxMessagesPerRead(int maxMessagesPerRead) { + super.setMaxMessagesPerRead(maxMessagesPerRead); + return this; + } + + @Override + public DatagramChannelConfig setAllocator(ByteBufAllocator allocator) { + super.setAllocator(allocator); + return this; + } + + @Override + public DatagramChannelConfig setRecvByteBufAllocator(RecvByteBufAllocator allocator) { + super.setRecvByteBufAllocator(allocator); + return this; + } + + @Override + public DatagramChannelConfig setAutoRead(boolean autoRead) { + super.setAutoRead(autoRead); + return this; + } + + @Override + public DatagramChannelConfig setAutoClose(boolean autoClose) { + super.setAutoClose(autoClose); + return this; + } + + @Override + public DatagramChannelConfig setWriteBufferHighWaterMark(int writeBufferHighWaterMark) { + super.setWriteBufferHighWaterMark(writeBufferHighWaterMark); + return this; + } + + @Override + public DatagramChannelConfig setWriteBufferLowWaterMark(int writeBufferLowWaterMark) { + super.setWriteBufferLowWaterMark(writeBufferLowWaterMark); + return this; + } + + @Override + public DatagramChannelConfig setWriteBufferWaterMark(WriteBufferWaterMark writeBufferWaterMark) { + super.setWriteBufferWaterMark(writeBufferWaterMark); + return this; + } + + @Override + public DatagramChannelConfig setMessageSizeEstimator(MessageSizeEstimator estimator) { + super.setMessageSizeEstimator(estimator); + return this; + } + + @Override + public DatagramChannelConfig setMaxMessagesPerWrite(int maxMessagesPerWrite) { + super.setMaxMessagesPerWrite(maxMessagesPerWrite); + return this; + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/socket/DefaultServerSocketChannelConfig.java b/netty-channel/src/main/java/io/netty/channel/socket/DefaultServerSocketChannelConfig.java new file mode 100644 index 0000000..c6bcfe5 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/socket/DefaultServerSocketChannelConfig.java @@ -0,0 +1,209 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.socket; + +import io.netty.buffer.ByteBufAllocator; +import io.netty.channel.ChannelException; +import io.netty.channel.ChannelOption; +import io.netty.channel.DefaultChannelConfig; +import io.netty.channel.MessageSizeEstimator; +import io.netty.channel.RecvByteBufAllocator; +import io.netty.channel.ServerChannelRecvByteBufAllocator; +import io.netty.channel.WriteBufferWaterMark; +import io.netty.util.NetUtil; +import io.netty.util.internal.ObjectUtil; + +import java.net.ServerSocket; +import java.net.SocketException; +import java.util.Map; + +import static io.netty.channel.ChannelOption.SO_BACKLOG; +import static io.netty.channel.ChannelOption.SO_RCVBUF; +import static io.netty.channel.ChannelOption.SO_REUSEADDR; +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; + +/** + * The default {@link ServerSocketChannelConfig} implementation. + */ +public class DefaultServerSocketChannelConfig extends DefaultChannelConfig + implements ServerSocketChannelConfig { + + protected final ServerSocket javaSocket; + private volatile int backlog = NetUtil.SOMAXCONN; + + /** + * Creates a new instance. + */ + public DefaultServerSocketChannelConfig(ServerSocketChannel channel, ServerSocket javaSocket) { + super(channel, new ServerChannelRecvByteBufAllocator()); + this.javaSocket = ObjectUtil.checkNotNull(javaSocket, "javaSocket"); + } + + @Override + public Map, Object> getOptions() { + return getOptions(super.getOptions(), SO_RCVBUF, SO_REUSEADDR, SO_BACKLOG); + } + + @SuppressWarnings("unchecked") + @Override + public T getOption(ChannelOption option) { + if (option == SO_RCVBUF) { + return (T) Integer.valueOf(getReceiveBufferSize()); + } + if (option == SO_REUSEADDR) { + return (T) Boolean.valueOf(isReuseAddress()); + } + if (option == SO_BACKLOG) { + return (T) Integer.valueOf(getBacklog()); + } + + return super.getOption(option); + } + + @Override + public boolean setOption(ChannelOption option, T value) { + validate(option, value); + + if (option == SO_RCVBUF) { + setReceiveBufferSize((Integer) value); + } else if (option == SO_REUSEADDR) { + setReuseAddress((Boolean) value); + } else if (option == SO_BACKLOG) { + setBacklog((Integer) value); + } else { + return super.setOption(option, value); + } + + return true; + } + + @Override + public boolean isReuseAddress() { + try { + return javaSocket.getReuseAddress(); + } catch (SocketException e) { + throw new ChannelException(e); + } + } + + @Override + public ServerSocketChannelConfig setReuseAddress(boolean reuseAddress) { + try { + javaSocket.setReuseAddress(reuseAddress); + } catch (SocketException e) { + throw new ChannelException(e); + } + return this; + } + + @Override + public int getReceiveBufferSize() { + try { + return javaSocket.getReceiveBufferSize(); + } catch (SocketException e) { + throw new ChannelException(e); + } + } + + @Override + public ServerSocketChannelConfig setReceiveBufferSize(int receiveBufferSize) { + try { + javaSocket.setReceiveBufferSize(receiveBufferSize); + } catch (SocketException e) { + throw new ChannelException(e); + } + return this; + } + + @Override + public ServerSocketChannelConfig setPerformancePreferences(int connectionTime, int latency, int bandwidth) { + javaSocket.setPerformancePreferences(connectionTime, latency, bandwidth); + return this; + } + + @Override + public int getBacklog() { + return backlog; + } + + @Override + public ServerSocketChannelConfig setBacklog(int backlog) { + checkPositiveOrZero(backlog, "backlog"); + this.backlog = backlog; + return this; + } + + @Override + public ServerSocketChannelConfig setConnectTimeoutMillis(int connectTimeoutMillis) { + super.setConnectTimeoutMillis(connectTimeoutMillis); + return this; + } + + @Override + @Deprecated + public ServerSocketChannelConfig setMaxMessagesPerRead(int maxMessagesPerRead) { + super.setMaxMessagesPerRead(maxMessagesPerRead); + return this; + } + + @Override + public ServerSocketChannelConfig setWriteSpinCount(int writeSpinCount) { + super.setWriteSpinCount(writeSpinCount); + return this; + } + + @Override + public ServerSocketChannelConfig setAllocator(ByteBufAllocator allocator) { + super.setAllocator(allocator); + return this; + } + + @Override + public ServerSocketChannelConfig setRecvByteBufAllocator(RecvByteBufAllocator allocator) { + super.setRecvByteBufAllocator(allocator); + return this; + } + + @Override + public ServerSocketChannelConfig setAutoRead(boolean autoRead) { + super.setAutoRead(autoRead); + return this; + } + + @Override + public ServerSocketChannelConfig setWriteBufferHighWaterMark(int writeBufferHighWaterMark) { + super.setWriteBufferHighWaterMark(writeBufferHighWaterMark); + return this; + } + + @Override + public ServerSocketChannelConfig setWriteBufferLowWaterMark(int writeBufferLowWaterMark) { + super.setWriteBufferLowWaterMark(writeBufferLowWaterMark); + return this; + } + + @Override + public ServerSocketChannelConfig setWriteBufferWaterMark(WriteBufferWaterMark writeBufferWaterMark) { + super.setWriteBufferWaterMark(writeBufferWaterMark); + return this; + } + + @Override + public ServerSocketChannelConfig setMessageSizeEstimator(MessageSizeEstimator estimator) { + super.setMessageSizeEstimator(estimator); + return this; + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/socket/DefaultSocketChannelConfig.java b/netty-channel/src/main/java/io/netty/channel/socket/DefaultSocketChannelConfig.java new file mode 100644 index 0000000..f9066a4 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/socket/DefaultSocketChannelConfig.java @@ -0,0 +1,347 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.socket; + +import io.netty.buffer.ByteBufAllocator; +import io.netty.channel.ChannelException; +import io.netty.channel.ChannelOption; +import io.netty.channel.DefaultChannelConfig; +import io.netty.channel.MessageSizeEstimator; +import io.netty.channel.RecvByteBufAllocator; +import io.netty.channel.WriteBufferWaterMark; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.PlatformDependent; + +import java.net.Socket; +import java.net.SocketException; +import java.util.Map; + +import static io.netty.channel.ChannelOption.*; + +/** + * The default {@link SocketChannelConfig} implementation. + */ +public class DefaultSocketChannelConfig extends DefaultChannelConfig + implements SocketChannelConfig { + + protected final Socket javaSocket; + private volatile boolean allowHalfClosure; + + /** + * Creates a new instance. + */ + public DefaultSocketChannelConfig(SocketChannel channel, Socket javaSocket) { + super(channel); + this.javaSocket = ObjectUtil.checkNotNull(javaSocket, "javaSocket"); + + // Enable TCP_NODELAY by default if possible. + if (PlatformDependent.canEnableTcpNoDelayByDefault()) { + try { + setTcpNoDelay(true); + } catch (Exception e) { + // Ignore. + } + } + } + + @Override + public Map, Object> getOptions() { + return getOptions( + super.getOptions(), + SO_RCVBUF, SO_SNDBUF, TCP_NODELAY, SO_KEEPALIVE, SO_REUSEADDR, SO_LINGER, IP_TOS, + ALLOW_HALF_CLOSURE); + } + + @SuppressWarnings("unchecked") + @Override + public T getOption(ChannelOption option) { + if (option == SO_RCVBUF) { + return (T) Integer.valueOf(getReceiveBufferSize()); + } + if (option == SO_SNDBUF) { + return (T) Integer.valueOf(getSendBufferSize()); + } + if (option == TCP_NODELAY) { + return (T) Boolean.valueOf(isTcpNoDelay()); + } + if (option == SO_KEEPALIVE) { + return (T) Boolean.valueOf(isKeepAlive()); + } + if (option == SO_REUSEADDR) { + return (T) Boolean.valueOf(isReuseAddress()); + } + if (option == SO_LINGER) { + return (T) Integer.valueOf(getSoLinger()); + } + if (option == IP_TOS) { + return (T) Integer.valueOf(getTrafficClass()); + } + if (option == ALLOW_HALF_CLOSURE) { + return (T) Boolean.valueOf(isAllowHalfClosure()); + } + + return super.getOption(option); + } + + @Override + public boolean setOption(ChannelOption option, T value) { + validate(option, value); + + if (option == SO_RCVBUF) { + setReceiveBufferSize((Integer) value); + } else if (option == SO_SNDBUF) { + setSendBufferSize((Integer) value); + } else if (option == TCP_NODELAY) { + setTcpNoDelay((Boolean) value); + } else if (option == SO_KEEPALIVE) { + setKeepAlive((Boolean) value); + } else if (option == SO_REUSEADDR) { + setReuseAddress((Boolean) value); + } else if (option == SO_LINGER) { + setSoLinger((Integer) value); + } else if (option == IP_TOS) { + setTrafficClass((Integer) value); + } else if (option == ALLOW_HALF_CLOSURE) { + setAllowHalfClosure((Boolean) value); + } else { + return super.setOption(option, value); + } + + return true; + } + + @Override + public int getReceiveBufferSize() { + try { + return javaSocket.getReceiveBufferSize(); + } catch (SocketException e) { + throw new ChannelException(e); + } + } + + @Override + public int getSendBufferSize() { + try { + return javaSocket.getSendBufferSize(); + } catch (SocketException e) { + throw new ChannelException(e); + } + } + + @Override + public int getSoLinger() { + try { + return javaSocket.getSoLinger(); + } catch (SocketException e) { + throw new ChannelException(e); + } + } + + @Override + public int getTrafficClass() { + try { + return javaSocket.getTrafficClass(); + } catch (SocketException e) { + throw new ChannelException(e); + } + } + + @Override + public boolean isKeepAlive() { + try { + return javaSocket.getKeepAlive(); + } catch (SocketException e) { + throw new ChannelException(e); + } + } + + @Override + public boolean isReuseAddress() { + try { + return javaSocket.getReuseAddress(); + } catch (SocketException e) { + throw new ChannelException(e); + } + } + + @Override + public boolean isTcpNoDelay() { + try { + return javaSocket.getTcpNoDelay(); + } catch (SocketException e) { + throw new ChannelException(e); + } + } + + @Override + public SocketChannelConfig setKeepAlive(boolean keepAlive) { + try { + javaSocket.setKeepAlive(keepAlive); + } catch (SocketException e) { + throw new ChannelException(e); + } + return this; + } + + @Override + public SocketChannelConfig setPerformancePreferences( + int connectionTime, int latency, int bandwidth) { + javaSocket.setPerformancePreferences(connectionTime, latency, bandwidth); + return this; + } + + @Override + public SocketChannelConfig setReceiveBufferSize(int receiveBufferSize) { + try { + javaSocket.setReceiveBufferSize(receiveBufferSize); + } catch (SocketException e) { + throw new ChannelException(e); + } + return this; + } + + @Override + public SocketChannelConfig setReuseAddress(boolean reuseAddress) { + try { + javaSocket.setReuseAddress(reuseAddress); + } catch (SocketException e) { + throw new ChannelException(e); + } + return this; + } + + @Override + public SocketChannelConfig setSendBufferSize(int sendBufferSize) { + try { + javaSocket.setSendBufferSize(sendBufferSize); + } catch (SocketException e) { + throw new ChannelException(e); + } + return this; + } + + @Override + public SocketChannelConfig setSoLinger(int soLinger) { + try { + if (soLinger < 0) { + javaSocket.setSoLinger(false, 0); + } else { + javaSocket.setSoLinger(true, soLinger); + } + } catch (SocketException e) { + throw new ChannelException(e); + } + return this; + } + + @Override + public SocketChannelConfig setTcpNoDelay(boolean tcpNoDelay) { + try { + javaSocket.setTcpNoDelay(tcpNoDelay); + } catch (SocketException e) { + throw new ChannelException(e); + } + return this; + } + + @Override + public SocketChannelConfig setTrafficClass(int trafficClass) { + try { + javaSocket.setTrafficClass(trafficClass); + } catch (SocketException e) { + throw new ChannelException(e); + } + return this; + } + + @Override + public boolean isAllowHalfClosure() { + return allowHalfClosure; + } + + @Override + public SocketChannelConfig setAllowHalfClosure(boolean allowHalfClosure) { + this.allowHalfClosure = allowHalfClosure; + return this; + } + + @Override + public SocketChannelConfig setConnectTimeoutMillis(int connectTimeoutMillis) { + super.setConnectTimeoutMillis(connectTimeoutMillis); + return this; + } + + @Override + @Deprecated + public SocketChannelConfig setMaxMessagesPerRead(int maxMessagesPerRead) { + super.setMaxMessagesPerRead(maxMessagesPerRead); + return this; + } + + @Override + public SocketChannelConfig setWriteSpinCount(int writeSpinCount) { + super.setWriteSpinCount(writeSpinCount); + return this; + } + + @Override + public SocketChannelConfig setAllocator(ByteBufAllocator allocator) { + super.setAllocator(allocator); + return this; + } + + @Override + public SocketChannelConfig setRecvByteBufAllocator(RecvByteBufAllocator allocator) { + super.setRecvByteBufAllocator(allocator); + return this; + } + + @Override + public SocketChannelConfig setAutoRead(boolean autoRead) { + super.setAutoRead(autoRead); + return this; + } + + @Override + public SocketChannelConfig setAutoClose(boolean autoClose) { + super.setAutoClose(autoClose); + return this; + } + + @Override + public SocketChannelConfig setWriteBufferHighWaterMark(int writeBufferHighWaterMark) { + super.setWriteBufferHighWaterMark(writeBufferHighWaterMark); + return this; + } + + @Override + public SocketChannelConfig setWriteBufferLowWaterMark(int writeBufferLowWaterMark) { + super.setWriteBufferLowWaterMark(writeBufferLowWaterMark); + return this; + } + + @Override + public SocketChannelConfig setWriteBufferWaterMark(WriteBufferWaterMark writeBufferWaterMark) { + super.setWriteBufferWaterMark(writeBufferWaterMark); + return this; + } + + @Override + public SocketChannelConfig setMessageSizeEstimator(MessageSizeEstimator estimator) { + super.setMessageSizeEstimator(estimator); + return this; + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/socket/DuplexChannel.java b/netty-channel/src/main/java/io/netty/channel/socket/DuplexChannel.java new file mode 100644 index 0000000..0fc1c14 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/socket/DuplexChannel.java @@ -0,0 +1,81 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.socket; + +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelPromise; + +import java.net.Socket; + +/** + * A duplex {@link Channel} that has two sides that can be shutdown independently. + */ +public interface DuplexChannel extends Channel { + /** + * Returns {@code true} if and only if the remote peer shut down its output so that no more + * data is received from this channel. Note that the semantic of this method is different from + * that of {@link Socket#shutdownInput()} and {@link Socket#isInputShutdown()}. + */ + boolean isInputShutdown(); + + /** + * @see Socket#shutdownInput() + */ + ChannelFuture shutdownInput(); + + /** + * Will shutdown the input and notify {@link ChannelPromise}. + * + * @see Socket#shutdownInput() + */ + ChannelFuture shutdownInput(ChannelPromise promise); + + /** + * @see Socket#isOutputShutdown() + */ + boolean isOutputShutdown(); + + /** + * @see Socket#shutdownOutput() + */ + ChannelFuture shutdownOutput(); + + /** + * Will shutdown the output and notify {@link ChannelPromise}. + * + * @see Socket#shutdownOutput() + */ + ChannelFuture shutdownOutput(ChannelPromise promise); + + /** + * Determine if both the input and output of this channel have been shutdown. + */ + boolean isShutdown(); + + /** + * Will shutdown the input and output sides of this channel. + * @return will be completed when both shutdown operations complete. + */ + ChannelFuture shutdown(); + + /** + * Will shutdown the input and output sides of this channel. + * @param promise will be completed when both shutdown operations complete. + * @return will be completed when both shutdown operations complete. + */ + ChannelFuture shutdown(ChannelPromise promise); +} diff --git a/netty-channel/src/main/java/io/netty/channel/socket/DuplexChannelConfig.java b/netty-channel/src/main/java/io/netty/channel/socket/DuplexChannelConfig.java new file mode 100644 index 0000000..d3757c4 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/socket/DuplexChannelConfig.java @@ -0,0 +1,84 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.socket; + +import io.netty.buffer.ByteBufAllocator; +import io.netty.channel.ChannelConfig; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandler; +import io.netty.channel.ChannelOption; +import io.netty.channel.MessageSizeEstimator; +import io.netty.channel.RecvByteBufAllocator; +import io.netty.channel.WriteBufferWaterMark; + +/** + * A {@link ChannelConfig} for a {@link DuplexChannel}. + * + *

Available options

+ * + * In addition to the options provided by {@link ChannelConfig}, + * {@link DuplexChannelConfig} allows the following options in the option map: + * + * + * + * + * + *
{@link ChannelOption#ALLOW_HALF_CLOSURE}{@link #setAllowHalfClosure(boolean)}
+ */ +public interface DuplexChannelConfig extends ChannelConfig { + + /** + * Returns {@code true} if and only if the channel should not close itself when its remote + * peer shuts down output to make the connection half-closed. If {@code false}, the connection + * is closed automatically when the remote peer shuts down output. + */ + boolean isAllowHalfClosure(); + + /** + * Sets whether the channel should not close itself when its remote peer shuts down output to + * make the connection half-closed. If {@code true} the connection is not closed when the + * remote peer shuts down output. Instead, + * {@link ChannelInboundHandler#userEventTriggered(ChannelHandlerContext, Object)} + * is invoked with a {@link ChannelInputShutdownEvent} object. If {@code false}, the connection + * is closed automatically. + */ + DuplexChannelConfig setAllowHalfClosure(boolean allowHalfClosure); + + @Override + @Deprecated + DuplexChannelConfig setMaxMessagesPerRead(int maxMessagesPerRead); + + @Override + DuplexChannelConfig setWriteSpinCount(int writeSpinCount); + + @Override + DuplexChannelConfig setAllocator(ByteBufAllocator allocator); + + @Override + DuplexChannelConfig setRecvByteBufAllocator(RecvByteBufAllocator allocator); + + @Override + DuplexChannelConfig setAutoRead(boolean autoRead); + + @Override + DuplexChannelConfig setAutoClose(boolean autoClose); + + @Override + DuplexChannelConfig setMessageSizeEstimator(MessageSizeEstimator estimator); + + @Override + DuplexChannelConfig setWriteBufferWaterMark(WriteBufferWaterMark writeBufferWaterMark); +} diff --git a/netty-channel/src/main/java/io/netty/channel/socket/InternetProtocolFamily.java b/netty-channel/src/main/java/io/netty/channel/socket/InternetProtocolFamily.java new file mode 100644 index 0000000..57f7cf7 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/socket/InternetProtocolFamily.java @@ -0,0 +1,81 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.socket; + +import io.netty.util.NetUtil; + +import java.net.Inet4Address; +import java.net.Inet6Address; +import java.net.InetAddress; + +/** + * Internet Protocol (IP) families used byte the {@link DatagramChannel} + */ +public enum InternetProtocolFamily { + IPv4(Inet4Address.class, 1), + IPv6(Inet6Address.class, 2); + + private final Class addressType; + private final int addressNumber; + + InternetProtocolFamily(Class addressType, int addressNumber) { + this.addressType = addressType; + this.addressNumber = addressNumber; + } + + /** + * Returns the address type of this protocol family. + */ + public Class addressType() { + return addressType; + } + + /** + * Returns the + * address number + * of the family. + */ + public int addressNumber() { + return addressNumber; + } + + /** + * Returns the {@link InetAddress} that represent the {@code LOCALHOST} for the family. + */ + public InetAddress localhost() { + switch (this) { + case IPv4: + return NetUtil.LOCALHOST4; + case IPv6: + return NetUtil.LOCALHOST6; + default: + throw new IllegalStateException("Unsupported family " + this); + } + } + + /** + * Returns the {@link InternetProtocolFamily} for the given {@link InetAddress}. + */ + public static InternetProtocolFamily of(InetAddress address) { + if (address instanceof Inet4Address) { + return IPv4; + } + if (address instanceof Inet6Address) { + return IPv6; + } + throw new IllegalArgumentException("address " + address + " not supported"); + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/socket/ServerSocketChannel.java b/netty-channel/src/main/java/io/netty/channel/socket/ServerSocketChannel.java new file mode 100644 index 0000000..b394238 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/socket/ServerSocketChannel.java @@ -0,0 +1,32 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.socket; + +import io.netty.channel.ServerChannel; + +import java.net.InetSocketAddress; + +/** + * A TCP/IP {@link ServerChannel} which accepts incoming TCP/IP connections. + */ +public interface ServerSocketChannel extends ServerChannel { + @Override + ServerSocketChannelConfig config(); + @Override + InetSocketAddress localAddress(); + @Override + InetSocketAddress remoteAddress(); +} diff --git a/netty-channel/src/main/java/io/netty/channel/socket/ServerSocketChannelConfig.java b/netty-channel/src/main/java/io/netty/channel/socket/ServerSocketChannelConfig.java new file mode 100644 index 0000000..1646bdb --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/socket/ServerSocketChannelConfig.java @@ -0,0 +1,119 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.socket; + +import io.netty.buffer.ByteBufAllocator; +import io.netty.channel.ChannelConfig; +import io.netty.channel.MessageSizeEstimator; +import io.netty.channel.RecvByteBufAllocator; +import io.netty.channel.WriteBufferWaterMark; + +import java.net.ServerSocket; +import java.net.StandardSocketOptions; + +/** + * A {@link ChannelConfig} for a {@link ServerSocketChannel}. + * + *

Available options

+ * + * In addition to the options provided by {@link ChannelConfig}, + * {@link ServerSocketChannelConfig} allows the following options in the + * option map: + * + * + * + * + * + * + * + * + * + * + * + *
NameAssociated setter method
{@code "backlog"}{@link #setBacklog(int)}
{@code "reuseAddress"}{@link #setReuseAddress(boolean)}
{@code "receiveBufferSize"}{@link #setReceiveBufferSize(int)}
+ */ +public interface ServerSocketChannelConfig extends ChannelConfig { + + /** + * Gets the backlog value to specify when the channel binds to a local + * address. + */ + int getBacklog(); + + /** + * Sets the backlog value to specify when the channel binds to a local + * address. + */ + ServerSocketChannelConfig setBacklog(int backlog); + + /** + * Gets the {@link StandardSocketOptions#SO_REUSEADDR} option. + */ + boolean isReuseAddress(); + + /** + * Sets the {@link StandardSocketOptions#SO_REUSEADDR} option. + */ + ServerSocketChannelConfig setReuseAddress(boolean reuseAddress); + + /** + * Gets the {@link StandardSocketOptions#SO_RCVBUF} option. + */ + int getReceiveBufferSize(); + + /** + * Gets the {@link StandardSocketOptions#SO_SNDBUF} option. + */ + ServerSocketChannelConfig setReceiveBufferSize(int receiveBufferSize); + + /** + * Sets the performance preferences as specified in + * {@link ServerSocket#setPerformancePreferences(int, int, int)}. + */ + ServerSocketChannelConfig setPerformancePreferences(int connectionTime, int latency, int bandwidth); + + @Override + ServerSocketChannelConfig setConnectTimeoutMillis(int connectTimeoutMillis); + + @Override + @Deprecated + ServerSocketChannelConfig setMaxMessagesPerRead(int maxMessagesPerRead); + + @Override + ServerSocketChannelConfig setWriteSpinCount(int writeSpinCount); + + @Override + ServerSocketChannelConfig setAllocator(ByteBufAllocator allocator); + + @Override + ServerSocketChannelConfig setRecvByteBufAllocator(RecvByteBufAllocator allocator); + + @Override + ServerSocketChannelConfig setAutoRead(boolean autoRead); + + @Override + ServerSocketChannelConfig setMessageSizeEstimator(MessageSizeEstimator estimator); + + @Override + ServerSocketChannelConfig setWriteBufferHighWaterMark(int writeBufferHighWaterMark); + + @Override + ServerSocketChannelConfig setWriteBufferLowWaterMark(int writeBufferLowWaterMark); + + @Override + ServerSocketChannelConfig setWriteBufferWaterMark(WriteBufferWaterMark writeBufferWaterMark); + +} diff --git a/netty-channel/src/main/java/io/netty/channel/socket/SocketChannel.java b/netty-channel/src/main/java/io/netty/channel/socket/SocketChannel.java new file mode 100644 index 0000000..0504e43 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/socket/SocketChannel.java @@ -0,0 +1,35 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.socket; + +import io.netty.channel.Channel; + +import java.net.InetSocketAddress; + +/** + * A TCP/IP socket {@link Channel}. + */ +public interface SocketChannel extends DuplexChannel { + @Override + ServerSocketChannel parent(); + + @Override + SocketChannelConfig config(); + @Override + InetSocketAddress localAddress(); + @Override + InetSocketAddress remoteAddress(); +} diff --git a/netty-channel/src/main/java/io/netty/channel/socket/SocketChannelConfig.java b/netty-channel/src/main/java/io/netty/channel/socket/SocketChannelConfig.java new file mode 100644 index 0000000..db9213b --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/socket/SocketChannelConfig.java @@ -0,0 +1,172 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.socket; + +import io.netty.buffer.ByteBufAllocator; +import io.netty.channel.ChannelConfig; +import io.netty.channel.ChannelOption; +import io.netty.channel.MessageSizeEstimator; +import io.netty.channel.RecvByteBufAllocator; +import io.netty.channel.WriteBufferWaterMark; + +import java.net.Socket; +import java.net.StandardSocketOptions; + +/** + * A {@link ChannelConfig} for a {@link SocketChannel}. + * + *

Available options

+ * + * In addition to the options provided by {@link DuplexChannelConfig}, + * {@link SocketChannelConfig} allows the following options in the option map: + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
NameAssociated setter method
{@link ChannelOption#SO_KEEPALIVE}{@link #setKeepAlive(boolean)}
{@link ChannelOption#SO_REUSEADDR}{@link #setReuseAddress(boolean)}
{@link ChannelOption#SO_LINGER}{@link #setSoLinger(int)}
{@link ChannelOption#TCP_NODELAY}{@link #setTcpNoDelay(boolean)}
{@link ChannelOption#SO_RCVBUF}{@link #setReceiveBufferSize(int)}
{@link ChannelOption#SO_SNDBUF}{@link #setSendBufferSize(int)}
{@link ChannelOption#IP_TOS}{@link #setTrafficClass(int)}
{@link ChannelOption#ALLOW_HALF_CLOSURE}{@link #setAllowHalfClosure(boolean)}
+ */ +public interface SocketChannelConfig extends DuplexChannelConfig { + + /** + * Gets the {@link StandardSocketOptions#TCP_NODELAY} option. Please note that the default value of this option + * is {@code true} unlike the operating system default ({@code false}). However, for some buggy platforms, such as + * Android, that shows erratic behavior with Nagle's algorithm disabled, the default value remains to be + * {@code false}. + */ + boolean isTcpNoDelay(); + + /** + * Sets the {@link StandardSocketOptions#TCP_NODELAY} option. Please note that the default value of this option + * is {@code true} unlike the operating system default ({@code false}). However, for some buggy platforms, such as + * Android, that shows erratic behavior with Nagle's algorithm disabled, the default value remains to be + * {@code false}. + */ + SocketChannelConfig setTcpNoDelay(boolean tcpNoDelay); + + /** + * Gets the {@link StandardSocketOptions#SO_LINGER} option. + */ + int getSoLinger(); + + /** + * Sets the {@link StandardSocketOptions#SO_LINGER} option. + */ + SocketChannelConfig setSoLinger(int soLinger); + + /** + * Gets the {@link StandardSocketOptions#SO_SNDBUF} option. + */ + int getSendBufferSize(); + + /** + * Sets the {@link StandardSocketOptions#SO_SNDBUF} option. + */ + SocketChannelConfig setSendBufferSize(int sendBufferSize); + + /** + * Gets the {@link StandardSocketOptions#SO_RCVBUF} option. + */ + int getReceiveBufferSize(); + + /** + * Sets the {@link StandardSocketOptions#SO_RCVBUF} option. + */ + SocketChannelConfig setReceiveBufferSize(int receiveBufferSize); + + /** + * Gets the {@link StandardSocketOptions#SO_KEEPALIVE} option. + */ + boolean isKeepAlive(); + + /** + * Sets the {@link StandardSocketOptions#SO_KEEPALIVE} option. + */ + SocketChannelConfig setKeepAlive(boolean keepAlive); + + /** + * Gets the {@link StandardSocketOptions#IP_TOS} option. + */ + int getTrafficClass(); + + /** + * Sets the {@link StandardSocketOptions#IP_TOS} option. + */ + SocketChannelConfig setTrafficClass(int trafficClass); + + /** + * Gets the {@link StandardSocketOptions#SO_REUSEADDR} option. + */ + boolean isReuseAddress(); + + /** + * Sets the {@link StandardSocketOptions#SO_REUSEADDR} option. + */ + SocketChannelConfig setReuseAddress(boolean reuseAddress); + + /** + * Sets the performance preferences as specified in + * {@link Socket#setPerformancePreferences(int, int, int)}. + */ + SocketChannelConfig setPerformancePreferences(int connectionTime, int latency, int bandwidth); + + @Override + SocketChannelConfig setAllowHalfClosure(boolean allowHalfClosure); + + @Override + SocketChannelConfig setConnectTimeoutMillis(int connectTimeoutMillis); + + @Override + @Deprecated + SocketChannelConfig setMaxMessagesPerRead(int maxMessagesPerRead); + + @Override + SocketChannelConfig setWriteSpinCount(int writeSpinCount); + + @Override + SocketChannelConfig setAllocator(ByteBufAllocator allocator); + + @Override + SocketChannelConfig setRecvByteBufAllocator(RecvByteBufAllocator allocator); + + @Override + SocketChannelConfig setAutoRead(boolean autoRead); + + @Override + SocketChannelConfig setAutoClose(boolean autoClose); + + @Override + SocketChannelConfig setMessageSizeEstimator(MessageSizeEstimator estimator); + + @Override + SocketChannelConfig setWriteBufferWaterMark(WriteBufferWaterMark writeBufferWaterMark); +} diff --git a/netty-channel/src/main/java/io/netty/channel/socket/nio/NioChannelOption.java b/netty-channel/src/main/java/io/netty/channel/socket/nio/NioChannelOption.java new file mode 100644 index 0000000..b5c15e7 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/socket/nio/NioChannelOption.java @@ -0,0 +1,123 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.socket.nio; + +import io.netty.channel.ChannelException; +import io.netty.channel.ChannelOption; +import io.netty.util.internal.SuppressJava6Requirement; + +import java.io.IOException; +import java.nio.channels.Channel; +import java.nio.channels.ServerSocketChannel; +import java.util.ArrayList; +import java.util.List; +import java.util.Set; + +/** + * Provides {@link ChannelOption} over a given {@link java.net.SocketOption} which is then passed through the underlying + * {@link java.nio.channels.NetworkChannel}. + */ +@SuppressJava6Requirement(reason = "Usage explicit by the user") +public final class NioChannelOption extends ChannelOption { + + private final java.net.SocketOption option; + + @SuppressWarnings("deprecation") + private NioChannelOption(java.net.SocketOption option) { + super(option.name()); + this.option = option; + } + + /** + * Returns a {@link ChannelOption} for the given {@link java.net.SocketOption}. + */ + public static ChannelOption of(java.net.SocketOption option) { + return new NioChannelOption(option); + } + + // It's important to not use java.nio.channels.NetworkChannel as otherwise the classes that sometimes call this + // method may not be used on Java 6, as method linking can happen eagerly even if this method was not actually + // called at runtime. + // + // See https://github.com/netty/netty/issues/8166 + + // Internal helper methods to remove code duplication between Nio*Channel implementations. + @SuppressJava6Requirement(reason = "Usage guarded by java version check") + static boolean setOption(Channel jdkChannel, NioChannelOption option, T value) { + java.nio.channels.NetworkChannel channel = (java.nio.channels.NetworkChannel) jdkChannel; + if (!channel.supportedOptions().contains(option.option)) { + return false; + } + if (channel instanceof ServerSocketChannel && option.option == java.net.StandardSocketOptions.IP_TOS) { + // Skip IP_TOS as a workaround for a JDK bug: + // See https://mail.openjdk.java.net/pipermail/nio-dev/2018-August/005365.html + return false; + } + try { + channel.setOption(option.option, value); + return true; + } catch (IOException e) { + throw new ChannelException(e); + } + } + + @SuppressJava6Requirement(reason = "Usage guarded by java version check") + static T getOption(Channel jdkChannel, NioChannelOption option) { + java.nio.channels.NetworkChannel channel = (java.nio.channels.NetworkChannel) jdkChannel; + + if (!channel.supportedOptions().contains(option.option)) { + return null; + } + if (channel instanceof ServerSocketChannel && option.option == java.net.StandardSocketOptions.IP_TOS) { + // Skip IP_TOS as a workaround for a JDK bug: + // See https://mail.openjdk.java.net/pipermail/nio-dev/2018-August/005365.html + return null; + } + try { + return channel.getOption(option.option); + } catch (IOException e) { + throw new ChannelException(e); + } + } + + @SuppressJava6Requirement(reason = "Usage guarded by java version check") + @SuppressWarnings("unchecked") + static ChannelOption[] getOptions(Channel jdkChannel) { + java.nio.channels.NetworkChannel channel = (java.nio.channels.NetworkChannel) jdkChannel; + Set> supportedOpts = channel.supportedOptions(); + + if (channel instanceof ServerSocketChannel) { + List> extraOpts = new ArrayList>(supportedOpts.size()); + for (java.net.SocketOption opt : supportedOpts) { + if (opt == java.net.StandardSocketOptions.IP_TOS) { + // Skip IP_TOS as a workaround for a JDK bug: + // See https://mail.openjdk.java.net/pipermail/nio-dev/2018-August/005365.html + continue; + } + extraOpts.add(new NioChannelOption(opt)); + } + return extraOpts.toArray(new ChannelOption[0]); + } else { + ChannelOption[] extraOpts = new ChannelOption[supportedOpts.size()]; + + int i = 0; + for (java.net.SocketOption opt : supportedOpts) { + extraOpts[i++] = new NioChannelOption(opt); + } + return extraOpts; + } + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/socket/nio/NioDatagramChannel.java b/netty-channel/src/main/java/io/netty/channel/socket/nio/NioDatagramChannel.java new file mode 100644 index 0000000..467d004 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/socket/nio/NioDatagramChannel.java @@ -0,0 +1,625 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.socket.nio; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.AddressedEnvelope; +import io.netty.channel.Channel; +import io.netty.channel.ChannelException; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelMetadata; +import io.netty.channel.ChannelOption; +import io.netty.channel.ChannelOutboundBuffer; +import io.netty.channel.ChannelPromise; +import io.netty.channel.DefaultAddressedEnvelope; +import io.netty.channel.RecvByteBufAllocator; +import io.netty.channel.nio.AbstractNioMessageChannel; +import io.netty.channel.socket.DatagramChannelConfig; +import io.netty.channel.socket.DatagramPacket; +import io.netty.channel.socket.InternetProtocolFamily; +import io.netty.util.UncheckedBooleanSupplier; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.SocketUtils; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.StringUtil; +import io.netty.util.internal.SuppressJava6Requirement; + +import java.io.IOException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.NetworkInterface; +import java.net.SocketAddress; +import java.net.SocketException; +import java.nio.ByteBuffer; +import java.nio.channels.DatagramChannel; +import java.nio.channels.MembershipKey; +import java.nio.channels.SelectionKey; +import java.nio.channels.UnresolvedAddressException; +import java.nio.channels.spi.SelectorProvider; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; + +/** + * An NIO datagram {@link Channel} that sends and receives an + * {@link AddressedEnvelope}. + * + * @see AddressedEnvelope + * @see DatagramPacket + */ +public final class NioDatagramChannel + extends AbstractNioMessageChannel implements io.netty.channel.socket.DatagramChannel { + + private static final ChannelMetadata METADATA = new ChannelMetadata(true, 16); + private static final SelectorProvider DEFAULT_SELECTOR_PROVIDER = SelectorProvider.provider(); + private static final String EXPECTED_TYPES = + " (expected: " + StringUtil.simpleClassName(DatagramPacket.class) + ", " + + StringUtil.simpleClassName(AddressedEnvelope.class) + '<' + + StringUtil.simpleClassName(ByteBuf.class) + ", " + + StringUtil.simpleClassName(SocketAddress.class) + ">, " + + StringUtil.simpleClassName(ByteBuf.class) + ')'; + + private final DatagramChannelConfig config; + + private Map> memberships; + + private static DatagramChannel newSocket(SelectorProvider provider) { + try { + /** + * Use the {@link SelectorProvider} to open {@link SocketChannel} and so remove condition in + * {@link SelectorProvider#provider()} which is called by each DatagramChannel.open() otherwise. + * + * See #2308. + */ + return provider.openDatagramChannel(); + } catch (IOException e) { + throw new ChannelException("Failed to open a socket.", e); + } + } + + @SuppressJava6Requirement(reason = "Usage guarded by java version check") + private static DatagramChannel newSocket(SelectorProvider provider, InternetProtocolFamily ipFamily) { + if (ipFamily == null) { + return newSocket(provider); + } + + checkJavaVersion(); + + try { + return provider.openDatagramChannel(ProtocolFamilyConverter.convert(ipFamily)); + } catch (IOException e) { + throw new ChannelException("Failed to open a socket.", e); + } + } + + private static void checkJavaVersion() { + if (PlatformDependent.javaVersion() < 7) { + throw new UnsupportedOperationException("Only supported on java 7+."); + } + } + + /** + * Create a new instance which will use the Operation Systems default {@link InternetProtocolFamily}. + */ + public NioDatagramChannel() { + this(newSocket(DEFAULT_SELECTOR_PROVIDER)); + } + + /** + * Create a new instance using the given {@link SelectorProvider} + * which will use the Operation Systems default {@link InternetProtocolFamily}. + */ + public NioDatagramChannel(SelectorProvider provider) { + this(newSocket(provider)); + } + + /** + * Create a new instance using the given {@link InternetProtocolFamily}. If {@code null} is used it will depend + * on the Operation Systems default which will be chosen. + */ + public NioDatagramChannel(InternetProtocolFamily ipFamily) { + this(newSocket(DEFAULT_SELECTOR_PROVIDER, ipFamily)); + } + + /** + * Create a new instance using the given {@link SelectorProvider} and {@link InternetProtocolFamily}. + * If {@link InternetProtocolFamily} is {@code null} it will depend on the Operation Systems default + * which will be chosen. + */ + public NioDatagramChannel(SelectorProvider provider, InternetProtocolFamily ipFamily) { + this(newSocket(provider, ipFamily)); + } + + /** + * Create a new instance from the given {@link DatagramChannel}. + */ + public NioDatagramChannel(DatagramChannel socket) { + super(null, socket, SelectionKey.OP_READ); + config = new NioDatagramChannelConfig(this, socket); + } + + @Override + public ChannelMetadata metadata() { + return METADATA; + } + + @Override + public DatagramChannelConfig config() { + return config; + } + + @Override + @SuppressWarnings("deprecation") + public boolean isActive() { + DatagramChannel ch = javaChannel(); + return ch.isOpen() && ( + config.getOption(ChannelOption.DATAGRAM_CHANNEL_ACTIVE_ON_REGISTRATION) && isRegistered() + || ch.socket().isBound()); + } + + @Override + public boolean isConnected() { + return javaChannel().isConnected(); + } + + @Override + protected DatagramChannel javaChannel() { + return (DatagramChannel) super.javaChannel(); + } + + @Override + protected SocketAddress localAddress0() { + return javaChannel().socket().getLocalSocketAddress(); + } + + @Override + protected SocketAddress remoteAddress0() { + return javaChannel().socket().getRemoteSocketAddress(); + } + + @Override + protected void doBind(SocketAddress localAddress) throws Exception { + doBind0(localAddress); + } + + private void doBind0(SocketAddress localAddress) throws Exception { + if (PlatformDependent.javaVersion() >= 7) { + SocketUtils.bind(javaChannel(), localAddress); + } else { + javaChannel().socket().bind(localAddress); + } + } + + @Override + protected boolean doConnect(SocketAddress remoteAddress, + SocketAddress localAddress) throws Exception { + if (localAddress != null) { + doBind0(localAddress); + } + + boolean success = false; + try { + javaChannel().connect(remoteAddress); + success = true; + return true; + } finally { + if (!success) { + doClose(); + } + } + } + + @Override + protected void doFinishConnect() throws Exception { + throw new Error(); + } + + @Override + protected void doDisconnect() throws Exception { + javaChannel().disconnect(); + } + + @Override + protected void doClose() throws Exception { + javaChannel().close(); + } + + @Override + protected int doReadMessages(List buf) throws Exception { + DatagramChannel ch = javaChannel(); + DatagramChannelConfig config = config(); + RecvByteBufAllocator.Handle allocHandle = unsafe().recvBufAllocHandle(); + + ByteBuf data = allocHandle.allocate(config.getAllocator()); + allocHandle.attemptedBytesRead(data.writableBytes()); + boolean free = true; + try { + ByteBuffer nioData = data.internalNioBuffer(data.writerIndex(), data.writableBytes()); + int pos = nioData.position(); + InetSocketAddress remoteAddress = (InetSocketAddress) ch.receive(nioData); + if (remoteAddress == null) { + return 0; + } + + allocHandle.lastBytesRead(nioData.position() - pos); + buf.add(new DatagramPacket(data.writerIndex(data.writerIndex() + allocHandle.lastBytesRead()), + localAddress(), remoteAddress)); + free = false; + return 1; + } catch (Throwable cause) { + PlatformDependent.throwException(cause); + return -1; + } finally { + if (free) { + data.release(); + } + } + } + + @Override + protected boolean doWriteMessage(Object msg, ChannelOutboundBuffer in) throws Exception { + final SocketAddress remoteAddress; + final ByteBuf data; + if (msg instanceof AddressedEnvelope) { + @SuppressWarnings("unchecked") + AddressedEnvelope envelope = (AddressedEnvelope) msg; + remoteAddress = envelope.recipient(); + data = envelope.content(); + } else { + data = (ByteBuf) msg; + remoteAddress = null; + } + + final int dataLen = data.readableBytes(); + if (dataLen == 0) { + return true; + } + + final ByteBuffer nioData = data.nioBufferCount() == 1 ? data.internalNioBuffer(data.readerIndex(), dataLen) + : data.nioBuffer(data.readerIndex(), dataLen); + final int writtenBytes; + if (remoteAddress != null) { + writtenBytes = javaChannel().send(nioData, remoteAddress); + } else { + writtenBytes = javaChannel().write(nioData); + } + return writtenBytes > 0; + } + + private static void checkUnresolved(AddressedEnvelope envelope) { + if (envelope.recipient() instanceof InetSocketAddress + && (((InetSocketAddress) envelope.recipient()).isUnresolved())) { + throw new UnresolvedAddressException(); + } + } + + @Override + protected Object filterOutboundMessage(Object msg) { + if (msg instanceof DatagramPacket) { + DatagramPacket p = (DatagramPacket) msg; + checkUnresolved(p); + ByteBuf content = p.content(); + if (isSingleDirectBuffer(content)) { + return p; + } + return new DatagramPacket(newDirectBuffer(p, content), p.recipient()); + } + + if (msg instanceof ByteBuf) { + ByteBuf buf = (ByteBuf) msg; + if (isSingleDirectBuffer(buf)) { + return buf; + } + return newDirectBuffer(buf); + } + + if (msg instanceof AddressedEnvelope) { + @SuppressWarnings("unchecked") + AddressedEnvelope e = (AddressedEnvelope) msg; + checkUnresolved(e); + if (e.content() instanceof ByteBuf) { + ByteBuf content = (ByteBuf) e.content(); + if (isSingleDirectBuffer(content)) { + return e; + } + return new DefaultAddressedEnvelope(newDirectBuffer(e, content), e.recipient()); + } + } + + throw new UnsupportedOperationException( + "unsupported message type: " + StringUtil.simpleClassName(msg) + EXPECTED_TYPES); + } + + /** + * Checks if the specified buffer is a direct buffer and is composed of a single NIO buffer. + * (We check this because otherwise we need to make it a non-composite buffer.) + */ + private static boolean isSingleDirectBuffer(ByteBuf buf) { + return buf.isDirect() && buf.nioBufferCount() == 1; + } + + @Override + protected boolean continueOnWriteError() { + // Continue on write error as a DatagramChannel can write to multiple remote peers + // + // See https://github.com/netty/netty/issues/2665 + return true; + } + + @Override + public InetSocketAddress localAddress() { + return (InetSocketAddress) super.localAddress(); + } + + @Override + public InetSocketAddress remoteAddress() { + return (InetSocketAddress) super.remoteAddress(); + } + + @Override + public ChannelFuture joinGroup(InetAddress multicastAddress) { + return joinGroup(multicastAddress, newPromise()); + } + + @Override + public ChannelFuture joinGroup(InetAddress multicastAddress, ChannelPromise promise) { + try { + NetworkInterface iface = config.getNetworkInterface(); + if (iface == null) { + iface = NetworkInterface.getByInetAddress(localAddress().getAddress()); + } + return joinGroup( + multicastAddress, iface, null, promise); + } catch (SocketException e) { + promise.setFailure(e); + } + return promise; + } + + @Override + public ChannelFuture joinGroup( + InetSocketAddress multicastAddress, NetworkInterface networkInterface) { + return joinGroup(multicastAddress, networkInterface, newPromise()); + } + + @Override + public ChannelFuture joinGroup( + InetSocketAddress multicastAddress, NetworkInterface networkInterface, + ChannelPromise promise) { + return joinGroup(multicastAddress.getAddress(), networkInterface, null, promise); + } + + @Override + public ChannelFuture joinGroup( + InetAddress multicastAddress, NetworkInterface networkInterface, InetAddress source) { + return joinGroup(multicastAddress, networkInterface, source, newPromise()); + } + + @SuppressJava6Requirement(reason = "Usage guarded by java version check") + @Override + public ChannelFuture joinGroup( + InetAddress multicastAddress, NetworkInterface networkInterface, + InetAddress source, ChannelPromise promise) { + + checkJavaVersion(); + + ObjectUtil.checkNotNull(multicastAddress, "multicastAddress"); + ObjectUtil.checkNotNull(networkInterface, "networkInterface"); + + try { + MembershipKey key; + if (source == null) { + key = javaChannel().join(multicastAddress, networkInterface); + } else { + key = javaChannel().join(multicastAddress, networkInterface, source); + } + + synchronized (this) { + List keys = null; + if (memberships == null) { + memberships = new HashMap>(); + } else { + keys = memberships.get(multicastAddress); + } + if (keys == null) { + keys = new ArrayList(); + memberships.put(multicastAddress, keys); + } + keys.add(key); + } + + promise.setSuccess(); + } catch (Throwable e) { + promise.setFailure(e); + } + + return promise; + } + + @Override + public ChannelFuture leaveGroup(InetAddress multicastAddress) { + return leaveGroup(multicastAddress, newPromise()); + } + + @Override + public ChannelFuture leaveGroup(InetAddress multicastAddress, ChannelPromise promise) { + try { + return leaveGroup( + multicastAddress, NetworkInterface.getByInetAddress(localAddress().getAddress()), null, promise); + } catch (SocketException e) { + promise.setFailure(e); + } + return promise; + } + + @Override + public ChannelFuture leaveGroup( + InetSocketAddress multicastAddress, NetworkInterface networkInterface) { + return leaveGroup(multicastAddress, networkInterface, newPromise()); + } + + @Override + public ChannelFuture leaveGroup( + InetSocketAddress multicastAddress, + NetworkInterface networkInterface, ChannelPromise promise) { + return leaveGroup(multicastAddress.getAddress(), networkInterface, null, promise); + } + + @Override + public ChannelFuture leaveGroup( + InetAddress multicastAddress, NetworkInterface networkInterface, InetAddress source) { + return leaveGroup(multicastAddress, networkInterface, source, newPromise()); + } + + @SuppressJava6Requirement(reason = "Usage guarded by java version check") + @Override + public ChannelFuture leaveGroup( + InetAddress multicastAddress, NetworkInterface networkInterface, InetAddress source, + ChannelPromise promise) { + checkJavaVersion(); + + ObjectUtil.checkNotNull(multicastAddress, "multicastAddress"); + ObjectUtil.checkNotNull(networkInterface, "networkInterface"); + + synchronized (this) { + if (memberships != null) { + List keys = memberships.get(multicastAddress); + if (keys != null) { + Iterator keyIt = keys.iterator(); + + while (keyIt.hasNext()) { + MembershipKey key = keyIt.next(); + if (networkInterface.equals(key.networkInterface())) { + if (source == null && key.sourceAddress() == null || + source != null && source.equals(key.sourceAddress())) { + key.drop(); + keyIt.remove(); + } + } + } + if (keys.isEmpty()) { + memberships.remove(multicastAddress); + } + } + } + } + + promise.setSuccess(); + return promise; + } + + /** + * Block the given sourceToBlock address for the given multicastAddress on the given networkInterface + */ + @Override + public ChannelFuture block( + InetAddress multicastAddress, NetworkInterface networkInterface, + InetAddress sourceToBlock) { + return block(multicastAddress, networkInterface, sourceToBlock, newPromise()); + } + + /** + * Block the given sourceToBlock address for the given multicastAddress on the given networkInterface + */ + @SuppressJava6Requirement(reason = "Usage guarded by java version check") + @Override + public ChannelFuture block( + InetAddress multicastAddress, NetworkInterface networkInterface, + InetAddress sourceToBlock, ChannelPromise promise) { + checkJavaVersion(); + + ObjectUtil.checkNotNull(multicastAddress, "multicastAddress"); + ObjectUtil.checkNotNull(sourceToBlock, "sourceToBlock"); + ObjectUtil.checkNotNull(networkInterface, "networkInterface"); + + synchronized (this) { + if (memberships != null) { + List keys = memberships.get(multicastAddress); + for (MembershipKey key: keys) { + if (networkInterface.equals(key.networkInterface())) { + try { + key.block(sourceToBlock); + } catch (IOException e) { + promise.setFailure(e); + } + } + } + } + } + promise.setSuccess(); + return promise; + } + + /** + * Block the given sourceToBlock address for the given multicastAddress + * + */ + @Override + public ChannelFuture block(InetAddress multicastAddress, InetAddress sourceToBlock) { + return block(multicastAddress, sourceToBlock, newPromise()); + } + + /** + * Block the given sourceToBlock address for the given multicastAddress + * + */ + @Override + public ChannelFuture block( + InetAddress multicastAddress, InetAddress sourceToBlock, ChannelPromise promise) { + try { + return block( + multicastAddress, + NetworkInterface.getByInetAddress(localAddress().getAddress()), + sourceToBlock, promise); + } catch (SocketException e) { + promise.setFailure(e); + } + return promise; + } + + @Override + @Deprecated + protected void setReadPending(boolean readPending) { + super.setReadPending(readPending); + } + + void clearReadPending0() { + clearReadPending(); + } + + @Override + protected boolean closeOnReadError(Throwable cause) { + // We do not want to close on SocketException when using DatagramChannel as we usually can continue receiving. + // See https://github.com/netty/netty/issues/5893 + if (cause instanceof SocketException) { + return false; + } + return super.closeOnReadError(cause); + } + + @Override + protected boolean continueReading(RecvByteBufAllocator.Handle allocHandle) { + if (allocHandle instanceof RecvByteBufAllocator.ExtendedHandle) { + // We use the TRUE_SUPPLIER as it is also ok to read less then what we did try to read (as long + // as we read anything). + return ((RecvByteBufAllocator.ExtendedHandle) allocHandle) + .continueReading(UncheckedBooleanSupplier.TRUE_SUPPLIER); + } + return allocHandle.continueReading(); + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/socket/nio/NioDatagramChannelConfig.java b/netty-channel/src/main/java/io/netty/channel/socket/nio/NioDatagramChannelConfig.java new file mode 100644 index 0000000..0c8fc66 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/socket/nio/NioDatagramChannelConfig.java @@ -0,0 +1,235 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.socket.nio; + +import io.netty.channel.ChannelException; +import io.netty.channel.ChannelOption; +import io.netty.channel.socket.DatagramChannelConfig; +import io.netty.channel.socket.DefaultDatagramChannelConfig; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.SocketUtils; + +import java.lang.reflect.Method; +import java.net.InetAddress; +import java.net.NetworkInterface; +import java.net.SocketException; +import java.nio.channels.DatagramChannel; +import java.util.Enumeration; +import java.util.Map; + +/** + * The default {@link NioDatagramChannelConfig} implementation. + */ +class NioDatagramChannelConfig extends DefaultDatagramChannelConfig { + + private static final Object IP_MULTICAST_TTL; + private static final Object IP_MULTICAST_IF; + private static final Object IP_MULTICAST_LOOP; + private static final Method GET_OPTION; + private static final Method SET_OPTION; + + static { + ClassLoader classLoader = PlatformDependent.getClassLoader(DatagramChannel.class); + Class socketOptionType = null; + try { + socketOptionType = Class.forName("java.net.SocketOption", true, classLoader); + } catch (Exception e) { + // Not Java 7+ + } + Class stdSocketOptionType = null; + try { + stdSocketOptionType = Class.forName("java.net.StandardSocketOptions", true, classLoader); + } catch (Exception e) { + // Not Java 7+ + } + + Object ipMulticastTtl = null; + Object ipMulticastIf = null; + Object ipMulticastLoop = null; + Method getOption = null; + Method setOption = null; + if (socketOptionType != null) { + try { + ipMulticastTtl = stdSocketOptionType.getDeclaredField("IP_MULTICAST_TTL").get(null); + } catch (Exception e) { + throw new Error("cannot locate the IP_MULTICAST_TTL field", e); + } + + try { + ipMulticastIf = stdSocketOptionType.getDeclaredField("IP_MULTICAST_IF").get(null); + } catch (Exception e) { + throw new Error("cannot locate the IP_MULTICAST_IF field", e); + } + + try { + ipMulticastLoop = stdSocketOptionType.getDeclaredField("IP_MULTICAST_LOOP").get(null); + } catch (Exception e) { + throw new Error("cannot locate the IP_MULTICAST_LOOP field", e); + } + + Class networkChannelClass = null; + try { + networkChannelClass = Class.forName("java.nio.channels.NetworkChannel", true, classLoader); + } catch (Throwable ignore) { + // Not Java 7+ + } + + if (networkChannelClass == null) { + getOption = null; + setOption = null; + } else { + try { + getOption = networkChannelClass.getDeclaredMethod("getOption", socketOptionType); + } catch (Exception e) { + throw new Error("cannot locate the getOption() method", e); + } + + try { + setOption = networkChannelClass.getDeclaredMethod("setOption", socketOptionType, Object.class); + } catch (Exception e) { + throw new Error("cannot locate the setOption() method", e); + } + } + } + IP_MULTICAST_TTL = ipMulticastTtl; + IP_MULTICAST_IF = ipMulticastIf; + IP_MULTICAST_LOOP = ipMulticastLoop; + GET_OPTION = getOption; + SET_OPTION = setOption; + } + + private final DatagramChannel javaChannel; + + NioDatagramChannelConfig(NioDatagramChannel channel, DatagramChannel javaChannel) { + super(channel, javaChannel.socket()); + this.javaChannel = javaChannel; + } + + @Override + public int getTimeToLive() { + return (Integer) getOption0(IP_MULTICAST_TTL); + } + + @Override + public DatagramChannelConfig setTimeToLive(int ttl) { + setOption0(IP_MULTICAST_TTL, ttl); + return this; + } + + @Override + public InetAddress getInterface() { + NetworkInterface inf = getNetworkInterface(); + if (inf != null) { + Enumeration addresses = SocketUtils.addressesFromNetworkInterface(inf); + if (addresses.hasMoreElements()) { + return addresses.nextElement(); + } + } + return null; + } + + @Override + public DatagramChannelConfig setInterface(InetAddress interfaceAddress) { + try { + setNetworkInterface(NetworkInterface.getByInetAddress(interfaceAddress)); + } catch (SocketException e) { + throw new ChannelException(e); + } + return this; + } + + @Override + public NetworkInterface getNetworkInterface() { + return (NetworkInterface) getOption0(IP_MULTICAST_IF); + } + + @Override + public DatagramChannelConfig setNetworkInterface(NetworkInterface networkInterface) { + setOption0(IP_MULTICAST_IF, networkInterface); + return this; + } + + @Override + public boolean isLoopbackModeDisabled() { + return (Boolean) getOption0(IP_MULTICAST_LOOP); + } + + @Override + public DatagramChannelConfig setLoopbackModeDisabled(boolean loopbackModeDisabled) { + setOption0(IP_MULTICAST_LOOP, loopbackModeDisabled); + return this; + } + + @Override + public DatagramChannelConfig setAutoRead(boolean autoRead) { + super.setAutoRead(autoRead); + return this; + } + + @Override + protected void autoReadCleared() { + ((NioDatagramChannel) channel).clearReadPending0(); + } + + private Object getOption0(Object option) { + if (GET_OPTION == null) { + throw new UnsupportedOperationException(); + } else { + try { + return GET_OPTION.invoke(javaChannel, option); + } catch (Exception e) { + throw new ChannelException(e); + } + } + } + + private void setOption0(Object option, Object value) { + if (SET_OPTION == null) { + throw new UnsupportedOperationException(); + } else { + try { + SET_OPTION.invoke(javaChannel, option, value); + } catch (Exception e) { + throw new ChannelException(e); + } + } + } + + @Override + public boolean setOption(ChannelOption option, T value) { + if (PlatformDependent.javaVersion() >= 7 && option instanceof NioChannelOption) { + return NioChannelOption.setOption(javaChannel, (NioChannelOption) option, value); + } + return super.setOption(option, value); + } + + @Override + public T getOption(ChannelOption option) { + if (PlatformDependent.javaVersion() >= 7 && option instanceof NioChannelOption) { + return NioChannelOption.getOption(javaChannel, (NioChannelOption) option); + } + return super.getOption(option); + } + + @SuppressWarnings("unchecked") + @Override + public Map, Object> getOptions() { + if (PlatformDependent.javaVersion() >= 7) { + return getOptions(super.getOptions(), NioChannelOption.getOptions(javaChannel)); + } + return super.getOptions(); + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/socket/nio/NioServerSocketChannel.java b/netty-channel/src/main/java/io/netty/channel/socket/nio/NioServerSocketChannel.java new file mode 100644 index 0000000..718132b --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/socket/nio/NioServerSocketChannel.java @@ -0,0 +1,250 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.socket.nio; + +import io.netty.channel.ChannelException; +import io.netty.channel.ChannelMetadata; +import io.netty.channel.ChannelOption; +import io.netty.channel.ChannelOutboundBuffer; +import io.netty.channel.socket.InternetProtocolFamily; +import io.netty.util.internal.SocketUtils; +import io.netty.channel.nio.AbstractNioMessageChannel; +import io.netty.channel.socket.DefaultServerSocketChannelConfig; +import io.netty.channel.socket.ServerSocketChannelConfig; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.SuppressJava6Requirement; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.io.IOException; +import java.lang.reflect.Method; +import java.net.InetSocketAddress; +import java.net.ServerSocket; +import java.net.SocketAddress; +import java.nio.channels.SelectionKey; +import java.nio.channels.ServerSocketChannel; +import java.nio.channels.SocketChannel; +import java.nio.channels.spi.SelectorProvider; +import java.util.List; +import java.util.Map; + +/** + * A {@link io.netty.channel.socket.ServerSocketChannel} implementation which uses + * NIO selector based implementation to accept new connections. + */ +public class NioServerSocketChannel extends AbstractNioMessageChannel + implements io.netty.channel.socket.ServerSocketChannel { + + private static final ChannelMetadata METADATA = new ChannelMetadata(false, 16); + private static final SelectorProvider DEFAULT_SELECTOR_PROVIDER = SelectorProvider.provider(); + + private static final InternalLogger logger = InternalLoggerFactory.getInstance(NioServerSocketChannel.class); + + private static final Method OPEN_SERVER_SOCKET_CHANNEL_WITH_FAMILY = + SelectorProviderUtil.findOpenMethod("openServerSocketChannel"); + + private static ServerSocketChannel newChannel(SelectorProvider provider, InternetProtocolFamily family) { + try { + ServerSocketChannel channel = + SelectorProviderUtil.newChannel(OPEN_SERVER_SOCKET_CHANNEL_WITH_FAMILY, provider, family); + return channel == null ? provider.openServerSocketChannel() : channel; + } catch (IOException e) { + throw new ChannelException("Failed to open a socket.", e); + } + } + + private final ServerSocketChannelConfig config; + + /** + * Create a new instance + */ + public NioServerSocketChannel() { + this(DEFAULT_SELECTOR_PROVIDER); + } + + /** + * Create a new instance using the given {@link SelectorProvider}. + */ + public NioServerSocketChannel(SelectorProvider provider) { + this(provider, null); + } + + /** + * Create a new instance using the given {@link SelectorProvider} and protocol family (supported only since JDK 15). + */ + public NioServerSocketChannel(SelectorProvider provider, InternetProtocolFamily family) { + this(newChannel(provider, family)); + } + + /** + * Create a new instance using the given {@link ServerSocketChannel}. + */ + public NioServerSocketChannel(ServerSocketChannel channel) { + super(null, channel, SelectionKey.OP_ACCEPT); + config = new NioServerSocketChannelConfig(this, javaChannel().socket()); + } + + @Override + public InetSocketAddress localAddress() { + return (InetSocketAddress) super.localAddress(); + } + + @Override + public ChannelMetadata metadata() { + return METADATA; + } + + @Override + public ServerSocketChannelConfig config() { + return config; + } + + @Override + public boolean isActive() { + // As java.nio.ServerSocketChannel.isBound() will continue to return true even after the channel was closed + // we will also need to check if it is open. + return isOpen() && javaChannel().socket().isBound(); + } + + @Override + public InetSocketAddress remoteAddress() { + return null; + } + + @Override + protected ServerSocketChannel javaChannel() { + return (ServerSocketChannel) super.javaChannel(); + } + + @Override + protected SocketAddress localAddress0() { + return SocketUtils.localSocketAddress(javaChannel().socket()); + } + + @SuppressJava6Requirement(reason = "Usage guarded by java version check") + @Override + protected void doBind(SocketAddress localAddress) throws Exception { + if (PlatformDependent.javaVersion() >= 7) { + javaChannel().bind(localAddress, config.getBacklog()); + } else { + javaChannel().socket().bind(localAddress, config.getBacklog()); + } + } + + @Override + protected void doClose() throws Exception { + javaChannel().close(); + } + + @Override + protected int doReadMessages(List buf) throws Exception { + SocketChannel ch = SocketUtils.accept(javaChannel()); + + try { + if (ch != null) { + buf.add(new NioSocketChannel(this, ch)); + return 1; + } + } catch (Throwable t) { + logger.warn("Failed to create a new channel from an accepted socket.", t); + + try { + ch.close(); + } catch (Throwable t2) { + logger.warn("Failed to close a socket.", t2); + } + } + + return 0; + } + + // Unnecessary stuff + @Override + protected boolean doConnect( + SocketAddress remoteAddress, SocketAddress localAddress) throws Exception { + throw new UnsupportedOperationException(); + } + + @Override + protected void doFinishConnect() throws Exception { + throw new UnsupportedOperationException(); + } + + @Override + protected SocketAddress remoteAddress0() { + return null; + } + + @Override + protected void doDisconnect() throws Exception { + throw new UnsupportedOperationException(); + } + + @Override + protected boolean doWriteMessage(Object msg, ChannelOutboundBuffer in) throws Exception { + throw new UnsupportedOperationException(); + } + + @Override + protected final Object filterOutboundMessage(Object msg) throws Exception { + throw new UnsupportedOperationException(); + } + + private final class NioServerSocketChannelConfig extends DefaultServerSocketChannelConfig { + private NioServerSocketChannelConfig(NioServerSocketChannel channel, ServerSocket javaSocket) { + super(channel, javaSocket); + } + + @Override + protected void autoReadCleared() { + clearReadPending(); + } + + @Override + public boolean setOption(ChannelOption option, T value) { + if (PlatformDependent.javaVersion() >= 7 && option instanceof NioChannelOption) { + return NioChannelOption.setOption(jdkChannel(), (NioChannelOption) option, value); + } + return super.setOption(option, value); + } + + @Override + public T getOption(ChannelOption option) { + if (PlatformDependent.javaVersion() >= 7 && option instanceof NioChannelOption) { + return NioChannelOption.getOption(jdkChannel(), (NioChannelOption) option); + } + return super.getOption(option); + } + + @Override + public Map, Object> getOptions() { + if (PlatformDependent.javaVersion() >= 7) { + return getOptions(super.getOptions(), NioChannelOption.getOptions(jdkChannel())); + } + return super.getOptions(); + } + + private ServerSocketChannel jdkChannel() { + return ((NioServerSocketChannel) channel).javaChannel(); + } + } + + // Override just to to be able to call directly via unit tests. + @Override + protected boolean closeOnReadError(Throwable cause) { + return super.closeOnReadError(cause); + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/socket/nio/NioSocketChannel.java b/netty-channel/src/main/java/io/netty/channel/socket/nio/NioSocketChannel.java new file mode 100644 index 0000000..c76ed5b --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/socket/nio/NioSocketChannel.java @@ -0,0 +1,537 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.socket.nio; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.Channel; +import io.netty.channel.ChannelException; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelOption; +import io.netty.channel.ChannelOutboundBuffer; +import io.netty.channel.ChannelPromise; +import io.netty.channel.EventLoop; +import io.netty.channel.FileRegion; +import io.netty.channel.RecvByteBufAllocator; +import io.netty.channel.nio.AbstractNioByteChannel; +import io.netty.channel.socket.DefaultSocketChannelConfig; +import io.netty.channel.socket.InternetProtocolFamily; +import io.netty.channel.socket.ServerSocketChannel; +import io.netty.channel.socket.SocketChannelConfig; +import io.netty.util.concurrent.GlobalEventExecutor; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.SocketUtils; +import io.netty.util.internal.SuppressJava6Requirement; +import io.netty.util.internal.UnstableApi; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.io.IOException; +import java.lang.reflect.Method; +import java.net.InetSocketAddress; +import java.net.Socket; +import java.net.SocketAddress; +import java.nio.ByteBuffer; +import java.nio.channels.SelectionKey; +import java.nio.channels.SocketChannel; +import java.nio.channels.spi.SelectorProvider; +import java.util.Map; +import java.util.concurrent.Executor; + +import static io.netty.channel.internal.ChannelUtils.MAX_BYTES_PER_GATHERING_WRITE_ATTEMPTED_LOW_THRESHOLD; + +/** + * {@link io.netty.channel.socket.SocketChannel} which uses NIO selector based implementation. + */ +public class NioSocketChannel extends AbstractNioByteChannel implements io.netty.channel.socket.SocketChannel { + private static final InternalLogger logger = InternalLoggerFactory.getInstance(NioSocketChannel.class); + private static final SelectorProvider DEFAULT_SELECTOR_PROVIDER = SelectorProvider.provider(); + + private static final Method OPEN_SOCKET_CHANNEL_WITH_FAMILY = + SelectorProviderUtil.findOpenMethod("openSocketChannel"); + + private final SocketChannelConfig config; + + private static SocketChannel newChannel(SelectorProvider provider, InternetProtocolFamily family) { + try { + SocketChannel channel = SelectorProviderUtil.newChannel(OPEN_SOCKET_CHANNEL_WITH_FAMILY, provider, family); + return channel == null ? provider.openSocketChannel() : channel; + } catch (IOException e) { + throw new ChannelException("Failed to open a socket.", e); + } + } + + /** + * Create a new instance + */ + public NioSocketChannel() { + this(DEFAULT_SELECTOR_PROVIDER); + } + + /** + * Create a new instance using the given {@link SelectorProvider}. + */ + public NioSocketChannel(SelectorProvider provider) { + this(provider, null); + } + + /** + * Create a new instance using the given {@link SelectorProvider} and protocol family (supported only since JDK 15). + */ + public NioSocketChannel(SelectorProvider provider, InternetProtocolFamily family) { + this(newChannel(provider, family)); + } + + /** + * Create a new instance using the given {@link SocketChannel}. + */ + public NioSocketChannel(SocketChannel socket) { + this(null, socket); + } + + /** + * Create a new instance + * + * @param parent the {@link Channel} which created this instance or {@code null} if it was created by the user + * @param socket the {@link SocketChannel} which will be used + */ + public NioSocketChannel(Channel parent, SocketChannel socket) { + super(parent, socket); + config = new NioSocketChannelConfig(this, socket.socket()); + } + + @Override + public ServerSocketChannel parent() { + return (ServerSocketChannel) super.parent(); + } + + @Override + public SocketChannelConfig config() { + return config; + } + + @Override + protected SocketChannel javaChannel() { + return (SocketChannel) super.javaChannel(); + } + + @Override + public boolean isActive() { + SocketChannel ch = javaChannel(); + return ch.isOpen() && ch.isConnected(); + } + + @Override + public boolean isOutputShutdown() { + return javaChannel().socket().isOutputShutdown() || !isActive(); + } + + @Override + public boolean isInputShutdown() { + return javaChannel().socket().isInputShutdown() || !isActive(); + } + + @Override + public boolean isShutdown() { + Socket socket = javaChannel().socket(); + return socket.isInputShutdown() && socket.isOutputShutdown() || !isActive(); + } + + @Override + public InetSocketAddress localAddress() { + return (InetSocketAddress) super.localAddress(); + } + + @Override + public InetSocketAddress remoteAddress() { + return (InetSocketAddress) super.remoteAddress(); + } + + @SuppressJava6Requirement(reason = "Usage guarded by java version check") + @UnstableApi + @Override + protected final void doShutdownOutput() throws Exception { + if (PlatformDependent.javaVersion() >= 7) { + javaChannel().shutdownOutput(); + } else { + javaChannel().socket().shutdownOutput(); + } + } + + @Override + public ChannelFuture shutdownOutput() { + return shutdownOutput(newPromise()); + } + + @Override + public ChannelFuture shutdownOutput(final ChannelPromise promise) { + final EventLoop loop = eventLoop(); + if (loop.inEventLoop()) { + ((AbstractUnsafe) unsafe()).shutdownOutput(promise); + } else { + loop.execute(new Runnable() { + @Override + public void run() { + ((AbstractUnsafe) unsafe()).shutdownOutput(promise); + } + }); + } + return promise; + } + + @Override + public ChannelFuture shutdownInput() { + return shutdownInput(newPromise()); + } + + @Override + protected boolean isInputShutdown0() { + return isInputShutdown(); + } + + @Override + public ChannelFuture shutdownInput(final ChannelPromise promise) { + EventLoop loop = eventLoop(); + if (loop.inEventLoop()) { + shutdownInput0(promise); + } else { + loop.execute(new Runnable() { + @Override + public void run() { + shutdownInput0(promise); + } + }); + } + return promise; + } + + @Override + public ChannelFuture shutdown() { + return shutdown(newPromise()); + } + + @Override + public ChannelFuture shutdown(final ChannelPromise promise) { + ChannelFuture shutdownOutputFuture = shutdownOutput(); + if (shutdownOutputFuture.isDone()) { + shutdownOutputDone(shutdownOutputFuture, promise); + } else { + shutdownOutputFuture.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(final ChannelFuture shutdownOutputFuture) throws Exception { + shutdownOutputDone(shutdownOutputFuture, promise); + } + }); + } + return promise; + } + + private void shutdownOutputDone(final ChannelFuture shutdownOutputFuture, final ChannelPromise promise) { + ChannelFuture shutdownInputFuture = shutdownInput(); + if (shutdownInputFuture.isDone()) { + shutdownDone(shutdownOutputFuture, shutdownInputFuture, promise); + } else { + shutdownInputFuture.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture shutdownInputFuture) throws Exception { + shutdownDone(shutdownOutputFuture, shutdownInputFuture, promise); + } + }); + } + } + + private static void shutdownDone(ChannelFuture shutdownOutputFuture, + ChannelFuture shutdownInputFuture, + ChannelPromise promise) { + Throwable shutdownOutputCause = shutdownOutputFuture.cause(); + Throwable shutdownInputCause = shutdownInputFuture.cause(); + if (shutdownOutputCause != null) { + if (shutdownInputCause != null) { + logger.debug("Exception suppressed because a previous exception occurred.", + shutdownInputCause); + } + promise.setFailure(shutdownOutputCause); + } else if (shutdownInputCause != null) { + promise.setFailure(shutdownInputCause); + } else { + promise.setSuccess(); + } + } + private void shutdownInput0(final ChannelPromise promise) { + try { + shutdownInput0(); + promise.setSuccess(); + } catch (Throwable t) { + promise.setFailure(t); + } + } + + @SuppressJava6Requirement(reason = "Usage guarded by java version check") + private void shutdownInput0() throws Exception { + if (PlatformDependent.javaVersion() >= 7) { + javaChannel().shutdownInput(); + } else { + javaChannel().socket().shutdownInput(); + } + } + + @Override + protected SocketAddress localAddress0() { + return javaChannel().socket().getLocalSocketAddress(); + } + + @Override + protected SocketAddress remoteAddress0() { + return javaChannel().socket().getRemoteSocketAddress(); + } + + @Override + protected void doBind(SocketAddress localAddress) throws Exception { + doBind0(localAddress); + } + + private void doBind0(SocketAddress localAddress) throws Exception { + if (PlatformDependent.javaVersion() >= 7) { + SocketUtils.bind(javaChannel(), localAddress); + } else { + SocketUtils.bind(javaChannel().socket(), localAddress); + } + } + + @Override + protected boolean doConnect(SocketAddress remoteAddress, SocketAddress localAddress) throws Exception { + if (localAddress != null) { + doBind0(localAddress); + } + + boolean success = false; + try { + boolean connected = SocketUtils.connect(javaChannel(), remoteAddress); + if (!connected) { + selectionKey().interestOps(SelectionKey.OP_CONNECT); + } + success = true; + return connected; + } finally { + if (!success) { + doClose(); + } + } + } + + @Override + protected void doFinishConnect() throws Exception { + if (!javaChannel().finishConnect()) { + throw new Error(); + } + } + + @Override + protected void doDisconnect() throws Exception { + doClose(); + } + + @Override + protected void doClose() throws Exception { + super.doClose(); + javaChannel().close(); + } + + @Override + protected int doReadBytes(ByteBuf byteBuf) throws Exception { + final RecvByteBufAllocator.Handle allocHandle = unsafe().recvBufAllocHandle(); + allocHandle.attemptedBytesRead(byteBuf.writableBytes()); + return byteBuf.writeBytes(javaChannel(), allocHandle.attemptedBytesRead()); + } + + @Override + protected int doWriteBytes(ByteBuf buf) throws Exception { + final int expectedWrittenBytes = buf.readableBytes(); + return buf.readBytes(javaChannel(), expectedWrittenBytes); + } + + @Override + protected long doWriteFileRegion(FileRegion region) throws Exception { + final long position = region.transferred(); + return region.transferTo(javaChannel(), position); + } + + private void adjustMaxBytesPerGatheringWrite(int attempted, int written, int oldMaxBytesPerGatheringWrite) { + // By default we track the SO_SNDBUF when ever it is explicitly set. However some OSes may dynamically change + // SO_SNDBUF (and other characteristics that determine how much data can be written at once) so we should try + // make a best effort to adjust as OS behavior changes. + if (attempted == written) { + if (attempted << 1 > oldMaxBytesPerGatheringWrite) { + ((NioSocketChannelConfig) config).setMaxBytesPerGatheringWrite(attempted << 1); + } + } else if (attempted > MAX_BYTES_PER_GATHERING_WRITE_ATTEMPTED_LOW_THRESHOLD && written < attempted >>> 1) { + ((NioSocketChannelConfig) config).setMaxBytesPerGatheringWrite(attempted >>> 1); + } + } + + @Override + protected void doWrite(ChannelOutboundBuffer in) throws Exception { + SocketChannel ch = javaChannel(); + int writeSpinCount = config().getWriteSpinCount(); + do { + if (in.isEmpty()) { + // All written so clear OP_WRITE + clearOpWrite(); + // Directly return here so incompleteWrite(...) is not called. + return; + } + + // Ensure the pending writes are made of ByteBufs only. + int maxBytesPerGatheringWrite = ((NioSocketChannelConfig) config).getMaxBytesPerGatheringWrite(); + ByteBuffer[] nioBuffers = in.nioBuffers(1024, maxBytesPerGatheringWrite); + int nioBufferCnt = in.nioBufferCount(); + + // Always use nioBuffers() to workaround data-corruption. + // See https://github.com/netty/netty/issues/2761 + switch (nioBufferCnt) { + case 0: + // We have something else beside ByteBuffers to write so fallback to normal writes. + writeSpinCount -= doWrite0(in); + break; + case 1: { + // Only one ByteBuf so use non-gathering write + // Zero length buffers are not added to nioBuffers by ChannelOutboundBuffer, so there is no need + // to check if the total size of all the buffers is non-zero. + ByteBuffer buffer = nioBuffers[0]; + int attemptedBytes = buffer.remaining(); + final int localWrittenBytes = ch.write(buffer); + if (localWrittenBytes <= 0) { + incompleteWrite(true); + return; + } + adjustMaxBytesPerGatheringWrite(attemptedBytes, localWrittenBytes, maxBytesPerGatheringWrite); + in.removeBytes(localWrittenBytes); + --writeSpinCount; + break; + } + default: { + // Zero length buffers are not added to nioBuffers by ChannelOutboundBuffer, so there is no need + // to check if the total size of all the buffers is non-zero. + // We limit the max amount to int above so cast is safe + long attemptedBytes = in.nioBufferSize(); + final long localWrittenBytes = ch.write(nioBuffers, 0, nioBufferCnt); + if (localWrittenBytes <= 0) { + incompleteWrite(true); + return; + } + // Casting to int is safe because we limit the total amount of data in the nioBuffers to int above. + adjustMaxBytesPerGatheringWrite((int) attemptedBytes, (int) localWrittenBytes, + maxBytesPerGatheringWrite); + in.removeBytes(localWrittenBytes); + --writeSpinCount; + break; + } + } + } while (writeSpinCount > 0); + + incompleteWrite(writeSpinCount < 0); + } + + @Override + protected AbstractNioUnsafe newUnsafe() { + return new NioSocketChannelUnsafe(); + } + + private final class NioSocketChannelUnsafe extends NioByteUnsafe { + @Override + protected Executor prepareToClose() { + try { + if (javaChannel().isOpen() && config().getSoLinger() > 0) { + // We need to cancel this key of the channel so we may not end up in a eventloop spin + // because we try to read or write until the actual close happens which may be later due + // SO_LINGER handling. + // See https://github.com/netty/netty/issues/4449 + doDeregister(); + return GlobalEventExecutor.INSTANCE; + } + } catch (Throwable ignore) { + // Ignore the error as the underlying channel may be closed in the meantime and so + // getSoLinger() may produce an exception. In this case we just return null. + // See https://github.com/netty/netty/issues/4449 + } + return null; + } + } + + private final class NioSocketChannelConfig extends DefaultSocketChannelConfig { + private volatile int maxBytesPerGatheringWrite = Integer.MAX_VALUE; + private NioSocketChannelConfig(NioSocketChannel channel, Socket javaSocket) { + super(channel, javaSocket); + calculateMaxBytesPerGatheringWrite(); + } + + @Override + protected void autoReadCleared() { + clearReadPending(); + } + + @Override + public NioSocketChannelConfig setSendBufferSize(int sendBufferSize) { + super.setSendBufferSize(sendBufferSize); + calculateMaxBytesPerGatheringWrite(); + return this; + } + + @Override + public boolean setOption(ChannelOption option, T value) { + if (PlatformDependent.javaVersion() >= 7 && option instanceof NioChannelOption) { + return NioChannelOption.setOption(jdkChannel(), (NioChannelOption) option, value); + } + return super.setOption(option, value); + } + + @Override + public T getOption(ChannelOption option) { + if (PlatformDependent.javaVersion() >= 7 && option instanceof NioChannelOption) { + return NioChannelOption.getOption(jdkChannel(), (NioChannelOption) option); + } + return super.getOption(option); + } + + @Override + public Map, Object> getOptions() { + if (PlatformDependent.javaVersion() >= 7) { + return getOptions(super.getOptions(), NioChannelOption.getOptions(jdkChannel())); + } + return super.getOptions(); + } + + void setMaxBytesPerGatheringWrite(int maxBytesPerGatheringWrite) { + this.maxBytesPerGatheringWrite = maxBytesPerGatheringWrite; + } + + int getMaxBytesPerGatheringWrite() { + return maxBytesPerGatheringWrite; + } + + private void calculateMaxBytesPerGatheringWrite() { + // Multiply by 2 to give some extra space in case the OS can process write data faster than we can provide. + int newSendBufferSize = getSendBufferSize() << 1; + if (newSendBufferSize > 0) { + setMaxBytesPerGatheringWrite(newSendBufferSize); + } + } + + private SocketChannel jdkChannel() { + return ((NioSocketChannel) channel).javaChannel(); + } + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/socket/nio/ProtocolFamilyConverter.java b/netty-channel/src/main/java/io/netty/channel/socket/nio/ProtocolFamilyConverter.java new file mode 100644 index 0000000..527b4eb --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/socket/nio/ProtocolFamilyConverter.java @@ -0,0 +1,47 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.socket.nio; + +import io.netty.channel.socket.InternetProtocolFamily; +import io.netty.util.internal.SuppressJava6Requirement; + +import java.net.ProtocolFamily; +import java.net.StandardProtocolFamily; + +/** + * Helper class which convert the {@link InternetProtocolFamily}. + */ +final class ProtocolFamilyConverter { + + private ProtocolFamilyConverter() { + // Utility class + } + + /** + * Convert the {@link InternetProtocolFamily}. This MUST only be called on jdk version >= 7. + */ + @SuppressJava6Requirement(reason = "Usage guarded by java version check") + public static ProtocolFamily convert(InternetProtocolFamily family) { + switch (family) { + case IPv4: + return StandardProtocolFamily.INET; + case IPv6: + return StandardProtocolFamily.INET6; + default: + throw new IllegalArgumentException(); + } + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/socket/nio/SelectorProviderUtil.java b/netty-channel/src/main/java/io/netty/channel/socket/nio/SelectorProviderUtil.java new file mode 100644 index 0000000..0c05d01 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/socket/nio/SelectorProviderUtil.java @@ -0,0 +1,71 @@ +/* + * Copyright 2022 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.socket.nio; + +import io.netty.channel.socket.InternetProtocolFamily; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.SuppressJava6Requirement; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.io.IOException; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.nio.channels.Channel; +import java.nio.channels.SocketChannel; +import java.nio.channels.spi.SelectorProvider; + +final class SelectorProviderUtil { + private static final InternalLogger logger = InternalLoggerFactory.getInstance(SelectorProviderUtil.class); + + @SuppressJava6Requirement(reason = "Usage guarded by java version check") + static Method findOpenMethod(String methodName) { + if (PlatformDependent.javaVersion() >= 15) { + try { + return SelectorProvider.class.getMethod(methodName, java.net.ProtocolFamily.class); + } catch (Throwable e) { + logger.debug("SelectorProvider.{}(ProtocolFamily) not available, will use default", methodName, e); + } + } + return null; + } + + @SuppressJava6Requirement(reason = "Usage guarded by java version check") + static C newChannel(Method method, SelectorProvider provider, + InternetProtocolFamily family) throws IOException { + /** + * Use the {@link SelectorProvider} to open {@link SocketChannel} and so remove condition in + * {@link SelectorProvider#provider()} which is called by each SocketChannel.open() otherwise. + * + * See #2308. + */ + if (family != null && method != null) { + try { + @SuppressWarnings("unchecked") + C channel = (C) method.invoke( + provider, ProtocolFamilyConverter.convert(family)); + return channel; + } catch (InvocationTargetException e) { + throw new IOException(e); + } catch (IllegalAccessException e) { + throw new IOException(e); + } + } + return null; + } + + private SelectorProviderUtil() { } +} diff --git a/netty-channel/src/main/java/io/netty/channel/socket/nio/package-info.java b/netty-channel/src/main/java/io/netty/channel/socket/nio/package-info.java new file mode 100644 index 0000000..a0942f6 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/socket/nio/package-info.java @@ -0,0 +1,21 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/** + * NIO-based socket channel + * API implementation - recommended for a large number of connections (>= 1000). + */ +package io.netty.channel.socket.nio; diff --git a/netty-channel/src/main/java/io/netty/channel/socket/oio/DefaultOioDatagramChannelConfig.java b/netty-channel/src/main/java/io/netty/channel/socket/oio/DefaultOioDatagramChannelConfig.java new file mode 100644 index 0000000..b490039 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/socket/oio/DefaultOioDatagramChannelConfig.java @@ -0,0 +1,207 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.socket.oio; + +import io.netty.buffer.ByteBufAllocator; +import io.netty.channel.ChannelException; +import io.netty.channel.ChannelOption; +import io.netty.channel.MessageSizeEstimator; +import io.netty.channel.PreferHeapByteBufAllocator; +import io.netty.channel.RecvByteBufAllocator; +import io.netty.channel.WriteBufferWaterMark; +import io.netty.channel.socket.DatagramChannel; +import io.netty.channel.socket.DefaultDatagramChannelConfig; + +import java.io.IOException; +import java.net.DatagramSocket; +import java.net.InetAddress; +import java.net.NetworkInterface; +import java.util.Map; + +import static io.netty.channel.ChannelOption.SO_TIMEOUT; + +final class DefaultOioDatagramChannelConfig extends DefaultDatagramChannelConfig implements OioDatagramChannelConfig { + + DefaultOioDatagramChannelConfig(DatagramChannel channel, DatagramSocket javaSocket) { + super(channel, javaSocket); + setAllocator(new PreferHeapByteBufAllocator(getAllocator())); + } + + @Override + public Map, Object> getOptions() { + return getOptions(super.getOptions(), SO_TIMEOUT); + } + + @SuppressWarnings("unchecked") + @Override + public T getOption(ChannelOption option) { + if (option == SO_TIMEOUT) { + return (T) Integer.valueOf(getSoTimeout()); + } + return super.getOption(option); + } + + @Override + public boolean setOption(ChannelOption option, T value) { + validate(option, value); + + if (option == SO_TIMEOUT) { + setSoTimeout((Integer) value); + } else { + return super.setOption(option, value); + } + return true; + } + + @Override + public OioDatagramChannelConfig setSoTimeout(int timeout) { + try { + javaSocket().setSoTimeout(timeout); + } catch (IOException e) { + throw new ChannelException(e); + } + return this; + } + + @Override + public int getSoTimeout() { + try { + return javaSocket().getSoTimeout(); + } catch (IOException e) { + throw new ChannelException(e); + } + } + + @Override + public OioDatagramChannelConfig setBroadcast(boolean broadcast) { + super.setBroadcast(broadcast); + return this; + } + + @Override + public OioDatagramChannelConfig setInterface(InetAddress interfaceAddress) { + super.setInterface(interfaceAddress); + return this; + } + + @Override + public OioDatagramChannelConfig setLoopbackModeDisabled(boolean loopbackModeDisabled) { + super.setLoopbackModeDisabled(loopbackModeDisabled); + return this; + } + + @Override + public OioDatagramChannelConfig setNetworkInterface(NetworkInterface networkInterface) { + super.setNetworkInterface(networkInterface); + return this; + } + + @Override + public OioDatagramChannelConfig setReuseAddress(boolean reuseAddress) { + super.setReuseAddress(reuseAddress); + return this; + } + + @Override + public OioDatagramChannelConfig setReceiveBufferSize(int receiveBufferSize) { + super.setReceiveBufferSize(receiveBufferSize); + return this; + } + + @Override + public OioDatagramChannelConfig setSendBufferSize(int sendBufferSize) { + super.setSendBufferSize(sendBufferSize); + return this; + } + + @Override + public OioDatagramChannelConfig setTimeToLive(int ttl) { + super.setTimeToLive(ttl); + return this; + } + + @Override + public OioDatagramChannelConfig setTrafficClass(int trafficClass) { + super.setTrafficClass(trafficClass); + return this; + } + + @Override + public OioDatagramChannelConfig setWriteSpinCount(int writeSpinCount) { + super.setWriteSpinCount(writeSpinCount); + return this; + } + + @Override + public OioDatagramChannelConfig setConnectTimeoutMillis(int connectTimeoutMillis) { + super.setConnectTimeoutMillis(connectTimeoutMillis); + return this; + } + + @Override + public OioDatagramChannelConfig setMaxMessagesPerRead(int maxMessagesPerRead) { + super.setMaxMessagesPerRead(maxMessagesPerRead); + return this; + } + + @Override + public OioDatagramChannelConfig setAllocator(ByteBufAllocator allocator) { + super.setAllocator(allocator); + return this; + } + + @Override + public OioDatagramChannelConfig setRecvByteBufAllocator(RecvByteBufAllocator allocator) { + super.setRecvByteBufAllocator(allocator); + return this; + } + + @Override + public OioDatagramChannelConfig setAutoRead(boolean autoRead) { + super.setAutoRead(autoRead); + return this; + } + + @Override + public OioDatagramChannelConfig setAutoClose(boolean autoClose) { + super.setAutoClose(autoClose); + return this; + } + + @Override + public OioDatagramChannelConfig setWriteBufferHighWaterMark(int writeBufferHighWaterMark) { + super.setWriteBufferHighWaterMark(writeBufferHighWaterMark); + return this; + } + + @Override + public OioDatagramChannelConfig setWriteBufferLowWaterMark(int writeBufferLowWaterMark) { + super.setWriteBufferLowWaterMark(writeBufferLowWaterMark); + return this; + } + + @Override + public OioDatagramChannelConfig setWriteBufferWaterMark(WriteBufferWaterMark writeBufferWaterMark) { + super.setWriteBufferWaterMark(writeBufferWaterMark); + return this; + } + + @Override + public OioDatagramChannelConfig setMessageSizeEstimator(MessageSizeEstimator estimator) { + super.setMessageSizeEstimator(estimator); + return this; + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/socket/oio/DefaultOioServerSocketChannelConfig.java b/netty-channel/src/main/java/io/netty/channel/socket/oio/DefaultOioServerSocketChannelConfig.java new file mode 100644 index 0000000..bc65f70 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/socket/oio/DefaultOioServerSocketChannelConfig.java @@ -0,0 +1,197 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.socket.oio; + +import io.netty.buffer.ByteBufAllocator; +import io.netty.channel.ChannelException; +import io.netty.channel.ChannelOption; +import io.netty.channel.MessageSizeEstimator; +import io.netty.channel.PreferHeapByteBufAllocator; +import io.netty.channel.RecvByteBufAllocator; +import io.netty.channel.WriteBufferWaterMark; +import io.netty.channel.socket.DefaultServerSocketChannelConfig; +import io.netty.channel.socket.ServerSocketChannel; + +import java.io.IOException; +import java.net.ServerSocket; +import java.util.Map; + +import static io.netty.channel.ChannelOption.*; + +/** + * Default {@link OioServerSocketChannelConfig} implementation + * + * @deprecated use NIO / EPOLL / KQUEUE transport. + */ +@Deprecated +public class DefaultOioServerSocketChannelConfig extends DefaultServerSocketChannelConfig implements + OioServerSocketChannelConfig { + + @Deprecated + public DefaultOioServerSocketChannelConfig(ServerSocketChannel channel, ServerSocket javaSocket) { + super(channel, javaSocket); + setAllocator(new PreferHeapByteBufAllocator(getAllocator())); + } + + DefaultOioServerSocketChannelConfig(OioServerSocketChannel channel, ServerSocket javaSocket) { + super(channel, javaSocket); + setAllocator(new PreferHeapByteBufAllocator(getAllocator())); + } + + @Override + public Map, Object> getOptions() { + return getOptions( + super.getOptions(), SO_TIMEOUT); + } + + @SuppressWarnings("unchecked") + @Override + public T getOption(ChannelOption option) { + if (option == SO_TIMEOUT) { + return (T) Integer.valueOf(getSoTimeout()); + } + return super.getOption(option); + } + + @Override + public boolean setOption(ChannelOption option, T value) { + validate(option, value); + + if (option == SO_TIMEOUT) { + setSoTimeout((Integer) value); + } else { + return super.setOption(option, value); + } + return true; + } + + @Override + public OioServerSocketChannelConfig setSoTimeout(int timeout) { + try { + javaSocket.setSoTimeout(timeout); + } catch (IOException e) { + throw new ChannelException(e); + } + return this; + } + + @Override + public int getSoTimeout() { + try { + return javaSocket.getSoTimeout(); + } catch (IOException e) { + throw new ChannelException(e); + } + } + + @Override + public OioServerSocketChannelConfig setBacklog(int backlog) { + super.setBacklog(backlog); + return this; + } + + @Override + public OioServerSocketChannelConfig setReuseAddress(boolean reuseAddress) { + super.setReuseAddress(reuseAddress); + return this; + } + + @Override + public OioServerSocketChannelConfig setReceiveBufferSize(int receiveBufferSize) { + super.setReceiveBufferSize(receiveBufferSize); + return this; + } + + @Override + public OioServerSocketChannelConfig setPerformancePreferences(int connectionTime, int latency, int bandwidth) { + super.setPerformancePreferences(connectionTime, latency, bandwidth); + return this; + } + + @Override + public OioServerSocketChannelConfig setConnectTimeoutMillis(int connectTimeoutMillis) { + super.setConnectTimeoutMillis(connectTimeoutMillis); + return this; + } + + @Override + @Deprecated + public OioServerSocketChannelConfig setMaxMessagesPerRead(int maxMessagesPerRead) { + super.setMaxMessagesPerRead(maxMessagesPerRead); + return this; + } + + @Override + public OioServerSocketChannelConfig setWriteSpinCount(int writeSpinCount) { + super.setWriteSpinCount(writeSpinCount); + return this; + } + + @Override + public OioServerSocketChannelConfig setAllocator(ByteBufAllocator allocator) { + super.setAllocator(allocator); + return this; + } + + @Override + public OioServerSocketChannelConfig setRecvByteBufAllocator(RecvByteBufAllocator allocator) { + super.setRecvByteBufAllocator(allocator); + return this; + } + + @Override + public OioServerSocketChannelConfig setAutoRead(boolean autoRead) { + super.setAutoRead(autoRead); + return this; + } + + @Override + protected void autoReadCleared() { + if (channel instanceof OioServerSocketChannel) { + ((OioServerSocketChannel) channel).clearReadPending0(); + } + } + + @Override + public OioServerSocketChannelConfig setAutoClose(boolean autoClose) { + super.setAutoClose(autoClose); + return this; + } + + @Override + public OioServerSocketChannelConfig setWriteBufferHighWaterMark(int writeBufferHighWaterMark) { + super.setWriteBufferHighWaterMark(writeBufferHighWaterMark); + return this; + } + + @Override + public OioServerSocketChannelConfig setWriteBufferLowWaterMark(int writeBufferLowWaterMark) { + super.setWriteBufferLowWaterMark(writeBufferLowWaterMark); + return this; + } + + @Override + public OioServerSocketChannelConfig setWriteBufferWaterMark(WriteBufferWaterMark writeBufferWaterMark) { + super.setWriteBufferWaterMark(writeBufferWaterMark); + return this; + } + + @Override + public OioServerSocketChannelConfig setMessageSizeEstimator(MessageSizeEstimator estimator) { + super.setMessageSizeEstimator(estimator); + return this; + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/socket/oio/DefaultOioSocketChannelConfig.java b/netty-channel/src/main/java/io/netty/channel/socket/oio/DefaultOioSocketChannelConfig.java new file mode 100644 index 0000000..0dae3e2 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/socket/oio/DefaultOioSocketChannelConfig.java @@ -0,0 +1,225 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.socket.oio; + +import io.netty.buffer.ByteBufAllocator; +import io.netty.channel.ChannelException; +import io.netty.channel.ChannelOption; +import io.netty.channel.MessageSizeEstimator; +import io.netty.channel.PreferHeapByteBufAllocator; +import io.netty.channel.RecvByteBufAllocator; +import io.netty.channel.WriteBufferWaterMark; +import io.netty.channel.socket.DefaultSocketChannelConfig; +import io.netty.channel.socket.SocketChannel; + +import java.io.IOException; +import java.net.Socket; +import java.util.Map; + +import static io.netty.channel.ChannelOption.*; + +/** + * Default {@link OioSocketChannelConfig} implementation + * + * @deprecated use NIO / EPOLL / KQUEUE transport. + */ +@Deprecated +public class DefaultOioSocketChannelConfig extends DefaultSocketChannelConfig implements OioSocketChannelConfig { + @Deprecated + public DefaultOioSocketChannelConfig(SocketChannel channel, Socket javaSocket) { + super(channel, javaSocket); + setAllocator(new PreferHeapByteBufAllocator(getAllocator())); + } + + DefaultOioSocketChannelConfig(OioSocketChannel channel, Socket javaSocket) { + super(channel, javaSocket); + setAllocator(new PreferHeapByteBufAllocator(getAllocator())); + } + + @Override + public Map, Object> getOptions() { + return getOptions( + super.getOptions(), SO_TIMEOUT); + } + + @SuppressWarnings("unchecked") + @Override + public T getOption(ChannelOption option) { + if (option == SO_TIMEOUT) { + return (T) Integer.valueOf(getSoTimeout()); + } + return super.getOption(option); + } + + @Override + public boolean setOption(ChannelOption option, T value) { + validate(option, value); + + if (option == SO_TIMEOUT) { + setSoTimeout((Integer) value); + } else { + return super.setOption(option, value); + } + return true; + } + + @Override + public OioSocketChannelConfig setSoTimeout(int timeout) { + try { + javaSocket.setSoTimeout(timeout); + } catch (IOException e) { + throw new ChannelException(e); + } + return this; + } + + @Override + public int getSoTimeout() { + try { + return javaSocket.getSoTimeout(); + } catch (IOException e) { + throw new ChannelException(e); + } + } + + @Override + public OioSocketChannelConfig setTcpNoDelay(boolean tcpNoDelay) { + super.setTcpNoDelay(tcpNoDelay); + return this; + } + + @Override + public OioSocketChannelConfig setSoLinger(int soLinger) { + super.setSoLinger(soLinger); + return this; + } + + @Override + public OioSocketChannelConfig setSendBufferSize(int sendBufferSize) { + super.setSendBufferSize(sendBufferSize); + return this; + } + + @Override + public OioSocketChannelConfig setReceiveBufferSize(int receiveBufferSize) { + super.setReceiveBufferSize(receiveBufferSize); + return this; + } + + @Override + public OioSocketChannelConfig setKeepAlive(boolean keepAlive) { + super.setKeepAlive(keepAlive); + return this; + } + + @Override + public OioSocketChannelConfig setTrafficClass(int trafficClass) { + super.setTrafficClass(trafficClass); + return this; + } + + @Override + public OioSocketChannelConfig setReuseAddress(boolean reuseAddress) { + super.setReuseAddress(reuseAddress); + return this; + } + + @Override + public OioSocketChannelConfig setPerformancePreferences(int connectionTime, int latency, int bandwidth) { + super.setPerformancePreferences(connectionTime, latency, bandwidth); + return this; + } + + @Override + public OioSocketChannelConfig setAllowHalfClosure(boolean allowHalfClosure) { + super.setAllowHalfClosure(allowHalfClosure); + return this; + } + + @Override + public OioSocketChannelConfig setConnectTimeoutMillis(int connectTimeoutMillis) { + super.setConnectTimeoutMillis(connectTimeoutMillis); + return this; + } + + @Override + @Deprecated + public OioSocketChannelConfig setMaxMessagesPerRead(int maxMessagesPerRead) { + super.setMaxMessagesPerRead(maxMessagesPerRead); + return this; + } + + @Override + public OioSocketChannelConfig setWriteSpinCount(int writeSpinCount) { + super.setWriteSpinCount(writeSpinCount); + return this; + } + + @Override + public OioSocketChannelConfig setAllocator(ByteBufAllocator allocator) { + super.setAllocator(allocator); + return this; + } + + @Override + public OioSocketChannelConfig setRecvByteBufAllocator(RecvByteBufAllocator allocator) { + super.setRecvByteBufAllocator(allocator); + return this; + } + + @Override + public OioSocketChannelConfig setAutoRead(boolean autoRead) { + super.setAutoRead(autoRead); + return this; + } + + @Override + protected void autoReadCleared() { + if (channel instanceof OioSocketChannel) { + ((OioSocketChannel) channel).clearReadPending0(); + } + } + + @Override + public OioSocketChannelConfig setAutoClose(boolean autoClose) { + super.setAutoClose(autoClose); + return this; + } + + @Override + public OioSocketChannelConfig setWriteBufferHighWaterMark(int writeBufferHighWaterMark) { + super.setWriteBufferHighWaterMark(writeBufferHighWaterMark); + return this; + } + + @Override + public OioSocketChannelConfig setWriteBufferLowWaterMark(int writeBufferLowWaterMark) { + super.setWriteBufferLowWaterMark(writeBufferLowWaterMark); + return this; + } + + @Override + public OioSocketChannelConfig setWriteBufferWaterMark(WriteBufferWaterMark writeBufferWaterMark) { + super.setWriteBufferWaterMark(writeBufferWaterMark); + return this; + } + + @Override + public OioSocketChannelConfig setMessageSizeEstimator(MessageSizeEstimator estimator) { + super.setMessageSizeEstimator(estimator); + return this; + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/socket/oio/OioDatagramChannel.java b/netty-channel/src/main/java/io/netty/channel/socket/oio/OioDatagramChannel.java new file mode 100644 index 0000000..36bca4d --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/socket/oio/OioDatagramChannel.java @@ -0,0 +1,460 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.socket.oio; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.channel.AddressedEnvelope; +import io.netty.channel.Channel; +import io.netty.channel.ChannelException; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelMetadata; +import io.netty.channel.ChannelOption; +import io.netty.channel.ChannelOutboundBuffer; +import io.netty.channel.ChannelPromise; +import io.netty.channel.RecvByteBufAllocator; +import io.netty.channel.oio.AbstractOioMessageChannel; +import io.netty.channel.socket.DatagramChannel; +import io.netty.channel.socket.DatagramChannelConfig; +import io.netty.channel.socket.DatagramPacket; +import io.netty.util.internal.EmptyArrays; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.StringUtil; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.io.IOException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.MulticastSocket; +import java.net.NetworkInterface; +import java.net.SocketAddress; +import java.net.SocketException; +import java.net.SocketTimeoutException; +import java.nio.channels.NotYetConnectedException; +import java.nio.channels.UnresolvedAddressException; +import java.util.List; +import java.util.Locale; + +/** + * An OIO datagram {@link Channel} that sends and receives an + * {@link AddressedEnvelope}. + * + * @see AddressedEnvelope + * @see DatagramPacket + * @deprecated use NIO / EPOLL / KQUEUE transport. + */ +@Deprecated +public class OioDatagramChannel extends AbstractOioMessageChannel + implements DatagramChannel { + + private static final InternalLogger logger = InternalLoggerFactory.getInstance(OioDatagramChannel.class); + + private static final ChannelMetadata METADATA = new ChannelMetadata(true); + private static final String EXPECTED_TYPES = + " (expected: " + StringUtil.simpleClassName(DatagramPacket.class) + ", " + + StringUtil.simpleClassName(AddressedEnvelope.class) + '<' + + StringUtil.simpleClassName(ByteBuf.class) + ", " + + StringUtil.simpleClassName(SocketAddress.class) + ">, " + + StringUtil.simpleClassName(ByteBuf.class) + ')'; + + private final MulticastSocket socket; + private final OioDatagramChannelConfig config; + private final java.net.DatagramPacket tmpPacket = new java.net.DatagramPacket(EmptyArrays.EMPTY_BYTES, 0); + + private static MulticastSocket newSocket() { + try { + return new MulticastSocket(null); + } catch (Exception e) { + throw new ChannelException("failed to create a new socket", e); + } + } + + /** + * Create a new instance with an new {@link MulticastSocket}. + */ + public OioDatagramChannel() { + this(newSocket()); + } + + /** + * Create a new instance from the given {@link MulticastSocket}. + * + * @param socket the {@link MulticastSocket} which is used by this instance + */ + public OioDatagramChannel(MulticastSocket socket) { + super(null); + + boolean success = false; + try { + socket.setSoTimeout(SO_TIMEOUT); + socket.setBroadcast(false); + success = true; + } catch (SocketException e) { + throw new ChannelException( + "Failed to configure the datagram socket timeout.", e); + } finally { + if (!success) { + socket.close(); + } + } + + this.socket = socket; + config = new DefaultOioDatagramChannelConfig(this, socket); + } + + @Override + public ChannelMetadata metadata() { + return METADATA; + } + + /** + * {@inheritDoc} + * + * This can be safely cast to {@link OioDatagramChannelConfig}. + */ + @Override + // TODO: Change return type to OioDatagramChannelConfig in next major release + public DatagramChannelConfig config() { + return config; + } + + @Override + public boolean isOpen() { + return !socket.isClosed(); + } + + @Override + @SuppressWarnings("deprecation") + public boolean isActive() { + return isOpen() + && (config.getOption(ChannelOption.DATAGRAM_CHANNEL_ACTIVE_ON_REGISTRATION) && isRegistered() + || socket.isBound()); + } + + @Override + public boolean isConnected() { + return socket.isConnected(); + } + + @Override + protected SocketAddress localAddress0() { + return socket.getLocalSocketAddress(); + } + + @Override + protected SocketAddress remoteAddress0() { + return socket.getRemoteSocketAddress(); + } + + @Override + protected void doBind(SocketAddress localAddress) throws Exception { + socket.bind(localAddress); + } + + @Override + public InetSocketAddress localAddress() { + return (InetSocketAddress) super.localAddress(); + } + + @Override + public InetSocketAddress remoteAddress() { + return (InetSocketAddress) super.remoteAddress(); + } + + @Override + protected void doConnect(SocketAddress remoteAddress, + SocketAddress localAddress) throws Exception { + if (localAddress != null) { + socket.bind(localAddress); + } + + boolean success = false; + try { + socket.connect(remoteAddress); + success = true; + } finally { + if (!success) { + try { + socket.close(); + } catch (Throwable t) { + logger.warn("Failed to close a socket.", t); + } + } + } + } + + @Override + protected void doDisconnect() throws Exception { + socket.disconnect(); + } + + @Override + protected void doClose() throws Exception { + socket.close(); + } + + @Override + protected int doReadMessages(List buf) throws Exception { + DatagramChannelConfig config = config(); + final RecvByteBufAllocator.Handle allocHandle = unsafe().recvBufAllocHandle(); + + ByteBuf data = config.getAllocator().heapBuffer(allocHandle.guess()); + boolean free = true; + try { + // Ensure we null out the address which may have been set before. + tmpPacket.setAddress(null); + tmpPacket.setData(data.array(), data.arrayOffset(), data.capacity()); + socket.receive(tmpPacket); + + InetSocketAddress remoteAddr = (InetSocketAddress) tmpPacket.getSocketAddress(); + + allocHandle.lastBytesRead(tmpPacket.getLength()); + buf.add(new DatagramPacket(data.writerIndex(allocHandle.lastBytesRead()), localAddress(), remoteAddr)); + free = false; + return 1; + } catch (SocketTimeoutException e) { + // Expected + return 0; + } catch (SocketException e) { + if (!e.getMessage().toLowerCase(Locale.US).contains("socket closed")) { + throw e; + } + return -1; + } catch (Throwable cause) { + PlatformDependent.throwException(cause); + return -1; + } finally { + if (free) { + data.release(); + } + } + } + + @Override + protected void doWrite(ChannelOutboundBuffer in) throws Exception { + for (;;) { + final Object o = in.current(); + if (o == null) { + break; + } + + final ByteBuf data; + final SocketAddress remoteAddress; + if (o instanceof AddressedEnvelope) { + @SuppressWarnings("unchecked") + AddressedEnvelope envelope = (AddressedEnvelope) o; + remoteAddress = envelope.recipient(); + data = envelope.content(); + } else { + data = (ByteBuf) o; + remoteAddress = null; + } + + final int length = data.readableBytes(); + try { + if (remoteAddress != null) { + tmpPacket.setSocketAddress(remoteAddress); + } else { + if (!isConnected()) { + // If not connected we should throw a NotYetConnectedException() to be consistent with + // NioDatagramChannel + throw new NotYetConnectedException(); + } + // Ensure we null out the address which may have been set before. + tmpPacket.setAddress(null); + } + if (data.hasArray()) { + tmpPacket.setData(data.array(), data.arrayOffset() + data.readerIndex(), length); + } else { + tmpPacket.setData(ByteBufUtil.getBytes(data, data.readerIndex(), length)); + } + socket.send(tmpPacket); + in.remove(); + } catch (Exception e) { + // Continue on write error as a DatagramChannel can write to multiple remote peers + // + // See https://github.com/netty/netty/issues/2665 + in.remove(e); + } + } + } + + private static void checkUnresolved(AddressedEnvelope envelope) { + if (envelope.recipient() instanceof InetSocketAddress + && (((InetSocketAddress) envelope.recipient()).isUnresolved())) { + throw new UnresolvedAddressException(); + } + } + + @Override + protected Object filterOutboundMessage(Object msg) { + if (msg instanceof DatagramPacket) { + checkUnresolved((DatagramPacket) msg); + return msg; + } + + if (msg instanceof ByteBuf) { + return msg; + } + + if (msg instanceof AddressedEnvelope) { + @SuppressWarnings("unchecked") + AddressedEnvelope e = (AddressedEnvelope) msg; + checkUnresolved(e); + if (e.content() instanceof ByteBuf) { + return msg; + } + } + + throw new UnsupportedOperationException( + "unsupported message type: " + StringUtil.simpleClassName(msg) + EXPECTED_TYPES); + } + + @Override + public ChannelFuture joinGroup(InetAddress multicastAddress) { + return joinGroup(multicastAddress, newPromise()); + } + + @Override + public ChannelFuture joinGroup(InetAddress multicastAddress, ChannelPromise promise) { + ensureBound(); + try { + socket.joinGroup(multicastAddress); + promise.setSuccess(); + } catch (IOException e) { + promise.setFailure(e); + } + return promise; + } + + @Override + public ChannelFuture joinGroup(InetSocketAddress multicastAddress, NetworkInterface networkInterface) { + return joinGroup(multicastAddress, networkInterface, newPromise()); + } + + @Override + public ChannelFuture joinGroup( + InetSocketAddress multicastAddress, NetworkInterface networkInterface, + ChannelPromise promise) { + ensureBound(); + try { + socket.joinGroup(multicastAddress, networkInterface); + promise.setSuccess(); + } catch (IOException e) { + promise.setFailure(e); + } + return promise; + } + + @Override + public ChannelFuture joinGroup( + InetAddress multicastAddress, NetworkInterface networkInterface, InetAddress source) { + return newFailedFuture(new UnsupportedOperationException()); + } + + @Override + public ChannelFuture joinGroup( + InetAddress multicastAddress, NetworkInterface networkInterface, InetAddress source, + ChannelPromise promise) { + promise.setFailure(new UnsupportedOperationException()); + return promise; + } + + private void ensureBound() { + if (!isActive()) { + throw new IllegalStateException( + DatagramChannel.class.getName() + + " must be bound to join a group."); + } + } + + @Override + public ChannelFuture leaveGroup(InetAddress multicastAddress) { + return leaveGroup(multicastAddress, newPromise()); + } + + @Override + public ChannelFuture leaveGroup(InetAddress multicastAddress, ChannelPromise promise) { + try { + socket.leaveGroup(multicastAddress); + promise.setSuccess(); + } catch (IOException e) { + promise.setFailure(e); + } + return promise; + } + + @Override + public ChannelFuture leaveGroup( + InetSocketAddress multicastAddress, NetworkInterface networkInterface) { + return leaveGroup(multicastAddress, networkInterface, newPromise()); + } + + @Override + public ChannelFuture leaveGroup( + InetSocketAddress multicastAddress, NetworkInterface networkInterface, + ChannelPromise promise) { + try { + socket.leaveGroup(multicastAddress, networkInterface); + promise.setSuccess(); + } catch (IOException e) { + promise.setFailure(e); + } + return promise; + } + + @Override + public ChannelFuture leaveGroup( + InetAddress multicastAddress, NetworkInterface networkInterface, InetAddress source) { + return newFailedFuture(new UnsupportedOperationException()); + } + + @Override + public ChannelFuture leaveGroup( + InetAddress multicastAddress, NetworkInterface networkInterface, InetAddress source, + ChannelPromise promise) { + promise.setFailure(new UnsupportedOperationException()); + return promise; + } + + @Override + public ChannelFuture block(InetAddress multicastAddress, + NetworkInterface networkInterface, InetAddress sourceToBlock) { + return newFailedFuture(new UnsupportedOperationException()); + } + + @Override + public ChannelFuture block(InetAddress multicastAddress, + NetworkInterface networkInterface, InetAddress sourceToBlock, + ChannelPromise promise) { + promise.setFailure(new UnsupportedOperationException()); + return promise; + } + + @Override + public ChannelFuture block(InetAddress multicastAddress, + InetAddress sourceToBlock) { + return newFailedFuture(new UnsupportedOperationException()); + } + + @Override + public ChannelFuture block(InetAddress multicastAddress, + InetAddress sourceToBlock, ChannelPromise promise) { + promise.setFailure(new UnsupportedOperationException()); + return promise; + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/socket/oio/OioDatagramChannelConfig.java b/netty-channel/src/main/java/io/netty/channel/socket/oio/OioDatagramChannelConfig.java new file mode 100644 index 0000000..4c1bd64 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/socket/oio/OioDatagramChannelConfig.java @@ -0,0 +1,101 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.socket.oio; + +import io.netty.buffer.ByteBufAllocator; +import io.netty.channel.MessageSizeEstimator; +import io.netty.channel.RecvByteBufAllocator; +import io.netty.channel.WriteBufferWaterMark; +import io.netty.channel.socket.DatagramChannelConfig; + +import java.net.InetAddress; +import java.net.NetworkInterface; + +/** + * @deprecated use NIO / EPOLL / KQUEUE transport. + */ +@Deprecated +public interface OioDatagramChannelConfig extends DatagramChannelConfig { + /** + * Sets the maximal time a operation on the underlying socket may block. + */ + OioDatagramChannelConfig setSoTimeout(int timeout); + + /** + * Returns the maximal time a operation on the underlying socket may block. + */ + int getSoTimeout(); + + @Override + OioDatagramChannelConfig setSendBufferSize(int sendBufferSize); + + @Override + OioDatagramChannelConfig setReceiveBufferSize(int receiveBufferSize); + + @Override + OioDatagramChannelConfig setTrafficClass(int trafficClass); + + @Override + OioDatagramChannelConfig setReuseAddress(boolean reuseAddress); + + @Override + OioDatagramChannelConfig setBroadcast(boolean broadcast); + + @Override + OioDatagramChannelConfig setLoopbackModeDisabled(boolean loopbackModeDisabled); + + @Override + OioDatagramChannelConfig setTimeToLive(int ttl); + + @Override + OioDatagramChannelConfig setInterface(InetAddress interfaceAddress); + + @Override + OioDatagramChannelConfig setNetworkInterface(NetworkInterface networkInterface); + + @Override + OioDatagramChannelConfig setMaxMessagesPerRead(int maxMessagesPerRead); + + @Override + OioDatagramChannelConfig setWriteSpinCount(int writeSpinCount); + + @Override + OioDatagramChannelConfig setConnectTimeoutMillis(int connectTimeoutMillis); + + @Override + OioDatagramChannelConfig setAllocator(ByteBufAllocator allocator); + + @Override + OioDatagramChannelConfig setRecvByteBufAllocator(RecvByteBufAllocator allocator); + + @Override + OioDatagramChannelConfig setAutoRead(boolean autoRead); + + @Override + OioDatagramChannelConfig setAutoClose(boolean autoClose); + + @Override + OioDatagramChannelConfig setMessageSizeEstimator(MessageSizeEstimator estimator); + + @Override + OioDatagramChannelConfig setWriteBufferWaterMark(WriteBufferWaterMark writeBufferWaterMark); + + @Override + OioDatagramChannelConfig setWriteBufferHighWaterMark(int writeBufferHighWaterMark); + + @Override + OioDatagramChannelConfig setWriteBufferLowWaterMark(int writeBufferLowWaterMark); +} diff --git a/netty-channel/src/main/java/io/netty/channel/socket/oio/OioServerSocketChannel.java b/netty-channel/src/main/java/io/netty/channel/socket/oio/OioServerSocketChannel.java new file mode 100644 index 0000000..7d7e147 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/socket/oio/OioServerSocketChannel.java @@ -0,0 +1,207 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.socket.oio; + +import io.netty.channel.ChannelException; +import io.netty.channel.ChannelMetadata; +import io.netty.channel.ChannelOutboundBuffer; +import io.netty.channel.oio.AbstractOioMessageChannel; +import io.netty.channel.socket.ServerSocketChannel; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.SocketUtils; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.ServerSocket; +import java.net.Socket; +import java.net.SocketAddress; +import java.net.SocketTimeoutException; +import java.util.List; + +/** + * {@link ServerSocketChannel} which accepts new connections and create the {@link OioSocketChannel}'s for them. + * + * This implementation use Old-Blocking-IO. + * + * @deprecated use NIO / EPOLL / KQUEUE transport. + */ +@Deprecated +public class OioServerSocketChannel extends AbstractOioMessageChannel + implements ServerSocketChannel { + + private static final InternalLogger logger = + InternalLoggerFactory.getInstance(OioServerSocketChannel.class); + + private static final ChannelMetadata METADATA = new ChannelMetadata(false, 1); + + private static ServerSocket newServerSocket() { + try { + return new ServerSocket(); + } catch (IOException e) { + throw new ChannelException("failed to create a server socket", e); + } + } + + final ServerSocket socket; + private final OioServerSocketChannelConfig config; + + /** + * Create a new instance with an new {@link Socket} + */ + public OioServerSocketChannel() { + this(newServerSocket()); + } + + /** + * Create a new instance from the given {@link ServerSocket} + * + * @param socket the {@link ServerSocket} which is used by this instance + */ + public OioServerSocketChannel(ServerSocket socket) { + super(null); + ObjectUtil.checkNotNull(socket, "socket"); + + boolean success = false; + try { + socket.setSoTimeout(SO_TIMEOUT); + success = true; + } catch (IOException e) { + throw new ChannelException( + "Failed to set the server socket timeout.", e); + } finally { + if (!success) { + try { + socket.close(); + } catch (IOException e) { + if (logger.isWarnEnabled()) { + logger.warn( + "Failed to close a partially initialized socket.", e); + } + } + } + } + this.socket = socket; + config = new DefaultOioServerSocketChannelConfig(this, socket); + } + + @Override + public InetSocketAddress localAddress() { + return (InetSocketAddress) super.localAddress(); + } + + @Override + public ChannelMetadata metadata() { + return METADATA; + } + + @Override + public OioServerSocketChannelConfig config() { + return config; + } + + @Override + public InetSocketAddress remoteAddress() { + return null; + } + + @Override + public boolean isOpen() { + return !socket.isClosed(); + } + + @Override + public boolean isActive() { + return isOpen() && socket.isBound(); + } + + @Override + protected SocketAddress localAddress0() { + return SocketUtils.localSocketAddress(socket); + } + + @Override + protected void doBind(SocketAddress localAddress) throws Exception { + socket.bind(localAddress, config.getBacklog()); + } + + @Override + protected void doClose() throws Exception { + socket.close(); + } + + @Override + protected int doReadMessages(List buf) throws Exception { + if (socket.isClosed()) { + return -1; + } + + try { + Socket s = socket.accept(); + try { + buf.add(new OioSocketChannel(this, s)); + return 1; + } catch (Throwable t) { + logger.warn("Failed to create a new channel from an accepted socket.", t); + try { + s.close(); + } catch (Throwable t2) { + logger.warn("Failed to close a socket.", t2); + } + } + } catch (SocketTimeoutException e) { + // Expected + } + return 0; + } + + @Override + protected void doWrite(ChannelOutboundBuffer in) throws Exception { + throw new UnsupportedOperationException(); + } + + @Override + protected Object filterOutboundMessage(Object msg) throws Exception { + throw new UnsupportedOperationException(); + } + + @Override + protected void doConnect( + SocketAddress remoteAddress, SocketAddress localAddress) throws Exception { + throw new UnsupportedOperationException(); + } + + @Override + protected SocketAddress remoteAddress0() { + return null; + } + + @Override + protected void doDisconnect() throws Exception { + throw new UnsupportedOperationException(); + } + + @Deprecated + @Override + protected void setReadPending(boolean readPending) { + super.setReadPending(readPending); + } + + final void clearReadPending0() { + super.clearReadPending(); + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/socket/oio/OioServerSocketChannelConfig.java b/netty-channel/src/main/java/io/netty/channel/socket/oio/OioServerSocketChannelConfig.java new file mode 100644 index 0000000..42cad8d --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/socket/oio/OioServerSocketChannelConfig.java @@ -0,0 +1,103 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.socket.oio; + +import io.netty.buffer.ByteBufAllocator; +import io.netty.channel.ChannelOption; +import io.netty.channel.MessageSizeEstimator; +import io.netty.channel.RecvByteBufAllocator; +import io.netty.channel.WriteBufferWaterMark; +import io.netty.channel.socket.ServerSocketChannelConfig; + + +/** + * A {@link ServerSocketChannelConfig} for a {@link OioServerSocketChannel}. + * + *

Available options

+ * + * In addition to the options provided by {@link ServerSocketChannelConfig}, + * {@link OioServerSocketChannelConfig} allows the following options in the + * option map: + * + * + * + * + * + * + * + *
NameAssociated setter method
{@link ChannelOption#SO_TIMEOUT}{@link #setSoTimeout(int)}
+ * + * @deprecated use NIO / EPOLL / KQUEUE transport. + */ +@Deprecated +public interface OioServerSocketChannelConfig extends ServerSocketChannelConfig { + + /** + * Sets the maximal time a operation on the underlying socket may block. + */ + OioServerSocketChannelConfig setSoTimeout(int timeout); + + /** + * Returns the maximal time a operation on the underlying socket may block. + */ + int getSoTimeout(); + + @Override + OioServerSocketChannelConfig setBacklog(int backlog); + + @Override + OioServerSocketChannelConfig setReuseAddress(boolean reuseAddress); + + @Override + OioServerSocketChannelConfig setReceiveBufferSize(int receiveBufferSize); + + @Override + OioServerSocketChannelConfig setPerformancePreferences(int connectionTime, int latency, int bandwidth); + + @Override + OioServerSocketChannelConfig setConnectTimeoutMillis(int connectTimeoutMillis); + + @Override + @Deprecated + OioServerSocketChannelConfig setMaxMessagesPerRead(int maxMessagesPerRead); + + @Override + OioServerSocketChannelConfig setWriteSpinCount(int writeSpinCount); + + @Override + OioServerSocketChannelConfig setAllocator(ByteBufAllocator allocator); + + @Override + OioServerSocketChannelConfig setRecvByteBufAllocator(RecvByteBufAllocator allocator); + + @Override + OioServerSocketChannelConfig setAutoRead(boolean autoRead); + + @Override + OioServerSocketChannelConfig setAutoClose(boolean autoClose); + + @Override + OioServerSocketChannelConfig setWriteBufferHighWaterMark(int writeBufferHighWaterMark); + + @Override + OioServerSocketChannelConfig setWriteBufferLowWaterMark(int writeBufferLowWaterMark); + + @Override + OioServerSocketChannelConfig setWriteBufferWaterMark(WriteBufferWaterMark writeBufferWaterMark); + + @Override + OioServerSocketChannelConfig setMessageSizeEstimator(MessageSizeEstimator estimator); +} diff --git a/netty-channel/src/main/java/io/netty/channel/socket/oio/OioSocketChannel.java b/netty-channel/src/main/java/io/netty/channel/socket/oio/OioSocketChannel.java new file mode 100644 index 0000000..f96faec --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/socket/oio/OioSocketChannel.java @@ -0,0 +1,352 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.socket.oio; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.Channel; +import io.netty.channel.ChannelException; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelPromise; +import io.netty.channel.ConnectTimeoutException; +import io.netty.channel.EventLoop; +import io.netty.channel.oio.OioByteStreamChannel; +import io.netty.channel.socket.ServerSocketChannel; +import io.netty.channel.socket.SocketChannel; +import io.netty.util.internal.SocketUtils; +import io.netty.util.internal.UnstableApi; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.Socket; +import java.net.SocketAddress; +import java.net.SocketTimeoutException; + +/** + * A {@link SocketChannel} which is using Old-Blocking-IO + * + * @deprecated use NIO / EPOLL / KQUEUE transport. + */ +@Deprecated +public class OioSocketChannel extends OioByteStreamChannel implements SocketChannel { + + private static final InternalLogger logger = InternalLoggerFactory.getInstance(OioSocketChannel.class); + + private final Socket socket; + private final OioSocketChannelConfig config; + + /** + * Create a new instance with an new {@link Socket} + */ + public OioSocketChannel() { + this(new Socket()); + } + + /** + * Create a new instance from the given {@link Socket} + * + * @param socket the {@link Socket} which is used by this instance + */ + public OioSocketChannel(Socket socket) { + this(null, socket); + } + + /** + * Create a new instance from the given {@link Socket} + * + * @param parent the parent {@link Channel} which was used to create this instance. This can be null if the + * channel has no parent as it was created by your self. + * @param socket the {@link Socket} which is used by this instance + */ + public OioSocketChannel(Channel parent, Socket socket) { + super(parent); + this.socket = socket; + config = new DefaultOioSocketChannelConfig(this, socket); + + boolean success = false; + try { + if (socket.isConnected()) { + activate(socket.getInputStream(), socket.getOutputStream()); + } + socket.setSoTimeout(SO_TIMEOUT); + success = true; + } catch (Exception e) { + throw new ChannelException("failed to initialize a socket", e); + } finally { + if (!success) { + try { + socket.close(); + } catch (IOException e) { + logger.warn("Failed to close a socket.", e); + } + } + } + } + + @Override + public ServerSocketChannel parent() { + return (ServerSocketChannel) super.parent(); + } + + @Override + public OioSocketChannelConfig config() { + return config; + } + + @Override + public boolean isOpen() { + return !socket.isClosed(); + } + + @Override + public boolean isActive() { + return !socket.isClosed() && socket.isConnected(); + } + + @Override + public boolean isOutputShutdown() { + return socket.isOutputShutdown() || !isActive(); + } + + @Override + public boolean isInputShutdown() { + return socket.isInputShutdown() || !isActive(); + } + + @Override + public boolean isShutdown() { + return socket.isInputShutdown() && socket.isOutputShutdown() || !isActive(); + } + + @UnstableApi + @Override + protected final void doShutdownOutput() throws Exception { + shutdownOutput0(); + } + + @Override + public ChannelFuture shutdownOutput() { + return shutdownOutput(newPromise()); + } + + @Override + public ChannelFuture shutdownInput() { + return shutdownInput(newPromise()); + } + + @Override + public ChannelFuture shutdown() { + return shutdown(newPromise()); + } + + @Override + protected int doReadBytes(ByteBuf buf) throws Exception { + if (socket.isClosed()) { + return -1; + } + try { + return super.doReadBytes(buf); + } catch (SocketTimeoutException ignored) { + return 0; + } + } + + @Override + public ChannelFuture shutdownOutput(final ChannelPromise promise) { + EventLoop loop = eventLoop(); + if (loop.inEventLoop()) { + shutdownOutput0(promise); + } else { + loop.execute(new Runnable() { + @Override + public void run() { + shutdownOutput0(promise); + } + }); + } + return promise; + } + + private void shutdownOutput0(ChannelPromise promise) { + try { + shutdownOutput0(); + promise.setSuccess(); + } catch (Throwable t) { + promise.setFailure(t); + } + } + + private void shutdownOutput0() throws IOException { + socket.shutdownOutput(); + } + + @Override + public ChannelFuture shutdownInput(final ChannelPromise promise) { + EventLoop loop = eventLoop(); + if (loop.inEventLoop()) { + shutdownInput0(promise); + } else { + loop.execute(new Runnable() { + @Override + public void run() { + shutdownInput0(promise); + } + }); + } + return promise; + } + + private void shutdownInput0(ChannelPromise promise) { + try { + socket.shutdownInput(); + promise.setSuccess(); + } catch (Throwable t) { + promise.setFailure(t); + } + } + + @Override + public ChannelFuture shutdown(final ChannelPromise promise) { + ChannelFuture shutdownOutputFuture = shutdownOutput(); + if (shutdownOutputFuture.isDone()) { + shutdownOutputDone(shutdownOutputFuture, promise); + } else { + shutdownOutputFuture.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(final ChannelFuture shutdownOutputFuture) throws Exception { + shutdownOutputDone(shutdownOutputFuture, promise); + } + }); + } + return promise; + } + + private void shutdownOutputDone(final ChannelFuture shutdownOutputFuture, final ChannelPromise promise) { + ChannelFuture shutdownInputFuture = shutdownInput(); + if (shutdownInputFuture.isDone()) { + shutdownDone(shutdownOutputFuture, shutdownInputFuture, promise); + } else { + shutdownInputFuture.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture shutdownInputFuture) throws Exception { + shutdownDone(shutdownOutputFuture, shutdownInputFuture, promise); + } + }); + } + } + + private static void shutdownDone(ChannelFuture shutdownOutputFuture, + ChannelFuture shutdownInputFuture, + ChannelPromise promise) { + Throwable shutdownOutputCause = shutdownOutputFuture.cause(); + Throwable shutdownInputCause = shutdownInputFuture.cause(); + if (shutdownOutputCause != null) { + if (shutdownInputCause != null) { + logger.debug("Exception suppressed because a previous exception occurred.", + shutdownInputCause); + } + promise.setFailure(shutdownOutputCause); + } else if (shutdownInputCause != null) { + promise.setFailure(shutdownInputCause); + } else { + promise.setSuccess(); + } + } + + @Override + public InetSocketAddress localAddress() { + return (InetSocketAddress) super.localAddress(); + } + + @Override + public InetSocketAddress remoteAddress() { + return (InetSocketAddress) super.remoteAddress(); + } + + @Override + protected SocketAddress localAddress0() { + return socket.getLocalSocketAddress(); + } + + @Override + protected SocketAddress remoteAddress0() { + return socket.getRemoteSocketAddress(); + } + + @Override + protected void doBind(SocketAddress localAddress) throws Exception { + SocketUtils.bind(socket, localAddress); + } + + @Override + protected void doConnect(SocketAddress remoteAddress, + SocketAddress localAddress) throws Exception { + if (localAddress != null) { + SocketUtils.bind(socket, localAddress); + } + + final int connectTimeoutMillis = config().getConnectTimeoutMillis(); + boolean success = false; + try { + SocketUtils.connect(socket, remoteAddress, connectTimeoutMillis); + activate(socket.getInputStream(), socket.getOutputStream()); + success = true; + } catch (SocketTimeoutException e) { + ConnectTimeoutException cause = new ConnectTimeoutException("connection timed out after " + + connectTimeoutMillis + " ms: " + remoteAddress); + cause.setStackTrace(e.getStackTrace()); + throw cause; + } finally { + if (!success) { + doClose(); + } + } + } + + @Override + protected void doDisconnect() throws Exception { + doClose(); + } + + @Override + protected void doClose() throws Exception { + socket.close(); + } + + protected boolean checkInputShutdown() { + if (isInputShutdown()) { + try { + Thread.sleep(config().getSoTimeout()); + } catch (Throwable e) { + // ignore + } + return true; + } + return false; + } + + @Deprecated + @Override + protected void setReadPending(boolean readPending) { + super.setReadPending(readPending); + } + + final void clearReadPending0() { + clearReadPending(); + } +} diff --git a/netty-channel/src/main/java/io/netty/channel/socket/oio/OioSocketChannelConfig.java b/netty-channel/src/main/java/io/netty/channel/socket/oio/OioSocketChannelConfig.java new file mode 100644 index 0000000..781fa0f --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/socket/oio/OioSocketChannelConfig.java @@ -0,0 +1,118 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.socket.oio; + +import io.netty.buffer.ByteBufAllocator; +import io.netty.channel.ChannelConfig; +import io.netty.channel.ChannelOption; +import io.netty.channel.MessageSizeEstimator; +import io.netty.channel.RecvByteBufAllocator; +import io.netty.channel.WriteBufferWaterMark; +import io.netty.channel.socket.SocketChannelConfig; + +/** + * A {@link ChannelConfig} for a {@link OioSocketChannel}. + * + *

Available options

+ * + * In addition to the options provided by {@link SocketChannelConfig}, + * {@link OioSocketChannelConfig} allows the following options in the + * option map: + * + * + * + * + * + * + * + *
NameAssociated setter method
{@link ChannelOption#SO_TIMEOUT}{@link #setSoTimeout(int)}
+ * + * @deprecated use NIO / EPOLL / KQUEUE transport. + */ +@Deprecated +public interface OioSocketChannelConfig extends SocketChannelConfig { + + /** + * Sets the maximal time a operation on the underlying socket may block. + */ + OioSocketChannelConfig setSoTimeout(int timeout); + + /** + * Returns the maximal time a operation on the underlying socket may block. + */ + int getSoTimeout(); + + @Override + OioSocketChannelConfig setTcpNoDelay(boolean tcpNoDelay); + + @Override + OioSocketChannelConfig setSoLinger(int soLinger); + + @Override + OioSocketChannelConfig setSendBufferSize(int sendBufferSize); + + @Override + OioSocketChannelConfig setReceiveBufferSize(int receiveBufferSize); + + @Override + OioSocketChannelConfig setKeepAlive(boolean keepAlive); + + @Override + OioSocketChannelConfig setTrafficClass(int trafficClass); + + @Override + OioSocketChannelConfig setReuseAddress(boolean reuseAddress); + + @Override + OioSocketChannelConfig setPerformancePreferences(int connectionTime, int latency, int bandwidth); + + @Override + OioSocketChannelConfig setAllowHalfClosure(boolean allowHalfClosure); + + @Override + OioSocketChannelConfig setConnectTimeoutMillis(int connectTimeoutMillis); + + @Override + @Deprecated + OioSocketChannelConfig setMaxMessagesPerRead(int maxMessagesPerRead); + + @Override + OioSocketChannelConfig setWriteSpinCount(int writeSpinCount); + + @Override + OioSocketChannelConfig setAllocator(ByteBufAllocator allocator); + + @Override + OioSocketChannelConfig setRecvByteBufAllocator(RecvByteBufAllocator allocator); + + @Override + OioSocketChannelConfig setAutoRead(boolean autoRead); + + @Override + OioSocketChannelConfig setAutoClose(boolean autoClose); + + @Override + OioSocketChannelConfig setWriteBufferHighWaterMark(int writeBufferHighWaterMark); + + @Override + OioSocketChannelConfig setWriteBufferLowWaterMark(int writeBufferLowWaterMark); + + @Override + OioSocketChannelConfig setWriteBufferWaterMark(WriteBufferWaterMark writeBufferWaterMark); + + @Override + OioSocketChannelConfig setMessageSizeEstimator(MessageSizeEstimator estimator); +} diff --git a/netty-channel/src/main/java/io/netty/channel/socket/oio/package-info.java b/netty-channel/src/main/java/io/netty/channel/socket/oio/package-info.java new file mode 100644 index 0000000..122e484 --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/socket/oio/package-info.java @@ -0,0 +1,21 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/** + * Old blocking I/O based socket channel API implementation - recommended for + * a small number of connections (< 1000). + */ +package io.netty.channel.socket.oio; diff --git a/netty-channel/src/main/java/io/netty/channel/socket/package-info.java b/netty-channel/src/main/java/io/netty/channel/socket/package-info.java new file mode 100644 index 0000000..4d9bcdb --- /dev/null +++ b/netty-channel/src/main/java/io/netty/channel/socket/package-info.java @@ -0,0 +1,20 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/** + * Abstract TCP and UDP socket interfaces which extend the core channel API. + */ +package io.netty.channel.socket; diff --git a/netty-channel/src/main/java/module-info.java b/netty-channel/src/main/java/module-info.java new file mode 100644 index 0000000..2a47566 --- /dev/null +++ b/netty-channel/src/main/java/module-info.java @@ -0,0 +1,17 @@ +import io.netty.bootstrap.ChannelInitializerExtension; + +module org.xbib.io.netty.channel { + exports io.netty.bootstrap; + exports io.netty.channel; + exports io.netty.channel.embedded; + exports io.netty.channel.group; + exports io.netty.channel.local; + exports io.netty.channel.nio; + exports io.netty.channel.oio; + exports io.netty.channel.pool; + exports io.netty.channel.socket; + uses ChannelInitializerExtension; + requires org.xbib.io.netty.buffer; + requires org.xbib.io.netty.resolver; + requires org.xbib.io.netty.util; +} diff --git a/netty-channel/src/main/resources/META-INF/native-image/io.netty/netty-transport/generated/handlers/reflect-config.json b/netty-channel/src/main/resources/META-INF/native-image/io.netty/netty-transport/generated/handlers/reflect-config.json new file mode 100644 index 0000000..3fa65eb --- /dev/null +++ b/netty-channel/src/main/resources/META-INF/native-image/io.netty/netty-transport/generated/handlers/reflect-config.json @@ -0,0 +1,121 @@ +[ + { + "name": "io.netty.bootstrap.ServerBootstrap$1", + "condition": { + "typeReachable": "io.netty.bootstrap.ServerBootstrap$1" + }, + "queryAllPublicMethods": true + }, + { + "name": "io.netty.bootstrap.ServerBootstrap$ServerBootstrapAcceptor", + "condition": { + "typeReachable": "io.netty.bootstrap.ServerBootstrap$ServerBootstrapAcceptor" + }, + "queryAllPublicMethods": true + }, + { + "name": "io.netty.channel.ChannelDuplexHandler", + "condition": { + "typeReachable": "io.netty.channel.ChannelDuplexHandler" + }, + "queryAllPublicMethods": true + }, + { + "name": "io.netty.channel.ChannelHandler", + "condition": { + "typeReachable": "io.netty.channel.ChannelHandler" + }, + "queryAllPublicMethods": true + }, + { + "name": "io.netty.channel.ChannelHandlerAdapter", + "condition": { + "typeReachable": "io.netty.channel.ChannelHandlerAdapter" + }, + "queryAllPublicMethods": true + }, + { + "name": "io.netty.channel.ChannelInboundHandler", + "condition": { + "typeReachable": "io.netty.channel.ChannelInboundHandler" + }, + "queryAllPublicMethods": true + }, + { + "name": "io.netty.channel.ChannelInboundHandlerAdapter", + "condition": { + "typeReachable": "io.netty.channel.ChannelInboundHandlerAdapter" + }, + "queryAllPublicMethods": true + }, + { + "name": "io.netty.channel.ChannelInitializer", + "condition": { + "typeReachable": "io.netty.channel.ChannelInitializer" + }, + "queryAllPublicMethods": true + }, + { + "name": "io.netty.channel.ChannelOutboundHandler", + "condition": { + "typeReachable": "io.netty.channel.ChannelOutboundHandler" + }, + "queryAllPublicMethods": true + }, + { + "name": "io.netty.channel.ChannelOutboundHandlerAdapter", + "condition": { + "typeReachable": "io.netty.channel.ChannelOutboundHandlerAdapter" + }, + "queryAllPublicMethods": true + }, + { + "name": "io.netty.channel.CombinedChannelDuplexHandler", + "condition": { + "typeReachable": "io.netty.channel.CombinedChannelDuplexHandler" + }, + "queryAllPublicMethods": true + }, + { + "name": "io.netty.channel.DefaultChannelPipeline$HeadContext", + "condition": { + "typeReachable": "io.netty.channel.DefaultChannelPipeline$HeadContext" + }, + "queryAllPublicMethods": true + }, + { + "name": "io.netty.channel.DefaultChannelPipeline$TailContext", + "condition": { + "typeReachable": "io.netty.channel.DefaultChannelPipeline$TailContext" + }, + "queryAllPublicMethods": true + }, + { + "name": "io.netty.channel.embedded.EmbeddedChannel$2", + "condition": { + "typeReachable": "io.netty.channel.embedded.EmbeddedChannel$2" + }, + "queryAllPublicMethods": true + }, + { + "name": "io.netty.channel.pool.SimpleChannelPool$1", + "condition": { + "typeReachable": "io.netty.channel.pool.SimpleChannelPool$1" + }, + "queryAllPublicMethods": true + }, + { + "name": "io.netty.channel.SimpleChannelInboundHandler", + "condition": { + "typeReachable": "io.netty.channel.SimpleChannelInboundHandler" + }, + "queryAllPublicMethods": true + }, + { + "name": "io.netty.channel.SimpleUserEventChannelHandler", + "condition": { + "typeReachable": "io.netty.channel.SimpleUserEventChannelHandler" + }, + "queryAllPublicMethods": true + } +] \ No newline at end of file diff --git a/netty-channel/src/main/resources/META-INF/native-image/io.netty/netty-transport/reflect-config.json b/netty-channel/src/main/resources/META-INF/native-image/io.netty/netty-transport/reflect-config.json new file mode 100644 index 0000000..c92cdb4 --- /dev/null +++ b/netty-channel/src/main/resources/META-INF/native-image/io.netty/netty-transport/reflect-config.json @@ -0,0 +1,33 @@ +[ + { + "name": "io.netty.channel.socket.nio.NioServerSocketChannel", + "methods": [ + { "name": "", "parameterTypes": [] } + ] + }, + { + "name": "sun.nio.ch.SelectorImpl", + "fields": [ + { "name": "selectedKeys", "allowUnsafeAccess" : true}, + { "name": "publicSelectedKeys", "allowUnsafeAccess" : true} + ] + }, + { + "name": "java.lang.management.ManagementFactory", + "methods": [ + { + "name": "getRuntimeMXBean", + "parameterTypes": [] + } + ] + }, + { + "name": "java.lang.management.RuntimeMXBean", + "methods": [ + { + "name": "getName", + "parameterTypes": [] + } + ] + } +] diff --git a/netty-channel/src/test/java/io/netty/bootstrap/BootstrapTest.java b/netty-channel/src/test/java/io/netty/bootstrap/BootstrapTest.java new file mode 100644 index 0000000..6ae3df6 --- /dev/null +++ b/netty-channel/src/test/java/io/netty/bootstrap/BootstrapTest.java @@ -0,0 +1,584 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.bootstrap; + +import io.netty.channel.Channel; +import io.netty.channel.ChannelConfig; +import io.netty.channel.ChannelFactory; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandler.Sharable; +import io.netty.channel.ChannelInboundHandler; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelOption; +import io.netty.channel.ChannelPromise; +import io.netty.channel.DefaultChannelConfig; +import io.netty.channel.DefaultEventLoop; +import io.netty.channel.DefaultEventLoopGroup; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.ServerChannel; +import io.netty.channel.local.LocalAddress; +import io.netty.channel.local.LocalChannel; +import io.netty.channel.local.LocalServerChannel; +import io.netty.resolver.AbstractAddressResolver; +import io.netty.resolver.AddressResolver; +import io.netty.resolver.AddressResolverGroup; +import io.netty.resolver.DefaultAddressResolverGroup; +import io.netty.util.AttributeKey; +import io.netty.util.concurrent.EventExecutor; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.Promise; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.api.function.Executable; + +import java.net.ConnectException; +import java.net.SocketAddress; +import java.net.SocketException; +import java.net.UnknownHostException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.nullValue; +import static org.hamcrest.Matchers.sameInstance; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class BootstrapTest { + + private static final EventLoopGroup groupA = new DefaultEventLoopGroup(1); + private static final EventLoopGroup groupB = new DefaultEventLoopGroup(1); + private static final ChannelInboundHandler dummyHandler = new DummyHandler(); + + @AfterAll + public static void destroy() { + groupA.shutdownGracefully(); + groupB.shutdownGracefully(); + groupA.terminationFuture().syncUninterruptibly(); + groupB.terminationFuture().syncUninterruptibly(); + } + + @Test + public void testOptionsCopied() { + final Bootstrap bootstrapA = new Bootstrap(); + bootstrapA.option(ChannelOption.AUTO_READ, true); + Map.Entry, Object>[] channelOptions = bootstrapA.newOptionsArray(); + bootstrapA.option(ChannelOption.AUTO_READ, false); + assertEquals(ChannelOption.AUTO_READ, channelOptions[0].getKey()); + assertEquals(true, channelOptions[0].getValue()); + } + + @Test + public void testAttributesCopied() { + AttributeKey key = AttributeKey.valueOf(UUID.randomUUID().toString()); + String value = "value"; + final Bootstrap bootstrapA = new Bootstrap(); + bootstrapA.attr(key, value); + Map.Entry, Object>[] attributesArray = bootstrapA.newAttributesArray(); + bootstrapA.attr(key, "value2"); + assertEquals(key, attributesArray[0].getKey()); + assertEquals(value, attributesArray[0].getValue()); + } + + @Test + public void optionsAndAttributesMustBeAvailableOnChannelInit() throws InterruptedException { + final AttributeKey key = AttributeKey.valueOf(UUID.randomUUID().toString()); + new Bootstrap() + .group(groupA) + .channel(LocalChannel.class) + .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 4242) + .attr(key, "value") + .handler(new ChannelInitializer() { + @Override + protected void initChannel(LocalChannel ch) throws Exception { + Integer option = ch.config().getOption(ChannelOption.CONNECT_TIMEOUT_MILLIS); + assertEquals(4242, (int) option); + assertEquals("value", ch.attr(key).get()); + } + }) + .bind(LocalAddress.ANY).sync(); + } + + @Test + @Timeout(value = 10000, unit = TimeUnit.MILLISECONDS) + public void testBindDeadLock() throws Exception { + final Bootstrap bootstrapA = new Bootstrap(); + bootstrapA.group(groupA); + bootstrapA.channel(LocalChannel.class); + bootstrapA.handler(dummyHandler); + + final Bootstrap bootstrapB = new Bootstrap(); + bootstrapB.group(groupB); + bootstrapB.channel(LocalChannel.class); + bootstrapB.handler(dummyHandler); + + List> bindFutures = new ArrayList>(); + + // Try to bind from each other. + for (int i = 0; i < 1024; i ++) { + bindFutures.add(groupA.next().submit(new Runnable() { + @Override + public void run() { + bootstrapB.bind(LocalAddress.ANY); + } + })); + + bindFutures.add(groupB.next().submit(new Runnable() { + @Override + public void run() { + bootstrapA.bind(LocalAddress.ANY); + } + })); + } + + for (Future f: bindFutures) { + f.sync(); + } + } + + @Test + @Timeout(value = 10000, unit = TimeUnit.MILLISECONDS) + public void testConnectDeadLock() throws Exception { + final Bootstrap bootstrapA = new Bootstrap(); + bootstrapA.group(groupA); + bootstrapA.channel(LocalChannel.class); + bootstrapA.handler(dummyHandler); + + final Bootstrap bootstrapB = new Bootstrap(); + bootstrapB.group(groupB); + bootstrapB.channel(LocalChannel.class); + bootstrapB.handler(dummyHandler); + + List> bindFutures = new ArrayList>(); + + // Try to connect from each other. + for (int i = 0; i < 1024; i ++) { + bindFutures.add(groupA.next().submit(new Runnable() { + @Override + public void run() { + bootstrapB.connect(LocalAddress.ANY); + } + })); + + bindFutures.add(groupB.next().submit(new Runnable() { + @Override + public void run() { + bootstrapA.connect(LocalAddress.ANY); + } + })); + } + + for (Future f: bindFutures) { + f.sync(); + } + } + + @Test + public void testLateRegisterSuccess() throws Exception { + TestEventLoopGroup group = new TestEventLoopGroup(); + try { + ServerBootstrap bootstrap = new ServerBootstrap(); + bootstrap.group(group); + bootstrap.channel(LocalServerChannel.class); + bootstrap.childHandler(new DummyHandler()); + bootstrap.localAddress(new LocalAddress("1")); + ChannelFuture future = bootstrap.bind(); + assertFalse(future.isDone()); + group.promise.setSuccess(); + final BlockingQueue queue = new LinkedBlockingQueue(); + future.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + queue.add(future.channel().eventLoop().inEventLoop(Thread.currentThread())); + queue.add(future.isSuccess()); + } + }); + assertTrue(queue.take()); + assertTrue(queue.take()); + } finally { + group.shutdownGracefully(); + group.terminationFuture().sync(); + } + } + + @Test + public void testLateRegisterSuccessBindFailed() throws Exception { + TestEventLoopGroup group = new TestEventLoopGroup(); + try { + ServerBootstrap bootstrap = new ServerBootstrap(); + bootstrap.group(group); + bootstrap.channelFactory(new ChannelFactory() { + @Override + public ServerChannel newChannel() { + return new LocalServerChannel() { + @Override + public ChannelFuture bind(SocketAddress localAddress) { + // Close the Channel to emulate what NIO and others impl do on bind failure + // See https://github.com/netty/netty/issues/2586 + close(); + return newFailedFuture(new SocketException()); + } + + @Override + public ChannelFuture bind(SocketAddress localAddress, ChannelPromise promise) { + // Close the Channel to emulate what NIO and others impl do on bind failure + // See https://github.com/netty/netty/issues/2586 + close(); + return promise.setFailure(new SocketException()); + } + }; + } + }); + bootstrap.childHandler(new DummyHandler()); + bootstrap.localAddress(new LocalAddress("1")); + ChannelFuture future = bootstrap.bind(); + assertFalse(future.isDone()); + group.promise.setSuccess(); + final BlockingQueue queue = new LinkedBlockingQueue(); + future.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + queue.add(future.channel().eventLoop().inEventLoop(Thread.currentThread())); + queue.add(future.isSuccess()); + } + }); + assertTrue(queue.take()); + assertFalse(queue.take()); + } finally { + group.shutdownGracefully(); + group.terminationFuture().sync(); + } + } + + @Test + @Timeout(value = 10000, unit = TimeUnit.MILLISECONDS) + public void testLateRegistrationConnect() throws Exception { + EventLoopGroup group = new DelayedEventLoopGroup(); + try { + final Bootstrap bootstrapA = new Bootstrap(); + bootstrapA.group(group); + bootstrapA.channel(LocalChannel.class); + bootstrapA.handler(dummyHandler); + assertThrows(ConnectException.class, new Executable() { + @Override + public void execute() { + bootstrapA.connect(LocalAddress.ANY).syncUninterruptibly(); + } + }); + } finally { + group.shutdownGracefully(); + } + } + + @Test + void testResolverDefault() throws Exception { + Bootstrap bootstrap = new Bootstrap(); + + assertTrue(bootstrap.config().toString().contains("resolver:")); + assertNotNull(bootstrap.config().resolver()); + assertEquals(DefaultAddressResolverGroup.class, bootstrap.config().resolver().getClass()); + } + + @Test + void testResolverDisabled() throws Exception { + Bootstrap bootstrap = new Bootstrap(); + bootstrap.disableResolver(); + + assertFalse(bootstrap.config().toString().contains("resolver:")); + assertNull(bootstrap.config().resolver()); + } + + @Test + public void testAsyncResolutionSuccess() throws Exception { + final Bootstrap bootstrapA = new Bootstrap(); + bootstrapA.group(groupA); + bootstrapA.channel(LocalChannel.class); + bootstrapA.resolver(new TestAddressResolverGroup(true)); + bootstrapA.handler(dummyHandler); + + final ServerBootstrap bootstrapB = new ServerBootstrap(); + bootstrapB.group(groupB); + bootstrapB.channel(LocalServerChannel.class); + bootstrapB.childHandler(dummyHandler); + + assertTrue(bootstrapA.config().toString().contains("resolver:")); + assertThat(bootstrapA.resolver(), is(instanceOf(TestAddressResolverGroup.class))); + + SocketAddress localAddress = bootstrapB.bind(LocalAddress.ANY).sync().channel().localAddress(); + + // Connect to the server using the asynchronous resolver. + bootstrapA.connect(localAddress).sync(); + } + + @Test + public void testAsyncResolutionFailure() throws Exception { + final Bootstrap bootstrapA = new Bootstrap(); + bootstrapA.group(groupA); + bootstrapA.channel(LocalChannel.class); + bootstrapA.resolver(new TestAddressResolverGroup(false)); + bootstrapA.handler(dummyHandler); + + final ServerBootstrap bootstrapB = new ServerBootstrap(); + bootstrapB.group(groupB); + bootstrapB.channel(LocalServerChannel.class); + bootstrapB.childHandler(dummyHandler); + SocketAddress localAddress = bootstrapB.bind(LocalAddress.ANY).sync().channel().localAddress(); + + // Connect to the server using the asynchronous resolver. + ChannelFuture connectFuture = bootstrapA.connect(localAddress); + + // Should fail with the UnknownHostException. + assertThat(connectFuture.await(10000), is(true)); + assertThat(connectFuture.cause(), is(instanceOf(UnknownHostException.class))); + assertThat(connectFuture.channel().isOpen(), is(false)); + } + + @Test + public void testGetResolverFailed() throws Exception { + class TestException extends RuntimeException { } + + final Bootstrap bootstrapA = new Bootstrap(); + bootstrapA.group(groupA); + bootstrapA.channel(LocalChannel.class); + + bootstrapA.resolver(new AddressResolverGroup() { + @Override + protected AddressResolver newResolver(EventExecutor executor) { + throw new TestException(); + } + }); + bootstrapA.handler(dummyHandler); + + final ServerBootstrap bootstrapB = new ServerBootstrap(); + bootstrapB.group(groupB); + bootstrapB.channel(LocalServerChannel.class); + bootstrapB.childHandler(dummyHandler); + SocketAddress localAddress = bootstrapB.bind(LocalAddress.ANY).sync().channel().localAddress(); + + // Connect to the server using the asynchronous resolver. + ChannelFuture connectFuture = bootstrapA.connect(localAddress); + + // Should fail with the IllegalStateException. + assertThat(connectFuture.await(10000), is(true)); + assertThat(connectFuture.cause(), instanceOf(IllegalStateException.class)); + assertThat(connectFuture.cause().getCause(), instanceOf(TestException.class)); + assertThat(connectFuture.channel().isOpen(), is(false)); + } + + @Test + public void testChannelFactoryFailureNotifiesPromise() throws Exception { + final RuntimeException exception = new RuntimeException("newChannel crash"); + + final Bootstrap bootstrap = new Bootstrap() + .handler(dummyHandler) + .group(groupA) + .channelFactory(new ChannelFactory() { + @Override + public Channel newChannel() { + throw exception; + } + }); + + ChannelFuture connectFuture = bootstrap.connect(LocalAddress.ANY); + + // Should fail with the RuntimeException. + assertThat(connectFuture.await(10000), is(true)); + assertThat(connectFuture.cause(), sameInstance((Throwable) exception)); + assertThat(connectFuture.channel(), is(not(nullValue()))); + } + + @Test + public void testChannelOptionOrderPreserve() throws InterruptedException { + final BlockingQueue> options = new LinkedBlockingQueue>(); + class ChannelConfigValidator extends DefaultChannelConfig { + ChannelConfigValidator(Channel channel) { + super(channel); + } + + @Override + public boolean setOption(ChannelOption option, T value) { + options.add(option); + return super.setOption(option, value); + } + } + final CountDownLatch latch = new CountDownLatch(1); + final Bootstrap bootstrap = new Bootstrap() + .handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + latch.countDown(); + } + }) + .group(groupA) + .channelFactory(new ChannelFactory() { + @Override + public Channel newChannel() { + return new LocalChannel() { + private ChannelConfigValidator config; + @Override + public synchronized ChannelConfig config() { + if (config == null) { + config = new ChannelConfigValidator(this); + } + return config; + } + }; + } + }) + .option(ChannelOption.WRITE_BUFFER_LOW_WATER_MARK, 1) + .option(ChannelOption.WRITE_BUFFER_HIGH_WATER_MARK, 2); + + bootstrap.register().syncUninterruptibly(); + + latch.await(); + + // Check the order is the same as what we defined before. + assertSame(ChannelOption.WRITE_BUFFER_LOW_WATER_MARK, options.take()); + assertSame(ChannelOption.WRITE_BUFFER_HIGH_WATER_MARK, options.take()); + } + + @Test + void mustCallInitializerExtensions() throws Exception { + final Bootstrap cb = new Bootstrap(); + cb.group(groupA); + cb.handler(dummyHandler); + cb.channel(LocalChannel.class); + + StubChannelInitializerExtension.clearThreadLocals(); + + ChannelFuture future = cb.register(); + future.sync(); + final Channel expectedChannel = future.channel(); + + assertSame(expectedChannel, StubChannelInitializerExtension.lastSeenClientChannel.get()); + assertNull(StubChannelInitializerExtension.lastSeenChildChannel.get()); + assertNull(StubChannelInitializerExtension.lastSeenListenerChannel.get()); + expectedChannel.close().sync(); + } + + private static final class DelayedEventLoopGroup extends DefaultEventLoop { + @Override + public ChannelFuture register(final Channel channel, final ChannelPromise promise) { + // Delay registration + execute(new Runnable() { + @Override + public void run() { + DelayedEventLoopGroup.super.register(channel, promise); + } + }); + return promise; + } + } + + private static final class TestEventLoopGroup extends DefaultEventLoopGroup { + + ChannelPromise promise; + + TestEventLoopGroup() { + super(1); + } + + @Override + public ChannelFuture register(Channel channel) { + super.register(channel).syncUninterruptibly(); + promise = channel.newPromise(); + return promise; + } + + @Override + public ChannelFuture register(ChannelPromise promise) { + throw new UnsupportedOperationException(); + } + + @Override + public ChannelFuture register(Channel channel, final ChannelPromise promise) { + throw new UnsupportedOperationException(); + } + } + + @Sharable + private static final class DummyHandler extends ChannelInboundHandlerAdapter { } + + private static final class TestAddressResolverGroup extends AddressResolverGroup { + + private final boolean success; + + TestAddressResolverGroup(boolean success) { + this.success = success; + } + + @Override + protected AddressResolver newResolver(EventExecutor executor) throws Exception { + return new AbstractAddressResolver(executor) { + + @Override + protected boolean doIsResolved(SocketAddress address) { + return false; + } + + @Override + protected void doResolve( + final SocketAddress unresolvedAddress, final Promise promise) { + executor().execute(new Runnable() { + @Override + public void run() { + if (success) { + promise.setSuccess(unresolvedAddress); + } else { + promise.setFailure(new UnknownHostException(unresolvedAddress.toString())); + } + } + }); + } + + @Override + protected void doResolveAll( + final SocketAddress unresolvedAddress, final Promise> promise) + throws Exception { + executor().execute(new Runnable() { + @Override + public void run() { + if (success) { + promise.setSuccess(Collections.singletonList(unresolvedAddress)); + } else { + promise.setFailure(new UnknownHostException(unresolvedAddress.toString())); + } + } + }); + } + }; + } + } +} diff --git a/netty-channel/src/test/java/io/netty/bootstrap/ServerBootstrapTest.java b/netty-channel/src/test/java/io/netty/bootstrap/ServerBootstrapTest.java new file mode 100644 index 0000000..7b24292 --- /dev/null +++ b/netty-channel/src/test/java/io/netty/bootstrap/ServerBootstrapTest.java @@ -0,0 +1,244 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.bootstrap; + +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerAdapter; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelOption; +import io.netty.channel.DefaultEventLoopGroup; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.local.LocalAddress; +import io.netty.channel.local.LocalChannel; +import io.netty.channel.local.LocalEventLoopGroup; +import io.netty.channel.local.LocalServerChannel; +import io.netty.util.AttributeKey; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.util.UUID; +import java.util.concurrent.Callable; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class ServerBootstrapTest { + + @Test + @Timeout(value = 5000, unit = TimeUnit.MILLISECONDS) + public void testHandlerRegister() throws Exception { + final CountDownLatch latch = new CountDownLatch(1); + final AtomicReference error = new AtomicReference(); + LocalEventLoopGroup group = new LocalEventLoopGroup(1); + try { + ServerBootstrap sb = new ServerBootstrap(); + sb.channel(LocalServerChannel.class) + .group(group) + .childHandler(new ChannelInboundHandlerAdapter()) + .handler(new ChannelHandlerAdapter() { + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + try { + assertTrue(ctx.executor().inEventLoop()); + } catch (Throwable cause) { + error.set(cause); + } finally { + latch.countDown(); + } + } + }); + sb.register().syncUninterruptibly(); + latch.await(); + assertNull(error.get()); + } finally { + group.shutdownGracefully(); + } + } + + @Test + @Timeout(value = 3000, unit = TimeUnit.MILLISECONDS) + public void testParentHandler() throws Exception { + testParentHandler(false); + } + + @Test + @Timeout(value = 3000, unit = TimeUnit.MILLISECONDS) + public void testParentHandlerViaChannelInitializer() throws Exception { + testParentHandler(true); + } + + private static void testParentHandler(boolean channelInitializer) throws Exception { + final LocalAddress addr = new LocalAddress(UUID.randomUUID().toString()); + final CountDownLatch readLatch = new CountDownLatch(1); + final CountDownLatch initLatch = new CountDownLatch(1); + + final ChannelHandler handler = new ChannelInboundHandlerAdapter() { + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + initLatch.countDown(); + super.handlerAdded(ctx); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + readLatch.countDown(); + super.channelRead(ctx, msg); + } + }; + + EventLoopGroup group = new DefaultEventLoopGroup(1); + Channel sch = null; + Channel cch = null; + try { + ServerBootstrap sb = new ServerBootstrap(); + sb.channel(LocalServerChannel.class) + .group(group) + .childHandler(new ChannelInboundHandlerAdapter()); + if (channelInitializer) { + sb.handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ch.pipeline().addLast(handler); + } + }); + } else { + sb.handler(handler); + } + + Bootstrap cb = new Bootstrap(); + cb.group(group) + .channel(LocalChannel.class) + .handler(new ChannelInboundHandlerAdapter()); + + sch = sb.bind(addr).syncUninterruptibly().channel(); + + cch = cb.connect(addr).syncUninterruptibly().channel(); + + initLatch.await(); + readLatch.await(); + } finally { + if (sch != null) { + sch.close().syncUninterruptibly(); + } + if (cch != null) { + cch.close().syncUninterruptibly(); + } + group.shutdownGracefully(); + } + } + + @Test + public void optionsAndAttributesMustBeAvailableOnChildChannelInit() throws InterruptedException { + EventLoopGroup group = new DefaultEventLoopGroup(1); + LocalAddress addr = new LocalAddress(UUID.randomUUID().toString()); + final AttributeKey key = AttributeKey.valueOf(UUID.randomUUID().toString()); + final AtomicBoolean requestServed = new AtomicBoolean(); + ServerBootstrap sb = new ServerBootstrap() + .group(group) + .channel(LocalServerChannel.class) + .childOption(ChannelOption.CONNECT_TIMEOUT_MILLIS, 4242) + .childAttr(key, "value") + .childHandler(new ChannelInitializer() { + @Override + protected void initChannel(LocalChannel ch) throws Exception { + Integer option = ch.config().getOption(ChannelOption.CONNECT_TIMEOUT_MILLIS); + assertEquals(4242, (int) option); + assertEquals("value", ch.attr(key).get()); + requestServed.set(true); + } + }); + Channel serverChannel = sb.bind(addr).syncUninterruptibly().channel(); + + Bootstrap cb = new Bootstrap(); + cb.group(group) + .channel(LocalChannel.class) + .handler(new ChannelInboundHandlerAdapter()); + Channel clientChannel = cb.connect(addr).syncUninterruptibly().channel(); + serverChannel.close().syncUninterruptibly(); + clientChannel.close().syncUninterruptibly(); + group.shutdownGracefully(); + assertTrue(requestServed.get()); + } + + @Test + void mustCallInitializerExtensions() throws Exception { + LocalAddress addr = new LocalAddress(ServerBootstrapTest.class); + final AtomicReference expectedServerChannel = new AtomicReference(); + final AtomicReference expectedChildChannel = new AtomicReference(); + LocalEventLoopGroup group = new LocalEventLoopGroup(1); + final ServerBootstrap sb = new ServerBootstrap(); + sb.group(group); + sb.channel(LocalServerChannel.class); + sb.handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + expectedServerChannel.set(ch); + } + }); + sb.childHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + expectedChildChannel.set(ch); + } + }); + + StubChannelInitializerExtension.clearThreadLocals(); + group.submit(new Runnable() { + @Override + public void run() { + StubChannelInitializerExtension.clearThreadLocals(); + } + }).sync(); + + Channel serverChannel = sb.bind(addr).syncUninterruptibly().channel(); + + assertNull(StubChannelInitializerExtension.lastSeenClientChannel.get()); + assertNull(StubChannelInitializerExtension.lastSeenChildChannel.get()); + assertSame(expectedServerChannel.get(), StubChannelInitializerExtension.lastSeenListenerChannel.get()); + assertSame(serverChannel, StubChannelInitializerExtension.lastSeenListenerChannel.get()); + + Bootstrap cb = new Bootstrap(); + cb.group(group) + .channel(LocalChannel.class) + .handler(new ChannelInboundHandlerAdapter()); + Channel clientChannel = cb.connect(addr).syncUninterruptibly().channel(); + + assertSame(clientChannel, StubChannelInitializerExtension.lastSeenClientChannel.get()); + group.submit(new Callable() { + @Override + public Object call() throws Exception { + assertSame(expectedChildChannel.get(), StubChannelInitializerExtension.lastSeenChildChannel.get()); + return null; + } + }).sync(); + assertSame(expectedServerChannel.get(), StubChannelInitializerExtension.lastSeenListenerChannel.get()); + assertSame(serverChannel, StubChannelInitializerExtension.lastSeenListenerChannel.get()); + + serverChannel.close().syncUninterruptibly(); + clientChannel.close().syncUninterruptibly(); + group.shutdownGracefully(); + } +} diff --git a/netty-channel/src/test/java/io/netty/bootstrap/StubChannelInitializerExtension.java b/netty-channel/src/test/java/io/netty/bootstrap/StubChannelInitializerExtension.java new file mode 100644 index 0000000..946fe9e --- /dev/null +++ b/netty-channel/src/test/java/io/netty/bootstrap/StubChannelInitializerExtension.java @@ -0,0 +1,47 @@ +/* + * Copyright 2023 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.bootstrap; + +import io.netty.channel.Channel; +import io.netty.channel.ServerChannel; +import io.netty.util.concurrent.FastThreadLocal; + +public class StubChannelInitializerExtension extends ChannelInitializerExtension { + static final FastThreadLocal lastSeenClientChannel = new FastThreadLocal(); + static final FastThreadLocal lastSeenListenerChannel = new FastThreadLocal(); + static final FastThreadLocal lastSeenChildChannel = new FastThreadLocal(); + + public static void clearThreadLocals() { + lastSeenChildChannel.remove(); + lastSeenClientChannel.remove(); + lastSeenListenerChannel.remove(); + } + + @Override + public void postInitializeClientChannel(Channel channel) { + lastSeenClientChannel.set(channel); + } + + @Override + public void postInitializeServerListenerChannel(ServerChannel channel) { + lastSeenListenerChannel.set(channel); + } + + @Override + public void postInitializeServerChildChannel(Channel channel) { + lastSeenChildChannel.set(channel); + } +} diff --git a/netty-channel/src/test/java/io/netty/channel/AbstractChannelTest.java b/netty-channel/src/test/java/io/netty/channel/AbstractChannelTest.java new file mode 100644 index 0000000..c0097b7 --- /dev/null +++ b/netty-channel/src/test/java/io/netty/channel/AbstractChannelTest.java @@ -0,0 +1,251 @@ +/* + * Copyright 2014 The Netty Project + + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + + * https://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.nio.channels.ClosedChannelException; + +import io.netty.util.NetUtil; +import io.netty.util.internal.PlatformDependent; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledForJreRange; +import org.junit.jupiter.api.condition.JRE; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.*; + +public class AbstractChannelTest { + + @Test + public void ensureInitialRegistrationFiresActive() throws Throwable { + EventLoop eventLoop = mock(EventLoop.class); + // This allows us to have a single-threaded test + when(eventLoop.inEventLoop()).thenReturn(true); + + TestChannel channel = new TestChannel(); + ChannelInboundHandler handler = mock(ChannelInboundHandler.class); + channel.pipeline().addLast(handler); + + registerChannel(eventLoop, channel); + + verify(handler).handlerAdded(any(ChannelHandlerContext.class)); + verify(handler).channelRegistered(any(ChannelHandlerContext.class)); + verify(handler).channelActive(any(ChannelHandlerContext.class)); + } + + @Test + public void ensureSubsequentRegistrationDoesNotFireActive() throws Throwable { + final EventLoop eventLoop = mock(EventLoop.class); + // This allows us to have a single-threaded test + when(eventLoop.inEventLoop()).thenReturn(true); + + doAnswer(new Answer() { + @Override + public Object answer(InvocationOnMock invocationOnMock) { + ((Runnable) invocationOnMock.getArgument(0)).run(); + return null; + } + }).when(eventLoop).execute(any(Runnable.class)); + + final TestChannel channel = new TestChannel(); + ChannelInboundHandler handler = mock(ChannelInboundHandler.class); + + channel.pipeline().addLast(handler); + + registerChannel(eventLoop, channel); + channel.unsafe().deregister(new DefaultChannelPromise(channel)); + + registerChannel(eventLoop, channel); + + verify(handler).handlerAdded(any(ChannelHandlerContext.class)); + + // Should register twice + verify(handler, times(2)) .channelRegistered(any(ChannelHandlerContext.class)); + verify(handler).channelActive(any(ChannelHandlerContext.class)); + verify(handler).channelUnregistered(any(ChannelHandlerContext.class)); + } + + @Test + public void ensureDefaultChannelId() { + TestChannel channel = new TestChannel(); + final ChannelId channelId = channel.id(); + assertTrue(channelId instanceof DefaultChannelId); + } + + @Test + @EnabledForJreRange(min = JRE.JAVA_9) + void processIdWithProcessHandleJava9() { + ClassLoader loader = PlatformDependent.getClassLoader(DefaultChannelId.class); + int processHandlePid = DefaultChannelId.processHandlePid(loader); + assertTrue(processHandlePid != -1); + assertEquals(DefaultChannelId.jmxPid(loader), processHandlePid); + assertEquals(DefaultChannelId.defaultProcessId(), processHandlePid); + } + + @Test + @EnabledForJreRange(max = JRE.JAVA_8) + void processIdWithJmxPrejava9() { + ClassLoader loader = PlatformDependent.getClassLoader(DefaultChannelId.class); + int processHandlePid = DefaultChannelId.processHandlePid(loader); + assertEquals(-1, processHandlePid); + assertEquals(DefaultChannelId.defaultProcessId(), DefaultChannelId.jmxPid(loader)); + } + + @Test + public void testClosedChannelExceptionCarryIOException() throws Exception { + final IOException ioException = new IOException(); + final Channel channel = new TestChannel() { + private boolean open = true; + private boolean active; + + @Override + protected AbstractUnsafe newUnsafe() { + return new AbstractUnsafe() { + @Override + public void connect( + SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) { + active = true; + promise.setSuccess(); + } + }; + } + + @Override + protected void doClose() { + active = false; + open = false; + } + + @Override + protected void doWrite(ChannelOutboundBuffer in) throws Exception { + throw ioException; + } + + @Override + public boolean isOpen() { + return open; + } + + @Override + public boolean isActive() { + return active; + } + }; + + EventLoop loop = new DefaultEventLoop(); + try { + registerChannel(loop, channel); + channel.connect(new InetSocketAddress(NetUtil.LOCALHOST, 8888)).sync(); + assertSame(ioException, channel.writeAndFlush("").await().cause()); + + assertClosedChannelException(channel.writeAndFlush(""), ioException); + assertClosedChannelException(channel.write(""), ioException); + assertClosedChannelException(channel.bind(new InetSocketAddress(NetUtil.LOCALHOST, 8888)), ioException); + } finally { + channel.close(); + loop.shutdownGracefully(); + } + } + + private static void assertClosedChannelException(ChannelFuture future, IOException expected) + throws InterruptedException { + Throwable cause = future.await().cause(); + assertTrue(cause instanceof ClosedChannelException); + assertSame(expected, cause.getCause()); + } + + private static void registerChannel(EventLoop eventLoop, Channel channel) throws Exception { + DefaultChannelPromise future = new DefaultChannelPromise(channel); + channel.unsafe().register(eventLoop, future); + future.sync(); // Cause any exceptions to be thrown + } + + private static class TestChannel extends AbstractChannel { + private static final ChannelMetadata TEST_METADATA = new ChannelMetadata(false); + + private final ChannelConfig config = new DefaultChannelConfig(this); + + TestChannel() { + super(null); + } + + @Override + public ChannelConfig config() { + return config; + } + + @Override + public boolean isOpen() { + return true; + } + + @Override + public boolean isActive() { + return true; + } + + @Override + public ChannelMetadata metadata() { + return TEST_METADATA; + } + + @Override + protected AbstractUnsafe newUnsafe() { + return new AbstractUnsafe() { + @Override + public void connect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) { + promise.setFailure(new UnsupportedOperationException()); + } + }; + } + + @Override + protected boolean isCompatible(EventLoop loop) { + return true; + } + + @Override + protected SocketAddress localAddress0() { + return null; + } + + @Override + protected SocketAddress remoteAddress0() { + return null; + } + + @Override + protected void doBind(SocketAddress localAddress) { } + + @Override + protected void doDisconnect() { } + + @Override + protected void doClose() { } + + @Override + protected void doBeginRead() { } + + @Override + protected void doWrite(ChannelOutboundBuffer in) throws Exception { } + } +} diff --git a/netty-channel/src/test/java/io/netty/channel/AbstractCoalescingBufferQueueTest.java b/netty-channel/src/test/java/io/netty/channel/AbstractCoalescingBufferQueueTest.java new file mode 100644 index 0000000..d79013a --- /dev/null +++ b/netty-channel/src/test/java/io/netty/channel/AbstractCoalescingBufferQueueTest.java @@ -0,0 +1,91 @@ +/* + * Copyright 2020 The Netty Project + + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + + * https://www.apache.org/licenses/LICENSE-2.0 + + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.util.ReferenceCountUtil; +import org.junit.jupiter.api.Test; + +import java.nio.channels.ClosedChannelException; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class AbstractCoalescingBufferQueueTest { + + // See https://github.com/netty/netty/issues/10286 + @Test + public void testDecrementAllWhenWriteAndRemoveAll() { + testDecrementAll(true); + } + + // See https://github.com/netty/netty/issues/10286 + @Test + public void testDecrementAllWhenReleaseAndFailAll() { + testDecrementAll(false); + } + + private static void testDecrementAll(boolean write) { + EmbeddedChannel channel = new EmbeddedChannel(new ChannelOutboundHandlerAdapter() { + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) { + ReferenceCountUtil.release(msg); + promise.setSuccess(); + } + }, new ChannelHandlerAdapter() { }); + final AbstractCoalescingBufferQueue queue = new AbstractCoalescingBufferQueue(channel, 128) { + @Override + protected ByteBuf compose(ByteBufAllocator alloc, ByteBuf cumulation, ByteBuf next) { + return composeIntoComposite(alloc, cumulation, next); + } + + @Override + protected ByteBuf removeEmptyValue() { + return Unpooled.EMPTY_BUFFER; + } + }; + + final byte[] bytes = new byte[128]; + queue.add(Unpooled.wrappedBuffer(bytes), new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + queue.add(Unpooled.wrappedBuffer(bytes)); + assertEquals(bytes.length, queue.readableBytes()); + } + }); + + assertEquals(bytes.length, queue.readableBytes()); + + ChannelHandlerContext ctx = channel.pipeline().lastContext(); + if (write) { + queue.writeAndRemoveAll(ctx); + } else { + queue.releaseAndFailAll(ctx, new ClosedChannelException()); + } + ByteBuf buffer = queue.remove(channel.alloc(), 128, channel.newPromise()); + assertFalse(buffer.isReadable()); + buffer.release(); + + assertTrue(queue.isEmpty()); + assertEquals(0, queue.readableBytes()); + + assertFalse(channel.finish()); + } +} diff --git a/netty-channel/src/test/java/io/netty/channel/AbstractEventLoopTest.java b/netty-channel/src/test/java/io/netty/channel/AbstractEventLoopTest.java new file mode 100644 index 0000000..65f5a9b --- /dev/null +++ b/netty-channel/src/test/java/io/netty/channel/AbstractEventLoopTest.java @@ -0,0 +1,95 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.bootstrap.ServerBootstrap; +import io.netty.channel.socket.ServerSocketChannel; +import io.netty.channel.socket.SocketChannel; +import io.netty.util.concurrent.DefaultEventExecutorGroup; +import io.netty.util.concurrent.EventExecutor; +import io.netty.util.concurrent.EventExecutorGroup; +import io.netty.util.concurrent.Future; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.util.concurrent.TimeUnit; + +import static org.junit.jupiter.api.Assertions.*; + +public abstract class AbstractEventLoopTest { + + /** + * Test for https://github.com/netty/netty/issues/803 + */ + @Test + public void testReregister() { + EventLoopGroup group = newEventLoopGroup(); + EventLoopGroup group2 = newEventLoopGroup(); + final EventExecutorGroup eventExecutorGroup = new DefaultEventExecutorGroup(2); + + ServerBootstrap bootstrap = new ServerBootstrap(); + ChannelFuture future = bootstrap.channel(newChannel()).group(group) + .childHandler(new ChannelInitializer() { + @Override + public void initChannel(SocketChannel ch) { + } + }).handler(new ChannelInitializer() { + @Override + public void initChannel(ServerSocketChannel ch) { + ch.pipeline().addLast(new TestChannelHandler()); + ch.pipeline().addLast(eventExecutorGroup, new TestChannelHandler2()); + } + }) + .bind(0).awaitUninterruptibly(); + + EventExecutor executor = future.channel().pipeline().context(TestChannelHandler2.class).executor(); + EventExecutor executor1 = future.channel().pipeline().context(TestChannelHandler.class).executor(); + future.channel().deregister().awaitUninterruptibly(); + Channel channel = group2.register(future.channel()).awaitUninterruptibly().channel(); + EventExecutor executorNew = channel.pipeline().context(TestChannelHandler.class).executor(); + assertNotSame(executor1, executorNew); + assertSame(executor, future.channel().pipeline().context(TestChannelHandler2.class).executor()); + } + + @Test + @Timeout(value = 5000, unit = TimeUnit.MILLISECONDS) + public void testShutdownGracefullyNoQuietPeriod() throws Exception { + EventLoopGroup loop = newEventLoopGroup(); + ServerBootstrap b = new ServerBootstrap(); + b.group(loop) + .channel(newChannel()) + .childHandler(new ChannelInboundHandlerAdapter()); + + // Not close the Channel to ensure the EventLoop is still shutdown in time. + b.bind(0).sync().channel(); + + Future f = loop.shutdownGracefully(0, 1, TimeUnit.MINUTES); + assertTrue(loop.awaitTermination(600, TimeUnit.MILLISECONDS)); + assertTrue(f.syncUninterruptibly().isSuccess()); + assertTrue(loop.isShutdown()); + assertTrue(loop.isTerminated()); + } + + private static final class TestChannelHandler extends ChannelDuplexHandler { } + + private static final class TestChannelHandler2 extends ChannelDuplexHandler { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { } + } + + protected abstract EventLoopGroup newEventLoopGroup(); + protected abstract Class newChannel(); +} diff --git a/netty-channel/src/test/java/io/netty/channel/AdaptiveRecvByteBufAllocatorTest.java b/netty-channel/src/test/java/io/netty/channel/AdaptiveRecvByteBufAllocatorTest.java new file mode 100644 index 0000000..f2ec812 --- /dev/null +++ b/netty-channel/src/test/java/io/netty/channel/AdaptiveRecvByteBufAllocatorTest.java @@ -0,0 +1,118 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.UnpooledByteBufAllocator; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class AdaptiveRecvByteBufAllocatorTest { + @Mock + private ChannelConfig config; + private final ByteBufAllocator alloc = UnpooledByteBufAllocator.DEFAULT; + private RecvByteBufAllocator.ExtendedHandle handle; + + @BeforeEach + public void setup() { + config = mock(ChannelConfig.class); + when(config.isAutoRead()).thenReturn(true); + AdaptiveRecvByteBufAllocator recvByteBufAllocator = new AdaptiveRecvByteBufAllocator(64, 512, 1024 * 1024 * 10); + handle = (RecvByteBufAllocator.ExtendedHandle) recvByteBufAllocator.newHandle(); + handle.reset(config); + } + + @Test + public void rampUpBeforeReadCompleteWhenLargeDataPending() { + // Simulate that there is always more data when we attempt to read so we should always ramp up. + allocReadExpected(handle, alloc, 512); + allocReadExpected(handle, alloc, 8192); + allocReadExpected(handle, alloc, 131072); + allocReadExpected(handle, alloc, 2097152); + handle.readComplete(); + + handle.reset(config); + allocReadExpected(handle, alloc, 8388608); + } + + @Test + public void memoryAllocationIntervalsTest() { + computingNext(512, 512); + computingNext(8192, 1110); + computingNext(8192, 1200); + computingNext(4096, 1300); + computingNext(4096, 1500); + computingNext(2048, 1700); + computingNext(2048, 1550); + computingNext(2048, 2000); + computingNext(2048, 1900); + } + + private void computingNext(long expectedSize, int actualReadBytes) { + assertEquals(expectedSize, handle.guess()); + handle.reset(config); + handle.lastBytesRead(actualReadBytes); + handle.readComplete(); + } + + @Test + public void lastPartialReadDoesNotRampDown() { + allocReadExpected(handle, alloc, 512); + // Simulate there is just 1 byte remaining which is unread. However the total bytes in the current read cycle + // means that we should stay at the current step for the next ready cycle. + allocRead(handle, alloc, 8192, 1); + handle.readComplete(); + + handle.reset(config); + allocReadExpected(handle, alloc, 8192); + } + + @Test + public void lastPartialReadCanRampUp() { + allocReadExpected(handle, alloc, 512); + // We simulate there is just 1 less byte than we try to read, but because of the adaptive steps the total amount + // of bytes read for this read cycle steps up to prepare for the next read cycle. + allocRead(handle, alloc, 8192, 8191); + handle.readComplete(); + + handle.reset(config); + allocReadExpected(handle, alloc, 131072); + } + + private static void allocReadExpected(RecvByteBufAllocator.ExtendedHandle handle, + ByteBufAllocator alloc, + int expectedSize) { + allocRead(handle, alloc, expectedSize, expectedSize); + } + + private static void allocRead(RecvByteBufAllocator.ExtendedHandle handle, + ByteBufAllocator alloc, + int expectedBufferSize, + int lastRead) { + ByteBuf buf = handle.allocate(alloc); + assertEquals(expectedBufferSize, buf.capacity()); + handle.attemptedBytesRead(expectedBufferSize); + handle.lastBytesRead(lastRead); + handle.incMessagesRead(1); + buf.release(); + } +} diff --git a/netty-channel/src/test/java/io/netty/channel/BaseChannelTest.java b/netty-channel/src/test/java/io/netty/channel/BaseChannelTest.java new file mode 100644 index 0000000..a015d8e --- /dev/null +++ b/netty-channel/src/test/java/io/netty/channel/BaseChannelTest.java @@ -0,0 +1,89 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.local.LocalChannel; +import io.netty.channel.local.LocalServerChannel; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +class BaseChannelTest { + + private final LoggingTestHandler loggingTestHandler; + + BaseChannelTest() { + loggingTestHandler = new LoggingTestHandler(); + } + + ServerBootstrap getLocalServerBootstrap() { + EventLoopGroup serverGroup = new DefaultEventLoopGroup(); + ServerBootstrap sb = new ServerBootstrap(); + sb.group(serverGroup); + sb.channel(LocalServerChannel.class); + sb.childHandler(new ChannelInitializer() { + @Override + public void initChannel(LocalChannel ch) throws Exception { + } + }); + + return sb; + } + + Bootstrap getLocalClientBootstrap() { + EventLoopGroup clientGroup = new DefaultEventLoopGroup(); + Bootstrap cb = new Bootstrap(); + cb.channel(LocalChannel.class); + cb.group(clientGroup); + + cb.handler(loggingTestHandler); + + return cb; + } + + static ByteBuf createTestBuf(int len) { + ByteBuf buf = Unpooled.buffer(len, len); + buf.setIndex(0, len); + return buf; + } + + void assertLog(String firstExpected, String... otherExpected) { + String actual = loggingTestHandler.getLog(); + if (firstExpected.equals(actual)) { + return; + } + for (String e: otherExpected) { + if (e.equals(actual)) { + return; + } + } + + // Let the comparison fail with the first expectation. + assertEquals(firstExpected, actual); + } + + void clearLog() { + loggingTestHandler.clear(); + } + + void setInterest(LoggingTestHandler.Event... events) { + loggingTestHandler.setInterest(events); + } + +} diff --git a/netty-channel/src/test/java/io/netty/channel/ChannelInitializerTest.java b/netty-channel/src/test/java/io/netty/channel/ChannelInitializerTest.java new file mode 100644 index 0000000..a228f63 --- /dev/null +++ b/netty-channel/src/test/java/io/netty/channel/ChannelInitializerTest.java @@ -0,0 +1,407 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.channel.local.LocalAddress; +import io.netty.channel.local.LocalChannel; +import io.netty.channel.local.LocalServerChannel; +import io.netty.util.concurrent.EventExecutor; +import io.netty.util.concurrent.Future; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.util.Iterator; +import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; + + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertSame; + +public class ChannelInitializerTest { + private static final int TIMEOUT_MILLIS = 1000; + private static final LocalAddress SERVER_ADDRESS = new LocalAddress("addr"); + private EventLoopGroup group; + private ServerBootstrap server; + private Bootstrap client; + private InspectableHandler testHandler; + + @BeforeEach + public void setUp() { + group = new DefaultEventLoopGroup(1); + server = new ServerBootstrap() + .group(group) + .channel(LocalServerChannel.class) + .localAddress(SERVER_ADDRESS); + client = new Bootstrap() + .group(group) + .channel(LocalChannel.class) + .handler(new ChannelInboundHandlerAdapter()); + testHandler = new InspectableHandler(); + } + + @AfterEach + public void tearDown() { + group.shutdownGracefully(0, TIMEOUT_MILLIS, TimeUnit.MILLISECONDS).syncUninterruptibly(); + } + + @Test + public void testInitChannelThrowsRegisterFirst() { + testInitChannelThrows(true); + } + + @Test + public void testInitChannelThrowsRegisterAfter() { + testInitChannelThrows(false); + } + + private void testInitChannelThrows(boolean registerFirst) { + final Exception exception = new Exception(); + final AtomicReference causeRef = new AtomicReference(); + + ChannelPipeline pipeline = new LocalChannel().pipeline(); + + if (registerFirst) { + group.register(pipeline.channel()).syncUninterruptibly(); + } + pipeline.addFirst(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + throw exception; + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + causeRef.set(cause); + super.exceptionCaught(ctx, cause); + } + }); + + if (!registerFirst) { + group.register(pipeline.channel()).syncUninterruptibly(); + } + pipeline.channel().close().syncUninterruptibly(); + pipeline.channel().closeFuture().syncUninterruptibly(); + + assertSame(exception, causeRef.get()); + } + + @Test + public void testChannelInitializerInInitializerCorrectOrdering() { + final ChannelInboundHandlerAdapter handler1 = new ChannelInboundHandlerAdapter(); + final ChannelInboundHandlerAdapter handler2 = new ChannelInboundHandlerAdapter(); + final ChannelInboundHandlerAdapter handler3 = new ChannelInboundHandlerAdapter(); + final ChannelInboundHandlerAdapter handler4 = new ChannelInboundHandlerAdapter(); + + client.handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + ch.pipeline().addLast(handler1); + ch.pipeline().addLast(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + ch.pipeline().addLast(handler2); + ch.pipeline().addLast(handler3); + } + }); + ch.pipeline().addLast(handler4); + } + }).localAddress(LocalAddress.ANY); + + Channel channel = client.bind().syncUninterruptibly().channel(); + try { + // Execute some task on the EventLoop and wait until its done to be sure all handlers are added to the + // pipeline. + channel.eventLoop().submit(new Runnable() { + @Override + public void run() { + // NOOP + } + }).syncUninterruptibly(); + Iterator> handlers = channel.pipeline().iterator(); + assertSame(handler1, handlers.next().getValue()); + assertSame(handler2, handlers.next().getValue()); + assertSame(handler3, handlers.next().getValue()); + assertSame(handler4, handlers.next().getValue()); + assertFalse(handlers.hasNext()); + } finally { + channel.close().syncUninterruptibly(); + } + } + + @Test + public void testChannelInitializerReentrance() { + final AtomicInteger registeredCalled = new AtomicInteger(0); + final ChannelInboundHandlerAdapter handler1 = new ChannelInboundHandlerAdapter() { + @Override + public void channelRegistered(ChannelHandlerContext ctx) { + registeredCalled.incrementAndGet(); + } + }; + final AtomicInteger initChannelCalled = new AtomicInteger(0); + client.handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + initChannelCalled.incrementAndGet(); + ch.pipeline().addLast(handler1); + ch.pipeline().fireChannelRegistered(); + } + }).localAddress(LocalAddress.ANY); + + Channel channel = client.bind().syncUninterruptibly().channel(); + try { + // Execute some task on the EventLoop and wait until its done to be sure all handlers are added to the + // pipeline. + channel.eventLoop().submit(new Runnable() { + @Override + public void run() { + // NOOP + } + }).syncUninterruptibly(); + assertEquals(1, initChannelCalled.get()); + assertEquals(2, registeredCalled.get()); + } finally { + channel.close().syncUninterruptibly(); + } + } + + @Test + @Timeout(value = TIMEOUT_MILLIS, unit = TimeUnit.MILLISECONDS) + public void firstHandlerInPipelineShouldReceiveChannelRegisteredEvent() { + testChannelRegisteredEventPropagation(new ChannelInitializer() { + @Override + public void initChannel(LocalChannel channel) { + channel.pipeline().addFirst(testHandler); + } + }); + } + + @Test + @Timeout(value = TIMEOUT_MILLIS, unit = TimeUnit.MILLISECONDS) + public void lastHandlerInPipelineShouldReceiveChannelRegisteredEvent() { + testChannelRegisteredEventPropagation(new ChannelInitializer() { + @Override + public void initChannel(LocalChannel channel) { + channel.pipeline().addLast(testHandler); + } + }); + } + + @Test + public void testAddFirstChannelInitializer() { + testAddChannelInitializer(true); + } + + @Test + public void testAddLastChannelInitializer() { + testAddChannelInitializer(false); + } + + private static void testAddChannelInitializer(final boolean first) { + final AtomicBoolean called = new AtomicBoolean(); + EmbeddedChannel channel = new EmbeddedChannel(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + ChannelHandler handler = new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + called.set(true); + } + }; + if (first) { + ch.pipeline().addFirst(handler); + } else { + ch.pipeline().addLast(handler); + } + } + }); + channel.finish(); + assertTrue(called.get()); + } + + private void testChannelRegisteredEventPropagation(ChannelInitializer init) { + Channel clientChannel = null, serverChannel = null; + try { + server.childHandler(init); + serverChannel = server.bind().syncUninterruptibly().channel(); + clientChannel = client.connect(SERVER_ADDRESS).syncUninterruptibly().channel(); + assertEquals(1, testHandler.channelRegisteredCount.get()); + } finally { + closeChannel(clientChannel); + closeChannel(serverChannel); + } + } + + @SuppressWarnings("deprecation") + @Test + @Timeout(value = 10000, unit = TimeUnit.MILLISECONDS) + public void testChannelInitializerEventExecutor() throws Throwable { + final AtomicInteger invokeCount = new AtomicInteger(); + final AtomicInteger completeCount = new AtomicInteger(); + final AtomicReference errorRef = new AtomicReference(); + LocalAddress addr = new LocalAddress("test"); + + final EventExecutor executor = new DefaultEventLoop() { + private final ScheduledExecutorService execService = Executors.newSingleThreadScheduledExecutor(); + + @Override + public void shutdown() { + execService.shutdown(); + } + + @Override + public boolean inEventLoop(Thread thread) { + // Always return false which will ensure we always call execute(...) + return false; + } + + @Override + public boolean isShuttingDown() { + return false; + } + + @Override + public Future shutdownGracefully(long quietPeriod, long timeout, TimeUnit unit) { + throw new IllegalStateException(); + } + + @Override + public Future terminationFuture() { + throw new IllegalStateException(); + } + + @Override + public boolean isShutdown() { + return execService.isShutdown(); + } + + @Override + public boolean isTerminated() { + return execService.isTerminated(); + } + + @Override + public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException { + return execService.awaitTermination(timeout, unit); + } + + @Override + public void execute(Runnable command) { + execService.execute(command); + } + }; + + final CountDownLatch latch = new CountDownLatch(1); + ServerBootstrap serverBootstrap = new ServerBootstrap() + .channel(LocalServerChannel.class) + .group(group) + .localAddress(addr) + .childHandler(new ChannelInitializer() { + @Override + protected void initChannel(LocalChannel ch) { + ch.pipeline().addLast(executor, new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + invokeCount.incrementAndGet(); + ChannelHandlerContext ctx = ch.pipeline().context(this); + assertNotNull(ctx); + ch.pipeline().addAfter(ctx.executor(), + ctx.name(), null, new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + // just drop on the floor. + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) { + latch.countDown(); + } + }); + completeCount.incrementAndGet(); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + if (cause instanceof AssertionError) { + errorRef.set(cause); + } + } + }); + } + }); + + Channel server = serverBootstrap.bind().sync().channel(); + + Bootstrap clientBootstrap = new Bootstrap() + .channel(LocalChannel.class) + .group(group) + .remoteAddress(addr) + .handler(new ChannelInboundHandlerAdapter()); + + Channel client = clientBootstrap.connect().sync().channel(); + client.writeAndFlush("Hello World").sync(); + + client.close().sync(); + server.close().sync(); + + client.closeFuture().sync(); + server.closeFuture().sync(); + + // Wait until the handler is removed from the pipeline and so no more events are handled by it. + latch.await(); + + assertEquals(1, invokeCount.get()); + assertEquals(invokeCount.get(), completeCount.get()); + + Throwable cause = errorRef.get(); + if (cause != null) { + throw cause; + } + + executor.shutdown(); + assertTrue(executor.awaitTermination(5, TimeUnit.SECONDS)); + } + + private static void closeChannel(Channel c) { + if (c != null) { + c.close().syncUninterruptibly(); + } + } + + private static final class InspectableHandler extends ChannelDuplexHandler { + final AtomicInteger channelRegisteredCount = new AtomicInteger(0); + + @Override + public void channelRegistered(ChannelHandlerContext ctx) { + channelRegisteredCount.incrementAndGet(); + ctx.fireChannelRegistered(); + } + } +} diff --git a/netty-channel/src/test/java/io/netty/channel/ChannelOptionTest.java b/netty-channel/src/test/java/io/netty/channel/ChannelOptionTest.java new file mode 100644 index 0000000..170db51 --- /dev/null +++ b/netty-channel/src/test/java/io/netty/channel/ChannelOptionTest.java @@ -0,0 +1,63 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class ChannelOptionTest { + + @Test + public void testExists() { + String name = "test"; + assertFalse(ChannelOption.exists(name)); + ChannelOption option = ChannelOption.valueOf(name); + + assertTrue(ChannelOption.exists(name)); + assertNotNull(option); + } + + @Test + public void testValueOf() { + String name = "test1"; + assertFalse(ChannelOption.exists(name)); + ChannelOption option = ChannelOption.valueOf(name); + ChannelOption option2 = ChannelOption.valueOf(name); + + assertSame(option, option2); + } + + @Test + public void testCreateOrFail() { + String name = "test2"; + assertFalse(ChannelOption.exists(name)); + ChannelOption option = ChannelOption.newInstance(name); + assertTrue(ChannelOption.exists(name)); + assertNotNull(option); + + try { + ChannelOption.newInstance(name); + fail(); + } catch (IllegalArgumentException e) { + // expected + } + } +} diff --git a/netty-channel/src/test/java/io/netty/channel/ChannelOutboundBufferTest.java b/netty-channel/src/test/java/io/netty/channel/ChannelOutboundBufferTest.java new file mode 100644 index 0000000..ce92661 --- /dev/null +++ b/netty-channel/src/test/java/io/netty/channel/ChannelOutboundBufferTest.java @@ -0,0 +1,539 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.CompositeByteBuf; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.util.CharsetUtil; +import io.netty.util.concurrent.DefaultThreadFactory; +import io.netty.util.concurrent.RejectedExecutionHandlers; +import io.netty.util.concurrent.SingleThreadEventExecutor; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.net.SocketAddress; +import java.nio.ByteBuffer; +import java.util.Queue; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.TimeUnit; + +import static io.netty.buffer.Unpooled.*; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.*; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class ChannelOutboundBufferTest { + + @Test + public void testEmptyNioBuffers() { + TestChannel channel = new TestChannel(); + ChannelOutboundBuffer buffer = new ChannelOutboundBuffer(channel); + assertEquals(0, buffer.nioBufferCount()); + ByteBuffer[] buffers = buffer.nioBuffers(); + assertNotNull(buffers); + for (ByteBuffer b: buffers) { + assertNull(b); + } + assertEquals(0, buffer.nioBufferCount()); + release(buffer); + } + + @Test + public void testNioBuffersCancelledRemoveBytes() { + TestChannel channel = new TestChannel(); + ChannelOutboundBuffer buffer = new ChannelOutboundBuffer(channel); + ByteBuf b1 = wrappedBuffer(new byte[] { 0 }); + int r1 = b1.readableBytes(); + ChannelPromise p1 = channel.newPromise(); + buffer.addMessage(b1, r1, p1); + + ByteBuf b2 = wrappedBuffer(new byte[] { 0, 1 }); + int r2 = b2.readableBytes(); + ChannelPromise p2 = channel.newPromise(); + buffer.addMessage(b2, r2, p2); + p2.cancel(false); + + ByteBuf b3 = wrappedBuffer(new byte[] { 0 }); + int r3 = b3.readableBytes(); + ChannelPromise p3 = channel.newPromise(); + buffer.addMessage(b3, r3, p3); + buffer.addFlush(); + + ByteBuffer[] buffers = buffer.nioBuffers(); + assertEquals(2, buffer.nioBufferCount()); + assertNotNull(buffers); + assertEquals(r1, buffers[0].remaining()); + assertEquals(r3, buffers[1].remaining()); + + buffer.removeBytes(r1 + r3); + assertEquals(0, b1.refCnt()); + assertEquals(0, b2.refCnt()); + assertEquals(0, b3.refCnt()); + + assertTrue(buffer.isEmpty()); + release(buffer); + } + + @Test + public void testNioBuffersSingleBacked() { + TestChannel channel = new TestChannel(); + + ChannelOutboundBuffer buffer = new ChannelOutboundBuffer(channel); + assertEquals(0, buffer.nioBufferCount()); + + ByteBuf buf = copiedBuffer("buf1", CharsetUtil.US_ASCII); + ByteBuffer nioBuf = buf.internalNioBuffer(buf.readerIndex(), buf.readableBytes()); + buffer.addMessage(buf, buf.readableBytes(), channel.voidPromise()); + assertEquals(0, buffer.nioBufferCount(), "Should still be 0 as not flushed yet"); + buffer.addFlush(); + ByteBuffer[] buffers = buffer.nioBuffers(); + assertNotNull(buffers); + assertEquals(1, buffer.nioBufferCount(), "Should still be 0 as not flushed yet"); + for (int i = 0; i < buffer.nioBufferCount(); i++) { + if (i == 0) { + assertEquals(buffers[i], nioBuf); + } else { + assertNull(buffers[i]); + } + } + release(buffer); + } + + @Test + public void testNioBuffersExpand() { + TestChannel channel = new TestChannel(); + + ChannelOutboundBuffer buffer = new ChannelOutboundBuffer(channel); + + ByteBuf buf = directBuffer().writeBytes("buf1".getBytes(CharsetUtil.US_ASCII)); + for (int i = 0; i < 64; i++) { + buffer.addMessage(buf.copy(), buf.readableBytes(), channel.voidPromise()); + } + assertEquals(0, buffer.nioBufferCount(), "Should still be 0 as not flushed yet"); + buffer.addFlush(); + ByteBuffer[] buffers = buffer.nioBuffers(); + assertEquals(64, buffer.nioBufferCount()); + for (int i = 0; i < buffer.nioBufferCount(); i++) { + assertEquals(buffers[i], buf.internalNioBuffer(buf.readerIndex(), buf.readableBytes())); + } + release(buffer); + buf.release(); + } + + @Test + public void testNioBuffersExpand2() { + TestChannel channel = new TestChannel(); + + ChannelOutboundBuffer buffer = new ChannelOutboundBuffer(channel); + + CompositeByteBuf comp = compositeBuffer(256); + ByteBuf buf = directBuffer().writeBytes("buf1".getBytes(CharsetUtil.US_ASCII)); + for (int i = 0; i < 65; i++) { + comp.addComponent(true, buf.copy()); + } + buffer.addMessage(comp, comp.readableBytes(), channel.voidPromise()); + + assertEquals(0, buffer.nioBufferCount(), "Should still be 0 as not flushed yet"); + buffer.addFlush(); + ByteBuffer[] buffers = buffer.nioBuffers(); + assertEquals(65, buffer.nioBufferCount()); + for (int i = 0; i < buffer.nioBufferCount(); i++) { + if (i < 65) { + assertEquals(buffers[i], buf.internalNioBuffer(buf.readerIndex(), buf.readableBytes())); + } else { + assertNull(buffers[i]); + } + } + release(buffer); + buf.release(); + } + + @Test + public void testNioBuffersMaxCount() { + TestChannel channel = new TestChannel(); + + ChannelOutboundBuffer buffer = new ChannelOutboundBuffer(channel); + + CompositeByteBuf comp = compositeBuffer(256); + ByteBuf buf = directBuffer().writeBytes("buf1".getBytes(CharsetUtil.US_ASCII)); + for (int i = 0; i < 65; i++) { + comp.addComponent(true, buf.copy()); + } + assertEquals(65, comp.nioBufferCount()); + buffer.addMessage(comp, comp.readableBytes(), channel.voidPromise()); + assertEquals(0, buffer.nioBufferCount(), "Should still be 0 as not flushed yet"); + buffer.addFlush(); + final int maxCount = 10; // less than comp.nioBufferCount() + ByteBuffer[] buffers = buffer.nioBuffers(maxCount, Integer.MAX_VALUE); + assertTrue(buffer.nioBufferCount() <= maxCount, "Should not be greater than maxCount"); + for (int i = 0; i < buffer.nioBufferCount(); i++) { + assertEquals(buffers[i], buf.internalNioBuffer(buf.readerIndex(), buf.readableBytes())); + } + release(buffer); + buf.release(); + } + + private static void release(ChannelOutboundBuffer buffer) { + for (;;) { + if (!buffer.remove()) { + break; + } + } + } + + private static final class TestChannel extends AbstractChannel { + private static final ChannelMetadata TEST_METADATA = new ChannelMetadata(false); + private final ChannelConfig config = new DefaultChannelConfig(this); + + TestChannel() { + super(null); + } + + @Override + protected AbstractUnsafe newUnsafe() { + return new TestUnsafe(); + } + + @Override + protected boolean isCompatible(EventLoop loop) { + return false; + } + + @Override + protected SocketAddress localAddress0() { + throw new UnsupportedOperationException(); + } + + @Override + protected SocketAddress remoteAddress0() { + throw new UnsupportedOperationException(); + } + + @Override + protected void doBind(SocketAddress localAddress) { + throw new UnsupportedOperationException(); + } + + @Override + protected void doDisconnect() { + throw new UnsupportedOperationException(); + } + + @Override + protected void doClose() { + throw new UnsupportedOperationException(); + } + + @Override + protected void doBeginRead() { + throw new UnsupportedOperationException(); + } + + @Override + protected void doWrite(ChannelOutboundBuffer in) { + throw new UnsupportedOperationException(); + } + + @Override + public ChannelConfig config() { + return config; + } + + @Override + public boolean isOpen() { + return true; + } + + @Override + public boolean isActive() { + return true; + } + + @Override + public ChannelMetadata metadata() { + return TEST_METADATA; + } + + final class TestUnsafe extends AbstractUnsafe { + @Override + public void connect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) { + throw new UnsupportedOperationException(); + } + } + } + + @Test + public void testWritability() { + final StringBuilder buf = new StringBuilder(); + EmbeddedChannel ch = new EmbeddedChannel(new ChannelInboundHandlerAdapter() { + @Override + public void channelWritabilityChanged(ChannelHandlerContext ctx) { + buf.append(ctx.channel().isWritable()); + buf.append(' '); + } + }); + + ch.config().setWriteBufferLowWaterMark(128 + ChannelOutboundBuffer.CHANNEL_OUTBOUND_BUFFER_ENTRY_OVERHEAD); + ch.config().setWriteBufferHighWaterMark(256 + ChannelOutboundBuffer.CHANNEL_OUTBOUND_BUFFER_ENTRY_OVERHEAD); + + ch.write(buffer().writeZero(128)); + // Ensure exceeding the low watermark does not make channel unwritable. + ch.write(buffer().writeZero(2)); + assertThat(buf.toString(), is("")); + + ch.unsafe().outboundBuffer().addFlush(); + + // Ensure exceeding the high watermark makes channel unwritable. + ch.write(buffer().writeZero(127)); + assertThat(buf.toString(), is("false ")); + + // Ensure going down to the low watermark makes channel writable again by flushing the first write. + assertThat(ch.unsafe().outboundBuffer().remove(), is(true)); + assertThat(ch.unsafe().outboundBuffer().remove(), is(true)); + assertThat(ch.unsafe().outboundBuffer().totalPendingWriteBytes(), + is(127L + ChannelOutboundBuffer.CHANNEL_OUTBOUND_BUFFER_ENTRY_OVERHEAD)); + assertThat(buf.toString(), is("false true ")); + + safeClose(ch); + } + + @Test + public void testUserDefinedWritability() { + final StringBuilder buf = new StringBuilder(); + EmbeddedChannel ch = new EmbeddedChannel(new ChannelInboundHandlerAdapter() { + @Override + public void channelWritabilityChanged(ChannelHandlerContext ctx) { + buf.append(ctx.channel().isWritable()); + buf.append(' '); + } + }); + + ch.config().setWriteBufferLowWaterMark(128); + ch.config().setWriteBufferHighWaterMark(256); + + ChannelOutboundBuffer cob = ch.unsafe().outboundBuffer(); + + // Ensure that the default value of a user-defined writability flag is true. + for (int i = 1; i <= 30; i ++) { + assertThat(cob.getUserDefinedWritability(i), is(true)); + } + + // Ensure that setting a user-defined writability flag to false affects channel.isWritable(); + cob.setUserDefinedWritability(1, false); + ch.runPendingTasks(); + assertThat(buf.toString(), is("false ")); + + // Ensure that setting a user-defined writability flag to true affects channel.isWritable(); + cob.setUserDefinedWritability(1, true); + ch.runPendingTasks(); + assertThat(buf.toString(), is("false true ")); + + safeClose(ch); + } + + @Test + public void testUserDefinedWritability2() { + final StringBuilder buf = new StringBuilder(); + EmbeddedChannel ch = new EmbeddedChannel(new ChannelInboundHandlerAdapter() { + @Override + public void channelWritabilityChanged(ChannelHandlerContext ctx) { + buf.append(ctx.channel().isWritable()); + buf.append(' '); + } + }); + + ch.config().setWriteBufferLowWaterMark(128); + ch.config().setWriteBufferHighWaterMark(256); + + ChannelOutboundBuffer cob = ch.unsafe().outboundBuffer(); + + // Ensure that setting a user-defined writability flag to false affects channel.isWritable() + cob.setUserDefinedWritability(1, false); + ch.runPendingTasks(); + assertThat(buf.toString(), is("false ")); + + // Ensure that setting another user-defined writability flag to false does not trigger + // channelWritabilityChanged. + cob.setUserDefinedWritability(2, false); + ch.runPendingTasks(); + assertThat(buf.toString(), is("false ")); + + // Ensure that setting only one user-defined writability flag to true does not affect channel.isWritable() + cob.setUserDefinedWritability(1, true); + ch.runPendingTasks(); + assertThat(buf.toString(), is("false ")); + + // Ensure that setting all user-defined writability flags to true affects channel.isWritable() + cob.setUserDefinedWritability(2, true); + ch.runPendingTasks(); + assertThat(buf.toString(), is("false true ")); + + safeClose(ch); + } + + @Test + public void testMixedWritability() { + final StringBuilder buf = new StringBuilder(); + EmbeddedChannel ch = new EmbeddedChannel(new ChannelInboundHandlerAdapter() { + @Override + public void channelWritabilityChanged(ChannelHandlerContext ctx) { + buf.append(ctx.channel().isWritable()); + buf.append(' '); + } + }); + + ch.config().setWriteBufferLowWaterMark(128); + ch.config().setWriteBufferHighWaterMark(256); + + ChannelOutboundBuffer cob = ch.unsafe().outboundBuffer(); + + // Trigger channelWritabilityChanged() by writing a lot. + ch.write(buffer().writeZero(257)); + assertThat(buf.toString(), is("false ")); + + // Ensure that setting a user-defined writability flag to false does not trigger channelWritabilityChanged() + cob.setUserDefinedWritability(1, false); + ch.runPendingTasks(); + assertThat(buf.toString(), is("false ")); + + // Ensure reducing the totalPendingWriteBytes down to zero does not trigger channelWritabilityChanged() + // because of the user-defined writability flag. + ch.flush(); + assertThat(cob.totalPendingWriteBytes(), is(0L)); + assertThat(buf.toString(), is("false ")); + + // Ensure that setting the user-defined writability flag to true triggers channelWritabilityChanged() + cob.setUserDefinedWritability(1, true); + ch.runPendingTasks(); + assertThat(buf.toString(), is("false true ")); + + safeClose(ch); + } + + @Test + @Timeout(value = 5000, unit = TimeUnit.MILLISECONDS) + public void testWriteTaskRejected() throws Exception { + final SingleThreadEventExecutor executor = new SingleThreadEventExecutor( + null, new DefaultThreadFactory("executorPool"), + true, 1, RejectedExecutionHandlers.reject()) { + @Override + protected void run() { + do { + Runnable task = takeTask(); + if (task != null) { + task.run(); + updateLastExecutionTime(); + } + } while (!confirmShutdown()); + } + + @Override + protected Queue newTaskQueue(int maxPendingTasks) { + return super.newTaskQueue(1); + } + }; + final CountDownLatch handlerAddedLatch = new CountDownLatch(1); + final CountDownLatch handlerRemovedLatch = new CountDownLatch(1); + EmbeddedChannel ch = new EmbeddedChannel(); + ch.pipeline().addLast(executor, "handler", new ChannelOutboundHandlerAdapter() { + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) { + promise.setFailure(new AssertionError("Should not be called")); + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) { + handlerAddedLatch.countDown(); + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) { + handlerRemovedLatch.countDown(); + } + }); + + // Lets wait until we are sure the handler was added. + handlerAddedLatch.await(); + + final CountDownLatch executeLatch = new CountDownLatch(1); + final CountDownLatch runLatch = new CountDownLatch(1); + executor.execute(new Runnable() { + @Override + public void run() { + try { + runLatch.countDown(); + executeLatch.await(); + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + }); + + runLatch.await(); + + executor.execute(new Runnable() { + @Override + public void run() { + // Will not be executed but ensure the pending count is 1. + } + }); + + assertEquals(1, executor.pendingTasks()); + assertEquals(0, ch.unsafe().outboundBuffer().totalPendingWriteBytes()); + + ByteBuf buffer = buffer(128).writeZero(128); + ChannelFuture future = ch.write(buffer); + ch.runPendingTasks(); + + assertTrue(future.cause() instanceof RejectedExecutionException); + assertEquals(0, buffer.refCnt()); + + // In case of rejected task we should not have anything pending. + assertEquals(0, ch.unsafe().outboundBuffer().totalPendingWriteBytes()); + executeLatch.countDown(); + + while (executor.pendingTasks() != 0) { + // Wait until there is no more pending task left. + Thread.sleep(10); + } + + ch.pipeline().remove("handler"); + + // Ensure we do not try to shutdown the executor before we handled everything for the Channel. Otherwise + // the Executor may reject when the Channel tries to add a task to it. + handlerRemovedLatch.await(); + + safeClose(ch); + + executor.shutdownGracefully(); + } + + private static void safeClose(EmbeddedChannel ch) { + ch.finish(); + for (;;) { + ByteBuf m = ch.readOutbound(); + if (m == null) { + break; + } + m.release(); + } + } +} diff --git a/netty-channel/src/test/java/io/netty/channel/CoalescingBufferQueueTest.java b/netty-channel/src/test/java/io/netty/channel/CoalescingBufferQueueTest.java new file mode 100644 index 0000000..278d4c5 --- /dev/null +++ b/netty-channel/src/test/java/io/netty/channel/CoalescingBufferQueueTest.java @@ -0,0 +1,318 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.channel; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.util.CharsetUtil; +import io.netty.util.ReferenceCountUtil; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * Tests for {@link CoalescingBufferQueue}. + */ +public class CoalescingBufferQueueTest { + + private ByteBuf cat; + private ByteBuf mouse; + + private ChannelPromise catPromise, emptyPromise; + private ChannelPromise voidPromise; + private ChannelFutureListener mouseListener; + + private boolean mouseDone; + private boolean mouseSuccess; + + private EmbeddedChannel channel; + private CoalescingBufferQueue writeQueue; + + @BeforeEach + public void setup() { + mouseDone = false; + mouseSuccess = false; + channel = new EmbeddedChannel(); + writeQueue = new CoalescingBufferQueue(channel, 16, true); + catPromise = newPromise(); + mouseListener = new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + mouseDone = true; + mouseSuccess = future.isSuccess(); + } + }; + emptyPromise = newPromise(); + voidPromise = channel.voidPromise(); + + cat = Unpooled.wrappedBuffer("cat".getBytes(CharsetUtil.US_ASCII)); + mouse = Unpooled.wrappedBuffer("mouse".getBytes(CharsetUtil.US_ASCII)); + } + + @AfterEach + public void finish() { + assertFalse(channel.finish()); + } + + @Test + public void testAddFirstPromiseRetained() { + writeQueue.add(cat, catPromise); + assertQueueSize(3, false); + writeQueue.add(mouse, mouseListener); + assertQueueSize(8, false); + ChannelPromise aggregatePromise = newPromise(); + assertEquals("catmous", dequeue(7, aggregatePromise)); + ByteBuf remainder = Unpooled.wrappedBuffer("mous".getBytes(CharsetUtil.US_ASCII)); + writeQueue.addFirst(remainder, aggregatePromise); + ChannelPromise aggregatePromise2 = newPromise(); + assertEquals("mouse", dequeue(5, aggregatePromise2)); + aggregatePromise2.setSuccess(); + assertTrue(catPromise.isSuccess()); + assertTrue(mouseSuccess); + assertEquals(0, cat.refCnt()); + assertEquals(0, mouse.refCnt()); + } + + @Test + public void testAddFirstVoidPromise() { + writeQueue.add(cat, catPromise); + assertQueueSize(3, false); + writeQueue.add(mouse, mouseListener); + assertQueueSize(8, false); + ChannelPromise aggregatePromise = newPromise(); + assertEquals("catmous", dequeue(7, aggregatePromise)); + ByteBuf remainder = Unpooled.wrappedBuffer("mous".getBytes(CharsetUtil.US_ASCII)); + writeQueue.addFirst(remainder, voidPromise); + ChannelPromise aggregatePromise2 = newPromise(); + assertEquals("mouse", dequeue(5, aggregatePromise2)); + aggregatePromise2.setSuccess(); + // Because we used a void promise above, we shouldn't complete catPromise until aggregatePromise is completed. + assertFalse(catPromise.isSuccess()); + assertTrue(mouseSuccess); + aggregatePromise.setSuccess(); + assertTrue(catPromise.isSuccess()); + assertTrue(mouseSuccess); + assertEquals(0, cat.refCnt()); + assertEquals(0, mouse.refCnt()); + } + + @Test + public void testAggregateWithFullRead() { + writeQueue.add(cat, catPromise); + assertQueueSize(3, false); + writeQueue.add(mouse, mouseListener); + assertQueueSize(8, false); + ChannelPromise aggregatePromise = newPromise(); + assertEquals("catmouse", dequeue(8, aggregatePromise)); + assertQueueSize(0, true); + assertFalse(catPromise.isSuccess()); + assertFalse(mouseDone); + aggregatePromise.setSuccess(); + assertTrue(catPromise.isSuccess()); + assertTrue(mouseSuccess); + assertEquals(0, cat.refCnt()); + assertEquals(0, mouse.refCnt()); + } + + @Test + public void testWithVoidPromise() { + writeQueue.add(cat, voidPromise); + writeQueue.add(mouse, voidPromise); + assertQueueSize(8, false); + assertEquals("catm", dequeue(4, newPromise())); + assertQueueSize(4, false); + assertEquals("ouse", dequeue(4, newPromise())); + assertQueueSize(0, true); + assertEquals(0, cat.refCnt()); + assertEquals(0, mouse.refCnt()); + } + + @Test + public void testAggregateWithPartialRead() { + writeQueue.add(cat, catPromise); + writeQueue.add(mouse, mouseListener); + ChannelPromise aggregatePromise = newPromise(); + assertEquals("catm", dequeue(4, aggregatePromise)); + assertQueueSize(4, false); + assertFalse(catPromise.isSuccess()); + assertFalse(mouseDone); + aggregatePromise.setSuccess(); + assertTrue(catPromise.isSuccess()); + assertFalse(mouseDone); + + aggregatePromise = newPromise(); + assertEquals("ouse", dequeue(Integer.MAX_VALUE, aggregatePromise)); + assertQueueSize(0, true); + assertFalse(mouseDone); + aggregatePromise.setSuccess(); + assertTrue(mouseSuccess); + assertEquals(0, cat.refCnt()); + assertEquals(0, mouse.refCnt()); + } + + @Test + public void testReadExactAddedBufferSizeReturnsOriginal() { + writeQueue.add(cat, catPromise); + writeQueue.add(mouse, mouseListener); + + ChannelPromise aggregatePromise = newPromise(); + assertSame(cat, writeQueue.remove(3, aggregatePromise)); + assertFalse(catPromise.isSuccess()); + aggregatePromise.setSuccess(); + assertTrue(catPromise.isSuccess()); + assertEquals(1, cat.refCnt()); + cat.release(); + + aggregatePromise = newPromise(); + assertSame(mouse, writeQueue.remove(5, aggregatePromise)); + assertFalse(mouseDone); + aggregatePromise.setSuccess(); + assertTrue(mouseSuccess); + assertEquals(1, mouse.refCnt()); + mouse.release(); + } + + @Test + public void testReadEmptyQueueReturnsEmptyBuffer() { + // Not used in this test. + cat.release(); + mouse.release(); + + assertQueueSize(0, true); + ChannelPromise aggregatePromise = newPromise(); + assertEquals("", dequeue(Integer.MAX_VALUE, aggregatePromise)); + assertQueueSize(0, true); + } + + @Test + public void testReleaseAndFailAll() { + writeQueue.add(cat, catPromise); + writeQueue.add(mouse, mouseListener); + RuntimeException cause = new RuntimeException("ooops"); + writeQueue.releaseAndFailAll(cause); + ChannelPromise aggregatePromise = newPromise(); + assertQueueSize(0, true); + assertEquals(0, cat.refCnt()); + assertEquals(0, mouse.refCnt()); + assertSame(cause, catPromise.cause()); + assertEquals("", dequeue(Integer.MAX_VALUE, aggregatePromise)); + assertQueueSize(0, true); + } + + @Test + public void testEmptyBuffersAreCoalesced() { + ByteBuf empty = Unpooled.buffer(0, 1); + assertQueueSize(0, true); + writeQueue.add(cat, catPromise); + writeQueue.add(empty, emptyPromise); + assertQueueSize(3, false); + ChannelPromise aggregatePromise = newPromise(); + assertEquals("cat", dequeue(3, aggregatePromise)); + assertQueueSize(0, true); + assertFalse(catPromise.isSuccess()); + assertFalse(emptyPromise.isSuccess()); + aggregatePromise.setSuccess(); + assertTrue(catPromise.isSuccess()); + assertTrue(emptyPromise.isSuccess()); + assertEquals(0, cat.refCnt()); + assertEquals(0, empty.refCnt()); + } + + @Test + public void testMerge() { + writeQueue.add(cat, catPromise); + CoalescingBufferQueue otherQueue = new CoalescingBufferQueue(channel); + otherQueue.add(mouse, mouseListener); + otherQueue.copyTo(writeQueue); + assertQueueSize(8, false); + ChannelPromise aggregatePromise = newPromise(); + assertEquals("catmouse", dequeue(8, aggregatePromise)); + assertQueueSize(0, true); + assertFalse(catPromise.isSuccess()); + assertFalse(mouseDone); + aggregatePromise.setSuccess(); + assertTrue(catPromise.isSuccess()); + assertTrue(mouseSuccess); + assertEquals(0, cat.refCnt()); + assertEquals(0, mouse.refCnt()); + } + + @Test + public void testWritabilityChanged() { + testWritabilityChanged0(false); + } + + @Test + public void testWritabilityChangedFailAll() { + testWritabilityChanged0(true); + } + + private void testWritabilityChanged0(boolean fail) { + channel.config().setWriteBufferWaterMark(new WriteBufferWaterMark(3, 4)); + assertTrue(channel.isWritable()); + writeQueue.add(Unpooled.wrappedBuffer(new byte[] {1 , 2, 3})); + assertTrue(channel.isWritable()); + writeQueue.add(Unpooled.wrappedBuffer(new byte[] {4, 5})); + assertFalse(channel.isWritable()); + assertEquals(5, writeQueue.readableBytes()); + + if (fail) { + writeQueue.releaseAndFailAll(new IllegalStateException()); + } else { + ByteBuf buffer = writeQueue.removeFirst(voidPromise); + assertEquals(1, buffer.readByte()); + assertEquals(2, buffer.readByte()); + assertEquals(3, buffer.readByte()); + assertFalse(buffer.isReadable()); + buffer.release(); + assertTrue(channel.isWritable()); + + buffer = writeQueue.removeFirst(voidPromise); + assertEquals(4, buffer.readByte()); + assertEquals(5, buffer.readByte()); + assertFalse(buffer.isReadable()); + buffer.release(); + } + + assertTrue(channel.isWritable()); + assertTrue(writeQueue.isEmpty()); + } + + private ChannelPromise newPromise() { + return channel.newPromise(); + } + + private void assertQueueSize(int size, boolean isEmpty) { + assertEquals(size, writeQueue.readableBytes()); + if (isEmpty) { + assertTrue(writeQueue.isEmpty()); + } else { + assertFalse(writeQueue.isEmpty()); + } + } + + private String dequeue(int numBytes, ChannelPromise aggregatePromise) { + ByteBuf removed = writeQueue.remove(numBytes, aggregatePromise); + String result = removed.toString(CharsetUtil.US_ASCII); + ReferenceCountUtil.safeRelease(removed); + return result; + } +} diff --git a/netty-channel/src/test/java/io/netty/channel/CombinedChannelDuplexHandlerTest.java b/netty-channel/src/test/java/io/netty/channel/CombinedChannelDuplexHandlerTest.java new file mode 100644 index 0000000..43f4b6f --- /dev/null +++ b/netty-channel/src/test/java/io/netty/channel/CombinedChannelDuplexHandlerTest.java @@ -0,0 +1,481 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.channel.embedded.EmbeddedChannel; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.api.function.Executable; + +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.util.ArrayDeque; +import java.util.Queue; +import java.util.concurrent.TimeUnit; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class CombinedChannelDuplexHandlerTest { + + private static final Object MSG = new Object(); + private static final SocketAddress LOCAL_ADDRESS = new InetSocketAddress(0); + private static final SocketAddress REMOTE_ADDRESS = new InetSocketAddress(0); + private static final Throwable CAUSE = new Throwable(); + private static final Object USER_EVENT = new Object(); + + private enum Event { + REGISTERED, + UNREGISTERED, + ACTIVE, + INACTIVE, + CHANNEL_READ, + CHANNEL_READ_COMPLETE, + EXCEPTION_CAUGHT, + USER_EVENT_TRIGGERED, + CHANNEL_WRITABILITY_CHANGED, + HANDLER_ADDED, + HANDLER_REMOVED, + BIND, + CONNECT, + WRITE, + FLUSH, + READ, + REGISTER, + DEREGISTER, + CLOSE, + DISCONNECT + } + + @Test + public void testInboundRemoveBeforeAdded() { + final CombinedChannelDuplexHandler handler = + new CombinedChannelDuplexHandler( + new ChannelInboundHandlerAdapter(), new ChannelOutboundHandlerAdapter()); + assertThrows(IllegalStateException.class, new Executable() { + @Override + public void execute() { + handler.removeInboundHandler(); + } + }); + } + + @Test + public void testOutboundRemoveBeforeAdded() { + final CombinedChannelDuplexHandler handler = + new CombinedChannelDuplexHandler( + new ChannelInboundHandlerAdapter(), new ChannelOutboundHandlerAdapter()); + assertThrows(IllegalStateException.class, new Executable() { + @Override + public void execute() { + handler.removeOutboundHandler(); + } + }); + } + + @Test + public void testInboundHandlerImplementsOutboundHandler() { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + new CombinedChannelDuplexHandler( + new ChannelDuplexHandler(), new ChannelOutboundHandlerAdapter()); + } + }); + } + + @Test + public void testOutboundHandlerImplementsInboundHandler() { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + new CombinedChannelDuplexHandler( + new ChannelInboundHandlerAdapter(), new ChannelDuplexHandler()); + } + }); + } + + @Test + public void testInitNotCalledBeforeAdded() { + final CombinedChannelDuplexHandler handler = + new CombinedChannelDuplexHandler() { }; + assertThrows(IllegalStateException.class, new Executable() { + @Override + public void execute() throws Throwable { + handler.handlerAdded(null); + } + }); + } + + @Test + public void testExceptionCaughtBothCombinedHandlers() { + final Exception exception = new Exception(); + final Queue queue = new ArrayDeque(); + + ChannelInboundHandler inboundHandler = new ChannelInboundHandlerAdapter() { + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + assertSame(exception, cause); + queue.add(this); + ctx.fireExceptionCaught(cause); + } + }; + ChannelOutboundHandler outboundHandler = new ChannelOutboundHandlerAdapter() { + @SuppressWarnings("deprecation") + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + assertSame(exception, cause); + queue.add(this); + ctx.fireExceptionCaught(cause); + } + }; + ChannelInboundHandler lastHandler = new ChannelInboundHandlerAdapter() { + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + assertSame(exception, cause); + queue.add(this); + } + }; + EmbeddedChannel channel = new EmbeddedChannel( + new CombinedChannelDuplexHandler( + inboundHandler, outboundHandler), lastHandler); + channel.pipeline().fireExceptionCaught(exception); + assertFalse(channel.finish()); + assertSame(inboundHandler, queue.poll()); + assertSame(outboundHandler, queue.poll()); + assertSame(lastHandler, queue.poll()); + assertTrue(queue.isEmpty()); + } + + @Test + public void testInboundEvents() { + InboundEventHandler inboundHandler = new InboundEventHandler(); + + CombinedChannelDuplexHandler handler = + new CombinedChannelDuplexHandler( + inboundHandler, new ChannelOutboundHandlerAdapter()); + + EmbeddedChannel channel = new EmbeddedChannel(); + channel.pipeline().addLast(handler); + assertEquals(Event.HANDLER_ADDED, inboundHandler.pollEvent()); + + doInboundOperations(channel); + assertInboundOperations(inboundHandler); + handler.removeInboundHandler(); + + assertEquals(Event.HANDLER_REMOVED, inboundHandler.pollEvent()); + + // These should not be handled by the inboundHandler anymore as it was removed before + doInboundOperations(channel); + + // Should have not received any more events as it was removed before via removeInboundHandler() + assertNull(inboundHandler.pollEvent()); + try { + channel.checkException(); + fail(); + } catch (Throwable cause) { + assertSame(CAUSE, cause); + } + + assertTrue(channel.finish()); + assertNull(inboundHandler.pollEvent()); + } + + @Test + public void testOutboundEvents() { + ChannelInboundHandler inboundHandler = new ChannelInboundHandlerAdapter(); + OutboundEventHandler outboundHandler = new OutboundEventHandler(); + + CombinedChannelDuplexHandler handler = + new CombinedChannelDuplexHandler( + inboundHandler, outboundHandler); + + EmbeddedChannel channel = new EmbeddedChannel(); + channel.pipeline().addLast(new OutboundEventHandler()); + channel.pipeline().addLast(handler); + + assertEquals(Event.HANDLER_ADDED, outboundHandler.pollEvent()); + + doOutboundOperations(channel); + + assertOutboundOperations(outboundHandler); + + handler.removeOutboundHandler(); + + assertEquals(Event.HANDLER_REMOVED, outboundHandler.pollEvent()); + + // These should not be handled by the inboundHandler anymore as it was removed before + doOutboundOperations(channel); + + // Should have not received any more events as it was removed before via removeInboundHandler() + assertNull(outboundHandler.pollEvent()); + assertFalse(channel.finish()); + assertNull(outboundHandler.pollEvent()); + } + + private static void doOutboundOperations(Channel channel) { + channel.pipeline().bind(LOCAL_ADDRESS).syncUninterruptibly(); + channel.pipeline().connect(REMOTE_ADDRESS, LOCAL_ADDRESS).syncUninterruptibly(); + channel.pipeline().write(MSG).syncUninterruptibly(); + channel.pipeline().flush(); + channel.pipeline().read(); + channel.pipeline().disconnect().syncUninterruptibly(); + channel.pipeline().close().syncUninterruptibly(); + channel.pipeline().deregister().syncUninterruptibly(); + } + + private static void assertOutboundOperations(OutboundEventHandler outboundHandler) { + assertEquals(Event.BIND, outboundHandler.pollEvent()); + assertEquals(Event.CONNECT, outboundHandler.pollEvent()); + assertEquals(Event.WRITE, outboundHandler.pollEvent()); + assertEquals(Event.FLUSH, outboundHandler.pollEvent()); + assertEquals(Event.READ, outboundHandler.pollEvent()); + assertEquals(Event.CLOSE, outboundHandler.pollEvent()); + assertEquals(Event.CLOSE, outboundHandler.pollEvent()); + assertEquals(Event.DEREGISTER, outboundHandler.pollEvent()); + } + + private static void doInboundOperations(Channel channel) { + channel.pipeline().fireChannelRegistered(); + channel.pipeline().fireChannelActive(); + channel.pipeline().fireChannelRead(MSG); + channel.pipeline().fireChannelReadComplete(); + channel.pipeline().fireExceptionCaught(CAUSE); + channel.pipeline().fireUserEventTriggered(USER_EVENT); + channel.pipeline().fireChannelWritabilityChanged(); + channel.pipeline().fireChannelInactive(); + channel.pipeline().fireChannelUnregistered(); + } + + private static void assertInboundOperations(InboundEventHandler handler) { + assertEquals(Event.REGISTERED, handler.pollEvent()); + assertEquals(Event.ACTIVE, handler.pollEvent()); + assertEquals(Event.CHANNEL_READ, handler.pollEvent()); + assertEquals(Event.CHANNEL_READ_COMPLETE, handler.pollEvent()); + assertEquals(Event.EXCEPTION_CAUGHT, handler.pollEvent()); + assertEquals(Event.USER_EVENT_TRIGGERED, handler.pollEvent()); + assertEquals(Event.CHANNEL_WRITABILITY_CHANGED, handler.pollEvent()); + assertEquals(Event.INACTIVE, handler.pollEvent()); + assertEquals(Event.UNREGISTERED, handler.pollEvent()); + } + + @Test + @Timeout(value = 3000, unit = TimeUnit.MILLISECONDS) + public void testPromisesPassed() { + OutboundEventHandler outboundHandler = new OutboundEventHandler(); + EmbeddedChannel ch = new EmbeddedChannel(outboundHandler, + new CombinedChannelDuplexHandler( + new ChannelInboundHandlerAdapter(), new ChannelOutboundHandlerAdapter())); + ChannelPipeline pipeline = ch.pipeline(); + + ChannelPromise promise = ch.newPromise(); + pipeline.bind(LOCAL_ADDRESS, promise); + promise.syncUninterruptibly(); + + promise = ch.newPromise(); + pipeline.connect(REMOTE_ADDRESS, LOCAL_ADDRESS, promise); + promise.syncUninterruptibly(); + + promise = ch.newPromise(); + pipeline.close(promise); + promise.syncUninterruptibly(); + + promise = ch.newPromise(); + pipeline.disconnect(promise); + promise.syncUninterruptibly(); + + promise = ch.newPromise(); + pipeline.write(MSG, promise); + promise.syncUninterruptibly(); + + promise = ch.newPromise(); + pipeline.deregister(promise); + promise.syncUninterruptibly(); + ch.finish(); + } + + @Test + public void testNotSharable() { + assertThrows(IllegalStateException.class, new Executable() { + @Override + public void execute() { + new CombinedChannelDuplexHandler() { + @Override + public boolean isSharable() { + return true; + } + }; + } + }); + } + + private static final class InboundEventHandler extends ChannelInboundHandlerAdapter { + private final Queue queue = new ArrayDeque(); + + @Override + public void handlerAdded(ChannelHandlerContext ctx) { + queue.add(Event.HANDLER_ADDED); + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) { + queue.add(Event.HANDLER_REMOVED); + } + + @Override + public void channelRegistered(ChannelHandlerContext ctx) { + queue.add(Event.REGISTERED); + } + + @Override + public void channelUnregistered(ChannelHandlerContext ctx) { + queue.add(Event.UNREGISTERED); + } + + @Override + public void channelActive(ChannelHandlerContext ctx) { + queue.add(Event.ACTIVE); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) { + queue.add(Event.INACTIVE); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + queue.add(Event.CHANNEL_READ); + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) { + queue.add(Event.CHANNEL_READ_COMPLETE); + } + + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + queue.add(Event.USER_EVENT_TRIGGERED); + } + + @Override + public void channelWritabilityChanged(ChannelHandlerContext ctx) { + queue.add(Event.CHANNEL_WRITABILITY_CHANGED); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + queue.add(Event.EXCEPTION_CAUGHT); + } + + Event pollEvent() { + Object o = queue.poll(); + if (o instanceof AssertionError) { + throw (AssertionError) o; + } + return (Event) o; + } + } + + private static final class OutboundEventHandler extends ChannelOutboundHandlerAdapter { + private final Queue queue = new ArrayDeque(); + + @Override + public void handlerAdded(ChannelHandlerContext ctx) { + queue.add(Event.HANDLER_ADDED); + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) { + queue.add(Event.HANDLER_REMOVED); + } + + @Override + public void bind(ChannelHandlerContext ctx, SocketAddress localAddress, ChannelPromise promise) { + try { + assertSame(LOCAL_ADDRESS, localAddress); + queue.add(Event.BIND); + promise.setSuccess(); + } catch (AssertionError e) { + promise.setFailure(e); + } + } + + @Override + public void connect(ChannelHandlerContext ctx, SocketAddress remoteAddress, + SocketAddress localAddress, ChannelPromise promise) { + try { + assertSame(REMOTE_ADDRESS, remoteAddress); + assertSame(LOCAL_ADDRESS, localAddress); + queue.add(Event.CONNECT); + promise.setSuccess(); + } catch (AssertionError e) { + promise.setFailure(e); + } + } + + @Override + public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) { + queue.add(Event.DISCONNECT); + promise.setSuccess(); + } + + @Override + public void close(ChannelHandlerContext ctx, ChannelPromise promise) { + queue.add(Event.CLOSE); + promise.setSuccess(); + } + + @Override + public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) { + queue.add(Event.DEREGISTER); + promise.setSuccess(); + } + + @Override + public void read(ChannelHandlerContext ctx) { + queue.add(Event.READ); + } + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) { + try { + assertSame(MSG, msg); + queue.add(Event.WRITE); + promise.setSuccess(); + } catch (AssertionError e) { + promise.setFailure(e); + } + } + + @Override + public void flush(ChannelHandlerContext ctx) { + queue.add(Event.FLUSH); + } + + Event pollEvent() { + Object o = queue.poll(); + if (o instanceof AssertionError) { + throw (AssertionError) o; + } + return (Event) o; + } + } +} diff --git a/netty-channel/src/test/java/io/netty/channel/CompleteChannelFutureTest.java b/netty-channel/src/test/java/io/netty/channel/CompleteChannelFutureTest.java new file mode 100644 index 0000000..3c0378b --- /dev/null +++ b/netty-channel/src/test/java/io/netty/channel/CompleteChannelFutureTest.java @@ -0,0 +1,92 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; +import org.mockito.Mockito; + +import java.util.concurrent.TimeUnit; + +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class CompleteChannelFutureTest { + + @Test + public void shouldDisallowNullChannel() { + assertThrows(NullPointerException.class, new Executable() { + @Override + public void execute() { + new CompleteChannelFutureImpl(null); + } + }); + } + + @Test + public void shouldNotDoAnythingOnRemove() { + Channel channel = Mockito.mock(Channel.class); + CompleteChannelFuture future = new CompleteChannelFutureImpl(channel); + ChannelFutureListener l = Mockito.mock(ChannelFutureListener.class); + future.removeListener(l); + Mockito.verifyNoMoreInteractions(l); + Mockito.verifyZeroInteractions(channel); + } + + @Test + public void testConstantProperties() throws InterruptedException { + Channel channel = Mockito.mock(Channel.class); + CompleteChannelFuture future = new CompleteChannelFutureImpl(channel); + + assertSame(channel, future.channel()); + assertTrue(future.isDone()); + assertSame(future, future.await()); + assertTrue(future.await(1)); + assertTrue(future.await(1, TimeUnit.NANOSECONDS)); + assertSame(future, future.awaitUninterruptibly()); + assertTrue(future.awaitUninterruptibly(1)); + assertTrue(future.awaitUninterruptibly(1, TimeUnit.NANOSECONDS)); + Mockito.verifyZeroInteractions(channel); + } + + private static class CompleteChannelFutureImpl extends CompleteChannelFuture { + + CompleteChannelFutureImpl(Channel channel) { + super(channel, null); + } + + @Override + public Throwable cause() { + throw new Error(); + } + + @Override + public boolean isSuccess() { + throw new Error(); + } + + @Override + public ChannelFuture sync() throws InterruptedException { + throw new Error(); + } + + @Override + public ChannelFuture syncUninterruptibly() { + throw new Error(); + } + } +} diff --git a/netty-channel/src/test/java/io/netty/channel/DefaultChannelIdTest.java b/netty-channel/src/test/java/io/netty/channel/DefaultChannelIdTest.java new file mode 100644 index 0000000..c036882 --- /dev/null +++ b/netty-channel/src/test/java/io/netty/channel/DefaultChannelIdTest.java @@ -0,0 +1,87 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.channel; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufInputStream; +import io.netty.buffer.ByteBufOutputStream; +import io.netty.buffer.Unpooled; +import org.junit.jupiter.api.Test; + +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.CoreMatchers.not; +import static org.hamcrest.CoreMatchers.sameInstance; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertTrue; + +@SuppressWarnings("DynamicRegexReplaceableByCompiledPattern") +public class DefaultChannelIdTest { + @Test + public void testShortText() { + String text = DefaultChannelId.newInstance().asShortText(); + assertTrue(text.matches("^[0-9a-f]{8}$")); + } + + @Test + public void testLongText() { + String text = DefaultChannelId.newInstance().asLongText(); + assertTrue(text.matches("^[0-9a-f]{16}-[0-9a-f]{8}-[0-9a-f]{8}-[0-9a-f]{16}-[0-9a-f]{8}$")); + } + + @Test + public void testIdempotentMachineId() { + String a = DefaultChannelId.newInstance().asLongText().substring(0, 16); + String b = DefaultChannelId.newInstance().asLongText().substring(0, 16); + assertThat(a, is(b)); + } + + @Test + public void testIdempotentProcessId() { + String a = DefaultChannelId.newInstance().asLongText().substring(17, 21); + String b = DefaultChannelId.newInstance().asLongText().substring(17, 21); + assertThat(a, is(b)); + } + + @Test + public void testSerialization() throws Exception { + ChannelId a = DefaultChannelId.newInstance(); + ChannelId b; + + ByteBuf buf = Unpooled.buffer(); + ObjectOutputStream out = new ObjectOutputStream(new ByteBufOutputStream(buf)); + try { + out.writeObject(a); + out.flush(); + } finally { + out.close(); + } + + ObjectInputStream in = new ObjectInputStream(new ByteBufInputStream(buf, true)); + try { + b = (ChannelId) in.readObject(); + } finally { + in.close(); + } + + assertThat(a, is(b)); + assertThat(a, is(not(sameInstance(b)))); + assertThat(a.asLongText(), is(b.asLongText())); + } +} diff --git a/netty-channel/src/test/java/io/netty/channel/DefaultChannelPipelineTailTest.java b/netty-channel/src/test/java/io/netty/channel/DefaultChannelPipelineTailTest.java new file mode 100644 index 0000000..69bdf8f --- /dev/null +++ b/netty-channel/src/test/java/io/netty/channel/DefaultChannelPipelineTailTest.java @@ -0,0 +1,408 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +import io.netty.bootstrap.Bootstrap; + +public class DefaultChannelPipelineTailTest { + + private static EventLoopGroup GROUP; + + @BeforeAll + public static void init() { + GROUP = new DefaultEventLoopGroup(1); + } + + @AfterAll + public static void destroy() { + GROUP.shutdownGracefully(); + } + + @Test + public void testOnUnhandledInboundChannelActive() throws Exception { + final CountDownLatch latch = new CountDownLatch(1); + MyChannel myChannel = new MyChannel() { + @Override + protected void onUnhandledInboundChannelActive() { + latch.countDown(); + } + }; + + Bootstrap bootstrap = new Bootstrap() + .channelFactory(new MyChannelFactory(myChannel)) + .group(GROUP) + .handler(new ChannelInboundHandlerAdapter()) + .remoteAddress(new InetSocketAddress(0)); + + Channel channel = bootstrap.connect() + .sync().channel(); + + try { + assertTrue(latch.await(1L, TimeUnit.SECONDS)); + } finally { + channel.close(); + } + } + + @Test + public void testOnUnhandledInboundChannelInactive() throws Exception { + final CountDownLatch latch = new CountDownLatch(1); + MyChannel myChannel = new MyChannel() { + @Override + protected void onUnhandledInboundChannelInactive() { + latch.countDown(); + } + }; + + Bootstrap bootstrap = new Bootstrap() + .channelFactory(new MyChannelFactory(myChannel)) + .group(GROUP) + .handler(new ChannelInboundHandlerAdapter()) + .remoteAddress(new InetSocketAddress(0)); + + Channel channel = bootstrap.connect() + .sync().channel(); + + channel.close().syncUninterruptibly(); + + assertTrue(latch.await(1L, TimeUnit.SECONDS)); + } + + @Test + public void testOnUnhandledInboundException() throws Exception { + final AtomicReference causeRef = new AtomicReference(); + final CountDownLatch latch = new CountDownLatch(1); + MyChannel myChannel = new MyChannel() { + @Override + protected void onUnhandledInboundException(Throwable cause) { + causeRef.set(cause); + latch.countDown(); + } + }; + + Bootstrap bootstrap = new Bootstrap() + .channelFactory(new MyChannelFactory(myChannel)) + .group(GROUP) + .handler(new ChannelInboundHandlerAdapter()) + .remoteAddress(new InetSocketAddress(0)); + + Channel channel = bootstrap.connect() + .sync().channel(); + + try { + IOException ex = new IOException("testOnUnhandledInboundException"); + channel.pipeline().fireExceptionCaught(ex); + assertTrue(latch.await(1L, TimeUnit.SECONDS)); + assertSame(ex, causeRef.get()); + } finally { + channel.close(); + } + } + + @Test + public void testOnUnhandledInboundMessage() throws Exception { + final CountDownLatch latch = new CountDownLatch(1); + MyChannel myChannel = new MyChannel() { + @Override + protected void onUnhandledInboundMessage(Object msg) { + latch.countDown(); + } + }; + + Bootstrap bootstrap = new Bootstrap() + .channelFactory(new MyChannelFactory(myChannel)) + .group(GROUP) + .handler(new ChannelInboundHandlerAdapter()) + .remoteAddress(new InetSocketAddress(0)); + + Channel channel = bootstrap.connect() + .sync().channel(); + + try { + channel.pipeline().fireChannelRead("testOnUnhandledInboundMessage"); + assertTrue(latch.await(1L, TimeUnit.SECONDS)); + } finally { + channel.close(); + } + } + + @Test + public void testOnUnhandledInboundReadComplete() throws Exception { + final CountDownLatch latch = new CountDownLatch(1); + MyChannel myChannel = new MyChannel() { + @Override + protected void onUnhandledInboundReadComplete() { + latch.countDown(); + } + }; + + Bootstrap bootstrap = new Bootstrap() + .channelFactory(new MyChannelFactory(myChannel)) + .group(GROUP) + .handler(new ChannelInboundHandlerAdapter()) + .remoteAddress(new InetSocketAddress(0)); + + Channel channel = bootstrap.connect() + .sync().channel(); + + try { + channel.pipeline().fireChannelReadComplete(); + assertTrue(latch.await(1L, TimeUnit.SECONDS)); + } finally { + channel.close(); + } + } + + @Test + public void testOnUnhandledInboundUserEventTriggered() throws Exception { + final CountDownLatch latch = new CountDownLatch(1); + MyChannel myChannel = new MyChannel() { + @Override + protected void onUnhandledInboundUserEventTriggered(Object evt) { + latch.countDown(); + } + }; + + Bootstrap bootstrap = new Bootstrap() + .channelFactory(new MyChannelFactory(myChannel)) + .group(GROUP) + .handler(new ChannelInboundHandlerAdapter()) + .remoteAddress(new InetSocketAddress(0)); + + Channel channel = bootstrap.connect() + .sync().channel(); + + try { + channel.pipeline().fireUserEventTriggered("testOnUnhandledInboundUserEventTriggered"); + assertTrue(latch.await(1L, TimeUnit.SECONDS)); + } finally { + channel.close(); + } + } + + @Test + public void testOnUnhandledInboundWritabilityChanged() throws Exception { + final CountDownLatch latch = new CountDownLatch(1); + MyChannel myChannel = new MyChannel() { + @Override + protected void onUnhandledInboundWritabilityChanged() { + latch.countDown(); + } + }; + + Bootstrap bootstrap = new Bootstrap() + .channelFactory(new MyChannelFactory(myChannel)) + .group(GROUP) + .handler(new ChannelInboundHandlerAdapter()) + .remoteAddress(new InetSocketAddress(0)); + + Channel channel = bootstrap.connect() + .sync().channel(); + + try { + channel.pipeline().fireChannelWritabilityChanged(); + assertTrue(latch.await(1L, TimeUnit.SECONDS)); + } finally { + channel.close(); + } + } + + private static class MyChannelFactory implements ChannelFactory { + private final MyChannel channel; + + MyChannelFactory(MyChannel channel) { + this.channel = channel; + } + + @Override + public MyChannel newChannel() { + return channel; + } + } + + private abstract static class MyChannel extends AbstractChannel { + private static final ChannelMetadata METADATA = new ChannelMetadata(false); + + private final ChannelConfig config = new DefaultChannelConfig(this); + + private boolean active; + private boolean closed; + + protected MyChannel() { + super(null); + } + + @Override + protected DefaultChannelPipeline newChannelPipeline() { + return new MyChannelPipeline(this); + } + + @Override + public ChannelConfig config() { + return config; + } + + @Override + public boolean isOpen() { + return !closed; + } + + @Override + public boolean isActive() { + return isOpen() && active; + } + + @Override + public ChannelMetadata metadata() { + return METADATA; + } + + @Override + protected AbstractUnsafe newUnsafe() { + return new MyUnsafe(); + } + + @Override + protected boolean isCompatible(EventLoop loop) { + return true; + } + + @Override + protected SocketAddress localAddress0() { + return null; + } + + @Override + protected SocketAddress remoteAddress0() { + return null; + } + + @Override + protected void doBind(SocketAddress localAddress) { + } + + @Override + protected void doDisconnect() { + } + + @Override + protected void doClose() { + closed = true; + } + + @Override + protected void doBeginRead() { + } + + @Override + protected void doWrite(ChannelOutboundBuffer in) throws Exception { + throw new IOException(); + } + + protected void onUnhandledInboundChannelActive() { + } + + protected void onUnhandledInboundChannelInactive() { + } + + protected void onUnhandledInboundException(Throwable cause) { + } + + protected void onUnhandledInboundMessage(Object msg) { + } + + protected void onUnhandledInboundReadComplete() { + } + + protected void onUnhandledInboundUserEventTriggered(Object evt) { + } + + protected void onUnhandledInboundWritabilityChanged() { + } + + private class MyUnsafe extends AbstractUnsafe { + @Override + public void connect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) { + if (!ensureOpen(promise)) { + return; + } + + if (!active) { + active = true; + pipeline().fireChannelActive(); + } + + promise.setSuccess(); + } + } + + private class MyChannelPipeline extends DefaultChannelPipeline { + + MyChannelPipeline(Channel channel) { + super(channel); + } + + @Override + protected void onUnhandledInboundChannelActive() { + MyChannel.this.onUnhandledInboundChannelActive(); + } + + @Override + protected void onUnhandledInboundChannelInactive() { + MyChannel.this.onUnhandledInboundChannelInactive(); + } + + @Override + protected void onUnhandledInboundException(Throwable cause) { + MyChannel.this.onUnhandledInboundException(cause); + } + + @Override + protected void onUnhandledInboundMessage(Object msg) { + MyChannel.this.onUnhandledInboundMessage(msg); + } + + @Override + protected void onUnhandledInboundChannelReadComplete() { + MyChannel.this.onUnhandledInboundReadComplete(); + } + + @Override + protected void onUnhandledInboundUserEventTriggered(Object evt) { + MyChannel.this.onUnhandledInboundUserEventTriggered(evt); + } + + @Override + protected void onUnhandledChannelWritabilityChanged() { + MyChannel.this.onUnhandledInboundWritabilityChanged(); + } + } + } +} diff --git a/netty-channel/src/test/java/io/netty/channel/DefaultChannelPipelineTest.java b/netty-channel/src/test/java/io/netty/channel/DefaultChannelPipelineTest.java new file mode 100644 index 0000000..fa86dd1 --- /dev/null +++ b/netty-channel/src/test/java/io/netty/channel/DefaultChannelPipelineTest.java @@ -0,0 +1,2293 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandler.Sharable; +import io.netty.channel.ChannelHandlerMask.Skip; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.channel.local.LocalAddress; +import io.netty.channel.local.LocalChannel; +import io.netty.channel.local.LocalServerChannel; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.oio.OioEventLoopGroup; +import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.channel.socket.oio.OioSocketChannel; +import io.netty.util.AbstractReferenceCounted; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.ReferenceCounted; +import io.netty.util.concurrent.AbstractEventExecutor; +import io.netty.util.concurrent.DefaultEventExecutorGroup; +import io.netty.util.concurrent.EventExecutor; +import io.netty.util.concurrent.EventExecutorGroup; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.ImmediateEventExecutor; +import io.netty.util.concurrent.Promise; +import io.netty.util.concurrent.UnorderedThreadPoolEventExecutor; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.api.function.Executable; + +import java.net.SocketAddress; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.NoSuchElementException; +import java.util.Queue; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; +import java.util.concurrent.locks.LockSupport; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNotSame; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.fail; + +public class DefaultChannelPipelineTest { + + private static EventLoopGroup group; + + private Channel self; + private Channel peer; + + @BeforeAll + public static void beforeClass() throws Exception { + group = new DefaultEventLoopGroup(1); + } + + @AfterAll + public static void afterClass() throws Exception { + group.shutdownGracefully().sync(); + } + + private void setUp(final ChannelHandler... handlers) throws Exception { + final AtomicReference peerRef = new AtomicReference(); + ServerBootstrap sb = new ServerBootstrap(); + sb.group(group).channel(LocalServerChannel.class); + sb.childHandler(new ChannelInboundHandlerAdapter() { + @Override + public void channelRegistered(ChannelHandlerContext ctx) { + peerRef.set(ctx.channel()); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + ReferenceCountUtil.release(msg); + } + }); + + ChannelFuture bindFuture = sb.bind(LocalAddress.ANY).sync(); + + Bootstrap b = new Bootstrap(); + b.group(group).channel(LocalChannel.class); + b.handler(new ChannelInitializer() { + @Override + protected void initChannel(LocalChannel ch) { + ch.pipeline().addLast(handlers); + } + }); + + self = b.connect(bindFuture.channel().localAddress()).sync().channel(); + peer = peerRef.get(); + + bindFuture.channel().close().sync(); + } + + @AfterEach + public void tearDown() throws Exception { + if (peer != null) { + peer.close(); + peer = null; + } + if (self != null) { + self = null; + } + } + + @Test + public void testFreeCalled() throws Exception { + final CountDownLatch free = new CountDownLatch(1); + + final ReferenceCounted holder = new AbstractReferenceCounted() { + @Override + protected void deallocate() { + free.countDown(); + } + + @Override + public ReferenceCounted touch(Object hint) { + return this; + } + }; + + StringInboundHandler handler = new StringInboundHandler(); + setUp(handler); + + peer.writeAndFlush(holder).sync(); + + assertTrue(free.await(10, TimeUnit.SECONDS)); + assertTrue(handler.called); + } + + private static final class StringInboundHandler extends ChannelInboundHandlerAdapter { + boolean called; + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + called = true; + if (!(msg instanceof String)) { + ctx.fireChannelRead(msg); + } + } + } + + @Test + public void testRemoveChannelHandler() { + ChannelPipeline pipeline = new LocalChannel().pipeline(); + + ChannelHandler handler1 = newHandler(); + ChannelHandler handler2 = newHandler(); + ChannelHandler handler3 = newHandler(); + + pipeline.addLast("handler1", handler1); + pipeline.addLast("handler2", handler2); + pipeline.addLast("handler3", handler3); + assertSame(pipeline.get("handler1"), handler1); + assertSame(pipeline.get("handler2"), handler2); + assertSame(pipeline.get("handler3"), handler3); + + pipeline.remove(handler1); + assertNull(pipeline.get("handler1")); + pipeline.remove(handler2); + assertNull(pipeline.get("handler2")); + pipeline.remove(handler3); + assertNull(pipeline.get("handler3")); + } + + @Test + public void testRemoveIfExists() { + DefaultChannelPipeline pipeline = new DefaultChannelPipeline(new LocalChannel()); + + ChannelHandler handler1 = newHandler(); + ChannelHandler handler2 = newHandler(); + ChannelHandler handler3 = newHandler(); + + pipeline.addLast("handler1", handler1); + pipeline.addLast("handler2", handler2); + pipeline.addLast("handler3", handler3); + + assertNotNull(pipeline.removeIfExists(handler1)); + assertNull(pipeline.get("handler1")); + + assertNotNull(pipeline.removeIfExists("handler2")); + assertNull(pipeline.get("handler2")); + + assertNotNull(pipeline.removeIfExists(TestHandler.class)); + assertNull(pipeline.get("handler3")); + } + + @Test + public void testRemoveIfExistsDoesNotThrowException() { + DefaultChannelPipeline pipeline = new DefaultChannelPipeline(new LocalChannel()); + + ChannelHandler handler1 = newHandler(); + ChannelHandler handler2 = newHandler(); + pipeline.addLast("handler1", handler1); + + assertNull(pipeline.removeIfExists("handlerXXX")); + assertNull(pipeline.removeIfExists(handler2)); + assertNull(pipeline.removeIfExists(ChannelOutboundHandlerAdapter.class)); + assertNotNull(pipeline.get("handler1")); + } + + @Test + public void testRemoveThrowNoSuchElementException() { + final DefaultChannelPipeline pipeline = new DefaultChannelPipeline(new LocalChannel()); + + ChannelHandler handler1 = newHandler(); + pipeline.addLast("handler1", handler1); + + assertThrows(NoSuchElementException.class, new Executable() { + @Override + public void execute() { + pipeline.remove("handlerXXX"); + } + }); + } + + @Test + public void testReplaceChannelHandler() { + ChannelPipeline pipeline = new LocalChannel().pipeline(); + + ChannelHandler handler1 = newHandler(); + pipeline.addLast("handler1", handler1); + pipeline.addLast("handler2", handler1); + pipeline.addLast("handler3", handler1); + assertSame(pipeline.get("handler1"), handler1); + assertSame(pipeline.get("handler2"), handler1); + assertSame(pipeline.get("handler3"), handler1); + + ChannelHandler newHandler1 = newHandler(); + pipeline.replace("handler1", "handler1", newHandler1); + assertSame(pipeline.get("handler1"), newHandler1); + + ChannelHandler newHandler3 = newHandler(); + pipeline.replace("handler3", "handler3", newHandler3); + assertSame(pipeline.get("handler3"), newHandler3); + + ChannelHandler newHandler2 = newHandler(); + pipeline.replace("handler2", "handler2", newHandler2); + assertSame(pipeline.get("handler2"), newHandler2); + } + + @Test + public void testReplaceHandlerChecksDuplicateNames() { + final ChannelPipeline pipeline = new LocalChannel().pipeline(); + + ChannelHandler handler1 = newHandler(); + ChannelHandler handler2 = newHandler(); + pipeline.addLast("handler1", handler1); + pipeline.addLast("handler2", handler2); + + final ChannelHandler newHandler1 = newHandler(); + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + pipeline.replace("handler1", "handler2", newHandler1); + } + }); + } + + @Test + public void testReplaceNameWithGenerated() { + ChannelPipeline pipeline = new LocalChannel().pipeline(); + + ChannelHandler handler1 = newHandler(); + pipeline.addLast("handler1", handler1); + assertSame(pipeline.get("handler1"), handler1); + + ChannelHandler newHandler1 = newHandler(); + pipeline.replace("handler1", null, newHandler1); + assertSame(pipeline.get("DefaultChannelPipelineTest$TestHandler#0"), newHandler1); + assertNull(pipeline.get("handler1")); + } + + @Test + public void testRenameChannelHandler() { + ChannelPipeline pipeline = new LocalChannel().pipeline(); + + ChannelHandler handler1 = newHandler(); + pipeline.addLast("handler1", handler1); + pipeline.addLast("handler2", handler1); + pipeline.addLast("handler3", handler1); + assertSame(pipeline.get("handler1"), handler1); + assertSame(pipeline.get("handler2"), handler1); + assertSame(pipeline.get("handler3"), handler1); + + ChannelHandler newHandler1 = newHandler(); + pipeline.replace("handler1", "newHandler1", newHandler1); + assertSame(pipeline.get("newHandler1"), newHandler1); + assertNull(pipeline.get("handler1")); + + ChannelHandler newHandler3 = newHandler(); + pipeline.replace("handler3", "newHandler3", newHandler3); + assertSame(pipeline.get("newHandler3"), newHandler3); + assertNull(pipeline.get("handler3")); + + ChannelHandler newHandler2 = newHandler(); + pipeline.replace("handler2", "newHandler2", newHandler2); + assertSame(pipeline.get("newHandler2"), newHandler2); + assertNull(pipeline.get("handler2")); + } + + @Test + public void testChannelHandlerContextNavigation() { + ChannelPipeline pipeline = new LocalChannel().pipeline(); + + final int HANDLER_ARRAY_LEN = 5; + ChannelHandler[] firstHandlers = newHandlers(HANDLER_ARRAY_LEN); + ChannelHandler[] lastHandlers = newHandlers(HANDLER_ARRAY_LEN); + + pipeline.addFirst(firstHandlers); + pipeline.addLast(lastHandlers); + + verifyContextNumber(pipeline, HANDLER_ARRAY_LEN * 2); + } + + @Test + @Timeout(value = 3000, unit = TimeUnit.MILLISECONDS) + public void testThrowInExceptionCaught() throws InterruptedException { + final CountDownLatch latch = new CountDownLatch(1); + final AtomicInteger counter = new AtomicInteger(); + Channel channel = new LocalChannel(); + try { + group.register(channel).syncUninterruptibly(); + channel.pipeline().addLast(new ChannelInboundHandlerAdapter() { + class TestException extends Exception { } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + throw new TestException(); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + if (cause instanceof TestException) { + ctx.executor().execute(new Runnable() { + @Override + public void run() { + latch.countDown(); + } + }); + } + counter.incrementAndGet(); + throw new Exception(); + } + }); + + channel.pipeline().fireChannelReadComplete(); + latch.await(); + assertEquals(1, counter.get()); + } finally { + channel.close().syncUninterruptibly(); + } + } + + @Test + @Timeout(value = 3000, unit = TimeUnit.MILLISECONDS) + public void testThrowInOtherHandlerAfterInvokedFromExceptionCaught() throws InterruptedException { + final CountDownLatch latch = new CountDownLatch(1); + final AtomicInteger counter = new AtomicInteger(); + Channel channel = new LocalChannel(); + try { + group.register(channel).syncUninterruptibly(); + channel.pipeline().addLast(new ChannelInboundHandlerAdapter() { + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + ctx.fireChannelReadComplete(); + } + }, new ChannelInboundHandlerAdapter() { + class TestException extends Exception { } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + throw new TestException(); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + if (cause instanceof TestException) { + ctx.executor().execute(new Runnable() { + @Override + public void run() { + latch.countDown(); + } + }); + } + counter.incrementAndGet(); + throw new Exception(); + } + }); + + channel.pipeline().fireExceptionCaught(new Exception()); + latch.await(); + assertEquals(1, counter.get()); + } finally { + channel.close().syncUninterruptibly(); + } + } + + @Test + public void testFireChannelRegistered() throws Exception { + final CountDownLatch latch = new CountDownLatch(1); + ChannelPipeline pipeline = new LocalChannel().pipeline(); + pipeline.addLast(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + ch.pipeline().addLast(new ChannelInboundHandlerAdapter() { + @Override + public void channelRegistered(ChannelHandlerContext ctx) { + latch.countDown(); + } + }); + } + }); + group.register(pipeline.channel()); + assertTrue(latch.await(2, TimeUnit.SECONDS)); + } + + @Test + public void testPipelineOperation() { + ChannelPipeline pipeline = new LocalChannel().pipeline(); + + final int handlerNum = 5; + ChannelHandler[] handlers1 = newHandlers(handlerNum); + ChannelHandler[] handlers2 = newHandlers(handlerNum); + + final String prefixX = "x"; + for (int i = 0; i < handlerNum; i++) { + if (i % 2 == 0) { + pipeline.addFirst(prefixX + i, handlers1[i]); + } else { + pipeline.addLast(prefixX + i, handlers1[i]); + } + } + + for (int i = 0; i < handlerNum; i++) { + if (i % 2 != 0) { + pipeline.addBefore(prefixX + i, String.valueOf(i), handlers2[i]); + } else { + pipeline.addAfter(prefixX + i, String.valueOf(i), handlers2[i]); + } + } + + verifyContextNumber(pipeline, handlerNum * 2); + } + + @Test + public void testChannelHandlerContextOrder() { + ChannelPipeline pipeline = new LocalChannel().pipeline(); + + pipeline.addFirst("1", newHandler()); + pipeline.addLast("10", newHandler()); + + pipeline.addBefore("10", "5", newHandler()); + pipeline.addAfter("1", "3", newHandler()); + pipeline.addBefore("5", "4", newHandler()); + pipeline.addAfter("5", "6", newHandler()); + + pipeline.addBefore("1", "0", newHandler()); + pipeline.addAfter("10", "11", newHandler()); + + AbstractChannelHandlerContext ctx = (AbstractChannelHandlerContext) pipeline.firstContext(); + assertNotNull(ctx); + while (ctx != null) { + int i = toInt(ctx.name()); + int j = next(ctx); + if (j != -1) { + assertTrue(i < j); + } else { + assertNull(ctx.next.next); + } + ctx = ctx.next; + } + + verifyContextNumber(pipeline, 8); + } + + @Test + @Timeout(value = 10000, unit = TimeUnit.MILLISECONDS) + public void testLifeCycleAwareness() throws Exception { + setUp(); + + ChannelPipeline p = self.pipeline(); + + final List handlers = new ArrayList(); + final int COUNT = 20; + final CountDownLatch addLatch = new CountDownLatch(COUNT); + for (int i = 0; i < COUNT; i++) { + final LifeCycleAwareTestHandler handler = new LifeCycleAwareTestHandler("handler-" + i); + + // Add handler. + p.addFirst(handler.name, handler); + self.eventLoop().execute(new Runnable() { + @Override + public void run() { + // Validate handler life-cycle methods called. + handler.validate(true, false); + + // Store handler into the list. + handlers.add(handler); + + addLatch.countDown(); + } + }); + } + addLatch.await(); + + // Change the order of remove operations over all handlers in the pipeline. + Collections.shuffle(handlers); + + final CountDownLatch removeLatch = new CountDownLatch(COUNT); + + for (final LifeCycleAwareTestHandler handler : handlers) { + assertSame(handler, p.remove(handler.name)); + + self.eventLoop().execute(new Runnable() { + @Override + public void run() { + // Validate handler life-cycle methods called. + handler.validate(true, true); + removeLatch.countDown(); + } + }); + } + removeLatch.await(); + } + + @Test + @Timeout(value = 10000, unit = TimeUnit.MILLISECONDS) + public void testRemoveAndForwardInbound() throws Exception { + final BufferedTestHandler handler1 = new BufferedTestHandler(); + final BufferedTestHandler handler2 = new BufferedTestHandler(); + + setUp(handler1, handler2); + + self.eventLoop().submit(new Runnable() { + @Override + public void run() { + ChannelPipeline p = self.pipeline(); + handler1.inboundBuffer.add(8); + assertEquals(8, handler1.inboundBuffer.peek()); + assertTrue(handler2.inboundBuffer.isEmpty()); + p.remove(handler1); + assertEquals(1, handler2.inboundBuffer.size()); + assertEquals(8, handler2.inboundBuffer.peek()); + } + }).sync(); + } + + @Test + @Timeout(value = 10000, unit = TimeUnit.MILLISECONDS) + public void testRemoveAndForwardOutbound() throws Exception { + final BufferedTestHandler handler1 = new BufferedTestHandler(); + final BufferedTestHandler handler2 = new BufferedTestHandler(); + + setUp(handler1, handler2); + + self.eventLoop().submit(new Runnable() { + @Override + public void run() { + ChannelPipeline p = self.pipeline(); + handler2.outboundBuffer.add(8); + assertEquals(8, handler2.outboundBuffer.peek()); + assertTrue(handler1.outboundBuffer.isEmpty()); + p.remove(handler2); + assertEquals(1, handler1.outboundBuffer.size()); + assertEquals(8, handler1.outboundBuffer.peek()); + } + }).sync(); + } + + @Test + @Timeout(value = 10000, unit = TimeUnit.MILLISECONDS) + public void testReplaceAndForwardOutbound() throws Exception { + final BufferedTestHandler handler1 = new BufferedTestHandler(); + final BufferedTestHandler handler2 = new BufferedTestHandler(); + + setUp(handler1); + + self.eventLoop().submit(new Runnable() { + @Override + public void run() { + ChannelPipeline p = self.pipeline(); + handler1.outboundBuffer.add(8); + assertEquals(8, handler1.outboundBuffer.peek()); + assertTrue(handler2.outboundBuffer.isEmpty()); + p.replace(handler1, "handler2", handler2); + assertEquals(8, handler2.outboundBuffer.peek()); + } + }).sync(); + } + + @Test + @Timeout(value = 10000, unit = TimeUnit.MILLISECONDS) + public void testReplaceAndForwardInboundAndOutbound() throws Exception { + final BufferedTestHandler handler1 = new BufferedTestHandler(); + final BufferedTestHandler handler2 = new BufferedTestHandler(); + + setUp(handler1); + + self.eventLoop().submit(new Runnable() { + @Override + public void run() { + ChannelPipeline p = self.pipeline(); + handler1.inboundBuffer.add(8); + handler1.outboundBuffer.add(8); + + assertEquals(8, handler1.inboundBuffer.peek()); + assertEquals(8, handler1.outboundBuffer.peek()); + assertTrue(handler2.inboundBuffer.isEmpty()); + assertTrue(handler2.outboundBuffer.isEmpty()); + + p.replace(handler1, "handler2", handler2); + assertEquals(8, handler2.outboundBuffer.peek()); + assertEquals(8, handler2.inboundBuffer.peek()); + } + }).sync(); + } + + @Test + @Timeout(value = 10000, unit = TimeUnit.MILLISECONDS) + public void testRemoveAndForwardInboundOutbound() throws Exception { + final BufferedTestHandler handler1 = new BufferedTestHandler(); + final BufferedTestHandler handler2 = new BufferedTestHandler(); + final BufferedTestHandler handler3 = new BufferedTestHandler(); + + setUp(handler1, handler2, handler3); + + self.eventLoop().submit(new Runnable() { + @Override + public void run() { + ChannelPipeline p = self.pipeline(); + handler2.inboundBuffer.add(8); + handler2.outboundBuffer.add(8); + + assertEquals(8, handler2.inboundBuffer.peek()); + assertEquals(8, handler2.outboundBuffer.peek()); + + assertEquals(0, handler1.outboundBuffer.size()); + assertEquals(0, handler3.inboundBuffer.size()); + + p.remove(handler2); + assertEquals(8, handler3.inboundBuffer.peek()); + assertEquals(8, handler1.outboundBuffer.peek()); + } + }).sync(); + } + + // Tests for https://github.com/netty/netty/issues/2349 + @Test + public void testCancelBind() { + ChannelPipeline pipeline = new LocalChannel().pipeline(); + group.register(pipeline.channel()); + + ChannelPromise promise = pipeline.channel().newPromise(); + assertTrue(promise.cancel(false)); + ChannelFuture future = pipeline.bind(new LocalAddress("test"), promise); + assertTrue(future.isCancelled()); + } + + @Test + public void testCancelConnect() { + ChannelPipeline pipeline = new LocalChannel().pipeline(); + group.register(pipeline.channel()); + + ChannelPromise promise = pipeline.channel().newPromise(); + assertTrue(promise.cancel(false)); + ChannelFuture future = pipeline.connect(new LocalAddress("test"), promise); + assertTrue(future.isCancelled()); + } + + @Test + public void testCancelDisconnect() { + ChannelPipeline pipeline = new LocalChannel().pipeline(); + group.register(pipeline.channel()); + + ChannelPromise promise = pipeline.channel().newPromise(); + assertTrue(promise.cancel(false)); + ChannelFuture future = pipeline.disconnect(promise); + assertTrue(future.isCancelled()); + } + + @Test + public void testCancelClose() { + ChannelPipeline pipeline = new LocalChannel().pipeline(); + group.register(pipeline.channel()); + + ChannelPromise promise = pipeline.channel().newPromise(); + assertTrue(promise.cancel(false)); + ChannelFuture future = pipeline.close(promise); + assertTrue(future.isCancelled()); + } + + @Test + public void testWrongPromiseChannel() throws Exception { + final ChannelPipeline pipeline = new LocalChannel().pipeline(); + group.register(pipeline.channel()).sync(); + + ChannelPipeline pipeline2 = new LocalChannel().pipeline(); + group.register(pipeline2.channel()).sync(); + + try { + final ChannelPromise promise2 = pipeline2.channel().newPromise(); + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + pipeline.close(promise2); + } + }); + } finally { + pipeline.close(); + pipeline2.close(); + } + } + + @Test + public void testUnexpectedVoidChannelPromise() throws Exception { + final ChannelPipeline pipeline = new LocalChannel().pipeline(); + group.register(pipeline.channel()).sync(); + + try { + final ChannelPromise promise = new VoidChannelPromise(pipeline.channel(), false); + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + pipeline.close(promise); + } + }); + } finally { + pipeline.close(); + } + } + + @Test + public void testUnexpectedVoidChannelPromiseCloseFuture() throws Exception { + final ChannelPipeline pipeline = new LocalChannel().pipeline(); + group.register(pipeline.channel()).sync(); + + try { + final ChannelPromise promise = (ChannelPromise) pipeline.channel().closeFuture(); + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + pipeline.close(promise); + } + }); + } finally { + pipeline.close(); + } + } + + @Test + public void testCancelDeregister() { + ChannelPipeline pipeline = new LocalChannel().pipeline(); + group.register(pipeline.channel()); + + ChannelPromise promise = pipeline.channel().newPromise(); + assertTrue(promise.cancel(false)); + ChannelFuture future = pipeline.deregister(promise); + assertTrue(future.isCancelled()); + } + + @Test + public void testCancelWrite() { + ChannelPipeline pipeline = new LocalChannel().pipeline(); + group.register(pipeline.channel()); + + ChannelPromise promise = pipeline.channel().newPromise(); + assertTrue(promise.cancel(false)); + ByteBuf buffer = Unpooled.buffer(); + assertEquals(1, buffer.refCnt()); + ChannelFuture future = pipeline.write(buffer, promise); + assertTrue(future.isCancelled()); + assertEquals(0, buffer.refCnt()); + } + + @Test + public void testCancelWriteAndFlush() { + ChannelPipeline pipeline = new LocalChannel().pipeline(); + group.register(pipeline.channel()); + + ChannelPromise promise = pipeline.channel().newPromise(); + assertTrue(promise.cancel(false)); + ByteBuf buffer = Unpooled.buffer(); + assertEquals(1, buffer.refCnt()); + ChannelFuture future = pipeline.writeAndFlush(buffer, promise); + assertTrue(future.isCancelled()); + assertEquals(0, buffer.refCnt()); + } + + @Test + public void testFirstContextEmptyPipeline() { + ChannelPipeline pipeline = new LocalChannel().pipeline(); + assertNull(pipeline.firstContext()); + } + + @Test + public void testLastContextEmptyPipeline() { + ChannelPipeline pipeline = new LocalChannel().pipeline(); + assertNull(pipeline.lastContext()); + } + + @Test + public void testFirstHandlerEmptyPipeline() { + ChannelPipeline pipeline = new LocalChannel().pipeline(); + assertNull(pipeline.first()); + } + + @Test + public void testLastHandlerEmptyPipeline() { + ChannelPipeline pipeline = new LocalChannel().pipeline(); + assertNull(pipeline.last()); + } + + @Test + @Timeout(value = 5000, unit = TimeUnit.MILLISECONDS) + public void testChannelInitializerException() throws Exception { + final IllegalStateException exception = new IllegalStateException(); + final AtomicReference error = new AtomicReference(); + final CountDownLatch latch = new CountDownLatch(1); + EmbeddedChannel channel = new EmbeddedChannel(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + throw exception; + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + super.exceptionCaught(ctx, cause); + error.set(cause); + latch.countDown(); + } + }); + latch.await(); + assertFalse(channel.isActive()); + assertSame(exception, error.get()); + } + + @Test + public void testChannelUnregistrationWithCustomExecutor() throws Exception { + final CountDownLatch channelLatch = new CountDownLatch(1); + final CountDownLatch handlerLatch = new CountDownLatch(1); + ChannelPipeline pipeline = new LocalChannel().pipeline(); + pipeline.addLast(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + ch.pipeline().addLast(new WrapperExecutor(), + new ChannelInboundHandlerAdapter() { + + @Override + public void channelUnregistered(ChannelHandlerContext ctx) { + channelLatch.countDown(); + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) { + handlerLatch.countDown(); + } + }); + } + }); + Channel channel = pipeline.channel(); + group.register(channel); + channel.close(); + channel.deregister(); + assertTrue(channelLatch.await(2, TimeUnit.SECONDS)); + assertTrue(handlerLatch.await(2, TimeUnit.SECONDS)); + } + + @Test + @Timeout(value = 3000, unit = TimeUnit.MILLISECONDS) + public void testAddHandlerBeforeRegisteredThenRemove() { + final EventLoop loop = group.next(); + + CheckEventExecutorHandler handler = new CheckEventExecutorHandler(loop); + ChannelPipeline pipeline = new LocalChannel().pipeline(); + pipeline.addFirst(handler); + assertFalse(handler.addedPromise.isDone()); + group.register(pipeline.channel()); + handler.addedPromise.syncUninterruptibly(); + pipeline.remove(handler); + handler.removedPromise.syncUninterruptibly(); + } + + @Test + @Timeout(value = 3000, unit = TimeUnit.MILLISECONDS) + public void testAddHandlerBeforeRegisteredThenReplace() throws Exception { + final EventLoop loop = group.next(); + final CountDownLatch latch = new CountDownLatch(1); + + CheckEventExecutorHandler handler = new CheckEventExecutorHandler(loop); + ChannelPipeline pipeline = new LocalChannel().pipeline(); + pipeline.addFirst(handler); + assertFalse(handler.addedPromise.isDone()); + group.register(pipeline.channel()); + handler.addedPromise.syncUninterruptibly(); + pipeline.replace(handler, null, new ChannelHandlerAdapter() { + @Override + public void handlerAdded(ChannelHandlerContext ctx) { + latch.countDown(); + } + }); + handler.removedPromise.syncUninterruptibly(); + latch.await(); + } + + @Test + public void testAddRemoveHandlerNotRegistered() throws Throwable { + final AtomicReference error = new AtomicReference(); + ChannelHandler handler = new ErrorChannelHandler(error); + ChannelPipeline pipeline = new LocalChannel().pipeline(); + pipeline.addFirst(handler); + pipeline.remove(handler); + + Throwable cause = error.get(); + if (cause != null) { + throw cause; + } + } + + @Test + public void testAddReplaceHandlerNotRegistered() throws Throwable { + final AtomicReference error = new AtomicReference(); + ChannelHandler handler = new ErrorChannelHandler(error); + ChannelPipeline pipeline = new LocalChannel().pipeline(); + pipeline.addFirst(handler); + pipeline.replace(handler, null, new ErrorChannelHandler(error)); + + Throwable cause = error.get(); + if (cause != null) { + throw cause; + } + } + + @Test + @Timeout(value = 3000, unit = TimeUnit.MILLISECONDS) + public void testHandlerAddedAndRemovedCalledInCorrectOrder() throws Throwable { + final EventExecutorGroup group1 = new DefaultEventExecutorGroup(1); + final EventExecutorGroup group2 = new DefaultEventExecutorGroup(1); + + try { + BlockingQueue addedQueue = new LinkedBlockingQueue(); + BlockingQueue removedQueue = new LinkedBlockingQueue(); + + CheckOrderHandler handler1 = new CheckOrderHandler(addedQueue, removedQueue); + CheckOrderHandler handler2 = new CheckOrderHandler(addedQueue, removedQueue); + CheckOrderHandler handler3 = new CheckOrderHandler(addedQueue, removedQueue); + CheckOrderHandler handler4 = new CheckOrderHandler(addedQueue, removedQueue); + + ChannelPipeline pipeline = new LocalChannel().pipeline(); + pipeline.addLast(handler1); + group.register(pipeline.channel()).syncUninterruptibly(); + pipeline.addLast(group1, handler2); + pipeline.addLast(group2, handler3); + pipeline.addLast(handler4); + + assertTrue(removedQueue.isEmpty()); + pipeline.channel().close().syncUninterruptibly(); + assertHandler(addedQueue.take(), handler1); + + // Depending on timing this can be handler2 or handler3 as these use different EventExecutorGroups. + assertHandler(addedQueue.take(), handler2, handler3, handler4); + assertHandler(addedQueue.take(), handler2, handler3, handler4); + assertHandler(addedQueue.take(), handler2, handler3, handler4); + + assertTrue(addedQueue.isEmpty()); + + assertHandler(removedQueue.take(), handler4); + assertHandler(removedQueue.take(), handler3); + assertHandler(removedQueue.take(), handler2); + assertHandler(removedQueue.take(), handler1); + assertTrue(removedQueue.isEmpty()); + } finally { + group1.shutdownGracefully(); + group2.shutdownGracefully(); + } + } + + @Test + @Timeout(value = 3000, unit = TimeUnit.MILLISECONDS) + public void testHandlerAddedExceptionFromChildHandlerIsPropagated() { + final EventExecutorGroup group1 = new DefaultEventExecutorGroup(1); + try { + final Promise promise = group1.next().newPromise(); + final AtomicBoolean handlerAdded = new AtomicBoolean(); + final Exception exception = new RuntimeException(); + ChannelPipeline pipeline = new LocalChannel().pipeline(); + pipeline.addLast(group1, new CheckExceptionHandler(exception, promise)); + pipeline.addFirst(new ChannelHandlerAdapter() { + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + handlerAdded.set(true); + throw exception; + } + }); + assertFalse(handlerAdded.get()); + group.register(pipeline.channel()); + promise.syncUninterruptibly(); + } finally { + group1.shutdownGracefully(); + } + } + + @Test + @Timeout(value = 3000, unit = TimeUnit.MILLISECONDS) + public void testHandlerRemovedExceptionFromChildHandlerIsPropagated() { + final EventExecutorGroup group1 = new DefaultEventExecutorGroup(1); + try { + final Promise promise = group1.next().newPromise(); + String handlerName = "foo"; + final Exception exception = new RuntimeException(); + ChannelPipeline pipeline = new LocalChannel().pipeline(); + pipeline.addLast(handlerName, new ChannelHandlerAdapter() { + @Override + public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { + throw exception; + } + }); + pipeline.addLast(group1, new CheckExceptionHandler(exception, promise)); + group.register(pipeline.channel()).syncUninterruptibly(); + pipeline.remove(handlerName); + promise.syncUninterruptibly(); + } finally { + group1.shutdownGracefully(); + } + } + + @Test + @Timeout(value = 3000, unit = TimeUnit.MILLISECONDS) + public void testHandlerAddedThrowsAndRemovedThrowsException() throws InterruptedException { + final EventExecutorGroup group1 = new DefaultEventExecutorGroup(1); + try { + final CountDownLatch latch = new CountDownLatch(1); + final Promise promise = group1.next().newPromise(); + final Exception exceptionAdded = new RuntimeException(); + final Exception exceptionRemoved = new RuntimeException(); + String handlerName = "foo"; + ChannelPipeline pipeline = new LocalChannel().pipeline(); + pipeline.addLast(group1, new CheckExceptionHandler(exceptionAdded, promise)); + pipeline.addFirst(handlerName, new ChannelHandlerAdapter() { + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + throw exceptionAdded; + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { + // Execute this later so we are sure the exception is handled first. + ctx.executor().execute(new Runnable() { + @Override + public void run() { + latch.countDown(); + } + }); + throw exceptionRemoved; + } + }); + group.register(pipeline.channel()).syncUninterruptibly(); + latch.await(); + assertNull(pipeline.context(handlerName)); + promise.syncUninterruptibly(); + } finally { + group1.shutdownGracefully(); + } + } + + @Test + @Timeout(value = 2000, unit = TimeUnit.MILLISECONDS) + public void testAddRemoveHandlerCalledOnceRegistered() throws Throwable { + ChannelPipeline pipeline = new LocalChannel().pipeline(); + CallbackCheckHandler handler = new CallbackCheckHandler(); + + pipeline.addFirst(handler); + pipeline.remove(handler); + + assertNull(handler.addedHandler.getNow()); + assertNull(handler.removedHandler.getNow()); + + group.register(pipeline.channel()).syncUninterruptibly(); + Throwable cause = handler.error.get(); + if (cause != null) { + throw cause; + } + + assertTrue(handler.addedHandler.get()); + assertTrue(handler.removedHandler.get()); + } + + @Test + @Timeout(value = 3000, unit = TimeUnit.MILLISECONDS) + public void testAddReplaceHandlerCalledOnceRegistered() throws Throwable { + ChannelPipeline pipeline = new LocalChannel().pipeline(); + CallbackCheckHandler handler = new CallbackCheckHandler(); + CallbackCheckHandler handler2 = new CallbackCheckHandler(); + + pipeline.addFirst(handler); + pipeline.replace(handler, null, handler2); + + assertNull(handler.addedHandler.getNow()); + assertNull(handler.removedHandler.getNow()); + assertNull(handler2.addedHandler.getNow()); + assertNull(handler2.removedHandler.getNow()); + + group.register(pipeline.channel()).syncUninterruptibly(); + Throwable cause = handler.error.get(); + if (cause != null) { + throw cause; + } + + assertTrue(handler.addedHandler.get()); + assertTrue(handler.removedHandler.get()); + + Throwable cause2 = handler2.error.get(); + if (cause2 != null) { + throw cause2; + } + + assertTrue(handler2.addedHandler.get()); + assertNull(handler2.removedHandler.getNow()); + pipeline.remove(handler2); + assertTrue(handler2.removedHandler.get()); + } + + @Test + @Timeout(value = 3000, unit = TimeUnit.MILLISECONDS) + public void testAddBefore() throws Throwable { + ChannelPipeline pipeline1 = new LocalChannel().pipeline(); + ChannelPipeline pipeline2 = new LocalChannel().pipeline(); + + EventLoopGroup defaultGroup = new DefaultEventLoopGroup(2); + try { + EventLoop eventLoop1 = defaultGroup.next(); + EventLoop eventLoop2 = defaultGroup.next(); + + eventLoop1.register(pipeline1.channel()).syncUninterruptibly(); + eventLoop2.register(pipeline2.channel()).syncUninterruptibly(); + + CountDownLatch latch = new CountDownLatch(2 * 10); + for (int i = 0; i < 10; i++) { + eventLoop1.execute(new TestTask(pipeline2, latch)); + eventLoop2.execute(new TestTask(pipeline1, latch)); + } + latch.await(); + } finally { + defaultGroup.shutdownGracefully(); + } + } + + @Test + @Timeout(value = 3000, unit = TimeUnit.MILLISECONDS) + public void testAddInListenerNio() { + testAddInListener(new NioSocketChannel(), new NioEventLoopGroup(1)); + } + + @SuppressWarnings("deprecation") + @Test + @Timeout(value = 3000, unit = TimeUnit.MILLISECONDS) + public void testAddInListenerOio() { + testAddInListener(new OioSocketChannel(), new OioEventLoopGroup(1)); + } + + @Test + @Timeout(value = 3000, unit = TimeUnit.MILLISECONDS) + public void testAddInListenerLocal() { + testAddInListener(new LocalChannel(), new DefaultEventLoopGroup(1)); + } + + private static void testAddInListener(Channel channel, EventLoopGroup group) { + ChannelPipeline pipeline1 = channel.pipeline(); + try { + final Object event = new Object(); + final Promise promise = ImmediateEventExecutor.INSTANCE.newPromise(); + group.register(pipeline1.channel()).addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + ChannelPipeline pipeline = future.channel().pipeline(); + final AtomicBoolean handlerAddedCalled = new AtomicBoolean(); + pipeline.addLast(new ChannelInboundHandlerAdapter() { + @Override + public void handlerAdded(ChannelHandlerContext ctx) { + handlerAddedCalled.set(true); + } + + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + promise.setSuccess(event); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + promise.setFailure(cause); + } + }); + if (!handlerAddedCalled.get()) { + promise.setFailure(new AssertionError("handlerAdded(...) should have been called")); + return; + } + // This event must be captured by the added handler. + pipeline.fireUserEventTriggered(event); + } + }); + assertSame(event, promise.syncUninterruptibly().getNow()); + } finally { + pipeline1.channel().close().syncUninterruptibly(); + group.shutdownGracefully(); + } + } + + @Test + public void testNullName() { + ChannelPipeline pipeline = new LocalChannel().pipeline(); + pipeline.addLast(newHandler()); + pipeline.addLast(null, newHandler()); + pipeline.addFirst(newHandler()); + pipeline.addFirst(null, newHandler()); + + pipeline.addLast("test", newHandler()); + pipeline.addAfter("test", null, newHandler()); + + pipeline.addBefore("test", null, newHandler()); + } + + @Test + @Timeout(value = 3000, unit = TimeUnit.MILLISECONDS) + public void testUnorderedEventExecutor() throws Throwable { + ChannelPipeline pipeline1 = new LocalChannel().pipeline(); + EventExecutorGroup eventExecutors = new UnorderedThreadPoolEventExecutor(2); + EventLoopGroup defaultGroup = new DefaultEventLoopGroup(1); + try { + EventLoop eventLoop1 = defaultGroup.next(); + eventLoop1.register(pipeline1.channel()).syncUninterruptibly(); + final CountDownLatch latch = new CountDownLatch(1); + pipeline1.addLast(eventExecutors, new ChannelInboundHandlerAdapter() { + @Override + public void handlerAdded(ChannelHandlerContext ctx) { + // Just block one of the two threads. + LockSupport.park(); + } + + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + latch.countDown(); + } + }); + // Trigger an event, as we use UnorderedEventExecutor userEventTriggered should be called even when + // handlerAdded(...) blocks. + pipeline1.fireUserEventTriggered(""); + latch.await(); + } finally { + defaultGroup.shutdownGracefully(0, 0, TimeUnit.SECONDS).syncUninterruptibly(); + eventExecutors.shutdownGracefully(0, 0, TimeUnit.SECONDS).syncUninterruptibly(); + } + } + + @Test + public void testPinExecutor() { + EventExecutorGroup group = new DefaultEventExecutorGroup(2); + ChannelPipeline pipeline = new LocalChannel().pipeline(); + ChannelPipeline pipeline2 = new LocalChannel().pipeline(); + + pipeline.addLast(group, "h1", new ChannelInboundHandlerAdapter()); + pipeline.addLast(group, "h2", new ChannelInboundHandlerAdapter()); + pipeline2.addLast(group, "h3", new ChannelInboundHandlerAdapter()); + + EventExecutor executor1 = pipeline.context("h1").executor(); + EventExecutor executor2 = pipeline.context("h2").executor(); + assertNotNull(executor1); + assertNotNull(executor2); + assertSame(executor1, executor2); + EventExecutor executor3 = pipeline2.context("h3").executor(); + assertNotNull(executor3); + assertNotSame(executor3, executor2); + group.shutdownGracefully(0, 0, TimeUnit.SECONDS); + } + + @Test + public void testNotPinExecutor() { + EventExecutorGroup group = new DefaultEventExecutorGroup(2); + ChannelPipeline pipeline = new LocalChannel().pipeline(); + pipeline.channel().config().setOption(ChannelOption.SINGLE_EVENTEXECUTOR_PER_GROUP, false); + + pipeline.addLast(group, "h1", new ChannelInboundHandlerAdapter()); + pipeline.addLast(group, "h2", new ChannelInboundHandlerAdapter()); + + EventExecutor executor1 = pipeline.context("h1").executor(); + EventExecutor executor2 = pipeline.context("h2").executor(); + assertNotNull(executor1); + assertNotNull(executor2); + assertNotSame(executor1, executor2); + group.shutdownGracefully(0, 0, TimeUnit.SECONDS); + } + + @Test + @Timeout(value = 3000, unit = TimeUnit.MILLISECONDS) + public void testVoidPromiseNotify() { + ChannelPipeline pipeline1 = new LocalChannel().pipeline(); + + EventLoopGroup defaultGroup = new DefaultEventLoopGroup(1); + EventLoop eventLoop1 = defaultGroup.next(); + final Promise promise = eventLoop1.newPromise(); + final Exception exception = new IllegalArgumentException(); + try { + eventLoop1.register(pipeline1.channel()).syncUninterruptibly(); + pipeline1.addLast(new ChannelDuplexHandler() { + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + throw exception; + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + promise.setSuccess(cause); + } + }); + pipeline1.write("test", pipeline1.voidPromise()); + assertSame(exception, promise.syncUninterruptibly().getNow()); + } finally { + pipeline1.channel().close().syncUninterruptibly(); + defaultGroup.shutdownGracefully(); + } + } + + // Test for https://github.com/netty/netty/issues/8676. + @Test + public void testHandlerRemovedOnlyCalledWhenHandlerAddedCalled() throws Exception { + EventLoopGroup group = new DefaultEventLoopGroup(1); + try { + final AtomicReference errorRef = new AtomicReference(); + + // As this only happens via a race we will verify 500 times. This was good enough to have it failed most of + // the time. + for (int i = 0; i < 500; i++) { + + ChannelPipeline pipeline = new LocalChannel().pipeline(); + group.register(pipeline.channel()).sync(); + + final CountDownLatch latch = new CountDownLatch(1); + + pipeline.addLast(new ChannelInboundHandlerAdapter() { + @Override + public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { + // Block just for a bit so we have a chance to trigger the race mentioned in the issue. + latch.await(50, TimeUnit.MILLISECONDS); + } + }); + + // Close the pipeline which will call destroy0(). This will remove each handler in the pipeline and + // should call handlerRemoved(...) if and only if handlerAdded(...) was called for the handler before. + pipeline.close(); + + pipeline.addLast(new ChannelInboundHandlerAdapter() { + private boolean handerAddedCalled; + + @Override + public void handlerAdded(ChannelHandlerContext ctx) { + handerAddedCalled = true; + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) { + if (!handerAddedCalled) { + errorRef.set(new AssertionError( + "handlerRemoved(...) called without handlerAdded(...) before")); + } + } + }); + + latch.countDown(); + + pipeline.channel().closeFuture().syncUninterruptibly(); + + // Schedule something on the EventLoop to ensure all other scheduled tasks had a chance to complete. + pipeline.channel().eventLoop().submit(new Runnable() { + @Override + public void run() { + // NOOP + } + }).syncUninterruptibly(); + Error error = errorRef.get(); + if (error != null) { + throw error; + } + } + } finally { + group.shutdownGracefully(); + } + } + + @Test + public void testSkipHandlerMethodsIfAnnotated() { + EmbeddedChannel channel = new EmbeddedChannel(true); + ChannelPipeline pipeline = channel.pipeline(); + + final class SkipHandler implements ChannelInboundHandler, ChannelOutboundHandler { + private int state = 2; + private Error errorRef; + + private void fail() { + errorRef = new AssertionError("Method should never been called"); + } + + @Skip + @Override + public void bind(ChannelHandlerContext ctx, SocketAddress localAddress, ChannelPromise promise) { + fail(); + ctx.bind(localAddress, promise); + } + + @Skip + @Override + public void connect(ChannelHandlerContext ctx, SocketAddress remoteAddress, + SocketAddress localAddress, ChannelPromise promise) { + fail(); + ctx.connect(remoteAddress, localAddress, promise); + } + + @Skip + @Override + public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) { + fail(); + ctx.disconnect(promise); + } + + @Skip + @Override + public void close(ChannelHandlerContext ctx, ChannelPromise promise) { + fail(); + ctx.close(promise); + } + + @Skip + @Override + public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) { + fail(); + ctx.deregister(promise); + } + + @Skip + @Override + public void read(ChannelHandlerContext ctx) { + fail(); + ctx.read(); + } + + @Skip + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) { + fail(); + ctx.write(msg, promise); + } + + @Skip + @Override + public void flush(ChannelHandlerContext ctx) { + fail(); + ctx.flush(); + } + + @Skip + @Override + public void channelRegistered(ChannelHandlerContext ctx) { + fail(); + ctx.fireChannelRegistered(); + } + + @Skip + @Override + public void channelUnregistered(ChannelHandlerContext ctx) { + fail(); + ctx.fireChannelUnregistered(); + } + + @Skip + @Override + public void channelActive(ChannelHandlerContext ctx) { + fail(); + ctx.fireChannelActive(); + } + + @Skip + @Override + public void channelInactive(ChannelHandlerContext ctx) { + fail(); + ctx.fireChannelInactive(); + } + + @Skip + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + fail(); + ctx.fireChannelRead(msg); + } + + @Skip + @Override + public void channelReadComplete(ChannelHandlerContext ctx) { + fail(); + ctx.fireChannelReadComplete(); + } + + @Skip + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + fail(); + ctx.fireUserEventTriggered(evt); + } + + @Skip + @Override + public void channelWritabilityChanged(ChannelHandlerContext ctx) { + fail(); + ctx.fireChannelWritabilityChanged(); + } + + @Skip + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + fail(); + ctx.fireExceptionCaught(cause); + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) { + state--; + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) { + state--; + } + + void assertSkipped() { + assertEquals(0, state); + Error error = errorRef; + if (error != null) { + throw error; + } + } + } + + final class OutboundCalledHandler extends ChannelOutboundHandlerAdapter { + private static final int MASK_BIND = 1; + private static final int MASK_CONNECT = 1 << 1; + private static final int MASK_DISCONNECT = 1 << 2; + private static final int MASK_CLOSE = 1 << 3; + private static final int MASK_DEREGISTER = 1 << 4; + private static final int MASK_READ = 1 << 5; + private static final int MASK_WRITE = 1 << 6; + private static final int MASK_FLUSH = 1 << 7; + private static final int MASK_ADDED = 1 << 8; + private static final int MASK_REMOVED = 1 << 9; + + private int executionMask; + + @Override + public void handlerAdded(ChannelHandlerContext ctx) { + executionMask |= MASK_ADDED; + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) { + executionMask |= MASK_REMOVED; + } + + @Override + public void bind(ChannelHandlerContext ctx, SocketAddress localAddress, ChannelPromise promise) { + executionMask |= MASK_BIND; + promise.setSuccess(); + } + + @Override + public void connect(ChannelHandlerContext ctx, SocketAddress remoteAddress, + SocketAddress localAddress, ChannelPromise promise) { + executionMask |= MASK_CONNECT; + promise.setSuccess(); + } + + @Override + public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) { + executionMask |= MASK_DISCONNECT; + promise.setSuccess(); + } + + @Override + public void close(ChannelHandlerContext ctx, ChannelPromise promise) { + executionMask |= MASK_CLOSE; + promise.setSuccess(); + } + + @Override + public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) { + executionMask |= MASK_DEREGISTER; + promise.setSuccess(); + } + + @Override + public void read(ChannelHandlerContext ctx) { + executionMask |= MASK_READ; + } + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) { + executionMask |= MASK_WRITE; + promise.setSuccess(); + } + + @Override + public void flush(ChannelHandlerContext ctx) { + executionMask |= MASK_FLUSH; + } + + void assertCalled() { + assertCalled("handlerAdded", MASK_ADDED); + assertCalled("handlerRemoved", MASK_REMOVED); + assertCalled("bind", MASK_BIND); + assertCalled("connect", MASK_CONNECT); + assertCalled("disconnect", MASK_DISCONNECT); + assertCalled("close", MASK_CLOSE); + assertCalled("deregister", MASK_DEREGISTER); + assertCalled("read", MASK_READ); + assertCalled("write", MASK_WRITE); + assertCalled("flush", MASK_FLUSH); + } + + private void assertCalled(String methodName, int mask) { + assertTrue((executionMask & mask) != 0, methodName + " was not called"); + } + } + + final class InboundCalledHandler extends ChannelInboundHandlerAdapter { + + private static final int MASK_CHANNEL_REGISTER = 1; + private static final int MASK_CHANNEL_UNREGISTER = 1 << 1; + private static final int MASK_CHANNEL_ACTIVE = 1 << 2; + private static final int MASK_CHANNEL_INACTIVE = 1 << 3; + private static final int MASK_CHANNEL_READ = 1 << 4; + private static final int MASK_CHANNEL_READ_COMPLETE = 1 << 5; + private static final int MASK_USER_EVENT_TRIGGERED = 1 << 6; + private static final int MASK_CHANNEL_WRITABILITY_CHANGED = 1 << 7; + private static final int MASK_EXCEPTION_CAUGHT = 1 << 8; + private static final int MASK_ADDED = 1 << 9; + private static final int MASK_REMOVED = 1 << 10; + + private int executionMask; + + @Override + public void handlerAdded(ChannelHandlerContext ctx) { + executionMask |= MASK_ADDED; + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) { + executionMask |= MASK_REMOVED; + } + + @Override + public void channelRegistered(ChannelHandlerContext ctx) { + executionMask |= MASK_CHANNEL_REGISTER; + } + + @Override + public void channelUnregistered(ChannelHandlerContext ctx) { + executionMask |= MASK_CHANNEL_UNREGISTER; + } + + @Override + public void channelActive(ChannelHandlerContext ctx) { + executionMask |= MASK_CHANNEL_ACTIVE; + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) { + executionMask |= MASK_CHANNEL_INACTIVE; + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + executionMask |= MASK_CHANNEL_READ; + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) { + executionMask |= MASK_CHANNEL_READ_COMPLETE; + } + + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + executionMask |= MASK_USER_EVENT_TRIGGERED; + } + + @Override + public void channelWritabilityChanged(ChannelHandlerContext ctx) { + executionMask |= MASK_CHANNEL_WRITABILITY_CHANGED; + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + executionMask |= MASK_EXCEPTION_CAUGHT; + } + + void assertCalled() { + assertCalled("handlerAdded", MASK_ADDED); + assertCalled("handlerRemoved", MASK_REMOVED); + assertCalled("channelRegistered", MASK_CHANNEL_REGISTER); + assertCalled("channelUnregistered", MASK_CHANNEL_UNREGISTER); + assertCalled("channelActive", MASK_CHANNEL_ACTIVE); + assertCalled("channelInactive", MASK_CHANNEL_INACTIVE); + assertCalled("channelRead", MASK_CHANNEL_READ); + assertCalled("channelReadComplete", MASK_CHANNEL_READ_COMPLETE); + assertCalled("userEventTriggered", MASK_USER_EVENT_TRIGGERED); + assertCalled("channelWritabilityChanged", MASK_CHANNEL_WRITABILITY_CHANGED); + assertCalled("exceptionCaught", MASK_EXCEPTION_CAUGHT); + } + + private void assertCalled(String methodName, int mask) { + assertTrue((executionMask & mask) != 0, methodName + " was not called"); + } + } + + OutboundCalledHandler outboundCalledHandler = new OutboundCalledHandler(); + SkipHandler skipHandler = new SkipHandler(); + InboundCalledHandler inboundCalledHandler = new InboundCalledHandler(); + pipeline.addLast(outboundCalledHandler, skipHandler, inboundCalledHandler); + + pipeline.fireChannelRegistered(); + pipeline.fireChannelUnregistered(); + pipeline.fireChannelActive(); + pipeline.fireChannelInactive(); + pipeline.fireChannelRead(""); + pipeline.fireChannelReadComplete(); + pipeline.fireChannelWritabilityChanged(); + pipeline.fireUserEventTriggered(""); + pipeline.fireExceptionCaught(new Exception()); + + pipeline.deregister().syncUninterruptibly(); + pipeline.bind(new SocketAddress() { + }).syncUninterruptibly(); + pipeline.connect(new SocketAddress() { + }).syncUninterruptibly(); + pipeline.disconnect().syncUninterruptibly(); + pipeline.close().syncUninterruptibly(); + pipeline.write(""); + pipeline.flush(); + pipeline.read(); + + pipeline.remove(outboundCalledHandler); + pipeline.remove(inboundCalledHandler); + pipeline.remove(skipHandler); + + assertFalse(channel.finish()); + + outboundCalledHandler.assertCalled(); + inboundCalledHandler.assertCalled(); + skipHandler.assertSkipped(); + } + + @Test + public void testWriteThrowsReleaseMessage() { + testWriteThrowsReleaseMessage0(false); + } + + @Test + public void testWriteAndFlushThrowsReleaseMessage() { + testWriteThrowsReleaseMessage0(true); + } + + private void testWriteThrowsReleaseMessage0(boolean flush) { + ReferenceCounted referenceCounted = new AbstractReferenceCounted() { + @Override + protected void deallocate() { + // NOOP + } + + @Override + public ReferenceCounted touch(Object hint) { + return this; + } + }; + assertEquals(1, referenceCounted.refCnt()); + + Channel channel = new LocalChannel(); + Channel channel2 = new LocalChannel(); + group.register(channel).syncUninterruptibly(); + group.register(channel2).syncUninterruptibly(); + + try { + if (flush) { + channel.writeAndFlush(referenceCounted, channel2.newPromise()); + } else { + channel.write(referenceCounted, channel2.newPromise()); + } + fail(); + } catch (IllegalArgumentException expected) { + // expected + } + assertEquals(0, referenceCounted.refCnt()); + + channel.close().syncUninterruptibly(); + channel2.close().syncUninterruptibly(); + } + + @Test + @Timeout(value = 5000, unit = TimeUnit.MILLISECONDS) + public void testHandlerAddedFailedButHandlerStillRemoved() throws InterruptedException { + testHandlerAddedFailedButHandlerStillRemoved0(false); + } + + @Test + @Timeout(value = 5000, unit = TimeUnit.MILLISECONDS) + public void testHandlerAddedFailedButHandlerStillRemovedWithLaterRegister() throws InterruptedException { + testHandlerAddedFailedButHandlerStillRemoved0(true); + } + + private static void testHandlerAddedFailedButHandlerStillRemoved0(boolean lateRegister) + throws InterruptedException { + EventExecutorGroup executorGroup = new DefaultEventExecutorGroup(16); + final int numHandlers = 32; + try { + Channel channel = new LocalChannel(); + channel.config().setOption(ChannelOption.SINGLE_EVENTEXECUTOR_PER_GROUP, false); + if (!lateRegister) { + group.register(channel).sync(); + } + channel.pipeline().addFirst(newHandler()); + + List latchList = new ArrayList(numHandlers); + for (int i = 0; i < numHandlers; i++) { + CountDownLatch latch = new CountDownLatch(1); + channel.pipeline().addFirst(executorGroup, "h" + i, new BadChannelHandler(latch)); + latchList.add(latch); + } + if (lateRegister) { + group.register(channel).sync(); + } + + for (int i = 0; i < numHandlers; i++) { + // Wait until the latch was countDown which means handlerRemoved(...) was called. + latchList.get(i).await(); + assertNull(channel.pipeline().get("h" + i)); + } + } finally { + executorGroup.shutdownGracefully(); + } + } + + private static final class BadChannelHandler extends ChannelHandlerAdapter { + private final CountDownLatch latch; + + BadChannelHandler(CountDownLatch latch) { + this.latch = latch; + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + TimeUnit.MILLISECONDS.sleep(10); + throw new RuntimeException(); + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) { + latch.countDown(); + } + } + + @Test + @Timeout(value = 5000, unit = TimeUnit.MILLISECONDS) + public void handlerAddedStateUpdatedBeforeHandlerAddedDoneForceEventLoop() throws InterruptedException { + handlerAddedStateUpdatedBeforeHandlerAddedDone(true); + } + + @Test + @Timeout(value = 5000, unit = TimeUnit.MILLISECONDS) + public void handlerAddedStateUpdatedBeforeHandlerAddedDoneOnCallingThread() throws InterruptedException { + handlerAddedStateUpdatedBeforeHandlerAddedDone(false); + } + + private static void handlerAddedStateUpdatedBeforeHandlerAddedDone(boolean executeInEventLoop) + throws InterruptedException { + final ChannelPipeline pipeline = new LocalChannel().pipeline(); + final Object userEvent = new Object(); + final Object writeObject = new Object(); + final CountDownLatch doneLatch = new CountDownLatch(1); + + group.register(pipeline.channel()); + + Runnable r = new Runnable() { + @Override + public void run() { + pipeline.addLast(new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + if (evt == userEvent) { + ctx.write(writeObject); + } + ctx.fireUserEventTriggered(evt); + } + }); + pipeline.addFirst(new ChannelDuplexHandler() { + @Override + public void handlerAdded(ChannelHandlerContext ctx) { + ctx.fireUserEventTriggered(userEvent); + } + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) { + if (msg == writeObject) { + doneLatch.countDown(); + } + ctx.write(msg, promise); + } + }); + } + }; + + if (executeInEventLoop) { + pipeline.channel().eventLoop().execute(r); + } else { + r.run(); + } + + doneLatch.await(); + } + + private static final class TestTask implements Runnable { + + private final ChannelPipeline pipeline; + private final CountDownLatch latch; + + TestTask(ChannelPipeline pipeline, CountDownLatch latch) { + this.pipeline = pipeline; + this.latch = latch; + } + + @Override + public void run() { + pipeline.addLast(new ChannelInboundHandlerAdapter()); + latch.countDown(); + } + } + + private static final class CallbackCheckHandler extends ChannelHandlerAdapter { + final Promise addedHandler = ImmediateEventExecutor.INSTANCE.newPromise(); + final Promise removedHandler = ImmediateEventExecutor.INSTANCE.newPromise(); + final AtomicReference error = new AtomicReference(); + + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + if (!addedHandler.trySuccess(true)) { + error.set(new AssertionError("handlerAdded(...) called multiple times: " + ctx.name())); + } else if (removedHandler.getNow() == Boolean.TRUE) { + error.set(new AssertionError("handlerRemoved(...) called before handlerAdded(...): " + ctx.name())); + } + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) { + if (!removedHandler.trySuccess(true)) { + error.set(new AssertionError("handlerRemoved(...) called multiple times: " + ctx.name())); + } else if (addedHandler.getNow() == Boolean.FALSE) { + error.set(new AssertionError("handlerRemoved(...) called before handlerAdded(...): " + ctx.name())); + } + } + } + + private static final class CheckExceptionHandler extends ChannelInboundHandlerAdapter { + private final Throwable expected; + private final Promise promise; + + CheckExceptionHandler(Throwable expected, Promise promise) { + this.expected = expected; + this.promise = promise; + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + if (cause instanceof ChannelPipelineException && cause.getCause() == expected) { + promise.setSuccess(null); + } else { + promise.setFailure(new AssertionError("cause not the expected instance")); + } + } + } + + private static void assertHandler(CheckOrderHandler actual, CheckOrderHandler... handlers) throws Throwable { + for (CheckOrderHandler h : handlers) { + if (h == actual) { + actual.checkError(); + return; + } + } + fail("handler was not one of the expected handlers"); + } + + private static final class CheckOrderHandler extends ChannelHandlerAdapter { + private final Queue addedQueue; + private final Queue removedQueue; + private final AtomicReference error = new AtomicReference(); + + CheckOrderHandler(Queue addedQueue, Queue removedQueue) { + this.addedQueue = addedQueue; + this.removedQueue = removedQueue; + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + addedQueue.add(this); + checkExecutor(ctx); + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) { + removedQueue.add(this); + checkExecutor(ctx); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + error.set(cause); + } + + void checkError() throws Throwable { + Throwable cause = error.get(); + if (cause != null) { + throw cause; + } + } + + private void checkExecutor(ChannelHandlerContext ctx) { + if (!ctx.executor().inEventLoop()) { + error.set(new AssertionError()); + } + } + } + + private static final class CheckEventExecutorHandler extends ChannelHandlerAdapter { + final EventExecutor executor; + final Promise addedPromise; + final Promise removedPromise; + + CheckEventExecutorHandler(EventExecutor executor) { + this.executor = executor; + addedPromise = executor.newPromise(); + removedPromise = executor.newPromise(); + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + assertExecutor(ctx, addedPromise); + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) { + assertExecutor(ctx, removedPromise); + } + + private void assertExecutor(ChannelHandlerContext ctx, Promise promise) { + final boolean same; + try { + same = executor == ctx.executor(); + } catch (Throwable cause) { + promise.setFailure(cause); + return; + } + if (same) { + promise.setSuccess(null); + } else { + promise.setFailure(new AssertionError("EventExecutor not the same")); + } + } + } + private static final class ErrorChannelHandler extends ChannelHandlerAdapter { + private final AtomicReference error; + + ErrorChannelHandler(AtomicReference error) { + this.error = error; + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + error.set(new AssertionError()); + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) { + error.set(new AssertionError()); + } + } + + private static int next(AbstractChannelHandlerContext ctx) { + AbstractChannelHandlerContext next = ctx.next; + if (next == null) { + return Integer.MAX_VALUE; + } + + return toInt(next.name()); + } + + private static int toInt(String name) { + try { + return Integer.parseInt(name); + } catch (NumberFormatException e) { + return -1; + } + } + + private static void verifyContextNumber(ChannelPipeline pipeline, int expectedNumber) { + AbstractChannelHandlerContext ctx = (AbstractChannelHandlerContext) pipeline.firstContext(); + int handlerNumber = 0; + while (ctx != ((DefaultChannelPipeline) pipeline).tail) { + handlerNumber++; + ctx = ctx.next; + } + assertEquals(expectedNumber, handlerNumber); + } + + private static ChannelHandler[] newHandlers(int num) { + assert num > 0; + + ChannelHandler[] handlers = new ChannelHandler[num]; + for (int i = 0; i < num; i++) { + handlers[i] = newHandler(); + } + + return handlers; + } + + private static ChannelHandler newHandler() { + return new TestHandler(); + } + + @Sharable + private static class TestHandler extends ChannelDuplexHandler { } + + private static class BufferedTestHandler extends ChannelDuplexHandler { + final Queue inboundBuffer = new ArrayDeque(); + final Queue outboundBuffer = new ArrayDeque(); + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + outboundBuffer.add(msg); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + inboundBuffer.add(msg); + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { + if (!inboundBuffer.isEmpty()) { + for (Object o: inboundBuffer) { + ctx.fireChannelRead(o); + } + ctx.fireChannelReadComplete(); + } + if (!outboundBuffer.isEmpty()) { + for (Object o: outboundBuffer) { + ctx.write(o); + } + ctx.flush(); + } + } + } + + /** Test handler to validate life-cycle aware behavior. */ + private static final class LifeCycleAwareTestHandler extends ChannelHandlerAdapter { + private final String name; + + private boolean afterAdd; + private boolean afterRemove; + + /** + * Constructs life-cycle aware test handler. + * + * @param name Handler name to display in assertion messages. + */ + private LifeCycleAwareTestHandler(String name) { + this.name = name; + } + + public void validate(boolean afterAdd, boolean afterRemove) { + assertEquals(afterAdd, this.afterAdd, name); + assertEquals(afterRemove, this.afterRemove, name); + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) { + validate(false, false); + + afterAdd = true; + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) { + validate(true, false); + + afterRemove = true; + } + } + + private static final class WrapperExecutor extends AbstractEventExecutor { + + private final ExecutorService wrapped = Executors.newSingleThreadExecutor(); + + @Override + public boolean isShuttingDown() { + return wrapped.isShutdown(); + } + + @Override + public Future shutdownGracefully(long l, long l2, TimeUnit timeUnit) { + throw new IllegalStateException(); + } + + @Override + public Future terminationFuture() { + throw new IllegalStateException(); + } + + @Override + public void shutdown() { + wrapped.shutdown(); + } + + @Override + public List shutdownNow() { + return wrapped.shutdownNow(); + } + + @Override + public boolean isShutdown() { + return wrapped.isShutdown(); + } + + @Override + public boolean isTerminated() { + return wrapped.isTerminated(); + } + + @Override + public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException { + return wrapped.awaitTermination(timeout, unit); + } + + @Override + public EventExecutorGroup parent() { + return null; + } + + @Override + public boolean inEventLoop(Thread thread) { + return false; + } + + @Override + public void execute(Runnable command) { + wrapped.execute(command); + } + } +} diff --git a/netty-channel/src/test/java/io/netty/channel/DefaultChannelPromiseTest.java b/netty-channel/src/test/java/io/netty/channel/DefaultChannelPromiseTest.java new file mode 100644 index 0000000..36391d2 --- /dev/null +++ b/netty-channel/src/test/java/io/netty/channel/DefaultChannelPromiseTest.java @@ -0,0 +1,56 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.util.concurrent.ImmediateEventExecutor; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class DefaultChannelPromiseTest { + + @Test + public void testNullChannel() { + assertThrows(NullPointerException.class, new Executable() { + @Override + public void execute() { + new DefaultChannelPromise(null); + } + }); + } + + @Test + public void testChannelWithNullExecutor() { + assertThrows(NullPointerException.class, new Executable() { + @Override + public void execute() { + new DefaultChannelPromise(new EmbeddedChannel(), null); + } + }); + } + + @Test + public void testNullChannelWithExecutor() { + assertThrows(NullPointerException.class, new Executable() { + @Override + public void execute() { + new DefaultChannelPromise(null, ImmediateEventExecutor.INSTANCE); + } + }); + } +} diff --git a/netty-channel/src/test/java/io/netty/channel/DefaultFileRegionTest.java b/netty-channel/src/test/java/io/netty/channel/DefaultFileRegionTest.java new file mode 100644 index 0000000..2a3aa4f --- /dev/null +++ b/netty-channel/src/test/java/io/netty/channel/DefaultFileRegionTest.java @@ -0,0 +1,120 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.util.internal.PlatformDependent; +import org.junit.jupiter.api.Test; + +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.RandomAccessFile; +import java.nio.channels.Channels; +import java.nio.channels.WritableByteChannel; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.fail; + +public class DefaultFileRegionTest { + + private static final byte[] data = new byte[1048576 * 10]; + + static { + PlatformDependent.threadLocalRandom().nextBytes(data); + } + + private static File newFile() throws IOException { + File file = PlatformDependent.createTempFile("netty-", ".tmp", null); + file.deleteOnExit(); + + final FileOutputStream out = new FileOutputStream(file); + out.write(data); + out.close(); + return file; + } + + @Test + public void testCreateFromFile() throws IOException { + File file = newFile(); + try { + testFileRegion(new DefaultFileRegion(file, 0, data.length)); + } finally { + file.delete(); + } + } + + @Test + public void testCreateFromFileChannel() throws IOException { + File file = newFile(); + RandomAccessFile randomAccessFile = new RandomAccessFile(file, "r"); + try { + testFileRegion(new DefaultFileRegion(randomAccessFile.getChannel(), 0, data.length)); + } finally { + randomAccessFile.close(); + file.delete(); + } + } + + private static void testFileRegion(FileRegion region) throws IOException { + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + WritableByteChannel channel = Channels.newChannel(outputStream); + + try { + assertEquals(data.length, region.count()); + assertEquals(0, region.transferred()); + assertEquals(data.length, region.transferTo(channel, 0)); + assertEquals(data.length, region.count()); + assertEquals(data.length, region.transferred()); + assertArrayEquals(data, outputStream.toByteArray()); + } finally { + channel.close(); + } + } + + @Test + public void testTruncated() throws IOException { + File file = newFile(); + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + WritableByteChannel channel = Channels.newChannel(outputStream); + RandomAccessFile randomAccessFile = new RandomAccessFile(file, "rw"); + + try { + FileRegion region = new DefaultFileRegion(randomAccessFile.getChannel(), 0, data.length); + + randomAccessFile.getChannel().truncate(data.length - 1024); + + assertEquals(data.length, region.count()); + assertEquals(0, region.transferred()); + + assertEquals(data.length - 1024, region.transferTo(channel, 0)); + assertEquals(data.length, region.count()); + assertEquals(data.length - 1024, region.transferred()); + try { + region.transferTo(channel, data.length - 1024); + fail(); + } catch (IOException expected) { + // expected + } + } finally { + channel.close(); + + randomAccessFile.close(); + file.delete(); + } + } +} diff --git a/netty-channel/src/test/java/io/netty/channel/DefaultMaxMessagesRecvByteBufAllocatorTest.java b/netty-channel/src/test/java/io/netty/channel/DefaultMaxMessagesRecvByteBufAllocatorTest.java new file mode 100644 index 0000000..4410446 --- /dev/null +++ b/netty-channel/src/test/java/io/netty/channel/DefaultMaxMessagesRecvByteBufAllocatorTest.java @@ -0,0 +1,76 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.channel.embedded.EmbeddedChannel; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class DefaultMaxMessagesRecvByteBufAllocatorTest { + + private DefaultMaxMessagesRecvByteBufAllocator newAllocator(boolean ignoreReadBytes) { + return new DefaultMaxMessagesRecvByteBufAllocator(2, ignoreReadBytes) { + @Override + public Handle newHandle() { + return new MaxMessageHandle() { + @Override + public int guess() { + return 0; + } + }; + } + }; + } + + @Test + public void testRespectReadBytes() { + DefaultMaxMessagesRecvByteBufAllocator allocator = newAllocator(false); + RecvByteBufAllocator.Handle handle = allocator.newHandle(); + + EmbeddedChannel channel = new EmbeddedChannel(); + handle.reset(channel.config()); + handle.incMessagesRead(1); + assertFalse(handle.continueReading()); + + handle.reset(channel.config()); + handle.incMessagesRead(1); + handle.attemptedBytesRead(1); + handle.lastBytesRead(1); + assertTrue(handle.continueReading()); + channel.finish(); + } + + @Test + public void testIgnoreReadBytes() { + DefaultMaxMessagesRecvByteBufAllocator allocator = newAllocator(true); + RecvByteBufAllocator.Handle handle = allocator.newHandle(); + + EmbeddedChannel channel = new EmbeddedChannel(); + handle.reset(channel.config()); + handle.incMessagesRead(1); + assertTrue(handle.continueReading()); + handle.incMessagesRead(1); + assertFalse(handle.continueReading()); + + handle.reset(channel.config()); + handle.attemptedBytesRead(0); + handle.lastBytesRead(0); + assertTrue(handle.continueReading()); + channel.finish(); + } +} diff --git a/netty-channel/src/test/java/io/netty/channel/DelegatingChannelPromiseNotifierTest.java b/netty-channel/src/test/java/io/netty/channel/DelegatingChannelPromiseNotifierTest.java new file mode 100644 index 0000000..272790d --- /dev/null +++ b/netty-channel/src/test/java/io/netty/channel/DelegatingChannelPromiseNotifierTest.java @@ -0,0 +1,37 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.GenericFutureListener; +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; + +public class DelegatingChannelPromiseNotifierTest { + @Test + public void varargsNotifiersAllowed() { + ChannelPromise promise = Mockito.mock(ChannelPromise.class); + DelegatingChannelPromiseNotifier promiseNotifier = new DelegatingChannelPromiseNotifier(promise); + + GenericFutureListener> gfl = + (GenericFutureListener>) Mockito.mock(GenericFutureListener.class); + promiseNotifier.addListeners(gfl); + promiseNotifier.removeListeners(gfl); + + Mockito.verify(promise).addListeners(gfl); + Mockito.verify(promise).removeListeners(gfl); + } +} diff --git a/netty-channel/src/test/java/io/netty/channel/FailedChannelFutureTest.java b/netty-channel/src/test/java/io/netty/channel/FailedChannelFutureTest.java new file mode 100644 index 0000000..f3dc100 --- /dev/null +++ b/netty-channel/src/test/java/io/netty/channel/FailedChannelFutureTest.java @@ -0,0 +1,46 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; +import org.mockito.Mockito; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class FailedChannelFutureTest { + @Test + public void testConstantProperties() { + Channel channel = Mockito.mock(Channel.class); + Exception e = new Exception(); + FailedChannelFuture future = new FailedChannelFuture(channel, null, e); + + assertFalse(future.isSuccess()); + assertSame(e, future.cause()); + } + + @Test + public void shouldDisallowNullException() { + assertThrows(NullPointerException.class, new Executable() { + @Override + public void execute() { + new FailedChannelFuture(Mockito.mock(Channel.class), null, null); + } + }); + } +} diff --git a/netty-channel/src/test/java/io/netty/channel/LoggingTestHandler.java b/netty-channel/src/test/java/io/netty/channel/LoggingTestHandler.java new file mode 100644 index 0000000..b613e9f --- /dev/null +++ b/netty-channel/src/test/java/io/netty/channel/LoggingTestHandler.java @@ -0,0 +1,171 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import java.net.SocketAddress; +import java.util.Collections; +import java.util.EnumSet; + +final class LoggingTestHandler implements ChannelInboundHandler, ChannelOutboundHandler { + + enum Event { WRITE, FLUSH, BIND, CONNECT, DISCONNECT, CLOSE, DEREGISTER, READ, WRITABILITY, + HANDLER_ADDED, HANDLER_REMOVED, EXCEPTION, READ_COMPLETE, REGISTERED, UNREGISTERED, ACTIVE, INACTIVE, + USER } + + private StringBuilder log = new StringBuilder(); + + private final EnumSet interest = EnumSet.allOf(Event.class); + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + log(Event.WRITE); + ctx.write(msg, promise); + } + + @Override + public void flush(ChannelHandlerContext ctx) throws Exception { + log(Event.FLUSH); + ctx.flush(); + } + + @Override + public void bind(ChannelHandlerContext ctx, SocketAddress localAddress, ChannelPromise promise) + throws Exception { + log(Event.BIND, "localAddress=" + localAddress); + ctx.bind(localAddress, promise); + } + + @Override + public void connect(ChannelHandlerContext ctx, SocketAddress remoteAddress, SocketAddress localAddress, + ChannelPromise promise) throws Exception { + log(Event.CONNECT, "remoteAddress=" + remoteAddress + " localAddress=" + localAddress); + ctx.connect(remoteAddress, localAddress, promise); + } + + @Override + public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + log(Event.DISCONNECT); + ctx.disconnect(promise); + } + + @Override + public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + log(Event.CLOSE); + ctx.close(promise); + } + + @Override + public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + log(Event.DEREGISTER); + ctx.deregister(promise); + } + + @Override + public void read(ChannelHandlerContext ctx) throws Exception { + log(Event.READ); + ctx.read(); + } + + @Override + public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception { + log(Event.WRITABILITY, "writable=" + ctx.channel().isWritable()); + ctx.fireChannelWritabilityChanged(); + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + log(Event.HANDLER_ADDED); + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { + log(Event.HANDLER_REMOVED); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + log(Event.EXCEPTION, cause.toString()); + } + + @Override + public void channelRegistered(ChannelHandlerContext ctx) throws Exception { + log(Event.REGISTERED); + ctx.fireChannelRegistered(); + } + + @Override + public void channelUnregistered(ChannelHandlerContext ctx) throws Exception { + log(Event.UNREGISTERED); + ctx.fireChannelUnregistered(); + } + + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + log(Event.ACTIVE); + ctx.fireChannelActive(); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + log(Event.INACTIVE); + ctx.fireChannelInactive(); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + log(Event.READ); + ctx.fireChannelRead(msg); + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + log(Event.READ_COMPLETE); + ctx.fireChannelReadComplete(); + } + + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + log(Event.USER, evt.toString()); + ctx.fireUserEventTriggered(evt); + } + + String getLog() { + return log.toString(); + } + + void clear() { + log = new StringBuilder(); + } + + void setInterest(Event... events) { + interest.clear(); + Collections.addAll(interest, events); + } + + private void log(Event e) { + log(e, null); + } + + private void log(Event e, String msg) { + if (interest.contains(e)) { + log.append(e); + if (msg != null) { + log.append(": ").append(msg); + } + log.append('\n'); + } + } +} diff --git a/netty-channel/src/test/java/io/netty/channel/NativeImageHandlerMetadataTest.java b/netty-channel/src/test/java/io/netty/channel/NativeImageHandlerMetadataTest.java new file mode 100644 index 0000000..77e9ed2 --- /dev/null +++ b/netty-channel/src/test/java/io/netty/channel/NativeImageHandlerMetadataTest.java @@ -0,0 +1,28 @@ +/* + * Copyright 2022 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.nativeimage.ChannelHandlerMetadataUtil; +import org.junit.jupiter.api.Test; + +public class NativeImageHandlerMetadataTest { + + @Test + public void collectAndCompareMetadata() { + ChannelHandlerMetadataUtil.generateMetadata("io.netty.bootstrap", "io.netty.channel"); + } + +} diff --git a/netty-channel/src/test/java/io/netty/channel/PendingWriteQueueTest.java b/netty-channel/src/test/java/io/netty/channel/PendingWriteQueueTest.java new file mode 100644 index 0000000..3835448 --- /dev/null +++ b/netty-channel/src/test/java/io/netty/channel/PendingWriteQueueTest.java @@ -0,0 +1,415 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.util.CharsetUtil; + +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.atomic.AtomicReference; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.*; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class PendingWriteQueueTest { + + @Test + public void testRemoveAndWrite() { + assertWrite(new TestHandler() { + @Override + public void flush(ChannelHandlerContext ctx) throws Exception { + assertFalse(ctx.channel().isWritable(), "Should not be writable anymore"); + + ChannelFuture future = queue.removeAndWrite(); + future.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + assertQueueEmpty(queue); + } + }); + super.flush(ctx); + } + }, 1); + } + + @Test + public void testRemoveAndWriteAll() { + assertWrite(new TestHandler() { + @Override + public void flush(ChannelHandlerContext ctx) throws Exception { + assertFalse(ctx.channel().isWritable(), "Should not be writable anymore"); + + ChannelFuture future = queue.removeAndWriteAll(); + future.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + assertQueueEmpty(queue); + } + }); + super.flush(ctx); + } + }, 3); + } + + @Test + public void testRemoveAndFail() { + assertWriteFails(new TestHandler() { + + @Override + public void flush(ChannelHandlerContext ctx) throws Exception { + queue.removeAndFail(new TestException()); + super.flush(ctx); + } + }, 1); + } + + @Test + public void testRemoveAndFailAll() { + assertWriteFails(new TestHandler() { + @Override + public void flush(ChannelHandlerContext ctx) throws Exception { + queue.removeAndFailAll(new TestException()); + super.flush(ctx); + } + }, 3); + } + + @Test + public void shouldFireChannelWritabilityChangedAfterRemoval() { + final AtomicReference ctxRef = new AtomicReference(); + final AtomicReference queueRef = new AtomicReference(); + final ByteBuf msg = Unpooled.copiedBuffer("test", CharsetUtil.US_ASCII); + + final EmbeddedChannel channel = new EmbeddedChannel(new ChannelInboundHandlerAdapter() { + @Override + public void handlerAdded(ChannelHandlerContext ctx) { + ctxRef.set(ctx); + queueRef.set(new PendingWriteQueue(ctx)); + } + + @Override + public void channelWritabilityChanged(ChannelHandlerContext ctx) { + final PendingWriteQueue queue = queueRef.get(); + + final ByteBuf msg = (ByteBuf) queue.current(); + if (msg == null) { + return; + } + + assertThat(msg.refCnt(), is(1)); + + // This call will trigger another channelWritabilityChanged() event because the number of + // pending bytes will go below the low watermark. + // + // If PendingWriteQueue.remove() did not remove the current entry before triggering + // channelWritabilityChanged() event, we will end up with attempting to remove the same + // element twice, resulting in the double release. + queue.remove(); + + assertThat(msg.refCnt(), is(0)); + } + }); + + channel.config().setWriteBufferLowWaterMark(1); + channel.config().setWriteBufferHighWaterMark(3); + + final PendingWriteQueue queue = queueRef.get(); + + // Trigger channelWritabilityChanged() by adding a message that's larger than the high watermark. + queue.add(msg, channel.newPromise()); + + channel.finish(); + + assertThat(msg.refCnt(), is(0)); + } + + private static void assertWrite(ChannelHandler handler, int count) { + final ByteBuf buffer = Unpooled.copiedBuffer("Test", CharsetUtil.US_ASCII); + final EmbeddedChannel channel = new EmbeddedChannel(handler); + channel.config().setWriteBufferLowWaterMark(1); + channel.config().setWriteBufferHighWaterMark(3); + + ByteBuf[] buffers = new ByteBuf[count]; + for (int i = 0; i < buffers.length; i++) { + buffers[i] = buffer.retainedDuplicate(); + } + assertTrue(channel.writeOutbound(buffers)); + assertTrue(channel.finish()); + channel.closeFuture().syncUninterruptibly(); + + for (int i = 0; i < buffers.length; i++) { + assertBuffer(channel, buffer); + } + buffer.release(); + assertNull(channel.readOutbound()); + } + + private static void assertBuffer(EmbeddedChannel channel, ByteBuf buffer) { + ByteBuf written = channel.readOutbound(); + assertEquals(buffer, written); + written.release(); + } + + private static void assertQueueEmpty(PendingWriteQueue queue) { + assertTrue(queue.isEmpty()); + assertEquals(0, queue.size()); + assertEquals(0, queue.bytes()); + assertNull(queue.current()); + assertNull(queue.removeAndWrite()); + assertNull(queue.removeAndWriteAll()); + } + + private static void assertWriteFails(ChannelHandler handler, int count) { + final ByteBuf buffer = Unpooled.copiedBuffer("Test", CharsetUtil.US_ASCII); + final EmbeddedChannel channel = new EmbeddedChannel(handler); + ByteBuf[] buffers = new ByteBuf[count]; + for (int i = 0; i < buffers.length; i++) { + buffers[i] = buffer.retainedDuplicate(); + } + try { + assertFalse(channel.writeOutbound(buffers)); + fail(); + } catch (Exception e) { + assertTrue(e instanceof TestException); + } + assertFalse(channel.finish()); + channel.closeFuture().syncUninterruptibly(); + + buffer.release(); + assertNull(channel.readOutbound()); + } + + private static EmbeddedChannel newChannel() { + // Add a handler so we can access a ChannelHandlerContext via the ChannelPipeline. + return new EmbeddedChannel(new ChannelHandlerAdapter() { }); + } + + @Test + public void testRemoveAndFailAllReentrantFailAll() { + EmbeddedChannel channel = newChannel(); + final PendingWriteQueue queue = new PendingWriteQueue(channel.pipeline().firstContext()); + + ChannelPromise promise = channel.newPromise(); + promise.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + queue.removeAndFailAll(new IllegalStateException()); + } + }); + queue.add(1L, promise); + + ChannelPromise promise2 = channel.newPromise(); + queue.add(2L, promise2); + queue.removeAndFailAll(new Exception()); + assertTrue(promise.isDone()); + assertFalse(promise.isSuccess()); + assertTrue(promise2.isDone()); + assertFalse(promise2.isSuccess()); + assertFalse(channel.finish()); + } + + @Test + public void testRemoveAndWriteAllReentrantWrite() { + EmbeddedChannel channel = new EmbeddedChannel(new ChannelOutboundHandlerAdapter() { + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) { + // Convert to writeAndFlush(...) so the promise will be notified by the transport. + ctx.writeAndFlush(msg, promise); + } + }, new ChannelOutboundHandlerAdapter()); + + final PendingWriteQueue queue = new PendingWriteQueue(channel.pipeline().lastContext()); + + ChannelPromise promise = channel.newPromise(); + final ChannelPromise promise3 = channel.newPromise(); + promise.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + queue.add(3L, promise3); + } + }); + queue.add(1L, promise); + ChannelPromise promise2 = channel.newPromise(); + queue.add(2L, promise2); + queue.removeAndWriteAll(); + + assertTrue(promise.isDone()); + assertTrue(promise.isSuccess()); + assertTrue(promise2.isDone()); + assertTrue(promise2.isSuccess()); + assertTrue(promise3.isDone()); + assertTrue(promise3.isSuccess()); + assertTrue(channel.finish()); + assertEquals(1L, (Long) channel.readOutbound()); + assertEquals(2L, (Long) channel.readOutbound()); + assertEquals(3L, (Long) channel.readOutbound()); + } + + @Test + public void testRemoveAndWriteAllWithVoidPromise() { + EmbeddedChannel channel = new EmbeddedChannel(new ChannelOutboundHandlerAdapter() { + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) { + // Convert to writeAndFlush(...) so the promise will be notified by the transport. + ctx.writeAndFlush(msg, promise); + } + }, new ChannelOutboundHandlerAdapter()); + + final PendingWriteQueue queue = new PendingWriteQueue(channel.pipeline().lastContext()); + + ChannelPromise promise = channel.newPromise(); + queue.add(1L, promise); + queue.add(2L, channel.voidPromise()); + queue.removeAndWriteAll(); + + assertTrue(channel.finish()); + assertTrue(promise.isDone()); + assertTrue(promise.isSuccess()); + assertEquals(1L, (Long) channel.readOutbound()); + assertEquals(2L, (Long) channel.readOutbound()); + } + + @Test + public void testRemoveAndFailAllReentrantWrite() { + final List failOrder = Collections.synchronizedList(new ArrayList()); + EmbeddedChannel channel = newChannel(); + final PendingWriteQueue queue = new PendingWriteQueue(channel.pipeline().firstContext()); + + ChannelPromise promise = channel.newPromise(); + final ChannelPromise promise3 = channel.newPromise(); + promise3.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + failOrder.add(3); + } + }); + promise.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + failOrder.add(1); + queue.add(3L, promise3); + } + }); + queue.add(1L, promise); + + ChannelPromise promise2 = channel.newPromise(); + promise2.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + failOrder.add(2); + } + }); + queue.add(2L, promise2); + queue.removeAndFailAll(new Exception()); + assertTrue(promise.isDone()); + assertFalse(promise.isSuccess()); + assertTrue(promise2.isDone()); + assertFalse(promise2.isSuccess()); + assertTrue(promise3.isDone()); + assertFalse(promise3.isSuccess()); + assertFalse(channel.finish()); + assertEquals(1, (int) failOrder.get(0)); + assertEquals(2, (int) failOrder.get(1)); + assertEquals(3, (int) failOrder.get(2)); + } + + @Test + public void testRemoveAndWriteAllReentrance() { + EmbeddedChannel channel = newChannel(); + final PendingWriteQueue queue = new PendingWriteQueue(channel.pipeline().firstContext()); + + ChannelPromise promise = channel.newPromise(); + promise.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + queue.removeAndWriteAll(); + } + }); + queue.add(1L, promise); + + ChannelPromise promise2 = channel.newPromise(); + queue.add(2L, promise2); + queue.removeAndWriteAll(); + channel.flush(); + assertTrue(promise.isSuccess()); + assertTrue(promise2.isSuccess()); + assertTrue(channel.finish()); + + assertEquals(1L, (Long) channel.readOutbound()); + assertEquals(2L, (Long) channel.readOutbound()); + assertNull(channel.readOutbound()); + assertNull(channel.readInbound()); + } + + // See https://github.com/netty/netty/issues/3967 + @Test + public void testCloseChannelOnCreation() { + EmbeddedChannel channel = newChannel(); + ChannelHandlerContext context = channel.pipeline().firstContext(); + channel.close().syncUninterruptibly(); + + final PendingWriteQueue queue = new PendingWriteQueue(context); + + IllegalStateException ex = new IllegalStateException(); + ChannelPromise promise = channel.newPromise(); + queue.add(1L, promise); + queue.removeAndFailAll(ex); + assertSame(ex, promise.cause()); + } + + private static class TestHandler extends ChannelDuplexHandler { + protected PendingWriteQueue queue; + private int expectedSize; + + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + super.channelActive(ctx); + assertQueueEmpty(queue); + assertTrue(ctx.channel().isWritable(), "Should be writable"); + } + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + queue.add(msg, promise); + assertFalse(queue.isEmpty()); + assertEquals(++expectedSize, queue.size()); + assertNotNull(queue.current()); + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + queue = new PendingWriteQueue(ctx); + } + } + + private static final class TestException extends Exception { + private static final long serialVersionUID = -9018570103039458401L; + } +} diff --git a/netty-channel/src/test/java/io/netty/channel/ReentrantChannelTest.java b/netty-channel/src/test/java/io/netty/channel/ReentrantChannelTest.java new file mode 100644 index 0000000..153c5e9 --- /dev/null +++ b/netty-channel/src/test/java/io/netty/channel/ReentrantChannelTest.java @@ -0,0 +1,288 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.channel.LoggingTestHandler.Event; +import io.netty.channel.local.LocalAddress; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.GenericFutureListener; +import org.hamcrest.Matchers; +import org.junit.jupiter.api.Test; + +import java.nio.channels.ClosedChannelException; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class ReentrantChannelTest extends BaseChannelTest { + + @Test + public void testWritabilityChanged() throws Exception { + + LocalAddress addr = new LocalAddress("testWritabilityChanged"); + + ServerBootstrap sb = getLocalServerBootstrap(); + sb.bind(addr).sync().channel(); + + Bootstrap cb = getLocalClientBootstrap(); + + setInterest(Event.WRITE, Event.FLUSH, Event.WRITABILITY); + + Channel clientChannel = cb.connect(addr).sync().channel(); + clientChannel.config().setWriteBufferLowWaterMark(512); + clientChannel.config().setWriteBufferHighWaterMark(1024); + + // What is supposed to happen from this point: + // + // 1. Because this write attempt has been made from a non-I/O thread, + // ChannelOutboundBuffer.pendingWriteBytes will be increased before + // write() event is really evaluated. + // -> channelWritabilityChanged() will be triggered, + // because the Channel became unwritable. + // + // 2. The write() event is handled by the pipeline in an I/O thread. + // -> write() will be triggered. + // + // 3. Once the write() event is handled, ChannelOutboundBuffer.pendingWriteBytes + // will be decreased. + // -> channelWritabilityChanged() will be triggered, + // because the Channel became writable again. + // + // 4. The message is added to the ChannelOutboundBuffer and thus + // pendingWriteBytes will be increased again. + // -> channelWritabilityChanged() will be triggered. + // + // 5. The flush() event causes the write request in theChannelOutboundBuffer + // to be removed. + // -> flush() and channelWritabilityChanged() will be triggered. + // + // Note that the channelWritabilityChanged() in the step 4 can occur between + // the flush() and the channelWritabilityChanged() in the step 5, because + // the flush() is invoked from a non-I/O thread while the other are from + // an I/O thread. + + ChannelFuture future = clientChannel.write(createTestBuf(2000)); + + clientChannel.flush(); + future.sync(); + + clientChannel.close().sync(); + + assertLog( + // Case 1: + "WRITABILITY: writable=false\n" + + "WRITE\n" + + "WRITABILITY: writable=false\n" + + "WRITABILITY: writable=false\n" + + "FLUSH\n" + + "WRITABILITY: writable=true\n", + // Case 2: + "WRITABILITY: writable=false\n" + + "WRITE\n" + + "WRITABILITY: writable=false\n" + + "FLUSH\n" + + "WRITABILITY: writable=true\n" + + "WRITABILITY: writable=true\n"); + } + + /** + * Similar to {@link #testWritabilityChanged()} with slight variation. + */ + @Test + public void testFlushInWritabilityChanged() throws Exception { + + LocalAddress addr = new LocalAddress("testFlushInWritabilityChanged"); + + ServerBootstrap sb = getLocalServerBootstrap(); + sb.bind(addr).sync().channel(); + + Bootstrap cb = getLocalClientBootstrap(); + + setInterest(Event.WRITE, Event.FLUSH, Event.WRITABILITY); + + Channel clientChannel = cb.connect(addr).sync().channel(); + clientChannel.config().setWriteBufferLowWaterMark(512); + clientChannel.config().setWriteBufferHighWaterMark(1024); + + clientChannel.pipeline().addLast(new ChannelInboundHandlerAdapter() { + @Override + public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception { + if (!ctx.channel().isWritable()) { + ctx.channel().flush(); + } + ctx.fireChannelWritabilityChanged(); + } + }); + + assertTrue(clientChannel.isWritable()); + + clientChannel.write(createTestBuf(2000)).sync(); + clientChannel.close().sync(); + + assertLog( + // Case 1: + "WRITABILITY: writable=false\n" + + "FLUSH\n" + + "WRITE\n" + + "WRITABILITY: writable=false\n" + + "WRITABILITY: writable=false\n" + + "FLUSH\n" + + "WRITABILITY: writable=true\n", + // Case 2: + "WRITABILITY: writable=false\n" + + "FLUSH\n" + + "WRITE\n" + + "WRITABILITY: writable=false\n" + + "FLUSH\n" + + "WRITABILITY: writable=true\n" + + "WRITABILITY: writable=true\n"); + } + + @Test + public void testWriteFlushPingPong() throws Exception { + + LocalAddress addr = new LocalAddress("testWriteFlushPingPong"); + + ServerBootstrap sb = getLocalServerBootstrap(); + sb.bind(addr).sync().channel(); + + Bootstrap cb = getLocalClientBootstrap(); + + setInterest(Event.WRITE, Event.FLUSH, Event.CLOSE, Event.EXCEPTION); + + Channel clientChannel = cb.connect(addr).sync().channel(); + + clientChannel.pipeline().addLast(new ChannelOutboundHandlerAdapter() { + + int writeCount; + int flushCount; + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + if (writeCount < 5) { + writeCount++; + ctx.channel().flush(); + } + super.write(ctx, msg, promise); + } + + @Override + public void flush(ChannelHandlerContext ctx) throws Exception { + if (flushCount < 5) { + flushCount++; + ctx.channel().write(createTestBuf(2000)); + } + super.flush(ctx); + } + }); + + clientChannel.writeAndFlush(createTestBuf(2000)); + clientChannel.close().sync(); + + assertLog( + "WRITE\n" + + "FLUSH\n" + + "WRITE\n" + + "FLUSH\n" + + "WRITE\n" + + "FLUSH\n" + + "WRITE\n" + + "FLUSH\n" + + "WRITE\n" + + "FLUSH\n" + + "WRITE\n" + + "FLUSH\n" + + "CLOSE\n"); + } + + @Test + public void testCloseInFlush() throws Exception { + + LocalAddress addr = new LocalAddress("testCloseInFlush"); + + ServerBootstrap sb = getLocalServerBootstrap(); + sb.bind(addr).sync().channel(); + + Bootstrap cb = getLocalClientBootstrap(); + + setInterest(Event.WRITE, Event.FLUSH, Event.CLOSE, Event.EXCEPTION); + + Channel clientChannel = cb.connect(addr).sync().channel(); + + clientChannel.pipeline().addLast(new ChannelOutboundHandlerAdapter() { + + @Override + public void write(final ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + promise.addListener(new GenericFutureListener>() { + @Override + public void operationComplete(Future future) throws Exception { + ctx.channel().close(); + } + }); + super.write(ctx, msg, promise); + ctx.channel().flush(); + } + }); + + clientChannel.write(createTestBuf(2000)).sync(); + clientChannel.closeFuture().sync(); + + assertLog("WRITE\nFLUSH\nCLOSE\n"); + } + + @Test + public void testFlushFailure() throws Exception { + + LocalAddress addr = new LocalAddress("testFlushFailure"); + + ServerBootstrap sb = getLocalServerBootstrap(); + sb.bind(addr).sync().channel(); + + Bootstrap cb = getLocalClientBootstrap(); + + setInterest(Event.WRITE, Event.FLUSH, Event.CLOSE, Event.EXCEPTION); + + Channel clientChannel = cb.connect(addr).sync().channel(); + + clientChannel.pipeline().addLast(new ChannelOutboundHandlerAdapter() { + + @Override + public void flush(ChannelHandlerContext ctx) throws Exception { + throw new Exception("intentional failure"); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + ctx.close(); + } + }); + + try { + clientChannel.writeAndFlush(createTestBuf(2000)).sync(); + fail(); + } catch (Throwable cce) { + // FIXME: shouldn't this contain the "intentional failure" exception? + assertThat(cce, Matchers.instanceOf(ClosedChannelException.class)); + } + + clientChannel.closeFuture().sync(); + + assertLog("WRITE\nCLOSE\n"); + } +} diff --git a/netty-channel/src/test/java/io/netty/channel/SimpleUserEventChannelHandlerTest.java b/netty-channel/src/test/java/io/netty/channel/SimpleUserEventChannelHandlerTest.java new file mode 100644 index 0000000..595a9b7 --- /dev/null +++ b/netty-channel/src/test/java/io/netty/channel/SimpleUserEventChannelHandlerTest.java @@ -0,0 +1,103 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.buffer.DefaultByteBufHolder; +import io.netty.buffer.Unpooled; +import io.netty.channel.embedded.EmbeddedChannel; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class SimpleUserEventChannelHandlerTest { + + private FooEventCatcher fooEventCatcher; + private AllEventCatcher allEventCatcher; + private EmbeddedChannel channel; + + @BeforeEach + public void setUp() { + fooEventCatcher = new FooEventCatcher(); + allEventCatcher = new AllEventCatcher(); + channel = new EmbeddedChannel(fooEventCatcher, allEventCatcher); + } + + @Test + public void testTypeMatch() { + FooEvent fooEvent = new FooEvent(); + channel.pipeline().fireUserEventTriggered(fooEvent); + assertEquals(1, fooEventCatcher.caughtEvents.size()); + assertEquals(0, allEventCatcher.caughtEvents.size()); + assertEquals(0, fooEvent.refCnt()); + assertFalse(channel.finish()); + } + + @Test + public void testTypeMismatch() { + BarEvent barEvent = new BarEvent(); + channel.pipeline().fireUserEventTriggered(barEvent); + assertEquals(0, fooEventCatcher.caughtEvents.size()); + assertEquals(1, allEventCatcher.caughtEvents.size()); + assertTrue(barEvent.release()); + assertFalse(channel.finish()); + } + + static final class FooEvent extends DefaultByteBufHolder { + FooEvent() { + super(Unpooled.buffer()); + } + } + + static final class BarEvent extends DefaultByteBufHolder { + BarEvent() { + super(Unpooled.buffer()); + } + } + + static final class FooEventCatcher extends SimpleUserEventChannelHandler { + + public List caughtEvents; + + FooEventCatcher() { + caughtEvents = new ArrayList(); + } + + @Override + protected void eventReceived(ChannelHandlerContext ctx, FooEvent evt) { + caughtEvents.add(evt); + } + } + + static final class AllEventCatcher extends ChannelInboundHandlerAdapter { + + public List caughtEvents; + + AllEventCatcher() { + caughtEvents = new ArrayList(); + } + + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + caughtEvents.add(evt); + } + } +} diff --git a/netty-channel/src/test/java/io/netty/channel/SingleThreadEventLoopTest.java b/netty-channel/src/test/java/io/netty/channel/SingleThreadEventLoopTest.java new file mode 100644 index 0000000..ead6299 --- /dev/null +++ b/netty-channel/src/test/java/io/netty/channel/SingleThreadEventLoopTest.java @@ -0,0 +1,584 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import io.netty.channel.local.LocalChannel; +import io.netty.util.concurrent.EventExecutor; +import org.hamcrest.MatcherAssert; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Queue; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Executors; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.*; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class SingleThreadEventLoopTest { + + private static final Runnable NOOP = new Runnable() { + @Override + public void run() { } + }; + + private SingleThreadEventLoopA loopA; + private SingleThreadEventLoopB loopB; + private SingleThreadEventLoopC loopC; + + @BeforeEach + public void newEventLoop() { + loopA = new SingleThreadEventLoopA(); + loopB = new SingleThreadEventLoopB(); + loopC = new SingleThreadEventLoopC(); + } + + @AfterEach + public void stopEventLoop() { + if (!loopA.isShuttingDown()) { + loopA.shutdownGracefully(0, 0, TimeUnit.MILLISECONDS); + } + if (!loopB.isShuttingDown()) { + loopB.shutdownGracefully(0, 0, TimeUnit.MILLISECONDS); + } + if (!loopC.isShuttingDown()) { + loopC.shutdownGracefully(0, 0, TimeUnit.MILLISECONDS); + } + + while (!loopA.isTerminated()) { + try { + loopA.awaitTermination(1, TimeUnit.DAYS); + } catch (InterruptedException e) { + // Ignore + } + } + assertEquals(1, loopA.cleanedUp.get()); + + while (!loopB.isTerminated()) { + try { + loopB.awaitTermination(1, TimeUnit.DAYS); + } catch (InterruptedException e) { + // Ignore + } + } + + while (!loopC.isTerminated()) { + try { + loopC.awaitTermination(1, TimeUnit.DAYS); + } catch (InterruptedException e) { + // Ignore + } + } + } + + @Test + @SuppressWarnings("deprecation") + public void shutdownBeforeStart() throws Exception { + loopA.shutdown(); + assertRejection(loopA); + } + + @Test + @SuppressWarnings("deprecation") + public void shutdownAfterStart() throws Exception { + final CountDownLatch latch = new CountDownLatch(1); + loopA.execute(new Runnable() { + @Override + public void run() { + latch.countDown(); + } + }); + + // Wait for the event loop thread to start. + latch.await(); + + // Request the event loop thread to stop. + loopA.shutdown(); + assertRejection(loopA); + + assertTrue(loopA.isShutdown()); + + // Wait until the event loop is terminated. + while (!loopA.isTerminated()) { + loopA.awaitTermination(1, TimeUnit.DAYS); + } + } + + private static void assertRejection(EventExecutor loop) { + try { + loop.execute(NOOP); + fail("A task must be rejected after shutdown() is called."); + } catch (RejectedExecutionException e) { + // Expected + } + } + + @Test + public void scheduleTaskA() throws Exception { + testScheduleTask(loopA); + } + + @Test + public void scheduleTaskB() throws Exception { + testScheduleTask(loopB); + } + + @Test + public void scheduleTaskC() throws Exception { + testScheduleTask(loopC); + } + + private static void testScheduleTask(EventLoop loopA) throws InterruptedException, ExecutionException { + long startTime = System.nanoTime(); + final AtomicLong endTime = new AtomicLong(); + loopA.schedule(new Runnable() { + @Override + public void run() { + endTime.set(System.nanoTime()); + } + }, 500, TimeUnit.MILLISECONDS).get(); + assertThat(endTime.get() - startTime, + is(greaterThanOrEqualTo(TimeUnit.MILLISECONDS.toNanos(500)))); + } + + @Test + @Timeout(value = 5000, unit = TimeUnit.MILLISECONDS) + public void scheduleTaskAtFixedRateA() throws Exception { + testScheduleTaskAtFixedRate(loopA); + } + + @Test + @Timeout(value = 5000, unit = TimeUnit.MILLISECONDS) + public void scheduleTaskAtFixedRateB() throws Exception { + testScheduleTaskAtFixedRate(loopB); + } + + private static void testScheduleTaskAtFixedRate(EventLoop loopA) throws InterruptedException { + final Queue timestamps = new LinkedBlockingQueue(); + final int expectedTimeStamps = 5; + final CountDownLatch allTimeStampsLatch = new CountDownLatch(expectedTimeStamps); + ScheduledFuture f = loopA.scheduleAtFixedRate(new Runnable() { + @Override + public void run() { + timestamps.add(System.nanoTime()); + try { + Thread.sleep(50); + } catch (InterruptedException e) { + // Ignore + } + allTimeStampsLatch.countDown(); + } + }, 100, 100, TimeUnit.MILLISECONDS); + allTimeStampsLatch.await(); + assertTrue(f.cancel(true)); + Thread.sleep(300); + assertEquals(expectedTimeStamps, timestamps.size()); + + // Check if the task was run without a lag. + Long firstTimestamp = null; + int cnt = 0; + for (Long t: timestamps) { + if (firstTimestamp == null) { + firstTimestamp = t; + continue; + } + + long timepoint = t - firstTimestamp; + assertThat(timepoint, is(greaterThanOrEqualTo(TimeUnit.MILLISECONDS.toNanos(100 * cnt + 80)))); + assertThat(timepoint, is(lessThan(TimeUnit.MILLISECONDS.toNanos(100 * (cnt + 1) + 20)))); + + cnt ++; + } + } + + @Test + @Timeout(value = 5000, unit = TimeUnit.MILLISECONDS) + public void scheduleLaggyTaskAtFixedRateA() throws Exception { + testScheduleLaggyTaskAtFixedRate(loopA); + } + + @Test + @Timeout(value = 5000, unit = TimeUnit.MILLISECONDS) + public void scheduleLaggyTaskAtFixedRateB() throws Exception { + testScheduleLaggyTaskAtFixedRate(loopB); + } + + private static void testScheduleLaggyTaskAtFixedRate(EventLoop loopA) throws InterruptedException { + final Queue timestamps = new LinkedBlockingQueue(); + final int expectedTimeStamps = 5; + final CountDownLatch allTimeStampsLatch = new CountDownLatch(expectedTimeStamps); + ScheduledFuture f = loopA.scheduleAtFixedRate(new Runnable() { + @Override + public void run() { + boolean empty = timestamps.isEmpty(); + timestamps.add(System.nanoTime()); + if (empty) { + try { + Thread.sleep(401); + } catch (InterruptedException e) { + // Ignore + } + } + allTimeStampsLatch.countDown(); + } + }, 100, 100, TimeUnit.MILLISECONDS); + allTimeStampsLatch.await(); + assertTrue(f.cancel(true)); + Thread.sleep(300); + assertEquals(expectedTimeStamps, timestamps.size()); + + // Check if the task was run with lag. + int i = 0; + Long previousTimestamp = null; + for (Long t: timestamps) { + if (previousTimestamp == null) { + previousTimestamp = t; + continue; + } + + long diff = t.longValue() - previousTimestamp.longValue(); + if (i == 0) { + assertThat(diff, is(greaterThanOrEqualTo(TimeUnit.MILLISECONDS.toNanos(400)))); + } else { + assertThat(diff, is(lessThanOrEqualTo(TimeUnit.MILLISECONDS.toNanos(10)))); + } + previousTimestamp = t; + i ++; + } + } + + @Test + @Timeout(value = 5000, unit = TimeUnit.MILLISECONDS) + public void scheduleTaskWithFixedDelayA() throws Exception { + testScheduleTaskWithFixedDelay(loopA); + } + + @Test + @Timeout(value = 5000, unit = TimeUnit.MILLISECONDS) + public void scheduleTaskWithFixedDelayB() throws Exception { + testScheduleTaskWithFixedDelay(loopB); + } + + private static void testScheduleTaskWithFixedDelay(EventLoop loopA) throws InterruptedException { + final Queue timestamps = new LinkedBlockingQueue(); + final int expectedTimeStamps = 3; + final CountDownLatch allTimeStampsLatch = new CountDownLatch(expectedTimeStamps); + ScheduledFuture f = loopA.scheduleWithFixedDelay(new Runnable() { + @Override + public void run() { + timestamps.add(System.nanoTime()); + try { + Thread.sleep(51); + } catch (InterruptedException e) { + // Ignore + } + allTimeStampsLatch.countDown(); + } + }, 100, 100, TimeUnit.MILLISECONDS); + allTimeStampsLatch.await(); + assertTrue(f.cancel(true)); + Thread.sleep(300); + assertEquals(expectedTimeStamps, timestamps.size()); + + // Check if the task was run without a lag. + Long previousTimestamp = null; + for (Long t: timestamps) { + if (previousTimestamp == null) { + previousTimestamp = t; + continue; + } + + assertThat(t.longValue() - previousTimestamp.longValue(), + is(greaterThanOrEqualTo(TimeUnit.MILLISECONDS.toNanos(150)))); + previousTimestamp = t; + } + } + + @Test + @SuppressWarnings("deprecation") + public void shutdownWithPendingTasks() throws Exception { + final int NUM_TASKS = 3; + final AtomicInteger ranTasks = new AtomicInteger(); + final CountDownLatch latch = new CountDownLatch(1); + final Runnable task = new Runnable() { + @Override + public void run() { + ranTasks.incrementAndGet(); + while (latch.getCount() > 0) { + try { + latch.await(); + } catch (InterruptedException e) { + // Ignored + } + } + } + }; + + for (int i = 0; i < NUM_TASKS; i ++) { + loopA.execute(task); + } + + // At this point, the first task should be running and stuck at latch.await(). + while (ranTasks.get() == 0) { + Thread.yield(); + } + assertEquals(1, ranTasks.get()); + + // Shut down the event loop to test if the other tasks are run before termination. + loopA.shutdown(); + + // Let the other tasks run. + latch.countDown(); + + // Wait until the event loop is terminated. + while (!loopA.isTerminated()) { + loopA.awaitTermination(1, TimeUnit.DAYS); + } + + // Make sure loop.shutdown() above triggered wakeup(). + assertEquals(NUM_TASKS, ranTasks.get()); + } + + @Test + @Timeout(value = 10000, unit = TimeUnit.MILLISECONDS) + @SuppressWarnings("deprecation") + public void testRegistrationAfterShutdown() throws Exception { + loopA.shutdown(); + ChannelFuture f = loopA.register(new LocalChannel()); + f.awaitUninterruptibly(); + assertFalse(f.isSuccess()); + assertThat(f.cause(), is(instanceOf(RejectedExecutionException.class))); + assertFalse(f.channel().isOpen()); + } + + @Test + @Timeout(value = 10000, unit = TimeUnit.MILLISECONDS) + @SuppressWarnings("deprecation") + public void testRegistrationAfterShutdown2() throws Exception { + loopA.shutdown(); + final CountDownLatch latch = new CountDownLatch(1); + Channel ch = new LocalChannel(); + ChannelPromise promise = ch.newPromise(); + promise.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + latch.countDown(); + } + }); + ChannelFuture f = loopA.register(promise); + f.awaitUninterruptibly(); + assertFalse(f.isSuccess()); + assertThat(f.cause(), is(instanceOf(RejectedExecutionException.class))); + + // Ensure the listener was notified. + assertFalse(latch.await(1, TimeUnit.SECONDS)); + assertFalse(ch.isOpen()); + } + + @Test + @Timeout(value = 5000, unit = TimeUnit.MILLISECONDS) + public void testGracefulShutdownQuietPeriod() throws Exception { + loopA.shutdownGracefully(1, Integer.MAX_VALUE, TimeUnit.SECONDS); + // Keep Scheduling tasks for another 2 seconds. + for (int i = 0; i < 20; i ++) { + Thread.sleep(100); + loopA.execute(NOOP); + } + + long startTime = System.nanoTime(); + + assertThat(loopA.isShuttingDown(), is(true)); + assertThat(loopA.isShutdown(), is(false)); + + while (!loopA.isTerminated()) { + loopA.awaitTermination(Integer.MAX_VALUE, TimeUnit.SECONDS); + } + + assertThat(System.nanoTime() - startTime, + is(greaterThanOrEqualTo(TimeUnit.SECONDS.toNanos(1)))); + } + + @Test + @Timeout(value = 5000, unit = TimeUnit.MILLISECONDS) + public void testGracefulShutdownTimeout() throws Exception { + loopA.shutdownGracefully(2, 2, TimeUnit.SECONDS); + // Keep Scheduling tasks for another 3 seconds. + // Submitted tasks must be rejected after 2 second timeout. + for (int i = 0; i < 10; i ++) { + Thread.sleep(100); + loopA.execute(NOOP); + } + + try { + for (int i = 0; i < 20; i ++) { + Thread.sleep(100); + loopA.execute(NOOP); + } + fail("shutdownGracefully() must reject a task after timeout."); + } catch (RejectedExecutionException e) { + // Expected + } + + assertThat(loopA.isShuttingDown(), is(true)); + assertThat(loopA.isShutdown(), is(true)); + } + + @Test + @Timeout(value = 10000, unit = TimeUnit.MILLISECONDS) + public void testOnEventLoopIteration() throws Exception { + CountingRunnable onIteration = new CountingRunnable(); + loopC.executeAfterEventLoopIteration(onIteration); + CountingRunnable noopTask = new CountingRunnable(); + loopC.submit(noopTask).sync(); + loopC.iterationEndSignal.take(); + MatcherAssert.assertThat("Unexpected invocation count for regular task.", + noopTask.getInvocationCount(), is(1)); + MatcherAssert.assertThat("Unexpected invocation count for on every eventloop iteration task.", + onIteration.getInvocationCount(), is(1)); + } + + @Test + @Timeout(value = 10000, unit = TimeUnit.MILLISECONDS) + public void testRemoveOnEventLoopIteration() throws Exception { + CountingRunnable onIteration1 = new CountingRunnable(); + loopC.executeAfterEventLoopIteration(onIteration1); + CountingRunnable onIteration2 = new CountingRunnable(); + loopC.executeAfterEventLoopIteration(onIteration2); + loopC.removeAfterEventLoopIterationTask(onIteration1); + CountingRunnable noopTask = new CountingRunnable(); + loopC.submit(noopTask).sync(); + + loopC.iterationEndSignal.take(); + MatcherAssert.assertThat("Unexpected invocation count for regular task.", + noopTask.getInvocationCount(), is(1)); + MatcherAssert.assertThat("Unexpected invocation count for on every eventloop iteration task.", + onIteration2.getInvocationCount(), is(1)); + MatcherAssert.assertThat("Unexpected invocation count for on every eventloop iteration task.", + onIteration1.getInvocationCount(), is(0)); + } + + private static final class SingleThreadEventLoopA extends SingleThreadEventLoop { + + final AtomicInteger cleanedUp = new AtomicInteger(); + + SingleThreadEventLoopA() { + super(null, Executors.defaultThreadFactory(), true); + } + + @Override + protected void run() { + for (;;) { + Runnable task = takeTask(); + if (task != null) { + task.run(); + updateLastExecutionTime(); + } + + if (confirmShutdown()) { + break; + } + } + } + + @Override + protected void cleanup() { + cleanedUp.incrementAndGet(); + } + } + + private static class SingleThreadEventLoopB extends SingleThreadEventLoop { + + SingleThreadEventLoopB() { + super(null, Executors.defaultThreadFactory(), false); + } + + @Override + protected void run() { + for (;;) { + try { + Thread.sleep(TimeUnit.NANOSECONDS.toMillis(delayNanos(System.nanoTime()))); + } catch (InterruptedException e) { + // Waken up by interruptThread() + } + + runTasks0(); + + if (confirmShutdown()) { + break; + } + } + } + + protected void runTasks0() { + runAllTasks(); + } + + @Override + protected void wakeup(boolean inEventLoop) { + interruptThread(); + } + } + + private static final class SingleThreadEventLoopC extends SingleThreadEventLoopB { + + final LinkedBlockingQueue iterationEndSignal = new LinkedBlockingQueue(1); + + @Override + protected void afterRunningAllTasks() { + super.afterRunningAllTasks(); + iterationEndSignal.offer(true); + } + + @Override + protected void runTasks0() { + runAllTasks(TimeUnit.MINUTES.toNanos(1)); + } + } + + private static class CountingRunnable implements Runnable { + + private final AtomicInteger invocationCount = new AtomicInteger(); + + @Override + public void run() { + invocationCount.incrementAndGet(); + } + + public int getInvocationCount() { + return invocationCount.get(); + } + + public void resetInvocationCount() { + invocationCount.set(0); + } + } +} diff --git a/netty-channel/src/test/java/io/netty/channel/SucceededChannelFutureTest.java b/netty-channel/src/test/java/io/netty/channel/SucceededChannelFutureTest.java new file mode 100644 index 0000000..025609e --- /dev/null +++ b/netty-channel/src/test/java/io/netty/channel/SucceededChannelFutureTest.java @@ -0,0 +1,33 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel; + +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; + +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class SucceededChannelFutureTest { + @Test + public void testConstantProperties() { + Channel channel = Mockito.mock(Channel.class); + SucceededChannelFuture future = new SucceededChannelFuture(channel, null); + + assertTrue(future.isSuccess()); + assertNull(future.cause()); + } +} diff --git a/netty-channel/src/test/java/io/netty/channel/ThreadPerChannelEventLoopGroupTest.java b/netty-channel/src/test/java/io/netty/channel/ThreadPerChannelEventLoopGroupTest.java new file mode 100644 index 0000000..c6f3642 --- /dev/null +++ b/netty-channel/src/test/java/io/netty/channel/ThreadPerChannelEventLoopGroupTest.java @@ -0,0 +1,117 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.channel; + +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.channel.group.ChannelGroup; +import io.netty.channel.group.DefaultChannelGroup; +import io.netty.util.concurrent.DefaultPromise; +import io.netty.util.concurrent.DefaultThreadFactory; +import io.netty.util.concurrent.EventExecutor; +import io.netty.util.concurrent.GlobalEventExecutor; +import io.netty.util.concurrent.Promise; +import io.netty.util.concurrent.SingleThreadEventExecutor; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; + +import java.lang.reflect.Field; +import java.util.concurrent.TimeUnit; + +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +@Disabled("Flaky test; See: https://github.com/netty/netty/issues/11551") +public class ThreadPerChannelEventLoopGroupTest { + + private static final ChannelHandler NOOP_HANDLER = new ChannelHandlerAdapter() { + @Override + public boolean isSharable() { + return true; + } + }; + + @Test + public void testTerminationFutureSuccessInLog() throws Exception { + for (int i = 0; i < 2; i++) { + ThreadPerChannelEventLoopGroup loopGroup = new ThreadPerChannelEventLoopGroup(64); + runTest(loopGroup); + } + } + + @Test + public void testTerminationFutureSuccessReflectively() throws Exception { + Field terminationFutureField = + ThreadPerChannelEventLoopGroup.class.getDeclaredField("terminationFuture"); + terminationFutureField.setAccessible(true); + final Exception[] exceptionHolder = new Exception[1]; + for (int i = 0; i < 2; i++) { + ThreadPerChannelEventLoopGroup loopGroup = new ThreadPerChannelEventLoopGroup(64); + Promise promise = new DefaultPromise(GlobalEventExecutor.INSTANCE) { + @Override + public Promise setSuccess(Void result) { + try { + return super.setSuccess(result); + } catch (IllegalStateException e) { + exceptionHolder[0] = e; + throw e; + } + } + }; + terminationFutureField.set(loopGroup, promise); + runTest(loopGroup); + } + // The global event executor will not terminate, but this will give the test a chance to fail. + GlobalEventExecutor.INSTANCE.awaitTermination(100, TimeUnit.MILLISECONDS); + assertNull(exceptionHolder[0]); + } + + private static void runTest(ThreadPerChannelEventLoopGroup loopGroup) throws InterruptedException { + int taskCount = 100; + EventExecutor testExecutor = new TestEventExecutor(); + ChannelGroup channelGroup = new DefaultChannelGroup(testExecutor); + while (taskCount-- > 0) { + Channel channel = new EmbeddedChannel(NOOP_HANDLER); + loopGroup.register(new DefaultChannelPromise(channel, testExecutor)); + channelGroup.add(channel); + } + channelGroup.close().sync(); + loopGroup.shutdownGracefully(100, 200, TimeUnit.MILLISECONDS).sync(); + assertTrue(loopGroup.isTerminated()); + } + + private static class TestEventExecutor extends SingleThreadEventExecutor { + + TestEventExecutor() { + super(null, new DefaultThreadFactory("test"), false); + } + + @Override + protected void run() { + for (;;) { + Runnable task = takeTask(); + if (task != null) { + task.run(); + updateLastExecutionTime(); + } + + if (confirmShutdown()) { + break; + } + } + } + } +} diff --git a/netty-channel/src/test/java/io/netty/channel/embedded/CustomChannelId.java b/netty-channel/src/test/java/io/netty/channel/embedded/CustomChannelId.java new file mode 100644 index 0000000..3d9a7a7 --- /dev/null +++ b/netty-channel/src/test/java/io/netty/channel/embedded/CustomChannelId.java @@ -0,0 +1,65 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.embedded; + +import io.netty.channel.ChannelId; +import io.netty.util.internal.MathUtil; + +public class CustomChannelId implements ChannelId { + + private static final long serialVersionUID = 1L; + + private final int id; + + CustomChannelId(int id) { + this.id = id; + } + + @Override + public int compareTo(final ChannelId o) { + if (o instanceof CustomChannelId) { + return MathUtil.compare(id, ((CustomChannelId) o).id); + } + + return asLongText().compareTo(o.asLongText()); + } + + @Override + public int hashCode() { + return id; + } + + @Override + public boolean equals(Object obj) { + return obj instanceof CustomChannelId && id == ((CustomChannelId) obj).id; + } + + @Override + public String toString() { + return "CustomChannelId " + id; + } + + @Override + public String asShortText() { + return toString(); + } + + @Override + public String asLongText() { + return toString(); + } + +} diff --git a/netty-channel/src/test/java/io/netty/channel/embedded/EmbeddedChannelIdTest.java b/netty-channel/src/test/java/io/netty/channel/embedded/EmbeddedChannelIdTest.java new file mode 100644 index 0000000..3c4949e --- /dev/null +++ b/netty-channel/src/test/java/io/netty/channel/embedded/EmbeddedChannelIdTest.java @@ -0,0 +1,59 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.embedded; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufInputStream; +import io.netty.buffer.ByteBufOutputStream; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelId; +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class EmbeddedChannelIdTest { + + @Test + public void testSerialization() throws IOException, ClassNotFoundException { + // test that a deserialized instance works the same as a normal instance (issue #2869) + ChannelId normalInstance = EmbeddedChannelId.INSTANCE; + + ByteBuf buf = Unpooled.buffer(); + ObjectOutputStream outStream = new ObjectOutputStream(new ByteBufOutputStream(buf)); + try { + outStream.writeObject(normalInstance); + } finally { + outStream.close(); + } + + ObjectInputStream inStream = new ObjectInputStream(new ByteBufInputStream(buf, true)); + final ChannelId deserializedInstance; + try { + deserializedInstance = (ChannelId) inStream.readObject(); + } finally { + inStream.close(); + } + + assertEquals(normalInstance, deserializedInstance); + assertEquals(normalInstance.hashCode(), deserializedInstance.hashCode()); + assertEquals(0, normalInstance.compareTo(deserializedInstance)); + } + +} diff --git a/netty-channel/src/test/java/io/netty/channel/embedded/EmbeddedChannelTest.java b/netty-channel/src/test/java/io/netty/channel/embedded/EmbeddedChannelTest.java new file mode 100644 index 0000000..efdce97 --- /dev/null +++ b/netty-channel/src/test/java/io/netty/channel/embedded/EmbeddedChannelTest.java @@ -0,0 +1,796 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.embedded; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerAdapter; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelId; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelOutboundHandlerAdapter; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.ChannelPromise; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.FutureListener; +import io.netty.util.concurrent.ScheduledFuture; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.nio.channels.ClosedChannelException; +import java.util.ArrayDeque; +import java.util.Queue; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class EmbeddedChannelTest { + + @Test + public void testParent() { + EmbeddedChannel parent = new EmbeddedChannel(); + EmbeddedChannel channel = new EmbeddedChannel(parent, EmbeddedChannelId.INSTANCE, true, false); + assertSame(parent, channel.parent()); + assertNull(parent.parent()); + + assertFalse(channel.finish()); + assertFalse(parent.finish()); + } + + @Test + public void testNotRegistered() throws Exception { + EmbeddedChannel channel = new EmbeddedChannel(false, false); + assertFalse(channel.isRegistered()); + channel.register(); + assertTrue(channel.isRegistered()); + assertFalse(channel.finish()); + } + + @Test + public void testRegistered() throws Exception { + EmbeddedChannel channel = new EmbeddedChannel(true, false); + assertTrue(channel.isRegistered()); + try { + channel.register(); + fail(); + } catch (IllegalStateException expected) { + // This is expected the channel is registered already on an EventLoop. + } + assertFalse(channel.finish()); + } + + @Test + @Timeout(value = 2000, unit = TimeUnit.MILLISECONDS) + public void promiseDoesNotInfiniteLoop() throws InterruptedException { + EmbeddedChannel channel = new EmbeddedChannel(); + channel.closeFuture().addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + future.channel().close(); + } + }); + + channel.close().syncUninterruptibly(); + } + + @Test + public void testConstructWithChannelInitializer() { + final Integer first = 1; + final Integer second = 2; + + final ChannelHandler handler = new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + ctx.fireChannelRead(first); + ctx.fireChannelRead(second); + } + }; + EmbeddedChannel channel = new EmbeddedChannel(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ch.pipeline().addLast(handler); + } + }); + ChannelPipeline pipeline = channel.pipeline(); + assertSame(handler, pipeline.firstContext().handler()); + assertTrue(channel.writeInbound(3)); + assertTrue(channel.finish()); + assertSame(first, channel.readInbound()); + assertSame(second, channel.readInbound()); + assertNull(channel.readInbound()); + } + + @SuppressWarnings({ "rawtypes", "unchecked" }) + @Test + public void testScheduling() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new ChannelInboundHandlerAdapter()); + final CountDownLatch latch = new CountDownLatch(2); + Future future = ch.eventLoop().schedule(new Runnable() { + @Override + public void run() { + latch.countDown(); + } + }, 1, TimeUnit.SECONDS); + future.addListener(new FutureListener() { + @Override + public void operationComplete(Future future) throws Exception { + latch.countDown(); + } + }); + long next = ch.runScheduledPendingTasks(); + assertTrue(next > 0); + // Sleep for the nanoseconds but also give extra 50ms as the clock my not be very precise and so fail the test + // otherwise. + Thread.sleep(TimeUnit.NANOSECONDS.toMillis(next) + 50); + assertEquals(-1, ch.runScheduledPendingTasks()); + latch.await(); + } + + @Test + public void testScheduledCancelled() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new ChannelInboundHandlerAdapter()); + Future future = ch.eventLoop().schedule(new Runnable() { + @Override + public void run() { } + }, 1, TimeUnit.DAYS); + ch.finish(); + assertTrue(future.isCancelled()); + } + + @Test + @Timeout(value = 3000, unit = TimeUnit.MILLISECONDS) + public void testHandlerAddedExecutedInEventLoop() throws Throwable { + final CountDownLatch latch = new CountDownLatch(1); + final AtomicReference error = new AtomicReference(); + final ChannelHandler handler = new ChannelHandlerAdapter() { + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + try { + assertTrue(ctx.executor().inEventLoop()); + } catch (Throwable cause) { + error.set(cause); + } finally { + latch.countDown(); + } + } + }; + EmbeddedChannel channel = new EmbeddedChannel(handler); + assertFalse(channel.finish()); + latch.await(); + Throwable cause = error.get(); + if (cause != null) { + throw cause; + } + } + + @Test + public void testConstructWithOutHandler() { + EmbeddedChannel channel = new EmbeddedChannel(); + assertTrue(channel.writeInbound(1)); + assertTrue(channel.writeOutbound(2)); + assertTrue(channel.finish()); + assertSame(1, channel.readInbound()); + assertNull(channel.readInbound()); + assertSame(2, channel.readOutbound()); + assertNull(channel.readOutbound()); + } + + @Test + public void testConstructWithChannelId() { + ChannelId channelId = new CustomChannelId(1); + EmbeddedChannel channel = new EmbeddedChannel(channelId); + assertSame(channelId, channel.id()); + } + + // See https://github.com/netty/netty/issues/4316. + @Test + @Timeout(value = 2000, unit = TimeUnit.MILLISECONDS) + public void testFireChannelInactiveAndUnregisteredOnClose() throws InterruptedException { + testFireChannelInactiveAndUnregistered(new Action() { + @Override + public ChannelFuture doRun(Channel channel) { + return channel.close(); + } + }); + testFireChannelInactiveAndUnregistered(new Action() { + @Override + public ChannelFuture doRun(Channel channel) { + return channel.close(channel.newPromise()); + } + }); + } + + @Test + @Timeout(value = 2000, unit = TimeUnit.MILLISECONDS) + public void testFireChannelInactiveAndUnregisteredOnDisconnect() throws InterruptedException { + testFireChannelInactiveAndUnregistered(new Action() { + @Override + public ChannelFuture doRun(Channel channel) { + return channel.disconnect(); + } + }); + + testFireChannelInactiveAndUnregistered(new Action() { + @Override + public ChannelFuture doRun(Channel channel) { + return channel.disconnect(channel.newPromise()); + } + }); + } + + private static void testFireChannelInactiveAndUnregistered(Action action) throws InterruptedException { + final CountDownLatch latch = new CountDownLatch(3); + EmbeddedChannel channel = new EmbeddedChannel(new ChannelInboundHandlerAdapter() { + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + latch.countDown(); + ctx.executor().execute(new Runnable() { + @Override + public void run() { + // Should be executed. + latch.countDown(); + } + }); + } + + @Override + public void channelUnregistered(ChannelHandlerContext ctx) throws Exception { + latch.countDown(); + } + }); + action.doRun(channel).syncUninterruptibly(); + latch.await(); + } + + private interface Action { + ChannelFuture doRun(Channel channel); + } + + @Test + public void testHasDisconnect() { + EventOutboundHandler handler = new EventOutboundHandler(); + EmbeddedChannel channel = new EmbeddedChannel(true, handler); + assertTrue(channel.disconnect().isSuccess()); + assertTrue(channel.close().isSuccess()); + assertEquals(EventOutboundHandler.DISCONNECT, handler.pollEvent()); + assertEquals(EventOutboundHandler.CLOSE, handler.pollEvent()); + assertNull(handler.pollEvent()); + } + + @Test + public void testHasNoDisconnect() { + EventOutboundHandler handler = new EventOutboundHandler(); + EmbeddedChannel channel = new EmbeddedChannel(false, handler); + assertTrue(channel.disconnect().isSuccess()); + assertTrue(channel.close().isSuccess()); + assertEquals(EventOutboundHandler.CLOSE, handler.pollEvent()); + assertEquals(EventOutboundHandler.CLOSE, handler.pollEvent()); + assertNull(handler.pollEvent()); + } + + @Test + public void testHasNoDisconnectSkipDisconnect() throws InterruptedException { + EmbeddedChannel channel = new EmbeddedChannel(false, new ChannelOutboundHandlerAdapter() { + @Override + public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + promise.tryFailure(new Throwable()); + } + }); + assertFalse(channel.disconnect().isSuccess()); + } + + @Test + public void testFinishAndReleaseAll() { + ByteBuf in = Unpooled.buffer(); + ByteBuf out = Unpooled.buffer(); + try { + EmbeddedChannel channel = new EmbeddedChannel(); + assertTrue(channel.writeInbound(in)); + assertEquals(1, in.refCnt()); + + assertTrue(channel.writeOutbound(out)); + assertEquals(1, out.refCnt()); + + assertTrue(channel.finishAndReleaseAll()); + assertEquals(0, in.refCnt()); + assertEquals(0, out.refCnt()); + + assertNull(channel.readInbound()); + assertNull(channel.readOutbound()); + } finally { + release(in, out); + } + } + + @Test + public void testReleaseInbound() { + ByteBuf in = Unpooled.buffer(); + ByteBuf out = Unpooled.buffer(); + try { + EmbeddedChannel channel = new EmbeddedChannel(); + assertTrue(channel.writeInbound(in)); + assertEquals(1, in.refCnt()); + + assertTrue(channel.writeOutbound(out)); + assertEquals(1, out.refCnt()); + + assertTrue(channel.releaseInbound()); + assertEquals(0, in.refCnt()); + assertEquals(1, out.refCnt()); + + assertTrue(channel.finish()); + assertNull(channel.readInbound()); + + ByteBuf buffer = channel.readOutbound(); + assertSame(out, buffer); + buffer.release(); + + assertNull(channel.readOutbound()); + } finally { + release(in, out); + } + } + + @Test + public void testReleaseOutbound() { + ByteBuf in = Unpooled.buffer(); + ByteBuf out = Unpooled.buffer(); + try { + EmbeddedChannel channel = new EmbeddedChannel(); + assertTrue(channel.writeInbound(in)); + assertEquals(1, in.refCnt()); + + assertTrue(channel.writeOutbound(out)); + assertEquals(1, out.refCnt()); + + assertTrue(channel.releaseOutbound()); + assertEquals(1, in.refCnt()); + assertEquals(0, out.refCnt()); + + assertTrue(channel.finish()); + assertNull(channel.readOutbound()); + + ByteBuf buffer = channel.readInbound(); + assertSame(in, buffer); + buffer.release(); + + assertNull(channel.readInbound()); + } finally { + release(in, out); + } + } + + @Test + public void testWriteLater() { + EmbeddedChannel channel = new EmbeddedChannel(new ChannelOutboundHandlerAdapter() { + @Override + public void write(final ChannelHandlerContext ctx, final Object msg, final ChannelPromise promise) + throws Exception { + ctx.executor().execute(new Runnable() { + @Override + public void run() { + ctx.write(msg, promise); + } + }); + } + }); + Object msg = new Object(); + + assertTrue(channel.writeOutbound(msg)); + assertTrue(channel.finish()); + assertSame(msg, channel.readOutbound()); + assertNull(channel.readOutbound()); + } + + @Test + public void testWriteScheduled() throws InterruptedException { + final int delay = 500; + EmbeddedChannel channel = new EmbeddedChannel(new ChannelOutboundHandlerAdapter() { + @Override + public void write(final ChannelHandlerContext ctx, final Object msg, final ChannelPromise promise) + throws Exception { + ctx.executor().schedule(new Runnable() { + @Override + public void run() { + ctx.writeAndFlush(msg, promise); + } + }, delay, TimeUnit.MILLISECONDS); + } + }); + Object msg = new Object(); + + assertFalse(channel.writeOutbound(msg)); + Thread.sleep(delay * 2); + assertTrue(channel.finish()); + assertSame(msg, channel.readOutbound()); + assertNull(channel.readOutbound()); + } + + @Test + public void testFlushInbound() throws InterruptedException { + final CountDownLatch latch = new CountDownLatch(1); + EmbeddedChannel channel = new EmbeddedChannel(new ChannelInboundHandlerAdapter() { + @Override + public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + latch.countDown(); + } + }); + + channel.flushInbound(); + + if (!latch.await(1L, TimeUnit.SECONDS)) { + fail("Nobody called #channelReadComplete() in time."); + } + } + + @Test + public void testWriteOneInbound() throws InterruptedException { + final CountDownLatch latch = new CountDownLatch(1); + final AtomicInteger flushCount = new AtomicInteger(0); + + EmbeddedChannel channel = new EmbeddedChannel(new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + ReferenceCountUtil.release(msg); + latch.countDown(); + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + flushCount.incrementAndGet(); + } + }); + + channel.writeOneInbound("Hello, Netty!"); + + if (!latch.await(1L, TimeUnit.SECONDS)) { + fail("Nobody called #channelRead() in time."); + } + + channel.close().syncUninterruptibly(); + + // There was no #flushInbound() call so nobody should have called + // #channelReadComplete() + assertEquals(0, flushCount.get()); + } + + @Test + public void testFlushOutbound() throws InterruptedException { + final CountDownLatch latch = new CountDownLatch(1); + EmbeddedChannel channel = new EmbeddedChannel(new ChannelOutboundHandlerAdapter() { + @Override + public void flush(ChannelHandlerContext ctx) throws Exception { + latch.countDown(); + } + }); + + channel.flushOutbound(); + + if (!latch.await(1L, TimeUnit.SECONDS)) { + fail("Nobody called #flush() in time."); + } + } + + @Test + public void testWriteOneOutbound() throws InterruptedException { + final CountDownLatch latch = new CountDownLatch(1); + final AtomicInteger flushCount = new AtomicInteger(0); + + EmbeddedChannel channel = new EmbeddedChannel(new ChannelOutboundHandlerAdapter() { + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + ctx.write(msg, promise); + latch.countDown(); + } + + @Override + public void flush(ChannelHandlerContext ctx) throws Exception { + flushCount.incrementAndGet(); + } + }); + + // This shouldn't trigger a #flush() + channel.writeOneOutbound("Hello, Netty!"); + + if (!latch.await(1L, TimeUnit.SECONDS)) { + fail("Nobody called #write() in time."); + } + + channel.close().syncUninterruptibly(); + + // There was no #flushOutbound() call so nobody should have called #flush() + assertEquals(0, flushCount.get()); + } + + @Test + public void testEnsureOpen() throws InterruptedException { + EmbeddedChannel channel = new EmbeddedChannel(); + channel.close().syncUninterruptibly(); + + try { + channel.writeOutbound("Hello, Netty!"); + fail("This should have failed with a ClosedChannelException"); + } catch (Exception expected) { + assertTrue(expected instanceof ClosedChannelException); + } + + try { + channel.writeInbound("Hello, Netty!"); + fail("This should have failed with a ClosedChannelException"); + } catch (Exception expected) { + assertTrue(expected instanceof ClosedChannelException); + } + } + + @Test + public void testHandleInboundMessage() throws InterruptedException { + final CountDownLatch latch = new CountDownLatch(1); + + EmbeddedChannel channel = new EmbeddedChannel() { + @Override + protected void handleInboundMessage(Object msg) { + latch.countDown(); + } + }; + + channel.writeOneInbound("Hello, Netty!"); + + if (!latch.await(1L, TimeUnit.SECONDS)) { + fail("Nobody called #handleInboundMessage() in time."); + } + } + + @Test + public void testHandleOutboundMessage() throws InterruptedException { + final CountDownLatch latch = new CountDownLatch(1); + + EmbeddedChannel channel = new EmbeddedChannel() { + @Override + protected void handleOutboundMessage(Object msg) { + latch.countDown(); + } + }; + + channel.writeOneOutbound("Hello, Netty!"); + if (latch.await(50L, TimeUnit.MILLISECONDS)) { + fail("Somebody called unexpectedly #flush()"); + } + + channel.flushOutbound(); + if (!latch.await(1L, TimeUnit.SECONDS)) { + fail("Nobody called #handleOutboundMessage() in time."); + } + } + + @Test + @Timeout(value = 5000, unit = TimeUnit.MILLISECONDS) + public void testChannelInactiveFired() throws InterruptedException { + final AtomicBoolean inactive = new AtomicBoolean(); + EmbeddedChannel channel = new EmbeddedChannel(new ChannelInboundHandlerAdapter() { + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + ctx.close(); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + inactive.set(true); + } + }); + channel.pipeline().fireExceptionCaught(new IllegalStateException()); + + assertTrue(inactive.get()); + } + + @Test + public void testReRegisterEventLoop() throws Exception { + final CountDownLatch unregisteredLatch = new CountDownLatch(1); + final CountDownLatch registeredLatch = new CountDownLatch(2); + final EmbeddedChannel channel = new EmbeddedChannel(new ChannelInboundHandlerAdapter() { + @Override + public void channelUnregistered(ChannelHandlerContext ctx) { + unregisteredLatch.countDown(); + } + + @Override + public void channelRegistered(ChannelHandlerContext ctx) { + registeredLatch.countDown(); + } + }); + + final EmbeddedEventLoop embeddedEventLoop = new EmbeddedEventLoop(); + channel.deregister().addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + embeddedEventLoop.register(channel); + } + }); + + if (!unregisteredLatch.await(5, TimeUnit.SECONDS)) { + fail("Channel was not unregistered in time."); + } + + if (!registeredLatch.await(5, TimeUnit.SECONDS)) { + fail("Channel was not registered in time."); + } + + final CountDownLatch taskLatch = new CountDownLatch(1); + channel.eventLoop().execute(new Runnable() { + @Override + public void run() { + taskLatch.countDown(); + } + }); + + channel.runPendingTasks(); + if (!taskLatch.await(5, TimeUnit.SECONDS)) { + fail("Task was not executed in time."); + } + } + + @Test + void testRunPendingTasksForNotRegisteredChannel() { + final EmbeddedChannel channel = new EmbeddedChannel(false, false); + long nextScheduledTaskTime = 0; + try { + nextScheduledTaskTime = channel.runScheduledPendingTasks(); + channel.checkException(); + } catch (Throwable t) { + fail("Channel should not throw an exception for scheduled pending tasks if it is not registered", t); + } + + assertEquals(-1L, nextScheduledTaskTime); + + try { + channel.runPendingTasks(); + channel.checkException(); + } catch (Throwable t) { + fail("Channel should not throw an exception for pending tasks if it is not registered", t); + } + } + + @Test + @Timeout(30) // generous timeout, just make sure we don't actually wait for the full 10 mins... + void testAdvanceTime() { + EmbeddedChannel channel = new EmbeddedChannel(); + Runnable runnable = new Runnable() { + @Override + public void run() { + } + }; + ScheduledFuture future10 = channel.eventLoop().schedule(runnable, 10, TimeUnit.MINUTES); + ScheduledFuture future20 = channel.eventLoop().schedule(runnable, 20, TimeUnit.MINUTES); + + channel.runPendingTasks(); + assertFalse(future10.isDone()); + assertFalse(future20.isDone()); + + channel.advanceTimeBy(10, TimeUnit.MINUTES); + channel.runPendingTasks(); + assertTrue(future10.isDone()); + assertFalse(future20.isDone()); + } + + @Test + @Timeout(30) // generous timeout, just make sure we don't actually wait for the full 10 mins... + void testFreezeTime() { + EmbeddedChannel channel = new EmbeddedChannel(); + Runnable runnable = new Runnable() { + @Override + public void run() { + } + }; + + channel.freezeTime(); + // this future will complete after 10min + ScheduledFuture future10 = channel.eventLoop().schedule(runnable, 10, TimeUnit.MINUTES); + // this future will complete after 10min + 1ns + ScheduledFuture future101 = channel.eventLoop().schedule(runnable, + TimeUnit.MINUTES.toNanos(10) + 1, TimeUnit.NANOSECONDS); + // this future will complete after 20min + ScheduledFuture future20 = channel.eventLoop().schedule(runnable, 20, TimeUnit.MINUTES); + + channel.runPendingTasks(); + assertFalse(future10.isDone()); + assertFalse(future101.isDone()); + assertFalse(future20.isDone()); + + channel.advanceTimeBy(10, TimeUnit.MINUTES); + channel.runPendingTasks(); + assertTrue(future10.isDone()); + assertFalse(future101.isDone()); + assertFalse(future20.isDone()); + + channel.unfreezeTime(); + channel.runPendingTasks(); + assertTrue(future101.isDone()); + assertFalse(future20.isDone()); + } + + @Test + void testHasPendingTasks() { + EmbeddedChannel channel = new EmbeddedChannel(); + channel.freezeTime(); + Runnable runnable = new Runnable() { + @Override + public void run() { + } + }; + + // simple execute + assertFalse(channel.hasPendingTasks()); + channel.eventLoop().execute(runnable); + assertTrue(channel.hasPendingTasks()); + channel.runPendingTasks(); + assertFalse(channel.hasPendingTasks()); + + // schedule in the future (note: time is frozen above) + channel.eventLoop().schedule(runnable, 1, TimeUnit.SECONDS); + assertFalse(channel.hasPendingTasks()); + channel.runPendingTasks(); + assertFalse(channel.hasPendingTasks()); + channel.advanceTimeBy(1, TimeUnit.SECONDS); + assertTrue(channel.hasPendingTasks()); + channel.runPendingTasks(); + assertFalse(channel.hasPendingTasks()); + } + + private static void release(ByteBuf... buffers) { + for (ByteBuf buffer : buffers) { + if (buffer.refCnt() > 0) { + buffer.release(); + } + } + } + + private static final class EventOutboundHandler extends ChannelOutboundHandlerAdapter { + static final Integer DISCONNECT = 0; + static final Integer CLOSE = 1; + + private final Queue queue = new ArrayDeque(); + + @Override + public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + queue.add(DISCONNECT); + promise.setSuccess(); + } + + @Override + public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + queue.add(CLOSE); + promise.setSuccess(); + } + + Integer pollEvent() { + return queue.poll(); + } + } +} diff --git a/netty-channel/src/test/java/io/netty/channel/group/DefaultChannelGroupTest.java b/netty-channel/src/test/java/io/netty/channel/group/DefaultChannelGroupTest.java new file mode 100644 index 0000000..7ff3fff --- /dev/null +++ b/netty-channel/src/test/java/io/netty/channel/group/DefaultChannelGroupTest.java @@ -0,0 +1,60 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.group; + +import io.netty.bootstrap.ServerBootstrap; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.nio.NioServerSocketChannel; +import io.netty.util.concurrent.GlobalEventExecutor; +import org.junit.jupiter.api.Test; + +public class DefaultChannelGroupTest { + + // Test for #1183 + @Test + public void testNotThrowBlockingOperationException() throws Exception { + EventLoopGroup bossGroup = new NioEventLoopGroup(); + EventLoopGroup workerGroup = new NioEventLoopGroup(); + + final ChannelGroup allChannels = new DefaultChannelGroup(GlobalEventExecutor.INSTANCE); + + ServerBootstrap b = new ServerBootstrap(); + b.group(bossGroup, workerGroup); + b.childHandler(new ChannelInboundHandlerAdapter() { + @Override + public void channelActive(ChannelHandlerContext ctx) { + allChannels.add(ctx.channel()); + } + }); + b.channel(NioServerSocketChannel.class); + + ChannelFuture f = b.bind(0).syncUninterruptibly(); + + if (f.isSuccess()) { + allChannels.add(f.channel()); + allChannels.close().awaitUninterruptibly(); + } + + bossGroup.shutdownGracefully(); + workerGroup.shutdownGracefully(); + bossGroup.terminationFuture().sync(); + workerGroup.terminationFuture().sync(); + } +} diff --git a/netty-channel/src/test/java/io/netty/channel/local/LocalChannelTest.java b/netty-channel/src/test/java/io/netty/channel/local/LocalChannelTest.java new file mode 100644 index 0000000..44620c6 --- /dev/null +++ b/netty-channel/src/test/java/io/netty/channel/local/LocalChannelTest.java @@ -0,0 +1,1307 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.local; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.AbstractChannel; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelOption; +import io.netty.channel.ChannelPromise; +import io.netty.channel.DefaultEventLoopGroup; +import io.netty.channel.DefaultMaxMessagesRecvByteBufAllocator; +import io.netty.channel.EventLoop; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.channel.SingleThreadEventLoop; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.Promise; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.api.function.Executable; + +import java.net.ConnectException; +import java.nio.channels.ClosedChannelException; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Executor; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +import static java.util.concurrent.TimeUnit.SECONDS; +import static org.hamcrest.CoreMatchers.instanceOf; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class LocalChannelTest { + + private static final InternalLogger logger = InternalLoggerFactory.getInstance(LocalChannelTest.class); + + private static final LocalAddress TEST_ADDRESS = new LocalAddress("test.id"); + + private static EventLoopGroup group1; + private static EventLoopGroup group2; + private static EventLoopGroup sharedGroup; + + @BeforeAll + public static void beforeClass() { + group1 = new DefaultEventLoopGroup(2); + group2 = new DefaultEventLoopGroup(2); + sharedGroup = new DefaultEventLoopGroup(1); + } + + @AfterAll + public static void afterClass() throws InterruptedException { + Future group1Future = group1.shutdownGracefully(0, 0, SECONDS); + Future group2Future = group2.shutdownGracefully(0, 0, SECONDS); + Future sharedGroupFuture = sharedGroup.shutdownGracefully(0, 0, SECONDS); + group1Future.await(); + group2Future.await(); + sharedGroupFuture.await(); + } + + @Test + public void testLocalAddressReuse() throws Exception { + for (int i = 0; i < 2; i ++) { + Bootstrap cb = new Bootstrap(); + ServerBootstrap sb = new ServerBootstrap(); + + cb.group(group1) + .channel(LocalChannel.class) + .handler(new TestHandler()); + + sb.group(group2) + .channel(LocalServerChannel.class) + .childHandler(new ChannelInitializer() { + @Override + public void initChannel(LocalChannel ch) throws Exception { + ch.pipeline().addLast(new TestHandler()); + } + }); + + Channel sc = null; + Channel cc = null; + try { + // Start server + sc = sb.bind(TEST_ADDRESS).sync().channel(); + + final CountDownLatch latch = new CountDownLatch(1); + // Connect to the server + cc = cb.connect(sc.localAddress()).sync().channel(); + final Channel ccCpy = cc; + cc.eventLoop().execute(new Runnable() { + @Override + public void run() { + // Send a message event up the pipeline. + ccCpy.pipeline().fireChannelRead("Hello, World"); + latch.countDown(); + } + }); + assertTrue(latch.await(5, SECONDS)); + + // Close the channel + closeChannel(cc); + closeChannel(sc); + sc.closeFuture().sync(); + + assertNull(LocalChannelRegistry.get(TEST_ADDRESS), String.format( + "Expected null, got channel '%s' for local address '%s'", + LocalChannelRegistry.get(TEST_ADDRESS), TEST_ADDRESS)); + } finally { + closeChannel(cc); + closeChannel(sc); + } + } + } + + @Test + public void testWriteFailsFastOnClosedChannel() throws Exception { + Bootstrap cb = new Bootstrap(); + ServerBootstrap sb = new ServerBootstrap(); + + cb.group(group1) + .channel(LocalChannel.class) + .handler(new TestHandler()); + + sb.group(group2) + .channel(LocalServerChannel.class) + .childHandler(new ChannelInitializer() { + @Override + public void initChannel(LocalChannel ch) throws Exception { + ch.pipeline().addLast(new TestHandler()); + } + }); + + Channel sc = null; + Channel cc = null; + try { + // Start server + sc = sb.bind(TEST_ADDRESS).sync().channel(); + + // Connect to the server + cc = cb.connect(sc.localAddress()).sync().channel(); + + // Close the channel and write something. + cc.close().sync(); + try { + cc.writeAndFlush(new Object()).sync(); + fail("must raise a ClosedChannelException"); + } catch (Exception e) { + assertThat(e, is(instanceOf(ClosedChannelException.class))); + // Ensure that the actual write attempt on a closed channel was never made by asserting that + // the ClosedChannelException has been created by AbstractUnsafe rather than transport implementations. + if (e.getStackTrace().length > 0) { + assertThat( + e.getStackTrace()[0].getClassName(), is(AbstractChannel.class.getName() + + "$AbstractUnsafe")); + e.printStackTrace(); + } + } + } finally { + closeChannel(cc); + closeChannel(sc); + } + } + + @Test + public void testServerCloseChannelSameEventLoop() throws Exception { + final CountDownLatch latch = new CountDownLatch(1); + ServerBootstrap sb = new ServerBootstrap() + .group(group2) + .channel(LocalServerChannel.class) + .childHandler(new SimpleChannelInboundHandler() { + @Override + protected void channelRead0(ChannelHandlerContext ctx, Object msg) throws Exception { + ctx.close(); + latch.countDown(); + } + }); + Channel sc = null; + Channel cc = null; + try { + sc = sb.bind(TEST_ADDRESS).sync().channel(); + + Bootstrap b = new Bootstrap() + .group(group2) + .channel(LocalChannel.class) + .handler(new SimpleChannelInboundHandler() { + @Override + protected void channelRead0(ChannelHandlerContext ctx, Object msg) throws Exception { + // discard + } + }); + cc = b.connect(sc.localAddress()).sync().channel(); + cc.writeAndFlush(new Object()); + assertTrue(latch.await(5, SECONDS)); + } finally { + closeChannel(cc); + closeChannel(sc); + } + } + + @Test + public void localChannelRaceCondition() throws Exception { + final CountDownLatch closeLatch = new CountDownLatch(1); + final EventLoopGroup clientGroup = new DefaultEventLoopGroup(1) { + @Override + protected EventLoop newChild(Executor threadFactory, Object... args) + throws Exception { + return new SingleThreadEventLoop(this, threadFactory, true) { + @Override + protected void run() { + for (;;) { + Runnable task = takeTask(); + if (task != null) { + /* Only slow down the anonymous class in LocalChannel#doRegister() */ + if (task.getClass().getEnclosingClass() == LocalChannel.class) { + try { + closeLatch.await(); + } catch (InterruptedException e) { + throw new Error(e); + } + } + task.run(); + updateLastExecutionTime(); + } + + if (confirmShutdown()) { + break; + } + } + } + }; + } + }; + Channel sc = null; + Channel cc = null; + try { + ServerBootstrap sb = new ServerBootstrap(); + sc = sb.group(group2). + channel(LocalServerChannel.class). + childHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ch.close(); + closeLatch.countDown(); + } + }). + bind(TEST_ADDRESS). + sync().channel(); + Bootstrap bootstrap = new Bootstrap(); + bootstrap.group(clientGroup). + channel(LocalChannel.class). + handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + /* Do nothing */ + } + }); + ChannelFuture future = bootstrap.connect(sc.localAddress()); + assertTrue(future.await(2000), "Connection should finish, not time out"); + cc = future.channel(); + } finally { + closeChannel(cc); + closeChannel(sc); + clientGroup.shutdownGracefully(0, 0, SECONDS).await(); + } + } + + @Test + public void testReRegister() { + Bootstrap cb = new Bootstrap(); + ServerBootstrap sb = new ServerBootstrap(); + + cb.group(group1) + .channel(LocalChannel.class) + .handler(new TestHandler()); + + sb.group(group2) + .channel(LocalServerChannel.class) + .childHandler(new ChannelInitializer() { + @Override + public void initChannel(LocalChannel ch) throws Exception { + ch.pipeline().addLast(new TestHandler()); + } + }); + + Channel sc = null; + Channel cc = null; + try { + // Start server + sc = sb.bind(TEST_ADDRESS).syncUninterruptibly().channel(); + + // Connect to the server + cc = cb.connect(sc.localAddress()).syncUninterruptibly().channel(); + + cc.deregister().syncUninterruptibly(); + } finally { + closeChannel(cc); + closeChannel(sc); + } + } + + @Test + public void testCloseInWritePromiseCompletePreservesOrder() throws InterruptedException { + Bootstrap cb = new Bootstrap(); + ServerBootstrap sb = new ServerBootstrap(); + final CountDownLatch messageLatch = new CountDownLatch(2); + final ByteBuf data = Unpooled.wrappedBuffer(new byte[1024]); + + try { + cb.group(group1) + .channel(LocalChannel.class) + .handler(new TestHandler()); + + sb.group(group2) + .channel(LocalServerChannel.class) + .childHandler(new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if (msg.equals(data)) { + ReferenceCountUtil.safeRelease(msg); + messageLatch.countDown(); + } else { + super.channelRead(ctx, msg); + } + } + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + messageLatch.countDown(); + super.channelInactive(ctx); + } + }); + + Channel sc = null; + Channel cc = null; + try { + // Start server + sc = sb.bind(TEST_ADDRESS).syncUninterruptibly().channel(); + + // Connect to the server + cc = cb.connect(sc.localAddress()).syncUninterruptibly().channel(); + + final Channel ccCpy = cc; + // Make sure a write operation is executed in the eventloop + cc.pipeline().lastContext().executor().execute(new Runnable() { + @Override + public void run() { + ChannelPromise promise = ccCpy.newPromise(); + promise.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + ccCpy.pipeline().lastContext().close(); + } + }); + ccCpy.writeAndFlush(data.retainedDuplicate(), promise); + } + }); + + assertTrue(messageLatch.await(5, SECONDS)); + assertFalse(cc.isOpen()); + } finally { + closeChannel(cc); + closeChannel(sc); + } + } finally { + data.release(); + } + } + + @Test + public void testCloseAfterWriteInSameEventLoopPreservesOrder() throws InterruptedException { + Bootstrap cb = new Bootstrap(); + ServerBootstrap sb = new ServerBootstrap(); + final CountDownLatch messageLatch = new CountDownLatch(3); + final ByteBuf data = Unpooled.wrappedBuffer(new byte[1024]); + + try { + cb.group(sharedGroup) + .channel(LocalChannel.class) + .handler(new ChannelInboundHandlerAdapter() { + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + ctx.writeAndFlush(data.retainedDuplicate()); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if (data.equals(msg)) { + ReferenceCountUtil.safeRelease(msg); + messageLatch.countDown(); + } else { + super.channelRead(ctx, msg); + } + } + }); + + sb.group(sharedGroup) + .channel(LocalServerChannel.class) + .childHandler(new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if (data.equals(msg)) { + messageLatch.countDown(); + ctx.writeAndFlush(data.retainedDuplicate()); + ctx.close(); + } else { + super.channelRead(ctx, msg); + } + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + messageLatch.countDown(); + super.channelInactive(ctx); + } + }); + + Channel sc = null; + Channel cc = null; + try { + // Start server + sc = sb.bind(TEST_ADDRESS).syncUninterruptibly().channel(); + + // Connect to the server + cc = cb.connect(sc.localAddress()).syncUninterruptibly().channel(); + assertTrue(messageLatch.await(5, SECONDS)); + assertFalse(cc.isOpen()); + } finally { + closeChannel(cc); + closeChannel(sc); + } + } finally { + data.release(); + } + } + + @Test + public void testWriteInWritePromiseCompletePreservesOrder() throws InterruptedException { + Bootstrap cb = new Bootstrap(); + ServerBootstrap sb = new ServerBootstrap(); + final CountDownLatch messageLatch = new CountDownLatch(2); + final ByteBuf data = Unpooled.buffer(); + final ByteBuf data2 = Unpooled.buffer(); + data.writeInt(Integer.BYTES).writeInt(2); + data2.writeInt(Integer.BYTES).writeInt(1); + + try { + cb.group(group1) + .channel(LocalChannel.class) + .handler(new TestHandler()); + + sb.group(group2) + .channel(LocalServerChannel.class) + .childHandler(new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if (msg instanceof ByteBuf) { + ByteBuf buf = (ByteBuf) msg; + while (buf.isReadable()) { + int size = buf.readInt(); + ByteBuf slice = buf.readRetainedSlice(size); + try { + if (slice.readInt() == messageLatch.getCount()) { + messageLatch.countDown(); + } + } finally { + slice.release(); + } + } + buf.release(); + } else { + super.channelRead(ctx, msg); + } + } + }); + + Channel sc = null; + Channel cc = null; + try { + // Start server + sc = sb.bind(TEST_ADDRESS).syncUninterruptibly().channel(); + + // Connect to the server + cc = cb.connect(sc.localAddress()).syncUninterruptibly().channel(); + + final Channel ccCpy = cc; + // Make sure a write operation is executed in the eventloop + cc.pipeline().lastContext().executor().execute(new Runnable() { + @Override + public void run() { + ChannelPromise promise = ccCpy.newPromise(); + promise.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + ccCpy.writeAndFlush(data2.retainedDuplicate(), ccCpy.newPromise()); + } + }); + ccCpy.writeAndFlush(data.retainedDuplicate(), promise); + } + }); + + assertTrue(messageLatch.await(5, SECONDS)); + } finally { + closeChannel(cc); + closeChannel(sc); + } + } finally { + data.release(); + data2.release(); + } + } + + @Test + public void testPeerWriteInWritePromiseCompleteDifferentEventLoopPreservesOrder() throws InterruptedException { + Bootstrap cb = new Bootstrap(); + ServerBootstrap sb = new ServerBootstrap(); + final CountDownLatch messageLatch = new CountDownLatch(2); + final ByteBuf data = Unpooled.wrappedBuffer(new byte[1024]); + final ByteBuf data2 = Unpooled.wrappedBuffer(new byte[512]); + final CountDownLatch serverChannelLatch = new CountDownLatch(1); + final AtomicReference serverChannelRef = new AtomicReference(); + + cb.group(group1) + .channel(LocalChannel.class) + .handler(new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if (data2.equals(msg)) { + ReferenceCountUtil.safeRelease(msg); + messageLatch.countDown(); + } else { + super.channelRead(ctx, msg); + } + } + }); + + sb.group(group2) + .channel(LocalServerChannel.class) + .childHandler(new ChannelInitializer() { + @Override + public void initChannel(LocalChannel ch) throws Exception { + ch.pipeline().addLast(new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if (data.equals(msg)) { + ReferenceCountUtil.safeRelease(msg); + messageLatch.countDown(); + } else { + super.channelRead(ctx, msg); + } + } + }); + serverChannelRef.set(ch); + serverChannelLatch.countDown(); + } + }); + + Channel sc = null; + Channel cc = null; + try { + // Start server + sc = sb.bind(TEST_ADDRESS).syncUninterruptibly().channel(); + + // Connect to the server + cc = cb.connect(sc.localAddress()).syncUninterruptibly().channel(); + assertTrue(serverChannelLatch.await(5, SECONDS)); + + final Channel ccCpy = cc; + // Make sure a write operation is executed in the eventloop + cc.pipeline().lastContext().executor().execute(new Runnable() { + @Override + public void run() { + ChannelPromise promise = ccCpy.newPromise(); + promise.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + Channel serverChannelCpy = serverChannelRef.get(); + serverChannelCpy.writeAndFlush(data2.retainedDuplicate(), serverChannelCpy.newPromise()); + } + }); + ccCpy.writeAndFlush(data.retainedDuplicate(), promise); + } + }); + + assertTrue(messageLatch.await(5, SECONDS)); + } finally { + closeChannel(cc); + closeChannel(sc); + data.release(); + data2.release(); + } + } + + @Test + public void testPeerWriteInWritePromiseCompleteSameEventLoopPreservesOrder() throws InterruptedException { + Bootstrap cb = new Bootstrap(); + ServerBootstrap sb = new ServerBootstrap(); + final CountDownLatch messageLatch = new CountDownLatch(2); + final ByteBuf data = Unpooled.wrappedBuffer(new byte[1024]); + final ByteBuf data2 = Unpooled.wrappedBuffer(new byte[512]); + final CountDownLatch serverChannelLatch = new CountDownLatch(1); + final AtomicReference serverChannelRef = new AtomicReference(); + + try { + cb.group(sharedGroup) + .channel(LocalChannel.class) + .handler(new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if (data2.equals(msg) && messageLatch.getCount() == 1) { + ReferenceCountUtil.safeRelease(msg); + messageLatch.countDown(); + } else { + super.channelRead(ctx, msg); + } + } + }); + + sb.group(sharedGroup) + .channel(LocalServerChannel.class) + .childHandler(new ChannelInitializer() { + @Override + public void initChannel(LocalChannel ch) throws Exception { + ch.pipeline().addLast(new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if (data.equals(msg) && messageLatch.getCount() == 2) { + ReferenceCountUtil.safeRelease(msg); + messageLatch.countDown(); + } else { + super.channelRead(ctx, msg); + } + } + }); + serverChannelRef.set(ch); + serverChannelLatch.countDown(); + } + }); + + Channel sc = null; + Channel cc = null; + try { + // Start server + sc = sb.bind(TEST_ADDRESS).syncUninterruptibly().channel(); + + // Connect to the server + cc = cb.connect(sc.localAddress()).syncUninterruptibly().channel(); + assertTrue(serverChannelLatch.await(5, SECONDS)); + + final Channel ccCpy = cc; + // Make sure a write operation is executed in the eventloop + cc.pipeline().lastContext().executor().execute(new Runnable() { + @Override + public void run() { + ChannelPromise promise = ccCpy.newPromise(); + promise.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + Channel serverChannelCpy = serverChannelRef.get(); + serverChannelCpy.writeAndFlush( + data2.retainedDuplicate(), serverChannelCpy.newPromise()); + } + }); + ccCpy.writeAndFlush(data.retainedDuplicate(), promise); + } + }); + + assertTrue(messageLatch.await(5, SECONDS)); + } finally { + closeChannel(cc); + closeChannel(sc); + } + } finally { + data.release(); + data2.release(); + } + } + + @Test + public void testWriteWhilePeerIsClosedReleaseObjectAndFailPromise() throws InterruptedException { + Bootstrap cb = new Bootstrap(); + ServerBootstrap sb = new ServerBootstrap(); + final CountDownLatch serverMessageLatch = new CountDownLatch(1); + final LatchChannelFutureListener serverChannelCloseLatch = new LatchChannelFutureListener(1); + final LatchChannelFutureListener clientChannelCloseLatch = new LatchChannelFutureListener(1); + final CountDownLatch writeFailLatch = new CountDownLatch(1); + final ByteBuf data = Unpooled.wrappedBuffer(new byte[1024]); + final ByteBuf data2 = Unpooled.wrappedBuffer(new byte[512]); + final CountDownLatch serverChannelLatch = new CountDownLatch(1); + final AtomicReference serverChannelRef = new AtomicReference(); + + try { + cb.group(group1) + .channel(LocalChannel.class) + .handler(new TestHandler()); + + sb.group(group2) + .channel(LocalServerChannel.class) + .childHandler(new ChannelInitializer() { + @Override + public void initChannel(LocalChannel ch) throws Exception { + ch.pipeline().addLast(new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if (data.equals(msg)) { + ReferenceCountUtil.safeRelease(msg); + serverMessageLatch.countDown(); + } else { + super.channelRead(ctx, msg); + } + } + }); + serverChannelRef.set(ch); + serverChannelLatch.countDown(); + } + }); + + Channel sc = null; + Channel cc = null; + try { + // Start server + sc = sb.bind(TEST_ADDRESS).syncUninterruptibly().channel(); + + // Connect to the server + cc = cb.connect(sc.localAddress()).syncUninterruptibly().channel(); + assertTrue(serverChannelLatch.await(5, SECONDS)); + + final Channel ccCpy = cc; + final Channel serverChannelCpy = serverChannelRef.get(); + serverChannelCpy.closeFuture().addListener(serverChannelCloseLatch); + ccCpy.closeFuture().addListener(clientChannelCloseLatch); + + // Make sure a write operation is executed in the eventloop + cc.pipeline().lastContext().executor().execute(new Runnable() { + @Override + public void run() { + ccCpy.writeAndFlush(data.retainedDuplicate(), ccCpy.newPromise()) + .addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + serverChannelCpy.eventLoop().execute(new Runnable() { + @Override + public void run() { + // The point of this test is to write while the peer is closed, so we should + // ensure the peer is actually closed before we write. + int waitCount = 0; + while (ccCpy.isOpen()) { + try { + Thread.sleep(50); + } catch (InterruptedException ignored) { + // ignored + } + if (++waitCount > 5) { + fail(); + } + } + serverChannelCpy.writeAndFlush(data2.retainedDuplicate(), + serverChannelCpy.newPromise()) + .addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + if (!future.isSuccess() && + future.cause() instanceof ClosedChannelException) { + writeFailLatch.countDown(); + } + } + }); + } + }); + ccCpy.close(); + } + }); + } + }); + + assertTrue(serverMessageLatch.await(5, SECONDS)); + assertTrue(writeFailLatch.await(5, SECONDS)); + assertTrue(serverChannelCloseLatch.await(5, SECONDS)); + assertTrue(clientChannelCloseLatch.await(5, SECONDS)); + assertFalse(ccCpy.isOpen()); + assertFalse(serverChannelCpy.isOpen()); + } finally { + closeChannel(cc); + closeChannel(sc); + } + } finally { + data.release(); + data2.release(); + } + } + + @Test + @Timeout(value = 3000, unit = TimeUnit.MILLISECONDS) + public void testConnectFutureBeforeChannelActive() throws Exception { + Bootstrap cb = new Bootstrap(); + ServerBootstrap sb = new ServerBootstrap(); + + cb.group(group1) + .channel(LocalChannel.class) + .handler(new ChannelInboundHandlerAdapter()); + + sb.group(group2) + .channel(LocalServerChannel.class) + .childHandler(new ChannelInitializer() { + @Override + public void initChannel(LocalChannel ch) throws Exception { + ch.pipeline().addLast(new TestHandler()); + } + }); + + Channel sc = null; + Channel cc = null; + try { + // Start server + sc = sb.bind(TEST_ADDRESS).sync().channel(); + + cc = cb.register().sync().channel(); + + final ChannelPromise promise = cc.newPromise(); + final Promise assertPromise = cc.eventLoop().newPromise(); + + cc.pipeline().addLast(new TestHandler() { + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + // Ensure the promise was done before the handler method is triggered. + if (promise.isDone()) { + assertPromise.setSuccess(null); + } else { + assertPromise.setFailure(new AssertionError("connect promise should be done")); + } + } + }); + // Connect to the server + cc.connect(sc.localAddress(), promise).sync(); + + assertPromise.syncUninterruptibly(); + assertTrue(promise.isSuccess()); + } finally { + closeChannel(cc); + closeChannel(sc); + } + } + + @Test + public void testConnectionRefused() { + final Bootstrap sb = new Bootstrap(); + sb.group(group1) + .channel(LocalChannel.class) + .handler(new TestHandler()); + assertThrows(ConnectException.class, new Executable() { + @Override + public void execute() { + sb.connect(LocalAddress.ANY).syncUninterruptibly(); + } + }); + } + + private static final class LatchChannelFutureListener extends CountDownLatch implements ChannelFutureListener { + private LatchChannelFutureListener(int count) { + super(count); + } + + @Override + public void operationComplete(ChannelFuture future) throws Exception { + countDown(); + } + } + + private static void closeChannel(Channel cc) { + if (cc != null) { + cc.close().syncUninterruptibly(); + } + } + + static class TestHandler extends ChannelInboundHandlerAdapter { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + logger.info(String.format("Received message: %s", msg)); + ReferenceCountUtil.safeRelease(msg); + } + } + + @Test + public void testNotLeakBuffersWhenCloseByRemotePeer() throws Exception { + Bootstrap cb = new Bootstrap(); + ServerBootstrap sb = new ServerBootstrap(); + + cb.group(sharedGroup) + .channel(LocalChannel.class) + .handler(new SimpleChannelInboundHandler() { + @Override + public void channelActive(final ChannelHandlerContext ctx) throws Exception { + ctx.writeAndFlush(ctx.alloc().buffer().writeZero(100)); + } + + @Override + public void channelRead0(ChannelHandlerContext ctx, ByteBuf buffer) throws Exception { + // Just drop the buffer + } + }); + + sb.group(sharedGroup) + .channel(LocalServerChannel.class) + .childHandler(new ChannelInitializer() { + @Override + public void initChannel(LocalChannel ch) throws Exception { + ch.pipeline().addLast(new SimpleChannelInboundHandler() { + + @Override + public void channelRead0(ChannelHandlerContext ctx, ByteBuf buffer) throws Exception { + while (buffer.isReadable()) { + // Fill the ChannelOutboundBuffer with multiple buffers + ctx.write(buffer.readRetainedSlice(1)); + } + // Flush and so transfer the written buffers to the inboundBuffer of the remote peer. + // After this point the remote peer is responsible to release all the buffers. + ctx.flush(); + // This close call will trigger the remote peer close as well. + ctx.close(); + } + }); + } + }); + + Channel sc = null; + LocalChannel cc = null; + try { + // Start server + sc = sb.bind(TEST_ADDRESS).sync().channel(); + + // Connect to the server + cc = (LocalChannel) cb.connect(sc.localAddress()).sync().channel(); + + // Close the channel + closeChannel(cc); + assertTrue(cc.inboundBuffer.isEmpty()); + closeChannel(sc); + } finally { + closeChannel(cc); + closeChannel(sc); + } + } + + private static void writeAndFlushReadOnSuccess(final ChannelHandlerContext ctx, Object msg) { + ctx.writeAndFlush(msg).addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + if (future.isSuccess()) { + ctx.read(); + } + } + }); + } + + @Test + @Timeout(value = 5000, unit = TimeUnit.MILLISECONDS) + public void testAutoReadDisabledSharedGroup() throws Exception { + testAutoReadDisabled(sharedGroup, sharedGroup); + } + + @Test + @Timeout(value = 5000, unit = TimeUnit.MILLISECONDS) + public void testAutoReadDisabledDifferentGroup() throws Exception { + testAutoReadDisabled(group1, group2); + } + + private static void testAutoReadDisabled(EventLoopGroup serverGroup, EventLoopGroup clientGroup) throws Exception { + final CountDownLatch latch = new CountDownLatch(100); + Bootstrap cb = new Bootstrap(); + ServerBootstrap sb = new ServerBootstrap(); + + cb.group(serverGroup) + .channel(LocalChannel.class) + .option(ChannelOption.AUTO_READ, false) + .handler(new ChannelInboundHandlerAdapter() { + + @Override + public void channelActive(final ChannelHandlerContext ctx) throws Exception { + writeAndFlushReadOnSuccess(ctx, "test"); + } + + @Override + public void channelRead(final ChannelHandlerContext ctx, Object msg) throws Exception { + writeAndFlushReadOnSuccess(ctx, msg); + } + }); + + sb.group(clientGroup) + .channel(LocalServerChannel.class) + .childOption(ChannelOption.AUTO_READ, false) + .childHandler(new ChannelInboundHandlerAdapter() { + @Override + public void channelActive(final ChannelHandlerContext ctx) throws Exception { + ctx.read(); + } + + @Override + public void channelRead(final ChannelHandlerContext ctx, Object msg) throws Exception { + latch.countDown(); + if (latch.getCount() > 0) { + writeAndFlushReadOnSuccess(ctx, msg); + } + } + }); + + Channel sc = null; + Channel cc = null; + try { + // Start server + sc = sb.bind(TEST_ADDRESS).sync().channel(); + cc = cb.connect(TEST_ADDRESS).sync().channel(); + + latch.await(); + } finally { + closeChannel(cc); + closeChannel(sc); + } + } + + @Test + @Timeout(value = 5000, unit = TimeUnit.MILLISECONDS) + public void testMaxMessagesPerReadRespectedWithAutoReadSharedGroup() throws Exception { + testMaxMessagesPerReadRespected(sharedGroup, sharedGroup, true); + } + + @Test + @Timeout(value = 5000, unit = TimeUnit.MILLISECONDS) + public void testMaxMessagesPerReadRespectedWithoutAutoReadSharedGroup() throws Exception { + testMaxMessagesPerReadRespected(sharedGroup, sharedGroup, false); + } + + @Test + @Timeout(value = 5000, unit = TimeUnit.MILLISECONDS) + public void testMaxMessagesPerReadRespectedWithAutoReadDifferentGroup() throws Exception { + testMaxMessagesPerReadRespected(group1, group2, true); + } + + @Test + @Timeout(value = 5000, unit = TimeUnit.MILLISECONDS) + public void testMaxMessagesPerReadRespectedWithoutAutoReadDifferentGroup() throws Exception { + testMaxMessagesPerReadRespected(group1, group2, false); + } + + private static void testMaxMessagesPerReadRespected( + EventLoopGroup serverGroup, EventLoopGroup clientGroup, final boolean autoRead) throws Exception { + final CountDownLatch countDownLatch = new CountDownLatch(5); + Bootstrap cb = new Bootstrap(); + ServerBootstrap sb = new ServerBootstrap(); + + cb.group(serverGroup) + .channel(LocalChannel.class) + .option(ChannelOption.AUTO_READ, autoRead) + .option(ChannelOption.MAX_MESSAGES_PER_READ, 1) + .handler(new ChannelReadHandler(countDownLatch, autoRead)); + sb.group(clientGroup) + .channel(LocalServerChannel.class) + .childHandler(new ChannelInboundHandlerAdapter() { + @Override + public void channelActive(final ChannelHandlerContext ctx) { + for (int i = 0; i < 10; i++) { + ctx.write(i); + } + ctx.flush(); + } + }); + + Channel sc = null; + Channel cc = null; + try { + // Start server + sc = sb.bind(TEST_ADDRESS).sync().channel(); + cc = cb.connect(TEST_ADDRESS).sync().channel(); + + countDownLatch.await(); + } finally { + closeChannel(cc); + closeChannel(sc); + } + } + + @Test + @Timeout(value = 5000, unit = TimeUnit.MILLISECONDS) + public void testServerMaxMessagesPerReadRespectedWithAutoReadSharedGroup() throws Exception { + testServerMaxMessagesPerReadRespected(sharedGroup, sharedGroup, true); + } + + @Test + @Timeout(value = 5000, unit = TimeUnit.MILLISECONDS) + public void testServerMaxMessagesPerReadRespectedWithoutAutoReadSharedGroup() throws Exception { + testServerMaxMessagesPerReadRespected(sharedGroup, sharedGroup, false); + } + + @Test + @Timeout(value = 5000, unit = TimeUnit.MILLISECONDS) + public void testServerMaxMessagesPerReadRespectedWithAutoReadDifferentGroup() throws Exception { + testServerMaxMessagesPerReadRespected(group1, group2, true); + } + + @Test + @Timeout(value = 5000, unit = TimeUnit.MILLISECONDS) + public void testServerMaxMessagesPerReadRespectedWithoutAutoReadDifferentGroup() throws Exception { + testServerMaxMessagesPerReadRespected(group1, group2, false); + } + + private void testServerMaxMessagesPerReadRespected( + EventLoopGroup serverGroup, EventLoopGroup clientGroup, final boolean autoRead) throws Exception { + final CountDownLatch countDownLatch = new CountDownLatch(5); + Bootstrap cb = new Bootstrap(); + ServerBootstrap sb = new ServerBootstrap(); + + cb.group(serverGroup) + .channel(LocalChannel.class) + .handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + // NOOP + } + }); + + sb.group(clientGroup) + .channel(LocalServerChannel.class) + .option(ChannelOption.AUTO_READ, autoRead) + .option(ChannelOption.MAX_MESSAGES_PER_READ, 1) + .handler(new ChannelReadHandler(countDownLatch, autoRead)) + .childHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + // NOOP + } + }); + + Channel sc = null; + Channel cc = null; + try { + // Start server + sc = sb.bind(TEST_ADDRESS).sync().channel(); + for (int i = 0; i < 5; i++) { + try { + cc = cb.connect(TEST_ADDRESS).sync().channel(); + } finally { + closeChannel(cc); + } + } + + countDownLatch.await(); + } finally { + closeChannel(sc); + } + } + + private static final class ChannelReadHandler extends ChannelInboundHandlerAdapter { + + private final CountDownLatch latch; + private final boolean autoRead; + private int read; + + ChannelReadHandler(CountDownLatch latch, boolean autoRead) { + this.latch = latch; + this.autoRead = autoRead; + } + + @Override + public void channelActive(ChannelHandlerContext ctx) { + if (!autoRead) { + ctx.read(); + } + ctx.fireChannelActive(); + } + + @Override + public void channelRead(final ChannelHandlerContext ctx, Object msg) { + assertEquals(0, read); + read++; + ctx.fireChannelRead(msg); + } + + @Override + public void channelReadComplete(final ChannelHandlerContext ctx) { + assertEquals(1, read); + latch.countDown(); + if (latch.getCount() > 0) { + if (!autoRead) { + // The read will be scheduled 100ms in the future to ensure we not receive any + // channelRead calls in the meantime. + ctx.executor().schedule(new Runnable() { + @Override + public void run() { + read = 0; + ctx.read(); + } + }, 100, TimeUnit.MILLISECONDS); + } else { + read = 0; + } + } else { + read = 0; + } + ctx.fireChannelReadComplete(); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + ctx.fireExceptionCaught(cause); + ctx.close(); + } + } + + @Test + public void testReadCompleteCalledOnHandle() throws Exception { + Bootstrap cb = new Bootstrap(); + ServerBootstrap sb = new ServerBootstrap(); + + cb.group(sharedGroup) + .channel(LocalChannel.class) + .handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + // NOOP + } + }); + + CountDownLatch serverLatch = new CountDownLatch(1); + CountDownLatch childLatch = new CountDownLatch(1); + + sb.group(sharedGroup) + .channel(LocalServerChannel.class) + .option(ChannelOption.RCVBUF_ALLOCATOR, new ReadCompleteRecvAllocator(serverLatch)) + .childHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + // NOOP + } + }) + .childOption(ChannelOption.RCVBUF_ALLOCATOR, new ReadCompleteRecvAllocator(childLatch)); + + Channel sc = null; + Channel cc = null; + try { + // Start server + sc = sb.bind(TEST_ADDRESS).sync().channel(); + try { + cc = cb.connect(TEST_ADDRESS).sync().channel(); + cc.writeAndFlush("msg").sync(); + } finally { + closeChannel(cc); + } + + serverLatch.await(); + childLatch.await(); + } finally { + closeChannel(sc); + } + } + + private static final class ReadCompleteRecvAllocator extends DefaultMaxMessagesRecvByteBufAllocator { + private final CountDownLatch latch; + ReadCompleteRecvAllocator(CountDownLatch latch) { + this.latch = latch; + } + + @Override + public Handle newHandle() { + return new MaxMessageHandle() { + @Override + public int guess() { + return 128; + } + + @Override + public void readComplete() { + super.readComplete(); + latch.countDown(); + } + }; + } + } +} diff --git a/netty-channel/src/test/java/io/netty/channel/local/LocalTransportThreadModelTest.java b/netty-channel/src/test/java/io/netty/channel/local/LocalTransportThreadModelTest.java new file mode 100644 index 0000000..49ff00b --- /dev/null +++ b/netty-channel/src/test/java/io/netty/channel/local/LocalTransportThreadModelTest.java @@ -0,0 +1,610 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.local; + +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelDuplexHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelPromise; +import io.netty.channel.DefaultEventLoopGroup; +import io.netty.channel.EventLoopGroup; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.concurrent.DefaultEventExecutorGroup; +import io.netty.util.concurrent.DefaultThreadFactory; +import io.netty.util.concurrent.EventExecutorGroup; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.util.HashSet; +import java.util.Queue; +import java.util.Set; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class LocalTransportThreadModelTest { + + private static EventLoopGroup group; + private static LocalAddress localAddr; + + @BeforeAll + public static void init() { + // Configure a test server + group = new DefaultEventLoopGroup(); + ServerBootstrap sb = new ServerBootstrap(); + sb.group(group) + .channel(LocalServerChannel.class) + .childHandler(new ChannelInitializer() { + @Override + public void initChannel(LocalChannel ch) throws Exception { + ch.pipeline().addLast(new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + // Discard + ReferenceCountUtil.release(msg); + } + }); + } + }); + + localAddr = (LocalAddress) sb.bind(LocalAddress.ANY).syncUninterruptibly().channel().localAddress(); + } + + @AfterAll + public static void destroy() throws Exception { + group.shutdownGracefully().sync(); + } + + @Test + @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS) + @Disabled("regression test") + public void testStagedExecutionMultiple() throws Throwable { + for (int i = 0; i < 10; i ++) { + testStagedExecution(); + } + } + + @Test + @Timeout(value = 5000, unit = TimeUnit.MILLISECONDS) + public void testStagedExecution() throws Throwable { + EventLoopGroup l = new DefaultEventLoopGroup(4, new DefaultThreadFactory("l")); + EventExecutorGroup e1 = new DefaultEventExecutorGroup(4, new DefaultThreadFactory("e1")); + EventExecutorGroup e2 = new DefaultEventExecutorGroup(4, new DefaultThreadFactory("e2")); + ThreadNameAuditor h1 = new ThreadNameAuditor(); + ThreadNameAuditor h2 = new ThreadNameAuditor(); + ThreadNameAuditor h3 = new ThreadNameAuditor(true); + + Channel ch = new LocalChannel(); + // With no EventExecutor specified, h1 will be always invoked by EventLoop 'l'. + ch.pipeline().addLast(h1); + // h2 will be always invoked by EventExecutor 'e1'. + ch.pipeline().addLast(e1, h2); + // h3 will be always invoked by EventExecutor 'e2'. + ch.pipeline().addLast(e2, h3); + + l.register(ch).sync().channel().connect(localAddr).sync(); + + // Fire inbound events from all possible starting points. + ch.pipeline().fireChannelRead("1"); + ch.pipeline().context(h1).fireChannelRead("2"); + ch.pipeline().context(h2).fireChannelRead("3"); + ch.pipeline().context(h3).fireChannelRead("4"); + // Fire outbound events from all possible starting points. + ch.pipeline().write("5"); + ch.pipeline().context(h3).write("6"); + ch.pipeline().context(h2).write("7"); + ch.pipeline().context(h1).writeAndFlush("8").sync(); + + ch.close().sync(); + + // Wait until all events are handled completely. + while (h1.outboundThreadNames.size() < 3 || h3.inboundThreadNames.size() < 3 || + h1.removalThreadNames.size() < 1) { + if (h1.exception.get() != null) { + throw h1.exception.get(); + } + if (h2.exception.get() != null) { + throw h2.exception.get(); + } + if (h3.exception.get() != null) { + throw h3.exception.get(); + } + + Thread.sleep(10); + } + + String currentName = Thread.currentThread().getName(); + + try { + // Events should never be handled from the current thread. + assertFalse(h1.inboundThreadNames.contains(currentName)); + assertFalse(h2.inboundThreadNames.contains(currentName)); + assertFalse(h3.inboundThreadNames.contains(currentName)); + assertFalse(h1.outboundThreadNames.contains(currentName)); + assertFalse(h2.outboundThreadNames.contains(currentName)); + assertFalse(h3.outboundThreadNames.contains(currentName)); + assertFalse(h1.removalThreadNames.contains(currentName)); + assertFalse(h2.removalThreadNames.contains(currentName)); + assertFalse(h3.removalThreadNames.contains(currentName)); + + // Assert that events were handled by the correct executor. + for (String name: h1.inboundThreadNames) { + assertTrue(name.startsWith("l-")); + } + for (String name: h2.inboundThreadNames) { + assertTrue(name.startsWith("e1-")); + } + for (String name: h3.inboundThreadNames) { + assertTrue(name.startsWith("e2-")); + } + for (String name: h1.outboundThreadNames) { + assertTrue(name.startsWith("l-")); + } + for (String name: h2.outboundThreadNames) { + assertTrue(name.startsWith("e1-")); + } + for (String name: h3.outboundThreadNames) { + assertTrue(name.startsWith("e2-")); + } + for (String name: h1.removalThreadNames) { + assertTrue(name.startsWith("l-")); + } + for (String name: h2.removalThreadNames) { + assertTrue(name.startsWith("e1-")); + } + for (String name: h3.removalThreadNames) { + assertTrue(name.startsWith("e2-")); + } + + // Assert that the events for the same handler were handled by the same thread. + Set names = new HashSet(); + names.addAll(h1.inboundThreadNames); + names.addAll(h1.outboundThreadNames); + names.addAll(h1.removalThreadNames); + assertEquals(1, names.size()); + + names.clear(); + names.addAll(h2.inboundThreadNames); + names.addAll(h2.outboundThreadNames); + names.addAll(h2.removalThreadNames); + assertEquals(1, names.size()); + + names.clear(); + names.addAll(h3.inboundThreadNames); + names.addAll(h3.outboundThreadNames); + names.addAll(h3.removalThreadNames); + assertEquals(1, names.size()); + + // Count the number of events + assertEquals(1, h1.inboundThreadNames.size()); + assertEquals(2, h2.inboundThreadNames.size()); + assertEquals(3, h3.inboundThreadNames.size()); + assertEquals(3, h1.outboundThreadNames.size()); + assertEquals(2, h2.outboundThreadNames.size()); + assertEquals(1, h3.outboundThreadNames.size()); + assertEquals(1, h1.removalThreadNames.size()); + assertEquals(1, h2.removalThreadNames.size()); + assertEquals(1, h3.removalThreadNames.size()); + } catch (AssertionError e) { + System.out.println("H1I: " + h1.inboundThreadNames); + System.out.println("H2I: " + h2.inboundThreadNames); + System.out.println("H3I: " + h3.inboundThreadNames); + System.out.println("H1O: " + h1.outboundThreadNames); + System.out.println("H2O: " + h2.outboundThreadNames); + System.out.println("H3O: " + h3.outboundThreadNames); + System.out.println("H1R: " + h1.removalThreadNames); + System.out.println("H2R: " + h2.removalThreadNames); + System.out.println("H3R: " + h3.removalThreadNames); + throw e; + } finally { + l.shutdownGracefully(); + e1.shutdownGracefully(); + e2.shutdownGracefully(); + + l.terminationFuture().sync(); + e1.terminationFuture().sync(); + e2.terminationFuture().sync(); + } + } + + @Test + @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS) + @Disabled + public void testConcurrentMessageBufferAccess() throws Throwable { + EventLoopGroup l = new DefaultEventLoopGroup(4, new DefaultThreadFactory("l")); + EventExecutorGroup e1 = new DefaultEventExecutorGroup(4, new DefaultThreadFactory("e1")); + EventExecutorGroup e2 = new DefaultEventExecutorGroup(4, new DefaultThreadFactory("e2")); + EventExecutorGroup e3 = new DefaultEventExecutorGroup(4, new DefaultThreadFactory("e3")); + EventExecutorGroup e4 = new DefaultEventExecutorGroup(4, new DefaultThreadFactory("e4")); + EventExecutorGroup e5 = new DefaultEventExecutorGroup(4, new DefaultThreadFactory("e5")); + + try { + final MessageForwarder1 h1 = new MessageForwarder1(); + final MessageForwarder2 h2 = new MessageForwarder2(); + final MessageForwarder3 h3 = new MessageForwarder3(); + final MessageForwarder1 h4 = new MessageForwarder1(); + final MessageForwarder2 h5 = new MessageForwarder2(); + final MessageDiscarder h6 = new MessageDiscarder(); + + final Channel ch = new LocalChannel(); + + // inbound: int -> byte[4] -> int -> int -> byte[4] -> int -> /dev/null + // outbound: int -> int -> byte[4] -> int -> int -> byte[4] -> /dev/null + ch.pipeline().addLast(h1) + .addLast(e1, h2) + .addLast(e2, h3) + .addLast(e3, h4) + .addLast(e4, h5) + .addLast(e5, h6); + + l.register(ch).sync().channel().connect(localAddr).sync(); + + final int ROUNDS = 1024; + final int ELEMS_PER_ROUNDS = 8192; + final int TOTAL_CNT = ROUNDS * ELEMS_PER_ROUNDS; + for (int i = 0; i < TOTAL_CNT;) { + final int start = i; + final int end = i + ELEMS_PER_ROUNDS; + i = end; + + ch.eventLoop().execute(new Runnable() { + @Override + public void run() { + for (int j = start; j < end; j ++) { + ch.pipeline().fireChannelRead(Integer.valueOf(j)); + } + } + }); + } + + while (h1.inCnt < TOTAL_CNT || h2.inCnt < TOTAL_CNT || h3.inCnt < TOTAL_CNT || + h4.inCnt < TOTAL_CNT || h5.inCnt < TOTAL_CNT || h6.inCnt < TOTAL_CNT) { + if (h1.exception.get() != null) { + throw h1.exception.get(); + } + if (h2.exception.get() != null) { + throw h2.exception.get(); + } + if (h3.exception.get() != null) { + throw h3.exception.get(); + } + if (h4.exception.get() != null) { + throw h4.exception.get(); + } + if (h5.exception.get() != null) { + throw h5.exception.get(); + } + if (h6.exception.get() != null) { + throw h6.exception.get(); + } + Thread.sleep(10); + } + + for (int i = 0; i < TOTAL_CNT;) { + final int start = i; + final int end = i + ELEMS_PER_ROUNDS; + i = end; + + ch.pipeline().context(h6).executor().execute(new Runnable() { + @Override + public void run() { + for (int j = start; j < end; j ++) { + ch.write(Integer.valueOf(j)); + } + ch.flush(); + } + }); + } + + while (h1.outCnt < TOTAL_CNT || h2.outCnt < TOTAL_CNT || h3.outCnt < TOTAL_CNT || + h4.outCnt < TOTAL_CNT || h5.outCnt < TOTAL_CNT || h6.outCnt < TOTAL_CNT) { + if (h1.exception.get() != null) { + throw h1.exception.get(); + } + if (h2.exception.get() != null) { + throw h2.exception.get(); + } + if (h3.exception.get() != null) { + throw h3.exception.get(); + } + if (h4.exception.get() != null) { + throw h4.exception.get(); + } + if (h5.exception.get() != null) { + throw h5.exception.get(); + } + if (h6.exception.get() != null) { + throw h6.exception.get(); + } + Thread.sleep(10); + } + + ch.close().sync(); + } finally { + l.shutdownGracefully(); + e1.shutdownGracefully(); + e2.shutdownGracefully(); + e3.shutdownGracefully(); + e4.shutdownGracefully(); + e5.shutdownGracefully(); + + l.terminationFuture().sync(); + e1.terminationFuture().sync(); + e2.terminationFuture().sync(); + e3.terminationFuture().sync(); + e4.terminationFuture().sync(); + e5.terminationFuture().sync(); + } + } + + private static class ThreadNameAuditor extends ChannelDuplexHandler { + + private final AtomicReference exception = new AtomicReference(); + + private final Queue inboundThreadNames = new ConcurrentLinkedQueue(); + private final Queue outboundThreadNames = new ConcurrentLinkedQueue(); + private final Queue removalThreadNames = new ConcurrentLinkedQueue(); + private final boolean discard; + + ThreadNameAuditor() { + this(false); + } + + ThreadNameAuditor(boolean discard) { + this.discard = discard; + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { + removalThreadNames.add(Thread.currentThread().getName()); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + inboundThreadNames.add(Thread.currentThread().getName()); + if (!discard) { + ctx.fireChannelRead(msg); + } + } + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + outboundThreadNames.add(Thread.currentThread().getName()); + ctx.write(msg, promise); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + exception.compareAndSet(null, cause); + System.err.print('[' + Thread.currentThread().getName() + "] "); + cause.printStackTrace(); + super.exceptionCaught(ctx, cause); + } + } + + /** + * Converts integers into a binary stream. + */ + private static class MessageForwarder1 extends ChannelDuplexHandler { + + private final AtomicReference exception = new AtomicReference(); + private volatile int inCnt; + private volatile int outCnt; + private volatile Thread t; + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + Thread t = this.t; + if (t == null) { + this.t = Thread.currentThread(); + } else { + assertSame(t, Thread.currentThread()); + } + + ByteBuf out = ctx.alloc().buffer(4); + int m = ((Integer) msg).intValue(); + int expected = inCnt ++; + assertEquals(expected, m); + out.writeInt(m); + + ctx.fireChannelRead(out); + } + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + assertSame(t, Thread.currentThread()); + + // Don't let the write request go to the server-side channel - just swallow. + boolean swallow = this == ctx.pipeline().first(); + + ByteBuf m = (ByteBuf) msg; + int count = m.readableBytes() / 4; + for (int j = 0; j < count; j ++) { + int actual = m.readInt(); + int expected = outCnt ++; + assertEquals(expected, actual); + if (!swallow) { + ctx.write(actual); + } + } + ctx.writeAndFlush(Unpooled.EMPTY_BUFFER, promise); + m.release(); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + exception.compareAndSet(null, cause); + //System.err.print("[" + Thread.currentThread().getName() + "] "); + //cause.printStackTrace(); + super.exceptionCaught(ctx, cause); + } + } + + /** + * Converts a binary stream into integers. + */ + private static class MessageForwarder2 extends ChannelDuplexHandler { + + private final AtomicReference exception = new AtomicReference(); + private volatile int inCnt; + private volatile int outCnt; + private volatile Thread t; + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + Thread t = this.t; + if (t == null) { + this.t = Thread.currentThread(); + } else { + assertSame(t, Thread.currentThread()); + } + + ByteBuf m = (ByteBuf) msg; + int count = m.readableBytes() / 4; + for (int j = 0; j < count; j ++) { + int actual = m.readInt(); + int expected = inCnt ++; + assertEquals(expected, actual); + ctx.fireChannelRead(actual); + } + m.release(); + } + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + assertSame(t, Thread.currentThread()); + + ByteBuf out = ctx.alloc().buffer(4); + int m = (Integer) msg; + int expected = outCnt ++; + assertEquals(expected, m); + out.writeInt(m); + + ctx.write(out, promise); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + exception.compareAndSet(null, cause); + //System.err.print("[" + Thread.currentThread().getName() + "] "); + //cause.printStackTrace(); + super.exceptionCaught(ctx, cause); + } + } + + /** + * Simply forwards the received object to the next handler. + */ + private static class MessageForwarder3 extends ChannelDuplexHandler { + + private final AtomicReference exception = new AtomicReference(); + private volatile int inCnt; + private volatile int outCnt; + private volatile Thread t; + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + Thread t = this.t; + if (t == null) { + this.t = Thread.currentThread(); + } else { + assertSame(t, Thread.currentThread()); + } + + int actual = (Integer) msg; + int expected = inCnt ++; + assertEquals(expected, actual); + + ctx.fireChannelRead(msg); + } + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + assertSame(t, Thread.currentThread()); + + int actual = (Integer) msg; + int expected = outCnt ++; + assertEquals(expected, actual); + + ctx.write(msg, promise); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + exception.compareAndSet(null, cause); + System.err.print('[' + Thread.currentThread().getName() + "] "); + cause.printStackTrace(); + super.exceptionCaught(ctx, cause); + } + } + + /** + * Discards all received messages. + */ + private static class MessageDiscarder extends ChannelDuplexHandler { + + private final AtomicReference exception = new AtomicReference(); + private volatile int inCnt; + private volatile int outCnt; + private volatile Thread t; + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + Thread t = this.t; + if (t == null) { + this.t = Thread.currentThread(); + } else { + assertSame(t, Thread.currentThread()); + } + + int actual = (Integer) msg; + int expected = inCnt ++; + assertEquals(expected, actual); + } + + @Override + public void write( + ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + assertSame(t, Thread.currentThread()); + + int actual = (Integer) msg; + int expected = outCnt ++; + assertEquals(expected, actual); + ctx.write(msg, promise); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + exception.compareAndSet(null, cause); + //System.err.print("[" + Thread.currentThread().getName() + "] "); + //cause.printStackTrace(); + super.exceptionCaught(ctx, cause); + } + } +} diff --git a/netty-channel/src/test/java/io/netty/channel/local/LocalTransportThreadModelTest2.java b/netty-channel/src/test/java/io/netty/channel/local/LocalTransportThreadModelTest2.java new file mode 100644 index 0000000..c357557 --- /dev/null +++ b/netty-channel/src/test/java/io/netty/channel/local/LocalTransportThreadModelTest2.java @@ -0,0 +1,125 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.local; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandler.Sharable; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.DefaultEventLoopGroup; +import io.netty.util.ReferenceCountUtil; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class LocalTransportThreadModelTest2 { + + private static final String LOCAL_CHANNEL = LocalTransportThreadModelTest2.class.getName(); + + static final int messageCountPerRun = 4; + + @Test + @Timeout(value = 15000, unit = TimeUnit.MILLISECONDS) + public void testSocketReuse() throws InterruptedException { + ServerBootstrap serverBootstrap = new ServerBootstrap(); + LocalHandler serverHandler = new LocalHandler("SERVER"); + serverBootstrap + .group(new DefaultEventLoopGroup(), new DefaultEventLoopGroup()) + .channel(LocalServerChannel.class) + .childHandler(serverHandler); + + Bootstrap clientBootstrap = new Bootstrap(); + LocalHandler clientHandler = new LocalHandler("CLIENT"); + clientBootstrap + .group(new DefaultEventLoopGroup()) + .channel(LocalChannel.class) + .remoteAddress(new LocalAddress(LOCAL_CHANNEL)).handler(clientHandler); + + serverBootstrap.bind(new LocalAddress(LOCAL_CHANNEL)).sync(); + + int count = 100; + for (int i = 1; i < count + 1; i ++) { + Channel ch = clientBootstrap.connect().sync().channel(); + + // SPIN until we get what we are looking for. + int target = i * messageCountPerRun; + while (serverHandler.count.get() != target || clientHandler.count.get() != target) { + Thread.sleep(50); + } + close(ch, clientHandler); + } + + assertEquals(count * 2 * messageCountPerRun, serverHandler.count.get() + + clientHandler.count.get()); + } + + public void close(final Channel localChannel, final LocalHandler localRegistrationHandler) { + // we want to make sure we actually shutdown IN the event loop + if (localChannel.eventLoop().inEventLoop()) { + // Wait until all messages are flushed before closing the channel. + if (localRegistrationHandler.lastWriteFuture != null) { + localRegistrationHandler.lastWriteFuture.awaitUninterruptibly(); + } + + localChannel.close(); + return; + } + + localChannel.eventLoop().execute(new Runnable() { + @Override + public void run() { + close(localChannel, localRegistrationHandler); + } + }); + + // Wait until the connection is closed or the connection attempt fails. + localChannel.closeFuture().awaitUninterruptibly(); + } + + @Sharable + static class LocalHandler extends ChannelInboundHandlerAdapter { + private final String name; + + public volatile ChannelFuture lastWriteFuture; + + public final AtomicInteger count = new AtomicInteger(0); + + LocalHandler(String name) { + this.name = name; + } + + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + for (int i = 0; i < messageCountPerRun; i ++) { + lastWriteFuture = ctx.channel().write(name + ' ' + i); + } + ctx.channel().flush(); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + count.incrementAndGet(); + ReferenceCountUtil.release(msg); + } + } +} diff --git a/netty-channel/src/test/java/io/netty/channel/local/LocalTransportThreadModelTest3.java b/netty-channel/src/test/java/io/netty/channel/local/LocalTransportThreadModelTest3.java new file mode 100644 index 0000000..9c229af --- /dev/null +++ b/netty-channel/src/test/java/io/netty/channel/local/LocalTransportThreadModelTest3.java @@ -0,0 +1,336 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.local; + +import io.netty.bootstrap.ServerBootstrap; +import io.netty.channel.Channel; +import io.netty.channel.ChannelDuplexHandler; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelPromise; +import io.netty.channel.DefaultEventLoopGroup; +import io.netty.channel.EventLoopGroup; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.concurrent.DefaultEventExecutorGroup; +import io.netty.util.concurrent.DefaultThreadFactory; +import io.netty.util.concurrent.EventExecutorGroup; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.util.Deque; +import java.util.LinkedList; +import java.util.Queue; +import java.util.Random; +import java.util.UUID; +import java.util.concurrent.ConcurrentLinkedDeque; +import java.util.concurrent.TimeUnit; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class LocalTransportThreadModelTest3 { + + enum EventType { + EXCEPTION_CAUGHT, + USER_EVENT, + MESSAGE_RECEIVED_LAST, + INACTIVE, + ACTIVE, + UNREGISTERED, + REGISTERED, + MESSAGE_RECEIVED, + WRITE, + READ + } + + private static EventLoopGroup group; + private static LocalAddress localAddr; + + @BeforeAll + public static void init() { + // Configure a test server + group = new DefaultEventLoopGroup(); + ServerBootstrap sb = new ServerBootstrap(); + sb.group(group) + .channel(LocalServerChannel.class) + .childHandler(new ChannelInitializer() { + @Override + public void initChannel(LocalChannel ch) throws Exception { + ch.pipeline().addLast(new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + // Discard + ReferenceCountUtil.release(msg); + } + }); + } + }); + + localAddr = (LocalAddress) sb.bind(LocalAddress.ANY).syncUninterruptibly().channel().localAddress(); + } + + @AfterAll + public static void destroy() throws Exception { + group.shutdownGracefully().sync(); + } + + @Test + @Timeout(value = 60000, unit = TimeUnit.MILLISECONDS) + @Disabled("regression test") + public void testConcurrentAddRemoveInboundEventsMultiple() throws Throwable { + for (int i = 0; i < 50; i ++) { + testConcurrentAddRemoveInboundEvents(); + } + } + + @Test + @Timeout(value = 60000, unit = TimeUnit.MILLISECONDS) + @Disabled("regression test") + public void testConcurrentAddRemoveOutboundEventsMultiple() throws Throwable { + for (int i = 0; i < 50; i ++) { + testConcurrentAddRemoveOutboundEvents(); + } + } + + @Test + @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS) + @Disabled("needs a fix") + public void testConcurrentAddRemoveInboundEvents() throws Throwable { + testConcurrentAddRemove(true); + } + + @Test + @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS) + @Disabled("needs a fix") + public void testConcurrentAddRemoveOutboundEvents() throws Throwable { + testConcurrentAddRemove(false); + } + + private static void testConcurrentAddRemove(boolean inbound) throws Exception { + EventLoopGroup l = new DefaultEventLoopGroup(4, new DefaultThreadFactory("l")); + EventExecutorGroup e1 = new DefaultEventExecutorGroup(4, new DefaultThreadFactory("e1")); + EventExecutorGroup e2 = new DefaultEventExecutorGroup(4, new DefaultThreadFactory("e2")); + EventExecutorGroup e3 = new DefaultEventExecutorGroup(4, new DefaultThreadFactory("e3")); + EventExecutorGroup e4 = new DefaultEventExecutorGroup(4, new DefaultThreadFactory("e4")); + EventExecutorGroup e5 = new DefaultEventExecutorGroup(4, new DefaultThreadFactory("e5")); + + final EventExecutorGroup[] groups = {e1, e2, e3, e4, e5}; + try { + Deque events = new ConcurrentLinkedDeque(); + final EventForwarder h1 = new EventForwarder(); + final EventForwarder h2 = new EventForwarder(); + final EventForwarder h3 = new EventForwarder(); + final EventForwarder h4 = new EventForwarder(); + final EventForwarder h5 = new EventForwarder(); + final EventRecorder h6 = new EventRecorder(events, inbound); + + final Channel ch = new LocalChannel(); + if (!inbound) { + ch.config().setAutoRead(false); + } + ch.pipeline().addLast(e1, h1) + .addLast(e1, h2) + .addLast(e1, h3) + .addLast(e1, h4) + .addLast(e1, h5) + .addLast(e1, "recorder", h6); + + l.register(ch).sync().channel().connect(localAddr).sync(); + + final LinkedList expectedEvents = events(inbound, 8192); + + Throwable cause = new Throwable(); + + Thread pipelineModifier = new Thread(new Runnable() { + @Override + public void run() { + Random random = new Random(); + + while (true) { + try { + Thread.sleep(100); + } catch (InterruptedException e) { + return; + } + if (!ch.isRegistered()) { + continue; + } + //EventForwardHandler forwardHandler = forwarders[random.nextInt(forwarders.length)]; + ChannelHandler handler = ch.pipeline().removeFirst(); + ch.pipeline().addBefore(groups[random.nextInt(groups.length)], "recorder", + UUID.randomUUID().toString(), handler); + } + } + }); + pipelineModifier.setDaemon(true); + pipelineModifier.start(); + for (EventType event: expectedEvents) { + switch (event) { + case EXCEPTION_CAUGHT: + ch.pipeline().fireExceptionCaught(cause); + break; + case MESSAGE_RECEIVED: + ch.pipeline().fireChannelRead(""); + break; + case MESSAGE_RECEIVED_LAST: + ch.pipeline().fireChannelReadComplete(); + break; + case USER_EVENT: + ch.pipeline().fireUserEventTriggered(""); + break; + case WRITE: + ch.pipeline().write(""); + break; + case READ: + ch.pipeline().read(); + break; + } + } + + ch.close().sync(); + + while (events.peekLast() != EventType.UNREGISTERED) { + Thread.sleep(10); + } + + expectedEvents.addFirst(EventType.ACTIVE); + expectedEvents.addFirst(EventType.REGISTERED); + expectedEvents.addLast(EventType.INACTIVE); + expectedEvents.addLast(EventType.UNREGISTERED); + + for (;;) { + EventType event = events.poll(); + if (event == null) { + assertTrue(expectedEvents.isEmpty(), "Missing events:" + expectedEvents); + break; + } + assertEquals(event, expectedEvents.poll()); + } + } finally { + l.shutdownGracefully(); + e1.shutdownGracefully(); + e2.shutdownGracefully(); + e3.shutdownGracefully(); + e4.shutdownGracefully(); + e5.shutdownGracefully(); + + l.terminationFuture().sync(); + e1.terminationFuture().sync(); + e2.terminationFuture().sync(); + e3.terminationFuture().sync(); + e4.terminationFuture().sync(); + e5.terminationFuture().sync(); + } + } + + private static LinkedList events(boolean inbound, int size) { + EventType[] events; + if (inbound) { + events = new EventType[] { + EventType.USER_EVENT, EventType.MESSAGE_RECEIVED, EventType.MESSAGE_RECEIVED_LAST, + EventType.EXCEPTION_CAUGHT}; + } else { + events = new EventType[] { + EventType.READ, EventType.WRITE, EventType.EXCEPTION_CAUGHT }; + } + + Random random = new Random(); + LinkedList expectedEvents = new LinkedList(); + for (int i = 0; i < size; i++) { + expectedEvents.add(events[random.nextInt(events.length)]); + } + return expectedEvents; + } + + @ChannelHandler.Sharable + private static final class EventForwarder extends ChannelDuplexHandler { } + + private static final class EventRecorder extends ChannelDuplexHandler { + private final Queue events; + private final boolean inbound; + + EventRecorder(Queue events, boolean inbound) { + this.events = events; + this.inbound = inbound; + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + events.add(EventType.EXCEPTION_CAUGHT); + } + + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + if (inbound) { + events.add(EventType.USER_EVENT); + } + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + if (inbound) { + events.add(EventType.MESSAGE_RECEIVED_LAST); + } + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + events.add(EventType.INACTIVE); + } + + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + events.add(EventType.ACTIVE); + } + + @Override + public void channelUnregistered(ChannelHandlerContext ctx) throws Exception { + events.add(EventType.UNREGISTERED); + } + + @Override + public void channelRegistered(ChannelHandlerContext ctx) throws Exception { + events.add(EventType.REGISTERED); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if (inbound) { + events.add(EventType.MESSAGE_RECEIVED); + } + } + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + if (!inbound) { + events.add(EventType.WRITE); + } + promise.setSuccess(); + } + + @Override + public void read(ChannelHandlerContext ctx) { + if (!inbound) { + events.add(EventType.READ); + } + } + } +} diff --git a/netty-channel/src/test/java/io/netty/channel/nio/NioEventLoopTest.java b/netty-channel/src/test/java/io/netty/channel/nio/NioEventLoopTest.java new file mode 100644 index 0000000..72a92e7 --- /dev/null +++ b/netty-channel/src/test/java/io/netty/channel/nio/NioEventLoopTest.java @@ -0,0 +1,348 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.nio; + +import io.netty.channel.AbstractEventLoopTest; +import io.netty.channel.Channel; +import io.netty.channel.DefaultSelectStrategyFactory; +import io.netty.channel.EventLoop; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.EventLoopTaskQueueFactory; +import io.netty.channel.SelectStrategy; +import io.netty.channel.SelectStrategyFactory; +import io.netty.channel.SingleThreadEventLoop; +import io.netty.channel.socket.ServerSocketChannel; +import io.netty.channel.socket.nio.NioServerSocketChannel; +import io.netty.util.IntSupplier; +import io.netty.util.concurrent.DefaultEventExecutorChooserFactory; +import io.netty.util.concurrent.DefaultThreadFactory; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.RejectedExecutionHandlers; +import io.netty.util.concurrent.ThreadPerTaskExecutor; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.nio.channels.SelectionKey; +import java.nio.channels.Selector; +import java.nio.channels.SocketChannel; +import java.nio.channels.spi.SelectorProvider; +import java.util.Queue; +import java.util.concurrent.Callable; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.instanceOf; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotSame; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class NioEventLoopTest extends AbstractEventLoopTest { + + @Override + protected EventLoopGroup newEventLoopGroup() { + return new NioEventLoopGroup(); + } + + @Override + protected Class newChannel() { + return NioServerSocketChannel.class; + } + + @Test + public void testRebuildSelector() { + EventLoopGroup group = new NioEventLoopGroup(1); + final NioEventLoop loop = (NioEventLoop) group.next(); + try { + Channel channel = new NioServerSocketChannel(); + loop.register(channel).syncUninterruptibly(); + + Selector selector = loop.unwrappedSelector(); + assertSame(selector, ((NioEventLoop) channel.eventLoop()).unwrappedSelector()); + assertTrue(selector.isOpen()); + + // Submit to the EventLoop so we are sure its really executed in a non-async manner. + loop.submit(new Runnable() { + @Override + public void run() { + loop.rebuildSelector(); + } + }).syncUninterruptibly(); + + Selector newSelector = ((NioEventLoop) channel.eventLoop()).unwrappedSelector(); + assertTrue(newSelector.isOpen()); + assertNotSame(selector, newSelector); + assertFalse(selector.isOpen()); + + channel.close().syncUninterruptibly(); + } finally { + group.shutdownGracefully(); + } + } + + @Test + public void testScheduleBigDelayNotOverflow() { + EventLoopGroup group = new NioEventLoopGroup(1); + + final EventLoop el = group.next(); + Future future = el.schedule(new Runnable() { + @Override + public void run() { + // NOOP + } + }, Long.MAX_VALUE, TimeUnit.MILLISECONDS); + + assertFalse(future.awaitUninterruptibly(1000)); + assertTrue(future.cancel(true)); + group.shutdownGracefully(); + } + + @Test + public void testInterruptEventLoopThread() throws Exception { + EventLoopGroup group = new NioEventLoopGroup(1); + final NioEventLoop loop = (NioEventLoop) group.next(); + try { + Selector selector = loop.unwrappedSelector(); + assertTrue(selector.isOpen()); + + loop.submit(new Runnable() { + @Override + public void run() { + // Interrupt the thread which should not end-up in a busy spin and + // so the selector should not have been rebuild. + Thread.currentThread().interrupt(); + } + }).syncUninterruptibly(); + + assertTrue(selector.isOpen()); + + final CountDownLatch latch = new CountDownLatch(2); + loop.submit(new Runnable() { + @Override + public void run() { + latch.countDown(); + } + }).syncUninterruptibly(); + + loop.schedule(new Runnable() { + @Override + public void run() { + latch.countDown(); + } + }, 2, TimeUnit.SECONDS).syncUninterruptibly(); + + latch.await(); + + assertSame(selector, loop.unwrappedSelector()); + assertTrue(selector.isOpen()); + } finally { + group.shutdownGracefully(); + } + } + + @Test + @Timeout(value = 3000, unit = TimeUnit.MILLISECONDS) + public void testSelectableChannel() throws Exception { + NioEventLoopGroup group = new NioEventLoopGroup(1); + NioEventLoop loop = (NioEventLoop) group.next(); + + try { + Channel channel = new NioServerSocketChannel(); + loop.register(channel).syncUninterruptibly(); + channel.bind(new InetSocketAddress(0)).syncUninterruptibly(); + + SocketChannel selectableChannel = SocketChannel.open(); + selectableChannel.configureBlocking(false); + selectableChannel.connect(channel.localAddress()); + + final CountDownLatch latch = new CountDownLatch(1); + + loop.register(selectableChannel, SelectionKey.OP_CONNECT, new NioTask() { + @Override + public void channelReady(SocketChannel ch, SelectionKey key) { + latch.countDown(); + } + + @Override + public void channelUnregistered(SocketChannel ch, Throwable cause) { + } + }); + + latch.await(); + + selectableChannel.close(); + channel.close().syncUninterruptibly(); + } finally { + group.shutdownGracefully(); + } + } + + @SuppressWarnings("deprecation") + @Test + public void testTaskRemovalOnShutdownThrowsNoUnsupportedOperationException() throws Exception { + final AtomicReference error = new AtomicReference(); + final Runnable task = new Runnable() { + @Override + public void run() { + // NOOP + } + }; + // Just run often enough to trigger it normally. + for (int i = 0; i < 1000; i++) { + NioEventLoopGroup group = new NioEventLoopGroup(1); + final NioEventLoop loop = (NioEventLoop) group.next(); + + Thread t = new Thread(new Runnable() { + @Override + public void run() { + try { + for (;;) { + loop.execute(task); + } + } catch (Throwable cause) { + error.set(cause); + } + } + }); + t.start(); + group.shutdownNow(); + t.join(); + group.terminationFuture().syncUninterruptibly(); + assertThat(error.get(), instanceOf(RejectedExecutionException.class)); + error.set(null); + } + } + + @Test + public void testRebuildSelectorOnIOException() { + SelectStrategyFactory selectStrategyFactory = new SelectStrategyFactory() { + @Override + public SelectStrategy newSelectStrategy() { + return new SelectStrategy() { + + private boolean thrown; + + @Override + public int calculateStrategy(IntSupplier selectSupplier, boolean hasTasks) throws Exception { + if (!thrown) { + thrown = true; + throw new IOException(); + } + return -1; + } + }; + } + }; + + EventLoopGroup group = new NioEventLoopGroup(1, new DefaultThreadFactory("ioPool"), + SelectorProvider.provider(), selectStrategyFactory); + final NioEventLoop loop = (NioEventLoop) group.next(); + try { + Channel channel = new NioServerSocketChannel(); + Selector selector = loop.unwrappedSelector(); + + loop.register(channel).syncUninterruptibly(); + + Selector newSelector = ((NioEventLoop) channel.eventLoop()).unwrappedSelector(); + assertTrue(newSelector.isOpen()); + assertNotSame(selector, newSelector); + assertFalse(selector.isOpen()); + + channel.close().syncUninterruptibly(); + } finally { + group.shutdownGracefully(); + } + } + + @Test + @Timeout(value = 3000, unit = TimeUnit.MILLISECONDS) + public void testChannelsRegistered() throws Exception { + NioEventLoopGroup group = new NioEventLoopGroup(1); + final NioEventLoop loop = (NioEventLoop) group.next(); + + try { + final Channel ch1 = new NioServerSocketChannel(); + final Channel ch2 = new NioServerSocketChannel(); + + assertEquals(0, registeredChannels(loop)); + + assertTrue(loop.register(ch1).syncUninterruptibly().isSuccess()); + assertTrue(loop.register(ch2).syncUninterruptibly().isSuccess()); + assertEquals(2, registeredChannels(loop)); + + assertTrue(ch1.deregister().syncUninterruptibly().isSuccess()); + + int registered; + // As SelectionKeys are removed in a lazy fashion in the JDK implementation we may need to query a few + // times before we see the right number of registered chanels. + while ((registered = registeredChannels(loop)) == 2) { + Thread.sleep(50); + } + assertEquals(1, registered); + } finally { + group.shutdownGracefully(); + } + } + + // Only reliable if run from event loop + private static int registeredChannels(final SingleThreadEventLoop loop) throws Exception { + return loop.submit(new Callable() { + @Override + public Integer call() { + return loop.registeredChannels(); + } + }).get(1, TimeUnit.SECONDS); + } + + @Test + public void testCustomQueue() { + final AtomicBoolean called = new AtomicBoolean(); + NioEventLoopGroup group = new NioEventLoopGroup(1, + new ThreadPerTaskExecutor(new DefaultThreadFactory(NioEventLoopGroup.class)), + DefaultEventExecutorChooserFactory.INSTANCE, SelectorProvider.provider(), + DefaultSelectStrategyFactory.INSTANCE, RejectedExecutionHandlers.reject(), + new EventLoopTaskQueueFactory() { + @Override + public Queue newTaskQueue(int maxCapacity) { + called.set(true); + return new LinkedBlockingQueue(maxCapacity); + } + }); + + final NioEventLoop loop = (NioEventLoop) group.next(); + + try { + loop.submit(new Runnable() { + @Override + public void run() { + // NOOP. + } + }).syncUninterruptibly(); + assertTrue(called.get()); + } finally { + group.shutdownGracefully(); + } + } + +} diff --git a/netty-channel/src/test/java/io/netty/channel/nio/SelectedSelectionKeySetTest.java b/netty-channel/src/test/java/io/netty/channel/nio/SelectedSelectionKeySetTest.java new file mode 100644 index 0000000..9f5bb2c --- /dev/null +++ b/netty-channel/src/test/java/io/netty/channel/nio/SelectedSelectionKeySetTest.java @@ -0,0 +1,117 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.nio; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import java.nio.channels.SelectionKey; +import java.util.Iterator; +import java.util.NoSuchElementException; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class SelectedSelectionKeySetTest { + @Mock + private SelectionKey mockKey; + @Mock + private SelectionKey mockKey2; + + @Mock + private SelectionKey mockKey3; + + @BeforeEach + public void setup() { + MockitoAnnotations.initMocks(this); + } + + @Test + public void addElements() { + SelectedSelectionKeySet set = new SelectedSelectionKeySet(); + final int expectedSize = 1000000; + for (int i = 0; i < expectedSize; ++i) { + assertTrue(set.add(mockKey)); + } + + assertEquals(expectedSize, set.size()); + assertFalse(set.isEmpty()); + } + + @Test + public void resetSet() { + SelectedSelectionKeySet set = new SelectedSelectionKeySet(); + assertTrue(set.add(mockKey)); + assertTrue(set.add(mockKey2)); + set.reset(1); + + assertSame(mockKey, set.keys[0]); + assertNull(set.keys[1]); + assertEquals(0, set.size()); + assertTrue(set.isEmpty()); + } + + @Test + public void iterator() { + SelectedSelectionKeySet set = new SelectedSelectionKeySet(); + assertTrue(set.add(mockKey)); + assertTrue(set.add(mockKey2)); + Iterator keys = set.iterator(); + assertTrue(keys.hasNext()); + assertSame(mockKey, keys.next()); + assertTrue(keys.hasNext()); + assertSame(mockKey2, keys.next()); + assertFalse(keys.hasNext()); + + try { + keys.next(); + fail(); + } catch (NoSuchElementException expected) { + // expected + } + + try { + keys.remove(); + fail(); + } catch (UnsupportedOperationException expected) { + // expected + } + } + + @Test + public void contains() { + SelectedSelectionKeySet set = new SelectedSelectionKeySet(); + assertTrue(set.add(mockKey)); + assertTrue(set.add(mockKey2)); + assertTrue(set.contains(mockKey)); + assertTrue(set.contains(mockKey2)); + assertFalse(set.contains(mockKey3)); + } + + @Test + public void remove() { + SelectedSelectionKeySet set = new SelectedSelectionKeySet(); + assertTrue(set.add(mockKey)); + assertFalse(set.remove(mockKey)); + assertFalse(set.remove(mockKey2)); + } +} diff --git a/netty-channel/src/test/java/io/netty/channel/oio/OioEventLoopTest.java b/netty-channel/src/test/java/io/netty/channel/oio/OioEventLoopTest.java new file mode 100644 index 0000000..8e748ae --- /dev/null +++ b/netty-channel/src/test/java/io/netty/channel/oio/OioEventLoopTest.java @@ -0,0 +1,117 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.channel.oio; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.channel.ChannelException; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.socket.oio.OioServerSocketChannel; +import io.netty.channel.socket.oio.OioSocketChannel; +import io.netty.util.NetUtil; +import org.junit.jupiter.api.Test; + +import java.net.InetSocketAddress; +import java.net.Socket; +import java.util.concurrent.CountDownLatch; + +import static org.hamcrest.CoreMatchers.containsString; +import static org.hamcrest.CoreMatchers.instanceOf; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; + +public class OioEventLoopTest { + @Test + public void testTooManyServerChannels() throws Exception { + EventLoopGroup g = new OioEventLoopGroup(1); + ServerBootstrap b = new ServerBootstrap(); + b.channel(OioServerSocketChannel.class); + b.group(g); + b.childHandler(new ChannelInboundHandlerAdapter()); + ChannelFuture f1 = b.bind(0); + f1.sync(); + + ChannelFuture f2 = b.bind(0); + f2.await(); + + assertThat(f2.cause(), is(instanceOf(ChannelException.class))); + assertThat(f2.cause().getMessage().toLowerCase(), containsString("too many channels")); + + final CountDownLatch notified = new CountDownLatch(1); + f2.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + notified.countDown(); + } + }); + + notified.await(); + g.shutdownGracefully(); + } + + @Test + public void testTooManyClientChannels() throws Exception { + EventLoopGroup g = new OioEventLoopGroup(1); + ServerBootstrap sb = new ServerBootstrap(); + sb.channel(OioServerSocketChannel.class); + sb.group(g); + sb.childHandler(new ChannelInboundHandlerAdapter()); + ChannelFuture f1 = sb.bind(0); + f1.sync(); + + Bootstrap cb = new Bootstrap(); + cb.channel(OioSocketChannel.class); + cb.group(g); + cb.handler(new ChannelInboundHandlerAdapter()); + ChannelFuture f2 = cb.connect(NetUtil.LOCALHOST, ((InetSocketAddress) f1.channel().localAddress()).getPort()); + f2.await(); + + assertThat(f2.cause(), is(instanceOf(ChannelException.class))); + assertThat(f2.cause().getMessage().toLowerCase(), containsString("too many channels")); + + final CountDownLatch notified = new CountDownLatch(1); + f2.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + notified.countDown(); + } + }); + + notified.await(); + g.shutdownGracefully(); + } + + @Test + public void testTooManyAcceptedChannels() throws Exception { + EventLoopGroup g = new OioEventLoopGroup(1); + ServerBootstrap sb = new ServerBootstrap(); + sb.channel(OioServerSocketChannel.class); + sb.group(g); + sb.childHandler(new ChannelInboundHandlerAdapter()); + ChannelFuture f1 = sb.bind(0); + f1.sync(); + + Socket s = new Socket(NetUtil.LOCALHOST, ((InetSocketAddress) f1.channel().localAddress()).getPort()); + assertThat(s.getInputStream().read(), is(-1)); + s.close(); + + g.shutdownGracefully(); + } +} diff --git a/netty-channel/src/test/java/io/netty/channel/pool/AbstractChannelPoolMapTest.java b/netty-channel/src/test/java/io/netty/channel/pool/AbstractChannelPoolMapTest.java new file mode 100644 index 0000000..7a257ef --- /dev/null +++ b/netty-channel/src/test/java/io/netty/channel/pool/AbstractChannelPoolMapTest.java @@ -0,0 +1,173 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.pool; + +import io.netty.bootstrap.Bootstrap; +import io.netty.channel.Channel; +import io.netty.channel.EventLoop; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.local.LocalAddress; +import io.netty.channel.local.LocalChannel; +import io.netty.channel.local.LocalEventLoopGroup; +import io.netty.util.concurrent.EventExecutor; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.GenericFutureListener; +import io.netty.util.concurrent.Promise; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import java.net.ConnectException; +import java.util.concurrent.TimeUnit; + +import static io.netty.channel.pool.ChannelPoolTestUtils.getLocalAddrId; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class AbstractChannelPoolMapTest { + @Test + public void testMap() throws Exception { + EventLoopGroup group = new LocalEventLoopGroup(); + LocalAddress addr = new LocalAddress(getLocalAddrId()); + final Bootstrap cb = new Bootstrap(); + cb.remoteAddress(addr); + cb.group(group) + .channel(LocalChannel.class); + + AbstractChannelPoolMap poolMap = + new AbstractChannelPoolMap() { + @Override + protected SimpleChannelPool newPool(EventLoop key) { + return new SimpleChannelPool(cb.clone(key), new TestChannelPoolHandler()); + } + }; + + EventLoop loop = group.next(); + + assertFalse(poolMap.iterator().hasNext()); + assertEquals(0, poolMap.size()); + + final SimpleChannelPool pool = poolMap.get(loop); + assertEquals(1, poolMap.size()); + assertTrue(poolMap.iterator().hasNext()); + + assertSame(pool, poolMap.get(loop)); + assertTrue(poolMap.remove(loop)); + assertFalse(poolMap.remove(loop)); + + assertFalse(poolMap.iterator().hasNext()); + assertEquals(0, poolMap.size()); + + assertThrows(ConnectException.class, new Executable() { + @Override + public void execute() throws Throwable { + pool.acquire().syncUninterruptibly(); + } + }); + poolMap.close(); + } + + @Test + public void testRemoveClosesChannelPool() { + EventLoopGroup group = new LocalEventLoopGroup(); + LocalAddress addr = new LocalAddress(getLocalAddrId()); + final Bootstrap cb = new Bootstrap(); + cb.remoteAddress(addr); + cb.group(group) + .channel(LocalChannel.class); + + AbstractChannelPoolMap poolMap = + new AbstractChannelPoolMap() { + @Override + protected TestPool newPool(EventLoop key) { + return new TestPool(cb.clone(key), new TestChannelPoolHandler()); + } + }; + + EventLoop loop = group.next(); + + TestPool pool = poolMap.get(loop); + assertTrue(poolMap.remove(loop)); + + // the pool should be closed eventually after remove + pool.closeFuture.awaitUninterruptibly(1, TimeUnit.SECONDS); + assertTrue(pool.closeFuture.isDone()); + poolMap.close(); + } + + @Test + public void testCloseClosesPoolsImmediately() { + EventLoopGroup group = new LocalEventLoopGroup(); + LocalAddress addr = new LocalAddress(getLocalAddrId()); + final Bootstrap cb = new Bootstrap(); + cb.remoteAddress(addr); + cb.group(group) + .channel(LocalChannel.class); + + AbstractChannelPoolMap poolMap = + new AbstractChannelPoolMap() { + @Override + protected TestPool newPool(EventLoop key) { + return new TestPool(cb.clone(key), new TestChannelPoolHandler()); + } + }; + + EventLoop loop = group.next(); + + TestPool pool = poolMap.get(loop); + assertFalse(pool.closeFuture.isDone()); + + // the pool should be closed immediately after remove + poolMap.close(); + assertTrue(pool.closeFuture.isDone()); + } + + private static final class TestChannelPoolHandler extends AbstractChannelPoolHandler { + @Override + public void channelCreated(Channel ch) throws Exception { + // NOOP + } + } + + private static final class TestPool extends SimpleChannelPool { + private final Promise closeFuture; + + TestPool(Bootstrap bootstrap, ChannelPoolHandler handler) { + super(bootstrap, handler); + EventExecutor executor = bootstrap.config().group().next(); + closeFuture = executor.newPromise(); + } + + @Override + public Future closeAsync() { + Future poolClose = super.closeAsync(); + poolClose.addListener(new GenericFutureListener>() { + @Override + public void operationComplete(Future future) throws Exception { + if (future.isSuccess()) { + closeFuture.setSuccess(null); + } else { + closeFuture.setFailure(future.cause()); + } + } + }); + return poolClose; + } + } +} diff --git a/netty-channel/src/test/java/io/netty/channel/pool/ChannelPoolTestUtils.java b/netty-channel/src/test/java/io/netty/channel/pool/ChannelPoolTestUtils.java new file mode 100644 index 0000000..71ef173 --- /dev/null +++ b/netty-channel/src/test/java/io/netty/channel/pool/ChannelPoolTestUtils.java @@ -0,0 +1,29 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.pool; + +import io.netty.util.internal.ThreadLocalRandom; + +final class ChannelPoolTestUtils { + private static final String LOCAL_ADDR_ID = "test.id"; + + private ChannelPoolTestUtils() { + } + + static String getLocalAddrId() { + return LOCAL_ADDR_ID + ThreadLocalRandom.current().nextInt(); + } +} diff --git a/netty-channel/src/test/java/io/netty/channel/pool/CountingChannelPoolHandler.java b/netty-channel/src/test/java/io/netty/channel/pool/CountingChannelPoolHandler.java new file mode 100644 index 0000000..b2401c0 --- /dev/null +++ b/netty-channel/src/test/java/io/netty/channel/pool/CountingChannelPoolHandler.java @@ -0,0 +1,53 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.pool; + +import io.netty.channel.Channel; + +import java.util.concurrent.atomic.AtomicInteger; + +final class CountingChannelPoolHandler implements ChannelPoolHandler { + private final AtomicInteger channelCount = new AtomicInteger(0); + private final AtomicInteger acquiredCount = new AtomicInteger(0); + private final AtomicInteger releasedCount = new AtomicInteger(0); + + @Override + public void channelCreated(Channel ch) { + channelCount.incrementAndGet(); + } + + @Override + public void channelReleased(Channel ch) { + releasedCount.incrementAndGet(); + } + + @Override + public void channelAcquired(Channel ch) { + acquiredCount.incrementAndGet(); + } + + public int channelCount() { + return channelCount.get(); + } + + public int acquiredCount() { + return acquiredCount.get(); + } + + public int releasedCount() { + return releasedCount.get(); + } +} diff --git a/netty-channel/src/test/java/io/netty/channel/pool/FixedChannelPoolMapDeadlockTest.java b/netty-channel/src/test/java/io/netty/channel/pool/FixedChannelPoolMapDeadlockTest.java new file mode 100644 index 0000000..ef44ec0 --- /dev/null +++ b/netty-channel/src/test/java/io/netty/channel/pool/FixedChannelPoolMapDeadlockTest.java @@ -0,0 +1,264 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.channel.pool; + +import io.netty.bootstrap.Bootstrap; +import io.netty.channel.Channel; +import io.netty.channel.DefaultEventLoop; +import io.netty.channel.EventLoop; +import io.netty.channel.local.LocalAddress; +import io.netty.channel.local.LocalChannel; +import io.netty.util.concurrent.Future; +import org.junit.jupiter.api.Test; + +import java.util.concurrent.Callable; +import java.util.concurrent.CyclicBarrier; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +import static org.junit.jupiter.api.Assertions.assertSame; + +/** + * This is a test case for the deadlock scenario described in https://github.com/netty/netty/issues/8238. + */ +public class FixedChannelPoolMapDeadlockTest { + + private static final NoopHandler NOOP_HANDLER = new NoopHandler(); + + @Test + public void testDeadlockOnAcquire() throws Exception { + + final EventLoop threadA1 = new DefaultEventLoop(); + final Bootstrap bootstrapA1 = new Bootstrap() + .channel(LocalChannel.class).group(threadA1).localAddress(new LocalAddress("A1")); + final EventLoop threadA2 = new DefaultEventLoop(); + final Bootstrap bootstrapA2 = new Bootstrap() + .channel(LocalChannel.class).group(threadA2).localAddress(new LocalAddress("A2")); + final EventLoop threadB1 = new DefaultEventLoop(); + final Bootstrap bootstrapB1 = new Bootstrap() + .channel(LocalChannel.class).group(threadB1).localAddress(new LocalAddress("B1")); + final EventLoop threadB2 = new DefaultEventLoop(); + final Bootstrap bootstrapB2 = new Bootstrap() + .channel(LocalChannel.class).group(threadB2).localAddress(new LocalAddress("B2")); + + final FixedChannelPool poolA1 = new FixedChannelPool(bootstrapA1, NOOP_HANDLER, 1); + final FixedChannelPool poolA2 = new FixedChannelPool(bootstrapB2, NOOP_HANDLER, 1); + final FixedChannelPool poolB1 = new FixedChannelPool(bootstrapB1, NOOP_HANDLER, 1); + final FixedChannelPool poolB2 = new FixedChannelPool(bootstrapA2, NOOP_HANDLER, 1); + + // Synchronize threads on these barriers to ensure order of execution, first wait until each thread is inside + // the newPool callback, then hold the two threads that should lose the match until the first two returns, then + // release them to test if they deadlock when trying to release their pools on each other's threads. + final CyclicBarrier arrivalBarrier = new CyclicBarrier(4); + final CyclicBarrier releaseBarrier = new CyclicBarrier(3); + + final AbstractChannelPoolMap channelPoolMap = + new AbstractChannelPoolMap() { + + @Override + protected FixedChannelPool newPool(String key) { + + // Thread A1 gets a new pool on eventexecutor thread A1 (anywhere but A2 or B2) + // Thread B1 gets a new pool on eventexecutor thread B1 (anywhere but A2 or B2) + // Thread A2 gets a new pool on eventexecutor thread B2 + // Thread B2 gets a new pool on eventexecutor thread A2 + + if ("A".equals(key)) { + if (threadA1.inEventLoop()) { + // Thread A1 gets pool A with thread A1 + await(arrivalBarrier); + return poolA1; + } else if (threadA2.inEventLoop()) { + // Thread A2 gets pool A with thread B2, but only after A1 won + await(arrivalBarrier); + await(releaseBarrier); + return poolA2; + } + } else if ("B".equals(key)) { + if (threadB1.inEventLoop()) { + // Thread B1 gets pool with thread B1 + await(arrivalBarrier); + return poolB1; + } else if (threadB2.inEventLoop()) { + // Thread B2 gets pool with thread A2 + await(arrivalBarrier); + await(releaseBarrier); + return poolB2; + } + } + throw new AssertionError("Unexpected key=" + key + " or thread=" + + Thread.currentThread().getName()); + } + }; + + // Thread A1 calls ChannelPoolMap.get(A) + // Thread A2 calls ChannelPoolMap.get(A) + // Thread B1 calls ChannelPoolMap.get(B) + // Thread B2 calls ChannelPoolMap.get(B) + + Future futureA1 = threadA1.submit(new Callable() { + @Override + public FixedChannelPool call() throws Exception { + return channelPoolMap.get("A"); + } + }); + + Future futureA2 = threadA2.submit(new Callable() { + @Override + public FixedChannelPool call() throws Exception { + return channelPoolMap.get("A"); + } + }); + + Future futureB1 = threadB1.submit(new Callable() { + @Override + public FixedChannelPool call() throws Exception { + return channelPoolMap.get("B"); + } + }); + + Future futureB2 = threadB2.submit(new Callable() { + @Override + public FixedChannelPool call() throws Exception { + return channelPoolMap.get("B"); + } + }); + + // Thread A1 succeeds on updating the map and moves on + // Thread B1 succeeds on updating the map and moves on + // These should always succeed and return with new pools + try { + assertSame(poolA1, futureA1.get(1, TimeUnit.SECONDS)); + assertSame(poolB1, futureB1.get(1, TimeUnit.SECONDS)); + } catch (Exception e) { + shutdown(threadA1, threadA2, threadB1, threadB2); + throw e; + } + + // Now release the other two threads which at this point lost the race and will try to clean up the acquired + // pools. The expected scenario is that both pools close, in case of a deadlock they will hang. + await(releaseBarrier); + + // Thread A2 fails to update the map and submits close to thread B2 + // Thread B2 fails to update the map and submits close to thread A2 + // If the close is blocking, then these calls will time out as the threads are waiting for each other + // If the close is not blocking, then the previously created pools will be returned + try { + assertSame(poolA1, futureA2.get(1, TimeUnit.SECONDS)); + assertSame(poolB1, futureB2.get(1, TimeUnit.SECONDS)); + } catch (TimeoutException e) { + // Fail the test on timeout to distinguish from other errors + throw new AssertionError(e); + } finally { + poolA1.close(); + poolA2.close(); + poolB1.close(); + poolB2.close(); + channelPoolMap.close(); + shutdown(threadA1, threadA2, threadB1, threadB2); + } + } + + @Test + public void testDeadlockOnRemove() throws Exception { + + final EventLoop thread1 = new DefaultEventLoop(); + final Bootstrap bootstrap1 = new Bootstrap() + .channel(LocalChannel.class).group(thread1).localAddress(new LocalAddress("#1")); + final EventLoop thread2 = new DefaultEventLoop(); + final Bootstrap bootstrap2 = new Bootstrap() + .channel(LocalChannel.class).group(thread2).localAddress(new LocalAddress("#2")); + + // pool1 runs on thread2, pool2 runs on thread1 + final FixedChannelPool pool1 = new FixedChannelPool(bootstrap2, NOOP_HANDLER, 1); + final FixedChannelPool pool2 = new FixedChannelPool(bootstrap1, NOOP_HANDLER, 1); + + final AbstractChannelPoolMap channelPoolMap = + new AbstractChannelPoolMap() { + + @Override + protected FixedChannelPool newPool(String key) { + if ("#1".equals(key)) { + return pool1; + } else if ("#2".equals(key)) { + return pool2; + } else { + throw new AssertionError("Unexpected key=" + key); + } + } + }; + + assertSame(pool1, channelPoolMap.get("#1")); + assertSame(pool2, channelPoolMap.get("#2")); + + // thread1 tries to remove pool1 which is running on thread2 + // thread2 tries to remove pool2 which is running on thread1 + + final CyclicBarrier barrier = new CyclicBarrier(2); + + Future future1 = thread1.submit(new Runnable() { + @Override + public void run() { + await(barrier); + channelPoolMap.remove("#1"); + } + }); + + Future future2 = thread2.submit(new Runnable() { + @Override + public void run() { + await(barrier); + channelPoolMap.remove("#2"); + } + }); + + // A blocking close on remove will cause a deadlock here and the test will time out + try { + future1.get(1, TimeUnit.SECONDS); + future2.get(1, TimeUnit.SECONDS); + } catch (TimeoutException e) { + // Fail the test on timeout to distinguish from other errors + throw new AssertionError(e); + } finally { + pool1.close(); + pool2.close(); + channelPoolMap.close(); + shutdown(thread1, thread2); + } + } + + private static void await(CyclicBarrier barrier) { + try { + barrier.await(1, TimeUnit.SECONDS); + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + private static void shutdown(EventLoop... eventLoops) { + for (EventLoop eventLoop : eventLoops) { + eventLoop.shutdownGracefully(0, 0, TimeUnit.SECONDS); + } + } + + private static class NoopHandler extends AbstractChannelPoolHandler { + @Override + public void channelCreated(Channel ch) throws Exception { + // noop + } + } +} diff --git a/netty-channel/src/test/java/io/netty/channel/pool/FixedChannelPoolTest.java b/netty-channel/src/test/java/io/netty/channel/pool/FixedChannelPoolTest.java new file mode 100644 index 0000000..82cf67f --- /dev/null +++ b/netty-channel/src/test/java/io/netty/channel/pool/FixedChannelPoolTest.java @@ -0,0 +1,459 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.pool; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.channel.Channel; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelPromise; +import io.netty.channel.DefaultEventLoopGroup; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.local.LocalAddress; +import io.netty.channel.local.LocalChannel; +import io.netty.channel.local.LocalServerChannel; +import io.netty.channel.pool.FixedChannelPool.AcquireTimeoutAction; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.GenericFutureListener; + +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +import static io.netty.channel.pool.ChannelPoolTestUtils.getLocalAddrId; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotSame; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class FixedChannelPoolTest { + private static EventLoopGroup group; + + @BeforeAll + public static void createEventLoop() { + group = new DefaultEventLoopGroup(); + } + + @AfterAll + public static void destroyEventLoop() { + if (group != null) { + group.shutdownGracefully(); + } + } + + @Test + public void testAcquire() throws Exception { + LocalAddress addr = new LocalAddress(getLocalAddrId()); + Bootstrap cb = new Bootstrap(); + cb.remoteAddress(addr); + cb.group(group) + .channel(LocalChannel.class); + + ServerBootstrap sb = new ServerBootstrap(); + sb.group(group) + .channel(LocalServerChannel.class) + .childHandler(new ChannelInitializer() { + @Override + public void initChannel(LocalChannel ch) throws Exception { + ch.pipeline().addLast(new ChannelInboundHandlerAdapter()); + } + }); + + // Start server + Channel sc = sb.bind(addr).syncUninterruptibly().channel(); + CountingChannelPoolHandler handler = new CountingChannelPoolHandler(); + + ChannelPool pool = new FixedChannelPool(cb, handler, 1, Integer.MAX_VALUE); + + Channel channel = pool.acquire().syncUninterruptibly().getNow(); + Future future = pool.acquire(); + assertFalse(future.isDone()); + + pool.release(channel).syncUninterruptibly(); + assertTrue(future.await(1, TimeUnit.SECONDS)); + + Channel channel2 = future.getNow(); + assertSame(channel, channel2); + assertEquals(1, handler.channelCount()); + + assertEquals(2, handler.acquiredCount()); + assertEquals(1, handler.releasedCount()); + + sc.close().syncUninterruptibly(); + channel2.close().syncUninterruptibly(); + pool.close(); + } + + @Test + public void testAcquireTimeout() throws Exception { + testAcquireTimeout(500); + } + + @Test + public void testAcquireWithZeroTimeout() throws Exception { + testAcquireTimeout(0); + } + + private static void testAcquireTimeout(long timeoutMillis) throws Exception { + LocalAddress addr = new LocalAddress(getLocalAddrId()); + Bootstrap cb = new Bootstrap(); + cb.remoteAddress(addr); + cb.group(group) + .channel(LocalChannel.class); + + ServerBootstrap sb = new ServerBootstrap(); + sb.group(group) + .channel(LocalServerChannel.class) + .childHandler(new ChannelInitializer() { + @Override + public void initChannel(LocalChannel ch) throws Exception { + ch.pipeline().addLast(new ChannelInboundHandlerAdapter()); + } + }); + + // Start server + Channel sc = sb.bind(addr).syncUninterruptibly().channel(); + ChannelPoolHandler handler = new TestChannelPoolHandler(); + ChannelPool pool = new FixedChannelPool(cb, handler, ChannelHealthChecker.ACTIVE, + AcquireTimeoutAction.FAIL, timeoutMillis, 1, Integer.MAX_VALUE); + + Channel channel = pool.acquire().syncUninterruptibly().getNow(); + final Future future = pool.acquire(); + assertThrows(TimeoutException.class, new Executable() { + @Override + public void execute() throws Throwable { + future.syncUninterruptibly(); + } + }); + sc.close().syncUninterruptibly(); + channel.close().syncUninterruptibly(); + pool.close(); + } + + @Test + public void testAcquireNewConnection() throws Exception { + LocalAddress addr = new LocalAddress(getLocalAddrId()); + Bootstrap cb = new Bootstrap(); + cb.remoteAddress(addr); + cb.group(group) + .channel(LocalChannel.class); + + ServerBootstrap sb = new ServerBootstrap(); + sb.group(group) + .channel(LocalServerChannel.class) + .childHandler(new ChannelInitializer() { + @Override + public void initChannel(LocalChannel ch) throws Exception { + ch.pipeline().addLast(new ChannelInboundHandlerAdapter()); + } + }); + + // Start server + Channel sc = sb.bind(addr).syncUninterruptibly().channel(); + ChannelPoolHandler handler = new TestChannelPoolHandler(); + ChannelPool pool = new FixedChannelPool(cb, handler, ChannelHealthChecker.ACTIVE, + AcquireTimeoutAction.NEW, 500, 1, Integer.MAX_VALUE); + + Channel channel = pool.acquire().syncUninterruptibly().getNow(); + Channel channel2 = pool.acquire().syncUninterruptibly().getNow(); + assertNotSame(channel, channel2); + sc.close().syncUninterruptibly(); + channel.close().syncUninterruptibly(); + channel2.close().syncUninterruptibly(); + pool.close(); + } + + /** + * Tests that the acquiredChannelCount is not added up several times for the same channel acquire request. + * @throws Exception + */ + @Test + public void testAcquireNewConnectionWhen() throws Exception { + LocalAddress addr = new LocalAddress(getLocalAddrId()); + Bootstrap cb = new Bootstrap(); + cb.remoteAddress(addr); + cb.group(group) + .channel(LocalChannel.class); + + ServerBootstrap sb = new ServerBootstrap(); + sb.group(group) + .channel(LocalServerChannel.class) + .childHandler(new ChannelInitializer() { + @Override + public void initChannel(LocalChannel ch) throws Exception { + ch.pipeline().addLast(new ChannelInboundHandlerAdapter()); + } + }); + + // Start server + Channel sc = sb.bind(addr).syncUninterruptibly().channel(); + ChannelPoolHandler handler = new TestChannelPoolHandler(); + ChannelPool pool = new FixedChannelPool(cb, handler, 1); + Channel channel1 = pool.acquire().syncUninterruptibly().getNow(); + channel1.close().syncUninterruptibly(); + pool.release(channel1); + + Channel channel2 = pool.acquire().syncUninterruptibly().getNow(); + + assertNotSame(channel1, channel2); + sc.close().syncUninterruptibly(); + channel2.close().syncUninterruptibly(); + pool.close(); + } + + @Test + public void testAcquireBoundQueue() throws Exception { + LocalAddress addr = new LocalAddress(getLocalAddrId()); + Bootstrap cb = new Bootstrap(); + cb.remoteAddress(addr); + cb.group(group) + .channel(LocalChannel.class); + + ServerBootstrap sb = new ServerBootstrap(); + sb.group(group) + .channel(LocalServerChannel.class) + .childHandler(new ChannelInitializer() { + @Override + public void initChannel(LocalChannel ch) throws Exception { + ch.pipeline().addLast(new ChannelInboundHandlerAdapter()); + } + }); + + // Start server + Channel sc = sb.bind(addr).syncUninterruptibly().channel(); + ChannelPoolHandler handler = new TestChannelPoolHandler(); + final ChannelPool pool = new FixedChannelPool(cb, handler, 1, 1); + + Channel channel = pool.acquire().syncUninterruptibly().getNow(); + Future future = pool.acquire(); + assertFalse(future.isDone()); + + assertThrows(IllegalStateException.class, new Executable() { + @Override + public void execute() throws Throwable { + pool.acquire().syncUninterruptibly(); + } + }); + sc.close().syncUninterruptibly(); + channel.close().syncUninterruptibly(); + pool.close(); + } + + @Test + public void testReleaseDifferentPool() throws Exception { + LocalAddress addr = new LocalAddress(getLocalAddrId()); + Bootstrap cb = new Bootstrap(); + cb.remoteAddress(addr); + cb.group(group) + .channel(LocalChannel.class); + + ServerBootstrap sb = new ServerBootstrap(); + sb.group(group) + .channel(LocalServerChannel.class) + .childHandler(new ChannelInitializer() { + @Override + public void initChannel(LocalChannel ch) throws Exception { + ch.pipeline().addLast(new ChannelInboundHandlerAdapter()); + } + }); + + // Start server + Channel sc = sb.bind(addr).syncUninterruptibly().channel(); + ChannelPoolHandler handler = new TestChannelPoolHandler(); + ChannelPool pool = new FixedChannelPool(cb, handler, 1, 1); + final ChannelPool pool2 = new FixedChannelPool(cb, handler, 1, 1); + + final Channel channel = pool.acquire().syncUninterruptibly().getNow(); + + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() throws Throwable { + pool2.release(channel).syncUninterruptibly(); + } + }); + sc.close().syncUninterruptibly(); + channel.close().syncUninterruptibly(); + pool.close(); + pool2.close(); + } + + @Test + public void testReleaseAfterClosePool() throws Exception { + LocalAddress addr = new LocalAddress(getLocalAddrId()); + Bootstrap cb = new Bootstrap(); + cb.remoteAddress(addr); + cb.group(group).channel(LocalChannel.class); + + ServerBootstrap sb = new ServerBootstrap(); + sb.group(group) + .channel(LocalServerChannel.class) + .childHandler(new ChannelInitializer() { + @Override + public void initChannel(LocalChannel ch) throws Exception { + ch.pipeline().addLast(new ChannelInboundHandlerAdapter()); + } + }); + + // Start server + Channel sc = sb.bind(addr).syncUninterruptibly().channel(); + + final FixedChannelPool pool = new FixedChannelPool(cb, new TestChannelPoolHandler(), 2); + final Future acquire = pool.acquire(); + final Channel channel = acquire.get(); + pool.close(); + group.submit(new Runnable() { + @Override + public void run() { + // NOOP + } + }).syncUninterruptibly(); + assertThrows(IllegalStateException.class, new Executable() { + @Override + public void execute() throws Throwable { + pool.release(channel).syncUninterruptibly(); + } + }); + // Since the pool is closed, the Channel should have been closed as well. + channel.closeFuture().syncUninterruptibly(); + assertFalse(channel.isOpen()); + sc.close().syncUninterruptibly(); + pool.close(); + } + + @Test + public void testReleaseClosed() { + LocalAddress addr = new LocalAddress(getLocalAddrId()); + Bootstrap cb = new Bootstrap(); + cb.remoteAddress(addr); + cb.group(group).channel(LocalChannel.class); + + ServerBootstrap sb = new ServerBootstrap(); + sb.group(group) + .channel(LocalServerChannel.class) + .childHandler(new ChannelInitializer() { + @Override + public void initChannel(LocalChannel ch) throws Exception { + ch.pipeline().addLast(new ChannelInboundHandlerAdapter()); + } + }); + + // Start server + Channel sc = sb.bind(addr).syncUninterruptibly().channel(); + + FixedChannelPool pool = new FixedChannelPool(cb, new TestChannelPoolHandler(), 2); + Channel channel = pool.acquire().syncUninterruptibly().getNow(); + channel.close().syncUninterruptibly(); + pool.release(channel).syncUninterruptibly(); + + sc.close().syncUninterruptibly(); + pool.close(); + } + + @Test + public void testCloseAsync() throws ExecutionException, InterruptedException { + LocalAddress addr = new LocalAddress(getLocalAddrId()); + Bootstrap cb = new Bootstrap(); + cb.remoteAddress(addr); + cb.group(group).channel(LocalChannel.class); + + ServerBootstrap sb = new ServerBootstrap(); + sb.group(group) + .channel(LocalServerChannel.class) + .childHandler(new ChannelInitializer() { + @Override + public void initChannel(LocalChannel ch) throws Exception { + ch.pipeline().addLast(new ChannelInboundHandlerAdapter()); + } + }); + + // Start server + final Channel sc = sb.bind(addr).syncUninterruptibly().channel(); + + final FixedChannelPool pool = new FixedChannelPool(cb, new TestChannelPoolHandler(), 2); + + pool.acquire().get(); + pool.acquire().get(); + + final ChannelPromise closePromise = sc.newPromise(); + pool.closeAsync().addListener(new GenericFutureListener>() { + @Override + public void operationComplete(Future future) throws Exception { + assertEquals(0, pool.acquiredChannelCount()); + sc.close(closePromise).syncUninterruptibly(); + } + }).awaitUninterruptibly(); + closePromise.awaitUninterruptibly(); + } + + @Test + public void testChannelAcquiredException() throws InterruptedException { + LocalAddress addr = new LocalAddress(getLocalAddrId()); + Bootstrap cb = new Bootstrap(); + cb.remoteAddress(addr); + cb.group(group).channel(LocalChannel.class); + + ServerBootstrap sb = new ServerBootstrap(); + sb.group(group) + .channel(LocalServerChannel.class) + .childHandler(new ChannelInitializer() { + @Override + public void initChannel(LocalChannel ch) throws Exception { + ch.pipeline().addLast(new ChannelInboundHandlerAdapter()); + } + }); + + // Start server + Channel sc = sb.bind(addr).syncUninterruptibly().channel(); + final NullPointerException exception = new NullPointerException(); + FixedChannelPool pool = new FixedChannelPool(cb, new ChannelPoolHandler() { + @Override + public void channelReleased(Channel ch) { + } + @Override + public void channelAcquired(Channel ch) { + throw exception; + } + @Override + public void channelCreated(Channel ch) { + } + }, 2); + + try { + pool.acquire().sync(); + } catch (NullPointerException e) { + assertSame(e, exception); + } + + sc.close().sync(); + pool.close(); + } + + private static final class TestChannelPoolHandler extends AbstractChannelPoolHandler { + @Override + public void channelCreated(Channel ch) throws Exception { + // NOOP + } + } +} diff --git a/netty-channel/src/test/java/io/netty/channel/pool/SimpleChannelPoolTest.java b/netty-channel/src/test/java/io/netty/channel/pool/SimpleChannelPoolTest.java new file mode 100644 index 0000000..bc2ea62 --- /dev/null +++ b/netty-channel/src/test/java/io/netty/channel/pool/SimpleChannelPoolTest.java @@ -0,0 +1,401 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.pool; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.channel.Channel; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.DefaultEventLoopGroup; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.local.LocalAddress; +import io.netty.channel.local.LocalChannel; +import io.netty.channel.local.LocalServerChannel; +import io.netty.util.concurrent.Future; +import org.hamcrest.CoreMatchers; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import java.util.Queue; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; + +import static io.netty.channel.pool.ChannelPoolTestUtils.getLocalAddrId; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNotSame; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class SimpleChannelPoolTest { + @Test + public void testAcquire() throws Exception { + EventLoopGroup group = new DefaultEventLoopGroup(); + LocalAddress addr = new LocalAddress(getLocalAddrId()); + Bootstrap cb = new Bootstrap(); + cb.remoteAddress(addr); + cb.group(group) + .channel(LocalChannel.class); + + ServerBootstrap sb = new ServerBootstrap(); + sb.group(group) + .channel(LocalServerChannel.class) + .childHandler(new ChannelInitializer() { + @Override + public void initChannel(LocalChannel ch) throws Exception { + ch.pipeline().addLast(new ChannelInboundHandlerAdapter()); + } + }); + + // Start server + Channel sc = sb.bind(addr).sync().channel(); + CountingChannelPoolHandler handler = new CountingChannelPoolHandler(); + + final ChannelPool pool = new SimpleChannelPool(cb, handler); + + Channel channel = pool.acquire().sync().getNow(); + + pool.release(channel).syncUninterruptibly(); + + final Channel channel2 = pool.acquire().sync().getNow(); + assertSame(channel, channel2); + assertEquals(1, handler.channelCount()); + pool.release(channel2).syncUninterruptibly(); + + // Should fail on multiple release calls. + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() throws Throwable { + pool.release(channel2).syncUninterruptibly(); + } + }); + assertFalse(channel.isActive()); + + assertEquals(2, handler.acquiredCount()); + assertEquals(2, handler.releasedCount()); + + sc.close().sync(); + pool.close(); + group.shutdownGracefully(); + } + + @Test + public void testBoundedChannelPoolSegment() throws Exception { + EventLoopGroup group = new DefaultEventLoopGroup(); + LocalAddress addr = new LocalAddress(getLocalAddrId()); + Bootstrap cb = new Bootstrap(); + cb.remoteAddress(addr); + cb.group(group) + .channel(LocalChannel.class); + + ServerBootstrap sb = new ServerBootstrap(); + sb.group(group) + .channel(LocalServerChannel.class) + .childHandler(new ChannelInitializer() { + @Override + public void initChannel(LocalChannel ch) throws Exception { + ch.pipeline().addLast(new ChannelInboundHandlerAdapter()); + } + }); + + // Start server + Channel sc = sb.bind(addr).sync().channel(); + CountingChannelPoolHandler handler = new CountingChannelPoolHandler(); + + final ChannelPool pool = new SimpleChannelPool(cb, handler, ChannelHealthChecker.ACTIVE) { + private final Queue queue = new LinkedBlockingQueue(1); + + @Override + protected Channel pollChannel() { + return queue.poll(); + } + + @Override + protected boolean offerChannel(Channel ch) { + return queue.offer(ch); + } + }; + + Channel channel = pool.acquire().sync().getNow(); + final Channel channel2 = pool.acquire().sync().getNow(); + + pool.release(channel).syncUninterruptibly().getNow(); + assertThrows(IllegalStateException.class, new Executable() { + @Override + public void execute() throws Throwable { + pool.release(channel2).syncUninterruptibly(); + } + }); + channel2.close().sync(); + + assertEquals(2, handler.channelCount()); + assertEquals(2, handler.acquiredCount()); + assertEquals(1, handler.releasedCount()); + sc.close().sync(); + channel.close().sync(); + channel2.close().sync(); + pool.close(); + group.shutdownGracefully(); + } + + /** + * Tests that if channel was unhealthy it is not offered back to the pool. + * + * @throws Exception + */ + @Test + public void testUnhealthyChannelIsNotOffered() throws Exception { + EventLoopGroup group = new DefaultEventLoopGroup(); + LocalAddress addr = new LocalAddress(getLocalAddrId()); + Bootstrap cb = new Bootstrap(); + cb.remoteAddress(addr); + cb.group(group) + .channel(LocalChannel.class); + + ServerBootstrap sb = new ServerBootstrap(); + sb.group(group) + .channel(LocalServerChannel.class) + .childHandler(new ChannelInitializer() { + @Override + public void initChannel(LocalChannel ch) throws Exception { + ch.pipeline().addLast(new ChannelInboundHandlerAdapter()); + } + }); + + // Start server + Channel sc = sb.bind(addr).syncUninterruptibly().channel(); + ChannelPoolHandler handler = new CountingChannelPoolHandler(); + ChannelPool pool = new SimpleChannelPool(cb, handler); + Channel channel1 = pool.acquire().syncUninterruptibly().getNow(); + pool.release(channel1).syncUninterruptibly(); + Channel channel2 = pool.acquire().syncUninterruptibly().getNow(); + //first check that when returned healthy then it actually offered back to the pool. + assertSame(channel1, channel2); + + channel1.close().syncUninterruptibly(); + + pool.release(channel1).syncUninterruptibly(); + Channel channel3 = pool.acquire().syncUninterruptibly().getNow(); + //channel1 was not healthy anymore so it should not get acquired anymore. + assertNotSame(channel1, channel3); + sc.close().syncUninterruptibly(); + channel3.close().syncUninterruptibly(); + pool.close(); + group.shutdownGracefully(); + } + + /** + * Tests that if channel was unhealthy it is was offered back to the pool because + * it was requested not to validate channel health on release. + * + * @throws Exception + */ + @Test + public void testUnhealthyChannelIsOfferedWhenNoHealthCheckRequested() throws Exception { + EventLoopGroup group = new DefaultEventLoopGroup(); + LocalAddress addr = new LocalAddress(getLocalAddrId()); + Bootstrap cb = new Bootstrap(); + cb.remoteAddress(addr); + cb.group(group) + .channel(LocalChannel.class); + + ServerBootstrap sb = new ServerBootstrap(); + sb.group(group) + .channel(LocalServerChannel.class) + .childHandler(new ChannelInitializer() { + @Override + public void initChannel(LocalChannel ch) throws Exception { + ch.pipeline().addLast(new ChannelInboundHandlerAdapter()); + } + }); + + // Start server + Channel sc = sb.bind(addr).syncUninterruptibly().channel(); + ChannelPoolHandler handler = new CountingChannelPoolHandler(); + ChannelPool pool = new SimpleChannelPool(cb, handler, ChannelHealthChecker.ACTIVE, false); + Channel channel1 = pool.acquire().syncUninterruptibly().getNow(); + channel1.close().syncUninterruptibly(); + Future releaseFuture = + pool.release(channel1, channel1.eventLoop().newPromise()).syncUninterruptibly(); + assertThat(releaseFuture.isSuccess(), CoreMatchers.is(true)); + + Channel channel2 = pool.acquire().syncUninterruptibly().getNow(); + //verifying that in fact the channel2 is different that means is not pulled from the pool + assertNotSame(channel1, channel2); + sc.close().syncUninterruptibly(); + channel2.close().syncUninterruptibly(); + pool.close(); + group.shutdownGracefully(); + } + + @Test + public void testBootstrap() { + final SimpleChannelPool pool = new SimpleChannelPool(new Bootstrap(), new CountingChannelPoolHandler()); + + try { + // Checking for the actual bootstrap object doesn't make sense here, since the pool uses a copy with a + // modified channel handler. + assertNotNull(pool.bootstrap()); + } finally { + pool.close(); + } + } + + @Test + public void testHandler() { + final ChannelPoolHandler handler = new CountingChannelPoolHandler(); + final SimpleChannelPool pool = new SimpleChannelPool(new Bootstrap(), handler); + + try { + assertSame(handler, pool.handler()); + } finally { + pool.close(); + } + } + + @Test + public void testHealthChecker() { + final ChannelHealthChecker healthChecker = ChannelHealthChecker.ACTIVE; + final SimpleChannelPool pool = new SimpleChannelPool( + new Bootstrap(), + new CountingChannelPoolHandler(), + healthChecker); + + try { + assertSame(healthChecker, pool.healthChecker()); + } finally { + pool.close(); + } + } + + @Test + public void testReleaseHealthCheck() { + final SimpleChannelPool healthCheckOnReleasePool = new SimpleChannelPool( + new Bootstrap(), + new CountingChannelPoolHandler(), + ChannelHealthChecker.ACTIVE, + true); + + try { + assertTrue(healthCheckOnReleasePool.releaseHealthCheck()); + } finally { + healthCheckOnReleasePool.close(); + } + + final SimpleChannelPool noHealthCheckOnReleasePool = new SimpleChannelPool( + new Bootstrap(), + new CountingChannelPoolHandler(), + ChannelHealthChecker.ACTIVE, + false); + + try { + assertFalse(noHealthCheckOnReleasePool.releaseHealthCheck()); + } finally { + noHealthCheckOnReleasePool.close(); + } + } + + @Test + public void testCloseAsync() throws Exception { + final LocalAddress addr = new LocalAddress(getLocalAddrId()); + final EventLoopGroup group = new DefaultEventLoopGroup(); + + // Start server + final ServerBootstrap sb = new ServerBootstrap() + .group(group) + .channel(LocalServerChannel.class) + .childHandler(new ChannelInitializer() { + @Override + protected void initChannel(LocalChannel ch) throws Exception { + ch.pipeline().addLast(new ChannelInboundHandlerAdapter()); + } + }); + final Channel sc = sb.bind(addr).syncUninterruptibly().channel(); + + // Create pool, acquire and return channels + final Bootstrap bootstrap = new Bootstrap() + .channel(LocalChannel.class).group(group).remoteAddress(addr); + final SimpleChannelPool pool = new SimpleChannelPool(bootstrap, new CountingChannelPoolHandler()); + Channel ch1 = pool.acquire().syncUninterruptibly().getNow(); + Channel ch2 = pool.acquire().syncUninterruptibly().getNow(); + pool.release(ch1).get(1, TimeUnit.SECONDS); + pool.release(ch2).get(1, TimeUnit.SECONDS); + + // Assert that returned channels are open before close + assertTrue(ch1.isOpen()); + assertTrue(ch2.isOpen()); + + // Close asynchronously with timeout + pool.closeAsync().get(1, TimeUnit.SECONDS); + + // Assert channels were indeed closed + assertFalse(ch1.isOpen()); + assertFalse(ch2.isOpen()); + + sc.close().sync(); + pool.close(); + group.shutdownGracefully(); + } + + @Test + public void testChannelAcquiredException() throws InterruptedException { + final LocalAddress addr = new LocalAddress(getLocalAddrId()); + final EventLoopGroup group = new DefaultEventLoopGroup(); + + // Start server + final ServerBootstrap sb = new ServerBootstrap() + .group(group) + .channel(LocalServerChannel.class) + .childHandler(new ChannelInitializer() { + @Override + protected void initChannel(LocalChannel ch) throws Exception { + ch.pipeline().addLast(new ChannelInboundHandlerAdapter()); + } + }); + final Channel sc = sb.bind(addr).syncUninterruptibly().channel(); + + // Create pool, acquire and return channels + final Bootstrap bootstrap = new Bootstrap() + .channel(LocalChannel.class).group(group).remoteAddress(addr); + final NullPointerException exception = new NullPointerException(); + final SimpleChannelPool pool = new SimpleChannelPool(bootstrap, new ChannelPoolHandler() { + @Override + public void channelReleased(Channel ch) { + } + @Override + public void channelAcquired(Channel ch) { + throw exception; + } + @Override + public void channelCreated(Channel ch) { + } + }); + + try { + pool.acquire().sync(); + } catch (NullPointerException e) { + assertSame(e, exception); + } + + sc.close().sync(); + pool.close(); + group.shutdownGracefully(); + } +} diff --git a/netty-channel/src/test/java/io/netty/channel/socket/InternetProtocolFamilyTest.java b/netty-channel/src/test/java/io/netty/channel/socket/InternetProtocolFamilyTest.java new file mode 100644 index 0000000..6fe225b --- /dev/null +++ b/netty-channel/src/test/java/io/netty/channel/socket/InternetProtocolFamilyTest.java @@ -0,0 +1,36 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.socket; + +import io.netty.util.NetUtil; +import org.junit.jupiter.api.Test; + +import java.net.InetAddress; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; + +public class InternetProtocolFamilyTest { + @Test + public void ipv4ShouldHaveLocalhostOfIpV4() { + assertThat(InternetProtocolFamily.IPv4.localhost(), is((InetAddress) NetUtil.LOCALHOST4)); + } + + @Test + public void ipv6ShouldHaveLocalhostOfIpV6() { + assertThat(InternetProtocolFamily.IPv6.localhost(), is((InetAddress) NetUtil.LOCALHOST6)); + } +} diff --git a/netty-channel/src/test/java/io/netty/channel/socket/nio/AbstractNioChannelTest.java b/netty-channel/src/test/java/io/netty/channel/socket/nio/AbstractNioChannelTest.java new file mode 100644 index 0000000..cc15078 --- /dev/null +++ b/netty-channel/src/test/java/io/netty/channel/socket/nio/AbstractNioChannelTest.java @@ -0,0 +1,82 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.socket.nio; + +import io.netty.channel.ChannelOption; +import io.netty.channel.nio.AbstractNioChannel; +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.net.SocketOption; +import java.net.StandardSocketOptions; +import java.nio.channels.NetworkChannel; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNull; + +public abstract class AbstractNioChannelTest { + + protected abstract T newNioChannel(); + + protected abstract NetworkChannel jdkChannel(T channel); + + protected abstract SocketOption newInvalidOption(); + + @Test + public void testNioChannelOption() throws IOException { + T channel = newNioChannel(); + try { + NetworkChannel jdkChannel = jdkChannel(channel); + ChannelOption option = NioChannelOption.of(StandardSocketOptions.SO_REUSEADDR); + boolean value1 = jdkChannel.getOption(StandardSocketOptions.SO_REUSEADDR); + boolean value2 = channel.config().getOption(option); + + assertEquals(value1, value2); + + channel.config().setOption(option, !value2); + boolean value3 = jdkChannel.getOption(StandardSocketOptions.SO_REUSEADDR); + boolean value4 = channel.config().getOption(option); + assertEquals(value3, value4); + assertNotEquals(value1, value4); + } finally { + channel.unsafe().closeForcibly(); + } + } + + @Test + public void testInvalidNioChannelOption() { + T channel = newNioChannel(); + try { + ChannelOption option = NioChannelOption.of(newInvalidOption()); + assertFalse(channel.config().setOption(option, null)); + assertNull(channel.config().getOption(option)); + } finally { + channel.unsafe().closeForcibly(); + } + } + + @Test + public void testGetOptions() { + T channel = newNioChannel(); + try { + channel.config().getOptions(); + } finally { + channel.unsafe().closeForcibly(); + } + } +} diff --git a/netty-channel/src/test/java/io/netty/channel/socket/nio/NioDatagramChannelTest.java b/netty-channel/src/test/java/io/netty/channel/socket/nio/NioDatagramChannelTest.java new file mode 100644 index 0000000..fc89f80 --- /dev/null +++ b/netty-channel/src/test/java/io/netty/channel/socket/nio/NioDatagramChannelTest.java @@ -0,0 +1,82 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.socket.nio; + +import io.netty.bootstrap.Bootstrap; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelOption; +import io.netty.channel.group.DefaultChannelGroup; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.DatagramChannel; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.concurrent.GlobalEventExecutor; +import org.junit.jupiter.api.Test; + +import java.net.InetSocketAddress; +import java.net.SocketOption; +import java.net.StandardSocketOptions; +import java.nio.channels.NetworkChannel; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class NioDatagramChannelTest extends AbstractNioChannelTest { + + /** + * Test try to reproduce issue #1335 + */ + @Test + public void testBindMultiple() throws Exception { + DefaultChannelGroup channelGroup = new DefaultChannelGroup(GlobalEventExecutor.INSTANCE); + NioEventLoopGroup group = new NioEventLoopGroup(); + try { + for (int i = 0; i < 100; i++) { + Bootstrap udpBootstrap = new Bootstrap(); + udpBootstrap.group(group).channel(NioDatagramChannel.class) + .option(ChannelOption.SO_BROADCAST, true) + .handler(new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + // Discard + ReferenceCountUtil.release(msg); + } + }); + DatagramChannel datagramChannel = (DatagramChannel) udpBootstrap + .bind(new InetSocketAddress(0)).syncUninterruptibly().channel(); + channelGroup.add(datagramChannel); + } + assertEquals(100, channelGroup.size()); + } finally { + channelGroup.close().sync(); + group.shutdownGracefully().sync(); + } + } + + @Override + protected NioDatagramChannel newNioChannel() { + return new NioDatagramChannel(); + } + + @Override + protected NetworkChannel jdkChannel(NioDatagramChannel channel) { + return channel.javaChannel(); + } + + @Override + protected SocketOption newInvalidOption() { + return StandardSocketOptions.TCP_NODELAY; + } +} diff --git a/netty-channel/src/test/java/io/netty/channel/socket/nio/NioServerSocketChannelTest.java b/netty-channel/src/test/java/io/netty/channel/socket/nio/NioServerSocketChannelTest.java new file mode 100644 index 0000000..fd8de09 --- /dev/null +++ b/netty-channel/src/test/java/io/netty/channel/socket/nio/NioServerSocketChannelTest.java @@ -0,0 +1,83 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.socket.nio; + +import io.netty.channel.Channel; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.nio.NioEventLoopGroup; + +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.SocketOption; +import java.net.StandardSocketOptions; +import java.nio.channels.NetworkChannel; +import java.nio.channels.ServerSocketChannel; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class NioServerSocketChannelTest extends AbstractNioChannelTest { + + @Test + public void testCloseOnError() throws Exception { + ServerSocketChannel jdkChannel = ServerSocketChannel.open(); + NioServerSocketChannel serverSocketChannel = new NioServerSocketChannel(jdkChannel); + EventLoopGroup group = new NioEventLoopGroup(1); + try { + group.register(serverSocketChannel).syncUninterruptibly(); + serverSocketChannel.bind(new InetSocketAddress(0)).syncUninterruptibly(); + assertFalse(serverSocketChannel.closeOnReadError(new IOException())); + assertTrue(serverSocketChannel.closeOnReadError(new IllegalArgumentException())); + serverSocketChannel.close().syncUninterruptibly(); + } finally { + group.shutdownGracefully(); + } + } + + @Test + public void testIsActiveFalseAfterClose() { + NioServerSocketChannel serverSocketChannel = new NioServerSocketChannel(); + EventLoopGroup group = new NioEventLoopGroup(1); + try { + group.register(serverSocketChannel).syncUninterruptibly(); + Channel channel = serverSocketChannel.bind(new InetSocketAddress(0)).syncUninterruptibly().channel(); + assertTrue(channel.isActive()); + assertTrue(channel.isOpen()); + channel.close().syncUninterruptibly(); + assertFalse(channel.isOpen()); + assertFalse(channel.isActive()); + } finally { + group.shutdownGracefully(); + } + } + + @Override + protected NioServerSocketChannel newNioChannel() { + return new NioServerSocketChannel(); + } + + @Override + protected NetworkChannel jdkChannel(NioServerSocketChannel channel) { + return channel.javaChannel(); + } + + @Override + protected SocketOption newInvalidOption() { + return StandardSocketOptions.IP_MULTICAST_IF; + } +} diff --git a/netty-channel/src/test/java/io/netty/channel/socket/nio/NioSocketChannelTest.java b/netty-channel/src/test/java/io/netty/channel/socket/nio/NioSocketChannelTest.java new file mode 100644 index 0000000..7fbd6fc --- /dev/null +++ b/netty-channel/src/test/java/io/netty/channel/socket/nio/NioSocketChannelTest.java @@ -0,0 +1,305 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.channel.socket.nio; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelOption; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.EventLoop; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.SocketChannel; +import io.netty.util.CharsetUtil; +import io.netty.util.NetUtil; +import io.netty.util.internal.PlatformDependent; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.io.DataInput; +import java.io.DataInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.net.InetSocketAddress; +import java.net.ServerSocket; +import java.net.Socket; +import java.net.SocketAddress; +import java.net.SocketOption; +import java.net.StandardSocketOptions; +import java.nio.channels.ClosedChannelException; +import java.nio.channels.NetworkChannel; +import java.util.Queue; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; + +import static org.hamcrest.CoreMatchers.*; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertNotSame; + + + +public class NioSocketChannelTest extends AbstractNioChannelTest { + + /** + * Reproduces the issue #1600 + */ + @Test + public void testFlushCloseReentrance() throws Exception { + NioEventLoopGroup group = new NioEventLoopGroup(1); + try { + final Queue futures = new LinkedBlockingQueue(); + + ServerBootstrap sb = new ServerBootstrap(); + sb.group(group).channel(NioServerSocketChannel.class); + sb.childOption(ChannelOption.SO_SNDBUF, 1024); + sb.childHandler(new ChannelInboundHandlerAdapter() { + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + // Write a large enough data so that it is split into two loops. + futures.add(ctx.write( + ctx.alloc().buffer().writeZero(1048576)).addListener(ChannelFutureListener.CLOSE)); + futures.add(ctx.write(ctx.alloc().buffer().writeZero(1048576))); + ctx.flush(); + futures.add(ctx.write(ctx.alloc().buffer().writeZero(1048576))); + ctx.flush(); + } + }); + + SocketAddress address = sb.bind(0).sync().channel().localAddress(); + + Socket s = new Socket(NetUtil.LOCALHOST, ((InetSocketAddress) address).getPort()); + + InputStream in = s.getInputStream(); + byte[] buf = new byte[8192]; + for (;;) { + if (in.read(buf) == -1) { + break; + } + + // Wait a little bit so that the write attempts are split into multiple flush attempts. + Thread.sleep(10); + } + s.close(); + + assertThat(futures.size(), is(3)); + ChannelFuture f1 = futures.poll(); + ChannelFuture f2 = futures.poll(); + ChannelFuture f3 = futures.poll(); + assertThat(f1.isSuccess(), is(true)); + assertThat(f2.isDone(), is(true)); + assertThat(f2.isSuccess(), is(false)); + assertThat(f2.cause(), is(instanceOf(ClosedChannelException.class))); + assertThat(f3.isDone(), is(true)); + assertThat(f3.isSuccess(), is(false)); + assertThat(f3.cause(), is(instanceOf(ClosedChannelException.class))); + } finally { + group.shutdownGracefully().sync(); + } + } + + /** + * Reproduces the issue #1679 + */ + @Test + public void testFlushAfterGatheredFlush() throws Exception { + NioEventLoopGroup group = new NioEventLoopGroup(1); + try { + ServerBootstrap sb = new ServerBootstrap(); + sb.group(group).channel(NioServerSocketChannel.class); + sb.childHandler(new ChannelInboundHandlerAdapter() { + @Override + public void channelActive(final ChannelHandlerContext ctx) throws Exception { + // Trigger a gathering write by writing two buffers. + ctx.write(Unpooled.wrappedBuffer(new byte[] { 'a' })); + ChannelFuture f = ctx.write(Unpooled.wrappedBuffer(new byte[] { 'b' })); + f.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + // This message must be flushed + ctx.writeAndFlush(Unpooled.wrappedBuffer(new byte[]{'c'})); + } + }); + ctx.flush(); + } + }); + + SocketAddress address = sb.bind(0).sync().channel().localAddress(); + + Socket s = new Socket(NetUtil.LOCALHOST, ((InetSocketAddress) address).getPort()); + + DataInput in = new DataInputStream(s.getInputStream()); + byte[] buf = new byte[3]; + in.readFully(buf); + + assertThat(new String(buf, CharsetUtil.US_ASCII), is("abc")); + + s.close(); + } finally { + group.shutdownGracefully().sync(); + } + } + + // Test for https://github.com/netty/netty/issues/4805 + @Test + @Timeout(value = 3000, unit = TimeUnit.MILLISECONDS) + public void testChannelReRegisterReadSameEventLoop() throws Exception { + testChannelReRegisterRead(true); + } + + @Test + @Timeout(value = 3000, unit = TimeUnit.MILLISECONDS) + public void testChannelReRegisterReadDifferentEventLoop() throws Exception { + testChannelReRegisterRead(false); + } + + private static void testChannelReRegisterRead(final boolean sameEventLoop) throws Exception { + final EventLoopGroup group = new NioEventLoopGroup(2); + final CountDownLatch latch = new CountDownLatch(1); + + // Just some random bytes + byte[] bytes = new byte[1024]; + PlatformDependent.threadLocalRandom().nextBytes(bytes); + + Channel sc = null; + Channel cc = null; + ServerBootstrap b = new ServerBootstrap(); + try { + b.group(group) + .channel(NioServerSocketChannel.class) + .childOption(ChannelOption.SO_KEEPALIVE, true) + .childHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ChannelPipeline pipeline = ch.pipeline(); + pipeline.addLast(new SimpleChannelInboundHandler() { + @Override + protected void channelRead0(ChannelHandlerContext ctx, ByteBuf byteBuf) { + // We was able to read something from the Channel after reregister. + latch.countDown(); + } + + @Override + public void channelActive(final ChannelHandlerContext ctx) throws Exception { + final EventLoop loop = group.next(); + if (sameEventLoop) { + deregister(ctx, loop); + } else { + loop.execute(new Runnable() { + @Override + public void run() { + deregister(ctx, loop); + } + }); + } + } + + private void deregister(ChannelHandlerContext ctx, final EventLoop loop) { + // As soon as the channel becomes active re-register it to another + // EventLoop. After this is done we should still receive the data that + // was written to the channel. + ctx.deregister().addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture cf) { + Channel channel = cf.channel(); + assertNotSame(loop, channel.eventLoop()); + group.next().register(channel); + } + }); + } + }); + } + }); + + sc = b.bind(0).syncUninterruptibly().channel(); + + Bootstrap bootstrap = new Bootstrap(); + bootstrap.group(group).channel(NioSocketChannel.class); + bootstrap.handler(new ChannelInboundHandlerAdapter()); + cc = bootstrap.connect(sc.localAddress()).syncUninterruptibly().channel(); + cc.writeAndFlush(Unpooled.wrappedBuffer(bytes)).syncUninterruptibly(); + latch.await(); + } finally { + if (cc != null) { + cc.close(); + } + if (sc != null) { + sc.close(); + } + group.shutdownGracefully(); + } + } + + @Test + @Timeout(value = 3000, unit = TimeUnit.MILLISECONDS) + public void testShutdownOutputAndClose() throws IOException { + NioEventLoopGroup group = new NioEventLoopGroup(1); + ServerSocket socket = new ServerSocket(); + socket.bind(new InetSocketAddress(0)); + Socket accepted = null; + try { + Bootstrap sb = new Bootstrap(); + sb.group(group).channel(NioSocketChannel.class); + sb.handler(new ChannelInboundHandlerAdapter()); + + SocketChannel channel = (SocketChannel) sb.connect(socket.getLocalSocketAddress()) + .syncUninterruptibly().channel(); + + accepted = socket.accept(); + channel.shutdownOutput().syncUninterruptibly(); + + channel.close().syncUninterruptibly(); + } finally { + if (accepted != null) { + try { + accepted.close(); + } catch (IOException ignore) { + // ignore + } + } + try { + socket.close(); + } catch (IOException ignore) { + // ignore + } + group.shutdownGracefully(); + } + } + + @Override + protected NioSocketChannel newNioChannel() { + return new NioSocketChannel(); + } + + @Override + protected NetworkChannel jdkChannel(NioSocketChannel channel) { + return channel.javaChannel(); + } + + @Override + protected SocketOption newInvalidOption() { + return StandardSocketOptions.IP_MULTICAST_IF; + } +} diff --git a/netty-channel/src/test/java/io/netty/nativeimage/ChannelHandlerMetadataUtil.java b/netty-channel/src/test/java/io/netty/nativeimage/ChannelHandlerMetadataUtil.java new file mode 100644 index 0000000..a630dce --- /dev/null +++ b/netty-channel/src/test/java/io/netty/nativeimage/ChannelHandlerMetadataUtil.java @@ -0,0 +1,246 @@ +/* + * Copyright 2022 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.nativeimage; + +import com.google.common.reflect.TypeToken; +import com.google.gson.Gson; +import com.google.gson.GsonBuilder; +import io.netty.channel.ChannelHandler; +import io.netty.channel.NativeImageHandlerMetadataTest; +import org.junit.jupiter.api.Assertions; +import org.reflections.Reflections; +import org.reflections.util.ConfigurationBuilder; + +import java.io.File; +import java.io.FileReader; +import java.io.IOException; +import java.lang.reflect.Type; +import java.net.URL; +import java.text.Collator; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +/** + * Generates native-image reflection metadata for subtypes of {@link io.netty.channel.ChannelHandler}. + *

+ * To use, create a JUnit test in the desired Netty module and invoke {@link #generateMetadata(String...)} with a list + * of packages present in the target Netty module that may contain subtypes of the ChannelHandler. + *

+ * See {@link NativeImageHandlerMetadataTest} + */ +public final class ChannelHandlerMetadataUtil { + + @SuppressWarnings("UnstableApiUsage") + private static final Type HANDLER_METADATA_LIST_TYPE = new TypeToken>() { + }.getType(); + private static final Gson gson = new GsonBuilder().setPrettyPrinting().create(); + + private ChannelHandlerMetadataUtil() { + } + + public static void generateMetadata(String... packageNames) { + String projectGroupId = System.getProperty("nativeImage.handlerMetadataGroupId"); + String projectArtifactId = System.getProperty("nativeimage.handlerMetadataArtifactId"); + + Set> subtypes = findChannelHandlerSubclasses(packageNames); + + if (Arrays.asList(packageNames).contains("io.netty.channel")) { + // We want the metadata for the ChannelHandler itself too + subtypes.add(ChannelHandler.class); + } + + Set handlerMetadata = new HashSet(); + for (Class subtype : subtypes) { + handlerMetadata.add(new HandlerMetadata(subtype.getName(), new Condition(subtype.getName()), true)); + } + + String projectRelativeResourcePath = "src/main/resources/META-INF/native-image/" + projectGroupId + "/" + + projectArtifactId + "/generated/handlers/reflect-config.json"; + File existingMetadataFile = new File(projectRelativeResourcePath); + String existingMetadataPath = existingMetadataFile.getAbsolutePath(); + if (!existingMetadataFile.exists()) { + if (handlerMetadata.size() == 0) { + return; + } + + String message = "Native Image reflection metadata is required for handlers in this project. " + + "This metadata was not found under " + + existingMetadataPath + + "\nPlease create this file with the following content: \n" + + getMetadataJsonString(handlerMetadata) + + "\n"; + Assertions.fail(message); + } + + List existingMetadata = null; + try { + FileReader reader = new FileReader(existingMetadataFile); + existingMetadata = gson.fromJson(reader, HANDLER_METADATA_LIST_TYPE); + } catch (IOException e) { + Assertions.fail("Failed to open the native-image metadata file at: " + existingMetadataPath, e); + } + + Set newMetadata = new HashSet(handlerMetadata); + newMetadata.removeAll(existingMetadata); + + Set removedMetadata = new HashSet(existingMetadata); + removedMetadata.removeAll(handlerMetadata); + + if (!newMetadata.isEmpty() || !removedMetadata.isEmpty()) { + StringBuilder builder = new StringBuilder(); + builder.append("In the native-image handler metadata file at ") + .append(existingMetadataPath) + .append("\n"); + + if (!newMetadata.isEmpty()) { + builder.append("The following new metadata must be added:\n\n") + .append(getMetadataJsonString(newMetadata)) + .append("\n\n"); + } + if (!removedMetadata.isEmpty()) { + builder.append("The following metadata must be removed:\n\n") + .append(getMetadataJsonString(removedMetadata)) + .append("\n\n"); + } + + builder.append("Expected metadata file contents:\n\n") + .append(getMetadataJsonString(handlerMetadata)) + .append("\n"); + Assertions.fail(builder.toString()); + } + } + + private static Set> findChannelHandlerSubclasses(String... packageNames) { + Reflections reflections = new Reflections( + new ConfigurationBuilder() + .forPackages(packageNames)); + + Set> allSubtypes = reflections.getSubTypesOf(ChannelHandler.class); + Set> targetSubtypes = new HashSet>(); + + for (Class subtype : allSubtypes) { + if (isTestClass(subtype)) { + continue; + } + String className = subtype.getName(); + boolean shouldInclude = false; + for (String packageName : packageNames) { + if (className.startsWith(packageName)) { + shouldInclude = true; + break; + } + } + + if (shouldInclude) { + targetSubtypes.add(subtype); + } + } + + return targetSubtypes; + } + + private static boolean isTestClass(Class clazz) { + String[] parts = clazz.getName().split("\\."); + if (parts.length > 0) { + URL classFile = clazz.getResource(parts[parts.length - 1] + ".class"); + if (classFile != null) { + return classFile.toString().contains("Test"); + } + } + return false; + } + + private static String getMetadataJsonString(Set metadata) { + List metadataList = new ArrayList(metadata); + Collections.sort(metadataList, new Comparator() { + @Override + public int compare(HandlerMetadata h1, HandlerMetadata h2) { + return Collator.getInstance().compare(h1.name, h2.name); + } + }); + return gson.toJson(metadataList, HANDLER_METADATA_LIST_TYPE); + } + + private static final class Condition { + Condition(String typeReachable) { + this.typeReachable = typeReachable; + } + + final String typeReachable; + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + Condition condition = (Condition) o; + return typeReachable != null && typeReachable.equals(condition.typeReachable); + } + + @Override + public int hashCode() { + return typeReachable.hashCode(); + } + } + + private static final class HandlerMetadata { + final String name; + + final Condition condition; + + final boolean queryAllPublicMethods; + + HandlerMetadata(String name, Condition condition, boolean queryAllPublicMethods) { + this.name = name; + this.condition = condition; + this.queryAllPublicMethods = queryAllPublicMethods; + } + + @Override + public String toString() { + return name; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + HandlerMetadata that = (HandlerMetadata) o; + return queryAllPublicMethods == that.queryAllPublicMethods + && (name != null && name.equals(that.name)) + && (condition != null && condition.equals(that.condition)); + } + + @Override + public int hashCode() { + return name.hashCode(); + } + } +} diff --git a/netty-channel/src/test/resources/META-INF/services/io.netty.bootstrap.ChannelInitializerExtension b/netty-channel/src/test/resources/META-INF/services/io.netty.bootstrap.ChannelInitializerExtension new file mode 100644 index 0000000..315cb7d --- /dev/null +++ b/netty-channel/src/test/resources/META-INF/services/io.netty.bootstrap.ChannelInitializerExtension @@ -0,0 +1 @@ +io.netty.bootstrap.StubChannelInitializerExtension diff --git a/netty-channel/src/test/resources/logging.properties b/netty-channel/src/test/resources/logging.properties new file mode 100644 index 0000000..3cd7309 --- /dev/null +++ b/netty-channel/src/test/resources/logging.properties @@ -0,0 +1,7 @@ +handlers=java.util.logging.ConsoleHandler +.level=ALL +java.util.logging.SimpleFormatter.format=%1$tY-%1$tm-%1$td %1$tH:%1$tM:%1$tS.%1$tL %4$-7s [%3$s] %5$s %6$s%n +java.util.logging.ConsoleHandler.level=ALL +java.util.logging.ConsoleHandler.formatter=java.util.logging.SimpleFormatter +jdk.event.security.level=INFO +org.junit.jupiter.engine.execution.ConditionEvaluator.level=OFF diff --git a/netty-handler-codec-compression/build.gradle b/netty-handler-codec-compression/build.gradle new file mode 100644 index 0000000..72f8ac4 --- /dev/null +++ b/netty-handler-codec-compression/build.gradle @@ -0,0 +1,17 @@ +dependencies { + api project(':netty-handler-codec') + implementation project(':netty-bzip2') + implementation project(':netty-zlib') + implementation libs.brotli4j + implementation libs.jzlib + implementation libs.lz4 + implementation libs.lzf + implementation libs.zstd + testImplementation testLibs.commons.compress + testImplementation testLibs.mockito.core + testRuntimeOnly testLibs.brotli4j.native.linux.x8664 + testRuntimeOnly testLibs.brotli4j.native.linux.aarch64 + testRuntimeOnly testLibs.brotli4j.native.osx.x8664 + testRuntimeOnly testLibs.brotli4j.native.osx.aarch64 + testRuntimeOnly testLibs.brotli4j.native.windows.x8664 +} diff --git a/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/Brotli.java b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/Brotli.java new file mode 100644 index 0000000..19935ee --- /dev/null +++ b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/Brotli.java @@ -0,0 +1,83 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.codec.compression; + +import com.aayushatharva.brotli4j.Brotli4jLoader; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +public final class Brotli { + + private static final InternalLogger logger = InternalLoggerFactory.getInstance(Brotli.class); + private static final ClassNotFoundException CNFE; + private static Throwable cause; + + static { + ClassNotFoundException cnfe = null; + + try { + Class.forName("com.aayushatharva.brotli4j.Brotli4jLoader", false, + PlatformDependent.getClassLoader(Brotli.class)); + } catch (ClassNotFoundException t) { + cnfe = t; + logger.debug( + "brotli4j not in the classpath; Brotli support will be unavailable."); + } + + CNFE = cnfe; + + // If in the classpath, try to load the native library and initialize brotli4j. + if (cnfe == null) { + cause = Brotli4jLoader.getUnavailabilityCause(); + if (cause != null) { + logger.debug("Failed to load brotli4j; Brotli support will be unavailable.", cause); + } + } + } + + /** + * + * @return true when brotli4j is in the classpath + * and native library is available on this platform and could be loaded + */ + public static boolean isAvailable() { + return CNFE == null && Brotli4jLoader.isAvailable(); + } + + /** + * Throws when brotli support is missing from the classpath or is unavailable on this platform + * @throws Throwable a ClassNotFoundException if brotli4j is missing + * or a UnsatisfiedLinkError if brotli4j native lib can't be loaded + */ + public static void ensureAvailability() throws Throwable { + if (CNFE != null) { + throw CNFE; + } + Brotli4jLoader.ensureAvailability(); + } + + /** + * Returns {@link Throwable} of unavailability cause + */ + public static Throwable cause() { + return cause; + } + + private Brotli() { + } +} diff --git a/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/BrotliDecoder.java b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/BrotliDecoder.java new file mode 100644 index 0000000..6d009d3 --- /dev/null +++ b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/BrotliDecoder.java @@ -0,0 +1,173 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.codec.compression; + +import com.aayushatharva.brotli4j.decoder.DecoderJNI; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.ByteToMessageDecoder; +import io.netty.util.internal.ObjectUtil; + +import java.nio.ByteBuffer; +import java.util.List; + +/** + * Decompresses a {@link ByteBuf} encoded with the brotli format. + * + * See brotli. + */ +public final class BrotliDecoder extends ByteToMessageDecoder { + + private enum State { + DONE, NEEDS_MORE_INPUT, ERROR + } + + static { + try { + Brotli.ensureAvailability(); + } catch (Throwable throwable) { + throw new ExceptionInInitializerError(throwable); + } + } + + private final int inputBufferSize; + private DecoderJNI.Wrapper decoder; + private boolean destroyed; + + /** + * Creates a new BrotliDecoder with a default 8kB input buffer + */ + public BrotliDecoder() { + this(8 * 1024); + } + + /** + * Creates a new BrotliDecoder + * @param inputBufferSize desired size of the input buffer in bytes + */ + public BrotliDecoder(int inputBufferSize) { + this.inputBufferSize = ObjectUtil.checkPositive(inputBufferSize, "inputBufferSize"); + } + + private ByteBuf pull(ByteBufAllocator alloc) { + ByteBuffer nativeBuffer = decoder.pull(); + // nativeBuffer actually wraps brotli's internal buffer so we need to copy its content + ByteBuf copy = alloc.buffer(nativeBuffer.remaining()); + copy.writeBytes(nativeBuffer); + return copy; + } + + private State decompress(ByteBuf input, List output, ByteBufAllocator alloc) { + for (;;) { + switch (decoder.getStatus()) { + case DONE: + return State.DONE; + + case OK: + decoder.push(0); + break; + + case NEEDS_MORE_INPUT: + if (decoder.hasOutput()) { + output.add(pull(alloc)); + } + + if (!input.isReadable()) { + return State.NEEDS_MORE_INPUT; + } + + ByteBuffer decoderInputBuffer = decoder.getInputBuffer(); + decoderInputBuffer.clear(); + int readBytes = readBytes(input, decoderInputBuffer); + decoder.push(readBytes); + break; + + case NEEDS_MORE_OUTPUT: + output.add(pull(alloc)); + break; + + default: + return State.ERROR; + } + } + } + + private static int readBytes(ByteBuf in, ByteBuffer dest) { + int limit = Math.min(in.readableBytes(), dest.remaining()); + ByteBuffer slice = dest.slice(); + slice.limit(limit); + in.readBytes(slice); + dest.position(dest.position() + limit); + return limit; + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + decoder = new DecoderJNI.Wrapper(inputBufferSize); + } + + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { + if (destroyed) { + // Skip data received after finished. + in.skipBytes(in.readableBytes()); + return; + } + + if (!in.isReadable()) { + return; + } + + try { + State state = decompress(in, out, ctx.alloc()); + if (state == State.DONE) { + destroy(); + } else if (state == State.ERROR) { + throw new DecompressionException("Brotli stream corrupted"); + } + } catch (Exception e) { + destroy(); + throw e; + } + } + + private void destroy() { + if (!destroyed) { + destroyed = true; + decoder.destroy(); + } + } + + @Override + protected void handlerRemoved0(ChannelHandlerContext ctx) throws Exception { + try { + destroy(); + } finally { + super.handlerRemoved0(ctx); + } + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + try { + destroy(); + } finally { + super.channelInactive(ctx); + } + } +} diff --git a/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/BrotliEncoder.java b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/BrotliEncoder.java new file mode 100644 index 0000000..652dd8b --- /dev/null +++ b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/BrotliEncoder.java @@ -0,0 +1,278 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.compression; + +import com.aayushatharva.brotli4j.encoder.BrotliEncoderChannel; +import com.aayushatharva.brotli4j.encoder.Encoder; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.handler.codec.MessageToByteEncoder; +import io.netty.util.AttributeKey; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.internal.ObjectUtil; + +import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.channels.ClosedChannelException; +import java.nio.channels.WritableByteChannel; + +/** + * Compress a {@link ByteBuf} with the Brotli compression. + *

+ * See brotli. + */ +@ChannelHandler.Sharable +public final class BrotliEncoder extends MessageToByteEncoder { + + private static final AttributeKey ATTR = AttributeKey.valueOf("BrotliEncoderWriter"); + + private final Encoder.Parameters parameters; + private final boolean isSharable; + private Writer writer; + + /** + * Create a new {@link BrotliEncoder} Instance with {@link BrotliOptions#DEFAULT} + * and {@link #isSharable()} set to {@code true} + */ + public BrotliEncoder() { + this(BrotliOptions.DEFAULT); + } + + /** + * Create a new {@link BrotliEncoder} Instance + * + * @param brotliOptions {@link BrotliOptions} to use and + * {@link #isSharable()} set to {@code true} + */ + public BrotliEncoder(BrotliOptions brotliOptions) { + this(brotliOptions.parameters()); + } + + /** + * Create a new {@link BrotliEncoder} Instance + * and {@link #isSharable()} set to {@code true} + * + * @param parameters {@link Encoder.Parameters} to use + */ + public BrotliEncoder(Encoder.Parameters parameters) { + this(parameters, true); + } + + /** + *

+ * Create a new {@link BrotliEncoder} Instance and specify + * whether this instance will be shared with multiple pipelines or not. + *

+ * + * If {@link #isSharable()} is true then on {@link #handlerAdded(ChannelHandlerContext)} call, + * a new {@link Writer} will create, and it will be mapped using {@link Channel#attr(AttributeKey)} + * so {@link BrotliEncoder} can be shared with multiple pipelines. This works fine but there on every + * {@link #encode(ChannelHandlerContext, ByteBuf, ByteBuf)} call, we have to get the {@link Writer} associated + * with the appropriate channel. And this will add a overhead. So it is recommended to set {@link #isSharable()} + * to {@code false} and create new {@link BrotliEncoder} instance for every pipeline. + * + * @param parameters {@link Encoder.Parameters} to use + * @param isSharable Set to {@code true} if this instance is shared else set to {@code false} + */ + public BrotliEncoder(Encoder.Parameters parameters, boolean isSharable) { + this.parameters = ObjectUtil.checkNotNull(parameters, "Parameters"); + this.isSharable = isSharable; + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + Writer writer = new Writer(parameters, ctx); + if (isSharable) { + ctx.channel().attr(ATTR).set(writer); + } else { + this.writer = writer; + } + super.handlerAdded(ctx); + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { + finish(ctx); + super.handlerRemoved(ctx); + } + + @Override + protected void encode(ChannelHandlerContext ctx, ByteBuf msg, ByteBuf out) throws Exception { + // NO-OP + } + + @Override + protected ByteBuf allocateBuffer(ChannelHandlerContext ctx, ByteBuf msg, boolean preferDirect) throws Exception { + if (!msg.isReadable()) { + return Unpooled.EMPTY_BUFFER; + } + + Writer writer; + if (isSharable) { + writer = ctx.channel().attr(ATTR).get(); + } else { + writer = this.writer; + } + + // If Writer is 'null' then Writer is not open. + if (writer == null) { + return Unpooled.EMPTY_BUFFER; + } else { + writer.encode(msg, preferDirect); + return writer.writableBuffer; + } + } + + @Override + public boolean isSharable() { + return isSharable; + } + + /** + * Finish the encoding, close streams and write final {@link ByteBuf} to the channel. + * + * @param ctx {@link ChannelHandlerContext} which we want to close + * @throws IOException If an error occurred during closure + */ + public void finish(ChannelHandlerContext ctx) throws IOException { + finishEncode(ctx, ctx.newPromise()); + } + + private ChannelFuture finishEncode(ChannelHandlerContext ctx, ChannelPromise promise) throws IOException { + Writer writer; + + if (isSharable) { + writer = ctx.channel().attr(ATTR).getAndSet(null); + } else { + writer = this.writer; + } + + if (writer != null) { + writer.close(); + this.writer = null; + } + return promise; + } + + @Override + public void close(final ChannelHandlerContext ctx, final ChannelPromise promise) throws Exception { + ChannelFuture f = finishEncode(ctx, ctx.newPromise()); + EncoderUtil.closeAfterFinishEncode(ctx, f, promise); + } + + /** + * {@link Writer} is the implementation of {@link WritableByteChannel} which encodes + * Brotli data and stores it into {@link ByteBuf}. + */ + private static final class Writer implements WritableByteChannel { + + private ByteBuf writableBuffer; + private final BrotliEncoderChannel brotliEncoderChannel; + private final ChannelHandlerContext ctx; + private boolean isClosed; + + private Writer(Encoder.Parameters parameters, ChannelHandlerContext ctx) throws IOException { + brotliEncoderChannel = new BrotliEncoderChannel(this, parameters); + this.ctx = ctx; + } + + private void encode(ByteBuf msg, boolean preferDirect) throws Exception { + try { + allocate(preferDirect); + + // Compress data and flush it into Buffer. + // + // As soon as we call flush, Encoder will be triggered to write encoded + // data into WritableByteChannel. + // + // A race condition will not arise because one flush call to encoder will result + // in only 1 call at `write(ByteBuffer)`. + ByteBuffer nioBuffer = CompressionUtil.safeReadableNioBuffer(msg); + int position = nioBuffer.position(); + brotliEncoderChannel.write(nioBuffer); + msg.skipBytes(nioBuffer.position() - position); + brotliEncoderChannel.flush(); + } catch (Exception e) { + ReferenceCountUtil.release(msg); + throw e; + } + } + + private void allocate(boolean preferDirect) { + if (preferDirect) { + writableBuffer = ctx.alloc().ioBuffer(); + } else { + writableBuffer = ctx.alloc().buffer(); + } + } + + @Override + public int write(ByteBuffer src) throws IOException { + if (!isOpen()) { + throw new ClosedChannelException(); + } + + return writableBuffer.writeBytes(src).readableBytes(); + } + + @Override + public boolean isOpen() { + return !isClosed; + } + + @Override + public void close() { + final ChannelPromise promise = ctx.newPromise(); + + ctx.executor().execute(new Runnable() { + @Override + public void run() { + try { + finish(promise); + } catch (IOException ex) { + promise.setFailure(new IllegalStateException("Failed to finish encoding", ex)); + } + } + }); + } + + public void finish(final ChannelPromise promise) throws IOException { + if (!isClosed) { + // Allocate a buffer and write last pending data. + allocate(true); + + try { + brotliEncoderChannel.close(); + isClosed = true; + } catch (Exception ex) { + promise.setFailure(ex); + + // Since we have already allocated Buffer for close operation, + // we will release that buffer to prevent memory leak. + ReferenceCountUtil.release(writableBuffer); + return; + } + + ctx.writeAndFlush(writableBuffer, promise); + } + } + } +} diff --git a/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/BrotliOptions.java b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/BrotliOptions.java new file mode 100644 index 0000000..737f5d1 --- /dev/null +++ b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/BrotliOptions.java @@ -0,0 +1,47 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.compression; + +import com.aayushatharva.brotli4j.encoder.Encoder; +import io.netty.util.internal.ObjectUtil; + +/** + * {@link BrotliOptions} holds {@link Encoder.Parameters} for + * Brotli compression. + */ +public final class BrotliOptions implements CompressionOptions { + + private final Encoder.Parameters parameters; + + /** + * @see StandardCompressionOptions#brotli() + */ + static final BrotliOptions DEFAULT = new BrotliOptions( + new Encoder.Parameters().setQuality(4).setMode(Encoder.Mode.TEXT) + ); + + BrotliOptions(Encoder.Parameters parameters) { + if (!Brotli.isAvailable()) { + throw new IllegalStateException("Brotli is not available", Brotli.cause()); + } + + this.parameters = ObjectUtil.checkNotNull(parameters, "Parameters"); + } + + public Encoder.Parameters parameters() { + return parameters; + } +} diff --git a/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/ByteBufChecksum.java b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/ByteBufChecksum.java new file mode 100644 index 0000000..555a194 --- /dev/null +++ b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/ByteBufChecksum.java @@ -0,0 +1,142 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.compression; + +import io.netty.buffer.ByteBuf; +import io.netty.util.ByteProcessor; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.PlatformDependent; + +import java.lang.reflect.Method; +import java.nio.ByteBuffer; +import java.util.zip.Adler32; +import java.util.zip.CRC32; +import java.util.zip.Checksum; + +/** + * {@link Checksum} implementation which can directly act on a {@link ByteBuf}. + * + * Implementations may optimize access patterns depending on if the {@link ByteBuf} is backed by a + * byte array ({@link ByteBuf#hasArray()} is {@code true}) or not. + */ +abstract class ByteBufChecksum implements Checksum { + private static final Method ADLER32_UPDATE_METHOD; + private static final Method CRC32_UPDATE_METHOD; + + static { + // See if we can use fast-path when using ByteBuf that is not heap based as Adler32 and CRC32 added support + // for update(ByteBuffer) in JDK8. + ADLER32_UPDATE_METHOD = updateByteBuffer(new Adler32()); + CRC32_UPDATE_METHOD = updateByteBuffer(new CRC32()); + } + + private final ByteProcessor updateProcessor = new ByteProcessor() { + @Override + public boolean process(byte value) throws Exception { + update(value); + return true; + } + }; + + private static Method updateByteBuffer(Checksum checksum) { + if (PlatformDependent.javaVersion() >= 8) { + try { + Method method = checksum.getClass().getDeclaredMethod("update", ByteBuffer.class); + method.invoke(checksum, ByteBuffer.allocate(1)); + return method; + } catch (Throwable ignore) { + return null; + } + } + return null; + } + + static ByteBufChecksum wrapChecksum(Checksum checksum) { + ObjectUtil.checkNotNull(checksum, "checksum"); + if (checksum instanceof ByteBufChecksum) { + return (ByteBufChecksum) checksum; + } + if (checksum instanceof Adler32 && ADLER32_UPDATE_METHOD != null) { + return new ReflectiveByteBufChecksum(checksum, ADLER32_UPDATE_METHOD); + } + if (checksum instanceof CRC32 && CRC32_UPDATE_METHOD != null) { + return new ReflectiveByteBufChecksum(checksum, CRC32_UPDATE_METHOD); + } + return new SlowByteBufChecksum(checksum); + } + + /** + * @see #update(byte[], int, int) + */ + public void update(ByteBuf b, int off, int len) { + if (b.hasArray()) { + update(b.array(), b.arrayOffset() + off, len); + } else { + b.forEachByte(off, len, updateProcessor); + } + } + + private static final class ReflectiveByteBufChecksum extends SlowByteBufChecksum { + private final Method method; + + ReflectiveByteBufChecksum(Checksum checksum, Method method) { + super(checksum); + this.method = method; + } + + @Override + public void update(ByteBuf b, int off, int len) { + if (b.hasArray()) { + update(b.array(), b.arrayOffset() + off, len); + } else { + try { + method.invoke(checksum, CompressionUtil.safeNioBuffer(b, off, len)); + } catch (Throwable cause) { + throw new Error(); + } + } + } + } + + private static class SlowByteBufChecksum extends ByteBufChecksum { + + protected final Checksum checksum; + + SlowByteBufChecksum(Checksum checksum) { + this.checksum = checksum; + } + + @Override + public void update(int b) { + checksum.update(b); + } + + @Override + public void update(byte[] b, int off, int len) { + checksum.update(b, off, len); + } + + @Override + public long getValue() { + return checksum.getValue(); + } + + @Override + public void reset() { + checksum.reset(); + } + } +} diff --git a/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/Bzip2Decoder.java b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/Bzip2Decoder.java new file mode 100644 index 0000000..2e96581 --- /dev/null +++ b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/Bzip2Decoder.java @@ -0,0 +1,348 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.compression; + +import io.netty.buffer.ByteBuf; +import io.netty.bzip2.Bzip2BitReader; +import io.netty.bzip2.Bzip2BlockDecompressor; +import io.netty.bzip2.Bzip2HuffmanStageDecoder; +import io.netty.bzip2.Bzip2MoveToFrontTable; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.ByteToMessageDecoder; + +import java.util.List; + +import static io.netty.bzip2.Bzip2Constants.BASE_BLOCK_SIZE; +import static io.netty.bzip2.Bzip2Constants.BLOCK_HEADER_MAGIC_1; +import static io.netty.bzip2.Bzip2Constants.BLOCK_HEADER_MAGIC_2; +import static io.netty.bzip2.Bzip2Constants.END_OF_STREAM_MAGIC_1; +import static io.netty.bzip2.Bzip2Constants.END_OF_STREAM_MAGIC_2; +import static io.netty.bzip2.Bzip2Constants.HUFFMAN_MAXIMUM_TABLES; +import static io.netty.bzip2.Bzip2Constants.HUFFMAN_MAX_ALPHABET_SIZE; +import static io.netty.bzip2.Bzip2Constants.HUFFMAN_MINIMUM_TABLES; +import static io.netty.bzip2.Bzip2Constants.HUFFMAN_SELECTOR_LIST_MAX_LENGTH; +import static io.netty.bzip2.Bzip2Constants.HUFFMAN_SYMBOL_RANGE_SIZE; +import static io.netty.bzip2.Bzip2Constants.MAGIC_NUMBER; +import static io.netty.bzip2.Bzip2Constants.MAX_BLOCK_SIZE; +import static io.netty.bzip2.Bzip2Constants.MAX_SELECTORS; +import static io.netty.bzip2.Bzip2Constants.MIN_BLOCK_SIZE; + +/** + * Uncompresses a {@link ByteBuf} encoded with the Bzip2 format. + * + * See Bzip2. + */ +public class Bzip2Decoder extends ByteToMessageDecoder { + /** + * Current state of stream. + */ + private enum State { + INIT, + INIT_BLOCK, + INIT_BLOCK_PARAMS, + RECEIVE_HUFFMAN_USED_MAP, + RECEIVE_HUFFMAN_USED_BITMAPS, + RECEIVE_SELECTORS_NUMBER, + RECEIVE_SELECTORS, + RECEIVE_HUFFMAN_LENGTH, + DECODE_HUFFMAN_DATA, + EOF + } + private State currentState = State.INIT; + + /** + * A reader that provides bit-level reads. + */ + private final Bzip2BitReader reader = new Bzip2BitReader(); + + /** + * The decompressor for the current block. + */ + private Bzip2BlockDecompressor blockDecompressor; + + /** + * Bzip2 Huffman coding stage. + */ + private Bzip2HuffmanStageDecoder huffmanStageDecoder; + + /** + * Always: in the range 0 .. 9. The current block size is 100000 * this number. + */ + private int blockSize; + + /** + * The CRC of the current block as read from the block header. + */ + private int blockCRC; + + /** + * The merged CRC of all blocks decompressed so far. + */ + private int streamCRC; + + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { + if (!in.isReadable()) { + return; + } + + final Bzip2BitReader reader = this.reader; + reader.setByteBuf(in); + + for (;;) { + switch (currentState) { + case INIT: + if (in.readableBytes() < 4) { + return; + } + int magicNumber = in.readUnsignedMedium(); + if (magicNumber != MAGIC_NUMBER) { + throw new DecompressionException("Unexpected stream identifier contents. Mismatched bzip2 " + + "protocol version?"); + } + int blockSize = in.readByte() - '0'; + if (blockSize < MIN_BLOCK_SIZE || blockSize > MAX_BLOCK_SIZE) { + throw new DecompressionException("block size is invalid"); + } + this.blockSize = blockSize * BASE_BLOCK_SIZE; + + streamCRC = 0; + currentState = State.INIT_BLOCK; + // fall through + case INIT_BLOCK: + if (!reader.hasReadableBytes(10)) { + return; + } + // Get the block magic bytes. + final int magic1 = reader.readBits(24); + final int magic2 = reader.readBits(24); + if (magic1 == END_OF_STREAM_MAGIC_1 && magic2 == END_OF_STREAM_MAGIC_2) { + // End of stream was reached. Check the combined CRC. + final int storedCombinedCRC = reader.readInt(); + if (storedCombinedCRC != streamCRC) { + throw new DecompressionException("stream CRC error"); + } + currentState = State.EOF; + break; + } + if (magic1 != BLOCK_HEADER_MAGIC_1 || magic2 != BLOCK_HEADER_MAGIC_2) { + throw new DecompressionException("bad block header"); + } + blockCRC = reader.readInt(); + currentState = State.INIT_BLOCK_PARAMS; + // fall through + case INIT_BLOCK_PARAMS: + if (!reader.hasReadableBits(25)) { + return; + } + final boolean blockRandomised = reader.readBoolean(); + final int bwtStartPointer = reader.readBits(24); + + blockDecompressor = new Bzip2BlockDecompressor(this.blockSize, blockCRC, + blockRandomised, bwtStartPointer, reader); + currentState = State.RECEIVE_HUFFMAN_USED_MAP; + // fall through + case RECEIVE_HUFFMAN_USED_MAP: + if (!reader.hasReadableBits(16)) { + return; + } + blockDecompressor.huffmanInUse16 = reader.readBits(16); + currentState = State.RECEIVE_HUFFMAN_USED_BITMAPS; + // fall through + case RECEIVE_HUFFMAN_USED_BITMAPS: + Bzip2BlockDecompressor blockDecompressor = this.blockDecompressor; + final int inUse16 = blockDecompressor.huffmanInUse16; + final int bitNumber = Integer.bitCount(inUse16); + final byte[] huffmanSymbolMap = blockDecompressor.huffmanSymbolMap; + + if (!reader.hasReadableBits(bitNumber * HUFFMAN_SYMBOL_RANGE_SIZE + 3)) { + return; + } + + int huffmanSymbolCount = 0; + if (bitNumber > 0) { + for (int i = 0; i < 16; i++) { + if ((inUse16 & 1 << 15 >>> i) != 0) { + for (int j = 0, k = i << 4; j < HUFFMAN_SYMBOL_RANGE_SIZE; j++, k++) { + if (reader.readBoolean()) { + huffmanSymbolMap[huffmanSymbolCount++] = (byte) k; + } + } + } + } + } + blockDecompressor.huffmanEndOfBlockSymbol = huffmanSymbolCount + 1; + + int totalTables = reader.readBits(3); + if (totalTables < HUFFMAN_MINIMUM_TABLES || totalTables > HUFFMAN_MAXIMUM_TABLES) { + throw new DecompressionException("incorrect huffman groups number"); + } + int alphaSize = huffmanSymbolCount + 2; + if (alphaSize > HUFFMAN_MAX_ALPHABET_SIZE) { + throw new DecompressionException("incorrect alphabet size"); + } + huffmanStageDecoder = new Bzip2HuffmanStageDecoder(reader, totalTables, alphaSize); + currentState = State.RECEIVE_SELECTORS_NUMBER; + // fall through + case RECEIVE_SELECTORS_NUMBER: + if (!reader.hasReadableBits(15)) { + return; + } + int totalSelectors = reader.readBits(15); + if (totalSelectors < 1 || totalSelectors > MAX_SELECTORS) { + throw new DecompressionException("incorrect selectors number"); + } + huffmanStageDecoder.selectors = new byte[totalSelectors]; + + currentState = State.RECEIVE_SELECTORS; + // fall through + case RECEIVE_SELECTORS: + Bzip2HuffmanStageDecoder huffmanStageDecoder = this.huffmanStageDecoder; + byte[] selectors = huffmanStageDecoder.selectors; + totalSelectors = selectors.length; + final Bzip2MoveToFrontTable tableMtf = huffmanStageDecoder.tableMTF; + + int currSelector; + // Get zero-terminated bit runs (0..62) of MTF'ed Huffman table. length = 1..6 + for (currSelector = huffmanStageDecoder.currentSelector; + currSelector < totalSelectors; currSelector++) { + if (!reader.hasReadableBits(HUFFMAN_SELECTOR_LIST_MAX_LENGTH)) { + // Save state if end of current ByteBuf was reached + huffmanStageDecoder.currentSelector = currSelector; + return; + } + int index = 0; + while (reader.readBoolean()) { + index++; + } + selectors[currSelector] = tableMtf.indexToFront(index); + } + + currentState = State.RECEIVE_HUFFMAN_LENGTH; + // fall through + case RECEIVE_HUFFMAN_LENGTH: + huffmanStageDecoder = this.huffmanStageDecoder; + totalTables = huffmanStageDecoder.totalTables; + final byte[][] codeLength = huffmanStageDecoder.tableCodeLengths; + alphaSize = huffmanStageDecoder.alphabetSize; + + /* Now the coding tables */ + int currGroup; + int currLength = huffmanStageDecoder.currentLength; + int currAlpha = 0; + boolean modifyLength = huffmanStageDecoder.modifyLength; + boolean saveStateAndReturn = false; + loop: for (currGroup = huffmanStageDecoder.currentGroup; currGroup < totalTables; currGroup++) { + // start_huffman_length + if (!reader.hasReadableBits(5)) { + saveStateAndReturn = true; + break; + } + if (currLength < 0) { + currLength = reader.readBits(5); + } + for (currAlpha = huffmanStageDecoder.currentAlpha; currAlpha < alphaSize; currAlpha++) { + // delta_bit_length: 1..40 + if (!reader.isReadable()) { + saveStateAndReturn = true; + break loop; + } + while (modifyLength || reader.readBoolean()) { // 0=>next symbol; 1=>alter length + if (!reader.isReadable()) { + modifyLength = true; + saveStateAndReturn = true; + break loop; + } + // 1=>decrement length; 0=>increment length + currLength += reader.readBoolean() ? -1 : 1; + modifyLength = false; + if (!reader.isReadable()) { + saveStateAndReturn = true; + break loop; + } + } + codeLength[currGroup][currAlpha] = (byte) currLength; + } + currLength = -1; + currAlpha = huffmanStageDecoder.currentAlpha = 0; + modifyLength = false; + } + if (saveStateAndReturn) { + // Save state if end of current ByteBuf was reached + huffmanStageDecoder.currentGroup = currGroup; + huffmanStageDecoder.currentLength = currLength; + huffmanStageDecoder.currentAlpha = currAlpha; + huffmanStageDecoder.modifyLength = modifyLength; + return; + } + + // Finally create the Huffman tables + huffmanStageDecoder.createHuffmanDecodingTables(); + currentState = State.DECODE_HUFFMAN_DATA; + // fall through + case DECODE_HUFFMAN_DATA: + blockDecompressor = this.blockDecompressor; + final int oldReaderIndex = in.readerIndex(); + final boolean decoded = blockDecompressor.decodeHuffmanData(this.huffmanStageDecoder); + if (!decoded) { + return; + } + // It used to avoid "Bzip2Decoder.decode() did not read anything but decoded a message" exception. + // Because previous operation may read only a few bits from Bzip2BitReader.bitBuffer and + // don't read incoming ByteBuf. + if (in.readerIndex() == oldReaderIndex && in.isReadable()) { + reader.refill(); + } + + final int blockLength = blockDecompressor.blockLength(); + ByteBuf uncompressed = ctx.alloc().buffer(blockLength); + try { + int uncByte; + while ((uncByte = blockDecompressor.read()) >= 0) { + uncompressed.writeByte(uncByte); + } + // We did read all the data, lets reset the state and do the CRC check. + currentState = State.INIT_BLOCK; + int currentBlockCRC = blockDecompressor.checkCRC(); + streamCRC = (streamCRC << 1 | streamCRC >>> 31) ^ currentBlockCRC; + + out.add(uncompressed); + uncompressed = null; + } finally { + if (uncompressed != null) { + uncompressed.release(); + } + } + // Return here so the ByteBuf that was put in the List will be forwarded to the user and so can be + // released as soon as possible. + return; + case EOF: + in.skipBytes(in.readableBytes()); + return; + default: + throw new IllegalStateException(); + } + } + } + + /** + * Returns {@code true} if and only if the end of the compressed stream + * has been reached. + */ + public boolean isClosed() { + return currentState == State.EOF; + } +} diff --git a/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/Bzip2Encoder.java b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/Bzip2Encoder.java new file mode 100644 index 0000000..e634ee3 --- /dev/null +++ b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/Bzip2Encoder.java @@ -0,0 +1,243 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.compression; + +import io.netty.buffer.ByteBuf; +import io.netty.bzip2.Bzip2BitWriter; +import io.netty.bzip2.Bzip2BlockCompressor; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.ChannelPromise; +import io.netty.handler.codec.MessageToByteEncoder; +import io.netty.util.concurrent.EventExecutor; +import io.netty.util.concurrent.PromiseNotifier; + +import static io.netty.bzip2.Bzip2Constants.BASE_BLOCK_SIZE; +import static io.netty.bzip2.Bzip2Constants.END_OF_STREAM_MAGIC_1; +import static io.netty.bzip2.Bzip2Constants.END_OF_STREAM_MAGIC_2; +import static io.netty.bzip2.Bzip2Constants.MAGIC_NUMBER; +import static io.netty.bzip2.Bzip2Constants.MAX_BLOCK_SIZE; +import static io.netty.bzip2.Bzip2Constants.MIN_BLOCK_SIZE; + +/** + * Compresses a {@link ByteBuf} using the Bzip2 algorithm. + * + * See Bzip2. + */ +public class Bzip2Encoder extends MessageToByteEncoder { + /** + * Current state of stream. + */ + private enum State { + INIT, + INIT_BLOCK, + WRITE_DATA, + CLOSE_BLOCK + } + + private State currentState = State.INIT; + + /** + * A writer that provides bit-level writes. + */ + private final Bzip2BitWriter writer = new Bzip2BitWriter(); + + /** + * The declared maximum block size of the stream (before final run-length decoding). + */ + private final int streamBlockSize; + + /** + * The merged CRC of all blocks compressed so far. + */ + private int streamCRC; + + /** + * The compressor for the current block. + */ + private Bzip2BlockCompressor blockCompressor; + + /** + * (@code true} if the compressed stream has been finished, otherwise {@code false}. + */ + private volatile boolean finished; + + /** + * Used to interact with its {@link ChannelPipeline} and other handlers. + */ + private volatile ChannelHandlerContext ctx; + + /** + * Creates a new bzip2 encoder with the maximum (900,000 byte) block size. + */ + public Bzip2Encoder() { + this(MAX_BLOCK_SIZE); + } + + /** + * Creates a new bzip2 encoder with the specified {@code blockSizeMultiplier}. + * @param blockSizeMultiplier + * The Bzip2 block size as a multiple of 100,000 bytes (minimum {@code 1}, maximum {@code 9}). + * Larger block sizes require more memory for both compression and decompression, + * but give better compression ratios. {@code 9} will usually be the best value to use. + */ + public Bzip2Encoder(final int blockSizeMultiplier) { + if (blockSizeMultiplier < MIN_BLOCK_SIZE || blockSizeMultiplier > MAX_BLOCK_SIZE) { + throw new IllegalArgumentException( + "blockSizeMultiplier: " + blockSizeMultiplier + " (expected: 1-9)"); + } + streamBlockSize = blockSizeMultiplier * BASE_BLOCK_SIZE; + } + + @Override + protected void encode(ChannelHandlerContext ctx, ByteBuf in, ByteBuf out) throws Exception { + if (finished) { + out.writeBytes(in); + return; + } + + for (;;) { + switch (currentState) { + case INIT: + out.ensureWritable(4); + out.writeMedium(MAGIC_NUMBER); + out.writeByte('0' + streamBlockSize / BASE_BLOCK_SIZE); + currentState = State.INIT_BLOCK; + // fall through + case INIT_BLOCK: + blockCompressor = new Bzip2BlockCompressor(writer, streamBlockSize); + currentState = State.WRITE_DATA; + // fall through + case WRITE_DATA: + if (!in.isReadable()) { + return; + } + Bzip2BlockCompressor blockCompressor = this.blockCompressor; + final int length = Math.min(in.readableBytes(), blockCompressor.availableSize()); + final int bytesWritten = blockCompressor.write(in, in.readerIndex(), length); + in.skipBytes(bytesWritten); + if (!blockCompressor.isFull()) { + if (in.isReadable()) { + break; + } else { + return; + } + } + currentState = State.CLOSE_BLOCK; + // fall through + case CLOSE_BLOCK: + closeBlock(out); + currentState = State.INIT_BLOCK; + break; + default: + throw new IllegalStateException(); + } + } + } + + /** + * Close current block and update {@link #streamCRC}. + */ + private void closeBlock(ByteBuf out) { + final Bzip2BlockCompressor blockCompressor = this.blockCompressor; + if (!blockCompressor.isEmpty()) { + blockCompressor.close(out); + final int blockCRC = blockCompressor.crc(); + streamCRC = (streamCRC << 1 | streamCRC >>> 31) ^ blockCRC; + } + } + + /** + * Returns {@code true} if and only if the end of the compressed stream has been reached. + */ + public boolean isClosed() { + return finished; + } + + /** + * Close this {@link Bzip2Encoder} and so finish the encoding. + * + * The returned {@link ChannelFuture} will be notified once the operation completes. + */ + public ChannelFuture close() { + return close(ctx().newPromise()); + } + + /** + * Close this {@link Bzip2Encoder} and so finish the encoding. + * The given {@link ChannelFuture} will be notified once the operation + * completes and will also be returned. + */ + public ChannelFuture close(final ChannelPromise promise) { + ChannelHandlerContext ctx = ctx(); + EventExecutor executor = ctx.executor(); + if (executor.inEventLoop()) { + return finishEncode(ctx, promise); + } else { + executor.execute(new Runnable() { + @Override + public void run() { + ChannelFuture f = finishEncode(ctx(), promise); + PromiseNotifier.cascade(f, promise); + } + }); + return promise; + } + } + + @Override + public void close(final ChannelHandlerContext ctx, final ChannelPromise promise) throws Exception { + ChannelFuture f = finishEncode(ctx, ctx.newPromise()); + EncoderUtil.closeAfterFinishEncode(ctx, f, promise); + } + + private ChannelFuture finishEncode(final ChannelHandlerContext ctx, ChannelPromise promise) { + if (finished) { + promise.setSuccess(); + return promise; + } + finished = true; + + final ByteBuf footer = ctx.alloc().buffer(); + closeBlock(footer); + + final int streamCRC = this.streamCRC; + final Bzip2BitWriter writer = this.writer; + try { + writer.writeBits(footer, 24, END_OF_STREAM_MAGIC_1); + writer.writeBits(footer, 24, END_OF_STREAM_MAGIC_2); + writer.writeInt(footer, streamCRC); + writer.flush(footer); + } finally { + blockCompressor = null; + } + return ctx.writeAndFlush(footer, promise); + } + + private ChannelHandlerContext ctx() { + ChannelHandlerContext ctx = this.ctx; + if (ctx == null) { + throw new IllegalStateException("not added to a pipeline"); + } + return ctx; + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + this.ctx = ctx; + } +} diff --git a/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/CompressionException.java b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/CompressionException.java new file mode 100644 index 0000000..ccadcf8 --- /dev/null +++ b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/CompressionException.java @@ -0,0 +1,53 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.compression; + +import io.netty.handler.codec.EncoderException; + +/** + * An {@link EncoderException} that is raised when compression failed. + */ +public class CompressionException extends EncoderException { + + private static final long serialVersionUID = 5603413481274811897L; + + /** + * Creates a new instance. + */ + public CompressionException() { + } + + /** + * Creates a new instance. + */ + public CompressionException(String message, Throwable cause) { + super(message, cause); + } + + /** + * Creates a new instance. + */ + public CompressionException(String message) { + super(message); + } + + /** + * Creates a new instance. + */ + public CompressionException(Throwable cause) { + super(cause); + } +} diff --git a/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/CompressionOptions.java b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/CompressionOptions.java new file mode 100644 index 0000000..9ee9646 --- /dev/null +++ b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/CompressionOptions.java @@ -0,0 +1,27 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.compression; + +/** + * {@link CompressionOptions} provides compression options for + * various types of compressor types, like Brotli. + * + * A {@link CompressionOptions} instance is thread-safe + * and should be shared between multiple instances of Compressor. + */ +public interface CompressionOptions { + // Empty +} diff --git a/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/CompressionUtil.java b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/CompressionUtil.java new file mode 100644 index 0000000..d2a06f9 --- /dev/null +++ b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/CompressionUtil.java @@ -0,0 +1,47 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.compression; + +import io.netty.buffer.ByteBuf; + +import java.nio.ByteBuffer; + +final class CompressionUtil { + + private CompressionUtil() { } + + static void checkChecksum(ByteBufChecksum checksum, ByteBuf uncompressed, int currentChecksum) { + checksum.reset(); + checksum.update(uncompressed, + uncompressed.readerIndex(), uncompressed.readableBytes()); + + final int checksumResult = (int) checksum.getValue(); + if (checksumResult != currentChecksum) { + throw new DecompressionException(String.format( + "stream corrupted: mismatching checksum: %d (expected: %d)", + checksumResult, currentChecksum)); + } + } + + static ByteBuffer safeReadableNioBuffer(ByteBuf buffer) { + return safeNioBuffer(buffer, buffer.readerIndex(), buffer.readableBytes()); + } + + static ByteBuffer safeNioBuffer(ByteBuf buffer, int index, int length) { + return buffer.nioBufferCount() == 1 ? buffer.internalNioBuffer(index, length) + : buffer.nioBuffer(index, length); + } +} diff --git a/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/Crc32c.java b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/Crc32c.java new file mode 100644 index 0000000..f4b4a52 --- /dev/null +++ b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/Crc32c.java @@ -0,0 +1,125 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.compression; + +/** + * Implements CRC32-C as defined in: + * "Optimization of Cyclic Redundancy-CHeck Codes with 24 and 32 Parity Bits", + * IEEE Transactions on Communications 41(6): 883-892 (1993). + * + * The implementation of this class has been sourced from the Appendix of RFC 3309, + * but with masking due to Java not being able to support unsigned types. + */ +class Crc32c extends ByteBufChecksum { + private static final int[] CRC_TABLE = { + 0x00000000, 0xF26B8303, 0xE13B70F7, 0x1350F3F4, + 0xC79A971F, 0x35F1141C, 0x26A1E7E8, 0xD4CA64EB, + 0x8AD958CF, 0x78B2DBCC, 0x6BE22838, 0x9989AB3B, + 0x4D43CFD0, 0xBF284CD3, 0xAC78BF27, 0x5E133C24, + 0x105EC76F, 0xE235446C, 0xF165B798, 0x030E349B, + 0xD7C45070, 0x25AFD373, 0x36FF2087, 0xC494A384, + 0x9A879FA0, 0x68EC1CA3, 0x7BBCEF57, 0x89D76C54, + 0x5D1D08BF, 0xAF768BBC, 0xBC267848, 0x4E4DFB4B, + 0x20BD8EDE, 0xD2D60DDD, 0xC186FE29, 0x33ED7D2A, + 0xE72719C1, 0x154C9AC2, 0x061C6936, 0xF477EA35, + 0xAA64D611, 0x580F5512, 0x4B5FA6E6, 0xB93425E5, + 0x6DFE410E, 0x9F95C20D, 0x8CC531F9, 0x7EAEB2FA, + 0x30E349B1, 0xC288CAB2, 0xD1D83946, 0x23B3BA45, + 0xF779DEAE, 0x05125DAD, 0x1642AE59, 0xE4292D5A, + 0xBA3A117E, 0x4851927D, 0x5B016189, 0xA96AE28A, + 0x7DA08661, 0x8FCB0562, 0x9C9BF696, 0x6EF07595, + 0x417B1DBC, 0xB3109EBF, 0xA0406D4B, 0x522BEE48, + 0x86E18AA3, 0x748A09A0, 0x67DAFA54, 0x95B17957, + 0xCBA24573, 0x39C9C670, 0x2A993584, 0xD8F2B687, + 0x0C38D26C, 0xFE53516F, 0xED03A29B, 0x1F682198, + 0x5125DAD3, 0xA34E59D0, 0xB01EAA24, 0x42752927, + 0x96BF4DCC, 0x64D4CECF, 0x77843D3B, 0x85EFBE38, + 0xDBFC821C, 0x2997011F, 0x3AC7F2EB, 0xC8AC71E8, + 0x1C661503, 0xEE0D9600, 0xFD5D65F4, 0x0F36E6F7, + 0x61C69362, 0x93AD1061, 0x80FDE395, 0x72966096, + 0xA65C047D, 0x5437877E, 0x4767748A, 0xB50CF789, + 0xEB1FCBAD, 0x197448AE, 0x0A24BB5A, 0xF84F3859, + 0x2C855CB2, 0xDEEEDFB1, 0xCDBE2C45, 0x3FD5AF46, + 0x7198540D, 0x83F3D70E, 0x90A324FA, 0x62C8A7F9, + 0xB602C312, 0x44694011, 0x5739B3E5, 0xA55230E6, + 0xFB410CC2, 0x092A8FC1, 0x1A7A7C35, 0xE811FF36, + 0x3CDB9BDD, 0xCEB018DE, 0xDDE0EB2A, 0x2F8B6829, + 0x82F63B78, 0x709DB87B, 0x63CD4B8F, 0x91A6C88C, + 0x456CAC67, 0xB7072F64, 0xA457DC90, 0x563C5F93, + 0x082F63B7, 0xFA44E0B4, 0xE9141340, 0x1B7F9043, + 0xCFB5F4A8, 0x3DDE77AB, 0x2E8E845F, 0xDCE5075C, + 0x92A8FC17, 0x60C37F14, 0x73938CE0, 0x81F80FE3, + 0x55326B08, 0xA759E80B, 0xB4091BFF, 0x466298FC, + 0x1871A4D8, 0xEA1A27DB, 0xF94AD42F, 0x0B21572C, + 0xDFEB33C7, 0x2D80B0C4, 0x3ED04330, 0xCCBBC033, + 0xA24BB5A6, 0x502036A5, 0x4370C551, 0xB11B4652, + 0x65D122B9, 0x97BAA1BA, 0x84EA524E, 0x7681D14D, + 0x2892ED69, 0xDAF96E6A, 0xC9A99D9E, 0x3BC21E9D, + 0xEF087A76, 0x1D63F975, 0x0E330A81, 0xFC588982, + 0xB21572C9, 0x407EF1CA, 0x532E023E, 0xA145813D, + 0x758FE5D6, 0x87E466D5, 0x94B49521, 0x66DF1622, + 0x38CC2A06, 0xCAA7A905, 0xD9F75AF1, 0x2B9CD9F2, + 0xFF56BD19, 0x0D3D3E1A, 0x1E6DCDEE, 0xEC064EED, + 0xC38D26C4, 0x31E6A5C7, 0x22B65633, 0xD0DDD530, + 0x0417B1DB, 0xF67C32D8, 0xE52CC12C, 0x1747422F, + 0x49547E0B, 0xBB3FFD08, 0xA86F0EFC, 0x5A048DFF, + 0x8ECEE914, 0x7CA56A17, 0x6FF599E3, 0x9D9E1AE0, + 0xD3D3E1AB, 0x21B862A8, 0x32E8915C, 0xC083125F, + 0x144976B4, 0xE622F5B7, 0xF5720643, 0x07198540, + 0x590AB964, 0xAB613A67, 0xB831C993, 0x4A5A4A90, + 0x9E902E7B, 0x6CFBAD78, 0x7FAB5E8C, 0x8DC0DD8F, + 0xE330A81A, 0x115B2B19, 0x020BD8ED, 0xF0605BEE, + 0x24AA3F05, 0xD6C1BC06, 0xC5914FF2, 0x37FACCF1, + 0x69E9F0D5, 0x9B8273D6, 0x88D28022, 0x7AB90321, + 0xAE7367CA, 0x5C18E4C9, 0x4F48173D, 0xBD23943E, + 0xF36E6F75, 0x0105EC76, 0x12551F82, 0xE03E9C81, + 0x34F4F86A, 0xC69F7B69, 0xD5CF889D, 0x27A40B9E, + 0x79B737BA, 0x8BDCB4B9, 0x988C474D, 0x6AE7C44E, + 0xBE2DA0A5, 0x4C4623A6, 0x5F16D052, 0xAD7D5351, + }; + + private static final long LONG_MASK = 0xFFFFFFFFL; + private static final int BYTE_MASK = 0xFF; + + private int crc = ~0; + + @Override + public void update(int b) { + crc = crc32c(crc, b); + } + + @Override + public void update(byte[] buffer, int offset, int length) { + int end = offset + length; + for (int i = offset; i < end; i++) { + update(buffer[i]); + } + } + + @Override + public long getValue() { + return (crc ^ LONG_MASK) & LONG_MASK; + } + + @Override + public void reset() { + crc = ~0; + } + + private static int crc32c(int crc, int b) { + return crc >>> 8 ^ CRC_TABLE[(crc ^ b & BYTE_MASK) & BYTE_MASK]; + } +} diff --git a/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/DecompressionException.java b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/DecompressionException.java new file mode 100644 index 0000000..b9dae26 --- /dev/null +++ b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/DecompressionException.java @@ -0,0 +1,53 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.compression; + +import io.netty.handler.codec.DecoderException; + +/** + * A {@link DecoderException} that is raised when decompression failed. + */ +public class DecompressionException extends DecoderException { + + private static final long serialVersionUID = 3546272712208105199L; + + /** + * Creates a new instance. + */ + public DecompressionException() { + } + + /** + * Creates a new instance. + */ + public DecompressionException(String message, Throwable cause) { + super(message, cause); + } + + /** + * Creates a new instance. + */ + public DecompressionException(String message) { + super(message); + } + + /** + * Creates a new instance. + */ + public DecompressionException(Throwable cause) { + super(cause); + } +} diff --git a/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/DeflateOptions.java b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/DeflateOptions.java new file mode 100644 index 0000000..85a67aa --- /dev/null +++ b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/DeflateOptions.java @@ -0,0 +1,57 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.compression; + +import io.netty.util.internal.ObjectUtil; + +/** + * {@link DeflateOptions} holds {@link #compressionLevel()}, + * {@link #memLevel()} and {@link #windowBits()} for Deflate compression. + */ +public class DeflateOptions implements CompressionOptions { + + private final int compressionLevel; + private final int windowBits; + private final int memLevel; + + /** + * @see StandardCompressionOptions#deflate() + */ + static final DeflateOptions DEFAULT = new DeflateOptions( + 6, 15, 8 + ); + + /** + * @see StandardCompressionOptions#deflate(int, int, int) + */ + DeflateOptions(int compressionLevel, int windowBits, int memLevel) { + this.compressionLevel = ObjectUtil.checkInRange(compressionLevel, 0, 9, "compressionLevel"); + this.windowBits = ObjectUtil.checkInRange(windowBits, 9, 15, "windowBits"); + this.memLevel = ObjectUtil.checkInRange(memLevel, 1, 9, "memLevel"); + } + + public int compressionLevel() { + return compressionLevel; + } + + public int windowBits() { + return windowBits; + } + + public int memLevel() { + return memLevel; + } +} diff --git a/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/EncoderUtil.java b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/EncoderUtil.java new file mode 100644 index 0000000..477fdfd --- /dev/null +++ b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/EncoderUtil.java @@ -0,0 +1,57 @@ +/* + * Copyright 2023 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.compression; + +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.util.concurrent.Future; + +import java.util.concurrent.TimeUnit; + +final class EncoderUtil { + private static final int THREAD_POOL_DELAY_SECONDS = 10; + + static void closeAfterFinishEncode(final ChannelHandlerContext ctx, final ChannelFuture finishFuture, + final ChannelPromise promise) { + if (!finishFuture.isDone()) { + // Ensure the channel is closed even if the write operation completes in time. + final Future future = ctx.executor().schedule(new Runnable() { + @Override + public void run() { + ctx.close(promise); + } + }, THREAD_POOL_DELAY_SECONDS, TimeUnit.SECONDS); + + finishFuture.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture f) { + // Cancel the scheduled timeout. + future.cancel(true); + if (!promise.isDone()) { + ctx.close(promise); + } + } + }); + } else { + ctx.close(promise); + } + } + + private EncoderUtil() { } +} + diff --git a/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/FastLz.java b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/FastLz.java new file mode 100644 index 0000000..bfb4f00 --- /dev/null +++ b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/FastLz.java @@ -0,0 +1,560 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.compression; + +import io.netty.buffer.ByteBuf; + +/** + * Core of FastLZ compression algorithm. + * + * This class provides methods for compression and decompression of buffers and saves + * constants which use by {@link FastLzFrameEncoder} and {@link FastLzFrameDecoder}. + * + * This is refactored code of jfastlz + * library written by William Kinney. + */ +final class FastLz { + + private static final int MAX_DISTANCE = 8191; + private static final int MAX_FARDISTANCE = 65535 + MAX_DISTANCE - 1; + + private static final int HASH_LOG = 13; + private static final int HASH_SIZE = 1 << HASH_LOG; // 8192 + private static final int HASH_MASK = HASH_SIZE - 1; + + private static final int MAX_COPY = 32; + private static final int MAX_LEN = 256 + 8; + + private static final int MIN_RECOMENDED_LENGTH_FOR_LEVEL_2 = 1024 * 64; + + static final int MAGIC_NUMBER = 'F' << 16 | 'L' << 8 | 'Z'; + + static final byte BLOCK_TYPE_NON_COMPRESSED = 0x00; + static final byte BLOCK_TYPE_COMPRESSED = 0x01; + static final byte BLOCK_WITHOUT_CHECKSUM = 0x00; + static final byte BLOCK_WITH_CHECKSUM = 0x10; + + static final int OPTIONS_OFFSET = 3; + static final int CHECKSUM_OFFSET = 4; + + static final int MAX_CHUNK_LENGTH = 0xFFFF; + + /** + * Do not call {@link #compress(ByteBuf, int, int, ByteBuf, int, int)} for input buffers + * which length less than this value. + */ + static final int MIN_LENGTH_TO_COMPRESSION = 32; + + /** + * In this case {@link #compress(ByteBuf, int, int, ByteBuf, int, int)} will choose level + * automatically depending on the length of the input buffer. If length less than + * {@link #MIN_RECOMENDED_LENGTH_FOR_LEVEL_2} {@link #LEVEL_1} will be chosen, + * otherwise {@link #LEVEL_2}. + */ + static final int LEVEL_AUTO = 0; + + /** + * Level 1 is the fastest compression and generally useful for short data. + */ + static final int LEVEL_1 = 1; + + /** + * Level 2 is slightly slower but it gives better compression ratio. + */ + static final int LEVEL_2 = 2; + + /** + * The output buffer must be at least 6% larger than the input buffer and can not be smaller than 66 bytes. + * @param inputLength length of input buffer + * @return Maximum output buffer length + */ + static int calculateOutputBufferLength(int inputLength) { + final int outputLength = (int) (inputLength * 1.06); + return Math.max(outputLength, 66); + } + + /** + * Compress a block of data in the input buffer and returns the size of compressed block. + * The size of input buffer is specified by length. The minimum input buffer size is 32. + * + * If the input is not compressible, the return value might be larger than length (input buffer size). + */ + @SuppressWarnings("IdentityBinaryExpression") + static int compress(final ByteBuf input, final int inOffset, final int inLength, + final ByteBuf output, final int outOffset, final int proposedLevel) { + final int level; + if (proposedLevel == LEVEL_AUTO) { + level = inLength < MIN_RECOMENDED_LENGTH_FOR_LEVEL_2 ? LEVEL_1 : LEVEL_2; + } else { + level = proposedLevel; + } + + int ip = 0; + int ipBound = ip + inLength - 2; + int ipLimit = ip + inLength - 12; + + int op = 0; + + // const flzuint8* htab[HASH_SIZE]; + int[] htab = new int[HASH_SIZE]; + // const flzuint8** hslot; + int hslot; + // flzuint32 hval; + // int OK b/c address starting from 0 + int hval; + // flzuint32 copy; + // int OK b/c address starting from 0 + int copy; + + /* sanity check */ + if (inLength < 4) { + if (inLength != 0) { + // *op++ = length-1; + output.setByte(outOffset + op++, (byte) (inLength - 1)); + ipBound++; + while (ip <= ipBound) { + output.setByte(outOffset + op++, input.getByte(inOffset + ip++)); + } + return inLength + 1; + } + // else + return 0; + } + + /* initializes hash table */ + // for (hslot = htab; hslot < htab + HASH_SIZE; hslot++) + for (hslot = 0; hslot < HASH_SIZE; hslot++) { + //*hslot = ip; + htab[hslot] = ip; + } + + /* we start with literal copy */ + copy = 2; + output.setByte(outOffset + op++, MAX_COPY - 1); + output.setByte(outOffset + op++, input.getByte(inOffset + ip++)); + output.setByte(outOffset + op++, input.getByte(inOffset + ip++)); + + /* main loop */ + while (ip < ipLimit) { + int ref = 0; + + long distance = 0; + + /* minimum match length */ + // flzuint32 len = 3; + // int OK b/c len is 0 and octal based + int len = 3; + + /* comparison starting-point */ + int anchor = ip; + + boolean matchLabel = false; + + /* check for a run */ + if (level == LEVEL_2) { + //if(ip[0] == ip[-1] && FASTLZ_READU16(ip-1)==FASTLZ_READU16(ip+1)) + if (input.getByte(inOffset + ip) == input.getByte(inOffset + ip - 1) && + readU16(input, inOffset + ip - 1) == readU16(input, inOffset + ip + 1)) { + distance = 1; + ip += 3; + ref = anchor + (3 - 1); + + /* + * goto match; + */ + matchLabel = true; + } + } + if (!matchLabel) { + /* find potential match */ + // HASH_FUNCTION(hval,ip); + hval = hashFunction(input, inOffset + ip); + // hslot = htab + hval; + hslot = hval; + // ref = htab[hval]; + ref = htab[hval]; + + /* calculate distance to the match */ + distance = anchor - ref; + + /* update hash table */ + //*hslot = anchor; + htab[hslot] = anchor; + + /* is this a match? check the first 3 bytes */ + if (distance == 0 + || (level == LEVEL_1 ? distance >= MAX_DISTANCE : distance >= MAX_FARDISTANCE) + || input.getByte(inOffset + ref++) != input.getByte(inOffset + ip++) + || input.getByte(inOffset + ref++) != input.getByte(inOffset + ip++) + || input.getByte(inOffset + ref++) != input.getByte(inOffset + ip++)) { + /* + * goto literal; + */ + output.setByte(outOffset + op++, input.getByte(inOffset + anchor++)); + ip = anchor; + copy++; + if (copy == MAX_COPY) { + copy = 0; + output.setByte(outOffset + op++, MAX_COPY - 1); + } + continue; + } + + if (level == LEVEL_2) { + /* far, needs at least 5-byte match */ + if (distance >= MAX_DISTANCE) { + if (input.getByte(inOffset + ip++) != input.getByte(inOffset + ref++) + || input.getByte(inOffset + ip++) != input.getByte(inOffset + ref++)) { + /* + * goto literal; + */ + output.setByte(outOffset + op++, input.getByte(inOffset + anchor++)); + ip = anchor; + copy++; + if (copy == MAX_COPY) { + copy = 0; + output.setByte(outOffset + op++, MAX_COPY - 1); + } + continue; + } + len += 2; + } + } + } // end if(!matchLabel) + /* + * match: + */ + /* last matched byte */ + ip = anchor + len; + + /* distance is biased */ + distance--; + + if (distance == 0) { + /* zero distance means a run */ + //flzuint8 x = ip[-1]; + byte x = input.getByte(inOffset + ip - 1); + while (ip < ipBound) { + if (input.getByte(inOffset + ref++) != x) { + break; + } else { + ip++; + } + } + } else { + /* safe because the outer check against ip limit */ + boolean missMatch = false; + for (int i = 0; i < 8; i++) { + if (input.getByte(inOffset + ref++) != input.getByte(inOffset + ip++)) { + missMatch = true; + break; + } + } + if (!missMatch) { + while (ip < ipBound) { + if (input.getByte(inOffset + ref++) != input.getByte(inOffset + ip++)) { + break; + } + } + } + } + + /* if we have copied something, adjust the copy count */ + if (copy != 0) { + /* copy is biased, '0' means 1 byte copy */ + // *(op-copy-1) = copy-1; + output.setByte(outOffset + op - copy - 1, (byte) (copy - 1)); + } else { + /* back, to overwrite the copy count */ + op--; + } + + /* reset literal counter */ + copy = 0; + + /* length is biased, '1' means a match of 3 bytes */ + ip -= 3; + len = ip - anchor; + + /* encode the match */ + if (level == LEVEL_2) { + if (distance < MAX_DISTANCE) { + if (len < 7) { + output.setByte(outOffset + op++, (byte) ((len << 5) + (distance >>> 8))); + output.setByte(outOffset + op++, (byte) (distance & 255)); + } else { + output.setByte(outOffset + op++, (byte) ((7 << 5) + (distance >>> 8))); + for (len -= 7; len >= 255; len -= 255) { + output.setByte(outOffset + op++, (byte) 255); + } + output.setByte(outOffset + op++, (byte) len); + output.setByte(outOffset + op++, (byte) (distance & 255)); + } + } else { + /* far away, but not yet in the another galaxy... */ + if (len < 7) { + distance -= MAX_DISTANCE; + output.setByte(outOffset + op++, (byte) ((len << 5) + 31)); + output.setByte(outOffset + op++, (byte) 255); + output.setByte(outOffset + op++, (byte) (distance >>> 8)); + output.setByte(outOffset + op++, (byte) (distance & 255)); + } else { + distance -= MAX_DISTANCE; + output.setByte(outOffset + op++, (byte) ((7 << 5) + 31)); + for (len -= 7; len >= 255; len -= 255) { + output.setByte(outOffset + op++, (byte) 255); + } + output.setByte(outOffset + op++, (byte) len); + output.setByte(outOffset + op++, (byte) 255); + output.setByte(outOffset + op++, (byte) (distance >>> 8)); + output.setByte(outOffset + op++, (byte) (distance & 255)); + } + } + } else { + if (len > MAX_LEN - 2) { + while (len > MAX_LEN - 2) { + output.setByte(outOffset + op++, (byte) ((7 << 5) + (distance >>> 8))); + output.setByte(outOffset + op++, (byte) (MAX_LEN - 2 - 7 - 2)); + output.setByte(outOffset + op++, (byte) (distance & 255)); + len -= MAX_LEN - 2; + } + } + + if (len < 7) { + output.setByte(outOffset + op++, (byte) ((len << 5) + (distance >>> 8))); + output.setByte(outOffset + op++, (byte) (distance & 255)); + } else { + output.setByte(outOffset + op++, (byte) ((7 << 5) + (distance >>> 8))); + output.setByte(outOffset + op++, (byte) (len - 7)); + output.setByte(outOffset + op++, (byte) (distance & 255)); + } + } + + /* update the hash at match boundary */ + //HASH_FUNCTION(hval,ip); + hval = hashFunction(input, inOffset + ip); + htab[hval] = ip++; + + //HASH_FUNCTION(hval,ip); + hval = hashFunction(input, inOffset + ip); + htab[hval] = ip++; + + /* assuming literal copy */ + output.setByte(outOffset + op++, MAX_COPY - 1); + + continue; + + // Moved to be inline, with a 'continue' + /* + * literal: + * + output[outOffset + op++] = input[inOffset + anchor++]; + ip = anchor; + copy++; + if(copy == MAX_COPY){ + copy = 0; + output[outOffset + op++] = MAX_COPY-1; + } + */ + } + + /* left-over as literal copy */ + ipBound++; + while (ip <= ipBound) { + output.setByte(outOffset + op++, input.getByte(inOffset + ip++)); + copy++; + if (copy == MAX_COPY) { + copy = 0; + output.setByte(outOffset + op++, MAX_COPY - 1); + } + } + + /* if we have copied something, adjust the copy length */ + if (copy != 0) { + //*(op-copy-1) = copy-1; + output.setByte(outOffset + op - copy - 1, (byte) (copy - 1)); + } else { + op--; + } + + if (level == LEVEL_2) { + /* marker for fastlz2 */ + output.setByte(outOffset, output.getByte(outOffset) | 1 << 5); + } + + return op; + } + + /** + * Decompress a block of compressed data and returns the size of the decompressed block. + * If error occurs, e.g. the compressed data is corrupted or the output buffer is not large + * enough, then 0 (zero) will be returned instead. + * + * Decompression is memory safe and guaranteed not to write the output buffer + * more than what is specified in outLength. + */ + static int decompress(final ByteBuf input, final int inOffset, final int inLength, + final ByteBuf output, final int outOffset, final int outLength) { + //int level = ((*(const flzuint8*)input) >> 5) + 1; + final int level = (input.getByte(inOffset) >> 5) + 1; + if (level != LEVEL_1 && level != LEVEL_2) { + throw new DecompressionException(String.format( + "invalid level: %d (expected: %d or %d)", level, LEVEL_1, LEVEL_2 + )); + } + + // const flzuint8* ip = (const flzuint8*) input; + int ip = 0; + // flzuint8* op = (flzuint8*) output; + int op = 0; + // flzuint32 ctrl = (*ip++) & 31; + long ctrl = input.getByte(inOffset + ip++) & 31; + + int loop = 1; + do { + // const flzuint8* ref = op; + int ref = op; + // flzuint32 len = ctrl >> 5; + long len = ctrl >> 5; + // flzuint32 ofs = (ctrl & 31) << 8; + long ofs = (ctrl & 31) << 8; + + if (ctrl >= 32) { + len--; + // ref -= ofs; + ref -= ofs; + + int code; + if (len == 6) { + if (level == LEVEL_1) { + // len += *ip++; + len += input.getUnsignedByte(inOffset + ip++); + } else { + do { + code = input.getUnsignedByte(inOffset + ip++); + len += code; + } while (code == 255); + } + } + if (level == LEVEL_1) { + // ref -= *ip++; + ref -= input.getUnsignedByte(inOffset + ip++); + } else { + code = input.getUnsignedByte(inOffset + ip++); + ref -= code; + + /* match from 16-bit distance */ + // if(FASTLZ_UNEXPECT_CONDITIONAL(code==255)) + // if(FASTLZ_EXPECT_CONDITIONAL(ofs==(31 << 8))) + if (code == 255 && ofs == 31 << 8) { + ofs = input.getUnsignedByte(inOffset + ip++) << 8; + ofs += input.getUnsignedByte(inOffset + ip++); + + ref = (int) (op - ofs - MAX_DISTANCE); + } + } + + // if the output index + length of block(?) + 3(?) is over the output limit? + if (op + len + 3 > outLength) { + return 0; + } + + // if (FASTLZ_UNEXPECT_CONDITIONAL(ref-1 < (flzuint8 *)output)) + // if the address space of ref-1 is < the address of output? + // if we are still at the beginning of the output address? + if (ref - 1 < 0) { + return 0; + } + + if (ip < inLength) { + ctrl = input.getUnsignedByte(inOffset + ip++); + } else { + loop = 0; + } + + if (ref == op) { + /* optimize copy for a run */ + // flzuint8 b = ref[-1]; + byte b = output.getByte(outOffset + ref - 1); + output.setByte(outOffset + op++, b); + output.setByte(outOffset + op++, b); + output.setByte(outOffset + op++, b); + while (len != 0) { + output.setByte(outOffset + op++, b); + --len; + } + } else { + /* copy from reference */ + ref--; + + // *op++ = *ref++; + output.setByte(outOffset + op++, output.getByte(outOffset + ref++)); + output.setByte(outOffset + op++, output.getByte(outOffset + ref++)); + output.setByte(outOffset + op++, output.getByte(outOffset + ref++)); + + while (len != 0) { + output.setByte(outOffset + op++, output.getByte(outOffset + ref++)); + --len; + } + } + } else { + ctrl++; + + if (op + ctrl > outLength) { + return 0; + } + if (ip + ctrl > inLength) { + return 0; + } + + //*op++ = *ip++; + output.setByte(outOffset + op++, input.getByte(inOffset + ip++)); + + for (--ctrl; ctrl != 0; ctrl--) { + // *op++ = *ip++; + output.setByte(outOffset + op++, input.getByte(inOffset + ip++)); + } + + loop = ip < inLength ? 1 : 0; + if (loop != 0) { + // ctrl = *ip++; + ctrl = input.getUnsignedByte(inOffset + ip++); + } + } + + // while(FASTLZ_EXPECT_CONDITIONAL(loop)); + } while (loop != 0); + + // return op - (flzuint8*)output; + return op; + } + + private static int hashFunction(ByteBuf p, int offset) { + int v = readU16(p, offset); + v ^= readU16(p, offset + 1) ^ v >> 16 - HASH_LOG; + v &= HASH_MASK; + return v; + } + + private static int readU16(ByteBuf data, int offset) { + if (offset + 1 >= data.readableBytes()) { + return data.getUnsignedByte(offset); + } + return data.getUnsignedByte(offset + 1) << 8 | data.getUnsignedByte(offset); + } + + private FastLz() { } +} diff --git a/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/FastLzFrameDecoder.java b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/FastLzFrameDecoder.java new file mode 100644 index 0000000..df36bd8 --- /dev/null +++ b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/FastLzFrameDecoder.java @@ -0,0 +1,208 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.compression; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.ByteToMessageDecoder; + +import java.util.List; +import java.util.zip.Adler32; +import java.util.zip.Checksum; + +import static io.netty.handler.codec.compression.FastLz.BLOCK_TYPE_COMPRESSED; +import static io.netty.handler.codec.compression.FastLz.BLOCK_WITH_CHECKSUM; +import static io.netty.handler.codec.compression.FastLz.MAGIC_NUMBER; +import static io.netty.handler.codec.compression.FastLz.decompress; + +/** + * Uncompresses a {@link ByteBuf} encoded by {@link FastLzFrameEncoder} using the FastLZ algorithm. + * + * See FastLZ format. + */ +public class FastLzFrameDecoder extends ByteToMessageDecoder { + /** + * Current state of decompression. + */ + private enum State { + INIT_BLOCK, + INIT_BLOCK_PARAMS, + DECOMPRESS_DATA, + CORRUPTED + } + + private State currentState = State.INIT_BLOCK; + + /** + * Underlying checksum calculator in use. + */ + private final ByteBufChecksum checksum; + + /** + * Length of current received chunk of data. + */ + private int chunkLength; + + /** + * Original of current received chunk of data. + * It is equal to {@link #chunkLength} for non compressed chunks. + */ + private int originalLength; + + /** + * Indicates is this chunk compressed or not. + */ + private boolean isCompressed; + + /** + * Indicates is this chunk has checksum or not. + */ + private boolean hasChecksum; + + /** + * Checksum value of current received chunk of data which has checksum. + */ + private int currentChecksum; + + /** + * Creates the fastest FastLZ decoder without checksum calculation. + */ + public FastLzFrameDecoder() { + this(false); + } + + /** + * Creates a FastLZ decoder with calculation of checksums as specified. + * + * @param validateChecksums + * If true, the checksum field will be validated against the actual + * uncompressed data, and if the checksums do not match, a suitable + * {@link DecompressionException} will be thrown. + * Note, that in this case decoder will use {@link java.util.zip.Adler32} + * as a default checksum calculator. + */ + public FastLzFrameDecoder(boolean validateChecksums) { + this(validateChecksums ? new Adler32() : null); + } + + /** + * Creates a FastLZ decoder with specified checksum calculator. + * + * @param checksum + * the {@link Checksum} instance to use to check data for integrity. + * You may set {@code null} if you do not want to validate checksum of each block. + */ + public FastLzFrameDecoder(Checksum checksum) { + this.checksum = checksum == null ? null : ByteBufChecksum.wrapChecksum(checksum); + } + + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { + try { + switch (currentState) { + case INIT_BLOCK: + if (in.readableBytes() < 4) { + break; + } + + final int magic = in.readUnsignedMedium(); + if (magic != MAGIC_NUMBER) { + throw new DecompressionException("unexpected block identifier"); + } + + final byte options = in.readByte(); + isCompressed = (options & 0x01) == BLOCK_TYPE_COMPRESSED; + hasChecksum = (options & 0x10) == BLOCK_WITH_CHECKSUM; + + currentState = State.INIT_BLOCK_PARAMS; + // fall through + case INIT_BLOCK_PARAMS: + if (in.readableBytes() < 2 + (isCompressed ? 2 : 0) + (hasChecksum ? 4 : 0)) { + break; + } + currentChecksum = hasChecksum ? in.readInt() : 0; + chunkLength = in.readUnsignedShort(); + originalLength = isCompressed ? in.readUnsignedShort() : chunkLength; + + currentState = State.DECOMPRESS_DATA; + // fall through + case DECOMPRESS_DATA: + final int chunkLength = this.chunkLength; + if (in.readableBytes() < chunkLength) { + break; + } + + final int idx = in.readerIndex(); + final int originalLength = this.originalLength; + + ByteBuf output = null; + + try { + if (isCompressed) { + + output = ctx.alloc().buffer(originalLength); + int outputOffset = output.writerIndex(); + final int decompressedBytes = decompress(in, idx, chunkLength, + output, outputOffset, originalLength); + if (originalLength != decompressedBytes) { + throw new DecompressionException(String.format( + "stream corrupted: originalLength(%d) and actual length(%d) mismatch", + originalLength, decompressedBytes)); + } + output.writerIndex(output.writerIndex() + decompressedBytes); + } else { + output = in.retainedSlice(idx, chunkLength); + } + + final ByteBufChecksum checksum = this.checksum; + if (hasChecksum && checksum != null) { + checksum.reset(); + checksum.update(output, output.readerIndex(), output.readableBytes()); + final int checksumResult = (int) checksum.getValue(); + if (checksumResult != currentChecksum) { + throw new DecompressionException(String.format( + "stream corrupted: mismatching checksum: %d (expected: %d)", + checksumResult, currentChecksum)); + } + } + + if (output.readableBytes() > 0) { + out.add(output); + } else { + output.release(); + } + output = null; + in.skipBytes(chunkLength); + + currentState = State.INIT_BLOCK; + } finally { + if (output != null) { + output.release(); + } + } + break; + case CORRUPTED: + in.skipBytes(in.readableBytes()); + break; + default: + throw new IllegalStateException(); + } + } catch (Exception e) { + currentState = State.CORRUPTED; + throw e; + } + } +} diff --git a/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/FastLzFrameEncoder.java b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/FastLzFrameEncoder.java new file mode 100644 index 0000000..18f87c3 --- /dev/null +++ b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/FastLzFrameEncoder.java @@ -0,0 +1,172 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.compression; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.MessageToByteEncoder; + +import java.util.zip.Adler32; +import java.util.zip.Checksum; + +import static io.netty.handler.codec.compression.FastLz.BLOCK_TYPE_COMPRESSED; +import static io.netty.handler.codec.compression.FastLz.BLOCK_TYPE_NON_COMPRESSED; +import static io.netty.handler.codec.compression.FastLz.BLOCK_WITHOUT_CHECKSUM; +import static io.netty.handler.codec.compression.FastLz.BLOCK_WITH_CHECKSUM; +import static io.netty.handler.codec.compression.FastLz.CHECKSUM_OFFSET; +import static io.netty.handler.codec.compression.FastLz.LEVEL_1; +import static io.netty.handler.codec.compression.FastLz.LEVEL_2; +import static io.netty.handler.codec.compression.FastLz.LEVEL_AUTO; +import static io.netty.handler.codec.compression.FastLz.MAGIC_NUMBER; +import static io.netty.handler.codec.compression.FastLz.MAX_CHUNK_LENGTH; +import static io.netty.handler.codec.compression.FastLz.MIN_LENGTH_TO_COMPRESSION; +import static io.netty.handler.codec.compression.FastLz.OPTIONS_OFFSET; +import static io.netty.handler.codec.compression.FastLz.calculateOutputBufferLength; +import static io.netty.handler.codec.compression.FastLz.compress; + +/** + * Compresses a {@link ByteBuf} using the FastLZ algorithm. + * + * See FastLZ format. + */ +public class FastLzFrameEncoder extends MessageToByteEncoder { + /** + * Compression level. + */ + private final int level; + + /** + * Underlying checksum calculator in use. + */ + private final ByteBufChecksum checksum; + + /** + * Creates a FastLZ encoder without checksum calculator and with auto detection of compression level. + */ + public FastLzFrameEncoder() { + this(LEVEL_AUTO, null); + } + + /** + * Creates a FastLZ encoder with specified compression level and without checksum calculator. + * + * @param level supports only these values: + * 0 - Encoder will choose level automatically depending on the length of the input buffer. + * 1 - Level 1 is the fastest compression and generally useful for short data. + * 2 - Level 2 is slightly slower but it gives better compression ratio. + */ + public FastLzFrameEncoder(int level) { + this(level, null); + } + + /** + * Creates a FastLZ encoder with auto detection of compression + * level and calculation of checksums as specified. + * + * @param validateChecksums + * If true, the checksum of each block will be calculated and this value + * will be added to the header of block. + * By default {@link FastLzFrameEncoder} uses {@link java.util.zip.Adler32} + * for checksum calculation. + */ + public FastLzFrameEncoder(boolean validateChecksums) { + this(LEVEL_AUTO, validateChecksums ? new Adler32() : null); + } + + /** + * Creates a FastLZ encoder with specified compression level and checksum calculator. + * + * @param level supports only these values: + * 0 - Encoder will choose level automatically depending on the length of the input buffer. + * 1 - Level 1 is the fastest compression and generally useful for short data. + * 2 - Level 2 is slightly slower but it gives better compression ratio. + * @param checksum + * the {@link Checksum} instance to use to check data for integrity. + * You may set {@code null} if you don't want to validate checksum of each block. + */ + public FastLzFrameEncoder(int level, Checksum checksum) { + if (level != LEVEL_AUTO && level != LEVEL_1 && level != LEVEL_2) { + throw new IllegalArgumentException(String.format( + "level: %d (expected: %d or %d or %d)", level, LEVEL_AUTO, LEVEL_1, LEVEL_2)); + } + this.level = level; + this.checksum = checksum == null ? null : ByteBufChecksum.wrapChecksum(checksum); + } + + @Override + protected void encode(ChannelHandlerContext ctx, ByteBuf in, ByteBuf out) throws Exception { + final ByteBufChecksum checksum = this.checksum; + + for (;;) { + if (!in.isReadable()) { + return; + } + final int idx = in.readerIndex(); + final int length = Math.min(in.readableBytes(), MAX_CHUNK_LENGTH); + + final int outputIdx = out.writerIndex(); + out.setMedium(outputIdx, MAGIC_NUMBER); + int outputOffset = outputIdx + CHECKSUM_OFFSET + (checksum != null ? 4 : 0); + + final byte blockType; + final int chunkLength; + if (length < MIN_LENGTH_TO_COMPRESSION) { + blockType = BLOCK_TYPE_NON_COMPRESSED; + + out.ensureWritable(outputOffset + 2 + length); + final int outputPtr = outputOffset + 2; + + if (checksum != null) { + checksum.reset(); + checksum.update(in, idx, length); + out.setInt(outputIdx + CHECKSUM_OFFSET, (int) checksum.getValue()); + } + out.setBytes(outputPtr, in, idx, length); + chunkLength = length; + } else { + // try to compress + if (checksum != null) { + checksum.reset(); + checksum.update(in, idx, length); + out.setInt(outputIdx + CHECKSUM_OFFSET, (int) checksum.getValue()); + } + + final int maxOutputLength = calculateOutputBufferLength(length); + out.ensureWritable(outputOffset + 4 + maxOutputLength); + final int outputPtr = outputOffset + 4; + final int compressedLength = compress(in, in.readerIndex(), length, out, outputPtr, level); + + if (compressedLength < length) { + blockType = BLOCK_TYPE_COMPRESSED; + chunkLength = compressedLength; + + out.setShort(outputOffset, chunkLength); + outputOffset += 2; + } else { + blockType = BLOCK_TYPE_NON_COMPRESSED; + out.setBytes(outputOffset + 2, in, idx, length); + chunkLength = length; + } + } + out.setShort(outputOffset, length); + + out.setByte(outputIdx + OPTIONS_OFFSET, + blockType | (checksum != null ? BLOCK_WITH_CHECKSUM : BLOCK_WITHOUT_CHECKSUM)); + out.writerIndex(outputOffset + 2 + chunkLength); + in.skipBytes(length); + } + } +} diff --git a/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/GzipOptions.java b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/GzipOptions.java new file mode 100644 index 0000000..1591e37 --- /dev/null +++ b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/GzipOptions.java @@ -0,0 +1,38 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.compression; + +/** + * {@link GzipOptions} holds {@link #compressionLevel()}, + * {@link #memLevel()} and {@link #windowBits()} for Gzip compression. + * This class is an extension of {@link DeflateOptions} + */ +public final class GzipOptions extends DeflateOptions { + + /** + * @see StandardCompressionOptions#gzip() + */ + static final GzipOptions DEFAULT = new GzipOptions( + 6, 15, 8 + ); + + /** + * @see StandardCompressionOptions#gzip(int, int, int) + */ + GzipOptions(int compressionLevel, int windowBits, int memLevel) { + super(compressionLevel, windowBits, memLevel); + } +} diff --git a/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/JZlibDecoder.java b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/JZlibDecoder.java new file mode 100644 index 0000000..b2c18cc --- /dev/null +++ b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/JZlibDecoder.java @@ -0,0 +1,204 @@ +package io.netty.handler.codec.compression; + +import io.netty.zlib.Inflater; +import io.netty.zlib.JZlib; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.channel.ChannelHandlerContext; +import io.netty.util.internal.ObjectUtil; + +import java.util.List; + +public class JZlibDecoder extends ZlibDecoder { + + private final Inflater z = new Inflater(); + private byte[] dictionary; + private volatile boolean finished; + + /** + * Creates a new instance with the default wrapper ({@link ZlibWrapper#ZLIB}). + * + * @throws DecompressionException if failed to initialize zlib + */ + public JZlibDecoder() { + this(ZlibWrapper.ZLIB, 0); + } + + /** + * Creates a new instance with the default wrapper ({@link ZlibWrapper#ZLIB}) + * and specified maximum buffer allocation. + * + * @param maxAllocation + * Maximum size of the decompression buffer. Must be >= 0. + * If zero, maximum size is decided by the {@link ByteBufAllocator}. + * + * @throws DecompressionException if failed to initialize zlib + */ + public JZlibDecoder(int maxAllocation) { + this(ZlibWrapper.ZLIB, maxAllocation); + } + + /** + * Creates a new instance with the specified wrapper. + * + * @throws DecompressionException if failed to initialize zlib + */ + public JZlibDecoder(ZlibWrapper wrapper) { + this(wrapper, 0); + } + + /** + * Creates a new instance with the specified wrapper and maximum buffer allocation. + * + * @param maxAllocation + * Maximum size of the decompression buffer. Must be >= 0. + * If zero, maximum size is decided by the {@link ByteBufAllocator}. + * + * @throws DecompressionException if failed to initialize zlib + */ + public JZlibDecoder(ZlibWrapper wrapper, int maxAllocation) { + super(maxAllocation); + + ObjectUtil.checkNotNull(wrapper, "wrapper"); + + int resultCode = z.init(ZlibUtil.convertWrapperType(wrapper)); + if (resultCode != JZlib.Z_OK) { + ZlibUtil.fail(z, "initialization failure", resultCode); + } + } + + /** + * Creates a new instance with the specified preset dictionary. The wrapper + * is always {@link ZlibWrapper#ZLIB} because it is the only format that + * supports the preset dictionary. + * + * @throws DecompressionException if failed to initialize zlib + */ + public JZlibDecoder(byte[] dictionary) { + this(dictionary, 0); + } + + /** + * Creates a new instance with the specified preset dictionary and maximum buffer allocation. + * The wrapper is always {@link ZlibWrapper#ZLIB} because it is the only format that + * supports the preset dictionary. + * + * @param maxAllocation + * Maximum size of the decompression buffer. Must be >= 0. + * If zero, maximum size is decided by the {@link ByteBufAllocator}. + * + * @throws DecompressionException if failed to initialize zlib + */ + public JZlibDecoder(byte[] dictionary, int maxAllocation) { + super(maxAllocation); + this.dictionary = ObjectUtil.checkNotNull(dictionary, "dictionary"); + int resultCode; + resultCode = z.inflateInit(JZlib.W_ZLIB); + if (resultCode != JZlib.Z_OK) { + ZlibUtil.fail(z, "initialization failure", resultCode); + } + } + + /** + * Returns {@code true} if and only if the end of the compressed stream + * has been reached. + */ + @Override + public boolean isClosed() { + return finished; + } + + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { + if (finished) { + // Skip data received after finished. + in.skipBytes(in.readableBytes()); + return; + } + + final int inputLength = in.readableBytes(); + if (inputLength == 0) { + return; + } + + try { + // Configure input. + z.avail_in = inputLength; + if (in.hasArray()) { + z.next_in = in.array(); + z.next_in_index = in.arrayOffset() + in.readerIndex(); + } else { + byte[] array = new byte[inputLength]; + in.getBytes(in.readerIndex(), array); + z.next_in = array; + z.next_in_index = 0; + } + final int oldNextInIndex = z.next_in_index; + + // Configure output. + ByteBuf decompressed = prepareDecompressBuffer(ctx, null, inputLength << 1); + + try { + loop: for (;;) { + decompressed = prepareDecompressBuffer(ctx, decompressed, z.avail_in << 1); + z.avail_out = decompressed.writableBytes(); + z.next_out = decompressed.array(); + z.next_out_index = decompressed.arrayOffset() + decompressed.writerIndex(); + int oldNextOutIndex = z.next_out_index; + + // Decompress 'in' into 'out' + int resultCode = z.inflate(JZlib.Z_SYNC_FLUSH); + int outputLength = z.next_out_index - oldNextOutIndex; + if (outputLength > 0) { + decompressed.writerIndex(decompressed.writerIndex() + outputLength); + } + + switch (resultCode) { + case JZlib.Z_NEED_DICT: + if (dictionary == null) { + ZlibUtil.fail(z, "decompression failure", resultCode); + } else { + resultCode = z.inflateSetDictionary(dictionary, dictionary.length); + if (resultCode != JZlib.Z_OK) { + ZlibUtil.fail(z, "failed to set the dictionary", resultCode); + } + } + break; + case JZlib.Z_STREAM_END: + finished = true; // Do not decode anymore. + z.inflateEnd(); + break loop; + case JZlib.Z_OK: + break; + case JZlib.Z_BUF_ERROR: + if (z.avail_in <= 0) { + break loop; + } + break; + default: + ZlibUtil.fail(z, "decompression failure", resultCode); + } + } + } finally { + in.skipBytes(z.next_in_index - oldNextInIndex); + if (decompressed.isReadable()) { + out.add(decompressed); + } else { + decompressed.release(); + } + } + } finally { + // Deference the external references explicitly to tell the VM that + // the allocated byte arrays are temporary so that the call stack + // can be utilized. + // I'm not sure if the modern VMs do this optimization though. + z.next_in = null; + z.next_out = null; + } + } + + @Override + protected void decompressionBufferExhausted(ByteBuf buffer) { + finished = true; + } +} diff --git a/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/JZlibEncoder.java b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/JZlibEncoder.java new file mode 100644 index 0000000..db409a8 --- /dev/null +++ b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/JZlibEncoder.java @@ -0,0 +1,404 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.compression; + +import io.netty.zlib.Deflater; +import io.netty.zlib.JZlib; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.util.concurrent.EventExecutor; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.PromiseNotifier; +import io.netty.util.internal.EmptyArrays; +import io.netty.util.internal.ObjectUtil; + +import java.util.concurrent.TimeUnit; + +/** + * Compresses a {@link ByteBuf} using the deflate algorithm. + */ +public class JZlibEncoder extends ZlibEncoder { + + private final int wrapperOverhead; + private final Deflater z = new Deflater(); + private volatile boolean finished; + private volatile ChannelHandlerContext ctx; + + private static final int THREAD_POOL_DELAY_SECONDS = 10; + + /** + * Creates a new zlib encoder with the default compression level ({@code 6}), + * default window bits ({@code 15}), default memory level ({@code 8}), + * and the default wrapper ({@link ZlibWrapper#ZLIB}). + * + * @throws CompressionException if failed to initialize zlib + */ + public JZlibEncoder() { + this(6); + } + + /** + * Creates a new zlib encoder with the specified {@code compressionLevel}, + * default window bits ({@code 15}), default memory level ({@code 8}), + * and the default wrapper ({@link ZlibWrapper#ZLIB}). + * + * @param compressionLevel + * {@code 1} yields the fastest compression and {@code 9} yields the + * best compression. {@code 0} means no compression. The default + * compression level is {@code 6}. + * + * @throws CompressionException if failed to initialize zlib + */ + public JZlibEncoder(int compressionLevel) { + this(ZlibWrapper.ZLIB, compressionLevel); + } + + /** + * Creates a new zlib encoder with the default compression level ({@code 6}), + * default window bits ({@code 15}), default memory level ({@code 8}), + * and the specified wrapper. + * + * @throws CompressionException if failed to initialize zlib + */ + public JZlibEncoder(ZlibWrapper wrapper) { + this(wrapper, 6); + } + + /** + * Creates a new zlib encoder with the specified {@code compressionLevel}, + * default window bits ({@code 15}), default memory level ({@code 8}), + * and the specified wrapper. + * + * @param compressionLevel + * {@code 1} yields the fastest compression and {@code 9} yields the + * best compression. {@code 0} means no compression. The default + * compression level is {@code 6}. + * + * @throws CompressionException if failed to initialize zlib + */ + public JZlibEncoder(ZlibWrapper wrapper, int compressionLevel) { + this(wrapper, compressionLevel, 15, 8); + } + + /** + * Creates a new zlib encoder with the specified {@code compressionLevel}, + * the specified {@code windowBits}, the specified {@code memLevel}, and + * the specified wrapper. + * + * @param compressionLevel + * {@code 1} yields the fastest compression and {@code 9} yields the + * best compression. {@code 0} means no compression. The default + * compression level is {@code 6}. + * @param windowBits + * The base two logarithm of the size of the history buffer. The + * value should be in the range {@code 9} to {@code 15} inclusive. + * Larger values result in better compression at the expense of + * memory usage. The default value is {@code 15}. + * @param memLevel + * How much memory should be allocated for the internal compression + * state. {@code 1} uses minimum memory and {@code 9} uses maximum + * memory. Larger values result in better and faster compression + * at the expense of memory usage. The default value is {@code 8} + * + * @throws CompressionException if failed to initialize zlib + */ + public JZlibEncoder(ZlibWrapper wrapper, int compressionLevel, int windowBits, int memLevel) { + ObjectUtil.checkInRange(compressionLevel, 0, 9, "compressionLevel"); + ObjectUtil.checkInRange(windowBits, 9, 15, "windowBits"); + ObjectUtil.checkInRange(memLevel, 1, 9, "memLevel"); + ObjectUtil.checkNotNull(wrapper, "wrapper"); + + if (wrapper == ZlibWrapper.ZLIB_OR_NONE) { + throw new IllegalArgumentException( + "wrapper '" + ZlibWrapper.ZLIB_OR_NONE + "' is not " + + "allowed for compression."); + } + + int resultCode = z.init( + compressionLevel, windowBits, memLevel, + ZlibUtil.convertWrapperType(wrapper)); + if (resultCode != JZlib.Z_OK) { + ZlibUtil.fail(z, "initialization failure", resultCode); + } + + wrapperOverhead = ZlibUtil.wrapperOverhead(wrapper); + } + + /** + * Creates a new zlib encoder with the default compression level ({@code 6}), + * default window bits ({@code 15}), default memory level ({@code 8}), + * and the specified preset dictionary. The wrapper is always + * {@link ZlibWrapper#ZLIB} because it is the only format that supports + * the preset dictionary. + * + * @param dictionary the preset dictionary + * + * @throws CompressionException if failed to initialize zlib + */ + public JZlibEncoder(byte[] dictionary) { + this(6, dictionary); + } + + /** + * Creates a new zlib encoder with the specified {@code compressionLevel}, + * default window bits ({@code 15}), default memory level ({@code 8}), + * and the specified preset dictionary. The wrapper is always + * {@link ZlibWrapper#ZLIB} because it is the only format that supports + * the preset dictionary. + * + * @param compressionLevel + * {@code 1} yields the fastest compression and {@code 9} yields the + * best compression. {@code 0} means no compression. The default + * compression level is {@code 6}. + * @param dictionary the preset dictionary + * + * @throws CompressionException if failed to initialize zlib + */ + public JZlibEncoder(int compressionLevel, byte[] dictionary) { + this(compressionLevel, 15, 8, dictionary); + } + + /** + * Creates a new zlib encoder with the specified {@code compressionLevel}, + * the specified {@code windowBits}, the specified {@code memLevel}, + * and the specified preset dictionary. The wrapper is always + * {@link ZlibWrapper#ZLIB} because it is the only format that supports + * the preset dictionary. + * + * @param compressionLevel + * {@code 1} yields the fastest compression and {@code 9} yields the + * best compression. {@code 0} means no compression. The default + * compression level is {@code 6}. + * @param windowBits + * The base two logarithm of the size of the history buffer. The + * value should be in the range {@code 9} to {@code 15} inclusive. + * Larger values result in better compression at the expense of + * memory usage. The default value is {@code 15}. + * @param memLevel + * How much memory should be allocated for the internal compression + * state. {@code 1} uses minimum memory and {@code 9} uses maximum + * memory. Larger values result in better and faster compression + * at the expense of memory usage. The default value is {@code 8} + * @param dictionary the preset dictionary + * + * @throws CompressionException if failed to initialize zlib + */ + public JZlibEncoder(int compressionLevel, int windowBits, int memLevel, byte[] dictionary) { + ObjectUtil.checkInRange(compressionLevel, 0, 9, "compressionLevel"); + ObjectUtil.checkInRange(windowBits, 9, 15, "windowBits"); + ObjectUtil.checkInRange(memLevel, 1, 9, "memLevel"); + ObjectUtil.checkNotNull(dictionary, "dictionary"); + + int resultCode; + resultCode = z.deflateInit( + compressionLevel, windowBits, memLevel, + JZlib.W_ZLIB); // Default: ZLIB format + if (resultCode != JZlib.Z_OK) { + ZlibUtil.fail(z, "initialization failure", resultCode); + } else { + resultCode = z.deflateSetDictionary(dictionary, dictionary.length); + if (resultCode != JZlib.Z_OK) { + ZlibUtil.fail(z, "failed to set the dictionary", resultCode); + } + } + + wrapperOverhead = ZlibUtil.wrapperOverhead(ZlibWrapper.ZLIB); + } + + @Override + public ChannelFuture close() { + return close(ctx().channel().newPromise()); + } + + @Override + public ChannelFuture close(final ChannelPromise promise) { + ChannelHandlerContext ctx = ctx(); + EventExecutor executor = ctx.executor(); + if (executor.inEventLoop()) { + return finishEncode(ctx, promise); + } else { + final ChannelPromise p = ctx.newPromise(); + executor.execute(new Runnable() { + @Override + public void run() { + ChannelFuture f = finishEncode(ctx(), p); + PromiseNotifier.cascade(f, promise); + } + }); + return p; + } + } + + private ChannelHandlerContext ctx() { + ChannelHandlerContext ctx = this.ctx; + if (ctx == null) { + throw new IllegalStateException("not added to a pipeline"); + } + return ctx; + } + + @Override + public boolean isClosed() { + return finished; + } + + @Override + protected void encode(ChannelHandlerContext ctx, ByteBuf in, ByteBuf out) throws Exception { + if (finished) { + out.writeBytes(in); + return; + } + + int inputLength = in.readableBytes(); + if (inputLength == 0) { + return; + } + + try { + // Configure input. + boolean inHasArray = in.hasArray(); + z.avail_in = inputLength; + if (inHasArray) { + z.next_in = in.array(); + z.next_in_index = in.arrayOffset() + in.readerIndex(); + } else { + byte[] array = new byte[inputLength]; + in.getBytes(in.readerIndex(), array); + z.next_in = array; + z.next_in_index = 0; + } + int oldNextInIndex = z.next_in_index; + + // Configure output. + int maxOutputLength = (int) Math.ceil(inputLength * 1.001) + 12 + wrapperOverhead; + out.ensureWritable(maxOutputLength); + z.avail_out = maxOutputLength; + z.next_out = out.array(); + z.next_out_index = out.arrayOffset() + out.writerIndex(); + int oldNextOutIndex = z.next_out_index; + + // Note that Z_PARTIAL_FLUSH has been deprecated. + int resultCode; + try { + resultCode = z.deflate(JZlib.Z_SYNC_FLUSH); + } finally { + in.skipBytes(z.next_in_index - oldNextInIndex); + } + + if (resultCode != JZlib.Z_OK) { + ZlibUtil.fail(z, "compression failure", resultCode); + } + + int outputLength = z.next_out_index - oldNextOutIndex; + if (outputLength > 0) { + out.writerIndex(out.writerIndex() + outputLength); + } + } finally { + // Deference the external references explicitly to tell the VM that + // the allocated byte arrays are temporary so that the call stack + // can be utilized. + // I'm not sure if the modern VMs do this optimization though. + z.next_in = null; + z.next_out = null; + } + } + + @Override + public void close( + final ChannelHandlerContext ctx, + final ChannelPromise promise) { + ChannelFuture f = finishEncode(ctx, ctx.newPromise()); + + if (!f.isDone()) { + // Ensure the channel is closed even if the write operation completes in time. + final Future future = ctx.executor().schedule(new Runnable() { + @Override + public void run() { + if (!promise.isDone()) { + ctx.close(promise); + } + } + }, THREAD_POOL_DELAY_SECONDS, TimeUnit.SECONDS); + + f.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture f) { + // Cancel the scheduled timeout. + future.cancel(true); + if (!promise.isDone()) { + ctx.close(promise); + } + } + }); + } else { + ctx.close(promise); + } + } + + private ChannelFuture finishEncode(ChannelHandlerContext ctx, ChannelPromise promise) { + if (finished) { + promise.setSuccess(); + return promise; + } + finished = true; + + ByteBuf footer; + try { + // Configure input. + z.next_in = EmptyArrays.EMPTY_BYTES; + z.next_in_index = 0; + z.avail_in = 0; + + // Configure output. + byte[] out = new byte[32]; // room for ADLER32 + ZLIB / CRC32 + GZIP header + z.next_out = out; + z.next_out_index = 0; + z.avail_out = out.length; + + // Write the ADLER32 checksum (stream footer). + int resultCode = z.deflate(JZlib.Z_FINISH); + if (resultCode != JZlib.Z_OK && resultCode != JZlib.Z_STREAM_END) { + promise.setFailure(ZlibUtil.deflaterException(z, "compression failure", resultCode)); + return promise; + } else if (z.next_out_index != 0) { + // Suppressed a warning above to be on the safe side + // even if z.next_out_index seems to be always 0 here + footer = Unpooled.wrappedBuffer(out, 0, z.next_out_index); + } else { + footer = Unpooled.EMPTY_BUFFER; + } + } finally { + z.deflateEnd(); + + // Deference the external references explicitly to tell the VM that + // the allocated byte arrays are temporary so that the call stack + // can be utilized. + // I'm not sure if the modern VMs do this optimization though. + z.next_in = null; + z.next_out = null; + } + return ctx.writeAndFlush(footer, promise); + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + this.ctx = ctx; + } +} diff --git a/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/JdkZlibDecoder.java b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/JdkZlibDecoder.java new file mode 100644 index 0000000..1cd29e2 --- /dev/null +++ b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/JdkZlibDecoder.java @@ -0,0 +1,511 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.compression; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.channel.ChannelHandlerContext; +import io.netty.util.internal.ObjectUtil; + +import java.util.List; +import java.util.zip.CRC32; +import java.util.zip.DataFormatException; +import java.util.zip.Deflater; +import java.util.zip.Inflater; + +/** + * Decompress a {@link ByteBuf} using the inflate algorithm. + */ +public class JdkZlibDecoder extends ZlibDecoder { + private static final int FHCRC = 0x02; + private static final int FEXTRA = 0x04; + private static final int FNAME = 0x08; + private static final int FCOMMENT = 0x10; + private static final int FRESERVED = 0xE0; + + private Inflater inflater; + private final byte[] dictionary; + + // GZIP related + private final ByteBufChecksum crc; + private final boolean decompressConcatenated; + + private enum GzipState { + HEADER_START, + HEADER_END, + FLG_READ, + XLEN_READ, + SKIP_FNAME, + SKIP_COMMENT, + PROCESS_FHCRC, + FOOTER_START, + } + + private GzipState gzipState = GzipState.HEADER_START; + private int flags = -1; + private int xlen = -1; + + private volatile boolean finished; + + private boolean decideZlibOrNone; + + /** + * Creates a new instance with the default wrapper ({@link ZlibWrapper#ZLIB}). + */ + public JdkZlibDecoder() { + this(ZlibWrapper.ZLIB, null, false, 0); + } + + /** + * Creates a new instance with the default wrapper ({@link ZlibWrapper#ZLIB}) + * and the specified maximum buffer allocation. + * + * @param maxAllocation + * Maximum size of the decompression buffer. Must be >= 0. + * If zero, maximum size is decided by the {@link ByteBufAllocator}. + */ + public JdkZlibDecoder(int maxAllocation) { + this(ZlibWrapper.ZLIB, null, false, maxAllocation); + } + + /** + * Creates a new instance with the specified preset dictionary. The wrapper + * is always {@link ZlibWrapper#ZLIB} because it is the only format that + * supports the preset dictionary. + */ + public JdkZlibDecoder(byte[] dictionary) { + this(ZlibWrapper.ZLIB, dictionary, false, 0); + } + + /** + * Creates a new instance with the specified preset dictionary and maximum buffer allocation. + * The wrapper is always {@link ZlibWrapper#ZLIB} because it is the only format that + * supports the preset dictionary. + * + * @param maxAllocation + * Maximum size of the decompression buffer. Must be >= 0. + * If zero, maximum size is decided by the {@link ByteBufAllocator}. + */ + public JdkZlibDecoder(byte[] dictionary, int maxAllocation) { + this(ZlibWrapper.ZLIB, dictionary, false, maxAllocation); + } + + /** + * Creates a new instance with the specified wrapper. + * Be aware that only {@link ZlibWrapper#GZIP}, {@link ZlibWrapper#ZLIB} and {@link ZlibWrapper#NONE} are + * supported atm. + */ + public JdkZlibDecoder(ZlibWrapper wrapper) { + this(wrapper, null, false, 0); + } + + /** + * Creates a new instance with the specified wrapper and maximum buffer allocation. + * Be aware that only {@link ZlibWrapper#GZIP}, {@link ZlibWrapper#ZLIB} and {@link ZlibWrapper#NONE} are + * supported atm. + * + * @param maxAllocation + * Maximum size of the decompression buffer. Must be >= 0. + * If zero, maximum size is decided by the {@link ByteBufAllocator}. + */ + public JdkZlibDecoder(ZlibWrapper wrapper, int maxAllocation) { + this(wrapper, null, false, maxAllocation); + } + + public JdkZlibDecoder(ZlibWrapper wrapper, boolean decompressConcatenated) { + this(wrapper, null, decompressConcatenated, 0); + } + + public JdkZlibDecoder(ZlibWrapper wrapper, boolean decompressConcatenated, int maxAllocation) { + this(wrapper, null, decompressConcatenated, maxAllocation); + } + + public JdkZlibDecoder(boolean decompressConcatenated) { + this(ZlibWrapper.GZIP, null, decompressConcatenated, 0); + } + + public JdkZlibDecoder(boolean decompressConcatenated, int maxAllocation) { + this(ZlibWrapper.GZIP, null, decompressConcatenated, maxAllocation); + } + + private JdkZlibDecoder(ZlibWrapper wrapper, byte[] dictionary, boolean decompressConcatenated, int maxAllocation) { + super(maxAllocation); + + ObjectUtil.checkNotNull(wrapper, "wrapper"); + + this.decompressConcatenated = decompressConcatenated; + switch (wrapper) { + case GZIP: + inflater = new Inflater(true); + crc = ByteBufChecksum.wrapChecksum(new CRC32()); + break; + case NONE: + inflater = new Inflater(true); + crc = null; + break; + case ZLIB: + inflater = new Inflater(); + crc = null; + break; + case ZLIB_OR_NONE: + // Postpone the decision until decode(...) is called. + decideZlibOrNone = true; + crc = null; + break; + default: + throw new IllegalArgumentException("Only GZIP or ZLIB is supported, but you used " + wrapper); + } + this.dictionary = dictionary; + } + + @Override + public boolean isClosed() { + return finished; + } + + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { + if (finished) { + // Skip data received after finished. + in.skipBytes(in.readableBytes()); + return; + } + + int readableBytes = in.readableBytes(); + if (readableBytes == 0) { + return; + } + + if (decideZlibOrNone) { + // First two bytes are needed to decide if it's a ZLIB stream. + if (readableBytes < 2) { + return; + } + + boolean nowrap = !looksLikeZlib(in.getShort(in.readerIndex())); + inflater = new Inflater(nowrap); + decideZlibOrNone = false; + } + + if (crc != null) { + if (gzipState != GzipState.HEADER_END) { + if (gzipState == GzipState.FOOTER_START) { + if (!handleGzipFooter(in)) { + // Either there was not enough data or the input is finished. + return; + } + // If we consumed the footer we will start with the header again. + assert gzipState == GzipState.HEADER_START; + } + if (!readGZIPHeader(in)) { + // There was not enough data readable to read the GZIP header. + return; + } + // Some bytes may have been consumed, and so we must re-set the number of readable bytes. + readableBytes = in.readableBytes(); + if (readableBytes == 0) { + return; + } + } + } + + if (inflater.needsInput()) { + if (in.hasArray()) { + inflater.setInput(in.array(), in.arrayOffset() + in.readerIndex(), readableBytes); + } else { + byte[] array = new byte[readableBytes]; + in.getBytes(in.readerIndex(), array); + inflater.setInput(array); + } + } + + ByteBuf decompressed = prepareDecompressBuffer(ctx, null, inflater.getRemaining() << 1); + try { + boolean readFooter = false; + while (!inflater.needsInput()) { + byte[] outArray = decompressed.array(); + int writerIndex = decompressed.writerIndex(); + int outIndex = decompressed.arrayOffset() + writerIndex; + int writable = decompressed.writableBytes(); + int outputLength = inflater.inflate(outArray, outIndex, writable); + if (outputLength > 0) { + decompressed.writerIndex(writerIndex + outputLength); + if (crc != null) { + crc.update(outArray, outIndex, outputLength); + } + } else if (inflater.needsDictionary()) { + if (dictionary == null) { + throw new DecompressionException( + "decompression failure, unable to set dictionary as non was specified"); + } + inflater.setDictionary(dictionary); + } + + if (inflater.finished()) { + if (crc == null) { + finished = true; // Do not decode anymore. + } else { + readFooter = true; + } + break; + } else { + decompressed = prepareDecompressBuffer(ctx, decompressed, inflater.getRemaining() << 1); + } + } + + in.skipBytes(readableBytes - inflater.getRemaining()); + + if (readFooter) { + gzipState = GzipState.FOOTER_START; + handleGzipFooter(in); + } + } catch (DataFormatException e) { + throw new DecompressionException("decompression failure", e); + } finally { + if (decompressed.isReadable()) { + out.add(decompressed); + } else { + decompressed.release(); + } + } + } + + private boolean handleGzipFooter(ByteBuf in) { + if (readGZIPFooter(in)) { + finished = !decompressConcatenated; + + if (!finished) { + inflater.reset(); + crc.reset(); + gzipState = GzipState.HEADER_START; + return true; + } + } + return false; + } + + @Override + protected void decompressionBufferExhausted(ByteBuf buffer) { + finished = true; + } + + @Override + protected void handlerRemoved0(ChannelHandlerContext ctx) throws Exception { + super.handlerRemoved0(ctx); + if (inflater != null) { + inflater.end(); + } + } + + private boolean readGZIPHeader(ByteBuf in) { + switch (gzipState) { + case HEADER_START: + if (in.readableBytes() < 10) { + return false; + } + // read magic numbers + int magic0 = in.readByte(); + int magic1 = in.readByte(); + + if (magic0 != 31) { + throw new DecompressionException("Input is not in the GZIP format"); + } + crc.update(magic0); + crc.update(magic1); + + int method = in.readUnsignedByte(); + if (method != Deflater.DEFLATED) { + throw new DecompressionException("Unsupported compression method " + + method + " in the GZIP header"); + } + crc.update(method); + + flags = in.readUnsignedByte(); + crc.update(flags); + + if ((flags & FRESERVED) != 0) { + throw new DecompressionException( + "Reserved flags are set in the GZIP header"); + } + + // mtime (int) + crc.update(in, in.readerIndex(), 4); + in.skipBytes(4); + + crc.update(in.readUnsignedByte()); // extra flags + crc.update(in.readUnsignedByte()); // operating system + + gzipState = GzipState.FLG_READ; + // fall through + case FLG_READ: + if ((flags & FEXTRA) != 0) { + if (in.readableBytes() < 2) { + return false; + } + int xlen1 = in.readUnsignedByte(); + int xlen2 = in.readUnsignedByte(); + crc.update(xlen1); + crc.update(xlen2); + + xlen |= xlen1 << 8 | xlen2; + } + gzipState = GzipState.XLEN_READ; + // fall through + case XLEN_READ: + if (xlen != -1) { + if (in.readableBytes() < xlen) { + return false; + } + crc.update(in, in.readerIndex(), xlen); + in.skipBytes(xlen); + } + gzipState = GzipState.SKIP_FNAME; + // fall through + case SKIP_FNAME: + if (!skipIfNeeded(in, FNAME)) { + return false; + } + gzipState = GzipState.SKIP_COMMENT; + // fall through + case SKIP_COMMENT: + if (!skipIfNeeded(in, FCOMMENT)) { + return false; + } + gzipState = GzipState.PROCESS_FHCRC; + // fall through + case PROCESS_FHCRC: + if ((flags & FHCRC) != 0) { + if (!verifyCrc16(in)) { + return false; + } + } + crc.reset(); + gzipState = GzipState.HEADER_END; + // fall through + case HEADER_END: + return true; + default: + throw new IllegalStateException(); + } + } + + /** + * Skip bytes in the input if needed until we find the end marker {@code 0x00}. + * @param in the input + * @param flagMask the mask that should be present in the {@code flags} when we need to skip bytes. + * @return {@code true} if the operation is complete and we can move to the next state, {@code false} if we need + * the retry again once we have more readable bytes. + */ + private boolean skipIfNeeded(ByteBuf in, int flagMask) { + if ((flags & flagMask) != 0) { + for (;;) { + if (!in.isReadable()) { + // We didnt find the end yet, need to retry again once more data is readable + return false; + } + int b = in.readUnsignedByte(); + crc.update(b); + if (b == 0x00) { + break; + } + } + } + // Skip is handled, we can move to the next processing state. + return true; + } + + /** + * Read the GZIP footer. + * + * @param in the input. + * @return {@code true} if the footer could be read, {@code false} if the read could not be performed as + * the input {@link ByteBuf} doesn't have enough readable bytes (8 bytes). + */ + private boolean readGZIPFooter(ByteBuf in) { + if (in.readableBytes() < 8) { + return false; + } + + boolean enoughData = verifyCrc(in); + assert enoughData; + + // read ISIZE and verify + int dataLength = 0; + for (int i = 0; i < 4; ++i) { + dataLength |= in.readUnsignedByte() << i * 8; + } + int readLength = inflater.getTotalOut(); + if (dataLength != readLength) { + throw new DecompressionException( + "Number of bytes mismatch. Expected: " + dataLength + ", Got: " + readLength); + } + return true; + } + + /** + * Verifies CRC. + * + * @param in the input. + * @return {@code true} if verification could be performed, {@code false} if verification could not be performed as + * the input {@link ByteBuf} doesn't have enough readable bytes (4 bytes). + */ + private boolean verifyCrc(ByteBuf in) { + if (in.readableBytes() < 4) { + return false; + } + long crcValue = 0; + for (int i = 0; i < 4; ++i) { + crcValue |= (long) in.readUnsignedByte() << i * 8; + } + long readCrc = crc.getValue(); + if (crcValue != readCrc) { + throw new DecompressionException( + "CRC value mismatch. Expected: " + crcValue + ", Got: " + readCrc); + } + return true; + } + + private boolean verifyCrc16(ByteBuf in) { + if (in.readableBytes() < 2) { + return false; + } + long readCrc32 = crc.getValue(); + long crc16Value = 0; + long readCrc16 = 0; // the two least significant bytes from the CRC32 + for (int i = 0; i < 2; ++i) { + crc16Value |= (long) in.readUnsignedByte() << (i * 8); + readCrc16 |= ((readCrc32 >> (i * 8)) & 0xff) << (i * 8); + } + + if (crc16Value != readCrc16) { + throw new DecompressionException( + "CRC16 value mismatch. Expected: " + crc16Value + ", Got: " + readCrc16); + } + return true; + } + + /* + * Returns true if the cmf_flg parameter (think: first two bytes of a zlib stream) + * indicates that this is a zlib stream. + *

+ * You can lookup the details in the ZLIB RFC: + * RFC 1950. + */ + private static boolean looksLikeZlib(short cmf_flg) { + return (cmf_flg & 0x7800) == 0x7800 && + cmf_flg % 31 == 0; + } +} diff --git a/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/JdkZlibEncoder.java b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/JdkZlibEncoder.java new file mode 100644 index 0000000..e43f6d5 --- /dev/null +++ b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/JdkZlibEncoder.java @@ -0,0 +1,375 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.compression; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.util.concurrent.EventExecutor; +import io.netty.util.concurrent.PromiseNotifier; +import io.netty.util.internal.EmptyArrays; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.SuppressJava6Requirement; +import io.netty.util.internal.SystemPropertyUtil; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.util.zip.CRC32; +import java.util.zip.Deflater; + +/** + * Compresses a {@link ByteBuf} using the deflate algorithm. + */ +public class JdkZlibEncoder extends ZlibEncoder { + + private static final InternalLogger logger = InternalLoggerFactory.getInstance(JdkZlibEncoder.class); + + /** + * Maximum initial size for temporary heap buffers used for the compressed output. Buffer may still grow beyond + * this if necessary. + */ + private static final int MAX_INITIAL_OUTPUT_BUFFER_SIZE; + /** + * Max size for temporary heap buffers used to copy input data to heap. + */ + private static final int MAX_INPUT_BUFFER_SIZE; + + private final ZlibWrapper wrapper; + private final Deflater deflater; + private volatile boolean finished; + private volatile ChannelHandlerContext ctx; + + /* + * GZIP support + */ + private final CRC32 crc = new CRC32(); + private static final byte[] gzipHeader = {0x1f, (byte) 0x8b, Deflater.DEFLATED, 0, 0, 0, 0, 0, 0, 0}; + private boolean writeHeader = true; + + static { + MAX_INITIAL_OUTPUT_BUFFER_SIZE = SystemPropertyUtil.getInt( + "io.netty.jdkzlib.encoder.maxInitialOutputBufferSize", + 65536); + MAX_INPUT_BUFFER_SIZE = SystemPropertyUtil.getInt( + "io.netty.jdkzlib.encoder.maxInputBufferSize", + 65536); + + if (logger.isDebugEnabled()) { + logger.debug("-Dio.netty.jdkzlib.encoder.maxInitialOutputBufferSize={}", MAX_INITIAL_OUTPUT_BUFFER_SIZE); + logger.debug("-Dio.netty.jdkzlib.encoder.maxInputBufferSize={}", MAX_INPUT_BUFFER_SIZE); + } + } + + /** + * Creates a new zlib encoder with the default compression level ({@code 6}) + * and the default wrapper ({@link ZlibWrapper#ZLIB}). + * + * @throws CompressionException if failed to initialize zlib + */ + public JdkZlibEncoder() { + this(6); + } + + /** + * Creates a new zlib encoder with the specified {@code compressionLevel} + * and the default wrapper ({@link ZlibWrapper#ZLIB}). + * + * @param compressionLevel + * {@code 1} yields the fastest compression and {@code 9} yields the + * best compression. {@code 0} means no compression. The default + * compression level is {@code 6}. + * + * @throws CompressionException if failed to initialize zlib + */ + public JdkZlibEncoder(int compressionLevel) { + this(ZlibWrapper.ZLIB, compressionLevel); + } + + /** + * Creates a new zlib encoder with the default compression level ({@code 6}) + * and the specified wrapper. + * + * @throws CompressionException if failed to initialize zlib + */ + public JdkZlibEncoder(ZlibWrapper wrapper) { + this(wrapper, 6); + } + + /** + * Creates a new zlib encoder with the specified {@code compressionLevel} + * and the specified wrapper. + * + * @param compressionLevel + * {@code 1} yields the fastest compression and {@code 9} yields the + * best compression. {@code 0} means no compression. The default + * compression level is {@code 6}. + * + * @throws CompressionException if failed to initialize zlib + */ + public JdkZlibEncoder(ZlibWrapper wrapper, int compressionLevel) { + ObjectUtil.checkInRange(compressionLevel, 0, 9, "compressionLevel"); + ObjectUtil.checkNotNull(wrapper, "wrapper"); + + if (wrapper == ZlibWrapper.ZLIB_OR_NONE) { + throw new IllegalArgumentException( + "wrapper '" + ZlibWrapper.ZLIB_OR_NONE + "' is not " + + "allowed for compression."); + } + + this.wrapper = wrapper; + deflater = new Deflater(compressionLevel, wrapper != ZlibWrapper.ZLIB); + } + + /** + * Creates a new zlib encoder with the default compression level ({@code 6}) + * and the specified preset dictionary. The wrapper is always + * {@link ZlibWrapper#ZLIB} because it is the only format that supports + * the preset dictionary. + * + * @param dictionary the preset dictionary + * + * @throws CompressionException if failed to initialize zlib + */ + public JdkZlibEncoder(byte[] dictionary) { + this(6, dictionary); + } + + /** + * Creates a new zlib encoder with the specified {@code compressionLevel} + * and the specified preset dictionary. The wrapper is always + * {@link ZlibWrapper#ZLIB} because it is the only format that supports + * the preset dictionary. + * + * @param compressionLevel + * {@code 1} yields the fastest compression and {@code 9} yields the + * best compression. {@code 0} means no compression. The default + * compression level is {@code 6}. + * @param dictionary the preset dictionary + * + * @throws CompressionException if failed to initialize zlib + */ + public JdkZlibEncoder(int compressionLevel, byte[] dictionary) { + ObjectUtil.checkInRange(compressionLevel, 0, 9, "compressionLevel"); + ObjectUtil.checkNotNull(dictionary, "dictionary"); + + wrapper = ZlibWrapper.ZLIB; + deflater = new Deflater(compressionLevel); + deflater.setDictionary(dictionary); + } + + @Override + public ChannelFuture close() { + return close(ctx().newPromise()); + } + + @Override + public ChannelFuture close(final ChannelPromise promise) { + ChannelHandlerContext ctx = ctx(); + EventExecutor executor = ctx.executor(); + if (executor.inEventLoop()) { + return finishEncode(ctx, promise); + } else { + final ChannelPromise p = ctx.newPromise(); + executor.execute(new Runnable() { + @Override + public void run() { + ChannelFuture f = finishEncode(ctx(), p); + PromiseNotifier.cascade(f, promise); + } + }); + return p; + } + } + + private ChannelHandlerContext ctx() { + ChannelHandlerContext ctx = this.ctx; + if (ctx == null) { + throw new IllegalStateException("not added to a pipeline"); + } + return ctx; + } + + @Override + public boolean isClosed() { + return finished; + } + + @Override + protected void encode(ChannelHandlerContext ctx, ByteBuf uncompressed, ByteBuf out) throws Exception { + if (finished) { + out.writeBytes(uncompressed); + return; + } + + int len = uncompressed.readableBytes(); + if (len == 0) { + return; + } + + if (uncompressed.hasArray()) { + // if it is backed by an array we not need to do a copy at all + encodeSome(uncompressed, out); + } else { + int heapBufferSize = Math.min(len, MAX_INPUT_BUFFER_SIZE); + ByteBuf heapBuf = ctx.alloc().heapBuffer(heapBufferSize, heapBufferSize); + try { + while (uncompressed.isReadable()) { + uncompressed.readBytes(heapBuf, Math.min(heapBuf.writableBytes(), uncompressed.readableBytes())); + encodeSome(heapBuf, out); + heapBuf.clear(); + } + } finally { + heapBuf.release(); + } + } + // clear input so that we don't keep an unnecessary reference to the input array + deflater.setInput(EmptyArrays.EMPTY_BYTES); + } + + private void encodeSome(ByteBuf in, ByteBuf out) { + // both in and out are heap buffers, here + + byte[] inAry = in.array(); + int offset = in.arrayOffset() + in.readerIndex(); + + if (writeHeader) { + writeHeader = false; + if (wrapper == ZlibWrapper.GZIP) { + out.writeBytes(gzipHeader); + } + } + + int len = in.readableBytes(); + if (wrapper == ZlibWrapper.GZIP) { + crc.update(inAry, offset, len); + } + + deflater.setInput(inAry, offset, len); + for (;;) { + deflate(out); + if (!out.isWritable()) { + // The buffer is not writable anymore. Increase the capacity to make more room. + // Can't rely on needsInput here, it might return true even if there's still data to be written. + out.ensureWritable(out.writerIndex()); + } else if (deflater.needsInput()) { + // Consumed everything + break; + } + } + in.skipBytes(len); + } + + @Override + protected final ByteBuf allocateBuffer(ChannelHandlerContext ctx, ByteBuf msg, + boolean preferDirect) throws Exception { + int sizeEstimate = (int) Math.ceil(msg.readableBytes() * 1.001) + 12; + if (writeHeader) { + switch (wrapper) { + case GZIP: + sizeEstimate += gzipHeader.length; + break; + case ZLIB: + sizeEstimate += 2; // first two magic bytes + break; + default: + // no op + } + } + // sizeEstimate might overflow if close to 2G + if (sizeEstimate < 0 || sizeEstimate > MAX_INITIAL_OUTPUT_BUFFER_SIZE) { + // can always expand later + return ctx.alloc().heapBuffer(MAX_INITIAL_OUTPUT_BUFFER_SIZE); + } + return ctx.alloc().heapBuffer(sizeEstimate); + } + + @Override + public void close(final ChannelHandlerContext ctx, final ChannelPromise promise) throws Exception { + ChannelFuture f = finishEncode(ctx, ctx.newPromise()); + EncoderUtil.closeAfterFinishEncode(ctx, f, promise); + } + + private ChannelFuture finishEncode(final ChannelHandlerContext ctx, ChannelPromise promise) { + if (finished) { + promise.setSuccess(); + return promise; + } + + finished = true; + ByteBuf footer = ctx.alloc().heapBuffer(); + if (writeHeader && wrapper == ZlibWrapper.GZIP) { + // Write the GZIP header first if not written yet. (i.e. user wrote nothing.) + writeHeader = false; + footer.writeBytes(gzipHeader); + } + + deflater.finish(); + + while (!deflater.finished()) { + deflate(footer); + if (!footer.isWritable()) { + // no more space so write it to the channel and continue + ctx.write(footer); + footer = ctx.alloc().heapBuffer(); + } + } + if (wrapper == ZlibWrapper.GZIP) { + int crcValue = (int) crc.getValue(); + int uncBytes = deflater.getTotalIn(); + footer.writeByte(crcValue); + footer.writeByte(crcValue >>> 8); + footer.writeByte(crcValue >>> 16); + footer.writeByte(crcValue >>> 24); + footer.writeByte(uncBytes); + footer.writeByte(uncBytes >>> 8); + footer.writeByte(uncBytes >>> 16); + footer.writeByte(uncBytes >>> 24); + } + deflater.end(); + return ctx.writeAndFlush(footer, promise); + } + + @SuppressJava6Requirement(reason = "Usage guarded by java version check") + private void deflate(ByteBuf out) { + if (PlatformDependent.javaVersion() < 7) { + deflateJdk6(out); + } + int numBytes; + do { + int writerIndex = out.writerIndex(); + numBytes = deflater.deflate( + out.array(), out.arrayOffset() + writerIndex, out.writableBytes(), Deflater.SYNC_FLUSH); + out.writerIndex(writerIndex + numBytes); + } while (numBytes > 0); + } + + private void deflateJdk6(ByteBuf out) { + int numBytes; + do { + int writerIndex = out.writerIndex(); + numBytes = deflater.deflate( + out.array(), out.arrayOffset() + writerIndex, out.writableBytes()); + out.writerIndex(writerIndex + numBytes); + } while (numBytes > 0); + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + this.ctx = ctx; + } +} diff --git a/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/Lz4Constants.java b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/Lz4Constants.java new file mode 100644 index 0000000..3757cbc --- /dev/null +++ b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/Lz4Constants.java @@ -0,0 +1,73 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.compression; + +final class Lz4Constants { + /** + * Magic number of LZ4 block. + */ + static final long MAGIC_NUMBER = (long) 'L' << 56 | + (long) 'Z' << 48 | + (long) '4' << 40 | + (long) 'B' << 32 | + 'l' << 24 | + 'o' << 16 | + 'c' << 8 | + 'k'; + + /** + * Full length of LZ4 block header. + */ + static final int HEADER_LENGTH = 8 + // magic number + 1 + // token + 4 + // compressed length + 4 + // decompressed length + 4; // checksum + + /** + * Offsets of header's parts. + */ + static final int TOKEN_OFFSET = 8; + + static final int COMPRESSED_LENGTH_OFFSET = TOKEN_OFFSET + 1; + static final int DECOMPRESSED_LENGTH_OFFSET = COMPRESSED_LENGTH_OFFSET + 4; + static final int CHECKSUM_OFFSET = DECOMPRESSED_LENGTH_OFFSET + 4; + + /** + * Base value for compression level. + */ + static final int COMPRESSION_LEVEL_BASE = 10; + + /** + * LZ4 block sizes. + */ + static final int MIN_BLOCK_SIZE = 64; + static final int MAX_BLOCK_SIZE = 1 << COMPRESSION_LEVEL_BASE + 0x0F; // 32 M + static final int DEFAULT_BLOCK_SIZE = 1 << 16; // 64 KB + + /** + * LZ4 block types. + */ + static final int BLOCK_TYPE_NON_COMPRESSED = 0x10; + static final int BLOCK_TYPE_COMPRESSED = 0x20; + + /** + * Default seed value for xxhash. + */ + static final int DEFAULT_SEED = 0x9747b28c; + + private Lz4Constants() { } +} diff --git a/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/Lz4FrameDecoder.java b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/Lz4FrameDecoder.java new file mode 100644 index 0000000..12a3ebd --- /dev/null +++ b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/Lz4FrameDecoder.java @@ -0,0 +1,277 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.compression; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.ByteToMessageDecoder; +import io.netty.util.internal.ObjectUtil; +import net.jpountz.lz4.LZ4Exception; +import net.jpountz.lz4.LZ4Factory; +import net.jpountz.lz4.LZ4FastDecompressor; + +import java.util.List; +import java.util.zip.Checksum; + +import static io.netty.handler.codec.compression.Lz4Constants.BLOCK_TYPE_COMPRESSED; +import static io.netty.handler.codec.compression.Lz4Constants.BLOCK_TYPE_NON_COMPRESSED; +import static io.netty.handler.codec.compression.Lz4Constants.COMPRESSION_LEVEL_BASE; +import static io.netty.handler.codec.compression.Lz4Constants.DEFAULT_SEED; +import static io.netty.handler.codec.compression.Lz4Constants.HEADER_LENGTH; +import static io.netty.handler.codec.compression.Lz4Constants.MAGIC_NUMBER; +import static io.netty.handler.codec.compression.Lz4Constants.MAX_BLOCK_SIZE; + +/** + * Uncompresses a {@link ByteBuf} encoded with the LZ4 format. + * + * See original LZ4 Github project + * and LZ4 block format + * for full description. + * + * Since the original LZ4 block format does not contains size of compressed block and size of original data + * this encoder uses format like LZ4 Java library + * written by Adrien Grand and approved by Yann Collet (author of original LZ4 library). + * + * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * + * * Magic * Token * Compressed * Decompressed * Checksum * + * LZ4 compressed * + * * * * length * length * * * block * + * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * + */ +public class Lz4FrameDecoder extends ByteToMessageDecoder { + /** + * Current state of stream. + */ + private enum State { + INIT_BLOCK, + DECOMPRESS_DATA, + FINISHED, + CORRUPTED + } + + private State currentState = State.INIT_BLOCK; + + /** + * Underlying decompressor in use. + */ + private LZ4FastDecompressor decompressor; + + /** + * Underlying checksum calculator in use. + */ + private ByteBufChecksum checksum; + + /** + * Type of current block. + */ + private int blockType; + + /** + * Compressed length of current incoming block. + */ + private int compressedLength; + + /** + * Decompressed length of current incoming block. + */ + private int decompressedLength; + + /** + * Checksum value of current incoming block. + */ + private int currentChecksum; + + /** + * Creates the fastest LZ4 decoder. + * + * Note that by default, validation of the checksum header in each chunk is + * DISABLED for performance improvements. If performance is less of an issue, + * or if you would prefer the safety that checksum validation brings, please + * use the {@link #Lz4FrameDecoder(boolean)} constructor with the argument + * set to {@code true}. + */ + public Lz4FrameDecoder() { + this(false); + } + + /** + * Creates a LZ4 decoder with fastest decoder instance available on your machine. + * + * @param validateChecksums if {@code true}, the checksum field will be validated against the actual + * uncompressed data, and if the checksums do not match, a suitable + * {@link DecompressionException} will be thrown + */ + public Lz4FrameDecoder(boolean validateChecksums) { + this(LZ4Factory.fastestInstance(), validateChecksums); + } + + /** + * Creates a new LZ4 decoder with customizable implementation. + * + * @param factory user customizable {@link LZ4Factory} instance + * which may be JNI bindings to the original C implementation, a pure Java implementation + * or a Java implementation that uses the sun.misc.Unsafe + * @param validateChecksums if {@code true}, the checksum field will be validated against the actual + * uncompressed data, and if the checksums do not match, a suitable + * {@link DecompressionException} will be thrown. In this case encoder will use + * xxhash hashing for Java, based on Yann Collet's work available at + * Github. + */ + public Lz4FrameDecoder(LZ4Factory factory, boolean validateChecksums) { + this(factory, validateChecksums ? new Lz4XXHash32(DEFAULT_SEED) : null); + } + + /** + * Creates a new customizable LZ4 decoder. + * + * @param factory user customizable {@link LZ4Factory} instance + * which may be JNI bindings to the original C implementation, a pure Java implementation + * or a Java implementation that uses the sun.misc.Unsafe + * @param checksum the {@link Checksum} instance to use to check data for integrity. + * You may set {@code null} if you do not want to validate checksum of each block + */ + public Lz4FrameDecoder(LZ4Factory factory, Checksum checksum) { + decompressor = ObjectUtil.checkNotNull(factory, "factory").fastDecompressor(); + this.checksum = checksum == null ? null : ByteBufChecksum.wrapChecksum(checksum); + } + + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { + try { + switch (currentState) { + case INIT_BLOCK: + if (in.readableBytes() < HEADER_LENGTH) { + break; + } + final long magic = in.readLong(); + if (magic != MAGIC_NUMBER) { + throw new DecompressionException("unexpected block identifier"); + } + + final int token = in.readByte(); + final int compressionLevel = (token & 0x0F) + COMPRESSION_LEVEL_BASE; + int blockType = token & 0xF0; + + int compressedLength = Integer.reverseBytes(in.readInt()); + if (compressedLength < 0 || compressedLength > MAX_BLOCK_SIZE) { + throw new DecompressionException(String.format( + "invalid compressedLength: %d (expected: 0-%d)", + compressedLength, MAX_BLOCK_SIZE)); + } + + int decompressedLength = Integer.reverseBytes(in.readInt()); + final int maxDecompressedLength = 1 << compressionLevel; + if (decompressedLength < 0 || decompressedLength > maxDecompressedLength) { + throw new DecompressionException(String.format( + "invalid decompressedLength: %d (expected: 0-%d)", + decompressedLength, maxDecompressedLength)); + } + if (decompressedLength == 0 && compressedLength != 0 + || decompressedLength != 0 && compressedLength == 0 + || blockType == BLOCK_TYPE_NON_COMPRESSED && decompressedLength != compressedLength) { + throw new DecompressionException(String.format( + "stream corrupted: compressedLength(%d) and decompressedLength(%d) mismatch", + compressedLength, decompressedLength)); + } + + int currentChecksum = Integer.reverseBytes(in.readInt()); + if (decompressedLength == 0 && compressedLength == 0) { + if (currentChecksum != 0) { + throw new DecompressionException("stream corrupted: checksum error"); + } + currentState = State.FINISHED; + decompressor = null; + checksum = null; + break; + } + + this.blockType = blockType; + this.compressedLength = compressedLength; + this.decompressedLength = decompressedLength; + this.currentChecksum = currentChecksum; + + currentState = State.DECOMPRESS_DATA; + // fall through + case DECOMPRESS_DATA: + blockType = this.blockType; + compressedLength = this.compressedLength; + decompressedLength = this.decompressedLength; + currentChecksum = this.currentChecksum; + + if (in.readableBytes() < compressedLength) { + break; + } + + final ByteBufChecksum checksum = this.checksum; + ByteBuf uncompressed = null; + + try { + switch (blockType) { + case BLOCK_TYPE_NON_COMPRESSED: + // Just pass through, we not update the readerIndex yet as we do this outside of the + // switch statement. + uncompressed = in.retainedSlice(in.readerIndex(), decompressedLength); + break; + case BLOCK_TYPE_COMPRESSED: + uncompressed = ctx.alloc().buffer(decompressedLength, decompressedLength); + + decompressor.decompress(CompressionUtil.safeReadableNioBuffer(in), + uncompressed.internalNioBuffer(uncompressed.writerIndex(), decompressedLength)); + // Update the writerIndex now to reflect what we decompressed. + uncompressed.writerIndex(uncompressed.writerIndex() + decompressedLength); + break; + default: + throw new DecompressionException(String.format( + "unexpected blockType: %d (expected: %d or %d)", + blockType, BLOCK_TYPE_NON_COMPRESSED, BLOCK_TYPE_COMPRESSED)); + } + // Skip inbound bytes after we processed them. + in.skipBytes(compressedLength); + + if (checksum != null) { + CompressionUtil.checkChecksum(checksum, uncompressed, currentChecksum); + } + out.add(uncompressed); + uncompressed = null; + currentState = State.INIT_BLOCK; + } catch (LZ4Exception e) { + throw new DecompressionException(e); + } finally { + if (uncompressed != null) { + uncompressed.release(); + } + } + break; + case FINISHED: + case CORRUPTED: + in.skipBytes(in.readableBytes()); + break; + default: + throw new IllegalStateException(); + } + } catch (Exception e) { + currentState = State.CORRUPTED; + throw e; + } + } + + /** + * Returns {@code true} if and only if the end of the compressed stream + * has been reached. + */ + public boolean isClosed() { + return currentState == State.FINISHED; + } +} diff --git a/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/Lz4FrameEncoder.java b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/Lz4FrameEncoder.java new file mode 100644 index 0000000..792fdf4 --- /dev/null +++ b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/Lz4FrameEncoder.java @@ -0,0 +1,402 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.codec.compression; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.ChannelPromise; +import io.netty.handler.codec.EncoderException; +import io.netty.handler.codec.MessageToByteEncoder; +import io.netty.util.concurrent.EventExecutor; +import io.netty.util.concurrent.PromiseNotifier; +import io.netty.util.internal.ObjectUtil; +import net.jpountz.lz4.LZ4Compressor; +import net.jpountz.lz4.LZ4Exception; +import net.jpountz.lz4.LZ4Factory; + +import java.nio.ByteBuffer; +import java.util.zip.Checksum; + +import static io.netty.handler.codec.compression.Lz4Constants.BLOCK_TYPE_COMPRESSED; +import static io.netty.handler.codec.compression.Lz4Constants.BLOCK_TYPE_NON_COMPRESSED; +import static io.netty.handler.codec.compression.Lz4Constants.CHECKSUM_OFFSET; +import static io.netty.handler.codec.compression.Lz4Constants.COMPRESSED_LENGTH_OFFSET; +import static io.netty.handler.codec.compression.Lz4Constants.COMPRESSION_LEVEL_BASE; +import static io.netty.handler.codec.compression.Lz4Constants.DECOMPRESSED_LENGTH_OFFSET; +import static io.netty.handler.codec.compression.Lz4Constants.DEFAULT_BLOCK_SIZE; +import static io.netty.handler.codec.compression.Lz4Constants.DEFAULT_SEED; +import static io.netty.handler.codec.compression.Lz4Constants.HEADER_LENGTH; +import static io.netty.handler.codec.compression.Lz4Constants.MAGIC_NUMBER; +import static io.netty.handler.codec.compression.Lz4Constants.MAX_BLOCK_SIZE; +import static io.netty.handler.codec.compression.Lz4Constants.MIN_BLOCK_SIZE; +import static io.netty.handler.codec.compression.Lz4Constants.TOKEN_OFFSET; + +/** + * Compresses a {@link ByteBuf} using the LZ4 format. + * + * See original LZ4 Github project + * and LZ4 block format + * for full description. + * + * Since the original LZ4 block format does not contains size of compressed block and size of original data + * this encoder uses format like LZ4 Java library + * written by Adrien Grand and approved by Yann Collet (author of original LZ4 library). + * + * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * + * * Magic * Token * Compressed * Decompressed * Checksum * + * LZ4 compressed * + * * * * length * length * * * block * + * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * * + */ +public class Lz4FrameEncoder extends MessageToByteEncoder { + static final int DEFAULT_MAX_ENCODE_SIZE = Integer.MAX_VALUE; + + private final int blockSize; + + /** + * Underlying compressor in use. + */ + private final LZ4Compressor compressor; + + /** + * Underlying checksum calculator in use. + */ + private final ByteBufChecksum checksum; + + /** + * Compression level of current LZ4 encoder (depends on {@link #blockSize}). + */ + private final int compressionLevel; + + /** + * Inner byte buffer for outgoing data. It's capacity will be {@link #blockSize}. + */ + private ByteBuf buffer; + + /** + * Maximum size for any buffer to write encoded (compressed) data into. + */ + private final int maxEncodeSize; + + /** + * Indicates if the compressed stream has been finished. + */ + private volatile boolean finished; + + /** + * Used to interact with its {@link ChannelPipeline} and other handlers. + */ + private volatile ChannelHandlerContext ctx; + + /** + * Creates the fastest LZ4 encoder with default block size (64 KB) + * and xxhash hashing for Java, based on Yann Collet's work available at + * Github. + */ + public Lz4FrameEncoder() { + this(false); + } + + /** + * Creates a new LZ4 encoder with hight or fast compression, default block size (64 KB) + * and xxhash hashing for Java, based on Yann Collet's work available at + * Github. + * + * @param highCompressor if {@code true} codec will use compressor which requires more memory + * and is slower but compresses more efficiently + */ + public Lz4FrameEncoder(boolean highCompressor) { + this(LZ4Factory.fastestInstance(), highCompressor, DEFAULT_BLOCK_SIZE, new Lz4XXHash32(DEFAULT_SEED)); + } + + /** + * Creates a new customizable LZ4 encoder. + * + * @param factory user customizable {@link LZ4Factory} instance + * which may be JNI bindings to the original C implementation, a pure Java implementation + * or a Java implementation that uses the sun.misc.Unsafe + * @param highCompressor if {@code true} codec will use compressor which requires more memory + * and is slower but compresses more efficiently + * @param blockSize the maximum number of bytes to try to compress at once, + * must be >= 64 and <= 32 M + * @param checksum the {@link Checksum} instance to use to check data for integrity + */ + public Lz4FrameEncoder(LZ4Factory factory, boolean highCompressor, int blockSize, Checksum checksum) { + this(factory, highCompressor, blockSize, checksum, DEFAULT_MAX_ENCODE_SIZE); + } + + /** + * Creates a new customizable LZ4 encoder. + * + * @param factory user customizable {@link LZ4Factory} instance + * which may be JNI bindings to the original C implementation, a pure Java implementation + * or a Java implementation that uses the sun.misc.Unsafe + * @param highCompressor if {@code true} codec will use compressor which requires more memory + * and is slower but compresses more efficiently + * @param blockSize the maximum number of bytes to try to compress at once, + * must be >= 64 and <= 32 M + * @param checksum the {@link Checksum} instance to use to check data for integrity + * @param maxEncodeSize the maximum size for an encode (compressed) buffer + */ + public Lz4FrameEncoder(LZ4Factory factory, boolean highCompressor, int blockSize, + Checksum checksum, int maxEncodeSize) { + ObjectUtil.checkNotNull(factory, "factory"); + ObjectUtil.checkNotNull(checksum, "checksum"); + + compressor = highCompressor ? factory.highCompressor() : factory.fastCompressor(); + this.checksum = ByteBufChecksum.wrapChecksum(checksum); + + compressionLevel = compressionLevel(blockSize); + this.blockSize = blockSize; + this.maxEncodeSize = ObjectUtil.checkPositive(maxEncodeSize, "maxEncodeSize"); + finished = false; + } + + /** + * Calculates compression level on the basis of block size. + */ + private static int compressionLevel(int blockSize) { + if (blockSize < MIN_BLOCK_SIZE || blockSize > MAX_BLOCK_SIZE) { + throw new IllegalArgumentException(String.format( + "blockSize: %d (expected: %d-%d)", blockSize, MIN_BLOCK_SIZE, MAX_BLOCK_SIZE)); + } + int compressionLevel = 32 - Integer.numberOfLeadingZeros(blockSize - 1); // ceil of log2 + compressionLevel = Math.max(0, compressionLevel - COMPRESSION_LEVEL_BASE); + return compressionLevel; + } + + @Override + protected ByteBuf allocateBuffer(ChannelHandlerContext ctx, ByteBuf msg, boolean preferDirect) { + return allocateBuffer(ctx, msg, preferDirect, true); + } + + private ByteBuf allocateBuffer(ChannelHandlerContext ctx, ByteBuf msg, boolean preferDirect, + boolean allowEmptyReturn) { + int targetBufSize = 0; + int remaining = msg.readableBytes() + buffer.readableBytes(); + + // quick overflow check + if (remaining < 0) { + throw new EncoderException("too much data to allocate a buffer for compression"); + } + + while (remaining > 0) { + int curSize = Math.min(blockSize, remaining); + remaining -= curSize; + // calculate the total compressed size of the current block (including header) and add to the total + targetBufSize += compressor.maxCompressedLength(curSize) + HEADER_LENGTH; + } + + // in addition to just the raw byte count, the headers (HEADER_LENGTH) per block (configured via + // #blockSize) will also add to the targetBufSize, and the combination of those would never wrap around + // again to be >= 0, this is a good check for the overflow case. + if (targetBufSize > maxEncodeSize || 0 > targetBufSize) { + throw new EncoderException(String.format("requested encode buffer size (%d bytes) exceeds the maximum " + + "allowable size (%d bytes)", targetBufSize, maxEncodeSize)); + } + + if (allowEmptyReturn && targetBufSize < blockSize) { + return Unpooled.EMPTY_BUFFER; + } + + if (preferDirect) { + return ctx.alloc().ioBuffer(targetBufSize, targetBufSize); + } else { + return ctx.alloc().heapBuffer(targetBufSize, targetBufSize); + } + } + + /** + * {@inheritDoc} + * + * Encodes the input buffer into {@link #blockSize} chunks in the output buffer. Data is only compressed and + * written once we hit the {@link #blockSize}; else, it is copied into the backing {@link #buffer} to await + * more data. + */ + @Override + protected void encode(ChannelHandlerContext ctx, ByteBuf in, ByteBuf out) throws Exception { + if (finished) { + if (!out.isWritable(in.readableBytes())) { + // out should be EMPTY_BUFFER because we should have allocated enough space above in allocateBuffer. + throw new IllegalStateException("encode finished and not enough space to write remaining data"); + } + out.writeBytes(in); + return; + } + + final ByteBuf buffer = this.buffer; + int length; + while ((length = in.readableBytes()) > 0) { + final int nextChunkSize = Math.min(length, buffer.writableBytes()); + in.readBytes(buffer, nextChunkSize); + + if (!buffer.isWritable()) { + flushBufferedData(out); + } + } + } + + private void flushBufferedData(ByteBuf out) { + int flushableBytes = buffer.readableBytes(); + if (flushableBytes == 0) { + return; + } + checksum.reset(); + checksum.update(buffer, buffer.readerIndex(), flushableBytes); + final int check = (int) checksum.getValue(); + + final int bufSize = compressor.maxCompressedLength(flushableBytes) + HEADER_LENGTH; + out.ensureWritable(bufSize); + final int idx = out.writerIndex(); + int compressedLength; + try { + ByteBuffer outNioBuffer = out.internalNioBuffer(idx + HEADER_LENGTH, out.writableBytes() - HEADER_LENGTH); + int pos = outNioBuffer.position(); + // We always want to start at position 0 as we take care of reusing the buffer in the encode(...) loop. + compressor.compress(buffer.internalNioBuffer(buffer.readerIndex(), flushableBytes), outNioBuffer); + compressedLength = outNioBuffer.position() - pos; + } catch (LZ4Exception e) { + throw new CompressionException(e); + } + final int blockType; + if (compressedLength >= flushableBytes) { + blockType = BLOCK_TYPE_NON_COMPRESSED; + compressedLength = flushableBytes; + out.setBytes(idx + HEADER_LENGTH, buffer, buffer.readerIndex(), flushableBytes); + } else { + blockType = BLOCK_TYPE_COMPRESSED; + } + + out.setLong(idx, MAGIC_NUMBER); + out.setByte(idx + TOKEN_OFFSET, (byte) (blockType | compressionLevel)); + out.setIntLE(idx + COMPRESSED_LENGTH_OFFSET, compressedLength); + out.setIntLE(idx + DECOMPRESSED_LENGTH_OFFSET, flushableBytes); + out.setIntLE(idx + CHECKSUM_OFFSET, check); + out.writerIndex(idx + HEADER_LENGTH + compressedLength); + buffer.clear(); + } + + @Override + public void flush(final ChannelHandlerContext ctx) throws Exception { + if (buffer != null && buffer.isReadable()) { + final ByteBuf buf = allocateBuffer(ctx, Unpooled.EMPTY_BUFFER, isPreferDirect(), false); + flushBufferedData(buf); + ctx.write(buf); + } + ctx.flush(); + } + + private ChannelFuture finishEncode(final ChannelHandlerContext ctx, ChannelPromise promise) { + if (finished) { + promise.setSuccess(); + return promise; + } + finished = true; + + final ByteBuf footer = ctx.alloc().heapBuffer( + compressor.maxCompressedLength(buffer.readableBytes()) + HEADER_LENGTH); + flushBufferedData(footer); + + footer.ensureWritable(HEADER_LENGTH); + final int idx = footer.writerIndex(); + footer.setLong(idx, MAGIC_NUMBER); + footer.setByte(idx + TOKEN_OFFSET, (byte) (BLOCK_TYPE_NON_COMPRESSED | compressionLevel)); + footer.setInt(idx + COMPRESSED_LENGTH_OFFSET, 0); + footer.setInt(idx + DECOMPRESSED_LENGTH_OFFSET, 0); + footer.setInt(idx + CHECKSUM_OFFSET, 0); + + footer.writerIndex(idx + HEADER_LENGTH); + + return ctx.writeAndFlush(footer, promise); + } + + /** + * Returns {@code true} if and only if the compressed stream has been finished. + */ + public boolean isClosed() { + return finished; + } + + /** + * Close this {@link Lz4FrameEncoder} and so finish the encoding. + * + * The returned {@link ChannelFuture} will be notified once the operation completes. + */ + public ChannelFuture close() { + return close(ctx().newPromise()); + } + + /** + * Close this {@link Lz4FrameEncoder} and so finish the encoding. + * The given {@link ChannelFuture} will be notified once the operation + * completes and will also be returned. + */ + public ChannelFuture close(final ChannelPromise promise) { + ChannelHandlerContext ctx = ctx(); + EventExecutor executor = ctx.executor(); + if (executor.inEventLoop()) { + return finishEncode(ctx, promise); + } else { + executor.execute(new Runnable() { + @Override + public void run() { + ChannelFuture f = finishEncode(ctx(), promise); + PromiseNotifier.cascade(f, promise); + } + }); + return promise; + } + } + + @Override + public void close(final ChannelHandlerContext ctx, final ChannelPromise promise) throws Exception { + ChannelFuture f = finishEncode(ctx, ctx.newPromise()); + + EncoderUtil.closeAfterFinishEncode(ctx, f, promise); + } + + private ChannelHandlerContext ctx() { + ChannelHandlerContext ctx = this.ctx; + if (ctx == null) { + throw new IllegalStateException("not added to a pipeline"); + } + return ctx; + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) { + this.ctx = ctx; + // Ensure we use a heap based ByteBuf. + buffer = Unpooled.wrappedBuffer(new byte[blockSize]); + buffer.clear(); + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { + super.handlerRemoved(ctx); + if (buffer != null) { + buffer.release(); + buffer = null; + } + } + + final ByteBuf getBackingBuffer() { + return buffer; + } +} diff --git a/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/Lz4XXHash32.java b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/Lz4XXHash32.java new file mode 100644 index 0000000..8b92ecf --- /dev/null +++ b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/Lz4XXHash32.java @@ -0,0 +1,107 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.compression; + +import io.netty.buffer.ByteBuf; +import net.jpountz.xxhash.StreamingXXHash32; +import net.jpountz.xxhash.XXHash32; +import net.jpountz.xxhash.XXHashFactory; + +import java.nio.ByteBuffer; +import java.util.zip.Checksum; + +/** + * A special-purpose {@link ByteBufChecksum} implementation for use with + * {@link Lz4FrameEncoder} and {@link Lz4FrameDecoder}. + * + * {@link StreamingXXHash32#asChecksum()} has a particularly nasty implementation + * of {@link Checksum#update(int)} that allocates a single-element byte array for + * every invocation. + * + * In addition to that, it doesn't implement an overload that accepts a {@link ByteBuffer} + * as an argument. + * + * Combined, this means that we can't use {@code ReflectiveByteBufChecksum} at all, + * and can't use {@code SlowByteBufChecksum} because of its atrocious performance + * with direct byte buffers (allocating an array and making a JNI call for every byte + * checksummed might be considered sub-optimal by some). + * + * Block version of xxHash32 ({@link XXHash32}), however, does provide + * {@link XXHash32#hash(ByteBuffer, int)} method that is efficient and does exactly + * what we need, with a caveat that we can only invoke it once before having to reset. + * This, however, is fine for our purposes, given the way we use it in + * {@link Lz4FrameEncoder} and {@link Lz4FrameDecoder}: + * {@code reset()}, followed by one {@code update()}, followed by {@code getValue()}. + */ +public final class Lz4XXHash32 extends ByteBufChecksum { + + private static final XXHash32 XXHASH32 = XXHashFactory.fastestInstance().hash32(); + + private final int seed; + private boolean used; + private int value; + + @SuppressWarnings("WeakerAccess") + public Lz4XXHash32(int seed) { + this.seed = seed; + } + + @Override + public void update(int b) { + throw new UnsupportedOperationException(); + } + + @Override + public void update(byte[] b, int off, int len) { + if (used) { + throw new IllegalStateException(); + } + value = XXHASH32.hash(b, off, len, seed); + used = true; + } + + @Override + public void update(ByteBuf b, int off, int len) { + if (used) { + throw new IllegalStateException(); + } + if (b.hasArray()) { + value = XXHASH32.hash(b.array(), b.arrayOffset() + off, len, seed); + } else { + value = XXHASH32.hash(CompressionUtil.safeNioBuffer(b, off, len), seed); + } + used = true; + } + + @Override + public long getValue() { + if (!used) { + throw new IllegalStateException(); + } + /* + * If you look carefully, you'll notice that the most significant nibble + * is being discarded; we believe this to be a bug, but this is what + * StreamingXXHash32#asChecksum() implementation of getValue() does, + * so we have to retain this behaviour for compatibility reasons. + */ + return value & 0xFFFFFFFL; + } + + @Override + public void reset() { + used = false; + } +} diff --git a/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/LzfDecoder.java b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/LzfDecoder.java new file mode 100644 index 0000000..329e966 --- /dev/null +++ b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/LzfDecoder.java @@ -0,0 +1,242 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.compression; + +import com.ning.compress.BufferRecycler; +import com.ning.compress.lzf.ChunkDecoder; +import com.ning.compress.lzf.LZFChunk; +import com.ning.compress.lzf.util.ChunkDecoderFactory; +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.ByteToMessageDecoder; + +import java.util.List; + +import static com.ning.compress.lzf.LZFChunk.BLOCK_TYPE_COMPRESSED; +import static com.ning.compress.lzf.LZFChunk.BLOCK_TYPE_NON_COMPRESSED; +import static com.ning.compress.lzf.LZFChunk.BYTE_V; +import static com.ning.compress.lzf.LZFChunk.BYTE_Z; +import static com.ning.compress.lzf.LZFChunk.HEADER_LEN_NOT_COMPRESSED; + +/** + * Uncompresses a {@link ByteBuf} encoded with the LZF format. + * + * See original LZF package + * and LZF format for full description. + */ +public class LzfDecoder extends ByteToMessageDecoder { + /** + * Current state of decompression. + */ + private enum State { + INIT_BLOCK, + INIT_ORIGINAL_LENGTH, + DECOMPRESS_DATA, + CORRUPTED + } + + private State currentState = State.INIT_BLOCK; + + /** + * Magic number of LZF chunk. + */ + private static final short MAGIC_NUMBER = BYTE_Z << 8 | BYTE_V; + + /** + * Underlying decoder in use. + */ + private ChunkDecoder decoder; + + /** + * Object that handles details of buffer recycling. + */ + private BufferRecycler recycler; + + /** + * Length of current received chunk of data. + */ + private int chunkLength; + + /** + * Original length of current received chunk of data. + * It is equal to {@link #chunkLength} for non compressed chunks. + */ + private int originalLength; + + /** + * Indicates is this chunk compressed or not. + */ + private boolean isCompressed; + + /** + * Creates a new LZF decoder with the most optimal available methods for underlying data access. + * It will "unsafe" instance if one can be used on current JVM. + * It should be safe to call this constructor as implementations are dynamically loaded; however, on some + * non-standard platforms it may be necessary to use {@link #LzfDecoder(boolean)} with {@code true} param. + */ + public LzfDecoder() { + this(false); + } + + /** + * Creates a new LZF decoder with specified decoding instance. + * + * @param safeInstance + * If {@code true} decoder will use {@link ChunkDecoder} that only uses standard JDK access methods, + * and should work on all Java platforms and JVMs. + * Otherwise decoder will try to use highly optimized {@link ChunkDecoder} implementation that uses + * Sun JDK's sun.misc.Unsafe class (which may be included by other JDK's as well). + */ + public LzfDecoder(boolean safeInstance) { + decoder = safeInstance ? + ChunkDecoderFactory.safeInstance() + : ChunkDecoderFactory.optimalInstance(); + + recycler = BufferRecycler.instance(); + } + + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { + try { + switch (currentState) { + case INIT_BLOCK: + if (in.readableBytes() < HEADER_LEN_NOT_COMPRESSED) { + break; + } + final int magic = in.readUnsignedShort(); + if (magic != MAGIC_NUMBER) { + throw new DecompressionException("unexpected block identifier"); + } + + final int type = in.readByte(); + switch (type) { + case BLOCK_TYPE_NON_COMPRESSED: + isCompressed = false; + currentState = State.DECOMPRESS_DATA; + break; + case BLOCK_TYPE_COMPRESSED: + isCompressed = true; + currentState = State.INIT_ORIGINAL_LENGTH; + break; + default: + throw new DecompressionException(String.format( + "unknown type of chunk: %d (expected: %d or %d)", + type, BLOCK_TYPE_NON_COMPRESSED, BLOCK_TYPE_COMPRESSED)); + } + chunkLength = in.readUnsignedShort(); + + // chunkLength can never exceed MAX_CHUNK_LEN as MAX_CHUNK_LEN is 64kb and readUnsignedShort can + // never return anything bigger as well. Let's add some check any way to make things easier in terms + // of debugging if we ever hit this because of an bug. + if (chunkLength > LZFChunk.MAX_CHUNK_LEN) { + throw new DecompressionException(String.format( + "chunk length exceeds maximum: %d (expected: =< %d)", + chunkLength, LZFChunk.MAX_CHUNK_LEN)); + } + + if (type != BLOCK_TYPE_COMPRESSED) { + break; + } + // fall through + case INIT_ORIGINAL_LENGTH: + if (in.readableBytes() < 2) { + break; + } + originalLength = in.readUnsignedShort(); + + // originalLength can never exceed MAX_CHUNK_LEN as MAX_CHUNK_LEN is 64kb and readUnsignedShort can + // never return anything bigger as well. Let's add some check any way to make things easier in terms + // of debugging if we ever hit this because of an bug. + if (originalLength > LZFChunk.MAX_CHUNK_LEN) { + throw new DecompressionException(String.format( + "original length exceeds maximum: %d (expected: =< %d)", + chunkLength, LZFChunk.MAX_CHUNK_LEN)); + } + + currentState = State.DECOMPRESS_DATA; + // fall through + case DECOMPRESS_DATA: + final int chunkLength = this.chunkLength; + if (in.readableBytes() < chunkLength) { + break; + } + final int originalLength = this.originalLength; + + if (isCompressed) { + final int idx = in.readerIndex(); + + final byte[] inputArray; + final int inPos; + if (in.hasArray()) { + inputArray = in.array(); + inPos = in.arrayOffset() + idx; + } else { + inputArray = recycler.allocInputBuffer(chunkLength); + in.getBytes(idx, inputArray, 0, chunkLength); + inPos = 0; + } + + ByteBuf uncompressed = ctx.alloc().heapBuffer(originalLength, originalLength); + final byte[] outputArray; + final int outPos; + if (uncompressed.hasArray()) { + outputArray = uncompressed.array(); + outPos = uncompressed.arrayOffset() + uncompressed.writerIndex(); + } else { + outputArray = new byte[originalLength]; + outPos = 0; + } + + boolean success = false; + try { + decoder.decodeChunk(inputArray, inPos, outputArray, outPos, outPos + originalLength); + if (uncompressed.hasArray()) { + uncompressed.writerIndex(uncompressed.writerIndex() + originalLength); + } else { + uncompressed.writeBytes(outputArray); + } + out.add(uncompressed); + in.skipBytes(chunkLength); + success = true; + } finally { + if (!success) { + uncompressed.release(); + } + } + + if (!in.hasArray()) { + recycler.releaseInputBuffer(inputArray); + } + } else if (chunkLength > 0) { + out.add(in.readRetainedSlice(chunkLength)); + } + + currentState = State.INIT_BLOCK; + break; + case CORRUPTED: + in.skipBytes(in.readableBytes()); + break; + default: + throw new IllegalStateException(); + } + } catch (Exception e) { + currentState = State.CORRUPTED; + decoder = null; + recycler = null; + throw e; + } + } +} diff --git a/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/LzfEncoder.java b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/LzfEncoder.java new file mode 100644 index 0000000..c18d1a8 --- /dev/null +++ b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/LzfEncoder.java @@ -0,0 +1,232 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.compression; + +import com.ning.compress.BufferRecycler; +import com.ning.compress.lzf.ChunkEncoder; +import com.ning.compress.lzf.LZFChunk; +import com.ning.compress.lzf.LZFEncoder; +import com.ning.compress.lzf.util.ChunkEncoderFactory; +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.MessageToByteEncoder; + +import static com.ning.compress.lzf.LZFChunk.MAX_CHUNK_LEN; + +/** + * Compresses a {@link ByteBuf} using the LZF format. + *

+ * See original LZF package + * and LZF format for full description. + */ +public class LzfEncoder extends MessageToByteEncoder { + + /** + * Minimum block size ready for compression. Blocks with length + * less than {@link #MIN_BLOCK_TO_COMPRESS} will write as uncompressed. + */ + private static final int MIN_BLOCK_TO_COMPRESS = 16; + + /** + * Compress threshold for LZF format. When the amount of input data is less than compressThreshold, + * we will construct an uncompressed output according to the LZF format. + *

+ * When the value is less than {@see ChunkEncoder#MIN_BLOCK_TO_COMPRESS}, since LZF will not compress data + * that is less than {@see ChunkEncoder#MIN_BLOCK_TO_COMPRESS}, compressThreshold will not work. + */ + private final int compressThreshold; + + /** + * Underlying decoder in use. + */ + private final ChunkEncoder encoder; + + /** + * Object that handles details of buffer recycling. + */ + private final BufferRecycler recycler; + + /** + * Creates a new LZF encoder with the most optimal available methods for underlying data access. + * It will "unsafe" instance if one can be used on current JVM. + * It should be safe to call this constructor as implementations are dynamically loaded; however, on some + * non-standard platforms it may be necessary to use {@link #LzfEncoder(boolean)} with {@code true} param. + */ + public LzfEncoder() { + this(false); + } + + /** + * Creates a new LZF encoder with specified encoding instance. + * + * @param safeInstance If {@code true} encoder will use {@link ChunkEncoder} that only uses + * standard JDK access methods, and should work on all Java platforms and JVMs. + * Otherwise encoder will try to use highly optimized {@link ChunkEncoder} + * implementation that uses Sun JDK's sun.misc.Unsafe + * class (which may be included by other JDK's as well). + */ + public LzfEncoder(boolean safeInstance) { + this(safeInstance, MAX_CHUNK_LEN); + } + + /** + * Creates a new LZF encoder with specified encoding instance and compressThreshold. + * + * @param safeInstance If {@code true} encoder will use {@link ChunkEncoder} that only uses standard + * JDK access methods, and should work on all Java platforms and JVMs. + * Otherwise encoder will try to use highly optimized {@link ChunkEncoder} + * implementation that uses Sun JDK's sun.misc.Unsafe + * class (which may be included by other JDK's as well). + * @param totalLength Expected total length of content to compress; only matters for outgoing messages + * that is smaller than maximum chunk size (64k), to optimize encoding hash tables. + */ + public LzfEncoder(boolean safeInstance, int totalLength) { + this(safeInstance, totalLength, MIN_BLOCK_TO_COMPRESS); + } + + /** + * Creates a new LZF encoder with specified total length of encoded chunk. You can configure it to encode + * your data flow more efficient if you know the average size of messages that you send. + * + * @param totalLength Expected total length of content to compress; + * only matters for outgoing messages that is smaller than maximum chunk size (64k), + * to optimize encoding hash tables. + */ + public LzfEncoder(int totalLength) { + this(false, totalLength); + } + + /** + * Creates a new LZF encoder with specified settings. + * + * @param safeInstance If {@code true} encoder will use {@link ChunkEncoder} that only uses standard JDK + * access methods, and should work on all Java platforms and JVMs. + * Otherwise encoder will try to use highly optimized {@link ChunkEncoder} + * implementation that uses Sun JDK's sun.misc.Unsafe + * class (which may be included by other JDK's as well). + * @param totalLength Expected total length of content to compress; only matters for outgoing messages + * that is smaller than maximum chunk size (64k), to optimize encoding hash tables. + * @param compressThreshold Compress threshold for LZF format. When the amount of input data is less than + * compressThreshold, we will construct an uncompressed output according + * to the LZF format. + */ + public LzfEncoder(boolean safeInstance, int totalLength, int compressThreshold) { + super(false); + if (totalLength < MIN_BLOCK_TO_COMPRESS || totalLength > MAX_CHUNK_LEN) { + throw new IllegalArgumentException("totalLength: " + totalLength + + " (expected: " + MIN_BLOCK_TO_COMPRESS + '-' + MAX_CHUNK_LEN + ')'); + } + + if (compressThreshold < MIN_BLOCK_TO_COMPRESS) { + // not a suitable value. + throw new IllegalArgumentException("compressThreshold:" + compressThreshold + + " expected >=" + MIN_BLOCK_TO_COMPRESS); + } + this.compressThreshold = compressThreshold; + + this.encoder = safeInstance ? + ChunkEncoderFactory.safeNonAllocatingInstance(totalLength) + : ChunkEncoderFactory.optimalNonAllocatingInstance(totalLength); + + this.recycler = BufferRecycler.instance(); + } + + @Override + protected void encode(ChannelHandlerContext ctx, ByteBuf in, ByteBuf out) throws Exception { + final int length = in.readableBytes(); + final int idx = in.readerIndex(); + final byte[] input; + final int inputPtr; + if (in.hasArray()) { + input = in.array(); + inputPtr = in.arrayOffset() + idx; + } else { + input = recycler.allocInputBuffer(length); + in.getBytes(idx, input, 0, length); + inputPtr = 0; + } + + // Estimate may apparently under-count by one in some cases. + final int maxOutputLength = LZFEncoder.estimateMaxWorkspaceSize(length) + 1; + out.ensureWritable(maxOutputLength); + final byte[] output; + final int outputPtr; + if (out.hasArray()) { + output = out.array(); + outputPtr = out.arrayOffset() + out.writerIndex(); + } else { + output = new byte[maxOutputLength]; + outputPtr = 0; + } + + final int outputLength; + if (length >= compressThreshold) { + // compress. + outputLength = encodeCompress(input, inputPtr, length, output, outputPtr); + } else { + // not compress. + outputLength = encodeNonCompress(input, inputPtr, length, output, outputPtr); + } + + if (out.hasArray()) { + out.writerIndex(out.writerIndex() + outputLength); + } else { + out.writeBytes(output, 0, outputLength); + } + + in.skipBytes(length); + + if (!in.hasArray()) { + recycler.releaseInputBuffer(input); + } + } + + private int encodeCompress(byte[] input, int inputPtr, int length, byte[] output, int outputPtr) { + return LZFEncoder.appendEncoded(encoder, + input, inputPtr, length, output, outputPtr) - outputPtr; + } + + private static int lzfEncodeNonCompress(byte[] input, int inputPtr, int length, byte[] output, int outputPtr) { + int left = length; + int chunkLen = Math.min(LZFChunk.MAX_CHUNK_LEN, left); + outputPtr = LZFChunk.appendNonCompressed(input, inputPtr, chunkLen, output, outputPtr); + left -= chunkLen; + if (left < 1) { + return outputPtr; + } + inputPtr += chunkLen; + do { + chunkLen = Math.min(left, LZFChunk.MAX_CHUNK_LEN); + outputPtr = LZFChunk.appendNonCompressed(input, inputPtr, chunkLen, output, outputPtr); + inputPtr += chunkLen; + left -= chunkLen; + } while (left > 0); + return outputPtr; + } + + /** + * Use lzf uncompressed format to encode a piece of input. + */ + private static int encodeNonCompress(byte[] input, int inputPtr, int length, byte[] output, int outputPtr) { + return lzfEncodeNonCompress(input, inputPtr, length, output, outputPtr) - outputPtr; + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { + encoder.close(); + super.handlerRemoved(ctx); + } +} diff --git a/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/Snappy.java b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/Snappy.java new file mode 100644 index 0000000..8dc8b04 --- /dev/null +++ b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/Snappy.java @@ -0,0 +1,677 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.compression; + +import io.netty.buffer.ByteBuf; + +/** + * Uncompresses an input {@link ByteBuf} encoded with Snappy compression into an + * output {@link ByteBuf}. + * + * See snappy format. + */ +public final class Snappy { + + private static final int MAX_HT_SIZE = 1 << 14; + private static final int MIN_COMPRESSIBLE_BYTES = 15; + + // used as a return value to indicate that we haven't yet read our full preamble + private static final int PREAMBLE_NOT_FULL = -1; + private static final int NOT_ENOUGH_INPUT = -1; + + // constants for the tag types + private static final int LITERAL = 0; + private static final int COPY_1_BYTE_OFFSET = 1; + private static final int COPY_2_BYTE_OFFSET = 2; + private static final int COPY_4_BYTE_OFFSET = 3; + + private State state = State.READING_PREAMBLE; + private byte tag; + private int written; + + private enum State { + READING_PREAMBLE, + READING_TAG, + READING_LITERAL, + READING_COPY + } + + public void reset() { + state = State.READING_PREAMBLE; + tag = 0; + written = 0; + } + + public void encode(final ByteBuf in, final ByteBuf out, final int length) { + // Write the preamble length to the output buffer + for (int i = 0;; i ++) { + int b = length >>> i * 7; + if ((b & 0xFFFFFF80) != 0) { + out.writeByte(b & 0x7f | 0x80); + } else { + out.writeByte(b); + break; + } + } + + int inIndex = in.readerIndex(); + final int baseIndex = inIndex; + + final short[] table = getHashTable(length); + final int shift = Integer.numberOfLeadingZeros(table.length) + 1; + + int nextEmit = inIndex; + + if (length - inIndex >= MIN_COMPRESSIBLE_BYTES) { + int nextHash = hash(in, ++inIndex, shift); + outer: while (true) { + int skip = 32; + + int candidate; + int nextIndex = inIndex; + do { + inIndex = nextIndex; + int hash = nextHash; + int bytesBetweenHashLookups = skip++ >> 5; + nextIndex = inIndex + bytesBetweenHashLookups; + + // We need at least 4 remaining bytes to read the hash + if (nextIndex > length - 4) { + break outer; + } + + nextHash = hash(in, nextIndex, shift); + + candidate = baseIndex + table[hash]; + + table[hash] = (short) (inIndex - baseIndex); + } + while (in.getInt(inIndex) != in.getInt(candidate)); + + encodeLiteral(in, out, inIndex - nextEmit); + + int insertTail; + do { + int base = inIndex; + int matched = 4 + findMatchingLength(in, candidate + 4, inIndex + 4, length); + inIndex += matched; + int offset = base - candidate; + encodeCopy(out, offset, matched); + in.readerIndex(in.readerIndex() + matched); + insertTail = inIndex - 1; + nextEmit = inIndex; + if (inIndex >= length - 4) { + break outer; + } + + int prevHash = hash(in, insertTail, shift); + table[prevHash] = (short) (inIndex - baseIndex - 1); + int currentHash = hash(in, insertTail + 1, shift); + candidate = baseIndex + table[currentHash]; + table[currentHash] = (short) (inIndex - baseIndex); + } + while (in.getInt(insertTail + 1) == in.getInt(candidate)); + + nextHash = hash(in, insertTail + 2, shift); + ++inIndex; + } + } + + // If there are any remaining characters, write them out as a literal + if (nextEmit < length) { + encodeLiteral(in, out, length - nextEmit); + } + } + + /** + * Hashes the 4 bytes located at index, shifting the resulting hash into + * the appropriate range for our hash table. + * + * @param in The input buffer to read 4 bytes from + * @param index The index to read at + * @param shift The shift value, for ensuring that the resulting value is + * within the range of our hash table size + * @return A 32-bit hash of 4 bytes located at index + */ + private static int hash(ByteBuf in, int index, int shift) { + return in.getInt(index) * 0x1e35a7bd >>> shift; + } + + /** + * Creates an appropriately sized hashtable for the given input size + * + * @param inputSize The size of our input, ie. the number of bytes we need to encode + * @return An appropriately sized empty hashtable + */ + private static short[] getHashTable(int inputSize) { + int htSize = 256; + while (htSize < MAX_HT_SIZE && htSize < inputSize) { + htSize <<= 1; + } + return new short[htSize]; + } + + /** + * Iterates over the supplied input buffer between the supplied minIndex and + * maxIndex to find how long our matched copy overlaps with an already-written + * literal value. + * + * @param in The input buffer to scan over + * @param minIndex The index in the input buffer to start scanning from + * @param inIndex The index of the start of our copy + * @param maxIndex The length of our input buffer + * @return The number of bytes for which our candidate copy is a repeat of + */ + private static int findMatchingLength(ByteBuf in, int minIndex, int inIndex, int maxIndex) { + int matched = 0; + + while (inIndex <= maxIndex - 4 && + in.getInt(inIndex) == in.getInt(minIndex + matched)) { + inIndex += 4; + matched += 4; + } + + while (inIndex < maxIndex && in.getByte(minIndex + matched) == in.getByte(inIndex)) { + ++inIndex; + ++matched; + } + + return matched; + } + + /** + * Calculates the minimum number of bits required to encode a value. This can + * then in turn be used to calculate the number of septets or octets (as + * appropriate) to use to encode a length parameter. + * + * @param value The value to calculate the minimum number of bits required to encode + * @return The minimum number of bits required to encode the supplied value + */ + private static int bitsToEncode(int value) { + int highestOneBit = Integer.highestOneBit(value); + int bitLength = 0; + while ((highestOneBit >>= 1) != 0) { + bitLength++; + } + + return bitLength; + } + + /** + * Writes a literal to the supplied output buffer by directly copying from + * the input buffer. The literal is taken from the current readerIndex + * up to the supplied length. + * + * @param in The input buffer to copy from + * @param out The output buffer to copy to + * @param length The length of the literal to copy + */ + static void encodeLiteral(ByteBuf in, ByteBuf out, int length) { + if (length < 61) { + out.writeByte(length - 1 << 2); + } else { + int bitLength = bitsToEncode(length - 1); + int bytesToEncode = 1 + bitLength / 8; + out.writeByte(59 + bytesToEncode << 2); + for (int i = 0; i < bytesToEncode; i++) { + out.writeByte(length - 1 >> i * 8 & 0x0ff); + } + } + + out.writeBytes(in, length); + } + + private static void encodeCopyWithOffset(ByteBuf out, int offset, int length) { + if (length < 12 && offset < 2048) { + out.writeByte(COPY_1_BYTE_OFFSET | length - 4 << 2 | offset >> 8 << 5); + out.writeByte(offset & 0x0ff); + } else { + out.writeByte(COPY_2_BYTE_OFFSET | length - 1 << 2); + out.writeByte(offset & 0x0ff); + out.writeByte(offset >> 8 & 0x0ff); + } + } + + /** + * Encodes a series of copies, each at most 64 bytes in length. + * + * @param out The output buffer to write the copy pointer to + * @param offset The offset at which the original instance lies + * @param length The length of the original instance + */ + private static void encodeCopy(ByteBuf out, int offset, int length) { + while (length >= 68) { + encodeCopyWithOffset(out, offset, 64); + length -= 64; + } + + if (length > 64) { + encodeCopyWithOffset(out, offset, 60); + length -= 60; + } + + encodeCopyWithOffset(out, offset, length); + } + + public void decode(ByteBuf in, ByteBuf out) { + while (in.isReadable()) { + switch (state) { + case READING_PREAMBLE: + int uncompressedLength = readPreamble(in); + if (uncompressedLength == PREAMBLE_NOT_FULL) { + // We've not yet read all of the preamble, so wait until we can + return; + } + if (uncompressedLength == 0) { + // Should never happen, but it does mean we have nothing further to do + return; + } + out.ensureWritable(uncompressedLength); + state = State.READING_TAG; + // fall through + case READING_TAG: + if (!in.isReadable()) { + return; + } + tag = in.readByte(); + switch (tag & 0x03) { + case LITERAL: + state = State.READING_LITERAL; + break; + case COPY_1_BYTE_OFFSET: + case COPY_2_BYTE_OFFSET: + case COPY_4_BYTE_OFFSET: + state = State.READING_COPY; + break; + } + break; + case READING_LITERAL: + int literalWritten = decodeLiteral(tag, in, out); + if (literalWritten != NOT_ENOUGH_INPUT) { + state = State.READING_TAG; + written += literalWritten; + } else { + // Need to wait for more data + return; + } + break; + case READING_COPY: + int decodeWritten; + switch (tag & 0x03) { + case COPY_1_BYTE_OFFSET: + decodeWritten = decodeCopyWith1ByteOffset(tag, in, out, written); + if (decodeWritten != NOT_ENOUGH_INPUT) { + state = State.READING_TAG; + written += decodeWritten; + } else { + // Need to wait for more data + return; + } + break; + case COPY_2_BYTE_OFFSET: + decodeWritten = decodeCopyWith2ByteOffset(tag, in, out, written); + if (decodeWritten != NOT_ENOUGH_INPUT) { + state = State.READING_TAG; + written += decodeWritten; + } else { + // Need to wait for more data + return; + } + break; + case COPY_4_BYTE_OFFSET: + decodeWritten = decodeCopyWith4ByteOffset(tag, in, out, written); + if (decodeWritten != NOT_ENOUGH_INPUT) { + state = State.READING_TAG; + written += decodeWritten; + } else { + // Need to wait for more data + return; + } + break; + } + } + } + } + + /** + * Reads the length varint (a series of bytes, where the lower 7 bits + * are data and the upper bit is a flag to indicate more bytes to be + * read). + * + * @param in The input buffer to read the preamble from + * @return The calculated length based on the input buffer, or 0 if + * no preamble is able to be calculated + */ + private static int readPreamble(ByteBuf in) { + int length = 0; + int byteIndex = 0; + while (in.isReadable()) { + int current = in.readUnsignedByte(); + length |= (current & 0x7f) << byteIndex++ * 7; + if ((current & 0x80) == 0) { + return length; + } + + if (byteIndex >= 4) { + throw new DecompressionException("Preamble is greater than 4 bytes"); + } + } + + return 0; + } + + /** + * Get the length varint (a series of bytes, where the lower 7 bits + * are data and the upper bit is a flag to indicate more bytes to be + * read). + * + * @param in The input buffer to get the preamble from + * @return The calculated length based on the input buffer, or 0 if + * no preamble is able to be calculated + */ + int getPreamble(ByteBuf in) { + if (state == State.READING_PREAMBLE) { + int readerIndex = in.readerIndex(); + try { + return readPreamble(in); + } finally { + in.readerIndex(readerIndex); + } + } + return 0; + } + + /** + * Reads a literal from the input buffer directly to the output buffer. + * A "literal" is an uncompressed segment of data stored directly in the + * byte stream. + * + * @param tag The tag that identified this segment as a literal is also + * used to encode part of the length of the data + * @param in The input buffer to read the literal from + * @param out The output buffer to write the literal to + * @return The number of bytes appended to the output buffer, or -1 to indicate "try again later" + */ + static int decodeLiteral(byte tag, ByteBuf in, ByteBuf out) { + in.markReaderIndex(); + int length; + switch(tag >> 2 & 0x3F) { + case 60: + if (!in.isReadable()) { + return NOT_ENOUGH_INPUT; + } + length = in.readUnsignedByte(); + break; + case 61: + if (in.readableBytes() < 2) { + return NOT_ENOUGH_INPUT; + } + length = in.readUnsignedShortLE(); + break; + case 62: + if (in.readableBytes() < 3) { + return NOT_ENOUGH_INPUT; + } + length = in.readUnsignedMediumLE(); + break; + case 63: + if (in.readableBytes() < 4) { + return NOT_ENOUGH_INPUT; + } + length = in.readIntLE(); + break; + default: + length = tag >> 2 & 0x3F; + } + length += 1; + + if (in.readableBytes() < length) { + in.resetReaderIndex(); + return NOT_ENOUGH_INPUT; + } + + out.writeBytes(in, length); + return length; + } + + /** + * Reads a compressed reference offset and length from the supplied input + * buffer, seeks back to the appropriate place in the input buffer and + * writes the found data to the supplied output stream. + * + * @param tag The tag used to identify this as a copy is also used to encode + * the length and part of the offset + * @param in The input buffer to read from + * @param out The output buffer to write to + * @return The number of bytes appended to the output buffer, or -1 to indicate + * "try again later" + * @throws DecompressionException If the read offset is invalid + */ + private static int decodeCopyWith1ByteOffset(byte tag, ByteBuf in, ByteBuf out, int writtenSoFar) { + if (!in.isReadable()) { + return NOT_ENOUGH_INPUT; + } + + int initialIndex = out.writerIndex(); + int length = 4 + ((tag & 0x01c) >> 2); + int offset = (tag & 0x0e0) << 8 >> 5 | in.readUnsignedByte(); + + validateOffset(offset, writtenSoFar); + + out.markReaderIndex(); + if (offset < length) { + int copies = length / offset; + for (; copies > 0; copies--) { + out.readerIndex(initialIndex - offset); + out.readBytes(out, offset); + } + if (length % offset != 0) { + out.readerIndex(initialIndex - offset); + out.readBytes(out, length % offset); + } + } else { + out.readerIndex(initialIndex - offset); + out.readBytes(out, length); + } + out.resetReaderIndex(); + + return length; + } + + /** + * Reads a compressed reference offset and length from the supplied input + * buffer, seeks back to the appropriate place in the input buffer and + * writes the found data to the supplied output stream. + * + * @param tag The tag used to identify this as a copy is also used to encode + * the length and part of the offset + * @param in The input buffer to read from + * @param out The output buffer to write to + * @throws DecompressionException If the read offset is invalid + * @return The number of bytes appended to the output buffer, or -1 to indicate + * "try again later" + */ + private static int decodeCopyWith2ByteOffset(byte tag, ByteBuf in, ByteBuf out, int writtenSoFar) { + if (in.readableBytes() < 2) { + return NOT_ENOUGH_INPUT; + } + + int initialIndex = out.writerIndex(); + int length = 1 + (tag >> 2 & 0x03f); + int offset = in.readUnsignedShortLE(); + + validateOffset(offset, writtenSoFar); + + out.markReaderIndex(); + if (offset < length) { + int copies = length / offset; + for (; copies > 0; copies--) { + out.readerIndex(initialIndex - offset); + out.readBytes(out, offset); + } + if (length % offset != 0) { + out.readerIndex(initialIndex - offset); + out.readBytes(out, length % offset); + } + } else { + out.readerIndex(initialIndex - offset); + out.readBytes(out, length); + } + out.resetReaderIndex(); + + return length; + } + + /** + * Reads a compressed reference offset and length from the supplied input + * buffer, seeks back to the appropriate place in the input buffer and + * writes the found data to the supplied output stream. + * + * @param tag The tag used to identify this as a copy is also used to encode + * the length and part of the offset + * @param in The input buffer to read from + * @param out The output buffer to write to + * @return The number of bytes appended to the output buffer, or -1 to indicate + * "try again later" + * @throws DecompressionException If the read offset is invalid + */ + private static int decodeCopyWith4ByteOffset(byte tag, ByteBuf in, ByteBuf out, int writtenSoFar) { + if (in.readableBytes() < 4) { + return NOT_ENOUGH_INPUT; + } + + int initialIndex = out.writerIndex(); + int length = 1 + (tag >> 2 & 0x03F); + int offset = in.readIntLE(); + + validateOffset(offset, writtenSoFar); + + out.markReaderIndex(); + if (offset < length) { + int copies = length / offset; + for (; copies > 0; copies--) { + out.readerIndex(initialIndex - offset); + out.readBytes(out, offset); + } + if (length % offset != 0) { + out.readerIndex(initialIndex - offset); + out.readBytes(out, length % offset); + } + } else { + out.readerIndex(initialIndex - offset); + out.readBytes(out, length); + } + out.resetReaderIndex(); + + return length; + } + + /** + * Validates that the offset extracted from a compressed reference is within + * the permissible bounds of an offset (0 < offset < Integer.MAX_VALUE), and does not + * exceed the length of the chunk currently read so far. + * + * @param offset The offset extracted from the compressed reference + * @param chunkSizeSoFar The number of bytes read so far from this chunk + * @throws DecompressionException if the offset is invalid + */ + private static void validateOffset(int offset, int chunkSizeSoFar) { + if (offset == 0) { + throw new DecompressionException("Offset is less than minimum permissible value"); + } + + if (offset < 0) { + // Due to arithmetic overflow + throw new DecompressionException("Offset is greater than maximum value supported by this implementation"); + } + + if (offset > chunkSizeSoFar) { + throw new DecompressionException("Offset exceeds size of chunk"); + } + } + + /** + * Computes the CRC32C checksum of the supplied data and performs the "mask" operation + * on the computed checksum + * + * @param data The input data to calculate the CRC32C checksum of + */ + static int calculateChecksum(ByteBuf data) { + return calculateChecksum(data, data.readerIndex(), data.readableBytes()); + } + + /** + * Computes the CRC32C checksum of the supplied data and performs the "mask" operation + * on the computed checksum + * + * @param data The input data to calculate the CRC32C checksum of + */ + static int calculateChecksum(ByteBuf data, int offset, int length) { + Crc32c crc32 = new Crc32c(); + try { + crc32.update(data, offset, length); + return maskChecksum(crc32.getValue()); + } finally { + crc32.reset(); + } + } + + /** + * Computes the CRC32C checksum of the supplied data, performs the "mask" operation + * on the computed checksum, and then compares the resulting masked checksum to the + * supplied checksum. + * + * @param expectedChecksum The checksum decoded from the stream to compare against + * @param data The input data to calculate the CRC32C checksum of + * @throws DecompressionException If the calculated and supplied checksums do not match + */ + static void validateChecksum(int expectedChecksum, ByteBuf data) { + validateChecksum(expectedChecksum, data, data.readerIndex(), data.readableBytes()); + } + + /** + * Computes the CRC32C checksum of the supplied data, performs the "mask" operation + * on the computed checksum, and then compares the resulting masked checksum to the + * supplied checksum. + * + * @param expectedChecksum The checksum decoded from the stream to compare against + * @param data The input data to calculate the CRC32C checksum of + * @throws DecompressionException If the calculated and supplied checksums do not match + */ + static void validateChecksum(int expectedChecksum, ByteBuf data, int offset, int length) { + final int actualChecksum = calculateChecksum(data, offset, length); + if (actualChecksum != expectedChecksum) { + throw new DecompressionException( + "mismatching checksum: " + Integer.toHexString(actualChecksum) + + " (expected: " + Integer.toHexString(expectedChecksum) + ')'); + } + } + + /** + * From the spec: + * + * "Checksums are not stored directly, but masked, as checksumming data and + * then its own checksum can be problematic. The masking is the same as used + * in Apache Hadoop: Rotate the checksum by 15 bits, then add the constant + * 0xa282ead8 (using wraparound as normal for unsigned integers)." + * + * @param checksum The actual checksum of the data + * @return The masked checksum + */ + static int maskChecksum(long checksum) { + return (int) ((checksum >> 15 | checksum << 17) + 0xa282ead8); + } +} diff --git a/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/SnappyFrameDecoder.java b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/SnappyFrameDecoder.java new file mode 100644 index 0000000..5199759 --- /dev/null +++ b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/SnappyFrameDecoder.java @@ -0,0 +1,259 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.compression; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.ByteToMessageDecoder; + +import java.util.List; + +import static io.netty.handler.codec.compression.Snappy.validateChecksum; + +/** + * Uncompresses a {@link ByteBuf} encoded with the Snappy framing format. + * + * See Snappy framing format. + * + * Note that by default, validation of the checksum header in each chunk is + * DISABLED for performance improvements. If performance is less of an issue, + * or if you would prefer the safety that checksum validation brings, please + * use the {@link #SnappyFrameDecoder(boolean)} constructor with the argument + * set to {@code true}. + */ +public class SnappyFrameDecoder extends ByteToMessageDecoder { + + private enum ChunkType { + STREAM_IDENTIFIER, + COMPRESSED_DATA, + UNCOMPRESSED_DATA, + RESERVED_UNSKIPPABLE, + RESERVED_SKIPPABLE + } + + private static final int SNAPPY_IDENTIFIER_LEN = 6; + // See https://github.com/google/snappy/blob/1.1.9/framing_format.txt#L95 + private static final int MAX_UNCOMPRESSED_DATA_SIZE = 65536 + 4; + // See https://github.com/google/snappy/blob/1.1.9/framing_format.txt#L82 + private static final int MAX_DECOMPRESSED_DATA_SIZE = 65536; + // See https://github.com/google/snappy/blob/1.1.9/framing_format.txt#L82 + private static final int MAX_COMPRESSED_CHUNK_SIZE = 16777216 - 1; + + private final Snappy snappy = new Snappy(); + private final boolean validateChecksums; + + private boolean started; + private boolean corrupted; + private int numBytesToSkip; + + /** + * Creates a new snappy-framed decoder with validation of checksums + * turned OFF. To turn checksum validation on, please use the alternate + * {@link #SnappyFrameDecoder(boolean)} constructor. + */ + public SnappyFrameDecoder() { + this(false); + } + + /** + * Creates a new snappy-framed decoder with validation of checksums + * as specified. + * + * @param validateChecksums + * If true, the checksum field will be validated against the actual + * uncompressed data, and if the checksums do not match, a suitable + * {@link DecompressionException} will be thrown + */ + public SnappyFrameDecoder(boolean validateChecksums) { + this.validateChecksums = validateChecksums; + } + + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { + if (corrupted) { + in.skipBytes(in.readableBytes()); + return; + } + + if (numBytesToSkip != 0) { + // The last chunkType we detected was RESERVED_SKIPPABLE and we still have some bytes to skip. + int skipBytes = Math.min(numBytesToSkip, in.readableBytes()); + in.skipBytes(skipBytes); + numBytesToSkip -= skipBytes; + + // Let's return and try again. + return; + } + + try { + int idx = in.readerIndex(); + final int inSize = in.readableBytes(); + if (inSize < 4) { + // We need to be at least able to read the chunk type identifier (one byte), + // and the length of the chunk (3 bytes) in order to proceed + return; + } + + final int chunkTypeVal = in.getUnsignedByte(idx); + final ChunkType chunkType = mapChunkType((byte) chunkTypeVal); + final int chunkLength = in.getUnsignedMediumLE(idx + 1); + + switch (chunkType) { + case STREAM_IDENTIFIER: + if (chunkLength != SNAPPY_IDENTIFIER_LEN) { + throw new DecompressionException("Unexpected length of stream identifier: " + chunkLength); + } + + if (inSize < 4 + SNAPPY_IDENTIFIER_LEN) { + break; + } + + in.skipBytes(4); + int offset = in.readerIndex(); + in.skipBytes(SNAPPY_IDENTIFIER_LEN); + + checkByte(in.getByte(offset++), (byte) 's'); + checkByte(in.getByte(offset++), (byte) 'N'); + checkByte(in.getByte(offset++), (byte) 'a'); + checkByte(in.getByte(offset++), (byte) 'P'); + checkByte(in.getByte(offset++), (byte) 'p'); + checkByte(in.getByte(offset), (byte) 'Y'); + + started = true; + break; + case RESERVED_SKIPPABLE: + if (!started) { + throw new DecompressionException("Received RESERVED_SKIPPABLE tag before STREAM_IDENTIFIER"); + } + + in.skipBytes(4); + + int skipBytes = Math.min(chunkLength, in.readableBytes()); + in.skipBytes(skipBytes); + if (skipBytes != chunkLength) { + // We could skip all bytes, let's store the remaining so we can do so once we receive more + // data. + numBytesToSkip = chunkLength - skipBytes; + } + break; + case RESERVED_UNSKIPPABLE: + // The spec mandates that reserved unskippable chunks must immediately + // return an error, as we must assume that we cannot decode the stream + // correctly + throw new DecompressionException( + "Found reserved unskippable chunk type: 0x" + Integer.toHexString(chunkTypeVal)); + case UNCOMPRESSED_DATA: + if (!started) { + throw new DecompressionException("Received UNCOMPRESSED_DATA tag before STREAM_IDENTIFIER"); + } + if (chunkLength > MAX_UNCOMPRESSED_DATA_SIZE) { + throw new DecompressionException("Received UNCOMPRESSED_DATA larger than " + + MAX_UNCOMPRESSED_DATA_SIZE + " bytes"); + } + + if (inSize < 4 + chunkLength) { + return; + } + + in.skipBytes(4); + if (validateChecksums) { + int checksum = in.readIntLE(); + validateChecksum(checksum, in, in.readerIndex(), chunkLength - 4); + } else { + in.skipBytes(4); + } + out.add(in.readRetainedSlice(chunkLength - 4)); + break; + case COMPRESSED_DATA: + if (!started) { + throw new DecompressionException("Received COMPRESSED_DATA tag before STREAM_IDENTIFIER"); + } + + if (chunkLength > MAX_COMPRESSED_CHUNK_SIZE) { + throw new DecompressionException("Received COMPRESSED_DATA that contains" + + " chunk that exceeds " + MAX_COMPRESSED_CHUNK_SIZE + " bytes"); + } + + if (inSize < 4 + chunkLength) { + return; + } + + in.skipBytes(4); + int checksum = in.readIntLE(); + + int uncompressedSize = snappy.getPreamble(in); + if (uncompressedSize > MAX_DECOMPRESSED_DATA_SIZE) { + throw new DecompressionException("Received COMPRESSED_DATA that contains" + + " uncompressed data that exceeds " + MAX_DECOMPRESSED_DATA_SIZE + " bytes"); + } + + ByteBuf uncompressed = ctx.alloc().buffer(uncompressedSize, MAX_DECOMPRESSED_DATA_SIZE); + try { + if (validateChecksums) { + int oldWriterIndex = in.writerIndex(); + try { + in.writerIndex(in.readerIndex() + chunkLength - 4); + snappy.decode(in, uncompressed); + } finally { + in.writerIndex(oldWriterIndex); + } + validateChecksum(checksum, uncompressed, 0, uncompressed.writerIndex()); + } else { + snappy.decode(in.readSlice(chunkLength - 4), uncompressed); + } + out.add(uncompressed); + uncompressed = null; + } finally { + if (uncompressed != null) { + uncompressed.release(); + } + } + snappy.reset(); + break; + } + } catch (Exception e) { + corrupted = true; + throw e; + } + } + + private static void checkByte(byte actual, byte expect) { + if (actual != expect) { + throw new DecompressionException("Unexpected stream identifier contents. Mismatched snappy " + + "protocol version?"); + } + } + + /** + * Decodes the chunk type from the type tag byte. + * + * @param type The tag byte extracted from the stream + * @return The appropriate {@link ChunkType}, defaulting to {@link ChunkType#RESERVED_UNSKIPPABLE} + */ + private static ChunkType mapChunkType(byte type) { + if (type == 0) { + return ChunkType.COMPRESSED_DATA; + } else if (type == 1) { + return ChunkType.UNCOMPRESSED_DATA; + } else if (type == (byte) 0xff) { + return ChunkType.STREAM_IDENTIFIER; + } else if ((type & 0x80) == 0x80) { + return ChunkType.RESERVED_SKIPPABLE; + } else { + return ChunkType.RESERVED_UNSKIPPABLE; + } + } +} diff --git a/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/SnappyFrameEncoder.java b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/SnappyFrameEncoder.java new file mode 100644 index 0000000..c4f43f3 --- /dev/null +++ b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/SnappyFrameEncoder.java @@ -0,0 +1,123 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.compression; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.MessageToByteEncoder; + +import static io.netty.handler.codec.compression.Snappy.calculateChecksum; + +/** + * Compresses a {@link ByteBuf} using the Snappy framing format. + * + * See Snappy framing format. + */ +public class SnappyFrameEncoder extends MessageToByteEncoder { + /** + * The minimum amount that we'll consider actually attempting to compress. + * This value is preamble + the minimum length our Snappy service will + * compress (instead of just emitting a literal). + */ + private static final int MIN_COMPRESSIBLE_LENGTH = 18; + + /** + * All streams should start with the "Stream identifier", containing chunk + * type 0xff, a length field of 0x6, and 'sNaPpY' in ASCII. + */ + private static final byte[] STREAM_START = { + (byte) 0xff, 0x06, 0x00, 0x00, 0x73, 0x4e, 0x61, 0x50, 0x70, 0x59 + }; + + private final Snappy snappy = new Snappy(); + private boolean started; + + @Override + protected void encode(ChannelHandlerContext ctx, ByteBuf in, ByteBuf out) throws Exception { + if (!in.isReadable()) { + return; + } + + if (!started) { + started = true; + out.writeBytes(STREAM_START); + } + + int dataLength = in.readableBytes(); + if (dataLength > MIN_COMPRESSIBLE_LENGTH) { + for (;;) { + final int lengthIdx = out.writerIndex() + 1; + if (dataLength < MIN_COMPRESSIBLE_LENGTH) { + ByteBuf slice = in.readSlice(dataLength); + writeUnencodedChunk(slice, out, dataLength); + break; + } + + out.writeInt(0); + if (dataLength > Short.MAX_VALUE) { + ByteBuf slice = in.readSlice(Short.MAX_VALUE); + calculateAndWriteChecksum(slice, out); + snappy.encode(slice, out, Short.MAX_VALUE); + setChunkLength(out, lengthIdx); + dataLength -= Short.MAX_VALUE; + } else { + ByteBuf slice = in.readSlice(dataLength); + calculateAndWriteChecksum(slice, out); + snappy.encode(slice, out, dataLength); + setChunkLength(out, lengthIdx); + break; + } + } + } else { + writeUnencodedChunk(in, out, dataLength); + } + } + + private static void writeUnencodedChunk(ByteBuf in, ByteBuf out, int dataLength) { + out.writeByte(1); + writeChunkLength(out, dataLength + 4); + calculateAndWriteChecksum(in, out); + out.writeBytes(in, dataLength); + } + + private static void setChunkLength(ByteBuf out, int lengthIdx) { + int chunkLength = out.writerIndex() - lengthIdx - 3; + if (chunkLength >>> 24 != 0) { + throw new CompressionException("compressed data too large: " + chunkLength); + } + out.setMediumLE(lengthIdx, chunkLength); + } + + /** + * Writes the 2-byte chunk length to the output buffer. + * + * @param out The buffer to write to + * @param chunkLength The length to write + */ + private static void writeChunkLength(ByteBuf out, int chunkLength) { + out.writeMediumLE(chunkLength); + } + + /** + * Calculates and writes the 4-byte checksum to the output buffer + * + * @param slice The data to calculate the checksum for + * @param out The output buffer to write the checksum to + */ + private static void calculateAndWriteChecksum(ByteBuf slice, ByteBuf out) { + out.writeIntLE(calculateChecksum(slice)); + } +} diff --git a/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/SnappyFramedDecoder.java b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/SnappyFramedDecoder.java new file mode 100644 index 0000000..88a244f --- /dev/null +++ b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/SnappyFramedDecoder.java @@ -0,0 +1,25 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.codec.compression; + +/** + * @deprecated Use {@link SnappyFrameDecoder} instead. + */ +@Deprecated +public class SnappyFramedDecoder extends SnappyFrameDecoder { + // Nothing new. Just staying here for backward compatibility. +} diff --git a/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/SnappyFramedEncoder.java b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/SnappyFramedEncoder.java new file mode 100644 index 0000000..f755ff7 --- /dev/null +++ b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/SnappyFramedEncoder.java @@ -0,0 +1,25 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.codec.compression; + +/** + * @deprecated Use {@link SnappyFrameEncoder} instead. + */ +@Deprecated +public class SnappyFramedEncoder extends SnappyFrameEncoder { + // Nothing new. Just staying here for backward compatibility. +} diff --git a/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/SnappyOptions.java b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/SnappyOptions.java new file mode 100644 index 0000000..aaee640 --- /dev/null +++ b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/SnappyOptions.java @@ -0,0 +1,24 @@ +/* + * Copyright 2023 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.compression; + +/** + * {@link SnappyOptions} holds config for + * Snappy compression. + */ +public final class SnappyOptions implements CompressionOptions { + // Will add config if Snappy supports this +} diff --git a/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/StandardCompressionOptions.java b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/StandardCompressionOptions.java new file mode 100644 index 0000000..ac25824 --- /dev/null +++ b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/StandardCompressionOptions.java @@ -0,0 +1,136 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.compression; + +import com.aayushatharva.brotli4j.encoder.Encoder; + +/** + * Standard Compression Options for {@link BrotliOptions}, + * {@link GzipOptions} and {@link DeflateOptions} + */ +public final class StandardCompressionOptions { + + private StandardCompressionOptions() { + // Prevent outside initialization + } + + /** + * Default implementation of {@link BrotliOptions} with {@link Encoder.Parameters#setQuality(int)} set to 4 + * and {@link Encoder.Parameters#setMode(Encoder.Mode)} set to {@link Encoder.Mode#TEXT} + */ + public static BrotliOptions brotli() { + return BrotliOptions.DEFAULT; + } + + /** + * Create a new {@link BrotliOptions} + * + * @param parameters {@link Encoder.Parameters} Instance + * @throws NullPointerException If {@link Encoder.Parameters} is {@code null} + */ + public static BrotliOptions brotli(Encoder.Parameters parameters) { + return new BrotliOptions(parameters); + } + + /** + * Default implementation of {@link ZstdOptions} with{compressionLevel(int)} set to + * {@link ZstdConstants#DEFAULT_COMPRESSION_LEVEL},{@link ZstdConstants#DEFAULT_BLOCK_SIZE}, + * {@link ZstdConstants#MAX_BLOCK_SIZE} + */ + public static ZstdOptions zstd() { + return ZstdOptions.DEFAULT; + } + + /** + * Create a new {@link ZstdOptions} + * + * @param blockSize + * is used to calculate the compressionLevel + * @param maxEncodeSize + * specifies the size of the largest compressed object + * @param compressionLevel + * specifies the level of the compression + */ + public static ZstdOptions zstd(int compressionLevel, int blockSize, int maxEncodeSize) { + return new ZstdOptions(compressionLevel, blockSize, maxEncodeSize); + } + + /** + * Create a new {@link SnappyOptions} + * + */ + public static SnappyOptions snappy() { + return new SnappyOptions(); + } + + /** + * Default implementation of {@link GzipOptions} with + * {@code compressionLevel()} set to 6, {@code windowBits()} set to 15 and {@code memLevel()} set to 8. + */ + public static GzipOptions gzip() { + return GzipOptions.DEFAULT; + } + + /** + * Create a new {@link GzipOptions} Instance + * + * @param compressionLevel {@code 1} yields the fastest compression and {@code 9} yields the + * best compression. {@code 0} means no compression. The default + * compression level is {@code 6}. + * + * @param windowBits The base two logarithm of the size of the history buffer. The + * value should be in the range {@code 9} to {@code 15} inclusive. + * Larger values result in better compression at the expense of + * memory usage. The default value is {@code 15}. + * + * @param memLevel How much memory should be allocated for the internal compression + * state. {@code 1} uses minimum memory and {@code 9} uses maximum + * memory. Larger values result in better and faster compression + * at the expense of memory usage. The default value is {@code 8} + */ + public static GzipOptions gzip(int compressionLevel, int windowBits, int memLevel) { + return new GzipOptions(compressionLevel, windowBits, memLevel); + } + + /** + * Default implementation of {@link DeflateOptions} with + * {@code compressionLevel} set to 6, {@code windowBits} set to 15 and {@code memLevel} set to 8. + */ + public static DeflateOptions deflate() { + return DeflateOptions.DEFAULT; + } + + /** + * Create a new {@link DeflateOptions} Instance + * + * @param compressionLevel {@code 1} yields the fastest compression and {@code 9} yields the + * best compression. {@code 0} means no compression. The default + * compression level is {@code 6}. + * + * @param windowBits The base two logarithm of the size of the history buffer. The + * value should be in the range {@code 9} to {@code 15} inclusive. + * Larger values result in better compression at the expense of + * memory usage. The default value is {@code 15}. + * + * @param memLevel How much memory should be allocated for the internal compression + * state. {@code 1} uses minimum memory and {@code 9} uses maximum + * memory. Larger values result in better and faster compression + * at the expense of memory usage. The default value is {@code 8} + */ + public static DeflateOptions deflate(int compressionLevel, int windowBits, int memLevel) { + return new DeflateOptions(compressionLevel, windowBits, memLevel); + } +} diff --git a/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/ZlibCodecFactory.java b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/ZlibCodecFactory.java new file mode 100644 index 0000000..d41d329 --- /dev/null +++ b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/ZlibCodecFactory.java @@ -0,0 +1,139 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.compression; + +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.SystemPropertyUtil; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +/** + * Creates a new {@link ZlibEncoder} and a new {@link ZlibDecoder}. + */ +public final class ZlibCodecFactory { + private static final InternalLogger logger = InternalLoggerFactory.getInstance(ZlibCodecFactory.class); + + private static final int DEFAULT_JDK_WINDOW_SIZE = 15; + private static final int DEFAULT_JDK_MEM_LEVEL = 8; + + private static final boolean noJdkZlibDecoder; + private static final boolean noJdkZlibEncoder; + private static final boolean supportsWindowSizeAndMemLevel; + + static { + noJdkZlibDecoder = SystemPropertyUtil.getBoolean("io.netty.noJdkZlibDecoder", + PlatformDependent.javaVersion() < 7); + logger.debug("-Dio.netty.noJdkZlibDecoder: {}", noJdkZlibDecoder); + + noJdkZlibEncoder = SystemPropertyUtil.getBoolean("io.netty.noJdkZlibEncoder", false); + logger.debug("-Dio.netty.noJdkZlibEncoder: {}", noJdkZlibEncoder); + + supportsWindowSizeAndMemLevel = noJdkZlibDecoder || PlatformDependent.javaVersion() >= 7; + } + + /** + * Returns {@code true} if specify a custom window size and mem level is supported. + */ + public static boolean isSupportingWindowSizeAndMemLevel() { + return supportsWindowSizeAndMemLevel; + } + + public static ZlibEncoder newZlibEncoder(int compressionLevel) { + if (PlatformDependent.javaVersion() < 7 || noJdkZlibEncoder) { + return new JZlibEncoder(compressionLevel); + } else { + return new JdkZlibEncoder(compressionLevel); + } + } + + public static ZlibEncoder newZlibEncoder(ZlibWrapper wrapper) { + if (PlatformDependent.javaVersion() < 7 || noJdkZlibEncoder) { + return new JZlibEncoder(wrapper); + } else { + return new JdkZlibEncoder(wrapper); + } + } + + public static ZlibEncoder newZlibEncoder(ZlibWrapper wrapper, int compressionLevel) { + if (PlatformDependent.javaVersion() < 7 || noJdkZlibEncoder) { + return new JZlibEncoder(wrapper, compressionLevel); + } else { + return new JdkZlibEncoder(wrapper, compressionLevel); + } + } + + public static ZlibEncoder newZlibEncoder(ZlibWrapper wrapper, int compressionLevel, int windowBits, int memLevel) { + if (PlatformDependent.javaVersion() < 7 || noJdkZlibEncoder || + windowBits != DEFAULT_JDK_WINDOW_SIZE || memLevel != DEFAULT_JDK_MEM_LEVEL) { + return new JZlibEncoder(wrapper, compressionLevel, windowBits, memLevel); + } else { + return new JdkZlibEncoder(wrapper, compressionLevel); + } + } + + public static ZlibEncoder newZlibEncoder(byte[] dictionary) { + if (PlatformDependent.javaVersion() < 7 || noJdkZlibEncoder) { + return new JZlibEncoder(dictionary); + } else { + return new JdkZlibEncoder(dictionary); + } + } + + public static ZlibEncoder newZlibEncoder(int compressionLevel, byte[] dictionary) { + if (PlatformDependent.javaVersion() < 7 || noJdkZlibEncoder) { + return new JZlibEncoder(compressionLevel, dictionary); + } else { + return new JdkZlibEncoder(compressionLevel, dictionary); + } + } + + public static ZlibEncoder newZlibEncoder(int compressionLevel, int windowBits, int memLevel, byte[] dictionary) { + if (PlatformDependent.javaVersion() < 7 || noJdkZlibEncoder || + windowBits != DEFAULT_JDK_WINDOW_SIZE || memLevel != DEFAULT_JDK_MEM_LEVEL) { + return new JZlibEncoder(compressionLevel, windowBits, memLevel, dictionary); + } else { + return new JdkZlibEncoder(compressionLevel, dictionary); + } + } + + public static ZlibDecoder newZlibDecoder() { + if (PlatformDependent.javaVersion() < 7 || noJdkZlibDecoder) { + return new JZlibDecoder(); + } else { + return new JdkZlibDecoder(true); + } + } + + public static ZlibDecoder newZlibDecoder(ZlibWrapper wrapper) { + if (PlatformDependent.javaVersion() < 7 || noJdkZlibDecoder) { + return new JZlibDecoder(wrapper); + } else { + return new JdkZlibDecoder(wrapper, true); + } + } + + public static ZlibDecoder newZlibDecoder(byte[] dictionary) { + if (PlatformDependent.javaVersion() < 7 || noJdkZlibDecoder) { + return new JZlibDecoder(dictionary); + } else { + return new JdkZlibDecoder(dictionary); + } + } + + private ZlibCodecFactory() { + // Unused + } +} diff --git a/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/ZlibDecoder.java b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/ZlibDecoder.java new file mode 100644 index 0000000..bc05424 --- /dev/null +++ b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/ZlibDecoder.java @@ -0,0 +1,95 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.compression; + +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.ByteToMessageDecoder; + +/** + * Decompresses a {@link ByteBuf} using the deflate algorithm. + */ +public abstract class ZlibDecoder extends ByteToMessageDecoder { + + /** + * Maximum allowed size of the decompression buffer. + */ + protected final int maxAllocation; + + /** + * Same as {@link #ZlibDecoder(int)} with maxAllocation = 0. + */ + public ZlibDecoder() { + this(0); + } + + /** + * Construct a new ZlibDecoder. + * @param maxAllocation + * Maximum size of the decompression buffer. Must be >= 0. + * If zero, maximum size is decided by the {@link ByteBufAllocator}. + */ + public ZlibDecoder(int maxAllocation) { + this.maxAllocation = checkPositiveOrZero(maxAllocation, "maxAllocation"); + } + + /** + * Returns {@code true} if and only if the end of the compressed stream + * has been reached. + */ + public abstract boolean isClosed(); + + /** + * Allocate or expand the decompression buffer, without exceeding the maximum allocation. + * Calls {@link #decompressionBufferExhausted(ByteBuf)} if the buffer is full and cannot be expanded further. + */ + protected ByteBuf prepareDecompressBuffer(ChannelHandlerContext ctx, ByteBuf buffer, int preferredSize) { + if (buffer == null) { + if (maxAllocation == 0) { + return ctx.alloc().heapBuffer(preferredSize); + } + + return ctx.alloc().heapBuffer(Math.min(preferredSize, maxAllocation), maxAllocation); + } + + // this always expands the buffer if possible, even if the expansion is less than preferredSize + // we throw the exception only if the buffer could not be expanded at all + // this means that one final attempt to deserialize will always be made with the buffer at maxAllocation + if (buffer.ensureWritable(preferredSize, true) == 1) { + // buffer must be consumed so subclasses don't add it to output + // we therefore duplicate it when calling decompressionBufferExhausted() to guarantee non-interference + // but wait until after to consume it so the subclass can tell how much output is really in the buffer + decompressionBufferExhausted(buffer.duplicate()); + buffer.skipBytes(buffer.readableBytes()); + throw new DecompressionException("Decompression buffer has reached maximum size: " + buffer.maxCapacity()); + } + + return buffer; + } + + /** + * Called when the decompression buffer cannot be expanded further. + * Default implementation is a no-op, but subclasses can override in case they want to + * do something before the {@link DecompressionException} is thrown, such as log the + * data that was decompressed so far. + */ + protected void decompressionBufferExhausted(ByteBuf buffer) { + } + +} diff --git a/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/ZlibEncoder.java b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/ZlibEncoder.java new file mode 100644 index 0000000..b558eaa --- /dev/null +++ b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/ZlibEncoder.java @@ -0,0 +1,53 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.compression; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelPromise; +import io.netty.handler.codec.MessageToByteEncoder; + +/** + * Compresses a {@link ByteBuf} using the deflate algorithm. + */ +public abstract class ZlibEncoder extends MessageToByteEncoder { + + protected ZlibEncoder() { + super(false); + } + + /** + * Returns {@code true} if and only if the end of the compressed stream + * has been reached. + */ + public abstract boolean isClosed(); + + /** + * Close this {@link ZlibEncoder} and so finish the encoding. + * + * The returned {@link ChannelFuture} will be notified once the + * operation completes. + */ + public abstract ChannelFuture close(); + + /** + * Close this {@link ZlibEncoder} and so finish the encoding. + * The given {@link ChannelFuture} will be notified once the operation + * completes and will also be returned. + */ + public abstract ChannelFuture close(ChannelPromise promise); + +} diff --git a/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/ZlibUtil.java b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/ZlibUtil.java new file mode 100644 index 0000000..f72adeb --- /dev/null +++ b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/ZlibUtil.java @@ -0,0 +1,85 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.compression; + +import io.netty.zlib.Deflater; +import io.netty.zlib.Inflater; +import io.netty.zlib.JZlib; + +/** + * Utility methods used by {@link JZlibEncoder} and {@link JZlibDecoder}. + */ +final class ZlibUtil { + + static void fail(Inflater z, String message, int resultCode) { + throw inflaterException(z, message, resultCode); + } + + static void fail(Deflater z, String message, int resultCode) { + throw deflaterException(z, message, resultCode); + } + + static DecompressionException inflaterException(Inflater z, String message, int resultCode) { + return new DecompressionException(message + " (" + resultCode + ')' + (z.msg != null? ": " + z.msg : "")); + } + + static CompressionException deflaterException(Deflater z, String message, int resultCode) { + return new CompressionException(message + " (" + resultCode + ')' + (z.msg != null? ": " + z.msg : "")); + } + + static JZlib.WrapperType convertWrapperType(ZlibWrapper wrapper) { + JZlib.WrapperType convertedWrapperType; + switch (wrapper) { + case NONE: + convertedWrapperType = JZlib.W_NONE; + break; + case ZLIB: + convertedWrapperType = JZlib.W_ZLIB; + break; + case GZIP: + convertedWrapperType = JZlib.W_GZIP; + break; + case ZLIB_OR_NONE: + convertedWrapperType = JZlib.W_ANY; + break; + default: + throw new Error(); + } + return convertedWrapperType; + } + + static int wrapperOverhead(ZlibWrapper wrapper) { + int overhead; + switch (wrapper) { + case NONE: + overhead = 0; + break; + case ZLIB: + case ZLIB_OR_NONE: + overhead = 2; + break; + case GZIP: + overhead = 10; + break; + default: + throw new Error(); + } + return overhead; + } + + private ZlibUtil() { + } +} diff --git a/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/ZlibWrapper.java b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/ZlibWrapper.java new file mode 100644 index 0000000..bd64bef --- /dev/null +++ b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/ZlibWrapper.java @@ -0,0 +1,40 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.compression; + +/** + * The container file formats that wrap the stream compressed by the DEFLATE + * algorithm. + */ +public enum ZlibWrapper { + /** + * The ZLIB wrapper as specified in RFC 1950. + */ + ZLIB, + /** + * The GZIP wrapper as specified in RFC 1952. + */ + GZIP, + /** + * Raw DEFLATE stream only (no header and no footer). + */ + NONE, + /** + * Try {@link #ZLIB} first and then {@link #NONE} if the first attempt fails. + * Please note that you can specify this wrapper type only when decompressing. + */ + ZLIB_OR_NONE +} diff --git a/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/Zstd.java b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/Zstd.java new file mode 100644 index 0000000..3d03524 --- /dev/null +++ b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/Zstd.java @@ -0,0 +1,75 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.codec.compression; + +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +public final class Zstd { + + private static final InternalLogger logger = InternalLoggerFactory.getInstance(Zstd.class); + private static final Throwable cause; + + static { + Throwable t = null; + + try { + Class.forName("com.github.luben.zstd.Zstd", false, + PlatformDependent.getClassLoader(Zstd.class)); + } catch (ClassNotFoundException e) { + t = e; + logger.debug( + "zstd-jni not in the classpath; Zstd support will be unavailable."); + } catch (Throwable e) { + t = e; + logger.debug("Failed to load zstd-jni; Zstd support will be unavailable.", t); + } + + cause = t; + } + + /** + * + * @return true when zstd-jni is in the classpath + * and native library is available on this platform and could be loaded + */ + public static boolean isAvailable() { + return cause == null; + } + + /** + * Throws when zstd support is missing from the classpath or is unavailable on this platform + * @throws Throwable a ClassNotFoundException if zstd-jni is missing + * or a ExceptionInInitializerError if zstd native lib can't be loaded + */ + public static void ensureAvailability() throws Throwable { + if (cause != null) { + throw cause; + } + } + + /** + * Returns {@link Throwable} of unavailability cause + */ + public static Throwable cause() { + return cause; + } + + private Zstd() { + } +} diff --git a/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/ZstdConstants.java b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/ZstdConstants.java new file mode 100644 index 0000000..1588ba5 --- /dev/null +++ b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/ZstdConstants.java @@ -0,0 +1,40 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.compression; + +final class ZstdConstants { + + /** + * Default compression level + */ + static final int DEFAULT_COMPRESSION_LEVEL = 3; + + /** + * Max compression level + */ + static final int MAX_COMPRESSION_LEVEL = 22; + + /** + * Max block size + */ + static final int MAX_BLOCK_SIZE = 1 << (DEFAULT_COMPRESSION_LEVEL + 7) + 0x0F; // 32 M + /** + * Default block size + */ + static final int DEFAULT_BLOCK_SIZE = 1 << 16; // 64 KB + + private ZstdConstants() { } +} diff --git a/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/ZstdEncoder.java b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/ZstdEncoder.java new file mode 100644 index 0000000..45badd2 --- /dev/null +++ b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/ZstdEncoder.java @@ -0,0 +1,185 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.compression; + +import com.github.luben.zstd.Zstd; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.EncoderException; +import io.netty.handler.codec.MessageToByteEncoder; +import io.netty.util.internal.ObjectUtil; +import java.nio.ByteBuffer; + +import static io.netty.handler.codec.compression.ZstdConstants.DEFAULT_COMPRESSION_LEVEL; +import static io.netty.handler.codec.compression.ZstdConstants.DEFAULT_BLOCK_SIZE; +import static io.netty.handler.codec.compression.ZstdConstants.MAX_BLOCK_SIZE; +import static io.netty.handler.codec.compression.ZstdConstants.MAX_COMPRESSION_LEVEL; + +/** + * Compresses a {@link ByteBuf} using the Zstandard algorithm. + * See Zstandard. + */ +public final class ZstdEncoder extends MessageToByteEncoder { + + private final int blockSize; + private final int compressionLevel; + private final int maxEncodeSize; + private ByteBuf buffer; + + /** + * Creates a new Zstd encoder. + * + * Please note that if you use the default constructor, the default BLOCK_SIZE and MAX_BLOCK_SIZE + * will be used. If you want to specify BLOCK_SIZE and MAX_BLOCK_SIZE yourself, + * please use {@link ZstdEncoder(int,int)} constructor + */ + public ZstdEncoder() { + this(DEFAULT_COMPRESSION_LEVEL, DEFAULT_BLOCK_SIZE, MAX_BLOCK_SIZE); + } + + /** + * Creates a new Zstd encoder. + * @param compressionLevel + * specifies the level of the compression + */ + public ZstdEncoder(int compressionLevel) { + this(compressionLevel, DEFAULT_BLOCK_SIZE, MAX_BLOCK_SIZE); + } + + /** + * Creates a new Zstd encoder. + * @param blockSize + * is used to calculate the compressionLevel + * @param maxEncodeSize + * specifies the size of the largest compressed object + */ + public ZstdEncoder(int blockSize, int maxEncodeSize) { + this(DEFAULT_COMPRESSION_LEVEL, blockSize, maxEncodeSize); + } + + /** + * @param blockSize + * is used to calculate the compressionLevel + * @param maxEncodeSize + * specifies the size of the largest compressed object + * @param compressionLevel + * specifies the level of the compression + */ + public ZstdEncoder(int compressionLevel, int blockSize, int maxEncodeSize) { + super(true); + this.compressionLevel = ObjectUtil.checkInRange(compressionLevel, 0, MAX_COMPRESSION_LEVEL, "compressionLevel"); + this.blockSize = ObjectUtil.checkPositive(blockSize, "blockSize"); + this.maxEncodeSize = ObjectUtil.checkPositive(maxEncodeSize, "maxEncodeSize"); + } + + @Override + protected ByteBuf allocateBuffer(ChannelHandlerContext ctx, ByteBuf msg, boolean preferDirect) { + if (buffer == null) { + throw new IllegalStateException("not added to a pipeline," + + "or has been removed,buffer is null"); + } + + int remaining = msg.readableBytes() + buffer.readableBytes(); + + // quick overflow check + if (remaining < 0) { + throw new EncoderException("too much data to allocate a buffer for compression"); + } + + long bufferSize = 0; + while (remaining > 0) { + int curSize = Math.min(blockSize, remaining); + remaining -= curSize; + bufferSize += Zstd.compressBound(curSize); + } + + if (bufferSize > maxEncodeSize || 0 > bufferSize) { + throw new EncoderException("requested encode buffer size (" + bufferSize + " bytes) exceeds " + + "the maximum allowable size (" + maxEncodeSize + " bytes)"); + } + + return ctx.alloc().directBuffer((int) bufferSize); + } + + @Override + protected void encode(ChannelHandlerContext ctx, ByteBuf in, ByteBuf out) { + if (buffer == null) { + throw new IllegalStateException("not added to a pipeline," + + "or has been removed,buffer is null"); + } + + final ByteBuf buffer = this.buffer; + int length; + while ((length = in.readableBytes()) > 0) { + final int nextChunkSize = Math.min(length, buffer.writableBytes()); + in.readBytes(buffer, nextChunkSize); + + if (!buffer.isWritable()) { + flushBufferedData(out); + } + } + } + + private void flushBufferedData(ByteBuf out) { + final int flushableBytes = buffer.readableBytes(); + if (flushableBytes == 0) { + return; + } + + final int bufSize = (int) Zstd.compressBound(flushableBytes); + out.ensureWritable(bufSize); + final int idx = out.writerIndex(); + int compressedLength; + try { + ByteBuffer outNioBuffer = out.internalNioBuffer(idx, out.writableBytes()); + compressedLength = Zstd.compress( + outNioBuffer, + buffer.internalNioBuffer(buffer.readerIndex(), flushableBytes), + compressionLevel); + } catch (Exception e) { + throw new CompressionException(e); + } + + out.writerIndex(idx + compressedLength); + buffer.clear(); + } + + @Override + public void flush(final ChannelHandlerContext ctx) { + if (buffer != null && buffer.isReadable()) { + final ByteBuf buf = allocateBuffer(ctx, Unpooled.EMPTY_BUFFER, isPreferDirect()); + flushBufferedData(buf); + ctx.write(buf); + } + ctx.flush(); + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) { + buffer = ctx.alloc().directBuffer(blockSize); + buffer.clear(); + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { + super.handlerRemoved(ctx); + if (buffer != null) { + buffer.release(); + buffer = null; + } + } +} diff --git a/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/ZstdOptions.java b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/ZstdOptions.java new file mode 100644 index 0000000..110a90f --- /dev/null +++ b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/ZstdOptions.java @@ -0,0 +1,73 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.compression; + +import io.netty.util.internal.ObjectUtil; + +import static io.netty.handler.codec.compression.ZstdConstants.DEFAULT_COMPRESSION_LEVEL; +import static io.netty.handler.codec.compression.ZstdConstants.MAX_COMPRESSION_LEVEL; +import static io.netty.handler.codec.compression.ZstdConstants.DEFAULT_BLOCK_SIZE; +import static io.netty.handler.codec.compression.ZstdConstants.MAX_BLOCK_SIZE; + +/** + * {@link ZstdOptions} holds compressionLevel for + * Zstd compression. + */ +public class ZstdOptions implements CompressionOptions { + + private final int blockSize; + private final int compressionLevel; + private final int maxEncodeSize; + + /** + * Default implementation of {@link ZstdOptions} with{compressionLevel(int)} set to + * {@link ZstdConstants#DEFAULT_COMPRESSION_LEVEL},{@link ZstdConstants#DEFAULT_BLOCK_SIZE}, + * {@link ZstdConstants#MAX_BLOCK_SIZE} + */ + static final ZstdOptions DEFAULT = new ZstdOptions(DEFAULT_COMPRESSION_LEVEL, DEFAULT_BLOCK_SIZE, MAX_BLOCK_SIZE); + + /** + * Create a new {@link ZstdOptions} + * + * @param blockSize + * is used to calculate the compressionLevel + * @param maxEncodeSize + * specifies the size of the largest compressed object + * @param compressionLevel + * specifies the level of the compression + */ + ZstdOptions(int compressionLevel, int blockSize, int maxEncodeSize) { + if (!Zstd.isAvailable()) { + throw new IllegalStateException("zstd-jni is not available", Zstd.cause()); + } + + this.compressionLevel = ObjectUtil.checkInRange(compressionLevel, 0, MAX_COMPRESSION_LEVEL, "compressionLevel"); + this.blockSize = ObjectUtil.checkPositive(blockSize, "blockSize"); + this.maxEncodeSize = ObjectUtil.checkPositive(maxEncodeSize, "maxEncodeSize"); + } + + public int compressionLevel() { + return compressionLevel; + } + + public int blockSize() { + return blockSize; + } + + public int maxEncodeSize() { + return maxEncodeSize; + } +} diff --git a/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/package-info.java b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/package-info.java new file mode 100644 index 0000000..9399aec --- /dev/null +++ b/netty-handler-codec-compression/src/main/java/io/netty/handler/codec/compression/package-info.java @@ -0,0 +1,23 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/** + * Encoder and decoder which compresses and decompresses {@link io.netty.buffer.ByteBuf}s + * in a compression format such as zlib, + * gzip, and + * Snappy. + */ +package io.netty.handler.codec.compression; diff --git a/netty-handler-codec-compression/src/main/java/module-info.java b/netty-handler-codec-compression/src/main/java/module-info.java new file mode 100644 index 0000000..b364f41 --- /dev/null +++ b/netty-handler-codec-compression/src/main/java/module-info.java @@ -0,0 +1,13 @@ +module org.xbib.io.netty.handler.codec.compression { + exports io.netty.handler.codec.compression; + requires org.xbib.io.netty.handler.codec; + requires org.xbib.io.netty.buffer; + requires org.xbib.io.netty.bziptwo; + requires org.xbib.io.netty.channel; + requires org.xbib.io.netty.util; + requires org.xbib.io.netty.zlib; + requires com.aayushatharva.brotli4j; + requires org.lz4.java; + requires com.ning.compress.lzf; + requires com.github.luben.zstd_jni; +} diff --git a/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/AbstractCompressionTest.java b/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/AbstractCompressionTest.java new file mode 100644 index 0000000..933c334 --- /dev/null +++ b/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/AbstractCompressionTest.java @@ -0,0 +1,38 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.compression; + +import java.util.Random; + +public abstract class AbstractCompressionTest { + + protected static final Random rand; + + protected static final byte[] BYTES_SMALL = new byte[256]; + protected static final byte[] BYTES_LARGE = new byte[256 * 1024]; + + static { + rand = new Random(); + fillArrayWithCompressibleData(BYTES_SMALL); + fillArrayWithCompressibleData(BYTES_LARGE); + } + + private static void fillArrayWithCompressibleData(byte[] array) { + for (int i = 0; i < array.length; i++) { + array[i] = i % 4 != 0 ? 0 : (byte) rand.nextInt(); + } + } +} diff --git a/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/AbstractDecoderTest.java b/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/AbstractDecoderTest.java new file mode 100644 index 0000000..2d708d8 --- /dev/null +++ b/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/AbstractDecoderTest.java @@ -0,0 +1,152 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.compression; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.CompositeByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.embedded.EmbeddedChannel; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +public abstract class AbstractDecoderTest extends AbstractCompressionTest { + + protected static final ByteBuf WRAPPED_BYTES_SMALL = Unpooled.unreleasableBuffer( + Unpooled.wrappedBuffer(BYTES_SMALL)).asReadOnly(); + protected static final ByteBuf WRAPPED_BYTES_LARGE = Unpooled.unreleasableBuffer( + Unpooled.wrappedBuffer(BYTES_LARGE)).asReadOnly(); + + protected EmbeddedChannel channel; + + protected byte[] compressedBytesSmall; + protected byte[] compressedBytesLarge; + + protected AbstractDecoderTest() throws Exception { + compressedBytesSmall = compress(BYTES_SMALL); + compressedBytesLarge = compress(BYTES_LARGE); + } + + /** + * Compresses data with some external library. + */ + protected abstract byte[] compress(byte[] data) throws Exception; + + @BeforeEach + public final void initChannel() { + channel = createChannel(); + } + + protected abstract EmbeddedChannel createChannel(); + + @AfterEach + public void destroyChannel() { + if (channel != null) { + channel.finishAndReleaseAll(); + channel = null; + } + } + + public ByteBuf[] smallData() { + ByteBuf heap = Unpooled.wrappedBuffer(compressedBytesSmall); + ByteBuf direct = Unpooled.directBuffer(compressedBytesSmall.length); + direct.writeBytes(compressedBytesSmall); + return new ByteBuf[] {heap, direct}; + } + + public ByteBuf[] largeData() { + ByteBuf heap = Unpooled.wrappedBuffer(compressedBytesLarge); + ByteBuf direct = Unpooled.directBuffer(compressedBytesLarge.length); + direct.writeBytes(compressedBytesLarge); + return new ByteBuf[] {heap, direct}; + } + + @ParameterizedTest + @MethodSource("smallData") + public void testDecompressionOfSmallChunkOfData(ByteBuf data) throws Exception { + testDecompression(WRAPPED_BYTES_SMALL.duplicate(), data); + } + + @ParameterizedTest + @MethodSource("largeData") + public void testDecompressionOfLargeChunkOfData(ByteBuf data) throws Exception { + testDecompression(WRAPPED_BYTES_LARGE.duplicate(), data); + } + + @ParameterizedTest + @MethodSource("largeData") + public void testDecompressionOfBatchedFlowOfData(ByteBuf data) throws Exception { + testDecompressionOfBatchedFlow(WRAPPED_BYTES_LARGE.duplicate(), data); + } + + protected void testDecompression(final ByteBuf expected, final ByteBuf data) throws Exception { + assertTrue(channel.writeInbound(data)); + + ByteBuf decompressed = readDecompressed(channel); + assertEquals(expected, decompressed); + + decompressed.release(); + } + + protected void testDecompressionOfBatchedFlow(final ByteBuf expected, final ByteBuf data) throws Exception { + final int compressedLength = data.readableBytes(); + int written = 0, length = rand.nextInt(100); + while (written + length < compressedLength) { + ByteBuf compressedBuf = data.retainedSlice(written, length); + channel.writeInbound(compressedBuf); + written += length; + length = rand.nextInt(100); + } + ByteBuf compressedBuf = data.slice(written, compressedLength - written); + assertTrue(channel.writeInbound(compressedBuf.retain())); + + ByteBuf decompressedBuf = readDecompressed(channel); + assertEquals(expected, decompressedBuf); + + decompressedBuf.release(); + data.release(); + } + + protected static ByteBuf readDecompressed(final EmbeddedChannel channel) { + CompositeByteBuf decompressed = Unpooled.compositeBuffer(); + ByteBuf msg; + while ((msg = channel.readInbound()) != null) { + decompressed.addComponent(true, msg); + } + return decompressed; + } + + protected static void tryDecodeAndCatchBufLeaks(final EmbeddedChannel channel, final ByteBuf data) { + try { + channel.writeInbound(data); + } finally { + for (;;) { + ByteBuf inflated = channel.readInbound(); + if (inflated == null) { + break; + } + inflated.release(); + } + channel.finish(); + } + } +} diff --git a/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/AbstractEncoderTest.java b/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/AbstractEncoderTest.java new file mode 100644 index 0000000..4d29a1b --- /dev/null +++ b/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/AbstractEncoderTest.java @@ -0,0 +1,129 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.compression; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.CompositeByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.embedded.EmbeddedChannel; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public abstract class AbstractEncoderTest extends AbstractCompressionTest { + + protected EmbeddedChannel channel; + + /** + * Decompresses data with some external library. + */ + protected abstract ByteBuf decompress(ByteBuf compressed, int originalLength) throws Exception; + + @BeforeEach + public final void initChannel() { + channel = createChannel(); + } + + protected abstract EmbeddedChannel createChannel(); + + @AfterEach + public void destroyChannel() { + if (channel != null) { + channel.finishAndReleaseAll(); + channel = null; + } + } + + public static ByteBuf[] smallData() { + ByteBuf heap = Unpooled.wrappedBuffer(BYTES_SMALL); + ByteBuf direct = Unpooled.directBuffer(BYTES_SMALL.length); + direct.writeBytes(BYTES_SMALL); + return new ByteBuf[] {heap, direct}; + } + + public static ByteBuf[] largeData() { + ByteBuf heap = Unpooled.wrappedBuffer(BYTES_LARGE); + ByteBuf direct = Unpooled.directBuffer(BYTES_LARGE.length); + direct.writeBytes(BYTES_LARGE); + return new ByteBuf[] {heap, direct}; + } + + @ParameterizedTest + @MethodSource("smallData") + public void testCompressionOfSmallChunkOfData(ByteBuf data) throws Exception { + testCompression(data); + } + + @ParameterizedTest + @MethodSource("largeData") + public void testCompressionOfLargeChunkOfData(ByteBuf data) throws Exception { + testCompression(data); + } + + @ParameterizedTest + @MethodSource("largeData") + public void testCompressionOfBatchedFlowOfData(ByteBuf data) throws Exception { + testCompressionOfBatchedFlow(data); + } + + protected void testCompression(final ByteBuf data) throws Exception { + final int dataLength = data.readableBytes(); + assertTrue(channel.writeOutbound(data.retain())); + assertTrue(channel.finish()); + assertEquals(0, data.readableBytes()); + + ByteBuf decompressed = readDecompressed(dataLength); + assertEquals(data.resetReaderIndex(), decompressed); + + decompressed.release(); + data.release(); + } + + protected void testCompressionOfBatchedFlow(final ByteBuf data) throws Exception { + final int dataLength = data.readableBytes(); + int written = 0, length = rand.nextInt(100); + while (written + length < dataLength) { + ByteBuf in = data.retainedSlice(written, length); + assertTrue(channel.writeOutbound(in)); + assertEquals(0, in.readableBytes()); + written += length; + length = rand.nextInt(100); + } + ByteBuf in = data.retainedSlice(written, dataLength - written); + assertTrue(channel.writeOutbound(in)); + assertTrue(channel.finish()); + assertEquals(0, in.readableBytes()); + + ByteBuf decompressed = readDecompressed(dataLength); + assertEquals(data, decompressed); + + decompressed.release(); + data.release(); + } + + protected ByteBuf readDecompressed(final int dataLength) throws Exception { + CompositeByteBuf compressed = Unpooled.compositeBuffer(); + ByteBuf msg; + while ((msg = channel.readOutbound()) != null) { + compressed.addComponent(true, msg); + } + return decompress(compressed, dataLength); + } +} diff --git a/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/AbstractIntegrationTest.java b/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/AbstractIntegrationTest.java new file mode 100644 index 0000000..3863a3f --- /dev/null +++ b/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/AbstractIntegrationTest.java @@ -0,0 +1,185 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.compression; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.CompositeByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.util.CharsetUtil; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.internal.EmptyArrays; +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.Random; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.*; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public abstract class AbstractIntegrationTest { + + protected static final Random rand = new Random(); + + protected EmbeddedChannel encoder; + protected EmbeddedChannel decoder; + + protected abstract EmbeddedChannel createEncoder(); + protected abstract EmbeddedChannel createDecoder(); + + public void initChannels() { + encoder = createEncoder(); + decoder = createDecoder(); + } + + public void closeChannels() { + encoder.close(); + for (;;) { + Object msg = encoder.readOutbound(); + if (msg == null) { + break; + } + ReferenceCountUtil.release(msg); + } + + decoder.close(); + for (;;) { + Object msg = decoder.readInbound(); + if (msg == null) { + break; + } + ReferenceCountUtil.release(msg); + } + } + + @Test + public void testEmpty() throws Exception { + testIdentity(EmptyArrays.EMPTY_BYTES, true); + testIdentity(EmptyArrays.EMPTY_BYTES, false); + } + + @Test + public void testOneByte() throws Exception { + final byte[] data = { 'A' }; + testIdentity(data, true); + testIdentity(data, false); + } + + @Test + public void testTwoBytes() throws Exception { + final byte[] data = { 'B', 'A' }; + testIdentity(data, true); + testIdentity(data, false); + } + + @Test + public void testRegular() throws Exception { + final byte[] data = ("Netty is a NIO client server framework which enables " + + "quick and easy development of network applications such as protocol " + + "servers and clients.").getBytes(CharsetUtil.UTF_8); + testIdentity(data, true); + testIdentity(data, false); + } + + @Test + public void testLargeRandom() throws Exception { + final byte[] data = new byte[1024 * 1024]; + rand.nextBytes(data); + testIdentity(data, true); + testIdentity(data, false); + } + + @Test + public void testPartRandom() throws Exception { + final byte[] data = new byte[10240]; + rand.nextBytes(data); + for (int i = 0; i < 1024; i++) { + data[i] = 2; + } + testIdentity(data, true); + testIdentity(data, false); + } + + @Test + public void testCompressible() throws Exception { + final byte[] data = new byte[10240]; + for (int i = 0; i < data.length; i++) { + data[i] = i % 4 != 0 ? 0 : (byte) rand.nextInt(); + } + testIdentity(data, true); + testIdentity(data, false); + } + + @Test + public void testLongBlank() throws Exception { + final byte[] data = new byte[102400]; + testIdentity(data, true); + testIdentity(data, false); + } + + @Test + public void testLongSame() throws Exception { + final byte[] data = new byte[102400]; + Arrays.fill(data, (byte) 123); + testIdentity(data, true); + testIdentity(data, false); + } + + @Test + public void testSequential() throws Exception { + final byte[] data = new byte[1024]; + for (int i = 0; i < data.length; i++) { + data[i] = (byte) i; + } + testIdentity(data, true); + testIdentity(data, false); + } + + protected void testIdentity(final byte[] data, boolean heapBuffer) { + initChannels(); + final ByteBuf in = heapBuffer? Unpooled.wrappedBuffer(data) : + Unpooled.directBuffer(data.length).writeBytes(data); + final CompositeByteBuf compressed = Unpooled.compositeBuffer(); + final CompositeByteBuf decompressed = Unpooled.compositeBuffer(); + + try { + assertTrue(encoder.writeOutbound(in.retain())); + assertTrue(encoder.finish()); + + ByteBuf msg; + while ((msg = encoder.readOutbound()) != null) { + compressed.addComponent(true, msg); + } + assertThat(compressed, is(notNullValue())); + + decoder.writeInbound(compressed.retain()); + assertFalse(compressed.isReadable()); + while ((msg = decoder.readInbound()) != null) { + decompressed.addComponent(true, msg); + } + in.readerIndex(0); + assertEquals(in, decompressed); + } finally { + compressed.release(); + decompressed.release(); + in.release(); + closeChannels(); + } + } +} diff --git a/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/BrotliDecoderTest.java b/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/BrotliDecoderTest.java new file mode 100644 index 0000000..104a2d2 --- /dev/null +++ b/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/BrotliDecoderTest.java @@ -0,0 +1,162 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.codec.compression; + +import com.aayushatharva.brotli4j.encoder.BrotliOutputStream; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.CompositeByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.embedded.EmbeddedChannel; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.util.Random; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class BrotliDecoderTest { + + private static Random RANDOM; + private static final byte[] BYTES_SMALL = new byte[256]; + private static final byte[] BYTES_LARGE = new byte[256 * 1024]; + private static byte[] COMPRESSED_BYTES_SMALL; + private static byte[] COMPRESSED_BYTES_LARGE; + + @BeforeAll + static void setUp() { + try { + Brotli.ensureAvailability(); + + RANDOM = new Random(); + fillArrayWithCompressibleData(BYTES_SMALL); + fillArrayWithCompressibleData(BYTES_LARGE); + COMPRESSED_BYTES_SMALL = compress(BYTES_SMALL); + COMPRESSED_BYTES_LARGE = compress(BYTES_LARGE); + } catch (Throwable throwable) { + throw new ExceptionInInitializerError(throwable); + } + } + + private static final ByteBuf WRAPPED_BYTES_SMALL = Unpooled.unreleasableBuffer( + Unpooled.wrappedBuffer(BYTES_SMALL)).asReadOnly(); + private static final ByteBuf WRAPPED_BYTES_LARGE = Unpooled.unreleasableBuffer( + Unpooled.wrappedBuffer(BYTES_LARGE)).asReadOnly(); + + private static void fillArrayWithCompressibleData(byte[] array) { + for (int i = 0; i < array.length; i++) { + array[i] = i % 4 != 0 ? 0 : (byte) RANDOM.nextInt(); + } + } + + private static byte[] compress(byte[] data) throws IOException { + ByteArrayOutputStream os = new ByteArrayOutputStream(); + BrotliOutputStream brotliOs = new BrotliOutputStream(os); + brotliOs.write(data); + brotliOs.close(); + return os.toByteArray(); + } + + private EmbeddedChannel channel; + + @BeforeEach + public void initChannel() { + channel = new EmbeddedChannel(new BrotliDecoder()); + } + + @AfterEach + public void destroyChannel() { + if (channel != null) { + channel.finishAndReleaseAll(); + channel = null; + } + } + + public static ByteBuf[] smallData() { + ByteBuf heap = Unpooled.wrappedBuffer(COMPRESSED_BYTES_SMALL); + ByteBuf direct = Unpooled.directBuffer(COMPRESSED_BYTES_SMALL.length); + direct.writeBytes(COMPRESSED_BYTES_SMALL); + return new ByteBuf[]{heap, direct}; + } + + public static ByteBuf[] largeData() { + ByteBuf heap = Unpooled.wrappedBuffer(COMPRESSED_BYTES_LARGE); + ByteBuf direct = Unpooled.directBuffer(COMPRESSED_BYTES_LARGE.length); + direct.writeBytes(COMPRESSED_BYTES_LARGE); + return new ByteBuf[]{heap, direct}; + } + + @ParameterizedTest + @MethodSource("smallData") + public void testDecompressionOfSmallChunkOfData(ByteBuf data) { + testDecompression(WRAPPED_BYTES_SMALL.duplicate(), data); + } + + @ParameterizedTest + @MethodSource("largeData") + public void testDecompressionOfLargeChunkOfData(ByteBuf data) { + testDecompression(WRAPPED_BYTES_LARGE.duplicate(), data); + } + + @ParameterizedTest + @MethodSource("largeData") + public void testDecompressionOfBatchedFlowOfData(ByteBuf data) { + testDecompressionOfBatchedFlow(WRAPPED_BYTES_LARGE, data); + } + + private void testDecompression(final ByteBuf expected, final ByteBuf data) { + assertTrue(channel.writeInbound(data)); + + ByteBuf decompressed = readDecompressed(channel); + assertEquals(expected, decompressed); + + decompressed.release(); + } + + private void testDecompressionOfBatchedFlow(final ByteBuf expected, final ByteBuf data) { + final int compressedLength = data.readableBytes(); + int written = 0, length = RANDOM.nextInt(100); + while (written + length < compressedLength) { + ByteBuf compressedBuf = data.retainedSlice(written, length); + channel.writeInbound(compressedBuf); + written += length; + length = RANDOM.nextInt(100); + } + ByteBuf compressedBuf = data.slice(written, compressedLength - written); + assertTrue(channel.writeInbound(compressedBuf.retain())); + + ByteBuf decompressedBuf = readDecompressed(channel); + assertEquals(expected, decompressedBuf); + + decompressedBuf.release(); + data.release(); + } + + private static ByteBuf readDecompressed(final EmbeddedChannel channel) { + CompositeByteBuf decompressed = Unpooled.compositeBuffer(); + ByteBuf msg; + while ((msg = channel.readInbound()) != null) { + decompressed.addComponent(true, msg); + } + return decompressed; + } +} diff --git a/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/BrotliEncoderTest.java b/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/BrotliEncoderTest.java new file mode 100644 index 0000000..4678129 --- /dev/null +++ b/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/BrotliEncoderTest.java @@ -0,0 +1,83 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.compression; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.CompositeByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.embedded.EmbeddedChannel; +import org.junit.jupiter.api.BeforeAll; + +public class BrotliEncoderTest extends AbstractEncoderTest { + + private EmbeddedChannel ENCODER_CHANNEL; + private EmbeddedChannel DECODER_CHANNEL; + + @BeforeAll + static void setUp() { + try { + Brotli.ensureAvailability(); + } catch (Throwable throwable) { + throw new ExceptionInInitializerError(throwable); + } + } + + @Override + public EmbeddedChannel createChannel() { + // Setup Encoder and Decoder + ENCODER_CHANNEL = new EmbeddedChannel(new BrotliEncoder()); + DECODER_CHANNEL = new EmbeddedChannel(new BrotliDecoder()); + + // Return the main channel (Encoder) + return ENCODER_CHANNEL; + } + + @Override + public void destroyChannel() { + ENCODER_CHANNEL.finishAndReleaseAll(); + DECODER_CHANNEL.finishAndReleaseAll(); + } + + @Override + protected ByteBuf decompress(ByteBuf compressed, int originalLength) { + DECODER_CHANNEL.writeInbound(compressed); + + ByteBuf aggregatedBuffer = Unpooled.buffer(); + ByteBuf decompressed = DECODER_CHANNEL.readInbound(); + while (decompressed != null) { + aggregatedBuffer.writeBytes(decompressed); + + decompressed.release(); + decompressed = DECODER_CHANNEL.readInbound(); + } + + return aggregatedBuffer; + } + + @Override + protected ByteBuf readDecompressed(final int dataLength) throws Exception { + CompositeByteBuf decompressed = Unpooled.compositeBuffer(); + ByteBuf msg; + while ((msg = channel.readOutbound()) != null) { + if (msg.isReadable()) { + decompressed.addComponent(true, decompress(msg, -1)); + } else { + msg.release(); + } + } + return decompressed; + } +} diff --git a/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/ByteBufChecksumTest.java b/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/ByteBufChecksumTest.java new file mode 100644 index 0000000..bf3af61 --- /dev/null +++ b/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/ByteBufChecksumTest.java @@ -0,0 +1,90 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.compression; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import net.jpountz.xxhash.XXHashFactory; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +import java.util.Random; +import java.util.zip.Adler32; +import java.util.zip.CRC32; +import java.util.zip.Checksum; + +import static io.netty.handler.codec.compression.Lz4Constants.DEFAULT_SEED; +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class ByteBufChecksumTest { + + private static final byte[] BYTE_ARRAY = new byte[1024]; + + @BeforeAll + public static void setUp() { + new Random().nextBytes(BYTE_ARRAY); + } + + @Test + public void testHeapByteBufUpdate() { + testUpdate(Unpooled.wrappedBuffer(BYTE_ARRAY)); + } + + @Test + public void testDirectByteBufUpdate() { + ByteBuf buf = Unpooled.directBuffer(BYTE_ARRAY.length); + buf.writeBytes(BYTE_ARRAY); + testUpdate(buf); + } + + private static void testUpdate(ByteBuf buf) { + try { + // all variations of xxHash32: slow and naive, optimised, wrapped optimised; + // the last two should be literally identical, but it's best to guard against + // an accidental regression in ByteBufChecksum#wrapChecksum(Checksum) + testUpdate(xxHash32(DEFAULT_SEED), ByteBufChecksum.wrapChecksum(xxHash32(DEFAULT_SEED)), buf); + testUpdate(xxHash32(DEFAULT_SEED), new Lz4XXHash32(DEFAULT_SEED), buf); + testUpdate(xxHash32(DEFAULT_SEED), ByteBufChecksum.wrapChecksum(new Lz4XXHash32(DEFAULT_SEED)), buf); + + // CRC32 and Adler32, special-cased to use ReflectiveByteBufChecksum + testUpdate(new CRC32(), ByteBufChecksum.wrapChecksum(new CRC32()), buf); + testUpdate(new Adler32(), ByteBufChecksum.wrapChecksum(new Adler32()), buf); + } finally { + buf.release(); + } + } + + private static void testUpdate(Checksum checksum, ByteBufChecksum wrapped, ByteBuf buf) { + testUpdate(checksum, wrapped, buf, 0, BYTE_ARRAY.length); + testUpdate(checksum, wrapped, buf, 0, BYTE_ARRAY.length - 1); + testUpdate(checksum, wrapped, buf, 1, BYTE_ARRAY.length - 1); + testUpdate(checksum, wrapped, buf, 1, BYTE_ARRAY.length - 2); + } + + private static void testUpdate(Checksum checksum, ByteBufChecksum wrapped, ByteBuf buf, int off, int len) { + checksum.reset(); + wrapped.reset(); + + checksum.update(BYTE_ARRAY, off, len); + wrapped.update(buf, off, len); + + assertEquals(checksum.getValue(), wrapped.getValue()); + } + + private static Checksum xxHash32(int seed) { + return XXHashFactory.fastestInstance().newStreamingHash32(seed).asChecksum(); + } +} diff --git a/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/Bzip2DecoderTest.java b/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/Bzip2DecoderTest.java new file mode 100644 index 0000000..620e57a --- /dev/null +++ b/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/Bzip2DecoderTest.java @@ -0,0 +1,200 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.compression; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.embedded.EmbeddedChannel; +import org.apache.commons.compress.compressors.bzip2.BZip2CompressorOutputStream; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import java.io.ByteArrayOutputStream; +import java.util.Arrays; + +import static io.netty.bzip2.Bzip2Constants.*; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.fail; + +public class Bzip2DecoderTest extends AbstractDecoderTest { + + private static final byte[] DATA = { 0x42, 0x5A, 0x68, 0x37, 0x31, 0x41, 0x59, 0x26, 0x53, + 0x59, 0x77, 0x7B, (byte) 0xCA, (byte) 0xC0, 0x00, 0x00, + 0x00, 0x05, (byte) 0x80, 0x00, 0x01, 0x02, 0x00, 0x04, + 0x20, 0x20, 0x00, 0x30, (byte) 0xCD, 0x34, 0x19, (byte) 0xA6, + (byte) 0x89, (byte) 0x99, (byte) 0xC5, (byte) 0xDC, (byte) 0x91, + 0x4E, 0x14, 0x24, 0x1D, (byte) 0xDE, (byte) 0xF2, (byte) 0xB0, 0x00 }; + + public Bzip2DecoderTest() throws Exception { + } + + @Override + protected EmbeddedChannel createChannel() { + return new EmbeddedChannel(new Bzip2Decoder()); + } + + private void writeInboundDestroyAndExpectDecompressionException(ByteBuf in) { + try { + channel.writeInbound(in); + } finally { + try { + destroyChannel(); + fail(); + } catch (io.netty.bzip2.DecompressionException ignored) { + // expected + } + } + } + + @Test + public void testUnexpectedStreamIdentifier() { + final ByteBuf in = Unpooled.buffer(); + in.writeLong(1823080128301928729L); //random value + assertThrows(DecompressionException.class, new Executable() { + @Override + public void execute() { + writeInboundDestroyAndExpectDecompressionException(in); + } + }, "Unexpected stream identifier contents"); + } + + @Test + public void testInvalidBlockSize() { + final ByteBuf in = Unpooled.buffer(); + in.writeMedium(MAGIC_NUMBER); + in.writeByte('0'); //incorrect block size + + assertThrows(DecompressionException.class, new Executable() { + @Override + public void execute() { + channel.writeInbound(in); + } + }, "block size is invalid"); + } + + @Test + public void testBadBlockHeader() { + final ByteBuf in = Unpooled.buffer(); + in.writeMedium(MAGIC_NUMBER); + in.writeByte('1'); //block size + in.writeMedium(11); //incorrect block header + in.writeMedium(11); //incorrect block header + in.writeInt(11111); //block CRC + + assertThrows(DecompressionException.class, new Executable() { + @Override + public void execute() { + channel.writeInbound(in); + } + }, "bad block header"); + } + + @Test + public void testStreamCrcErrorOfEmptyBlock() { + final ByteBuf in = Unpooled.buffer(); + in.writeMedium(MAGIC_NUMBER); + in.writeByte('1'); //block size + in.writeMedium(END_OF_STREAM_MAGIC_1); + in.writeMedium(END_OF_STREAM_MAGIC_2); + in.writeInt(1); //wrong storedCombinedCRC + + assertThrows(DecompressionException.class, new Executable() { + @Override + public void execute() { + channel.writeInbound(in); + } + }, "stream CRC error"); + } + + @Test + public void testStreamCrcError() { + final byte[] data = Arrays.copyOf(DATA, DATA.length); + data[41] = (byte) 0xDD; + + assertThrows(DecompressionException.class, new Executable() { + @Override + public void execute() { + tryDecodeAndCatchBufLeaks(channel, Unpooled.wrappedBuffer(data)); + } + }, "stream CRC error"); + } + + @Test + public void testIncorrectHuffmanGroupsNumber() { + final byte[] data = Arrays.copyOf(DATA, DATA.length); + data[25] = 0x70; + + final ByteBuf in = Unpooled.wrappedBuffer(data); + assertThrows(DecompressionException.class, new Executable() { + @Override + public void execute() { + channel.writeInbound(in); + } + }, "incorrect huffman groups number"); + } + + @Test + public void testIncorrectSelectorsNumber() { + final byte[] data = Arrays.copyOf(DATA, DATA.length); + data[25] = 0x2F; + + final ByteBuf in = Unpooled.wrappedBuffer(data); + assertThrows(DecompressionException.class, new Executable() { + @Override + public void execute() { + channel.writeInbound(in); + } + }, "incorrect selectors number"); + } + + @Test + public void testBlockCrcError() { + final byte[] data = Arrays.copyOf(DATA, DATA.length); + data[11] = 0x77; + + final ByteBuf in = Unpooled.wrappedBuffer(data); + assertThrows(DecompressionException.class, new Executable() { + @Override + public void execute() { + writeInboundDestroyAndExpectDecompressionException(in); + } + }, "block CRC error"); + } + + @Test + public void testStartPointerInvalid() { + final byte[] data = Arrays.copyOf(DATA, DATA.length); + data[14] = (byte) 0xFF; + + final ByteBuf in = Unpooled.wrappedBuffer(data); + assertThrows(io.netty.bzip2.DecompressionException.class, new Executable() { + @Override + public void execute() { + writeInboundDestroyAndExpectDecompressionException(in); + } + }, "start pointer invalid"); + } + + @Override + protected byte[] compress(byte[] data) throws Exception { + ByteArrayOutputStream os = new ByteArrayOutputStream(); + BZip2CompressorOutputStream bZip2Os = new BZip2CompressorOutputStream(os, MIN_BLOCK_SIZE); + bZip2Os.write(data); + bZip2Os.close(); + + return os.toByteArray(); + } +} diff --git a/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/Bzip2EncoderTest.java b/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/Bzip2EncoderTest.java new file mode 100644 index 0000000..e23a2e1 --- /dev/null +++ b/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/Bzip2EncoderTest.java @@ -0,0 +1,63 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.compression; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufInputStream; +import io.netty.buffer.Unpooled; +import io.netty.channel.embedded.EmbeddedChannel; +import org.apache.commons.compress.compressors.bzip2.BZip2CompressorInputStream; + +import java.io.InputStream; + +import static io.netty.bzip2.Bzip2Constants.*; +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class Bzip2EncoderTest extends AbstractEncoderTest { + + @Override + protected EmbeddedChannel createChannel() { + return new EmbeddedChannel(new Bzip2Encoder(MIN_BLOCK_SIZE)); + } + + @Override + protected ByteBuf decompress(ByteBuf compressed, int originalLength) throws Exception { + InputStream is = new ByteBufInputStream(compressed, true); + BZip2CompressorInputStream bzip2Is = null; + byte[] decompressed = new byte[originalLength]; + try { + bzip2Is = new BZip2CompressorInputStream(is); + int remaining = originalLength; + while (remaining > 0) { + int read = bzip2Is.read(decompressed, originalLength - remaining, remaining); + if (read > 0) { + remaining -= read; + } else { + break; + } + } + assertEquals(-1, bzip2Is.read()); + } finally { + if (bzip2Is != null) { + bzip2Is.close(); + } else { + is.close(); + } + } + + return Unpooled.wrappedBuffer(decompressed); + } +} diff --git a/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/Bzip2IntegrationTest.java b/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/Bzip2IntegrationTest.java new file mode 100644 index 0000000..2b358eb --- /dev/null +++ b/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/Bzip2IntegrationTest.java @@ -0,0 +1,56 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.compression; + +import io.netty.channel.embedded.EmbeddedChannel; +import org.junit.jupiter.api.Test; + +public class Bzip2IntegrationTest extends AbstractIntegrationTest { + + @Override + protected EmbeddedChannel createEncoder() { + return new EmbeddedChannel(new Bzip2Encoder()); + } + + @Override + protected EmbeddedChannel createDecoder() { + return new EmbeddedChannel(new Bzip2Decoder()); + } + + @Test + public void test3Tables() throws Exception { + byte[] data = new byte[500]; + rand.nextBytes(data); + testIdentity(data, true); + testIdentity(data, false); + } + + @Test + public void test4Tables() throws Exception { + byte[] data = new byte[1100]; + rand.nextBytes(data); + testIdentity(data, true); + testIdentity(data, false); + } + + @Test + public void test5Tables() throws Exception { + byte[] data = new byte[2300]; + rand.nextBytes(data); + testIdentity(data, true); + testIdentity(data, false); + } +} diff --git a/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/FastLzIntegrationTest.java b/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/FastLzIntegrationTest.java new file mode 100644 index 0000000..5987042 --- /dev/null +++ b/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/FastLzIntegrationTest.java @@ -0,0 +1,117 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.compression; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.CompositeByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.embedded.EmbeddedChannel; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.*; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; + +public class FastLzIntegrationTest extends AbstractIntegrationTest { + + public static class TestWithChecksum extends AbstractIntegrationTest { + + @Override + protected EmbeddedChannel createEncoder() { + return new EmbeddedChannel(new FastLzFrameEncoder(true)); + } + + @Override + protected EmbeddedChannel createDecoder() { + return new EmbeddedChannel(new FastLzFrameDecoder(true)); + } + } + + public static class TestRandomChecksum extends AbstractIntegrationTest { + + @Override + protected EmbeddedChannel createEncoder() { + return new EmbeddedChannel(new FastLzFrameEncoder(rand.nextBoolean())); + } + + @Override + protected EmbeddedChannel createDecoder() { + return new EmbeddedChannel(new FastLzFrameDecoder(rand.nextBoolean())); + } + } + + @Override + protected EmbeddedChannel createEncoder() { + return new EmbeddedChannel(new FastLzFrameEncoder(rand.nextBoolean())); + } + + @Override + protected EmbeddedChannel createDecoder() { + return new EmbeddedChannel(new FastLzFrameDecoder(rand.nextBoolean())); + } + + @Override // test batched flow of data + protected void testIdentity(final byte[] data, boolean heapBuffer) { + initChannels(); + final ByteBuf original = heapBuffer? Unpooled.wrappedBuffer(data) : + Unpooled.directBuffer(data.length).writeBytes(data); + final CompositeByteBuf compressed = Unpooled.compositeBuffer(); + final CompositeByteBuf decompressed = Unpooled.compositeBuffer(); + + try { + int written = 0, length = rand.nextInt(100); + while (written + length < data.length) { + ByteBuf in = Unpooled.wrappedBuffer(data, written, length); + encoder.writeOutbound(in); + written += length; + length = rand.nextInt(100); + } + ByteBuf in = Unpooled.wrappedBuffer(data, written, data.length - written); + encoder.writeOutbound(in); + encoder.finish(); + + ByteBuf msg; + while ((msg = encoder.readOutbound()) != null) { + compressed.addComponent(true, msg); + } + assertThat(compressed, is(notNullValue())); + + final byte[] compressedArray = new byte[compressed.readableBytes()]; + compressed.readBytes(compressedArray); + written = 0; + length = rand.nextInt(100); + while (written + length < compressedArray.length) { + in = Unpooled.wrappedBuffer(compressedArray, written, length); + decoder.writeInbound(in); + written += length; + length = rand.nextInt(100); + } + in = Unpooled.wrappedBuffer(compressedArray, written, compressedArray.length - written); + decoder.writeInbound(in); + + assertFalse(compressed.isReadable()); + while ((msg = decoder.readInbound()) != null) { + decompressed.addComponent(true, msg); + } + assertEquals(original, decompressed); + } finally { + compressed.release(); + decompressed.release(); + original.release(); + closeChannels(); + } + } +} diff --git a/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/JZlibTest.java b/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/JZlibTest.java new file mode 100644 index 0000000..cf15400 --- /dev/null +++ b/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/JZlibTest.java @@ -0,0 +1,29 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.compression; + +public class JZlibTest extends ZlibTest { + + @Override + protected ZlibEncoder createEncoder(ZlibWrapper wrapper) { + return new JZlibEncoder(wrapper); + } + + @Override + protected ZlibDecoder createDecoder(ZlibWrapper wrapper, int maxAllocation) { + return new JZlibDecoder(wrapper, maxAllocation); + } +} diff --git a/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/JdkZlibTest.java b/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/JdkZlibTest.java new file mode 100644 index 0000000..d6d3626 --- /dev/null +++ b/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/JdkZlibTest.java @@ -0,0 +1,213 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.compression; + +import io.netty.buffer.AbstractByteBufAllocator; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.util.CharsetUtil; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.internal.PlatformDependent; +import org.apache.commons.compress.utils.IOUtils; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.util.Arrays; +import java.util.Queue; +import java.util.zip.GZIPOutputStream; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + + +public class JdkZlibTest extends ZlibTest { + + @Override + protected ZlibEncoder createEncoder(ZlibWrapper wrapper) { + return new JdkZlibEncoder(wrapper); + } + + @Override + protected ZlibDecoder createDecoder(ZlibWrapper wrapper, int maxAllocation) { + return new JdkZlibDecoder(wrapper, maxAllocation); + } + + @Test + @Override + public void testZLIB_OR_NONE3() throws Exception { + assertThrows(DecompressionException.class, new Executable() { + @Override + public void execute() throws Throwable { + JdkZlibTest.super.testZLIB_OR_NONE3(); + } + }); + } + + @Test + // verifies backward compatibility + public void testConcatenatedStreamsReadFirstOnly() throws IOException { + EmbeddedChannel chDecoderGZip = new EmbeddedChannel(createDecoder(ZlibWrapper.GZIP)); + + try { + byte[] bytes = IOUtils.toByteArray(getClass().getResourceAsStream("/multiple.gz")); + + assertTrue(chDecoderGZip.writeInbound(Unpooled.copiedBuffer(bytes))); + Queue messages = chDecoderGZip.inboundMessages(); + assertEquals(1, messages.size()); + + ByteBuf msg = (ByteBuf) messages.poll(); + assertEquals("a", msg.toString(CharsetUtil.UTF_8)); + ReferenceCountUtil.release(msg); + } finally { + assertFalse(chDecoderGZip.finish()); + chDecoderGZip.close(); + } + } + + @Test + public void testConcatenatedStreamsReadFully() throws IOException { + EmbeddedChannel chDecoderGZip = new EmbeddedChannel(new JdkZlibDecoder(true)); + + try { + byte[] bytes = IOUtils.toByteArray(getClass().getResourceAsStream("/multiple.gz")); + + assertTrue(chDecoderGZip.writeInbound(Unpooled.copiedBuffer(bytes))); + Queue messages = chDecoderGZip.inboundMessages(); + assertEquals(2, messages.size()); + + for (String s : Arrays.asList("a", "b")) { + ByteBuf msg = (ByteBuf) messages.poll(); + assertEquals(s, msg.toString(CharsetUtil.UTF_8)); + ReferenceCountUtil.release(msg); + } + } finally { + assertFalse(chDecoderGZip.finish()); + chDecoderGZip.close(); + } + } + + @Test + public void testConcatenatedStreamsReadFullyWhenFragmented() throws IOException { + EmbeddedChannel chDecoderGZip = new EmbeddedChannel(new JdkZlibDecoder(true)); + + try { + byte[] bytes = IOUtils.toByteArray(getClass().getResourceAsStream("/multiple.gz")); + + // Let's feed the input byte by byte to simulate fragmentation. + ByteBuf buf = Unpooled.copiedBuffer(bytes); + boolean written = false; + while (buf.isReadable()) { + written |= chDecoderGZip.writeInbound(buf.readRetainedSlice(1)); + } + buf.release(); + + assertTrue(written); + Queue messages = chDecoderGZip.inboundMessages(); + assertEquals(2, messages.size()); + + for (String s : Arrays.asList("a", "b")) { + ByteBuf msg = (ByteBuf) messages.poll(); + assertEquals(s, msg.toString(CharsetUtil.UTF_8)); + ReferenceCountUtil.release(msg); + } + } finally { + assertFalse(chDecoderGZip.finish()); + chDecoderGZip.close(); + } + } + + @Test + public void testDecodeWithHeaderFollowingFooter() throws Exception { + byte[] bytes = new byte[1024]; + PlatformDependent.threadLocalRandom().nextBytes(bytes); + ByteArrayOutputStream bytesOut = new ByteArrayOutputStream(); + GZIPOutputStream out = new GZIPOutputStream(bytesOut); + out.write(bytes); + out.close(); + + byte[] compressed = bytesOut.toByteArray(); + ByteBuf buffer = Unpooled.buffer().writeBytes(compressed).writeBytes(compressed); + EmbeddedChannel channel = new EmbeddedChannel(new JdkZlibDecoder(ZlibWrapper.GZIP, true)); + // Write it into the Channel in a way that we were able to decompress the first data completely but not the + // whole footer. + assertTrue(channel.writeInbound(buffer.readRetainedSlice(compressed.length - 1))); + assertTrue(channel.writeInbound(buffer)); + assertTrue(channel.finish()); + + ByteBuf uncompressedBuffer = Unpooled.wrappedBuffer(bytes); + ByteBuf read = channel.readInbound(); + assertEquals(uncompressedBuffer, read); + read.release(); + + read = channel.readInbound(); + assertEquals(uncompressedBuffer, read); + read.release(); + + assertNull(channel.readInbound()); + uncompressedBuffer.release(); + } + + @Test + public void testLargeEncode() throws Exception { + // construct a 128M buffer out of many times the same 1M buffer :) + byte[] smallArray = new byte[1024 * 1024]; + byte[][] arrayOfArrays = new byte[128][]; + Arrays.fill(arrayOfArrays, smallArray); + ByteBuf bigBuffer = Unpooled.wrappedBuffer(arrayOfArrays); + + EmbeddedChannel channel = new EmbeddedChannel(new JdkZlibEncoder(ZlibWrapper.NONE)); + channel.config().setAllocator(new LimitedByteBufAllocator(channel.alloc())); + assertTrue(channel.writeOutbound(bigBuffer)); + assertTrue(channel.finish()); + channel.checkException(); + assertTrue(channel.releaseOutbound()); + } + + /** + * Allocator that will limit buffer capacity to 1M. + */ + private static final class LimitedByteBufAllocator extends AbstractByteBufAllocator { + private static final int MAX = 1024 * 1024; + + private final ByteBufAllocator wrapped; + + LimitedByteBufAllocator(ByteBufAllocator wrapped) { + this.wrapped = wrapped; + } + + @Override + public boolean isDirectBufferPooled() { + return wrapped.isDirectBufferPooled(); + } + + @Override + protected ByteBuf newHeapBuffer(int initialCapacity, int maxCapacity) { + return wrapped.heapBuffer(initialCapacity, Math.min(maxCapacity, MAX)); + } + + @Override + protected ByteBuf newDirectBuffer(int initialCapacity, int maxCapacity) { + return wrapped.directBuffer(initialCapacity, Math.min(maxCapacity, MAX)); + } + } +} diff --git a/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/LengthAwareLzfIntegrationTest.java b/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/LengthAwareLzfIntegrationTest.java new file mode 100644 index 0000000..09b1313 --- /dev/null +++ b/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/LengthAwareLzfIntegrationTest.java @@ -0,0 +1,28 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.compression; + +import io.netty.channel.embedded.EmbeddedChannel; + +import static com.ning.compress.lzf.LZFChunk.MAX_CHUNK_LEN; + +public class LengthAwareLzfIntegrationTest extends LzfIntegrationTest { + + @Override + protected EmbeddedChannel createEncoder() { + return new EmbeddedChannel(new LzfEncoder(false, MAX_CHUNK_LEN, 2 * 1024 * 1024)); + } +} diff --git a/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/Lz4FrameDecoderTest.java b/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/Lz4FrameDecoderTest.java new file mode 100644 index 0000000..01338f6 --- /dev/null +++ b/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/Lz4FrameDecoderTest.java @@ -0,0 +1,160 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.compression; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.embedded.EmbeddedChannel; +import net.jpountz.lz4.LZ4BlockOutputStream; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import java.io.ByteArrayOutputStream; +import java.util.Arrays; + +import static io.netty.handler.codec.compression.Lz4Constants.*; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class Lz4FrameDecoderTest extends AbstractDecoderTest { + + private static final byte[] DATA = { 0x4C, 0x5A, 0x34, 0x42, 0x6C, 0x6F, 0x63, 0x6B, // magic bytes + 0x16, // token + 0x05, 0x00, 0x00, 0x00, 0x05, 0x00, 0x00, 0x00, // compr. and decompr. length + (byte) 0x86, (byte) 0xE4, 0x79, 0x0F, // checksum + 0x4E, 0x65, 0x74, 0x74, 0x79, // data + 0x4C, 0x5A, 0x34, 0x42, 0x6C, 0x6F, 0x63, 0x6B, // magic bytes + 0x16, // token + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, // last empty block + 0x00, 0x00, 0x00, 0x00 }; + + public Lz4FrameDecoderTest() throws Exception { + } + + @Override + protected EmbeddedChannel createChannel() { + return new EmbeddedChannel(new Lz4FrameDecoder(true)); + } + + @Test + public void testUnexpectedBlockIdentifier() { + final byte[] data = Arrays.copyOf(DATA, DATA.length); + data[1] = 0x00; + + final ByteBuf in = Unpooled.wrappedBuffer(data); + assertThrows(DecompressionException.class, new Executable() { + @Override + public void execute() { + channel.writeInbound(in); + } + }, "unexpected block identifier"); + } + + @Test + public void testInvalidCompressedLength() { + final byte[] data = Arrays.copyOf(DATA, DATA.length); + data[12] = (byte) 0xFF; + + final ByteBuf in = Unpooled.wrappedBuffer(data); + assertThrows(DecompressionException.class, new Executable() { + @Override + public void execute() { + channel.writeInbound(in); + } + }, "invalid compressedLength"); + } + + @Test + public void testInvalidDecompressedLength() { + final byte[] data = Arrays.copyOf(DATA, DATA.length); + data[16] = (byte) 0xFF; + + final ByteBuf in = Unpooled.wrappedBuffer(data); + assertThrows(DecompressionException.class, new Executable() { + @Override + public void execute() { + channel.writeInbound(in); + } + }, "invalid decompressedLength"); + } + + @Test + public void testDecompressedAndCompressedLengthMismatch() { + final byte[] data = Arrays.copyOf(DATA, DATA.length); + data[13] = 0x01; + + final ByteBuf in = Unpooled.wrappedBuffer(data); + assertThrows(DecompressionException.class, new Executable() { + @Override + public void execute() { + channel.writeInbound(in); + } + }, "mismatch"); + } + + @Test + public void testUnexpectedBlockType() { + final byte[] data = Arrays.copyOf(DATA, DATA.length); + data[8] = 0x36; + + final ByteBuf in = Unpooled.wrappedBuffer(data); + assertThrows(DecompressionException.class, new Executable() { + @Override + public void execute() { + channel.writeInbound(in); + } + }, "unexpected blockType"); + } + + @Test + public void testMismatchingChecksum() { + final byte[] data = Arrays.copyOf(DATA, DATA.length); + data[17] = 0x01; + + final ByteBuf in = Unpooled.wrappedBuffer(data); + assertThrows(DecompressionException.class, new Executable() { + @Override + public void execute() { + channel.writeInbound(in); + } + }, "mismatching checksum"); + } + + @Test + public void testChecksumErrorOfLastBlock() { + final byte[] data = Arrays.copyOf(DATA, DATA.length); + data[44] = 0x01; + + assertThrows(DecompressionException.class, + new Executable() { + @Override + public void execute() { + tryDecodeAndCatchBufLeaks(channel, Unpooled.wrappedBuffer(data)); + } + }, "checksum error"); + } + + @Override + protected byte[] compress(byte[] data) throws Exception { + ByteArrayOutputStream os = new ByteArrayOutputStream(); + int size = MAX_BLOCK_SIZE + 1; + LZ4BlockOutputStream lz4Os = new LZ4BlockOutputStream(os, + rand.nextInt(size - MIN_BLOCK_SIZE) + MIN_BLOCK_SIZE); + lz4Os.write(data); + lz4Os.close(); + + return os.toByteArray(); + } +} diff --git a/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/Lz4FrameEncoderTest.java b/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/Lz4FrameEncoderTest.java new file mode 100644 index 0000000..483e752 --- /dev/null +++ b/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/Lz4FrameEncoderTest.java @@ -0,0 +1,323 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.compression; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufInputStream; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.nio.NioServerSocketChannel; +import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.handler.codec.EncoderException; +import java.util.concurrent.TimeUnit; +import net.jpountz.lz4.LZ4BlockInputStream; +import net.jpountz.lz4.LZ4Factory; +import net.jpountz.xxhash.XXHashFactory; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.api.function.Executable; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import java.io.InputStream; +import java.net.InetSocketAddress; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicReference; +import java.util.zip.Checksum; + +import static io.netty.handler.codec.compression.Lz4Constants.DEFAULT_SEED; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.not; +import static org.hamcrest.core.Is.is; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.when; + +public class Lz4FrameEncoderTest extends AbstractEncoderTest { + /** + * For the purposes of this test, if we pass this (very small) size of buffer into + * {@link Lz4FrameEncoder#allocateBuffer(ChannelHandlerContext, ByteBuf, boolean)}, we should get back + * an empty buffer. + */ + private static final int NONALLOCATABLE_SIZE = 1; + + @Mock + private ChannelHandlerContext ctx; + + /** + * A {@link ByteBuf} for mocking purposes, largely because it's difficult to allocate to huge buffers. + */ + @Mock + private ByteBuf buffer; + + @BeforeEach + public void setup() { + MockitoAnnotations.initMocks(this); + when(ctx.alloc()).thenReturn(ByteBufAllocator.DEFAULT); + } + + @Override + protected EmbeddedChannel createChannel() { + return new EmbeddedChannel(new Lz4FrameEncoder()); + } + + @Override + protected ByteBuf decompress(ByteBuf compressed, int originalLength) throws Exception { + InputStream is = new ByteBufInputStream(compressed, true); + LZ4BlockInputStream lz4Is = null; + byte[] decompressed = new byte[originalLength]; + try { + lz4Is = new LZ4BlockInputStream(is); + int remaining = originalLength; + while (remaining > 0) { + int read = lz4Is.read(decompressed, originalLength - remaining, remaining); + if (read > 0) { + remaining -= read; + } else { + break; + } + } + assertEquals(-1, lz4Is.read()); + } finally { + if (lz4Is != null) { + lz4Is.close(); + } else { + is.close(); + } + } + + return Unpooled.wrappedBuffer(decompressed); + } + + @Test + public void testAllocateDirectBuffer() { + final int blockSize = 100; + testAllocateBuffer(blockSize, blockSize - 13, true); + testAllocateBuffer(blockSize, blockSize * 5, true); + testAllocateBuffer(blockSize, NONALLOCATABLE_SIZE, true); + } + + @Test + public void testAllocateHeapBuffer() { + final int blockSize = 100; + testAllocateBuffer(blockSize, blockSize - 13, false); + testAllocateBuffer(blockSize, blockSize * 5, false); + testAllocateBuffer(blockSize, NONALLOCATABLE_SIZE, false); + } + + private void testAllocateBuffer(int blockSize, int bufSize, boolean preferDirect) { + // allocate the input buffer to an arbitrary size less than the blockSize + ByteBuf in = ByteBufAllocator.DEFAULT.buffer(bufSize, bufSize); + in.writerIndex(in.capacity()); + + ByteBuf out = null; + try { + Lz4FrameEncoder encoder = newEncoder(blockSize, Lz4FrameEncoder.DEFAULT_MAX_ENCODE_SIZE); + out = encoder.allocateBuffer(ctx, in, preferDirect); + assertNotNull(out); + if (NONALLOCATABLE_SIZE == bufSize) { + assertFalse(out.isWritable()); + } else { + assertTrue(out.writableBytes() > 0); + if (!preferDirect) { + // Only check if preferDirect is not true as if a direct buffer is returned or not depends on + // if sun.misc.Unsafe is present. + assertFalse(out.isDirect()); + } + } + } finally { + in.release(); + if (out != null) { + out.release(); + } + } + } + + @Test + public void testAllocateDirectBufferExceedMaxEncodeSize() { + final int maxEncodeSize = 1024; + final Lz4FrameEncoder encoder = newEncoder(Lz4Constants.DEFAULT_BLOCK_SIZE, maxEncodeSize); + int inputBufferSize = maxEncodeSize * 10; + final ByteBuf buf = ByteBufAllocator.DEFAULT.buffer(inputBufferSize, inputBufferSize); + try { + buf.writerIndex(inputBufferSize); + assertThrows(EncoderException.class, new Executable() { + @Override + public void execute() { + encoder.allocateBuffer(ctx, buf, false); + } + }); + } finally { + buf.release(); + } + } + + private Lz4FrameEncoder newEncoder(int blockSize, int maxEncodeSize) { + Checksum checksum = XXHashFactory.fastestInstance().newStreamingHash32(DEFAULT_SEED).asChecksum(); + Lz4FrameEncoder encoder = new Lz4FrameEncoder(LZ4Factory.fastestInstance(), true, + blockSize, + checksum, + maxEncodeSize); + encoder.handlerAdded(ctx); + return encoder; + } + + /** + * This test might be a invasive in terms of knowing what happens inside + * {@link Lz4FrameEncoder#allocateBuffer(ChannelHandlerContext, ByteBuf, boolean)}, but this is safest way + * of testing the overflow conditions as allocating the huge buffers fails in many CI environments. + */ + @Test + public void testAllocateOnHeapBufferOverflowsOutputSize() { + final int maxEncodeSize = Integer.MAX_VALUE; + final Lz4FrameEncoder encoder = newEncoder(Lz4Constants.DEFAULT_BLOCK_SIZE, maxEncodeSize); + when(buffer.readableBytes()).thenReturn(maxEncodeSize); + buffer.writerIndex(maxEncodeSize); + assertThrows(EncoderException.class, new Executable() { + @Override + public void execute() { + encoder.allocateBuffer(ctx, buffer, false); + } + }); + } + + @Test + public void testFlush() { + Lz4FrameEncoder encoder = new Lz4FrameEncoder(); + EmbeddedChannel channel = new EmbeddedChannel(encoder); + int size = 27; + ByteBuf buf = ByteBufAllocator.DEFAULT.buffer(size, size); + buf.writerIndex(size); + assertEquals(0, encoder.getBackingBuffer().readableBytes()); + channel.write(buf); + assertTrue(channel.outboundMessages().isEmpty()); + assertEquals(size, encoder.getBackingBuffer().readableBytes()); + channel.flush(); + assertTrue(channel.finish()); + assertTrue(channel.releaseOutbound()); + assertFalse(channel.releaseInbound()); + } + + @Test + public void testAllocatingAroundBlockSize() { + int blockSize = 100; + Lz4FrameEncoder encoder = newEncoder(blockSize, Lz4FrameEncoder.DEFAULT_MAX_ENCODE_SIZE); + EmbeddedChannel channel = new EmbeddedChannel(encoder); + + int size = blockSize - 1; + ByteBuf buf = ByteBufAllocator.DEFAULT.buffer(size, size); + buf.writerIndex(size); + assertEquals(0, encoder.getBackingBuffer().readableBytes()); + channel.write(buf); + assertEquals(size, encoder.getBackingBuffer().readableBytes()); + + int nextSize = size - 1; + buf = ByteBufAllocator.DEFAULT.buffer(nextSize, nextSize); + buf.writerIndex(nextSize); + channel.write(buf); + assertEquals(size + nextSize - blockSize, encoder.getBackingBuffer().readableBytes()); + + channel.flush(); + assertEquals(0, encoder.getBackingBuffer().readableBytes()); + assertTrue(channel.finish()); + assertTrue(channel.releaseOutbound()); + assertFalse(channel.releaseInbound()); + } + + @Test + @Timeout(value = 3000, unit = TimeUnit.MILLISECONDS) + public void writingAfterClosedChannelDoesNotNPE() throws InterruptedException { + EventLoopGroup group = new NioEventLoopGroup(2); + Channel serverChannel = null; + Channel clientChannel = null; + final CountDownLatch latch = new CountDownLatch(1); + final AtomicReference writeFailCauseRef = new AtomicReference(); + try { + ServerBootstrap sb = new ServerBootstrap(); + sb.group(group); + sb.channel(NioServerSocketChannel.class); + sb.childHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + } + }); + + Bootstrap bs = new Bootstrap(); + bs.group(group); + bs.channel(NioSocketChannel.class); + bs.handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ch.pipeline().addLast(new Lz4FrameEncoder()); + } + }); + + serverChannel = sb.bind(new InetSocketAddress(0)).syncUninterruptibly().channel(); + clientChannel = bs.connect(serverChannel.localAddress()).syncUninterruptibly().channel(); + + final Channel finalClientChannel = clientChannel; + clientChannel.eventLoop().execute(new Runnable() { + @Override + public void run() { + finalClientChannel.close(); + final int size = 27; + ByteBuf buf = ByteBufAllocator.DEFAULT.buffer(size, size); + finalClientChannel.writeAndFlush(buf.writerIndex(buf.writerIndex() + size)) + .addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + try { + writeFailCauseRef.set(future.cause()); + } finally { + latch.countDown(); + } + } + }); + } + }); + latch.await(); + Throwable writeFailCause = writeFailCauseRef.get(); + assertNotNull(writeFailCause); + Throwable writeFailCauseCause = writeFailCause.getCause(); + if (writeFailCauseCause != null) { + assertThat(writeFailCauseCause, is(not(instanceOf(NullPointerException.class)))); + } + } finally { + if (serverChannel != null) { + serverChannel.close(); + } + if (clientChannel != null) { + clientChannel.close(); + } + group.shutdownGracefully(); + } + } +} diff --git a/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/Lz4FrameIntegrationTest.java b/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/Lz4FrameIntegrationTest.java new file mode 100644 index 0000000..decc779 --- /dev/null +++ b/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/Lz4FrameIntegrationTest.java @@ -0,0 +1,31 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.compression; + +import io.netty.channel.embedded.EmbeddedChannel; + +public class Lz4FrameIntegrationTest extends AbstractIntegrationTest { + + @Override + protected EmbeddedChannel createEncoder() { + return new EmbeddedChannel(new Lz4FrameEncoder()); + } + + @Override + protected EmbeddedChannel createDecoder() { + return new EmbeddedChannel(new Lz4FrameDecoder()); + } +} diff --git a/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/LzfDecoderTest.java b/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/LzfDecoderTest.java new file mode 100644 index 0000000..551d39d --- /dev/null +++ b/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/LzfDecoderTest.java @@ -0,0 +1,73 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.compression; + +import com.ning.compress.lzf.LZFEncoder; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.embedded.EmbeddedChannel; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import static com.ning.compress.lzf.LZFChunk.*; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class LzfDecoderTest extends AbstractDecoderTest { + + public LzfDecoderTest() throws Exception { + } + + @Override + protected EmbeddedChannel createChannel() { + return new EmbeddedChannel(new LzfDecoder()); + } + + @Test + public void testUnexpectedBlockIdentifier() { + final ByteBuf in = Unpooled.buffer(); + in.writeShort(0x1234); //random value + in.writeByte(BLOCK_TYPE_NON_COMPRESSED); + in.writeShort(0); + + assertThrows(DecompressionException.class, new Executable() { + @Override + public void execute() { + channel.writeInbound(in); + } + }, "unexpected block identifier"); + } + + @Test + public void testUnknownTypeOfChunk() { + final ByteBuf in = Unpooled.buffer(); + in.writeByte(BYTE_Z); + in.writeByte(BYTE_V); + in.writeByte(0xFF); //random value + in.writeInt(0); + + assertThrows(DecompressionException.class, new Executable() { + @Override + public void execute() { + channel.writeInbound(in); + } + }, "unknown type of chunk"); + } + + @Override + protected byte[] compress(byte[] data) throws Exception { + return LZFEncoder.encode(data); + } +} diff --git a/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/LzfEncoderTest.java b/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/LzfEncoderTest.java new file mode 100644 index 0000000..e006a95 --- /dev/null +++ b/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/LzfEncoderTest.java @@ -0,0 +1,39 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.compression; + +import com.ning.compress.lzf.LZFDecoder; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.embedded.EmbeddedChannel; + +public class LzfEncoderTest extends AbstractEncoderTest { + + @Override + protected EmbeddedChannel createChannel() { + return new EmbeddedChannel(new LzfEncoder()); + } + + @Override + protected ByteBuf decompress(ByteBuf compressed, int originalLength) throws Exception { + byte[] compressedArray = new byte[compressed.readableBytes()]; + compressed.readBytes(compressedArray); + compressed.release(); + + byte[] decompressed = LZFDecoder.decode(compressedArray); + return Unpooled.wrappedBuffer(decompressed); + } +} diff --git a/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/LzfIntegrationTest.java b/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/LzfIntegrationTest.java new file mode 100644 index 0000000..23cf3ba --- /dev/null +++ b/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/LzfIntegrationTest.java @@ -0,0 +1,31 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.compression; + +import io.netty.channel.embedded.EmbeddedChannel; + +public class LzfIntegrationTest extends AbstractIntegrationTest { + + @Override + protected EmbeddedChannel createEncoder() { + return new EmbeddedChannel(new LzfEncoder()); + } + + @Override + protected EmbeddedChannel createDecoder() { + return new EmbeddedChannel(new LzfDecoder()); + } +} diff --git a/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/SnappyFrameDecoderTest.java b/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/SnappyFrameDecoderTest.java new file mode 100644 index 0000000..aee05aa --- /dev/null +++ b/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/SnappyFrameDecoderTest.java @@ -0,0 +1,225 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.compression; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.embedded.EmbeddedChannel; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class SnappyFrameDecoderTest { + private EmbeddedChannel channel; + + @BeforeEach + public void initChannel() { + channel = new EmbeddedChannel(new SnappyFrameDecoder()); + } + + @AfterEach + public void tearDown() { + assertFalse(channel.finishAndReleaseAll()); + } + + @Test + public void testReservedUnskippableChunkTypeCausesError() { + final ByteBuf in = Unpooled.wrappedBuffer(new byte[] { + 0x03, 0x01, 0x00, 0x00, 0x00 + }); + + assertThrows(DecompressionException.class, new Executable() { + @Override + public void execute() { + channel.writeInbound(in); + } + }); + } + + @Test + public void testInvalidStreamIdentifierLength() { + final ByteBuf in = Unpooled.wrappedBuffer(new byte[] { + -0x80, 0x05, 0x00, 0x00, 'n', 'e', 't', 't', 'y' + }); + + assertThrows(DecompressionException.class, new Executable() { + @Override + public void execute() { + channel.writeInbound(in); + } + }); + } + + @Test + public void testInvalidStreamIdentifierValue() { + final ByteBuf in = Unpooled.wrappedBuffer(new byte[] { + (byte) 0xff, 0x06, 0x00, 0x00, 's', 'n', 'e', 't', 't', 'y' + }); + + assertThrows(DecompressionException.class, new Executable() { + @Override + public void execute() { + channel.writeInbound(in); + } + }); + } + + @Test + public void testReservedSkippableBeforeStreamIdentifier() { + final ByteBuf in = Unpooled.wrappedBuffer(new byte[] { + -0x7f, 0x06, 0x00, 0x00, 's', 'n', 'e', 't', 't', 'y' + }); + + assertThrows(DecompressionException.class, new Executable() { + @Override + public void execute() { + channel.writeInbound(in); + } + }); + } + + @Test + public void testUncompressedDataBeforeStreamIdentifier() { + final ByteBuf in = Unpooled.wrappedBuffer(new byte[] { + 0x01, 0x05, 0x00, 0x00, 'n', 'e', 't', 't', 'y' + }); + + assertThrows(DecompressionException.class, new Executable() { + @Override + public void execute() throws Throwable { + channel.writeInbound(in); + } + }); + } + + @Test + public void testCompressedDataBeforeStreamIdentifier() { + final ByteBuf in = Unpooled.wrappedBuffer(new byte[] { + 0x00, 0x05, 0x00, 0x00, 'n', 'e', 't', 't', 'y' + }); + + assertThrows(DecompressionException.class, new Executable() { + @Override + public void execute() { + channel.writeInbound(in); + } + }); + } + + @Test + public void testReservedSkippableSkipsInput() { + ByteBuf in = Unpooled.wrappedBuffer(new byte[] { + (byte) 0xff, 0x06, 0x00, 0x00, 0x73, 0x4e, 0x61, 0x50, 0x70, 0x59, + -0x7f, 0x05, 0x00, 0x00, 'n', 'e', 't', 't', 'y' + }); + + assertFalse(channel.writeInbound(in)); + assertNull(channel.readInbound()); + + assertFalse(in.isReadable()); + } + + @Test + public void testUncompressedDataAppendsToOut() { + ByteBuf in = Unpooled.wrappedBuffer(new byte[] { + (byte) 0xff, 0x06, 0x00, 0x00, 0x73, 0x4e, 0x61, 0x50, 0x70, 0x59, + 0x01, 0x09, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 'n', 'e', 't', 't', 'y' + }); + + assertTrue(channel.writeInbound(in)); + + ByteBuf expected = Unpooled.wrappedBuffer(new byte[] { 'n', 'e', 't', 't', 'y' }); + ByteBuf actual = channel.readInbound(); + assertEquals(expected, actual); + + expected.release(); + actual.release(); + } + + @Test + public void testCompressedDataDecodesAndAppendsToOut() { + ByteBuf in = Unpooled.wrappedBuffer(new byte[] { + (byte) 0xff, 0x06, 0x00, 0x00, 0x73, 0x4e, 0x61, 0x50, 0x70, 0x59, + 0x00, 0x0B, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x05, // preamble length + 0x04 << 2, // literal tag + length + 0x6e, 0x65, 0x74, 0x74, 0x79 // "netty" + }); + + assertTrue(channel.writeInbound(in)); + + ByteBuf expected = Unpooled.wrappedBuffer(new byte[] { 'n', 'e', 't', 't', 'y' }); + ByteBuf actual = channel.readInbound(); + + assertEquals(expected, actual); + + expected.release(); + actual.release(); + } + + // The following two tests differ in only the checksum provided for the literal + // uncompressed string "netty" + + @Test + public void testInvalidChecksumThrowsException() { + final EmbeddedChannel channel = new EmbeddedChannel(new SnappyFrameDecoder(true)); + try { + // checksum here is presented as 0 + final ByteBuf in = Unpooled.wrappedBuffer(new byte[]{ + (byte) 0xff, 0x06, 0x00, 0x00, 0x73, 0x4e, 0x61, 0x50, 0x70, 0x59, + 0x01, 0x09, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 'n', 'e', 't', 't', 'y' + }); + + assertThrows(DecompressionException.class, new Executable() { + @Override + public void execute() { + channel.writeInbound(in); + } + }); + } finally { + channel.finishAndReleaseAll(); + } + } + + @Test + public void testInvalidChecksumDoesNotThrowException() { + EmbeddedChannel channel = new EmbeddedChannel(new SnappyFrameDecoder(true)); + try { + // checksum here is presented as a282986f (little endian) + ByteBuf in = Unpooled.wrappedBuffer(new byte[]{ + (byte) 0xff, 0x06, 0x00, 0x00, 0x73, 0x4e, 0x61, 0x50, 0x70, 0x59, + 0x01, 0x09, 0x00, 0x00, 0x6f, -0x68, 0x2e, -0x47, 'n', 'e', 't', 't', 'y' + }); + + assertTrue(channel.writeInbound(in)); + ByteBuf expected = Unpooled.wrappedBuffer(new byte[] { 'n', 'e', 't', 't', 'y' }); + ByteBuf actual = channel.readInbound(); + assertEquals(expected, actual); + + expected.release(); + actual.release(); + } finally { + channel.finishAndReleaseAll(); + } + } +} diff --git a/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/SnappyFrameEncoderTest.java b/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/SnappyFrameEncoderTest.java new file mode 100644 index 0000000..80ff0f7 --- /dev/null +++ b/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/SnappyFrameEncoderTest.java @@ -0,0 +1,156 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.compression; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.CompositeByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.embedded.EmbeddedChannel; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.*; + +public class SnappyFrameEncoderTest { + private EmbeddedChannel channel; + + @BeforeEach + public void setUp() { + channel = new EmbeddedChannel(new SnappyFrameEncoder()); + } + + @Test + public void testSmallAmountOfDataIsUncompressed() throws Exception { + ByteBuf in = Unpooled.wrappedBuffer(new byte[] { + 'n', 'e', 't', 't', 'y' + }); + + channel.writeOutbound(in); + assertTrue(channel.finish()); + ByteBuf expected = Unpooled.wrappedBuffer(new byte[] { + (byte) 0xff, 0x06, 0x00, 0x00, 0x73, 0x4e, 0x61, 0x50, 0x70, 0x59, + 0x01, 0x09, 0x00, 0x00, 0x6f, -0x68, 0x2e, -0x47, 'n', 'e', 't', 't', 'y' + }); + ByteBuf actual = channel.readOutbound(); + assertEquals(expected, actual); + + expected.release(); + actual.release(); + } + + @Test + public void testLargeAmountOfDataIsCompressed() throws Exception { + ByteBuf in = Unpooled.wrappedBuffer(new byte[] { + 'n', 'e', 't', 't', 'y', 'n', 'e', 't', 't', 'y', + 'n', 'e', 't', 't', 'y', 'n', 'e', 't', 't', 'y' + }); + + channel.writeOutbound(in); + assertTrue(channel.finish()); + + ByteBuf expected = Unpooled.wrappedBuffer(new byte[] { + (byte) 0xff, 0x06, 0x00, 0x00, 0x73, 0x4e, 0x61, 0x50, 0x70, 0x59, + 0x00, 0x0E, 0x00, 0x00, 0x3b, 0x36, -0x7f, 0x37, + 0x14, 0x10, + 'n', 'e', 't', 't', 'y', + 0x3a, 0x05, 0x00 + }); + ByteBuf actual = channel.readOutbound(); + assertEquals(expected, actual); + + expected.release(); + actual.release(); + } + + @Test + public void testStreamStartIsOnlyWrittenOnce() throws Exception { + ByteBuf in = Unpooled.wrappedBuffer(new byte[] { + 'n', 'e', 't', 't', 'y' + }); + + channel.writeOutbound(in.retain()); + in.resetReaderIndex(); // rewind the buffer to write the same data + channel.writeOutbound(in); + assertTrue(channel.finish()); + + ByteBuf expected = Unpooled.wrappedBuffer(new byte[] { + (byte) 0xff, 0x06, 0x00, 0x00, 0x73, 0x4e, 0x61, 0x50, 0x70, 0x59, + 0x01, 0x09, 0x00, 0x00, 0x6f, -0x68, 0x2e, -0x47, 'n', 'e', 't', 't', 'y', + 0x01, 0x09, 0x00, 0x00, 0x6f, -0x68, 0x2e, -0x47, 'n', 'e', 't', 't', 'y', + }); + + CompositeByteBuf actual = Unpooled.compositeBuffer(); + for (;;) { + ByteBuf m = channel.readOutbound(); + if (m == null) { + break; + } + actual.addComponent(true, m); + } + assertEquals(expected, actual); + + expected.release(); + actual.release(); + } + + /** + * This test asserts that if we have a remainder after emitting a copy that + * is less than 4 bytes (ie. the minimum required for a copy), we should + * emit a literal rather than trying to see if we can emit another copy. + */ + @Test + public void testInputBufferOverseek() throws Exception { + ByteBuf in = Unpooled.wrappedBuffer(new byte[] { + 11, 0, // literal + 0, 0, 0, 0, // 1st copy + 16, 65, 96, 119, -22, 79, -43, 76, -75, -93, + 11, 104, 96, -99, 126, -98, 27, -36, 40, 117, + -65, -3, -57, -83, -58, 7, 114, -14, 68, -122, + 124, 88, 118, 54, 45, -26, 117, 13, -45, -9, + 60, -73, -53, -44, 53, 68, -77, -71, 109, 43, + -38, 59, 100, -12, -87, 44, -106, 123, -107, 38, + 13, -117, -23, -49, 29, 21, 26, 66, 1, -1, + -1, // literal + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, // 2nd copy + 66, 0, -104, -49, 16, -120, 22, 8, -52, -54, + -102, -52, -119, -124, -92, -71, 101, -120, -52, -48, + 45, -26, -24, 26, 41, -13, 36, 64, -47, 15, + -124, -7, -16, 91, 96, 0, -93, -42, 101, 20, + -74, 39, -124, 35, 43, -49, -21, -92, -20, -41, + 79, 41, 110, -105, 42, -96, 90, -9, -100, -22, + -62, 91, 2, 35, 113, 117, -71, 66, 1, // literal + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, // copy + -1, 1 // remainder + }); + + channel.writeOutbound(in); + assertTrue(channel.finish()); + ByteBuf out = channel.readOutbound(); + out.release(); + } +} diff --git a/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/SnappyIntegrationTest.java b/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/SnappyIntegrationTest.java new file mode 100644 index 0000000..5e16178 --- /dev/null +++ b/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/SnappyIntegrationTest.java @@ -0,0 +1,117 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.compression; + +import io.netty.channel.embedded.EmbeddedChannel; +import org.junit.jupiter.api.Test; + +import java.util.Random; + +public class SnappyIntegrationTest extends AbstractIntegrationTest { + + /** + * The number of random regression tests run by testRandom() runs. Whenever testRandom() finds the case that + * the snappy codec can't encode/decode, it will print the generated source code of the offending test case. + * You can always reproduce the problem using it rather than relying on testRandom(). + * + * The default is 1, but you can increase it to increase the chance of finding any unusual cases. + **/ + private static final int RANDOM_RUNS = 1; + + @Override + protected EmbeddedChannel createEncoder() { + return new EmbeddedChannel(new SnappyFrameEncoder()); + } + + @Override + protected EmbeddedChannel createDecoder() { + return new EmbeddedChannel(new SnappyFrameDecoder()); + } + + @Test + public void test1002() throws Throwable { + // Data from https://github.com/netty/netty/issues/1002 + final byte[] data = { + 11, 0, 0, 0, 0, 0, 16, 65, 96, 119, -22, 79, -43, 76, -75, -93, + 11, 104, 96, -99, 126, -98, 27, -36, 40, 117, -65, -3, -57, -83, -58, 7, + 114, -14, 68, -122, 124, 88, 118, 54, 45, -26, 117, 13, -45, -9, 60, -73, + -53, -44, 53, 68, -77, -71, 109, 43, -38, 59, 100, -12, -87, 44, -106, 123, + -107, 38, 13, -117, -23, -49, 29, 21, 26, 66, 1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, 66, 0, -104, -49, + 16, -120, 22, 8, -52, -54, -102, -52, -119, -124, -92, -71, 101, -120, -52, -48, + 45, -26, -24, 26, 41, -13, 36, 64, -47, 15, -124, -7, -16, 91, 96, 0, + -93, -42, 101, 20, -74, 39, -124, 35, 43, -49, -21, -92, -20, -41, 79, 41, + 110, -105, 42, -96, 90, -9, -100, -22, -62, 91, 2, 35, 113, 117, -71, 66, + 1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, -1, + -1, -1 + }; + testIdentity(data, true); + } + + // These tests were found using testRandom() with large RANDOM_RUNS. + + // Tests that copies do not attempt to overrun into a previous frame chunk + @Test + public void test5323211032315942961() { + testWithSeed(5323211032315942961L); + } + + // Tests that when generating the hash lookup table for finding copies, we + // do not exceed the length of the input when there are no copies + @Test + public void test7088170877360183401() { + testWithSeed(7088170877360183401L); + } + + @Test + public void testRandom() throws Throwable { + for (int i = 0; i < RANDOM_RUNS; i++) { + long seed = rand.nextLong(); + if (seed < 0) { + // Use only positive seed to get prettier test name. :-) + continue; + } + + try { + testWithSeed(seed); + } catch (Throwable t) { + System.out.println("Failed with random seed " + seed + ". Here is a test for it:\n"); + printSeedAsTest(seed); + throw t; + } + } + } + + private void testWithSeed(long seed) { + byte[] data = new byte[16 * 1048576]; + new Random(seed).nextBytes(data); + testIdentity(data, true); + } + + private static void printSeedAsTest(long l) { + System.out.println("@Test"); + System.out.println("@Ignore"); + System.out.println("public void test" + l + "(){"); + System.out.println(" testWithSeed(" + l + "L);"); + System.out.println("}"); + } +} diff --git a/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/SnappyTest.java b/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/SnappyTest.java new file mode 100644 index 0000000..21c3abd --- /dev/null +++ b/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/SnappyTest.java @@ -0,0 +1,359 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.compression; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import static io.netty.handler.codec.compression.Snappy.*; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.nio.CharBuffer; + +public class SnappyTest { + private final Snappy snappy = new Snappy(); + + @AfterEach + public void resetSnappy() { + snappy.reset(); + } + + @Test + public void testDecodeLiteral() throws Exception { + ByteBuf in = Unpooled.wrappedBuffer(new byte[] { + 0x05, // preamble length + 0x04 << 2, // literal tag + length + 0x6e, 0x65, 0x74, 0x74, 0x79 // "netty" + }); + ByteBuf out = Unpooled.buffer(5); + snappy.decode(in, out); + + // "netty" + ByteBuf expected = Unpooled.wrappedBuffer(new byte[] { + 0x6e, 0x65, 0x74, 0x74, 0x79 + }); + assertEquals(expected, out, "Literal was not decoded correctly"); + + in.release(); + out.release(); + expected.release(); + } + + @Test + public void testDecodeCopyWith1ByteOffset() throws Exception { + ByteBuf in = Unpooled.wrappedBuffer(new byte[] { + 0x0a, // preamble length + 0x04 << 2, // literal tag + length + 0x6e, 0x65, 0x74, 0x74, 0x79, // "netty" + 0x01 << 2 | 0x01, // copy with 1-byte offset + length + 0x05 // offset + }); + ByteBuf out = Unpooled.buffer(10); + snappy.decode(in, out); + + // "nettynetty" - we saved a whole byte :) + ByteBuf expected = Unpooled.wrappedBuffer(new byte[] { + 0x6e, 0x65, 0x74, 0x74, 0x79, 0x6e, 0x65, 0x74, 0x74, 0x79 + }); + assertEquals(expected, out, "Copy was not decoded correctly"); + + in.release(); + out.release(); + expected.release(); + } + + @Test + public void testDecodeCopyWithTinyOffset() { + final ByteBuf in = Unpooled.wrappedBuffer(new byte[] { + 0x0b, // preamble length + 0x04 << 2, // literal tag + length + 0x6e, 0x65, 0x74, 0x74, 0x79, // "netty" + 0x05 << 2 | 0x01, // copy with 1-byte offset + length + 0x00 // INVALID offset (< 1) + }); + final ByteBuf out = Unpooled.buffer(10); + try { + assertThrows(DecompressionException.class, new Executable() { + @Override + public void execute() { + snappy.decode(in, out); + } + }); + } finally { + in.release(); + out.release(); + } + } + + @Test + public void testDecodeCopyWithOffsetBeforeChunk() { + final ByteBuf in = Unpooled.wrappedBuffer(new byte[] { + 0x0a, // preamble length + 0x04 << 2, // literal tag + length + 0x6e, 0x65, 0x74, 0x74, 0x79, // "netty" + 0x05 << 2 | 0x01, // copy with 1-byte offset + length + 0x0b // INVALID offset (greater than chunk size) + }); + final ByteBuf out = Unpooled.buffer(10); + try { + assertThrows(DecompressionException.class, new Executable() { + @Override + public void execute() { + snappy.decode(in, out); + } + }); + } finally { + in.release(); + out.release(); + } + } + + @Test + public void testDecodeWithOverlyLongPreamble() { + final ByteBuf in = Unpooled.wrappedBuffer(new byte[] { + -0x80, -0x80, -0x80, -0x80, 0x7f, // preamble length + 0x04 << 2, // literal tag + length + 0x6e, 0x65, 0x74, 0x74, 0x79, // "netty" + }); + final ByteBuf out = Unpooled.buffer(10); + try { + assertThrows(DecompressionException.class, new Executable() { + @Override + public void execute() { + snappy.decode(in, out); + } + }); + } finally { + in.release(); + out.release(); + } + } + + @Test + public void encodeShortTextIsLiteral() throws Exception { + ByteBuf in = Unpooled.wrappedBuffer(new byte[] { + 0x6e, 0x65, 0x74, 0x74, 0x79 + }); + ByteBuf out = Unpooled.buffer(7); + snappy.encode(in, out, 5); + + ByteBuf expected = Unpooled.wrappedBuffer(new byte[] { + 0x05, // preamble length + 0x04 << 2, // literal tag + length + 0x6e, 0x65, 0x74, 0x74, 0x79 // "netty" + }); + assertEquals(expected, out, "Encoded literal was invalid"); + + in.release(); + out.release(); + expected.release(); + } + + @Test + public void encodeAndDecodeLongTextUsesCopy() throws Exception { + String srcStr = "Netty has been designed carefully with the experiences " + + "earned from the implementation of a lot of protocols " + + "such as FTP, SMTP, HTTP, and various binary and " + + "text-based legacy protocols"; + ByteBuf in = Unpooled.wrappedBuffer(srcStr.getBytes("US-ASCII")); + ByteBuf out = Unpooled.buffer(180); + snappy.encode(in, out, in.readableBytes()); + + // The only compressibility in the above are the words: + // "the ", "rotocols", " of ", "TP, " and "and ". So this is a literal, + // followed by a copy followed by another literal, followed by another copy... + ByteBuf expected = Unpooled.wrappedBuffer(new byte[] { + -0x49, 0x01, // preamble length + -0x10, 0x42, // literal tag + length + + // Literal + 0x4e, 0x65, 0x74, 0x74, 0x79, 0x20, 0x68, 0x61, 0x73, 0x20, + 0x62, 0x65, 0x65, 0x6e, 0x20, 0x64, 0x65, 0x73, 0x69, 0x67, + 0x6e, 0x65, 0x64, 0x20, 0x63, 0x61, 0x72, 0x65, 0x66, 0x75, + 0x6c, 0x6c, 0x79, 0x20, 0x77, 0x69, 0x74, 0x68, 0x20, 0x74, + 0x68, 0x65, 0x20, 0x65, 0x78, 0x70, 0x65, 0x72, 0x69, 0x65, + 0x6e, 0x63, 0x65, 0x73, 0x20, 0x65, 0x61, 0x72, 0x6e, 0x65, + 0x64, 0x20, 0x66, 0x72, 0x6f, 0x6d, 0x20, + + // copy of "the " + 0x01, 0x1c, 0x58, + + // Next literal + 0x69, 0x6d, 0x70, 0x6c, 0x65, 0x6d, 0x65, 0x6e, 0x74, 0x61, + 0x74, 0x69, 0x6f, 0x6e, 0x20, 0x6f, 0x66, 0x20, 0x61, 0x20, + 0x6c, 0x6f, 0x74, + + // copy of " of " + 0x01, 0x09, 0x60, + + // literal + 0x70, 0x72, 0x6f, 0x74, 0x6f, 0x63, 0x6f, 0x6c, 0x73, 0x20, + 0x73, 0x75, 0x63, 0x68, 0x20, 0x61, 0x73, 0x20, 0x46, 0x54, + 0x50, 0x2c, 0x20, 0x53, 0x4d, + + // copy of " TP, " + 0x01, 0x06, 0x04, + + // literal + 0x48, 0x54, + + // copy of " TP, " + 0x01, 0x06, 0x44, + + // literal + 0x61, 0x6e, 0x64, 0x20, 0x76, 0x61, 0x72, 0x69, 0x6f, 0x75, + 0x73, 0x20, 0x62, 0x69, 0x6e, 0x61, 0x72, 0x79, + + // copy of "and " + 0x05, 0x13, 0x48, + + // literal + 0x74, 0x65, 0x78, 0x74, 0x2d, 0x62, 0x61, 0x73, 0x65, + 0x64, 0x20, 0x6c, 0x65, 0x67, 0x61, 0x63, 0x79, 0x20, 0x70, + + // copy of "rotocols" + 0x11, 0x4c, + }); + + assertEquals(expected, out, "Encoded result was incorrect"); + + // Decode + ByteBuf outDecoded = Unpooled.buffer(); + snappy.decode(out, outDecoded); + assertEquals(CharBuffer.wrap(srcStr), + CharBuffer.wrap(outDecoded.getCharSequence(0, outDecoded.writerIndex(), CharsetUtil.US_ASCII))); + + in.release(); + out.release(); + outDecoded.release(); + } + + @Test + public void testCalculateChecksum() { + ByteBuf input = Unpooled.wrappedBuffer(new byte[] { + 'n', 'e', 't', 't', 'y' + }); + + assertEquals(maskChecksum(0xd6cb8b55L), calculateChecksum(input)); + input.release(); + } + + @Test + public void testMaskChecksum() { + ByteBuf input = Unpooled.wrappedBuffer(new byte[] { + 0x00, 0x00, 0x00, 0x0f, 0x00, 0x00, 0x00, 0x00, + 0x5f, 0x68, 0x65, 0x61, 0x72, 0x74, 0x62, 0x65, + 0x61, 0x74, 0x5f, + }); + assertEquals(0x44a4301f, calculateChecksum(input)); + input.release(); + } + + @Test + public void testValidateChecksumMatches() { + ByteBuf input = Unpooled.wrappedBuffer(new byte[] { + 'y', 't', 't', 'e', 'n' + }); + + validateChecksum(maskChecksum(0x2d4d3535), input); + input.release(); + } + + @Test + public void testValidateChecksumFails() { + final ByteBuf input = Unpooled.wrappedBuffer(new byte[] { + 'y', 't', 't', 'e', 'n' + }); + try { + assertThrows(DecompressionException.class, new Executable() { + @Override + public void execute() { + validateChecksum(maskChecksum(0xd6cb8b55), input); + } + }); + } finally { + input.release(); + } + } + + @Test + public void testEncodeLiteralAndDecodeLiteral() { + int[] lengths = { + 0x11, // default + 0x100, // case 60 + 0x1000, // case 61 + 0x100000, // case 62 + 0x1000001 // case 63 + }; + for (int len : lengths) { + ByteBuf in = Unpooled.wrappedBuffer(new byte[len]); + ByteBuf encoded = Unpooled.buffer(10); + ByteBuf decoded = Unpooled.buffer(10); + ByteBuf expected = Unpooled.wrappedBuffer(new byte[len]); + try { + encodeLiteral(in, encoded, len); + byte tag = encoded.readByte(); + decodeLiteral(tag, encoded, decoded); + assertEquals(expected, decoded, "Encoded or decoded literal was incorrect"); + } finally { + in.release(); + encoded.release(); + decoded.release(); + expected.release(); + } + } + } + + @Test + public void testLarge2ByteLiteralLengthAndCopyOffset() { + ByteBuf compressed = Unpooled.buffer(); + ByteBuf actualDecompressed = Unpooled.buffer(); + ByteBuf expectedDecompressed = Unpooled.buffer().writeByte(0x01).writeZero(0x8000).writeByte(0x01); + try { + // Generate a Snappy-encoded buffer that can only be decompressed correctly if + // the decoder treats 2-byte literal lengths and 2-byte copy offsets as unsigned values. + + // Write preamble, uncompressed content length (0x8002) encoded as varint. + compressed.writeByte(0x82).writeByte(0x80).writeByte(0x02); + + // Write a literal consisting of 0x01 followed by 0x8000 zeroes. + // The total length of this literal is 0x8001, which gets encoded as 0x8000 (length - 1). + // This length was selected because the encoded form is one larger than the maximum value + // representable using a signed 16-bit integer, and we want to assert the decoder is reading + // the length as an unsigned value. + compressed.writeByte(61 << 2); // tag for LITERAL with a 2-byte length + compressed.writeShortLE(0x8000); // length - 1 + compressed.writeByte(0x01).writeZero(0x8000); // literal content + + // Similarly, for a 2-byte copy operation we want to ensure the offset is treated as unsigned. + // Copy the initial 0x01 which was written 0x8001 bytes back in the stream. + compressed.writeByte(0x02); // tag for COPY with 2-byte offset, length = 1 + compressed.writeShortLE(0x8001); // offset + + snappy.decode(compressed, actualDecompressed); + assertEquals(expectedDecompressed, actualDecompressed); + } finally { + compressed.release(); + actualDecompressed.release(); + expectedDecompressed.release(); + } + } +} diff --git a/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/ZlibCrossTest1.java b/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/ZlibCrossTest1.java new file mode 100644 index 0000000..a7c051d --- /dev/null +++ b/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/ZlibCrossTest1.java @@ -0,0 +1,29 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.compression; + +public class ZlibCrossTest1 extends ZlibTest { + + @Override + protected ZlibEncoder createEncoder(ZlibWrapper wrapper) { + return new JdkZlibEncoder(wrapper); + } + + @Override + protected ZlibDecoder createDecoder(ZlibWrapper wrapper, int maxAllocation) { + return new JZlibDecoder(wrapper, maxAllocation); + } +} diff --git a/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/ZlibCrossTest2.java b/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/ZlibCrossTest2.java new file mode 100644 index 0000000..306b652 --- /dev/null +++ b/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/ZlibCrossTest2.java @@ -0,0 +1,45 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.compression; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class ZlibCrossTest2 extends ZlibTest { + + @Override + protected ZlibEncoder createEncoder(ZlibWrapper wrapper) { + return new JZlibEncoder(wrapper); + } + + @Override + protected ZlibDecoder createDecoder(ZlibWrapper wrapper, int maxAllocation) { + return new JdkZlibDecoder(wrapper, maxAllocation); + } + + @Test + @Override + public void testZLIB_OR_NONE3() throws Exception { + assertThrows(DecompressionException.class, new Executable() { + @Override + public void execute() throws Throwable { + ZlibCrossTest2.super.testZLIB_OR_NONE3(); + } + }); + } +} diff --git a/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/ZlibTest.java b/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/ZlibTest.java new file mode 100644 index 0000000..ee7acdd --- /dev/null +++ b/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/ZlibTest.java @@ -0,0 +1,479 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.compression; + +import io.netty.buffer.AbstractByteBufAllocator; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufInputStream; +import io.netty.buffer.Unpooled; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.util.CharsetUtil; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.internal.EmptyArrays; +import io.netty.util.internal.PlatformDependent; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.OutputStream; +import java.util.Random; +import java.util.zip.DeflaterOutputStream; +import java.util.zip.GZIPInputStream; +import java.util.zip.GZIPOutputStream; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public abstract class ZlibTest { + + private static final byte[] BYTES_SMALL = new byte[128]; + private static final byte[] BYTES_LARGE = new byte[1024 * 1024]; + private static final byte[] BYTES_LARGE2 = ("\n" + + "\n" + + "\n" + + " Apache Tomcat\n" + + "\n" + + '\n' + + "\n" + + "

It works !

\n" + + '\n' + + "

If you're seeing this page via a web browser, it means you've setup Tomcat successfully." + + " Congratulations!

\n" + + " \n" + + "

This is the default Tomcat home page." + + " It can be found on the local filesystem at: /var/lib/tomcat7/webapps/ROOT/index.html

\n" + + '\n' + + "

Tomcat7 veterans might be pleased to learn that this system instance of Tomcat is installed with" + + " CATALINA_HOME in /usr/share/tomcat7 and CATALINA_BASE in" + + " /var/lib/tomcat7, following the rules from" + + " /usr/share/doc/tomcat7-common/RUNNING.txt.gz.

\n" + + '\n' + + "

You might consider installing the following packages, if you haven't already done so:

\n" + + '\n' + + "

tomcat7-docs: This package installs a web application that allows to browse the Tomcat 7" + + " documentation locally. Once installed, you can access it by clicking here.

\n" + + '\n' + + "

tomcat7-examples: This package installs a web application that allows to access the Tomcat" + + " 7 Servlet and JSP examples. Once installed, you can access it by clicking" + + " here.

\n" + + '\n' + + "

tomcat7-admin: This package installs two web applications that can help managing this Tomcat" + + " instance. Once installed, you can access the manager webapp and" + + " the host-manager webapp.

\n" + + '\n' + + "

NOTE: For security reasons, using the manager webapp is restricted" + + " to users with role \"manager\"." + + " The host-manager webapp is restricted to users with role \"admin\". Users are " + + "defined in /etc/tomcat7/tomcat-users.xml.

\n" + + '\n' + + '\n' + + '\n' + + "").getBytes(CharsetUtil.UTF_8); + + static { + Random rand = PlatformDependent.threadLocalRandom(); + rand.nextBytes(BYTES_SMALL); + rand.nextBytes(BYTES_LARGE); + } + + protected ZlibDecoder createDecoder(ZlibWrapper wrapper) { + return createDecoder(wrapper, 0); + } + + protected abstract ZlibEncoder createEncoder(ZlibWrapper wrapper); + protected abstract ZlibDecoder createDecoder(ZlibWrapper wrapper, int maxAllocation); + + @Test + public void testGZIP2() throws Exception { + byte[] bytes = "message".getBytes(CharsetUtil.UTF_8); + ByteBuf data = Unpooled.wrappedBuffer(bytes); + ByteBuf deflatedData = Unpooled.wrappedBuffer(gzip(bytes)); + + EmbeddedChannel chDecoderGZip = new EmbeddedChannel(createDecoder(ZlibWrapper.GZIP)); + try { + while (deflatedData.isReadable()) { + chDecoderGZip.writeInbound(deflatedData.readRetainedSlice(1)); + } + deflatedData.release(); + assertTrue(chDecoderGZip.finish()); + ByteBuf buf = Unpooled.buffer(); + for (;;) { + ByteBuf b = chDecoderGZip.readInbound(); + if (b == null) { + break; + } + buf.writeBytes(b); + b.release(); + } + assertEquals(buf, data); + assertNull(chDecoderGZip.readInbound()); + data.release(); + buf.release(); + } finally { + dispose(chDecoderGZip); + } + } + + @Test + public void testGZIP3() throws Exception { + byte[] bytes = "Foo".getBytes(CharsetUtil.UTF_8); + ByteBuf data = Unpooled.wrappedBuffer(bytes); + ByteBuf deflatedData = Unpooled.wrappedBuffer( + new byte[]{ + 31, -117, // magic number + 8, // CM + 2, // FLG.FHCRC + 0, 0, 0, 0, // MTIME + 0, // XFL + 7, // OS + -66, -77, // CRC16 + 115, -53, -49, 7, 0, // compressed blocks + -63, 35, 62, -76, // CRC32 + 3, 0, 0, 0 // ISIZE + } + ); + + EmbeddedChannel chDecoderGZip = new EmbeddedChannel(createDecoder(ZlibWrapper.GZIP)); + try { + while (deflatedData.isReadable()) { + chDecoderGZip.writeInbound(deflatedData.readRetainedSlice(1)); + } + deflatedData.release(); + assertTrue(chDecoderGZip.finish()); + ByteBuf buf = Unpooled.buffer(); + for (;;) { + ByteBuf b = chDecoderGZip.readInbound(); + if (b == null) { + break; + } + buf.writeBytes(b); + b.release(); + } + assertEquals(buf, data); + assertNull(chDecoderGZip.readInbound()); + data.release(); + buf.release(); + } finally { + dispose(chDecoderGZip); + } + } + + private void testCompress0(ZlibWrapper encoderWrapper, ZlibWrapper decoderWrapper, ByteBuf data) throws Exception { + EmbeddedChannel chEncoder = new EmbeddedChannel(createEncoder(encoderWrapper)); + EmbeddedChannel chDecoderZlib = new EmbeddedChannel(createDecoder(decoderWrapper)); + + try { + chEncoder.writeOutbound(data.retain()); + chEncoder.flush(); + data.resetReaderIndex(); + + for (;;) { + ByteBuf deflatedData = chEncoder.readOutbound(); + if (deflatedData == null) { + break; + } + chDecoderZlib.writeInbound(deflatedData); + } + + byte[] decompressed = new byte[data.readableBytes()]; + int offset = 0; + for (;;) { + ByteBuf buf = chDecoderZlib.readInbound(); + if (buf == null) { + break; + } + int length = buf.readableBytes(); + buf.readBytes(decompressed, offset, length); + offset += length; + buf.release(); + if (offset == decompressed.length) { + break; + } + } + assertEquals(data, Unpooled.wrappedBuffer(decompressed)); + assertNull(chDecoderZlib.readInbound()); + + // Closing an encoder channel will generate a footer. + assertTrue(chEncoder.finish()); + for (;;) { + Object msg = chEncoder.readOutbound(); + if (msg == null) { + break; + } + ReferenceCountUtil.release(msg); + } + // But, the footer will be decoded into nothing. It's only for validation. + assertFalse(chDecoderZlib.finish()); + + data.release(); + } finally { + dispose(chEncoder); + dispose(chDecoderZlib); + } + } + + private void testCompressNone(ZlibWrapper encoderWrapper, ZlibWrapper decoderWrapper) throws Exception { + EmbeddedChannel chEncoder = new EmbeddedChannel(createEncoder(encoderWrapper)); + EmbeddedChannel chDecoderZlib = new EmbeddedChannel(createDecoder(decoderWrapper)); + + try { + // Closing an encoder channel without writing anything should generate both header and footer. + assertTrue(chEncoder.finish()); + + for (;;) { + ByteBuf deflatedData = chEncoder.readOutbound(); + if (deflatedData == null) { + break; + } + chDecoderZlib.writeInbound(deflatedData); + } + + // Decoder should not generate anything at all. + boolean decoded = false; + for (;;) { + ByteBuf buf = chDecoderZlib.readInbound(); + if (buf == null) { + break; + } + + buf.release(); + decoded = true; + } + assertFalse(decoded, "should decode nothing"); + + assertFalse(chDecoderZlib.finish()); + } finally { + dispose(chEncoder); + dispose(chDecoderZlib); + } + } + + private static void dispose(EmbeddedChannel ch) { + if (ch.finish()) { + for (;;) { + Object msg = ch.readInbound(); + if (msg == null) { + break; + } + ReferenceCountUtil.release(msg); + } + for (;;) { + Object msg = ch.readOutbound(); + if (msg == null) { + break; + } + ReferenceCountUtil.release(msg); + } + } + } + + // Test for https://github.com/netty/netty/issues/2572 + private void testDecompressOnly(ZlibWrapper decoderWrapper, byte[] compressed, byte[] data) throws Exception { + EmbeddedChannel chDecoder = new EmbeddedChannel(createDecoder(decoderWrapper)); + chDecoder.writeInbound(Unpooled.copiedBuffer(compressed)); + assertTrue(chDecoder.finish()); + + ByteBuf decoded = Unpooled.buffer(data.length); + + for (;;) { + ByteBuf buf = chDecoder.readInbound(); + if (buf == null) { + break; + } + decoded.writeBytes(buf); + buf.release(); + } + assertEquals(Unpooled.copiedBuffer(data), decoded); + decoded.release(); + } + + private void testCompressSmall(ZlibWrapper encoderWrapper, ZlibWrapper decoderWrapper) throws Exception { + testCompress0(encoderWrapper, decoderWrapper, Unpooled.wrappedBuffer(BYTES_SMALL)); + testCompress0(encoderWrapper, decoderWrapper, + Unpooled.directBuffer(BYTES_SMALL.length).writeBytes(BYTES_SMALL)); + } + + private void testCompressLarge(ZlibWrapper encoderWrapper, ZlibWrapper decoderWrapper) throws Exception { + testCompress0(encoderWrapper, decoderWrapper, Unpooled.wrappedBuffer(BYTES_LARGE)); + testCompress0(encoderWrapper, decoderWrapper, + Unpooled.directBuffer(BYTES_LARGE.length).writeBytes(BYTES_LARGE)); + } + + @Test + public void testZLIB() throws Exception { + testCompressNone(ZlibWrapper.ZLIB, ZlibWrapper.ZLIB); + testCompressSmall(ZlibWrapper.ZLIB, ZlibWrapper.ZLIB); + testCompressLarge(ZlibWrapper.ZLIB, ZlibWrapper.ZLIB); + testDecompressOnly(ZlibWrapper.ZLIB, deflate(BYTES_LARGE2), BYTES_LARGE2); + } + + @Test + public void testNONE() throws Exception { + testCompressNone(ZlibWrapper.NONE, ZlibWrapper.NONE); + testCompressSmall(ZlibWrapper.NONE, ZlibWrapper.NONE); + testCompressLarge(ZlibWrapper.NONE, ZlibWrapper.NONE); + } + + @Test + public void testGZIP() throws Exception { + testCompressNone(ZlibWrapper.GZIP, ZlibWrapper.GZIP); + testCompressSmall(ZlibWrapper.GZIP, ZlibWrapper.GZIP); + testCompressLarge(ZlibWrapper.GZIP, ZlibWrapper.GZIP); + testDecompressOnly(ZlibWrapper.GZIP, gzip(BYTES_LARGE2), BYTES_LARGE2); + } + + @Test + public void testGZIPCompressOnly() throws Exception { + testGZIPCompressOnly0(null); // Do not write anything; just finish the stream. + testGZIPCompressOnly0(EmptyArrays.EMPTY_BYTES); // Write an empty array. + testGZIPCompressOnly0(BYTES_SMALL); + testGZIPCompressOnly0(BYTES_LARGE); + } + + private void testGZIPCompressOnly0(byte[] data) throws IOException { + EmbeddedChannel chEncoder = new EmbeddedChannel(createEncoder(ZlibWrapper.GZIP)); + if (data != null) { + chEncoder.writeOutbound(Unpooled.wrappedBuffer(data)); + } + assertTrue(chEncoder.finish()); + + ByteBuf encoded = Unpooled.buffer(); + for (;;) { + ByteBuf buf = chEncoder.readOutbound(); + if (buf == null) { + break; + } + encoded.writeBytes(buf); + buf.release(); + } + + ByteBuf decoded = Unpooled.buffer(); + GZIPInputStream stream = new GZIPInputStream(new ByteBufInputStream(encoded, true)); + try { + byte[] buf = new byte[8192]; + for (;;) { + int readBytes = stream.read(buf); + if (readBytes < 0) { + break; + } + decoded.writeBytes(buf, 0, readBytes); + } + } finally { + stream.close(); + } + + if (data != null) { + assertEquals(Unpooled.wrappedBuffer(data), decoded); + } else { + assertFalse(decoded.isReadable()); + } + + decoded.release(); + } + + @Test + public void testZLIB_OR_NONE() throws Exception { + testCompressNone(ZlibWrapper.NONE, ZlibWrapper.ZLIB_OR_NONE); + testCompressSmall(ZlibWrapper.NONE, ZlibWrapper.ZLIB_OR_NONE); + testCompressLarge(ZlibWrapper.NONE, ZlibWrapper.ZLIB_OR_NONE); + } + + @Test + public void testZLIB_OR_NONE2() throws Exception { + testCompressNone(ZlibWrapper.ZLIB, ZlibWrapper.ZLIB_OR_NONE); + testCompressSmall(ZlibWrapper.ZLIB, ZlibWrapper.ZLIB_OR_NONE); + testCompressLarge(ZlibWrapper.ZLIB, ZlibWrapper.ZLIB_OR_NONE); + } + + @Test + public void testZLIB_OR_NONE3() throws Exception { + testCompressNone(ZlibWrapper.GZIP, ZlibWrapper.ZLIB_OR_NONE); + testCompressSmall(ZlibWrapper.GZIP, ZlibWrapper.ZLIB_OR_NONE); + testCompressLarge(ZlibWrapper.GZIP, ZlibWrapper.ZLIB_OR_NONE); + } + + @Test + public void testMaxAllocation() throws Exception { + int maxAllocation = 1024; + ZlibDecoder decoder = createDecoder(ZlibWrapper.ZLIB, maxAllocation); + final EmbeddedChannel chDecoder = new EmbeddedChannel(decoder); + TestByteBufAllocator alloc = new TestByteBufAllocator(chDecoder.alloc()); + chDecoder.config().setAllocator(alloc); + + DecompressionException e = assertThrows(DecompressionException.class, new Executable() { + @Override + public void execute() throws Throwable { + chDecoder.writeInbound(Unpooled.wrappedBuffer(deflate(BYTES_LARGE))); + } + }); + assertTrue(e.getMessage().startsWith("Decompression buffer has reached maximum size")); + assertEquals(maxAllocation, alloc.getMaxAllocation()); + assertTrue(decoder.isClosed()); + assertFalse(chDecoder.finish()); + } + + private static byte[] gzip(byte[] bytes) throws IOException { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + GZIPOutputStream stream = new GZIPOutputStream(out); + stream.write(bytes); + stream.close(); + return out.toByteArray(); + } + + private static byte[] deflate(byte[] bytes) throws IOException { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + OutputStream stream = new DeflaterOutputStream(out); + stream.write(bytes); + stream.close(); + return out.toByteArray(); + } + + private static final class TestByteBufAllocator extends AbstractByteBufAllocator { + private final ByteBufAllocator wrapped; + private int maxAllocation; + + TestByteBufAllocator(ByteBufAllocator wrapped) { + this.wrapped = wrapped; + } + + public int getMaxAllocation() { + return maxAllocation; + } + + @Override + public boolean isDirectBufferPooled() { + return wrapped.isDirectBufferPooled(); + } + + @Override + protected ByteBuf newHeapBuffer(int initialCapacity, int maxCapacity) { + maxAllocation = Math.max(maxAllocation, maxCapacity); + return wrapped.heapBuffer(initialCapacity, maxCapacity); + } + + @Override + protected ByteBuf newDirectBuffer(int initialCapacity, int maxCapacity) { + maxAllocation = Math.max(maxAllocation, maxCapacity); + return wrapped.directBuffer(initialCapacity, maxCapacity); + } + } +} diff --git a/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/ZstdEncoderTest.java b/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/ZstdEncoderTest.java new file mode 100644 index 0000000..296e3da --- /dev/null +++ b/netty-handler-codec-compression/src/test/java/io/netty/handler/codec/compression/ZstdEncoderTest.java @@ -0,0 +1,108 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.compression; + +import com.github.luben.zstd.ZstdInputStream; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufInputStream; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.embedded.EmbeddedChannel; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import java.io.InputStream; + + +import static org.mockito.Mockito.when; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class ZstdEncoderTest extends AbstractEncoderTest { + + @Mock + private ChannelHandlerContext ctx; + + @BeforeEach + public void setup() { + MockitoAnnotations.initMocks(this); + when(ctx.alloc()).thenReturn(ByteBufAllocator.DEFAULT); + } + + @Override + public EmbeddedChannel createChannel() { + return new EmbeddedChannel(new ZstdEncoder()); + } + + @ParameterizedTest + @MethodSource("largeData") + public void testCompressionOfLargeBatchedFlow(final ByteBuf data) throws Exception { + final int dataLength = data.readableBytes(); + int written = 0; + + ByteBuf in = data.retainedSlice(written, 65535); + assertTrue(channel.writeOutbound(in)); + + ByteBuf in2 = data.retainedSlice(65535, dataLength - 65535); + assertTrue(channel.writeOutbound(in2)); + + assertTrue(channel.finish()); + + ByteBuf decompressed = readDecompressed(dataLength); + assertEquals(data, decompressed); + + decompressed.release(); + data.release(); + } + + @ParameterizedTest + @MethodSource("smallData") + public void testCompressionOfSmallBatchedFlow(final ByteBuf data) throws Exception { + testCompressionOfBatchedFlow(data); + } + + @Override + protected ByteBuf decompress(ByteBuf compressed, int originalLength) throws Exception { + InputStream is = new ByteBufInputStream(compressed, true); + ZstdInputStream zstdIs = null; + byte[] decompressed = new byte[originalLength]; + try { + zstdIs = new ZstdInputStream(is); + int remaining = originalLength; + while (remaining > 0) { + int read = zstdIs.read(decompressed, originalLength - remaining, remaining); + if (read > 0) { + remaining -= read; + } else { + break; + } + } + assertEquals(-1, zstdIs.read()); + } finally { + if (zstdIs != null) { + zstdIs.close(); + } else { + is.close(); + } + } + + return Unpooled.wrappedBuffer(decompressed); + } +} diff --git a/netty-handler-codec-compression/src/test/resources/logging.properties b/netty-handler-codec-compression/src/test/resources/logging.properties new file mode 100644 index 0000000..3cd7309 --- /dev/null +++ b/netty-handler-codec-compression/src/test/resources/logging.properties @@ -0,0 +1,7 @@ +handlers=java.util.logging.ConsoleHandler +.level=ALL +java.util.logging.SimpleFormatter.format=%1$tY-%1$tm-%1$td %1$tH:%1$tM:%1$tS.%1$tL %4$-7s [%3$s] %5$s %6$s%n +java.util.logging.ConsoleHandler.level=ALL +java.util.logging.ConsoleHandler.formatter=java.util.logging.SimpleFormatter +jdk.event.security.level=INFO +org.junit.jupiter.engine.execution.ConditionEvaluator.level=OFF diff --git a/netty-handler-codec-compression/src/test/resources/multiple.gz b/netty-handler-codec-compression/src/test/resources/multiple.gz new file mode 100644 index 0000000000000000000000000000000000000000..f5fd0675ee17e4477cb8707b69453ea355812b19 GIT binary patch literal 46 vcmb2|=HSTpIOfN|oXFtK!r;7b`wK<}1_pVca3xqciNTwR;ph8(g&<)794!np literal 0 HcmV?d00001 diff --git a/netty-handler-codec-http/build.gradle b/netty-handler-codec-http/build.gradle new file mode 100644 index 0000000..ba6ee81 --- /dev/null +++ b/netty-handler-codec-http/build.gradle @@ -0,0 +1,14 @@ +dependencies { + api project(':netty-handler') + api project(':netty-handler-ssl') + api project(':netty-handler-codec-compression') + implementation project(':netty-zlib') + implementation libs.brotli4j // accessing com.aayushatharva.brotli4j.encoder.Encoder + testImplementation testLibs.assertj + testImplementation testLibs.mockito.core + testRuntimeOnly testLibs.brotli4j.native.linux.x8664 + testRuntimeOnly testLibs.brotli4j.native.linux.aarch64 + testRuntimeOnly testLibs.brotli4j.native.osx.x8664 + testRuntimeOnly testLibs.brotli4j.native.osx.aarch64 + testRuntimeOnly testLibs.brotli4j.native.windows.x8664 +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/ClientCookieEncoder.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/ClientCookieEncoder.java new file mode 100644 index 0000000..4091353 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/ClientCookieEncoder.java @@ -0,0 +1,87 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.handler.codec.http.cookie.ClientCookieDecoder; + +/** + * A RFC6265 compliant cookie encoder to be used client side, + * so only name=value pairs are sent. + * + * User-Agents are not supposed to interpret cookies. {@link Cookie#value()} will be used unquoted. + * + * Note that multiple cookies are supposed to be sent at once in a single "Cookie" header. + * + *
+ * // Example
+ * {@link HttpRequest} req = ...;
+ * res.setHeader("Cookie", {@link ClientCookieEncoder}.encode("JSESSIONID", "1234"));
+ * 
+ * + * @see ClientCookieDecoder + */ +@Deprecated +public final class ClientCookieEncoder { + + /** + * Encodes the specified cookie into a Cookie header value. + * + * @param name the cookie name + * @param value the cookie value + * @return a Rfc6265 style Cookie header value + */ + @Deprecated + public static String encode(String name, String value) { + return io.netty.handler.codec.http.cookie.ClientCookieEncoder.LAX.encode(name, value); + } + + /** + * Encodes the specified cookie into a Cookie header value. + * + * @param cookie the specified cookie + * @return a Rfc6265 style Cookie header value + */ + @Deprecated + public static String encode(Cookie cookie) { + return io.netty.handler.codec.http.cookie.ClientCookieEncoder.LAX.encode(cookie); + } + + /** + * Encodes the specified cookies into a single Cookie header value. + * + * @param cookies some cookies + * @return a Rfc6265 style Cookie header value, null if no cookies are passed. + */ + @Deprecated + public static String encode(Cookie... cookies) { + return io.netty.handler.codec.http.cookie.ClientCookieEncoder.LAX.encode(cookies); + } + + /** + * Encodes the specified cookies into a single Cookie header value. + * + * @param cookies some cookies + * @return a Rfc6265 style Cookie header value, null if no cookies are passed. + */ + @Deprecated + public static String encode(Iterable cookies) { + return io.netty.handler.codec.http.cookie.ClientCookieEncoder.LAX.encode(cookies); + } + + private ClientCookieEncoder() { + // unused + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/CombinedHttpHeaders.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/CombinedHttpHeaders.java new file mode 100644 index 0000000..c09a010 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/CombinedHttpHeaders.java @@ -0,0 +1,329 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.handler.codec.DefaultHeaders; +import io.netty.handler.codec.DefaultHeaders.NameValidator; +import io.netty.handler.codec.DefaultHeaders.ValueValidator; +import io.netty.handler.codec.Headers; +import io.netty.handler.codec.ValueConverter; +import io.netty.util.HashingStrategy; +import io.netty.util.internal.StringUtil; + +import java.util.Collection; +import java.util.Iterator; +import java.util.List; +import java.util.Map; + +import static io.netty.handler.codec.http.HttpHeaderNames.SET_COOKIE; +import static io.netty.util.AsciiString.CASE_INSENSITIVE_HASHER; +import static io.netty.util.internal.ObjectUtil.checkNotNull; +import static io.netty.util.internal.StringUtil.COMMA; +import static io.netty.util.internal.StringUtil.unescapeCsvFields; + +/** + * Will add multiple values for the same header as single header with a comma separated list of values. + *

+ * Please refer to section RFC 7230, 3.2.2. + */ +public class CombinedHttpHeaders extends DefaultHttpHeaders { + /** + * Create a combined HTTP header object, with optional validation. + * + * @param validate Should Netty validate header values to ensure they aren't malicious. + * @deprecated Prefer instead to configuring a {@link HttpHeadersFactory} + * by calling {@link DefaultHttpHeadersFactory#withCombiningHeaders(boolean) withCombiningHeaders(true)} + * on {@link DefaultHttpHeadersFactory#headersFactory()}. + */ + @Deprecated + public CombinedHttpHeaders(boolean validate) { + super(new CombinedHttpHeadersImpl(CASE_INSENSITIVE_HASHER, valueConverter(), nameValidator(validate), + valueValidator(validate))); + } + + CombinedHttpHeaders(NameValidator nameValidator, ValueValidator valueValidator) { + super(new CombinedHttpHeadersImpl( + CASE_INSENSITIVE_HASHER, + valueConverter(), + checkNotNull(nameValidator, "nameValidator"), + checkNotNull(valueValidator, "valueValidator"))); + } + + CombinedHttpHeaders( + NameValidator nameValidator, ValueValidator valueValidator, int sizeHint) { + super(new CombinedHttpHeadersImpl( + CASE_INSENSITIVE_HASHER, + valueConverter(), + checkNotNull(nameValidator, "nameValidator"), + checkNotNull(valueValidator, "valueValidator"), + sizeHint)); + } + + @Override + public boolean containsValue(CharSequence name, CharSequence value, boolean ignoreCase) { + return super.containsValue(name, StringUtil.trimOws(value), ignoreCase); + } + + private static final class CombinedHttpHeadersImpl + extends DefaultHeaders { + /** + * An estimate of the size of a header value. + */ + private static final int VALUE_LENGTH_ESTIMATE = 10; + private CsvValueEscaper objectEscaper; + private CsvValueEscaper charSequenceEscaper; + + private CsvValueEscaper objectEscaper() { + if (objectEscaper == null) { + objectEscaper = new CsvValueEscaper() { + @Override + public CharSequence escape(CharSequence name, Object value) { + CharSequence converted; + try { + converted = valueConverter().convertObject(value); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException( + "Failed to convert object value for header '" + name + '\'', e); + } + return StringUtil.escapeCsv(converted, true); + } + }; + } + return objectEscaper; + } + + private CsvValueEscaper charSequenceEscaper() { + if (charSequenceEscaper == null) { + charSequenceEscaper = new CsvValueEscaper() { + @Override + public CharSequence escape(CharSequence name, CharSequence value) { + return StringUtil.escapeCsv(value, true); + } + }; + } + return charSequenceEscaper; + } + + CombinedHttpHeadersImpl(HashingStrategy nameHashingStrategy, + ValueConverter valueConverter, + NameValidator nameValidator, + ValueValidator valueValidator) { + this(nameHashingStrategy, valueConverter, nameValidator, valueValidator, 16); + } + + CombinedHttpHeadersImpl(HashingStrategy nameHashingStrategy, + ValueConverter valueConverter, + NameValidator nameValidator, + ValueValidator valueValidator, + int sizeHint) { + super(nameHashingStrategy, valueConverter, nameValidator, sizeHint, valueValidator); + } + + @Override + public Iterator valueIterator(CharSequence name) { + Iterator itr = super.valueIterator(name); + if (!itr.hasNext() || cannotBeCombined(name)) { + return itr; + } + Iterator unescapedItr = unescapeCsvFields(itr.next()).iterator(); + if (itr.hasNext()) { + throw new IllegalStateException("CombinedHttpHeaders should only have one value"); + } + return unescapedItr; + } + + @Override + public List getAll(CharSequence name) { + List values = super.getAll(name); + if (values.isEmpty() || cannotBeCombined(name)) { + return values; + } + if (values.size() != 1) { + throw new IllegalStateException("CombinedHttpHeaders should only have one value"); + } + return unescapeCsvFields(values.get(0)); + } + + @Override + public CombinedHttpHeadersImpl add(Headers headers) { + // Override the fast-copy mechanism used by DefaultHeaders + if (headers == this) { + throw new IllegalArgumentException("can't add to itself."); + } + if (headers instanceof CombinedHttpHeadersImpl) { + if (isEmpty()) { + // Can use the fast underlying copy + addImpl(headers); + } else { + // Values are already escaped so don't escape again + for (Map.Entry header : headers) { + addEscapedValue(header.getKey(), header.getValue()); + } + } + } else { + for (Map.Entry header : headers) { + add(header.getKey(), header.getValue()); + } + } + return this; + } + + @Override + public CombinedHttpHeadersImpl set(Headers headers) { + if (headers == this) { + return this; + } + clear(); + return add(headers); + } + + @Override + public CombinedHttpHeadersImpl setAll(Headers headers) { + if (headers == this) { + return this; + } + for (CharSequence key : headers.names()) { + remove(key); + } + return add(headers); + } + + @Override + public CombinedHttpHeadersImpl add(CharSequence name, CharSequence value) { + return addEscapedValue(name, charSequenceEscaper().escape(name, value)); + } + + @Override + public CombinedHttpHeadersImpl add(CharSequence name, CharSequence... values) { + return addEscapedValue(name, commaSeparate(name, charSequenceEscaper(), values)); + } + + @Override + public CombinedHttpHeadersImpl add(CharSequence name, Iterable values) { + return addEscapedValue(name, commaSeparate(name, charSequenceEscaper(), values)); + } + + @Override + public CombinedHttpHeadersImpl addObject(CharSequence name, Object value) { + return addEscapedValue(name, commaSeparate(name, objectEscaper(), value)); + } + + @Override + public CombinedHttpHeadersImpl addObject(CharSequence name, Iterable values) { + return addEscapedValue(name, commaSeparate(name, objectEscaper(), values)); + } + + @Override + public CombinedHttpHeadersImpl addObject(CharSequence name, Object... values) { + return addEscapedValue(name, commaSeparate(name, objectEscaper(), values)); + } + + @Override + public CombinedHttpHeadersImpl set(CharSequence name, CharSequence... values) { + set(name, commaSeparate(name, charSequenceEscaper(), values)); + return this; + } + + @Override + public CombinedHttpHeadersImpl set(CharSequence name, Iterable values) { + set(name, commaSeparate(name, charSequenceEscaper(), values)); + return this; + } + + @Override + public CombinedHttpHeadersImpl setObject(CharSequence name, Object value) { + set(name, commaSeparate(name, objectEscaper(), value)); + return this; + } + + @Override + public CombinedHttpHeadersImpl setObject(CharSequence name, Object... values) { + set(name, commaSeparate(name, objectEscaper(), values)); + return this; + } + + @Override + public CombinedHttpHeadersImpl setObject(CharSequence name, Iterable values) { + set(name, commaSeparate(name, objectEscaper(), values)); + return this; + } + + private static boolean cannotBeCombined(CharSequence name) { + return SET_COOKIE.contentEqualsIgnoreCase(name); + } + + private CombinedHttpHeadersImpl addEscapedValue(CharSequence name, CharSequence escapedValue) { + CharSequence currentValue = get(name); + if (currentValue == null || cannotBeCombined(name)) { + super.add(name, escapedValue); + } else { + set(name, commaSeparateEscapedValues(currentValue, escapedValue)); + } + return this; + } + + private static CharSequence commaSeparate(CharSequence name, CsvValueEscaper escaper, T... values) { + StringBuilder sb = new StringBuilder(values.length * VALUE_LENGTH_ESTIMATE); + if (values.length > 0) { + int end = values.length - 1; + for (int i = 0; i < end; i++) { + sb.append(escaper.escape(name, values[i])).append(COMMA); + } + sb.append(escaper.escape(name, values[end])); + } + return sb; + } + + private static CharSequence commaSeparate(CharSequence name, CsvValueEscaper escaper, + Iterable values) { + @SuppressWarnings("rawtypes") + final StringBuilder sb = values instanceof Collection + ? new StringBuilder(((Collection) values).size() * VALUE_LENGTH_ESTIMATE) : new StringBuilder(); + Iterator iterator = values.iterator(); + if (iterator.hasNext()) { + T next = iterator.next(); + while (iterator.hasNext()) { + sb.append(escaper.escape(name, next)).append(COMMA); + next = iterator.next(); + } + sb.append(escaper.escape(name, next)); + } + return sb; + } + + private static CharSequence commaSeparateEscapedValues(CharSequence currentValue, CharSequence value) { + return new StringBuilder(currentValue.length() + 1 + value.length()) + .append(currentValue) + .append(COMMA) + .append(value); + } + + /** + * Escapes comma separated values (CSV). + * + * @param The type that a concrete implementation handles + */ + private interface CsvValueEscaper { + /** + * Appends the value to the specified {@link StringBuilder}, escaping if necessary. + * + * @param name the name of the header for the value being escaped + * @param value the value to be appended, escaped if necessary + */ + CharSequence escape(CharSequence name, T value); + } + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/ComposedLastHttpContent.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/ComposedLastHttpContent.java new file mode 100644 index 0000000..8ceaf4b --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/ComposedLastHttpContent.java @@ -0,0 +1,119 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.handler.codec.DecoderResult; + + +final class ComposedLastHttpContent implements LastHttpContent { + private final HttpHeaders trailingHeaders; + private DecoderResult result; + + ComposedLastHttpContent(HttpHeaders trailingHeaders) { + this.trailingHeaders = trailingHeaders; + } + + ComposedLastHttpContent(HttpHeaders trailingHeaders, DecoderResult result) { + this(trailingHeaders); + this.result = result; + } + + @Override + public HttpHeaders trailingHeaders() { + return trailingHeaders; + } + + @Override + public LastHttpContent copy() { + LastHttpContent content = new DefaultLastHttpContent(Unpooled.EMPTY_BUFFER); + content.trailingHeaders().set(trailingHeaders()); + return content; + } + + @Override + public LastHttpContent duplicate() { + return copy(); + } + + @Override + public LastHttpContent retainedDuplicate() { + return copy(); + } + + @Override + public LastHttpContent replace(ByteBuf content) { + final LastHttpContent dup = new DefaultLastHttpContent(content); + dup.trailingHeaders().setAll(trailingHeaders()); + return dup; + } + + @Override + public LastHttpContent retain(int increment) { + return this; + } + + @Override + public LastHttpContent retain() { + return this; + } + + @Override + public LastHttpContent touch() { + return this; + } + + @Override + public LastHttpContent touch(Object hint) { + return this; + } + + @Override + public ByteBuf content() { + return Unpooled.EMPTY_BUFFER; + } + + @Override + public DecoderResult decoderResult() { + return result; + } + + @Override + public DecoderResult getDecoderResult() { + return decoderResult(); + } + + @Override + public void setDecoderResult(DecoderResult result) { + this.result = result; + } + + @Override + public int refCnt() { + return 1; + } + + @Override + public boolean release() { + return false; + } + + @Override + public boolean release(int decrement) { + return false; + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/CompressionEncoderFactory.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/CompressionEncoderFactory.java new file mode 100644 index 0000000..f3d361d --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/CompressionEncoderFactory.java @@ -0,0 +1,27 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.buffer.ByteBuf; +import io.netty.handler.codec.MessageToByteEncoder; + +/** + * Compression Encoder Factory for create {@link MessageToByteEncoder} + * used to compress http content + */ +interface CompressionEncoderFactory { + MessageToByteEncoder createEncoder(); +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/Cookie.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/Cookie.java new file mode 100644 index 0000000..e192e57 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/Cookie.java @@ -0,0 +1,221 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import java.util.Set; + +/** + * An interface defining an + * HTTP cookie. + * @deprecated Use {@link io.netty.handler.codec.http.cookie.Cookie} instead. + */ +@Deprecated +public interface Cookie extends io.netty.handler.codec.http.cookie.Cookie { + + /** + * @deprecated Use {@link #name()} instead. + */ + @Deprecated + String getName(); + + /** + * @deprecated Use {@link #value()} instead. + */ + @Deprecated + String getValue(); + + /** + * @deprecated Use {@link #domain()} instead. + */ + @Deprecated + String getDomain(); + + /** + * @deprecated Use {@link #path()} instead. + */ + @Deprecated + String getPath(); + + /** + * @deprecated Use {@link #comment()} instead. + */ + @Deprecated + String getComment(); + + /** + * Returns the comment of this {@link Cookie}. + * + * @return The comment of this {@link Cookie} + * + * @deprecated Not part of RFC6265 + */ + @Deprecated + String comment(); + + /** + * Sets the comment of this {@link Cookie}. + * + * @param comment The comment to use + * + * @deprecated Not part of RFC6265 + */ + @Deprecated + void setComment(String comment); + + /** + * @deprecated Use {@link #maxAge()} instead. + */ + @Deprecated + long getMaxAge(); + + /** + * Returns the maximum age of this {@link Cookie} in seconds or {@link Long#MIN_VALUE} if unspecified + * + * @return The maximum age of this {@link Cookie} + * + * @deprecated Not part of RFC6265 + */ + @Deprecated + @Override + long maxAge(); + + /** + * Sets the maximum age of this {@link Cookie} in seconds. + * If an age of {@code 0} is specified, this {@link Cookie} will be + * automatically removed by browser because it will expire immediately. + * If {@link Long#MIN_VALUE} is specified, this {@link Cookie} will be removed when the + * browser is closed. + * + * @param maxAge The maximum age of this {@link Cookie} in seconds + * + * @deprecated Not part of RFC6265 + */ + @Deprecated + @Override + void setMaxAge(long maxAge); + + /** + * @deprecated Use {@link #version()} instead. + */ + @Deprecated + int getVersion(); + + /** + * Returns the version of this {@link Cookie}. + * + * @return The version of this {@link Cookie} + * + * @deprecated Not part of RFC6265 + */ + @Deprecated + int version(); + + /** + * Sets the version of this {@link Cookie}. + * + * @param version The new version to use + * + * @deprecated Not part of RFC6265 + */ + @Deprecated + void setVersion(int version); + + /** + * @deprecated Use {@link #commentUrl()} instead. + */ + @Deprecated + String getCommentUrl(); + + /** + * Returns the comment URL of this {@link Cookie}. + * + * @return The comment URL of this {@link Cookie} + * + * @deprecated Not part of RFC6265 + */ + @Deprecated + String commentUrl(); + + /** + * Sets the comment URL of this {@link Cookie}. + * + * @param commentUrl The comment URL to use + * + * @deprecated Not part of RFC6265 + */ + @Deprecated + void setCommentUrl(String commentUrl); + + /** + * Checks to see if this {@link Cookie} is to be discarded by the browser + * at the end of the current session. + * + * @return True if this {@link Cookie} is to be discarded, otherwise false + * + * @deprecated Not part of RFC6265 + */ + @Deprecated + boolean isDiscard(); + + /** + * Sets the discard flag of this {@link Cookie}. + * If set to true, this {@link Cookie} will be discarded by the browser + * at the end of the current session + * + * @param discard True if the {@link Cookie} is to be discarded + * + * @deprecated Not part of RFC6265 + */ + @Deprecated + void setDiscard(boolean discard); + + /** + * @deprecated Use {@link #ports()} instead. + */ + @Deprecated + Set getPorts(); + + /** + * Returns the ports that this {@link Cookie} can be accessed on. + * + * @return The {@link Set} of ports that this {@link Cookie} can use + * + * @deprecated Not part of RFC6265 + */ + @Deprecated + Set ports(); + + /** + * Sets the ports that this {@link Cookie} can be accessed on. + * + * @param ports The ports that this {@link Cookie} can be accessed on + * + * @deprecated Not part of RFC6265 + */ + @Deprecated + void setPorts(int... ports); + + /** + * Sets the ports that this {@link Cookie} can be accessed on. + * + * @param ports The {@link Iterable} collection of ports that this + * {@link Cookie} can be accessed on. + * + * @deprecated Not part of RFC6265 + */ + @Deprecated + void setPorts(Iterable ports); +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/CookieDecoder.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/CookieDecoder.java new file mode 100644 index 0000000..8a35318 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/CookieDecoder.java @@ -0,0 +1,369 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import static io.netty.handler.codec.http.CookieUtil.firstInvalidCookieNameOctet; +import static io.netty.handler.codec.http.CookieUtil.firstInvalidCookieValueOctet; +import static io.netty.handler.codec.http.CookieUtil.unwrapValue; + +import io.netty.handler.codec.DateFormatter; +import io.netty.handler.codec.http.cookie.CookieHeaderNames; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Date; +import java.util.List; +import java.util.Set; +import java.util.TreeSet; + +/** + * @deprecated Use {@link io.netty.handler.codec.http.cookie.ClientCookieDecoder} + * or {@link io.netty.handler.codec.http.cookie.ServerCookieDecoder} instead. + * + * Decodes an HTTP header value into {@link Cookie}s. This decoder can decode + * the HTTP cookie version 0, 1, and 2. + * + *
+ * {@link HttpRequest} req = ...;
+ * String value = req.getHeader("Cookie");
+ * Set<{@link Cookie}> cookies = {@link CookieDecoder}.decode(value);
+ * 
+ * + * @see io.netty.handler.codec.http.cookie.ClientCookieDecoder + * @see io.netty.handler.codec.http.cookie.ServerCookieDecoder + */ +@Deprecated +public final class CookieDecoder { + + private final InternalLogger logger = InternalLoggerFactory.getInstance(getClass()); + + private static final CookieDecoder STRICT = new CookieDecoder(true); + + private static final CookieDecoder LAX = new CookieDecoder(false); + + private static final String COMMENT = "Comment"; + + private static final String COMMENTURL = "CommentURL"; + + private static final String DISCARD = "Discard"; + + private static final String PORT = "Port"; + + private static final String VERSION = "Version"; + + private final boolean strict; + + public static Set decode(String header) { + return decode(header, true); + } + + public static Set decode(String header, boolean strict) { + return (strict ? STRICT : LAX).doDecode(header); + } + + /** + * Decodes the specified HTTP header value into {@link Cookie}s. + * + * @return the decoded {@link Cookie}s + */ + private Set doDecode(String header) { + List names = new ArrayList(8); + List values = new ArrayList(8); + extractKeyValuePairs(header, names, values); + + if (names.isEmpty()) { + return Collections.emptySet(); + } + + int i; + int version = 0; + + // $Version is the only attribute that can appear before the actual + // cookie name-value pair. + if (names.get(0).equalsIgnoreCase(VERSION)) { + try { + version = Integer.parseInt(values.get(0)); + } catch (NumberFormatException e) { + // Ignore. + } + i = 1; + } else { + i = 0; + } + + if (names.size() <= i) { + // There's a version attribute, but nothing more. + return Collections.emptySet(); + } + + Set cookies = new TreeSet(); + for (; i < names.size(); i ++) { + String name = names.get(i); + String value = values.get(i); + if (value == null) { + value = ""; + } + + Cookie c = initCookie(name, value); + + if (c == null) { + break; + } + + boolean discard = false; + boolean secure = false; + boolean httpOnly = false; + String comment = null; + String commentURL = null; + String domain = null; + String path = null; + long maxAge = Long.MIN_VALUE; + List ports = new ArrayList(2); + + for (int j = i + 1; j < names.size(); j++, i++) { + name = names.get(j); + value = values.get(j); + + if (DISCARD.equalsIgnoreCase(name)) { + discard = true; + } else if (CookieHeaderNames.SECURE.equalsIgnoreCase(name)) { + secure = true; + } else if (CookieHeaderNames.HTTPONLY.equalsIgnoreCase(name)) { + httpOnly = true; + } else if (COMMENT.equalsIgnoreCase(name)) { + comment = value; + } else if (COMMENTURL.equalsIgnoreCase(name)) { + commentURL = value; + } else if (CookieHeaderNames.DOMAIN.equalsIgnoreCase(name)) { + domain = value; + } else if (CookieHeaderNames.PATH.equalsIgnoreCase(name)) { + path = value; + } else if (CookieHeaderNames.EXPIRES.equalsIgnoreCase(name)) { + Date date = DateFormatter.parseHttpDate(value); + if (date != null) { + long maxAgeMillis = date.getTime() - System.currentTimeMillis(); + maxAge = maxAgeMillis / 1000 + (maxAgeMillis % 1000 != 0? 1 : 0); + } + } else if (CookieHeaderNames.MAX_AGE.equalsIgnoreCase(name)) { + maxAge = Integer.parseInt(value); + } else if (VERSION.equalsIgnoreCase(name)) { + version = Integer.parseInt(value); + } else if (PORT.equalsIgnoreCase(name)) { + String[] portList = value.split(","); + for (String s1: portList) { + try { + ports.add(Integer.valueOf(s1)); + } catch (NumberFormatException e) { + // Ignore. + } + } + } else { + break; + } + } + + c.setVersion(version); + c.setMaxAge(maxAge); + c.setPath(path); + c.setDomain(domain); + c.setSecure(secure); + c.setHttpOnly(httpOnly); + if (version > 0) { + c.setComment(comment); + } + if (version > 1) { + c.setCommentUrl(commentURL); + c.setPorts(ports); + c.setDiscard(discard); + } + + cookies.add(c); + } + + return cookies; + } + + private static void extractKeyValuePairs( + final String header, final List names, final List values) { + final int headerLen = header.length(); + loop: for (int i = 0;;) { + + // Skip spaces and separators. + for (;;) { + if (i == headerLen) { + break loop; + } + switch (header.charAt(i)) { + case '\t': case '\n': case 0x0b: case '\f': case '\r': + case ' ': case ',': case ';': + i ++; + continue; + } + break; + } + + // Skip '$'. + for (;;) { + if (i == headerLen) { + break loop; + } + if (header.charAt(i) == '$') { + i ++; + continue; + } + break; + } + + String name; + String value; + + if (i == headerLen) { + name = null; + value = null; + } else { + int newNameStart = i; + keyValLoop: for (;;) { + switch (header.charAt(i)) { + case ';': + // NAME; (no value till ';') + name = header.substring(newNameStart, i); + value = null; + break keyValLoop; + case '=': + // NAME=VALUE + name = header.substring(newNameStart, i); + i ++; + if (i == headerLen) { + // NAME= (empty value, i.e. nothing after '=') + value = ""; + break keyValLoop; + } + + int newValueStart = i; + char c = header.charAt(i); + if (c == '"' || c == '\'') { + // NAME="VALUE" or NAME='VALUE' + StringBuilder newValueBuf = new StringBuilder(header.length() - i); + final char q = c; + boolean hadBackslash = false; + i ++; + for (;;) { + if (i == headerLen) { + value = newValueBuf.toString(); + break keyValLoop; + } + if (hadBackslash) { + hadBackslash = false; + c = header.charAt(i ++); + switch (c) { + case '\\': case '"': case '\'': + // Escape last backslash. + newValueBuf.setCharAt(newValueBuf.length() - 1, c); + break; + default: + // Do not escape last backslash. + newValueBuf.append(c); + } + } else { + c = header.charAt(i ++); + if (c == q) { + value = newValueBuf.toString(); + break keyValLoop; + } + newValueBuf.append(c); + if (c == '\\') { + hadBackslash = true; + } + } + } + } else { + // NAME=VALUE; + int semiPos = header.indexOf(';', i); + if (semiPos > 0) { + value = header.substring(newValueStart, semiPos); + i = semiPos; + } else { + value = header.substring(newValueStart); + i = headerLen; + } + } + break keyValLoop; + default: + i ++; + } + + if (i == headerLen) { + // NAME (no value till the end of string) + name = header.substring(newNameStart); + value = null; + break; + } + } + } + + names.add(name); + values.add(value); + } + } + + private CookieDecoder(boolean strict) { + this.strict = strict; + } + + private DefaultCookie initCookie(String name, String value) { + if (name == null || name.length() == 0) { + logger.debug("Skipping cookie with null name"); + return null; + } + + if (value == null) { + logger.debug("Skipping cookie with null value"); + return null; + } + + CharSequence unwrappedValue = unwrapValue(value); + if (unwrappedValue == null) { + logger.debug("Skipping cookie because starting quotes are not properly balanced in '{}'", + unwrappedValue); + return null; + } + + int invalidOctetPos; + if (strict && (invalidOctetPos = firstInvalidCookieNameOctet(name)) >= 0) { + if (logger.isDebugEnabled()) { + logger.debug("Skipping cookie because name '{}' contains invalid char '{}'", + name, name.charAt(invalidOctetPos)); + } + return null; + } + + final boolean wrap = unwrappedValue.length() != value.length(); + + if (strict && (invalidOctetPos = firstInvalidCookieValueOctet(unwrappedValue)) >= 0) { + if (logger.isDebugEnabled()) { + logger.debug("Skipping cookie because value '{}' contains invalid char '{}'", + unwrappedValue, unwrappedValue.charAt(invalidOctetPos)); + } + return null; + } + + DefaultCookie cookie = new DefaultCookie(name, unwrappedValue.toString()); + cookie.setWrap(wrap); + return cookie; + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/CookieUtil.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/CookieUtil.java new file mode 100644 index 0000000..a58af9a --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/CookieUtil.java @@ -0,0 +1,104 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import java.util.BitSet; + +/** + * @deprecated Duplicate of package private ${@link io.netty.handler.codec.http.cookie.CookieUtil} + */ +@Deprecated +final class CookieUtil { + + private static final BitSet VALID_COOKIE_VALUE_OCTETS = validCookieValueOctets(); + + private static final BitSet VALID_COOKIE_NAME_OCTETS = validCookieNameOctets(VALID_COOKIE_VALUE_OCTETS); + + // US-ASCII characters excluding CTLs, whitespace, DQUOTE, comma, semicolon, and backslash + private static BitSet validCookieValueOctets() { + BitSet bits = new BitSet(8); + for (int i = 35; i < 127; i++) { + // US-ASCII characters excluding CTLs (%x00-1F / %x7F) + bits.set(i); + } + bits.set('"', false); // exclude DQUOTE = %x22 + bits.set(',', false); // exclude comma = %x2C + bits.set(';', false); // exclude semicolon = %x3B + bits.set('\\', false); // exclude backslash = %x5C + return bits; + } + + // token = 1* + // separators = "(" | ")" | "<" | ">" | "@" + // | "," | ";" | ":" | "\" | <"> + // | "/" | "[" | "]" | "?" | "=" + // | "{" | "}" | SP | HT + private static BitSet validCookieNameOctets(BitSet validCookieValueOctets) { + BitSet bits = new BitSet(8); + bits.or(validCookieValueOctets); + bits.set('(', false); + bits.set(')', false); + bits.set('<', false); + bits.set('>', false); + bits.set('@', false); + bits.set(':', false); + bits.set('/', false); + bits.set('[', false); + bits.set(']', false); + bits.set('?', false); + bits.set('=', false); + bits.set('{', false); + bits.set('}', false); + bits.set(' ', false); + bits.set('\t', false); + return bits; + } + + static int firstInvalidCookieNameOctet(CharSequence cs) { + return firstInvalidOctet(cs, VALID_COOKIE_NAME_OCTETS); + } + + static int firstInvalidCookieValueOctet(CharSequence cs) { + return firstInvalidOctet(cs, VALID_COOKIE_VALUE_OCTETS); + } + + static int firstInvalidOctet(CharSequence cs, BitSet bits) { + for (int i = 0; i < cs.length(); i++) { + char c = cs.charAt(i); + if (!bits.get(c)) { + return i; + } + } + return -1; + } + + static CharSequence unwrapValue(CharSequence cs) { + final int len = cs.length(); + if (len > 0 && cs.charAt(0) == '"') { + if (len >= 2 && cs.charAt(len - 1) == '"') { + // properly balanced + return len == 2 ? "" : cs.subSequence(1, len - 1); + } else { + return null; + } + } + return cs; + } + + private CookieUtil() { + // Unused + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/DefaultCookie.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/DefaultCookie.java new file mode 100644 index 0000000..ac640e3 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/DefaultCookie.java @@ -0,0 +1,195 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.util.internal.ObjectUtil; + +import java.util.Collections; +import java.util.Set; +import java.util.TreeSet; + +/** + * The default {@link Cookie} implementation. + * + * @deprecated Use {@link io.netty.handler.codec.http.cookie.DefaultCookie} instead. + */ +@Deprecated +public class DefaultCookie extends io.netty.handler.codec.http.cookie.DefaultCookie implements Cookie { + + private String comment; + private String commentUrl; + private boolean discard; + private Set ports = Collections.emptySet(); + private Set unmodifiablePorts = ports; + private int version; + + /** + * Creates a new cookie with the specified name and value. + */ + public DefaultCookie(String name, String value) { + super(name, value); + } + + @Override + @Deprecated + public String getName() { + return name(); + } + + @Override + @Deprecated + public String getValue() { + return value(); + } + + @Override + @Deprecated + public String getDomain() { + return domain(); + } + + @Override + @Deprecated + public String getPath() { + return path(); + } + + @Override + @Deprecated + public String getComment() { + return comment(); + } + + @Override + @Deprecated + public String comment() { + return comment; + } + + @Override + @Deprecated + public void setComment(String comment) { + this.comment = validateValue("comment", comment); + } + + @Override + @Deprecated + public String getCommentUrl() { + return commentUrl(); + } + + @Override + @Deprecated + public String commentUrl() { + return commentUrl; + } + + @Override + @Deprecated + public void setCommentUrl(String commentUrl) { + this.commentUrl = validateValue("commentUrl", commentUrl); + } + + @Override + @Deprecated + public boolean isDiscard() { + return discard; + } + + @Override + @Deprecated + public void setDiscard(boolean discard) { + this.discard = discard; + } + + @Override + @Deprecated + public Set getPorts() { + return ports(); + } + + @Override + @Deprecated + public Set ports() { + if (unmodifiablePorts == null) { + unmodifiablePorts = Collections.unmodifiableSet(ports); + } + return unmodifiablePorts; + } + + @Override + @Deprecated + public void setPorts(int... ports) { + ObjectUtil.checkNotNull(ports, "ports"); + + int[] portsCopy = ports.clone(); + if (portsCopy.length == 0) { + unmodifiablePorts = this.ports = Collections.emptySet(); + } else { + Set newPorts = new TreeSet(); + for (int p: portsCopy) { + if (p <= 0 || p > 65535) { + throw new IllegalArgumentException("port out of range: " + p); + } + newPorts.add(Integer.valueOf(p)); + } + this.ports = newPorts; + unmodifiablePorts = null; + } + } + + @Override + @Deprecated + public void setPorts(Iterable ports) { + Set newPorts = new TreeSet(); + for (int p: ports) { + if (p <= 0 || p > 65535) { + throw new IllegalArgumentException("port out of range: " + p); + } + newPorts.add(Integer.valueOf(p)); + } + if (newPorts.isEmpty()) { + unmodifiablePorts = this.ports = Collections.emptySet(); + } else { + this.ports = newPorts; + unmodifiablePorts = null; + } + } + + @Override + @Deprecated + public long getMaxAge() { + return maxAge(); + } + + @Override + @Deprecated + public int getVersion() { + return version(); + } + + @Override + @Deprecated + public int version() { + return version; + } + + @Override + @Deprecated + public void setVersion(int version) { + this.version = version; + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/DefaultFullHttpRequest.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/DefaultFullHttpRequest.java new file mode 100644 index 0000000..3cd8d0c --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/DefaultFullHttpRequest.java @@ -0,0 +1,228 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import io.netty.util.IllegalReferenceCountException; + +import static io.netty.handler.codec.http.DefaultHttpHeadersFactory.headersFactory; +import static io.netty.handler.codec.http.DefaultHttpHeadersFactory.trailersFactory; +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +/** + * Default implementation of {@link FullHttpRequest}. + */ +public class DefaultFullHttpRequest extends DefaultHttpRequest implements FullHttpRequest { + private final ByteBuf content; + private final HttpHeaders trailingHeader; + + /** + * Used to cache the value of the hash code and avoid {@link IllegalReferenceCountException}. + */ + private int hash; + + /** + * Create a full HTTP response with the given HTTP version, method, and URI. + */ + public DefaultFullHttpRequest(HttpVersion httpVersion, HttpMethod method, String uri) { + this(httpVersion, method, uri, Unpooled.buffer(0), headersFactory(), trailersFactory()); + } + + /** + * Create a full HTTP response with the given HTTP version, method, URI, and contents. + */ + public DefaultFullHttpRequest(HttpVersion httpVersion, HttpMethod method, String uri, ByteBuf content) { + this(httpVersion, method, uri, content, headersFactory(), trailersFactory()); + } + + /** + * Create a full HTTP response with the given HTTP version, method, URI, and optional validation. + * @deprecated Use the {@link #DefaultFullHttpRequest(HttpVersion, HttpMethod, String, ByteBuf, + * HttpHeadersFactory, HttpHeadersFactory)} constructor instead. + */ + @Deprecated + public DefaultFullHttpRequest(HttpVersion httpVersion, HttpMethod method, String uri, boolean validateHeaders) { + this(httpVersion, method, uri, Unpooled.buffer(0), + headersFactory().withValidation(validateHeaders), + trailersFactory().withValidation(validateHeaders)); + } + + /** + * Create a full HTTP response with the given HTTP version, method, URI, contents, and optional validation. + * @deprecated Use the {@link #DefaultFullHttpRequest(HttpVersion, HttpMethod, String, ByteBuf, + * HttpHeadersFactory, HttpHeadersFactory)} constructor instead. + */ + @Deprecated + public DefaultFullHttpRequest(HttpVersion httpVersion, HttpMethod method, String uri, + ByteBuf content, boolean validateHeaders) { + this(httpVersion, method, uri, content, + headersFactory().withValidation(validateHeaders), + trailersFactory().withValidation(validateHeaders)); + } + + /** + * Create a full HTTP response with the given HTTP version, method, URI, contents, + * and factories for creating headers and trailers. + *

+ * The recommended default header factory is {@link DefaultHttpHeadersFactory#headersFactory()}, + * and the recommended default trailer factory is {@link DefaultHttpHeadersFactory#trailersFactory()}. + */ + public DefaultFullHttpRequest(HttpVersion httpVersion, HttpMethod method, String uri, + ByteBuf content, HttpHeadersFactory headersFactory, HttpHeadersFactory trailersFactory) { + this(httpVersion, method, uri, content, headersFactory.newHeaders(), trailersFactory.newHeaders()); + } + + /** + * Create a full HTTP response with the given HTTP version, method, URI, contents, and header and trailer objects. + */ + public DefaultFullHttpRequest(HttpVersion httpVersion, HttpMethod method, String uri, + ByteBuf content, HttpHeaders headers, HttpHeaders trailingHeader) { + super(httpVersion, method, uri, headers); + this.content = checkNotNull(content, "content"); + this.trailingHeader = checkNotNull(trailingHeader, "trailingHeader"); + } + + @Override + public HttpHeaders trailingHeaders() { + return trailingHeader; + } + + @Override + public ByteBuf content() { + return content; + } + + @Override + public int refCnt() { + return content.refCnt(); + } + + @Override + public FullHttpRequest retain() { + content.retain(); + return this; + } + + @Override + public FullHttpRequest retain(int increment) { + content.retain(increment); + return this; + } + + @Override + public FullHttpRequest touch() { + content.touch(); + return this; + } + + @Override + public FullHttpRequest touch(Object hint) { + content.touch(hint); + return this; + } + + @Override + public boolean release() { + return content.release(); + } + + @Override + public boolean release(int decrement) { + return content.release(decrement); + } + + @Override + public FullHttpRequest setProtocolVersion(HttpVersion version) { + super.setProtocolVersion(version); + return this; + } + + @Override + public FullHttpRequest setMethod(HttpMethod method) { + super.setMethod(method); + return this; + } + + @Override + public FullHttpRequest setUri(String uri) { + super.setUri(uri); + return this; + } + + @Override + public FullHttpRequest copy() { + return replace(content().copy()); + } + + @Override + public FullHttpRequest duplicate() { + return replace(content().duplicate()); + } + + @Override + public FullHttpRequest retainedDuplicate() { + return replace(content().retainedDuplicate()); + } + + @Override + public FullHttpRequest replace(ByteBuf content) { + FullHttpRequest request = new DefaultFullHttpRequest(protocolVersion(), method(), uri(), content, + headers().copy(), trailingHeaders().copy()); + request.setDecoderResult(decoderResult()); + return request; + } + + @Override + public int hashCode() { + int hash = this.hash; + if (hash == 0) { + if (ByteBufUtil.isAccessible(content())) { + try { + hash = 31 + content().hashCode(); + } catch (IllegalReferenceCountException ignored) { + // Handle race condition between checking refCnt() == 0 and using the object. + hash = 31; + } + } else { + hash = 31; + } + hash = 31 * hash + trailingHeaders().hashCode(); + hash = 31 * hash + super.hashCode(); + this.hash = hash; + } + return hash; + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof DefaultFullHttpRequest)) { + return false; + } + + DefaultFullHttpRequest other = (DefaultFullHttpRequest) o; + + return super.equals(other) && + content().equals(other.content()) && + trailingHeaders().equals(other.trailingHeaders()); + } + + @Override + public String toString() { + return HttpMessageUtil.appendFullRequest(new StringBuilder(256), this).toString(); + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/DefaultFullHttpResponse.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/DefaultFullHttpResponse.java new file mode 100644 index 0000000..aa1e5fa --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/DefaultFullHttpResponse.java @@ -0,0 +1,255 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import io.netty.util.IllegalReferenceCountException; + +import static io.netty.handler.codec.http.DefaultHttpHeadersFactory.headersFactory; +import static io.netty.handler.codec.http.DefaultHttpHeadersFactory.trailersFactory; +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +/** + * Default implementation of a {@link FullHttpResponse}. + */ +public class DefaultFullHttpResponse extends DefaultHttpResponse implements FullHttpResponse { + + private final ByteBuf content; + private final HttpHeaders trailingHeaders; + + /** + * Used to cache the value of the hash code and avoid {@link IllegalReferenceCountException}. + */ + private int hash; + + /** + * Create an empty HTTP response with the given HTTP version and status. + */ + public DefaultFullHttpResponse(HttpVersion version, HttpResponseStatus status) { + this(version, status, Unpooled.buffer(0), headersFactory(), trailersFactory()); + } + + /** + * Create an HTTP response with the given HTTP version, status, and contents. + */ + public DefaultFullHttpResponse(HttpVersion version, HttpResponseStatus status, ByteBuf content) { + this(version, status, content, headersFactory(), trailersFactory()); + } + + /** + * Create an empty HTTP response with the given HTTP version, status, and optional header validation. + * + * @deprecated Prefer the {@link #DefaultFullHttpResponse(HttpVersion, HttpResponseStatus, ByteBuf, + * HttpHeadersFactory, HttpHeadersFactory)} constructor instead. + */ + @Deprecated + public DefaultFullHttpResponse(HttpVersion version, HttpResponseStatus status, boolean validateHeaders) { + this(version, status, Unpooled.buffer(0), + headersFactory().withValidation(validateHeaders), + trailersFactory().withValidation(validateHeaders)); + } + + /** + * Create an empty HTTP response with the given HTTP version, status, optional header validation, + * and optional header combining. + * + * @deprecated Prefer the {@link #DefaultFullHttpResponse(HttpVersion, HttpResponseStatus, ByteBuf, + * HttpHeadersFactory, HttpHeadersFactory)} constructor instead. + */ + @Deprecated + public DefaultFullHttpResponse(HttpVersion version, HttpResponseStatus status, boolean validateHeaders, + boolean singleFieldHeaders) { + this(version, status, Unpooled.buffer(0), + headersFactory().withValidation(validateHeaders).withCombiningHeaders(singleFieldHeaders), + trailersFactory().withValidation(validateHeaders).withCombiningHeaders(singleFieldHeaders)); + } + + /** + * Create an HTTP response with the given HTTP version, status, contents, and optional header validation. + * + * @deprecated Prefer the {@link #DefaultFullHttpResponse(HttpVersion, HttpResponseStatus, ByteBuf, + * HttpHeadersFactory, HttpHeadersFactory)} constructor instead. + */ + @Deprecated + public DefaultFullHttpResponse(HttpVersion version, HttpResponseStatus status, + ByteBuf content, boolean validateHeaders) { + this(version, status, content, + headersFactory().withValidation(validateHeaders), + trailersFactory().withValidation(validateHeaders)); + } + + /** + * Create an HTTP response with the given HTTP version, status, contents, optional header validation, + * and optional header combining. + * + * @deprecated Prefer the {@link #DefaultFullHttpResponse(HttpVersion, HttpResponseStatus, ByteBuf, + * HttpHeadersFactory, HttpHeadersFactory)} constructor instead. + */ + @Deprecated + public DefaultFullHttpResponse(HttpVersion version, HttpResponseStatus status, + ByteBuf content, boolean validateHeaders, boolean singleFieldHeaders) { + this(version, status, content, + headersFactory().withValidation(validateHeaders).withCombiningHeaders(singleFieldHeaders), + trailersFactory().withValidation(validateHeaders).withCombiningHeaders(singleFieldHeaders)); + } + + /** + * Create an HTTP response with the given HTTP version, status, contents, + * and with headers and trailers created by the given header factories. + *

+ * The recommended header factory is {@link DefaultHttpHeadersFactory#headersFactory()}, + * and the recommended trailer factory is {@link DefaultHttpHeadersFactory#trailersFactory()}. + */ + public DefaultFullHttpResponse(HttpVersion version, HttpResponseStatus status, ByteBuf content, + HttpHeadersFactory headersFactory, HttpHeadersFactory trailersFactory) { + this(version, status, content, headersFactory.newHeaders(), trailersFactory.newHeaders()); + } + + /** + * Create an HTTP response with the given HTTP version, status, contents, headers and trailers. + */ + public DefaultFullHttpResponse(HttpVersion version, HttpResponseStatus status, + ByteBuf content, HttpHeaders headers, HttpHeaders trailingHeaders) { + super(version, status, headers); + this.content = checkNotNull(content, "content"); + this.trailingHeaders = checkNotNull(trailingHeaders, "trailingHeaders"); + } + + @Override + public HttpHeaders trailingHeaders() { + return trailingHeaders; + } + + @Override + public ByteBuf content() { + return content; + } + + @Override + public int refCnt() { + return content.refCnt(); + } + + @Override + public FullHttpResponse retain() { + content.retain(); + return this; + } + + @Override + public FullHttpResponse retain(int increment) { + content.retain(increment); + return this; + } + + @Override + public FullHttpResponse touch() { + content.touch(); + return this; + } + + @Override + public FullHttpResponse touch(Object hint) { + content.touch(hint); + return this; + } + + @Override + public boolean release() { + return content.release(); + } + + @Override + public boolean release(int decrement) { + return content.release(decrement); + } + + @Override + public FullHttpResponse setProtocolVersion(HttpVersion version) { + super.setProtocolVersion(version); + return this; + } + + @Override + public FullHttpResponse setStatus(HttpResponseStatus status) { + super.setStatus(status); + return this; + } + + @Override + public FullHttpResponse copy() { + return replace(content().copy()); + } + + @Override + public FullHttpResponse duplicate() { + return replace(content().duplicate()); + } + + @Override + public FullHttpResponse retainedDuplicate() { + return replace(content().retainedDuplicate()); + } + + @Override + public FullHttpResponse replace(ByteBuf content) { + FullHttpResponse response = new DefaultFullHttpResponse(protocolVersion(), status(), content, + headers().copy(), trailingHeaders().copy()); + response.setDecoderResult(decoderResult()); + return response; + } + + @Override + public int hashCode() { + int hash = this.hash; + if (hash == 0) { + if (ByteBufUtil.isAccessible(content())) { + try { + hash = 31 + content().hashCode(); + } catch (IllegalReferenceCountException ignored) { + // Handle race condition between checking refCnt() == 0 and using the object. + hash = 31; + } + } else { + hash = 31; + } + hash = 31 * hash + trailingHeaders().hashCode(); + hash = 31 * hash + super.hashCode(); + this.hash = hash; + } + return hash; + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof DefaultFullHttpResponse)) { + return false; + } + + DefaultFullHttpResponse other = (DefaultFullHttpResponse) o; + + return super.equals(other) && + content().equals(other.content()) && + trailingHeaders().equals(other.trailingHeaders()); + } + + @Override + public String toString() { + return HttpMessageUtil.appendFullResponse(new StringBuilder(256), this).toString(); + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/DefaultHttpContent.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/DefaultHttpContent.java new file mode 100644 index 0000000..032a4cb --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/DefaultHttpContent.java @@ -0,0 +1,105 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.buffer.ByteBuf; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.StringUtil; + +/** + * The default {@link HttpContent} implementation. + */ +public class DefaultHttpContent extends DefaultHttpObject implements HttpContent { + + private final ByteBuf content; + + /** + * Creates a new instance with the specified chunk content. + */ + public DefaultHttpContent(ByteBuf content) { + this.content = ObjectUtil.checkNotNull(content, "content"); + } + + @Override + public ByteBuf content() { + return content; + } + + @Override + public HttpContent copy() { + return replace(content.copy()); + } + + @Override + public HttpContent duplicate() { + return replace(content.duplicate()); + } + + @Override + public HttpContent retainedDuplicate() { + return replace(content.retainedDuplicate()); + } + + @Override + public HttpContent replace(ByteBuf content) { + return new DefaultHttpContent(content); + } + + @Override + public int refCnt() { + return content.refCnt(); + } + + @Override + public HttpContent retain() { + content.retain(); + return this; + } + + @Override + public HttpContent retain(int increment) { + content.retain(increment); + return this; + } + + @Override + public HttpContent touch() { + content.touch(); + return this; + } + + @Override + public HttpContent touch(Object hint) { + content.touch(hint); + return this; + } + + @Override + public boolean release() { + return content.release(); + } + + @Override + public boolean release(int decrement) { + return content.release(decrement); + } + + @Override + public String toString() { + return StringUtil.simpleClassName(this) + + "(data: " + content() + ", decoderResult: " + decoderResult() + ')'; + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/DefaultHttpHeaders.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/DefaultHttpHeaders.java new file mode 100644 index 0000000..af8b334 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/DefaultHttpHeaders.java @@ -0,0 +1,446 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.handler.codec.CharSequenceValueConverter; +import io.netty.handler.codec.DateFormatter; +import io.netty.handler.codec.DefaultHeaders; +import io.netty.handler.codec.DefaultHeaders.NameValidator; +import io.netty.handler.codec.DefaultHeaders.ValueValidator; +import io.netty.handler.codec.DefaultHeadersImpl; +import io.netty.handler.codec.HeadersUtils; +import io.netty.handler.codec.ValueConverter; + +import java.util.ArrayList; +import java.util.Calendar; +import java.util.Collections; +import java.util.Date; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Set; + +import static io.netty.util.AsciiString.CASE_INSENSITIVE_HASHER; +import static io.netty.util.AsciiString.CASE_SENSITIVE_HASHER; + +/** + * Default implementation of {@link HttpHeaders}. + */ +public class DefaultHttpHeaders extends HttpHeaders { + private final DefaultHeaders headers; + + /** + * Create a new, empty HTTP headers object. + *

+ * Header names and values are validated as they are added, to ensure they are compliant with the HTTP protocol. + */ + public DefaultHttpHeaders() { + this(nameValidator(true), valueValidator(true)); + } + + /** + * Warning! Setting {@code validate} to {@code false} will mean that Netty won't + * validate & protect against user-supplied header values that are malicious. + * This can leave your server implementation vulnerable to + * + * CWE-113: Improper Neutralization of CRLF Sequences in HTTP Headers ('HTTP Response Splitting') + * . + * When disabling this validation, it is the responsibility of the caller to ensure that the values supplied + * do not contain a non-url-escaped carriage return (CR) and/or line feed (LF) characters. + * + * @param validate Should Netty validate header values to ensure they aren't malicious. + * @deprecated Prefer using the {@link #DefaultHttpHeaders()} constructor instead, + * to always have validation enabled. + */ + @Deprecated + public DefaultHttpHeaders(boolean validate) { + this(nameValidator(validate), valueValidator(validate)); + } + + /** + * Create an HTTP headers object with the given name validator. + *

+ * Warning! It is strongly recommended that the name validator implement validation that is at least as + * strict as {@link HttpHeaderValidationUtil#validateToken(CharSequence)}. + * It is also strongly recommended that {@code validateValues} is enabled. + *

+ * Without these validations in place, your code can be susceptible to + * + * CWE-113: Improper Neutralization of CRLF Sequences in HTTP Headers ('HTTP Response Splitting') + * . + * It is the responsibility of the caller to ensure that the values supplied + * do not contain a non-url-escaped carriage return (CR) and/or line feed (LF) characters. + * + * @param validateValues Should Netty validate header values to ensure they aren't malicious. + * @param nameValidator The {@link NameValidator} to use, never {@code null}. + */ + protected DefaultHttpHeaders(boolean validateValues, NameValidator nameValidator) { + this(nameValidator, valueValidator(validateValues)); + } + + /** + * Create an HTTP headers object with the given name and value validators. + *

+ * Warning! It is strongly recommended that the name validator implement validation that is at least as + * strict as {@link HttpHeaderValidationUtil#validateToken(CharSequence)}. + * And that the value validator is at least as strict as + * {@link HttpHeaderValidationUtil#validateValidHeaderValue(CharSequence)}. + *

+ * Without these validations in place, your code can be susceptible to + * + * CWE-113: Improper Neutralization of CRLF Sequences in HTTP Headers ('HTTP Response Splitting') + * . + * It is the responsibility of the caller to ensure that the values supplied + * do not contain a non-url-escaped carriage return (CR) and/or line feed (LF) characters. + * + * @param nameValidator The {@link NameValidator} to use, never {@code null}. + * @param valueValidator The {@link ValueValidator} to use, never {@code null}. + */ + protected DefaultHttpHeaders( + NameValidator nameValidator, + ValueValidator valueValidator) { + this(nameValidator, valueValidator, 16); + } + + /** + * Create an HTTP headers object with the given name and value validators. + *

+ * Warning! It is strongly recommended that the name validator implement validation that is at least as + * strict as {@link HttpHeaderValidationUtil#validateToken(CharSequence)}. + * And that the value validator is at least as strict as + * {@link HttpHeaderValidationUtil#validateValidHeaderValue(CharSequence)}. + *

+ * Without these validations in place, your code can be susceptible to + * + * CWE-113: Improper Neutralization of CRLF Sequences in HTTP Headers ('HTTP Response Splitting') + * . + * It is the responsibility of the caller to ensure that the values supplied + * do not contain a non-url-escaped carriage return (CR) and/or line feed (LF) characters. + * + * @param nameValidator The {@link NameValidator} to use, never {@code null}. + * @param valueValidator The {@link ValueValidator} to use, never {@code null}. + * @param sizeHint A hint about the anticipated number of entries. + */ + protected DefaultHttpHeaders( + NameValidator nameValidator, + ValueValidator valueValidator, + int sizeHint) { + this(new DefaultHeadersImpl( + CASE_INSENSITIVE_HASHER, + HeaderValueConverter.INSTANCE, + nameValidator, + sizeHint, + valueValidator)); + } + + protected DefaultHttpHeaders(DefaultHeaders headers) { + this.headers = headers; + } + + @Override + public HttpHeaders add(HttpHeaders headers) { + if (headers instanceof DefaultHttpHeaders) { + this.headers.add(((DefaultHttpHeaders) headers).headers); + return this; + } else { + return super.add(headers); + } + } + + @Override + public HttpHeaders set(HttpHeaders headers) { + if (headers instanceof DefaultHttpHeaders) { + this.headers.set(((DefaultHttpHeaders) headers).headers); + return this; + } else { + return super.set(headers); + } + } + + @Override + public HttpHeaders add(String name, Object value) { + headers.addObject(name, value); + return this; + } + + @Override + public HttpHeaders add(CharSequence name, Object value) { + headers.addObject(name, value); + return this; + } + + @Override + public HttpHeaders add(String name, Iterable values) { + headers.addObject(name, values); + return this; + } + + @Override + public HttpHeaders add(CharSequence name, Iterable values) { + headers.addObject(name, values); + return this; + } + + @Override + public HttpHeaders addInt(CharSequence name, int value) { + headers.addInt(name, value); + return this; + } + + @Override + public HttpHeaders addShort(CharSequence name, short value) { + headers.addShort(name, value); + return this; + } + + @Override + public HttpHeaders remove(String name) { + headers.remove(name); + return this; + } + + @Override + public HttpHeaders remove(CharSequence name) { + headers.remove(name); + return this; + } + + @Override + public HttpHeaders set(String name, Object value) { + headers.setObject(name, value); + return this; + } + + @Override + public HttpHeaders set(CharSequence name, Object value) { + headers.setObject(name, value); + return this; + } + + @Override + public HttpHeaders set(String name, Iterable values) { + headers.setObject(name, values); + return this; + } + + @Override + public HttpHeaders set(CharSequence name, Iterable values) { + headers.setObject(name, values); + return this; + } + + @Override + public HttpHeaders setInt(CharSequence name, int value) { + headers.setInt(name, value); + return this; + } + + @Override + public HttpHeaders setShort(CharSequence name, short value) { + headers.setShort(name, value); + return this; + } + + @Override + public HttpHeaders clear() { + headers.clear(); + return this; + } + + @Override + public String get(String name) { + return get((CharSequence) name); + } + + @Override + public String get(CharSequence name) { + return HeadersUtils.getAsString(headers, name); + } + + @Override + public Integer getInt(CharSequence name) { + return headers.getInt(name); + } + + @Override + public int getInt(CharSequence name, int defaultValue) { + return headers.getInt(name, defaultValue); + } + + @Override + public Short getShort(CharSequence name) { + return headers.getShort(name); + } + + @Override + public short getShort(CharSequence name, short defaultValue) { + return headers.getShort(name, defaultValue); + } + + @Override + public Long getTimeMillis(CharSequence name) { + return headers.getTimeMillis(name); + } + + @Override + public long getTimeMillis(CharSequence name, long defaultValue) { + return headers.getTimeMillis(name, defaultValue); + } + + @Override + public List getAll(String name) { + return getAll((CharSequence) name); + } + + @Override + public List getAll(CharSequence name) { + return HeadersUtils.getAllAsString(headers, name); + } + + @Override + public List> entries() { + if (isEmpty()) { + return Collections.emptyList(); + } + List> entriesConverted = new ArrayList>( + headers.size()); + for (Entry entry : this) { + entriesConverted.add(entry); + } + return entriesConverted; + } + + @Deprecated + @Override + public Iterator> iterator() { + return HeadersUtils.iteratorAsString(headers); + } + + @Override + public Iterator> iteratorCharSequence() { + return headers.iterator(); + } + + @Override + public Iterator valueStringIterator(CharSequence name) { + final Iterator itr = valueCharSequenceIterator(name); + return new Iterator() { + @Override + public boolean hasNext() { + return itr.hasNext(); + } + + @Override + public String next() { + return itr.next().toString(); + } + + @Override + public void remove() { + itr.remove(); + } + }; + } + + @Override + public Iterator valueCharSequenceIterator(CharSequence name) { + return headers.valueIterator(name); + } + + @Override + public boolean contains(String name) { + return contains((CharSequence) name); + } + + @Override + public boolean contains(CharSequence name) { + return headers.contains(name); + } + + @Override + public boolean isEmpty() { + return headers.isEmpty(); + } + + @Override + public int size() { + return headers.size(); + } + + @Override + public boolean contains(String name, String value, boolean ignoreCase) { + return contains((CharSequence) name, (CharSequence) value, ignoreCase); + } + + @Override + public boolean contains(CharSequence name, CharSequence value, boolean ignoreCase) { + return headers.contains(name, value, ignoreCase ? CASE_INSENSITIVE_HASHER : CASE_SENSITIVE_HASHER); + } + + @Override + public Set names() { + return HeadersUtils.namesAsString(headers); + } + + @Override + public boolean equals(Object o) { + return o instanceof DefaultHttpHeaders + && headers.equals(((DefaultHttpHeaders) o).headers, CASE_SENSITIVE_HASHER); + } + + @Override + public int hashCode() { + return headers.hashCode(CASE_SENSITIVE_HASHER); + } + + @Override + public HttpHeaders copy() { + return new DefaultHttpHeaders(headers.copy()); + } + + static ValueConverter valueConverter() { + return HeaderValueConverter.INSTANCE; + } + + static ValueValidator valueValidator(boolean validate) { + return validate ? DefaultHttpHeadersFactory.headersFactory().getValueValidator() : + DefaultHttpHeadersFactory.headersFactory().withValidation(false).getValueValidator(); + } + + static NameValidator nameValidator(boolean validate) { + return validate ? DefaultHttpHeadersFactory.headersFactory().getNameValidator() : + DefaultHttpHeadersFactory.headersFactory().withNameValidation(false).getNameValidator(); + } + + private static class HeaderValueConverter extends CharSequenceValueConverter { + static final HeaderValueConverter INSTANCE = new HeaderValueConverter(); + + @Override + public CharSequence convertObject(Object value) { + if (value instanceof CharSequence) { + return (CharSequence) value; + } + if (value instanceof Date) { + return DateFormatter.format((Date) value); + } + if (value instanceof Calendar) { + return DateFormatter.format(((Calendar) value).getTime()); + } + return value.toString(); + } + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/DefaultHttpHeadersFactory.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/DefaultHttpHeadersFactory.java new file mode 100644 index 0000000..3c7e76a --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/DefaultHttpHeadersFactory.java @@ -0,0 +1,313 @@ +/* + * Copyright 2023 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.handler.codec.DefaultHeaders.NameValidator; +import io.netty.handler.codec.DefaultHeaders.ValueValidator; + +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +/** + * A builder of {@link HttpHeadersFactory} instances, that itself implements {@link HttpHeadersFactory}. + * The builder is immutable, and every {@code with-} method produce a new, modified instance. + *

+ * The default builder you most likely want to start with is {@link DefaultHttpHeadersFactory#headersFactory()}. + */ +public final class DefaultHttpHeadersFactory implements HttpHeadersFactory { + private static final NameValidator DEFAULT_NAME_VALIDATOR = new NameValidator() { + @Override + public void validateName(CharSequence name) { + if (name == null || name.length() == 0) { + throw new IllegalArgumentException("empty headers are not allowed [" + name + ']'); + } + int index = HttpHeaderValidationUtil.validateToken(name); + if (index != -1) { + throw new IllegalArgumentException("a header name can only contain \"token\" characters, " + + "but found invalid character 0x" + Integer.toHexString(name.charAt(index)) + + " at index " + index + " of header '" + name + "'."); + } + } + }; + private static final ValueValidator DEFAULT_VALUE_VALIDATOR = new ValueValidator() { + @Override + public void validate(CharSequence value) { + int index = HttpHeaderValidationUtil.validateValidHeaderValue(value); + if (index != -1) { + throw new IllegalArgumentException("a header value contains prohibited character 0x" + + Integer.toHexString(value.charAt(index)) + " at index " + index + '.'); + } + } + }; + private static final NameValidator DEFAULT_TRAILER_NAME_VALIDATOR = + new NameValidator() { + @Override + public void validateName(CharSequence name) { + DEFAULT_NAME_VALIDATOR.validateName(name); + if (HttpHeaderNames.CONTENT_LENGTH.contentEqualsIgnoreCase(name) + || HttpHeaderNames.TRANSFER_ENCODING.contentEqualsIgnoreCase(name) + || HttpHeaderNames.TRAILER.contentEqualsIgnoreCase(name)) { + throw new IllegalArgumentException("prohibited trailing header: " + name); + } + } + }; + + @SuppressWarnings("unchecked") + private static final NameValidator NO_NAME_VALIDATOR = NameValidator.NOT_NULL; + @SuppressWarnings("unchecked") + private static final ValueValidator NO_VALUE_VALIDATOR = + (ValueValidator) ValueValidator.NO_VALIDATION; + + private static final DefaultHttpHeadersFactory DEFAULT = + new DefaultHttpHeadersFactory(DEFAULT_NAME_VALIDATOR, DEFAULT_VALUE_VALIDATOR, false); + private static final DefaultHttpHeadersFactory DEFAULT_TRAILER = + new DefaultHttpHeadersFactory(DEFAULT_TRAILER_NAME_VALIDATOR, DEFAULT_VALUE_VALIDATOR, false); + private static final DefaultHttpHeadersFactory DEFAULT_COMBINING = + new DefaultHttpHeadersFactory(DEFAULT.nameValidator, DEFAULT.valueValidator, true); + private static final DefaultHttpHeadersFactory DEFAULT_NO_VALIDATION = + new DefaultHttpHeadersFactory(NO_NAME_VALIDATOR, NO_VALUE_VALIDATOR, false); + + private final NameValidator nameValidator; + private final ValueValidator valueValidator; + private final boolean combiningHeaders; + + /** + * Create a header builder with the given settings. + * + * @param nameValidator The name validator to use, not null. + * @param valueValidator The value validator to use, not null. + * @param combiningHeaders {@code true} if multi-valued headers should be combined into single lines. + */ + private DefaultHttpHeadersFactory( + NameValidator nameValidator, + ValueValidator valueValidator, + boolean combiningHeaders) { + this.nameValidator = checkNotNull(nameValidator, "nameValidator"); + this.valueValidator = checkNotNull(valueValidator, "valueValidator"); + this.combiningHeaders = combiningHeaders; + } + + /** + * Get the default implementation of {@link HttpHeadersFactory} for creating headers. + *

+ * This {@link DefaultHttpHeadersFactory} creates {@link HttpHeaders} instances that has the + * recommended header validation enabled. + */ + public static DefaultHttpHeadersFactory headersFactory() { + return DEFAULT; + } + + /** + * Get the default implementation of {@link HttpHeadersFactory} for creating trailers. + *

+ * This {@link DefaultHttpHeadersFactory} creates {@link HttpHeaders} instances that has the + * validation enabled that is recommended for trailers. + */ + public static DefaultHttpHeadersFactory trailersFactory() { + return DEFAULT_TRAILER; + } + + @Override + public HttpHeaders newHeaders() { + if (isCombiningHeaders()) { + return new CombinedHttpHeaders(getNameValidator(), getValueValidator()); + } + return new DefaultHttpHeaders(getNameValidator(), getValueValidator()); + } + + @Override + public HttpHeaders newEmptyHeaders() { + if (isCombiningHeaders()) { + return new CombinedHttpHeaders(getNameValidator(), getValueValidator(), 2); + } + return new DefaultHttpHeaders(getNameValidator(), getValueValidator(), 2); + } + + /** + * Create a new builder that has HTTP header name validation enabled or disabled. + *

+ * Warning! Setting {@code validation} to {@code false} will mean that Netty won't + * validate & protect against user-supplied headers that are malicious. + * This can leave your server implementation vulnerable to + * + * CWE-113: Improper Neutralization of CRLF Sequences in HTTP Headers ('HTTP Response Splitting') + * . + * When disabling this validation, it is the responsibility of the caller to ensure that the values supplied + * do not contain a non-url-escaped carriage return (CR) and/or line feed (LF) characters. + * + * @param validation If validation should be enabled or disabled. + * @return The new builder. + */ + public DefaultHttpHeadersFactory withNameValidation(boolean validation) { + return withNameValidator(validation ? DEFAULT_NAME_VALIDATOR : NO_NAME_VALIDATOR); + } + + /** + * Create a new builder that with the given {@link NameValidator}. + *

+ * Warning! If the given validator does not check that the header names are standards compliant, Netty won't + * validate & protect against user-supplied headers that are malicious. + * This can leave your server implementation vulnerable to + * + * CWE-113: Improper Neutralization of CRLF Sequences in HTTP Headers ('HTTP Response Splitting') + * . + * When disabling this validation, it is the responsibility of the caller to ensure that the values supplied + * do not contain a non-url-escaped carriage return (CR) and/or line feed (LF) characters. + * + * @param validator The HTTP header name validator to use. + * @return The new builder. + */ + public DefaultHttpHeadersFactory withNameValidator(NameValidator validator) { + if (nameValidator == checkNotNull(validator, "validator")) { + return this; + } + if (validator == DEFAULT_NAME_VALIDATOR && valueValidator == DEFAULT_VALUE_VALIDATOR) { + return combiningHeaders ? DEFAULT_COMBINING : DEFAULT; + } + return new DefaultHttpHeadersFactory(validator, valueValidator, combiningHeaders); + } + + /** + * Create a new builder that has HTTP header value validation enabled or disabled. + *

+ * Warning! Setting {@code validation} to {@code false} will mean that Netty won't + * validate & protect against user-supplied headers that are malicious. + * This can leave your server implementation vulnerable to + * + * CWE-113: Improper Neutralization of CRLF Sequences in HTTP Headers ('HTTP Response Splitting') + * . + * When disabling this validation, it is the responsibility of the caller to ensure that the values supplied + * do not contain a non-url-escaped carriage return (CR) and/or line feed (LF) characters. + * + * @param validation If validation should be enabled or disabled. + * @return The new builder. + */ + public DefaultHttpHeadersFactory withValueValidation(boolean validation) { + return withValueValidator(validation ? DEFAULT_VALUE_VALIDATOR : NO_VALUE_VALIDATOR); + } + + /** + * Create a new builder that with the given {@link ValueValidator}. + *

+ * Warning! If the given validator does not check that the header values are standards compliant, Netty won't + * validate & protect against user-supplied headers that are malicious. + * This can leave your server implementation vulnerable to + * + * CWE-113: Improper Neutralization of CRLF Sequences in HTTP Headers ('HTTP Response Splitting') + * . + * When disabling this validation, it is the responsibility of the caller to ensure that the values supplied + * do not contain a non-url-escaped carriage return (CR) and/or line feed (LF) characters. + * + * @param validator The HTTP header name validator to use. + * @return The new builder. + */ + public DefaultHttpHeadersFactory withValueValidator(ValueValidator validator) { + if (valueValidator == checkNotNull(validator, "validator")) { + return this; + } + if (nameValidator == DEFAULT_NAME_VALIDATOR && validator == DEFAULT_VALUE_VALIDATOR) { + return combiningHeaders ? DEFAULT_COMBINING : DEFAULT; + } + return new DefaultHttpHeadersFactory(nameValidator, validator, combiningHeaders); + } + + /** + * Create a new builder that has HTTP header validation enabled or disabled. + *

+ * Warning! Setting {@code validation} to {@code false} will mean that Netty won't + * validate & protect against user-supplied headers that are malicious. + * This can leave your server implementation vulnerable to + * + * CWE-113: Improper Neutralization of CRLF Sequences in HTTP Headers ('HTTP Response Splitting') + * . + * When disabling this validation, it is the responsibility of the caller to ensure that the values supplied + * do not contain a non-url-escaped carriage return (CR) and/or line feed (LF) characters. + * + * @param validation If validation should be enabled or disabled. + * @return The new builder. + */ + public DefaultHttpHeadersFactory withValidation(boolean validation) { + if (this == DEFAULT && !validation) { + return DEFAULT_NO_VALIDATION; + } + if (this == DEFAULT_NO_VALIDATION && validation) { + return DEFAULT; + } + return withNameValidation(validation).withValueValidation(validation); + } + + /** + * Create a new builder that will build {@link HttpHeaders} objects that either combine + * multi-valued headers, or not. + * + * @param combiningHeaders {@code true} if multi-valued headers should be combined, otherwise {@code false}. + * @return The new builder. + */ + public DefaultHttpHeadersFactory withCombiningHeaders(boolean combiningHeaders) { + if (this.combiningHeaders == combiningHeaders) { + return this; + } + return new DefaultHttpHeadersFactory(nameValidator, valueValidator, combiningHeaders); + } + + /** + * Get the currently configured {@link NameValidator}. + *

+ * This method will be used by the {@link #newHeaders()} method. + * + * @return The configured name validator. + */ + public NameValidator getNameValidator() { + return nameValidator; + } + + /** + * Get the currently configured {@link ValueValidator}. + *

+ * This method will be used by the {@link #newHeaders()} method. + * + * @return The configured value validator. + */ + public ValueValidator getValueValidator() { + return valueValidator; + } + + /** + * Check whether header combining is enabled or not. + * + * @return {@code true} if header value combining is enabled, otherwise {@code false}. + */ + public boolean isCombiningHeaders() { + return combiningHeaders; + } + + /** + * Check whether header name validation is enabled. + * + * @return {@code true} if header name validation is enabled, otherwise {@code false}. + */ + public boolean isValidatingHeaderNames() { + return nameValidator != NO_NAME_VALIDATOR; + } + + /** + * Check whether header value validation is enabled. + * + * @return {@code true} if header value validation is enabled, otherwise {@code false}. + */ + public boolean isValidatingHeaderValues() { + return valueValidator != NO_VALUE_VALIDATOR; + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/DefaultHttpMessage.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/DefaultHttpMessage.java new file mode 100644 index 0000000..9e477c4 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/DefaultHttpMessage.java @@ -0,0 +1,107 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +/** + * The default {@link HttpMessage} implementation. + */ +public abstract class DefaultHttpMessage extends DefaultHttpObject implements HttpMessage { + private static final int HASH_CODE_PRIME = 31; + private HttpVersion version; + private final HttpHeaders headers; + + /** + * Creates a new instance. + */ + protected DefaultHttpMessage(final HttpVersion version) { + this(version, DefaultHttpHeadersFactory.headersFactory()); + } + + /** + * Creates a new instance. + *

+ * @deprecated Use the {@link #DefaultHttpMessage(HttpVersion, HttpHeadersFactory)} constructor instead, + * ideally using the {@link DefaultHttpHeadersFactory#headersFactory()}, + * or a factory that otherwise has validation enabled. + */ + @Deprecated + protected DefaultHttpMessage(final HttpVersion version, boolean validateHeaders, boolean singleFieldHeaders) { + this(version, DefaultHttpHeadersFactory.headersFactory() + .withValidation(validateHeaders) + .withCombiningHeaders(singleFieldHeaders)); + } + + /** + * Creates a new instance. + */ + protected DefaultHttpMessage(HttpVersion version, HttpHeadersFactory headersFactory) { + this(version, headersFactory.newHeaders()); + } + + /** + * Creates a new instance. + */ + protected DefaultHttpMessage(final HttpVersion version, HttpHeaders headers) { + this.version = checkNotNull(version, "version"); + this.headers = checkNotNull(headers, "headers"); + } + + @Override + public HttpHeaders headers() { + return headers; + } + + @Override + @Deprecated + public HttpVersion getProtocolVersion() { + return protocolVersion(); + } + + @Override + public HttpVersion protocolVersion() { + return version; + } + + @Override + public int hashCode() { + int result = 1; + result = HASH_CODE_PRIME * result + headers.hashCode(); + result = HASH_CODE_PRIME * result + version.hashCode(); + result = HASH_CODE_PRIME * result + super.hashCode(); + return result; + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof DefaultHttpMessage)) { + return false; + } + + DefaultHttpMessage other = (DefaultHttpMessage) o; + + return headers().equals(other.headers()) && + protocolVersion().equals(other.protocolVersion()) && + super.equals(o); + } + + @Override + public HttpMessage setProtocolVersion(HttpVersion version) { + this.version = checkNotNull(version, "version"); + return this; + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/DefaultHttpObject.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/DefaultHttpObject.java new file mode 100644 index 0000000..df66016 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/DefaultHttpObject.java @@ -0,0 +1,63 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.handler.codec.DecoderResult; +import io.netty.util.internal.ObjectUtil; + +public class DefaultHttpObject implements HttpObject { + + private static final int HASH_CODE_PRIME = 31; + private DecoderResult decoderResult = DecoderResult.SUCCESS; + + protected DefaultHttpObject() { + // Disallow direct instantiation + } + + @Override + public DecoderResult decoderResult() { + return decoderResult; + } + + @Override + @Deprecated + public DecoderResult getDecoderResult() { + return decoderResult(); + } + + @Override + public void setDecoderResult(DecoderResult decoderResult) { + this.decoderResult = ObjectUtil.checkNotNull(decoderResult, "decoderResult"); + } + + @Override + public int hashCode() { + int result = 1; + result = HASH_CODE_PRIME * result + decoderResult.hashCode(); + return result; + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof DefaultHttpObject)) { + return false; + } + + DefaultHttpObject other = (DefaultHttpObject) o; + + return decoderResult().equals(other.decoderResult()); + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/DefaultHttpRequest.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/DefaultHttpRequest.java new file mode 100644 index 0000000..271b606 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/DefaultHttpRequest.java @@ -0,0 +1,149 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import static io.netty.handler.codec.http.DefaultHttpHeadersFactory.headersFactory; +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +/** + * The default {@link HttpRequest} implementation. + */ +public class DefaultHttpRequest extends DefaultHttpMessage implements HttpRequest { + private static final int HASH_CODE_PRIME = 31; + private HttpMethod method; + private String uri; + + /** + * Creates a new instance. + * + * @param httpVersion the HTTP version of the request + * @param method the HTTP method of the request + * @param uri the URI or path of the request + */ + public DefaultHttpRequest(HttpVersion httpVersion, HttpMethod method, String uri) { + this(httpVersion, method, uri, headersFactory().newHeaders()); + } + + /** + * Creates a new instance. + * + * @param httpVersion the HTTP version of the request + * @param method the HTTP method of the request + * @param uri the URI or path of the request + * @param validateHeaders validate the header names and values when adding them to the {@link HttpHeaders} + * @deprecated Prefer the {@link #DefaultHttpRequest(HttpVersion, HttpMethod, String)} constructor instead, + * to always have header validation enabled. + */ + @Deprecated + public DefaultHttpRequest(HttpVersion httpVersion, HttpMethod method, String uri, boolean validateHeaders) { + this(httpVersion, method, uri, headersFactory().withValidation(validateHeaders)); + } + + /** + * Creates a new instance. + * + * @param httpVersion the HTTP version of the request + * @param method the HTTP method of the request + * @param uri the URI or path of the request + * @param headersFactory the {@link HttpHeadersFactory} used to create the headers for this Request. + * The recommended default is {@link DefaultHttpHeadersFactory#headersFactory()}. + */ + public DefaultHttpRequest(HttpVersion httpVersion, HttpMethod method, String uri, + HttpHeadersFactory headersFactory) { + this(httpVersion, method, uri, headersFactory.newHeaders()); + } + + /** + * Creates a new instance. + * + * @param httpVersion the HTTP version of the request + * @param method the HTTP method of the request + * @param uri the URI or path of the request + * @param headers the Headers for this Request + */ + public DefaultHttpRequest(HttpVersion httpVersion, HttpMethod method, String uri, HttpHeaders headers) { + super(httpVersion, headers); + this.method = checkNotNull(method, "method"); + this.uri = checkNotNull(uri, "uri"); + } + + @Override + @Deprecated + public HttpMethod getMethod() { + return method(); + } + + @Override + public HttpMethod method() { + return method; + } + + @Override + @Deprecated + public String getUri() { + return uri(); + } + + @Override + public String uri() { + return uri; + } + + @Override + public HttpRequest setMethod(HttpMethod method) { + this.method = checkNotNull(method, "method"); + return this; + } + + @Override + public HttpRequest setUri(String uri) { + this.uri = checkNotNull(uri, "uri"); + return this; + } + + @Override + public HttpRequest setProtocolVersion(HttpVersion version) { + super.setProtocolVersion(version); + return this; + } + + @Override + public int hashCode() { + int result = 1; + result = HASH_CODE_PRIME * result + method.hashCode(); + result = HASH_CODE_PRIME * result + uri.hashCode(); + result = HASH_CODE_PRIME * result + super.hashCode(); + return result; + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof DefaultHttpRequest)) { + return false; + } + + DefaultHttpRequest other = (DefaultHttpRequest) o; + + return method().equals(other.method()) && + uri().equalsIgnoreCase(other.uri()) && + super.equals(o); + } + + @Override + public String toString() { + return HttpMessageUtil.appendRequest(new StringBuilder(256), this).toString(); + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/DefaultHttpResponse.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/DefaultHttpResponse.java new file mode 100644 index 0000000..6bc7d79 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/DefaultHttpResponse.java @@ -0,0 +1,145 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.util.internal.ObjectUtil; + +import static io.netty.handler.codec.http.DefaultHttpHeadersFactory.headersFactory; +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +/** + * The default {@link HttpResponse} implementation. + */ +public class DefaultHttpResponse extends DefaultHttpMessage implements HttpResponse { + + private HttpResponseStatus status; + + /** + * Creates a new instance. + * + * @param version the HTTP version of this response + * @param status the status of this response + */ + public DefaultHttpResponse(HttpVersion version, HttpResponseStatus status) { + this(version, status, headersFactory()); + } + + /** + * Creates a new instance. + * + * @param version the HTTP version of this response + * @param status the status of this response + * @param validateHeaders validate the header names and values when adding them to the {@link HttpHeaders} + * @deprecated Use the {@link #DefaultHttpResponse(HttpVersion, HttpResponseStatus, HttpHeadersFactory)} constructor + * instead. + */ + @Deprecated + public DefaultHttpResponse(HttpVersion version, HttpResponseStatus status, boolean validateHeaders) { + this(version, status, headersFactory().withValidation(validateHeaders)); + } + + /** + * Creates a new instance. + * + * @param version the HTTP version of this response + * @param status the status of this response + * @param validateHeaders validate the header names and values when adding them to the {@link HttpHeaders} + * @param singleFieldHeaders {@code true} to check and enforce that headers with the same name are appended + * to the same entry and comma separated. + * See RFC 7230, 3.2.2. + * {@code false} to allow multiple header entries with the same name to + * coexist. + * @deprecated Use the {@link #DefaultHttpResponse(HttpVersion, HttpResponseStatus, HttpHeadersFactory)} constructor + * instead. + */ + @Deprecated + public DefaultHttpResponse(HttpVersion version, HttpResponseStatus status, boolean validateHeaders, + boolean singleFieldHeaders) { + this(version, status, headersFactory().withValidation(validateHeaders) + .withCombiningHeaders(singleFieldHeaders)); + } + + /** + * Creates a new instance. + * + * @param version the HTTP version of this response + * @param status the status of this response + * @param headersFactory the {@link HttpHeadersFactory} used to create the headers for this HTTP Response. + * The recommended default is {@link DefaultHttpHeadersFactory#headersFactory()}. + */ + public DefaultHttpResponse(HttpVersion version, HttpResponseStatus status, HttpHeadersFactory headersFactory) { + this(version, status, headersFactory.newHeaders()); + } + + /** + * Creates a new instance. + * + * @param version the HTTP version of this response + * @param status the status of this response + * @param headers the headers for this HTTP Response + */ + public DefaultHttpResponse(HttpVersion version, HttpResponseStatus status, HttpHeaders headers) { + super(version, headers); + this.status = checkNotNull(status, "status"); + } + + @Override + @Deprecated + public HttpResponseStatus getStatus() { + return status(); + } + + @Override + public HttpResponseStatus status() { + return status; + } + + @Override + public HttpResponse setStatus(HttpResponseStatus status) { + this.status = ObjectUtil.checkNotNull(status, "status"); + return this; + } + + @Override + public HttpResponse setProtocolVersion(HttpVersion version) { + super.setProtocolVersion(version); + return this; + } + + @Override + public String toString() { + return HttpMessageUtil.appendResponse(new StringBuilder(256), this).toString(); + } + + @Override + public int hashCode() { + int result = 1; + result = 31 * result + status.hashCode(); + result = 31 * result + super.hashCode(); + return result; + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof DefaultHttpResponse)) { + return false; + } + + DefaultHttpResponse other = (DefaultHttpResponse) o; + + return status.equals(other.status()) && super.equals(o); + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/DefaultLastHttpContent.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/DefaultLastHttpContent.java new file mode 100644 index 0000000..9965eea --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/DefaultLastHttpContent.java @@ -0,0 +1,151 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.util.internal.StringUtil; + +import java.util.Map.Entry; + +import static io.netty.handler.codec.http.DefaultHttpHeadersFactory.trailersFactory; +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +/** + * The default {@link LastHttpContent} implementation. + */ +public class DefaultLastHttpContent extends DefaultHttpContent implements LastHttpContent { + private final HttpHeaders trailingHeaders; + + /** + * Create a new empty, last HTTP content message. + */ + public DefaultLastHttpContent() { + this(Unpooled.buffer(0)); + } + + /** + * Create a new last HTTP content message with the given contents. + */ + public DefaultLastHttpContent(ByteBuf content) { + this(content, trailersFactory()); + } + + /** + * Create a new last HTTP content message with the given contents, and optional trailing header validation. + *

+ * Warning! Setting {@code validateHeaders} to {@code false} will mean that Netty won't + * validate & protect against user-supplied header values that are malicious. + * This can leave your server implementation vulnerable to + * + * CWE-113: Improper Neutralization of CRLF Sequences in HTTP Headers ('HTTP Response Splitting') + * . + * When disabling this validation, it is the responsibility of the caller to ensure that the values supplied + * do not contain a non-url-escaped carriage return (CR) and/or line feed (LF) characters. + * + * @deprecated Prefer the {@link #DefaultLastHttpContent(ByteBuf)} constructor instead, to always have header + * validation enabled. + */ + @Deprecated + public DefaultLastHttpContent(ByteBuf content, boolean validateHeaders) { + this(content, trailersFactory().withValidation(validateHeaders)); + } + + /** + * Create a new last HTTP content message with the given contents, and trailing headers from the given factory. + */ + public DefaultLastHttpContent(ByteBuf content, HttpHeadersFactory trailersFactory) { + super(content); + trailingHeaders = trailersFactory.newHeaders(); + } + + /** + * Create a new last HTTP content message with the given contents, and trailing headers. + */ + public DefaultLastHttpContent(ByteBuf content, HttpHeaders trailingHeaders) { + super(content); + this.trailingHeaders = checkNotNull(trailingHeaders, "trailingHeaders"); + } + + @Override + public LastHttpContent copy() { + return replace(content().copy()); + } + + @Override + public LastHttpContent duplicate() { + return replace(content().duplicate()); + } + + @Override + public LastHttpContent retainedDuplicate() { + return replace(content().retainedDuplicate()); + } + + @Override + public LastHttpContent replace(ByteBuf content) { + return new DefaultLastHttpContent(content, trailingHeaders().copy()); + } + + @Override + public LastHttpContent retain(int increment) { + super.retain(increment); + return this; + } + + @Override + public LastHttpContent retain() { + super.retain(); + return this; + } + + @Override + public LastHttpContent touch() { + super.touch(); + return this; + } + + @Override + public LastHttpContent touch(Object hint) { + super.touch(hint); + return this; + } + + @Override + public HttpHeaders trailingHeaders() { + return trailingHeaders; + } + + @Override + public String toString() { + StringBuilder buf = new StringBuilder(super.toString()); + buf.append(StringUtil.NEWLINE); + appendHeaders(buf); + + // Remove the last newline. + buf.setLength(buf.length() - StringUtil.NEWLINE.length()); + return buf.toString(); + } + + private void appendHeaders(StringBuilder buf) { + for (Entry e : trailingHeaders()) { + buf.append(e.getKey()); + buf.append(": "); + buf.append(e.getValue()); + buf.append(StringUtil.NEWLINE); + } + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/EmptyHttpHeaders.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/EmptyHttpHeaders.java new file mode 100644 index 0000000..4ca26a8 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/EmptyHttpHeaders.java @@ -0,0 +1,188 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.codec.http; + +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.Map.Entry; +import java.util.Set; + +public class EmptyHttpHeaders extends HttpHeaders { + static final Iterator> EMPTY_CHARS_ITERATOR = + Collections.>emptyList().iterator(); + + public static final EmptyHttpHeaders INSTANCE = instance(); + + /** + * @see InstanceInitializer#EMPTY_HEADERS + * @deprecated Use {@link EmptyHttpHeaders#INSTANCE} + *

+ * This is needed to break a cyclic static initialization loop between {@link HttpHeaders} and {@link + * EmptyHttpHeaders}. + */ + @Deprecated + static EmptyHttpHeaders instance() { + return InstanceInitializer.EMPTY_HEADERS; + } + + protected EmptyHttpHeaders() { + } + + @Override + public String get(String name) { + return null; + } + + @Override + public Integer getInt(CharSequence name) { + return null; + } + + @Override + public int getInt(CharSequence name, int defaultValue) { + return defaultValue; + } + + @Override + public Short getShort(CharSequence name) { + return null; + } + + @Override + public short getShort(CharSequence name, short defaultValue) { + return defaultValue; + } + + @Override + public Long getTimeMillis(CharSequence name) { + return null; + } + + @Override + public long getTimeMillis(CharSequence name, long defaultValue) { + return defaultValue; + } + + @Override + public List getAll(String name) { + return Collections.emptyList(); + } + + @Override + public List> entries() { + return Collections.emptyList(); + } + + @Override + public boolean contains(String name) { + return false; + } + + @Override + public boolean isEmpty() { + return true; + } + + @Override + public int size() { + return 0; + } + + @Override + public Set names() { + return Collections.emptySet(); + } + + @Override + public HttpHeaders add(String name, Object value) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public HttpHeaders add(String name, Iterable values) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public HttpHeaders addInt(CharSequence name, int value) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public HttpHeaders addShort(CharSequence name, short value) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public HttpHeaders set(String name, Object value) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public HttpHeaders set(String name, Iterable values) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public HttpHeaders setInt(CharSequence name, int value) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public HttpHeaders setShort(CharSequence name, short value) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public HttpHeaders remove(String name) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public HttpHeaders clear() { + throw new UnsupportedOperationException("read only"); + } + + @Override + public Iterator> iterator() { + return entries().iterator(); + } + + @Override + public Iterator> iteratorCharSequence() { + return EMPTY_CHARS_ITERATOR; + } + + /** + * This class is needed to break a cyclic static initialization loop between {@link HttpHeaders} and + * {@link EmptyHttpHeaders}. + */ + @Deprecated + private static final class InstanceInitializer { + /** + * The instance is instantiated here to break the cyclic static initialization between {@link EmptyHttpHeaders} + * and {@link HttpHeaders}. The issue is that if someone accesses {@link EmptyHttpHeaders#INSTANCE} before + * {@link HttpHeaders#EMPTY_HEADERS} then {@link HttpHeaders#EMPTY_HEADERS} will be {@code null}. + */ + @Deprecated + private static final EmptyHttpHeaders EMPTY_HEADERS = new EmptyHttpHeaders(); + + private InstanceInitializer() { + } + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/FullHttpMessage.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/FullHttpMessage.java new file mode 100644 index 0000000..735f4b6 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/FullHttpMessage.java @@ -0,0 +1,48 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.buffer.ByteBuf; + +/** + * Combines {@link HttpMessage} and {@link LastHttpContent} into one + * message. So it represent a complete http message. + */ +public interface FullHttpMessage extends HttpMessage, LastHttpContent { + @Override + FullHttpMessage copy(); + + @Override + FullHttpMessage duplicate(); + + @Override + FullHttpMessage retainedDuplicate(); + + @Override + FullHttpMessage replace(ByteBuf content); + + @Override + FullHttpMessage retain(int increment); + + @Override + FullHttpMessage retain(); + + @Override + FullHttpMessage touch(); + + @Override + FullHttpMessage touch(Object hint); +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/FullHttpRequest.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/FullHttpRequest.java new file mode 100644 index 0000000..1db1cec --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/FullHttpRequest.java @@ -0,0 +1,57 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.buffer.ByteBuf; + +/** + * Combine the {@link HttpRequest} and {@link FullHttpMessage}, so the request is a complete HTTP + * request. + */ +public interface FullHttpRequest extends HttpRequest, FullHttpMessage { + @Override + FullHttpRequest copy(); + + @Override + FullHttpRequest duplicate(); + + @Override + FullHttpRequest retainedDuplicate(); + + @Override + FullHttpRequest replace(ByteBuf content); + + @Override + FullHttpRequest retain(int increment); + + @Override + FullHttpRequest retain(); + + @Override + FullHttpRequest touch(); + + @Override + FullHttpRequest touch(Object hint); + + @Override + FullHttpRequest setProtocolVersion(HttpVersion version); + + @Override + FullHttpRequest setMethod(HttpMethod method); + + @Override + FullHttpRequest setUri(String uri); +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/FullHttpResponse.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/FullHttpResponse.java new file mode 100644 index 0000000..9f9be63 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/FullHttpResponse.java @@ -0,0 +1,54 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.buffer.ByteBuf; + +/** + * Combination of a {@link HttpResponse} and {@link FullHttpMessage}. + * So it represent a complete http response. + */ +public interface FullHttpResponse extends HttpResponse, FullHttpMessage { + @Override + FullHttpResponse copy(); + + @Override + FullHttpResponse duplicate(); + + @Override + FullHttpResponse retainedDuplicate(); + + @Override + FullHttpResponse replace(ByteBuf content); + + @Override + FullHttpResponse retain(int increment); + + @Override + FullHttpResponse retain(); + + @Override + FullHttpResponse touch(); + + @Override + FullHttpResponse touch(Object hint); + + @Override + FullHttpResponse setProtocolVersion(HttpVersion version); + + @Override + FullHttpResponse setStatus(HttpResponseStatus status); +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpChunkedInput.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpChunkedInput.java new file mode 100644 index 0000000..3c17b20 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpChunkedInput.java @@ -0,0 +1,119 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.stream.ChunkedInput; + +/** + * A {@link ChunkedInput} that fetches data chunk by chunk for use with HTTP chunked transfers. + *

+ * Each chunk from the input data will be wrapped within a {@link HttpContent}. At the end of the input data, + * {@link LastHttpContent} will be written. + *

+ * Ensure that your HTTP response header contains {@code Transfer-Encoding: chunked}. + *

+ *

+ * public void messageReceived(ChannelHandlerContext ctx, FullHttpRequest request) throws Exception {
+ *     HttpResponse response = new DefaultHttpResponse(HTTP_1_1, OK);
+ *     response.headers().set(TRANSFER_ENCODING, CHUNKED);
+ *     ctx.write(response);
+ *
+ *     HttpChunkedInput httpChunkWriter = new HttpChunkedInput(
+ *         new ChunkedFile("/tmp/myfile.txt"));
+ *     ChannelFuture sendFileFuture = ctx.write(httpChunkWriter);
+ * }
+ * 
+ */ +public class HttpChunkedInput implements ChunkedInput { + + private final ChunkedInput input; + private final LastHttpContent lastHttpContent; + private boolean sentLastChunk; + + /** + * Creates a new instance using the specified input. + * @param input {@link ChunkedInput} containing data to write + */ + public HttpChunkedInput(ChunkedInput input) { + this.input = input; + lastHttpContent = LastHttpContent.EMPTY_LAST_CONTENT; + } + + /** + * Creates a new instance using the specified input. {@code lastHttpContent} will be written as the terminating + * chunk. + * @param input {@link ChunkedInput} containing data to write + * @param lastHttpContent {@link LastHttpContent} that will be written as the terminating chunk. Use this for + * training headers. + */ + public HttpChunkedInput(ChunkedInput input, LastHttpContent lastHttpContent) { + this.input = input; + this.lastHttpContent = lastHttpContent; + } + + @Override + public boolean isEndOfInput() throws Exception { + if (input.isEndOfInput()) { + // Only end of input after last HTTP chunk has been sent + return sentLastChunk; + } else { + return false; + } + } + + @Override + public void close() throws Exception { + input.close(); + } + + @Deprecated + @Override + public HttpContent readChunk(ChannelHandlerContext ctx) throws Exception { + return readChunk(ctx.alloc()); + } + + @Override + public HttpContent readChunk(ByteBufAllocator allocator) throws Exception { + if (input.isEndOfInput()) { + if (sentLastChunk) { + return null; + } else { + // Send last chunk for this input + sentLastChunk = true; + return lastHttpContent; + } + } else { + ByteBuf buf = input.readChunk(allocator); + if (buf == null) { + return null; + } + return new DefaultHttpContent(buf); + } + } + + @Override + public long length() { + return input.length(); + } + + @Override + public long progress() { + return input.progress(); + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpClientCodec.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpClientCodec.java new file mode 100644 index 0000000..b341227 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpClientCodec.java @@ -0,0 +1,422 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.CombinedChannelDuplexHandler; +import io.netty.handler.codec.PrematureChannelClosureException; + +import java.util.ArrayDeque; +import java.util.List; +import java.util.Queue; +import java.util.concurrent.atomic.AtomicLong; + +import static io.netty.handler.codec.http.HttpObjectDecoder.DEFAULT_ALLOW_DUPLICATE_CONTENT_LENGTHS; +import static io.netty.handler.codec.http.HttpObjectDecoder.DEFAULT_ALLOW_PARTIAL_CHUNKS; +import static io.netty.handler.codec.http.HttpObjectDecoder.DEFAULT_MAX_CHUNK_SIZE; +import static io.netty.handler.codec.http.HttpObjectDecoder.DEFAULT_MAX_HEADER_SIZE; +import static io.netty.handler.codec.http.HttpObjectDecoder.DEFAULT_MAX_INITIAL_LINE_LENGTH; +import static io.netty.handler.codec.http.HttpObjectDecoder.DEFAULT_VALIDATE_HEADERS; + +/** + * A combination of {@link HttpRequestEncoder} and {@link HttpResponseDecoder} + * which enables easier client side HTTP implementation. {@link HttpClientCodec} + * provides additional state management for HEAD and CONNECT + * requests, which {@link HttpResponseDecoder} lacks. Please refer to + * {@link HttpResponseDecoder} to learn what additional state management needs + * to be done for HEAD and CONNECT and why + * {@link HttpResponseDecoder} can not handle it by itself. + *

+ * If the {@link Channel} is closed and there are missing responses, + * a {@link PrematureChannelClosureException} is thrown. + * + *

Header Validation

+ * + * It is recommended to always enable header validation. + *

+ * Without header validation, your system can become vulnerable to + * + * CWE-113: Improper Neutralization of CRLF Sequences in HTTP Headers ('HTTP Response Splitting') + * . + *

+ * This recommendation stands even when both peers in the HTTP exchange are trusted, + * as it helps with defence-in-depth. + * + * @see HttpServerCodec + */ +public final class HttpClientCodec extends CombinedChannelDuplexHandler + implements HttpClientUpgradeHandler.SourceCodec { + public static final boolean DEFAULT_FAIL_ON_MISSING_RESPONSE = false; + public static final boolean DEFAULT_PARSE_HTTP_AFTER_CONNECT_REQUEST = false; + + /** A queue that is used for correlating a request and a response. */ + private final Queue queue = new ArrayDeque(); + private final boolean parseHttpAfterConnectRequest; + + /** If true, decoding stops (i.e. pass-through) */ + private boolean done; + + private final AtomicLong requestResponseCounter = new AtomicLong(); + private final boolean failOnMissingResponse; + + /** + * Creates a new instance with the default decoder options + * ({@code maxInitialLineLength (4096)}, {@code maxHeaderSize (8192)}, and + * {@code maxChunkSize (8192)}). + */ + public HttpClientCodec() { + this(new HttpDecoderConfig(), + DEFAULT_PARSE_HTTP_AFTER_CONNECT_REQUEST, + DEFAULT_FAIL_ON_MISSING_RESPONSE); + } + + /** + * Creates a new instance with the specified decoder options. + */ + public HttpClientCodec(int maxInitialLineLength, int maxHeaderSize, int maxChunkSize) { + this(new HttpDecoderConfig() + .setMaxInitialLineLength(maxInitialLineLength) + .setMaxHeaderSize(maxHeaderSize) + .setMaxChunkSize(maxChunkSize), + DEFAULT_PARSE_HTTP_AFTER_CONNECT_REQUEST, + DEFAULT_FAIL_ON_MISSING_RESPONSE); + } + + /** + * Creates a new instance with the specified decoder options. + */ + public HttpClientCodec( + int maxInitialLineLength, int maxHeaderSize, int maxChunkSize, boolean failOnMissingResponse) { + this(new HttpDecoderConfig() + .setMaxInitialLineLength(maxInitialLineLength) + .setMaxHeaderSize(maxHeaderSize) + .setMaxChunkSize(maxChunkSize), + DEFAULT_PARSE_HTTP_AFTER_CONNECT_REQUEST, + failOnMissingResponse); + } + + /** + * Creates a new instance with the specified decoder options. + * + * @deprecated Prefer the {@link #HttpClientCodec(int, int, int, boolean)} constructor, + * to always enable header validation. + */ + @Deprecated + public HttpClientCodec( + int maxInitialLineLength, int maxHeaderSize, int maxChunkSize, boolean failOnMissingResponse, + boolean validateHeaders) { + this(new HttpDecoderConfig() + .setMaxInitialLineLength(maxInitialLineLength) + .setMaxHeaderSize(maxHeaderSize) + .setMaxChunkSize(maxChunkSize) + .setValidateHeaders(validateHeaders), + DEFAULT_PARSE_HTTP_AFTER_CONNECT_REQUEST, + failOnMissingResponse); + } + + /** + * Creates a new instance with the specified decoder options. + * + * @deprecated Prefer the {@link #HttpClientCodec(HttpDecoderConfig, boolean, boolean)} constructor, + * to always enable header validation. + */ + @Deprecated + public HttpClientCodec( + int maxInitialLineLength, int maxHeaderSize, int maxChunkSize, boolean failOnMissingResponse, + boolean validateHeaders, boolean parseHttpAfterConnectRequest) { + this(new HttpDecoderConfig() + .setMaxInitialLineLength(maxInitialLineLength) + .setMaxHeaderSize(maxHeaderSize) + .setMaxChunkSize(maxChunkSize) + .setValidateHeaders(validateHeaders), + parseHttpAfterConnectRequest, + failOnMissingResponse); + } + + /** + * Creates a new instance with the specified decoder options. + * + * @deprecated Prefer the {@link #HttpClientCodec(HttpDecoderConfig, boolean, boolean)} constructor, + * to always enable header validation. + */ + @Deprecated + public HttpClientCodec( + int maxInitialLineLength, int maxHeaderSize, int maxChunkSize, boolean failOnMissingResponse, + boolean validateHeaders, int initialBufferSize) { + this(new HttpDecoderConfig() + .setMaxInitialLineLength(maxInitialLineLength) + .setMaxHeaderSize(maxHeaderSize) + .setMaxChunkSize(maxChunkSize) + .setValidateHeaders(validateHeaders) + .setInitialBufferSize(initialBufferSize), + DEFAULT_PARSE_HTTP_AFTER_CONNECT_REQUEST, + failOnMissingResponse); + } + + /** + * Creates a new instance with the specified decoder options. + * + * @deprecated Prefer the {@link #HttpClientCodec(HttpDecoderConfig, boolean, boolean)} constructor, + * to always enable header validation. + */ + @Deprecated + public HttpClientCodec( + int maxInitialLineLength, int maxHeaderSize, int maxChunkSize, boolean failOnMissingResponse, + boolean validateHeaders, int initialBufferSize, boolean parseHttpAfterConnectRequest) { + this(new HttpDecoderConfig() + .setMaxInitialLineLength(maxInitialLineLength) + .setMaxHeaderSize(maxHeaderSize) + .setMaxChunkSize(maxChunkSize) + .setValidateHeaders(validateHeaders) + .setInitialBufferSize(initialBufferSize), + parseHttpAfterConnectRequest, + failOnMissingResponse); + } + /** + * Creates a new instance with the specified decoder options. + * + * @deprecated Prefer the {@link #HttpClientCodec(HttpDecoderConfig, boolean, boolean)} constructor, + * to always enable header validation. + */ + @Deprecated + public HttpClientCodec( + int maxInitialLineLength, int maxHeaderSize, int maxChunkSize, boolean failOnMissingResponse, + boolean validateHeaders, int initialBufferSize, boolean parseHttpAfterConnectRequest, + boolean allowDuplicateContentLengths) { + this(new HttpDecoderConfig() + .setMaxInitialLineLength(maxInitialLineLength) + .setMaxHeaderSize(maxHeaderSize) + .setMaxChunkSize(maxChunkSize) + .setValidateHeaders(validateHeaders) + .setInitialBufferSize(initialBufferSize) + .setAllowDuplicateContentLengths(allowDuplicateContentLengths), + parseHttpAfterConnectRequest, + failOnMissingResponse); + } + + /** + * Creates a new instance with the specified decoder options. + * + * @deprecated Prefer the {@link #HttpClientCodec(HttpDecoderConfig, boolean, boolean)} + * constructor, to always enable header validation. + */ + @Deprecated + public HttpClientCodec( + int maxInitialLineLength, int maxHeaderSize, int maxChunkSize, boolean failOnMissingResponse, + boolean validateHeaders, int initialBufferSize, boolean parseHttpAfterConnectRequest, + boolean allowDuplicateContentLengths, boolean allowPartialChunks) { + this(new HttpDecoderConfig() + .setMaxInitialLineLength(maxInitialLineLength) + .setMaxHeaderSize(maxHeaderSize) + .setMaxChunkSize(maxChunkSize) + .setValidateHeaders(validateHeaders) + .setInitialBufferSize(initialBufferSize) + .setAllowDuplicateContentLengths(allowDuplicateContentLengths) + .setAllowPartialChunks(allowPartialChunks), + parseHttpAfterConnectRequest, + failOnMissingResponse); + } + + /** + * Creates a new instance with the specified decoder options. + */ + public HttpClientCodec( + HttpDecoderConfig config, boolean parseHttpAfterConnectRequest, boolean failOnMissingResponse) { + init(new Decoder(config), new Encoder()); + this.parseHttpAfterConnectRequest = parseHttpAfterConnectRequest; + this.failOnMissingResponse = failOnMissingResponse; + } + + /** + * Prepares to upgrade to another protocol from HTTP. Disables the {@link Encoder}. + */ + @Override + public void prepareUpgradeFrom(ChannelHandlerContext ctx) { + ((Encoder) outboundHandler()).upgraded = true; + } + + /** + * Upgrades to another protocol from HTTP. Removes the {@link Decoder} and {@link Encoder} from + * the pipeline. + */ + @Override + public void upgradeFrom(ChannelHandlerContext ctx) { + final ChannelPipeline p = ctx.pipeline(); + p.remove(this); + } + + public void setSingleDecode(boolean singleDecode) { + inboundHandler().setSingleDecode(singleDecode); + } + + public boolean isSingleDecode() { + return inboundHandler().isSingleDecode(); + } + + private final class Encoder extends HttpRequestEncoder { + + boolean upgraded; + + @Override + protected void encode( + ChannelHandlerContext ctx, Object msg, List out) throws Exception { + + if (upgraded) { + // HttpObjectEncoder overrides .write and does not release msg, so we don't need to retain it here + out.add(msg); + return; + } + + if (msg instanceof HttpRequest) { + queue.offer(((HttpRequest) msg).method()); + } + + super.encode(ctx, msg, out); + + if (failOnMissingResponse && !done) { + // check if the request is chunked if so do not increment + if (msg instanceof LastHttpContent) { + // increment as its the last chunk + requestResponseCounter.incrementAndGet(); + } + } + } + } + + private final class Decoder extends HttpResponseDecoder { + Decoder(HttpDecoderConfig config) { + super(config); + } + + @Override + protected void decode( + ChannelHandlerContext ctx, ByteBuf buffer, List out) throws Exception { + if (done) { + int readable = actualReadableBytes(); + if (readable == 0) { + // if non is readable just return null + // https://github.com/netty/netty/issues/1159 + return; + } + out.add(buffer.readBytes(readable)); + } else { + int oldSize = out.size(); + super.decode(ctx, buffer, out); + if (failOnMissingResponse) { + int size = out.size(); + for (int i = oldSize; i < size; i++) { + decrement(out.get(i)); + } + } + } + } + + private void decrement(Object msg) { + if (msg == null) { + return; + } + + // check if it's an Header and its transfer encoding is not chunked. + if (msg instanceof LastHttpContent) { + requestResponseCounter.decrementAndGet(); + } + } + + @Override + protected boolean isContentAlwaysEmpty(HttpMessage msg) { + // Get the method of the HTTP request that corresponds to the + // current response. + // + // Even if we do not use the method to compare we still need to poll it to ensure we keep + // request / response pairs in sync. + HttpMethod method = queue.poll(); + + final HttpResponseStatus status = ((HttpResponse) msg).status(); + final HttpStatusClass statusClass = status.codeClass(); + final int statusCode = status.code(); + if (statusClass == HttpStatusClass.INFORMATIONAL) { + // An informational response should be excluded from paired comparison. + // Just delegate to super method which has all the needed handling. + return super.isContentAlwaysEmpty(msg); + } + + // If the remote peer did for example send multiple responses for one request (which is not allowed per + // spec but may still be possible) method will be null so guard against it. + if (method != null) { + char firstChar = method.name().charAt(0); + switch (firstChar) { + case 'H': + // According to 4.3, RFC2616: + // All responses to the HEAD request method MUST NOT include a + // message-body, even though the presence of entity-header fields + // might lead one to believe they do. + if (HttpMethod.HEAD.equals(method)) { + return true; + + // The following code was inserted to work around the servers + // that behave incorrectly. It has been commented out + // because it does not work with well behaving servers. + // Please note, even if the 'Transfer-Encoding: chunked' + // header exists in the HEAD response, the response should + // have absolutely no content. + // + //// Interesting edge case: + //// Some poorly implemented servers will send a zero-byte + //// chunk if Transfer-Encoding of the response is 'chunked'. + //// + //// return !msg.isChunked(); + } + break; + case 'C': + // Successful CONNECT request results in a response with empty body. + if (statusCode == 200) { + if (HttpMethod.CONNECT.equals(method)) { + // Proxy connection established - Parse HTTP only if configured by + // parseHttpAfterConnectRequest, else pass through. + if (!parseHttpAfterConnectRequest) { + done = true; + queue.clear(); + } + return true; + } + } + break; + default: + break; + } + } + return super.isContentAlwaysEmpty(msg); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) + throws Exception { + super.channelInactive(ctx); + + if (failOnMissingResponse) { + long missingResponses = requestResponseCounter.get(); + if (missingResponses > 0) { + ctx.fireExceptionCaught(new PrematureChannelClosureException( + "channel gone inactive with " + missingResponses + + " missing response(s)")); + } + } + } + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpClientUpgradeHandler.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpClientUpgradeHandler.java new file mode 100644 index 0000000..d3f0b5a --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpClientUpgradeHandler.java @@ -0,0 +1,277 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelOutboundHandler; +import io.netty.channel.ChannelPromise; +import io.netty.util.AsciiString; +import io.netty.util.internal.ObjectUtil; + +import java.net.SocketAddress; +import java.util.Collection; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Set; + +import static io.netty.handler.codec.http.HttpResponseStatus.SWITCHING_PROTOCOLS; +import static io.netty.util.ReferenceCountUtil.release; + +/** + * Client-side handler for handling an HTTP upgrade handshake to another protocol. When the first + * HTTP request is sent, this handler will add all appropriate headers to perform an upgrade to the + * new protocol. If the upgrade fails (i.e. response is not 101 Switching Protocols), this handler + * simply removes itself from the pipeline. If the upgrade is successful, upgrades the pipeline to + * the new protocol. + */ +public class HttpClientUpgradeHandler extends HttpObjectAggregator implements ChannelOutboundHandler { + + /** + * User events that are fired to notify about upgrade status. + */ + public enum UpgradeEvent { + /** + * The Upgrade request was sent to the server. + */ + UPGRADE_ISSUED, + + /** + * The Upgrade to the new protocol was successful. + */ + UPGRADE_SUCCESSFUL, + + /** + * The Upgrade was unsuccessful due to the server not issuing + * with a 101 Switching Protocols response. + */ + UPGRADE_REJECTED + } + + /** + * The source codec that is used in the pipeline initially. + */ + public interface SourceCodec { + + /** + * Removes or disables the encoder of this codec so that the {@link UpgradeCodec} can send an initial greeting + * (if any). + */ + void prepareUpgradeFrom(ChannelHandlerContext ctx); + + /** + * Removes this codec (i.e. all associated handlers) from the pipeline. + */ + void upgradeFrom(ChannelHandlerContext ctx); + } + + /** + * A codec that the source can be upgraded to. + */ + public interface UpgradeCodec { + /** + * Returns the name of the protocol supported by this codec, as indicated by the {@code 'UPGRADE'} header. + */ + CharSequence protocol(); + + /** + * Sets any protocol-specific headers required to the upgrade request. Returns the names of + * all headers that were added. These headers will be used to populate the CONNECTION header. + */ + Collection setUpgradeHeaders(ChannelHandlerContext ctx, HttpRequest upgradeRequest); + + /** + * Performs an HTTP protocol upgrade from the source codec. This method is responsible for + * adding all handlers required for the new protocol. + * + * @param ctx the context for the current handler. + * @param upgradeResponse the 101 Switching Protocols response that indicates that the server + * has switched to this protocol. + */ + void upgradeTo(ChannelHandlerContext ctx, FullHttpResponse upgradeResponse) throws Exception; + } + + private final SourceCodec sourceCodec; + private final UpgradeCodec upgradeCodec; + private boolean upgradeRequested; + + /** + * Constructs the client upgrade handler. + * + * @param sourceCodec the codec that is being used initially. + * @param upgradeCodec the codec that the client would like to upgrade to. + * @param maxContentLength the maximum length of the aggregated content. + */ + public HttpClientUpgradeHandler(SourceCodec sourceCodec, UpgradeCodec upgradeCodec, + int maxContentLength) { + super(maxContentLength); + this.sourceCodec = ObjectUtil.checkNotNull(sourceCodec, "sourceCodec"); + this.upgradeCodec = ObjectUtil.checkNotNull(upgradeCodec, "upgradeCodec"); + } + + @Override + public void bind(ChannelHandlerContext ctx, SocketAddress localAddress, ChannelPromise promise) throws Exception { + ctx.bind(localAddress, promise); + } + + @Override + public void connect(ChannelHandlerContext ctx, SocketAddress remoteAddress, SocketAddress localAddress, + ChannelPromise promise) throws Exception { + ctx.connect(remoteAddress, localAddress, promise); + } + + @Override + public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + ctx.disconnect(promise); + } + + @Override + public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + ctx.close(promise); + } + + @Override + public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + ctx.deregister(promise); + } + + @Override + public void read(ChannelHandlerContext ctx) throws Exception { + ctx.read(); + } + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) + throws Exception { + if (!(msg instanceof HttpRequest)) { + ctx.write(msg, promise); + return; + } + + if (upgradeRequested) { + promise.setFailure(new IllegalStateException( + "Attempting to write HTTP request with upgrade in progress")); + return; + } + + upgradeRequested = true; + setUpgradeRequestHeaders(ctx, (HttpRequest) msg); + + // Continue writing the request. + ctx.write(msg, promise); + + // Notify that the upgrade request was issued. + ctx.fireUserEventTriggered(UpgradeEvent.UPGRADE_ISSUED); + // Now we wait for the next HTTP response to see if we switch protocols. + } + + @Override + public void flush(ChannelHandlerContext ctx) throws Exception { + ctx.flush(); + } + + @Override + protected void decode(ChannelHandlerContext ctx, HttpObject msg, List out) + throws Exception { + FullHttpResponse response = null; + try { + if (!upgradeRequested) { + throw new IllegalStateException("Read HTTP response without requesting protocol switch"); + } + + if (msg instanceof HttpResponse) { + HttpResponse rep = (HttpResponse) msg; + if (!SWITCHING_PROTOCOLS.equals(rep.status())) { + // The server does not support the requested protocol, just remove this handler + // and continue processing HTTP. + // NOTE: not releasing the response since we're letting it propagate to the + // next handler. + ctx.fireUserEventTriggered(UpgradeEvent.UPGRADE_REJECTED); + removeThisHandler(ctx); + ctx.fireChannelRead(msg); + return; + } + } + + if (msg instanceof FullHttpResponse) { + response = (FullHttpResponse) msg; + // Need to retain since the base class will release after returning from this method. + response.retain(); + out.add(response); + } else { + // Call the base class to handle the aggregation of the full request. + super.decode(ctx, msg, out); + if (out.isEmpty()) { + // The full request hasn't been created yet, still awaiting more data. + return; + } + + assert out.size() == 1; + response = (FullHttpResponse) out.get(0); + } + + CharSequence upgradeHeader = response.headers().get(HttpHeaderNames.UPGRADE); + if (upgradeHeader != null && !AsciiString.contentEqualsIgnoreCase(upgradeCodec.protocol(), upgradeHeader)) { + throw new IllegalStateException( + "Switching Protocols response with unexpected UPGRADE protocol: " + upgradeHeader); + } + + // Upgrade to the new protocol. + sourceCodec.prepareUpgradeFrom(ctx); + upgradeCodec.upgradeTo(ctx, response); + + // Notify that the upgrade to the new protocol completed successfully. + ctx.fireUserEventTriggered(UpgradeEvent.UPGRADE_SUCCESSFUL); + + // We guarantee UPGRADE_SUCCESSFUL event will be arrived at the next handler + // before http2 setting frame and http response. + sourceCodec.upgradeFrom(ctx); + + // We switched protocols, so we're done with the upgrade response. + // Release it and clear it from the output. + response.release(); + out.clear(); + removeThisHandler(ctx); + } catch (Throwable t) { + release(response); + ctx.fireExceptionCaught(t); + removeThisHandler(ctx); + } + } + + private static void removeThisHandler(ChannelHandlerContext ctx) { + ctx.pipeline().remove(ctx.name()); + } + + /** + * Adds all upgrade request headers necessary for an upgrade to the supported protocols. + */ + private void setUpgradeRequestHeaders(ChannelHandlerContext ctx, HttpRequest request) { + // Set the UPGRADE header on the request. + request.headers().set(HttpHeaderNames.UPGRADE, upgradeCodec.protocol()); + + // Add all protocol-specific headers to the request. + Set connectionParts = new LinkedHashSet(2); + connectionParts.addAll(upgradeCodec.setUpgradeHeaders(ctx, request)); + + // Set the CONNECTION header from the set of all protocol-specific headers that were added. + StringBuilder builder = new StringBuilder(); + for (CharSequence part : connectionParts) { + builder.append(part); + builder.append(','); + } + builder.append(HttpHeaderValues.UPGRADE); + request.headers().add(HttpHeaderNames.CONNECTION, builder.toString()); + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpConstants.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpConstants.java new file mode 100644 index 0000000..9bb1f70 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpConstants.java @@ -0,0 +1,82 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.util.CharsetUtil; + +import java.nio.charset.Charset; + +public final class HttpConstants { + + /** + * Horizontal space + */ + public static final byte SP = 32; + + /** + * Horizontal tab + */ + public static final byte HT = 9; + + /** + * Carriage return + */ + public static final byte CR = 13; + + /** + * Equals '=' + */ + public static final byte EQUALS = 61; + + /** + * Line feed character + */ + public static final byte LF = 10; + + /** + * Colon ':' + */ + public static final byte COLON = 58; + + /** + * Semicolon ';' + */ + public static final byte SEMICOLON = 59; + + /** + * Comma ',' + */ + public static final byte COMMA = 44; + + /** + * Double quote '"' + */ + public static final byte DOUBLE_QUOTE = '"'; + + /** + * Default character set (UTF-8) + */ + public static final Charset DEFAULT_CHARSET = CharsetUtil.UTF_8; + + /** + * Horizontal space + */ + public static final char SP_CHAR = (char) SP; + + private HttpConstants() { + // Unused + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpContent.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpContent.java new file mode 100644 index 0000000..d6b9e2e --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpContent.java @@ -0,0 +1,54 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufHolder; +import io.netty.channel.ChannelPipeline; + +/** + * An HTTP chunk which is used for HTTP chunked transfer-encoding. + * {@link HttpObjectDecoder} generates {@link HttpContent} after + * {@link HttpMessage} when the content is large or the encoding of the content + * is 'chunked. If you prefer not to receive {@link HttpContent} in your handler, + * place {@link HttpObjectAggregator} after {@link HttpObjectDecoder} in the + * {@link ChannelPipeline}. + */ +public interface HttpContent extends HttpObject, ByteBufHolder { + @Override + HttpContent copy(); + + @Override + HttpContent duplicate(); + + @Override + HttpContent retainedDuplicate(); + + @Override + HttpContent replace(ByteBuf content); + + @Override + HttpContent retain(); + + @Override + HttpContent retain(int increment); + + @Override + HttpContent touch(); + + @Override + HttpContent touch(Object hint); +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpContentCompressor.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpContentCompressor.java new file mode 100644 index 0000000..bf8396a --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpContentCompressor.java @@ -0,0 +1,475 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.handler.codec.compression.JdkZlibEncoder; +import java.util.HashMap; +import java.util.Map; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.MessageToByteEncoder; +import io.netty.handler.codec.compression.Brotli; +import io.netty.handler.codec.compression.BrotliEncoder; +import io.netty.handler.codec.compression.BrotliOptions; +import io.netty.handler.codec.compression.CompressionOptions; +import io.netty.handler.codec.compression.DeflateOptions; +import io.netty.handler.codec.compression.GzipOptions; +import io.netty.handler.codec.compression.StandardCompressionOptions; +import io.netty.handler.codec.compression.ZlibCodecFactory; +import io.netty.handler.codec.compression.ZlibEncoder; +import io.netty.handler.codec.compression.ZlibWrapper; +import io.netty.handler.codec.compression.Zstd; +import io.netty.handler.codec.compression.ZstdEncoder; +import io.netty.handler.codec.compression.ZstdOptions; +import io.netty.handler.codec.compression.SnappyFrameEncoder; +import io.netty.handler.codec.compression.SnappyOptions; +import io.netty.util.internal.ObjectUtil; + +/** + * Compresses an {@link HttpMessage} and an {@link HttpContent} in {@code gzip} or + * {@code deflate} encoding while respecting the {@code "Accept-Encoding"} header. + * If there is no matching encoding, no compression is done. For more + * information on how this handler modifies the message, please refer to + * {@link HttpContentEncoder}. + */ +public class HttpContentCompressor extends HttpContentEncoder { + + private final boolean supportsCompressionOptions; + private final BrotliOptions brotliOptions; + private final GzipOptions gzipOptions; + private final DeflateOptions deflateOptions; + private final ZstdOptions zstdOptions; + private final SnappyOptions snappyOptions; + + private final int compressionLevel; + private final int windowBits; + private final int memLevel; + private final int contentSizeThreshold; + private ChannelHandlerContext ctx; + private final Map factories; + + /** + * Creates a new handler with the default compression level (6), + * default window size (15) and default memory level (8). + */ + public HttpContentCompressor() { + this(6); + } + + /** + * Creates a new handler with the specified compression level, default + * window size (15) and default memory level (8). + * + * @param compressionLevel + * {@code 1} yields the fastest compression and {@code 9} yields the + * best compression. {@code 0} means no compression. The default + * compression level is {@code 6}. + */ + @Deprecated + public HttpContentCompressor(int compressionLevel) { + this(compressionLevel, 15, 8, 0); + } + + /** + * Creates a new handler with the specified compression level, window size, + * and memory level.. + * + * @param compressionLevel + * {@code 1} yields the fastest compression and {@code 9} yields the + * best compression. {@code 0} means no compression. The default + * compression level is {@code 6}. + * @param windowBits + * The base two logarithm of the size of the history buffer. The + * value should be in the range {@code 9} to {@code 15} inclusive. + * Larger values result in better compression at the expense of + * memory usage. The default value is {@code 15}. + * @param memLevel + * How much memory should be allocated for the internal compression + * state. {@code 1} uses minimum memory and {@code 9} uses maximum + * memory. Larger values result in better and faster compression + * at the expense of memory usage. The default value is {@code 8} + */ + @Deprecated + public HttpContentCompressor(int compressionLevel, int windowBits, int memLevel) { + this(compressionLevel, windowBits, memLevel, 0); + } + + /** + * Creates a new handler with the specified compression level, window size, + * and memory level.. + * + * @param compressionLevel + * {@code 1} yields the fastest compression and {@code 9} yields the + * best compression. {@code 0} means no compression. The default + * compression level is {@code 6}. + * @param windowBits + * The base two logarithm of the size of the history buffer. The + * value should be in the range {@code 9} to {@code 15} inclusive. + * Larger values result in better compression at the expense of + * memory usage. The default value is {@code 15}. + * @param memLevel + * How much memory should be allocated for the internal compression + * state. {@code 1} uses minimum memory and {@code 9} uses maximum + * memory. Larger values result in better and faster compression + * at the expense of memory usage. The default value is {@code 8} + * @param contentSizeThreshold + * The response body is compressed when the size of the response + * body exceeds the threshold. The value should be a non negative + * number. {@code 0} will enable compression for all responses. + */ + @Deprecated + public HttpContentCompressor(int compressionLevel, int windowBits, int memLevel, int contentSizeThreshold) { + this.compressionLevel = ObjectUtil.checkInRange(compressionLevel, 0, 9, "compressionLevel"); + this.windowBits = ObjectUtil.checkInRange(windowBits, 9, 15, "windowBits"); + this.memLevel = ObjectUtil.checkInRange(memLevel, 1, 9, "memLevel"); + this.contentSizeThreshold = ObjectUtil.checkPositiveOrZero(contentSizeThreshold, "contentSizeThreshold"); + this.brotliOptions = null; + this.gzipOptions = null; + this.deflateOptions = null; + this.zstdOptions = null; + this.snappyOptions = null; + this.factories = null; + this.supportsCompressionOptions = false; + } + + /** + * Create a new {@link HttpContentCompressor} Instance with specified + * {@link CompressionOptions}s and contentSizeThreshold set to {@code 0} + * + * @param compressionOptions {@link CompressionOptions} or {@code null} if the default + * should be used. + */ + public HttpContentCompressor(CompressionOptions... compressionOptions) { + this(0, compressionOptions); + } + + /** + * Create a new {@link HttpContentCompressor} instance with specified + * {@link CompressionOptions}s + * + * @param contentSizeThreshold + * The response body is compressed when the size of the response + * body exceeds the threshold. The value should be a non negative + * number. {@code 0} will enable compression for all responses. + * @param compressionOptions {@link CompressionOptions} or {@code null} + * if the default should be used. + */ + public HttpContentCompressor(int contentSizeThreshold, CompressionOptions... compressionOptions) { + this.contentSizeThreshold = ObjectUtil.checkPositiveOrZero(contentSizeThreshold, "contentSizeThreshold"); + BrotliOptions brotliOptions = null; + GzipOptions gzipOptions = null; + DeflateOptions deflateOptions = null; + ZstdOptions zstdOptions = null; + SnappyOptions snappyOptions = null; + if (compressionOptions == null || compressionOptions.length == 0) { + brotliOptions = Brotli.isAvailable() ? StandardCompressionOptions.brotli() : null; + gzipOptions = StandardCompressionOptions.gzip(); + deflateOptions = StandardCompressionOptions.deflate(); + zstdOptions = Zstd.isAvailable() ? StandardCompressionOptions.zstd() : null; + snappyOptions = StandardCompressionOptions.snappy(); + } else { + ObjectUtil.deepCheckNotNull("compressionOptions", compressionOptions); + for (CompressionOptions compressionOption : compressionOptions) { + // BrotliOptions' class initialization depends on Brotli classes being on the classpath. + // The Brotli.isAvailable check ensures that BrotliOptions will only get instantiated if Brotli is + // on the classpath. + // This results in the static analysis of native-image identifying the instanceof BrotliOptions check + // and thus BrotliOptions itself as unreachable, enabling native-image to link all classes + // at build time and not complain about the missing Brotli classes. + if (Brotli.isAvailable() && compressionOption instanceof BrotliOptions) { + brotliOptions = (BrotliOptions) compressionOption; + } else if (compressionOption instanceof GzipOptions) { + gzipOptions = (GzipOptions) compressionOption; + } else if (compressionOption instanceof DeflateOptions) { + deflateOptions = (DeflateOptions) compressionOption; + } else if (compressionOption instanceof ZstdOptions) { + zstdOptions = (ZstdOptions) compressionOption; + } else if (compressionOption instanceof SnappyOptions) { + snappyOptions = (SnappyOptions) compressionOption; + } else { + throw new IllegalArgumentException("Unsupported " + CompressionOptions.class.getSimpleName() + + ": " + compressionOption); + } + } + } + + this.gzipOptions = gzipOptions; + this.deflateOptions = deflateOptions; + this.brotliOptions = brotliOptions; + this.zstdOptions = zstdOptions; + this.snappyOptions = snappyOptions; + + this.factories = new HashMap(); + if (this.gzipOptions != null) { + this.factories.put("gzip", new GzipEncoderFactory()); + } + if (this.deflateOptions != null) { + this.factories.put("deflate", new DeflateEncoderFactory()); + } + if (Brotli.isAvailable() && this.brotliOptions != null) { + this.factories.put("br", new BrEncoderFactory()); + } + if (this.zstdOptions != null) { + this.factories.put("zstd", new ZstdEncoderFactory()); + } + if (this.snappyOptions != null) { + this.factories.put("snappy", new SnappyEncoderFactory()); + } + + this.compressionLevel = -1; + this.windowBits = -1; + this.memLevel = -1; + supportsCompressionOptions = true; + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + this.ctx = ctx; + } + + @Override + protected Result beginEncode(HttpResponse httpResponse, String acceptEncoding) throws Exception { + if (this.contentSizeThreshold > 0) { + if (httpResponse instanceof HttpContent && + ((HttpContent) httpResponse).content().readableBytes() < contentSizeThreshold) { + return null; + } + } + + String contentEncoding = httpResponse.headers().get(HttpHeaderNames.CONTENT_ENCODING); + if (contentEncoding != null) { + // Content-Encoding was set, either as something specific or as the IDENTITY encoding + // Therefore, we should NOT encode here + return null; + } + + if (supportsCompressionOptions) { + String targetContentEncoding = determineEncoding(acceptEncoding); + if (targetContentEncoding == null) { + return null; + } + + CompressionEncoderFactory encoderFactory = factories.get(targetContentEncoding); + + if (encoderFactory == null) { + throw new Error(); + } + + return new Result(targetContentEncoding, + new EmbeddedChannel(ctx.channel().id(), ctx.channel().metadata().hasDisconnect(), + ctx.channel().config(), encoderFactory.createEncoder())); + } else { + ZlibWrapper wrapper = determineWrapper(acceptEncoding); + if (wrapper == null) { + return null; + } + + String targetContentEncoding; + switch (wrapper) { + case GZIP: + targetContentEncoding = "gzip"; + break; + case ZLIB: + targetContentEncoding = "deflate"; + break; + default: + throw new Error(); + } + + return new Result( + targetContentEncoding, + new EmbeddedChannel(ctx.channel().id(), ctx.channel().metadata().hasDisconnect(), + ctx.channel().config(), ZlibCodecFactory.newZlibEncoder( + wrapper, compressionLevel, windowBits, memLevel))); + } + } + + @SuppressWarnings("FloatingPointEquality") + protected String determineEncoding(String acceptEncoding) { + float starQ = -1.0f; + float brQ = -1.0f; + float zstdQ = -1.0f; + float snappyQ = -1.0f; + float gzipQ = -1.0f; + float deflateQ = -1.0f; + for (String encoding : acceptEncoding.split(",")) { + float q = 1.0f; + int equalsPos = encoding.indexOf('='); + if (equalsPos != -1) { + try { + q = Float.parseFloat(encoding.substring(equalsPos + 1)); + } catch (NumberFormatException e) { + // Ignore encoding + q = 0.0f; + } + } + if (encoding.contains("*")) { + starQ = q; + } else if (encoding.contains("br") && q > brQ) { + brQ = q; + } else if (encoding.contains("zstd") && q > zstdQ) { + zstdQ = q; + } else if (encoding.contains("snappy") && q > snappyQ) { + snappyQ = q; + } else if (encoding.contains("gzip") && q > gzipQ) { + gzipQ = q; + } else if (encoding.contains("deflate") && q > deflateQ) { + deflateQ = q; + } + } + if (brQ > 0.0f || zstdQ > 0.0f || snappyQ > 0.0f || gzipQ > 0.0f || deflateQ > 0.0f) { + if (brQ != -1.0f && brQ >= zstdQ && this.brotliOptions != null) { + return "br"; + } else if (zstdQ != -1.0f && zstdQ >= snappyQ && this.zstdOptions != null) { + return "zstd"; + } else if (snappyQ != -1.0f && snappyQ >= gzipQ && this.snappyOptions != null) { + return "snappy"; + } else if (gzipQ != -1.0f && gzipQ >= deflateQ && this.gzipOptions != null) { + return "gzip"; + } else if (deflateQ != -1.0f && this.deflateOptions != null) { + return "deflate"; + } + } + if (starQ > 0.0f) { + if (brQ == -1.0f && this.brotliOptions != null) { + return "br"; + } + if (zstdQ == -1.0f && this.zstdOptions != null) { + return "zstd"; + } + if (snappyQ == -1.0f && this.snappyOptions != null) { + return "snappy"; + } + if (gzipQ == -1.0f && this.gzipOptions != null) { + return "gzip"; + } + if (deflateQ == -1.0f && this.deflateOptions != null) { + return "deflate"; + } + } + return null; + } + + @Deprecated + @SuppressWarnings("FloatingPointEquality") + protected ZlibWrapper determineWrapper(String acceptEncoding) { + float starQ = -1.0f; + float gzipQ = -1.0f; + float deflateQ = -1.0f; + for (String encoding : acceptEncoding.split(",")) { + float q = 1.0f; + int equalsPos = encoding.indexOf('='); + if (equalsPos != -1) { + try { + q = Float.parseFloat(encoding.substring(equalsPos + 1)); + } catch (NumberFormatException e) { + // Ignore encoding + q = 0.0f; + } + } + if (encoding.contains("*")) { + starQ = q; + } else if (encoding.contains("gzip") && q > gzipQ) { + gzipQ = q; + } else if (encoding.contains("deflate") && q > deflateQ) { + deflateQ = q; + } + } + if (gzipQ > 0.0f || deflateQ > 0.0f) { + if (gzipQ >= deflateQ) { + return ZlibWrapper.GZIP; + } else { + return ZlibWrapper.ZLIB; + } + } + if (starQ > 0.0f) { + if (gzipQ == -1.0f) { + return ZlibWrapper.GZIP; + } + if (deflateQ == -1.0f) { + return ZlibWrapper.ZLIB; + } + } + return null; + } + + /** + * Compression Encoder Factory that creates {@link ZlibEncoder}s + * used to compress http content for gzip content encoding + */ + private final class GzipEncoderFactory implements CompressionEncoderFactory { + + @Override + public MessageToByteEncoder createEncoder() { + return ZlibCodecFactory.newZlibEncoder( + ZlibWrapper.GZIP, gzipOptions.compressionLevel(), + gzipOptions.windowBits(), gzipOptions.memLevel()); + } + } + + /** + * Compression Encoder Factory that creates {@link ZlibEncoder}s + * used to compress http content for deflate content encoding + */ + private final class DeflateEncoderFactory implements CompressionEncoderFactory { + + @Override + public MessageToByteEncoder createEncoder() { + return ZlibCodecFactory.newZlibEncoder( + ZlibWrapper.ZLIB, deflateOptions.compressionLevel(), + deflateOptions.windowBits(), deflateOptions.memLevel()); + } + } + + /** + * Compression Encoder Factory that creates {@link BrotliEncoder}s + * used to compress http content for br content encoding + */ + private final class BrEncoderFactory implements CompressionEncoderFactory { + + @Override + public MessageToByteEncoder createEncoder() { + return new BrotliEncoder(brotliOptions.parameters()); + } + } + + /** + * Compression Encoder Factory for create {@link ZstdEncoder} + * used to compress http content for zstd content encoding + */ + private final class ZstdEncoderFactory implements CompressionEncoderFactory { + + @Override + public MessageToByteEncoder createEncoder() { + return new ZstdEncoder(zstdOptions.compressionLevel(), + zstdOptions.blockSize(), zstdOptions.maxEncodeSize()); + } + } + + /** + * Compression Encoder Factory for create {@link SnappyFrameEncoder} + * used to compress http content for snappy content encoding + */ + private static final class SnappyEncoderFactory implements CompressionEncoderFactory { + + @Override + public MessageToByteEncoder createEncoder() { + return new SnappyFrameEncoder(); + } + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpContentDecoder.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpContentDecoder.java new file mode 100644 index 0000000..4208d2e --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpContentDecoder.java @@ -0,0 +1,288 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.CodecException; +import io.netty.handler.codec.DecoderResult; +import io.netty.handler.codec.MessageToMessageDecoder; +import io.netty.util.ReferenceCountUtil; + +import java.util.List; + +/** + * Decodes the content of the received {@link HttpRequest} and {@link HttpContent}. + * The original content is replaced with the new content decoded by the + * {@link EmbeddedChannel}, which is created by {@link #newContentDecoder(String)}. + * Once decoding is finished, the value of the 'Content-Encoding' + * header is set to the target content encoding, as returned by {@link #getTargetContentEncoding(String)}. + * Also, the 'Content-Length' header is updated to the length of the + * decoded content. If the content encoding of the original is not supported + * by the decoder, {@link #newContentDecoder(String)} should return {@code null} + * so that no decoding occurs (i.e. pass-through). + *

+ * Please note that this is an abstract class. You have to extend this class + * and implement {@link #newContentDecoder(String)} properly to make this class + * functional. For example, refer to the source code of {@link HttpContentDecompressor}. + *

+ * This handler must be placed after {@link HttpObjectDecoder} in the pipeline + * so that this handler can intercept HTTP requests after {@link HttpObjectDecoder} + * converts {@link ByteBuf}s into HTTP requests. + */ +public abstract class HttpContentDecoder extends MessageToMessageDecoder { + + static final String IDENTITY = HttpHeaderValues.IDENTITY.toString(); + + protected ChannelHandlerContext ctx; + private EmbeddedChannel decoder; + private boolean continueResponse; + private boolean needRead = true; + + @Override + protected void decode(ChannelHandlerContext ctx, HttpObject msg, List out) throws Exception { + try { + if (msg instanceof HttpResponse && ((HttpResponse) msg).status().code() == 100) { + + if (!(msg instanceof LastHttpContent)) { + continueResponse = true; + } + // 100-continue response must be passed through. + out.add(ReferenceCountUtil.retain(msg)); + return; + } + + if (continueResponse) { + if (msg instanceof LastHttpContent) { + continueResponse = false; + } + // 100-continue response must be passed through. + out.add(ReferenceCountUtil.retain(msg)); + return; + } + + if (msg instanceof HttpMessage) { + cleanup(); + final HttpMessage message = (HttpMessage) msg; + final HttpHeaders headers = message.headers(); + + // Determine the content encoding. + String contentEncoding = headers.get(HttpHeaderNames.CONTENT_ENCODING); + if (contentEncoding != null) { + contentEncoding = contentEncoding.trim(); + } else { + String transferEncoding = headers.get(HttpHeaderNames.TRANSFER_ENCODING); + if (transferEncoding != null) { + int idx = transferEncoding.indexOf(","); + if (idx != -1) { + contentEncoding = transferEncoding.substring(0, idx).trim(); + } else { + contentEncoding = transferEncoding.trim(); + } + } else { + contentEncoding = IDENTITY; + } + } + decoder = newContentDecoder(contentEncoding); + + if (decoder == null) { + if (message instanceof HttpContent) { + ((HttpContent) message).retain(); + } + out.add(message); + return; + } + + // Remove content-length header: + // the correct value can be set only after all chunks are processed/decoded. + // If buffering is not an issue, add HttpObjectAggregator down the chain, it will set the header. + // Otherwise, rely on LastHttpContent message. + if (headers.contains(HttpHeaderNames.CONTENT_LENGTH)) { + headers.remove(HttpHeaderNames.CONTENT_LENGTH); + headers.set(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED); + } + // Either it is already chunked or EOF terminated. + // See https://github.com/netty/netty/issues/5892 + + // set new content encoding, + CharSequence targetContentEncoding = getTargetContentEncoding(contentEncoding); + if (HttpHeaderValues.IDENTITY.contentEquals(targetContentEncoding)) { + // Do NOT set the 'Content-Encoding' header if the target encoding is 'identity' + // as per: https://tools.ietf.org/html/rfc2616#section-14.11 + headers.remove(HttpHeaderNames.CONTENT_ENCODING); + } else { + headers.set(HttpHeaderNames.CONTENT_ENCODING, targetContentEncoding); + } + + if (message instanceof HttpContent) { + // If message is a full request or response object (headers + data), don't copy data part into out. + // Output headers only; data part will be decoded below. + // Note: "copy" object must not be an instance of LastHttpContent class, + // as this would (erroneously) indicate the end of the HttpMessage to other handlers. + HttpMessage copy; + if (message instanceof HttpRequest) { + HttpRequest r = (HttpRequest) message; // HttpRequest or FullHttpRequest + copy = new DefaultHttpRequest(r.protocolVersion(), r.method(), r.uri()); + } else if (message instanceof HttpResponse) { + HttpResponse r = (HttpResponse) message; // HttpResponse or FullHttpResponse + copy = new DefaultHttpResponse(r.protocolVersion(), r.status()); + } else { + throw new CodecException("Object of class " + message.getClass().getName() + + " is not an HttpRequest or HttpResponse"); + } + copy.headers().set(message.headers()); + copy.setDecoderResult(message.decoderResult()); + out.add(copy); + } else { + out.add(message); + } + } + + if (msg instanceof HttpContent) { + final HttpContent c = (HttpContent) msg; + if (decoder == null) { + out.add(c.retain()); + } else { + decodeContent(c, out); + } + } + } finally { + needRead = out.isEmpty(); + } + } + + private void decodeContent(HttpContent c, List out) { + ByteBuf content = c.content(); + + decode(content, out); + + if (c instanceof LastHttpContent) { + finishDecode(out); + + LastHttpContent last = (LastHttpContent) c; + // Generate an additional chunk if the decoder produced + // the last product on closure, + HttpHeaders headers = last.trailingHeaders(); + if (headers.isEmpty()) { + out.add(LastHttpContent.EMPTY_LAST_CONTENT); + } else { + out.add(new ComposedLastHttpContent(headers, DecoderResult.SUCCESS)); + } + } + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + boolean needRead = this.needRead; + this.needRead = true; + + try { + ctx.fireChannelReadComplete(); + } finally { + if (needRead && !ctx.channel().config().isAutoRead()) { + ctx.read(); + } + } + } + + /** + * Returns a new {@link EmbeddedChannel} that decodes the HTTP message + * content encoded in the specified contentEncoding. + * + * @param contentEncoding the value of the {@code "Content-Encoding"} header + * @return a new {@link EmbeddedChannel} if the specified encoding is supported. + * {@code null} otherwise (alternatively, you can throw an exception + * to block unknown encoding). + */ + protected abstract EmbeddedChannel newContentDecoder(String contentEncoding) throws Exception; + + /** + * Returns the expected content encoding of the decoded content. + * This getMethod returns {@code "identity"} by default, which is the case for + * most decoders. + * + * @param contentEncoding the value of the {@code "Content-Encoding"} header + * @return the expected content encoding of the new content + */ + protected String getTargetContentEncoding( + @SuppressWarnings("UnusedParameters") String contentEncoding) throws Exception { + return IDENTITY; + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { + cleanupSafely(ctx); + super.handlerRemoved(ctx); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + cleanupSafely(ctx); + super.channelInactive(ctx); + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + this.ctx = ctx; + super.handlerAdded(ctx); + } + + private void cleanup() { + if (decoder != null) { + // Clean-up the previous decoder if not cleaned up correctly. + decoder.finishAndReleaseAll(); + decoder = null; + } + } + + private void cleanupSafely(ChannelHandlerContext ctx) { + try { + cleanup(); + } catch (Throwable cause) { + // If cleanup throws any error we need to propagate it through the pipeline + // so we don't fail to propagate pipeline events. + ctx.fireExceptionCaught(cause); + } + } + + private void decode(ByteBuf in, List out) { + // call retain here as it will call release after its written to the channel + decoder.writeInbound(in.retain()); + fetchDecoderOutput(out); + } + + private void finishDecode(List out) { + if (decoder.finish()) { + fetchDecoderOutput(out); + } + decoder = null; + } + + private void fetchDecoderOutput(List out) { + for (;;) { + ByteBuf buf = decoder.readInbound(); + if (buf == null) { + break; + } + if (!buf.isReadable()) { + buf.release(); + continue; + } + out.add(new DefaultHttpContent(buf)); + } + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpContentDecompressor.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpContentDecompressor.java new file mode 100644 index 0000000..58b3371 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpContentDecompressor.java @@ -0,0 +1,86 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import static io.netty.handler.codec.http.HttpHeaderValues.BR; +import static io.netty.handler.codec.http.HttpHeaderValues.DEFLATE; +import static io.netty.handler.codec.http.HttpHeaderValues.GZIP; +import static io.netty.handler.codec.http.HttpHeaderValues.X_DEFLATE; +import static io.netty.handler.codec.http.HttpHeaderValues.X_GZIP; +import static io.netty.handler.codec.http.HttpHeaderValues.SNAPPY; + +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.compression.Brotli; +import io.netty.handler.codec.compression.BrotliDecoder; +import io.netty.handler.codec.compression.JdkZlibDecoder; +import io.netty.handler.codec.compression.ZlibCodecFactory; +import io.netty.handler.codec.compression.ZlibWrapper; +import io.netty.handler.codec.compression.SnappyFrameDecoder; + +/** + * Decompresses an {@link HttpMessage} and an {@link HttpContent} compressed in + * {@code gzip} or {@code deflate} encoding. For more information on how this + * handler modifies the message, please refer to {@link HttpContentDecoder}. + */ +public class HttpContentDecompressor extends HttpContentDecoder { + + private final boolean strict; + + /** + * Create a new {@link HttpContentDecompressor} in non-strict mode. + */ + public HttpContentDecompressor() { + this(false); + } + + /** + * Create a new {@link HttpContentDecompressor}. + * + * @param strict if {@code true} use strict handling of deflate if used, otherwise handle it in a + * more lenient fashion. + */ + public HttpContentDecompressor(boolean strict) { + this.strict = strict; + } + + @Override + protected EmbeddedChannel newContentDecoder(String contentEncoding) throws Exception { + if (GZIP.contentEqualsIgnoreCase(contentEncoding) || + X_GZIP.contentEqualsIgnoreCase(contentEncoding)) { + return new EmbeddedChannel(ctx.channel().id(), ctx.channel().metadata().hasDisconnect(), + ctx.channel().config(), ZlibCodecFactory.newZlibDecoder(ZlibWrapper.GZIP)); + } + if (DEFLATE.contentEqualsIgnoreCase(contentEncoding) || + X_DEFLATE.contentEqualsIgnoreCase(contentEncoding)) { + final ZlibWrapper wrapper = strict ? ZlibWrapper.ZLIB : ZlibWrapper.ZLIB_OR_NONE; + // To be strict, 'deflate' means ZLIB, but some servers were not implemented correctly. + return new EmbeddedChannel(ctx.channel().id(), ctx.channel().metadata().hasDisconnect(), + ctx.channel().config(), new JdkZlibDecoder()); + } + if (Brotli.isAvailable() && BR.contentEqualsIgnoreCase(contentEncoding)) { + return new EmbeddedChannel(ctx.channel().id(), ctx.channel().metadata().hasDisconnect(), + ctx.channel().config(), new BrotliDecoder()); + } + + if (SNAPPY.contentEqualsIgnoreCase(contentEncoding)) { + return new EmbeddedChannel(ctx.channel().id(), ctx.channel().metadata().hasDisconnect(), + ctx.channel().config(), new SnappyFrameDecoder()); + } + + // 'identity' or unsupported + return null; + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpContentEncoder.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpContentEncoder.java new file mode 100644 index 0000000..3899c81 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpContentEncoder.java @@ -0,0 +1,383 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufHolder; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.DecoderResult; +import io.netty.handler.codec.MessageToMessageCodec; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.StringUtil; + +import java.util.ArrayDeque; +import java.util.List; +import java.util.Queue; + +import static io.netty.handler.codec.http.HttpHeaderNames.*; + +/** + * Encodes the content of the outbound {@link HttpResponse} and {@link HttpContent}. + * The original content is replaced with the new content encoded by the + * {@link EmbeddedChannel}, which is created by {@link #beginEncode(HttpResponse, String)}. + * Once encoding is finished, the value of the 'Content-Encoding' header + * is set to the target content encoding, as returned by + * {@link #beginEncode(HttpResponse, String)}. + * Also, the 'Content-Length' header is updated to the length of the + * encoded content. If there is no supported or allowed encoding in the + * corresponding {@link HttpRequest}'s {@code "Accept-Encoding"} header, + * {@link #beginEncode(HttpResponse, String)} should return {@code null} so that + * no encoding occurs (i.e. pass-through). + *

+ * Please note that this is an abstract class. You have to extend this class + * and implement {@link #beginEncode(HttpResponse, String)} properly to make + * this class functional. For example, refer to the source code of + * {@link HttpContentCompressor}. + *

+ * This handler must be placed after {@link HttpObjectEncoder} in the pipeline + * so that this handler can intercept HTTP responses before {@link HttpObjectEncoder} + * converts them into {@link ByteBuf}s. + */ +public abstract class HttpContentEncoder extends MessageToMessageCodec { + + private enum State { + PASS_THROUGH, + AWAIT_HEADERS, + AWAIT_CONTENT + } + + private static final CharSequence ZERO_LENGTH_HEAD = "HEAD"; + private static final CharSequence ZERO_LENGTH_CONNECT = "CONNECT"; + + private final Queue acceptEncodingQueue = new ArrayDeque(); + private EmbeddedChannel encoder; + private State state = State.AWAIT_HEADERS; + + @Override + public boolean acceptOutboundMessage(Object msg) throws Exception { + return msg instanceof HttpContent || msg instanceof HttpResponse; + } + + @Override + protected void decode(ChannelHandlerContext ctx, HttpRequest msg, List out) throws Exception { + CharSequence acceptEncoding; + List acceptEncodingHeaders = msg.headers().getAll(ACCEPT_ENCODING); + switch (acceptEncodingHeaders.size()) { + case 0: + acceptEncoding = HttpContentDecoder.IDENTITY; + break; + case 1: + acceptEncoding = acceptEncodingHeaders.get(0); + break; + default: + // Multiple message-header fields https://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html#sec4.2 + acceptEncoding = StringUtil.join(",", acceptEncodingHeaders); + break; + } + + HttpMethod method = msg.method(); + if (HttpMethod.HEAD.equals(method)) { + acceptEncoding = ZERO_LENGTH_HEAD; + } else if (HttpMethod.CONNECT.equals(method)) { + acceptEncoding = ZERO_LENGTH_CONNECT; + } + + acceptEncodingQueue.add(acceptEncoding); + out.add(ReferenceCountUtil.retain(msg)); + } + + @Override + protected void encode(ChannelHandlerContext ctx, HttpObject msg, List out) throws Exception { + final boolean isFull = msg instanceof HttpResponse && msg instanceof LastHttpContent; + switch (state) { + case AWAIT_HEADERS: { + ensureHeaders(msg); + assert encoder == null; + + final HttpResponse res = (HttpResponse) msg; + final int code = res.status().code(); + final HttpStatusClass codeClass = res.status().codeClass(); + final CharSequence acceptEncoding; + if (codeClass == HttpStatusClass.INFORMATIONAL) { + // We need to not poll the encoding when response with 1xx codes as another response will follow + // for the issued request. + // See https://github.com/netty/netty/issues/12904 and https://github.com/netty/netty/issues/4079 + acceptEncoding = null; + } else { + // Get the list of encodings accepted by the peer. + acceptEncoding = acceptEncodingQueue.poll(); + if (acceptEncoding == null) { + throw new IllegalStateException("cannot send more responses than requests"); + } + } + + /* + * per rfc2616 4.3 Message Body + * All 1xx (informational), 204 (no content), and 304 (not modified) responses MUST NOT include a + * message-body. All other responses do include a message-body, although it MAY be of zero length. + * + * 9.4 HEAD + * The HEAD method is identical to GET except that the server MUST NOT return a message-body + * in the response. + * + * Also we should pass through HTTP/1.0 as transfer-encoding: chunked is not supported. + * + * See https://github.com/netty/netty/issues/5382 + */ + if (isPassthru(res.protocolVersion(), code, acceptEncoding)) { + if (isFull) { + out.add(ReferenceCountUtil.retain(res)); + } else { + out.add(ReferenceCountUtil.retain(res)); + // Pass through all following contents. + state = State.PASS_THROUGH; + } + break; + } + + if (isFull) { + // Pass through the full response with empty content and continue waiting for the next resp. + if (!((ByteBufHolder) res).content().isReadable()) { + out.add(ReferenceCountUtil.retain(res)); + break; + } + } + + // Prepare to encode the content. + final Result result = beginEncode(res, acceptEncoding.toString()); + + // If unable to encode, pass through. + if (result == null) { + if (isFull) { + out.add(ReferenceCountUtil.retain(res)); + } else { + out.add(ReferenceCountUtil.retain(res)); + // Pass through all following contents. + state = State.PASS_THROUGH; + } + break; + } + + encoder = result.contentEncoder(); + + // Encode the content and remove or replace the existing headers + // so that the message looks like a decoded message. + res.headers().set(HttpHeaderNames.CONTENT_ENCODING, result.targetContentEncoding()); + + // Output the rewritten response. + if (isFull) { + // Convert full message into unfull one. + HttpResponse newRes = new DefaultHttpResponse(res.protocolVersion(), res.status()); + newRes.headers().set(res.headers()); + out.add(newRes); + + ensureContent(res); + encodeFullResponse(newRes, (HttpContent) res, out); + break; + } else { + // Make the response chunked to simplify content transformation. + res.headers().remove(HttpHeaderNames.CONTENT_LENGTH); + res.headers().set(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED); + + out.add(ReferenceCountUtil.retain(res)); + state = State.AWAIT_CONTENT; + if (!(msg instanceof HttpContent)) { + // only break out the switch statement if we have not content to process + // See https://github.com/netty/netty/issues/2006 + break; + } + // Fall through to encode the content + } + } + case AWAIT_CONTENT: { + ensureContent(msg); + if (encodeContent((HttpContent) msg, out)) { + state = State.AWAIT_HEADERS; + } else if (out.isEmpty()) { + // MessageToMessageCodec needs at least one output message + out.add(new DefaultHttpContent(Unpooled.EMPTY_BUFFER)); + } + break; + } + case PASS_THROUGH: { + ensureContent(msg); + out.add(ReferenceCountUtil.retain(msg)); + // Passed through all following contents of the current response. + if (msg instanceof LastHttpContent) { + state = State.AWAIT_HEADERS; + } + break; + } + } + } + + private void encodeFullResponse(HttpResponse newRes, HttpContent content, List out) { + int existingMessages = out.size(); + encodeContent(content, out); + + if (HttpUtil.isContentLengthSet(newRes)) { + // adjust the content-length header + int messageSize = 0; + for (int i = existingMessages; i < out.size(); i++) { + Object item = out.get(i); + if (item instanceof HttpContent) { + messageSize += ((HttpContent) item).content().readableBytes(); + } + } + HttpUtil.setContentLength(newRes, messageSize); + } else { + newRes.headers().set(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED); + } + } + + private static boolean isPassthru(HttpVersion version, int code, CharSequence httpMethod) { + return code < 200 || code == 204 || code == 304 || + (httpMethod == ZERO_LENGTH_HEAD || (httpMethod == ZERO_LENGTH_CONNECT && code == 200)) || + version == HttpVersion.HTTP_1_0; + } + + private static void ensureHeaders(HttpObject msg) { + if (!(msg instanceof HttpResponse)) { + throw new IllegalStateException( + "unexpected message type: " + + msg.getClass().getName() + " (expected: " + HttpResponse.class.getSimpleName() + ')'); + } + } + + private static void ensureContent(HttpObject msg) { + if (!(msg instanceof HttpContent)) { + throw new IllegalStateException( + "unexpected message type: " + + msg.getClass().getName() + " (expected: " + HttpContent.class.getSimpleName() + ')'); + } + } + + private boolean encodeContent(HttpContent c, List out) { + ByteBuf content = c.content(); + + encode(content, out); + + if (c instanceof LastHttpContent) { + finishEncode(out); + LastHttpContent last = (LastHttpContent) c; + + // Generate an additional chunk if the decoder produced + // the last product on closure, + HttpHeaders headers = last.trailingHeaders(); + if (headers.isEmpty()) { + out.add(LastHttpContent.EMPTY_LAST_CONTENT); + } else { + out.add(new ComposedLastHttpContent(headers, DecoderResult.SUCCESS)); + } + return true; + } + return false; + } + + /** + * Prepare to encode the HTTP message content. + * + * @param httpResponse + * the http response + * @param acceptEncoding + * the value of the {@code "Accept-Encoding"} header + * + * @return the result of preparation, which is composed of the determined + * target content encoding and a new {@link EmbeddedChannel} that + * encodes the content into the target content encoding. + * {@code null} if {@code acceptEncoding} is unsupported or rejected + * and thus the content should be handled as-is (i.e. no encoding). + */ + protected abstract Result beginEncode(HttpResponse httpResponse, String acceptEncoding) throws Exception; + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { + cleanupSafely(ctx); + super.handlerRemoved(ctx); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + cleanupSafely(ctx); + super.channelInactive(ctx); + } + + private void cleanup() { + if (encoder != null) { + // Clean-up the previous encoder if not cleaned up correctly. + encoder.finishAndReleaseAll(); + encoder = null; + } + } + + private void cleanupSafely(ChannelHandlerContext ctx) { + try { + cleanup(); + } catch (Throwable cause) { + // If cleanup throws any error we need to propagate it through the pipeline + // so we don't fail to propagate pipeline events. + ctx.fireExceptionCaught(cause); + } + } + + private void encode(ByteBuf in, List out) { + // call retain here as it will call release after its written to the channel + encoder.writeOutbound(in.retain()); + fetchEncoderOutput(out); + } + + private void finishEncode(List out) { + if (encoder.finish()) { + fetchEncoderOutput(out); + } + encoder = null; + } + + private void fetchEncoderOutput(List out) { + for (;;) { + ByteBuf buf = encoder.readOutbound(); + if (buf == null) { + break; + } + if (!buf.isReadable()) { + buf.release(); + continue; + } + out.add(new DefaultHttpContent(buf)); + } + } + + public static final class Result { + private final String targetContentEncoding; + private final EmbeddedChannel contentEncoder; + + public Result(String targetContentEncoding, EmbeddedChannel contentEncoder) { + this.targetContentEncoding = ObjectUtil.checkNotNull(targetContentEncoding, "targetContentEncoding"); + this.contentEncoder = ObjectUtil.checkNotNull(contentEncoder, "contentEncoder"); + } + + public String targetContentEncoding() { + return targetContentEncoding; + } + + public EmbeddedChannel contentEncoder() { + return contentEncoder; + } + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpDecoderConfig.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpDecoderConfig.java new file mode 100644 index 0000000..03f12b7 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpDecoderConfig.java @@ -0,0 +1,225 @@ +/* + * Copyright 2023 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import static io.netty.util.internal.ObjectUtil.checkNotNull; +import static io.netty.util.internal.ObjectUtil.checkPositive; + +/** + * A configuration object for specifying the behaviour of {@link HttpObjectDecoder} and its subclasses. + *

+ * The {@link HttpDecoderConfig} objects are mutable to reduce allocation, + * but also {@link Cloneable} in case a defensive copy is needed. + */ +public final class HttpDecoderConfig implements Cloneable { + private int maxChunkSize = HttpObjectDecoder.DEFAULT_MAX_CHUNK_SIZE; + private boolean chunkedSupported = HttpObjectDecoder.DEFAULT_CHUNKED_SUPPORTED; + private boolean allowPartialChunks = HttpObjectDecoder.DEFAULT_ALLOW_PARTIAL_CHUNKS; + private HttpHeadersFactory headersFactory = DefaultHttpHeadersFactory.headersFactory(); + private HttpHeadersFactory trailersFactory = DefaultHttpHeadersFactory.trailersFactory(); + private boolean allowDuplicateContentLengths = HttpObjectDecoder.DEFAULT_ALLOW_DUPLICATE_CONTENT_LENGTHS; + private int maxInitialLineLength = HttpObjectDecoder.DEFAULT_MAX_INITIAL_LINE_LENGTH; + private int maxHeaderSize = HttpObjectDecoder.DEFAULT_MAX_HEADER_SIZE; + private int initialBufferSize = HttpObjectDecoder.DEFAULT_INITIAL_BUFFER_SIZE; + + public int getInitialBufferSize() { + return initialBufferSize; + } + + /** + * Set the initial size of the temporary buffer used when parsing the lines of the HTTP headers. + * + * @param initialBufferSize The buffer size in bytes. + * @return This decoder config. + */ + public HttpDecoderConfig setInitialBufferSize(int initialBufferSize) { + checkPositive(initialBufferSize, "initialBufferSize"); + this.initialBufferSize = initialBufferSize; + return this; + } + + public int getMaxInitialLineLength() { + return maxInitialLineLength; + } + + /** + * Set the maximum length of the first line of the HTTP header. + * This limits how much memory Netty will use when parsed the initial HTTP header line. + * You would typically set this to the same value as {@link #setMaxHeaderSize(int)}. + * + * @param maxInitialLineLength The maximum length, in bytes. + * @return This decoder config. + */ + public HttpDecoderConfig setMaxInitialLineLength(int maxInitialLineLength) { + checkPositive(maxInitialLineLength, "maxInitialLineLength"); + this.maxInitialLineLength = maxInitialLineLength; + return this; + } + + public int getMaxHeaderSize() { + return maxHeaderSize; + } + + /** + * Set the maximum line length of header lines. + * This limits how much memory Netty will use when parsing HTTP header key-value pairs. + * You would typically set this to the same value as {@link #setMaxInitialLineLength(int)}. + * + * @param maxHeaderSize The maximum length, in bytes. + * @return This decoder config. + */ + public HttpDecoderConfig setMaxHeaderSize(int maxHeaderSize) { + checkPositive(maxHeaderSize, "maxHeaderSize"); + this.maxHeaderSize = maxHeaderSize; + return this; + } + + public int getMaxChunkSize() { + return maxChunkSize; + } + + /** + * Set the maximum chunk size. + * HTTP requests and responses can be quite large, in which case it's better to process the data as a stream of + * chunks. + * This sets the limit, in bytes, at which Netty will send a chunk down the pipeline. + * + * @param maxChunkSize The maximum chunk size, in bytes. + * @return This decoder config. + */ + public HttpDecoderConfig setMaxChunkSize(int maxChunkSize) { + checkPositive(maxChunkSize, "maxChunkSize"); + this.maxChunkSize = maxChunkSize; + return this; + } + + public boolean isChunkedSupported() { + return chunkedSupported; + } + + /** + * Set whether {@code Transfer-Encoding: Chunked} should be supported. + * + * @param chunkedSupported if {@code false}, then a {@code Transfer-Encoding: Chunked} header will produce an error, + * instead of a stream of chunks. + * @return This decoder config. + */ + public HttpDecoderConfig setChunkedSupported(boolean chunkedSupported) { + this.chunkedSupported = chunkedSupported; + return this; + } + + public boolean isAllowPartialChunks() { + return allowPartialChunks; + } + + /** + * Set whether chunks can be split into multiple messages, if their chunk size exceeds the size of the input buffer. + * + * @param allowPartialChunks set to {@code false} to only allow sending whole chunks down the pipeline. + * @return This decoder config. + */ + public HttpDecoderConfig setAllowPartialChunks(boolean allowPartialChunks) { + this.allowPartialChunks = allowPartialChunks; + return this; + } + + public HttpHeadersFactory getHeadersFactory() { + return headersFactory; + } + + /** + * Set the {@link HttpHeadersFactory} to use when creating new HTTP headers objects. + * The default headers factory is {@link DefaultHttpHeadersFactory#headersFactory()}. + *

+ * For the purpose of {@link #clone()}, it is assumed that the factory is either immutable, or can otherwise be + * shared across different decoders and decoder configs. + * + * @param headersFactory The header factory to use. + * @return This decoder config. + */ + public HttpDecoderConfig setHeadersFactory(HttpHeadersFactory headersFactory) { + checkNotNull(headersFactory, "headersFactory"); + this.headersFactory = headersFactory; + return this; + } + + public boolean isAllowDuplicateContentLengths() { + return allowDuplicateContentLengths; + } + + /** + * Set whether more than one {@code Content-Length} header is allowed. + * You usually want to disallow this (which is the default) as multiple {@code Content-Length} headers can indicate + * a request- or response-splitting attack. + * + * @param allowDuplicateContentLengths set to {@code true} to allow multiple content length headers. + * @return This decoder config. + */ + public HttpDecoderConfig setAllowDuplicateContentLengths(boolean allowDuplicateContentLengths) { + this.allowDuplicateContentLengths = allowDuplicateContentLengths; + return this; + } + + /** + * Set whether header validation should be enabled or not. + * This works by changing the configured {@linkplain #setHeadersFactory(HttpHeadersFactory) header factory} + * and {@linkplain #setTrailersFactory(HttpHeadersFactory) trailer factory}. + *

+ * You usually want header validation enabled (which is the default) in order to prevent request-/response-splitting + * attacks. + * + * @param validateHeaders set to {@code false} to disable header validation. + * @return This decoder config. + */ + public HttpDecoderConfig setValidateHeaders(boolean validateHeaders) { + DefaultHttpHeadersFactory noValidation = DefaultHttpHeadersFactory.headersFactory().withValidation(false); + headersFactory = validateHeaders ? DefaultHttpHeadersFactory.headersFactory() : noValidation; + trailersFactory = validateHeaders ? DefaultHttpHeadersFactory.trailersFactory() : noValidation; + return this; + } + + public HttpHeadersFactory getTrailersFactory() { + return trailersFactory; + } + + /** + * Set the {@link HttpHeadersFactory} used to create HTTP trailers. + * This differs from {@link #setHeadersFactory(HttpHeadersFactory)} in that trailers have different validation + * requirements. + * The default trailer factory is {@link DefaultHttpHeadersFactory#headersFactory()}. + *

+ * For the purpose of {@link #clone()}, it is assumed that the factory is either immutable, or can otherwise be + * shared across different decoders and decoder configs. + * + * @param trailersFactory The headers factory to use for creating trailers. + * @return This decoder config. + */ + public HttpDecoderConfig setTrailersFactory(HttpHeadersFactory trailersFactory) { + checkNotNull(trailersFactory, "trailersFactory"); + this.trailersFactory = trailersFactory; + return this; + } + + @Override + public HttpDecoderConfig clone() { + try { + return (HttpDecoderConfig) super.clone(); + } catch (CloneNotSupportedException e) { + throw new AssertionError(); + } + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpExpectationFailedEvent.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpExpectationFailedEvent.java new file mode 100644 index 0000000..c9a5b19 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpExpectationFailedEvent.java @@ -0,0 +1,25 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +/** + * A user event designed to communicate that a expectation has failed and there should be no expectation that a + * body will follow. + */ +public final class HttpExpectationFailedEvent { + public static final HttpExpectationFailedEvent INSTANCE = new HttpExpectationFailedEvent(); + private HttpExpectationFailedEvent() { } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpHeaderDateFormat.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpHeaderDateFormat.java new file mode 100644 index 0000000..a3d1532 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpHeaderDateFormat.java @@ -0,0 +1,104 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.util.concurrent.FastThreadLocal; +import io.netty.handler.codec.DateFormatter; + +import java.text.ParsePosition; +import java.text.SimpleDateFormat; +import java.util.Date; +import java.util.Locale; +import java.util.TimeZone; + +/** + * This DateFormat decodes 3 formats of {@link Date}, but only encodes the one, + * the first: + *

    + *
  • Sun, 06 Nov 1994 08:49:37 GMT: standard specification, the only one with + * valid generation
  • + *
  • Sunday, 06-Nov-94 08:49:37 GMT: obsolete specification
  • + *
  • Sun Nov 6 08:49:37 1994: obsolete specification
  • + *
+ * @deprecated Use {@link DateFormatter} instead + */ +@Deprecated +public final class HttpHeaderDateFormat extends SimpleDateFormat { + private static final long serialVersionUID = -925286159755905325L; + + private final SimpleDateFormat format1 = new HttpHeaderDateFormatObsolete1(); + private final SimpleDateFormat format2 = new HttpHeaderDateFormatObsolete2(); + + private static final FastThreadLocal dateFormatThreadLocal = + new FastThreadLocal() { + @Override + protected HttpHeaderDateFormat initialValue() { + return new HttpHeaderDateFormat(); + } + }; + + public static HttpHeaderDateFormat get() { + return dateFormatThreadLocal.get(); + } + + /** + * Standard date format

+ * Sun, 06 Nov 1994 08:49:37 GMT -> E, d MMM yyyy HH:mm:ss z + */ + private HttpHeaderDateFormat() { + super("E, dd MMM yyyy HH:mm:ss z", Locale.ENGLISH); + setTimeZone(TimeZone.getTimeZone("GMT")); + } + + @Override + public Date parse(String text, ParsePosition pos) { + Date date = super.parse(text, pos); + if (date == null) { + date = format1.parse(text, pos); + } + if (date == null) { + date = format2.parse(text, pos); + } + return date; + } + + /** + * First obsolete format

+ * Sunday, 06-Nov-94 08:49:37 GMT -> E, d-MMM-y HH:mm:ss z + */ + private static final class HttpHeaderDateFormatObsolete1 extends SimpleDateFormat { + private static final long serialVersionUID = -3178072504225114298L; + + HttpHeaderDateFormatObsolete1() { + super("E, dd-MMM-yy HH:mm:ss z", Locale.ENGLISH); + setTimeZone(TimeZone.getTimeZone("GMT")); + } + } + + /** + * Second obsolete format + *

+ * Sun Nov 6 08:49:37 1994 -> EEE, MMM d HH:mm:ss yyyy + */ + private static final class HttpHeaderDateFormatObsolete2 extends SimpleDateFormat { + private static final long serialVersionUID = 3010674519968303714L; + + HttpHeaderDateFormatObsolete2() { + super("E MMM d HH:mm:ss yyyy", Locale.ENGLISH); + setTimeZone(TimeZone.getTimeZone("GMT")); + } + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpHeaderNames.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpHeaderNames.java new file mode 100644 index 0000000..9b68ea3 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpHeaderNames.java @@ -0,0 +1,386 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.codec.http; + +import io.netty.util.AsciiString; + +/** + * Standard HTTP header names. + *

+ * These are all defined as lowercase to support HTTP/2 requirements while also not + * violating HTTP/1.x requirements. New header names should always be lowercase. + */ +public final class HttpHeaderNames { + /** + * {@code "accept"} + */ + public static final AsciiString ACCEPT = AsciiString.cached("accept"); + /** + * {@code "accept-charset"} + */ + public static final AsciiString ACCEPT_CHARSET = AsciiString.cached("accept-charset"); + /** + * {@code "accept-encoding"} + */ + public static final AsciiString ACCEPT_ENCODING = AsciiString.cached("accept-encoding"); + /** + * {@code "accept-language"} + */ + public static final AsciiString ACCEPT_LANGUAGE = AsciiString.cached("accept-language"); + /** + * {@code "accept-ranges"} + */ + public static final AsciiString ACCEPT_RANGES = AsciiString.cached("accept-ranges"); + /** + * {@code "accept-patch"} + */ + public static final AsciiString ACCEPT_PATCH = AsciiString.cached("accept-patch"); + /** + * {@code "access-control-allow-credentials"} + */ + public static final AsciiString ACCESS_CONTROL_ALLOW_CREDENTIALS = + AsciiString.cached("access-control-allow-credentials"); + /** + * {@code "access-control-allow-headers"} + */ + public static final AsciiString ACCESS_CONTROL_ALLOW_HEADERS = + AsciiString.cached("access-control-allow-headers"); + /** + * {@code "access-control-allow-methods"} + */ + public static final AsciiString ACCESS_CONTROL_ALLOW_METHODS = + AsciiString.cached("access-control-allow-methods"); + /** + * {@code "access-control-allow-origin"} + */ + public static final AsciiString ACCESS_CONTROL_ALLOW_ORIGIN = + AsciiString.cached("access-control-allow-origin"); + /** + * {@code "access-control-allow-origin"} + */ + public static final AsciiString ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK = + AsciiString.cached("access-control-allow-private-network"); + /** + * {@code "access-control-expose-headers"} + */ + public static final AsciiString ACCESS_CONTROL_EXPOSE_HEADERS = + AsciiString.cached("access-control-expose-headers"); + /** + * {@code "access-control-max-age"} + */ + public static final AsciiString ACCESS_CONTROL_MAX_AGE = AsciiString.cached("access-control-max-age"); + /** + * {@code "access-control-request-headers"} + */ + public static final AsciiString ACCESS_CONTROL_REQUEST_HEADERS = + AsciiString.cached("access-control-request-headers"); + /** + * {@code "access-control-request-method"} + */ + public static final AsciiString ACCESS_CONTROL_REQUEST_METHOD = + AsciiString.cached("access-control-request-method"); + /** + * {@code "access-control-request-private-network"} + */ + public static final AsciiString ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK = + AsciiString.cached("access-control-request-private-network"); + /** + * {@code "age"} + */ + public static final AsciiString AGE = AsciiString.cached("age"); + /** + * {@code "allow"} + */ + public static final AsciiString ALLOW = AsciiString.cached("allow"); + /** + * {@code "authorization"} + */ + public static final AsciiString AUTHORIZATION = AsciiString.cached("authorization"); + /** + * {@code "cache-control"} + */ + public static final AsciiString CACHE_CONTROL = AsciiString.cached("cache-control"); + /** + * {@code "connection"} + */ + public static final AsciiString CONNECTION = AsciiString.cached("connection"); + /** + * {@code "content-base"} + */ + public static final AsciiString CONTENT_BASE = AsciiString.cached("content-base"); + /** + * {@code "content-encoding"} + */ + public static final AsciiString CONTENT_ENCODING = AsciiString.cached("content-encoding"); + /** + * {@code "content-language"} + */ + public static final AsciiString CONTENT_LANGUAGE = AsciiString.cached("content-language"); + /** + * {@code "content-length"} + */ + public static final AsciiString CONTENT_LENGTH = AsciiString.cached("content-length"); + /** + * {@code "content-location"} + */ + public static final AsciiString CONTENT_LOCATION = AsciiString.cached("content-location"); + /** + * {@code "content-transfer-encoding"} + */ + public static final AsciiString CONTENT_TRANSFER_ENCODING = AsciiString.cached("content-transfer-encoding"); + /** + * {@code "content-disposition"} + */ + public static final AsciiString CONTENT_DISPOSITION = AsciiString.cached("content-disposition"); + /** + * {@code "content-md5"} + */ + public static final AsciiString CONTENT_MD5 = AsciiString.cached("content-md5"); + /** + * {@code "content-range"} + */ + public static final AsciiString CONTENT_RANGE = AsciiString.cached("content-range"); + /** + * {@code "content-security-policy"} + */ + public static final AsciiString CONTENT_SECURITY_POLICY = AsciiString.cached("content-security-policy"); + /** + * {@code "content-type"} + */ + public static final AsciiString CONTENT_TYPE = AsciiString.cached("content-type"); + /** + * {@code "cookie"} + */ + public static final AsciiString COOKIE = AsciiString.cached("cookie"); + /** + * {@code "date"} + */ + public static final AsciiString DATE = AsciiString.cached("date"); + /** + * {@code "dnt"} + */ + public static final AsciiString DNT = AsciiString.cached("dnt"); + /** + * {@code "etag"} + */ + public static final AsciiString ETAG = AsciiString.cached("etag"); + /** + * {@code "expect"} + */ + public static final AsciiString EXPECT = AsciiString.cached("expect"); + /** + * {@code "expires"} + */ + public static final AsciiString EXPIRES = AsciiString.cached("expires"); + /** + * {@code "from"} + */ + public static final AsciiString FROM = AsciiString.cached("from"); + /** + * {@code "host"} + */ + public static final AsciiString HOST = AsciiString.cached("host"); + /** + * {@code "if-match"} + */ + public static final AsciiString IF_MATCH = AsciiString.cached("if-match"); + /** + * {@code "if-modified-since"} + */ + public static final AsciiString IF_MODIFIED_SINCE = AsciiString.cached("if-modified-since"); + /** + * {@code "if-none-match"} + */ + public static final AsciiString IF_NONE_MATCH = AsciiString.cached("if-none-match"); + /** + * {@code "if-range"} + */ + public static final AsciiString IF_RANGE = AsciiString.cached("if-range"); + /** + * {@code "if-unmodified-since"} + */ + public static final AsciiString IF_UNMODIFIED_SINCE = AsciiString.cached("if-unmodified-since"); + /** + * @deprecated use {@link #CONNECTION} + * + * {@code "keep-alive"} + */ + @Deprecated + public static final AsciiString KEEP_ALIVE = AsciiString.cached("keep-alive"); + /** + * {@code "last-modified"} + */ + public static final AsciiString LAST_MODIFIED = AsciiString.cached("last-modified"); + /** + * {@code "location"} + */ + public static final AsciiString LOCATION = AsciiString.cached("location"); + /** + * {@code "max-forwards"} + */ + public static final AsciiString MAX_FORWARDS = AsciiString.cached("max-forwards"); + /** + * {@code "origin"} + */ + public static final AsciiString ORIGIN = AsciiString.cached("origin"); + /** + * {@code "pragma"} + */ + public static final AsciiString PRAGMA = AsciiString.cached("pragma"); + /** + * {@code "proxy-authenticate"} + */ + public static final AsciiString PROXY_AUTHENTICATE = AsciiString.cached("proxy-authenticate"); + /** + * {@code "proxy-authorization"} + */ + public static final AsciiString PROXY_AUTHORIZATION = AsciiString.cached("proxy-authorization"); + /** + * @deprecated use {@link #CONNECTION} + * + * {@code "proxy-connection"} + */ + @Deprecated + public static final AsciiString PROXY_CONNECTION = AsciiString.cached("proxy-connection"); + /** + * {@code "range"} + */ + public static final AsciiString RANGE = AsciiString.cached("range"); + /** + * {@code "referer"} + */ + public static final AsciiString REFERER = AsciiString.cached("referer"); + /** + * {@code "retry-after"} + */ + public static final AsciiString RETRY_AFTER = AsciiString.cached("retry-after"); + /** + * {@code "sec-websocket-key1"} + */ + public static final AsciiString SEC_WEBSOCKET_KEY1 = AsciiString.cached("sec-websocket-key1"); + /** + * {@code "sec-websocket-key2"} + */ + public static final AsciiString SEC_WEBSOCKET_KEY2 = AsciiString.cached("sec-websocket-key2"); + /** + * {@code "sec-websocket-location"} + */ + public static final AsciiString SEC_WEBSOCKET_LOCATION = AsciiString.cached("sec-websocket-location"); + /** + * {@code "sec-websocket-origin"} + */ + public static final AsciiString SEC_WEBSOCKET_ORIGIN = AsciiString.cached("sec-websocket-origin"); + /** + * {@code "sec-websocket-protocol"} + */ + public static final AsciiString SEC_WEBSOCKET_PROTOCOL = AsciiString.cached("sec-websocket-protocol"); + /** + * {@code "sec-websocket-version"} + */ + public static final AsciiString SEC_WEBSOCKET_VERSION = AsciiString.cached("sec-websocket-version"); + /** + * {@code "sec-websocket-key"} + */ + public static final AsciiString SEC_WEBSOCKET_KEY = AsciiString.cached("sec-websocket-key"); + /** + * {@code "sec-websocket-accept"} + */ + public static final AsciiString SEC_WEBSOCKET_ACCEPT = AsciiString.cached("sec-websocket-accept"); + /** + * {@code "sec-websocket-protocol"} + */ + public static final AsciiString SEC_WEBSOCKET_EXTENSIONS = AsciiString.cached("sec-websocket-extensions"); + /** + * {@code "server"} + */ + public static final AsciiString SERVER = AsciiString.cached("server"); + /** + * {@code "set-cookie"} + */ + public static final AsciiString SET_COOKIE = AsciiString.cached("set-cookie"); + /** + * {@code "set-cookie2"} + */ + public static final AsciiString SET_COOKIE2 = AsciiString.cached("set-cookie2"); + /** + * {@code "te"} + */ + public static final AsciiString TE = AsciiString.cached("te"); + /** + * {@code "trailer"} + */ + public static final AsciiString TRAILER = AsciiString.cached("trailer"); + /** + * {@code "transfer-encoding"} + */ + public static final AsciiString TRANSFER_ENCODING = AsciiString.cached("transfer-encoding"); + /** + * {@code "upgrade"} + */ + public static final AsciiString UPGRADE = AsciiString.cached("upgrade"); + /** + * {@code "upgrade-insecure-requests"} + */ + public static final AsciiString UPGRADE_INSECURE_REQUESTS = AsciiString.cached("upgrade-insecure-requests"); + /** + * {@code "user-agent"} + */ + public static final AsciiString USER_AGENT = AsciiString.cached("user-agent"); + /** + * {@code "vary"} + */ + public static final AsciiString VARY = AsciiString.cached("vary"); + /** + * {@code "via"} + */ + public static final AsciiString VIA = AsciiString.cached("via"); + /** + * {@code "warning"} + */ + public static final AsciiString WARNING = AsciiString.cached("warning"); + /** + * {@code "websocket-location"} + */ + public static final AsciiString WEBSOCKET_LOCATION = AsciiString.cached("websocket-location"); + /** + * {@code "websocket-origin"} + */ + public static final AsciiString WEBSOCKET_ORIGIN = AsciiString.cached("websocket-origin"); + /** + * {@code "websocket-protocol"} + */ + public static final AsciiString WEBSOCKET_PROTOCOL = AsciiString.cached("websocket-protocol"); + /** + * {@code "www-authenticate"} + */ + public static final AsciiString WWW_AUTHENTICATE = AsciiString.cached("www-authenticate"); + /** + * {@code "x-frame-options"} + */ + public static final AsciiString X_FRAME_OPTIONS = AsciiString.cached("x-frame-options"); + /** + * {@code "x-requested-with"} + */ + public static final AsciiString X_REQUESTED_WITH = AsciiString.cached("x-requested-with"); + + /** + * {@code "alt-svc"} + */ + public static final AsciiString ALT_SVC = AsciiString.cached("alt-svc"); + + private HttpHeaderNames() { } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpHeaderValidationUtil.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpHeaderValidationUtil.java new file mode 100644 index 0000000..6cbc0d1 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpHeaderValidationUtil.java @@ -0,0 +1,300 @@ +/* + * Copyright 2022 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.util.AsciiString; +import io.netty.util.internal.UnstableApi; + +import static io.netty.util.AsciiString.contentEqualsIgnoreCase; + +/** + * Functions used to perform various validations of HTTP header names and values. + */ +@UnstableApi +public final class HttpHeaderValidationUtil { + private HttpHeaderValidationUtil() { + } + + /** + * Check if a header name is "connection related". + *

+ * The RFC9110 only specify an incomplete + * list of the following headers: + * + *

    + *
  • Connection
  • + *
  • Proxy-Connection
  • + *
  • Keep-Alive
  • + *
  • TE
  • + *
  • Transfer-Encoding
  • + *
  • Upgrade
  • + *
+ * + * @param name the name of the header to check. The check is case-insensitive. + * @param ignoreTeHeader {@code true} if the TE header should be ignored by this check. + * This is relevant for HTTP/2 header validation, where the TE header has special rules. + * @return {@code true} if the given header name is one of the specified connection-related headers. + */ + @SuppressWarnings("deprecation") // We need to check for deprecated headers as well. + public static boolean isConnectionHeader(CharSequence name, boolean ignoreTeHeader) { + // These are the known standard and non-standard connection related headers: + // - upgrade (7 chars) + // - connection (10 chars) + // - keep-alive (10 chars) + // - proxy-connection (16 chars) + // - transfer-encoding (17 chars) + // + // See https://datatracker.ietf.org/doc/html/rfc9113#section-8.2.2 + // and https://datatracker.ietf.org/doc/html/rfc9110#section-7.6.1 + // for the list of connection related headers. + // + // We scan for these based on the length, then double-check any matching name. + int len = name.length(); + switch (len) { + case 2: return ignoreTeHeader? false : contentEqualsIgnoreCase(name, HttpHeaderNames.TE); + case 7: return contentEqualsIgnoreCase(name, HttpHeaderNames.UPGRADE); + case 10: return contentEqualsIgnoreCase(name, HttpHeaderNames.CONNECTION) || + contentEqualsIgnoreCase(name, HttpHeaderNames.KEEP_ALIVE); + case 16: return contentEqualsIgnoreCase(name, HttpHeaderNames.PROXY_CONNECTION); + case 17: return contentEqualsIgnoreCase(name, HttpHeaderNames.TRANSFER_ENCODING); + default: + return false; + } + } + + /** + * If the given header is {@link HttpHeaderNames#TE} and the given header value is not + * {@link HttpHeaderValues#TRAILERS}, then return {@code true}. Otherwie, {@code false}. + *

+ * The string comparisons are case-insensitive. + *

+ * This check is important for HTTP/2 header validation. + * + * @param name the header name to check if it is TE or not. + * @param value the header value to check if it is something other than TRAILERS. + * @return {@code true} only if the header name is TE, and the header value is not + * TRAILERS. Otherwise, {@code false}. + */ + public static boolean isTeNotTrailers(CharSequence name, CharSequence value) { + if (name.length() == 2) { + return contentEqualsIgnoreCase(name, HttpHeaderNames.TE) && + !contentEqualsIgnoreCase(value, HttpHeaderValues.TRAILERS); + } + return false; + } + + /** + * Validate the given HTTP header value by searching for any illegal characters. + * + * @param value the HTTP header value to validate. + * @return the index of the first illegal character found, or {@code -1} if there are none and the header value is + * valid. + */ + public static int validateValidHeaderValue(CharSequence value) { + int length = value.length(); + if (length == 0) { + return -1; + } + if (value instanceof AsciiString) { + return verifyValidHeaderValueAsciiString((AsciiString) value); + } + return verifyValidHeaderValueCharSequence(value); + } + + private static int verifyValidHeaderValueAsciiString(AsciiString value) { + // Validate value to field-content rule. + // field-content = field-vchar [ 1*( SP / HTAB ) field-vchar ] + // field-vchar = VCHAR / obs-text + // VCHAR = %x21-7E ; visible (printing) characters + // obs-text = %x80-FF + // SP = %x20 + // HTAB = %x09 ; horizontal tab + // See: https://datatracker.ietf.org/doc/html/rfc7230#section-3.2 + // And: https://datatracker.ietf.org/doc/html/rfc5234#appendix-B.1 + final byte[] array = value.array(); + final int start = value.arrayOffset(); + int b = array[start] & 0xFF; + if (b < 0x21 || b == 0x7F) { + return 0; + } + int length = value.length(); + for (int i = start + 1; i < length; i++) { + b = array[i] & 0xFF; + if (b < 0x20 && b != 0x09 || b == 0x7F) { + return i - start; + } + } + return -1; + } + + private static int verifyValidHeaderValueCharSequence(CharSequence value) { + // Validate value to field-content rule. + // field-content = field-vchar [ 1*( SP / HTAB ) field-vchar ] + // field-vchar = VCHAR / obs-text + // VCHAR = %x21-7E ; visible (printing) characters + // obs-text = %x80-FF + // SP = %x20 + // HTAB = %x09 ; horizontal tab + // See: https://datatracker.ietf.org/doc/html/rfc7230#section-3.2 + // And: https://datatracker.ietf.org/doc/html/rfc5234#appendix-B.1 + int b = value.charAt(0); + if (b < 0x21 || b == 0x7F) { + return 0; + } + int length = value.length(); + for (int i = 1; i < length; i++) { + b = value.charAt(i); + if (b < 0x20 && b != 0x09 || b == 0x7F) { + return i; + } + } + return -1; + } + + /** + * Validate a token contains only allowed + * characters. + *

+ * The token format is used for variety of HTTP + * components, like cookie-name, + * field-name of a + * header-field, or + * request method. + * + * @param token the token to validate. + * @return the index of the first invalid token character found, or {@code -1} if there are none. + */ + public static int validateToken(CharSequence token) { + if (token instanceof AsciiString) { + return validateAsciiStringToken((AsciiString) token); + } + return validateCharSequenceToken(token); + } + + /** + * Validate that an {@link AsciiString} contain onlu valid + * token characters. + * + * @param token the ascii string to validate. + */ + private static int validateAsciiStringToken(AsciiString token) { + byte[] array = token.array(); + for (int i = token.arrayOffset(), len = token.arrayOffset() + token.length(); i < len; i++) { + if (!BitSet128.contains(array[i], TOKEN_CHARS_HIGH, TOKEN_CHARS_LOW)) { + return i - token.arrayOffset(); + } + } + return -1; + } + + /** + * Validate that a {@link CharSequence} contain onlu valid + * token characters. + * + * @param token the character sequence to validate. + */ + private static int validateCharSequenceToken(CharSequence token) { + for (int i = 0, len = token.length(); i < len; i++) { + byte value = (byte) token.charAt(i); + if (!BitSet128.contains(value, TOKEN_CHARS_HIGH, TOKEN_CHARS_LOW)) { + return i; + } + } + return -1; + } + + private static final long TOKEN_CHARS_HIGH; + private static final long TOKEN_CHARS_LOW; + static { + // HEADER + // header-field = field-name ":" OWS field-value OWS + // + // field-name = token + // token = 1*tchar + // + // tchar = "!" / "#" / "$" / "%" / "&" / "'" / "*" + // / "+" / "-" / "." / "^" / "_" / "`" / "|" / "~" + // / DIGIT / ALPHA + // ; any VCHAR, except delimiters. + // Delimiters are chosen + // from the set of US-ASCII visual characters not allowed in a token + // (DQUOTE and "(),/:;<=>?@[\]{}") + // + // COOKIE + // cookie-pair = cookie-name "=" cookie-value + // cookie-name = token + // token = 1* + // CTL = + // separators = "(" | ")" | "<" | ">" | "@" + // | "," | ";" | ":" | "\" | <"> + // | "/" | "[" | "]" | "?" | "=" + // | "{" | "}" | SP | HT + // + // field-name's token is equivalent to cookie-name's token, we can reuse the tchar mask for both: + BitSet128 tokenChars = new BitSet128() + .range('0', '9').range('a', 'z').range('A', 'Z') // Alphanumeric. + .bits('-', '.', '_', '~') // Unreserved characters. + .bits('!', '#', '$', '%', '&', '\'', '*', '+', '^', '`', '|'); // Token special characters. + TOKEN_CHARS_HIGH = tokenChars.high(); + TOKEN_CHARS_LOW = tokenChars.low(); + } + + private static final class BitSet128 { + private long high; + private long low; + + BitSet128 range(char fromInc, char toInc) { + for (int bit = fromInc; bit <= toInc; bit++) { + if (bit < 64) { + low |= 1L << bit; + } else { + high |= 1L << bit - 64; + } + } + return this; + } + + BitSet128 bits(char... bits) { + for (char bit : bits) { + if (bit < 64) { + low |= 1L << bit; + } else { + high |= 1L << bit - 64; + } + } + return this; + } + + long high() { + return high; + } + + long low() { + return low; + } + + static boolean contains(byte bit, long high, long low) { + if (bit < 0) { + return false; + } + if (bit < 64) { + return 0 != (low & 1L << bit); + } + return 0 != (high & 1L << bit - 64); + } + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpHeaderValues.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpHeaderValues.java new file mode 100644 index 0000000..6b09c16 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpHeaderValues.java @@ -0,0 +1,255 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.codec.http; + +import io.netty.util.AsciiString; + +/** + * Standard HTTP header values. + */ +public final class HttpHeaderValues { + /** + * {@code "application/json"} + */ + public static final AsciiString APPLICATION_JSON = AsciiString.cached("application/json"); + /** + * {@code "application/x-www-form-urlencoded"} + */ + public static final AsciiString APPLICATION_X_WWW_FORM_URLENCODED = + AsciiString.cached("application/x-www-form-urlencoded"); + /** + * {@code "application/octet-stream"} + */ + public static final AsciiString APPLICATION_OCTET_STREAM = AsciiString.cached("application/octet-stream"); + /** + * {@code "application/xhtml+xml"} + */ + public static final AsciiString APPLICATION_XHTML = AsciiString.cached("application/xhtml+xml"); + /** + * {@code "application/xml"} + */ + public static final AsciiString APPLICATION_XML = AsciiString.cached("application/xml"); + /** + * {@code "application/zstd"} + */ + public static final AsciiString APPLICATION_ZSTD = AsciiString.cached("application/zstd"); + /** + * {@code "attachment"} + * See {@link HttpHeaderNames#CONTENT_DISPOSITION} + */ + public static final AsciiString ATTACHMENT = AsciiString.cached("attachment"); + /** + * {@code "base64"} + */ + public static final AsciiString BASE64 = AsciiString.cached("base64"); + /** + * {@code "binary"} + */ + public static final AsciiString BINARY = AsciiString.cached("binary"); + /** + * {@code "boundary"} + */ + public static final AsciiString BOUNDARY = AsciiString.cached("boundary"); + /** + * {@code "bytes"} + */ + public static final AsciiString BYTES = AsciiString.cached("bytes"); + /** + * {@code "charset"} + */ + public static final AsciiString CHARSET = AsciiString.cached("charset"); + /** + * {@code "chunked"} + */ + public static final AsciiString CHUNKED = AsciiString.cached("chunked"); + /** + * {@code "close"} + */ + public static final AsciiString CLOSE = AsciiString.cached("close"); + /** + * {@code "compress"} + */ + public static final AsciiString COMPRESS = AsciiString.cached("compress"); + /** + * {@code "100-continue"} + */ + public static final AsciiString CONTINUE = AsciiString.cached("100-continue"); + /** + * {@code "deflate"} + */ + public static final AsciiString DEFLATE = AsciiString.cached("deflate"); + /** + * {@code "x-deflate"} + */ + public static final AsciiString X_DEFLATE = AsciiString.cached("x-deflate"); + /** + * {@code "file"} + * See {@link HttpHeaderNames#CONTENT_DISPOSITION} + */ + public static final AsciiString FILE = AsciiString.cached("file"); + /** + * {@code "filename"} + * See {@link HttpHeaderNames#CONTENT_DISPOSITION} + */ + public static final AsciiString FILENAME = AsciiString.cached("filename"); + /** + * {@code "form-data"} + * See {@link HttpHeaderNames#CONTENT_DISPOSITION} + */ + public static final AsciiString FORM_DATA = AsciiString.cached("form-data"); + /** + * {@code "gzip"} + */ + public static final AsciiString GZIP = AsciiString.cached("gzip"); + /** + * {@code "br"} + */ + public static final AsciiString BR = AsciiString.cached("br"); + + /** + * {@code "snappy"} + */ + public static final AsciiString SNAPPY = AsciiString.cached("snappy"); + + /** + * {@code "zstd"} + */ + public static final AsciiString ZSTD = AsciiString.cached("zstd"); + /** + * {@code "gzip,deflate"} + */ + public static final AsciiString GZIP_DEFLATE = AsciiString.cached("gzip,deflate"); + /** + * {@code "x-gzip"} + */ + public static final AsciiString X_GZIP = AsciiString.cached("x-gzip"); + /** + * {@code "identity"} + */ + public static final AsciiString IDENTITY = AsciiString.cached("identity"); + /** + * {@code "keep-alive"} + */ + public static final AsciiString KEEP_ALIVE = AsciiString.cached("keep-alive"); + /** + * {@code "max-age"} + */ + public static final AsciiString MAX_AGE = AsciiString.cached("max-age"); + /** + * {@code "max-stale"} + */ + public static final AsciiString MAX_STALE = AsciiString.cached("max-stale"); + /** + * {@code "min-fresh"} + */ + public static final AsciiString MIN_FRESH = AsciiString.cached("min-fresh"); + /** + * {@code "multipart/form-data"} + */ + public static final AsciiString MULTIPART_FORM_DATA = AsciiString.cached("multipart/form-data"); + /** + * {@code "multipart/mixed"} + */ + public static final AsciiString MULTIPART_MIXED = AsciiString.cached("multipart/mixed"); + /** + * {@code "must-revalidate"} + */ + public static final AsciiString MUST_REVALIDATE = AsciiString.cached("must-revalidate"); + /** + * {@code "name"} + * See {@link HttpHeaderNames#CONTENT_DISPOSITION} + */ + public static final AsciiString NAME = AsciiString.cached("name"); + /** + * {@code "no-cache"} + */ + public static final AsciiString NO_CACHE = AsciiString.cached("no-cache"); + /** + * {@code "no-store"} + */ + public static final AsciiString NO_STORE = AsciiString.cached("no-store"); + /** + * {@code "no-transform"} + */ + public static final AsciiString NO_TRANSFORM = AsciiString.cached("no-transform"); + /** + * {@code "none"} + */ + public static final AsciiString NONE = AsciiString.cached("none"); + /** + * {@code "0"} + */ + public static final AsciiString ZERO = AsciiString.cached("0"); + /** + * {@code "only-if-cached"} + */ + public static final AsciiString ONLY_IF_CACHED = AsciiString.cached("only-if-cached"); + /** + * {@code "private"} + */ + public static final AsciiString PRIVATE = AsciiString.cached("private"); + /** + * {@code "proxy-revalidate"} + */ + public static final AsciiString PROXY_REVALIDATE = AsciiString.cached("proxy-revalidate"); + /** + * {@code "public"} + */ + public static final AsciiString PUBLIC = AsciiString.cached("public"); + /** + * {@code "quoted-printable"} + */ + public static final AsciiString QUOTED_PRINTABLE = AsciiString.cached("quoted-printable"); + /** + * {@code "s-maxage"} + */ + public static final AsciiString S_MAXAGE = AsciiString.cached("s-maxage"); + /** + * {@code "text/css"} + */ + public static final AsciiString TEXT_CSS = AsciiString.cached("text/css"); + /** + * {@code "text/html"} + */ + public static final AsciiString TEXT_HTML = AsciiString.cached("text/html"); + /** + * {@code "text/event-stream"} + */ + public static final AsciiString TEXT_EVENT_STREAM = AsciiString.cached("text/event-stream"); + /** + * {@code "text/plain"} + */ + public static final AsciiString TEXT_PLAIN = AsciiString.cached("text/plain"); + /** + * {@code "trailers"} + */ + public static final AsciiString TRAILERS = AsciiString.cached("trailers"); + /** + * {@code "upgrade"} + */ + public static final AsciiString UPGRADE = AsciiString.cached("upgrade"); + /** + * {@code "websocket"} + */ + public static final AsciiString WEBSOCKET = AsciiString.cached("websocket"); + /** + * {@code "XmlHttpRequest"} + */ + public static final AsciiString XML_HTTP_REQUEST = AsciiString.cached("XMLHttpRequest"); + + private HttpHeaderValues() { } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpHeaders.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpHeaders.java new file mode 100644 index 0000000..ede9173 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpHeaders.java @@ -0,0 +1,1705 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.handler.codec.DateFormatter; +import io.netty.handler.codec.Headers; +import io.netty.handler.codec.HeadersUtils; +import io.netty.util.AsciiString; +import io.netty.util.CharsetUtil; +import io.netty.util.internal.ObjectUtil; + +import java.text.ParseException; +import java.util.Calendar; +import java.util.Date; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Set; + +import static io.netty.util.AsciiString.contentEquals; +import static io.netty.util.AsciiString.contentEqualsIgnoreCase; +import static io.netty.util.AsciiString.trim; +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +/** + * Provides the constants for the standard HTTP header names and values and + * commonly used utility methods that accesses an {@link HttpMessage}. + *

+ * Concrete instances of this class are most easily obtained from its default factory: + * {@link DefaultHttpHeadersFactory#headersFactory()}. + */ +public abstract class HttpHeaders implements Iterable> { + /** + * @deprecated Use {@link EmptyHttpHeaders#INSTANCE}. + *

+ * The instance is instantiated here to break the cyclic static initialization between {@link EmptyHttpHeaders} and + * {@link HttpHeaders}. The issue is that if someone accesses {@link EmptyHttpHeaders#INSTANCE} before + * {@link HttpHeaders#EMPTY_HEADERS} then {@link HttpHeaders#EMPTY_HEADERS} will be {@code null}. + */ + @Deprecated + public static final HttpHeaders EMPTY_HEADERS = EmptyHttpHeaders.instance(); + + /** + * @deprecated Use {@link HttpHeaderNames} instead. + * + * Standard HTTP header names. + */ + @Deprecated + public static final class Names { + /** + * {@code "Accept"} + */ + public static final String ACCEPT = "Accept"; + /** + * {@code "Accept-Charset"} + */ + public static final String ACCEPT_CHARSET = "Accept-Charset"; + /** + * {@code "Accept-Encoding"} + */ + public static final String ACCEPT_ENCODING = "Accept-Encoding"; + /** + * {@code "Accept-Language"} + */ + public static final String ACCEPT_LANGUAGE = "Accept-Language"; + /** + * {@code "Accept-Ranges"} + */ + public static final String ACCEPT_RANGES = "Accept-Ranges"; + /** + * {@code "Accept-Patch"} + */ + public static final String ACCEPT_PATCH = "Accept-Patch"; + /** + * {@code "Access-Control-Allow-Credentials"} + */ + public static final String ACCESS_CONTROL_ALLOW_CREDENTIALS = "Access-Control-Allow-Credentials"; + /** + * {@code "Access-Control-Allow-Headers"} + */ + public static final String ACCESS_CONTROL_ALLOW_HEADERS = "Access-Control-Allow-Headers"; + /** + * {@code "Access-Control-Allow-Methods"} + */ + public static final String ACCESS_CONTROL_ALLOW_METHODS = "Access-Control-Allow-Methods"; + /** + * {@code "Access-Control-Allow-Origin"} + */ + public static final String ACCESS_CONTROL_ALLOW_ORIGIN = "Access-Control-Allow-Origin"; + /** + * {@code "Access-Control-Expose-Headers"} + */ + public static final String ACCESS_CONTROL_EXPOSE_HEADERS = "Access-Control-Expose-Headers"; + /** + * {@code "Access-Control-Max-Age"} + */ + public static final String ACCESS_CONTROL_MAX_AGE = "Access-Control-Max-Age"; + /** + * {@code "Access-Control-Request-Headers"} + */ + public static final String ACCESS_CONTROL_REQUEST_HEADERS = "Access-Control-Request-Headers"; + /** + * {@code "Access-Control-Request-Method"} + */ + public static final String ACCESS_CONTROL_REQUEST_METHOD = "Access-Control-Request-Method"; + /** + * {@code "Age"} + */ + public static final String AGE = "Age"; + /** + * {@code "Allow"} + */ + public static final String ALLOW = "Allow"; + /** + * {@code "Authorization"} + */ + public static final String AUTHORIZATION = "Authorization"; + /** + * {@code "Cache-Control"} + */ + public static final String CACHE_CONTROL = "Cache-Control"; + /** + * {@code "Connection"} + */ + public static final String CONNECTION = "Connection"; + /** + * {@code "Content-Base"} + */ + public static final String CONTENT_BASE = "Content-Base"; + /** + * {@code "Content-Encoding"} + */ + public static final String CONTENT_ENCODING = "Content-Encoding"; + /** + * {@code "Content-Language"} + */ + public static final String CONTENT_LANGUAGE = "Content-Language"; + /** + * {@code "Content-Length"} + */ + public static final String CONTENT_LENGTH = "Content-Length"; + /** + * {@code "Content-Location"} + */ + public static final String CONTENT_LOCATION = "Content-Location"; + /** + * {@code "Content-Transfer-Encoding"} + */ + public static final String CONTENT_TRANSFER_ENCODING = "Content-Transfer-Encoding"; + /** + * {@code "Content-MD5"} + */ + public static final String CONTENT_MD5 = "Content-MD5"; + /** + * {@code "Content-Range"} + */ + public static final String CONTENT_RANGE = "Content-Range"; + /** + * {@code "Content-Type"} + */ + public static final String CONTENT_TYPE = "Content-Type"; + /** + * {@code "Cookie"} + */ + public static final String COOKIE = "Cookie"; + /** + * {@code "Date"} + */ + public static final String DATE = "Date"; + /** + * {@code "ETag"} + */ + public static final String ETAG = "ETag"; + /** + * {@code "Expect"} + */ + public static final String EXPECT = "Expect"; + /** + * {@code "Expires"} + */ + public static final String EXPIRES = "Expires"; + /** + * {@code "From"} + */ + public static final String FROM = "From"; + /** + * {@code "Host"} + */ + public static final String HOST = "Host"; + /** + * {@code "If-Match"} + */ + public static final String IF_MATCH = "If-Match"; + /** + * {@code "If-Modified-Since"} + */ + public static final String IF_MODIFIED_SINCE = "If-Modified-Since"; + /** + * {@code "If-None-Match"} + */ + public static final String IF_NONE_MATCH = "If-None-Match"; + /** + * {@code "If-Range"} + */ + public static final String IF_RANGE = "If-Range"; + /** + * {@code "If-Unmodified-Since"} + */ + public static final String IF_UNMODIFIED_SINCE = "If-Unmodified-Since"; + /** + * {@code "Last-Modified"} + */ + public static final String LAST_MODIFIED = "Last-Modified"; + /** + * {@code "Location"} + */ + public static final String LOCATION = "Location"; + /** + * {@code "Max-Forwards"} + */ + public static final String MAX_FORWARDS = "Max-Forwards"; + /** + * {@code "Origin"} + */ + public static final String ORIGIN = "Origin"; + /** + * {@code "Pragma"} + */ + public static final String PRAGMA = "Pragma"; + /** + * {@code "Proxy-Authenticate"} + */ + public static final String PROXY_AUTHENTICATE = "Proxy-Authenticate"; + /** + * {@code "Proxy-Authorization"} + */ + public static final String PROXY_AUTHORIZATION = "Proxy-Authorization"; + /** + * {@code "Range"} + */ + public static final String RANGE = "Range"; + /** + * {@code "Referer"} + */ + public static final String REFERER = "Referer"; + /** + * {@code "Retry-After"} + */ + public static final String RETRY_AFTER = "Retry-After"; + /** + * {@code "Sec-WebSocket-Key1"} + */ + public static final String SEC_WEBSOCKET_KEY1 = "Sec-WebSocket-Key1"; + /** + * {@code "Sec-WebSocket-Key2"} + */ + public static final String SEC_WEBSOCKET_KEY2 = "Sec-WebSocket-Key2"; + /** + * {@code "Sec-WebSocket-Location"} + */ + public static final String SEC_WEBSOCKET_LOCATION = "Sec-WebSocket-Location"; + /** + * {@code "Sec-WebSocket-Origin"} + */ + public static final String SEC_WEBSOCKET_ORIGIN = "Sec-WebSocket-Origin"; + /** + * {@code "Sec-WebSocket-Protocol"} + */ + public static final String SEC_WEBSOCKET_PROTOCOL = "Sec-WebSocket-Protocol"; + /** + * {@code "Sec-WebSocket-Version"} + */ + public static final String SEC_WEBSOCKET_VERSION = "Sec-WebSocket-Version"; + /** + * {@code "Sec-WebSocket-Key"} + */ + public static final String SEC_WEBSOCKET_KEY = "Sec-WebSocket-Key"; + /** + * {@code "Sec-WebSocket-Accept"} + */ + public static final String SEC_WEBSOCKET_ACCEPT = "Sec-WebSocket-Accept"; + /** + * {@code "Server"} + */ + public static final String SERVER = "Server"; + /** + * {@code "Set-Cookie"} + */ + public static final String SET_COOKIE = "Set-Cookie"; + /** + * {@code "Set-Cookie2"} + */ + public static final String SET_COOKIE2 = "Set-Cookie2"; + /** + * {@code "TE"} + */ + public static final String TE = "TE"; + /** + * {@code "Trailer"} + */ + public static final String TRAILER = "Trailer"; + /** + * {@code "Transfer-Encoding"} + */ + public static final String TRANSFER_ENCODING = "Transfer-Encoding"; + /** + * {@code "Upgrade"} + */ + public static final String UPGRADE = "Upgrade"; + /** + * {@code "User-Agent"} + */ + public static final String USER_AGENT = "User-Agent"; + /** + * {@code "Vary"} + */ + public static final String VARY = "Vary"; + /** + * {@code "Via"} + */ + public static final String VIA = "Via"; + /** + * {@code "Warning"} + */ + public static final String WARNING = "Warning"; + /** + * {@code "WebSocket-Location"} + */ + public static final String WEBSOCKET_LOCATION = "WebSocket-Location"; + /** + * {@code "WebSocket-Origin"} + */ + public static final String WEBSOCKET_ORIGIN = "WebSocket-Origin"; + /** + * {@code "WebSocket-Protocol"} + */ + public static final String WEBSOCKET_PROTOCOL = "WebSocket-Protocol"; + /** + * {@code "WWW-Authenticate"} + */ + public static final String WWW_AUTHENTICATE = "WWW-Authenticate"; + + private Names() { + } + } + + /** + * @deprecated Use {@link HttpHeaderValues} instead. + * + * Standard HTTP header values. + */ + @Deprecated + public static final class Values { + /** + * {@code "application/json"} + */ + public static final String APPLICATION_JSON = "application/json"; + /** + * {@code "application/x-www-form-urlencoded"} + */ + public static final String APPLICATION_X_WWW_FORM_URLENCODED = + "application/x-www-form-urlencoded"; + /** + * {@code "base64"} + */ + public static final String BASE64 = "base64"; + /** + * {@code "binary"} + */ + public static final String BINARY = "binary"; + /** + * {@code "boundary"} + */ + public static final String BOUNDARY = "boundary"; + /** + * {@code "bytes"} + */ + public static final String BYTES = "bytes"; + /** + * {@code "charset"} + */ + public static final String CHARSET = "charset"; + /** + * {@code "chunked"} + */ + public static final String CHUNKED = "chunked"; + /** + * {@code "close"} + */ + public static final String CLOSE = "close"; + /** + * {@code "compress"} + */ + public static final String COMPRESS = "compress"; + /** + * {@code "100-continue"} + */ + public static final String CONTINUE = "100-continue"; + /** + * {@code "deflate"} + */ + public static final String DEFLATE = "deflate"; + /** + * {@code "gzip"} + */ + public static final String GZIP = "gzip"; + /** + * {@code "gzip,deflate"} + */ + public static final String GZIP_DEFLATE = "gzip,deflate"; + /** + * {@code "identity"} + */ + public static final String IDENTITY = "identity"; + /** + * {@code "keep-alive"} + */ + public static final String KEEP_ALIVE = "keep-alive"; + /** + * {@code "max-age"} + */ + public static final String MAX_AGE = "max-age"; + /** + * {@code "max-stale"} + */ + public static final String MAX_STALE = "max-stale"; + /** + * {@code "min-fresh"} + */ + public static final String MIN_FRESH = "min-fresh"; + /** + * {@code "multipart/form-data"} + */ + public static final String MULTIPART_FORM_DATA = "multipart/form-data"; + /** + * {@code "must-revalidate"} + */ + public static final String MUST_REVALIDATE = "must-revalidate"; + /** + * {@code "no-cache"} + */ + public static final String NO_CACHE = "no-cache"; + /** + * {@code "no-store"} + */ + public static final String NO_STORE = "no-store"; + /** + * {@code "no-transform"} + */ + public static final String NO_TRANSFORM = "no-transform"; + /** + * {@code "none"} + */ + public static final String NONE = "none"; + /** + * {@code "only-if-cached"} + */ + public static final String ONLY_IF_CACHED = "only-if-cached"; + /** + * {@code "private"} + */ + public static final String PRIVATE = "private"; + /** + * {@code "proxy-revalidate"} + */ + public static final String PROXY_REVALIDATE = "proxy-revalidate"; + /** + * {@code "public"} + */ + public static final String PUBLIC = "public"; + /** + * {@code "quoted-printable"} + */ + public static final String QUOTED_PRINTABLE = "quoted-printable"; + /** + * {@code "s-maxage"} + */ + public static final String S_MAXAGE = "s-maxage"; + /** + * {@code "trailers"} + */ + public static final String TRAILERS = "trailers"; + /** + * {@code "Upgrade"} + */ + public static final String UPGRADE = "Upgrade"; + /** + * {@code "WebSocket"} + */ + public static final String WEBSOCKET = "WebSocket"; + + private Values() { + } + } + + /** + * @deprecated Use {@link HttpUtil#isKeepAlive(HttpMessage)} instead. + * + * Returns {@code true} if and only if the connection can remain open and + * thus 'kept alive'. This methods respects the value of the + * {@code "Connection"} header first and then the return value of + * {@link HttpVersion#isKeepAliveDefault()}. + */ + @Deprecated + public static boolean isKeepAlive(HttpMessage message) { + return HttpUtil.isKeepAlive(message); + } + + /** + * @deprecated Use {@link HttpUtil#setKeepAlive(HttpMessage, boolean)} instead. + * + * Sets the value of the {@code "Connection"} header depending on the + * protocol version of the specified message. This getMethod sets or removes + * the {@code "Connection"} header depending on what the default keep alive + * mode of the message's protocol version is, as specified by + * {@link HttpVersion#isKeepAliveDefault()}. + *

    + *
  • If the connection is kept alive by default: + *
      + *
    • set to {@code "close"} if {@code keepAlive} is {@code false}.
    • + *
    • remove otherwise.
    • + *
  • + *
  • If the connection is closed by default: + *
      + *
    • set to {@code "keep-alive"} if {@code keepAlive} is {@code true}.
    • + *
    • remove otherwise.
    • + *
  • + *
+ */ + @Deprecated + public static void setKeepAlive(HttpMessage message, boolean keepAlive) { + HttpUtil.setKeepAlive(message, keepAlive); + } + + /** + * @deprecated Use {@link #get(CharSequence)} instead. + */ + @Deprecated + public static String getHeader(HttpMessage message, String name) { + return message.headers().get(name); + } + + /** + * @deprecated Use {@link #get(CharSequence)} instead. + * + * Returns the header value with the specified header name. If there are + * more than one header value for the specified header name, the first + * value is returned. + * + * @return the header value or {@code null} if there is no such header + */ + @Deprecated + public static String getHeader(HttpMessage message, CharSequence name) { + return message.headers().get(name); + } + + /** + * @deprecated Use {@link #get(CharSequence, String)} instead. + * + * @see #getHeader(HttpMessage, CharSequence, String) + */ + @Deprecated + public static String getHeader(HttpMessage message, String name, String defaultValue) { + return message.headers().get(name, defaultValue); + } + + /** + * @deprecated Use {@link #get(CharSequence, String)} instead. + * + * Returns the header value with the specified header name. If there are + * more than one header value for the specified header name, the first + * value is returned. + * + * @return the header value or the {@code defaultValue} if there is no such + * header + */ + @Deprecated + public static String getHeader(HttpMessage message, CharSequence name, String defaultValue) { + return message.headers().get(name, defaultValue); + } + + /** + * @deprecated Use {@link #set(CharSequence, Object)} instead. + * + * @see #setHeader(HttpMessage, CharSequence, Object) + */ + @Deprecated + public static void setHeader(HttpMessage message, String name, Object value) { + message.headers().set(name, value); + } + + /** + * @deprecated Use {@link #set(CharSequence, Object)} instead. + * + * Sets a new header with the specified name and value. If there is an + * existing header with the same name, the existing header is removed. + * If the specified value is not a {@link String}, it is converted into a + * {@link String} by {@link Object#toString()}, except for {@link Date} + * and {@link Calendar} which are formatted to the date format defined in + * RFC2616. + */ + @Deprecated + public static void setHeader(HttpMessage message, CharSequence name, Object value) { + message.headers().set(name, value); + } + + /** + * @deprecated Use {@link #set(CharSequence, Iterable)} instead. + * + * @see #setHeader(HttpMessage, CharSequence, Iterable) + */ + @Deprecated + public static void setHeader(HttpMessage message, String name, Iterable values) { + message.headers().set(name, values); + } + + /** + * @deprecated Use {@link #set(CharSequence, Iterable)} instead. + * + * Sets a new header with the specified name and values. If there is an + * existing header with the same name, the existing header is removed. + * This getMethod can be represented approximately as the following code: + *
+     * removeHeader(message, name);
+     * for (Object v: values) {
+     *     if (v == null) {
+     *         break;
+     *     }
+     *     addHeader(message, name, v);
+     * }
+     * 
+ */ + @Deprecated + public static void setHeader(HttpMessage message, CharSequence name, Iterable values) { + message.headers().set(name, values); + } + + /** + * @deprecated Use {@link #add(CharSequence, Object)} instead. + * + * @see #addHeader(HttpMessage, CharSequence, Object) + */ + @Deprecated + public static void addHeader(HttpMessage message, String name, Object value) { + message.headers().add(name, value); + } + + /** + * @deprecated Use {@link #add(CharSequence, Object)} instead. + * + * Adds a new header with the specified name and value. + * If the specified value is not a {@link String}, it is converted into a + * {@link String} by {@link Object#toString()}, except for {@link Date} + * and {@link Calendar} which are formatted to the date format defined in + * RFC2616. + */ + @Deprecated + public static void addHeader(HttpMessage message, CharSequence name, Object value) { + message.headers().add(name, value); + } + + /** + * @deprecated Use {@link #remove(CharSequence)} instead. + * + * @see #removeHeader(HttpMessage, CharSequence) + */ + @Deprecated + public static void removeHeader(HttpMessage message, String name) { + message.headers().remove(name); + } + + /** + * @deprecated Use {@link #remove(CharSequence)} instead. + * + * Removes the header with the specified name. + */ + @Deprecated + public static void removeHeader(HttpMessage message, CharSequence name) { + message.headers().remove(name); + } + + /** + * @deprecated Use {@link #clear()} instead. + * + * Removes all headers from the specified message. + */ + @Deprecated + public static void clearHeaders(HttpMessage message) { + message.headers().clear(); + } + + /** + * @deprecated Use {@link #getInt(CharSequence)} instead. + * + * @see #getIntHeader(HttpMessage, CharSequence) + */ + @Deprecated + public static int getIntHeader(HttpMessage message, String name) { + return getIntHeader(message, (CharSequence) name); + } + + /** + * @deprecated Use {@link #getInt(CharSequence)} instead. + * + * Returns the integer header value with the specified header name. If + * there are more than one header value for the specified header name, the + * first value is returned. + * + * @return the header value + * @throws NumberFormatException + * if there is no such header or the header value is not a number + */ + @Deprecated + public static int getIntHeader(HttpMessage message, CharSequence name) { + String value = message.headers().get(name); + if (value == null) { + throw new NumberFormatException("header not found: " + name); + } + return Integer.parseInt(value); + } + + /** + * @deprecated Use {@link #getInt(CharSequence, int)} instead. + * + * @see #getIntHeader(HttpMessage, CharSequence, int) + */ + @Deprecated + public static int getIntHeader(HttpMessage message, String name, int defaultValue) { + return message.headers().getInt(name, defaultValue); + } + + /** + * @deprecated Use {@link #getInt(CharSequence, int)} instead. + * + * Returns the integer header value with the specified header name. If + * there are more than one header value for the specified header name, the + * first value is returned. + * + * @return the header value or the {@code defaultValue} if there is no such + * header or the header value is not a number + */ + @Deprecated + public static int getIntHeader(HttpMessage message, CharSequence name, int defaultValue) { + return message.headers().getInt(name, defaultValue); + } + + /** + * @deprecated Use {@link #setInt(CharSequence, int)} instead. + * + * @see #setIntHeader(HttpMessage, CharSequence, int) + */ + @Deprecated + public static void setIntHeader(HttpMessage message, String name, int value) { + message.headers().setInt(name, value); + } + + /** + * @deprecated Use {@link #setInt(CharSequence, int)} instead. + * + * Sets a new integer header with the specified name and value. If there + * is an existing header with the same name, the existing header is removed. + */ + @Deprecated + public static void setIntHeader(HttpMessage message, CharSequence name, int value) { + message.headers().setInt(name, value); + } + + /** + * @deprecated Use {@link #set(CharSequence, Iterable)} instead. + * + * @see #setIntHeader(HttpMessage, CharSequence, Iterable) + */ + @Deprecated + public static void setIntHeader(HttpMessage message, String name, Iterable values) { + message.headers().set(name, values); + } + + /** + * @deprecated Use {@link #set(CharSequence, Iterable)} instead. + * + * Sets a new integer header with the specified name and values. If there + * is an existing header with the same name, the existing header is removed. + */ + @Deprecated + public static void setIntHeader(HttpMessage message, CharSequence name, Iterable values) { + message.headers().set(name, values); + } + + /** + * @deprecated Use {@link #add(CharSequence, Iterable)} instead. + * + * @see #addIntHeader(HttpMessage, CharSequence, int) + */ + @Deprecated + public static void addIntHeader(HttpMessage message, String name, int value) { + message.headers().add(name, value); + } + + /** + * @deprecated Use {@link #addInt(CharSequence, int)} instead. + * + * Adds a new integer header with the specified name and value. + */ + @Deprecated + public static void addIntHeader(HttpMessage message, CharSequence name, int value) { + message.headers().addInt(name, value); + } + + /** + * @deprecated Use {@link #getTimeMillis(CharSequence)} instead. + * + * @see #getDateHeader(HttpMessage, CharSequence) + */ + @Deprecated + public static Date getDateHeader(HttpMessage message, String name) throws ParseException { + return getDateHeader(message, (CharSequence) name); + } + + /** + * @deprecated Use {@link #getTimeMillis(CharSequence)} instead. + * + * Returns the date header value with the specified header name. If + * there are more than one header value for the specified header name, the + * first value is returned. + * + * @return the header value + * @throws ParseException + * if there is no such header or the header value is not a formatted date + */ + @Deprecated + public static Date getDateHeader(HttpMessage message, CharSequence name) throws ParseException { + String value = message.headers().get(name); + if (value == null) { + throw new ParseException("header not found: " + name, 0); + } + Date date = DateFormatter.parseHttpDate(value); + if (date == null) { + throw new ParseException("header can't be parsed into a Date: " + value, 0); + } + return date; + } + + /** + * @deprecated Use {@link #getTimeMillis(CharSequence, long)} instead. + * + * @see #getDateHeader(HttpMessage, CharSequence, Date) + */ + @Deprecated + public static Date getDateHeader(HttpMessage message, String name, Date defaultValue) { + return getDateHeader(message, (CharSequence) name, defaultValue); + } + + /** + * @deprecated Use {@link #getTimeMillis(CharSequence, long)} instead. + * + * Returns the date header value with the specified header name. If + * there are more than one header value for the specified header name, the + * first value is returned. + * + * @return the header value or the {@code defaultValue} if there is no such + * header or the header value is not a formatted date + */ + @Deprecated + public static Date getDateHeader(HttpMessage message, CharSequence name, Date defaultValue) { + final String value = getHeader(message, name); + Date date = DateFormatter.parseHttpDate(value); + return date != null ? date : defaultValue; + } + + /** + * @deprecated Use {@link #set(CharSequence, Object)} instead. + * + * @see #setDateHeader(HttpMessage, CharSequence, Date) + */ + @Deprecated + public static void setDateHeader(HttpMessage message, String name, Date value) { + setDateHeader(message, (CharSequence) name, value); + } + + /** + * @deprecated Use {@link #set(CharSequence, Object)} instead. + * + * Sets a new date header with the specified name and value. If there + * is an existing header with the same name, the existing header is removed. + * The specified value is formatted as defined in + * RFC2616 + */ + @Deprecated + public static void setDateHeader(HttpMessage message, CharSequence name, Date value) { + if (value != null) { + message.headers().set(name, DateFormatter.format(value)); + } else { + message.headers().set(name, null); + } + } + + /** + * @deprecated Use {@link #set(CharSequence, Iterable)} instead. + * + * @see #setDateHeader(HttpMessage, CharSequence, Iterable) + */ + @Deprecated + public static void setDateHeader(HttpMessage message, String name, Iterable values) { + message.headers().set(name, values); + } + + /** + * @deprecated Use {@link #set(CharSequence, Iterable)} instead. + * + * Sets a new date header with the specified name and values. If there + * is an existing header with the same name, the existing header is removed. + * The specified values are formatted as defined in + * RFC2616 + */ + @Deprecated + public static void setDateHeader(HttpMessage message, CharSequence name, Iterable values) { + message.headers().set(name, values); + } + + /** + * @deprecated Use {@link #add(CharSequence, Object)} instead. + * + * @see #addDateHeader(HttpMessage, CharSequence, Date) + */ + @Deprecated + public static void addDateHeader(HttpMessage message, String name, Date value) { + message.headers().add(name, value); + } + + /** + * @deprecated Use {@link #add(CharSequence, Object)} instead. + * + * Adds a new date header with the specified name and value. The specified + * value is formatted as defined in + * RFC2616 + */ + @Deprecated + public static void addDateHeader(HttpMessage message, CharSequence name, Date value) { + message.headers().add(name, value); + } + + /** + * @deprecated Use {@link HttpUtil#getContentLength(HttpMessage)} instead. + * + * Returns the length of the content. Please note that this value is + * not retrieved from {@link HttpContent#content()} but from the + * {@code "Content-Length"} header, and thus they are independent from each + * other. + * + * @return the content length + * + * @throws NumberFormatException + * if the message does not have the {@code "Content-Length"} header + * or its value is not a number + */ + @Deprecated + public static long getContentLength(HttpMessage message) { + return HttpUtil.getContentLength(message); + } + + /** + * @deprecated Use {@link HttpUtil#getContentLength(HttpMessage, long)} instead. + * + * Returns the length of the content. Please note that this value is + * not retrieved from {@link HttpContent#content()} but from the + * {@code "Content-Length"} header, and thus they are independent from each + * other. + * + * @return the content length or {@code defaultValue} if this message does + * not have the {@code "Content-Length"} header or its value is not + * a number + */ + @Deprecated + public static long getContentLength(HttpMessage message, long defaultValue) { + return HttpUtil.getContentLength(message, defaultValue); + } + + /** + * @deprecated Use {@link HttpUtil#setContentLength(HttpMessage, long)} instead. + */ + @Deprecated + public static void setContentLength(HttpMessage message, long length) { + HttpUtil.setContentLength(message, length); + } + + /** + * @deprecated Use {@link #get(CharSequence)} instead. + * + * Returns the value of the {@code "Host"} header. + */ + @Deprecated + public static String getHost(HttpMessage message) { + return message.headers().get(HttpHeaderNames.HOST); + } + + /** + * @deprecated Use {@link #get(CharSequence, String)} instead. + * + * Returns the value of the {@code "Host"} header. If there is no such + * header, the {@code defaultValue} is returned. + */ + @Deprecated + public static String getHost(HttpMessage message, String defaultValue) { + return message.headers().get(HttpHeaderNames.HOST, defaultValue); + } + + /** + * @deprecated Use {@link #set(CharSequence, Object)} instead. + * + * @see #setHost(HttpMessage, CharSequence) + */ + @Deprecated + public static void setHost(HttpMessage message, String value) { + message.headers().set(HttpHeaderNames.HOST, value); + } + + /** + * @deprecated Use {@link #set(CharSequence, Object)} instead. + * + * Sets the {@code "Host"} header. + */ + @Deprecated + public static void setHost(HttpMessage message, CharSequence value) { + message.headers().set(HttpHeaderNames.HOST, value); + } + + /** + * @deprecated Use {@link #getTimeMillis(CharSequence)} instead. + * + * Returns the value of the {@code "Date"} header. + * + * @throws ParseException + * if there is no such header or the header value is not a formatted date + */ + @Deprecated + public static Date getDate(HttpMessage message) throws ParseException { + return getDateHeader(message, HttpHeaderNames.DATE); + } + + /** + * @deprecated Use {@link #getTimeMillis(CharSequence, long)} instead. + * + * Returns the value of the {@code "Date"} header. If there is no such + * header or the header is not a formatted date, the {@code defaultValue} + * is returned. + */ + @Deprecated + public static Date getDate(HttpMessage message, Date defaultValue) { + return getDateHeader(message, HttpHeaderNames.DATE, defaultValue); + } + + /** + * @deprecated Use {@link #set(CharSequence, Object)} instead. + * + * Sets the {@code "Date"} header. + */ + @Deprecated + public static void setDate(HttpMessage message, Date value) { + message.headers().set(HttpHeaderNames.DATE, value); + } + + /** + * @deprecated Use {@link HttpUtil#is100ContinueExpected(HttpMessage)} instead. + * + * Returns {@code true} if and only if the specified message contains the + * {@code "Expect: 100-continue"} header. + */ + @Deprecated + public static boolean is100ContinueExpected(HttpMessage message) { + return HttpUtil.is100ContinueExpected(message); + } + + /** + * @deprecated Use {@link HttpUtil#set100ContinueExpected(HttpMessage, boolean)} instead. + * + * Sets the {@code "Expect: 100-continue"} header to the specified message. + * If there is any existing {@code "Expect"} header, they are replaced with + * the new one. + */ + @Deprecated + public static void set100ContinueExpected(HttpMessage message) { + HttpUtil.set100ContinueExpected(message, true); + } + + /** + * @deprecated Use {@link HttpUtil#set100ContinueExpected(HttpMessage, boolean)} instead. + * + * Sets or removes the {@code "Expect: 100-continue"} header to / from the + * specified message. If {@code set} is {@code true}, + * the {@code "Expect: 100-continue"} header is set and all other previous + * {@code "Expect"} headers are removed. Otherwise, all {@code "Expect"} + * headers are removed completely. + */ + @Deprecated + public static void set100ContinueExpected(HttpMessage message, boolean set) { + HttpUtil.set100ContinueExpected(message, set); + } + + /** + * @deprecated Use {@link HttpUtil#isTransferEncodingChunked(HttpMessage)} instead. + * + * Checks to see if the transfer encoding in a specified {@link HttpMessage} is chunked + * + * @param message The message to check + * @return True if transfer encoding is chunked, otherwise false + */ + @Deprecated + public static boolean isTransferEncodingChunked(HttpMessage message) { + return HttpUtil.isTransferEncodingChunked(message); + } + + /** + * @deprecated Use {@link HttpUtil#setTransferEncodingChunked(HttpMessage, boolean)} instead. + */ + @Deprecated + public static void removeTransferEncodingChunked(HttpMessage m) { + HttpUtil.setTransferEncodingChunked(m, false); + } + + /** + * @deprecated Use {@link HttpUtil#setTransferEncodingChunked(HttpMessage, boolean)} instead. + */ + @Deprecated + public static void setTransferEncodingChunked(HttpMessage m) { + HttpUtil.setTransferEncodingChunked(m, true); + } + + /** + * @deprecated Use {@link HttpUtil#isContentLengthSet(HttpMessage)} instead. + */ + @Deprecated + public static boolean isContentLengthSet(HttpMessage m) { + return HttpUtil.isContentLengthSet(m); + } + + /** + * @deprecated Use {@link AsciiString#contentEqualsIgnoreCase(CharSequence, CharSequence)} instead. + */ + @Deprecated + public static boolean equalsIgnoreCase(CharSequence name1, CharSequence name2) { + return contentEqualsIgnoreCase(name1, name2); + } + + @Deprecated + public static void encodeAscii(CharSequence seq, ByteBuf buf) { + if (seq instanceof AsciiString) { + ByteBufUtil.copy((AsciiString) seq, 0, buf, seq.length()); + } else { + buf.writeCharSequence(seq, CharsetUtil.US_ASCII); + } + } + + /** + * @deprecated Use {@link AsciiString} instead. + *

+ * Create a new {@link CharSequence} which is optimized for reuse as {@link HttpHeaders} name or value. + * So if you have a Header name or value that you want to reuse you should make use of this. + */ + @Deprecated + public static CharSequence newEntity(String name) { + return new AsciiString(name); + } + + protected HttpHeaders() { } + + /** + * @see #get(CharSequence) + */ + public abstract String get(String name); + + /** + * Returns the value of a header with the specified name. If there are + * more than one values for the specified name, the first value is returned. + * + * @param name The name of the header to search + * @return The first header value or {@code null} if there is no such header + * @see #getAsString(CharSequence) + */ + public String get(CharSequence name) { + return get(name.toString()); + } + + /** + * Returns the value of a header with the specified name. If there are + * more than one values for the specified name, the first value is returned. + * + * @param name The name of the header to search + * @return The first header value or {@code defaultValue} if there is no such header + */ + public String get(CharSequence name, String defaultValue) { + String value = get(name); + if (value == null) { + return defaultValue; + } + return value; + } + + /** + * Returns the integer value of a header with the specified name. If there are more than one values for the + * specified name, the first value is returned. + * + * @param name the name of the header to search + * @return the first header value if the header is found and its value is an integer. {@code null} if there's no + * such header or its value is not an integer. + */ + public abstract Integer getInt(CharSequence name); + + /** + * Returns the integer value of a header with the specified name. If there are more than one values for the + * specified name, the first value is returned. + * + * @param name the name of the header to search + * @param defaultValue the default value + * @return the first header value if the header is found and its value is an integer. {@code defaultValue} if + * there's no such header or its value is not an integer. + */ + public abstract int getInt(CharSequence name, int defaultValue); + + /** + * Returns the short value of a header with the specified name. If there are more than one values for the + * specified name, the first value is returned. + * + * @param name the name of the header to search + * @return the first header value if the header is found and its value is a short. {@code null} if there's no + * such header or its value is not a short. + */ + public abstract Short getShort(CharSequence name); + + /** + * Returns the short value of a header with the specified name. If there are more than one values for the + * specified name, the first value is returned. + * + * @param name the name of the header to search + * @param defaultValue the default value + * @return the first header value if the header is found and its value is a short. {@code defaultValue} if + * there's no such header or its value is not a short. + */ + public abstract short getShort(CharSequence name, short defaultValue); + + /** + * Returns the date value of a header with the specified name. If there are more than one values for the + * specified name, the first value is returned. + * + * @param name the name of the header to search + * @return the first header value if the header is found and its value is a date. {@code null} if there's no + * such header or its value is not a date. + */ + public abstract Long getTimeMillis(CharSequence name); + + /** + * Returns the date value of a header with the specified name. If there are more than one values for the + * specified name, the first value is returned. + * + * @param name the name of the header to search + * @param defaultValue the default value + * @return the first header value if the header is found and its value is a date. {@code defaultValue} if + * there's no such header or its value is not a date. + */ + public abstract long getTimeMillis(CharSequence name, long defaultValue); + + /** + * @see #getAll(CharSequence) + */ + public abstract List getAll(String name); + + /** + * Returns the values of headers with the specified name + * + * @param name The name of the headers to search + * @return A {@link List} of header values which will be empty if no values + * are found + * @see #getAllAsString(CharSequence) + */ + public List getAll(CharSequence name) { + return getAll(name.toString()); + } + + /** + * Returns a new {@link List} that contains all headers in this object. Note that modifying the + * returned {@link List} will not affect the state of this object. If you intend to enumerate over the header + * entries only, use {@link #iterator()} instead, which has much less overhead. + * @see #iteratorCharSequence() + */ + public abstract List> entries(); + + /** + * @see #contains(CharSequence) + */ + public abstract boolean contains(String name); + + /** + * @deprecated It is preferred to use {@link #iteratorCharSequence()} unless you need {@link String}. + * If {@link String} is required then use {@link #iteratorAsString()}. + */ + @Deprecated + @Override + public abstract Iterator> iterator(); + + /** + * @return Iterator over the name/value header pairs. + */ + public abstract Iterator> iteratorCharSequence(); + + /** + * Equivalent to {@link #getAll(String)} but it is possible that no intermediate list is generated. + * @param name the name of the header to retrieve + * @return an {@link Iterator} of header values corresponding to {@code name}. + */ + public Iterator valueStringIterator(CharSequence name) { + return getAll(name).iterator(); + } + + /** + * Equivalent to {@link #getAll(String)} but it is possible that no intermediate list is generated. + * @param name the name of the header to retrieve + * @return an {@link Iterator} of header values corresponding to {@code name}. + */ + public Iterator valueCharSequenceIterator(CharSequence name) { + return valueStringIterator(name); + } + + /** + * Checks to see if there is a header with the specified name + * + * @param name The name of the header to search for + * @return True if at least one header is found + */ + public boolean contains(CharSequence name) { + return contains(name.toString()); + } + + /** + * Checks if no header exists. + */ + public abstract boolean isEmpty(); + + /** + * Returns the number of headers in this object. + */ + public abstract int size(); + + /** + * Returns a new {@link Set} that contains the names of all headers in this object. Note that modifying the + * returned {@link Set} will not affect the state of this object. If you intend to enumerate over the header + * entries only, use {@link #iterator()} instead, which has much less overhead. + */ + public abstract Set names(); + + /** + * @see #add(CharSequence, Object) + */ + public abstract HttpHeaders add(String name, Object value); + + /** + * Adds a new header with the specified name and value. + * + * If the specified value is not a {@link String}, it is converted + * into a {@link String} by {@link Object#toString()}, except in the cases + * of {@link Date} and {@link Calendar}, which are formatted to the date + * format defined in RFC2616. + * + * @param name The name of the header being added + * @param value The value of the header being added + * + * @return {@code this} + */ + public HttpHeaders add(CharSequence name, Object value) { + return add(name.toString(), value); + } + + /** + * @see #add(CharSequence, Iterable) + */ + public abstract HttpHeaders add(String name, Iterable values); + + /** + * Adds a new header with the specified name and values. + * + * This getMethod can be represented approximately as the following code: + *

+     * for (Object v: values) {
+     *     if (v == null) {
+     *         break;
+     *     }
+     *     headers.add(name, v);
+     * }
+     * 
+ * + * @param name The name of the headers being set + * @param values The values of the headers being set + * @return {@code this} + */ + public HttpHeaders add(CharSequence name, Iterable values) { + return add(name.toString(), values); + } + + /** + * Adds all header entries of the specified {@code headers}. + * + * @return {@code this} + */ + public HttpHeaders add(HttpHeaders headers) { + ObjectUtil.checkNotNull(headers, "headers"); + for (Map.Entry e: headers) { + add(e.getKey(), e.getValue()); + } + return this; + } + + /** + * Add the {@code name} to {@code value}. + * @param name The name to modify + * @param value The value + * @return {@code this} + */ + public abstract HttpHeaders addInt(CharSequence name, int value); + + /** + * Add the {@code name} to {@code value}. + * @param name The name to modify + * @param value The value + * @return {@code this} + */ + public abstract HttpHeaders addShort(CharSequence name, short value); + + /** + * @see #set(CharSequence, Object) + */ + public abstract HttpHeaders set(String name, Object value); + + /** + * Sets a header with the specified name and value. + * + * If there is an existing header with the same name, it is removed. + * If the specified value is not a {@link String}, it is converted into a + * {@link String} by {@link Object#toString()}, except for {@link Date} + * and {@link Calendar}, which are formatted to the date format defined in + * RFC2616. + * + * @param name The name of the header being set + * @param value The value of the header being set + * @return {@code this} + */ + public HttpHeaders set(CharSequence name, Object value) { + return set(name.toString(), value); + } + + /** + * @see #set(CharSequence, Iterable) + */ + public abstract HttpHeaders set(String name, Iterable values); + + /** + * Sets a header with the specified name and values. + * + * If there is an existing header with the same name, it is removed. + * This getMethod can be represented approximately as the following code: + *
+     * headers.remove(name);
+     * for (Object v: values) {
+     *     if (v == null) {
+     *         break;
+     *     }
+     *     headers.add(name, v);
+     * }
+     * 
+ * + * @param name The name of the headers being set + * @param values The values of the headers being set + * @return {@code this} + */ + public HttpHeaders set(CharSequence name, Iterable values) { + return set(name.toString(), values); + } + + /** + * Cleans the current header entries and copies all header entries of the specified {@code headers}. + * + * @return {@code this} + */ + public HttpHeaders set(HttpHeaders headers) { + checkNotNull(headers, "headers"); + + clear(); + + if (headers.isEmpty()) { + return this; + } + + for (Entry entry : headers) { + add(entry.getKey(), entry.getValue()); + } + return this; + } + + /** + * Retains all current headers but calls {@link #set(String, Object)} for each entry in {@code headers} + * + * @param headers The headers used to {@link #set(String, Object)} values in this instance + * @return {@code this} + */ + public HttpHeaders setAll(HttpHeaders headers) { + checkNotNull(headers, "headers"); + + if (headers.isEmpty()) { + return this; + } + + for (Entry entry : headers) { + set(entry.getKey(), entry.getValue()); + } + return this; + } + + /** + * Set the {@code name} to {@code value}. This will remove all previous values associated with {@code name}. + * @param name The name to modify + * @param value The value + * @return {@code this} + */ + public abstract HttpHeaders setInt(CharSequence name, int value); + + /** + * Set the {@code name} to {@code value}. This will remove all previous values associated with {@code name}. + * @param name The name to modify + * @param value The value + * @return {@code this} + */ + public abstract HttpHeaders setShort(CharSequence name, short value); + + /** + * @see #remove(CharSequence) + */ + public abstract HttpHeaders remove(String name); + + /** + * Removes the header with the specified name. + * + * @param name The name of the header to remove + * @return {@code this} + */ + public HttpHeaders remove(CharSequence name) { + return remove(name.toString()); + } + + /** + * Removes all headers from this {@link HttpMessage}. + * + * @return {@code this} + */ + public abstract HttpHeaders clear(); + + /** + * @see #contains(CharSequence, CharSequence, boolean) + */ + public boolean contains(String name, String value, boolean ignoreCase) { + Iterator valueIterator = valueStringIterator(name); + if (ignoreCase) { + while (valueIterator.hasNext()) { + if (valueIterator.next().equalsIgnoreCase(value)) { + return true; + } + } + } else { + while (valueIterator.hasNext()) { + if (valueIterator.next().equals(value)) { + return true; + } + } + } + return false; + } + + /** + * Returns {@code true} if a header with the {@code name} and {@code value} exists, {@code false} otherwise. + * This also handles multiple values that are separated with a {@code ,}. + *

+ * If {@code ignoreCase} is {@code true} then a case insensitive compare is done on the value. + * @param name the name of the header to find + * @param value the value of the header to find + * @param ignoreCase {@code true} then a case insensitive compare is run to compare values. + * otherwise a case sensitive compare is run to compare values. + */ + public boolean containsValue(CharSequence name, CharSequence value, boolean ignoreCase) { + Iterator itr = valueCharSequenceIterator(name); + while (itr.hasNext()) { + if (containsCommaSeparatedTrimmed(itr.next(), value, ignoreCase)) { + return true; + } + } + return false; + } + + private static boolean containsCommaSeparatedTrimmed(CharSequence rawNext, CharSequence expected, + boolean ignoreCase) { + int begin = 0; + int end; + if (ignoreCase) { + if ((end = AsciiString.indexOf(rawNext, ',', begin)) == -1) { + if (contentEqualsIgnoreCase(trim(rawNext), expected)) { + return true; + } + } else { + do { + if (contentEqualsIgnoreCase(trim(rawNext.subSequence(begin, end)), expected)) { + return true; + } + begin = end + 1; + } while ((end = AsciiString.indexOf(rawNext, ',', begin)) != -1); + + if (begin < rawNext.length()) { + if (contentEqualsIgnoreCase(trim(rawNext.subSequence(begin, rawNext.length())), expected)) { + return true; + } + } + } + } else { + if ((end = AsciiString.indexOf(rawNext, ',', begin)) == -1) { + if (contentEquals(trim(rawNext), expected)) { + return true; + } + } else { + do { + if (contentEquals(trim(rawNext.subSequence(begin, end)), expected)) { + return true; + } + begin = end + 1; + } while ((end = AsciiString.indexOf(rawNext, ',', begin)) != -1); + + if (begin < rawNext.length()) { + if (contentEquals(trim(rawNext.subSequence(begin, rawNext.length())), expected)) { + return true; + } + } + } + } + return false; + } + + /** + * {@link Headers#get(Object)} and convert the result to a {@link String}. + * @param name the name of the header to retrieve + * @return the first header value if the header is found. {@code null} if there's no such header. + */ + public final String getAsString(CharSequence name) { + return get(name); + } + + /** + * {@link Headers#getAll(Object)} and convert each element of {@link List} to a {@link String}. + * @param name the name of the header to retrieve + * @return a {@link List} of header values or an empty {@link List} if no values are found. + */ + public final List getAllAsString(CharSequence name) { + return getAll(name); + } + + /** + * {@link Iterator} that converts each {@link Entry}'s key and value to a {@link String}. + */ + public final Iterator> iteratorAsString() { + return iterator(); + } + + /** + * Returns {@code true} if a header with the {@code name} and {@code value} exists, {@code false} otherwise. + *

+ * If {@code ignoreCase} is {@code true} then a case insensitive compare is done on the value. + * @param name the name of the header to find + * @param value the value of the header to find + * @param ignoreCase {@code true} then a case insensitive compare is run to compare values. + * otherwise a case sensitive compare is run to compare values. + */ + public boolean contains(CharSequence name, CharSequence value, boolean ignoreCase) { + return contains(name.toString(), value.toString(), ignoreCase); + } + + @Override + public String toString() { + return HeadersUtils.toString(getClass(), iteratorCharSequence(), size()); + } + + /** + * Returns a deep copy of the passed in {@link HttpHeaders}. + */ + public HttpHeaders copy() { + return new DefaultHttpHeaders().set(this); + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpHeadersEncoder.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpHeadersEncoder.java new file mode 100644 index 0000000..2ba766a --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpHeadersEncoder.java @@ -0,0 +1,57 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.codec.http; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.util.AsciiString; +import io.netty.util.CharsetUtil; + +import static io.netty.handler.codec.http.HttpConstants.*; +import static io.netty.handler.codec.http.HttpObjectEncoder.CRLF_SHORT; + +final class HttpHeadersEncoder { + private static final int COLON_AND_SPACE_SHORT = (COLON << 8) | SP; + + private HttpHeadersEncoder() { + } + + static void encoderHeader(CharSequence name, CharSequence value, ByteBuf buf) { + final int nameLen = name.length(); + final int valueLen = value.length(); + final int entryLen = nameLen + valueLen + 4; + buf.ensureWritable(entryLen); + int offset = buf.writerIndex(); + writeAscii(buf, offset, name); + offset += nameLen; + ByteBufUtil.setShortBE(buf, offset, COLON_AND_SPACE_SHORT); + offset += 2; + writeAscii(buf, offset, value); + offset += valueLen; + ByteBufUtil.setShortBE(buf, offset, CRLF_SHORT); + offset += 2; + buf.writerIndex(offset); + } + + private static void writeAscii(ByteBuf buf, int offset, CharSequence value) { + if (value instanceof AsciiString) { + ByteBufUtil.copy((AsciiString) value, 0, buf, offset, value.length()); + } else { + buf.setCharSequence(offset, value, CharsetUtil.US_ASCII); + } + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpHeadersFactory.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpHeadersFactory.java new file mode 100644 index 0000000..2fffced --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpHeadersFactory.java @@ -0,0 +1,34 @@ +/* + * Copyright 2023 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +/** + * An interface for creating {@link HttpHeaders} instances. + *

+ * The default implementation is {@link DefaultHttpHeadersFactory}, + * and the default instance is {@link DefaultHttpHeadersFactory#headersFactory()}. + */ +public interface HttpHeadersFactory { + /** + * Create a new {@link HttpHeaders} instance. + */ + HttpHeaders newHeaders(); + + /** + * Create a new {@link HttpHeaders} instance, but sized to be as small an object as possible. + */ + HttpHeaders newEmptyHeaders(); +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpMessage.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpMessage.java new file mode 100644 index 0000000..d0307ce --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpMessage.java @@ -0,0 +1,49 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + + +/** + * An interface that defines an HTTP message, providing common properties for + * {@link HttpRequest} and {@link HttpResponse}. + * + * @see HttpResponse + * @see HttpRequest + * @see HttpHeaders + */ +public interface HttpMessage extends HttpObject { + + /** + * @deprecated Use {@link #protocolVersion()} instead. + */ + @Deprecated + HttpVersion getProtocolVersion(); + + /** + * Returns the protocol version of this {@link HttpMessage} + */ + HttpVersion protocolVersion(); + + /** + * Set the protocol version of this {@link HttpMessage} + */ + HttpMessage setProtocolVersion(HttpVersion version); + + /** + * Returns the headers of this message. + */ + HttpHeaders headers(); +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpMessageDecoderResult.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpMessageDecoderResult.java new file mode 100644 index 0000000..b89252b --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpMessageDecoderResult.java @@ -0,0 +1,58 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.handler.codec.DecoderResult; + +/** + * A {@link DecoderResult} for {@link HttpMessage}s as produced by an {@link HttpObjectDecoder}. + *

+ * Please note that there is no guarantee that a {@link HttpObjectDecoder} will produce a {@link + * HttpMessageDecoderResult}. It may simply produce a regular {@link DecoderResult}. This result is intended for + * successful {@link HttpMessage} decoder results. + */ +public final class HttpMessageDecoderResult extends DecoderResult { + + private final int initialLineLength; + private final int headerSize; + + HttpMessageDecoderResult(int initialLineLength, int headerSize) { + super(SIGNAL_SUCCESS); + this.initialLineLength = initialLineLength; + this.headerSize = headerSize; + } + + /** + * The decoded initial line length (in bytes), as controlled by {@code maxInitialLineLength}. + */ + public int initialLineLength() { + return initialLineLength; + } + + /** + * The decoded header size (in bytes), as controlled by {@code maxHeaderSize}. + */ + public int headerSize() { + return headerSize; + } + + /** + * The decoded initial line length plus the decoded header size (in bytes). + */ + public int totalSize() { + return initialLineLength + headerSize; + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpMessageUtil.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpMessageUtil.java new file mode 100644 index 0000000..e339f51 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpMessageUtil.java @@ -0,0 +1,113 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.util.internal.StringUtil; + +import java.util.Map; + +/** + * Provides some utility methods for HTTP message implementations. + */ +final class HttpMessageUtil { + + static StringBuilder appendRequest(StringBuilder buf, HttpRequest req) { + appendCommon(buf, req); + appendInitialLine(buf, req); + appendHeaders(buf, req.headers()); + removeLastNewLine(buf); + return buf; + } + + static StringBuilder appendResponse(StringBuilder buf, HttpResponse res) { + appendCommon(buf, res); + appendInitialLine(buf, res); + appendHeaders(buf, res.headers()); + removeLastNewLine(buf); + return buf; + } + + private static void appendCommon(StringBuilder buf, HttpMessage msg) { + buf.append(StringUtil.simpleClassName(msg)); + buf.append("(decodeResult: "); + buf.append(msg.decoderResult()); + buf.append(", version: "); + buf.append(msg.protocolVersion()); + buf.append(')'); + buf.append(StringUtil.NEWLINE); + } + + static StringBuilder appendFullRequest(StringBuilder buf, FullHttpRequest req) { + appendFullCommon(buf, req); + appendInitialLine(buf, req); + appendHeaders(buf, req.headers()); + appendHeaders(buf, req.trailingHeaders()); + removeLastNewLine(buf); + return buf; + } + + static StringBuilder appendFullResponse(StringBuilder buf, FullHttpResponse res) { + appendFullCommon(buf, res); + appendInitialLine(buf, res); + appendHeaders(buf, res.headers()); + appendHeaders(buf, res.trailingHeaders()); + removeLastNewLine(buf); + return buf; + } + + private static void appendFullCommon(StringBuilder buf, FullHttpMessage msg) { + buf.append(StringUtil.simpleClassName(msg)); + buf.append("(decodeResult: "); + buf.append(msg.decoderResult()); + buf.append(", version: "); + buf.append(msg.protocolVersion()); + buf.append(", content: "); + buf.append(msg.content()); + buf.append(')'); + buf.append(StringUtil.NEWLINE); + } + + private static void appendInitialLine(StringBuilder buf, HttpRequest req) { + buf.append(req.method()); + buf.append(' '); + buf.append(req.uri()); + buf.append(' '); + buf.append(req.protocolVersion()); + buf.append(StringUtil.NEWLINE); + } + + private static void appendInitialLine(StringBuilder buf, HttpResponse res) { + buf.append(res.protocolVersion()); + buf.append(' '); + buf.append(res.status()); + buf.append(StringUtil.NEWLINE); + } + + private static void appendHeaders(StringBuilder buf, HttpHeaders headers) { + for (Map.Entry e: headers) { + buf.append(e.getKey()); + buf.append(": "); + buf.append(e.getValue()); + buf.append(StringUtil.NEWLINE); + } + } + + private static void removeLastNewLine(StringBuilder buf) { + buf.setLength(buf.length() - StringUtil.NEWLINE.length()); + } + + private HttpMessageUtil() { } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpMethod.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpMethod.java new file mode 100644 index 0000000..59db66e --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpMethod.java @@ -0,0 +1,229 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.util.AsciiString; + +import static io.netty.util.internal.MathUtil.findNextPositivePowerOfTwo; +import static io.netty.util.internal.ObjectUtil.checkNonEmptyAfterTrim; + +/** + * The request method of HTTP or its derived protocols, such as + * RTSP and + * ICAP. + */ +public class HttpMethod implements Comparable { + /** + * The OPTIONS method represents a request for information about the communication options + * available on the request/response chain identified by the Request-URI. This method allows + * the client to determine the options and/or requirements associated with a resource, or the + * capabilities of a server, without implying a resource action or initiating a resource + * retrieval. + */ + public static final HttpMethod OPTIONS = new HttpMethod("OPTIONS"); + + /** + * The GET method means retrieve whatever information (in the form of an entity) is identified + * by the Request-URI. If the Request-URI refers to a data-producing process, it is the + * produced data which shall be returned as the entity in the response and not the source text + * of the process, unless that text happens to be the output of the process. + */ + public static final HttpMethod GET = new HttpMethod("GET"); + + /** + * The HEAD method is identical to GET except that the server MUST NOT return a message-body + * in the response. + */ + public static final HttpMethod HEAD = new HttpMethod("HEAD"); + + /** + * The POST method is used to request that the origin server accept the entity enclosed in the + * request as a new subordinate of the resource identified by the Request-URI in the + * Request-Line. + */ + public static final HttpMethod POST = new HttpMethod("POST"); + + /** + * The PUT method requests that the enclosed entity be stored under the supplied Request-URI. + */ + public static final HttpMethod PUT = new HttpMethod("PUT"); + + /** + * The PATCH method requests that a set of changes described in the + * request entity be applied to the resource identified by the Request-URI. + */ + public static final HttpMethod PATCH = new HttpMethod("PATCH"); + + /** + * The DELETE method requests that the origin server delete the resource identified by the + * Request-URI. + */ + public static final HttpMethod DELETE = new HttpMethod("DELETE"); + + /** + * The TRACE method is used to invoke a remote, application-layer loop- back of the request + * message. + */ + public static final HttpMethod TRACE = new HttpMethod("TRACE"); + + /** + * This specification reserves the method name CONNECT for use with a proxy that can dynamically + * switch to being a tunnel + */ + public static final HttpMethod CONNECT = new HttpMethod("CONNECT"); + + private static final EnumNameMap methodMap; + + static { + methodMap = new EnumNameMap( + new EnumNameMap.Node(OPTIONS.toString(), OPTIONS), + new EnumNameMap.Node(GET.toString(), GET), + new EnumNameMap.Node(HEAD.toString(), HEAD), + new EnumNameMap.Node(POST.toString(), POST), + new EnumNameMap.Node(PUT.toString(), PUT), + new EnumNameMap.Node(PATCH.toString(), PATCH), + new EnumNameMap.Node(DELETE.toString(), DELETE), + new EnumNameMap.Node(TRACE.toString(), TRACE), + new EnumNameMap.Node(CONNECT.toString(), CONNECT)); + } + + /** + * Returns the {@link HttpMethod} represented by the specified name. + * If the specified name is a standard HTTP method name, a cached instance + * will be returned. Otherwise, a new instance will be returned. + */ + public static HttpMethod valueOf(String name) { + // fast-path + if (name == HttpMethod.GET.name()) { + return HttpMethod.GET; + } + if (name == HttpMethod.POST.name()) { + return HttpMethod.POST; + } + // "slow"-path + HttpMethod result = methodMap.get(name); + return result != null ? result : new HttpMethod(name); + } + + private final AsciiString name; + + /** + * Creates a new HTTP method with the specified name. You will not need to + * create a new method unless you are implementing a protocol derived from + * HTTP, such as + * RTSP and + * ICAP + */ + public HttpMethod(String name) { + name = checkNonEmptyAfterTrim(name, "name"); + + for (int i = 0; i < name.length(); i ++) { + char c = name.charAt(i); + if (Character.isISOControl(c) || Character.isWhitespace(c)) { + throw new IllegalArgumentException("invalid character in name"); + } + } + + this.name = AsciiString.cached(name); + } + + /** + * Returns the name of this method. + */ + public String name() { + return name.toString(); + } + + /** + * Returns the name of this method. + */ + public AsciiString asciiName() { + return name; + } + + @Override + public int hashCode() { + return name().hashCode(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof HttpMethod)) { + return false; + } + + HttpMethod that = (HttpMethod) o; + return name().equals(that.name()); + } + + @Override + public String toString() { + return name.toString(); + } + + @Override + public int compareTo(HttpMethod o) { + if (o == this) { + return 0; + } + return name().compareTo(o.name()); + } + + private static final class EnumNameMap { + private final EnumNameMap.Node[] values; + private final int valuesMask; + + EnumNameMap(EnumNameMap.Node... nodes) { + values = (EnumNameMap.Node[]) new EnumNameMap.Node[findNextPositivePowerOfTwo(nodes.length)]; + valuesMask = values.length - 1; + for (EnumNameMap.Node node : nodes) { + int i = hashCode(node.key) & valuesMask; + if (values[i] != null) { + throw new IllegalArgumentException("index " + i + " collision between values: [" + + values[i].key + ", " + node.key + ']'); + } + values[i] = node; + } + } + + T get(String name) { + EnumNameMap.Node node = values[hashCode(name) & valuesMask]; + return node == null || !node.key.equals(name) ? null : node.value; + } + + private static int hashCode(String name) { + // This hash code needs to produce a unique index in the "values" array for each HttpMethod. If new + // HttpMethods are added this algorithm will need to be adjusted. The constructor will "fail fast" if there + // are duplicates detected. + // For example with the current set of HttpMethods it just so happens that the String hash code value + // shifted right by 6 bits modulo 16 is unique relative to all other HttpMethod values. + return name.hashCode() >>> 6; + } + + private static final class Node { + final String key; + final T value; + + Node(String key, T value) { + this.key = key; + this.value = value; + } + } + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpObject.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpObject.java new file mode 100644 index 0000000..62c3841 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpObject.java @@ -0,0 +1,27 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.handler.codec.DecoderResult; +import io.netty.handler.codec.DecoderResultProvider; + +public interface HttpObject extends DecoderResultProvider { + /** + * @deprecated Use {@link #decoderResult()} instead. + */ + @Deprecated + DecoderResult getDecoderResult(); +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpObjectAggregator.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpObjectAggregator.java new file mode 100644 index 0000000..1efd2c5 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpObjectAggregator.java @@ -0,0 +1,574 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPipeline; +import io.netty.handler.codec.DecoderResult; +import io.netty.handler.codec.MessageAggregator; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import static io.netty.handler.codec.http.HttpHeaderNames.CONNECTION; +import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_LENGTH; +import static io.netty.handler.codec.http.HttpHeaderNames.EXPECT; +import static io.netty.handler.codec.http.HttpUtil.getContentLength; + +/** + * A {@link ChannelHandler} that aggregates an {@link HttpMessage} + * and its following {@link HttpContent}s into a single {@link FullHttpRequest} + * or {@link FullHttpResponse} (depending on if it used to handle requests or responses) + * with no following {@link HttpContent}s. It is useful when you don't want to take + * care of HTTP messages whose transfer encoding is 'chunked'. Insert this + * handler after {@link HttpResponseDecoder} in the {@link ChannelPipeline} if being used to handle + * responses, or after {@link HttpRequestDecoder} and {@link HttpResponseEncoder} in the + * {@link ChannelPipeline} if being used to handle requests. + *

+ *
+ *  {@link ChannelPipeline} p = ...;
+ *  ...
+ *  p.addLast("decoder", new {@link HttpRequestDecoder}());
+ *  p.addLast("encoder", new {@link HttpResponseEncoder}());
+ *  p.addLast("aggregator", new {@link HttpObjectAggregator}(1048576));
+ *  ...
+ *  p.addLast("handler", new HttpRequestHandler());
+ *  
+ *
+ *

+ * For convenience, consider putting a {@link HttpServerCodec} before the {@link HttpObjectAggregator} + * as it functions as both a {@link HttpRequestDecoder} and a {@link HttpResponseEncoder}. + *

+ * Be aware that {@link HttpObjectAggregator} may end up sending a {@link HttpResponse}: + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
Response StatusCondition When Sent
100 ContinueA '100-continue' expectation is received and the 'content-length' doesn't exceed maxContentLength
417 Expectation FailedA '100-continue' expectation is received and the 'content-length' exceeds maxContentLength
413 Request Entity Too LargeEither the 'content-length' or the bytes received so far exceed maxContentLength
+ * + * @see FullHttpRequest + * @see FullHttpResponse + * @see HttpResponseDecoder + * @see HttpServerCodec + */ +public class HttpObjectAggregator + extends MessageAggregator { + private static final InternalLogger logger = InternalLoggerFactory.getInstance(HttpObjectAggregator.class); + private static final FullHttpResponse CONTINUE = + new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.CONTINUE, Unpooled.EMPTY_BUFFER); + private static final FullHttpResponse EXPECTATION_FAILED = new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, HttpResponseStatus.EXPECTATION_FAILED, Unpooled.EMPTY_BUFFER); + private static final FullHttpResponse TOO_LARGE_CLOSE = new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE, Unpooled.EMPTY_BUFFER); + private static final FullHttpResponse TOO_LARGE = new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE, Unpooled.EMPTY_BUFFER); + + static { + EXPECTATION_FAILED.headers().set(CONTENT_LENGTH, 0); + TOO_LARGE.headers().set(CONTENT_LENGTH, 0); + + TOO_LARGE_CLOSE.headers().set(CONTENT_LENGTH, 0); + TOO_LARGE_CLOSE.headers().set(CONNECTION, HttpHeaderValues.CLOSE); + } + + private final boolean closeOnExpectationFailed; + + /** + * Creates a new instance. + * @param maxContentLength the maximum length of the aggregated content in bytes. + * If the length of the aggregated content exceeds this value, + * {@link #handleOversizedMessage(ChannelHandlerContext, HttpMessage)} will be called. + */ + public HttpObjectAggregator(int maxContentLength) { + this(maxContentLength, false); + } + + /** + * Creates a new instance. + * @param maxContentLength the maximum length of the aggregated content in bytes. + * If the length of the aggregated content exceeds this value, + * {@link #handleOversizedMessage(ChannelHandlerContext, HttpMessage)} will be called. + * @param closeOnExpectationFailed If a 100-continue response is detected but the content length is too large + * then {@code true} means close the connection. otherwise the connection will remain open and data will be + * consumed and discarded until the next request is received. + */ + public HttpObjectAggregator(int maxContentLength, boolean closeOnExpectationFailed) { + super(maxContentLength); + this.closeOnExpectationFailed = closeOnExpectationFailed; + } + + @Override + protected boolean isStartMessage(HttpObject msg) throws Exception { + return msg instanceof HttpMessage; + } + + @Override + protected boolean isContentMessage(HttpObject msg) throws Exception { + return msg instanceof HttpContent; + } + + @Override + protected boolean isLastContentMessage(HttpContent msg) throws Exception { + return msg instanceof LastHttpContent; + } + + @Override + protected boolean isAggregated(HttpObject msg) throws Exception { + return msg instanceof FullHttpMessage; + } + + @Override + protected boolean isContentLengthInvalid(HttpMessage start, int maxContentLength) { + try { + return getContentLength(start, -1L) > maxContentLength; + } catch (final NumberFormatException e) { + return false; + } + } + + private static Object continueResponse(HttpMessage start, int maxContentLength, ChannelPipeline pipeline) { + if (HttpUtil.isUnsupportedExpectation(start)) { + // if the request contains an unsupported expectation, we return 417 + pipeline.fireUserEventTriggered(HttpExpectationFailedEvent.INSTANCE); + return EXPECTATION_FAILED.retainedDuplicate(); + } else if (HttpUtil.is100ContinueExpected(start)) { + // if the request contains 100-continue but the content-length is too large, we return 413 + if (getContentLength(start, -1L) <= maxContentLength) { + return CONTINUE.retainedDuplicate(); + } + pipeline.fireUserEventTriggered(HttpExpectationFailedEvent.INSTANCE); + return TOO_LARGE.retainedDuplicate(); + } + + return null; + } + + @Override + protected Object newContinueResponse(HttpMessage start, int maxContentLength, ChannelPipeline pipeline) { + Object response = continueResponse(start, maxContentLength, pipeline); + // we're going to respond based on the request expectation so there's no + // need to propagate the expectation further. + if (response != null) { + start.headers().remove(EXPECT); + } + return response; + } + + @Override + protected boolean closeAfterContinueResponse(Object msg) { + return closeOnExpectationFailed && ignoreContentAfterContinueResponse(msg); + } + + @Override + protected boolean ignoreContentAfterContinueResponse(Object msg) { + if (msg instanceof HttpResponse) { + final HttpResponse httpResponse = (HttpResponse) msg; + return httpResponse.status().codeClass().equals(HttpStatusClass.CLIENT_ERROR); + } + return false; + } + + @Override + protected FullHttpMessage beginAggregation(HttpMessage start, ByteBuf content) throws Exception { + assert !(start instanceof FullHttpMessage); + + HttpUtil.setTransferEncodingChunked(start, false); + + AggregatedFullHttpMessage ret; + if (start instanceof HttpRequest) { + ret = new AggregatedFullHttpRequest((HttpRequest) start, content, null); + } else if (start instanceof HttpResponse) { + ret = new AggregatedFullHttpResponse((HttpResponse) start, content, null); + } else { + throw new Error(); + } + return ret; + } + + @Override + protected void aggregate(FullHttpMessage aggregated, HttpContent content) throws Exception { + if (content instanceof LastHttpContent) { + // Merge trailing headers into the message. + ((AggregatedFullHttpMessage) aggregated).setTrailingHeaders(((LastHttpContent) content).trailingHeaders()); + } + } + + @Override + protected void finishAggregation(FullHttpMessage aggregated) throws Exception { + // Set the 'Content-Length' header. If one isn't already set. + // This is important as HEAD responses will use a 'Content-Length' header which + // does not match the actual body, but the number of bytes that would be + // transmitted if a GET would have been used. + // + // See rfc2616 14.13 Content-Length + if (!HttpUtil.isContentLengthSet(aggregated)) { + aggregated.headers().set( + CONTENT_LENGTH, + String.valueOf(aggregated.content().readableBytes())); + } + } + + @Override + protected void handleOversizedMessage(final ChannelHandlerContext ctx, HttpMessage oversized) throws Exception { + if (oversized instanceof HttpRequest) { + // send back a 413 and close the connection + + // If the client started to send data already, close because it's impossible to recover. + // If keep-alive is off and 'Expect: 100-continue' is missing, no need to leave the connection open. + if (oversized instanceof FullHttpMessage || + !HttpUtil.is100ContinueExpected(oversized) && !HttpUtil.isKeepAlive(oversized)) { + ChannelFuture future = ctx.writeAndFlush(TOO_LARGE_CLOSE.retainedDuplicate()); + future.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + if (!future.isSuccess()) { + logger.debug("Failed to send a 413 Request Entity Too Large.", future.cause()); + } + ctx.close(); + } + }); + } else { + ctx.writeAndFlush(TOO_LARGE.retainedDuplicate()).addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + if (!future.isSuccess()) { + logger.debug("Failed to send a 413 Request Entity Too Large.", future.cause()); + ctx.close(); + } + } + }); + } + } else if (oversized instanceof HttpResponse) { + ctx.close(); + throw new TooLongHttpContentException("Response entity too large: " + oversized); + } else { + throw new IllegalStateException(); + } + } + + private abstract static class AggregatedFullHttpMessage implements FullHttpMessage { + protected final HttpMessage message; + private final ByteBuf content; + private HttpHeaders trailingHeaders; + + AggregatedFullHttpMessage(HttpMessage message, ByteBuf content, HttpHeaders trailingHeaders) { + this.message = message; + this.content = content; + this.trailingHeaders = trailingHeaders; + } + + @Override + public HttpHeaders trailingHeaders() { + HttpHeaders trailingHeaders = this.trailingHeaders; + if (trailingHeaders == null) { + return EmptyHttpHeaders.INSTANCE; + } else { + return trailingHeaders; + } + } + + void setTrailingHeaders(HttpHeaders trailingHeaders) { + this.trailingHeaders = trailingHeaders; + } + + @Override + public HttpVersion getProtocolVersion() { + return message.protocolVersion(); + } + + @Override + public HttpVersion protocolVersion() { + return message.protocolVersion(); + } + + @Override + public FullHttpMessage setProtocolVersion(HttpVersion version) { + message.setProtocolVersion(version); + return this; + } + + @Override + public HttpHeaders headers() { + return message.headers(); + } + + @Override + public DecoderResult decoderResult() { + return message.decoderResult(); + } + + @Override + public DecoderResult getDecoderResult() { + return message.decoderResult(); + } + + @Override + public void setDecoderResult(DecoderResult result) { + message.setDecoderResult(result); + } + + @Override + public ByteBuf content() { + return content; + } + + @Override + public int refCnt() { + return content.refCnt(); + } + + @Override + public FullHttpMessage retain() { + content.retain(); + return this; + } + + @Override + public FullHttpMessage retain(int increment) { + content.retain(increment); + return this; + } + + @Override + public FullHttpMessage touch(Object hint) { + content.touch(hint); + return this; + } + + @Override + public FullHttpMessage touch() { + content.touch(); + return this; + } + + @Override + public boolean release() { + return content.release(); + } + + @Override + public boolean release(int decrement) { + return content.release(decrement); + } + + @Override + public abstract FullHttpMessage copy(); + + @Override + public abstract FullHttpMessage duplicate(); + + @Override + public abstract FullHttpMessage retainedDuplicate(); + } + + private static final class AggregatedFullHttpRequest extends AggregatedFullHttpMessage implements FullHttpRequest { + + AggregatedFullHttpRequest(HttpRequest request, ByteBuf content, HttpHeaders trailingHeaders) { + super(request, content, trailingHeaders); + } + + @Override + public FullHttpRequest copy() { + return replace(content().copy()); + } + + @Override + public FullHttpRequest duplicate() { + return replace(content().duplicate()); + } + + @Override + public FullHttpRequest retainedDuplicate() { + return replace(content().retainedDuplicate()); + } + + @Override + public FullHttpRequest replace(ByteBuf content) { + DefaultFullHttpRequest dup = new DefaultFullHttpRequest(protocolVersion(), method(), uri(), content, + headers().copy(), trailingHeaders().copy()); + dup.setDecoderResult(decoderResult()); + return dup; + } + + @Override + public FullHttpRequest retain(int increment) { + super.retain(increment); + return this; + } + + @Override + public FullHttpRequest retain() { + super.retain(); + return this; + } + + @Override + public FullHttpRequest touch() { + super.touch(); + return this; + } + + @Override + public FullHttpRequest touch(Object hint) { + super.touch(hint); + return this; + } + + @Override + public FullHttpRequest setMethod(HttpMethod method) { + ((HttpRequest) message).setMethod(method); + return this; + } + + @Override + public FullHttpRequest setUri(String uri) { + ((HttpRequest) message).setUri(uri); + return this; + } + + @Override + public HttpMethod getMethod() { + return ((HttpRequest) message).method(); + } + + @Override + public String getUri() { + return ((HttpRequest) message).uri(); + } + + @Override + public HttpMethod method() { + return getMethod(); + } + + @Override + public String uri() { + return getUri(); + } + + @Override + public FullHttpRequest setProtocolVersion(HttpVersion version) { + super.setProtocolVersion(version); + return this; + } + + @Override + public String toString() { + return HttpMessageUtil.appendFullRequest(new StringBuilder(256), this).toString(); + } + } + + private static final class AggregatedFullHttpResponse extends AggregatedFullHttpMessage + implements FullHttpResponse { + + AggregatedFullHttpResponse(HttpResponse message, ByteBuf content, HttpHeaders trailingHeaders) { + super(message, content, trailingHeaders); + } + + @Override + public FullHttpResponse copy() { + return replace(content().copy()); + } + + @Override + public FullHttpResponse duplicate() { + return replace(content().duplicate()); + } + + @Override + public FullHttpResponse retainedDuplicate() { + return replace(content().retainedDuplicate()); + } + + @Override + public FullHttpResponse replace(ByteBuf content) { + DefaultFullHttpResponse dup = new DefaultFullHttpResponse(getProtocolVersion(), getStatus(), content, + headers().copy(), trailingHeaders().copy()); + dup.setDecoderResult(decoderResult()); + return dup; + } + + @Override + public FullHttpResponse setStatus(HttpResponseStatus status) { + ((HttpResponse) message).setStatus(status); + return this; + } + + @Override + public HttpResponseStatus getStatus() { + return ((HttpResponse) message).status(); + } + + @Override + public HttpResponseStatus status() { + return getStatus(); + } + + @Override + public FullHttpResponse setProtocolVersion(HttpVersion version) { + super.setProtocolVersion(version); + return this; + } + + @Override + public FullHttpResponse retain(int increment) { + super.retain(increment); + return this; + } + + @Override + public FullHttpResponse retain() { + super.retain(); + return this; + } + + @Override + public FullHttpResponse touch(Object hint) { + super.touch(hint); + return this; + } + + @Override + public FullHttpResponse touch() { + super.touch(); + return this; + } + + @Override + public String toString() { + return HttpMessageUtil.appendFullResponse(new StringBuilder(256), this).toString(); + } + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpObjectDecoder.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpObjectDecoder.java new file mode 100644 index 0000000..18dfde4 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpObjectDecoder.java @@ -0,0 +1,1233 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPipeline; +import io.netty.handler.codec.ByteToMessageDecoder; +import io.netty.handler.codec.DecoderResult; +import io.netty.handler.codec.PrematureChannelClosureException; +import io.netty.handler.codec.TooLongFrameException; +import io.netty.util.AsciiString; +import io.netty.util.ByteProcessor; +import io.netty.util.internal.StringUtil; + +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; + +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +/** + * Decodes {@link ByteBuf}s into {@link HttpMessage}s and + * {@link HttpContent}s. + * + *

Parameters that prevents excessive memory consumption

+ * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
NameDefault valueMeaning
{@code maxInitialLineLength}{@value #DEFAULT_MAX_INITIAL_LINE_LENGTH}The maximum length of the initial line + * (e.g. {@code "GET / HTTP/1.0"} or {@code "HTTP/1.0 200 OK"}) + * If the length of the initial line exceeds this value, a + * {@link TooLongHttpLineException} will be raised.
{@code maxHeaderSize}{@value #DEFAULT_MAX_HEADER_SIZE}The maximum length of all headers. If the sum of the length of each + * header exceeds this value, a {@link TooLongHttpHeaderException} will be raised.
{@code maxChunkSize}{@value #DEFAULT_MAX_CHUNK_SIZE}The maximum length of the content or each chunk. If the content length + * (or the length of each chunk) exceeds this value, the content or chunk + * will be split into multiple {@link HttpContent}s whose length is + * {@code maxChunkSize} at maximum.
+ * + *

Parameters that control parsing behavior

+ * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
NameDefault valueMeaning
{@code allowDuplicateContentLengths}{@value #DEFAULT_ALLOW_DUPLICATE_CONTENT_LENGTHS}When set to {@code false}, will reject any messages that contain multiple Content-Length header fields. + * When set to {@code true}, will allow multiple Content-Length headers only if they are all the same decimal value. + * The duplicated field-values will be replaced with a single valid Content-Length field. + * See RFC 7230, Section 3.3.2.
{@code allowPartialChunks}{@value #DEFAULT_ALLOW_PARTIAL_CHUNKS}If the length of a chunk exceeds the {@link ByteBuf}s readable bytes and {@code allowPartialChunks} + * is set to {@code true}, the chunk will be split into multiple {@link HttpContent}s. + * Otherwise, if the chunk size does not exceed {@code maxChunkSize} and {@code allowPartialChunks} + * is set to {@code false}, the {@link ByteBuf} is not decoded into an {@link HttpContent} until + * the readable bytes are greater or equal to the chunk size.
+ * + *

Chunked Content

+ * + * If the content of an HTTP message is greater than {@code maxChunkSize} or + * the transfer encoding of the HTTP message is 'chunked', this decoder + * generates one {@link HttpMessage} instance and its following + * {@link HttpContent}s per single HTTP message to avoid excessive memory + * consumption. For example, the following HTTP message: + *
+ * GET / HTTP/1.1
+ * Transfer-Encoding: chunked
+ *
+ * 1a
+ * abcdefghijklmnopqrstuvwxyz
+ * 10
+ * 1234567890abcdef
+ * 0
+ * Content-MD5: ...
+ * [blank line]
+ * 
+ * triggers {@link HttpRequestDecoder} to generate 3 objects: + *
    + *
  1. An {@link HttpRequest},
  2. + *
  3. The first {@link HttpContent} whose content is {@code 'abcdefghijklmnopqrstuvwxyz'},
  4. + *
  5. The second {@link LastHttpContent} whose content is {@code '1234567890abcdef'}, which marks + * the end of the content.
  6. + *
+ * + * If you prefer not to handle {@link HttpContent}s by yourself for your + * convenience, insert {@link HttpObjectAggregator} after this decoder in the + * {@link ChannelPipeline}. However, please note that your server might not + * be as memory efficient as without the aggregator. + * + *

Extensibility

+ * + * Please note that this decoder is designed to be extended to implement + * a protocol derived from HTTP, such as + * RTSP and + * ICAP. + * To implement the decoder of such a derived protocol, extend this class and + * implement all abstract methods properly. + * + *

Header Validation

+ * + * It is recommended to always enable header validation. + *

+ * Without header validation, your system can become vulnerable to + * + * CWE-113: Improper Neutralization of CRLF Sequences in HTTP Headers ('HTTP Response Splitting') + * . + *

+ * This recommendation stands even when both peers in the HTTP exchange are trusted, + * as it helps with defence-in-depth. + */ +public abstract class HttpObjectDecoder extends ByteToMessageDecoder { + public static final int DEFAULT_MAX_INITIAL_LINE_LENGTH = 4096; + public static final int DEFAULT_MAX_HEADER_SIZE = 8192; + public static final boolean DEFAULT_CHUNKED_SUPPORTED = true; + public static final boolean DEFAULT_ALLOW_PARTIAL_CHUNKS = true; + public static final int DEFAULT_MAX_CHUNK_SIZE = 8192; + public static final boolean DEFAULT_VALIDATE_HEADERS = true; + public static final int DEFAULT_INITIAL_BUFFER_SIZE = 128; + public static final boolean DEFAULT_ALLOW_DUPLICATE_CONTENT_LENGTHS = false; + private final int maxChunkSize; + private final boolean chunkedSupported; + private final boolean allowPartialChunks; + /** + * This field is no longer used. It is only kept around for backwards compatibility purpose. + */ + @Deprecated + protected final boolean validateHeaders; + protected final HttpHeadersFactory headersFactory; + protected final HttpHeadersFactory trailersFactory; + private final boolean allowDuplicateContentLengths; + private final ByteBuf parserScratchBuffer; + private final HeaderParser headerParser; + private final LineParser lineParser; + + private HttpMessage message; + private long chunkSize; + private long contentLength = Long.MIN_VALUE; + private final AtomicBoolean resetRequested = new AtomicBoolean(); + + // These will be updated by splitHeader(...) + private AsciiString name; + private String value; + private LastHttpContent trailer; + + @Override + protected void handlerRemoved0(ChannelHandlerContext ctx) throws Exception { + try { + parserScratchBuffer.release(); + } finally { + super.handlerRemoved0(ctx); + } + } + + /** + * The internal state of {@link HttpObjectDecoder}. + * Internal use only. + */ + private enum State { + SKIP_CONTROL_CHARS, + READ_INITIAL, + READ_HEADER, + READ_VARIABLE_LENGTH_CONTENT, + READ_FIXED_LENGTH_CONTENT, + READ_CHUNK_SIZE, + READ_CHUNKED_CONTENT, + READ_CHUNK_DELIMITER, + READ_CHUNK_FOOTER, + BAD_MESSAGE, + UPGRADED + } + + private State currentState = State.SKIP_CONTROL_CHARS; + + /** + * Creates a new instance with the default + * {@code maxInitialLineLength (4096)}, {@code maxHeaderSize (8192)}, and + * {@code maxChunkSize (8192)}. + */ + protected HttpObjectDecoder() { + this(new HttpDecoderConfig()); + } + + /** + * Creates a new instance with the specified parameters. + * + * @deprecated Use {@link #HttpObjectDecoder(HttpDecoderConfig)} instead. + */ + @Deprecated + protected HttpObjectDecoder( + int maxInitialLineLength, int maxHeaderSize, int maxChunkSize, boolean chunkedSupported) { + this(new HttpDecoderConfig() + .setMaxInitialLineLength(maxInitialLineLength) + .setMaxHeaderSize(maxHeaderSize) + .setMaxChunkSize(maxChunkSize) + .setChunkedSupported(chunkedSupported)); + } + + /** + * Creates a new instance with the specified parameters. + * + * @deprecated Use {@link #HttpObjectDecoder(HttpDecoderConfig)} instead. + */ + @Deprecated + protected HttpObjectDecoder( + int maxInitialLineLength, int maxHeaderSize, int maxChunkSize, + boolean chunkedSupported, boolean validateHeaders) { + this(new HttpDecoderConfig() + .setMaxInitialLineLength(maxInitialLineLength) + .setMaxHeaderSize(maxHeaderSize) + .setMaxChunkSize(maxChunkSize) + .setChunkedSupported(chunkedSupported) + .setValidateHeaders(validateHeaders)); + } + + /** + * Creates a new instance with the specified parameters. + * + * @deprecated Use {@link #HttpObjectDecoder(HttpDecoderConfig)} instead. + */ + @Deprecated + protected HttpObjectDecoder( + int maxInitialLineLength, int maxHeaderSize, int maxChunkSize, + boolean chunkedSupported, boolean validateHeaders, int initialBufferSize) { + this(new HttpDecoderConfig() + .setMaxInitialLineLength(maxInitialLineLength) + .setMaxHeaderSize(maxHeaderSize) + .setMaxChunkSize(maxChunkSize) + .setChunkedSupported(chunkedSupported) + .setValidateHeaders(validateHeaders) + .setInitialBufferSize(initialBufferSize)); + } + + /** + * Creates a new instance with the specified parameters. + * + * @deprecated Use {@link #HttpObjectDecoder(HttpDecoderConfig)} instead. + */ + @Deprecated + protected HttpObjectDecoder( + int maxInitialLineLength, int maxHeaderSize, int maxChunkSize, + boolean chunkedSupported, boolean validateHeaders, int initialBufferSize, + boolean allowDuplicateContentLengths) { + this(new HttpDecoderConfig() + .setMaxInitialLineLength(maxInitialLineLength) + .setMaxHeaderSize(maxHeaderSize) + .setMaxChunkSize(maxChunkSize) + .setChunkedSupported(chunkedSupported) + .setValidateHeaders(validateHeaders) + .setInitialBufferSize(initialBufferSize) + .setAllowDuplicateContentLengths(allowDuplicateContentLengths)); + } + + /** + * Creates a new instance with the specified parameters. + * + * @deprecated Use {@link #HttpObjectDecoder(HttpDecoderConfig)} instead. + */ + @Deprecated + protected HttpObjectDecoder( + int maxInitialLineLength, int maxHeaderSize, int maxChunkSize, + boolean chunkedSupported, boolean validateHeaders, int initialBufferSize, + boolean allowDuplicateContentLengths, boolean allowPartialChunks) { + this(new HttpDecoderConfig() + .setMaxInitialLineLength(maxInitialLineLength) + .setMaxHeaderSize(maxHeaderSize) + .setMaxChunkSize(maxChunkSize) + .setChunkedSupported(chunkedSupported) + .setValidateHeaders(validateHeaders) + .setInitialBufferSize(initialBufferSize) + .setAllowDuplicateContentLengths(allowDuplicateContentLengths) + .setAllowPartialChunks(allowPartialChunks)); + } + + /** + * Creates a new instance with the specified configuration. + */ + protected HttpObjectDecoder(HttpDecoderConfig config) { + checkNotNull(config, "config"); + + parserScratchBuffer = Unpooled.buffer(config.getInitialBufferSize()); + lineParser = new LineParser(parserScratchBuffer, config.getMaxInitialLineLength()); + headerParser = new HeaderParser(parserScratchBuffer, config.getMaxHeaderSize()); + maxChunkSize = config.getMaxChunkSize(); + chunkedSupported = config.isChunkedSupported(); + headersFactory = config.getHeadersFactory(); + trailersFactory = config.getTrailersFactory(); + validateHeaders = isValidating(headersFactory); + allowDuplicateContentLengths = config.isAllowDuplicateContentLengths(); + allowPartialChunks = config.isAllowPartialChunks(); + } + + protected boolean isValidating(HttpHeadersFactory headersFactory) { + if (headersFactory instanceof DefaultHttpHeadersFactory) { + DefaultHttpHeadersFactory builder = (DefaultHttpHeadersFactory) headersFactory; + return builder.isValidatingHeaderNames() || builder.isValidatingHeaderValues(); + } + return true; // We can't actually tell in this case, but we assume some validation is taking place. + } + + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf buffer, List out) throws Exception { + if (resetRequested.get()) { + resetNow(); + } + + switch (currentState) { + case SKIP_CONTROL_CHARS: + // Fall-through + case READ_INITIAL: try { + ByteBuf line = lineParser.parse(buffer); + if (line == null) { + return; + } + final String[] initialLine = splitInitialLine(line); + assert initialLine.length == 3 : "initialLine::length must be 3"; + + message = createMessage(initialLine); + currentState = State.READ_HEADER; + // fall-through + } catch (Exception e) { + out.add(invalidMessage(buffer, e)); + return; + } + case READ_HEADER: try { + State nextState = readHeaders(buffer); + if (nextState == null) { + return; + } + currentState = nextState; + switch (nextState) { + case SKIP_CONTROL_CHARS: + // fast-path + // No content is expected. + out.add(message); + out.add(LastHttpContent.EMPTY_LAST_CONTENT); + resetNow(); + return; + case READ_CHUNK_SIZE: + if (!chunkedSupported) { + throw new IllegalArgumentException("Chunked messages not supported"); + } + // Chunked encoding - generate HttpMessage first. HttpChunks will follow. + out.add(message); + return; + default: + /* + * RFC 7230, 3.3.3 (https://tools.ietf.org/html/rfc7230#section-3.3.3) states that if a + * request does not have either a transfer-encoding or a content-length header then the message body + * length is 0. However, for a response the body length is the number of octets received prior to the + * server closing the connection. So we treat this as variable length chunked encoding. + */ + long contentLength = contentLength(); + if (contentLength == 0 || contentLength == -1 && isDecodingRequest()) { + out.add(message); + out.add(LastHttpContent.EMPTY_LAST_CONTENT); + resetNow(); + return; + } + + assert nextState == State.READ_FIXED_LENGTH_CONTENT || + nextState == State.READ_VARIABLE_LENGTH_CONTENT; + + out.add(message); + + if (nextState == State.READ_FIXED_LENGTH_CONTENT) { + // chunkSize will be decreased as the READ_FIXED_LENGTH_CONTENT state reads data chunk by chunk. + chunkSize = contentLength; + } + + // We return here, this forces decode to be called again where we will decode the content + return; + } + } catch (Exception e) { + out.add(invalidMessage(buffer, e)); + return; + } + case READ_VARIABLE_LENGTH_CONTENT: { + // Keep reading data as a chunk until the end of connection is reached. + int toRead = Math.min(buffer.readableBytes(), maxChunkSize); + if (toRead > 0) { + ByteBuf content = buffer.readRetainedSlice(toRead); + out.add(new DefaultHttpContent(content)); + } + return; + } + case READ_FIXED_LENGTH_CONTENT: { + int readLimit = buffer.readableBytes(); + + // Check if the buffer is readable first as we use the readable byte count + // to create the HttpChunk. This is needed as otherwise we may end up with + // create an HttpChunk instance that contains an empty buffer and so is + // handled like it is the last HttpChunk. + // + // See https://github.com/netty/netty/issues/433 + if (readLimit == 0) { + return; + } + + int toRead = Math.min(readLimit, maxChunkSize); + if (toRead > chunkSize) { + toRead = (int) chunkSize; + } + ByteBuf content = buffer.readRetainedSlice(toRead); + chunkSize -= toRead; + + if (chunkSize == 0) { + // Read all content. + out.add(new DefaultLastHttpContent(content, trailersFactory)); + resetNow(); + } else { + out.add(new DefaultHttpContent(content)); + } + return; + } + /* + * everything else after this point takes care of reading chunked content. basically, read chunk size, + * read chunk, read and ignore the CRLF and repeat until 0 + */ + case READ_CHUNK_SIZE: try { + ByteBuf line = lineParser.parse(buffer); + if (line == null) { + return; + } + int chunkSize = getChunkSize(line.array(), line.arrayOffset() + line.readerIndex(), line.readableBytes()); + this.chunkSize = chunkSize; + if (chunkSize == 0) { + currentState = State.READ_CHUNK_FOOTER; + return; + } + currentState = State.READ_CHUNKED_CONTENT; + // fall-through + } catch (Exception e) { + out.add(invalidChunk(buffer, e)); + return; + } + case READ_CHUNKED_CONTENT: { + assert chunkSize <= Integer.MAX_VALUE; + int toRead = Math.min((int) chunkSize, maxChunkSize); + if (!allowPartialChunks && buffer.readableBytes() < toRead) { + return; + } + toRead = Math.min(toRead, buffer.readableBytes()); + if (toRead == 0) { + return; + } + HttpContent chunk = new DefaultHttpContent(buffer.readRetainedSlice(toRead)); + chunkSize -= toRead; + + out.add(chunk); + + if (chunkSize != 0) { + return; + } + currentState = State.READ_CHUNK_DELIMITER; + // fall-through + } + case READ_CHUNK_DELIMITER: { + final int wIdx = buffer.writerIndex(); + int rIdx = buffer.readerIndex(); + while (wIdx > rIdx) { + byte next = buffer.getByte(rIdx++); + if (next == HttpConstants.LF) { + currentState = State.READ_CHUNK_SIZE; + break; + } + } + buffer.readerIndex(rIdx); + return; + } + case READ_CHUNK_FOOTER: try { + LastHttpContent trailer = readTrailingHeaders(buffer); + if (trailer == null) { + return; + } + out.add(trailer); + resetNow(); + return; + } catch (Exception e) { + out.add(invalidChunk(buffer, e)); + return; + } + case BAD_MESSAGE: { + // Keep discarding until disconnection. + buffer.skipBytes(buffer.readableBytes()); + break; + } + case UPGRADED: { + int readableBytes = buffer.readableBytes(); + if (readableBytes > 0) { + // Keep on consuming as otherwise we may trigger an DecoderException, + // other handler will replace this codec with the upgraded protocol codec to + // take the traffic over at some point then. + // See https://github.com/netty/netty/issues/2173 + out.add(buffer.readBytes(readableBytes)); + } + break; + } + default: + break; + } + } + + @Override + protected void decodeLast(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { + super.decodeLast(ctx, in, out); + + if (resetRequested.get()) { + // If a reset was requested by decodeLast() we need to do it now otherwise we may produce a + // LastHttpContent while there was already one. + resetNow(); + } + // Handle the last unfinished message. + if (message != null) { + boolean chunked = HttpUtil.isTransferEncodingChunked(message); + if (currentState == State.READ_VARIABLE_LENGTH_CONTENT && !in.isReadable() && !chunked) { + // End of connection. + out.add(LastHttpContent.EMPTY_LAST_CONTENT); + resetNow(); + return; + } + + if (currentState == State.READ_HEADER) { + // If we are still in the state of reading headers we need to create a new invalid message that + // signals that the connection was closed before we received the headers. + out.add(invalidMessage(Unpooled.EMPTY_BUFFER, + new PrematureChannelClosureException("Connection closed before received headers"))); + resetNow(); + return; + } + + // Check if the closure of the connection signifies the end of the content. + boolean prematureClosure; + if (isDecodingRequest() || chunked) { + // The last request did not wait for a response. + prematureClosure = true; + } else { + // Compare the length of the received content and the 'Content-Length' header. + // If the 'Content-Length' header is absent, the length of the content is determined by the end of the + // connection, so it is perfectly fine. + prematureClosure = contentLength() > 0; + } + + if (!prematureClosure) { + out.add(LastHttpContent.EMPTY_LAST_CONTENT); + } + resetNow(); + } + } + + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + if (evt instanceof HttpExpectationFailedEvent) { + switch (currentState) { + case READ_FIXED_LENGTH_CONTENT: + case READ_VARIABLE_LENGTH_CONTENT: + case READ_CHUNK_SIZE: + reset(); + break; + default: + break; + } + } + super.userEventTriggered(ctx, evt); + } + + protected boolean isContentAlwaysEmpty(HttpMessage msg) { + if (msg instanceof HttpResponse) { + HttpResponse res = (HttpResponse) msg; + final HttpResponseStatus status = res.status(); + final int code = status.code(); + final HttpStatusClass statusClass = status.codeClass(); + + // Correctly handle return codes of 1xx. + // + // See: + // - https://www.w3.org/Protocols/rfc2616/rfc2616-sec4.html Section 4.4 + // - https://github.com/netty/netty/issues/222 + if (statusClass == HttpStatusClass.INFORMATIONAL) { + // One exception: Hixie 76 websocket handshake response + return !(code == 101 && !res.headers().contains(HttpHeaderNames.SEC_WEBSOCKET_ACCEPT) + && res.headers().contains(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET, true)); + } + + switch (code) { + case 204: case 304: + return true; + default: + return false; + } + } + return false; + } + + /** + * Returns true if the server switched to a different protocol than HTTP/1.0 or HTTP/1.1, e.g. HTTP/2 or Websocket. + * Returns false if the upgrade happened in a different layer, e.g. upgrade from HTTP/1.1 to HTTP/1.1 over TLS. + */ + protected boolean isSwitchingToNonHttp1Protocol(HttpResponse msg) { + if (msg.status().code() != HttpResponseStatus.SWITCHING_PROTOCOLS.code()) { + return false; + } + String newProtocol = msg.headers().get(HttpHeaderNames.UPGRADE); + return newProtocol == null || + !newProtocol.contains(HttpVersion.HTTP_1_0.text()) && + !newProtocol.contains(HttpVersion.HTTP_1_1.text()); + } + + /** + * Resets the state of the decoder so that it is ready to decode a new message. + * This method is useful for handling a rejected request with {@code Expect: 100-continue} header. + */ + public void reset() { + resetRequested.lazySet(true); + } + + private void resetNow() { + HttpMessage message = this.message; + this.message = null; + name = null; + value = null; + contentLength = Long.MIN_VALUE; + lineParser.reset(); + headerParser.reset(); + trailer = null; + if (!isDecodingRequest()) { + HttpResponse res = (HttpResponse) message; + if (res != null && isSwitchingToNonHttp1Protocol(res)) { + currentState = State.UPGRADED; + return; + } + } + + resetRequested.lazySet(false); + currentState = State.SKIP_CONTROL_CHARS; + } + + private HttpMessage invalidMessage(ByteBuf in, Exception cause) { + currentState = State.BAD_MESSAGE; + + // Advance the readerIndex so that ByteToMessageDecoder does not complain + // when we produced an invalid message without consuming anything. + in.skipBytes(in.readableBytes()); + + if (message == null) { + message = createInvalidMessage(); + } + message.setDecoderResult(DecoderResult.failure(cause)); + + HttpMessage ret = message; + message = null; + return ret; + } + + private HttpContent invalidChunk(ByteBuf in, Exception cause) { + currentState = State.BAD_MESSAGE; + + // Advance the readerIndex so that ByteToMessageDecoder does not complain + // when we produced an invalid message without consuming anything. + in.skipBytes(in.readableBytes()); + + HttpContent chunk = new DefaultLastHttpContent(Unpooled.EMPTY_BUFFER); + chunk.setDecoderResult(DecoderResult.failure(cause)); + message = null; + trailer = null; + return chunk; + } + + private State readHeaders(ByteBuf buffer) { + final HttpMessage message = this.message; + final HttpHeaders headers = message.headers(); + + final HeaderParser headerParser = this.headerParser; + + ByteBuf line = headerParser.parse(buffer); + if (line == null) { + return null; + } + int lineLength = line.readableBytes(); + while (lineLength > 0) { + final byte[] lineContent = line.array(); + final int startLine = line.arrayOffset() + line.readerIndex(); + final byte firstChar = lineContent[startLine]; + if (name != null && (firstChar == ' ' || firstChar == '\t')) { + //please do not make one line from below code + //as it breaks +XX:OptimizeStringConcat optimization + String trimmedLine = langAsciiString(lineContent, startLine, lineLength).trim(); + String valueStr = value; + value = valueStr + ' ' + trimmedLine; + } else { + if (name != null) { + headers.add(name, value); + } + splitHeader(lineContent, startLine, lineLength); + } + + line = headerParser.parse(buffer); + if (line == null) { + return null; + } + lineLength = line.readableBytes(); + } + + // Add the last header. + if (name != null) { + headers.add(name, value); + } + + // reset name and value fields + name = null; + value = null; + + // Done parsing initial line and headers. Set decoder result. + HttpMessageDecoderResult decoderResult = new HttpMessageDecoderResult(lineParser.size, headerParser.size); + message.setDecoderResult(decoderResult); + + List contentLengthFields = headers.getAll(HttpHeaderNames.CONTENT_LENGTH); + if (!contentLengthFields.isEmpty()) { + HttpVersion version = message.protocolVersion(); + boolean isHttp10OrEarlier = version.majorVersion() < 1 || (version.majorVersion() == 1 + && version.minorVersion() == 0); + // Guard against multiple Content-Length headers as stated in + // https://tools.ietf.org/html/rfc7230#section-3.3.2: + contentLength = HttpUtil.normalizeAndGetContentLength(contentLengthFields, + isHttp10OrEarlier, allowDuplicateContentLengths); + if (contentLength != -1) { + String lengthValue = contentLengthFields.get(0).trim(); + if (contentLengthFields.size() > 1 || // don't unnecessarily re-order headers + !lengthValue.equals(Long.toString(contentLength))) { + headers.set(HttpHeaderNames.CONTENT_LENGTH, contentLength); + } + } + } + + if (isContentAlwaysEmpty(message)) { + HttpUtil.setTransferEncodingChunked(message, false); + return State.SKIP_CONTROL_CHARS; + } else if (HttpUtil.isTransferEncodingChunked(message)) { + if (!contentLengthFields.isEmpty() && message.protocolVersion() == HttpVersion.HTTP_1_1) { + handleTransferEncodingChunkedWithContentLength(message); + } + return State.READ_CHUNK_SIZE; + } else if (contentLength() >= 0) { + return State.READ_FIXED_LENGTH_CONTENT; + } else { + return State.READ_VARIABLE_LENGTH_CONTENT; + } + } + + /** + * Invoked when a message with both a "Transfer-Encoding: chunked" and a "Content-Length" header field is detected. + * The default behavior is to remove the Content-Length field, but this method could be overridden + * to change the behavior (to, e.g., throw an exception and produce an invalid message). + *

+ * See: https://tools.ietf.org/html/rfc7230#section-3.3.3 + *

+     *     If a message is received with both a Transfer-Encoding and a
+     *     Content-Length header field, the Transfer-Encoding overrides the
+     *     Content-Length.  Such a message might indicate an attempt to
+     *     perform request smuggling (Section 9.5) or response splitting
+     *     (Section 9.4) and ought to be handled as an error.  A sender MUST
+     *     remove the received Content-Length field prior to forwarding such
+     *     a message downstream.
+     * 
+ * Also see: + * https://github.com/apache/tomcat/blob/b693d7c1981fa7f51e58bc8c8e72e3fe80b7b773/ + * java/org/apache/coyote/http11/Http11Processor.java#L747-L755 + * https://github.com/nginx/nginx/blob/0ad4393e30c119d250415cb769e3d8bc8dce5186/ + * src/http/ngx_http_request.c#L1946-L1953 + */ + protected void handleTransferEncodingChunkedWithContentLength(HttpMessage message) { + message.headers().remove(HttpHeaderNames.CONTENT_LENGTH); + contentLength = Long.MIN_VALUE; + } + + private long contentLength() { + if (contentLength == Long.MIN_VALUE) { + contentLength = HttpUtil.getContentLength(message, -1L); + } + return contentLength; + } + + private LastHttpContent readTrailingHeaders(ByteBuf buffer) { + final HeaderParser headerParser = this.headerParser; + ByteBuf line = headerParser.parse(buffer); + if (line == null) { + return null; + } + LastHttpContent trailer = this.trailer; + int lineLength = line.readableBytes(); + if (lineLength == 0 && trailer == null) { + // We have received the empty line which signals the trailer is complete and did not parse any trailers + // before. Just return an empty last content to reduce allocations. + return LastHttpContent.EMPTY_LAST_CONTENT; + } + + CharSequence lastHeader = null; + if (trailer == null) { + trailer = this.trailer = new DefaultLastHttpContent(Unpooled.EMPTY_BUFFER, trailersFactory); + } + while (lineLength > 0) { + final byte[] lineContent = line.array(); + final int startLine = line.arrayOffset() + line.readerIndex(); + final byte firstChar = lineContent[startLine]; + if (lastHeader != null && (firstChar == ' ' || firstChar == '\t')) { + List current = trailer.trailingHeaders().getAll(lastHeader); + if (!current.isEmpty()) { + int lastPos = current.size() - 1; + //please do not make one line from below code + //as it breaks +XX:OptimizeStringConcat optimization + String lineTrimmed = langAsciiString(lineContent, startLine, line.readableBytes()).trim(); + String currentLastPos = current.get(lastPos); + current.set(lastPos, currentLastPos + lineTrimmed); + } + } else { + splitHeader(lineContent, startLine, lineLength); + AsciiString headerName = name; + if (!HttpHeaderNames.CONTENT_LENGTH.contentEqualsIgnoreCase(headerName) && + !HttpHeaderNames.TRANSFER_ENCODING.contentEqualsIgnoreCase(headerName) && + !HttpHeaderNames.TRAILER.contentEqualsIgnoreCase(headerName)) { + trailer.trailingHeaders().add(headerName, value); + } + lastHeader = name; + // reset name and value fields + name = null; + value = null; + } + line = headerParser.parse(buffer); + if (line == null) { + return null; + } + lineLength = line.readableBytes(); + } + + this.trailer = null; + return trailer; + } + + protected abstract boolean isDecodingRequest(); + protected abstract HttpMessage createMessage(String[] initialLine) throws Exception; + protected abstract HttpMessage createInvalidMessage(); + + /** + * It skips any whitespace char and return the number of skipped bytes. + */ + private static int skipWhiteSpaces(byte[] hex, int start, int length) { + for (int i = 0; i < length; i++) { + if (!isWhitespace(hex[start + i])) { + return i; + } + } + return length; + } + + private static int getChunkSize(byte[] hex, int start, int length) { + // trim the leading bytes if white spaces, if any + final int skipped = skipWhiteSpaces(hex, start, length); + if (skipped == length) { + // empty case + throw new NumberFormatException(); + } + start += skipped; + length -= skipped; + int result = 0; + for (int i = 0; i < length; i++) { + final int digit = StringUtil.decodeHexNibble(hex[start + i]); + if (digit == -1) { + // uncommon path + final byte b = hex[start + i]; + if (b == ';' || isControlOrWhitespaceAsciiChar(b)) { + if (i == 0) { + // empty case + throw new NumberFormatException(); + } + return result; + } + // non-hex char fail-fast path + throw new NumberFormatException(); + } + result *= 16; + result += digit; + } + return result; + } + + private String[] splitInitialLine(ByteBuf asciiBuffer) { + final byte[] asciiBytes = asciiBuffer.array(); + + final int arrayOffset = asciiBuffer.arrayOffset(); + + final int startContent = arrayOffset + asciiBuffer.readerIndex(); + + final int end = startContent + asciiBuffer.readableBytes(); + + final int aStart = findNonSPLenient(asciiBytes, startContent, end); + final int aEnd = findSPLenient(asciiBytes, aStart, end); + + final int bStart = findNonSPLenient(asciiBytes, aEnd, end); + final int bEnd = findSPLenient(asciiBytes, bStart, end); + + final int cStart = findNonSPLenient(asciiBytes, bEnd, end); + final int cEnd = findEndOfString(asciiBytes, Math.max(cStart - 1, startContent), end); + + return new String[]{ + splitFirstWordInitialLine(asciiBytes, aStart, aEnd - aStart), + splitSecondWordInitialLine(asciiBytes, bStart, bEnd - bStart), + cStart < cEnd ? splitThirdWordInitialLine(asciiBytes, cStart, cEnd - cStart) : StringUtil.EMPTY_STRING}; + } + + protected String splitFirstWordInitialLine(final byte[] asciiContent, int start, int length) { + return langAsciiString(asciiContent, start, length); + } + + protected String splitSecondWordInitialLine(final byte[] asciiContent, int start, int length) { + return langAsciiString(asciiContent, start, length); + } + + protected String splitThirdWordInitialLine(final byte[] asciiContent, int start, int length) { + return langAsciiString(asciiContent, start, length); + } + + /** + * This method shouldn't exist: look at https://bugs.openjdk.org/browse/JDK-8295496 for more context + */ + private static String langAsciiString(final byte[] asciiContent, int start, int length) { + if (length == 0) { + return StringUtil.EMPTY_STRING; + } + // DON'T REMOVE: it helps JIT to use a simpler intrinsic stub for System::arrayCopy based on the call-site + if (start == 0) { + if (length == asciiContent.length) { + return new String(asciiContent, 0, 0, asciiContent.length); + } + return new String(asciiContent, 0, 0, length); + } + return new String(asciiContent, 0, start, length); + } + + private void splitHeader(byte[] line, int start, int length) { + final int end = start + length; + int nameEnd; + final int nameStart = findNonWhitespace(line, start, end); + // hoist this load out of the loop, because it won't change! + final boolean isDecodingRequest = isDecodingRequest(); + for (nameEnd = nameStart; nameEnd < end; nameEnd ++) { + byte ch = line[nameEnd]; + // https://tools.ietf.org/html/rfc7230#section-3.2.4 + // + // No whitespace is allowed between the header field-name and colon. In + // the past, differences in the handling of such whitespace have led to + // security vulnerabilities in request routing and response handling. A + // server MUST reject any received request message that contains + // whitespace between a header field-name and colon with a response code + // of 400 (Bad Request). A proxy MUST remove any such whitespace from a + // response message before forwarding the message downstream. + if (ch == ':' || + // In case of decoding a request we will just continue processing and header validation + // is done in the DefaultHttpHeaders implementation. + // + // In the case of decoding a response we will "skip" the whitespace. + (!isDecodingRequest && isOWS(ch))) { + break; + } + } + + if (nameEnd == end) { + // There was no colon present at all. + throw new IllegalArgumentException("No colon found"); + } + int colonEnd; + for (colonEnd = nameEnd; colonEnd < end; colonEnd ++) { + if (line[colonEnd] == ':') { + colonEnd ++; + break; + } + } + name = splitHeaderName(line, nameStart, nameEnd - nameStart); + final int valueStart = findNonWhitespace(line, colonEnd, end); + if (valueStart == end) { + value = StringUtil.EMPTY_STRING; + } else { + final int valueEnd = findEndOfString(line, start, end); + // no need to make uses of the ByteBuf's toString ASCII method here, and risk to get JIT confused + value = langAsciiString(line, valueStart, valueEnd - valueStart); + } + } + + protected AsciiString splitHeaderName(byte[] sb, int start, int length) { + return new AsciiString(sb, start, length, true); + } + + private static int findNonSPLenient(byte[] sb, int offset, int end) { + for (int result = offset; result < end; ++result) { + byte c = sb[result]; + // See https://tools.ietf.org/html/rfc7230#section-3.5 + if (isSPLenient(c)) { + continue; + } + if (isWhitespace(c)) { + // Any other whitespace delimiter is invalid + throw new IllegalArgumentException("Invalid separator"); + } + return result; + } + return end; + } + + private static int findSPLenient(byte[] sb, int offset, int end) { + for (int result = offset; result < end; ++result) { + if (isSPLenient(sb[result])) { + return result; + } + } + return end; + } + + private static final boolean[] SP_LENIENT_BYTES; + private static final boolean[] LATIN_WHITESPACE; + + static { + // See https://tools.ietf.org/html/rfc7230#section-3.5 + SP_LENIENT_BYTES = new boolean[256]; + SP_LENIENT_BYTES[128 + ' '] = true; + SP_LENIENT_BYTES[128 + 0x09] = true; + SP_LENIENT_BYTES[128 + 0x0B] = true; + SP_LENIENT_BYTES[128 + 0x0C] = true; + SP_LENIENT_BYTES[128 + 0x0D] = true; + // TO SAVE PERFORMING Character::isWhitespace ceremony + LATIN_WHITESPACE = new boolean[256]; + for (byte b = Byte.MIN_VALUE; b < Byte.MAX_VALUE; b++) { + LATIN_WHITESPACE[128 + b] = Character.isWhitespace(b); + } + } + + private static boolean isSPLenient(byte c) { + // See https://tools.ietf.org/html/rfc7230#section-3.5 + return SP_LENIENT_BYTES[c + 128]; + } + + private static boolean isWhitespace(byte b) { + return LATIN_WHITESPACE[b + 128]; + } + + private static int findNonWhitespace(byte[] sb, int offset, int end) { + for (int result = offset; result < end; ++result) { + byte c = sb[result]; + if (!isWhitespace(c)) { + return result; + } else if (!isOWS(c)) { + // Only OWS is supported for whitespace + throw new IllegalArgumentException("Invalid separator, only a single space or horizontal tab allowed," + + " but received a '" + c + "' (0x" + Integer.toHexString(c) + ")"); + } + } + return end; + } + + private static int findEndOfString(byte[] sb, int start, int end) { + for (int result = end - 1; result > start; --result) { + if (!isWhitespace(sb[result])) { + return result + 1; + } + } + return 0; + } + + private static boolean isOWS(byte ch) { + return ch == ' ' || ch == 0x09; + } + + private static class HeaderParser { + protected final ByteBuf seq; + protected final int maxLength; + int size; + + HeaderParser(ByteBuf seq, int maxLength) { + this.seq = seq; + this.maxLength = maxLength; + } + + public ByteBuf parse(ByteBuf buffer) { + final int readableBytes = buffer.readableBytes(); + final int readerIndex = buffer.readerIndex(); + final int maxBodySize = maxLength - size; + assert maxBodySize >= 0; + // adding 2 to account for both CR (if present) and LF + // don't remove 2L: it's key to cover maxLength = Integer.MAX_VALUE + final long maxBodySizeWithCRLF = maxBodySize + 2L; + final int toProcess = (int) Math.min(maxBodySizeWithCRLF, readableBytes); + final int toIndexExclusive = readerIndex + toProcess; + assert toIndexExclusive >= readerIndex; + final int indexOfLf = buffer.indexOf(readerIndex, toIndexExclusive, HttpConstants.LF); + if (indexOfLf == -1) { + if (readableBytes > maxBodySize) { + // TODO: Respond with Bad Request and discard the traffic + // or close the connection. + // No need to notify the upstream handlers - just log. + // If decoding a response, just throw an exception. + throw newException(maxLength); + } + return null; + } + final int endOfSeqIncluded; + if (indexOfLf > readerIndex && buffer.getByte(indexOfLf - 1) == HttpConstants.CR) { + // Drop CR if we had a CRLF pair + endOfSeqIncluded = indexOfLf - 1; + } else { + endOfSeqIncluded = indexOfLf; + } + final int newSize = endOfSeqIncluded - readerIndex; + if (newSize == 0) { + seq.clear(); + buffer.readerIndex(indexOfLf + 1); + return seq; + } + int size = this.size + newSize; + if (size > maxLength) { + throw newException(maxLength); + } + this.size = size; + seq.clear(); + seq.writeBytes(buffer, readerIndex, newSize); + buffer.readerIndex(indexOfLf + 1); + return seq; + } + + public void reset() { + size = 0; + } + + protected TooLongFrameException newException(int maxLength) { + return new TooLongHttpHeaderException("HTTP header is larger than " + maxLength + " bytes."); + } + } + + private final class LineParser extends HeaderParser { + + LineParser(ByteBuf seq, int maxLength) { + super(seq, maxLength); + } + + @Override + public ByteBuf parse(ByteBuf buffer) { + // Suppress a warning because HeaderParser.reset() is supposed to be called + reset(); + final int readableBytes = buffer.readableBytes(); + if (readableBytes == 0) { + return null; + } + final int readerIndex = buffer.readerIndex(); + if (currentState == State.SKIP_CONTROL_CHARS && skipControlChars(buffer, readableBytes, readerIndex)) { + return null; + } + return super.parse(buffer); + } + + private boolean skipControlChars(ByteBuf buffer, int readableBytes, int readerIndex) { + assert currentState == State.SKIP_CONTROL_CHARS; + final int maxToSkip = Math.min(maxLength, readableBytes); + final int firstNonControlIndex = buffer.forEachByte(readerIndex, maxToSkip, SKIP_CONTROL_CHARS_BYTES); + if (firstNonControlIndex == -1) { + buffer.skipBytes(maxToSkip); + if (readableBytes > maxLength) { + throw newException(maxLength); + } + return true; + } + // from now on we don't care about control chars + buffer.readerIndex(firstNonControlIndex); + currentState = State.READ_INITIAL; + return false; + } + + @Override + protected TooLongFrameException newException(int maxLength) { + return new TooLongHttpLineException("An HTTP line is larger than " + maxLength + " bytes."); + } + } + + private static final boolean[] ISO_CONTROL_OR_WHITESPACE; + + static { + ISO_CONTROL_OR_WHITESPACE = new boolean[256]; + for (byte b = Byte.MIN_VALUE; b < Byte.MAX_VALUE; b++) { + ISO_CONTROL_OR_WHITESPACE[128 + b] = Character.isISOControl(b) || isWhitespace(b); + } + } + + private static final ByteProcessor SKIP_CONTROL_CHARS_BYTES = new ByteProcessor() { + + @Override + public boolean process(byte value) { + return ISO_CONTROL_OR_WHITESPACE[128 + value]; + } + }; + + private static boolean isControlOrWhitespaceAsciiChar(byte b) { + return ISO_CONTROL_OR_WHITESPACE[128 + b]; + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpObjectEncoder.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpObjectEncoder.java new file mode 100755 index 0000000..1008833 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpObjectEncoder.java @@ -0,0 +1,597 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.channel.FileRegion; +import io.netty.handler.codec.EncoderException; +import io.netty.handler.codec.MessageToMessageEncoder; +import io.netty.util.CharsetUtil; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.concurrent.PromiseCombiner; +import io.netty.util.internal.StringUtil; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Map.Entry; + +import static io.netty.buffer.Unpooled.directBuffer; +import static io.netty.buffer.Unpooled.unreleasableBuffer; +import static io.netty.handler.codec.http.HttpConstants.CR; +import static io.netty.handler.codec.http.HttpConstants.LF; + +/** + * Encodes an {@link HttpMessage} or an {@link HttpContent} into + * a {@link ByteBuf}. + * + *

Extensibility

+ * + * Please note that this encoder is designed to be extended to implement + * a protocol derived from HTTP, such as + * RTSP and + * ICAP. + * To implement the encoder of such a derived protocol, extend this class and + * implement all abstract methods properly. + */ +public abstract class HttpObjectEncoder extends MessageToMessageEncoder { + static final int CRLF_SHORT = (CR << 8) | LF; + private static final int ZERO_CRLF_MEDIUM = ('0' << 16) | CRLF_SHORT; + private static final byte[] ZERO_CRLF_CRLF = { '0', CR, LF, CR, LF }; + private static final ByteBuf CRLF_BUF = unreleasableBuffer( + directBuffer(2).writeByte(CR).writeByte(LF)).asReadOnly(); + private static final ByteBuf ZERO_CRLF_CRLF_BUF = unreleasableBuffer( + directBuffer(ZERO_CRLF_CRLF.length).writeBytes(ZERO_CRLF_CRLF)).asReadOnly(); + private static final float HEADERS_WEIGHT_NEW = 1 / 5f; + private static final float HEADERS_WEIGHT_HISTORICAL = 1 - HEADERS_WEIGHT_NEW; + private static final float TRAILERS_WEIGHT_NEW = HEADERS_WEIGHT_NEW; + private static final float TRAILERS_WEIGHT_HISTORICAL = HEADERS_WEIGHT_HISTORICAL; + + private static final int ST_INIT = 0; + private static final int ST_CONTENT_NON_CHUNK = 1; + private static final int ST_CONTENT_CHUNK = 2; + private static final int ST_CONTENT_ALWAYS_EMPTY = 3; + + @SuppressWarnings("RedundantFieldInitialization") + private int state = ST_INIT; + + /** + * Used to calculate an exponential moving average of the encoded size of the initial line and the headers for + * a guess for future buffer allocations. + */ + private float headersEncodedSizeAccumulator = 256; + + /** + * Used to calculate an exponential moving average of the encoded size of the trailers for + * a guess for future buffer allocations. + */ + private float trailersEncodedSizeAccumulator = 256; + + private final List out = new ArrayList(); + + private static boolean checkContentState(int state) { + return state == ST_CONTENT_CHUNK || state == ST_CONTENT_NON_CHUNK || state == ST_CONTENT_ALWAYS_EMPTY; + } + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + try { + if (acceptOutboundMessage(msg)) { + encode(ctx, msg, out); + if (out.isEmpty()) { + throw new EncoderException( + StringUtil.simpleClassName(this) + " must produce at least one message."); + } + } else { + ctx.write(msg, promise); + } + } catch (EncoderException e) { + throw e; + } catch (Throwable t) { + throw new EncoderException(t); + } finally { + writeOutList(ctx, out, promise); + } + } + + private static void writeOutList(ChannelHandlerContext ctx, List out, ChannelPromise promise) { + final int size = out.size(); + try { + if (size == 1) { + ctx.write(out.get(0), promise); + } else if (size > 1) { + // Check if we can use a voidPromise for our extra writes to reduce GC-Pressure + // See https://github.com/netty/netty/issues/2525 + if (promise == ctx.voidPromise()) { + writeVoidPromise(ctx, out); + } else { + writePromiseCombiner(ctx, out, promise); + } + } + } finally { + out.clear(); + } + } + + private static void writeVoidPromise(ChannelHandlerContext ctx, List out) { + final ChannelPromise voidPromise = ctx.voidPromise(); + for (int i = 0; i < out.size(); i++) { + ctx.write(out.get(i), voidPromise); + } + } + + private static void writePromiseCombiner(ChannelHandlerContext ctx, List out, ChannelPromise promise) { + final PromiseCombiner combiner = new PromiseCombiner(ctx.executor()); + for (int i = 0; i < out.size(); i++) { + combiner.add(ctx.write(out.get(i))); + } + combiner.finish(promise); + } + + @Override + @SuppressWarnings("ConditionCoveredByFurtherCondition") + protected void encode(ChannelHandlerContext ctx, Object msg, List out) throws Exception { + // fast-path for common idiom that doesn't require class-checks + if (msg == Unpooled.EMPTY_BUFFER) { + out.add(Unpooled.EMPTY_BUFFER); + return; + } + // The reason why we perform instanceof checks in this order, + // by duplicating some code and without relying on ReferenceCountUtil::release as a generic release + // mechanism, is https://bugs.openjdk.org/browse/JDK-8180450. + // https://github.com/netty/netty/issues/12708 contains more detail re how the previous version of this + // code was interacting with the JIT instanceof optimizations. + if (msg instanceof FullHttpMessage) { + encodeFullHttpMessage(ctx, msg, out); + return; + } + if (msg instanceof HttpMessage) { + final H m; + try { + m = (H) msg; + } catch (Exception rethrow) { + ReferenceCountUtil.release(msg); + throw rethrow; + } + if (m instanceof LastHttpContent) { + encodeHttpMessageLastContent(ctx, m, out); + } else if (m instanceof HttpContent) { + encodeHttpMessageNotLastContent(ctx, m, out); + } else { + encodeJustHttpMessage(ctx, m, out); + } + } else { + encodeNotHttpMessageContentTypes(ctx, msg, out); + } + } + + private void encodeJustHttpMessage(ChannelHandlerContext ctx, H m, List out) throws Exception { + assert !(m instanceof HttpContent); + try { + if (state != ST_INIT) { + throwUnexpectedMessageTypeEx(m, state); + } + final ByteBuf buf = encodeInitHttpMessage(ctx, m); + + assert checkContentState(state); + + out.add(buf); + } finally { + ReferenceCountUtil.release(m); + } + } + + private void encodeByteBufHttpContent(int state, ChannelHandlerContext ctx, ByteBuf buf, ByteBuf content, + HttpHeaders trailingHeaders, List out) { + switch (state) { + case ST_CONTENT_NON_CHUNK: + if (encodeContentNonChunk(out, buf, content)) { + break; + } + // fall-through! + case ST_CONTENT_ALWAYS_EMPTY: + // We allocated a buffer so add it now. + out.add(buf); + break; + case ST_CONTENT_CHUNK: + // We allocated a buffer so add it now. + out.add(buf); + encodeChunkedHttpContent(ctx, content, trailingHeaders, out); + break; + default: + throw new Error(); + } + } + + private void encodeHttpMessageNotLastContent(ChannelHandlerContext ctx, H m, List out) throws Exception { + assert m instanceof HttpContent; + assert !(m instanceof LastHttpContent); + final HttpContent httpContent = (HttpContent) m; + try { + if (state != ST_INIT) { + throwUnexpectedMessageTypeEx(m, state); + } + final ByteBuf buf = encodeInitHttpMessage(ctx, m); + + assert checkContentState(state); + + encodeByteBufHttpContent(state, ctx, buf, httpContent.content(), null, out); + } finally { + httpContent.release(); + } + } + + private void encodeHttpMessageLastContent(ChannelHandlerContext ctx, H m, List out) throws Exception { + assert m instanceof LastHttpContent; + final LastHttpContent httpContent = (LastHttpContent) m; + try { + if (state != ST_INIT) { + throwUnexpectedMessageTypeEx(m, state); + } + final ByteBuf buf = encodeInitHttpMessage(ctx, m); + + assert checkContentState(state); + + encodeByteBufHttpContent(state, ctx, buf, httpContent.content(), httpContent.trailingHeaders(), out); + + state = ST_INIT; + } finally { + httpContent.release(); + } + } + @SuppressWarnings("ConditionCoveredByFurtherCondition") + private void encodeNotHttpMessageContentTypes(ChannelHandlerContext ctx, Object msg, List out) { + assert !(msg instanceof HttpMessage); + if (state == ST_INIT) { + try { + if (msg instanceof ByteBuf && bypassEncoderIfEmpty((ByteBuf) msg, out)) { + return; + } + throwUnexpectedMessageTypeEx(msg, ST_INIT); + } finally { + ReferenceCountUtil.release(msg); + } + } + if (msg == LastHttpContent.EMPTY_LAST_CONTENT) { + state = encodeEmptyLastHttpContent(state, out); + return; + } + if (msg instanceof LastHttpContent) { + encodeLastHttpContent(ctx, (LastHttpContent) msg, out); + return; + } + if (msg instanceof HttpContent) { + encodeHttpContent(ctx, (HttpContent) msg, out); + return; + } + if (msg instanceof ByteBuf) { + encodeByteBufContent(ctx, (ByteBuf) msg, out); + return; + } + if (msg instanceof FileRegion) { + encodeFileRegionContent(ctx, (FileRegion) msg, out); + return; + } + try { + throwUnexpectedMessageTypeEx(msg, state); + } finally { + ReferenceCountUtil.release(msg); + } + } + + private void encodeFullHttpMessage(ChannelHandlerContext ctx, Object o, List out) + throws Exception { + assert o instanceof FullHttpMessage; + final FullHttpMessage msg = (FullHttpMessage) o; + try { + if (state != ST_INIT) { + throwUnexpectedMessageTypeEx(o, state); + } + + final H m = (H) o; + + final ByteBuf buf = ctx.alloc().buffer((int) headersEncodedSizeAccumulator); + + encodeInitialLine(buf, m); + + final int state = isContentAlwaysEmpty(m) ? ST_CONTENT_ALWAYS_EMPTY : + HttpUtil.isTransferEncodingChunked(m) ? ST_CONTENT_CHUNK : ST_CONTENT_NON_CHUNK; + + sanitizeHeadersBeforeEncode(m, state == ST_CONTENT_ALWAYS_EMPTY); + + encodeHeaders(m.headers(), buf); + ByteBufUtil.writeShortBE(buf, CRLF_SHORT); + + headersEncodedSizeAccumulator = HEADERS_WEIGHT_NEW * padSizeForAccumulation(buf.readableBytes()) + + HEADERS_WEIGHT_HISTORICAL * headersEncodedSizeAccumulator; + + encodeByteBufHttpContent(state, ctx, buf, msg.content(), msg.trailingHeaders(), out); + } finally { + msg.release(); + } + } + + private static boolean encodeContentNonChunk(List out, ByteBuf buf, ByteBuf content) { + final int contentLength = content.readableBytes(); + if (contentLength > 0) { + if (buf.writableBytes() >= contentLength) { + // merge into other buffer for performance reasons + buf.writeBytes(content); + out.add(buf); + } else { + out.add(buf); + out.add(content.retain()); + } + return true; + } + return false; + } + + private static void throwUnexpectedMessageTypeEx(Object msg, int state) { + throw new IllegalStateException("unexpected message type: " + StringUtil.simpleClassName(msg) + + ", state: " + state); + } + + private void encodeFileRegionContent(ChannelHandlerContext ctx, FileRegion msg, List out) { + try { + assert state != ST_INIT; + switch (state) { + case ST_CONTENT_NON_CHUNK: + if (msg.count() > 0) { + out.add(msg.retain()); + break; + } + + // fall-through! + case ST_CONTENT_ALWAYS_EMPTY: + // Need to produce some output otherwise an + // IllegalStateException will be thrown as we did not write anything + // Its ok to just write an EMPTY_BUFFER as if there are reference count issues these will be + // propagated as the caller of the encode(...) method will release the original + // buffer. + // Writing an empty buffer will not actually write anything on the wire, so if there is a user + // error with msg it will not be visible externally + out.add(Unpooled.EMPTY_BUFFER); + break; + case ST_CONTENT_CHUNK: + encodedChunkedFileRegionContent(ctx, msg, out); + break; + default: + throw new Error(); + } + } finally { + msg.release(); + } + } + + // Bypass the encoder in case of an empty buffer, so that the following idiom works: + // + // ch.write(Unpooled.EMPTY_BUFFER).addListener(ChannelFutureListener.CLOSE); + // + // See https://github.com/netty/netty/issues/2983 for more information. + private static boolean bypassEncoderIfEmpty(ByteBuf msg, List out) { + if (!msg.isReadable()) { + out.add(msg.retain()); + return true; + } + return false; + } + + private void encodeByteBufContent(ChannelHandlerContext ctx, ByteBuf content, List out) { + try { + assert state != ST_INIT; + if (bypassEncoderIfEmpty(content, out)) { + return; + } + encodeByteBufAndTrailers(state, ctx, out, content, null); + } finally { + content.release(); + } + } + + private static int encodeEmptyLastHttpContent(int state, List out) { + assert state != ST_INIT; + + switch (state) { + case ST_CONTENT_NON_CHUNK: + case ST_CONTENT_ALWAYS_EMPTY: + out.add(Unpooled.EMPTY_BUFFER); + break; + case ST_CONTENT_CHUNK: + out.add(ZERO_CRLF_CRLF_BUF.duplicate()); + break; + default: + throw new Error(); + } + return ST_INIT; + } + + private void encodeLastHttpContent(ChannelHandlerContext ctx, LastHttpContent msg, List out) { + assert state != ST_INIT; + assert !(msg instanceof HttpMessage); + try { + encodeByteBufAndTrailers(state, ctx, out, msg.content(), msg.trailingHeaders()); + state = ST_INIT; + } finally { + msg.release(); + } + } + + private void encodeHttpContent(ChannelHandlerContext ctx, HttpContent msg, List out) { + assert state != ST_INIT; + assert !(msg instanceof HttpMessage); + assert !(msg instanceof LastHttpContent); + try { + this.encodeByteBufAndTrailers(state, ctx, out, msg.content(), null); + } finally { + msg.release(); + } + } + + private void encodeByteBufAndTrailers(int state, ChannelHandlerContext ctx, List out, ByteBuf content, + HttpHeaders trailingHeaders) { + switch (state) { + case ST_CONTENT_NON_CHUNK: + if (content.isReadable()) { + out.add(content.retain()); + break; + } + // fall-through! + case ST_CONTENT_ALWAYS_EMPTY: + out.add(Unpooled.EMPTY_BUFFER); + break; + case ST_CONTENT_CHUNK: + encodeChunkedHttpContent(ctx, content, trailingHeaders, out); + break; + default: + throw new Error(); + } + } + + private void encodeChunkedHttpContent(ChannelHandlerContext ctx, ByteBuf content, HttpHeaders trailingHeaders, + List out) { + final int contentLength = content.readableBytes(); + if (contentLength > 0) { + addEncodedLengthHex(ctx, contentLength, out); + out.add(content.retain()); + out.add(CRLF_BUF.duplicate()); + } + if (trailingHeaders != null) { + encodeTrailingHeaders(ctx, trailingHeaders, out); + } else if (contentLength == 0) { + // Need to produce some output otherwise an + // IllegalStateException will be thrown + out.add(content.retain()); + } + } + + private void encodeTrailingHeaders(ChannelHandlerContext ctx, HttpHeaders trailingHeaders, List out) { + if (trailingHeaders.isEmpty()) { + out.add(ZERO_CRLF_CRLF_BUF.duplicate()); + } else { + ByteBuf buf = ctx.alloc().buffer((int) trailersEncodedSizeAccumulator); + ByteBufUtil.writeMediumBE(buf, ZERO_CRLF_MEDIUM); + encodeHeaders(trailingHeaders, buf); + ByteBufUtil.writeShortBE(buf, CRLF_SHORT); + trailersEncodedSizeAccumulator = TRAILERS_WEIGHT_NEW * padSizeForAccumulation(buf.readableBytes()) + + TRAILERS_WEIGHT_HISTORICAL * trailersEncodedSizeAccumulator; + out.add(buf); + } + } + + private ByteBuf encodeInitHttpMessage(ChannelHandlerContext ctx, H m) throws Exception { + assert state == ST_INIT; + + ByteBuf buf = ctx.alloc().buffer((int) headersEncodedSizeAccumulator); + // Encode the message. + encodeInitialLine(buf, m); + state = isContentAlwaysEmpty(m) ? ST_CONTENT_ALWAYS_EMPTY : + HttpUtil.isTransferEncodingChunked(m) ? ST_CONTENT_CHUNK : ST_CONTENT_NON_CHUNK; + + sanitizeHeadersBeforeEncode(m, state == ST_CONTENT_ALWAYS_EMPTY); + + encodeHeaders(m.headers(), buf); + ByteBufUtil.writeShortBE(buf, CRLF_SHORT); + + headersEncodedSizeAccumulator = HEADERS_WEIGHT_NEW * padSizeForAccumulation(buf.readableBytes()) + + HEADERS_WEIGHT_HISTORICAL * headersEncodedSizeAccumulator; + return buf; + } + + /** + * Encode the {@link HttpHeaders} into a {@link ByteBuf}. + */ + protected void encodeHeaders(HttpHeaders headers, ByteBuf buf) { + Iterator> iter = headers.iteratorCharSequence(); + while (iter.hasNext()) { + Entry header = iter.next(); + HttpHeadersEncoder.encoderHeader(header.getKey(), header.getValue(), buf); + } + } + + private static void encodedChunkedFileRegionContent(ChannelHandlerContext ctx, FileRegion msg, List out) { + final long contentLength = msg.count(); + if (contentLength > 0) { + addEncodedLengthHex(ctx, contentLength, out); + out.add(msg.retain()); + out.add(CRLF_BUF.duplicate()); + } else if (contentLength == 0) { + // Need to produce some output otherwise an + // IllegalStateException will be thrown + out.add(msg.retain()); + } + } + + private static void addEncodedLengthHex(ChannelHandlerContext ctx, long contentLength, List out) { + String lengthHex = Long.toHexString(contentLength); + ByteBuf buf = ctx.alloc().buffer(lengthHex.length() + 2); + buf.writeCharSequence(lengthHex, CharsetUtil.US_ASCII); + ByteBufUtil.writeShortBE(buf, CRLF_SHORT); + out.add(buf); + } + + /** + * Allows to sanitize headers of the message before encoding these. + */ + protected void sanitizeHeadersBeforeEncode(@SuppressWarnings("unused") H msg, boolean isAlwaysEmpty) { + // noop + } + + /** + * Determine whether a message has a content or not. Some message may have headers indicating + * a content without having an actual content, e.g the response to an HEAD or CONNECT request. + * + * @param msg the message to test + * @return {@code true} to signal the message has no content + */ + protected boolean isContentAlwaysEmpty(@SuppressWarnings("unused") H msg) { + return false; + } + + @Override + @SuppressWarnings("ConditionCoveredByFurtherCondition") + public boolean acceptOutboundMessage(Object msg) throws Exception { + return msg == Unpooled.EMPTY_BUFFER || + msg == LastHttpContent.EMPTY_LAST_CONTENT || + msg instanceof FullHttpMessage || + msg instanceof HttpMessage || + msg instanceof LastHttpContent || + msg instanceof HttpContent || + msg instanceof ByteBuf || msg instanceof FileRegion; + } + + /** + * Add some additional overhead to the buffer. The rational is that it is better to slightly over allocate and waste + * some memory, rather than under allocate and require a resize/copy. + * + * @param readableBytes The readable bytes in the buffer. + * @return The {@code readableBytes} with some additional padding. + */ + private static int padSizeForAccumulation(int readableBytes) { + return (readableBytes << 2) / 3; + } + + @Deprecated + protected static void encodeAscii(String s, ByteBuf buf) { + buf.writeCharSequence(s, CharsetUtil.US_ASCII); + } + + protected abstract void encodeInitialLine(ByteBuf buf, H message) throws Exception; +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpRequest.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpRequest.java new file mode 100644 index 0000000..5484b4c --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpRequest.java @@ -0,0 +1,78 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +/** + * An HTTP request. + * + *

Accessing Query Parameters and Cookie

+ *

+ * Unlike the Servlet API, a query string is constructed and decomposed by + * {@link QueryStringEncoder} and {@link QueryStringDecoder}. + * + * {@link io.netty.handler.codec.http.cookie.Cookie} support is also provided + * separately via {@link io.netty.handler.codec.http.cookie.ServerCookieDecoder}, + * {@link io.netty.handler.codec.http.cookie.ClientCookieDecoder}, + * {@link io.netty.handler.codec.http.cookie.ServerCookieEncoder}, + * and {@link io.netty.handler.codec.http.cookie.ClientCookieEncoder}. + * + * @see HttpResponse + * @see io.netty.handler.codec.http.cookie.ServerCookieDecoder + * @see io.netty.handler.codec.http.cookie.ClientCookieDecoder + * @see io.netty.handler.codec.http.cookie.ServerCookieEncoder + * @see io.netty.handler.codec.http.cookie.ClientCookieEncoder + */ +public interface HttpRequest extends HttpMessage { + + /** + * @deprecated Use {@link #method()} instead. + */ + @Deprecated + HttpMethod getMethod(); + + /** + * Returns the {@link HttpMethod} of this {@link HttpRequest}. + * + * @return The {@link HttpMethod} of this {@link HttpRequest} + */ + HttpMethod method(); + + /** + * Set the {@link HttpMethod} of this {@link HttpRequest}. + */ + HttpRequest setMethod(HttpMethod method); + + /** + * @deprecated Use {@link #uri()} instead. + */ + @Deprecated + String getUri(); + + /** + * Returns the requested URI (or alternatively, path) + * + * @return The URI being requested + */ + String uri(); + + /** + * Set the requested URI (or alternatively, path) + */ + HttpRequest setUri(String uri); + + @Override + HttpRequest setProtocolVersion(HttpVersion version); +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpRequestDecoder.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpRequestDecoder.java new file mode 100644 index 0000000..59fdd8d --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpRequestDecoder.java @@ -0,0 +1,359 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelPipeline; +import io.netty.util.AsciiString; + +/** + * Decodes {@link ByteBuf}s into {@link HttpRequest}s and {@link HttpContent}s. + * + *

Parameters that prevents excessive memory consumption

+ * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
NameMeaning
{@code maxInitialLineLength}The maximum length of the initial line (e.g. {@code "GET / HTTP/1.0"}) + * If the length of the initial line exceeds this value, a + * {@link TooLongHttpLineException} will be raised.
{@code maxHeaderSize}The maximum length of all headers. If the sum of the length of each + * header exceeds this value, a {@link TooLongHttpHeaderException} will be raised.
{@code maxChunkSize}The maximum length of the content or each chunk. If the content length + * exceeds this value, the transfer encoding of the decoded request will be + * converted to 'chunked' and the content will be split into multiple + * {@link HttpContent}s. If the transfer encoding of the HTTP request is + * 'chunked' already, each chunk will be split into smaller chunks if the + * length of the chunk exceeds this value. If you prefer not to handle + * {@link HttpContent}s in your handler, insert {@link HttpObjectAggregator} + * after this decoder in the {@link ChannelPipeline}.
+ * + *

Parameters that control parsing behavior

+ * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
NameDefault valueMeaning
{@code allowDuplicateContentLengths}{@value #DEFAULT_ALLOW_DUPLICATE_CONTENT_LENGTHS}When set to {@code false}, will reject any messages that contain multiple Content-Length header fields. + * When set to {@code true}, will allow multiple Content-Length headers only if they are all the same decimal value. + * The duplicated field-values will be replaced with a single valid Content-Length field. + * See RFC 7230, Section 3.3.2.
{@code allowPartialChunks}{@value #DEFAULT_ALLOW_PARTIAL_CHUNKS}If the length of a chunk exceeds the {@link ByteBuf}s readable bytes and {@code allowPartialChunks} + * is set to {@code true}, the chunk will be split into multiple {@link HttpContent}s. + * Otherwise, if the chunk size does not exceed {@code maxChunkSize} and {@code allowPartialChunks} + * is set to {@code false}, the {@link ByteBuf} is not decoded into an {@link HttpContent} until + * the readable bytes are greater or equal to the chunk size.
+ * + *

Header Validation

+ * + * It is recommended to always enable header validation. + *

+ * Without header validation, your system can become vulnerable to + * + * CWE-113: Improper Neutralization of CRLF Sequences in HTTP Headers ('HTTP Response Splitting') + * . + *

+ * This recommendation stands even when both peers in the HTTP exchange are trusted, + * as it helps with defence-in-depth. + */ +public class HttpRequestDecoder extends HttpObjectDecoder { + + private static final AsciiString Host = AsciiString.cached("Host"); + private static final AsciiString Connection = AsciiString.cached("Connection"); + private static final AsciiString ContentType = AsciiString.cached("Content-Type"); + private static final AsciiString ContentLength = AsciiString.cached("Content-Length"); + + private static final int GET_AS_INT = 'G' | 'E' << 8 | 'T' << 16; + private static final int POST_AS_INT = 'P' | 'O' << 8 | 'S' << 16 | 'T' << 24; + private static final long HTTP_1_1_AS_LONG = 'H' | 'T' << 8 | 'T' << 16 | 'P' << 24 | (long) '/' << 32 | + (long) '1' << 40 | (long) '.' << 48 | (long) '1' << 56; + + private static final long HTTP_1_0_AS_LONG = 'H' | 'T' << 8 | 'T' << 16 | 'P' << 24 | (long) '/' << 32 | + (long) '1' << 40 | (long) '.' << 48 | (long) '0' << 56; + + private static final int HOST_AS_INT = 'H' | 'o' << 8 | 's' << 16 | 't' << 24; + + private static final long CONNECTION_AS_LONG_0 = 'C' | 'o' << 8 | 'n' << 16 | 'n' << 24 | + (long) 'e' << 32 | (long) 'c' << 40 | (long) 't' << 48 | (long) 'i' << 56; + + private static final short CONNECTION_AS_SHORT_1 = 'o' | 'n' << 8; + + private static final long CONTENT_AS_LONG = 'C' | 'o' << 8 | 'n' << 16 | 't' << 24 | + (long) 'e' << 32 | (long) 'n' << 40 | (long) 't' << 48 | (long) '-' << 56; + + private static final int TYPE_AS_INT = 'T' | 'y' << 8 | 'p' << 16 | 'e' << 24; + + private static final long LENGTH_AS_LONG = 'L' | 'e' << 8 | 'n' << 16 | 'g' << 24 | + (long) 't' << 32 | (long) 'h' << 40; + + /** + * Creates a new instance with the default + * {@code maxInitialLineLength (4096)}, {@code maxHeaderSize (8192)}, and + * {@code maxChunkSize (8192)}. + */ + public HttpRequestDecoder() { + } + + /** + * Creates a new instance with the specified parameters. + */ + public HttpRequestDecoder( + int maxInitialLineLength, int maxHeaderSize, int maxChunkSize) { + this(new HttpDecoderConfig() + .setMaxInitialLineLength(maxInitialLineLength) + .setMaxHeaderSize(maxHeaderSize) + .setMaxChunkSize(maxChunkSize)); + } + + /** + * @deprecated Prefer the {@link #HttpRequestDecoder(HttpDecoderConfig)} constructor, + * to always have header validation enabled. + */ + @Deprecated + public HttpRequestDecoder( + int maxInitialLineLength, int maxHeaderSize, int maxChunkSize, boolean validateHeaders) { + super(maxInitialLineLength, maxHeaderSize, maxChunkSize, DEFAULT_CHUNKED_SUPPORTED, validateHeaders); + } + + /** + * @deprecated Prefer the {@link #HttpRequestDecoder(HttpDecoderConfig)} constructor, + * to always have header validation enabled. + */ + @Deprecated + public HttpRequestDecoder( + int maxInitialLineLength, int maxHeaderSize, int maxChunkSize, boolean validateHeaders, + int initialBufferSize) { + super(maxInitialLineLength, maxHeaderSize, maxChunkSize, DEFAULT_CHUNKED_SUPPORTED, validateHeaders, + initialBufferSize); + } + + /** + * @deprecated Prefer the {@link #HttpRequestDecoder(HttpDecoderConfig)} constructor, + * to always have header validation enabled. + */ + @Deprecated + public HttpRequestDecoder( + int maxInitialLineLength, int maxHeaderSize, int maxChunkSize, boolean validateHeaders, + int initialBufferSize, boolean allowDuplicateContentLengths) { + super(maxInitialLineLength, maxHeaderSize, maxChunkSize, DEFAULT_CHUNKED_SUPPORTED, validateHeaders, + initialBufferSize, allowDuplicateContentLengths); + } + + /** + * @deprecated Prefer the {@link #HttpRequestDecoder(HttpDecoderConfig)} constructor, + * to always have header validation enabled. + */ + @Deprecated + public HttpRequestDecoder( + int maxInitialLineLength, int maxHeaderSize, int maxChunkSize, boolean validateHeaders, + int initialBufferSize, boolean allowDuplicateContentLengths, boolean allowPartialChunks) { + super(maxInitialLineLength, maxHeaderSize, maxChunkSize, DEFAULT_CHUNKED_SUPPORTED, validateHeaders, + initialBufferSize, allowDuplicateContentLengths, allowPartialChunks); + } + + /** + * Creates a new instance with the specified configuration. + */ + public HttpRequestDecoder(HttpDecoderConfig config) { + super(config); + } + + @Override + protected HttpMessage createMessage(String[] initialLine) throws Exception { + return new DefaultHttpRequest( + HttpVersion.valueOf(initialLine[2]), + HttpMethod.valueOf(initialLine[0]), initialLine[1], headersFactory); + } + + @Override + protected AsciiString splitHeaderName(final byte[] sb, final int start, final int length) { + final byte firstChar = sb[start]; + if (firstChar == 'H' && length == 4) { + if (isHost(sb, start)) { + return Host; + } + } else if (firstChar == 'C') { + if (length == 10) { + if (isConnection(sb, start)) { + return Connection; + } + } else if (length == 12) { + if (isContentType(sb, start)) { + return ContentType; + } + } else if (length == 14) { + if (isContentLength(sb, start)) { + return ContentLength; + } + } + } + return super.splitHeaderName(sb, start, length); + } + + private static boolean isHost(byte[] sb, int start) { + final int maybeHost = sb[start] | + sb[start + 1] << 8 | + sb[start + 2] << 16 | + sb[start + 3] << 24; + return maybeHost == HOST_AS_INT; + } + + private static boolean isConnection(byte[] sb, int start) { + final long maybeConnecti = sb[start] | + sb[start + 1] << 8 | + sb[start + 2] << 16 | + sb[start + 3] << 24 | + (long) sb[start + 4] << 32 | + (long) sb[start + 5] << 40 | + (long) sb[start + 6] << 48 | + (long) sb[start + 7] << 56; + if (maybeConnecti != CONNECTION_AS_LONG_0) { + return false; + } + final short maybeOn = (short) (sb[start + 8] | sb[start + 9] << 8); + return maybeOn == CONNECTION_AS_SHORT_1; + } + + private static boolean isContentType(byte[] sb, int start) { + final long maybeContent = sb[start] | + sb[start + 1] << 8 | + sb[start + 2] << 16 | + sb[start + 3] << 24 | + (long) sb[start + 4] << 32 | + (long) sb[start + 5] << 40 | + (long) sb[start + 6] << 48 | + (long) sb[start + 7] << 56; + if (maybeContent != CONTENT_AS_LONG) { + return false; + } + final int maybeType = sb[start + 8] | + sb[start + 9] << 8 | + sb[start + 10] << 16 | + sb[start + 11] << 24; + return maybeType == TYPE_AS_INT; + } + + private static boolean isContentLength(byte[] sb, int start) { + final long maybeContent = sb[start] | + sb[start + 1] << 8 | + sb[start + 2] << 16 | + sb[start + 3] << 24 | + (long) sb[start + 4] << 32 | + (long) sb[start + 5] << 40 | + (long) sb[start + 6] << 48 | + (long) sb[start + 7] << 56; + if (maybeContent != CONTENT_AS_LONG) { + return false; + } + final long maybeLength = sb[start + 8] | + sb[start + 9] << 8 | + sb[start + 10] << 16 | + sb[start + 11] << 24 | + (long) sb[start + 12] << 32 | + (long) sb[start + 13] << 40; + return maybeLength == LENGTH_AS_LONG; + } + + private static boolean isGetMethod(final byte[] sb, int start) { + final int maybeGet = sb[start] | + sb[start + 1] << 8 | + sb[start + 2] << 16; + return maybeGet == GET_AS_INT; + } + + private static boolean isPostMethod(final byte[] sb, int start) { + final int maybePost = sb[start] | + sb[start + 1] << 8 | + sb[start + 2] << 16 | + sb[start + 3] << 24; + return maybePost == POST_AS_INT; + } + + @Override + protected String splitFirstWordInitialLine(final byte[] sb, final int start, final int length) { + if (length == 3) { + if (isGetMethod(sb, start)) { + return HttpMethod.GET.name(); + } + } else if (length == 4) { + if (isPostMethod(sb, start)) { + return HttpMethod.POST.name(); + } + } + return super.splitFirstWordInitialLine(sb, start, length); + } + + @Override + protected String splitThirdWordInitialLine(final byte[] sb, final int start, final int length) { + if (length == 8) { + final long maybeHttp1_x = sb[start] | + sb[start + 1] << 8 | + sb[start + 2] << 16 | + sb[start + 3] << 24 | + (long) sb[start + 4] << 32 | + (long) sb[start + 5] << 40 | + (long) sb[start + 6] << 48 | + (long) sb[start + 7] << 56; + if (maybeHttp1_x == HTTP_1_1_AS_LONG) { + return HttpVersion.HTTP_1_1_STRING; + } else if (maybeHttp1_x == HTTP_1_0_AS_LONG) { + return HttpVersion.HTTP_1_0_STRING; + } + } + return super.splitThirdWordInitialLine(sb, start, length); + } + + @Override + protected HttpMessage createInvalidMessage() { + return new DefaultFullHttpRequest(HttpVersion.HTTP_1_0, HttpMethod.GET, "/bad-request", + Unpooled.buffer(0), headersFactory, trailersFactory); + } + + @Override + protected boolean isDecodingRequest() { + return true; + } + + @Override + protected boolean isContentAlwaysEmpty(final HttpMessage msg) { + // fast-path to save expensive O(n) checks; users can override createMessage + // and extends DefaultHttpRequest making implementing HttpResponse: + // this is why we cannot use instanceof DefaultHttpRequest here :( + if (msg.getClass() == DefaultHttpRequest.class) { + return false; + } + return super.isContentAlwaysEmpty(msg); + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpRequestEncoder.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpRequestEncoder.java new file mode 100644 index 0000000..a741e03 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpRequestEncoder.java @@ -0,0 +1,80 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.util.CharsetUtil; + +import static io.netty.handler.codec.http.HttpConstants.SP; + +/** + * Encodes an {@link HttpRequest} or an {@link HttpContent} into + * a {@link ByteBuf}. + */ +public class HttpRequestEncoder extends HttpObjectEncoder { + private static final char SLASH = '/'; + private static final char QUESTION_MARK = '?'; + private static final int SLASH_AND_SPACE_SHORT = (SLASH << 8) | SP; + private static final int SPACE_SLASH_AND_SPACE_MEDIUM = (SP << 16) | SLASH_AND_SPACE_SHORT; + + @Override + public boolean acceptOutboundMessage(Object msg) throws Exception { + return super.acceptOutboundMessage(msg) && !(msg instanceof HttpResponse); + } + + @Override + protected void encodeInitialLine(ByteBuf buf, HttpRequest request) throws Exception { + ByteBufUtil.copy(request.method().asciiName(), buf); + + String uri = request.uri(); + + if (uri.isEmpty()) { + // Add " / " as absolute path if uri is not present. + // See https://tools.ietf.org/html/rfc2616#section-5.1.2 + ByteBufUtil.writeMediumBE(buf, SPACE_SLASH_AND_SPACE_MEDIUM); + } else { + CharSequence uriCharSequence = uri; + boolean needSlash = false; + int start = uri.indexOf("://"); + if (start != -1 && uri.charAt(0) != SLASH) { + start += 3; + // Correctly handle query params. + // See https://github.com/netty/netty/issues/2732 + int index = uri.indexOf(QUESTION_MARK, start); + if (index == -1) { + if (uri.lastIndexOf(SLASH) < start) { + needSlash = true; + } + } else { + if (uri.lastIndexOf(SLASH, index) < start) { + uriCharSequence = new StringBuilder(uri).insert(index, SLASH); + } + } + } + buf.writeByte(SP).writeCharSequence(uriCharSequence, CharsetUtil.UTF_8); + if (needSlash) { + // write "/ " after uri + ByteBufUtil.writeShortBE(buf, SLASH_AND_SPACE_SHORT); + } else { + buf.writeByte(SP); + } + } + + request.protocolVersion().encode(buf); + ByteBufUtil.writeShortBE(buf, CRLF_SHORT); + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpResponse.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpResponse.java new file mode 100644 index 0000000..b0ba345 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpResponse.java @@ -0,0 +1,57 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +/** + * An HTTP response. + * + *

Accessing Cookies

+ *

+ * Unlike the Servlet API, {@link io.netty.handler.codec.http.cookie.Cookie} support is provided + * separately via {@link io.netty.handler.codec.http.cookie.ServerCookieDecoder}, + * {@link io.netty.handler.codec.http.cookie.ClientCookieDecoder}, + * {@link io.netty.handler.codec.http.cookie.ServerCookieEncoder}, + * and {@link io.netty.handler.codec.http.cookie.ClientCookieEncoder}. + * + * @see HttpRequest + * @see io.netty.handler.codec.http.cookie.ServerCookieDecoder + * @see io.netty.handler.codec.http.cookie.ClientCookieDecoder + * @see io.netty.handler.codec.http.cookie.ServerCookieEncoder + * @see io.netty.handler.codec.http.cookie.ClientCookieEncoder + */ +public interface HttpResponse extends HttpMessage { + + /** + * @deprecated Use {@link #status()} instead. + */ + @Deprecated + HttpResponseStatus getStatus(); + + /** + * Returns the status of this {@link HttpResponse}. + * + * @return The {@link HttpResponseStatus} of this {@link HttpResponse} + */ + HttpResponseStatus status(); + + /** + * Set the status of this {@link HttpResponse}. + */ + HttpResponse setStatus(HttpResponseStatus status); + + @Override + HttpResponse setProtocolVersion(HttpVersion version); +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpResponseDecoder.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpResponseDecoder.java new file mode 100644 index 0000000..38be391 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpResponseDecoder.java @@ -0,0 +1,208 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelPipeline; + +/** + * Decodes {@link ByteBuf}s into {@link HttpResponse}s and + * {@link HttpContent}s. + * + *

Parameters that prevents excessive memory consumption

+ * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
NameMeaning
{@code maxInitialLineLength}The maximum length of the initial line (e.g. {@code "HTTP/1.0 200 OK"}) + * If the length of the initial line exceeds this value, a + * {@link TooLongHttpLineException} will be raised.
{@code maxHeaderSize}The maximum length of all headers. If the sum of the length of each + * header exceeds this value, a {@link TooLongHttpHeaderException} will be raised.
{@code maxChunkSize}The maximum length of the content or each chunk. If the content length + * exceeds this value, the transfer encoding of the decoded response will be + * converted to 'chunked' and the content will be split into multiple + * {@link HttpContent}s. If the transfer encoding of the HTTP response is + * 'chunked' already, each chunk will be split into smaller chunks if the + * length of the chunk exceeds this value. If you prefer not to handle + * {@link HttpContent}s in your handler, insert {@link HttpObjectAggregator} + * after this decoder in the {@link ChannelPipeline}.
+ * + *

Parameters that control parsing behavior

+ * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
NameDefault valueMeaning
{@code allowDuplicateContentLengths}{@value #DEFAULT_ALLOW_DUPLICATE_CONTENT_LENGTHS}When set to {@code false}, will reject any messages that contain multiple Content-Length header fields. + * When set to {@code true}, will allow multiple Content-Length headers only if they are all the same decimal value. + * The duplicated field-values will be replaced with a single valid Content-Length field. + * See RFC 7230, Section 3.3.2.
{@code allowPartialChunks}{@value #DEFAULT_ALLOW_PARTIAL_CHUNKS}If the length of a chunk exceeds the {@link ByteBuf}s readable bytes and {@code allowPartialChunks} + * is set to {@code true}, the chunk will be split into multiple {@link HttpContent}s. + * Otherwise, if the chunk size does not exceed {@code maxChunkSize} and {@code allowPartialChunks} + * is set to {@code false}, the {@link ByteBuf} is not decoded into an {@link HttpContent} until + * the readable bytes are greater or equal to the chunk size.
+ * + *

Decoding a response for a HEAD request

+ *

+ * Unlike other HTTP requests, the successful response of a HEAD + * request does not have any content even if there is Content-Length + * header. Because {@link HttpResponseDecoder} is not able to determine if the + * response currently being decoded is associated with a HEAD request, + * you must override {@link #isContentAlwaysEmpty(HttpMessage)} to return + * true for the response of the HEAD request. + *

+ * If you are writing an HTTP client that issues a HEAD request, + * please use {@link HttpClientCodec} instead of this decoder. It will perform + * additional state management to handle the responses for HEAD + * requests correctly. + *

+ * + *

Decoding a response for a CONNECT request

+ *

+ * You also need to do additional state management to handle the response of a + * CONNECT request properly, like you did for HEAD. One + * difference is that the decoder should stop decoding completely after decoding + * the successful 200 response since the connection is not an HTTP connection + * anymore. + *

+ * {@link HttpClientCodec} also handles this edge case correctly, so you have to + * use {@link HttpClientCodec} if you are writing an HTTP client that issues a + * CONNECT request. + *

+ * + *

Header Validation

+ * + * It is recommended to always enable header validation. + *

+ * Without header validation, your system can become vulnerable to + * + * CWE-113: Improper Neutralization of CRLF Sequences in HTTP Headers ('HTTP Response Splitting') + * . + *

+ * This recommendation stands even when both peers in the HTTP exchange are trusted, + * as it helps with defence-in-depth. + */ +public class HttpResponseDecoder extends HttpObjectDecoder { + + private static final HttpResponseStatus UNKNOWN_STATUS = new HttpResponseStatus(999, "Unknown"); + + /** + * Creates a new instance with the default + * {@code maxInitialLineLength (4096)}, {@code maxHeaderSize (8192)}, and + * {@code maxChunkSize (8192)}. + */ + public HttpResponseDecoder() { + } + + /** + * Creates a new instance with the specified parameters. + */ + public HttpResponseDecoder( + int maxInitialLineLength, int maxHeaderSize, int maxChunkSize) { + super(new HttpDecoderConfig() + .setMaxInitialLineLength(maxInitialLineLength) + .setMaxHeaderSize(maxHeaderSize) + .setMaxChunkSize(maxChunkSize)); + } + + /** + * @deprecated Prefer the {@link #HttpResponseDecoder(HttpDecoderConfig)} constructor. + */ + @Deprecated + public HttpResponseDecoder( + int maxInitialLineLength, int maxHeaderSize, int maxChunkSize, boolean validateHeaders) { + super(maxInitialLineLength, maxHeaderSize, maxChunkSize, DEFAULT_CHUNKED_SUPPORTED, validateHeaders); + } + + /** + * @deprecated Prefer the {@link #HttpResponseDecoder(HttpDecoderConfig)} constructor. + */ + @Deprecated + public HttpResponseDecoder( + int maxInitialLineLength, int maxHeaderSize, int maxChunkSize, boolean validateHeaders, + int initialBufferSize) { + super(maxInitialLineLength, maxHeaderSize, maxChunkSize, DEFAULT_CHUNKED_SUPPORTED, validateHeaders, + initialBufferSize); + } + + /** + * @deprecated Prefer the {@link #HttpResponseDecoder(HttpDecoderConfig)} constructor. + */ + @Deprecated + public HttpResponseDecoder( + int maxInitialLineLength, int maxHeaderSize, int maxChunkSize, boolean validateHeaders, + int initialBufferSize, boolean allowDuplicateContentLengths) { + super(maxInitialLineLength, maxHeaderSize, maxChunkSize, DEFAULT_CHUNKED_SUPPORTED, validateHeaders, + initialBufferSize, allowDuplicateContentLengths); + } + + /** + * @deprecated Prefer the {@link #HttpResponseDecoder(HttpDecoderConfig)} constructor. + */ + @Deprecated + public HttpResponseDecoder( + int maxInitialLineLength, int maxHeaderSize, int maxChunkSize, boolean validateHeaders, + int initialBufferSize, boolean allowDuplicateContentLengths, boolean allowPartialChunks) { + super(maxInitialLineLength, maxHeaderSize, maxChunkSize, DEFAULT_CHUNKED_SUPPORTED, validateHeaders, + initialBufferSize, allowDuplicateContentLengths, allowPartialChunks); + } + + /** + * Creates a new instance with the specified configuration. + */ + public HttpResponseDecoder(HttpDecoderConfig config) { + super(config); + } + + @Override + protected HttpMessage createMessage(String[] initialLine) { + return new DefaultHttpResponse( + HttpVersion.valueOf(initialLine[0]), + HttpResponseStatus.valueOf(Integer.parseInt(initialLine[1]), initialLine[2]), headersFactory); + } + + @Override + protected HttpMessage createInvalidMessage() { + return new DefaultFullHttpResponse(HttpVersion.HTTP_1_0, UNKNOWN_STATUS, Unpooled.buffer(0), + headersFactory, trailersFactory); + } + + @Override + protected boolean isDecodingRequest() { + return false; + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpResponseEncoder.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpResponseEncoder.java new file mode 100644 index 0000000..fd52993 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpResponseEncoder.java @@ -0,0 +1,98 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; + +import static io.netty.handler.codec.http.HttpConstants.*; + +/** + * Encodes an {@link HttpResponse} or an {@link HttpContent} into + * a {@link ByteBuf}. + */ +public class HttpResponseEncoder extends HttpObjectEncoder { + + @Override + public boolean acceptOutboundMessage(Object msg) throws Exception { + // JDK type checks vs non-implemented interfaces costs O(N), where + // N is the number of interfaces already implemented by the concrete type that's being tested. + // !(msg instanceof HttpRequest) is supposed to always be true (and meaning that msg isn't a HttpRequest), + // but sadly was part of the original behaviour of this method and cannot be removed. + // We place here exact checks vs DefaultHttpResponse and DefaultFullHttpResponse because bad users can + // extends such types and make them to implement HttpRequest (non-sense, but still possible). + final Class msgClass = msg.getClass(); + if (msgClass == DefaultFullHttpResponse.class || msgClass == DefaultHttpResponse.class) { + return true; + } + return super.acceptOutboundMessage(msg) && !(msg instanceof HttpRequest); + } + + @Override + protected void encodeInitialLine(ByteBuf buf, HttpResponse response) throws Exception { + response.protocolVersion().encode(buf); + buf.writeByte(SP); + response.status().encode(buf); + ByteBufUtil.writeShortBE(buf, CRLF_SHORT); + } + + @Override + protected void sanitizeHeadersBeforeEncode(HttpResponse msg, boolean isAlwaysEmpty) { + if (isAlwaysEmpty) { + HttpResponseStatus status = msg.status(); + if (status.codeClass() == HttpStatusClass.INFORMATIONAL || + status.code() == HttpResponseStatus.NO_CONTENT.code()) { + + // Stripping Content-Length: + // See https://tools.ietf.org/html/rfc7230#section-3.3.2 + msg.headers().remove(HttpHeaderNames.CONTENT_LENGTH); + + // Stripping Transfer-Encoding: + // See https://tools.ietf.org/html/rfc7230#section-3.3.1 + msg.headers().remove(HttpHeaderNames.TRANSFER_ENCODING); + } else if (status.code() == HttpResponseStatus.RESET_CONTENT.code()) { + + // Stripping Transfer-Encoding: + msg.headers().remove(HttpHeaderNames.TRANSFER_ENCODING); + + // Set Content-Length: 0 + // https://httpstatuses.com/205 + msg.headers().setInt(HttpHeaderNames.CONTENT_LENGTH, 0); + } + } + } + + @Override + protected boolean isContentAlwaysEmpty(HttpResponse msg) { + // Correctly handle special cases as stated in: + // https://tools.ietf.org/html/rfc7230#section-3.3.3 + HttpResponseStatus status = msg.status(); + + if (status.codeClass() == HttpStatusClass.INFORMATIONAL) { + + if (status.code() == HttpResponseStatus.SWITCHING_PROTOCOLS.code()) { + // We need special handling for WebSockets version 00 as it will include an body. + // Fortunally this version should not really be used in the wild very often. + // See https://tools.ietf.org/html/draft-ietf-hybi-thewebsocketprotocol-00#section-1.2 + return msg.headers().contains(HttpHeaderNames.SEC_WEBSOCKET_VERSION); + } + return true; + } + return status.code() == HttpResponseStatus.NO_CONTENT.code() || + status.code() == HttpResponseStatus.NOT_MODIFIED.code() || + status.code() == HttpResponseStatus.RESET_CONTENT.code(); + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpResponseStatus.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpResponseStatus.java new file mode 100644 index 0000000..7b20a48 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpResponseStatus.java @@ -0,0 +1,649 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.util.AsciiString; +import io.netty.util.CharsetUtil; +import io.netty.util.internal.ObjectUtil; + +import static io.netty.handler.codec.http.HttpConstants.SP; +import static io.netty.util.ByteProcessor.FIND_ASCII_SPACE; +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; +import static java.lang.Integer.parseInt; + +/** + * The response code and its description of HTTP or its derived protocols, such as + * RTSP and + * ICAP. + */ +public class HttpResponseStatus implements Comparable { + + /** + * 100 Continue + */ + public static final HttpResponseStatus CONTINUE = newStatus(100, "Continue"); + + /** + * 101 Switching Protocols + */ + public static final HttpResponseStatus SWITCHING_PROTOCOLS = newStatus(101, "Switching Protocols"); + + /** + * 102 Processing (WebDAV, RFC2518) + */ + public static final HttpResponseStatus PROCESSING = newStatus(102, "Processing"); + + /** + * 103 Early Hints (RFC 8297) + */ + public static final HttpResponseStatus EARLY_HINTS = newStatus(103, "Early Hints"); + + /** + * 200 OK + */ + public static final HttpResponseStatus OK = newStatus(200, "OK"); + + /** + * 201 Created + */ + public static final HttpResponseStatus CREATED = newStatus(201, "Created"); + + /** + * 202 Accepted + */ + public static final HttpResponseStatus ACCEPTED = newStatus(202, "Accepted"); + + /** + * 203 Non-Authoritative Information (since HTTP/1.1) + */ + public static final HttpResponseStatus NON_AUTHORITATIVE_INFORMATION = + newStatus(203, "Non-Authoritative Information"); + + /** + * 204 No Content + */ + public static final HttpResponseStatus NO_CONTENT = newStatus(204, "No Content"); + + /** + * 205 Reset Content + */ + public static final HttpResponseStatus RESET_CONTENT = newStatus(205, "Reset Content"); + + /** + * 206 Partial Content + */ + public static final HttpResponseStatus PARTIAL_CONTENT = newStatus(206, "Partial Content"); + + /** + * 207 Multi-Status (WebDAV, RFC2518) + */ + public static final HttpResponseStatus MULTI_STATUS = newStatus(207, "Multi-Status"); + + /** + * 300 Multiple Choices + */ + public static final HttpResponseStatus MULTIPLE_CHOICES = newStatus(300, "Multiple Choices"); + + /** + * 301 Moved Permanently + */ + public static final HttpResponseStatus MOVED_PERMANENTLY = newStatus(301, "Moved Permanently"); + + /** + * 302 Found + */ + public static final HttpResponseStatus FOUND = newStatus(302, "Found"); + + /** + * 303 See Other (since HTTP/1.1) + */ + public static final HttpResponseStatus SEE_OTHER = newStatus(303, "See Other"); + + /** + * 304 Not Modified + */ + public static final HttpResponseStatus NOT_MODIFIED = newStatus(304, "Not Modified"); + + /** + * 305 Use Proxy (since HTTP/1.1) + */ + public static final HttpResponseStatus USE_PROXY = newStatus(305, "Use Proxy"); + + /** + * 307 Temporary Redirect (since HTTP/1.1) + */ + public static final HttpResponseStatus TEMPORARY_REDIRECT = newStatus(307, "Temporary Redirect"); + + /** + * 308 Permanent Redirect (RFC7538) + */ + public static final HttpResponseStatus PERMANENT_REDIRECT = newStatus(308, "Permanent Redirect"); + + /** + * 400 Bad Request + */ + public static final HttpResponseStatus BAD_REQUEST = newStatus(400, "Bad Request"); + + /** + * 401 Unauthorized + */ + public static final HttpResponseStatus UNAUTHORIZED = newStatus(401, "Unauthorized"); + + /** + * 402 Payment Required + */ + public static final HttpResponseStatus PAYMENT_REQUIRED = newStatus(402, "Payment Required"); + + /** + * 403 Forbidden + */ + public static final HttpResponseStatus FORBIDDEN = newStatus(403, "Forbidden"); + + /** + * 404 Not Found + */ + public static final HttpResponseStatus NOT_FOUND = newStatus(404, "Not Found"); + + /** + * 405 Method Not Allowed + */ + public static final HttpResponseStatus METHOD_NOT_ALLOWED = newStatus(405, "Method Not Allowed"); + + /** + * 406 Not Acceptable + */ + public static final HttpResponseStatus NOT_ACCEPTABLE = newStatus(406, "Not Acceptable"); + + /** + * 407 Proxy Authentication Required + */ + public static final HttpResponseStatus PROXY_AUTHENTICATION_REQUIRED = + newStatus(407, "Proxy Authentication Required"); + + /** + * 408 Request Timeout + */ + public static final HttpResponseStatus REQUEST_TIMEOUT = newStatus(408, "Request Timeout"); + + /** + * 409 Conflict + */ + public static final HttpResponseStatus CONFLICT = newStatus(409, "Conflict"); + + /** + * 410 Gone + */ + public static final HttpResponseStatus GONE = newStatus(410, "Gone"); + + /** + * 411 Length Required + */ + public static final HttpResponseStatus LENGTH_REQUIRED = newStatus(411, "Length Required"); + + /** + * 412 Precondition Failed + */ + public static final HttpResponseStatus PRECONDITION_FAILED = newStatus(412, "Precondition Failed"); + + /** + * 413 Request Entity Too Large + */ + public static final HttpResponseStatus REQUEST_ENTITY_TOO_LARGE = + newStatus(413, "Request Entity Too Large"); + + /** + * 414 Request-URI Too Long + */ + public static final HttpResponseStatus REQUEST_URI_TOO_LONG = newStatus(414, "Request-URI Too Long"); + + /** + * 415 Unsupported Media Type + */ + public static final HttpResponseStatus UNSUPPORTED_MEDIA_TYPE = newStatus(415, "Unsupported Media Type"); + + /** + * 416 Requested Range Not Satisfiable + */ + public static final HttpResponseStatus REQUESTED_RANGE_NOT_SATISFIABLE = + newStatus(416, "Requested Range Not Satisfiable"); + + /** + * 417 Expectation Failed + */ + public static final HttpResponseStatus EXPECTATION_FAILED = newStatus(417, "Expectation Failed"); + + /** + * 421 Misdirected Request + * + * @see 421 (Misdirected Request) Status Code + */ + public static final HttpResponseStatus MISDIRECTED_REQUEST = newStatus(421, "Misdirected Request"); + + /** + * 422 Unprocessable Entity (WebDAV, RFC4918) + */ + public static final HttpResponseStatus UNPROCESSABLE_ENTITY = newStatus(422, "Unprocessable Entity"); + + /** + * 423 Locked (WebDAV, RFC4918) + */ + public static final HttpResponseStatus LOCKED = newStatus(423, "Locked"); + + /** + * 424 Failed Dependency (WebDAV, RFC4918) + */ + public static final HttpResponseStatus FAILED_DEPENDENCY = newStatus(424, "Failed Dependency"); + + /** + * 425 Unordered Collection (WebDAV, RFC3648) + */ + public static final HttpResponseStatus UNORDERED_COLLECTION = newStatus(425, "Unordered Collection"); + + /** + * 426 Upgrade Required (RFC2817) + */ + public static final HttpResponseStatus UPGRADE_REQUIRED = newStatus(426, "Upgrade Required"); + + /** + * 428 Precondition Required (RFC6585) + */ + public static final HttpResponseStatus PRECONDITION_REQUIRED = newStatus(428, "Precondition Required"); + + /** + * 429 Too Many Requests (RFC6585) + */ + public static final HttpResponseStatus TOO_MANY_REQUESTS = newStatus(429, "Too Many Requests"); + + /** + * 431 Request Header Fields Too Large (RFC6585) + */ + public static final HttpResponseStatus REQUEST_HEADER_FIELDS_TOO_LARGE = + newStatus(431, "Request Header Fields Too Large"); + + /** + * 500 Internal Server Error + */ + public static final HttpResponseStatus INTERNAL_SERVER_ERROR = newStatus(500, "Internal Server Error"); + + /** + * 501 Not Implemented + */ + public static final HttpResponseStatus NOT_IMPLEMENTED = newStatus(501, "Not Implemented"); + + /** + * 502 Bad Gateway + */ + public static final HttpResponseStatus BAD_GATEWAY = newStatus(502, "Bad Gateway"); + + /** + * 503 Service Unavailable + */ + public static final HttpResponseStatus SERVICE_UNAVAILABLE = newStatus(503, "Service Unavailable"); + + /** + * 504 Gateway Timeout + */ + public static final HttpResponseStatus GATEWAY_TIMEOUT = newStatus(504, "Gateway Timeout"); + + /** + * 505 HTTP Version Not Supported + */ + public static final HttpResponseStatus HTTP_VERSION_NOT_SUPPORTED = + newStatus(505, "HTTP Version Not Supported"); + + /** + * 506 Variant Also Negotiates (RFC2295) + */ + public static final HttpResponseStatus VARIANT_ALSO_NEGOTIATES = newStatus(506, "Variant Also Negotiates"); + + /** + * 507 Insufficient Storage (WebDAV, RFC4918) + */ + public static final HttpResponseStatus INSUFFICIENT_STORAGE = newStatus(507, "Insufficient Storage"); + + /** + * 510 Not Extended (RFC2774) + */ + public static final HttpResponseStatus NOT_EXTENDED = newStatus(510, "Not Extended"); + + /** + * 511 Network Authentication Required (RFC6585) + */ + public static final HttpResponseStatus NETWORK_AUTHENTICATION_REQUIRED = + newStatus(511, "Network Authentication Required"); + + private static HttpResponseStatus newStatus(int statusCode, String reasonPhrase) { + return new HttpResponseStatus(statusCode, reasonPhrase, true); + } + + /** + * Returns the {@link HttpResponseStatus} represented by the specified code. + * If the specified code is a standard HTTP status code, a cached instance + * will be returned. Otherwise, a new instance will be returned. + */ + public static HttpResponseStatus valueOf(int code) { + HttpResponseStatus status = valueOf0(code); + return status != null ? status : new HttpResponseStatus(code); + } + + private static HttpResponseStatus valueOf0(int code) { + switch (code) { + case 100: + return CONTINUE; + case 101: + return SWITCHING_PROTOCOLS; + case 102: + return PROCESSING; + case 103: + return EARLY_HINTS; + case 200: + return OK; + case 201: + return CREATED; + case 202: + return ACCEPTED; + case 203: + return NON_AUTHORITATIVE_INFORMATION; + case 204: + return NO_CONTENT; + case 205: + return RESET_CONTENT; + case 206: + return PARTIAL_CONTENT; + case 207: + return MULTI_STATUS; + case 300: + return MULTIPLE_CHOICES; + case 301: + return MOVED_PERMANENTLY; + case 302: + return FOUND; + case 303: + return SEE_OTHER; + case 304: + return NOT_MODIFIED; + case 305: + return USE_PROXY; + case 307: + return TEMPORARY_REDIRECT; + case 308: + return PERMANENT_REDIRECT; + case 400: + return BAD_REQUEST; + case 401: + return UNAUTHORIZED; + case 402: + return PAYMENT_REQUIRED; + case 403: + return FORBIDDEN; + case 404: + return NOT_FOUND; + case 405: + return METHOD_NOT_ALLOWED; + case 406: + return NOT_ACCEPTABLE; + case 407: + return PROXY_AUTHENTICATION_REQUIRED; + case 408: + return REQUEST_TIMEOUT; + case 409: + return CONFLICT; + case 410: + return GONE; + case 411: + return LENGTH_REQUIRED; + case 412: + return PRECONDITION_FAILED; + case 413: + return REQUEST_ENTITY_TOO_LARGE; + case 414: + return REQUEST_URI_TOO_LONG; + case 415: + return UNSUPPORTED_MEDIA_TYPE; + case 416: + return REQUESTED_RANGE_NOT_SATISFIABLE; + case 417: + return EXPECTATION_FAILED; + case 421: + return MISDIRECTED_REQUEST; + case 422: + return UNPROCESSABLE_ENTITY; + case 423: + return LOCKED; + case 424: + return FAILED_DEPENDENCY; + case 425: + return UNORDERED_COLLECTION; + case 426: + return UPGRADE_REQUIRED; + case 428: + return PRECONDITION_REQUIRED; + case 429: + return TOO_MANY_REQUESTS; + case 431: + return REQUEST_HEADER_FIELDS_TOO_LARGE; + case 500: + return INTERNAL_SERVER_ERROR; + case 501: + return NOT_IMPLEMENTED; + case 502: + return BAD_GATEWAY; + case 503: + return SERVICE_UNAVAILABLE; + case 504: + return GATEWAY_TIMEOUT; + case 505: + return HTTP_VERSION_NOT_SUPPORTED; + case 506: + return VARIANT_ALSO_NEGOTIATES; + case 507: + return INSUFFICIENT_STORAGE; + case 510: + return NOT_EXTENDED; + case 511: + return NETWORK_AUTHENTICATION_REQUIRED; + } + return null; + } + + /** + * Returns the {@link HttpResponseStatus} represented by the specified {@code code} and {@code reasonPhrase}. + * If the specified code is a standard HTTP status {@code code} and {@code reasonPhrase}, a cached instance + * will be returned. Otherwise, a new instance will be returned. + * @param code The response code value. + * @param reasonPhrase The response code reason phrase. + * @return the {@link HttpResponseStatus} represented by the specified {@code code} and {@code reasonPhrase}. + */ + public static HttpResponseStatus valueOf(int code, String reasonPhrase) { + HttpResponseStatus responseStatus = valueOf0(code); + return responseStatus != null && responseStatus.reasonPhrase().contentEquals(reasonPhrase) ? responseStatus : + new HttpResponseStatus(code, reasonPhrase); + } + + /** + * Parses the specified HTTP status line into a {@link HttpResponseStatus}. The expected formats of the line are: + *

    + *
  • {@code statusCode} (e.g. 200)
  • + *
  • {@code statusCode} {@code reasonPhrase} (e.g. 404 Not Found)
  • + *
+ * + * @throws IllegalArgumentException if the specified status line is malformed + */ + public static HttpResponseStatus parseLine(CharSequence line) { + return (line instanceof AsciiString) ? parseLine((AsciiString) line) : parseLine(line.toString()); + } + + /** + * Parses the specified HTTP status line into a {@link HttpResponseStatus}. The expected formats of the line are: + *
    + *
  • {@code statusCode} (e.g. 200)
  • + *
  • {@code statusCode} {@code reasonPhrase} (e.g. 404 Not Found)
  • + *
+ * + * @throws IllegalArgumentException if the specified status line is malformed + */ + public static HttpResponseStatus parseLine(String line) { + try { + int space = line.indexOf(' '); + return space == -1 ? valueOf(parseInt(line)) : + valueOf(parseInt(line.substring(0, space)), line.substring(space + 1)); + } catch (Exception e) { + throw new IllegalArgumentException("malformed status line: " + line, e); + } + } + + /** + * Parses the specified HTTP status line into a {@link HttpResponseStatus}. The expected formats of the line are: + *
    + *
  • {@code statusCode} (e.g. 200)
  • + *
  • {@code statusCode} {@code reasonPhrase} (e.g. 404 Not Found)
  • + *
+ * + * @throws IllegalArgumentException if the specified status line is malformed + */ + public static HttpResponseStatus parseLine(AsciiString line) { + try { + int space = line.forEachByte(FIND_ASCII_SPACE); + return space == -1 ? valueOf(line.parseInt()) : valueOf(line.parseInt(0, space), line.toString(space + 1)); + } catch (Exception e) { + throw new IllegalArgumentException("malformed status line: " + line, e); + } + } + + private final int code; + private final AsciiString codeAsText; + private final HttpStatusClass codeClass; + + private final String reasonPhrase; + private final byte[] bytes; + + /** + * Creates a new instance with the specified {@code code} and the auto-generated default reason phrase. + */ + private HttpResponseStatus(int code) { + this(code, HttpStatusClass.valueOf(code).defaultReasonPhrase() + " (" + code + ')', false); + } + + /** + * Creates a new instance with the specified {@code code} and its {@code reasonPhrase}. + */ + public HttpResponseStatus(int code, String reasonPhrase) { + this(code, reasonPhrase, false); + } + + private HttpResponseStatus(int code, String reasonPhrase, boolean bytes) { + checkPositiveOrZero(code, "code"); + ObjectUtil.checkNotNull(reasonPhrase, "reasonPhrase"); + + for (int i = 0; i < reasonPhrase.length(); i ++) { + char c = reasonPhrase.charAt(i); + // Check prohibited characters. + switch (c) { + case '\n': case '\r': + throw new IllegalArgumentException( + "reasonPhrase contains one of the following prohibited characters: " + + "\\r\\n: " + reasonPhrase); + } + } + + this.code = code; + this.codeClass = HttpStatusClass.valueOf(code); + String codeString = Integer.toString(code); + codeAsText = new AsciiString(codeString); + this.reasonPhrase = reasonPhrase; + if (bytes) { + this.bytes = (codeString + ' ' + reasonPhrase).getBytes(CharsetUtil.US_ASCII); + } else { + this.bytes = null; + } + } + + /** + * Returns the code of this {@link HttpResponseStatus}. + */ + public int code() { + return code; + } + + /** + * Returns the status code as {@link AsciiString}. + */ + public AsciiString codeAsText() { + return codeAsText; + } + + /** + * Returns the reason phrase of this {@link HttpResponseStatus}. + */ + public String reasonPhrase() { + return reasonPhrase; + } + + /** + * Returns the class of this {@link HttpResponseStatus} + */ + public HttpStatusClass codeClass() { + return this.codeClass; + } + + @Override + public int hashCode() { + return code(); + } + + /** + * Equality of {@link HttpResponseStatus} only depends on {@link #code()}. The + * reason phrase is not considered for equality. + */ + @Override + public boolean equals(Object o) { + if (!(o instanceof HttpResponseStatus)) { + return false; + } + + return code() == ((HttpResponseStatus) o).code(); + } + + /** + * Equality of {@link HttpResponseStatus} only depends on {@link #code()}. The + * reason phrase is not considered for equality. + */ + @Override + public int compareTo(HttpResponseStatus o) { + return code() - o.code(); + } + + @Override + public String toString() { + return new StringBuilder(reasonPhrase.length() + 4) + .append(codeAsText) + .append(' ') + .append(reasonPhrase) + .toString(); + } + + void encode(ByteBuf buf) { + if (bytes == null) { + ByteBufUtil.copy(codeAsText, buf); + buf.writeByte(SP); + buf.writeCharSequence(reasonPhrase, CharsetUtil.US_ASCII); + } else { + buf.writeBytes(bytes); + } + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpScheme.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpScheme.java new file mode 100644 index 0000000..97a6c6d --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpScheme.java @@ -0,0 +1,70 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.util.AsciiString; + +/** + * Defines the common schemes used for the HTTP protocol as defined by + * rfc7230. + */ +public final class HttpScheme { + + /** + * Scheme for non-secure HTTP connection. + */ + public static final HttpScheme HTTP = new HttpScheme(80, "http"); + + /** + * Scheme for secure HTTP connection. + */ + public static final HttpScheme HTTPS = new HttpScheme(443, "https"); + + private final int port; + private final AsciiString name; + + private HttpScheme(int port, String name) { + this.port = port; + this.name = AsciiString.cached(name); + } + + public AsciiString name() { + return name; + } + + public int port() { + return port; + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof HttpScheme)) { + return false; + } + HttpScheme other = (HttpScheme) o; + return other.port() == port && other.name().equals(name); + } + + @Override + public int hashCode() { + return port * 31 + name.hashCode(); + } + + @Override + public String toString() { + return name.toString(); + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpServerCodec.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpServerCodec.java new file mode 100644 index 0000000..55a4e3d --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpServerCodec.java @@ -0,0 +1,201 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.CombinedChannelDuplexHandler; + +import java.util.ArrayDeque; +import java.util.List; +import java.util.Queue; + +import static io.netty.handler.codec.http.HttpObjectDecoder.DEFAULT_MAX_CHUNK_SIZE; +import static io.netty.handler.codec.http.HttpObjectDecoder.DEFAULT_MAX_HEADER_SIZE; +import static io.netty.handler.codec.http.HttpObjectDecoder.DEFAULT_MAX_INITIAL_LINE_LENGTH; +import static io.netty.handler.codec.http.HttpObjectDecoder.DEFAULT_VALIDATE_HEADERS; + +/** + * A combination of {@link HttpRequestDecoder} and {@link HttpResponseEncoder} + * which enables easier server side HTTP implementation. + * + *

Header Validation

+ * + * It is recommended to always enable header validation. + *

+ * Without header validation, your system can become vulnerable to + * + * CWE-113: Improper Neutralization of CRLF Sequences in HTTP Headers ('HTTP Response Splitting') + * . + *

+ * This recommendation stands even when both peers in the HTTP exchange are trusted, + * as it helps with defence-in-depth. + * + * @see HttpClientCodec + */ +public final class HttpServerCodec extends CombinedChannelDuplexHandler + implements HttpServerUpgradeHandler.SourceCodec { + + /** A queue that is used for correlating a request and a response. */ + private final Queue queue = new ArrayDeque(); + + /** + * Creates a new instance with the default decoder options + * ({@code maxInitialLineLength (4096)}, {@code maxHeaderSize (8192)}, and + * {@code maxChunkSize (8192)}). + */ + public HttpServerCodec() { + this(DEFAULT_MAX_INITIAL_LINE_LENGTH, DEFAULT_MAX_HEADER_SIZE, DEFAULT_MAX_CHUNK_SIZE); + } + + /** + * Creates a new instance with the specified decoder options. + */ + public HttpServerCodec(int maxInitialLineLength, int maxHeaderSize, int maxChunkSize) { + this(new HttpDecoderConfig() + .setMaxInitialLineLength(maxInitialLineLength) + .setMaxHeaderSize(maxHeaderSize) + .setMaxChunkSize(maxChunkSize)); + } + + /** + * Creates a new instance with the specified decoder options. + * + * @deprecated Prefer the {@link #HttpServerCodec(HttpDecoderConfig)} constructor, + * to always enable header validation. + */ + @Deprecated + public HttpServerCodec(int maxInitialLineLength, int maxHeaderSize, int maxChunkSize, boolean validateHeaders) { + this(new HttpDecoderConfig() + .setMaxInitialLineLength(maxInitialLineLength) + .setMaxHeaderSize(maxHeaderSize) + .setMaxChunkSize(maxChunkSize) + .setValidateHeaders(validateHeaders)); + } + + /** + * Creates a new instance with the specified decoder options. + * + * @deprecated Prefer the {@link #HttpServerCodec(HttpDecoderConfig)} constructor, to always enable header + * validation. + */ + @Deprecated + public HttpServerCodec(int maxInitialLineLength, int maxHeaderSize, int maxChunkSize, boolean validateHeaders, + int initialBufferSize) { + this(new HttpDecoderConfig() + .setMaxInitialLineLength(maxInitialLineLength) + .setMaxHeaderSize(maxHeaderSize) + .setMaxChunkSize(maxChunkSize) + .setValidateHeaders(validateHeaders) + .setInitialBufferSize(initialBufferSize)); + } + + /** + * Creates a new instance with the specified decoder options. + * + * @deprecated Prefer the {@link #HttpServerCodec(HttpDecoderConfig)} constructor, + * to always enable header validation. + */ + @Deprecated + public HttpServerCodec(int maxInitialLineLength, int maxHeaderSize, int maxChunkSize, boolean validateHeaders, + int initialBufferSize, boolean allowDuplicateContentLengths) { + this(new HttpDecoderConfig() + .setMaxInitialLineLength(maxInitialLineLength) + .setMaxHeaderSize(maxHeaderSize) + .setMaxChunkSize(maxChunkSize) + .setValidateHeaders(validateHeaders) + .setInitialBufferSize(initialBufferSize) + .setAllowDuplicateContentLengths(allowDuplicateContentLengths)); + } + + /** + * Creates a new instance with the specified decoder options. + * + * @deprecated Prefer the {@link #HttpServerCodec(HttpDecoderConfig)} constructor, + * to always enable header validation. + */ + @Deprecated + public HttpServerCodec(int maxInitialLineLength, int maxHeaderSize, int maxChunkSize, boolean validateHeaders, + int initialBufferSize, boolean allowDuplicateContentLengths, boolean allowPartialChunks) { + this(new HttpDecoderConfig() + .setMaxInitialLineLength(maxInitialLineLength) + .setMaxHeaderSize(maxHeaderSize) + .setMaxChunkSize(maxChunkSize) + .setValidateHeaders(validateHeaders) + .setInitialBufferSize(initialBufferSize) + .setAllowDuplicateContentLengths(allowDuplicateContentLengths) + .setAllowPartialChunks(allowPartialChunks)); + } + + /** + * Creates a new instance with the specified decoder configuration. + */ + public HttpServerCodec(HttpDecoderConfig config) { + init(new HttpServerRequestDecoder(config), new HttpServerResponseEncoder()); + } + + /** + * Upgrades to another protocol from HTTP. Removes the {@link HttpRequestDecoder} and + * {@link HttpResponseEncoder} from the pipeline. + */ + @Override + public void upgradeFrom(ChannelHandlerContext ctx) { + ctx.pipeline().remove(this); + } + + private final class HttpServerRequestDecoder extends HttpRequestDecoder { + HttpServerRequestDecoder(HttpDecoderConfig config) { + super(config); + } + + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf buffer, List out) throws Exception { + int oldSize = out.size(); + super.decode(ctx, buffer, out); + int size = out.size(); + for (int i = oldSize; i < size; i++) { + Object obj = out.get(i); + if (obj instanceof HttpRequest) { + queue.add(((HttpRequest) obj).method()); + } + } + } + } + + private final class HttpServerResponseEncoder extends HttpResponseEncoder { + + private HttpMethod method; + + @Override + protected void sanitizeHeadersBeforeEncode(HttpResponse msg, boolean isAlwaysEmpty) { + if (!isAlwaysEmpty && HttpMethod.CONNECT.equals(method) + && msg.status().codeClass() == HttpStatusClass.SUCCESS) { + // Stripping Transfer-Encoding: + // See https://tools.ietf.org/html/rfc7230#section-3.3.1 + msg.headers().remove(HttpHeaderNames.TRANSFER_ENCODING); + return; + } + + super.sanitizeHeadersBeforeEncode(msg, isAlwaysEmpty); + } + + @Override + protected boolean isContentAlwaysEmpty(@SuppressWarnings("unused") HttpResponse msg) { + method = queue.poll(); + return HttpMethod.HEAD.equals(method) || super.isContentAlwaysEmpty(msg); + } + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpServerExpectContinueHandler.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpServerExpectContinueHandler.java new file mode 100644 index 0000000..ec91f59 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpServerExpectContinueHandler.java @@ -0,0 +1,97 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.util.ReferenceCountUtil; + +import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_LENGTH; +import static io.netty.handler.codec.http.HttpResponseStatus.CONTINUE; +import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1; + +/** + * Sends a 100 CONTINUE + * {@link HttpResponse} to {@link HttpRequest}s which contain a 'expect: 100-continue' header. It + * should only be used for applications which do not install the {@link HttpObjectAggregator}. + *

+ * By default it accepts all expectations. + *

+ * Since {@link HttpServerExpectContinueHandler} expects {@link HttpRequest}s it should be added after {@link + * HttpServerCodec} but before any other handlers that might send a {@link HttpResponse}.

+ *
+ *  {@link io.netty.channel.ChannelPipeline} p = ...;
+ *  ...
+ *  p.addLast("serverCodec", new {@link HttpServerCodec}());
+ *  p.addLast("respondExpectContinue", new {@link HttpServerExpectContinueHandler}());
+ *  ...
+ *  p.addLast("handler", new HttpRequestHandler());
+ *  
+ *
+ */ +public class HttpServerExpectContinueHandler extends ChannelInboundHandlerAdapter { + + private static final FullHttpResponse EXPECTATION_FAILED = new DefaultFullHttpResponse( + HTTP_1_1, HttpResponseStatus.EXPECTATION_FAILED, Unpooled.EMPTY_BUFFER); + + private static final FullHttpResponse ACCEPT = new DefaultFullHttpResponse( + HTTP_1_1, CONTINUE, Unpooled.EMPTY_BUFFER); + + static { + EXPECTATION_FAILED.headers().set(CONTENT_LENGTH, 0); + ACCEPT.headers().set(CONTENT_LENGTH, 0); + } + + /** + * Produces a {@link HttpResponse} for {@link HttpRequest}s which define an expectation. Returns {@code null} if the + * request should be rejected. See {@link #rejectResponse(HttpRequest)}. + */ + protected HttpResponse acceptMessage(@SuppressWarnings("unused") HttpRequest request) { + return ACCEPT.retainedDuplicate(); + } + + /** + * Returns the appropriate 4XX {@link HttpResponse} for the given {@link HttpRequest}. + */ + protected HttpResponse rejectResponse(@SuppressWarnings("unused") HttpRequest request) { + return EXPECTATION_FAILED.retainedDuplicate(); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if (msg instanceof HttpRequest) { + HttpRequest req = (HttpRequest) msg; + + if (HttpUtil.is100ContinueExpected(req)) { + HttpResponse accept = acceptMessage(req); + + if (accept == null) { + // the expectation failed so we refuse the request. + HttpResponse rejection = rejectResponse(req); + ReferenceCountUtil.release(msg); + ctx.writeAndFlush(rejection).addListener(ChannelFutureListener.CLOSE_ON_FAILURE); + return; + } + + ctx.writeAndFlush(accept).addListener(ChannelFutureListener.CLOSE_ON_FAILURE); + req.headers().remove(HttpHeaderNames.EXPECT); + } + } + super.channelRead(ctx, msg); + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpServerKeepAliveHandler.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpServerKeepAliveHandler.java new file mode 100644 index 0000000..25b39e0 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpServerKeepAliveHandler.java @@ -0,0 +1,128 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.channel.ChannelDuplexHandler; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.ChannelPromise; + +import static io.netty.handler.codec.http.HttpUtil.*; + +/** + * HttpServerKeepAliveHandler helps close persistent connections when appropriate. + *

+ * The server channel is expected to set the proper 'Connection' header if it can handle persistent connections. {@link + * HttpServerKeepAliveHandler} will automatically close the channel for any LastHttpContent that corresponds to a client + * request for closing the connection, or if the HttpResponse associated with that LastHttpContent requested closing the + * connection or didn't have a self defined message length. + *

+ * Since {@link HttpServerKeepAliveHandler} expects {@link HttpObject}s it should be added after {@link HttpServerCodec} + * but before any other handlers that might send a {@link HttpResponse}.

+ *
+ *  {@link ChannelPipeline} p = ...;
+ *  ...
+ *  p.addLast("serverCodec", new {@link HttpServerCodec}());
+ *  p.addLast("httpKeepAlive", new {@link HttpServerKeepAliveHandler}());
+ *  p.addLast("aggregator", new {@link HttpObjectAggregator}(1048576));
+ *  ...
+ *  p.addLast("handler", new HttpRequestHandler());
+ *  
+ *
+ */ +public class HttpServerKeepAliveHandler extends ChannelDuplexHandler { + private static final String MULTIPART_PREFIX = "multipart"; + + private boolean persistentConnection = true; + // Track pending responses to support client pipelining: https://tools.ietf.org/html/rfc7230#section-6.3.2 + private int pendingResponses; + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + // read message and track if it was keepAlive + if (msg instanceof HttpRequest) { + final HttpRequest request = (HttpRequest) msg; + if (persistentConnection) { + pendingResponses += 1; + persistentConnection = isKeepAlive(request); + } + } + super.channelRead(ctx, msg); + } + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + // modify message on way out to add headers if needed + if (msg instanceof HttpResponse) { + final HttpResponse response = (HttpResponse) msg; + trackResponse(response); + // Assume the response writer knows if they can persist or not and sets isKeepAlive on the response + if (!isKeepAlive(response) || !isSelfDefinedMessageLength(response)) { + // No longer keep alive as the client can't tell when the message is done unless we close connection + pendingResponses = 0; + persistentConnection = false; + } + // Server might think it can keep connection alive, but we should fix response header if we know better + if (!shouldKeepAlive()) { + setKeepAlive(response, false); + } + } + if (msg instanceof LastHttpContent && !shouldKeepAlive()) { + promise = promise.unvoid().addListener(ChannelFutureListener.CLOSE); + } + super.write(ctx, msg, promise); + } + + private void trackResponse(HttpResponse response) { + if (!isInformational(response)) { + pendingResponses -= 1; + } + } + + private boolean shouldKeepAlive() { + return pendingResponses != 0 || persistentConnection; + } + + /** + * Keep-alive only works if the client can detect when the message has ended without relying on the connection being + * closed. + *

+ *

+ * + * @param response The HttpResponse to check + * + * @return true if the response has a self defined message length. + */ + private static boolean isSelfDefinedMessageLength(HttpResponse response) { + return isContentLengthSet(response) || isTransferEncodingChunked(response) || isMultipart(response) || + isInformational(response) || response.status().code() == HttpResponseStatus.NO_CONTENT.code(); + } + + private static boolean isInformational(HttpResponse response) { + return response.status().codeClass() == HttpStatusClass.INFORMATIONAL; + } + + private static boolean isMultipart(HttpResponse response) { + String contentType = response.headers().get(HttpHeaderNames.CONTENT_TYPE); + return contentType != null && + contentType.regionMatches(true, 0, MULTIPART_PREFIX, 0, MULTIPART_PREFIX.length()); + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpServerUpgradeHandler.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpServerUpgradeHandler.java new file mode 100644 index 0000000..9778f43 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpServerUpgradeHandler.java @@ -0,0 +1,453 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http; + +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.ReferenceCounted; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; + +import static io.netty.handler.codec.http.HttpResponseStatus.SWITCHING_PROTOCOLS; +import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1; +import static io.netty.util.AsciiString.containsAllContentEqualsIgnoreCase; +import static io.netty.util.AsciiString.containsContentEqualsIgnoreCase; +import static io.netty.util.internal.ObjectUtil.checkNotNull; +import static io.netty.util.internal.StringUtil.COMMA; + +/** + * A server-side handler that receives HTTP requests and optionally performs a protocol switch if + * the requested protocol is supported. Once an upgrade is performed, this handler removes itself + * from the pipeline. + */ +public class HttpServerUpgradeHandler extends HttpObjectAggregator { + + /** + * The source codec that is used in the pipeline initially. + */ + public interface SourceCodec { + /** + * Removes this codec (i.e. all associated handlers) from the pipeline. + */ + void upgradeFrom(ChannelHandlerContext ctx); + } + + /** + * A codec that the source can be upgraded to. + */ + public interface UpgradeCodec { + /** + * Gets all protocol-specific headers required by this protocol for a successful upgrade. + * Any supplied header will be required to appear in the {@link HttpHeaderNames#CONNECTION} header as well. + */ + Collection requiredUpgradeHeaders(); + + /** + * Prepares the {@code upgradeHeaders} for a protocol update based upon the contents of {@code upgradeRequest}. + * This method returns a boolean value to proceed or abort the upgrade in progress. If {@code false} is + * returned, the upgrade is aborted and the {@code upgradeRequest} will be passed through the inbound pipeline + * as if no upgrade was performed. If {@code true} is returned, the upgrade will proceed to the next + * step which invokes {@link #upgradeTo}. When returning {@code true}, you can add headers to + * the {@code upgradeHeaders} so that they are added to the 101 Switching protocols response. + */ + boolean prepareUpgradeResponse(ChannelHandlerContext ctx, FullHttpRequest upgradeRequest, + HttpHeaders upgradeHeaders); + + /** + * Performs an HTTP protocol upgrade from the source codec. This method is responsible for + * adding all handlers required for the new protocol. + * + * @param ctx the context for the current handler. + * @param upgradeRequest the request that triggered the upgrade to this protocol. + */ + void upgradeTo(ChannelHandlerContext ctx, FullHttpRequest upgradeRequest); + } + + /** + * Creates a new {@link UpgradeCodec} for the requested protocol name. + */ + public interface UpgradeCodecFactory { + /** + * Invoked by {@link HttpServerUpgradeHandler} for all the requested protocol names in the order of + * the client preference. The first non-{@code null} {@link UpgradeCodec} returned by this method + * will be selected. + * + * @return a new {@link UpgradeCodec}, or {@code null} if the specified protocol name is not supported + */ + UpgradeCodec newUpgradeCodec(CharSequence protocol); + } + + /** + * User event that is fired to notify about the completion of an HTTP upgrade + * to another protocol. Contains the original upgrade request so that the response + * (if required) can be sent using the new protocol. + */ + public static final class UpgradeEvent implements ReferenceCounted { + private final CharSequence protocol; + private final FullHttpRequest upgradeRequest; + + UpgradeEvent(CharSequence protocol, FullHttpRequest upgradeRequest) { + this.protocol = protocol; + this.upgradeRequest = upgradeRequest; + } + + /** + * The protocol that the channel has been upgraded to. + */ + public CharSequence protocol() { + return protocol; + } + + /** + * Gets the request that triggered the protocol upgrade. + */ + public FullHttpRequest upgradeRequest() { + return upgradeRequest; + } + + @Override + public int refCnt() { + return upgradeRequest.refCnt(); + } + + @Override + public UpgradeEvent retain() { + upgradeRequest.retain(); + return this; + } + + @Override + public UpgradeEvent retain(int increment) { + upgradeRequest.retain(increment); + return this; + } + + @Override + public UpgradeEvent touch() { + upgradeRequest.touch(); + return this; + } + + @Override + public UpgradeEvent touch(Object hint) { + upgradeRequest.touch(hint); + return this; + } + + @Override + public boolean release() { + return upgradeRequest.release(); + } + + @Override + public boolean release(int decrement) { + return upgradeRequest.release(decrement); + } + + @Override + public String toString() { + return "UpgradeEvent [protocol=" + protocol + ", upgradeRequest=" + upgradeRequest + ']'; + } + } + + private final SourceCodec sourceCodec; + private final UpgradeCodecFactory upgradeCodecFactory; + private final HttpHeadersFactory headersFactory; + private final HttpHeadersFactory trailersFactory; + private boolean handlingUpgrade; + + /** + * Constructs the upgrader with the supported codecs. + *

+ * The handler instantiated by this constructor will reject an upgrade request with non-empty content. + * It should not be a concern because an upgrade request is most likely a GET request. + * If you have a client that sends a non-GET upgrade request, please consider using + * {@link #HttpServerUpgradeHandler(SourceCodec, UpgradeCodecFactory, int)} to specify the maximum + * length of the content of an upgrade request. + *

+ * + * @param sourceCodec the codec that is being used initially + * @param upgradeCodecFactory the factory that creates a new upgrade codec + * for one of the requested upgrade protocols + */ + public HttpServerUpgradeHandler(SourceCodec sourceCodec, UpgradeCodecFactory upgradeCodecFactory) { + this(sourceCodec, upgradeCodecFactory, 0, + DefaultHttpHeadersFactory.headersFactory(), DefaultHttpHeadersFactory.trailersFactory()); + } + + /** + * Constructs the upgrader with the supported codecs. + * + * @param sourceCodec the codec that is being used initially + * @param upgradeCodecFactory the factory that creates a new upgrade codec + * for one of the requested upgrade protocols + * @param maxContentLength the maximum length of the content of an upgrade request + */ + public HttpServerUpgradeHandler( + SourceCodec sourceCodec, UpgradeCodecFactory upgradeCodecFactory, int maxContentLength) { + this(sourceCodec, upgradeCodecFactory, maxContentLength, + DefaultHttpHeadersFactory.headersFactory(), DefaultHttpHeadersFactory.trailersFactory()); + } + + /** + * Constructs the upgrader with the supported codecs. + * + * @param sourceCodec the codec that is being used initially + * @param upgradeCodecFactory the factory that creates a new upgrade codec + * for one of the requested upgrade protocols + * @param maxContentLength the maximum length of the content of an upgrade request + * @param validateHeaders validate the header names and values of the upgrade response. + */ + public HttpServerUpgradeHandler(SourceCodec sourceCodec, UpgradeCodecFactory upgradeCodecFactory, + int maxContentLength, boolean validateHeaders) { + this(sourceCodec, upgradeCodecFactory, maxContentLength, + DefaultHttpHeadersFactory.headersFactory().withValidation(validateHeaders), + DefaultHttpHeadersFactory.trailersFactory().withValidation(validateHeaders)); + } + + /** + * Constructs the upgrader with the supported codecs. + * + * @param sourceCodec the codec that is being used initially + * @param upgradeCodecFactory the factory that creates a new upgrade codec + * for one of the requested upgrade protocols + * @param maxContentLength the maximum length of the content of an upgrade request + * @param headersFactory The {@link HttpHeadersFactory} to use for headers. + * The recommended default factory is {@link DefaultHttpHeadersFactory#headersFactory()}. + * @param trailersFactory The {@link HttpHeadersFactory} to use for trailers. + * The recommended default factory is {@link DefaultHttpHeadersFactory#trailersFactory()}. + */ + public HttpServerUpgradeHandler( + SourceCodec sourceCodec, UpgradeCodecFactory upgradeCodecFactory, int maxContentLength, + HttpHeadersFactory headersFactory, HttpHeadersFactory trailersFactory) { + super(maxContentLength); + + this.sourceCodec = checkNotNull(sourceCodec, "sourceCodec"); + this.upgradeCodecFactory = checkNotNull(upgradeCodecFactory, "upgradeCodecFactory"); + this.headersFactory = checkNotNull(headersFactory, "headersFactory"); + this.trailersFactory = checkNotNull(trailersFactory, "trailersFactory"); + } + + @Override + protected void decode(ChannelHandlerContext ctx, HttpObject msg, List out) + throws Exception { + + if (!handlingUpgrade) { + // Not handling an upgrade request yet. Check if we received a new upgrade request. + if (msg instanceof HttpRequest) { + HttpRequest req = (HttpRequest) msg; + if (req.headers().contains(HttpHeaderNames.UPGRADE) && + shouldHandleUpgradeRequest(req)) { + handlingUpgrade = true; + } else { + ReferenceCountUtil.retain(msg); + ctx.fireChannelRead(msg); + return; + } + } else { + ReferenceCountUtil.retain(msg); + ctx.fireChannelRead(msg); + return; + } + } + + FullHttpRequest fullRequest; + if (msg instanceof FullHttpRequest) { + fullRequest = (FullHttpRequest) msg; + ReferenceCountUtil.retain(msg); + out.add(msg); + } else { + // Call the base class to handle the aggregation of the full request. + super.decode(ctx, msg, out); + if (out.isEmpty()) { + // The full request hasn't been created yet, still awaiting more data. + return; + } + + // Finished aggregating the full request, get it from the output list. + assert out.size() == 1; + handlingUpgrade = false; + fullRequest = (FullHttpRequest) out.get(0); + } + + if (upgrade(ctx, fullRequest)) { + // The upgrade was successful, remove the message from the output list + // so that it's not propagated to the next handler. This request will + // be propagated as a user event instead. + out.clear(); + } + + // The upgrade did not succeed, just allow the full request to propagate to the + // next handler. + } + + /** + * Determines whether the specified upgrade {@link HttpRequest} should be handled by this handler or not. + * This method will be invoked only when the request contains an {@code Upgrade} header. + * It always returns {@code true} by default, which means any request with an {@code Upgrade} header + * will be handled. You can override this method to ignore certain {@code Upgrade} headers, for example: + *
{@code
+     * @Override
+     * protected boolean isUpgradeRequest(HttpRequest req) {
+     *   // Do not handle WebSocket upgrades.
+     *   return !req.headers().contains(HttpHeaderNames.UPGRADE, "websocket", false);
+     * }
+     * }
+ */ + protected boolean shouldHandleUpgradeRequest(HttpRequest req) { + return true; + } + + /** + * Attempts to upgrade to the protocol(s) identified by the {@link HttpHeaderNames#UPGRADE} header (if provided + * in the request). + * + * @param ctx the context for this handler. + * @param request the HTTP request. + * @return {@code true} if the upgrade occurred, otherwise {@code false}. + */ + private boolean upgrade(final ChannelHandlerContext ctx, final FullHttpRequest request) { + // Select the best protocol based on those requested in the UPGRADE header. + final List requestedProtocols = splitHeader(request.headers().get(HttpHeaderNames.UPGRADE)); + final int numRequestedProtocols = requestedProtocols.size(); + UpgradeCodec upgradeCodec = null; + CharSequence upgradeProtocol = null; + for (int i = 0; i < numRequestedProtocols; i ++) { + final CharSequence p = requestedProtocols.get(i); + final UpgradeCodec c = upgradeCodecFactory.newUpgradeCodec(p); + if (c != null) { + upgradeProtocol = p; + upgradeCodec = c; + break; + } + } + + if (upgradeCodec == null) { + // None of the requested protocols are supported, don't upgrade. + return false; + } + + // Make sure the CONNECTION header is present. + List connectionHeaderValues = request.headers().getAll(HttpHeaderNames.CONNECTION); + + if (connectionHeaderValues == null || connectionHeaderValues.isEmpty()) { + return false; + } + + final StringBuilder concatenatedConnectionValue = new StringBuilder(connectionHeaderValues.size() * 10); + for (CharSequence connectionHeaderValue : connectionHeaderValues) { + concatenatedConnectionValue.append(connectionHeaderValue).append(COMMA); + } + concatenatedConnectionValue.setLength(concatenatedConnectionValue.length() - 1); + + // Make sure the CONNECTION header contains UPGRADE as well as all protocol-specific headers. + Collection requiredHeaders = upgradeCodec.requiredUpgradeHeaders(); + List values = splitHeader(concatenatedConnectionValue); + if (!containsContentEqualsIgnoreCase(values, HttpHeaderNames.UPGRADE) || + !containsAllContentEqualsIgnoreCase(values, requiredHeaders)) { + return false; + } + + // Ensure that all required protocol-specific headers are found in the request. + for (CharSequence requiredHeader : requiredHeaders) { + if (!request.headers().contains(requiredHeader)) { + return false; + } + } + + // Prepare and send the upgrade response. Wait for this write to complete before upgrading, + // since we need the old codec in-place to properly encode the response. + final FullHttpResponse upgradeResponse = createUpgradeResponse(upgradeProtocol); + if (!upgradeCodec.prepareUpgradeResponse(ctx, request, upgradeResponse.headers())) { + return false; + } + + // Create the user event to be fired once the upgrade completes. + final UpgradeEvent event = new UpgradeEvent(upgradeProtocol, request); + + // After writing the upgrade response we immediately prepare the + // pipeline for the next protocol to avoid a race between completion + // of the write future and receiving data before the pipeline is + // restructured. + try { + final ChannelFuture writeComplete = ctx.writeAndFlush(upgradeResponse); + // Perform the upgrade to the new protocol. + sourceCodec.upgradeFrom(ctx); + upgradeCodec.upgradeTo(ctx, request); + + // Remove this handler from the pipeline. + ctx.pipeline().remove(HttpServerUpgradeHandler.this); + + // Notify that the upgrade has occurred. Retain the event to offset + // the release() in the finally block. + ctx.fireUserEventTriggered(event.retain()); + + // Add the listener last to avoid firing upgrade logic after + // the channel is already closed since the listener may fire + // immediately if the write failed eagerly. + writeComplete.addListener(ChannelFutureListener.CLOSE_ON_FAILURE); + } finally { + // Release the event if the upgrade event wasn't fired. + event.release(); + } + return true; + } + + /** + * Creates the 101 Switching Protocols response message. + */ + private FullHttpResponse createUpgradeResponse(CharSequence upgradeProtocol) { + DefaultFullHttpResponse res = new DefaultFullHttpResponse( + HTTP_1_1, SWITCHING_PROTOCOLS, Unpooled.EMPTY_BUFFER, headersFactory, trailersFactory); + res.headers().add(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE); + res.headers().add(HttpHeaderNames.UPGRADE, upgradeProtocol); + return res; + } + + /** + * Splits a comma-separated header value. The returned set is case-insensitive and contains each + * part with whitespace removed. + */ + private static List splitHeader(CharSequence header) { + final StringBuilder builder = new StringBuilder(header.length()); + final List protocols = new ArrayList(4); + for (int i = 0; i < header.length(); ++i) { + char c = header.charAt(i); + if (Character.isWhitespace(c)) { + // Don't include any whitespace. + continue; + } + if (c == ',') { + // Add the string and reset the builder for the next protocol. + protocols.add(builder.toString()); + builder.setLength(0); + } else { + builder.append(c); + } + } + + // Add the last protocol + if (builder.length() > 0) { + protocols.add(builder.toString()); + } + + return protocols; + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpStatusClass.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpStatusClass.java new file mode 100644 index 0000000..b31c29e --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpStatusClass.java @@ -0,0 +1,126 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.codec.http; + +import io.netty.util.AsciiString; + +/** + * The class of HTTP status. + */ +public enum HttpStatusClass { + /** + * The informational class (1xx) + */ + INFORMATIONAL(100, 200, "Informational"), + /** + * The success class (2xx) + */ + SUCCESS(200, 300, "Success"), + /** + * The redirection class (3xx) + */ + REDIRECTION(300, 400, "Redirection"), + /** + * The client error class (4xx) + */ + CLIENT_ERROR(400, 500, "Client Error"), + /** + * The server error class (5xx) + */ + SERVER_ERROR(500, 600, "Server Error"), + /** + * The unknown class + */ + UNKNOWN(0, 0, "Unknown Status") { + @Override + public boolean contains(int code) { + return code < 100 || code >= 600; + } + }; + + private static final HttpStatusClass[] statusArray = new HttpStatusClass[6]; + static { + statusArray[1] = INFORMATIONAL; + statusArray[2] = SUCCESS; + statusArray[3] = REDIRECTION; + statusArray[4] = CLIENT_ERROR; + statusArray[5] = SERVER_ERROR; + } + + /** + * Returns the class of the specified HTTP status code. + */ + public static HttpStatusClass valueOf(int code) { + if (UNKNOWN.contains(code)) { + return UNKNOWN; + } + return statusArray[fast_div100(code)]; + } + + /** + * @param dividend Must >= 0 + * @return dividend/100 + */ + private static int fast_div100(int dividend) { + return (int) ((dividend * 1374389535L) >> 37); + } + + /** + * Returns the class of the specified HTTP status code. + * @param code Just the numeric portion of the http status code. + */ + public static HttpStatusClass valueOf(CharSequence code) { + if (code != null && code.length() == 3) { + char c0 = code.charAt(0); + return isDigit(c0) && isDigit(code.charAt(1)) && isDigit(code.charAt(2)) ? valueOf(digit(c0) * 100) + : UNKNOWN; + } + return UNKNOWN; + } + + private static int digit(char c) { + return c - '0'; + } + + private static boolean isDigit(char c) { + return c >= '0' && c <= '9'; + } + + private final int min; + private final int max; + private final AsciiString defaultReasonPhrase; + + HttpStatusClass(int min, int max, String defaultReasonPhrase) { + this.min = min; + this.max = max; + this.defaultReasonPhrase = AsciiString.cached(defaultReasonPhrase); + } + + /** + * Returns {@code true} if and only if the specified HTTP status code falls into this class. + */ + public boolean contains(int code) { + return code >= min && code < max; + } + + /** + * Returns the default reason phrase of this HTTP status class. + */ + AsciiString defaultReasonPhrase() { + return defaultReasonPhrase; + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpUtil.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpUtil.java new file mode 100644 index 0000000..50b1a63 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpUtil.java @@ -0,0 +1,632 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import java.net.InetSocketAddress; +import java.net.URI; +import java.nio.charset.Charset; +import java.nio.charset.IllegalCharsetNameException; +import java.nio.charset.UnsupportedCharsetException; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; + +import io.netty.util.AsciiString; +import io.netty.util.CharsetUtil; +import io.netty.util.NetUtil; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.UnstableApi; + +import static io.netty.util.internal.StringUtil.COMMA; +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; + +/** + * Utility methods useful in the HTTP context. + */ +public final class HttpUtil { + + private static final AsciiString CHARSET_EQUALS = AsciiString.of(HttpHeaderValues.CHARSET + "="); + private static final AsciiString SEMICOLON = AsciiString.cached(";"); + private static final String COMMA_STRING = String.valueOf(COMMA); + + private HttpUtil() { } + + /** + * Determine if a uri is in origin-form according to + * rfc7230, 5.3. + */ + public static boolean isOriginForm(URI uri) { + return isOriginForm(uri.toString()); + } + + /** + * Determine if a string uri is in origin-form according to + * rfc7230, 5.3. + */ + public static boolean isOriginForm(String uri) { + return uri.startsWith("/"); + } + + /** + * Determine if a uri is in asterisk-form according to + * rfc7230, 5.3. + */ + public static boolean isAsteriskForm(URI uri) { + return isAsteriskForm(uri.toString()); + } + + /** + * Determine if a string uri is in asterisk-form according to + * rfc7230, 5.3. + */ + public static boolean isAsteriskForm(String uri) { + return "*".equals(uri); + } + + /** + * Returns {@code true} if and only if the connection can remain open and + * thus 'kept alive'. This methods respects the value of the. + * + * {@code "Connection"} header first and then the return value of + * {@link HttpVersion#isKeepAliveDefault()}. + */ + public static boolean isKeepAlive(HttpMessage message) { + return !message.headers().containsValue(HttpHeaderNames.CONNECTION, HttpHeaderValues.CLOSE, true) && + (message.protocolVersion().isKeepAliveDefault() || + message.headers().containsValue(HttpHeaderNames.CONNECTION, HttpHeaderValues.KEEP_ALIVE, true)); + } + + /** + * Sets the value of the {@code "Connection"} header depending on the + * protocol version of the specified message. This getMethod sets or removes + * the {@code "Connection"} header depending on what the default keep alive + * mode of the message's protocol version is, as specified by + * {@link HttpVersion#isKeepAliveDefault()}. + *
    + *
  • If the connection is kept alive by default: + *
      + *
    • set to {@code "close"} if {@code keepAlive} is {@code false}.
    • + *
    • remove otherwise.
    • + *
  • + *
  • If the connection is closed by default: + *
      + *
    • set to {@code "keep-alive"} if {@code keepAlive} is {@code true}.
    • + *
    • remove otherwise.
    • + *
  • + *
+ * @see #setKeepAlive(HttpHeaders, HttpVersion, boolean) + */ + public static void setKeepAlive(HttpMessage message, boolean keepAlive) { + setKeepAlive(message.headers(), message.protocolVersion(), keepAlive); + } + + /** + * Sets the value of the {@code "Connection"} header depending on the + * protocol version of the specified message. This getMethod sets or removes + * the {@code "Connection"} header depending on what the default keep alive + * mode of the message's protocol version is, as specified by + * {@link HttpVersion#isKeepAliveDefault()}. + *
    + *
  • If the connection is kept alive by default: + *
      + *
    • set to {@code "close"} if {@code keepAlive} is {@code false}.
    • + *
    • remove otherwise.
    • + *
  • + *
  • If the connection is closed by default: + *
      + *
    • set to {@code "keep-alive"} if {@code keepAlive} is {@code true}.
    • + *
    • remove otherwise.
    • + *
  • + *
+ */ + public static void setKeepAlive(HttpHeaders h, HttpVersion httpVersion, boolean keepAlive) { + if (httpVersion.isKeepAliveDefault()) { + if (keepAlive) { + h.remove(HttpHeaderNames.CONNECTION); + } else { + h.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.CLOSE); + } + } else { + if (keepAlive) { + h.set(HttpHeaderNames.CONNECTION, HttpHeaderValues.KEEP_ALIVE); + } else { + h.remove(HttpHeaderNames.CONNECTION); + } + } + } + + /** + * Returns the length of the content. Please note that this value is + * not retrieved from {@link HttpContent#content()} but from the + * {@code "Content-Length"} header, and thus they are independent from each + * other. + * + * @return the content length + * + * @throws NumberFormatException + * if the message does not have the {@code "Content-Length"} header + * or its value is not a number + */ + public static long getContentLength(HttpMessage message) { + String value = message.headers().get(HttpHeaderNames.CONTENT_LENGTH); + if (value != null) { + return Long.parseLong(value); + } + + // We know the content length if it's a Web Socket message even if + // Content-Length header is missing. + long webSocketContentLength = getWebSocketContentLength(message); + if (webSocketContentLength >= 0) { + return webSocketContentLength; + } + + // Otherwise we don't. + throw new NumberFormatException("header not found: " + HttpHeaderNames.CONTENT_LENGTH); + } + + /** + * Returns the length of the content or the specified default value if the message does not have the {@code + * "Content-Length" header}. Please note that this value is not retrieved from {@link HttpContent#content()} but + * from the {@code "Content-Length"} header, and thus they are independent from each other. + * + * @param message the message + * @param defaultValue the default value + * @return the content length or the specified default value + * @throws NumberFormatException if the {@code "Content-Length"} header does not parse as a long + */ + public static long getContentLength(HttpMessage message, long defaultValue) { + String value = message.headers().get(HttpHeaderNames.CONTENT_LENGTH); + if (value != null) { + return Long.parseLong(value); + } + + // We know the content length if it's a Web Socket message even if + // Content-Length header is missing. + long webSocketContentLength = getWebSocketContentLength(message); + if (webSocketContentLength >= 0) { + return webSocketContentLength; + } + + // Otherwise we don't. + return defaultValue; + } + + /** + * Get an {@code int} representation of {@link #getContentLength(HttpMessage, long)}. + * + * @return the content length or {@code defaultValue} if this message does + * not have the {@code "Content-Length"} header. + * + * @throws NumberFormatException if the {@code "Content-Length"} header does not parse as an int + */ + public static int getContentLength(HttpMessage message, int defaultValue) { + return (int) Math.min(Integer.MAX_VALUE, getContentLength(message, (long) defaultValue)); + } + + /** + * Returns the content length of the specified web socket message. If the + * specified message is not a web socket message, {@code -1} is returned. + */ + private static int getWebSocketContentLength(HttpMessage message) { + // WebSocket messages have constant content-lengths. + HttpHeaders h = message.headers(); + if (message instanceof HttpRequest) { + HttpRequest req = (HttpRequest) message; + if (HttpMethod.GET.equals(req.method()) && + h.contains(HttpHeaderNames.SEC_WEBSOCKET_KEY1) && + h.contains(HttpHeaderNames.SEC_WEBSOCKET_KEY2)) { + return 8; + } + } else if (message instanceof HttpResponse) { + HttpResponse res = (HttpResponse) message; + if (res.status().code() == 101 && + h.contains(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN) && + h.contains(HttpHeaderNames.SEC_WEBSOCKET_LOCATION)) { + return 16; + } + } + + // Not a web socket message + return -1; + } + + /** + * Sets the {@code "Content-Length"} header. + */ + public static void setContentLength(HttpMessage message, long length) { + message.headers().set(HttpHeaderNames.CONTENT_LENGTH, length); + } + + public static boolean isContentLengthSet(HttpMessage m) { + return m.headers().contains(HttpHeaderNames.CONTENT_LENGTH); + } + + /** + * Returns {@code true} if and only if the specified message contains an expect header and the only expectation + * present is the 100-continue expectation. Note that this method returns {@code false} if the expect header is + * not valid for the message (e.g., the message is a response, or the version on the message is HTTP/1.0). + * + * @param message the message + * @return {@code true} if and only if the expectation 100-continue is present and it is the only expectation + * present + */ + public static boolean is100ContinueExpected(HttpMessage message) { + return isExpectHeaderValid(message) + // unquoted tokens in the expect header are case-insensitive, thus 100-continue is case insensitive + && message.headers().contains(HttpHeaderNames.EXPECT, HttpHeaderValues.CONTINUE, true); + } + + /** + * Returns {@code true} if the specified message contains an expect header specifying an expectation that is not + * supported. Note that this method returns {@code false} if the expect header is not valid for the message + * (e.g., the message is a response, or the version on the message is HTTP/1.0). + * + * @param message the message + * @return {@code true} if and only if an expectation is present that is not supported + */ + static boolean isUnsupportedExpectation(HttpMessage message) { + if (!isExpectHeaderValid(message)) { + return false; + } + + final String expectValue = message.headers().get(HttpHeaderNames.EXPECT); + return expectValue != null && !HttpHeaderValues.CONTINUE.toString().equalsIgnoreCase(expectValue); + } + + private static boolean isExpectHeaderValid(final HttpMessage message) { + /* + * Expect: 100-continue is for requests only and it works only on HTTP/1.1 or later. Note further that RFC 7231 + * section 5.1.1 says "A server that receives a 100-continue expectation in an HTTP/1.0 request MUST ignore + * that expectation." + */ + return message instanceof HttpRequest && + message.protocolVersion().compareTo(HttpVersion.HTTP_1_1) >= 0; + } + + /** + * Sets or removes the {@code "Expect: 100-continue"} header to / from the + * specified message. If {@code expected} is {@code true}, + * the {@code "Expect: 100-continue"} header is set and all other previous + * {@code "Expect"} headers are removed. Otherwise, all {@code "Expect"} + * headers are removed completely. + */ + public static void set100ContinueExpected(HttpMessage message, boolean expected) { + if (expected) { + message.headers().set(HttpHeaderNames.EXPECT, HttpHeaderValues.CONTINUE); + } else { + message.headers().remove(HttpHeaderNames.EXPECT); + } + } + + /** + * Checks to see if the transfer encoding in a specified {@link HttpMessage} is chunked + * + * @param message The message to check + * @return True if transfer encoding is chunked, otherwise false + */ + public static boolean isTransferEncodingChunked(HttpMessage message) { + return message.headers().containsValue(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED, true); + } + + /** + * Set the {@link HttpHeaderNames#TRANSFER_ENCODING} to either include {@link HttpHeaderValues#CHUNKED} if + * {@code chunked} is {@code true}, or remove {@link HttpHeaderValues#CHUNKED} if {@code chunked} is {@code false}. + * + * @param m The message which contains the headers to modify. + * @param chunked if {@code true} then include {@link HttpHeaderValues#CHUNKED} in the headers. otherwise remove + * {@link HttpHeaderValues#CHUNKED} from the headers. + */ + public static void setTransferEncodingChunked(HttpMessage m, boolean chunked) { + if (chunked) { + m.headers().set(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED); + m.headers().remove(HttpHeaderNames.CONTENT_LENGTH); + } else { + List encodings = m.headers().getAll(HttpHeaderNames.TRANSFER_ENCODING); + if (encodings.isEmpty()) { + return; + } + List values = new ArrayList(encodings); + Iterator valuesIt = values.iterator(); + while (valuesIt.hasNext()) { + CharSequence value = valuesIt.next(); + if (HttpHeaderValues.CHUNKED.contentEqualsIgnoreCase(value)) { + valuesIt.remove(); + } + } + if (values.isEmpty()) { + m.headers().remove(HttpHeaderNames.TRANSFER_ENCODING); + } else { + m.headers().set(HttpHeaderNames.TRANSFER_ENCODING, values); + } + } + } + + /** + * Fetch charset from message's Content-Type header. + * + * @param message entity to fetch Content-Type header from + * @return the charset from message's Content-Type header or {@link CharsetUtil#ISO_8859_1} + * if charset is not presented or unparsable + */ + public static Charset getCharset(HttpMessage message) { + return getCharset(message, CharsetUtil.ISO_8859_1); + } + + /** + * Fetch charset from Content-Type header value. + * + * @param contentTypeValue Content-Type header value to parse + * @return the charset from message's Content-Type header or {@link CharsetUtil#ISO_8859_1} + * if charset is not presented or unparsable + */ + public static Charset getCharset(CharSequence contentTypeValue) { + if (contentTypeValue != null) { + return getCharset(contentTypeValue, CharsetUtil.ISO_8859_1); + } else { + return CharsetUtil.ISO_8859_1; + } + } + + /** + * Fetch charset from message's Content-Type header. + * + * @param message entity to fetch Content-Type header from + * @param defaultCharset result to use in case of empty, incorrect or doesn't contain required part header value + * @return the charset from message's Content-Type header or {@code defaultCharset} + * if charset is not presented or unparsable + */ + public static Charset getCharset(HttpMessage message, Charset defaultCharset) { + CharSequence contentTypeValue = message.headers().get(HttpHeaderNames.CONTENT_TYPE); + if (contentTypeValue != null) { + return getCharset(contentTypeValue, defaultCharset); + } else { + return defaultCharset; + } + } + + /** + * Fetch charset from Content-Type header value. + * + * @param contentTypeValue Content-Type header value to parse + * @param defaultCharset result to use in case of empty, incorrect or doesn't contain required part header value + * @return the charset from message's Content-Type header or {@code defaultCharset} + * if charset is not presented or unparsable + */ + public static Charset getCharset(CharSequence contentTypeValue, Charset defaultCharset) { + if (contentTypeValue != null) { + CharSequence charsetRaw = getCharsetAsSequence(contentTypeValue); + if (charsetRaw != null) { + if (charsetRaw.length() > 2) { // at least contains 2 quotes(") + if (charsetRaw.charAt(0) == '"' && charsetRaw.charAt(charsetRaw.length() - 1) == '"') { + charsetRaw = charsetRaw.subSequence(1, charsetRaw.length() - 1); + } + } + try { + return Charset.forName(charsetRaw.toString()); + } catch (IllegalCharsetNameException ignored) { + // just return the default charset + } catch (UnsupportedCharsetException ignored) { + // just return the default charset + } + } + } + return defaultCharset; + } + + /** + * Fetch charset from message's Content-Type header as a char sequence. + * + * A lot of sites/possibly clients have charset="CHARSET", for example charset="utf-8". Or "utf8" instead of "utf-8" + * This is not according to standard, but this method provide an ability to catch desired mistakes manually in code + * + * @param message entity to fetch Content-Type header from + * @return the {@code CharSequence} with charset from message's Content-Type header + * or {@code null} if charset is not presented + * @deprecated use {@link #getCharsetAsSequence(HttpMessage)} + */ + @Deprecated + public static CharSequence getCharsetAsString(HttpMessage message) { + return getCharsetAsSequence(message); + } + + /** + * Fetch charset from message's Content-Type header as a char sequence. + * + * A lot of sites/possibly clients have charset="CHARSET", for example charset="utf-8". Or "utf8" instead of "utf-8" + * This is not according to standard, but this method provide an ability to catch desired mistakes manually in code + * + * @return the {@code CharSequence} with charset from message's Content-Type header + * or {@code null} if charset is not presented + */ + public static CharSequence getCharsetAsSequence(HttpMessage message) { + CharSequence contentTypeValue = message.headers().get(HttpHeaderNames.CONTENT_TYPE); + if (contentTypeValue != null) { + return getCharsetAsSequence(contentTypeValue); + } else { + return null; + } + } + + /** + * Fetch charset from Content-Type header value as a char sequence. + * + * A lot of sites/possibly clients have charset="CHARSET", for example charset="utf-8". Or "utf8" instead of "utf-8" + * This is not according to standard, but this method provide an ability to catch desired mistakes manually in code + * + * @param contentTypeValue Content-Type header value to parse + * @return the {@code CharSequence} with charset from message's Content-Type header + * or {@code null} if charset is not presented + * @throws NullPointerException in case if {@code contentTypeValue == null} + */ + public static CharSequence getCharsetAsSequence(CharSequence contentTypeValue) { + ObjectUtil.checkNotNull(contentTypeValue, "contentTypeValue"); + + int indexOfCharset = AsciiString.indexOfIgnoreCaseAscii(contentTypeValue, CHARSET_EQUALS, 0); + if (indexOfCharset == AsciiString.INDEX_NOT_FOUND) { + return null; + } + + int indexOfEncoding = indexOfCharset + CHARSET_EQUALS.length(); + if (indexOfEncoding < contentTypeValue.length()) { + CharSequence charsetCandidate = contentTypeValue.subSequence(indexOfEncoding, contentTypeValue.length()); + int indexOfSemicolon = AsciiString.indexOfIgnoreCaseAscii(charsetCandidate, SEMICOLON, 0); + if (indexOfSemicolon == AsciiString.INDEX_NOT_FOUND) { + return charsetCandidate; + } + + return charsetCandidate.subSequence(0, indexOfSemicolon); + } + + return null; + } + + /** + * Fetch MIME type part from message's Content-Type header as a char sequence. + * + * @param message entity to fetch Content-Type header from + * @return the MIME type as a {@code CharSequence} from message's Content-Type header + * or {@code null} if content-type header or MIME type part of this header are not presented + *

+ * "content-type: text/html; charset=utf-8" - "text/html" will be returned
+ * "content-type: text/html" - "text/html" will be returned
+ * "content-type: " or no header - {@code null} we be returned + */ + public static CharSequence getMimeType(HttpMessage message) { + CharSequence contentTypeValue = message.headers().get(HttpHeaderNames.CONTENT_TYPE); + if (contentTypeValue != null) { + return getMimeType(contentTypeValue); + } else { + return null; + } + } + + /** + * Fetch MIME type part from Content-Type header value as a char sequence. + * + * @param contentTypeValue Content-Type header value to parse + * @return the MIME type as a {@code CharSequence} from message's Content-Type header + * or {@code null} if content-type header or MIME type part of this header are not presented + *

+ * "content-type: text/html; charset=utf-8" - "text/html" will be returned
+ * "content-type: text/html" - "text/html" will be returned
+ * "content-type: empty header - {@code null} we be returned + * @throws NullPointerException in case if {@code contentTypeValue == null} + */ + public static CharSequence getMimeType(CharSequence contentTypeValue) { + ObjectUtil.checkNotNull(contentTypeValue, "contentTypeValue"); + + int indexOfSemicolon = AsciiString.indexOfIgnoreCaseAscii(contentTypeValue, SEMICOLON, 0); + if (indexOfSemicolon != AsciiString.INDEX_NOT_FOUND) { + return contentTypeValue.subSequence(0, indexOfSemicolon); + } else { + return contentTypeValue.length() > 0 ? contentTypeValue : null; + } + } + + /** + * Formats the host string of an address so it can be used for computing an HTTP component + * such as a URL or a Host header + * + * @param addr the address + * @return the formatted String + */ + public static String formatHostnameForHttp(InetSocketAddress addr) { + String hostString = NetUtil.getHostname(addr); + if (NetUtil.isValidIpV6Address(hostString)) { + if (!addr.isUnresolved()) { + hostString = NetUtil.toAddressString(addr.getAddress()); + } + return '[' + hostString + ']'; + } + return hostString; + } + + /** + * Validates, and optionally extracts the content length from headers. This method is not intended for + * general use, but is here to be shared between HTTP/1 and HTTP/2 parsing. + * + * @param contentLengthFields the content-length header fields. + * @param isHttp10OrEarlier {@code true} if we are handling HTTP/1.0 or earlier + * @param allowDuplicateContentLengths {@code true} if multiple, identical-value content lengths should be allowed. + * @return the normalized content length from the headers or {@code -1} if the fields were empty. + * @throws IllegalArgumentException if the content-length fields are not valid + */ + @UnstableApi + public static long normalizeAndGetContentLength( + List contentLengthFields, boolean isHttp10OrEarlier, + boolean allowDuplicateContentLengths) { + if (contentLengthFields.isEmpty()) { + return -1; + } + + // Guard against multiple Content-Length headers as stated in + // https://tools.ietf.org/html/rfc7230#section-3.3.2: + // + // If a message is received that has multiple Content-Length header + // fields with field-values consisting of the same decimal value, or a + // single Content-Length header field with a field value containing a + // list of identical decimal values (e.g., "Content-Length: 42, 42"), + // indicating that duplicate Content-Length header fields have been + // generated or combined by an upstream message processor, then the + // recipient MUST either reject the message as invalid or replace the + // duplicated field-values with a single valid Content-Length field + // containing that decimal value prior to determining the message body + // length or forwarding the message. + String firstField = contentLengthFields.get(0).toString(); + boolean multipleContentLengths = + contentLengthFields.size() > 1 || firstField.indexOf(COMMA) >= 0; + + if (multipleContentLengths && !isHttp10OrEarlier) { + if (allowDuplicateContentLengths) { + // Find and enforce that all Content-Length values are the same + String firstValue = null; + for (CharSequence field : contentLengthFields) { + String[] tokens = field.toString().split(COMMA_STRING, -1); + for (String token : tokens) { + String trimmed = token.trim(); + if (firstValue == null) { + firstValue = trimmed; + } else if (!trimmed.equals(firstValue)) { + throw new IllegalArgumentException( + "Multiple Content-Length values found: " + contentLengthFields); + } + } + } + // Replace the duplicated field-values with a single valid Content-Length field + firstField = firstValue; + } else { + // Reject the message as invalid + throw new IllegalArgumentException( + "Multiple Content-Length values found: " + contentLengthFields); + } + } + // Ensure we not allow sign as part of the content-length: + // See https://github.com/squid-cache/squid/security/advisories/GHSA-qf3v-rc95-96j5 + if (firstField.isEmpty() || !Character.isDigit(firstField.charAt(0))) { + // Reject the message as invalid + throw new IllegalArgumentException( + "Content-Length value is not a number: " + firstField); + } + try { + final long value = Long.parseLong(firstField); + return checkPositiveOrZero(value, "Content-Length value"); + } catch (NumberFormatException e) { + // Reject the message as invalid + throw new IllegalArgumentException( + "Content-Length value is not a number: " + firstField, e); + } + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpVersion.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpVersion.java new file mode 100644 index 0000000..ec287e6 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/HttpVersion.java @@ -0,0 +1,263 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; +import static io.netty.util.internal.ObjectUtil.checkNonEmptyAfterTrim; + +import io.netty.buffer.ByteBuf; +import io.netty.util.CharsetUtil; +import io.netty.util.internal.ObjectUtil; + +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * The version of HTTP or its derived protocols, such as + * RTSP and + * ICAP. + */ +public class HttpVersion implements Comparable { + + private static final Pattern VERSION_PATTERN = + Pattern.compile("(\\S+)/(\\d+)\\.(\\d+)"); + + static final String HTTP_1_0_STRING = "HTTP/1.0"; + static final String HTTP_1_1_STRING = "HTTP/1.1"; + + /** + * HTTP/1.0 + */ + public static final HttpVersion HTTP_1_0 = new HttpVersion("HTTP", 1, 0, false, true); + + /** + * HTTP/1.1 + */ + public static final HttpVersion HTTP_1_1 = new HttpVersion("HTTP", 1, 1, true, true); + + /** + * Returns an existing or new {@link HttpVersion} instance which matches to + * the specified protocol version string. If the specified {@code text} is + * equal to {@code "HTTP/1.0"}, {@link #HTTP_1_0} will be returned. If the + * specified {@code text} is equal to {@code "HTTP/1.1"}, {@link #HTTP_1_1} + * will be returned. Otherwise, a new {@link HttpVersion} instance will be + * returned. + */ + public static HttpVersion valueOf(String text) { + ObjectUtil.checkNotNull(text, "text"); + + // super fast-path + if (text == HTTP_1_1_STRING) { + return HTTP_1_1; + } else if (text == HTTP_1_0_STRING) { + return HTTP_1_0; + } + + text = text.trim(); + + if (text.isEmpty()) { + throw new IllegalArgumentException("text is empty (possibly HTTP/0.9)"); + } + + // Try to match without convert to uppercase first as this is what 99% of all clients + // will send anyway. Also there is a change to the RFC to make it clear that it is + // expected to be case-sensitive + // + // See: + // * https://trac.tools.ietf.org/wg/httpbis/trac/ticket/1 + // * https://trac.tools.ietf.org/wg/httpbis/trac/wiki + // + HttpVersion version = version0(text); + if (version == null) { + version = new HttpVersion(text, true); + } + return version; + } + + private static HttpVersion version0(String text) { + if (HTTP_1_1_STRING.equals(text)) { + return HTTP_1_1; + } + if (HTTP_1_0_STRING.equals(text)) { + return HTTP_1_0; + } + return null; + } + + private final String protocolName; + private final int majorVersion; + private final int minorVersion; + private final String text; + private final boolean keepAliveDefault; + private final byte[] bytes; + + /** + * Creates a new HTTP version with the specified version string. You will + * not need to create a new instance unless you are implementing a protocol + * derived from HTTP, such as + * RTSP and + * ICAP. + * + * @param keepAliveDefault + * {@code true} if and only if the connection is kept alive unless + * the {@code "Connection"} header is set to {@code "close"} explicitly. + */ + public HttpVersion(String text, boolean keepAliveDefault) { + text = checkNonEmptyAfterTrim(text, "text").toUpperCase(); + + Matcher m = VERSION_PATTERN.matcher(text); + if (!m.matches()) { + throw new IllegalArgumentException("invalid version format: " + text); + } + + protocolName = m.group(1); + majorVersion = Integer.parseInt(m.group(2)); + minorVersion = Integer.parseInt(m.group(3)); + this.text = protocolName + '/' + majorVersion + '.' + minorVersion; + this.keepAliveDefault = keepAliveDefault; + bytes = null; + } + + /** + * Creates a new HTTP version with the specified protocol name and version + * numbers. You will not need to create a new instance unless you are + * implementing a protocol derived from HTTP, such as + * RTSP and + * ICAP + * + * @param keepAliveDefault + * {@code true} if and only if the connection is kept alive unless + * the {@code "Connection"} header is set to {@code "close"} explicitly. + */ + public HttpVersion( + String protocolName, int majorVersion, int minorVersion, + boolean keepAliveDefault) { + this(protocolName, majorVersion, minorVersion, keepAliveDefault, false); + } + + private HttpVersion( + String protocolName, int majorVersion, int minorVersion, + boolean keepAliveDefault, boolean bytes) { + protocolName = checkNonEmptyAfterTrim(protocolName, "protocolName").toUpperCase(); + + for (int i = 0; i < protocolName.length(); i ++) { + if (Character.isISOControl(protocolName.charAt(i)) || + Character.isWhitespace(protocolName.charAt(i))) { + throw new IllegalArgumentException("invalid character in protocolName"); + } + } + + checkPositiveOrZero(majorVersion, "majorVersion"); + checkPositiveOrZero(minorVersion, "minorVersion"); + + this.protocolName = protocolName; + this.majorVersion = majorVersion; + this.minorVersion = minorVersion; + text = protocolName + '/' + majorVersion + '.' + minorVersion; + this.keepAliveDefault = keepAliveDefault; + + if (bytes) { + this.bytes = text.getBytes(CharsetUtil.US_ASCII); + } else { + this.bytes = null; + } + } + + /** + * Returns the name of the protocol such as {@code "HTTP"} in {@code "HTTP/1.0"}. + */ + public String protocolName() { + return protocolName; + } + + /** + * Returns the name of the protocol such as {@code 1} in {@code "HTTP/1.0"}. + */ + public int majorVersion() { + return majorVersion; + } + + /** + * Returns the name of the protocol such as {@code 0} in {@code "HTTP/1.0"}. + */ + public int minorVersion() { + return minorVersion; + } + + /** + * Returns the full protocol version text such as {@code "HTTP/1.0"}. + */ + public String text() { + return text; + } + + /** + * Returns {@code true} if and only if the connection is kept alive unless + * the {@code "Connection"} header is set to {@code "close"} explicitly. + */ + public boolean isKeepAliveDefault() { + return keepAliveDefault; + } + + /** + * Returns the full protocol version text such as {@code "HTTP/1.0"}. + */ + @Override + public String toString() { + return text(); + } + + @Override + public int hashCode() { + return (protocolName().hashCode() * 31 + majorVersion()) * 31 + + minorVersion(); + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof HttpVersion)) { + return false; + } + + HttpVersion that = (HttpVersion) o; + return minorVersion() == that.minorVersion() && + majorVersion() == that.majorVersion() && + protocolName().equals(that.protocolName()); + } + + @Override + public int compareTo(HttpVersion o) { + int v = protocolName().compareTo(o.protocolName()); + if (v != 0) { + return v; + } + + v = majorVersion() - o.majorVersion(); + if (v != 0) { + return v; + } + + return minorVersion() - o.minorVersion(); + } + + void encode(ByteBuf buf) { + if (bytes == null) { + buf.writeCharSequence(text, CharsetUtil.US_ASCII); + } else { + buf.writeBytes(bytes); + } + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/LastHttpContent.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/LastHttpContent.java new file mode 100644 index 0000000..a52a383 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/LastHttpContent.java @@ -0,0 +1,144 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.handler.codec.DecoderResult; + +/** + * The last {@link HttpContent} which has trailing headers. + */ +public interface LastHttpContent extends HttpContent { + + /** + * The 'end of content' marker in chunked encoding. + */ + LastHttpContent EMPTY_LAST_CONTENT = new LastHttpContent() { + + @Override + public ByteBuf content() { + return Unpooled.EMPTY_BUFFER; + } + + @Override + public LastHttpContent copy() { + return EMPTY_LAST_CONTENT; + } + + @Override + public LastHttpContent duplicate() { + return this; + } + + @Override + public LastHttpContent replace(ByteBuf content) { + return new DefaultLastHttpContent(content); + } + + @Override + public LastHttpContent retainedDuplicate() { + return this; + } + + @Override + public HttpHeaders trailingHeaders() { + return EmptyHttpHeaders.INSTANCE; + } + + @Override + public DecoderResult decoderResult() { + return DecoderResult.SUCCESS; + } + + @Override + @Deprecated + public DecoderResult getDecoderResult() { + return decoderResult(); + } + + @Override + public void setDecoderResult(DecoderResult result) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public int refCnt() { + return 1; + } + + @Override + public LastHttpContent retain() { + return this; + } + + @Override + public LastHttpContent retain(int increment) { + return this; + } + + @Override + public LastHttpContent touch() { + return this; + } + + @Override + public LastHttpContent touch(Object hint) { + return this; + } + + @Override + public boolean release() { + return false; + } + + @Override + public boolean release(int decrement) { + return false; + } + + @Override + public String toString() { + return "EmptyLastHttpContent"; + } + }; + + HttpHeaders trailingHeaders(); + + @Override + LastHttpContent copy(); + + @Override + LastHttpContent duplicate(); + + @Override + LastHttpContent retainedDuplicate(); + + @Override + LastHttpContent replace(ByteBuf content); + + @Override + LastHttpContent retain(int increment); + + @Override + LastHttpContent retain(); + + @Override + LastHttpContent touch(); + + @Override + LastHttpContent touch(Object hint); +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/QueryStringDecoder.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/QueryStringDecoder.java new file mode 100644 index 0000000..70ea912 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/QueryStringDecoder.java @@ -0,0 +1,393 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.util.CharsetUtil; +import io.netty.util.internal.PlatformDependent; + +import java.net.URI; +import java.net.URLDecoder; +import java.nio.charset.Charset; +import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; + +import static io.netty.util.internal.ObjectUtil.checkNotNull; +import static io.netty.util.internal.ObjectUtil.checkPositive; +import static io.netty.util.internal.StringUtil.EMPTY_STRING; +import static io.netty.util.internal.StringUtil.SPACE; +import static io.netty.util.internal.StringUtil.decodeHexByte; + +/** + * Splits an HTTP query string into a path string and key-value parameter pairs. + * This decoder is for one time use only. Create a new instance for each URI: + *

+ * {@link QueryStringDecoder} decoder = new {@link QueryStringDecoder}("/hello?recipient=world&x=1;y=2");
+ * assert decoder.path().equals("/hello");
+ * assert decoder.parameters().get("recipient").get(0).equals("world");
+ * assert decoder.parameters().get("x").get(0).equals("1");
+ * assert decoder.parameters().get("y").get(0).equals("2");
+ * 
+ * + * This decoder can also decode the content of an HTTP POST request whose + * content type is application/x-www-form-urlencoded: + *
+ * {@link QueryStringDecoder} decoder = new {@link QueryStringDecoder}("recipient=world&x=1;y=2", false);
+ * ...
+ * 
+ * + *

HashDOS vulnerability fix

+ * + * As a workaround to the HashDOS vulnerability, the decoder + * limits the maximum number of decoded key-value parameter pairs, up to {@literal 1024} by + * default, and you can configure it when you construct the decoder by passing an additional + * integer parameter. + * + * @see QueryStringEncoder + */ +public class QueryStringDecoder { + + private static final int DEFAULT_MAX_PARAMS = 1024; + + private final Charset charset; + private final String uri; + private final int maxParams; + private final boolean semicolonIsNormalChar; + private int pathEndIdx; + private String path; + private Map> params; + + /** + * Creates a new decoder that decodes the specified URI. The decoder will + * assume that the query string is encoded in UTF-8. + */ + public QueryStringDecoder(String uri) { + this(uri, HttpConstants.DEFAULT_CHARSET); + } + + /** + * Creates a new decoder that decodes the specified URI encoded in the + * specified charset. + */ + public QueryStringDecoder(String uri, boolean hasPath) { + this(uri, HttpConstants.DEFAULT_CHARSET, hasPath); + } + + /** + * Creates a new decoder that decodes the specified URI encoded in the + * specified charset. + */ + public QueryStringDecoder(String uri, Charset charset) { + this(uri, charset, true); + } + + /** + * Creates a new decoder that decodes the specified URI encoded in the + * specified charset. + */ + public QueryStringDecoder(String uri, Charset charset, boolean hasPath) { + this(uri, charset, hasPath, DEFAULT_MAX_PARAMS); + } + + /** + * Creates a new decoder that decodes the specified URI encoded in the + * specified charset. + */ + public QueryStringDecoder(String uri, Charset charset, boolean hasPath, int maxParams) { + this(uri, charset, hasPath, maxParams, false); + } + + /** + * Creates a new decoder that decodes the specified URI encoded in the + * specified charset. + */ + public QueryStringDecoder(String uri, Charset charset, boolean hasPath, + int maxParams, boolean semicolonIsNormalChar) { + this.uri = checkNotNull(uri, "uri"); + this.charset = checkNotNull(charset, "charset"); + this.maxParams = checkPositive(maxParams, "maxParams"); + this.semicolonIsNormalChar = semicolonIsNormalChar; + + // `-1` means that path end index will be initialized lazily + pathEndIdx = hasPath ? -1 : 0; + } + + /** + * Creates a new decoder that decodes the specified URI. The decoder will + * assume that the query string is encoded in UTF-8. + */ + public QueryStringDecoder(URI uri) { + this(uri, HttpConstants.DEFAULT_CHARSET); + } + + /** + * Creates a new decoder that decodes the specified URI encoded in the + * specified charset. + */ + public QueryStringDecoder(URI uri, Charset charset) { + this(uri, charset, DEFAULT_MAX_PARAMS); + } + + /** + * Creates a new decoder that decodes the specified URI encoded in the + * specified charset. + */ + public QueryStringDecoder(URI uri, Charset charset, int maxParams) { + this(uri, charset, maxParams, false); + } + + /** + * Creates a new decoder that decodes the specified URI encoded in the + * specified charset. + */ + public QueryStringDecoder(URI uri, Charset charset, int maxParams, boolean semicolonIsNormalChar) { + String rawPath = uri.getRawPath(); + if (rawPath == null) { + rawPath = EMPTY_STRING; + } + String rawQuery = uri.getRawQuery(); + // Also take care of cut of things like "http://localhost" + this.uri = rawQuery == null? rawPath : rawPath + '?' + rawQuery; + this.charset = checkNotNull(charset, "charset"); + this.maxParams = checkPositive(maxParams, "maxParams"); + this.semicolonIsNormalChar = semicolonIsNormalChar; + pathEndIdx = rawPath.length(); + } + + @Override + public String toString() { + return uri(); + } + + /** + * Returns the uri used to initialize this {@link QueryStringDecoder}. + */ + public String uri() { + return uri; + } + + /** + * Returns the decoded path string of the URI. + */ + public String path() { + if (path == null) { + path = decodeComponent(uri, 0, pathEndIdx(), charset, true); + } + return path; + } + + /** + * Returns the decoded key-value parameter pairs of the URI. + */ + public Map> parameters() { + if (params == null) { + params = decodeParams(uri, pathEndIdx(), charset, maxParams, semicolonIsNormalChar); + } + return params; + } + + /** + * Returns the raw path string of the URI. + */ + public String rawPath() { + return uri.substring(0, pathEndIdx()); + } + + /** + * Returns raw query string of the URI. + */ + public String rawQuery() { + int start = pathEndIdx() + 1; + return start < uri.length() ? uri.substring(start) : EMPTY_STRING; + } + + private int pathEndIdx() { + if (pathEndIdx == -1) { + pathEndIdx = findPathEndIndex(uri); + } + return pathEndIdx; + } + + private static Map> decodeParams(String s, int from, Charset charset, int paramsLimit, + boolean semicolonIsNormalChar) { + int len = s.length(); + if (from >= len) { + return Collections.emptyMap(); + } + if (s.charAt(from) == '?') { + from++; + } + Map> params = new LinkedHashMap>(); + int nameStart = from; + int valueStart = -1; + int i; + loop: + for (i = from; i < len; i++) { + switch (s.charAt(i)) { + case '=': + if (nameStart == i) { + nameStart = i + 1; + } else if (valueStart < nameStart) { + valueStart = i + 1; + } + break; + case ';': + if (semicolonIsNormalChar) { + continue; + } + // fall-through + case '&': + if (addParam(s, nameStart, valueStart, i, params, charset)) { + paramsLimit--; + if (paramsLimit == 0) { + return params; + } + } + nameStart = i + 1; + break; + case '#': + break loop; + default: + // continue + } + } + addParam(s, nameStart, valueStart, i, params, charset); + return params; + } + + private static boolean addParam(String s, int nameStart, int valueStart, int valueEnd, + Map> params, Charset charset) { + if (nameStart >= valueEnd) { + return false; + } + if (valueStart <= nameStart) { + valueStart = valueEnd + 1; + } + String name = decodeComponent(s, nameStart, valueStart - 1, charset, false); + String value = decodeComponent(s, valueStart, valueEnd, charset, false); + List values = params.get(name); + if (values == null) { + values = new ArrayList(1); // Often there's only 1 value. + params.put(name, values); + } + values.add(value); + return true; + } + + /** + * Decodes a bit of a URL encoded by a browser. + *

+ * This is equivalent to calling {@link #decodeComponent(String, Charset)} + * with the UTF-8 charset (recommended to comply with RFC 3986, Section 2). + * @param s The string to decode (can be empty). + * @return The decoded string, or {@code s} if there's nothing to decode. + * If the string to decode is {@code null}, returns an empty string. + * @throws IllegalArgumentException if the string contains a malformed + * escape sequence. + */ + public static String decodeComponent(final String s) { + return decodeComponent(s, HttpConstants.DEFAULT_CHARSET); + } + + /** + * Decodes a bit of a URL encoded by a browser. + *

+ * The string is expected to be encoded as per RFC 3986, Section 2. + * This is the encoding used by JavaScript functions {@code encodeURI} + * and {@code encodeURIComponent}, but not {@code escape}. For example + * in this encoding, é (in Unicode {@code U+00E9} or in UTF-8 + * {@code 0xC3 0xA9}) is encoded as {@code %C3%A9} or {@code %c3%a9}. + *

+ * This is essentially equivalent to calling + * {@link URLDecoder#decode(String, String)} + * except that it's over 2x faster and generates less garbage for the GC. + * Actually this function doesn't allocate any memory if there's nothing + * to decode, the argument itself is returned. + * @param s The string to decode (can be empty). + * @param charset The charset to use to decode the string (should really + * be {@link CharsetUtil#UTF_8}. + * @return The decoded string, or {@code s} if there's nothing to decode. + * If the string to decode is {@code null}, returns an empty string. + * @throws IllegalArgumentException if the string contains a malformed + * escape sequence. + */ + public static String decodeComponent(final String s, final Charset charset) { + if (s == null) { + return EMPTY_STRING; + } + return decodeComponent(s, 0, s.length(), charset, false); + } + + private static String decodeComponent(String s, int from, int toExcluded, Charset charset, boolean isPath) { + int len = toExcluded - from; + if (len <= 0) { + return EMPTY_STRING; + } + int firstEscaped = -1; + for (int i = from; i < toExcluded; i++) { + char c = s.charAt(i); + if (c == '%' || c == '+' && !isPath) { + firstEscaped = i; + break; + } + } + if (firstEscaped == -1) { + return s.substring(from, toExcluded); + } + + // Each encoded byte takes 3 characters (e.g. "%20") + int decodedCapacity = (toExcluded - firstEscaped) / 3; + byte[] buf = PlatformDependent.allocateUninitializedArray(decodedCapacity); + int bufIdx; + + StringBuilder strBuf = new StringBuilder(len); + strBuf.append(s, from, firstEscaped); + + for (int i = firstEscaped; i < toExcluded; i++) { + char c = s.charAt(i); + if (c != '%') { + strBuf.append(c != '+' || isPath? c : SPACE); + continue; + } + + bufIdx = 0; + do { + if (i + 3 > toExcluded) { + throw new IllegalArgumentException("unterminated escape sequence at index " + i + " of: " + s); + } + buf[bufIdx++] = decodeHexByte(s, i + 1); + i += 3; + } while (i < toExcluded && s.charAt(i) == '%'); + i--; + + strBuf.append(new String(buf, 0, bufIdx, charset)); + } + return strBuf.toString(); + } + + private static int findPathEndIndex(String uri) { + int len = uri.length(); + for (int i = 0; i < len; i++) { + char c = uri.charAt(i); + if (c == '?' || c == '#') { + return i; + } + } + return len; + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/QueryStringEncoder.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/QueryStringEncoder.java new file mode 100644 index 0000000..2fca889 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/QueryStringEncoder.java @@ -0,0 +1,250 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.buffer.ByteBufUtil; +import io.netty.util.CharsetUtil; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.StringUtil; + +import java.net.URI; +import java.net.URISyntaxException; +import java.net.URLEncoder; +import java.nio.charset.Charset; + +/** + * Creates a URL-encoded URI from a path string and key-value parameter pairs. + * This encoder is for one time use only. Create a new instance for each URI. + * + *

+ * {@link QueryStringEncoder} encoder = new {@link QueryStringEncoder}("/hello");
+ * encoder.addParam("recipient", "world");
+ * assert encoder.toString().equals("/hello?recipient=world");
+ * 
+ * + * @see QueryStringDecoder + */ +public class QueryStringEncoder { + + private final Charset charset; + private final StringBuilder uriBuilder; + private boolean hasParams; + private static final byte WRITE_UTF_UNKNOWN = (byte) '?'; + private static final char[] CHAR_MAP = "0123456789ABCDEF".toCharArray(); + + /** + * Creates a new encoder that encodes a URI that starts with the specified + * path string. The encoder will encode the URI in UTF-8. + */ + public QueryStringEncoder(String uri) { + this(uri, HttpConstants.DEFAULT_CHARSET); + } + + /** + * Creates a new encoder that encodes a URI that starts with the specified + * path string in the specified charset. + */ + public QueryStringEncoder(String uri, Charset charset) { + ObjectUtil.checkNotNull(charset, "charset"); + uriBuilder = new StringBuilder(uri); + this.charset = CharsetUtil.UTF_8.equals(charset) ? null : charset; + } + + /** + * Adds a parameter with the specified name and value to this encoder. + */ + public void addParam(String name, String value) { + ObjectUtil.checkNotNull(name, "name"); + if (hasParams) { + uriBuilder.append('&'); + } else { + uriBuilder.append('?'); + hasParams = true; + } + + encodeComponent(name); + if (value != null) { + uriBuilder.append('='); + encodeComponent(value); + } + } + + private void encodeComponent(CharSequence s) { + if (charset == null) { + encodeUtf8Component(s); + } else { + encodeNonUtf8Component(s); + } + } + + /** + * Returns the URL-encoded URI object which was created from the path string + * specified in the constructor and the parameters added by + * {@link #addParam(String, String)} method. + */ + public URI toUri() throws URISyntaxException { + return new URI(toString()); + } + + /** + * Returns the URL-encoded URI which was created from the path string + * specified in the constructor and the parameters added by + * {@link #addParam(String, String)} method. + */ + @Override + public String toString() { + return uriBuilder.toString(); + } + + /** + * Encode the String as per RFC 3986, Section 2. + *

+ * There is a little different between the JDK's encode method : {@link URLEncoder#encode(String, String)}. + * The JDK's encoder encode the space to {@code +} and this method directly encode the blank to {@code %20} + * beyond that , this method reuse the {@link #uriBuilder} in this class rather then create a new one, + * thus generates less garbage for the GC. + * + * @param s The String to encode + */ + private void encodeNonUtf8Component(CharSequence s) { + //Don't allocate memory until needed + char[] buf = null; + + for (int i = 0, len = s.length(); i < len;) { + char c = s.charAt(i); + if (dontNeedEncoding(c)) { + uriBuilder.append(c); + i++; + } else { + int index = 0; + if (buf == null) { + buf = new char[s.length() - i]; + } + + do { + buf[index] = c; + index++; + i++; + } while (i < s.length() && !dontNeedEncoding(c = s.charAt(i))); + + byte[] bytes = new String(buf, 0, index).getBytes(charset); + + for (byte b : bytes) { + appendEncoded(b); + } + } + } + } + + /** + * @see ByteBufUtil#writeUtf8(io.netty.buffer.ByteBuf, CharSequence, int, int) + */ + private void encodeUtf8Component(CharSequence s) { + for (int i = 0, len = s.length(); i < len; i++) { + char c = s.charAt(i); + if (!dontNeedEncoding(c)) { + encodeUtf8Component(s, i, len); + return; + } + } + uriBuilder.append(s); + } + + private void encodeUtf8Component(CharSequence s, int encodingStart, int len) { + if (encodingStart > 0) { + // Append non-encoded characters directly first. + uriBuilder.append(s, 0, encodingStart); + } + encodeUtf8ComponentSlow(s, encodingStart, len); + } + + private void encodeUtf8ComponentSlow(CharSequence s, int start, int len) { + for (int i = start; i < len; i++) { + char c = s.charAt(i); + if (c < 0x80) { + if (dontNeedEncoding(c)) { + uriBuilder.append(c); + } else { + appendEncoded(c); + } + } else if (c < 0x800) { + appendEncoded(0xc0 | (c >> 6)); + appendEncoded(0x80 | (c & 0x3f)); + } else if (StringUtil.isSurrogate(c)) { + if (!Character.isHighSurrogate(c)) { + appendEncoded(WRITE_UTF_UNKNOWN); + continue; + } + // Surrogate Pair consumes 2 characters. + if (++i == s.length()) { + appendEncoded(WRITE_UTF_UNKNOWN); + break; + } + // Extra method to allow inlining the rest of writeUtf8 which is the most likely code path. + writeUtf8Surrogate(c, s.charAt(i)); + } else { + appendEncoded(0xe0 | (c >> 12)); + appendEncoded(0x80 | ((c >> 6) & 0x3f)); + appendEncoded(0x80 | (c & 0x3f)); + } + } + } + + private void writeUtf8Surrogate(char c, char c2) { + if (!Character.isLowSurrogate(c2)) { + appendEncoded(WRITE_UTF_UNKNOWN); + appendEncoded(Character.isHighSurrogate(c2) ? WRITE_UTF_UNKNOWN : c2); + return; + } + int codePoint = Character.toCodePoint(c, c2); + // See https://www.unicode.org/versions/Unicode7.0.0/ch03.pdf#G2630. + appendEncoded(0xf0 | (codePoint >> 18)); + appendEncoded(0x80 | ((codePoint >> 12) & 0x3f)); + appendEncoded(0x80 | ((codePoint >> 6) & 0x3f)); + appendEncoded(0x80 | (codePoint & 0x3f)); + } + + private void appendEncoded(int b) { + uriBuilder.append('%').append(forDigit(b >> 4)).append(forDigit(b)); + } + + /** + * Convert the given digit to a upper hexadecimal char. + * + * @param digit the number to convert to a character. + * @return the {@code char} representation of the specified digit + * in hexadecimal. + */ + private static char forDigit(int digit) { + return CHAR_MAP[digit & 0xF]; + } + + /** + * Determines whether the given character is a unreserved character. + *

+ * unreserved characters do not need to be encoded, and include uppercase and lowercase + * letters, decimal digits, hyphen, period, underscore, and tilde. + *

+ * unreserved = ALPHA / DIGIT / "-" / "." / "_" / "~" / "*" + * + * @param ch the char to be judged whether it need to be encode + * @return true or false + */ + private static boolean dontNeedEncoding(char ch) { + return ch >= 'a' && ch <= 'z' || ch >= 'A' && ch <= 'Z' || ch >= '0' && ch <= '9' + || ch == '-' || ch == '_' || ch == '.' || ch == '*' || ch == '~'; + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/ReadOnlyHttpHeaders.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/ReadOnlyHttpHeaders.java new file mode 100644 index 0000000..d7639ca --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/ReadOnlyHttpHeaders.java @@ -0,0 +1,459 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.util.AsciiString; +import io.netty.util.internal.UnstableApi; + +import java.util.AbstractMap.SimpleImmutableEntry; +import java.util.ArrayList; +import java.util.Collections; +import java.util.Iterator; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Set; + +import static io.netty.handler.codec.CharSequenceValueConverter.INSTANCE; +import static io.netty.util.AsciiString.contentEquals; +import static io.netty.util.AsciiString.contentEqualsIgnoreCase; + +/** + * A variant of {@link HttpHeaders} which only supports read-only methods. + *

+ * Any array passed to this class may be used directly in the underlying data structures of this class. If these + * arrays may be modified it is the caller's responsibility to supply this class with a copy of the array. + *

+ * This may be a good alternative to {@link DefaultHttpHeaders} if your have a fixed set of headers which will not + * change. + */ +@UnstableApi +public final class ReadOnlyHttpHeaders extends HttpHeaders { + private final CharSequence[] nameValuePairs; + + /** + * Create a new instance. + * @param validateHeaders {@code true} to validate the contents of each header name. + * @param nameValuePairs An array of the structure {@code [,,...]}. + * A copy will NOT be made of this array. If the contents of this array + * may be modified externally you are responsible for passing in a copy. + */ + public ReadOnlyHttpHeaders(boolean validateHeaders, CharSequence... nameValuePairs) { + if ((nameValuePairs.length & 1) != 0) { + throw newInvalidArraySizeException(); + } + if (validateHeaders) { + validateHeaders(nameValuePairs); + } + this.nameValuePairs = nameValuePairs; + } + + private static IllegalArgumentException newInvalidArraySizeException() { + return new IllegalArgumentException("nameValuePairs must be arrays of [name, value] pairs"); + } + + private static void validateHeaders(CharSequence... keyValuePairs) { + for (int i = 0; i < keyValuePairs.length; i += 2) { + DefaultHttpHeadersFactory.headersFactory().getNameValidator().validateName(keyValuePairs[i]); + } + } + + private CharSequence get0(CharSequence name) { + final int nameHash = AsciiString.hashCode(name); + for (int i = 0; i < nameValuePairs.length; i += 2) { + CharSequence roName = nameValuePairs[i]; + if (AsciiString.hashCode(roName) == nameHash && contentEqualsIgnoreCase(roName, name)) { + // Suppress a warning out of bounds access since the constructor allows only pairs + return nameValuePairs[i + 1]; + } + } + return null; + } + + @Override + public String get(String name) { + CharSequence value = get0(name); + return value == null ? null : value.toString(); + } + + @Override + public Integer getInt(CharSequence name) { + CharSequence value = get0(name); + return value == null ? null : INSTANCE.convertToInt(value); + } + + @Override + public int getInt(CharSequence name, int defaultValue) { + CharSequence value = get0(name); + return value == null ? defaultValue : INSTANCE.convertToInt(value); + } + + @Override + public Short getShort(CharSequence name) { + CharSequence value = get0(name); + return value == null ? null : INSTANCE.convertToShort(value); + } + + @Override + public short getShort(CharSequence name, short defaultValue) { + CharSequence value = get0(name); + return value == null ? defaultValue : INSTANCE.convertToShort(value); + } + + @Override + public Long getTimeMillis(CharSequence name) { + CharSequence value = get0(name); + return value == null ? null : INSTANCE.convertToTimeMillis(value); + } + + @Override + public long getTimeMillis(CharSequence name, long defaultValue) { + CharSequence value = get0(name); + return value == null ? defaultValue : INSTANCE.convertToTimeMillis(value); + } + + @Override + public List getAll(String name) { + if (isEmpty()) { + return Collections.emptyList(); + } + final int nameHash = AsciiString.hashCode(name); + List values = new ArrayList(4); + for (int i = 0; i < nameValuePairs.length; i += 2) { + CharSequence roName = nameValuePairs[i]; + if (AsciiString.hashCode(roName) == nameHash && contentEqualsIgnoreCase(roName, name)) { + values.add(nameValuePairs[i + 1].toString()); + } + } + return values; + } + + @Override + public List> entries() { + if (isEmpty()) { + return Collections.emptyList(); + } + List> entries = new ArrayList>(size()); + for (int i = 0; i < nameValuePairs.length; i += 2) { + entries.add(new SimpleImmutableEntry(nameValuePairs[i].toString(), + nameValuePairs[i + 1].toString())); // [java/index-out-of-bounds] + } + return entries; + } + + @Override + public boolean contains(String name) { + return get0(name) != null; + } + + @Override + public boolean contains(String name, String value, boolean ignoreCase) { + return containsValue(name, value, ignoreCase); + } + + @Override + public boolean containsValue(CharSequence name, CharSequence value, boolean ignoreCase) { + if (ignoreCase) { + for (int i = 0; i < nameValuePairs.length; i += 2) { + if (contentEqualsIgnoreCase(nameValuePairs[i], name) && + contentEqualsIgnoreCase(nameValuePairs[i + 1], value)) { + return true; + } + } + } else { + for (int i = 0; i < nameValuePairs.length; i += 2) { + if (contentEqualsIgnoreCase(nameValuePairs[i], name) && + contentEquals(nameValuePairs[i + 1], value)) { + return true; + } + } + } + return false; + } + + @Override + public Iterator valueStringIterator(CharSequence name) { + return new ReadOnlyStringValueIterator(name); + } + + @Override + public Iterator valueCharSequenceIterator(CharSequence name) { + return new ReadOnlyValueIterator(name); + } + + @Override + public Iterator> iterator() { + return new ReadOnlyStringIterator(); + } + + @Override + public Iterator> iteratorCharSequence() { + return new ReadOnlyIterator(); + } + + @Override + public boolean isEmpty() { + return nameValuePairs.length == 0; + } + + @Override + public int size() { + return nameValuePairs.length >>> 1; + } + + @Override + public Set names() { + if (isEmpty()) { + return Collections.emptySet(); + } + Set names = new LinkedHashSet(size()); + for (int i = 0; i < nameValuePairs.length; i += 2) { + names.add(nameValuePairs[i].toString()); + } + return names; + } + + @Override + public HttpHeaders add(String name, Object value) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public HttpHeaders add(String name, Iterable values) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public HttpHeaders addInt(CharSequence name, int value) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public HttpHeaders addShort(CharSequence name, short value) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public HttpHeaders set(String name, Object value) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public HttpHeaders set(String name, Iterable values) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public HttpHeaders setInt(CharSequence name, int value) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public HttpHeaders setShort(CharSequence name, short value) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public HttpHeaders remove(String name) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public HttpHeaders clear() { + throw new UnsupportedOperationException("read only"); + } + + private final class ReadOnlyIterator implements Map.Entry, + Iterator> { + private CharSequence key; + private CharSequence value; + private int nextNameIndex; + + @Override + public boolean hasNext() { + return nextNameIndex != nameValuePairs.length; + } + + @Override + public Map.Entry next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + key = nameValuePairs[nextNameIndex]; + value = nameValuePairs[nextNameIndex + 1]; + nextNameIndex += 2; + return this; + } + + @Override + public void remove() { + throw new UnsupportedOperationException("read only"); + } + + @Override + public CharSequence getKey() { + return key; + } + + @Override + public CharSequence getValue() { + return value; + } + + @Override + public CharSequence setValue(CharSequence value) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public String toString() { + return key.toString() + '=' + value.toString(); + } + } + + private final class ReadOnlyStringIterator implements Map.Entry, + Iterator> { + private String key; + private String value; + private int nextNameIndex; + + @Override + public boolean hasNext() { + return nextNameIndex != nameValuePairs.length; + } + + @Override + public Map.Entry next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + key = nameValuePairs[nextNameIndex].toString(); + value = nameValuePairs[nextNameIndex + 1].toString(); + nextNameIndex += 2; + return this; + } + + @Override + public void remove() { + throw new UnsupportedOperationException("read only"); + } + + @Override + public String getKey() { + return key; + } + + @Override + public String getValue() { + return value; + } + + @Override + public String setValue(String value) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public String toString() { + return key + '=' + value; + } + } + + private final class ReadOnlyStringValueIterator implements Iterator { + private final CharSequence name; + private final int nameHash; + private int nextNameIndex; + + ReadOnlyStringValueIterator(CharSequence name) { + this.name = name; + nameHash = AsciiString.hashCode(name); + nextNameIndex = findNextValue(); + } + + @Override + public boolean hasNext() { + return nextNameIndex != -1; + } + + @Override + public String next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + String value = nameValuePairs[nextNameIndex + 1].toString(); + nextNameIndex = findNextValue(); + return value; + } + + @Override + public void remove() { + throw new UnsupportedOperationException("read only"); + } + + private int findNextValue() { + for (int i = nextNameIndex; i < nameValuePairs.length; i += 2) { + final CharSequence roName = nameValuePairs[i]; + if (nameHash == AsciiString.hashCode(roName) && contentEqualsIgnoreCase(name, roName)) { + return i; + } + } + return -1; + } + } + + private final class ReadOnlyValueIterator implements Iterator { + private final CharSequence name; + private final int nameHash; + private int nextNameIndex; + + ReadOnlyValueIterator(CharSequence name) { + this.name = name; + nameHash = AsciiString.hashCode(name); + nextNameIndex = findNextValue(); + } + + @Override + public boolean hasNext() { + return nextNameIndex != -1; + } + + @Override + public CharSequence next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + CharSequence value = nameValuePairs[nextNameIndex + 1]; + nextNameIndex = findNextValue(); + return value; + } + + @Override + public void remove() { + throw new UnsupportedOperationException("read only"); + } + + private int findNextValue() { + for (int i = nextNameIndex; i < nameValuePairs.length; i += 2) { + final CharSequence roName = nameValuePairs[i]; + if (nameHash == AsciiString.hashCode(roName) && contentEqualsIgnoreCase(name, roName)) { + return i; + } + } + return -1; + } + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/ServerCookieEncoder.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/ServerCookieEncoder.java new file mode 100644 index 0000000..f8120fa --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/ServerCookieEncoder.java @@ -0,0 +1,103 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.handler.codec.http.cookie.ServerCookieDecoder; + +import java.util.Collection; +import java.util.List; + +/** + * A RFC6265 compliant cookie encoder to be used server side, + * so some fields are sent (Version is typically ignored). + * + * As Netty's Cookie merges Expires and MaxAge into one single field, only Max-Age field is sent. + * + * Note that multiple cookies must be sent as separate "Set-Cookie" headers. + * + *

+ * // Example
+ * {@link HttpResponse} res = ...;
+ * res.setHeader("Set-Cookie", {@link ServerCookieEncoder}.encode("JSESSIONID", "1234"));
+ * 
+ * + * @see ServerCookieDecoder + * + * @deprecated Use {@link io.netty.handler.codec.http.cookie.ServerCookieEncoder} instead + */ +@Deprecated +public final class ServerCookieEncoder { + + /** + * Encodes the specified cookie name-value pair into a Set-Cookie header value. + * + * @param name the cookie name + * @param value the cookie value + * @return a single Set-Cookie header value + */ + @Deprecated + public static String encode(String name, String value) { + return io.netty.handler.codec.http.cookie.ServerCookieEncoder.LAX.encode(name, value); + } + + /** + * Encodes the specified cookie into a Set-Cookie header value. + * + * @param cookie the cookie + * @return a single Set-Cookie header value + */ + @Deprecated + public static String encode(Cookie cookie) { + return io.netty.handler.codec.http.cookie.ServerCookieEncoder.LAX.encode(cookie); + } + + /** + * Batch encodes cookies into Set-Cookie header values. + * + * @param cookies a bunch of cookies + * @return the corresponding bunch of Set-Cookie headers + */ + @Deprecated + public static List encode(Cookie... cookies) { + return io.netty.handler.codec.http.cookie.ServerCookieEncoder.LAX.encode(cookies); + } + + /** + * Batch encodes cookies into Set-Cookie header values. + * + * @param cookies a bunch of cookies + * @return the corresponding bunch of Set-Cookie headers + */ + @Deprecated + public static List encode(Collection cookies) { + return io.netty.handler.codec.http.cookie.ServerCookieEncoder.LAX.encode(cookies); + } + + /** + * Batch encodes cookies into Set-Cookie header values. + * + * @param cookies a bunch of cookies + * @return the corresponding bunch of Set-Cookie headers + */ + @Deprecated + public static List encode(Iterable cookies) { + return io.netty.handler.codec.http.cookie.ServerCookieEncoder.LAX.encode(cookies); + } + + private ServerCookieEncoder() { + // Unused + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/TooLongHttpContentException.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/TooLongHttpContentException.java new file mode 100644 index 0000000..8e61212 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/TooLongHttpContentException.java @@ -0,0 +1,54 @@ +/* + * Copyright 2022 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.handler.codec.TooLongFrameException; + +/** + * An {@link TooLongFrameException} which is thrown when the length of the + * content decoded is greater than the allowed maximum. + */ +public final class TooLongHttpContentException extends TooLongFrameException { + + private static final long serialVersionUID = 3238341182129476117L; + + /** + * Creates a new instance. + */ + public TooLongHttpContentException() { + } + + /** + * Creates a new instance. + */ + public TooLongHttpContentException(String message, Throwable cause) { + super(message, cause); + } + + /** + * Creates a new instance. + */ + public TooLongHttpContentException(String message) { + super(message); + } + + /** + * Creates a new instance. + */ + public TooLongHttpContentException(Throwable cause) { + super(cause); + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/TooLongHttpHeaderException.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/TooLongHttpHeaderException.java new file mode 100644 index 0000000..433b1b2 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/TooLongHttpHeaderException.java @@ -0,0 +1,54 @@ +/* + * Copyright 2022 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.handler.codec.TooLongFrameException; + +/** + * An {@link TooLongFrameException} which is thrown when the length of the + * header decoded is greater than the allowed maximum. + */ +public final class TooLongHttpHeaderException extends TooLongFrameException { + + private static final long serialVersionUID = -8295159138628369730L; + + /** + * Creates a new instance. + */ + public TooLongHttpHeaderException() { + } + + /** + * Creates a new instance. + */ + public TooLongHttpHeaderException(String message, Throwable cause) { + super(message, cause); + } + + /** + * Creates a new instance. + */ + public TooLongHttpHeaderException(String message) { + super(message); + } + + /** + * Creates a new instance. + */ + public TooLongHttpHeaderException(Throwable cause) { + super(cause); + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/TooLongHttpLineException.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/TooLongHttpLineException.java new file mode 100644 index 0000000..2f30b53 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/TooLongHttpLineException.java @@ -0,0 +1,54 @@ +/* + * Copyright 2022 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.handler.codec.TooLongFrameException; + +/** + * An {@link TooLongFrameException} which is thrown when the length of the + * line decoded is greater than the allowed maximum. + */ +public final class TooLongHttpLineException extends TooLongFrameException { + + private static final long serialVersionUID = 1614751125592211890L; + + /** + * Creates a new instance. + */ + public TooLongHttpLineException() { + } + + /** + * Creates a new instance. + */ + public TooLongHttpLineException(String message, Throwable cause) { + super(message, cause); + } + + /** + * Creates a new instance. + */ + public TooLongHttpLineException(String message) { + super(message); + } + + /** + * Creates a new instance. + */ + public TooLongHttpLineException(Throwable cause) { + super(cause); + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/cookie/ClientCookieDecoder.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/cookie/ClientCookieDecoder.java new file mode 100644 index 0000000..e143b31 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/cookie/ClientCookieDecoder.java @@ -0,0 +1,263 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.cookie; + +import io.netty.handler.codec.DateFormatter; +import io.netty.handler.codec.http.cookie.CookieHeaderNames.SameSite; + +import java.util.Date; + +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +/** + * A RFC6265 compliant cookie decoder to be used client side. + * + * It will store the way the raw value was wrapped in {@link Cookie#setWrap(boolean)} so it can be + * eventually sent back to the Origin server as is. + * + * @see ClientCookieEncoder + */ +public final class ClientCookieDecoder extends CookieDecoder { + + /** + * Strict encoder that validates that name and value chars are in the valid scope + * defined in RFC6265 + */ + public static final ClientCookieDecoder STRICT = new ClientCookieDecoder(true); + + /** + * Lax instance that doesn't validate name and value + */ + public static final ClientCookieDecoder LAX = new ClientCookieDecoder(false); + + private ClientCookieDecoder(boolean strict) { + super(strict); + } + + /** + * Decodes the specified Set-Cookie HTTP header value into a {@link Cookie}. + * + * @return the decoded {@link Cookie} + */ + public Cookie decode(String header) { + final int headerLen = checkNotNull(header, "header").length(); + + if (headerLen == 0) { + return null; + } + + CookieBuilder cookieBuilder = null; + + loop: for (int i = 0;;) { + + // Skip spaces and separators. + for (;;) { + if (i == headerLen) { + break loop; + } + char c = header.charAt(i); + if (c == ',') { + // Having multiple cookies in a single Set-Cookie header is + // deprecated, modern browsers only parse the first one + break loop; + + } else if (c == '\t' || c == '\n' || c == 0x0b || c == '\f' + || c == '\r' || c == ' ' || c == ';') { + i++; + continue; + } + break; + } + + int nameBegin = i; + int nameEnd; + int valueBegin; + int valueEnd; + + for (;;) { + char curChar = header.charAt(i); + if (curChar == ';') { + // NAME; (no value till ';') + nameEnd = i; + valueBegin = valueEnd = -1; + break; + + } else if (curChar == '=') { + // NAME=VALUE + nameEnd = i; + i++; + if (i == headerLen) { + // NAME= (empty value, i.e. nothing after '=') + valueBegin = valueEnd = 0; + break; + } + + valueBegin = i; + // NAME=VALUE; + int semiPos = header.indexOf(';', i); + valueEnd = i = semiPos > 0 ? semiPos : headerLen; + break; + } else { + i++; + } + + if (i == headerLen) { + // NAME (no value till the end of string) + nameEnd = headerLen; + valueBegin = valueEnd = -1; + break; + } + } + + if (valueEnd > 0 && header.charAt(valueEnd - 1) == ',') { + // old multiple cookies separator, skipping it + valueEnd--; + } + + if (cookieBuilder == null) { + // cookie name-value pair + DefaultCookie cookie = initCookie(header, nameBegin, nameEnd, valueBegin, valueEnd); + + if (cookie == null) { + return null; + } + + cookieBuilder = new CookieBuilder(cookie, header); + } else { + // cookie attribute + cookieBuilder.appendAttribute(nameBegin, nameEnd, valueBegin, valueEnd); + } + } + return cookieBuilder != null ? cookieBuilder.cookie() : null; + } + + private static class CookieBuilder { + + private final String header; + private final DefaultCookie cookie; + private String domain; + private String path; + private long maxAge = Long.MIN_VALUE; + private int expiresStart; + private int expiresEnd; + private boolean secure; + private boolean httpOnly; + private SameSite sameSite; + + CookieBuilder(DefaultCookie cookie, String header) { + this.cookie = cookie; + this.header = header; + } + + private long mergeMaxAgeAndExpires() { + // max age has precedence over expires + if (maxAge != Long.MIN_VALUE) { + return maxAge; + } else if (isValueDefined(expiresStart, expiresEnd)) { + Date expiresDate = DateFormatter.parseHttpDate(header, expiresStart, expiresEnd); + if (expiresDate != null) { + long maxAgeMillis = expiresDate.getTime() - System.currentTimeMillis(); + return maxAgeMillis / 1000 + (maxAgeMillis % 1000 != 0 ? 1 : 0); + } + } + return Long.MIN_VALUE; + } + + Cookie cookie() { + cookie.setDomain(domain); + cookie.setPath(path); + cookie.setMaxAge(mergeMaxAgeAndExpires()); + cookie.setSecure(secure); + cookie.setHttpOnly(httpOnly); + cookie.setSameSite(sameSite); + return cookie; + } + + /** + * Parse and store a key-value pair. First one is considered to be the + * cookie name/value. Unknown attribute names are silently discarded. + * + * @param keyStart + * where the key starts in the header + * @param keyEnd + * where the key ends in the header + * @param valueStart + * where the value starts in the header + * @param valueEnd + * where the value ends in the header + */ + void appendAttribute(int keyStart, int keyEnd, int valueStart, int valueEnd) { + int length = keyEnd - keyStart; + + if (length == 4) { + parse4(keyStart, valueStart, valueEnd); + } else if (length == 6) { + parse6(keyStart, valueStart, valueEnd); + } else if (length == 7) { + parse7(keyStart, valueStart, valueEnd); + } else if (length == 8) { + parse8(keyStart, valueStart, valueEnd); + } + } + + private void parse4(int nameStart, int valueStart, int valueEnd) { + if (header.regionMatches(true, nameStart, CookieHeaderNames.PATH, 0, 4)) { + path = computeValue(valueStart, valueEnd); + } + } + + private void parse6(int nameStart, int valueStart, int valueEnd) { + if (header.regionMatches(true, nameStart, CookieHeaderNames.DOMAIN, 0, 5)) { + domain = computeValue(valueStart, valueEnd); + } else if (header.regionMatches(true, nameStart, CookieHeaderNames.SECURE, 0, 5)) { + secure = true; + } + } + + private void setMaxAge(String value) { + try { + maxAge = Math.max(Long.parseLong(value), 0L); + } catch (NumberFormatException e1) { + // ignore failure to parse -> treat as session cookie + } + } + + private void parse7(int nameStart, int valueStart, int valueEnd) { + if (header.regionMatches(true, nameStart, CookieHeaderNames.EXPIRES, 0, 7)) { + expiresStart = valueStart; + expiresEnd = valueEnd; + } else if (header.regionMatches(true, nameStart, CookieHeaderNames.MAX_AGE, 0, 7)) { + setMaxAge(computeValue(valueStart, valueEnd)); + } + } + + private void parse8(int nameStart, int valueStart, int valueEnd) { + if (header.regionMatches(true, nameStart, CookieHeaderNames.HTTPONLY, 0, 8)) { + httpOnly = true; + } else if (header.regionMatches(true, nameStart, CookieHeaderNames.SAMESITE, 0, 8)) { + sameSite = SameSite.of(computeValue(valueStart, valueEnd)); + } + } + + private static boolean isValueDefined(int valueStart, int valueEnd) { + return valueStart != -1 && valueStart != valueEnd; + } + + private String computeValue(int valueStart, int valueEnd) { + return isValueDefined(valueStart, valueEnd) ? header.substring(valueStart, valueEnd) : null; + } + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/cookie/ClientCookieEncoder.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/cookie/ClientCookieEncoder.java new file mode 100644 index 0000000..f542e3b --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/cookie/ClientCookieEncoder.java @@ -0,0 +1,225 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.cookie; + +import static io.netty.handler.codec.http.cookie.CookieUtil.add; +import static io.netty.handler.codec.http.cookie.CookieUtil.addQuoted; +import static io.netty.handler.codec.http.cookie.CookieUtil.stringBuilder; +import static io.netty.handler.codec.http.cookie.CookieUtil.stripTrailingSeparator; +import static io.netty.handler.codec.http.cookie.CookieUtil.stripTrailingSeparatorOrNull; +import static io.netty.util.internal.ObjectUtil.checkNotNull; +import io.netty.handler.codec.http.HttpRequest; +import io.netty.util.internal.InternalThreadLocalMap; + +import java.util.Arrays; +import java.util.Collection; +import java.util.Comparator; +import java.util.Iterator; +import java.util.List; + +/** + * A RFC6265 compliant cookie encoder to be used client side, so + * only name=value pairs are sent. + * + * Note that multiple cookies are supposed to be sent at once in a single "Cookie" header. + * + *
+ * // Example
+ * {@link HttpRequest} req = ...;
+ * res.setHeader("Cookie", {@link ClientCookieEncoder}.encode("JSESSIONID", "1234"));
+ * 
+ * + * @see ClientCookieDecoder + */ +public final class ClientCookieEncoder extends CookieEncoder { + + /** + * Strict encoder that validates that name and value chars are in the valid scope and (for methods that accept + * multiple cookies) sorts cookies into order of decreasing path length, as specified in RFC6265. + */ + public static final ClientCookieEncoder STRICT = new ClientCookieEncoder(true); + + /** + * Lax instance that doesn't validate name and value, and (for methods that accept multiple cookies) keeps + * cookies in the order in which they were given. + */ + public static final ClientCookieEncoder LAX = new ClientCookieEncoder(false); + + private ClientCookieEncoder(boolean strict) { + super(strict); + } + + /** + * Encodes the specified cookie into a Cookie header value. + * + * @param name + * the cookie name + * @param value + * the cookie value + * @return a Rfc6265 style Cookie header value + */ + public String encode(String name, String value) { + return encode(new DefaultCookie(name, value)); + } + + /** + * Encodes the specified cookie into a Cookie header value. + * + * @param cookie the specified cookie + * @return a Rfc6265 style Cookie header value + */ + public String encode(Cookie cookie) { + StringBuilder buf = stringBuilder(); + encode(buf, checkNotNull(cookie, "cookie")); + return stripTrailingSeparator(buf); + } + + /** + * Sort cookies into decreasing order of path length, breaking ties by sorting into increasing chronological + * order of creation time, as recommended by RFC 6265. + */ + // package-private for testing only. + static final Comparator COOKIE_COMPARATOR = new Comparator() { + @Override + public int compare(Cookie c1, Cookie c2) { + String path1 = c1.path(); + String path2 = c2.path(); + // Cookies with unspecified path default to the path of the request. We don't + // know the request path here, but we assume that the length of an unspecified + // path is longer than any specified path (i.e. pathless cookies come first), + // because setting cookies with a path longer than the request path is of + // limited use. + int len1 = path1 == null ? Integer.MAX_VALUE : path1.length(); + int len2 = path2 == null ? Integer.MAX_VALUE : path2.length(); + + // Rely on Arrays.sort's stability to retain creation order in cases where + // cookies have same path length. + return len2 - len1; + } + }; + + /** + * Encodes the specified cookies into a single Cookie header value. + * + * @param cookies + * some cookies + * @return a Rfc6265 style Cookie header value, null if no cookies are passed. + */ + public String encode(Cookie... cookies) { + if (checkNotNull(cookies, "cookies").length == 0) { + return null; + } + + StringBuilder buf = stringBuilder(); + if (strict) { + if (cookies.length == 1) { + encode(buf, cookies[0]); + } else { + Cookie[] cookiesSorted = Arrays.copyOf(cookies, cookies.length); + Arrays.sort(cookiesSorted, COOKIE_COMPARATOR); + for (Cookie c : cookiesSorted) { + encode(buf, c); + } + } + } else { + for (Cookie c : cookies) { + encode(buf, c); + } + } + return stripTrailingSeparatorOrNull(buf); + } + + /** + * Encodes the specified cookies into a single Cookie header value. + * + * @param cookies + * some cookies + * @return a Rfc6265 style Cookie header value, null if no cookies are passed. + */ + public String encode(Collection cookies) { + if (checkNotNull(cookies, "cookies").isEmpty()) { + return null; + } + + StringBuilder buf = stringBuilder(); + if (strict) { + if (cookies.size() == 1) { + encode(buf, cookies.iterator().next()); + } else { + Cookie[] cookiesSorted = cookies.toArray(new Cookie[0]); + Arrays.sort(cookiesSorted, COOKIE_COMPARATOR); + for (Cookie c : cookiesSorted) { + encode(buf, c); + } + } + } else { + for (Cookie c : cookies) { + encode(buf, c); + } + } + return stripTrailingSeparatorOrNull(buf); + } + + /** + * Encodes the specified cookies into a single Cookie header value. + * + * @param cookies some cookies + * @return a Rfc6265 style Cookie header value, null if no cookies are passed. + */ + public String encode(Iterable cookies) { + Iterator cookiesIt = checkNotNull(cookies, "cookies").iterator(); + if (!cookiesIt.hasNext()) { + return null; + } + + StringBuilder buf = stringBuilder(); + if (strict) { + Cookie firstCookie = cookiesIt.next(); + if (!cookiesIt.hasNext()) { + encode(buf, firstCookie); + } else { + List cookiesList = InternalThreadLocalMap.get().arrayList(); + cookiesList.add(firstCookie); + while (cookiesIt.hasNext()) { + cookiesList.add(cookiesIt.next()); + } + Cookie[] cookiesSorted = cookiesList.toArray(new Cookie[0]); + Arrays.sort(cookiesSorted, COOKIE_COMPARATOR); + for (Cookie c : cookiesSorted) { + encode(buf, c); + } + } + } else { + while (cookiesIt.hasNext()) { + encode(buf, cookiesIt.next()); + } + } + return stripTrailingSeparatorOrNull(buf); + } + + private void encode(StringBuilder buf, Cookie c) { + final String name = c.name(); + final String value = c.value() != null ? c.value() : ""; + + validateCookie(name, value); + + if (c.wrap()) { + addQuoted(buf, name, value); + } else { + add(buf, name, value); + } + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/cookie/Cookie.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/cookie/Cookie.java new file mode 100644 index 0000000..16bfb0f --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/cookie/Cookie.java @@ -0,0 +1,146 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.cookie; + +/** + * An interface defining an + * HTTP cookie. + */ +public interface Cookie extends Comparable { + + /** + * Constant for undefined MaxAge attribute value. + */ + long UNDEFINED_MAX_AGE = Long.MIN_VALUE; + + /** + * Returns the name of this {@link Cookie}. + * + * @return The name of this {@link Cookie} + */ + String name(); + + /** + * Returns the value of this {@link Cookie}. + * + * @return The value of this {@link Cookie} + */ + String value(); + + /** + * Sets the value of this {@link Cookie}. + * + * @param value The value to set + */ + void setValue(String value); + + /** + * Returns true if the raw value of this {@link Cookie}, + * was wrapped with double quotes in original Set-Cookie header. + * + * @return If the value of this {@link Cookie} is to be wrapped + */ + boolean wrap(); + + /** + * Sets true if the value of this {@link Cookie} + * is to be wrapped with double quotes. + * + * @param wrap true if wrap + */ + void setWrap(boolean wrap); + + /** + * Returns the domain of this {@link Cookie}. + * + * @return The domain of this {@link Cookie} + */ + String domain(); + + /** + * Sets the domain of this {@link Cookie}. + * + * @param domain The domain to use + */ + void setDomain(String domain); + + /** + * Returns the path of this {@link Cookie}. + * + * @return The {@link Cookie}'s path + */ + String path(); + + /** + * Sets the path of this {@link Cookie}. + * + * @param path The path to use for this {@link Cookie} + */ + void setPath(String path); + + /** + * Returns the maximum age of this {@link Cookie} in seconds or {@link Cookie#UNDEFINED_MAX_AGE} if unspecified + * + * @return The maximum age of this {@link Cookie} + */ + long maxAge(); + + /** + * Sets the maximum age of this {@link Cookie} in seconds. + * If an age of {@code 0} is specified, this {@link Cookie} will be + * automatically removed by browser because it will expire immediately. + * If {@link Cookie#UNDEFINED_MAX_AGE} is specified, this {@link Cookie} will be removed when the + * browser is closed. + * + * @param maxAge The maximum age of this {@link Cookie} in seconds + */ + void setMaxAge(long maxAge); + + /** + * Checks to see if this {@link Cookie} is secure + * + * @return True if this {@link Cookie} is secure, otherwise false + */ + boolean isSecure(); + + /** + * Sets the security getStatus of this {@link Cookie} + * + * @param secure True if this {@link Cookie} is to be secure, otherwise false + */ + void setSecure(boolean secure); + + /** + * Checks to see if this {@link Cookie} can only be accessed via HTTP. + * If this returns true, the {@link Cookie} cannot be accessed through + * client side script - But only if the browser supports it. + * For more information, please look here + * + * @return True if this {@link Cookie} is HTTP-only or false if it isn't + */ + boolean isHttpOnly(); + + /** + * Determines if this {@link Cookie} is HTTP only. + * If set to true, this {@link Cookie} cannot be accessed by a client + * side script. However, this works only if the browser supports it. + * For information, please look + * here. + * + * @param httpOnly True if the {@link Cookie} is HTTP only, otherwise false. + */ + void setHttpOnly(boolean httpOnly); +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/cookie/CookieDecoder.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/cookie/CookieDecoder.java new file mode 100644 index 0000000..97edaa4 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/cookie/CookieDecoder.java @@ -0,0 +1,84 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.cookie; + +import static io.netty.handler.codec.http.cookie.CookieUtil.firstInvalidCookieNameOctet; +import static io.netty.handler.codec.http.cookie.CookieUtil.firstInvalidCookieValueOctet; +import static io.netty.handler.codec.http.cookie.CookieUtil.unwrapValue; + +import java.nio.CharBuffer; + +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +/** + * Parent of Client and Server side cookie decoders + */ +public abstract class CookieDecoder { + + private final InternalLogger logger = InternalLoggerFactory.getInstance(getClass()); + + private final boolean strict; + + protected CookieDecoder(boolean strict) { + this.strict = strict; + } + + protected DefaultCookie initCookie(String header, int nameBegin, int nameEnd, int valueBegin, int valueEnd) { + if (nameBegin == -1 || nameBegin == nameEnd) { + logger.debug("Skipping cookie with null name"); + return null; + } + + if (valueBegin == -1) { + logger.debug("Skipping cookie with null value"); + return null; + } + + CharSequence wrappedValue = CharBuffer.wrap(header, valueBegin, valueEnd); + CharSequence unwrappedValue = unwrapValue(wrappedValue); + if (unwrappedValue == null) { + logger.debug("Skipping cookie because starting quotes are not properly balanced in '{}'", + wrappedValue); + return null; + } + + final String name = header.substring(nameBegin, nameEnd); + + int invalidOctetPos; + if (strict && (invalidOctetPos = firstInvalidCookieNameOctet(name)) >= 0) { + if (logger.isDebugEnabled()) { + logger.debug("Skipping cookie because name '{}' contains invalid char '{}'", + name, name.charAt(invalidOctetPos)); + } + return null; + } + + final boolean wrap = unwrappedValue.length() != valueEnd - valueBegin; + + if (strict && (invalidOctetPos = firstInvalidCookieValueOctet(unwrappedValue)) >= 0) { + if (logger.isDebugEnabled()) { + logger.debug("Skipping cookie because value '{}' contains invalid char '{}'", + unwrappedValue, unwrappedValue.charAt(invalidOctetPos)); + } + return null; + } + + DefaultCookie cookie = new DefaultCookie(name, unwrappedValue.toString()); + cookie.setWrap(wrap); + return cookie; + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/cookie/CookieEncoder.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/cookie/CookieEncoder.java new file mode 100644 index 0000000..a1f20f1 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/cookie/CookieEncoder.java @@ -0,0 +1,52 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.cookie; + +import static io.netty.handler.codec.http.cookie.CookieUtil.firstInvalidCookieNameOctet; +import static io.netty.handler.codec.http.cookie.CookieUtil.firstInvalidCookieValueOctet; +import static io.netty.handler.codec.http.cookie.CookieUtil.unwrapValue; + +/** + * Parent of Client and Server side cookie encoders + */ +public abstract class CookieEncoder { + + protected final boolean strict; + + protected CookieEncoder(boolean strict) { + this.strict = strict; + } + + protected void validateCookie(String name, String value) { + if (strict) { + int pos; + + if ((pos = firstInvalidCookieNameOctet(name)) >= 0) { + throw new IllegalArgumentException("Cookie name contains an invalid char: " + name.charAt(pos)); + } + + CharSequence unwrappedValue = unwrapValue(value); + if (unwrappedValue == null) { + throw new IllegalArgumentException("Cookie value wrapping quotes are not balanced: " + value); + } + + if ((pos = firstInvalidCookieValueOctet(unwrappedValue)) >= 0) { + throw new IllegalArgumentException("Cookie value contains an invalid char: " + + unwrappedValue.charAt(pos)); + } + } + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/cookie/CookieHeaderNames.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/cookie/CookieHeaderNames.java new file mode 100644 index 0000000..7e3881e --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/cookie/CookieHeaderNames.java @@ -0,0 +1,63 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.cookie; + +public final class CookieHeaderNames { + public static final String PATH = "Path"; + + public static final String EXPIRES = "Expires"; + + public static final String MAX_AGE = "Max-Age"; + + public static final String DOMAIN = "Domain"; + + public static final String SECURE = "Secure"; + + public static final String HTTPONLY = "HTTPOnly"; + + public static final String SAMESITE = "SameSite"; + + /** + * Possible values for the SameSite attribute. + * See changes to RFC6265bis + */ + public enum SameSite { + Lax, + Strict, + None; + + /** + * Return the enum value corresponding to the passed in same-site-flag, using a case insensitive comparison. + * + * @param name value for the SameSite Attribute + * @return enum value for the provided name or null + */ + static SameSite of(String name) { + if (name != null) { + for (SameSite each : SameSite.class.getEnumConstants()) { + if (each.name().equalsIgnoreCase(name)) { + return each; + } + } + } + return null; + } + } + + private CookieHeaderNames() { + // Unused. + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/cookie/CookieUtil.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/cookie/CookieUtil.java new file mode 100644 index 0000000..64aa6bf --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/cookie/CookieUtil.java @@ -0,0 +1,183 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.cookie; + +import io.netty.handler.codec.http.HttpConstants; +import io.netty.util.internal.InternalThreadLocalMap; + +import java.util.BitSet; + +final class CookieUtil { + + private static final BitSet VALID_COOKIE_NAME_OCTETS = validCookieNameOctets(); + + private static final BitSet VALID_COOKIE_VALUE_OCTETS = validCookieValueOctets(); + + private static final BitSet VALID_COOKIE_ATTRIBUTE_VALUE_OCTETS = validCookieAttributeValueOctets(); + + // token = 1* + // separators = "(" | ")" | "<" | ">" | "@" + // | "," | ";" | ":" | "\" | <"> + // | "/" | "[" | "]" | "?" | "=" + // | "{" | "}" | SP | HT + private static BitSet validCookieNameOctets() { + BitSet bits = new BitSet(); + for (int i = 32; i < 127; i++) { + bits.set(i); + } + int[] separators = new int[] + { '(', ')', '<', '>', '@', ',', ';', ':', '\\', '"', '/', '[', ']', '?', '=', '{', '}', ' ', '\t' }; + for (int separator : separators) { + bits.set(separator, false); + } + return bits; + } + + // cookie-octet = %x21 / %x23-2B / %x2D-3A / %x3C-5B / %x5D-7E + // US-ASCII characters excluding CTLs, whitespace, DQUOTE, comma, semicolon, and backslash + private static BitSet validCookieValueOctets() { + BitSet bits = new BitSet(); + bits.set(0x21); + for (int i = 0x23; i <= 0x2B; i++) { + bits.set(i); + } + for (int i = 0x2D; i <= 0x3A; i++) { + bits.set(i); + } + for (int i = 0x3C; i <= 0x5B; i++) { + bits.set(i); + } + for (int i = 0x5D; i <= 0x7E; i++) { + bits.set(i); + } + return bits; + } + + // path-value = + private static BitSet validCookieAttributeValueOctets() { + BitSet bits = new BitSet(); + for (int i = 32; i < 127; i++) { + bits.set(i); + } + bits.set(';', false); + return bits; + } + + static StringBuilder stringBuilder() { + return InternalThreadLocalMap.get().stringBuilder(); + } + + /** + * @param buf a buffer where some cookies were maybe encoded + * @return the buffer String without the trailing separator, or null if no cookie was appended. + */ + static String stripTrailingSeparatorOrNull(StringBuilder buf) { + return buf.length() == 0 ? null : stripTrailingSeparator(buf); + } + + static String stripTrailingSeparator(StringBuilder buf) { + if (buf.length() > 0) { + buf.setLength(buf.length() - 2); + } + return buf.toString(); + } + + static void add(StringBuilder sb, String name, long val) { + sb.append(name); + sb.append('='); + sb.append(val); + sb.append(';'); + sb.append(HttpConstants.SP_CHAR); + } + + static void add(StringBuilder sb, String name, String val) { + sb.append(name); + sb.append('='); + sb.append(val); + sb.append(';'); + sb.append(HttpConstants.SP_CHAR); + } + + static void add(StringBuilder sb, String name) { + sb.append(name); + sb.append(';'); + sb.append(HttpConstants.SP_CHAR); + } + + static void addQuoted(StringBuilder sb, String name, String val) { + if (val == null) { + val = ""; + } + + sb.append(name); + sb.append('='); + sb.append('"'); + sb.append(val); + sb.append('"'); + sb.append(';'); + sb.append(HttpConstants.SP_CHAR); + } + + static int firstInvalidCookieNameOctet(CharSequence cs) { + return firstInvalidOctet(cs, VALID_COOKIE_NAME_OCTETS); + } + + static int firstInvalidCookieValueOctet(CharSequence cs) { + return firstInvalidOctet(cs, VALID_COOKIE_VALUE_OCTETS); + } + + static int firstInvalidOctet(CharSequence cs, BitSet bits) { + for (int i = 0; i < cs.length(); i++) { + char c = cs.charAt(i); + if (!bits.get(c)) { + return i; + } + } + return -1; + } + + static CharSequence unwrapValue(CharSequence cs) { + final int len = cs.length(); + if (len > 0 && cs.charAt(0) == '"') { + if (len >= 2 && cs.charAt(len - 1) == '"') { + // properly balanced + return len == 2 ? "" : cs.subSequence(1, len - 1); + } else { + return null; + } + } + return cs; + } + + static String validateAttributeValue(String name, String value) { + if (value == null) { + return null; + } + value = value.trim(); + if (value.isEmpty()) { + return null; + } + int i = firstInvalidOctet(value, VALID_COOKIE_ATTRIBUTE_VALUE_OCTETS); + if (i != -1) { + throw new IllegalArgumentException(name + " contains the prohibited characters: " + value.charAt(i)); + } + return value; + } + + private CookieUtil() { + // Unused + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/cookie/DefaultCookie.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/cookie/DefaultCookie.java new file mode 100644 index 0000000..f40a027 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/cookie/DefaultCookie.java @@ -0,0 +1,261 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.cookie; + +import io.netty.handler.codec.http.cookie.CookieHeaderNames.SameSite; + +import static io.netty.handler.codec.http.cookie.CookieUtil.stringBuilder; +import static io.netty.handler.codec.http.cookie.CookieUtil.validateAttributeValue; +import static io.netty.util.internal.ObjectUtil.checkNotNull; +import static io.netty.util.internal.ObjectUtil.checkNonEmptyAfterTrim; + +/** + * The default {@link Cookie} implementation. + */ +public class DefaultCookie implements Cookie { + + private final String name; + private String value; + private boolean wrap; + private String domain; + private String path; + private long maxAge = UNDEFINED_MAX_AGE; + private boolean secure; + private boolean httpOnly; + private SameSite sameSite; + + /** + * Creates a new cookie with the specified name and value. + */ + public DefaultCookie(String name, String value) { + this.name = checkNonEmptyAfterTrim(name, "name"); + setValue(value); + } + + @Override + public String name() { + return name; + } + + @Override + public String value() { + return value; + } + + @Override + public void setValue(String value) { + this.value = checkNotNull(value, "value"); + } + + @Override + public boolean wrap() { + return wrap; + } + + @Override + public void setWrap(boolean wrap) { + this.wrap = wrap; + } + + @Override + public String domain() { + return domain; + } + + @Override + public void setDomain(String domain) { + this.domain = validateAttributeValue("domain", domain); + } + + @Override + public String path() { + return path; + } + + @Override + public void setPath(String path) { + this.path = validateAttributeValue("path", path); + } + + @Override + public long maxAge() { + return maxAge; + } + + @Override + public void setMaxAge(long maxAge) { + this.maxAge = maxAge; + } + + @Override + public boolean isSecure() { + return secure; + } + + @Override + public void setSecure(boolean secure) { + this.secure = secure; + } + + @Override + public boolean isHttpOnly() { + return httpOnly; + } + + @Override + public void setHttpOnly(boolean httpOnly) { + this.httpOnly = httpOnly; + } + + /** + * Checks to see if this {@link Cookie} can be sent along cross-site requests. + * For more information, please look + * here + * @return same-site-flag value + */ + public SameSite sameSite() { + return sameSite; + } + + /** + * Determines if this this {@link Cookie} can be sent along cross-site requests. + * For more information, please look + * here + * @param sameSite same-site-flag value + */ + public void setSameSite(SameSite sameSite) { + this.sameSite = sameSite; + } + + @Override + public int hashCode() { + return name().hashCode(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + + if (!(o instanceof Cookie)) { + return false; + } + + Cookie that = (Cookie) o; + if (!name().equals(that.name())) { + return false; + } + + if (path() == null) { + if (that.path() != null) { + return false; + } + } else if (that.path() == null) { + return false; + } else if (!path().equals(that.path())) { + return false; + } + + if (domain() == null) { + if (that.domain() != null) { + return false; + } + } else { + return domain().equalsIgnoreCase(that.domain()); + } + + return true; + } + + @Override + public int compareTo(Cookie c) { + int v = name().compareTo(c.name()); + if (v != 0) { + return v; + } + + if (path() == null) { + if (c.path() != null) { + return -1; + } + } else if (c.path() == null) { + return 1; + } else { + v = path().compareTo(c.path()); + if (v != 0) { + return v; + } + } + + if (domain() == null) { + if (c.domain() != null) { + return -1; + } + } else if (c.domain() == null) { + return 1; + } else { + v = domain().compareToIgnoreCase(c.domain()); + return v; + } + + return 0; + } + + /** + * Validate a cookie attribute value, throws a {@link IllegalArgumentException} otherwise. + * Only intended to be used by {@link io.netty.handler.codec.http.DefaultCookie}. + * @param name attribute name + * @param value attribute value + * @return the trimmed, validated attribute value + * @deprecated CookieUtil is package private, will be removed once old Cookie API is dropped + */ + @Deprecated + protected String validateValue(String name, String value) { + return validateAttributeValue(name, value); + } + + @Override + public String toString() { + StringBuilder buf = stringBuilder() + .append(name()) + .append('=') + .append(value()); + if (domain() != null) { + buf.append(", domain=") + .append(domain()); + } + if (path() != null) { + buf.append(", path=") + .append(path()); + } + if (maxAge() >= 0) { + buf.append(", maxAge=") + .append(maxAge()) + .append('s'); + } + if (isSecure()) { + buf.append(", secure"); + } + if (isHttpOnly()) { + buf.append(", HTTPOnly"); + } + if (sameSite() != null) { + buf.append(", SameSite=").append(sameSite()); + } + return buf.toString(); + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/cookie/ServerCookieDecoder.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/cookie/ServerCookieDecoder.java new file mode 100644 index 0000000..e6adc46 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/cookie/ServerCookieDecoder.java @@ -0,0 +1,175 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.cookie; + +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Set; +import java.util.TreeSet; + +/** + * A RFC6265 compliant cookie decoder to be used server side. + * + * Only name and value fields are expected, so old fields are not populated (path, domain, etc). + * + * Old RFC2965 cookies are still supported, + * old fields will simply be ignored. + * + * @see ServerCookieEncoder + */ +public final class ServerCookieDecoder extends CookieDecoder { + + private static final String RFC2965_VERSION = "$Version"; + + private static final String RFC2965_PATH = "$" + CookieHeaderNames.PATH; + + private static final String RFC2965_DOMAIN = "$" + CookieHeaderNames.DOMAIN; + + private static final String RFC2965_PORT = "$Port"; + + /** + * Strict encoder that validates that name and value chars are in the valid scope + * defined in RFC6265 + */ + public static final ServerCookieDecoder STRICT = new ServerCookieDecoder(true); + + /** + * Lax instance that doesn't validate name and value + */ + public static final ServerCookieDecoder LAX = new ServerCookieDecoder(false); + + private ServerCookieDecoder(boolean strict) { + super(strict); + } + + /** + * Decodes the specified {@code Cookie} HTTP header value into a {@link Cookie}. Unlike {@link #decode(String)}, + * this includes all cookie values present, even if they have the same name. + * + * @return the decoded {@link Cookie} + */ + public List decodeAll(String header) { + List cookies = new ArrayList(); + decode(cookies, header); + return Collections.unmodifiableList(cookies); + } + + /** + * Decodes the specified {@code Cookie} HTTP header value into a {@link Cookie}. + * + * @return the decoded {@link Cookie} + */ + public Set decode(String header) { + Set cookies = new TreeSet(); + decode(cookies, header); + return cookies; + } + + /** + * Decodes the specified {@code Cookie} HTTP header value into a {@link Cookie}. + */ + private void decode(Collection cookies, String header) { + final int headerLen = checkNotNull(header, "header").length(); + + if (headerLen == 0) { + return; + } + + int i = 0; + + boolean rfc2965Style = false; + if (header.regionMatches(true, 0, RFC2965_VERSION, 0, RFC2965_VERSION.length())) { + // RFC 2965 style cookie, move to after version value + i = header.indexOf(';') + 1; + rfc2965Style = true; + } + + loop: for (;;) { + + // Skip spaces and separators. + for (;;) { + if (i == headerLen) { + break loop; + } + char c = header.charAt(i); + if (c == '\t' || c == '\n' || c == 0x0b || c == '\f' + || c == '\r' || c == ' ' || c == ',' || c == ';') { + i++; + continue; + } + break; + } + + int nameBegin = i; + int nameEnd; + int valueBegin; + int valueEnd; + + for (;;) { + + char curChar = header.charAt(i); + if (curChar == ';') { + // NAME; (no value till ';') + nameEnd = i; + valueBegin = valueEnd = -1; + break; + + } else if (curChar == '=') { + // NAME=VALUE + nameEnd = i; + i++; + if (i == headerLen) { + // NAME= (empty value, i.e. nothing after '=') + valueBegin = valueEnd = 0; + break; + } + + valueBegin = i; + // NAME=VALUE; + int semiPos = header.indexOf(';', i); + valueEnd = i = semiPos > 0 ? semiPos : headerLen; + break; + } else { + i++; + } + + if (i == headerLen) { + // NAME (no value till the end of string) + nameEnd = headerLen; + valueBegin = valueEnd = -1; + break; + } + } + + if (rfc2965Style && (header.regionMatches(nameBegin, RFC2965_PATH, 0, RFC2965_PATH.length()) || + header.regionMatches(nameBegin, RFC2965_DOMAIN, 0, RFC2965_DOMAIN.length()) || + header.regionMatches(nameBegin, RFC2965_PORT, 0, RFC2965_PORT.length()))) { + + // skip obsolete RFC2965 fields + continue; + } + + DefaultCookie cookie = initCookie(header, nameBegin, nameEnd, valueBegin, valueEnd); + if (cookie != null) { + cookies.add(cookie); + } + } + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/cookie/ServerCookieEncoder.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/cookie/ServerCookieEncoder.java new file mode 100644 index 0000000..b33d923 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/cookie/ServerCookieEncoder.java @@ -0,0 +1,232 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.cookie; + +import io.netty.handler.codec.DateFormatter; +import io.netty.handler.codec.http.HttpConstants; +import io.netty.handler.codec.http.HttpResponse; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.Date; +import java.util.HashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; + +import static io.netty.handler.codec.http.cookie.CookieUtil.add; +import static io.netty.handler.codec.http.cookie.CookieUtil.addQuoted; +import static io.netty.handler.codec.http.cookie.CookieUtil.stringBuilder; +import static io.netty.handler.codec.http.cookie.CookieUtil.stripTrailingSeparator; +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +/** + * A RFC6265 compliant cookie encoder to be used server side, + * so some fields are sent (Version is typically ignored). + * + * As Netty's Cookie merges Expires and MaxAge into one single field, only Max-Age field is sent. + * + * Note that multiple cookies must be sent as separate "Set-Cookie" headers. + * + *
+ * // Example
+ * {@link HttpResponse} res = ...;
+ * res.setHeader("Set-Cookie", {@link ServerCookieEncoder}.encode("JSESSIONID", "1234"));
+ * 
+ * + * @see ServerCookieDecoder + */ +public final class ServerCookieEncoder extends CookieEncoder { + + /** + * Strict encoder that validates that name and value chars are in the valid scope + * defined in RFC6265, and (for methods that accept multiple cookies) that only + * one cookie is encoded with any given name. (If multiple cookies have the same + * name, the last one is the one that is encoded.) + */ + public static final ServerCookieEncoder STRICT = new ServerCookieEncoder(true); + + /** + * Lax instance that doesn't validate name and value, and that allows multiple + * cookies with the same name. + */ + public static final ServerCookieEncoder LAX = new ServerCookieEncoder(false); + + private ServerCookieEncoder(boolean strict) { + super(strict); + } + + /** + * Encodes the specified cookie name-value pair into a Set-Cookie header value. + * + * @param name the cookie name + * @param value the cookie value + * @return a single Set-Cookie header value + */ + public String encode(String name, String value) { + return encode(new DefaultCookie(name, value)); + } + + /** + * Encodes the specified cookie into a Set-Cookie header value. + * + * @param cookie the cookie + * @return a single Set-Cookie header value + */ + public String encode(Cookie cookie) { + final String name = checkNotNull(cookie, "cookie").name(); + final String value = cookie.value() != null ? cookie.value() : ""; + + validateCookie(name, value); + + StringBuilder buf = stringBuilder(); + + if (cookie.wrap()) { + addQuoted(buf, name, value); + } else { + add(buf, name, value); + } + + if (cookie.maxAge() != Long.MIN_VALUE) { + add(buf, CookieHeaderNames.MAX_AGE, cookie.maxAge()); + Date expires = new Date(cookie.maxAge() * 1000 + System.currentTimeMillis()); + buf.append(CookieHeaderNames.EXPIRES); + buf.append('='); + DateFormatter.append(expires, buf); + buf.append(';'); + buf.append(HttpConstants.SP_CHAR); + } + + if (cookie.path() != null) { + add(buf, CookieHeaderNames.PATH, cookie.path()); + } + + if (cookie.domain() != null) { + add(buf, CookieHeaderNames.DOMAIN, cookie.domain()); + } + if (cookie.isSecure()) { + add(buf, CookieHeaderNames.SECURE); + } + if (cookie.isHttpOnly()) { + add(buf, CookieHeaderNames.HTTPONLY); + } + if (cookie instanceof DefaultCookie) { + DefaultCookie c = (DefaultCookie) cookie; + if (c.sameSite() != null) { + add(buf, CookieHeaderNames.SAMESITE, c.sameSite().name()); + } + } + + return stripTrailingSeparator(buf); + } + + /** Deduplicate a list of encoded cookies by keeping only the last instance with a given name. + * + * @param encoded The list of encoded cookies. + * @param nameToLastIndex A map from cookie name to index of last cookie instance. + * @return The encoded list with all but the last instance of a named cookie. + */ + private static List dedup(List encoded, Map nameToLastIndex) { + boolean[] isLastInstance = new boolean[encoded.size()]; + for (int idx : nameToLastIndex.values()) { + isLastInstance[idx] = true; + } + List dedupd = new ArrayList(nameToLastIndex.size()); + for (int i = 0, n = encoded.size(); i < n; i++) { + if (isLastInstance[i]) { + dedupd.add(encoded.get(i)); + } + } + return dedupd; + } + + /** + * Batch encodes cookies into Set-Cookie header values. + * + * @param cookies a bunch of cookies + * @return the corresponding bunch of Set-Cookie headers + */ + public List encode(Cookie... cookies) { + if (checkNotNull(cookies, "cookies").length == 0) { + return Collections.emptyList(); + } + + List encoded = new ArrayList(cookies.length); + Map nameToIndex = strict && cookies.length > 1 ? new HashMap() : null; + boolean hasDupdName = false; + for (int i = 0; i < cookies.length; i++) { + Cookie c = cookies[i]; + encoded.add(encode(c)); + if (nameToIndex != null) { + hasDupdName |= nameToIndex.put(c.name(), i) != null; + } + } + return hasDupdName ? dedup(encoded, nameToIndex) : encoded; + } + + /** + * Batch encodes cookies into Set-Cookie header values. + * + * @param cookies a bunch of cookies + * @return the corresponding bunch of Set-Cookie headers + */ + public List encode(Collection cookies) { + if (checkNotNull(cookies, "cookies").isEmpty()) { + return Collections.emptyList(); + } + + List encoded = new ArrayList(cookies.size()); + Map nameToIndex = strict && cookies.size() > 1 ? new HashMap() : null; + int i = 0; + boolean hasDupdName = false; + for (Cookie c : cookies) { + encoded.add(encode(c)); + if (nameToIndex != null) { + hasDupdName |= nameToIndex.put(c.name(), i++) != null; + } + } + return hasDupdName ? dedup(encoded, nameToIndex) : encoded; + } + + /** + * Batch encodes cookies into Set-Cookie header values. + * + * @param cookies a bunch of cookies + * @return the corresponding bunch of Set-Cookie headers + */ + public List encode(Iterable cookies) { + Iterator cookiesIt = checkNotNull(cookies, "cookies").iterator(); + if (!cookiesIt.hasNext()) { + return Collections.emptyList(); + } + + List encoded = new ArrayList(); + Cookie firstCookie = cookiesIt.next(); + Map nameToIndex = strict && cookiesIt.hasNext() ? new HashMap() : null; + int i = 0; + encoded.add(encode(firstCookie)); + boolean hasDupdName = nameToIndex != null && nameToIndex.put(firstCookie.name(), i++) != null; + while (cookiesIt.hasNext()) { + Cookie c = cookiesIt.next(); + encoded.add(encode(c)); + if (nameToIndex != null) { + hasDupdName |= nameToIndex.put(c.name(), i++) != null; + } + } + return hasDupdName ? dedup(encoded, nameToIndex) : encoded; + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/cookie/package-info.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/cookie/package-info.java new file mode 100644 index 0000000..d67e10b --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/cookie/package-info.java @@ -0,0 +1,20 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/** + * This package contains Cookie related classes. + */ +package io.netty.handler.codec.http.cookie; diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/cors/CorsConfig.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/cors/CorsConfig.java new file mode 100644 index 0000000..0cd7b97 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/cors/CorsConfig.java @@ -0,0 +1,455 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version + * 2.0 (the "License"); you may not use this file except in compliance with the + * License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http.cors; + +import io.netty.handler.codec.http.DefaultHttpHeaders; +import io.netty.handler.codec.http.EmptyHttpHeaders; +import io.netty.handler.codec.http.HttpHeaders; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.util.internal.StringUtil; + +import java.util.Collections; +import java.util.Date; +import java.util.LinkedHashSet; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Set; +import java.util.concurrent.Callable; + +/** + * Configuration for Cross-Origin Resource Sharing (CORS). + */ +public final class CorsConfig { + + private final Set origins; + private final boolean anyOrigin; + private final boolean enabled; + private final Set exposeHeaders; + private final boolean allowCredentials; + private final long maxAge; + private final Set allowedRequestMethods; + private final Set allowedRequestHeaders; + private final boolean allowNullOrigin; + private final Map> preflightHeaders; + private final boolean shortCircuit; + private final boolean allowPrivateNetwork; + + CorsConfig(final CorsConfigBuilder builder) { + origins = new LinkedHashSet(builder.origins); + anyOrigin = builder.anyOrigin; + enabled = builder.enabled; + exposeHeaders = builder.exposeHeaders; + allowCredentials = builder.allowCredentials; + maxAge = builder.maxAge; + allowedRequestMethods = builder.requestMethods; + allowedRequestHeaders = builder.requestHeaders; + allowNullOrigin = builder.allowNullOrigin; + preflightHeaders = builder.preflightHeaders; + shortCircuit = builder.shortCircuit; + allowPrivateNetwork = builder.allowPrivateNetwork; + } + + /** + * Determines if support for CORS is enabled. + * + * @return {@code true} if support for CORS is enabled, false otherwise. + */ + public boolean isCorsSupportEnabled() { + return enabled; + } + + /** + * Determines whether a wildcard origin, '*', is supported. + * + * @return {@code boolean} true if any origin is allowed. + */ + public boolean isAnyOriginSupported() { + return anyOrigin; + } + + /** + * Returns the allowed origin. This can either be a wildcard or an origin value. + * + * @return the value that will be used for the CORS response header 'Access-Control-Allow-Origin' + */ + public String origin() { + return origins.isEmpty() ? "*" : origins.iterator().next(); + } + + /** + * Returns the set of allowed origins. + * + * @return {@code Set} the allowed origins. + */ + public Set origins() { + return origins; + } + + /** + * Web browsers may set the 'Origin' request header to 'null' if a resource is loaded + * from the local file system. + * + * If isNullOriginAllowed is true then the server will response with the wildcard for + * the CORS response header 'Access-Control-Allow-Origin'. + * + * @return {@code true} if a 'null' origin should be supported. + */ + public boolean isNullOriginAllowed() { + return allowNullOrigin; + } + + /** + * Web browsers may set the 'Access-Control-Request-Private-Network' request header if a resource is loaded + * from a local network. + * By default direct access to private network endpoints from public websites is not allowed. + * + * If isPrivateNetworkAllowed is true the server will response with the CORS response header + * 'Access-Control-Request-Private-Network'. + * + * @return {@code true} if private network access should be allowed. + */ + public boolean isPrivateNetworkAllowed() { + return allowPrivateNetwork; + } + + /** + * Returns a set of headers to be exposed to calling clients. + * + * During a simple CORS request only certain response headers are made available by the + * browser, for example using: + *
+     * xhr.getResponseHeader("Content-Type");
+     * 
+ * The headers that are available by default are: + *
    + *
  • Cache-Control
  • + *
  • Content-Language
  • + *
  • Content-Type
  • + *
  • Expires
  • + *
  • Last-Modified
  • + *
  • Pragma
  • + *
+ * To expose other headers they need to be specified, which is what this method enables by + * adding the headers names to the CORS 'Access-Control-Expose-Headers' response header. + * + * @return {@code List} a list of the headers to expose. + */ + public Set exposedHeaders() { + return Collections.unmodifiableSet(exposeHeaders); + } + + /** + * Determines if cookies are supported for CORS requests. + * + * By default cookies are not included in CORS requests but if isCredentialsAllowed returns + * true cookies will be added to CORS requests. Setting this value to true will set the + * CORS 'Access-Control-Allow-Credentials' response header to true. + * + * Please note that cookie support needs to be enabled on the client side as well. + * The client needs to opt-in to send cookies by calling: + *
+     * xhr.withCredentials = true;
+     * 
+ * The default value for 'withCredentials' is false in which case no cookies are sent. + * Setting this to true will included cookies in cross origin requests. + * + * @return {@code true} if cookies are supported. + */ + public boolean isCredentialsAllowed() { + return allowCredentials; + } + + /** + * Gets the maxAge setting. + * + * When making a preflight request the client has to perform two request with can be inefficient. + * This setting will set the CORS 'Access-Control-Max-Age' response header and enables the + * caching of the preflight response for the specified time. During this time no preflight + * request will be made. + * + * @return {@code long} the time in seconds that a preflight request may be cached. + */ + public long maxAge() { + return maxAge; + } + + /** + * Returns the allowed set of Request Methods. The Http methods that should be returned in the + * CORS 'Access-Control-Request-Method' response header. + * + * @return {@code Set} of {@link HttpMethod}s that represent the allowed Request Methods. + */ + public Set allowedRequestMethods() { + return Collections.unmodifiableSet(allowedRequestMethods); + } + + /** + * Returns the allowed set of Request Headers. + * + * The header names returned from this method will be used to set the CORS + * 'Access-Control-Allow-Headers' response header. + * + * @return {@code Set} of strings that represent the allowed Request Headers. + */ + public Set allowedRequestHeaders() { + return Collections.unmodifiableSet(allowedRequestHeaders); + } + + /** + * Returns HTTP response headers that should be added to a CORS preflight response. + * + * @return {@link HttpHeaders} the HTTP response headers to be added. + */ + public HttpHeaders preflightResponseHeaders() { + if (preflightHeaders.isEmpty()) { + return EmptyHttpHeaders.INSTANCE; + } + final HttpHeaders preflightHeaders = new DefaultHttpHeaders(); + for (Entry> entry : this.preflightHeaders.entrySet()) { + final Object value = getValue(entry.getValue()); + if (value instanceof Iterable) { + preflightHeaders.add(entry.getKey(), (Iterable) value); + } else { + preflightHeaders.add(entry.getKey(), value); + } + } + return preflightHeaders; + } + + /** + * Determines whether a CORS request should be rejected if it's invalid before being + * further processing. + * + * CORS headers are set after a request is processed. This may not always be desired + * and this setting will check that the Origin is valid and if it is not valid no + * further processing will take place, and an error will be returned to the calling client. + * + * @return {@code true} if a CORS request should short-circuit upon receiving an invalid Origin header. + */ + public boolean isShortCircuit() { + return shortCircuit; + } + + /** + * @deprecated Use {@link #isShortCircuit()} instead. + */ + @Deprecated + public boolean isShortCurcuit() { + return isShortCircuit(); + } + + private static T getValue(final Callable callable) { + try { + return callable.call(); + } catch (final Exception e) { + throw new IllegalStateException("Could not generate value for callable [" + callable + ']', e); + } + } + + @Override + public String toString() { + return StringUtil.simpleClassName(this) + "[enabled=" + enabled + + ", origins=" + origins + + ", anyOrigin=" + anyOrigin + + ", exposedHeaders=" + exposeHeaders + + ", isCredentialsAllowed=" + allowCredentials + + ", maxAge=" + maxAge + + ", allowedRequestMethods=" + allowedRequestMethods + + ", allowedRequestHeaders=" + allowedRequestHeaders + + ", preflightHeaders=" + preflightHeaders + + ", isPrivateNetworkAllowed=" + allowPrivateNetwork + ']'; + } + + /** + * @deprecated Use {@link CorsConfigBuilder#forAnyOrigin()} instead. + */ + @Deprecated + public static Builder withAnyOrigin() { + return new Builder(); + } + + /** + * @deprecated Use {@link CorsConfigBuilder#forOrigin(String)} instead. + */ + @Deprecated + public static Builder withOrigin(final String origin) { + if ("*".equals(origin)) { + return new Builder(); + } + return new Builder(origin); + } + + /** + * @deprecated Use {@link CorsConfigBuilder#forOrigins(String...)} instead. + */ + @Deprecated + public static Builder withOrigins(final String... origins) { + return new Builder(origins); + } + + /** + * @deprecated Use {@link CorsConfigBuilder} instead. + */ + @Deprecated + public static class Builder { + + private final CorsConfigBuilder builder; + + /** + * @deprecated Use {@link CorsConfigBuilder} instead. + */ + @Deprecated + public Builder(final String... origins) { + builder = new CorsConfigBuilder(origins); + } + + /** + * @deprecated Use {@link CorsConfigBuilder} instead. + */ + @Deprecated + public Builder() { + builder = new CorsConfigBuilder(); + } + + /** + * @deprecated Use {@link CorsConfigBuilder#allowNullOrigin()} instead. + */ + @Deprecated + public Builder allowNullOrigin() { + builder.allowNullOrigin(); + return this; + } + + /** + * @deprecated Use {@link CorsConfigBuilder#disable()} instead. + */ + @Deprecated + public Builder disable() { + builder.disable(); + return this; + } + + /** + * @deprecated Use {@link CorsConfigBuilder#exposeHeaders(String...)} instead. + */ + @Deprecated + public Builder exposeHeaders(final String... headers) { + builder.exposeHeaders(headers); + return this; + } + + /** + * @deprecated Use {@link CorsConfigBuilder#allowCredentials()} instead. + */ + @Deprecated + public Builder allowCredentials() { + builder.allowCredentials(); + return this; + } + + /** + * @deprecated Use {@link CorsConfigBuilder#maxAge(long)} instead. + */ + @Deprecated + public Builder maxAge(final long max) { + builder.maxAge(max); + return this; + } + + /** + * @deprecated Use {@link CorsConfigBuilder#allowedRequestMethods(HttpMethod...)} instead. + */ + @Deprecated + public Builder allowedRequestMethods(final HttpMethod... methods) { + builder.allowedRequestMethods(methods); + return this; + } + + /** + * @deprecated Use {@link CorsConfigBuilder#allowedRequestHeaders(String...)} instead. + */ + @Deprecated + public Builder allowedRequestHeaders(final String... headers) { + builder.allowedRequestHeaders(headers); + return this; + } + + /** + * @deprecated Use {@link CorsConfigBuilder#preflightResponseHeader(CharSequence, Object...)} instead. + */ + @Deprecated + public Builder preflightResponseHeader(final CharSequence name, final Object... values) { + builder.preflightResponseHeader(name, values); + return this; + } + + /** + * @deprecated Use {@link CorsConfigBuilder#preflightResponseHeader(CharSequence, Iterable)} instead. + */ + @Deprecated + public Builder preflightResponseHeader(final CharSequence name, final Iterable value) { + builder.preflightResponseHeader(name, value); + return this; + } + + /** + * @deprecated Use {@link CorsConfigBuilder#preflightResponseHeader(CharSequence, Callable)} instead. + */ + @Deprecated + public Builder preflightResponseHeader(final String name, final Callable valueGenerator) { + builder.preflightResponseHeader(name, valueGenerator); + return this; + } + + /** + * @deprecated Use {@link CorsConfigBuilder#noPreflightResponseHeaders()} instead. + */ + @Deprecated + public Builder noPreflightResponseHeaders() { + builder.noPreflightResponseHeaders(); + return this; + } + + /** + * @deprecated Use {@link CorsConfigBuilder#build()} instead. + */ + @Deprecated + public CorsConfig build() { + return builder.build(); + } + + /** + * @deprecated Use {@link CorsConfigBuilder#shortCircuit()} instead. + */ + @Deprecated + public Builder shortCurcuit() { + builder.shortCircuit(); + return this; + } + } + + /** + * @deprecated Removed without alternatives. + */ + @Deprecated + public static final class DateValueGenerator implements Callable { + + @Override + public Date call() throws Exception { + return new Date(); + } + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/cors/CorsConfigBuilder.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/cors/CorsConfigBuilder.java new file mode 100644 index 0000000..c41408d --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/cors/CorsConfigBuilder.java @@ -0,0 +1,420 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version + * 2.0 (the "License"); you may not use this file except in compliance with the + * License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http.cors; + +import static io.netty.util.internal.ObjectUtil.checkNotNullWithIAE; + +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpMethod; + +import java.util.Arrays; +import java.util.Collections; +import java.util.Date; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashSet; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.Callable; + +/** + * Builder used to configure and build a {@link CorsConfig} instance. + */ +public final class CorsConfigBuilder { + + /** + * Creates a Builder instance with it's origin set to '*'. + * + * @return Builder to support method chaining. + */ + public static CorsConfigBuilder forAnyOrigin() { + return new CorsConfigBuilder(); + } + + /** + * Creates a {@link CorsConfigBuilder} instance with the specified origin. + * + * @return {@link CorsConfigBuilder} to support method chaining. + */ + public static CorsConfigBuilder forOrigin(final String origin) { + if ("*".equals(origin)) { + return new CorsConfigBuilder(); + } + return new CorsConfigBuilder(origin); + } + + /** + * Creates a {@link CorsConfigBuilder} instance with the specified origins. + * + * @return {@link CorsConfigBuilder} to support method chaining. + */ + public static CorsConfigBuilder forOrigins(final String... origins) { + return new CorsConfigBuilder(origins); + } + + final Set origins; + final boolean anyOrigin; + boolean allowNullOrigin; + boolean enabled = true; + boolean allowCredentials; + final Set exposeHeaders = new HashSet(); + long maxAge; + final Set requestMethods = new HashSet(); + final Set requestHeaders = new HashSet(); + final Map> preflightHeaders = new HashMap>(); + private boolean noPreflightHeaders; + boolean shortCircuit; + boolean allowPrivateNetwork; + + /** + * Creates a new Builder instance with the origin passed in. + * + * @param origins the origin to be used for this builder. + */ + CorsConfigBuilder(final String... origins) { + this.origins = new LinkedHashSet(Arrays.asList(origins)); + anyOrigin = false; + } + + /** + * Creates a new Builder instance allowing any origin, "*" which is the + * wildcard origin. + * + */ + CorsConfigBuilder() { + anyOrigin = true; + origins = Collections.emptySet(); + } + + /** + * Web browsers may set the 'Origin' request header to 'null' if a resource is loaded + * from the local file system. Calling this method will enable a successful CORS response + * with a {@code "null"} value for the CORS response header 'Access-Control-Allow-Origin'. + * + * @return {@link CorsConfigBuilder} to support method chaining. + */ + public CorsConfigBuilder allowNullOrigin() { + allowNullOrigin = true; + return this; + } + + /** + * Disables CORS support. + * + * @return {@link CorsConfigBuilder} to support method chaining. + */ + public CorsConfigBuilder disable() { + enabled = false; + return this; + } + + /** + * Specifies the headers to be exposed to calling clients. + * + * During a simple CORS request, only certain response headers are made available by the + * browser, for example using: + *
+     * xhr.getResponseHeader("Content-Type");
+     * 
+ * + * The headers that are available by default are: + *
    + *
  • Cache-Control
  • + *
  • Content-Language
  • + *
  • Content-Type
  • + *
  • Expires
  • + *
  • Last-Modified
  • + *
  • Pragma
  • + *
+ * + * To expose other headers they need to be specified which is what this method enables by + * adding the headers to the CORS 'Access-Control-Expose-Headers' response header. + * + * @param headers the values to be added to the 'Access-Control-Expose-Headers' response header + * @return {@link CorsConfigBuilder} to support method chaining. + */ + public CorsConfigBuilder exposeHeaders(final String... headers) { + exposeHeaders.addAll(Arrays.asList(headers)); + return this; + } + + /** + * Specifies the headers to be exposed to calling clients. + * + * During a simple CORS request, only certain response headers are made available by the + * browser, for example using: + *
+     * xhr.getResponseHeader(HttpHeaderNames.CONTENT_TYPE);
+     * 
+ * + * The headers that are available by default are: + *
    + *
  • Cache-Control
  • + *
  • Content-Language
  • + *
  • Content-Type
  • + *
  • Expires
  • + *
  • Last-Modified
  • + *
  • Pragma
  • + *
+ * + * To expose other headers they need to be specified which is what this method enables by + * adding the headers to the CORS 'Access-Control-Expose-Headers' response header. + * + * @param headers the values to be added to the 'Access-Control-Expose-Headers' response header + * @return {@link CorsConfigBuilder} to support method chaining. + */ + public CorsConfigBuilder exposeHeaders(final CharSequence... headers) { + for (CharSequence header: headers) { + exposeHeaders.add(header.toString()); + } + return this; + } + + /** + * By default cookies are not included in CORS requests, but this method will enable cookies to + * be added to CORS requests. Calling this method will set the CORS 'Access-Control-Allow-Credentials' + * response header to true. + * + * Please note, that cookie support needs to be enabled on the client side as well. + * The client needs to opt-in to send cookies by calling: + *
+     * xhr.withCredentials = true;
+     * 
+ * The default value for 'withCredentials' is false in which case no cookies are sent. + * Setting this to true will included cookies in cross origin requests. + * + * @return {@link CorsConfigBuilder} to support method chaining. + */ + public CorsConfigBuilder allowCredentials() { + allowCredentials = true; + return this; + } + + /** + * When making a preflight request the client has to perform two request with can be inefficient. + * This setting will set the CORS 'Access-Control-Max-Age' response header and enables the + * caching of the preflight response for the specified time. During this time no preflight + * request will be made. + * + * @param max the maximum time, in seconds, that the preflight response may be cached. + * @return {@link CorsConfigBuilder} to support method chaining. + */ + public CorsConfigBuilder maxAge(final long max) { + maxAge = max; + return this; + } + + /** + * Specifies the allowed set of HTTP Request Methods that should be returned in the + * CORS 'Access-Control-Request-Method' response header. + * + * @param methods the {@link HttpMethod}s that should be allowed. + * @return {@link CorsConfigBuilder} to support method chaining. + */ + public CorsConfigBuilder allowedRequestMethods(final HttpMethod... methods) { + requestMethods.addAll(Arrays.asList(methods)); + return this; + } + + /** + * Specifies the if headers that should be returned in the CORS 'Access-Control-Allow-Headers' + * response header. + * + * If a client specifies headers on the request, for example by calling: + *
+     * xhr.setRequestHeader('My-Custom-Header', "SomeValue");
+     * 
+ * the server will receive the above header name in the 'Access-Control-Request-Headers' of the + * preflight request. The server will then decide if it allows this header to be sent for the + * real request (remember that a preflight is not the real request but a request asking the server + * if it allow a request). + * + * @param headers the headers to be added to the preflight 'Access-Control-Allow-Headers' response header. + * @return {@link CorsConfigBuilder} to support method chaining. + */ + public CorsConfigBuilder allowedRequestHeaders(final String... headers) { + requestHeaders.addAll(Arrays.asList(headers)); + return this; + } + + /** + * Specifies the if headers that should be returned in the CORS 'Access-Control-Allow-Headers' + * response header. + * + * If a client specifies headers on the request, for example by calling: + *
+     * xhr.setRequestHeader('My-Custom-Header', "SomeValue");
+     * 
+ * the server will receive the above header name in the 'Access-Control-Request-Headers' of the + * preflight request. The server will then decide if it allows this header to be sent for the + * real request (remember that a preflight is not the real request but a request asking the server + * if it allow a request). + * + * @param headers the headers to be added to the preflight 'Access-Control-Allow-Headers' response header. + * @return {@link CorsConfigBuilder} to support method chaining. + */ + public CorsConfigBuilder allowedRequestHeaders(final CharSequence... headers) { + for (CharSequence header: headers) { + requestHeaders.add(header.toString()); + } + return this; + } + + /** + * Returns HTTP response headers that should be added to a CORS preflight response. + * + * An intermediary like a load balancer might require that a CORS preflight request + * have certain headers set. This enables such headers to be added. + * + * @param name the name of the HTTP header. + * @param values the values for the HTTP header. + * @return {@link CorsConfigBuilder} to support method chaining. + */ + public CorsConfigBuilder preflightResponseHeader(final CharSequence name, final Object... values) { + if (values.length == 1) { + preflightHeaders.put(name, new ConstantValueGenerator(values[0])); + } else { + preflightResponseHeader(name, Arrays.asList(values)); + } + return this; + } + + /** + * Returns HTTP response headers that should be added to a CORS preflight response. + * + * An intermediary like a load balancer might require that a CORS preflight request + * have certain headers set. This enables such headers to be added. + * + * @param name the name of the HTTP header. + * @param value the values for the HTTP header. + * @param the type of values that the Iterable contains. + * @return {@link CorsConfigBuilder} to support method chaining. + */ + public CorsConfigBuilder preflightResponseHeader(final CharSequence name, final Iterable value) { + preflightHeaders.put(name, new ConstantValueGenerator(value)); + return this; + } + + /** + * Returns HTTP response headers that should be added to a CORS preflight response. + * + * An intermediary like a load balancer might require that a CORS preflight request + * have certain headers set. This enables such headers to be added. + * + * Some values must be dynamically created when the HTTP response is created, for + * example the 'Date' response header. This can be accomplished by using a Callable + * which will have its 'call' method invoked when the HTTP response is created. + * + * @param name the name of the HTTP header. + * @param valueGenerator a Callable which will be invoked at HTTP response creation. + * @param the type of the value that the Callable can return. + * @return {@link CorsConfigBuilder} to support method chaining. + */ + public CorsConfigBuilder preflightResponseHeader(final CharSequence name, final Callable valueGenerator) { + preflightHeaders.put(name, valueGenerator); + return this; + } + + /** + * Specifies that no preflight response headers should be added to a preflight response. + * + * @return {@link CorsConfigBuilder} to support method chaining. + */ + public CorsConfigBuilder noPreflightResponseHeaders() { + noPreflightHeaders = true; + return this; + } + + /** + * Specifies that a CORS request should be rejected if it's invalid before being + * further processing. + * + * CORS headers are set after a request is processed. This may not always be desired + * and this setting will check that the Origin is valid and if it is not valid no + * further processing will take place, and an error will be returned to the calling client. + * + * @return {@link CorsConfigBuilder} to support method chaining. + */ + public CorsConfigBuilder shortCircuit() { + shortCircuit = true; + return this; + } + + /** + * Web browsers may set the 'Access-Control-Request-Private-Network' request header if a resource is loaded + * from a local network. + * By default direct access to private network endpoints from public websites is not allowed. + * Calling this method will set the CORS 'Access-Control-Request-Private-Network' response header to true. + * + * @return {@link CorsConfigBuilder} to support method chaining. + */ + public CorsConfigBuilder allowPrivateNetwork() { + allowPrivateNetwork = true; + return this; + } + + /** + * Builds a {@link CorsConfig} with settings specified by previous method calls. + * + * @return {@link CorsConfig} the configured CorsConfig instance. + */ + public CorsConfig build() { + if (preflightHeaders.isEmpty() && !noPreflightHeaders) { + preflightHeaders.put(HttpHeaderNames.DATE, DateValueGenerator.INSTANCE); + preflightHeaders.put(HttpHeaderNames.CONTENT_LENGTH, new ConstantValueGenerator("0")); + } + return new CorsConfig(this); + } + + /** + * This class is used for preflight HTTP response values that do not need to be + * generated, but instead the value is "static" in that the same value will be returned + * for each call. + */ + private static final class ConstantValueGenerator implements Callable { + + private final Object value; + + /** + * Sole constructor. + * + * @param value the value that will be returned when the call method is invoked. + */ + private ConstantValueGenerator(final Object value) { + this.value = checkNotNullWithIAE(value, "value"); + } + + @Override + public Object call() { + return value; + } + } + + /** + * This callable is used for the DATE preflight HTTP response HTTP header. + * It's value must be generated when the response is generated, hence will be + * different for every call. + */ + private static final class DateValueGenerator implements Callable { + + static final DateValueGenerator INSTANCE = new DateValueGenerator(); + + @Override + public Date call() throws Exception { + return new Date(); + } + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/cors/CorsHandler.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/cors/CorsHandler.java new file mode 100644 index 0000000..75e958c --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/cors/CorsHandler.java @@ -0,0 +1,270 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version + * 2.0 (the "License"); you may not use this file except in compliance with the + * License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http.cors; + +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelDuplexHandler; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.handler.codec.http.DefaultFullHttpResponse; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaderValues; +import io.netty.handler.codec.http.HttpHeaders; +import io.netty.handler.codec.http.DefaultHttpHeadersFactory; +import io.netty.handler.codec.http.HttpRequest; +import io.netty.handler.codec.http.HttpResponse; +import io.netty.handler.codec.http.HttpUtil; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.util.Collections; +import java.util.List; + +import static io.netty.handler.codec.http.HttpMethod.OPTIONS; +import static io.netty.handler.codec.http.HttpResponseStatus.FORBIDDEN; +import static io.netty.handler.codec.http.HttpResponseStatus.OK; +import static io.netty.util.ReferenceCountUtil.release; +import static io.netty.util.internal.ObjectUtil.checkNonEmpty; +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +/** + * Handles Cross Origin Resource Sharing (CORS) requests. + *

+ * This handler can be configured using one or more {@link CorsConfig}, please + * refer to this class for details about the configuration options available. + */ +public class CorsHandler extends ChannelDuplexHandler { + + private static final InternalLogger logger = InternalLoggerFactory.getInstance(CorsHandler.class); + private static final String ANY_ORIGIN = "*"; + private static final String NULL_ORIGIN = "null"; + private CorsConfig config; + + private HttpRequest request; + private final List configList; + private final boolean isShortCircuit; + + /** + * Creates a new instance with a single {@link CorsConfig}. + */ + public CorsHandler(final CorsConfig config) { + this(Collections.singletonList(checkNotNull(config, "config")), config.isShortCircuit()); + } + + /** + * Creates a new instance with the specified config list. If more than one + * config matches a certain origin, the first in the List will be used. + * + * @param configList List of {@link CorsConfig} + * @param isShortCircuit Same as {@link CorsConfig#isShortCircuit} but applicable to all supplied configs. + */ + public CorsHandler(final List configList, boolean isShortCircuit) { + checkNonEmpty(configList, "configList"); + this.configList = configList; + this.isShortCircuit = isShortCircuit; + } + + @Override + public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception { + if (msg instanceof HttpRequest) { + request = (HttpRequest) msg; + final String origin = request.headers().get(HttpHeaderNames.ORIGIN); + config = getForOrigin(origin); + if (isPreflightRequest(request)) { + handlePreflight(ctx, request); + return; + } + if (isShortCircuit && !(origin == null || config != null)) { + forbidden(ctx, request); + return; + } + } + ctx.fireChannelRead(msg); + } + + private void handlePreflight(final ChannelHandlerContext ctx, final HttpRequest request) { + final HttpResponse response = new DefaultFullHttpResponse( + request.protocolVersion(), + OK, + Unpooled.buffer(0), + DefaultHttpHeadersFactory.headersFactory().withCombiningHeaders(true), + DefaultHttpHeadersFactory.trailersFactory().withCombiningHeaders(true)); + if (setOrigin(response)) { + setAllowMethods(response); + setAllowHeaders(response); + setAllowCredentials(response); + setMaxAge(response); + setPreflightHeaders(response); + setAllowPrivateNetwork(response); + } + if (!response.headers().contains(HttpHeaderNames.CONTENT_LENGTH)) { + response.headers().set(HttpHeaderNames.CONTENT_LENGTH, HttpHeaderValues.ZERO); + } + release(request); + respond(ctx, request, response); + } + + /** + * This is a non CORS specification feature which enables the setting of preflight + * response headers that might be required by intermediaries. + * + * @param response the HttpResponse to which the preflight response headers should be added. + */ + private void setPreflightHeaders(final HttpResponse response) { + response.headers().add(config.preflightResponseHeaders()); + } + + private CorsConfig getForOrigin(String requestOrigin) { + for (CorsConfig corsConfig : configList) { + if (corsConfig.isAnyOriginSupported()) { + return corsConfig; + } + if (corsConfig.origins().contains(requestOrigin)) { + return corsConfig; + } + if (corsConfig.isNullOriginAllowed() || NULL_ORIGIN.equals(requestOrigin)) { + return corsConfig; + } + } + return null; + } + + private boolean setOrigin(final HttpResponse response) { + final String origin = request.headers().get(HttpHeaderNames.ORIGIN); + if (origin != null && config != null) { + if (NULL_ORIGIN.equals(origin) && config.isNullOriginAllowed()) { + setNullOrigin(response); + return true; + } + if (config.isAnyOriginSupported()) { + if (config.isCredentialsAllowed()) { + echoRequestOrigin(response); + setVaryHeader(response); + } else { + setAnyOrigin(response); + } + return true; + } + if (config.origins().contains(origin)) { + setOrigin(response, origin); + setVaryHeader(response); + return true; + } + logger.debug("Request origin [{}]] was not among the configured origins [{}]", origin, config.origins()); + } + return false; + } + + private void echoRequestOrigin(final HttpResponse response) { + setOrigin(response, request.headers().get(HttpHeaderNames.ORIGIN)); + } + + private static void setVaryHeader(final HttpResponse response) { + response.headers().set(HttpHeaderNames.VARY, HttpHeaderNames.ORIGIN); + } + + private static void setAnyOrigin(final HttpResponse response) { + setOrigin(response, ANY_ORIGIN); + } + + private static void setNullOrigin(final HttpResponse response) { + setOrigin(response, NULL_ORIGIN); + } + + private static void setOrigin(final HttpResponse response, final String origin) { + response.headers().set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN, origin); + } + + private void setAllowCredentials(final HttpResponse response) { + if (config.isCredentialsAllowed() + && !response.headers().get(HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN).equals(ANY_ORIGIN)) { + response.headers().set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS, "true"); + } + } + + private static boolean isPreflightRequest(final HttpRequest request) { + final HttpHeaders headers = request.headers(); + return OPTIONS.equals(request.method()) && + headers.contains(HttpHeaderNames.ORIGIN) && + headers.contains(HttpHeaderNames.ACCESS_CONTROL_REQUEST_METHOD); + } + + private void setExposeHeaders(final HttpResponse response) { + if (!config.exposedHeaders().isEmpty()) { + response.headers().set(HttpHeaderNames.ACCESS_CONTROL_EXPOSE_HEADERS, config.exposedHeaders()); + } + } + + private void setAllowMethods(final HttpResponse response) { + response.headers().set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_METHODS, config.allowedRequestMethods()); + } + + private void setAllowHeaders(final HttpResponse response) { + response.headers().set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_HEADERS, config.allowedRequestHeaders()); + } + + private void setMaxAge(final HttpResponse response) { + response.headers().set(HttpHeaderNames.ACCESS_CONTROL_MAX_AGE, config.maxAge()); + } + + private void setAllowPrivateNetwork(final HttpResponse response) { + if (request.headers().contains(HttpHeaderNames.ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK)) { + if (config.isPrivateNetworkAllowed()) { + response.headers().set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK, "true"); + } else { + response.headers().set(HttpHeaderNames.ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK, "false"); + } + } + } + + @Override + public void write(final ChannelHandlerContext ctx, final Object msg, final ChannelPromise promise) + throws Exception { + if (config != null && config.isCorsSupportEnabled() && msg instanceof HttpResponse) { + final HttpResponse response = (HttpResponse) msg; + if (setOrigin(response)) { + setAllowCredentials(response); + setExposeHeaders(response); + } + } + ctx.write(msg, promise); + } + + private static void forbidden(final ChannelHandlerContext ctx, final HttpRequest request) { + HttpResponse response = new DefaultFullHttpResponse( + request.protocolVersion(), FORBIDDEN, ctx.alloc().buffer(0)); + response.headers().set(HttpHeaderNames.CONTENT_LENGTH, HttpHeaderValues.ZERO); + release(request); + respond(ctx, request, response); + } + + private static void respond( + final ChannelHandlerContext ctx, + final HttpRequest request, + final HttpResponse response) { + + final boolean keepAlive = HttpUtil.isKeepAlive(request); + + HttpUtil.setKeepAlive(response, keepAlive); + + final ChannelFuture future = ctx.writeAndFlush(response); + if (!keepAlive) { + future.addListener(ChannelFutureListener.CLOSE); + } + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/cors/package-info.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/cors/package-info.java new file mode 100644 index 0000000..6c1570f --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/cors/package-info.java @@ -0,0 +1,20 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/** + * This package contains Cross Origin Resource Sharing (CORS) related classes. + */ +package io.netty.handler.codec.http.cors; diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/AbstractDiskHttpData.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/AbstractDiskHttpData.java new file mode 100644 index 0000000..39d4602 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/AbstractDiskHttpData.java @@ -0,0 +1,484 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.multipart; + +import io.netty.buffer.ByteBuf; +import io.netty.handler.codec.http.HttpConstants; +import io.netty.util.internal.EmptyArrays; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.io.RandomAccessFile; +import java.nio.ByteBuffer; +import java.nio.channels.FileChannel; +import java.nio.charset.Charset; + +import static io.netty.buffer.Unpooled.EMPTY_BUFFER; +import static io.netty.buffer.Unpooled.wrappedBuffer; + +/** + * Abstract Disk HttpData implementation + */ +public abstract class AbstractDiskHttpData extends AbstractHttpData { + + private static final InternalLogger logger = InternalLoggerFactory.getInstance(AbstractDiskHttpData.class); + + private File file; + private boolean isRenamed; + private FileChannel fileChannel; + + protected AbstractDiskHttpData(String name, Charset charset, long size) { + super(name, charset, size); + } + + /** + * + * @return the real DiskFilename (basename) + */ + protected abstract String getDiskFilename(); + /** + * + * @return the default prefix + */ + protected abstract String getPrefix(); + /** + * + * @return the default base Directory + */ + protected abstract String getBaseDirectory(); + /** + * + * @return the default postfix + */ + protected abstract String getPostfix(); + /** + * + * @return True if the file should be deleted on Exit by default + */ + protected abstract boolean deleteOnExit(); + + /** + * @return a new Temp File from getDiskFilename(), default prefix, postfix and baseDirectory + */ + private File tempFile() throws IOException { + String newpostfix; + String diskFilename = getDiskFilename(); + if (diskFilename != null) { + newpostfix = '_' + Integer.toString(diskFilename.hashCode()); + } else { + newpostfix = getPostfix(); + } + File tmpFile; + if (getBaseDirectory() == null) { + // create a temporary file + tmpFile = PlatformDependent.createTempFile(getPrefix(), newpostfix, null); + } else { + tmpFile = PlatformDependent.createTempFile(getPrefix(), newpostfix, new File( + getBaseDirectory())); + } + if (deleteOnExit()) { + // See https://github.com/netty/netty/issues/10351 + DeleteFileOnExitHook.add(tmpFile.getPath()); + } + return tmpFile; + } + + @Override + public void setContent(ByteBuf buffer) throws IOException { + ObjectUtil.checkNotNull(buffer, "buffer"); + try { + size = buffer.readableBytes(); + checkSize(size); + if (definedSize > 0 && definedSize < size) { + throw new IOException("Out of size: " + size + " > " + definedSize); + } + if (file == null) { + file = tempFile(); + } + if (buffer.readableBytes() == 0) { + // empty file + if (!file.createNewFile()) { + if (file.length() == 0) { + return; + } else { + if (!file.delete() || !file.createNewFile()) { + throw new IOException("file exists already: " + file); + } + } + } + return; + } + RandomAccessFile accessFile = new RandomAccessFile(file, "rw"); + try { + accessFile.setLength(0); + FileChannel localfileChannel = accessFile.getChannel(); + ByteBuffer byteBuffer = buffer.nioBuffer(); + int written = 0; + while (written < size) { + written += localfileChannel.write(byteBuffer); + } + buffer.readerIndex(buffer.readerIndex() + written); + localfileChannel.force(false); + } finally { + accessFile.close(); + } + setCompleted(); + } finally { + // Release the buffer as it was retained before and we not need a reference to it at all + // See https://github.com/netty/netty/issues/1516 + buffer.release(); + } + } + + @Override + public void addContent(ByteBuf buffer, boolean last) + throws IOException { + if (buffer != null) { + try { + int localsize = buffer.readableBytes(); + checkSize(size + localsize); + if (definedSize > 0 && definedSize < size + localsize) { + throw new IOException("Out of size: " + (size + localsize) + + " > " + definedSize); + } + if (file == null) { + file = tempFile(); + } + if (fileChannel == null) { + RandomAccessFile accessFile = new RandomAccessFile(file, "rw"); + fileChannel = accessFile.getChannel(); + } + int remaining = localsize; + long position = fileChannel.position(); + int index = buffer.readerIndex(); + while (remaining > 0) { + int written = buffer.getBytes(index, fileChannel, position, remaining); + if (written < 0) { + break; + } + remaining -= written; + position += written; + index += written; + } + fileChannel.position(position); + buffer.readerIndex(index); + size += localsize - remaining; + } finally { + // Release the buffer as it was retained before and we not need a reference to it at all + // See https://github.com/netty/netty/issues/1516 + buffer.release(); + } + } + if (last) { + if (file == null) { + file = tempFile(); + } + if (fileChannel == null) { + RandomAccessFile accessFile = new RandomAccessFile(file, "rw"); + fileChannel = accessFile.getChannel(); + } + try { + fileChannel.force(false); + } finally { + fileChannel.close(); + } + fileChannel = null; + setCompleted(); + } else { + ObjectUtil.checkNotNull(buffer, "buffer"); + } + } + + @Override + public void setContent(File file) throws IOException { + long size = file.length(); + checkSize(size); + this.size = size; + if (this.file != null) { + delete(); + } + this.file = file; + isRenamed = true; + setCompleted(); + } + + @Override + public void setContent(InputStream inputStream) throws IOException { + ObjectUtil.checkNotNull(inputStream, "inputStream"); + if (file != null) { + delete(); + } + file = tempFile(); + RandomAccessFile accessFile = new RandomAccessFile(file, "rw"); + int written = 0; + try { + accessFile.setLength(0); + FileChannel localfileChannel = accessFile.getChannel(); + byte[] bytes = new byte[4096 * 4]; + ByteBuffer byteBuffer = ByteBuffer.wrap(bytes); + int read = inputStream.read(bytes); + while (read > 0) { + byteBuffer.position(read).flip(); + written += localfileChannel.write(byteBuffer); + checkSize(written); + byteBuffer.clear(); + read = inputStream.read(bytes); + } + localfileChannel.force(false); + } finally { + accessFile.close(); + } + size = written; + if (definedSize > 0 && definedSize < size) { + if (!file.delete()) { + logger.warn("Failed to delete: {}", file); + } + file = null; + throw new IOException("Out of size: " + size + " > " + definedSize); + } + isRenamed = true; + setCompleted(); + } + + @Override + public void delete() { + if (fileChannel != null) { + try { + fileChannel.force(false); + } catch (IOException e) { + logger.warn("Failed to force.", e); + } finally { + try { + fileChannel.close(); + } catch (IOException e) { + logger.warn("Failed to close a file.", e); + } + } + fileChannel = null; + } + if (!isRenamed) { + String filePath = null; + + if (file != null && file.exists()) { + filePath = file.getPath(); + if (!file.delete()) { + filePath = null; + logger.warn("Failed to delete: {}", file); + } + } + + // If you turn on deleteOnExit make sure it is executed. + if (deleteOnExit() && filePath != null) { + DeleteFileOnExitHook.remove(filePath); + } + file = null; + } + } + + @Override + public byte[] get() throws IOException { + if (file == null) { + return EmptyArrays.EMPTY_BYTES; + } + return readFrom(file); + } + + @Override + public ByteBuf getByteBuf() throws IOException { + if (file == null) { + return EMPTY_BUFFER; + } + byte[] array = readFrom(file); + return wrappedBuffer(array); + } + + @Override + public ByteBuf getChunk(int length) throws IOException { + if (file == null || length == 0) { + return EMPTY_BUFFER; + } + if (fileChannel == null) { + RandomAccessFile accessFile = new RandomAccessFile(file, "r"); + fileChannel = accessFile.getChannel(); + } + int read = 0; + ByteBuffer byteBuffer = ByteBuffer.allocate(length); + try { + while (read < length) { + int readnow = fileChannel.read(byteBuffer); + if (readnow == -1) { + fileChannel.close(); + fileChannel = null; + break; + } + read += readnow; + } + } catch (IOException e) { + fileChannel.close(); + fileChannel = null; + throw e; + } + if (read == 0) { + return EMPTY_BUFFER; + } + byteBuffer.flip(); + ByteBuf buffer = wrappedBuffer(byteBuffer); + buffer.readerIndex(0); + buffer.writerIndex(read); + return buffer; + } + + @Override + public String getString() throws IOException { + return getString(HttpConstants.DEFAULT_CHARSET); + } + + @Override + public String getString(Charset encoding) throws IOException { + if (file == null) { + return ""; + } + if (encoding == null) { + byte[] array = readFrom(file); + return new String(array, HttpConstants.DEFAULT_CHARSET.name()); + } + byte[] array = readFrom(file); + return new String(array, encoding.name()); + } + + @Override + public boolean isInMemory() { + return false; + } + + @Override + public boolean renameTo(File dest) throws IOException { + ObjectUtil.checkNotNull(dest, "dest"); + if (file == null) { + throw new IOException("No file defined so cannot be renamed"); + } + if (!file.renameTo(dest)) { + // must copy + IOException exception = null; + RandomAccessFile inputAccessFile = null; + RandomAccessFile outputAccessFile = null; + long chunkSize = 8196; + long position = 0; + try { + inputAccessFile = new RandomAccessFile(file, "r"); + outputAccessFile = new RandomAccessFile(dest, "rw"); + FileChannel in = inputAccessFile.getChannel(); + FileChannel out = outputAccessFile.getChannel(); + while (position < size) { + if (chunkSize < size - position) { + chunkSize = size - position; + } + position += in.transferTo(position, chunkSize, out); + } + } catch (IOException e) { + exception = e; + } finally { + if (inputAccessFile != null) { + try { + inputAccessFile.close(); + } catch (IOException e) { + if (exception == null) { // Choose to report the first exception + exception = e; + } else { + logger.warn("Multiple exceptions detected, the following will be suppressed {}", e); + } + } + } + if (outputAccessFile != null) { + try { + outputAccessFile.close(); + } catch (IOException e) { + if (exception == null) { // Choose to report the first exception + exception = e; + } else { + logger.warn("Multiple exceptions detected, the following will be suppressed {}", e); + } + } + } + } + if (exception != null) { + throw exception; + } + if (position == size) { + if (!file.delete()) { + logger.warn("Failed to delete: {}", file); + } + file = dest; + isRenamed = true; + return true; + } else { + if (!dest.delete()) { + logger.warn("Failed to delete: {}", dest); + } + return false; + } + } + file = dest; + isRenamed = true; + return true; + } + + /** + * Utility function + * + * @return the array of bytes + */ + private static byte[] readFrom(File src) throws IOException { + long srcsize = src.length(); + if (srcsize > Integer.MAX_VALUE) { + throw new IllegalArgumentException( + "File too big to be loaded in memory"); + } + RandomAccessFile accessFile = new RandomAccessFile(src, "r"); + byte[] array = new byte[(int) srcsize]; + try { + FileChannel fileChannel = accessFile.getChannel(); + ByteBuffer byteBuffer = ByteBuffer.wrap(array); + int read = 0; + while (read < srcsize) { + read += fileChannel.read(byteBuffer); + } + } finally { + accessFile.close(); + } + return array; + } + + @Override + public File getFile() throws IOException { + return file; + } + + @Override + public HttpData touch() { + return this; + } + + @Override + public HttpData touch(Object hint) { + return this; + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/AbstractHttpData.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/AbstractHttpData.java new file mode 100644 index 0000000..406fc08 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/AbstractHttpData.java @@ -0,0 +1,144 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.multipart; + +import static io.netty.util.internal.ObjectUtil.checkNonEmpty; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelException; +import io.netty.handler.codec.http.HttpConstants; +import io.netty.util.AbstractReferenceCounted; +import io.netty.util.internal.ObjectUtil; + +import java.io.IOException; +import java.nio.charset.Charset; +import java.util.regex.Pattern; + +/** + * Abstract HttpData implementation + */ +public abstract class AbstractHttpData extends AbstractReferenceCounted implements HttpData { + + private static final Pattern STRIP_PATTERN = Pattern.compile("(?:^\\s+|\\s+$|\\n)"); + private static final Pattern REPLACE_PATTERN = Pattern.compile("[\\r\\t]"); + + private final String name; + protected long definedSize; + protected long size; + private Charset charset = HttpConstants.DEFAULT_CHARSET; + private boolean completed; + private long maxSize = DefaultHttpDataFactory.MAXSIZE; + + protected AbstractHttpData(String name, Charset charset, long size) { + ObjectUtil.checkNotNull(name, "name"); + + name = REPLACE_PATTERN.matcher(name).replaceAll(" "); + name = STRIP_PATTERN.matcher(name).replaceAll(""); + + this.name = checkNonEmpty(name, "name"); + if (charset != null) { + setCharset(charset); + } + definedSize = size; + } + + @Override + public long getMaxSize() { + return maxSize; + } + + @Override + public void setMaxSize(long maxSize) { + this.maxSize = maxSize; + } + + @Override + public void checkSize(long newSize) throws IOException { + if (maxSize >= 0 && newSize > maxSize) { + throw new IOException("Size exceed allowed maximum capacity"); + } + } + + @Override + public String getName() { + return name; + } + + @Override + public boolean isCompleted() { + return completed; + } + + protected void setCompleted() { + setCompleted(true); + } + + protected void setCompleted(boolean completed) { + this.completed = completed; + } + + @Override + public Charset getCharset() { + return charset; + } + + @Override + public void setCharset(Charset charset) { + this.charset = ObjectUtil.checkNotNull(charset, "charset"); + } + + @Override + public long length() { + return size; + } + + @Override + public long definedLength() { + return definedSize; + } + + @Override + public ByteBuf content() { + try { + return getByteBuf(); + } catch (IOException e) { + throw new ChannelException(e); + } + } + + @Override + protected void deallocate() { + delete(); + } + + @Override + public HttpData retain() { + super.retain(); + return this; + } + + @Override + public HttpData retain(int increment) { + super.retain(increment); + return this; + } + + @Override + public abstract HttpData touch(); + + @Override + public abstract HttpData touch(Object hint); +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/AbstractMemoryHttpData.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/AbstractMemoryHttpData.java new file mode 100644 index 0000000..f25801b --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/AbstractMemoryHttpData.java @@ -0,0 +1,303 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.multipart; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.CompositeByteBuf; +import io.netty.handler.codec.http.HttpConstants; +import io.netty.util.internal.ObjectUtil; + +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.io.RandomAccessFile; +import java.nio.ByteBuffer; +import java.nio.channels.FileChannel; +import java.nio.charset.Charset; + +import static io.netty.buffer.Unpooled.EMPTY_BUFFER; +import static io.netty.buffer.Unpooled.buffer; +import static io.netty.buffer.Unpooled.compositeBuffer; +import static io.netty.buffer.Unpooled.wrappedBuffer; + +/** + * Abstract Memory HttpData implementation + */ +public abstract class AbstractMemoryHttpData extends AbstractHttpData { + + private ByteBuf byteBuf; + private int chunkPosition; + + protected AbstractMemoryHttpData(String name, Charset charset, long size) { + super(name, charset, size); + byteBuf = EMPTY_BUFFER; + } + + @Override + public void setContent(ByteBuf buffer) throws IOException { + ObjectUtil.checkNotNull(buffer, "buffer"); + long localsize = buffer.readableBytes(); + try { + checkSize(localsize); + } catch (IOException e) { + buffer.release(); + throw e; + } + if (definedSize > 0 && definedSize < localsize) { + buffer.release(); + throw new IOException("Out of size: " + localsize + " > " + + definedSize); + } + if (byteBuf != null) { + byteBuf.release(); + } + byteBuf = buffer; + size = localsize; + setCompleted(); + } + + @Override + public void setContent(InputStream inputStream) throws IOException { + ObjectUtil.checkNotNull(inputStream, "inputStream"); + + byte[] bytes = new byte[4096 * 4]; + ByteBuf buffer = buffer(); + int written = 0; + try { + int read = inputStream.read(bytes); + while (read > 0) { + buffer.writeBytes(bytes, 0, read); + written += read; + checkSize(written); + read = inputStream.read(bytes); + } + } catch (IOException e) { + buffer.release(); + throw e; + } + size = written; + if (definedSize > 0 && definedSize < size) { + buffer.release(); + throw new IOException("Out of size: " + size + " > " + definedSize); + } + if (byteBuf != null) { + byteBuf.release(); + } + byteBuf = buffer; + setCompleted(); + } + + @Override + public void addContent(ByteBuf buffer, boolean last) + throws IOException { + if (buffer != null) { + long localsize = buffer.readableBytes(); + try { + checkSize(size + localsize); + } catch (IOException e) { + buffer.release(); + throw e; + } + if (definedSize > 0 && definedSize < size + localsize) { + buffer.release(); + throw new IOException("Out of size: " + (size + localsize) + + " > " + definedSize); + } + size += localsize; + if (byteBuf == null) { + byteBuf = buffer; + } else if (localsize == 0) { + // Nothing to add and byteBuf already exists + buffer.release(); + } else if (byteBuf.readableBytes() == 0) { + // Previous buffer is empty, so just replace it + byteBuf.release(); + byteBuf = buffer; + } else if (byteBuf instanceof CompositeByteBuf) { + CompositeByteBuf cbb = (CompositeByteBuf) byteBuf; + cbb.addComponent(true, buffer); + } else { + CompositeByteBuf cbb = compositeBuffer(Integer.MAX_VALUE); + cbb.addComponents(true, byteBuf, buffer); + byteBuf = cbb; + } + } + if (last) { + setCompleted(); + } else { + ObjectUtil.checkNotNull(buffer, "buffer"); + } + } + + @Override + public void setContent(File file) throws IOException { + ObjectUtil.checkNotNull(file, "file"); + + long newsize = file.length(); + if (newsize > Integer.MAX_VALUE) { + throw new IllegalArgumentException("File too big to be loaded in memory"); + } + checkSize(newsize); + RandomAccessFile accessFile = new RandomAccessFile(file, "r"); + ByteBuffer byteBuffer; + try { + FileChannel fileChannel = accessFile.getChannel(); + try { + byte[] array = new byte[(int) newsize]; + byteBuffer = ByteBuffer.wrap(array); + int read = 0; + while (read < newsize) { + read += fileChannel.read(byteBuffer); + } + } finally { + fileChannel.close(); + } + } finally { + accessFile.close(); + } + byteBuffer.flip(); + if (byteBuf != null) { + byteBuf.release(); + } + byteBuf = wrappedBuffer(Integer.MAX_VALUE, byteBuffer); + size = newsize; + setCompleted(); + } + + @Override + public void delete() { + if (byteBuf != null) { + byteBuf.release(); + byteBuf = null; + } + } + + @Override + public byte[] get() { + if (byteBuf == null) { + return EMPTY_BUFFER.array(); + } + byte[] array = new byte[byteBuf.readableBytes()]; + byteBuf.getBytes(byteBuf.readerIndex(), array); + return array; + } + + @Override + public String getString() { + return getString(HttpConstants.DEFAULT_CHARSET); + } + + @Override + public String getString(Charset encoding) { + if (byteBuf == null) { + return ""; + } + if (encoding == null) { + encoding = HttpConstants.DEFAULT_CHARSET; + } + return byteBuf.toString(encoding); + } + + /** + * Utility to go from a In Memory FileUpload + * to a Disk (or another implementation) FileUpload + * @return the attached ByteBuf containing the actual bytes + */ + @Override + public ByteBuf getByteBuf() { + return byteBuf; + } + + @Override + public ByteBuf getChunk(int length) throws IOException { + if (byteBuf == null || length == 0 || byteBuf.readableBytes() == 0) { + chunkPosition = 0; + return EMPTY_BUFFER; + } + int sizeLeft = byteBuf.readableBytes() - chunkPosition; + if (sizeLeft == 0) { + chunkPosition = 0; + return EMPTY_BUFFER; + } + int sliceLength = length; + if (sizeLeft < length) { + sliceLength = sizeLeft; + } + ByteBuf chunk = byteBuf.retainedSlice(chunkPosition, sliceLength); + chunkPosition += sliceLength; + return chunk; + } + + @Override + public boolean isInMemory() { + return true; + } + + @Override + public boolean renameTo(File dest) throws IOException { + ObjectUtil.checkNotNull(dest, "dest"); + if (byteBuf == null) { + // empty file + if (!dest.createNewFile()) { + throw new IOException("file exists already: " + dest); + } + return true; + } + int length = byteBuf.readableBytes(); + long written = 0; + RandomAccessFile accessFile = new RandomAccessFile(dest, "rw"); + try { + FileChannel fileChannel = accessFile.getChannel(); + try { + if (byteBuf.nioBufferCount() == 1) { + ByteBuffer byteBuffer = byteBuf.nioBuffer(); + while (written < length) { + written += fileChannel.write(byteBuffer); + } + } else { + ByteBuffer[] byteBuffers = byteBuf.nioBuffers(); + while (written < length) { + written += fileChannel.write(byteBuffers); + } + } + fileChannel.force(false); + } finally { + fileChannel.close(); + } + } finally { + accessFile.close(); + } + return written == length; + } + + @Override + public File getFile() throws IOException { + throw new IOException("Not represented by a file"); + } + + @Override + public HttpData touch() { + return touch(null); + } + + @Override + public HttpData touch(Object hint) { + if (byteBuf != null) { + byteBuf.touch(hint); + } + return this; + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/AbstractMixedHttpData.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/AbstractMixedHttpData.java new file mode 100644 index 0000000..5fb9573 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/AbstractMixedHttpData.java @@ -0,0 +1,279 @@ +/* + * Copyright 2022 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.multipart; + +import io.netty.buffer.ByteBuf; +import io.netty.util.AbstractReferenceCounted; + +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.Charset; + +abstract class AbstractMixedHttpData extends AbstractReferenceCounted implements HttpData { + final String baseDir; + final boolean deleteOnExit; + D wrapped; + + private final long limitSize; + + AbstractMixedHttpData(long limitSize, String baseDir, boolean deleteOnExit, D initial) { + this.limitSize = limitSize; + this.wrapped = initial; + this.baseDir = baseDir; + this.deleteOnExit = deleteOnExit; + } + + abstract D makeDiskData(); + + @Override + public long getMaxSize() { + return wrapped.getMaxSize(); + } + + @Override + public void setMaxSize(long maxSize) { + wrapped.setMaxSize(maxSize); + } + + @Override + public ByteBuf content() { + return wrapped.content(); + } + + @Override + public void checkSize(long newSize) throws IOException { + wrapped.checkSize(newSize); + } + + @Override + public long definedLength() { + return wrapped.definedLength(); + } + + @Override + public Charset getCharset() { + return wrapped.getCharset(); + } + + @Override + public String getName() { + return wrapped.getName(); + } + + @Override + public void addContent(ByteBuf buffer, boolean last) throws IOException { + if (wrapped instanceof AbstractMemoryHttpData) { + try { + checkSize(wrapped.length() + buffer.readableBytes()); + if (wrapped.length() + buffer.readableBytes() > limitSize) { + D diskData = makeDiskData(); + ByteBuf data = ((AbstractMemoryHttpData) wrapped).getByteBuf(); + if (data != null && data.isReadable()) { + diskData.addContent(data.retain(), false); + } + wrapped.release(); + wrapped = diskData; + } + } catch (IOException e) { + buffer.release(); + throw e; + } + } + wrapped.addContent(buffer, last); + } + + @Override + protected void deallocate() { + delete(); + } + + @Override + public void delete() { + wrapped.delete(); + } + + @Override + public byte[] get() throws IOException { + return wrapped.get(); + } + + @Override + public ByteBuf getByteBuf() throws IOException { + return wrapped.getByteBuf(); + } + + @Override + public String getString() throws IOException { + return wrapped.getString(); + } + + @Override + public String getString(Charset encoding) throws IOException { + return wrapped.getString(encoding); + } + + @Override + public boolean isInMemory() { + return wrapped.isInMemory(); + } + + @Override + public long length() { + return wrapped.length(); + } + + @Override + public boolean renameTo(File dest) throws IOException { + return wrapped.renameTo(dest); + } + + @Override + public void setCharset(Charset charset) { + wrapped.setCharset(charset); + } + + @Override + public void setContent(ByteBuf buffer) throws IOException { + try { + checkSize(buffer.readableBytes()); + } catch (IOException e) { + buffer.release(); + throw e; + } + if (buffer.readableBytes() > limitSize) { + if (wrapped instanceof AbstractMemoryHttpData) { + // change to Disk + wrapped.release(); + wrapped = makeDiskData(); + } + } + wrapped.setContent(buffer); + } + + @Override + public void setContent(File file) throws IOException { + checkSize(file.length()); + if (file.length() > limitSize) { + if (wrapped instanceof AbstractMemoryHttpData) { + // change to Disk + wrapped.release(); + wrapped = makeDiskData(); + } + } + wrapped.setContent(file); + } + + @Override + public void setContent(InputStream inputStream) throws IOException { + if (wrapped instanceof AbstractMemoryHttpData) { + // change to Disk even if we don't know the size + wrapped.release(); + wrapped = makeDiskData(); + } + wrapped.setContent(inputStream); + } + + @Override + public boolean isCompleted() { + return wrapped.isCompleted(); + } + + @Override + public HttpDataType getHttpDataType() { + return wrapped.getHttpDataType(); + } + + @Override + public int hashCode() { + return wrapped.hashCode(); + } + + @Override + public boolean equals(Object obj) { + return wrapped.equals(obj); + } + + @Override + public int compareTo(InterfaceHttpData o) { + return wrapped.compareTo(o); + } + + @Override + public String toString() { + return "Mixed: " + wrapped; + } + + @Override + public ByteBuf getChunk(int length) throws IOException { + return wrapped.getChunk(length); + } + + @Override + public File getFile() throws IOException { + return wrapped.getFile(); + } + + @SuppressWarnings("unchecked") + @Override + public D copy() { + return (D) wrapped.copy(); + } + + @SuppressWarnings("unchecked") + @Override + public D duplicate() { + return (D) wrapped.duplicate(); + } + + @SuppressWarnings("unchecked") + @Override + public D retainedDuplicate() { + return (D) wrapped.retainedDuplicate(); + } + + @SuppressWarnings("unchecked") + @Override + public D replace(ByteBuf content) { + return (D) wrapped.replace(content); + } + + @SuppressWarnings("unchecked") + @Override + public D touch() { + wrapped.touch(); + return (D) this; + } + + @SuppressWarnings("unchecked") + @Override + public D touch(Object hint) { + wrapped.touch(hint); + return (D) this; + } + + @SuppressWarnings("unchecked") + @Override + public D retain() { + return (D) super.retain(); + } + + @SuppressWarnings("unchecked") + @Override + public D retain(int increment) { + return (D) super.retain(increment); + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/Attribute.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/Attribute.java new file mode 100644 index 0000000..5250cc5 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/Attribute.java @@ -0,0 +1,59 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.multipart; + +import io.netty.buffer.ByteBuf; + +import java.io.IOException; + +/** + * Attribute interface + */ +public interface Attribute extends HttpData { + /** + * Returns the value of this HttpData. + */ + String getValue() throws IOException; + + /** + * Sets the value of this HttpData. + */ + void setValue(String value) throws IOException; + + @Override + Attribute copy(); + + @Override + Attribute duplicate(); + + @Override + Attribute retainedDuplicate(); + + @Override + Attribute replace(ByteBuf content); + + @Override + Attribute retain(); + + @Override + Attribute retain(int increment); + + @Override + Attribute touch(); + + @Override + Attribute touch(Object hint); +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/CaseIgnoringComparator.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/CaseIgnoringComparator.java new file mode 100644 index 0000000..034b74b --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/CaseIgnoringComparator.java @@ -0,0 +1,56 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.multipart; + +import java.io.Serializable; +import java.util.Comparator; + +final class CaseIgnoringComparator implements Comparator, Serializable { + + private static final long serialVersionUID = 4582133183775373862L; + + static final CaseIgnoringComparator INSTANCE = new CaseIgnoringComparator(); + + private CaseIgnoringComparator() { + } + + @Override + public int compare(CharSequence o1, CharSequence o2) { + int o1Length = o1.length(); + int o2Length = o2.length(); + int min = Math.min(o1Length, o2Length); + for (int i = 0; i < min; i++) { + char c1 = o1.charAt(i); + char c2 = o2.charAt(i); + if (c1 != c2) { + c1 = Character.toUpperCase(c1); + c2 = Character.toUpperCase(c2); + if (c1 != c2) { + c1 = Character.toLowerCase(c1); + c2 = Character.toLowerCase(c2); + if (c1 != c2) { + return c1 - c2; + } + } + } + } + return o1Length - o2Length; + } + + private Object readResolve() { + return INSTANCE; + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/DefaultHttpDataFactory.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/DefaultHttpDataFactory.java new file mode 100644 index 0000000..254ff2f --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/DefaultHttpDataFactory.java @@ -0,0 +1,348 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.multipart; + +import io.netty.handler.codec.http.DefaultHttpRequest; +import io.netty.handler.codec.http.HttpConstants; +import io.netty.handler.codec.http.HttpRequest; + +import java.io.IOException; +import java.nio.charset.Charset; +import java.util.ArrayList; +import java.util.Collections; +import java.util.IdentityHashMap; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; + +/** + * Default factory giving {@link Attribute} and {@link FileUpload} according to constructor. + * + *

According to the constructor, {@link Attribute} and {@link FileUpload} can be:

+ *
    + *
  • MemoryAttribute, DiskAttribute or MixedAttribute
  • + *
  • MemoryFileUpload, DiskFileUpload or MixedFileUpload
  • + *
+ * A good example of releasing HttpData once all work is done is as follow:
+ *
{@code
+ *   for (InterfaceHttpData httpData: decoder.getBodyHttpDatas()) {
+ *     httpData.release();
+ *     factory.removeHttpDataFromClean(request, httpData);
+ *   }
+ *   factory.cleanAllHttpData();
+ *   decoder.destroy();
+ *  }
+ */ +public class DefaultHttpDataFactory implements HttpDataFactory { + + /** + * Proposed default MINSIZE as 16 KB. + */ + public static final long MINSIZE = 0x4000; + /** + * Proposed default MAXSIZE = -1 as UNLIMITED + */ + public static final long MAXSIZE = -1; + + private final boolean useDisk; + + private final boolean checkSize; + + private long minSize; + + private long maxSize = MAXSIZE; + + private Charset charset = HttpConstants.DEFAULT_CHARSET; + + private String baseDir; + + private boolean deleteOnExit; // false is a good default cause true leaks + + /** + * Keep all {@link HttpData}s until cleaning methods are called. + * We need to use {@link IdentityHashMap} because different requests may be equal. + * See {@link DefaultHttpRequest#hashCode} and {@link DefaultHttpRequest#equals}. + * Similarly, when removing data items, we need to check their identities because + * different data items may be equal. + */ + private final Map> requestFileDeleteMap = + Collections.synchronizedMap(new IdentityHashMap>()); + + /** + * HttpData will be in memory if less than default size (16KB). + * The type will be Mixed. + */ + public DefaultHttpDataFactory() { + useDisk = false; + checkSize = true; + minSize = MINSIZE; + } + + public DefaultHttpDataFactory(Charset charset) { + this(); + this.charset = charset; + } + + /** + * HttpData will be always on Disk if useDisk is True, else always in Memory if False + */ + public DefaultHttpDataFactory(boolean useDisk) { + this.useDisk = useDisk; + checkSize = false; + } + + public DefaultHttpDataFactory(boolean useDisk, Charset charset) { + this(useDisk); + this.charset = charset; + } + /** + * HttpData will be on Disk if the size of the file is greater than minSize, else it + * will be in memory. The type will be Mixed. + */ + public DefaultHttpDataFactory(long minSize) { + useDisk = false; + checkSize = true; + this.minSize = minSize; + } + + public DefaultHttpDataFactory(long minSize, Charset charset) { + this(minSize); + this.charset = charset; + } + + /** + * Override global {@link DiskAttribute#baseDirectory} and {@link DiskFileUpload#baseDirectory} values. + * + * @param baseDir directory path where to store disk attributes and file uploads. + */ + public void setBaseDir(String baseDir) { + this.baseDir = baseDir; + } + + /** + * Override global {@link DiskAttribute#deleteOnExitTemporaryFile} and + * {@link DiskFileUpload#deleteOnExitTemporaryFile} values. + * + * @param deleteOnExit true if temporary files should be deleted with the JVM, false otherwise. + */ + public void setDeleteOnExit(boolean deleteOnExit) { + this.deleteOnExit = deleteOnExit; + } + + @Override + public void setMaxLimit(long maxSize) { + this.maxSize = maxSize; + } + + /** + * @return the associated list of {@link HttpData} for the request + */ + private List getList(HttpRequest request) { + List list = requestFileDeleteMap.get(request); + if (list == null) { + list = new ArrayList(); + requestFileDeleteMap.put(request, list); + } + return list; + } + + @Override + public Attribute createAttribute(HttpRequest request, String name) { + if (useDisk) { + Attribute attribute = new DiskAttribute(name, charset, baseDir, deleteOnExit); + attribute.setMaxSize(maxSize); + List list = getList(request); + list.add(attribute); + return attribute; + } + if (checkSize) { + Attribute attribute = new MixedAttribute(name, minSize, charset, baseDir, deleteOnExit); + attribute.setMaxSize(maxSize); + List list = getList(request); + list.add(attribute); + return attribute; + } + MemoryAttribute attribute = new MemoryAttribute(name); + attribute.setMaxSize(maxSize); + return attribute; + } + + @Override + public Attribute createAttribute(HttpRequest request, String name, long definedSize) { + if (useDisk) { + Attribute attribute = new DiskAttribute(name, definedSize, charset, baseDir, deleteOnExit); + attribute.setMaxSize(maxSize); + List list = getList(request); + list.add(attribute); + return attribute; + } + if (checkSize) { + Attribute attribute = new MixedAttribute(name, definedSize, minSize, charset, baseDir, deleteOnExit); + attribute.setMaxSize(maxSize); + List list = getList(request); + list.add(attribute); + return attribute; + } + MemoryAttribute attribute = new MemoryAttribute(name, definedSize); + attribute.setMaxSize(maxSize); + return attribute; + } + + /** + * Utility method + */ + private static void checkHttpDataSize(HttpData data) { + try { + data.checkSize(data.length()); + } catch (IOException ignored) { + throw new IllegalArgumentException("Attribute bigger than maxSize allowed"); + } + } + + @Override + public Attribute createAttribute(HttpRequest request, String name, String value) { + if (useDisk) { + Attribute attribute; + try { + attribute = new DiskAttribute(name, value, charset, baseDir, deleteOnExit); + attribute.setMaxSize(maxSize); + } catch (IOException e) { + // revert to Mixed mode + attribute = new MixedAttribute(name, value, minSize, charset, baseDir, deleteOnExit); + attribute.setMaxSize(maxSize); + } + checkHttpDataSize(attribute); + List list = getList(request); + list.add(attribute); + return attribute; + } + if (checkSize) { + Attribute attribute = new MixedAttribute(name, value, minSize, charset, baseDir, deleteOnExit); + attribute.setMaxSize(maxSize); + checkHttpDataSize(attribute); + List list = getList(request); + list.add(attribute); + return attribute; + } + try { + MemoryAttribute attribute = new MemoryAttribute(name, value, charset); + attribute.setMaxSize(maxSize); + checkHttpDataSize(attribute); + return attribute; + } catch (IOException e) { + throw new IllegalArgumentException(e); + } + } + + @Override + public FileUpload createFileUpload(HttpRequest request, String name, String filename, + String contentType, String contentTransferEncoding, Charset charset, + long size) { + if (useDisk) { + FileUpload fileUpload = new DiskFileUpload(name, filename, contentType, + contentTransferEncoding, charset, size, baseDir, deleteOnExit); + fileUpload.setMaxSize(maxSize); + checkHttpDataSize(fileUpload); + List list = getList(request); + list.add(fileUpload); + return fileUpload; + } + if (checkSize) { + FileUpload fileUpload = new MixedFileUpload(name, filename, contentType, + contentTransferEncoding, charset, size, minSize, baseDir, deleteOnExit); + fileUpload.setMaxSize(maxSize); + checkHttpDataSize(fileUpload); + List list = getList(request); + list.add(fileUpload); + return fileUpload; + } + MemoryFileUpload fileUpload = new MemoryFileUpload(name, filename, contentType, + contentTransferEncoding, charset, size); + fileUpload.setMaxSize(maxSize); + checkHttpDataSize(fileUpload); + return fileUpload; + } + + @Override + public void removeHttpDataFromClean(HttpRequest request, InterfaceHttpData data) { + if (!(data instanceof HttpData)) { + return; + } + + // Do not use getList because it adds empty list to requestFileDeleteMap + // if request is not found + List list = requestFileDeleteMap.get(request); + if (list == null) { + return; + } + + // Can't simply call list.remove(data), because different data items may be equal. + // Need to check identity. + Iterator i = list.iterator(); + while (i.hasNext()) { + HttpData n = i.next(); + if (n == data) { + i.remove(); + + // Remove empty list to avoid memory leak + if (list.isEmpty()) { + requestFileDeleteMap.remove(request); + } + + return; + } + } + } + + @Override + public void cleanRequestHttpData(HttpRequest request) { + List list = requestFileDeleteMap.remove(request); + if (list != null) { + for (HttpData data : list) { + data.release(); + } + } + } + + @Override + public void cleanAllHttpData() { + Iterator>> i = requestFileDeleteMap.entrySet().iterator(); + while (i.hasNext()) { + Entry> e = i.next(); + + // Calling i.remove() here will cause "java.lang.IllegalStateException: Entry was removed" + // at e.getValue() below + + List list = e.getValue(); + for (HttpData data : list) { + data.release(); + } + + i.remove(); + } + } + + @Override + public void cleanRequestHttpDatas(HttpRequest request) { + cleanRequestHttpData(request); + } + + @Override + public void cleanAllHttpDatas() { + cleanAllHttpData(); + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/DeleteFileOnExitHook.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/DeleteFileOnExitHook.java new file mode 100644 index 0000000..f93208d --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/DeleteFileOnExitHook.java @@ -0,0 +1,82 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.multipart; + +import java.io.File; +import java.util.Collections; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; + +/** + * DeleteFileOnExitHook. + */ +final class DeleteFileOnExitHook { + private static final Set FILES = Collections.newSetFromMap(new ConcurrentHashMap()); + + private DeleteFileOnExitHook() { + } + + static { + // DeleteOnExitHook must be the last shutdown hook to be invoked. + // Application shutdown hooks may add the first file to the + // delete on exit list and cause the DeleteOnExitHook to be + // registered during shutdown in progress. + Runtime.getRuntime().addShutdownHook(new Thread() { + + @Override + public void run() { + runHook(); + } + }); + } + + /** + * Remove from the pool to reduce space footprint. + * + * @param file tmp file path + */ + public static void remove(String file) { + FILES.remove(file); + } + + /** + * Add to the hook and clean up when the program exits. + * + * @param file tmp file path + */ + public static void add(String file) { + FILES.add(file); + } + + /** + * Check in the hook files. + * + * @param file target file + * @return true or false + */ + public static boolean checkFileExist(String file) { + return FILES.contains(file); + } + + /** + * Clean up all the files. + */ + static void runHook() { + for (String filename : FILES) { + new File(filename).delete(); + } + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/DiskAttribute.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/DiskAttribute.java new file mode 100644 index 0000000..bb22afc --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/DiskAttribute.java @@ -0,0 +1,272 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.multipart; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelException; +import io.netty.handler.codec.http.HttpConstants; +import io.netty.util.internal.ObjectUtil; + +import java.io.IOException; +import java.nio.charset.Charset; + +import static io.netty.buffer.Unpooled.wrappedBuffer; + +/** + * Disk implementation of Attributes + */ +public class DiskAttribute extends AbstractDiskHttpData implements Attribute { + public static String baseDirectory; + + public static boolean deleteOnExitTemporaryFile = true; + + public static final String prefix = "Attr_"; + + public static final String postfix = ".att"; + + private String baseDir; + + private boolean deleteOnExit; + + /** + * Constructor used for huge Attribute + */ + public DiskAttribute(String name) { + this(name, HttpConstants.DEFAULT_CHARSET); + } + + public DiskAttribute(String name, String baseDir, boolean deleteOnExit) { + this(name, HttpConstants.DEFAULT_CHARSET); + this.baseDir = baseDir == null ? baseDirectory : baseDir; + this.deleteOnExit = deleteOnExit; + } + + public DiskAttribute(String name, long definedSize) { + this(name, definedSize, HttpConstants.DEFAULT_CHARSET, baseDirectory, deleteOnExitTemporaryFile); + } + + public DiskAttribute(String name, long definedSize, String baseDir, boolean deleteOnExit) { + this(name, definedSize, HttpConstants.DEFAULT_CHARSET); + this.baseDir = baseDir == null ? baseDirectory : baseDir; + this.deleteOnExit = deleteOnExit; + } + + public DiskAttribute(String name, Charset charset) { + this(name, charset, baseDirectory, deleteOnExitTemporaryFile); + } + + public DiskAttribute(String name, Charset charset, String baseDir, boolean deleteOnExit) { + super(name, charset, 0); + this.baseDir = baseDir == null ? baseDirectory : baseDir; + this.deleteOnExit = deleteOnExit; + } + + public DiskAttribute(String name, long definedSize, Charset charset) { + this(name, definedSize, charset, baseDirectory, deleteOnExitTemporaryFile); + } + + public DiskAttribute(String name, long definedSize, Charset charset, String baseDir, boolean deleteOnExit) { + super(name, charset, definedSize); + this.baseDir = baseDir == null ? baseDirectory : baseDir; + this.deleteOnExit = deleteOnExit; + } + + public DiskAttribute(String name, String value) throws IOException { + this(name, value, HttpConstants.DEFAULT_CHARSET); + } + + public DiskAttribute(String name, String value, Charset charset) throws IOException { + this(name, value, charset, baseDirectory, deleteOnExitTemporaryFile); + } + + public DiskAttribute(String name, String value, Charset charset, + String baseDir, boolean deleteOnExit) throws IOException { + super(name, charset, 0); // Attribute have no default size + setValue(value); + this.baseDir = baseDir == null ? baseDirectory : baseDir; + this.deleteOnExit = deleteOnExit; + } + + @Override + public HttpDataType getHttpDataType() { + return HttpDataType.Attribute; + } + + @Override + public String getValue() throws IOException { + byte [] bytes = get(); + return new String(bytes, getCharset()); + } + + @Override + public void setValue(String value) throws IOException { + ObjectUtil.checkNotNull(value, "value"); + byte [] bytes = value.getBytes(getCharset()); + checkSize(bytes.length); + ByteBuf buffer = wrappedBuffer(bytes); + if (definedSize > 0) { + definedSize = buffer.readableBytes(); + } + setContent(buffer); + } + + @Override + public void addContent(ByteBuf buffer, boolean last) throws IOException { + final long newDefinedSize = size + buffer.readableBytes(); + try { + checkSize(newDefinedSize); + } catch (IOException e) { + buffer.release(); + throw e; + } + if (definedSize > 0 && definedSize < newDefinedSize) { + definedSize = newDefinedSize; + } + super.addContent(buffer, last); + } + + @Override + public int hashCode() { + return getName().hashCode(); + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof Attribute)) { + return false; + } + Attribute attribute = (Attribute) o; + return getName().equalsIgnoreCase(attribute.getName()); + } + + @Override + public int compareTo(InterfaceHttpData o) { + if (!(o instanceof Attribute)) { + throw new ClassCastException("Cannot compare " + getHttpDataType() + + " with " + o.getHttpDataType()); + } + return compareTo((Attribute) o); + } + + public int compareTo(Attribute o) { + return getName().compareToIgnoreCase(o.getName()); + } + + @Override + public String toString() { + try { + return getName() + '=' + getValue(); + } catch (IOException e) { + return getName() + '=' + e; + } + } + + @Override + protected boolean deleteOnExit() { + return deleteOnExit; + } + + @Override + protected String getBaseDirectory() { + return baseDir; + } + + @Override + protected String getDiskFilename() { + return getName() + postfix; + } + + @Override + protected String getPostfix() { + return postfix; + } + + @Override + protected String getPrefix() { + return prefix; + } + + @Override + public Attribute copy() { + final ByteBuf content = content(); + return replace(content != null ? content.copy() : null); + } + + @Override + public Attribute duplicate() { + final ByteBuf content = content(); + return replace(content != null ? content.duplicate() : null); + } + + @Override + public Attribute retainedDuplicate() { + ByteBuf content = content(); + if (content != null) { + content = content.retainedDuplicate(); + boolean success = false; + try { + Attribute duplicate = replace(content); + success = true; + return duplicate; + } finally { + if (!success) { + content.release(); + } + } + } else { + return replace(null); + } + } + + @Override + public Attribute replace(ByteBuf content) { + DiskAttribute attr = new DiskAttribute(getName(), baseDir, deleteOnExit); + attr.setCharset(getCharset()); + if (content != null) { + try { + attr.setContent(content); + } catch (IOException e) { + throw new ChannelException(e); + } + } + attr.setCompleted(isCompleted()); + return attr; + } + + @Override + public Attribute retain(int increment) { + super.retain(increment); + return this; + } + + @Override + public Attribute retain() { + super.retain(); + return this; + } + + @Override + public Attribute touch() { + super.touch(); + return this; + } + + @Override + public Attribute touch(Object hint) { + super.touch(hint); + return this; + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/DiskFileUpload.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/DiskFileUpload.java new file mode 100644 index 0000000..0740a92 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/DiskFileUpload.java @@ -0,0 +1,240 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.multipart; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelException; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaderValues; +import io.netty.util.internal.ObjectUtil; + +import java.io.File; +import java.io.IOException; +import java.nio.charset.Charset; + +/** + * Disk FileUpload implementation that stores file into real files + */ +public class DiskFileUpload extends AbstractDiskHttpData implements FileUpload { + public static String baseDirectory; + + public static boolean deleteOnExitTemporaryFile = true; + + public static final String prefix = "FUp_"; + + public static final String postfix = ".tmp"; + + private final String baseDir; + + private final boolean deleteOnExit; + + private String filename; + + private String contentType; + + private String contentTransferEncoding; + + public DiskFileUpload(String name, String filename, String contentType, + String contentTransferEncoding, Charset charset, long size, String baseDir, boolean deleteOnExit) { + super(name, charset, size); + setFilename(filename); + setContentType(contentType); + setContentTransferEncoding(contentTransferEncoding); + this.baseDir = baseDir == null ? baseDirectory : baseDir; + this.deleteOnExit = deleteOnExit; + } + + public DiskFileUpload(String name, String filename, String contentType, + String contentTransferEncoding, Charset charset, long size) { + this(name, filename, contentType, contentTransferEncoding, + charset, size, baseDirectory, deleteOnExitTemporaryFile); + } + + @Override + public HttpDataType getHttpDataType() { + return HttpDataType.FileUpload; + } + + @Override + public String getFilename() { + return filename; + } + + @Override + public void setFilename(String filename) { + this.filename = ObjectUtil.checkNotNull(filename, "filename"); + } + + @Override + public int hashCode() { + return FileUploadUtil.hashCode(this); + } + + @Override + public boolean equals(Object o) { + return o instanceof FileUpload && FileUploadUtil.equals(this, (FileUpload) o); + } + + @Override + public int compareTo(InterfaceHttpData o) { + if (!(o instanceof FileUpload)) { + throw new ClassCastException("Cannot compare " + getHttpDataType() + + " with " + o.getHttpDataType()); + } + return compareTo((FileUpload) o); + } + + public int compareTo(FileUpload o) { + return FileUploadUtil.compareTo(this, o); + } + + @Override + public void setContentType(String contentType) { + this.contentType = ObjectUtil.checkNotNull(contentType, "contentType"); + } + + @Override + public String getContentType() { + return contentType; + } + + @Override + public String getContentTransferEncoding() { + return contentTransferEncoding; + } + + @Override + public void setContentTransferEncoding(String contentTransferEncoding) { + this.contentTransferEncoding = contentTransferEncoding; + } + + @Override + public String toString() { + File file = null; + try { + file = getFile(); + } catch (IOException e) { + // Should not occur. + } + + return HttpHeaderNames.CONTENT_DISPOSITION + ": " + + HttpHeaderValues.FORM_DATA + "; " + HttpHeaderValues.NAME + "=\"" + getName() + + "\"; " + HttpHeaderValues.FILENAME + "=\"" + filename + "\"\r\n" + + HttpHeaderNames.CONTENT_TYPE + ": " + contentType + + (getCharset() != null? "; " + HttpHeaderValues.CHARSET + '=' + getCharset().name() + "\r\n" : "\r\n") + + HttpHeaderNames.CONTENT_LENGTH + ": " + length() + "\r\n" + + "Completed: " + isCompleted() + + "\r\nIsInMemory: " + isInMemory() + "\r\nRealFile: " + + (file != null ? file.getAbsolutePath() : "null") + " DeleteAfter: " + deleteOnExit; + } + + @Override + protected boolean deleteOnExit() { + return deleteOnExit; + } + + @Override + protected String getBaseDirectory() { + return baseDir; + } + + @Override + protected String getDiskFilename() { + return "upload"; + } + + @Override + protected String getPostfix() { + return postfix; + } + + @Override + protected String getPrefix() { + return prefix; + } + + @Override + public FileUpload copy() { + final ByteBuf content = content(); + return replace(content != null ? content.copy() : null); + } + + @Override + public FileUpload duplicate() { + final ByteBuf content = content(); + return replace(content != null ? content.duplicate() : null); + } + + @Override + public FileUpload retainedDuplicate() { + ByteBuf content = content(); + if (content != null) { + content = content.retainedDuplicate(); + boolean success = false; + try { + FileUpload duplicate = replace(content); + success = true; + return duplicate; + } finally { + if (!success) { + content.release(); + } + } + } else { + return replace(null); + } + } + + @Override + public FileUpload replace(ByteBuf content) { + DiskFileUpload upload = new DiskFileUpload( + getName(), getFilename(), getContentType(), getContentTransferEncoding(), getCharset(), size, + baseDir, deleteOnExit); + if (content != null) { + try { + upload.setContent(content); + } catch (IOException e) { + throw new ChannelException(e); + } + } + upload.setCompleted(isCompleted()); + return upload; + } + + @Override + public FileUpload retain(int increment) { + super.retain(increment); + return this; + } + + @Override + public FileUpload retain() { + super.retain(); + return this; + } + + @Override + public FileUpload touch() { + super.touch(); + return this; + } + + @Override + public FileUpload touch(Object hint) { + super.touch(hint); + return this; + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/FileUpload.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/FileUpload.java new file mode 100644 index 0000000..35b9741 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/FileUpload.java @@ -0,0 +1,84 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.multipart; + +import io.netty.buffer.ByteBuf; + +/** + * FileUpload interface that could be in memory, on temporary file or any other implementations. + * + * Most methods are inspired from java.io.File API. + */ +public interface FileUpload extends HttpData { + /** + * Returns the original filename in the client's filesystem, + * as provided by the browser (or other client software). + * @return the original filename + */ + String getFilename(); + + /** + * Set the original filename + */ + void setFilename(String filename); + + /** + * Set the Content Type passed by the browser if defined + * @param contentType Content Type to set - must be not null + */ + void setContentType(String contentType); + + /** + * Returns the content type passed by the browser or null if not defined. + * @return the content type passed by the browser or null if not defined. + */ + String getContentType(); + + /** + * Set the Content-Transfer-Encoding type from String as 7bit, 8bit or binary + */ + void setContentTransferEncoding(String contentTransferEncoding); + + /** + * Returns the Content-Transfer-Encoding + * @return the Content-Transfer-Encoding + */ + String getContentTransferEncoding(); + + @Override + FileUpload copy(); + + @Override + FileUpload duplicate(); + + @Override + FileUpload retainedDuplicate(); + + @Override + FileUpload replace(ByteBuf content); + + @Override + FileUpload retain(); + + @Override + FileUpload retain(int increment); + + @Override + FileUpload touch(); + + @Override + FileUpload touch(Object hint); +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/FileUploadUtil.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/FileUploadUtil.java new file mode 100644 index 0000000..6fa8131 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/FileUploadUtil.java @@ -0,0 +1,33 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.multipart; + +final class FileUploadUtil { + + private FileUploadUtil() { } + + static int hashCode(FileUpload upload) { + return upload.getName().hashCode(); + } + + static boolean equals(FileUpload upload1, FileUpload upload2) { + return upload1.getName().equalsIgnoreCase(upload2.getName()); + } + + static int compareTo(FileUpload upload1, FileUpload upload2) { + return upload1.getName().compareToIgnoreCase(upload2.getName()); + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/HttpData.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/HttpData.java new file mode 100644 index 0000000..72ac59c --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/HttpData.java @@ -0,0 +1,243 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.multipart; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufHolder; + +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.nio.charset.Charset; + +/** + * Extended interface for InterfaceHttpData + */ +public interface HttpData extends InterfaceHttpData, ByteBufHolder { + + /** + * Returns the maxSize for this HttpData. + */ + long getMaxSize(); + + /** + * Set the maxSize for this HttpData. When limit will be reached, an exception will be raised. + * Setting it to (-1) means no limitation. + * + * By default, to be set from the HttpDataFactory. + */ + void setMaxSize(long maxSize); + + /** + * Check if the new size is not reaching the max limit allowed. + * The limit is always computed in terms of bytes. + */ + void checkSize(long newSize) throws IOException; + + /** + * Set the content from the ChannelBuffer (erase any previous data) + *

{@link ByteBuf#release()} ownership of {@code buffer} is transferred to this {@link HttpData}. + * + * @param buffer + * must be not null + * @throws IOException + */ + void setContent(ByteBuf buffer) throws IOException; + + /** + * Add the content from the ChannelBuffer + *

{@link ByteBuf#release()} ownership of {@code buffer} is transferred to this {@link HttpData}. + * + * @param buffer + * must be not null except if last is set to False + * @param last + * True of the buffer is the last one + * @throws IOException + */ + void addContent(ByteBuf buffer, boolean last) throws IOException; + + /** + * Set the content from the file (erase any previous data) + * + * @param file + * must be not null + * @throws IOException + */ + void setContent(File file) throws IOException; + + /** + * Set the content from the inputStream (erase any previous data) + * + * @param inputStream + * must be not null + * @throws IOException + */ + void setContent(InputStream inputStream) throws IOException; + + /** + * + * @return True if the InterfaceHttpData is completed (all data are stored) + */ + boolean isCompleted(); + + /** + * Returns the size in byte of the InterfaceHttpData + * + * @return the size of the InterfaceHttpData + */ + long length(); + + /** + * Returns the defined length of the HttpData. + * + * If no Content-Length is provided in the request, the defined length is + * always 0 (whatever during decoding or in final state). + * + * If Content-Length is provided in the request, this is this given defined length. + * This value does not change, whatever during decoding or in the final state. + * + * This method could be used for instance to know the amount of bytes transmitted for + * one particular HttpData, for example one {@link FileUpload} or any known big {@link Attribute}. + * + * @return the defined length of the HttpData + */ + long definedLength(); + + /** + * Deletes the underlying storage for a file item, including deleting any + * associated temporary disk file. + */ + void delete(); + + /** + * Returns the contents of the file item as an array of bytes.
+ * Note: this method will allocate a lot of memory, if the data is currently stored on the file system. + * + * @return the contents of the file item as an array of bytes. + * @throws IOException + */ + byte[] get() throws IOException; + + /** + * Returns the content of the file item as a ByteBuf.
+ * Note: this method will allocate a lot of memory, if the data is currently stored on the file system. + * + * @return the content of the file item as a ByteBuf + * @throws IOException + */ + ByteBuf getByteBuf() throws IOException; + + /** + * Returns a ChannelBuffer for the content from the current position with at + * most length read bytes, increasing the current position of the Bytes + * read. Once it arrives at the end, it returns an EMPTY_BUFFER and it + * resets the current position to 0. + * + * @return a ChannelBuffer for the content from the current position or an + * EMPTY_BUFFER if there is no more data to return + */ + ByteBuf getChunk(int length) throws IOException; + + /** + * Returns the contents of the file item as a String, using the default + * character encoding. + * + * @return the contents of the file item as a String, using the default + * character encoding. + * @throws IOException + */ + String getString() throws IOException; + + /** + * Returns the contents of the file item as a String, using the specified + * charset. + * + * @param encoding + * the charset to use + * @return the contents of the file item as a String, using the specified + * charset. + * @throws IOException + */ + String getString(Charset encoding) throws IOException; + + /** + * Set the Charset passed by the browser if defined + * + * @param charset + * Charset to set - must be not null + */ + void setCharset(Charset charset); + + /** + * Returns the Charset passed by the browser or null if not defined. + * + * @return the Charset passed by the browser or null if not defined. + */ + Charset getCharset(); + + /** + * A convenience getMethod to write an uploaded item to disk. If a previous one + * exists, it will be deleted. Once this getMethod is called, if successful, + * the new file will be out of the cleaner of the factory that creates the + * original InterfaceHttpData object. + * + * @param dest + * destination file - must be not null + * @return True if the write is successful + * @throws IOException + */ + boolean renameTo(File dest) throws IOException; + + /** + * Provides a hint as to whether or not the file contents will be read from + * memory. + * + * @return True if the file contents is in memory. + */ + boolean isInMemory(); + + /** + * + * @return the associated File if this data is represented in a file + * @exception IOException + * if this data is not represented by a file + */ + File getFile() throws IOException; + + @Override + HttpData copy(); + + @Override + HttpData duplicate(); + + @Override + HttpData retainedDuplicate(); + + @Override + HttpData replace(ByteBuf content); + + @Override + HttpData retain(); + + @Override + HttpData retain(int increment); + + @Override + HttpData touch(); + + @Override + HttpData touch(Object hint); +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/HttpDataFactory.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/HttpDataFactory.java new file mode 100644 index 0000000..630c6d1 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/HttpDataFactory.java @@ -0,0 +1,93 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.multipart; + +import io.netty.handler.codec.http.HttpRequest; + +import java.nio.charset.Charset; + +/** + * Interface to enable creation of InterfaceHttpData objects + */ +public interface HttpDataFactory { + + /** + * To set a max size limitation on fields. Exceeding it will generate an ErrorDataDecoderException. + * A value of -1 means no limitation (default). + */ + void setMaxLimit(long max); + + /** + * + * @param request associated request + * @return a new Attribute with no value + */ + Attribute createAttribute(HttpRequest request, String name); + + /** + * @param request associated request + * @param name name of the attribute + * @param definedSize defined size from request for this attribute + * @return a new Attribute with no value + */ + Attribute createAttribute(HttpRequest request, String name, long definedSize); + + /** + * @param request associated request + * @return a new Attribute + */ + Attribute createAttribute(HttpRequest request, String name, String value); + + /** + * @param request associated request + * @param size the size of the Uploaded file + * @return a new FileUpload + */ + FileUpload createFileUpload(HttpRequest request, String name, String filename, + String contentType, String contentTransferEncoding, Charset charset, + long size); + + /** + * Remove the given InterfaceHttpData from clean list (will not delete the file, except if the file + * is still a temporary one as setup at construction) + * @param request associated request + */ + void removeHttpDataFromClean(HttpRequest request, InterfaceHttpData data); + + /** + * Remove all InterfaceHttpData from virtual File storage from clean list for the request + * + * @param request associated request + */ + void cleanRequestHttpData(HttpRequest request); + + /** + * Remove all InterfaceHttpData from virtual File storage from clean list for all requests + */ + void cleanAllHttpData(); + + /** + * @deprecated Use {@link #cleanRequestHttpData(HttpRequest)} instead. + */ + @Deprecated + void cleanRequestHttpDatas(HttpRequest request); + + /** + * @deprecated Use {@link #cleanAllHttpData()} instead. + */ + @Deprecated + void cleanAllHttpDatas(); +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/HttpPostBodyUtil.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/HttpPostBodyUtil.java new file mode 100644 index 0000000..e93dfaf --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/HttpPostBodyUtil.java @@ -0,0 +1,269 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.multipart; + +import io.netty.buffer.ByteBuf; +import io.netty.handler.codec.http.HttpConstants; + +/** + * Shared Static object between HttpMessageDecoder, HttpPostRequestDecoder and HttpPostRequestEncoder + */ +final class HttpPostBodyUtil { + + public static final int chunkSize = 8096; + + /** + * Default Content-Type in binary form + */ + public static final String DEFAULT_BINARY_CONTENT_TYPE = "application/octet-stream"; + + /** + * Default Content-Type in Text form + */ + public static final String DEFAULT_TEXT_CONTENT_TYPE = "text/plain"; + + /** + * Allowed mechanism for multipart + * mechanism := "7bit" + / "8bit" + / "binary" + Not allowed: "quoted-printable" + / "base64" + */ + public enum TransferEncodingMechanism { + /** + * Default encoding + */ + BIT7("7bit"), + /** + * Short lines but not in ASCII - no encoding + */ + BIT8("8bit"), + /** + * Could be long text not in ASCII - no encoding + */ + BINARY("binary"); + + private final String value; + + TransferEncodingMechanism(String value) { + this.value = value; + } + + public String value() { + return value; + } + + @Override + public String toString() { + return value; + } + } + + private HttpPostBodyUtil() { + } + + /** + * This class intends to decrease the CPU in seeking ahead some bytes in + * HttpPostRequestDecoder + */ + static class SeekAheadOptimize { + byte[] bytes; + int readerIndex; + int pos; + int origPos; + int limit; + ByteBuf buffer; + + /** + * @param buffer buffer with a backing byte array + */ + SeekAheadOptimize(ByteBuf buffer) { + if (!buffer.hasArray()) { + throw new IllegalArgumentException("buffer hasn't backing byte array"); + } + this.buffer = buffer; + bytes = buffer.array(); + readerIndex = buffer.readerIndex(); + origPos = pos = buffer.arrayOffset() + readerIndex; + limit = buffer.arrayOffset() + buffer.writerIndex(); + } + + /** + * + * @param minus this value will be used as (currentPos - minus) to set + * the current readerIndex in the buffer. + */ + void setReadPosition(int minus) { + pos -= minus; + readerIndex = getReadPosition(pos); + buffer.readerIndex(readerIndex); + } + + /** + * + * @param index raw index of the array (pos in general) + * @return the value equivalent of raw index to be used in readerIndex(value) + */ + int getReadPosition(int index) { + return index - origPos + readerIndex; + } + } + + /** + * Find the first non whitespace + * @return the rank of the first non whitespace + */ + static int findNonWhitespace(String sb, int offset) { + int result; + for (result = offset; result < sb.length(); result ++) { + if (!Character.isWhitespace(sb.charAt(result))) { + break; + } + } + return result; + } + + /** + * Find the end of String + * @return the rank of the end of string + */ + static int findEndOfString(String sb) { + int result; + for (result = sb.length(); result > 0; result --) { + if (!Character.isWhitespace(sb.charAt(result - 1))) { + break; + } + } + return result; + } + + /** + * Try to find first LF or CRLF as Line Breaking + * + * @param buffer the buffer to search in + * @param index the index to start from in the buffer + * @return a relative position from index > 0 if LF or CRLF is found + * or < 0 if not found + */ + static int findLineBreak(ByteBuf buffer, int index) { + int toRead = buffer.readableBytes() - (index - buffer.readerIndex()); + int posFirstChar = buffer.bytesBefore(index, toRead, HttpConstants.LF); + if (posFirstChar == -1) { + // No LF, so neither CRLF + return -1; + } + if (posFirstChar > 0 && buffer.getByte(index + posFirstChar - 1) == HttpConstants.CR) { + posFirstChar--; + } + return posFirstChar; + } + + /** + * Try to find last LF or CRLF as Line Breaking + * + * @param buffer the buffer to search in + * @param index the index to start from in the buffer + * @return a relative position from index > 0 if LF or CRLF is found + * or < 0 if not found + */ + static int findLastLineBreak(ByteBuf buffer, int index) { + int candidate = findLineBreak(buffer, index); + int findCRLF = 0; + if (candidate >= 0) { + if (buffer.getByte(index + candidate) == HttpConstants.CR) { + findCRLF = 2; + } else { + findCRLF = 1; + } + candidate += findCRLF; + } + int next; + while (candidate > 0 && (next = findLineBreak(buffer, index + candidate)) >= 0) { + candidate += next; + if (buffer.getByte(index + candidate) == HttpConstants.CR) { + findCRLF = 2; + } else { + findCRLF = 1; + } + candidate += findCRLF; + } + return candidate - findCRLF; + } + + /** + * Try to find the delimiter, with LF or CRLF in front of it (added as delimiters) if needed + * + * @param buffer the buffer to search in + * @param index the index to start from in the buffer + * @param delimiter the delimiter as byte array + * @param precededByLineBreak true if it must be preceded by LF or CRLF, else false + * @return a relative position from index > 0 if delimiter found designing the start of it + * (including LF or CRLF is asked) + * or a number < 0 if delimiter is not found + * @throws IndexOutOfBoundsException + * if {@code offset + delimiter.length} is greater than {@code buffer.capacity} + */ + static int findDelimiter(ByteBuf buffer, int index, byte[] delimiter, boolean precededByLineBreak) { + final int delimiterLength = delimiter.length; + final int readerIndex = buffer.readerIndex(); + final int writerIndex = buffer.writerIndex(); + int toRead = writerIndex - index; + int newOffset = index; + boolean delimiterNotFound = true; + while (delimiterNotFound && delimiterLength <= toRead) { + // Find first position: delimiter + int posDelimiter = buffer.bytesBefore(newOffset, toRead, delimiter[0]); + if (posDelimiter < 0) { + return -1; + } + newOffset += posDelimiter; + toRead -= posDelimiter; + // Now check for delimiter + if (toRead >= delimiterLength) { + delimiterNotFound = false; + for (int i = 0; i < delimiterLength; i++) { + if (buffer.getByte(newOffset + i) != delimiter[i]) { + newOffset++; + toRead--; + delimiterNotFound = true; + break; + } + } + } + if (!delimiterNotFound) { + // Delimiter found, find if necessary: LF or CRLF + if (precededByLineBreak && newOffset > readerIndex) { + if (buffer.getByte(newOffset - 1) == HttpConstants.LF) { + newOffset--; + // Check if CR before: not mandatory to be there + if (newOffset > readerIndex && buffer.getByte(newOffset - 1) == HttpConstants.CR) { + newOffset--; + } + } else { + // Delimiter with Line Break could be further: iterate after first char of delimiter + newOffset++; + toRead--; + delimiterNotFound = true; + continue; + } + } + return newOffset - readerIndex; + } + } + return -1; + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/HttpPostMultipartRequestDecoder.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/HttpPostMultipartRequestDecoder.java new file mode 100644 index 0000000..09bd880 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/HttpPostMultipartRequestDecoder.java @@ -0,0 +1,1394 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.multipart; + +import io.netty.buffer.ByteBuf; +import io.netty.handler.codec.http.HttpConstants; +import io.netty.handler.codec.http.HttpContent; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaderValues; +import io.netty.handler.codec.http.HttpRequest; +import io.netty.handler.codec.http.LastHttpContent; +import io.netty.handler.codec.http.QueryStringDecoder; +import io.netty.handler.codec.http.multipart.HttpPostBodyUtil.SeekAheadOptimize; +import io.netty.handler.codec.http.multipart.HttpPostBodyUtil.TransferEncodingMechanism; +import io.netty.handler.codec.http.multipart.HttpPostRequestDecoder.EndOfDataDecoderException; +import io.netty.handler.codec.http.multipart.HttpPostRequestDecoder.ErrorDataDecoderException; +import io.netty.handler.codec.http.multipart.HttpPostRequestDecoder.MultiPartStatus; +import io.netty.handler.codec.http.multipart.HttpPostRequestDecoder.NotEnoughDataDecoderException; +import io.netty.util.CharsetUtil; +import io.netty.util.internal.EmptyArrays; +import io.netty.util.internal.InternalThreadLocalMap; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.StringUtil; + +import java.io.IOException; +import java.nio.charset.Charset; +import java.nio.charset.IllegalCharsetNameException; +import java.nio.charset.UnsupportedCharsetException; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.TreeMap; + +import static io.netty.util.internal.ObjectUtil.*; + +/** + * This decoder will decode Body and can handle POST BODY. + * + * You MUST call {@link #destroy()} after completion to release all resources. + * + */ +public class HttpPostMultipartRequestDecoder implements InterfaceHttpPostRequestDecoder { + + /** + * Factory used to create InterfaceHttpData + */ + private final HttpDataFactory factory; + + /** + * Request to decode + */ + private final HttpRequest request; + + /** + * Default charset to use + */ + private Charset charset; + + /** + * Does the last chunk already received + */ + private boolean isLastChunk; + + /** + * HttpDatas from Body + */ + private final List bodyListHttpData = new ArrayList(); + + /** + * HttpDatas as Map from Body + */ + private final Map> bodyMapHttpData = new TreeMap>( + CaseIgnoringComparator.INSTANCE); + + /** + * The current channelBuffer + */ + private ByteBuf undecodedChunk; + + /** + * Body HttpDatas current position + */ + private int bodyListHttpDataRank; + + /** + * If multipart, this is the boundary for the global multipart + */ + private final String multipartDataBoundary; + + /** + * If multipart, there could be internal multiparts (mixed) to the global + * multipart. Only one level is allowed. + */ + private String multipartMixedBoundary; + + /** + * Current getStatus + */ + private MultiPartStatus currentStatus = MultiPartStatus.NOTSTARTED; + + /** + * Used in Multipart + */ + private Map currentFieldAttributes; + + /** + * The current FileUpload that is currently in decode process + */ + private FileUpload currentFileUpload; + + /** + * The current Attribute that is currently in decode process + */ + private Attribute currentAttribute; + + private boolean destroyed; + + private int discardThreshold = HttpPostRequestDecoder.DEFAULT_DISCARD_THRESHOLD; + + /** + * + * @param request + * the request to decode + * @throws NullPointerException + * for request + * @throws ErrorDataDecoderException + * if the default charset was wrong when decoding or other + * errors + */ + public HttpPostMultipartRequestDecoder(HttpRequest request) { + this(new DefaultHttpDataFactory(DefaultHttpDataFactory.MINSIZE), request, HttpConstants.DEFAULT_CHARSET); + } + + /** + * + * @param factory + * the factory used to create InterfaceHttpData + * @param request + * the request to decode + * @throws NullPointerException + * for request or factory + * @throws ErrorDataDecoderException + * if the default charset was wrong when decoding or other + * errors + */ + public HttpPostMultipartRequestDecoder(HttpDataFactory factory, HttpRequest request) { + this(factory, request, HttpConstants.DEFAULT_CHARSET); + } + + /** + * + * @param factory + * the factory used to create InterfaceHttpData + * @param request + * the request to decode + * @param charset + * the charset to use as default + * @throws NullPointerException + * for request or charset or factory + * @throws ErrorDataDecoderException + * if the default charset was wrong when decoding or other + * errors + */ + public HttpPostMultipartRequestDecoder(HttpDataFactory factory, HttpRequest request, Charset charset) { + this.request = checkNotNull(request, "request"); + this.charset = checkNotNull(charset, "charset"); + this.factory = checkNotNull(factory, "factory"); + // Fill default values + + String contentTypeValue = this.request.headers().get(HttpHeaderNames.CONTENT_TYPE); + if (contentTypeValue == null) { + throw new ErrorDataDecoderException("No '" + HttpHeaderNames.CONTENT_TYPE + "' header present."); + } + + String[] dataBoundary = HttpPostRequestDecoder.getMultipartDataBoundary(contentTypeValue); + if (dataBoundary != null) { + multipartDataBoundary = dataBoundary[0]; + if (dataBoundary.length > 1 && dataBoundary[1] != null) { + try { + this.charset = Charset.forName(dataBoundary[1]); + } catch (IllegalCharsetNameException e) { + throw new ErrorDataDecoderException(e); + } + } + } else { + multipartDataBoundary = null; + } + currentStatus = MultiPartStatus.HEADERDELIMITER; + + try { + if (request instanceof HttpContent) { + // Offer automatically if the given request is als type of HttpContent + // See #1089 + offer((HttpContent) request); + } else { + parseBody(); + } + } catch (Throwable e) { + destroy(); + PlatformDependent.throwException(e); + } + } + + private void checkDestroyed() { + if (destroyed) { + throw new IllegalStateException(HttpPostMultipartRequestDecoder.class.getSimpleName() + + " was destroyed already"); + } + } + + /** + * True if this request is a Multipart request + * + * @return True if this request is a Multipart request + */ + @Override + public boolean isMultipart() { + checkDestroyed(); + return true; + } + + /** + * Set the amount of bytes after which read bytes in the buffer should be discarded. + * Setting this lower gives lower memory usage but with the overhead of more memory copies. + * Use {@code 0} to disable it. + */ + @Override + public void setDiscardThreshold(int discardThreshold) { + this.discardThreshold = checkPositiveOrZero(discardThreshold, "discardThreshold"); + } + + /** + * Return the threshold in bytes after which read data in the buffer should be discarded. + */ + @Override + public int getDiscardThreshold() { + return discardThreshold; + } + + /** + * This getMethod returns a List of all HttpDatas from body.
+ * + * If chunked, all chunks must have been offered using offer() getMethod. If + * not, NotEnoughDataDecoderException will be raised. + * + * @return the list of HttpDatas from Body part for POST getMethod + * @throws NotEnoughDataDecoderException + * Need more chunks + */ + @Override + public List getBodyHttpDatas() { + checkDestroyed(); + + if (!isLastChunk) { + throw new NotEnoughDataDecoderException(); + } + return bodyListHttpData; + } + + /** + * This getMethod returns a List of all HttpDatas with the given name from + * body.
+ * + * If chunked, all chunks must have been offered using offer() getMethod. If + * not, NotEnoughDataDecoderException will be raised. + * + * @return All Body HttpDatas with the given name (ignore case) + * @throws NotEnoughDataDecoderException + * need more chunks + */ + @Override + public List getBodyHttpDatas(String name) { + checkDestroyed(); + + if (!isLastChunk) { + throw new NotEnoughDataDecoderException(); + } + return bodyMapHttpData.get(name); + } + + /** + * This getMethod returns the first InterfaceHttpData with the given name from + * body.
+ * + * If chunked, all chunks must have been offered using offer() getMethod. If + * not, NotEnoughDataDecoderException will be raised. + * + * @return The first Body InterfaceHttpData with the given name (ignore + * case) + * @throws NotEnoughDataDecoderException + * need more chunks + */ + @Override + public InterfaceHttpData getBodyHttpData(String name) { + checkDestroyed(); + + if (!isLastChunk) { + throw new NotEnoughDataDecoderException(); + } + List list = bodyMapHttpData.get(name); + if (list != null) { + return list.get(0); + } + return null; + } + + /** + * Initialized the internals from a new chunk + * + * @param content + * the new received chunk + * @throws ErrorDataDecoderException + * if there is a problem with the charset decoding or other + * errors + */ + @Override + public HttpPostMultipartRequestDecoder offer(HttpContent content) { + checkDestroyed(); + + if (content instanceof LastHttpContent) { + isLastChunk = true; + } + + ByteBuf buf = content.content(); + if (undecodedChunk == null) { + undecodedChunk = + // Since the Handler will release the incoming later on, we need to copy it + // + // We are explicit allocate a buffer and NOT calling copy() as otherwise it may set a maxCapacity + // which is not really usable for us as we may exceed it once we add more bytes. + buf.alloc().buffer(buf.readableBytes()).writeBytes(buf); + } else { + undecodedChunk.writeBytes(buf); + } + parseBody(); + if (undecodedChunk != null && undecodedChunk.writerIndex() > discardThreshold) { + if (undecodedChunk.refCnt() == 1) { + // It's safe to call discardBytes() as we are the only owner of the buffer. + undecodedChunk.discardReadBytes(); + } else { + // There seems to be multiple references of the buffer. Let's copy the data and release the buffer to + // ensure we can give back memory to the system. + ByteBuf buffer = undecodedChunk.alloc().buffer(undecodedChunk.readableBytes()); + buffer.writeBytes(undecodedChunk); + undecodedChunk.release(); + undecodedChunk = buffer; + } + } + return this; + } + + /** + * True if at current getStatus, there is an available decoded + * InterfaceHttpData from the Body. + * + * This getMethod works for chunked and not chunked request. + * + * @return True if at current getStatus, there is a decoded InterfaceHttpData + * @throws EndOfDataDecoderException + * No more data will be available + */ + @Override + public boolean hasNext() { + checkDestroyed(); + + if (currentStatus == MultiPartStatus.EPILOGUE) { + // OK except if end of list + if (bodyListHttpDataRank >= bodyListHttpData.size()) { + throw new EndOfDataDecoderException(); + } + } + return !bodyListHttpData.isEmpty() && bodyListHttpDataRank < bodyListHttpData.size(); + } + + /** + * Returns the next available InterfaceHttpData or null if, at the time it + * is called, there is no more available InterfaceHttpData. A subsequent + * call to offer(httpChunk) could enable more data. + * + * Be sure to call {@link InterfaceHttpData#release()} after you are done + * with processing to make sure to not leak any resources + * + * @return the next available InterfaceHttpData or null if none + * @throws EndOfDataDecoderException + * No more data will be available + */ + @Override + public InterfaceHttpData next() { + checkDestroyed(); + + if (hasNext()) { + return bodyListHttpData.get(bodyListHttpDataRank++); + } + return null; + } + + @Override + public InterfaceHttpData currentPartialHttpData() { + if (currentFileUpload != null) { + return currentFileUpload; + } else { + return currentAttribute; + } + } + + /** + * This getMethod will parse as much as possible data and fill the list and map + * + * @throws ErrorDataDecoderException + * if there is a problem with the charset decoding or other + * errors + */ + private void parseBody() { + if (currentStatus == MultiPartStatus.PREEPILOGUE || currentStatus == MultiPartStatus.EPILOGUE) { + if (isLastChunk) { + currentStatus = MultiPartStatus.EPILOGUE; + } + return; + } + parseBodyMultipart(); + } + + /** + * Utility function to add a new decoded data + */ + protected void addHttpData(InterfaceHttpData data) { + if (data == null) { + return; + } + List datas = bodyMapHttpData.get(data.getName()); + if (datas == null) { + datas = new ArrayList(1); + bodyMapHttpData.put(data.getName(), datas); + } + datas.add(data); + bodyListHttpData.add(data); + } + + /** + * Parse the Body for multipart + * + * @throws ErrorDataDecoderException + * if there is a problem with the charset decoding or other + * errors + */ + private void parseBodyMultipart() { + if (undecodedChunk == null || undecodedChunk.readableBytes() == 0) { + // nothing to decode + return; + } + InterfaceHttpData data = decodeMultipart(currentStatus); + while (data != null) { + addHttpData(data); + if (currentStatus == MultiPartStatus.PREEPILOGUE || currentStatus == MultiPartStatus.EPILOGUE) { + break; + } + data = decodeMultipart(currentStatus); + } + } + + /** + * Decode a multipart request by pieces
+ *
+ * NOTSTARTED PREAMBLE (
+ * (HEADERDELIMITER DISPOSITION (FIELD | FILEUPLOAD))*
+ * (HEADERDELIMITER DISPOSITION MIXEDPREAMBLE
+ * (MIXEDDELIMITER MIXEDDISPOSITION MIXEDFILEUPLOAD)+
+ * MIXEDCLOSEDELIMITER)*
+ * CLOSEDELIMITER)+ EPILOGUE
+ * + * Inspired from HttpMessageDecoder + * + * @return the next decoded InterfaceHttpData or null if none until now. + * @throws ErrorDataDecoderException + * if an error occurs + */ + private InterfaceHttpData decodeMultipart(MultiPartStatus state) { + switch (state) { + case NOTSTARTED: + throw new ErrorDataDecoderException("Should not be called with the current getStatus"); + case PREAMBLE: + // Content-type: multipart/form-data, boundary=AaB03x + throw new ErrorDataDecoderException("Should not be called with the current getStatus"); + case HEADERDELIMITER: { + // --AaB03x or --AaB03x-- + return findMultipartDelimiter(multipartDataBoundary, MultiPartStatus.DISPOSITION, + MultiPartStatus.PREEPILOGUE); + } + case DISPOSITION: { + // content-disposition: form-data; name="field1" + // content-disposition: form-data; name="pics"; filename="file1.txt" + // and other immediate values like + // Content-type: image/gif + // Content-Type: text/plain + // Content-Type: text/plain; charset=ISO-8859-1 + // Content-Transfer-Encoding: binary + // The following line implies a change of mode (mixed mode) + // Content-type: multipart/mixed, boundary=BbC04y + return findMultipartDisposition(); + } + case FIELD: { + // Now get value according to Content-Type and Charset + Charset localCharset = null; + Attribute charsetAttribute = currentFieldAttributes.get(HttpHeaderValues.CHARSET); + if (charsetAttribute != null) { + try { + localCharset = Charset.forName(charsetAttribute.getValue()); + } catch (IOException e) { + throw new ErrorDataDecoderException(e); + } catch (UnsupportedCharsetException e) { + throw new ErrorDataDecoderException(e); + } + } + Attribute nameAttribute = currentFieldAttributes.get(HttpHeaderValues.NAME); + if (currentAttribute == null) { + Attribute lengthAttribute = currentFieldAttributes + .get(HttpHeaderNames.CONTENT_LENGTH); + long size; + try { + size = lengthAttribute != null? Long.parseLong(lengthAttribute + .getValue()) : 0L; + } catch (IOException e) { + throw new ErrorDataDecoderException(e); + } catch (NumberFormatException ignored) { + size = 0; + } + try { + if (size > 0) { + currentAttribute = factory.createAttribute(request, + cleanString(nameAttribute.getValue()), size); + } else { + currentAttribute = factory.createAttribute(request, + cleanString(nameAttribute.getValue())); + } + } catch (NullPointerException e) { + throw new ErrorDataDecoderException(e); + } catch (IllegalArgumentException e) { + throw new ErrorDataDecoderException(e); + } catch (IOException e) { + throw new ErrorDataDecoderException(e); + } + if (localCharset != null) { + currentAttribute.setCharset(localCharset); + } + } + // load data + if (!loadDataMultipartOptimized(undecodedChunk, multipartDataBoundary, currentAttribute)) { + // Delimiter is not found. Need more chunks. + return null; + } + Attribute finalAttribute = currentAttribute; + currentAttribute = null; + currentFieldAttributes = null; + // ready to load the next one + currentStatus = MultiPartStatus.HEADERDELIMITER; + return finalAttribute; + } + case FILEUPLOAD: { + // eventually restart from existing FileUpload + return getFileUpload(multipartDataBoundary); + } + case MIXEDDELIMITER: { + // --AaB03x or --AaB03x-- + // Note that currentFieldAttributes exists + return findMultipartDelimiter(multipartMixedBoundary, MultiPartStatus.MIXEDDISPOSITION, + MultiPartStatus.HEADERDELIMITER); + } + case MIXEDDISPOSITION: { + return findMultipartDisposition(); + } + case MIXEDFILEUPLOAD: { + // eventually restart from existing FileUpload + return getFileUpload(multipartMixedBoundary); + } + case PREEPILOGUE: + return null; + case EPILOGUE: + return null; + default: + throw new ErrorDataDecoderException("Shouldn't reach here."); + } + } + + /** + * Skip control Characters + * + * @throws NotEnoughDataDecoderException + */ + private static void skipControlCharacters(ByteBuf undecodedChunk) { + if (!undecodedChunk.hasArray()) { + try { + skipControlCharactersStandard(undecodedChunk); + } catch (IndexOutOfBoundsException e1) { + throw new NotEnoughDataDecoderException(e1); + } + return; + } + SeekAheadOptimize sao = new SeekAheadOptimize(undecodedChunk); + while (sao.pos < sao.limit) { + char c = (char) (sao.bytes[sao.pos++] & 0xFF); + if (!Character.isISOControl(c) && !Character.isWhitespace(c)) { + sao.setReadPosition(1); + return; + } + } + throw new NotEnoughDataDecoderException("Access out of bounds"); + } + + private static void skipControlCharactersStandard(ByteBuf undecodedChunk) { + for (;;) { + char c = (char) undecodedChunk.readUnsignedByte(); + if (!Character.isISOControl(c) && !Character.isWhitespace(c)) { + undecodedChunk.readerIndex(undecodedChunk.readerIndex() - 1); + break; + } + } + } + + /** + * Find the next Multipart Delimiter + * + * @param delimiter + * delimiter to find + * @param dispositionStatus + * the next getStatus if the delimiter is a start + * @param closeDelimiterStatus + * the next getStatus if the delimiter is a close delimiter + * @return the next InterfaceHttpData if any + * @throws ErrorDataDecoderException + */ + private InterfaceHttpData findMultipartDelimiter(String delimiter, MultiPartStatus dispositionStatus, + MultiPartStatus closeDelimiterStatus) { + // --AaB03x or --AaB03x-- + int readerIndex = undecodedChunk.readerIndex(); + try { + skipControlCharacters(undecodedChunk); + } catch (NotEnoughDataDecoderException ignored) { + undecodedChunk.readerIndex(readerIndex); + return null; + } + skipOneLine(); + String newline; + try { + newline = readDelimiterOptimized(undecodedChunk, delimiter, charset); + } catch (NotEnoughDataDecoderException ignored) { + undecodedChunk.readerIndex(readerIndex); + return null; + } + if (newline.equals(delimiter)) { + currentStatus = dispositionStatus; + return decodeMultipart(dispositionStatus); + } + if (newline.equals(delimiter + "--")) { + // CLOSEDELIMITER or MIXED CLOSEDELIMITER found + currentStatus = closeDelimiterStatus; + if (currentStatus == MultiPartStatus.HEADERDELIMITER) { + // MIXEDCLOSEDELIMITER + // end of the Mixed part + currentFieldAttributes = null; + return decodeMultipart(MultiPartStatus.HEADERDELIMITER); + } + return null; + } + undecodedChunk.readerIndex(readerIndex); + throw new ErrorDataDecoderException("No Multipart delimiter found"); + } + + /** + * Find the next Disposition + * + * @return the next InterfaceHttpData if any + * @throws ErrorDataDecoderException + */ + private InterfaceHttpData findMultipartDisposition() { + int readerIndex = undecodedChunk.readerIndex(); + if (currentStatus == MultiPartStatus.DISPOSITION) { + currentFieldAttributes = new TreeMap(CaseIgnoringComparator.INSTANCE); + } + // read many lines until empty line with newline found! Store all data + while (!skipOneLine()) { + String newline; + try { + skipControlCharacters(undecodedChunk); + newline = readLineOptimized(undecodedChunk, charset); + } catch (NotEnoughDataDecoderException ignored) { + undecodedChunk.readerIndex(readerIndex); + return null; + } + String[] contents = splitMultipartHeader(newline); + if (HttpHeaderNames.CONTENT_DISPOSITION.contentEqualsIgnoreCase(contents[0])) { + boolean checkSecondArg; + if (currentStatus == MultiPartStatus.DISPOSITION) { + checkSecondArg = HttpHeaderValues.FORM_DATA.contentEqualsIgnoreCase(contents[1]); + } else { + checkSecondArg = HttpHeaderValues.ATTACHMENT.contentEqualsIgnoreCase(contents[1]) + || HttpHeaderValues.FILE.contentEqualsIgnoreCase(contents[1]); + } + if (checkSecondArg) { + // read next values and store them in the map as Attribute + for (int i = 2; i < contents.length; i++) { + String[] values = contents[i].split("=", 2); + Attribute attribute; + try { + attribute = getContentDispositionAttribute(values); + } catch (NullPointerException e) { + throw new ErrorDataDecoderException(e); + } catch (IllegalArgumentException e) { + throw new ErrorDataDecoderException(e); + } + currentFieldAttributes.put(attribute.getName(), attribute); + } + } + } else if (HttpHeaderNames.CONTENT_TRANSFER_ENCODING.contentEqualsIgnoreCase(contents[0])) { + Attribute attribute; + try { + attribute = factory.createAttribute(request, HttpHeaderNames.CONTENT_TRANSFER_ENCODING.toString(), + cleanString(contents[1])); + } catch (NullPointerException e) { + throw new ErrorDataDecoderException(e); + } catch (IllegalArgumentException e) { + throw new ErrorDataDecoderException(e); + } + + currentFieldAttributes.put(HttpHeaderNames.CONTENT_TRANSFER_ENCODING, attribute); + } else if (HttpHeaderNames.CONTENT_LENGTH.contentEqualsIgnoreCase(contents[0])) { + Attribute attribute; + try { + attribute = factory.createAttribute(request, HttpHeaderNames.CONTENT_LENGTH.toString(), + cleanString(contents[1])); + } catch (NullPointerException e) { + throw new ErrorDataDecoderException(e); + } catch (IllegalArgumentException e) { + throw new ErrorDataDecoderException(e); + } + + currentFieldAttributes.put(HttpHeaderNames.CONTENT_LENGTH, attribute); + } else if (HttpHeaderNames.CONTENT_TYPE.contentEqualsIgnoreCase(contents[0])) { + // Take care of possible "multipart/mixed" + if (HttpHeaderValues.MULTIPART_MIXED.contentEqualsIgnoreCase(contents[1])) { + if (currentStatus == MultiPartStatus.DISPOSITION) { + String values = StringUtil.substringAfter(contents[2], '='); + multipartMixedBoundary = "--" + values; + currentStatus = MultiPartStatus.MIXEDDELIMITER; + return decodeMultipart(MultiPartStatus.MIXEDDELIMITER); + } else { + throw new ErrorDataDecoderException("Mixed Multipart found in a previous Mixed Multipart"); + } + } else { + for (int i = 1; i < contents.length; i++) { + final String charsetHeader = HttpHeaderValues.CHARSET.toString(); + if (contents[i].regionMatches(true, 0, charsetHeader, 0, charsetHeader.length())) { + String values = StringUtil.substringAfter(contents[i], '='); + Attribute attribute; + try { + attribute = factory.createAttribute(request, charsetHeader, cleanString(values)); + } catch (NullPointerException e) { + throw new ErrorDataDecoderException(e); + } catch (IllegalArgumentException e) { + throw new ErrorDataDecoderException(e); + } + currentFieldAttributes.put(HttpHeaderValues.CHARSET, attribute); + } else if (contents[i].contains("=")) { + String name = StringUtil.substringBefore(contents[i], '='); + String values = StringUtil.substringAfter(contents[i], '='); + Attribute attribute; + try { + attribute = factory.createAttribute(request, cleanString(name), values); + } catch (NullPointerException e) { + throw new ErrorDataDecoderException(e); + } catch (IllegalArgumentException e) { + throw new ErrorDataDecoderException(e); + } + currentFieldAttributes.put(name, attribute); + } else { + Attribute attribute; + try { + attribute = factory.createAttribute(request, + cleanString(contents[0]), contents[i]); + } catch (NullPointerException e) { + throw new ErrorDataDecoderException(e); + } catch (IllegalArgumentException e) { + throw new ErrorDataDecoderException(e); + } + currentFieldAttributes.put(attribute.getName(), attribute); + } + } + } + } + } + // Is it a FileUpload + Attribute filenameAttribute = currentFieldAttributes.get(HttpHeaderValues.FILENAME); + if (currentStatus == MultiPartStatus.DISPOSITION) { + if (filenameAttribute != null) { + // FileUpload + currentStatus = MultiPartStatus.FILEUPLOAD; + // do not change the buffer position + return decodeMultipart(MultiPartStatus.FILEUPLOAD); + } else { + // Field + currentStatus = MultiPartStatus.FIELD; + // do not change the buffer position + return decodeMultipart(MultiPartStatus.FIELD); + } + } else { + if (filenameAttribute != null) { + // FileUpload + currentStatus = MultiPartStatus.MIXEDFILEUPLOAD; + // do not change the buffer position + return decodeMultipart(MultiPartStatus.MIXEDFILEUPLOAD); + } else { + // Field is not supported in MIXED mode + throw new ErrorDataDecoderException("Filename not found"); + } + } + } + + private static final String FILENAME_ENCODED = HttpHeaderValues.FILENAME.toString() + '*'; + + private Attribute getContentDispositionAttribute(String... values) { + String name = cleanString(values[0]); + String value = values[1]; + + // Filename can be token, quoted or encoded. See https://tools.ietf.org/html/rfc5987 + if (HttpHeaderValues.FILENAME.contentEquals(name)) { + // Value is quoted or token. Strip if quoted: + int last = value.length() - 1; + if (last > 0 && + value.charAt(0) == HttpConstants.DOUBLE_QUOTE && + value.charAt(last) == HttpConstants.DOUBLE_QUOTE) { + value = value.substring(1, last); + } + } else if (FILENAME_ENCODED.equals(name)) { + try { + name = HttpHeaderValues.FILENAME.toString(); + String[] split = cleanString(value).split("'", 3); + value = QueryStringDecoder.decodeComponent(split[2], Charset.forName(split[0])); + } catch (ArrayIndexOutOfBoundsException e) { + throw new ErrorDataDecoderException(e); + } catch (UnsupportedCharsetException e) { + throw new ErrorDataDecoderException(e); + } + } else { + // otherwise we need to clean the value + value = cleanString(value); + } + return factory.createAttribute(request, name, value); + } + + /** + * Get the FileUpload (new one or current one) + * + * @param delimiter + * the delimiter to use + * @return the InterfaceHttpData if any + * @throws ErrorDataDecoderException + */ + protected InterfaceHttpData getFileUpload(String delimiter) { + // eventually restart from existing FileUpload + // Now get value according to Content-Type and Charset + Attribute encoding = currentFieldAttributes.get(HttpHeaderNames.CONTENT_TRANSFER_ENCODING); + Charset localCharset = charset; + // Default + TransferEncodingMechanism mechanism = TransferEncodingMechanism.BIT7; + if (encoding != null) { + String code; + try { + code = encoding.getValue().toLowerCase(); + } catch (IOException e) { + throw new ErrorDataDecoderException(e); + } + if (code.equals(HttpPostBodyUtil.TransferEncodingMechanism.BIT7.value())) { + localCharset = CharsetUtil.US_ASCII; + } else if (code.equals(HttpPostBodyUtil.TransferEncodingMechanism.BIT8.value())) { + localCharset = CharsetUtil.ISO_8859_1; + mechanism = TransferEncodingMechanism.BIT8; + } else if (code.equals(HttpPostBodyUtil.TransferEncodingMechanism.BINARY.value())) { + // no real charset, so let the default + mechanism = TransferEncodingMechanism.BINARY; + } else { + throw new ErrorDataDecoderException("TransferEncoding Unknown: " + code); + } + } + Attribute charsetAttribute = currentFieldAttributes.get(HttpHeaderValues.CHARSET); + if (charsetAttribute != null) { + try { + localCharset = Charset.forName(charsetAttribute.getValue()); + } catch (IOException e) { + throw new ErrorDataDecoderException(e); + } catch (UnsupportedCharsetException e) { + throw new ErrorDataDecoderException(e); + } + } + if (currentFileUpload == null) { + Attribute filenameAttribute = currentFieldAttributes.get(HttpHeaderValues.FILENAME); + Attribute nameAttribute = currentFieldAttributes.get(HttpHeaderValues.NAME); + Attribute contentTypeAttribute = currentFieldAttributes.get(HttpHeaderNames.CONTENT_TYPE); + Attribute lengthAttribute = currentFieldAttributes.get(HttpHeaderNames.CONTENT_LENGTH); + long size; + try { + size = lengthAttribute != null ? Long.parseLong(lengthAttribute.getValue()) : 0L; + } catch (IOException e) { + throw new ErrorDataDecoderException(e); + } catch (NumberFormatException ignored) { + size = 0; + } + try { + String contentType; + if (contentTypeAttribute != null) { + contentType = contentTypeAttribute.getValue(); + } else { + contentType = HttpPostBodyUtil.DEFAULT_BINARY_CONTENT_TYPE; + } + currentFileUpload = factory.createFileUpload(request, + cleanString(nameAttribute.getValue()), cleanString(filenameAttribute.getValue()), + contentType, mechanism.value(), localCharset, + size); + } catch (NullPointerException e) { + throw new ErrorDataDecoderException(e); + } catch (IllegalArgumentException e) { + throw new ErrorDataDecoderException(e); + } catch (IOException e) { + throw new ErrorDataDecoderException(e); + } + } + // load data as much as possible + if (!loadDataMultipartOptimized(undecodedChunk, delimiter, currentFileUpload)) { + // Delimiter is not found. Need more chunks. + return null; + } + if (currentFileUpload.isCompleted()) { + // ready to load the next one + if (currentStatus == MultiPartStatus.FILEUPLOAD) { + currentStatus = MultiPartStatus.HEADERDELIMITER; + currentFieldAttributes = null; + } else { + currentStatus = MultiPartStatus.MIXEDDELIMITER; + cleanMixedAttributes(); + } + FileUpload fileUpload = currentFileUpload; + currentFileUpload = null; + return fileUpload; + } + // do not change the buffer position + // since some can be already saved into FileUpload + // So do not change the currentStatus + return null; + } + + /** + * Destroy the {@link HttpPostMultipartRequestDecoder} and release all it resources. After this method + * was called it is not possible to operate on it anymore. + */ + @Override + public void destroy() { + // Release all data items, including those not yet pulled, only file based items + cleanFiles(); + // Clean Memory based data + for (InterfaceHttpData httpData : bodyListHttpData) { + // Might have been already released by the user + if (httpData.refCnt() > 0) { + httpData.release(); + } + } + + destroyed = true; + + if (undecodedChunk != null && undecodedChunk.refCnt() > 0) { + undecodedChunk.release(); + undecodedChunk = null; + } + } + + /** + * Clean all HttpDatas (on Disk) for the current request. + */ + @Override + public void cleanFiles() { + checkDestroyed(); + + factory.cleanRequestHttpData(request); + } + + /** + * Remove the given FileUpload from the list of FileUploads to clean + */ + @Override + public void removeHttpDataFromClean(InterfaceHttpData data) { + checkDestroyed(); + + factory.removeHttpDataFromClean(request, data); + } + + /** + * Remove all Attributes that should be cleaned between two FileUpload in + * Mixed mode + */ + private void cleanMixedAttributes() { + currentFieldAttributes.remove(HttpHeaderValues.CHARSET); + currentFieldAttributes.remove(HttpHeaderNames.CONTENT_LENGTH); + currentFieldAttributes.remove(HttpHeaderNames.CONTENT_TRANSFER_ENCODING); + currentFieldAttributes.remove(HttpHeaderNames.CONTENT_TYPE); + currentFieldAttributes.remove(HttpHeaderValues.FILENAME); + } + + /** + * Read one line up to the CRLF or LF + * + * @return the String from one line + * @throws NotEnoughDataDecoderException + * Need more chunks and reset the {@code readerIndex} to the previous + * value + */ + private static String readLineOptimized(ByteBuf undecodedChunk, Charset charset) { + int readerIndex = undecodedChunk.readerIndex(); + ByteBuf line = null; + try { + if (undecodedChunk.isReadable()) { + int posLfOrCrLf = HttpPostBodyUtil.findLineBreak(undecodedChunk, undecodedChunk.readerIndex()); + if (posLfOrCrLf <= 0) { + throw new NotEnoughDataDecoderException(); + } + try { + line = undecodedChunk.alloc().heapBuffer(posLfOrCrLf); + line.writeBytes(undecodedChunk, posLfOrCrLf); + + byte nextByte = undecodedChunk.readByte(); + if (nextByte == HttpConstants.CR) { + // force read next byte since LF is the following one + undecodedChunk.readByte(); + } + return line.toString(charset); + } finally { + line.release(); + } + } + } catch (IndexOutOfBoundsException e) { + undecodedChunk.readerIndex(readerIndex); + throw new NotEnoughDataDecoderException(e); + } + undecodedChunk.readerIndex(readerIndex); + throw new NotEnoughDataDecoderException(); + } + + /** + * Read one line up to --delimiter or --delimiter-- and if existing the CRLF + * or LF Read one line up to --delimiter or --delimiter-- and if existing + * the CRLF or LF. Note that CRLF or LF are mandatory for opening delimiter + * (--delimiter) but not for closing delimiter (--delimiter--) since some + * clients does not include CRLF in this case. + * + * @param delimiter + * of the form --string, such that '--' is already included + * @return the String from one line as the delimiter searched (opening or + * closing) + * @throws NotEnoughDataDecoderException + * Need more chunks and reset the {@code readerIndex} to the previous + * value + */ + private static String readDelimiterOptimized(ByteBuf undecodedChunk, String delimiter, Charset charset) { + final int readerIndex = undecodedChunk.readerIndex(); + final byte[] bdelimiter = delimiter.getBytes(charset); + final int delimiterLength = bdelimiter.length; + try { + int delimiterPos = HttpPostBodyUtil.findDelimiter(undecodedChunk, readerIndex, bdelimiter, false); + if (delimiterPos < 0) { + // delimiter not found so break here ! + undecodedChunk.readerIndex(readerIndex); + throw new NotEnoughDataDecoderException(); + } + StringBuilder sb = new StringBuilder(delimiter); + undecodedChunk.readerIndex(readerIndex + delimiterPos + delimiterLength); + // Now check if either opening delimiter or closing delimiter + if (undecodedChunk.isReadable()) { + byte nextByte = undecodedChunk.readByte(); + // first check for opening delimiter + if (nextByte == HttpConstants.CR) { + nextByte = undecodedChunk.readByte(); + if (nextByte == HttpConstants.LF) { + return sb.toString(); + } else { + // error since CR must be followed by LF + // delimiter not found so break here ! + undecodedChunk.readerIndex(readerIndex); + throw new NotEnoughDataDecoderException(); + } + } else if (nextByte == HttpConstants.LF) { + return sb.toString(); + } else if (nextByte == '-') { + sb.append('-'); + // second check for closing delimiter + nextByte = undecodedChunk.readByte(); + if (nextByte == '-') { + sb.append('-'); + // now try to find if CRLF or LF there + if (undecodedChunk.isReadable()) { + nextByte = undecodedChunk.readByte(); + if (nextByte == HttpConstants.CR) { + nextByte = undecodedChunk.readByte(); + if (nextByte == HttpConstants.LF) { + return sb.toString(); + } else { + // error CR without LF + // delimiter not found so break here ! + undecodedChunk.readerIndex(readerIndex); + throw new NotEnoughDataDecoderException(); + } + } else if (nextByte == HttpConstants.LF) { + return sb.toString(); + } else { + // No CRLF but ok however (Adobe Flash uploader) + // minus 1 since we read one char ahead but + // should not + undecodedChunk.readerIndex(undecodedChunk.readerIndex() - 1); + return sb.toString(); + } + } + // FIXME what do we do here? + // either considering it is fine, either waiting for + // more data to come? + // lets try considering it is fine... + return sb.toString(); + } + // only one '-' => not enough + // whatever now => error since incomplete + } + } + } catch (IndexOutOfBoundsException e) { + undecodedChunk.readerIndex(readerIndex); + throw new NotEnoughDataDecoderException(e); + } + undecodedChunk.readerIndex(readerIndex); + throw new NotEnoughDataDecoderException(); + } + + /** + * Rewrite buffer in order to skip lengthToSkip bytes from current readerIndex, + * such that any readable bytes available after readerIndex + lengthToSkip (so before writerIndex) + * are moved at readerIndex position, + * therefore decreasing writerIndex of lengthToSkip at the end of the process. + * + * @param buffer the buffer to rewrite from current readerIndex + * @param lengthToSkip the size to skip from readerIndex + */ + private static void rewriteCurrentBuffer(ByteBuf buffer, int lengthToSkip) { + if (lengthToSkip == 0) { + return; + } + final int readerIndex = buffer.readerIndex(); + final int readableBytes = buffer.readableBytes(); + if (readableBytes == lengthToSkip) { + buffer.readerIndex(readerIndex); + buffer.writerIndex(readerIndex); + return; + } + buffer.setBytes(readerIndex, buffer, readerIndex + lengthToSkip, readableBytes - lengthToSkip); + buffer.readerIndex(readerIndex); + buffer.writerIndex(readerIndex + readableBytes - lengthToSkip); + } + + /** + * Load the field value or file data from a Multipart request + * + * @return {@code true} if the last chunk is loaded (boundary delimiter found), {@code false} if need more chunks + * @throws ErrorDataDecoderException + */ + private static boolean loadDataMultipartOptimized(ByteBuf undecodedChunk, String delimiter, HttpData httpData) { + if (!undecodedChunk.isReadable()) { + return false; + } + final int startReaderIndex = undecodedChunk.readerIndex(); + final byte[] bdelimiter = delimiter.getBytes(httpData.getCharset()); + int posDelimiter = HttpPostBodyUtil.findDelimiter(undecodedChunk, startReaderIndex, bdelimiter, true); + if (posDelimiter < 0) { + // Not found but however perhaps because incomplete so search LF or CRLF from the end. + // Possible last bytes contain partially delimiter + // (delimiter is possibly partially there, at least 1 missing byte), + // therefore searching last delimiter.length +1 (+1 for CRLF instead of LF) + int readableBytes = undecodedChunk.readableBytes(); + int lastPosition = readableBytes - bdelimiter.length - 1; + if (lastPosition < 0) { + // Not enough bytes, but at most delimiter.length bytes available so can still try to find CRLF there + lastPosition = 0; + } + posDelimiter = HttpPostBodyUtil.findLastLineBreak(undecodedChunk, startReaderIndex + lastPosition); + // No LineBreak, however CR can be at the end of the buffer, LF not yet there (issue #11668) + // Check if last CR (if any) shall not be in the content (definedLength vs actual length + buffer - 1) + if (posDelimiter < 0 && + httpData.definedLength() == httpData.length() + readableBytes - 1 && + undecodedChunk.getByte(readableBytes + startReaderIndex - 1) == HttpConstants.CR) { + // Last CR shall precede a future LF + lastPosition = 0; + posDelimiter = readableBytes - 1; + } + if (posDelimiter < 0) { + // not found so this chunk can be fully added + ByteBuf content = undecodedChunk.copy(); + try { + httpData.addContent(content, false); + } catch (IOException e) { + throw new ErrorDataDecoderException(e); + } + undecodedChunk.readerIndex(startReaderIndex); + undecodedChunk.writerIndex(startReaderIndex); + return false; + } + // posDelimiter is not from startReaderIndex but from startReaderIndex + lastPosition + posDelimiter += lastPosition; + if (posDelimiter == 0) { + // Nothing to add + return false; + } + // Not fully but still some bytes to provide: httpData is not yet finished since delimiter not found + ByteBuf content = undecodedChunk.copy(startReaderIndex, posDelimiter); + try { + httpData.addContent(content, false); + } catch (IOException e) { + throw new ErrorDataDecoderException(e); + } + rewriteCurrentBuffer(undecodedChunk, posDelimiter); + return false; + } + // Delimiter found at posDelimiter, including LF or CRLF, so httpData has its last chunk + ByteBuf content = undecodedChunk.copy(startReaderIndex, posDelimiter); + try { + httpData.addContent(content, true); + } catch (IOException e) { + throw new ErrorDataDecoderException(e); + } + rewriteCurrentBuffer(undecodedChunk, posDelimiter); + return true; + } + + /** + * Clean the String from any unallowed character + * + * @return the cleaned String + */ + private static String cleanString(String field) { + int size = field.length(); + StringBuilder sb = new StringBuilder(size); + for (int i = 0; i < size; i++) { + char nextChar = field.charAt(i); + switch (nextChar) { + case HttpConstants.COLON: + case HttpConstants.COMMA: + case HttpConstants.EQUALS: + case HttpConstants.SEMICOLON: + case HttpConstants.HT: + sb.append(HttpConstants.SP_CHAR); + break; + case HttpConstants.DOUBLE_QUOTE: + // nothing added, just removes it + break; + default: + sb.append(nextChar); + break; + } + } + return sb.toString().trim(); + } + + /** + * Skip one empty line + * + * @return True if one empty line was skipped + */ + private boolean skipOneLine() { + if (!undecodedChunk.isReadable()) { + return false; + } + byte nextByte = undecodedChunk.readByte(); + if (nextByte == HttpConstants.CR) { + if (!undecodedChunk.isReadable()) { + undecodedChunk.readerIndex(undecodedChunk.readerIndex() - 1); + return false; + } + nextByte = undecodedChunk.readByte(); + if (nextByte == HttpConstants.LF) { + return true; + } + undecodedChunk.readerIndex(undecodedChunk.readerIndex() - 2); + return false; + } + if (nextByte == HttpConstants.LF) { + return true; + } + undecodedChunk.readerIndex(undecodedChunk.readerIndex() - 1); + return false; + } + + /** + * Split one header in Multipart + * + * @return an array of String where rank 0 is the name of the header, + * follows by several values that were separated by ';' or ',' + */ + private static String[] splitMultipartHeader(String sb) { + ArrayList headers = new ArrayList(1); + int nameStart; + int nameEnd; + int colonEnd; + int valueStart; + int valueEnd; + nameStart = HttpPostBodyUtil.findNonWhitespace(sb, 0); + for (nameEnd = nameStart; nameEnd < sb.length(); nameEnd++) { + char ch = sb.charAt(nameEnd); + if (ch == ':' || Character.isWhitespace(ch)) { + break; + } + } + for (colonEnd = nameEnd; colonEnd < sb.length(); colonEnd++) { + if (sb.charAt(colonEnd) == ':') { + colonEnd++; + break; + } + } + valueStart = HttpPostBodyUtil.findNonWhitespace(sb, colonEnd); + valueEnd = HttpPostBodyUtil.findEndOfString(sb); + headers.add(sb.substring(nameStart, nameEnd)); + String svalue = (valueStart >= valueEnd) ? StringUtil.EMPTY_STRING : sb.substring(valueStart, valueEnd); + String[] values; + if (svalue.indexOf(';') >= 0) { + values = splitMultipartHeaderValues(svalue); + } else { + values = svalue.split(","); + } + for (String value : values) { + headers.add(value.trim()); + } + String[] array = new String[headers.size()]; + for (int i = 0; i < headers.size(); i++) { + array[i] = headers.get(i); + } + return array; + } + + /** + * Split one header value in Multipart + * @return an array of String where values that were separated by ';' or ',' + */ + private static String[] splitMultipartHeaderValues(String svalue) { + List values = InternalThreadLocalMap.get().arrayList(1); + boolean inQuote = false; + boolean escapeNext = false; + int start = 0; + for (int i = 0; i < svalue.length(); i++) { + char c = svalue.charAt(i); + if (inQuote) { + if (escapeNext) { + escapeNext = false; + } else { + if (c == '\\') { + escapeNext = true; + } else if (c == '"') { + inQuote = false; + } + } + } else { + if (c == '"') { + inQuote = true; + } else if (c == ';') { + values.add(svalue.substring(start, i)); + start = i + 1; + } + } + } + values.add(svalue.substring(start)); + return values.toArray(EmptyArrays.EMPTY_STRINGS); + } + + /** + * This method is package private intentionally in order to allow during tests + * to access to the amount of memory allocated (capacity) within the private + * ByteBuf undecodedChunk + * + * @return the number of bytes the internal buffer can contain + */ + int getCurrentAllocatedCapacity() { + return undecodedChunk.capacity(); + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/HttpPostRequestDecoder.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/HttpPostRequestDecoder.java new file mode 100644 index 0000000..eb6bd79 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/HttpPostRequestDecoder.java @@ -0,0 +1,341 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.multipart; + +import io.netty.handler.codec.DecoderException; +import io.netty.handler.codec.http.HttpConstants; +import io.netty.handler.codec.http.HttpContent; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaderValues; +import io.netty.handler.codec.http.HttpRequest; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.StringUtil; + +import java.nio.charset.Charset; +import java.util.List; + +/** + * This decoder will decode Body and can handle POST BODY. + * + * You MUST call {@link #destroy()} after completion to release all resources. + * + */ +public class HttpPostRequestDecoder implements InterfaceHttpPostRequestDecoder { + + static final int DEFAULT_DISCARD_THRESHOLD = 10 * 1024 * 1024; + + private final InterfaceHttpPostRequestDecoder decoder; + + /** + * + * @param request + * the request to decode + * @throws NullPointerException + * for request + * @throws ErrorDataDecoderException + * if the default charset was wrong when decoding or other + * errors + */ + public HttpPostRequestDecoder(HttpRequest request) { + this(new DefaultHttpDataFactory(DefaultHttpDataFactory.MINSIZE), request, HttpConstants.DEFAULT_CHARSET); + } + + /** + * + * @param factory + * the factory used to create InterfaceHttpData + * @param request + * the request to decode + * @throws NullPointerException + * for request or factory + * @throws ErrorDataDecoderException + * if the default charset was wrong when decoding or other + * errors + */ + public HttpPostRequestDecoder(HttpDataFactory factory, HttpRequest request) { + this(factory, request, HttpConstants.DEFAULT_CHARSET); + } + + /** + * + * @param factory + * the factory used to create InterfaceHttpData + * @param request + * the request to decode + * @param charset + * the charset to use as default + * @throws NullPointerException + * for request or charset or factory + * @throws ErrorDataDecoderException + * if the default charset was wrong when decoding or other + * errors + */ + public HttpPostRequestDecoder(HttpDataFactory factory, HttpRequest request, Charset charset) { + ObjectUtil.checkNotNull(factory, "factory"); + ObjectUtil.checkNotNull(request, "request"); + ObjectUtil.checkNotNull(charset, "charset"); + + // Fill default values + if (isMultipart(request)) { + decoder = new HttpPostMultipartRequestDecoder(factory, request, charset); + } else { + decoder = new HttpPostStandardRequestDecoder(factory, request, charset); + } + } + + /** + * states follow NOTSTARTED PREAMBLE ( (HEADERDELIMITER DISPOSITION (FIELD | + * FILEUPLOAD))* (HEADERDELIMITER DISPOSITION MIXEDPREAMBLE (MIXEDDELIMITER + * MIXEDDISPOSITION MIXEDFILEUPLOAD)+ MIXEDCLOSEDELIMITER)* CLOSEDELIMITER)+ + * EPILOGUE + * + * First getStatus is: NOSTARTED + * + * Content-type: multipart/form-data, boundary=AaB03x => PREAMBLE in Header + * + * --AaB03x => HEADERDELIMITER content-disposition: form-data; name="field1" + * => DISPOSITION + * + * Joe Blow => FIELD --AaB03x => HEADERDELIMITER content-disposition: + * form-data; name="pics" => DISPOSITION Content-type: multipart/mixed, + * boundary=BbC04y + * + * --BbC04y => MIXEDDELIMITER Content-disposition: attachment; + * filename="file1.txt" => MIXEDDISPOSITION Content-Type: text/plain + * + * ... contents of file1.txt ... => MIXEDFILEUPLOAD --BbC04y => + * MIXEDDELIMITER Content-disposition: file; filename="file2.gif" => + * MIXEDDISPOSITION Content-type: image/gif Content-Transfer-Encoding: + * binary + * + * ...contents of file2.gif... => MIXEDFILEUPLOAD --BbC04y-- => + * MIXEDCLOSEDELIMITER --AaB03x-- => CLOSEDELIMITER + * + * Once CLOSEDELIMITER is found, last getStatus is EPILOGUE + */ + protected enum MultiPartStatus { + NOTSTARTED, PREAMBLE, HEADERDELIMITER, DISPOSITION, FIELD, FILEUPLOAD, MIXEDPREAMBLE, MIXEDDELIMITER, + MIXEDDISPOSITION, MIXEDFILEUPLOAD, MIXEDCLOSEDELIMITER, CLOSEDELIMITER, PREEPILOGUE, EPILOGUE + } + + /** + * Check if the given request is a multipart request + * @return True if the request is a Multipart request + */ + public static boolean isMultipart(HttpRequest request) { + String mimeType = request.headers().get(HttpHeaderNames.CONTENT_TYPE); + if (mimeType != null && mimeType.startsWith(HttpHeaderValues.MULTIPART_FORM_DATA.toString())) { + return getMultipartDataBoundary(mimeType) != null; + } + return false; + } + + /** + * Check from the request ContentType if this request is a Multipart request. + * @return an array of String if multipartDataBoundary exists with the multipartDataBoundary + * as first element, charset if any as second (missing if not set), else null + */ + protected static String[] getMultipartDataBoundary(String contentType) { + // Check if Post using "multipart/form-data; boundary=--89421926422648 [; charset=xxx]" + String[] headerContentType = splitHeaderContentType(contentType); + final String multiPartHeader = HttpHeaderValues.MULTIPART_FORM_DATA.toString(); + if (headerContentType[0].regionMatches(true, 0, multiPartHeader, 0 , multiPartHeader.length())) { + int mrank; + int crank; + final String boundaryHeader = HttpHeaderValues.BOUNDARY.toString(); + if (headerContentType[1].regionMatches(true, 0, boundaryHeader, 0, boundaryHeader.length())) { + mrank = 1; + crank = 2; + } else if (headerContentType[2].regionMatches(true, 0, boundaryHeader, 0, boundaryHeader.length())) { + mrank = 2; + crank = 1; + } else { + return null; + } + String boundary = StringUtil.substringAfter(headerContentType[mrank], '='); + if (boundary == null) { + throw new ErrorDataDecoderException("Needs a boundary value"); + } + if (boundary.charAt(0) == '"') { + String bound = boundary.trim(); + int index = bound.length() - 1; + if (bound.charAt(index) == '"') { + boundary = bound.substring(1, index); + } + } + final String charsetHeader = HttpHeaderValues.CHARSET.toString(); + if (headerContentType[crank].regionMatches(true, 0, charsetHeader, 0, charsetHeader.length())) { + String charset = StringUtil.substringAfter(headerContentType[crank], '='); + if (charset != null) { + return new String[] {"--" + boundary, charset}; + } + } + return new String[] {"--" + boundary}; + } + return null; + } + + @Override + public boolean isMultipart() { + return decoder.isMultipart(); + } + + @Override + public void setDiscardThreshold(int discardThreshold) { + decoder.setDiscardThreshold(discardThreshold); + } + + @Override + public int getDiscardThreshold() { + return decoder.getDiscardThreshold(); + } + + @Override + public List getBodyHttpDatas() { + return decoder.getBodyHttpDatas(); + } + + @Override + public List getBodyHttpDatas(String name) { + return decoder.getBodyHttpDatas(name); + } + + @Override + public InterfaceHttpData getBodyHttpData(String name) { + return decoder.getBodyHttpData(name); + } + + @Override + public InterfaceHttpPostRequestDecoder offer(HttpContent content) { + return decoder.offer(content); + } + + @Override + public boolean hasNext() { + return decoder.hasNext(); + } + + @Override + public InterfaceHttpData next() { + return decoder.next(); + } + + @Override + public InterfaceHttpData currentPartialHttpData() { + return decoder.currentPartialHttpData(); + } + + @Override + public void destroy() { + decoder.destroy(); + } + + @Override + public void cleanFiles() { + decoder.cleanFiles(); + } + + @Override + public void removeHttpDataFromClean(InterfaceHttpData data) { + decoder.removeHttpDataFromClean(data); + } + + /** + * Split the very first line (Content-Type value) in 3 Strings + * + * @return the array of 3 Strings + */ + private static String[] splitHeaderContentType(String sb) { + int aStart; + int aEnd; + int bStart; + int bEnd; + int cStart; + int cEnd; + aStart = HttpPostBodyUtil.findNonWhitespace(sb, 0); + aEnd = sb.indexOf(';'); + if (aEnd == -1) { + return new String[] { sb, "", "" }; + } + bStart = HttpPostBodyUtil.findNonWhitespace(sb, aEnd + 1); + if (sb.charAt(aEnd - 1) == ' ') { + aEnd--; + } + bEnd = sb.indexOf(';', bStart); + if (bEnd == -1) { + bEnd = HttpPostBodyUtil.findEndOfString(sb); + return new String[] { sb.substring(aStart, aEnd), sb.substring(bStart, bEnd), "" }; + } + cStart = HttpPostBodyUtil.findNonWhitespace(sb, bEnd + 1); + if (sb.charAt(bEnd - 1) == ' ') { + bEnd--; + } + cEnd = HttpPostBodyUtil.findEndOfString(sb); + return new String[] { sb.substring(aStart, aEnd), sb.substring(bStart, bEnd), sb.substring(cStart, cEnd) }; + } + + /** + * Exception when try reading data from request in chunked format, and not + * enough data are available (need more chunks) + */ + public static class NotEnoughDataDecoderException extends DecoderException { + private static final long serialVersionUID = -7846841864603865638L; + + public NotEnoughDataDecoderException() { + } + + public NotEnoughDataDecoderException(String msg) { + super(msg); + } + + public NotEnoughDataDecoderException(Throwable cause) { + super(cause); + } + + public NotEnoughDataDecoderException(String msg, Throwable cause) { + super(msg, cause); + } + } + + /** + * Exception when the body is fully decoded, even if there is still data + */ + public static class EndOfDataDecoderException extends DecoderException { + private static final long serialVersionUID = 1336267941020800769L; + } + + /** + * Exception when an error occurs while decoding + */ + public static class ErrorDataDecoderException extends DecoderException { + private static final long serialVersionUID = 5020247425493164465L; + + public ErrorDataDecoderException() { + } + + public ErrorDataDecoderException(String msg) { + super(msg); + } + + public ErrorDataDecoderException(Throwable cause) { + super(cause); + } + + public ErrorDataDecoderException(String msg, Throwable cause) { + super(msg, cause); + } + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/HttpPostRequestEncoder.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/HttpPostRequestEncoder.java new file mode 100755 index 0000000..8921fca --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/HttpPostRequestEncoder.java @@ -0,0 +1,1347 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.multipart; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.DecoderResult; +import io.netty.handler.codec.http.DefaultFullHttpRequest; +import io.netty.handler.codec.http.DefaultHttpContent; +import io.netty.handler.codec.http.EmptyHttpHeaders; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.HttpConstants; +import io.netty.handler.codec.http.HttpContent; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaderValues; +import io.netty.handler.codec.http.HttpHeaders; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpRequest; +import io.netty.handler.codec.http.HttpUtil; +import io.netty.handler.codec.http.HttpVersion; +import io.netty.handler.codec.http.LastHttpContent; +import io.netty.handler.stream.ChunkedInput; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.StringUtil; + +import java.io.File; +import java.io.IOException; +import java.io.UnsupportedEncodingException; +import java.net.URLEncoder; +import java.nio.charset.Charset; +import java.util.ArrayList; +import java.util.List; +import java.util.ListIterator; +import java.util.Map; +import java.util.regex.Pattern; + +import static io.netty.buffer.Unpooled.wrappedBuffer; +import static io.netty.util.internal.ObjectUtil.checkNotNull; +import static java.util.AbstractMap.SimpleImmutableEntry; + +/** + * This encoder will help to encode Request for a FORM as POST. + * + *

According to RFC 7231, POST, PUT and OPTIONS allow to have a body. + * This encoder will support widely all methods except TRACE since the RFC notes + * for GET, DELETE, HEAD and CONNECT: (replaces XXX by one of these methods)

+ *

"A payload within a XXX request message has no defined semantics; + * sending a payload body on a XXX request might cause some existing + * implementations to reject the request."

+ *

On the contrary, for TRACE method, RFC says:

+ *

"A client MUST NOT send a message body in a TRACE request."

+ */ +public class HttpPostRequestEncoder implements ChunkedInput { + + /** + * Different modes to use to encode form data. + */ + public enum EncoderMode { + /** + * Legacy mode which should work for most. It is known to not work with OAUTH. For OAUTH use + * {@link EncoderMode#RFC3986}. The W3C form recommendations this for submitting post form data. + */ + RFC1738, + + /** + * Mode which is more new and is used for OAUTH + */ + RFC3986, + + /** + * The HTML5 spec disallows mixed mode in multipart/form-data + * requests. More concretely this means that more files submitted + * under the same name will not be encoded using mixed mode, but + * will be treated as distinct fields. + * + * Reference: + * https://www.w3.org/TR/html5/forms.html#multipart-form-data + */ + HTML5 + } + + @SuppressWarnings("rawtypes") + private static final Map.Entry[] percentEncodings; + + static { + percentEncodings = new Map.Entry[] { + new SimpleImmutableEntry(Pattern.compile("\\*"), "%2A"), + new SimpleImmutableEntry(Pattern.compile("\\+"), "%20"), + new SimpleImmutableEntry(Pattern.compile("~"), "%7E") + }; + } + + /** + * Factory used to create InterfaceHttpData + */ + private final HttpDataFactory factory; + + /** + * Request to encode + */ + private final HttpRequest request; + + /** + * Default charset to use + */ + private final Charset charset; + + /** + * Chunked false by default + */ + private boolean isChunked; + + /** + * InterfaceHttpData for Body (without encoding) + */ + private final List bodyListDatas; + /** + * The final Multipart List of InterfaceHttpData including encoding + */ + final List multipartHttpDatas; + + /** + * Does this request is a Multipart request + */ + private final boolean isMultipart; + + /** + * If multipart, this is the boundary for the flobal multipart + */ + String multipartDataBoundary; + + /** + * If multipart, there could be internal multiparts (mixed) to the global multipart. Only one level is allowed. + */ + String multipartMixedBoundary; + /** + * To check if the header has been finalized + */ + private boolean headerFinalized; + + private final EncoderMode encoderMode; + + /** + * + * @param request + * the request to encode + * @param multipart + * True if the FORM is a ENCTYPE="multipart/form-data" + * @throws NullPointerException + * for request + * @throws ErrorDataEncoderException + * if the request is a TRACE + */ + public HttpPostRequestEncoder(HttpRequest request, boolean multipart) throws ErrorDataEncoderException { + this(new DefaultHttpDataFactory(DefaultHttpDataFactory.MINSIZE), request, multipart, + HttpConstants.DEFAULT_CHARSET, EncoderMode.RFC1738); + } + + /** + * + * @param factory + * the factory used to create InterfaceHttpData + * @param request + * the request to encode + * @param multipart + * True if the FORM is a ENCTYPE="multipart/form-data" + * @throws NullPointerException + * for request and factory + * @throws ErrorDataEncoderException + * if the request is a TRACE + */ + public HttpPostRequestEncoder(HttpDataFactory factory, HttpRequest request, boolean multipart) + throws ErrorDataEncoderException { + this(factory, request, multipart, HttpConstants.DEFAULT_CHARSET, EncoderMode.RFC1738); + } + + /** + * + * @param factory + * the factory used to create InterfaceHttpData + * @param request + * the request to encode + * @param multipart + * True if the FORM is a ENCTYPE="multipart/form-data" + * @param charset + * the charset to use as default + * @param encoderMode + * the mode for the encoder to use. See {@link EncoderMode} for the details. + * @throws NullPointerException + * for request or charset or factory + * @throws ErrorDataEncoderException + * if the request is a TRACE + */ + public HttpPostRequestEncoder( + HttpDataFactory factory, HttpRequest request, boolean multipart, Charset charset, + EncoderMode encoderMode) + throws ErrorDataEncoderException { + this.request = checkNotNull(request, "request"); + this.charset = checkNotNull(charset, "charset"); + this.factory = checkNotNull(factory, "factory"); + if (HttpMethod.TRACE.equals(request.method())) { + throw new ErrorDataEncoderException("Cannot create a Encoder if request is a TRACE"); + } + // Fill default values + bodyListDatas = new ArrayList(); + // default mode + isLastChunk = false; + isLastChunkSent = false; + isMultipart = multipart; + multipartHttpDatas = new ArrayList(); + this.encoderMode = encoderMode; + if (isMultipart) { + initDataMultipart(); + } + } + + /** + * Clean all HttpDatas (on Disk) for the current request. + */ + public void cleanFiles() { + factory.cleanRequestHttpData(request); + } + + /** + * Does the last non empty chunk already encoded so that next chunk will be empty (last chunk) + */ + private boolean isLastChunk; + /** + * Last chunk already sent + */ + private boolean isLastChunkSent; + /** + * The current FileUpload that is currently in encode process + */ + private FileUpload currentFileUpload; + /** + * While adding a FileUpload, is the multipart currently in Mixed Mode + */ + private boolean duringMixedMode; + /** + * Global Body size + */ + private long globalBodySize; + /** + * Global Transfer progress + */ + private long globalProgress; + + /** + * True if this request is a Multipart request + * + * @return True if this request is a Multipart request + */ + public boolean isMultipart() { + return isMultipart; + } + + /** + * Init the delimiter for Global Part (Data). + */ + private void initDataMultipart() { + multipartDataBoundary = getNewMultipartDelimiter(); + } + + /** + * Init the delimiter for Mixed Part (Mixed). + */ + private void initMixedMultipart() { + multipartMixedBoundary = getNewMultipartDelimiter(); + } + + /** + * + * @return a newly generated Delimiter (either for DATA or MIXED) + */ + private static String getNewMultipartDelimiter() { + // construct a generated delimiter + return Long.toHexString(PlatformDependent.threadLocalRandom().nextLong()); + } + + /** + * This getMethod returns a List of all InterfaceHttpData from body part.
+ + * @return the list of InterfaceHttpData from Body part + */ + public List getBodyListAttributes() { + return bodyListDatas; + } + + /** + * Set the Body HttpDatas list + * + * @throws NullPointerException + * for datas + * @throws ErrorDataEncoderException + * if the encoding is in error or if the finalize were already done + */ + public void setBodyHttpDatas(List datas) throws ErrorDataEncoderException { + ObjectUtil.checkNotNull(datas, "datas"); + globalBodySize = 0; + bodyListDatas.clear(); + currentFileUpload = null; + duringMixedMode = false; + multipartHttpDatas.clear(); + for (InterfaceHttpData data : datas) { + addBodyHttpData(data); + } + } + + /** + * Add a simple attribute in the body as Name=Value + * + * @param name + * name of the parameter + * @param value + * the value of the parameter + * @throws NullPointerException + * for name + * @throws ErrorDataEncoderException + * if the encoding is in error or if the finalize were already done + */ + public void addBodyAttribute(String name, String value) throws ErrorDataEncoderException { + String svalue = value != null? value : StringUtil.EMPTY_STRING; + Attribute data = factory.createAttribute(request, checkNotNull(name, "name"), svalue); + addBodyHttpData(data); + } + + /** + * Add a file as a FileUpload + * + * @param name + * the name of the parameter + * @param file + * the file to be uploaded (if not Multipart mode, only the filename will be included) + * @param contentType + * the associated contentType for the File + * @param isText + * True if this file should be transmitted in Text format (else binary) + * @throws NullPointerException + * for name and file + * @throws ErrorDataEncoderException + * if the encoding is in error or if the finalize were already done + */ + public void addBodyFileUpload(String name, File file, String contentType, boolean isText) + throws ErrorDataEncoderException { + addBodyFileUpload(name, file.getName(), file, contentType, isText); + } + + /** + * Add a file as a FileUpload + * + * @param name + * the name of the parameter + * @param file + * the file to be uploaded (if not Multipart mode, only the filename will be included) + * @param filename + * the filename to use for this File part, empty String will be ignored by + * the encoder + * @param contentType + * the associated contentType for the File + * @param isText + * True if this file should be transmitted in Text format (else binary) + * @throws NullPointerException + * for name and file + * @throws ErrorDataEncoderException + * if the encoding is in error or if the finalize were already done + */ + public void addBodyFileUpload(String name, String filename, File file, String contentType, boolean isText) + throws ErrorDataEncoderException { + checkNotNull(name, "name"); + checkNotNull(file, "file"); + if (filename == null) { + filename = StringUtil.EMPTY_STRING; + } + String scontentType = contentType; + String contentTransferEncoding = null; + if (contentType == null) { + if (isText) { + scontentType = HttpPostBodyUtil.DEFAULT_TEXT_CONTENT_TYPE; + } else { + scontentType = HttpPostBodyUtil.DEFAULT_BINARY_CONTENT_TYPE; + } + } + if (!isText) { + contentTransferEncoding = HttpPostBodyUtil.TransferEncodingMechanism.BINARY.value(); + } + FileUpload fileUpload = factory.createFileUpload(request, name, filename, scontentType, + contentTransferEncoding, null, file.length()); + try { + fileUpload.setContent(file); + } catch (IOException e) { + throw new ErrorDataEncoderException(e); + } + addBodyHttpData(fileUpload); + } + + /** + * Add a series of Files associated with one File parameter + * + * @param name + * the name of the parameter + * @param file + * the array of files + * @param contentType + * the array of content Types associated with each file + * @param isText + * the array of isText attribute (False meaning binary mode) for each file + * @throws IllegalArgumentException + * also throws if array have different sizes + * @throws ErrorDataEncoderException + * if the encoding is in error or if the finalize were already done + */ + public void addBodyFileUploads(String name, File[] file, String[] contentType, boolean[] isText) + throws ErrorDataEncoderException { + if (file.length != contentType.length && file.length != isText.length) { + throw new IllegalArgumentException("Different array length"); + } + for (int i = 0; i < file.length; i++) { + addBodyFileUpload(name, file[i], contentType[i], isText[i]); + } + } + + /** + * Add the InterfaceHttpData to the Body list + * + * @throws NullPointerException + * for data + * @throws ErrorDataEncoderException + * if the encoding is in error or if the finalize were already done + */ + public void addBodyHttpData(InterfaceHttpData data) throws ErrorDataEncoderException { + if (headerFinalized) { + throw new ErrorDataEncoderException("Cannot add value once finalized"); + } + bodyListDatas.add(checkNotNull(data, "data")); + if (!isMultipart) { + if (data instanceof Attribute) { + Attribute attribute = (Attribute) data; + try { + // name=value& with encoded name and attribute + String key = encodeAttribute(attribute.getName(), charset); + String value = encodeAttribute(attribute.getValue(), charset); + Attribute newattribute = factory.createAttribute(request, key, value); + multipartHttpDatas.add(newattribute); + globalBodySize += newattribute.getName().length() + 1 + newattribute.length() + 1; + } catch (IOException e) { + throw new ErrorDataEncoderException(e); + } + } else if (data instanceof FileUpload) { + // since not Multipart, only name=filename => Attribute + FileUpload fileUpload = (FileUpload) data; + // name=filename& with encoded name and filename + String key = encodeAttribute(fileUpload.getName(), charset); + String value = encodeAttribute(fileUpload.getFilename(), charset); + Attribute newattribute = factory.createAttribute(request, key, value); + multipartHttpDatas.add(newattribute); + globalBodySize += newattribute.getName().length() + 1 + newattribute.length() + 1; + } + return; + } + /* + * Logic: + * if not Attribute: + * add Data to body list + * if (duringMixedMode) + * add endmixedmultipart delimiter + * currentFileUpload = null + * duringMixedMode = false; + * add multipart delimiter, multipart body header and Data to multipart list + * reset currentFileUpload, duringMixedMode + * if FileUpload: take care of multiple file for one field => mixed mode + * if (duringMixedMode) + * if (currentFileUpload.name == data.name) + * add mixedmultipart delimiter, mixedmultipart body header and Data to multipart list + * else + * add endmixedmultipart delimiter, multipart body header and Data to multipart list + * currentFileUpload = data + * duringMixedMode = false; + * else + * if (currentFileUpload.name == data.name) + * change multipart body header of previous file into multipart list to + * mixedmultipart start, mixedmultipart body header + * add mixedmultipart delimiter, mixedmultipart body header and Data to multipart list + * duringMixedMode = true + * else + * add multipart delimiter, multipart body header and Data to multipart list + * currentFileUpload = data + * duringMixedMode = false; + * Do not add last delimiter! Could be: + * if duringmixedmode: endmixedmultipart + endmultipart + * else only endmultipart + */ + if (data instanceof Attribute) { + if (duringMixedMode) { + InternalAttribute internal = new InternalAttribute(charset); + internal.addValue("\r\n--" + multipartMixedBoundary + "--"); + multipartHttpDatas.add(internal); + multipartMixedBoundary = null; + currentFileUpload = null; + duringMixedMode = false; + } + InternalAttribute internal = new InternalAttribute(charset); + if (!multipartHttpDatas.isEmpty()) { + // previously a data field so CRLF + internal.addValue("\r\n"); + } + internal.addValue("--" + multipartDataBoundary + "\r\n"); + // content-disposition: form-data; name="field1" + Attribute attribute = (Attribute) data; + internal.addValue(HttpHeaderNames.CONTENT_DISPOSITION + ": " + HttpHeaderValues.FORM_DATA + "; " + + HttpHeaderValues.NAME + "=\"" + attribute.getName() + "\"\r\n"); + // Add Content-Length: xxx + internal.addValue(HttpHeaderNames.CONTENT_LENGTH + ": " + + attribute.length() + "\r\n"); + Charset localcharset = attribute.getCharset(); + if (localcharset != null) { + // Content-Type: text/plain; charset=charset + internal.addValue(HttpHeaderNames.CONTENT_TYPE + ": " + + HttpPostBodyUtil.DEFAULT_TEXT_CONTENT_TYPE + "; " + + HttpHeaderValues.CHARSET + '=' + + localcharset.name() + "\r\n"); + } + // CRLF between body header and data + internal.addValue("\r\n"); + multipartHttpDatas.add(internal); + multipartHttpDatas.add(data); + globalBodySize += attribute.length() + internal.size(); + } else if (data instanceof FileUpload) { + FileUpload fileUpload = (FileUpload) data; + InternalAttribute internal = new InternalAttribute(charset); + if (!multipartHttpDatas.isEmpty()) { + // previously a data field so CRLF + internal.addValue("\r\n"); + } + boolean localMixed; + if (duringMixedMode) { + if (currentFileUpload != null && currentFileUpload.getName().equals(fileUpload.getName())) { + // continue a mixed mode + + localMixed = true; + } else { + // end a mixed mode + + // add endmixedmultipart delimiter, multipart body header + // and + // Data to multipart list + internal.addValue("--" + multipartMixedBoundary + "--"); + multipartHttpDatas.add(internal); + multipartMixedBoundary = null; + // start a new one (could be replaced if mixed start again + // from here + internal = new InternalAttribute(charset); + internal.addValue("\r\n"); + localMixed = false; + // new currentFileUpload and no more in Mixed mode + currentFileUpload = fileUpload; + duringMixedMode = false; + } + } else { + if (encoderMode != EncoderMode.HTML5 && currentFileUpload != null + && currentFileUpload.getName().equals(fileUpload.getName())) { + // create a new mixed mode (from previous file) + + // change multipart body header of previous file into + // multipart list to + // mixedmultipart start, mixedmultipart body header + + // change Internal (size()-2 position in multipartHttpDatas) + // from (line starting with *) + // --AaB03x + // * Content-Disposition: form-data; name="files"; + // filename="file1.txt" + // Content-Type: text/plain + // to (lines starting with *) + // --AaB03x + // * Content-Disposition: form-data; name="files" + // * Content-Type: multipart/mixed; boundary=BbC04y + // * + // * --BbC04y + // * Content-Disposition: attachment; filename="file1.txt" + // Content-Type: text/plain + initMixedMultipart(); + InternalAttribute pastAttribute = (InternalAttribute) multipartHttpDatas.get(multipartHttpDatas + .size() - 2); + // remove past size + globalBodySize -= pastAttribute.size(); + StringBuilder replacement = new StringBuilder( + 139 + multipartDataBoundary.length() + multipartMixedBoundary.length() * 2 + + fileUpload.getFilename().length() + fileUpload.getName().length()) + + .append("--") + .append(multipartDataBoundary) + .append("\r\n") + + .append(HttpHeaderNames.CONTENT_DISPOSITION) + .append(": ") + .append(HttpHeaderValues.FORM_DATA) + .append("; ") + .append(HttpHeaderValues.NAME) + .append("=\"") + .append(fileUpload.getName()) + .append("\"\r\n") + + .append(HttpHeaderNames.CONTENT_TYPE) + .append(": ") + .append(HttpHeaderValues.MULTIPART_MIXED) + .append("; ") + .append(HttpHeaderValues.BOUNDARY) + .append('=') + .append(multipartMixedBoundary) + .append("\r\n\r\n") + + .append("--") + .append(multipartMixedBoundary) + .append("\r\n") + + .append(HttpHeaderNames.CONTENT_DISPOSITION) + .append(": ") + .append(HttpHeaderValues.ATTACHMENT); + + if (!fileUpload.getFilename().isEmpty()) { + replacement.append("; ") + .append(HttpHeaderValues.FILENAME) + .append("=\"") + .append(currentFileUpload.getFilename()) + .append('"'); + } + + replacement.append("\r\n"); + + pastAttribute.setValue(replacement.toString(), 1); + pastAttribute.setValue("", 2); + + // update past size + globalBodySize += pastAttribute.size(); + + // now continue + // add mixedmultipart delimiter, mixedmultipart body header + // and + // Data to multipart list + localMixed = true; + duringMixedMode = true; + } else { + // a simple new multipart + // add multipart delimiter, multipart body header and Data + // to multipart list + localMixed = false; + currentFileUpload = fileUpload; + duringMixedMode = false; + } + } + + if (localMixed) { + // add mixedmultipart delimiter, mixedmultipart body header and + // Data to multipart list + internal.addValue("--" + multipartMixedBoundary + "\r\n"); + + if (fileUpload.getFilename().isEmpty()) { + // Content-Disposition: attachment + internal.addValue(HttpHeaderNames.CONTENT_DISPOSITION + ": " + + HttpHeaderValues.ATTACHMENT + "\r\n"); + } else { + // Content-Disposition: attachment; filename="file1.txt" + internal.addValue(HttpHeaderNames.CONTENT_DISPOSITION + ": " + + HttpHeaderValues.ATTACHMENT + "; " + + HttpHeaderValues.FILENAME + "=\"" + fileUpload.getFilename() + "\"\r\n"); + } + } else { + internal.addValue("--" + multipartDataBoundary + "\r\n"); + + if (fileUpload.getFilename().isEmpty()) { + // Content-Disposition: form-data; name="files"; + internal.addValue(HttpHeaderNames.CONTENT_DISPOSITION + ": " + HttpHeaderValues.FORM_DATA + "; " + + HttpHeaderValues.NAME + "=\"" + fileUpload.getName() + "\"\r\n"); + } else { + // Content-Disposition: form-data; name="files"; + // filename="file1.txt" + internal.addValue(HttpHeaderNames.CONTENT_DISPOSITION + ": " + HttpHeaderValues.FORM_DATA + "; " + + HttpHeaderValues.NAME + "=\"" + fileUpload.getName() + "\"; " + + HttpHeaderValues.FILENAME + "=\"" + fileUpload.getFilename() + "\"\r\n"); + } + } + // Add Content-Length: xxx + internal.addValue(HttpHeaderNames.CONTENT_LENGTH + ": " + + fileUpload.length() + "\r\n"); + // Content-Type: image/gif + // Content-Type: text/plain; charset=ISO-8859-1 + // Content-Transfer-Encoding: binary + internal.addValue(HttpHeaderNames.CONTENT_TYPE + ": " + fileUpload.getContentType()); + String contentTransferEncoding = fileUpload.getContentTransferEncoding(); + if (contentTransferEncoding != null + && contentTransferEncoding.equals(HttpPostBodyUtil.TransferEncodingMechanism.BINARY.value())) { + internal.addValue("\r\n" + HttpHeaderNames.CONTENT_TRANSFER_ENCODING + ": " + + HttpPostBodyUtil.TransferEncodingMechanism.BINARY.value() + "\r\n\r\n"); + } else if (fileUpload.getCharset() != null) { + internal.addValue("; " + HttpHeaderValues.CHARSET + '=' + fileUpload.getCharset().name() + "\r\n\r\n"); + } else { + internal.addValue("\r\n\r\n"); + } + multipartHttpDatas.add(internal); + multipartHttpDatas.add(data); + globalBodySize += fileUpload.length() + internal.size(); + } + } + + /** + * Iterator to be used when encoding will be called chunk after chunk + */ + private ListIterator iterator; + + /** + * Finalize the request by preparing the Header in the request and returns the request ready to be sent.
+ * Once finalized, no data must be added.
+ * If the request does not need chunk (isChunked() == false), this request is the only object to send to the remote + * server. + * + * @return the request object (chunked or not according to size of body) + * @throws ErrorDataEncoderException + * if the encoding is in error or if the finalize were already done + */ + public HttpRequest finalizeRequest() throws ErrorDataEncoderException { + // Finalize the multipartHttpDatas + if (!headerFinalized) { + if (isMultipart) { + InternalAttribute internal = new InternalAttribute(charset); + if (duringMixedMode) { + internal.addValue("\r\n--" + multipartMixedBoundary + "--"); + } + internal.addValue("\r\n--" + multipartDataBoundary + "--\r\n"); + multipartHttpDatas.add(internal); + multipartMixedBoundary = null; + currentFileUpload = null; + duringMixedMode = false; + globalBodySize += internal.size(); + } + headerFinalized = true; + } else { + throw new ErrorDataEncoderException("Header already encoded"); + } + + HttpHeaders headers = request.headers(); + List contentTypes = headers.getAll(HttpHeaderNames.CONTENT_TYPE); + List transferEncoding = headers.getAll(HttpHeaderNames.TRANSFER_ENCODING); + if (contentTypes != null) { + headers.remove(HttpHeaderNames.CONTENT_TYPE); + for (String contentType : contentTypes) { + // "multipart/form-data; boundary=--89421926422648" + String lowercased = contentType.toLowerCase(); + if (lowercased.startsWith(HttpHeaderValues.MULTIPART_FORM_DATA.toString()) || + lowercased.startsWith(HttpHeaderValues.APPLICATION_X_WWW_FORM_URLENCODED.toString())) { + // ignore + } else { + headers.add(HttpHeaderNames.CONTENT_TYPE, contentType); + } + } + } + if (isMultipart) { + String value = HttpHeaderValues.MULTIPART_FORM_DATA + "; " + HttpHeaderValues.BOUNDARY + '=' + + multipartDataBoundary; + headers.add(HttpHeaderNames.CONTENT_TYPE, value); + } else { + // Not multipart + headers.add(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_X_WWW_FORM_URLENCODED); + } + // Now consider size for chunk or not + long realSize = globalBodySize; + if (!isMultipart) { + realSize -= 1; // last '&' removed + } + iterator = multipartHttpDatas.listIterator(); + + headers.set(HttpHeaderNames.CONTENT_LENGTH, String.valueOf(realSize)); + if (realSize > HttpPostBodyUtil.chunkSize || isMultipart) { + isChunked = true; + if (transferEncoding != null) { + headers.remove(HttpHeaderNames.TRANSFER_ENCODING); + for (CharSequence v : transferEncoding) { + if (HttpHeaderValues.CHUNKED.contentEqualsIgnoreCase(v)) { + // ignore + } else { + headers.add(HttpHeaderNames.TRANSFER_ENCODING, v); + } + } + } + HttpUtil.setTransferEncodingChunked(request, true); + + // wrap to hide the possible content + return new WrappedHttpRequest(request); + } else { + // get the only one body and set it to the request + HttpContent chunk = nextChunk(); + if (request instanceof FullHttpRequest) { + FullHttpRequest fullRequest = (FullHttpRequest) request; + ByteBuf chunkContent = chunk.content(); + if (fullRequest.content() != chunkContent) { + fullRequest.content().clear().writeBytes(chunkContent); + chunkContent.release(); + } + return fullRequest; + } else { + return new WrappedFullHttpRequest(request, chunk); + } + } + } + + /** + * @return True if the request is by Chunk + */ + public boolean isChunked() { + return isChunked; + } + + /** + * Encode one attribute + * + * @return the encoded attribute + * @throws ErrorDataEncoderException + * if the encoding is in error + */ + @SuppressWarnings("unchecked") + private String encodeAttribute(String s, Charset charset) throws ErrorDataEncoderException { + if (s == null) { + return ""; + } + try { + String encoded = URLEncoder.encode(s, charset.name()); + if (encoderMode == EncoderMode.RFC3986) { + for (Map.Entry entry : percentEncodings) { + String replacement = entry.getValue(); + encoded = entry.getKey().matcher(encoded).replaceAll(replacement); + } + } + return encoded; + } catch (UnsupportedEncodingException e) { + throw new ErrorDataEncoderException(charset.name(), e); + } + } + + /** + * The ByteBuf currently used by the encoder + */ + private ByteBuf currentBuffer; + /** + * The current InterfaceHttpData to encode (used if more chunks are available) + */ + private InterfaceHttpData currentData; + /** + * If not multipart, does the currentBuffer stands for the Key or for the Value + */ + private boolean isKey = true; + + /** + * + * @return the next ByteBuf to send as an HttpChunk and modifying currentBuffer accordingly + */ + private ByteBuf fillByteBuf() { + int length = currentBuffer.readableBytes(); + if (length > HttpPostBodyUtil.chunkSize) { + return currentBuffer.readRetainedSlice(HttpPostBodyUtil.chunkSize); + } else { + // to continue + ByteBuf slice = currentBuffer; + currentBuffer = null; + return slice; + } + } + + /** + * From the current context (currentBuffer and currentData), returns the next HttpChunk (if possible) trying to get + * sizeleft bytes more into the currentBuffer. This is the Multipart version. + * + * @param sizeleft + * the number of bytes to try to get from currentData + * @return the next HttpChunk or null if not enough bytes were found + * @throws ErrorDataEncoderException + * if the encoding is in error + */ + private HttpContent encodeNextChunkMultipart(int sizeleft) throws ErrorDataEncoderException { + if (currentData == null) { + return null; + } + ByteBuf buffer; + if (currentData instanceof InternalAttribute) { + buffer = ((InternalAttribute) currentData).toByteBuf(); + currentData = null; + } else { + try { + buffer = ((HttpData) currentData).getChunk(sizeleft); + } catch (IOException e) { + throw new ErrorDataEncoderException(e); + } + if (buffer.capacity() == 0) { + // end for current InterfaceHttpData, need more data + currentData = null; + return null; + } + } + if (currentBuffer == null) { + currentBuffer = buffer; + } else { + currentBuffer = wrappedBuffer(currentBuffer, buffer); + } + if (currentBuffer.readableBytes() < HttpPostBodyUtil.chunkSize) { + currentData = null; + return null; + } + buffer = fillByteBuf(); + return new DefaultHttpContent(buffer); + } + + /** + * From the current context (currentBuffer and currentData), returns the next HttpChunk (if possible) trying to get + * sizeleft bytes more into the currentBuffer. This is the UrlEncoded version. + * + * @param sizeleft + * the number of bytes to try to get from currentData + * @return the next HttpChunk or null if not enough bytes were found + * @throws ErrorDataEncoderException + * if the encoding is in error + */ + private HttpContent encodeNextChunkUrlEncoded(int sizeleft) throws ErrorDataEncoderException { + if (currentData == null) { + return null; + } + int size = sizeleft; + ByteBuf buffer; + + // Set name= + if (isKey) { + String key = currentData.getName(); + buffer = wrappedBuffer(key.getBytes(charset)); + isKey = false; + if (currentBuffer == null) { + currentBuffer = wrappedBuffer(buffer, wrappedBuffer("=".getBytes(charset))); + } else { + currentBuffer = wrappedBuffer(currentBuffer, buffer, wrappedBuffer("=".getBytes(charset))); + } + // continue + size -= buffer.readableBytes() + 1; + if (currentBuffer.readableBytes() >= HttpPostBodyUtil.chunkSize) { + buffer = fillByteBuf(); + return new DefaultHttpContent(buffer); + } + } + + // Put value into buffer + try { + buffer = ((HttpData) currentData).getChunk(size); + } catch (IOException e) { + throw new ErrorDataEncoderException(e); + } + + // Figure out delimiter + ByteBuf delimiter = null; + if (buffer.readableBytes() < size) { + isKey = true; + delimiter = iterator.hasNext() ? wrappedBuffer("&".getBytes(charset)) : null; + } + + // End for current InterfaceHttpData, need potentially more data + if (buffer.capacity() == 0) { + currentData = null; + if (currentBuffer == null) { + if (delimiter == null) { + return null; + } else { + currentBuffer = delimiter; + } + } else { + if (delimiter != null) { + currentBuffer = wrappedBuffer(currentBuffer, delimiter); + } + } + if (currentBuffer.readableBytes() >= HttpPostBodyUtil.chunkSize) { + buffer = fillByteBuf(); + return new DefaultHttpContent(buffer); + } + return null; + } + + // Put it all together: name=value& + if (currentBuffer == null) { + if (delimiter != null) { + currentBuffer = wrappedBuffer(buffer, delimiter); + } else { + currentBuffer = buffer; + } + } else { + if (delimiter != null) { + currentBuffer = wrappedBuffer(currentBuffer, buffer, delimiter); + } else { + currentBuffer = wrappedBuffer(currentBuffer, buffer); + } + } + + // end for current InterfaceHttpData, need more data + if (currentBuffer.readableBytes() < HttpPostBodyUtil.chunkSize) { + currentData = null; + isKey = true; + return null; + } + + buffer = fillByteBuf(); + return new DefaultHttpContent(buffer); + } + + @Override + public void close() throws Exception { + // NO since the user can want to reuse (broadcast for instance) + // cleanFiles(); + } + + @Deprecated + @Override + public HttpContent readChunk(ChannelHandlerContext ctx) throws Exception { + return readChunk(ctx.alloc()); + } + + /** + * Returns the next available HttpChunk. The caller is responsible to test if this chunk is the last one (isLast()), + * in order to stop calling this getMethod. + * + * @return the next available HttpChunk + * @throws ErrorDataEncoderException + * if the encoding is in error + */ + @Override + public HttpContent readChunk(ByteBufAllocator allocator) throws Exception { + if (isLastChunkSent) { + return null; + } else { + HttpContent nextChunk = nextChunk(); + globalProgress += nextChunk.content().readableBytes(); + return nextChunk; + } + } + + /** + * Returns the next available HttpChunk. The caller is responsible to test if this chunk is the last one (isLast()), + * in order to stop calling this getMethod. + * + * @return the next available HttpChunk + * @throws ErrorDataEncoderException + * if the encoding is in error + */ + private HttpContent nextChunk() throws ErrorDataEncoderException { + if (isLastChunk) { + isLastChunkSent = true; + return LastHttpContent.EMPTY_LAST_CONTENT; + } + // first test if previous buffer is not empty + int size = calculateRemainingSize(); + if (size <= 0) { + // NextChunk from buffer + ByteBuf buffer = fillByteBuf(); + return new DefaultHttpContent(buffer); + } + // size > 0 + if (currentData != null) { + // continue to read data + HttpContent chunk; + if (isMultipart) { + chunk = encodeNextChunkMultipart(size); + } else { + chunk = encodeNextChunkUrlEncoded(size); + } + if (chunk != null) { + // NextChunk from data + return chunk; + } + size = calculateRemainingSize(); + } + if (!iterator.hasNext()) { + return lastChunk(); + } + while (size > 0 && iterator.hasNext()) { + currentData = iterator.next(); + HttpContent chunk; + if (isMultipart) { + chunk = encodeNextChunkMultipart(size); + } else { + chunk = encodeNextChunkUrlEncoded(size); + } + if (chunk == null) { + // not enough + size = calculateRemainingSize(); + continue; + } + // NextChunk from data + return chunk; + } + // end since no more data + return lastChunk(); + } + + private int calculateRemainingSize() { + int size = HttpPostBodyUtil.chunkSize; + if (currentBuffer != null) { + size -= currentBuffer.readableBytes(); + } + return size; + } + + private HttpContent lastChunk() { + isLastChunk = true; + if (currentBuffer == null) { + isLastChunkSent = true; + // LastChunk with no more data + return LastHttpContent.EMPTY_LAST_CONTENT; + } + // NextChunk as last non empty from buffer + ByteBuf buffer = currentBuffer; + currentBuffer = null; + return new DefaultHttpContent(buffer); + } + + @Override + public boolean isEndOfInput() throws Exception { + return isLastChunkSent; + } + + @Override + public long length() { + return isMultipart? globalBodySize : globalBodySize - 1; + } + + @Override + public long progress() { + return globalProgress; + } + + /** + * Exception when an error occurs while encoding + */ + public static class ErrorDataEncoderException extends Exception { + private static final long serialVersionUID = 5020247425493164465L; + + public ErrorDataEncoderException() { + } + + public ErrorDataEncoderException(String msg) { + super(msg); + } + + public ErrorDataEncoderException(Throwable cause) { + super(cause); + } + + public ErrorDataEncoderException(String msg, Throwable cause) { + super(msg, cause); + } + } + + private static class WrappedHttpRequest implements HttpRequest { + private final HttpRequest request; + WrappedHttpRequest(HttpRequest request) { + this.request = request; + } + + @Override + public HttpRequest setProtocolVersion(HttpVersion version) { + request.setProtocolVersion(version); + return this; + } + + @Override + public HttpRequest setMethod(HttpMethod method) { + request.setMethod(method); + return this; + } + + @Override + public HttpRequest setUri(String uri) { + request.setUri(uri); + return this; + } + + @Override + public HttpMethod getMethod() { + return request.method(); + } + + @Override + public HttpMethod method() { + return request.method(); + } + + @Override + public String getUri() { + return request.uri(); + } + + @Override + public String uri() { + return request.uri(); + } + + @Override + public HttpVersion getProtocolVersion() { + return request.protocolVersion(); + } + + @Override + public HttpVersion protocolVersion() { + return request.protocolVersion(); + } + + @Override + public HttpHeaders headers() { + return request.headers(); + } + + @Override + public DecoderResult decoderResult() { + return request.decoderResult(); + } + + @Override + @Deprecated + public DecoderResult getDecoderResult() { + return request.getDecoderResult(); + } + + @Override + public void setDecoderResult(DecoderResult result) { + request.setDecoderResult(result); + } + } + + private static final class WrappedFullHttpRequest extends WrappedHttpRequest implements FullHttpRequest { + private final HttpContent content; + + private WrappedFullHttpRequest(HttpRequest request, HttpContent content) { + super(request); + this.content = content; + } + + @Override + public FullHttpRequest setProtocolVersion(HttpVersion version) { + super.setProtocolVersion(version); + return this; + } + + @Override + public FullHttpRequest setMethod(HttpMethod method) { + super.setMethod(method); + return this; + } + + @Override + public FullHttpRequest setUri(String uri) { + super.setUri(uri); + return this; + } + + @Override + public FullHttpRequest copy() { + return replace(content().copy()); + } + + @Override + public FullHttpRequest duplicate() { + return replace(content().duplicate()); + } + + @Override + public FullHttpRequest retainedDuplicate() { + return replace(content().retainedDuplicate()); + } + + @Override + public FullHttpRequest replace(ByteBuf content) { + DefaultFullHttpRequest duplicate = new DefaultFullHttpRequest(protocolVersion(), method(), uri(), content); + duplicate.headers().set(headers()); + duplicate.trailingHeaders().set(trailingHeaders()); + return duplicate; + } + + @Override + public FullHttpRequest retain(int increment) { + content.retain(increment); + return this; + } + + @Override + public FullHttpRequest retain() { + content.retain(); + return this; + } + + @Override + public FullHttpRequest touch() { + content.touch(); + return this; + } + + @Override + public FullHttpRequest touch(Object hint) { + content.touch(hint); + return this; + } + + @Override + public ByteBuf content() { + return content.content(); + } + + @Override + public HttpHeaders trailingHeaders() { + if (content instanceof LastHttpContent) { + return ((LastHttpContent) content).trailingHeaders(); + } else { + return EmptyHttpHeaders.INSTANCE; + } + } + + @Override + public int refCnt() { + return content.refCnt(); + } + + @Override + public boolean release() { + return content.release(); + } + + @Override + public boolean release(int decrement) { + return content.release(decrement); + } + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/HttpPostStandardRequestDecoder.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/HttpPostStandardRequestDecoder.java new file mode 100644 index 0000000..b8e7058 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/HttpPostStandardRequestDecoder.java @@ -0,0 +1,784 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.multipart; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.handler.codec.http.HttpConstants; +import io.netty.handler.codec.http.HttpContent; +import io.netty.handler.codec.http.HttpRequest; +import io.netty.handler.codec.http.LastHttpContent; +import io.netty.handler.codec.http.QueryStringDecoder; +import io.netty.handler.codec.http.multipart.HttpPostBodyUtil.SeekAheadOptimize; +import io.netty.handler.codec.http.multipart.HttpPostRequestDecoder.EndOfDataDecoderException; +import io.netty.handler.codec.http.multipart.HttpPostRequestDecoder.ErrorDataDecoderException; +import io.netty.handler.codec.http.multipart.HttpPostRequestDecoder.MultiPartStatus; +import io.netty.handler.codec.http.multipart.HttpPostRequestDecoder.NotEnoughDataDecoderException; +import io.netty.util.ByteProcessor; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.StringUtil; + +import java.io.IOException; +import java.nio.charset.Charset; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.TreeMap; + +import static io.netty.util.internal.ObjectUtil.*; + +/** + * This decoder will decode Body and can handle POST BODY. + * + * You MUST call {@link #destroy()} after completion to release all resources. + * + */ +public class HttpPostStandardRequestDecoder implements InterfaceHttpPostRequestDecoder { + + /** + * Factory used to create InterfaceHttpData + */ + private final HttpDataFactory factory; + + /** + * Request to decode + */ + private final HttpRequest request; + + /** + * Default charset to use + */ + private final Charset charset; + + /** + * Does the last chunk already received + */ + private boolean isLastChunk; + + /** + * HttpDatas from Body + */ + private final List bodyListHttpData = new ArrayList(); + + /** + * HttpDatas as Map from Body + */ + private final Map> bodyMapHttpData = new TreeMap>( + CaseIgnoringComparator.INSTANCE); + + /** + * The current channelBuffer + */ + private ByteBuf undecodedChunk; + + /** + * Body HttpDatas current position + */ + private int bodyListHttpDataRank; + + /** + * Current getStatus + */ + private MultiPartStatus currentStatus = MultiPartStatus.NOTSTARTED; + + /** + * The current Attribute that is currently in decode process + */ + private Attribute currentAttribute; + + private boolean destroyed; + + private int discardThreshold = HttpPostRequestDecoder.DEFAULT_DISCARD_THRESHOLD; + + /** + * + * @param request + * the request to decode + * @throws NullPointerException + * for request + * @throws ErrorDataDecoderException + * if the default charset was wrong when decoding or other + * errors + */ + public HttpPostStandardRequestDecoder(HttpRequest request) { + this(new DefaultHttpDataFactory(DefaultHttpDataFactory.MINSIZE), request, HttpConstants.DEFAULT_CHARSET); + } + + /** + * + * @param factory + * the factory used to create InterfaceHttpData + * @param request + * the request to decode + * @throws NullPointerException + * for request or factory + * @throws ErrorDataDecoderException + * if the default charset was wrong when decoding or other + * errors + */ + public HttpPostStandardRequestDecoder(HttpDataFactory factory, HttpRequest request) { + this(factory, request, HttpConstants.DEFAULT_CHARSET); + } + + /** + * + * @param factory + * the factory used to create InterfaceHttpData + * @param request + * the request to decode + * @param charset + * the charset to use as default + * @throws NullPointerException + * for request or charset or factory + * @throws ErrorDataDecoderException + * if the default charset was wrong when decoding or other + * errors + */ + public HttpPostStandardRequestDecoder(HttpDataFactory factory, HttpRequest request, Charset charset) { + this.request = checkNotNull(request, "request"); + this.charset = checkNotNull(charset, "charset"); + this.factory = checkNotNull(factory, "factory"); + try { + if (request instanceof HttpContent) { + // Offer automatically if the given request is as type of HttpContent + // See #1089 + offer((HttpContent) request); + } else { + parseBody(); + } + } catch (Throwable e) { + destroy(); + PlatformDependent.throwException(e); + } + } + + private void checkDestroyed() { + if (destroyed) { + throw new IllegalStateException(HttpPostStandardRequestDecoder.class.getSimpleName() + + " was destroyed already"); + } + } + + /** + * True if this request is a Multipart request + * + * @return True if this request is a Multipart request + */ + @Override + public boolean isMultipart() { + checkDestroyed(); + return false; + } + + /** + * Set the amount of bytes after which read bytes in the buffer should be discarded. + * Setting this lower gives lower memory usage but with the overhead of more memory copies. + * Use {@code 0} to disable it. + */ + @Override + public void setDiscardThreshold(int discardThreshold) { + this.discardThreshold = checkPositiveOrZero(discardThreshold, "discardThreshold"); + } + + /** + * Return the threshold in bytes after which read data in the buffer should be discarded. + */ + @Override + public int getDiscardThreshold() { + return discardThreshold; + } + + /** + * This getMethod returns a List of all HttpDatas from body.
+ * + * If chunked, all chunks must have been offered using offer() getMethod. If + * not, NotEnoughDataDecoderException will be raised. + * + * @return the list of HttpDatas from Body part for POST getMethod + * @throws NotEnoughDataDecoderException + * Need more chunks + */ + @Override + public List getBodyHttpDatas() { + checkDestroyed(); + + if (!isLastChunk) { + throw new NotEnoughDataDecoderException(); + } + return bodyListHttpData; + } + + /** + * This getMethod returns a List of all HttpDatas with the given name from + * body.
+ * + * If chunked, all chunks must have been offered using offer() getMethod. If + * not, NotEnoughDataDecoderException will be raised. + * + * @return All Body HttpDatas with the given name (ignore case) + * @throws NotEnoughDataDecoderException + * need more chunks + */ + @Override + public List getBodyHttpDatas(String name) { + checkDestroyed(); + + if (!isLastChunk) { + throw new NotEnoughDataDecoderException(); + } + return bodyMapHttpData.get(name); + } + + /** + * This getMethod returns the first InterfaceHttpData with the given name from + * body.
+ * + * If chunked, all chunks must have been offered using offer() getMethod. If + * not, NotEnoughDataDecoderException will be raised. + * + * @return The first Body InterfaceHttpData with the given name (ignore + * case) + * @throws NotEnoughDataDecoderException + * need more chunks + */ + @Override + public InterfaceHttpData getBodyHttpData(String name) { + checkDestroyed(); + + if (!isLastChunk) { + throw new NotEnoughDataDecoderException(); + } + List list = bodyMapHttpData.get(name); + if (list != null) { + return list.get(0); + } + return null; + } + + /** + * Initialized the internals from a new chunk + * + * @param content + * the new received chunk + * @throws ErrorDataDecoderException + * if there is a problem with the charset decoding or other + * errors + */ + @Override + public HttpPostStandardRequestDecoder offer(HttpContent content) { + checkDestroyed(); + + if (content instanceof LastHttpContent) { + isLastChunk = true; + } + + ByteBuf buf = content.content(); + if (undecodedChunk == null) { + undecodedChunk = + // Since the Handler will release the incoming later on, we need to copy it + // + // We are explicit allocate a buffer and NOT calling copy() as otherwise it may set a maxCapacity + // which is not really usable for us as we may exceed it once we add more bytes. + buf.alloc().buffer(buf.readableBytes()).writeBytes(buf); + } else { + undecodedChunk.writeBytes(buf); + } + parseBody(); + if (undecodedChunk != null && undecodedChunk.writerIndex() > discardThreshold) { + if (undecodedChunk.refCnt() == 1) { + // It's safe to call discardBytes() as we are the only owner of the buffer. + undecodedChunk.discardReadBytes(); + } else { + // There seems to be multiple references of the buffer. Let's copy the data and release the buffer to + // ensure we can give back memory to the system. + ByteBuf buffer = undecodedChunk.alloc().buffer(undecodedChunk.readableBytes()); + buffer.writeBytes(undecodedChunk); + undecodedChunk.release(); + undecodedChunk = buffer; + } + } + return this; + } + + /** + * True if at current getStatus, there is an available decoded + * InterfaceHttpData from the Body. + * + * This getMethod works for chunked and not chunked request. + * + * @return True if at current getStatus, there is a decoded InterfaceHttpData + * @throws EndOfDataDecoderException + * No more data will be available + */ + @Override + public boolean hasNext() { + checkDestroyed(); + + if (currentStatus == MultiPartStatus.EPILOGUE) { + // OK except if end of list + if (bodyListHttpDataRank >= bodyListHttpData.size()) { + throw new EndOfDataDecoderException(); + } + } + return !bodyListHttpData.isEmpty() && bodyListHttpDataRank < bodyListHttpData.size(); + } + + /** + * Returns the next available InterfaceHttpData or null if, at the time it + * is called, there is no more available InterfaceHttpData. A subsequent + * call to offer(httpChunk) could enable more data. + * + * Be sure to call {@link InterfaceHttpData#release()} after you are done + * with processing to make sure to not leak any resources + * + * @return the next available InterfaceHttpData or null if none + * @throws EndOfDataDecoderException + * No more data will be available + */ + @Override + public InterfaceHttpData next() { + checkDestroyed(); + + if (hasNext()) { + return bodyListHttpData.get(bodyListHttpDataRank++); + } + return null; + } + + @Override + public InterfaceHttpData currentPartialHttpData() { + return currentAttribute; + } + + /** + * This getMethod will parse as much as possible data and fill the list and map + * + * @throws ErrorDataDecoderException + * if there is a problem with the charset decoding or other + * errors + */ + private void parseBody() { + if (currentStatus == MultiPartStatus.PREEPILOGUE || currentStatus == MultiPartStatus.EPILOGUE) { + if (isLastChunk) { + currentStatus = MultiPartStatus.EPILOGUE; + } + return; + } + parseBodyAttributes(); + } + + /** + * Utility function to add a new decoded data + */ + protected void addHttpData(InterfaceHttpData data) { + if (data == null) { + return; + } + List datas = bodyMapHttpData.get(data.getName()); + if (datas == null) { + datas = new ArrayList(1); + bodyMapHttpData.put(data.getName(), datas); + } + datas.add(data); + bodyListHttpData.add(data); + } + + /** + * This getMethod fill the map and list with as much Attribute as possible from + * Body in not Multipart mode. + * + * @throws ErrorDataDecoderException + * if there is a problem with the charset decoding or other + * errors + */ + private void parseBodyAttributesStandard() { + int firstpos = undecodedChunk.readerIndex(); + int currentpos = firstpos; + int equalpos; + int ampersandpos; + if (currentStatus == MultiPartStatus.NOTSTARTED) { + currentStatus = MultiPartStatus.DISPOSITION; + } + boolean contRead = true; + try { + while (undecodedChunk.isReadable() && contRead) { + char read = (char) undecodedChunk.readUnsignedByte(); + currentpos++; + switch (currentStatus) { + case DISPOSITION:// search '=' + if (read == '=') { + currentStatus = MultiPartStatus.FIELD; + equalpos = currentpos - 1; + String key = decodeAttribute(undecodedChunk.toString(firstpos, equalpos - firstpos, charset), + charset); + currentAttribute = factory.createAttribute(request, key); + firstpos = currentpos; + } else if (read == '&') { // special empty FIELD + currentStatus = MultiPartStatus.DISPOSITION; + ampersandpos = currentpos - 1; + String key = decodeAttribute( + undecodedChunk.toString(firstpos, ampersandpos - firstpos, charset), charset); + // Some weird request bodies start with an '&' character, eg: &name=J&age=17. + // In that case, key would be "", will get exception: + // java.lang.IllegalArgumentException: Param 'name' must not be empty; + // Just check and skip empty key. + if (!key.isEmpty()) { + currentAttribute = factory.createAttribute(request, key); + currentAttribute.setValue(""); // empty + addHttpData(currentAttribute); + } + currentAttribute = null; + firstpos = currentpos; + contRead = true; + } + break; + case FIELD:// search '&' or end of line + if (read == '&') { + currentStatus = MultiPartStatus.DISPOSITION; + ampersandpos = currentpos - 1; + setFinalBuffer(undecodedChunk.retainedSlice(firstpos, ampersandpos - firstpos)); + firstpos = currentpos; + contRead = true; + } else if (read == HttpConstants.CR) { + if (undecodedChunk.isReadable()) { + read = (char) undecodedChunk.readUnsignedByte(); + currentpos++; + if (read == HttpConstants.LF) { + currentStatus = MultiPartStatus.PREEPILOGUE; + ampersandpos = currentpos - 2; + setFinalBuffer(undecodedChunk.retainedSlice(firstpos, ampersandpos - firstpos)); + firstpos = currentpos; + contRead = false; + } else { + // Error + throw new ErrorDataDecoderException("Bad end of line"); + } + } else { + currentpos--; + } + } else if (read == HttpConstants.LF) { + currentStatus = MultiPartStatus.PREEPILOGUE; + ampersandpos = currentpos - 1; + setFinalBuffer(undecodedChunk.retainedSlice(firstpos, ampersandpos - firstpos)); + firstpos = currentpos; + contRead = false; + } + break; + default: + // just stop + contRead = false; + } + } + if (isLastChunk && currentAttribute != null) { + // special case + ampersandpos = currentpos; + if (ampersandpos > firstpos) { + setFinalBuffer(undecodedChunk.retainedSlice(firstpos, ampersandpos - firstpos)); + } else if (!currentAttribute.isCompleted()) { + setFinalBuffer(Unpooled.EMPTY_BUFFER); + } + firstpos = currentpos; + currentStatus = MultiPartStatus.EPILOGUE; + } else if (contRead && currentAttribute != null && currentStatus == MultiPartStatus.FIELD) { + // reset index except if to continue in case of FIELD getStatus + currentAttribute.addContent(undecodedChunk.retainedSlice(firstpos, currentpos - firstpos), + false); + firstpos = currentpos; + } + undecodedChunk.readerIndex(firstpos); + } catch (ErrorDataDecoderException e) { + // error while decoding + undecodedChunk.readerIndex(firstpos); + throw e; + } catch (IOException e) { + // error while decoding + undecodedChunk.readerIndex(firstpos); + throw new ErrorDataDecoderException(e); + } catch (IllegalArgumentException e) { + // error while decoding + undecodedChunk.readerIndex(firstpos); + throw new ErrorDataDecoderException(e); + } + } + + /** + * This getMethod fill the map and list with as much Attribute as possible from + * Body in not Multipart mode. + * + * @throws ErrorDataDecoderException + * if there is a problem with the charset decoding or other + * errors + */ + private void parseBodyAttributes() { + if (undecodedChunk == null) { + return; + } + if (!undecodedChunk.hasArray()) { + parseBodyAttributesStandard(); + return; + } + SeekAheadOptimize sao = new SeekAheadOptimize(undecodedChunk); + int firstpos = undecodedChunk.readerIndex(); + int currentpos = firstpos; + int equalpos; + int ampersandpos; + if (currentStatus == MultiPartStatus.NOTSTARTED) { + currentStatus = MultiPartStatus.DISPOSITION; + } + boolean contRead = true; + try { + loop: while (sao.pos < sao.limit) { + char read = (char) (sao.bytes[sao.pos++] & 0xFF); + currentpos++; + switch (currentStatus) { + case DISPOSITION:// search '=' + if (read == '=') { + currentStatus = MultiPartStatus.FIELD; + equalpos = currentpos - 1; + String key = decodeAttribute(undecodedChunk.toString(firstpos, equalpos - firstpos, charset), + charset); + currentAttribute = factory.createAttribute(request, key); + firstpos = currentpos; + } else if (read == '&') { // special empty FIELD + currentStatus = MultiPartStatus.DISPOSITION; + ampersandpos = currentpos - 1; + String key = decodeAttribute( + undecodedChunk.toString(firstpos, ampersandpos - firstpos, charset), charset); + // Some weird request bodies start with an '&' char, eg: &name=J&age=17. + // In that case, key would be "", will get exception: + // java.lang.IllegalArgumentException: Param 'name' must not be empty; + // Just check and skip empty key. + if (!key.isEmpty()) { + currentAttribute = factory.createAttribute(request, key); + currentAttribute.setValue(""); // empty + addHttpData(currentAttribute); + } + currentAttribute = null; + firstpos = currentpos; + contRead = true; + } + break; + case FIELD:// search '&' or end of line + if (read == '&') { + currentStatus = MultiPartStatus.DISPOSITION; + ampersandpos = currentpos - 1; + setFinalBuffer(undecodedChunk.retainedSlice(firstpos, ampersandpos - firstpos)); + firstpos = currentpos; + contRead = true; + } else if (read == HttpConstants.CR) { + if (sao.pos < sao.limit) { + read = (char) (sao.bytes[sao.pos++] & 0xFF); + currentpos++; + if (read == HttpConstants.LF) { + currentStatus = MultiPartStatus.PREEPILOGUE; + ampersandpos = currentpos - 2; + sao.setReadPosition(0); + setFinalBuffer(undecodedChunk.retainedSlice(firstpos, ampersandpos - firstpos)); + firstpos = currentpos; + contRead = false; + break loop; + } else { + // Error + sao.setReadPosition(0); + throw new ErrorDataDecoderException("Bad end of line"); + } + } else { + if (sao.limit > 0) { + currentpos--; + } + } + } else if (read == HttpConstants.LF) { + currentStatus = MultiPartStatus.PREEPILOGUE; + ampersandpos = currentpos - 1; + sao.setReadPosition(0); + setFinalBuffer(undecodedChunk.retainedSlice(firstpos, ampersandpos - firstpos)); + firstpos = currentpos; + contRead = false; + break loop; + } + break; + default: + // just stop + sao.setReadPosition(0); + contRead = false; + break loop; + } + } + if (isLastChunk && currentAttribute != null) { + // special case + ampersandpos = currentpos; + if (ampersandpos > firstpos) { + setFinalBuffer(undecodedChunk.retainedSlice(firstpos, ampersandpos - firstpos)); + } else if (!currentAttribute.isCompleted()) { + setFinalBuffer(Unpooled.EMPTY_BUFFER); + } + firstpos = currentpos; + currentStatus = MultiPartStatus.EPILOGUE; + } else if (contRead && currentAttribute != null && currentStatus == MultiPartStatus.FIELD) { + // reset index except if to continue in case of FIELD getStatus + currentAttribute.addContent(undecodedChunk.retainedSlice(firstpos, currentpos - firstpos), + false); + firstpos = currentpos; + } + undecodedChunk.readerIndex(firstpos); + } catch (ErrorDataDecoderException e) { + // error while decoding + undecodedChunk.readerIndex(firstpos); + throw e; + } catch (IOException e) { + // error while decoding + undecodedChunk.readerIndex(firstpos); + throw new ErrorDataDecoderException(e); + } catch (IllegalArgumentException e) { + // error while decoding + undecodedChunk.readerIndex(firstpos); + throw new ErrorDataDecoderException(e); + } + } + + private void setFinalBuffer(ByteBuf buffer) throws IOException { + currentAttribute.addContent(buffer, true); + ByteBuf decodedBuf = decodeAttribute(currentAttribute.getByteBuf(), charset); + if (decodedBuf != null) { // override content only when ByteBuf needed decoding + currentAttribute.setContent(decodedBuf); + } + addHttpData(currentAttribute); + currentAttribute = null; + } + + /** + * Decode component + * + * @return the decoded component + */ + private static String decodeAttribute(String s, Charset charset) { + try { + return QueryStringDecoder.decodeComponent(s, charset); + } catch (IllegalArgumentException e) { + throw new ErrorDataDecoderException("Bad string: '" + s + '\'', e); + } + } + + private static ByteBuf decodeAttribute(ByteBuf b, Charset charset) { + int firstEscaped = b.forEachByte(new UrlEncodedDetector()); + if (firstEscaped == -1) { + return null; // nothing to decode + } + + ByteBuf buf = b.alloc().buffer(b.readableBytes()); + UrlDecoder urlDecode = new UrlDecoder(buf); + int idx = b.forEachByte(urlDecode); + if (urlDecode.nextEscapedIdx != 0) { // incomplete hex byte + if (idx == -1) { + idx = b.readableBytes() - 1; + } + idx -= urlDecode.nextEscapedIdx - 1; + buf.release(); + throw new ErrorDataDecoderException( + String.format("Invalid hex byte at index '%d' in string: '%s'", idx, b.toString(charset))); + } + + return buf; + } + + /** + * Destroy the {@link HttpPostStandardRequestDecoder} and release all it resources. After this method + * was called it is not possible to operate on it anymore. + */ + @Override + public void destroy() { + // Release all data items, including those not yet pulled, only file based items + cleanFiles(); + // Clean Memory based data + for (InterfaceHttpData httpData : bodyListHttpData) { + // Might have been already released by the user + if (httpData.refCnt() > 0) { + httpData.release(); + } + } + + destroyed = true; + + if (undecodedChunk != null && undecodedChunk.refCnt() > 0) { + undecodedChunk.release(); + undecodedChunk = null; + } + } + + /** + * Clean all {@link HttpData}s for the current request. + */ + @Override + public void cleanFiles() { + checkDestroyed(); + + factory.cleanRequestHttpData(request); + } + + /** + * Remove the given FileUpload from the list of FileUploads to clean + */ + @Override + public void removeHttpDataFromClean(InterfaceHttpData data) { + checkDestroyed(); + + factory.removeHttpDataFromClean(request, data); + } + + private static final class UrlEncodedDetector implements ByteProcessor { + @Override + public boolean process(byte value) throws Exception { + return value != '%' && value != '+'; + } + } + + private static final class UrlDecoder implements ByteProcessor { + + private final ByteBuf output; + private int nextEscapedIdx; + private byte hiByte; + + UrlDecoder(ByteBuf output) { + this.output = output; + } + + @Override + public boolean process(byte value) { + if (nextEscapedIdx != 0) { + if (nextEscapedIdx == 1) { + hiByte = value; + ++nextEscapedIdx; + } else { + int hi = StringUtil.decodeHexNibble((char) hiByte); + int lo = StringUtil.decodeHexNibble((char) value); + if (hi == -1 || lo == -1) { + ++nextEscapedIdx; + return false; + } + output.writeByte((hi << 4) + lo); + nextEscapedIdx = 0; + } + } else if (value == '%') { + nextEscapedIdx = 1; + } else if (value == '+') { + output.writeByte(' '); + } else { + output.writeByte(value); + } + return true; + } + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/InterfaceHttpData.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/InterfaceHttpData.java new file mode 100644 index 0000000..8c15329 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/InterfaceHttpData.java @@ -0,0 +1,50 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.multipart; + +import io.netty.util.ReferenceCounted; + +/** + * Interface for all Objects that could be encoded/decoded using HttpPostRequestEncoder/Decoder + */ +public interface InterfaceHttpData extends Comparable, ReferenceCounted { + enum HttpDataType { + Attribute, FileUpload, InternalAttribute + } + + /** + * Returns the name of this InterfaceHttpData. + */ + String getName(); + + /** + * + * @return The HttpDataType + */ + HttpDataType getHttpDataType(); + + @Override + InterfaceHttpData retain(); + + @Override + InterfaceHttpData retain(int increment); + + @Override + InterfaceHttpData touch(); + + @Override + InterfaceHttpData touch(Object hint); +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/InterfaceHttpPostRequestDecoder.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/InterfaceHttpPostRequestDecoder.java new file mode 100755 index 0000000..21ac13c --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/InterfaceHttpPostRequestDecoder.java @@ -0,0 +1,148 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.multipart; + +import io.netty.handler.codec.http.HttpContent; + +import java.util.List; + +/** + * This decoder will decode Body and can handle POST BODY. + * + * You MUST call {@link #destroy()} after completion to release all resources. + */ +public interface InterfaceHttpPostRequestDecoder { + /** + * True if this request is a Multipart request + * + * @return True if this request is a Multipart request + */ + boolean isMultipart(); + + /** + * Set the amount of bytes after which read bytes in the buffer should be discarded. + * Setting this lower gives lower memory usage but with the overhead of more memory copies. + * Use {@code 0} to disable it. + */ + void setDiscardThreshold(int discardThreshold); + + /** + * Return the threshold in bytes after which read data in the buffer should be discarded. + */ + int getDiscardThreshold(); + + /** + * This getMethod returns a List of all HttpDatas from body.
+ * + * If chunked, all chunks must have been offered using offer() getMethod. If + * not, NotEnoughDataDecoderException will be raised. + * + * @return the list of HttpDatas from Body part for POST getMethod + * @throws HttpPostRequestDecoder.NotEnoughDataDecoderException + * Need more chunks + */ + List getBodyHttpDatas(); + + /** + * This getMethod returns a List of all HttpDatas with the given name from + * body.
+ * + * If chunked, all chunks must have been offered using offer() getMethod. If + * not, NotEnoughDataDecoderException will be raised. + * + * @return All Body HttpDatas with the given name (ignore case) + * @throws HttpPostRequestDecoder.NotEnoughDataDecoderException + * need more chunks + */ + List getBodyHttpDatas(String name); + + /** + * This getMethod returns the first InterfaceHttpData with the given name from + * body.
+ * + * If chunked, all chunks must have been offered using offer() getMethod. If + * not, NotEnoughDataDecoderException will be raised. + * + * @return The first Body InterfaceHttpData with the given name (ignore + * case) + * @throws HttpPostRequestDecoder.NotEnoughDataDecoderException + * need more chunks + */ + InterfaceHttpData getBodyHttpData(String name); + + /** + * Initialized the internals from a new chunk + * + * @param content + * the new received chunk + * @throws HttpPostRequestDecoder.ErrorDataDecoderException + * if there is a problem with the charset decoding or other + * errors + */ + InterfaceHttpPostRequestDecoder offer(HttpContent content); + + /** + * True if at current getStatus, there is an available decoded + * InterfaceHttpData from the Body. + * + * This getMethod works for chunked and not chunked request. + * + * @return True if at current getStatus, there is a decoded InterfaceHttpData + * @throws HttpPostRequestDecoder.EndOfDataDecoderException + * No more data will be available + */ + boolean hasNext(); + + /** + * Returns the next available InterfaceHttpData or null if, at the time it + * is called, there is no more available InterfaceHttpData. A subsequent + * call to offer(httpChunk) could enable more data. + * + * Be sure to call {@link InterfaceHttpData#release()} after you are done + * with processing to make sure to not leak any resources + * + * @return the next available InterfaceHttpData or null if none + * @throws HttpPostRequestDecoder.EndOfDataDecoderException + * No more data will be available + */ + InterfaceHttpData next(); + + /** + * Returns the current InterfaceHttpData if currently in decoding status, + * meaning all data are not yet within, or null if there is no InterfaceHttpData + * currently in decoding status (either because none yet decoded or none currently partially + * decoded). Full decoded ones are accessible through hasNext() and next() methods. + * + * @return the current InterfaceHttpData if currently in decoding status or null if none. + */ + InterfaceHttpData currentPartialHttpData(); + + /** + * Destroy the {@link InterfaceHttpPostRequestDecoder} and release all it resources. After this method + * was called it is not possible to operate on it anymore. + */ + void destroy(); + + /** + * Clean all HttpDatas (on Disk) for the current request. + */ + void cleanFiles(); + + /** + * Remove the given FileUpload from the list of FileUploads to clean + */ + void removeHttpDataFromClean(InterfaceHttpData data); +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/InternalAttribute.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/InternalAttribute.java new file mode 100644 index 0000000..438b20a --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/InternalAttribute.java @@ -0,0 +1,155 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.multipart; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.util.AbstractReferenceCounted; +import io.netty.util.internal.ObjectUtil; + +import java.nio.charset.Charset; +import java.util.ArrayList; +import java.util.List; + +/** + * This Attribute is only for Encoder use to insert special command between object if needed + * (like Multipart Mixed mode) + */ +final class InternalAttribute extends AbstractReferenceCounted implements InterfaceHttpData { + private final List value = new ArrayList(); + private final Charset charset; + private int size; + + InternalAttribute(Charset charset) { + this.charset = charset; + } + + @Override + public HttpDataType getHttpDataType() { + return HttpDataType.InternalAttribute; + } + + public void addValue(String value) { + ObjectUtil.checkNotNull(value, "value"); + ByteBuf buf = Unpooled.copiedBuffer(value, charset); + this.value.add(buf); + size += buf.readableBytes(); + } + + public void addValue(String value, int rank) { + ObjectUtil.checkNotNull(value, "value"); + ByteBuf buf = Unpooled.copiedBuffer(value, charset); + this.value.add(rank, buf); + size += buf.readableBytes(); + } + + public void setValue(String value, int rank) { + ObjectUtil.checkNotNull(value, "value"); + ByteBuf buf = Unpooled.copiedBuffer(value, charset); + ByteBuf old = this.value.set(rank, buf); + if (old != null) { + size -= old.readableBytes(); + old.release(); + } + size += buf.readableBytes(); + } + + @Override + public int hashCode() { + return getName().hashCode(); + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof InternalAttribute)) { + return false; + } + InternalAttribute attribute = (InternalAttribute) o; + return getName().equalsIgnoreCase(attribute.getName()); + } + + @Override + public int compareTo(InterfaceHttpData o) { + if (!(o instanceof InternalAttribute)) { + throw new ClassCastException("Cannot compare " + getHttpDataType() + + " with " + o.getHttpDataType()); + } + return compareTo((InternalAttribute) o); + } + + public int compareTo(InternalAttribute o) { + return getName().compareToIgnoreCase(o.getName()); + } + + @Override + public String toString() { + StringBuilder result = new StringBuilder(); + for (ByteBuf elt : value) { + result.append(elt.toString(charset)); + } + return result.toString(); + } + + public int size() { + return size; + } + + public ByteBuf toByteBuf() { + return Unpooled.compositeBuffer().addComponents(value).writerIndex(size()).readerIndex(0); + } + + @Override + public String getName() { + return "InternalAttribute"; + } + + @Override + protected void deallocate() { + // Do nothing + } + + @Override + public InterfaceHttpData retain() { + for (ByteBuf buf: value) { + buf.retain(); + } + return this; + } + + @Override + public InterfaceHttpData retain(int increment) { + for (ByteBuf buf: value) { + buf.retain(increment); + } + return this; + } + + @Override + public InterfaceHttpData touch() { + for (ByteBuf buf: value) { + buf.touch(); + } + return this; + } + + @Override + public InterfaceHttpData touch(Object hint) { + for (ByteBuf buf: value) { + buf.touch(hint); + } + return this; + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/MemoryAttribute.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/MemoryAttribute.java new file mode 100644 index 0000000..25a7bbf --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/MemoryAttribute.java @@ -0,0 +1,197 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.multipart; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelException; +import io.netty.handler.codec.http.HttpConstants; +import io.netty.util.internal.ObjectUtil; + +import java.io.IOException; +import java.nio.charset.Charset; + +import static io.netty.buffer.Unpooled.*; + +/** + * Memory implementation of Attributes + */ +public class MemoryAttribute extends AbstractMemoryHttpData implements Attribute { + + public MemoryAttribute(String name) { + this(name, HttpConstants.DEFAULT_CHARSET); + } + + public MemoryAttribute(String name, long definedSize) { + this(name, definedSize, HttpConstants.DEFAULT_CHARSET); + } + + public MemoryAttribute(String name, Charset charset) { + super(name, charset, 0); + } + + public MemoryAttribute(String name, long definedSize, Charset charset) { + super(name, charset, definedSize); + } + + public MemoryAttribute(String name, String value) throws IOException { + this(name, value, HttpConstants.DEFAULT_CHARSET); // Attribute have no default size + } + + public MemoryAttribute(String name, String value, Charset charset) throws IOException { + super(name, charset, 0); // Attribute have no default size + setValue(value); + } + + @Override + public HttpDataType getHttpDataType() { + return HttpDataType.Attribute; + } + + @Override + public String getValue() { + return getByteBuf().toString(getCharset()); + } + + @Override + public void setValue(String value) throws IOException { + ObjectUtil.checkNotNull(value, "value"); + byte [] bytes = value.getBytes(getCharset()); + checkSize(bytes.length); + ByteBuf buffer = wrappedBuffer(bytes); + if (definedSize > 0) { + definedSize = buffer.readableBytes(); + } + setContent(buffer); + } + + @Override + public void addContent(ByteBuf buffer, boolean last) throws IOException { + int localsize = buffer.readableBytes(); + try { + checkSize(size + localsize); + } catch (IOException e) { + buffer.release(); + throw e; + } + if (definedSize > 0 && definedSize < size + localsize) { + definedSize = size + localsize; + } + super.addContent(buffer, last); + } + + @Override + public int hashCode() { + return getName().hashCode(); + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof Attribute)) { + return false; + } + Attribute attribute = (Attribute) o; + return getName().equalsIgnoreCase(attribute.getName()); + } + + @Override + public int compareTo(InterfaceHttpData other) { + if (!(other instanceof Attribute)) { + throw new ClassCastException("Cannot compare " + getHttpDataType() + + " with " + other.getHttpDataType()); + } + return compareTo((Attribute) other); + } + + public int compareTo(Attribute o) { + return getName().compareToIgnoreCase(o.getName()); + } + + @Override + public String toString() { + return getName() + '=' + getValue(); + } + + @Override + public Attribute copy() { + final ByteBuf content = content(); + return replace(content != null ? content.copy() : null); + } + + @Override + public Attribute duplicate() { + final ByteBuf content = content(); + return replace(content != null ? content.duplicate() : null); + } + + @Override + public Attribute retainedDuplicate() { + ByteBuf content = content(); + if (content != null) { + content = content.retainedDuplicate(); + boolean success = false; + try { + Attribute duplicate = replace(content); + success = true; + return duplicate; + } finally { + if (!success) { + content.release(); + } + } + } else { + return replace(null); + } + } + + @Override + public Attribute replace(ByteBuf content) { + MemoryAttribute attr = new MemoryAttribute(getName()); + attr.setCharset(getCharset()); + if (content != null) { + try { + attr.setContent(content); + } catch (IOException e) { + throw new ChannelException(e); + } + } + attr.setCompleted(isCompleted()); + return attr; + } + + @Override + public Attribute retain() { + super.retain(); + return this; + } + + @Override + public Attribute retain(int increment) { + super.retain(increment); + return this; + } + + @Override + public Attribute touch() { + super.touch(); + return this; + } + + @Override + public Attribute touch(Object hint) { + super.touch(hint); + return this; + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/MemoryFileUpload.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/MemoryFileUpload.java new file mode 100644 index 0000000..de1cbe2 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/MemoryFileUpload.java @@ -0,0 +1,188 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.multipart; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelException; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaderValues; +import io.netty.util.internal.ObjectUtil; + +import java.io.IOException; +import java.nio.charset.Charset; + +/** + * Default FileUpload implementation that stores file into memory.

+ * + * Warning: be aware of the memory limitation. + */ +public class MemoryFileUpload extends AbstractMemoryHttpData implements FileUpload { + + private String filename; + + private String contentType; + + private String contentTransferEncoding; + + public MemoryFileUpload(String name, String filename, String contentType, + String contentTransferEncoding, Charset charset, long size) { + super(name, charset, size); + setFilename(filename); + setContentType(contentType); + setContentTransferEncoding(contentTransferEncoding); + } + + @Override + public HttpDataType getHttpDataType() { + return HttpDataType.FileUpload; + } + + @Override + public String getFilename() { + return filename; + } + + @Override + public void setFilename(String filename) { + this.filename = ObjectUtil.checkNotNull(filename, "filename"); + } + + @Override + public int hashCode() { + return FileUploadUtil.hashCode(this); + } + + @Override + public boolean equals(Object o) { + return o instanceof FileUpload && FileUploadUtil.equals(this, (FileUpload) o); + } + + @Override + public int compareTo(InterfaceHttpData o) { + if (!(o instanceof FileUpload)) { + throw new ClassCastException("Cannot compare " + getHttpDataType() + + " with " + o.getHttpDataType()); + } + return compareTo((FileUpload) o); + } + + public int compareTo(FileUpload o) { + return FileUploadUtil.compareTo(this, o); + } + + @Override + public void setContentType(String contentType) { + this.contentType = ObjectUtil.checkNotNull(contentType, "contentType"); + } + + @Override + public String getContentType() { + return contentType; + } + + @Override + public String getContentTransferEncoding() { + return contentTransferEncoding; + } + + @Override + public void setContentTransferEncoding(String contentTransferEncoding) { + this.contentTransferEncoding = contentTransferEncoding; + } + + @Override + public String toString() { + return HttpHeaderNames.CONTENT_DISPOSITION + ": " + + HttpHeaderValues.FORM_DATA + "; " + HttpHeaderValues.NAME + "=\"" + getName() + + "\"; " + HttpHeaderValues.FILENAME + "=\"" + filename + "\"\r\n" + + HttpHeaderNames.CONTENT_TYPE + ": " + contentType + + (getCharset() != null? "; " + HttpHeaderValues.CHARSET + '=' + getCharset().name() + "\r\n" : "\r\n") + + HttpHeaderNames.CONTENT_LENGTH + ": " + length() + "\r\n" + + "Completed: " + isCompleted() + + "\r\nIsInMemory: " + isInMemory(); + } + + @Override + public FileUpload copy() { + final ByteBuf content = content(); + return replace(content != null ? content.copy() : content); + } + + @Override + public FileUpload duplicate() { + final ByteBuf content = content(); + return replace(content != null ? content.duplicate() : content); + } + + @Override + public FileUpload retainedDuplicate() { + ByteBuf content = content(); + if (content != null) { + content = content.retainedDuplicate(); + boolean success = false; + try { + FileUpload duplicate = replace(content); + success = true; + return duplicate; + } finally { + if (!success) { + content.release(); + } + } + } else { + return replace(null); + } + } + + @Override + public FileUpload replace(ByteBuf content) { + MemoryFileUpload upload = new MemoryFileUpload( + getName(), getFilename(), getContentType(), getContentTransferEncoding(), getCharset(), size); + if (content != null) { + try { + upload.setContent(content); + } catch (IOException e) { + throw new ChannelException(e); + } + } + upload.setCompleted(isCompleted()); + return upload; + } + + @Override + public FileUpload retain() { + super.retain(); + return this; + } + + @Override + public FileUpload retain(int increment) { + super.retain(increment); + return this; + } + + @Override + public FileUpload touch() { + super.touch(); + return this; + } + + @Override + public FileUpload touch(Object hint) { + super.touch(hint); + return this; + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/MixedAttribute.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/MixedAttribute.java new file mode 100644 index 0000000..3fd2bfc --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/MixedAttribute.java @@ -0,0 +1,157 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.multipart; + +import io.netty.buffer.ByteBuf; +import io.netty.handler.codec.http.HttpConstants; + +import java.io.IOException; +import java.nio.charset.Charset; + +/** + * Mixed implementation using both in Memory and in File with a limit of size + */ +public class MixedAttribute extends AbstractMixedHttpData implements Attribute { + public MixedAttribute(String name, long limitSize) { + this(name, limitSize, HttpConstants.DEFAULT_CHARSET); + } + + public MixedAttribute(String name, long definedSize, long limitSize) { + this(name, definedSize, limitSize, HttpConstants.DEFAULT_CHARSET); + } + + public MixedAttribute(String name, long limitSize, Charset charset) { + this(name, limitSize, charset, DiskAttribute.baseDirectory, DiskAttribute.deleteOnExitTemporaryFile); + } + + public MixedAttribute(String name, long limitSize, Charset charset, String baseDir, boolean deleteOnExit) { + this(name, 0, limitSize, charset, baseDir, deleteOnExit); + } + + public MixedAttribute(String name, long definedSize, long limitSize, Charset charset) { + this(name, definedSize, limitSize, charset, + DiskAttribute.baseDirectory, DiskAttribute.deleteOnExitTemporaryFile); + } + + public MixedAttribute(String name, long definedSize, long limitSize, Charset charset, + String baseDir, boolean deleteOnExit) { + super(limitSize, baseDir, deleteOnExit, + new MemoryAttribute(name, definedSize, charset)); + } + + public MixedAttribute(String name, String value, long limitSize) { + this(name, value, limitSize, HttpConstants.DEFAULT_CHARSET, + DiskAttribute.baseDirectory, DiskFileUpload.deleteOnExitTemporaryFile); + } + + public MixedAttribute(String name, String value, long limitSize, Charset charset) { + this(name, value, limitSize, charset, + DiskAttribute.baseDirectory, DiskFileUpload.deleteOnExitTemporaryFile); + } + + private static Attribute makeInitialAttributeFromValue(String name, String value, long limitSize, Charset charset, + String baseDir, boolean deleteOnExit) { + if (value.length() > limitSize) { + try { + return new DiskAttribute(name, value, charset, baseDir, deleteOnExit); + } catch (IOException e) { + // revert to Memory mode + try { + return new MemoryAttribute(name, value, charset); + } catch (IOException ignore) { + throw new IllegalArgumentException(e); + } + } + } else { + try { + return new MemoryAttribute(name, value, charset); + } catch (IOException e) { + throw new IllegalArgumentException(e); + } + } + } + + public MixedAttribute(String name, String value, long limitSize, Charset charset, + String baseDir, boolean deleteOnExit) { + super(limitSize, baseDir, deleteOnExit, + makeInitialAttributeFromValue(name, value, limitSize, charset, baseDir, deleteOnExit)); + } + + @Override + public String getValue() throws IOException { + return wrapped.getValue(); + } + + @Override + public void setValue(String value) throws IOException { + wrapped.setValue(value); + } + + @Override + Attribute makeDiskData() { + DiskAttribute diskAttribute = new DiskAttribute(getName(), definedLength(), baseDir, deleteOnExit); + diskAttribute.setMaxSize(getMaxSize()); + return diskAttribute; + } + + @Override + public Attribute copy() { + // for binary compatibility + return super.copy(); + } + + @Override + public Attribute duplicate() { + // for binary compatibility + return super.duplicate(); + } + + @Override + public Attribute replace(ByteBuf content) { + // for binary compatibility + return super.replace(content); + } + + @Override + public Attribute retain() { + // for binary compatibility + return super.retain(); + } + + @Override + public Attribute retain(int increment) { + // for binary compatibility + return super.retain(increment); + } + + @Override + public Attribute retainedDuplicate() { + // for binary compatibility + return super.retainedDuplicate(); + } + + @Override + public Attribute touch() { + // for binary compatibility + return super.touch(); + } + + @Override + public Attribute touch(Object hint) { + // for binary compatibility + return super.touch(hint); + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/MixedFileUpload.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/MixedFileUpload.java new file mode 100644 index 0000000..ceb9890 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/MixedFileUpload.java @@ -0,0 +1,131 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.multipart; + +import io.netty.buffer.ByteBuf; + +import java.nio.charset.Charset; + +/** + * Mixed implementation using both in Memory and in File with a limit of size + */ +public class MixedFileUpload extends AbstractMixedHttpData implements FileUpload { + + public MixedFileUpload(String name, String filename, String contentType, + String contentTransferEncoding, Charset charset, long size, + long limitSize) { + this(name, filename, contentType, contentTransferEncoding, + charset, size, limitSize, DiskFileUpload.baseDirectory, DiskFileUpload.deleteOnExitTemporaryFile); + } + + public MixedFileUpload(String name, String filename, String contentType, + String contentTransferEncoding, Charset charset, long size, + long limitSize, String baseDir, boolean deleteOnExit) { + super(limitSize, baseDir, deleteOnExit, + size > limitSize? + new DiskFileUpload(name, filename, contentType, contentTransferEncoding, charset, size, baseDir, + deleteOnExit) : + new MemoryFileUpload(name, filename, contentType, contentTransferEncoding, charset, size) + ); + } + + @Override + public String getContentTransferEncoding() { + return wrapped.getContentTransferEncoding(); + } + + @Override + public String getFilename() { + return wrapped.getFilename(); + } + + @Override + public void setContentTransferEncoding(String contentTransferEncoding) { + wrapped.setContentTransferEncoding(contentTransferEncoding); + } + + @Override + public void setFilename(String filename) { + wrapped.setFilename(filename); + } + + @Override + public void setContentType(String contentType) { + wrapped.setContentType(contentType); + } + + @Override + public String getContentType() { + return wrapped.getContentType(); + } + + @Override + FileUpload makeDiskData() { + DiskFileUpload diskFileUpload = new DiskFileUpload( + getName(), getFilename(), getContentType(), getContentTransferEncoding(), getCharset(), definedLength(), + baseDir, deleteOnExit); + diskFileUpload.setMaxSize(getMaxSize()); + return diskFileUpload; + } + + @Override + public FileUpload copy() { + // for binary compatibility + return super.copy(); + } + + @Override + public FileUpload duplicate() { + // for binary compatibility + return super.duplicate(); + } + + @Override + public FileUpload retainedDuplicate() { + // for binary compatibility + return super.retainedDuplicate(); + } + + @Override + public FileUpload replace(ByteBuf content) { + // for binary compatibility + return super.replace(content); + } + + @Override + public FileUpload touch() { + // for binary compatibility + return super.touch(); + } + + @Override + public FileUpload touch(Object hint) { + // for binary compatibility + return super.touch(hint); + } + + @Override + public FileUpload retain() { + // for binary compatibility + return super.retain(); + } + + @Override + public FileUpload retain(int increment) { + // for binary compatibility + return super.retain(increment); + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/package-info.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/package-info.java new file mode 100644 index 0000000..9575df5 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/multipart/package-info.java @@ -0,0 +1,20 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/** + * HTTP multipart support. + */ +package io.netty.handler.codec.http.multipart; diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/package-info.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/package-info.java new file mode 100644 index 0000000..305e125 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/package-info.java @@ -0,0 +1,20 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/** + * Encoder, decoder and their related message types for HTTP. + */ +package io.netty.handler.codec.http; diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/BinaryWebSocketFrame.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/BinaryWebSocketFrame.java new file mode 100644 index 0000000..9ea4288 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/BinaryWebSocketFrame.java @@ -0,0 +1,100 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; + +/** + * Web Socket frame containing binary data. + */ +public class BinaryWebSocketFrame extends WebSocketFrame { + + /** + * Creates a new empty binary frame. + */ + public BinaryWebSocketFrame() { + super(Unpooled.buffer(0)); + } + + /** + * Creates a new binary frame with the specified binary data. The final fragment flag is set to true. + * + * @param binaryData + * the content of the frame. + */ + public BinaryWebSocketFrame(ByteBuf binaryData) { + super(binaryData); + } + + /** + * Creates a new binary frame with the specified binary data and the final fragment flag. + * + * @param finalFragment + * flag indicating if this frame is the final fragment + * @param rsv + * reserved bits used for protocol extensions + * @param binaryData + * the content of the frame. + */ + public BinaryWebSocketFrame(boolean finalFragment, int rsv, ByteBuf binaryData) { + super(finalFragment, rsv, binaryData); + } + + @Override + public BinaryWebSocketFrame copy() { + return (BinaryWebSocketFrame) super.copy(); + } + + @Override + public BinaryWebSocketFrame duplicate() { + return (BinaryWebSocketFrame) super.duplicate(); + } + + @Override + public BinaryWebSocketFrame retainedDuplicate() { + return (BinaryWebSocketFrame) super.retainedDuplicate(); + } + + @Override + public BinaryWebSocketFrame replace(ByteBuf content) { + return new BinaryWebSocketFrame(isFinalFragment(), rsv(), content); + } + + @Override + public BinaryWebSocketFrame retain() { + super.retain(); + return this; + } + + @Override + public BinaryWebSocketFrame retain(int increment) { + super.retain(increment); + return this; + } + + @Override + public BinaryWebSocketFrame touch() { + super.touch(); + return this; + } + + @Override + public BinaryWebSocketFrame touch(Object hint) { + super.touch(hint); + return this; + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/CloseWebSocketFrame.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/CloseWebSocketFrame.java new file mode 100644 index 0000000..ce8b430 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/CloseWebSocketFrame.java @@ -0,0 +1,206 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.netty.util.internal.StringUtil; + +/** + * Web Socket Frame for closing the connection. + */ +public class CloseWebSocketFrame extends WebSocketFrame { + + /** + * Creates a new empty close frame. + */ + public CloseWebSocketFrame() { + super(Unpooled.buffer(0)); + } + + /** + * Creates a new empty close frame with closing status code and reason text + * + * @param status + * Status code as per RFC 6455. For + * example, 1000 indicates normal closure. + */ + public CloseWebSocketFrame(WebSocketCloseStatus status) { + this(requireValidStatusCode(status.code()), status.reasonText()); + } + + /** + * Creates a new empty close frame with closing status code and reason text + * + * @param status + * Status code as per RFC 6455. For + * example, 1000 indicates normal closure. + * @param reasonText + * Reason text. Set to null if no text. + */ + public CloseWebSocketFrame(WebSocketCloseStatus status, String reasonText) { + this(requireValidStatusCode(status.code()), reasonText); + } + + /** + * Creates a new empty close frame with closing status code and reason text + * + * @param statusCode + * Integer status code as per RFC 6455. For + * example, 1000 indicates normal closure. + * @param reasonText + * Reason text. Set to null if no text. + */ + public CloseWebSocketFrame(int statusCode, String reasonText) { + this(true, 0, requireValidStatusCode(statusCode), reasonText); + } + + /** + * Creates a new close frame with no losing status code and no reason text + * + * @param finalFragment + * flag indicating if this frame is the final fragment + * @param rsv + * reserved bits used for protocol extensions. + */ + public CloseWebSocketFrame(boolean finalFragment, int rsv) { + this(finalFragment, rsv, Unpooled.buffer(0)); + } + + /** + * Creates a new close frame with closing status code and reason text + * + * @param finalFragment + * flag indicating if this frame is the final fragment + * @param rsv + * reserved bits used for protocol extensions + * @param statusCode + * Integer status code as per RFC 6455. For + * example, 1000 indicates normal closure. + * @param reasonText + * Reason text. Set to null if no text. + */ + public CloseWebSocketFrame(boolean finalFragment, int rsv, int statusCode, String reasonText) { + super(finalFragment, rsv, newBinaryData(requireValidStatusCode(statusCode), reasonText)); + } + + private static ByteBuf newBinaryData(int statusCode, String reasonText) { + if (reasonText == null) { + reasonText = StringUtil.EMPTY_STRING; + } + + ByteBuf binaryData = Unpooled.buffer(2 + reasonText.length()); + binaryData.writeShort(statusCode); + if (!reasonText.isEmpty()) { + binaryData.writeCharSequence(reasonText, CharsetUtil.UTF_8); + } + return binaryData; + } + + /** + * Creates a new close frame + * + * @param finalFragment + * flag indicating if this frame is the final fragment + * @param rsv + * reserved bits used for protocol extensions + * @param binaryData + * the content of the frame. Must be 2 byte integer followed by optional UTF-8 encoded string. + */ + public CloseWebSocketFrame(boolean finalFragment, int rsv, ByteBuf binaryData) { + super(finalFragment, rsv, binaryData); + } + + /** + * Returns the closing status code as per RFC 6455. If + * a status code is set, -1 is returned. + */ + public int statusCode() { + ByteBuf binaryData = content(); + if (binaryData == null || binaryData.readableBytes() < 2) { + return -1; + } + + return binaryData.getUnsignedShort(binaryData.readerIndex()); + } + + /** + * Returns the reason text as per RFC 6455 If a reason + * text is not supplied, an empty string is returned. + */ + public String reasonText() { + ByteBuf binaryData = content(); + if (binaryData == null || binaryData.readableBytes() <= 2) { + return ""; + } + + return binaryData.toString(binaryData.readerIndex() + 2, binaryData.readableBytes() - 2, CharsetUtil.UTF_8); + } + + @Override + public CloseWebSocketFrame copy() { + return (CloseWebSocketFrame) super.copy(); + } + + @Override + public CloseWebSocketFrame duplicate() { + return (CloseWebSocketFrame) super.duplicate(); + } + + @Override + public CloseWebSocketFrame retainedDuplicate() { + return (CloseWebSocketFrame) super.retainedDuplicate(); + } + + @Override + public CloseWebSocketFrame replace(ByteBuf content) { + return new CloseWebSocketFrame(isFinalFragment(), rsv(), content); + } + + @Override + public CloseWebSocketFrame retain() { + super.retain(); + return this; + } + + @Override + public CloseWebSocketFrame retain(int increment) { + super.retain(increment); + return this; + } + + @Override + public CloseWebSocketFrame touch() { + super.touch(); + return this; + } + + @Override + public CloseWebSocketFrame touch(Object hint) { + super.touch(hint); + return this; + } + + static int requireValidStatusCode(int statusCode) { + if (WebSocketCloseStatus.isValidStatusCode(statusCode)) { + return statusCode; + } else { + throw new IllegalArgumentException("WebSocket close status code does NOT comply with RFC-6455: " + + statusCode); + } + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/ContinuationWebSocketFrame.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/ContinuationWebSocketFrame.java new file mode 100644 index 0000000..166bda0 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/ContinuationWebSocketFrame.java @@ -0,0 +1,137 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; + +/** + * Web Socket continuation frame containing continuation text or binary data. This is used for + * fragmented messages where the contents of a messages is contained more than 1 frame. + */ +public class ContinuationWebSocketFrame extends WebSocketFrame { + + /** + * Creates a new empty continuation frame. + */ + public ContinuationWebSocketFrame() { + this(Unpooled.buffer(0)); + } + + /** + * Creates a new continuation frame with the specified binary data. The final fragment flag is + * set to true. + * + * @param binaryData the content of the frame. + */ + public ContinuationWebSocketFrame(ByteBuf binaryData) { + super(binaryData); + } + + /** + * Creates a new continuation frame with the specified binary data. + * + * @param finalFragment + * flag indicating if this frame is the final fragment + * @param rsv + * reserved bits used for protocol extensions + * @param binaryData + * the content of the frame. + */ + public ContinuationWebSocketFrame(boolean finalFragment, int rsv, ByteBuf binaryData) { + super(finalFragment, rsv, binaryData); + } + + /** + * Creates a new continuation frame with the specified text data + * + * @param finalFragment + * flag indicating if this frame is the final fragment + * @param rsv + * reserved bits used for protocol extensions + * @param text + * text content of the frame. + */ + public ContinuationWebSocketFrame(boolean finalFragment, int rsv, String text) { + this(finalFragment, rsv, fromText(text)); + } + + /** + * Returns the text data in this frame. + */ + public String text() { + return content().toString(CharsetUtil.UTF_8); + } + + /** + * Sets the string for this frame. + * + * @param text + * text to store. + */ + private static ByteBuf fromText(String text) { + if (text == null || text.isEmpty()) { + return Unpooled.EMPTY_BUFFER; + } else { + return Unpooled.copiedBuffer(text, CharsetUtil.UTF_8); + } + } + + @Override + public ContinuationWebSocketFrame copy() { + return (ContinuationWebSocketFrame) super.copy(); + } + + @Override + public ContinuationWebSocketFrame duplicate() { + return (ContinuationWebSocketFrame) super.duplicate(); + } + + @Override + public ContinuationWebSocketFrame retainedDuplicate() { + return (ContinuationWebSocketFrame) super.retainedDuplicate(); + } + + @Override + public ContinuationWebSocketFrame replace(ByteBuf content) { + return new ContinuationWebSocketFrame(isFinalFragment(), rsv(), content); + } + + @Override + public ContinuationWebSocketFrame retain() { + super.retain(); + return this; + } + + @Override + public ContinuationWebSocketFrame retain(int increment) { + super.retain(increment); + return this; + } + + @Override + public ContinuationWebSocketFrame touch() { + super.touch(); + return this; + } + + @Override + public ContinuationWebSocketFrame touch(Object hint) { + super.touch(hint); + return this; + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/CorruptedWebSocketFrameException.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/CorruptedWebSocketFrameException.java new file mode 100644 index 0000000..92022df --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/CorruptedWebSocketFrameException.java @@ -0,0 +1,64 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx; + +import io.netty.handler.codec.CorruptedFrameException; +import io.netty.handler.codec.DecoderException; + +/** + * An {@link DecoderException} which is thrown when the received {@link WebSocketFrame} data could not be decoded by + * an inbound handler. + */ +public final class CorruptedWebSocketFrameException extends CorruptedFrameException { + + private static final long serialVersionUID = 3918055132492988338L; + + private final WebSocketCloseStatus closeStatus; + + /** + * Creates a new instance. + */ + public CorruptedWebSocketFrameException() { + this(WebSocketCloseStatus.PROTOCOL_ERROR, null, null); + } + + /** + * Creates a new instance. + */ + public CorruptedWebSocketFrameException(WebSocketCloseStatus status, String message, Throwable cause) { + super(message == null ? status.reasonText() : message, cause); + closeStatus = status; + } + + /** + * Creates a new instance. + */ + public CorruptedWebSocketFrameException(WebSocketCloseStatus status, String message) { + this(status, message, null); + } + + /** + * Creates a new instance. + */ + public CorruptedWebSocketFrameException(WebSocketCloseStatus status, Throwable cause) { + this(status, null, cause); + } + + public WebSocketCloseStatus closeStatus() { + return closeStatus; + } + +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/PingWebSocketFrame.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/PingWebSocketFrame.java new file mode 100644 index 0000000..7208bb8 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/PingWebSocketFrame.java @@ -0,0 +1,100 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; + +/** + * Web Socket frame containing binary data. + */ +public class PingWebSocketFrame extends WebSocketFrame { + + /** + * Creates a new empty ping frame. + */ + public PingWebSocketFrame() { + super(true, 0, Unpooled.buffer(0)); + } + + /** + * Creates a new ping frame with the specified binary data. + * + * @param binaryData + * the content of the frame. + */ + public PingWebSocketFrame(ByteBuf binaryData) { + super(binaryData); + } + + /** + * Creates a new ping frame with the specified binary data. + * + * @param finalFragment + * flag indicating if this frame is the final fragment + * @param rsv + * reserved bits used for protocol extensions + * @param binaryData + * the content of the frame. + */ + public PingWebSocketFrame(boolean finalFragment, int rsv, ByteBuf binaryData) { + super(finalFragment, rsv, binaryData); + } + + @Override + public PingWebSocketFrame copy() { + return (PingWebSocketFrame) super.copy(); + } + + @Override + public PingWebSocketFrame duplicate() { + return (PingWebSocketFrame) super.duplicate(); + } + + @Override + public PingWebSocketFrame retainedDuplicate() { + return (PingWebSocketFrame) super.retainedDuplicate(); + } + + @Override + public PingWebSocketFrame replace(ByteBuf content) { + return new PingWebSocketFrame(isFinalFragment(), rsv(), content); + } + + @Override + public PingWebSocketFrame retain() { + super.retain(); + return this; + } + + @Override + public PingWebSocketFrame retain(int increment) { + super.retain(increment); + return this; + } + + @Override + public PingWebSocketFrame touch() { + super.touch(); + return this; + } + + @Override + public PingWebSocketFrame touch(Object hint) { + super.touch(hint); + return this; + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/PongWebSocketFrame.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/PongWebSocketFrame.java new file mode 100644 index 0000000..79cb9a7 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/PongWebSocketFrame.java @@ -0,0 +1,100 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; + +/** + * Web Socket frame containing binary data. + */ +public class PongWebSocketFrame extends WebSocketFrame { + + /** + * Creates a new empty pong frame. + */ + public PongWebSocketFrame() { + super(Unpooled.buffer(0)); + } + + /** + * Creates a new pong frame with the specified binary data. + * + * @param binaryData + * the content of the frame. + */ + public PongWebSocketFrame(ByteBuf binaryData) { + super(binaryData); + } + + /** + * Creates a new pong frame with the specified binary data + * + * @param finalFragment + * flag indicating if this frame is the final fragment + * @param rsv + * reserved bits used for protocol extensions + * @param binaryData + * the content of the frame. + */ + public PongWebSocketFrame(boolean finalFragment, int rsv, ByteBuf binaryData) { + super(finalFragment, rsv, binaryData); + } + + @Override + public PongWebSocketFrame copy() { + return (PongWebSocketFrame) super.copy(); + } + + @Override + public PongWebSocketFrame duplicate() { + return (PongWebSocketFrame) super.duplicate(); + } + + @Override + public PongWebSocketFrame retainedDuplicate() { + return (PongWebSocketFrame) super.retainedDuplicate(); + } + + @Override + public PongWebSocketFrame replace(ByteBuf content) { + return new PongWebSocketFrame(isFinalFragment(), rsv(), content); + } + + @Override + public PongWebSocketFrame retain() { + super.retain(); + return this; + } + + @Override + public PongWebSocketFrame retain(int increment) { + super.retain(increment); + return this; + } + + @Override + public PongWebSocketFrame touch() { + super.touch(); + return this; + } + + @Override + public PongWebSocketFrame touch(Object hint) { + super.touch(hint); + return this; + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/TextWebSocketFrame.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/TextWebSocketFrame.java new file mode 100644 index 0000000..f520cf6 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/TextWebSocketFrame.java @@ -0,0 +1,140 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; + +/** + * Web Socket text frame. + */ +public class TextWebSocketFrame extends WebSocketFrame { + + /** + * Creates a new empty text frame. + */ + public TextWebSocketFrame() { + super(Unpooled.buffer(0)); + } + + /** + * Creates a new text frame with the specified text string. The final fragment flag is set to true. + * + * @param text + * String to put in the frame. + */ + public TextWebSocketFrame(String text) { + super(fromText(text)); + } + + /** + * Creates a new text frame with the specified binary data. The final fragment flag is set to true. + * + * @param binaryData + * the content of the frame. + */ + public TextWebSocketFrame(ByteBuf binaryData) { + super(binaryData); + } + + /** + * Creates a new text frame with the specified text string. The final fragment flag is set to true. + * + * @param finalFragment + * flag indicating if this frame is the final fragment + * @param rsv + * reserved bits used for protocol extensions + * @param text + * String to put in the frame. + */ + public TextWebSocketFrame(boolean finalFragment, int rsv, String text) { + super(finalFragment, rsv, fromText(text)); + } + + private static ByteBuf fromText(String text) { + if (text == null || text.isEmpty()) { + return Unpooled.EMPTY_BUFFER; + } else { + return Unpooled.copiedBuffer(text, CharsetUtil.UTF_8); + } + } + + /** + * Creates a new text frame with the specified binary data and the final fragment flag. + * + * @param finalFragment + * flag indicating if this frame is the final fragment + * @param rsv + * reserved bits used for protocol extensions + * @param binaryData + * the content of the frame. + */ + public TextWebSocketFrame(boolean finalFragment, int rsv, ByteBuf binaryData) { + super(finalFragment, rsv, binaryData); + } + + /** + * Returns the text data in this frame. + */ + public String text() { + return content().toString(CharsetUtil.UTF_8); + } + + @Override + public TextWebSocketFrame copy() { + return (TextWebSocketFrame) super.copy(); + } + + @Override + public TextWebSocketFrame duplicate() { + return (TextWebSocketFrame) super.duplicate(); + } + + @Override + public TextWebSocketFrame retainedDuplicate() { + return (TextWebSocketFrame) super.retainedDuplicate(); + } + + @Override + public TextWebSocketFrame replace(ByteBuf content) { + return new TextWebSocketFrame(isFinalFragment(), rsv(), content); + } + + @Override + public TextWebSocketFrame retain() { + super.retain(); + return this; + } + + @Override + public TextWebSocketFrame retain(int increment) { + super.retain(increment); + return this; + } + + @Override + public TextWebSocketFrame touch() { + super.touch(); + return this; + } + + @Override + public TextWebSocketFrame touch(Object hint) { + super.touch(hint); + return this; + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/Utf8FrameValidator.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/Utf8FrameValidator.java new file mode 100644 index 0000000..435e6da --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/Utf8FrameValidator.java @@ -0,0 +1,120 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; + +/** + * + */ +public class Utf8FrameValidator extends ChannelInboundHandlerAdapter { + + private final boolean closeOnProtocolViolation; + + private int fragmentedFramesCount; + private Utf8Validator utf8Validator; + + public Utf8FrameValidator() { + this(true); + } + + public Utf8FrameValidator(boolean closeOnProtocolViolation) { + this.closeOnProtocolViolation = closeOnProtocolViolation; + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if (msg instanceof WebSocketFrame) { + WebSocketFrame frame = (WebSocketFrame) msg; + + try { + // Processing for possible fragmented messages for text and binary + // frames + if (((WebSocketFrame) msg).isFinalFragment()) { + // Final frame of the sequence. Apparently ping frames are + // allowed in the middle of a fragmented message + if (!(frame instanceof PingWebSocketFrame)) { + fragmentedFramesCount = 0; + + // Check text for UTF8 correctness + if ((frame instanceof TextWebSocketFrame) || + (utf8Validator != null && utf8Validator.isChecking())) { + // Check UTF-8 correctness for this payload + checkUTF8String(frame.content()); + + // This does a second check to make sure UTF-8 + // correctness for entire text message + utf8Validator.finish(); + } + } + } else { + // Not final frame so we can expect more frames in the + // fragmented sequence + if (fragmentedFramesCount == 0) { + // First text or binary frame for a fragmented set + if (frame instanceof TextWebSocketFrame) { + checkUTF8String(frame.content()); + } + } else { + // Subsequent frames - only check if init frame is text + if (utf8Validator != null && utf8Validator.isChecking()) { + checkUTF8String(frame.content()); + } + } + + // Increment counter + fragmentedFramesCount++; + } + } catch (CorruptedWebSocketFrameException e) { + protocolViolation(ctx, frame, e); + } + } + + super.channelRead(ctx, msg); + } + + private void checkUTF8String(ByteBuf buffer) { + if (utf8Validator == null) { + utf8Validator = new Utf8Validator(); + } + utf8Validator.check(buffer); + } + + private void protocolViolation(ChannelHandlerContext ctx, WebSocketFrame frame, + CorruptedWebSocketFrameException ex) { + frame.release(); + if (closeOnProtocolViolation && ctx.channel().isOpen()) { + WebSocketCloseStatus closeStatus = ex.closeStatus(); + String reasonText = ex.getMessage(); + if (reasonText == null) { + reasonText = closeStatus.reasonText(); + } + + CloseWebSocketFrame closeFrame = new CloseWebSocketFrame(closeStatus.code(), reasonText); + ctx.writeAndFlush(closeFrame).addListener(ChannelFutureListener.CLOSE); + } + + throw ex; + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + super.exceptionCaught(ctx, cause); + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/Utf8Validator.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/Utf8Validator.java new file mode 100644 index 0000000..d415c59 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/Utf8Validator.java @@ -0,0 +1,110 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +/* + * Adaptation of https://bjoern.hoehrmann.de/utf-8/decoder/dfa/ + * + * Copyright (c) 2008-2009 Bjoern Hoehrmann + * + * Permission is hereby granted, free of charge, to any person obtaining a copy of this software + * and associated documentation files (the "Software"), to deal in the Software without restriction, + * including without limitation the rights to use, copy, modify, merge, publish, distribute, + * sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is + * furnished to do so, subject to the following conditions: + * + * The above copyright notice and this permission notice shall be included in all copies or + * substantial portions of the Software. + * + * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING + * BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, + * DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, + * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ +package io.netty.handler.codec.http.websocketx; + +import io.netty.buffer.ByteBuf; +import io.netty.util.ByteProcessor; + +/** + * Checks UTF8 bytes for validity + */ +final class Utf8Validator implements ByteProcessor { + private static final int UTF8_ACCEPT = 0; + private static final int UTF8_REJECT = 12; + + private static final byte[] TYPES = { 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1, + 1, 1, 1, 1, 1, 1, 1, 1, 1, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 9, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 8, + 8, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, 2, + 2, 2, 10, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 3, 4, 3, 3, 11, 6, 6, 6, 5, 8, 8, 8, 8, 8, + 8, 8, 8, 8, 8, 8 }; + + private static final byte[] STATES = { 0, 12, 24, 36, 60, 96, 84, 12, 12, 12, 48, 72, 12, 12, + 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 0, 12, 12, 12, 12, 12, 0, 12, 0, 12, 12, + 12, 24, 12, 12, 12, 12, 12, 24, 12, 24, 12, 12, 12, 12, 12, 12, 12, 12, 12, 24, 12, 12, + 12, 12, 12, 24, 12, 12, 12, 12, 12, 12, 12, 24, 12, 12, 12, 12, 12, 12, 12, 12, 12, 36, + 12, 36, 12, 12, 12, 36, 12, 12, 12, 12, 12, 36, 12, 36, 12, 12, 12, 36, 12, 12, 12, 12, + 12, 12, 12, 12, 12, 12 }; + + @SuppressWarnings("RedundantFieldInitialization") + private int state = UTF8_ACCEPT; + private int codep; + private boolean checking; + + public void check(ByteBuf buffer) { + checking = true; + buffer.forEachByte(this); + } + + void check(ByteBuf buffer, int index, int length) { + checking = true; + buffer.forEachByte(index, length, this); + } + + public void finish() { + checking = false; + codep = 0; + if (state != UTF8_ACCEPT) { + state = UTF8_ACCEPT; + throw new CorruptedWebSocketFrameException( + WebSocketCloseStatus.INVALID_PAYLOAD_DATA, "bytes are not UTF-8"); + } + } + + @Override + public boolean process(byte b) throws Exception { + byte type = TYPES[b & 0xFF]; + + codep = state != UTF8_ACCEPT ? b & 0x3f | codep << 6 : 0xff >> type & b; + + state = STATES[state + type]; + + if (state == UTF8_REJECT) { + checking = false; + throw new CorruptedWebSocketFrameException( + WebSocketCloseStatus.INVALID_PAYLOAD_DATA, "bytes are not UTF-8"); + } + return true; + } + + public boolean isChecking() { + return checking; + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocket00FrameDecoder.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocket00FrameDecoder.java new file mode 100644 index 0000000..edbec96 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocket00FrameDecoder.java @@ -0,0 +1,148 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.ReplayingDecoder; +import io.netty.handler.codec.TooLongFrameException; +import io.netty.util.internal.ObjectUtil; + +import java.util.List; + +import static io.netty.buffer.ByteBufUtil.readBytes; + +/** + * Decodes {@link ByteBuf}s into {@link WebSocketFrame}s. + *

+ * For the detailed instruction on adding add Web Socket support to your HTTP server, take a look into the + * WebSocketServer example located in the {@code io.netty.example.http.websocket} package. + */ +public class WebSocket00FrameDecoder extends ReplayingDecoder implements WebSocketFrameDecoder { + + static final int DEFAULT_MAX_FRAME_SIZE = 16384; + + private final long maxFrameSize; + private boolean receivedClosingHandshake; + + public WebSocket00FrameDecoder() { + this(DEFAULT_MAX_FRAME_SIZE); + } + + /** + * Creates a new instance of {@code WebSocketFrameDecoder} with the specified {@code maxFrameSize}. If the client + * sends a frame size larger than {@code maxFrameSize}, the channel will be closed. + * + * @param maxFrameSize + * the maximum frame size to decode + */ + public WebSocket00FrameDecoder(int maxFrameSize) { + this.maxFrameSize = maxFrameSize; + } + + /** + * Creates a new instance of {@code WebSocketFrameDecoder} with the specified {@code maxFrameSize}. If the client + * sends a frame size larger than {@code maxFrameSize}, the channel will be closed. + * + * @param decoderConfig + * Frames decoder configuration. + */ + public WebSocket00FrameDecoder(WebSocketDecoderConfig decoderConfig) { + this.maxFrameSize = ObjectUtil.checkNotNull(decoderConfig, "decoderConfig").maxFramePayloadLength(); + } + + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { + // Discard all data received if closing handshake was received before. + if (receivedClosingHandshake) { + in.skipBytes(actualReadableBytes()); + return; + } + + // Decode a frame otherwise. + byte type = in.readByte(); + WebSocketFrame frame; + if ((type & 0x80) == 0x80) { + // If the MSB on type is set, decode the frame length + frame = decodeBinaryFrame(ctx, type, in); + } else { + // Decode a 0xff terminated UTF-8 string + frame = decodeTextFrame(ctx, in); + } + + if (frame != null) { + out.add(frame); + } + } + + private WebSocketFrame decodeBinaryFrame(ChannelHandlerContext ctx, byte type, ByteBuf buffer) { + long frameSize = 0; + int lengthFieldSize = 0; + byte b; + do { + b = buffer.readByte(); + frameSize <<= 7; + frameSize |= b & 0x7f; + if (frameSize > maxFrameSize) { + throw new TooLongFrameException(); + } + lengthFieldSize++; + if (lengthFieldSize > 8) { + // Perhaps a malicious peer? + throw new TooLongFrameException(); + } + } while ((b & 0x80) == 0x80); + + if (type == (byte) 0xFF && frameSize == 0) { + receivedClosingHandshake = true; + return new CloseWebSocketFrame(true, 0, ctx.alloc().buffer(0)); + } + ByteBuf payload = readBytes(ctx.alloc(), buffer, (int) frameSize); + return new BinaryWebSocketFrame(payload); + } + + private WebSocketFrame decodeTextFrame(ChannelHandlerContext ctx, ByteBuf buffer) { + int ridx = buffer.readerIndex(); + int rbytes = actualReadableBytes(); + int delimPos = buffer.indexOf(ridx, ridx + rbytes, (byte) 0xFF); + if (delimPos == -1) { + // Frame delimiter (0xFF) not found + if (rbytes > maxFrameSize) { + // Frame length exceeded the maximum + throw new TooLongFrameException(); + } else { + // Wait until more data is received + return null; + } + } + + int frameSize = delimPos - ridx; + if (frameSize > maxFrameSize) { + throw new TooLongFrameException(); + } + + ByteBuf binaryData = readBytes(ctx.alloc(), buffer, frameSize); + buffer.skipBytes(1); + + int ffDelimPos = binaryData.indexOf(binaryData.readerIndex(), binaryData.writerIndex(), (byte) 0xFF); + if (ffDelimPos >= 0) { + binaryData.release(); + throw new IllegalArgumentException("a text frame should not contain 0xFF."); + } + + return new TextWebSocketFrame(binaryData); + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocket00FrameEncoder.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocket00FrameEncoder.java new file mode 100644 index 0000000..6b5d29e --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocket00FrameEncoder.java @@ -0,0 +1,99 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandler.Sharable; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.MessageToMessageEncoder; + +import java.util.List; + +/** + * Encodes a {@link WebSocketFrame} into a {@link ByteBuf}. + *

+ * For the detailed instruction on adding add Web Socket support to your HTTP server, take a look into the + * WebSocketServer example located in the {@code io.netty.example.http.websocket} package. + */ +@Sharable +public class WebSocket00FrameEncoder extends MessageToMessageEncoder implements WebSocketFrameEncoder { + private static final ByteBuf _0X00 = Unpooled.unreleasableBuffer( + Unpooled.directBuffer(1, 1).writeByte(0x00)).asReadOnly(); + private static final ByteBuf _0XFF = Unpooled.unreleasableBuffer( + Unpooled.directBuffer(1, 1).writeByte((byte) 0xFF)).asReadOnly(); + private static final ByteBuf _0XFF_0X00 = Unpooled.unreleasableBuffer( + Unpooled.directBuffer(2, 2).writeByte((byte) 0xFF).writeByte(0x00)).asReadOnly(); + + @Override + protected void encode(ChannelHandlerContext ctx, WebSocketFrame msg, List out) throws Exception { + if (msg instanceof TextWebSocketFrame) { + // Text frame + ByteBuf data = msg.content(); + + out.add(_0X00.duplicate()); + out.add(data.retain()); + out.add(_0XFF.duplicate()); + } else if (msg instanceof CloseWebSocketFrame) { + // Close frame, needs to call duplicate to allow multiple writes. + // See https://github.com/netty/netty/issues/2768 + out.add(_0XFF_0X00.duplicate()); + } else { + // Binary frame + ByteBuf data = msg.content(); + int dataLen = data.readableBytes(); + + ByteBuf buf = ctx.alloc().buffer(5); + boolean release = true; + try { + // Encode type. + buf.writeByte((byte) 0x80); + + // Encode length. + int b1 = dataLen >>> 28 & 0x7F; + int b2 = dataLen >>> 14 & 0x7F; + int b3 = dataLen >>> 7 & 0x7F; + int b4 = dataLen & 0x7F; + if (b1 == 0) { + if (b2 == 0) { + if (b3 != 0) { + buf.writeByte(b3 | 0x80); + } + buf.writeByte(b4); + } else { + buf.writeByte(b2 | 0x80); + buf.writeByte(b3 | 0x80); + buf.writeByte(b4); + } + } else { + buf.writeByte(b1 | 0x80); + buf.writeByte(b2 | 0x80); + buf.writeByte(b3 | 0x80); + buf.writeByte(b4); + } + + // Encode binary data. + out.add(buf); + out.add(data.retain()); + release = false; + } finally { + if (release) { + buf.release(); + } + } + } + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocket07FrameDecoder.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocket07FrameDecoder.java new file mode 100644 index 0000000..466004f --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocket07FrameDecoder.java @@ -0,0 +1,115 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +// (BSD License: https://www.opensource.org/licenses/bsd-license) +// +// Copyright (c) 2011, Joe Walnes and contributors +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or +// without modification, are permitted provided that the +// following conditions are met: +// +// * Redistributions of source code must retain the above +// copyright notice, this list of conditions and the +// following disclaimer. +// +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other +// materials provided with the distribution. +// +// * Neither the name of the Webbit nor the names of +// its contributors may be used to endorse or promote products +// derived from this software without specific prior written +// permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +// CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF +// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +// INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE +// GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR +// BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT +// OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. + +package io.netty.handler.codec.http.websocketx; + +/** + * Decodes a web socket frame from wire protocol version 7 format. V7 is essentially the same as V8. + */ +public class WebSocket07FrameDecoder extends WebSocket08FrameDecoder { + + /** + * Constructor + * + * @param expectMaskedFrames + * Web socket servers must set this to true processed incoming masked payload. Client implementations + * must set this to false. + * @param allowExtensions + * Flag to allow reserved extension bits to be used or not + * @param maxFramePayloadLength + * Maximum length of a frame's payload. Setting this to an appropriate value for you application + * helps check for denial of services attacks. + */ + public WebSocket07FrameDecoder(boolean expectMaskedFrames, boolean allowExtensions, int maxFramePayloadLength) { + this(WebSocketDecoderConfig.newBuilder() + .expectMaskedFrames(expectMaskedFrames) + .allowExtensions(allowExtensions) + .maxFramePayloadLength(maxFramePayloadLength) + .build()); + } + + /** + * Constructor + * + * @param expectMaskedFrames + * Web socket servers must set this to true processed incoming masked payload. Client implementations + * must set this to false. + * @param allowExtensions + * Flag to allow reserved extension bits to be used or not + * @param maxFramePayloadLength + * Maximum length of a frame's payload. Setting this to an appropriate value for you application + * helps check for denial of services attacks. + * @param allowMaskMismatch + * When set to true, frames which are not masked properly according to the standard will still be + * accepted. + */ + public WebSocket07FrameDecoder(boolean expectMaskedFrames, boolean allowExtensions, int maxFramePayloadLength, + boolean allowMaskMismatch) { + this(WebSocketDecoderConfig.newBuilder() + .expectMaskedFrames(expectMaskedFrames) + .allowExtensions(allowExtensions) + .maxFramePayloadLength(maxFramePayloadLength) + .allowMaskMismatch(allowMaskMismatch) + .build()); + } + + /** + * Constructor + * + * @param decoderConfig + * Frames decoder configuration. + */ + public WebSocket07FrameDecoder(WebSocketDecoderConfig decoderConfig) { + super(decoderConfig); + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocket07FrameEncoder.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocket07FrameEncoder.java new file mode 100644 index 0000000..f2604c9 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocket07FrameEncoder.java @@ -0,0 +1,73 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +// (BSD License: https://www.opensource.org/licenses/bsd-license) +// +// Copyright (c) 2011, Joe Walnes and contributors +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or +// without modification, are permitted provided that the +// following conditions are met: +// +// * Redistributions of source code must retain the above +// copyright notice, this list of conditions and the +// following disclaimer. +// +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other +// materials provided with the distribution. +// +// * Neither the name of the Webbit nor the names of +// its contributors may be used to endorse or promote products +// derived from this software without specific prior written +// permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +// CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF +// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +// INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE +// GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR +// BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT +// OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. + +package io.netty.handler.codec.http.websocketx; + +/** + *

+ * Encodes a web socket frame into wire protocol version 7 format. V7 is essentially the same as V8. + *

+ */ +public class WebSocket07FrameEncoder extends WebSocket08FrameEncoder { + + /** + * Constructor + * + * @param maskPayload + * Web socket clients must set this to true to mask payload. Server implementations must set this to + * false. + */ + public WebSocket07FrameEncoder(boolean maskPayload) { + super(maskPayload); + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocket08FrameDecoder.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocket08FrameDecoder.java new file mode 100644 index 0000000..29d7694 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocket08FrameDecoder.java @@ -0,0 +1,488 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +// (BSD License: https://www.opensource.org/licenses/bsd-license) +// +// Copyright (c) 2011, Joe Walnes and contributors +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or +// without modification, are permitted provided that the +// following conditions are met: +// +// * Redistributions of source code must retain the above +// copyright notice, this list of conditions and the +// following disclaimer. +// +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other +// materials provided with the distribution. +// +// * Neither the name of the Webbit nor the names of +// its contributors may be used to endorse or promote products +// derived from this software without specific prior written +// permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +// CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF +// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +// INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE +// GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR +// BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT +// OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. + +package io.netty.handler.codec.http.websocketx; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.ByteToMessageDecoder; +import io.netty.handler.codec.TooLongFrameException; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.nio.ByteOrder; +import java.util.List; + +import static io.netty.buffer.ByteBufUtil.readBytes; + +/** + * Decodes a web socket frame from wire protocol version 8 format. This code was forked from webbit and modified. + */ +public class WebSocket08FrameDecoder extends ByteToMessageDecoder + implements WebSocketFrameDecoder { + + enum State { + READING_FIRST, + READING_SECOND, + READING_SIZE, + MASKING_KEY, + PAYLOAD, + CORRUPT + } + + private static final InternalLogger logger = InternalLoggerFactory.getInstance(WebSocket08FrameDecoder.class); + + private static final byte OPCODE_CONT = 0x0; + private static final byte OPCODE_TEXT = 0x1; + private static final byte OPCODE_BINARY = 0x2; + private static final byte OPCODE_CLOSE = 0x8; + private static final byte OPCODE_PING = 0x9; + private static final byte OPCODE_PONG = 0xA; + + private final WebSocketDecoderConfig config; + + private int fragmentedFramesCount; + private boolean frameFinalFlag; + private boolean frameMasked; + private int frameRsv; + private int frameOpcode; + private long framePayloadLength; + private int mask; + private int framePayloadLen1; + private boolean receivedClosingHandshake; + private State state = State.READING_FIRST; + + /** + * Constructor + * + * @param expectMaskedFrames + * Web socket servers must set this to true processed incoming masked payload. Client implementations + * must set this to false. + * @param allowExtensions + * Flag to allow reserved extension bits to be used or not + * @param maxFramePayloadLength + * Maximum length of a frame's payload. Setting this to an appropriate value for you application + * helps check for denial of services attacks. + */ + public WebSocket08FrameDecoder(boolean expectMaskedFrames, boolean allowExtensions, int maxFramePayloadLength) { + this(expectMaskedFrames, allowExtensions, maxFramePayloadLength, false); + } + + /** + * Constructor + * + * @param expectMaskedFrames + * Web socket servers must set this to true processed incoming masked payload. Client implementations + * must set this to false. + * @param allowExtensions + * Flag to allow reserved extension bits to be used or not + * @param maxFramePayloadLength + * Maximum length of a frame's payload. Setting this to an appropriate value for you application + * helps check for denial of services attacks. + * @param allowMaskMismatch + * When set to true, frames which are not masked properly according to the standard will still be + * accepted. + */ + public WebSocket08FrameDecoder(boolean expectMaskedFrames, boolean allowExtensions, int maxFramePayloadLength, + boolean allowMaskMismatch) { + this(WebSocketDecoderConfig.newBuilder() + .expectMaskedFrames(expectMaskedFrames) + .allowExtensions(allowExtensions) + .maxFramePayloadLength(maxFramePayloadLength) + .allowMaskMismatch(allowMaskMismatch) + .build()); + } + + /** + * Constructor + * + * @param decoderConfig + * Frames decoder configuration. + */ + public WebSocket08FrameDecoder(WebSocketDecoderConfig decoderConfig) { + this.config = ObjectUtil.checkNotNull(decoderConfig, "decoderConfig"); + } + + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { + // Discard all data received if closing handshake was received before. + if (receivedClosingHandshake) { + in.skipBytes(actualReadableBytes()); + return; + } + + switch (state) { + case READING_FIRST: + if (!in.isReadable()) { + return; + } + + framePayloadLength = 0; + + // FIN, RSV, OPCODE + byte b = in.readByte(); + frameFinalFlag = (b & 0x80) != 0; + frameRsv = (b & 0x70) >> 4; + frameOpcode = b & 0x0F; + + if (logger.isTraceEnabled()) { + logger.trace("Decoding WebSocket Frame opCode={}", frameOpcode); + } + + state = State.READING_SECOND; + case READING_SECOND: + if (!in.isReadable()) { + return; + } + // MASK, PAYLOAD LEN 1 + b = in.readByte(); + frameMasked = (b & 0x80) != 0; + framePayloadLen1 = b & 0x7F; + + if (frameRsv != 0 && !config.allowExtensions()) { + protocolViolation(ctx, in, "RSV != 0 and no extension negotiated, RSV:" + frameRsv); + return; + } + + if (!config.allowMaskMismatch() && config.expectMaskedFrames() != frameMasked) { + protocolViolation(ctx, in, "received a frame that is not masked as expected"); + return; + } + + if (frameOpcode > 7) { // control frame (have MSB in opcode set) + + // control frames MUST NOT be fragmented + if (!frameFinalFlag) { + protocolViolation(ctx, in, "fragmented control frame"); + return; + } + + // control frames MUST have payload 125 octets or less + if (framePayloadLen1 > 125) { + protocolViolation(ctx, in, "control frame with payload length > 125 octets"); + return; + } + + // check for reserved control frame opcodes + if (!(frameOpcode == OPCODE_CLOSE || frameOpcode == OPCODE_PING + || frameOpcode == OPCODE_PONG)) { + protocolViolation(ctx, in, "control frame using reserved opcode " + frameOpcode); + return; + } + + // close frame : if there is a body, the first two bytes of the + // body MUST be a 2-byte unsigned integer (in network byte + // order) representing a getStatus code + if (frameOpcode == 8 && framePayloadLen1 == 1) { + protocolViolation(ctx, in, "received close control frame with payload len 1"); + return; + } + } else { // data frame + // check for reserved data frame opcodes + if (!(frameOpcode == OPCODE_CONT || frameOpcode == OPCODE_TEXT + || frameOpcode == OPCODE_BINARY)) { + protocolViolation(ctx, in, "data frame using reserved opcode " + frameOpcode); + return; + } + + // check opcode vs message fragmentation state 1/2 + if (fragmentedFramesCount == 0 && frameOpcode == OPCODE_CONT) { + protocolViolation(ctx, in, "received continuation data frame outside fragmented message"); + return; + } + + // check opcode vs message fragmentation state 2/2 + if (fragmentedFramesCount != 0 && frameOpcode != OPCODE_CONT) { + protocolViolation(ctx, in, + "received non-continuation data frame while inside fragmented message"); + return; + } + } + + state = State.READING_SIZE; + case READING_SIZE: + + // Read frame payload length + if (framePayloadLen1 == 126) { + if (in.readableBytes() < 2) { + return; + } + framePayloadLength = in.readUnsignedShort(); + if (framePayloadLength < 126) { + protocolViolation(ctx, in, "invalid data frame length (not using minimal length encoding)"); + return; + } + } else if (framePayloadLen1 == 127) { + if (in.readableBytes() < 8) { + return; + } + framePayloadLength = in.readLong(); + if (framePayloadLength < 0) { + protocolViolation(ctx, in, "invalid data frame length (negative length)"); + return; + } + + if (framePayloadLength < 65536) { + protocolViolation(ctx, in, "invalid data frame length (not using minimal length encoding)"); + return; + } + } else { + framePayloadLength = framePayloadLen1; + } + + if (framePayloadLength > config.maxFramePayloadLength()) { + protocolViolation(ctx, in, WebSocketCloseStatus.MESSAGE_TOO_BIG, + "Max frame length of " + config.maxFramePayloadLength() + " has been exceeded."); + return; + } + + if (logger.isTraceEnabled()) { + logger.trace("Decoding WebSocket Frame length={}", framePayloadLength); + } + + state = State.MASKING_KEY; + case MASKING_KEY: + if (frameMasked) { + if (in.readableBytes() < 4) { + return; + } + mask = in.readInt(); + } + state = State.PAYLOAD; + case PAYLOAD: + if (in.readableBytes() < framePayloadLength) { + return; + } + + ByteBuf payloadBuffer = Unpooled.EMPTY_BUFFER; + try { + if (framePayloadLength > 0) { + payloadBuffer = readBytes(ctx.alloc(), in, toFrameLength(framePayloadLength)); + } + + // Now we have all the data, the next checkpoint must be the next + // frame + state = State.READING_FIRST; + + // Unmask data if needed + if (frameMasked & framePayloadLength > 0) { + unmask(payloadBuffer); + } + + // Processing ping/pong/close frames because they cannot be + // fragmented + if (frameOpcode == OPCODE_PING) { + out.add(new PingWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer)); + payloadBuffer = null; + return; + } + if (frameOpcode == OPCODE_PONG) { + out.add(new PongWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer)); + payloadBuffer = null; + return; + } + if (frameOpcode == OPCODE_CLOSE) { + receivedClosingHandshake = true; + checkCloseFrameBody(ctx, payloadBuffer); + out.add(new CloseWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer)); + payloadBuffer = null; + return; + } + + // Processing for possible fragmented messages for text and binary + // frames + if (frameFinalFlag) { + // Final frame of the sequence. Apparently ping frames are + // allowed in the middle of a fragmented message + fragmentedFramesCount = 0; + } else { + // Increment counter + fragmentedFramesCount++; + } + + // Return the frame + if (frameOpcode == OPCODE_TEXT) { + out.add(new TextWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer)); + payloadBuffer = null; + return; + } else if (frameOpcode == OPCODE_BINARY) { + out.add(new BinaryWebSocketFrame(frameFinalFlag, frameRsv, payloadBuffer)); + payloadBuffer = null; + return; + } else if (frameOpcode == OPCODE_CONT) { + out.add(new ContinuationWebSocketFrame(frameFinalFlag, frameRsv, + payloadBuffer)); + payloadBuffer = null; + return; + } else { + throw new UnsupportedOperationException("Cannot decode web socket frame with opcode: " + + frameOpcode); + } + } finally { + if (payloadBuffer != null) { + payloadBuffer.release(); + } + } + case CORRUPT: + if (in.isReadable()) { + // If we don't keep reading Netty will throw an exception saying + // we can't return null if no bytes read and state not changed. + in.readByte(); + } + return; + default: + throw new Error("Shouldn't reach here."); + } + } + + private void unmask(ByteBuf frame) { + int i = frame.readerIndex(); + int end = frame.writerIndex(); + + ByteOrder order = frame.order(); + + int intMask = mask; + // Avoid sign extension on widening primitive conversion + long longMask = intMask & 0xFFFFFFFFL; + longMask |= longMask << 32; + + for (int lim = end - 7; i < lim; i += 8) { + frame.setLong(i, frame.getLong(i) ^ longMask); + } + + if (i < end - 3) { + frame.setInt(i, frame.getInt(i) ^ (int) longMask); + i += 4; + } + + if (order == ByteOrder.LITTLE_ENDIAN) { + intMask = Integer.reverseBytes(intMask); + } + + int maskOffset = 0; + for (; i < end; i++) { + frame.setByte(i, frame.getByte(i) ^ WebSocketUtil.byteAtIndex(intMask, maskOffset++ & 3)); + } + } + + private void protocolViolation(ChannelHandlerContext ctx, ByteBuf in, String reason) { + protocolViolation(ctx, in, WebSocketCloseStatus.PROTOCOL_ERROR, reason); + } + + private void protocolViolation(ChannelHandlerContext ctx, ByteBuf in, WebSocketCloseStatus status, String reason) { + protocolViolation(ctx, in, new CorruptedWebSocketFrameException(status, reason)); + } + + private void protocolViolation(ChannelHandlerContext ctx, ByteBuf in, CorruptedWebSocketFrameException ex) { + state = State.CORRUPT; + int readableBytes = in.readableBytes(); + if (readableBytes > 0) { + // Fix for memory leak, caused by ByteToMessageDecoder#channelRead: + // buffer 'cumulation' is released ONLY when no more readable bytes available. + in.skipBytes(readableBytes); + } + if (ctx.channel().isActive() && config.closeOnProtocolViolation()) { + Object closeMessage; + if (receivedClosingHandshake) { + closeMessage = Unpooled.EMPTY_BUFFER; + } else { + WebSocketCloseStatus closeStatus = ex.closeStatus(); + String reasonText = ex.getMessage(); + if (reasonText == null) { + reasonText = closeStatus.reasonText(); + } + closeMessage = new CloseWebSocketFrame(closeStatus, reasonText); + } + ctx.writeAndFlush(closeMessage).addListener(ChannelFutureListener.CLOSE); + } + throw ex; + } + + private static int toFrameLength(long l) { + if (l > Integer.MAX_VALUE) { + throw new TooLongFrameException("Length:" + l); + } else { + return (int) l; + } + } + + /** */ + protected void checkCloseFrameBody( + ChannelHandlerContext ctx, ByteBuf buffer) { + if (buffer == null || !buffer.isReadable()) { + return; + } + if (buffer.readableBytes() < 2) { + protocolViolation(ctx, buffer, WebSocketCloseStatus.INVALID_PAYLOAD_DATA, "Invalid close frame body"); + } + + // Must have 2 byte integer within the valid range + int statusCode = buffer.getShort(buffer.readerIndex()); + if (!WebSocketCloseStatus.isValidStatusCode(statusCode)) { + protocolViolation(ctx, buffer, "Invalid close frame getStatus code: " + statusCode); + } + + // May have UTF-8 message + if (buffer.readableBytes() > 2) { + try { + new Utf8Validator().check(buffer, buffer.readerIndex() + 2, buffer.readableBytes() - 2); + } catch (CorruptedWebSocketFrameException ex) { + protocolViolation(ctx, buffer, ex); + } + } + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocket08FrameEncoder.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocket08FrameEncoder.java new file mode 100644 index 0000000..1795ae1 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocket08FrameEncoder.java @@ -0,0 +1,232 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +// (BSD License: https://www.opensource.org/licenses/bsd-license) +// +// Copyright (c) 2011, Joe Walnes and contributors +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or +// without modification, are permitted provided that the +// following conditions are met: +// +// * Redistributions of source code must retain the above +// copyright notice, this list of conditions and the +// following disclaimer. +// +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other +// materials provided with the distribution. +// +// * Neither the name of the Webbit nor the names of +// its contributors may be used to endorse or promote products +// derived from this software without specific prior written +// permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +// CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF +// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +// INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE +// GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR +// BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT +// OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. + +package io.netty.handler.codec.http.websocketx; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.MessageToMessageEncoder; +import io.netty.handler.codec.TooLongFrameException; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.nio.ByteOrder; +import java.util.List; + +/** + *

+ * Encodes a web socket frame into wire protocol version 8 format. This code was forked from webbit and modified. + *

+ */ +public class WebSocket08FrameEncoder extends MessageToMessageEncoder implements WebSocketFrameEncoder { + + private static final InternalLogger logger = InternalLoggerFactory.getInstance(WebSocket08FrameEncoder.class); + + private static final byte OPCODE_CONT = 0x0; + private static final byte OPCODE_TEXT = 0x1; + private static final byte OPCODE_BINARY = 0x2; + private static final byte OPCODE_CLOSE = 0x8; + private static final byte OPCODE_PING = 0x9; + private static final byte OPCODE_PONG = 0xA; + + /** + * The size threshold for gathering writes. Non-Masked messages bigger than this size will be sent fragmented as + * a header and a content ByteBuf whereas messages smaller than the size will be merged into a single buffer and + * sent at once.
+ * Masked messages will always be sent at once. + */ + private static final int GATHERING_WRITE_THRESHOLD = 1024; + + private final boolean maskPayload; + + /** + * Constructor + * + * @param maskPayload + * Web socket clients must set this to true to mask payload. Server implementations must set this to + * false. + */ + public WebSocket08FrameEncoder(boolean maskPayload) { + this.maskPayload = maskPayload; + } + + @Override + protected void encode(ChannelHandlerContext ctx, WebSocketFrame msg, List out) throws Exception { + final ByteBuf data = msg.content(); + + byte opcode; + if (msg instanceof TextWebSocketFrame) { + opcode = OPCODE_TEXT; + } else if (msg instanceof PingWebSocketFrame) { + opcode = OPCODE_PING; + } else if (msg instanceof PongWebSocketFrame) { + opcode = OPCODE_PONG; + } else if (msg instanceof CloseWebSocketFrame) { + opcode = OPCODE_CLOSE; + } else if (msg instanceof BinaryWebSocketFrame) { + opcode = OPCODE_BINARY; + } else if (msg instanceof ContinuationWebSocketFrame) { + opcode = OPCODE_CONT; + } else { + throw new UnsupportedOperationException("Cannot encode frame of type: " + msg.getClass().getName()); + } + + int length = data.readableBytes(); + + if (logger.isTraceEnabled()) { + logger.trace("Encoding WebSocket Frame opCode={} length={}", opcode, length); + } + + int b0 = 0; + if (msg.isFinalFragment()) { + b0 |= 1 << 7; + } + b0 |= msg.rsv() % 8 << 4; + b0 |= opcode % 128; + + if (opcode == OPCODE_PING && length > 125) { + throw new TooLongFrameException("invalid payload for PING (payload length must be <= 125, was " + length); + } + + boolean release = true; + ByteBuf buf = null; + try { + int maskLength = maskPayload ? 4 : 0; + if (length <= 125) { + int size = 2 + maskLength + length; + buf = ctx.alloc().buffer(size); + buf.writeByte(b0); + byte b = (byte) (maskPayload ? 0x80 | (byte) length : (byte) length); + buf.writeByte(b); + } else if (length <= 0xFFFF) { + int size = 4 + maskLength; + if (maskPayload || length <= GATHERING_WRITE_THRESHOLD) { + size += length; + } + buf = ctx.alloc().buffer(size); + buf.writeByte(b0); + buf.writeByte(maskPayload ? 0xFE : 126); + buf.writeByte(length >>> 8 & 0xFF); + buf.writeByte(length & 0xFF); + } else { + int size = 10 + maskLength; + if (maskPayload) { + size += length; + } + buf = ctx.alloc().buffer(size); + buf.writeByte(b0); + buf.writeByte(maskPayload ? 0xFF : 127); + buf.writeLong(length); + } + + // Write payload + if (maskPayload) { + int mask = PlatformDependent.threadLocalRandom().nextInt(Integer.MAX_VALUE); + buf.writeInt(mask); + + if (data.isReadable()) { + + ByteOrder srcOrder = data.order(); + ByteOrder dstOrder = buf.order(); + + int i = data.readerIndex(); + int end = data.writerIndex(); + + if (srcOrder == dstOrder) { + // Use the optimized path only when byte orders match. + // Avoid sign extension on widening primitive conversion + long longMask = mask & 0xFFFFFFFFL; + longMask |= longMask << 32; + + // If the byte order of our buffers it little endian we have to bring our mask + // into the same format, because getInt() and writeInt() will use a reversed byte order + if (srcOrder == ByteOrder.LITTLE_ENDIAN) { + longMask = Long.reverseBytes(longMask); + } + + for (int lim = end - 7; i < lim; i += 8) { + buf.writeLong(data.getLong(i) ^ longMask); + } + + if (i < end - 3) { + buf.writeInt(data.getInt(i) ^ (int) longMask); + i += 4; + } + } + int maskOffset = 0; + for (; i < end; i++) { + byte byteData = data.getByte(i); + buf.writeByte(byteData ^ WebSocketUtil.byteAtIndex(mask, maskOffset++ & 3)); + } + } + out.add(buf); + } else { + if (buf.writableBytes() >= data.readableBytes()) { + // merge buffers as this is cheaper then a gathering write if the payload is small enough + buf.writeBytes(data); + out.add(buf); + } else { + out.add(buf); + out.add(data.retain()); + } + } + release = false; + } finally { + if (release && buf != null) { + buf.release(); + } + } + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocket13FrameDecoder.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocket13FrameDecoder.java new file mode 100644 index 0000000..bb8363c --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocket13FrameDecoder.java @@ -0,0 +1,111 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +// (BSD License: https://www.opensource.org/licenses/bsd-license) +// +// Copyright (c) 2011, Joe Walnes and contributors +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or +// without modification, are permitted provided that the +// following conditions are met: +// +// * Redistributions of source code must retain the above +// copyright notice, this list of conditions and the +// following disclaimer. +// +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other +// materials provided with the distribution. +// +// * Neither the name of the Webbit nor the names of +// its contributors may be used to endorse or promote products +// derived from this software without specific prior written +// permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +// CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF +// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +// INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE +// GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR +// BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT +// OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. + +package io.netty.handler.codec.http.websocketx; + +/** + * Decodes a web socket frame from wire protocol version 13 format. V13 is essentially the same as V8. + */ +public class WebSocket13FrameDecoder extends WebSocket08FrameDecoder { + + /** + * Constructor + * + * @param expectMaskedFrames + * Web socket servers must set this to true processed incoming masked payload. Client implementations + * must set this to false. + * @param allowExtensions + * Flag to allow reserved extension bits to be used or not + * @param maxFramePayloadLength + * Maximum length of a frame's payload. Setting this to an appropriate value for you application + * helps check for denial of services attacks. + */ + public WebSocket13FrameDecoder(boolean expectMaskedFrames, boolean allowExtensions, int maxFramePayloadLength) { + this(expectMaskedFrames, allowExtensions, maxFramePayloadLength, false); + } + + /** + * Constructor + * + * @param expectMaskedFrames + * Web socket servers must set this to true processed incoming masked payload. Client implementations + * must set this to false. + * @param allowExtensions + * Flag to allow reserved extension bits to be used or not + * @param maxFramePayloadLength + * Maximum length of a frame's payload. Setting this to an appropriate value for you application + * helps check for denial of services attacks. + * @param allowMaskMismatch + * When set to true, frames which are not masked properly according to the standard will still be + * accepted. + */ + public WebSocket13FrameDecoder(boolean expectMaskedFrames, boolean allowExtensions, int maxFramePayloadLength, + boolean allowMaskMismatch) { + this(WebSocketDecoderConfig.newBuilder() + .expectMaskedFrames(expectMaskedFrames) + .allowExtensions(allowExtensions) + .maxFramePayloadLength(maxFramePayloadLength) + .allowMaskMismatch(allowMaskMismatch) + .build()); + } + + /** + * Constructor + * + * @param decoderConfig + * Frames decoder configuration. + */ + public WebSocket13FrameDecoder(WebSocketDecoderConfig decoderConfig) { + super(decoderConfig); + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocket13FrameEncoder.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocket13FrameEncoder.java new file mode 100644 index 0000000..3586b83 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocket13FrameEncoder.java @@ -0,0 +1,73 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +// (BSD License: https://www.opensource.org/licenses/bsd-license) +// +// Copyright (c) 2011, Joe Walnes and contributors +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or +// without modification, are permitted provided that the +// following conditions are met: +// +// * Redistributions of source code must retain the above +// copyright notice, this list of conditions and the +// following disclaimer. +// +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the +// following disclaimer in the documentation and/or other +// materials provided with the distribution. +// +// * Neither the name of the Webbit nor the names of +// its contributors may be used to endorse or promote products +// derived from this software without specific prior written +// permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND +// CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, +// INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF +// MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR +// CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, +// INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE +// GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR +// BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF +// LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT +// OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +// POSSIBILITY OF SUCH DAMAGE. + +package io.netty.handler.codec.http.websocketx; + +/** + *

+ * Encodes a web socket frame into wire protocol version 13 format. V13 is essentially the same as V8. + *

+ */ +public class WebSocket13FrameEncoder extends WebSocket08FrameEncoder { + + /** + * Constructor + * + * @param maskPayload + * Web socket clients must set this to true to mask payload. Server implementations must set this to + * false. + */ + public WebSocket13FrameEncoder(boolean maskPayload) { + super(maskPayload); + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketChunkedInput.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketChunkedInput.java new file mode 100644 index 0000000..7af27d5 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketChunkedInput.java @@ -0,0 +1,114 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.stream.ChunkedInput; +import io.netty.util.internal.ObjectUtil; + +/** + * A {@link ChunkedInput} that fetches data chunk by chunk for use with WebSocket chunked transfers. + *

+ * Each chunk from the input data will be wrapped within a {@link ContinuationWebSocketFrame}. + * At the end of the input data, {@link ContinuationWebSocketFrame} with finalFragment will be written. + *

+ */ +public final class WebSocketChunkedInput implements ChunkedInput { + private final ChunkedInput input; + private final int rsv; + + /** + * Creates a new instance using the specified input. + * @param input {@link ChunkedInput} containing data to write + */ + public WebSocketChunkedInput(ChunkedInput input) { + this(input, 0); + } + + /** + * Creates a new instance using the specified input. + * @param input {@link ChunkedInput} containing data to write + * @param rsv RSV1, RSV2, RSV3 used for extensions + * + * @throws NullPointerException if {@code input} is null + */ + public WebSocketChunkedInput(ChunkedInput input, int rsv) { + this.input = ObjectUtil.checkNotNull(input, "input"); + this.rsv = rsv; + } + + /** + * @return {@code true} if and only if there is no data left in the stream + * and the stream has reached at its end. + */ + @Override + public boolean isEndOfInput() throws Exception { + return input.isEndOfInput(); + } + + /** + * Releases the resources associated with the input. + */ + @Override + public void close() throws Exception { + input.close(); + } + + /** + * @deprecated Use {@link #readChunk(ByteBufAllocator)}. + * + * Fetches a chunked data from the stream. Once this method returns the last chunk + * and thus the stream has reached at its end, any subsequent {@link #isEndOfInput()} + * call must return {@code true}. + * + * @param ctx {@link ChannelHandlerContext} context of channelHandler + * @return {@link WebSocketFrame} contain chunk of data + */ + @Deprecated + @Override + public WebSocketFrame readChunk(ChannelHandlerContext ctx) throws Exception { + return readChunk(ctx.alloc()); + } + + /** + * Fetches a chunked data from the stream. Once this method returns the last chunk + * and thus the stream has reached at its end, any subsequent {@link #isEndOfInput()} + * call must return {@code true}. + * + * @param allocator {@link ByteBufAllocator} + * @return {@link WebSocketFrame} contain chunk of data + */ + @Override + public WebSocketFrame readChunk(ByteBufAllocator allocator) throws Exception { + ByteBuf buf = input.readChunk(allocator); + if (buf == null) { + return null; + } + return new ContinuationWebSocketFrame(input.isEndOfInput(), rsv, buf); + } + + @Override + public long length() { + return input.length(); + } + + @Override + public long progress() { + return input.progress(); + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshakeException.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshakeException.java new file mode 100644 index 0000000..69f1839 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshakeException.java @@ -0,0 +1,55 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx; + +import io.netty.handler.codec.http.DefaultHttpResponse; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpResponse; +import io.netty.util.ReferenceCounted; + +/** + * Client exception during handshaking process. + * + *

IMPORTANT: This exception does not contain any {@link ReferenceCounted} fields + * e.g. {@link FullHttpResponse}, so no special treatment is needed. + */ +public final class WebSocketClientHandshakeException extends WebSocketHandshakeException { + + private static final long serialVersionUID = 1L; + + private final HttpResponse response; + + public WebSocketClientHandshakeException(String message) { + this(message, null); + } + + public WebSocketClientHandshakeException(String message, HttpResponse httpResponse) { + super(message); + if (httpResponse != null) { + response = new DefaultHttpResponse(httpResponse.protocolVersion(), + httpResponse.status(), httpResponse.headers()); + } else { + response = null; + } + } + + /** + * Returns a {@link HttpResponse response} if exception occurs during response validation otherwise {@code null}. + */ + public HttpResponse response() { + return response; + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker.java new file mode 100644 index 0000000..a4b647d --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker.java @@ -0,0 +1,784 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx; + +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelOutboundInvoker; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.ChannelPromise; +import io.netty.handler.codec.http.DefaultFullHttpResponse; +import io.netty.handler.codec.http.EmptyHttpHeaders; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpClientCodec; +import io.netty.handler.codec.http.HttpContentDecompressor; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaders; +import io.netty.handler.codec.http.HttpObject; +import io.netty.handler.codec.http.HttpObjectAggregator; +import io.netty.handler.codec.http.HttpRequestEncoder; +import io.netty.handler.codec.http.HttpResponse; +import io.netty.handler.codec.http.HttpResponseDecoder; +import io.netty.handler.codec.http.HttpScheme; +import io.netty.handler.codec.http.LastHttpContent; +import io.netty.util.NetUtil; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.internal.ObjectUtil; + +import java.net.URI; +import java.nio.channels.ClosedChannelException; +import java.util.Locale; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; + +/** + * Base class for web socket client handshake implementations + */ +public abstract class WebSocketClientHandshaker { + + private static final String HTTP_SCHEME_PREFIX = HttpScheme.HTTP + "://"; + private static final String HTTPS_SCHEME_PREFIX = HttpScheme.HTTPS + "://"; + protected static final int DEFAULT_FORCE_CLOSE_TIMEOUT_MILLIS = 10000; + + private final URI uri; + + private final WebSocketVersion version; + + private volatile boolean handshakeComplete; + + private volatile long forceCloseTimeoutMillis = DEFAULT_FORCE_CLOSE_TIMEOUT_MILLIS; + + private volatile int forceCloseInit; + + private static final AtomicIntegerFieldUpdater FORCE_CLOSE_INIT_UPDATER = + AtomicIntegerFieldUpdater.newUpdater(WebSocketClientHandshaker.class, "forceCloseInit"); + + private volatile boolean forceCloseComplete; + + private final String expectedSubprotocol; + + private volatile String actualSubprotocol; + + protected final HttpHeaders customHeaders; + + private final int maxFramePayloadLength; + + private final boolean absoluteUpgradeUrl; + + protected final boolean generateOriginHeader; + + /** + * Base constructor + * + * @param uri + * URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be + * sent to this URL. + * @param version + * Version of web socket specification to use to connect to the server + * @param subprotocol + * Sub protocol request sent to the server. + * @param customHeaders + * Map of custom headers to add to the client request + * @param maxFramePayloadLength + * Maximum length of a frame's payload + */ + protected WebSocketClientHandshaker(URI uri, WebSocketVersion version, String subprotocol, + HttpHeaders customHeaders, int maxFramePayloadLength) { + this(uri, version, subprotocol, customHeaders, maxFramePayloadLength, DEFAULT_FORCE_CLOSE_TIMEOUT_MILLIS); + } + + /** + * Base constructor + * + * @param uri + * URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be + * sent to this URL. + * @param version + * Version of web socket specification to use to connect to the server + * @param subprotocol + * Sub protocol request sent to the server. + * @param customHeaders + * Map of custom headers to add to the client request + * @param maxFramePayloadLength + * Maximum length of a frame's payload + * @param forceCloseTimeoutMillis + * Close the connection if it was not closed by the server after timeout specified + */ + protected WebSocketClientHandshaker(URI uri, WebSocketVersion version, String subprotocol, + HttpHeaders customHeaders, int maxFramePayloadLength, + long forceCloseTimeoutMillis) { + this(uri, version, subprotocol, customHeaders, maxFramePayloadLength, forceCloseTimeoutMillis, false); + } + + /** + * Base constructor + * + * @param uri + * URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be + * sent to this URL. + * @param version + * Version of web socket specification to use to connect to the server + * @param subprotocol + * Sub protocol request sent to the server. + * @param customHeaders + * Map of custom headers to add to the client request + * @param maxFramePayloadLength + * Maximum length of a frame's payload + * @param forceCloseTimeoutMillis + * Close the connection if it was not closed by the server after timeout specified + * @param absoluteUpgradeUrl + * Use an absolute url for the Upgrade request, typically when connecting through an HTTP proxy over + * clear HTTP + */ + protected WebSocketClientHandshaker(URI uri, WebSocketVersion version, String subprotocol, + HttpHeaders customHeaders, int maxFramePayloadLength, + long forceCloseTimeoutMillis, boolean absoluteUpgradeUrl) { + this(uri, version, subprotocol, customHeaders, maxFramePayloadLength, forceCloseTimeoutMillis, + absoluteUpgradeUrl, true); + } + + /** + * Base constructor + * + * @param uri + * URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be + * sent to this URL. + * @param version + * Version of web socket specification to use to connect to the server + * @param subprotocol + * Sub protocol request sent to the server. + * @param customHeaders + * Map of custom headers to add to the client request + * @param maxFramePayloadLength + * Maximum length of a frame's payload + * @param forceCloseTimeoutMillis + * Close the connection if it was not closed by the server after timeout specified + * @param absoluteUpgradeUrl + * Use an absolute url for the Upgrade request, typically when connecting through an HTTP proxy over + * clear HTTP + * @param generateOriginHeader + * Allows to generate the `Origin`|`Sec-WebSocket-Origin` header value for handshake request + * according to the given webSocketURL + */ + protected WebSocketClientHandshaker(URI uri, WebSocketVersion version, String subprotocol, + HttpHeaders customHeaders, int maxFramePayloadLength, + long forceCloseTimeoutMillis, boolean absoluteUpgradeUrl, boolean generateOriginHeader) { + this.uri = uri; + this.version = version; + expectedSubprotocol = subprotocol; + this.customHeaders = customHeaders; + this.maxFramePayloadLength = maxFramePayloadLength; + this.forceCloseTimeoutMillis = forceCloseTimeoutMillis; + this.absoluteUpgradeUrl = absoluteUpgradeUrl; + this.generateOriginHeader = generateOriginHeader; + } + + /** + * Returns the URI to the web socket. e.g. "ws://myhost.com/path" + */ + public URI uri() { + return uri; + } + + /** + * Version of the web socket specification that is being used + */ + public WebSocketVersion version() { + return version; + } + + /** + * Returns the max length for any frame's payload + */ + public int maxFramePayloadLength() { + return maxFramePayloadLength; + } + + /** + * Flag to indicate if the opening handshake is complete + */ + public boolean isHandshakeComplete() { + return handshakeComplete; + } + + private void setHandshakeComplete() { + handshakeComplete = true; + } + + /** + * Returns the CSV of requested subprotocol(s) sent to the server as specified in the constructor + */ + public String expectedSubprotocol() { + return expectedSubprotocol; + } + + /** + * Returns the subprotocol response sent by the server. Only available after end of handshake. + * Null if no subprotocol was requested or confirmed by the server. + */ + public String actualSubprotocol() { + return actualSubprotocol; + } + + private void setActualSubprotocol(String actualSubprotocol) { + this.actualSubprotocol = actualSubprotocol; + } + + public long forceCloseTimeoutMillis() { + return forceCloseTimeoutMillis; + } + + /** + * Flag to indicate if the closing handshake was initiated because of timeout. + * For testing only. + */ + protected boolean isForceCloseComplete() { + return forceCloseComplete; + } + + /** + * Sets timeout to close the connection if it was not closed by the server. + * + * @param forceCloseTimeoutMillis + * Close the connection if it was not closed by the server after timeout specified + */ + public WebSocketClientHandshaker setForceCloseTimeoutMillis(long forceCloseTimeoutMillis) { + this.forceCloseTimeoutMillis = forceCloseTimeoutMillis; + return this; + } + + /** + * Begins the opening handshake + * + * @param channel + * Channel + */ + public ChannelFuture handshake(Channel channel) { + ObjectUtil.checkNotNull(channel, "channel"); + return handshake(channel, channel.newPromise()); + } + + /** + * Begins the opening handshake + * + * @param channel + * Channel + * @param promise + * the {@link ChannelPromise} to be notified when the opening handshake is sent + */ + public final ChannelFuture handshake(Channel channel, final ChannelPromise promise) { + ChannelPipeline pipeline = channel.pipeline(); + HttpResponseDecoder decoder = pipeline.get(HttpResponseDecoder.class); + if (decoder == null) { + HttpClientCodec codec = pipeline.get(HttpClientCodec.class); + if (codec == null) { + promise.setFailure(new IllegalStateException("ChannelPipeline does not contain " + + "an HttpResponseDecoder or HttpClientCodec")); + return promise; + } + } + + if (uri.getHost() == null) { + if (customHeaders == null || !customHeaders.contains(HttpHeaderNames.HOST)) { + promise.setFailure(new IllegalArgumentException("Cannot generate the 'host' header value," + + " webSocketURI should contain host or passed through customHeaders")); + return promise; + } + + if (generateOriginHeader && !customHeaders.contains(HttpHeaderNames.ORIGIN)) { + final String originName; + if (version == WebSocketVersion.V07 || version == WebSocketVersion.V08) { + originName = HttpHeaderNames.SEC_WEBSOCKET_ORIGIN.toString(); + } else { + originName = HttpHeaderNames.ORIGIN.toString(); + } + + promise.setFailure(new IllegalArgumentException("Cannot generate the '" + originName + "' header" + + " value, webSocketURI should contain host or disable generateOriginHeader or pass value" + + " through customHeaders")); + return promise; + } + } + + FullHttpRequest request = newHandshakeRequest(); + + channel.writeAndFlush(request).addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + if (future.isSuccess()) { + ChannelPipeline p = future.channel().pipeline(); + ChannelHandlerContext ctx = p.context(HttpRequestEncoder.class); + if (ctx == null) { + ctx = p.context(HttpClientCodec.class); + } + if (ctx == null) { + promise.setFailure(new IllegalStateException("ChannelPipeline does not contain " + + "an HttpRequestEncoder or HttpClientCodec")); + return; + } + p.addAfter(ctx.name(), "ws-encoder", newWebSocketEncoder()); + + promise.setSuccess(); + } else { + promise.setFailure(future.cause()); + } + } + }); + return promise; + } + + /** + * @return a new FullHttpRequest which will be used for the handshake. + */ + protected abstract FullHttpRequest newHandshakeRequest(); + + /** + * Validates and finishes the opening handshake initiated by {@link #handshake}}. + * + * @param channel + * Channel + * @param response + * HTTP response containing the closing handshake details + */ + public final void finishHandshake(Channel channel, FullHttpResponse response) { + verify(response); + + // Verify the subprotocol that we received from the server. + // This must be one of our expected subprotocols - or null/empty if we didn't want to speak a subprotocol + String receivedProtocol = response.headers().get(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL); + receivedProtocol = receivedProtocol != null ? receivedProtocol.trim() : null; + String expectedProtocol = expectedSubprotocol != null ? expectedSubprotocol : ""; + boolean protocolValid = false; + + if (expectedProtocol.isEmpty() && receivedProtocol == null) { + // No subprotocol required and none received + protocolValid = true; + setActualSubprotocol(expectedSubprotocol); // null or "" - we echo what the user requested + } else if (!expectedProtocol.isEmpty() && receivedProtocol != null && !receivedProtocol.isEmpty()) { + // We require a subprotocol and received one -> verify it + for (String protocol : expectedProtocol.split(",")) { + if (protocol.trim().equals(receivedProtocol)) { + protocolValid = true; + setActualSubprotocol(receivedProtocol); + break; + } + } + } // else mixed cases - which are all errors + + if (!protocolValid) { + throw new WebSocketClientHandshakeException(String.format( + "Invalid subprotocol. Actual: %s. Expected one of: %s", + receivedProtocol, expectedSubprotocol), response); + } + + setHandshakeComplete(); + + final ChannelPipeline p = channel.pipeline(); + // Remove decompressor from pipeline if its in use + HttpContentDecompressor decompressor = p.get(HttpContentDecompressor.class); + if (decompressor != null) { + p.remove(decompressor); + } + + // Remove aggregator if present before + HttpObjectAggregator aggregator = p.get(HttpObjectAggregator.class); + if (aggregator != null) { + p.remove(aggregator); + } + + ChannelHandlerContext ctx = p.context(HttpResponseDecoder.class); + if (ctx == null) { + ctx = p.context(HttpClientCodec.class); + if (ctx == null) { + throw new IllegalStateException("ChannelPipeline does not contain " + + "an HttpRequestEncoder or HttpClientCodec"); + } + final HttpClientCodec codec = (HttpClientCodec) ctx.handler(); + // Remove the encoder part of the codec as the user may start writing frames after this method returns. + codec.removeOutboundHandler(); + + p.addAfter(ctx.name(), "ws-decoder", newWebsocketDecoder()); + + // Delay the removal of the decoder so the user can setup the pipeline if needed to handle + // WebSocketFrame messages. + // See https://github.com/netty/netty/issues/4533 + channel.eventLoop().execute(new Runnable() { + @Override + public void run() { + p.remove(codec); + } + }); + } else { + if (p.get(HttpRequestEncoder.class) != null) { + // Remove the encoder part of the codec as the user may start writing frames after this method returns. + p.remove(HttpRequestEncoder.class); + } + final ChannelHandlerContext context = ctx; + p.addAfter(context.name(), "ws-decoder", newWebsocketDecoder()); + + // Delay the removal of the decoder so the user can setup the pipeline if needed to handle + // WebSocketFrame messages. + // See https://github.com/netty/netty/issues/4533 + channel.eventLoop().execute(new Runnable() { + @Override + public void run() { + p.remove(context.handler()); + } + }); + } + } + + /** + * Process the opening handshake initiated by {@link #handshake}}. + * + * @param channel + * Channel + * @param response + * HTTP response containing the closing handshake details + * @return future + * the {@link ChannelFuture} which is notified once the handshake completes. + */ + public final ChannelFuture processHandshake(final Channel channel, HttpResponse response) { + return processHandshake(channel, response, channel.newPromise()); + } + + /** + * Process the opening handshake initiated by {@link #handshake}}. + * + * @param channel + * Channel + * @param response + * HTTP response containing the closing handshake details + * @param promise + * the {@link ChannelPromise} to notify once the handshake completes. + * @return future + * the {@link ChannelFuture} which is notified once the handshake completes. + */ + public final ChannelFuture processHandshake(final Channel channel, HttpResponse response, + final ChannelPromise promise) { + if (response instanceof FullHttpResponse) { + try { + finishHandshake(channel, (FullHttpResponse) response); + promise.setSuccess(); + } catch (Throwable cause) { + promise.setFailure(cause); + } + } else { + ChannelPipeline p = channel.pipeline(); + ChannelHandlerContext ctx = p.context(HttpResponseDecoder.class); + if (ctx == null) { + ctx = p.context(HttpClientCodec.class); + if (ctx == null) { + return promise.setFailure(new IllegalStateException("ChannelPipeline does not contain " + + "an HttpResponseDecoder or HttpClientCodec")); + } + } + + String aggregatorCtx = ctx.name(); + // Content-Length and Transfer-Encoding must not be sent in any response with a status code of 1xx or 204. + if (version == WebSocketVersion.V00) { + // Add aggregator and ensure we feed the HttpResponse so it is aggregated. A limit of 8192 should be + // more then enough for the websockets handshake payload. + aggregatorCtx = "httpAggregator"; + p.addAfter(ctx.name(), aggregatorCtx, new HttpObjectAggregator(8192)); + } + + p.addAfter(aggregatorCtx, "handshaker", new ChannelInboundHandlerAdapter() { + + private FullHttpResponse fullHttpResponse; + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if (msg instanceof HttpObject) { + try { + handleHandshakeResponse(ctx, (HttpObject) msg); + } finally { + ReferenceCountUtil.release(msg); + } + } else { + super.channelRead(ctx, msg); + } + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + // Remove ourself and fail the handshake promise. + ctx.pipeline().remove(this); + promise.setFailure(cause); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + try { + // Fail promise if Channel was closed + if (!promise.isDone()) { + promise.tryFailure(new ClosedChannelException()); + } + ctx.fireChannelInactive(); + } finally { + releaseFullHttpResponse(); + } + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { + releaseFullHttpResponse(); + } + + private void handleHandshakeResponse(ChannelHandlerContext ctx, HttpObject response) { + if (response instanceof FullHttpResponse) { + ctx.pipeline().remove(this); + tryFinishHandshake((FullHttpResponse) response); + return; + } + + if (response instanceof LastHttpContent) { + assert fullHttpResponse != null; + FullHttpResponse handshakeResponse = fullHttpResponse; + fullHttpResponse = null; + try { + ctx.pipeline().remove(this); + tryFinishHandshake(handshakeResponse); + } finally { + handshakeResponse.release(); + } + return; + } + + if (response instanceof HttpResponse) { + HttpResponse httpResponse = (HttpResponse) response; + fullHttpResponse = new DefaultFullHttpResponse(httpResponse.protocolVersion(), + httpResponse.status(), Unpooled.EMPTY_BUFFER, httpResponse.headers(), + EmptyHttpHeaders.INSTANCE); + if (httpResponse.decoderResult().isFailure()) { + fullHttpResponse.setDecoderResult(httpResponse.decoderResult()); + } + } + } + + private void tryFinishHandshake(FullHttpResponse fullHttpResponse) { + try { + finishHandshake(channel, fullHttpResponse); + promise.setSuccess(); + } catch (Throwable cause) { + promise.setFailure(cause); + } + } + + private void releaseFullHttpResponse() { + if (fullHttpResponse != null) { + fullHttpResponse.release(); + fullHttpResponse = null; + } + } + }); + try { + ctx.fireChannelRead(ReferenceCountUtil.retain(response)); + } catch (Throwable cause) { + promise.setFailure(cause); + } + } + return promise; + } + + /** + * Verify the {@link FullHttpResponse} and throws a {@link WebSocketHandshakeException} if something is wrong. + */ + protected abstract void verify(FullHttpResponse response); + + /** + * Returns the decoder to use after handshake is complete. + */ + protected abstract WebSocketFrameDecoder newWebsocketDecoder(); + + /** + * Returns the encoder to use after the handshake is complete. + */ + protected abstract WebSocketFrameEncoder newWebSocketEncoder(); + + /** + * Performs the closing handshake. + * + * When called from within a {@link ChannelHandler} you most likely want to use + * {@link #close(ChannelHandlerContext, CloseWebSocketFrame)}. + * + * @param channel + * Channel + * @param frame + * Closing Frame that was received + */ + public ChannelFuture close(Channel channel, CloseWebSocketFrame frame) { + ObjectUtil.checkNotNull(channel, "channel"); + return close(channel, frame, channel.newPromise()); + } + + /** + * Performs the closing handshake + * + * When called from within a {@link ChannelHandler} you most likely want to use + * {@link #close(ChannelHandlerContext, CloseWebSocketFrame, ChannelPromise)}. + * + * @param channel + * Channel + * @param frame + * Closing Frame that was received + * @param promise + * the {@link ChannelPromise} to be notified when the closing handshake is done + */ + public ChannelFuture close(Channel channel, CloseWebSocketFrame frame, ChannelPromise promise) { + ObjectUtil.checkNotNull(channel, "channel"); + return close0(channel, channel, frame, promise); + } + + /** + * Performs the closing handshake + * + * @param ctx + * the {@link ChannelHandlerContext} to use. + * @param frame + * Closing Frame that was received + */ + public ChannelFuture close(ChannelHandlerContext ctx, CloseWebSocketFrame frame) { + ObjectUtil.checkNotNull(ctx, "ctx"); + return close(ctx, frame, ctx.newPromise()); + } + + /** + * Performs the closing handshake + * + * @param ctx + * the {@link ChannelHandlerContext} to use. + * @param frame + * Closing Frame that was received + * @param promise + * the {@link ChannelPromise} to be notified when the closing handshake is done + */ + public ChannelFuture close(ChannelHandlerContext ctx, CloseWebSocketFrame frame, ChannelPromise promise) { + ObjectUtil.checkNotNull(ctx, "ctx"); + return close0(ctx, ctx.channel(), frame, promise); + } + + private ChannelFuture close0(final ChannelOutboundInvoker invoker, final Channel channel, + CloseWebSocketFrame frame, ChannelPromise promise) { + invoker.writeAndFlush(frame, promise); + final long forceCloseTimeoutMillis = this.forceCloseTimeoutMillis; + final WebSocketClientHandshaker handshaker = this; + if (forceCloseTimeoutMillis <= 0 || !channel.isActive() || forceCloseInit != 0) { + return promise; + } + + promise.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + // If flush operation failed, there is no reason to expect + // a server to receive CloseFrame. Thus this should be handled + // by the application separately. + // Also, close might be called twice from different threads. + if (future.isSuccess() && channel.isActive() && + FORCE_CLOSE_INIT_UPDATER.compareAndSet(handshaker, 0, 1)) { + final Future forceCloseFuture = channel.eventLoop().schedule(new Runnable() { + @Override + public void run() { + if (channel.isActive()) { + invoker.close(); + forceCloseComplete = true; + } + } + }, forceCloseTimeoutMillis, TimeUnit.MILLISECONDS); + + channel.closeFuture().addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + forceCloseFuture.cancel(false); + } + }); + } + } + }); + return promise; + } + + /** + * Return the constructed raw path for the give {@link URI}. + */ + protected String upgradeUrl(URI wsURL) { + if (absoluteUpgradeUrl) { + return wsURL.toString(); + } + + String path = wsURL.getRawPath(); + path = path == null || path.isEmpty() ? "/" : path; + String query = wsURL.getRawQuery(); + return query != null && !query.isEmpty() ? path + '?' + query : path; + } + + static CharSequence websocketHostValue(URI wsURL) { + int port = wsURL.getPort(); + if (port == -1) { + return wsURL.getHost(); + } + String host = wsURL.getHost(); + String scheme = wsURL.getScheme(); + if (port == HttpScheme.HTTP.port()) { + return HttpScheme.HTTP.name().contentEquals(scheme) + || WebSocketScheme.WS.name().contentEquals(scheme) ? + host : NetUtil.toSocketAddressString(host, port); + } + if (port == HttpScheme.HTTPS.port()) { + return HttpScheme.HTTPS.name().contentEquals(scheme) + || WebSocketScheme.WSS.name().contentEquals(scheme) ? + host : NetUtil.toSocketAddressString(host, port); + } + + // if the port is not standard (80/443) its needed to add the port to the header. + // See https://tools.ietf.org/html/rfc6454#section-6.2 + return NetUtil.toSocketAddressString(host, port); + } + + static CharSequence websocketOriginValue(URI wsURL) { + String scheme = wsURL.getScheme(); + final String schemePrefix; + int port = wsURL.getPort(); + final int defaultPort; + if (WebSocketScheme.WSS.name().contentEquals(scheme) + || HttpScheme.HTTPS.name().contentEquals(scheme) + || (scheme == null && port == WebSocketScheme.WSS.port())) { + + schemePrefix = HTTPS_SCHEME_PREFIX; + defaultPort = WebSocketScheme.WSS.port(); + } else { + schemePrefix = HTTP_SCHEME_PREFIX; + defaultPort = WebSocketScheme.WS.port(); + } + + // Convert uri-host to lower case (by RFC 6454, chapter 4 "Origin of a URI") + String host = wsURL.getHost().toLowerCase(Locale.US); + + if (port != defaultPort && port != -1) { + // if the port is not standard (80/443) its needed to add the port to the header. + // See https://tools.ietf.org/html/rfc6454#section-6.2 + return schemePrefix + NetUtil.toSocketAddressString(host, port); + } + return schemePrefix + host; + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker00.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker00.java new file mode 100644 index 0000000..6b63e38 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker00.java @@ -0,0 +1,340 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.handler.codec.http.DefaultFullHttpRequest; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaderValues; +import io.netty.handler.codec.http.HttpHeaders; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.handler.codec.http.HttpVersion; +import io.netty.util.internal.PlatformDependent; + +import java.net.URI; +import java.nio.ByteBuffer; + +/** + *

+ * Performs client side opening and closing handshakes for web socket specification version draft-ietf-hybi-thewebsocketprotocol- + * 00 + *

+ *

+ * A very large portion of this code was taken from the Netty 3.2 HTTP example. + *

+ */ +public class WebSocketClientHandshaker00 extends WebSocketClientHandshaker { + + private ByteBuf expectedChallengeResponseBytes; + + /** + * Creates a new instance with the specified destination WebSocket location and version to initiate. + * + * @param webSocketURL + * URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be + * sent to this URL. + * @param version + * Version of web socket specification to use to connect to the server + * @param subprotocol + * Sub protocol request sent to the server. + * @param customHeaders + * Map of custom headers to add to the client request + * @param maxFramePayloadLength + * Maximum length of a frame's payload + */ + public WebSocketClientHandshaker00(URI webSocketURL, WebSocketVersion version, String subprotocol, + HttpHeaders customHeaders, int maxFramePayloadLength) { + this(webSocketURL, version, subprotocol, customHeaders, maxFramePayloadLength, + DEFAULT_FORCE_CLOSE_TIMEOUT_MILLIS); + } + + /** + * Creates a new instance with the specified destination WebSocket location and version to initiate. + * + * @param webSocketURL + * URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be + * sent to this URL. + * @param version + * Version of web socket specification to use to connect to the server + * @param subprotocol + * Sub protocol request sent to the server. + * @param customHeaders + * Map of custom headers to add to the client request + * @param maxFramePayloadLength + * Maximum length of a frame's payload + * @param forceCloseTimeoutMillis + * Close the connection if it was not closed by the server after timeout specified + */ + public WebSocketClientHandshaker00(URI webSocketURL, WebSocketVersion version, String subprotocol, + HttpHeaders customHeaders, int maxFramePayloadLength, + long forceCloseTimeoutMillis) { + this(webSocketURL, version, subprotocol, customHeaders, maxFramePayloadLength, forceCloseTimeoutMillis, false); + } + + /** + * Creates a new instance with the specified destination WebSocket location and version to initiate. + * + * @param webSocketURL + * URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be + * sent to this URL. + * @param version + * Version of web socket specification to use to connect to the server + * @param subprotocol + * Sub protocol request sent to the server. + * @param customHeaders + * Map of custom headers to add to the client request + * @param maxFramePayloadLength + * Maximum length of a frame's payload + * @param forceCloseTimeoutMillis + * Close the connection if it was not closed by the server after timeout specified + * @param absoluteUpgradeUrl + * Use an absolute url for the Upgrade request, typically when connecting through an HTTP proxy over + * clear HTTP + */ + WebSocketClientHandshaker00(URI webSocketURL, WebSocketVersion version, String subprotocol, + HttpHeaders customHeaders, int maxFramePayloadLength, + long forceCloseTimeoutMillis, boolean absoluteUpgradeUrl) { + this(webSocketURL, version, subprotocol, customHeaders, maxFramePayloadLength, forceCloseTimeoutMillis, + absoluteUpgradeUrl, true); + } + + /** + * Creates a new instance with the specified destination WebSocket location and version to initiate. + * + * @param webSocketURL + * URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be + * sent to this URL. + * @param version + * Version of web socket specification to use to connect to the server + * @param subprotocol + * Sub protocol request sent to the server. + * @param customHeaders + * Map of custom headers to add to the client request + * @param maxFramePayloadLength + * Maximum length of a frame's payload + * @param forceCloseTimeoutMillis + * Close the connection if it was not closed by the server after timeout specified + * @param absoluteUpgradeUrl + * Use an absolute url for the Upgrade request, typically when connecting through an HTTP proxy over + * clear HTTP + * @param generateOriginHeader + * Allows to generate the `Origin` header value for handshake request + * according to the given webSocketURL + */ + WebSocketClientHandshaker00(URI webSocketURL, WebSocketVersion version, String subprotocol, + HttpHeaders customHeaders, int maxFramePayloadLength, + long forceCloseTimeoutMillis, boolean absoluteUpgradeUrl, + boolean generateOriginHeader) { + super(webSocketURL, version, subprotocol, customHeaders, maxFramePayloadLength, forceCloseTimeoutMillis, + absoluteUpgradeUrl, generateOriginHeader); + } + + /** + *

+ * Sends the opening request to the server: + *

+ * + *
+     * GET /demo HTTP/1.1
+     * Upgrade: WebSocket
+     * Connection: Upgrade
+     * Host: example.com
+     * Origin: http://example.com
+     * Sec-WebSocket-Key1: 4 @1  46546xW%0l 1 5
+     * Sec-WebSocket-Key2: 12998 5 Y3 1  .P00
+     *
+     * ^n:ds[4U
+     * 
+ * + */ + @Override + protected FullHttpRequest newHandshakeRequest() { + // Make keys + int spaces1 = WebSocketUtil.randomNumber(1, 12); + int spaces2 = WebSocketUtil.randomNumber(1, 12); + + int max1 = Integer.MAX_VALUE / spaces1; + int max2 = Integer.MAX_VALUE / spaces2; + + int number1 = WebSocketUtil.randomNumber(0, max1); + int number2 = WebSocketUtil.randomNumber(0, max2); + + int product1 = number1 * spaces1; + int product2 = number2 * spaces2; + + String key1 = Integer.toString(product1); + String key2 = Integer.toString(product2); + + key1 = insertRandomCharacters(key1); + key2 = insertRandomCharacters(key2); + + key1 = insertSpaces(key1, spaces1); + key2 = insertSpaces(key2, spaces2); + + byte[] key3 = WebSocketUtil.randomBytes(8); + + ByteBuffer buffer = ByteBuffer.allocate(4); + buffer.putInt(number1); + byte[] number1Array = buffer.array(); + buffer = ByteBuffer.allocate(4); + buffer.putInt(number2); + byte[] number2Array = buffer.array(); + + byte[] challenge = new byte[16]; + System.arraycopy(number1Array, 0, challenge, 0, 4); + System.arraycopy(number2Array, 0, challenge, 4, 4); + System.arraycopy(key3, 0, challenge, 8, 8); + expectedChallengeResponseBytes = Unpooled.wrappedBuffer(WebSocketUtil.md5(challenge)); + + URI wsURL = uri(); + + // Format request + FullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, upgradeUrl(wsURL), + Unpooled.wrappedBuffer(key3)); + HttpHeaders headers = request.headers(); + + if (customHeaders != null) { + headers.add(customHeaders); + if (!headers.contains(HttpHeaderNames.HOST)) { + // Only add HOST header if customHeaders did not contain it. + // + // See https://github.com/netty/netty/issues/10101 + headers.set(HttpHeaderNames.HOST, websocketHostValue(wsURL)); + } + } else { + headers.set(HttpHeaderNames.HOST, websocketHostValue(wsURL)); + } + + headers.set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET) + .set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE) + .set(HttpHeaderNames.SEC_WEBSOCKET_KEY1, key1) + .set(HttpHeaderNames.SEC_WEBSOCKET_KEY2, key2); + + if (generateOriginHeader && !headers.contains(HttpHeaderNames.ORIGIN)) { + headers.set(HttpHeaderNames.ORIGIN, websocketOriginValue(wsURL)); + } + + String expectedSubprotocol = expectedSubprotocol(); + if (expectedSubprotocol != null && !expectedSubprotocol.isEmpty()) { + headers.set(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, expectedSubprotocol); + } + + // Set Content-Length to workaround some known defect. + // See also: https://www.ietf.org/mail-archive/web/hybi/current/msg02149.html + headers.set(HttpHeaderNames.CONTENT_LENGTH, key3.length); + return request; + } + + /** + *

+ * Process server response: + *

+ * + *
+     * HTTP/1.1 101 WebSocket Protocol Handshake
+     * Upgrade: WebSocket
+     * Connection: Upgrade
+     * Sec-WebSocket-Origin: http://example.com
+     * Sec-WebSocket-Location: ws://example.com/demo
+     * Sec-WebSocket-Protocol: sample
+     *
+     * 8jKS'y:G*Co,Wxa-
+     * 
+ * + * @param response + * HTTP response returned from the server for the request sent by beginOpeningHandshake00(). + * @throws WebSocketHandshakeException + */ + @Override + protected void verify(FullHttpResponse response) { + HttpResponseStatus status = response.status(); + if (!HttpResponseStatus.SWITCHING_PROTOCOLS.equals(status)) { + throw new WebSocketClientHandshakeException("Invalid handshake response getStatus: " + status, response); + } + + HttpHeaders headers = response.headers(); + CharSequence upgrade = headers.get(HttpHeaderNames.UPGRADE); + if (!HttpHeaderValues.WEBSOCKET.contentEqualsIgnoreCase(upgrade)) { + throw new WebSocketClientHandshakeException("Invalid handshake response upgrade: " + upgrade, response); + } + + if (!headers.containsValue(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE, true)) { + throw new WebSocketClientHandshakeException("Invalid handshake response connection: " + + headers.get(HttpHeaderNames.CONNECTION), response); + } + + ByteBuf challenge = response.content(); + if (!challenge.equals(expectedChallengeResponseBytes)) { + throw new WebSocketClientHandshakeException("Invalid challenge", response); + } + } + + private static String insertRandomCharacters(String key) { + int count = WebSocketUtil.randomNumber(1, 12); + + char[] randomChars = new char[count]; + int randCount = 0; + while (randCount < count) { + int rand = PlatformDependent.threadLocalRandom().nextInt(0x7e) + 0x21; + if (0x21 < rand && rand < 0x2f || 0x3a < rand && rand < 0x7e) { + randomChars[randCount] = (char) rand; + randCount += 1; + } + } + + for (int i = 0; i < count; i++) { + int split = WebSocketUtil.randomNumber(0, key.length()); + String part1 = key.substring(0, split); + String part2 = key.substring(split); + key = part1 + randomChars[i] + part2; + } + + return key; + } + + private static String insertSpaces(String key, int spaces) { + for (int i = 0; i < spaces; i++) { + int split = WebSocketUtil.randomNumber(1, key.length() - 1); + String part1 = key.substring(0, split); + String part2 = key.substring(split); + key = part1 + ' ' + part2; + } + + return key; + } + + @Override + protected WebSocketFrameDecoder newWebsocketDecoder() { + return new WebSocket00FrameDecoder(maxFramePayloadLength()); + } + + @Override + protected WebSocketFrameEncoder newWebSocketEncoder() { + return new WebSocket00FrameEncoder(); + } + + @Override + public WebSocketClientHandshaker00 setForceCloseTimeoutMillis(long forceCloseTimeoutMillis) { + super.setForceCloseTimeoutMillis(forceCloseTimeoutMillis); + return this; + } + +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker07.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker07.java new file mode 100644 index 0000000..b0d1ace --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker07.java @@ -0,0 +1,346 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx; + +import io.netty.buffer.Unpooled; +import io.netty.handler.codec.http.DefaultFullHttpRequest; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaderValues; +import io.netty.handler.codec.http.HttpHeaders; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.handler.codec.http.HttpVersion; +import io.netty.util.CharsetUtil; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.net.URI; + +/** + *

+ * Performs client side opening and closing handshakes for web socket specification version draft-ietf-hybi-thewebsocketprotocol- + * 10 + *

+ */ +public class WebSocketClientHandshaker07 extends WebSocketClientHandshaker { + + private static final InternalLogger logger = InternalLoggerFactory.getInstance(WebSocketClientHandshaker07.class); + public static final String MAGIC_GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + + private String expectedChallengeResponseString; + + private final boolean allowExtensions; + private final boolean performMasking; + private final boolean allowMaskMismatch; + + /** + * Creates a new instance. + * + * @param webSocketURL + * URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be + * sent to this URL. + * @param version + * Version of web socket specification to use to connect to the server + * @param subprotocol + * Sub protocol request sent to the server. + * @param allowExtensions + * Allow extensions to be used in the reserved bits of the web socket frame + * @param customHeaders + * Map of custom headers to add to the client request + * @param maxFramePayloadLength + * Maximum length of a frame's payload + */ + public WebSocketClientHandshaker07(URI webSocketURL, WebSocketVersion version, String subprotocol, + boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength) { + this(webSocketURL, version, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength, true, false); + } + + /** + * Creates a new instance. + * + * @param webSocketURL + * URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be + * sent to this URL. + * @param version + * Version of web socket specification to use to connect to the server + * @param subprotocol + * Sub protocol request sent to the server. + * @param allowExtensions + * Allow extensions to be used in the reserved bits of the web socket frame + * @param customHeaders + * Map of custom headers to add to the client request + * @param maxFramePayloadLength + * Maximum length of a frame's payload + * @param performMasking + * Whether to mask all written websocket frames. This must be set to true in order to be fully compatible + * with the websocket specifications. Client applications that communicate with a non-standard server + * which doesn't require masking might set this to false to achieve a higher performance. + * @param allowMaskMismatch + * When set to true, frames which are not masked properly according to the standard will still be + * accepted. + */ + public WebSocketClientHandshaker07(URI webSocketURL, WebSocketVersion version, String subprotocol, + boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength, + boolean performMasking, boolean allowMaskMismatch) { + this(webSocketURL, version, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength, performMasking, + allowMaskMismatch, DEFAULT_FORCE_CLOSE_TIMEOUT_MILLIS); + } + + /** + * Creates a new instance. + * + * @param webSocketURL + * URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be + * sent to this URL. + * @param version + * Version of web socket specification to use to connect to the server + * @param subprotocol + * Sub protocol request sent to the server. + * @param allowExtensions + * Allow extensions to be used in the reserved bits of the web socket frame + * @param customHeaders + * Map of custom headers to add to the client request + * @param maxFramePayloadLength + * Maximum length of a frame's payload + * @param performMasking + * Whether to mask all written websocket frames. This must be set to true in order to be fully compatible + * with the websocket specifications. Client applications that communicate with a non-standard server + * which doesn't require masking might set this to false to achieve a higher performance. + * @param allowMaskMismatch + * When set to true, frames which are not masked properly according to the standard will still be + * accepted + * @param forceCloseTimeoutMillis + * Close the connection if it was not closed by the server after timeout specified. + */ + public WebSocketClientHandshaker07(URI webSocketURL, WebSocketVersion version, String subprotocol, + boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength, + boolean performMasking, boolean allowMaskMismatch, long forceCloseTimeoutMillis) { + this(webSocketURL, version, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength, performMasking, + allowMaskMismatch, forceCloseTimeoutMillis, false); + } + + /** + * Creates a new instance. + * + * @param webSocketURL + * URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be + * sent to this URL. + * @param version + * Version of web socket specification to use to connect to the server + * @param subprotocol + * Sub protocol request sent to the server. + * @param allowExtensions + * Allow extensions to be used in the reserved bits of the web socket frame + * @param customHeaders + * Map of custom headers to add to the client request + * @param maxFramePayloadLength + * Maximum length of a frame's payload + * @param performMasking + * Whether to mask all written websocket frames. This must be set to true in order to be fully compatible + * with the websocket specifications. Client applications that communicate with a non-standard server + * which doesn't require masking might set this to false to achieve a higher performance. + * @param allowMaskMismatch + * When set to true, frames which are not masked properly according to the standard will still be + * accepted + * @param forceCloseTimeoutMillis + * Close the connection if it was not closed by the server after timeout specified. + * @param absoluteUpgradeUrl + * Use an absolute url for the Upgrade request, typically when connecting through an HTTP proxy over + * clear HTTP + */ + WebSocketClientHandshaker07(URI webSocketURL, WebSocketVersion version, String subprotocol, + boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength, + boolean performMasking, boolean allowMaskMismatch, long forceCloseTimeoutMillis, + boolean absoluteUpgradeUrl) { + this(webSocketURL, version, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength, performMasking, + allowMaskMismatch, forceCloseTimeoutMillis, absoluteUpgradeUrl, true); + } + + /** + * Creates a new instance. + * + * @param webSocketURL + * URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be + * sent to this URL. + * @param version + * Version of web socket specification to use to connect to the server + * @param subprotocol + * Sub protocol request sent to the server. + * @param allowExtensions + * Allow extensions to be used in the reserved bits of the web socket frame + * @param customHeaders + * Map of custom headers to add to the client request + * @param maxFramePayloadLength + * Maximum length of a frame's payload + * @param performMasking + * Whether to mask all written websocket frames. This must be set to true in order to be fully compatible + * with the websocket specifications. Client applications that communicate with a non-standard server + * which doesn't require masking might set this to false to achieve a higher performance. + * @param allowMaskMismatch + * When set to true, frames which are not masked properly according to the standard will still be + * accepted + * @param forceCloseTimeoutMillis + * Close the connection if it was not closed by the server after timeout specified. + * @param absoluteUpgradeUrl + * Use an absolute url for the Upgrade request, typically when connecting through an HTTP proxy over + * clear HTTP + * @param generateOriginHeader + * Allows to generate a `Sec-WebSocket-Origin` header value for handshake request + * according to the given webSocketURL + */ + WebSocketClientHandshaker07(URI webSocketURL, WebSocketVersion version, String subprotocol, + boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength, + boolean performMasking, boolean allowMaskMismatch, long forceCloseTimeoutMillis, + boolean absoluteUpgradeUrl, boolean generateOriginHeader) { + super(webSocketURL, version, subprotocol, customHeaders, maxFramePayloadLength, forceCloseTimeoutMillis, + absoluteUpgradeUrl, generateOriginHeader); + this.allowExtensions = allowExtensions; + this.performMasking = performMasking; + this.allowMaskMismatch = allowMaskMismatch; + } + + /** + * /** + *

+ * Sends the opening request to the server: + *

+ * + *
+     * GET /chat HTTP/1.1
+     * Host: server.example.com
+     * Upgrade: websocket
+     * Connection: Upgrade
+     * Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
+     * Sec-WebSocket-Origin: http://example.com
+     * Sec-WebSocket-Protocol: chat, superchat
+     * Sec-WebSocket-Version: 7
+     * 
+ * + */ + @Override + protected FullHttpRequest newHandshakeRequest() { + URI wsURL = uri(); + + // Get 16 bit nonce and base 64 encode it + byte[] nonce = WebSocketUtil.randomBytes(16); + String key = WebSocketUtil.base64(nonce); + + String acceptSeed = key + MAGIC_GUID; + byte[] sha1 = WebSocketUtil.sha1(acceptSeed.getBytes(CharsetUtil.US_ASCII)); + expectedChallengeResponseString = WebSocketUtil.base64(sha1); + + if (logger.isDebugEnabled()) { + logger.debug( + "WebSocket version 07 client handshake key: {}, expected response: {}", + key, expectedChallengeResponseString); + } + + // Format request + FullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, upgradeUrl(wsURL), + Unpooled.EMPTY_BUFFER); + HttpHeaders headers = request.headers(); + + if (customHeaders != null) { + headers.add(customHeaders); + if (!headers.contains(HttpHeaderNames.HOST)) { + // Only add HOST header if customHeaders did not contain it. + // + // See https://github.com/netty/netty/issues/10101 + headers.set(HttpHeaderNames.HOST, websocketHostValue(wsURL)); + } + } else { + headers.set(HttpHeaderNames.HOST, websocketHostValue(wsURL)); + } + + headers.set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET) + .set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE) + .set(HttpHeaderNames.SEC_WEBSOCKET_KEY, key); + + if (generateOriginHeader && !headers.contains(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN)) { + headers.set(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN, websocketOriginValue(wsURL)); + } + + String expectedSubprotocol = expectedSubprotocol(); + if (expectedSubprotocol != null && !expectedSubprotocol.isEmpty()) { + headers.set(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, expectedSubprotocol); + } + + headers.set(HttpHeaderNames.SEC_WEBSOCKET_VERSION, version().toAsciiString()); + return request; + } + + /** + *

+ * Process server response: + *

+ * + *
+     * HTTP/1.1 101 Switching Protocols
+     * Upgrade: websocket
+     * Connection: Upgrade
+     * Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=
+     * Sec-WebSocket-Protocol: chat
+     * 
+ * + * @param response + * HTTP response returned from the server for the request sent by beginOpeningHandshake00(). + * @throws WebSocketHandshakeException + */ + @Override + protected void verify(FullHttpResponse response) { + HttpResponseStatus status = response.status(); + if (!HttpResponseStatus.SWITCHING_PROTOCOLS.equals(status)) { + throw new WebSocketClientHandshakeException("Invalid handshake response getStatus: " + status, response); + } + + HttpHeaders headers = response.headers(); + CharSequence upgrade = headers.get(HttpHeaderNames.UPGRADE); + if (!HttpHeaderValues.WEBSOCKET.contentEqualsIgnoreCase(upgrade)) { + throw new WebSocketClientHandshakeException("Invalid handshake response upgrade: " + upgrade, response); + } + + if (!headers.containsValue(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE, true)) { + throw new WebSocketClientHandshakeException("Invalid handshake response connection: " + + headers.get(HttpHeaderNames.CONNECTION), response); + } + + CharSequence accept = headers.get(HttpHeaderNames.SEC_WEBSOCKET_ACCEPT); + if (accept == null || !accept.equals(expectedChallengeResponseString)) { + throw new WebSocketClientHandshakeException(String.format( + "Invalid challenge. Actual: %s. Expected: %s", accept, expectedChallengeResponseString), response); + } + } + + @Override + protected WebSocketFrameDecoder newWebsocketDecoder() { + return new WebSocket07FrameDecoder(false, allowExtensions, maxFramePayloadLength(), allowMaskMismatch); + } + + @Override + protected WebSocketFrameEncoder newWebSocketEncoder() { + return new WebSocket07FrameEncoder(performMasking); + } + + @Override + public WebSocketClientHandshaker07 setForceCloseTimeoutMillis(long forceCloseTimeoutMillis) { + super.setForceCloseTimeoutMillis(forceCloseTimeoutMillis); + return this; + } + +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker08.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker08.java new file mode 100644 index 0000000..8f2f7b7 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker08.java @@ -0,0 +1,348 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx; + +import io.netty.buffer.Unpooled; +import io.netty.handler.codec.http.DefaultFullHttpRequest; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaderValues; +import io.netty.handler.codec.http.HttpHeaders; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.handler.codec.http.HttpVersion; +import io.netty.util.CharsetUtil; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.net.URI; + +/** + *

+ * Performs client side opening and closing handshakes for web socket specification version draft-ietf-hybi-thewebsocketprotocol- + * 10 + *

+ */ +public class WebSocketClientHandshaker08 extends WebSocketClientHandshaker { + + private static final InternalLogger logger = InternalLoggerFactory.getInstance(WebSocketClientHandshaker08.class); + + public static final String MAGIC_GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + + private String expectedChallengeResponseString; + + private final boolean allowExtensions; + private final boolean performMasking; + private final boolean allowMaskMismatch; + + /** + * Creates a new instance. + * + * @param webSocketURL + * URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be + * sent to this URL. + * @param version + * Version of web socket specification to use to connect to the server + * @param subprotocol + * Sub protocol request sent to the server. + * @param allowExtensions + * Allow extensions to be used in the reserved bits of the web socket frame + * @param customHeaders + * Map of custom headers to add to the client request + * @param maxFramePayloadLength + * Maximum length of a frame's payload + */ + public WebSocketClientHandshaker08(URI webSocketURL, WebSocketVersion version, String subprotocol, + boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength) { + this(webSocketURL, version, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength, true, + false, DEFAULT_FORCE_CLOSE_TIMEOUT_MILLIS); + } + + /** + * Creates a new instance. + * + * @param webSocketURL + * URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be + * sent to this URL. + * @param version + * Version of web socket specification to use to connect to the server + * @param subprotocol + * Sub protocol request sent to the server. + * @param allowExtensions + * Allow extensions to be used in the reserved bits of the web socket frame + * @param customHeaders + * Map of custom headers to add to the client request + * @param maxFramePayloadLength + * Maximum length of a frame's payload + * @param performMasking + * Whether to mask all written websocket frames. This must be set to true in order to be fully compatible + * with the websocket specifications. Client applications that communicate with a non-standard server + * which doesn't require masking might set this to false to achieve a higher performance. + * @param allowMaskMismatch + * When set to true, frames which are not masked properly according to the standard will still be + * accepted + */ + public WebSocketClientHandshaker08(URI webSocketURL, WebSocketVersion version, String subprotocol, + boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength, + boolean performMasking, boolean allowMaskMismatch) { + this(webSocketURL, version, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength, performMasking, + allowMaskMismatch, DEFAULT_FORCE_CLOSE_TIMEOUT_MILLIS); + } + + /** + * Creates a new instance. + * + * @param webSocketURL + * URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be + * sent to this URL. + * @param version + * Version of web socket specification to use to connect to the server + * @param subprotocol + * Sub protocol request sent to the server. + * @param allowExtensions + * Allow extensions to be used in the reserved bits of the web socket frame + * @param customHeaders + * Map of custom headers to add to the client request + * @param maxFramePayloadLength + * Maximum length of a frame's payload + * @param performMasking + * Whether to mask all written websocket frames. This must be set to true in order to be fully compatible + * with the websocket specifications. Client applications that communicate with a non-standard server + * which doesn't require masking might set this to false to achieve a higher performance. + * @param allowMaskMismatch + * When set to true, frames which are not masked properly according to the standard will still be + * accepted + * @param forceCloseTimeoutMillis + * Close the connection if it was not closed by the server after timeout specified. + */ + public WebSocketClientHandshaker08(URI webSocketURL, WebSocketVersion version, String subprotocol, + boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength, + boolean performMasking, boolean allowMaskMismatch, long forceCloseTimeoutMillis) { + this(webSocketURL, version, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength, performMasking, + allowMaskMismatch, forceCloseTimeoutMillis, false, true); + } + + /** + * Creates a new instance. + * + * @param webSocketURL + * URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be + * sent to this URL. + * @param version + * Version of web socket specification to use to connect to the server + * @param subprotocol + * Sub protocol request sent to the server. + * @param allowExtensions + * Allow extensions to be used in the reserved bits of the web socket frame + * @param customHeaders + * Map of custom headers to add to the client request + * @param maxFramePayloadLength + * Maximum length of a frame's payload + * @param performMasking + * Whether to mask all written websocket frames. This must be set to true in order to be fully compatible + * with the websocket specifications. Client applications that communicate with a non-standard server + * which doesn't require masking might set this to false to achieve a higher performance. + * @param allowMaskMismatch + * When set to true, frames which are not masked properly according to the standard will still be + * accepted + * @param forceCloseTimeoutMillis + * Close the connection if it was not closed by the server after timeout specified. + * @param absoluteUpgradeUrl + * Use an absolute url for the Upgrade request, typically when connecting through an HTTP proxy over + * clear HTTP + */ + WebSocketClientHandshaker08(URI webSocketURL, WebSocketVersion version, String subprotocol, + boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength, + boolean performMasking, boolean allowMaskMismatch, long forceCloseTimeoutMillis, + boolean absoluteUpgradeUrl) { + this(webSocketURL, version, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength, performMasking, + allowMaskMismatch, forceCloseTimeoutMillis, absoluteUpgradeUrl, true); + } + + /** + * Creates a new instance. + * + * @param webSocketURL + * URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be + * sent to this URL. + * @param version + * Version of web socket specification to use to connect to the server + * @param subprotocol + * Sub protocol request sent to the server. + * @param allowExtensions + * Allow extensions to be used in the reserved bits of the web socket frame + * @param customHeaders + * Map of custom headers to add to the client request + * @param maxFramePayloadLength + * Maximum length of a frame's payload + * @param performMasking + * Whether to mask all written websocket frames. This must be set to true in order to be fully compatible + * with the websocket specifications. Client applications that communicate with a non-standard server + * which doesn't require masking might set this to false to achieve a higher performance. + * @param allowMaskMismatch + * When set to true, frames which are not masked properly according to the standard will still be + * accepted + * @param forceCloseTimeoutMillis + * Close the connection if it was not closed by the server after timeout specified. + * @param absoluteUpgradeUrl + * Use an absolute url for the Upgrade request, typically when connecting through an HTTP proxy over + * clear HTTP + * @param generateOriginHeader + * Allows to generate a `Sec-WebSocket-Origin` header value for handshake request + * according to the given webSocketURL + */ + WebSocketClientHandshaker08(URI webSocketURL, WebSocketVersion version, String subprotocol, + boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength, + boolean performMasking, boolean allowMaskMismatch, long forceCloseTimeoutMillis, + boolean absoluteUpgradeUrl, boolean generateOriginHeader) { + super(webSocketURL, version, subprotocol, customHeaders, maxFramePayloadLength, forceCloseTimeoutMillis, + absoluteUpgradeUrl, generateOriginHeader); + this.allowExtensions = allowExtensions; + this.performMasking = performMasking; + this.allowMaskMismatch = allowMaskMismatch; + } + + /** + * /** + *

+ * Sends the opening request to the server: + *

+ * + *
+     * GET /chat HTTP/1.1
+     * Host: server.example.com
+     * Upgrade: websocket
+     * Connection: Upgrade
+     * Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
+     * Sec-WebSocket-Origin: http://example.com
+     * Sec-WebSocket-Protocol: chat, superchat
+     * Sec-WebSocket-Version: 8
+     * 
+ * + */ + @Override + protected FullHttpRequest newHandshakeRequest() { + URI wsURL = uri(); + + // Get 16 bit nonce and base 64 encode it + byte[] nonce = WebSocketUtil.randomBytes(16); + String key = WebSocketUtil.base64(nonce); + + String acceptSeed = key + MAGIC_GUID; + byte[] sha1 = WebSocketUtil.sha1(acceptSeed.getBytes(CharsetUtil.US_ASCII)); + expectedChallengeResponseString = WebSocketUtil.base64(sha1); + + if (logger.isDebugEnabled()) { + logger.debug( + "WebSocket version 08 client handshake key: {}, expected response: {}", + key, expectedChallengeResponseString); + } + + // Format request + FullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, upgradeUrl(wsURL), + Unpooled.EMPTY_BUFFER); + HttpHeaders headers = request.headers(); + + if (customHeaders != null) { + headers.add(customHeaders); + if (!headers.contains(HttpHeaderNames.HOST)) { + // Only add HOST header if customHeaders did not contain it. + // + // See https://github.com/netty/netty/issues/10101 + headers.set(HttpHeaderNames.HOST, websocketHostValue(wsURL)); + } + } else { + headers.set(HttpHeaderNames.HOST, websocketHostValue(wsURL)); + } + + headers.set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET) + .set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE) + .set(HttpHeaderNames.SEC_WEBSOCKET_KEY, key); + + if (generateOriginHeader && !headers.contains(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN)) { + headers.set(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN, websocketOriginValue(wsURL)); + } + + String expectedSubprotocol = expectedSubprotocol(); + if (expectedSubprotocol != null && !expectedSubprotocol.isEmpty()) { + headers.set(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, expectedSubprotocol); + } + + headers.set(HttpHeaderNames.SEC_WEBSOCKET_VERSION, version().toAsciiString()); + return request; + } + + /** + *

+ * Process server response: + *

+ * + *
+     * HTTP/1.1 101 Switching Protocols
+     * Upgrade: websocket
+     * Connection: Upgrade
+     * Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=
+     * Sec-WebSocket-Protocol: chat
+     * 
+ * + * @param response + * HTTP response returned from the server for the request sent by beginOpeningHandshake00(). + * @throws WebSocketHandshakeException + */ + @Override + protected void verify(FullHttpResponse response) { + HttpResponseStatus status = response.status(); + if (!HttpResponseStatus.SWITCHING_PROTOCOLS.equals(status)) { + throw new WebSocketClientHandshakeException("Invalid handshake response getStatus: " + status, response); + } + + HttpHeaders headers = response.headers(); + CharSequence upgrade = headers.get(HttpHeaderNames.UPGRADE); + if (!HttpHeaderValues.WEBSOCKET.contentEqualsIgnoreCase(upgrade)) { + throw new WebSocketClientHandshakeException("Invalid handshake response upgrade: " + upgrade, response); + } + + if (!headers.containsValue(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE, true)) { + throw new WebSocketClientHandshakeException("Invalid handshake response connection: " + + headers.get(HttpHeaderNames.CONNECTION), response); + } + + CharSequence accept = headers.get(HttpHeaderNames.SEC_WEBSOCKET_ACCEPT); + if (accept == null || !accept.equals(expectedChallengeResponseString)) { + throw new WebSocketClientHandshakeException(String.format( + "Invalid challenge. Actual: %s. Expected: %s", accept, expectedChallengeResponseString), response); + } + } + + @Override + protected WebSocketFrameDecoder newWebsocketDecoder() { + return new WebSocket08FrameDecoder(false, allowExtensions, maxFramePayloadLength(), allowMaskMismatch); + } + + @Override + protected WebSocketFrameEncoder newWebSocketEncoder() { + return new WebSocket08FrameEncoder(performMasking); + } + + @Override + public WebSocketClientHandshaker08 setForceCloseTimeoutMillis(long forceCloseTimeoutMillis) { + super.setForceCloseTimeoutMillis(forceCloseTimeoutMillis); + return this; + } + +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker13.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker13.java new file mode 100644 index 0000000..277a547 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker13.java @@ -0,0 +1,360 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx; + +import io.netty.buffer.Unpooled; +import io.netty.handler.codec.http.DefaultFullHttpRequest; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaderValues; +import io.netty.handler.codec.http.HttpHeaders; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.handler.codec.http.HttpVersion; +import io.netty.util.CharsetUtil; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.net.URI; + +/** + *

+ * Performs client side opening and closing handshakes for web socket specification version draft-ietf-hybi-thewebsocketprotocol- + * 17 + *

+ */ +public class WebSocketClientHandshaker13 extends WebSocketClientHandshaker { + + private static final InternalLogger logger = InternalLoggerFactory.getInstance(WebSocketClientHandshaker13.class); + + public static final String MAGIC_GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + + private String expectedChallengeResponseString; + + private final boolean allowExtensions; + private final boolean performMasking; + private final boolean allowMaskMismatch; + + /** + * Creates a new instance. + * + * @param webSocketURL + * URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be + * sent to this URL. + * @param version + * Version of web socket specification to use to connect to the server + * @param subprotocol + * Sub protocol request sent to the server. + * @param allowExtensions + * Allow extensions to be used in the reserved bits of the web socket frame + * @param customHeaders + * Map of custom headers to add to the client request + * @param maxFramePayloadLength + * Maximum length of a frame's payload + */ + public WebSocketClientHandshaker13(URI webSocketURL, WebSocketVersion version, String subprotocol, + boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength) { + this(webSocketURL, version, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength, + true, false); + } + + /** + * Creates a new instance. + * + * @param webSocketURL + * URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be + * sent to this URL. + * @param version + * Version of web socket specification to use to connect to the server + * @param subprotocol + * Sub protocol request sent to the server. + * @param allowExtensions + * Allow extensions to be used in the reserved bits of the web socket frame + * @param customHeaders + * Map of custom headers to add to the client request + * @param maxFramePayloadLength + * Maximum length of a frame's payload + * @param performMasking + * Whether to mask all written websocket frames. This must be set to true in order to be fully compatible + * with the websocket specifications. Client applications that communicate with a non-standard server + * which doesn't require masking might set this to false to achieve a higher performance. + * @param allowMaskMismatch + * When set to true, frames which are not masked properly according to the standard will still be + * accepted. + */ + public WebSocketClientHandshaker13(URI webSocketURL, WebSocketVersion version, String subprotocol, + boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength, + boolean performMasking, boolean allowMaskMismatch) { + this(webSocketURL, version, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength, + performMasking, allowMaskMismatch, DEFAULT_FORCE_CLOSE_TIMEOUT_MILLIS); + } + + /** + * Creates a new instance. + * + * @param webSocketURL + * URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be + * sent to this URL. + * @param version + * Version of web socket specification to use to connect to the server + * @param subprotocol + * Sub protocol request sent to the server. + * @param allowExtensions + * Allow extensions to be used in the reserved bits of the web socket frame + * @param customHeaders + * Map of custom headers to add to the client request + * @param maxFramePayloadLength + * Maximum length of a frame's payload + * @param performMasking + * Whether to mask all written websocket frames. This must be set to true in order to be fully compatible + * with the websocket specifications. Client applications that communicate with a non-standard server + * which doesn't require masking might set this to false to achieve a higher performance. + * @param allowMaskMismatch + * When set to true, frames which are not masked properly according to the standard will still be + * accepted + * @param forceCloseTimeoutMillis + * Close the connection if it was not closed by the server after timeout specified. + */ + public WebSocketClientHandshaker13(URI webSocketURL, WebSocketVersion version, String subprotocol, + boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength, + boolean performMasking, boolean allowMaskMismatch, + long forceCloseTimeoutMillis) { + this(webSocketURL, version, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength, performMasking, + allowMaskMismatch, forceCloseTimeoutMillis, false); + } + + /** + * Creates a new instance. + * + * @param webSocketURL + * URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be + * sent to this URL. + * @param version + * Version of web socket specification to use to connect to the server + * @param subprotocol + * Sub protocol request sent to the server. + * @param allowExtensions + * Allow extensions to be used in the reserved bits of the web socket frame + * @param customHeaders + * Map of custom headers to add to the client request + * @param maxFramePayloadLength + * Maximum length of a frame's payload + * @param performMasking + * Whether to mask all written websocket frames. This must be set to true in order to be fully compatible + * with the websocket specifications. Client applications that communicate with a non-standard server + * which doesn't require masking might set this to false to achieve a higher performance. + * @param allowMaskMismatch + * When set to true, frames which are not masked properly according to the standard will still be + * accepted + * @param forceCloseTimeoutMillis + * Close the connection if it was not closed by the server after timeout specified. + * @param absoluteUpgradeUrl + * Use an absolute url for the Upgrade request, typically when connecting through an HTTP proxy over + * clear HTTP + */ + WebSocketClientHandshaker13(URI webSocketURL, WebSocketVersion version, String subprotocol, + boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength, + boolean performMasking, boolean allowMaskMismatch, + long forceCloseTimeoutMillis, boolean absoluteUpgradeUrl) { + this(webSocketURL, version, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength, performMasking, + allowMaskMismatch, forceCloseTimeoutMillis, absoluteUpgradeUrl, true); + } + + /** + * Creates a new instance. + * + * @param webSocketURL + * URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be + * sent to this URL. + * @param version + * Version of web socket specification to use to connect to the server + * @param subprotocol + * Sub protocol request sent to the server. + * @param allowExtensions + * Allow extensions to be used in the reserved bits of the web socket frame + * @param customHeaders + * Map of custom headers to add to the client request + * @param maxFramePayloadLength + * Maximum length of a frame's payload + * @param performMasking + * Whether to mask all written websocket frames. This must be set to true in order to be fully compatible + * with the websocket specifications. Client applications that communicate with a non-standard server + * which doesn't require masking might set this to false to achieve a higher performance. + * @param allowMaskMismatch + * When set to true, frames which are not masked properly according to the standard will still be + * accepted + * @param forceCloseTimeoutMillis + * Close the connection if it was not closed by the server after timeout specified. + * @param absoluteUpgradeUrl + * Use an absolute url for the Upgrade request, typically when connecting through an HTTP proxy over + * clear HTTP + * @param generateOriginHeader + * Allows to generate the `Origin` header value for handshake request + * according to the given webSocketURL + */ + WebSocketClientHandshaker13(URI webSocketURL, WebSocketVersion version, String subprotocol, + boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength, + boolean performMasking, boolean allowMaskMismatch, + long forceCloseTimeoutMillis, boolean absoluteUpgradeUrl, + boolean generateOriginHeader) { + super(webSocketURL, version, subprotocol, customHeaders, maxFramePayloadLength, forceCloseTimeoutMillis, + absoluteUpgradeUrl, generateOriginHeader); + this.allowExtensions = allowExtensions; + this.performMasking = performMasking; + this.allowMaskMismatch = allowMaskMismatch; + } + + /** + * /** + *

+ * Sends the opening request to the server: + *

+ * + *
+     * GET /chat HTTP/1.1
+     * Host: server.example.com
+     * Upgrade: websocket
+     * Connection: Upgrade
+     * Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
+     * Sec-WebSocket-Protocol: chat, superchat
+     * Sec-WebSocket-Version: 13
+     * 
+ * + */ + @Override + protected FullHttpRequest newHandshakeRequest() { + URI wsURL = uri(); + + // Get 16 bit nonce and base 64 encode it + byte[] nonce = WebSocketUtil.randomBytes(16); + String key = WebSocketUtil.base64(nonce); + + String acceptSeed = key + MAGIC_GUID; + byte[] sha1 = WebSocketUtil.sha1(acceptSeed.getBytes(CharsetUtil.US_ASCII)); + expectedChallengeResponseString = WebSocketUtil.base64(sha1); + + if (logger.isDebugEnabled()) { + logger.debug( + "WebSocket version 13 client handshake key: {}, expected response: {}", + key, expectedChallengeResponseString); + } + + // Format request + FullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, upgradeUrl(wsURL), + Unpooled.EMPTY_BUFFER); + HttpHeaders headers = request.headers(); + + if (customHeaders != null) { + headers.add(customHeaders); + if (!headers.contains(HttpHeaderNames.HOST)) { + // Only add HOST header if customHeaders did not contain it. + // + // See https://github.com/netty/netty/issues/10101 + headers.set(HttpHeaderNames.HOST, websocketHostValue(wsURL)); + } + } else { + headers.set(HttpHeaderNames.HOST, websocketHostValue(wsURL)); + } + + headers.set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET) + .set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE) + .set(HttpHeaderNames.SEC_WEBSOCKET_KEY, key); + + if (generateOriginHeader && !headers.contains(HttpHeaderNames.ORIGIN)) { + headers.set(HttpHeaderNames.ORIGIN, websocketOriginValue(wsURL)); + } + + String expectedSubprotocol = expectedSubprotocol(); + if (expectedSubprotocol != null && !expectedSubprotocol.isEmpty()) { + headers.set(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, expectedSubprotocol); + } + + headers.set(HttpHeaderNames.SEC_WEBSOCKET_VERSION, version().toAsciiString()); + return request; + } + + /** + *

+ * Process server response: + *

+ * + *
+     * HTTP/1.1 101 Switching Protocols
+     * Upgrade: websocket
+     * Connection: Upgrade
+     * Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=
+     * Sec-WebSocket-Protocol: chat
+     * 
+ * + * @param response + * HTTP response returned from the server for the request sent by beginOpeningHandshake00(). + * @throws WebSocketHandshakeException if handshake response is invalid. + */ + @Override + protected void verify(FullHttpResponse response) { + HttpResponseStatus status = response.status(); + if (!HttpResponseStatus.SWITCHING_PROTOCOLS.equals(status)) { + throw new WebSocketClientHandshakeException("Invalid handshake response getStatus: " + status, response); + } + + HttpHeaders headers = response.headers(); + CharSequence upgrade = headers.get(HttpHeaderNames.UPGRADE); + if (!HttpHeaderValues.WEBSOCKET.contentEqualsIgnoreCase(upgrade)) { + throw new WebSocketClientHandshakeException("Invalid handshake response upgrade: " + upgrade, response); + } + + if (!headers.containsValue(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE, true)) { + throw new WebSocketClientHandshakeException("Invalid handshake response connection: " + + headers.get(HttpHeaderNames.CONNECTION), response); + } + + CharSequence accept = headers.get(HttpHeaderNames.SEC_WEBSOCKET_ACCEPT); + if (accept == null || !accept.equals(expectedChallengeResponseString)) { + throw new WebSocketClientHandshakeException(String.format( + "Invalid challenge. Actual: %s. Expected: %s", accept, expectedChallengeResponseString), response); + } + } + + @Override + protected WebSocketFrameDecoder newWebsocketDecoder() { + return new WebSocket13FrameDecoder(false, allowExtensions, maxFramePayloadLength(), allowMaskMismatch); + } + + @Override + protected WebSocketFrameEncoder newWebSocketEncoder() { + return new WebSocket13FrameEncoder(performMasking); + } + + @Override + public WebSocketClientHandshaker13 setForceCloseTimeoutMillis(long forceCloseTimeoutMillis) { + super.setForceCloseTimeoutMillis(forceCloseTimeoutMillis); + return this; + } + + public boolean isAllowExtensions() { + return allowExtensions; + } + + public boolean isPerformMasking() { + return performMasking; + } + + public boolean isAllowMaskMismatch() { + return allowMaskMismatch; + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshakerFactory.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshakerFactory.java new file mode 100644 index 0000000..ce915eb --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshakerFactory.java @@ -0,0 +1,290 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx; + +import io.netty.handler.codec.http.HttpHeaders; + +import java.net.URI; + +import static io.netty.handler.codec.http.websocketx.WebSocketVersion.*; + +/** + * Creates a new {@link WebSocketClientHandshaker} of desired protocol version. + */ +public final class WebSocketClientHandshakerFactory { + + /** + * Private constructor so this static class cannot be instanced. + */ + private WebSocketClientHandshakerFactory() { + } + + /** + * Creates a new handshaker. + * + * @param webSocketURL + * URL for web socket communications. e.g "ws://myhost.com/mypath". + * Subsequent web socket frames will be sent to this URL. + * @param version + * Version of web socket specification to use to connect to the server + * @param subprotocol + * Sub protocol request sent to the server. Null if no sub-protocol support is required. + * @param allowExtensions + * Allow extensions to be used in the reserved bits of the web socket frame + * @param customHeaders + * Custom HTTP headers to send during the handshake + */ + public static WebSocketClientHandshaker newHandshaker( + URI webSocketURL, WebSocketVersion version, String subprotocol, + boolean allowExtensions, HttpHeaders customHeaders) { + return newHandshaker(webSocketURL, version, subprotocol, allowExtensions, customHeaders, 65536); + } + + /** + * Creates a new handshaker. + * + * @param webSocketURL + * URL for web socket communications. e.g "ws://myhost.com/mypath". + * Subsequent web socket frames will be sent to this URL. + * @param version + * Version of web socket specification to use to connect to the server + * @param subprotocol + * Sub protocol request sent to the server. Null if no sub-protocol support is required. + * @param allowExtensions + * Allow extensions to be used in the reserved bits of the web socket frame + * @param customHeaders + * Custom HTTP headers to send during the handshake + * @param maxFramePayloadLength + * Maximum allowable frame payload length. Setting this value to your application's + * requirement may reduce denial of service attacks using long data frames. + */ + public static WebSocketClientHandshaker newHandshaker( + URI webSocketURL, WebSocketVersion version, String subprotocol, + boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength) { + return newHandshaker(webSocketURL, version, subprotocol, allowExtensions, customHeaders, + maxFramePayloadLength, true, false); + } + + /** + * Creates a new handshaker. + * + * @param webSocketURL + * URL for web socket communications. e.g "ws://myhost.com/mypath". + * Subsequent web socket frames will be sent to this URL. + * @param version + * Version of web socket specification to use to connect to the server + * @param subprotocol + * Sub protocol request sent to the server. Null if no sub-protocol support is required. + * @param allowExtensions + * Allow extensions to be used in the reserved bits of the web socket frame + * @param customHeaders + * Custom HTTP headers to send during the handshake + * @param maxFramePayloadLength + * Maximum allowable frame payload length. Setting this value to your application's + * requirement may reduce denial of service attacks using long data frames. + * @param performMasking + * Whether to mask all written websocket frames. This must be set to true in order to be fully compatible + * with the websocket specifications. Client applications that communicate with a non-standard server + * which doesn't require masking might set this to false to achieve a higher performance. + * @param allowMaskMismatch + * When set to true, frames which are not masked properly according to the standard will still be + * accepted. + */ + public static WebSocketClientHandshaker newHandshaker( + URI webSocketURL, WebSocketVersion version, String subprotocol, + boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength, + boolean performMasking, boolean allowMaskMismatch) { + return newHandshaker(webSocketURL, version, subprotocol, allowExtensions, customHeaders, + maxFramePayloadLength, performMasking, allowMaskMismatch, -1); + } + + /** + * Creates a new handshaker. + * + * @param webSocketURL + * URL for web socket communications. e.g "ws://myhost.com/mypath". + * Subsequent web socket frames will be sent to this URL. + * @param version + * Version of web socket specification to use to connect to the server + * @param subprotocol + * Sub protocol request sent to the server. Null if no sub-protocol support is required. + * @param allowExtensions + * Allow extensions to be used in the reserved bits of the web socket frame + * @param customHeaders + * Custom HTTP headers to send during the handshake + * @param maxFramePayloadLength + * Maximum allowable frame payload length. Setting this value to your application's + * requirement may reduce denial of service attacks using long data frames. + * @param performMasking + * Whether to mask all written websocket frames. This must be set to true in order to be fully compatible + * with the websocket specifications. Client applications that communicate with a non-standard server + * which doesn't require masking might set this to false to achieve a higher performance. + * @param allowMaskMismatch + * When set to true, frames which are not masked properly according to the standard will still be + * accepted. + * @param forceCloseTimeoutMillis + * Close the connection if it was not closed by the server after timeout specified + */ + public static WebSocketClientHandshaker newHandshaker( + URI webSocketURL, WebSocketVersion version, String subprotocol, + boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength, + boolean performMasking, boolean allowMaskMismatch, long forceCloseTimeoutMillis) { + if (version == V13) { + return new WebSocketClientHandshaker13( + webSocketURL, V13, subprotocol, allowExtensions, customHeaders, + maxFramePayloadLength, performMasking, allowMaskMismatch, forceCloseTimeoutMillis); + } + if (version == V08) { + return new WebSocketClientHandshaker08( + webSocketURL, V08, subprotocol, allowExtensions, customHeaders, + maxFramePayloadLength, performMasking, allowMaskMismatch, forceCloseTimeoutMillis); + } + if (version == V07) { + return new WebSocketClientHandshaker07( + webSocketURL, V07, subprotocol, allowExtensions, customHeaders, + maxFramePayloadLength, performMasking, allowMaskMismatch, forceCloseTimeoutMillis); + } + if (version == V00) { + return new WebSocketClientHandshaker00( + webSocketURL, V00, subprotocol, customHeaders, maxFramePayloadLength, forceCloseTimeoutMillis); + } + + throw new WebSocketClientHandshakeException("Protocol version " + version + " not supported."); + } + + /** + * Creates a new handshaker. + * + * @param webSocketURL + * URL for web socket communications. e.g "ws://myhost.com/mypath". + * Subsequent web socket frames will be sent to this URL. + * @param version + * Version of web socket specification to use to connect to the server + * @param subprotocol + * Sub protocol request sent to the server. Null if no sub-protocol support is required. + * @param allowExtensions + * Allow extensions to be used in the reserved bits of the web socket frame + * @param customHeaders + * Custom HTTP headers to send during the handshake + * @param maxFramePayloadLength + * Maximum allowable frame payload length. Setting this value to your application's + * requirement may reduce denial of service attacks using long data frames. + * @param performMasking + * Whether to mask all written websocket frames. This must be set to true in order to be fully compatible + * with the websocket specifications. Client applications that communicate with a non-standard server + * which doesn't require masking might set this to false to achieve a higher performance. + * @param allowMaskMismatch + * When set to true, frames which are not masked properly according to the standard will still be + * accepted. + * @param forceCloseTimeoutMillis + * Close the connection if it was not closed by the server after timeout specified + * @param absoluteUpgradeUrl + * Use an absolute url for the Upgrade request, typically when connecting through an HTTP proxy over + * clear HTTP + */ + public static WebSocketClientHandshaker newHandshaker( + URI webSocketURL, WebSocketVersion version, String subprotocol, + boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength, + boolean performMasking, boolean allowMaskMismatch, long forceCloseTimeoutMillis, boolean absoluteUpgradeUrl) { + if (version == V13) { + return new WebSocketClientHandshaker13( + webSocketURL, V13, subprotocol, allowExtensions, customHeaders, + maxFramePayloadLength, performMasking, allowMaskMismatch, forceCloseTimeoutMillis, absoluteUpgradeUrl); + } + if (version == V08) { + return new WebSocketClientHandshaker08( + webSocketURL, V08, subprotocol, allowExtensions, customHeaders, + maxFramePayloadLength, performMasking, allowMaskMismatch, forceCloseTimeoutMillis, absoluteUpgradeUrl); + } + if (version == V07) { + return new WebSocketClientHandshaker07( + webSocketURL, V07, subprotocol, allowExtensions, customHeaders, + maxFramePayloadLength, performMasking, allowMaskMismatch, forceCloseTimeoutMillis, absoluteUpgradeUrl); + } + if (version == V00) { + return new WebSocketClientHandshaker00( + webSocketURL, V00, subprotocol, customHeaders, + maxFramePayloadLength, forceCloseTimeoutMillis, absoluteUpgradeUrl); + } + + throw new WebSocketClientHandshakeException("Protocol version " + version + " not supported."); + } + + /** + * Creates a new handshaker. + * + * @param webSocketURL + * URL for web socket communications. e.g "ws://myhost.com/mypath". + * Subsequent web socket frames will be sent to this URL. + * @param version + * Version of web socket specification to use to connect to the server + * @param subprotocol + * Sub protocol request sent to the server. Null if no sub-protocol support is required. + * @param allowExtensions + * Allow extensions to be used in the reserved bits of the web socket frame + * @param customHeaders + * Custom HTTP headers to send during the handshake + * @param maxFramePayloadLength + * Maximum allowable frame payload length. Setting this value to your application's + * requirement may reduce denial of service attacks using long data frames. + * @param performMasking + * Whether to mask all written websocket frames. This must be set to true in order to be fully compatible + * with the websocket specifications. Client applications that communicate with a non-standard server + * which doesn't require masking might set this to false to achieve a higher performance. + * @param allowMaskMismatch + * When set to true, frames which are not masked properly according to the standard will still be + * accepted. + * @param forceCloseTimeoutMillis + * Close the connection if it was not closed by the server after timeout specified + * @param absoluteUpgradeUrl + * Use an absolute url for the Upgrade request, typically when connecting through an HTTP proxy over + * clear HTTP + * @param generateOriginHeader + * Allows to generate the `Origin`|`Sec-WebSocket-Origin` header value for handshake request + * according to the given webSocketURL + */ + public static WebSocketClientHandshaker newHandshaker( + URI webSocketURL, WebSocketVersion version, String subprotocol, + boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength, + boolean performMasking, boolean allowMaskMismatch, long forceCloseTimeoutMillis, + boolean absoluteUpgradeUrl, boolean generateOriginHeader) { + if (version == V13) { + return new WebSocketClientHandshaker13( + webSocketURL, V13, subprotocol, allowExtensions, customHeaders, + maxFramePayloadLength, performMasking, allowMaskMismatch, forceCloseTimeoutMillis, + absoluteUpgradeUrl, generateOriginHeader); + } + if (version == V08) { + return new WebSocketClientHandshaker08( + webSocketURL, V08, subprotocol, allowExtensions, customHeaders, + maxFramePayloadLength, performMasking, allowMaskMismatch, forceCloseTimeoutMillis, + absoluteUpgradeUrl, generateOriginHeader); + } + if (version == V07) { + return new WebSocketClientHandshaker07( + webSocketURL, V07, subprotocol, allowExtensions, customHeaders, + maxFramePayloadLength, performMasking, allowMaskMismatch, forceCloseTimeoutMillis, + absoluteUpgradeUrl, generateOriginHeader); + } + if (version == V00) { + return new WebSocketClientHandshaker00( + webSocketURL, V00, subprotocol, customHeaders, + maxFramePayloadLength, forceCloseTimeoutMillis, absoluteUpgradeUrl, generateOriginHeader); + } + + throw new WebSocketClientHandshakeException("Protocol version " + version + " not supported."); + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientProtocolConfig.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientProtocolConfig.java new file mode 100644 index 0000000..abeea74 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientProtocolConfig.java @@ -0,0 +1,438 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx; + +import io.netty.handler.codec.http.EmptyHttpHeaders; +import io.netty.handler.codec.http.HttpHeaders; +import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler.ClientHandshakeStateEvent; +import io.netty.util.internal.ObjectUtil; + +import java.net.URI; + +import static io.netty.handler.codec.http.websocketx.WebSocketServerProtocolConfig.DEFAULT_HANDSHAKE_TIMEOUT_MILLIS; +import static io.netty.util.internal.ObjectUtil.checkPositive; + +/** + * WebSocket server configuration. + */ +public final class WebSocketClientProtocolConfig { + + static final boolean DEFAULT_PERFORM_MASKING = true; + static final boolean DEFAULT_ALLOW_MASK_MISMATCH = false; + static final boolean DEFAULT_HANDLE_CLOSE_FRAMES = true; + static final boolean DEFAULT_DROP_PONG_FRAMES = true; + static final boolean DEFAULT_GENERATE_ORIGIN_HEADER = true; + + private final URI webSocketUri; + private final String subprotocol; + private final WebSocketVersion version; + private final boolean allowExtensions; + private final HttpHeaders customHeaders; + private final int maxFramePayloadLength; + private final boolean performMasking; + private final boolean allowMaskMismatch; + private final boolean handleCloseFrames; + private final WebSocketCloseStatus sendCloseFrame; + private final boolean dropPongFrames; + private final long handshakeTimeoutMillis; + private final long forceCloseTimeoutMillis; + private final boolean absoluteUpgradeUrl; + private final boolean generateOriginHeader; + private final boolean withUTF8Validator; + + private WebSocketClientProtocolConfig( + URI webSocketUri, + String subprotocol, + WebSocketVersion version, + boolean allowExtensions, + HttpHeaders customHeaders, + int maxFramePayloadLength, + boolean performMasking, + boolean allowMaskMismatch, + boolean handleCloseFrames, + WebSocketCloseStatus sendCloseFrame, + boolean dropPongFrames, + long handshakeTimeoutMillis, + long forceCloseTimeoutMillis, + boolean absoluteUpgradeUrl, + boolean generateOriginHeader, + boolean withUTF8Validator + ) { + this.webSocketUri = webSocketUri; + this.subprotocol = subprotocol; + this.version = version; + this.allowExtensions = allowExtensions; + this.customHeaders = customHeaders; + this.maxFramePayloadLength = maxFramePayloadLength; + this.performMasking = performMasking; + this.allowMaskMismatch = allowMaskMismatch; + this.forceCloseTimeoutMillis = forceCloseTimeoutMillis; + this.handleCloseFrames = handleCloseFrames; + this.sendCloseFrame = sendCloseFrame; + this.dropPongFrames = dropPongFrames; + this.handshakeTimeoutMillis = checkPositive(handshakeTimeoutMillis, "handshakeTimeoutMillis"); + this.absoluteUpgradeUrl = absoluteUpgradeUrl; + this.generateOriginHeader = generateOriginHeader; + this.withUTF8Validator = withUTF8Validator; + } + + public URI webSocketUri() { + return webSocketUri; + } + + public String subprotocol() { + return subprotocol; + } + + public WebSocketVersion version() { + return version; + } + + public boolean allowExtensions() { + return allowExtensions; + } + + public HttpHeaders customHeaders() { + return customHeaders; + } + + public int maxFramePayloadLength() { + return maxFramePayloadLength; + } + + public boolean performMasking() { + return performMasking; + } + + public boolean allowMaskMismatch() { + return allowMaskMismatch; + } + + public boolean handleCloseFrames() { + return handleCloseFrames; + } + + public WebSocketCloseStatus sendCloseFrame() { + return sendCloseFrame; + } + + public boolean dropPongFrames() { + return dropPongFrames; + } + + public long handshakeTimeoutMillis() { + return handshakeTimeoutMillis; + } + + public long forceCloseTimeoutMillis() { + return forceCloseTimeoutMillis; + } + + public boolean absoluteUpgradeUrl() { + return absoluteUpgradeUrl; + } + + public boolean generateOriginHeader() { + return generateOriginHeader; + } + + public boolean withUTF8Validator() { + return withUTF8Validator; + } + + @Override + public String toString() { + return "WebSocketClientProtocolConfig" + + " {webSocketUri=" + webSocketUri + + ", subprotocol=" + subprotocol + + ", version=" + version + + ", allowExtensions=" + allowExtensions + + ", customHeaders=" + customHeaders + + ", maxFramePayloadLength=" + maxFramePayloadLength + + ", performMasking=" + performMasking + + ", allowMaskMismatch=" + allowMaskMismatch + + ", handleCloseFrames=" + handleCloseFrames + + ", sendCloseFrame=" + sendCloseFrame + + ", dropPongFrames=" + dropPongFrames + + ", handshakeTimeoutMillis=" + handshakeTimeoutMillis + + ", forceCloseTimeoutMillis=" + forceCloseTimeoutMillis + + ", absoluteUpgradeUrl=" + absoluteUpgradeUrl + + ", generateOriginHeader=" + generateOriginHeader + + "}"; + } + + public Builder toBuilder() { + return new Builder(this); + } + + public static Builder newBuilder() { + return new Builder( + URI.create("https://localhost/"), + null, + WebSocketVersion.V13, + false, + EmptyHttpHeaders.INSTANCE, + 65536, + DEFAULT_PERFORM_MASKING, + DEFAULT_ALLOW_MASK_MISMATCH, + DEFAULT_HANDLE_CLOSE_FRAMES, + WebSocketCloseStatus.NORMAL_CLOSURE, + DEFAULT_DROP_PONG_FRAMES, + DEFAULT_HANDSHAKE_TIMEOUT_MILLIS, + -1, + false, + DEFAULT_GENERATE_ORIGIN_HEADER, + true); + } + + public static final class Builder { + private URI webSocketUri; + private String subprotocol; + private WebSocketVersion version; + private boolean allowExtensions; + private HttpHeaders customHeaders; + private int maxFramePayloadLength; + private boolean performMasking; + private boolean allowMaskMismatch; + private boolean handleCloseFrames; + private WebSocketCloseStatus sendCloseFrame; + private boolean dropPongFrames; + private long handshakeTimeoutMillis; + private long forceCloseTimeoutMillis; + private boolean absoluteUpgradeUrl; + private boolean generateOriginHeader; + private boolean withUTF8Validator; + + private Builder(WebSocketClientProtocolConfig clientConfig) { + this(ObjectUtil.checkNotNull(clientConfig, "clientConfig").webSocketUri(), + clientConfig.subprotocol(), + clientConfig.version(), + clientConfig.allowExtensions(), + clientConfig.customHeaders(), + clientConfig.maxFramePayloadLength(), + clientConfig.performMasking(), + clientConfig.allowMaskMismatch(), + clientConfig.handleCloseFrames(), + clientConfig.sendCloseFrame(), + clientConfig.dropPongFrames(), + clientConfig.handshakeTimeoutMillis(), + clientConfig.forceCloseTimeoutMillis(), + clientConfig.absoluteUpgradeUrl(), + clientConfig.generateOriginHeader(), + clientConfig.withUTF8Validator()); + } + + private Builder(URI webSocketUri, + String subprotocol, + WebSocketVersion version, + boolean allowExtensions, + HttpHeaders customHeaders, + int maxFramePayloadLength, + boolean performMasking, + boolean allowMaskMismatch, + boolean handleCloseFrames, + WebSocketCloseStatus sendCloseFrame, + boolean dropPongFrames, + long handshakeTimeoutMillis, + long forceCloseTimeoutMillis, + boolean absoluteUpgradeUrl, + boolean generateOriginHeader, + boolean withUTF8Validator) { + this.webSocketUri = webSocketUri; + this.subprotocol = subprotocol; + this.version = version; + this.allowExtensions = allowExtensions; + this.customHeaders = customHeaders; + this.maxFramePayloadLength = maxFramePayloadLength; + this.performMasking = performMasking; + this.allowMaskMismatch = allowMaskMismatch; + this.handleCloseFrames = handleCloseFrames; + this.sendCloseFrame = sendCloseFrame; + this.dropPongFrames = dropPongFrames; + this.handshakeTimeoutMillis = handshakeTimeoutMillis; + this.forceCloseTimeoutMillis = forceCloseTimeoutMillis; + this.absoluteUpgradeUrl = absoluteUpgradeUrl; + this.generateOriginHeader = generateOriginHeader; + this.withUTF8Validator = withUTF8Validator; + } + + /** + * URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be + * sent to this URL. + */ + public Builder webSocketUri(String webSocketUri) { + return webSocketUri(URI.create(webSocketUri)); + } + + /** + * URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be + * sent to this URL. + */ + public Builder webSocketUri(URI webSocketUri) { + this.webSocketUri = webSocketUri; + return this; + } + + /** + * Sub protocol request sent to the server. + */ + public Builder subprotocol(String subprotocol) { + this.subprotocol = subprotocol; + return this; + } + + /** + * Version of web socket specification to use to connect to the server + */ + public Builder version(WebSocketVersion version) { + this.version = version; + return this; + } + + /** + * Allow extensions to be used in the reserved bits of the web socket frame + */ + public Builder allowExtensions(boolean allowExtensions) { + this.allowExtensions = allowExtensions; + return this; + } + + /** + * Map of custom headers to add to the client request + */ + public Builder customHeaders(HttpHeaders customHeaders) { + this.customHeaders = customHeaders; + return this; + } + + /** + * Maximum length of a frame's payload + */ + public Builder maxFramePayloadLength(int maxFramePayloadLength) { + this.maxFramePayloadLength = maxFramePayloadLength; + return this; + } + + /** + * Whether to mask all written websocket frames. This must be set to true in order to be fully compatible + * with the websocket specifications. Client applications that communicate with a non-standard server + * which doesn't require masking might set this to false to achieve a higher performance. + */ + public Builder performMasking(boolean performMasking) { + this.performMasking = performMasking; + return this; + } + + /** + * When set to true, frames which are not masked properly according to the standard will still be accepted. + */ + public Builder allowMaskMismatch(boolean allowMaskMismatch) { + this.allowMaskMismatch = allowMaskMismatch; + return this; + } + + /** + * {@code true} if close frames should not be forwarded and just close the channel + */ + public Builder handleCloseFrames(boolean handleCloseFrames) { + this.handleCloseFrames = handleCloseFrames; + return this; + } + + /** + * Close frame to send, when close frame was not send manually. Or {@code null} to disable proper close. + */ + public Builder sendCloseFrame(WebSocketCloseStatus sendCloseFrame) { + this.sendCloseFrame = sendCloseFrame; + return this; + } + + /** + * {@code true} if pong frames should not be forwarded + */ + public Builder dropPongFrames(boolean dropPongFrames) { + this.dropPongFrames = dropPongFrames; + return this; + } + + /** + * Handshake timeout in mills, when handshake timeout, will trigger user + * event {@link ClientHandshakeStateEvent#HANDSHAKE_TIMEOUT} + */ + public Builder handshakeTimeoutMillis(long handshakeTimeoutMillis) { + this.handshakeTimeoutMillis = handshakeTimeoutMillis; + return this; + } + + /** + * Close the connection if it was not closed by the server after timeout specified + */ + public Builder forceCloseTimeoutMillis(long forceCloseTimeoutMillis) { + this.forceCloseTimeoutMillis = forceCloseTimeoutMillis; + return this; + } + + /** + * Use an absolute url for the Upgrade request, typically when connecting through an HTTP proxy over clear HTTP + */ + public Builder absoluteUpgradeUrl(boolean absoluteUpgradeUrl) { + this.absoluteUpgradeUrl = absoluteUpgradeUrl; + return this; + } + + /** + * Allows to generate the `Origin`|`Sec-WebSocket-Origin` header value for handshake request + * according the given webSocketURI. Usually it's not necessary and can be disabled, + * but for backward compatibility is set to {@code true} as default. + */ + public Builder generateOriginHeader(boolean generateOriginHeader) { + this.generateOriginHeader = generateOriginHeader; + return this; + } + + /** + * Toggles UTF8 validation for payload of text websocket frames. By default validation is enabled. + */ + public Builder withUTF8Validator(boolean withUTF8Validator) { + this.withUTF8Validator = withUTF8Validator; + return this; + } + + /** + * Build unmodifiable client protocol configuration. + */ + public WebSocketClientProtocolConfig build() { + return new WebSocketClientProtocolConfig( + webSocketUri, + subprotocol, + version, + allowExtensions, + customHeaders, + maxFramePayloadLength, + performMasking, + allowMaskMismatch, + handleCloseFrames, + sendCloseFrame, + dropPongFrames, + handshakeTimeoutMillis, + forceCloseTimeoutMillis, + absoluteUpgradeUrl, + generateOriginHeader, + withUTF8Validator + ); + } + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientProtocolHandler.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientProtocolHandler.java new file mode 100644 index 0000000..62137c5 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientProtocolHandler.java @@ -0,0 +1,393 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandler; +import io.netty.channel.ChannelPipeline; +import io.netty.handler.codec.http.HttpHeaders; + +import java.net.URI; +import java.util.List; + +import static io.netty.handler.codec.http.websocketx.WebSocketClientProtocolConfig.DEFAULT_ALLOW_MASK_MISMATCH; +import static io.netty.handler.codec.http.websocketx.WebSocketClientProtocolConfig.DEFAULT_DROP_PONG_FRAMES; +import static io.netty.handler.codec.http.websocketx.WebSocketClientProtocolConfig.DEFAULT_HANDLE_CLOSE_FRAMES; +import static io.netty.handler.codec.http.websocketx.WebSocketClientProtocolConfig.DEFAULT_PERFORM_MASKING; +import static io.netty.handler.codec.http.websocketx.WebSocketServerProtocolConfig.DEFAULT_HANDSHAKE_TIMEOUT_MILLIS; +import static io.netty.util.internal.ObjectUtil.*; + +/** + * This handler does all the heavy lifting for you to run a websocket client. + * + * It takes care of websocket handshaking as well as processing of Ping, Pong frames. Text and Binary + * data frames are passed to the next handler in the pipeline (implemented by you) for processing. + * Also the close frame is passed to the next handler as you may want inspect it before close the connection if + * the {@code handleCloseFrames} is {@code false}, default is {@code true}. + * + * This implementation will establish the websocket connection once the connection to the remote server was complete. + * + * To know once a handshake was done you can intercept the + * {@link ChannelInboundHandler#userEventTriggered(ChannelHandlerContext, Object)} and check if the event was of type + * {@link ClientHandshakeStateEvent#HANDSHAKE_ISSUED} or {@link ClientHandshakeStateEvent#HANDSHAKE_COMPLETE}. + */ +public class WebSocketClientProtocolHandler extends WebSocketProtocolHandler { + private final WebSocketClientHandshaker handshaker; + private final WebSocketClientProtocolConfig clientConfig; + + /** + * Returns the used handshaker + */ + public WebSocketClientHandshaker handshaker() { + return handshaker; + } + + /** + * Events that are fired to notify about handshake status + */ + public enum ClientHandshakeStateEvent { + /** + * The Handshake was timed out + */ + HANDSHAKE_TIMEOUT, + + /** + * The Handshake was started but the server did not response yet to the request + */ + HANDSHAKE_ISSUED, + + /** + * The Handshake was complete successful and so the channel was upgraded to websockets + */ + HANDSHAKE_COMPLETE + } + + /** + * Base constructor + * + * @param clientConfig + * Client protocol configuration. + */ + public WebSocketClientProtocolHandler(WebSocketClientProtocolConfig clientConfig) { + super(checkNotNull(clientConfig, "clientConfig").dropPongFrames(), + clientConfig.sendCloseFrame(), clientConfig.forceCloseTimeoutMillis()); + this.handshaker = WebSocketClientHandshakerFactory.newHandshaker( + clientConfig.webSocketUri(), + clientConfig.version(), + clientConfig.subprotocol(), + clientConfig.allowExtensions(), + clientConfig.customHeaders(), + clientConfig.maxFramePayloadLength(), + clientConfig.performMasking(), + clientConfig.allowMaskMismatch(), + clientConfig.forceCloseTimeoutMillis(), + clientConfig.absoluteUpgradeUrl(), + clientConfig.generateOriginHeader() + ); + this.clientConfig = clientConfig; + } + + /** + * Base constructor + * + * @param webSocketURL + * URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be + * sent to this URL. + * @param version + * Version of web socket specification to use to connect to the server + * @param subprotocol + * Sub protocol request sent to the server. + * @param customHeaders + * Map of custom headers to add to the client request + * @param maxFramePayloadLength + * Maximum length of a frame's payload + * @param handleCloseFrames + * {@code true} if close frames should not be forwarded and just close the channel + * @param performMasking + * Whether to mask all written websocket frames. This must be set to true in order to be fully compatible + * with the websocket specifications. Client applications that communicate with a non-standard server + * which doesn't require masking might set this to false to achieve a higher performance. + * @param allowMaskMismatch + * When set to true, frames which are not masked properly according to the standard will still be + * accepted. + */ + public WebSocketClientProtocolHandler(URI webSocketURL, WebSocketVersion version, String subprotocol, + boolean allowExtensions, HttpHeaders customHeaders, + int maxFramePayloadLength, boolean handleCloseFrames, + boolean performMasking, boolean allowMaskMismatch) { + this(webSocketURL, version, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength, + handleCloseFrames, performMasking, allowMaskMismatch, DEFAULT_HANDSHAKE_TIMEOUT_MILLIS); + } + + /** + * Base constructor + * + * @param webSocketURL + * URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be + * sent to this URL. + * @param version + * Version of web socket specification to use to connect to the server + * @param subprotocol + * Sub protocol request sent to the server. + * @param customHeaders + * Map of custom headers to add to the client request + * @param maxFramePayloadLength + * Maximum length of a frame's payload + * @param handleCloseFrames + * {@code true} if close frames should not be forwarded and just close the channel + * @param performMasking + * Whether to mask all written websocket frames. This must be set to true in order to be fully compatible + * with the websocket specifications. Client applications that communicate with a non-standard server + * which doesn't require masking might set this to false to achieve a higher performance. + * @param allowMaskMismatch + * When set to true, frames which are not masked properly according to the standard will still be + * accepted. + * @param handshakeTimeoutMillis + * Handshake timeout in mills, when handshake timeout, will trigger user + * event {@link ClientHandshakeStateEvent#HANDSHAKE_TIMEOUT} + */ + public WebSocketClientProtocolHandler(URI webSocketURL, WebSocketVersion version, String subprotocol, + boolean allowExtensions, HttpHeaders customHeaders, + int maxFramePayloadLength, boolean handleCloseFrames, boolean performMasking, + boolean allowMaskMismatch, long handshakeTimeoutMillis) { + this(WebSocketClientHandshakerFactory.newHandshaker(webSocketURL, version, subprotocol, + allowExtensions, customHeaders, maxFramePayloadLength, + performMasking, allowMaskMismatch), + handleCloseFrames, handshakeTimeoutMillis); + } + + /** + * Base constructor + * + * @param webSocketURL + * URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be + * sent to this URL. + * @param version + * Version of web socket specification to use to connect to the server + * @param subprotocol + * Sub protocol request sent to the server. + * @param customHeaders + * Map of custom headers to add to the client request + * @param maxFramePayloadLength + * Maximum length of a frame's payload + * @param handleCloseFrames + * {@code true} if close frames should not be forwarded and just close the channel + */ + public WebSocketClientProtocolHandler(URI webSocketURL, WebSocketVersion version, String subprotocol, + boolean allowExtensions, HttpHeaders customHeaders, + int maxFramePayloadLength, boolean handleCloseFrames) { + this(webSocketURL, version, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength, + handleCloseFrames, DEFAULT_HANDSHAKE_TIMEOUT_MILLIS); + } + + /** + * Base constructor + * + * @param webSocketURL + * URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be + * sent to this URL. + * @param version + * Version of web socket specification to use to connect to the server + * @param subprotocol + * Sub protocol request sent to the server. + * @param customHeaders + * Map of custom headers to add to the client request + * @param maxFramePayloadLength + * Maximum length of a frame's payload + * @param handleCloseFrames + * {@code true} if close frames should not be forwarded and just close the channel + * @param handshakeTimeoutMillis + * Handshake timeout in mills, when handshake timeout, will trigger user + * event {@link ClientHandshakeStateEvent#HANDSHAKE_TIMEOUT} + */ + public WebSocketClientProtocolHandler(URI webSocketURL, WebSocketVersion version, String subprotocol, + boolean allowExtensions, HttpHeaders customHeaders, int maxFramePayloadLength, + boolean handleCloseFrames, long handshakeTimeoutMillis) { + this(webSocketURL, version, subprotocol, allowExtensions, customHeaders, maxFramePayloadLength, + handleCloseFrames, DEFAULT_PERFORM_MASKING, DEFAULT_ALLOW_MASK_MISMATCH, handshakeTimeoutMillis); + } + + /** + * Base constructor + * + * @param webSocketURL + * URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be + * sent to this URL. + * @param version + * Version of web socket specification to use to connect to the server + * @param subprotocol + * Sub protocol request sent to the server. + * @param customHeaders + * Map of custom headers to add to the client request + * @param maxFramePayloadLength + * Maximum length of a frame's payload + */ + public WebSocketClientProtocolHandler(URI webSocketURL, WebSocketVersion version, String subprotocol, + boolean allowExtensions, HttpHeaders customHeaders, + int maxFramePayloadLength) { + this(webSocketURL, version, subprotocol, allowExtensions, + customHeaders, maxFramePayloadLength, DEFAULT_HANDSHAKE_TIMEOUT_MILLIS); + } + + /** + * Base constructor + * + * @param webSocketURL + * URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be + * sent to this URL. + * @param version + * Version of web socket specification to use to connect to the server + * @param subprotocol + * Sub protocol request sent to the server. + * @param customHeaders + * Map of custom headers to add to the client request + * @param maxFramePayloadLength + * Maximum length of a frame's payload + * @param handshakeTimeoutMillis + * Handshake timeout in mills, when handshake timeout, will trigger user + * event {@link ClientHandshakeStateEvent#HANDSHAKE_TIMEOUT} + */ + public WebSocketClientProtocolHandler(URI webSocketURL, WebSocketVersion version, String subprotocol, + boolean allowExtensions, HttpHeaders customHeaders, + int maxFramePayloadLength, long handshakeTimeoutMillis) { + this(webSocketURL, version, subprotocol, allowExtensions, customHeaders, + maxFramePayloadLength, DEFAULT_HANDLE_CLOSE_FRAMES, handshakeTimeoutMillis); + } + + /** + * Base constructor + * + * @param handshaker + * The {@link WebSocketClientHandshaker} which will be used to issue the handshake once the connection + * was established to the remote peer. + * @param handleCloseFrames + * {@code true} if close frames should not be forwarded and just close the channel + */ + public WebSocketClientProtocolHandler(WebSocketClientHandshaker handshaker, boolean handleCloseFrames) { + this(handshaker, handleCloseFrames, DEFAULT_HANDSHAKE_TIMEOUT_MILLIS); + } + + /** + * Base constructor + * + * @param handshaker + * The {@link WebSocketClientHandshaker} which will be used to issue the handshake once the connection + * was established to the remote peer. + * @param handleCloseFrames + * {@code true} if close frames should not be forwarded and just close the channel + * @param handshakeTimeoutMillis + * Handshake timeout in mills, when handshake timeout, will trigger user + * event {@link ClientHandshakeStateEvent#HANDSHAKE_TIMEOUT} + */ + public WebSocketClientProtocolHandler(WebSocketClientHandshaker handshaker, boolean handleCloseFrames, + long handshakeTimeoutMillis) { + this(handshaker, handleCloseFrames, DEFAULT_DROP_PONG_FRAMES, handshakeTimeoutMillis); + } + + /** + * Base constructor + * + * @param handshaker + * The {@link WebSocketClientHandshaker} which will be used to issue the handshake once the connection + * was established to the remote peer. + * @param handleCloseFrames + * {@code true} if close frames should not be forwarded and just close the channel + * @param dropPongFrames + * {@code true} if pong frames should not be forwarded + */ + public WebSocketClientProtocolHandler(WebSocketClientHandshaker handshaker, boolean handleCloseFrames, + boolean dropPongFrames) { + this(handshaker, handleCloseFrames, dropPongFrames, DEFAULT_HANDSHAKE_TIMEOUT_MILLIS); + } + + /** + * Base constructor + * + * @param handshaker + * The {@link WebSocketClientHandshaker} which will be used to issue the handshake once the connection + * was established to the remote peer. + * @param handleCloseFrames + * {@code true} if close frames should not be forwarded and just close the channel + * @param dropPongFrames + * {@code true} if pong frames should not be forwarded + * @param handshakeTimeoutMillis + * Handshake timeout in mills, when handshake timeout, will trigger user + * event {@link ClientHandshakeStateEvent#HANDSHAKE_TIMEOUT} + */ + public WebSocketClientProtocolHandler(WebSocketClientHandshaker handshaker, boolean handleCloseFrames, + boolean dropPongFrames, long handshakeTimeoutMillis) { + super(dropPongFrames); + this.handshaker = handshaker; + this.clientConfig = WebSocketClientProtocolConfig.newBuilder() + .handleCloseFrames(handleCloseFrames) + .handshakeTimeoutMillis(handshakeTimeoutMillis) + .build(); + } + + /** + * Base constructor + * + * @param handshaker + * The {@link WebSocketClientHandshaker} which will be used to issue the handshake once the connection + * was established to the remote peer. + */ + public WebSocketClientProtocolHandler(WebSocketClientHandshaker handshaker) { + this(handshaker, DEFAULT_HANDSHAKE_TIMEOUT_MILLIS); + } + + /** + * Base constructor + * + * @param handshaker + * The {@link WebSocketClientHandshaker} which will be used to issue the handshake once the connection + * was established to the remote peer. + * @param handshakeTimeoutMillis + * Handshake timeout in mills, when handshake timeout, will trigger user + * event {@link ClientHandshakeStateEvent#HANDSHAKE_TIMEOUT} + */ + public WebSocketClientProtocolHandler(WebSocketClientHandshaker handshaker, long handshakeTimeoutMillis) { + this(handshaker, DEFAULT_HANDLE_CLOSE_FRAMES, handshakeTimeoutMillis); + } + + @Override + protected void decode(ChannelHandlerContext ctx, WebSocketFrame frame, List out) throws Exception { + if (clientConfig.handleCloseFrames() && frame instanceof CloseWebSocketFrame) { + ctx.close(); + return; + } + super.decode(ctx, frame, out); + } + + @Override + protected WebSocketClientHandshakeException buildHandshakeException(String message) { + return new WebSocketClientHandshakeException(message); + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) { + ChannelPipeline cp = ctx.pipeline(); + if (cp.get(WebSocketClientProtocolHandshakeHandler.class) == null) { + // Add the WebSocketClientProtocolHandshakeHandler before this one. + ctx.pipeline().addBefore(ctx.name(), WebSocketClientProtocolHandshakeHandler.class.getName(), + new WebSocketClientProtocolHandshakeHandler(handshaker, clientConfig.handshakeTimeoutMillis())); + } + if (clientConfig.withUTF8Validator() && cp.get(Utf8FrameValidator.class) == null) { + // Add the UFT8 checking before this one. + ctx.pipeline().addBefore(ctx.name(), Utf8FrameValidator.class.getName(), + new Utf8FrameValidator()); + } + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientProtocolHandshakeHandler.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientProtocolHandshakeHandler.java new file mode 100644 index 0000000..0ef2ac2 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketClientProtocolHandshakeHandler.java @@ -0,0 +1,144 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx; + +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelPromise; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler.ClientHandshakeStateEvent; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.FutureListener; + +import java.util.concurrent.TimeUnit; + +import static io.netty.util.internal.ObjectUtil.*; + +class WebSocketClientProtocolHandshakeHandler extends ChannelInboundHandlerAdapter { + private static final long DEFAULT_HANDSHAKE_TIMEOUT_MS = 10000L; + + private final WebSocketClientHandshaker handshaker; + private final long handshakeTimeoutMillis; + private ChannelHandlerContext ctx; + private ChannelPromise handshakePromise; + + WebSocketClientProtocolHandshakeHandler(WebSocketClientHandshaker handshaker) { + this(handshaker, DEFAULT_HANDSHAKE_TIMEOUT_MS); + } + + WebSocketClientProtocolHandshakeHandler(WebSocketClientHandshaker handshaker, long handshakeTimeoutMillis) { + this.handshaker = handshaker; + this.handshakeTimeoutMillis = checkPositive(handshakeTimeoutMillis, "handshakeTimeoutMillis"); + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + this.ctx = ctx; + handshakePromise = ctx.newPromise(); + } + + @Override + public void channelActive(final ChannelHandlerContext ctx) throws Exception { + super.channelActive(ctx); + handshaker.handshake(ctx.channel()).addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + if (!future.isSuccess()) { + handshakePromise.tryFailure(future.cause()); + ctx.fireExceptionCaught(future.cause()); + } else { + ctx.fireUserEventTriggered( + WebSocketClientProtocolHandler.ClientHandshakeStateEvent.HANDSHAKE_ISSUED); + } + } + }); + applyHandshakeTimeout(); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + if (!handshakePromise.isDone()) { + handshakePromise.tryFailure(new WebSocketClientHandshakeException("channel closed with handshake " + + "in progress")); + } + + super.channelInactive(ctx); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if (!(msg instanceof FullHttpResponse)) { + ctx.fireChannelRead(msg); + return; + } + + FullHttpResponse response = (FullHttpResponse) msg; + try { + if (!handshaker.isHandshakeComplete()) { + handshaker.finishHandshake(ctx.channel(), response); + handshakePromise.trySuccess(); + ctx.fireUserEventTriggered( + WebSocketClientProtocolHandler.ClientHandshakeStateEvent.HANDSHAKE_COMPLETE); + ctx.pipeline().remove(this); + return; + } + throw new IllegalStateException("WebSocketClientHandshaker should have been non finished yet"); + } finally { + response.release(); + } + } + + private void applyHandshakeTimeout() { + final ChannelPromise localHandshakePromise = handshakePromise; + if (handshakeTimeoutMillis <= 0 || localHandshakePromise.isDone()) { + return; + } + + final Future timeoutFuture = ctx.executor().schedule(new Runnable() { + @Override + public void run() { + if (localHandshakePromise.isDone()) { + return; + } + + if (localHandshakePromise.tryFailure(new WebSocketClientHandshakeException("handshake timed out"))) { + ctx.flush() + .fireUserEventTriggered(ClientHandshakeStateEvent.HANDSHAKE_TIMEOUT) + .close(); + } + } + }, handshakeTimeoutMillis, TimeUnit.MILLISECONDS); + + // Cancel the handshake timeout when handshake is finished. + localHandshakePromise.addListener(new FutureListener() { + @Override + public void operationComplete(Future f) throws Exception { + timeoutFuture.cancel(false); + } + }); + } + + /** + * This method is visible for testing. + * + * @return current handshake future + */ + ChannelFuture getHandshakeFuture() { + return handshakePromise; + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketCloseStatus.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketCloseStatus.java new file mode 100644 index 0000000..2a0b7ed --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketCloseStatus.java @@ -0,0 +1,330 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx; + +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +/** + * WebSocket status codes specified in RFC-6455. + *
+ *
+ * RFC-6455 The WebSocket Protocol, December 2011:
+ * https://tools.ietf.org/html/rfc6455#section-7.4.1
+ *
+ * WebSocket Protocol Registries, April 2019:
+ * https://www.iana.org/assignments/websocket/websocket.xhtml
+ *
+ * 7.4.1.  Defined Status Codes
+ *
+ * Endpoints MAY use the following pre-defined status codes when sending
+ * a Close frame.
+ *
+ * 1000
+ *
+ *    1000 indicates a normal closure, meaning that the purpose for
+ *    which the connection was established has been fulfilled.
+ *
+ * 1001
+ *
+ *    1001 indicates that an endpoint is "going away", such as a server
+ *    going down or a browser having navigated away from a page.
+ *
+ * 1002
+ *
+ *    1002 indicates that an endpoint is terminating the connection due
+ *    to a protocol error.
+ *
+ * 1003
+ *
+ *    1003 indicates that an endpoint is terminating the connection
+ *    because it has received a type of data it cannot accept (e.g., an
+ *    endpoint that understands only text data MAY send this if it
+ *    receives a binary message).
+ *
+ * 1004
+ *
+ *    Reserved. The specific meaning might be defined in the future.
+ *
+ * 1005
+ *
+ *    1005 is a reserved value and MUST NOT be set as a status code in a
+ *    Close control frame by an endpoint. It is designated for use in
+ *    applications expecting a status code to indicate that no status
+ *    code was actually present.
+ *
+ * 1006
+ *
+ *    1006 is a reserved value and MUST NOT be set as a status code in a
+ *    Close control frame by an endpoint. It is designated for use in
+ *    applications expecting a status code to indicate that the
+ *    connection was closed abnormally, e.g., without sending or
+ *    receiving a Close control frame.
+ *
+ * 1007
+ *
+ *    1007 indicates that an endpoint is terminating the connection
+ *    because it has received data within a message that was not
+ *    consistent with the type of the message (e.g., non-UTF-8 [RFC3629]
+ *    data within a text message).
+ *
+ * 1008
+ *
+ *    1008 indicates that an endpoint is terminating the connection
+ *    because it has received a message that violates its policy. This
+ *    is a generic status code that can be returned when there is no
+ *    other more suitable status code (e.g., 1003 or 1009) or if there
+ *    is a need to hide specific details about the policy.
+ *
+ * 1009
+ *
+ *    1009 indicates that an endpoint is terminating the connection
+ *    because it has received a message that is too big for it to
+ *    process.
+ *
+ * 1010
+ *
+ *    1010 indicates that an endpoint (client) is terminating the
+ *    connection because it has expected the server to negotiate one or
+ *    more extension, but the server didn't return them in the response
+ *    message of the WebSocket handshake. The list of extensions that
+ *    are needed SHOULD appear in the /reason/ part of the Close frame.
+ *    Note that this status code is not used by the server, because it
+ *    can fail the WebSocket handshake instead.
+ *
+ * 1011
+ *
+ *    1011 indicates that a server is terminating the connection because
+ *    it encountered an unexpected condition that prevented it from
+ *    fulfilling the request.
+ *
+ * 1012 (IANA Registry, Non RFC-6455)
+ *
+ *    1012 indicates that the service is restarted. a client may reconnect,
+ *    and if it choses to do, should reconnect using a randomized delay
+ *    of 5 - 30 seconds.
+ *
+ * 1013 (IANA Registry, Non RFC-6455)
+ *
+ *    1013 indicates that the service is experiencing overload. a client
+ *    should only connect to a different IP (when there are multiple for the
+ *    target) or reconnect to the same IP upon user action.
+ *
+ * 1014 (IANA Registry, Non RFC-6455)
+ *
+ *    The server was acting as a gateway or proxy and received an invalid
+ *    response from the upstream server. This is similar to 502 HTTP Status Code.
+ *
+ * 1015
+ *
+ *    1015 is a reserved value and MUST NOT be set as a status code in a
+ *    Close control frame by an endpoint. It is designated for use in
+ *    applications expecting a status code to indicate that the
+ *    connection was closed due to a failure to perform a TLS handshake
+ *    (e.g., the server certificate can't be verified).
+ *
+ *
+ * 7.4.2. Reserved Status Code Ranges
+ *
+ * 0-999
+ *
+ *    Status codes in the range 0-999 are not used.
+ *
+ * 1000-2999
+ *
+ *    Status codes in the range 1000-2999 are reserved for definition by
+ *    this protocol, its future revisions, and extensions specified in a
+ *    permanent and readily available public specification.
+ *
+ * 3000-3999
+ *
+ *    Status codes in the range 3000-3999 are reserved for use by
+ *    libraries, frameworks, and applications. These status codes are
+ *    registered directly with IANA. The interpretation of these codes
+ *    is undefined by this protocol.
+ *
+ * 4000-4999
+ *
+ *    Status codes in the range 4000-4999 are reserved for private use
+ *    and thus can't be registered. Such codes can be used by prior
+ *    agreements between WebSocket applications. The interpretation of
+ *    these codes is undefined by this protocol.
+ * 
+ *

+ * While {@link WebSocketCloseStatus} is enum-like structure, its instances should NOT be compared by reference. + * Instead, either {@link #equals(Object)} should be used or direct comparison of {@link #code()} value. + */ +public final class WebSocketCloseStatus implements Comparable { + + public static final WebSocketCloseStatus NORMAL_CLOSURE = + new WebSocketCloseStatus(1000, "Bye"); + + public static final WebSocketCloseStatus ENDPOINT_UNAVAILABLE = + new WebSocketCloseStatus(1001, "Endpoint unavailable"); + + public static final WebSocketCloseStatus PROTOCOL_ERROR = + new WebSocketCloseStatus(1002, "Protocol error"); + + public static final WebSocketCloseStatus INVALID_MESSAGE_TYPE = + new WebSocketCloseStatus(1003, "Invalid message type"); + + public static final WebSocketCloseStatus INVALID_PAYLOAD_DATA = + new WebSocketCloseStatus(1007, "Invalid payload data"); + + public static final WebSocketCloseStatus POLICY_VIOLATION = + new WebSocketCloseStatus(1008, "Policy violation"); + + public static final WebSocketCloseStatus MESSAGE_TOO_BIG = + new WebSocketCloseStatus(1009, "Message too big"); + + public static final WebSocketCloseStatus MANDATORY_EXTENSION = + new WebSocketCloseStatus(1010, "Mandatory extension"); + + public static final WebSocketCloseStatus INTERNAL_SERVER_ERROR = + new WebSocketCloseStatus(1011, "Internal server error"); + + public static final WebSocketCloseStatus SERVICE_RESTART = + new WebSocketCloseStatus(1012, "Service Restart"); + + public static final WebSocketCloseStatus TRY_AGAIN_LATER = + new WebSocketCloseStatus(1013, "Try Again Later"); + + public static final WebSocketCloseStatus BAD_GATEWAY = + new WebSocketCloseStatus(1014, "Bad Gateway"); + + // 1004, 1005, 1006, 1015 are reserved and should never be used by user + //public static final WebSocketCloseStatus SPECIFIC_MEANING = register(1004, "..."); + + public static final WebSocketCloseStatus EMPTY = + new WebSocketCloseStatus(1005, "Empty", false); + + public static final WebSocketCloseStatus ABNORMAL_CLOSURE = + new WebSocketCloseStatus(1006, "Abnormal closure", false); + + public static final WebSocketCloseStatus TLS_HANDSHAKE_FAILED = + new WebSocketCloseStatus(1015, "TLS handshake failed", false); + + private final int statusCode; + private final String reasonText; + private String text; + + public WebSocketCloseStatus(int statusCode, String reasonText) { + this(statusCode, reasonText, true); + } + + public WebSocketCloseStatus(int statusCode, String reasonText, boolean validate) { + if (validate && !isValidStatusCode(statusCode)) { + throw new IllegalArgumentException( + "WebSocket close status code does NOT comply with RFC-6455: " + statusCode); + } + this.statusCode = statusCode; + this.reasonText = checkNotNull(reasonText, "reasonText"); + } + + public int code() { + return statusCode; + } + + public String reasonText() { + return reasonText; + } + + /** + * Order of {@link WebSocketCloseStatus} only depends on {@link #code()}. + */ + @Override + public int compareTo(WebSocketCloseStatus o) { + return code() - o.code(); + } + + /** + * Equality of {@link WebSocketCloseStatus} only depends on {@link #code()}. + */ + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (null == o || getClass() != o.getClass()) { + return false; + } + + WebSocketCloseStatus that = (WebSocketCloseStatus) o; + + return statusCode == that.statusCode; + } + + @Override + public int hashCode() { + return statusCode; + } + + @Override + public String toString() { + String text = this.text; + if (text == null) { + // E.g.: "1000 Bye", "1009 Message too big" + this.text = text = code() + " " + reasonText(); + } + return text; + } + + public static boolean isValidStatusCode(int code) { + return code < 0 || + 1000 <= code && code <= 1003 || + 1007 <= code && code <= 1014 || + 3000 <= code; + } + + public static WebSocketCloseStatus valueOf(int code) { + switch (code) { + case 1000: + return NORMAL_CLOSURE; + case 1001: + return ENDPOINT_UNAVAILABLE; + case 1002: + return PROTOCOL_ERROR; + case 1003: + return INVALID_MESSAGE_TYPE; + case 1005: + return EMPTY; + case 1006: + return ABNORMAL_CLOSURE; + case 1007: + return INVALID_PAYLOAD_DATA; + case 1008: + return POLICY_VIOLATION; + case 1009: + return MESSAGE_TOO_BIG; + case 1010: + return MANDATORY_EXTENSION; + case 1011: + return INTERNAL_SERVER_ERROR; + case 1012: + return SERVICE_RESTART; + case 1013: + return TRY_AGAIN_LATER; + case 1014: + return BAD_GATEWAY; + case 1015: + return TLS_HANDSHAKE_FAILED; + default: + return new WebSocketCloseStatus(code, "Close status #" + code); + } + } + +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketDecoderConfig.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketDecoderConfig.java new file mode 100644 index 0000000..9367b5a --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketDecoderConfig.java @@ -0,0 +1,165 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx; + +import io.netty.util.internal.ObjectUtil; + +/** + * Frames decoder configuration. + */ +public final class WebSocketDecoderConfig { + + static final WebSocketDecoderConfig DEFAULT = + new WebSocketDecoderConfig(65536, true, false, false, true, true); + + private final int maxFramePayloadLength; + private final boolean expectMaskedFrames; + private final boolean allowMaskMismatch; + private final boolean allowExtensions; + private final boolean closeOnProtocolViolation; + private final boolean withUTF8Validator; + + /** + * Constructor + * + * @param maxFramePayloadLength + * Maximum length of a frame's payload. Setting this to an appropriate value for you application + * helps check for denial of services attacks. + * @param expectMaskedFrames + * Web socket servers must set this to true processed incoming masked payload. Client implementations + * must set this to false. + * @param allowMaskMismatch + * Allows to loosen the masking requirement on received frames. When this is set to false then also + * frames which are not masked properly according to the standard will still be accepted. + * @param allowExtensions + * Flag to allow reserved extension bits to be used or not + * @param closeOnProtocolViolation + * Flag to send close frame immediately on any protocol violation.ion. + * @param withUTF8Validator + * Allows you to avoid adding of Utf8FrameValidator to the pipeline on the + * WebSocketServerProtocolHandler creation. This is useful (less overhead) + * when you use only BinaryWebSocketFrame within your web socket connection. + */ + private WebSocketDecoderConfig(int maxFramePayloadLength, boolean expectMaskedFrames, boolean allowMaskMismatch, + boolean allowExtensions, boolean closeOnProtocolViolation, + boolean withUTF8Validator) { + this.maxFramePayloadLength = maxFramePayloadLength; + this.expectMaskedFrames = expectMaskedFrames; + this.allowMaskMismatch = allowMaskMismatch; + this.allowExtensions = allowExtensions; + this.closeOnProtocolViolation = closeOnProtocolViolation; + this.withUTF8Validator = withUTF8Validator; + } + + public int maxFramePayloadLength() { + return maxFramePayloadLength; + } + + public boolean expectMaskedFrames() { + return expectMaskedFrames; + } + + public boolean allowMaskMismatch() { + return allowMaskMismatch; + } + + public boolean allowExtensions() { + return allowExtensions; + } + + public boolean closeOnProtocolViolation() { + return closeOnProtocolViolation; + } + + public boolean withUTF8Validator() { + return withUTF8Validator; + } + + @Override + public String toString() { + return "WebSocketDecoderConfig" + + " [maxFramePayloadLength=" + maxFramePayloadLength + + ", expectMaskedFrames=" + expectMaskedFrames + + ", allowMaskMismatch=" + allowMaskMismatch + + ", allowExtensions=" + allowExtensions + + ", closeOnProtocolViolation=" + closeOnProtocolViolation + + ", withUTF8Validator=" + withUTF8Validator + + "]"; + } + + public Builder toBuilder() { + return new Builder(this); + } + + public static Builder newBuilder() { + return new Builder(DEFAULT); + } + + public static final class Builder { + private int maxFramePayloadLength; + private boolean expectMaskedFrames; + private boolean allowMaskMismatch; + private boolean allowExtensions; + private boolean closeOnProtocolViolation; + private boolean withUTF8Validator; + + private Builder(WebSocketDecoderConfig decoderConfig) { + ObjectUtil.checkNotNull(decoderConfig, "decoderConfig"); + maxFramePayloadLength = decoderConfig.maxFramePayloadLength(); + expectMaskedFrames = decoderConfig.expectMaskedFrames(); + allowMaskMismatch = decoderConfig.allowMaskMismatch(); + allowExtensions = decoderConfig.allowExtensions(); + closeOnProtocolViolation = decoderConfig.closeOnProtocolViolation(); + withUTF8Validator = decoderConfig.withUTF8Validator(); + } + + public Builder maxFramePayloadLength(int maxFramePayloadLength) { + this.maxFramePayloadLength = maxFramePayloadLength; + return this; + } + + public Builder expectMaskedFrames(boolean expectMaskedFrames) { + this.expectMaskedFrames = expectMaskedFrames; + return this; + } + + public Builder allowMaskMismatch(boolean allowMaskMismatch) { + this.allowMaskMismatch = allowMaskMismatch; + return this; + } + + public Builder allowExtensions(boolean allowExtensions) { + this.allowExtensions = allowExtensions; + return this; + } + + public Builder closeOnProtocolViolation(boolean closeOnProtocolViolation) { + this.closeOnProtocolViolation = closeOnProtocolViolation; + return this; + } + + public Builder withUTF8Validator(boolean withUTF8Validator) { + this.withUTF8Validator = withUTF8Validator; + return this; + } + + public WebSocketDecoderConfig build() { + return new WebSocketDecoderConfig( + maxFramePayloadLength, expectMaskedFrames, allowMaskMismatch, + allowExtensions, closeOnProtocolViolation, withUTF8Validator); + } + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketFrame.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketFrame.java new file mode 100644 index 0000000..3cd0d18 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketFrame.java @@ -0,0 +1,109 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.DefaultByteBufHolder; +import io.netty.util.internal.StringUtil; + +/** + * Base class for web socket frames. + */ +public abstract class WebSocketFrame extends DefaultByteBufHolder { + + /** + * Flag to indicate if this frame is the final fragment in a message. The first fragment (frame) may also be the + * final fragment. + */ + private final boolean finalFragment; + + /** + * RSV1, RSV2, RSV3 used for extensions + */ + private final int rsv; + + protected WebSocketFrame(ByteBuf binaryData) { + this(true, 0, binaryData); + } + + protected WebSocketFrame(boolean finalFragment, int rsv, ByteBuf binaryData) { + super(binaryData); + this.finalFragment = finalFragment; + this.rsv = rsv; + } + + /** + * Flag to indicate if this frame is the final fragment in a message. The first fragment (frame) may also be the + * final fragment. + */ + public boolean isFinalFragment() { + return finalFragment; + } + + /** + * Bits used for extensions to the standard. + */ + public int rsv() { + return rsv; + } + + @Override + public WebSocketFrame copy() { + return (WebSocketFrame) super.copy(); + } + + @Override + public WebSocketFrame duplicate() { + return (WebSocketFrame) super.duplicate(); + } + + @Override + public WebSocketFrame retainedDuplicate() { + return (WebSocketFrame) super.retainedDuplicate(); + } + + @Override + public abstract WebSocketFrame replace(ByteBuf content); + + @Override + public String toString() { + return StringUtil.simpleClassName(this) + "(data: " + contentToString() + ')'; + } + + @Override + public WebSocketFrame retain() { + super.retain(); + return this; + } + + @Override + public WebSocketFrame retain(int increment) { + super.retain(increment); + return this; + } + + @Override + public WebSocketFrame touch() { + super.touch(); + return this; + } + + @Override + public WebSocketFrame touch(Object hint) { + super.touch(hint); + return this; + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketFrameAggregator.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketFrameAggregator.java new file mode 100644 index 0000000..ef500c8 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketFrameAggregator.java @@ -0,0 +1,99 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelPipeline; +import io.netty.handler.codec.MessageAggregator; +import io.netty.handler.codec.TooLongFrameException; + +/** + * Handler that aggregate fragmented WebSocketFrame's. + * + * Be aware if PING/PONG/CLOSE frames are send in the middle of a fragmented {@link WebSocketFrame} they will + * just get forwarded to the next handler in the pipeline. + */ +public class WebSocketFrameAggregator + extends MessageAggregator { + + /** + * Creates a new instance + * + * @param maxContentLength If the size of the aggregated frame exceeds this value, + * a {@link TooLongFrameException} is thrown. + */ + public WebSocketFrameAggregator(int maxContentLength) { + super(maxContentLength); + } + + @Override + protected boolean isStartMessage(WebSocketFrame msg) throws Exception { + return msg instanceof TextWebSocketFrame || msg instanceof BinaryWebSocketFrame; + } + + @Override + protected boolean isContentMessage(WebSocketFrame msg) throws Exception { + return msg instanceof ContinuationWebSocketFrame; + } + + @Override + protected boolean isLastContentMessage(ContinuationWebSocketFrame msg) throws Exception { + return isContentMessage(msg) && msg.isFinalFragment(); + } + + @Override + protected boolean isAggregated(WebSocketFrame msg) throws Exception { + if (msg.isFinalFragment()) { + return !isContentMessage(msg); + } + + return !isStartMessage(msg) && !isContentMessage(msg); + } + + @Override + protected boolean isContentLengthInvalid(WebSocketFrame start, int maxContentLength) { + return false; + } + + @Override + protected Object newContinueResponse(WebSocketFrame start, int maxContentLength, ChannelPipeline pipeline) { + return null; + } + + @Override + protected boolean closeAfterContinueResponse(Object msg) throws Exception { + throw new UnsupportedOperationException(); + } + + @Override + protected boolean ignoreContentAfterContinueResponse(Object msg) throws Exception { + throw new UnsupportedOperationException(); + } + + @Override + protected WebSocketFrame beginAggregation(WebSocketFrame start, ByteBuf content) throws Exception { + if (start instanceof TextWebSocketFrame) { + return new TextWebSocketFrame(true, start.rsv(), content); + } + + if (start instanceof BinaryWebSocketFrame) { + return new BinaryWebSocketFrame(true, start.rsv(), content); + } + + // Should not reach here. + throw new Error(); + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketFrameDecoder.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketFrameDecoder.java new file mode 100644 index 0000000..e65cc38 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketFrameDecoder.java @@ -0,0 +1,27 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx; + +import io.netty.channel.ChannelInboundHandler; +import io.netty.channel.ChannelPipeline; + +/** + * Marker interface which all WebSocketFrame decoders need to implement. + * + * This makes it easier to access the added encoder later in the {@link ChannelPipeline}. + */ +public interface WebSocketFrameDecoder extends ChannelInboundHandler { +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketFrameEncoder.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketFrameEncoder.java new file mode 100644 index 0000000..1bcaf5c --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketFrameEncoder.java @@ -0,0 +1,27 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx; + +import io.netty.channel.ChannelOutboundHandler; +import io.netty.channel.ChannelPipeline; + +/** + * Marker interface which all WebSocketFrame encoders need to implement. + * + * This makes it easier to access the added encoder later in the {@link ChannelPipeline}. + */ +public interface WebSocketFrameEncoder extends ChannelOutboundHandler { +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketHandshakeException.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketHandshakeException.java new file mode 100644 index 0000000..dd0bf61 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketHandshakeException.java @@ -0,0 +1,32 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx; + +/** + * Exception during handshaking process + */ +public class WebSocketHandshakeException extends RuntimeException { + + private static final long serialVersionUID = 1L; + + public WebSocketHandshakeException(String s) { + super(s); + } + + public WebSocketHandshakeException(String s, Throwable throwable) { + super(s, throwable); + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketProtocolHandler.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketProtocolHandler.java new file mode 100644 index 0000000..ae7b4c5 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketProtocolHandler.java @@ -0,0 +1,194 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx; + + +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelOutboundHandler; +import io.netty.channel.ChannelPromise; +import io.netty.handler.codec.MessageToMessageDecoder; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.PromiseNotifier; + +import java.net.SocketAddress; +import java.nio.channels.ClosedChannelException; +import java.util.List; +import java.util.concurrent.TimeUnit; + +abstract class WebSocketProtocolHandler extends MessageToMessageDecoder + implements ChannelOutboundHandler { + + private final boolean dropPongFrames; + private final WebSocketCloseStatus closeStatus; + private final long forceCloseTimeoutMillis; + private ChannelPromise closeSent; + + /** + * Creates a new {@link WebSocketProtocolHandler} that will drop {@link PongWebSocketFrame}s. + */ + WebSocketProtocolHandler() { + this(true); + } + + /** + * Creates a new {@link WebSocketProtocolHandler}, given a parameter that determines whether or not to drop {@link + * PongWebSocketFrame}s. + * + * @param dropPongFrames + * {@code true} if {@link PongWebSocketFrame}s should be dropped + */ + WebSocketProtocolHandler(boolean dropPongFrames) { + this(dropPongFrames, null, 0L); + } + + WebSocketProtocolHandler(boolean dropPongFrames, + WebSocketCloseStatus closeStatus, + long forceCloseTimeoutMillis) { + this.dropPongFrames = dropPongFrames; + this.closeStatus = closeStatus; + this.forceCloseTimeoutMillis = forceCloseTimeoutMillis; + } + + @Override + protected void decode(ChannelHandlerContext ctx, WebSocketFrame frame, List out) throws Exception { + if (frame instanceof PingWebSocketFrame) { + frame.content().retain(); + ctx.writeAndFlush(new PongWebSocketFrame(frame.content())); + readIfNeeded(ctx); + return; + } + if (frame instanceof PongWebSocketFrame && dropPongFrames) { + readIfNeeded(ctx); + return; + } + + out.add(frame.retain()); + } + + private static void readIfNeeded(ChannelHandlerContext ctx) { + if (!ctx.channel().config().isAutoRead()) { + ctx.read(); + } + } + + @Override + public void close(final ChannelHandlerContext ctx, final ChannelPromise promise) throws Exception { + if (closeStatus == null || !ctx.channel().isActive()) { + ctx.close(promise); + } else { + if (closeSent == null) { + write(ctx, new CloseWebSocketFrame(closeStatus), ctx.newPromise()); + } + flush(ctx); + applyCloseSentTimeout(ctx); + closeSent.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + ctx.close(promise); + } + }); + } + } + + @Override + public void write(final ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + if (closeSent != null) { + ReferenceCountUtil.release(msg); + promise.setFailure(new ClosedChannelException()); + } else if (msg instanceof CloseWebSocketFrame) { + closeSent(promise.unvoid()); + ctx.write(msg).addListener(new PromiseNotifier(false, closeSent)); + } else { + ctx.write(msg, promise); + } + } + + void closeSent(ChannelPromise promise) { + closeSent = promise; + } + + private void applyCloseSentTimeout(ChannelHandlerContext ctx) { + if (closeSent.isDone() || forceCloseTimeoutMillis < 0) { + return; + } + + final Future timeoutTask = ctx.executor().schedule(new Runnable() { + @Override + public void run() { + if (!closeSent.isDone()) { + closeSent.tryFailure(buildHandshakeException("send close frame timed out")); + } + } + }, forceCloseTimeoutMillis, TimeUnit.MILLISECONDS); + + closeSent.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + timeoutTask.cancel(false); + } + }); + } + + /** + * Returns a {@link WebSocketHandshakeException} that depends on which client or server pipeline + * this handler belongs. Should be overridden in implementation otherwise a default exception is used. + */ + protected WebSocketHandshakeException buildHandshakeException(String message) { + return new WebSocketHandshakeException(message); + } + + @Override + public void bind(ChannelHandlerContext ctx, SocketAddress localAddress, + ChannelPromise promise) throws Exception { + ctx.bind(localAddress, promise); + } + + @Override + public void connect(ChannelHandlerContext ctx, SocketAddress remoteAddress, + SocketAddress localAddress, ChannelPromise promise) throws Exception { + ctx.connect(remoteAddress, localAddress, promise); + } + + @Override + public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) + throws Exception { + ctx.disconnect(promise); + } + + @Override + public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + ctx.deregister(promise); + } + + @Override + public void read(ChannelHandlerContext ctx) throws Exception { + ctx.read(); + } + + @Override + public void flush(ChannelHandlerContext ctx) throws Exception { + ctx.flush(); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + ctx.fireExceptionCaught(cause); + ctx.close(); + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketScheme.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketScheme.java new file mode 100644 index 0000000..056d6b0 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketScheme.java @@ -0,0 +1,69 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx; + +import io.netty.util.AsciiString; + +/** + * Defines the common schemes used for the WebSocket protocol as defined by + * rfc6455. + */ +public final class WebSocketScheme { + /** + * Scheme for non-secure WebSocket connection. + */ + public static final WebSocketScheme WS = new WebSocketScheme(80, "ws"); + + /** + * Scheme for secure WebSocket connection. + */ + public static final WebSocketScheme WSS = new WebSocketScheme(443, "wss"); + + private final int port; + private final AsciiString name; + + private WebSocketScheme(int port, String name) { + this.port = port; + this.name = AsciiString.cached(name); + } + + public AsciiString name() { + return name; + } + + public int port() { + return port; + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof WebSocketScheme)) { + return false; + } + WebSocketScheme other = (WebSocketScheme) o; + return other.port() == port && other.name().equals(name); + } + + @Override + public int hashCode() { + return port * 31 + name.hashCode(); + } + + @Override + public String toString() { + return name.toString(); + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshakeException.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshakeException.java new file mode 100644 index 0000000..456cf19 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshakeException.java @@ -0,0 +1,55 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx; + +import io.netty.handler.codec.http.DefaultHttpRequest; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.HttpRequest; +import io.netty.util.ReferenceCounted; + +/** + * Server exception during handshaking process. + * + *

IMPORTANT: This exception does not contain any {@link ReferenceCounted} fields + * e.g. {@link FullHttpRequest}, so no special treatment is needed. + */ +public final class WebSocketServerHandshakeException extends WebSocketHandshakeException { + + private static final long serialVersionUID = 1L; + + private final HttpRequest request; + + public WebSocketServerHandshakeException(String message) { + this(message, null); + } + + public WebSocketServerHandshakeException(String message, HttpRequest httpRequest) { + super(message); + if (httpRequest != null) { + request = new DefaultHttpRequest(httpRequest.protocolVersion(), httpRequest.method(), + httpRequest.uri(), httpRequest.headers()); + } else { + request = null; + } + } + + /** + * Returns a {@link HttpRequest request} if exception occurs during request validation otherwise {@code null}. + */ + public HttpRequest request() { + return request; + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshaker.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshaker.java new file mode 100644 index 0000000..229eac8 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshaker.java @@ -0,0 +1,514 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx; + +import java.nio.channels.ClosedChannelException; +import java.util.Collections; +import java.util.LinkedHashSet; +import java.util.Set; + +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelOutboundInvoker; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.ChannelPromise; +import io.netty.handler.codec.http.DefaultFullHttpRequest; +import io.netty.handler.codec.http.EmptyHttpHeaders; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpContentCompressor; +import io.netty.handler.codec.http.HttpHeaders; +import io.netty.handler.codec.http.HttpObject; +import io.netty.handler.codec.http.HttpObjectAggregator; +import io.netty.handler.codec.http.HttpRequest; +import io.netty.handler.codec.http.HttpRequestDecoder; +import io.netty.handler.codec.http.HttpResponseEncoder; +import io.netty.handler.codec.http.HttpServerCodec; +import io.netty.handler.codec.http.HttpUtil; +import io.netty.handler.codec.http.LastHttpContent; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.internal.EmptyArrays; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +/** + * Base class for server side web socket opening and closing handshakes + */ +public abstract class WebSocketServerHandshaker { + protected static final InternalLogger logger = InternalLoggerFactory.getInstance(WebSocketServerHandshaker.class); + + private final String uri; + + private final String[] subprotocols; + + private final WebSocketVersion version; + + private final WebSocketDecoderConfig decoderConfig; + + private String selectedSubprotocol; + + /** + * Use this as wildcard to support all requested sub-protocols + */ + public static final String SUB_PROTOCOL_WILDCARD = "*"; + + /** + * Constructor specifying the destination web socket location + * + * @param version + * the protocol version + * @param uri + * URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be + * sent to this URL. + * @param subprotocols + * CSV of supported protocols. Null if sub protocols not supported. + * @param maxFramePayloadLength + * Maximum length of a frame's payload + */ + protected WebSocketServerHandshaker( + WebSocketVersion version, String uri, String subprotocols, + int maxFramePayloadLength) { + this(version, uri, subprotocols, WebSocketDecoderConfig.newBuilder() + .maxFramePayloadLength(maxFramePayloadLength) + .build()); + } + + /** + * Constructor specifying the destination web socket location + * + * @param version + * the protocol version + * @param uri + * URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be + * sent to this URL. + * @param subprotocols + * CSV of supported protocols. Null if sub protocols not supported. + * @param decoderConfig + * Frames decoder configuration. + */ + protected WebSocketServerHandshaker( + WebSocketVersion version, String uri, String subprotocols, WebSocketDecoderConfig decoderConfig) { + this.version = version; + this.uri = uri; + if (subprotocols != null) { + String[] subprotocolArray = subprotocols.split(","); + for (int i = 0; i < subprotocolArray.length; i++) { + subprotocolArray[i] = subprotocolArray[i].trim(); + } + this.subprotocols = subprotocolArray; + } else { + this.subprotocols = EmptyArrays.EMPTY_STRINGS; + } + this.decoderConfig = ObjectUtil.checkNotNull(decoderConfig, "decoderConfig"); + } + + /** + * Returns the URL of the web socket + */ + public String uri() { + return uri; + } + + /** + * Returns the CSV of supported sub protocols + */ + public Set subprotocols() { + Set ret = new LinkedHashSet(); + Collections.addAll(ret, subprotocols); + return ret; + } + + /** + * Returns the version of the specification being supported + */ + public WebSocketVersion version() { + return version; + } + + /** + * Gets the maximum length for any frame's payload. + * + * @return The maximum length for a frame's payload + */ + public int maxFramePayloadLength() { + return decoderConfig.maxFramePayloadLength(); + } + + /** + * Gets this decoder configuration. + * + * @return This decoder configuration. + */ + public WebSocketDecoderConfig decoderConfig() { + return decoderConfig; + } + + /** + * Performs the opening handshake. When call this method you MUST NOT retain the + * {@link FullHttpRequest} which is passed in. + * + * @param channel + * Channel + * @param req + * HTTP Request + * @return future + * The {@link ChannelFuture} which is notified once the opening handshake completes + */ + public ChannelFuture handshake(Channel channel, FullHttpRequest req) { + return handshake(channel, req, null, channel.newPromise()); + } + + /** + * Performs the opening handshake + * + * When call this method you MUST NOT retain the {@link FullHttpRequest} which is passed in. + * + * @param channel + * Channel + * @param req + * HTTP Request + * @param responseHeaders + * Extra headers to add to the handshake response or {@code null} if no extra headers should be added + * @param promise + * the {@link ChannelPromise} to be notified when the opening handshake is done + * @return future + * the {@link ChannelFuture} which is notified when the opening handshake is done + */ + public final ChannelFuture handshake(Channel channel, FullHttpRequest req, + HttpHeaders responseHeaders, final ChannelPromise promise) { + + if (logger.isDebugEnabled()) { + logger.debug("{} WebSocket version {} server handshake", channel, version()); + } + FullHttpResponse response = newHandshakeResponse(req, responseHeaders); + ChannelPipeline p = channel.pipeline(); + if (p.get(HttpObjectAggregator.class) != null) { + p.remove(HttpObjectAggregator.class); + } + if (p.get(HttpContentCompressor.class) != null) { + p.remove(HttpContentCompressor.class); + } + ChannelHandlerContext ctx = p.context(HttpRequestDecoder.class); + final String encoderName; + if (ctx == null) { + // this means the user use an HttpServerCodec + ctx = p.context(HttpServerCodec.class); + if (ctx == null) { + promise.setFailure( + new IllegalStateException("No HttpDecoder and no HttpServerCodec in the pipeline")); + response.release(); + return promise; + } + p.addBefore(ctx.name(), "wsencoder", newWebSocketEncoder()); + p.addBefore(ctx.name(), "wsdecoder", newWebsocketDecoder()); + encoderName = ctx.name(); + } else { + p.replace(ctx.name(), "wsdecoder", newWebsocketDecoder()); + + encoderName = p.context(HttpResponseEncoder.class).name(); + p.addBefore(encoderName, "wsencoder", newWebSocketEncoder()); + } + channel.writeAndFlush(response).addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + if (future.isSuccess()) { + ChannelPipeline p = future.channel().pipeline(); + p.remove(encoderName); + promise.setSuccess(); + } else { + promise.setFailure(future.cause()); + } + } + }); + return promise; + } + + /** + * Performs the opening handshake. When call this method you MUST NOT retain the + * {@link FullHttpRequest} which is passed in. + * + * @param channel + * Channel + * @param req + * HTTP Request + * @return future + * The {@link ChannelFuture} which is notified once the opening handshake completes + */ + public ChannelFuture handshake(Channel channel, HttpRequest req) { + return handshake(channel, req, null, channel.newPromise()); + } + + /** + * Performs the opening handshake + * + * When call this method you MUST NOT retain the {@link HttpRequest} which is passed in. + * + * @param channel + * Channel + * @param req + * HTTP Request + * @param responseHeaders + * Extra headers to add to the handshake response or {@code null} if no extra headers should be added + * @param promise + * the {@link ChannelPromise} to be notified when the opening handshake is done + * @return future + * the {@link ChannelFuture} which is notified when the opening handshake is done + */ + public final ChannelFuture handshake(final Channel channel, HttpRequest req, + final HttpHeaders responseHeaders, final ChannelPromise promise) { + if (req instanceof FullHttpRequest) { + return handshake(channel, (FullHttpRequest) req, responseHeaders, promise); + } + + if (logger.isDebugEnabled()) { + logger.debug("{} WebSocket version {} server handshake", channel, version()); + } + + ChannelPipeline p = channel.pipeline(); + ChannelHandlerContext ctx = p.context(HttpRequestDecoder.class); + if (ctx == null) { + // this means the user use an HttpServerCodec + ctx = p.context(HttpServerCodec.class); + if (ctx == null) { + promise.setFailure( + new IllegalStateException("No HttpDecoder and no HttpServerCodec in the pipeline")); + return promise; + } + } + + String aggregatorCtx = ctx.name(); + if (HttpUtil.isContentLengthSet(req) || HttpUtil.isTransferEncodingChunked(req) || + version == WebSocketVersion.V00) { + // Add aggregator and ensure we feed the HttpRequest so it is aggregated. A limit of 8192 should be + // more then enough for the websockets handshake payload. + aggregatorCtx = "httpAggregator"; + p.addAfter(ctx.name(), aggregatorCtx, new HttpObjectAggregator(8192)); + } + + p.addAfter(aggregatorCtx, "handshaker", new ChannelInboundHandlerAdapter() { + + private FullHttpRequest fullHttpRequest; + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if (msg instanceof HttpObject) { + try { + handleHandshakeRequest(ctx, (HttpObject) msg); + } finally { + ReferenceCountUtil.release(msg); + } + } else { + super.channelRead(ctx, msg); + } + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + // Remove ourself and fail the handshake promise. + ctx.pipeline().remove(this); + promise.tryFailure(cause); + ctx.fireExceptionCaught(cause); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + try { + // Fail promise if Channel was closed + if (!promise.isDone()) { + promise.tryFailure(new ClosedChannelException()); + } + ctx.fireChannelInactive(); + } finally { + releaseFullHttpRequest(); + } + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { + releaseFullHttpRequest(); + } + + private void handleHandshakeRequest(ChannelHandlerContext ctx, HttpObject httpObject) { + if (httpObject instanceof FullHttpRequest) { + ctx.pipeline().remove(this); + handshake(channel, (FullHttpRequest) httpObject, responseHeaders, promise); + return; + } + + if (httpObject instanceof LastHttpContent) { + assert fullHttpRequest != null; + FullHttpRequest handshakeRequest = fullHttpRequest; + fullHttpRequest = null; + try { + ctx.pipeline().remove(this); + handshake(channel, handshakeRequest, responseHeaders, promise); + } finally { + handshakeRequest.release(); + } + return; + } + + if (httpObject instanceof HttpRequest) { + HttpRequest httpRequest = (HttpRequest) httpObject; + fullHttpRequest = new DefaultFullHttpRequest(httpRequest.protocolVersion(), httpRequest.method(), + httpRequest.uri(), Unpooled.EMPTY_BUFFER, httpRequest.headers(), EmptyHttpHeaders.INSTANCE); + if (httpRequest.decoderResult().isFailure()) { + fullHttpRequest.setDecoderResult(httpRequest.decoderResult()); + } + } + } + + private void releaseFullHttpRequest() { + if (fullHttpRequest != null) { + fullHttpRequest.release(); + fullHttpRequest = null; + } + } + }); + try { + ctx.fireChannelRead(ReferenceCountUtil.retain(req)); + } catch (Throwable cause) { + promise.setFailure(cause); + } + return promise; + } + + /** + * @return a new FullHttpResponse which will be used for as response to the handshake request. + */ + protected abstract FullHttpResponse newHandshakeResponse(FullHttpRequest req, + HttpHeaders responseHeaders); + /** + * Performs the closing handshake. + * + * When called from within a {@link ChannelHandler} you most likely want to use + * {@link #close(ChannelHandlerContext, CloseWebSocketFrame)}. + * + * @param channel + * the {@link Channel} to use. + * @param frame + * Closing Frame that was received. + */ + public ChannelFuture close(Channel channel, CloseWebSocketFrame frame) { + ObjectUtil.checkNotNull(channel, "channel"); + return close(channel, frame, channel.newPromise()); + } + + /** + * Performs the closing handshake. + * + * When called from within a {@link ChannelHandler} you most likely want to use + * {@link #close(ChannelHandlerContext, CloseWebSocketFrame, ChannelPromise)}. + * + * @param channel + * the {@link Channel} to use. + * @param frame + * Closing Frame that was received. + * @param promise + * the {@link ChannelPromise} to be notified when the closing handshake is done + */ + public ChannelFuture close(Channel channel, CloseWebSocketFrame frame, ChannelPromise promise) { + return close0(channel, frame, promise); + } + + /** + * Performs the closing handshake. + * + * @param ctx + * the {@link ChannelHandlerContext} to use. + * @param frame + * Closing Frame that was received. + */ + public ChannelFuture close(ChannelHandlerContext ctx, CloseWebSocketFrame frame) { + ObjectUtil.checkNotNull(ctx, "ctx"); + return close(ctx, frame, ctx.newPromise()); + } + + /** + * Performs the closing handshake. + * + * @param ctx + * the {@link ChannelHandlerContext} to use. + * @param frame + * Closing Frame that was received. + * @param promise + * the {@link ChannelPromise} to be notified when the closing handshake is done. + */ + public ChannelFuture close(ChannelHandlerContext ctx, CloseWebSocketFrame frame, ChannelPromise promise) { + ObjectUtil.checkNotNull(ctx, "ctx"); + return close0(ctx, frame, promise).addListener(ChannelFutureListener.CLOSE); + } + + private ChannelFuture close0(ChannelOutboundInvoker invoker, CloseWebSocketFrame frame, ChannelPromise promise) { + return invoker.writeAndFlush(frame, promise).addListener(ChannelFutureListener.CLOSE); + } + + /** + * Selects the first matching supported sub protocol + * + * @param requestedSubprotocols + * CSV of protocols to be supported. e.g. "chat, superchat" + * @return First matching supported sub protocol. Null if not found. + */ + protected String selectSubprotocol(String requestedSubprotocols) { + if (requestedSubprotocols == null || subprotocols.length == 0) { + return null; + } + + String[] requestedSubprotocolArray = requestedSubprotocols.split(","); + for (String p: requestedSubprotocolArray) { + String requestedSubprotocol = p.trim(); + + for (String supportedSubprotocol: subprotocols) { + if (SUB_PROTOCOL_WILDCARD.equals(supportedSubprotocol) + || requestedSubprotocol.equals(supportedSubprotocol)) { + selectedSubprotocol = requestedSubprotocol; + return requestedSubprotocol; + } + } + } + + // No match found + return null; + } + + /** + * Returns the selected subprotocol. Null if no subprotocol has been selected. + *

+ * This is only available AFTER handshake() has been called. + *

+ */ + public String selectedSubprotocol() { + return selectedSubprotocol; + } + + /** + * Returns the decoder to use after handshake is complete. + */ + protected abstract WebSocketFrameDecoder newWebsocketDecoder(); + + /** + * Returns the encoder to use after the handshake is complete. + */ + protected abstract WebSocketFrameEncoder newWebSocketEncoder(); + +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshaker00.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshaker00.java new file mode 100644 index 0000000..961be56 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshaker00.java @@ -0,0 +1,245 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.handler.codec.http.DefaultFullHttpResponse; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaderValues; +import io.netty.handler.codec.http.HttpHeaders; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpResponseStatus; + +import java.util.regex.Pattern; + +import static io.netty.handler.codec.http.HttpMethod.GET; +import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1; + +/** + *

+ * Performs server side opening and closing handshakes for web socket specification version draft-ietf-hybi-thewebsocketprotocol- + * 00 + *

+ *

+ * A very large portion of this code was taken from the Netty 3.2 HTTP example. + *

+ */ +public class WebSocketServerHandshaker00 extends WebSocketServerHandshaker { + + private static final Pattern BEGINNING_DIGIT = Pattern.compile("[^0-9]"); + private static final Pattern BEGINNING_SPACE = Pattern.compile("[^ ]"); + + /** + * Constructor specifying the destination web socket location + * + * @param webSocketURL + * URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be + * sent to this URL. + * @param subprotocols + * CSV of supported protocols + * @param maxFramePayloadLength + * Maximum allowable frame payload length. Setting this value to your application's requirement may + * reduce denial of service attacks using long data frames. + */ + public WebSocketServerHandshaker00(String webSocketURL, String subprotocols, int maxFramePayloadLength) { + this(webSocketURL, subprotocols, WebSocketDecoderConfig.newBuilder() + .maxFramePayloadLength(maxFramePayloadLength) + .build()); + } + + /** + * Constructor specifying the destination web socket location + * + * @param webSocketURL + * URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web socket frames will be + * sent to this URL. + * @param subprotocols + * CSV of supported protocols + * @param decoderConfig + * Frames decoder configuration. + */ + public WebSocketServerHandshaker00(String webSocketURL, String subprotocols, WebSocketDecoderConfig decoderConfig) { + super(WebSocketVersion.V00, webSocketURL, subprotocols, decoderConfig); + } + + /** + *

+ * Handle the web socket handshake for the web socket specification HyBi version 0 and lower. This standard + * is really a rehash of hixie-76 and + * hixie-75. + *

+ * + *

+ * Browser request to the server: + *

+ * + *
+     * GET /demo HTTP/1.1
+     * Upgrade: WebSocket
+     * Connection: Upgrade
+     * Host: example.com
+     * Origin: http://example.com
+     * Sec-WebSocket-Protocol: chat, sample
+     * Sec-WebSocket-Key1: 4 @1  46546xW%0l 1 5
+     * Sec-WebSocket-Key2: 12998 5 Y3 1  .P00
+     *
+     * ^n:ds[4U
+     * 
+ * + *

+ * Server response: + *

+ * + *
+     * HTTP/1.1 101 WebSocket Protocol Handshake
+     * Upgrade: WebSocket
+     * Connection: Upgrade
+     * Sec-WebSocket-Origin: http://example.com
+     * Sec-WebSocket-Location: ws://example.com/demo
+     * Sec-WebSocket-Protocol: sample
+     *
+     * 8jKS'y:G*Co,Wxa-
+     * 
+ */ + @Override + protected FullHttpResponse newHandshakeResponse(FullHttpRequest req, HttpHeaders headers) { + HttpMethod method = req.method(); + if (!GET.equals(method)) { + throw new WebSocketServerHandshakeException("Invalid WebSocket handshake method: " + method, req); + } + + // Serve the WebSocket handshake request. + if (!req.headers().containsValue(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE, true) + || !HttpHeaderValues.WEBSOCKET.contentEqualsIgnoreCase(req.headers().get(HttpHeaderNames.UPGRADE))) { + throw new WebSocketServerHandshakeException("not a WebSocket handshake request: missing upgrade", req); + } + + // Hixie 75 does not contain these headers while Hixie 76 does + boolean isHixie76 = req.headers().contains(HttpHeaderNames.SEC_WEBSOCKET_KEY1) && + req.headers().contains(HttpHeaderNames.SEC_WEBSOCKET_KEY2); + + String origin = req.headers().get(HttpHeaderNames.ORIGIN); + //throw before allocating FullHttpResponse + if (origin == null && !isHixie76) { + throw new WebSocketServerHandshakeException("Missing origin header, got only " + req.headers().names(), + req); + } + + // Create the WebSocket handshake response. + FullHttpResponse res = new DefaultFullHttpResponse(HTTP_1_1, new HttpResponseStatus(101, + isHixie76 ? "WebSocket Protocol Handshake" : "Web Socket Protocol Handshake"), + req.content().alloc().buffer(0)); + if (headers != null) { + res.headers().add(headers); + } + + res.headers().set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET) + .set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE); + + // Fill in the headers and contents depending on handshake getMethod. + if (isHixie76) { + // New handshake getMethod with a challenge: + res.headers().add(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN, origin); + res.headers().add(HttpHeaderNames.SEC_WEBSOCKET_LOCATION, uri()); + + String subprotocols = req.headers().get(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL); + if (subprotocols != null) { + String selectedSubprotocol = selectSubprotocol(subprotocols); + if (selectedSubprotocol == null) { + if (logger.isDebugEnabled()) { + logger.debug("Requested subprotocol(s) not supported: {}", subprotocols); + } + } else { + res.headers().set(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, selectedSubprotocol); + } + } + + // Calculate the answer of the challenge. + String key1 = req.headers().get(HttpHeaderNames.SEC_WEBSOCKET_KEY1); + String key2 = req.headers().get(HttpHeaderNames.SEC_WEBSOCKET_KEY2); + int a = (int) (Long.parseLong(BEGINNING_DIGIT.matcher(key1).replaceAll("")) / + BEGINNING_SPACE.matcher(key1).replaceAll("").length()); + int b = (int) (Long.parseLong(BEGINNING_DIGIT.matcher(key2).replaceAll("")) / + BEGINNING_SPACE.matcher(key2).replaceAll("").length()); + long c = req.content().readLong(); + ByteBuf input = Unpooled.wrappedBuffer(new byte[16]).setIndex(0, 0); + input.writeInt(a); + input.writeInt(b); + input.writeLong(c); + res.content().writeBytes(WebSocketUtil.md5(input.array())); + } else { + // Old Hixie 75 handshake getMethod with no challenge: + res.headers().add(HttpHeaderNames.WEBSOCKET_ORIGIN, origin); + res.headers().add(HttpHeaderNames.WEBSOCKET_LOCATION, uri()); + + String protocol = req.headers().get(HttpHeaderNames.WEBSOCKET_PROTOCOL); + if (protocol != null) { + res.headers().set(HttpHeaderNames.WEBSOCKET_PROTOCOL, selectSubprotocol(protocol)); + } + } + return res; + } + + /** + * Echo back the closing frame + * + * @param channel + * the {@link Channel} to use. + * @param frame + * Web Socket frame that was received. + * @param promise + * the {@link ChannelPromise} to be notified when the closing handshake is done. + */ + @Override + public ChannelFuture close(Channel channel, CloseWebSocketFrame frame, ChannelPromise promise) { + return channel.writeAndFlush(frame, promise); + } + + /** + * Echo back the closing frame + * + * @param ctx + * the {@link ChannelHandlerContext} to use. + * @param frame + * Closing Frame that was received. + * @param promise + * the {@link ChannelPromise} to be notified when the closing handshake is done. + */ + @Override + public ChannelFuture close(ChannelHandlerContext ctx, CloseWebSocketFrame frame, + ChannelPromise promise) { + return ctx.writeAndFlush(frame, promise); + } + + @Override + protected WebSocketFrameDecoder newWebsocketDecoder() { + return new WebSocket00FrameDecoder(decoderConfig()); + } + + @Override + protected WebSocketFrameEncoder newWebSocketEncoder() { + return new WebSocket00FrameEncoder(); + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshaker07.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshaker07.java new file mode 100644 index 0000000..eee7636 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshaker07.java @@ -0,0 +1,186 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx; + +import io.netty.handler.codec.http.DefaultFullHttpResponse; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaderValues; +import io.netty.handler.codec.http.HttpHeaders; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.util.CharsetUtil; + +import static io.netty.handler.codec.http.HttpMethod.GET; +import static io.netty.handler.codec.http.HttpVersion.*; + +/** + *

+ * Performs server side opening and closing handshakes for web socket specification version draft-ietf-hybi-thewebsocketprotocol- + * 10 + *

+ */ +public class WebSocketServerHandshaker07 extends WebSocketServerHandshaker { + + public static final String WEBSOCKET_07_ACCEPT_GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + + /** + * Constructor specifying the destination web socket location + * + * @param webSocketURL + * URL for web socket communications. e.g "ws://myhost.com/mypath". + * Subsequent web socket frames will be sent to this URL. + * @param subprotocols + * CSV of supported protocols + * @param allowExtensions + * Allow extensions to be used in the reserved bits of the web socket frame + * @param maxFramePayloadLength + * Maximum allowable frame payload length. Setting this value to your application's + * requirement may reduce denial of service attacks using long data frames. + */ + public WebSocketServerHandshaker07( + String webSocketURL, String subprotocols, boolean allowExtensions, int maxFramePayloadLength) { + this(webSocketURL, subprotocols, allowExtensions, maxFramePayloadLength, false); + } + + /** + * Constructor specifying the destination web socket location + * + * @param webSocketURL + * URL for web socket communications. e.g "ws://myhost.com/mypath". + * Subsequent web socket frames will be sent to this URL. + * @param subprotocols + * CSV of supported protocols + * @param allowExtensions + * Allow extensions to be used in the reserved bits of the web socket frame + * @param maxFramePayloadLength + * Maximum allowable frame payload length. Setting this value to your application's + * requirement may reduce denial of service attacks using long data frames. + * @param allowMaskMismatch + * When set to true, frames which are not masked properly according to the standard will still be + * accepted. + */ + public WebSocketServerHandshaker07( + String webSocketURL, String subprotocols, boolean allowExtensions, int maxFramePayloadLength, + boolean allowMaskMismatch) { + this(webSocketURL, subprotocols, WebSocketDecoderConfig.newBuilder() + .allowExtensions(allowExtensions) + .maxFramePayloadLength(maxFramePayloadLength) + .allowMaskMismatch(allowMaskMismatch) + .build()); + } + + /** + * Constructor specifying the destination web socket location + * + * @param decoderConfig + * Frames decoder configuration. + */ + public WebSocketServerHandshaker07(String webSocketURL, String subprotocols, WebSocketDecoderConfig decoderConfig) { + super(WebSocketVersion.V07, webSocketURL, subprotocols, decoderConfig); + } + + /** + *

+ * Handle the web socket handshake for the web socket specification HyBi version 7. + *

+ * + *

+ * Browser request to the server: + *

+ * + *
+     * GET /chat HTTP/1.1
+     * Host: server.example.com
+     * Upgrade: websocket
+     * Connection: Upgrade
+     * Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
+     * Sec-WebSocket-Origin: http://example.com
+     * Sec-WebSocket-Protocol: chat, superchat
+     * Sec-WebSocket-Version: 7
+     * 
+ * + *

+ * Server response: + *

+ * + *
+     * HTTP/1.1 101 Switching Protocols
+     * Upgrade: websocket
+     * Connection: Upgrade
+     * Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=
+     * Sec-WebSocket-Protocol: chat
+     * 
+ */ + @Override + protected FullHttpResponse newHandshakeResponse(FullHttpRequest req, HttpHeaders headers) { + HttpMethod method = req.method(); + if (!GET.equals(method)) { + throw new WebSocketServerHandshakeException("Invalid WebSocket handshake method: " + method, req); + } + + CharSequence key = req.headers().get(HttpHeaderNames.SEC_WEBSOCKET_KEY); + if (key == null) { + throw new WebSocketServerHandshakeException("not a WebSocket request: missing key", req); + } + + FullHttpResponse res = + new DefaultFullHttpResponse(HTTP_1_1, HttpResponseStatus.SWITCHING_PROTOCOLS, + req.content().alloc().buffer(0)); + + if (headers != null) { + res.headers().add(headers); + } + + String acceptSeed = key + WEBSOCKET_07_ACCEPT_GUID; + byte[] sha1 = WebSocketUtil.sha1(acceptSeed.getBytes(CharsetUtil.US_ASCII)); + String accept = WebSocketUtil.base64(sha1); + + if (logger.isDebugEnabled()) { + logger.debug("WebSocket version 07 server handshake key: {}, response: {}.", key, accept); + } + + res.headers().set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET) + .set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE) + .set(HttpHeaderNames.SEC_WEBSOCKET_ACCEPT, accept); + + String subprotocols = req.headers().get(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL); + if (subprotocols != null) { + String selectedSubprotocol = selectSubprotocol(subprotocols); + if (selectedSubprotocol == null) { + if (logger.isDebugEnabled()) { + logger.debug("Requested subprotocol(s) not supported: {}", subprotocols); + } + } else { + res.headers().set(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, selectedSubprotocol); + } + } + return res; + } + + @Override + protected WebSocketFrameDecoder newWebsocketDecoder() { + return new WebSocket07FrameDecoder(decoderConfig()); + } + + @Override + protected WebSocketFrameEncoder newWebSocketEncoder() { + return new WebSocket07FrameEncoder(false); + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshaker08.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshaker08.java new file mode 100644 index 0000000..3c2a9bd --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshaker08.java @@ -0,0 +1,192 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx; + +import io.netty.handler.codec.http.DefaultFullHttpResponse; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaderValues; +import io.netty.handler.codec.http.HttpHeaders; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.util.CharsetUtil; + +import static io.netty.handler.codec.http.HttpMethod.GET; +import static io.netty.handler.codec.http.HttpVersion.*; + +/** + *

+ * Performs server side opening and closing handshakes for web socket specification version draft-ietf-hybi-thewebsocketprotocol- + * 10 + *

+ */ +public class WebSocketServerHandshaker08 extends WebSocketServerHandshaker { + + public static final String WEBSOCKET_08_ACCEPT_GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + + /** + * Constructor specifying the destination web socket location + * + * @param webSocketURL + * URL for web socket communications. e.g "ws://myhost.com/mypath". + * Subsequent web socket frames will be sent to this URL. + * @param subprotocols + * CSV of supported protocols + * @param allowExtensions + * Allow extensions to be used in the reserved bits of the web socket frame + * @param maxFramePayloadLength + * Maximum allowable frame payload length. Setting this value to your application's + * requirement may reduce denial of service attacks using long data frames. + */ + public WebSocketServerHandshaker08( + String webSocketURL, String subprotocols, boolean allowExtensions, int maxFramePayloadLength) { + this(webSocketURL, subprotocols, allowExtensions, maxFramePayloadLength, false); + } + + /** + * Constructor specifying the destination web socket location + * + * @param webSocketURL + * URL for web socket communications. e.g "ws://myhost.com/mypath". + * Subsequent web socket frames will be sent to this URL. + * @param subprotocols + * CSV of supported protocols + * @param allowExtensions + * Allow extensions to be used in the reserved bits of the web socket frame + * @param maxFramePayloadLength + * Maximum allowable frame payload length. Setting this value to your application's + * requirement may reduce denial of service attacks using long data frames. + * @param allowMaskMismatch + * When set to true, frames which are not masked properly according to the standard will still be + * accepted. + */ + public WebSocketServerHandshaker08( + String webSocketURL, String subprotocols, boolean allowExtensions, int maxFramePayloadLength, + boolean allowMaskMismatch) { + this(webSocketURL, subprotocols, WebSocketDecoderConfig.newBuilder() + .allowExtensions(allowExtensions) + .maxFramePayloadLength(maxFramePayloadLength) + .allowMaskMismatch(allowMaskMismatch) + .build()); + } + + /** + * Constructor specifying the destination web socket location + * + * @param webSocketURL + * URL for web socket communications. e.g "ws://myhost.com/mypath". + * Subsequent web socket frames will be sent to this URL. + * @param subprotocols + * CSV of supported protocols + * @param decoderConfig + * Frames decoder configuration. + */ + public WebSocketServerHandshaker08( + String webSocketURL, String subprotocols, WebSocketDecoderConfig decoderConfig) { + super(WebSocketVersion.V08, webSocketURL, subprotocols, decoderConfig); + } + + /** + *

+ * Handle the web socket handshake for the web socket specification HyBi version 8 to 10. Version 8, 9 and + * 10 share the same wire protocol. + *

+ * + *

+ * Browser request to the server: + *

+ * + *
+     * GET /chat HTTP/1.1
+     * Host: server.example.com
+     * Upgrade: websocket
+     * Connection: Upgrade
+     * Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
+     * Sec-WebSocket-Origin: http://example.com
+     * Sec-WebSocket-Protocol: chat, superchat
+     * Sec-WebSocket-Version: 8
+     * 
+ * + *

+ * Server response: + *

+ * + *
+     * HTTP/1.1 101 Switching Protocols
+     * Upgrade: websocket
+     * Connection: Upgrade
+     * Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=
+     * Sec-WebSocket-Protocol: chat
+     * 
+ */ + @Override + protected FullHttpResponse newHandshakeResponse(FullHttpRequest req, HttpHeaders headers) { + HttpMethod method = req.method(); + if (!GET.equals(method)) { + throw new WebSocketServerHandshakeException("Invalid WebSocket handshake method: " + method, req); + } + + CharSequence key = req.headers().get(HttpHeaderNames.SEC_WEBSOCKET_KEY); + if (key == null) { + throw new WebSocketServerHandshakeException("not a WebSocket request: missing key", req); + } + + FullHttpResponse res = new DefaultFullHttpResponse(HTTP_1_1, HttpResponseStatus.SWITCHING_PROTOCOLS, + req.content().alloc().buffer(0)); + + if (headers != null) { + res.headers().add(headers); + } + + String acceptSeed = key + WEBSOCKET_08_ACCEPT_GUID; + byte[] sha1 = WebSocketUtil.sha1(acceptSeed.getBytes(CharsetUtil.US_ASCII)); + String accept = WebSocketUtil.base64(sha1); + + if (logger.isDebugEnabled()) { + logger.debug("WebSocket version 08 server handshake key: {}, response: {}", key, accept); + } + + res.headers().set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET) + .set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE) + .set(HttpHeaderNames.SEC_WEBSOCKET_ACCEPT, accept); + + String subprotocols = req.headers().get(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL); + if (subprotocols != null) { + String selectedSubprotocol = selectSubprotocol(subprotocols); + if (selectedSubprotocol == null) { + if (logger.isDebugEnabled()) { + logger.debug("Requested subprotocol(s) not supported: {}", subprotocols); + } + } else { + res.headers().set(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, selectedSubprotocol); + } + } + return res; + } + + @Override + protected WebSocketFrameDecoder newWebsocketDecoder() { + return new WebSocket08FrameDecoder(decoderConfig()); + } + + @Override + protected WebSocketFrameEncoder newWebSocketEncoder() { + return new WebSocket08FrameEncoder(false); + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshaker13.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshaker13.java new file mode 100644 index 0000000..153bc64 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshaker13.java @@ -0,0 +1,202 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx; + +import io.netty.handler.codec.http.DefaultFullHttpResponse; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaderValues; +import io.netty.handler.codec.http.HttpHeaders; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.util.CharsetUtil; + +import static io.netty.handler.codec.http.HttpMethod.GET; +import static io.netty.handler.codec.http.HttpVersion.*; + +/** + *

+ * Performs server side opening and closing handshakes for RFC 6455 + * (originally web socket specification draft-ietf-hybi-thewebsocketprotocol-17). + *

+ */ +public class WebSocketServerHandshaker13 extends WebSocketServerHandshaker { + + public static final String WEBSOCKET_13_ACCEPT_GUID = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + + /** + * Constructor specifying the destination web socket location + * + * @param webSocketURL + * URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web + * socket frames will be sent to this URL. + * @param subprotocols + * CSV of supported protocols + * @param allowExtensions + * Allow extensions to be used in the reserved bits of the web socket frame + * @param maxFramePayloadLength + * Maximum allowable frame payload length. Setting this value to your application's + * requirement may reduce denial of service attacks using long data frames. + */ + public WebSocketServerHandshaker13( + String webSocketURL, String subprotocols, boolean allowExtensions, int maxFramePayloadLength) { + this(webSocketURL, subprotocols, allowExtensions, maxFramePayloadLength, false); + } + + /** + * Constructor specifying the destination web socket location + * + * @param webSocketURL + * URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web + * socket frames will be sent to this URL. + * @param subprotocols + * CSV of supported protocols + * @param allowExtensions + * Allow extensions to be used in the reserved bits of the web socket frame + * @param maxFramePayloadLength + * Maximum allowable frame payload length. Setting this value to your application's + * requirement may reduce denial of service attacks using long data frames. + * @param allowMaskMismatch + * When set to true, frames which are not masked properly according to the standard will still be + * accepted. + */ + public WebSocketServerHandshaker13( + String webSocketURL, String subprotocols, boolean allowExtensions, int maxFramePayloadLength, + boolean allowMaskMismatch) { + this(webSocketURL, subprotocols, WebSocketDecoderConfig.newBuilder() + .allowExtensions(allowExtensions) + .maxFramePayloadLength(maxFramePayloadLength) + .allowMaskMismatch(allowMaskMismatch) + .build()); + } + + /** + * Constructor specifying the destination web socket location + * + * @param webSocketURL + * URL for web socket communications. e.g "ws://myhost.com/mypath". Subsequent web + * socket frames will be sent to this URL. + * @param subprotocols + * CSV of supported protocols + * @param decoderConfig + * Frames decoder configuration. + */ + public WebSocketServerHandshaker13( + String webSocketURL, String subprotocols, WebSocketDecoderConfig decoderConfig) { + super(WebSocketVersion.V13, webSocketURL, subprotocols, decoderConfig); + } + + /** + *

+ * Handle the web socket handshake for the web socket specification HyBi versions 13-17. Versions 13-17 + * share the same wire protocol. + *

+ * + *

+ * Browser request to the server: + *

+ * + *
+     * GET /chat HTTP/1.1
+     * Host: server.example.com
+     * Upgrade: websocket
+     * Connection: Upgrade
+     * Sec-WebSocket-Key: dGhlIHNhbXBsZSBub25jZQ==
+     * Origin: http://example.com
+     * Sec-WebSocket-Protocol: chat, superchat
+     * Sec-WebSocket-Version: 13
+     * 
+ * + *

+ * Server response: + *

+ * + *
+     * HTTP/1.1 101 Switching Protocols
+     * Upgrade: websocket
+     * Connection: Upgrade
+     * Sec-WebSocket-Accept: s3pPLMBiTxaQ9kYGzzhZRbK+xOo=
+     * Sec-WebSocket-Protocol: chat
+     * 
+ */ + @Override + protected FullHttpResponse newHandshakeResponse(FullHttpRequest req, HttpHeaders headers) { + HttpMethod method = req.method(); + if (!GET.equals(method)) { + throw new WebSocketServerHandshakeException("Invalid WebSocket handshake method: " + method, req); + } + + HttpHeaders reqHeaders = req.headers(); + if (!reqHeaders.contains(HttpHeaderNames.CONNECTION) || + !reqHeaders.containsValue(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE, true)) { + throw new WebSocketServerHandshakeException( + "not a WebSocket request: a |Connection| header must includes a token 'Upgrade'", req); + } + + if (!reqHeaders.contains(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET, true)) { + throw new WebSocketServerHandshakeException( + "not a WebSocket request: a |Upgrade| header must containing the value 'websocket'", req); + } + + CharSequence key = reqHeaders.get(HttpHeaderNames.SEC_WEBSOCKET_KEY); + if (key == null) { + throw new WebSocketServerHandshakeException("not a WebSocket request: missing key", req); + } + + FullHttpResponse res = new DefaultFullHttpResponse(HTTP_1_1, HttpResponseStatus.SWITCHING_PROTOCOLS, + req.content().alloc().buffer(0)); + if (headers != null) { + res.headers().add(headers); + } + + String acceptSeed = key + WEBSOCKET_13_ACCEPT_GUID; + byte[] sha1 = WebSocketUtil.sha1(acceptSeed.getBytes(CharsetUtil.US_ASCII)); + String accept = WebSocketUtil.base64(sha1); + + if (logger.isDebugEnabled()) { + logger.debug("WebSocket version 13 server handshake key: {}, response: {}", key, accept); + } + + res.headers().set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET) + .set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE) + .set(HttpHeaderNames.SEC_WEBSOCKET_ACCEPT, accept); + + String subprotocols = reqHeaders.get(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL); + if (subprotocols != null) { + String selectedSubprotocol = selectSubprotocol(subprotocols); + if (selectedSubprotocol == null) { + if (logger.isDebugEnabled()) { + logger.debug("Requested subprotocol(s) not supported: {}", subprotocols); + } + } else { + res.headers().set(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, selectedSubprotocol); + } + } + return res; + } + + @Override + protected WebSocketFrameDecoder newWebsocketDecoder() { + return new WebSocket13FrameDecoder(decoderConfig()); + } + + @Override + protected WebSocketFrameEncoder newWebSocketEncoder() { + return new WebSocket13FrameEncoder(false); + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshakerFactory.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshakerFactory.java new file mode 100644 index 0000000..266622d --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshakerFactory.java @@ -0,0 +1,180 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx; + +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelPromise; +import io.netty.handler.codec.http.DefaultFullHttpResponse; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpUtil; +import io.netty.handler.codec.http.HttpRequest; +import io.netty.handler.codec.http.HttpResponse; +import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.handler.codec.http.HttpVersion; +import io.netty.util.internal.ObjectUtil; + +/** + * Auto-detects the version of the Web Socket protocol in use and creates a new proper + * {@link WebSocketServerHandshaker}. + */ +public class WebSocketServerHandshakerFactory { + + private final String webSocketURL; + + private final String subprotocols; + + private final WebSocketDecoderConfig decoderConfig; + + /** + * Constructor specifying the destination web socket location + * + * @param webSocketURL + * URL for web socket communications. e.g "ws://myhost.com/mypath". + * Subsequent web socket frames will be sent to this URL. + * @param subprotocols + * CSV of supported protocols. Null if sub protocols not supported. + * @param allowExtensions + * Allow extensions to be used in the reserved bits of the web socket frame + */ + public WebSocketServerHandshakerFactory( + String webSocketURL, String subprotocols, boolean allowExtensions) { + this(webSocketURL, subprotocols, allowExtensions, 65536); + } + + /** + * Constructor specifying the destination web socket location + * + * @param webSocketURL + * URL for web socket communications. e.g "ws://myhost.com/mypath". + * Subsequent web socket frames will be sent to this URL. + * @param subprotocols + * CSV of supported protocols. Null if sub protocols not supported. + * @param allowExtensions + * Allow extensions to be used in the reserved bits of the web socket frame + * @param maxFramePayloadLength + * Maximum allowable frame payload length. Setting this value to your application's + * requirement may reduce denial of service attacks using long data frames. + */ + public WebSocketServerHandshakerFactory( + String webSocketURL, String subprotocols, boolean allowExtensions, + int maxFramePayloadLength) { + this(webSocketURL, subprotocols, allowExtensions, maxFramePayloadLength, false); + } + + /** + * Constructor specifying the destination web socket location + * + * @param webSocketURL + * URL for web socket communications. e.g "ws://myhost.com/mypath". + * Subsequent web socket frames will be sent to this URL. + * @param subprotocols + * CSV of supported protocols. Null if sub protocols not supported. + * @param allowExtensions + * Allow extensions to be used in the reserved bits of the web socket frame + * @param maxFramePayloadLength + * Maximum allowable frame payload length. Setting this value to your application's + * requirement may reduce denial of service attacks using long data frames. + * @param allowMaskMismatch + * When set to true, frames which are not masked properly according to the standard will still be + * accepted. + */ + public WebSocketServerHandshakerFactory( + String webSocketURL, String subprotocols, boolean allowExtensions, + int maxFramePayloadLength, boolean allowMaskMismatch) { + this(webSocketURL, subprotocols, WebSocketDecoderConfig.newBuilder() + .allowExtensions(allowExtensions) + .maxFramePayloadLength(maxFramePayloadLength) + .allowMaskMismatch(allowMaskMismatch) + .build()); + } + + /** + * Constructor specifying the destination web socket location + * + * @param webSocketURL + * URL for web socket communications. e.g "ws://myhost.com/mypath". + * Subsequent web socket frames will be sent to this URL. + * @param subprotocols + * CSV of supported protocols. Null if sub protocols not supported. + * @param decoderConfig + * Frames decoder options. + */ + public WebSocketServerHandshakerFactory( + String webSocketURL, String subprotocols, WebSocketDecoderConfig decoderConfig) { + this.webSocketURL = webSocketURL; + this.subprotocols = subprotocols; + this.decoderConfig = ObjectUtil.checkNotNull(decoderConfig, "decoderConfig"); + } + + /** + * Instances a new handshaker + * + * @return A new WebSocketServerHandshaker for the requested web socket version. Null if web + * socket version is not supported. + */ + public WebSocketServerHandshaker newHandshaker(HttpRequest req) { + + CharSequence version = req.headers().get(HttpHeaderNames.SEC_WEBSOCKET_VERSION); + if (version != null) { + if (version.equals(WebSocketVersion.V13.toHttpHeaderValue())) { + // Version 13 of the wire protocol - RFC 6455 (version 17 of the draft hybi specification). + return new WebSocketServerHandshaker13( + webSocketURL, subprotocols, decoderConfig); + } else if (version.equals(WebSocketVersion.V08.toHttpHeaderValue())) { + // Version 8 of the wire protocol - version 10 of the draft hybi specification. + return new WebSocketServerHandshaker08( + webSocketURL, subprotocols, decoderConfig); + } else if (version.equals(WebSocketVersion.V07.toHttpHeaderValue())) { + // Version 8 of the wire protocol - version 07 of the draft hybi specification. + return new WebSocketServerHandshaker07( + webSocketURL, subprotocols, decoderConfig); + } else { + return null; + } + } else { + // Assume version 00 where version header was not specified + return new WebSocketServerHandshaker00(webSocketURL, subprotocols, decoderConfig); + } + } + + /** + * @deprecated use {@link #sendUnsupportedVersionResponse(Channel)} + */ + @Deprecated + public static void sendUnsupportedWebSocketVersionResponse(Channel channel) { + sendUnsupportedVersionResponse(channel); + } + + /** + * Return that we need cannot support the web socket version + */ + public static ChannelFuture sendUnsupportedVersionResponse(Channel channel) { + return sendUnsupportedVersionResponse(channel, channel.newPromise()); + } + + /** + * Return that we need cannot support the web socket version + */ + public static ChannelFuture sendUnsupportedVersionResponse(Channel channel, ChannelPromise promise) { + HttpResponse res = new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, + HttpResponseStatus.UPGRADE_REQUIRED, channel.alloc().buffer(0)); + res.headers().set(HttpHeaderNames.SEC_WEBSOCKET_VERSION, WebSocketVersion.V13.toHttpHeaderValue()); + HttpUtil.setContentLength(res, 0); + return channel.writeAndFlush(res, promise); + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolConfig.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolConfig.java new file mode 100644 index 0000000..97954f8 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolConfig.java @@ -0,0 +1,296 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx; + +import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler.ClientHandshakeStateEvent; +import io.netty.util.internal.ObjectUtil; + +import static io.netty.util.internal.ObjectUtil.checkPositive; + +/** + * WebSocket server configuration. + */ +public final class WebSocketServerProtocolConfig { + + static final long DEFAULT_HANDSHAKE_TIMEOUT_MILLIS = 10000L; + + private final String websocketPath; + private final String subprotocols; + private final boolean checkStartsWith; + private final long handshakeTimeoutMillis; + private final long forceCloseTimeoutMillis; + private final boolean handleCloseFrames; + private final WebSocketCloseStatus sendCloseFrame; + private final boolean dropPongFrames; + private final WebSocketDecoderConfig decoderConfig; + + private WebSocketServerProtocolConfig( + String websocketPath, + String subprotocols, + boolean checkStartsWith, + long handshakeTimeoutMillis, + long forceCloseTimeoutMillis, + boolean handleCloseFrames, + WebSocketCloseStatus sendCloseFrame, + boolean dropPongFrames, + WebSocketDecoderConfig decoderConfig + ) { + this.websocketPath = websocketPath; + this.subprotocols = subprotocols; + this.checkStartsWith = checkStartsWith; + this.handshakeTimeoutMillis = checkPositive(handshakeTimeoutMillis, "handshakeTimeoutMillis"); + this.forceCloseTimeoutMillis = forceCloseTimeoutMillis; + this.handleCloseFrames = handleCloseFrames; + this.sendCloseFrame = sendCloseFrame; + this.dropPongFrames = dropPongFrames; + this.decoderConfig = decoderConfig == null ? WebSocketDecoderConfig.DEFAULT : decoderConfig; + } + + public String websocketPath() { + return websocketPath; + } + + public String subprotocols() { + return subprotocols; + } + + public boolean checkStartsWith() { + return checkStartsWith; + } + + public long handshakeTimeoutMillis() { + return handshakeTimeoutMillis; + } + + public long forceCloseTimeoutMillis() { + return forceCloseTimeoutMillis; + } + + public boolean handleCloseFrames() { + return handleCloseFrames; + } + + public WebSocketCloseStatus sendCloseFrame() { + return sendCloseFrame; + } + + public boolean dropPongFrames() { + return dropPongFrames; + } + + public WebSocketDecoderConfig decoderConfig() { + return decoderConfig; + } + + @Override + public String toString() { + return "WebSocketServerProtocolConfig" + + " {websocketPath=" + websocketPath + + ", subprotocols=" + subprotocols + + ", checkStartsWith=" + checkStartsWith + + ", handshakeTimeoutMillis=" + handshakeTimeoutMillis + + ", forceCloseTimeoutMillis=" + forceCloseTimeoutMillis + + ", handleCloseFrames=" + handleCloseFrames + + ", sendCloseFrame=" + sendCloseFrame + + ", dropPongFrames=" + dropPongFrames + + ", decoderConfig=" + decoderConfig + + "}"; + } + + public Builder toBuilder() { + return new Builder(this); + } + + public static Builder newBuilder() { + return new Builder("/", null, false, DEFAULT_HANDSHAKE_TIMEOUT_MILLIS, 0L, + true, WebSocketCloseStatus.NORMAL_CLOSURE, true, WebSocketDecoderConfig.DEFAULT); + } + + public static final class Builder { + private String websocketPath; + private String subprotocols; + private boolean checkStartsWith; + private long handshakeTimeoutMillis; + private long forceCloseTimeoutMillis; + private boolean handleCloseFrames; + private WebSocketCloseStatus sendCloseFrame; + private boolean dropPongFrames; + private WebSocketDecoderConfig decoderConfig; + private WebSocketDecoderConfig.Builder decoderConfigBuilder; + + private Builder(WebSocketServerProtocolConfig serverConfig) { + this(ObjectUtil.checkNotNull(serverConfig, "serverConfig").websocketPath(), + serverConfig.subprotocols(), + serverConfig.checkStartsWith(), + serverConfig.handshakeTimeoutMillis(), + serverConfig.forceCloseTimeoutMillis(), + serverConfig.handleCloseFrames(), + serverConfig.sendCloseFrame(), + serverConfig.dropPongFrames(), + serverConfig.decoderConfig() + ); + } + + private Builder(String websocketPath, + String subprotocols, + boolean checkStartsWith, + long handshakeTimeoutMillis, + long forceCloseTimeoutMillis, + boolean handleCloseFrames, + WebSocketCloseStatus sendCloseFrame, + boolean dropPongFrames, + WebSocketDecoderConfig decoderConfig) { + this.websocketPath = websocketPath; + this.subprotocols = subprotocols; + this.checkStartsWith = checkStartsWith; + this.handshakeTimeoutMillis = handshakeTimeoutMillis; + this.forceCloseTimeoutMillis = forceCloseTimeoutMillis; + this.handleCloseFrames = handleCloseFrames; + this.sendCloseFrame = sendCloseFrame; + this.dropPongFrames = dropPongFrames; + this.decoderConfig = decoderConfig; + } + + /** + * URI path component to handle websocket upgrade requests on. + */ + public Builder websocketPath(String websocketPath) { + this.websocketPath = websocketPath; + return this; + } + + /** + * CSV of supported protocols + */ + public Builder subprotocols(String subprotocols) { + this.subprotocols = subprotocols; + return this; + } + + /** + * {@code true} to handle all requests, where URI path component starts from + * {@link WebSocketServerProtocolConfig#websocketPath()}, {@code false} for exact match (default). + */ + public Builder checkStartsWith(boolean checkStartsWith) { + this.checkStartsWith = checkStartsWith; + return this; + } + + /** + * Handshake timeout in mills, when handshake timeout, will trigger user + * event {@link ClientHandshakeStateEvent#HANDSHAKE_TIMEOUT} + */ + public Builder handshakeTimeoutMillis(long handshakeTimeoutMillis) { + this.handshakeTimeoutMillis = handshakeTimeoutMillis; + return this; + } + + /** + * Close the connection if it was not closed by the client after timeout specified + */ + public Builder forceCloseTimeoutMillis(long forceCloseTimeoutMillis) { + this.forceCloseTimeoutMillis = forceCloseTimeoutMillis; + return this; + } + + /** + * {@code true} if close frames should not be forwarded and just close the channel + */ + public Builder handleCloseFrames(boolean handleCloseFrames) { + this.handleCloseFrames = handleCloseFrames; + return this; + } + + /** + * Close frame to send, when close frame was not send manually. Or {@code null} to disable proper close. + */ + public Builder sendCloseFrame(WebSocketCloseStatus sendCloseFrame) { + this.sendCloseFrame = sendCloseFrame; + return this; + } + + /** + * {@code true} if pong frames should not be forwarded + */ + public Builder dropPongFrames(boolean dropPongFrames) { + this.dropPongFrames = dropPongFrames; + return this; + } + + /** + * Frames decoder configuration. + */ + public Builder decoderConfig(WebSocketDecoderConfig decoderConfig) { + this.decoderConfig = decoderConfig == null ? WebSocketDecoderConfig.DEFAULT : decoderConfig; + this.decoderConfigBuilder = null; + return this; + } + + private WebSocketDecoderConfig.Builder decoderConfigBuilder() { + if (decoderConfigBuilder == null) { + decoderConfigBuilder = decoderConfig.toBuilder(); + } + return decoderConfigBuilder; + } + + public Builder maxFramePayloadLength(int maxFramePayloadLength) { + decoderConfigBuilder().maxFramePayloadLength(maxFramePayloadLength); + return this; + } + + public Builder expectMaskedFrames(boolean expectMaskedFrames) { + decoderConfigBuilder().expectMaskedFrames(expectMaskedFrames); + return this; + } + + public Builder allowMaskMismatch(boolean allowMaskMismatch) { + decoderConfigBuilder().allowMaskMismatch(allowMaskMismatch); + return this; + } + + public Builder allowExtensions(boolean allowExtensions) { + decoderConfigBuilder().allowExtensions(allowExtensions); + return this; + } + + public Builder closeOnProtocolViolation(boolean closeOnProtocolViolation) { + decoderConfigBuilder().closeOnProtocolViolation(closeOnProtocolViolation); + return this; + } + + public Builder withUTF8Validator(boolean withUTF8Validator) { + decoderConfigBuilder().withUTF8Validator(withUTF8Validator); + return this; + } + + /** + * Build unmodifiable server protocol configuration. + */ + public WebSocketServerProtocolConfig build() { + return new WebSocketServerProtocolConfig( + websocketPath, + subprotocols, + checkStartsWith, + handshakeTimeoutMillis, + forceCloseTimeoutMillis, + handleCloseFrames, + sendCloseFrame, + dropPongFrames, + decoderConfigBuilder == null ? decoderConfig : decoderConfigBuilder.build() + ); + } + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandler.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandler.java new file mode 100644 index 0000000..d872880 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandler.java @@ -0,0 +1,276 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx; + +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandler; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.ChannelPromise; +import io.netty.handler.codec.http.DefaultFullHttpResponse; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpHeaders; +import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.util.AttributeKey; + +import java.util.List; + +import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1; +import static io.netty.handler.codec.http.websocketx.WebSocketServerProtocolConfig.DEFAULT_HANDSHAKE_TIMEOUT_MILLIS; +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +/** + * This handler does all the heavy lifting for you to run a websocket server. + * + * It takes care of websocket handshaking as well as processing of control frames (Close, Ping, Pong). Text and Binary + * data frames are passed to the next handler in the pipeline (implemented by you) for processing. + * + * See io.netty.example.http.websocketx.html5.WebSocketServer for usage. + * + * The implementation of this handler assumes that you just want to run a websocket server and not process other types + * HTTP requests (like GET and POST). If you wish to support both HTTP requests and websockets in the one server, refer + * to the io.netty.example.http.websocketx.server.WebSocketServer example. + * + * To know once a handshake was done you can intercept the + * {@link ChannelInboundHandler#userEventTriggered(ChannelHandlerContext, Object)} and check if the event was instance + * of {@link HandshakeComplete}, the event will contain extra information about the handshake such as the request and + * selected subprotocol. + */ +public class WebSocketServerProtocolHandler extends WebSocketProtocolHandler { + + /** + * Events that are fired to notify about handshake status + */ + public enum ServerHandshakeStateEvent { + /** + * The Handshake was completed successfully and the channel was upgraded to websockets. + * + * @deprecated in favor of {@link HandshakeComplete} class, + * it provides extra information about the handshake + */ + @Deprecated + HANDSHAKE_COMPLETE, + + /** + * The Handshake was timed out + */ + HANDSHAKE_TIMEOUT + } + + /** + * The Handshake was completed successfully and the channel was upgraded to websockets. + */ + public static final class HandshakeComplete { + private final String requestUri; + private final HttpHeaders requestHeaders; + private final String selectedSubprotocol; + + public HandshakeComplete(String requestUri, HttpHeaders requestHeaders, String selectedSubprotocol) { + this.requestUri = requestUri; + this.requestHeaders = requestHeaders; + this.selectedSubprotocol = selectedSubprotocol; + } + + public String requestUri() { + return requestUri; + } + + public HttpHeaders requestHeaders() { + return requestHeaders; + } + + public String selectedSubprotocol() { + return selectedSubprotocol; + } + } + + private static final AttributeKey HANDSHAKER_ATTR_KEY = + AttributeKey.valueOf(WebSocketServerHandshaker.class, "HANDSHAKER"); + + private final WebSocketServerProtocolConfig serverConfig; + + /** + * Base constructor + * + * @param serverConfig + * Server protocol configuration. + */ + public WebSocketServerProtocolHandler(WebSocketServerProtocolConfig serverConfig) { + super(checkNotNull(serverConfig, "serverConfig").dropPongFrames(), + serverConfig.sendCloseFrame(), + serverConfig.forceCloseTimeoutMillis() + ); + this.serverConfig = serverConfig; + } + + public WebSocketServerProtocolHandler(String websocketPath) { + this(websocketPath, DEFAULT_HANDSHAKE_TIMEOUT_MILLIS); + } + + public WebSocketServerProtocolHandler(String websocketPath, long handshakeTimeoutMillis) { + this(websocketPath, false, handshakeTimeoutMillis); + } + + public WebSocketServerProtocolHandler(String websocketPath, boolean checkStartsWith) { + this(websocketPath, checkStartsWith, DEFAULT_HANDSHAKE_TIMEOUT_MILLIS); + } + + public WebSocketServerProtocolHandler(String websocketPath, boolean checkStartsWith, long handshakeTimeoutMillis) { + this(websocketPath, null, false, 65536, false, checkStartsWith, handshakeTimeoutMillis); + } + + public WebSocketServerProtocolHandler(String websocketPath, String subprotocols) { + this(websocketPath, subprotocols, DEFAULT_HANDSHAKE_TIMEOUT_MILLIS); + } + + public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, long handshakeTimeoutMillis) { + this(websocketPath, subprotocols, false, handshakeTimeoutMillis); + } + + public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, boolean allowExtensions) { + this(websocketPath, subprotocols, allowExtensions, DEFAULT_HANDSHAKE_TIMEOUT_MILLIS); + } + + public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, boolean allowExtensions, + long handshakeTimeoutMillis) { + this(websocketPath, subprotocols, allowExtensions, 65536, handshakeTimeoutMillis); + } + + public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, + boolean allowExtensions, int maxFrameSize) { + this(websocketPath, subprotocols, allowExtensions, maxFrameSize, DEFAULT_HANDSHAKE_TIMEOUT_MILLIS); + } + + public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, + boolean allowExtensions, int maxFrameSize, long handshakeTimeoutMillis) { + this(websocketPath, subprotocols, allowExtensions, maxFrameSize, false, handshakeTimeoutMillis); + } + + public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, + boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch) { + this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, + DEFAULT_HANDSHAKE_TIMEOUT_MILLIS); + } + + public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, boolean allowExtensions, + int maxFrameSize, boolean allowMaskMismatch, long handshakeTimeoutMillis) { + this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, false, + handshakeTimeoutMillis); + } + + public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, + boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch, boolean checkStartsWith) { + this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, checkStartsWith, + DEFAULT_HANDSHAKE_TIMEOUT_MILLIS); + } + + public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, + boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch, + boolean checkStartsWith, long handshakeTimeoutMillis) { + this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, checkStartsWith, true, + handshakeTimeoutMillis); + } + + public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, + boolean allowExtensions, int maxFrameSize, boolean allowMaskMismatch, + boolean checkStartsWith, boolean dropPongFrames) { + this(websocketPath, subprotocols, allowExtensions, maxFrameSize, allowMaskMismatch, checkStartsWith, + dropPongFrames, DEFAULT_HANDSHAKE_TIMEOUT_MILLIS); + } + + public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, boolean allowExtensions, + int maxFrameSize, boolean allowMaskMismatch, boolean checkStartsWith, + boolean dropPongFrames, long handshakeTimeoutMillis) { + this(websocketPath, subprotocols, checkStartsWith, dropPongFrames, handshakeTimeoutMillis, + WebSocketDecoderConfig.newBuilder() + .maxFramePayloadLength(maxFrameSize) + .allowMaskMismatch(allowMaskMismatch) + .allowExtensions(allowExtensions) + .build()); + } + + public WebSocketServerProtocolHandler(String websocketPath, String subprotocols, boolean checkStartsWith, + boolean dropPongFrames, long handshakeTimeoutMillis, + WebSocketDecoderConfig decoderConfig) { + this(WebSocketServerProtocolConfig.newBuilder() + .websocketPath(websocketPath) + .subprotocols(subprotocols) + .checkStartsWith(checkStartsWith) + .handshakeTimeoutMillis(handshakeTimeoutMillis) + .dropPongFrames(dropPongFrames) + .decoderConfig(decoderConfig) + .build()); + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) { + ChannelPipeline cp = ctx.pipeline(); + if (cp.get(WebSocketServerProtocolHandshakeHandler.class) == null) { + // Add the WebSocketHandshakeHandler before this one. + cp.addBefore(ctx.name(), WebSocketServerProtocolHandshakeHandler.class.getName(), + new WebSocketServerProtocolHandshakeHandler(serverConfig)); + } + if (serverConfig.decoderConfig().withUTF8Validator() && cp.get(Utf8FrameValidator.class) == null) { + // Add the UFT8 checking before this one. + cp.addBefore(ctx.name(), Utf8FrameValidator.class.getName(), + new Utf8FrameValidator(serverConfig.decoderConfig().closeOnProtocolViolation())); + } + } + + @Override + protected void decode(ChannelHandlerContext ctx, WebSocketFrame frame, List out) throws Exception { + if (serverConfig.handleCloseFrames() && frame instanceof CloseWebSocketFrame) { + WebSocketServerHandshaker handshaker = getHandshaker(ctx.channel()); + if (handshaker != null) { + frame.retain(); + ChannelPromise promise = ctx.newPromise(); + closeSent(promise); + handshaker.close(ctx, (CloseWebSocketFrame) frame, promise); + } else { + ctx.writeAndFlush(Unpooled.EMPTY_BUFFER).addListener(ChannelFutureListener.CLOSE); + } + return; + } + super.decode(ctx, frame, out); + } + + @Override + protected WebSocketServerHandshakeException buildHandshakeException(String message) { + return new WebSocketServerHandshakeException(message); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + if (cause instanceof WebSocketHandshakeException) { + FullHttpResponse response = new DefaultFullHttpResponse( + HTTP_1_1, HttpResponseStatus.BAD_REQUEST, Unpooled.wrappedBuffer(cause.getMessage().getBytes())); + ctx.channel().writeAndFlush(response).addListener(ChannelFutureListener.CLOSE); + } else { + ctx.fireExceptionCaught(cause); + ctx.close(); + } + } + + static WebSocketServerHandshaker getHandshaker(Channel channel) { + return channel.attr(HANDSHAKER_ATTR_KEY).get(); + } + + static void setHandshaker(Channel channel, WebSocketServerHandshaker handshaker) { + channel.attr(HANDSHAKER_ATTR_KEY).set(handshaker); + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandshakeHandler.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandshakeHandler.java new file mode 100644 index 0000000..fad3246 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandshakeHandler.java @@ -0,0 +1,179 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx; + +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.ChannelPromise; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpObject; +import io.netty.handler.codec.http.HttpRequest; +import io.netty.handler.codec.http.HttpResponse; +import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler.ServerHandshakeStateEvent; +import io.netty.handler.ssl.SslHandler; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.FutureListener; + +import java.util.concurrent.TimeUnit; + +import static io.netty.handler.codec.http.HttpUtil.*; +import static io.netty.util.internal.ObjectUtil.*; + +/** + * Handles the HTTP handshake (the HTTP Upgrade request) for {@link WebSocketServerProtocolHandler}. + */ +class WebSocketServerProtocolHandshakeHandler extends ChannelInboundHandlerAdapter { + + private final WebSocketServerProtocolConfig serverConfig; + private ChannelHandlerContext ctx; + private ChannelPromise handshakePromise; + private boolean isWebSocketPath; + + WebSocketServerProtocolHandshakeHandler(WebSocketServerProtocolConfig serverConfig) { + this.serverConfig = checkNotNull(serverConfig, "serverConfig"); + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) { + this.ctx = ctx; + handshakePromise = ctx.newPromise(); + } + + @Override + public void channelRead(final ChannelHandlerContext ctx, Object msg) throws Exception { + final HttpObject httpObject = (HttpObject) msg; + + if (httpObject instanceof HttpRequest) { + final HttpRequest req = (HttpRequest) httpObject; + isWebSocketPath = isWebSocketPath(req); + if (!isWebSocketPath) { + ctx.fireChannelRead(msg); + return; + } + + try { + final WebSocketServerHandshakerFactory wsFactory = new WebSocketServerHandshakerFactory( + getWebSocketLocation(ctx.pipeline(), req, serverConfig.websocketPath()), + serverConfig.subprotocols(), serverConfig.decoderConfig()); + final WebSocketServerHandshaker handshaker = wsFactory.newHandshaker(req); + final ChannelPromise localHandshakePromise = handshakePromise; + if (handshaker == null) { + WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ctx.channel()); + } else { + // Ensure we set the handshaker and replace this handler before we + // trigger the actual handshake. Otherwise we may receive websocket bytes in this handler + // before we had a chance to replace it. + // + // See https://github.com/netty/netty/issues/9471. + WebSocketServerProtocolHandler.setHandshaker(ctx.channel(), handshaker); + ctx.pipeline().remove(this); + + final ChannelFuture handshakeFuture = handshaker.handshake(ctx.channel(), req); + handshakeFuture.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + if (!future.isSuccess()) { + localHandshakePromise.tryFailure(future.cause()); + ctx.fireExceptionCaught(future.cause()); + } else { + localHandshakePromise.trySuccess(); + // Kept for compatibility + ctx.fireUserEventTriggered( + WebSocketServerProtocolHandler.ServerHandshakeStateEvent.HANDSHAKE_COMPLETE); + ctx.fireUserEventTriggered( + new WebSocketServerProtocolHandler.HandshakeComplete( + req.uri(), req.headers(), handshaker.selectedSubprotocol())); + } + } + }); + applyHandshakeTimeout(); + } + } finally { + ReferenceCountUtil.release(req); + } + } else if (!isWebSocketPath) { + ctx.fireChannelRead(msg); + } else { + ReferenceCountUtil.release(msg); + } + } + + private boolean isWebSocketPath(HttpRequest req) { + String websocketPath = serverConfig.websocketPath(); + String uri = req.uri(); + boolean checkStartUri = uri.startsWith(websocketPath); + boolean checkNextUri = "/".equals(websocketPath) || checkNextUri(uri, websocketPath); + return serverConfig.checkStartsWith() ? (checkStartUri && checkNextUri) : uri.equals(websocketPath); + } + + private boolean checkNextUri(String uri, String websocketPath) { + int len = websocketPath.length(); + if (uri.length() > len) { + char nextUri = uri.charAt(len); + return nextUri == '/' || nextUri == '?'; + } + return true; + } + + private static void sendHttpResponse(ChannelHandlerContext ctx, HttpRequest req, HttpResponse res) { + ChannelFuture f = ctx.writeAndFlush(res); + if (!isKeepAlive(req) || res.status().code() != 200) { + f.addListener(ChannelFutureListener.CLOSE); + } + } + + private static String getWebSocketLocation(ChannelPipeline cp, HttpRequest req, String path) { + String protocol = "ws"; + if (cp.get(SslHandler.class) != null) { + // SSL in use so use Secure WebSockets + protocol = "wss"; + } + String host = req.headers().get(HttpHeaderNames.HOST); + return protocol + "://" + host + path; + } + + private void applyHandshakeTimeout() { + final ChannelPromise localHandshakePromise = handshakePromise; + final long handshakeTimeoutMillis = serverConfig.handshakeTimeoutMillis(); + if (handshakeTimeoutMillis <= 0 || localHandshakePromise.isDone()) { + return; + } + + final Future timeoutFuture = ctx.executor().schedule(new Runnable() { + @Override + public void run() { + if (!localHandshakePromise.isDone() && + localHandshakePromise.tryFailure(new WebSocketServerHandshakeException("handshake timed out"))) { + ctx.flush() + .fireUserEventTriggered(ServerHandshakeStateEvent.HANDSHAKE_TIMEOUT) + .close(); + } + } + }, handshakeTimeoutMillis, TimeUnit.MILLISECONDS); + + // Cancel the handshake timeout when handshake is finished. + localHandshakePromise.addListener(new FutureListener() { + @Override + public void operationComplete(Future f) { + timeoutFuture.cancel(false); + } + }); + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketUtil.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketUtil.java new file mode 100644 index 0000000..f35efaa --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketUtil.java @@ -0,0 +1,173 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.handler.codec.base64.Base64; +import io.netty.util.CharsetUtil; +import io.netty.util.concurrent.FastThreadLocal; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.SuppressJava6Requirement; + +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; + +/** + * A utility class mainly for use by web sockets + */ +final class WebSocketUtil { + + private static final FastThreadLocal MD5 = new FastThreadLocal() { + @Override + protected MessageDigest initialValue() throws Exception { + try { + //Try to get a MessageDigest that uses MD5 + //Suppress a warning about weak hash algorithm + //since it's defined in draft-ietf-hybi-thewebsocketprotocol-00 + return MessageDigest.getInstance("MD5"); + } catch (NoSuchAlgorithmException e) { + //This shouldn't happen! How old is the computer? + throw new InternalError("MD5 not supported on this platform - Outdated?"); + } + } + }; + + private static final FastThreadLocal SHA1 = new FastThreadLocal() { + @Override + protected MessageDigest initialValue() throws Exception { + try { + //Try to get a MessageDigest that uses SHA1 + //Suppress a warning about weak hash algorithm + //since it's defined in draft-ietf-hybi-thewebsocketprotocol-00 + return MessageDigest.getInstance("SHA1"); + } catch (NoSuchAlgorithmException e) { + //This shouldn't happen! How old is the computer? + throw new InternalError("SHA-1 not supported on this platform - Outdated?"); + } + } + }; + + /** + * Performs a MD5 hash on the specified data + * + * @param data The data to hash + * @return The hashed data + */ + static byte[] md5(byte[] data) { + // TODO(normanmaurer): Create md5 method that not need MessageDigest. + return digest(MD5, data); + } + + /** + * Performs a SHA-1 hash on the specified data + * + * @param data The data to hash + * @return The hashed data + */ + static byte[] sha1(byte[] data) { + // TODO(normanmaurer): Create sha1 method that not need MessageDigest. + return digest(SHA1, data); + } + + private static byte[] digest(FastThreadLocal digestFastThreadLocal, byte[] data) { + MessageDigest digest = digestFastThreadLocal.get(); + digest.reset(); + return digest.digest(data); + } + + /** + * Performs base64 encoding on the specified data + * + * @param data The data to encode + * @return An encoded string containing the data + */ + @SuppressJava6Requirement(reason = "Guarded with java version check") + static String base64(byte[] data) { + if (PlatformDependent.javaVersion() >= 8) { + return java.util.Base64.getEncoder().encodeToString(data); + } + String encodedString; + ByteBuf encodedData = Unpooled.wrappedBuffer(data); + try { + ByteBuf encoded = Base64.encode(encodedData); + try { + encodedString = encoded.toString(CharsetUtil.UTF_8); + } finally { + encoded.release(); + } + } finally { + encodedData.release(); + } + return encodedString; + } + + /** + * Creates an arbitrary number of random bytes + * + * @param size the number of random bytes to create + * @return An array of random bytes + */ + static byte[] randomBytes(int size) { + byte[] bytes = new byte[size]; + PlatformDependent.threadLocalRandom().nextBytes(bytes); + return bytes; + } + + /** + * Generates a pseudo-random number + * + * @param minimum The minimum allowable value + * @param maximum The maximum allowable value + * @return A pseudo-random number + */ + static int randomNumber(int minimum, int maximum) { + assert minimum < maximum; + double fraction = PlatformDependent.threadLocalRandom().nextDouble(); + + // the idea here is that nextDouble gives us a random value + // + // 0 <= fraction <= 1 + // + // the distance from min to max declared as + // + // dist = max - min + // + // satisfies the following + // + // min + dist = max + // + // taking into account + // + // 0 <= fraction * dist <= dist + // + // we've got + // + // min <= min + fraction * dist <= max + return (int) (minimum + fraction * (maximum - minimum)); + } + + static int byteAtIndex(int mask, int index) { + return (mask >> 8 * (3 - index)) & 0xFF; + } + + /** + * A private constructor to ensure that instances of this class cannot be made + */ + private WebSocketUtil() { + // Unused + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketVersion.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketVersion.java new file mode 100644 index 0000000..c8c78c1 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/WebSocketVersion.java @@ -0,0 +1,77 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx; + +import io.netty.util.AsciiString; +import io.netty.util.internal.StringUtil; + +/** + *

+ * Versions of the web socket specification. + *

+ *

+ * A specification is tied to one wire protocol version but a protocol version may have use by more than 1 version of + * the specification. + *

+ */ +public enum WebSocketVersion { + UNKNOWN(AsciiString.cached(StringUtil.EMPTY_STRING)), + + /** + * draft-ietf-hybi-thewebsocketprotocol- 00. + */ + V00(AsciiString.cached("0")), + + /** + * draft-ietf-hybi-thewebsocketprotocol- 07 + */ + V07(AsciiString.cached("7")), + + /** + * draft-ietf-hybi-thewebsocketprotocol- 10 + */ + V08(AsciiString.cached("8")), + + /** + * RFC 6455. This was originally draft-ietf-hybi-thewebsocketprotocol- + * 17 + */ + V13(AsciiString.cached("13")); + + private final AsciiString headerValue; + + WebSocketVersion(AsciiString headerValue) { + this.headerValue = headerValue; + } + /** + * @return Value for HTTP Header 'Sec-WebSocket-Version' + */ + public String toHttpHeaderValue() { + return toAsciiString().toString(); + } + + AsciiString toAsciiString() { + if (this == UNKNOWN) { + // Let's special case this to preserve behaviour + throw new IllegalStateException("Unknown web socket version: " + this); + } + return headerValue; + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketClientExtension.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketClientExtension.java new file mode 100644 index 0000000..1d4a089 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketClientExtension.java @@ -0,0 +1,23 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx.extensions; + +/** + * Created once the handshake phase is done. + */ +public interface WebSocketClientExtension extends WebSocketExtension { + +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketClientExtensionHandler.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketClientExtensionHandler.java new file mode 100644 index 0000000..84ead3d --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketClientExtensionHandler.java @@ -0,0 +1,128 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx.extensions; + +import static io.netty.util.internal.ObjectUtil.checkNonEmpty; + +import io.netty.channel.ChannelDuplexHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.handler.codec.CodecException; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpRequest; +import io.netty.handler.codec.http.HttpResponse; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; + +/** + * This handler negotiates and initializes the WebSocket Extensions. + * + * This implementation negotiates the extension with the server in a defined order, + * ensures that the successfully negotiated extensions are consistent between them, + * and initializes the channel pipeline with the extension decoder and encoder. + * + * Find a basic implementation for compression extensions at + * io.netty.handler.codec.http.websocketx.extensions.compression.WebSocketClientCompressionHandler. + */ +public class WebSocketClientExtensionHandler extends ChannelDuplexHandler { + + private final List extensionHandshakers; + + /** + * Constructor + * + * @param extensionHandshakers + * The extension handshaker in priority order. A handshaker could be repeated many times + * with fallback configuration. + */ + public WebSocketClientExtensionHandler(WebSocketClientExtensionHandshaker... extensionHandshakers) { + this.extensionHandshakers = Arrays.asList(checkNonEmpty(extensionHandshakers, "extensionHandshakers")); + } + + @Override + public void write(final ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + if (msg instanceof HttpRequest && WebSocketExtensionUtil.isWebsocketUpgrade(((HttpRequest) msg).headers())) { + HttpRequest request = (HttpRequest) msg; + String headerValue = request.headers().getAsString(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS); + List extraExtensions = + new ArrayList(extensionHandshakers.size()); + for (WebSocketClientExtensionHandshaker extensionHandshaker : extensionHandshakers) { + extraExtensions.add(extensionHandshaker.newRequestData()); + } + String newHeaderValue = WebSocketExtensionUtil + .computeMergeExtensionsHeaderValue(headerValue, extraExtensions); + + request.headers().set(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS, newHeaderValue); + } + + super.write(ctx, msg, promise); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) + throws Exception { + if (msg instanceof HttpResponse) { + HttpResponse response = (HttpResponse) msg; + + if (WebSocketExtensionUtil.isWebsocketUpgrade(response.headers())) { + String extensionsHeader = response.headers().getAsString(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS); + + if (extensionsHeader != null) { + List extensions = + WebSocketExtensionUtil.extractExtensions(extensionsHeader); + List validExtensions = + new ArrayList(extensions.size()); + int rsv = 0; + + for (WebSocketExtensionData extensionData : extensions) { + Iterator extensionHandshakersIterator = + extensionHandshakers.iterator(); + WebSocketClientExtension validExtension = null; + + while (validExtension == null && extensionHandshakersIterator.hasNext()) { + WebSocketClientExtensionHandshaker extensionHandshaker = + extensionHandshakersIterator.next(); + validExtension = extensionHandshaker.handshakeExtension(extensionData); + } + + if (validExtension != null && ((validExtension.rsv() & rsv) == 0)) { + rsv = rsv | validExtension.rsv(); + validExtensions.add(validExtension); + } else { + throw new CodecException( + "invalid WebSocket Extension handshake for \"" + extensionsHeader + '"'); + } + } + + for (WebSocketClientExtension validExtension : validExtensions) { + WebSocketExtensionDecoder decoder = validExtension.newExtensionDecoder(); + WebSocketExtensionEncoder encoder = validExtension.newExtensionEncoder(); + ctx.pipeline().addAfter(ctx.name(), decoder.getClass().getName(), decoder); + ctx.pipeline().addAfter(ctx.name(), encoder.getClass().getName(), encoder); + } + } + + ctx.pipeline().remove(ctx.name()); + } + } + + super.channelRead(ctx, msg); + } +} + diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketClientExtensionHandshaker.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketClientExtensionHandshaker.java new file mode 100644 index 0000000..4812966 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketClientExtensionHandshaker.java @@ -0,0 +1,41 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx.extensions; + + +/** + * Handshakes a client extension with the server. + */ +public interface WebSocketClientExtensionHandshaker { + + /** + * Return extension configuration to submit to the server. + * + * @return the desired extension configuration. + */ + WebSocketExtensionData newRequestData(); + + /** + * Handshake based on server response. It should always succeed because server response + * should be a request acknowledge. + * + * @param extensionData + * the extension configuration sent by the server. + * @return an initialized extension if handshake phase succeed or null if failed. + */ + WebSocketClientExtension handshakeExtension(WebSocketExtensionData extensionData); + +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketExtension.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketExtension.java new file mode 100644 index 0000000..30bfb4c --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketExtension.java @@ -0,0 +1,42 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx.extensions; + +/** + * Created once the handshake phase is done. + */ +public interface WebSocketExtension { + + int RSV1 = 0x04; + int RSV2 = 0x02; + int RSV3 = 0x01; + + /** + * @return the reserved bit value to ensure that no other extension should interfere. + */ + int rsv(); + + /** + * @return create the extension encoder. + */ + WebSocketExtensionEncoder newExtensionEncoder(); + + /** + * @return create the extension decoder. + */ + WebSocketExtensionDecoder newExtensionDecoder(); + +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketExtensionData.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketExtensionData.java new file mode 100644 index 0000000..f718a8c --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketExtensionData.java @@ -0,0 +1,52 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx.extensions; + +import io.netty.util.internal.ObjectUtil; + +import java.util.Collections; +import java.util.Map; + +/** + * A WebSocket Extension data from the Sec-WebSocket-Extensions header. + * + * See io.netty.handler.codec.http.HttpHeaders.Names.SEC_WEBSOCKET_EXTENSIONS. + */ +public final class WebSocketExtensionData { + + private final String name; + private final Map parameters; + + public WebSocketExtensionData(String name, Map parameters) { + this.name = ObjectUtil.checkNotNull(name, "name"); + this.parameters = Collections.unmodifiableMap( + ObjectUtil.checkNotNull(parameters, "parameters")); + } + + /** + * @return the extension name. + */ + public String name() { + return name; + } + + /** + * @return the extension optional parameters. + */ + public Map parameters() { + return parameters; + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketExtensionDecoder.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketExtensionDecoder.java new file mode 100644 index 0000000..d86b1ad --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketExtensionDecoder.java @@ -0,0 +1,26 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx.extensions; + +import io.netty.handler.codec.MessageToMessageDecoder; +import io.netty.handler.codec.http.websocketx.WebSocketFrame; + +/** + * Convenient class for io.netty.handler.codec.http.websocketx.extensions.WebSocketExtension decoder. + */ +public abstract class WebSocketExtensionDecoder extends MessageToMessageDecoder { + +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketExtensionEncoder.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketExtensionEncoder.java new file mode 100644 index 0000000..96a84f8 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketExtensionEncoder.java @@ -0,0 +1,26 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx.extensions; + +import io.netty.handler.codec.MessageToMessageEncoder; +import io.netty.handler.codec.http.websocketx.WebSocketFrame; + +/** + * Convenient class for io.netty.handler.codec.http.websocketx.extensions.WebSocketExtension encoder. + */ +public abstract class WebSocketExtensionEncoder extends MessageToMessageEncoder { + +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketExtensionFilter.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketExtensionFilter.java new file mode 100644 index 0000000..3fd3d12 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketExtensionFilter.java @@ -0,0 +1,54 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx.extensions; + +import io.netty.handler.codec.http.websocketx.WebSocketFrame; + +/** + * Filter that is responsible to skip the evaluation of a certain extension + * according to standard. + */ +public interface WebSocketExtensionFilter { + + /** + * A {@link WebSocketExtensionFilter} that never skip the evaluation of an + * any given extensions {@link WebSocketExtension}. + */ + WebSocketExtensionFilter NEVER_SKIP = new WebSocketExtensionFilter() { + @Override + public boolean mustSkip(WebSocketFrame frame) { + return false; + } + }; + + /** + * A {@link WebSocketExtensionFilter} that always skip the evaluation of an + * any given extensions {@link WebSocketExtension}. + */ + WebSocketExtensionFilter ALWAYS_SKIP = new WebSocketExtensionFilter() { + @Override + public boolean mustSkip(WebSocketFrame frame) { + return true; + } + }; + + /** + * Returns {@code true} if the evaluation of the extension must skipped + * for the given frame otherwise {@code false}. + */ + boolean mustSkip(WebSocketFrame frame); + +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketExtensionFilterProvider.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketExtensionFilterProvider.java new file mode 100644 index 0000000..4633e76 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketExtensionFilterProvider.java @@ -0,0 +1,45 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx.extensions; + +/** + * Extension filter provider that is responsible to provide filters for a certain {@link WebSocketExtension} extension. + */ +public interface WebSocketExtensionFilterProvider { + + WebSocketExtensionFilterProvider DEFAULT = new WebSocketExtensionFilterProvider() { + @Override + public WebSocketExtensionFilter encoderFilter() { + return WebSocketExtensionFilter.NEVER_SKIP; + } + + @Override + public WebSocketExtensionFilter decoderFilter() { + return WebSocketExtensionFilter.NEVER_SKIP; + } + }; + + /** + * Returns the extension filter for {@link WebSocketExtensionEncoder} encoder. + */ + WebSocketExtensionFilter encoderFilter(); + + /** + * Returns the extension filter for {@link WebSocketExtensionDecoder} decoder. + */ + WebSocketExtensionFilter decoderFilter(); + +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketExtensionUtil.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketExtensionUtil.java new file mode 100644 index 0000000..01f1c00 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketExtensionUtil.java @@ -0,0 +1,127 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx.extensions; + +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaderValues; +import io.netty.handler.codec.http.HttpHeaders; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +public final class WebSocketExtensionUtil { + + private static final String EXTENSION_SEPARATOR = ","; + private static final String PARAMETER_SEPARATOR = ";"; + private static final char PARAMETER_EQUAL = '='; + + private static final Pattern PARAMETER = Pattern.compile("^([^=]+)(=[\\\"]?([^\\\"]+)[\\\"]?)?$"); + + static boolean isWebsocketUpgrade(HttpHeaders headers) { + //this contains check does not allocate an iterator, and most requests are not upgrades + //so we do the contains check first before checking for specific values + return headers.contains(HttpHeaderNames.UPGRADE) && + headers.containsValue(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE, true) && + headers.contains(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET, true); + } + + public static List extractExtensions(String extensionHeader) { + String[] rawExtensions = extensionHeader.split(EXTENSION_SEPARATOR); + if (rawExtensions.length > 0) { + List extensions = new ArrayList(rawExtensions.length); + for (String rawExtension : rawExtensions) { + String[] extensionParameters = rawExtension.split(PARAMETER_SEPARATOR); + String name = extensionParameters[0].trim(); + Map parameters; + if (extensionParameters.length > 1) { + parameters = new HashMap(extensionParameters.length - 1); + for (int i = 1; i < extensionParameters.length; i++) { + String parameter = extensionParameters[i].trim(); + Matcher parameterMatcher = PARAMETER.matcher(parameter); + if (parameterMatcher.matches() && parameterMatcher.group(1) != null) { + parameters.put(parameterMatcher.group(1), parameterMatcher.group(3)); + } + } + } else { + parameters = Collections.emptyMap(); + } + extensions.add(new WebSocketExtensionData(name, parameters)); + } + return extensions; + } else { + return Collections.emptyList(); + } + } + + static String computeMergeExtensionsHeaderValue(String userDefinedHeaderValue, + List extraExtensions) { + List userDefinedExtensions = + userDefinedHeaderValue != null ? + extractExtensions(userDefinedHeaderValue) : + Collections.emptyList(); + + for (WebSocketExtensionData userDefined: userDefinedExtensions) { + WebSocketExtensionData matchingExtra = null; + int i; + for (i = 0; i < extraExtensions.size(); i ++) { + WebSocketExtensionData extra = extraExtensions.get(i); + if (extra.name().equals(userDefined.name())) { + matchingExtra = extra; + break; + } + } + if (matchingExtra == null) { + extraExtensions.add(userDefined); + } else { + // merge with higher precedence to user defined parameters + Map mergedParameters = new HashMap(matchingExtra.parameters()); + mergedParameters.putAll(userDefined.parameters()); + extraExtensions.set(i, new WebSocketExtensionData(matchingExtra.name(), mergedParameters)); + } + } + + StringBuilder sb = new StringBuilder(150); + + for (WebSocketExtensionData data: extraExtensions) { + sb.append(data.name()); + for (Entry parameter : data.parameters().entrySet()) { + sb.append(PARAMETER_SEPARATOR); + sb.append(parameter.getKey()); + if (parameter.getValue() != null) { + sb.append(PARAMETER_EQUAL); + sb.append(parameter.getValue()); + } + } + sb.append(EXTENSION_SEPARATOR); + } + + if (!extraExtensions.isEmpty()) { + sb.setLength(sb.length() - EXTENSION_SEPARATOR.length()); + } + + return sb.toString(); + } + + private WebSocketExtensionUtil() { + // Unused + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketServerExtension.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketServerExtension.java new file mode 100644 index 0000000..9cb6237 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketServerExtension.java @@ -0,0 +1,32 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx.extensions; + + +/** + * Created once the handshake phase is done. + */ +public interface WebSocketServerExtension extends WebSocketExtension { + + /** + * Return an extension configuration to submit to the client as an acknowledge. + * + * @return the acknowledged extension configuration. + */ + //TODO: after migrating to JDK 8 rename this to 'newResponseData()' and mark old as deprecated with default method + WebSocketExtensionData newReponseData(); + +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketServerExtensionHandler.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketServerExtensionHandler.java new file mode 100644 index 0000000..b4e1f7a --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketServerExtensionHandler.java @@ -0,0 +1,263 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx.extensions; + +import static io.netty.util.internal.ObjectUtil.checkNonEmpty; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelDuplexHandler; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.handler.codec.http.DefaultHttpRequest; +import io.netty.handler.codec.http.DefaultHttpResponse; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaders; +import io.netty.handler.codec.http.HttpRequest; +import io.netty.handler.codec.http.HttpResponse; +import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.handler.codec.http.LastHttpContent; +import io.netty.util.internal.UnstableApi; + +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.Queue; + +/** + * This handler negotiates and initializes the WebSocket Extensions. + * + * It negotiates the extensions based on the client desired order, + * ensures that the successfully negotiated extensions are consistent between them, + * and initializes the channel pipeline with the extension decoder and encoder. + * + * Find a basic implementation for compression extensions at + * io.netty.handler.codec.http.websocketx.extensions.compression.WebSocketServerCompressionHandler. + */ +public class WebSocketServerExtensionHandler extends ChannelDuplexHandler { + + private final List extensionHandshakers; + + private final Queue> validExtensions = + new ArrayDeque>(4); + + /** + * Constructor + * + * @param extensionHandshakers + * The extension handshaker in priority order. A handshaker could be repeated many times + * with fallback configuration. + */ + public WebSocketServerExtensionHandler(WebSocketServerExtensionHandshaker... extensionHandshakers) { + this.extensionHandshakers = Arrays.asList(checkNonEmpty(extensionHandshakers, "extensionHandshakers")); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + // JDK type checks vs non-implemented interfaces costs O(N), where + // N is the number of interfaces already implemented by the concrete type that's being tested. + // The only requirement for this call is to make HttpRequest(s) implementors to call onHttpRequestChannelRead + // and super.channelRead the others, but due to the O(n) cost we perform few fast-path for commonly met + // singleton and/or concrete types, to save performing such slow type checks. + if (msg != LastHttpContent.EMPTY_LAST_CONTENT) { + if (msg instanceof DefaultHttpRequest) { + // fast-path + onHttpRequestChannelRead(ctx, (DefaultHttpRequest) msg); + } else if (msg instanceof HttpRequest) { + // slow path + onHttpRequestChannelRead(ctx, (HttpRequest) msg); + } else { + super.channelRead(ctx, msg); + } + } else { + super.channelRead(ctx, msg); + } + } + + /** + * This is a method exposed to perform fail-fast checks of user-defined http types.

+ * eg:
+ * If the user has defined a specific {@link HttpRequest} type i.e.{@code CustomHttpRequest} and + * {@link #channelRead} can receive {@link LastHttpContent#EMPTY_LAST_CONTENT} {@code msg} + * types too, can override it like this: + *

+     *     public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
+     *         if (msg != LastHttpContent.EMPTY_LAST_CONTENT) {
+     *             if (msg instanceof CustomHttpRequest) {
+     *                 onHttpRequestChannelRead(ctx, (CustomHttpRequest) msg);
+     *             } else {
+     *                 // if it's handling other HttpRequest types it MUST use onHttpRequestChannelRead again
+     *                 // or have to delegate it to super.channelRead (that can perform redundant checks).
+     *                 // If msg is not implementing HttpRequest, it can call ctx.fireChannelRead(msg) on it
+     *                 // ...
+     *                 super.channelRead(ctx, msg);
+     *             }
+     *         } else {
+     *             // given that msg isn't a HttpRequest type we can just skip calling super.channelRead
+     *             ctx.fireChannelRead(msg);
+     *         }
+     *     }
+     * 
+ * IMPORTANT: + * It already call {@code super.channelRead(ctx, request)} before returning. + */ + @UnstableApi + protected void onHttpRequestChannelRead(ChannelHandlerContext ctx, HttpRequest request) throws Exception { + List validExtensionsList = null; + + if (WebSocketExtensionUtil.isWebsocketUpgrade(request.headers())) { + String extensionsHeader = request.headers().getAsString(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS); + + if (extensionsHeader != null) { + List extensions = + WebSocketExtensionUtil.extractExtensions(extensionsHeader); + int rsv = 0; + + for (WebSocketExtensionData extensionData : extensions) { + Iterator extensionHandshakersIterator = + extensionHandshakers.iterator(); + WebSocketServerExtension validExtension = null; + + while (validExtension == null && extensionHandshakersIterator.hasNext()) { + WebSocketServerExtensionHandshaker extensionHandshaker = + extensionHandshakersIterator.next(); + validExtension = extensionHandshaker.handshakeExtension(extensionData); + } + + if (validExtension != null && ((validExtension.rsv() & rsv) == 0)) { + if (validExtensionsList == null) { + validExtensionsList = new ArrayList(1); + } + rsv = rsv | validExtension.rsv(); + validExtensionsList.add(validExtension); + } + } + } + } + + if (validExtensionsList == null) { + validExtensionsList = Collections.emptyList(); + } + validExtensions.offer(validExtensionsList); + super.channelRead(ctx, request); + } + + @Override + public void write(final ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + if (msg != Unpooled.EMPTY_BUFFER && !(msg instanceof ByteBuf)) { + if (msg instanceof DefaultHttpResponse) { + onHttpResponseWrite(ctx, (DefaultHttpResponse) msg, promise); + } else if (msg instanceof HttpResponse) { + onHttpResponseWrite(ctx, (HttpResponse) msg, promise); + } else { + super.write(ctx, msg, promise); + } + } else { + super.write(ctx, msg, promise); + } + } + + /** + * This is a method exposed to perform fail-fast checks of user-defined http types.

+ * eg:
+ * If the user has defined a specific {@link HttpResponse} type i.e.{@code CustomHttpResponse} and + * {@link #write} can receive {@link ByteBuf} {@code msg} types too, it can be overridden like this: + *

+     *     public void write(final ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception {
+     *         if (msg != Unpooled.EMPTY_BUFFER && !(msg instanceof ByteBuf)) {
+     *             if (msg instanceof CustomHttpResponse) {
+     *                 onHttpResponseWrite(ctx, (CustomHttpResponse) msg, promise);
+     *             } else {
+     *                 // if it's handling other HttpResponse types it MUST use onHttpResponseWrite again
+     *                 // or have to delegate it to super.write (that can perform redundant checks).
+     *                 // If msg is not implementing HttpResponse, it can call ctx.write(msg, promise) on it
+     *                 // ...
+     *                 super.write(ctx, msg, promise);
+     *             }
+     *         } else {
+     *             // given that msg isn't a HttpResponse type we can just skip calling super.write
+     *             ctx.write(msg, promise);
+     *         }
+     *     }
+     * 
+ * IMPORTANT: + * It already call {@code super.write(ctx, response, promise)} before returning. + */ + @UnstableApi + protected void onHttpResponseWrite(ChannelHandlerContext ctx, HttpResponse response, ChannelPromise promise) + throws Exception { + List validExtensionsList = validExtensions.poll(); + HttpResponse httpResponse = response; + //checking the status is faster than looking at headers + //so we do this first + if (HttpResponseStatus.SWITCHING_PROTOCOLS.equals(httpResponse.status())) { + handlePotentialUpgrade(ctx, promise, httpResponse, validExtensionsList); + } + super.write(ctx, response, promise); + } + + private void handlePotentialUpgrade(final ChannelHandlerContext ctx, + ChannelPromise promise, HttpResponse httpResponse, + final List validExtensionsList) { + HttpHeaders headers = httpResponse.headers(); + + if (WebSocketExtensionUtil.isWebsocketUpgrade(headers)) { + if (validExtensionsList != null && !validExtensionsList.isEmpty()) { + String headerValue = headers.getAsString(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS); + List extraExtensions = + new ArrayList(extensionHandshakers.size()); + for (WebSocketServerExtension extension : validExtensionsList) { + extraExtensions.add(extension.newReponseData()); + } + String newHeaderValue = WebSocketExtensionUtil + .computeMergeExtensionsHeaderValue(headerValue, extraExtensions); + promise.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + if (future.isSuccess()) { + for (WebSocketServerExtension extension : validExtensionsList) { + WebSocketExtensionDecoder decoder = extension.newExtensionDecoder(); + WebSocketExtensionEncoder encoder = extension.newExtensionEncoder(); + String name = ctx.name(); + ctx.pipeline() + .addAfter(name, decoder.getClass().getName(), decoder) + .addAfter(name, encoder.getClass().getName(), encoder); + } + } + } + }); + + if (newHeaderValue != null) { + headers.set(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS, newHeaderValue); + } + } + + promise.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + if (future.isSuccess()) { + ctx.pipeline().remove(WebSocketServerExtensionHandler.this); + } + } + }); + } + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketServerExtensionHandshaker.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketServerExtensionHandshaker.java new file mode 100644 index 0000000..599b1b4 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketServerExtensionHandshaker.java @@ -0,0 +1,33 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx.extensions; + + +/** + * Handshakes a client extension based on this server capabilities. + */ +public interface WebSocketServerExtensionHandshaker { + + /** + * Handshake based on client request. It must failed with null if server cannot handle it. + * + * @param extensionData + * the extension configuration sent by the client. + * @return an initialized extension if handshake phase succeed or null if failed. + */ + WebSocketServerExtension handshakeExtension(WebSocketExtensionData extensionData); + +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/DeflateDecoder.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/DeflateDecoder.java new file mode 100644 index 0000000..6706348 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/DeflateDecoder.java @@ -0,0 +1,146 @@ +package io.netty.handler.codec.http.websocketx.extensions.compression; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.CompositeByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.CodecException; +import io.netty.handler.codec.compression.ZlibCodecFactory; +import io.netty.handler.codec.compression.ZlibWrapper; +import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; +import io.netty.handler.codec.http.websocketx.ContinuationWebSocketFrame; +import io.netty.handler.codec.http.websocketx.TextWebSocketFrame; +import io.netty.handler.codec.http.websocketx.WebSocketFrame; +import io.netty.handler.codec.http.websocketx.extensions.WebSocketExtensionDecoder; +import io.netty.handler.codec.http.websocketx.extensions.WebSocketExtensionFilter; + +import java.util.List; + +import static io.netty.util.internal.ObjectUtil.*; + +/** + * Deflate implementation of a payload decompressor for + * io.netty.handler.codec.http.websocketx.WebSocketFrame. + */ +abstract class DeflateDecoder extends WebSocketExtensionDecoder { + + static final ByteBuf FRAME_TAIL = Unpooled.unreleasableBuffer( + Unpooled.wrappedBuffer(new byte[] {0x00, 0x00, (byte) 0xff, (byte) 0xff})) + .asReadOnly(); + + static final ByteBuf EMPTY_DEFLATE_BLOCK = Unpooled.unreleasableBuffer( + Unpooled.wrappedBuffer(new byte[] { 0x00 })) + .asReadOnly(); + + private final boolean noContext; + private final WebSocketExtensionFilter extensionDecoderFilter; + + private EmbeddedChannel decoder; + + /** + * Constructor + * + * @param noContext true to disable context takeover. + * @param extensionDecoderFilter extension decoder filter. + */ + DeflateDecoder(boolean noContext, WebSocketExtensionFilter extensionDecoderFilter) { + this.noContext = noContext; + this.extensionDecoderFilter = checkNotNull(extensionDecoderFilter, "extensionDecoderFilter"); + } + + /** + * Returns the extension decoder filter. + */ + protected WebSocketExtensionFilter extensionDecoderFilter() { + return extensionDecoderFilter; + } + + protected abstract boolean appendFrameTail(WebSocketFrame msg); + + protected abstract int newRsv(WebSocketFrame msg); + + @Override + protected void decode(ChannelHandlerContext ctx, WebSocketFrame msg, List out) throws Exception { + final ByteBuf decompressedContent = decompressContent(ctx, msg); + + final WebSocketFrame outMsg; + if (msg instanceof TextWebSocketFrame) { + outMsg = new TextWebSocketFrame(msg.isFinalFragment(), newRsv(msg), decompressedContent); + } else if (msg instanceof BinaryWebSocketFrame) { + outMsg = new BinaryWebSocketFrame(msg.isFinalFragment(), newRsv(msg), decompressedContent); + } else if (msg instanceof ContinuationWebSocketFrame) { + outMsg = new ContinuationWebSocketFrame(msg.isFinalFragment(), newRsv(msg), decompressedContent); + } else { + throw new CodecException("unexpected frame type: " + msg.getClass().getName()); + } + + out.add(outMsg); + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { + cleanup(); + super.handlerRemoved(ctx); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + cleanup(); + super.channelInactive(ctx); + } + + private ByteBuf decompressContent(ChannelHandlerContext ctx, WebSocketFrame msg) { + if (decoder == null) { + if (!(msg instanceof TextWebSocketFrame) && !(msg instanceof BinaryWebSocketFrame)) { + throw new CodecException("unexpected initial frame type: " + msg.getClass().getName()); + } + decoder = new EmbeddedChannel(ZlibCodecFactory.newZlibDecoder(ZlibWrapper.NONE)); + } + + boolean readable = msg.content().isReadable(); + boolean emptyDeflateBlock = EMPTY_DEFLATE_BLOCK.equals(msg.content()); + + decoder.writeInbound(msg.content().retain()); + if (appendFrameTail(msg)) { + decoder.writeInbound(FRAME_TAIL.duplicate()); + } + + CompositeByteBuf compositeDecompressedContent = ctx.alloc().compositeBuffer(); + for (;;) { + ByteBuf partUncompressedContent = decoder.readInbound(); + if (partUncompressedContent == null) { + break; + } + if (!partUncompressedContent.isReadable()) { + partUncompressedContent.release(); + continue; + } + compositeDecompressedContent.addComponent(true, partUncompressedContent); + } + // Correctly handle empty frames + // See https://github.com/netty/netty/issues/4348 + if (!emptyDeflateBlock && readable && compositeDecompressedContent.numComponents() <= 0) { + // Sometimes after fragmentation the last frame + // May contain left-over data that doesn't affect decompression + if (!(msg instanceof ContinuationWebSocketFrame)) { + compositeDecompressedContent.release(); + throw new CodecException("cannot read uncompressed buffer"); + } + } + + if (msg.isFinalFragment() && noContext) { + cleanup(); + } + + return compositeDecompressedContent; + } + + private void cleanup() { + if (decoder != null) { + // Clean-up the previous encoder if not cleaned up correctly. + decoder.finishAndReleaseAll(); + decoder = null; + } + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/DeflateEncoder.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/DeflateEncoder.java new file mode 100644 index 0000000..31d11fa --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/DeflateEncoder.java @@ -0,0 +1,165 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx.extensions.compression; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.CompositeByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.CodecException; +import io.netty.handler.codec.compression.ZlibCodecFactory; +import io.netty.handler.codec.compression.ZlibWrapper; +import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; +import io.netty.handler.codec.http.websocketx.ContinuationWebSocketFrame; +import io.netty.handler.codec.http.websocketx.TextWebSocketFrame; +import io.netty.handler.codec.http.websocketx.WebSocketFrame; +import io.netty.handler.codec.http.websocketx.extensions.WebSocketExtensionEncoder; +import io.netty.handler.codec.http.websocketx.extensions.WebSocketExtensionFilter; + +import java.util.List; + +import static io.netty.handler.codec.http.websocketx.extensions.compression.PerMessageDeflateDecoder.*; +import static io.netty.util.internal.ObjectUtil.*; + +/** + * Deflate implementation of a payload compressor for + * io.netty.handler.codec.http.websocketx.WebSocketFrame. + */ +abstract class DeflateEncoder extends WebSocketExtensionEncoder { + + private final int compressionLevel; + private final int windowSize; + private final boolean noContext; + private final WebSocketExtensionFilter extensionEncoderFilter; + + private EmbeddedChannel encoder; + + /** + * Constructor + * @param compressionLevel compression level of the compressor. + * @param windowSize maximum size of the window compressor buffer. + * @param noContext true to disable context takeover. + * @param extensionEncoderFilter extension encoder filter. + */ + DeflateEncoder(int compressionLevel, int windowSize, boolean noContext, + WebSocketExtensionFilter extensionEncoderFilter) { + this.compressionLevel = compressionLevel; + this.windowSize = windowSize; + this.noContext = noContext; + this.extensionEncoderFilter = checkNotNull(extensionEncoderFilter, "extensionEncoderFilter"); + } + + /** + * Returns the extension encoder filter. + */ + protected WebSocketExtensionFilter extensionEncoderFilter() { + return extensionEncoderFilter; + } + + /** + * @param msg the current frame. + * @return the rsv bits to set in the compressed frame. + */ + protected abstract int rsv(WebSocketFrame msg); + + /** + * @param msg the current frame. + * @return true if compressed payload tail needs to be removed. + */ + protected abstract boolean removeFrameTail(WebSocketFrame msg); + + @Override + protected void encode(ChannelHandlerContext ctx, WebSocketFrame msg, List out) throws Exception { + final ByteBuf compressedContent; + if (msg.content().isReadable()) { + compressedContent = compressContent(ctx, msg); + } else if (msg.isFinalFragment()) { + // Set empty DEFLATE block manually for unknown buffer size + // https://tools.ietf.org/html/rfc7692#section-7.2.3.6 + compressedContent = EMPTY_DEFLATE_BLOCK.duplicate(); + } else { + throw new CodecException("cannot compress content buffer"); + } + + final WebSocketFrame outMsg; + if (msg instanceof TextWebSocketFrame) { + outMsg = new TextWebSocketFrame(msg.isFinalFragment(), rsv(msg), compressedContent); + } else if (msg instanceof BinaryWebSocketFrame) { + outMsg = new BinaryWebSocketFrame(msg.isFinalFragment(), rsv(msg), compressedContent); + } else if (msg instanceof ContinuationWebSocketFrame) { + outMsg = new ContinuationWebSocketFrame(msg.isFinalFragment(), rsv(msg), compressedContent); + } else { + throw new CodecException("unexpected frame type: " + msg.getClass().getName()); + } + + out.add(outMsg); + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { + cleanup(); + super.handlerRemoved(ctx); + } + + private ByteBuf compressContent(ChannelHandlerContext ctx, WebSocketFrame msg) { + if (encoder == null) { + encoder = new EmbeddedChannel(ZlibCodecFactory.newZlibEncoder( + ZlibWrapper.NONE, compressionLevel, windowSize, 8)); + } + + encoder.writeOutbound(msg.content().retain()); + + CompositeByteBuf fullCompressedContent = ctx.alloc().compositeBuffer(); + for (;;) { + ByteBuf partCompressedContent = encoder.readOutbound(); + if (partCompressedContent == null) { + break; + } + if (!partCompressedContent.isReadable()) { + partCompressedContent.release(); + continue; + } + fullCompressedContent.addComponent(true, partCompressedContent); + } + + if (fullCompressedContent.numComponents() <= 0) { + fullCompressedContent.release(); + throw new CodecException("cannot read compressed buffer"); + } + + if (msg.isFinalFragment() && noContext) { + cleanup(); + } + + ByteBuf compressedContent; + if (removeFrameTail(msg)) { + int realLength = fullCompressedContent.readableBytes() - FRAME_TAIL.readableBytes(); + compressedContent = fullCompressedContent.slice(0, realLength); + } else { + compressedContent = fullCompressedContent; + } + + return compressedContent; + } + + private void cleanup() { + if (encoder != null) { + // Clean-up the previous encoder if not cleaned up correctly. + encoder.finishAndReleaseAll(); + encoder = null; + } + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/DeflateFrameClientExtensionHandshaker.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/DeflateFrameClientExtensionHandshaker.java new file mode 100644 index 0000000..c7360a9 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/DeflateFrameClientExtensionHandshaker.java @@ -0,0 +1,124 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx.extensions.compression; + +import io.netty.handler.codec.http.websocketx.extensions.WebSocketClientExtension; +import io.netty.handler.codec.http.websocketx.extensions.WebSocketClientExtensionHandshaker; +import io.netty.handler.codec.http.websocketx.extensions.WebSocketExtensionData; +import io.netty.handler.codec.http.websocketx.extensions.WebSocketExtensionDecoder; +import io.netty.handler.codec.http.websocketx.extensions.WebSocketExtensionEncoder; +import io.netty.handler.codec.http.websocketx.extensions.WebSocketExtensionFilterProvider; + +import java.util.Collections; + +import static io.netty.handler.codec.http.websocketx.extensions.compression.DeflateFrameServerExtensionHandshaker.*; +import static io.netty.util.internal.ObjectUtil.*; + +/** + * perframe-deflate + * handshake implementation. + */ +public final class DeflateFrameClientExtensionHandshaker implements WebSocketClientExtensionHandshaker { + + private final int compressionLevel; + private final boolean useWebkitExtensionName; + private final WebSocketExtensionFilterProvider extensionFilterProvider; + + /** + * Constructor with default configuration. + */ + public DeflateFrameClientExtensionHandshaker(boolean useWebkitExtensionName) { + this(6, useWebkitExtensionName); + } + + /** + * Constructor with custom configuration. + * + * @param compressionLevel + * Compression level between 0 and 9 (default is 6). + */ + public DeflateFrameClientExtensionHandshaker(int compressionLevel, boolean useWebkitExtensionName) { + this(compressionLevel, useWebkitExtensionName, WebSocketExtensionFilterProvider.DEFAULT); + } + + /** + * Constructor with custom configuration. + * + * @param compressionLevel + * Compression level between 0 and 9 (default is 6). + * @param extensionFilterProvider + * provides client extension filters for per frame deflate encoder and decoder. + */ + public DeflateFrameClientExtensionHandshaker(int compressionLevel, boolean useWebkitExtensionName, + WebSocketExtensionFilterProvider extensionFilterProvider) { + if (compressionLevel < 0 || compressionLevel > 9) { + throw new IllegalArgumentException( + "compressionLevel: " + compressionLevel + " (expected: 0-9)"); + } + this.compressionLevel = compressionLevel; + this.useWebkitExtensionName = useWebkitExtensionName; + this.extensionFilterProvider = checkNotNull(extensionFilterProvider, "extensionFilterProvider"); + } + + @Override + public WebSocketExtensionData newRequestData() { + return new WebSocketExtensionData( + useWebkitExtensionName ? X_WEBKIT_DEFLATE_FRAME_EXTENSION : DEFLATE_FRAME_EXTENSION, + Collections.emptyMap()); + } + + @Override + public WebSocketClientExtension handshakeExtension(WebSocketExtensionData extensionData) { + if (!X_WEBKIT_DEFLATE_FRAME_EXTENSION.equals(extensionData.name()) && + !DEFLATE_FRAME_EXTENSION.equals(extensionData.name())) { + return null; + } + + if (extensionData.parameters().isEmpty()) { + return new DeflateFrameClientExtension(compressionLevel, extensionFilterProvider); + } else { + return null; + } + } + + private static class DeflateFrameClientExtension implements WebSocketClientExtension { + + private final int compressionLevel; + private final WebSocketExtensionFilterProvider extensionFilterProvider; + + DeflateFrameClientExtension(int compressionLevel, WebSocketExtensionFilterProvider extensionFilterProvider) { + this.compressionLevel = compressionLevel; + this.extensionFilterProvider = extensionFilterProvider; + } + + @Override + public int rsv() { + return RSV1; + } + + @Override + public WebSocketExtensionEncoder newExtensionEncoder() { + return new PerFrameDeflateEncoder(compressionLevel, 15, false, + extensionFilterProvider.encoderFilter()); + } + + @Override + public WebSocketExtensionDecoder newExtensionDecoder() { + return new PerFrameDeflateDecoder(false, extensionFilterProvider.decoderFilter()); + } + } + +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/DeflateFrameServerExtensionHandshaker.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/DeflateFrameServerExtensionHandshaker.java new file mode 100644 index 0000000..4ca15f2 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/DeflateFrameServerExtensionHandshaker.java @@ -0,0 +1,125 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx.extensions.compression; + +import io.netty.handler.codec.http.websocketx.extensions.WebSocketExtensionData; +import io.netty.handler.codec.http.websocketx.extensions.WebSocketExtensionDecoder; +import io.netty.handler.codec.http.websocketx.extensions.WebSocketExtensionEncoder; +import io.netty.handler.codec.http.websocketx.extensions.WebSocketExtensionFilterProvider; +import io.netty.handler.codec.http.websocketx.extensions.WebSocketServerExtension; +import io.netty.handler.codec.http.websocketx.extensions.WebSocketServerExtensionHandshaker; + +import java.util.Collections; + +import static io.netty.util.internal.ObjectUtil.*; + +/** + * perframe-deflate + * handshake implementation. + */ +public final class DeflateFrameServerExtensionHandshaker implements WebSocketServerExtensionHandshaker { + + static final String X_WEBKIT_DEFLATE_FRAME_EXTENSION = "x-webkit-deflate-frame"; + static final String DEFLATE_FRAME_EXTENSION = "deflate-frame"; + + private final int compressionLevel; + private final WebSocketExtensionFilterProvider extensionFilterProvider; + + /** + * Constructor with default configuration. + */ + public DeflateFrameServerExtensionHandshaker() { + this(6); + } + + /** + * Constructor with custom configuration. + * + * @param compressionLevel + * Compression level between 0 and 9 (default is 6). + */ + public DeflateFrameServerExtensionHandshaker(int compressionLevel) { + this(compressionLevel, WebSocketExtensionFilterProvider.DEFAULT); + } + + /** + * Constructor with custom configuration. + * + * @param compressionLevel + * Compression level between 0 and 9 (default is 6). + * @param extensionFilterProvider + * provides server extension filters for per frame deflate encoder and decoder. + */ + public DeflateFrameServerExtensionHandshaker(int compressionLevel, + WebSocketExtensionFilterProvider extensionFilterProvider) { + if (compressionLevel < 0 || compressionLevel > 9) { + throw new IllegalArgumentException( + "compressionLevel: " + compressionLevel + " (expected: 0-9)"); + } + this.compressionLevel = compressionLevel; + this.extensionFilterProvider = checkNotNull(extensionFilterProvider, "extensionFilterProvider"); + } + + @Override + public WebSocketServerExtension handshakeExtension(WebSocketExtensionData extensionData) { + if (!X_WEBKIT_DEFLATE_FRAME_EXTENSION.equals(extensionData.name()) && + !DEFLATE_FRAME_EXTENSION.equals(extensionData.name())) { + return null; + } + + if (extensionData.parameters().isEmpty()) { + return new DeflateFrameServerExtension(compressionLevel, extensionData.name(), extensionFilterProvider); + } else { + return null; + } + } + + private static class DeflateFrameServerExtension implements WebSocketServerExtension { + + private final String extensionName; + private final int compressionLevel; + private final WebSocketExtensionFilterProvider extensionFilterProvider; + + DeflateFrameServerExtension(int compressionLevel, String extensionName, + WebSocketExtensionFilterProvider extensionFilterProvider) { + this.extensionName = extensionName; + this.compressionLevel = compressionLevel; + this.extensionFilterProvider = extensionFilterProvider; + } + + @Override + public int rsv() { + return RSV1; + } + + @Override + public WebSocketExtensionEncoder newExtensionEncoder() { + return new PerFrameDeflateEncoder(compressionLevel, 15, false, + extensionFilterProvider.encoderFilter()); + } + + @Override + public WebSocketExtensionDecoder newExtensionDecoder() { + return new PerFrameDeflateDecoder(false, extensionFilterProvider.decoderFilter()); + } + + @Override + public WebSocketExtensionData newReponseData() { + return new WebSocketExtensionData(extensionName, Collections.emptyMap()); + } + } + +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerFrameDeflateDecoder.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerFrameDeflateDecoder.java new file mode 100644 index 0000000..a987e2c --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerFrameDeflateDecoder.java @@ -0,0 +1,75 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx.extensions.compression; + +import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; +import io.netty.handler.codec.http.websocketx.ContinuationWebSocketFrame; +import io.netty.handler.codec.http.websocketx.TextWebSocketFrame; +import io.netty.handler.codec.http.websocketx.WebSocketFrame; +import io.netty.handler.codec.http.websocketx.extensions.WebSocketExtension; +import io.netty.handler.codec.http.websocketx.extensions.WebSocketExtensionFilter; + +/** + * Per-frame implementation of deflate decompressor. + */ +class PerFrameDeflateDecoder extends DeflateDecoder { + + /** + * Constructor + * + * @param noContext true to disable context takeover. + */ + PerFrameDeflateDecoder(boolean noContext) { + super(noContext, WebSocketExtensionFilter.NEVER_SKIP); + } + + /** + * Constructor + * + * @param noContext true to disable context takeover. + * @param extensionDecoderFilter extension decoder filter for per frame deflate decoder. + */ + PerFrameDeflateDecoder(boolean noContext, WebSocketExtensionFilter extensionDecoderFilter) { + super(noContext, extensionDecoderFilter); + } + + @Override + public boolean acceptInboundMessage(Object msg) throws Exception { + if (!super.acceptInboundMessage(msg)) { + return false; + } + + WebSocketFrame wsFrame = (WebSocketFrame) msg; + if (extensionDecoderFilter().mustSkip(wsFrame)) { + return false; + } + + return (msg instanceof TextWebSocketFrame || msg instanceof BinaryWebSocketFrame || + msg instanceof ContinuationWebSocketFrame) && + (wsFrame.rsv() & WebSocketExtension.RSV1) > 0; + } + + @Override + protected int newRsv(WebSocketFrame msg) { + return msg.rsv() ^ WebSocketExtension.RSV1; + } + + @Override + protected boolean appendFrameTail(WebSocketFrame msg) { + return true; + } + +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerFrameDeflateEncoder.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerFrameDeflateEncoder.java new file mode 100644 index 0000000..9cd3ede --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerFrameDeflateEncoder.java @@ -0,0 +1,81 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx.extensions.compression; + +import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; +import io.netty.handler.codec.http.websocketx.ContinuationWebSocketFrame; +import io.netty.handler.codec.http.websocketx.TextWebSocketFrame; +import io.netty.handler.codec.http.websocketx.WebSocketFrame; +import io.netty.handler.codec.http.websocketx.extensions.WebSocketExtension; +import io.netty.handler.codec.http.websocketx.extensions.WebSocketExtensionFilter; + +/** + * Per-frame implementation of deflate compressor. + */ +class PerFrameDeflateEncoder extends DeflateEncoder { + + /** + * Constructor + * + * @param compressionLevel compression level of the compressor. + * @param windowSize maximum size of the window compressor buffer. + * @param noContext true to disable context takeover. + */ + PerFrameDeflateEncoder(int compressionLevel, int windowSize, boolean noContext) { + super(compressionLevel, windowSize, noContext, WebSocketExtensionFilter.NEVER_SKIP); + } + + /** + * Constructor + * + * @param compressionLevel compression level of the compressor. + * @param windowSize maximum size of the window compressor buffer. + * @param noContext true to disable context takeover. + * @param extensionEncoderFilter extension encoder filter for per frame deflate encoder. + */ + PerFrameDeflateEncoder(int compressionLevel, int windowSize, boolean noContext, + WebSocketExtensionFilter extensionEncoderFilter) { + super(compressionLevel, windowSize, noContext, extensionEncoderFilter); + } + + @Override + public boolean acceptOutboundMessage(Object msg) throws Exception { + if (!super.acceptOutboundMessage(msg)) { + return false; + } + + WebSocketFrame wsFrame = (WebSocketFrame) msg; + if (extensionEncoderFilter().mustSkip(wsFrame)) { + return false; + } + + return (msg instanceof TextWebSocketFrame || msg instanceof BinaryWebSocketFrame || + msg instanceof ContinuationWebSocketFrame) && + wsFrame.content().readableBytes() > 0 && + (wsFrame.rsv() & WebSocketExtension.RSV1) == 0; + } + + @Override + protected int rsv(WebSocketFrame msg) { + return msg.rsv() | WebSocketExtension.RSV1; + } + + @Override + protected boolean removeFrameTail(WebSocketFrame msg) { + return true; + } + +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerMessageDeflateClientExtensionHandshaker.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerMessageDeflateClientExtensionHandshaker.java new file mode 100644 index 0000000..e691d6f --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerMessageDeflateClientExtensionHandshaker.java @@ -0,0 +1,231 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx.extensions.compression; + +import io.netty.handler.codec.http.websocketx.extensions.WebSocketClientExtension; +import io.netty.handler.codec.http.websocketx.extensions.WebSocketClientExtensionHandshaker; +import io.netty.handler.codec.http.websocketx.extensions.WebSocketExtensionData; +import io.netty.handler.codec.http.websocketx.extensions.WebSocketExtensionDecoder; +import io.netty.handler.codec.http.websocketx.extensions.WebSocketExtensionEncoder; +import io.netty.handler.codec.http.websocketx.extensions.WebSocketExtensionFilterProvider; + +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map.Entry; + +import static io.netty.handler.codec.http.websocketx.extensions.compression.PerMessageDeflateServerExtensionHandshaker.*; +import static io.netty.util.internal.ObjectUtil.*; + +/** + * permessage-deflate + * handshake implementation. + */ +public final class PerMessageDeflateClientExtensionHandshaker implements WebSocketClientExtensionHandshaker { + + private final int compressionLevel; + private final boolean allowClientWindowSize; + private final int requestedServerWindowSize; + private final boolean allowClientNoContext; + private final boolean requestedServerNoContext; + private final WebSocketExtensionFilterProvider extensionFilterProvider; + + /** + * Constructor with default configuration. + */ + public PerMessageDeflateClientExtensionHandshaker() { + this(6, false, MAX_WINDOW_SIZE, false, false); + } + + /** + * Constructor with custom configuration. + * + * @param compressionLevel + * Compression level between 0 and 9 (default is 6). + * @param allowClientWindowSize + * allows WebSocket server to customize the client inflater window size + * (default is false). + * @param requestedServerWindowSize + * indicates the requested sever window size to use if server inflater is customizable. + * @param allowClientNoContext + * allows WebSocket server to activate client_no_context_takeover + * (default is false). + * @param requestedServerNoContext + * indicates if client needs to activate server_no_context_takeover + * if server is compatible with (default is false). + */ + public PerMessageDeflateClientExtensionHandshaker(int compressionLevel, + boolean allowClientWindowSize, int requestedServerWindowSize, + boolean allowClientNoContext, boolean requestedServerNoContext) { + this(compressionLevel, allowClientWindowSize, requestedServerWindowSize, + allowClientNoContext, requestedServerNoContext, WebSocketExtensionFilterProvider.DEFAULT); + } + + /** + * Constructor with custom configuration. + * + * @param compressionLevel + * Compression level between 0 and 9 (default is 6). + * @param allowClientWindowSize + * allows WebSocket server to customize the client inflater window size + * (default is false). + * @param requestedServerWindowSize + * indicates the requested sever window size to use if server inflater is customizable. + * @param allowClientNoContext + * allows WebSocket server to activate client_no_context_takeover + * (default is false). + * @param requestedServerNoContext + * indicates if client needs to activate server_no_context_takeover + * if server is compatible with (default is false). + * @param extensionFilterProvider + * provides client extension filters for per message deflate encoder and decoder. + */ + public PerMessageDeflateClientExtensionHandshaker(int compressionLevel, + boolean allowClientWindowSize, int requestedServerWindowSize, + boolean allowClientNoContext, boolean requestedServerNoContext, + WebSocketExtensionFilterProvider extensionFilterProvider) { + + if (requestedServerWindowSize > MAX_WINDOW_SIZE || requestedServerWindowSize < MIN_WINDOW_SIZE) { + throw new IllegalArgumentException( + "requestedServerWindowSize: " + requestedServerWindowSize + " (expected: 8-15)"); + } + if (compressionLevel < 0 || compressionLevel > 9) { + throw new IllegalArgumentException( + "compressionLevel: " + compressionLevel + " (expected: 0-9)"); + } + this.compressionLevel = compressionLevel; + this.allowClientWindowSize = allowClientWindowSize; + this.requestedServerWindowSize = requestedServerWindowSize; + this.allowClientNoContext = allowClientNoContext; + this.requestedServerNoContext = requestedServerNoContext; + this.extensionFilterProvider = checkNotNull(extensionFilterProvider, "extensionFilterProvider"); + } + + @Override + public WebSocketExtensionData newRequestData() { + HashMap parameters = new HashMap(4); + if (requestedServerNoContext) { + parameters.put(SERVER_NO_CONTEXT, null); + } + if (allowClientNoContext) { + parameters.put(CLIENT_NO_CONTEXT, null); + } + if (requestedServerWindowSize != MAX_WINDOW_SIZE) { + parameters.put(SERVER_MAX_WINDOW, Integer.toString(requestedServerWindowSize)); + } + if (allowClientWindowSize) { + parameters.put(CLIENT_MAX_WINDOW, null); + } + return new WebSocketExtensionData(PERMESSAGE_DEFLATE_EXTENSION, parameters); + } + + @Override + public WebSocketClientExtension handshakeExtension(WebSocketExtensionData extensionData) { + if (!PERMESSAGE_DEFLATE_EXTENSION.equals(extensionData.name())) { + return null; + } + + boolean succeed = true; + int clientWindowSize = MAX_WINDOW_SIZE; + int serverWindowSize = MAX_WINDOW_SIZE; + boolean serverNoContext = false; + boolean clientNoContext = false; + + Iterator> parametersIterator = + extensionData.parameters().entrySet().iterator(); + while (succeed && parametersIterator.hasNext()) { + Entry parameter = parametersIterator.next(); + + if (CLIENT_MAX_WINDOW.equalsIgnoreCase(parameter.getKey())) { + // allowed client_window_size_bits + if (allowClientWindowSize) { + clientWindowSize = Integer.parseInt(parameter.getValue()); + if (clientWindowSize > MAX_WINDOW_SIZE || clientWindowSize < MIN_WINDOW_SIZE) { + succeed = false; + } + } else { + succeed = false; + } + } else if (SERVER_MAX_WINDOW.equalsIgnoreCase(parameter.getKey())) { + // acknowledged server_window_size_bits + serverWindowSize = Integer.parseInt(parameter.getValue()); + if (serverWindowSize > MAX_WINDOW_SIZE || serverWindowSize < MIN_WINDOW_SIZE) { + succeed = false; + } + } else if (CLIENT_NO_CONTEXT.equalsIgnoreCase(parameter.getKey())) { + // allowed client_no_context_takeover + if (allowClientNoContext) { + clientNoContext = true; + } else { + succeed = false; + } + } else if (SERVER_NO_CONTEXT.equalsIgnoreCase(parameter.getKey())) { + // acknowledged server_no_context_takeover + serverNoContext = true; + } else { + // unknown parameter + succeed = false; + } + } + + if ((requestedServerNoContext && !serverNoContext) || + requestedServerWindowSize < serverWindowSize) { + succeed = false; + } + + if (succeed) { + return new PermessageDeflateExtension(serverNoContext, serverWindowSize, + clientNoContext, clientWindowSize, extensionFilterProvider); + } else { + return null; + } + } + + private final class PermessageDeflateExtension implements WebSocketClientExtension { + + private final boolean serverNoContext; + private final int serverWindowSize; + private final boolean clientNoContext; + private final int clientWindowSize; + private final WebSocketExtensionFilterProvider extensionFilterProvider; + + @Override + public int rsv() { + return RSV1; + } + + PermessageDeflateExtension(boolean serverNoContext, int serverWindowSize, + boolean clientNoContext, int clientWindowSize, + WebSocketExtensionFilterProvider extensionFilterProvider) { + this.serverNoContext = serverNoContext; + this.serverWindowSize = serverWindowSize; + this.clientNoContext = clientNoContext; + this.clientWindowSize = clientWindowSize; + this.extensionFilterProvider = extensionFilterProvider; + } + + @Override + public WebSocketExtensionEncoder newExtensionEncoder() { + return new PerMessageDeflateEncoder(compressionLevel, clientWindowSize, clientNoContext, + extensionFilterProvider.encoderFilter()); + } + + @Override + public WebSocketExtensionDecoder newExtensionDecoder() { + return new PerMessageDeflateDecoder(serverNoContext, extensionFilterProvider.decoderFilter()); + } + } + +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerMessageDeflateDecoder.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerMessageDeflateDecoder.java new file mode 100644 index 0000000..ce6b484 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerMessageDeflateDecoder.java @@ -0,0 +1,96 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx.extensions.compression; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; +import io.netty.handler.codec.http.websocketx.ContinuationWebSocketFrame; +import io.netty.handler.codec.http.websocketx.TextWebSocketFrame; +import io.netty.handler.codec.http.websocketx.WebSocketFrame; +import io.netty.handler.codec.http.websocketx.extensions.WebSocketExtension; +import io.netty.handler.codec.http.websocketx.extensions.WebSocketExtensionFilter; + +import java.util.List; + +/** + * Per-message implementation of deflate decompressor. + */ +class PerMessageDeflateDecoder extends DeflateDecoder { + + private boolean compressing; + + /** + * Constructor + * + * @param noContext true to disable context takeover. + */ + PerMessageDeflateDecoder(boolean noContext) { + super(noContext, WebSocketExtensionFilter.NEVER_SKIP); + } + + /** + * Constructor + * + * @param noContext true to disable context takeover. + * @param extensionDecoderFilter extension decoder for per message deflate decoder. + */ + PerMessageDeflateDecoder(boolean noContext, WebSocketExtensionFilter extensionDecoderFilter) { + super(noContext, extensionDecoderFilter); + } + + @Override + public boolean acceptInboundMessage(Object msg) throws Exception { + if (!super.acceptInboundMessage(msg)) { + return false; + } + + WebSocketFrame wsFrame = (WebSocketFrame) msg; + if (extensionDecoderFilter().mustSkip(wsFrame)) { + if (compressing) { + throw new IllegalStateException("Cannot skip per message deflate decoder, compression in progress"); + } + return false; + } + + return ((wsFrame instanceof TextWebSocketFrame || wsFrame instanceof BinaryWebSocketFrame) && + (wsFrame.rsv() & WebSocketExtension.RSV1) > 0) || + (wsFrame instanceof ContinuationWebSocketFrame && compressing); + } + + @Override + protected int newRsv(WebSocketFrame msg) { + return (msg.rsv() & WebSocketExtension.RSV1) > 0? + msg.rsv() ^ WebSocketExtension.RSV1 : msg.rsv(); + } + + @Override + protected boolean appendFrameTail(WebSocketFrame msg) { + return msg.isFinalFragment(); + } + + @Override + protected void decode(ChannelHandlerContext ctx, WebSocketFrame msg, + List out) throws Exception { + super.decode(ctx, msg, out); + + if (msg.isFinalFragment()) { + compressing = false; + } else if (msg instanceof TextWebSocketFrame || msg instanceof BinaryWebSocketFrame) { + compressing = true; + } + } + +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerMessageDeflateEncoder.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerMessageDeflateEncoder.java new file mode 100644 index 0000000..9e3d274 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerMessageDeflateEncoder.java @@ -0,0 +1,101 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx.extensions.compression; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; +import io.netty.handler.codec.http.websocketx.ContinuationWebSocketFrame; +import io.netty.handler.codec.http.websocketx.TextWebSocketFrame; +import io.netty.handler.codec.http.websocketx.WebSocketFrame; +import io.netty.handler.codec.http.websocketx.extensions.WebSocketExtension; +import io.netty.handler.codec.http.websocketx.extensions.WebSocketExtensionFilter; + +import java.util.List; + +/** + * Per-message implementation of deflate compressor. + */ +class PerMessageDeflateEncoder extends DeflateEncoder { + + private boolean compressing; + + /** + * Constructor + * + * @param compressionLevel compression level of the compressor. + * @param windowSize maximum size of the window compressor buffer. + * @param noContext true to disable context takeover. + */ + PerMessageDeflateEncoder(int compressionLevel, int windowSize, boolean noContext) { + super(compressionLevel, windowSize, noContext, WebSocketExtensionFilter.NEVER_SKIP); + } + + /** + * Constructor + * + * @param compressionLevel compression level of the compressor. + * @param windowSize maximum size of the window compressor buffer. + * @param noContext true to disable context takeover. + * @param extensionEncoderFilter extension filter for per message deflate encoder. + */ + PerMessageDeflateEncoder(int compressionLevel, int windowSize, boolean noContext, + WebSocketExtensionFilter extensionEncoderFilter) { + super(compressionLevel, windowSize, noContext, extensionEncoderFilter); + } + + @Override + public boolean acceptOutboundMessage(Object msg) throws Exception { + if (!super.acceptOutboundMessage(msg)) { + return false; + } + + WebSocketFrame wsFrame = (WebSocketFrame) msg; + if (extensionEncoderFilter().mustSkip(wsFrame)) { + if (compressing) { + throw new IllegalStateException("Cannot skip per message deflate encoder, compression in progress"); + } + return false; + } + + return ((wsFrame instanceof TextWebSocketFrame || wsFrame instanceof BinaryWebSocketFrame) && + (wsFrame.rsv() & WebSocketExtension.RSV1) == 0) || + (wsFrame instanceof ContinuationWebSocketFrame && compressing); + } + + @Override + protected int rsv(WebSocketFrame msg) { + return msg instanceof TextWebSocketFrame || msg instanceof BinaryWebSocketFrame? + msg.rsv() | WebSocketExtension.RSV1 : msg.rsv(); + } + + @Override + protected boolean removeFrameTail(WebSocketFrame msg) { + return msg.isFinalFragment(); + } + + @Override + protected void encode(ChannelHandlerContext ctx, WebSocketFrame msg, + List out) throws Exception { + super.encode(ctx, msg, out); + + if (msg.isFinalFragment()) { + compressing = false; + } else if (msg instanceof TextWebSocketFrame || msg instanceof BinaryWebSocketFrame) { + compressing = true; + } + } + +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerMessageDeflateServerExtensionHandshaker.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerMessageDeflateServerExtensionHandshaker.java new file mode 100644 index 0000000..c4a4db3 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerMessageDeflateServerExtensionHandshaker.java @@ -0,0 +1,232 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx.extensions.compression; + +import io.netty.handler.codec.http.websocketx.extensions.WebSocketExtensionData; +import io.netty.handler.codec.http.websocketx.extensions.WebSocketExtensionDecoder; +import io.netty.handler.codec.http.websocketx.extensions.WebSocketExtensionEncoder; +import io.netty.handler.codec.http.websocketx.extensions.WebSocketExtensionFilterProvider; +import io.netty.handler.codec.http.websocketx.extensions.WebSocketServerExtension; +import io.netty.handler.codec.http.websocketx.extensions.WebSocketServerExtensionHandshaker; + +import java.util.HashMap; +import java.util.Iterator; +import java.util.Map.Entry; + +import static io.netty.util.internal.ObjectUtil.*; + +/** + * permessage-deflate + * handshake implementation. + */ +public final class PerMessageDeflateServerExtensionHandshaker implements WebSocketServerExtensionHandshaker { + + public static final int MIN_WINDOW_SIZE = 8; + public static final int MAX_WINDOW_SIZE = 15; + + static final String PERMESSAGE_DEFLATE_EXTENSION = "permessage-deflate"; + static final String CLIENT_MAX_WINDOW = "client_max_window_bits"; + static final String SERVER_MAX_WINDOW = "server_max_window_bits"; + static final String CLIENT_NO_CONTEXT = "client_no_context_takeover"; + static final String SERVER_NO_CONTEXT = "server_no_context_takeover"; + + private final int compressionLevel; + private final boolean allowServerWindowSize; + private final int preferredClientWindowSize; + private final boolean allowServerNoContext; + private final boolean preferredClientNoContext; + private final WebSocketExtensionFilterProvider extensionFilterProvider; + + /** + * Constructor with default configuration. + */ + public PerMessageDeflateServerExtensionHandshaker() { + this(6, false, MAX_WINDOW_SIZE, false, false); + } + + /** + * Constructor with custom configuration. + * + * @param compressionLevel + * Compression level between 0 and 9 (default is 6). + * @param allowServerWindowSize + * allows WebSocket client to customize the server inflater window size + * (default is false). + * @param preferredClientWindowSize + * indicates the preferred client window size to use if client inflater is customizable. + * @param allowServerNoContext + * allows WebSocket client to activate server_no_context_takeover + * (default is false). + * @param preferredClientNoContext + * indicates if server prefers to activate client_no_context_takeover + * if client is compatible with (default is false). + */ + public PerMessageDeflateServerExtensionHandshaker(int compressionLevel, boolean allowServerWindowSize, + int preferredClientWindowSize, + boolean allowServerNoContext, boolean preferredClientNoContext) { + this(compressionLevel, allowServerWindowSize, preferredClientWindowSize, allowServerNoContext, + preferredClientNoContext, WebSocketExtensionFilterProvider.DEFAULT); + } + + /** + * Constructor with custom configuration. + * + * @param compressionLevel + * Compression level between 0 and 9 (default is 6). + * @param allowServerWindowSize + * allows WebSocket client to customize the server inflater window size + * (default is false). + * @param preferredClientWindowSize + * indicates the preferred client window size to use if client inflater is customizable. + * @param allowServerNoContext + * allows WebSocket client to activate server_no_context_takeover + * (default is false). + * @param preferredClientNoContext + * indicates if server prefers to activate client_no_context_takeover + * if client is compatible with (default is false). + * @param extensionFilterProvider + * provides server extension filters for per message deflate encoder and decoder. + */ + public PerMessageDeflateServerExtensionHandshaker(int compressionLevel, boolean allowServerWindowSize, + int preferredClientWindowSize, + boolean allowServerNoContext, boolean preferredClientNoContext, + WebSocketExtensionFilterProvider extensionFilterProvider) { + if (preferredClientWindowSize > MAX_WINDOW_SIZE || preferredClientWindowSize < MIN_WINDOW_SIZE) { + throw new IllegalArgumentException( + "preferredServerWindowSize: " + preferredClientWindowSize + " (expected: 8-15)"); + } + if (compressionLevel < 0 || compressionLevel > 9) { + throw new IllegalArgumentException( + "compressionLevel: " + compressionLevel + " (expected: 0-9)"); + } + this.compressionLevel = compressionLevel; + this.allowServerWindowSize = allowServerWindowSize; + this.preferredClientWindowSize = preferredClientWindowSize; + this.allowServerNoContext = allowServerNoContext; + this.preferredClientNoContext = preferredClientNoContext; + this.extensionFilterProvider = checkNotNull(extensionFilterProvider, "extensionFilterProvider"); + } + + @Override + public WebSocketServerExtension handshakeExtension(WebSocketExtensionData extensionData) { + if (!PERMESSAGE_DEFLATE_EXTENSION.equals(extensionData.name())) { + return null; + } + + boolean deflateEnabled = true; + int clientWindowSize = MAX_WINDOW_SIZE; + int serverWindowSize = MAX_WINDOW_SIZE; + boolean serverNoContext = false; + boolean clientNoContext = false; + + Iterator> parametersIterator = + extensionData.parameters().entrySet().iterator(); + while (deflateEnabled && parametersIterator.hasNext()) { + Entry parameter = parametersIterator.next(); + + if (CLIENT_MAX_WINDOW.equalsIgnoreCase(parameter.getKey())) { + // use preferred clientWindowSize because client is compatible with customization + clientWindowSize = preferredClientWindowSize; + } else if (SERVER_MAX_WINDOW.equalsIgnoreCase(parameter.getKey())) { + // use provided windowSize if it is allowed + if (allowServerWindowSize) { + serverWindowSize = Integer.parseInt(parameter.getValue()); + if (serverWindowSize > MAX_WINDOW_SIZE || serverWindowSize < MIN_WINDOW_SIZE) { + deflateEnabled = false; + } + } else { + deflateEnabled = false; + } + } else if (CLIENT_NO_CONTEXT.equalsIgnoreCase(parameter.getKey())) { + // use preferred clientNoContext because client is compatible with customization + clientNoContext = preferredClientNoContext; + } else if (SERVER_NO_CONTEXT.equalsIgnoreCase(parameter.getKey())) { + // use server no context if allowed + if (allowServerNoContext) { + serverNoContext = true; + } else { + deflateEnabled = false; + } + } else { + // unknown parameter + deflateEnabled = false; + } + } + + if (deflateEnabled) { + return new PermessageDeflateExtension(compressionLevel, serverNoContext, + serverWindowSize, clientNoContext, clientWindowSize, extensionFilterProvider); + } else { + return null; + } + } + + private static class PermessageDeflateExtension implements WebSocketServerExtension { + + private final int compressionLevel; + private final boolean serverNoContext; + private final int serverWindowSize; + private final boolean clientNoContext; + private final int clientWindowSize; + private final WebSocketExtensionFilterProvider extensionFilterProvider; + + PermessageDeflateExtension(int compressionLevel, boolean serverNoContext, + int serverWindowSize, boolean clientNoContext, int clientWindowSize, + WebSocketExtensionFilterProvider extensionFilterProvider) { + this.compressionLevel = compressionLevel; + this.serverNoContext = serverNoContext; + this.serverWindowSize = serverWindowSize; + this.clientNoContext = clientNoContext; + this.clientWindowSize = clientWindowSize; + this.extensionFilterProvider = extensionFilterProvider; + } + + @Override + public int rsv() { + return RSV1; + } + + @Override + public WebSocketExtensionEncoder newExtensionEncoder() { + return new PerMessageDeflateEncoder(compressionLevel, serverWindowSize, serverNoContext, + extensionFilterProvider.encoderFilter()); + } + + @Override + public WebSocketExtensionDecoder newExtensionDecoder() { + return new PerMessageDeflateDecoder(clientNoContext, extensionFilterProvider.decoderFilter()); + } + + @Override + public WebSocketExtensionData newReponseData() { + HashMap parameters = new HashMap(4); + if (serverNoContext) { + parameters.put(SERVER_NO_CONTEXT, null); + } + if (clientNoContext) { + parameters.put(CLIENT_NO_CONTEXT, null); + } + if (serverWindowSize != MAX_WINDOW_SIZE) { + parameters.put(SERVER_MAX_WINDOW, Integer.toString(serverWindowSize)); + } + if (clientWindowSize != MAX_WINDOW_SIZE) { + parameters.put(CLIENT_MAX_WINDOW, Integer.toString(clientWindowSize)); + } + return new WebSocketExtensionData(PERMESSAGE_DEFLATE_EXTENSION, parameters); + } + } + +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/WebSocketClientCompressionHandler.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/WebSocketClientCompressionHandler.java new file mode 100644 index 0000000..383215b --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/WebSocketClientCompressionHandler.java @@ -0,0 +1,38 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx.extensions.compression; + +import io.netty.channel.ChannelHandler; +import io.netty.handler.codec.http.websocketx.extensions.WebSocketClientExtensionHandler; + +/** + * Extends io.netty.handler.codec.http.websocketx.extensions.compression.WebSocketClientExtensionHandler + * to handle the most common WebSocket Compression Extensions. + * + * See io.netty.example.http.websocketx.client.WebSocketClient for usage. + */ +@ChannelHandler.Sharable +public final class WebSocketClientCompressionHandler extends WebSocketClientExtensionHandler { + + public static final WebSocketClientCompressionHandler INSTANCE = new WebSocketClientCompressionHandler(); + + private WebSocketClientCompressionHandler() { + super(new PerMessageDeflateClientExtensionHandshaker(), + new DeflateFrameClientExtensionHandshaker(false), + new DeflateFrameClientExtensionHandshaker(true)); + } + +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/WebSocketServerCompressionHandler.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/WebSocketServerCompressionHandler.java new file mode 100644 index 0000000..ff07d88 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/WebSocketServerCompressionHandler.java @@ -0,0 +1,36 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx.extensions.compression; + +import io.netty.handler.codec.http.websocketx.extensions.WebSocketServerExtensionHandler; + +/** + * Extends io.netty.handler.codec.http.websocketx.extensions.compression.WebSocketServerExtensionHandler + * to handle the most common WebSocket Compression Extensions. + * + * See io.netty.example.http.websocketx.html5.WebSocketServer for usage. + */ +public class WebSocketServerCompressionHandler extends WebSocketServerExtensionHandler { + + /** + * Constructor with default configuration. + */ + public WebSocketServerCompressionHandler() { + super(new PerMessageDeflateServerExtensionHandshaker(), + new DeflateFrameServerExtensionHandshaker()); + } + +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/package-info.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/package-info.java new file mode 100644 index 0000000..ca029c1 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/compression/package-info.java @@ -0,0 +1,33 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/** + * Encoder, decoder, handshakers to handle most common WebSocket Compression Extensions. + *

+ * This package supports different web socket extensions. + * The specification currently supported are: + *

+ *

+ *

+ * See io.netty.example.http.websocketx.client.WebSocketClient and + * io.netty.example.http.websocketx.html5.WebSocketServer for usage. + *

+ */ +package io.netty.handler.codec.http.websocketx.extensions.compression; diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/package-info.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/package-info.java new file mode 100644 index 0000000..89cef75 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/extensions/package-info.java @@ -0,0 +1,23 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/** + * Encoder, decoder, handshakers to handle + * WebSocket Extensions. + * + * See WebSocketServerExtensionHandler for more details. + */ +package io.netty.handler.codec.http.websocketx.extensions; diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/package-info.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/package-info.java new file mode 100644 index 0000000..6a20336 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/http/websocketx/package-info.java @@ -0,0 +1,39 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/** + * Encoder, decoder, handshakers and their related message types for + * Web Socket data frames. + *

+ * This package supports different web socket specification versions (hence the X suffix). + * The specification current supported are: + *

+ *

+ *

+ * For the detailed instruction on adding add Web Socket support to your HTTP + * server, take a look into the WebSocketServerX example located in the + * {@code io.netty.example.http.websocket} package. + *

+ */ +package io.netty.handler.codec.http.websocketx; + diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/rtsp/RtspDecoder.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/rtsp/RtspDecoder.java new file mode 100644 index 0000000..8feb531 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/rtsp/RtspDecoder.java @@ -0,0 +1,181 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.rtsp; + +import java.util.regex.Pattern; + +import io.netty.buffer.Unpooled; +import io.netty.handler.codec.http.DefaultFullHttpRequest; +import io.netty.handler.codec.http.DefaultFullHttpResponse; +import io.netty.handler.codec.http.DefaultHttpRequest; +import io.netty.handler.codec.http.DefaultHttpResponse; +import io.netty.handler.codec.http.HttpDecoderConfig; +import io.netty.handler.codec.http.HttpMessage; +import io.netty.handler.codec.http.HttpObjectDecoder; +import io.netty.handler.codec.http.HttpResponseStatus; + +/** + * Decodes {@link io.netty.buffer.ByteBuf}s into RTSP messages represented in + * {@link HttpMessage}s. + *

+ *

Parameters that prevents excessive memory consumption

+ * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
NameMeaning
{@code maxInitialLineLength}The maximum length of the initial line + * (e.g. {@code "SETUP / RTSP/1.0"} or {@code "RTSP/1.0 200 OK"}) + * If the length of the initial line exceeds this value, a + * {@link io.netty.handler.codec.TooLongFrameException} will be raised.
{@code maxHeaderSize}The maximum length of all headers. If the sum of the length of each + * header exceeds this value, a {@link io.netty.handler.codec.TooLongFrameException} will be + * raised.
{@code maxContentLength}The maximum length of the content. If the content length exceeds this + * value, a {@link io.netty.handler.codec.TooLongFrameException} will be raised.
+ */ +public class RtspDecoder extends HttpObjectDecoder { + /** + * Status code for unknown responses. + */ + private static final HttpResponseStatus UNKNOWN_STATUS = + new HttpResponseStatus(999, "Unknown"); + /** + * True if the message to decode is a request. + * False if the message to decode is a response. + */ + private boolean isDecodingRequest; + + /** + * Regex used on first line in message to detect if it is a response. + */ + private static final Pattern versionPattern = Pattern.compile("RTSP/\\d\\.\\d"); + + /** + * Constant for default max content length. + */ + public static final int DEFAULT_MAX_CONTENT_LENGTH = 8192; + + /** + * Creates a new instance with the default + * {@code maxInitialLineLength (4096)}, {@code maxHeaderSize (8192)}, and + * {@code maxContentLength (8192)}. + */ + public RtspDecoder() { + this(DEFAULT_MAX_INITIAL_LINE_LENGTH, + DEFAULT_MAX_HEADER_SIZE, + DEFAULT_MAX_CONTENT_LENGTH); + } + + /** + * Creates a new instance with the specified parameters. + * @param maxInitialLineLength The max allowed length of initial line + * @param maxHeaderSize The max allowed size of header + * @param maxContentLength The max allowed content length + */ + public RtspDecoder(final int maxInitialLineLength, + final int maxHeaderSize, + final int maxContentLength) { + super(new HttpDecoderConfig() + .setMaxInitialLineLength(maxInitialLineLength) + .setMaxHeaderSize(maxHeaderSize) + .setMaxChunkSize(maxContentLength * 2) + .setChunkedSupported(false)); + } + + /** + * Creates a new instance with the specified parameters. + * @param maxInitialLineLength The max allowed length of initial line + * @param maxHeaderSize The max allowed size of header + * @param maxContentLength The max allowed content length + * @param validateHeaders Set to true if headers should be validated + * @deprecated Use the {@link #RtspDecoder(HttpDecoderConfig)} constructor instead, + * or the {@link #RtspDecoder(int, int, int)} to always enable header validation. + */ + @Deprecated + public RtspDecoder(final int maxInitialLineLength, + final int maxHeaderSize, + final int maxContentLength, + final boolean validateHeaders) { + super(new HttpDecoderConfig() + .setMaxInitialLineLength(maxInitialLineLength) + .setMaxHeaderSize(maxHeaderSize) + .setMaxChunkSize(maxContentLength * 2) + .setChunkedSupported(false) + .setValidateHeaders(validateHeaders)); + } + + /** + * Creates a new instance with the specified configuration. + */ + public RtspDecoder(HttpDecoderConfig config) { + super(config.clone() + .setMaxChunkSize(2 * config.getMaxChunkSize()) + .setChunkedSupported(false)); + } + + @Override + protected HttpMessage createMessage(final String[] initialLine) + throws Exception { + // If the first element of the initial line is a version string then + // this is a response + if (versionPattern.matcher(initialLine[0]).matches()) { + isDecodingRequest = false; + return new DefaultHttpResponse(RtspVersions.valueOf(initialLine[0]), + new HttpResponseStatus(Integer.parseInt(initialLine[1]), + initialLine[2]), + headersFactory); + } else { + isDecodingRequest = true; + return new DefaultHttpRequest(RtspVersions.valueOf(initialLine[2]), + RtspMethods.valueOf(initialLine[0]), + initialLine[1], + headersFactory); + } + } + + @Override + protected boolean isContentAlwaysEmpty(final HttpMessage msg) { + // Unlike HTTP, RTSP always assumes zero-length body if Content-Length + // header is absent. + return super.isContentAlwaysEmpty(msg) || !msg.headers().contains(RtspHeaderNames.CONTENT_LENGTH); + } + + @Override + protected HttpMessage createInvalidMessage() { + if (isDecodingRequest) { + return new DefaultFullHttpRequest(RtspVersions.RTSP_1_0, + RtspMethods.OPTIONS, "/bad-request", Unpooled.buffer(0), headersFactory, trailersFactory); + } else { + return new DefaultFullHttpResponse( + RtspVersions.RTSP_1_0, UNKNOWN_STATUS, Unpooled.buffer(0), headersFactory, trailersFactory); + } + } + + @Override + protected boolean isDecodingRequest() { + return isDecodingRequest; + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/rtsp/RtspEncoder.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/rtsp/RtspEncoder.java new file mode 100644 index 0000000..20987be --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/rtsp/RtspEncoder.java @@ -0,0 +1,68 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.rtsp; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.handler.codec.UnsupportedMessageTypeException; +import io.netty.handler.codec.http.HttpContent; +import io.netty.handler.codec.http.HttpMessage; +import io.netty.handler.codec.http.HttpObjectEncoder; +import io.netty.handler.codec.http.HttpRequest; +import io.netty.handler.codec.http.HttpResponse; +import io.netty.util.CharsetUtil; +import io.netty.util.internal.StringUtil; + +import static io.netty.handler.codec.http.HttpConstants.*; + +/** + * Encodes an RTSP message represented in {@link HttpMessage} or an {@link HttpContent} into + * a {@link ByteBuf}. + */ +public class RtspEncoder extends HttpObjectEncoder { + private static final int CRLF_SHORT = (CR << 8) | LF; + + @Override + public boolean acceptOutboundMessage(final Object msg) + throws Exception { + return super.acceptOutboundMessage(msg) && ((msg instanceof HttpRequest) || (msg instanceof HttpResponse)); + } + + @Override + protected void encodeInitialLine(final ByteBuf buf, final HttpMessage message) + throws Exception { + if (message instanceof HttpRequest) { + HttpRequest request = (HttpRequest) message; + ByteBufUtil.copy(request.method().asciiName(), buf); + buf.writeByte(SP); + buf.writeCharSequence(request.uri(), CharsetUtil.UTF_8); + buf.writeByte(SP); + buf.writeCharSequence(request.protocolVersion().toString(), CharsetUtil.US_ASCII); + ByteBufUtil.writeShortBE(buf, CRLF_SHORT); + } else if (message instanceof HttpResponse) { + HttpResponse response = (HttpResponse) message; + buf.writeCharSequence(response.protocolVersion().toString(), CharsetUtil.US_ASCII); + buf.writeByte(SP); + ByteBufUtil.copy(response.status().codeAsText(), buf); + buf.writeByte(SP); + buf.writeCharSequence(response.status().reasonPhrase(), CharsetUtil.US_ASCII); + ByteBufUtil.writeShortBE(buf, CRLF_SHORT); + } else { + throw new UnsupportedMessageTypeException("Unsupported type " + + StringUtil.simpleClassName(message)); + } + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/rtsp/RtspHeaderNames.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/rtsp/RtspHeaderNames.java new file mode 100644 index 0000000..42154c6 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/rtsp/RtspHeaderNames.java @@ -0,0 +1,207 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.codec.rtsp; + +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.util.AsciiString; + +/** + * Standard RTSP header names. + *

+ * These are all defined as lowercase to support HTTP/2 requirements while also not + * violating RTSP/1.x requirements. New header names should always be lowercase. + */ +public final class RtspHeaderNames { + /** + * {@code "accept"} + */ + public static final AsciiString ACCEPT = HttpHeaderNames.ACCEPT; + /** + * {@code "accept-encoding"} + */ + public static final AsciiString ACCEPT_ENCODING = HttpHeaderNames.ACCEPT_ENCODING; + /** + * {@code "accept-language"} + */ + public static final AsciiString ACCEPT_LANGUAGE = HttpHeaderNames.ACCEPT_LANGUAGE; + /** + * {@code "allow"} + */ + public static final AsciiString ALLOW = AsciiString.cached("allow"); + /** + * {@code "authorization"} + */ + public static final AsciiString AUTHORIZATION = HttpHeaderNames.AUTHORIZATION; + /** + * {@code "bandwidth"} + */ + public static final AsciiString BANDWIDTH = AsciiString.cached("bandwidth"); + /** + * {@code "blocksize"} + */ + public static final AsciiString BLOCKSIZE = AsciiString.cached("blocksize"); + /** + * {@code "cache-control"} + */ + public static final AsciiString CACHE_CONTROL = HttpHeaderNames.CACHE_CONTROL; + /** + * {@code "conference"} + */ + public static final AsciiString CONFERENCE = AsciiString.cached("conference"); + /** + * {@code "connection"} + */ + public static final AsciiString CONNECTION = HttpHeaderNames.CONNECTION; + /** + * {@code "content-base"} + */ + public static final AsciiString CONTENT_BASE = HttpHeaderNames.CONTENT_BASE; + /** + * {@code "content-encoding"} + */ + public static final AsciiString CONTENT_ENCODING = HttpHeaderNames.CONTENT_ENCODING; + /** + * {@code "content-language"} + */ + public static final AsciiString CONTENT_LANGUAGE = HttpHeaderNames.CONTENT_LANGUAGE; + /** + * {@code "content-length"} + */ + public static final AsciiString CONTENT_LENGTH = HttpHeaderNames.CONTENT_LENGTH; + /** + * {@code "content-location"} + */ + public static final AsciiString CONTENT_LOCATION = HttpHeaderNames.CONTENT_LOCATION; + /** + * {@code "content-type"} + */ + public static final AsciiString CONTENT_TYPE = HttpHeaderNames.CONTENT_TYPE; + /** + * {@code "cseq"} + */ + public static final AsciiString CSEQ = AsciiString.cached("cseq"); + /** + * {@code "date"} + */ + public static final AsciiString DATE = HttpHeaderNames.DATE; + /** + * {@code "expires"} + */ + public static final AsciiString EXPIRES = HttpHeaderNames.EXPIRES; + /** + * {@code "from"} + */ + public static final AsciiString FROM = HttpHeaderNames.FROM; + /** + * {@code "host"} + */ + public static final AsciiString HOST = HttpHeaderNames.HOST; + /** + * {@code "if-match"} + */ + public static final AsciiString IF_MATCH = HttpHeaderNames.IF_MATCH; + /** + * {@code "if-modified-since"} + */ + public static final AsciiString IF_MODIFIED_SINCE = HttpHeaderNames.IF_MODIFIED_SINCE; + /** + * {@code "keymgmt"} + */ + public static final AsciiString KEYMGMT = AsciiString.cached("keymgmt"); + /** + * {@code "last-modified"} + */ + public static final AsciiString LAST_MODIFIED = HttpHeaderNames.LAST_MODIFIED; + /** + * {@code "proxy-authenticate"} + */ + public static final AsciiString PROXY_AUTHENTICATE = HttpHeaderNames.PROXY_AUTHENTICATE; + /** + * {@code "proxy-require"} + */ + public static final AsciiString PROXY_REQUIRE = AsciiString.cached("proxy-require"); + /** + * {@code "public"} + */ + public static final AsciiString PUBLIC = AsciiString.cached("public"); + /** + * {@code "range"} + */ + public static final AsciiString RANGE = HttpHeaderNames.RANGE; + /** + * {@code "referer"} + */ + public static final AsciiString REFERER = HttpHeaderNames.REFERER; + /** + * {@code "require"} + */ + public static final AsciiString REQUIRE = AsciiString.cached("require"); + /** + * {@code "retry-after"} + */ + public static final AsciiString RETRT_AFTER = HttpHeaderNames.RETRY_AFTER; + /** + * {@code "rtp-info"} + */ + public static final AsciiString RTP_INFO = AsciiString.cached("rtp-info"); + /** + * {@code "scale"} + */ + public static final AsciiString SCALE = AsciiString.cached("scale"); + /** + * {@code "session"} + */ + public static final AsciiString SESSION = AsciiString.cached("session"); + /** + * {@code "server"} + */ + public static final AsciiString SERVER = HttpHeaderNames.SERVER; + /** + * {@code "speed"} + */ + public static final AsciiString SPEED = AsciiString.cached("speed"); + /** + * {@code "timestamp"} + */ + public static final AsciiString TIMESTAMP = AsciiString.cached("timestamp"); + /** + * {@code "transport"} + */ + public static final AsciiString TRANSPORT = AsciiString.cached("transport"); + /** + * {@code "unsupported"} + */ + public static final AsciiString UNSUPPORTED = AsciiString.cached("unsupported"); + /** + * {@code "user-agent"} + */ + public static final AsciiString USER_AGENT = HttpHeaderNames.USER_AGENT; + /** + * {@code "vary"} + */ + public static final AsciiString VARY = HttpHeaderNames.VARY; + /** + * {@code "via"} + */ + public static final AsciiString VIA = HttpHeaderNames.VIA; + /** + * {@code "www-authenticate"} + */ + public static final AsciiString WWW_AUTHENTICATE = HttpHeaderNames.WWW_AUTHENTICATE; + + private RtspHeaderNames() { } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/rtsp/RtspHeaderValues.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/rtsp/RtspHeaderValues.java new file mode 100644 index 0000000..8c50276 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/rtsp/RtspHeaderValues.java @@ -0,0 +1,196 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.codec.rtsp; + +import io.netty.handler.codec.http.HttpHeaderValues; +import io.netty.util.AsciiString; + +/** + * Standard RTSP header names. + */ +public final class RtspHeaderValues { + /** + * {@code "append"} + */ + public static final AsciiString APPEND = AsciiString.cached("append"); + /** + * {@code "AVP"} + */ + public static final AsciiString AVP = AsciiString.cached("AVP"); + /** + * {@code "bytes"} + */ + public static final AsciiString BYTES = HttpHeaderValues.BYTES; + /** + * {@code "charset"} + */ + public static final AsciiString CHARSET = HttpHeaderValues.CHARSET; + /** + * {@code "client_port"} + */ + public static final AsciiString CLIENT_PORT = AsciiString.cached("client_port"); + /** + * {@code "clock"} + */ + public static final AsciiString CLOCK = AsciiString.cached("clock"); + /** + * {@code "close"} + */ + public static final AsciiString CLOSE = HttpHeaderValues.CLOSE; + /** + * {@code "compress"} + */ + public static final AsciiString COMPRESS = HttpHeaderValues.COMPRESS; + /** + * {@code "100-continue"} + */ + public static final AsciiString CONTINUE = HttpHeaderValues.CONTINUE; + /** + * {@code "deflate"} + */ + public static final AsciiString DEFLATE = HttpHeaderValues.DEFLATE; + /** + * {@code "destination"} + */ + public static final AsciiString DESTINATION = AsciiString.cached("destination"); + /** + * {@code "gzip"} + */ + public static final AsciiString GZIP = HttpHeaderValues.GZIP; + /** + * {@code "identity"} + */ + public static final AsciiString IDENTITY = HttpHeaderValues.IDENTITY; + /** + * {@code "interleaved"} + */ + public static final AsciiString INTERLEAVED = AsciiString.cached("interleaved"); + /** + * {@code "keep-alive"} + */ + public static final AsciiString KEEP_ALIVE = HttpHeaderValues.KEEP_ALIVE; + /** + * {@code "layers"} + */ + public static final AsciiString LAYERS = AsciiString.cached("layers"); + /** + * {@code "max-age"} + */ + public static final AsciiString MAX_AGE = HttpHeaderValues.MAX_AGE; + /** + * {@code "max-stale"} + */ + public static final AsciiString MAX_STALE = HttpHeaderValues.MAX_STALE; + /** + * {@code "min-fresh"} + */ + public static final AsciiString MIN_FRESH = HttpHeaderValues.MIN_FRESH; + /** + * {@code "mode"} + */ + public static final AsciiString MODE = AsciiString.cached("mode"); + /** + * {@code "multicast"} + */ + public static final AsciiString MULTICAST = AsciiString.cached("multicast"); + /** + * {@code "must-revalidate"} + */ + public static final AsciiString MUST_REVALIDATE = HttpHeaderValues.MUST_REVALIDATE; + /** + * {@code "none"} + */ + public static final AsciiString NONE = HttpHeaderValues.NONE; + /** + * {@code "no-cache"} + */ + public static final AsciiString NO_CACHE = HttpHeaderValues.NO_CACHE; + /** + * {@code "no-transform"} + */ + public static final AsciiString NO_TRANSFORM = HttpHeaderValues.NO_TRANSFORM; + /** + * {@code "only-if-cached"} + */ + public static final AsciiString ONLY_IF_CACHED = HttpHeaderValues.ONLY_IF_CACHED; + /** + * {@code "port"} + */ + public static final AsciiString PORT = AsciiString.cached("port"); + /** + * {@code "private"} + */ + public static final AsciiString PRIVATE = HttpHeaderValues.PRIVATE; + /** + * {@code "proxy-revalidate"} + */ + public static final AsciiString PROXY_REVALIDATE = HttpHeaderValues.PROXY_REVALIDATE; + /** + * {@code "public"} + */ + public static final AsciiString PUBLIC = HttpHeaderValues.PUBLIC; + /** + * {@code "RTP"} + */ + public static final AsciiString RTP = AsciiString.cached("RTP"); + /** + * {@code "rtptime"} + */ + public static final AsciiString RTPTIME = AsciiString.cached("rtptime"); + /** + * {@code "seq"} + */ + public static final AsciiString SEQ = AsciiString.cached("seq"); + /** + * {@code "server_port"} + */ + public static final AsciiString SERVER_PORT = AsciiString.cached("server_port"); + /** + * {@code "ssrc"} + */ + public static final AsciiString SSRC = AsciiString.cached("ssrc"); + /** + * {@code "TCP"} + */ + public static final AsciiString TCP = AsciiString.cached("TCP"); + /** + * {@code "time"} + */ + public static final AsciiString TIME = AsciiString.cached("time"); + /** + * {@code "timeout"} + */ + public static final AsciiString TIMEOUT = AsciiString.cached("timeout"); + /** + * {@code "ttl"} + */ + public static final AsciiString TTL = AsciiString.cached("ttl"); + /** + * {@code "UDP"} + */ + public static final AsciiString UDP = AsciiString.cached("UDP"); + /** + * {@code "unicast"} + */ + public static final AsciiString UNICAST = AsciiString.cached("unicast"); + /** + * {@code "url"} + */ + public static final AsciiString URL = AsciiString.cached("url"); + + private RtspHeaderValues() { } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/rtsp/RtspHeaders.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/rtsp/RtspHeaders.java new file mode 100644 index 0000000..3abc012 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/rtsp/RtspHeaders.java @@ -0,0 +1,398 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.rtsp; + +import io.netty.handler.codec.http.HttpHeaders; + + +/** + * @deprecated Use {@link RtspHeaderNames} or {@link RtspHeaderValues} instead. + + * Standard RTSP header names and values. + */ +@Deprecated +@SuppressWarnings("deprecation") +public final class RtspHeaders { + + /** + * @deprecated Use {@link RtspHeaderNames} instead. + * + * Standard RTSP header names. + */ + @Deprecated + public static final class Names { + /** + * {@code "Accept"} + */ + public static final String ACCEPT = HttpHeaders.Names.ACCEPT; + /** + * {@code "Accept-Encoding"} + */ + public static final String ACCEPT_ENCODING = HttpHeaders.Names.ACCEPT_ENCODING; + /** + * {@code "Accept-Language"} + */ + public static final String ACCEPT_LANGUAGE = HttpHeaders.Names.ACCEPT_LANGUAGE; + /** + * {@code "Allow"} + */ + public static final String ALLOW = "Allow"; + /** + * {@code "Authorization"} + */ + public static final String AUTHORIZATION = HttpHeaders.Names.AUTHORIZATION; + /** + * {@code "Bandwidth"} + */ + public static final String BANDWIDTH = "Bandwidth"; + /** + * {@code "Blocksize"} + */ + public static final String BLOCKSIZE = "Blocksize"; + /** + * {@code "Cache-Control"} + */ + public static final String CACHE_CONTROL = HttpHeaders.Names.CACHE_CONTROL; + /** + * {@code "Conference"} + */ + public static final String CONFERENCE = "Conference"; + /** + * {@code "Connection"} + */ + public static final String CONNECTION = HttpHeaders.Names.CONNECTION; + /** + * {@code "Content-Base"} + */ + public static final String CONTENT_BASE = HttpHeaders.Names.CONTENT_BASE; + /** + * {@code "Content-Encoding"} + */ + public static final String CONTENT_ENCODING = HttpHeaders.Names.CONTENT_ENCODING; + /** + * {@code "Content-Language"} + */ + public static final String CONTENT_LANGUAGE = HttpHeaders.Names.CONTENT_LANGUAGE; + /** + * {@code "Content-Length"} + */ + public static final String CONTENT_LENGTH = HttpHeaders.Names.CONTENT_LENGTH; + /** + * {@code "Content-Location"} + */ + public static final String CONTENT_LOCATION = HttpHeaders.Names.CONTENT_LOCATION; + /** + * {@code "Content-Type"} + */ + public static final String CONTENT_TYPE = HttpHeaders.Names.CONTENT_TYPE; + /** + * {@code "CSeq"} + */ + public static final String CSEQ = "CSeq"; + /** + * {@code "Date"} + */ + public static final String DATE = HttpHeaders.Names.DATE; + /** + * {@code "Expires"} + */ + public static final String EXPIRES = HttpHeaders.Names.EXPIRES; + /** + * {@code "From"} + */ + public static final String FROM = HttpHeaders.Names.FROM; + /** + * {@code "Host"} + */ + public static final String HOST = HttpHeaders.Names.HOST; + /** + * {@code "If-Match"} + */ + public static final String IF_MATCH = HttpHeaders.Names.IF_MATCH; + /** + * {@code "If-Modified-Since"} + */ + public static final String IF_MODIFIED_SINCE = HttpHeaders.Names.IF_MODIFIED_SINCE; + /** + * {@code "KeyMgmt"} + */ + public static final String KEYMGMT = "KeyMgmt"; + /** + * {@code "Last-Modified"} + */ + public static final String LAST_MODIFIED = HttpHeaders.Names.LAST_MODIFIED; + /** + * {@code "Proxy-Authenticate"} + */ + public static final String PROXY_AUTHENTICATE = HttpHeaders.Names.PROXY_AUTHENTICATE; + /** + * {@code "Proxy-Require"} + */ + public static final String PROXY_REQUIRE = "Proxy-Require"; + /** + * {@code "Public"} + */ + public static final String PUBLIC = "Public"; + /** + * {@code "Range"} + */ + public static final String RANGE = HttpHeaders.Names.RANGE; + /** + * {@code "Referer"} + */ + public static final String REFERER = HttpHeaders.Names.REFERER; + /** + * {@code "Require"} + */ + public static final String REQUIRE = "Require"; + /** + * {@code "Retry-After"} + */ + public static final String RETRT_AFTER = HttpHeaders.Names.RETRY_AFTER; + /** + * {@code "RTP-Info"} + */ + public static final String RTP_INFO = "RTP-Info"; + /** + * {@code "Scale"} + */ + public static final String SCALE = "Scale"; + /** + * {@code "Session"} + */ + public static final String SESSION = "Session"; + /** + * {@code "Server"} + */ + public static final String SERVER = HttpHeaders.Names.SERVER; + /** + * {@code "Speed"} + */ + public static final String SPEED = "Speed"; + /** + * {@code "Timestamp"} + */ + public static final String TIMESTAMP = "Timestamp"; + /** + * {@code "Transport"} + */ + public static final String TRANSPORT = "Transport"; + /** + * {@code "Unsupported"} + */ + public static final String UNSUPPORTED = "Unsupported"; + /** + * {@code "User-Agent"} + */ + public static final String USER_AGENT = HttpHeaders.Names.USER_AGENT; + /** + * {@code "Vary"} + */ + public static final String VARY = HttpHeaders.Names.VARY; + /** + * {@code "Via"} + */ + public static final String VIA = HttpHeaders.Names.VIA; + /** + * {@code "WWW-Authenticate"} + */ + public static final String WWW_AUTHENTICATE = HttpHeaders.Names.WWW_AUTHENTICATE; + + private Names() { + } + } + + /** + * @deprecated Use {@link RtspHeaderValues} instead. + * + * Standard RTSP header values. + */ + @Deprecated + public static final class Values { + /** + * {@code "append"} + */ + public static final String APPEND = "append"; + /** + * {@code "AVP"} + */ + public static final String AVP = "AVP"; + /** + * {@code "bytes"} + */ + public static final String BYTES = HttpHeaders.Values.BYTES; + /** + * {@code "charset"} + */ + public static final String CHARSET = HttpHeaders.Values.CHARSET; + /** + * {@code "client_port"} + */ + public static final String CLIENT_PORT = "client_port"; + /** + * {@code "clock"} + */ + public static final String CLOCK = "clock"; + /** + * {@code "close"} + */ + public static final String CLOSE = HttpHeaders.Values.CLOSE; + /** + * {@code "compress"} + */ + public static final String COMPRESS = HttpHeaders.Values.COMPRESS; + /** + * {@code "100-continue"} + */ + public static final String CONTINUE = HttpHeaders.Values.CONTINUE; + /** + * {@code "deflate"} + */ + public static final String DEFLATE = HttpHeaders.Values.DEFLATE; + /** + * {@code "destination"} + */ + public static final String DESTINATION = "destination"; + /** + * {@code "gzip"} + */ + public static final String GZIP = HttpHeaders.Values.GZIP; + /** + * {@code "identity"} + */ + public static final String IDENTITY = HttpHeaders.Values.IDENTITY; + /** + * {@code "interleaved"} + */ + public static final String INTERLEAVED = "interleaved"; + /** + * {@code "keep-alive"} + */ + public static final String KEEP_ALIVE = HttpHeaders.Values.KEEP_ALIVE; + /** + * {@code "layers"} + */ + public static final String LAYERS = "layers"; + /** + * {@code "max-age"} + */ + public static final String MAX_AGE = HttpHeaders.Values.MAX_AGE; + /** + * {@code "max-stale"} + */ + public static final String MAX_STALE = HttpHeaders.Values.MAX_STALE; + /** + * {@code "min-fresh"} + */ + public static final String MIN_FRESH = HttpHeaders.Values.MIN_FRESH; + /** + * {@code "mode"} + */ + public static final String MODE = "mode"; + /** + * {@code "multicast"} + */ + public static final String MULTICAST = "multicast"; + /** + * {@code "must-revalidate"} + */ + public static final String MUST_REVALIDATE = HttpHeaders.Values.MUST_REVALIDATE; + /** + * {@code "none"} + */ + public static final String NONE = HttpHeaders.Values.NONE; + /** + * {@code "no-cache"} + */ + public static final String NO_CACHE = HttpHeaders.Values.NO_CACHE; + /** + * {@code "no-transform"} + */ + public static final String NO_TRANSFORM = HttpHeaders.Values.NO_TRANSFORM; + /** + * {@code "only-if-cached"} + */ + public static final String ONLY_IF_CACHED = HttpHeaders.Values.ONLY_IF_CACHED; + /** + * {@code "port"} + */ + public static final String PORT = "port"; + /** + * {@code "private"} + */ + public static final String PRIVATE = HttpHeaders.Values.PRIVATE; + /** + * {@code "proxy-revalidate"} + */ + public static final String PROXY_REVALIDATE = HttpHeaders.Values.PROXY_REVALIDATE; + /** + * {@code "public"} + */ + public static final String PUBLIC = HttpHeaders.Values.PUBLIC; + /** + * {@code "RTP"} + */ + public static final String RTP = "RTP"; + /** + * {@code "rtptime"} + */ + public static final String RTPTIME = "rtptime"; + /** + * {@code "seq"} + */ + public static final String SEQ = "seq"; + /** + * {@code "server_port"} + */ + public static final String SERVER_PORT = "server_port"; + /** + * {@code "ssrc"} + */ + public static final String SSRC = "ssrc"; + /** + * {@code "TCP"} + */ + public static final String TCP = "TCP"; + /** + * {@code "time"} + */ + public static final String TIME = "time"; + /** + * {@code "timeout"} + */ + public static final String TIMEOUT = "timeout"; + /** + * {@code "ttl"} + */ + public static final String TTL = "ttl"; + /** + * {@code "UDP"} + */ + public static final String UDP = "UDP"; + /** + * {@code "unicast"} + */ + public static final String UNICAST = "unicast"; + /** + * {@code "url"} + */ + public static final String URL = "url"; + + private Values() { } + } + + private RtspHeaders() { } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/rtsp/RtspMethods.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/rtsp/RtspMethods.java new file mode 100644 index 0000000..a0dca76 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/rtsp/RtspMethods.java @@ -0,0 +1,133 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.rtsp; + +import static io.netty.util.internal.ObjectUtil.checkNonEmptyAfterTrim; + +import io.netty.handler.codec.http.HttpMethod; + +import java.util.HashMap; +import java.util.Map; + +/** + * The request getMethod of RTSP. + */ +public final class RtspMethods { + + /** + * The OPTIONS getMethod represents a request for information about the communication options + * available on the request/response chain identified by the Request-URI. This getMethod allows + * the client to determine the options and/or requirements associated with a resource, or the + * capabilities of a server, without implying a resource action or initiating a resource + * retrieval. + */ + public static final HttpMethod OPTIONS = HttpMethod.OPTIONS; + + /** + * The DESCRIBE getMethod retrieves the description of a presentation or + * media object identified by the request URL from a server. + */ + public static final HttpMethod DESCRIBE = HttpMethod.valueOf("DESCRIBE"); + + /** + * The ANNOUNCE posts the description of a presentation or media object + * identified by the request URL to a server, or updates the client-side + * session description in real-time. + */ + public static final HttpMethod ANNOUNCE = HttpMethod.valueOf("ANNOUNCE"); + + /** + * The SETUP request for a URI specifies the transport mechanism to be + * used for the streamed media. + */ + public static final HttpMethod SETUP = HttpMethod.valueOf("SETUP"); + + /** + * The PLAY getMethod tells the server to start sending data via the + * mechanism specified in SETUP. + */ + public static final HttpMethod PLAY = HttpMethod.valueOf("PLAY"); + + /** + * The PAUSE request causes the stream delivery to be interrupted + * (halted) temporarily. + */ + public static final HttpMethod PAUSE = HttpMethod.valueOf("PAUSE"); + + /** + * The TEARDOWN request stops the stream delivery for the given URI, + * freeing the resources associated with it. + */ + public static final HttpMethod TEARDOWN = HttpMethod.valueOf("TEARDOWN"); + + /** + * The GET_PARAMETER request retrieves the value of a parameter of a + * presentation or stream specified in the URI. + */ + public static final HttpMethod GET_PARAMETER = HttpMethod.valueOf("GET_PARAMETER"); + + /** + * The SET_PARAMETER requests to set the value of a parameter for a + * presentation or stream specified by the URI. + */ + public static final HttpMethod SET_PARAMETER = HttpMethod.valueOf("SET_PARAMETER"); + + /** + * The REDIRECT request informs the client that it must connect to another + * server location. + */ + public static final HttpMethod REDIRECT = HttpMethod.valueOf("REDIRECT"); + + /** + * The RECORD getMethod initiates recording a range of media data according to + * the presentation description. + */ + public static final HttpMethod RECORD = HttpMethod.valueOf("RECORD"); + + private static final Map methodMap = new HashMap(); + + static { + methodMap.put(DESCRIBE.toString(), DESCRIBE); + methodMap.put(ANNOUNCE.toString(), ANNOUNCE); + methodMap.put(GET_PARAMETER.toString(), GET_PARAMETER); + methodMap.put(OPTIONS.toString(), OPTIONS); + methodMap.put(PAUSE.toString(), PAUSE); + methodMap.put(PLAY.toString(), PLAY); + methodMap.put(RECORD.toString(), RECORD); + methodMap.put(REDIRECT.toString(), REDIRECT); + methodMap.put(SETUP.toString(), SETUP); + methodMap.put(SET_PARAMETER.toString(), SET_PARAMETER); + methodMap.put(TEARDOWN.toString(), TEARDOWN); + } + + /** + * Returns the {@link HttpMethod} represented by the specified name. + * If the specified name is a standard RTSP getMethod name, a cached instance + * will be returned. Otherwise, a new instance will be returned. + */ + public static HttpMethod valueOf(String name) { + name = checkNonEmptyAfterTrim(name, "name").toUpperCase(); + HttpMethod result = methodMap.get(name); + if (result != null) { + return result; + } else { + return HttpMethod.valueOf(name); + } + } + + private RtspMethods() { + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/rtsp/RtspObjectDecoder.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/rtsp/RtspObjectDecoder.java new file mode 100644 index 0000000..d472259 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/rtsp/RtspObjectDecoder.java @@ -0,0 +1,92 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.rtsp; + +import io.netty.buffer.ByteBuf; +import io.netty.handler.codec.TooLongFrameException; +import io.netty.handler.codec.http.HttpMessage; +import io.netty.handler.codec.http.HttpObjectDecoder; + +import static io.netty.handler.codec.rtsp.RtspDecoder.DEFAULT_MAX_CONTENT_LENGTH; + +/** + * Decodes {@link ByteBuf}s into RTSP messages represented in + * {@link HttpMessage}s. + *

+ *

Parameters that prevents excessive memory consumption

+ * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
NameMeaning
{@code maxInitialLineLength}The maximum length of the initial line + * (e.g. {@code "SETUP / RTSP/1.0"} or {@code "RTSP/1.0 200 OK"}) + * If the length of the initial line exceeds this value, a + * {@link TooLongFrameException} will be raised.
{@code maxHeaderSize}The maximum length of all headers. If the sum of the length of each + * header exceeds this value, a {@link TooLongFrameException} will be raised.
{@code maxContentLength}The maximum length of the content. If the content length exceeds this + * value, a {@link TooLongFrameException} will be raised.
+ * + * @deprecated Use {@link RtspDecoder} instead. + */ +@Deprecated +public abstract class RtspObjectDecoder extends HttpObjectDecoder { + + /** + * Creates a new instance with the default + * {@code maxInitialLineLength (4096)}, {@code maxHeaderSize (8192)}, and + * {@code maxContentLength (8192)}. + */ + protected RtspObjectDecoder() { + this(DEFAULT_MAX_INITIAL_LINE_LENGTH, DEFAULT_MAX_HEADER_SIZE, DEFAULT_MAX_CONTENT_LENGTH); + } + + /** + * Creates a new instance with the specified parameters. + */ + protected RtspObjectDecoder(int maxInitialLineLength, int maxHeaderSize, int maxContentLength) { + super(maxInitialLineLength, maxHeaderSize, maxContentLength * 2, false); + } + + protected RtspObjectDecoder( + int maxInitialLineLength, int maxHeaderSize, int maxContentLength, boolean validateHeaders) { + super(maxInitialLineLength, maxHeaderSize, maxContentLength * 2, false, validateHeaders); + } + + @Override + protected boolean isContentAlwaysEmpty(HttpMessage msg) { + // Unlike HTTP, RTSP always assumes zero-length body if Content-Length + // header is absent. + boolean empty = super.isContentAlwaysEmpty(msg); + if (empty) { + return true; + } + if (!msg.headers().contains(RtspHeaderNames.CONTENT_LENGTH)) { + return true; + } + return empty; + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/rtsp/RtspObjectEncoder.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/rtsp/RtspObjectEncoder.java new file mode 100644 index 0000000..752020e --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/rtsp/RtspObjectEncoder.java @@ -0,0 +1,44 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.rtsp; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandler.Sharable; +import io.netty.handler.codec.http.FullHttpMessage; +import io.netty.handler.codec.http.HttpMessage; +import io.netty.handler.codec.http.HttpObjectEncoder; + +/** + * Encodes an RTSP message represented in {@link FullHttpMessage} into + * a {@link ByteBuf}. + * + * @deprecated Use {@link RtspEncoder} instead. + */ +@Sharable +@Deprecated +public abstract class RtspObjectEncoder extends HttpObjectEncoder { + + /** + * Creates a new instance. + */ + protected RtspObjectEncoder() { + } + + @Override + public boolean acceptOutboundMessage(Object msg) throws Exception { + return msg instanceof FullHttpMessage; + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/rtsp/RtspRequestDecoder.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/rtsp/RtspRequestDecoder.java new file mode 100644 index 0000000..76b186c --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/rtsp/RtspRequestDecoder.java @@ -0,0 +1,23 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.rtsp; + +/** + * @deprecated Use {@link RtspDecoder} directly instead + */ +@Deprecated +public class RtspRequestDecoder extends RtspDecoder { +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/rtsp/RtspRequestEncoder.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/rtsp/RtspRequestEncoder.java new file mode 100644 index 0000000..356a3c4 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/rtsp/RtspRequestEncoder.java @@ -0,0 +1,23 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.rtsp; + +/** + * @deprecated Use {@link RtspEncoder} directly instead + */ +@Deprecated +public class RtspRequestEncoder extends RtspEncoder { +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/rtsp/RtspResponseDecoder.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/rtsp/RtspResponseDecoder.java new file mode 100644 index 0000000..cff0403 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/rtsp/RtspResponseDecoder.java @@ -0,0 +1,23 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.rtsp; + +/** + * @deprecated Use {@link RtspDecoder} directly instead + */ +@Deprecated +public class RtspResponseDecoder extends RtspDecoder { +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/rtsp/RtspResponseEncoder.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/rtsp/RtspResponseEncoder.java new file mode 100644 index 0000000..24fd1e9 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/rtsp/RtspResponseEncoder.java @@ -0,0 +1,23 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.rtsp; + +/** + * @deprecated Use {@link RtspEncoder} directly instead + */ +@Deprecated +public class RtspResponseEncoder extends RtspEncoder { +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/rtsp/RtspResponseStatuses.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/rtsp/RtspResponseStatuses.java new file mode 100644 index 0000000..8a2ad15 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/rtsp/RtspResponseStatuses.java @@ -0,0 +1,292 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.rtsp; + +import io.netty.handler.codec.http.HttpResponseStatus; + +/** + * The getStatus code and its description of a RTSP response. + */ +public final class RtspResponseStatuses { + + /** + * 100 Continue + */ + public static final HttpResponseStatus CONTINUE = HttpResponseStatus.CONTINUE; + + /** + * 200 OK + */ + public static final HttpResponseStatus OK = HttpResponseStatus.OK; + + /** + * 201 Created + */ + public static final HttpResponseStatus CREATED = HttpResponseStatus.CREATED; + + /** + * 250 Low on Storage Space + */ + public static final HttpResponseStatus LOW_STORAGE_SPACE = new HttpResponseStatus( + 250, "Low on Storage Space"); + + /** + * 300 Multiple Choices + */ + public static final HttpResponseStatus MULTIPLE_CHOICES = HttpResponseStatus.MULTIPLE_CHOICES; + + /** + * 301 Moved Permanently + */ + public static final HttpResponseStatus MOVED_PERMANENTLY = HttpResponseStatus.MOVED_PERMANENTLY; + + /** + * 302 Moved Temporarily + */ + public static final HttpResponseStatus MOVED_TEMPORARILY = new HttpResponseStatus( + 302, "Moved Temporarily"); + /** + * 304 Not Modified + */ + public static final HttpResponseStatus NOT_MODIFIED = HttpResponseStatus.NOT_MODIFIED; + + /** + * 305 Use Proxy + */ + public static final HttpResponseStatus USE_PROXY = HttpResponseStatus.USE_PROXY; + + /** + * 400 Bad Request + */ + public static final HttpResponseStatus BAD_REQUEST = HttpResponseStatus.BAD_REQUEST; + + /** + * 401 Unauthorized + */ + public static final HttpResponseStatus UNAUTHORIZED = HttpResponseStatus.UNAUTHORIZED; + + /** + * 402 Payment Required + */ + public static final HttpResponseStatus PAYMENT_REQUIRED = HttpResponseStatus.PAYMENT_REQUIRED; + + /** + * 403 Forbidden + */ + public static final HttpResponseStatus FORBIDDEN = HttpResponseStatus.FORBIDDEN; + + /** + * 404 Not Found + */ + public static final HttpResponseStatus NOT_FOUND = HttpResponseStatus.NOT_FOUND; + + /** + * 405 Method Not Allowed + */ + public static final HttpResponseStatus METHOD_NOT_ALLOWED = HttpResponseStatus.METHOD_NOT_ALLOWED; + + /** + * 406 Not Acceptable + */ + public static final HttpResponseStatus NOT_ACCEPTABLE = HttpResponseStatus.NOT_ACCEPTABLE; + + /** + * 407 Proxy Authentication Required + */ + public static final HttpResponseStatus PROXY_AUTHENTICATION_REQUIRED = + HttpResponseStatus.PROXY_AUTHENTICATION_REQUIRED; + + /** + * 408 Request Timeout + */ + public static final HttpResponseStatus REQUEST_TIMEOUT = HttpResponseStatus.REQUEST_TIMEOUT; + + /** + * 410 Gone + */ + public static final HttpResponseStatus GONE = HttpResponseStatus.GONE; + + /** + * 411 Length Required + */ + public static final HttpResponseStatus LENGTH_REQUIRED = HttpResponseStatus.LENGTH_REQUIRED; + + /** + * 412 Precondition Failed + */ + public static final HttpResponseStatus PRECONDITION_FAILED = HttpResponseStatus.PRECONDITION_FAILED; + + /** + * 413 Request Entity Too Large + */ + public static final HttpResponseStatus REQUEST_ENTITY_TOO_LARGE = HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE; + + /** + * 414 Request-URI Too Long + */ + public static final HttpResponseStatus REQUEST_URI_TOO_LONG = HttpResponseStatus.REQUEST_URI_TOO_LONG; + + /** + * 415 Unsupported Media Type + */ + public static final HttpResponseStatus UNSUPPORTED_MEDIA_TYPE = HttpResponseStatus.UNSUPPORTED_MEDIA_TYPE; + + /** + * 451 Parameter Not Understood + */ + public static final HttpResponseStatus PARAMETER_NOT_UNDERSTOOD = new HttpResponseStatus( + 451, "Parameter Not Understood"); + + /** + * 452 Conference Not Found + */ + public static final HttpResponseStatus CONFERENCE_NOT_FOUND = new HttpResponseStatus( + 452, "Conference Not Found"); + + /** + * 453 Not Enough Bandwidth + */ + public static final HttpResponseStatus NOT_ENOUGH_BANDWIDTH = new HttpResponseStatus( + 453, "Not Enough Bandwidth"); + + /** + * 454 Session Not Found + */ + public static final HttpResponseStatus SESSION_NOT_FOUND = new HttpResponseStatus( + 454, "Session Not Found"); + + /** + * 455 Method Not Valid in This State + */ + public static final HttpResponseStatus METHOD_NOT_VALID = new HttpResponseStatus( + 455, "Method Not Valid in This State"); + + /** + * 456 Header Field Not Valid for Resource + */ + public static final HttpResponseStatus HEADER_FIELD_NOT_VALID = new HttpResponseStatus( + 456, "Header Field Not Valid for Resource"); + + /** + * 457 Invalid Range + */ + public static final HttpResponseStatus INVALID_RANGE = new HttpResponseStatus( + 457, "Invalid Range"); + + /** + * 458 Parameter Is Read-Only + */ + public static final HttpResponseStatus PARAMETER_IS_READONLY = new HttpResponseStatus( + 458, "Parameter Is Read-Only"); + + /** + * 459 Aggregate operation not allowed + */ + public static final HttpResponseStatus AGGREGATE_OPERATION_NOT_ALLOWED = new HttpResponseStatus( + 459, "Aggregate operation not allowed"); + + /** + * 460 Only Aggregate operation allowed + */ + public static final HttpResponseStatus ONLY_AGGREGATE_OPERATION_ALLOWED = new HttpResponseStatus( + 460, "Only Aggregate operation allowed"); + + /** + * 461 Unsupported transport + */ + public static final HttpResponseStatus UNSUPPORTED_TRANSPORT = new HttpResponseStatus( + 461, "Unsupported transport"); + + /** + * 462 Destination unreachable + */ + public static final HttpResponseStatus DESTINATION_UNREACHABLE = new HttpResponseStatus( + 462, "Destination unreachable"); + + /** + * 463 Key management failure + */ + public static final HttpResponseStatus KEY_MANAGEMENT_FAILURE = new HttpResponseStatus( + 463, "Key management failure"); + + /** + * 500 Internal Server Error + */ + public static final HttpResponseStatus INTERNAL_SERVER_ERROR = HttpResponseStatus.INTERNAL_SERVER_ERROR; + + /** + * 501 Not Implemented + */ + public static final HttpResponseStatus NOT_IMPLEMENTED = HttpResponseStatus.NOT_IMPLEMENTED; + + /** + * 502 Bad Gateway + */ + public static final HttpResponseStatus BAD_GATEWAY = HttpResponseStatus.BAD_GATEWAY; + + /** + * 503 Service Unavailable + */ + public static final HttpResponseStatus SERVICE_UNAVAILABLE = HttpResponseStatus.SERVICE_UNAVAILABLE; + + /** + * 504 Gateway Timeout + */ + public static final HttpResponseStatus GATEWAY_TIMEOUT = HttpResponseStatus.GATEWAY_TIMEOUT; + + /** + * 505 RTSP Version not supported + */ + public static final HttpResponseStatus RTSP_VERSION_NOT_SUPPORTED = new HttpResponseStatus( + 505, "RTSP Version not supported"); + + /** + * 551 Option not supported + */ + public static final HttpResponseStatus OPTION_NOT_SUPPORTED = new HttpResponseStatus( + 551, "Option not supported"); + + /** + * Returns the {@link HttpResponseStatus} represented by the specified code. + * If the specified code is a standard RTSP getStatus code, a cached instance + * will be returned. Otherwise, a new instance will be returned. + */ + public static HttpResponseStatus valueOf(int code) { + switch (code) { + case 250: return LOW_STORAGE_SPACE; + case 302: return MOVED_TEMPORARILY; + case 451: return PARAMETER_NOT_UNDERSTOOD; + case 452: return CONFERENCE_NOT_FOUND; + case 453: return NOT_ENOUGH_BANDWIDTH; + case 454: return SESSION_NOT_FOUND; + case 455: return METHOD_NOT_VALID; + case 456: return HEADER_FIELD_NOT_VALID; + case 457: return INVALID_RANGE; + case 458: return PARAMETER_IS_READONLY; + case 459: return AGGREGATE_OPERATION_NOT_ALLOWED; + case 460: return ONLY_AGGREGATE_OPERATION_ALLOWED; + case 461: return UNSUPPORTED_TRANSPORT; + case 462: return DESTINATION_UNREACHABLE; + case 463: return KEY_MANAGEMENT_FAILURE; + case 505: return RTSP_VERSION_NOT_SUPPORTED; + case 551: return OPTION_NOT_SUPPORTED; + default: return HttpResponseStatus.valueOf(code); + } + } + + private RtspResponseStatuses() { + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/rtsp/RtspVersions.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/rtsp/RtspVersions.java new file mode 100644 index 0000000..92831fb --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/rtsp/RtspVersions.java @@ -0,0 +1,50 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.rtsp; + +import io.netty.handler.codec.http.HttpVersion; +import io.netty.util.internal.ObjectUtil; + +/** + * The version of RTSP. + */ +public final class RtspVersions { + + /** + * RTSP/1.0 + */ + public static final HttpVersion RTSP_1_0 = new HttpVersion("RTSP", 1, 0, true); + + /** + * Returns an existing or new {@link HttpVersion} instance which matches to + * the specified RTSP version string. If the specified {@code text} is + * equal to {@code "RTSP/1.0"}, {@link #RTSP_1_0} will be returned. + * Otherwise, a new {@link HttpVersion} instance will be returned. + */ + public static HttpVersion valueOf(String text) { + ObjectUtil.checkNotNull(text, "text"); + + text = text.trim().toUpperCase(); + if ("RTSP/1.0".equals(text)) { + return RTSP_1_0; + } + + return new HttpVersion(text, true); + } + + private RtspVersions() { + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/rtsp/package-info.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/rtsp/package-info.java new file mode 100644 index 0000000..5af8ae0 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/rtsp/package-info.java @@ -0,0 +1,21 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/** + * An RTSP + * extension based on the HTTP codec. + */ +package io.netty.handler.codec.rtsp; diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/DefaultSpdyDataFrame.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/DefaultSpdyDataFrame.java new file mode 100644 index 0000000..6c32729 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/DefaultSpdyDataFrame.java @@ -0,0 +1,157 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.spdy; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.StringUtil; + +/** + * The default {@link SpdyDataFrame} implementation. + */ +public class DefaultSpdyDataFrame extends DefaultSpdyStreamFrame implements SpdyDataFrame { + + private final ByteBuf data; + + /** + * Creates a new instance. + * + * @param streamId the Stream-ID of this frame + */ + public DefaultSpdyDataFrame(int streamId) { + this(streamId, Unpooled.buffer(0)); + } + + /** + * Creates a new instance. + * + * @param streamId the Stream-ID of this frame + * @param data the payload of the frame. Can not exceed {@link SpdyCodecUtil#SPDY_MAX_LENGTH} + */ + public DefaultSpdyDataFrame(int streamId, ByteBuf data) { + super(streamId); + this.data = validate( + ObjectUtil.checkNotNull(data, "data")); + } + + private static ByteBuf validate(ByteBuf data) { + if (data.readableBytes() > SpdyCodecUtil.SPDY_MAX_LENGTH) { + throw new IllegalArgumentException("data payload cannot exceed " + + SpdyCodecUtil.SPDY_MAX_LENGTH + " bytes"); + } + return data; + } + + @Override + public SpdyDataFrame setStreamId(int streamId) { + super.setStreamId(streamId); + return this; + } + + @Override + public SpdyDataFrame setLast(boolean last) { + super.setLast(last); + return this; + } + + @Override + public ByteBuf content() { + return ByteBufUtil.ensureAccessible(data); + } + + @Override + public SpdyDataFrame copy() { + return replace(content().copy()); + } + + @Override + public SpdyDataFrame duplicate() { + return replace(content().duplicate()); + } + + @Override + public SpdyDataFrame retainedDuplicate() { + return replace(content().retainedDuplicate()); + } + + @Override + public SpdyDataFrame replace(ByteBuf content) { + SpdyDataFrame frame = new DefaultSpdyDataFrame(streamId(), content); + frame.setLast(isLast()); + return frame; + } + + @Override + public int refCnt() { + return data.refCnt(); + } + + @Override + public SpdyDataFrame retain() { + data.retain(); + return this; + } + + @Override + public SpdyDataFrame retain(int increment) { + data.retain(increment); + return this; + } + + @Override + public SpdyDataFrame touch() { + data.touch(); + return this; + } + + @Override + public SpdyDataFrame touch(Object hint) { + data.touch(hint); + return this; + } + + @Override + public boolean release() { + return data.release(); + } + + @Override + public boolean release(int decrement) { + return data.release(decrement); + } + + @Override + public String toString() { + StringBuilder buf = new StringBuilder() + .append(StringUtil.simpleClassName(this)) + .append("(last: ") + .append(isLast()) + .append(')') + .append(StringUtil.NEWLINE) + .append("--> Stream-ID = ") + .append(streamId()) + .append(StringUtil.NEWLINE) + .append("--> Size = "); + if (refCnt() == 0) { + buf.append("(freed)"); + } else { + buf.append(content().readableBytes()); + } + return buf.toString(); + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/DefaultSpdyGoAwayFrame.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/DefaultSpdyGoAwayFrame.java new file mode 100644 index 0000000..be59733 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/DefaultSpdyGoAwayFrame.java @@ -0,0 +1,95 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.spdy; + +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; + +import io.netty.util.internal.StringUtil; + +/** + * The default {@link SpdyGoAwayFrame} implementation. + */ +public class DefaultSpdyGoAwayFrame implements SpdyGoAwayFrame { + + private int lastGoodStreamId; + private SpdySessionStatus status; + + /** + * Creates a new instance. + * + * @param lastGoodStreamId the Last-good-stream-ID of this frame + */ + public DefaultSpdyGoAwayFrame(int lastGoodStreamId) { + this(lastGoodStreamId, 0); + } + + /** + * Creates a new instance. + * + * @param lastGoodStreamId the Last-good-stream-ID of this frame + * @param statusCode the Status code of this frame + */ + public DefaultSpdyGoAwayFrame(int lastGoodStreamId, int statusCode) { + this(lastGoodStreamId, SpdySessionStatus.valueOf(statusCode)); + } + + /** + * Creates a new instance. + * + * @param lastGoodStreamId the Last-good-stream-ID of this frame + * @param status the status of this frame + */ + public DefaultSpdyGoAwayFrame(int lastGoodStreamId, SpdySessionStatus status) { + setLastGoodStreamId(lastGoodStreamId); + setStatus(status); + } + + @Override + public int lastGoodStreamId() { + return lastGoodStreamId; + } + + @Override + public SpdyGoAwayFrame setLastGoodStreamId(int lastGoodStreamId) { + checkPositiveOrZero(lastGoodStreamId, "lastGoodStreamId"); + this.lastGoodStreamId = lastGoodStreamId; + return this; + } + + @Override + public SpdySessionStatus status() { + return status; + } + + @Override + public SpdyGoAwayFrame setStatus(SpdySessionStatus status) { + this.status = status; + return this; + } + + @Override + public String toString() { + return new StringBuilder() + .append(StringUtil.simpleClassName(this)) + .append(StringUtil.NEWLINE) + .append("--> Last-good-stream-ID = ") + .append(lastGoodStreamId()) + .append(StringUtil.NEWLINE) + .append("--> Status: ") + .append(status()) + .toString(); + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/DefaultSpdyHeaders.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/DefaultSpdyHeaders.java new file mode 100644 index 0000000..d6316ce --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/DefaultSpdyHeaders.java @@ -0,0 +1,84 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.spdy; + +import io.netty.handler.codec.CharSequenceValueConverter; +import io.netty.handler.codec.DefaultHeaders; +import io.netty.handler.codec.HeadersUtils; + +import java.util.Iterator; +import java.util.List; +import java.util.Map.Entry; + +import static io.netty.util.AsciiString.CASE_INSENSITIVE_HASHER; +import static io.netty.util.AsciiString.CASE_SENSITIVE_HASHER; + +public class DefaultSpdyHeaders extends DefaultHeaders implements SpdyHeaders { + private static final NameValidator SpdyNameValidator = new NameValidator() { + @Override + public void validateName(CharSequence name) { + SpdyCodecUtil.validateHeaderName(name); + } + }; + + public DefaultSpdyHeaders() { + this(true); + } + + @SuppressWarnings("unchecked") + public DefaultSpdyHeaders(boolean validate) { + super(CASE_INSENSITIVE_HASHER, + validate ? HeaderValueConverterAndValidator.INSTANCE : CharSequenceValueConverter.INSTANCE, + validate ? SpdyNameValidator : NameValidator.NOT_NULL); + } + + @Override + public String getAsString(CharSequence name) { + return HeadersUtils.getAsString(this, name); + } + + @Override + public List getAllAsString(CharSequence name) { + return HeadersUtils.getAllAsString(this, name); + } + + @Override + public Iterator> iteratorAsString() { + return HeadersUtils.iteratorAsString(this); + } + + @Override + public boolean contains(CharSequence name, CharSequence value) { + return contains(name, value, false); + } + + @Override + public boolean contains(CharSequence name, CharSequence value, boolean ignoreCase) { + return contains(name, value, + ignoreCase ? CASE_INSENSITIVE_HASHER : CASE_SENSITIVE_HASHER); + } + + private static final class HeaderValueConverterAndValidator extends CharSequenceValueConverter { + public static final HeaderValueConverterAndValidator INSTANCE = new HeaderValueConverterAndValidator(); + + @Override + public CharSequence convertObject(Object value) { + final CharSequence seq = super.convertObject(value); + SpdyCodecUtil.validateHeaderValue(seq); + return seq; + } + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/DefaultSpdyHeadersFrame.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/DefaultSpdyHeadersFrame.java new file mode 100644 index 0000000..28471b2 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/DefaultSpdyHeadersFrame.java @@ -0,0 +1,120 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.spdy; + +import io.netty.util.internal.StringUtil; + +import java.util.Map; + +/** + * The default {@link SpdyHeadersFrame} implementation. + */ +public class DefaultSpdyHeadersFrame extends DefaultSpdyStreamFrame + implements SpdyHeadersFrame { + + private boolean invalid; + private boolean truncated; + private final SpdyHeaders headers; + + /** + * Creates a new instance. + * + * @param streamId the Stream-ID of this frame + */ + public DefaultSpdyHeadersFrame(int streamId) { + this(streamId, true); + } + + /** + * Creates a new instance. + * + * @param streamId the Stream-ID of this frame + * @param validate validate the header names and values when adding them to the {@link SpdyHeaders} + */ + public DefaultSpdyHeadersFrame(int streamId, boolean validate) { + super(streamId); + headers = new DefaultSpdyHeaders(validate); + } + + @Override + public SpdyHeadersFrame setStreamId(int streamId) { + super.setStreamId(streamId); + return this; + } + + @Override + public SpdyHeadersFrame setLast(boolean last) { + super.setLast(last); + return this; + } + + @Override + public boolean isInvalid() { + return invalid; + } + + @Override + public SpdyHeadersFrame setInvalid() { + invalid = true; + return this; + } + + @Override + public boolean isTruncated() { + return truncated; + } + + @Override + public SpdyHeadersFrame setTruncated() { + truncated = true; + return this; + } + + @Override + public SpdyHeaders headers() { + return headers; + } + + @Override + public String toString() { + StringBuilder buf = new StringBuilder() + .append(StringUtil.simpleClassName(this)) + .append("(last: ") + .append(isLast()) + .append(')') + .append(StringUtil.NEWLINE) + .append("--> Stream-ID = ") + .append(streamId()) + .append(StringUtil.NEWLINE) + .append("--> Headers:") + .append(StringUtil.NEWLINE); + appendHeaders(buf); + + // Remove the last newline. + buf.setLength(buf.length() - StringUtil.NEWLINE.length()); + return buf.toString(); + } + + protected void appendHeaders(StringBuilder buf) { + for (Map.Entry e: headers()) { + buf.append(" "); + buf.append(e.getKey()); + buf.append(": "); + buf.append(e.getValue()); + buf.append(StringUtil.NEWLINE); + } + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/DefaultSpdyPingFrame.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/DefaultSpdyPingFrame.java new file mode 100644 index 0000000..8a04d35 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/DefaultSpdyPingFrame.java @@ -0,0 +1,56 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.spdy; + +import io.netty.util.internal.StringUtil; + +/** + * The default {@link SpdyPingFrame} implementation. + */ +public class DefaultSpdyPingFrame implements SpdyPingFrame { + + private int id; + + /** + * Creates a new instance. + * + * @param id the unique ID of this frame + */ + public DefaultSpdyPingFrame(int id) { + setId(id); + } + + @Override + public int id() { + return id; + } + + @Override + public SpdyPingFrame setId(int id) { + this.id = id; + return this; + } + + @Override + public String toString() { + return new StringBuilder() + .append(StringUtil.simpleClassName(this)) + .append(StringUtil.NEWLINE) + .append("--> ID = ") + .append(id()) + .toString(); + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/DefaultSpdyRstStreamFrame.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/DefaultSpdyRstStreamFrame.java new file mode 100644 index 0000000..939e424 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/DefaultSpdyRstStreamFrame.java @@ -0,0 +1,84 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.spdy; + +import io.netty.util.internal.StringUtil; + +/** + * The default {@link SpdyRstStreamFrame} implementation. + */ +public class DefaultSpdyRstStreamFrame extends DefaultSpdyStreamFrame + implements SpdyRstStreamFrame { + + private SpdyStreamStatus status; + + /** + * Creates a new instance. + * + * @param streamId the Stream-ID of this frame + * @param statusCode the Status code of this frame + */ + public DefaultSpdyRstStreamFrame(int streamId, int statusCode) { + this(streamId, SpdyStreamStatus.valueOf(statusCode)); + } + + /** + * Creates a new instance. + * + * @param streamId the Stream-ID of this frame + * @param status the status of this frame + */ + public DefaultSpdyRstStreamFrame(int streamId, SpdyStreamStatus status) { + super(streamId); + setStatus(status); + } + + @Override + public SpdyRstStreamFrame setStreamId(int streamId) { + super.setStreamId(streamId); + return this; + } + + @Override + public SpdyRstStreamFrame setLast(boolean last) { + super.setLast(last); + return this; + } + + @Override + public SpdyStreamStatus status() { + return status; + } + + @Override + public SpdyRstStreamFrame setStatus(SpdyStreamStatus status) { + this.status = status; + return this; + } + + @Override + public String toString() { + return new StringBuilder() + .append(StringUtil.simpleClassName(this)) + .append(StringUtil.NEWLINE) + .append("--> Stream-ID = ") + .append(streamId()) + .append(StringUtil.NEWLINE) + .append("--> Status: ") + .append(status()) + .toString(); + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/DefaultSpdySettingsFrame.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/DefaultSpdySettingsFrame.java new file mode 100644 index 0000000..59b1b4d --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/DefaultSpdySettingsFrame.java @@ -0,0 +1,184 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.spdy; + +import io.netty.util.internal.StringUtil; + +import java.util.Map; +import java.util.Set; +import java.util.TreeMap; + +/** + * The default {@link SpdySettingsFrame} implementation. + */ +public class DefaultSpdySettingsFrame implements SpdySettingsFrame { + + private boolean clear; + private final Map settingsMap = new TreeMap(); + + @Override + public Set ids() { + return settingsMap.keySet(); + } + + @Override + public boolean isSet(int id) { + return settingsMap.containsKey(id); + } + + @Override + public int getValue(int id) { + final Setting setting = settingsMap.get(id); + return setting != null ? setting.getValue() : -1; + } + + @Override + public SpdySettingsFrame setValue(int id, int value) { + return setValue(id, value, false, false); + } + + @Override + public SpdySettingsFrame setValue(int id, int value, boolean persistValue, boolean persisted) { + if (id < 0 || id > SpdyCodecUtil.SPDY_SETTINGS_MAX_ID) { + throw new IllegalArgumentException("Setting ID is not valid: " + id); + } + final Integer key = Integer.valueOf(id); + final Setting setting = settingsMap.get(key); + if (setting != null) { + setting.setValue(value); + setting.setPersist(persistValue); + setting.setPersisted(persisted); + } else { + settingsMap.put(key, new Setting(value, persistValue, persisted)); + } + return this; + } + + @Override + public SpdySettingsFrame removeValue(int id) { + settingsMap.remove(id); + return this; + } + + @Override + public boolean isPersistValue(int id) { + final Setting setting = settingsMap.get(id); + return setting != null && setting.isPersist(); + } + + @Override + public SpdySettingsFrame setPersistValue(int id, boolean persistValue) { + final Setting setting = settingsMap.get(id); + if (setting != null) { + setting.setPersist(persistValue); + } + return this; + } + + @Override + public boolean isPersisted(int id) { + final Setting setting = settingsMap.get(id); + return setting != null && setting.isPersisted(); + } + + @Override + public SpdySettingsFrame setPersisted(int id, boolean persisted) { + final Setting setting = settingsMap.get(id); + if (setting != null) { + setting.setPersisted(persisted); + } + return this; + } + + @Override + public boolean clearPreviouslyPersistedSettings() { + return clear; + } + + @Override + public SpdySettingsFrame setClearPreviouslyPersistedSettings(boolean clear) { + this.clear = clear; + return this; + } + + private Set> getSettings() { + return settingsMap.entrySet(); + } + + private void appendSettings(StringBuilder buf) { + for (Map.Entry e: getSettings()) { + Setting setting = e.getValue(); + buf.append("--> "); + buf.append(e.getKey()); + buf.append(':'); + buf.append(setting.getValue()); + buf.append(" (persist value: "); + buf.append(setting.isPersist()); + buf.append("; persisted: "); + buf.append(setting.isPersisted()); + buf.append(')'); + buf.append(StringUtil.NEWLINE); + } + } + + @Override + public String toString() { + StringBuilder buf = new StringBuilder() + .append(StringUtil.simpleClassName(this)) + .append(StringUtil.NEWLINE); + appendSettings(buf); + + buf.setLength(buf.length() - StringUtil.NEWLINE.length()); + return buf.toString(); + } + + private static final class Setting { + + private int value; + private boolean persist; + private boolean persisted; + + Setting(int value, boolean persist, boolean persisted) { + this.value = value; + this.persist = persist; + this.persisted = persisted; + } + + int getValue() { + return value; + } + + void setValue(int value) { + this.value = value; + } + + boolean isPersist() { + return persist; + } + + void setPersist(boolean persist) { + this.persist = persist; + } + + boolean isPersisted() { + return persisted; + } + + void setPersisted(boolean persisted) { + this.persisted = persisted; + } + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/DefaultSpdyStreamFrame.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/DefaultSpdyStreamFrame.java new file mode 100644 index 0000000..c9d9bf7 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/DefaultSpdyStreamFrame.java @@ -0,0 +1,59 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.spdy; + +import static io.netty.util.internal.ObjectUtil.checkPositive; + +/** + * The default {@link SpdyStreamFrame} implementation. + */ +public abstract class DefaultSpdyStreamFrame implements SpdyStreamFrame { + + private int streamId; + private boolean last; + + /** + * Creates a new instance. + * + * @param streamId the Stream-ID of this frame + */ + protected DefaultSpdyStreamFrame(int streamId) { + setStreamId(streamId); + } + + @Override + public int streamId() { + return streamId; + } + + @Override + public SpdyStreamFrame setStreamId(int streamId) { + checkPositive(streamId, "streamId"); + this.streamId = streamId; + return this; + } + + @Override + public boolean isLast() { + return last; + } + + @Override + public SpdyStreamFrame setLast(boolean last) { + this.last = last; + return this; + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/DefaultSpdySynReplyFrame.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/DefaultSpdySynReplyFrame.java new file mode 100644 index 0000000..689e72d --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/DefaultSpdySynReplyFrame.java @@ -0,0 +1,81 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.spdy; + +import io.netty.util.internal.StringUtil; + +/** + * The default {@link SpdySynReplyFrame} implementation. + */ +public class DefaultSpdySynReplyFrame extends DefaultSpdyHeadersFrame implements SpdySynReplyFrame { + + /** + * Creates a new instance. + * + * @param streamId the Stream-ID of this frame + */ + public DefaultSpdySynReplyFrame(int streamId) { + super(streamId); + } + + /** + * Creates a new instance. + * + * @param streamId the Stream-ID of this frame + * @param validateHeaders validate the header names and values when adding them to the {@link SpdyHeaders} + */ + public DefaultSpdySynReplyFrame(int streamId, boolean validateHeaders) { + super(streamId, validateHeaders); + } + + @Override + public SpdySynReplyFrame setStreamId(int streamId) { + super.setStreamId(streamId); + return this; + } + + @Override + public SpdySynReplyFrame setLast(boolean last) { + super.setLast(last); + return this; + } + + @Override + public SpdySynReplyFrame setInvalid() { + super.setInvalid(); + return this; + } + + @Override + public String toString() { + StringBuilder buf = new StringBuilder() + .append(StringUtil.simpleClassName(this)) + .append("(last: ") + .append(isLast()) + .append(')') + .append(StringUtil.NEWLINE) + .append("--> Stream-ID = ") + .append(streamId()) + .append(StringUtil.NEWLINE) + .append("--> Headers:") + .append(StringUtil.NEWLINE); + appendHeaders(buf); + + // Remove the last newline. + buf.setLength(buf.length() - StringUtil.NEWLINE.length()); + return buf.toString(); + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/DefaultSpdySynStreamFrame.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/DefaultSpdySynStreamFrame.java new file mode 100644 index 0000000..149d155 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/DefaultSpdySynStreamFrame.java @@ -0,0 +1,142 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.spdy; + +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; + +import io.netty.util.internal.StringUtil; + +/** + * The default {@link SpdySynStreamFrame} implementation. + */ +public class DefaultSpdySynStreamFrame extends DefaultSpdyHeadersFrame + implements SpdySynStreamFrame { + + private int associatedStreamId; + private byte priority; + private boolean unidirectional; + + /** + * Creates a new instance. + * + * @param streamId the Stream-ID of this frame + * @param associatedStreamId the Associated-To-Stream-ID of this frame + * @param priority the priority of the stream + */ + public DefaultSpdySynStreamFrame(int streamId, int associatedStreamId, byte priority) { + this(streamId, associatedStreamId, priority, true); + } + + /** + * Creates a new instance. + * + * @param streamId the Stream-ID of this frame + * @param associatedStreamId the Associated-To-Stream-ID of this frame + * @param priority the priority of the stream + * @param validateHeaders validate the header names and values when adding them to the {@link SpdyHeaders} + */ + public DefaultSpdySynStreamFrame(int streamId, int associatedStreamId, byte priority, boolean validateHeaders) { + super(streamId, validateHeaders); + setAssociatedStreamId(associatedStreamId); + setPriority(priority); + } + + @Override + public SpdySynStreamFrame setStreamId(int streamId) { + super.setStreamId(streamId); + return this; + } + + @Override + public SpdySynStreamFrame setLast(boolean last) { + super.setLast(last); + return this; + } + + @Override + public SpdySynStreamFrame setInvalid() { + super.setInvalid(); + return this; + } + + @Override + public int associatedStreamId() { + return associatedStreamId; + } + + @Override + public SpdySynStreamFrame setAssociatedStreamId(int associatedStreamId) { + checkPositiveOrZero(associatedStreamId, "associatedStreamId"); + this.associatedStreamId = associatedStreamId; + return this; + } + + @Override + public byte priority() { + return priority; + } + + @Override + public SpdySynStreamFrame setPriority(byte priority) { + if (priority < 0 || priority > 7) { + throw new IllegalArgumentException( + "Priority must be between 0 and 7 inclusive: " + priority); + } + this.priority = priority; + return this; + } + + @Override + public boolean isUnidirectional() { + return unidirectional; + } + + @Override + public SpdySynStreamFrame setUnidirectional(boolean unidirectional) { + this.unidirectional = unidirectional; + return this; + } + + @Override + public String toString() { + StringBuilder buf = new StringBuilder() + .append(StringUtil.simpleClassName(this)) + .append("(last: ") + .append(isLast()) + .append("; unidirectional: ") + .append(isUnidirectional()) + .append(')') + .append(StringUtil.NEWLINE) + .append("--> Stream-ID = ") + .append(streamId()) + .append(StringUtil.NEWLINE); + if (associatedStreamId != 0) { + buf.append("--> Associated-To-Stream-ID = ") + .append(associatedStreamId()) + .append(StringUtil.NEWLINE); + } + buf.append("--> Priority = ") + .append(priority()) + .append(StringUtil.NEWLINE) + .append("--> Headers:") + .append(StringUtil.NEWLINE); + appendHeaders(buf); + + // Remove the last newline. + buf.setLength(buf.length() - StringUtil.NEWLINE.length()); + return buf.toString(); + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/DefaultSpdyWindowUpdateFrame.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/DefaultSpdyWindowUpdateFrame.java new file mode 100644 index 0000000..d4b88de --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/DefaultSpdyWindowUpdateFrame.java @@ -0,0 +1,78 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.spdy; + +import static io.netty.util.internal.ObjectUtil.checkPositive; +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; + +import io.netty.util.internal.StringUtil; + +/** + * The default {@link SpdyWindowUpdateFrame} implementation. + */ +public class DefaultSpdyWindowUpdateFrame implements SpdyWindowUpdateFrame { + + private int streamId; + private int deltaWindowSize; + + /** + * Creates a new instance. + * + * @param streamId the Stream-ID of this frame + * @param deltaWindowSize the Delta-Window-Size of this frame + */ + public DefaultSpdyWindowUpdateFrame(int streamId, int deltaWindowSize) { + setStreamId(streamId); + setDeltaWindowSize(deltaWindowSize); + } + + @Override + public int streamId() { + return streamId; + } + + @Override + public SpdyWindowUpdateFrame setStreamId(int streamId) { + checkPositiveOrZero(streamId, "streamId"); + this.streamId = streamId; + return this; + } + + @Override + public int deltaWindowSize() { + return deltaWindowSize; + } + + @Override + public SpdyWindowUpdateFrame setDeltaWindowSize(int deltaWindowSize) { + checkPositive(deltaWindowSize, "deltaWindowSize"); + this.deltaWindowSize = deltaWindowSize; + return this; + } + + @Override + public String toString() { + return new StringBuilder() + .append(StringUtil.simpleClassName(this)) + .append(StringUtil.NEWLINE) + .append("--> Stream-ID = ") + .append(streamId()) + .append(StringUtil.NEWLINE) + .append("--> Delta-Window-Size = ") + .append(deltaWindowSize()) + .toString(); + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyCodecUtil.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyCodecUtil.java new file mode 100644 index 0000000..4f241aa --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyCodecUtil.java @@ -0,0 +1,328 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.spdy; + +import static io.netty.util.internal.ObjectUtil.checkNonEmpty; +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +import io.netty.buffer.ByteBuf; + +final class SpdyCodecUtil { + + static final int SPDY_SESSION_STREAM_ID = 0; + + static final int SPDY_HEADER_TYPE_OFFSET = 2; + static final int SPDY_HEADER_FLAGS_OFFSET = 4; + static final int SPDY_HEADER_LENGTH_OFFSET = 5; + static final int SPDY_HEADER_SIZE = 8; + + static final int SPDY_MAX_LENGTH = 0xFFFFFF; // Length is a 24-bit field + + static final byte SPDY_DATA_FLAG_FIN = 0x01; + + static final int SPDY_DATA_FRAME = 0; + static final int SPDY_SYN_STREAM_FRAME = 1; + static final int SPDY_SYN_REPLY_FRAME = 2; + static final int SPDY_RST_STREAM_FRAME = 3; + static final int SPDY_SETTINGS_FRAME = 4; + static final int SPDY_PUSH_PROMISE_FRAME = 5; + static final int SPDY_PING_FRAME = 6; + static final int SPDY_GOAWAY_FRAME = 7; + static final int SPDY_HEADERS_FRAME = 8; + static final int SPDY_WINDOW_UPDATE_FRAME = 9; + + static final byte SPDY_FLAG_FIN = 0x01; + static final byte SPDY_FLAG_UNIDIRECTIONAL = 0x02; + + static final byte SPDY_SETTINGS_CLEAR = 0x01; + static final byte SPDY_SETTINGS_PERSIST_VALUE = 0x01; + static final byte SPDY_SETTINGS_PERSISTED = 0x02; + + static final int SPDY_SETTINGS_MAX_ID = 0xFFFFFF; // ID is a 24-bit field + + static final int SPDY_MAX_NV_LENGTH = 0xFFFF; // Length is a 16-bit field + + // Zlib Dictionary + static final byte[] SPDY_DICT = { + 0x00, 0x00, 0x00, 0x07, 0x6f, 0x70, 0x74, 0x69, // - - - - o p t i + 0x6f, 0x6e, 0x73, 0x00, 0x00, 0x00, 0x04, 0x68, // o n s - - - - h + 0x65, 0x61, 0x64, 0x00, 0x00, 0x00, 0x04, 0x70, // e a d - - - - p + 0x6f, 0x73, 0x74, 0x00, 0x00, 0x00, 0x03, 0x70, // o s t - - - - p + 0x75, 0x74, 0x00, 0x00, 0x00, 0x06, 0x64, 0x65, // u t - - - - d e + 0x6c, 0x65, 0x74, 0x65, 0x00, 0x00, 0x00, 0x05, // l e t e - - - - + 0x74, 0x72, 0x61, 0x63, 0x65, 0x00, 0x00, 0x00, // t r a c e - - - + 0x06, 0x61, 0x63, 0x63, 0x65, 0x70, 0x74, 0x00, // - a c c e p t - + 0x00, 0x00, 0x0e, 0x61, 0x63, 0x63, 0x65, 0x70, // - - - a c c e p + 0x74, 0x2d, 0x63, 0x68, 0x61, 0x72, 0x73, 0x65, // t - c h a r s e + 0x74, 0x00, 0x00, 0x00, 0x0f, 0x61, 0x63, 0x63, // t - - - - a c c + 0x65, 0x70, 0x74, 0x2d, 0x65, 0x6e, 0x63, 0x6f, // e p t - e n c o + 0x64, 0x69, 0x6e, 0x67, 0x00, 0x00, 0x00, 0x0f, // d i n g - - - - + 0x61, 0x63, 0x63, 0x65, 0x70, 0x74, 0x2d, 0x6c, // a c c e p t - l + 0x61, 0x6e, 0x67, 0x75, 0x61, 0x67, 0x65, 0x00, // a n g u a g e - + 0x00, 0x00, 0x0d, 0x61, 0x63, 0x63, 0x65, 0x70, // - - - a c c e p + 0x74, 0x2d, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x73, // t - r a n g e s + 0x00, 0x00, 0x00, 0x03, 0x61, 0x67, 0x65, 0x00, // - - - - a g e - + 0x00, 0x00, 0x05, 0x61, 0x6c, 0x6c, 0x6f, 0x77, // - - - a l l o w + 0x00, 0x00, 0x00, 0x0d, 0x61, 0x75, 0x74, 0x68, // - - - - a u t h + 0x6f, 0x72, 0x69, 0x7a, 0x61, 0x74, 0x69, 0x6f, // o r i z a t i o + 0x6e, 0x00, 0x00, 0x00, 0x0d, 0x63, 0x61, 0x63, // n - - - - c a c + 0x68, 0x65, 0x2d, 0x63, 0x6f, 0x6e, 0x74, 0x72, // h e - c o n t r + 0x6f, 0x6c, 0x00, 0x00, 0x00, 0x0a, 0x63, 0x6f, // o l - - - - c o + 0x6e, 0x6e, 0x65, 0x63, 0x74, 0x69, 0x6f, 0x6e, // n n e c t i o n + 0x00, 0x00, 0x00, 0x0c, 0x63, 0x6f, 0x6e, 0x74, // - - - - c o n t + 0x65, 0x6e, 0x74, 0x2d, 0x62, 0x61, 0x73, 0x65, // e n t - b a s e + 0x00, 0x00, 0x00, 0x10, 0x63, 0x6f, 0x6e, 0x74, // - - - - c o n t + 0x65, 0x6e, 0x74, 0x2d, 0x65, 0x6e, 0x63, 0x6f, // e n t - e n c o + 0x64, 0x69, 0x6e, 0x67, 0x00, 0x00, 0x00, 0x10, // d i n g - - - - + 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x6e, 0x74, 0x2d, // c o n t e n t - + 0x6c, 0x61, 0x6e, 0x67, 0x75, 0x61, 0x67, 0x65, // l a n g u a g e + 0x00, 0x00, 0x00, 0x0e, 0x63, 0x6f, 0x6e, 0x74, // - - - - c o n t + 0x65, 0x6e, 0x74, 0x2d, 0x6c, 0x65, 0x6e, 0x67, // e n t - l e n g + 0x74, 0x68, 0x00, 0x00, 0x00, 0x10, 0x63, 0x6f, // t h - - - - c o + 0x6e, 0x74, 0x65, 0x6e, 0x74, 0x2d, 0x6c, 0x6f, // n t e n t - l o + 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x00, 0x00, // c a t i o n - - + 0x00, 0x0b, 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x6e, // - - c o n t e n + 0x74, 0x2d, 0x6d, 0x64, 0x35, 0x00, 0x00, 0x00, // t - m d 5 - - - + 0x0d, 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x6e, 0x74, // - c o n t e n t + 0x2d, 0x72, 0x61, 0x6e, 0x67, 0x65, 0x00, 0x00, // - r a n g e - - + 0x00, 0x0c, 0x63, 0x6f, 0x6e, 0x74, 0x65, 0x6e, // - - c o n t e n + 0x74, 0x2d, 0x74, 0x79, 0x70, 0x65, 0x00, 0x00, // t - t y p e - - + 0x00, 0x04, 0x64, 0x61, 0x74, 0x65, 0x00, 0x00, // - - d a t e - - + 0x00, 0x04, 0x65, 0x74, 0x61, 0x67, 0x00, 0x00, // - - e t a g - - + 0x00, 0x06, 0x65, 0x78, 0x70, 0x65, 0x63, 0x74, // - - e x p e c t + 0x00, 0x00, 0x00, 0x07, 0x65, 0x78, 0x70, 0x69, // - - - - e x p i + 0x72, 0x65, 0x73, 0x00, 0x00, 0x00, 0x04, 0x66, // r e s - - - - f + 0x72, 0x6f, 0x6d, 0x00, 0x00, 0x00, 0x04, 0x68, // r o m - - - - h + 0x6f, 0x73, 0x74, 0x00, 0x00, 0x00, 0x08, 0x69, // o s t - - - - i + 0x66, 0x2d, 0x6d, 0x61, 0x74, 0x63, 0x68, 0x00, // f - m a t c h - + 0x00, 0x00, 0x11, 0x69, 0x66, 0x2d, 0x6d, 0x6f, // - - - i f - m o + 0x64, 0x69, 0x66, 0x69, 0x65, 0x64, 0x2d, 0x73, // d i f i e d - s + 0x69, 0x6e, 0x63, 0x65, 0x00, 0x00, 0x00, 0x0d, // i n c e - - - - + 0x69, 0x66, 0x2d, 0x6e, 0x6f, 0x6e, 0x65, 0x2d, // i f - n o n e - + 0x6d, 0x61, 0x74, 0x63, 0x68, 0x00, 0x00, 0x00, // m a t c h - - - + 0x08, 0x69, 0x66, 0x2d, 0x72, 0x61, 0x6e, 0x67, // - i f - r a n g + 0x65, 0x00, 0x00, 0x00, 0x13, 0x69, 0x66, 0x2d, // e - - - - i f - + 0x75, 0x6e, 0x6d, 0x6f, 0x64, 0x69, 0x66, 0x69, // u n m o d i f i + 0x65, 0x64, 0x2d, 0x73, 0x69, 0x6e, 0x63, 0x65, // e d - s i n c e + 0x00, 0x00, 0x00, 0x0d, 0x6c, 0x61, 0x73, 0x74, // - - - - l a s t + 0x2d, 0x6d, 0x6f, 0x64, 0x69, 0x66, 0x69, 0x65, // - m o d i f i e + 0x64, 0x00, 0x00, 0x00, 0x08, 0x6c, 0x6f, 0x63, // d - - - - l o c + 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x00, 0x00, 0x00, // a t i o n - - - + 0x0c, 0x6d, 0x61, 0x78, 0x2d, 0x66, 0x6f, 0x72, // - m a x - f o r + 0x77, 0x61, 0x72, 0x64, 0x73, 0x00, 0x00, 0x00, // w a r d s - - - + 0x06, 0x70, 0x72, 0x61, 0x67, 0x6d, 0x61, 0x00, // - p r a g m a - + 0x00, 0x00, 0x12, 0x70, 0x72, 0x6f, 0x78, 0x79, // - - - p r o x y + 0x2d, 0x61, 0x75, 0x74, 0x68, 0x65, 0x6e, 0x74, // - a u t h e n t + 0x69, 0x63, 0x61, 0x74, 0x65, 0x00, 0x00, 0x00, // i c a t e - - - + 0x13, 0x70, 0x72, 0x6f, 0x78, 0x79, 0x2d, 0x61, // - p r o x y - a + 0x75, 0x74, 0x68, 0x6f, 0x72, 0x69, 0x7a, 0x61, // u t h o r i z a + 0x74, 0x69, 0x6f, 0x6e, 0x00, 0x00, 0x00, 0x05, // t i o n - - - - + 0x72, 0x61, 0x6e, 0x67, 0x65, 0x00, 0x00, 0x00, // r a n g e - - - + 0x07, 0x72, 0x65, 0x66, 0x65, 0x72, 0x65, 0x72, // - r e f e r e r + 0x00, 0x00, 0x00, 0x0b, 0x72, 0x65, 0x74, 0x72, // - - - - r e t r + 0x79, 0x2d, 0x61, 0x66, 0x74, 0x65, 0x72, 0x00, // y - a f t e r - + 0x00, 0x00, 0x06, 0x73, 0x65, 0x72, 0x76, 0x65, // - - - s e r v e + 0x72, 0x00, 0x00, 0x00, 0x02, 0x74, 0x65, 0x00, // r - - - - t e - + 0x00, 0x00, 0x07, 0x74, 0x72, 0x61, 0x69, 0x6c, // - - - t r a i l + 0x65, 0x72, 0x00, 0x00, 0x00, 0x11, 0x74, 0x72, // e r - - - - t r + 0x61, 0x6e, 0x73, 0x66, 0x65, 0x72, 0x2d, 0x65, // a n s f e r - e + 0x6e, 0x63, 0x6f, 0x64, 0x69, 0x6e, 0x67, 0x00, // n c o d i n g - + 0x00, 0x00, 0x07, 0x75, 0x70, 0x67, 0x72, 0x61, // - - - u p g r a + 0x64, 0x65, 0x00, 0x00, 0x00, 0x0a, 0x75, 0x73, // d e - - - - u s + 0x65, 0x72, 0x2d, 0x61, 0x67, 0x65, 0x6e, 0x74, // e r - a g e n t + 0x00, 0x00, 0x00, 0x04, 0x76, 0x61, 0x72, 0x79, // - - - - v a r y + 0x00, 0x00, 0x00, 0x03, 0x76, 0x69, 0x61, 0x00, // - - - - v i a - + 0x00, 0x00, 0x07, 0x77, 0x61, 0x72, 0x6e, 0x69, // - - - w a r n i + 0x6e, 0x67, 0x00, 0x00, 0x00, 0x10, 0x77, 0x77, // n g - - - - w w + 0x77, 0x2d, 0x61, 0x75, 0x74, 0x68, 0x65, 0x6e, // w - a u t h e n + 0x74, 0x69, 0x63, 0x61, 0x74, 0x65, 0x00, 0x00, // t i c a t e - - + 0x00, 0x06, 0x6d, 0x65, 0x74, 0x68, 0x6f, 0x64, // - - m e t h o d + 0x00, 0x00, 0x00, 0x03, 0x67, 0x65, 0x74, 0x00, // - - - - g e t - + 0x00, 0x00, 0x06, 0x73, 0x74, 0x61, 0x74, 0x75, // - - - s t a t u + 0x73, 0x00, 0x00, 0x00, 0x06, 0x32, 0x30, 0x30, // s - - - - 2 0 0 + 0x20, 0x4f, 0x4b, 0x00, 0x00, 0x00, 0x07, 0x76, // - O K - - - - v + 0x65, 0x72, 0x73, 0x69, 0x6f, 0x6e, 0x00, 0x00, // e r s i o n - - + 0x00, 0x08, 0x48, 0x54, 0x54, 0x50, 0x2f, 0x31, // - - H T T P - 1 + 0x2e, 0x31, 0x00, 0x00, 0x00, 0x03, 0x75, 0x72, // - 1 - - - - u r + 0x6c, 0x00, 0x00, 0x00, 0x06, 0x70, 0x75, 0x62, // l - - - - p u b + 0x6c, 0x69, 0x63, 0x00, 0x00, 0x00, 0x0a, 0x73, // l i c - - - - s + 0x65, 0x74, 0x2d, 0x63, 0x6f, 0x6f, 0x6b, 0x69, // e t - c o o k i + 0x65, 0x00, 0x00, 0x00, 0x0a, 0x6b, 0x65, 0x65, // e - - - - k e e + 0x70, 0x2d, 0x61, 0x6c, 0x69, 0x76, 0x65, 0x00, // p - a l i v e - + 0x00, 0x00, 0x06, 0x6f, 0x72, 0x69, 0x67, 0x69, // - - - o r i g i + 0x6e, 0x31, 0x30, 0x30, 0x31, 0x30, 0x31, 0x32, // n 1 0 0 1 0 1 2 + 0x30, 0x31, 0x32, 0x30, 0x32, 0x32, 0x30, 0x35, // 0 1 2 0 2 2 0 5 + 0x32, 0x30, 0x36, 0x33, 0x30, 0x30, 0x33, 0x30, // 2 0 6 3 0 0 3 0 + 0x32, 0x33, 0x30, 0x33, 0x33, 0x30, 0x34, 0x33, // 2 3 0 3 3 0 4 3 + 0x30, 0x35, 0x33, 0x30, 0x36, 0x33, 0x30, 0x37, // 0 5 3 0 6 3 0 7 + 0x34, 0x30, 0x32, 0x34, 0x30, 0x35, 0x34, 0x30, // 4 0 2 4 0 5 4 0 + 0x36, 0x34, 0x30, 0x37, 0x34, 0x30, 0x38, 0x34, // 6 4 0 7 4 0 8 4 + 0x30, 0x39, 0x34, 0x31, 0x30, 0x34, 0x31, 0x31, // 0 9 4 1 0 4 1 1 + 0x34, 0x31, 0x32, 0x34, 0x31, 0x33, 0x34, 0x31, // 4 1 2 4 1 3 4 1 + 0x34, 0x34, 0x31, 0x35, 0x34, 0x31, 0x36, 0x34, // 4 4 1 5 4 1 6 4 + 0x31, 0x37, 0x35, 0x30, 0x32, 0x35, 0x30, 0x34, // 1 7 5 0 2 5 0 4 + 0x35, 0x30, 0x35, 0x32, 0x30, 0x33, 0x20, 0x4e, // 5 0 5 2 0 3 - N + 0x6f, 0x6e, 0x2d, 0x41, 0x75, 0x74, 0x68, 0x6f, // o n - A u t h o + 0x72, 0x69, 0x74, 0x61, 0x74, 0x69, 0x76, 0x65, // r i t a t i v e + 0x20, 0x49, 0x6e, 0x66, 0x6f, 0x72, 0x6d, 0x61, // - I n f o r m a + 0x74, 0x69, 0x6f, 0x6e, 0x32, 0x30, 0x34, 0x20, // t i o n 2 0 4 - + 0x4e, 0x6f, 0x20, 0x43, 0x6f, 0x6e, 0x74, 0x65, // N o - C o n t e + 0x6e, 0x74, 0x33, 0x30, 0x31, 0x20, 0x4d, 0x6f, // n t 3 0 1 - M o + 0x76, 0x65, 0x64, 0x20, 0x50, 0x65, 0x72, 0x6d, // v e d - P e r m + 0x61, 0x6e, 0x65, 0x6e, 0x74, 0x6c, 0x79, 0x34, // a n e n t l y 4 + 0x30, 0x30, 0x20, 0x42, 0x61, 0x64, 0x20, 0x52, // 0 0 - B a d - R + 0x65, 0x71, 0x75, 0x65, 0x73, 0x74, 0x34, 0x30, // e q u e s t 4 0 + 0x31, 0x20, 0x55, 0x6e, 0x61, 0x75, 0x74, 0x68, // 1 - U n a u t h + 0x6f, 0x72, 0x69, 0x7a, 0x65, 0x64, 0x34, 0x30, // o r i z e d 4 0 + 0x33, 0x20, 0x46, 0x6f, 0x72, 0x62, 0x69, 0x64, // 3 - F o r b i d + 0x64, 0x65, 0x6e, 0x34, 0x30, 0x34, 0x20, 0x4e, // d e n 4 0 4 - N + 0x6f, 0x74, 0x20, 0x46, 0x6f, 0x75, 0x6e, 0x64, // o t - F o u n d + 0x35, 0x30, 0x30, 0x20, 0x49, 0x6e, 0x74, 0x65, // 5 0 0 - I n t e + 0x72, 0x6e, 0x61, 0x6c, 0x20, 0x53, 0x65, 0x72, // r n a l - S e r + 0x76, 0x65, 0x72, 0x20, 0x45, 0x72, 0x72, 0x6f, // v e r - E r r o + 0x72, 0x35, 0x30, 0x31, 0x20, 0x4e, 0x6f, 0x74, // r 5 0 1 - N o t + 0x20, 0x49, 0x6d, 0x70, 0x6c, 0x65, 0x6d, 0x65, // - I m p l e m e + 0x6e, 0x74, 0x65, 0x64, 0x35, 0x30, 0x33, 0x20, // n t e d 5 0 3 - + 0x53, 0x65, 0x72, 0x76, 0x69, 0x63, 0x65, 0x20, // S e r v i c e - + 0x55, 0x6e, 0x61, 0x76, 0x61, 0x69, 0x6c, 0x61, // U n a v a i l a + 0x62, 0x6c, 0x65, 0x4a, 0x61, 0x6e, 0x20, 0x46, // b l e J a n - F + 0x65, 0x62, 0x20, 0x4d, 0x61, 0x72, 0x20, 0x41, // e b - M a r - A + 0x70, 0x72, 0x20, 0x4d, 0x61, 0x79, 0x20, 0x4a, // p r - M a y - J + 0x75, 0x6e, 0x20, 0x4a, 0x75, 0x6c, 0x20, 0x41, // u n - J u l - A + 0x75, 0x67, 0x20, 0x53, 0x65, 0x70, 0x74, 0x20, // u g - S e p t - + 0x4f, 0x63, 0x74, 0x20, 0x4e, 0x6f, 0x76, 0x20, // O c t - N o v - + 0x44, 0x65, 0x63, 0x20, 0x30, 0x30, 0x3a, 0x30, // D e c - 0 0 - 0 + 0x30, 0x3a, 0x30, 0x30, 0x20, 0x4d, 0x6f, 0x6e, // 0 - 0 0 - M o n + 0x2c, 0x20, 0x54, 0x75, 0x65, 0x2c, 0x20, 0x57, // - - T u e - - W + 0x65, 0x64, 0x2c, 0x20, 0x54, 0x68, 0x75, 0x2c, // e d - - T h u - + 0x20, 0x46, 0x72, 0x69, 0x2c, 0x20, 0x53, 0x61, // - F r i - - S a + 0x74, 0x2c, 0x20, 0x53, 0x75, 0x6e, 0x2c, 0x20, // t - - S u n - - + 0x47, 0x4d, 0x54, 0x63, 0x68, 0x75, 0x6e, 0x6b, // G M T c h u n k + 0x65, 0x64, 0x2c, 0x74, 0x65, 0x78, 0x74, 0x2f, // e d - t e x t - + 0x68, 0x74, 0x6d, 0x6c, 0x2c, 0x69, 0x6d, 0x61, // h t m l - i m a + 0x67, 0x65, 0x2f, 0x70, 0x6e, 0x67, 0x2c, 0x69, // g e - p n g - i + 0x6d, 0x61, 0x67, 0x65, 0x2f, 0x6a, 0x70, 0x67, // m a g e - j p g + 0x2c, 0x69, 0x6d, 0x61, 0x67, 0x65, 0x2f, 0x67, // - i m a g e - g + 0x69, 0x66, 0x2c, 0x61, 0x70, 0x70, 0x6c, 0x69, // i f - a p p l i + 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2f, 0x78, // c a t i o n - x + 0x6d, 0x6c, 0x2c, 0x61, 0x70, 0x70, 0x6c, 0x69, // m l - a p p l i + 0x63, 0x61, 0x74, 0x69, 0x6f, 0x6e, 0x2f, 0x78, // c a t i o n - x + 0x68, 0x74, 0x6d, 0x6c, 0x2b, 0x78, 0x6d, 0x6c, // h t m l - x m l + 0x2c, 0x74, 0x65, 0x78, 0x74, 0x2f, 0x70, 0x6c, // - t e x t - p l + 0x61, 0x69, 0x6e, 0x2c, 0x74, 0x65, 0x78, 0x74, // a i n - t e x t + 0x2f, 0x6a, 0x61, 0x76, 0x61, 0x73, 0x63, 0x72, // - j a v a s c r + 0x69, 0x70, 0x74, 0x2c, 0x70, 0x75, 0x62, 0x6c, // i p t - p u b l + 0x69, 0x63, 0x70, 0x72, 0x69, 0x76, 0x61, 0x74, // i c p r i v a t + 0x65, 0x6d, 0x61, 0x78, 0x2d, 0x61, 0x67, 0x65, // e m a x - a g e + 0x3d, 0x67, 0x7a, 0x69, 0x70, 0x2c, 0x64, 0x65, // - g z i p - d e + 0x66, 0x6c, 0x61, 0x74, 0x65, 0x2c, 0x73, 0x64, // f l a t e - s d + 0x63, 0x68, 0x63, 0x68, 0x61, 0x72, 0x73, 0x65, // c h c h a r s e + 0x74, 0x3d, 0x75, 0x74, 0x66, 0x2d, 0x38, 0x63, // t - u t f - 8 c + 0x68, 0x61, 0x72, 0x73, 0x65, 0x74, 0x3d, 0x69, // h a r s e t - i + 0x73, 0x6f, 0x2d, 0x38, 0x38, 0x35, 0x39, 0x2d, // s o - 8 8 5 9 - + 0x31, 0x2c, 0x75, 0x74, 0x66, 0x2d, 0x2c, 0x2a, // 1 - u t f - - - + 0x2c, 0x65, 0x6e, 0x71, 0x3d, 0x30, 0x2e // - e n q - 0 - + }; + + private SpdyCodecUtil() { + } + + /** + * Reads a big-endian unsigned short integer from the buffer. + */ + static int getUnsignedShort(ByteBuf buf, int offset) { + return (buf.getByte(offset) & 0xFF) << 8 | + buf.getByte(offset + 1) & 0xFF; + } + + /** + * Reads a big-endian unsigned medium integer from the buffer. + */ + static int getUnsignedMedium(ByteBuf buf, int offset) { + return (buf.getByte(offset) & 0xFF) << 16 | + (buf.getByte(offset + 1) & 0xFF) << 8 | + buf.getByte(offset + 2) & 0xFF; + } + + /** + * Reads a big-endian (31-bit) integer from the buffer. + */ + static int getUnsignedInt(ByteBuf buf, int offset) { + return (buf.getByte(offset) & 0x7F) << 24 | + (buf.getByte(offset + 1) & 0xFF) << 16 | + (buf.getByte(offset + 2) & 0xFF) << 8 | + buf.getByte(offset + 3) & 0xFF; + } + + /** + * Reads a big-endian signed integer from the buffer. + */ + static int getSignedInt(ByteBuf buf, int offset) { + return (buf.getByte(offset) & 0xFF) << 24 | + (buf.getByte(offset + 1) & 0xFF) << 16 | + (buf.getByte(offset + 2) & 0xFF) << 8 | + buf.getByte(offset + 3) & 0xFF; + } + + /** + * Returns {@code true} if ID is for a server initiated stream or ping. + */ + static boolean isServerId(int id) { + // Server initiated streams and pings have even IDs + return id % 2 == 0; + } + + /** + * Validate a SPDY header name. + */ + static void validateHeaderName(CharSequence name) { + checkNonEmpty(name, "name"); + // Since name may only contain ascii characters, for valid names + // name.length() returns the number of bytes when UTF-8 encoded. + if (name.length() > SPDY_MAX_NV_LENGTH) { + throw new IllegalArgumentException( + "name exceeds allowable length: " + name); + } + for (int i = 0; i < name.length(); i ++) { + char c = name.charAt(i); + if (c == 0) { + throw new IllegalArgumentException( + "name contains null character: " + name); + } + if (c >= 'A' && c <= 'Z') { + throw new IllegalArgumentException("name must be all lower case."); + } + if (c > 127) { + throw new IllegalArgumentException( + "name contains non-ascii character: " + name); + } + } + } + + /** + * Validate a SPDY header value. Does not validate max length. + */ + static void validateHeaderValue(CharSequence value) { + checkNotNull(value, "value"); + for (int i = 0; i < value.length(); i ++) { + char c = value.charAt(i); + if (c == 0) { + throw new IllegalArgumentException( + "value contains null character"); + } + } + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyDataFrame.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyDataFrame.java new file mode 100644 index 0000000..56ec79d --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyDataFrame.java @@ -0,0 +1,65 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.spdy; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufHolder; +import io.netty.buffer.Unpooled; + +/** + * A SPDY Protocol DATA Frame + */ +public interface SpdyDataFrame extends ByteBufHolder, SpdyStreamFrame { + + @Override + SpdyDataFrame setStreamId(int streamID); + + @Override + SpdyDataFrame setLast(boolean last); + + /** + * Returns the data payload of this frame. If there is no data payload + * {@link Unpooled#EMPTY_BUFFER} is returned. + * + * The data payload cannot exceed 16777215 bytes. + */ + @Override + ByteBuf content(); + + @Override + SpdyDataFrame copy(); + + @Override + SpdyDataFrame duplicate(); + + @Override + SpdyDataFrame retainedDuplicate(); + + @Override + SpdyDataFrame replace(ByteBuf content); + + @Override + SpdyDataFrame retain(); + + @Override + SpdyDataFrame retain(int increment); + + @Override + SpdyDataFrame touch(); + + @Override + SpdyDataFrame touch(Object hint); +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyFrame.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyFrame.java new file mode 100644 index 0000000..254756a --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyFrame.java @@ -0,0 +1,23 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.spdy; + +/** + * A SPDY Protocol Frame + */ +public interface SpdyFrame { + // Tag interface +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyFrameCodec.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyFrameCodec.java new file mode 100644 index 0000000..c67992f --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyFrameCodec.java @@ -0,0 +1,410 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.spdy; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelOutboundHandler; +import io.netty.channel.ChannelPromise; +import io.netty.handler.codec.ByteToMessageDecoder; +import io.netty.handler.codec.UnsupportedMessageTypeException; + +import java.net.SocketAddress; +import java.util.List; + +/** + * A {@link ChannelHandler} that encodes and decodes SPDY Frames. + */ +public class SpdyFrameCodec extends ByteToMessageDecoder + implements SpdyFrameDecoderDelegate, ChannelOutboundHandler { + + private static final SpdyProtocolException INVALID_FRAME = + new SpdyProtocolException("Received invalid frame"); + + private final SpdyFrameDecoder spdyFrameDecoder; + private final SpdyFrameEncoder spdyFrameEncoder; + private final SpdyHeaderBlockDecoder spdyHeaderBlockDecoder; + private final SpdyHeaderBlockEncoder spdyHeaderBlockEncoder; + + private SpdyHeadersFrame spdyHeadersFrame; + private SpdySettingsFrame spdySettingsFrame; + + private ChannelHandlerContext ctx; + private boolean read; + private final boolean validateHeaders; + + /** + * Creates a new instance with the specified {@code version}, + * {@code validateHeaders (true)}, and + * the default decoder and encoder options + * ({@code maxChunkSize (8192)}, {@code maxHeaderSize (16384)}, + * {@code compressionLevel (6)}, {@code windowBits (15)}, + * and {@code memLevel (8)}). + */ + public SpdyFrameCodec(SpdyVersion version) { + this(version, true); + } + + /** + * Creates a new instance with the specified {@code version}, + * {@code validateHeaders}, and + * the default decoder and encoder options + * ({@code maxChunkSize (8192)}, {@code maxHeaderSize (16384)}, + * {@code compressionLevel (6)}, {@code windowBits (15)}, + * and {@code memLevel (8)}). + */ + public SpdyFrameCodec(SpdyVersion version, boolean validateHeaders) { + this(version, 8192, 16384, 6, 15, 8, validateHeaders); + } + + /** + * Creates a new instance with the specified {@code version}, {@code validateHeaders (true)}, + * decoder and encoder options. + */ + public SpdyFrameCodec( + SpdyVersion version, int maxChunkSize, int maxHeaderSize, + int compressionLevel, int windowBits, int memLevel) { + this(version, maxChunkSize, maxHeaderSize, compressionLevel, windowBits, memLevel, true); + } + + /** + * Creates a new instance with the specified {@code version}, {@code validateHeaders}, + * decoder and encoder options. + */ + public SpdyFrameCodec( + SpdyVersion version, int maxChunkSize, int maxHeaderSize, + int compressionLevel, int windowBits, int memLevel, boolean validateHeaders) { + this(version, maxChunkSize, + SpdyHeaderBlockDecoder.newInstance(version, maxHeaderSize), + SpdyHeaderBlockEncoder.newInstance(version, compressionLevel, windowBits, memLevel), validateHeaders); + } + + protected SpdyFrameCodec(SpdyVersion version, int maxChunkSize, + SpdyHeaderBlockDecoder spdyHeaderBlockDecoder, SpdyHeaderBlockEncoder spdyHeaderBlockEncoder, + boolean validateHeaders) { + spdyFrameDecoder = new SpdyFrameDecoder(version, this, maxChunkSize); + spdyFrameEncoder = new SpdyFrameEncoder(version); + this.spdyHeaderBlockDecoder = spdyHeaderBlockDecoder; + this.spdyHeaderBlockEncoder = spdyHeaderBlockEncoder; + this.validateHeaders = validateHeaders; + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + super.handlerAdded(ctx); + this.ctx = ctx; + ctx.channel().closeFuture().addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + spdyHeaderBlockDecoder.end(); + spdyHeaderBlockEncoder.end(); + } + }); + } + + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { + spdyFrameDecoder.decode(in); + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + if (!read) { + if (!ctx.channel().config().isAutoRead()) { + ctx.read(); + } + } + read = false; + super.channelReadComplete(ctx); + } + + @Override + public void bind(ChannelHandlerContext ctx, SocketAddress localAddress, ChannelPromise promise) throws Exception { + ctx.bind(localAddress, promise); + } + + @Override + public void connect(ChannelHandlerContext ctx, SocketAddress remoteAddress, SocketAddress localAddress, + ChannelPromise promise) throws Exception { + ctx.connect(remoteAddress, localAddress, promise); + } + + @Override + public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + ctx.disconnect(promise); + } + + @Override + public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + ctx.close(promise); + } + + @Override + public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + ctx.deregister(promise); + } + + @Override + public void read(ChannelHandlerContext ctx) throws Exception { + ctx.read(); + } + + @Override + public void flush(ChannelHandlerContext ctx) throws Exception { + ctx.flush(); + } + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + ByteBuf frame; + + if (msg instanceof SpdyDataFrame) { + + SpdyDataFrame spdyDataFrame = (SpdyDataFrame) msg; + frame = spdyFrameEncoder.encodeDataFrame( + ctx.alloc(), + spdyDataFrame.streamId(), + spdyDataFrame.isLast(), + spdyDataFrame.content() + ); + spdyDataFrame.release(); + ctx.write(frame, promise); + + } else if (msg instanceof SpdySynStreamFrame) { + + SpdySynStreamFrame spdySynStreamFrame = (SpdySynStreamFrame) msg; + ByteBuf headerBlock = spdyHeaderBlockEncoder.encode(ctx.alloc(), spdySynStreamFrame); + try { + frame = spdyFrameEncoder.encodeSynStreamFrame( + ctx.alloc(), + spdySynStreamFrame.streamId(), + spdySynStreamFrame.associatedStreamId(), + spdySynStreamFrame.priority(), + spdySynStreamFrame.isLast(), + spdySynStreamFrame.isUnidirectional(), + headerBlock + ); + } finally { + headerBlock.release(); + } + ctx.write(frame, promise); + + } else if (msg instanceof SpdySynReplyFrame) { + + SpdySynReplyFrame spdySynReplyFrame = (SpdySynReplyFrame) msg; + ByteBuf headerBlock = spdyHeaderBlockEncoder.encode(ctx.alloc(), spdySynReplyFrame); + try { + frame = spdyFrameEncoder.encodeSynReplyFrame( + ctx.alloc(), + spdySynReplyFrame.streamId(), + spdySynReplyFrame.isLast(), + headerBlock + ); + } finally { + headerBlock.release(); + } + ctx.write(frame, promise); + + } else if (msg instanceof SpdyRstStreamFrame) { + + SpdyRstStreamFrame spdyRstStreamFrame = (SpdyRstStreamFrame) msg; + frame = spdyFrameEncoder.encodeRstStreamFrame( + ctx.alloc(), + spdyRstStreamFrame.streamId(), + spdyRstStreamFrame.status().code() + ); + ctx.write(frame, promise); + + } else if (msg instanceof SpdySettingsFrame) { + + SpdySettingsFrame spdySettingsFrame = (SpdySettingsFrame) msg; + frame = spdyFrameEncoder.encodeSettingsFrame( + ctx.alloc(), + spdySettingsFrame + ); + ctx.write(frame, promise); + + } else if (msg instanceof SpdyPingFrame) { + + SpdyPingFrame spdyPingFrame = (SpdyPingFrame) msg; + frame = spdyFrameEncoder.encodePingFrame( + ctx.alloc(), + spdyPingFrame.id() + ); + ctx.write(frame, promise); + + } else if (msg instanceof SpdyGoAwayFrame) { + + SpdyGoAwayFrame spdyGoAwayFrame = (SpdyGoAwayFrame) msg; + frame = spdyFrameEncoder.encodeGoAwayFrame( + ctx.alloc(), + spdyGoAwayFrame.lastGoodStreamId(), + spdyGoAwayFrame.status().code() + ); + ctx.write(frame, promise); + + } else if (msg instanceof SpdyHeadersFrame) { + + SpdyHeadersFrame spdyHeadersFrame = (SpdyHeadersFrame) msg; + ByteBuf headerBlock = spdyHeaderBlockEncoder.encode(ctx.alloc(), spdyHeadersFrame); + try { + frame = spdyFrameEncoder.encodeHeadersFrame( + ctx.alloc(), + spdyHeadersFrame.streamId(), + spdyHeadersFrame.isLast(), + headerBlock + ); + } finally { + headerBlock.release(); + } + ctx.write(frame, promise); + + } else if (msg instanceof SpdyWindowUpdateFrame) { + + SpdyWindowUpdateFrame spdyWindowUpdateFrame = (SpdyWindowUpdateFrame) msg; + frame = spdyFrameEncoder.encodeWindowUpdateFrame( + ctx.alloc(), + spdyWindowUpdateFrame.streamId(), + spdyWindowUpdateFrame.deltaWindowSize() + ); + ctx.write(frame, promise); + } else { + throw new UnsupportedMessageTypeException(msg); + } + } + + @Override + public void readDataFrame(int streamId, boolean last, ByteBuf data) { + read = true; + + SpdyDataFrame spdyDataFrame = new DefaultSpdyDataFrame(streamId, data); + spdyDataFrame.setLast(last); + ctx.fireChannelRead(spdyDataFrame); + } + + @Override + public void readSynStreamFrame( + int streamId, int associatedToStreamId, byte priority, boolean last, boolean unidirectional) { + SpdySynStreamFrame spdySynStreamFrame = + new DefaultSpdySynStreamFrame(streamId, associatedToStreamId, priority, validateHeaders); + spdySynStreamFrame.setLast(last); + spdySynStreamFrame.setUnidirectional(unidirectional); + spdyHeadersFrame = spdySynStreamFrame; + } + + @Override + public void readSynReplyFrame(int streamId, boolean last) { + SpdySynReplyFrame spdySynReplyFrame = new DefaultSpdySynReplyFrame(streamId, validateHeaders); + spdySynReplyFrame.setLast(last); + spdyHeadersFrame = spdySynReplyFrame; + } + + @Override + public void readRstStreamFrame(int streamId, int statusCode) { + read = true; + + SpdyRstStreamFrame spdyRstStreamFrame = new DefaultSpdyRstStreamFrame(streamId, statusCode); + ctx.fireChannelRead(spdyRstStreamFrame); + } + + @Override + public void readSettingsFrame(boolean clearPersisted) { + read = true; + + spdySettingsFrame = new DefaultSpdySettingsFrame(); + spdySettingsFrame.setClearPreviouslyPersistedSettings(clearPersisted); + } + + @Override + public void readSetting(int id, int value, boolean persistValue, boolean persisted) { + spdySettingsFrame.setValue(id, value, persistValue, persisted); + } + + @Override + public void readSettingsEnd() { + read = true; + + Object frame = spdySettingsFrame; + spdySettingsFrame = null; + ctx.fireChannelRead(frame); + } + + @Override + public void readPingFrame(int id) { + read = true; + + SpdyPingFrame spdyPingFrame = new DefaultSpdyPingFrame(id); + ctx.fireChannelRead(spdyPingFrame); + } + + @Override + public void readGoAwayFrame(int lastGoodStreamId, int statusCode) { + read = true; + + SpdyGoAwayFrame spdyGoAwayFrame = new DefaultSpdyGoAwayFrame(lastGoodStreamId, statusCode); + ctx.fireChannelRead(spdyGoAwayFrame); + } + + @Override + public void readHeadersFrame(int streamId, boolean last) { + spdyHeadersFrame = new DefaultSpdyHeadersFrame(streamId, validateHeaders); + spdyHeadersFrame.setLast(last); + } + + @Override + public void readWindowUpdateFrame(int streamId, int deltaWindowSize) { + read = true; + + SpdyWindowUpdateFrame spdyWindowUpdateFrame = new DefaultSpdyWindowUpdateFrame(streamId, deltaWindowSize); + ctx.fireChannelRead(spdyWindowUpdateFrame); + } + + @Override + public void readHeaderBlock(ByteBuf headerBlock) { + try { + spdyHeaderBlockDecoder.decode(ctx.alloc(), headerBlock, spdyHeadersFrame); + } catch (Exception e) { + ctx.fireExceptionCaught(e); + } finally { + headerBlock.release(); + } + } + + @Override + public void readHeaderBlockEnd() { + Object frame = null; + try { + spdyHeaderBlockDecoder.endHeaderBlock(spdyHeadersFrame); + frame = spdyHeadersFrame; + spdyHeadersFrame = null; + } catch (Exception e) { + ctx.fireExceptionCaught(e); + } + if (frame != null) { + read = true; + + ctx.fireChannelRead(frame); + } + } + + @Override + public void readFrameError(String message) { + ctx.fireExceptionCaught(INVALID_FRAME); + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyFrameDecoder.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyFrameDecoder.java new file mode 100644 index 0000000..afc222c --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyFrameDecoder.java @@ -0,0 +1,457 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.spdy; + +import static io.netty.handler.codec.spdy.SpdyCodecUtil.SPDY_DATA_FLAG_FIN; +import static io.netty.handler.codec.spdy.SpdyCodecUtil.SPDY_DATA_FRAME; +import static io.netty.handler.codec.spdy.SpdyCodecUtil.SPDY_FLAG_FIN; +import static io.netty.handler.codec.spdy.SpdyCodecUtil.SPDY_FLAG_UNIDIRECTIONAL; +import static io.netty.handler.codec.spdy.SpdyCodecUtil.SPDY_GOAWAY_FRAME; +import static io.netty.handler.codec.spdy.SpdyCodecUtil.SPDY_HEADERS_FRAME; +import static io.netty.handler.codec.spdy.SpdyCodecUtil.SPDY_HEADER_FLAGS_OFFSET; +import static io.netty.handler.codec.spdy.SpdyCodecUtil.SPDY_HEADER_LENGTH_OFFSET; +import static io.netty.handler.codec.spdy.SpdyCodecUtil.SPDY_HEADER_SIZE; +import static io.netty.handler.codec.spdy.SpdyCodecUtil.SPDY_HEADER_TYPE_OFFSET; +import static io.netty.handler.codec.spdy.SpdyCodecUtil.SPDY_PING_FRAME; +import static io.netty.handler.codec.spdy.SpdyCodecUtil.SPDY_RST_STREAM_FRAME; +import static io.netty.handler.codec.spdy.SpdyCodecUtil.SPDY_SETTINGS_CLEAR; +import static io.netty.handler.codec.spdy.SpdyCodecUtil.SPDY_SETTINGS_FRAME; +import static io.netty.handler.codec.spdy.SpdyCodecUtil.SPDY_SETTINGS_PERSISTED; +import static io.netty.handler.codec.spdy.SpdyCodecUtil.SPDY_SETTINGS_PERSIST_VALUE; +import static io.netty.handler.codec.spdy.SpdyCodecUtil.SPDY_SYN_REPLY_FRAME; +import static io.netty.handler.codec.spdy.SpdyCodecUtil.SPDY_SYN_STREAM_FRAME; +import static io.netty.handler.codec.spdy.SpdyCodecUtil.SPDY_WINDOW_UPDATE_FRAME; +import static io.netty.handler.codec.spdy.SpdyCodecUtil.getSignedInt; +import static io.netty.handler.codec.spdy.SpdyCodecUtil.getUnsignedInt; +import static io.netty.handler.codec.spdy.SpdyCodecUtil.getUnsignedMedium; +import static io.netty.handler.codec.spdy.SpdyCodecUtil.getUnsignedShort; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.util.internal.ObjectUtil; + +/** + * Decodes {@link ByteBuf}s into SPDY Frames. + */ +public class SpdyFrameDecoder { + + private final int spdyVersion; + private final int maxChunkSize; + + private final SpdyFrameDecoderDelegate delegate; + + private State state; + + // SPDY common header fields + private byte flags; + private int length; + private int streamId; + + private int numSettings; + + private enum State { + READ_COMMON_HEADER, + READ_DATA_FRAME, + READ_SYN_STREAM_FRAME, + READ_SYN_REPLY_FRAME, + READ_RST_STREAM_FRAME, + READ_SETTINGS_FRAME, + READ_SETTING, + READ_PING_FRAME, + READ_GOAWAY_FRAME, + READ_HEADERS_FRAME, + READ_WINDOW_UPDATE_FRAME, + READ_HEADER_BLOCK, + DISCARD_FRAME, + FRAME_ERROR + } + + /** + * Creates a new instance with the specified {@code version} + * and the default {@code maxChunkSize (8192)}. + */ + public SpdyFrameDecoder(SpdyVersion spdyVersion, SpdyFrameDecoderDelegate delegate) { + this(spdyVersion, delegate, 8192); + } + + /** + * Creates a new instance with the specified parameters. + */ + public SpdyFrameDecoder(SpdyVersion spdyVersion, SpdyFrameDecoderDelegate delegate, int maxChunkSize) { + this.spdyVersion = ObjectUtil.checkNotNull(spdyVersion, "spdyVersion").getVersion(); + this.delegate = ObjectUtil.checkNotNull(delegate, "delegate"); + this.maxChunkSize = ObjectUtil.checkPositive(maxChunkSize, "maxChunkSize"); + state = State.READ_COMMON_HEADER; + } + + public void decode(ByteBuf buffer) { + boolean last; + int statusCode; + + while (true) { + switch(state) { + case READ_COMMON_HEADER: + if (buffer.readableBytes() < SPDY_HEADER_SIZE) { + return; + } + + int frameOffset = buffer.readerIndex(); + int flagsOffset = frameOffset + SPDY_HEADER_FLAGS_OFFSET; + int lengthOffset = frameOffset + SPDY_HEADER_LENGTH_OFFSET; + buffer.skipBytes(SPDY_HEADER_SIZE); + + boolean control = (buffer.getByte(frameOffset) & 0x80) != 0; + + int version; + int type; + if (control) { + // Decode control frame common header + version = getUnsignedShort(buffer, frameOffset) & 0x7FFF; + type = getUnsignedShort(buffer, frameOffset + SPDY_HEADER_TYPE_OFFSET); + streamId = 0; // Default to session Stream-ID + } else { + // Decode data frame common header + version = spdyVersion; // Default to expected version + type = SPDY_DATA_FRAME; + streamId = getUnsignedInt(buffer, frameOffset); + } + + flags = buffer.getByte(flagsOffset); + length = getUnsignedMedium(buffer, lengthOffset); + + // Check version first then validity + if (version != spdyVersion) { + state = State.FRAME_ERROR; + delegate.readFrameError("Invalid SPDY Version"); + } else if (!isValidFrameHeader(streamId, type, flags, length)) { + state = State.FRAME_ERROR; + delegate.readFrameError("Invalid Frame Error"); + } else { + state = getNextState(type, length); + } + break; + + case READ_DATA_FRAME: + if (length == 0) { + state = State.READ_COMMON_HEADER; + delegate.readDataFrame(streamId, hasFlag(flags, SPDY_DATA_FLAG_FIN), Unpooled.buffer(0)); + break; + } + + // Generate data frames that do not exceed maxChunkSize + int dataLength = Math.min(maxChunkSize, length); + + // Wait until entire frame is readable + if (buffer.readableBytes() < dataLength) { + return; + } + + ByteBuf data = buffer.alloc().buffer(dataLength); + data.writeBytes(buffer, dataLength); + length -= dataLength; + + if (length == 0) { + state = State.READ_COMMON_HEADER; + } + + last = length == 0 && hasFlag(flags, SPDY_DATA_FLAG_FIN); + + delegate.readDataFrame(streamId, last, data); + break; + + case READ_SYN_STREAM_FRAME: + if (buffer.readableBytes() < 10) { + return; + } + + int offset = buffer.readerIndex(); + streamId = getUnsignedInt(buffer, offset); + int associatedToStreamId = getUnsignedInt(buffer, offset + 4); + byte priority = (byte) (buffer.getByte(offset + 8) >> 5 & 0x07); + last = hasFlag(flags, SPDY_FLAG_FIN); + boolean unidirectional = hasFlag(flags, SPDY_FLAG_UNIDIRECTIONAL); + buffer.skipBytes(10); + length -= 10; + + if (streamId == 0) { + state = State.FRAME_ERROR; + delegate.readFrameError("Invalid SYN_STREAM Frame"); + } else { + state = State.READ_HEADER_BLOCK; + delegate.readSynStreamFrame(streamId, associatedToStreamId, priority, last, unidirectional); + } + break; + + case READ_SYN_REPLY_FRAME: + if (buffer.readableBytes() < 4) { + return; + } + + streamId = getUnsignedInt(buffer, buffer.readerIndex()); + last = hasFlag(flags, SPDY_FLAG_FIN); + + buffer.skipBytes(4); + length -= 4; + + if (streamId == 0) { + state = State.FRAME_ERROR; + delegate.readFrameError("Invalid SYN_REPLY Frame"); + } else { + state = State.READ_HEADER_BLOCK; + delegate.readSynReplyFrame(streamId, last); + } + break; + + case READ_RST_STREAM_FRAME: + if (buffer.readableBytes() < 8) { + return; + } + + streamId = getUnsignedInt(buffer, buffer.readerIndex()); + statusCode = getSignedInt(buffer, buffer.readerIndex() + 4); + buffer.skipBytes(8); + + if (streamId == 0 || statusCode == 0) { + state = State.FRAME_ERROR; + delegate.readFrameError("Invalid RST_STREAM Frame"); + } else { + state = State.READ_COMMON_HEADER; + delegate.readRstStreamFrame(streamId, statusCode); + } + break; + + case READ_SETTINGS_FRAME: + if (buffer.readableBytes() < 4) { + return; + } + + boolean clear = hasFlag(flags, SPDY_SETTINGS_CLEAR); + + numSettings = getUnsignedInt(buffer, buffer.readerIndex()); + buffer.skipBytes(4); + length -= 4; + + // Validate frame length against number of entries. Each ID/Value entry is 8 bytes. + if ((length & 0x07) != 0 || length >> 3 != numSettings) { + state = State.FRAME_ERROR; + delegate.readFrameError("Invalid SETTINGS Frame"); + } else { + state = State.READ_SETTING; + delegate.readSettingsFrame(clear); + } + break; + + case READ_SETTING: + if (numSettings == 0) { + state = State.READ_COMMON_HEADER; + delegate.readSettingsEnd(); + break; + } + + if (buffer.readableBytes() < 8) { + return; + } + + byte settingsFlags = buffer.getByte(buffer.readerIndex()); + int id = getUnsignedMedium(buffer, buffer.readerIndex() + 1); + int value = getSignedInt(buffer, buffer.readerIndex() + 4); + boolean persistValue = hasFlag(settingsFlags, SPDY_SETTINGS_PERSIST_VALUE); + boolean persisted = hasFlag(settingsFlags, SPDY_SETTINGS_PERSISTED); + buffer.skipBytes(8); + + --numSettings; + + delegate.readSetting(id, value, persistValue, persisted); + break; + + case READ_PING_FRAME: + if (buffer.readableBytes() < 4) { + return; + } + + int pingId = getSignedInt(buffer, buffer.readerIndex()); + buffer.skipBytes(4); + + state = State.READ_COMMON_HEADER; + delegate.readPingFrame(pingId); + break; + + case READ_GOAWAY_FRAME: + if (buffer.readableBytes() < 8) { + return; + } + + int lastGoodStreamId = getUnsignedInt(buffer, buffer.readerIndex()); + statusCode = getSignedInt(buffer, buffer.readerIndex() + 4); + buffer.skipBytes(8); + + state = State.READ_COMMON_HEADER; + delegate.readGoAwayFrame(lastGoodStreamId, statusCode); + break; + + case READ_HEADERS_FRAME: + if (buffer.readableBytes() < 4) { + return; + } + + streamId = getUnsignedInt(buffer, buffer.readerIndex()); + last = hasFlag(flags, SPDY_FLAG_FIN); + + buffer.skipBytes(4); + length -= 4; + + if (streamId == 0) { + state = State.FRAME_ERROR; + delegate.readFrameError("Invalid HEADERS Frame"); + } else { + state = State.READ_HEADER_BLOCK; + delegate.readHeadersFrame(streamId, last); + } + break; + + case READ_WINDOW_UPDATE_FRAME: + if (buffer.readableBytes() < 8) { + return; + } + + streamId = getUnsignedInt(buffer, buffer.readerIndex()); + int deltaWindowSize = getUnsignedInt(buffer, buffer.readerIndex() + 4); + buffer.skipBytes(8); + + if (deltaWindowSize == 0) { + state = State.FRAME_ERROR; + delegate.readFrameError("Invalid WINDOW_UPDATE Frame"); + } else { + state = State.READ_COMMON_HEADER; + delegate.readWindowUpdateFrame(streamId, deltaWindowSize); + } + break; + + case READ_HEADER_BLOCK: + if (length == 0) { + state = State.READ_COMMON_HEADER; + delegate.readHeaderBlockEnd(); + break; + } + + if (!buffer.isReadable()) { + return; + } + + int compressedBytes = Math.min(buffer.readableBytes(), length); + ByteBuf headerBlock = buffer.alloc().buffer(compressedBytes); + headerBlock.writeBytes(buffer, compressedBytes); + length -= compressedBytes; + + delegate.readHeaderBlock(headerBlock); + break; + + case DISCARD_FRAME: + int numBytes = Math.min(buffer.readableBytes(), length); + buffer.skipBytes(numBytes); + length -= numBytes; + if (length == 0) { + state = State.READ_COMMON_HEADER; + break; + } + return; + + case FRAME_ERROR: + buffer.skipBytes(buffer.readableBytes()); + return; + + default: + throw new Error("Shouldn't reach here."); + } + } + } + + private static boolean hasFlag(byte flags, byte flag) { + return (flags & flag) != 0; + } + + private static State getNextState(int type, int length) { + switch (type) { + case SPDY_DATA_FRAME: + return State.READ_DATA_FRAME; + + case SPDY_SYN_STREAM_FRAME: + return State.READ_SYN_STREAM_FRAME; + + case SPDY_SYN_REPLY_FRAME: + return State.READ_SYN_REPLY_FRAME; + + case SPDY_RST_STREAM_FRAME: + return State.READ_RST_STREAM_FRAME; + + case SPDY_SETTINGS_FRAME: + return State.READ_SETTINGS_FRAME; + + case SPDY_PING_FRAME: + return State.READ_PING_FRAME; + + case SPDY_GOAWAY_FRAME: + return State.READ_GOAWAY_FRAME; + + case SPDY_HEADERS_FRAME: + return State.READ_HEADERS_FRAME; + + case SPDY_WINDOW_UPDATE_FRAME: + return State.READ_WINDOW_UPDATE_FRAME; + + default: + if (length != 0) { + return State.DISCARD_FRAME; + } else { + return State.READ_COMMON_HEADER; + } + } + } + + private static boolean isValidFrameHeader(int streamId, int type, byte flags, int length) { + switch (type) { + case SPDY_DATA_FRAME: + return streamId != 0; + + case SPDY_SYN_STREAM_FRAME: + return length >= 10; + + case SPDY_SYN_REPLY_FRAME: + return length >= 4; + + case SPDY_RST_STREAM_FRAME: + return flags == 0 && length == 8; + + case SPDY_SETTINGS_FRAME: + return length >= 4; + + case SPDY_PING_FRAME: + return length == 4; + + case SPDY_GOAWAY_FRAME: + return length == 8; + + case SPDY_HEADERS_FRAME: + return length >= 4; + + case SPDY_WINDOW_UPDATE_FRAME: + return length == 8; + + default: + return true; + } + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyFrameDecoderDelegate.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyFrameDecoderDelegate.java new file mode 100644 index 0000000..524af93 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyFrameDecoderDelegate.java @@ -0,0 +1,99 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.spdy; + +import io.netty.buffer.ByteBuf; + +/** + * Callback interface for {@link SpdyFrameDecoder}. + */ +public interface SpdyFrameDecoderDelegate { + + /** + * Called when a DATA frame is received. + */ + void readDataFrame(int streamId, boolean last, ByteBuf data); + + /** + * Called when a SYN_STREAM frame is received. + * The Name/Value Header Block is not included. See readHeaderBlock(). + */ + void readSynStreamFrame( + int streamId, int associatedToStreamId, byte priority, boolean last, boolean unidirectional); + + /** + * Called when a SYN_REPLY frame is received. + * The Name/Value Header Block is not included. See readHeaderBlock(). + */ + void readSynReplyFrame(int streamId, boolean last); + + /** + * Called when a RST_STREAM frame is received. + */ + void readRstStreamFrame(int streamId, int statusCode); + + /** + * Called when a SETTINGS frame is received. + * Settings are not included. See readSetting(). + */ + void readSettingsFrame(boolean clearPersisted); + + /** + * Called when an individual setting within a SETTINGS frame is received. + */ + void readSetting(int id, int value, boolean persistValue, boolean persisted); + + /** + * Called when the entire SETTINGS frame has been received. + */ + void readSettingsEnd(); + + /** + * Called when a PING frame is received. + */ + void readPingFrame(int id); + + /** + * Called when a GOAWAY frame is received. + */ + void readGoAwayFrame(int lastGoodStreamId, int statusCode); + + /** + * Called when a HEADERS frame is received. + * The Name/Value Header Block is not included. See readHeaderBlock(). + */ + void readHeadersFrame(int streamId, boolean last); + + /** + * Called when a WINDOW_UPDATE frame is received. + */ + void readWindowUpdateFrame(int streamId, int deltaWindowSize); + + /** + * Called when the header block within a SYN_STREAM, SYN_REPLY, or HEADERS frame is received. + */ + void readHeaderBlock(ByteBuf headerBlock); + + /** + * Called when an entire header block has been received. + */ + void readHeaderBlockEnd(); + + /** + * Called when an unrecoverable session error has occurred. + */ + void readFrameError(String message); +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyFrameEncoder.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyFrameEncoder.java new file mode 100644 index 0000000..7aacee7 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyFrameEncoder.java @@ -0,0 +1,161 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.spdy; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.util.internal.ObjectUtil; + +import java.nio.ByteOrder; +import java.util.Set; + +import static io.netty.handler.codec.spdy.SpdyCodecUtil.*; + +/** + * Encodes a SPDY Frame into a {@link ByteBuf}. + */ +public class SpdyFrameEncoder { + + private final int version; + + /** + * Creates a new instance with the specified {@code spdyVersion}. + */ + public SpdyFrameEncoder(SpdyVersion spdyVersion) { + version = ObjectUtil.checkNotNull(spdyVersion, "spdyVersion").getVersion(); + } + + private void writeControlFrameHeader(ByteBuf buffer, int type, byte flags, int length) { + buffer.writeShort(version | 0x8000); + buffer.writeShort(type); + buffer.writeByte(flags); + buffer.writeMedium(length); + } + + public ByteBuf encodeDataFrame(ByteBufAllocator allocator, int streamId, boolean last, ByteBuf data) { + byte flags = last ? SPDY_DATA_FLAG_FIN : 0; + int length = data.readableBytes(); + ByteBuf frame = allocator.ioBuffer(SPDY_HEADER_SIZE + length).order(ByteOrder.BIG_ENDIAN); + frame.writeInt(streamId & 0x7FFFFFFF); + frame.writeByte(flags); + frame.writeMedium(length); + frame.writeBytes(data, data.readerIndex(), length); + return frame; + } + + public ByteBuf encodeSynStreamFrame(ByteBufAllocator allocator, int streamId, int associatedToStreamId, + byte priority, boolean last, boolean unidirectional, ByteBuf headerBlock) { + int headerBlockLength = headerBlock.readableBytes(); + byte flags = last ? SPDY_FLAG_FIN : 0; + if (unidirectional) { + flags |= SPDY_FLAG_UNIDIRECTIONAL; + } + int length = 10 + headerBlockLength; + ByteBuf frame = allocator.ioBuffer(SPDY_HEADER_SIZE + length).order(ByteOrder.BIG_ENDIAN); + writeControlFrameHeader(frame, SPDY_SYN_STREAM_FRAME, flags, length); + frame.writeInt(streamId); + frame.writeInt(associatedToStreamId); + frame.writeShort((priority & 0xFF) << 13); + frame.writeBytes(headerBlock, headerBlock.readerIndex(), headerBlockLength); + return frame; + } + + public ByteBuf encodeSynReplyFrame(ByteBufAllocator allocator, int streamId, boolean last, ByteBuf headerBlock) { + int headerBlockLength = headerBlock.readableBytes(); + byte flags = last ? SPDY_FLAG_FIN : 0; + int length = 4 + headerBlockLength; + ByteBuf frame = allocator.ioBuffer(SPDY_HEADER_SIZE + length).order(ByteOrder.BIG_ENDIAN); + writeControlFrameHeader(frame, SPDY_SYN_REPLY_FRAME, flags, length); + frame.writeInt(streamId); + frame.writeBytes(headerBlock, headerBlock.readerIndex(), headerBlockLength); + return frame; + } + + public ByteBuf encodeRstStreamFrame(ByteBufAllocator allocator, int streamId, int statusCode) { + byte flags = 0; + int length = 8; + ByteBuf frame = allocator.ioBuffer(SPDY_HEADER_SIZE + length).order(ByteOrder.BIG_ENDIAN); + writeControlFrameHeader(frame, SPDY_RST_STREAM_FRAME, flags, length); + frame.writeInt(streamId); + frame.writeInt(statusCode); + return frame; + } + + public ByteBuf encodeSettingsFrame(ByteBufAllocator allocator, SpdySettingsFrame spdySettingsFrame) { + Set ids = spdySettingsFrame.ids(); + int numSettings = ids.size(); + + byte flags = spdySettingsFrame.clearPreviouslyPersistedSettings() ? + SPDY_SETTINGS_CLEAR : 0; + int length = 4 + 8 * numSettings; + ByteBuf frame = allocator.ioBuffer(SPDY_HEADER_SIZE + length).order(ByteOrder.BIG_ENDIAN); + writeControlFrameHeader(frame, SPDY_SETTINGS_FRAME, flags, length); + frame.writeInt(numSettings); + for (Integer id : ids) { + flags = 0; + if (spdySettingsFrame.isPersistValue(id)) { + flags |= SPDY_SETTINGS_PERSIST_VALUE; + } + if (spdySettingsFrame.isPersisted(id)) { + flags |= SPDY_SETTINGS_PERSISTED; + } + frame.writeByte(flags); + frame.writeMedium(id); + frame.writeInt(spdySettingsFrame.getValue(id)); + } + return frame; + } + + public ByteBuf encodePingFrame(ByteBufAllocator allocator, int id) { + byte flags = 0; + int length = 4; + ByteBuf frame = allocator.ioBuffer(SPDY_HEADER_SIZE + length).order(ByteOrder.BIG_ENDIAN); + writeControlFrameHeader(frame, SPDY_PING_FRAME, flags, length); + frame.writeInt(id); + return frame; + } + + public ByteBuf encodeGoAwayFrame(ByteBufAllocator allocator, int lastGoodStreamId, int statusCode) { + byte flags = 0; + int length = 8; + ByteBuf frame = allocator.ioBuffer(SPDY_HEADER_SIZE + length).order(ByteOrder.BIG_ENDIAN); + writeControlFrameHeader(frame, SPDY_GOAWAY_FRAME, flags, length); + frame.writeInt(lastGoodStreamId); + frame.writeInt(statusCode); + return frame; + } + + public ByteBuf encodeHeadersFrame(ByteBufAllocator allocator, int streamId, boolean last, ByteBuf headerBlock) { + int headerBlockLength = headerBlock.readableBytes(); + byte flags = last ? SPDY_FLAG_FIN : 0; + int length = 4 + headerBlockLength; + ByteBuf frame = allocator.ioBuffer(SPDY_HEADER_SIZE + length).order(ByteOrder.BIG_ENDIAN); + writeControlFrameHeader(frame, SPDY_HEADERS_FRAME, flags, length); + frame.writeInt(streamId); + frame.writeBytes(headerBlock, headerBlock.readerIndex(), headerBlockLength); + return frame; + } + + public ByteBuf encodeWindowUpdateFrame(ByteBufAllocator allocator, int streamId, int deltaWindowSize) { + byte flags = 0; + int length = 8; + ByteBuf frame = allocator.ioBuffer(SPDY_HEADER_SIZE + length).order(ByteOrder.BIG_ENDIAN); + writeControlFrameHeader(frame, SPDY_WINDOW_UPDATE_FRAME, flags, length); + frame.writeInt(streamId); + frame.writeInt(deltaWindowSize); + return frame; + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyGoAwayFrame.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyGoAwayFrame.java new file mode 100644 index 0000000..a098a71 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyGoAwayFrame.java @@ -0,0 +1,43 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.spdy; + +/** + * A SPDY Protocol GOAWAY Frame + */ +public interface SpdyGoAwayFrame extends SpdyFrame { + + /** + * Returns the Last-good-stream-ID of this frame. + */ + int lastGoodStreamId(); + + /** + * Sets the Last-good-stream-ID of this frame. The Last-good-stream-ID + * cannot be negative. + */ + SpdyGoAwayFrame setLastGoodStreamId(int lastGoodStreamId); + + /** + * Returns the status of this frame. + */ + SpdySessionStatus status(); + + /** + * Sets the status of this frame. + */ + SpdyGoAwayFrame setStatus(SpdySessionStatus status); +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyHeaderBlockDecoder.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyHeaderBlockDecoder.java new file mode 100644 index 0000000..91ec6d1 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyHeaderBlockDecoder.java @@ -0,0 +1,50 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.spdy; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; + +/** + * Super-class for SPDY header-block decoders. + * + * @see SpdyHeaderBlockRawDecoder + * @see SpdyHeaderBlockZlibDecoder + */ +public abstract class SpdyHeaderBlockDecoder { + + static SpdyHeaderBlockDecoder newInstance(SpdyVersion spdyVersion, int maxHeaderSize) { + return new SpdyHeaderBlockZlibDecoder(spdyVersion, maxHeaderSize); + } + + /** + * Decodes a SPDY Header Block, adding the Name/Value pairs to the given Headers frame. + * If the header block is malformed, the Headers frame will be marked as invalid. + * A stream error with status code PROTOCOL_ERROR must be issued in response to an invalid frame. + * + * @param alloc the {@link ByteBufAllocator} which can be used to allocate new {@link ByteBuf}s + * @param headerBlock the HeaderBlock to decode + * @param frame the Headers frame that receives the Name/Value pairs + * @throws Exception If the header block is malformed in a way that prevents any future + * decoding of any other header blocks, an exception will be thrown. + * A session error with status code PROTOCOL_ERROR must be issued. + */ + abstract void decode(ByteBufAllocator alloc, ByteBuf headerBlock, SpdyHeadersFrame frame) throws Exception; + + abstract void endHeaderBlock(SpdyHeadersFrame frame) throws Exception; + + abstract void end(); +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyHeaderBlockEncoder.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyHeaderBlockEncoder.java new file mode 100644 index 0000000..69217ca --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyHeaderBlockEncoder.java @@ -0,0 +1,45 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.spdy; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.util.internal.PlatformDependent; + +/** + * Super-class for SPDY header-block encoders. + * + * @see SpdyHeaderBlockZlibEncoder + * @see SpdyHeaderBlockJZlibEncoder + * @see SpdyHeaderBlockRawEncoder + */ +public abstract class SpdyHeaderBlockEncoder { + + static SpdyHeaderBlockEncoder newInstance( + SpdyVersion version, int compressionLevel, int windowBits, int memLevel) { + + if (PlatformDependent.javaVersion() >= 7) { + return new SpdyHeaderBlockZlibEncoder( + version, compressionLevel); + } else { + return new SpdyHeaderBlockJZlibEncoder( + version, compressionLevel, windowBits, memLevel); + } + } + + abstract ByteBuf encode(ByteBufAllocator alloc, SpdyHeadersFrame frame) throws Exception; + abstract void end(); +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyHeaderBlockJZlibEncoder.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyHeaderBlockJZlibEncoder.java new file mode 100644 index 0000000..638d33a --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyHeaderBlockJZlibEncoder.java @@ -0,0 +1,141 @@ +package io.netty.handler.codec.spdy; + +import io.netty.zlib.Deflater; +import io.netty.zlib.JZlib; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.netty.handler.codec.compression.CompressionException; + +import static io.netty.handler.codec.spdy.SpdyCodecUtil.*; +import static io.netty.util.internal.ObjectUtil.checkNotNullWithIAE; + +class SpdyHeaderBlockJZlibEncoder extends SpdyHeaderBlockRawEncoder { + + private final Deflater z = new Deflater(); + + private boolean finished; + + SpdyHeaderBlockJZlibEncoder( + SpdyVersion version, int compressionLevel, int windowBits, int memLevel) { + super(version); + if (compressionLevel < 0 || compressionLevel > 9) { + throw new IllegalArgumentException( + "compressionLevel: " + compressionLevel + " (expected: 0-9)"); + } + if (windowBits < 9 || windowBits > 15) { + throw new IllegalArgumentException( + "windowBits: " + windowBits + " (expected: 9-15)"); + } + if (memLevel < 1 || memLevel > 9) { + throw new IllegalArgumentException( + "memLevel: " + memLevel + " (expected: 1-9)"); + } + + int resultCode = z.deflateInit( + compressionLevel, windowBits, memLevel, JZlib.W_ZLIB); + if (resultCode != JZlib.Z_OK) { + throw new CompressionException( + "failed to initialize an SPDY header block deflater: " + resultCode); + } else { + resultCode = z.deflateSetDictionary(SPDY_DICT, SPDY_DICT.length); + if (resultCode != JZlib.Z_OK) { + throw new CompressionException( + "failed to set the SPDY dictionary: " + resultCode); + } + } + } + + private void setInput(ByteBuf decompressed) { + int len = decompressed.readableBytes(); + + byte[] in; + int offset; + if (decompressed.hasArray()) { + in = decompressed.array(); + offset = decompressed.arrayOffset() + decompressed.readerIndex(); + } else { + in = new byte[len]; + decompressed.getBytes(decompressed.readerIndex(), in); + offset = 0; + } + z.next_in = in; + z.next_in_index = offset; + z.avail_in = len; + } + + private ByteBuf encode(ByteBufAllocator alloc) { + boolean release = true; + ByteBuf out = null; + try { + int oldNextInIndex = z.next_in_index; + int oldNextOutIndex = z.next_out_index; + + int maxOutputLength = (int) Math.ceil(z.next_in.length * 1.001) + 12; + out = alloc.heapBuffer(maxOutputLength); + z.next_out = out.array(); + z.next_out_index = out.arrayOffset() + out.writerIndex(); + z.avail_out = maxOutputLength; + + int resultCode; + try { + resultCode = z.deflate(JZlib.Z_SYNC_FLUSH); + } finally { + out.skipBytes(z.next_in_index - oldNextInIndex); + } + if (resultCode != JZlib.Z_OK) { + throw new CompressionException("compression failure: " + resultCode); + } + + int outputLength = z.next_out_index - oldNextOutIndex; + if (outputLength > 0) { + out.writerIndex(out.writerIndex() + outputLength); + } + release = false; + return out; + } finally { + // Deference the external references explicitly to tell the VM that + // the allocated byte arrays are temporary so that the call stack + // can be utilized. + // I'm not sure if the modern VMs do this optimization though. + z.next_in = null; + z.next_out = null; + if (release && out != null) { + out.release(); + } + } + } + + @Override + public ByteBuf encode(ByteBufAllocator alloc, SpdyHeadersFrame frame) throws Exception { + checkNotNullWithIAE(alloc, "alloc"); + checkNotNullWithIAE(frame, "frame"); + + if (finished) { + return Unpooled.EMPTY_BUFFER; + } + + ByteBuf decompressed = super.encode(alloc, frame); + try { + if (!decompressed.isReadable()) { + return Unpooled.EMPTY_BUFFER; + } + + setInput(decompressed); + return encode(alloc); + } finally { + decompressed.release(); + } + } + + @Override + public void end() { + if (finished) { + return; + } + finished = true; + z.deflateEnd(); + z.next_in = null; + z.next_out = null; + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyHeaderBlockRawDecoder.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyHeaderBlockRawDecoder.java new file mode 100644 index 0000000..28e6898 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyHeaderBlockRawDecoder.java @@ -0,0 +1,306 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.spdy; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.util.internal.ObjectUtil; + +import static io.netty.handler.codec.spdy.SpdyCodecUtil.getSignedInt; + +public class SpdyHeaderBlockRawDecoder extends SpdyHeaderBlockDecoder { + + private static final int LENGTH_FIELD_SIZE = 4; + + private final int maxHeaderSize; + + private State state; + + private ByteBuf cumulation; + + private int headerSize; + private int numHeaders; + private int length; + private String name; + + private enum State { + READ_NUM_HEADERS, + READ_NAME_LENGTH, + READ_NAME, + SKIP_NAME, + READ_VALUE_LENGTH, + READ_VALUE, + SKIP_VALUE, + END_HEADER_BLOCK, + ERROR + } + + public SpdyHeaderBlockRawDecoder(SpdyVersion spdyVersion, int maxHeaderSize) { + ObjectUtil.checkNotNull(spdyVersion, "spdyVersion"); + this.maxHeaderSize = maxHeaderSize; + state = State.READ_NUM_HEADERS; + } + + private static int readLengthField(ByteBuf buffer) { + int length = getSignedInt(buffer, buffer.readerIndex()); + buffer.skipBytes(LENGTH_FIELD_SIZE); + return length; + } + + @Override + void decode(ByteBufAllocator alloc, ByteBuf headerBlock, SpdyHeadersFrame frame) throws Exception { + ObjectUtil.checkNotNull(headerBlock, "headerBlock"); + ObjectUtil.checkNotNull(frame, "frame"); + + if (cumulation == null) { + decodeHeaderBlock(headerBlock, frame); + if (headerBlock.isReadable()) { + cumulation = alloc.buffer(headerBlock.readableBytes()); + cumulation.writeBytes(headerBlock); + } + } else { + cumulation.writeBytes(headerBlock); + decodeHeaderBlock(cumulation, frame); + if (cumulation.isReadable()) { + cumulation.discardReadBytes(); + } else { + releaseBuffer(); + } + } + } + + protected void decodeHeaderBlock(ByteBuf headerBlock, SpdyHeadersFrame frame) throws Exception { + int skipLength; + while (headerBlock.isReadable()) { + switch(state) { + case READ_NUM_HEADERS: + if (headerBlock.readableBytes() < LENGTH_FIELD_SIZE) { + return; + } + + numHeaders = readLengthField(headerBlock); + + if (numHeaders < 0) { + state = State.ERROR; + frame.setInvalid(); + } else if (numHeaders == 0) { + state = State.END_HEADER_BLOCK; + } else { + state = State.READ_NAME_LENGTH; + } + break; + + case READ_NAME_LENGTH: + if (headerBlock.readableBytes() < LENGTH_FIELD_SIZE) { + return; + } + + length = readLengthField(headerBlock); + + // Recipients of a zero-length name must issue a stream error + if (length <= 0) { + state = State.ERROR; + frame.setInvalid(); + } else if (length > maxHeaderSize || headerSize > maxHeaderSize - length) { + headerSize = maxHeaderSize + 1; + state = State.SKIP_NAME; + frame.setTruncated(); + } else { + headerSize += length; + state = State.READ_NAME; + } + break; + + case READ_NAME: + if (headerBlock.readableBytes() < length) { + return; + } + + byte[] nameBytes = new byte[length]; + headerBlock.readBytes(nameBytes); + name = new String(nameBytes, "UTF-8"); + + // Check for identically named headers + if (frame.headers().contains(name)) { + state = State.ERROR; + frame.setInvalid(); + } else { + state = State.READ_VALUE_LENGTH; + } + break; + + case SKIP_NAME: + skipLength = Math.min(headerBlock.readableBytes(), length); + headerBlock.skipBytes(skipLength); + length -= skipLength; + + if (length == 0) { + state = State.READ_VALUE_LENGTH; + } + break; + + case READ_VALUE_LENGTH: + if (headerBlock.readableBytes() < LENGTH_FIELD_SIZE) { + return; + } + + length = readLengthField(headerBlock); + + // Recipients of illegal value fields must issue a stream error + if (length < 0) { + state = State.ERROR; + frame.setInvalid(); + } else if (length == 0) { + if (!frame.isTruncated()) { + // SPDY/3 allows zero-length (empty) header values + frame.headers().add(name, ""); + } + + name = null; + if (--numHeaders == 0) { + state = State.END_HEADER_BLOCK; + } else { + state = State.READ_NAME_LENGTH; + } + + } else if (length > maxHeaderSize || headerSize > maxHeaderSize - length) { + headerSize = maxHeaderSize + 1; + name = null; + state = State.SKIP_VALUE; + frame.setTruncated(); + } else { + headerSize += length; + state = State.READ_VALUE; + } + break; + + case READ_VALUE: + if (headerBlock.readableBytes() < length) { + return; + } + + byte[] valueBytes = new byte[length]; + headerBlock.readBytes(valueBytes); + + // Add Name/Value pair to headers + int index = 0; + int offset = 0; + + // Value must not start with a NULL character + if (valueBytes[0] == (byte) 0) { + state = State.ERROR; + frame.setInvalid(); + break; + } + + while (index < length) { + while (index < valueBytes.length && valueBytes[index] != (byte) 0) { + index ++; + } + if (index < valueBytes.length) { + // Received NULL character + if (index + 1 == valueBytes.length || valueBytes[index + 1] == (byte) 0) { + // Value field ended with a NULL character or + // received multiple, in-sequence NULL characters. + // Recipients of illegal value fields must issue a stream error + state = State.ERROR; + frame.setInvalid(); + break; + } + } + String value = new String(valueBytes, offset, index - offset, "UTF-8"); + + try { + frame.headers().add(name, value); + } catch (IllegalArgumentException e) { + // Name contains NULL or non-ascii characters + state = State.ERROR; + frame.setInvalid(); + break; + } + index ++; + offset = index; + } + + name = null; + + // If we broke out of the add header loop, break here + if (state == State.ERROR) { + break; + } + + if (--numHeaders == 0) { + state = State.END_HEADER_BLOCK; + } else { + state = State.READ_NAME_LENGTH; + } + break; + + case SKIP_VALUE: + skipLength = Math.min(headerBlock.readableBytes(), length); + headerBlock.skipBytes(skipLength); + length -= skipLength; + + if (length == 0) { + if (--numHeaders == 0) { + state = State.END_HEADER_BLOCK; + } else { + state = State.READ_NAME_LENGTH; + } + } + break; + + case END_HEADER_BLOCK: + state = State.ERROR; + frame.setInvalid(); + break; + + case ERROR: + headerBlock.skipBytes(headerBlock.readableBytes()); + return; + + default: + throw new Error("Shouldn't reach here."); + } + } + } + + @Override + void endHeaderBlock(SpdyHeadersFrame frame) throws Exception { + if (state != State.END_HEADER_BLOCK) { + frame.setInvalid(); + } + + releaseBuffer(); + + // Initialize header block decoding fields + headerSize = 0; + name = null; + state = State.READ_NUM_HEADERS; + } + + @Override + void end() { + releaseBuffer(); + } + + private void releaseBuffer() { + if (cumulation != null) { + cumulation.release(); + cumulation = null; + } + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyHeaderBlockRawEncoder.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyHeaderBlockRawEncoder.java new file mode 100644 index 0000000..7d3d91d --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyHeaderBlockRawEncoder.java @@ -0,0 +1,89 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.spdy; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import io.netty.util.internal.ObjectUtil; + +import java.util.Set; + +import static io.netty.handler.codec.spdy.SpdyCodecUtil.SPDY_MAX_NV_LENGTH; + +public class SpdyHeaderBlockRawEncoder extends SpdyHeaderBlockEncoder { + + private final int version; + + public SpdyHeaderBlockRawEncoder(SpdyVersion version) { + this.version = ObjectUtil.checkNotNull(version, "version").getVersion(); + } + + private static void setLengthField(ByteBuf buffer, int writerIndex, int length) { + buffer.setInt(writerIndex, length); + } + + private static void writeLengthField(ByteBuf buffer, int length) { + buffer.writeInt(length); + } + + @Override + public ByteBuf encode(ByteBufAllocator alloc, SpdyHeadersFrame frame) throws Exception { + Set names = frame.headers().names(); + int numHeaders = names.size(); + if (numHeaders == 0) { + return Unpooled.EMPTY_BUFFER; + } + if (numHeaders > SPDY_MAX_NV_LENGTH) { + throw new IllegalArgumentException( + "header block contains too many headers"); + } + ByteBuf headerBlock = alloc.heapBuffer(); + writeLengthField(headerBlock, numHeaders); + for (CharSequence name: names) { + writeLengthField(headerBlock, name.length()); + ByteBufUtil.writeAscii(headerBlock, name); + int savedIndex = headerBlock.writerIndex(); + int valueLength = 0; + writeLengthField(headerBlock, valueLength); + for (CharSequence value: frame.headers().getAll(name)) { + int length = value.length(); + if (length > 0) { + ByteBufUtil.writeAscii(headerBlock, value); + headerBlock.writeByte(0); + valueLength += length + 1; + } + } + if (valueLength != 0) { + valueLength --; + } + if (valueLength > SPDY_MAX_NV_LENGTH) { + throw new IllegalArgumentException( + "header exceeds allowable length: " + name); + } + if (valueLength > 0) { + setLengthField(headerBlock, savedIndex, valueLength); + headerBlock.writerIndex(headerBlock.writerIndex() - 1); + } + } + return headerBlock; + } + + @Override + void end() { + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyHeaderBlockZlibDecoder.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyHeaderBlockZlibDecoder.java new file mode 100644 index 0000000..2db45ed --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyHeaderBlockZlibDecoder.java @@ -0,0 +1,125 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.spdy; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; + +import java.util.zip.DataFormatException; +import java.util.zip.Inflater; + +import static io.netty.handler.codec.spdy.SpdyCodecUtil.*; + +final class SpdyHeaderBlockZlibDecoder extends SpdyHeaderBlockRawDecoder { + + private static final int DEFAULT_BUFFER_CAPACITY = 4096; + private static final SpdyProtocolException INVALID_HEADER_BLOCK = + new SpdyProtocolException("Invalid Header Block"); + + private final Inflater decompressor = new Inflater(); + + private ByteBuf decompressed; + + SpdyHeaderBlockZlibDecoder(SpdyVersion spdyVersion, int maxHeaderSize) { + super(spdyVersion, maxHeaderSize); + } + + @Override + void decode(ByteBufAllocator alloc, ByteBuf headerBlock, SpdyHeadersFrame frame) throws Exception { + int len = setInput(headerBlock); + + int numBytes; + do { + numBytes = decompress(alloc, frame); + } while (numBytes > 0); + + // z_stream has an internal 64-bit hold buffer + // it is always capable of consuming the entire input + if (decompressor.getRemaining() != 0) { + // we reached the end of the deflate stream + throw INVALID_HEADER_BLOCK; + } + + headerBlock.skipBytes(len); + } + + private int setInput(ByteBuf compressed) { + int len = compressed.readableBytes(); + + if (compressed.hasArray()) { + decompressor.setInput(compressed.array(), compressed.arrayOffset() + compressed.readerIndex(), len); + } else { + byte[] in = new byte[len]; + compressed.getBytes(compressed.readerIndex(), in); + decompressor.setInput(in, 0, in.length); + } + + return len; + } + + private int decompress(ByteBufAllocator alloc, SpdyHeadersFrame frame) throws Exception { + ensureBuffer(alloc); + byte[] out = decompressed.array(); + int off = decompressed.arrayOffset() + decompressed.writerIndex(); + try { + int numBytes = decompressor.inflate(out, off, decompressed.writableBytes()); + if (numBytes == 0 && decompressor.needsDictionary()) { + try { + decompressor.setDictionary(SPDY_DICT); + } catch (IllegalArgumentException ignored) { + throw INVALID_HEADER_BLOCK; + } + numBytes = decompressor.inflate(out, off, decompressed.writableBytes()); + } + if (frame != null) { + decompressed.writerIndex(decompressed.writerIndex() + numBytes); + decodeHeaderBlock(decompressed, frame); + decompressed.discardReadBytes(); + } + + return numBytes; + } catch (DataFormatException e) { + throw new SpdyProtocolException("Received invalid header block", e); + } + } + + private void ensureBuffer(ByteBufAllocator alloc) { + if (decompressed == null) { + decompressed = alloc.heapBuffer(DEFAULT_BUFFER_CAPACITY); + } + decompressed.ensureWritable(1); + } + + @Override + void endHeaderBlock(SpdyHeadersFrame frame) throws Exception { + super.endHeaderBlock(frame); + releaseBuffer(); + } + + @Override + public void end() { + super.end(); + releaseBuffer(); + decompressor.end(); + } + + private void releaseBuffer() { + if (decompressed != null) { + decompressed.release(); + decompressed = null; + } + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyHeaderBlockZlibEncoder.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyHeaderBlockZlibEncoder.java new file mode 100644 index 0000000..21a18f9 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyHeaderBlockZlibEncoder.java @@ -0,0 +1,122 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.spdy; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.SuppressJava6Requirement; + +import java.util.zip.Deflater; + +import static io.netty.handler.codec.spdy.SpdyCodecUtil.*; +import static io.netty.util.internal.ObjectUtil.checkNotNullWithIAE; + +class SpdyHeaderBlockZlibEncoder extends SpdyHeaderBlockRawEncoder { + + private final Deflater compressor; + + private boolean finished; + + SpdyHeaderBlockZlibEncoder(SpdyVersion spdyVersion, int compressionLevel) { + super(spdyVersion); + if (compressionLevel < 0 || compressionLevel > 9) { + throw new IllegalArgumentException( + "compressionLevel: " + compressionLevel + " (expected: 0-9)"); + } + compressor = new Deflater(compressionLevel); + compressor.setDictionary(SPDY_DICT); + } + + private int setInput(ByteBuf decompressed) { + int len = decompressed.readableBytes(); + + if (decompressed.hasArray()) { + compressor.setInput(decompressed.array(), decompressed.arrayOffset() + decompressed.readerIndex(), len); + } else { + byte[] in = new byte[len]; + decompressed.getBytes(decompressed.readerIndex(), in); + compressor.setInput(in, 0, in.length); + } + + return len; + } + + private ByteBuf encode(ByteBufAllocator alloc, int len) { + ByteBuf compressed = alloc.heapBuffer(len); + boolean release = true; + try { + while (compressInto(compressed)) { + // Although unlikely, it's possible that the compressed size is larger than the decompressed size + compressed.ensureWritable(compressed.capacity() << 1); + } + release = false; + return compressed; + } finally { + if (release) { + compressed.release(); + } + } + } + + @SuppressJava6Requirement(reason = "Guarded by java version check") + private boolean compressInto(ByteBuf compressed) { + byte[] out = compressed.array(); + int off = compressed.arrayOffset() + compressed.writerIndex(); + int toWrite = compressed.writableBytes(); + final int numBytes; + if (PlatformDependent.javaVersion() >= 7) { + numBytes = compressor.deflate(out, off, toWrite, Deflater.SYNC_FLUSH); + } else { + numBytes = compressor.deflate(out, off, toWrite); + } + compressed.writerIndex(compressed.writerIndex() + numBytes); + return numBytes == toWrite; + } + + @Override + public ByteBuf encode(ByteBufAllocator alloc, SpdyHeadersFrame frame) throws Exception { + checkNotNullWithIAE(alloc, "alloc"); + checkNotNullWithIAE(frame, "frame"); + + if (finished) { + return Unpooled.EMPTY_BUFFER; + } + + ByteBuf decompressed = super.encode(alloc, frame); + try { + if (!decompressed.isReadable()) { + return Unpooled.EMPTY_BUFFER; + } + + int len = setInput(decompressed); + return encode(alloc, len); + } finally { + decompressed.release(); + } + } + + @Override + public void end() { + if (finished) { + return; + } + finished = true; + compressor.end(); + super.end(); + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyHeaders.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyHeaders.java new file mode 100644 index 0000000..087e54f --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyHeaders.java @@ -0,0 +1,92 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.spdy; + +import io.netty.handler.codec.Headers; +import io.netty.util.AsciiString; + +import java.util.Iterator; +import java.util.List; +import java.util.Map.Entry; + +/** + * Provides the constants for the standard SPDY HTTP header names and commonly + * used utility methods that access a {@link SpdyHeadersFrame}. + */ +public interface SpdyHeaders extends Headers { + + /** + * SPDY HTTP header names + */ + final class HttpNames { + /** + * {@code ":host"} + */ + public static final AsciiString HOST = AsciiString.cached(":host"); + /** + * {@code ":method"} + */ + public static final AsciiString METHOD = AsciiString.cached(":method"); + /** + * {@code ":path"} + */ + public static final AsciiString PATH = AsciiString.cached(":path"); + /** + * {@code ":scheme"} + */ + public static final AsciiString SCHEME = AsciiString.cached(":scheme"); + /** + * {@code ":status"} + */ + public static final AsciiString STATUS = AsciiString.cached(":status"); + /** + * {@code ":version"} + */ + public static final AsciiString VERSION = AsciiString.cached(":version"); + + private HttpNames() { } + } + + /** + * {@link Headers#get(Object)} and convert the result to a {@link String}. + * @param name the name of the header to retrieve + * @return the first header value if the header is found. {@code null} if there's no such header. + */ + String getAsString(CharSequence name); + + /** + * {@link Headers#getAll(Object)} and convert each element of {@link List} to a {@link String}. + * @param name the name of the header to retrieve + * @return a {@link List} of header values or an empty {@link List} if no values are found. + */ + List getAllAsString(CharSequence name); + + /** + * {@link #iterator()} that converts each {@link Entry}'s key and value to a {@link String}. + */ + Iterator> iteratorAsString(); + + /** + * Returns {@code true} if a header with the {@code name} and {@code value} exists, {@code false} otherwise. + *

+ * If {@code ignoreCase} is {@code true} then a case insensitive compare is done on the value. + * @param name the name of the header to find + * @param value the value of the header to find + * @param ignoreCase {@code true} then a case insensitive compare is run to compare values. + * otherwise a case sensitive compare is run to compare values. + */ + boolean contains(CharSequence name, CharSequence value, boolean ignoreCase); +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyHeadersFrame.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyHeadersFrame.java new file mode 100644 index 0000000..f2524bb --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyHeadersFrame.java @@ -0,0 +1,55 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.spdy; + +/** + * A SPDY Protocol HEADERS Frame + */ +public interface SpdyHeadersFrame extends SpdyStreamFrame { + + /** + * Returns {@code true} if this header block is invalid. + * A RST_STREAM frame with code PROTOCOL_ERROR should be sent. + */ + boolean isInvalid(); + + /** + * Marks this header block as invalid. + */ + SpdyHeadersFrame setInvalid(); + + /** + * Returns {@code true} if this header block has been truncated due to + * length restrictions. + */ + boolean isTruncated(); + + /** + * Mark this header block as truncated. + */ + SpdyHeadersFrame setTruncated(); + + /** + * Returns the {@link SpdyHeaders}. + */ + SpdyHeaders headers(); + + @Override + SpdyHeadersFrame setStreamId(int streamID); + + @Override + SpdyHeadersFrame setLast(boolean last); +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyHttpCodec.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyHttpCodec.java new file mode 100644 index 0000000..50ac28c --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyHttpCodec.java @@ -0,0 +1,51 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.spdy; + +import io.netty.channel.CombinedChannelDuplexHandler; +import io.netty.handler.codec.http.FullHttpMessage; +import io.netty.handler.codec.http.HttpHeadersFactory; + +import java.util.HashMap; + +/** + * A combination of {@link SpdyHttpDecoder} and {@link SpdyHttpEncoder} + */ +public final class SpdyHttpCodec extends CombinedChannelDuplexHandler { + /** + * Creates a new instance with the specified decoder options. + */ + public SpdyHttpCodec(SpdyVersion version, int maxContentLength) { + super(new SpdyHttpDecoder(version, maxContentLength), new SpdyHttpEncoder(version)); + } + + /** + * Creates a new instance with the specified decoder options. + */ + @Deprecated + public SpdyHttpCodec(SpdyVersion version, int maxContentLength, boolean validateHttpHeaders) { + super(new SpdyHttpDecoder(version, maxContentLength, validateHttpHeaders), new SpdyHttpEncoder(version)); + } + + /** + * Creates a new instance with the specified decoder options. + */ + public SpdyHttpCodec(SpdyVersion version, int maxContentLength, + HttpHeadersFactory headersFactory, HttpHeadersFactory trailersFactory) { + super(new SpdyHttpDecoder(version, maxContentLength, new HashMap(), + headersFactory, trailersFactory), new SpdyHttpEncoder(version)); + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyHttpDecoder.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyHttpDecoder.java new file mode 100644 index 0000000..e07ff7a --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyHttpDecoder.java @@ -0,0 +1,463 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.spdy; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.MessageToMessageDecoder; +import io.netty.handler.codec.TooLongFrameException; +import io.netty.handler.codec.http.DefaultFullHttpRequest; +import io.netty.handler.codec.http.DefaultFullHttpResponse; +import io.netty.handler.codec.http.FullHttpMessage; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.DefaultHttpHeadersFactory; +import io.netty.handler.codec.http.HttpHeadersFactory; +import io.netty.handler.codec.http.HttpUtil; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.handler.codec.http.HttpVersion; +import io.netty.handler.codec.spdy.SpdyHttpHeaders.Names; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.internal.ObjectUtil; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static io.netty.handler.codec.spdy.SpdyHeaders.HttpNames.*; +import static io.netty.util.internal.ObjectUtil.checkPositive; + +/** + * Decodes {@link SpdySynStreamFrame}s, {@link SpdySynReplyFrame}s, + * and {@link SpdyDataFrame}s into {@link FullHttpRequest}s and {@link FullHttpResponse}s. + */ +public class SpdyHttpDecoder extends MessageToMessageDecoder { + + private final int spdyVersion; + private final int maxContentLength; + private final Map messageMap; + private final HttpHeadersFactory headersFactory; + private final HttpHeadersFactory trailersFactory; + + /** + * Creates a new instance. + * + * @param version the protocol version + * @param maxContentLength the maximum length of the message content. + * If the length of the message content exceeds this value, + * a {@link TooLongFrameException} will be raised. + */ + public SpdyHttpDecoder(SpdyVersion version, int maxContentLength) { + this(version, maxContentLength, new HashMap(), + DefaultHttpHeadersFactory.headersFactory(), DefaultHttpHeadersFactory.trailersFactory()); + } + + /** + * Creates a new instance. + * + * @param version the protocol version + * @param maxContentLength the maximum length of the message content. + * If the length of the message content exceeds this value, + * a {@link TooLongFrameException} will be raised. + * @param validateHeaders {@code true} if http headers should be validated + * @deprecated Use the {@link #SpdyHttpDecoder(SpdyVersion, int, Map, HttpHeadersFactory, HttpHeadersFactory)} + * constructor instead. + */ + @Deprecated + public SpdyHttpDecoder(SpdyVersion version, int maxContentLength, boolean validateHeaders) { + this(version, maxContentLength, new HashMap(), validateHeaders); + } + + /** + * Creates a new instance with the specified parameters. + * + * @param version the protocol version + * @param maxContentLength the maximum length of the message content. + * If the length of the message content exceeds this value, + * a {@link TooLongFrameException} will be raised. + * @param messageMap the {@link Map} used to hold partially received messages. + */ + protected SpdyHttpDecoder(SpdyVersion version, int maxContentLength, Map messageMap) { + this(version, maxContentLength, messageMap, + DefaultHttpHeadersFactory.headersFactory(), DefaultHttpHeadersFactory.trailersFactory()); + } + + /** + * Creates a new instance with the specified parameters. + * + * @param version the protocol version + * @param maxContentLength the maximum length of the message content. + * If the length of the message content exceeds this value, + * a {@link TooLongFrameException} will be raised. + * @param messageMap the {@link Map} used to hold partially received messages. + * @param validateHeaders {@code true} if http headers should be validated + * @deprecated Use the {@link #SpdyHttpDecoder(SpdyVersion, int, Map, HttpHeadersFactory, HttpHeadersFactory)} + * constructor instead. + */ + @Deprecated + protected SpdyHttpDecoder(SpdyVersion version, int maxContentLength, Map messageMap, boolean validateHeaders) { + this(version, maxContentLength, messageMap, + DefaultHttpHeadersFactory.headersFactory().withValidation(validateHeaders), + DefaultHttpHeadersFactory.trailersFactory().withValidation(validateHeaders)); + } + + /** + * Creates a new instance with the specified parameters. + * + * @param version the protocol version + * @param maxContentLength the maximum length of the message content. + * If the length of the message content exceeds this value, + * a {@link TooLongFrameException} will be raised. + * @param messageMap the {@link Map} used to hold partially received messages. + * @param headersFactory The factory used for creating HTTP headers + * @param trailersFactory The factory used for creating HTTP trailers. + */ + protected SpdyHttpDecoder(SpdyVersion version, int maxContentLength, Map messageMap, HttpHeadersFactory headersFactory, HttpHeadersFactory trailersFactory) { + spdyVersion = ObjectUtil.checkNotNull(version, "version").getVersion(); + this.maxContentLength = checkPositive(maxContentLength, "maxContentLength"); + this.messageMap = messageMap; + this.headersFactory = headersFactory; + this.trailersFactory = trailersFactory; + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + // Release any outstanding messages from the map + for (Map.Entry entry : messageMap.entrySet()) { + ReferenceCountUtil.safeRelease(entry.getValue()); + } + messageMap.clear(); + super.channelInactive(ctx); + } + + protected FullHttpMessage putMessage(int streamId, FullHttpMessage message) { + return messageMap.put(streamId, message); + } + + protected FullHttpMessage getMessage(int streamId) { + return messageMap.get(streamId); + } + + protected FullHttpMessage removeMessage(int streamId) { + return messageMap.remove(streamId); + } + + @Override + protected void decode(ChannelHandlerContext ctx, SpdyFrame msg, List out) + throws Exception { + if (msg instanceof SpdySynStreamFrame) { + + // HTTP requests/responses are mapped one-to-one to SPDY streams. + SpdySynStreamFrame spdySynStreamFrame = (SpdySynStreamFrame) msg; + int streamId = spdySynStreamFrame.streamId(); + + if (SpdyCodecUtil.isServerId(streamId)) { + // SYN_STREAM frames initiated by the server are pushed resources + int associatedToStreamId = spdySynStreamFrame.associatedStreamId(); + + // If a client receives a SYN_STREAM with an Associated-To-Stream-ID of 0 + // it must reply with a RST_STREAM with error code INVALID_STREAM. + if (associatedToStreamId == 0) { + SpdyRstStreamFrame spdyRstStreamFrame = + new DefaultSpdyRstStreamFrame(streamId, SpdyStreamStatus.INVALID_STREAM); + ctx.writeAndFlush(spdyRstStreamFrame); + return; + } + + // If a client receives a SYN_STREAM with isLast set, + // reply with a RST_STREAM with error code PROTOCOL_ERROR + // (we only support pushed resources divided into two header blocks). + if (spdySynStreamFrame.isLast()) { + SpdyRstStreamFrame spdyRstStreamFrame = + new DefaultSpdyRstStreamFrame(streamId, SpdyStreamStatus.PROTOCOL_ERROR); + ctx.writeAndFlush(spdyRstStreamFrame); + return; + } + + // If a client receives a response with a truncated header block, + // reply with a RST_STREAM with error code INTERNAL_ERROR. + if (spdySynStreamFrame.isTruncated()) { + SpdyRstStreamFrame spdyRstStreamFrame = + new DefaultSpdyRstStreamFrame(streamId, SpdyStreamStatus.INTERNAL_ERROR); + ctx.writeAndFlush(spdyRstStreamFrame); + return; + } + + try { + FullHttpRequest httpRequestWithEntity = createHttpRequest(spdySynStreamFrame, ctx.alloc()); + + // Set the Stream-ID, Associated-To-Stream-ID, and Priority as headers + httpRequestWithEntity.headers().setInt(Names.STREAM_ID, streamId); + httpRequestWithEntity.headers().setInt(Names.ASSOCIATED_TO_STREAM_ID, associatedToStreamId); + httpRequestWithEntity.headers().setInt(Names.PRIORITY, spdySynStreamFrame.priority()); + + out.add(httpRequestWithEntity); + + } catch (Throwable ignored) { + SpdyRstStreamFrame spdyRstStreamFrame = + new DefaultSpdyRstStreamFrame(streamId, SpdyStreamStatus.PROTOCOL_ERROR); + ctx.writeAndFlush(spdyRstStreamFrame); + } + } else { + // SYN_STREAM frames initiated by the client are HTTP requests + + // If a client sends a request with a truncated header block, the server must + // reply with an HTTP 431 REQUEST HEADER FIELDS TOO LARGE reply. + if (spdySynStreamFrame.isTruncated()) { + SpdySynReplyFrame spdySynReplyFrame = new DefaultSpdySynReplyFrame(streamId); + spdySynReplyFrame.setLast(true); + SpdyHeaders frameHeaders = spdySynReplyFrame.headers(); + frameHeaders.setInt(STATUS, HttpResponseStatus.REQUEST_HEADER_FIELDS_TOO_LARGE.code()); + frameHeaders.setObject(VERSION, HttpVersion.HTTP_1_0); + ctx.writeAndFlush(spdySynReplyFrame); + return; + } + + try { + FullHttpRequest httpRequestWithEntity = createHttpRequest(spdySynStreamFrame, ctx.alloc()); + + // Set the Stream-ID as a header + httpRequestWithEntity.headers().setInt(Names.STREAM_ID, streamId); + + if (spdySynStreamFrame.isLast()) { + out.add(httpRequestWithEntity); + } else { + // Request body will follow in a series of Data Frames + putMessage(streamId, httpRequestWithEntity); + } + } catch (Throwable t) { + // If a client sends a SYN_STREAM without all of the getMethod, url (host and path), + // scheme, and version headers the server must reply with an HTTP 400 BAD REQUEST reply. + // Also sends HTTP 400 BAD REQUEST reply if header name/value pairs are invalid + SpdySynReplyFrame spdySynReplyFrame = new DefaultSpdySynReplyFrame(streamId); + spdySynReplyFrame.setLast(true); + SpdyHeaders frameHeaders = spdySynReplyFrame.headers(); + frameHeaders.setInt(STATUS, HttpResponseStatus.BAD_REQUEST.code()); + frameHeaders.setObject(VERSION, HttpVersion.HTTP_1_0); + ctx.writeAndFlush(spdySynReplyFrame); + } + } + + } else if (msg instanceof SpdySynReplyFrame) { + + SpdySynReplyFrame spdySynReplyFrame = (SpdySynReplyFrame) msg; + int streamId = spdySynReplyFrame.streamId(); + + // If a client receives a SYN_REPLY with a truncated header block, + // reply with a RST_STREAM frame with error code INTERNAL_ERROR. + if (spdySynReplyFrame.isTruncated()) { + SpdyRstStreamFrame spdyRstStreamFrame = + new DefaultSpdyRstStreamFrame(streamId, SpdyStreamStatus.INTERNAL_ERROR); + ctx.writeAndFlush(spdyRstStreamFrame); + return; + } + + try { + FullHttpResponse httpResponseWithEntity = + createHttpResponse(spdySynReplyFrame, ctx.alloc()); + + // Set the Stream-ID as a header + httpResponseWithEntity.headers().setInt(Names.STREAM_ID, streamId); + + if (spdySynReplyFrame.isLast()) { + HttpUtil.setContentLength(httpResponseWithEntity, 0); + out.add(httpResponseWithEntity); + } else { + // Response body will follow in a series of Data Frames + putMessage(streamId, httpResponseWithEntity); + } + } catch (Throwable t) { + // If a client receives a SYN_REPLY without valid getStatus and version headers + // the client must reply with a RST_STREAM frame indicating a PROTOCOL_ERROR + SpdyRstStreamFrame spdyRstStreamFrame = + new DefaultSpdyRstStreamFrame(streamId, SpdyStreamStatus.PROTOCOL_ERROR); + ctx.writeAndFlush(spdyRstStreamFrame); + } + + } else if (msg instanceof SpdyHeadersFrame) { + + SpdyHeadersFrame spdyHeadersFrame = (SpdyHeadersFrame) msg; + int streamId = spdyHeadersFrame.streamId(); + FullHttpMessage fullHttpMessage = getMessage(streamId); + + if (fullHttpMessage == null) { + // HEADERS frames may initiate a pushed response + if (SpdyCodecUtil.isServerId(streamId)) { + + // If a client receives a HEADERS with a truncated header block, + // reply with a RST_STREAM frame with error code INTERNAL_ERROR. + if (spdyHeadersFrame.isTruncated()) { + SpdyRstStreamFrame spdyRstStreamFrame = + new DefaultSpdyRstStreamFrame(streamId, SpdyStreamStatus.INTERNAL_ERROR); + ctx.writeAndFlush(spdyRstStreamFrame); + return; + } + + try { + fullHttpMessage = createHttpResponse(spdyHeadersFrame, ctx.alloc()); + + // Set the Stream-ID as a header + fullHttpMessage.headers().setInt(Names.STREAM_ID, streamId); + + if (spdyHeadersFrame.isLast()) { + HttpUtil.setContentLength(fullHttpMessage, 0); + out.add(fullHttpMessage); + } else { + // Response body will follow in a series of Data Frames + putMessage(streamId, fullHttpMessage); + } + } catch (Throwable t) { + // If a client receives a SYN_REPLY without valid getStatus and version headers + // the client must reply with a RST_STREAM frame indicating a PROTOCOL_ERROR + SpdyRstStreamFrame spdyRstStreamFrame = + new DefaultSpdyRstStreamFrame(streamId, SpdyStreamStatus.PROTOCOL_ERROR); + ctx.writeAndFlush(spdyRstStreamFrame); + } + } + return; + } + + // Ignore trailers in a truncated HEADERS frame. + if (!spdyHeadersFrame.isTruncated()) { + for (Map.Entry e: spdyHeadersFrame.headers()) { + fullHttpMessage.headers().add(e.getKey(), e.getValue()); + } + } + + if (spdyHeadersFrame.isLast()) { + HttpUtil.setContentLength(fullHttpMessage, fullHttpMessage.content().readableBytes()); + removeMessage(streamId); + out.add(fullHttpMessage); + } + + } else if (msg instanceof SpdyDataFrame) { + + SpdyDataFrame spdyDataFrame = (SpdyDataFrame) msg; + int streamId = spdyDataFrame.streamId(); + FullHttpMessage fullHttpMessage = getMessage(streamId); + + // If message is not in map discard Data Frame. + if (fullHttpMessage == null) { + return; + } + + ByteBuf content = fullHttpMessage.content(); + if (content.readableBytes() > maxContentLength - spdyDataFrame.content().readableBytes()) { + removeMessage(streamId); + throw new TooLongFrameException( + "HTTP content length exceeded " + maxContentLength + " bytes."); + } + + ByteBuf spdyDataFrameData = spdyDataFrame.content(); + int spdyDataFrameDataLen = spdyDataFrameData.readableBytes(); + content.writeBytes(spdyDataFrameData, spdyDataFrameData.readerIndex(), spdyDataFrameDataLen); + + if (spdyDataFrame.isLast()) { + HttpUtil.setContentLength(fullHttpMessage, content.readableBytes()); + removeMessage(streamId); + out.add(fullHttpMessage); + } + + } else if (msg instanceof SpdyRstStreamFrame) { + + SpdyRstStreamFrame spdyRstStreamFrame = (SpdyRstStreamFrame) msg; + int streamId = spdyRstStreamFrame.streamId(); + removeMessage(streamId); + } + } + + private static FullHttpRequest createHttpRequest(SpdyHeadersFrame requestFrame, ByteBufAllocator alloc) + throws Exception { + // Create the first line of the request from the name/value pairs + SpdyHeaders headers = requestFrame.headers(); + HttpMethod method = HttpMethod.valueOf(headers.getAsString(METHOD)); + String url = headers.getAsString(PATH); + HttpVersion httpVersion = HttpVersion.valueOf(headers.getAsString(VERSION)); + headers.remove(METHOD); + headers.remove(PATH); + headers.remove(VERSION); + + boolean release = true; + ByteBuf buffer = alloc.buffer(); + try { + FullHttpRequest req = new DefaultFullHttpRequest(httpVersion, method, url, buffer); + + // Remove the scheme header + headers.remove(SCHEME); + + // Replace the SPDY host header with the HTTP host header + CharSequence host = headers.get(HOST); + headers.remove(HOST); + req.headers().set(HttpHeaderNames.HOST, host); + + for (Map.Entry e : requestFrame.headers()) { + req.headers().add(e.getKey(), e.getValue()); + } + + // The Connection and Keep-Alive headers are no longer valid + HttpUtil.setKeepAlive(req, true); + + // Transfer-Encoding header is not valid + req.headers().remove(HttpHeaderNames.TRANSFER_ENCODING); + release = false; + return req; + } finally { + if (release) { + buffer.release(); + } + } + } + + private FullHttpResponse createHttpResponse(SpdyHeadersFrame responseFrame, ByteBufAllocator alloc) + throws Exception { + + // Create the first line of the response from the name/value pairs + SpdyHeaders headers = responseFrame.headers(); + HttpResponseStatus status = HttpResponseStatus.parseLine(headers.get(STATUS)); + HttpVersion version = HttpVersion.valueOf(headers.getAsString(VERSION)); + headers.remove(STATUS); + headers.remove(VERSION); + + boolean release = true; + ByteBuf buffer = alloc.buffer(); + try { + FullHttpResponse res = new DefaultFullHttpResponse( + version, status, buffer, headersFactory, trailersFactory); + for (Map.Entry e: responseFrame.headers()) { + res.headers().add(e.getKey(), e.getValue()); + } + + // The Connection and Keep-Alive headers are no longer valid + HttpUtil.setKeepAlive(res, true); + + // Transfer-Encoding header is not valid + res.headers().remove(HttpHeaderNames.TRANSFER_ENCODING); + res.headers().remove(HttpHeaderNames.TRAILER); + + release = false; + return res; + } finally { + if (release) { + buffer.release(); + } + } + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyHttpEncoder.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyHttpEncoder.java new file mode 100644 index 0000000..27017b9 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyHttpEncoder.java @@ -0,0 +1,331 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.spdy; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.MessageToMessageEncoder; +import io.netty.handler.codec.UnsupportedMessageTypeException; +import io.netty.handler.codec.http.FullHttpMessage; +import io.netty.handler.codec.http.HttpContent; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaders; +import io.netty.handler.codec.http.HttpMessage; +import io.netty.handler.codec.http.HttpObject; +import io.netty.handler.codec.http.HttpRequest; +import io.netty.handler.codec.http.HttpResponse; +import io.netty.handler.codec.http.LastHttpContent; +import io.netty.util.AsciiString; +import io.netty.util.internal.ObjectUtil; + +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; + +/** + * Encodes {@link HttpRequest}s, {@link HttpResponse}s, and {@link HttpContent}s + * into {@link SpdySynStreamFrame}s and {@link SpdySynReplyFrame}s. + * + *

Request Annotations

+ * + * SPDY specific headers must be added to {@link HttpRequest}s: + * + * + * + * + * + * + * + * + * + * + * + * + *
Header NameHeader Value
{@code "X-SPDY-Stream-ID"}The Stream-ID for this request. + * Stream-IDs must be odd, positive integers, and must increase monotonically.
{@code "X-SPDY-Priority"}The priority value for this request. + * The priority should be between 0 and 7 inclusive. + * 0 represents the highest priority and 7 represents the lowest. + * This header is optional and defaults to 0.
+ * + *

Response Annotations

+ * + * SPDY specific headers must be added to {@link HttpResponse}s: + * + * + * + * + * + * + * + * + *
Header NameHeader Value
{@code "X-SPDY-Stream-ID"}The Stream-ID of the request corresponding to this response.
+ * + *

Pushed Resource Annotations

+ * + * SPDY specific headers must be added to pushed {@link HttpRequest}s: + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
Header NameHeader Value
{@code "X-SPDY-Stream-ID"}The Stream-ID for this resource. + * Stream-IDs must be even, positive integers, and must increase monotonically.
{@code "X-SPDY-Associated-To-Stream-ID"}The Stream-ID of the request that initiated this pushed resource.
{@code "X-SPDY-Priority"}The priority value for this resource. + * The priority should be between 0 and 7 inclusive. + * 0 represents the highest priority and 7 represents the lowest. + * This header is optional and defaults to 0.
+ * + *

Required Annotations

+ * + * SPDY requires that all Requests and Pushed Resources contain + * an HTTP "Host" header. + * + *

Optional Annotations

+ * + * Requests and Pushed Resources must contain a SPDY scheme header. + * This can be set via the {@code "X-SPDY-Scheme"} header but otherwise + * defaults to "https" as that is the most common SPDY deployment. + * + *

Chunked Content

+ * + * This encoder associates all {@link HttpContent}s that it receives + * with the most recently received 'chunked' {@link HttpRequest} + * or {@link HttpResponse}. + * + *

Pushed Resources

+ * + * All pushed resources should be sent before sending the response + * that corresponds to the initial request. + */ +public class SpdyHttpEncoder extends MessageToMessageEncoder { + + private int currentStreamId; + + private final boolean validateHeaders; + private final boolean headersToLowerCase; + + /** + * Creates a new instance. + * + * @param version the protocol version + */ + public SpdyHttpEncoder(SpdyVersion version) { + this(version, true, true); + } + + /** + * Creates a new instance. + * + * @param version the protocol version + * @param headersToLowerCase convert header names to lowercase. In a controlled environment, + * one can disable the conversion. + * @param validateHeaders validate the header names and values when adding them to the {@link SpdyHeaders} + */ + public SpdyHttpEncoder(SpdyVersion version, boolean headersToLowerCase, boolean validateHeaders) { + ObjectUtil.checkNotNull(version, "version"); + this.headersToLowerCase = headersToLowerCase; + this.validateHeaders = validateHeaders; + } + + @Override + protected void encode(ChannelHandlerContext ctx, HttpObject msg, List out) throws Exception { + + boolean valid = false; + boolean last = false; + + if (msg instanceof HttpRequest) { + + HttpRequest httpRequest = (HttpRequest) msg; + SpdySynStreamFrame spdySynStreamFrame = createSynStreamFrame(httpRequest); + out.add(spdySynStreamFrame); + + last = spdySynStreamFrame.isLast() || spdySynStreamFrame.isUnidirectional(); + valid = true; + } + if (msg instanceof HttpResponse) { + + HttpResponse httpResponse = (HttpResponse) msg; + SpdyHeadersFrame spdyHeadersFrame = createHeadersFrame(httpResponse); + out.add(spdyHeadersFrame); + + last = spdyHeadersFrame.isLast(); + valid = true; + } + if (msg instanceof HttpContent && !last) { + + HttpContent chunk = (HttpContent) msg; + + chunk.content().retain(); + SpdyDataFrame spdyDataFrame = new DefaultSpdyDataFrame(currentStreamId, chunk.content()); + if (chunk instanceof LastHttpContent) { + LastHttpContent trailer = (LastHttpContent) chunk; + HttpHeaders trailers = trailer.trailingHeaders(); + if (trailers.isEmpty()) { + spdyDataFrame.setLast(true); + out.add(spdyDataFrame); + } else { + // Create SPDY HEADERS frame out of trailers + SpdyHeadersFrame spdyHeadersFrame = new DefaultSpdyHeadersFrame(currentStreamId, validateHeaders); + spdyHeadersFrame.setLast(true); + Iterator> itr = trailers.iteratorCharSequence(); + while (itr.hasNext()) { + Map.Entry entry = itr.next(); + final CharSequence headerName = + headersToLowerCase ? AsciiString.of(entry.getKey()).toLowerCase() : entry.getKey(); + spdyHeadersFrame.headers().add(headerName, entry.getValue()); + } + + // Write DATA frame and append HEADERS frame + out.add(spdyDataFrame); + out.add(spdyHeadersFrame); + } + } else { + out.add(spdyDataFrame); + } + + valid = true; + } + + if (!valid) { + throw new UnsupportedMessageTypeException(msg); + } + } + + @SuppressWarnings("deprecation") + private SpdySynStreamFrame createSynStreamFrame(HttpRequest httpRequest) throws Exception { + // Get the Stream-ID, Associated-To-Stream-ID, Priority, and scheme from the headers + final HttpHeaders httpHeaders = httpRequest.headers(); + int streamId = httpHeaders.getInt(SpdyHttpHeaders.Names.STREAM_ID); + int associatedToStreamId = httpHeaders.getInt(SpdyHttpHeaders.Names.ASSOCIATED_TO_STREAM_ID, 0); + byte priority = (byte) httpHeaders.getInt(SpdyHttpHeaders.Names.PRIORITY, 0); + CharSequence scheme = httpHeaders.get(SpdyHttpHeaders.Names.SCHEME); + httpHeaders.remove(SpdyHttpHeaders.Names.STREAM_ID); + httpHeaders.remove(SpdyHttpHeaders.Names.ASSOCIATED_TO_STREAM_ID); + httpHeaders.remove(SpdyHttpHeaders.Names.PRIORITY); + httpHeaders.remove(SpdyHttpHeaders.Names.SCHEME); + + // The Connection, Keep-Alive, Proxy-Connection, and Transfer-Encoding + // headers are not valid and MUST not be sent. + httpHeaders.remove(HttpHeaderNames.CONNECTION); + httpHeaders.remove("Keep-Alive"); + httpHeaders.remove("Proxy-Connection"); + httpHeaders.remove(HttpHeaderNames.TRANSFER_ENCODING); + + SpdySynStreamFrame spdySynStreamFrame = + new DefaultSpdySynStreamFrame(streamId, associatedToStreamId, priority, validateHeaders); + + // Unfold the first line of the message into name/value pairs + SpdyHeaders frameHeaders = spdySynStreamFrame.headers(); + frameHeaders.set(SpdyHeaders.HttpNames.METHOD, httpRequest.method().name()); + frameHeaders.set(SpdyHeaders.HttpNames.PATH, httpRequest.uri()); + frameHeaders.set(SpdyHeaders.HttpNames.VERSION, httpRequest.protocolVersion().text()); + + // Replace the HTTP host header with the SPDY host header + CharSequence host = httpHeaders.get(HttpHeaderNames.HOST); + httpHeaders.remove(HttpHeaderNames.HOST); + frameHeaders.set(SpdyHeaders.HttpNames.HOST, host); + + // Set the SPDY scheme header + if (scheme == null) { + scheme = "https"; + } + frameHeaders.set(SpdyHeaders.HttpNames.SCHEME, scheme); + + // Transfer the remaining HTTP headers + Iterator> itr = httpHeaders.iteratorCharSequence(); + while (itr.hasNext()) { + Map.Entry entry = itr.next(); + final CharSequence headerName = + headersToLowerCase ? AsciiString.of(entry.getKey()).toLowerCase() : entry.getKey(); + frameHeaders.add(headerName, entry.getValue()); + } + currentStreamId = spdySynStreamFrame.streamId(); + if (associatedToStreamId == 0) { + spdySynStreamFrame.setLast(isLast(httpRequest)); + } else { + spdySynStreamFrame.setUnidirectional(true); + } + + return spdySynStreamFrame; + } + + @SuppressWarnings("deprecation") + private SpdyHeadersFrame createHeadersFrame(HttpResponse httpResponse) throws Exception { + // Get the Stream-ID from the headers + final HttpHeaders httpHeaders = httpResponse.headers(); + int streamId = httpHeaders.getInt(SpdyHttpHeaders.Names.STREAM_ID); + httpHeaders.remove(SpdyHttpHeaders.Names.STREAM_ID); + + // The Connection, Keep-Alive, Proxy-Connection, and Transfer-Encoding + // headers are not valid and MUST not be sent. + httpHeaders.remove(HttpHeaderNames.CONNECTION); + httpHeaders.remove("Keep-Alive"); + httpHeaders.remove("Proxy-Connection"); + httpHeaders.remove(HttpHeaderNames.TRANSFER_ENCODING); + + SpdyHeadersFrame spdyHeadersFrame; + if (SpdyCodecUtil.isServerId(streamId)) { + spdyHeadersFrame = new DefaultSpdyHeadersFrame(streamId, validateHeaders); + } else { + spdyHeadersFrame = new DefaultSpdySynReplyFrame(streamId, validateHeaders); + } + SpdyHeaders frameHeaders = spdyHeadersFrame.headers(); + // Unfold the first line of the response into name/value pairs + frameHeaders.set(SpdyHeaders.HttpNames.STATUS, httpResponse.status().codeAsText()); + frameHeaders.set(SpdyHeaders.HttpNames.VERSION, httpResponse.protocolVersion().text()); + + // Transfer the remaining HTTP headers + Iterator> itr = httpHeaders.iteratorCharSequence(); + while (itr.hasNext()) { + Map.Entry entry = itr.next(); + final CharSequence headerName = + headersToLowerCase ? AsciiString.of(entry.getKey()).toLowerCase() : entry.getKey(); + spdyHeadersFrame.headers().add(headerName, entry.getValue()); + } + + currentStreamId = streamId; + spdyHeadersFrame.setLast(isLast(httpResponse)); + + return spdyHeadersFrame; + } + + /** + * Checks if the given HTTP message should be considered as a last SPDY frame. + * + * @param httpMessage check this HTTP message + * @return whether the given HTTP message should generate a last SPDY frame. + */ + private static boolean isLast(HttpMessage httpMessage) { + if (httpMessage instanceof FullHttpMessage) { + FullHttpMessage fullMessage = (FullHttpMessage) httpMessage; + if (fullMessage.trailingHeaders().isEmpty() && !fullMessage.content().isReadable()) { + return true; + } + } + + return false; + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyHttpHeaders.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyHttpHeaders.java new file mode 100644 index 0000000..f659435 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyHttpHeaders.java @@ -0,0 +1,51 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.spdy; + +import io.netty.util.AsciiString; + +/** + * Provides the constants for the header names and the utility methods + * used by the {@link SpdyHttpDecoder} and {@link SpdyHttpEncoder}. + */ +public final class SpdyHttpHeaders { + + /** + * SPDY HTTP header names + */ + public static final class Names { + /** + * {@code "x-spdy-stream-id"} + */ + public static final AsciiString STREAM_ID = AsciiString.cached("x-spdy-stream-id"); + /** + * {@code "x-spdy-associated-to-stream-id"} + */ + public static final AsciiString ASSOCIATED_TO_STREAM_ID = AsciiString.cached("x-spdy-associated-to-stream-id"); + /** + * {@code "x-spdy-priority"} + */ + public static final AsciiString PRIORITY = AsciiString.cached("x-spdy-priority"); + /** + * {@code "x-spdy-scheme"} + */ + public static final AsciiString SCHEME = AsciiString.cached("x-spdy-scheme"); + + private Names() { } + } + + private SpdyHttpHeaders() { } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyHttpResponseStreamIdHandler.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyHttpResponseStreamIdHandler.java new file mode 100644 index 0000000..0f8c8fb --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyHttpResponseStreamIdHandler.java @@ -0,0 +1,68 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.spdy; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.MessageToMessageCodec; +import io.netty.handler.codec.http.HttpMessage; +import io.netty.handler.codec.spdy.SpdyHttpHeaders.Names; +import io.netty.util.ReferenceCountUtil; + +import java.util.ArrayDeque; +import java.util.List; +import java.util.Queue; + +/** + * {@link MessageToMessageCodec} that takes care of adding the right {@link SpdyHttpHeaders.Names#STREAM_ID} to the + * {@link HttpMessage} if one is not present. This makes it possible to just re-use plan handlers current used + * for HTTP. + */ +public class SpdyHttpResponseStreamIdHandler extends + MessageToMessageCodec { + private static final Integer NO_ID = -1; + private final Queue ids = new ArrayDeque(); + + @Override + public boolean acceptInboundMessage(Object msg) throws Exception { + return msg instanceof HttpMessage || msg instanceof SpdyRstStreamFrame; + } + + @Override + protected void encode(ChannelHandlerContext ctx, HttpMessage msg, List out) throws Exception { + Integer id = ids.poll(); + if (id != null && id.intValue() != NO_ID && !msg.headers().contains(SpdyHttpHeaders.Names.STREAM_ID)) { + msg.headers().setInt(Names.STREAM_ID, id); + } + + out.add(ReferenceCountUtil.retain(msg)); + } + + @Override + protected void decode(ChannelHandlerContext ctx, Object msg, List out) throws Exception { + if (msg instanceof HttpMessage) { + boolean contains = ((HttpMessage) msg).headers().contains(SpdyHttpHeaders.Names.STREAM_ID); + if (!contains) { + ids.add(NO_ID); + } else { + ids.add(((HttpMessage) msg).headers().getInt(Names.STREAM_ID)); + } + } else if (msg instanceof SpdyRstStreamFrame) { + ids.remove(((SpdyRstStreamFrame) msg).streamId()); + } + + out.add(ReferenceCountUtil.retain(msg)); + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyPingFrame.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyPingFrame.java new file mode 100644 index 0000000..169284f --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyPingFrame.java @@ -0,0 +1,32 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.spdy; + +/** + * A SPDY Protocol PING Frame + */ +public interface SpdyPingFrame extends SpdyFrame { + + /** + * Returns the ID of this frame. + */ + int id(); + + /** + * Sets the ID of this frame. + */ + SpdyPingFrame setId(int id); +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyProtocolException.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyProtocolException.java new file mode 100644 index 0000000..e097d5c --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyProtocolException.java @@ -0,0 +1,87 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.spdy; + +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.SuppressJava6Requirement; +import io.netty.util.internal.ThrowableUtil; + +public class SpdyProtocolException extends Exception { + + private static final long serialVersionUID = 7870000537743847264L; + + /** + * Creates a new instance. + */ + public SpdyProtocolException() { } + + /** + * Creates a new instance. + */ + public SpdyProtocolException(String message, Throwable cause) { + super(message, cause); + } + + /** + * Creates a new instance. + */ + public SpdyProtocolException(String message) { + super(message); + } + + /** + * Creates a new instance. + */ + public SpdyProtocolException(Throwable cause) { + super(cause); + } + + static SpdyProtocolException newStatic(String message, Class clazz, String method) { + final SpdyProtocolException exception; + if (PlatformDependent.javaVersion() >= 7) { + exception = new StacklessSpdyProtocolException(message, true); + } else { + exception = new StacklessSpdyProtocolException(message); + } + return ThrowableUtil.unknownStackTrace(exception, clazz, method); + } + + @SuppressJava6Requirement(reason = "uses Java 7+ Exception.(String, Throwable, boolean, boolean)" + + " but is guarded by version checks") + private SpdyProtocolException(String message, boolean shared) { + super(message, null, false, true); + assert shared; + } + + private static final class StacklessSpdyProtocolException extends SpdyProtocolException { + private static final long serialVersionUID = -6302754207557485099L; + + StacklessSpdyProtocolException(String message) { + super(message); + } + + StacklessSpdyProtocolException(String message, boolean shared) { + super(message, shared); + } + + // Override fillInStackTrace() so we not populate the backtrace via a native call and so leak the + // Classloader. + @Override + public Throwable fillInStackTrace() { + return this; + } + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyRstStreamFrame.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyRstStreamFrame.java new file mode 100644 index 0000000..d86ec60 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyRstStreamFrame.java @@ -0,0 +1,38 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.spdy; + +/** + * A SPDY Protocol RST_STREAM Frame + */ +public interface SpdyRstStreamFrame extends SpdyStreamFrame { + + /** + * Returns the status of this frame. + */ + SpdyStreamStatus status(); + + /** + * Sets the status of this frame. + */ + SpdyRstStreamFrame setStatus(SpdyStreamStatus status); + + @Override + SpdyRstStreamFrame setStreamId(int streamId); + + @Override + SpdyRstStreamFrame setLast(boolean last); +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdySession.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdySession.java new file mode 100644 index 0000000..3f26483 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdySession.java @@ -0,0 +1,357 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.spdy; + +import io.netty.channel.ChannelPromise; +import io.netty.util.internal.PlatformDependent; + +import java.util.Comparator; +import java.util.Map; +import java.util.Queue; +import java.util.TreeMap; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.atomic.AtomicInteger; + +import static io.netty.handler.codec.spdy.SpdyCodecUtil.*; + +final class SpdySession { + + private final AtomicInteger activeLocalStreams = new AtomicInteger(); + private final AtomicInteger activeRemoteStreams = new AtomicInteger(); + private final Map activeStreams = PlatformDependent.newConcurrentHashMap(); + private final StreamComparator streamComparator = new StreamComparator(); + private final AtomicInteger sendWindowSize; + private final AtomicInteger receiveWindowSize; + + SpdySession(int sendWindowSize, int receiveWindowSize) { + this.sendWindowSize = new AtomicInteger(sendWindowSize); + this.receiveWindowSize = new AtomicInteger(receiveWindowSize); + } + + int numActiveStreams(boolean remote) { + if (remote) { + return activeRemoteStreams.get(); + } else { + return activeLocalStreams.get(); + } + } + + boolean noActiveStreams() { + return activeStreams.isEmpty(); + } + + boolean isActiveStream(int streamId) { + return activeStreams.containsKey(streamId); + } + + // Stream-IDs should be iterated in priority order + Map activeStreams() { + Map streams = new TreeMap(streamComparator); + streams.putAll(activeStreams); + return streams; + } + + void acceptStream( + int streamId, byte priority, boolean remoteSideClosed, boolean localSideClosed, + int sendWindowSize, int receiveWindowSize, boolean remote) { + if (!remoteSideClosed || !localSideClosed) { + StreamState state = activeStreams.put(streamId, new StreamState( + priority, remoteSideClosed, localSideClosed, sendWindowSize, receiveWindowSize)); + if (state == null) { + if (remote) { + activeRemoteStreams.incrementAndGet(); + } else { + activeLocalStreams.incrementAndGet(); + } + } + } + } + + private StreamState removeActiveStream(int streamId, boolean remote) { + StreamState state = activeStreams.remove(streamId); + if (state != null) { + if (remote) { + activeRemoteStreams.decrementAndGet(); + } else { + activeLocalStreams.decrementAndGet(); + } + } + return state; + } + + void removeStream(int streamId, Throwable cause, boolean remote) { + StreamState state = removeActiveStream(streamId, remote); + if (state != null) { + state.clearPendingWrites(cause); + } + } + + boolean isRemoteSideClosed(int streamId) { + StreamState state = activeStreams.get(streamId); + return state == null || state.isRemoteSideClosed(); + } + + void closeRemoteSide(int streamId, boolean remote) { + StreamState state = activeStreams.get(streamId); + if (state != null) { + state.closeRemoteSide(); + if (state.isLocalSideClosed()) { + removeActiveStream(streamId, remote); + } + } + } + + boolean isLocalSideClosed(int streamId) { + StreamState state = activeStreams.get(streamId); + return state == null || state.isLocalSideClosed(); + } + + void closeLocalSide(int streamId, boolean remote) { + StreamState state = activeStreams.get(streamId); + if (state != null) { + state.closeLocalSide(); + if (state.isRemoteSideClosed()) { + removeActiveStream(streamId, remote); + } + } + } + + /* + * hasReceivedReply and receivedReply are only called from channelRead() + * no need to synchronize access to the StreamState + */ + boolean hasReceivedReply(int streamId) { + StreamState state = activeStreams.get(streamId); + return state != null && state.hasReceivedReply(); + } + + void receivedReply(int streamId) { + StreamState state = activeStreams.get(streamId); + if (state != null) { + state.receivedReply(); + } + } + + int getSendWindowSize(int streamId) { + if (streamId == SPDY_SESSION_STREAM_ID) { + return sendWindowSize.get(); + } + + StreamState state = activeStreams.get(streamId); + return state != null ? state.getSendWindowSize() : -1; + } + + int updateSendWindowSize(int streamId, int deltaWindowSize) { + if (streamId == SPDY_SESSION_STREAM_ID) { + return sendWindowSize.addAndGet(deltaWindowSize); + } + + StreamState state = activeStreams.get(streamId); + return state != null ? state.updateSendWindowSize(deltaWindowSize) : -1; + } + + int updateReceiveWindowSize(int streamId, int deltaWindowSize) { + if (streamId == SPDY_SESSION_STREAM_ID) { + return receiveWindowSize.addAndGet(deltaWindowSize); + } + + StreamState state = activeStreams.get(streamId); + if (state == null) { + return -1; + } + if (deltaWindowSize > 0) { + state.setReceiveWindowSizeLowerBound(0); + } + return state.updateReceiveWindowSize(deltaWindowSize); + } + + int getReceiveWindowSizeLowerBound(int streamId) { + if (streamId == SPDY_SESSION_STREAM_ID) { + return 0; + } + + StreamState state = activeStreams.get(streamId); + return state != null ? state.getReceiveWindowSizeLowerBound() : 0; + } + + void updateAllSendWindowSizes(int deltaWindowSize) { + for (StreamState state: activeStreams.values()) { + state.updateSendWindowSize(deltaWindowSize); + } + } + + void updateAllReceiveWindowSizes(int deltaWindowSize) { + for (StreamState state: activeStreams.values()) { + state.updateReceiveWindowSize(deltaWindowSize); + if (deltaWindowSize < 0) { + state.setReceiveWindowSizeLowerBound(deltaWindowSize); + } + } + } + + boolean putPendingWrite(int streamId, PendingWrite pendingWrite) { + StreamState state = activeStreams.get(streamId); + return state != null && state.putPendingWrite(pendingWrite); + } + + PendingWrite getPendingWrite(int streamId) { + if (streamId == SPDY_SESSION_STREAM_ID) { + for (Map.Entry e: activeStreams().entrySet()) { + StreamState state = e.getValue(); + if (state.getSendWindowSize() > 0) { + PendingWrite pendingWrite = state.getPendingWrite(); + if (pendingWrite != null) { + return pendingWrite; + } + } + } + return null; + } + + StreamState state = activeStreams.get(streamId); + return state != null ? state.getPendingWrite() : null; + } + + PendingWrite removePendingWrite(int streamId) { + StreamState state = activeStreams.get(streamId); + return state != null ? state.removePendingWrite() : null; + } + + private static final class StreamState { + + private final byte priority; + private boolean remoteSideClosed; + private boolean localSideClosed; + private boolean receivedReply; + private final AtomicInteger sendWindowSize; + private final AtomicInteger receiveWindowSize; + private int receiveWindowSizeLowerBound; + private final Queue pendingWriteQueue = new ConcurrentLinkedQueue(); + + StreamState( + byte priority, boolean remoteSideClosed, boolean localSideClosed, + int sendWindowSize, int receiveWindowSize) { + this.priority = priority; + this.remoteSideClosed = remoteSideClosed; + this.localSideClosed = localSideClosed; + this.sendWindowSize = new AtomicInteger(sendWindowSize); + this.receiveWindowSize = new AtomicInteger(receiveWindowSize); + } + + byte getPriority() { + return priority; + } + + boolean isRemoteSideClosed() { + return remoteSideClosed; + } + + void closeRemoteSide() { + remoteSideClosed = true; + } + + boolean isLocalSideClosed() { + return localSideClosed; + } + + void closeLocalSide() { + localSideClosed = true; + } + + boolean hasReceivedReply() { + return receivedReply; + } + + void receivedReply() { + receivedReply = true; + } + + int getSendWindowSize() { + return sendWindowSize.get(); + } + + int updateSendWindowSize(int deltaWindowSize) { + return sendWindowSize.addAndGet(deltaWindowSize); + } + + int updateReceiveWindowSize(int deltaWindowSize) { + return receiveWindowSize.addAndGet(deltaWindowSize); + } + + int getReceiveWindowSizeLowerBound() { + return receiveWindowSizeLowerBound; + } + + void setReceiveWindowSizeLowerBound(int receiveWindowSizeLowerBound) { + this.receiveWindowSizeLowerBound = receiveWindowSizeLowerBound; + } + + boolean putPendingWrite(PendingWrite msg) { + return pendingWriteQueue.offer(msg); + } + + PendingWrite getPendingWrite() { + return pendingWriteQueue.peek(); + } + + PendingWrite removePendingWrite() { + return pendingWriteQueue.poll(); + } + + void clearPendingWrites(Throwable cause) { + for (;;) { + PendingWrite pendingWrite = pendingWriteQueue.poll(); + if (pendingWrite == null) { + break; + } + pendingWrite.fail(cause); + } + } + } + + private final class StreamComparator implements Comparator { + + StreamComparator() { } + + @Override + public int compare(Integer id1, Integer id2) { + StreamState state1 = activeStreams.get(id1); + StreamState state2 = activeStreams.get(id2); + + int result = state1.getPriority() - state2.getPriority(); + if (result != 0) { + return result; + } + + return id1 - id2; + } + } + + public static final class PendingWrite { + final SpdyDataFrame spdyDataFrame; + final ChannelPromise promise; + + PendingWrite(SpdyDataFrame spdyDataFrame, ChannelPromise promise) { + this.spdyDataFrame = spdyDataFrame; + this.promise = promise; + } + + void fail(Throwable cause) { + spdyDataFrame.release(); + promise.setFailure(cause); + } + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdySessionHandler.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdySessionHandler.java new file mode 100644 index 0000000..0ad90e6 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdySessionHandler.java @@ -0,0 +1,854 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.spdy; + +import io.netty.channel.ChannelDuplexHandler; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.util.internal.ObjectUtil; + +import java.util.concurrent.atomic.AtomicInteger; + +import static io.netty.handler.codec.spdy.SpdyCodecUtil.SPDY_SESSION_STREAM_ID; +import static io.netty.handler.codec.spdy.SpdyCodecUtil.isServerId; +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; + +/** + * Manages streams within a SPDY session. + */ +public class SpdySessionHandler extends ChannelDuplexHandler { + + private static final SpdyProtocolException PROTOCOL_EXCEPTION = + SpdyProtocolException.newStatic(null, SpdySessionHandler.class, "handleOutboundMessage(...)"); + private static final SpdyProtocolException STREAM_CLOSED = + SpdyProtocolException.newStatic("Stream closed", SpdySessionHandler.class, "removeStream(...)"); + + private static final int DEFAULT_WINDOW_SIZE = 64 * 1024; // 64 KB default initial window size + private int initialSendWindowSize = DEFAULT_WINDOW_SIZE; + private int initialReceiveWindowSize = DEFAULT_WINDOW_SIZE; + private volatile int initialSessionReceiveWindowSize = DEFAULT_WINDOW_SIZE; + + private final SpdySession spdySession = new SpdySession(initialSendWindowSize, initialReceiveWindowSize); + private int lastGoodStreamId; + + private static final int DEFAULT_MAX_CONCURRENT_STREAMS = Integer.MAX_VALUE; + private int remoteConcurrentStreams = DEFAULT_MAX_CONCURRENT_STREAMS; + private int localConcurrentStreams = DEFAULT_MAX_CONCURRENT_STREAMS; + + private final AtomicInteger pings = new AtomicInteger(); + + private boolean sentGoAwayFrame; + private boolean receivedGoAwayFrame; + + private ChannelFutureListener closeSessionFutureListener; + + private final boolean server; + private final int minorVersion; + + /** + * Creates a new session handler. + * + * @param version the protocol version + * @param server {@code true} if and only if this session handler should + * handle the server endpoint of the connection. + * {@code false} if and only if this session handler should + * handle the client endpoint of the connection. + */ + public SpdySessionHandler(SpdyVersion version, boolean server) { + this.minorVersion = ObjectUtil.checkNotNull(version, "version").getMinorVersion(); + this.server = server; + } + + public void setSessionReceiveWindowSize(int sessionReceiveWindowSize) { + checkPositiveOrZero(sessionReceiveWindowSize, "sessionReceiveWindowSize"); + // This will not send a window update frame immediately. + // If this value increases the allowed receive window size, + // a WINDOW_UPDATE frame will be sent when only half of the + // session window size remains during data frame processing. + // If this value decreases the allowed receive window size, + // the window will be reduced as data frames are processed. + initialSessionReceiveWindowSize = sessionReceiveWindowSize; + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if (msg instanceof SpdyDataFrame) { + + /* + * SPDY Data frame processing requirements: + * + * If an endpoint receives a data frame for a Stream-ID which is not open + * and the endpoint has not sent a GOAWAY frame, it must issue a stream error + * with the error code INVALID_STREAM for the Stream-ID. + * + * If an endpoint which created the stream receives a data frame before receiving + * a SYN_REPLY on that stream, it is a protocol error, and the recipient must + * issue a stream error with the getStatus code PROTOCOL_ERROR for the Stream-ID. + * + * If an endpoint receives multiple data frames for invalid Stream-IDs, + * it may close the session. + * + * If an endpoint refuses a stream it must ignore any data frames for that stream. + * + * If an endpoint receives a data frame after the stream is half-closed from the + * sender, it must send a RST_STREAM frame with the getStatus STREAM_ALREADY_CLOSED. + * + * If an endpoint receives a data frame after the stream is closed, it must send + * a RST_STREAM frame with the getStatus PROTOCOL_ERROR. + */ + SpdyDataFrame spdyDataFrame = (SpdyDataFrame) msg; + int streamId = spdyDataFrame.streamId(); + + int deltaWindowSize = -1 * spdyDataFrame.content().readableBytes(); + int newSessionWindowSize = + spdySession.updateReceiveWindowSize(SPDY_SESSION_STREAM_ID, deltaWindowSize); + + // Check if session window size is reduced beyond allowable lower bound + if (newSessionWindowSize < 0) { + issueSessionError(ctx, SpdySessionStatus.PROTOCOL_ERROR); + return; + } + + // Send a WINDOW_UPDATE frame if less than half the session window size remains + if (newSessionWindowSize <= initialSessionReceiveWindowSize / 2) { + int sessionDeltaWindowSize = initialSessionReceiveWindowSize - newSessionWindowSize; + spdySession.updateReceiveWindowSize(SPDY_SESSION_STREAM_ID, sessionDeltaWindowSize); + SpdyWindowUpdateFrame spdyWindowUpdateFrame = + new DefaultSpdyWindowUpdateFrame(SPDY_SESSION_STREAM_ID, sessionDeltaWindowSize); + ctx.writeAndFlush(spdyWindowUpdateFrame); + } + + // Check if we received a data frame for a Stream-ID which is not open + + if (!spdySession.isActiveStream(streamId)) { + spdyDataFrame.release(); + if (streamId <= lastGoodStreamId) { + issueStreamError(ctx, streamId, SpdyStreamStatus.PROTOCOL_ERROR); + } else if (!sentGoAwayFrame) { + issueStreamError(ctx, streamId, SpdyStreamStatus.INVALID_STREAM); + } + return; + } + + // Check if we received a data frame for a stream which is half-closed + + if (spdySession.isRemoteSideClosed(streamId)) { + spdyDataFrame.release(); + issueStreamError(ctx, streamId, SpdyStreamStatus.STREAM_ALREADY_CLOSED); + return; + } + + // Check if we received a data frame before receiving a SYN_REPLY + if (!isRemoteInitiatedId(streamId) && !spdySession.hasReceivedReply(streamId)) { + spdyDataFrame.release(); + issueStreamError(ctx, streamId, SpdyStreamStatus.PROTOCOL_ERROR); + return; + } + + /* + * SPDY Data frame flow control processing requirements: + * + * Recipient should not send a WINDOW_UPDATE frame as it consumes the last data frame. + */ + + // Update receive window size + int newWindowSize = spdySession.updateReceiveWindowSize(streamId, deltaWindowSize); + + // Window size can become negative if we sent a SETTINGS frame that reduces the + // size of the transfer window after the peer has written data frames. + // The value is bounded by the length that SETTINGS frame decrease the window. + // This difference is stored for the session when writing the SETTINGS frame + // and is cleared once we send a WINDOW_UPDATE frame. + if (newWindowSize < spdySession.getReceiveWindowSizeLowerBound(streamId)) { + spdyDataFrame.release(); + issueStreamError(ctx, streamId, SpdyStreamStatus.FLOW_CONTROL_ERROR); + return; + } + + // Window size became negative due to sender writing frame before receiving SETTINGS + // Send data frames upstream in initialReceiveWindowSize chunks + if (newWindowSize < 0) { + while (spdyDataFrame.content().readableBytes() > initialReceiveWindowSize) { + SpdyDataFrame partialDataFrame = new DefaultSpdyDataFrame( + streamId, spdyDataFrame.content().readRetainedSlice(initialReceiveWindowSize)); + ctx.writeAndFlush(partialDataFrame); + } + } + + // Send a WINDOW_UPDATE frame if less than half the stream window size remains + if (newWindowSize <= initialReceiveWindowSize / 2 && !spdyDataFrame.isLast()) { + int streamDeltaWindowSize = initialReceiveWindowSize - newWindowSize; + spdySession.updateReceiveWindowSize(streamId, streamDeltaWindowSize); + SpdyWindowUpdateFrame spdyWindowUpdateFrame = + new DefaultSpdyWindowUpdateFrame(streamId, streamDeltaWindowSize); + ctx.writeAndFlush(spdyWindowUpdateFrame); + } + + // Close the remote side of the stream if this is the last frame + if (spdyDataFrame.isLast()) { + halfCloseStream(streamId, true, ctx.newSucceededFuture()); + } + + } else if (msg instanceof SpdySynStreamFrame) { + + /* + * SPDY SYN_STREAM frame processing requirements: + * + * If an endpoint receives a SYN_STREAM with a Stream-ID that is less than + * any previously received SYN_STREAM, it must issue a session error with + * the getStatus PROTOCOL_ERROR. + * + * If an endpoint receives multiple SYN_STREAM frames with the same active + * Stream-ID, it must issue a stream error with the getStatus code PROTOCOL_ERROR. + * + * The recipient can reject a stream by sending a stream error with the + * getStatus code REFUSED_STREAM. + */ + + SpdySynStreamFrame spdySynStreamFrame = (SpdySynStreamFrame) msg; + int streamId = spdySynStreamFrame.streamId(); + + // Check if we received a valid SYN_STREAM frame + if (spdySynStreamFrame.isInvalid() || + !isRemoteInitiatedId(streamId) || + spdySession.isActiveStream(streamId)) { + issueStreamError(ctx, streamId, SpdyStreamStatus.PROTOCOL_ERROR); + return; + } + + // Stream-IDs must be monotonically increasing + if (streamId <= lastGoodStreamId) { + issueSessionError(ctx, SpdySessionStatus.PROTOCOL_ERROR); + return; + } + + // Try to accept the stream + byte priority = spdySynStreamFrame.priority(); + boolean remoteSideClosed = spdySynStreamFrame.isLast(); + boolean localSideClosed = spdySynStreamFrame.isUnidirectional(); + if (!acceptStream(streamId, priority, remoteSideClosed, localSideClosed)) { + issueStreamError(ctx, streamId, SpdyStreamStatus.REFUSED_STREAM); + return; + } + + } else if (msg instanceof SpdySynReplyFrame) { + + /* + * SPDY SYN_REPLY frame processing requirements: + * + * If an endpoint receives multiple SYN_REPLY frames for the same active Stream-ID + * it must issue a stream error with the getStatus code STREAM_IN_USE. + */ + + SpdySynReplyFrame spdySynReplyFrame = (SpdySynReplyFrame) msg; + int streamId = spdySynReplyFrame.streamId(); + + // Check if we received a valid SYN_REPLY frame + if (spdySynReplyFrame.isInvalid() || + isRemoteInitiatedId(streamId) || + spdySession.isRemoteSideClosed(streamId)) { + issueStreamError(ctx, streamId, SpdyStreamStatus.INVALID_STREAM); + return; + } + + // Check if we have received multiple frames for the same Stream-ID + if (spdySession.hasReceivedReply(streamId)) { + issueStreamError(ctx, streamId, SpdyStreamStatus.STREAM_IN_USE); + return; + } + + spdySession.receivedReply(streamId); + + // Close the remote side of the stream if this is the last frame + if (spdySynReplyFrame.isLast()) { + halfCloseStream(streamId, true, ctx.newSucceededFuture()); + } + + } else if (msg instanceof SpdyRstStreamFrame) { + + /* + * SPDY RST_STREAM frame processing requirements: + * + * After receiving a RST_STREAM on a stream, the receiver must not send + * additional frames on that stream. + * + * An endpoint must not send a RST_STREAM in response to a RST_STREAM. + */ + + SpdyRstStreamFrame spdyRstStreamFrame = (SpdyRstStreamFrame) msg; + removeStream(spdyRstStreamFrame.streamId(), ctx.newSucceededFuture()); + + } else if (msg instanceof SpdySettingsFrame) { + + SpdySettingsFrame spdySettingsFrame = (SpdySettingsFrame) msg; + + int settingsMinorVersion = spdySettingsFrame.getValue(SpdySettingsFrame.SETTINGS_MINOR_VERSION); + if (settingsMinorVersion >= 0 && settingsMinorVersion != minorVersion) { + // Settings frame had the wrong minor version + issueSessionError(ctx, SpdySessionStatus.PROTOCOL_ERROR); + return; + } + + int newConcurrentStreams = + spdySettingsFrame.getValue(SpdySettingsFrame.SETTINGS_MAX_CONCURRENT_STREAMS); + if (newConcurrentStreams >= 0) { + remoteConcurrentStreams = newConcurrentStreams; + } + + // Persistence flag are inconsistent with the use of SETTINGS to communicate + // the initial window size. Remove flags from the sender requesting that the + // value be persisted. Remove values that the sender indicates are persisted. + if (spdySettingsFrame.isPersisted(SpdySettingsFrame.SETTINGS_INITIAL_WINDOW_SIZE)) { + spdySettingsFrame.removeValue(SpdySettingsFrame.SETTINGS_INITIAL_WINDOW_SIZE); + } + spdySettingsFrame.setPersistValue(SpdySettingsFrame.SETTINGS_INITIAL_WINDOW_SIZE, false); + + int newInitialWindowSize = + spdySettingsFrame.getValue(SpdySettingsFrame.SETTINGS_INITIAL_WINDOW_SIZE); + if (newInitialWindowSize >= 0) { + updateInitialSendWindowSize(newInitialWindowSize); + } + + } else if (msg instanceof SpdyPingFrame) { + + /* + * SPDY PING frame processing requirements: + * + * Receivers of a PING frame should send an identical frame to the sender + * as soon as possible. + * + * Receivers of a PING frame must ignore frames that it did not initiate + */ + + SpdyPingFrame spdyPingFrame = (SpdyPingFrame) msg; + + if (isRemoteInitiatedId(spdyPingFrame.id())) { + ctx.writeAndFlush(spdyPingFrame); + return; + } + + // Note: only checks that there are outstanding pings since uniqueness is not enforced + if (pings.get() == 0) { + return; + } + pings.getAndDecrement(); + + } else if (msg instanceof SpdyGoAwayFrame) { + + receivedGoAwayFrame = true; + + } else if (msg instanceof SpdyHeadersFrame) { + + SpdyHeadersFrame spdyHeadersFrame = (SpdyHeadersFrame) msg; + int streamId = spdyHeadersFrame.streamId(); + + // Check if we received a valid HEADERS frame + if (spdyHeadersFrame.isInvalid()) { + issueStreamError(ctx, streamId, SpdyStreamStatus.PROTOCOL_ERROR); + return; + } + + if (spdySession.isRemoteSideClosed(streamId)) { + issueStreamError(ctx, streamId, SpdyStreamStatus.INVALID_STREAM); + return; + } + + // Close the remote side of the stream if this is the last frame + if (spdyHeadersFrame.isLast()) { + halfCloseStream(streamId, true, ctx.newSucceededFuture()); + } + + } else if (msg instanceof SpdyWindowUpdateFrame) { + + /* + * SPDY WINDOW_UPDATE frame processing requirements: + * + * Receivers of a WINDOW_UPDATE that cause the window size to exceed 2^31 + * must send a RST_STREAM with the getStatus code FLOW_CONTROL_ERROR. + * + * Sender should ignore all WINDOW_UPDATE frames associated with a stream + * after sending the last frame for the stream. + */ + + SpdyWindowUpdateFrame spdyWindowUpdateFrame = (SpdyWindowUpdateFrame) msg; + int streamId = spdyWindowUpdateFrame.streamId(); + int deltaWindowSize = spdyWindowUpdateFrame.deltaWindowSize(); + + // Ignore frames for half-closed streams + if (streamId != SPDY_SESSION_STREAM_ID && spdySession.isLocalSideClosed(streamId)) { + return; + } + + // Check for numerical overflow + if (spdySession.getSendWindowSize(streamId) > Integer.MAX_VALUE - deltaWindowSize) { + if (streamId == SPDY_SESSION_STREAM_ID) { + issueSessionError(ctx, SpdySessionStatus.PROTOCOL_ERROR); + } else { + issueStreamError(ctx, streamId, SpdyStreamStatus.FLOW_CONTROL_ERROR); + } + return; + } + + updateSendWindowSize(ctx, streamId, deltaWindowSize); + } + + ctx.fireChannelRead(msg); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + for (Integer streamId: spdySession.activeStreams().keySet()) { + removeStream(streamId, ctx.newSucceededFuture()); + } + ctx.fireChannelInactive(); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + if (cause instanceof SpdyProtocolException) { + issueSessionError(ctx, SpdySessionStatus.PROTOCOL_ERROR); + } + + ctx.fireExceptionCaught(cause); + } + + @Override + public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + sendGoAwayFrame(ctx, promise); + } + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + if (msg instanceof SpdyDataFrame || + msg instanceof SpdySynStreamFrame || + msg instanceof SpdySynReplyFrame || + msg instanceof SpdyRstStreamFrame || + msg instanceof SpdySettingsFrame || + msg instanceof SpdyPingFrame || + msg instanceof SpdyGoAwayFrame || + msg instanceof SpdyHeadersFrame || + msg instanceof SpdyWindowUpdateFrame) { + + handleOutboundMessage(ctx, msg, promise); + } else { + ctx.write(msg, promise); + } + } + + private void handleOutboundMessage(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + if (msg instanceof SpdyDataFrame) { + + SpdyDataFrame spdyDataFrame = (SpdyDataFrame) msg; + int streamId = spdyDataFrame.streamId(); + + // Frames must not be sent on half-closed streams + if (spdySession.isLocalSideClosed(streamId)) { + spdyDataFrame.release(); + promise.setFailure(PROTOCOL_EXCEPTION); + return; + } + + /* + * SPDY Data frame flow control processing requirements: + * + * Sender must not send a data frame with data length greater + * than the transfer window size. + * + * After sending each data frame, the sender decrements its + * transfer window size by the amount of data transmitted. + * + * When the window size becomes less than or equal to 0, the + * sender must pause transmitting data frames. + */ + + int dataLength = spdyDataFrame.content().readableBytes(); + int sendWindowSize = spdySession.getSendWindowSize(streamId); + int sessionSendWindowSize = spdySession.getSendWindowSize(SPDY_SESSION_STREAM_ID); + sendWindowSize = Math.min(sendWindowSize, sessionSendWindowSize); + + if (sendWindowSize <= 0) { + // Stream is stalled -- enqueue Data frame and return + spdySession.putPendingWrite(streamId, new SpdySession.PendingWrite(spdyDataFrame, promise)); + return; + } else if (sendWindowSize < dataLength) { + // Stream is not stalled but we cannot send the entire frame + spdySession.updateSendWindowSize(streamId, -1 * sendWindowSize); + spdySession.updateSendWindowSize(SPDY_SESSION_STREAM_ID, -1 * sendWindowSize); + + // Create a partial data frame whose length is the current window size + SpdyDataFrame partialDataFrame = new DefaultSpdyDataFrame( + streamId, spdyDataFrame.content().readRetainedSlice(sendWindowSize)); + + // Enqueue the remaining data (will be the first frame queued) + spdySession.putPendingWrite(streamId, new SpdySession.PendingWrite(spdyDataFrame, promise)); + + // The transfer window size is pre-decremented when sending a data frame downstream. + // Close the session on write failures that leave the transfer window in a corrupt state. + final ChannelHandlerContext context = ctx; + ctx.write(partialDataFrame).addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + if (!future.isSuccess()) { + issueSessionError(context, SpdySessionStatus.INTERNAL_ERROR); + } + } + }); + return; + } else { + // Window size is large enough to send entire data frame + spdySession.updateSendWindowSize(streamId, -1 * dataLength); + spdySession.updateSendWindowSize(SPDY_SESSION_STREAM_ID, -1 * dataLength); + + // The transfer window size is pre-decremented when sending a data frame downstream. + // Close the session on write failures that leave the transfer window in a corrupt state. + final ChannelHandlerContext context = ctx; + promise.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + if (!future.isSuccess()) { + issueSessionError(context, SpdySessionStatus.INTERNAL_ERROR); + } + } + }); + } + + // Close the local side of the stream if this is the last frame + if (spdyDataFrame.isLast()) { + halfCloseStream(streamId, false, promise); + } + + } else if (msg instanceof SpdySynStreamFrame) { + + SpdySynStreamFrame spdySynStreamFrame = (SpdySynStreamFrame) msg; + int streamId = spdySynStreamFrame.streamId(); + + if (isRemoteInitiatedId(streamId)) { + promise.setFailure(PROTOCOL_EXCEPTION); + return; + } + + byte priority = spdySynStreamFrame.priority(); + boolean remoteSideClosed = spdySynStreamFrame.isUnidirectional(); + boolean localSideClosed = spdySynStreamFrame.isLast(); + if (!acceptStream(streamId, priority, remoteSideClosed, localSideClosed)) { + promise.setFailure(PROTOCOL_EXCEPTION); + return; + } + + } else if (msg instanceof SpdySynReplyFrame) { + + SpdySynReplyFrame spdySynReplyFrame = (SpdySynReplyFrame) msg; + int streamId = spdySynReplyFrame.streamId(); + + // Frames must not be sent on half-closed streams + if (!isRemoteInitiatedId(streamId) || spdySession.isLocalSideClosed(streamId)) { + promise.setFailure(PROTOCOL_EXCEPTION); + return; + } + + // Close the local side of the stream if this is the last frame + if (spdySynReplyFrame.isLast()) { + halfCloseStream(streamId, false, promise); + } + + } else if (msg instanceof SpdyRstStreamFrame) { + + SpdyRstStreamFrame spdyRstStreamFrame = (SpdyRstStreamFrame) msg; + removeStream(spdyRstStreamFrame.streamId(), promise); + + } else if (msg instanceof SpdySettingsFrame) { + + SpdySettingsFrame spdySettingsFrame = (SpdySettingsFrame) msg; + + int settingsMinorVersion = spdySettingsFrame.getValue(SpdySettingsFrame.SETTINGS_MINOR_VERSION); + if (settingsMinorVersion >= 0 && settingsMinorVersion != minorVersion) { + // Settings frame had the wrong minor version + promise.setFailure(PROTOCOL_EXCEPTION); + return; + } + + int newConcurrentStreams = + spdySettingsFrame.getValue(SpdySettingsFrame.SETTINGS_MAX_CONCURRENT_STREAMS); + if (newConcurrentStreams >= 0) { + localConcurrentStreams = newConcurrentStreams; + } + + // Persistence flag are inconsistent with the use of SETTINGS to communicate + // the initial window size. Remove flags from the sender requesting that the + // value be persisted. Remove values that the sender indicates are persisted. + if (spdySettingsFrame.isPersisted(SpdySettingsFrame.SETTINGS_INITIAL_WINDOW_SIZE)) { + spdySettingsFrame.removeValue(SpdySettingsFrame.SETTINGS_INITIAL_WINDOW_SIZE); + } + spdySettingsFrame.setPersistValue(SpdySettingsFrame.SETTINGS_INITIAL_WINDOW_SIZE, false); + + int newInitialWindowSize = + spdySettingsFrame.getValue(SpdySettingsFrame.SETTINGS_INITIAL_WINDOW_SIZE); + if (newInitialWindowSize >= 0) { + updateInitialReceiveWindowSize(newInitialWindowSize); + } + + } else if (msg instanceof SpdyPingFrame) { + + SpdyPingFrame spdyPingFrame = (SpdyPingFrame) msg; + if (isRemoteInitiatedId(spdyPingFrame.id())) { + ctx.fireExceptionCaught(new IllegalArgumentException( + "invalid PING ID: " + spdyPingFrame.id())); + return; + } + pings.getAndIncrement(); + + } else if (msg instanceof SpdyGoAwayFrame) { + + // Why is this being sent? Intercept it and fail the write. + // Should have sent a CLOSE ChannelStateEvent + promise.setFailure(PROTOCOL_EXCEPTION); + return; + + } else if (msg instanceof SpdyHeadersFrame) { + + SpdyHeadersFrame spdyHeadersFrame = (SpdyHeadersFrame) msg; + int streamId = spdyHeadersFrame.streamId(); + + // Frames must not be sent on half-closed streams + if (spdySession.isLocalSideClosed(streamId)) { + promise.setFailure(PROTOCOL_EXCEPTION); + return; + } + + // Close the local side of the stream if this is the last frame + if (spdyHeadersFrame.isLast()) { + halfCloseStream(streamId, false, promise); + } + + } else if (msg instanceof SpdyWindowUpdateFrame) { + + // Why is this being sent? Intercept it and fail the write. + promise.setFailure(PROTOCOL_EXCEPTION); + return; + } + + ctx.write(msg, promise); + } + + /* + * SPDY Session Error Handling: + * + * When a session error occurs, the endpoint encountering the error must first + * send a GOAWAY frame with the Stream-ID of the most recently received stream + * from the remote endpoint, and the error code for why the session is terminating. + * + * After sending the GOAWAY frame, the endpoint must close the TCP connection. + */ + private void issueSessionError( + ChannelHandlerContext ctx, SpdySessionStatus status) { + + sendGoAwayFrame(ctx, status).addListener(new ClosingChannelFutureListener(ctx, ctx.newPromise())); + } + + /* + * SPDY Stream Error Handling: + * + * Upon a stream error, the endpoint must send a RST_STREAM frame which contains + * the Stream-ID for the stream where the error occurred and the error getStatus which + * caused the error. + * + * After sending the RST_STREAM, the stream is closed to the sending endpoint. + * + * Note: this is only called by the worker thread + */ + private void issueStreamError(ChannelHandlerContext ctx, int streamId, SpdyStreamStatus status) { + boolean fireChannelRead = !spdySession.isRemoteSideClosed(streamId); + ChannelPromise promise = ctx.newPromise(); + removeStream(streamId, promise); + + SpdyRstStreamFrame spdyRstStreamFrame = new DefaultSpdyRstStreamFrame(streamId, status); + ctx.writeAndFlush(spdyRstStreamFrame, promise); + if (fireChannelRead) { + ctx.fireChannelRead(spdyRstStreamFrame); + } + } + + /* + * Helper functions + */ + + private boolean isRemoteInitiatedId(int id) { + boolean serverId = isServerId(id); + return server && !serverId || !server && serverId; + } + + // need to synchronize to prevent new streams from being created while updating active streams + private void updateInitialSendWindowSize(int newInitialWindowSize) { + int deltaWindowSize = newInitialWindowSize - initialSendWindowSize; + initialSendWindowSize = newInitialWindowSize; + spdySession.updateAllSendWindowSizes(deltaWindowSize); + } + + // need to synchronize to prevent new streams from being created while updating active streams + private void updateInitialReceiveWindowSize(int newInitialWindowSize) { + int deltaWindowSize = newInitialWindowSize - initialReceiveWindowSize; + initialReceiveWindowSize = newInitialWindowSize; + spdySession.updateAllReceiveWindowSizes(deltaWindowSize); + } + + // need to synchronize accesses to sentGoAwayFrame, lastGoodStreamId, and initial window sizes + private boolean acceptStream( + int streamId, byte priority, boolean remoteSideClosed, boolean localSideClosed) { + // Cannot initiate any new streams after receiving or sending GOAWAY + if (receivedGoAwayFrame || sentGoAwayFrame) { + return false; + } + + boolean remote = isRemoteInitiatedId(streamId); + int maxConcurrentStreams = remote ? localConcurrentStreams : remoteConcurrentStreams; + if (spdySession.numActiveStreams(remote) >= maxConcurrentStreams) { + return false; + } + spdySession.acceptStream( + streamId, priority, remoteSideClosed, localSideClosed, + initialSendWindowSize, initialReceiveWindowSize, remote); + if (remote) { + lastGoodStreamId = streamId; + } + return true; + } + + private void halfCloseStream(int streamId, boolean remote, ChannelFuture future) { + if (remote) { + spdySession.closeRemoteSide(streamId, isRemoteInitiatedId(streamId)); + } else { + spdySession.closeLocalSide(streamId, isRemoteInitiatedId(streamId)); + } + if (closeSessionFutureListener != null && spdySession.noActiveStreams()) { + future.addListener(closeSessionFutureListener); + } + } + + private void removeStream(int streamId, ChannelFuture future) { + spdySession.removeStream(streamId, STREAM_CLOSED, isRemoteInitiatedId(streamId)); + + if (closeSessionFutureListener != null && spdySession.noActiveStreams()) { + future.addListener(closeSessionFutureListener); + } + } + + private void updateSendWindowSize(final ChannelHandlerContext ctx, int streamId, int deltaWindowSize) { + spdySession.updateSendWindowSize(streamId, deltaWindowSize); + + while (true) { + // Check if we have unblocked a stalled stream + SpdySession.PendingWrite pendingWrite = spdySession.getPendingWrite(streamId); + if (pendingWrite == null) { + return; + } + + SpdyDataFrame spdyDataFrame = pendingWrite.spdyDataFrame; + int dataFrameSize = spdyDataFrame.content().readableBytes(); + int writeStreamId = spdyDataFrame.streamId(); + int sendWindowSize = spdySession.getSendWindowSize(writeStreamId); + int sessionSendWindowSize = spdySession.getSendWindowSize(SPDY_SESSION_STREAM_ID); + sendWindowSize = Math.min(sendWindowSize, sessionSendWindowSize); + + if (sendWindowSize <= 0) { + return; + } else if (sendWindowSize < dataFrameSize) { + // We can send a partial frame + spdySession.updateSendWindowSize(writeStreamId, -1 * sendWindowSize); + spdySession.updateSendWindowSize(SPDY_SESSION_STREAM_ID, -1 * sendWindowSize); + + // Create a partial data frame whose length is the current window size + SpdyDataFrame partialDataFrame = new DefaultSpdyDataFrame( + writeStreamId, spdyDataFrame.content().readRetainedSlice(sendWindowSize)); + + // The transfer window size is pre-decremented when sending a data frame downstream. + // Close the session on write failures that leave the transfer window in a corrupt state. + ctx.writeAndFlush(partialDataFrame).addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + if (!future.isSuccess()) { + issueSessionError(ctx, SpdySessionStatus.INTERNAL_ERROR); + } + } + }); + } else { + // Window size is large enough to send entire data frame + spdySession.removePendingWrite(writeStreamId); + spdySession.updateSendWindowSize(writeStreamId, -1 * dataFrameSize); + spdySession.updateSendWindowSize(SPDY_SESSION_STREAM_ID, -1 * dataFrameSize); + + // Close the local side of the stream if this is the last frame + if (spdyDataFrame.isLast()) { + halfCloseStream(writeStreamId, false, pendingWrite.promise); + } + + // The transfer window size is pre-decremented when sending a data frame downstream. + // Close the session on write failures that leave the transfer window in a corrupt state. + ctx.writeAndFlush(spdyDataFrame, pendingWrite.promise).addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + if (!future.isSuccess()) { + issueSessionError(ctx, SpdySessionStatus.INTERNAL_ERROR); + } + } + }); + } + } + } + + private void sendGoAwayFrame(ChannelHandlerContext ctx, ChannelPromise future) { + // Avoid NotYetConnectedException + if (!ctx.channel().isActive()) { + ctx.close(future); + return; + } + + ChannelFuture f = sendGoAwayFrame(ctx, SpdySessionStatus.OK); + if (spdySession.noActiveStreams()) { + f.addListener(new ClosingChannelFutureListener(ctx, future)); + } else { + closeSessionFutureListener = new ClosingChannelFutureListener(ctx, future); + } + // FIXME: Close the connection forcibly after timeout. + } + + private ChannelFuture sendGoAwayFrame( + ChannelHandlerContext ctx, SpdySessionStatus status) { + if (!sentGoAwayFrame) { + sentGoAwayFrame = true; + SpdyGoAwayFrame spdyGoAwayFrame = new DefaultSpdyGoAwayFrame(lastGoodStreamId, status); + return ctx.writeAndFlush(spdyGoAwayFrame); + } else { + return ctx.newSucceededFuture(); + } + } + + private static final class ClosingChannelFutureListener implements ChannelFutureListener { + private final ChannelHandlerContext ctx; + private final ChannelPromise promise; + + ClosingChannelFutureListener(ChannelHandlerContext ctx, ChannelPromise promise) { + this.ctx = ctx; + this.promise = promise; + } + + @Override + public void operationComplete(ChannelFuture sentGoAwayFuture) throws Exception { + ctx.close(promise); + } + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdySessionStatus.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdySessionStatus.java new file mode 100644 index 0000000..befd30a --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdySessionStatus.java @@ -0,0 +1,111 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.spdy; + +import io.netty.util.internal.ObjectUtil; + +/** + * The SPDY session status code and its description. + */ +public class SpdySessionStatus implements Comparable { + + /** + * 0 OK + */ + public static final SpdySessionStatus OK = + new SpdySessionStatus(0, "OK"); + + /** + * 1 Protocol Error + */ + public static final SpdySessionStatus PROTOCOL_ERROR = + new SpdySessionStatus(1, "PROTOCOL_ERROR"); + + /** + * 2 Internal Error + */ + public static final SpdySessionStatus INTERNAL_ERROR = + new SpdySessionStatus(2, "INTERNAL_ERROR"); + + /** + * Returns the {@link SpdySessionStatus} represented by the specified code. + * If the specified code is a defined SPDY status code, a cached instance + * will be returned. Otherwise, a new instance will be returned. + */ + public static SpdySessionStatus valueOf(int code) { + switch (code) { + case 0: + return OK; + case 1: + return PROTOCOL_ERROR; + case 2: + return INTERNAL_ERROR; + } + + return new SpdySessionStatus(code, "UNKNOWN (" + code + ')'); + } + + private final int code; + + private final String statusPhrase; + + /** + * Creates a new instance with the specified {@code code} and its + * {@code statusPhrase}. + */ + public SpdySessionStatus(int code, String statusPhrase) { + this.statusPhrase = ObjectUtil.checkNotNull(statusPhrase, "statusPhrase"); + this.code = code; + } + + /** + * Returns the code of this status. + */ + public int code() { + return code; + } + + /** + * Returns the status phrase of this status. + */ + public String statusPhrase() { + return statusPhrase; + } + + @Override + public int hashCode() { + return code(); + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof SpdySessionStatus)) { + return false; + } + + return code() == ((SpdySessionStatus) o).code(); + } + + @Override + public String toString() { + return statusPhrase(); + } + + @Override + public int compareTo(SpdySessionStatus o) { + return code() - o.code(); + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdySettingsFrame.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdySettingsFrame.java new file mode 100644 index 0000000..3fb4063 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdySettingsFrame.java @@ -0,0 +1,107 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.spdy; + +import java.util.Set; + +/** + * A SPDY Protocol SETTINGS Frame + */ +public interface SpdySettingsFrame extends SpdyFrame { + + int SETTINGS_MINOR_VERSION = 0; + int SETTINGS_UPLOAD_BANDWIDTH = 1; + int SETTINGS_DOWNLOAD_BANDWIDTH = 2; + int SETTINGS_ROUND_TRIP_TIME = 3; + int SETTINGS_MAX_CONCURRENT_STREAMS = 4; + int SETTINGS_CURRENT_CWND = 5; + int SETTINGS_DOWNLOAD_RETRANS_RATE = 6; + int SETTINGS_INITIAL_WINDOW_SIZE = 7; + int SETTINGS_CLIENT_CERTIFICATE_VECTOR_SIZE = 8; + + /** + * Returns a {@code Set} of the setting IDs. + * The set's iterator will return the IDs in ascending order. + */ + Set ids(); + + /** + * Returns {@code true} if the setting ID has a value. + */ + boolean isSet(int id); + + /** + * Returns the value of the setting ID. + * Returns -1 if the setting ID is not set. + */ + int getValue(int id); + + /** + * Sets the value of the setting ID. + * The ID cannot be negative and cannot exceed 16777215. + */ + SpdySettingsFrame setValue(int id, int value); + + /** + * Sets the value of the setting ID. + * Sets if the setting should be persisted (should only be set by the server). + * Sets if the setting is persisted (should only be set by the client). + * The ID cannot be negative and cannot exceed 16777215. + */ + SpdySettingsFrame setValue(int id, int value, boolean persistVal, boolean persisted); + + /** + * Removes the value of the setting ID. + * Removes all persistence information for the setting. + */ + SpdySettingsFrame removeValue(int id); + + /** + * Returns {@code true} if this setting should be persisted. + * Returns {@code false} if this setting should not be persisted + * or if the setting ID has no value. + */ + boolean isPersistValue(int id); + + /** + * Sets if this setting should be persisted. + * Has no effect if the setting ID has no value. + */ + SpdySettingsFrame setPersistValue(int id, boolean persistValue); + + /** + * Returns {@code true} if this setting is persisted. + * Returns {@code false} if this setting should not be persisted + * or if the setting ID has no value. + */ + boolean isPersisted(int id); + + /** + * Sets if this setting is persisted. + * Has no effect if the setting ID has no value. + */ + SpdySettingsFrame setPersisted(int id, boolean persisted); + + /** + * Returns {@code true} if previously persisted settings should be cleared. + */ + boolean clearPreviouslyPersistedSettings(); + + /** + * Sets if previously persisted settings should be cleared. + */ + SpdySettingsFrame setClearPreviouslyPersistedSettings(boolean clear); +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyStreamFrame.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyStreamFrame.java new file mode 100644 index 0000000..6ee78a9 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyStreamFrame.java @@ -0,0 +1,43 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.spdy; + +/** + * A SPDY Protocol Frame that is associated with an individual SPDY Stream + */ +public interface SpdyStreamFrame extends SpdyFrame { + + /** + * Returns the Stream-ID of this frame. + */ + int streamId(); + + /** + * Sets the Stream-ID of this frame. The Stream-ID must be positive. + */ + SpdyStreamFrame setStreamId(int streamID); + + /** + * Returns {@code true} if this frame is the last frame to be transmitted + * on the stream. + */ + boolean isLast(); + + /** + * Sets if this frame is the last frame to be transmitted on the stream. + */ + SpdyStreamFrame setLast(boolean last); +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyStreamStatus.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyStreamStatus.java new file mode 100644 index 0000000..3ffafaa --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyStreamStatus.java @@ -0,0 +1,185 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.spdy; + +import io.netty.util.internal.ObjectUtil; + +/** + * The SPDY stream status code and its description. + */ +public class SpdyStreamStatus implements Comparable { + + /** + * 1 Protocol Error + */ + public static final SpdyStreamStatus PROTOCOL_ERROR = + new SpdyStreamStatus(1, "PROTOCOL_ERROR"); + + /** + * 2 Invalid Stream + */ + public static final SpdyStreamStatus INVALID_STREAM = + new SpdyStreamStatus(2, "INVALID_STREAM"); + + /** + * 3 Refused Stream + */ + public static final SpdyStreamStatus REFUSED_STREAM = + new SpdyStreamStatus(3, "REFUSED_STREAM"); + + /** + * 4 Unsupported Version + */ + public static final SpdyStreamStatus UNSUPPORTED_VERSION = + new SpdyStreamStatus(4, "UNSUPPORTED_VERSION"); + + /** + * 5 Cancel + */ + public static final SpdyStreamStatus CANCEL = + new SpdyStreamStatus(5, "CANCEL"); + + /** + * 6 Internal Error + */ + public static final SpdyStreamStatus INTERNAL_ERROR = + new SpdyStreamStatus(6, "INTERNAL_ERROR"); + + /** + * 7 Flow Control Error + */ + public static final SpdyStreamStatus FLOW_CONTROL_ERROR = + new SpdyStreamStatus(7, "FLOW_CONTROL_ERROR"); + + /** + * 8 Stream In Use + */ + public static final SpdyStreamStatus STREAM_IN_USE = + new SpdyStreamStatus(8, "STREAM_IN_USE"); + + /** + * 9 Stream Already Closed + */ + public static final SpdyStreamStatus STREAM_ALREADY_CLOSED = + new SpdyStreamStatus(9, "STREAM_ALREADY_CLOSED"); + + /** + * 10 Invalid Credentials + */ + public static final SpdyStreamStatus INVALID_CREDENTIALS = + new SpdyStreamStatus(10, "INVALID_CREDENTIALS"); + + /** + * 11 Frame Too Large + */ + public static final SpdyStreamStatus FRAME_TOO_LARGE = + new SpdyStreamStatus(11, "FRAME_TOO_LARGE"); + + /** + * Returns the {@link SpdyStreamStatus} represented by the specified code. + * If the specified code is a defined SPDY status code, a cached instance + * will be returned. Otherwise, a new instance will be returned. + */ + public static SpdyStreamStatus valueOf(int code) { + if (code == 0) { + throw new IllegalArgumentException( + "0 is not a valid status code for a RST_STREAM"); + } + + switch (code) { + case 1: + return PROTOCOL_ERROR; + case 2: + return INVALID_STREAM; + case 3: + return REFUSED_STREAM; + case 4: + return UNSUPPORTED_VERSION; + case 5: + return CANCEL; + case 6: + return INTERNAL_ERROR; + case 7: + return FLOW_CONTROL_ERROR; + case 8: + return STREAM_IN_USE; + case 9: + return STREAM_ALREADY_CLOSED; + case 10: + return INVALID_CREDENTIALS; + case 11: + return FRAME_TOO_LARGE; + } + + return new SpdyStreamStatus(code, "UNKNOWN (" + code + ')'); + } + + private final int code; + + private final String statusPhrase; + + /** + * Creates a new instance with the specified {@code code} and its + * {@code statusPhrase}. + */ + public SpdyStreamStatus(int code, String statusPhrase) { + if (code == 0) { + throw new IllegalArgumentException( + "0 is not a valid status code for a RST_STREAM"); + } + + this.statusPhrase = ObjectUtil.checkNotNull(statusPhrase, "statusPhrase"); + this.code = code; + } + + /** + * Returns the code of this status. + */ + public int code() { + return code; + } + + /** + * Returns the status phrase of this status. + */ + public String statusPhrase() { + return statusPhrase; + } + + @Override + public int hashCode() { + return code(); + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof SpdyStreamStatus)) { + return false; + } + + return code() == ((SpdyStreamStatus) o).code(); + } + + @Override + public String toString() { + return statusPhrase(); + } + + @Override + public int compareTo(SpdyStreamStatus o) { + return code() - o.code(); + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdySynReplyFrame.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdySynReplyFrame.java new file mode 100644 index 0000000..6003e54 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdySynReplyFrame.java @@ -0,0 +1,31 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.spdy; + +/** + * A SPDY Protocol SYN_REPLY Frame + */ +public interface SpdySynReplyFrame extends SpdyHeadersFrame { + + @Override + SpdySynReplyFrame setStreamId(int streamID); + + @Override + SpdySynReplyFrame setLast(boolean last); + + @Override + SpdySynReplyFrame setInvalid(); +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdySynStreamFrame.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdySynStreamFrame.java new file mode 100644 index 0000000..7019dff --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdySynStreamFrame.java @@ -0,0 +1,65 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.spdy; + +/** + * A SPDY Protocol SYN_STREAM Frame + */ +public interface SpdySynStreamFrame extends SpdyHeadersFrame { + + /** + * Returns the Associated-To-Stream-ID of this frame. + */ + int associatedStreamId(); + + /** + * Sets the Associated-To-Stream-ID of this frame. + * The Associated-To-Stream-ID cannot be negative. + */ + SpdySynStreamFrame setAssociatedStreamId(int associatedStreamId); + + /** + * Returns the priority of the stream. + */ + byte priority(); + + /** + * Sets the priority of the stream. + * The priority must be between 0 and 7 inclusive. + */ + SpdySynStreamFrame setPriority(byte priority); + + /** + * Returns {@code true} if the stream created with this frame is to be + * considered half-closed to the receiver. + */ + boolean isUnidirectional(); + + /** + * Sets if the stream created with this frame is to be considered + * half-closed to the receiver. + */ + SpdySynStreamFrame setUnidirectional(boolean unidirectional); + + @Override + SpdySynStreamFrame setStreamId(int streamID); + + @Override + SpdySynStreamFrame setLast(boolean last); + + @Override + SpdySynStreamFrame setInvalid(); +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyVersion.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyVersion.java new file mode 100644 index 0000000..74d58c7 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyVersion.java @@ -0,0 +1,36 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.spdy; + +public enum SpdyVersion { + SPDY_3_1 (3, 1); + + private final int version; + private final int minorVersion; + + SpdyVersion(int version, int minorVersion) { + this.version = version; + this.minorVersion = minorVersion; + } + + int getVersion() { + return version; + } + + int getMinorVersion() { + return minorVersion; + } +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyWindowUpdateFrame.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyWindowUpdateFrame.java new file mode 100644 index 0000000..ad6099e --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/SpdyWindowUpdateFrame.java @@ -0,0 +1,43 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.spdy; + +/** + * A SPDY Protocol WINDOW_UPDATE Frame + */ +public interface SpdyWindowUpdateFrame extends SpdyFrame { + + /** + * Returns the Stream-ID of this frame. + */ + int streamId(); + + /** + * Sets the Stream-ID of this frame. The Stream-ID cannot be negative. + */ + SpdyWindowUpdateFrame setStreamId(int streamID); + + /** + * Returns the Delta-Window-Size of this frame. + */ + int deltaWindowSize(); + + /** + * Sets the Delta-Window-Size of this frame. + * The Delta-Window-Size must be positive. + */ + SpdyWindowUpdateFrame setDeltaWindowSize(int deltaWindowSize); +} diff --git a/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/package-info.java b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/package-info.java new file mode 100644 index 0000000..4e06215 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/io/netty/handler/codec/spdy/package-info.java @@ -0,0 +1,19 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +/** + * Encoder, decoder, session handler and their related message types for the SPDY protocol. + */ +package io.netty.handler.codec.spdy; diff --git a/netty-handler-codec-http/src/main/java/module-info.java b/netty-handler-codec-http/src/main/java/module-info.java new file mode 100644 index 0000000..c5fb298 --- /dev/null +++ b/netty-handler-codec-http/src/main/java/module-info.java @@ -0,0 +1,18 @@ +module org.xbib.io.netty.handler.codec.http { + exports io.netty.handler.codec.http; + exports io.netty.handler.codec.http.cookie; + exports io.netty.handler.codec.http.cors; + exports io.netty.handler.codec.http.multipart; + exports io.netty.handler.codec.http.websocketx; + exports io.netty.handler.codec.http.websocketx.extensions; + requires org.xbib.io.netty.buffer; + requires org.xbib.io.netty.channel; + requires org.xbib.io.netty.handler.codec.compression; + requires org.xbib.io.netty.handler; + requires org.xbib.io.netty.handler.codec; + requires org.xbib.io.netty.handler.ssl; + requires org.xbib.io.netty.util; + requires org.xbib.io.netty.zlib; + requires com.aayushatharva.brotli4j; + +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/CombinedHttpHeadersTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/CombinedHttpHeadersTest.java new file mode 100644 index 0000000..b675435 --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/CombinedHttpHeadersTest.java @@ -0,0 +1,387 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.handler.codec.http.HttpHeadersTestUtils.HeaderValue; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import java.util.Arrays; +import java.util.Collections; +import java.util.Iterator; + +import static io.netty.handler.codec.http.HttpHeaderNames.SET_COOKIE; +import static io.netty.util.AsciiString.contentEquals; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.hasSize; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class CombinedHttpHeadersTest { + private static final CharSequence HEADER_NAME = "testHeader"; + + @Test + public void addCharSequencesCsv() { + final CombinedHttpHeaders headers = newCombinedHttpHeaders(); + headers.add(HEADER_NAME, HeaderValue.THREE.asList()); + assertCsvValues(headers, HeaderValue.THREE); + } + + @Test + public void addCharSequencesCsvWithExistingHeader() { + final CombinedHttpHeaders headers = newCombinedHttpHeaders(); + headers.add(HEADER_NAME, HeaderValue.THREE.asList()); + headers.add(HEADER_NAME, HeaderValue.FIVE.subset(4)); + assertCsvValues(headers, HeaderValue.FIVE); + } + + @Test + public void addCombinedHeadersWhenEmpty() { + final CombinedHttpHeaders headers = newCombinedHttpHeaders(); + final CombinedHttpHeaders otherHeaders = newCombinedHttpHeaders(); + otherHeaders.add(HEADER_NAME, "a"); + otherHeaders.add(HEADER_NAME, "b"); + headers.add(otherHeaders); + assertEquals("a,b", headers.get(HEADER_NAME)); + } + + @Test + public void addCombinedHeadersWhenNotEmpty() { + final CombinedHttpHeaders headers = newCombinedHttpHeaders(); + headers.add(HEADER_NAME, "a"); + final CombinedHttpHeaders otherHeaders = newCombinedHttpHeaders(); + otherHeaders.add(HEADER_NAME, "b"); + otherHeaders.add(HEADER_NAME, "c"); + headers.add(otherHeaders); + assertEquals("a,b,c", headers.get(HEADER_NAME)); + } + + @Test + public void dontCombineSetCookieHeaders() { + final CombinedHttpHeaders headers = newCombinedHttpHeaders(); + headers.add(SET_COOKIE, "a"); + final CombinedHttpHeaders otherHeaders = newCombinedHttpHeaders(); + otherHeaders.add(SET_COOKIE, "b"); + otherHeaders.add(SET_COOKIE, "c"); + headers.add(otherHeaders); + assertThat(headers.getAll(SET_COOKIE), hasSize(3)); + } + + @Test + public void dontCombineSetCookieHeadersRegardlessOfCase() { + final CombinedHttpHeaders headers = newCombinedHttpHeaders(); + headers.add("Set-Cookie", "a"); + final CombinedHttpHeaders otherHeaders = newCombinedHttpHeaders(); + otherHeaders.add("set-cookie", "b"); + otherHeaders.add("SET-COOKIE", "c"); + headers.add(otherHeaders); + assertThat(headers.getAll(SET_COOKIE), hasSize(3)); + } + + @Test + public void setCombinedHeadersWhenNotEmpty() { + final CombinedHttpHeaders headers = newCombinedHttpHeaders(); + headers.add(HEADER_NAME, "a"); + final CombinedHttpHeaders otherHeaders = newCombinedHttpHeaders(); + otherHeaders.add(HEADER_NAME, "b"); + otherHeaders.add(HEADER_NAME, "c"); + headers.set(otherHeaders); + assertEquals("b,c", headers.get(HEADER_NAME)); + } + + @Test + public void addUncombinedHeaders() { + final CombinedHttpHeaders headers = newCombinedHttpHeaders(); + headers.add(HEADER_NAME, "a"); + final DefaultHttpHeaders otherHeaders = new DefaultHttpHeaders(); + otherHeaders.add(HEADER_NAME, "b"); + otherHeaders.add(HEADER_NAME, "c"); + headers.add(otherHeaders); + assertEquals("a,b,c", headers.get(HEADER_NAME)); + } + + @Test + public void setUncombinedHeaders() { + final CombinedHttpHeaders headers = newCombinedHttpHeaders(); + headers.add(HEADER_NAME, "a"); + final DefaultHttpHeaders otherHeaders = new DefaultHttpHeaders(); + otherHeaders.add(HEADER_NAME, "b"); + otherHeaders.add(HEADER_NAME, "c"); + headers.set(otherHeaders); + assertEquals("b,c", headers.get(HEADER_NAME)); + } + + @Test + public void addCharSequencesCsvWithValueContainingComma() { + final CombinedHttpHeaders headers = newCombinedHttpHeaders(); + headers.add(HEADER_NAME, HeaderValue.SIX_QUOTED.subset(4)); + assertTrue(contentEquals(HeaderValue.SIX_QUOTED.subsetAsCsvString(4), headers.get(HEADER_NAME))); + assertEquals(HeaderValue.SIX_QUOTED.subset(4), headers.getAll(HEADER_NAME)); + } + + @Test + public void addCharSequencesCsvWithValueContainingCommas() { + final CombinedHttpHeaders headers = newCombinedHttpHeaders(); + headers.add(HEADER_NAME, HeaderValue.EIGHT.subset(6)); + assertTrue(contentEquals(HeaderValue.EIGHT.subsetAsCsvString(6), headers.get(HEADER_NAME))); + assertEquals(HeaderValue.EIGHT.subset(6), headers.getAll(HEADER_NAME)); + } + + @Test + public void addCharSequencesCsvNullValue() { + final CombinedHttpHeaders headers = newCombinedHttpHeaders(); + final String value = null; + assertThrows(NullPointerException.class, new Executable() { + @Override + public void execute() { + headers.add(HEADER_NAME, value); + } + }); + } + + @Test + public void addCharSequencesCsvMultipleTimes() { + final CombinedHttpHeaders headers = newCombinedHttpHeaders(); + for (int i = 0; i < 5; ++i) { + headers.add(HEADER_NAME, "value"); + } + assertTrue(contentEquals("value,value,value,value,value", headers.get(HEADER_NAME))); + } + + @Test + public void addCharSequenceCsv() { + final CombinedHttpHeaders headers = newCombinedHttpHeaders(); + addValues(headers, HeaderValue.ONE, HeaderValue.TWO, HeaderValue.THREE); + assertCsvValues(headers, HeaderValue.THREE); + } + + @Test + public void addCharSequenceCsvSingleValue() { + final CombinedHttpHeaders headers = newCombinedHttpHeaders(); + addValues(headers, HeaderValue.ONE); + assertCsvValue(headers, HeaderValue.ONE); + } + + @Test + public void addIterableCsv() { + final CombinedHttpHeaders headers = newCombinedHttpHeaders(); + headers.add(HEADER_NAME, HeaderValue.THREE.asList()); + assertCsvValues(headers, HeaderValue.THREE); + } + + @Test + public void addIterableCsvWithExistingHeader() { + final CombinedHttpHeaders headers = newCombinedHttpHeaders(); + headers.add(HEADER_NAME, HeaderValue.THREE.asList()); + headers.add(HEADER_NAME, HeaderValue.FIVE.subset(4)); + assertCsvValues(headers, HeaderValue.FIVE); + } + + @Test + public void addIterableCsvSingleValue() { + final CombinedHttpHeaders headers = newCombinedHttpHeaders(); + headers.add(HEADER_NAME, HeaderValue.ONE.asList()); + assertCsvValue(headers, HeaderValue.ONE); + } + + @Test + public void addIterableCsvEmpty() { + final CombinedHttpHeaders headers = newCombinedHttpHeaders(); + headers.add(HEADER_NAME, Collections.emptyList()); + assertEquals(Collections.singletonList(""), headers.getAll(HEADER_NAME)); + } + + @Test + public void addObjectCsv() { + final CombinedHttpHeaders headers = newCombinedHttpHeaders(); + addObjectValues(headers, HeaderValue.ONE, HeaderValue.TWO, HeaderValue.THREE); + assertCsvValues(headers, HeaderValue.THREE); + } + + @Test + public void addObjectsCsv() { + final CombinedHttpHeaders headers = newCombinedHttpHeaders(); + headers.add(HEADER_NAME, HeaderValue.THREE.asList()); + assertCsvValues(headers, HeaderValue.THREE); + } + + @Test + public void addObjectsIterableCsv() { + final CombinedHttpHeaders headers = newCombinedHttpHeaders(); + headers.add(HEADER_NAME, HeaderValue.THREE.asList()); + assertCsvValues(headers, HeaderValue.THREE); + } + + @Test + public void addObjectsCsvWithExistingHeader() { + final CombinedHttpHeaders headers = newCombinedHttpHeaders(); + headers.add(HEADER_NAME, HeaderValue.THREE.asList()); + headers.add(HEADER_NAME, HeaderValue.FIVE.subset(4)); + assertCsvValues(headers, HeaderValue.FIVE); + } + + @Test + public void setCharSequenceCsv() { + final CombinedHttpHeaders headers = newCombinedHttpHeaders(); + headers.set(HEADER_NAME, HeaderValue.THREE.asList()); + assertCsvValues(headers, HeaderValue.THREE); + } + + @Test + public void setIterableCsv() { + final CombinedHttpHeaders headers = newCombinedHttpHeaders(); + headers.set(HEADER_NAME, HeaderValue.THREE.asList()); + assertCsvValues(headers, HeaderValue.THREE); + } + + @Test + public void setObjectObjectsCsv() { + final CombinedHttpHeaders headers = newCombinedHttpHeaders(); + headers.set(HEADER_NAME, HeaderValue.THREE.asList()); + assertCsvValues(headers, HeaderValue.THREE); + } + + @Test + public void setObjectIterableCsv() { + final CombinedHttpHeaders headers = newCombinedHttpHeaders(); + headers.set(HEADER_NAME, HeaderValue.THREE.asList()); + assertCsvValues(headers, HeaderValue.THREE); + } + + private static CombinedHttpHeaders newCombinedHttpHeaders() { + return new CombinedHttpHeaders(true); + } + + private static void assertCsvValues(final CombinedHttpHeaders headers, final HeaderValue headerValue) { + assertTrue(contentEquals(headerValue.asCsv(), headers.get(HEADER_NAME))); + assertEquals(headerValue.asList(), headers.getAll(HEADER_NAME)); + } + + private static void assertCsvValue(final CombinedHttpHeaders headers, final HeaderValue headerValue) { + assertTrue(contentEquals(headerValue.toString(), headers.get(HEADER_NAME))); + assertTrue(contentEquals(headerValue.toString(), headers.getAll(HEADER_NAME).get(0))); + } + + private static void addValues(final CombinedHttpHeaders headers, HeaderValue... headerValues) { + for (HeaderValue v: headerValues) { + headers.add(HEADER_NAME, v.toString()); + } + } + + private static void addObjectValues(final CombinedHttpHeaders headers, HeaderValue... headerValues) { + for (HeaderValue v: headerValues) { + headers.add(HEADER_NAME, v.toString()); + } + } + + @Test + public void testGetAll() { + final CombinedHttpHeaders headers = newCombinedHttpHeaders(); + headers.set(HEADER_NAME, Arrays.asList("a", "b", "c")); + assertEquals(Arrays.asList("a", "b", "c"), headers.getAll(HEADER_NAME)); + headers.set(HEADER_NAME, Arrays.asList("a,", "b,", "c,")); + assertEquals(Arrays.asList("a,", "b,", "c,"), headers.getAll(HEADER_NAME)); + headers.set(HEADER_NAME, Arrays.asList("a\"", "b\"", "c\"")); + assertEquals(Arrays.asList("a\"", "b\"", "c\""), headers.getAll(HEADER_NAME)); + headers.set(HEADER_NAME, Arrays.asList("\"a\"", "\"b\"", "\"c\"")); + assertEquals(Arrays.asList("a", "b", "c"), headers.getAll(HEADER_NAME)); + headers.set(HEADER_NAME, "a,b,c"); + assertEquals(Collections.singletonList("a,b,c"), headers.getAll(HEADER_NAME)); + headers.set(HEADER_NAME, "\"a,b,c\""); + assertEquals(Collections.singletonList("a,b,c"), headers.getAll(HEADER_NAME)); + } + + @Test + public void getAllDontCombineSetCookie() { + final CombinedHttpHeaders headers = newCombinedHttpHeaders(); + headers.add(SET_COOKIE, "a"); + headers.add(SET_COOKIE, "b"); + assertThat(headers.getAll(SET_COOKIE), hasSize(2)); + assertEquals(Arrays.asList("a", "b"), headers.getAll(SET_COOKIE)); + } + + @Test + public void owsTrimming() { + final CombinedHttpHeaders headers = newCombinedHttpHeaders(); + headers.set(HEADER_NAME, Arrays.asList("\ta", " ", " b ", "\t \t")); + headers.add(HEADER_NAME, " c, d \t"); + + assertEquals(Arrays.asList("a", "", "b", "", "c, d"), headers.getAll(HEADER_NAME)); + assertEquals("a,,b,,\"c, d\"", headers.get(HEADER_NAME)); + + assertTrue(headers.containsValue(HEADER_NAME, "a", true)); + assertTrue(headers.containsValue(HEADER_NAME, " a ", true)); + assertTrue(headers.containsValue(HEADER_NAME, "a", true)); + assertFalse(headers.containsValue(HEADER_NAME, "a,b", true)); + + assertFalse(headers.containsValue(HEADER_NAME, " c, d ", true)); + assertFalse(headers.containsValue(HEADER_NAME, "c, d", true)); + assertTrue(headers.containsValue(HEADER_NAME, " c ", true)); + assertTrue(headers.containsValue(HEADER_NAME, "d", true)); + + assertTrue(headers.containsValue(HEADER_NAME, "\t", true)); + assertTrue(headers.containsValue(HEADER_NAME, "", true)); + + assertFalse(headers.containsValue(HEADER_NAME, "e", true)); + + HttpHeaders copiedHeaders = newCombinedHttpHeaders().add(headers); + assertEquals(Arrays.asList("a", "", "b", "", "c, d"), copiedHeaders.getAll(HEADER_NAME)); + } + + @Test + public void valueIterator() { + final CombinedHttpHeaders headers = newCombinedHttpHeaders(); + headers.set(HEADER_NAME, Arrays.asList("\ta", " ", " b ", "\t \t")); + headers.add(HEADER_NAME, " c, d \t"); + + assertFalse(headers.valueStringIterator("foo").hasNext()); + assertValueIterator(headers.valueStringIterator(HEADER_NAME)); + assertFalse(headers.valueCharSequenceIterator("foo").hasNext()); + assertValueIterator(headers.valueCharSequenceIterator(HEADER_NAME)); + } + + @Test + public void nonCombinableHeaderIterator() { + final CombinedHttpHeaders headers = newCombinedHttpHeaders(); + headers.add(SET_COOKIE, "c"); + headers.add(SET_COOKIE, "b"); + headers.add(SET_COOKIE, "a"); + + final Iterator strItr = headers.valueStringIterator(SET_COOKIE); + assertTrue(strItr.hasNext()); + assertEquals("a", strItr.next()); + assertTrue(strItr.hasNext()); + assertEquals("b", strItr.next()); + assertTrue(strItr.hasNext()); + assertEquals("c", strItr.next()); + } + + private static void assertValueIterator(Iterator strItr) { + assertTrue(strItr.hasNext()); + assertEquals("a", strItr.next()); + assertTrue(strItr.hasNext()); + assertEquals("", strItr.next()); + assertTrue(strItr.hasNext()); + assertEquals("b", strItr.next()); + assertTrue(strItr.hasNext()); + assertEquals("", strItr.next()); + assertTrue(strItr.hasNext()); + assertEquals("c, d", strItr.next()); + assertFalse(strItr.hasNext()); + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/DefaultHttpHeadersTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/DefaultHttpHeadersTest.java new file mode 100644 index 0000000..7e7efdb --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/DefaultHttpHeadersTest.java @@ -0,0 +1,344 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.handler.codec.http.HttpHeadersTestUtils.HeaderValue; +import io.netty.util.AsciiString; +import io.netty.util.internal.EmptyArrays; +import io.netty.util.internal.StringUtil; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.Set; + +import static io.netty.handler.codec.http.HttpHeaderNames.ACCEPT; +import static io.netty.handler.codec.http.HttpHeaderNames.CONNECTION; +import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_LENGTH; +import static io.netty.handler.codec.http.HttpHeaderValues.APPLICATION_JSON; +import static io.netty.handler.codec.http.HttpHeaderValues.CLOSE; +import static io.netty.handler.codec.http.HttpHeaderValues.ZERO; +import static io.netty.handler.codec.http.HttpHeadersTestUtils.of; +import static io.netty.util.AsciiString.contentEquals; +import static java.util.Arrays.asList; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class DefaultHttpHeadersTest { + private static final CharSequence HEADER_NAME = "testHeader"; + private static final CharSequence ILLEGAL_VALUE = "testHeader\r\nContent-Length:45\r\n\r\n"; + + @Test + public void nullHeaderNameNotAllowed() { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + new DefaultHttpHeaders().add(null, "foo"); + } + }); + } + + @Test + public void emptyHeaderNameNotAllowed() { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + new DefaultHttpHeaders().add(StringUtil.EMPTY_STRING, "foo"); + } + }); + } + + @Test + public void keysShouldBeCaseInsensitive() { + DefaultHttpHeaders headers = new DefaultHttpHeaders(); + headers.add(of("Name"), of("value1")); + headers.add(of("name"), of("value2")); + headers.add(of("NAME"), of("value3")); + assertEquals(3, headers.size()); + + List values = asList("value1", "value2", "value3"); + + assertEquals(values, headers.getAll(of("NAME"))); + assertEquals(values, headers.getAll(of("name"))); + assertEquals(values, headers.getAll(of("Name"))); + assertEquals(values, headers.getAll(of("nAmE"))); + } + + @Test + public void keysShouldBeCaseInsensitiveInHeadersEquals() { + DefaultHttpHeaders headers1 = new DefaultHttpHeaders(); + headers1.add(of("name1"), asList("value1", "value2", "value3")); + headers1.add(of("nAmE2"), of("value4")); + + DefaultHttpHeaders headers2 = new DefaultHttpHeaders(); + headers2.add(of("naMe1"), asList("value1", "value2", "value3")); + headers2.add(of("NAME2"), of("value4")); + + assertEquals(headers1, headers1); + assertEquals(headers2, headers2); + assertEquals(headers1, headers2); + assertEquals(headers2, headers1); + assertEquals(headers1.hashCode(), headers2.hashCode()); + } + + @Test + public void testStringKeyRetrievedAsAsciiString() { + final HttpHeaders headers = new DefaultHttpHeaders(false); + + // Test adding String key and retrieving it using a AsciiString key + final String connection = "keep-alive"; + headers.add(of("Connection"), connection); + + // Passes + final String value = headers.getAsString(HttpHeaderNames.CONNECTION.toString()); + assertNotNull(value); + assertEquals(connection, value); + + // Passes + final String value2 = headers.getAsString(HttpHeaderNames.CONNECTION); + assertNotNull(value2); + assertEquals(connection, value2); + } + + @Test + public void testAsciiStringKeyRetrievedAsString() { + final HttpHeaders headers = new DefaultHttpHeaders(false); + + // Test adding AsciiString key and retrieving it using a String key + final String cacheControl = "no-cache"; + headers.add(HttpHeaderNames.CACHE_CONTROL, cacheControl); + + final String value = headers.getAsString(HttpHeaderNames.CACHE_CONTROL); + assertNotNull(value); + assertEquals(cacheControl, value); + + final String value2 = headers.getAsString(HttpHeaderNames.CACHE_CONTROL.toString()); + assertNotNull(value2); + assertEquals(cacheControl, value2); + } + + @Test + public void testRemoveTransferEncodingIgnoreCase() { + HttpMessage message = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK); + message.headers().set(HttpHeaderNames.TRANSFER_ENCODING, "Chunked"); + assertFalse(message.headers().isEmpty()); + HttpUtil.setTransferEncodingChunked(message, false); + assertTrue(message.headers().isEmpty()); + } + + // Test for https://github.com/netty/netty/issues/1690 + @Test + public void testGetOperations() { + HttpHeaders headers = new DefaultHttpHeaders(); + headers.add(of("Foo"), of("1")); + headers.add(of("Foo"), of("2")); + + assertEquals("1", headers.get(of("Foo"))); + + List values = headers.getAll(of("Foo")); + assertEquals(2, values.size()); + assertEquals("1", values.get(0)); + assertEquals("2", values.get(1)); + } + + @Test + public void testEqualsIgnoreCase() { + assertThat(AsciiString.contentEqualsIgnoreCase(null, null), is(true)); + assertThat(AsciiString.contentEqualsIgnoreCase(null, "foo"), is(false)); + assertThat(AsciiString.contentEqualsIgnoreCase("bar", null), is(false)); + assertThat(AsciiString.contentEqualsIgnoreCase("FoO", "fOo"), is(true)); + } + + @Test + public void testSetNullHeaderValueValidate() { + final HttpHeaders headers = new DefaultHttpHeaders(true); + assertThrows(NullPointerException.class, new Executable() { + @Override + public void execute() { + headers.set(of("test"), (CharSequence) null); + } + }); + } + + @Test + public void testSetNullHeaderValueNotValidate() { + final HttpHeaders headers = new DefaultHttpHeaders(false); + assertThrows(NullPointerException.class, new Executable() { + @Override + public void execute() { + headers.set(of("test"), (CharSequence) null); + } + }); + } + + @Test + public void addCharSequences() { + final DefaultHttpHeaders headers = newDefaultDefaultHttpHeaders(); + headers.add(HEADER_NAME, HeaderValue.THREE.asList()); + assertDefaultValues(headers, HeaderValue.THREE); + } + + @Test + public void addIterable() { + final DefaultHttpHeaders headers = newDefaultDefaultHttpHeaders(); + headers.add(HEADER_NAME, HeaderValue.THREE.asList()); + assertDefaultValues(headers, HeaderValue.THREE); + } + + @Test + public void addObjects() { + final DefaultHttpHeaders headers = newDefaultDefaultHttpHeaders(); + headers.add(HEADER_NAME, HeaderValue.THREE.asList()); + assertDefaultValues(headers, HeaderValue.THREE); + } + + @Test + public void setCharSequences() { + final DefaultHttpHeaders headers = newDefaultDefaultHttpHeaders(); + headers.set(HEADER_NAME, HeaderValue.THREE.asList()); + assertDefaultValues(headers, HeaderValue.THREE); + } + + @Test + public void setIterable() { + final DefaultHttpHeaders headers = newDefaultDefaultHttpHeaders(); + headers.set(HEADER_NAME, HeaderValue.THREE.asList()); + assertDefaultValues(headers, HeaderValue.THREE); + } + + @Test + public void setObjectObjects() { + final DefaultHttpHeaders headers = newDefaultDefaultHttpHeaders(); + headers.set(HEADER_NAME, HeaderValue.THREE.asList()); + assertDefaultValues(headers, HeaderValue.THREE); + } + + @Test + public void setObjectIterable() { + final DefaultHttpHeaders headers = newDefaultDefaultHttpHeaders(); + headers.set(HEADER_NAME, HeaderValue.THREE.asList()); + assertDefaultValues(headers, HeaderValue.THREE); + } + + @Test + public void setCharSequenceValidatesValue() { + final DefaultHttpHeaders headers = newDefaultDefaultHttpHeaders(); + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() throws Throwable { + headers.set(HEADER_NAME, ILLEGAL_VALUE); + } + }); + assertTrue(exception.getMessage().contains(HEADER_NAME)); + } + + @Test + public void setIterableValidatesValue() { + final DefaultHttpHeaders headers = newDefaultDefaultHttpHeaders(); + IllegalArgumentException exception = assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() throws Throwable { + headers.set(HEADER_NAME, Collections.singleton(ILLEGAL_VALUE)); + } + }); + assertTrue(exception.getMessage().contains(HEADER_NAME)); + } + + @Test + public void toStringOnEmptyHeaders() { + assertEquals("DefaultHttpHeaders[]", newDefaultDefaultHttpHeaders().toString()); + } + + @Test + public void toStringOnSingleHeader() { + assertEquals("DefaultHttpHeaders[foo: bar]", newDefaultDefaultHttpHeaders() + .add("foo", "bar") + .toString()); + } + + @Test + public void toStringOnMultipleHeaders() { + assertEquals("DefaultHttpHeaders[foo: bar, baz: qix]", newDefaultDefaultHttpHeaders() + .add("foo", "bar") + .add("baz", "qix") + .toString()); + } + + @Test + public void providesHeaderNamesAsArray() throws Exception { + Set nettyHeaders = new DefaultHttpHeaders() + .add(HttpHeaderNames.CONTENT_LENGTH, 10) + .names(); + + String[] namesArray = nettyHeaders.toArray(EmptyArrays.EMPTY_STRINGS); + assertArrayEquals(namesArray, new String[] { HttpHeaderNames.CONTENT_LENGTH.toString() }); + } + + @Test + public void names() { + HttpHeaders headers = new DefaultHttpHeaders(true) + .add(ACCEPT, APPLICATION_JSON) + .add(CONTENT_LENGTH, ZERO) + .add(CONNECTION, CLOSE); + assertFalse(headers.isEmpty()); + assertEquals(3, headers.size()); + Set names = headers.names(); + assertEquals(3, names.size()); + assertTrue(names.contains(ACCEPT.toString())); + assertTrue(names.contains(CONTENT_LENGTH.toString())); + assertTrue(names.contains(CONNECTION.toString())); + } + + @Test + public void testContainsName() { + HttpHeaders headers = new DefaultHttpHeaders(true) + .add(CONTENT_LENGTH, "36"); + assertTrue(headers.contains("Content-Length")); + assertTrue(headers.contains("content-length")); + assertTrue(headers.contains(CONTENT_LENGTH)); + headers.remove(CONTENT_LENGTH); + assertFalse(headers.contains("Content-Length")); + assertFalse(headers.contains("content-length")); + assertFalse(headers.contains(CONTENT_LENGTH)); + + assertFalse(headers.contains("non-existent-name")); + assertFalse(headers.contains(new AsciiString("non-existent-name"))); + } + + private static void assertDefaultValues(final DefaultHttpHeaders headers, final HeaderValue headerValue) { + assertTrue(contentEquals(headerValue.asList().get(0), headers.get(HEADER_NAME))); + List expected = headerValue.asList(); + List actual = headers.getAll(HEADER_NAME); + assertEquals(expected.size(), actual.size()); + Iterator eItr = expected.iterator(); + Iterator aItr = actual.iterator(); + while (eItr.hasNext()) { + assertTrue(contentEquals(eItr.next(), aItr.next())); + } + } + + private static DefaultHttpHeaders newDefaultDefaultHttpHeaders() { + return new DefaultHttpHeaders(true); + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/DefaultHttpRequestTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/DefaultHttpRequestTest.java new file mode 100644 index 0000000..9ddb597 --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/DefaultHttpRequestTest.java @@ -0,0 +1,50 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.util.AsciiString; +import org.junit.jupiter.api.Test; + +import static io.netty.handler.codec.http.HttpHeadersTestUtils.of; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class DefaultHttpRequestTest { + + @Test + public void testHeaderRemoval() { + HttpMessage m = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"); + HttpHeaders h = m.headers(); + + // Insert sample keys. + for (int i = 0; i < 1000; i ++) { + h.set(of(String.valueOf(i)), AsciiString.EMPTY_STRING); + } + + // Remove in reversed order. + for (int i = 999; i >= 0; i --) { + h.remove(of(String.valueOf(i))); + } + + // Check if random access returns nothing. + for (int i = 0; i < 1000; i ++) { + assertNull(h.get(of(String.valueOf(i)))); + } + + // Check if sequential access returns nothing. + assertTrue(h.isEmpty()); + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/DefaultHttpResponseTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/DefaultHttpResponseTest.java new file mode 100644 index 0000000..a713059 --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/DefaultHttpResponseTest.java @@ -0,0 +1,40 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; + +public class DefaultHttpResponseTest { + + @Test + public void testNotEquals() { + HttpResponse ok = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK); + HttpResponse notFound = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.NOT_FOUND); + assertNotEquals(ok, notFound); + assertNotEquals(ok.hashCode(), notFound.hashCode()); + } + + @Test + public void testEquals() { + HttpResponse ok = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK); + HttpResponse ok2 = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK); + assertEquals(ok, ok2); + assertEquals(ok.hashCode(), ok2.hashCode()); + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/EmptyHttpHeadersInitializationTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/EmptyHttpHeadersInitializationTest.java new file mode 100644 index 0000000..22c7d34 --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/EmptyHttpHeadersInitializationTest.java @@ -0,0 +1,43 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertNotNull; + +/** + * A test to validate that either order of initialization of the {@link EmptyHttpHeaders#INSTANCE} and + * {@link HttpHeaders#EMPTY_HEADERS} field results in both fields being non-null. + * + * Since this is testing static initialization, the tests might not actually test anything, except + * when run in isolation. + */ +public class EmptyHttpHeadersInitializationTest { + + @Test + public void testEmptyHttpHeadersFirst() { + assertNotNull(EmptyHttpHeaders.INSTANCE); + assertNotNull(HttpHeaders.EMPTY_HEADERS); + } + + @Test + public void testHttpHeadersFirst() { + assertNotNull(HttpHeaders.EMPTY_HEADERS); + assertNotNull(EmptyHttpHeaders.INSTANCE); + } + +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpChunkedInputTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpChunkedInputTest.java new file mode 100644 index 0000000..c0904c1 --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpChunkedInputTest.java @@ -0,0 +1,166 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.stream.ChunkedFile; +import io.netty.handler.stream.ChunkedInput; +import io.netty.handler.stream.ChunkedNioFile; +import io.netty.handler.stream.ChunkedNioStream; +import io.netty.handler.stream.ChunkedStream; +import io.netty.handler.stream.ChunkedWriteHandler; +import io.netty.util.internal.PlatformDependent; +import org.junit.jupiter.api.Test; + +import java.io.ByteArrayInputStream; +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.nio.channels.Channels; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class HttpChunkedInputTest { + private static final byte[] BYTES = new byte[1024 * 64]; + private static final File TMP; + + static { + for (int i = 0; i < BYTES.length; i++) { + BYTES[i] = (byte) i; + } + + FileOutputStream out = null; + try { + TMP = PlatformDependent.createTempFile("netty-chunk-", ".tmp", null); + TMP.deleteOnExit(); + out = new FileOutputStream(TMP); + out.write(BYTES); + out.flush(); + } catch (IOException e) { + throw new RuntimeException(e); + } finally { + if (out != null) { + try { + out.close(); + } catch (IOException e) { + // ignore + } + } + } + } + + @Test + public void testChunkedStream() { + check(new HttpChunkedInput(new ChunkedStream(new ByteArrayInputStream(BYTES)))); + } + + @Test + public void testChunkedNioStream() { + check(new HttpChunkedInput(new ChunkedNioStream(Channels.newChannel(new ByteArrayInputStream(BYTES))))); + } + + @Test + public void testChunkedFile() throws IOException { + check(new HttpChunkedInput(new ChunkedFile(TMP))); + } + + @Test + public void testChunkedNioFile() throws IOException { + check(new HttpChunkedInput(new ChunkedNioFile(TMP))); + } + + @Test + public void testWrappedReturnNull() throws Exception { + HttpChunkedInput input = new HttpChunkedInput(new ChunkedInput() { + @Override + public boolean isEndOfInput() throws Exception { + return false; + } + + @Override + public void close() throws Exception { + // NOOP + } + + @Override + public ByteBuf readChunk(ChannelHandlerContext ctx) throws Exception { + return null; + } + + @Override + public ByteBuf readChunk(ByteBufAllocator allocator) throws Exception { + return null; + } + + @Override + public long length() { + return 0; + } + + @Override + public long progress() { + return 0; + } + }); + assertNull(input.readChunk(ByteBufAllocator.DEFAULT)); + } + + private static void check(ChunkedInput... inputs) { + EmbeddedChannel ch = new EmbeddedChannel(new ChunkedWriteHandler()); + + for (ChunkedInput input : inputs) { + ch.writeOutbound(input); + } + + assertTrue(ch.finish()); + + int i = 0; + int read = 0; + HttpContent lastHttpContent = null; + for (;;) { + HttpContent httpContent = ch.readOutbound(); + if (httpContent == null) { + break; + } + if (lastHttpContent != null) { + assertTrue(lastHttpContent instanceof DefaultHttpContent, "Chunk must be DefaultHttpContent"); + } + + ByteBuf buffer = httpContent.content(); + while (buffer.isReadable()) { + assertEquals(BYTES[i++], buffer.readByte()); + read++; + if (i == BYTES.length) { + i = 0; + } + } + buffer.release(); + + // Save last chunk + lastHttpContent = httpContent; + } + + assertEquals(BYTES.length * inputs.length, read); + assertSame(LastHttpContent.EMPTY_LAST_CONTENT, lastHttpContent, + "Last chunk must be LastHttpContent.EMPTY_LAST_CONTENT"); + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpClientCodecTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpClientCodecTest.java new file mode 100644 index 0000000..d2faab8 --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpClientCodecTest.java @@ -0,0 +1,440 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelOption; +import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.SocketChannel; +import io.netty.channel.socket.nio.NioServerSocketChannel; +import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.handler.codec.CodecException; +import io.netty.handler.codec.PrematureChannelClosureException; +import io.netty.util.CharsetUtil; +import io.netty.util.NetUtil; +import org.junit.jupiter.api.Test; + +import java.net.InetSocketAddress; +import java.util.concurrent.CountDownLatch; + +import static io.netty.util.ReferenceCountUtil.release; +import static java.util.concurrent.TimeUnit.SECONDS; +import static org.hamcrest.CoreMatchers.*; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.not; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class HttpClientCodecTest { + + private static final String EMPTY_RESPONSE = "HTTP/1.0 200 OK\r\nContent-Length: 0\r\n\r\n"; + private static final String RESPONSE = "HTTP/1.0 200 OK\r\n" + "Date: Fri, 31 Dec 1999 23:59:59 GMT\r\n" + + "Content-Type: text/html\r\n" + "Content-Length: 28\r\n" + "\r\n" + + "\r\n"; + private static final String INCOMPLETE_CHUNKED_RESPONSE = "HTTP/1.1 200 OK\r\n" + "Content-Type: text/plain\r\n" + + "Transfer-Encoding: chunked\r\n" + "\r\n" + + "5\r\n" + "first\r\n" + "6\r\n" + "second\r\n" + "0\r\n"; + private static final String CHUNKED_RESPONSE = INCOMPLETE_CHUNKED_RESPONSE + "\r\n"; + + @Test + public void testConnectWithResponseContent() { + HttpClientCodec codec = new HttpClientCodec(4096, 8192, 8192, true); + EmbeddedChannel ch = new EmbeddedChannel(codec); + + sendRequestAndReadResponse(ch, HttpMethod.CONNECT, RESPONSE); + ch.finish(); + } + + @Test + public void testFailsNotOnRequestResponseChunked() { + HttpClientCodec codec = new HttpClientCodec(4096, 8192, 8192, true); + EmbeddedChannel ch = new EmbeddedChannel(codec); + + sendRequestAndReadResponse(ch, HttpMethod.GET, CHUNKED_RESPONSE); + ch.finish(); + } + + @Test + public void testFailsOnMissingResponse() { + HttpClientCodec codec = new HttpClientCodec(4096, 8192, 8192, true); + EmbeddedChannel ch = new EmbeddedChannel(codec); + + assertTrue(ch.writeOutbound(new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, + "http://localhost/"))); + ByteBuf buffer = ch.readOutbound(); + assertNotNull(buffer); + buffer.release(); + try { + ch.finish(); + fail(); + } catch (CodecException e) { + assertTrue(e instanceof PrematureChannelClosureException); + } + } + + @Test + public void testFailsOnIncompleteChunkedResponse() { + HttpClientCodec codec = new HttpClientCodec(4096, 8192, 8192, true); + EmbeddedChannel ch = new EmbeddedChannel(codec); + + ch.writeOutbound(new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "http://localhost/")); + ByteBuf buffer = ch.readOutbound(); + assertNotNull(buffer); + buffer.release(); + assertNull(ch.readInbound()); + ch.writeInbound(Unpooled.copiedBuffer(INCOMPLETE_CHUNKED_RESPONSE, CharsetUtil.ISO_8859_1)); + assertThat(ch.readInbound(), instanceOf(HttpResponse.class)); + ((HttpContent) ch.readInbound()).release(); // Chunk 'first' + ((HttpContent) ch.readInbound()).release(); // Chunk 'second' + assertNull(ch.readInbound()); + + try { + ch.finish(); + fail(); + } catch (CodecException e) { + assertTrue(e instanceof PrematureChannelClosureException); + } + } + + @Test + public void testServerCloseSocketInputProvidesData() throws InterruptedException { + ServerBootstrap sb = new ServerBootstrap(); + Bootstrap cb = new Bootstrap(); + final CountDownLatch serverChannelLatch = new CountDownLatch(1); + final CountDownLatch responseReceivedLatch = new CountDownLatch(1); + try { + sb.group(new NioEventLoopGroup(2)); + sb.channel(NioServerSocketChannel.class); + sb.childHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + // Don't use the HttpServerCodec, because we don't want to have content-length or anything added. + ch.pipeline().addLast(new HttpRequestDecoder(4096, 8192, 8192, true)); + ch.pipeline().addLast(new HttpObjectAggregator(4096)); + ch.pipeline().addLast(new SimpleChannelInboundHandler() { + @Override + protected void channelRead0(ChannelHandlerContext ctx, FullHttpRequest msg) { + // This is just a simple demo...don't block in IO + assertTrue(ctx.channel() instanceof SocketChannel); + final SocketChannel sChannel = (SocketChannel) ctx.channel(); + /** + * The point of this test is to not add any content-length or content-encoding headers + * and the client should still handle this. + * See RFC 7230, 3.3.3. + */ + sChannel.writeAndFlush(Unpooled.wrappedBuffer(("HTTP/1.0 200 OK\r\n" + + "Date: Fri, 31 Dec 1999 23:59:59 GMT\r\n" + + "Content-Type: text/html\r\n\r\n").getBytes(CharsetUtil.ISO_8859_1))) + .addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + assertTrue(future.isSuccess()); + sChannel.writeAndFlush(Unpooled.wrappedBuffer( + "hello half closed!\r\n" + .getBytes(CharsetUtil.ISO_8859_1))) + .addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + assertTrue(future.isSuccess()); + sChannel.shutdownOutput(); + } + }); + } + }); + } + }); + serverChannelLatch.countDown(); + } + }); + + cb.group(new NioEventLoopGroup(1)); + cb.channel(NioSocketChannel.class); + cb.option(ChannelOption.ALLOW_HALF_CLOSURE, true); + cb.handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ch.pipeline().addLast(new HttpClientCodec(4096, 8192, 8192, true, true)); + ch.pipeline().addLast(new HttpObjectAggregator(4096)); + ch.pipeline().addLast(new SimpleChannelInboundHandler() { + @Override + protected void channelRead0(ChannelHandlerContext ctx, FullHttpResponse msg) { + responseReceivedLatch.countDown(); + } + }); + } + }); + + Channel serverChannel = sb.bind(new InetSocketAddress(0)).sync().channel(); + int port = ((InetSocketAddress) serverChannel.localAddress()).getPort(); + + ChannelFuture ccf = cb.connect(new InetSocketAddress(NetUtil.LOCALHOST, port)); + assertTrue(ccf.awaitUninterruptibly().isSuccess()); + Channel clientChannel = ccf.channel(); + assertTrue(serverChannelLatch.await(5, SECONDS)); + clientChannel.writeAndFlush(new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/")); + assertTrue(responseReceivedLatch.await(5, SECONDS)); + } finally { + sb.config().group().shutdownGracefully(); + sb.config().childGroup().shutdownGracefully(); + cb.config().group().shutdownGracefully(); + } + } + + @Test + public void testContinueParsingAfterConnect() throws Exception { + testAfterConnect(true); + } + + @Test + public void testPassThroughAfterConnect() throws Exception { + testAfterConnect(false); + } + + private static void testAfterConnect(final boolean parseAfterConnect) throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new HttpClientCodec(4096, 8192, 8192, true, true, parseAfterConnect)); + + Consumer connectResponseConsumer = new Consumer(); + sendRequestAndReadResponse(ch, HttpMethod.CONNECT, EMPTY_RESPONSE, connectResponseConsumer); + assertTrue(connectResponseConsumer.getReceivedCount() > 0, "No connect response messages received."); + Consumer responseConsumer = new Consumer() { + @Override + void accept(Object object) { + if (parseAfterConnect) { + assertThat("Unexpected response message type.", object, instanceOf(HttpObject.class)); + } else { + assertThat("Unexpected response message type.", object, not(instanceOf(HttpObject.class))); + } + } + }; + sendRequestAndReadResponse(ch, HttpMethod.GET, RESPONSE, responseConsumer); + assertTrue(responseConsumer.getReceivedCount() > 0, "No response messages received."); + assertFalse(ch.finish(), "Channel finish failed."); + } + + private static void sendRequestAndReadResponse(EmbeddedChannel ch, HttpMethod httpMethod, String response) { + sendRequestAndReadResponse(ch, httpMethod, response, new Consumer()); + } + + private static void sendRequestAndReadResponse(EmbeddedChannel ch, HttpMethod httpMethod, String response, + Consumer responseConsumer) { + assertTrue(ch.writeOutbound(new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, httpMethod, "http://localhost/")), + "Channel outbound write failed."); + assertTrue(ch.writeInbound(Unpooled.copiedBuffer(response, CharsetUtil.ISO_8859_1)), + "Channel inbound write failed."); + + for (;;) { + Object msg = ch.readOutbound(); + if (msg == null) { + break; + } + release(msg); + } + for (;;) { + Object msg = ch.readInbound(); + if (msg == null) { + break; + } + responseConsumer.onResponse(msg); + release(msg); + } + } + + private static class Consumer { + + private int receivedCount; + + final void onResponse(Object object) { + receivedCount++; + accept(object); + } + + void accept(Object object) { + // Default noop. + } + + int getReceivedCount() { + return receivedCount; + } + } + + @Test + public void testDecodesFinalResponseAfterSwitchingProtocols() { + String SWITCHING_PROTOCOLS_RESPONSE = "HTTP/1.1 101 Switching Protocols\r\n" + + "Connection: Upgrade\r\n" + + "Upgrade: TLS/1.2, HTTP/1.1\r\n\r\n"; + + HttpClientCodec codec = new HttpClientCodec(4096, 8192, 8192, true); + EmbeddedChannel ch = new EmbeddedChannel(codec, new HttpObjectAggregator(1024)); + + HttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "http://localhost/"); + request.headers().set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE); + request.headers().set(HttpHeaderNames.UPGRADE, "TLS/1.2"); + assertTrue(ch.writeOutbound(request), "Channel outbound write failed."); + + assertTrue(ch.writeInbound(Unpooled.copiedBuffer(SWITCHING_PROTOCOLS_RESPONSE, CharsetUtil.ISO_8859_1)), + "Channel inbound write failed."); + Object switchingProtocolsResponse = ch.readInbound(); + assertNotNull(switchingProtocolsResponse, "No response received"); + assertThat("Response was not decoded", switchingProtocolsResponse, instanceOf(FullHttpResponse.class)); + ((FullHttpResponse) switchingProtocolsResponse).release(); + + assertTrue(ch.writeInbound(Unpooled.copiedBuffer(RESPONSE, CharsetUtil.ISO_8859_1)), + "Channel inbound write failed"); + Object finalResponse = ch.readInbound(); + assertNotNull(finalResponse, "No response received"); + assertThat("Response was not decoded", finalResponse, instanceOf(FullHttpResponse.class)); + ((FullHttpResponse) finalResponse).release(); + assertTrue(ch.finishAndReleaseAll(), "Channel finish failed"); + } + + @Test + public void testWebSocket00Response() { + byte[] data = ("HTTP/1.1 101 WebSocket Protocol Handshake\r\n" + + "Upgrade: WebSocket\r\n" + + "Connection: Upgrade\r\n" + + "Sec-WebSocket-Origin: http://localhost:8080\r\n" + + "Sec-WebSocket-Location: ws://localhost/some/path\r\n" + + "\r\n" + + "1234567812345678").getBytes(); + EmbeddedChannel ch = new EmbeddedChannel(new HttpClientCodec()); + assertTrue(ch.writeInbound(Unpooled.wrappedBuffer(data))); + + HttpResponse res = ch.readInbound(); + assertThat(res.protocolVersion(), sameInstance(HttpVersion.HTTP_1_1)); + assertThat(res.status(), is(HttpResponseStatus.SWITCHING_PROTOCOLS)); + HttpContent content = ch.readInbound(); + assertThat(content.content().readableBytes(), is(16)); + content.release(); + + assertThat(ch.finish(), is(false)); + + assertThat(ch.readInbound(), is(nullValue())); + } + + @Test + public void testWebDavResponse() { + byte[] data = ("HTTP/1.1 102 Processing\r\n" + + "Status-URI: Status-URI:http://status.com; 404\r\n" + + "\r\n" + + "1234567812345678").getBytes(); + EmbeddedChannel ch = new EmbeddedChannel(new HttpClientCodec()); + assertTrue(ch.writeInbound(Unpooled.wrappedBuffer(data))); + + HttpResponse res = ch.readInbound(); + assertThat(res.protocolVersion(), sameInstance(HttpVersion.HTTP_1_1)); + assertThat(res.status(), is(HttpResponseStatus.PROCESSING)); + HttpContent content = ch.readInbound(); + // HTTP 102 is not allowed to have content. + assertThat(content.content().readableBytes(), is(0)); + content.release(); + + assertThat(ch.finish(), is(false)); + } + + @Test + public void testInformationalResponseKeepsPairsInSync() { + byte[] data = ("HTTP/1.1 102 Processing\r\n" + + "Status-URI: Status-URI:http://status.com; 404\r\n" + + "\r\n").getBytes(); + byte[] data2 = ("HTTP/1.1 200 OK\r\n" + + "Content-Length: 8\r\n" + + "\r\n" + + "12345678").getBytes(); + EmbeddedChannel ch = new EmbeddedChannel(new HttpClientCodec()); + assertTrue(ch.writeOutbound(new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.HEAD, "/"))); + ByteBuf buffer = ch.readOutbound(); + buffer.release(); + assertNull(ch.readOutbound()); + assertTrue(ch.writeInbound(Unpooled.wrappedBuffer(data))); + HttpResponse res = ch.readInbound(); + assertThat(res.protocolVersion(), sameInstance(HttpVersion.HTTP_1_1)); + assertThat(res.status(), is(HttpResponseStatus.PROCESSING)); + HttpContent content = ch.readInbound(); + // HTTP 102 is not allowed to have content. + assertThat(content.content().readableBytes(), is(0)); + assertThat(content, instanceOf(LastHttpContent.class)); + content.release(); + + assertTrue(ch.writeOutbound(new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"))); + buffer = ch.readOutbound(); + buffer.release(); + assertNull(ch.readOutbound()); + assertTrue(ch.writeInbound(Unpooled.wrappedBuffer(data2))); + + res = ch.readInbound(); + assertThat(res.protocolVersion(), sameInstance(HttpVersion.HTTP_1_1)); + assertThat(res.status(), is(HttpResponseStatus.OK)); + content = ch.readInbound(); + // HTTP 200 has content. + assertThat(content.content().readableBytes(), is(8)); + assertThat(content, instanceOf(LastHttpContent.class)); + content.release(); + + assertThat(ch.finish(), is(false)); + } + + @Test + public void testMultipleResponses() { + String response = "HTTP/1.1 200 OK\r\n" + + "Content-Length: 0\r\n\r\n"; + + HttpClientCodec codec = new HttpClientCodec(4096, 8192, 8192, true); + EmbeddedChannel ch = new EmbeddedChannel(codec, new HttpObjectAggregator(1024)); + + HttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "http://localhost/"); + assertTrue(ch.writeOutbound(request)); + + assertTrue(ch.writeInbound(Unpooled.copiedBuffer(response, CharsetUtil.UTF_8))); + assertTrue(ch.writeInbound(Unpooled.copiedBuffer(response, CharsetUtil.UTF_8))); + FullHttpResponse resp = ch.readInbound(); + assertTrue(resp.decoderResult().isSuccess()); + resp.release(); + + resp = ch.readInbound(); + assertTrue(resp.decoderResult().isSuccess()); + resp.release(); + assertTrue(ch.finishAndReleaseAll()); + } + + @Test + public void testWriteThroughAfterUpgrade() { + HttpClientCodec codec = new HttpClientCodec(); + EmbeddedChannel ch = new EmbeddedChannel(codec); + codec.prepareUpgradeFrom(null); + + ByteBuf buffer = ch.alloc().buffer(); + assertThat(buffer.refCnt(), is(1)); + assertTrue(ch.writeOutbound(buffer)); + // buffer should pass through unchanged + assertThat(ch.readOutbound(), sameInstance(buffer)); + assertThat(buffer.refCnt(), is(1)); + + buffer.release(); + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpClientUpgradeHandlerTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpClientUpgradeHandlerTest.java new file mode 100644 index 0000000..a119118 --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpClientUpgradeHandlerTest.java @@ -0,0 +1,199 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.embedded.EmbeddedChannel; + +import java.util.Collection; +import java.util.Collections; +import java.util.List; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class HttpClientUpgradeHandlerTest { + + private static final class FakeSourceCodec implements HttpClientUpgradeHandler.SourceCodec { + @Override + public void prepareUpgradeFrom(ChannelHandlerContext ctx) { + } + + @Override + public void upgradeFrom(ChannelHandlerContext ctx) { + } + } + + private static final class FakeUpgradeCodec implements HttpClientUpgradeHandler.UpgradeCodec { + @Override + public CharSequence protocol() { + return "fancyhttp"; + } + + @Override + public Collection setUpgradeHeaders(ChannelHandlerContext ctx, HttpRequest upgradeRequest) { + return Collections.emptyList(); + } + + @Override + public void upgradeTo(ChannelHandlerContext ctx, FullHttpResponse upgradeResponse) throws Exception { + } + } + + private static final class UserEventCatcher extends ChannelInboundHandlerAdapter { + private Object evt; + + public Object getUserEvent() { + return evt; + } + + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + this.evt = evt; + } + } + + @Test + public void testSuccessfulUpgrade() { + HttpClientUpgradeHandler.SourceCodec sourceCodec = new FakeSourceCodec(); + HttpClientUpgradeHandler.UpgradeCodec upgradeCodec = new FakeUpgradeCodec(); + HttpClientUpgradeHandler handler = new HttpClientUpgradeHandler(sourceCodec, upgradeCodec, 1024); + UserEventCatcher catcher = new UserEventCatcher(); + EmbeddedChannel channel = new EmbeddedChannel(catcher); + channel.pipeline().addFirst("upgrade", handler); + + assertTrue( + channel.writeOutbound(new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "netty.io"))); + FullHttpRequest request = channel.readOutbound(); + + assertEquals(2, request.headers().size()); + assertTrue(request.headers().contains(HttpHeaderNames.UPGRADE, "fancyhttp", false)); + assertTrue(request.headers().contains("connection", "upgrade", false)); + assertTrue(request.release()); + assertEquals(HttpClientUpgradeHandler.UpgradeEvent.UPGRADE_ISSUED, catcher.getUserEvent()); + + HttpResponse upgradeResponse = + new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.SWITCHING_PROTOCOLS); + + upgradeResponse.headers().add(HttpHeaderNames.UPGRADE, "fancyhttp"); + assertFalse(channel.writeInbound(upgradeResponse)); + assertFalse(channel.writeInbound(LastHttpContent.EMPTY_LAST_CONTENT)); + + assertEquals(HttpClientUpgradeHandler.UpgradeEvent.UPGRADE_SUCCESSFUL, catcher.getUserEvent()); + assertNull(channel.pipeline().get("upgrade")); + + assertTrue(channel.writeInbound(new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK))); + FullHttpResponse response = channel.readInbound(); + assertEquals(HttpResponseStatus.OK, response.status()); + assertTrue(response.release()); + assertFalse(channel.finish()); + } + + @Test + public void testUpgradeRejected() { + HttpClientUpgradeHandler.SourceCodec sourceCodec = new FakeSourceCodec(); + HttpClientUpgradeHandler.UpgradeCodec upgradeCodec = new FakeUpgradeCodec(); + HttpClientUpgradeHandler handler = new HttpClientUpgradeHandler(sourceCodec, upgradeCodec, 1024); + UserEventCatcher catcher = new UserEventCatcher(); + EmbeddedChannel channel = new EmbeddedChannel(catcher); + channel.pipeline().addFirst("upgrade", handler); + + assertTrue( + channel.writeOutbound(new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "netty.io"))); + FullHttpRequest request = channel.readOutbound(); + + assertEquals(2, request.headers().size()); + assertTrue(request.headers().contains(HttpHeaderNames.UPGRADE, "fancyhttp", false)); + assertTrue(request.headers().contains("connection", "upgrade", false)); + assertTrue(request.release()); + assertEquals(HttpClientUpgradeHandler.UpgradeEvent.UPGRADE_ISSUED, catcher.getUserEvent()); + + HttpResponse upgradeResponse = + new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.SWITCHING_PROTOCOLS); + upgradeResponse.headers().add(HttpHeaderNames.UPGRADE, "fancyhttp"); + assertTrue(channel.writeInbound(new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK))); + assertTrue(channel.writeInbound(LastHttpContent.EMPTY_LAST_CONTENT)); + + assertEquals(HttpClientUpgradeHandler.UpgradeEvent.UPGRADE_REJECTED, catcher.getUserEvent()); + assertNull(channel.pipeline().get("upgrade")); + + HttpResponse response = channel.readInbound(); + assertEquals(HttpResponseStatus.OK, response.status()); + + LastHttpContent last = channel.readInbound(); + assertEquals(LastHttpContent.EMPTY_LAST_CONTENT, last); + assertFalse(last.release()); + assertFalse(channel.finish()); + } + + @Test + public void testEarlyBailout() { + HttpClientUpgradeHandler.SourceCodec sourceCodec = new FakeSourceCodec(); + HttpClientUpgradeHandler.UpgradeCodec upgradeCodec = new FakeUpgradeCodec(); + HttpClientUpgradeHandler handler = new HttpClientUpgradeHandler(sourceCodec, upgradeCodec, 1024); + UserEventCatcher catcher = new UserEventCatcher(); + EmbeddedChannel channel = new EmbeddedChannel(catcher); + channel.pipeline().addFirst("upgrade", handler); + + assertTrue( + channel.writeOutbound(new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "netty.io"))); + FullHttpRequest request = channel.readOutbound(); + + assertEquals(2, request.headers().size()); + assertTrue(request.headers().contains(HttpHeaderNames.UPGRADE, "fancyhttp", false)); + assertTrue(request.headers().contains("connection", "upgrade", false)); + assertTrue(request.release()); + assertEquals(HttpClientUpgradeHandler.UpgradeEvent.UPGRADE_ISSUED, catcher.getUserEvent()); + + HttpResponse upgradeResponse = + new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.SWITCHING_PROTOCOLS); + upgradeResponse.headers().add(HttpHeaderNames.UPGRADE, "fancyhttp"); + assertTrue(channel.writeInbound(new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK))); + + assertEquals(HttpClientUpgradeHandler.UpgradeEvent.UPGRADE_REJECTED, catcher.getUserEvent()); + assertNull(channel.pipeline().get("upgrade")); + + HttpResponse response = channel.readInbound(); + assertEquals(HttpResponseStatus.OK, response.status()); + assertFalse(channel.finish()); + } + + @Test + public void dontStripConnectionHeaders() { + HttpClientUpgradeHandler.SourceCodec sourceCodec = new FakeSourceCodec(); + HttpClientUpgradeHandler.UpgradeCodec upgradeCodec = new FakeUpgradeCodec(); + HttpClientUpgradeHandler handler = new HttpClientUpgradeHandler(sourceCodec, upgradeCodec, 1024); + UserEventCatcher catcher = new UserEventCatcher(); + EmbeddedChannel channel = new EmbeddedChannel(catcher); + channel.pipeline().addFirst("upgrade", handler); + + DefaultFullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "netty.io"); + request.headers().add("connection", "extra"); + request.headers().add("extra", "value"); + assertTrue(channel.writeOutbound(request)); + FullHttpRequest readRequest = channel.readOutbound(); + + List connectionHeaders = readRequest.headers().getAll("connection"); + assertTrue(connectionHeaders.contains("extra")); + assertTrue(readRequest.release()); + assertFalse(channel.finish()); + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpContentCompressorOptionsTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpContentCompressorOptionsTest.java new file mode 100644 index 0000000..dfc1d64 --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpContentCompressorOptionsTest.java @@ -0,0 +1,157 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.compression.Brotli; +import io.netty.handler.codec.compression.StandardCompressionOptions; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIf; + +import static org.hamcrest.CoreMatchers.instanceOf; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.CoreMatchers.not; +import static org.hamcrest.CoreMatchers.nullValue; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +@EnabledIf("isBrotiAvailable") +class HttpContentCompressorOptionsTest { + + static boolean isBrotiAvailable() { + return Brotli.isAvailable(); + } + + @Test + void testGetBrTargetContentEncoding() { + HttpContentCompressor compressor = new HttpContentCompressor( + StandardCompressionOptions.gzip(), + StandardCompressionOptions.deflate(), + StandardCompressionOptions.brotli(), + StandardCompressionOptions.zstd(), + StandardCompressionOptions.snappy() + ); + + String[] tests = { + // Accept-Encoding -> Content-Encoding + "", null, + "*", "br", + "*;q=0.0", null, + "br", "br", + "compress, br;q=0.5", "br", + "br; q=0.5, identity", "br", + "br; q=0, deflate", "br", + }; + for (int i = 0; i < tests.length; i += 2) { + String acceptEncoding = tests[i]; + String contentEncoding = tests[i + 1]; + String targetEncoding = compressor.determineEncoding(acceptEncoding); + assertEquals(contentEncoding, targetEncoding); + } + } + + @Test + void testGetZstdTargetContentEncoding() { + HttpContentCompressor compressor = new HttpContentCompressor( + StandardCompressionOptions.gzip(), + StandardCompressionOptions.deflate(), + StandardCompressionOptions.brotli(), + StandardCompressionOptions.zstd(), + StandardCompressionOptions.snappy() + ); + + String[] tests = { + // Accept-Encoding -> Content-Encoding + "", null, + "*;q=0.0", null, + "zstd", "zstd", + "compress, zstd;q=0.5", "zstd", + "zstd; q=0.5, identity", "zstd", + "zstd; q=0, deflate", "zstd", + }; + for (int i = 0; i < tests.length; i += 2) { + String acceptEncoding = tests[i]; + String contentEncoding = tests[i + 1]; + String targetEncoding = compressor.determineEncoding(acceptEncoding); + assertEquals(contentEncoding, targetEncoding); + } + } + + @Test + void testGetSnappyTargetContentEncoding() { + HttpContentCompressor compressor = new HttpContentCompressor( + StandardCompressionOptions.gzip(), + StandardCompressionOptions.deflate(), + StandardCompressionOptions.brotli(), + StandardCompressionOptions.zstd(), + StandardCompressionOptions.snappy() + ); + + String[] tests = { + // Accept-Encoding -> Content-Encoding + "", null, + "*;q=0.0", null, + "snappy", "snappy", + "compress, snappy;q=0.5", "snappy", + "snappy; q=0.5, identity", "snappy", + "snappy; q=0, deflate", "snappy", + }; + for (int i = 0; i < tests.length; i += 2) { + String acceptEncoding = tests[i]; + String contentEncoding = tests[i + 1]; + String targetEncoding = compressor.determineEncoding(acceptEncoding); + assertEquals(contentEncoding, targetEncoding); + } + } + + @Test + void testAcceptEncodingHttpRequest() { + EmbeddedChannel ch = new EmbeddedChannel(new HttpContentCompressor(null)); + ch.writeInbound(newRequest()); + FullHttpRequest fullHttpRequest = ch.readInbound(); + fullHttpRequest.release(); + + HttpResponse res = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK); + res.headers().set(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED); + ch.writeOutbound(res); + + assertEncodedResponse(ch); + + assertTrue(ch.close().isSuccess()); + } + + private static void assertEncodedResponse(EmbeddedChannel ch) { + Object o = ch.readOutbound(); + assertThat(o, is(instanceOf(HttpResponse.class))); + + assertEncodedResponse((HttpResponse) o); + } + + private static void assertEncodedResponse(HttpResponse res) { + assertThat(res, is(not(instanceOf(HttpContent.class)))); + assertThat(res.headers().get(HttpHeaderNames.TRANSFER_ENCODING), is("chunked")); + assertThat(res.headers().get(HttpHeaderNames.CONTENT_LENGTH), is(nullValue())); + assertThat(res.headers().get(HttpHeaderNames.CONTENT_ENCODING), is("br")); + } + + private static FullHttpRequest newRequest() { + FullHttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"); + req.headers().set(HttpHeaderNames.ACCEPT_ENCODING, "br, zstd, snappy, gzip, deflate"); + return req; + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpContentCompressorTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpContentCompressorTest.java new file mode 100644 index 0000000..c71f689 --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpContentCompressorTest.java @@ -0,0 +1,1105 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import com.aayushatharva.brotli4j.decoder.DecoderJNI; +import com.aayushatharva.brotli4j.decoder.DirectDecompress; +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.CompositeByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelOutboundHandlerAdapter; +import io.netty.channel.ChannelPromise; +import io.netty.channel.DefaultEventLoopGroup; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.channel.local.LocalAddress; +import io.netty.channel.local.LocalChannel; +import io.netty.channel.local.LocalServerChannel; +import io.netty.handler.codec.DecoderResult; +import io.netty.handler.codec.EncoderException; +import io.netty.handler.codec.compression.Brotli; +import io.netty.handler.codec.compression.CompressionOptions; +import io.netty.handler.codec.compression.ZlibWrapper; +import io.netty.util.CharsetUtil; +import io.netty.util.ReferenceCountUtil; + +import java.util.UUID; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIf; + +import java.nio.charset.StandardCharsets; + +import static io.netty.handler.codec.http.HttpHeadersTestUtils.of; +import static org.hamcrest.CoreMatchers.instanceOf; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.CoreMatchers.not; +import static org.hamcrest.CoreMatchers.nullValue; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class HttpContentCompressorTest { + + @Test + public void testGetTargetContentEncoding() throws Exception { + HttpContentCompressor compressor = new HttpContentCompressor(); + + String[] tests = { + // Accept-Encoding -> Content-Encoding + "", null, + "*", "gzip", + "*;q=0.0", null, + "gzip", "gzip", + "compress, gzip;q=0.5", "gzip", + "gzip; q=0.5, identity", "gzip", + "gzip ; q=0.1", "gzip", + "gzip; q=0, deflate", "deflate", + " deflate ; q=0 , *;q=0.5", "gzip", + }; + for (int i = 0; i < tests.length; i += 2) { + String acceptEncoding = tests[i]; + String contentEncoding = tests[i + 1]; + ZlibWrapper targetWrapper = compressor.determineWrapper(acceptEncoding); + String targetEncoding = null; + if (targetWrapper != null) { + switch (targetWrapper) { + case GZIP: + targetEncoding = "gzip"; + break; + case ZLIB: + targetEncoding = "deflate"; + break; + default: + fail(); + } + } + assertEquals(contentEncoding, targetEncoding); + } + } + + @Test + public void testDetermineEncoding() throws Exception { + HttpContentCompressor compressor = new HttpContentCompressor((CompressionOptions[]) null); + + String[] tests = { + // Accept-Encoding -> Content-Encoding + "", null, + ",", null, + "identity", null, + "unknown", null, + "*", "br", + "br", "br", + "br ; q=0.1", "br", + "unknown, br", "br", + "br, gzip", "br", + "gzip, br", "br", + "identity, br", "br", + "gzip", "gzip", + "gzip ; q=0.1", "gzip", + }; + for (int i = 0; i < tests.length; i += 2) { + final String acceptEncoding = tests[i]; + final String expectedEncoding = tests[i + 1]; + final String targetEncoding = compressor.determineEncoding(acceptEncoding); + assertEquals(expectedEncoding, targetEncoding); + } + } + + @Test + public void testSplitContent() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new HttpContentCompressor()); + ch.writeInbound(newRequest()); + + ch.writeOutbound(new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK)); + ch.writeOutbound(new DefaultHttpContent(Unpooled.copiedBuffer("Hell", CharsetUtil.US_ASCII))); + ch.writeOutbound(new DefaultHttpContent(Unpooled.copiedBuffer("o, w", CharsetUtil.US_ASCII))); + ch.writeOutbound(new DefaultLastHttpContent(Unpooled.copiedBuffer("orld", CharsetUtil.US_ASCII))); + + assertEncodedResponse(ch); + + HttpContent chunk; + chunk = ch.readOutbound(); + assertThat(ByteBufUtil.hexDump(chunk.content()), is("1f8b0800000000000000f248cdc901000000ffff")); + chunk.release(); + + chunk = ch.readOutbound(); + assertThat(ByteBufUtil.hexDump(chunk.content()), is("cad7512807000000ffff")); + chunk.release(); + + chunk = ch.readOutbound(); + assertThat(ByteBufUtil.hexDump(chunk.content()), is("ca2fca4901000000ffff")); + chunk.release(); + + chunk = ch.readOutbound(); + assertThat(ByteBufUtil.hexDump(chunk.content()), is("0300c2a99ae70c000000")); + assertThat(chunk, is(instanceOf(HttpContent.class))); + chunk.release(); + + chunk = ch.readOutbound(); + assertThat(chunk.content().isReadable(), is(false)); + assertThat(chunk, is(instanceOf(LastHttpContent.class))); + chunk.release(); + + assertThat(ch.readOutbound(), is(nullValue())); + } + + @Test + public void testChunkedContent() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new HttpContentCompressor()); + ch.writeInbound(newRequest()); + + HttpResponse res = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK); + res.headers().set(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED); + ch.writeOutbound(res); + + assertEncodedResponse(ch); + + ch.writeOutbound(new DefaultHttpContent(Unpooled.copiedBuffer("Hell", CharsetUtil.US_ASCII))); + ch.writeOutbound(new DefaultHttpContent(Unpooled.copiedBuffer("o, w", CharsetUtil.US_ASCII))); + ch.writeOutbound(new DefaultLastHttpContent(Unpooled.copiedBuffer("orld", CharsetUtil.US_ASCII))); + + HttpContent chunk; + chunk = ch.readOutbound(); + assertThat(ByteBufUtil.hexDump(chunk.content()), is("1f8b0800000000000000f248cdc901000000ffff")); + chunk.release(); + + chunk = ch.readOutbound(); + assertThat(ByteBufUtil.hexDump(chunk.content()), is("cad7512807000000ffff")); + chunk.release(); + + chunk = ch.readOutbound(); + assertThat(ByteBufUtil.hexDump(chunk.content()), is("ca2fca4901000000ffff")); + chunk.release(); + + chunk = ch.readOutbound(); + assertThat(ByteBufUtil.hexDump(chunk.content()), is("0300c2a99ae70c000000")); + assertThat(chunk, is(instanceOf(HttpContent.class))); + chunk.release(); + + chunk = ch.readOutbound(); + assertThat(chunk.content().isReadable(), is(false)); + assertThat(chunk, is(instanceOf(LastHttpContent.class))); + chunk.release(); + + assertThat(ch.readOutbound(), is(nullValue())); + } + + @Test + public void testChunkedContentWithAssembledResponse() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new HttpContentCompressor()); + ch.writeInbound(newRequest()); + + HttpResponse res = new AssembledHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK, + Unpooled.copiedBuffer("Hell", CharsetUtil.US_ASCII)); + res.headers().set(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED); + ch.writeOutbound(res); + + assertAssembledEncodedResponse(ch); + + ch.writeOutbound(new DefaultHttpContent(Unpooled.copiedBuffer("o, w", CharsetUtil.US_ASCII))); + ch.writeOutbound(new DefaultLastHttpContent(Unpooled.copiedBuffer("orld", CharsetUtil.US_ASCII))); + + HttpContent chunk; + chunk = ch.readOutbound(); + assertThat(ByteBufUtil.hexDump(chunk.content()), is("1f8b0800000000000000f248cdc901000000ffff")); + chunk.release(); + + chunk = ch.readOutbound(); + assertThat(ByteBufUtil.hexDump(chunk.content()), is("cad7512807000000ffff")); + chunk.release(); + + chunk = ch.readOutbound(); + assertThat(ByteBufUtil.hexDump(chunk.content()), is("ca2fca4901000000ffff")); + chunk.release(); + + chunk = ch.readOutbound(); + assertThat(ByteBufUtil.hexDump(chunk.content()), is("0300c2a99ae70c000000")); + assertThat(chunk, is(instanceOf(HttpContent.class))); + chunk.release(); + + chunk = ch.readOutbound(); + assertThat(chunk.content().isReadable(), is(false)); + assertThat(chunk, is(instanceOf(LastHttpContent.class))); + chunk.release(); + + assertThat(ch.readOutbound(), is(nullValue())); + } + + @Test + public void testChunkedContentWithAssembledResponseIdentityEncoding() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new HttpContentCompressor()); + ch.writeInbound(new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/")); + + HttpResponse res = new AssembledHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK, + Unpooled.copiedBuffer("Hell", CharsetUtil.US_ASCII)); + res.headers().set(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED); + ch.writeOutbound(res); + + ch.writeOutbound(new DefaultHttpContent(Unpooled.copiedBuffer("o, w", CharsetUtil.US_ASCII))); + ch.writeOutbound(new DefaultLastHttpContent(Unpooled.copiedBuffer("orld", CharsetUtil.US_ASCII))); + + HttpContent chunk; + chunk = ch.readOutbound(); + assertThat(chunk.content().toString(StandardCharsets.UTF_8), is("Hell")); + chunk.release(); + + chunk = ch.readOutbound(); + assertThat(chunk.content().toString(StandardCharsets.UTF_8), is("o, w")); + chunk.release(); + + chunk = ch.readOutbound(); + assertThat(chunk.content().toString(StandardCharsets.UTF_8), is("orld")); + assertThat(chunk, is(instanceOf(LastHttpContent.class))); + chunk.release(); + + assertThat(ch.readOutbound(), is(nullValue())); + } + + @Test + public void testContentWithAssembledResponseIdentityEncodingHttp10() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new HttpContentCompressor()); + ch.writeInbound(new DefaultFullHttpRequest(HttpVersion.HTTP_1_0, HttpMethod.GET, "/")); + + HttpResponse res = new AssembledHttpResponse(HttpVersion.HTTP_1_0, HttpResponseStatus.OK, + Unpooled.copiedBuffer("Hell", CharsetUtil.US_ASCII)); + ch.writeOutbound(res); + + ch.writeOutbound(new DefaultHttpContent(Unpooled.copiedBuffer("o, w", CharsetUtil.US_ASCII))); + ch.writeOutbound(new DefaultLastHttpContent(Unpooled.copiedBuffer("orld", CharsetUtil.US_ASCII))); + + HttpContent chunk; + chunk = ch.readOutbound(); + assertThat(chunk.content().toString(StandardCharsets.UTF_8), is("Hell")); + chunk.release(); + + chunk = ch.readOutbound(); + assertThat(chunk.content().toString(StandardCharsets.UTF_8), is("o, w")); + chunk.release(); + + chunk = ch.readOutbound(); + assertThat(chunk.content().toString(StandardCharsets.UTF_8), is("orld")); + assertThat(chunk, is(instanceOf(LastHttpContent.class))); + chunk.release(); + + assertThat(ch.readOutbound(), is(nullValue())); + } + + @Test + public void testChunkedContentWithTrailingHeader() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new HttpContentCompressor()); + ch.writeInbound(newRequest()); + + HttpResponse res = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK); + res.headers().set(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED); + ch.writeOutbound(res); + + assertEncodedResponse(ch); + + ch.writeOutbound(new DefaultHttpContent(Unpooled.copiedBuffer("Hell", CharsetUtil.US_ASCII))); + ch.writeOutbound(new DefaultHttpContent(Unpooled.copiedBuffer("o, w", CharsetUtil.US_ASCII))); + LastHttpContent content = new DefaultLastHttpContent(Unpooled.copiedBuffer("orld", CharsetUtil.US_ASCII)); + content.trailingHeaders().set(of("X-Test"), of("Netty")); + ch.writeOutbound(content); + + HttpContent chunk; + chunk = ch.readOutbound(); + assertThat(ByteBufUtil.hexDump(chunk.content()), is("1f8b0800000000000000f248cdc901000000ffff")); + chunk.release(); + + chunk = ch.readOutbound(); + assertThat(ByteBufUtil.hexDump(chunk.content()), is("cad7512807000000ffff")); + chunk.release(); + + chunk = ch.readOutbound(); + assertThat(ByteBufUtil.hexDump(chunk.content()), is("ca2fca4901000000ffff")); + chunk.release(); + + chunk = ch.readOutbound(); + assertThat(ByteBufUtil.hexDump(chunk.content()), is("0300c2a99ae70c000000")); + assertThat(chunk, is(instanceOf(HttpContent.class))); + chunk.release(); + + chunk = ch.readOutbound(); + assertThat(chunk.content().isReadable(), is(false)); + assertThat(chunk, is(instanceOf(LastHttpContent.class))); + assertEquals("Netty", ((LastHttpContent) chunk).trailingHeaders().get(of("X-Test"))); + assertEquals(DecoderResult.SUCCESS, chunk.decoderResult()); + chunk.release(); + + assertThat(ch.readOutbound(), is(nullValue())); + } + + @Test + public void testFullContentWithContentLength() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new HttpContentCompressor()); + ch.writeInbound(newRequest()); + + FullHttpResponse fullRes = new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, HttpResponseStatus.OK, + Unpooled.copiedBuffer("Hello, World", CharsetUtil.US_ASCII)); + fullRes.headers().set(HttpHeaderNames.CONTENT_LENGTH, fullRes.content().readableBytes()); + ch.writeOutbound(fullRes); + + HttpResponse res = ch.readOutbound(); + assertThat(res, is(not(instanceOf(HttpContent.class)))); + + assertThat(res.headers().get(HttpHeaderNames.TRANSFER_ENCODING), is(nullValue())); + assertThat(res.headers().get(HttpHeaderNames.CONTENT_ENCODING), is("gzip")); + + long contentLengthHeaderValue = HttpUtil.getContentLength(res); + long observedLength = 0; + + HttpContent c = ch.readOutbound(); + observedLength += c.content().readableBytes(); + assertThat(ByteBufUtil.hexDump(c.content()), is("1f8b0800000000000000f248cdc9c9d75108cf2fca4901000000ffff")); + c.release(); + + c = ch.readOutbound(); + observedLength += c.content().readableBytes(); + assertThat(ByteBufUtil.hexDump(c.content()), is("0300c6865b260c000000")); + c.release(); + + LastHttpContent last = ch.readOutbound(); + assertThat(last.content().readableBytes(), is(0)); + last.release(); + + assertThat(ch.readOutbound(), is(nullValue())); + assertEquals(contentLengthHeaderValue, observedLength); + } + + @Test + public void testFullContent() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new HttpContentCompressor()); + ch.writeInbound(newRequest()); + + FullHttpResponse res = new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, HttpResponseStatus.OK, + Unpooled.copiedBuffer("Hello, World", CharsetUtil.US_ASCII)); + ch.writeOutbound(res); + + assertEncodedResponse(ch); + HttpContent c = ch.readOutbound(); + assertThat(ByteBufUtil.hexDump(c.content()), is("1f8b0800000000000000f248cdc9c9d75108cf2fca4901000000ffff")); + c.release(); + + c = ch.readOutbound(); + assertThat(ByteBufUtil.hexDump(c.content()), is("0300c6865b260c000000")); + c.release(); + + LastHttpContent last = ch.readOutbound(); + assertThat(last.content().readableBytes(), is(0)); + last.release(); + + assertThat(ch.readOutbound(), is(nullValue())); + } + + @Test + public void testExecutorPreserveOrdering() throws Exception { + final EventLoopGroup compressorGroup = new DefaultEventLoopGroup(1); + EventLoopGroup localGroup = new DefaultEventLoopGroup(1); + Channel server = null; + Channel client = null; + try { + ServerBootstrap bootstrap = new ServerBootstrap() + .channel(LocalServerChannel.class) + .group(localGroup) + .childHandler(new ChannelInitializer() { + @Override + protected void initChannel(LocalChannel ch) throws Exception { + ch.pipeline() + .addLast(new HttpServerCodec()) + .addLast(new HttpObjectAggregator(1024)) + .addLast(compressorGroup, new HttpContentCompressor()) + .addLast(new ChannelOutboundHandlerAdapter() { + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) + throws Exception { + super.write(ctx, msg, promise); + } + }) + .addLast(new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if (msg instanceof FullHttpRequest) { + FullHttpResponse res = + new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK, + Unpooled.copiedBuffer("Hello, World", CharsetUtil.US_ASCII)); + ctx.writeAndFlush(res); + ReferenceCountUtil.release(msg); + return; + } + super.channelRead(ctx, msg); + } + }); + } + }); + + LocalAddress address = new LocalAddress(UUID.randomUUID().toString()); + server = bootstrap.bind(address).sync().channel(); + + final BlockingQueue responses = new LinkedBlockingQueue(); + + client = new Bootstrap() + .channel(LocalChannel.class) + .remoteAddress(address) + .group(localGroup) + .handler(new ChannelInitializer() { + @Override + protected void initChannel(LocalChannel ch) throws Exception { + ch.pipeline().addLast(new HttpClientCodec()).addLast(new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if (msg instanceof HttpObject) { + responses.put((HttpObject) msg); + return; + } + super.channelRead(ctx, msg); + } + }); + } + }).connect().sync().channel(); + + client.writeAndFlush(newRequest()).sync(); + + assertEncodedResponse((HttpResponse) responses.poll(1, TimeUnit.SECONDS)); + HttpContent c = (HttpContent) responses.poll(1, TimeUnit.SECONDS); + assertNotNull(c); + assertThat(ByteBufUtil.hexDump(c.content()), + is("1f8b0800000000000000f248cdc9c9d75108cf2fca4901000000ffff")); + c.release(); + + c = (HttpContent) responses.poll(1, TimeUnit.SECONDS); + assertNotNull(c); + assertThat(ByteBufUtil.hexDump(c.content()), is("0300c6865b260c000000")); + c.release(); + + LastHttpContent last = (LastHttpContent) responses.poll(1, TimeUnit.SECONDS); + assertNotNull(last); + assertThat(last.content().readableBytes(), is(0)); + last.release(); + + assertNull(responses.poll(1, TimeUnit.SECONDS)); + } finally { + if (client != null) { + client.close().sync(); + } + if (server != null) { + server.close().sync(); + } + compressorGroup.shutdownGracefully(); + localGroup.shutdownGracefully(); + } + } + + /** + * If the length of the content is unknown, {@link HttpContentEncoder} should not skip encoding the content + * even if the actual length is turned out to be 0. + */ + @Test + public void testEmptySplitContent() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new HttpContentCompressor()); + ch.writeInbound(newRequest()); + + ch.writeOutbound(new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK)); + assertEncodedResponse(ch); + + ch.writeOutbound(LastHttpContent.EMPTY_LAST_CONTENT); + HttpContent chunk = ch.readOutbound(); + assertThat(ByteBufUtil.hexDump(chunk.content()), is("1f8b080000000000000003000000000000000000")); + assertThat(chunk, is(instanceOf(HttpContent.class))); + chunk.release(); + + chunk = ch.readOutbound(); + assertThat(chunk.content().isReadable(), is(false)); + assertThat(chunk, is(instanceOf(LastHttpContent.class))); + chunk.release(); + + assertThat(ch.readOutbound(), is(nullValue())); + } + + /** + * If the length of the content is 0 for sure, {@link HttpContentEncoder} should skip encoding. + */ + @Test + public void testEmptyFullContent() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new HttpContentCompressor()); + ch.writeInbound(newRequest()); + + FullHttpResponse res = new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, HttpResponseStatus.OK, Unpooled.EMPTY_BUFFER); + ch.writeOutbound(res); + + Object o = ch.readOutbound(); + assertThat(o, is(instanceOf(FullHttpResponse.class))); + + res = (FullHttpResponse) o; + assertThat(res.headers().get(HttpHeaderNames.TRANSFER_ENCODING), is(nullValue())); + + // Content encoding shouldn't be modified. + assertThat(res.headers().get(HttpHeaderNames.CONTENT_ENCODING), is(nullValue())); + assertThat(res.content().readableBytes(), is(0)); + assertThat(res.content().toString(CharsetUtil.US_ASCII), is("")); + res.release(); + + assertThat(ch.readOutbound(), is(nullValue())); + } + + @Test + public void testEmptyFullContentWithTrailer() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new HttpContentCompressor()); + ch.writeInbound(newRequest()); + + FullHttpResponse res = new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, HttpResponseStatus.OK, Unpooled.EMPTY_BUFFER); + res.trailingHeaders().set(of("X-Test"), of("Netty")); + ch.writeOutbound(res); + + Object o = ch.readOutbound(); + assertThat(o, is(instanceOf(FullHttpResponse.class))); + + res = (FullHttpResponse) o; + assertThat(res.headers().get(HttpHeaderNames.TRANSFER_ENCODING), is(nullValue())); + + // Content encoding shouldn't be modified. + assertThat(res.headers().get(HttpHeaderNames.CONTENT_ENCODING), is(nullValue())); + assertThat(res.content().readableBytes(), is(0)); + assertThat(res.content().toString(CharsetUtil.US_ASCII), is("")); + assertEquals("Netty", res.trailingHeaders().get(of("X-Test"))); + assertEquals(DecoderResult.SUCCESS, res.decoderResult()); + assertThat(ch.readOutbound(), is(nullValue())); + } + + @Test + public void test100Continue() throws Exception { + FullHttpRequest request = newRequest(); + HttpUtil.set100ContinueExpected(request, true); + + EmbeddedChannel ch = new EmbeddedChannel(new HttpContentCompressor()); + ch.writeInbound(request); + + FullHttpResponse continueResponse = new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, HttpResponseStatus.CONTINUE, Unpooled.EMPTY_BUFFER); + + ch.writeOutbound(continueResponse); + + FullHttpResponse res = new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, HttpResponseStatus.OK, Unpooled.EMPTY_BUFFER); + res.trailingHeaders().set(of("X-Test"), of("Netty")); + ch.writeOutbound(res); + + Object o = ch.readOutbound(); + assertThat(o, is(instanceOf(FullHttpResponse.class))); + + res = (FullHttpResponse) o; + assertSame(continueResponse, res); + res.release(); + + o = ch.readOutbound(); + assertThat(o, is(instanceOf(FullHttpResponse.class))); + + res = (FullHttpResponse) o; + assertThat(res.headers().get(HttpHeaderNames.TRANSFER_ENCODING), is(nullValue())); + + // Content encoding shouldn't be modified. + assertThat(res.headers().get(HttpHeaderNames.CONTENT_ENCODING), is(nullValue())); + assertThat(res.content().readableBytes(), is(0)); + assertThat(res.content().toString(CharsetUtil.US_ASCII), is("")); + assertEquals("Netty", res.trailingHeaders().get(of("X-Test"))); + assertEquals(DecoderResult.SUCCESS, res.decoderResult()); + assertThat(ch.readOutbound(), is(nullValue())); + } + + @Test + public void testMultiple1xxInformationalResponse() throws Exception { + FullHttpRequest request = newRequest(); + HttpUtil.set100ContinueExpected(request, true); + + EmbeddedChannel ch = new EmbeddedChannel(new HttpContentCompressor()); + ch.writeInbound(request); + + FullHttpResponse continueResponse = new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, HttpResponseStatus.CONTINUE, Unpooled.EMPTY_BUFFER); + ch.writeOutbound(continueResponse); + + FullHttpResponse earlyHintsResponse = new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, HttpResponseStatus.EARLY_HINTS, Unpooled.EMPTY_BUFFER); + earlyHintsResponse.trailingHeaders().set(of("X-Test"), of("Netty")); + ch.writeOutbound(earlyHintsResponse); + + FullHttpResponse res = new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, HttpResponseStatus.OK, Unpooled.EMPTY_BUFFER); + res.trailingHeaders().set(of("X-Test"), of("Netty")); + ch.writeOutbound(res); + + Object o = ch.readOutbound(); + assertThat(o, is(instanceOf(FullHttpResponse.class))); + + res = (FullHttpResponse) o; + assertSame(continueResponse, res); + res.release(); + + o = ch.readOutbound(); + assertThat(o, is(instanceOf(FullHttpResponse.class))); + + res = (FullHttpResponse) o; + assertSame(earlyHintsResponse, res); + res.release(); + + o = ch.readOutbound(); + assertThat(o, is(instanceOf(FullHttpResponse.class))); + + res = (FullHttpResponse) o; + assertThat(res.headers().get(HttpHeaderNames.TRANSFER_ENCODING), is(nullValue())); + + // Content encoding shouldn't be modified. + assertThat(res.headers().get(HttpHeaderNames.CONTENT_ENCODING), is(nullValue())); + assertThat(res.content().readableBytes(), is(0)); + assertThat(res.content().toString(CharsetUtil.US_ASCII), is("")); + assertEquals("Netty", res.trailingHeaders().get(of("X-Test"))); + assertEquals(DecoderResult.SUCCESS, res.decoderResult()); + assertThat(ch.readOutbound(), is(nullValue())); + + assertTrue(ch.finishAndReleaseAll()); + } + + @Test + public void test103EarlyHintsResponse() throws Exception { + FullHttpRequest request = newRequest(); + + EmbeddedChannel ch = new EmbeddedChannel(new HttpContentCompressor()); + ch.writeInbound(request); + + FullHttpResponse earlyHintsResponse = new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, HttpResponseStatus.EARLY_HINTS, Unpooled.EMPTY_BUFFER); + earlyHintsResponse.trailingHeaders().set(of("X-Test"), of("Netty")); + ch.writeOutbound(earlyHintsResponse); + + FullHttpResponse res = new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, HttpResponseStatus.OK, Unpooled.EMPTY_BUFFER); + res.trailingHeaders().set(of("X-Test"), of("Netty")); + ch.writeOutbound(res); + + Object o = ch.readOutbound(); + assertThat(o, is(instanceOf(FullHttpResponse.class))); + + res = (FullHttpResponse) o; + assertSame(earlyHintsResponse, res); + res.release(); + + o = ch.readOutbound(); + assertThat(o, is(instanceOf(FullHttpResponse.class))); + + res = (FullHttpResponse) o; + assertThat(res.headers().get(HttpHeaderNames.TRANSFER_ENCODING), is(nullValue())); + + // Content encoding shouldn't be modified. + assertThat(res.headers().get(HttpHeaderNames.CONTENT_ENCODING), is(nullValue())); + assertThat(res.content().readableBytes(), is(0)); + assertThat(res.content().toString(CharsetUtil.US_ASCII), is("")); + assertEquals("Netty", res.trailingHeaders().get(of("X-Test"))); + assertEquals(DecoderResult.SUCCESS, res.decoderResult()); + assertThat(ch.readOutbound(), is(nullValue())); + + assertTrue(ch.finishAndReleaseAll()); + } + + @Test + public void testTooManyResponses() throws Exception { + FullHttpRequest request = newRequest(); + EmbeddedChannel ch = new EmbeddedChannel(new HttpContentCompressor()); + ch.writeInbound(request); + + ch.writeOutbound(new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, HttpResponseStatus.OK, Unpooled.EMPTY_BUFFER)); + + try { + ch.writeOutbound(new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, HttpResponseStatus.OK, Unpooled.EMPTY_BUFFER)); + fail(); + } catch (EncoderException e) { + assertTrue(e.getCause() instanceof IllegalStateException); + } + assertTrue(ch.finish()); + for (;;) { + Object message = ch.readOutbound(); + if (message == null) { + break; + } + ReferenceCountUtil.release(message); + } + for (;;) { + Object message = ch.readInbound(); + if (message == null) { + break; + } + ReferenceCountUtil.release(message); + } + } + + @Test + public void testIdentity() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new HttpContentCompressor()); + assertTrue(ch.writeInbound(newRequest())); + + FullHttpResponse res = new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, HttpResponseStatus.OK, + Unpooled.copiedBuffer("Hello, World", CharsetUtil.US_ASCII)); + int len = res.content().readableBytes(); + res.headers().set(HttpHeaderNames.CONTENT_LENGTH, len); + res.headers().set(HttpHeaderNames.CONTENT_ENCODING, HttpHeaderValues.IDENTITY); + assertTrue(ch.writeOutbound(res)); + + FullHttpResponse response = ch.readOutbound(); + assertEquals(String.valueOf(len), response.headers().get(HttpHeaderNames.CONTENT_LENGTH)); + assertEquals(HttpHeaderValues.IDENTITY.toString(), response.headers().get(HttpHeaderNames.CONTENT_ENCODING)); + assertEquals("Hello, World", response.content().toString(CharsetUtil.US_ASCII)); + response.release(); + + assertTrue(ch.finishAndReleaseAll()); + } + + @Test + public void testCustomEncoding() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new HttpContentCompressor()); + assertTrue(ch.writeInbound(newRequest())); + + FullHttpResponse res = new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, HttpResponseStatus.OK, + Unpooled.copiedBuffer("Hello, World", CharsetUtil.US_ASCII)); + int len = res.content().readableBytes(); + res.headers().set(HttpHeaderNames.CONTENT_LENGTH, len); + res.headers().set(HttpHeaderNames.CONTENT_ENCODING, "ascii"); + assertTrue(ch.writeOutbound(res)); + + FullHttpResponse response = ch.readOutbound(); + assertEquals(String.valueOf(len), response.headers().get(HttpHeaderNames.CONTENT_LENGTH)); + assertEquals("ascii", response.headers().get(HttpHeaderNames.CONTENT_ENCODING)); + assertEquals("Hello, World", response.content().toString(CharsetUtil.US_ASCII)); + response.release(); + + assertTrue(ch.finishAndReleaseAll()); + } + + static boolean isBrotliAvailable() { + return Brotli.isAvailable(); + } + + @Test + @EnabledIf("isBrotliAvailable") + public void testBrotliFullHttpResponse() throws Exception { + HttpContentCompressor compressor = new HttpContentCompressor((CompressionOptions[]) null); + EmbeddedChannel ch = new EmbeddedChannel(compressor); + assertTrue(ch.writeInbound(newBrotliRequest())); + + FullHttpResponse res = new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, HttpResponseStatus.OK, + Unpooled.copiedBuffer("Hello Hello Hello Hello Hello", CharsetUtil.US_ASCII)); + int len = res.content().readableBytes(); + res.headers().set(HttpHeaderNames.CONTENT_LENGTH, len); + res.headers().set(HttpHeaderNames.CONTENT_TYPE, "text/plain"); + assertTrue(ch.writeOutbound(res)); + + DefaultHttpResponse response = ch.readOutbound(); + assertEquals(String.valueOf(19), response.headers().get(HttpHeaderNames.CONTENT_LENGTH)); + assertEquals("text/plain", response.headers().get(HttpHeaderNames.CONTENT_TYPE)); + assertEquals("br", response.headers().get(HttpHeaderNames.CONTENT_ENCODING)); + + CompositeByteBuf contentBuf = Unpooled.compositeBuffer(); + HttpContent content; + while ((content = ch.readOutbound()) != null) { + if (content.content().isReadable()) { + contentBuf.addComponent(true, content.content()); + } else { + content.content().release(); + } + } + + DirectDecompress decompressResult = DirectDecompress.decompress(ByteBufUtil.getBytes(contentBuf)); + assertEquals(DecoderJNI.Status.DONE, decompressResult.getResultStatus()); + assertEquals("Hello Hello Hello Hello Hello", + new String(decompressResult.getDecompressedData(), CharsetUtil.US_ASCII)); + + assertTrue(ch.finishAndReleaseAll()); + contentBuf.release(); + } + + @Test + @EnabledIf("isBrotliAvailable") + public void testBrotliChunkedContent() throws Exception { + HttpContentCompressor compressor = new HttpContentCompressor((CompressionOptions[]) null); + EmbeddedChannel ch = new EmbeddedChannel(compressor); + assertTrue(ch.writeInbound(newBrotliRequest())); + + HttpResponse res = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK); + res.headers().set(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED); + res.headers().set(HttpHeaderNames.CONTENT_TYPE, "text/plain"); + ch.writeOutbound(res); + + HttpResponse outboundRes = ch.readOutbound(); + assertThat(outboundRes, is(not(instanceOf(HttpContent.class)))); + assertThat(outboundRes.headers().get(HttpHeaderNames.TRANSFER_ENCODING), is("chunked")); + assertThat(outboundRes.headers().get(HttpHeaderNames.CONTENT_LENGTH), is(nullValue())); + assertThat(outboundRes.headers().get(HttpHeaderNames.CONTENT_ENCODING), is("br")); + assertThat(outboundRes.headers().get(HttpHeaderNames.CONTENT_TYPE), is("text/plain")); + + ch.writeOutbound(new DefaultHttpContent(Unpooled.copiedBuffer("Hell", CharsetUtil.US_ASCII))); + ch.writeOutbound(new DefaultHttpContent(Unpooled.copiedBuffer("o world. Hello w", CharsetUtil.US_ASCII))); + ch.writeOutbound(new DefaultLastHttpContent(Unpooled.copiedBuffer("orld.", CharsetUtil.US_ASCII))); + + CompositeByteBuf contentBuf = Unpooled.compositeBuffer(); + HttpContent content; + while ((content = ch.readOutbound()) != null) { + if (content.content().isReadable()) { + contentBuf.addComponent(true, content.content()); + } else { + content.content().release(); + } + } + + DirectDecompress decompressResult = DirectDecompress.decompress(ByteBufUtil.getBytes(contentBuf)); + assertEquals(DecoderJNI.Status.DONE, decompressResult.getResultStatus()); + assertEquals("Hello world. Hello world.", + new String(decompressResult.getDecompressedData(), CharsetUtil.US_ASCII)); + + assertTrue(ch.finishAndReleaseAll()); + contentBuf.release(); + } + + @Test + public void testCompressThresholdAllCompress() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new HttpContentCompressor()); + assertTrue(ch.writeInbound(newRequest())); + + FullHttpResponse res1023 = new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, HttpResponseStatus.OK, + Unpooled.wrappedBuffer(new byte[1023])); + assertTrue(ch.writeOutbound(res1023)); + DefaultHttpResponse response1023 = ch.readOutbound(); + assertThat(response1023.headers().get(HttpHeaderNames.CONTENT_ENCODING), is("gzip")); + ch.releaseOutbound(); + + assertTrue(ch.writeInbound(newRequest())); + FullHttpResponse res1024 = new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, HttpResponseStatus.OK, + Unpooled.wrappedBuffer(new byte[1024])); + assertTrue(ch.writeOutbound(res1024)); + DefaultHttpResponse response1024 = ch.readOutbound(); + assertThat(response1024.headers().get(HttpHeaderNames.CONTENT_ENCODING), is("gzip")); + assertTrue(ch.finishAndReleaseAll()); + } + + @Test + public void testCompressThresholdNotCompress() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new HttpContentCompressor(6, 15, 8, 1024)); + assertTrue(ch.writeInbound(newRequest())); + + FullHttpResponse res1023 = new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, HttpResponseStatus.OK, + Unpooled.wrappedBuffer(new byte[1023])); + assertTrue(ch.writeOutbound(res1023)); + DefaultHttpResponse response1023 = ch.readOutbound(); + assertFalse(response1023.headers().contains(HttpHeaderNames.CONTENT_ENCODING)); + ch.releaseOutbound(); + + assertTrue(ch.writeInbound(newRequest())); + FullHttpResponse res1024 = new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, HttpResponseStatus.OK, + Unpooled.wrappedBuffer(new byte[1024])); + assertTrue(ch.writeOutbound(res1024)); + DefaultHttpResponse response1024 = ch.readOutbound(); + assertThat(response1024.headers().get(HttpHeaderNames.CONTENT_ENCODING), is("gzip")); + assertTrue(ch.finishAndReleaseAll()); + } + + @Test + public void testMultipleAcceptEncodingHeaders() { + FullHttpRequest request = newRequest(); + request.headers().set(HttpHeaderNames.ACCEPT_ENCODING, "unknown; q=1.0") + .add(HttpHeaderNames.ACCEPT_ENCODING, "gzip; q=0.5") + .add(HttpHeaderNames.ACCEPT_ENCODING, "deflate; q=0"); + + EmbeddedChannel ch = new EmbeddedChannel(new HttpContentCompressor()); + + assertTrue(ch.writeInbound(request)); + + FullHttpResponse res = new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, HttpResponseStatus.OK, + Unpooled.copiedBuffer("Gzip Win", CharsetUtil.US_ASCII)); + assertTrue(ch.writeOutbound(res)); + + assertEncodedResponse(ch); + HttpContent c = ch.readOutbound(); + assertThat(ByteBufUtil.hexDump(c.content()), is("1f8b080000000000000072afca2c5008cfcc03000000ffff")); + c.release(); + + c = ch.readOutbound(); + assertThat(ByteBufUtil.hexDump(c.content()), is("03001f2ebf0f08000000")); + c.release(); + + LastHttpContent last = ch.readOutbound(); + assertThat(last.content().readableBytes(), is(0)); + last.release(); + + assertThat(ch.readOutbound(), is(nullValue())); + assertTrue(ch.finishAndReleaseAll()); + } + + @Test + public void testEmpty() { + EmbeddedChannel ch = new EmbeddedChannel(new HttpContentCompressor()); + assertTrue(ch.writeInbound(newRequest())); + + DefaultHttpResponse response = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK); + response.headers().add(HttpHeaderNames.CONTENT_LENGTH, 0); + assertTrue(ch.writeOutbound(response)); + assertTrue(ch.writeOutbound(new DefaultHttpContent(Unpooled.EMPTY_BUFFER))); + assertTrue(ch.writeOutbound(DefaultLastHttpContent.EMPTY_LAST_CONTENT)); + + ch.checkException(); + ch.finishAndReleaseAll(); + } + + private static FullHttpRequest newRequest() { + FullHttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"); + req.headers().set(HttpHeaderNames.ACCEPT_ENCODING, "gzip"); + return req; + } + + private static FullHttpRequest newBrotliRequest() { + FullHttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"); + req.headers().set(HttpHeaderNames.ACCEPT_ENCODING, "br"); + return req; + } + + private static void assertEncodedResponse(EmbeddedChannel ch) { + Object o = ch.readOutbound(); + assertThat(o, is(instanceOf(HttpResponse.class))); + + assertEncodedResponse((HttpResponse) o); + } + + private static void assertEncodedResponse(HttpResponse res) { + assertThat(res, is(not(instanceOf(HttpContent.class)))); + assertThat(res.headers().get(HttpHeaderNames.TRANSFER_ENCODING), is("chunked")); + assertThat(res.headers().get(HttpHeaderNames.CONTENT_LENGTH), is(nullValue())); + assertThat(res.headers().get(HttpHeaderNames.CONTENT_ENCODING), is("gzip")); + } + + private static void assertAssembledEncodedResponse(EmbeddedChannel ch) { + Object o = ch.readOutbound(); + assertThat(o, is(instanceOf(AssembledHttpResponse.class))); + + AssembledHttpResponse res = (AssembledHttpResponse) o; + try { + assertThat(res, is(instanceOf(HttpContent.class))); + assertThat(res.headers().get(HttpHeaderNames.TRANSFER_ENCODING), is("chunked")); + assertThat(res.headers().get(HttpHeaderNames.CONTENT_LENGTH), is(nullValue())); + assertThat(res.headers().get(HttpHeaderNames.CONTENT_ENCODING), is("gzip")); + } finally { + res.release(); + } + } + + static class AssembledHttpResponse extends DefaultHttpResponse implements HttpContent { + + private final ByteBuf content; + + AssembledHttpResponse(HttpVersion version, HttpResponseStatus status, ByteBuf content) { + super(version, status); + this.content = content; + } + + @Override + public HttpContent copy() { + throw new UnsupportedOperationException(); + } + + @Override + public HttpContent duplicate() { + throw new UnsupportedOperationException(); + } + + @Override + public HttpContent retainedDuplicate() { + throw new UnsupportedOperationException(); + } + + @Override + public HttpContent replace(ByteBuf content) { + throw new UnsupportedOperationException(); + } + + @Override + public AssembledHttpResponse retain() { + content.retain(); + return this; + } + + @Override + public AssembledHttpResponse retain(int increment) { + content.retain(increment); + return this; + } + + @Override + public ByteBuf content() { + return content; + } + + @Override + public int refCnt() { + return content.refCnt(); + } + + @Override + public boolean release() { + return content.release(); + } + + @Override + public boolean release(int decrement) { + return content.release(decrement); + } + + @Override + public AssembledHttpResponse touch() { + content.touch(); + return this; + } + + @Override + public AssembledHttpResponse touch(Object hint) { + content.touch(hint); + return this; + } + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpContentDecoderTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpContentDecoderTest.java new file mode 100644 index 0000000..aeca8dc --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpContentDecoderTest.java @@ -0,0 +1,874 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.codec.http; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.CodecException; +import io.netty.handler.codec.DecoderException; +import io.netty.handler.codec.compression.Brotli; +import io.netty.handler.codec.compression.ZlibCodecFactory; +import io.netty.handler.codec.compression.ZlibDecoder; +import io.netty.handler.codec.compression.ZlibEncoder; +import io.netty.handler.codec.compression.ZlibWrapper; +import io.netty.util.CharsetUtil; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.internal.PlatformDependent; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.DisabledIf; + +import java.util.ArrayList; +import java.util.List; +import java.util.Queue; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; + +import static org.hamcrest.CoreMatchers.instanceOf; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class HttpContentDecoderTest { + private static final String HELLO_WORLD = "hello, world"; + private static final byte[] GZ_HELLO_WORLD = { + 31, -117, 8, 8, 12, 3, -74, 84, 0, 3, 50, 0, -53, 72, -51, -55, -55, + -41, 81, 40, -49, 47, -54, 73, 1, 0, 58, 114, -85, -1, 12, 0, 0, 0 + }; + private static final byte[] SNAPPY_HELLO_WORLD = { + -1, 6, 0, 0, 115, 78, 97, 80, 112, 89, 1, 16, 0, 0, 11, -66, -63, + -22, 104, 101, 108, 108, 111, 44, 32, 119, 111, 114, 108, 100 + }; + private static final String SAMPLE_STRING = "Hello, I am Meow!. A small kitten. :)" + + "I sleep all day, and meow all night."; + private static final byte[] SAMPLE_BZ_BYTES = new byte[]{27, 72, 0, 0, -60, -102, 91, -86, 103, 20, + -28, -23, 54, -101, 11, -106, -16, -32, -95, -61, -37, 94, -16, 97, -40, -93, -56, 18, 21, 86, + -110, 82, -41, 102, -89, 20, 11, 10, -68, -31, 96, -116, -55, -80, -31, -91, 96, -64, 83, 51, + -39, 13, -21, 92, -16, -119, 124, -31, 18, 78, -1, 91, 82, 105, -116, -95, -22, -11, -70, -45, 0}; + + @Test + public void testBinaryDecompression() throws Exception { + // baseline test: zlib library and test helpers work correctly. + byte[] helloWorld = gzDecompress(GZ_HELLO_WORLD); + assertEquals(HELLO_WORLD.length(), helloWorld.length); + assertEquals(HELLO_WORLD, new String(helloWorld, CharsetUtil.US_ASCII)); + + String fullCycleTest = "full cycle test"; + byte[] compressed = gzCompress(fullCycleTest.getBytes(CharsetUtil.US_ASCII)); + byte[] decompressed = gzDecompress(compressed); + assertEquals(decompressed.length, fullCycleTest.length()); + assertEquals(fullCycleTest, new String(decompressed, CharsetUtil.US_ASCII)); + } + + @Test + public void testRequestDecompression() { + // baseline test: request decoder, content decompressor && request aggregator work as expected + HttpRequestDecoder decoder = new HttpRequestDecoder(); + HttpContentDecoder decompressor = new HttpContentDecompressor(); + HttpObjectAggregator aggregator = new HttpObjectAggregator(1024); + EmbeddedChannel channel = new EmbeddedChannel(decoder, decompressor, aggregator); + + String headers = "POST / HTTP/1.1\r\n" + + "Content-Length: " + GZ_HELLO_WORLD.length + "\r\n" + + "Content-Encoding: gzip\r\n" + + "\r\n"; + ByteBuf buf = Unpooled.copiedBuffer(headers.getBytes(CharsetUtil.US_ASCII), GZ_HELLO_WORLD); + assertTrue(channel.writeInbound(buf)); + + Object o = channel.readInbound(); + assertThat(o, is(instanceOf(FullHttpRequest.class))); + FullHttpRequest req = (FullHttpRequest) o; + assertEquals(HELLO_WORLD.length(), req.headers().getInt(HttpHeaderNames.CONTENT_LENGTH).intValue()); + assertEquals(HELLO_WORLD, req.content().toString(CharsetUtil.US_ASCII)); + req.release(); + + assertHasInboundMessages(channel, false); + assertHasOutboundMessages(channel, false); + assertFalse(channel.finish()); // assert that no messages are left in channel + } + + @Test + public void testChunkedRequestDecompression() { + HttpResponseDecoder decoder = new HttpResponseDecoder(); + HttpContentDecoder decompressor = new HttpContentDecompressor(); + + EmbeddedChannel channel = new EmbeddedChannel(decoder, decompressor, null); + + String headers = "HTTP/1.1 200 OK\r\n" + + "Transfer-Encoding: chunked\r\n" + + "Trailer: My-Trailer\r\n" + + "Content-Encoding: gzip\r\n\r\n"; + + channel.writeInbound(Unpooled.copiedBuffer(headers.getBytes(CharsetUtil.US_ASCII))); + + String chunkLength = Integer.toHexString(GZ_HELLO_WORLD.length); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(chunkLength + "\r\n", CharsetUtil.US_ASCII))); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(GZ_HELLO_WORLD))); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer("\r\n".getBytes(CharsetUtil.US_ASCII)))); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer("0\r\n", CharsetUtil.US_ASCII))); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer("My-Trailer: 42\r\n\r\n\r\n", CharsetUtil.US_ASCII))); + + Object ob1 = channel.readInbound(); + assertThat(ob1, is(instanceOf(DefaultHttpResponse.class))); + + Object ob2 = channel.readInbound(); + assertThat(ob2, is(instanceOf(HttpContent.class))); + HttpContent content = (HttpContent) ob2; + assertEquals(HELLO_WORLD, content.content().toString(CharsetUtil.US_ASCII)); + content.release(); + + Object ob3 = channel.readInbound(); + assertThat(ob3, is(instanceOf(LastHttpContent.class))); + LastHttpContent lastContent = (LastHttpContent) ob3; + assertNotNull(lastContent.decoderResult()); + assertTrue(lastContent.decoderResult().isSuccess()); + assertFalse(lastContent.trailingHeaders().isEmpty()); + assertEquals("42", lastContent.trailingHeaders().get("My-Trailer")); + assertHasInboundMessages(channel, false); + assertHasOutboundMessages(channel, false); + assertFalse(channel.finish()); + } + + @Test + public void testSnappyResponseDecompression() { + // baseline test: response decoder, content decompressor && request aggregator work as expected + HttpResponseDecoder decoder = new HttpResponseDecoder(); + HttpContentDecoder decompressor = new HttpContentDecompressor(); + HttpObjectAggregator aggregator = new HttpObjectAggregator(1024); + EmbeddedChannel channel = new EmbeddedChannel(decoder, decompressor, aggregator); + + String headers = "HTTP/1.1 200 OK\r\n" + + "Content-Length: " + SNAPPY_HELLO_WORLD.length + "\r\n" + + "Content-Encoding: snappy\r\n" + + "\r\n"; + ByteBuf buf = Unpooled.copiedBuffer(headers.getBytes(CharsetUtil.UTF_8), SNAPPY_HELLO_WORLD); + assertTrue(channel.writeInbound(buf)); + + Object o = channel.readInbound(); + assertThat(o, is(instanceOf(FullHttpResponse.class))); + FullHttpResponse resp = (FullHttpResponse) o; + assertEquals(HELLO_WORLD.length(), resp.headers().getInt(HttpHeaderNames.CONTENT_LENGTH).intValue()); + assertEquals(HELLO_WORLD, resp.content().toString(CharsetUtil.UTF_8)); + resp.release(); + + assertHasInboundMessages(channel, false); + assertHasOutboundMessages(channel, false); + assertFalse(channel.finish()); // assert that no messages are left in channel + } + + @Test + public void testResponseDecompression() { + // baseline test: response decoder, content decompressor && request aggregator work as expected + HttpResponseDecoder decoder = new HttpResponseDecoder(); + HttpContentDecoder decompressor = new HttpContentDecompressor(); + HttpObjectAggregator aggregator = new HttpObjectAggregator(1024); + + EmbeddedChannel channel = new EmbeddedChannel(decoder, decompressor, aggregator); + + String headers = "HTTP/1.1 200 OK\r\n" + + "Content-Length: " + GZ_HELLO_WORLD.length + "\r\n" + + "Content-Encoding: gzip\r\n" + + "\r\n"; + ByteBuf buf = Unpooled.copiedBuffer(headers.getBytes(CharsetUtil.US_ASCII), GZ_HELLO_WORLD); + assertTrue(channel.writeInbound(buf)); + + Object o = channel.readInbound(); + assertThat(o, is(instanceOf(FullHttpResponse.class))); + FullHttpResponse resp = (FullHttpResponse) o; + assertEquals(HELLO_WORLD.length(), resp.headers().getInt(HttpHeaderNames.CONTENT_LENGTH).intValue()); + assertEquals(HELLO_WORLD, resp.content().toString(CharsetUtil.US_ASCII)); + resp.release(); + + assertHasInboundMessages(channel, false); + assertHasOutboundMessages(channel, false); + assertFalse(channel.finish()); // assert that no messages are left in channel + } + + @DisabledIf(value = "isNotSupported", disabledReason = "Brotli is not supported on this platform") + @Test + public void testResponseBrotliDecompression() throws Throwable { + Brotli.ensureAvailability(); + + HttpResponseDecoder decoder = new HttpResponseDecoder(); + HttpContentDecoder decompressor = new HttpContentDecompressor(); + HttpObjectAggregator aggregator = new HttpObjectAggregator(Integer.MAX_VALUE); + EmbeddedChannel channel = new EmbeddedChannel(decoder, decompressor, aggregator); + + String headers = "HTTP/1.1 200 OK\r\n" + + "Content-Length: " + SAMPLE_BZ_BYTES.length + "\r\n" + + "Content-Encoding: br\r\n" + + "\r\n"; + ByteBuf buf = Unpooled.wrappedBuffer(headers.getBytes(CharsetUtil.US_ASCII), SAMPLE_BZ_BYTES); + assertTrue(channel.writeInbound(buf)); + + Object o = channel.readInbound(); + assertThat(o, is(instanceOf(FullHttpResponse.class))); + FullHttpResponse resp = (FullHttpResponse) o; + assertNull(resp.headers().get(HttpHeaderNames.CONTENT_ENCODING), "Content-Encoding header should be removed"); + assertEquals(SAMPLE_STRING, resp.content().toString(CharsetUtil.UTF_8), + "Response body should match uncompressed string"); + resp.release(); + + assertHasInboundMessages(channel, false); + assertHasOutboundMessages(channel, false); + assertFalse(channel.finish()); // assert that no messages are left in channel + } + + @DisabledIf(value = "isNotSupported", disabledReason = "Brotli is not supported on this platform") + @Test + public void testResponseChunksBrotliDecompression() throws Throwable { + Brotli.ensureAvailability(); + + HttpResponseDecoder decoder = new HttpResponseDecoder(); + HttpContentDecoder decompressor = new HttpContentDecompressor(); + HttpObjectAggregator aggregator = new HttpObjectAggregator(Integer.MAX_VALUE); + EmbeddedChannel channel = new EmbeddedChannel(decoder, decompressor, aggregator); + + String headers = "HTTP/1.1 200 OK\r\n" + + "Content-Length: " + SAMPLE_BZ_BYTES.length + "\r\n" + + "Content-Encoding: br\r\n" + + "\r\n"; + + assertFalse(channel.writeInbound(Unpooled.wrappedBuffer(headers.getBytes(CharsetUtil.US_ASCII)))); + + int offset = 0; + while (offset < SAMPLE_BZ_BYTES.length) { + int len = Math.min(1500, SAMPLE_BZ_BYTES.length - offset); + boolean available = channel.writeInbound(Unpooled.wrappedBuffer(SAMPLE_BZ_BYTES, offset, len)); + offset += 1500; + if (offset < SAMPLE_BZ_BYTES.length) { + assertFalse(available); + } else { + assertTrue(available); + } + } + + Object o = channel.readInbound(); + assertThat(o, is(instanceOf(FullHttpResponse.class))); + FullHttpResponse resp = (FullHttpResponse) o; + assertEquals(SAMPLE_STRING, resp.content().toString(CharsetUtil.UTF_8), + "Response body should match uncompressed string"); + resp.release(); + + assertHasInboundMessages(channel, false); + assertHasOutboundMessages(channel, false); + assertFalse(channel.finish()); // assert that no messages are left in channel + } + + @Test + public void testExpectContinueResponse1() { + // request with header "Expect: 100-continue" must be replied with one "100 Continue" response + // case 1: no ContentDecoder in chain at all (baseline test) + HttpRequestDecoder decoder = new HttpRequestDecoder(); + HttpObjectAggregator aggregator = new HttpObjectAggregator(1024); + EmbeddedChannel channel = new EmbeddedChannel(decoder, aggregator); + String req = "POST / HTTP/1.1\r\n" + + "Content-Length: " + GZ_HELLO_WORLD.length + "\r\n" + + "Expect: 100-continue\r\n" + + "\r\n"; + // note: the following writeInbound() returns false as there is no message is inbound buffer + // until HttpObjectAggregator caches composes a complete message. + // however, http response "100 continue" must be sent as soon as headers are received + assertFalse(channel.writeInbound(Unpooled.wrappedBuffer(req.getBytes()))); + + Object o = channel.readOutbound(); + assertThat(o, is(instanceOf(FullHttpResponse.class))); + FullHttpResponse r = (FullHttpResponse) o; + assertEquals(100, r.status().code()); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(GZ_HELLO_WORLD))); + r.release(); + + assertHasInboundMessages(channel, true); + assertHasOutboundMessages(channel, false); + assertFalse(channel.finish()); + } + + @Test + public void testExpectContinueResponse2() { + // request with header "Expect: 100-continue" must be replied with one "100 Continue" response + // case 2: contentDecoder is in chain, but the content is not encoded, should be no-op + HttpRequestDecoder decoder = new HttpRequestDecoder(); + HttpContentDecoder decompressor = new HttpContentDecompressor(); + HttpObjectAggregator aggregator = new HttpObjectAggregator(1024); + EmbeddedChannel channel = new EmbeddedChannel(decoder, decompressor, aggregator); + String req = "POST / HTTP/1.1\r\n" + + "Content-Length: " + GZ_HELLO_WORLD.length + "\r\n" + + "Expect: 100-continue\r\n" + + "\r\n"; + assertFalse(channel.writeInbound(Unpooled.wrappedBuffer(req.getBytes()))); + + Object o = channel.readOutbound(); + assertThat(o, is(instanceOf(FullHttpResponse.class))); + FullHttpResponse r = (FullHttpResponse) o; + assertEquals(100, r.status().code()); + r.release(); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(GZ_HELLO_WORLD))); + + assertHasInboundMessages(channel, true); + assertHasOutboundMessages(channel, false); + assertFalse(channel.finish()); + } + + @Test + public void testExpectContinueResponse3() { + // request with header "Expect: 100-continue" must be replied with one "100 Continue" response + // case 3: ContentDecoder is in chain and content is encoded + HttpRequestDecoder decoder = new HttpRequestDecoder(); + HttpContentDecoder decompressor = new HttpContentDecompressor(); + HttpObjectAggregator aggregator = new HttpObjectAggregator(1024); + EmbeddedChannel channel = new EmbeddedChannel(decoder, decompressor, aggregator); + String req = "POST / HTTP/1.1\r\n" + + "Content-Length: " + GZ_HELLO_WORLD.length + "\r\n" + + "Expect: 100-continue\r\n" + + "Content-Encoding: gzip\r\n" + + "\r\n"; + assertFalse(channel.writeInbound(Unpooled.wrappedBuffer(req.getBytes()))); + + Object o = channel.readOutbound(); + assertThat(o, is(instanceOf(FullHttpResponse.class))); + FullHttpResponse r = (FullHttpResponse) o; + assertEquals(100, r.status().code()); + r.release(); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(GZ_HELLO_WORLD))); + + assertHasInboundMessages(channel, true); + assertHasOutboundMessages(channel, false); + assertFalse(channel.finish()); + } + + @Test + public void testExpectContinueResponse4() { + // request with header "Expect: 100-continue" must be replied with one "100 Continue" response + // case 4: ObjectAggregator is up in chain + HttpRequestDecoder decoder = new HttpRequestDecoder(); + HttpObjectAggregator aggregator = new HttpObjectAggregator(1024); + HttpContentDecoder decompressor = new HttpContentDecompressor(); + EmbeddedChannel channel = new EmbeddedChannel(decoder, aggregator, decompressor); + String req = "POST / HTTP/1.1\r\n" + + "Content-Length: " + GZ_HELLO_WORLD.length + "\r\n" + + "Expect: 100-continue\r\n" + + "Content-Encoding: gzip\r\n" + + "\r\n"; + assertFalse(channel.writeInbound(Unpooled.wrappedBuffer(req.getBytes()))); + + Object o = channel.readOutbound(); + assertThat(o, is(instanceOf(FullHttpResponse.class))); + FullHttpResponse r = (FullHttpResponse) o; + assertEquals(100, r.status().code()); + r.release(); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(GZ_HELLO_WORLD))); + + assertHasInboundMessages(channel, true); + assertHasOutboundMessages(channel, false); + assertFalse(channel.finish()); + } + + @Test + public void testExpectContinueResetHttpObjectDecoder() { + // request with header "Expect: 100-continue" must be replied with one "100 Continue" response + // case 5: Test that HttpObjectDecoder correctly resets its internal state after a failed expectation. + HttpRequestDecoder decoder = new HttpRequestDecoder(); + final int maxBytes = 10; + HttpObjectAggregator aggregator = new HttpObjectAggregator(maxBytes); + final AtomicReference secondRequestRef = new AtomicReference(); + EmbeddedChannel channel = new EmbeddedChannel(decoder, aggregator, new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if (msg instanceof FullHttpRequest) { + if (!secondRequestRef.compareAndSet(null, (FullHttpRequest) msg)) { + ((FullHttpRequest) msg).release(); + } + } else { + ReferenceCountUtil.release(msg); + } + } + }); + String req1 = "POST /1 HTTP/1.1\r\n" + + "Content-Length: " + (maxBytes + 1) + "\r\n" + + "Expect: 100-continue\r\n" + + "\r\n"; + assertFalse(channel.writeInbound(Unpooled.wrappedBuffer(req1.getBytes(CharsetUtil.US_ASCII)))); + + FullHttpResponse resp = channel.readOutbound(); + assertEquals(HttpStatusClass.CLIENT_ERROR, resp.status().codeClass()); + resp.release(); + + String req2 = "POST /2 HTTP/1.1\r\n" + + "Content-Length: " + maxBytes + "\r\n" + + "Expect: 100-continue\r\n" + + "\r\n"; + assertFalse(channel.writeInbound(Unpooled.wrappedBuffer(req2.getBytes(CharsetUtil.US_ASCII)))); + + resp = channel.readOutbound(); + assertEquals(100, resp.status().code()); + resp.release(); + + byte[] content = new byte[maxBytes]; + assertFalse(channel.writeInbound(Unpooled.wrappedBuffer(content))); + + FullHttpRequest req = secondRequestRef.get(); + assertNotNull(req); + assertEquals("/2", req.uri()); + assertEquals(10, req.content().readableBytes()); + req.release(); + + assertHasInboundMessages(channel, false); + assertHasOutboundMessages(channel, false); + assertFalse(channel.finish()); + } + + @Test + public void testRequestContentLength1() { + // case 1: test that ContentDecompressor either sets the correct Content-Length header + // or removes it completely (handlers down the chain must rely on LastHttpContent object) + + // force content to be in more than one chunk (5 bytes/chunk) + HttpRequestDecoder decoder = new HttpRequestDecoder(4096, 4096, 5); + HttpContentDecoder decompressor = new HttpContentDecompressor(); + EmbeddedChannel channel = new EmbeddedChannel(decoder, decompressor); + String headers = "POST / HTTP/1.1\r\n" + + "Content-Length: " + GZ_HELLO_WORLD.length + "\r\n" + + "Content-Encoding: gzip\r\n" + + "\r\n"; + ByteBuf buf = Unpooled.copiedBuffer(headers.getBytes(CharsetUtil.US_ASCII), GZ_HELLO_WORLD); + assertTrue(channel.writeInbound(buf)); + + Queue req = channel.inboundMessages(); + assertTrue(req.size() >= 1); + Object o = req.peek(); + assertThat(o, is(instanceOf(HttpRequest.class))); + HttpRequest r = (HttpRequest) o; + String v = r.headers().get(HttpHeaderNames.CONTENT_LENGTH); + Long value = v == null ? null : Long.parseLong(v); + assertTrue(value == null || value.longValue() == HELLO_WORLD.length()); + + assertHasInboundMessages(channel, true); + assertHasOutboundMessages(channel, false); + assertFalse(channel.finish()); + } + + @Test + public void testRequestContentLength2() { + // case 2: if HttpObjectAggregator is down the chain, then correct Content-Length header must be set + + // force content to be in more than one chunk (5 bytes/chunk) + HttpRequestDecoder decoder = new HttpRequestDecoder(4096, 4096, 5); + HttpContentDecoder decompressor = new HttpContentDecompressor(); + HttpObjectAggregator aggregator = new HttpObjectAggregator(1024); + EmbeddedChannel channel = new EmbeddedChannel(decoder, decompressor, aggregator); + String headers = "POST / HTTP/1.1\r\n" + + "Content-Length: " + GZ_HELLO_WORLD.length + "\r\n" + + "Content-Encoding: gzip\r\n" + + "\r\n"; + ByteBuf buf = Unpooled.copiedBuffer(headers.getBytes(CharsetUtil.US_ASCII), GZ_HELLO_WORLD); + assertTrue(channel.writeInbound(buf)); + + Object o = channel.readInbound(); + assertThat(o, is(instanceOf(FullHttpRequest.class))); + FullHttpRequest r = (FullHttpRequest) o; + String v = r.headers().get(HttpHeaderNames.CONTENT_LENGTH); + Long value = v == null ? null : Long.parseLong(v); + + r.release(); + assertNotNull(value); + assertEquals(HELLO_WORLD.length(), value.longValue()); + + assertHasInboundMessages(channel, false); + assertHasOutboundMessages(channel, false); + assertFalse(channel.finish()); + } + + @Test + public void testResponseContentLength1() { + // case 1: test that ContentDecompressor either sets the correct Content-Length header + // or removes it completely (handlers down the chain must rely on LastHttpContent object) + + // force content to be in more than one chunk (5 bytes/chunk) + HttpResponseDecoder decoder = new HttpResponseDecoder(4096, 4096, 5); + HttpContentDecoder decompressor = new HttpContentDecompressor(); + EmbeddedChannel channel = new EmbeddedChannel(decoder, decompressor); + String headers = "HTTP/1.1 200 OK\r\n" + + "Content-Length: " + GZ_HELLO_WORLD.length + "\r\n" + + "Content-Encoding: gzip\r\n" + + "\r\n"; + ByteBuf buf = Unpooled.copiedBuffer(headers.getBytes(CharsetUtil.US_ASCII), GZ_HELLO_WORLD); + assertTrue(channel.writeInbound(buf)); + + Queue resp = channel.inboundMessages(); + assertTrue(resp.size() >= 1); + Object o = resp.peek(); + assertThat(o, is(instanceOf(HttpResponse.class))); + HttpResponse r = (HttpResponse) o; + + assertFalse(r.headers().contains(HttpHeaderNames.CONTENT_LENGTH), "Content-Length header not removed."); + + String transferEncoding = r.headers().get(HttpHeaderNames.TRANSFER_ENCODING); + assertNotNull(transferEncoding, "Content-length as well as transfer-encoding not set."); + assertEquals(HttpHeaderValues.CHUNKED.toString(), transferEncoding, "Unexpected transfer-encoding value."); + + assertHasInboundMessages(channel, true); + assertHasOutboundMessages(channel, false); + assertFalse(channel.finish()); + } + + @Test + public void testResponseContentLength2() { + // case 2: if HttpObjectAggregator is down the chain, then correct Content-Length header must be set + + // force content to be in more than one chunk (5 bytes/chunk) + HttpResponseDecoder decoder = new HttpResponseDecoder(4096, 4096, 5); + HttpContentDecoder decompressor = new HttpContentDecompressor(); + HttpObjectAggregator aggregator = new HttpObjectAggregator(1024); + EmbeddedChannel channel = new EmbeddedChannel(decoder, decompressor, aggregator); + String headers = "HTTP/1.1 200 OK\r\n" + + "Content-Length: " + GZ_HELLO_WORLD.length + "\r\n" + + "Content-Encoding: gzip\r\n" + + "\r\n"; + ByteBuf buf = Unpooled.copiedBuffer(headers.getBytes(CharsetUtil.US_ASCII), GZ_HELLO_WORLD); + assertTrue(channel.writeInbound(buf)); + + Object o = channel.readInbound(); + assertThat(o, is(instanceOf(FullHttpResponse.class))); + FullHttpResponse r = (FullHttpResponse) o; + String v = r.headers().get(HttpHeaderNames.CONTENT_LENGTH); + Long value = v == null ? null : Long.parseLong(v); + assertNotNull(value); + assertEquals(HELLO_WORLD.length(), value.longValue()); + r.release(); + + assertHasInboundMessages(channel, false); + assertHasOutboundMessages(channel, false); + assertFalse(channel.finish()); + } + + @Test + public void testFullHttpRequest() { + // test that ContentDecoder can be used after the ObjectAggregator + HttpRequestDecoder decoder = new HttpRequestDecoder(4096, 4096, 5); + HttpObjectAggregator aggregator = new HttpObjectAggregator(1024); + HttpContentDecoder decompressor = new HttpContentDecompressor(); + EmbeddedChannel channel = new EmbeddedChannel(decoder, aggregator, decompressor); + String headers = "POST / HTTP/1.1\r\n" + + "Content-Length: " + GZ_HELLO_WORLD.length + "\r\n" + + "Content-Encoding: gzip\r\n" + + "\r\n"; + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(headers.getBytes(), GZ_HELLO_WORLD))); + + Queue req = channel.inboundMessages(); + assertTrue(req.size() > 1); + int contentLength = 0; + contentLength = calculateContentLength(req, contentLength); + + byte[] receivedContent = readContent(req, contentLength, true); + + assertEquals(HELLO_WORLD, new String(receivedContent, CharsetUtil.US_ASCII)); + + assertHasInboundMessages(channel, true); + assertHasOutboundMessages(channel, false); + assertFalse(channel.finish()); + } + + @Test + public void testFullHttpResponse() { + // test that ContentDecoder can be used after the ObjectAggregator + HttpResponseDecoder decoder = new HttpResponseDecoder(4096, 4096, 5); + HttpObjectAggregator aggregator = new HttpObjectAggregator(1024); + HttpContentDecoder decompressor = new HttpContentDecompressor(); + EmbeddedChannel channel = new EmbeddedChannel(decoder, aggregator, decompressor); + String headers = "HTTP/1.1 200 OK\r\n" + + "Content-Length: " + GZ_HELLO_WORLD.length + "\r\n" + + "Content-Encoding: gzip\r\n" + + "\r\n"; + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(headers.getBytes(), GZ_HELLO_WORLD))); + + Queue resp = channel.inboundMessages(); + assertTrue(resp.size() > 1); + int contentLength = 0; + contentLength = calculateContentLength(resp, contentLength); + + byte[] receivedContent = readContent(resp, contentLength, true); + + assertEquals(HELLO_WORLD, new String(receivedContent, CharsetUtil.US_ASCII)); + + assertHasInboundMessages(channel, true); + assertHasOutboundMessages(channel, false); + assertFalse(channel.finish()); + } + + // See https://github.com/netty/netty/issues/5892 + @Test + public void testFullHttpResponseEOF() { + // test that ContentDecoder can be used after the ObjectAggregator + HttpResponseDecoder decoder = new HttpResponseDecoder(4096, 4096, 5); + HttpContentDecoder decompressor = new HttpContentDecompressor(); + EmbeddedChannel channel = new EmbeddedChannel(decoder, decompressor); + String headers = "HTTP/1.1 200 OK\r\n" + + "Content-Encoding: gzip\r\n" + + "\r\n"; + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(headers.getBytes(), GZ_HELLO_WORLD))); + // This should terminate it. + assertTrue(channel.finish()); + + Queue resp = channel.inboundMessages(); + assertTrue(resp.size() > 1); + int contentLength = 0; + contentLength = calculateContentLength(resp, contentLength); + + byte[] receivedContent = readContent(resp, contentLength, false); + + assertEquals(HELLO_WORLD, new String(receivedContent, CharsetUtil.US_ASCII)); + + assertHasInboundMessages(channel, true); + assertHasOutboundMessages(channel, false); + assertFalse(channel.finish()); + } + + @Test + public void testCleanupThrows() { + HttpContentDecoder decoder = new HttpContentDecoder() { + @Override + protected EmbeddedChannel newContentDecoder(String contentEncoding) throws Exception { + return new EmbeddedChannel(new ChannelInboundHandlerAdapter() { + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + ctx.fireExceptionCaught(new DecoderException()); + ctx.fireChannelInactive(); + } + }); + } + }; + + final AtomicBoolean channelInactiveCalled = new AtomicBoolean(); + EmbeddedChannel channel = new EmbeddedChannel(decoder, new ChannelInboundHandlerAdapter() { + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + assertTrue(channelInactiveCalled.compareAndSet(false, true)); + super.channelInactive(ctx); + } + }); + assertTrue(channel.writeInbound(new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"))); + HttpContent content = new DefaultHttpContent(Unpooled.buffer().writeZero(10)); + assertTrue(channel.writeInbound(content)); + assertEquals(1, content.refCnt()); + try { + channel.finishAndReleaseAll(); + fail(); + } catch (CodecException expected) { + // expected + } + assertTrue(channelInactiveCalled.get()); + assertEquals(0, content.refCnt()); + } + + @Test + public void testTransferCodingGZIP() { + String requestStr = "POST / HTTP/1.1\r\n" + + "Content-Length: " + GZ_HELLO_WORLD.length + "\r\n" + + "Transfer-Encoding: gzip\r\n" + + "\r\n"; + HttpRequestDecoder decoder = new HttpRequestDecoder(); + HttpContentDecoder decompressor = new HttpContentDecompressor(); + EmbeddedChannel channel = new EmbeddedChannel(decoder, decompressor); + + channel.writeInbound(Unpooled.copiedBuffer(requestStr.getBytes())); + channel.writeInbound(Unpooled.copiedBuffer(GZ_HELLO_WORLD)); + + HttpRequest request = channel.readInbound(); + assertTrue(request.decoderResult().isSuccess()); + assertFalse(request.headers().contains(HttpHeaderNames.CONTENT_LENGTH)); + + HttpContent content = channel.readInbound(); + assertTrue(content.decoderResult().isSuccess()); + assertEquals(HELLO_WORLD, content.content().toString(CharsetUtil.US_ASCII)); + content.release(); + + LastHttpContent lastHttpContent = channel.readInbound(); + assertTrue(lastHttpContent.decoderResult().isSuccess()); + lastHttpContent.release(); + + assertHasInboundMessages(channel, false); + assertHasOutboundMessages(channel, false); + assertFalse(channel.finish()); + channel.releaseInbound(); + } + + @Test + public void testTransferCodingGZIPAndChunked() { + String requestStr = "POST / HTTP/1.1\r\n" + + "Host: example.com\r\n" + + "Content-Type: application/x-www-form-urlencoded\r\n" + + "Trailer: My-Trailer\r\n" + + "Transfer-Encoding: gzip, chunked\r\n" + + "\r\n"; + HttpRequestDecoder decoder = new HttpRequestDecoder(); + HttpContentDecoder decompressor = new HttpContentDecompressor(); + EmbeddedChannel channel = new EmbeddedChannel(decoder, decompressor); + + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(requestStr, CharsetUtil.US_ASCII))); + + String chunkLength = Integer.toHexString(GZ_HELLO_WORLD.length); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(chunkLength + "\r\n", CharsetUtil.US_ASCII))); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(GZ_HELLO_WORLD))); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer("\r\n".getBytes(CharsetUtil.US_ASCII)))); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer("0\r\n", CharsetUtil.US_ASCII))); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer("My-Trailer: 42\r\n\r\n", CharsetUtil.US_ASCII))); + + HttpRequest request = channel.readInbound(); + assertTrue(request.decoderResult().isSuccess()); + assertTrue(request.headers().containsValue(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED, true)); + assertFalse(request.headers().contains(HttpHeaderNames.CONTENT_LENGTH)); + + HttpContent chunk1 = channel.readInbound(); + assertTrue(chunk1.decoderResult().isSuccess()); + assertEquals(HELLO_WORLD, chunk1.content().toString(CharsetUtil.US_ASCII)); + chunk1.release(); + + LastHttpContent chunk2 = channel.readInbound(); + assertTrue(chunk2.decoderResult().isSuccess()); + assertEquals("42", chunk2.trailingHeaders().get("My-Trailer")); + chunk2.release(); + + assertFalse(channel.finish()); + channel.releaseInbound(); + } + + private static byte[] gzDecompress(byte[] input) { + ZlibDecoder decoder = ZlibCodecFactory.newZlibDecoder(ZlibWrapper.GZIP); + EmbeddedChannel channel = new EmbeddedChannel(decoder); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(input))); + assertTrue(channel.finish()); // close the channel to indicate end-of-data + + int outputSize = 0; + ByteBuf o; + List inbound = new ArrayList(); + while ((o = channel.readInbound()) != null) { + inbound.add(o); + outputSize += o.readableBytes(); + } + + byte[] output = new byte[outputSize]; + int readCount = 0; + for (ByteBuf b : inbound) { + int readableBytes = b.readableBytes(); + b.readBytes(output, readCount, readableBytes); + b.release(); + readCount += readableBytes; + } + assertTrue(channel.inboundMessages().isEmpty() && channel.outboundMessages().isEmpty()); + return output; + } + + private static byte[] readContent(Queue req, int contentLength, boolean hasTransferEncoding) { + byte[] receivedContent = new byte[contentLength]; + int readCount = 0; + for (Object o : req) { + if (o instanceof HttpContent) { + ByteBuf b = ((HttpContent) o).content(); + int readableBytes = b.readableBytes(); + b.readBytes(receivedContent, readCount, readableBytes); + readCount += readableBytes; + } + if (o instanceof HttpMessage) { + assertEquals(hasTransferEncoding, + ((HttpMessage) o).headers().contains(HttpHeaderNames.TRANSFER_ENCODING)); + } + } + return receivedContent; + } + + private static int calculateContentLength(Queue req, int contentLength) { + for (Object o : req) { + if (o instanceof HttpContent) { + assertTrue(((HttpContent) o).refCnt() > 0); + ByteBuf b = ((HttpContent) o).content(); + contentLength += b.readableBytes(); + } + } + return contentLength; + } + + private static byte[] gzCompress(byte[] input) { + ZlibEncoder encoder = ZlibCodecFactory.newZlibEncoder(ZlibWrapper.GZIP); + EmbeddedChannel channel = new EmbeddedChannel(encoder); + assertTrue(channel.writeOutbound(Unpooled.wrappedBuffer(input))); + assertTrue(channel.finish()); // close the channel to indicate end-of-data + + int outputSize = 0; + ByteBuf o; + List outbound = new ArrayList(); + while ((o = channel.readOutbound()) != null) { + outbound.add(o); + outputSize += o.readableBytes(); + } + + byte[] output = new byte[outputSize]; + int readCount = 0; + for (ByteBuf b : outbound) { + int readableBytes = b.readableBytes(); + b.readBytes(output, readCount, readableBytes); + b.release(); + readCount += readableBytes; + } + assertTrue(channel.inboundMessages().isEmpty() && channel.outboundMessages().isEmpty()); + return output; + } + + private static void assertHasInboundMessages(EmbeddedChannel channel, boolean hasMessages) { + Object o; + if (hasMessages) { + while (true) { + o = channel.readInbound(); + assertNotNull(o); + ReferenceCountUtil.release(o); + if (o instanceof LastHttpContent) { + break; + } + } + } else { + o = channel.readInbound(); + assertNull(o); + } + } + + private static void assertHasOutboundMessages(EmbeddedChannel channel, boolean hasMessages) { + Object o; + if (hasMessages) { + while (true) { + o = channel.readOutbound(); + assertNotNull(o); + ReferenceCountUtil.release(o); + if (o instanceof LastHttpContent) { + break; + } + } + } else { + o = channel.readOutbound(); + assertNull(o); + } + } + + static boolean isNotSupported() { + return PlatformDependent.isOsx() && "aarch_64".equals(PlatformDependent.normalizedArch()); + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpContentDecompressorTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpContentDecompressorTest.java new file mode 100644 index 0000000..d9f5cd5 --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpContentDecompressorTest.java @@ -0,0 +1,73 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelOutboundHandlerAdapter; +import io.netty.channel.embedded.EmbeddedChannel; +import org.junit.jupiter.api.Test; + +import java.util.concurrent.atomic.AtomicInteger; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class HttpContentDecompressorTest { + + // See https://github.com/netty/netty/issues/8915. + @Test + public void testInvokeReadWhenNotProduceMessage() { + final AtomicInteger readCalled = new AtomicInteger(); + EmbeddedChannel channel = new EmbeddedChannel(new ChannelOutboundHandlerAdapter() { + @Override + public void read(ChannelHandlerContext ctx) { + readCalled.incrementAndGet(); + ctx.read(); + } + }, new HttpContentDecompressor(), new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + ctx.fireChannelRead(msg); + ctx.read(); + } + }); + + channel.config().setAutoRead(false); + + readCalled.set(0); + HttpResponse response = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK); + response.headers().set(HttpHeaderNames.CONTENT_ENCODING, "gzip"); + response.headers().set(HttpHeaderNames.CONTENT_TYPE, "application/json;charset=UTF-8"); + response.headers().set(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED); + + assertTrue(channel.writeInbound(response)); + + // we triggered read explicitly + assertEquals(1, readCalled.get()); + + assertTrue(channel.readInbound() instanceof HttpResponse); + + assertFalse(channel.writeInbound(new DefaultHttpContent(Unpooled.EMPTY_BUFFER))); + + // read was triggered by the HttpContentDecompressor itself as it did not produce any message to the next + // inbound handler. + assertEquals(2, readCalled.get()); + assertFalse(channel.finishAndReleaseAll()); + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpContentEncoderTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpContentEncoderTest.java new file mode 100644 index 0000000..e1dda9d --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpContentEncoderTest.java @@ -0,0 +1,465 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.codec.http; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.CodecException; +import io.netty.handler.codec.DecoderResult; +import io.netty.handler.codec.EncoderException; +import io.netty.handler.codec.MessageToByteEncoder; +import io.netty.util.CharsetUtil; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import java.util.concurrent.atomic.AtomicBoolean; + +import static io.netty.handler.codec.http.HttpHeadersTestUtils.of; +import static org.hamcrest.CoreMatchers.instanceOf; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.CoreMatchers.not; +import static org.hamcrest.CoreMatchers.nullValue; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class HttpContentEncoderTest { + + private static final class TestEncoder extends HttpContentEncoder { + @Override + protected Result beginEncode(HttpResponse httpResponse, String acceptEncoding) { + return new Result("test", new EmbeddedChannel(new MessageToByteEncoder() { + @Override + protected void encode(ChannelHandlerContext ctx, ByteBuf in, ByteBuf out) throws Exception { + out.writeBytes(String.valueOf(in.readableBytes()).getBytes(CharsetUtil.US_ASCII)); + in.skipBytes(in.readableBytes()); + } + })); + } + } + + @Test + public void testSplitContent() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new TestEncoder()); + ch.writeInbound(new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/")); + + ch.writeOutbound(new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK)); + ch.writeOutbound(new DefaultHttpContent(Unpooled.wrappedBuffer(new byte[3]))); + ch.writeOutbound(new DefaultHttpContent(Unpooled.wrappedBuffer(new byte[2]))); + ch.writeOutbound(new DefaultLastHttpContent(Unpooled.wrappedBuffer(new byte[1]))); + + assertEncodedResponse(ch); + + HttpContent chunk; + chunk = ch.readOutbound(); + assertThat(chunk.content().toString(CharsetUtil.US_ASCII), is("3")); + chunk.release(); + + chunk = ch.readOutbound(); + assertThat(chunk.content().toString(CharsetUtil.US_ASCII), is("2")); + chunk.release(); + + chunk = ch.readOutbound(); + assertThat(chunk.content().toString(CharsetUtil.US_ASCII), is("1")); + chunk.release(); + + chunk = ch.readOutbound(); + assertThat(chunk.content().isReadable(), is(false)); + assertThat(chunk, is(instanceOf(LastHttpContent.class))); + chunk.release(); + + assertThat(ch.readOutbound(), is(nullValue())); + } + + @Test + public void testChunkedContent() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new TestEncoder()); + ch.writeInbound(new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/")); + + HttpResponse res = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK); + res.headers().set(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED); + ch.writeOutbound(res); + + assertEncodedResponse(ch); + + ch.writeOutbound(new DefaultHttpContent(Unpooled.wrappedBuffer(new byte[3]))); + ch.writeOutbound(new DefaultHttpContent(Unpooled.wrappedBuffer(new byte[2]))); + ch.writeOutbound(new DefaultLastHttpContent(Unpooled.wrappedBuffer(new byte[1]))); + + HttpContent chunk; + chunk = ch.readOutbound(); + assertThat(chunk.content().toString(CharsetUtil.US_ASCII), is("3")); + chunk.release(); + + chunk = ch.readOutbound(); + assertThat(chunk.content().toString(CharsetUtil.US_ASCII), is("2")); + chunk.release(); + + chunk = ch.readOutbound(); + assertThat(chunk.content().toString(CharsetUtil.US_ASCII), is("1")); + assertThat(chunk, is(instanceOf(HttpContent.class))); + chunk.release(); + + chunk = ch.readOutbound(); + assertThat(chunk.content().isReadable(), is(false)); + assertThat(chunk, is(instanceOf(LastHttpContent.class))); + chunk.release(); + + assertThat(ch.readOutbound(), is(nullValue())); + } + + @Test + public void testChunkedContentWithTrailingHeader() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new TestEncoder()); + ch.writeInbound(new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/")); + + HttpResponse res = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK); + res.headers().set(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED); + ch.writeOutbound(res); + + assertEncodedResponse(ch); + + ch.writeOutbound(new DefaultHttpContent(Unpooled.wrappedBuffer(new byte[3]))); + ch.writeOutbound(new DefaultHttpContent(Unpooled.wrappedBuffer(new byte[2]))); + LastHttpContent content = new DefaultLastHttpContent(Unpooled.wrappedBuffer(new byte[1])); + content.trailingHeaders().set(of("X-Test"), of("Netty")); + ch.writeOutbound(content); + + HttpContent chunk; + chunk = ch.readOutbound(); + assertThat(chunk.content().toString(CharsetUtil.US_ASCII), is("3")); + chunk.release(); + + chunk = ch.readOutbound(); + assertThat(chunk.content().toString(CharsetUtil.US_ASCII), is("2")); + chunk.release(); + + chunk = ch.readOutbound(); + assertThat(chunk.content().toString(CharsetUtil.US_ASCII), is("1")); + assertThat(chunk, is(instanceOf(HttpContent.class))); + chunk.release(); + + chunk = ch.readOutbound(); + assertThat(chunk.content().isReadable(), is(false)); + assertThat(chunk, is(instanceOf(LastHttpContent.class))); + assertEquals("Netty", ((LastHttpContent) chunk).trailingHeaders().get(of("X-Test"))); + assertEquals(DecoderResult.SUCCESS, res.decoderResult()); + chunk.release(); + + assertThat(ch.readOutbound(), is(nullValue())); + } + + @Test + public void testFullContentWithContentLength() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new TestEncoder()); + ch.writeInbound(new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/")); + + FullHttpResponse fullRes = new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, HttpResponseStatus.OK, Unpooled.wrappedBuffer(new byte[42])); + fullRes.headers().set(HttpHeaderNames.CONTENT_LENGTH, 42); + ch.writeOutbound(fullRes); + + HttpResponse res = ch.readOutbound(); + assertThat(res, is(not(instanceOf(HttpContent.class)))); + assertThat(res.headers().get(HttpHeaderNames.TRANSFER_ENCODING), is(nullValue())); + assertThat(res.headers().get(HttpHeaderNames.CONTENT_LENGTH), is("2")); + assertThat(res.headers().get(HttpHeaderNames.CONTENT_ENCODING), is("test")); + + HttpContent c = ch.readOutbound(); + assertThat(c.content().readableBytes(), is(2)); + assertThat(c.content().toString(CharsetUtil.US_ASCII), is("42")); + c.release(); + + LastHttpContent last = ch.readOutbound(); + assertThat(last.content().readableBytes(), is(0)); + last.release(); + + assertThat(ch.readOutbound(), is(nullValue())); + } + + @Test + public void testFullContent() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new TestEncoder()); + ch.writeInbound(new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/")); + + FullHttpResponse res = new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, HttpResponseStatus.OK, Unpooled.wrappedBuffer(new byte[42])); + ch.writeOutbound(res); + + assertEncodedResponse(ch); + HttpContent c = ch.readOutbound(); + assertThat(c.content().readableBytes(), is(2)); + assertThat(c.content().toString(CharsetUtil.US_ASCII), is("42")); + c.release(); + + LastHttpContent last = ch.readOutbound(); + assertThat(last.content().readableBytes(), is(0)); + last.release(); + + assertThat(ch.readOutbound(), is(nullValue())); + } + + /** + * If the length of the content is unknown, {@link HttpContentEncoder} should not skip encoding the content + * even if the actual length is turned out to be 0. + */ + @Test + public void testEmptySplitContent() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new TestEncoder()); + ch.writeInbound(new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/")); + + ch.writeOutbound(new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK)); + assertEncodedResponse(ch); + + ch.writeOutbound(LastHttpContent.EMPTY_LAST_CONTENT); + HttpContent chunk = ch.readOutbound(); + assertThat(chunk.content().toString(CharsetUtil.US_ASCII), is("0")); + assertThat(chunk, is(instanceOf(HttpContent.class))); + chunk.release(); + + chunk = ch.readOutbound(); + assertThat(chunk.content().isReadable(), is(false)); + assertThat(chunk, is(instanceOf(LastHttpContent.class))); + chunk.release(); + + assertThat(ch.readOutbound(), is(nullValue())); + } + + /** + * If the length of the content is 0 for sure, {@link HttpContentEncoder} should skip encoding. + */ + @Test + public void testEmptyFullContent() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new TestEncoder()); + ch.writeInbound(new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/")); + + FullHttpResponse res = new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, HttpResponseStatus.OK, Unpooled.EMPTY_BUFFER); + ch.writeOutbound(res); + + Object o = ch.readOutbound(); + assertThat(o, is(instanceOf(FullHttpResponse.class))); + + res = (FullHttpResponse) o; + assertThat(res.headers().get(HttpHeaderNames.TRANSFER_ENCODING), is(nullValue())); + + // Content encoding shouldn't be modified. + assertThat(res.headers().get(HttpHeaderNames.CONTENT_ENCODING), is(nullValue())); + assertThat(res.content().readableBytes(), is(0)); + assertThat(res.content().toString(CharsetUtil.US_ASCII), is("")); + res.release(); + + assertThat(ch.readOutbound(), is(nullValue())); + } + + @Test + public void testEmptyFullContentWithTrailer() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new TestEncoder()); + ch.writeInbound(new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/")); + + FullHttpResponse res = new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, HttpResponseStatus.OK, Unpooled.EMPTY_BUFFER); + res.trailingHeaders().set(of("X-Test"), of("Netty")); + ch.writeOutbound(res); + + Object o = ch.readOutbound(); + assertThat(o, is(instanceOf(FullHttpResponse.class))); + + res = (FullHttpResponse) o; + assertThat(res.headers().get(HttpHeaderNames.TRANSFER_ENCODING), is(nullValue())); + + // Content encoding shouldn't be modified. + assertThat(res.headers().get(HttpHeaderNames.CONTENT_ENCODING), is(nullValue())); + assertThat(res.content().readableBytes(), is(0)); + assertThat(res.content().toString(CharsetUtil.US_ASCII), is("")); + assertEquals("Netty", res.trailingHeaders().get(of("X-Test"))); + assertEquals(DecoderResult.SUCCESS, res.decoderResult()); + assertThat(ch.readOutbound(), is(nullValue())); + } + + @Test + public void testEmptyHeadResponse() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new TestEncoder()); + HttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.HEAD, "/"); + ch.writeInbound(req); + + HttpResponse res = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK); + res.headers().set(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED); + ch.writeOutbound(res); + ch.writeOutbound(LastHttpContent.EMPTY_LAST_CONTENT); + + assertEmptyResponse(ch); + } + + @Test + public void testHttp304Response() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new TestEncoder()); + HttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"); + req.headers().set(HttpHeaderNames.ACCEPT_ENCODING, HttpHeaderValues.GZIP); + ch.writeInbound(req); + + HttpResponse res = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.NOT_MODIFIED); + res.headers().set(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED); + ch.writeOutbound(res); + ch.writeOutbound(LastHttpContent.EMPTY_LAST_CONTENT); + + assertEmptyResponse(ch); + } + + @Test + public void testConnect200Response() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new TestEncoder()); + HttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.CONNECT, "google.com:80"); + ch.writeInbound(req); + + HttpResponse res = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK); + res.headers().set(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED); + ch.writeOutbound(res); + ch.writeOutbound(LastHttpContent.EMPTY_LAST_CONTENT); + + assertEmptyResponse(ch); + } + + @Test + public void testConnectFailureResponse() throws Exception { + String content = "Not allowed by configuration"; + + EmbeddedChannel ch = new EmbeddedChannel(new TestEncoder()); + HttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.CONNECT, "google.com:80"); + ch.writeInbound(req); + + HttpResponse res = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.METHOD_NOT_ALLOWED); + res.headers().set(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED); + ch.writeOutbound(res); + ch.writeOutbound(new DefaultHttpContent(Unpooled.wrappedBuffer(content.getBytes(CharsetUtil.UTF_8)))); + ch.writeOutbound(LastHttpContent.EMPTY_LAST_CONTENT); + + assertEncodedResponse(ch); + Object o = ch.readOutbound(); + assertThat(o, is(instanceOf(HttpContent.class))); + HttpContent chunk = (HttpContent) o; + assertThat(chunk.content().toString(CharsetUtil.US_ASCII), is("28")); + chunk.release(); + + chunk = ch.readOutbound(); + assertThat(chunk.content().isReadable(), is(true)); + assertThat(chunk.content().toString(CharsetUtil.US_ASCII), is("0")); + chunk.release(); + + chunk = ch.readOutbound(); + assertThat(chunk, is(instanceOf(LastHttpContent.class))); + chunk.release(); + assertThat(ch.readOutbound(), is(nullValue())); + } + + @Test + public void testHttp1_0() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new TestEncoder()); + FullHttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_0, HttpMethod.GET, "/"); + assertTrue(ch.writeInbound(req)); + + HttpResponse res = new DefaultHttpResponse(HttpVersion.HTTP_1_0, HttpResponseStatus.OK); + res.headers().set(HttpHeaderNames.CONTENT_LENGTH, HttpHeaderValues.ZERO); + assertTrue(ch.writeOutbound(res)); + assertTrue(ch.writeOutbound(LastHttpContent.EMPTY_LAST_CONTENT)); + assertTrue(ch.finish()); + + FullHttpRequest request = ch.readInbound(); + assertTrue(request.release()); + assertNull(ch.readInbound()); + + HttpResponse response = ch.readOutbound(); + assertSame(res, response); + + LastHttpContent content = ch.readOutbound(); + assertSame(LastHttpContent.EMPTY_LAST_CONTENT, content); + content.release(); + assertNull(ch.readOutbound()); + } + + @Test + public void testCleanupThrows() { + HttpContentEncoder encoder = new HttpContentEncoder() { + @Override + protected Result beginEncode(HttpResponse httpResponse, String acceptEncoding) throws Exception { + return new Result("myencoding", new EmbeddedChannel( + new ChannelInboundHandlerAdapter() { + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + ctx.fireExceptionCaught(new EncoderException()); + ctx.fireChannelInactive(); + } + })); + } + }; + + final AtomicBoolean channelInactiveCalled = new AtomicBoolean(); + final EmbeddedChannel channel = new EmbeddedChannel(encoder, new ChannelInboundHandlerAdapter() { + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + assertTrue(channelInactiveCalled.compareAndSet(false, true)); + super.channelInactive(ctx); + } + }); + assertTrue(channel.writeInbound(new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"))); + assertTrue(channel.writeOutbound(new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK))); + HttpContent content = new DefaultHttpContent(Unpooled.buffer().writeZero(10)); + assertTrue(channel.writeOutbound(content)); + assertEquals(1, content.refCnt()); + assertThrows(CodecException.class, new Executable() { + @Override + public void execute() { + channel.finishAndReleaseAll(); + } + }); + + assertTrue(channelInactiveCalled.get()); + assertEquals(0, content.refCnt()); + } + + private static void assertEmptyResponse(EmbeddedChannel ch) { + Object o = ch.readOutbound(); + assertThat(o, is(instanceOf(HttpResponse.class))); + + HttpResponse res = (HttpResponse) o; + assertThat(res, is(not(instanceOf(HttpContent.class)))); + assertThat(res.headers().get(HttpHeaderNames.TRANSFER_ENCODING), is("chunked")); + assertThat(res.headers().get(HttpHeaderNames.CONTENT_LENGTH), is(nullValue())); + + HttpContent chunk = ch.readOutbound(); + assertThat(chunk, is(instanceOf(LastHttpContent.class))); + chunk.release(); + assertThat(ch.readOutbound(), is(nullValue())); + } + + private static void assertEncodedResponse(EmbeddedChannel ch) { + Object o = ch.readOutbound(); + assertThat(o, is(instanceOf(HttpResponse.class))); + + HttpResponse res = (HttpResponse) o; + assertThat(res, is(not(instanceOf(HttpContent.class)))); + assertThat(res.headers().get(HttpHeaderNames.TRANSFER_ENCODING), is("chunked")); + assertThat(res.headers().get(HttpHeaderNames.CONTENT_LENGTH), is(nullValue())); + assertThat(res.headers().get(HttpHeaderNames.CONTENT_ENCODING), is("test")); + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpHeaderDateFormatTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpHeaderDateFormatTest.java new file mode 100644 index 0000000..b8daba3 --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpHeaderDateFormatTest.java @@ -0,0 +1,68 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import org.junit.jupiter.api.Test; + +import java.text.ParseException; +import java.util.Date; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; + +public class HttpHeaderDateFormatTest { + /** + * This date is set at "06 Nov 1994 08:49:37 GMT" (same used in example in + * RFC documentation) + *

+ * https://www.w3.org/Protocols/rfc2616/rfc2616-sec3.html + */ + private static final Date DATE = new Date(784111777000L); + + @Test + public void testParse() throws ParseException { + HttpHeaderDateFormat format = HttpHeaderDateFormat.get(); + + final Date parsedDateWithSingleDigitDay = format.parse("Sun, 6 Nov 1994 08:49:37 GMT"); + assertNotNull(parsedDateWithSingleDigitDay); + assertEquals(DATE, parsedDateWithSingleDigitDay); + + final Date parsedDateWithDoubleDigitDay = format.parse("Sun, 06 Nov 1994 08:49:37 GMT"); + assertNotNull(parsedDateWithDoubleDigitDay); + assertEquals(DATE, parsedDateWithDoubleDigitDay); + + final Date parsedDateWithDashSeparatorSingleDigitDay = format.parse("Sunday, 06-Nov-94 08:49:37 GMT"); + assertNotNull(parsedDateWithDashSeparatorSingleDigitDay); + assertEquals(DATE, parsedDateWithDashSeparatorSingleDigitDay); + + final Date parsedDateWithSingleDoubleDigitDay = format.parse("Sunday, 6-Nov-94 08:49:37 GMT"); + assertNotNull(parsedDateWithSingleDoubleDigitDay); + assertEquals(DATE, parsedDateWithSingleDoubleDigitDay); + + final Date parsedDateWithoutGMT = format.parse("Sun Nov 6 08:49:37 1994"); + assertNotNull(parsedDateWithoutGMT); + assertEquals(DATE, parsedDateWithoutGMT); + } + + @Test + public void testFormat() { + HttpHeaderDateFormat format = HttpHeaderDateFormat.get(); + + final String formatted = format.format(DATE); + assertNotNull(formatted); + assertEquals("Sun, 06 Nov 1994 08:49:37 GMT", formatted); + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpHeaderValidationUtilTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpHeaderValidationUtilTest.java new file mode 100644 index 0000000..149d160 --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpHeaderValidationUtilTest.java @@ -0,0 +1,584 @@ +/* + * Copyright 2022 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.util.AsciiString; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.DisabledForJreRange; +import org.junit.jupiter.api.condition.JRE; +import org.junit.jupiter.api.parallel.Execution; +import org.junit.jupiter.api.parallel.ExecutionMode; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.Arguments; +import org.junit.jupiter.params.provider.MethodSource; + +import java.nio.ByteBuffer; +import java.util.ArrayList; +import java.util.List; +import java.util.function.Supplier; + +import static io.netty.handler.codec.http.HttpHeaderValidationUtil.validateToken; +import static io.netty.handler.codec.http.HttpHeaderValidationUtil.validateValidHeaderValue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; + +@Execution(ExecutionMode.CONCURRENT) // We have a couple of fairly slow tests here. Better to run them in parallel. +public class HttpHeaderValidationUtilTest { + @SuppressWarnings("deprecation") // We need to check for deprecated headers as well. + public static List connectionRelatedHeaders() { + List list = new ArrayList(); + + list.add(header(false, HttpHeaderNames.ACCEPT)); + list.add(header(false, HttpHeaderNames.ACCEPT_CHARSET)); + list.add(header(false, HttpHeaderNames.ACCEPT_ENCODING)); + list.add(header(false, HttpHeaderNames.ACCEPT_LANGUAGE)); + list.add(header(false, HttpHeaderNames.ACCEPT_RANGES)); + list.add(header(false, HttpHeaderNames.ACCEPT_PATCH)); + list.add(header(false, HttpHeaderNames.ACCESS_CONTROL_ALLOW_CREDENTIALS)); + list.add(header(false, HttpHeaderNames.ACCESS_CONTROL_ALLOW_HEADERS)); + list.add(header(false, HttpHeaderNames.ACCESS_CONTROL_ALLOW_METHODS)); + list.add(header(false, HttpHeaderNames.ACCESS_CONTROL_ALLOW_ORIGIN)); + list.add(header(false, HttpHeaderNames.ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK)); + list.add(header(false, HttpHeaderNames.ACCESS_CONTROL_EXPOSE_HEADERS)); + list.add(header(false, HttpHeaderNames.ACCESS_CONTROL_MAX_AGE)); + list.add(header(false, HttpHeaderNames.ACCESS_CONTROL_REQUEST_HEADERS)); + list.add(header(false, HttpHeaderNames.ACCESS_CONTROL_REQUEST_METHOD)); + list.add(header(false, HttpHeaderNames.ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK)); + list.add(header(false, HttpHeaderNames.AGE)); + list.add(header(false, HttpHeaderNames.ALLOW)); + list.add(header(false, HttpHeaderNames.AUTHORIZATION)); + list.add(header(false, HttpHeaderNames.CACHE_CONTROL)); + list.add(header(true, HttpHeaderNames.CONNECTION)); + list.add(header(false, HttpHeaderNames.CONTENT_BASE)); + list.add(header(false, HttpHeaderNames.CONTENT_ENCODING)); + list.add(header(false, HttpHeaderNames.CONTENT_LANGUAGE)); + list.add(header(false, HttpHeaderNames.CONTENT_LENGTH)); + list.add(header(false, HttpHeaderNames.CONTENT_LOCATION)); + list.add(header(false, HttpHeaderNames.CONTENT_TRANSFER_ENCODING)); + list.add(header(false, HttpHeaderNames.CONTENT_DISPOSITION)); + list.add(header(false, HttpHeaderNames.CONTENT_MD5)); + list.add(header(false, HttpHeaderNames.CONTENT_RANGE)); + list.add(header(false, HttpHeaderNames.CONTENT_SECURITY_POLICY)); + list.add(header(false, HttpHeaderNames.CONTENT_TYPE)); + list.add(header(false, HttpHeaderNames.COOKIE)); + list.add(header(false, HttpHeaderNames.DATE)); + list.add(header(false, HttpHeaderNames.DNT)); + list.add(header(false, HttpHeaderNames.ETAG)); + list.add(header(false, HttpHeaderNames.EXPECT)); + list.add(header(false, HttpHeaderNames.EXPIRES)); + list.add(header(false, HttpHeaderNames.FROM)); + list.add(header(false, HttpHeaderNames.HOST)); + list.add(header(false, HttpHeaderNames.IF_MATCH)); + list.add(header(false, HttpHeaderNames.IF_MODIFIED_SINCE)); + list.add(header(false, HttpHeaderNames.IF_NONE_MATCH)); + list.add(header(false, HttpHeaderNames.IF_RANGE)); + list.add(header(false, HttpHeaderNames.IF_UNMODIFIED_SINCE)); + list.add(header(true, HttpHeaderNames.KEEP_ALIVE)); + list.add(header(false, HttpHeaderNames.LAST_MODIFIED)); + list.add(header(false, HttpHeaderNames.LOCATION)); + list.add(header(false, HttpHeaderNames.MAX_FORWARDS)); + list.add(header(false, HttpHeaderNames.ORIGIN)); + list.add(header(false, HttpHeaderNames.PRAGMA)); + list.add(header(false, HttpHeaderNames.PROXY_AUTHENTICATE)); + list.add(header(false, HttpHeaderNames.PROXY_AUTHORIZATION)); + list.add(header(true, HttpHeaderNames.PROXY_CONNECTION)); + list.add(header(false, HttpHeaderNames.RANGE)); + list.add(header(false, HttpHeaderNames.REFERER)); + list.add(header(false, HttpHeaderNames.RETRY_AFTER)); + list.add(header(false, HttpHeaderNames.SEC_WEBSOCKET_KEY1)); + list.add(header(false, HttpHeaderNames.SEC_WEBSOCKET_KEY2)); + list.add(header(false, HttpHeaderNames.SEC_WEBSOCKET_LOCATION)); + list.add(header(false, HttpHeaderNames.SEC_WEBSOCKET_ORIGIN)); + list.add(header(false, HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL)); + list.add(header(false, HttpHeaderNames.SEC_WEBSOCKET_VERSION)); + list.add(header(false, HttpHeaderNames.SEC_WEBSOCKET_KEY)); + list.add(header(false, HttpHeaderNames.SEC_WEBSOCKET_ACCEPT)); + list.add(header(false, HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS)); + list.add(header(false, HttpHeaderNames.SERVER)); + list.add(header(false, HttpHeaderNames.SET_COOKIE)); + list.add(header(false, HttpHeaderNames.SET_COOKIE2)); + list.add(header(true, HttpHeaderNames.TE)); + list.add(header(false, HttpHeaderNames.TRAILER)); + list.add(header(true, HttpHeaderNames.TRANSFER_ENCODING)); + list.add(header(true, HttpHeaderNames.UPGRADE)); + list.add(header(false, HttpHeaderNames.UPGRADE_INSECURE_REQUESTS)); + list.add(header(false, HttpHeaderNames.USER_AGENT)); + list.add(header(false, HttpHeaderNames.VARY)); + list.add(header(false, HttpHeaderNames.VIA)); + list.add(header(false, HttpHeaderNames.WARNING)); + list.add(header(false, HttpHeaderNames.WEBSOCKET_LOCATION)); + list.add(header(false, HttpHeaderNames.WEBSOCKET_ORIGIN)); + list.add(header(false, HttpHeaderNames.WEBSOCKET_PROTOCOL)); + list.add(header(false, HttpHeaderNames.WWW_AUTHENTICATE)); + list.add(header(false, HttpHeaderNames.X_FRAME_OPTIONS)); + list.add(header(false, HttpHeaderNames.X_REQUESTED_WITH)); + + return list; + } + + private static Arguments header(final boolean isConnectionRelated, final AsciiString headerName) { + return new Arguments() { + @Override + public Object[] get() { + return new Object[]{headerName, isConnectionRelated}; + } + }; + } + + @ParameterizedTest + @MethodSource("connectionRelatedHeaders") + void mustIdentifyConnectionRelatedHeadersAsciiString(AsciiString headerName, boolean isConnectionRelated) { + assertEquals(isConnectionRelated, HttpHeaderValidationUtil.isConnectionHeader(headerName, false)); + } + + @ParameterizedTest + @MethodSource("connectionRelatedHeaders") + void mustIdentifyConnectionRelatedHeadersString(AsciiString headerName, boolean isConnectionRelated) { + assertEquals(isConnectionRelated, HttpHeaderValidationUtil.isConnectionHeader(headerName.toString(), false)); + } + + @Test + void teHeaderIsNotConnectionRelatedWhenIgnoredAsciiString() { + assertFalse(HttpHeaderValidationUtil.isConnectionHeader(HttpHeaderNames.TE, true)); + } + + @Test + void teHeaderIsNotConnectionRelatedWhenIgnoredString() { + assertFalse(HttpHeaderValidationUtil.isConnectionHeader(HttpHeaderNames.TE.toString(), true)); + } + + public static List teIsTrailersTruthTable() { + List list = new ArrayList(); + + list.add(teIsTrailter(HttpHeaderNames.TE, HttpHeaderValues.TRAILERS, false)); + list.add(teIsTrailter(HttpHeaderNames.TE, HttpHeaderValues.CHUNKED, true)); + list.add(teIsTrailter(HttpHeaderNames.COOKIE, HttpHeaderValues.CHUNKED, false)); + list.add(teIsTrailter(HttpHeaderNames.COOKIE, HttpHeaderValues.TRAILERS, false)); + list.add(teIsTrailter(HttpHeaderNames.TRAILER, HttpHeaderValues.TRAILERS, false)); + list.add(teIsTrailter(HttpHeaderNames.TRAILER, HttpHeaderValues.CHUNKED, false)); + + return list; + } + + private static Arguments teIsTrailter( + final AsciiString headerName, final AsciiString headerValue, final boolean result) { + return new Arguments() { + @Override + public Object[] get() { + return new Object[]{headerName, headerValue, result}; + } + }; + } + + @ParameterizedTest + @MethodSource("teIsTrailersTruthTable") + void whenTeIsNotTrailerOrNotWithNameAndValueAsciiString( + AsciiString headerName, AsciiString headerValue, boolean result) { + assertEquals(result, HttpHeaderValidationUtil.isTeNotTrailers(headerName, headerValue)); + } + + @ParameterizedTest + @MethodSource("teIsTrailersTruthTable") + void whenTeIsNotTrailerOrNotSWithNameAndValueString( + AsciiString headerName, AsciiString headerValue, boolean result) { + assertEquals(result, HttpHeaderValidationUtil.isTeNotTrailers(headerName.toString(), headerValue.toString())); + } + + @ParameterizedTest + @MethodSource("teIsTrailersTruthTable") + void whenTeIsNotTrailerOrNotSWithNameAsciiStringAndValueString( + AsciiString headerName, AsciiString headerValue, boolean result) { + assertEquals(result, HttpHeaderValidationUtil.isTeNotTrailers(headerName, headerValue.toString())); + } + + @ParameterizedTest + @MethodSource("teIsTrailersTruthTable") + void whenTeIsNotTrailerOrNotSWithNametringAndValueAsciiString( + AsciiString headerName, AsciiString headerValue, boolean result) { + assertEquals(result, HttpHeaderValidationUtil.isTeNotTrailers(headerName.toString(), headerValue)); + } + + public static List illegalFirstChar() { + List list = new ArrayList(); + + for (byte i = 0; i < 0x21; i++) { + list.add(new AsciiString(new byte[]{i, 'a'})); + } + list.add(new AsciiString(new byte[]{0x7F, 'a'})); + + return list; + } + + @ParameterizedTest + @MethodSource("illegalFirstChar") + void decodingInvalidHeaderValuesMustFailIfFirstCharIsIllegalAsciiString(AsciiString value) { + assertEquals(0, validateValidHeaderValue(value)); + } + + @ParameterizedTest + @MethodSource("illegalFirstChar") + void decodingInvalidHeaderValuesMustFailIfFirstCharIsIllegalCharSequence(AsciiString value) { + assertEquals(0, validateValidHeaderValue(asCharSequence(value))); + } + + public static List legalFirstChar() { + List list = new ArrayList(); + + for (int i = 0x21; i <= 0xFF; i++) { + if (i == 0x7F) { + continue; + } + list.add(new AsciiString(new byte[]{(byte) i, 'a'})); + } + + return list; + } + + @ParameterizedTest + @MethodSource("legalFirstChar") + void allOtherCharsAreLegalFirstCharsAsciiString(AsciiString value) { + assertEquals(-1, validateValidHeaderValue(value)); + } + + @ParameterizedTest + @MethodSource("legalFirstChar") + void allOtherCharsAreLegalFirstCharsCharSequence(AsciiString value) { + assertEquals(-1, validateValidHeaderValue(value)); + } + + public static List illegalNotFirstChar() { + ArrayList list = new ArrayList(); + + for (byte i = 0; i < 0x21; i++) { + if (i == ' ' || i == '\t') { + continue; // Space and horizontal tab are only illegal as first chars. + } + list.add(new AsciiString(new byte[]{'a', i})); + } + list.add(new AsciiString(new byte[]{'a', 0x7F})); + + return list; + } + + @ParameterizedTest + @MethodSource("illegalNotFirstChar") + void decodingInvalidHeaderValuesMustFailIfNotFirstCharIsIllegalAsciiString(AsciiString value) { + assertEquals(1, validateValidHeaderValue(value)); + } + + @ParameterizedTest + @MethodSource("illegalNotFirstChar") + void decodingInvalidHeaderValuesMustFailIfNotFirstCharIsIllegalCharSequence(AsciiString value) { + assertEquals(1, validateValidHeaderValue(asCharSequence(value))); + } + + public static List legalNotFirstChar() { + List list = new ArrayList(); + + for (int i = 0; i < 0xFF; i++) { + if (i == 0x7F || i < 0x21 && (i != ' ' || i != '\t')) { + continue; + } + list.add(new AsciiString(new byte[] {'a', (byte) i})); + } + + return list; + } + + @ParameterizedTest + @MethodSource("legalNotFirstChar") + void allOtherCharsArgLegalNotFirstCharsAsciiString(AsciiString value) { + assertEquals(-1, validateValidHeaderValue(value)); + } + + @ParameterizedTest + @MethodSource("legalNotFirstChar") + void allOtherCharsArgLegalNotFirstCharsCharSequence(AsciiString value) { + assertEquals(-1, validateValidHeaderValue(asCharSequence(value))); + } + + @Test + void emptyValuesHaveNoIllegalCharsAsciiString() { + assertEquals(-1, validateValidHeaderValue(AsciiString.EMPTY_STRING)); + } + + @Test + void emptyValuesHaveNoIllegalCharsCharSequence() { + assertEquals(-1, validateValidHeaderValue(asCharSequence(AsciiString.EMPTY_STRING))); + } + + @Test + void headerValuesCannotEndWithNewlinesAsciiString() { + assertEquals(1, validateValidHeaderValue(AsciiString.of("a\n"))); + assertEquals(1, validateValidHeaderValue(AsciiString.of("a\r"))); + } + + @Test + void headerValuesCannotEndWithNewlinesCharSequence() { + assertEquals(1, validateValidHeaderValue("a\n")); + assertEquals(1, validateValidHeaderValue("a\r")); + } + + /** + * This method returns a {@link CharSequence} instance that has the same contents as the given {@link AsciiString}, + * but which is, critically, not itself an {@link AsciiString}. + *

+ * Some methods specialise on {@link AsciiString}, while having a {@link CharSequence} based fallback. + *

+ * This method exist to test those fallback methods. + * + * @param value The {@link AsciiString} instance to wrap. + * @return A new {@link CharSequence} instance which backed by the given {@link AsciiString}, + * but which is itself not an {@link AsciiString}. + */ + private static CharSequence asCharSequence(final AsciiString value) { + return new CharSequence() { + @Override + public int length() { + return value.length(); + } + + @Override + public char charAt(int index) { + return value.charAt(index); + } + + @Override + public CharSequence subSequence(int start, int end) { + return asCharSequence(value.subSequence(start, end)); + } + }; + } + + private static final IllegalArgumentException VALIDATION_EXCEPTION = new IllegalArgumentException() { + private static final long serialVersionUID = -8857428534361331089L; + + @Override + public synchronized Throwable fillInStackTrace() { + return this; + } + }; + + @DisabledForJreRange(max = JRE.JAVA_17) // This test is much too slow on older Java versions. + @Test + void headerValueValidationMustRejectAllValuesRejectedByOldAlgorithm() { + byte[] array = new byte[4]; + final ByteBuffer buffer = ByteBuffer.wrap(array); + final AsciiString asciiString = new AsciiString(buffer, false); + CharSequence charSequence = asCharSequence(asciiString); + int i = Integer.MIN_VALUE; + Supplier failureMessageSupplier = new Supplier() { + @Override + public String get() { + return "validation mismatch on string '" + asciiString + "', iteration " + buffer.getInt(0); + } + }; + + do { + buffer.putInt(0, i); + try { + oldHeaderValueValidationAlgorithm(asciiString); + } catch (IllegalArgumentException ignore) { + assertNotEquals(-1, validateValidHeaderValue(asciiString), failureMessageSupplier); + assertNotEquals(-1, validateValidHeaderValue(charSequence), failureMessageSupplier); + } + i++; + } while (i != Integer.MIN_VALUE); + } + + private static void oldHeaderValueValidationAlgorithm(CharSequence seq) { + int state = 0; + // Start looping through each of the character + for (int index = 0; index < seq.length(); index++) { + state = oldValidationAlgorithmValidateValueChar(state, seq.charAt(index)); + } + + if (state != 0) { + throw VALIDATION_EXCEPTION; + } + } + + private static int oldValidationAlgorithmValidateValueChar(int state, char character) { + /* + * State: + * 0: Previous character was neither CR nor LF + * 1: The previous character was CR + * 2: The previous character was LF + */ + if ((character & ~15) == 0) { + // Check the absolutely prohibited characters. + switch (character) { + case 0x0: // NULL + throw VALIDATION_EXCEPTION; + case 0x0b: // Vertical tab + throw VALIDATION_EXCEPTION; + case '\f': + throw VALIDATION_EXCEPTION; + default: + break; + } + } + + // Check the CRLF (HT | SP) pattern + switch (state) { + case 0: + switch (character) { + case '\r': + return 1; + case '\n': + return 2; + default: + break; + } + break; + case 1: + if (character == '\n') { + return 2; + } + throw VALIDATION_EXCEPTION; + case 2: + switch (character) { + case '\t': + case ' ': + return 0; + default: + throw VALIDATION_EXCEPTION; + } + default: + break; + } + return state; + } + + @DisabledForJreRange(max = JRE.JAVA_17) // This test is much too slow on older Java versions. + @Test + void headerNameValidationMustRejectAllNamesRejectedByOldAlgorithm() throws Exception { + byte[] array = new byte[4]; + final ByteBuffer buffer = ByteBuffer.wrap(array); + final AsciiString asciiString = new AsciiString(buffer, false); + CharSequence charSequence = asCharSequence(asciiString); + int i = Integer.MIN_VALUE; + Supplier failureMessageSupplier = new Supplier() { + @Override + public String get() { + return "validation mismatch on string '" + asciiString + "', iteration " + buffer.getInt(0); + } + }; + + do { + buffer.putInt(0, i); + try { + oldHeaderNameValidationAlgorithmAsciiString(asciiString); + } catch (IllegalArgumentException ignore) { + assertNotEquals(-1, validateToken(asciiString), failureMessageSupplier); + assertNotEquals(-1, validateToken(charSequence), failureMessageSupplier); + } + i++; + } while (i != Integer.MIN_VALUE); + } + + private static void oldHeaderNameValidationAlgorithmAsciiString(AsciiString name) throws Exception { + byte[] array = name.array(); + for (int i = name.arrayOffset(), len = name.arrayOffset() + name.length(); i < len; i++) { + validateHeaderNameElement(array[i]); + } + } + + private static void validateHeaderNameElement(byte value) { + switch (value) { + case 0x1c: + case 0x1d: + case 0x1e: + case 0x1f: + case 0x00: + case '\t': + case '\n': + case 0x0b: + case '\f': + case '\r': + case ' ': + case ',': + case ':': + case ';': + case '=': + throw VALIDATION_EXCEPTION; + default: + // Check to see if the character is not an ASCII character, or invalid + if (value < 0) { + throw VALIDATION_EXCEPTION; + } + } + } + + public static List validTokenChars() { + List list = new ArrayList(); + for (char c = '0'; c <= '9'; c++) { + list.add(c); + } + for (char c = 'a'; c <= 'z'; c++) { + list.add(c); + } + for (char c = 'A'; c <= 'Z'; c++) { + list.add(c); + } + + // Unreserved characters: + list.add('-'); + list.add('.'); + list.add('_'); + list.add('~'); + + // Token special characters: + list.add('!'); + list.add('#'); + list.add('$'); + list.add('%'); + list.add('&'); + list.add('\''); + list.add('*'); + list.add('+'); + list.add('^'); + list.add('`'); + list.add('|'); + + return list; + } + + @ParameterizedTest + @MethodSource("validTokenChars") + void allTokenCharsAreValidFirstCharHeaderName(char tokenChar) { + AsciiString asciiString = new AsciiString(new byte[] {(byte) tokenChar, 'a'}); + CharSequence charSequence = asCharSequence(asciiString); + String string = tokenChar + "a"; + + assertEquals(-1, validateToken(asciiString)); + assertEquals(-1, validateToken(charSequence)); + assertEquals(-1, validateToken(string)); + } + + @ParameterizedTest + @MethodSource("validTokenChars") + void allTokenCharsAreValidSecondCharHeaderName(char tokenChar) { + AsciiString asciiString = new AsciiString(new byte[] {'a', (byte) tokenChar}); + CharSequence charSequence = asCharSequence(asciiString); + String string = "a" + tokenChar; + + assertEquals(-1, validateToken(asciiString)); + assertEquals(-1, validateToken(charSequence)); + assertEquals(-1, validateToken(string)); + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpHeadersTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpHeadersTest.java new file mode 100644 index 0000000..dd46ec8 --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpHeadersTest.java @@ -0,0 +1,106 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.util.AsciiString; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import java.util.List; + +import static io.netty.handler.codec.http.HttpHeadersTestUtils.of; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class HttpHeadersTest { + + @Test + public void testRemoveTransferEncodingIgnoreCase() { + HttpMessage message = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK); + message.headers().set(HttpHeaderNames.TRANSFER_ENCODING, "Chunked"); + assertFalse(message.headers().isEmpty()); + HttpUtil.setTransferEncodingChunked(message, false); + assertTrue(message.headers().isEmpty()); + } + + // Test for https://github.com/netty/netty/issues/1690 + @Test + public void testGetOperations() { + HttpHeaders headers = new DefaultHttpHeaders(); + headers.add(of("Foo"), of("1")); + headers.add(of("Foo"), of("2")); + + assertEquals("1", headers.get(of("Foo"))); + + List values = headers.getAll(of("Foo")); + assertEquals(2, values.size()); + assertEquals("1", values.get(0)); + assertEquals("2", values.get(1)); + } + + @Test + public void testEqualsIgnoreCase() { + assertThat(AsciiString.contentEqualsIgnoreCase(null, null), is(true)); + assertThat(AsciiString.contentEqualsIgnoreCase(null, "foo"), is(false)); + assertThat(AsciiString.contentEqualsIgnoreCase("bar", null), is(false)); + assertThat(AsciiString.contentEqualsIgnoreCase("FoO", "fOo"), is(true)); + } + + @Test + public void testSetNullHeaderValueValidate() { + final HttpHeaders headers = new DefaultHttpHeaders(true); + assertThrows(NullPointerException.class, new Executable() { + @Override + public void execute() { + headers.set(of("test"), (CharSequence) null); + } + }); + } + + @Test + public void testSetNullHeaderValueNotValidate() { + final HttpHeaders headers = new DefaultHttpHeaders(false); + assertThrows(NullPointerException.class, new Executable() { + @Override + public void execute() { + headers.set(of("test"), (CharSequence) null); + } + }); + } + + @Test + public void testAddSelf() { + final HttpHeaders headers = new DefaultHttpHeaders(false); + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + headers.add(headers); + } + }); + } + + @Test + public void testSetSelfIsNoOp() { + HttpHeaders headers = new DefaultHttpHeaders(false); + headers.add("name", "value"); + headers.set(headers); + assertEquals(1, headers.size()); + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpHeadersTestUtils.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpHeadersTestUtils.java new file mode 100644 index 0000000..03ff957 --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpHeadersTestUtils.java @@ -0,0 +1,139 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static io.netty.util.internal.StringUtil.COMMA; +import static io.netty.util.internal.StringUtil.DOUBLE_QUOTE; + +/** + * Utility methods for {@link HttpHeaders} related unit tests. + */ +public final class HttpHeadersTestUtils { + enum HeaderValue { + UNKNOWN("Unknown", 0), + ONE("One", 1), + TWO("Two", 2), + THREE("Three", 3), + FOUR("Four", 4), + FIVE("Five", 5), + SIX_QUOTED("Six,", 6), + SEVEN_QUOTED("Seven; , GMT", 7), + EIGHT("Eight", 8); + + private final int nr; + private final String value; + private List array; + + HeaderValue(final String value, final int nr) { + this.nr = nr; + this.value = value; + } + + @Override + public String toString() { + return value; + } + + public List asList() { + if (array == null) { + List list = new ArrayList(nr); + for (int i = 1; i <= nr; i++) { + list.add(of(i).toString()); + } + array = list; + } + return array; + } + + public List subset(int from) { + assert from > 0; + --from; + final int size = nr - from; + final int end = from + size; + List list = new ArrayList(size); + List fullList = asList(); + for (int i = from; i < end; ++i) { + list.add(fullList.get(i)); + } + return list; + } + + public String subsetAsCsvString(final int from) { + final List subset = subset(from); + return asCsv(subset); + } + + public String asCsv(final List arr) { + if (arr == null || arr.isEmpty()) { + return ""; + } + final StringBuilder sb = new StringBuilder(arr.size() * 10); + final int end = arr.size() - 1; + for (int i = 0; i < end; ++i) { + quoted(sb, arr.get(i)).append(COMMA); + } + quoted(sb, arr.get(end)); + return sb.toString(); + } + + public CharSequence asCsv() { + return asCsv(asList()); + } + + private static StringBuilder quoted(final StringBuilder sb, final CharSequence value) { + if (contains(value, COMMA) && !contains(value, DOUBLE_QUOTE)) { + return sb.append(DOUBLE_QUOTE).append(value).append(DOUBLE_QUOTE); + } + return sb.append(value); + } + + private static boolean contains(CharSequence value, char c) { + for (int i = 0; i < value.length(); ++i) { + if (value.charAt(i) == c) { + return true; + } + } + return false; + } + + private static final Map MAP; + + static { + final Map map = new HashMap(); + for (HeaderValue v : values()) { + final int nr = v.nr; + map.put(Integer.valueOf(nr), v); + } + MAP = map; + } + + public static HeaderValue of(final int nr) { + final HeaderValue v = MAP.get(Integer.valueOf(nr)); + return v == null ? UNKNOWN : v; + } + } + + public static CharSequence of(String s) { + return s; + } + + private HttpHeadersTestUtils() { } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpInvalidMessageTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpInvalidMessageTest.java new file mode 100644 index 0000000..d5e96ee --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpInvalidMessageTest.java @@ -0,0 +1,122 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.DecoderResult; +import io.netty.util.CharsetUtil; +import org.junit.jupiter.api.Test; + +import java.util.Random; + +import static io.netty.handler.codec.http.HttpHeadersTestUtils.of; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class HttpInvalidMessageTest { + + private final Random rnd = new Random(); + + @Test + public void testRequestWithBadInitialLine() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new HttpRequestDecoder()); + ch.writeInbound(Unpooled.copiedBuffer("GET / HTTP/1.0 with extra\r\n", CharsetUtil.UTF_8)); + HttpRequest req = ch.readInbound(); + DecoderResult dr = req.decoderResult(); + assertFalse(dr.isSuccess()); + assertTrue(dr.isFailure()); + ensureInboundTrafficDiscarded(ch); + } + + @Test + public void testRequestWithBadHeader() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new HttpRequestDecoder()); + ch.writeInbound(Unpooled.copiedBuffer("GET /maybe-something HTTP/1.0\r\n", CharsetUtil.UTF_8)); + ch.writeInbound(Unpooled.copiedBuffer("Good_Name: Good Value\r\n", CharsetUtil.UTF_8)); + ch.writeInbound(Unpooled.copiedBuffer("Bad=Name: Bad Value\r\n", CharsetUtil.UTF_8)); + ch.writeInbound(Unpooled.copiedBuffer("\r\n", CharsetUtil.UTF_8)); + HttpRequest req = ch.readInbound(); + DecoderResult dr = req.decoderResult(); + assertFalse(dr.isSuccess()); + assertTrue(dr.isFailure()); + assertEquals("Good Value", req.headers().get(of("Good_Name"))); + assertEquals("/maybe-something", req.uri()); + ensureInboundTrafficDiscarded(ch); + } + + @Test + public void testResponseWithBadInitialLine() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new HttpResponseDecoder()); + ch.writeInbound(Unpooled.copiedBuffer("HTTP/1.0 BAD_CODE Bad Server\r\n", CharsetUtil.UTF_8)); + HttpResponse res = ch.readInbound(); + DecoderResult dr = res.decoderResult(); + assertFalse(dr.isSuccess()); + assertTrue(dr.isFailure()); + ensureInboundTrafficDiscarded(ch); + } + + @Test + public void testResponseWithBadHeader() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new HttpResponseDecoder()); + ch.writeInbound(Unpooled.copiedBuffer("HTTP/1.0 200 Maybe OK\r\n", CharsetUtil.UTF_8)); + ch.writeInbound(Unpooled.copiedBuffer("Good_Name: Good Value\r\n", CharsetUtil.UTF_8)); + ch.writeInbound(Unpooled.copiedBuffer("Bad=Name: Bad Value\r\n", CharsetUtil.UTF_8)); + ch.writeInbound(Unpooled.copiedBuffer("\r\n", CharsetUtil.UTF_8)); + HttpResponse res = ch.readInbound(); + DecoderResult dr = res.decoderResult(); + assertFalse(dr.isSuccess()); + assertTrue(dr.isFailure()); + assertEquals("Maybe OK", res.status().reasonPhrase()); + assertEquals("Good Value", res.headers().get(of("Good_Name"))); + ensureInboundTrafficDiscarded(ch); + } + + @Test + public void testBadChunk() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new HttpRequestDecoder()); + ch.writeInbound(Unpooled.copiedBuffer("GET / HTTP/1.0\r\n", CharsetUtil.UTF_8)); + ch.writeInbound(Unpooled.copiedBuffer("Transfer-Encoding: chunked\r\n\r\n", CharsetUtil.UTF_8)); + ch.writeInbound(Unpooled.copiedBuffer("BAD_LENGTH\r\n", CharsetUtil.UTF_8)); + + HttpRequest req = ch.readInbound(); + assertTrue(req.decoderResult().isSuccess()); + + LastHttpContent chunk = ch.readInbound(); + DecoderResult dr = chunk.decoderResult(); + assertFalse(dr.isSuccess()); + assertTrue(dr.isFailure()); + ensureInboundTrafficDiscarded(ch); + } + + private void ensureInboundTrafficDiscarded(EmbeddedChannel ch) { + // Generate a lot of random traffic to ensure that it's discarded silently. + byte[] data = new byte[1048576]; + rnd.nextBytes(data); + + ByteBuf buf = Unpooled.wrappedBuffer(data); + for (int i = 0; i < 4096; i ++) { + buf.setIndex(0, data.length); + ch.writeInbound(buf.retain()); + ch.checkException(); + assertNull(ch.readInbound()); + } + buf.release(); + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpObjectAggregatorTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpObjectAggregatorTest.java new file mode 100644 index 0000000..ff0894b --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpObjectAggregatorTest.java @@ -0,0 +1,747 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.CompositeByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.DecoderResult; +import io.netty.handler.codec.DecoderResultProvider; +import io.netty.util.AsciiString; +import io.netty.util.CharsetUtil; +import io.netty.util.ReferenceCountUtil; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; +import org.mockito.Mockito; + +import java.nio.channels.ClosedChannelException; +import java.util.List; + +import static io.netty.handler.codec.http.HttpHeadersTestUtils.of; +import static org.hamcrest.CoreMatchers.instanceOf; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class HttpObjectAggregatorTest { + + @Test + public void testAggregate() { + HttpObjectAggregator aggr = new HttpObjectAggregator(1024 * 1024); + EmbeddedChannel embedder = new EmbeddedChannel(aggr); + + HttpRequest message = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "http://localhost"); + message.headers().set(of("X-Test"), true); + HttpContent chunk1 = new DefaultHttpContent(Unpooled.copiedBuffer("test", CharsetUtil.US_ASCII)); + HttpContent chunk2 = new DefaultHttpContent(Unpooled.copiedBuffer("test2", CharsetUtil.US_ASCII)); + HttpContent chunk3 = new DefaultLastHttpContent(Unpooled.EMPTY_BUFFER); + assertFalse(embedder.writeInbound(message)); + assertFalse(embedder.writeInbound(chunk1)); + assertFalse(embedder.writeInbound(chunk2)); + + // this should trigger a channelRead event so return true + assertTrue(embedder.writeInbound(chunk3)); + assertTrue(embedder.finish()); + FullHttpRequest aggregatedMessage = embedder.readInbound(); + assertNotNull(aggregatedMessage); + + assertEquals(chunk1.content().readableBytes() + chunk2.content().readableBytes(), + HttpUtil.getContentLength(aggregatedMessage)); + assertEquals(Boolean.TRUE.toString(), aggregatedMessage.headers().get(of("X-Test"))); + checkContentBuffer(aggregatedMessage); + assertNull(embedder.readInbound()); + } + + private static void checkContentBuffer(FullHttpRequest aggregatedMessage) { + CompositeByteBuf buffer = (CompositeByteBuf) aggregatedMessage.content(); + assertEquals(2, buffer.numComponents()); + List buffers = buffer.decompose(0, buffer.capacity()); + assertEquals(2, buffers.size()); + for (ByteBuf buf: buffers) { + // This should be false as we decompose the buffer before to not have deep hierarchy + assertFalse(buf instanceof CompositeByteBuf); + } + aggregatedMessage.release(); + } + + @Test + public void testAggregateWithTrailer() { + HttpObjectAggregator aggr = new HttpObjectAggregator(1024 * 1024); + EmbeddedChannel embedder = new EmbeddedChannel(aggr); + HttpRequest message = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "http://localhost"); + message.headers().set(of("X-Test"), true); + HttpUtil.setTransferEncodingChunked(message, true); + HttpContent chunk1 = new DefaultHttpContent(Unpooled.copiedBuffer("test", CharsetUtil.US_ASCII)); + HttpContent chunk2 = new DefaultHttpContent(Unpooled.copiedBuffer("test2", CharsetUtil.US_ASCII)); + LastHttpContent trailer = new DefaultLastHttpContent(); + trailer.trailingHeaders().set(of("X-Trailer"), true); + + assertFalse(embedder.writeInbound(message)); + assertFalse(embedder.writeInbound(chunk1)); + assertFalse(embedder.writeInbound(chunk2)); + + // this should trigger a channelRead event so return true + assertTrue(embedder.writeInbound(trailer)); + assertTrue(embedder.finish()); + FullHttpRequest aggregatedMessage = embedder.readInbound(); + assertNotNull(aggregatedMessage); + + assertEquals(chunk1.content().readableBytes() + chunk2.content().readableBytes(), + HttpUtil.getContentLength(aggregatedMessage)); + assertEquals(Boolean.TRUE.toString(), aggregatedMessage.headers().get(of("X-Test"))); + assertEquals(Boolean.TRUE.toString(), aggregatedMessage.trailingHeaders().get(of("X-Trailer"))); + checkContentBuffer(aggregatedMessage); + assertNull(embedder.readInbound()); + } + + @Test + public void testOversizedRequest() { + final EmbeddedChannel embedder = new EmbeddedChannel(new HttpObjectAggregator(4)); + HttpRequest message = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.PUT, "http://localhost"); + HttpContent chunk1 = new DefaultHttpContent(Unpooled.copiedBuffer("test", CharsetUtil.US_ASCII)); + HttpContent chunk2 = new DefaultHttpContent(Unpooled.copiedBuffer("test2", CharsetUtil.US_ASCII)); + final HttpContent chunk3 = LastHttpContent.EMPTY_LAST_CONTENT; + + assertFalse(embedder.writeInbound(message)); + assertFalse(embedder.writeInbound(chunk1)); + assertFalse(embedder.writeInbound(chunk2)); + + FullHttpResponse response = embedder.readOutbound(); + assertEquals(HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE, response.status()); + assertEquals("0", response.headers().get(HttpHeaderNames.CONTENT_LENGTH)); + assertFalse(embedder.isOpen()); + + assertThrows(ClosedChannelException.class, new Executable() { + @Override + public void execute() { + embedder.writeInbound(chunk3); + } + }); + + assertFalse(embedder.finish()); + } + + @Test + public void testOversizedRequestWithContentLengthAndDecoder() { + EmbeddedChannel embedder = new EmbeddedChannel(new HttpRequestDecoder(), new HttpObjectAggregator(4, false)); + assertFalse(embedder.writeInbound(Unpooled.copiedBuffer( + "PUT /upload HTTP/1.1\r\n" + + "Content-Length: 5\r\n\r\n", CharsetUtil.US_ASCII))); + + assertNull(embedder.readInbound()); + + FullHttpResponse response = embedder.readOutbound(); + assertEquals(HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE, response.status()); + assertEquals("0", response.headers().get(HttpHeaderNames.CONTENT_LENGTH)); + + assertTrue(embedder.isOpen()); + + assertFalse(embedder.writeInbound(Unpooled.wrappedBuffer(new byte[] { 1, 2, 3, 4 }))); + assertFalse(embedder.writeInbound(Unpooled.wrappedBuffer(new byte[] { 5 }))); + + assertNull(embedder.readOutbound()); + + assertFalse(embedder.writeInbound(Unpooled.copiedBuffer( + "PUT /upload HTTP/1.1\r\n" + + "Content-Length: 2\r\n\r\n", CharsetUtil.US_ASCII))); + + assertEquals(HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE, response.status()); + assertEquals("0", response.headers().get(HttpHeaderNames.CONTENT_LENGTH)); + + assertThat(response, instanceOf(LastHttpContent.class)); + ReferenceCountUtil.release(response); + + assertTrue(embedder.isOpen()); + + assertFalse(embedder.writeInbound(Unpooled.copiedBuffer(new byte[] { 1 }))); + assertNull(embedder.readOutbound()); + assertTrue(embedder.writeInbound(Unpooled.copiedBuffer(new byte[] { 2 }))); + assertNull(embedder.readOutbound()); + + FullHttpRequest request = embedder.readInbound(); + assertEquals(HttpVersion.HTTP_1_1, request.protocolVersion()); + assertEquals(HttpMethod.PUT, request.method()); + assertEquals("/upload", request.uri()); + assertEquals(2, HttpUtil.getContentLength(request)); + + byte[] actual = new byte[request.content().readableBytes()]; + request.content().readBytes(actual); + assertArrayEquals(new byte[] { 1, 2 }, actual); + request.release(); + + assertFalse(embedder.finish()); + } + + @Test + public void testOversizedRequestWithoutKeepAlive() { + // send an HTTP/1.0 request with no keep-alive header + HttpRequest message = new DefaultHttpRequest(HttpVersion.HTTP_1_0, HttpMethod.PUT, "http://localhost"); + HttpUtil.setContentLength(message, 5); + checkOversizedRequest(message); + } + + @Test + public void testOversizedRequestWithContentLength() { + HttpRequest message = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.PUT, "http://localhost"); + HttpUtil.setContentLength(message, 5); + checkOversizedRequest(message); + } + + private static void checkOversizedRequest(HttpRequest message) { + final EmbeddedChannel embedder = new EmbeddedChannel(new HttpObjectAggregator(4)); + + assertFalse(embedder.writeInbound(message)); + HttpResponse response = embedder.readOutbound(); + assertEquals(HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE, response.status()); + assertEquals("0", response.headers().get(HttpHeaderNames.CONTENT_LENGTH)); + + assertThat(response, instanceOf(LastHttpContent.class)); + ReferenceCountUtil.release(response); + + if (serverShouldCloseConnection(message, response)) { + assertFalse(embedder.isOpen()); + + assertThrows(ClosedChannelException.class, new Executable() { + @Override + public void execute() { + embedder.writeInbound(new DefaultHttpContent(Unpooled.EMPTY_BUFFER)); + } + }); + + assertFalse(embedder.finish()); + } else { + assertTrue(embedder.isOpen()); + assertFalse(embedder.writeInbound(new DefaultHttpContent(Unpooled.copiedBuffer(new byte[8])))); + assertFalse(embedder.writeInbound(new DefaultHttpContent(Unpooled.copiedBuffer(new byte[8])))); + + // Now start a new message and ensure we will not reject it again. + HttpRequest message2 = new DefaultHttpRequest(HttpVersion.HTTP_1_0, HttpMethod.PUT, "http://localhost"); + HttpUtil.setContentLength(message, 2); + + assertFalse(embedder.writeInbound(message2)); + assertNull(embedder.readOutbound()); + assertFalse(embedder.writeInbound(new DefaultHttpContent(Unpooled.copiedBuffer(new byte[] { 1 })))); + assertNull(embedder.readOutbound()); + assertTrue(embedder.writeInbound(new DefaultLastHttpContent(Unpooled.copiedBuffer(new byte[] { 2 })))); + assertNull(embedder.readOutbound()); + + FullHttpRequest request = embedder.readInbound(); + assertEquals(message2.protocolVersion(), request.protocolVersion()); + assertEquals(message2.method(), request.method()); + assertEquals(message2.uri(), request.uri()); + assertEquals(2, HttpUtil.getContentLength(request)); + + byte[] actual = new byte[request.content().readableBytes()]; + request.content().readBytes(actual); + assertArrayEquals(new byte[] { 1, 2 }, actual); + request.release(); + + assertFalse(embedder.finish()); + } + } + + private static boolean serverShouldCloseConnection(HttpRequest message, HttpResponse response) { + // If the response wasn't keep-alive, the server should close the connection. + if (!HttpUtil.isKeepAlive(response)) { + return true; + } + // The connection should only be kept open if Expect: 100-continue is set, + // or if keep-alive is on. + if (HttpUtil.is100ContinueExpected(message)) { + return false; + } + if (HttpUtil.isKeepAlive(message)) { + return false; + } + return true; + } + + @Test + public void testOversizedResponse() { + final EmbeddedChannel embedder = new EmbeddedChannel(new HttpObjectAggregator(4)); + HttpResponse message = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK); + HttpContent chunk1 = new DefaultHttpContent(Unpooled.copiedBuffer("test", CharsetUtil.US_ASCII)); + final HttpContent chunk2 = new DefaultHttpContent(Unpooled.copiedBuffer("test2", CharsetUtil.US_ASCII)); + + assertFalse(embedder.writeInbound(message)); + assertFalse(embedder.writeInbound(chunk1)); + + assertThrows(TooLongHttpContentException.class, new Executable() { + @Override + public void execute() { + embedder.writeInbound(chunk2); + } + }); + + assertFalse(embedder.isOpen()); + assertFalse(embedder.finish()); + } + + @Test + public void testInvalidConstructorUsage() { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + new HttpObjectAggregator(-1); + } + }); + } + + @Test + public void testInvalidMaxCumulationBufferComponents() { + final HttpObjectAggregator aggr = new HttpObjectAggregator(Integer.MAX_VALUE); + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + aggr.setMaxCumulationBufferComponents(1); + } + }); + } + + @Test + public void testSetMaxCumulationBufferComponentsAfterInit() throws Exception { + final HttpObjectAggregator aggr = new HttpObjectAggregator(Integer.MAX_VALUE); + ChannelHandlerContext ctx = Mockito.mock(ChannelHandlerContext.class); + aggr.handlerAdded(ctx); + Mockito.verifyNoMoreInteractions(ctx); + assertThrows(IllegalStateException.class, new Executable() { + @Override + public void execute() { + aggr.setMaxCumulationBufferComponents(10); + } + }); + } + + @Test + public void testAggregateTransferEncodingChunked() { + HttpObjectAggregator aggr = new HttpObjectAggregator(1024 * 1024); + EmbeddedChannel embedder = new EmbeddedChannel(aggr); + + HttpRequest message = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.PUT, "http://localhost"); + message.headers().set(of("X-Test"), true); + message.headers().set(of("Transfer-Encoding"), of("Chunked")); + HttpContent chunk1 = new DefaultHttpContent(Unpooled.copiedBuffer("test", CharsetUtil.US_ASCII)); + HttpContent chunk2 = new DefaultHttpContent(Unpooled.copiedBuffer("test2", CharsetUtil.US_ASCII)); + HttpContent chunk3 = LastHttpContent.EMPTY_LAST_CONTENT; + assertFalse(embedder.writeInbound(message)); + assertFalse(embedder.writeInbound(chunk1)); + assertFalse(embedder.writeInbound(chunk2)); + + // this should trigger a channelRead event so return true + assertTrue(embedder.writeInbound(chunk3)); + assertTrue(embedder.finish()); + FullHttpRequest aggregatedMessage = embedder.readInbound(); + assertNotNull(aggregatedMessage); + + assertEquals(chunk1.content().readableBytes() + chunk2.content().readableBytes(), + HttpUtil.getContentLength(aggregatedMessage)); + assertEquals(Boolean.TRUE.toString(), aggregatedMessage.headers().get(of("X-Test"))); + checkContentBuffer(aggregatedMessage); + assertNull(embedder.readInbound()); + } + + @Test + public void testBadRequest() { + EmbeddedChannel ch = new EmbeddedChannel(new HttpRequestDecoder(), new HttpObjectAggregator(1024 * 1024)); + ch.writeInbound(Unpooled.copiedBuffer("GET / HTTP/1.0 with extra\r\n", CharsetUtil.UTF_8)); + Object inbound = ch.readInbound(); + assertThat(inbound, is(instanceOf(FullHttpRequest.class))); + assertTrue(((DecoderResultProvider) inbound).decoderResult().isFailure()); + assertNull(ch.readInbound()); + ch.finish(); + } + + @Test + public void testBadResponse() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new HttpResponseDecoder(), new HttpObjectAggregator(1024 * 1024)); + ch.writeInbound(Unpooled.copiedBuffer("HTTP/1.0 BAD_CODE Bad Server\r\n", CharsetUtil.UTF_8)); + Object inbound = ch.readInbound(); + assertThat(inbound, is(instanceOf(FullHttpResponse.class))); + assertTrue(((DecoderResultProvider) inbound).decoderResult().isFailure()); + assertNull(ch.readInbound()); + ch.finish(); + } + + @Test + public void testOversizedRequestWith100Continue() { + EmbeddedChannel embedder = new EmbeddedChannel(new HttpObjectAggregator(8)); + + // Send an oversized request with 100 continue. + HttpRequest message = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.PUT, "http://localhost"); + HttpUtil.set100ContinueExpected(message, true); + HttpUtil.setContentLength(message, 16); + + HttpContent chunk1 = new DefaultHttpContent(Unpooled.copiedBuffer("some", CharsetUtil.US_ASCII)); + HttpContent chunk2 = new DefaultHttpContent(Unpooled.copiedBuffer("test", CharsetUtil.US_ASCII)); + HttpContent chunk3 = LastHttpContent.EMPTY_LAST_CONTENT; + + // Send a request with 100-continue + large Content-Length header value. + assertFalse(embedder.writeInbound(message)); + + // The aggregator should respond with '413.' + FullHttpResponse response = embedder.readOutbound(); + assertEquals(HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE, response.status()); + assertEquals("0", response.headers().get(HttpHeaderNames.CONTENT_LENGTH)); + + // An ill-behaving client could continue to send data without a respect, and such data should be discarded. + assertFalse(embedder.writeInbound(chunk1)); + + // The aggregator should not close the connection because keep-alive is on. + assertTrue(embedder.isOpen()); + + // Now send a valid request. + HttpRequest message2 = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.PUT, "http://localhost"); + + assertFalse(embedder.writeInbound(message2)); + assertFalse(embedder.writeInbound(chunk2)); + assertTrue(embedder.writeInbound(chunk3)); + + FullHttpRequest fullMsg = embedder.readInbound(); + assertNotNull(fullMsg); + + assertEquals( + chunk2.content().readableBytes() + chunk3.content().readableBytes(), + HttpUtil.getContentLength(fullMsg)); + + assertEquals(HttpUtil.getContentLength(fullMsg), fullMsg.content().readableBytes()); + + fullMsg.release(); + assertFalse(embedder.finish()); + } + + @Test + public void testUnsupportedExpectHeaderExpectation() { + runUnsupportedExceptHeaderExceptionTest(true); + runUnsupportedExceptHeaderExceptionTest(false); + } + + private static void runUnsupportedExceptHeaderExceptionTest(final boolean close) { + final HttpObjectAggregator aggregator; + final int maxContentLength = 4; + if (close) { + aggregator = new HttpObjectAggregator(maxContentLength, true); + } else { + aggregator = new HttpObjectAggregator(maxContentLength); + } + final EmbeddedChannel embedder = new EmbeddedChannel(new HttpRequestDecoder(), aggregator); + + assertFalse(embedder.writeInbound(Unpooled.copiedBuffer( + "GET / HTTP/1.1\r\n" + + "Expect: chocolate=yummy\r\n" + + "Content-Length: 100\r\n\r\n", CharsetUtil.US_ASCII))); + assertNull(embedder.readInbound()); + + final FullHttpResponse response = embedder.readOutbound(); + assertEquals(HttpResponseStatus.EXPECTATION_FAILED, response.status()); + assertEquals("0", response.headers().get(HttpHeaderNames.CONTENT_LENGTH)); + response.release(); + + if (close) { + assertFalse(embedder.isOpen()); + } else { + // keep-alive is on by default in HTTP/1.1, so the connection should be still alive + assertTrue(embedder.isOpen()); + + // the decoder should be reset by the aggregator at this point and be able to decode the next request + assertTrue(embedder.writeInbound(Unpooled.copiedBuffer("GET / HTTP/1.1\r\n\r\n", CharsetUtil.US_ASCII))); + + final FullHttpRequest request = embedder.readInbound(); + assertThat(request.method(), is(HttpMethod.GET)); + assertThat(request.uri(), is("/")); + assertThat(request.content().readableBytes(), is(0)); + request.release(); + } + + assertFalse(embedder.finish()); + } + + @Test + public void testValidRequestWith100ContinueAndDecoder() { + EmbeddedChannel embedder = new EmbeddedChannel(new HttpRequestDecoder(), new HttpObjectAggregator(100)); + embedder.writeInbound(Unpooled.copiedBuffer( + "GET /upload HTTP/1.1\r\n" + + "Expect: 100-continue\r\n" + + "Content-Length: 0\r\n\r\n", CharsetUtil.US_ASCII)); + + FullHttpResponse response = embedder.readOutbound(); + assertEquals(HttpResponseStatus.CONTINUE, response.status()); + FullHttpRequest request = embedder.readInbound(); + assertFalse(request.headers().contains(HttpHeaderNames.EXPECT)); + request.release(); + response.release(); + assertFalse(embedder.finish()); + } + + @Test + public void testOversizedRequestWith100ContinueAndDecoder() { + EmbeddedChannel embedder = new EmbeddedChannel(new HttpRequestDecoder(), new HttpObjectAggregator(4)); + embedder.writeInbound(Unpooled.copiedBuffer( + "PUT /upload HTTP/1.1\r\n" + + "Expect: 100-continue\r\n" + + "Content-Length: 100\r\n\r\n", CharsetUtil.US_ASCII)); + + assertNull(embedder.readInbound()); + + FullHttpResponse response = embedder.readOutbound(); + assertEquals(HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE, response.status()); + assertEquals("0", response.headers().get(HttpHeaderNames.CONTENT_LENGTH)); + + // Keep-alive is on by default in HTTP/1.1, so the connection should be still alive. + assertTrue(embedder.isOpen()); + + // The decoder should be reset by the aggregator at this point and be able to decode the next request. + embedder.writeInbound(Unpooled.copiedBuffer("GET /max-upload-size HTTP/1.1\r\n\r\n", CharsetUtil.US_ASCII)); + + FullHttpRequest request = embedder.readInbound(); + assertThat(request.method(), is(HttpMethod.GET)); + assertThat(request.uri(), is("/max-upload-size")); + assertThat(request.content().readableBytes(), is(0)); + request.release(); + + assertFalse(embedder.finish()); + } + + @Test + public void testOversizedRequestWith100ContinueAndDecoderCloseConnection() { + EmbeddedChannel embedder = new EmbeddedChannel(new HttpRequestDecoder(), new HttpObjectAggregator(4, true)); + embedder.writeInbound(Unpooled.copiedBuffer( + "PUT /upload HTTP/1.1\r\n" + + "Expect: 100-continue\r\n" + + "Content-Length: 100\r\n\r\n", CharsetUtil.US_ASCII)); + + assertNull(embedder.readInbound()); + + FullHttpResponse response = embedder.readOutbound(); + assertEquals(HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE, response.status()); + assertEquals("0", response.headers().get(HttpHeaderNames.CONTENT_LENGTH)); + + // We are forcing the connection closed if an expectation is exceeded. + assertFalse(embedder.isOpen()); + assertFalse(embedder.finish()); + } + + @Test + public void testRequestAfterOversized100ContinueAndDecoder() { + EmbeddedChannel embedder = new EmbeddedChannel(new HttpRequestDecoder(), new HttpObjectAggregator(15)); + + // Write first request with Expect: 100-continue. + HttpRequest message = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.PUT, "http://localhost"); + HttpUtil.set100ContinueExpected(message, true); + HttpUtil.setContentLength(message, 16); + + HttpContent chunk1 = new DefaultHttpContent(Unpooled.copiedBuffer("some", CharsetUtil.US_ASCII)); + HttpContent chunk2 = new DefaultHttpContent(Unpooled.copiedBuffer("test", CharsetUtil.US_ASCII)); + HttpContent chunk3 = LastHttpContent.EMPTY_LAST_CONTENT; + + // Send a request with 100-continue + large Content-Length header value. + assertFalse(embedder.writeInbound(message)); + + // The aggregator should respond with '413'. + FullHttpResponse response = embedder.readOutbound(); + assertEquals(HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE, response.status()); + assertEquals("0", response.headers().get(HttpHeaderNames.CONTENT_LENGTH)); + + // An ill-behaving client could continue to send data without a respect, and such data should be discarded. + assertFalse(embedder.writeInbound(chunk1)); + + // The aggregator should not close the connection because keep-alive is on. + assertTrue(embedder.isOpen()); + + // Now send a valid request. + HttpRequest message2 = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.PUT, "http://localhost"); + + assertFalse(embedder.writeInbound(message2)); + assertFalse(embedder.writeInbound(chunk2)); + assertTrue(embedder.writeInbound(chunk3)); + + FullHttpRequest fullMsg = embedder.readInbound(); + assertNotNull(fullMsg); + + assertEquals( + chunk2.content().readableBytes() + chunk3.content().readableBytes(), + HttpUtil.getContentLength(fullMsg)); + + assertEquals(HttpUtil.getContentLength(fullMsg), fullMsg.content().readableBytes()); + + fullMsg.release(); + assertFalse(embedder.finish()); + } + + @Test + public void testReplaceAggregatedRequest() { + EmbeddedChannel embedder = new EmbeddedChannel(new HttpObjectAggregator(1024 * 1024)); + + Exception boom = new Exception("boom"); + HttpRequest req = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "http://localhost"); + req.setDecoderResult(DecoderResult.failure(boom)); + + assertTrue(embedder.writeInbound(req) && embedder.finish()); + + FullHttpRequest aggregatedReq = embedder.readInbound(); + FullHttpRequest replacedReq = aggregatedReq.replace(Unpooled.EMPTY_BUFFER); + + assertEquals(replacedReq.decoderResult(), aggregatedReq.decoderResult()); + aggregatedReq.release(); + replacedReq.release(); + } + + @Test + public void testReplaceAggregatedResponse() { + EmbeddedChannel embedder = new EmbeddedChannel(new HttpObjectAggregator(1024 * 1024)); + + Exception boom = new Exception("boom"); + HttpResponse rep = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK); + rep.setDecoderResult(DecoderResult.failure(boom)); + + assertTrue(embedder.writeInbound(rep) && embedder.finish()); + + FullHttpResponse aggregatedRep = embedder.readInbound(); + FullHttpResponse replacedRep = aggregatedRep.replace(Unpooled.EMPTY_BUFFER); + + assertEquals(replacedRep.decoderResult(), aggregatedRep.decoderResult()); + aggregatedRep.release(); + replacedRep.release(); + } + + @Test + public void testSelectiveRequestAggregation() { + HttpObjectAggregator myPostAggregator = new HttpObjectAggregator(1024 * 1024) { + @Override + protected boolean isStartMessage(HttpObject msg) throws Exception { + if (msg instanceof HttpRequest) { + HttpRequest request = (HttpRequest) msg; + HttpMethod method = request.method(); + + if (method.equals(HttpMethod.POST)) { + return true; + } + } + + return false; + } + }; + + EmbeddedChannel channel = new EmbeddedChannel(myPostAggregator); + + try { + // Aggregate: POST + HttpRequest request1 = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, "/"); + HttpContent content1 = new DefaultHttpContent(Unpooled.copiedBuffer("Hello, World!", CharsetUtil.UTF_8)); + request1.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.TEXT_PLAIN); + + assertTrue(channel.writeInbound(request1, content1, LastHttpContent.EMPTY_LAST_CONTENT)); + + // Getting an aggregated response out + Object msg1 = channel.readInbound(); + try { + assertTrue(msg1 instanceof FullHttpRequest); + } finally { + ReferenceCountUtil.release(msg1); + } + + // Don't aggregate: non-POST + HttpRequest request2 = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.PUT, "/"); + HttpContent content2 = new DefaultHttpContent(Unpooled.copiedBuffer("Hello, World!", CharsetUtil.UTF_8)); + request2.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.TEXT_PLAIN); + + try { + assertTrue(channel.writeInbound(request2, content2, LastHttpContent.EMPTY_LAST_CONTENT)); + + // Getting the same response objects out + assertSame(request2, channel.readInbound()); + assertSame(content2, channel.readInbound()); + assertSame(LastHttpContent.EMPTY_LAST_CONTENT, channel.readInbound()); + } finally { + ReferenceCountUtil.release(request2); + ReferenceCountUtil.release(content2); + } + + assertFalse(channel.finish()); + } finally { + channel.close(); + } + } + + @Test + public void testSelectiveResponseAggregation() { + HttpObjectAggregator myTextAggregator = new HttpObjectAggregator(1024 * 1024) { + @Override + protected boolean isStartMessage(HttpObject msg) throws Exception { + if (msg instanceof HttpResponse) { + HttpResponse response = (HttpResponse) msg; + HttpHeaders headers = response.headers(); + + String contentType = headers.get(HttpHeaderNames.CONTENT_TYPE); + if (AsciiString.contentEqualsIgnoreCase(contentType, HttpHeaderValues.TEXT_PLAIN)) { + return true; + } + } + + return false; + } + }; + + EmbeddedChannel channel = new EmbeddedChannel(myTextAggregator); + + try { + // Aggregate: text/plain + HttpResponse response1 = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK); + HttpContent content1 = new DefaultHttpContent(Unpooled.copiedBuffer("Hello, World!", CharsetUtil.UTF_8)); + response1.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.TEXT_PLAIN); + + assertTrue(channel.writeInbound(response1, content1, LastHttpContent.EMPTY_LAST_CONTENT)); + + // Getting an aggregated response out + Object msg1 = channel.readInbound(); + try { + assertTrue(msg1 instanceof FullHttpResponse); + } finally { + ReferenceCountUtil.release(msg1); + } + + // Don't aggregate: application/json + HttpResponse response2 = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK); + HttpContent content2 = new DefaultHttpContent(Unpooled.copiedBuffer("{key: 'value'}", CharsetUtil.UTF_8)); + response2.headers().set(HttpHeaderNames.CONTENT_TYPE, HttpHeaderValues.APPLICATION_JSON); + + try { + assertTrue(channel.writeInbound(response2, content2, LastHttpContent.EMPTY_LAST_CONTENT)); + + // Getting the same response objects out + assertSame(response2, channel.readInbound()); + assertSame(content2, channel.readInbound()); + assertSame(LastHttpContent.EMPTY_LAST_CONTENT, channel.readInbound()); + } finally { + ReferenceCountUtil.release(response2); + ReferenceCountUtil.release(content2); + } + + assertFalse(channel.finish()); + } finally { + channel.close(); + } + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpRequestDecoderTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpRequestDecoderTest.java new file mode 100644 index 0000000..35a0def --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpRequestDecoderTest.java @@ -0,0 +1,654 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.util.AsciiString; +import io.netty.util.CharsetUtil; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; + +import static io.netty.handler.codec.http.HttpHeaderNames.*; +import static io.netty.handler.codec.http.HttpHeadersTestUtils.of; +import static org.hamcrest.CoreMatchers.instanceOf; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.CoreMatchers.nullValue; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class HttpRequestDecoderTest { + private static final byte[] CONTENT_CRLF_DELIMITERS = createContent("\r\n"); + private static final byte[] CONTENT_LF_DELIMITERS = createContent("\n"); + private static final byte[] CONTENT_MIXED_DELIMITERS = createContent("\r\n", "\n"); + private static final int CONTENT_LENGTH = 8; + + private static byte[] createContent(String... lineDelimiters) { + String lineDelimiter; + String lineDelimiter2; + if (lineDelimiters.length == 2) { + lineDelimiter = lineDelimiters[0]; + lineDelimiter2 = lineDelimiters[1]; + } else { + lineDelimiter = lineDelimiters[0]; + lineDelimiter2 = lineDelimiters[0]; + } + return ("GET /some/path?foo=bar&wibble=eek HTTP/1.1" + "\r\n" + + "Upgrade: WebSocket" + lineDelimiter2 + + "Connection: Upgrade" + lineDelimiter + + "Host: localhost" + lineDelimiter2 + + "Origin: http://localhost:8080" + lineDelimiter + + "Sec-WebSocket-Key1: 10 28 8V7 8 48 0" + lineDelimiter2 + + "Sec-WebSocket-Key2: 8 Xt754O3Q3QW 0 _60" + lineDelimiter + + "Content-Length: " + CONTENT_LENGTH + lineDelimiter2 + + "\r\n" + + "12345678").getBytes(CharsetUtil.US_ASCII); + } + + @Test + public void testDecodeWholeRequestAtOnceCRLFDelimiters() { + testDecodeWholeRequestAtOnce(CONTENT_CRLF_DELIMITERS); + } + + @Test + public void testDecodeWholeRequestAtOnceLFDelimiters() { + testDecodeWholeRequestAtOnce(CONTENT_LF_DELIMITERS); + } + + @Test + public void testDecodeWholeRequestAtOnceMixedDelimiters() { + testDecodeWholeRequestAtOnce(CONTENT_MIXED_DELIMITERS); + } + + @Test + public void testDecodeWholeRequestAtOnceMixedDelimitersWithIntegerOverflowOnMaxBodySize() { + testDecodeWholeRequestAtOnce(CONTENT_MIXED_DELIMITERS, Integer.MAX_VALUE); + testDecodeWholeRequestAtOnce(CONTENT_MIXED_DELIMITERS, Integer.MAX_VALUE - 1); + } + + private static void testDecodeWholeRequestAtOnce(byte[] content) { + testDecodeWholeRequestAtOnce(content, HttpRequestDecoder.DEFAULT_MAX_HEADER_SIZE); + } + + private static void testDecodeWholeRequestAtOnce(byte[] content, int maxHeaderSize) { + EmbeddedChannel channel = + new EmbeddedChannel(new HttpRequestDecoder(HttpObjectDecoder.DEFAULT_MAX_INITIAL_LINE_LENGTH, + maxHeaderSize, + HttpObjectDecoder.DEFAULT_MAX_CHUNK_SIZE)); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(content))); + HttpRequest req = channel.readInbound(); + assertNotNull(req); + checkHeaders(req.headers()); + LastHttpContent c = channel.readInbound(); + assertEquals(CONTENT_LENGTH, c.content().readableBytes()); + assertEquals( + Unpooled.wrappedBuffer(content, content.length - CONTENT_LENGTH, CONTENT_LENGTH), + c.content().readSlice(CONTENT_LENGTH)); + c.release(); + + assertFalse(channel.finish()); + assertNull(channel.readInbound()); + } + + private static void checkHeaders(HttpHeaders headers) { + assertEquals(7, headers.names().size()); + checkHeader(headers, "Upgrade", "WebSocket"); + checkHeader(headers, "Connection", "Upgrade"); + checkHeader(headers, "Host", "localhost"); + checkHeader(headers, "Origin", "http://localhost:8080"); + checkHeader(headers, "Sec-WebSocket-Key1", "10 28 8V7 8 48 0"); + checkHeader(headers, "Sec-WebSocket-Key2", "8 Xt754O3Q3QW 0 _60"); + checkHeader(headers, "Content-Length", String.valueOf(CONTENT_LENGTH)); + } + + private static void checkHeader(HttpHeaders headers, String name, String value) { + List header1 = headers.getAll(of(name)); + assertEquals(1, header1.size()); + assertEquals(value, header1.get(0)); + } + + @Test + public void testDecodeWholeRequestInMultipleStepsCRLFDelimiters() { + testDecodeWholeRequestInMultipleSteps(CONTENT_CRLF_DELIMITERS); + } + + @Test + public void testDecodeWholeRequestInMultipleStepsLFDelimiters() { + testDecodeWholeRequestInMultipleSteps(CONTENT_LF_DELIMITERS); + } + + @Test + public void testDecodeWholeRequestInMultipleStepsMixedDelimiters() { + testDecodeWholeRequestInMultipleSteps(CONTENT_MIXED_DELIMITERS); + } + + private static void testDecodeWholeRequestInMultipleSteps(byte[] content) { + for (int i = 1; i < content.length; i++) { + testDecodeWholeRequestInMultipleSteps(content, i); + } + } + + private static void testDecodeWholeRequestInMultipleSteps(byte[] content, int fragmentSize) { + EmbeddedChannel channel = new EmbeddedChannel(new HttpRequestDecoder()); + int headerLength = content.length - CONTENT_LENGTH; + + // split up the header + for (int a = 0; a < headerLength;) { + int amount = fragmentSize; + if (a + amount > headerLength) { + amount = headerLength - a; + } + + // if header is done it should produce an HttpRequest + channel.writeInbound(Unpooled.copiedBuffer(content, a, amount)); + a += amount; + } + + for (int i = CONTENT_LENGTH; i > 0; i --) { + // Should produce HttpContent + channel.writeInbound(Unpooled.copiedBuffer(content, content.length - i, 1)); + } + + HttpRequest req = channel.readInbound(); + assertNotNull(req); + checkHeaders(req.headers()); + + for (int i = CONTENT_LENGTH; i > 1; i --) { + HttpContent c = channel.readInbound(); + assertEquals(1, c.content().readableBytes()); + assertEquals(content[content.length - i], c.content().readByte()); + c.release(); + } + + LastHttpContent c = channel.readInbound(); + assertEquals(1, c.content().readableBytes()); + assertEquals(content[content.length - 1], c.content().readByte()); + c.release(); + + assertFalse(channel.finish()); + assertNull(channel.readInbound()); + } + + @Test + public void testMultiLineHeader() { + EmbeddedChannel channel = new EmbeddedChannel(new HttpRequestDecoder()); + String crlf = "\r\n"; + String request = "GET /some/path HTTP/1.1" + crlf + + "Host: localhost" + crlf + + "MyTestHeader: part1" + crlf + + " newLinePart2" + crlf + + "MyTestHeader2: part21" + crlf + + "\t newLinePart22" + + crlf + crlf; + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(request, CharsetUtil.US_ASCII))); + HttpRequest req = channel.readInbound(); + assertEquals("part1 newLinePart2", req.headers().get(of("MyTestHeader"))); + assertEquals("part21 newLinePart22", req.headers().get(of("MyTestHeader2"))); + + LastHttpContent c = channel.readInbound(); + c.release(); + + assertFalse(channel.finish()); + assertNull(channel.readInbound()); + } + + @Test + public void testEmptyHeaderValue() { + EmbeddedChannel channel = new EmbeddedChannel(new HttpRequestDecoder()); + String crlf = "\r\n"; + String request = "GET /some/path HTTP/1.1" + crlf + + "Host: localhost" + crlf + + "EmptyHeader:" + crlf + crlf; + channel.writeInbound(Unpooled.copiedBuffer(request, CharsetUtil.US_ASCII)); + HttpRequest req = channel.readInbound(); + assertEquals("", req.headers().get(of("EmptyHeader"))); + } + + @Test + public void test100Continue() { + HttpRequestDecoder decoder = new HttpRequestDecoder(); + EmbeddedChannel channel = new EmbeddedChannel(decoder); + String oversized = + "PUT /file HTTP/1.1\r\n" + + "Expect: 100-continue\r\n" + + "Content-Length: 1048576000\r\n\r\n"; + + channel.writeInbound(Unpooled.copiedBuffer(oversized, CharsetUtil.US_ASCII)); + assertThat(channel.readInbound(), is(instanceOf(HttpRequest.class))); + + // At this point, we assume that we sent '413 Entity Too Large' to the peer without closing the connection + // so that the client can try again. + decoder.reset(); + + String query = "GET /max-file-size HTTP/1.1\r\n\r\n"; + channel.writeInbound(Unpooled.copiedBuffer(query, CharsetUtil.US_ASCII)); + assertThat(channel.readInbound(), is(instanceOf(HttpRequest.class))); + assertThat(channel.readInbound(), is(instanceOf(LastHttpContent.class))); + + assertThat(channel.finish(), is(false)); + } + + @Test + public void test100ContinueWithBadClient() { + HttpRequestDecoder decoder = new HttpRequestDecoder(); + EmbeddedChannel channel = new EmbeddedChannel(decoder); + String oversized = + "PUT /file HTTP/1.1\r\n" + + "Expect: 100-continue\r\n" + + "Content-Length: 1048576000\r\n\r\n" + + "WAY_TOO_LARGE_DATA_BEGINS"; + + channel.writeInbound(Unpooled.copiedBuffer(oversized, CharsetUtil.US_ASCII)); + assertThat(channel.readInbound(), is(instanceOf(HttpRequest.class))); + + HttpContent prematureData = channel.readInbound(); + prematureData.release(); + + assertThat(channel.readInbound(), is(nullValue())); + + // At this point, we assume that we sent '413 Entity Too Large' to the peer without closing the connection + // so that the client can try again. + decoder.reset(); + + String query = "GET /max-file-size HTTP/1.1\r\n\r\n"; + channel.writeInbound(Unpooled.copiedBuffer(query, CharsetUtil.US_ASCII)); + assertThat(channel.readInbound(), is(instanceOf(HttpRequest.class))); + assertThat(channel.readInbound(), is(instanceOf(LastHttpContent.class))); + + assertThat(channel.finish(), is(false)); + } + + @Test + public void testMessagesSplitBetweenMultipleBuffers() { + EmbeddedChannel channel = new EmbeddedChannel(new HttpRequestDecoder()); + String crlf = "\r\n"; + String str1 = "GET /some/path HTTP/1.1" + crlf + + "Host: localhost1" + crlf + crlf + + "GET /some/other/path HTTP/1.0" + crlf + + "Hos"; + String str2 = "t: localhost2" + crlf + + "content-length: 0" + crlf + crlf; + channel.writeInbound(Unpooled.copiedBuffer(str1, CharsetUtil.US_ASCII)); + HttpRequest req = channel.readInbound(); + assertEquals(HttpVersion.HTTP_1_1, req.protocolVersion()); + assertEquals("/some/path", req.uri()); + assertEquals(1, req.headers().size()); + assertTrue(AsciiString.contentEqualsIgnoreCase("localhost1", req.headers().get(HOST))); + LastHttpContent cnt = channel.readInbound(); + cnt.release(); + + channel.writeInbound(Unpooled.copiedBuffer(str2, CharsetUtil.US_ASCII)); + req = channel.readInbound(); + assertEquals(HttpVersion.HTTP_1_0, req.protocolVersion()); + assertEquals("/some/other/path", req.uri()); + assertEquals(2, req.headers().size()); + assertTrue(AsciiString.contentEqualsIgnoreCase("localhost2", req.headers().get(HOST))); + assertTrue(AsciiString.contentEqualsIgnoreCase("0", req.headers().get(HttpHeaderNames.CONTENT_LENGTH))); + cnt = channel.readInbound(); + cnt.release(); + assertFalse(channel.finishAndReleaseAll()); + } + + @Test + public void testTooLargeInitialLine() { + EmbeddedChannel channel = new EmbeddedChannel(new HttpRequestDecoder(10, 1024, 1024)); + String requestStr = "GET /some/path HTTP/1.1\r\n" + + "Host: localhost1\r\n\r\n"; + + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(requestStr, CharsetUtil.US_ASCII))); + HttpRequest request = channel.readInbound(); + assertTrue(request.decoderResult().isFailure()); + assertThat(request.decoderResult().cause(), instanceOf(TooLongHttpLineException.class)); + assertFalse(channel.finish()); + } + + @Test + public void testTooLargeInitialLineWithWSOnly() { + testTooLargeInitialLineWithControlCharsOnly(" "); + } + + @Test + public void testTooLargeInitialLineWithCRLFOnly() { + testTooLargeInitialLineWithControlCharsOnly("\r\n\r\n\r\n\r\n\r\n\r\n\r\n\r\n"); + } + + private static void testTooLargeInitialLineWithControlCharsOnly(String controlChars) { + EmbeddedChannel channel = new EmbeddedChannel(new HttpRequestDecoder(15, 1024, 1024)); + String requestStr = controlChars + "GET / HTTP/1.1\r\n" + + "Host: localhost1\r\n\r\n"; + + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(requestStr, CharsetUtil.US_ASCII))); + HttpRequest request = channel.readInbound(); + assertTrue(request.decoderResult().isFailure()); + assertTrue(request.decoderResult().cause() instanceof TooLongHttpLineException); + assertFalse(channel.finish()); + } + + @Test + public void testInitialLineWithLeadingControlChars() { + EmbeddedChannel channel = new EmbeddedChannel(new HttpRequestDecoder()); + String crlf = "\r\n"; + String request = crlf + "GET /some/path HTTP/1.1" + crlf + + "Host: localhost" + crlf + crlf; + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(request, CharsetUtil.US_ASCII))); + HttpRequest req = channel.readInbound(); + assertEquals(HttpMethod.GET, req.method()); + assertEquals("/some/path", req.uri()); + assertEquals(HttpVersion.HTTP_1_1, req.protocolVersion()); + assertTrue(channel.finishAndReleaseAll()); + } + + @Test + public void testTooLargeHeaders() { + EmbeddedChannel channel = new EmbeddedChannel(new HttpRequestDecoder(1024, 10, 1024)); + String requestStr = "GET /some/path HTTP/1.1\r\n" + + "Host: localhost1\r\n\r\n"; + + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(requestStr, CharsetUtil.US_ASCII))); + HttpRequest request = channel.readInbound(); + assertTrue(request.decoderResult().isFailure()); + assertTrue(request.decoderResult().cause() instanceof TooLongHttpHeaderException); + assertFalse(channel.finish()); + } + + @Test + public void testHeaderNameStartsWithControlChar1c() { + testHeaderNameStartsWithControlChar(0x1c); + } + + @Test + public void testHeaderNameStartsWithControlChar1d() { + testHeaderNameStartsWithControlChar(0x1d); + } + + @Test + public void testHeaderNameStartsWithControlChar1e() { + testHeaderNameStartsWithControlChar(0x1e); + } + + @Test + public void testHeaderNameStartsWithControlChar1f() { + testHeaderNameStartsWithControlChar(0x1f); + } + + @Test + public void testHeaderNameStartsWithControlChar0c() { + testHeaderNameStartsWithControlChar(0x0c); + } + + private void testHeaderNameStartsWithControlChar(int controlChar) { + ByteBuf requestBuffer = Unpooled.buffer(); + requestBuffer.writeCharSequence("GET /some/path HTTP/1.1\r\n" + + "Host: netty.io\r\n", CharsetUtil.US_ASCII); + requestBuffer.writeByte(controlChar); + requestBuffer.writeCharSequence("Transfer-Encoding: chunked\r\n\r\n", CharsetUtil.US_ASCII); + testInvalidHeaders0(requestBuffer); + } + + @Test + public void testHeaderNameEndsWithControlChar1c() { + testHeaderNameEndsWithControlChar(0x1c); + } + + @Test + public void testHeaderNameEndsWithControlChar1d() { + testHeaderNameEndsWithControlChar(0x1d); + } + + @Test + public void testHeaderNameEndsWithControlChar1e() { + testHeaderNameEndsWithControlChar(0x1e); + } + + @Test + public void testHeaderNameEndsWithControlChar1f() { + testHeaderNameEndsWithControlChar(0x1f); + } + + @Test + public void testHeaderNameEndsWithControlChar0c() { + testHeaderNameEndsWithControlChar(0x0c); + } + + private void testHeaderNameEndsWithControlChar(int controlChar) { + ByteBuf requestBuffer = Unpooled.buffer(); + requestBuffer.writeCharSequence("GET /some/path HTTP/1.1\r\n" + + "Host: netty.io\r\n", CharsetUtil.US_ASCII); + requestBuffer.writeCharSequence("Transfer-Encoding", CharsetUtil.US_ASCII); + requestBuffer.writeByte(controlChar); + requestBuffer.writeCharSequence(": chunked\r\n\r\n", CharsetUtil.US_ASCII); + testInvalidHeaders0(requestBuffer); + } + + @Test + public void testWhitespace() { + String requestStr = "GET /some/path HTTP/1.1\r\n" + + "Transfer-Encoding : chunked\r\n" + + "Host: netty.io\r\n\r\n"; + testInvalidHeaders0(requestStr); + } + + @Test + public void testWhitespaceInTransferEncoding01() { + String requestStr = "GET /some/path HTTP/1.1\r\n" + + "Transfer-Encoding : chunked\r\n" + + "Content-Length: 1\r\n" + + "Host: netty.io\r\n\r\n" + + "a"; + testInvalidHeaders0(requestStr); + } + + @Test + public void testWhitespaceInTransferEncoding02() { + String requestStr = "POST / HTTP/1.1" + + "Transfer-Encoding : chunked\r\n" + + "Host: target.com" + + "Content-Length: 65\r\n\r\n" + + "0\r\n\r\n" + + "GET /maliciousRequest HTTP/1.1\r\n" + + "Host: evilServer.com\r\n" + + "Foo: x"; + testInvalidHeaders0(requestStr); + } + + @Test + public void testHeaderWithNoValueAndMissingColon() { + String requestStr = "GET /some/path HTTP/1.1\r\n" + + "Content-Length: 0\r\n" + + "Host:\r\n" + + "netty.io\r\n\r\n"; + testInvalidHeaders0(requestStr); + } + + @Test + public void testMultipleContentLengthHeaders() { + String requestStr = "GET /some/path HTTP/1.1\r\n" + + "Content-Length: 1\r\n" + + "Content-Length: 0\r\n\r\n" + + "b"; + testInvalidHeaders0(requestStr); + } + + @Test + public void testMultipleContentLengthHeaders2() { + String requestStr = "GET /some/path HTTP/1.1\r\n" + + "Content-Length: 1\r\n" + + "Connection: close\r\n" + + "Content-Length: 0\r\n\r\n" + + "b"; + testInvalidHeaders0(requestStr); + } + + @Test + public void testContentLengthHeaderWithCommaValue() { + String requestStr = "GET /some/path HTTP/1.1\r\n" + + "Content-Length: 1,1\r\n\r\n" + + "b"; + testInvalidHeaders0(requestStr); + } + + @Test + public void testMultipleContentLengthHeadersWithFolding() { + String requestStr = "POST / HTTP/1.1\r\n" + + "Host: example.com\r\n" + + "Connection: close\r\n" + + "Content-Length: 5\r\n" + + "Content-Length:\r\n" + + "\t6\r\n\r\n" + + "123456"; + testInvalidHeaders0(requestStr); + } + + @Test + public void testContentLengthAndTransferEncodingHeadersWithVerticalTab() { + testContentLengthAndTransferEncodingHeadersWithInvalidSeparator((char) 0x0b, false); + testContentLengthAndTransferEncodingHeadersWithInvalidSeparator((char) 0x0b, true); + } + + @Test + public void testContentLengthAndTransferEncodingHeadersWithCR() { + testContentLengthAndTransferEncodingHeadersWithInvalidSeparator((char) 0x0d, false); + testContentLengthAndTransferEncodingHeadersWithInvalidSeparator((char) 0x0d, true); + } + + private static void testContentLengthAndTransferEncodingHeadersWithInvalidSeparator( + char separator, boolean extraLine) { + String requestStr = "POST / HTTP/1.1\r\n" + + "Host: example.com\r\n" + + "Connection: close\r\n" + + "Content-Length: 9\r\n" + + "Transfer-Encoding:" + separator + "chunked\r\n\r\n" + + (extraLine ? "0\r\n\r\n" : "") + + "something\r\n\r\n"; + testInvalidHeaders0(requestStr); + } + + @Test + public void testContentLengthHeaderAndChunked() { + String requestStr = "POST / HTTP/1.1\r\n" + + "Host: example.com\r\n" + + "Connection: close\r\n" + + "Content-Length: 5\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + "0\r\n\r\n"; + EmbeddedChannel channel = new EmbeddedChannel(new HttpRequestDecoder()); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(requestStr, CharsetUtil.US_ASCII))); + HttpRequest request = channel.readInbound(); + assertFalse(request.decoderResult().isFailure()); + assertTrue(request.headers().names().contains("Transfer-Encoding")); + assertTrue(request.headers().contains("Transfer-Encoding", "chunked", false)); + assertFalse(request.headers().contains("Content-Length")); + LastHttpContent c = channel.readInbound(); + c.release(); + assertFalse(channel.finish()); + } + + @Test + public void testOrderOfHeadersWithContentLength() { + String requestStr = "GET /some/path HTTP/1.1\r\n" + + "Host: example.com\r\n" + + "Content-Length: 5\r\n" + + "Connection: close\r\n\r\n" + + "hello"; + EmbeddedChannel channel = new EmbeddedChannel(new HttpRequestDecoder()); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(requestStr, CharsetUtil.US_ASCII))); + HttpRequest request = channel.readInbound(); + List headers = new ArrayList(); + for (Map.Entry header : request.headers()) { + headers.add(header.getKey()); + } + assertEquals(Arrays.asList("Host", "Content-Length", "Connection"), headers, "ordered headers"); + } + + @Test + public void testHttpMessageDecoderResult() { + String requestStr = "PUT /some/path HTTP/1.1\r\n" + + "Content-Length: 11\r\n" + + "Connection: close\r\n\r\n" + + "Lorem ipsum"; + EmbeddedChannel channel = new EmbeddedChannel(new HttpRequestDecoder()); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(requestStr, CharsetUtil.US_ASCII))); + HttpRequest request = channel.readInbound(); + assertTrue(request.decoderResult().isSuccess()); + assertThat(request.decoderResult(), instanceOf(HttpMessageDecoderResult.class)); + HttpMessageDecoderResult decoderResult = (HttpMessageDecoderResult) request.decoderResult(); + assertThat(decoderResult.initialLineLength(), is(23)); + assertThat(decoderResult.headerSize(), is(35)); + assertThat(decoderResult.totalSize(), is(58)); + HttpContent c = channel.readInbound(); + c.release(); + assertFalse(channel.finish()); + } + + /** + * RFC 9112 define the header field + * syntax thusly, where the field value is bracketed by optional whitespace: + *

+     *     field-line   = field-name ":" OWS field-value OWS
+     * 
+ * Meanwhile, RFC 9110 says that + * "optional whitespace" (OWS) is defined as "zero or more linear whitespace octets". + * And a "linear whitespace octet" is defined in the ABNF as either a space or a tab character. + */ + @Test + void headerValuesMayBeBracketedByZeroOrMoreWhitespace() throws Exception { + String requestStr = "GET / HTTP/1.1\r\n" + + "Host:example.com\r\n" + // zero whitespace + "X-0-Header: x0\r\n" + // two whitespace + "X-1-Header:\tx1\r\n" + // tab whitespace + "X-2-Header: \t x2\r\n" + // mixed whitespace + "X-3-Header:x3\t \r\n" + // whitespace after the value + "\r\n"; + HttpRequestDecoder decoder = new HttpRequestDecoder(); + EmbeddedChannel channel = new EmbeddedChannel(decoder); + + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(requestStr, CharsetUtil.US_ASCII))); + HttpRequest request = channel.readInbound(); + assertTrue(request.decoderResult().isSuccess()); + HttpHeaders headers = request.headers(); + assertEquals("example.com", headers.get("Host")); + assertEquals("x0", headers.get("X-0-Header")); + assertEquals("x1", headers.get("X-1-Header")); + assertEquals("x2", headers.get("X-2-Header")); + assertEquals("x3", headers.get("X-3-Header")); + LastHttpContent last = channel.readInbound(); + assertEquals(LastHttpContent.EMPTY_LAST_CONTENT, last); + last.release(); + assertFalse(channel.finish()); + } + + private static void testInvalidHeaders0(String requestStr) { + testInvalidHeaders0(Unpooled.copiedBuffer(requestStr, CharsetUtil.US_ASCII)); + } + + private static void testInvalidHeaders0(ByteBuf requestBuffer) { + EmbeddedChannel channel = new EmbeddedChannel(new HttpRequestDecoder()); + assertTrue(channel.writeInbound(requestBuffer)); + HttpRequest request = channel.readInbound(); + assertThat(request.decoderResult().cause(), instanceOf(IllegalArgumentException.class)); + assertTrue(request.decoderResult().isFailure()); + assertFalse(channel.finish()); + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpRequestEncoderTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpRequestEncoderTest.java new file mode 100644 index 0000000..bcbcd49 --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpRequestEncoderTest.java @@ -0,0 +1,432 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.DecoderResult; +import io.netty.util.CharsetUtil; +import io.netty.util.IllegalReferenceCountException; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.charset.Charset; +import java.util.concurrent.ExecutionException; + +import static io.netty.util.internal.ObjectUtil.checkNotNull; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + */ +public class HttpRequestEncoderTest { + + @SuppressWarnings("deprecation") + private static ByteBuf[] getBuffers() { + return new ByteBuf[]{ + Unpooled.buffer(128).order(ByteOrder.BIG_ENDIAN), + Unpooled.buffer(128).order(ByteOrder.LITTLE_ENDIAN), + Unpooled.wrappedBuffer(ByteBuffer.allocate(128).order(ByteOrder.BIG_ENDIAN)).resetWriterIndex(), + Unpooled.wrappedBuffer(ByteBuffer.allocate(128).order(ByteOrder.LITTLE_ENDIAN)).resetWriterIndex() + }; + } + + @Test + public void testUriWithoutPath() throws Exception { + for (ByteBuf buffer : getBuffers()) { + HttpRequestEncoder encoder = new HttpRequestEncoder(); + encoder.encodeInitialLine(buffer, new DefaultHttpRequest(HttpVersion.HTTP_1_1, + HttpMethod.GET, "http://localhost")); + String req = buffer.toString(Charset.forName("US-ASCII")); + assertEquals("GET http://localhost/ HTTP/1.1\r\n", req); + buffer.release(); + } + } + + @Test + public void testUriWithoutPath2() throws Exception { + for (ByteBuf buffer : getBuffers()) { + HttpRequestEncoder encoder = new HttpRequestEncoder(); + encoder.encodeInitialLine(buffer, new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, + "http://localhost:9999?p1=v1")); + String req = buffer.toString(Charset.forName("US-ASCII")); + assertEquals("GET http://localhost:9999/?p1=v1 HTTP/1.1\r\n", req); + buffer.release(); + } + } + + @Test + public void testUriWithEmptyPath() throws Exception { + for (ByteBuf buffer : getBuffers()) { + HttpRequestEncoder encoder = new HttpRequestEncoder(); + encoder.encodeInitialLine(buffer, new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, + "http://localhost:9999/?p1=v1")); + String req = buffer.toString(Charset.forName("US-ASCII")); + assertEquals("GET http://localhost:9999/?p1=v1 HTTP/1.1\r\n", req); + buffer.release(); + } + } + + @Test + public void testUriWithPath() throws Exception { + for (ByteBuf buffer : getBuffers()) { + HttpRequestEncoder encoder = new HttpRequestEncoder(); + encoder.encodeInitialLine(buffer, new DefaultHttpRequest(HttpVersion.HTTP_1_1, + HttpMethod.GET, "http://localhost/")); + String req = buffer.toString(Charset.forName("US-ASCII")); + assertEquals("GET http://localhost/ HTTP/1.1\r\n", req); + buffer.release(); + } + } + + @Test + public void testAbsPath() throws Exception { + for (ByteBuf buffer : getBuffers()) { + HttpRequestEncoder encoder = new HttpRequestEncoder(); + encoder.encodeInitialLine(buffer, new DefaultHttpRequest(HttpVersion.HTTP_1_1, + HttpMethod.GET, "/")); + String req = buffer.toString(Charset.forName("US-ASCII")); + assertEquals("GET / HTTP/1.1\r\n", req); + buffer.release(); + } + } + + @Test + public void testEmptyAbsPath() throws Exception { + for (ByteBuf buffer : getBuffers()) { + HttpRequestEncoder encoder = new HttpRequestEncoder(); + encoder.encodeInitialLine(buffer, new DefaultHttpRequest(HttpVersion.HTTP_1_1, + HttpMethod.GET, "")); + String req = buffer.toString(Charset.forName("US-ASCII")); + assertEquals("GET / HTTP/1.1\r\n", req); + buffer.release(); + } + } + + @Test + public void testQueryStringPath() throws Exception { + for (ByteBuf buffer : getBuffers()) { + HttpRequestEncoder encoder = new HttpRequestEncoder(); + encoder.encodeInitialLine(buffer, new DefaultHttpRequest(HttpVersion.HTTP_1_1, + HttpMethod.GET, "/?url=http://example.com")); + String req = buffer.toString(Charset.forName("US-ASCII")); + assertEquals("GET /?url=http://example.com HTTP/1.1\r\n", req); + buffer.release(); + } + } + + @Test + public void testEmptyReleasedBufferShouldNotWriteEmptyBufferToChannel() throws Exception { + HttpRequestEncoder encoder = new HttpRequestEncoder(); + final EmbeddedChannel channel = new EmbeddedChannel(encoder); + final ByteBuf buf = Unpooled.buffer(); + buf.release(); + ExecutionException e = assertThrows(ExecutionException.class, new Executable() { + @Override + public void execute() throws Throwable { + channel.writeAndFlush(buf).get(); + } + }); + assertThat(e.getCause().getCause(), is(instanceOf(IllegalReferenceCountException.class))); + + channel.finishAndReleaseAll(); + } + + @Test + public void testEmptyBufferShouldPassThrough() throws Exception { + HttpRequestEncoder encoder = new HttpRequestEncoder(); + EmbeddedChannel channel = new EmbeddedChannel(encoder); + ByteBuf buffer = Unpooled.buffer(); + channel.writeAndFlush(buffer).get(); + channel.finishAndReleaseAll(); + assertEquals(0, buffer.refCnt()); + } + + @Test + public void testEmptyContentsChunked() throws Exception { + testEmptyContents(true, false); + } + + @Test + public void testEmptyContentsChunkedWithTrailers() throws Exception { + testEmptyContents(true, true); + } + + @Test + public void testEmptyContentsNotChunked() throws Exception { + testEmptyContents(false, false); + } + + @Test + public void testEmptyContentNotsChunkedWithTrailers() throws Exception { + testEmptyContents(false, true); + } + + // this is not using Full types on purpose!!! + private static class CustomFullHttpRequest extends DefaultHttpRequest implements LastHttpContent { + private final ByteBuf content; + private final HttpHeaders trailingHeader; + + CustomFullHttpRequest(HttpVersion httpVersion, HttpMethod method, String uri, ByteBuf content) { + this(httpVersion, method, uri, content, true); + } + + CustomFullHttpRequest(HttpVersion httpVersion, HttpMethod method, String uri, + ByteBuf content, boolean validateHeaders) { + super(httpVersion, method, uri, validateHeaders); + this.content = checkNotNull(content, "content"); + trailingHeader = new DefaultHttpHeaders(validateHeaders); + } + + private CustomFullHttpRequest(HttpVersion httpVersion, HttpMethod method, String uri, + ByteBuf content, HttpHeaders headers, HttpHeaders trailingHeader) { + super(httpVersion, method, uri, headers); + this.content = checkNotNull(content, "content"); + this.trailingHeader = checkNotNull(trailingHeader, "trailingHeader"); + } + + @Override + public HttpHeaders trailingHeaders() { + return trailingHeader; + } + + @Override + public ByteBuf content() { + return content; + } + + @Override + public int refCnt() { + return content.refCnt(); + } + + @Override + public CustomFullHttpRequest retain() { + content.retain(); + return this; + } + + @Override + public CustomFullHttpRequest retain(int increment) { + content.retain(increment); + return this; + } + + @Override + public CustomFullHttpRequest touch() { + content.touch(); + return this; + } + + @Override + public CustomFullHttpRequest touch(Object hint) { + content.touch(hint); + return this; + } + + @Override + public boolean release() { + return content.release(); + } + + @Override + public boolean release(int decrement) { + return content.release(decrement); + } + + @Override + public CustomFullHttpRequest setProtocolVersion(HttpVersion version) { + super.setProtocolVersion(version); + return this; + } + + @Override + public CustomFullHttpRequest setMethod(HttpMethod method) { + super.setMethod(method); + return this; + } + + @Override + public CustomFullHttpRequest setUri(String uri) { + super.setUri(uri); + return this; + } + + @Override + public CustomFullHttpRequest copy() { + return replace(content().copy()); + } + + @Override + public CustomFullHttpRequest duplicate() { + return replace(content().duplicate()); + } + + @Override + public CustomFullHttpRequest retainedDuplicate() { + return replace(content().retainedDuplicate()); + } + + @Override + public CustomFullHttpRequest replace(ByteBuf content) { + CustomFullHttpRequest request = new CustomFullHttpRequest(protocolVersion(), method(), uri(), content, + headers().copy(), trailingHeaders().copy()); + request.setDecoderResult(decoderResult()); + return request; + } + } + + @Test + public void testCustomMessageEmptyLastContent() { + HttpRequestEncoder encoder = new HttpRequestEncoder(); + EmbeddedChannel channel = new EmbeddedChannel(encoder); + HttpRequest customMsg = new CustomFullHttpRequest(HttpVersion.HTTP_1_1, + HttpMethod.POST, "/", Unpooled.EMPTY_BUFFER); + assertTrue(channel.writeOutbound(customMsg)); + // Ensure we only produce ByteBuf instances. + ByteBuf head = channel.readOutbound(); + assertTrue(head.release()); + assertNull(channel.readOutbound()); + assertFalse(channel.finish()); + } + + private void testEmptyContents(boolean chunked, boolean trailers) throws Exception { + HttpRequestEncoder encoder = new HttpRequestEncoder(); + EmbeddedChannel channel = new EmbeddedChannel(encoder); + HttpRequest request = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, "/"); + if (chunked) { + HttpUtil.setTransferEncodingChunked(request, true); + } + assertTrue(channel.writeOutbound(request)); + + ByteBuf contentBuffer = Unpooled.buffer(); + assertTrue(channel.writeOutbound(new DefaultHttpContent(contentBuffer))); + + ByteBuf lastContentBuffer = Unpooled.buffer(); + LastHttpContent last = new DefaultLastHttpContent(lastContentBuffer); + if (trailers) { + last.trailingHeaders().set("X-Netty-Test", "true"); + } + assertTrue(channel.writeOutbound(last)); + + // Ensure we only produce ByteBuf instances. + ByteBuf head = channel.readOutbound(); + assertTrue(head.release()); + + ByteBuf content = channel.readOutbound(); + content.release(); + + ByteBuf lastContent = channel.readOutbound(); + lastContent.release(); + assertFalse(channel.finish()); + } + + /** + * A test that checks for a NPE that would occur if when processing {@link LastHttpContent#EMPTY_LAST_CONTENT} + * when a certain initialization order of {@link EmptyHttpHeaders} would occur. + */ + @Test + public void testForChunkedRequestNpe() throws Exception { + EmbeddedChannel channel = new EmbeddedChannel(new HttpRequestEncoder()); + assertTrue(channel.writeOutbound(new CustomHttpRequest())); + assertTrue(channel.writeOutbound(new DefaultHttpContent(Unpooled.copiedBuffer("test", CharsetUtil.US_ASCII)))); + assertTrue(channel.writeOutbound(LastHttpContent.EMPTY_LAST_CONTENT)); + assertTrue(channel.finishAndReleaseAll()); + } + + /** + * This class is required to triggered the desired initialization order of {@link EmptyHttpHeaders}. + * If {@link DefaultHttpRequest} is used, the {@link HttpHeaders} class will be initialized before {@link HttpUtil} + * and the test won't trigger the original issue. + */ + private static final class CustomHttpRequest implements HttpRequest { + + @Override + public DecoderResult decoderResult() { + return DecoderResult.SUCCESS; + } + + @Override + public void setDecoderResult(DecoderResult result) { + } + + @Override + public DecoderResult getDecoderResult() { + return decoderResult(); + } + + @Override + public HttpVersion getProtocolVersion() { + return HttpVersion.HTTP_1_1; + } + + @Override + public HttpVersion protocolVersion() { + return getProtocolVersion(); + } + + @Override + public HttpHeaders headers() { + DefaultHttpHeaders headers = new DefaultHttpHeaders(); + headers.add("Transfer-Encoding", "chunked"); + return headers; + } + + @Override + public HttpMethod getMethod() { + return HttpMethod.POST; + } + + @Override + public HttpMethod method() { + return getMethod(); + } + + @Override + public HttpRequest setMethod(HttpMethod method) { + return this; + } + + @Override + public String getUri() { + return "/"; + } + + @Override + public String uri() { + return "/"; + } + + @Override + public HttpRequest setUri(String uri) { + return this; + } + + @Override + public HttpRequest setProtocolVersion(HttpVersion version) { + return this; + } + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpResponseDecoderTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpResponseDecoderTest.java new file mode 100644 index 0000000..93d847d --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpResponseDecoderTest.java @@ -0,0 +1,1125 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.PrematureChannelClosureException; +import io.netty.util.CharsetUtil; +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.ArrayList; +import java.util.List; +import java.util.Map; +import java.util.Random; +import static io.netty.handler.codec.http.HttpHeadersTestUtils.of; +import static org.hamcrest.CoreMatchers.instanceOf; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.CoreMatchers.not; +import static org.hamcrest.CoreMatchers.nullValue; +import static org.hamcrest.CoreMatchers.sameInstance; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class HttpResponseDecoderTest { + + /** + * The size of headers should be calculated correctly even if a single header is split into multiple fragments. + * @see #3445 + */ + @Test + public void testMaxHeaderSize1() { + final int maxHeaderSize = 8192; + + final EmbeddedChannel ch = new EmbeddedChannel(new HttpResponseDecoder(4096, maxHeaderSize, 8192)); + final char[] bytes = new char[maxHeaderSize / 2 - 4]; + Arrays.fill(bytes, 'a'); + + ch.writeInbound(Unpooled.copiedBuffer("HTTP/1.1 200 OK\r\n", CharsetUtil.US_ASCII)); + + // Write two 4096-byte headers (= 8192 bytes) + ch.writeInbound(Unpooled.copiedBuffer("A:", CharsetUtil.US_ASCII)); + ch.writeInbound(Unpooled.copiedBuffer(bytes, CharsetUtil.US_ASCII)); + ch.writeInbound(Unpooled.copiedBuffer("\r\n", CharsetUtil.US_ASCII)); + assertNull(ch.readInbound()); + ch.writeInbound(Unpooled.copiedBuffer("B:", CharsetUtil.US_ASCII)); + ch.writeInbound(Unpooled.copiedBuffer(bytes, CharsetUtil.US_ASCII)); + ch.writeInbound(Unpooled.copiedBuffer("\r\n", CharsetUtil.US_ASCII)); + ch.writeInbound(Unpooled.copiedBuffer("\r\n", CharsetUtil.US_ASCII)); + + HttpResponse res = ch.readInbound(); + assertNull(res.decoderResult().cause()); + assertTrue(res.decoderResult().isSuccess()); + + assertNull(ch.readInbound()); + assertTrue(ch.finish()); + assertThat(ch.readInbound(), instanceOf(LastHttpContent.class)); + } + + /** + * Complementary test case of {@link #testMaxHeaderSize1()} When it actually exceeds the maximum, it should fail. + */ + @Test + public void testMaxHeaderSize2() { + final int maxHeaderSize = 8192; + + final EmbeddedChannel ch = new EmbeddedChannel(new HttpResponseDecoder(4096, maxHeaderSize, 8192)); + final char[] bytes = new char[maxHeaderSize / 2 - 2]; + Arrays.fill(bytes, 'a'); + + ch.writeInbound(Unpooled.copiedBuffer("HTTP/1.1 200 OK\r\n", CharsetUtil.US_ASCII)); + + // Write a 4096-byte header and a 4097-byte header to test an off-by-one case (= 8193 bytes) + ch.writeInbound(Unpooled.copiedBuffer("A:", CharsetUtil.US_ASCII)); + ch.writeInbound(Unpooled.copiedBuffer(bytes, CharsetUtil.US_ASCII)); + ch.writeInbound(Unpooled.copiedBuffer("\r\n", CharsetUtil.US_ASCII)); + assertNull(ch.readInbound()); + ch.writeInbound(Unpooled.copiedBuffer("B: ", CharsetUtil.US_ASCII)); // Note an extra space. + ch.writeInbound(Unpooled.copiedBuffer(bytes, CharsetUtil.US_ASCII)); + ch.writeInbound(Unpooled.copiedBuffer("\r\n", CharsetUtil.US_ASCII)); + ch.writeInbound(Unpooled.copiedBuffer("\r\n", CharsetUtil.US_ASCII)); + + HttpResponse res = ch.readInbound(); + assertTrue(res.decoderResult().cause() instanceof TooLongHttpHeaderException); + + assertFalse(ch.finish()); + assertNull(ch.readInbound()); + } + + @Test + public void testResponseChunked() { + EmbeddedChannel ch = new EmbeddedChannel(new HttpResponseDecoder()); + ch.writeInbound(Unpooled.copiedBuffer("HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n", + CharsetUtil.US_ASCII)); + + HttpResponse res = ch.readInbound(); + assertThat(res.protocolVersion(), sameInstance(HttpVersion.HTTP_1_1)); + assertThat(res.status(), is(HttpResponseStatus.OK)); + + byte[] data = new byte[64]; + for (int i = 0; i < data.length; i++) { + data[i] = (byte) i; + } + + for (int i = 0; i < 10; i++) { + assertFalse(ch.writeInbound(Unpooled.copiedBuffer(Integer.toHexString(data.length) + "\r\n", + CharsetUtil.US_ASCII))); + assertTrue(ch.writeInbound(Unpooled.copiedBuffer(data))); + HttpContent content = ch.readInbound(); + assertEquals(data.length, content.content().readableBytes()); + + byte[] decodedData = new byte[data.length]; + content.content().readBytes(decodedData); + assertArrayEquals(data, decodedData); + content.release(); + + assertFalse(ch.writeInbound(Unpooled.copiedBuffer("\r\n", CharsetUtil.US_ASCII))); + } + + // Write the last chunk. + ch.writeInbound(Unpooled.copiedBuffer("0\r\n\r\n", CharsetUtil.US_ASCII)); + + // Ensure the last chunk was decoded. + LastHttpContent content = ch.readInbound(); + assertFalse(content.content().isReadable()); + content.release(); + + ch.finish(); + assertNull(ch.readInbound()); + } + + @Test + public void testResponseChunkedWithValidUncommonPatterns() { + EmbeddedChannel ch = new EmbeddedChannel(new HttpResponseDecoder()); + ch.writeInbound(Unpooled.copiedBuffer("HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n", + CharsetUtil.US_ASCII)); + + HttpResponse res = ch.readInbound(); + assertThat(res.protocolVersion(), sameInstance(HttpVersion.HTTP_1_1)); + assertThat(res.status(), is(HttpResponseStatus.OK)); + + byte[] data = new byte[1]; + for (int i = 0; i < data.length; i++) { + data[i] = (byte) i; + } + + // leading whitespace, trailing whitespace + + assertFalse(ch.writeInbound(Unpooled.copiedBuffer(" " + Integer.toHexString(data.length) + " \r\n", + CharsetUtil.US_ASCII))); + assertTrue(ch.writeInbound(Unpooled.copiedBuffer(data))); + HttpContent content = ch.readInbound(); + assertEquals(data.length, content.content().readableBytes()); + + byte[] decodedData = new byte[data.length]; + content.content().readBytes(decodedData); + assertArrayEquals(data, decodedData); + content.release(); + + assertFalse(ch.writeInbound(Unpooled.copiedBuffer("\r\n", CharsetUtil.US_ASCII))); + + // leading whitespace, trailing control char + + assertFalse(ch.writeInbound(Unpooled.copiedBuffer(" " + Integer.toHexString(data.length) + "\0\r\n", + CharsetUtil.US_ASCII))); + assertTrue(ch.writeInbound(Unpooled.copiedBuffer(data))); + content = ch.readInbound(); + assertEquals(data.length, content.content().readableBytes()); + + decodedData = new byte[data.length]; + content.content().readBytes(decodedData); + assertArrayEquals(data, decodedData); + content.release(); + + assertFalse(ch.writeInbound(Unpooled.copiedBuffer("\r\n", CharsetUtil.US_ASCII))); + + // leading whitespace, trailing semicolon + + assertFalse(ch.writeInbound(Unpooled.copiedBuffer(" " + Integer.toHexString(data.length) + ";\r\n", + CharsetUtil.US_ASCII))); + assertTrue(ch.writeInbound(Unpooled.copiedBuffer(data))); + content = ch.readInbound(); + assertEquals(data.length, content.content().readableBytes()); + + decodedData = new byte[data.length]; + content.content().readBytes(decodedData); + assertArrayEquals(data, decodedData); + content.release(); + + assertFalse(ch.writeInbound(Unpooled.copiedBuffer("\r\n", CharsetUtil.US_ASCII))); + + // Write the last chunk. + ch.writeInbound(Unpooled.copiedBuffer("0\r\n\r\n", CharsetUtil.US_ASCII)); + + // Ensure the last chunk was decoded. + LastHttpContent lastContent = ch.readInbound(); + assertFalse(lastContent.content().isReadable()); + lastContent.release(); + + ch.finish(); + assertNull(ch.readInbound()); + } + + @Test + public void testResponseChunkedWithControlChars() { + EmbeddedChannel ch = new EmbeddedChannel(new HttpResponseDecoder()); + ch.writeInbound(Unpooled.copiedBuffer("HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n", + CharsetUtil.US_ASCII)); + + HttpResponse res = ch.readInbound(); + assertThat(res.protocolVersion(), sameInstance(HttpVersion.HTTP_1_1)); + assertThat(res.status(), is(HttpResponseStatus.OK)); + + byte[] data = new byte[1]; + for (int i = 0; i < data.length; i++) { + data[i] = (byte) i; + } + + assertFalse(ch.writeInbound(Unpooled.copiedBuffer(" " + Integer.toHexString(data.length) + " \r\n", + CharsetUtil.US_ASCII))); + assertTrue(ch.writeInbound(Unpooled.copiedBuffer(data))); + HttpContent content = ch.readInbound(); + assertEquals(data.length, content.content().readableBytes()); + + byte[] decodedData = new byte[data.length]; + content.content().readBytes(decodedData); + assertArrayEquals(data, decodedData); + content.release(); + + assertFalse(ch.writeInbound(Unpooled.copiedBuffer("\r\n", CharsetUtil.US_ASCII))); + + // Write the last chunk. + ch.writeInbound(Unpooled.copiedBuffer("0\r\n\r\n", CharsetUtil.US_ASCII)); + + // Ensure the last chunk was decoded. + LastHttpContent lastContent = ch.readInbound(); + assertFalse(lastContent.content().isReadable()); + lastContent.release(); + + assertFalse(ch.finish()); + assertNull(ch.readInbound()); + } + + @Test + public void testResponseDisallowPartialChunks() { + HttpResponseDecoder decoder = new HttpResponseDecoder( + HttpObjectDecoder.DEFAULT_MAX_INITIAL_LINE_LENGTH, + HttpObjectDecoder.DEFAULT_MAX_HEADER_SIZE, + HttpObjectDecoder.DEFAULT_MAX_CHUNK_SIZE, + HttpObjectDecoder.DEFAULT_VALIDATE_HEADERS, + HttpObjectDecoder.DEFAULT_INITIAL_BUFFER_SIZE, + HttpObjectDecoder.DEFAULT_ALLOW_DUPLICATE_CONTENT_LENGTHS, + false); + EmbeddedChannel ch = new EmbeddedChannel(decoder); + + String headers = "HTTP/1.1 200 OK\r\n" + + "Transfer-Encoding: chunked\r\n" + + "\r\n"; + assertTrue(ch.writeInbound(Unpooled.copiedBuffer(headers, CharsetUtil.US_ASCII))); + + HttpResponse res = ch.readInbound(); + assertThat(res.protocolVersion(), sameInstance(HttpVersion.HTTP_1_1)); + assertThat(res.status(), is(HttpResponseStatus.OK)); + + byte[] chunkBytes = new byte[10]; + Random random = new Random(); + random.nextBytes(chunkBytes); + final ByteBuf chunk = ch.alloc().buffer().writeBytes(chunkBytes); + final int chunkSize = chunk.readableBytes(); + ByteBuf partialChunk1 = chunk.retainedSlice(0, 5); + ByteBuf partialChunk2 = chunk.retainedSlice(5, 5); + + assertFalse(ch.writeInbound(Unpooled.copiedBuffer(Integer.toHexString(chunkSize) + + "\r\n", CharsetUtil.US_ASCII))); + assertFalse(ch.writeInbound(partialChunk1)); + assertTrue(ch.writeInbound(partialChunk2)); + + HttpContent content = ch.readInbound(); + assertEquals(chunk, content.content()); + content.release(); + chunk.release(); + + assertFalse(ch.writeInbound(Unpooled.copiedBuffer("\r\n", CharsetUtil.US_ASCII))); + + // Write the last chunk. + assertTrue(ch.writeInbound(Unpooled.copiedBuffer("0\r\n\r\n", CharsetUtil.US_ASCII))); + + // Ensure the last chunk was decoded. + HttpContent lastContent = ch.readInbound(); + assertFalse(lastContent.content().isReadable()); + lastContent.release(); + + assertFalse(ch.finish()); + } + + @Test + public void testResponseChunkedExceedMaxChunkSize() { + EmbeddedChannel ch = new EmbeddedChannel(new HttpResponseDecoder(4096, 8192, 32)); + ch.writeInbound( + Unpooled.copiedBuffer("HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n", CharsetUtil.US_ASCII)); + + HttpResponse res = ch.readInbound(); + assertThat(res.protocolVersion(), sameInstance(HttpVersion.HTTP_1_1)); + assertThat(res.status(), is(HttpResponseStatus.OK)); + + byte[] data = new byte[64]; + for (int i = 0; i < data.length; i++) { + data[i] = (byte) i; + } + + for (int i = 0; i < 10; i++) { + assertFalse(ch.writeInbound(Unpooled.copiedBuffer(Integer.toHexString(data.length) + "\r\n", + CharsetUtil.US_ASCII))); + assertTrue(ch.writeInbound(Unpooled.copiedBuffer(data))); + + byte[] decodedData = new byte[data.length]; + HttpContent content = ch.readInbound(); + assertEquals(32, content.content().readableBytes()); + content.content().readBytes(decodedData, 0, 32); + content.release(); + + content = ch.readInbound(); + assertEquals(32, content.content().readableBytes()); + + content.content().readBytes(decodedData, 32, 32); + + assertArrayEquals(data, decodedData); + content.release(); + + assertFalse(ch.writeInbound(Unpooled.copiedBuffer("\r\n", CharsetUtil.US_ASCII))); + } + + // Write the last chunk. + ch.writeInbound(Unpooled.copiedBuffer("0\r\n\r\n", CharsetUtil.US_ASCII)); + + // Ensure the last chunk was decoded. + LastHttpContent content = ch.readInbound(); + assertFalse(content.content().isReadable()); + content.release(); + + ch.finish(); + assertNull(ch.readInbound()); + } + + @Test + public void testClosureWithoutContentLength1() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new HttpResponseDecoder()); + ch.writeInbound(Unpooled.copiedBuffer("HTTP/1.1 200 OK\r\n\r\n", CharsetUtil.US_ASCII)); + + // Read the response headers. + HttpResponse res = ch.readInbound(); + assertThat(res.protocolVersion(), sameInstance(HttpVersion.HTTP_1_1)); + assertThat(res.status(), is(HttpResponseStatus.OK)); + assertThat(ch.readInbound(), is(nullValue())); + + // Close the connection without sending anything. + assertTrue(ch.finish()); + + // The decoder should still produce the last content. + LastHttpContent content = ch.readInbound(); + assertThat(content.content().isReadable(), is(false)); + content.release(); + + // But nothing more. + assertThat(ch.readInbound(), is(nullValue())); + } + + @Test + public void testClosureWithoutContentLength2() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new HttpResponseDecoder()); + + // Write the partial response. + ch.writeInbound(Unpooled.copiedBuffer("HTTP/1.1 200 OK\r\n\r\n12345678", CharsetUtil.US_ASCII)); + + // Read the response headers. + HttpResponse res = ch.readInbound(); + assertThat(res.protocolVersion(), sameInstance(HttpVersion.HTTP_1_1)); + assertThat(res.status(), is(HttpResponseStatus.OK)); + + // Read the partial content. + HttpContent content = ch.readInbound(); + assertThat(content.content().toString(CharsetUtil.US_ASCII), is("12345678")); + assertThat(content, is(not(instanceOf(LastHttpContent.class)))); + content.release(); + + assertThat(ch.readInbound(), is(nullValue())); + + // Close the connection. + assertTrue(ch.finish()); + + // The decoder should still produce the last content. + LastHttpContent lastContent = ch.readInbound(); + assertThat(lastContent.content().isReadable(), is(false)); + lastContent.release(); + + // But nothing more. + assertThat(ch.readInbound(), is(nullValue())); + } + + @Test + public void testPrematureClosureWithChunkedEncoding1() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new HttpResponseDecoder()); + ch.writeInbound( + Unpooled.copiedBuffer("HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n", CharsetUtil.US_ASCII)); + + // Read the response headers. + HttpResponse res = ch.readInbound(); + assertThat(res.protocolVersion(), sameInstance(HttpVersion.HTTP_1_1)); + assertThat(res.status(), is(HttpResponseStatus.OK)); + assertThat(res.headers().get(HttpHeaderNames.TRANSFER_ENCODING), is("chunked")); + assertThat(ch.readInbound(), is(nullValue())); + + // Close the connection without sending anything. + ch.finish(); + // The decoder should not generate the last chunk because it's closed prematurely. + assertThat(ch.readInbound(), is(nullValue())); + } + + @Test + public void testPrematureClosureWithChunkedEncoding2() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new HttpResponseDecoder()); + + // Write the partial response. + ch.writeInbound(Unpooled.copiedBuffer( + "HTTP/1.1 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n8\r\n12345678", CharsetUtil.US_ASCII)); + + // Read the response headers. + HttpResponse res = ch.readInbound(); + assertThat(res.protocolVersion(), sameInstance(HttpVersion.HTTP_1_1)); + assertThat(res.status(), is(HttpResponseStatus.OK)); + assertThat(res.headers().get(HttpHeaderNames.TRANSFER_ENCODING), is("chunked")); + + // Read the partial content. + HttpContent content = ch.readInbound(); + assertThat(content.content().toString(CharsetUtil.US_ASCII), is("12345678")); + assertThat(content, is(not(instanceOf(LastHttpContent.class)))); + content.release(); + + assertThat(ch.readInbound(), is(nullValue())); + + // Close the connection. + ch.finish(); + + // The decoder should not generate the last chunk because it's closed prematurely. + assertThat(ch.readInbound(), is(nullValue())); + } + + @Test + public void testLastResponseWithEmptyHeaderAndEmptyContent() { + EmbeddedChannel ch = new EmbeddedChannel(new HttpResponseDecoder()); + ch.writeInbound(Unpooled.copiedBuffer("HTTP/1.1 200 OK\r\n\r\n", CharsetUtil.US_ASCII)); + + HttpResponse res = ch.readInbound(); + assertThat(res.protocolVersion(), sameInstance(HttpVersion.HTTP_1_1)); + assertThat(res.status(), is(HttpResponseStatus.OK)); + assertThat(ch.readInbound(), is(nullValue())); + + assertThat(ch.finish(), is(true)); + + LastHttpContent content = ch.readInbound(); + assertThat(content.content().isReadable(), is(false)); + content.release(); + + assertThat(ch.readInbound(), is(nullValue())); + } + + @Test + public void testLastResponseWithoutContentLengthHeader() { + EmbeddedChannel ch = new EmbeddedChannel(new HttpResponseDecoder()); + ch.writeInbound(Unpooled.copiedBuffer("HTTP/1.1 200 OK\r\n\r\n", CharsetUtil.US_ASCII)); + + HttpResponse res = ch.readInbound(); + assertThat(res.protocolVersion(), sameInstance(HttpVersion.HTTP_1_1)); + assertThat(res.status(), is(HttpResponseStatus.OK)); + assertThat(ch.readInbound(), is(nullValue())); + + ch.writeInbound(Unpooled.wrappedBuffer(new byte[1024])); + HttpContent content = ch.readInbound(); + assertThat(content.content().readableBytes(), is(1024)); + content.release(); + + assertThat(ch.finish(), is(true)); + + LastHttpContent lastContent = ch.readInbound(); + assertThat(lastContent.content().isReadable(), is(false)); + lastContent.release(); + + assertThat(ch.readInbound(), is(nullValue())); + } + + @Test + public void testLastResponseWithHeaderRemoveTrailingSpaces() { + EmbeddedChannel ch = new EmbeddedChannel(new HttpResponseDecoder()); + ch.writeInbound(Unpooled.copiedBuffer( + "HTTP/1.1 200 OK\r\nX-Header: h2=h2v2; Expires=Wed, 09-Jun-2021 10:18:14 GMT \r\n\r\n", + CharsetUtil.US_ASCII)); + + HttpResponse res = ch.readInbound(); + assertThat(res.protocolVersion(), sameInstance(HttpVersion.HTTP_1_1)); + assertThat(res.status(), is(HttpResponseStatus.OK)); + assertThat(res.headers().get(of("X-Header")), is("h2=h2v2; Expires=Wed, 09-Jun-2021 10:18:14 GMT")); + assertThat(ch.readInbound(), is(nullValue())); + + ch.writeInbound(Unpooled.wrappedBuffer(new byte[1024])); + HttpContent content = ch.readInbound(); + assertThat(content.content().readableBytes(), is(1024)); + content.release(); + + assertThat(ch.finish(), is(true)); + + LastHttpContent lastContent = ch.readInbound(); + assertThat(lastContent.content().isReadable(), is(false)); + lastContent.release(); + + assertThat(ch.readInbound(), is(nullValue())); + } + + @Test + public void testResetContentResponseWithTransferEncoding() { + EmbeddedChannel ch = new EmbeddedChannel(new HttpResponseDecoder()); + assertTrue(ch.writeInbound(Unpooled.copiedBuffer( + "HTTP/1.1 205 Reset Content\r\n" + + "Transfer-Encoding: chunked\r\n" + + "\r\n" + + "0\r\n" + + "\r\n", + CharsetUtil.US_ASCII))); + + HttpResponse res = ch.readInbound(); + assertThat(res.protocolVersion(), sameInstance(HttpVersion.HTTP_1_1)); + assertThat(res.status(), is(HttpResponseStatus.RESET_CONTENT)); + + LastHttpContent lastContent = ch.readInbound(); + assertThat(lastContent.content().isReadable(), is(false)); + lastContent.release(); + + assertThat(ch.finish(), is(false)); + } + + @Test + public void testLastResponseWithTrailingHeader() { + EmbeddedChannel ch = new EmbeddedChannel(new HttpResponseDecoder()); + ch.writeInbound(Unpooled.copiedBuffer( + "HTTP/1.1 200 OK\r\n" + + "Transfer-Encoding: chunked\r\n" + + "\r\n" + + "0\r\n" + + "Set-Cookie: t1=t1v1\r\n" + + "Set-Cookie: t2=t2v2; Expires=Wed, 09-Jun-2021 10:18:14 GMT\r\n" + + "\r\n", + CharsetUtil.US_ASCII)); + + HttpResponse res = ch.readInbound(); + assertThat(res.protocolVersion(), sameInstance(HttpVersion.HTTP_1_1)); + assertThat(res.status(), is(HttpResponseStatus.OK)); + + LastHttpContent lastContent = ch.readInbound(); + assertThat(lastContent.content().isReadable(), is(false)); + HttpHeaders headers = lastContent.trailingHeaders(); + assertEquals(1, headers.names().size()); + List values = headers.getAll(of("Set-Cookie")); + assertEquals(2, values.size()); + assertTrue(values.contains("t1=t1v1")); + assertTrue(values.contains("t2=t2v2; Expires=Wed, 09-Jun-2021 10:18:14 GMT")); + lastContent.release(); + + assertThat(ch.finish(), is(false)); + assertThat(ch.readInbound(), is(nullValue())); + } + + @Test + public void testLastResponseWithTrailingHeaderFragmented() { + byte[] data = ("HTTP/1.1 200 OK\r\n" + + "Transfer-Encoding: chunked\r\n" + + "\r\n" + + "0\r\n" + + "Set-Cookie: t1=t1v1\r\n" + + "Set-Cookie: t2=t2v2; Expires=Wed, 09-Jun-2021 10:18:14 GMT\r\n" + + "\r\n").getBytes(CharsetUtil.US_ASCII); + + for (int i = 1; i < data.length; i++) { + testLastResponseWithTrailingHeaderFragmented(data, i); + } + } + + private static void testLastResponseWithTrailingHeaderFragmented(byte[] content, int fragmentSize) { + EmbeddedChannel ch = new EmbeddedChannel(new HttpResponseDecoder()); + int headerLength = 47; + // split up the header + for (int a = 0; a < headerLength;) { + int amount = fragmentSize; + if (a + amount > headerLength) { + amount = headerLength - a; + } + + // if header is done it should produce an HttpRequest + boolean headerDone = a + amount == headerLength; + assertEquals(headerDone, ch.writeInbound(Unpooled.copiedBuffer(content, a, amount))); + a += amount; + } + + ch.writeInbound(Unpooled.copiedBuffer(content, headerLength, content.length - headerLength)); + HttpResponse res = ch.readInbound(); + assertThat(res.protocolVersion(), sameInstance(HttpVersion.HTTP_1_1)); + assertThat(res.status(), is(HttpResponseStatus.OK)); + + LastHttpContent lastContent = ch.readInbound(); + assertThat(lastContent.content().isReadable(), is(false)); + HttpHeaders headers = lastContent.trailingHeaders(); + assertEquals(1, headers.names().size()); + List values = headers.getAll(of("Set-Cookie")); + assertEquals(2, values.size()); + assertTrue(values.contains("t1=t1v1")); + assertTrue(values.contains("t2=t2v2; Expires=Wed, 09-Jun-2021 10:18:14 GMT")); + lastContent.release(); + + assertThat(ch.finish(), is(false)); + assertThat(ch.readInbound(), is(nullValue())); + } + + @Test + public void testResponseWithContentLength() { + EmbeddedChannel ch = new EmbeddedChannel(new HttpResponseDecoder()); + ch.writeInbound(Unpooled.copiedBuffer( + "HTTP/1.1 200 OK\r\n" + + "Content-Length: 10\r\n" + + "\r\n", CharsetUtil.US_ASCII)); + + byte[] data = new byte[10]; + for (int i = 0; i < data.length; i++) { + data[i] = (byte) i; + } + ch.writeInbound(Unpooled.copiedBuffer(data, 0, data.length / 2)); + ch.writeInbound(Unpooled.copiedBuffer(data, 5, data.length / 2)); + + HttpResponse res = ch.readInbound(); + assertThat(res.protocolVersion(), sameInstance(HttpVersion.HTTP_1_1)); + assertThat(res.status(), is(HttpResponseStatus.OK)); + + HttpContent firstContent = ch.readInbound(); + assertThat(firstContent.content().readableBytes(), is(5)); + assertEquals(Unpooled.copiedBuffer(data, 0, 5), firstContent.content()); + firstContent.release(); + + LastHttpContent lastContent = ch.readInbound(); + assertEquals(5, lastContent.content().readableBytes()); + assertEquals(Unpooled.copiedBuffer(data, 5, 5), lastContent.content()); + lastContent.release(); + + assertThat(ch.finish(), is(false)); + assertThat(ch.readInbound(), is(nullValue())); + } + + @Test + public void testResponseWithContentLengthFragmented() { + byte[] data = ("HTTP/1.1 200 OK\r\n" + + "Content-Length: 10\r\n" + + "\r\n").getBytes(CharsetUtil.US_ASCII); + + for (int i = 1; i < data.length; i++) { + testResponseWithContentLengthFragmented(data, i); + } + } + + private static void testResponseWithContentLengthFragmented(byte[] header, int fragmentSize) { + EmbeddedChannel ch = new EmbeddedChannel(new HttpResponseDecoder()); + // split up the header + for (int a = 0; a < header.length;) { + int amount = fragmentSize; + if (a + amount > header.length) { + amount = header.length - a; + } + + ch.writeInbound(Unpooled.copiedBuffer(header, a, amount)); + a += amount; + } + byte[] data = new byte[10]; + for (int i = 0; i < data.length; i++) { + data[i] = (byte) i; + } + ch.writeInbound(Unpooled.copiedBuffer(data, 0, data.length / 2)); + ch.writeInbound(Unpooled.copiedBuffer(data, 5, data.length / 2)); + + HttpResponse res = ch.readInbound(); + assertThat(res.protocolVersion(), sameInstance(HttpVersion.HTTP_1_1)); + assertThat(res.status(), is(HttpResponseStatus.OK)); + + HttpContent firstContent = ch.readInbound(); + assertThat(firstContent.content().readableBytes(), is(5)); + assertEquals(Unpooled.wrappedBuffer(data, 0, 5), firstContent.content()); + firstContent.release(); + + LastHttpContent lastContent = ch.readInbound(); + assertEquals(5, lastContent.content().readableBytes()); + assertEquals(Unpooled.wrappedBuffer(data, 5, 5), lastContent.content()); + lastContent.release(); + + assertThat(ch.finish(), is(false)); + assertThat(ch.readInbound(), is(nullValue())); + } + + @Test + public void testOrderOfHeadersWithContentLength() { + String requestStr = "HTTP/1.1 200 OK\r\n" + + "Content-Type: text/plain; charset=UTF-8\r\n" + + "Content-Length: 5\r\n" + + "Connection: close\r\n\r\n" + + "hello"; + EmbeddedChannel channel = new EmbeddedChannel(new HttpResponseDecoder()); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(requestStr, CharsetUtil.US_ASCII))); + HttpResponse response = channel.readInbound(); + List headers = new ArrayList(); + for (Map.Entry header : response.headers()) { + headers.add(header.getKey()); + } + assertEquals(Arrays.asList("Content-Type", "Content-Length", "Connection"), headers, "ordered headers"); + } + + @Test + public void testWebSocketResponse() { + byte[] data = ("HTTP/1.1 101 WebSocket Protocol Handshake\r\n" + + "Upgrade: WebSocket\r\n" + + "Connection: Upgrade\r\n" + + "Sec-WebSocket-Origin: http://localhost:8080\r\n" + + "Sec-WebSocket-Location: ws://localhost/some/path\r\n" + + "\r\n" + + "1234567812345678").getBytes(); + EmbeddedChannel ch = new EmbeddedChannel(new HttpResponseDecoder()); + ch.writeInbound(Unpooled.wrappedBuffer(data)); + + HttpResponse res = ch.readInbound(); + assertThat(res.protocolVersion(), sameInstance(HttpVersion.HTTP_1_1)); + assertThat(res.status(), is(HttpResponseStatus.SWITCHING_PROTOCOLS)); + HttpContent content = ch.readInbound(); + assertThat(content.content().readableBytes(), is(16)); + content.release(); + + assertThat(ch.finish(), is(false)); + + assertThat(ch.readInbound(), is(nullValue())); + } + + // See https://github.com/netty/netty/issues/2173 + @Test + public void testWebSocketResponseWithDataFollowing() { + byte[] data = ("HTTP/1.1 101 WebSocket Protocol Handshake\r\n" + + "Upgrade: WebSocket\r\n" + + "Connection: Upgrade\r\n" + + "Sec-WebSocket-Origin: http://localhost:8080\r\n" + + "Sec-WebSocket-Location: ws://localhost/some/path\r\n" + + "\r\n" + + "1234567812345678").getBytes(); + byte[] otherData = {1, 2, 3, 4}; + + EmbeddedChannel ch = new EmbeddedChannel(new HttpResponseDecoder()); + ch.writeInbound(Unpooled.copiedBuffer(data, otherData)); + + HttpResponse res = ch.readInbound(); + assertThat(res.protocolVersion(), sameInstance(HttpVersion.HTTP_1_1)); + assertThat(res.status(), is(HttpResponseStatus.SWITCHING_PROTOCOLS)); + HttpContent content = ch.readInbound(); + assertThat(content.content().readableBytes(), is(16)); + content.release(); + + assertThat(ch.finish(), is(true)); + + ByteBuf expected = Unpooled.wrappedBuffer(otherData); + ByteBuf buffer = ch.readInbound(); + try { + assertEquals(expected, buffer); + } finally { + expected.release(); + if (buffer != null) { + buffer.release(); + } + } + } + + @Test + public void testGarbageHeaders() { + // A response without headers - from https://github.com/netty/netty/issues/2103 + byte[] data = ("\r\n" + + "400 Bad Request\r\n" + + "\r\n" + + "

400 Bad Request

\r\n" + + "
nginx/1.1.19
\r\n" + + "\r\n" + + "\r\n").getBytes(); + + EmbeddedChannel ch = new EmbeddedChannel(new HttpResponseDecoder()); + + ch.writeInbound(Unpooled.copiedBuffer(data)); + + // Garbage input should generate the 999 Unknown response. + HttpResponse res = ch.readInbound(); + assertThat(res.protocolVersion(), sameInstance(HttpVersion.HTTP_1_0)); + assertThat(res.status().code(), is(999)); + assertThat(res.decoderResult().isFailure(), is(true)); + assertThat(res.decoderResult().isFinished(), is(true)); + assertThat(ch.readInbound(), is(nullValue())); + + // More garbage should not generate anything (i.e. the decoder discards anything beyond this point.) + ch.writeInbound(Unpooled.copiedBuffer(data)); + assertThat(ch.readInbound(), is(nullValue())); + + // Closing the connection should not generate anything since the protocol has been violated. + ch.finish(); + assertThat(ch.readInbound(), is(nullValue())); + } + + /** + * Tests if the decoder produces one and only {@link LastHttpContent} when an invalid chunk is received and + * the connection is closed. + */ + @Test + public void testGarbageChunk() { + EmbeddedChannel channel = new EmbeddedChannel(new HttpResponseDecoder()); + String responseWithIllegalChunk = + "HTTP/1.1 200 OK\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + "NOT_A_CHUNK_LENGTH\r\n"; + + channel.writeInbound(Unpooled.copiedBuffer(responseWithIllegalChunk, CharsetUtil.US_ASCII)); + assertThat(channel.readInbound(), is(instanceOf(HttpResponse.class))); + + // Ensure that the decoder generates the last chunk with correct decoder result. + LastHttpContent invalidChunk = channel.readInbound(); + assertThat(invalidChunk.decoderResult().isFailure(), is(true)); + invalidChunk.release(); + + // And no more messages should be produced by the decoder. + assertThat(channel.readInbound(), is(nullValue())); + + // .. even after the connection is closed. + assertThat(channel.finish(), is(false)); + } + + @Test + public void testWhiteSpaceGarbageChunk() { + EmbeddedChannel channel = new EmbeddedChannel(new HttpResponseDecoder()); + String responseWithIllegalChunk = + "HTTP/1.1 200 OK\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + " \r\n"; + + channel.writeInbound(Unpooled.copiedBuffer(responseWithIllegalChunk, CharsetUtil.US_ASCII)); + assertThat(channel.readInbound(), is(instanceOf(HttpResponse.class))); + + // Ensure that the decoder generates the last chunk with correct decoder result. + LastHttpContent invalidChunk = channel.readInbound(); + assertThat(invalidChunk.decoderResult().isFailure(), is(true)); + invalidChunk.release(); + + // And no more messages should be produced by the decoder. + assertThat(channel.readInbound(), is(nullValue())); + + // .. even after the connection is closed. + assertThat(channel.finish(), is(false)); + } + + @Test + public void testLeadingWhiteSpacesSemiColonGarbageChunk() { + EmbeddedChannel channel = new EmbeddedChannel(new HttpResponseDecoder()); + String responseWithIllegalChunk = + "HTTP/1.1 200 OK\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + " ;\r\n"; + + channel.writeInbound(Unpooled.copiedBuffer(responseWithIllegalChunk, CharsetUtil.US_ASCII)); + assertThat(channel.readInbound(), is(instanceOf(HttpResponse.class))); + + // Ensure that the decoder generates the last chunk with correct decoder result. + LastHttpContent invalidChunk = channel.readInbound(); + assertThat(invalidChunk.decoderResult().isFailure(), is(true)); + invalidChunk.release(); + + // And no more messages should be produced by the decoder. + assertThat(channel.readInbound(), is(nullValue())); + + // .. even after the connection is closed. + assertThat(channel.finish(), is(false)); + } + + @Test + public void testControlCharGarbageChunk() { + EmbeddedChannel channel = new EmbeddedChannel(new HttpResponseDecoder()); + String responseWithIllegalChunk = + "HTTP/1.1 200 OK\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + "\0\r\n"; + + channel.writeInbound(Unpooled.copiedBuffer(responseWithIllegalChunk, CharsetUtil.US_ASCII)); + assertThat(channel.readInbound(), is(instanceOf(HttpResponse.class))); + + // Ensure that the decoder generates the last chunk with correct decoder result. + LastHttpContent invalidChunk = channel.readInbound(); + assertThat(invalidChunk.decoderResult().isFailure(), is(true)); + invalidChunk.release(); + + // And no more messages should be produced by the decoder. + assertThat(channel.readInbound(), is(nullValue())); + + // .. even after the connection is closed. + assertThat(channel.finish(), is(false)); + } + + @Test + public void testLeadingWhiteSpacesControlCharGarbageChunk() { + EmbeddedChannel channel = new EmbeddedChannel(new HttpResponseDecoder()); + String responseWithIllegalChunk = + "HTTP/1.1 200 OK\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + " \0\r\n"; + + channel.writeInbound(Unpooled.copiedBuffer(responseWithIllegalChunk, CharsetUtil.US_ASCII)); + assertThat(channel.readInbound(), is(instanceOf(HttpResponse.class))); + + // Ensure that the decoder generates the last chunk with correct decoder result. + LastHttpContent invalidChunk = channel.readInbound(); + assertThat(invalidChunk.decoderResult().isFailure(), is(true)); + invalidChunk.release(); + + // And no more messages should be produced by the decoder. + assertThat(channel.readInbound(), is(nullValue())); + + // .. even after the connection is closed. + assertThat(channel.finish(), is(false)); + } + + @Test + public void testGarbageChunkAfterWhiteSpaces() { + EmbeddedChannel channel = new EmbeddedChannel(new HttpResponseDecoder()); + String responseWithIllegalChunk = + "HTTP/1.1 200 OK\r\n" + + "Transfer-Encoding: chunked\r\n\r\n" + + " 12345N1 ;\r\n"; + + channel.writeInbound(Unpooled.copiedBuffer(responseWithIllegalChunk, CharsetUtil.US_ASCII)); + assertThat(channel.readInbound(), is(instanceOf(HttpResponse.class))); + + // Ensure that the decoder generates the last chunk with correct decoder result. + LastHttpContent invalidChunk = channel.readInbound(); + assertThat(invalidChunk.decoderResult().isFailure(), is(true)); + invalidChunk.release(); + + // And no more messages should be produced by the decoder. + assertThat(channel.readInbound(), is(nullValue())); + + // .. even after the connection is closed. + assertThat(channel.finish(), is(false)); + } + + @Test + public void testConnectionClosedBeforeHeadersReceived() { + EmbeddedChannel channel = new EmbeddedChannel(new HttpResponseDecoder()); + String responseInitialLine = + "HTTP/1.1 200 OK\r\n"; + assertFalse(channel.writeInbound(Unpooled.copiedBuffer(responseInitialLine, CharsetUtil.US_ASCII))); + assertTrue(channel.finish()); + HttpMessage message = channel.readInbound(); + assertTrue(message.decoderResult().isFailure()); + assertThat(message.decoderResult().cause(), instanceOf(PrematureChannelClosureException.class)); + assertNull(channel.readInbound()); + } + + @Test + public void testTrailerWithEmptyLineInSeparateBuffer() { + HttpResponseDecoder decoder = new HttpResponseDecoder(); + EmbeddedChannel channel = new EmbeddedChannel(decoder); + + String headers = "HTTP/1.1 200 OK\r\n" + + "Transfer-Encoding: chunked\r\n" + + "Trailer: My-Trailer\r\n"; + assertFalse(channel.writeInbound(Unpooled.copiedBuffer(headers.getBytes(CharsetUtil.US_ASCII)))); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer("\r\n".getBytes(CharsetUtil.US_ASCII)))); + + assertTrue(channel.writeInbound(Unpooled.copiedBuffer("0\r\n", CharsetUtil.US_ASCII))); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer("My-Trailer: 42\r\n", CharsetUtil.US_ASCII))); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer("\r\n", CharsetUtil.US_ASCII))); + + HttpResponse response = channel.readInbound(); + assertEquals(2, response.headers().size()); + assertEquals("chunked", response.headers().get(HttpHeaderNames.TRANSFER_ENCODING)); + assertEquals("My-Trailer", response.headers().get(HttpHeaderNames.TRAILER)); + + LastHttpContent lastContent = channel.readInbound(); + assertEquals(1, lastContent.trailingHeaders().size()); + assertEquals("42", lastContent.trailingHeaders().get("My-Trailer")); + assertEquals(0, lastContent.content().readableBytes()); + lastContent.release(); + + assertFalse(channel.finish()); + } + + @Test + public void testWhitespace() { + EmbeddedChannel channel = new EmbeddedChannel(new HttpResponseDecoder()); + String requestStr = "HTTP/1.1 200 OK\r\n" + + "Transfer-Encoding : chunked\r\n" + + "Host: netty.io\n\r\n"; + + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(requestStr, CharsetUtil.US_ASCII))); + HttpResponse response = channel.readInbound(); + assertFalse(response.decoderResult().isFailure()); + assertEquals(HttpHeaderValues.CHUNKED.toString(), response.headers().get(HttpHeaderNames.TRANSFER_ENCODING)); + assertEquals("netty.io", response.headers().get(HttpHeaderNames.HOST)); + assertFalse(channel.finish()); + } + + @Test + public void testHttpMessageDecoderResult() { + String responseStr = "HTTP/1.1 200 OK\r\n" + + "Content-Length: 11\r\n" + + "Connection: close\r\n\r\n" + + "Lorem ipsum"; + EmbeddedChannel channel = new EmbeddedChannel(new HttpResponseDecoder()); + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(responseStr, CharsetUtil.US_ASCII))); + HttpResponse response = channel.readInbound(); + assertTrue(response.decoderResult().isSuccess()); + assertThat(response.decoderResult(), instanceOf(HttpMessageDecoderResult.class)); + HttpMessageDecoderResult decoderResult = (HttpMessageDecoderResult) response.decoderResult(); + assertThat(decoderResult.initialLineLength(), is(15)); + assertThat(decoderResult.headerSize(), is(35)); + assertThat(decoderResult.totalSize(), is(50)); + HttpContent c = channel.readInbound(); + c.release(); + assertFalse(channel.finish()); + } + + @Test + public void testHeaderNameStartsWithControlChar1c() { + testHeaderNameStartsWithControlChar(0x1c); + } + + @Test + public void testHeaderNameStartsWithControlChar1d() { + testHeaderNameStartsWithControlChar(0x1d); + } + + @Test + public void testHeaderNameStartsWithControlChar1e() { + testHeaderNameStartsWithControlChar(0x1e); + } + + @Test + public void testHeaderNameStartsWithControlChar1f() { + testHeaderNameStartsWithControlChar(0x1f); + } + + @Test + public void testHeaderNameStartsWithControlChar0c() { + testHeaderNameStartsWithControlChar(0x0c); + } + + private void testHeaderNameStartsWithControlChar(int controlChar) { + ByteBuf responseBuffer = Unpooled.buffer(); + responseBuffer.writeCharSequence("HTTP/1.1 200 OK\r\n" + + "Host: netty.io\r\n", CharsetUtil.US_ASCII); + responseBuffer.writeByte(controlChar); + responseBuffer.writeCharSequence("Transfer-Encoding: chunked\r\n\r\n", CharsetUtil.US_ASCII); + testInvalidHeaders0(responseBuffer); + } + + @Test + public void testHeaderNameEndsWithControlChar1c() { + testHeaderNameEndsWithControlChar(0x1c); + } + + @Test + public void testHeaderNameEndsWithControlChar1d() { + testHeaderNameEndsWithControlChar(0x1d); + } + + @Test + public void testHeaderNameEndsWithControlChar1e() { + testHeaderNameEndsWithControlChar(0x1e); + } + + @Test + public void testHeaderNameEndsWithControlChar1f() { + testHeaderNameEndsWithControlChar(0x1f); + } + + @Test + public void testHeaderNameEndsWithControlChar0c() { + testHeaderNameEndsWithControlChar(0x0c); + } + + private void testHeaderNameEndsWithControlChar(int controlChar) { + ByteBuf responseBuffer = Unpooled.buffer(); + responseBuffer.writeCharSequence("HTTP/1.1 200 OK\r\n" + + "Host: netty.io\r\n", CharsetUtil.US_ASCII); + responseBuffer.writeCharSequence("Transfer-Encoding", CharsetUtil.US_ASCII); + responseBuffer.writeByte(controlChar); + responseBuffer.writeCharSequence(": chunked\r\n\r\n", CharsetUtil.US_ASCII); + testInvalidHeaders0(responseBuffer); + } + + private static void testInvalidHeaders0(ByteBuf responseBuffer) { + EmbeddedChannel channel = new EmbeddedChannel(new HttpResponseDecoder()); + assertTrue(channel.writeInbound(responseBuffer)); + HttpResponse response = channel.readInbound(); + assertThat(response.decoderResult().cause(), instanceOf(IllegalArgumentException.class)); + assertTrue(response.decoderResult().isFailure()); + assertFalse(channel.finish()); + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpResponseEncoderTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpResponseEncoderTest.java new file mode 100644 index 0000000..3ecd918 --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpResponseEncoderTest.java @@ -0,0 +1,404 @@ +/* +* Copyright 2014 The Netty Project +* +* The Netty Project licenses this file to you under the Apache License, +* version 2.0 (the "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at: +* +* https://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ +package io.netty.handler.codec.http; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.FileRegion; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.util.CharsetUtil; +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.nio.channels.WritableByteChannel; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.*; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class HttpResponseEncoderTest { + private static final long INTEGER_OVERFLOW = (long) Integer.MAX_VALUE + 1; + private static final FileRegion FILE_REGION = new DummyLongFileRegion(); + + @Test + public void testLargeFileRegionChunked() throws Exception { + EmbeddedChannel channel = new EmbeddedChannel(new HttpResponseEncoder()); + HttpResponse response = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK); + response.headers().set(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED); + assertTrue(channel.writeOutbound(response)); + + ByteBuf buffer = channel.readOutbound(); + + assertEquals("HTTP/1.1 200 OK\r\n" + HttpHeaderNames.TRANSFER_ENCODING + ": " + + HttpHeaderValues.CHUNKED + "\r\n\r\n", buffer.toString(CharsetUtil.US_ASCII)); + buffer.release(); + assertTrue(channel.writeOutbound(FILE_REGION)); + buffer = channel.readOutbound(); + assertEquals("80000000\r\n", buffer.toString(CharsetUtil.US_ASCII)); + buffer.release(); + + FileRegion region = channel.readOutbound(); + assertSame(FILE_REGION, region); + region.release(); + buffer = channel.readOutbound(); + assertEquals("\r\n", buffer.toString(CharsetUtil.US_ASCII)); + buffer.release(); + + assertTrue(channel.writeOutbound(LastHttpContent.EMPTY_LAST_CONTENT)); + buffer = channel.readOutbound(); + assertEquals("0\r\n\r\n", buffer.toString(CharsetUtil.US_ASCII)); + buffer.release(); + + assertFalse(channel.finish()); + } + + private static class DummyLongFileRegion implements FileRegion { + + @Override + public long position() { + return 0; + } + + @Override + public long transfered() { + return 0; + } + + @Override + public long transferred() { + return 0; + } + + @Override + public long count() { + return INTEGER_OVERFLOW; + } + + @Override + public long transferTo(WritableByteChannel target, long position) throws IOException { + throw new UnsupportedOperationException(); + } + + @Override + public FileRegion touch(Object hint) { + return this; + } + + @Override + public FileRegion touch() { + return this; + } + + @Override + public FileRegion retain() { + return this; + } + + @Override + public FileRegion retain(int increment) { + return this; + } + + @Override + public int refCnt() { + return 1; + } + + @Override + public boolean release() { + return false; + } + + @Override + public boolean release(int decrement) { + return false; + } + } + + @Test + public void testEmptyBufferBypass() throws Exception { + EmbeddedChannel channel = new EmbeddedChannel(new HttpResponseEncoder()); + + // Test writing an empty buffer works when the encoder is at ST_INIT. + channel.writeOutbound(Unpooled.EMPTY_BUFFER); + ByteBuf buffer = channel.readOutbound(); + assertThat(buffer, is(sameInstance(Unpooled.EMPTY_BUFFER))); + + // Leave the ST_INIT state. + HttpResponse response = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK); + assertTrue(channel.writeOutbound(response)); + buffer = channel.readOutbound(); + assertEquals("HTTP/1.1 200 OK\r\n\r\n", buffer.toString(CharsetUtil.US_ASCII)); + buffer.release(); + + // Test writing an empty buffer works when the encoder is not at ST_INIT. + channel.writeOutbound(Unpooled.EMPTY_BUFFER); + buffer = channel.readOutbound(); + assertThat(buffer, is(sameInstance(Unpooled.EMPTY_BUFFER))); + + assertFalse(channel.finish()); + } + + @Test + public void testEmptyContentChunked() throws Exception { + testEmptyContent(true); + } + + @Test + public void testEmptyContentNotChunked() throws Exception { + testEmptyContent(false); + } + + private static void testEmptyContent(boolean chunked) throws Exception { + String content = "netty rocks"; + ByteBuf contentBuffer = Unpooled.copiedBuffer(content, CharsetUtil.US_ASCII); + int length = contentBuffer.readableBytes(); + + EmbeddedChannel channel = new EmbeddedChannel(new HttpResponseEncoder()); + HttpResponse response = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK); + if (!chunked) { + HttpUtil.setContentLength(response, length); + } + assertTrue(channel.writeOutbound(response)); + assertTrue(channel.writeOutbound(new DefaultHttpContent(Unpooled.EMPTY_BUFFER))); + assertTrue(channel.writeOutbound(new DefaultLastHttpContent(contentBuffer))); + + ByteBuf buffer = channel.readOutbound(); + if (!chunked) { + assertEquals("HTTP/1.1 200 OK\r\ncontent-length: " + length + "\r\n\r\n", + buffer.toString(CharsetUtil.US_ASCII)); + } else { + assertEquals("HTTP/1.1 200 OK\r\n\r\n", buffer.toString(CharsetUtil.US_ASCII)); + } + buffer.release(); + + // Test writing an empty buffer works when the encoder is not at ST_INIT. + buffer = channel.readOutbound(); + assertEquals(0, buffer.readableBytes()); + buffer.release(); + + buffer = channel.readOutbound(); + assertEquals(length, buffer.readableBytes()); + buffer.release(); + + assertFalse(channel.finish()); + } + + @Test + public void testStatusNoContent() throws Exception { + EmbeddedChannel channel = new EmbeddedChannel(new HttpResponseEncoder()); + assertEmptyResponse(channel, HttpResponseStatus.NO_CONTENT, null, false); + assertFalse(channel.finish()); + } + + @Test + public void testStatusNoContentContentLength() throws Exception { + EmbeddedChannel channel = new EmbeddedChannel(new HttpResponseEncoder()); + assertEmptyResponse(channel, HttpResponseStatus.NO_CONTENT, HttpHeaderNames.CONTENT_LENGTH, true); + assertFalse(channel.finish()); + } + + @Test + public void testStatusNoContentTransferEncoding() throws Exception { + EmbeddedChannel channel = new EmbeddedChannel(new HttpResponseEncoder()); + assertEmptyResponse(channel, HttpResponseStatus.NO_CONTENT, HttpHeaderNames.TRANSFER_ENCODING, true); + assertFalse(channel.finish()); + } + + @Test + public void testStatusNotModified() throws Exception { + EmbeddedChannel channel = new EmbeddedChannel(new HttpResponseEncoder()); + assertEmptyResponse(channel, HttpResponseStatus.NOT_MODIFIED, null, false); + assertFalse(channel.finish()); + } + + @Test + public void testStatusNotModifiedContentLength() throws Exception { + EmbeddedChannel channel = new EmbeddedChannel(new HttpResponseEncoder()); + assertEmptyResponse(channel, HttpResponseStatus.NOT_MODIFIED, HttpHeaderNames.CONTENT_LENGTH, false); + assertFalse(channel.finish()); + } + + @Test + public void testStatusNotModifiedTransferEncoding() throws Exception { + EmbeddedChannel channel = new EmbeddedChannel(new HttpResponseEncoder()); + assertEmptyResponse(channel, HttpResponseStatus.NOT_MODIFIED, HttpHeaderNames.TRANSFER_ENCODING, false); + assertFalse(channel.finish()); + } + + @Test + public void testStatusInformational() throws Exception { + EmbeddedChannel channel = new EmbeddedChannel(new HttpResponseEncoder()); + for (int code = 100; code < 200; code++) { + HttpResponseStatus status = HttpResponseStatus.valueOf(code); + assertEmptyResponse(channel, status, null, false); + } + assertFalse(channel.finish()); + } + + @Test + public void testStatusInformationalContentLength() throws Exception { + EmbeddedChannel channel = new EmbeddedChannel(new HttpResponseEncoder()); + for (int code = 100; code < 200; code++) { + HttpResponseStatus status = HttpResponseStatus.valueOf(code); + assertEmptyResponse(channel, status, HttpHeaderNames.CONTENT_LENGTH, code != 101); + } + assertFalse(channel.finish()); + } + + @Test + public void testStatusInformationalTransferEncoding() throws Exception { + EmbeddedChannel channel = new EmbeddedChannel(new HttpResponseEncoder()); + for (int code = 100; code < 200; code++) { + HttpResponseStatus status = HttpResponseStatus.valueOf(code); + assertEmptyResponse(channel, status, HttpHeaderNames.TRANSFER_ENCODING, code != 101); + } + assertFalse(channel.finish()); + } + + private static void assertEmptyResponse(EmbeddedChannel channel, HttpResponseStatus status, + CharSequence headerName, boolean headerStripped) { + HttpResponse response = new DefaultHttpResponse(HttpVersion.HTTP_1_1, status); + if (HttpHeaderNames.CONTENT_LENGTH.contentEquals(headerName)) { + response.headers().set(headerName, "0"); + } else if (HttpHeaderNames.TRANSFER_ENCODING.contentEquals(headerName)) { + response.headers().set(headerName, HttpHeaderValues.CHUNKED); + } + + assertTrue(channel.writeOutbound(response)); + assertTrue(channel.writeOutbound(LastHttpContent.EMPTY_LAST_CONTENT)); + + ByteBuf buffer = channel.readOutbound(); + StringBuilder responseText = new StringBuilder(); + responseText.append(HttpVersion.HTTP_1_1).append(' ').append(status.toString()).append("\r\n"); + if (!headerStripped && headerName != null) { + responseText.append(headerName).append(": "); + + if (HttpHeaderNames.CONTENT_LENGTH.contentEquals(headerName)) { + responseText.append('0'); + } else { + responseText.append(HttpHeaderValues.CHUNKED); + } + responseText.append("\r\n"); + } + responseText.append("\r\n"); + + assertEquals(responseText.toString(), buffer.toString(CharsetUtil.US_ASCII)); + + buffer.release(); + + buffer = channel.readOutbound(); + buffer.release(); + } + + @Test + public void testEmptyContentsChunked() throws Exception { + testEmptyContents(true, false); + } + + @Test + public void testEmptyContentsChunkedWithTrailers() throws Exception { + testEmptyContents(true, true); + } + + @Test + public void testEmptyContentsNotChunked() throws Exception { + testEmptyContents(false, false); + } + + @Test + public void testEmptyContentNotsChunkedWithTrailers() throws Exception { + testEmptyContents(false, true); + } + + private void testEmptyContents(boolean chunked, boolean trailers) throws Exception { + HttpResponseEncoder encoder = new HttpResponseEncoder(); + EmbeddedChannel channel = new EmbeddedChannel(encoder); + HttpResponse request = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK); + if (chunked) { + HttpUtil.setTransferEncodingChunked(request, true); + } + assertTrue(channel.writeOutbound(request)); + + ByteBuf contentBuffer = Unpooled.buffer(); + assertTrue(channel.writeOutbound(new DefaultHttpContent(contentBuffer))); + + ByteBuf lastContentBuffer = Unpooled.buffer(); + LastHttpContent last = new DefaultLastHttpContent(lastContentBuffer); + if (trailers) { + last.trailingHeaders().set("X-Netty-Test", "true"); + } + assertTrue(channel.writeOutbound(last)); + + // Ensure we only produce ByteBuf instances. + ByteBuf head = channel.readOutbound(); + assertTrue(head.release()); + + ByteBuf content = channel.readOutbound(); + content.release(); + + ByteBuf lastContent = channel.readOutbound(); + lastContent.release(); + assertFalse(channel.finish()); + } + + @Test + public void testStatusResetContentTransferContentLength() { + testStatusResetContentTransferContentLength0(HttpHeaderNames.CONTENT_LENGTH, Unpooled.buffer().writeLong(8)); + } + + @Test + public void testStatusResetContentTransferEncoding() { + testStatusResetContentTransferContentLength0(HttpHeaderNames.TRANSFER_ENCODING, Unpooled.buffer().writeLong(8)); + } + + private static void testStatusResetContentTransferContentLength0(CharSequence headerName, ByteBuf content) { + EmbeddedChannel channel = new EmbeddedChannel(new HttpResponseEncoder()); + + HttpResponse response = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.RESET_CONTENT); + if (HttpHeaderNames.CONTENT_LENGTH.contentEqualsIgnoreCase(headerName)) { + response.headers().set(HttpHeaderNames.CONTENT_LENGTH, content.readableBytes()); + } else { + response.headers().set(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED); + } + + assertTrue(channel.writeOutbound(response)); + assertTrue(channel.writeOutbound(new DefaultHttpContent(content))); + assertTrue(channel.writeOutbound(LastHttpContent.EMPTY_LAST_CONTENT)); + + StringBuilder responseText = new StringBuilder(); + responseText.append(HttpVersion.HTTP_1_1).append(' ') + .append(HttpResponseStatus.RESET_CONTENT).append("\r\n"); + responseText.append(HttpHeaderNames.CONTENT_LENGTH).append(": 0\r\n"); + responseText.append("\r\n"); + + StringBuilder written = new StringBuilder(); + for (;;) { + ByteBuf buffer = channel.readOutbound(); + if (buffer == null) { + break; + } + written.append(buffer.toString(CharsetUtil.US_ASCII)); + buffer.release(); + } + + assertEquals(responseText.toString(), written.toString()); + assertFalse(channel.finish()); + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpResponseStatusTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpResponseStatusTest.java new file mode 100644 index 0000000..067ece0 --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpResponseStatusTest.java @@ -0,0 +1,147 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.util.AsciiString; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import static io.netty.handler.codec.http.HttpResponseStatus.parseLine; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotSame; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.fail; + +public class HttpResponseStatusTest { + @Test + public void parseLineStringJustCode() { + assertSame(HttpResponseStatus.OK, parseLine("200")); + } + + @Test + public void parseLineStringCodeAndPhrase() { + assertSame(HttpResponseStatus.OK, parseLine("200 OK")); + } + + @Test + public void parseLineStringCustomCode() { + HttpResponseStatus customStatus = parseLine("612"); + assertEquals(612, customStatus.code()); + } + + @Test + public void parseLineStringCustomCodeAndPhrase() { + HttpResponseStatus customStatus = parseLine("612 FOO"); + assertEquals(612, customStatus.code()); + assertEquals("FOO", customStatus.reasonPhrase()); + } + + @Test + public void parseLineStringMalformedCode() { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + parseLine("200a"); + } + }); + } + + @Test + public void parseLineStringMalformedCodeWithPhrase() { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + parseLine("200a foo"); + } + }); + } + + @Test + public void parseLineAsciiStringJustCode() { + assertSame(HttpResponseStatus.OK, parseLine(new AsciiString("200"))); + } + + @Test + public void parseLineAsciiStringCodeAndPhrase() { + assertSame(HttpResponseStatus.OK, parseLine(new AsciiString("200 OK"))); + } + + @Test + public void parseLineAsciiStringCustomCode() { + HttpResponseStatus customStatus = parseLine(new AsciiString("612")); + assertEquals(612, customStatus.code()); + } + + @Test + public void parseLineAsciiStringCustomCodeAndPhrase() { + HttpResponseStatus customStatus = parseLine(new AsciiString("612 FOO")); + assertEquals(612, customStatus.code()); + assertEquals("FOO", customStatus.reasonPhrase()); + } + + @Test + public void parseLineAsciiStringMalformedCode() { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + parseLine(new AsciiString("200a")); + } + }); + } + + @Test + public void parseLineAsciiStringMalformedCodeWithPhrase() { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + parseLine(new AsciiString("200a foo")); + } + }); + } + + @Test + public void testHttpStatusClassValueOf() { + // status scope: [100, 600). + for (int code = 100; code < 600; code ++) { + HttpStatusClass httpStatusClass = HttpStatusClass.valueOf(code); + assertNotSame(HttpStatusClass.UNKNOWN, httpStatusClass); + if (HttpStatusClass.INFORMATIONAL.contains(code)) { + assertEquals(HttpStatusClass.INFORMATIONAL, httpStatusClass); + } else if (HttpStatusClass.SUCCESS.contains(code)) { + assertEquals(HttpStatusClass.SUCCESS, httpStatusClass); + } else if (HttpStatusClass.REDIRECTION.contains(code)) { + assertEquals(HttpStatusClass.REDIRECTION, httpStatusClass); + } else if (HttpStatusClass.CLIENT_ERROR.contains(code)) { + assertEquals(HttpStatusClass.CLIENT_ERROR, httpStatusClass); + } else if (HttpStatusClass.SERVER_ERROR.contains(code)) { + assertEquals(HttpStatusClass.SERVER_ERROR, httpStatusClass); + } else { + fail("At least one of the if-branches above must be true"); + } + } + // status scope: [Integer.MIN_VALUE, 100). + for (int code = Integer.MIN_VALUE; code < 100; code ++) { + HttpStatusClass httpStatusClass = HttpStatusClass.valueOf(code); + assertEquals(HttpStatusClass.UNKNOWN, httpStatusClass); + } + // status scope: [600, Integer.MAX_VALUE]. + for (int code = 600; code > 0; code ++) { + HttpStatusClass httpStatusClass = HttpStatusClass.valueOf(code); + assertEquals(HttpStatusClass.UNKNOWN, httpStatusClass); + } + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpServerCodecTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpServerCodecTest.java new file mode 100644 index 0000000..8589fb2 --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpServerCodecTest.java @@ -0,0 +1,185 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.util.CharsetUtil; +import org.junit.jupiter.api.Test; + +import static org.hamcrest.CoreMatchers.*; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class HttpServerCodecTest { + + /** + * Testcase for https://github.com/netty/netty/issues/433 + */ + @Test + public void testUnfinishedChunkedHttpRequestIsLastFlag() throws Exception { + + int maxChunkSize = 2000; + HttpServerCodec httpServerCodec = new HttpServerCodec(1000, 1000, maxChunkSize); + EmbeddedChannel decoderEmbedder = new EmbeddedChannel(httpServerCodec); + + int totalContentLength = maxChunkSize * 5; + decoderEmbedder.writeInbound(Unpooled.copiedBuffer( + "PUT /test HTTP/1.1\r\n" + + "Content-Length: " + totalContentLength + "\r\n" + + "\r\n", CharsetUtil.UTF_8)); + + int offeredContentLength = (int) (maxChunkSize * 2.5); + decoderEmbedder.writeInbound(prepareDataChunk(offeredContentLength)); + decoderEmbedder.finish(); + + HttpMessage httpMessage = decoderEmbedder.readInbound(); + assertNotNull(httpMessage); + + boolean empty = true; + int totalBytesPolled = 0; + for (;;) { + HttpContent httpChunk = decoderEmbedder.readInbound(); + if (httpChunk == null) { + break; + } + empty = false; + totalBytesPolled += httpChunk.content().readableBytes(); + assertFalse(httpChunk instanceof LastHttpContent); + httpChunk.release(); + } + assertFalse(empty); + assertEquals(offeredContentLength, totalBytesPolled); + } + + @Test + public void test100Continue() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new HttpServerCodec(), new HttpObjectAggregator(1024)); + + // Send the request headers. + ch.writeInbound(Unpooled.copiedBuffer( + "PUT /upload-large HTTP/1.1\r\n" + + "Expect: 100-continue\r\n" + + "Content-Length: 1\r\n\r\n", CharsetUtil.UTF_8)); + + // Ensure the aggregator generates nothing. + assertThat(ch.readInbound(), is(nullValue())); + + // Ensure the aggregator writes a 100 Continue response. + ByteBuf continueResponse = ch.readOutbound(); + assertThat(continueResponse.toString(CharsetUtil.UTF_8), is("HTTP/1.1 100 Continue\r\n\r\n")); + continueResponse.release(); + + // But nothing more. + assertThat(ch.readOutbound(), is(nullValue())); + + // Send the content of the request. + ch.writeInbound(Unpooled.wrappedBuffer(new byte[] { 42 })); + + // Ensure the aggregator generates a full request. + FullHttpRequest req = ch.readInbound(); + assertThat(req.headers().get(HttpHeaderNames.CONTENT_LENGTH), is("1")); + assertThat(req.content().readableBytes(), is(1)); + assertThat(req.content().readByte(), is((byte) 42)); + req.release(); + + // But nothing more. + assertThat(ch.readInbound(), is(nullValue())); + + // Send the actual response. + FullHttpResponse res = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.CREATED); + res.content().writeBytes("OK".getBytes(CharsetUtil.UTF_8)); + res.headers().setInt(HttpHeaderNames.CONTENT_LENGTH, 2); + ch.writeOutbound(res); + + // Ensure the encoder handles the response after handling 100 Continue. + ByteBuf encodedRes = ch.readOutbound(); + assertThat(encodedRes.toString(CharsetUtil.UTF_8), + is("HTTP/1.1 201 Created\r\n" + HttpHeaderNames.CONTENT_LENGTH + ": 2\r\n\r\nOK")); + encodedRes.release(); + + ch.finish(); + } + + @Test + public void testChunkedHeadResponse() { + EmbeddedChannel ch = new EmbeddedChannel(new HttpServerCodec()); + + // Send the request headers. + assertTrue(ch.writeInbound(Unpooled.copiedBuffer( + "HEAD / HTTP/1.1\r\n\r\n", CharsetUtil.UTF_8))); + + HttpRequest request = ch.readInbound(); + assertEquals(HttpMethod.HEAD, request.method()); + LastHttpContent content = ch.readInbound(); + assertFalse(content.content().isReadable()); + content.release(); + + HttpResponse response = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK); + HttpUtil.setTransferEncodingChunked(response, true); + assertTrue(ch.writeOutbound(response)); + assertTrue(ch.writeOutbound(LastHttpContent.EMPTY_LAST_CONTENT)); + assertTrue(ch.finish()); + + ByteBuf buf = ch.readOutbound(); + assertEquals("HTTP/1.1 200 OK\r\ntransfer-encoding: chunked\r\n\r\n", buf.toString(CharsetUtil.US_ASCII)); + buf.release(); + + buf = ch.readOutbound(); + assertFalse(buf.isReadable()); + buf.release(); + + assertFalse(ch.finishAndReleaseAll()); + } + + @Test + public void testChunkedHeadFullHttpResponse() { + EmbeddedChannel ch = new EmbeddedChannel(new HttpServerCodec()); + + // Send the request headers. + assertTrue(ch.writeInbound(Unpooled.copiedBuffer( + "HEAD / HTTP/1.1\r\n\r\n", CharsetUtil.UTF_8))); + + HttpRequest request = ch.readInbound(); + assertEquals(HttpMethod.HEAD, request.method()); + LastHttpContent content = ch.readInbound(); + assertFalse(content.content().isReadable()); + content.release(); + + FullHttpResponse response = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK); + HttpUtil.setTransferEncodingChunked(response, true); + assertTrue(ch.writeOutbound(response)); + assertTrue(ch.finish()); + + ByteBuf buf = ch.readOutbound(); + assertEquals("HTTP/1.1 200 OK\r\ntransfer-encoding: chunked\r\n\r\n", buf.toString(CharsetUtil.US_ASCII)); + buf.release(); + + assertFalse(ch.finishAndReleaseAll()); + } + + private static ByteBuf prepareDataChunk(int size) { + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < size; ++i) { + sb.append('a'); + } + return Unpooled.copiedBuffer(sb.toString(), CharsetUtil.UTF_8); + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpServerExpectContinueHandlerTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpServerExpectContinueHandlerTest.java new file mode 100644 index 0000000..d908c8d --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpServerExpectContinueHandlerTest.java @@ -0,0 +1,85 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.util.ReferenceCountUtil; +import org.junit.jupiter.api.Test; + +import static org.hamcrest.CoreMatchers.*; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class HttpServerExpectContinueHandlerTest { + + @Test + public void shouldRespondToExpectedHeader() { + EmbeddedChannel channel = new EmbeddedChannel(new HttpServerExpectContinueHandler() { + @Override + protected HttpResponse acceptMessage(HttpRequest request) { + HttpResponse response = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.CONTINUE); + response.headers().set("foo", "bar"); + return response; + } + }); + HttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"); + HttpUtil.set100ContinueExpected(request, true); + + channel.writeInbound(request); + HttpResponse response = channel.readOutbound(); + + assertThat(response.status(), is(HttpResponseStatus.CONTINUE)); + assertThat(response.headers().get("foo"), is("bar")); + ReferenceCountUtil.release(response); + + HttpRequest processedRequest = channel.readInbound(); + assertFalse(processedRequest.headers().contains(HttpHeaderNames.EXPECT)); + ReferenceCountUtil.release(processedRequest); + assertFalse(channel.finishAndReleaseAll()); + } + + @Test + public void shouldAllowCustomResponses() { + EmbeddedChannel channel = new EmbeddedChannel( + new HttpServerExpectContinueHandler() { + @Override + protected HttpResponse acceptMessage(HttpRequest request) { + return null; + } + + @Override + protected HttpResponse rejectResponse(HttpRequest request) { + return new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, + HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE); + } + } + ); + + HttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/"); + HttpUtil.set100ContinueExpected(request, true); + + channel.writeInbound(request); + HttpResponse response = channel.readOutbound(); + + assertThat(response.status(), is(HttpResponseStatus.REQUEST_ENTITY_TOO_LARGE)); + ReferenceCountUtil.release(response); + + // request was swallowed + assertTrue(channel.inboundMessages().isEmpty()); + assertFalse(channel.finishAndReleaseAll()); + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpServerKeepAliveHandlerTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpServerKeepAliveHandlerTest.java new file mode 100644 index 0000000..0ef78e0 --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpServerKeepAliveHandlerTest.java @@ -0,0 +1,235 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.util.AsciiString; +import io.netty.util.ReferenceCountUtil; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import java.util.Arrays; +import java.util.Collection; + +import static io.netty.handler.codec.http.HttpHeaderValues.CLOSE; +import static io.netty.handler.codec.http.HttpHeaderValues.KEEP_ALIVE; +import static io.netty.handler.codec.http.HttpHeaderValues.MULTIPART_MIXED; +import static io.netty.handler.codec.http.HttpResponseStatus.NO_CONTENT; +import static io.netty.handler.codec.http.HttpResponseStatus.OK; +import static io.netty.handler.codec.http.HttpUtil.isContentLengthSet; +import static io.netty.handler.codec.http.HttpUtil.isKeepAlive; +import static io.netty.handler.codec.http.HttpUtil.setContentLength; +import static io.netty.handler.codec.http.HttpUtil.setKeepAlive; +import static io.netty.handler.codec.http.HttpUtil.setTransferEncodingChunked; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class HttpServerKeepAliveHandlerTest { + private static final String REQUEST_KEEP_ALIVE = "REQUEST_KEEP_ALIVE"; + private static final int NOT_SELF_DEFINED_MSG_LENGTH = 0; + private static final int SET_RESPONSE_LENGTH = 1; + private static final int SET_MULTIPART = 2; + private static final int SET_CHUNKED = 4; + + private EmbeddedChannel channel; + + @BeforeEach + public void setUp() { + channel = new EmbeddedChannel(new HttpServerKeepAliveHandler()); + } + + static Collection keepAliveProvider() { + return Arrays.asList(new Object[][] { + { true, HttpVersion.HTTP_1_0, OK, REQUEST_KEEP_ALIVE, SET_RESPONSE_LENGTH, KEEP_ALIVE }, // 0 + { true, HttpVersion.HTTP_1_0, OK, REQUEST_KEEP_ALIVE, SET_MULTIPART, KEEP_ALIVE }, // 1 + { false, HttpVersion.HTTP_1_0, OK, null, SET_RESPONSE_LENGTH, null }, // 2 + { true, HttpVersion.HTTP_1_1, OK, REQUEST_KEEP_ALIVE, SET_RESPONSE_LENGTH, null }, // 3 + { false, HttpVersion.HTTP_1_1, OK, REQUEST_KEEP_ALIVE, SET_RESPONSE_LENGTH, CLOSE }, // 4 + { true, HttpVersion.HTTP_1_1, OK, REQUEST_KEEP_ALIVE, SET_MULTIPART, null }, // 5 + { true, HttpVersion.HTTP_1_1, OK, REQUEST_KEEP_ALIVE, SET_CHUNKED, null }, // 6 + { false, HttpVersion.HTTP_1_1, OK, null, SET_RESPONSE_LENGTH, null }, // 7 + { false, HttpVersion.HTTP_1_0, OK, REQUEST_KEEP_ALIVE, NOT_SELF_DEFINED_MSG_LENGTH, null }, // 8 + { false, HttpVersion.HTTP_1_0, OK, null, NOT_SELF_DEFINED_MSG_LENGTH, null }, // 9 + { false, HttpVersion.HTTP_1_1, OK, REQUEST_KEEP_ALIVE, NOT_SELF_DEFINED_MSG_LENGTH, null }, // 10 + { false, HttpVersion.HTTP_1_1, OK, null, NOT_SELF_DEFINED_MSG_LENGTH, null }, // 11 + { false, HttpVersion.HTTP_1_0, OK, REQUEST_KEEP_ALIVE, SET_RESPONSE_LENGTH, null }, // 12 + { true, HttpVersion.HTTP_1_1, NO_CONTENT, REQUEST_KEEP_ALIVE, NOT_SELF_DEFINED_MSG_LENGTH, null}, // 13 + { false, HttpVersion.HTTP_1_0, NO_CONTENT, null, NOT_SELF_DEFINED_MSG_LENGTH, null} // 14 + }); + } + + @ParameterizedTest + @MethodSource("keepAliveProvider") + public void test_KeepAlive(boolean isKeepAliveResponseExpected, HttpVersion httpVersion, + HttpResponseStatus responseStatus, + String sendKeepAlive, int setSelfDefinedMessageLength, + AsciiString setResponseConnection) throws Exception { + FullHttpRequest request = new DefaultFullHttpRequest(httpVersion, HttpMethod.GET, "/v1/foo/bar"); + setKeepAlive(request, REQUEST_KEEP_ALIVE.equals(sendKeepAlive)); + HttpResponse response = new DefaultFullHttpResponse(httpVersion, responseStatus); + if (setResponseConnection != null) { + response.headers().set(HttpHeaderNames.CONNECTION, setResponseConnection); + } + setupMessageLength(response, setSelfDefinedMessageLength); + + assertTrue(channel.writeInbound(request)); + Object requestForwarded = channel.readInbound(); + assertEquals(request, requestForwarded); + ReferenceCountUtil.release(requestForwarded); + channel.writeAndFlush(response); + HttpResponse writtenResponse = channel.readOutbound(); + + assertEquals(isKeepAliveResponseExpected, channel.isOpen(), "channel.isOpen"); + assertEquals(isKeepAliveResponseExpected, isKeepAlive(writtenResponse), "response keep-alive"); + ReferenceCountUtil.release(writtenResponse); + assertFalse(channel.finishAndReleaseAll()); + } + + static Collection connectionCloseProvider() { + return Arrays.asList(new Object[][] { + { HttpVersion.HTTP_1_0, OK, SET_RESPONSE_LENGTH }, + { HttpVersion.HTTP_1_0, OK, SET_MULTIPART }, + { HttpVersion.HTTP_1_0, OK, NOT_SELF_DEFINED_MSG_LENGTH }, + { HttpVersion.HTTP_1_0, NO_CONTENT, NOT_SELF_DEFINED_MSG_LENGTH }, + { HttpVersion.HTTP_1_1, OK, SET_RESPONSE_LENGTH }, + { HttpVersion.HTTP_1_1, OK, SET_MULTIPART }, + { HttpVersion.HTTP_1_1, OK, NOT_SELF_DEFINED_MSG_LENGTH }, + { HttpVersion.HTTP_1_1, OK, SET_CHUNKED }, + { HttpVersion.HTTP_1_1, NO_CONTENT, NOT_SELF_DEFINED_MSG_LENGTH } + }); + } + + @ParameterizedTest + @MethodSource("connectionCloseProvider") + public void testConnectionCloseHeaderHandledCorrectly( + HttpVersion httpVersion, HttpResponseStatus responseStatus, int setSelfDefinedMessageLength) { + HttpResponse response = new DefaultFullHttpResponse(httpVersion, responseStatus); + response.headers().set(HttpHeaderNames.CONNECTION, HttpHeaderValues.CLOSE); + setupMessageLength(response, setSelfDefinedMessageLength); + + channel.writeAndFlush(response); + HttpResponse writtenResponse = channel.readOutbound(); + + assertFalse(channel.isOpen()); + ReferenceCountUtil.release(writtenResponse); + assertFalse(channel.finishAndReleaseAll()); + } + + @ParameterizedTest + @MethodSource("connectionCloseProvider") + public void testConnectionCloseHeaderHandledCorrectlyForVoidPromise( + HttpVersion httpVersion, HttpResponseStatus responseStatus, int setSelfDefinedMessageLength) { + HttpResponse response = new DefaultFullHttpResponse(httpVersion, responseStatus); + response.headers().set(HttpHeaderNames.CONNECTION, HttpHeaderValues.CLOSE); + setupMessageLength(response, setSelfDefinedMessageLength); + + channel.writeAndFlush(response, channel.voidPromise()); + HttpResponse writtenResponse = channel.readOutbound(); + + assertFalse(channel.isOpen()); + ReferenceCountUtil.release(writtenResponse); + assertFalse(channel.finishAndReleaseAll()); + } + + @ParameterizedTest + @MethodSource("keepAliveProvider") + public void testPipelineKeepAlive(boolean isKeepAliveResponseExpected, HttpVersion httpVersion, + HttpResponseStatus responseStatus, + String sendKeepAlive, int setSelfDefinedMessageLength, + AsciiString setResponseConnection) { + FullHttpRequest firstRequest = new DefaultFullHttpRequest(httpVersion, HttpMethod.GET, "/v1/foo/bar"); + setKeepAlive(firstRequest, true); + FullHttpRequest secondRequest = new DefaultFullHttpRequest(httpVersion, HttpMethod.GET, "/v1/foo/bar"); + setKeepAlive(secondRequest, REQUEST_KEEP_ALIVE.equals(sendKeepAlive)); + FullHttpRequest finalRequest = new DefaultFullHttpRequest(httpVersion, HttpMethod.GET, "/v1/foo/bar"); + setKeepAlive(finalRequest, false); + FullHttpResponse response = new DefaultFullHttpResponse(httpVersion, responseStatus); + FullHttpResponse informationalResp = new DefaultFullHttpResponse(httpVersion, HttpResponseStatus.PROCESSING); + setKeepAlive(response, true); + setContentLength(response, 0); + setKeepAlive(informationalResp, true); + + assertTrue(channel.writeInbound(firstRequest, secondRequest, finalRequest)); + + Object requestForwarded = channel.readInbound(); + assertEquals(firstRequest, requestForwarded); + ReferenceCountUtil.release(requestForwarded); + + channel.writeAndFlush(response.retainedDuplicate()); + HttpResponse firstResponse = channel.readOutbound(); + assertTrue(channel.isOpen(), "channel.isOpen"); + assertTrue(isKeepAlive(firstResponse), "response keep-alive"); + ReferenceCountUtil.release(firstResponse); + + requestForwarded = channel.readInbound(); + assertEquals(secondRequest, requestForwarded); + ReferenceCountUtil.release(requestForwarded); + + channel.writeAndFlush(informationalResp); + HttpResponse writtenInfoResp = channel.readOutbound(); + assertTrue(channel.isOpen(), "channel.isOpen"); + assertTrue(isKeepAlive(writtenInfoResp), "response keep-alive"); + ReferenceCountUtil.release(writtenInfoResp); + + if (setResponseConnection != null) { + response.headers().set(HttpHeaderNames.CONNECTION, setResponseConnection); + } else { + response.headers().remove(HttpHeaderNames.CONNECTION); + } + setupMessageLength(response, setSelfDefinedMessageLength); + channel.writeAndFlush(response.retainedDuplicate()); + HttpResponse secondResponse = channel.readOutbound(); + assertEquals(isKeepAliveResponseExpected, channel.isOpen(), "channel.isOpen"); + assertEquals(isKeepAliveResponseExpected, isKeepAlive(secondResponse), "response keep-alive"); + ReferenceCountUtil.release(secondResponse); + + requestForwarded = channel.readInbound(); + assertEquals(finalRequest, requestForwarded); + ReferenceCountUtil.release(requestForwarded); + + if (isKeepAliveResponseExpected) { + channel.writeAndFlush(response); + HttpResponse finalResponse = channel.readOutbound(); + assertFalse(channel.isOpen(), "channel.isOpen"); + assertFalse(isKeepAlive(finalResponse), "response keep-alive"); + } + ReferenceCountUtil.release(response); + assertFalse(channel.finishAndReleaseAll()); + } + + private static void setupMessageLength(HttpResponse response, int setSelfDefinedMessageLength) { + switch (setSelfDefinedMessageLength) { + case NOT_SELF_DEFINED_MSG_LENGTH: + if (isContentLengthSet(response)) { + response.headers().remove(HttpHeaderNames.CONTENT_LENGTH); + } + break; + case SET_RESPONSE_LENGTH: + setContentLength(response, 0); + break; + case SET_CHUNKED: + setTransferEncodingChunked(response, true); + break; + case SET_MULTIPART: + response.headers().set(HttpHeaderNames.CONTENT_TYPE, MULTIPART_MIXED.toUpperCase()); + break; + default: + throw new IllegalArgumentException("selfDefinedMessageLength: " + setSelfDefinedMessageLength); + } + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpServerUpgradeHandlerTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpServerUpgradeHandlerTest.java new file mode 100644 index 0000000..e0f45c7 --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpServerUpgradeHandlerTest.java @@ -0,0 +1,232 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import java.util.Collection; +import java.util.Collections; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelDuplexHandler; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelPromise; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.http.HttpServerUpgradeHandler.UpgradeCodec; +import io.netty.handler.codec.http.HttpServerUpgradeHandler.UpgradeCodecFactory; +import io.netty.util.CharsetUtil; +import io.netty.util.ReferenceCountUtil; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class HttpServerUpgradeHandlerTest { + + private static class TestUpgradeCodec implements UpgradeCodec { + @Override + public Collection requiredUpgradeHeaders() { + return Collections.emptyList(); + } + + @Override + public boolean prepareUpgradeResponse(ChannelHandlerContext ctx, FullHttpRequest upgradeRequest, + HttpHeaders upgradeHeaders) { + return true; + } + + @Override + public void upgradeTo(ChannelHandlerContext ctx, FullHttpRequest upgradeRequest) { + // Ensure that the HttpServerUpgradeHandler is still installed when this is called + assertEquals(ctx.pipeline().context(HttpServerUpgradeHandler.class), ctx); + assertNotNull(ctx.pipeline().get(HttpServerUpgradeHandler.class)); + + // Add a marker handler to signal that the upgrade has happened + ctx.pipeline().addAfter(ctx.name(), "marker", new ChannelInboundHandlerAdapter()); + } + } + + @Test + public void upgradesPipelineInSameMethodInvocation() { + final HttpServerCodec httpServerCodec = new HttpServerCodec(); + final UpgradeCodecFactory factory = new UpgradeCodecFactory() { + @Override + public UpgradeCodec newUpgradeCodec(CharSequence protocol) { + return new TestUpgradeCodec(); + } + }; + + ChannelHandler testInStackFrame = new ChannelDuplexHandler() { + // marker boolean to signal that we're in the `channelRead` method + private boolean inReadCall; + private boolean writeUpgradeMessage; + private boolean writeFlushed; + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + assertFalse(inReadCall); + assertFalse(writeUpgradeMessage); + + inReadCall = true; + try { + super.channelRead(ctx, msg); + // All in the same call stack, the upgrade codec should receive the message, + // written the upgrade response, and upgraded the pipeline. + assertTrue(writeUpgradeMessage); + assertFalse(writeFlushed); + assertNull(ctx.pipeline().get(HttpServerCodec.class)); + assertNotNull(ctx.pipeline().get("marker")); + } finally { + inReadCall = false; + } + } + + @Override + public void write(final ChannelHandlerContext ctx, final Object msg, final ChannelPromise promise) { + // We ensure that we're in the read call and defer the write so we can + // make sure the pipeline was reformed irrespective of the flush completing. + assertTrue(inReadCall); + writeUpgradeMessage = true; + ctx.channel().eventLoop().execute(new Runnable() { + @Override + public void run() { + ctx.write(msg, promise); + } + }); + promise.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + writeFlushed = true; + } + }); + } + }; + + HttpServerUpgradeHandler upgradeHandler = new HttpServerUpgradeHandler(httpServerCodec, factory); + + EmbeddedChannel channel = new EmbeddedChannel(testInStackFrame, httpServerCodec, upgradeHandler); + + String upgradeString = "GET / HTTP/1.1\r\n" + + "Host: example.com\r\n" + + "Connection: Upgrade, HTTP2-Settings\r\n" + + "Upgrade: nextprotocol\r\n" + + "HTTP2-Settings: AAMAAABkAAQAAP__\r\n\r\n"; + ByteBuf upgrade = Unpooled.copiedBuffer(upgradeString, CharsetUtil.US_ASCII); + + assertFalse(channel.writeInbound(upgrade)); + assertNull(channel.pipeline().get(HttpServerCodec.class)); + assertNotNull(channel.pipeline().get("marker")); + + channel.flushOutbound(); + ByteBuf upgradeMessage = channel.readOutbound(); + String expectedHttpResponse = "HTTP/1.1 101 Switching Protocols\r\n" + + "connection: upgrade\r\n" + + "upgrade: nextprotocol\r\n\r\n"; + assertEquals(expectedHttpResponse, upgradeMessage.toString(CharsetUtil.US_ASCII)); + assertTrue(upgradeMessage.release()); + assertFalse(channel.finishAndReleaseAll()); + } + + @Test + public void skippedUpgrade() { + final HttpServerCodec httpServerCodec = new HttpServerCodec(); + final UpgradeCodecFactory factory = new UpgradeCodecFactory() { + @Override + public UpgradeCodec newUpgradeCodec(CharSequence protocol) { + fail("Should never be invoked"); + return null; + } + }; + + HttpServerUpgradeHandler upgradeHandler = new HttpServerUpgradeHandler(httpServerCodec, factory) { + @Override + protected boolean shouldHandleUpgradeRequest(HttpRequest req) { + return !req.headers().contains(HttpHeaderNames.UPGRADE, "do-not-upgrade", false); + } + }; + + EmbeddedChannel channel = new EmbeddedChannel(httpServerCodec, upgradeHandler); + + String upgradeString = "GET / HTTP/1.1\r\n" + + "Host: example.com\r\n" + + "Connection: Upgrade\r\n" + + "Upgrade: do-not-upgrade\r\n\r\n"; + ByteBuf upgrade = Unpooled.copiedBuffer(upgradeString, CharsetUtil.US_ASCII); + + // The upgrade request should not be passed to the next handler without any processing. + assertTrue(channel.writeInbound(upgrade)); + assertNotNull(channel.pipeline().get(HttpServerCodec.class)); + assertNull(channel.pipeline().get("marker")); + + HttpRequest req = channel.readInbound(); + assertThat(req).isNotInstanceOf(FullHttpRequest.class); // Should not be aggregated. + assertTrue(req.headers().contains(HttpHeaderNames.CONNECTION, "Upgrade", false)); + assertTrue(req.headers().contains(HttpHeaderNames.UPGRADE, "do-not-upgrade", false)); + assertTrue(channel.readInbound() instanceof LastHttpContent); + assertNull(channel.readInbound()); + + // No response should be written because we're just passing through. + channel.flushOutbound(); + assertNull(channel.readOutbound()); + assertFalse(channel.finishAndReleaseAll()); + } + + @Test + public void upgradeFail() { + final HttpServerCodec httpServerCodec = new HttpServerCodec(); + final UpgradeCodecFactory factory = new UpgradeCodecFactory() { + @Override + public UpgradeCodec newUpgradeCodec(CharSequence protocol) { + return new TestUpgradeCodec(); + } + }; + + HttpServerUpgradeHandler upgradeHandler = new HttpServerUpgradeHandler(httpServerCodec, factory); + + EmbeddedChannel channel = new EmbeddedChannel(httpServerCodec, upgradeHandler); + + // Build a h2c upgrade request, but without connection header. + String upgradeString = "GET / HTTP/1.1\r\n" + + "Host: example.com\r\n" + + "Upgrade: h2c\r\n\r\n"; + ByteBuf upgrade = Unpooled.copiedBuffer(upgradeString, CharsetUtil.US_ASCII); + + assertTrue(channel.writeInbound(upgrade)); + assertNotNull(channel.pipeline().get(HttpServerCodec.class)); + assertNotNull(channel.pipeline().get(HttpServerUpgradeHandler.class)); // Should not be removed. + assertNull(channel.pipeline().get("marker")); + + HttpRequest req = channel.readInbound(); + assertEquals(HttpVersion.HTTP_1_1, req.protocolVersion()); + assertTrue(req.headers().contains(HttpHeaderNames.UPGRADE, "h2c", false)); + assertFalse(req.headers().contains(HttpHeaderNames.CONNECTION)); + ReferenceCountUtil.release(req); + assertNull(channel.readInbound()); + + // No response should be written because we're just passing through. + channel.flushOutbound(); + assertNull(channel.readOutbound()); + assertFalse(channel.finishAndReleaseAll()); + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpUtilTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpUtilTest.java new file mode 100644 index 0000000..b9b72e0 --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/HttpUtilTest.java @@ -0,0 +1,449 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.util.CharsetUtil; +import io.netty.util.ReferenceCountUtil; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.URI; +import java.nio.charset.StandardCharsets; +import java.util.ArrayList; +import java.util.List; + +import static io.netty.handler.codec.http.HttpHeadersTestUtils.of; +import static io.netty.handler.codec.http.HttpUtil.normalizeAndGetContentLength; +import static java.util.Collections.singletonList; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class HttpUtilTest { + + @Test + public void testRecognizesOriginForm() { + // Origin form: https://tools.ietf.org/html/rfc7230#section-5.3.1 + assertTrue(HttpUtil.isOriginForm(URI.create("/where?q=now"))); + // Absolute form: https://tools.ietf.org/html/rfc7230#section-5.3.2 + assertFalse(HttpUtil.isOriginForm(URI.create("http://www.example.org/pub/WWW/TheProject.html"))); + // Authority form: https://tools.ietf.org/html/rfc7230#section-5.3.3 + assertFalse(HttpUtil.isOriginForm(URI.create("www.example.com:80"))); + // Asterisk form: https://tools.ietf.org/html/rfc7230#section-5.3.4 + assertFalse(HttpUtil.isOriginForm(URI.create("*"))); + } + + @Test public void testRecognizesAsteriskForm() { + // Asterisk form: https://tools.ietf.org/html/rfc7230#section-5.3.4 + assertTrue(HttpUtil.isAsteriskForm(URI.create("*"))); + // Origin form: https://tools.ietf.org/html/rfc7230#section-5.3.1 + assertFalse(HttpUtil.isAsteriskForm(URI.create("/where?q=now"))); + // Absolute form: https://tools.ietf.org/html/rfc7230#section-5.3.2 + assertFalse(HttpUtil.isAsteriskForm(URI.create("http://www.example.org/pub/WWW/TheProject.html"))); + // Authority form: https://tools.ietf.org/html/rfc7230#section-5.3.3 + assertFalse(HttpUtil.isAsteriskForm(URI.create("www.example.com:80"))); + } + + @Test + public void testRemoveTransferEncodingIgnoreCase() { + HttpMessage message = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK); + message.headers().set(HttpHeaderNames.TRANSFER_ENCODING, "Chunked"); + assertFalse(message.headers().isEmpty()); + HttpUtil.setTransferEncodingChunked(message, false); + assertTrue(message.headers().isEmpty()); + } + + // Test for https://github.com/netty/netty/issues/1690 + @Test + public void testGetOperations() { + HttpHeaders headers = new DefaultHttpHeaders(); + headers.add(of("Foo"), of("1")); + headers.add(of("Foo"), of("2")); + + assertEquals("1", headers.get(of("Foo"))); + + List values = headers.getAll(of("Foo")); + assertEquals(2, values.size()); + assertEquals("1", values.get(0)); + assertEquals("2", values.get(1)); + } + + @Test + public void testGetCharsetAsRawCharSequence() { + String QUOTES_CHARSET_CONTENT_TYPE = "text/html; charset=\"utf8\""; + String SIMPLE_CONTENT_TYPE = "text/html"; + + HttpMessage message = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK); + message.headers().set(HttpHeaderNames.CONTENT_TYPE, QUOTES_CHARSET_CONTENT_TYPE); + assertEquals("\"utf8\"", HttpUtil.getCharsetAsSequence(message)); + assertEquals("\"utf8\"", HttpUtil.getCharsetAsSequence(QUOTES_CHARSET_CONTENT_TYPE)); + + message.headers().set(HttpHeaderNames.CONTENT_TYPE, "text/html"); + assertNull(HttpUtil.getCharsetAsSequence(message)); + assertNull(HttpUtil.getCharsetAsSequence(SIMPLE_CONTENT_TYPE)); + } + + @Test + public void testGetCharset() { + testGetCharsetUtf8("text/html; charset=utf-8"); + } + + @Test + public void testGetCharsetNoSpace() { + testGetCharsetUtf8("text/html;charset=utf-8"); + } + + @Test + public void testGetCharsetQuoted() { + testGetCharsetUtf8("text/html; charset=\"utf-8\""); + } + + @Test + public void testGetCharsetNoSpaceQuoted() { + testGetCharsetUtf8("text/html;charset=\"utf-8\""); + } + + private void testGetCharsetUtf8(String contentType) { + String UPPER_CASE_NORMAL_CONTENT_TYPE = contentType.toUpperCase(); + + HttpMessage message = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK); + message.headers().set(HttpHeaderNames.CONTENT_TYPE, contentType); + assertEquals(CharsetUtil.UTF_8, HttpUtil.getCharset(message)); + assertEquals(CharsetUtil.UTF_8, HttpUtil.getCharset(contentType)); + + message.headers().set(HttpHeaderNames.CONTENT_TYPE, UPPER_CASE_NORMAL_CONTENT_TYPE); + assertEquals(CharsetUtil.UTF_8, HttpUtil.getCharset(message)); + assertEquals(CharsetUtil.UTF_8, HttpUtil.getCharset(UPPER_CASE_NORMAL_CONTENT_TYPE)); + } + + @Test + public void testGetCharsetNoLeadingQuotes() { + testGetCharsetInvalidQuotes("text/html;charset=utf-8\""); + } + + @Test + public void testGetCharsetNoTrailingQuotes() { + testGetCharsetInvalidQuotes("text/html;charset=\"utf-8"); + } + + @Test + public void testGetCharsetOnlyQuotes() { + testGetCharsetInvalidQuotes("text/html;charset=\"\""); + } + + private static void testGetCharsetInvalidQuotes(String contentType) { + String UPPER_CASE_NORMAL_CONTENT_TYPE = contentType.toUpperCase(); + + HttpMessage message = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK); + message.headers().set(HttpHeaderNames.CONTENT_TYPE, contentType); + assertEquals(CharsetUtil.ISO_8859_1, HttpUtil.getCharset(message, CharsetUtil.ISO_8859_1)); + assertEquals(CharsetUtil.ISO_8859_1, HttpUtil.getCharset(contentType, CharsetUtil.ISO_8859_1)); + + message.headers().set(HttpHeaderNames.CONTENT_TYPE, UPPER_CASE_NORMAL_CONTENT_TYPE); + assertEquals(CharsetUtil.ISO_8859_1, HttpUtil.getCharset(message, CharsetUtil.ISO_8859_1)); + assertEquals(CharsetUtil.ISO_8859_1, HttpUtil.getCharset(UPPER_CASE_NORMAL_CONTENT_TYPE, + CharsetUtil.ISO_8859_1)); + } + + @Test + public void testGetCharsetIfNotLastParameter() { + String NORMAL_CONTENT_TYPE_WITH_PARAMETERS = "application/soap-xml; charset=utf-8; " + + "action=\"http://www.soap-service.by/foo/add\""; + + HttpMessage message = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, + "http://localhost:7788/foo"); + message.headers().set(HttpHeaderNames.CONTENT_TYPE, NORMAL_CONTENT_TYPE_WITH_PARAMETERS); + + assertEquals(CharsetUtil.UTF_8, HttpUtil.getCharset(message)); + assertEquals(CharsetUtil.UTF_8, HttpUtil.getCharset(NORMAL_CONTENT_TYPE_WITH_PARAMETERS)); + + assertEquals("utf-8", HttpUtil.getCharsetAsSequence(message)); + assertEquals("utf-8", HttpUtil.getCharsetAsSequence(NORMAL_CONTENT_TYPE_WITH_PARAMETERS)); + } + + @Test + public void testGetCharset_defaultValue() { + final String SIMPLE_CONTENT_TYPE = "text/html"; + final String CONTENT_TYPE_WITH_INCORRECT_CHARSET = "text/html; charset=UTFFF"; + final String CONTENT_TYPE_WITH_ILLEGAL_CHARSET_NAME = "text/html; charset=!illegal!"; + + HttpMessage message = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK); + message.headers().set(HttpHeaderNames.CONTENT_TYPE, SIMPLE_CONTENT_TYPE); + assertEquals(CharsetUtil.ISO_8859_1, HttpUtil.getCharset(message)); + assertEquals(CharsetUtil.ISO_8859_1, HttpUtil.getCharset(SIMPLE_CONTENT_TYPE)); + + message.headers().set(HttpHeaderNames.CONTENT_TYPE, SIMPLE_CONTENT_TYPE); + assertEquals(CharsetUtil.UTF_8, HttpUtil.getCharset(message, StandardCharsets.UTF_8)); + assertEquals(CharsetUtil.UTF_8, HttpUtil.getCharset(SIMPLE_CONTENT_TYPE, StandardCharsets.UTF_8)); + + message.headers().set(HttpHeaderNames.CONTENT_TYPE, CONTENT_TYPE_WITH_INCORRECT_CHARSET); + assertEquals(CharsetUtil.ISO_8859_1, HttpUtil.getCharset(message)); + assertEquals(CharsetUtil.ISO_8859_1, HttpUtil.getCharset(CONTENT_TYPE_WITH_INCORRECT_CHARSET)); + + message.headers().set(HttpHeaderNames.CONTENT_TYPE, CONTENT_TYPE_WITH_INCORRECT_CHARSET); + assertEquals(CharsetUtil.UTF_8, HttpUtil.getCharset(message, StandardCharsets.UTF_8)); + assertEquals(CharsetUtil.UTF_8, + HttpUtil.getCharset(CONTENT_TYPE_WITH_INCORRECT_CHARSET, StandardCharsets.UTF_8)); + + message.headers().set(HttpHeaderNames.CONTENT_TYPE, CONTENT_TYPE_WITH_ILLEGAL_CHARSET_NAME); + assertEquals(CharsetUtil.ISO_8859_1, HttpUtil.getCharset(message)); + assertEquals(CharsetUtil.ISO_8859_1, HttpUtil.getCharset(CONTENT_TYPE_WITH_ILLEGAL_CHARSET_NAME)); + + message.headers().set(HttpHeaderNames.CONTENT_TYPE, CONTENT_TYPE_WITH_ILLEGAL_CHARSET_NAME); + assertEquals(CharsetUtil.UTF_8, HttpUtil.getCharset(message, StandardCharsets.UTF_8)); + assertEquals(CharsetUtil.UTF_8, + HttpUtil.getCharset(CONTENT_TYPE_WITH_ILLEGAL_CHARSET_NAME, StandardCharsets.UTF_8)); + } + + @Test + public void testGetMimeType() { + final String SIMPLE_CONTENT_TYPE = "text/html"; + final String NORMAL_CONTENT_TYPE = "text/html; charset=utf-8"; + + HttpMessage message = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK); + assertNull(HttpUtil.getMimeType(message)); + message.headers().set(HttpHeaderNames.CONTENT_TYPE, ""); + assertNull(HttpUtil.getMimeType(message)); + assertNull(HttpUtil.getMimeType("")); + message.headers().set(HttpHeaderNames.CONTENT_TYPE, SIMPLE_CONTENT_TYPE); + assertEquals("text/html", HttpUtil.getMimeType(message)); + assertEquals("text/html", HttpUtil.getMimeType(SIMPLE_CONTENT_TYPE)); + + message.headers().set(HttpHeaderNames.CONTENT_TYPE, NORMAL_CONTENT_TYPE); + assertEquals("text/html", HttpUtil.getMimeType(message)); + assertEquals("text/html", HttpUtil.getMimeType(NORMAL_CONTENT_TYPE)); + } + + @Test + public void testGetContentLengthThrowsNumberFormatException() { + final HttpMessage message = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK); + message.headers().set(HttpHeaderNames.CONTENT_LENGTH, "bar"); + try { + HttpUtil.getContentLength(message); + fail(); + } catch (final NumberFormatException e) { + // a number format exception is expected here + } + } + + @Test + public void testGetContentLengthIntDefaultValueThrowsNumberFormatException() { + final HttpMessage message = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK); + message.headers().set(HttpHeaderNames.CONTENT_LENGTH, "bar"); + try { + HttpUtil.getContentLength(message, 1); + fail(); + } catch (final NumberFormatException e) { + // a number format exception is expected here + } + } + + @Test + public void testGetContentLengthLongDefaultValueThrowsNumberFormatException() { + final HttpMessage message = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK); + message.headers().set(HttpHeaderNames.CONTENT_LENGTH, "bar"); + try { + HttpUtil.getContentLength(message, 1L); + fail(); + } catch (final NumberFormatException e) { + // a number format exception is expected here + } + } + + @Test + public void testDoubleChunkedHeader() { + HttpMessage message = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK); + message.headers().add(HttpHeaderNames.TRANSFER_ENCODING, "chunked"); + HttpUtil.setTransferEncodingChunked(message, true); + List expected = singletonList("chunked"); + assertEquals(expected, message.headers().getAll(HttpHeaderNames.TRANSFER_ENCODING)); + } + + private static List allPossibleCasesOfContinue() { + final List cases = new ArrayList(); + final String c = "continue"; + for (int i = 0; i < Math.pow(2, c.length()); i++) { + final StringBuilder sb = new StringBuilder(c.length()); + int j = i; + int k = 0; + while (j > 0) { + if ((j & 1) == 1) { + sb.append(Character.toUpperCase(c.charAt(k++))); + } else { + sb.append(c.charAt(k++)); + } + j >>= 1; + } + for (; k < c.length(); k++) { + sb.append(c.charAt(k)); + } + cases.add(sb.toString()); + } + return cases; + } + + @Test + public void testIs100Continue() { + // test all possible cases of 100-continue + for (final String continueCase : allPossibleCasesOfContinue()) { + run100ContinueTest(HttpVersion.HTTP_1_1, "100-" + continueCase, true); + } + run100ContinueTest(HttpVersion.HTTP_1_1, null, false); + run100ContinueTest(HttpVersion.HTTP_1_1, "chocolate=yummy", false); + run100ContinueTest(HttpVersion.HTTP_1_0, "100-continue", false); + final HttpMessage message = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK); + message.headers().set(HttpHeaderNames.EXPECT, "100-continue"); + run100ContinueTest(message, false); + } + + private static void run100ContinueTest(final HttpVersion version, final String expectations, boolean expect) { + final HttpMessage message = new DefaultFullHttpRequest(version, HttpMethod.GET, "/"); + if (expectations != null) { + message.headers().set(HttpHeaderNames.EXPECT, expectations); + } + run100ContinueTest(message, expect); + } + + private static void run100ContinueTest(final HttpMessage message, final boolean expected) { + assertEquals(expected, HttpUtil.is100ContinueExpected(message)); + ReferenceCountUtil.release(message); + } + + @Test + public void testContainsUnsupportedExpectation() { + // test all possible cases of 100-continue + for (final String continueCase : allPossibleCasesOfContinue()) { + runUnsupportedExpectationTest(HttpVersion.HTTP_1_1, "100-" + continueCase, false); + } + runUnsupportedExpectationTest(HttpVersion.HTTP_1_1, null, false); + runUnsupportedExpectationTest(HttpVersion.HTTP_1_1, "chocolate=yummy", true); + runUnsupportedExpectationTest(HttpVersion.HTTP_1_0, "100-continue", false); + final HttpMessage message = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK); + message.headers().set("Expect", "100-continue"); + runUnsupportedExpectationTest(message, false); + } + + private static void runUnsupportedExpectationTest(final HttpVersion version, + final String expectations, boolean expect) { + final HttpMessage message = new DefaultFullHttpRequest(version, HttpMethod.GET, "/"); + if (expectations != null) { + message.headers().set("Expect", expectations); + } + runUnsupportedExpectationTest(message, expect); + } + + private static void runUnsupportedExpectationTest(final HttpMessage message, final boolean expected) { + assertEquals(expected, HttpUtil.isUnsupportedExpectation(message)); + ReferenceCountUtil.release(message); + } + + @Test + public void testFormatHostnameForHttpFromResolvedAddressWithHostname() throws Exception { + InetSocketAddress socketAddress = new InetSocketAddress(InetAddress.getByName("localhost"), 8080); + assertEquals("localhost", HttpUtil.formatHostnameForHttp(socketAddress)); + } + + @Test + public void testFormatHostnameForHttpFromUnesolvedAddressWithHostname() { + InetSocketAddress socketAddress = InetSocketAddress.createUnresolved("localhost", 80); + assertEquals("localhost", HttpUtil.formatHostnameForHttp(socketAddress)); + } + + @Test + public void testIpv6() throws Exception { + InetSocketAddress socketAddress = new InetSocketAddress(InetAddress.getByName("::1"), 8080); + assertEquals("[::1]", HttpUtil.formatHostnameForHttp(socketAddress)); + } + + @Test + public void testIpv6Unresolved() { + InetSocketAddress socketAddress = InetSocketAddress.createUnresolved("::1", 8080); + assertEquals("[::1]", HttpUtil.formatHostnameForHttp(socketAddress)); + } + + @Test + public void testIpv4() throws Exception { + InetSocketAddress socketAddress = new InetSocketAddress(InetAddress.getByName("10.0.0.1"), 8080); + assertEquals("10.0.0.1", HttpUtil.formatHostnameForHttp(socketAddress)); + } + + @Test + public void testIpv4Unresolved() { + InetSocketAddress socketAddress = InetSocketAddress.createUnresolved("10.0.0.1", 8080); + assertEquals("10.0.0.1", HttpUtil.formatHostnameForHttp(socketAddress)); + } + + @Test + public void testKeepAliveIfConnectionHeaderAbsent() { + HttpMessage http11Message = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, + "http:localhost/http_1_1"); + assertTrue(HttpUtil.isKeepAlive(http11Message)); + + HttpMessage http10Message = new DefaultHttpRequest(HttpVersion.HTTP_1_0, HttpMethod.GET, + "http:localhost/http_1_0"); + assertFalse(HttpUtil.isKeepAlive(http10Message)); + } + + @Test + public void testKeepAliveIfConnectionHeaderMultipleValues() { + HttpMessage http11Message = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, + "http:localhost/http_1_1"); + http11Message.headers().set( + HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE + ", " + HttpHeaderValues.CLOSE); + assertFalse(HttpUtil.isKeepAlive(http11Message)); + + http11Message.headers().set( + HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE + ", Close"); + assertFalse(HttpUtil.isKeepAlive(http11Message)); + + http11Message.headers().set( + HttpHeaderNames.CONNECTION, HttpHeaderValues.CLOSE + ", " + HttpHeaderValues.UPGRADE); + assertFalse(HttpUtil.isKeepAlive(http11Message)); + + http11Message.headers().set( + HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE + ", " + HttpHeaderValues.KEEP_ALIVE); + assertTrue(HttpUtil.isKeepAlive(http11Message)); + } + + @Test + public void normalizeAndGetContentLengthEmpty() { + testNormalizeAndGetContentLengthInvalidContentLength(""); + } + + @Test + public void normalizeAndGetContentLengthNotANumber() { + testNormalizeAndGetContentLengthInvalidContentLength("foo"); + } + + @Test + public void normalizeAndGetContentLengthNegative() { + testNormalizeAndGetContentLengthInvalidContentLength("-1"); + } + + private static void testNormalizeAndGetContentLengthInvalidContentLength(final String contentLengthField) { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + normalizeAndGetContentLength(singletonList(contentLengthField), false, false); + } + }); + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/MultipleContentLengthHeadersTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/MultipleContentLengthHeadersTest.java new file mode 100644 index 0000000..793864e --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/MultipleContentLengthHeadersTest.java @@ -0,0 +1,130 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.buffer.Unpooled; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.util.CharsetUtil; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import java.util.Arrays; +import java.util.Collection; +import java.util.List; + +import static io.netty.handler.codec.http.HttpObjectDecoder.DEFAULT_INITIAL_BUFFER_SIZE; +import static io.netty.handler.codec.http.HttpObjectDecoder.DEFAULT_MAX_CHUNK_SIZE; +import static io.netty.handler.codec.http.HttpObjectDecoder.DEFAULT_MAX_HEADER_SIZE; +import static io.netty.handler.codec.http.HttpObjectDecoder.DEFAULT_MAX_INITIAL_LINE_LENGTH; +import static io.netty.handler.codec.http.HttpObjectDecoder.DEFAULT_VALIDATE_HEADERS; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.core.IsInstanceOf.instanceOf; + +public class MultipleContentLengthHeadersTest { + + static Collection parameters() { + return Arrays.asList(new Object[][] { + { false, false, false }, + { false, false, true }, + { false, true, false }, + { false, true, true }, + { true, false, false }, + { true, false, true }, + { true, true, false }, + { true, true, true } + }); + } + + private static EmbeddedChannel newChannel(boolean allowDuplicateContentLengths) { + HttpRequestDecoder decoder = new HttpRequestDecoder( + DEFAULT_MAX_INITIAL_LINE_LENGTH, + DEFAULT_MAX_HEADER_SIZE, + DEFAULT_MAX_CHUNK_SIZE, + DEFAULT_VALIDATE_HEADERS, + DEFAULT_INITIAL_BUFFER_SIZE, + allowDuplicateContentLengths); + return new EmbeddedChannel(decoder); + } + + @ParameterizedTest + @MethodSource("parameters") + public void testMultipleContentLengthHeadersBehavior(boolean allowDuplicateContentLengths, + boolean sameValue, boolean singleField) { + EmbeddedChannel channel = newChannel(allowDuplicateContentLengths); + String requestStr = setupRequestString(sameValue, singleField); + assertThat(channel.writeInbound(Unpooled.copiedBuffer(requestStr, CharsetUtil.US_ASCII)), is(true)); + HttpRequest request = channel.readInbound(); + + if (allowDuplicateContentLengths) { + if (sameValue) { + assertValid(request); + List contentLengths = request.headers().getAll(HttpHeaderNames.CONTENT_LENGTH); + assertThat(contentLengths, contains("1")); + LastHttpContent body = channel.readInbound(); + assertThat(body.content().readableBytes(), is(1)); + assertThat(body.content().readCharSequence(1, CharsetUtil.US_ASCII).toString(), is("a")); + } else { + assertInvalid(request); + } + } else { + assertInvalid(request); + } + assertThat(channel.finish(), is(false)); + } + + private static String setupRequestString(boolean sameValue, boolean singleField) { + String firstValue = "1"; + String secondValue = sameValue ? firstValue : "2"; + String contentLength; + if (singleField) { + contentLength = "Content-Length: " + firstValue + ", " + secondValue + "\r\n\r\n"; + } else { + contentLength = "Content-Length: " + firstValue + "\r\n" + + "Content-Length: " + secondValue + "\r\n\r\n"; + } + return "PUT /some/path HTTP/1.1\r\n" + + contentLength + + "ab"; + } + + @Test + public void testDanglingComma() { + EmbeddedChannel channel = newChannel(false); + String requestStr = "GET /some/path HTTP/1.1\r\n" + + "Content-Length: 1,\r\n" + + "Connection: close\n\n" + + "ab"; + assertThat(channel.writeInbound(Unpooled.copiedBuffer(requestStr, CharsetUtil.US_ASCII)), is(true)); + HttpRequest request = channel.readInbound(); + assertInvalid(request); + assertThat(channel.finish(), is(false)); + } + + private static void assertValid(HttpRequest request) { + assertThat(request.decoderResult().isFailure(), is(false)); + } + + private static void assertInvalid(HttpRequest request) { + assertThat(request.decoderResult().isFailure(), is(true)); + assertThat(request.decoderResult().cause(), instanceOf(IllegalArgumentException.class)); + assertThat(request.decoderResult().cause().getMessage(), + containsString("Multiple Content-Length values found")); + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/QueryStringDecoderTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/QueryStringDecoderTest.java new file mode 100644 index 0000000..b6adc4e --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/QueryStringDecoderTest.java @@ -0,0 +1,384 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.util.CharsetUtil; +import org.junit.jupiter.api.Test; + +import java.net.URI; +import java.net.URISyntaxException; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class QueryStringDecoderTest { + + @Test + public void testBasicUris() throws URISyntaxException { + QueryStringDecoder d = new QueryStringDecoder(new URI("http://localhost/path")); + assertEquals(0, d.parameters().size()); + } + + @Test + public void testBasic() { + QueryStringDecoder d; + + d = new QueryStringDecoder("/foo"); + assertEquals("/foo", d.path()); + assertEquals(0, d.parameters().size()); + + d = new QueryStringDecoder("/foo%20bar"); + assertEquals("/foo bar", d.path()); + assertEquals(0, d.parameters().size()); + + d = new QueryStringDecoder("/foo?a=b=c"); + assertEquals("/foo", d.path()); + assertEquals(1, d.parameters().size()); + assertEquals(1, d.parameters().get("a").size()); + assertEquals("b=c", d.parameters().get("a").get(0)); + + d = new QueryStringDecoder("/foo?a=1&a=2"); + assertEquals("/foo", d.path()); + assertEquals(1, d.parameters().size()); + assertEquals(2, d.parameters().get("a").size()); + assertEquals("1", d.parameters().get("a").get(0)); + assertEquals("2", d.parameters().get("a").get(1)); + + d = new QueryStringDecoder("/foo%20bar?a=1&a=2"); + assertEquals("/foo bar", d.path()); + assertEquals(1, d.parameters().size()); + assertEquals(2, d.parameters().get("a").size()); + assertEquals("1", d.parameters().get("a").get(0)); + assertEquals("2", d.parameters().get("a").get(1)); + + d = new QueryStringDecoder("/foo?a=&a=2"); + assertEquals("/foo", d.path()); + assertEquals(1, d.parameters().size()); + assertEquals(2, d.parameters().get("a").size()); + assertEquals("", d.parameters().get("a").get(0)); + assertEquals("2", d.parameters().get("a").get(1)); + + d = new QueryStringDecoder("/foo?a=1&a="); + assertEquals("/foo", d.path()); + assertEquals(1, d.parameters().size()); + assertEquals(2, d.parameters().get("a").size()); + assertEquals("1", d.parameters().get("a").get(0)); + assertEquals("", d.parameters().get("a").get(1)); + + d = new QueryStringDecoder("/foo?a=1&a=&a="); + assertEquals("/foo", d.path()); + assertEquals(1, d.parameters().size()); + assertEquals(3, d.parameters().get("a").size()); + assertEquals("1", d.parameters().get("a").get(0)); + assertEquals("", d.parameters().get("a").get(1)); + assertEquals("", d.parameters().get("a").get(2)); + + d = new QueryStringDecoder("/foo?a=1=&a==2"); + assertEquals("/foo", d.path()); + assertEquals(1, d.parameters().size()); + assertEquals(2, d.parameters().get("a").size()); + assertEquals("1=", d.parameters().get("a").get(0)); + assertEquals("=2", d.parameters().get("a").get(1)); + + d = new QueryStringDecoder("/foo?abc=1%2023&abc=124%20"); + assertEquals("/foo", d.path()); + assertEquals(1, d.parameters().size()); + assertEquals(2, d.parameters().get("abc").size()); + assertEquals("1 23", d.parameters().get("abc").get(0)); + assertEquals("124 ", d.parameters().get("abc").get(1)); + + d = new QueryStringDecoder("/foo?abc=%7E"); + assertEquals("~", d.parameters().get("abc").get(0)); + } + + @Test + public void testExotic() { + assertQueryString("", ""); + assertQueryString("foo", "foo"); + assertQueryString("foo", "foo?"); + assertQueryString("/foo", "/foo?"); + assertQueryString("/foo", "/foo"); + assertQueryString("?a=", "?a"); + assertQueryString("foo?a=", "foo?a"); + assertQueryString("/foo?a=", "/foo?a"); + assertQueryString("/foo?a=", "/foo?a&"); + assertQueryString("/foo?a=", "/foo?&a"); + assertQueryString("/foo?a=", "/foo?&a&"); + assertQueryString("/foo?a=", "/foo?&=a"); + assertQueryString("/foo?a=", "/foo?=a&"); + assertQueryString("/foo?a=", "/foo?a=&"); + assertQueryString("/foo?a=b&c=d", "/foo?a=b&&c=d"); + assertQueryString("/foo?a=b&c=d", "/foo?a=b&=&c=d"); + assertQueryString("/foo?a=b&c=d", "/foo?a=b&==&c=d"); + assertQueryString("/foo?a=b&c=&x=y", "/foo?a=b&c&x=y"); + assertQueryString("/foo?a=", "/foo?a="); + assertQueryString("/foo?a=", "/foo?&a="); + assertQueryString("/foo?a=b&c=d", "/foo?a=b&c=d"); + assertQueryString("/foo?a=1&a=&a=", "/foo?a=1&a&a="); + } + + @Test + public void testSemicolon() { + assertQueryString("/foo?a=1;2", "/foo?a=1;2", false); + // ";" should be treated as a normal character, see #8855 + assertQueryString("/foo?a=1;2", "/foo?a=1%3B2", true); + } + + @Test + public void testPathSpecific() { + // decode escaped characters + assertEquals("/foo bar/", new QueryStringDecoder("/foo%20bar/?").path()); + assertEquals("/foo\r\n\\bar/", new QueryStringDecoder("/foo%0D%0A\\bar/?").path()); + + // a 'fragment' after '#' should be cuted (see RFC 3986) + assertEquals("", new QueryStringDecoder("#123").path()); + assertEquals("foo", new QueryStringDecoder("foo?bar#anchor").path()); + assertEquals("/foo-bar", new QueryStringDecoder("/foo-bar#anchor").path()); + assertEquals("/foo-bar", new QueryStringDecoder("/foo-bar#a#b?c=d").path()); + + // '+' is not escape ' ' for the path + assertEquals("+", new QueryStringDecoder("+").path()); + assertEquals("/foo+bar/", new QueryStringDecoder("/foo+bar/?").path()); + assertEquals("/foo++", new QueryStringDecoder("/foo++?index.php").path()); + assertEquals("/foo +", new QueryStringDecoder("/foo%20+?index.php").path()); + assertEquals("/foo+ ", new QueryStringDecoder("/foo+%20").path()); + } + + @Test + public void testExcludeFragment() { + // a 'fragment' after '#' should be cuted (see RFC 3986) + assertEquals("a", new QueryStringDecoder("?a#anchor").parameters().keySet().iterator().next()); + assertEquals("b", new QueryStringDecoder("?a=b#anchor").parameters().get("a").get(0)); + assertTrue(new QueryStringDecoder("?#").parameters().isEmpty()); + assertTrue(new QueryStringDecoder("?#anchor").parameters().isEmpty()); + assertTrue(new QueryStringDecoder("#?a=b#anchor").parameters().isEmpty()); + assertTrue(new QueryStringDecoder("?#a=b#anchor").parameters().isEmpty()); + } + + @Test + public void testHashDos() { + StringBuilder buf = new StringBuilder(); + buf.append('?'); + for (int i = 0; i < 65536; i ++) { + buf.append('k'); + buf.append(i); + buf.append("=v"); + buf.append(i); + buf.append('&'); + } + assertEquals(1024, new QueryStringDecoder(buf.toString()).parameters().size()); + } + + @Test + public void testHasPath() { + QueryStringDecoder decoder = new QueryStringDecoder("1=2", false); + assertEquals("", decoder.path()); + Map> params = decoder.parameters(); + assertEquals(1, params.size()); + assertTrue(params.containsKey("1")); + List param = params.get("1"); + assertNotNull(param); + assertEquals(1, param.size()); + assertEquals("2", param.get(0)); + } + + @Test + public void testUrlDecoding() throws Exception { + final String caffe = new String( + // "Caffé" but instead of putting the literal E-acute in the + // source file, we directly use the UTF-8 encoding so as to + // not rely on the platform's default encoding (not portable). + new byte[] {'C', 'a', 'f', 'f', (byte) 0xC3, (byte) 0xA9}, + "UTF-8"); + final String[] tests = { + // Encoded -> Decoded or error message substring + "", "", + "foo", "foo", + "f+o", "f o", + "f++", "f ", + "fo%", "unterminated escape sequence at index 2 of: fo%", + "%42", "B", + "%5f", "_", + "f%4", "unterminated escape sequence at index 1 of: f%4", + "%x2", "invalid hex byte 'x2' at index 1 of '%x2'", + "%4x", "invalid hex byte '4x' at index 1 of '%4x'", + "Caff%C3%A9", caffe, + "случайный праздник", "случайный праздник", + "случайный%20праздник", "случайный праздник", + "случайный%20праздник%20%E2%98%BA", "случайный праздник ☺", + }; + for (int i = 0; i < tests.length; i += 2) { + final String encoded = tests[i]; + final String expected = tests[i + 1]; + try { + final String decoded = QueryStringDecoder.decodeComponent(encoded); + assertEquals(expected, decoded); + } catch (IllegalArgumentException e) { + assertEquals(expected, e.getMessage()); + } + } + } + + private static void assertQueryString(String expected, String actual) { + assertQueryString(expected, actual, false); + } + + private static void assertQueryString(String expected, String actual, boolean semicolonIsNormalChar) { + QueryStringDecoder ed = new QueryStringDecoder(expected, CharsetUtil.UTF_8, true, + 1024, semicolonIsNormalChar); + QueryStringDecoder ad = new QueryStringDecoder(actual, CharsetUtil.UTF_8, true, + 1024, semicolonIsNormalChar); + assertEquals(ed.path(), ad.path()); + assertEquals(ed.parameters(), ad.parameters()); + } + + // See #189 + @Test + public void testURI() { + URI uri = URI.create("http://localhost:8080/foo?param1=value1¶m2=value2¶m3=value3"); + QueryStringDecoder decoder = new QueryStringDecoder(uri); + assertEquals("/foo", decoder.path()); + assertEquals("/foo", decoder.rawPath()); + assertEquals("param1=value1¶m2=value2¶m3=value3", decoder.rawQuery()); + Map> params = decoder.parameters(); + assertEquals(3, params.size()); + Iterator>> entries = params.entrySet().iterator(); + + Entry> entry = entries.next(); + assertEquals("param1", entry.getKey()); + assertEquals(1, entry.getValue().size()); + assertEquals("value1", entry.getValue().get(0)); + + entry = entries.next(); + assertEquals("param2", entry.getKey()); + assertEquals(1, entry.getValue().size()); + assertEquals("value2", entry.getValue().get(0)); + + entry = entries.next(); + assertEquals("param3", entry.getKey()); + assertEquals(1, entry.getValue().size()); + assertEquals("value3", entry.getValue().get(0)); + + assertFalse(entries.hasNext()); + } + + // See #189 + @Test + public void testURISlashPath() { + URI uri = URI.create("http://localhost:8080/?param1=value1¶m2=value2¶m3=value3"); + QueryStringDecoder decoder = new QueryStringDecoder(uri); + assertEquals("/", decoder.path()); + assertEquals("/", decoder.rawPath()); + assertEquals("param1=value1¶m2=value2¶m3=value3", decoder.rawQuery()); + + Map> params = decoder.parameters(); + assertEquals(3, params.size()); + Iterator>> entries = params.entrySet().iterator(); + + Entry> entry = entries.next(); + assertEquals("param1", entry.getKey()); + assertEquals(1, entry.getValue().size()); + assertEquals("value1", entry.getValue().get(0)); + + entry = entries.next(); + assertEquals("param2", entry.getKey()); + assertEquals(1, entry.getValue().size()); + assertEquals("value2", entry.getValue().get(0)); + + entry = entries.next(); + assertEquals("param3", entry.getKey()); + assertEquals(1, entry.getValue().size()); + assertEquals("value3", entry.getValue().get(0)); + + assertFalse(entries.hasNext()); + } + + // See #189 + @Test + public void testURINoPath() { + URI uri = URI.create("http://localhost:8080?param1=value1¶m2=value2¶m3=value3"); + QueryStringDecoder decoder = new QueryStringDecoder(uri); + assertEquals("", decoder.path()); + assertEquals("", decoder.rawPath()); + assertEquals("param1=value1¶m2=value2¶m3=value3", decoder.rawQuery()); + + Map> params = decoder.parameters(); + assertEquals(3, params.size()); + Iterator>> entries = params.entrySet().iterator(); + + Entry> entry = entries.next(); + assertEquals("param1", entry.getKey()); + assertEquals(1, entry.getValue().size()); + assertEquals("value1", entry.getValue().get(0)); + + entry = entries.next(); + assertEquals("param2", entry.getKey()); + assertEquals(1, entry.getValue().size()); + assertEquals("value2", entry.getValue().get(0)); + + entry = entries.next(); + assertEquals("param3", entry.getKey()); + assertEquals(1, entry.getValue().size()); + assertEquals("value3", entry.getValue().get(0)); + + assertFalse(entries.hasNext()); + } + + // See https://github.com/netty/netty/issues/1833 + @Test + public void testURI2() { + URI uri = URI.create("http://foo.com/images;num=10?query=name;value=123"); + QueryStringDecoder decoder = new QueryStringDecoder(uri); + assertEquals("/images;num=10", decoder.path()); + assertEquals("/images;num=10", decoder.rawPath()); + assertEquals("query=name;value=123", decoder.rawQuery()); + + Map> params = decoder.parameters(); + assertEquals(2, params.size()); + Iterator>> entries = params.entrySet().iterator(); + + Entry> entry = entries.next(); + assertEquals("query", entry.getKey()); + assertEquals(1, entry.getValue().size()); + assertEquals("name", entry.getValue().get(0)); + + entry = entries.next(); + assertEquals("value", entry.getKey()); + assertEquals(1, entry.getValue().size()); + assertEquals("123", entry.getValue().get(0)); + + assertFalse(entries.hasNext()); + } + + @Test + public void testEmptyStrings() { + QueryStringDecoder pathSlash = new QueryStringDecoder("path/"); + assertEquals("path/", pathSlash.rawPath()); + assertEquals("", pathSlash.rawQuery()); + QueryStringDecoder pathQuestion = new QueryStringDecoder("path?"); + assertEquals("path", pathQuestion.rawPath()); + assertEquals("", pathQuestion.rawQuery()); + QueryStringDecoder empty = new QueryStringDecoder(""); + assertEquals("", empty.rawPath()); + assertEquals("", empty.rawQuery()); + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/QueryStringEncoderTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/QueryStringEncoderTest.java new file mode 100644 index 0000000..e30459a --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/QueryStringEncoderTest.java @@ -0,0 +1,81 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import org.junit.jupiter.api.Test; + +import java.net.URI; +import java.nio.charset.Charset; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class QueryStringEncoderTest { + + @Test + public void testDefaultEncoding() throws Exception { + QueryStringEncoder e; + + e = new QueryStringEncoder("/foo"); + e.addParam("a", "b=c"); + assertEquals("/foo?a=b%3Dc", e.toString()); + assertEquals(new URI("/foo?a=b%3Dc"), e.toUri()); + + e = new QueryStringEncoder("/foo/\u00A5"); + e.addParam("a", "\u00A5"); + assertEquals("/foo/\u00A5?a=%C2%A5", e.toString()); + assertEquals(new URI("/foo/\u00A5?a=%C2%A5"), e.toUri()); + + e = new QueryStringEncoder("/foo/\u00A5"); + e.addParam("a", "abc\u00A5"); + assertEquals("/foo/\u00A5?a=abc%C2%A5", e.toString()); + assertEquals(new URI("/foo/\u00A5?a=abc%C2%A5"), e.toUri()); + + e = new QueryStringEncoder("/foo"); + e.addParam("a", "1"); + e.addParam("b", "2"); + assertEquals("/foo?a=1&b=2", e.toString()); + assertEquals(new URI("/foo?a=1&b=2"), e.toUri()); + + e = new QueryStringEncoder("/foo"); + e.addParam("a", "1"); + e.addParam("b", ""); + e.addParam("c", null); + e.addParam("d", null); + assertEquals("/foo?a=1&b=&c&d", e.toString()); + assertEquals(new URI("/foo?a=1&b=&c&d"), e.toUri()); + + e = new QueryStringEncoder("/foo"); + e.addParam("test", "a~b"); + assertEquals("/foo?test=a~b", e.toString()); + assertEquals(new URI("/foo?test=a~b"), e.toUri()); + } + + @Test + public void testNonDefaultEncoding() throws Exception { + QueryStringEncoder e = new QueryStringEncoder("/foo/\u00A5", Charset.forName("UTF-16")); + e.addParam("a", "\u00A5"); + assertEquals("/foo/\u00A5?a=%FE%FF%00%A5", e.toString()); + assertEquals(new URI("/foo/\u00A5?a=%FE%FF%00%A5"), e.toUri()); + } + + @Test + public void testWhitespaceEncoding() throws Exception { + QueryStringEncoder e = new QueryStringEncoder("/foo"); + e.addParam("a", "b c"); + assertEquals("/foo?a=b%20c", e.toString()); + assertEquals(new URI("/foo?a=b%20c"), e.toUri()); + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/ReadOnlyHttpHeadersTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/ReadOnlyHttpHeadersTest.java new file mode 100644 index 0000000..395f3b2 --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/ReadOnlyHttpHeadersTest.java @@ -0,0 +1,166 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http; + +import io.netty.util.AsciiString; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import java.util.Iterator; +import java.util.List; +import java.util.Map.Entry; +import java.util.Set; + +import static io.netty.handler.codec.http.HttpHeaderNames.ACCEPT; +import static io.netty.handler.codec.http.HttpHeaderNames.CONNECTION; +import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_LENGTH; +import static io.netty.handler.codec.http.HttpHeaderValues.APPLICATION_JSON; +import static io.netty.handler.codec.http.HttpHeaderValues.APPLICATION_OCTET_STREAM; +import static io.netty.handler.codec.http.HttpHeaderValues.CLOSE; +import static io.netty.handler.codec.http.HttpHeaderValues.ZERO; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class ReadOnlyHttpHeadersTest { + @Test + public void getValue() { + ReadOnlyHttpHeaders headers = new ReadOnlyHttpHeaders(true, + ACCEPT, APPLICATION_JSON); + assertFalse(headers.isEmpty()); + assertEquals(1, headers.size()); + assertTrue(APPLICATION_JSON.contentEquals(headers.get(ACCEPT))); + assertTrue(headers.contains(ACCEPT)); + assertNull(headers.get(CONTENT_LENGTH)); + assertFalse(headers.contains(CONTENT_LENGTH)); + } + + @Test + public void charSequenceIterator() { + ReadOnlyHttpHeaders headers = new ReadOnlyHttpHeaders(true, + ACCEPT, APPLICATION_JSON, CONTENT_LENGTH, ZERO, CONNECTION, CLOSE); + assertFalse(headers.isEmpty()); + assertEquals(3, headers.size()); + Iterator> itr = headers.iteratorCharSequence(); + assertTrue(itr.hasNext()); + Entry next = itr.next(); + assertTrue(ACCEPT.contentEqualsIgnoreCase(next.getKey())); + assertTrue(APPLICATION_JSON.contentEqualsIgnoreCase(next.getValue())); + assertTrue(itr.hasNext()); + next = itr.next(); + assertTrue(CONTENT_LENGTH.contentEqualsIgnoreCase(next.getKey())); + assertTrue(ZERO.contentEqualsIgnoreCase(next.getValue())); + assertTrue(itr.hasNext()); + next = itr.next(); + assertTrue(CONNECTION.contentEqualsIgnoreCase(next.getKey())); + assertTrue(CLOSE.contentEqualsIgnoreCase(next.getValue())); + assertFalse(itr.hasNext()); + } + + @Test + public void stringIterator() { + ReadOnlyHttpHeaders headers = new ReadOnlyHttpHeaders(true, + ACCEPT, APPLICATION_JSON, CONTENT_LENGTH, ZERO, CONNECTION, CLOSE); + assertFalse(headers.isEmpty()); + assertEquals(3, headers.size()); + assert3ParisEquals(headers.iterator()); + } + + @Test + public void entries() { + ReadOnlyHttpHeaders headers = new ReadOnlyHttpHeaders(true, + ACCEPT, APPLICATION_JSON, CONTENT_LENGTH, ZERO, CONNECTION, CLOSE); + assertFalse(headers.isEmpty()); + assertEquals(3, headers.size()); + assert3ParisEquals(headers.entries().iterator()); + } + + @Test + public void names() { + ReadOnlyHttpHeaders headers = new ReadOnlyHttpHeaders(true, + ACCEPT, APPLICATION_JSON, CONTENT_LENGTH, ZERO, CONNECTION, CLOSE); + assertFalse(headers.isEmpty()); + assertEquals(3, headers.size()); + Set names = headers.names(); + assertEquals(3, names.size()); + assertTrue(names.contains(ACCEPT.toString())); + assertTrue(names.contains(CONTENT_LENGTH.toString())); + assertTrue(names.contains(CONNECTION.toString())); + } + + @Test + public void getAll() { + ReadOnlyHttpHeaders headers = new ReadOnlyHttpHeaders(false, + ACCEPT, APPLICATION_JSON, CONTENT_LENGTH, ZERO, ACCEPT, APPLICATION_OCTET_STREAM); + assertFalse(headers.isEmpty()); + assertEquals(3, headers.size()); + List names = headers.getAll(ACCEPT); + assertEquals(2, names.size()); + assertTrue(APPLICATION_JSON.contentEqualsIgnoreCase(names.get(0))); + assertTrue(APPLICATION_OCTET_STREAM.contentEqualsIgnoreCase(names.get(1))); + } + + @Test + public void validateNamesFail() { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + new ReadOnlyHttpHeaders(true, + ACCEPT, APPLICATION_JSON, AsciiString.cached(" ")); + } + }); + } + + @Test + public void emptyHeaderName() { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + new ReadOnlyHttpHeaders(true, + ACCEPT, APPLICATION_JSON, AsciiString.cached(" "), ZERO); + } + }); + } + + @Test + public void headerWithoutValue() { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + new ReadOnlyHttpHeaders(false, + ACCEPT, APPLICATION_JSON, CONTENT_LENGTH); + } + }); + } + + private static void assert3ParisEquals(Iterator> itr) { + assertTrue(itr.hasNext()); + Entry next = itr.next(); + assertTrue(ACCEPT.contentEqualsIgnoreCase(next.getKey())); + assertTrue(APPLICATION_JSON.contentEqualsIgnoreCase(next.getValue())); + assertTrue(itr.hasNext()); + next = itr.next(); + assertTrue(CONTENT_LENGTH.contentEqualsIgnoreCase(next.getKey())); + assertTrue(ZERO.contentEqualsIgnoreCase(next.getValue())); + assertTrue(itr.hasNext()); + next = itr.next(); + assertTrue(CONNECTION.contentEqualsIgnoreCase(next.getKey())); + assertTrue(CLOSE.contentEqualsIgnoreCase(next.getValue())); + assertFalse(itr.hasNext()); + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/cookie/ClientCookieDecoderTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/cookie/ClientCookieDecoderTest.java new file mode 100644 index 0000000..deb0edb --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/cookie/ClientCookieDecoderTest.java @@ -0,0 +1,292 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.cookie; + +import io.netty.handler.codec.DateFormatter; +import io.netty.handler.codec.http.cookie.CookieHeaderNames.SameSite; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.Calendar; +import java.util.Collection; +import java.util.Date; +import java.util.Iterator; +import java.util.TimeZone; + +import static org.hamcrest.CoreMatchers.instanceOf; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class ClientCookieDecoderTest { + @Test + public void testDecodingSingleCookieV0() { + String cookieString = "myCookie=myValue;expires=" + + DateFormatter.format(new Date(System.currentTimeMillis() + 50000)) + + ";path=/apathsomewhere;domain=.adomainsomewhere;secure;SameSite=None"; + + Cookie cookie = ClientCookieDecoder.STRICT.decode(cookieString); + assertNotNull(cookie); + assertEquals("myValue", cookie.value()); + assertEquals(".adomainsomewhere", cookie.domain()); + assertNotEquals(Long.MIN_VALUE, cookie.maxAge(), + "maxAge should be defined when parsing cookie " + cookieString); + assertTrue(cookie.maxAge() >= 40 && cookie.maxAge() <= 60, + "maxAge should be about 50ms when parsing cookie " + cookieString); + assertEquals("/apathsomewhere", cookie.path()); + assertTrue(cookie.isSecure()); + + assertThat(cookie, is(instanceOf(DefaultCookie.class))); + assertEquals(SameSite.None, ((DefaultCookie) cookie).sameSite()); + } + + @Test + public void testDecodingSingleCookieV0ExtraParamsIgnored() { + String cookieString = "myCookie=myValue;max-age=50;path=/apathsomewhere;" + + "domain=.adomainsomewhere;secure;comment=this is a comment;version=0;" + + "commentURL=http://aurl.com;port=\"80,8080\";discard;"; + Cookie cookie = ClientCookieDecoder.STRICT.decode(cookieString); + assertNotNull(cookie); + assertEquals("myValue", cookie.value()); + assertEquals(".adomainsomewhere", cookie.domain()); + assertEquals(50, cookie.maxAge()); + assertEquals("/apathsomewhere", cookie.path()); + assertTrue(cookie.isSecure()); + } + + @Test + public void testDecodingSingleCookieV1() { + String cookieString = "myCookie=myValue;max-age=50;path=/apathsomewhere;domain=.adomainsomewhere" + + ";secure;comment=this is a comment;version=1;"; + Cookie cookie = ClientCookieDecoder.STRICT.decode(cookieString); + assertEquals("myValue", cookie.value()); + assertNotNull(cookie); + assertEquals(".adomainsomewhere", cookie.domain()); + assertEquals(50, cookie.maxAge()); + assertEquals("/apathsomewhere", cookie.path()); + assertTrue(cookie.isSecure()); + } + + @Test + public void testDecodingSingleCookieV1ExtraParamsIgnored() { + String cookieString = "myCookie=myValue;max-age=50;path=/apathsomewhere;" + + "domain=.adomainsomewhere;secure;comment=this is a comment;version=1;" + + "commentURL=http://aurl.com;port='80,8080';discard;"; + Cookie cookie = ClientCookieDecoder.STRICT.decode(cookieString); + assertNotNull(cookie); + assertEquals("myValue", cookie.value()); + assertEquals(".adomainsomewhere", cookie.domain()); + assertEquals(50, cookie.maxAge()); + assertEquals("/apathsomewhere", cookie.path()); + assertTrue(cookie.isSecure()); + } + + @Test + public void testDecodingSingleCookieV2() { + String cookieString = "myCookie=myValue;max-age=50;path=/apathsomewhere;" + + "domain=.adomainsomewhere;secure;comment=this is a comment;version=2;" + + "commentURL=http://aurl.com;port=\"80,8080\";discard;"; + Cookie cookie = ClientCookieDecoder.STRICT.decode(cookieString); + assertNotNull(cookie); + assertEquals("myValue", cookie.value()); + assertEquals(".adomainsomewhere", cookie.domain()); + assertEquals(50, cookie.maxAge()); + assertEquals("/apathsomewhere", cookie.path()); + assertTrue(cookie.isSecure()); + } + + @Test + public void testDecodingComplexCookie() { + String c1 = "myCookie=myValue;max-age=50;path=/apathsomewhere;" + + "domain=.adomainsomewhere;secure;comment=this is a comment;version=2;" + + "commentURL=\"http://aurl.com\";port='80,8080';discard;"; + + Cookie cookie = ClientCookieDecoder.STRICT.decode(c1); + assertNotNull(cookie); + assertEquals("myValue", cookie.value()); + assertEquals(".adomainsomewhere", cookie.domain()); + assertEquals(50, cookie.maxAge()); + assertEquals("/apathsomewhere", cookie.path()); + assertTrue(cookie.isSecure()); + } + + @Test + public void testDecodingQuotedCookie() { + Collection sources = new ArrayList(); + sources.add("a=\"\","); + sources.add("b=\"1\","); + + Collection cookies = new ArrayList(); + for (String source : sources) { + cookies.add(ClientCookieDecoder.STRICT.decode(source)); + } + + Iterator it = cookies.iterator(); + Cookie c; + + c = it.next(); + assertEquals("a", c.name()); + assertEquals("", c.value()); + + c = it.next(); + assertEquals("b", c.name()); + assertEquals("1", c.value()); + + assertFalse(it.hasNext()); + } + + @Test + public void testDecodingGoogleAnalyticsCookie() { + String source = "ARPT=LWUKQPSWRTUN04CKKJI; " + + "kw-2E343B92-B097-442c-BFA5-BE371E0325A2=unfinished furniture; " + + "__utma=48461872.1094088325.1258140131.1258140131.1258140131.1; " + + "__utmb=48461872.13.10.1258140131; __utmc=48461872; " + + "__utmz=48461872.1258140131.1.1.utmcsr=overstock.com|utmccn=(referral)|" + + "utmcmd=referral|utmcct=/Home-Garden/Furniture/Clearance,/clearance,/32/dept.html"; + Cookie cookie = ClientCookieDecoder.STRICT.decode(source); + + assertEquals("ARPT", cookie.name()); + assertEquals("LWUKQPSWRTUN04CKKJI", cookie.value()); + } + + @Test + public void testDecodingLongDates() { + Calendar cookieDate = Calendar.getInstance(TimeZone.getTimeZone("UTC")); + cookieDate.set(9999, Calendar.DECEMBER, 31, 23, 59, 59); + long expectedMaxAge = (cookieDate.getTimeInMillis() - System + .currentTimeMillis()) / 1000; + + String source = "Format=EU; expires=Fri, 31-Dec-9999 23:59:59 GMT; path=/"; + + Cookie cookie = ClientCookieDecoder.STRICT.decode(source); + + assertTrue(Math.abs(expectedMaxAge - cookie.maxAge()) < 2); + } + + @Test + public void testDecodingValueWithCommaFails() { + String source = "UserCookie=timeZoneName=(GMT+04:00) Moscow, St. Petersburg, Volgograd&promocode=®ion=BE;" + + " expires=Sat, 01-Dec-2012 10:53:31 GMT; path=/"; + + Cookie cookie = ClientCookieDecoder.STRICT.decode(source); + + assertNull(cookie); + } + + @Test + public void testDecodingWeirdNames1() { + String src = "path=; expires=Mon, 01-Jan-1990 00:00:00 GMT; path=/; domain=.www.google.com"; + Cookie cookie = ClientCookieDecoder.STRICT.decode(src); + assertEquals("path", cookie.name()); + assertEquals("", cookie.value()); + assertEquals("/", cookie.path()); + } + + @Test + public void testDecodingWeirdNames2() { + String src = "HTTPOnly="; + Cookie cookie = ClientCookieDecoder.STRICT.decode(src); + assertEquals("HTTPOnly", cookie.name()); + assertEquals("", cookie.value()); + } + + @Test + public void testDecodingValuesWithCommasAndEqualsFails() { + String src = "A=v=1&lg=en-US,it-IT,it&intl=it&np=1;T=z=E"; + Cookie cookie = ClientCookieDecoder.STRICT.decode(src); + assertNull(cookie); + } + + @Test + public void testDecodingInvalidValuesWithCommaAtStart() { + assertNull(ClientCookieDecoder.STRICT.decode(",")); + assertNull(ClientCookieDecoder.STRICT.decode(",a")); + assertNull(ClientCookieDecoder.STRICT.decode(",a=a")); + } + + @Test + public void testDecodingLongValue() { + String longValue = + "b___$Q__$ha______" + + "%=J^wI__3iD____$=HbQW__3iF____#=J^wI__3iH____%=J^wI__3iM____%=J^wI__3iS____" + + "#=J^wI__3iU____%=J^wI__3iZ____#=J^wI__3i]____%=J^wI__3ig____%=J^wI__3ij____" + + "%=J^wI__3ik____#=J^wI__3il____$=HbQW__3in____%=J^wI__3ip____$=HbQW__3iq____" + + "$=HbQW__3it____%=J^wI__3ix____#=J^wI__3j_____$=HbQW__3j%____$=HbQW__3j'____" + + "%=J^wI__3j(____%=J^wI__9mJ____'=KqtH__=SE__M____" + + "'=KqtH__s1X____$=MMyc__s1_____#=MN#O__ypn____'=KqtH__ypr____'=KqtH_#%h_____" + + "%=KqtH_#%o_____'=KqtH_#)H6______'=KqtH_#]9R____$=H/Lt_#]I6____#=KqtH_#]Z#____%=KqtH_#^*N____" + + "#=KqtH_#^:m____#=KqtH_#_*_____%=J^wI_#`-7____#=KqtH_#`T>____'=KqtH_#`T?____" + + "'=KqtH_#`TA____'=KqtH_#`TB____'=KqtH_#`TG____'=KqtH_#`TP____#=KqtH_#`U_____" + + "'=KqtH_#`U/____'=KqtH_#`U0____#=KqtH_#`U9____'=KqtH_#aEQ____%=KqtH_#b<)____" + + "'=KqtH_#c9-____%=KqtH_#dxC____%=KqtH_#dxE____%=KqtH_#ev$____'=KqtH_#fBi____" + + "#=KqtH_#fBj____'=KqtH_#fG)____'=KqtH_#fG+____'=KqtH_#g*B____'=KqtH_$>hD____+=J^x0_$?lW____'=KqtH_$?ll____'=KqtH_$?lm____" + + "%=KqtH_$?mi____'=KqtH_$?mx____'=KqtH_$D7]____#=J_#p_$D@T____#=J_#p_$V cookies = ServerCookieDecoder.STRICT.decode(cookieString); + assertEquals(1, cookies.size()); + Cookie cookie = cookies.iterator().next(); + assertNotNull(cookie); + assertEquals("myValue", cookie.value()); + } + + @Test + public void testDecodingMultipleCookies() { + String c1 = "myCookie=myValue;"; + String c2 = "myCookie2=myValue2;"; + String c3 = "myCookie3=myValue3;"; + + Set cookies = ServerCookieDecoder.STRICT.decode(c1 + c2 + c3); + assertEquals(3, cookies.size()); + Iterator it = cookies.iterator(); + Cookie cookie = it.next(); + assertNotNull(cookie); + assertEquals("myValue", cookie.value()); + cookie = it.next(); + assertNotNull(cookie); + assertEquals("myValue2", cookie.value()); + cookie = it.next(); + assertNotNull(cookie); + assertEquals("myValue3", cookie.value()); + } + + @Test + public void testDecodingAllMultipleCookies() { + String c1 = "myCookie=myValue;"; + String c2 = "myCookie=myValue2;"; + String c3 = "myCookie=myValue3;"; + + List cookies = ServerCookieDecoder.STRICT.decodeAll(c1 + c2 + c3); + assertEquals(3, cookies.size()); + Iterator it = cookies.iterator(); + Cookie cookie = it.next(); + assertNotNull(cookie); + assertEquals("myValue", cookie.value()); + cookie = it.next(); + assertNotNull(cookie); + assertEquals("myValue2", cookie.value()); + cookie = it.next(); + assertNotNull(cookie); + assertEquals("myValue3", cookie.value()); + } + + @Test + public void testDecodingGoogleAnalyticsCookie() { + String source = + "ARPT=LWUKQPSWRTUN04CKKJI; " + + "kw-2E343B92-B097-442c-BFA5-BE371E0325A2=unfinished_furniture; " + + "__utma=48461872.1094088325.1258140131.1258140131.1258140131.1; " + + "__utmb=48461872.13.10.1258140131; __utmc=48461872; " + + "__utmz=48461872.1258140131.1.1.utmcsr=overstock.com|utmccn=(referral)|" + + "utmcmd=referral|utmcct=/Home-Garden/Furniture/Clearance/clearance/32/dept.html"; + Set cookies = ServerCookieDecoder.STRICT.decode(source); + Iterator it = cookies.iterator(); + Cookie c; + + c = it.next(); + assertEquals("ARPT", c.name()); + assertEquals("LWUKQPSWRTUN04CKKJI", c.value()); + + c = it.next(); + assertEquals("__utma", c.name()); + assertEquals("48461872.1094088325.1258140131.1258140131.1258140131.1", c.value()); + + c = it.next(); + assertEquals("__utmb", c.name()); + assertEquals("48461872.13.10.1258140131", c.value()); + + c = it.next(); + assertEquals("__utmc", c.name()); + assertEquals("48461872", c.value()); + + c = it.next(); + assertEquals("__utmz", c.name()); + assertEquals("48461872.1258140131.1.1.utmcsr=overstock.com|" + + "utmccn=(referral)|utmcmd=referral|utmcct=/Home-Garden/Furniture/Clearance/clearance/32/dept.html", + c.value()); + + c = it.next(); + assertEquals("kw-2E343B92-B097-442c-BFA5-BE371E0325A2", c.name()); + assertEquals("unfinished_furniture", c.value()); + + assertFalse(it.hasNext()); + } + + @Test + public void testDecodingLongValue() { + String longValue = + "b___$Q__$ha______" + + "%=J^wI__3iD____$=HbQW__3iF____#=J^wI__3iH____%=J^wI__3iM____%=J^wI__3iS____" + + "#=J^wI__3iU____%=J^wI__3iZ____#=J^wI__3i]____%=J^wI__3ig____%=J^wI__3ij____" + + "%=J^wI__3ik____#=J^wI__3il____$=HbQW__3in____%=J^wI__3ip____$=HbQW__3iq____" + + "$=HbQW__3it____%=J^wI__3ix____#=J^wI__3j_____$=HbQW__3j%____$=HbQW__3j'____" + + "%=J^wI__3j(____%=J^wI__9mJ____'=KqtH__=SE__M____" + + "'=KqtH__s1X____$=MMyc__s1_____#=MN#O__ypn____'=KqtH__ypr____'=KqtH_#%h_____" + + "%=KqtH_#%o_____'=KqtH_#)H6______'=KqtH_#]9R____$=H/Lt_#]I6____#=KqtH_#]Z#____%=KqtH_#^*N____" + + "#=KqtH_#^:m____#=KqtH_#_*_____%=J^wI_#`-7____#=KqtH_#`T>____'=KqtH_#`T?____" + + "'=KqtH_#`TA____'=KqtH_#`TB____'=KqtH_#`TG____'=KqtH_#`TP____#=KqtH_#`U_____" + + "'=KqtH_#`U/____'=KqtH_#`U0____#=KqtH_#`U9____'=KqtH_#aEQ____%=KqtH_#b<)____" + + "'=KqtH_#c9-____%=KqtH_#dxC____%=KqtH_#dxE____%=KqtH_#ev$____'=KqtH_#fBi____" + + "#=KqtH_#fBj____'=KqtH_#fG)____'=KqtH_#fG+____'=KqtH_#g*B____'=KqtH_$>hD____+=J^x0_$?lW____'=KqtH_$?ll____'=KqtH_$?lm____" + + "%=KqtH_$?mi____'=KqtH_$?mx____'=KqtH_$D7]____#=J_#p_$D@T____#=J_#p_$V cookies = ServerCookieDecoder.STRICT.decode("bh=\"" + longValue + "\";"); + assertEquals(1, cookies.size()); + Cookie c = cookies.iterator().next(); + assertEquals("bh", c.name()); + assertEquals(longValue, c.value()); + } + + @Test + public void testDecodingOldRFC2965Cookies() { + String source = "$Version=\"1\"; " + + "Part_Number1=\"Riding_Rocket_0023\"; $Path=\"/acme/ammo\"; " + + "Part_Number2=\"Rocket_Launcher_0001\"; $Path=\"/acme\""; + + Set cookies = ServerCookieDecoder.STRICT.decode(source); + Iterator it = cookies.iterator(); + Cookie c; + + c = it.next(); + assertEquals("Part_Number1", c.name()); + assertEquals("Riding_Rocket_0023", c.value()); + + c = it.next(); + assertEquals("Part_Number2", c.name()); + assertEquals("Rocket_Launcher_0001", c.value()); + + assertFalse(it.hasNext()); + } + + @Test + public void testRejectCookieValueWithSemicolon() { + Set cookies = ServerCookieDecoder.STRICT.decode("name=\"foo;bar\";"); + assertTrue(cookies.isEmpty()); + } + + @Test + public void testCaseSensitiveNames() { + Set cookies = ServerCookieDecoder.STRICT.decode("session_id=a; Session_id=b;"); + Iterator it = cookies.iterator(); + Cookie c; + + c = it.next(); + assertEquals("Session_id", c.name()); + assertEquals("b", c.value()); + + c = it.next(); + assertEquals("session_id", c.name()); + assertEquals("a", c.value()); + + assertFalse(it.hasNext()); + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/cookie/ServerCookieEncoderTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/cookie/ServerCookieEncoderTest.java new file mode 100644 index 0000000..6c86f65 --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/cookie/ServerCookieEncoderTest.java @@ -0,0 +1,160 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.cookie; + +import static org.hamcrest.CoreMatchers.containsString; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import io.netty.handler.codec.DateFormatter; + +import java.text.ParseException; +import java.util.ArrayList; +import java.util.Date; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import io.netty.handler.codec.http.cookie.CookieHeaderNames.SameSite; +import org.junit.jupiter.api.Test; + +public class ServerCookieEncoderTest { + + @Test + public void testEncodingSingleCookieV0() throws ParseException { + + int maxAge = 50; + + String result = "myCookie=myValue; Max-Age=50; Expires=(.+?); Path=/apathsomewhere;" + + " Domain=.adomainsomewhere; Secure; SameSite=Lax"; + DefaultCookie cookie = new DefaultCookie("myCookie", "myValue"); + cookie.setDomain(".adomainsomewhere"); + cookie.setMaxAge(maxAge); + cookie.setPath("/apathsomewhere"); + cookie.setSecure(true); + cookie.setSameSite(SameSite.Lax); + + String encodedCookie = ServerCookieEncoder.STRICT.encode(cookie); + + Matcher matcher = Pattern.compile(result).matcher(encodedCookie); + assertTrue(matcher.find()); + Date expiresDate = DateFormatter.parseHttpDate(matcher.group(1)); + long diff = (expiresDate.getTime() - System.currentTimeMillis()) / 1000; + // 2 secs should be fine + assertTrue(Math.abs(diff - maxAge) <= 2); + } + + @Test + public void testEncodingWithNoCookies() { + String encodedCookie1 = ClientCookieEncoder.STRICT.encode(); + List encodedCookie2 = ServerCookieEncoder.STRICT.encode(); + assertNull(encodedCookie1); + assertNotNull(encodedCookie2); + assertTrue(encodedCookie2.isEmpty()); + } + + @Test + public void testEncodingMultipleCookiesStrict() { + List result = new ArrayList(); + result.add("cookie2=value2"); + result.add("cookie1=value3"); + Cookie cookie1 = new DefaultCookie("cookie1", "value1"); + Cookie cookie2 = new DefaultCookie("cookie2", "value2"); + Cookie cookie3 = new DefaultCookie("cookie1", "value3"); + List encodedCookies = ServerCookieEncoder.STRICT.encode(cookie1, cookie2, cookie3); + assertEquals(result, encodedCookies); + } + + @Test + public void illegalCharInCookieNameMakesStrictEncoderThrowsException() { + Set illegalChars = new HashSet(); + // CTLs + for (int i = 0x00; i <= 0x1F; i++) { + illegalChars.add((char) i); + } + illegalChars.add((char) 0x7F); + // separators + for (char c : new char[] { '(', ')', '<', '>', '@', ',', ';', ':', '\\', '"', '/', '[', ']', + '?', '=', '{', '}', ' ', '\t' }) { + illegalChars.add(c); + } + + int exceptions = 0; + + for (char c : illegalChars) { + try { + ServerCookieEncoder.STRICT.encode(new DefaultCookie("foo" + c + "bar", "value")); + } catch (IllegalArgumentException e) { + exceptions++; + } + } + + assertEquals(illegalChars.size(), exceptions); + } + + @Test + public void illegalCharInCookieValueMakesStrictEncoderThrowsException() { + Set illegalChars = new HashSet(); + // CTLs + for (int i = 0x00; i <= 0x1F; i++) { + illegalChars.add((char) i); + } + illegalChars.add((char) 0x7F); + // whitespace, DQUOTE, comma, semicolon, and backslash + for (char c : new char[] { ' ', '"', ',', ';', '\\' }) { + illegalChars.add(c); + } + + int exceptions = 0; + + for (char c : illegalChars) { + try { + ServerCookieEncoder.STRICT.encode(new DefaultCookie("name", "value" + c)); + } catch (IllegalArgumentException e) { + exceptions++; + } + } + + assertEquals(illegalChars.size(), exceptions); + } + + @Test + public void illegalCharInWrappedValueAppearsInException() { + try { + ServerCookieEncoder.STRICT.encode(new DefaultCookie("name", "\"value,\"")); + } catch (IllegalArgumentException e) { + assertThat(e.getMessage().toLowerCase(), containsString("cookie value contains an invalid char: ,")); + } + } + + @Test + public void testEncodingMultipleCookiesLax() { + List result = new ArrayList(); + result.add("cookie1=value1"); + result.add("cookie2=value2"); + result.add("cookie1=value3"); + Cookie cookie1 = new DefaultCookie("cookie1", "value1"); + Cookie cookie2 = new DefaultCookie("cookie2", "value2"); + Cookie cookie3 = new DefaultCookie("cookie1", "value3"); + List encodedCookies = ServerCookieEncoder.LAX.encode(cookie1, cookie2, cookie3); + assertEquals(result, encodedCookies); + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/cors/CorsConfigTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/cors/CorsConfigTest.java new file mode 100644 index 0000000..6b9f852 --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/cors/CorsConfigTest.java @@ -0,0 +1,146 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version + * 2.0 (the "License"); you may not use this file except in compliance with the + * License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http.cors; + +import io.netty.handler.codec.http.EmptyHttpHeaders; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaders; +import io.netty.handler.codec.http.HttpMethod; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import static io.netty.handler.codec.http.HttpHeadersTestUtils.of; +import static io.netty.handler.codec.http.cors.CorsConfigBuilder.forAnyOrigin; +import static io.netty.handler.codec.http.cors.CorsConfigBuilder.forOrigin; +import static io.netty.handler.codec.http.cors.CorsConfigBuilder.forOrigins; +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.hasItems; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.CoreMatchers.notNullValue; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class CorsConfigTest { + + @Test + public void disabled() { + final CorsConfig cors = forAnyOrigin().disable().build(); + assertThat(cors.isCorsSupportEnabled(), is(false)); + } + + @Test + public void anyOrigin() { + final CorsConfig cors = forAnyOrigin().build(); + assertThat(cors.isAnyOriginSupported(), is(true)); + assertThat(cors.origin(), is("*")); + assertThat(cors.origins().isEmpty(), is(true)); + } + + @Test + public void wildcardOrigin() { + final CorsConfig cors = forOrigin("*").build(); + assertThat(cors.isAnyOriginSupported(), is(true)); + assertThat(cors.origin(), equalTo("*")); + assertThat(cors.origins().isEmpty(), is(true)); + } + + @Test + public void origin() { + final CorsConfig cors = forOrigin("http://localhost:7888").build(); + assertThat(cors.origin(), is(equalTo("http://localhost:7888"))); + assertThat(cors.isAnyOriginSupported(), is(false)); + } + + @Test + public void origins() { + final String[] origins = {"http://localhost:7888", "https://localhost:7888"}; + final CorsConfig cors = forOrigins(origins).build(); + assertThat(cors.origins(), hasItems(origins)); + assertThat(cors.isAnyOriginSupported(), is(false)); + } + + @Test + public void exposeHeaders() { + final CorsConfig cors = forAnyOrigin().exposeHeaders("custom-header1", "custom-header2").build(); + assertThat(cors.exposedHeaders(), hasItems("custom-header1", "custom-header2")); + } + + @Test + public void allowCredentials() { + final CorsConfig cors = forAnyOrigin().allowCredentials().build(); + assertThat(cors.isCredentialsAllowed(), is(true)); + } + + @Test + public void maxAge() { + final CorsConfig cors = forAnyOrigin().maxAge(3000).build(); + assertThat(cors.maxAge(), is(3000L)); + } + + @Test + public void requestMethods() { + final CorsConfig cors = forAnyOrigin().allowedRequestMethods(HttpMethod.POST, HttpMethod.GET).build(); + assertThat(cors.allowedRequestMethods(), hasItems(HttpMethod.POST, HttpMethod.GET)); + } + + @Test + public void requestHeaders() { + final CorsConfig cors = forAnyOrigin().allowedRequestHeaders("preflight-header1", "preflight-header2").build(); + assertThat(cors.allowedRequestHeaders(), hasItems("preflight-header1", "preflight-header2")); + } + + @Test + public void preflightResponseHeadersSingleValue() { + final CorsConfig cors = forAnyOrigin().preflightResponseHeader("SingleValue", "value").build(); + assertThat(cors.preflightResponseHeaders().get(of("SingleValue")), equalTo("value")); + } + + @Test + public void preflightResponseHeadersMultipleValues() { + final CorsConfig cors = forAnyOrigin().preflightResponseHeader("MultipleValues", "value1", "value2").build(); + assertThat(cors.preflightResponseHeaders().getAll(of("MultipleValues")), hasItems("value1", "value2")); + } + + @Test + public void defaultPreflightResponseHeaders() { + final CorsConfig cors = forAnyOrigin().build(); + assertThat(cors.preflightResponseHeaders().get(HttpHeaderNames.DATE), is(notNullValue())); + assertThat(cors.preflightResponseHeaders().get(HttpHeaderNames.CONTENT_LENGTH), is("0")); + } + + @Test + public void emptyPreflightResponseHeaders() { + final CorsConfig cors = forAnyOrigin().noPreflightResponseHeaders().build(); + assertThat(cors.preflightResponseHeaders(), equalTo((HttpHeaders) EmptyHttpHeaders.INSTANCE)); + } + + @Test + public void shouldThrowIfValueIsNull() { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + forOrigin("*").preflightResponseHeader("HeaderName", new Object[]{null}).build(); + } + }); + } + + @Test + public void shortCircuit() { + final CorsConfig cors = forOrigin("http://localhost:8080").shortCircuit().build(); + assertThat(cors.isShortCircuit(), is(true)); + } + +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/cors/CorsHandlerTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/cors/CorsHandlerTest.java new file mode 100644 index 0000000..e484bff --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/cors/CorsHandlerTest.java @@ -0,0 +1,591 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version + * 2.0 (the "License"); you may not use this file except in compliance with the + * License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http.cors; + +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.http.DefaultFullHttpRequest; +import io.netty.handler.codec.http.DefaultFullHttpResponse; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.DefaultHttpHeadersFactory; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpResponse; +import io.netty.handler.codec.http.HttpUtil; +import io.netty.util.AsciiString; +import io.netty.util.ReferenceCountUtil; +import org.hamcrest.core.IsEqual; +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.Callable; + +import static io.netty.handler.codec.http.HttpHeaderNames.*; +import static io.netty.handler.codec.http.HttpHeaderValues.CLOSE; +import static io.netty.handler.codec.http.HttpHeaderValues.KEEP_ALIVE; +import static io.netty.handler.codec.http.HttpHeadersTestUtils.of; +import static io.netty.handler.codec.http.HttpMethod.*; +import static io.netty.handler.codec.http.HttpResponseStatus.FORBIDDEN; +import static io.netty.handler.codec.http.HttpResponseStatus.OK; +import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1; +import static io.netty.handler.codec.http.cors.CorsConfigBuilder.*; +import static org.hamcrest.CoreMatchers.*; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.core.IsEqual.equalTo; + +public class CorsHandlerTest { + + @Test + public void nonCorsRequest() { + final HttpResponse response = simpleRequest(forAnyOrigin().build(), null); + assertThat(response.headers().contains(ACCESS_CONTROL_ALLOW_ORIGIN), is(false)); + assertThat(ReferenceCountUtil.release(response), is(true)); + } + + @Test + public void simpleRequestWithAnyOrigin() { + final HttpResponse response = simpleRequest(forAnyOrigin().build(), "http://localhost:7777"); + assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN), is("*")); + assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_HEADERS), is(nullValue())); + assertThat(ReferenceCountUtil.release(response), is(true)); + } + + @Test + public void simpleRequestWithNullOrigin() { + final HttpResponse response = simpleRequest(forOrigin("http://test.com").allowNullOrigin() + .allowCredentials() + .build(), "null"); + assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN), is("null")); + assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_CREDENTIALS), is(equalTo("true"))); + assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_HEADERS), is(nullValue())); + assertThat(ReferenceCountUtil.release(response), is(true)); + } + + @Test + public void simpleRequestWithOrigin() { + final String origin = "http://localhost:8888"; + final HttpResponse response = simpleRequest(forOrigin(origin).build(), origin); + assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN), is(origin)); + assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_HEADERS), is(nullValue())); + assertThat(ReferenceCountUtil.release(response), is(true)); + } + + @Test + public void simpleRequestWithOrigins() { + final String origin1 = "http://localhost:8888"; + final String origin2 = "https://localhost:8888"; + final String[] origins = {origin1, origin2}; + final HttpResponse response1 = simpleRequest(forOrigins(origins).build(), origin1); + assertThat(response1.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN), is(origin1)); + assertThat(response1.headers().get(ACCESS_CONTROL_ALLOW_HEADERS), is(nullValue())); + assertThat(ReferenceCountUtil.release(response1), is(true)); + + final HttpResponse response2 = simpleRequest(forOrigins(origins).build(), origin2); + assertThat(response2.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN), is(origin2)); + assertThat(response2.headers().get(ACCESS_CONTROL_ALLOW_HEADERS), is(nullValue())); + assertThat(ReferenceCountUtil.release(response2), is(true)); + } + + @Test + public void simpleRequestWithNoMatchingOrigin() { + final String origin = "http://localhost:8888"; + final HttpResponse response = simpleRequest( + forOrigins("https://localhost:8888").build(), origin); + assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN), is(nullValue())); + assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_HEADERS), is(nullValue())); + assertThat(ReferenceCountUtil.release(response), is(true)); + } + + @Test + public void preflightDeleteRequestWithCustomHeaders() { + final CorsConfig config = forOrigin("http://localhost:8888") + .allowedRequestMethods(GET, DELETE) + .build(); + final HttpResponse response = preflightRequest(config, "http://localhost:8888", "content-type, xheader1"); + assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN), is("http://localhost:8888")); + assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_METHODS), containsString("GET")); + assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_METHODS), containsString("DELETE")); + assertThat(response.headers().get(VARY), equalTo(ORIGIN.toString())); + assertThat(ReferenceCountUtil.release(response), is(true)); + } + + @Test + public void preflightGetRequestWithCustomHeaders() { + final CorsConfig config = forOrigin("http://localhost:8888") + .allowedRequestMethods(OPTIONS, GET, DELETE) + .allowedRequestHeaders("content-type", "xheader1") + .build(); + final HttpResponse response = preflightRequest(config, "http://localhost:8888", "content-type, xheader1"); + assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN), is("http://localhost:8888")); + assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_METHODS), containsString("OPTIONS")); + assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_METHODS), containsString("GET")); + assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_HEADERS), containsString("content-type")); + assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_HEADERS), containsString("xheader1")); + assertThat(response.headers().get(VARY), equalTo(ORIGIN.toString())); + assertThat(ReferenceCountUtil.release(response), is(true)); + } + + @Test + public void preflightRequestWithDefaultHeaders() { + final CorsConfig config = forOrigin("http://localhost:8888").build(); + final HttpResponse response = preflightRequest(config, "http://localhost:8888", "content-type, xheader1"); + assertThat(response.headers().get(CONTENT_LENGTH), is("0")); + assertThat(response.headers().get(DATE), is(notNullValue())); + assertThat(response.headers().get(VARY), equalTo(ORIGIN.toString())); + assertThat(ReferenceCountUtil.release(response), is(true)); + } + + @Test + public void preflightRequestWithCustomHeader() { + final CorsConfig config = forOrigin("http://localhost:8888") + .preflightResponseHeader("CustomHeader", "somevalue") + .build(); + final HttpResponse response = preflightRequest(config, "http://localhost:8888", "content-type, xheader1"); + assertThat(response.headers().get(of("CustomHeader")), equalTo("somevalue")); + assertThat(response.headers().get(VARY), equalTo(ORIGIN.toString())); + assertThat(response.headers().get(CONTENT_LENGTH), is("0")); + assertThat(ReferenceCountUtil.release(response), is(true)); + } + + @Test + public void preflightRequestWithUnauthorizedOrigin() { + final String origin = "http://host"; + final CorsConfig config = forOrigin("http://localhost").build(); + final HttpResponse response = preflightRequest(config, origin, "xheader1"); + assertThat(response.headers().contains(ACCESS_CONTROL_ALLOW_ORIGIN), is(false)); + assertThat(ReferenceCountUtil.release(response), is(true)); + } + + @Test + public void preflightRequestWithCustomHeaders() { + final String headerName = "CustomHeader"; + final String value1 = "value1"; + final String value2 = "value2"; + final CorsConfig config = forOrigin("http://localhost:8888") + .preflightResponseHeader(headerName, value1, value2) + .build(); + final HttpResponse response = preflightRequest(config, "http://localhost:8888", "content-type, xheader1"); + assertValues(response, headerName, value1, value2); + assertThat(response.headers().get(VARY), equalTo(ORIGIN.toString())); + assertThat(ReferenceCountUtil.release(response), is(true)); + } + + @Test + public void preflightRequestWithCustomHeadersIterable() { + final String headerName = "CustomHeader"; + final String value1 = "value1"; + final String value2 = "value2"; + final CorsConfig config = forOrigin("http://localhost:8888") + .preflightResponseHeader(headerName, Arrays.asList(value1, value2)) + .build(); + final HttpResponse response = preflightRequest(config, "http://localhost:8888", "content-type, xheader1"); + assertValues(response, headerName, value1, value2); + assertThat(response.headers().get(VARY), equalTo(ORIGIN.toString())); + assertThat(ReferenceCountUtil.release(response), is(true)); + } + + @Test + public void preflightRequestWithValueGenerator() { + final CorsConfig config = forOrigin("http://localhost:8888") + .preflightResponseHeader("GenHeader", new Callable() { + @Override + public String call() throws Exception { + return "generatedValue"; + } + }).build(); + final HttpResponse response = preflightRequest(config, "http://localhost:8888", "content-type, xheader1"); + assertThat(response.headers().get(of("GenHeader")), equalTo("generatedValue")); + assertThat(response.headers().get(VARY), equalTo(ORIGIN.toString())); + assertThat(ReferenceCountUtil.release(response), is(true)); + } + + @Test + public void preflightRequestWithNullOrigin() { + final String origin = "null"; + final CorsConfig config = forOrigin(origin) + .allowNullOrigin() + .allowCredentials() + .build(); + final HttpResponse response = preflightRequest(config, origin, "content-type, xheader1"); + assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN), is(equalTo("null"))); + assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_CREDENTIALS), is(equalTo("true"))); + assertThat(ReferenceCountUtil.release(response), is(true)); + } + + @Test + public void preflightRequestAllowCredentials() { + final String origin = "null"; + final CorsConfig config = forOrigin(origin).allowCredentials().build(); + final HttpResponse response = preflightRequest(config, origin, "content-type, xheader1"); + assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_CREDENTIALS), is(equalTo("true"))); + assertThat(ReferenceCountUtil.release(response), is(true)); + } + + @Test + public void preflightRequestDoNotAllowCredentials() { + final CorsConfig config = forOrigin("http://localhost:8888").build(); + final HttpResponse response = preflightRequest(config, "http://localhost:8888", ""); + // the only valid value for Access-Control-Allow-Credentials is true. + assertThat(response.headers().contains(ACCESS_CONTROL_ALLOW_CREDENTIALS), is(false)); + assertThat(ReferenceCountUtil.release(response), is(true)); + } + + @Test + public void simpleRequestCustomHeaders() { + final CorsConfig config = forAnyOrigin().exposeHeaders("custom1", "custom2").build(); + final HttpResponse response = simpleRequest(config, "http://localhost:7777"); + assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN), equalTo("*")); + assertThat(response.headers().get(ACCESS_CONTROL_EXPOSE_HEADERS), containsString("custom1")); + assertThat(response.headers().get(ACCESS_CONTROL_EXPOSE_HEADERS), containsString("custom2")); + assertThat(ReferenceCountUtil.release(response), is(true)); + } + + @Test + public void simpleRequestAllowCredentials() { + final CorsConfig config = forAnyOrigin().allowCredentials().build(); + final HttpResponse response = simpleRequest(config, "http://localhost:7777"); + assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_CREDENTIALS), equalTo("true")); + assertThat(ReferenceCountUtil.release(response), is(true)); + } + + @Test + public void simpleRequestDoNotAllowCredentials() { + final CorsConfig config = forAnyOrigin().build(); + final HttpResponse response = simpleRequest(config, "http://localhost:7777"); + assertThat(response.headers().contains(ACCESS_CONTROL_ALLOW_CREDENTIALS), is(false)); + assertThat(ReferenceCountUtil.release(response), is(true)); + } + + @Test + public void anyOriginAndAllowCredentialsShouldEchoRequestOrigin() { + final CorsConfig config = forAnyOrigin().allowCredentials().build(); + final HttpResponse response = simpleRequest(config, "http://localhost:7777"); + assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_CREDENTIALS), equalTo("true")); + assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN), equalTo("http://localhost:7777")); + assertThat(response.headers().get(VARY), equalTo(ORIGIN.toString())); + assertThat(ReferenceCountUtil.release(response), is(true)); + } + + @Test + public void simpleRequestExposeHeaders() { + final CorsConfig config = forAnyOrigin().exposeHeaders("one", "two").build(); + final HttpResponse response = simpleRequest(config, "http://localhost:7777"); + assertThat(response.headers().get(ACCESS_CONTROL_EXPOSE_HEADERS), containsString("one")); + assertThat(response.headers().get(ACCESS_CONTROL_EXPOSE_HEADERS), containsString("two")); + assertThat(ReferenceCountUtil.release(response), is(true)); + } + + @Test + public void simpleRequestShortCircuit() { + final CorsConfig config = forOrigin("http://localhost:8080").shortCircuit().build(); + final HttpResponse response = simpleRequest(config, "http://localhost:7777"); + assertThat(response.status(), is(FORBIDDEN)); + assertThat(response.headers().get(CONTENT_LENGTH), is("0")); + assertThat(ReferenceCountUtil.release(response), is(true)); + } + + @Test + public void simpleRequestNoShortCircuit() { + final CorsConfig config = forOrigin("http://localhost:8080").build(); + final HttpResponse response = simpleRequest(config, "http://localhost:7777"); + assertThat(response.status(), is(OK)); + assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN), is(nullValue())); + assertThat(ReferenceCountUtil.release(response), is(true)); + } + + @Test + public void shortCircuitNonCorsRequest() { + final CorsConfig config = forOrigin("https://localhost").shortCircuit().build(); + final HttpResponse response = simpleRequest(config, null); + assertThat(response.status(), is(OK)); + assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_ORIGIN), is(nullValue())); + assertThat(ReferenceCountUtil.release(response), is(true)); + } + + @Test + public void shortCircuitWithConnectionKeepAliveShouldStayOpen() { + final CorsConfig config = forOrigin("http://localhost:8080").shortCircuit().build(); + final EmbeddedChannel channel = new EmbeddedChannel(new CorsHandler(config)); + final FullHttpRequest request = createHttpRequest(GET); + request.headers().set(ORIGIN, "http://localhost:8888"); + request.headers().set(CONNECTION, KEEP_ALIVE); + + assertThat(channel.writeInbound(request), is(false)); + final HttpResponse response = channel.readOutbound(); + assertThat(HttpUtil.isKeepAlive(response), is(true)); + + assertThat(channel.isOpen(), is(true)); + assertThat(response.status(), is(FORBIDDEN)); + assertThat(ReferenceCountUtil.release(response), is(true)); + assertThat(channel.finish(), is(false)); + } + + @Test + public void shortCircuitWithoutConnectionShouldStayOpen() { + final CorsConfig config = forOrigin("http://localhost:8080").shortCircuit().build(); + final EmbeddedChannel channel = new EmbeddedChannel(new CorsHandler(config)); + final FullHttpRequest request = createHttpRequest(GET); + request.headers().set(ORIGIN, "http://localhost:8888"); + + assertThat(channel.writeInbound(request), is(false)); + final HttpResponse response = channel.readOutbound(); + assertThat(HttpUtil.isKeepAlive(response), is(true)); + + assertThat(channel.isOpen(), is(true)); + assertThat(response.status(), is(FORBIDDEN)); + assertThat(ReferenceCountUtil.release(response), is(true)); + assertThat(channel.finish(), is(false)); + } + + @Test + public void shortCircuitWithConnectionCloseShouldClose() { + final CorsConfig config = forOrigin("http://localhost:8080").shortCircuit().build(); + final EmbeddedChannel channel = new EmbeddedChannel(new CorsHandler(config)); + final FullHttpRequest request = createHttpRequest(GET); + request.headers().set(ORIGIN, "http://localhost:8888"); + request.headers().set(CONNECTION, CLOSE); + + assertThat(channel.writeInbound(request), is(false)); + final HttpResponse response = channel.readOutbound(); + assertThat(HttpUtil.isKeepAlive(response), is(false)); + + assertThat(channel.isOpen(), is(false)); + assertThat(response.status(), is(FORBIDDEN)); + assertThat(ReferenceCountUtil.release(response), is(true)); + assertThat(channel.finish(), is(false)); + } + + @Test + public void preflightRequestShouldReleaseRequest() { + final CorsConfig config = forOrigin("http://localhost:8888") + .preflightResponseHeader("CustomHeader", Arrays.asList("value1", "value2")) + .build(); + final EmbeddedChannel channel = new EmbeddedChannel(new CorsHandler(config)); + final FullHttpRequest request = optionsRequest("http://localhost:8888", "content-type, xheader1", null); + assertThat(channel.writeInbound(request), is(false)); + assertThat(request.refCnt(), is(0)); + assertThat(ReferenceCountUtil.release(channel.readOutbound()), is(true)); + assertThat(channel.finish(), is(false)); + } + + @Test + public void preflightRequestWithConnectionKeepAliveShouldStayOpen() throws Exception { + + final CorsConfig config = forOrigin("http://localhost:8888").build(); + final EmbeddedChannel channel = new EmbeddedChannel(new CorsHandler(config)); + final FullHttpRequest request = optionsRequest("http://localhost:8888", "", KEEP_ALIVE); + assertThat(channel.writeInbound(request), is(false)); + final HttpResponse response = channel.readOutbound(); + assertThat(HttpUtil.isKeepAlive(response), is(true)); + + assertThat(channel.isOpen(), is(true)); + assertThat(response.status(), is(OK)); + assertThat(ReferenceCountUtil.release(response), is(true)); + assertThat(channel.finish(), is(false)); + } + + @Test + public void preflightRequestWithoutConnectionShouldStayOpen() throws Exception { + + final CorsConfig config = forOrigin("http://localhost:8888").build(); + final EmbeddedChannel channel = new EmbeddedChannel(new CorsHandler(config)); + final FullHttpRequest request = optionsRequest("http://localhost:8888", "", null); + assertThat(channel.writeInbound(request), is(false)); + final HttpResponse response = channel.readOutbound(); + assertThat(HttpUtil.isKeepAlive(response), is(true)); + + assertThat(channel.isOpen(), is(true)); + assertThat(response.status(), is(OK)); + assertThat(ReferenceCountUtil.release(response), is(true)); + assertThat(channel.finish(), is(false)); + } + + @Test + public void preflightRequestWithConnectionCloseShouldClose() throws Exception { + + final CorsConfig config = forOrigin("http://localhost:8888").build(); + final EmbeddedChannel channel = new EmbeddedChannel(new CorsHandler(config)); + final FullHttpRequest request = optionsRequest("http://localhost:8888", "", CLOSE); + assertThat(channel.writeInbound(request), is(false)); + final HttpResponse response = channel.readOutbound(); + assertThat(HttpUtil.isKeepAlive(response), is(false)); + + assertThat(channel.isOpen(), is(false)); + assertThat(response.status(), is(OK)); + assertThat(ReferenceCountUtil.release(response), is(true)); + assertThat(channel.finish(), is(false)); + } + + @Test + public void forbiddenShouldReleaseRequest() { + final CorsConfig config = forOrigin("https://localhost").shortCircuit().build(); + final EmbeddedChannel channel = new EmbeddedChannel(new CorsHandler(config), new EchoHandler()); + final FullHttpRequest request = createHttpRequest(GET); + request.headers().set(ORIGIN, "http://localhost:8888"); + assertThat(channel.writeInbound(request), is(false)); + assertThat(request.refCnt(), is(0)); + assertThat(ReferenceCountUtil.release(channel.readOutbound()), is(true)); + assertThat(channel.finish(), is(false)); + } + + @Test + public void differentConfigsPerOrigin() { + String host1 = "http://host1:80"; + String host2 = "http://host2"; + CorsConfig rule1 = forOrigin(host1).allowedRequestMethods(HttpMethod.GET).build(); + CorsConfig rule2 = forOrigin(host2).allowedRequestMethods(HttpMethod.GET, HttpMethod.POST) + .allowCredentials().build(); + + List corsConfigs = Arrays.asList(rule1, rule2); + + final HttpResponse preFlightHost1 = preflightRequest(corsConfigs, host1, "", false); + assertThat(preFlightHost1.headers().get(ACCESS_CONTROL_ALLOW_METHODS), is("GET")); + assertThat(preFlightHost1.headers().getAsString(ACCESS_CONTROL_ALLOW_CREDENTIALS), is(nullValue())); + + final HttpResponse preFlightHost2 = preflightRequest(corsConfigs, host2, "", false); + assertValues(preFlightHost2, ACCESS_CONTROL_ALLOW_METHODS.toString(), "GET", "POST"); + assertThat(preFlightHost2.headers().getAsString(ACCESS_CONTROL_ALLOW_CREDENTIALS), IsEqual.equalTo("true")); + } + + @Test + public void specificConfigPrecedenceOverGeneric() { + String host1 = "http://host1"; + String host2 = "http://host2"; + + CorsConfig forHost1 = forOrigin(host1).allowedRequestMethods(HttpMethod.GET).maxAge(3600L).build(); + CorsConfig allowAll = forAnyOrigin().allowedRequestMethods(HttpMethod.POST, HttpMethod.GET, HttpMethod.OPTIONS) + .maxAge(1800).build(); + + List rules = Arrays.asList(forHost1, allowAll); + + final HttpResponse host1Response = preflightRequest(rules, host1, "", false); + assertThat(host1Response.headers().get(ACCESS_CONTROL_ALLOW_METHODS), is("GET")); + assertThat(host1Response.headers().getAsString(ACCESS_CONTROL_MAX_AGE), equalTo("3600")); + + final HttpResponse host2Response = preflightRequest(rules, host2, "", false); + assertValues(host2Response, ACCESS_CONTROL_ALLOW_METHODS.toString(), "POST", "GET", "OPTIONS"); + assertThat(host2Response.headers().getAsString(ACCESS_CONTROL_ALLOW_ORIGIN), equalTo("*")); + assertThat(host2Response.headers().getAsString(ACCESS_CONTROL_MAX_AGE), equalTo("1800")); + } + + @Test + public void simpleRequestAllowPrivateNetwork() { + final CorsConfig config = forOrigin("http://localhost:8888").allowPrivateNetwork().build(); + final EmbeddedChannel channel = new EmbeddedChannel(new CorsHandler(config)); + final FullHttpRequest request = optionsRequest("http://localhost:8888", "", null); + request.headers().set(ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK, "true"); + assertThat(channel.writeInbound(request), is(false)); + final HttpResponse response = channel.readOutbound(); + + assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK), equalTo("true")); + assertThat(ReferenceCountUtil.release(response), is(true)); + } + + @Test + public void simpleRequestDoNotAllowPrivateNetwork() { + final CorsConfig config = forOrigin("http://localhost:8888").build(); + final EmbeddedChannel channel = new EmbeddedChannel(new CorsHandler(config)); + final FullHttpRequest request = optionsRequest("http://localhost:8888", "", null); + request.headers().set(ACCESS_CONTROL_REQUEST_PRIVATE_NETWORK, "true"); + assertThat(channel.writeInbound(request), is(false)); + final HttpResponse response = channel.readOutbound(); + + assertThat(response.headers().get(ACCESS_CONTROL_ALLOW_PRIVATE_NETWORK), equalTo("false")); + assertThat(ReferenceCountUtil.release(response), is(true)); + } + + private static HttpResponse simpleRequest(final CorsConfig config, final String origin) { + return simpleRequest(config, origin, null); + } + + private static HttpResponse simpleRequest(final CorsConfig config, + final String origin, + final String requestHeaders) { + return simpleRequest(config, origin, requestHeaders, GET); + } + + private static HttpResponse simpleRequest(final CorsConfig config, + final String origin, + final String requestHeaders, + final HttpMethod method) { + final EmbeddedChannel channel = new EmbeddedChannel(new CorsHandler(config), new EchoHandler()); + final FullHttpRequest httpRequest = createHttpRequest(method); + if (origin != null) { + httpRequest.headers().set(ORIGIN, origin); + } + if (requestHeaders != null) { + httpRequest.headers().set(ACCESS_CONTROL_REQUEST_HEADERS, requestHeaders); + } + assertThat(channel.writeInbound(httpRequest), is(false)); + HttpResponse response = channel.readOutbound(); + assertThat(channel.finish(), is(false)); + return response; + } + + private static HttpResponse preflightRequest(final CorsConfig config, + final String origin, + final String requestHeaders) { + return preflightRequest(Collections.singletonList(config), origin, requestHeaders, config.isShortCircuit()); + } + + private static HttpResponse preflightRequest(final List configs, + final String origin, + final String requestHeaders, + final boolean isSHortCircuit) { + final EmbeddedChannel channel = new EmbeddedChannel(new CorsHandler(configs, isSHortCircuit)); + assertThat(channel.writeInbound(optionsRequest(origin, requestHeaders, null)), is(false)); + HttpResponse response = channel.readOutbound(); + assertThat(channel.finish(), is(false)); + return response; + } + + private static FullHttpRequest optionsRequest(final String origin, + final String requestHeaders, + final AsciiString connection) { + final FullHttpRequest httpRequest = createHttpRequest(OPTIONS); + httpRequest.headers().set(ORIGIN, origin); + httpRequest.headers().set(ACCESS_CONTROL_REQUEST_METHOD, httpRequest.method().toString()); + httpRequest.headers().set(ACCESS_CONTROL_REQUEST_HEADERS, requestHeaders); + if (connection != null) { + httpRequest.headers().set(CONNECTION, connection); + } + + return httpRequest; + } + + private static FullHttpRequest createHttpRequest(HttpMethod method) { + return new DefaultFullHttpRequest(HTTP_1_1, method, "/info"); + } + + private static class EchoHandler extends SimpleChannelInboundHandler { + @Override + public void channelRead0(ChannelHandlerContext ctx, Object msg) throws Exception { + ctx.writeAndFlush(new DefaultFullHttpResponse(HTTP_1_1, OK, Unpooled.buffer(0), + DefaultHttpHeadersFactory.headersFactory().withCombiningHeaders(true), + DefaultHttpHeadersFactory.trailersFactory().withCombiningHeaders(true))); + } + } + + private static void assertValues(final HttpResponse response, final String headerName, final String... values) { + final String header = response.headers().get(of(headerName)); + for (String value : values) { + assertThat(header, containsString(value)); + } + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/multipart/AbstractDiskHttpDataTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/multipart/AbstractDiskHttpDataTest.java new file mode 100644 index 0000000..a194284 --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/multipart/AbstractDiskHttpDataTest.java @@ -0,0 +1,128 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.multipart; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.util.internal.PlatformDependent; +import org.junit.jupiter.api.Test; + +import java.io.File; +import java.io.FileOutputStream; +import java.nio.charset.Charset; +import java.util.Arrays; +import java.util.UUID; + +import static io.netty.util.CharsetUtil.UTF_8; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; + +/** + * {@link AbstractDiskHttpData} test cases + */ +public class AbstractDiskHttpDataTest { + + @Test + public void testGetChunk() throws Exception { + TestHttpData test = new TestHttpData("test", UTF_8, 0); + try { + File tmpFile = PlatformDependent.createTempFile(UUID.randomUUID().toString(), ".tmp", null); + tmpFile.deleteOnExit(); + FileOutputStream fos = new FileOutputStream(tmpFile); + byte[] bytes = new byte[4096]; + PlatformDependent.threadLocalRandom().nextBytes(bytes); + try { + fos.write(bytes); + fos.flush(); + } finally { + fos.close(); + } + test.setContent(tmpFile); + ByteBuf buf1 = test.getChunk(1024); + assertEquals(buf1.readerIndex(), 0); + assertEquals(buf1.writerIndex(), 1024); + ByteBuf buf2 = test.getChunk(1024); + assertEquals(buf2.readerIndex(), 0); + assertEquals(buf2.writerIndex(), 1024); + assertFalse(Arrays.equals(ByteBufUtil.getBytes(buf1), ByteBufUtil.getBytes(buf2)), + "Arrays should not be equal"); + } finally { + test.delete(); + } + } + + private static final class TestHttpData extends AbstractDiskHttpData { + + private TestHttpData(String name, Charset charset, long size) { + super(name, charset, size); + } + + @Override + protected String getDiskFilename() { + return null; + } + + @Override + protected String getPrefix() { + return null; + } + + @Override + protected String getBaseDirectory() { + return null; + } + + @Override + protected String getPostfix() { + return null; + } + + @Override + protected boolean deleteOnExit() { + return false; + } + + @Override + public HttpData copy() { + return null; + } + + @Override + public HttpData duplicate() { + return null; + } + + @Override + public HttpData retainedDuplicate() { + return null; + } + + @Override + public HttpData replace(ByteBuf content) { + return null; + } + + @Override + public HttpDataType getHttpDataType() { + return null; + } + + @Override + public int compareTo(InterfaceHttpData o) { + return 0; + } + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/multipart/AbstractMemoryHttpDataTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/multipart/AbstractMemoryHttpDataTest.java new file mode 100644 index 0000000..3e75aeb --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/multipart/AbstractMemoryHttpDataTest.java @@ -0,0 +1,212 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.multipart; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufInputStream; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import io.netty.util.internal.PlatformDependent; + +import org.junit.jupiter.api.Test; + +import java.io.ByteArrayInputStream; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.nio.charset.Charset; +import java.security.SecureRandom; +import java.util.Arrays; +import java.util.Random; +import java.util.UUID; + +import static io.netty.util.CharsetUtil.*; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** {@link AbstractMemoryHttpData} test cases. */ +public class AbstractMemoryHttpDataTest { + + @Test + public void testSetContentFromFile() throws Exception { + TestHttpData test = new TestHttpData("test", UTF_8, 0); + try { + File tmpFile = PlatformDependent.createTempFile(UUID.randomUUID().toString(), ".tmp", null); + tmpFile.deleteOnExit(); + FileOutputStream fos = new FileOutputStream(tmpFile); + byte[] bytes = new byte[4096]; + PlatformDependent.threadLocalRandom().nextBytes(bytes); + try { + fos.write(bytes); + fos.flush(); + } finally { + fos.close(); + } + test.setContent(tmpFile); + ByteBuf buf = test.getByteBuf(); + assertEquals(buf.readerIndex(), 0); + assertEquals(buf.writerIndex(), bytes.length); + assertArrayEquals(bytes, test.get()); + assertArrayEquals(bytes, ByteBufUtil.getBytes(buf)); + } finally { + //release the ByteBuf + test.delete(); + } + } + + @Test + public void testRenameTo() throws Exception { + TestHttpData test = new TestHttpData("test", UTF_8, 0); + try { + File tmpFile = PlatformDependent.createTempFile(UUID.randomUUID().toString(), ".tmp", null); + tmpFile.deleteOnExit(); + final int totalByteCount = 4096; + byte[] bytes = new byte[totalByteCount]; + PlatformDependent.threadLocalRandom().nextBytes(bytes); + ByteBuf content = Unpooled.wrappedBuffer(bytes); + test.setContent(content); + boolean succ = test.renameTo(tmpFile); + assertTrue(succ); + FileInputStream fis = new FileInputStream(tmpFile); + try { + byte[] buf = new byte[totalByteCount]; + int count = 0; + int offset = 0; + int size = totalByteCount; + while ((count = fis.read(buf, offset, size)) > 0) { + offset += count; + size -= count; + if (offset >= totalByteCount || size <= 0) { + break; + } + } + assertArrayEquals(bytes, buf); + assertEquals(0, fis.available()); + } finally { + fis.close(); + } + } finally { + //release the ByteBuf in AbstractMemoryHttpData + test.delete(); + } + } + /** + * Provide content into HTTP data with input stream. + * + * @throws Exception In case of any exception. + */ + @Test + public void testSetContentFromStream() throws Exception { + // definedSize=0 + TestHttpData test = new TestHttpData("test", UTF_8, 0); + String contentStr = "foo_test"; + ByteBuf buf = Unpooled.wrappedBuffer(contentStr.getBytes(UTF_8)); + buf.markReaderIndex(); + ByteBufInputStream is = new ByteBufInputStream(buf); + try { + test.setContent(is); + assertFalse(buf.isReadable()); + assertEquals(test.getString(UTF_8), contentStr); + buf.resetReaderIndex(); + assertTrue(ByteBufUtil.equals(buf, test.getByteBuf())); + } finally { + is.close(); + } + + Random random = new SecureRandom(); + + for (int i = 0; i < 20; i++) { + // Generate input data bytes. + int size = random.nextInt(Short.MAX_VALUE); + byte[] bytes = new byte[size]; + + random.nextBytes(bytes); + + // Generate parsed HTTP data block. + TestHttpData data = new TestHttpData("name", UTF_8, 0); + + data.setContent(new ByteArrayInputStream(bytes)); + + // Validate stored data. + ByteBuf buffer = data.getByteBuf(); + + assertEquals(0, buffer.readerIndex()); + assertEquals(bytes.length, buffer.writerIndex()); + assertArrayEquals(bytes, Arrays.copyOf(buffer.array(), bytes.length)); + assertArrayEquals(bytes, data.get()); + } + } + + /** Memory-based HTTP data implementation for test purposes. */ + private static final class TestHttpData extends AbstractMemoryHttpData { + /** + * Constructs HTTP data for tests. + * + * @param name Name of parsed data block. + * @param charset Used charset for data decoding. + * @param size Expected data block size. + */ + private TestHttpData(String name, Charset charset, long size) { + super(name, charset, size); + } + + @Override + public InterfaceHttpData.HttpDataType getHttpDataType() { + throw reject(); + } + + @Override + public HttpData copy() { + throw reject(); + } + + @Override + public HttpData duplicate() { + throw reject(); + } + + @Override + public HttpData retainedDuplicate() { + throw reject(); + } + + @Override + public HttpData replace(ByteBuf content) { + return null; + } + + @Override + public int compareTo(InterfaceHttpData o) { + throw reject(); + } + + @Override + public int hashCode() { + return super.hashCode(); + } + + @Override + public boolean equals(Object obj) { + return super.equals(obj); + } + + private static UnsupportedOperationException reject() { + throw new UnsupportedOperationException("Should never be called."); + } + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/multipart/DefaultHttpDataFactoryTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/multipart/DefaultHttpDataFactoryTest.java new file mode 100644 index 0000000..9e6c2e1 --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/multipart/DefaultHttpDataFactoryTest.java @@ -0,0 +1,164 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.multipart; + +import io.netty.buffer.Unpooled; +import io.netty.handler.codec.http.DefaultHttpRequest; +import io.netty.handler.codec.http.HttpRequest; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import static io.netty.handler.codec.http.HttpHeaderValues.IDENTITY; +import static io.netty.handler.codec.http.HttpMethod.POST; +import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1; +import static io.netty.handler.codec.http.multipart.HttpPostBodyUtil.DEFAULT_TEXT_CONTENT_TYPE; +import static io.netty.util.CharsetUtil.UTF_8; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class DefaultHttpDataFactoryTest { + // req1 equals req2 + private static final HttpRequest req1 = new DefaultHttpRequest(HTTP_1_1, POST, "/form"); + private static final HttpRequest req2 = new DefaultHttpRequest(HTTP_1_1, POST, "/form"); + + private DefaultHttpDataFactory factory; + + @BeforeAll + public static void assertReq1EqualsReq2() { + // Before doing anything, assert that the requests are equal + assertEquals(req1.hashCode(), req2.hashCode()); + assertTrue(req1.equals(req2)); + } + + @BeforeEach + public void setupFactory() { + factory = new DefaultHttpDataFactory(); + } + + @AfterEach + public void cleanupFactory() { + factory.cleanAllHttpData(); + } + + @Test + public void customBaseDirAndDeleteOnExit() { + final DefaultHttpDataFactory defaultHttpDataFactory = new DefaultHttpDataFactory(true); + final String dir = "target/DefaultHttpDataFactoryTest/customBaseDirAndDeleteOnExit"; + defaultHttpDataFactory.setBaseDir(dir); + defaultHttpDataFactory.setDeleteOnExit(true); + final Attribute attr = defaultHttpDataFactory.createAttribute(req1, "attribute1"); + final FileUpload fu = defaultHttpDataFactory.createFileUpload( + req1, "attribute1", "f.txt", "text/plain", null, null, 0); + assertEquals(dir, DiskAttribute.class.cast(attr).getBaseDirectory()); + assertEquals(dir, DiskFileUpload.class.cast(fu).getBaseDirectory()); + assertTrue(DiskAttribute.class.cast(attr).deleteOnExit()); + assertTrue(DiskFileUpload.class.cast(fu).deleteOnExit()); + } + + @Test + public void cleanRequestHttpDataShouldIdentifiesRequestsByTheirIdentities() throws Exception { + // Create some data belonging to req1 and req2 + Attribute attribute1 = factory.createAttribute(req1, "attribute1", "value1"); + Attribute attribute2 = factory.createAttribute(req2, "attribute2", "value2"); + FileUpload file1 = factory.createFileUpload( + req1, "file1", "file1.txt", + DEFAULT_TEXT_CONTENT_TYPE, IDENTITY.toString(), UTF_8, 123 + ); + FileUpload file2 = factory.createFileUpload( + req2, "file2", "file2.txt", + DEFAULT_TEXT_CONTENT_TYPE, IDENTITY.toString(), UTF_8, 123 + ); + file1.setContent(Unpooled.copiedBuffer("file1 content", UTF_8)); + file2.setContent(Unpooled.copiedBuffer("file2 content", UTF_8)); + + // Assert that they are not deleted + assertNotNull(attribute1.getByteBuf()); + assertNotNull(attribute2.getByteBuf()); + assertNotNull(file1.getByteBuf()); + assertNotNull(file2.getByteBuf()); + assertEquals(1, attribute1.refCnt()); + assertEquals(1, attribute2.refCnt()); + assertEquals(1, file1.refCnt()); + assertEquals(1, file2.refCnt()); + + // Clean up by req1 + factory.cleanRequestHttpData(req1); + + // Assert that data belonging to req1 has been cleaned up + assertNull(attribute1.getByteBuf()); + assertNull(file1.getByteBuf()); + assertEquals(0, attribute1.refCnt()); + assertEquals(0, file1.refCnt()); + + // But not req2 + assertNotNull(attribute2.getByteBuf()); + assertNotNull(file2.getByteBuf()); + assertEquals(1, attribute2.refCnt()); + assertEquals(1, file2.refCnt()); + } + + @Test + public void removeHttpDataFromCleanShouldIdentifiesDataByTheirIdentities() throws Exception { + // Create some equal data items belonging to the same request + Attribute attribute1 = factory.createAttribute(req1, "attribute", "value"); + Attribute attribute2 = factory.createAttribute(req1, "attribute", "value"); + FileUpload file1 = factory.createFileUpload( + req1, "file", "file.txt", + DEFAULT_TEXT_CONTENT_TYPE, IDENTITY.toString(), UTF_8, 123 + ); + FileUpload file2 = factory.createFileUpload( + req1, "file", "file.txt", + DEFAULT_TEXT_CONTENT_TYPE, IDENTITY.toString(), UTF_8, 123 + ); + file1.setContent(Unpooled.copiedBuffer("file content", UTF_8)); + file2.setContent(Unpooled.copiedBuffer("file content", UTF_8)); + + // Before doing anything, assert that the data items are equal + assertEquals(attribute1.hashCode(), attribute2.hashCode()); + assertTrue(attribute1.equals(attribute2)); + assertEquals(file1.hashCode(), file2.hashCode()); + assertTrue(file1.equals(file2)); + + // Remove attribute2 and file2 from being cleaned up by factory + factory.removeHttpDataFromClean(req1, attribute2); + factory.removeHttpDataFromClean(req1, file2); + + // Clean up by req1 + factory.cleanRequestHttpData(req1); + + // Assert that attribute1 and file1 have been cleaned up + assertNull(attribute1.getByteBuf()); + assertNull(file1.getByteBuf()); + assertEquals(0, attribute1.refCnt()); + assertEquals(0, file1.refCnt()); + + // But not attribute2 and file2 + assertNotNull(attribute2.getByteBuf()); + assertNotNull(file2.getByteBuf()); + assertEquals(1, attribute2.refCnt()); + assertEquals(1, file2.refCnt()); + + // Cleanup attribute2 and file2 manually to avoid memory leak, not via factory + attribute2.release(); + file2.release(); + assertEquals(0, attribute2.refCnt()); + assertEquals(0, file2.refCnt()); + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/multipart/DeleteFileOnExitHookTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/multipart/DeleteFileOnExitHookTest.java new file mode 100644 index 0000000..38bd384 --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/multipart/DeleteFileOnExitHookTest.java @@ -0,0 +1,87 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.multipart; + +import io.netty.buffer.Unpooled; +import io.netty.handler.codec.http.DefaultHttpRequest; +import io.netty.handler.codec.http.HttpRequest; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.parallel.Isolated; + +import java.io.File; +import java.io.FilenameFilter; +import java.io.IOException; +import java.util.UUID; + +import static io.netty.handler.codec.http.HttpMethod.POST; +import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * Test DeleteFileOnExitHook + */ +@Isolated("The DeleteFileOnExitHook has static shared mutable, " + + "and can interferre with other tests that use DiskAttribute") +public class DeleteFileOnExitHookTest { + private static final HttpRequest REQUEST = new DefaultHttpRequest(HTTP_1_1, POST, "/form"); + private static final String HOOK_TEST_TMP = "target/DeleteFileOnExitHookTest-" + UUID.randomUUID() + "/tmp"; + private FileUpload fu; + + @BeforeEach + public void setUp() throws IOException { + DefaultHttpDataFactory defaultHttpDataFactory = new DefaultHttpDataFactory(true); + defaultHttpDataFactory.setBaseDir(HOOK_TEST_TMP); + defaultHttpDataFactory.setDeleteOnExit(true); + + File baseDir = new File(HOOK_TEST_TMP); + baseDir.mkdirs(); // we don't need to clean it since it is in volatile files anyway + + fu = defaultHttpDataFactory.createFileUpload( + REQUEST, "attribute1", "tmp_f.txt", "text/plain", null, null, 0); + fu.setContent(Unpooled.wrappedBuffer(new byte[]{1, 2, 3, 4})); + + assertTrue(fu.getFile().exists()); + } + + @Test + public void testSimulateTriggerDeleteFileOnExitHook() { + + // simulate app exit + DeleteFileOnExitHook.runHook(); + + File[] files = new File(HOOK_TEST_TMP).listFiles(new FilenameFilter() { + @Override + public boolean accept(File dir, String name) { + return name.startsWith(DiskFileUpload.prefix); + } + }); + + assertEquals(0, files.length); + } + + @Test + public void testAfterHttpDataReleaseCheckFileExist() throws IOException { + + String filePath = fu.getFile().getPath(); + assertTrue(DeleteFileOnExitHook.checkFileExist(filePath)); + + fu.release(); + assertFalse(DeleteFileOnExitHook.checkFileExist(filePath)); + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/multipart/DiskFileUploadTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/multipart/DiskFileUploadTest.java new file mode 100644 index 0000000..5508919 --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/multipart/DiskFileUploadTest.java @@ -0,0 +1,297 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.multipart; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufInputStream; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.netty.util.internal.PlatformDependent; + +import org.junit.jupiter.api.Test; + +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.util.UUID; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class DiskFileUploadTest { + @Test + public void testSpecificCustomBaseDir() throws IOException { + File baseDir = new File("target/DiskFileUploadTest/testSpecificCustomBaseDir"); + baseDir.mkdirs(); // we don't need to clean it since it is in volatile files anyway + DiskFileUpload f = + new DiskFileUpload("d1", "d1", "application/json", null, null, 100, + baseDir.getAbsolutePath(), false); + + f.setContent(Unpooled.EMPTY_BUFFER); + + assertTrue(f.getFile().getAbsolutePath().startsWith(baseDir.getAbsolutePath())); + assertTrue(f.getFile().exists()); + assertEquals(0, f.getFile().length()); + f.delete(); + } + + @Test + public final void testDiskFileUploadEquals() { + DiskFileUpload f2 = + new DiskFileUpload("d1", "d1", "application/json", null, null, 100); + assertEquals(f2, f2); + f2.delete(); + } + + @Test + public void testEmptyBufferSetMultipleTimes() throws IOException { + DiskFileUpload f = + new DiskFileUpload("d1", "d1", "application/json", null, null, 100); + + f.setContent(Unpooled.EMPTY_BUFFER); + + assertTrue(f.getFile().exists()); + assertEquals(0, f.getFile().length()); + f.setContent(Unpooled.EMPTY_BUFFER); + assertTrue(f.getFile().exists()); + assertEquals(0, f.getFile().length()); + f.delete(); + } + + @Test + public void testEmptyBufferSetAfterNonEmptyBuffer() throws IOException { + DiskFileUpload f = + new DiskFileUpload("d1", "d1", "application/json", null, null, 100); + + f.setContent(Unpooled.wrappedBuffer(new byte[] { 1, 2, 3, 4 })); + + assertTrue(f.getFile().exists()); + assertEquals(4, f.getFile().length()); + f.setContent(Unpooled.EMPTY_BUFFER); + assertTrue(f.getFile().exists()); + assertEquals(0, f.getFile().length()); + f.delete(); + } + + @Test + public void testNonEmptyBufferSetMultipleTimes() throws IOException { + DiskFileUpload f = + new DiskFileUpload("d1", "d1", "application/json", null, null, 100); + + f.setContent(Unpooled.wrappedBuffer(new byte[] { 1, 2, 3, 4 })); + + assertTrue(f.getFile().exists()); + assertEquals(4, f.getFile().length()); + f.setContent(Unpooled.wrappedBuffer(new byte[] { 1, 2})); + assertTrue(f.getFile().exists()); + assertEquals(2, f.getFile().length()); + f.delete(); + } + + @Test + public void testAddContents() throws Exception { + DiskFileUpload f1 = new DiskFileUpload("file1", "file1", "application/json", null, null, 0); + try { + byte[] jsonBytes = new byte[4096]; + PlatformDependent.threadLocalRandom().nextBytes(jsonBytes); + + f1.addContent(Unpooled.wrappedBuffer(jsonBytes, 0, 1024), false); + f1.addContent(Unpooled.wrappedBuffer(jsonBytes, 1024, jsonBytes.length - 1024), true); + assertArrayEquals(jsonBytes, f1.get()); + + File file = f1.getFile(); + assertEquals(jsonBytes.length, file.length()); + + FileInputStream fis = new FileInputStream(file); + try { + byte[] buf = new byte[jsonBytes.length]; + int offset = 0; + int read = 0; + int len = buf.length; + while ((read = fis.read(buf, offset, len)) > 0) { + len -= read; + offset += read; + if (len <= 0 || offset >= buf.length) { + break; + } + } + assertArrayEquals(jsonBytes, buf); + } finally { + fis.close(); + } + } finally { + f1.delete(); + } + } + + @Test + public void testSetContentFromByteBuf() throws Exception { + DiskFileUpload f1 = new DiskFileUpload("file2", "file2", "application/json", null, null, 0); + try { + String json = "{\"hello\":\"world\"}"; + byte[] bytes = json.getBytes(CharsetUtil.UTF_8); + f1.setContent(Unpooled.wrappedBuffer(bytes)); + assertEquals(json, f1.getString()); + assertArrayEquals(bytes, f1.get()); + File file = f1.getFile(); + assertEquals((long) bytes.length, file.length()); + assertArrayEquals(bytes, doReadFile(file, bytes.length)); + } finally { + f1.delete(); + } + } + + @Test + public void testSetContentFromInputStream() throws Exception { + String json = "{\"hello\":\"world\",\"foo\":\"bar\"}"; + DiskFileUpload f1 = new DiskFileUpload("file3", "file3", "application/json", null, null, 0); + try { + byte[] bytes = json.getBytes(CharsetUtil.UTF_8); + ByteBuf buf = Unpooled.wrappedBuffer(bytes); + InputStream is = new ByteBufInputStream(buf); + try { + f1.setContent(is); + assertEquals(json, f1.getString()); + assertArrayEquals(bytes, f1.get()); + File file = f1.getFile(); + assertEquals((long) bytes.length, file.length()); + assertArrayEquals(bytes, doReadFile(file, bytes.length)); + } finally { + buf.release(); + is.close(); + } + } finally { + f1.delete(); + } + } + + @Test + public void testAddContentFromByteBuf() throws Exception { + testAddContentFromByteBuf0(false); + } + + @Test + public void testAddContentFromCompositeByteBuf() throws Exception { + testAddContentFromByteBuf0(true); + } + + private static void testAddContentFromByteBuf0(boolean composite) throws Exception { + DiskFileUpload f1 = new DiskFileUpload("file3", "file3", "application/json", null, null, 0); + try { + byte[] bytes = new byte[4096]; + PlatformDependent.threadLocalRandom().nextBytes(bytes); + + final ByteBuf buffer; + + if (composite) { + buffer = Unpooled.compositeBuffer() + .addComponent(true, Unpooled.wrappedBuffer(bytes, 0 , bytes.length / 2)) + .addComponent(true, Unpooled.wrappedBuffer(bytes, bytes.length / 2, bytes.length / 2)); + } else { + buffer = Unpooled.wrappedBuffer(bytes); + } + f1.addContent(buffer, true); + ByteBuf buf = f1.getByteBuf(); + assertEquals(buf.readerIndex(), 0); + assertEquals(buf.writerIndex(), bytes.length); + assertArrayEquals(bytes, ByteBufUtil.getBytes(buf)); + } finally { + //release the ByteBuf + f1.delete(); + } + } + + private static byte[] doReadFile(File file, int maxRead) throws Exception { + FileInputStream fis = new FileInputStream(file); + try { + byte[] buf = new byte[maxRead]; + int offset = 0; + int read = 0; + int len = buf.length; + while ((read = fis.read(buf, offset, len)) > 0) { + len -= read; + offset += read; + if (len <= 0 || offset >= buf.length) { + break; + } + } + return buf; + } finally { + fis.close(); + } + } + + @Test + public void testDelete() throws Exception { + String json = "{\"foo\":\"bar\"}"; + byte[] bytes = json.getBytes(CharsetUtil.UTF_8); + File tmpFile = null; + DiskFileUpload f1 = new DiskFileUpload("file4", "file4", "application/json", null, null, 0); + try { + assertNull(f1.getFile()); + f1.setContent(Unpooled.wrappedBuffer(bytes)); + assertNotNull(tmpFile = f1.getFile()); + } finally { + f1.delete(); + assertNull(f1.getFile()); + assertNotNull(tmpFile); + assertFalse(tmpFile.exists()); + } + } + + @Test + public void setSetContentFromFileExceptionally() throws Exception { + final long maxSize = 4; + DiskFileUpload f1 = new DiskFileUpload("file5", "file5", "application/json", null, null, 0); + f1.setMaxSize(maxSize); + try { + f1.setContent(Unpooled.wrappedBuffer(new byte[(int) maxSize])); + File originalFile = f1.getFile(); + assertNotNull(originalFile); + assertEquals(maxSize, originalFile.length()); + assertEquals(maxSize, f1.length()); + byte[] bytes = new byte[8]; + PlatformDependent.threadLocalRandom().nextBytes(bytes); + File tmpFile = PlatformDependent.createTempFile(UUID.randomUUID().toString(), ".tmp", null); + tmpFile.deleteOnExit(); + FileOutputStream fos = new FileOutputStream(tmpFile); + try { + fos.write(bytes); + fos.flush(); + } finally { + fos.close(); + } + try { + f1.setContent(tmpFile); + fail("should not reach here!"); + } catch (IOException e) { + assertNotNull(f1.getFile()); + assertEquals(originalFile, f1.getFile()); + assertEquals(maxSize, f1.length()); + } + } finally { + f1.delete(); + } + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/multipart/HttpDataTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/multipart/HttpDataTest.java new file mode 100644 index 0000000..7c8355b --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/multipart/HttpDataTest.java @@ -0,0 +1,150 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.multipart; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.PooledByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import org.assertj.core.api.ThrowableAssert; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import java.io.IOException; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; +import java.util.Random; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; + +class HttpDataTest { + private static final byte[] BYTES = new byte[64]; + + @Retention(RetentionPolicy.RUNTIME) + @Target(ElementType.METHOD) + @ParameterizedTest(name = "{displayName}({0})") + @MethodSource("data") + @interface ParameterizedHttpDataTest { + } + + static HttpData[] data() { + return new HttpData[]{ + new MemoryAttribute("test", 10), + new MemoryFileUpload("test", "", "text/plain", null, CharsetUtil.UTF_8, 10), + new MixedAttribute("test", 10, -1), + new MixedFileUpload("test", "", "text/plain", null, CharsetUtil.UTF_8, 10, -1), + new DiskAttribute("test", 10), + new DiskFileUpload("test", "", "text/plain", null, CharsetUtil.UTF_8, 10) + }; + } + + @BeforeAll + static void setUp() { + Random rndm = new Random(); + rndm.nextBytes(BYTES); + } + + @ParameterizedHttpDataTest + void testAddContentEmptyBuffer(HttpData httpData) throws IOException { + ByteBuf content = PooledByteBufAllocator.DEFAULT.buffer(); + httpData.addContent(content, false); + assertThat(content.refCnt()).isEqualTo(0); + } + + @ParameterizedHttpDataTest + void testCompletedFlagPreservedAfterRetainDuplicate(HttpData httpData) throws IOException { + httpData.addContent(Unpooled.wrappedBuffer("foo".getBytes(CharsetUtil.UTF_8)), false); + assertThat(httpData.isCompleted()).isFalse(); + HttpData duplicate = httpData.retainedDuplicate(); + assertThat(duplicate.isCompleted()).isFalse(); + assertThat(duplicate.release()).isTrue(); + httpData.addContent(Unpooled.wrappedBuffer("bar".getBytes(CharsetUtil.UTF_8)), true); + assertThat(httpData.isCompleted()).isTrue(); + duplicate = httpData.retainedDuplicate(); + assertThat(duplicate.isCompleted()).isTrue(); + assertThat(duplicate.release()).isTrue(); + } + + @Test + void testAddContentExceedsDefinedSizeDiskFileUpload() { + doTestAddContentExceedsSize( + new DiskFileUpload("test", "", "application/json", null, CharsetUtil.UTF_8, 10), + "Out of size: 64 > 10"); + } + + @Test + void testAddContentExceedsDefinedSizeMemoryFileUpload() { + doTestAddContentExceedsSize( + new MemoryFileUpload("test", "", "application/json", null, CharsetUtil.UTF_8, 10), + "Out of size: 64 > 10"); + } + + @ParameterizedHttpDataTest + void testAddContentExceedsMaxSize(final HttpData httpData) { + httpData.setMaxSize(10); + doTestAddContentExceedsSize(httpData, "Size exceed allowed maximum capacity"); + } + + @ParameterizedHttpDataTest + void testSetContentExceedsDefinedSize(final HttpData httpData) { + doTestSetContentExceedsSize(httpData, "Out of size: 64 > 10"); + } + + @ParameterizedHttpDataTest + void testSetContentExceedsMaxSize(final HttpData httpData) { + httpData.setMaxSize(10); + doTestSetContentExceedsSize(httpData, "Size exceed allowed maximum capacity"); + } + + private static void doTestAddContentExceedsSize(final HttpData httpData, String expectedMessage) { + final ByteBuf content = PooledByteBufAllocator.DEFAULT.buffer(); + content.writeBytes(BYTES); + + assertThatExceptionOfType(IOException.class) + .isThrownBy(new ThrowableAssert.ThrowingCallable() { + + @Override + public void call() throws Throwable { + httpData.addContent(content, false); + } + }) + .withMessage(expectedMessage); + + assertThat(content.refCnt()).isEqualTo(0); + } + + private static void doTestSetContentExceedsSize(final HttpData httpData, String expectedMessage) { + final ByteBuf content = PooledByteBufAllocator.DEFAULT.buffer(); + content.writeBytes(BYTES); + + assertThatExceptionOfType(IOException.class) + .isThrownBy(new ThrowableAssert.ThrowingCallable() { + + @Override + public void call() throws Throwable { + httpData.setContent(content); + } + }) + .withMessage(expectedMessage); + + assertThat(content.refCnt()).isEqualTo(0); + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/multipart/HttpPostMultiPartRequestDecoderTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/multipart/HttpPostMultiPartRequestDecoderTest.java new file mode 100644 index 0000000..4961f4b --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/multipart/HttpPostMultiPartRequestDecoderTest.java @@ -0,0 +1,513 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.multipart; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.handler.codec.http.DefaultFullHttpRequest; +import io.netty.handler.codec.http.DefaultHttpContent; +import io.netty.handler.codec.http.DefaultHttpRequest; +import io.netty.handler.codec.http.DefaultLastHttpContent; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.HttpConstants; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpRequest; +import io.netty.handler.codec.http.HttpVersion; +import io.netty.util.CharsetUtil; +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.util.Arrays; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class HttpPostMultiPartRequestDecoderTest { + + @Test + public void testDecodeFullHttpRequestWithNoContentTypeHeader() { + FullHttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, "/"); + try { + new HttpPostMultipartRequestDecoder(req); + fail("Was expecting an ErrorDataDecoderException"); + } catch (HttpPostRequestDecoder.ErrorDataDecoderException expected) { + // expected + } finally { + assertTrue(req.release()); + } + } + + @Test + public void testDecodeFullHttpRequestWithInvalidCharset() { + FullHttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, "/"); + req.headers().set(HttpHeaderNames.CONTENT_TYPE, + "multipart/form-data; boundary=--89421926422648 [; charset=UTF-8]"); + + try { + new HttpPostMultipartRequestDecoder(req); + fail("Was expecting an ErrorDataDecoderException"); + } catch (HttpPostRequestDecoder.ErrorDataDecoderException expected) { + // expected + } finally { + assertTrue(req.release()); + } + } + + @Test + public void testDecodeFullHttpRequestWithInvalidPayloadReleaseBuffer() { + String content = "\n--861fbeab-cd20-470c-9609-d40a0f704466\n" + + "Content-Disposition: form-data; name=\"image1\"; filename*=\"'some.jpeg\"\n" + + "Content-Type: image/jpeg\n" + + "Content-Length: 1\n" + + "x\n" + + "--861fbeab-cd20-470c-9609-d40a0f704466--\n"; + + FullHttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, "/upload", + Unpooled.copiedBuffer(content, CharsetUtil.US_ASCII)); + req.headers().set("content-type", "multipart/form-data; boundary=861fbeab-cd20-470c-9609-d40a0f704466"); + req.headers().set("content-length", content.length()); + + try { + new HttpPostMultipartRequestDecoder(req); + fail("Was expecting an ErrorDataDecoderException"); + } catch (HttpPostRequestDecoder.ErrorDataDecoderException expected) { + // expected + } finally { + assertTrue(req.release()); + } + } + + @Test + public void testDelimiterExceedLeftSpaceInCurrentBuffer() { + String delimiter = "--861fbeab-cd20-470c-9609-d40a0f704466"; + String suffix = '\n' + delimiter + "--\n"; + byte[] bsuffix = suffix.getBytes(CharsetUtil.UTF_8); + int partOfDelimiter = bsuffix.length / 2; + int bytesLastChunk = 355 - partOfDelimiter; // to try to have an out of bound since content is > delimiter + byte[] bsuffix1 = Arrays.copyOf(bsuffix, partOfDelimiter); + byte[] bsuffix2 = Arrays.copyOfRange(bsuffix, partOfDelimiter, bsuffix.length); + String prefix = delimiter + "\n" + + "Content-Disposition: form-data; name=\"image\"; filename=\"guangzhou.jpeg\"\n" + + "Content-Type: image/jpeg\n" + + "Content-Length: " + bytesLastChunk + "\n\n"; + HttpRequest request = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, "/upload"); + request.headers().set("content-type", "multipart/form-data; boundary=861fbeab-cd20-470c-9609-d40a0f704466"); + request.headers().set("content-length", prefix.length() + bytesLastChunk + suffix.length()); + + // Factory using Memory mode + HttpDataFactory factory = new DefaultHttpDataFactory(false); + HttpPostMultipartRequestDecoder decoder = new HttpPostMultipartRequestDecoder(factory, request); + ByteBuf buf = Unpooled.wrappedBuffer(prefix.getBytes(CharsetUtil.UTF_8)); + DefaultHttpContent httpContent = new DefaultHttpContent(buf); + decoder.offer(httpContent); + assertNotNull((HttpData) decoder.currentPartialHttpData()); + httpContent.release(); + // Chunk less than Delimiter size but containing part of delimiter + byte[] body = new byte[bytesLastChunk + bsuffix1.length]; + Arrays.fill(body, (byte) 2); + for (int i = 0; i < bsuffix1.length; i++) { + body[bytesLastChunk + i] = bsuffix1[i]; + } + ByteBuf content = Unpooled.wrappedBuffer(body); + httpContent = new DefaultHttpContent(content); + decoder.offer(httpContent); // Ouf of range before here + assertNotNull(((HttpData) decoder.currentPartialHttpData()).content()); + httpContent.release(); + content = Unpooled.wrappedBuffer(bsuffix2); + httpContent = new DefaultHttpContent(content); + decoder.offer(httpContent); + assertNull((HttpData) decoder.currentPartialHttpData()); + httpContent.release(); + decoder.offer(new DefaultLastHttpContent()); + FileUpload data = (FileUpload) decoder.getBodyHttpDatas().get(0); + assertEquals(data.length(), bytesLastChunk); + assertEquals(true, data.isInMemory()); + + InterfaceHttpData[] httpDatas = decoder.getBodyHttpDatas().toArray(new InterfaceHttpData[0]); + for (InterfaceHttpData httpData : httpDatas) { + assertEquals(1, httpData.refCnt(), "Before cleanAllHttpData should be 1"); + } + factory.cleanAllHttpData(); + for (InterfaceHttpData httpData : httpDatas) { + assertEquals(1, httpData.refCnt(), "After cleanAllHttpData should be 1 if in Memory"); + } + decoder.destroy(); + for (InterfaceHttpData httpData : httpDatas) { + assertEquals(0, httpData.refCnt(), "RefCnt should be 0"); + } + } + + private void commonTestBigFileDelimiterInMiddleChunk(HttpDataFactory factory, boolean inMemory) + throws IOException { + int nbChunks = 100; + int bytesPerChunk = 100000; + int bytesLastChunk = 10000; + int fileSize = bytesPerChunk * nbChunks + bytesLastChunk; // set Xmx to a number lower than this and it crashes + + String delimiter = "--861fbeab-cd20-470c-9609-d40a0f704466"; + String prefix = delimiter + "\n" + + "Content-Disposition: form-data; name=\"image\"; filename=\"guangzhou.jpeg\"\n" + + "Content-Type: image/jpeg\n" + + "Content-Length: " + fileSize + "\n" + + "\n"; + + String suffix1 = "\n" + + "--861fbeab-"; + String suffix2 = "cd20-470c-9609-d40a0f704466--\n"; + String suffix = suffix1 + suffix2; + + HttpRequest request = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, "/upload"); + request.headers().set("content-type", "multipart/form-data; boundary=861fbeab-cd20-470c-9609-d40a0f704466"); + request.headers().set("content-length", prefix.length() + fileSize + suffix.length()); + + HttpPostMultipartRequestDecoder decoder = new HttpPostMultipartRequestDecoder(factory, request); + ByteBuf buf = Unpooled.wrappedBuffer(prefix.getBytes(CharsetUtil.UTF_8)); + DefaultHttpContent httpContent = new DefaultHttpContent(buf); + decoder.offer(httpContent); + assertNotNull(((HttpData) decoder.currentPartialHttpData()).content()); + httpContent.release(); + + byte[] body = new byte[bytesPerChunk]; + Arrays.fill(body, (byte) 1); + // Set first bytes as CRLF to ensure it is correctly getting the last CRLF + body[0] = HttpConstants.CR; + body[1] = HttpConstants.LF; + for (int i = 0; i < nbChunks; i++) { + ByteBuf content = Unpooled.wrappedBuffer(body, 0, bytesPerChunk); + httpContent = new DefaultHttpContent(content); + decoder.offer(httpContent); // **OutOfMemory previously here** + assertNotNull(((HttpData) decoder.currentPartialHttpData()).content()); + httpContent.release(); + } + + byte[] bsuffix1 = suffix1.getBytes(CharsetUtil.UTF_8); + byte[] previousLastbody = new byte[bytesLastChunk - bsuffix1.length]; + byte[] bdelimiter = delimiter.getBytes(CharsetUtil.UTF_8); + byte[] lastbody = new byte[2 * bsuffix1.length]; + Arrays.fill(previousLastbody, (byte) 1); + previousLastbody[0] = HttpConstants.CR; + previousLastbody[1] = HttpConstants.LF; + Arrays.fill(lastbody, (byte) 1); + // put somewhere a not valid delimiter + for (int i = 0; i < bdelimiter.length; i++) { + previousLastbody[i + 10] = bdelimiter[i]; + } + lastbody[0] = HttpConstants.CR; + lastbody[1] = HttpConstants.LF; + for (int i = 0; i < bsuffix1.length; i++) { + lastbody[bsuffix1.length + i] = bsuffix1[i]; + } + + ByteBuf content2 = Unpooled.wrappedBuffer(previousLastbody, 0, previousLastbody.length); + httpContent = new DefaultHttpContent(content2); + decoder.offer(httpContent); + assertNotNull(((HttpData) decoder.currentPartialHttpData()).content()); + httpContent.release(); + content2 = Unpooled.wrappedBuffer(lastbody, 0, lastbody.length); + httpContent = new DefaultHttpContent(content2); + decoder.offer(httpContent); + assertNotNull(((HttpData) decoder.currentPartialHttpData()).content()); + httpContent.release(); + content2 = Unpooled.wrappedBuffer(suffix2.getBytes(CharsetUtil.UTF_8)); + httpContent = new DefaultHttpContent(content2); + decoder.offer(httpContent); + assertNull(decoder.currentPartialHttpData()); + httpContent.release(); + decoder.offer(new DefaultLastHttpContent()); + + FileUpload data = (FileUpload) decoder.getBodyHttpDatas().get(0); + assertEquals(data.length(), fileSize); + assertEquals(inMemory, data.isInMemory()); + if (data.isInMemory()) { + // To be done only if not inMemory: assertEquals(data.get().length, fileSize); + assertFalse(data.getByteBuf().capacity() < 1024 * 1024, + "Capacity should be higher than 1M"); + } + assertTrue(decoder.getCurrentAllocatedCapacity() < 1024 * 1024, + "Capacity should be less than 1M"); + InterfaceHttpData[] httpDatas = decoder.getBodyHttpDatas().toArray(new InterfaceHttpData[0]); + for (InterfaceHttpData httpData : httpDatas) { + assertEquals(1, httpData.refCnt(), "Before cleanAllHttpData should be 1"); + } + factory.cleanAllHttpData(); + for (InterfaceHttpData httpData : httpDatas) { + assertEquals(inMemory? 1 : 0, httpData.refCnt(), "After cleanAllHttpData should be 1 if in Memory"); + } + decoder.destroy(); + for (InterfaceHttpData httpData : httpDatas) { + assertEquals(0, httpData.refCnt(), "RefCnt should be 0"); + } + } + + @Test + public void testBIgFileUploadDelimiterInMiddleChunkDecoderDiskFactory() throws IOException { + // Factory using Disk mode + HttpDataFactory factory = new DefaultHttpDataFactory(true); + + commonTestBigFileDelimiterInMiddleChunk(factory, false); + } + + @Test + public void testBIgFileUploadDelimiterInMiddleChunkDecoderMemoryFactory() throws IOException { + // Factory using Memory mode + HttpDataFactory factory = new DefaultHttpDataFactory(false); + + commonTestBigFileDelimiterInMiddleChunk(factory, true); + } + + @Test + public void testBIgFileUploadDelimiterInMiddleChunkDecoderMixedFactory() throws IOException { + // Factory using Mixed mode, where file shall be on Disk + HttpDataFactory factory = new DefaultHttpDataFactory(10000); + + commonTestBigFileDelimiterInMiddleChunk(factory, false); + } + + @Test + public void testNotBadReleaseBuffersDuringDecodingDiskFactory() throws IOException { + // Using Disk Factory + HttpDataFactory factory = new DefaultHttpDataFactory(true); + commonNotBadReleaseBuffersDuringDecoding(factory, false); + } + @Test + public void testNotBadReleaseBuffersDuringDecodingMemoryFactory() throws IOException { + // Using Memory Factory + HttpDataFactory factory = new DefaultHttpDataFactory(false); + commonNotBadReleaseBuffersDuringDecoding(factory, true); + } + @Test + public void testNotBadReleaseBuffersDuringDecodingMixedFactory() throws IOException { + // Using Mixed Factory + HttpDataFactory factory = new DefaultHttpDataFactory(100); + commonNotBadReleaseBuffersDuringDecoding(factory, false); + } + + @Test + public void testDecodeFullHttpRequestWithOptionalParameters() { + String content = "\n--861fbeab-cd20-470c-9609-d40a0f704466\r\n" + + "content-disposition: form-data; " + + "name=\"file\"; filename=\"myfile.ogg\"\r\n" + + "content-type: audio/ogg; codecs=opus; charset=UTF8\r\ncontent-transfer-encoding: binary\r\n" + + "\r\n\u0001\u0002\u0003\u0004\r\n--861fbeab-cd20-470c-9609-d40a0f704466--\r\n\",\n"; + + FullHttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, "/upload", + Unpooled.copiedBuffer(content, CharsetUtil.US_ASCII)); + req.headers().set("content-type", "multipart/form-data; boundary=861fbeab-cd20-470c-9609-d40a0f704466"); + req.headers().set("content-length", content.length()); + + HttpPostMultipartRequestDecoder test = new HttpPostMultipartRequestDecoder(req); + FileUpload httpData = (FileUpload) test.getBodyHttpDatas("file").get(0); + assertEquals("audio/ogg", httpData.getContentType()); + test.destroy(); + } + + private static void commonNotBadReleaseBuffersDuringDecoding(HttpDataFactory factory, boolean inMemory) + throws IOException { + int nbItems = 20; + int bytesPerItem = 1000; + int maxMemory = 500; + + String prefix1 = "\n--861fbeab-cd20-470c-9609-d40a0f704466\n" + + "Content-Disposition: form-data; name=\"image"; + String prefix2 = + "\"; filename=\"guangzhou.jpeg\"\n" + + "Content-Type: image/jpeg\n" + + "Content-Length: " + bytesPerItem + "\n" + "\n"; + + String suffix = "\n--861fbeab-cd20-470c-9609-d40a0f704466--\n"; + + HttpRequest request = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, "/upload"); + request.headers().set("content-type", "multipart/form-data; boundary=861fbeab-cd20-470c-9609-d40a0f704466"); + request.headers().set("content-length", nbItems * (prefix1.length() + prefix2.length() + 2 + bytesPerItem) + + suffix.length()); + HttpPostMultipartRequestDecoder decoder = new HttpPostMultipartRequestDecoder(factory, request); + decoder.setDiscardThreshold(maxMemory); + for (int rank = 0; rank < nbItems; rank++) { + byte[] bp1 = prefix1.getBytes(CharsetUtil.UTF_8); + byte[] bp2 = prefix2.getBytes(CharsetUtil.UTF_8); + byte[] prefix = new byte[bp1.length + 2 + bp2.length]; + for (int i = 0; i < bp1.length; i++) { + prefix[i] = bp1[i]; + } + byte[] brank = Integer.toString(10 + rank).getBytes(CharsetUtil.UTF_8); + prefix[bp1.length] = brank[0]; + prefix[bp1.length + 1] = brank[1]; + for (int i = 0; i < bp2.length; i++) { + prefix[bp1.length + 2 + i] = bp2[i]; + } + ByteBuf buf = Unpooled.wrappedBuffer(prefix); + DefaultHttpContent httpContent = new DefaultHttpContent(buf); + decoder.offer(httpContent); + httpContent.release(); + byte[] body = new byte[bytesPerItem]; + Arrays.fill(body, (byte) rank); + ByteBuf content = Unpooled.wrappedBuffer(body, 0, bytesPerItem); + httpContent = new DefaultHttpContent(content); + decoder.offer(httpContent); + httpContent.release(); + } + byte[] lastbody = suffix.getBytes(CharsetUtil.UTF_8); + ByteBuf content2 = Unpooled.wrappedBuffer(lastbody, 0, lastbody.length); + DefaultHttpContent httpContent = new DefaultHttpContent(content2); + decoder.offer(httpContent); + httpContent.release(); + decoder.offer(new DefaultLastHttpContent()); + + for (int rank = 0; rank < nbItems; rank++) { + FileUpload data = (FileUpload) decoder.getBodyHttpData("image" + (10 + rank)); + assertEquals(bytesPerItem, data.length()); + assertEquals(inMemory, data.isInMemory()); + byte[] body = new byte[bytesPerItem]; + Arrays.fill(body, (byte) rank); + assertTrue(Arrays.equals(body, data.get())); + } + // To not be done since will load full file on memory: assertEquals(data.get().length, fileSize); + // Not mandatory since implicitly called during destroy of decoder + for (InterfaceHttpData httpData: decoder.getBodyHttpDatas()) { + httpData.release(); + factory.removeHttpDataFromClean(request, httpData); + } + factory.cleanAllHttpData(); + decoder.destroy(); + } + + // Issue #11668 + private static void commonTestFileDelimiterLFLastChunk(HttpDataFactory factory, boolean inMemory) + throws IOException { + int nbChunks = 2; + int bytesPerChunk = 100000; + int bytesLastChunk = 10000; + int fileSize = bytesPerChunk * nbChunks + bytesLastChunk; // set Xmx to a number lower than this and it crashes + + String delimiter = "--861fbeab-cd20-470c-9609-d40a0f704466"; + String prefix = delimiter + "\n" + + "Content-Disposition: form-data; name=\"image\"; filename=\"guangzhou.jpeg\"\n" + + "Content-Type: image/jpeg\n" + + "Content-Length: " + fileSize + "\n" + + "\n"; + + String suffix = "--861fbeab-cd20-470c-9609-d40a0f704466--"; + byte[] bsuffix = suffix.getBytes(CharsetUtil.UTF_8); + byte[] bsuffixReal = new byte[bsuffix.length + 2]; + for (int i = 0; i < bsuffix.length; i++) { + bsuffixReal[1 + i] = bsuffix[i]; + } + bsuffixReal[0] = HttpConstants.LF; + bsuffixReal[bsuffixReal.length - 1] = HttpConstants.CR; + byte[] lastbody = {HttpConstants.LF}; + + HttpRequest request = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, "/upload"); + request.headers().set("content-type", "multipart/form-data; boundary=861fbeab-cd20-470c-9609-d40a0f704466"); + // +4 => 2xCRLF (beginning, end) + request.headers().set("content-length", prefix.length() + fileSize + suffix.length() + 4); + + HttpPostMultipartRequestDecoder decoder = new HttpPostMultipartRequestDecoder(factory, request); + ByteBuf buf = Unpooled.wrappedBuffer(prefix.getBytes(CharsetUtil.UTF_8)); + DefaultHttpContent httpContent = new DefaultHttpContent(buf); + decoder.offer(httpContent); + assertNotNull(((HttpData) decoder.currentPartialHttpData()).content()); + httpContent.release(); + + byte[] body = new byte[bytesPerChunk]; + Arrays.fill(body, (byte) 1); + // Set first bytes as CRLF to ensure it is correctly getting the last CRLF + body[0] = HttpConstants.CR; + body[1] = HttpConstants.LF; + for (int i = 0; i < nbChunks; i++) { + ByteBuf content = Unpooled.wrappedBuffer(body, 0, bytesPerChunk); + httpContent = new DefaultHttpContent(content); + decoder.offer(httpContent); // **OutOfMemory previously here** + assertNotNull(((HttpData) decoder.currentPartialHttpData()).content()); + httpContent.release(); + } + // Last -2 body = content + CR but no delimiter + byte[] previousLastbody = new byte[bytesLastChunk + 1]; + Arrays.fill(previousLastbody, (byte) 1); + previousLastbody[bytesLastChunk] = HttpConstants.CR; + ByteBuf content2 = Unpooled.wrappedBuffer(previousLastbody, 0, previousLastbody.length); + httpContent = new DefaultHttpContent(content2); + decoder.offer(httpContent); + assertNotNull(decoder.currentPartialHttpData()); + httpContent.release(); + // Last -1 body = LF+delimiter+CR but no LF + content2 = Unpooled.wrappedBuffer(bsuffixReal, 0, bsuffixReal.length); + httpContent = new DefaultHttpContent(content2); + decoder.offer(httpContent); + assertNull(decoder.currentPartialHttpData()); + httpContent.release(); + // Last (LF) + content2 = Unpooled.wrappedBuffer(lastbody, 0, lastbody.length); + httpContent = new DefaultHttpContent(content2); + decoder.offer(httpContent); + assertNull(decoder.currentPartialHttpData()); + httpContent.release(); + // End + decoder.offer(new DefaultLastHttpContent()); + + FileUpload data = (FileUpload) decoder.getBodyHttpDatas().get(0); + assertEquals(data.length(), fileSize); + assertEquals(inMemory, data.isInMemory()); + if (data.isInMemory()) { + // To be done only if not inMemory: assertEquals(data.get().length, fileSize); + assertFalse(data.getByteBuf().capacity() < fileSize, + "Capacity should be at least file size"); + } + assertTrue(decoder.getCurrentAllocatedCapacity() < fileSize, + "Capacity should be less than 1M"); + InterfaceHttpData[] httpDatas = decoder.getBodyHttpDatas().toArray(new InterfaceHttpData[0]); + for (InterfaceHttpData httpData : httpDatas) { + assertEquals(1, httpData.refCnt(), "Before cleanAllHttpData should be 1"); + } + factory.cleanAllHttpData(); + for (InterfaceHttpData httpData : httpDatas) { + assertEquals(inMemory? 1 : 0, httpData.refCnt(), "After cleanAllHttpData should be 1 if in Memory"); + } + decoder.destroy(); + for (InterfaceHttpData httpData : httpDatas) { + assertEquals(0, httpData.refCnt(), "RefCnt should be 0"); + } + } + + @Test + public void testFileDelimiterLFLastChunkDecoderDiskFactory() throws IOException { + // Factory using Disk mode + HttpDataFactory factory = new DefaultHttpDataFactory(true); + + commonTestFileDelimiterLFLastChunk(factory, false); + } + + @Test + public void testFileDelimiterLFLastChunkDecoderMemoryFactory() throws IOException { + // Factory using Memory mode + HttpDataFactory factory = new DefaultHttpDataFactory(false); + + commonTestFileDelimiterLFLastChunk(factory, true); + } + + @Test + public void testFileDelimiterLFLastChunkDecoderMixedFactory() throws IOException { + // Factory using Mixed mode, where file shall be on Disk + HttpDataFactory factory = new DefaultHttpDataFactory(10000); + + commonTestFileDelimiterLFLastChunk(factory, false); + } + +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/multipart/HttpPostRequestDecoderTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/multipart/HttpPostRequestDecoderTest.java new file mode 100644 index 0000000..6099062 --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/multipart/HttpPostRequestDecoderTest.java @@ -0,0 +1,1043 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.multipart; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.netty.buffer.UnpooledByteBufAllocator; +import io.netty.handler.codec.DecoderResult; +import io.netty.handler.codec.http.DefaultFullHttpRequest; +import io.netty.handler.codec.http.DefaultHttpContent; +import io.netty.handler.codec.http.DefaultHttpRequest; +import io.netty.handler.codec.http.DefaultLastHttpContent; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaderValues; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpRequest; +import io.netty.handler.codec.http.HttpVersion; +import io.netty.handler.codec.http.LastHttpContent; +import io.netty.util.CharsetUtil; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import java.io.UnsupportedEncodingException; +import java.net.URLEncoder; +import java.nio.charset.UnsupportedCharsetException; +import java.util.Arrays; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +/** + * {@link HttpPostRequestDecoder} test case. + */ +public class HttpPostRequestDecoderTest { + + @Test + public void testBinaryStreamUploadWithSpace() throws Exception { + testBinaryStreamUpload(true); + } + + // https://github.com/netty/netty/issues/1575 + @Test + public void testBinaryStreamUploadWithoutSpace() throws Exception { + testBinaryStreamUpload(false); + } + + private static void testBinaryStreamUpload(boolean withSpace) throws Exception { + final String boundary = "dLV9Wyq26L_-JQxk6ferf-RT153LhOO"; + final String contentTypeValue; + if (withSpace) { + contentTypeValue = "multipart/form-data; boundary=" + boundary; + } else { + contentTypeValue = "multipart/form-data;boundary=" + boundary; + } + final DefaultHttpRequest req = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, + "http://localhost"); + + req.setDecoderResult(DecoderResult.SUCCESS); + req.headers().add(HttpHeaderNames.CONTENT_TYPE, contentTypeValue); + req.headers().add(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED); + + // Force to use memory-based data. + final DefaultHttpDataFactory inMemoryFactory = new DefaultHttpDataFactory(false); + + for (String data : Arrays.asList("", "\r", "\r\r", "\r\r\r")) { + final String body = + "--" + boundary + "\r\n" + + "Content-Disposition: form-data; name=\"file\"; filename=\"tmp-0.txt\"\r\n" + + "Content-Type: image/gif\r\n" + + "\r\n" + + data + "\r\n" + + "--" + boundary + "--\r\n"; + + // Create decoder instance to test. + final HttpPostRequestDecoder decoder = new HttpPostRequestDecoder(inMemoryFactory, req); + + ByteBuf buf = Unpooled.copiedBuffer(body, CharsetUtil.UTF_8); + decoder.offer(new DefaultHttpContent(buf)); + decoder.offer(new DefaultHttpContent(Unpooled.EMPTY_BUFFER)); + + // Validate it's enough chunks to decode upload. + assertTrue(decoder.hasNext()); + + // Decode binary upload. + MemoryFileUpload upload = (MemoryFileUpload) decoder.next(); + + // Validate data has been parsed correctly as it was passed into request. + assertEquals(data, upload.getString(CharsetUtil.UTF_8), + "Invalid decoded data [data=" + data.replaceAll("\r", "\\\\r") + ", upload=" + upload + ']'); + upload.release(); + decoder.destroy(); + buf.release(); + } + } + + // See https://github.com/netty/netty/issues/1089 + @Test + public void testFullHttpRequestUpload() throws Exception { + final String boundary = "dLV9Wyq26L_-JQxk6ferf-RT153LhOO"; + + final DefaultFullHttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, + "http://localhost"); + + req.setDecoderResult(DecoderResult.SUCCESS); + req.headers().add(HttpHeaderNames.CONTENT_TYPE, "multipart/form-data; boundary=" + boundary); + req.headers().add(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED); + + // Force to use memory-based data. + final DefaultHttpDataFactory inMemoryFactory = new DefaultHttpDataFactory(false); + + for (String data : Arrays.asList("", "\r", "\r\r", "\r\r\r")) { + final String body = + "--" + boundary + "\r\n" + + "Content-Disposition: form-data; name=\"file\"; filename=\"tmp-0.txt\"\r\n" + + "Content-Type: image/gif\r\n" + + "\r\n" + + data + "\r\n" + + "--" + boundary + "--\r\n"; + + req.content().writeBytes(body.getBytes(CharsetUtil.UTF_8)); + } + // Create decoder instance to test. + final HttpPostRequestDecoder decoder = new HttpPostRequestDecoder(inMemoryFactory, req); + assertFalse(decoder.getBodyHttpDatas().isEmpty()); + decoder.destroy(); + assertTrue(req.release()); + } + + // See https://github.com/netty/netty/issues/2544 + @Test + public void testMultipartCodecWithCRasEndOfAttribute() throws Exception { + final String boundary = "dLV9Wyq26L_-JQxk6ferf-RT153LhOO"; + + // Force to use memory-based data. + final DefaultHttpDataFactory inMemoryFactory = new DefaultHttpDataFactory(false); + // Build test case + String extradata = "aaaa"; + String[] datas = new String[5]; + for (int i = 0; i < 4; i++) { + datas[i] = extradata; + for (int j = 0; j < i; j++) { + datas[i] += '\r'; + } + } + + for (int i = 0; i < 4; i++) { + final DefaultFullHttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, + "http://localhost"); + req.setDecoderResult(DecoderResult.SUCCESS); + req.headers().add(HttpHeaderNames.CONTENT_TYPE, "multipart/form-data; boundary=" + boundary); + req.headers().add(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED); + final String body = + "--" + boundary + "\r\n" + + "Content-Disposition: form-data; name=\"file" + i + "\"\r\n" + + "Content-Type: image/gif\r\n" + + "\r\n" + + datas[i] + "\r\n" + + "--" + boundary + "--\r\n"; + + req.content().writeBytes(body.getBytes(CharsetUtil.UTF_8)); + // Create decoder instance to test. + final HttpPostRequestDecoder decoder = new HttpPostRequestDecoder(inMemoryFactory, req); + assertFalse(decoder.getBodyHttpDatas().isEmpty()); + // Check correctness: data size + InterfaceHttpData httpdata = decoder.getBodyHttpData("file" + i); + assertNotNull(httpdata); + Attribute attribute = (Attribute) httpdata; + byte[] datar = attribute.get(); + assertNotNull(datar); + assertEquals(datas[i].getBytes(CharsetUtil.UTF_8).length, datar.length); + + decoder.destroy(); + assertTrue(req.release()); + } + } + + // See https://github.com/netty/netty/issues/2542 + @Test + public void testQuotedBoundary() throws Exception { + final String boundary = "dLV9Wyq26L_-JQxk6ferf-RT153LhOO"; + + final DefaultFullHttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, + "http://localhost"); + + req.setDecoderResult(DecoderResult.SUCCESS); + req.headers().add(HttpHeaderNames.CONTENT_TYPE, "multipart/form-data; boundary=\"" + boundary + '"'); + req.headers().add(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED); + + // Force to use memory-based data. + final DefaultHttpDataFactory inMemoryFactory = new DefaultHttpDataFactory(false); + + for (String data : Arrays.asList("", "\r", "\r\r", "\r\r\r")) { + final String body = + "--" + boundary + "\r\n" + + "Content-Disposition: form-data; name=\"file\"; filename=\"tmp-0.txt\"\r\n" + + "Content-Type: image/gif\r\n" + + "\r\n" + + data + "\r\n" + + "--" + boundary + "--\r\n"; + + req.content().writeBytes(body.getBytes(CharsetUtil.UTF_8)); + } + // Create decoder instance to test. + final HttpPostRequestDecoder decoder = new HttpPostRequestDecoder(inMemoryFactory, req); + assertFalse(decoder.getBodyHttpDatas().isEmpty()); + decoder.destroy(); + assertTrue(req.release()); + } + + // See https://github.com/netty/netty/issues/1848 + @Test + public void testNoZeroOut() throws Exception { + final String boundary = "E832jQp_Rq2ErFmAduHSR8YlMSm0FCY"; + + final DefaultHttpDataFactory aMemFactory = new DefaultHttpDataFactory(false); + + DefaultHttpRequest aRequest = new DefaultHttpRequest(HttpVersion.HTTP_1_1, + HttpMethod.POST, + "http://localhost"); + aRequest.headers().set(HttpHeaderNames.CONTENT_TYPE, + "multipart/form-data; boundary=" + boundary); + aRequest.headers().set(HttpHeaderNames.TRANSFER_ENCODING, + HttpHeaderValues.CHUNKED); + + HttpPostRequestDecoder aDecoder = new HttpPostRequestDecoder(aMemFactory, aRequest); + + final String aData = "some data would be here. the data should be long enough that it " + + "will be longer than the original buffer length of 256 bytes in " + + "the HttpPostRequestDecoder in order to trigger the issue. Some more " + + "data just to be on the safe side."; + + final String body = + "--" + boundary + "\r\n" + + "Content-Disposition: form-data; name=\"root\"\r\n" + + "Content-Type: text/plain\r\n" + + "\r\n" + + aData + + "\r\n" + + "--" + boundary + "--\r\n"; + + byte[] aBytes = body.getBytes(); + + int split = 125; + + ByteBufAllocator aAlloc = new UnpooledByteBufAllocator(true); + ByteBuf aSmallBuf = aAlloc.heapBuffer(split, split); + ByteBuf aLargeBuf = aAlloc.heapBuffer(aBytes.length - split, aBytes.length - split); + + aSmallBuf.writeBytes(aBytes, 0, split); + aLargeBuf.writeBytes(aBytes, split, aBytes.length - split); + + aDecoder.offer(new DefaultHttpContent(aSmallBuf)); + aDecoder.offer(new DefaultHttpContent(aLargeBuf)); + + aDecoder.offer(LastHttpContent.EMPTY_LAST_CONTENT); + + assertTrue(aDecoder.hasNext(), "Should have a piece of data"); + + InterfaceHttpData aDecodedData = aDecoder.next(); + assertEquals(InterfaceHttpData.HttpDataType.Attribute, aDecodedData.getHttpDataType()); + + Attribute aAttr = (Attribute) aDecodedData; + assertEquals(aData, aAttr.getValue()); + + aDecodedData.release(); + aDecoder.destroy(); + aSmallBuf.release(); + aLargeBuf.release(); + } + + // See https://github.com/netty/netty/issues/2305 + @Test + public void testChunkCorrect() throws Exception { + String payload = "town=794649819&town=784444184&town=794649672&town=794657800&town=" + + "794655734&town=794649377&town=794652136&town=789936338&town=789948986&town=" + + "789949643&town=786358677&town=794655880&town=786398977&town=789901165&town=" + + "789913325&town=789903418&town=789903579&town=794645251&town=794694126&town=" + + "794694831&town=794655274&town=789913656&town=794653956&town=794665634&town=" + + "789936598&town=789904658&town=789899210&town=799696252&town=794657521&town=" + + "789904837&town=789961286&town=789958704&town=789948839&town=789933899&town=" + + "793060398&town=794659180&town=794659365&town=799724096&town=794696332&town=" + + "789953438&town=786398499&town=794693372&town=789935439&town=794658041&town=" + + "789917595&town=794655427&town=791930372&town=794652891&town=794656365&town=" + + "789960339&town=794645586&town=794657688&town=794697211&town=789937427&town=" + + "789902813&town=789941130&town=794696907&town=789904328&town=789955151&town=" + + "789911570&town=794655074&town=789939531&town=789935242&town=789903835&town=" + + "789953800&town=794649962&town=789939841&town=789934819&town=789959672&town=" + + "794659043&town=794657035&town=794658938&town=794651746&town=794653732&town=" + + "794653881&town=786397909&town=794695736&town=799724044&town=794695926&town=" + + "789912270&town=794649030&town=794657946&town=794655370&town=794659660&town=" + + "794694617&town=799149862&town=789953234&town=789900476&town=794654995&town=" + + "794671126&town=789908868&town=794652942&town=789955605&town=789901934&town=" + + "789950015&town=789937922&town=789962576&town=786360170&town=789954264&town=" + + "789911738&town=789955416&town=799724187&town=789911879&town=794657462&town=" + + "789912561&town=789913167&town=794655195&town=789938266&town=789952099&town=" + + "794657160&town=789949414&town=794691293&town=794698153&town=789935636&town=" + + "789956374&town=789934635&town=789935475&town=789935085&town=794651425&town=" + + "794654936&town=794655680&town=789908669&town=794652031&town=789951298&town=" + + "789938382&town=794651503&town=794653330&town=817675037&town=789951623&town=" + + "789958999&town=789961555&town=794694050&town=794650241&town=794656286&town=" + + "794692081&town=794660090&town=794665227&town=794665136&town=794669931"; + DefaultHttpRequest defaultHttpRequest = + new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, "/"); + + HttpPostRequestDecoder decoder = new HttpPostRequestDecoder(defaultHttpRequest); + + int firstChunk = 10; + int middleChunk = 1024; + + byte[] payload1 = payload.substring(0, firstChunk).getBytes(); + byte[] payload2 = payload.substring(firstChunk, firstChunk + middleChunk).getBytes(); + byte[] payload3 = payload.substring(firstChunk + middleChunk, firstChunk + middleChunk * 2).getBytes(); + byte[] payload4 = payload.substring(firstChunk + middleChunk * 2).getBytes(); + + ByteBuf buf1 = Unpooled.directBuffer(payload1.length); + ByteBuf buf2 = Unpooled.directBuffer(payload2.length); + ByteBuf buf3 = Unpooled.directBuffer(payload3.length); + ByteBuf buf4 = Unpooled.directBuffer(payload4.length); + + buf1.writeBytes(payload1); + buf2.writeBytes(payload2); + buf3.writeBytes(payload3); + buf4.writeBytes(payload4); + + decoder.offer(new DefaultHttpContent(buf1)); + decoder.offer(new DefaultHttpContent(buf2)); + decoder.offer(new DefaultHttpContent(buf3)); + decoder.offer(new DefaultLastHttpContent(buf4)); + + assertFalse(decoder.getBodyHttpDatas().isEmpty()); + assertEquals(139, decoder.getBodyHttpDatas().size()); + + Attribute attr = (Attribute) decoder.getBodyHttpData("town"); + assertEquals("794649819", attr.getValue()); + + decoder.destroy(); + buf1.release(); + buf2.release(); + buf3.release(); + buf4.release(); + } + + // See https://github.com/netty/netty/issues/3326 + @Test + public void testFilenameContainingSemicolon() throws Exception { + final String boundary = "dLV9Wyq26L_-JQxk6ferf-RT153LhOO"; + final DefaultFullHttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, + "http://localhost"); + req.headers().add(HttpHeaderNames.CONTENT_TYPE, "multipart/form-data; boundary=" + boundary); + // Force to use memory-based data. + final DefaultHttpDataFactory inMemoryFactory = new DefaultHttpDataFactory(false); + final String data = "asdf"; + final String filename = "tmp;0.txt"; + final String body = + "--" + boundary + "\r\n" + + "Content-Disposition: form-data; name=\"file\"; filename=\"" + filename + "\"\r\n" + + "Content-Type: image/gif\r\n" + + "\r\n" + + data + "\r\n" + + "--" + boundary + "--\r\n"; + + req.content().writeBytes(body.getBytes(CharsetUtil.UTF_8.name())); + // Create decoder instance to test. + final HttpPostRequestDecoder decoder = new HttpPostRequestDecoder(inMemoryFactory, req); + assertFalse(decoder.getBodyHttpDatas().isEmpty()); + decoder.destroy(); + assertTrue(req.release()); + } + + @Test + public void testFilenameContainingSemicolon2() throws Exception { + final String boundary = "dLV9Wyq26L_-JQxk6ferf-RT153LhOO"; + final DefaultFullHttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, + "http://localhost"); + req.headers().add(HttpHeaderNames.CONTENT_TYPE, "multipart/form-data; boundary=" + boundary); + // Force to use memory-based data. + final DefaultHttpDataFactory inMemoryFactory = new DefaultHttpDataFactory(false); + final String data = "asdf"; + final String filename = "tmp;0.txt"; + final String body = + "--" + boundary + "\r\n" + + "Content-Disposition: form-data; name=\"file\"; filename=\"" + filename + "\"\r\n" + + "Content-Type: image/gif\r\n" + + "\r\n" + + data + "\r\n" + + "--" + boundary + "--\r\n"; + + req.content().writeBytes(body.getBytes(CharsetUtil.UTF_8.name())); + // Create decoder instance to test. + final HttpPostRequestDecoder decoder = new HttpPostRequestDecoder(inMemoryFactory, req); + assertFalse(decoder.getBodyHttpDatas().isEmpty()); + InterfaceHttpData part1 = decoder.getBodyHttpDatas().get(0); + assertTrue(part1 instanceof FileUpload); + FileUpload fileUpload = (FileUpload) part1; + assertEquals("tmp 0.txt", fileUpload.getFilename()); + decoder.destroy(); + assertTrue(req.release()); + } + + @Test + public void testMultipartRequestWithoutContentTypeBody() { + final String boundary = "dLV9Wyq26L_-JQxk6ferf-RT153LhOO"; + + final DefaultFullHttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, + "http://localhost"); + + req.setDecoderResult(DecoderResult.SUCCESS); + req.headers().add(HttpHeaderNames.CONTENT_TYPE, "multipart/form-data; boundary=" + boundary); + req.headers().add(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED); + + // Force to use memory-based data. + final DefaultHttpDataFactory inMemoryFactory = new DefaultHttpDataFactory(false); + + for (String data : Arrays.asList("", "\r", "\r\r", "\r\r\r")) { + final String body = + "--" + boundary + "\r\n" + + "Content-Disposition: form-data; name=\"file\"; filename=\"tmp-0.txt\"\r\n" + + "\r\n" + + data + "\r\n" + + "--" + boundary + "--\r\n"; + + req.content().writeBytes(body.getBytes(CharsetUtil.UTF_8)); + } + // Create decoder instance to test without any exception. + final HttpPostRequestDecoder decoder = new HttpPostRequestDecoder(inMemoryFactory, req); + assertFalse(decoder.getBodyHttpDatas().isEmpty()); + decoder.destroy(); + assertTrue(req.release()); + } + + @Test + public void testDecodeOtherMimeHeaderFields() throws Exception { + final String boundary = "74e78d11b0214bdcbc2f86491eeb4902"; + String filecontent = "123456"; + + final String body = "--" + boundary + "\r\n" + + "Content-Disposition: form-data; name=\"file\"; filename=" + "\"" + "attached.txt" + "\"" + + "\r\n" + + "Content-Type: application/octet-stream" + "\r\n" + + "Content-Encoding: gzip" + "\r\n" + + "\r\n" + + filecontent + + "\r\n" + + "--" + boundary + "--"; + + final DefaultFullHttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, + HttpMethod.POST, + "http://localhost", + Unpooled.wrappedBuffer(body.getBytes())); + req.headers().add(HttpHeaderNames.CONTENT_TYPE, "multipart/form-data; boundary=" + boundary); + req.headers().add(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED); + final DefaultHttpDataFactory inMemoryFactory = new DefaultHttpDataFactory(false); + final HttpPostRequestDecoder decoder = new HttpPostRequestDecoder(inMemoryFactory, req); + assertFalse(decoder.getBodyHttpDatas().isEmpty()); + InterfaceHttpData part1 = decoder.getBodyHttpDatas().get(0); + assertTrue(part1 instanceof FileUpload, "the item should be a FileUpload"); + FileUpload fileUpload = (FileUpload) part1; + byte[] fileBytes = fileUpload.get(); + assertTrue(filecontent.equals(new String(fileBytes)), "the filecontent should not be decoded"); + decoder.destroy(); + assertTrue(req.release()); + } + + @Test + public void testMultipartRequestWithFileInvalidCharset() throws Exception { + final String boundary = "dLV9Wyq26L_-JQxk6ferf-RT153LhOO"; + final DefaultFullHttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, + "http://localhost"); + req.headers().add(HttpHeaderNames.CONTENT_TYPE, "multipart/form-data; boundary=" + boundary); + // Force to use memory-based data. + final DefaultHttpDataFactory inMemoryFactory = new DefaultHttpDataFactory(false); + final String data = "asdf"; + final String filename = "tmp;0.txt"; + final String body = + "--" + boundary + "\r\n" + + "Content-Disposition: form-data; name=\"file\"; filename=\"" + filename + "\"\r\n" + + "Content-Type: image/gif; charset=ABCD\r\n" + + "\r\n" + + data + "\r\n" + + "--" + boundary + "--\r\n"; + + req.content().writeBytes(body.getBytes(CharsetUtil.UTF_8)); + // Create decoder instance to test. + try { + new HttpPostRequestDecoder(inMemoryFactory, req); + fail("Was expecting an ErrorDataDecoderException"); + } catch (HttpPostRequestDecoder.ErrorDataDecoderException e) { + assertTrue(e.getCause() instanceof UnsupportedCharsetException); + } finally { + assertTrue(req.release()); + } + } + + @Test + public void testMultipartRequestWithFieldInvalidCharset() throws Exception { + final String boundary = "dLV9Wyq26L_-JQxk6ferf-RT153LhOO"; + final DefaultFullHttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, + "http://localhost"); + req.headers().add(HttpHeaderNames.CONTENT_TYPE, "multipart/form-data; boundary=" + boundary); + // Force to use memory-based data. + final DefaultHttpDataFactory inMemoryFactory = new DefaultHttpDataFactory(false); + final String aData = "some data would be here. the data should be long enough that it " + + "will be longer than the original buffer length of 256 bytes in " + + "the HttpPostRequestDecoder in order to trigger the issue. Some more " + + "data just to be on the safe side."; + final String body = + "--" + boundary + "\r\n" + + "Content-Disposition: form-data; name=\"root\"\r\n" + + "Content-Type: text/plain; charset=ABCD\r\n" + + "\r\n" + + aData + + "\r\n" + + "--" + boundary + "--\r\n"; + + req.content().writeBytes(body.getBytes(CharsetUtil.UTF_8)); + // Create decoder instance to test. + try { + new HttpPostRequestDecoder(inMemoryFactory, req); + fail("Was expecting an ErrorDataDecoderException"); + } catch (HttpPostRequestDecoder.ErrorDataDecoderException e) { + assertTrue(e.getCause() instanceof UnsupportedCharsetException); + } finally { + assertTrue(req.release()); + } + } + + @Test + public void testFormEncodeIncorrect() throws Exception { + LastHttpContent content = new DefaultLastHttpContent( + Unpooled.copiedBuffer("project=netty&=netty&project=netty", CharsetUtil.US_ASCII)); + DefaultHttpRequest req = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, "/"); + HttpPostRequestDecoder decoder = new HttpPostRequestDecoder(req); + try { + decoder.offer(content); + fail(); + } catch (HttpPostRequestDecoder.ErrorDataDecoderException e) { + assertTrue(e.getCause() instanceof IllegalArgumentException); + } finally { + content.release(); + decoder.destroy(); + } + } + + // https://github.com/netty/netty/pull/7265 + @Test + public void testDecodeContentDispositionFieldParameters() throws Exception { + + final String boundary = "74e78d11b0214bdcbc2f86491eeb4902"; + + String encoding = "utf-8"; + String filename = "attached_файл.txt"; + String filenameEncoded = URLEncoder.encode(filename, encoding); + + final String body = "--" + boundary + "\r\n" + + "Content-Disposition: form-data; name=\"file\"; filename*=" + encoding + "''" + filenameEncoded + + "\r\n\r\n" + + "foo\r\n" + + "\r\n" + + "--" + boundary + "--"; + + final DefaultFullHttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, + HttpMethod.POST, + "http://localhost", + Unpooled.wrappedBuffer(body.getBytes())); + + req.headers().add(HttpHeaderNames.CONTENT_TYPE, "multipart/form-data; boundary=" + boundary); + final DefaultHttpDataFactory inMemoryFactory = new DefaultHttpDataFactory(false); + final HttpPostRequestDecoder decoder = new HttpPostRequestDecoder(inMemoryFactory, req); + assertFalse(decoder.getBodyHttpDatas().isEmpty()); + InterfaceHttpData part1 = decoder.getBodyHttpDatas().get(0); + assertTrue(part1 instanceof FileUpload, "the item should be a FileUpload"); + FileUpload fileUpload = (FileUpload) part1; + assertEquals(filename, fileUpload.getFilename(), "the filename should be decoded"); + decoder.destroy(); + assertTrue(req.release()); + } + + // https://github.com/netty/netty/pull/7265 + @Test + public void testDecodeWithLanguageContentDispositionFieldParameters() throws Exception { + + final String boundary = "74e78d11b0214bdcbc2f86491eeb4902"; + + String encoding = "utf-8"; + String filename = "attached_файл.txt"; + String language = "anything"; + String filenameEncoded = URLEncoder.encode(filename, encoding); + + final String body = "--" + boundary + "\r\n" + + "Content-Disposition: form-data; name=\"file\"; filename*=" + + encoding + "'" + language + "'" + filenameEncoded + "\r\n" + + "\r\n" + + "foo\r\n" + + "\r\n" + + "--" + boundary + "--"; + + final DefaultFullHttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, + HttpMethod.POST, + "http://localhost", + Unpooled.wrappedBuffer(body.getBytes())); + + req.headers().add(HttpHeaderNames.CONTENT_TYPE, "multipart/form-data; boundary=" + boundary); + final DefaultHttpDataFactory inMemoryFactory = new DefaultHttpDataFactory(false); + final HttpPostRequestDecoder decoder = new HttpPostRequestDecoder(inMemoryFactory, req); + assertFalse(decoder.getBodyHttpDatas().isEmpty()); + InterfaceHttpData part1 = decoder.getBodyHttpDatas().get(0); + assertTrue(part1 instanceof FileUpload, "the item should be a FileUpload"); + FileUpload fileUpload = (FileUpload) part1; + assertEquals(filename, fileUpload.getFilename(), "the filename should be decoded"); + decoder.destroy(); + assertTrue(req.release()); + } + + // https://github.com/netty/netty/pull/7265 + @Test + public void testDecodeMalformedNotEncodedContentDispositionFieldParameters() throws Exception { + + final String boundary = "74e78d11b0214bdcbc2f86491eeb4902"; + + final String body = "--" + boundary + "\r\n" + + "Content-Disposition: form-data; name=\"file\"; filename*=not-encoded\r\n" + + "\r\n" + + "foo\r\n" + + "\r\n" + + "--" + boundary + "--"; + + final DefaultFullHttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, + HttpMethod.POST, + "http://localhost", + Unpooled.wrappedBuffer(body.getBytes())); + + req.headers().add(HttpHeaderNames.CONTENT_TYPE, "multipart/form-data; boundary=" + boundary); + + final DefaultHttpDataFactory inMemoryFactory = new DefaultHttpDataFactory(false); + + try { + new HttpPostRequestDecoder(inMemoryFactory, req); + fail("Was expecting an ErrorDataDecoderException"); + } catch (HttpPostRequestDecoder.ErrorDataDecoderException e) { + assertTrue(e.getCause() instanceof ArrayIndexOutOfBoundsException); + } finally { + assertTrue(req.release()); + } + } + + // https://github.com/netty/netty/pull/7265 + @Test + public void testDecodeMalformedBadCharsetContentDispositionFieldParameters() throws Exception { + + final String boundary = "74e78d11b0214bdcbc2f86491eeb4902"; + + final String body = "--" + boundary + "\r\n" + + "Content-Disposition: form-data; name=\"file\"; filename*=not-a-charset''filename\r\n" + + "\r\n" + + "foo\r\n" + + "\r\n" + + "--" + boundary + "--"; + + final DefaultFullHttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, + HttpMethod.POST, + "http://localhost", + Unpooled.wrappedBuffer(body.getBytes())); + + req.headers().add(HttpHeaderNames.CONTENT_TYPE, "multipart/form-data; boundary=" + boundary); + + final DefaultHttpDataFactory inMemoryFactory = new DefaultHttpDataFactory(false); + + try { + new HttpPostRequestDecoder(inMemoryFactory, req); + fail("Was expecting an ErrorDataDecoderException"); + } catch (HttpPostRequestDecoder.ErrorDataDecoderException e) { + assertTrue(e.getCause() instanceof UnsupportedCharsetException); + } finally { + assertTrue(req.release()); + } + } + + // https://github.com/netty/netty/issues/7620 + @Test + public void testDecodeMalformedEmptyContentTypeFieldParameters() throws Exception { + final String boundary = "dLV9Wyq26L_-JQxk6ferf-RT153LhOO"; + final DefaultFullHttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, + "http://localhost"); + req.headers().add(HttpHeaderNames.CONTENT_TYPE, "multipart/form-data; boundary=" + boundary); + // Force to use memory-based data. + final DefaultHttpDataFactory inMemoryFactory = new DefaultHttpDataFactory(false); + final String data = "asdf"; + final String filename = "tmp-0.txt"; + final String body = + "--" + boundary + "\r\n" + + "Content-Disposition: form-data; name=\"file\"; filename=\"" + filename + "\"\r\n" + + "Content-Type: \r\n" + + "\r\n" + + data + "\r\n" + + "--" + boundary + "--\r\n"; + + req.content().writeBytes(body.getBytes(CharsetUtil.UTF_8.name())); + // Create decoder instance to test. + final HttpPostRequestDecoder decoder = new HttpPostRequestDecoder(inMemoryFactory, req); + assertFalse(decoder.getBodyHttpDatas().isEmpty()); + InterfaceHttpData part1 = decoder.getBodyHttpDatas().get(0); + assertTrue(part1 instanceof FileUpload); + FileUpload fileUpload = (FileUpload) part1; + assertEquals("tmp-0.txt", fileUpload.getFilename()); + decoder.destroy(); + assertTrue(req.release()); + } + + // https://github.com/netty/netty/issues/8575 + @Test + public void testMultipartRequest() throws Exception { + String BOUNDARY = "01f136d9282f"; + + byte[] bodyBytes = ("--" + BOUNDARY + "\n" + + "Content-Disposition: form-data; name=\"msg_id\"\n" + + "\n" + + "15200\n" + + "--" + BOUNDARY + "\n" + + "Content-Disposition: form-data; name=\"msg\"\n" + + "\n" + + "test message\n" + + "--" + BOUNDARY + "--").getBytes(); + ByteBuf byteBuf = Unpooled.directBuffer(bodyBytes.length); + byteBuf.writeBytes(bodyBytes); + + FullHttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_0, HttpMethod.POST, "/up", byteBuf); + req.headers().add(HttpHeaderNames.CONTENT_TYPE, "multipart/form-data; boundary=" + BOUNDARY); + + HttpPostRequestDecoder decoder = + new HttpPostRequestDecoder(new DefaultHttpDataFactory(DefaultHttpDataFactory.MINSIZE), + req, + CharsetUtil.UTF_8); + + assertTrue(decoder.isMultipart()); + assertFalse(decoder.getBodyHttpDatas().isEmpty()); + assertEquals(2, decoder.getBodyHttpDatas().size()); + + Attribute attrMsg = (Attribute) decoder.getBodyHttpData("msg"); + assertTrue(attrMsg.getByteBuf().isDirect()); + assertEquals("test message", attrMsg.getValue()); + Attribute attrMsgId = (Attribute) decoder.getBodyHttpData("msg_id"); + assertTrue(attrMsgId.getByteBuf().isDirect()); + assertEquals("15200", attrMsgId.getValue()); + + decoder.destroy(); + assertTrue(req.release()); + } + + @Test + public void testNotLeak() { + final FullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, "/", + Unpooled.copiedBuffer("a=1&=2&b=3", CharsetUtil.US_ASCII)); + try { + assertThrows(HttpPostRequestDecoder.ErrorDataDecoderException.class, new Executable() { + @Override + public void execute() { + new HttpPostStandardRequestDecoder(request).destroy(); + } + }); + } finally { + assertTrue(request.release()); + } + } + + @Test + public void testNotLeakDirectBufferWhenWrapIllegalArgumentException() { + assertThrows(HttpPostRequestDecoder.ErrorDataDecoderException.class, new Executable() { + @Override + public void execute() { + testNotLeakWhenWrapIllegalArgumentException(Unpooled.directBuffer()); + } + }); + } + + @Test + public void testNotLeakHeapBufferWhenWrapIllegalArgumentException() { + assertThrows(HttpPostRequestDecoder.ErrorDataDecoderException.class, new Executable() { + @Override + public void execute() throws Throwable { + testNotLeakWhenWrapIllegalArgumentException(Unpooled.buffer()); + } + }); + } + + private static void testNotLeakWhenWrapIllegalArgumentException(ByteBuf buf) { + buf.writeCharSequence("a=b&foo=%22bar%22&==", CharsetUtil.US_ASCII); + FullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, "/", buf); + try { + new HttpPostStandardRequestDecoder(request).destroy(); + } finally { + assertTrue(request.release()); + } + } + + @Test + public void testMultipartFormDataContentType() { + HttpRequest request = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, "/"); + assertFalse(HttpPostRequestDecoder.isMultipart(request)); + + String multipartDataValue = HttpHeaderValues.MULTIPART_FORM_DATA + ";" + "boundary=gc0p4Jq0M2Yt08jU534c0p"; + request.headers().set(HttpHeaderNames.CONTENT_TYPE, ";" + multipartDataValue); + assertFalse(HttpPostRequestDecoder.isMultipart(request)); + + request.headers().set(HttpHeaderNames.CONTENT_TYPE, multipartDataValue); + assertTrue(HttpPostRequestDecoder.isMultipart(request)); + } + + // see https://github.com/netty/netty/issues/10087 + @Test + public void testDecodeWithLanguageContentDispositionFieldParametersForFix() throws Exception { + + final String boundary = "952178786863262625034234"; + + String encoding = "UTF-8"; + String filename = "测试test.txt"; + String filenameEncoded = URLEncoder.encode(filename, encoding); + + final String body = "--" + boundary + "\r\n" + + "Content-Disposition: form-data; name=\"file\"; filename*=\"" + + encoding + "''" + filenameEncoded + "\"\r\n" + + "\r\n" + + "foo\r\n" + + "\r\n" + + "--" + boundary + "--"; + + final DefaultFullHttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, + HttpMethod.POST, + "http://localhost", + Unpooled.wrappedBuffer(body.getBytes())); + + req.headers().add(HttpHeaderNames.CONTENT_TYPE, "multipart/form-data; boundary=" + boundary); + final DefaultHttpDataFactory inMemoryFactory = new DefaultHttpDataFactory(false); + final HttpPostRequestDecoder decoder = new HttpPostRequestDecoder(inMemoryFactory, req); + assertFalse(decoder.getBodyHttpDatas().isEmpty()); + InterfaceHttpData part1 = decoder.getBodyHttpDatas().get(0); + assertTrue(part1 instanceof FileUpload, "the item should be a FileUpload"); + FileUpload fileUpload = (FileUpload) part1; + assertEquals(filename, fileUpload.getFilename(), "the filename should be decoded"); + + decoder.destroy(); + assertTrue(req.release()); + } + + @Test + public void testDecodeFullHttpRequestWithUrlEncodedBody() throws Exception { + byte[] bodyBytes = "foo=bar&a=b&empty=&city=%3c%22new%22%20york%20city%3e&other_city=los+angeles".getBytes(); + ByteBuf content = Unpooled.directBuffer(bodyBytes.length); + content.writeBytes(bodyBytes); + + FullHttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, "/", content); + HttpPostRequestDecoder decoder = new HttpPostRequestDecoder(req); + assertFalse(decoder.getBodyHttpDatas().isEmpty()); + + assertFalse(decoder.getBodyHttpDatas().isEmpty()); + assertEquals(5, decoder.getBodyHttpDatas().size()); + + Attribute attr = (Attribute) decoder.getBodyHttpData("foo"); + assertTrue(attr.getByteBuf().isDirect()); + assertEquals("bar", attr.getValue()); + + attr = (Attribute) decoder.getBodyHttpData("a"); + assertTrue(attr.getByteBuf().isDirect()); + assertEquals("b", attr.getValue()); + + attr = (Attribute) decoder.getBodyHttpData("empty"); + assertTrue(attr.getByteBuf().isDirect()); + assertEquals("", attr.getValue()); + + attr = (Attribute) decoder.getBodyHttpData("city"); + assertTrue(attr.getByteBuf().isDirect()); + assertEquals("<\"new\" york city>", attr.getValue()); + + attr = (Attribute) decoder.getBodyHttpData("other_city"); + assertTrue(attr.getByteBuf().isDirect()); + assertEquals("los angeles", attr.getValue()); + + decoder.destroy(); + assertTrue(req.release()); + } + + @Test + public void testDecodeFullHttpRequestWithUrlEncodedBodyWithBrokenHexByte0() { + byte[] bodyBytes = "foo=bar&a=b&empty=%&city=paris".getBytes(); + ByteBuf content = Unpooled.directBuffer(bodyBytes.length); + content.writeBytes(bodyBytes); + + FullHttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, "/", content); + try { + new HttpPostRequestDecoder(req); + fail("Was expecting an ErrorDataDecoderException"); + } catch (HttpPostRequestDecoder.ErrorDataDecoderException e) { + assertEquals("Invalid hex byte at index '0' in string: '%'", e.getMessage()); + } finally { + assertTrue(req.release()); + } + } + + @Test + public void testDecodeFullHttpRequestWithUrlEncodedBodyWithBrokenHexByte1() { + byte[] bodyBytes = "foo=bar&a=b&empty=%2&city=london".getBytes(); + ByteBuf content = Unpooled.directBuffer(bodyBytes.length); + content.writeBytes(bodyBytes); + + FullHttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, "/", content); + try { + new HttpPostRequestDecoder(req); + fail("Was expecting an ErrorDataDecoderException"); + } catch (HttpPostRequestDecoder.ErrorDataDecoderException e) { + assertEquals("Invalid hex byte at index '0' in string: '%2'", e.getMessage()); + } finally { + assertTrue(req.release()); + } + } + + @Test + public void testDecodeFullHttpRequestWithUrlEncodedBodyWithInvalidHexNibbleHi() { + byte[] bodyBytes = "foo=bar&a=b&empty=%Zc&city=london".getBytes(); + ByteBuf content = Unpooled.directBuffer(bodyBytes.length); + content.writeBytes(bodyBytes); + + FullHttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, "/", content); + try { + new HttpPostRequestDecoder(req); + fail("Was expecting an ErrorDataDecoderException"); + } catch (HttpPostRequestDecoder.ErrorDataDecoderException e) { + assertEquals("Invalid hex byte at index '0' in string: '%Zc'", e.getMessage()); + } finally { + assertTrue(req.release()); + } + } + + @Test + public void testDecodeFullHttpRequestWithUrlEncodedBodyWithInvalidHexNibbleLo() { + byte[] bodyBytes = "foo=bar&a=b&empty=%2g&city=london".getBytes(); + ByteBuf content = Unpooled.directBuffer(bodyBytes.length); + content.writeBytes(bodyBytes); + + FullHttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, "/", content); + try { + new HttpPostRequestDecoder(req); + fail("Was expecting an ErrorDataDecoderException"); + } catch (HttpPostRequestDecoder.ErrorDataDecoderException e) { + assertEquals("Invalid hex byte at index '0' in string: '%2g'", e.getMessage()); + } finally { + assertTrue(req.release()); + } + } + + @Test + public void testDecodeMultipartRequest() { + byte[] bodyBytes = ("--be38b42a9ad2713f\n" + + "content-disposition: form-data; name=\"title\"\n" + + "content-length: 10\n" + + "content-type: text/plain; charset=UTF-8\n" + + "\n" + + "bar-stream\n" + + "--be38b42a9ad2713f\n" + + "content-disposition: form-data; name=\"data\"; filename=\"data.json\"\n" + + "content-length: 16\n" + + "content-type: application/json; charset=UTF-8\n" + + "\n" + + "{\"title\":\"Test\"}\n" + + "--be38b42a9ad2713f--").getBytes(); + ByteBuf content = Unpooled.directBuffer(bodyBytes.length); + content.writeBytes(bodyBytes); + FullHttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, "/", content); + req.headers().add("Content-Type", "multipart/form-data;boundary=be38b42a9ad2713f"); + + try { + HttpPostRequestDecoder decoder = new HttpPostRequestDecoder(new DefaultHttpDataFactory(false), req); + assertEquals(2, decoder.getBodyHttpDatas().size()); + InterfaceHttpData data = decoder.getBodyHttpData("title"); + assertTrue(data instanceof MemoryAttribute); + assertEquals("bar-stream", ((MemoryAttribute) data).getString()); + assertTrue(data.release()); + data = decoder.getBodyHttpData("data"); + assertTrue(data instanceof MemoryFileUpload); + assertEquals("{\"title\":\"Test\"}", ((MemoryFileUpload) data).getString()); + assertTrue(data.release()); + decoder.destroy(); + } catch (HttpPostRequestDecoder.ErrorDataDecoderException e) { + fail("Was not expecting an exception"); + } finally { + assertTrue(req.release()); + } + } + + /** + * when diskFilename contain "\" create temp file error + */ + @Test + void testHttpPostStandardRequestDecoderToDiskNameContainingUnauthorizedChar() throws UnsupportedEncodingException { + StringBuffer sb = new StringBuffer(); + byte[] bodyBytes = ("aaaa/bbbb=aaaaaaaaaa" + + "aaaaaaaaaaaaaaaaaaaaaaaaaa" + + "aaaaaaaaaaaaaaaaaaaaaaaaaa" + + "aaaaaaaaaaaaaaaaaaa").getBytes(CharsetUtil.US_ASCII); + ByteBuf content = Unpooled.directBuffer(bodyBytes.length); + content.writeBytes(bodyBytes); + + FullHttpRequest req = + new DefaultFullHttpRequest( + HttpVersion.HTTP_1_1, + HttpMethod.POST, + "/", + content); + HttpPostStandardRequestDecoder decoder = null; + try { + decoder = new HttpPostStandardRequestDecoder( + new DefaultHttpDataFactory(true), + req + ); + decoder.destroy(); + } catch (HttpPostRequestDecoder.ErrorDataDecoderException e) { + if (null != decoder) { + decoder.destroy(); + } + fail("Was not expecting an exception"); + } finally { + assertTrue(req.release()); + } + } + +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/multipart/HttpPostRequestEncoderTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/multipart/HttpPostRequestEncoderTest.java new file mode 100755 index 0000000..d1810cc --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/multipart/HttpPostRequestEncoderTest.java @@ -0,0 +1,471 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.multipart; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.netty.handler.codec.http.DefaultHttpRequest; +import io.netty.handler.codec.http.DefaultFullHttpRequest; +import io.netty.handler.codec.http.HttpConstants; +import io.netty.handler.codec.http.HttpContent; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpRequest; +import io.netty.handler.codec.http.HttpVersion; +import io.netty.handler.codec.http.LastHttpContent; +import io.netty.handler.codec.http.multipart.HttpPostRequestEncoder.EncoderMode; +import io.netty.handler.codec.http.multipart.HttpPostRequestEncoder.ErrorDataEncoderException; +import io.netty.util.CharsetUtil; +import io.netty.util.internal.StringUtil; +import org.junit.jupiter.api.Test; + +import java.io.ByteArrayInputStream; +import java.io.File; +import java.util.Arrays; +import java.util.List; + +import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_DISPOSITION; +import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_LENGTH; +import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_TRANSFER_ENCODING; +import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_TYPE; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +/** {@link HttpPostRequestEncoder} test case. */ +public class HttpPostRequestEncoderTest { + + @Test + public void testAllowedMethods() throws Exception { + shouldThrowExceptionIfNotAllowed(HttpMethod.CONNECT); + shouldThrowExceptionIfNotAllowed(HttpMethod.PUT); + shouldThrowExceptionIfNotAllowed(HttpMethod.POST); + shouldThrowExceptionIfNotAllowed(HttpMethod.PATCH); + shouldThrowExceptionIfNotAllowed(HttpMethod.DELETE); + shouldThrowExceptionIfNotAllowed(HttpMethod.GET); + shouldThrowExceptionIfNotAllowed(HttpMethod.HEAD); + shouldThrowExceptionIfNotAllowed(HttpMethod.OPTIONS); + try { + shouldThrowExceptionIfNotAllowed(HttpMethod.TRACE); + fail("Should raised an exception with TRACE method"); + } catch (ErrorDataEncoderException e) { + // Exception is willing + } + } + + private void shouldThrowExceptionIfNotAllowed(HttpMethod method) throws Exception { + DefaultFullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, + method, "http://localhost"); + + HttpPostRequestEncoder encoder = new HttpPostRequestEncoder(request, true); + File file1 = new File(getClass().getResource("/file-01.txt").toURI()); + encoder.addBodyAttribute("foo", "bar"); + encoder.addBodyFileUpload("quux", file1, "text/plain", false); + + String multipartDataBoundary = encoder.multipartDataBoundary; + String content = getRequestBody(encoder); + + String expected = "--" + multipartDataBoundary + "\r\n" + + CONTENT_DISPOSITION + ": form-data; name=\"foo\"" + "\r\n" + + CONTENT_LENGTH + ": 3" + "\r\n" + + CONTENT_TYPE + ": text/plain; charset=UTF-8" + "\r\n" + + "\r\n" + + "bar" + + "\r\n" + + "--" + multipartDataBoundary + "\r\n" + + CONTENT_DISPOSITION + ": form-data; name=\"quux\"; filename=\"file-01.txt\"" + "\r\n" + + CONTENT_LENGTH + ": " + file1.length() + "\r\n" + + CONTENT_TYPE + ": text/plain" + "\r\n" + + CONTENT_TRANSFER_ENCODING + ": binary" + "\r\n" + + "\r\n" + + "File 01" + StringUtil.NEWLINE + + "\r\n" + + "--" + multipartDataBoundary + "--" + "\r\n"; + + assertEquals(expected, content); + } + + @Test + public void testSingleFileUploadNoName() throws Exception { + DefaultFullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, + HttpMethod.POST, "http://localhost"); + + HttpPostRequestEncoder encoder = new HttpPostRequestEncoder(request, true); + File file1 = new File(getClass().getResource("/file-01.txt").toURI()); + encoder.addBodyAttribute("foo", "bar"); + encoder.addBodyFileUpload("quux", "", file1, "text/plain", false); + + String multipartDataBoundary = encoder.multipartDataBoundary; + String content = getRequestBody(encoder); + + String expected = "--" + multipartDataBoundary + "\r\n" + + CONTENT_DISPOSITION + ": form-data; name=\"foo\"" + "\r\n" + + CONTENT_LENGTH + ": 3" + "\r\n" + + CONTENT_TYPE + ": text/plain; charset=UTF-8" + "\r\n" + + "\r\n" + + "bar" + + "\r\n" + + "--" + multipartDataBoundary + "\r\n" + + CONTENT_DISPOSITION + ": form-data; name=\"quux\"\r\n" + + CONTENT_LENGTH + ": " + file1.length() + "\r\n" + + CONTENT_TYPE + ": text/plain" + "\r\n" + + CONTENT_TRANSFER_ENCODING + ": binary" + "\r\n" + + "\r\n" + + "File 01" + StringUtil.NEWLINE + + "\r\n" + + "--" + multipartDataBoundary + "--" + "\r\n"; + + assertEquals(expected, content); + } + + @Test + public void testMultiFileUploadInMixedMode() throws Exception { + DefaultFullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, + HttpMethod.POST, "http://localhost"); + + HttpPostRequestEncoder encoder = new HttpPostRequestEncoder(request, true); + File file1 = new File(getClass().getResource("/file-01.txt").toURI()); + File file2 = new File(getClass().getResource("/file-02.txt").toURI()); + File file3 = new File(getClass().getResource("/file-03.txt").toURI()); + encoder.addBodyAttribute("foo", "bar"); + encoder.addBodyFileUpload("quux", file1, "text/plain", false); + encoder.addBodyFileUpload("quux", file2, "text/plain", false); + encoder.addBodyFileUpload("quux", file3, "text/plain", false); + + // We have to query the value of these two fields before finalizing + // the request, which unsets one of them. + String multipartDataBoundary = encoder.multipartDataBoundary; + String multipartMixedBoundary = encoder.multipartMixedBoundary; + String content = getRequestBody(encoder); + + String expected = "--" + multipartDataBoundary + "\r\n" + + CONTENT_DISPOSITION + ": form-data; name=\"foo\"" + "\r\n" + + CONTENT_LENGTH + ": 3" + "\r\n" + + CONTENT_TYPE + ": text/plain; charset=UTF-8" + "\r\n" + + "\r\n" + + "bar" + "\r\n" + + "--" + multipartDataBoundary + "\r\n" + + CONTENT_DISPOSITION + ": form-data; name=\"quux\"" + "\r\n" + + CONTENT_TYPE + ": multipart/mixed; boundary=" + multipartMixedBoundary + "\r\n" + + "\r\n" + + "--" + multipartMixedBoundary + "\r\n" + + CONTENT_DISPOSITION + ": attachment; filename=\"file-01.txt\"" + "\r\n" + + CONTENT_LENGTH + ": " + file1.length() + "\r\n" + + CONTENT_TYPE + ": text/plain" + "\r\n" + + CONTENT_TRANSFER_ENCODING + ": binary" + "\r\n" + + "\r\n" + + "File 01" + StringUtil.NEWLINE + + "\r\n" + + "--" + multipartMixedBoundary + "\r\n" + + CONTENT_DISPOSITION + ": attachment; filename=\"file-02.txt\"" + "\r\n" + + CONTENT_LENGTH + ": " + file2.length() + "\r\n" + + CONTENT_TYPE + ": text/plain" + "\r\n" + + CONTENT_TRANSFER_ENCODING + ": binary" + "\r\n" + + "\r\n" + + "File 02" + StringUtil.NEWLINE + + "\r\n" + + "--" + multipartMixedBoundary + "\r\n" + + CONTENT_DISPOSITION + ": attachment; filename=\"file-03.txt\"" + "\r\n" + + CONTENT_LENGTH + ": " + file3.length() + "\r\n" + + CONTENT_TYPE + ": text/plain" + "\r\n" + + CONTENT_TRANSFER_ENCODING + ": binary" + "\r\n" + + "\r\n" + + "File 03" + StringUtil.NEWLINE + + "\r\n" + + "--" + multipartMixedBoundary + "--" + "\r\n" + + "--" + multipartDataBoundary + "--" + "\r\n"; + + assertEquals(expected, content); + } + + @Test + public void testMultiFileUploadInMixedModeNoName() throws Exception { + DefaultFullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, + HttpMethod.POST, "http://localhost"); + + HttpPostRequestEncoder encoder = new HttpPostRequestEncoder(request, true); + File file1 = new File(getClass().getResource("/file-01.txt").toURI()); + File file2 = new File(getClass().getResource("/file-02.txt").toURI()); + encoder.addBodyAttribute("foo", "bar"); + encoder.addBodyFileUpload("quux", "", file1, "text/plain", false); + encoder.addBodyFileUpload("quux", "", file2, "text/plain", false); + + // We have to query the value of these two fields before finalizing + // the request, which unsets one of them. + String multipartDataBoundary = encoder.multipartDataBoundary; + String multipartMixedBoundary = encoder.multipartMixedBoundary; + String content = getRequestBody(encoder); + + String expected = "--" + multipartDataBoundary + "\r\n" + + CONTENT_DISPOSITION + ": form-data; name=\"foo\"" + "\r\n" + + CONTENT_LENGTH + ": 3" + "\r\n" + + CONTENT_TYPE + ": text/plain; charset=UTF-8" + "\r\n" + + "\r\n" + + "bar" + "\r\n" + + "--" + multipartDataBoundary + "\r\n" + + CONTENT_DISPOSITION + ": form-data; name=\"quux\"" + "\r\n" + + CONTENT_TYPE + ": multipart/mixed; boundary=" + multipartMixedBoundary + "\r\n" + + "\r\n" + + "--" + multipartMixedBoundary + "\r\n" + + CONTENT_DISPOSITION + ": attachment\r\n" + + CONTENT_LENGTH + ": " + file1.length() + "\r\n" + + CONTENT_TYPE + ": text/plain" + "\r\n" + + CONTENT_TRANSFER_ENCODING + ": binary" + "\r\n" + + "\r\n" + + "File 01" + StringUtil.NEWLINE + + "\r\n" + + "--" + multipartMixedBoundary + "\r\n" + + CONTENT_DISPOSITION + ": attachment\r\n" + + CONTENT_LENGTH + ": " + file2.length() + "\r\n" + + CONTENT_TYPE + ": text/plain" + "\r\n" + + CONTENT_TRANSFER_ENCODING + ": binary" + "\r\n" + + "\r\n" + + "File 02" + StringUtil.NEWLINE + + "\r\n" + + "--" + multipartMixedBoundary + "--" + "\r\n" + + "--" + multipartDataBoundary + "--" + "\r\n"; + + assertEquals(expected, content); + } + + @Test + public void testSingleFileUploadInHtml5Mode() throws Exception { + DefaultFullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, + HttpMethod.POST, "http://localhost"); + + DefaultHttpDataFactory factory = new DefaultHttpDataFactory(DefaultHttpDataFactory.MINSIZE); + + HttpPostRequestEncoder encoder = new HttpPostRequestEncoder(factory, + request, true, CharsetUtil.UTF_8, EncoderMode.HTML5); + File file1 = new File(getClass().getResource("/file-01.txt").toURI()); + File file2 = new File(getClass().getResource("/file-02.txt").toURI()); + encoder.addBodyAttribute("foo", "bar"); + encoder.addBodyFileUpload("quux", file1, "text/plain", false); + encoder.addBodyFileUpload("quux", file2, "text/plain", false); + + String multipartDataBoundary = encoder.multipartDataBoundary; + String content = getRequestBody(encoder); + + String expected = "--" + multipartDataBoundary + "\r\n" + + CONTENT_DISPOSITION + ": form-data; name=\"foo\"" + "\r\n" + + CONTENT_LENGTH + ": 3" + "\r\n" + + CONTENT_TYPE + ": text/plain; charset=UTF-8" + "\r\n" + + "\r\n" + + "bar" + "\r\n" + + "--" + multipartDataBoundary + "\r\n" + + CONTENT_DISPOSITION + ": form-data; name=\"quux\"; filename=\"file-01.txt\"" + "\r\n" + + CONTENT_LENGTH + ": " + file1.length() + "\r\n" + + CONTENT_TYPE + ": text/plain" + "\r\n" + + CONTENT_TRANSFER_ENCODING + ": binary" + "\r\n" + + "\r\n" + + "File 01" + StringUtil.NEWLINE + "\r\n" + + "--" + multipartDataBoundary + "\r\n" + + CONTENT_DISPOSITION + ": form-data; name=\"quux\"; filename=\"file-02.txt\"" + "\r\n" + + CONTENT_LENGTH + ": " + file2.length() + "\r\n" + + CONTENT_TYPE + ": text/plain" + "\r\n" + + CONTENT_TRANSFER_ENCODING + ": binary" + "\r\n" + + "\r\n" + + "File 02" + StringUtil.NEWLINE + + "\r\n" + + "--" + multipartDataBoundary + "--" + "\r\n"; + + assertEquals(expected, content); + } + + @Test + public void testMultiFileUploadInHtml5Mode() throws Exception { + DefaultFullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, + HttpMethod.POST, "http://localhost"); + + DefaultHttpDataFactory factory = new DefaultHttpDataFactory(DefaultHttpDataFactory.MINSIZE); + + HttpPostRequestEncoder encoder = new HttpPostRequestEncoder(factory, + request, true, CharsetUtil.UTF_8, EncoderMode.HTML5); + File file1 = new File(getClass().getResource("/file-01.txt").toURI()); + encoder.addBodyAttribute("foo", "bar"); + encoder.addBodyFileUpload("quux", file1, "text/plain", false); + + String multipartDataBoundary = encoder.multipartDataBoundary; + String content = getRequestBody(encoder); + + String expected = "--" + multipartDataBoundary + "\r\n" + + CONTENT_DISPOSITION + ": form-data; name=\"foo\"" + "\r\n" + + CONTENT_LENGTH + ": 3" + "\r\n" + + CONTENT_TYPE + ": text/plain; charset=UTF-8" + "\r\n" + + "\r\n" + + "bar" + + "\r\n" + + "--" + multipartDataBoundary + "\r\n" + + CONTENT_DISPOSITION + ": form-data; name=\"quux\"; filename=\"file-01.txt\"" + "\r\n" + + CONTENT_LENGTH + ": " + file1.length() + "\r\n" + + CONTENT_TYPE + ": text/plain" + "\r\n" + + CONTENT_TRANSFER_ENCODING + ": binary" + "\r\n" + + "\r\n" + + "File 01" + StringUtil.NEWLINE + + "\r\n" + + "--" + multipartDataBoundary + "--" + "\r\n"; + + assertEquals(expected, content); + } + + @Test + public void testHttpPostRequestEncoderSlicedBuffer() throws Exception { + DefaultFullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, + HttpMethod.POST, "http://localhost"); + + HttpPostRequestEncoder encoder = new HttpPostRequestEncoder(request, true); + // add Form attribute + encoder.addBodyAttribute("getform", "POST"); + encoder.addBodyAttribute("info", "first value"); + encoder.addBodyAttribute("secondinfo", "secondvalue a&"); + encoder.addBodyAttribute("thirdinfo", "short text"); + int length = 100000; + char[] array = new char[length]; + Arrays.fill(array, 'a'); + String longText = new String(array); + encoder.addBodyAttribute("fourthinfo", longText.substring(0, 7470)); + File file1 = new File(getClass().getResource("/file-01.txt").toURI()); + encoder.addBodyFileUpload("myfile", file1, "application/x-zip-compressed", false); + encoder.finalizeRequest(); + while (! encoder.isEndOfInput()) { + HttpContent httpContent = encoder.readChunk((ByteBufAllocator) null); + ByteBuf content = httpContent.content(); + int refCnt = content.refCnt(); + assertTrue((content.unwrap() == content || content.unwrap() == null) && refCnt == 1 || + content.unwrap() != content && refCnt == 2, + "content: " + content + " content.unwrap(): " + content.unwrap() + " refCnt: " + refCnt); + httpContent.release(); + } + encoder.cleanFiles(); + encoder.close(); + } + + private static String getRequestBody(HttpPostRequestEncoder encoder) throws Exception { + encoder.finalizeRequest(); + + List chunks = encoder.multipartHttpDatas; + ByteBuf[] buffers = new ByteBuf[chunks.size()]; + + for (int i = 0; i < buffers.length; i++) { + InterfaceHttpData data = chunks.get(i); + if (data instanceof InternalAttribute) { + buffers[i] = ((InternalAttribute) data).toByteBuf(); + } else if (data instanceof HttpData) { + buffers[i] = ((HttpData) data).getByteBuf(); + } + } + + ByteBuf content = Unpooled.wrappedBuffer(buffers); + String contentStr = content.toString(CharsetUtil.UTF_8); + content.release(); + return contentStr; + } + + @Test + public void testDataIsMultipleOfChunkSize1() throws Exception { + DefaultHttpDataFactory factory = new DefaultHttpDataFactory(DefaultHttpDataFactory.MINSIZE); + DefaultFullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, + HttpMethod.POST, "http://localhost"); + HttpPostRequestEncoder encoder = new HttpPostRequestEncoder(factory, request, true, + HttpConstants.DEFAULT_CHARSET, HttpPostRequestEncoder.EncoderMode.RFC1738); + + MemoryFileUpload first = new MemoryFileUpload("resources", "", "application/json", null, + CharsetUtil.UTF_8, -1); + first.setMaxSize(-1); + first.setContent(new ByteArrayInputStream(new byte[7955])); + encoder.addBodyHttpData(first); + + MemoryFileUpload second = new MemoryFileUpload("resources2", "", "application/json", null, + CharsetUtil.UTF_8, -1); + second.setMaxSize(-1); + second.setContent(new ByteArrayInputStream(new byte[7928])); + encoder.addBodyHttpData(second); + + assertNotNull(encoder.finalizeRequest()); + + checkNextChunkSize(encoder, 8080); + checkNextChunkSize(encoder, 8080); + + HttpContent httpContent = encoder.readChunk((ByteBufAllocator) null); + assertTrue(httpContent instanceof LastHttpContent, "Expected LastHttpContent is not received"); + httpContent.release(); + + assertTrue(encoder.isEndOfInput(), "Expected end of input is not receive"); + } + + @Test + public void testDataIsMultipleOfChunkSize2() throws Exception { + DefaultFullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, + HttpMethod.POST, "http://localhost"); + HttpPostRequestEncoder encoder = new HttpPostRequestEncoder(request, true); + int length = 7943; + char[] array = new char[length]; + Arrays.fill(array, 'a'); + String longText = new String(array); + encoder.addBodyAttribute("foo", longText); + + assertNotNull(encoder.finalizeRequest()); + + checkNextChunkSize(encoder, 8080); + + HttpContent httpContent = encoder.readChunk((ByteBufAllocator) null); + assertTrue(httpContent instanceof LastHttpContent, "Expected LastHttpContent is not received"); + httpContent.release(); + + assertTrue(encoder.isEndOfInput(), "Expected end of input is not receive"); + } + + private static void checkNextChunkSize(HttpPostRequestEncoder encoder, int sizeWithoutDelimiter) throws Exception { + // 16 bytes as HttpPostRequestEncoder uses Long.toHexString(...) to generate a hex-string which will be between + // 2 and 16 bytes. + // See https://github.com/netty/netty/blob/4.1/codec-http/src/main/java/io/netty/handler/ + // codec/http/multipart/HttpPostRequestEncoder.java#L291 + int expectedSizeMin = sizeWithoutDelimiter + 2; + int expectedSizeMax = sizeWithoutDelimiter + 16; + + HttpContent httpContent = encoder.readChunk((ByteBufAllocator) null); + + int readable = httpContent.content().readableBytes(); + boolean expectedSize = readable >= expectedSizeMin && readable <= expectedSizeMax; + assertTrue(expectedSize, "Chunk size is not in expected range (" + expectedSizeMin + " - " + + expectedSizeMax + "), was: " + readable); + httpContent.release(); + } + + @Test + public void testEncodeChunkedContent() throws Exception { + HttpRequest req = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, "/"); + HttpPostRequestEncoder encoder = new HttpPostRequestEncoder(req, false); + + int length = 8077 + 8096; + char[] array = new char[length]; + Arrays.fill(array, 'a'); + String longText = new String(array); + + encoder.addBodyAttribute("data", longText); + encoder.addBodyAttribute("moreData", "abcd"); + + assertNotNull(encoder.finalizeRequest()); + + while (!encoder.isEndOfInput()) { + encoder.readChunk((ByteBufAllocator) null).release(); + } + + assertTrue(encoder.isEndOfInput()); + encoder.cleanFiles(); + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/multipart/HttpPostStandardRequestDecoderTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/multipart/HttpPostStandardRequestDecoderTest.java new file mode 100644 index 0000000..454d125 --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/multipart/HttpPostStandardRequestDecoderTest.java @@ -0,0 +1,90 @@ +/* + * Copyright 2022 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.multipart; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.handler.codec.http.DefaultHttpContent; +import io.netty.handler.codec.http.DefaultHttpRequest; +import io.netty.handler.codec.http.DefaultLastHttpContent; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpRequest; +import io.netty.handler.codec.http.HttpVersion; +import io.netty.util.CharsetUtil; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.*; + +class HttpPostStandardRequestDecoderTest { + + @Test + void testDecodeAttributes() { + String requestBody = "key1=value1&key2=value2"; + + HttpRequest request = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, "/upload"); + + HttpPostStandardRequestDecoder decoder = new HttpPostStandardRequestDecoder(httpDiskDataFactory(), request); + ByteBuf buf = Unpooled.wrappedBuffer(requestBody.getBytes(CharsetUtil.UTF_8)); + DefaultHttpContent httpContent = new DefaultLastHttpContent(buf); + decoder.offer(httpContent); + + assertEquals(2, decoder.getBodyHttpDatas().size()); + assertMemoryAttribute(decoder.getBodyHttpData("key1"), "value1"); + assertMemoryAttribute(decoder.getBodyHttpData("key2"), "value2"); + decoder.destroy(); + } + + @Test + void testDecodeAttributesWithAmpersandPrefixSkipsNullAttribute() { + String requestBody = "&key1=value1"; + + HttpRequest request = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, "/upload"); + + HttpPostStandardRequestDecoder decoder = new HttpPostStandardRequestDecoder(httpDiskDataFactory(), request); + ByteBuf buf = Unpooled.wrappedBuffer(requestBody.getBytes(CharsetUtil.UTF_8)); + DefaultHttpContent httpContent = new DefaultLastHttpContent(buf); + decoder.offer(httpContent); + + assertEquals(1, decoder.getBodyHttpDatas().size()); + assertMemoryAttribute(decoder.getBodyHttpData("key1"), "value1"); + decoder.destroy(); + } + + @Test + void testDecodeZeroAttributesWithAmpersandPrefix() { + String requestBody = "&"; + + HttpRequest request = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.POST, "/upload"); + + HttpPostStandardRequestDecoder decoder = new HttpPostStandardRequestDecoder(httpDiskDataFactory(), request); + ByteBuf buf = Unpooled.wrappedBuffer(requestBody.getBytes(CharsetUtil.UTF_8)); + DefaultHttpContent httpContent = new DefaultLastHttpContent(buf); + decoder.offer(httpContent); + + assertEquals(0, decoder.getBodyHttpDatas().size()); + decoder.destroy(); + } + + private static DefaultHttpDataFactory httpDiskDataFactory() { + return new DefaultHttpDataFactory(false); + } + + private static void assertMemoryAttribute(InterfaceHttpData data, String expectedValue) { + assertEquals(InterfaceHttpData.HttpDataType.Attribute, data.getHttpDataType()); + assertEquals(((MemoryAttribute) data).getValue(), expectedValue); + } + +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/multipart/MemoryFileUploadTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/multipart/MemoryFileUploadTest.java new file mode 100644 index 0000000..167c8c3 --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/multipart/MemoryFileUploadTest.java @@ -0,0 +1,30 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.multipart; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class MemoryFileUploadTest { + + @Test + public final void testMemoryFileUploadEquals() { + MemoryFileUpload f1 = + new MemoryFileUpload("m1", "m1", "application/json", null, null, 100); + assertEquals(f1, f1); + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/multipart/MixedTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/multipart/MixedTest.java new file mode 100644 index 0000000..2e6e6a8 --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/multipart/MixedTest.java @@ -0,0 +1,76 @@ +/* + * Copyright 2022 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.multipart; + +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.io.File; +import java.io.IOException; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class MixedTest { + @Test + public void mixedAttributeRefCnt() throws IOException { + MixedAttribute attribute = new MixedAttribute("foo", 100); + Assertions.assertEquals(1, attribute.refCnt()); + attribute.retain(); + Assertions.assertEquals(2, attribute.refCnt()); + + attribute.addContent(Unpooled.wrappedBuffer(new byte[90]), false); + Assertions.assertEquals(2, attribute.refCnt()); + + attribute.addContent(Unpooled.wrappedBuffer(new byte[90]), true); + Assertions.assertEquals(2, attribute.refCnt()); + + attribute.release(2); + } + + @Test + public void mixedFileUploadRefCnt() throws IOException { + MixedFileUpload upload = new MixedFileUpload("foo", "foo", "foo", "UTF-8", CharsetUtil.UTF_8, 0, 100); + Assertions.assertEquals(1, upload.refCnt()); + upload.retain(); + Assertions.assertEquals(2, upload.refCnt()); + + upload.addContent(Unpooled.wrappedBuffer(new byte[90]), false); + Assertions.assertEquals(2, upload.refCnt()); + + upload.addContent(Unpooled.wrappedBuffer(new byte[90]), true); + Assertions.assertEquals(2, upload.refCnt()); + + upload.release(2); + } + + @Test + public void testSpecificCustomBaseDir() throws IOException { + File baseDir = new File("target/MixedTest/testSpecificCustomBaseDir"); + baseDir.mkdirs(); // we don't need to clean it since it is in volatile files anyway + MixedFileUpload upload = new MixedFileUpload("foo", "foo", "foo", "UTF-8", CharsetUtil.UTF_8, 1000, 100, + baseDir.getAbsolutePath(), true); + + upload.addContent(Unpooled.wrappedBuffer(new byte[1000]), true); + + assertTrue(upload.getFile().getAbsolutePath().startsWith(baseDir.getAbsolutePath())); + assertTrue(upload.getFile().exists()); + assertEquals(1000, upload.getFile().length()); + upload.delete(); + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/CloseWebSocketFrameTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/CloseWebSocketFrameTest.java new file mode 100644 index 0000000..dcb6fbb --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/CloseWebSocketFrameTest.java @@ -0,0 +1,104 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the "License"); + * you may not use this file except in compliance with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is + * distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and limitations under the License. + */ +package io.netty.handler.codec.http.websocketx; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import org.assertj.core.api.ThrowableAssert; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; + +class CloseWebSocketFrameTest { + + @Test + void testInvalidCode() { + doTestInvalidCode(new ThrowableAssert.ThrowingCallable() { + + @Override + public void call() throws RuntimeException { + new CloseWebSocketFrame(WebSocketCloseStatus.ABNORMAL_CLOSURE); + } + }); + + doTestInvalidCode(new ThrowableAssert.ThrowingCallable() { + + @Override + public void call() throws RuntimeException { + new CloseWebSocketFrame(WebSocketCloseStatus.ABNORMAL_CLOSURE, "invalid code"); + } + }); + + doTestInvalidCode(new ThrowableAssert.ThrowingCallable() { + + @Override + public void call() throws RuntimeException { + new CloseWebSocketFrame(1006, "invalid code"); + } + }); + + doTestInvalidCode(new ThrowableAssert.ThrowingCallable() { + + @Override + public void call() throws RuntimeException { + new CloseWebSocketFrame(true, 0, 1006, "invalid code"); + } + }); + } + + @Test + void testValidCode() { + doTestValidCode(new CloseWebSocketFrame(WebSocketCloseStatus.NORMAL_CLOSURE), + WebSocketCloseStatus.NORMAL_CLOSURE.code(), WebSocketCloseStatus.NORMAL_CLOSURE.reasonText()); + + doTestValidCode(new CloseWebSocketFrame(WebSocketCloseStatus.NORMAL_CLOSURE, "valid code"), + WebSocketCloseStatus.NORMAL_CLOSURE.code(), "valid code"); + + doTestValidCode(new CloseWebSocketFrame(1000, "valid code"), 1000, "valid code"); + + doTestValidCode(new CloseWebSocketFrame(true, 0, 1000, "valid code"), 1000, "valid code"); + } + + @Test + void testNonZeroReaderIndex() { + ByteBuf buffer = Unpooled.buffer().writeZero(1); + buffer.writeShort(WebSocketCloseStatus.NORMAL_CLOSURE.code()) + .writeCharSequence(WebSocketCloseStatus.NORMAL_CLOSURE.reasonText(), CharsetUtil.US_ASCII); + doTestValidCode(new CloseWebSocketFrame(true, 0, buffer.skipBytes(1)), + WebSocketCloseStatus.NORMAL_CLOSURE.code(), WebSocketCloseStatus.NORMAL_CLOSURE.reasonText()); + } + + @Test + void testCustomCloseCode() { + ByteBuf buffer = Unpooled.buffer().writeZero(1); + buffer.writeShort(60000) + .writeCharSequence("Custom close code", CharsetUtil.US_ASCII); + doTestValidCode(new CloseWebSocketFrame(true, 0, buffer.skipBytes(1)), + 60000, "Custom close code"); + } + + private static void doTestInvalidCode(ThrowableAssert.ThrowingCallable callable) { + assertThatExceptionOfType(IllegalArgumentException.class).isThrownBy(callable); + } + + private static void doTestValidCode(CloseWebSocketFrame frame, int expectedCode, String expectedReason) { + try { + assertThat(frame.statusCode()).isEqualTo(expectedCode); + assertThat(frame.reasonText()).isEqualTo(expectedReason); + } finally { + frame.release(); + } + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocket00FrameEncoderTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocket00FrameEncoderTest.java new file mode 100644 index 0000000..c428e31 --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocket00FrameEncoderTest.java @@ -0,0 +1,47 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.embedded.EmbeddedChannel; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class WebSocket00FrameEncoderTest { + + // Test for https://github.com/netty/netty/issues/2768 + @Test + public void testMultipleWebSocketCloseFrames() { + EmbeddedChannel channel = new EmbeddedChannel(new WebSocket00FrameEncoder()); + assertTrue(channel.writeOutbound(new CloseWebSocketFrame())); + assertTrue(channel.writeOutbound(new CloseWebSocketFrame())); + assertTrue(channel.finish()); + assertCloseWebSocketFrame(channel); + assertCloseWebSocketFrame(channel); + assertNull(channel.readOutbound()); + } + + private static void assertCloseWebSocketFrame(EmbeddedChannel channel) { + ByteBuf buf = channel.readOutbound(); + assertEquals(2, buf.readableBytes()); + assertEquals((byte) 0xFF, buf.readByte()); + assertEquals((byte) 0x00, buf.readByte()); + buf.release(); + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocket08EncoderDecoderTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocket08EncoderDecoderTest.java new file mode 100644 index 0000000..2edc7bd --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocket08EncoderDecoderTest.java @@ -0,0 +1,222 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.embedded.EmbeddedChannel; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * Tests the WebSocket08FrameEncoder and Decoder implementation.
+ * Checks whether the combination of encoding and decoding yields the original data.
+ * Thereby also the masking behavior is checked. + */ +public class WebSocket08EncoderDecoderTest { + + private ByteBuf binTestData; + private String strTestData; + + private static final int MAX_TESTDATA_LENGTH = 100 * 1024; + + private void initTestData() { + binTestData = Unpooled.buffer(MAX_TESTDATA_LENGTH); + byte j = 0; + for (int i = 0; i < MAX_TESTDATA_LENGTH; i++) { + binTestData.array()[i] = j; + j++; + } + + StringBuilder s = new StringBuilder(); + char c = 'A'; + for (int i = 0; i < MAX_TESTDATA_LENGTH; i++) { + s.append(c); + c++; + if (c == 'Z') { + c = 'A'; + } + } + strTestData = s.toString(); + } + + @Test + public void testWebSocketProtocolViolation() { + // Given + initTestData(); + + int maxPayloadLength = 255; + String errorMessage = "Max frame length of " + maxPayloadLength + " has been exceeded."; + WebSocketCloseStatus expectedStatus = WebSocketCloseStatus.MESSAGE_TOO_BIG; + + // With auto-close + WebSocketDecoderConfig config = WebSocketDecoderConfig.newBuilder() + .maxFramePayloadLength(maxPayloadLength) + .closeOnProtocolViolation(true) + .build(); + EmbeddedChannel inChannel = new EmbeddedChannel(new WebSocket08FrameDecoder(config)); + EmbeddedChannel outChannel = new EmbeddedChannel(new WebSocket08FrameEncoder(true)); + + executeProtocolViolationTest(outChannel, inChannel, maxPayloadLength + 1, expectedStatus, errorMessage); + + CloseWebSocketFrame response = inChannel.readOutbound(); + assertNotNull(response); + assertEquals(expectedStatus.code(), response.statusCode()); + assertEquals(errorMessage, response.reasonText()); + response.release(); + + assertFalse(inChannel.finish()); + assertFalse(outChannel.finish()); + + // Without auto-close + config = WebSocketDecoderConfig.newBuilder() + .maxFramePayloadLength(maxPayloadLength) + .closeOnProtocolViolation(false) + .build(); + inChannel = new EmbeddedChannel(new WebSocket08FrameDecoder(config)); + outChannel = new EmbeddedChannel(new WebSocket08FrameEncoder(true)); + + executeProtocolViolationTest(outChannel, inChannel, maxPayloadLength + 1, expectedStatus, errorMessage); + + response = inChannel.readOutbound(); + assertNull(response); + + assertFalse(inChannel.finish()); + assertFalse(outChannel.finish()); + + // Release test data + binTestData.release(); + } + + private void executeProtocolViolationTest(EmbeddedChannel outChannel, EmbeddedChannel inChannel, + int testDataLength, WebSocketCloseStatus expectedStatus, String errorMessage) { + CorruptedWebSocketFrameException corrupted = null; + + try { + testBinaryWithLen(outChannel, inChannel, testDataLength); + } catch (CorruptedWebSocketFrameException e) { + corrupted = e; + } + + BinaryWebSocketFrame exceedingFrame = inChannel.readInbound(); + assertNull(exceedingFrame); + + assertNotNull(corrupted); + assertEquals(expectedStatus, corrupted.closeStatus()); + assertEquals(errorMessage, corrupted.getMessage()); + } + + @Test + public void testWebSocketEncodingAndDecoding() { + initTestData(); + + // Test without masking + EmbeddedChannel outChannel = new EmbeddedChannel(new WebSocket08FrameEncoder(false)); + EmbeddedChannel inChannel = new EmbeddedChannel(new WebSocket08FrameDecoder(false, false, 1024 * 1024, false)); + executeTests(outChannel, inChannel); + + // Test with activated masking + outChannel = new EmbeddedChannel(new WebSocket08FrameEncoder(true)); + inChannel = new EmbeddedChannel(new WebSocket08FrameDecoder(true, false, 1024 * 1024, false)); + executeTests(outChannel, inChannel); + + // Test with activated masking and an unmasked expecting but forgiving decoder + outChannel = new EmbeddedChannel(new WebSocket08FrameEncoder(true)); + inChannel = new EmbeddedChannel(new WebSocket08FrameDecoder(false, false, 1024 * 1024, true)); + executeTests(outChannel, inChannel); + + // Release test data + binTestData.release(); + } + + private void executeTests(EmbeddedChannel outChannel, EmbeddedChannel inChannel) { + // Test at the boundaries of each message type, because this shifts the position of the mask field + // Test min. 4 lengths to check for problems related to an uneven frame length + executeTests(outChannel, inChannel, 0); + executeTests(outChannel, inChannel, 1); + executeTests(outChannel, inChannel, 2); + executeTests(outChannel, inChannel, 3); + executeTests(outChannel, inChannel, 4); + executeTests(outChannel, inChannel, 5); + + executeTests(outChannel, inChannel, 125); + executeTests(outChannel, inChannel, 126); + executeTests(outChannel, inChannel, 127); + executeTests(outChannel, inChannel, 128); + executeTests(outChannel, inChannel, 129); + + executeTests(outChannel, inChannel, 65535); + executeTests(outChannel, inChannel, 65536); + executeTests(outChannel, inChannel, 65537); + executeTests(outChannel, inChannel, 65538); + executeTests(outChannel, inChannel, 65539); + } + + private void executeTests(EmbeddedChannel outChannel, EmbeddedChannel inChannel, int testDataLength) { + testTextWithLen(outChannel, inChannel, testDataLength); + testBinaryWithLen(outChannel, inChannel, testDataLength); + } + + private void testTextWithLen(EmbeddedChannel outChannel, EmbeddedChannel inChannel, int testDataLength) { + String testStr = strTestData.substring(0, testDataLength); + outChannel.writeOutbound(new TextWebSocketFrame(testStr)); + + transfer(outChannel, inChannel); + + Object decoded = inChannel.readInbound(); + assertNotNull(decoded); + assertTrue(decoded instanceof TextWebSocketFrame); + TextWebSocketFrame txt = (TextWebSocketFrame) decoded; + assertEquals(txt.text(), testStr); + txt.release(); + } + + private void testBinaryWithLen(EmbeddedChannel outChannel, EmbeddedChannel inChannel, int testDataLength) { + binTestData.retain(); // need to retain for sending and still keeping it + binTestData.setIndex(0, testDataLength); // Send only len bytes + outChannel.writeOutbound(new BinaryWebSocketFrame(binTestData)); + + transfer(outChannel, inChannel); + + Object decoded = inChannel.readInbound(); + assertNotNull(decoded); + assertTrue(decoded instanceof BinaryWebSocketFrame); + BinaryWebSocketFrame binFrame = (BinaryWebSocketFrame) decoded; + int readable = binFrame.content().readableBytes(); + assertEquals(readable, testDataLength); + for (int i = 0; i < testDataLength; i++) { + assertEquals(binTestData.getByte(i), binFrame.content().getByte(i)); + } + binFrame.release(); + } + + private void transfer(EmbeddedChannel outChannel, EmbeddedChannel inChannel) { + // Transfer encoded data into decoder + // Loop because there might be multiple frames (gathering write) + for (;;) { + ByteBuf encoded = outChannel.readOutbound(); + if (encoded == null) { + return; + } + inChannel.writeInbound(encoded); + } + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocket08FrameDecoderTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocket08FrameDecoderTest.java new file mode 100644 index 0000000..fd2a35d --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocket08FrameDecoderTest.java @@ -0,0 +1,97 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the "License"); + * you may not use this file except in compliance with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is + * distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and limitations under the License. + */ +package io.netty.handler.codec.http.websocketx; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.embedded.EmbeddedChannel; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import java.util.HashSet; +import java.util.Set; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +public class WebSocket08FrameDecoderTest { + + @Test + public void channelInactive() throws Exception { + final WebSocket08FrameDecoder decoder = new WebSocket08FrameDecoder(true, true, 65535, false); + final ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + decoder.channelInactive(ctx); + verify(ctx).fireChannelInactive(); + } + + @Test + public void supportIanaStatusCodes() throws Exception { + Set forbiddenIanaCodes = new HashSet(); + forbiddenIanaCodes.add(1004); + forbiddenIanaCodes.add(1005); + forbiddenIanaCodes.add(1006); + Set validIanaCodes = new HashSet(); + for (int i = 1000; i < 1015; i++) { + validIanaCodes.add(i); + } + validIanaCodes.removeAll(forbiddenIanaCodes); + + for (int statusCode: validIanaCodes) { + EmbeddedChannel encoderChannel = new EmbeddedChannel(new WebSocket08FrameEncoder(true)); + EmbeddedChannel decoderChannel = new EmbeddedChannel(new WebSocket08FrameDecoder(true, true, 65535, false)); + + assertTrue(encoderChannel.writeOutbound(new CloseWebSocketFrame(statusCode, "Bye"))); + assertTrue(encoderChannel.finish()); + ByteBuf serializedCloseFrame = encoderChannel.readOutbound(); + assertNull(encoderChannel.readOutbound()); + + assertTrue(decoderChannel.writeInbound(serializedCloseFrame)); + assertTrue(decoderChannel.finish()); + + CloseWebSocketFrame outputFrame = decoderChannel.readInbound(); + assertNull(decoderChannel.readOutbound()); + try { + assertEquals(statusCode, outputFrame.statusCode()); + } finally { + outputFrame.release(); + } + } + } + + @Test + void protocolViolationWhenNegativeFrameLength() { + WebSocket08FrameDecoder decoder = new WebSocket08FrameDecoder(true, true, 65535, false); + final EmbeddedChannel channel = new EmbeddedChannel(decoder); + final ByteBuf invalidFrame = Unpooled.buffer(10).writeByte(0x81) + .writeByte(0xFF).writeLong(-1L); + + Throwable exception = assertThrows(CorruptedWebSocketFrameException.class, new Executable() { + @Override + public void execute() { + channel.writeInbound(invalidFrame); + } + }); + assertEquals("invalid data frame length (negative length)", exception.getMessage()); + + CloseWebSocketFrame closeFrame = channel.readOutbound(); + assertEquals("invalid data frame length (negative length)", closeFrame.reasonText()); + assertTrue(closeFrame.release()); + assertFalse(channel.isActive()); + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker00Test.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker00Test.java new file mode 100644 index 0000000..d84c903 --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker00Test.java @@ -0,0 +1,51 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx; + +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaders; + +import java.net.URI; + +public class WebSocketClientHandshaker00Test extends WebSocketClientHandshakerTest { + @Override + protected WebSocketClientHandshaker newHandshaker(URI uri, String subprotocol, HttpHeaders headers, + boolean absoluteUpgradeUrl, boolean generateOriginHeader) { + return new WebSocketClientHandshaker00(uri, WebSocketVersion.V00, subprotocol, headers, + 1024, 10000, absoluteUpgradeUrl, generateOriginHeader); + } + + @Override + protected CharSequence getOriginHeaderName() { + return HttpHeaderNames.ORIGIN; + } + + @Override + protected CharSequence getProtocolHeaderName() { + return HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL; + } + + @Override + protected CharSequence[] getHandshakeRequiredHeaderNames() { + return new CharSequence[] { + HttpHeaderNames.CONNECTION, + HttpHeaderNames.UPGRADE, + HttpHeaderNames.HOST, + HttpHeaderNames.SEC_WEBSOCKET_KEY1, + HttpHeaderNames.SEC_WEBSOCKET_KEY2, + }; + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker07Test.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker07Test.java new file mode 100644 index 0000000..e3f04da --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker07Test.java @@ -0,0 +1,73 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx; + +import io.netty.handler.codec.http.DefaultHttpHeaders; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaders; +import org.junit.jupiter.api.Test; + +import java.net.URI; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class WebSocketClientHandshaker07Test extends WebSocketClientHandshakerTest { + + @Test + public void testHostHeaderPreserved() { + URI uri = URI.create("ws://localhost:9999"); + WebSocketClientHandshaker handshaker = newHandshaker(uri, null, + new DefaultHttpHeaders().set(HttpHeaderNames.HOST, "test.netty.io"), false, true); + + FullHttpRequest request = handshaker.newHandshakeRequest(); + try { + assertEquals("/", request.uri()); + assertEquals("test.netty.io", request.headers().get(HttpHeaderNames.HOST)); + } finally { + request.release(); + } + } + + @Override + protected WebSocketClientHandshaker newHandshaker(URI uri, String subprotocol, HttpHeaders headers, + boolean absoluteUpgradeUrl, boolean generateOriginHeader) { + return new WebSocketClientHandshaker07(uri, WebSocketVersion.V07, subprotocol, false, headers, + 1024, true, false, 10000, + absoluteUpgradeUrl, generateOriginHeader); + } + + @Override + protected CharSequence getOriginHeaderName() { + return HttpHeaderNames.SEC_WEBSOCKET_ORIGIN; + } + + @Override + protected CharSequence getProtocolHeaderName() { + return HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL; + } + + @Override + protected CharSequence[] getHandshakeRequiredHeaderNames() { + return new CharSequence[] { + HttpHeaderNames.UPGRADE, + HttpHeaderNames.CONNECTION, + HttpHeaderNames.SEC_WEBSOCKET_KEY, + HttpHeaderNames.HOST, + HttpHeaderNames.SEC_WEBSOCKET_VERSION, + }; + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker08Test.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker08Test.java new file mode 100644 index 0000000..aa8e11d --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker08Test.java @@ -0,0 +1,30 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx; + +import io.netty.handler.codec.http.HttpHeaders; + +import java.net.URI; + +public class WebSocketClientHandshaker08Test extends WebSocketClientHandshaker07Test { + @Override + protected WebSocketClientHandshaker newHandshaker(URI uri, String subprotocol, HttpHeaders headers, + boolean absoluteUpgradeUrl, boolean generateOriginHeader) { + return new WebSocketClientHandshaker08(uri, WebSocketVersion.V08, subprotocol, false, headers, + 1024, true, true, 10000, + absoluteUpgradeUrl, generateOriginHeader); + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker13Test.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker13Test.java new file mode 100644 index 0000000..cd82f21 --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshaker13Test.java @@ -0,0 +1,38 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx; + +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaders; + +import java.net.URI; + +public class WebSocketClientHandshaker13Test extends WebSocketClientHandshaker07Test { + + @Override + protected WebSocketClientHandshaker newHandshaker(URI uri, String subprotocol, HttpHeaders headers, + boolean absoluteUpgradeUrl, boolean generateOriginHeader) { + return new WebSocketClientHandshaker13(uri, WebSocketVersion.V13, subprotocol, false, headers, + 1024, true, true, 10000, + absoluteUpgradeUrl, generateOriginHeader); + } + + @Override + protected CharSequence getOriginHeaderName() { + return HttpHeaderNames.ORIGIN; + } + +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshakerTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshakerTest.java new file mode 100644 index 0000000..411e8ae --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketClientHandshakerTest.java @@ -0,0 +1,517 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.CompositeByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.http.DefaultFullHttpResponse; +import io.netty.handler.codec.http.DefaultHttpHeaders; +import io.netty.handler.codec.http.DefaultHttpResponse; +import io.netty.handler.codec.http.DefaultLastHttpContent; +import io.netty.handler.codec.http.EmptyHttpHeaders; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpClientCodec; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaderValues; +import io.netty.handler.codec.http.HttpHeaders; +import io.netty.handler.codec.http.HttpObjectAggregator; +import io.netty.handler.codec.http.HttpRequestEncoder; +import io.netty.handler.codec.http.HttpResponse; +import io.netty.handler.codec.http.HttpResponseDecoder; +import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.handler.codec.http.HttpVersion; +import io.netty.handler.codec.http.LastHttpContent; +import io.netty.util.CharsetUtil; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.net.URI; +import java.util.concurrent.TimeUnit; + +import static io.netty.handler.codec.http.HttpResponseStatus.SWITCHING_PROTOCOLS; +import static io.netty.handler.codec.http.websocketx.WebSocketServerHandshaker13.WEBSOCKET_13_ACCEPT_GUID; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public abstract class WebSocketClientHandshakerTest { + protected abstract WebSocketClientHandshaker newHandshaker(URI uri, String subprotocol, HttpHeaders headers, + boolean absoluteUpgradeUrl, + boolean generateOriginHeader); + + protected WebSocketClientHandshaker newHandshaker(URI uri) { + return newHandshaker(uri, null, null, false, true); + } + + protected abstract CharSequence getOriginHeaderName(); + + protected abstract CharSequence getProtocolHeaderName(); + + protected abstract CharSequence[] getHandshakeRequiredHeaderNames(); + + @Test + public void hostHeaderWs() { + for (String scheme : new String[]{"ws://", "http://"}) { + for (String host : new String[]{"localhost", "127.0.0.1", "[::1]", "Netty.io"}) { + String enter = scheme + host; + + testHostHeader(enter, host); + testHostHeader(enter + '/', host); + testHostHeader(enter + ":80", host); + testHostHeader(enter + ":443", host + ":443"); + testHostHeader(enter + ":9999", host + ":9999"); + testHostHeader(enter + "/path", host); + testHostHeader(enter + ":80/path", host); + testHostHeader(enter + ":443/path", host + ":443"); + testHostHeader(enter + ":9999/path", host + ":9999"); + } + } + } + + @Test + public void hostHeaderWss() { + for (String scheme : new String[]{"wss://", "https://"}) { + for (String host : new String[]{"localhost", "127.0.0.1", "[::1]", "Netty.io"}) { + String enter = scheme + host; + + testHostHeader(enter, host); + testHostHeader(enter + '/', host); + testHostHeader(enter + ":80", host + ":80"); + testHostHeader(enter + ":443", host); + testHostHeader(enter + ":9999", host + ":9999"); + testHostHeader(enter + "/path", host); + testHostHeader(enter + ":80/path", host + ":80"); + testHostHeader(enter + ":443/path", host); + testHostHeader(enter + ":9999/path", host + ":9999"); + } + } + } + + @Test + public void hostHeaderWithoutScheme() { + testHostHeader("//localhost/", "localhost"); + testHostHeader("//localhost/path", "localhost"); + testHostHeader("//localhost:80/", "localhost:80"); + testHostHeader("//localhost:443/", "localhost:443"); + testHostHeader("//localhost:9999/", "localhost:9999"); + } + + @Test + public void originHeaderWs() { + for (String scheme : new String[]{"ws://", "http://"}) { + for (String host : new String[]{"localhost", "127.0.0.1", "[::1]", "NETTY.IO"}) { + String enter = scheme + host; + String expect = "http://" + host.toLowerCase(); + + testOriginHeader(enter, expect); + testOriginHeader(enter + '/', expect); + testOriginHeader(enter + ":80", expect); + testOriginHeader(enter + ":443", expect + ":443"); + testOriginHeader(enter + ":9999", expect + ":9999"); + testOriginHeader(enter + "/path%20with%20ws", expect); + testOriginHeader(enter + ":80/path%20with%20ws", expect); + testOriginHeader(enter + ":443/path%20with%20ws", expect + ":443"); + testOriginHeader(enter + ":9999/path%20with%20ws", expect + ":9999"); + } + } + } + + @Test + public void originHeaderWss() { + for (String scheme : new String[]{"wss://", "https://"}) { + for (String host : new String[]{"localhost", "127.0.0.1", "[::1]", "NETTY.IO"}) { + String enter = scheme + host; + String expect = "https://" + host.toLowerCase(); + + testOriginHeader(enter, expect); + testOriginHeader(enter + '/', expect); + testOriginHeader(enter + ":80", expect + ":80"); + testOriginHeader(enter + ":443", expect); + testOriginHeader(enter + ":9999", expect + ":9999"); + testOriginHeader(enter + "/path%20with%20ws", expect); + testOriginHeader(enter + ":80/path%20with%20ws", expect + ":80"); + testOriginHeader(enter + ":443/path%20with%20ws", expect); + testOriginHeader(enter + ":9999/path%20with%20ws", expect + ":9999"); + } + } + } + + @Test + public void originHeaderWithoutScheme() { + testOriginHeader("//localhost/", "http://localhost"); + testOriginHeader("//localhost/path", "http://localhost"); + + // http scheme by port + testOriginHeader("//localhost:80/", "http://localhost"); + testOriginHeader("//localhost:80/path", "http://localhost"); + + // https scheme by port + testOriginHeader("//localhost:443/", "https://localhost"); + testOriginHeader("//localhost:443/path", "https://localhost"); + + // http scheme for non standard port + testOriginHeader("//localhost:9999/", "http://localhost:9999"); + testOriginHeader("//localhost:9999/path", "http://localhost:9999"); + + // convert host to lower case + testOriginHeader("//LOCALHOST/", "http://localhost"); + } + + @Test + public void testSetOriginFromCustomHeaders() { + HttpHeaders customHeaders = new DefaultHttpHeaders().set(getOriginHeaderName(), "http://example.com"); + WebSocketClientHandshaker handshaker = newHandshaker(URI.create("ws://server.example.com/chat"), null, + customHeaders, false, true); + FullHttpRequest request = handshaker.newHandshakeRequest(); + try { + assertEquals("http://example.com", request.headers().get(getOriginHeaderName())); + } finally { + request.release(); + } + } + + @Test + public void testOriginHeaderIsAbsentWhenGeneratingDisable() { + URI uri = URI.create("http://example.com/ws"); + WebSocketClientHandshaker handshaker = newHandshaker(uri, null, null, false, false); + FullHttpRequest request = handshaker.newHandshakeRequest(); + try { + assertFalse(request.headers().contains(getOriginHeaderName())); + assertEquals("/ws", request.uri()); + } finally { + request.release(); + } + } + + @Test + public void testInvalidHostWhenIncorrectWebSocketURI() { + URI uri = URI.create("/ws"); + EmbeddedChannel channel = new EmbeddedChannel(new HttpClientCodec()); + final WebSocketClientHandshaker handshaker = newHandshaker(uri, null, null, false, true); + final ChannelFuture handshakeFuture = handshaker.handshake(channel); + + assertFalse(handshakeFuture.isSuccess()); + assertInstanceOf(IllegalArgumentException.class, handshakeFuture.cause()); + assertEquals("Cannot generate the 'host' header value, webSocketURI should contain host" + + " or passed through customHeaders", handshakeFuture.cause().getMessage()); + assertFalse(channel.finish()); + } + + @Test + public void testInvalidOriginWhenIncorrectWebSocketURI() { + URI uri = URI.create("/ws"); + EmbeddedChannel channel = new EmbeddedChannel(new HttpClientCodec()); + HttpHeaders headers = new DefaultHttpHeaders(); + headers.set(HttpHeaderNames.HOST, "localhost:80"); + final WebSocketClientHandshaker handshaker = newHandshaker(uri, null, headers, false, true); + final ChannelFuture handshakeFuture = handshaker.handshake(channel); + + assertFalse(handshakeFuture.isSuccess()); + assertInstanceOf(IllegalArgumentException.class, handshakeFuture.cause()); + assertEquals("Cannot generate the '" + getOriginHeaderName() + "' header value," + + " webSocketURI should contain host or disable generateOriginHeader" + + " or pass value through customHeaders", handshakeFuture.cause().getMessage()); + assertFalse(channel.finish()); + } + + private void testHostHeader(String uri, String expected) { + testHeaderDefaultHttp(uri, HttpHeaderNames.HOST, expected); + } + + private void testOriginHeader(String uri, String expected) { + testHeaderDefaultHttp(uri, getOriginHeaderName(), expected); + } + + protected void testHeaderDefaultHttp(String uri, CharSequence header, String expectedValue) { + WebSocketClientHandshaker handshaker = newHandshaker(URI.create(uri)); + FullHttpRequest request = handshaker.newHandshakeRequest(); + try { + assertEquals(expectedValue, request.headers().get(header)); + } finally { + request.release(); + } + } + + @Test + @SuppressWarnings("deprecation") + public void testUpgradeUrl() { + URI uri = URI.create("ws://localhost:9999/path%20with%20ws"); + WebSocketClientHandshaker handshaker = newHandshaker(uri); + FullHttpRequest request = handshaker.newHandshakeRequest(); + try { + assertEquals("/path%20with%20ws", request.getUri()); + } finally { + request.release(); + } + } + + @Test + public void testUpgradeUrlWithQuery() { + URI uri = URI.create("ws://localhost:9999/path%20with%20ws?a=b%20c"); + WebSocketClientHandshaker handshaker = newHandshaker(uri); + FullHttpRequest request = handshaker.newHandshakeRequest(); + try { + assertEquals("/path%20with%20ws?a=b%20c", request.uri()); + } finally { + request.release(); + } + } + + @Test + public void testUpgradeUrlWithoutPath() { + URI uri = URI.create("ws://localhost:9999"); + WebSocketClientHandshaker handshaker = newHandshaker(uri); + FullHttpRequest request = handshaker.newHandshakeRequest(); + try { + assertEquals("/", request.uri()); + } finally { + request.release(); + } + } + + @Test + public void testUpgradeUrlWithoutPathWithQuery() { + URI uri = URI.create("ws://localhost:9999?a=b%20c"); + WebSocketClientHandshaker handshaker = newHandshaker(uri); + FullHttpRequest request = handshaker.newHandshakeRequest(); + try { + assertEquals("/?a=b%20c", request.uri()); + } finally { + request.release(); + } + } + + @Test + public void testAbsoluteUpgradeUrlWithQuery() { + URI uri = URI.create("ws://localhost:9999/path%20with%20ws?a=b%20c"); + WebSocketClientHandshaker handshaker = newHandshaker(uri, null, null, true, true); + FullHttpRequest request = handshaker.newHandshakeRequest(); + try { + assertEquals("ws://localhost:9999/path%20with%20ws?a=b%20c", request.uri()); + } finally { + request.release(); + } + } + + @Test + @Timeout(value = 3000, unit = TimeUnit.MILLISECONDS) + public void testHttpResponseAndFrameInSameBuffer() { + testHttpResponseAndFrameInSameBuffer(false); + } + + @Test + @Timeout(value = 3000, unit = TimeUnit.MILLISECONDS) + public void testHttpResponseAndFrameInSameBufferCodec() { + testHttpResponseAndFrameInSameBuffer(true); + } + + private void testHttpResponseAndFrameInSameBuffer(boolean codec) { + String url = "ws://localhost:9999/ws"; + final WebSocketClientHandshaker shaker = newHandshaker(URI.create(url)); + final WebSocketClientHandshaker handshaker = new WebSocketClientHandshaker( + shaker.uri(), shaker.version(), null, EmptyHttpHeaders.INSTANCE, Integer.MAX_VALUE, -1) { + @Override + protected FullHttpRequest newHandshakeRequest() { + return shaker.newHandshakeRequest(); + } + + @Override + protected void verify(FullHttpResponse response) { + // Not do any verification, so we not need to care sending the correct headers etc in the test, + // which would just make things more complicated. + } + + @Override + protected WebSocketFrameDecoder newWebsocketDecoder() { + return shaker.newWebsocketDecoder(); + } + + @Override + protected WebSocketFrameEncoder newWebSocketEncoder() { + return shaker.newWebSocketEncoder(); + } + }; + + // use randomBytes helper from utils to check that it functions properly + byte[] data = WebSocketUtil.randomBytes(24); + + // Create a EmbeddedChannel which we will use to encode a BinaryWebsocketFrame to bytes and so use these + // to test the actual handshaker. + WebSocketServerHandshakerFactory factory = new WebSocketServerHandshakerFactory(url, null, false); + FullHttpRequest request = shaker.newHandshakeRequest(); + WebSocketServerHandshaker socketServerHandshaker = factory.newHandshaker(request); + request.release(); + EmbeddedChannel websocketChannel = new EmbeddedChannel(socketServerHandshaker.newWebSocketEncoder(), + socketServerHandshaker.newWebsocketDecoder()); + assertTrue(websocketChannel.writeOutbound(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(data)))); + + byte[] bytes = ("HTTP/1.1 101 Switching Protocols\r\nSec-Websocket-Accept: not-verify\r\n" + + "Upgrade: websocket\r\n\r\n").getBytes(CharsetUtil.US_ASCII); + + CompositeByteBuf compositeByteBuf = Unpooled.compositeBuffer(); + compositeByteBuf.addComponent(true, Unpooled.wrappedBuffer(bytes)); + for (;;) { + ByteBuf frameBytes = websocketChannel.readOutbound(); + if (frameBytes == null) { + break; + } + compositeByteBuf.addComponent(true, frameBytes); + } + + EmbeddedChannel ch = new EmbeddedChannel(new HttpObjectAggregator(Integer.MAX_VALUE), + new SimpleChannelInboundHandler() { + @Override + protected void channelRead0(ChannelHandlerContext ctx, FullHttpResponse msg) throws Exception { + handshaker.finishHandshake(ctx.channel(), msg); + ctx.pipeline().remove(this); + } + }); + if (codec) { + ch.pipeline().addFirst(new HttpClientCodec()); + } else { + ch.pipeline().addFirst(new HttpRequestEncoder(), new HttpResponseDecoder()); + } + // We need to first write the request as HttpClientCodec will fail if we receive a response before a request + // was written. + shaker.handshake(ch).syncUninterruptibly(); + for (;;) { + // Just consume the bytes, we are not interested in these. + ByteBuf buf = ch.readOutbound(); + if (buf == null) { + break; + } + buf.release(); + } + assertTrue(ch.writeInbound(compositeByteBuf)); + assertTrue(ch.finish()); + + BinaryWebSocketFrame frame = ch.readInbound(); + ByteBuf expect = Unpooled.wrappedBuffer(data); + try { + assertEquals(expect, frame.content()); + assertTrue(frame.isFinalFragment()); + assertEquals(0, frame.rsv()); + } finally { + expect.release(); + frame.release(); + } + } + + @Test + public void testDuplicateWebsocketHandshakeHeaders() { + URI uri = URI.create("ws://localhost:9999/foo"); + + HttpHeaders inputHeaders = new DefaultHttpHeaders(); + String bogusSubProtocol = "bogusSubProtocol"; + String bogusHeaderValue = "bogusHeaderValue"; + + // add values for the headers that are reserved for use in the websockets handshake + for (CharSequence header : getHandshakeRequiredHeaderNames()) { + if (!HttpHeaderNames.HOST.equals(header)) { + inputHeaders.add(header, bogusHeaderValue); + } + } + inputHeaders.add(getProtocolHeaderName(), bogusSubProtocol); + + String realSubProtocol = "realSubProtocol"; + WebSocketClientHandshaker handshaker = newHandshaker(uri, realSubProtocol, inputHeaders, false, true); + FullHttpRequest request = handshaker.newHandshakeRequest(); + HttpHeaders outputHeaders = request.headers(); + + // the header values passed in originally have been replaced with values generated by the Handshaker + for (CharSequence header : getHandshakeRequiredHeaderNames()) { + assertEquals(1, outputHeaders.getAll(header).size()); + assertNotEquals(bogusHeaderValue, outputHeaders.get(header)); + } + + // the subprotocol header value is that of the subprotocol string passed into the Handshaker + assertEquals(1, outputHeaders.getAll(getProtocolHeaderName()).size()); + assertEquals(realSubProtocol, outputHeaders.get(getProtocolHeaderName())); + + request.release(); + } + + @Test + public void testWebSocketClientHandshakeException() { + URI uri = URI.create("ws://localhost:9999/exception"); + WebSocketClientHandshaker handshaker = newHandshaker(uri, null, null, false, true); + FullHttpResponse response = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.UNAUTHORIZED); + response.headers().set(HttpHeaderNames.WWW_AUTHENTICATE, "realm = access token required"); + + try { + handshaker.finishHandshake(null, response); + fail("Expected WebSocketClientHandshakeException"); + } catch (WebSocketClientHandshakeException exception) { + assertEquals("Invalid handshake response getStatus: 401 Unauthorized", exception.getMessage()); + assertEquals(HttpResponseStatus.UNAUTHORIZED, exception.response().status()); + assertTrue(exception.response().headers().contains(HttpHeaderNames.WWW_AUTHENTICATE, + "realm = access token required", false)); + } finally { + response.release(); + } + } + + @Test + public void testHandshakeForHttpResponseWithoutAggregator() { + EmbeddedChannel channel = new EmbeddedChannel(new HttpRequestEncoder(), new HttpResponseDecoder()); + URI uri = URI.create("ws://localhost:9999/chat"); + WebSocketClientHandshaker clientHandshaker = newHandshaker(uri); + FullHttpRequest handshakeRequest = clientHandshaker.newHandshakeRequest(); + handshakeRequest.release(); + + String accept = ""; + if (clientHandshaker.version() != WebSocketVersion.V00) { + String acceptSeed = handshakeRequest.headers().get(HttpHeaderNames.SEC_WEBSOCKET_KEY) + + WEBSOCKET_13_ACCEPT_GUID; + byte[] sha1 = WebSocketUtil.sha1(acceptSeed.getBytes(CharsetUtil.US_ASCII)); + accept = WebSocketUtil.base64(sha1); + } + + HttpResponse response = new DefaultHttpResponse(HttpVersion.HTTP_1_1, SWITCHING_PROTOCOLS); + response.headers() + .set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET) + .set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE) + .set(HttpHeaderNames.SEC_WEBSOCKET_ACCEPT, accept); + + ChannelFuture handshakeFuture = clientHandshaker.processHandshake(channel, response); + assertFalse(handshakeFuture.isDone()); + assertNotNull(channel.pipeline().get("handshaker")); + + if (clientHandshaker.version() != WebSocketVersion.V00) { + assertNull(channel.pipeline().get("httpAggregator")); + channel.writeInbound(LastHttpContent.EMPTY_LAST_CONTENT); + } else { + assertNotNull(channel.pipeline().get("httpAggregator")); + channel.writeInbound(new DefaultLastHttpContent( + Unpooled.copiedBuffer("8jKS'y:G*Co,Wxa-", CharsetUtil.US_ASCII))); + } + + assertTrue(handshakeFuture.isDone()); + assertNull(channel.pipeline().get("handshaker")); + assertFalse(channel.finish()); + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketCloseStatusTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketCloseStatusTest.java new file mode 100644 index 0000000..9f9e43e --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketCloseStatusTest.java @@ -0,0 +1,154 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the "License"); + * you may not use this file except in compliance with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is + * distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and limitations under the License. + */ +package io.netty.handler.codec.http.websocketx; + +import java.util.Arrays; +import java.util.Collections; +import java.util.List; +import java.util.SortedSet; +import java.util.TreeSet; + +import org.assertj.core.api.ThrowableAssert; +import org.hamcrest.Matchers; +import org.junit.jupiter.api.Test; + +import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotSame; +import static org.junit.jupiter.api.Assertions.assertSame; + +import static io.netty.handler.codec.http.websocketx.WebSocketCloseStatus.*; + +public class WebSocketCloseStatusTest { + + private final List validCodes = Arrays.asList( + NORMAL_CLOSURE, + ENDPOINT_UNAVAILABLE, + PROTOCOL_ERROR, + INVALID_MESSAGE_TYPE, + INVALID_PAYLOAD_DATA, + POLICY_VIOLATION, + MESSAGE_TOO_BIG, + MANDATORY_EXTENSION, + INTERNAL_SERVER_ERROR, + SERVICE_RESTART, + TRY_AGAIN_LATER, + BAD_GATEWAY + ); + + @Test + public void testToString() { + assertEquals("1000 Bye", NORMAL_CLOSURE.toString()); + } + + @Test + public void testKnownStatuses() { + assertSame(NORMAL_CLOSURE, valueOf(1000)); + assertSame(ENDPOINT_UNAVAILABLE, valueOf(1001)); + assertSame(PROTOCOL_ERROR, valueOf(1002)); + assertSame(INVALID_MESSAGE_TYPE, valueOf(1003)); + assertSame(EMPTY, valueOf(1005)); + assertSame(ABNORMAL_CLOSURE, valueOf(1006)); + assertSame(INVALID_PAYLOAD_DATA, valueOf(1007)); + assertSame(POLICY_VIOLATION, valueOf(1008)); + assertSame(MESSAGE_TOO_BIG, valueOf(1009)); + assertSame(MANDATORY_EXTENSION, valueOf(1010)); + assertSame(INTERNAL_SERVER_ERROR, valueOf(1011)); + assertSame(SERVICE_RESTART, valueOf(1012)); + assertSame(TRY_AGAIN_LATER, valueOf(1013)); + assertSame(BAD_GATEWAY, valueOf(1014)); + assertSame(TLS_HANDSHAKE_FAILED, valueOf(1015)); + } + + @Test + public void testNaturalOrder() { + assertThat(PROTOCOL_ERROR, Matchers.greaterThan(NORMAL_CLOSURE)); + assertThat(PROTOCOL_ERROR, Matchers.greaterThan(valueOf(1001))); + assertThat(PROTOCOL_ERROR, Matchers.comparesEqualTo(PROTOCOL_ERROR)); + assertThat(PROTOCOL_ERROR, Matchers.comparesEqualTo(valueOf(1002))); + assertThat(PROTOCOL_ERROR, Matchers.lessThan(INVALID_MESSAGE_TYPE)); + assertThat(PROTOCOL_ERROR, Matchers.lessThan(valueOf(1007))); + } + + @Test + public void testUserDefinedStatuses() { + // Given, when + WebSocketCloseStatus feedTimeot = new WebSocketCloseStatus(6033, "Feed timed out"); + WebSocketCloseStatus untradablePrice = new WebSocketCloseStatus(6034, "Untradable price"); + + // Then + assertNotSame(feedTimeot, valueOf(6033)); + assertEquals(feedTimeot.code(), 6033); + assertEquals(feedTimeot.reasonText(), "Feed timed out"); + + assertNotSame(untradablePrice, valueOf(6034)); + assertEquals(untradablePrice.code(), 6034); + assertEquals(untradablePrice.reasonText(), "Untradable price"); + } + + @Test + public void testRfc6455CodeValidation() { + // Given + List knownCodes = Arrays.asList( + NORMAL_CLOSURE.code(), + ENDPOINT_UNAVAILABLE.code(), + PROTOCOL_ERROR.code(), + INVALID_MESSAGE_TYPE.code(), + INVALID_PAYLOAD_DATA.code(), + POLICY_VIOLATION.code(), + MESSAGE_TOO_BIG.code(), + MANDATORY_EXTENSION.code(), + INTERNAL_SERVER_ERROR.code(), + SERVICE_RESTART.code(), + TRY_AGAIN_LATER.code(), + BAD_GATEWAY.code() + ); + + SortedSet invalidCodes = new TreeSet(); + + // When + for (int statusCode = Short.MIN_VALUE; statusCode < Short.MAX_VALUE; statusCode++) { + if (!isValidStatusCode(statusCode)) { + invalidCodes.add(statusCode); + } + } + + // Then + assertEquals(0, invalidCodes.first().intValue()); + assertEquals(2999, invalidCodes.last().intValue()); + assertEquals(3000 - validCodes.size(), invalidCodes.size()); + + invalidCodes.retainAll(knownCodes); + assertEquals(invalidCodes, Collections.emptySet()); + } + + @Test + public void testValidationEnabled() { + assertThatExceptionOfType(IllegalArgumentException.class) + .isThrownBy(new ThrowableAssert.ThrowingCallable() { + + @Override + public void call() throws RuntimeException { + new WebSocketCloseStatus(1006, "validation disabled"); + } + }); + } + + @Test + public void testValidationDisabled() { + WebSocketCloseStatus status = new WebSocketCloseStatus(1006, "validation disabled", false); + assertEquals(1006, status.code()); + assertEquals("validation disabled", status.reasonText()); + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketFrameAggregatorTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketFrameAggregatorTest.java new file mode 100644 index 0000000..2339db6 --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketFrameAggregatorTest.java @@ -0,0 +1,154 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.TooLongFrameException; +import io.netty.util.CharsetUtil; +import io.netty.util.ReferenceCountUtil; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + + +public class WebSocketFrameAggregatorTest { + private static final byte[] content1 = "Content1".getBytes(CharsetUtil.UTF_8); + private static final byte[] content2 = "Content2".getBytes(CharsetUtil.UTF_8); + private static final byte[] content3 = "Content3".getBytes(CharsetUtil.UTF_8); + private static final byte[] aggregatedContent = new byte[content1.length + content2.length + content3.length]; + static { + System.arraycopy(content1, 0, aggregatedContent, 0, content1.length); + System.arraycopy(content2, 0, aggregatedContent, content1.length, content2.length); + System.arraycopy(content3, 0, aggregatedContent, content1.length + content2.length, content3.length); + } + + @Test + public void testAggregationBinary() { + EmbeddedChannel channel = new EmbeddedChannel(new WebSocketFrameAggregator(Integer.MAX_VALUE)); + channel.writeInbound(new BinaryWebSocketFrame(true, 1, Unpooled.wrappedBuffer(content1))); + channel.writeInbound(new BinaryWebSocketFrame(false, 0, Unpooled.wrappedBuffer(content1))); + channel.writeInbound(new ContinuationWebSocketFrame(false, 0, Unpooled.wrappedBuffer(content2))); + channel.writeInbound(new PingWebSocketFrame(Unpooled.wrappedBuffer(content1))); + channel.writeInbound(new PongWebSocketFrame(Unpooled.wrappedBuffer(content1))); + channel.writeInbound(new ContinuationWebSocketFrame(true, 0, Unpooled.wrappedBuffer(content3))); + + assertTrue(channel.finish()); + + BinaryWebSocketFrame frame = channel.readInbound(); + assertTrue(frame.isFinalFragment()); + assertEquals(1, frame.rsv()); + assertArrayEquals(content1, toBytes(frame.content())); + + PingWebSocketFrame frame2 = channel.readInbound(); + assertTrue(frame2.isFinalFragment()); + assertEquals(0, frame2.rsv()); + assertArrayEquals(content1, toBytes(frame2.content())); + + PongWebSocketFrame frame3 = channel.readInbound(); + assertTrue(frame3.isFinalFragment()); + assertEquals(0, frame3.rsv()); + assertArrayEquals(content1, toBytes(frame3.content())); + + BinaryWebSocketFrame frame4 = channel.readInbound(); + assertTrue(frame4.isFinalFragment()); + assertEquals(0, frame4.rsv()); + assertArrayEquals(aggregatedContent, toBytes(frame4.content())); + + assertNull(channel.readInbound()); + } + + @Test + public void testAggregationText() { + EmbeddedChannel channel = new EmbeddedChannel(new WebSocketFrameAggregator(Integer.MAX_VALUE)); + channel.writeInbound(new TextWebSocketFrame(true, 1, Unpooled.wrappedBuffer(content1))); + channel.writeInbound(new TextWebSocketFrame(false, 0, Unpooled.wrappedBuffer(content1))); + channel.writeInbound(new ContinuationWebSocketFrame(false, 0, Unpooled.wrappedBuffer(content2))); + channel.writeInbound(new PingWebSocketFrame(Unpooled.wrappedBuffer(content1))); + channel.writeInbound(new PongWebSocketFrame(Unpooled.wrappedBuffer(content1))); + channel.writeInbound(new ContinuationWebSocketFrame(true, 0, Unpooled.wrappedBuffer(content3))); + + assertTrue(channel.finish()); + + TextWebSocketFrame frame = channel.readInbound(); + assertTrue(frame.isFinalFragment()); + assertEquals(1, frame.rsv()); + assertArrayEquals(content1, toBytes(frame.content())); + + PingWebSocketFrame frame2 = channel.readInbound(); + assertTrue(frame2.isFinalFragment()); + assertEquals(0, frame2.rsv()); + assertArrayEquals(content1, toBytes(frame2.content())); + + PongWebSocketFrame frame3 = channel.readInbound(); + assertTrue(frame3.isFinalFragment()); + assertEquals(0, frame3.rsv()); + assertArrayEquals(content1, toBytes(frame3.content())); + + TextWebSocketFrame frame4 = channel.readInbound(); + assertTrue(frame4.isFinalFragment()); + assertEquals(0, frame4.rsv()); + assertArrayEquals(aggregatedContent, toBytes(frame4.content())); + + assertNull(channel.readInbound()); + } + + @Test + public void textFrameTooBig() throws Exception { + EmbeddedChannel channel = new EmbeddedChannel(new WebSocketFrameAggregator(8)); + channel.writeInbound(new BinaryWebSocketFrame(true, 1, Unpooled.wrappedBuffer(content1))); + channel.writeInbound(new BinaryWebSocketFrame(false, 0, Unpooled.wrappedBuffer(content1))); + try { + channel.writeInbound(new ContinuationWebSocketFrame(false, 0, Unpooled.wrappedBuffer(content2))); + fail(); + } catch (TooLongFrameException e) { + // expected + } + channel.writeInbound(new ContinuationWebSocketFrame(false, 0, Unpooled.wrappedBuffer(content2))); + channel.writeInbound(new ContinuationWebSocketFrame(true, 0, Unpooled.wrappedBuffer(content2))); + + channel.writeInbound(new BinaryWebSocketFrame(true, 1, Unpooled.wrappedBuffer(content1))); + channel.writeInbound(new BinaryWebSocketFrame(false, 0, Unpooled.wrappedBuffer(content1))); + try { + channel.writeInbound(new ContinuationWebSocketFrame(false, 0, Unpooled.wrappedBuffer(content2))); + fail(); + } catch (TooLongFrameException e) { + // expected + } + channel.writeInbound(new ContinuationWebSocketFrame(false, 0, Unpooled.wrappedBuffer(content2))); + channel.writeInbound(new ContinuationWebSocketFrame(true, 0, Unpooled.wrappedBuffer(content2))); + for (;;) { + Object msg = channel.readInbound(); + if (msg == null) { + break; + } + ReferenceCountUtil.release(msg); + } + channel.finish(); + } + + private static byte[] toBytes(ByteBuf buf) { + byte[] bytes = new byte[buf.readableBytes()]; + buf.readBytes(bytes); + buf.release(); + return bytes; + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketHandshakeExceptionTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketHandshakeExceptionTest.java new file mode 100644 index 0000000..e9ec9d7 --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketHandshakeExceptionTest.java @@ -0,0 +1,76 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx; + +import io.netty.handler.codec.http.DefaultHttpRequest; +import io.netty.handler.codec.http.DefaultHttpResponse; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpRequest; +import io.netty.handler.codec.http.HttpResponse; +import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.handler.codec.http.HttpVersion; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; + +public class WebSocketHandshakeExceptionTest { + + @Test + public void testClientExceptionWithoutResponse() { + WebSocketClientHandshakeException clientException = new WebSocketClientHandshakeException("client message"); + + assertNull(clientException.response()); + assertEquals("client message", clientException.getMessage()); + } + + @Test + public void testClientExceptionWithResponse() { + HttpResponse httpResponse = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.BAD_REQUEST); + httpResponse.headers().set("x-header", "x-value"); + WebSocketClientHandshakeException clientException = new WebSocketClientHandshakeException("client message", + httpResponse); + + assertNotNull(clientException.response()); + assertEquals("client message", clientException.getMessage()); + assertEquals(HttpResponseStatus.BAD_REQUEST, clientException.response().status()); + assertEquals(httpResponse.headers(), clientException.response().headers()); + } + + @Test + public void testServerExceptionWithoutRequest() { + WebSocketServerHandshakeException serverException = new WebSocketServerHandshakeException("server message"); + + assertNull(serverException.request()); + assertEquals("server message", serverException.getMessage()); + } + + @Test + public void testClientExceptionWithRequest() { + HttpRequest httpRequest = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, + "ws://localhost:9999/ws"); + httpRequest.headers().set("x-header", "x-value"); + WebSocketServerHandshakeException serverException = new WebSocketServerHandshakeException("server message", + httpRequest); + + assertNotNull(serverException.request()); + assertEquals("server message", serverException.getMessage()); + assertEquals(HttpMethod.GET, serverException.request().method()); + assertEquals(httpRequest.headers(), serverException.request().headers()); + assertEquals(httpRequest.uri(), serverException.request().uri()); + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketHandshakeHandOverTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketHandshakeHandOverTest.java new file mode 100644 index 0000000..f3d2fd8 --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketHandshakeHandOverTest.java @@ -0,0 +1,371 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version + * 2.0 (the "License"); you may not use this file except in compliance with the + * License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http.websocketx; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.http.EmptyHttpHeaders; +import io.netty.handler.codec.http.HttpClientCodec; +import io.netty.handler.codec.http.HttpObjectAggregator; +import io.netty.handler.codec.http.HttpServerCodec; +import io.netty.handler.codec.http.websocketx.WebSocketClientProtocolHandler.ClientHandshakeStateEvent; +import io.netty.handler.codec.http.websocketx.WebSocketServerProtocolHandler.ServerHandshakeStateEvent; +import java.util.concurrent.TimeUnit; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.net.URI; +import java.util.List; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.api.function.Executable; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class WebSocketHandshakeHandOverTest { + + private boolean serverReceivedHandshake; + private WebSocketServerProtocolHandler.HandshakeComplete serverHandshakeComplete; + private boolean clientReceivedHandshake; + private boolean clientReceivedMessage; + private boolean serverReceivedCloseHandshake; + private boolean clientForceClosed; + private boolean clientHandshakeTimeout; + + private final class CloseNoOpServerProtocolHandler extends WebSocketServerProtocolHandler { + CloseNoOpServerProtocolHandler(String websocketPath) { + super(WebSocketServerProtocolConfig.newBuilder() + .websocketPath(websocketPath) + .allowExtensions(false) + .sendCloseFrame(null) + .build()); + } + + @Override + protected void decode(ChannelHandlerContext ctx, WebSocketFrame frame, List out) throws Exception { + if (frame instanceof CloseWebSocketFrame) { + serverReceivedCloseHandshake = true; + return; + } + super.decode(ctx, frame, out); + } + } + + @BeforeEach + public void setUp() { + serverReceivedHandshake = false; + serverHandshakeComplete = null; + clientReceivedHandshake = false; + clientReceivedMessage = false; + serverReceivedCloseHandshake = false; + clientForceClosed = false; + clientHandshakeTimeout = false; + } + + @Test + public void testHandover() throws Exception { + EmbeddedChannel serverChannel = createServerChannel(new SimpleChannelInboundHandler() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + if (evt == ServerHandshakeStateEvent.HANDSHAKE_COMPLETE) { + serverReceivedHandshake = true; + // immediately send a message to the client on connect + ctx.writeAndFlush(new TextWebSocketFrame("abc")); + } else if (evt instanceof WebSocketServerProtocolHandler.HandshakeComplete) { + serverHandshakeComplete = (WebSocketServerProtocolHandler.HandshakeComplete) evt; + } + } + @Override + protected void channelRead0(ChannelHandlerContext ctx, Object msg) throws Exception { + } + }); + + EmbeddedChannel clientChannel = createClientChannel(new SimpleChannelInboundHandler() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + if (evt == ClientHandshakeStateEvent.HANDSHAKE_COMPLETE) { + clientReceivedHandshake = true; + } + } + @Override + protected void channelRead0(ChannelHandlerContext ctx, Object msg) throws Exception { + if (msg instanceof TextWebSocketFrame) { + clientReceivedMessage = true; + } + } + }); + + // Transfer the handshake from the client to the server + transferAllDataWithMerge(clientChannel, serverChannel); + assertTrue(serverReceivedHandshake); + assertNotNull(serverHandshakeComplete); + assertEquals("/test", serverHandshakeComplete.requestUri()); + assertEquals(8, serverHandshakeComplete.requestHeaders().size()); + assertEquals("test-proto-2", serverHandshakeComplete.selectedSubprotocol()); + + // Transfer the handshake response and the websocket message to the client + transferAllDataWithMerge(serverChannel, clientChannel); + assertTrue(clientReceivedHandshake); + assertTrue(clientReceivedMessage); + } + + @Test + public void testClientHandshakeTimeout() throws Throwable { + EmbeddedChannel serverChannel = createServerChannel(new SimpleChannelInboundHandler() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + if (evt == ServerHandshakeStateEvent.HANDSHAKE_COMPLETE) { + serverReceivedHandshake = true; + // immediately send a message to the client on connect + ctx.writeAndFlush(new TextWebSocketFrame("abc")); + } else if (evt instanceof WebSocketServerProtocolHandler.HandshakeComplete) { + serverHandshakeComplete = (WebSocketServerProtocolHandler.HandshakeComplete) evt; + } + } + + @Override + protected void channelRead0(ChannelHandlerContext ctx, Object msg) throws Exception { + } + }); + + EmbeddedChannel clientChannel = createClientChannel(new SimpleChannelInboundHandler() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + if (evt == ClientHandshakeStateEvent.HANDSHAKE_COMPLETE) { + clientReceivedHandshake = true; + } else if (evt == ClientHandshakeStateEvent.HANDSHAKE_TIMEOUT) { + clientHandshakeTimeout = true; + } + } + + @Override + protected void channelRead0(ChannelHandlerContext ctx, Object msg) throws Exception { + if (msg instanceof TextWebSocketFrame) { + clientReceivedMessage = true; + } + } + }, 100); + // Client send the handshake request to server + transferAllDataWithMerge(clientChannel, serverChannel); + // Server do not send the response back + // transferAllDataWithMerge(serverChannel, clientChannel); + final WebSocketClientProtocolHandshakeHandler handshakeHandler = + (WebSocketClientProtocolHandshakeHandler) clientChannel + .pipeline().get(WebSocketClientProtocolHandshakeHandler.class.getName()); + + while (!handshakeHandler.getHandshakeFuture().isDone()) { + Thread.sleep(10); + // We need to run all pending tasks as the handshake timeout is scheduled on the EventLoop. + clientChannel.runScheduledPendingTasks(); + } + assertTrue(clientHandshakeTimeout); + assertFalse(clientReceivedHandshake); + assertFalse(clientReceivedMessage); + // Should throw WebSocketHandshakeException + try { + assertThrows(WebSocketHandshakeException.class, new Executable() { + @Override + public void execute() { + handshakeHandler.getHandshakeFuture().syncUninterruptibly(); + } + }); + } finally { + serverChannel.finishAndReleaseAll(); + } + } + + /** + * Tests a scenario when channel is closed while the handshake is in progress. Validates that the handshake + * future is notified in such cases. + */ + @Test + public void testHandshakeFutureIsNotifiedOnChannelClose() throws Exception { + EmbeddedChannel clientChannel = createClientChannel(null); + EmbeddedChannel serverChannel = createServerChannel(null); + + try { + // Start handshake from client to server but don't complete the handshake for the purpose of this test. + transferAllDataWithMerge(clientChannel, serverChannel); + + final WebSocketClientProtocolHandler clientWsHandler = + clientChannel.pipeline().get(WebSocketClientProtocolHandler.class); + final WebSocketClientProtocolHandshakeHandler clientWsHandshakeHandler = + clientChannel.pipeline().get(WebSocketClientProtocolHandshakeHandler.class); + + final ChannelHandlerContext ctx = clientChannel.pipeline().context(WebSocketClientProtocolHandler.class); + + // Close the channel while the handshake is in progress. The channel could be closed before the handshake is + // complete due to a number of varied reasons. To reproduce the test scenario for this test case, + // we would manually close the channel. + clientWsHandler.close(ctx, ctx.newPromise()); + + // At this stage handshake is incomplete but the handshake future should be completed exceptionally since + // channel is closed. + assertTrue(clientWsHandshakeHandler.getHandshakeFuture().isDone()); + } finally { + serverChannel.finishAndReleaseAll(); + clientChannel.finishAndReleaseAll(); + } + } + + @Test + @Timeout(value = 10000, unit = TimeUnit.MILLISECONDS) + public void testClientHandshakerForceClose() throws Exception { + final WebSocketClientHandshaker handshaker = WebSocketClientHandshakerFactory.newHandshaker( + new URI("ws://localhost:1234/test"), WebSocketVersion.V13, null, true, + EmptyHttpHeaders.INSTANCE, Integer.MAX_VALUE, true, false, 20); + + EmbeddedChannel serverChannel = createServerChannel( + new CloseNoOpServerProtocolHandler("/test"), + new SimpleChannelInboundHandler() { + @Override + protected void channelRead0(ChannelHandlerContext ctx, Object msg) throws Exception { + } + }); + + EmbeddedChannel clientChannel = createClientChannel(handshaker, new SimpleChannelInboundHandler() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + if (evt == ClientHandshakeStateEvent.HANDSHAKE_COMPLETE) { + ctx.channel().closeFuture().addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + clientForceClosed = true; + } + }); + handshaker.close(ctx.channel(), new CloseWebSocketFrame()); + } + } + @Override + protected void channelRead0(ChannelHandlerContext ctx, Object msg) throws Exception { + } + }); + + // Transfer the handshake from the client to the server + transferAllDataWithMerge(clientChannel, serverChannel); + // Transfer the handshake from the server to client + transferAllDataWithMerge(serverChannel, clientChannel); + + // Transfer closing handshake + transferAllDataWithMerge(clientChannel, serverChannel); + assertTrue(serverReceivedCloseHandshake); + // Should not be closed yet as we disabled closing the connection on the server + assertFalse(clientForceClosed); + + while (!clientForceClosed) { + Thread.sleep(10); + // We need to run all pending tasks as the force close timeout is scheduled on the EventLoop. + clientChannel.runPendingTasks(); + } + + // clientForceClosed would be set to TRUE after any close, + // so check here that force close timeout was actually fired + assertTrue(handshaker.isForceCloseComplete()); + + // Both should be empty + assertFalse(serverChannel.finishAndReleaseAll()); + assertFalse(clientChannel.finishAndReleaseAll()); + } + + /** + * Transfers all pending data from the source channel into the destination channel.
+ * Merges all data into a single buffer before transmission into the destination. + * @param srcChannel The source channel + * @param dstChannel The destination channel + */ + private static void transferAllDataWithMerge(EmbeddedChannel srcChannel, EmbeddedChannel dstChannel) { + ByteBuf mergedBuffer = null; + for (;;) { + Object srcData = srcChannel.readOutbound(); + + if (srcData != null) { + assertTrue(srcData instanceof ByteBuf); + ByteBuf srcBuf = (ByteBuf) srcData; + try { + if (mergedBuffer == null) { + mergedBuffer = Unpooled.buffer(); + } + mergedBuffer.writeBytes(srcBuf); + } finally { + srcBuf.release(); + } + } else { + break; + } + } + + if (mergedBuffer != null) { + dstChannel.writeInbound(mergedBuffer); + } + } + + private static EmbeddedChannel createClientChannel(ChannelHandler handler) throws Exception { + return createClientChannel(handler, WebSocketClientProtocolConfig.newBuilder() + .webSocketUri("ws://localhost:1234/test") + .subprotocol("test-proto-2") + .build()); + } + + private static EmbeddedChannel createClientChannel(ChannelHandler handler, long timeoutMillis) throws Exception { + return createClientChannel(handler, WebSocketClientProtocolConfig.newBuilder() + .webSocketUri("ws://localhost:1234/test") + .subprotocol("test-proto-2") + .handshakeTimeoutMillis(timeoutMillis) + .build()); + } + + private static EmbeddedChannel createClientChannel(ChannelHandler handler, WebSocketClientProtocolConfig config) { + return new EmbeddedChannel( + new HttpClientCodec(), + new HttpObjectAggregator(8192), + new WebSocketClientProtocolHandler(config), + handler); + } + + private static EmbeddedChannel createClientChannel(WebSocketClientHandshaker handshaker, + ChannelHandler handler) throws Exception { + return new EmbeddedChannel( + new HttpClientCodec(), + new HttpObjectAggregator(8192), + // Note that we're switching off close frames handling on purpose to test forced close on timeout. + new WebSocketClientProtocolHandler(handshaker, false, false), + handler); + } + + private static EmbeddedChannel createServerChannel(ChannelHandler handler) { + return createServerChannel( + new WebSocketServerProtocolHandler("/test", "test-proto-1, test-proto-2", false), + handler); + } + + private static EmbeddedChannel createServerChannel(WebSocketServerProtocolHandler webSocketHandler, + ChannelHandler handler) { + return new EmbeddedChannel( + new HttpServerCodec(), + new HttpObjectAggregator(8192), + webSocketHandler, + handler); + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketProtocolHandlerTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketProtocolHandlerTest.java new file mode 100644 index 0000000..f30d792 --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketProtocolHandlerTest.java @@ -0,0 +1,188 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.codec.http.websocketx; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelOutboundHandlerAdapter; +import io.netty.channel.ChannelPromise; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.flow.FlowControlHandler; +import io.netty.util.ReferenceCountUtil; +import org.hamcrest.Matchers; +import org.junit.jupiter.api.Test; + +import java.util.concurrent.atomic.AtomicReference; + +import static io.netty.util.CharsetUtil.UTF_8; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * Tests common, abstract class functionality in {@link WebSocketClientProtocolHandler}. + */ +public class WebSocketProtocolHandlerTest { + + @Test + public void testPingFrame() { + ByteBuf pingData = Unpooled.copiedBuffer("Hello, world", UTF_8); + EmbeddedChannel channel = new EmbeddedChannel(new WebSocketProtocolHandler() { }); + + PingWebSocketFrame inputMessage = new PingWebSocketFrame(pingData); + assertFalse(channel.writeInbound(inputMessage)); // the message was not propagated inbound + + // a Pong frame was written to the channel + PongWebSocketFrame response = channel.readOutbound(); + assertEquals(pingData, response.content()); + + pingData.release(); + assertFalse(channel.finish()); + } + + @Test + public void testPingPongFlowControlWhenAutoReadIsDisabled() { + String text1 = "Hello, world #1"; + String text2 = "Hello, world #2"; + String text3 = "Hello, world #3"; + String text4 = "Hello, world #4"; + + EmbeddedChannel channel = new EmbeddedChannel(); + channel.config().setAutoRead(false); + channel.pipeline().addLast(new FlowControlHandler()); + channel.pipeline().addLast(new WebSocketProtocolHandler() { }); + + // When + assertFalse(channel.writeInbound( + new PingWebSocketFrame(Unpooled.copiedBuffer(text1, UTF_8)), + new TextWebSocketFrame(text2), + new TextWebSocketFrame(text3), + new PingWebSocketFrame(Unpooled.copiedBuffer(text4, UTF_8)) + )); + + // Then - no messages were handled or propagated + assertNull(channel.readInbound()); + assertNull(channel.readOutbound()); + + // When + channel.read(); + + // Then - pong frame was written to the outbound + PongWebSocketFrame response1 = channel.readOutbound(); + assertEquals(text1, response1.content().toString(UTF_8)); + + // And - one requested message was handled and propagated inbound + TextWebSocketFrame message2 = channel.readInbound(); + assertEquals(text2, message2.text()); + + // And - no more messages were handled or propagated + assertNull(channel.readInbound()); + assertNull(channel.readOutbound()); + + // When + channel.read(); + + // Then - one requested message was handled and propagated inbound + TextWebSocketFrame message3 = channel.readInbound(); + assertEquals(text3, message3.text()); + + // And - no more messages were handled or propagated + // Precisely, ping frame 'text4' was NOT read or handled. + // It would be handle ONLY on the next 'channel.read()' call. + assertNull(channel.readInbound()); + assertNull(channel.readOutbound()); + + // Cleanup + response1.release(); + message2.release(); + message3.release(); + assertFalse(channel.finish()); + } + + @Test + public void testPongFrameDropFrameFalse() { + EmbeddedChannel channel = new EmbeddedChannel(new WebSocketProtocolHandler(false) { }); + + PongWebSocketFrame pingResponse = new PongWebSocketFrame(); + assertTrue(channel.writeInbound(pingResponse)); + + assertPropagatedInbound(pingResponse, channel); + + pingResponse.release(); + assertFalse(channel.finish()); + } + + @Test + public void testPongFrameDropFrameTrue() { + EmbeddedChannel channel = new EmbeddedChannel(new WebSocketProtocolHandler(true) { }); + + PongWebSocketFrame pingResponse = new PongWebSocketFrame(); + assertFalse(channel.writeInbound(pingResponse)); // message was not propagated inbound + } + + @Test + public void testTextFrame() { + EmbeddedChannel channel = new EmbeddedChannel(new WebSocketProtocolHandler() { }); + + TextWebSocketFrame textFrame = new TextWebSocketFrame(); + assertTrue(channel.writeInbound(textFrame)); + + assertPropagatedInbound(textFrame, channel); + + textFrame.release(); + assertFalse(channel.finish()); + } + + @Test + public void testTimeout() throws Exception { + final AtomicReference ref = new AtomicReference(); + WebSocketProtocolHandler handler = new WebSocketProtocolHandler( + false, WebSocketCloseStatus.NORMAL_CLOSURE, 1) { }; + EmbeddedChannel channel = new EmbeddedChannel(new ChannelOutboundHandlerAdapter() { + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) { + ref.set(promise); + ReferenceCountUtil.release(msg); + } + }, handler); + + ChannelFuture future = channel.writeAndFlush(new CloseWebSocketFrame()); + ChannelHandlerContext ctx = channel.pipeline().context(WebSocketProtocolHandler.class); + handler.close(ctx, ctx.newPromise()); + + do { + Thread.sleep(10); + channel.runPendingTasks(); + } while (!future.isDone()); + + assertThat(future.cause(), Matchers.instanceOf(WebSocketHandshakeException.class)); + assertFalse(ref.get().isDone()); + assertFalse(channel.finish()); + } + + /** + * Asserts that a message was propagated inbound through the channel. + */ + private static void assertPropagatedInbound(T message, EmbeddedChannel channel) { + T propagatedResponse = channel.readInbound(); + assertEquals(message, propagatedResponse); + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketRequestBuilder.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketRequestBuilder.java new file mode 100644 index 0000000..685083c --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketRequestBuilder.java @@ -0,0 +1,165 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version + * 2.0 (the "License"); you may not use this file except in compliance with the + * License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http.websocketx; + +import io.netty.handler.codec.http.DefaultFullHttpRequest; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaderValues; +import io.netty.handler.codec.http.HttpHeaders; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpRequest; +import io.netty.handler.codec.http.HttpVersion; + +import static io.netty.handler.codec.http.HttpVersion.*; + +public class WebSocketRequestBuilder { + + private HttpVersion httpVersion; + private HttpMethod method; + private String uri; + private String host; + private String upgrade; + private String connection; + private String key; + private String origin; + private WebSocketVersion version; + + public WebSocketRequestBuilder httpVersion(HttpVersion httpVersion) { + this.httpVersion = httpVersion; + return this; + } + + public WebSocketRequestBuilder method(HttpMethod method) { + this.method = method; + return this; + } + + public WebSocketRequestBuilder uri(CharSequence uri) { + if (uri == null) { + this.uri = null; + } else { + this.uri = uri.toString(); + } + return this; + } + + public WebSocketRequestBuilder host(CharSequence host) { + if (host == null) { + this.host = null; + } else { + this.host = host.toString(); + } + return this; + } + + public WebSocketRequestBuilder upgrade(CharSequence upgrade) { + if (upgrade == null) { + this.upgrade = null; + } else { + this.upgrade = upgrade.toString(); + } + return this; + } + + public WebSocketRequestBuilder connection(CharSequence connection) { + if (connection == null) { + this.connection = null; + } else { + this.connection = connection.toString(); + } + return this; + } + + public WebSocketRequestBuilder key(CharSequence key) { + if (key == null) { + this.key = null; + } else { + this.key = key.toString(); + } + return this; + } + + public WebSocketRequestBuilder origin(CharSequence origin) { + if (origin == null) { + this.origin = null; + } else { + this.origin = origin.toString(); + } + return this; + } + + public WebSocketRequestBuilder version13() { + version = WebSocketVersion.V13; + return this; + } + + public WebSocketRequestBuilder version8() { + version = WebSocketVersion.V08; + return this; + } + + public WebSocketRequestBuilder version00() { + version = null; + return this; + } + + public WebSocketRequestBuilder noVersion() { + return this; + } + + public FullHttpRequest build() { + FullHttpRequest req = new DefaultFullHttpRequest(httpVersion, method, uri); + HttpHeaders headers = req.headers(); + + if (host != null) { + headers.set(HttpHeaderNames.HOST, host); + } + if (upgrade != null) { + headers.set(HttpHeaderNames.UPGRADE, upgrade); + } + if (connection != null) { + headers.set(HttpHeaderNames.CONNECTION, connection); + } + if (key != null) { + headers.set(HttpHeaderNames.SEC_WEBSOCKET_KEY, key); + } + if (origin != null) { + if (version == WebSocketVersion.V13 || version == WebSocketVersion.V00) { + headers.set(HttpHeaderNames.ORIGIN, origin); + } else { + headers.set(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN, origin); + } + } + if (version != null) { + headers.set(HttpHeaderNames.SEC_WEBSOCKET_VERSION, version.toHttpHeaderValue()); + } + return req; + } + + public static HttpRequest successful() { + return new WebSocketRequestBuilder().httpVersion(HTTP_1_1) + .method(HttpMethod.GET) + .uri("/test") + .host("server.example.com") + .upgrade(HttpHeaderValues.WEBSOCKET) + .connection(HttpHeaderValues.UPGRADE) + .key("dGhlIHNhbXBsZSBub25jZQ==") + .origin("http://example.com") + .version13() + .build(); + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshaker00Test.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshaker00Test.java new file mode 100644 index 0000000..e6afcd5 --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshaker00Test.java @@ -0,0 +1,131 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.http.DefaultFullHttpRequest; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaderValues; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpObjectAggregator; +import io.netty.handler.codec.http.HttpRequestDecoder; +import io.netty.handler.codec.http.HttpResponse; +import io.netty.handler.codec.http.HttpResponseDecoder; +import io.netty.handler.codec.http.HttpResponseEncoder; +import io.netty.handler.codec.http.LastHttpContent; +import io.netty.util.CharsetUtil; +import org.junit.jupiter.api.Test; + +import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.fail; + +public class WebSocketServerHandshaker00Test extends WebSocketServerHandshakerTest { + + @Override + protected WebSocketServerHandshaker newHandshaker(String webSocketURL, String subprotocols, + WebSocketDecoderConfig decoderConfig) { + return new WebSocketServerHandshaker00(webSocketURL, subprotocols, decoderConfig); + } + + @Override + protected WebSocketVersion webSocketVersion() { + return WebSocketVersion.V00; + } + + @Test + public void testPerformOpeningHandshake() { + testPerformOpeningHandshake0(true); + } + + @Test + public void testPerformOpeningHandshakeSubProtocolNotSupported() { + testPerformOpeningHandshake0(false); + } + + @Test + public void testPerformHandshakeWithoutOriginHeader() { + EmbeddedChannel ch = new EmbeddedChannel( + new HttpObjectAggregator(42), new HttpRequestDecoder(), new HttpResponseEncoder()); + + FullHttpRequest req = new DefaultFullHttpRequest( + HTTP_1_1, HttpMethod.GET, "/chat", Unpooled.copiedBuffer("^n:ds[4U", CharsetUtil.US_ASCII)); + + req.headers().set(HttpHeaderNames.HOST, "server.example.com"); + req.headers().set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET); + req.headers().set(HttpHeaderNames.CONNECTION, "Upgrade"); + req.headers().set(HttpHeaderNames.SEC_WEBSOCKET_KEY1, "4 @1 46546xW%0l 1 5"); + req.headers().set(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, "chat, superchat"); + + WebSocketServerHandshaker00 handshaker00 = new WebSocketServerHandshaker00( + "ws://example.com/chat", "chat", Integer.MAX_VALUE); + try { + handshaker00.handshake(ch, req); + fail("Expecting WebSocketHandshakeException"); + } catch (WebSocketHandshakeException e) { + assertEquals("Missing origin header, got only " + + "[host, upgrade, connection, sec-websocket-key1, sec-websocket-protocol]", + e.getMessage()); + } finally { + req.release(); + } + } + + private static void testPerformOpeningHandshake0(boolean subProtocol) { + EmbeddedChannel ch = new EmbeddedChannel( + new HttpObjectAggregator(42), new HttpRequestDecoder(), new HttpResponseEncoder()); + + FullHttpRequest req = new DefaultFullHttpRequest( + HTTP_1_1, HttpMethod.GET, "/chat", Unpooled.copiedBuffer("^n:ds[4U", CharsetUtil.US_ASCII)); + + req.headers().set(HttpHeaderNames.HOST, "server.example.com"); + req.headers().set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET); + req.headers().set(HttpHeaderNames.CONNECTION, "Upgrade"); + req.headers().set(HttpHeaderNames.ORIGIN, "http://example.com"); + req.headers().set(HttpHeaderNames.SEC_WEBSOCKET_KEY1, "4 @1 46546xW%0l 1 5"); + req.headers().set(HttpHeaderNames.SEC_WEBSOCKET_KEY2, "12998 5 Y3 1 .P00"); + req.headers().set(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, "chat, superchat"); + + if (subProtocol) { + new WebSocketServerHandshaker00( + "ws://example.com/chat", "chat", Integer.MAX_VALUE).handshake(ch, req); + } else { + new WebSocketServerHandshaker00( + "ws://example.com/chat", null, Integer.MAX_VALUE).handshake(ch, req); + } + + EmbeddedChannel ch2 = new EmbeddedChannel(new HttpResponseDecoder()); + ch2.writeInbound(ch.readOutbound()); + HttpResponse res = ch2.readInbound(); + + assertEquals("ws://example.com/chat", res.headers().get(HttpHeaderNames.SEC_WEBSOCKET_LOCATION)); + + if (subProtocol) { + assertEquals("chat", res.headers().get(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL)); + } else { + assertNull(res.headers().get(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL)); + } + LastHttpContent content = ch2.readInbound(); + + assertEquals("8jKS'y:G*Co,Wxa-", content.content().toString(CharsetUtil.US_ASCII)); + content.release(); + req.release(); + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshaker07Test.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshaker07Test.java new file mode 100644 index 0000000..0a13f42 --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshaker07Test.java @@ -0,0 +1,30 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx; + +public class WebSocketServerHandshaker07Test extends WebSocketServerHandshakerTest { + + @Override + protected WebSocketServerHandshaker newHandshaker(String webSocketURL, String subprotocols, + WebSocketDecoderConfig decoderConfig) { + return new WebSocketServerHandshaker07(webSocketURL, subprotocols, decoderConfig); + } + + @Override + protected WebSocketVersion webSocketVersion() { + return WebSocketVersion.V07; + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshaker08Test.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshaker08Test.java new file mode 100644 index 0000000..153d918 --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshaker08Test.java @@ -0,0 +1,97 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.http.DefaultFullHttpRequest; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaderValues; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpObjectAggregator; +import io.netty.handler.codec.http.HttpRequestDecoder; +import io.netty.handler.codec.http.HttpResponse; +import io.netty.handler.codec.http.HttpResponseDecoder; +import io.netty.handler.codec.http.HttpResponseEncoder; +import io.netty.util.ReferenceCountUtil; +import org.junit.jupiter.api.Test; + +import static io.netty.handler.codec.http.HttpVersion.*; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; + +public class WebSocketServerHandshaker08Test extends WebSocketServerHandshakerTest { + + @Override + protected WebSocketServerHandshaker newHandshaker(String webSocketURL, String subprotocols, + WebSocketDecoderConfig decoderConfig) { + return new WebSocketServerHandshaker08(webSocketURL, subprotocols, decoderConfig); + } + + @Override + protected WebSocketVersion webSocketVersion() { + return WebSocketVersion.V08; + } + + @Test + public void testPerformOpeningHandshake() { + testPerformOpeningHandshake0(true); + } + + @Test + public void testPerformOpeningHandshakeSubProtocolNotSupported() { + testPerformOpeningHandshake0(false); + } + + private static void testPerformOpeningHandshake0(boolean subProtocol) { + EmbeddedChannel ch = new EmbeddedChannel( + new HttpObjectAggregator(42), new HttpRequestDecoder(), new HttpResponseEncoder()); + + FullHttpRequest req = new DefaultFullHttpRequest(HTTP_1_1, HttpMethod.GET, "/chat"); + req.headers().set(HttpHeaderNames.HOST, "server.example.com"); + req.headers().set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET); + req.headers().set(HttpHeaderNames.CONNECTION, "Upgrade"); + req.headers().set(HttpHeaderNames.SEC_WEBSOCKET_KEY, "dGhlIHNhbXBsZSBub25jZQ=="); + req.headers().set(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN, "http://example.com"); + req.headers().set(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, "chat, superchat"); + req.headers().set(HttpHeaderNames.SEC_WEBSOCKET_VERSION, "8"); + + if (subProtocol) { + new WebSocketServerHandshaker08( + "ws://example.com/chat", "chat", false, Integer.MAX_VALUE, false).handshake(ch, req); + } else { + new WebSocketServerHandshaker08( + "ws://example.com/chat", null, false, Integer.MAX_VALUE, false).handshake(ch, req); + } + + ByteBuf resBuf = ch.readOutbound(); + + EmbeddedChannel ch2 = new EmbeddedChannel(new HttpResponseDecoder()); + ch2.writeInbound(resBuf); + HttpResponse res = ch2.readInbound(); + + assertEquals( + "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=", res.headers().get(HttpHeaderNames.SEC_WEBSOCKET_ACCEPT)); + if (subProtocol) { + assertEquals("chat", res.headers().get(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL)); + } else { + assertNull(res.headers().get(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL)); + } + ReferenceCountUtil.release(res); + req.release(); + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshaker13Test.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshaker13Test.java new file mode 100644 index 0000000..4fc8ce9 --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshaker13Test.java @@ -0,0 +1,226 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandler; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.http.DefaultFullHttpRequest; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaderValues; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpObjectAggregator; +import io.netty.handler.codec.http.HttpRequestDecoder; +import io.netty.handler.codec.http.HttpResponse; +import io.netty.handler.codec.http.HttpResponseDecoder; +import io.netty.handler.codec.http.HttpResponseEncoder; +import io.netty.handler.codec.http.HttpServerCodec; +import io.netty.handler.codec.http.HttpVersion; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.ReferenceCounted; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import java.util.Iterator; + +import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1; +import static org.hamcrest.CoreMatchers.instanceOf; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class WebSocketServerHandshaker13Test extends WebSocketServerHandshakerTest { + + @Override + protected WebSocketServerHandshaker newHandshaker(String webSocketURL, String subprotocols, + WebSocketDecoderConfig decoderConfig) { + return new WebSocketServerHandshaker13(webSocketURL, subprotocols, decoderConfig); + } + + @Override + protected WebSocketVersion webSocketVersion() { + return WebSocketVersion.V13; + } + + @Test + public void testPerformOpeningHandshake() { + testPerformOpeningHandshake0(true); + } + + @Test + public void testPerformOpeningHandshakeSubProtocolNotSupported() { + testPerformOpeningHandshake0(false); + } + + private static void testPerformOpeningHandshake0(boolean subProtocol) { + EmbeddedChannel ch = new EmbeddedChannel( + new HttpObjectAggregator(42), new HttpResponseEncoder(), new HttpRequestDecoder()); + + if (subProtocol) { + testUpgrade0(ch, new WebSocketServerHandshaker13( + "ws://example.com/chat", "chat", false, Integer.MAX_VALUE, false)); + } else { + testUpgrade0(ch, new WebSocketServerHandshaker13( + "ws://example.com/chat", null, false, Integer.MAX_VALUE, false)); + } + assertFalse(ch.finish()); + } + + @Test + public void testCloseReasonWithEncoderAndDecoder() { + testCloseReason0(new HttpResponseEncoder(), new HttpRequestDecoder()); + } + + @Test + public void testCloseReasonWithCodec() { + testCloseReason0(new HttpServerCodec()); + } + + @Test + public void testHandshakeExceptionWhenConnectionHeaderIsAbsent() { + final WebSocketServerHandshaker serverHandshaker = newHandshaker("ws://example.com/chat", + "chat", WebSocketDecoderConfig.DEFAULT); + final FullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, + "ws://example.com/chat"); + request.headers() + .set(HttpHeaderNames.HOST, "server.example.com") + .set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET) + .set(HttpHeaderNames.SEC_WEBSOCKET_KEY, "dGhlIHNhbXBsZSBub25jZQ==") + .set(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN, "http://example.com") + .set(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, "chat, superchat") + .set(HttpHeaderNames.SEC_WEBSOCKET_VERSION, "13"); + Throwable exception = assertThrows(WebSocketServerHandshakeException.class, new Executable() { + @Override + public void execute() throws Throwable { + serverHandshaker.handshake(null, request, null, null); + } + }); + + assertEquals("not a WebSocket request: a |Connection| header must includes a token 'Upgrade'", + exception.getMessage()); + assertTrue(request.release()); + } + + @Test + public void testHandshakeExceptionWhenInvalidConnectionHeader() { + final WebSocketServerHandshaker serverHandshaker = newHandshaker("ws://example.com/chat", + "chat", WebSocketDecoderConfig.DEFAULT); + final FullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, + "ws://example.com/chat"); + request.headers() + .set(HttpHeaderNames.HOST, "server.example.com") + .set(HttpHeaderNames.CONNECTION, "close") + .set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET) + .set(HttpHeaderNames.SEC_WEBSOCKET_KEY, "dGhlIHNhbXBsZSBub25jZQ==") + .set(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN, "http://example.com") + .set(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, "chat, superchat") + .set(HttpHeaderNames.SEC_WEBSOCKET_VERSION, "13"); + Throwable exception = assertThrows(WebSocketServerHandshakeException.class, new Executable() { + @Override + public void execute() throws Throwable { + serverHandshaker.handshake(null, request, null, null); + } + }); + + assertEquals("not a WebSocket request: a |Connection| header must includes a token 'Upgrade'", + exception.getMessage()); + assertTrue(request.release()); + } + + @Test + public void testHandshakeExceptionWhenInvalidUpgradeHeader() { + final WebSocketServerHandshaker serverHandshaker = newHandshaker("ws://example.com/chat", + "chat", WebSocketDecoderConfig.DEFAULT); + final FullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, + "ws://example.com/chat"); + request.headers() + .set(HttpHeaderNames.HOST, "server.example.com") + .set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE) + .set(HttpHeaderNames.UPGRADE, "my_websocket") + .set(HttpHeaderNames.SEC_WEBSOCKET_KEY, "dGhlIHNhbXBsZSBub25jZQ==") + .set(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN, "http://example.com") + .set(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, "chat, superchat") + .set(HttpHeaderNames.SEC_WEBSOCKET_VERSION, "13"); + Throwable exception = assertThrows(WebSocketServerHandshakeException.class, new Executable() { + @Override + public void execute() throws Throwable { + serverHandshaker.handshake(null, request, null, null); + } + }); + + assertEquals("not a WebSocket request: a |Upgrade| header must containing the value 'websocket'", + exception.getMessage()); + assertTrue(request.release()); + } + + private static void testCloseReason0(ChannelHandler... handlers) { + EmbeddedChannel ch = new EmbeddedChannel( + new HttpObjectAggregator(42)); + ch.pipeline().addLast(handlers); + testUpgrade0(ch, new WebSocketServerHandshaker13("ws://example.com/chat", "chat", + WebSocketDecoderConfig.newBuilder().maxFramePayloadLength(4).closeOnProtocolViolation(true).build())); + + ch.writeOutbound(new BinaryWebSocketFrame(Unpooled.wrappedBuffer(new byte[8]))); + ByteBuf buffer = ch.readOutbound(); + try { + ch.writeInbound(buffer); + fail(); + } catch (CorruptedWebSocketFrameException expected) { + // expected + } + ReferenceCounted closeMessage = ch.readOutbound(); + assertThat(closeMessage, instanceOf(ByteBuf.class)); + closeMessage.release(); + assertFalse(ch.finish()); + } + + private static void testUpgrade0(EmbeddedChannel ch, WebSocketServerHandshaker13 handshaker) { + FullHttpRequest req = new DefaultFullHttpRequest(HTTP_1_1, HttpMethod.GET, "/chat"); + req.headers().set(HttpHeaderNames.HOST, "server.example.com"); + req.headers().set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET); + req.headers().set(HttpHeaderNames.CONNECTION, "Upgrade"); + req.headers().set(HttpHeaderNames.SEC_WEBSOCKET_KEY, "dGhlIHNhbXBsZSBub25jZQ=="); + req.headers().set(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN, "http://example.com"); + req.headers().set(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, "chat, superchat"); + req.headers().set(HttpHeaderNames.SEC_WEBSOCKET_VERSION, "13"); + + handshaker.handshake(ch, req); + + ByteBuf resBuf = ch.readOutbound(); + + EmbeddedChannel ch2 = new EmbeddedChannel(new HttpResponseDecoder()); + ch2.writeInbound(resBuf); + HttpResponse res = ch2.readInbound(); + + assertEquals( + "s3pPLMBiTxaQ9kYGzzhZRbK+xOo=", res.headers().get(HttpHeaderNames.SEC_WEBSOCKET_ACCEPT)); + Iterator subProtocols = handshaker.subprotocols().iterator(); + if (subProtocols.hasNext()) { + assertEquals(subProtocols.next(), + res.headers().get(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL)); + } else { + assertNull(res.headers().get(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL)); + } + ReferenceCountUtil.release(res); + req.release(); + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshakerFactoryTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshakerFactoryTest.java new file mode 100644 index 0000000..f978e92 --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshakerFactoryTest.java @@ -0,0 +1,55 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.codec.http.websocketx; + +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpUtil; +import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.util.ReferenceCountUtil; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class WebSocketServerHandshakerFactoryTest { + + @Test + public void testUnsupportedVersion() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(); + WebSocketServerHandshakerFactory.sendUnsupportedVersionResponse(ch); + ch.runPendingTasks(); + Object msg = ch.readOutbound(); + + if (!(msg instanceof FullHttpResponse)) { + fail("Got wrong response " + msg); + } + FullHttpResponse response = (FullHttpResponse) msg; + + assertEquals(HttpResponseStatus.UPGRADE_REQUIRED, response.status()); + assertEquals(WebSocketVersion.V13.toHttpHeaderValue(), + response.headers().get(HttpHeaderNames.SEC_WEBSOCKET_VERSION)); + assertTrue(HttpUtil.isContentLengthSet(response)); + assertEquals(0, HttpUtil.getContentLength(response)); + + ReferenceCountUtil.release(response); + assertFalse(ch.finish()); + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshakerTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshakerTest.java new file mode 100644 index 0000000..1d50da9 --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketServerHandshakerTest.java @@ -0,0 +1,180 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelFuture; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.http.DefaultFullHttpRequest; +import io.netty.handler.codec.http.DefaultHttpHeaders; +import io.netty.handler.codec.http.DefaultHttpRequest; +import io.netty.handler.codec.http.DefaultLastHttpContent; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaderValues; +import io.netty.handler.codec.http.HttpHeaders; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpRequest; +import io.netty.handler.codec.http.HttpRequestDecoder; +import io.netty.handler.codec.http.HttpResponse; +import io.netty.handler.codec.http.HttpResponseDecoder; +import io.netty.handler.codec.http.HttpResponseEncoder; +import io.netty.handler.codec.http.HttpVersion; +import io.netty.handler.codec.http.LastHttpContent; +import io.netty.util.CharsetUtil; +import org.junit.jupiter.api.Test; + +import static io.netty.handler.codec.http.HttpResponseStatus.*; +import static org.junit.jupiter.api.Assertions.*; + +public abstract class WebSocketServerHandshakerTest { + + protected abstract WebSocketServerHandshaker newHandshaker(String webSocketURL, String subprotocols, + WebSocketDecoderConfig decoderConfig); + + protected abstract WebSocketVersion webSocketVersion(); + + @Test + public void testDuplicateHandshakeResponseHeaders() { + WebSocketServerHandshaker serverHandshaker = newHandshaker("ws://example.com/chat", + "chat", WebSocketDecoderConfig.DEFAULT); + FullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/chat"); + request.headers() + .set(HttpHeaderNames.HOST, "example.com") + .set(HttpHeaderNames.ORIGIN, "example.com") + .set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET) + .set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE) + .set(HttpHeaderNames.SEC_WEBSOCKET_KEY, "dGhlIHNhbXBsZSBub25jZQ==") + .set(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN, "http://example.com") + .set(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, "chat, superchat") + .set(HttpHeaderNames.WEBSOCKET_PROTOCOL, "chat, superchat") + .set(HttpHeaderNames.SEC_WEBSOCKET_VERSION, webSocketVersion().toAsciiString()); + HttpHeaders customResponseHeaders = new DefaultHttpHeaders(); + // set duplicate required headers and one custom + customResponseHeaders + .set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE) + .set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET) + .set(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, "superchat") + .set(HttpHeaderNames.WEBSOCKET_PROTOCOL, "superchat") + .set("custom", "header"); + + if (webSocketVersion() != WebSocketVersion.V00) { + customResponseHeaders.set(HttpHeaderNames.SEC_WEBSOCKET_ACCEPT, "12345"); + } + + FullHttpResponse response = null; + try { + response = serverHandshaker.newHandshakeResponse(request, customResponseHeaders); + HttpHeaders responseHeaders = response.headers(); + + assertEquals(1, responseHeaders.getAll(HttpHeaderNames.CONNECTION).size()); + assertEquals(1, responseHeaders.getAll(HttpHeaderNames.UPGRADE).size()); + assertTrue(responseHeaders.containsValue("custom", "header", true)); + + if (webSocketVersion() != WebSocketVersion.V00) { + assertFalse(responseHeaders.containsValue(HttpHeaderNames.SEC_WEBSOCKET_ACCEPT, "12345", false)); + assertEquals(1, responseHeaders.getAll(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL).size()); + assertEquals("chat", responseHeaders.get(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL)); + } else { + assertEquals(1, responseHeaders.getAll(HttpHeaderNames.WEBSOCKET_PROTOCOL).size()); + assertEquals("chat", responseHeaders.get(HttpHeaderNames.WEBSOCKET_PROTOCOL)); + } + } finally { + request.release(); + if (response != null) { + response.release(); + } + } + } + + @Test + public void testWebSocketServerHandshakeException() { + WebSocketServerHandshaker serverHandshaker = newHandshaker("ws://example.com/chat", + "chat", WebSocketDecoderConfig.DEFAULT); + + FullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, + "ws://example.com/chat"); + request.headers().set("x-client-header", "value"); + try { + serverHandshaker.handshake(null, request, null, null); + } catch (WebSocketServerHandshakeException exception) { + assertNotNull(exception.getMessage()); + assertEquals(request.headers(), exception.request().headers()); + assertEquals(HttpMethod.GET, exception.request().method()); + } finally { + request.release(); + } + } + + @Test + public void testHandshakeForHttpRequestWithoutAggregator() { + EmbeddedChannel channel = new EmbeddedChannel(new HttpRequestDecoder(), new HttpResponseEncoder()); + WebSocketServerHandshaker serverHandshaker = newHandshaker("ws://example.com/chat", + "chat", WebSocketDecoderConfig.DEFAULT); + + HttpRequest request = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/chat"); + request.headers() + .set(HttpHeaderNames.HOST, "example.com") + .set(HttpHeaderNames.ORIGIN, "example.com") + .set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET) + .set(HttpHeaderNames.CONNECTION, HttpHeaderValues.UPGRADE) + .set(HttpHeaderNames.SEC_WEBSOCKET_KEY, "dGhlIHNhbXBsZSBub25jZQ==") + .set(HttpHeaderNames.SEC_WEBSOCKET_ORIGIN, "http://example.com") + .set(HttpHeaderNames.SEC_WEBSOCKET_PROTOCOL, "chat, superchat") + .set(HttpHeaderNames.SEC_WEBSOCKET_KEY1, "4 @1 46546xW%0l 1 5") + .set(HttpHeaderNames.SEC_WEBSOCKET_KEY2, "12998 5 Y3 1 .P00") + .set(HttpHeaderNames.WEBSOCKET_PROTOCOL, "chat, superchat") + .set(HttpHeaderNames.SEC_WEBSOCKET_VERSION, webSocketVersion().toAsciiString()); + + ChannelFuture future = serverHandshaker.handshake(channel, request); + assertFalse(future.isDone()); + assertNotNull(channel.pipeline().get("handshaker")); + + if (webSocketVersion() != WebSocketVersion.V00) { + assertNull(channel.pipeline().get("httpAggregator")); + channel.writeInbound(LastHttpContent.EMPTY_LAST_CONTENT); + } else { + assertNotNull(channel.pipeline().get("httpAggregator")); + channel.writeInbound(new DefaultLastHttpContent( + Unpooled.copiedBuffer("^n:ds[4U", CharsetUtil.US_ASCII))); + } + + assertTrue(future.isDone()); + assertNull(channel.pipeline().get("handshaker")); + + ByteBuf byteBuf = channel.readOutbound(); + assertFalse(channel.finish()); + + channel = new EmbeddedChannel(new HttpResponseDecoder()); + assertTrue(channel.writeInbound(byteBuf)); + + HttpResponse response = channel.readInbound(); + assertEquals(SWITCHING_PROTOCOLS, response.status()); + assertTrue(response.headers().containsValue(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET, true)); + + LastHttpContent lastHttpContent = channel.readInbound(); + if (webSocketVersion() != WebSocketVersion.V00) { + assertEquals(LastHttpContent.EMPTY_LAST_CONTENT, lastHttpContent); + } else { + assertEquals("8jKS'y:G*Co,Wxa-", lastHttpContent.content().toString(CharsetUtil.US_ASCII)); + assertTrue(lastHttpContent.release()); + } + + assertFalse(channel.finish()); + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandlerTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandlerTest.java new file mode 100644 index 0000000..bfad1c6 --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketServerProtocolHandlerTest.java @@ -0,0 +1,519 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version + * 2.0 (the "License"); you may not use this file except in compliance with the + * License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http.websocketx; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelOutboundHandlerAdapter; +import io.netty.channel.ChannelPromise; +import io.netty.channel.embedded.EmbeddedChannel; + +import io.netty.handler.codec.http.DefaultHttpContent; +import io.netty.handler.codec.http.DefaultHttpRequest; +import io.netty.handler.codec.http.HttpClientCodec; +import io.netty.handler.codec.http.HttpHeaderValues; +import io.netty.handler.codec.http.HttpRequestDecoder; +import io.netty.handler.codec.http.HttpResponseEncoder; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpRequest; +import io.netty.handler.codec.http.HttpObjectAggregator; +import io.netty.handler.codec.http.HttpServerCodec; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.LastHttpContent; +import io.netty.util.CharsetUtil; +import io.netty.util.ReferenceCountUtil; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.ArrayDeque; +import java.util.Queue; + +import static io.netty.handler.codec.http.HttpResponseStatus.*; +import static io.netty.handler.codec.http.HttpVersion.*; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class WebSocketServerProtocolHandlerTest { + + private final Queue responses = new ArrayDeque(); + + @BeforeEach + public void setUp() { + responses.clear(); + } + + @Test + public void testHttpUpgradeRequestFull() { + testHttpUpgradeRequest0(true); + } + + @Test + public void testHttpUpgradeRequestNonFull() { + testHttpUpgradeRequest0(false); + } + + private void testHttpUpgradeRequest0(boolean full) { + EmbeddedChannel ch = createChannel(new MockOutboundHandler()); + ChannelHandlerContext handshakerCtx = ch.pipeline().context(WebSocketServerProtocolHandshakeHandler.class); + writeUpgradeRequest(ch, full); + + FullHttpResponse response = responses.remove(); + assertEquals(SWITCHING_PROTOCOLS, response.status()); + response.release(); + assertNotNull(WebSocketServerProtocolHandler.getHandshaker(handshakerCtx.channel())); + assertFalse(ch.finish()); + } + + @Test + public void testWebSocketServerProtocolHandshakeHandlerReplacedBeforeHandshake() { + EmbeddedChannel ch = createChannel(new MockOutboundHandler()); + ChannelHandlerContext handshakerCtx = ch.pipeline().context(WebSocketServerProtocolHandshakeHandler.class); + ch.pipeline().addLast(new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + if (evt instanceof WebSocketServerProtocolHandler.HandshakeComplete) { + // We should have removed the handler already. + assertNull(ctx.pipeline().context(WebSocketServerProtocolHandshakeHandler.class)); + } + } + }); + writeUpgradeRequest(ch); + + FullHttpResponse response = responses.remove(); + assertEquals(SWITCHING_PROTOCOLS, response.status()); + response.release(); + assertNotNull(WebSocketServerProtocolHandler.getHandshaker(handshakerCtx.channel())); + assertFalse(ch.finish()); + } + + @Test + public void testHttpUpgradeRequestInvalidUpgradeHeader() { + EmbeddedChannel ch = createChannel(); + FullHttpRequest httpRequestWithEntity = new WebSocketRequestBuilder().httpVersion(HTTP_1_1) + .method(HttpMethod.GET) + .uri("/test") + .connection("Upgrade") + .version00() + .upgrade("BogusSocket") + .build(); + + ch.writeInbound(httpRequestWithEntity); + + FullHttpResponse response = responses.remove(); + assertEquals(BAD_REQUEST, response.status()); + assertEquals("not a WebSocket handshake request: missing upgrade", getResponseMessage(response)); + response.release(); + assertFalse(ch.finish()); + } + + @Test + public void testHttpUpgradeRequestMissingWSKeyHeader() { + EmbeddedChannel ch = createChannel(); + HttpRequest httpRequest = new WebSocketRequestBuilder().httpVersion(HTTP_1_1) + .method(HttpMethod.GET) + .uri("/test") + .key(null) + .connection("Upgrade") + .upgrade(HttpHeaderValues.WEBSOCKET) + .version13() + .build(); + + ch.writeInbound(httpRequest); + + FullHttpResponse response = responses.remove(); + assertEquals(BAD_REQUEST, response.status()); + assertEquals("not a WebSocket request: missing key", getResponseMessage(response)); + response.release(); + assertFalse(ch.finish()); + } + + @Test + public void testCreateUTF8Validator() { + WebSocketServerProtocolConfig config = WebSocketServerProtocolConfig.newBuilder() + .websocketPath("/test") + .withUTF8Validator(true) + .build(); + + EmbeddedChannel ch = new EmbeddedChannel( + new WebSocketServerProtocolHandler(config), + new HttpRequestDecoder(), + new HttpResponseEncoder(), + new MockOutboundHandler()); + writeUpgradeRequest(ch); + + FullHttpResponse response = responses.remove(); + assertEquals(SWITCHING_PROTOCOLS, response.status()); + response.release(); + + assertNotNull(ch.pipeline().get(Utf8FrameValidator.class)); + } + + @Test + public void testDoNotCreateUTF8Validator() { + WebSocketServerProtocolConfig config = WebSocketServerProtocolConfig.newBuilder() + .websocketPath("/test") + .withUTF8Validator(false) + .build(); + + EmbeddedChannel ch = new EmbeddedChannel( + new WebSocketServerProtocolHandler(config), + new HttpRequestDecoder(), + new HttpResponseEncoder(), + new MockOutboundHandler()); + writeUpgradeRequest(ch); + + FullHttpResponse response = responses.remove(); + assertEquals(SWITCHING_PROTOCOLS, response.status()); + response.release(); + + assertNull(ch.pipeline().get(Utf8FrameValidator.class)); + } + + @Test + public void testHandleTextFrame() { + CustomTextFrameHandler customTextFrameHandler = new CustomTextFrameHandler(); + EmbeddedChannel ch = createChannel(customTextFrameHandler); + writeUpgradeRequest(ch); + + FullHttpResponse response = responses.remove(); + assertEquals(SWITCHING_PROTOCOLS, response.status()); + response.release(); + + if (ch.pipeline().context(HttpRequestDecoder.class) != null) { + // Removing the HttpRequestDecoder because we are writing a TextWebSocketFrame and thus + // decoding is not necessary. + ch.pipeline().remove(HttpRequestDecoder.class); + } + + ch.writeInbound(new TextWebSocketFrame("payload")); + + assertEquals("processed: payload", customTextFrameHandler.getContent()); + assertFalse(ch.finish()); + } + + @Test + public void testCheckWebSocketPathStartWithSlash() { + WebSocketRequestBuilder builder = new WebSocketRequestBuilder().httpVersion(HTTP_1_1) + .method(HttpMethod.GET) + .key(HttpHeaderNames.SEC_WEBSOCKET_KEY) + .connection("Upgrade") + .upgrade(HttpHeaderValues.WEBSOCKET) + .version13(); + + WebSocketServerProtocolConfig config = WebSocketServerProtocolConfig.newBuilder() + .websocketPath("/") + .checkStartsWith(true) + .build(); + + FullHttpResponse response; + + createChannel(config, null).writeInbound(builder.uri("/test").build()); + response = responses.remove(); + assertEquals(SWITCHING_PROTOCOLS, response.status()); + response.release(); + + createChannel(config, null).writeInbound(builder.uri("/?q=v").build()); + response = responses.remove(); + assertEquals(SWITCHING_PROTOCOLS, response.status()); + response.release(); + + createChannel(config, null).writeInbound(builder.uri("/").build()); + response = responses.remove(); + assertEquals(SWITCHING_PROTOCOLS, response.status()); + response.release(); + } + + @Test + public void testCheckValidWebSocketPath() { + HttpRequest httpRequest = new WebSocketRequestBuilder().httpVersion(HTTP_1_1) + .method(HttpMethod.GET) + .uri("/test") + .key(HttpHeaderNames.SEC_WEBSOCKET_KEY) + .connection("Upgrade") + .upgrade(HttpHeaderValues.WEBSOCKET) + .version13() + .build(); + + WebSocketServerProtocolConfig config = WebSocketServerProtocolConfig.newBuilder() + .websocketPath("/test") + .checkStartsWith(true) + .build(); + + EmbeddedChannel ch = new EmbeddedChannel( + new WebSocketServerProtocolHandler(config), + new HttpRequestDecoder(), + new HttpResponseEncoder(), + new MockOutboundHandler()); + ch.writeInbound(httpRequest); + + FullHttpResponse response = responses.remove(); + assertEquals(SWITCHING_PROTOCOLS, response.status()); + response.release(); + } + + @Test + public void testCheckInvalidWebSocketPath() { + HttpRequest httpRequest = new WebSocketRequestBuilder().httpVersion(HTTP_1_1) + .method(HttpMethod.GET) + .uri("/testabc") + .key(HttpHeaderNames.SEC_WEBSOCKET_KEY) + .connection("Upgrade") + .upgrade(HttpHeaderValues.WEBSOCKET) + .version13() + .build(); + + WebSocketServerProtocolConfig config = WebSocketServerProtocolConfig.newBuilder() + .websocketPath("/test") + .checkStartsWith(true) + .build(); + + EmbeddedChannel ch = new EmbeddedChannel( + new WebSocketServerProtocolHandler(config), + new HttpRequestDecoder(), + new HttpResponseEncoder(), + new MockOutboundHandler()); + ch.writeInbound(httpRequest); + + ChannelHandlerContext handshakerCtx = ch.pipeline().context(WebSocketServerProtocolHandshakeHandler.class); + assertNull(WebSocketServerProtocolHandler.getHandshaker(handshakerCtx.channel())); + } + + @Test + public void testExplicitCloseFrameSentWhenServerChannelClosed() throws Exception { + WebSocketCloseStatus closeStatus = WebSocketCloseStatus.ENDPOINT_UNAVAILABLE; + EmbeddedChannel client = createClient(); + EmbeddedChannel server = createServer(); + + assertFalse(server.writeInbound(client.readOutbound())); + assertFalse(client.writeInbound(server.readOutbound())); + + // When server channel closed with explicit close-frame + assertTrue(server.writeOutbound(new CloseWebSocketFrame(closeStatus))); + server.close(); + + // Then client receives provided close-frame + assertTrue(client.writeInbound(server.readOutbound())); + assertFalse(server.isOpen()); + + CloseWebSocketFrame closeMessage = client.readInbound(); + assertEquals(closeMessage.statusCode(), closeStatus.code()); + closeMessage.release(); + + client.close(); + assertTrue(ReferenceCountUtil.release(client.readOutbound())); + assertFalse(client.finishAndReleaseAll()); + assertFalse(server.finishAndReleaseAll()); + } + + @Test + public void testCloseFrameSentWhenServerChannelClosedSilently() throws Exception { + EmbeddedChannel client = createClient(); + EmbeddedChannel server = createServer(); + + assertFalse(server.writeInbound(client.readOutbound())); + assertFalse(client.writeInbound(server.readOutbound())); + + // When server channel closed without explicit close-frame + server.close(); + + // Then client receives NORMAL_CLOSURE close-frame + assertTrue(client.writeInbound(server.readOutbound())); + assertFalse(server.isOpen()); + + CloseWebSocketFrame closeMessage = client.readInbound(); + assertEquals(closeMessage.statusCode(), WebSocketCloseStatus.NORMAL_CLOSURE.code()); + closeMessage.release(); + + client.close(); + assertTrue(ReferenceCountUtil.release(client.readOutbound())); + assertFalse(client.finishAndReleaseAll()); + assertFalse(server.finishAndReleaseAll()); + } + + @Test + public void testExplicitCloseFrameSentWhenClientChannelClosed() throws Exception { + WebSocketCloseStatus closeStatus = WebSocketCloseStatus.INVALID_PAYLOAD_DATA; + EmbeddedChannel client = createClient(); + EmbeddedChannel server = createServer(); + + assertFalse(server.writeInbound(client.readOutbound())); + assertFalse(client.writeInbound(server.readOutbound())); + + // When client channel closed with explicit close-frame + assertTrue(client.writeOutbound(new CloseWebSocketFrame(closeStatus))); + client.close(); + + // Then client receives provided close-frame + assertFalse(server.writeInbound(client.readOutbound())); + assertFalse(client.isOpen()); + assertFalse(server.isOpen()); + + CloseWebSocketFrame closeMessage = decode(server.readOutbound(), CloseWebSocketFrame.class); + assertEquals(closeMessage.statusCode(), closeStatus.code()); + closeMessage.release(); + + assertFalse(client.finishAndReleaseAll()); + assertFalse(server.finishAndReleaseAll()); + } + + @Test + public void testCloseFrameSentWhenClientChannelClosedSilently() throws Exception { + EmbeddedChannel client = createClient(); + EmbeddedChannel server = createServer(); + + assertFalse(server.writeInbound(client.readOutbound())); + assertFalse(client.writeInbound(server.readOutbound())); + + // When client channel closed without explicit close-frame + client.close(); + + // Then server receives NORMAL_CLOSURE close-frame + assertFalse(server.writeInbound(client.readOutbound())); + assertFalse(client.isOpen()); + assertFalse(server.isOpen()); + + CloseWebSocketFrame closeMessage = decode(server.readOutbound(), CloseWebSocketFrame.class); + assertEquals(closeMessage, new CloseWebSocketFrame(WebSocketCloseStatus.NORMAL_CLOSURE)); + closeMessage.release(); + + assertFalse(client.finishAndReleaseAll()); + assertFalse(server.finishAndReleaseAll()); + } + + private EmbeddedChannel createClient(ChannelHandler... handlers) throws Exception { + WebSocketClientProtocolConfig clientConfig = WebSocketClientProtocolConfig.newBuilder() + .webSocketUri("http://test/test") + .dropPongFrames(false) + .handleCloseFrames(false) + .build(); + EmbeddedChannel ch = new EmbeddedChannel(false, false, + new HttpClientCodec(), + new HttpObjectAggregator(8192), + new WebSocketClientProtocolHandler(clientConfig) + ); + ch.pipeline().addLast(handlers); + ch.register(); + return ch; + } + + private EmbeddedChannel createServer(ChannelHandler... handlers) throws Exception { + WebSocketServerProtocolConfig serverConfig = WebSocketServerProtocolConfig.newBuilder() + .websocketPath("/test") + .dropPongFrames(false) + .build(); + EmbeddedChannel ch = new EmbeddedChannel(false, false, + new HttpServerCodec(), + new HttpObjectAggregator(8192), + new WebSocketServerProtocolHandler(serverConfig) + ); + ch.pipeline().addLast(handlers); + ch.register(); + return ch; + } + + @SuppressWarnings("SameParameterValue") + private T decode(ByteBuf input, Class clazz) { + EmbeddedChannel ch = new EmbeddedChannel(new WebSocket13FrameDecoder(true, false, 65536, true)); + assertTrue(ch.writeInbound(input)); + Object decoded = ch.readInbound(); + assertNotNull(decoded); + assertFalse(ch.finish()); + return clazz.cast(decoded); + } + + private EmbeddedChannel createChannel() { + return createChannel(null); + } + + private EmbeddedChannel createChannel(ChannelHandler handler) { + WebSocketServerProtocolConfig serverConfig = WebSocketServerProtocolConfig.newBuilder() + .websocketPath("/test") + .sendCloseFrame(null) + .build(); + return createChannel(serverConfig, handler); + } + + private EmbeddedChannel createChannel(WebSocketServerProtocolConfig serverConfig, ChannelHandler handler) { + return new EmbeddedChannel( + new WebSocketServerProtocolHandler(serverConfig), + new HttpRequestDecoder(), + new HttpResponseEncoder(), + new MockOutboundHandler(), + handler); + } + + private static void writeUpgradeRequest(EmbeddedChannel ch) { + writeUpgradeRequest(ch, true); + } + + private static void writeUpgradeRequest(EmbeddedChannel ch, boolean full) { + HttpRequest request = WebSocketRequestBuilder.successful(); + if (full) { + ch.writeInbound(request); + } else { + if (request instanceof FullHttpRequest) { + FullHttpRequest fullHttpRequest = (FullHttpRequest) request; + HttpRequest req = new DefaultHttpRequest(fullHttpRequest.protocolVersion(), fullHttpRequest.method(), + fullHttpRequest.uri(), fullHttpRequest.headers().copy()); + ch.writeInbound(req); + ch.writeInbound(new DefaultHttpContent(fullHttpRequest.content().copy())); + ch.writeInbound(LastHttpContent.EMPTY_LAST_CONTENT); + fullHttpRequest.release(); + } else { + ch.writeInbound(request); + } + } + } + + private static String getResponseMessage(FullHttpResponse response) { + return response.content().toString(CharsetUtil.UTF_8); + } + + private class MockOutboundHandler extends ChannelOutboundHandlerAdapter { + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) { + responses.add((FullHttpResponse) msg); + promise.setSuccess(); + } + + @Override + public void flush(ChannelHandlerContext ctx) { + } + } + + private static class CustomTextFrameHandler extends ChannelInboundHandlerAdapter { + private String content; + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + assertNull(content); + content = "processed: " + ((TextWebSocketFrame) msg).text(); + ReferenceCountUtil.release(msg); + } + + String getContent() { + return content; + } + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketUtf8FrameValidatorTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketUtf8FrameValidatorTest.java new file mode 100644 index 0000000..72905aa --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketUtf8FrameValidatorTest.java @@ -0,0 +1,79 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx; + +import io.netty.buffer.Unpooled; +import io.netty.channel.embedded.EmbeddedChannel; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class WebSocketUtf8FrameValidatorTest { + + @Test + public void testCorruptedFrameExceptionInFinish() { + assertCorruptedFrameExceptionHandling(new byte[]{-50}); + } + + @Test + public void testCorruptedFrameExceptionInCheck() { + assertCorruptedFrameExceptionHandling(new byte[]{-8, -120, -128, -128, -128}); + } + + @Test + void testNotCloseOnProtocolViolation() { + final EmbeddedChannel channel = new EmbeddedChannel(new Utf8FrameValidator(false)); + final TextWebSocketFrame frame = new TextWebSocketFrame(Unpooled.copiedBuffer(new byte[] { -50 })); + assertThrows(CorruptedWebSocketFrameException.class, new Executable() { + @Override + public void execute() throws Throwable { + channel.writeInbound(frame); + } + }, "bytes are not UTF-8"); + + assertTrue(channel.isActive()); + assertFalse(channel.finish()); + assertEquals(0, frame.refCnt()); + } + + private void assertCorruptedFrameExceptionHandling(byte[] data) { + final EmbeddedChannel channel = new EmbeddedChannel(new Utf8FrameValidator()); + final TextWebSocketFrame frame = new TextWebSocketFrame(Unpooled.copiedBuffer(data)); + assertThrows(CorruptedWebSocketFrameException.class, new Executable() { + @Override + public void execute() throws Throwable { + channel.writeInbound(frame); + } + }, "bytes are not UTF-8"); + + assertFalse(channel.isActive()); + + CloseWebSocketFrame closeFrame = channel.readOutbound(); + assertNotNull(closeFrame); + assertEquals("bytes are not UTF-8", closeFrame.reasonText()); + assertEquals(1007, closeFrame.statusCode()); + assertTrue(closeFrame.release()); + + assertEquals(0, frame.refCnt()); + assertFalse(channel.finish()); + } + +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketUtilTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketUtilTest.java new file mode 100644 index 0000000..25f979e --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/WebSocketUtilTest.java @@ -0,0 +1,74 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx; + +import org.junit.jupiter.api.Test; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import io.netty.handler.codec.base64.Base64; +import io.netty.util.CharsetUtil; +import io.netty.util.internal.EmptyArrays; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class WebSocketUtilTest { + + // how many times do we want to run each random variable checker + private static final int NUM_ITERATIONS = 1000; + + private static void assertRandomWithinBoundaries(int min, int max) { + int r = WebSocketUtil.randomNumber(min, max); + assertTrue(min <= r && r <= max); + } + + @Test + public void testRandomNumberGenerator() { + int iteration = 0; + while (++iteration < NUM_ITERATIONS) { + assertRandomWithinBoundaries(0, 1); + assertRandomWithinBoundaries(0, 1); + assertRandomWithinBoundaries(-1, 1); + assertRandomWithinBoundaries(-1, 0); + } + } + + @Test + public void testBase64() { + String base64 = WebSocketUtil.base64(EmptyArrays.EMPTY_BYTES); + assertNotNull(base64); + assertTrue(base64.isEmpty()); + + base64 = WebSocketUtil.base64("foo".getBytes(CharsetUtil.UTF_8)); + assertEquals(base64, "Zm9v"); + + base64 = WebSocketUtil.base64("bar".getBytes(CharsetUtil.UTF_8)); + ByteBuf src = Unpooled.wrappedBuffer(base64.getBytes(CharsetUtil.UTF_8)); + try { + ByteBuf dst = Base64.decode(src); + try { + assertEquals(new String(ByteBufUtil.getBytes(dst), CharsetUtil.UTF_8), "bar"); + } finally { + dst.release(); + } + } finally { + src.release(); + } + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketClientExtensionHandlerTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketClientExtensionHandlerTest.java new file mode 100644 index 0000000..e1eabb9 --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketClientExtensionHandlerTest.java @@ -0,0 +1,277 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx.extensions; + +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.CodecException; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpRequest; +import io.netty.handler.codec.http.HttpResponse; + +import java.util.Collections; +import java.util.List; + +import org.junit.jupiter.api.Test; + +import static io.netty.handler.codec.http.websocketx.extensions.WebSocketExtensionTestUtil.*; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class WebSocketClientExtensionHandlerTest { + + WebSocketClientExtensionHandshaker mainHandshakerMock = + mock(WebSocketClientExtensionHandshaker.class, "mainHandshaker"); + WebSocketClientExtensionHandshaker fallbackHandshakerMock = + mock(WebSocketClientExtensionHandshaker.class, "fallbackHandshaker"); + WebSocketClientExtension mainExtensionMock = + mock(WebSocketClientExtension.class, "mainExtension"); + WebSocketClientExtension fallbackExtensionMock = + mock(WebSocketClientExtension.class, "fallbackExtension"); + + @Test + public void testMainSuccess() { + // initialize + when(mainHandshakerMock.newRequestData()). + thenReturn(new WebSocketExtensionData("main", Collections.emptyMap())); + when(mainHandshakerMock.handshakeExtension(any(WebSocketExtensionData.class))).thenReturn(mainExtensionMock); + when(fallbackHandshakerMock.newRequestData()). + thenReturn(new WebSocketExtensionData("fallback", Collections.emptyMap())); + when(mainExtensionMock.rsv()).thenReturn(WebSocketExtension.RSV1); + when(mainExtensionMock.newExtensionEncoder()).thenReturn(new DummyEncoder()); + when(mainExtensionMock.newExtensionDecoder()).thenReturn(new DummyDecoder()); + + // execute + EmbeddedChannel ch = new EmbeddedChannel(new WebSocketClientExtensionHandler( + mainHandshakerMock, fallbackHandshakerMock)); + + HttpRequest req = newUpgradeRequest(null); + ch.writeOutbound(req); + + HttpRequest req2 = ch.readOutbound(); + List reqExts = WebSocketExtensionUtil.extractExtensions( + req2.headers().get(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS)); + + HttpResponse res = newUpgradeResponse("main"); + ch.writeInbound(res); + + HttpResponse res2 = ch.readInbound(); + List resExts = WebSocketExtensionUtil.extractExtensions( + res2.headers().get(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS)); + + // test + assertEquals(2, reqExts.size()); + assertEquals("main", reqExts.get(0).name()); + assertEquals("fallback", reqExts.get(1).name()); + + assertEquals(1, resExts.size()); + assertEquals("main", resExts.get(0).name()); + assertTrue(resExts.get(0).parameters().isEmpty()); + assertNotNull(ch.pipeline().get(DummyDecoder.class)); + assertNotNull(ch.pipeline().get(DummyEncoder.class)); + + verify(mainHandshakerMock).newRequestData(); + verify(mainHandshakerMock).handshakeExtension(any(WebSocketExtensionData.class)); + verify(fallbackHandshakerMock).newRequestData(); + verify(mainExtensionMock, atLeastOnce()).rsv(); + verify(mainExtensionMock).newExtensionEncoder(); + verify(mainExtensionMock).newExtensionDecoder(); + } + + @Test + public void testFallbackSuccess() { + // initialize + when(mainHandshakerMock.newRequestData()). + thenReturn(new WebSocketExtensionData("main", Collections.emptyMap())); + when(mainHandshakerMock.handshakeExtension(any(WebSocketExtensionData.class))).thenReturn(null); + when(fallbackHandshakerMock.newRequestData()). + thenReturn(new WebSocketExtensionData("fallback", Collections.emptyMap())); + when(fallbackHandshakerMock.handshakeExtension( + any(WebSocketExtensionData.class))).thenReturn(fallbackExtensionMock); + when(fallbackExtensionMock.rsv()).thenReturn(WebSocketExtension.RSV1); + when(fallbackExtensionMock.newExtensionEncoder()).thenReturn(new DummyEncoder()); + when(fallbackExtensionMock.newExtensionDecoder()).thenReturn(new DummyDecoder()); + + // execute + EmbeddedChannel ch = new EmbeddedChannel(new WebSocketClientExtensionHandler( + mainHandshakerMock, fallbackHandshakerMock)); + + HttpRequest req = newUpgradeRequest(null); + ch.writeOutbound(req); + + HttpRequest req2 = ch.readOutbound(); + List reqExts = WebSocketExtensionUtil.extractExtensions( + req2.headers().get(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS)); + + HttpResponse res = newUpgradeResponse("fallback"); + ch.writeInbound(res); + + HttpResponse res2 = ch.readInbound(); + List resExts = WebSocketExtensionUtil.extractExtensions( + res2.headers().get(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS)); + + // test + assertEquals(2, reqExts.size()); + assertEquals("main", reqExts.get(0).name()); + assertEquals("fallback", reqExts.get(1).name()); + + assertEquals(1, resExts.size()); + assertEquals("fallback", resExts.get(0).name()); + assertTrue(resExts.get(0).parameters().isEmpty()); + assertNotNull(ch.pipeline().get(DummyDecoder.class)); + assertNotNull(ch.pipeline().get(DummyEncoder.class)); + + verify(mainHandshakerMock).newRequestData(); + verify(mainHandshakerMock).handshakeExtension(any(WebSocketExtensionData.class)); + verify(fallbackHandshakerMock).newRequestData(); + verify(fallbackHandshakerMock).handshakeExtension(any(WebSocketExtensionData.class)); + verify(fallbackExtensionMock, atLeastOnce()).rsv(); + verify(fallbackExtensionMock).newExtensionEncoder(); + verify(fallbackExtensionMock).newExtensionDecoder(); + } + + @Test + public void testAllSuccess() { + // initialize + when(mainHandshakerMock.newRequestData()). + thenReturn(new WebSocketExtensionData("main", Collections.emptyMap())); + when(mainHandshakerMock.handshakeExtension( + webSocketExtensionDataMatcher("main"))).thenReturn(mainExtensionMock); + when(mainHandshakerMock.handshakeExtension( + webSocketExtensionDataMatcher("fallback"))).thenReturn(null); + when(fallbackHandshakerMock.newRequestData()). + thenReturn(new WebSocketExtensionData("fallback", Collections.emptyMap())); + when(fallbackHandshakerMock.handshakeExtension( + webSocketExtensionDataMatcher("main"))).thenReturn(null); + when(fallbackHandshakerMock.handshakeExtension( + webSocketExtensionDataMatcher("fallback"))).thenReturn(fallbackExtensionMock); + + DummyEncoder mainEncoder = new DummyEncoder(); + DummyDecoder mainDecoder = new DummyDecoder(); + when(mainExtensionMock.rsv()).thenReturn(WebSocketExtension.RSV1); + when(mainExtensionMock.newExtensionEncoder()).thenReturn(mainEncoder); + when(mainExtensionMock.newExtensionDecoder()).thenReturn(mainDecoder); + + Dummy2Encoder fallbackEncoder = new Dummy2Encoder(); + Dummy2Decoder fallbackDecoder = new Dummy2Decoder(); + when(fallbackExtensionMock.rsv()).thenReturn(WebSocketExtension.RSV2); + when(fallbackExtensionMock.newExtensionEncoder()).thenReturn(fallbackEncoder); + when(fallbackExtensionMock.newExtensionDecoder()).thenReturn(fallbackDecoder); + + // execute + EmbeddedChannel ch = new EmbeddedChannel(new WebSocketClientExtensionHandler( + mainHandshakerMock, fallbackHandshakerMock)); + + HttpRequest req = newUpgradeRequest(null); + ch.writeOutbound(req); + + HttpRequest req2 = ch.readOutbound(); + List reqExts = WebSocketExtensionUtil.extractExtensions( + req2.headers().get(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS)); + + HttpResponse res = newUpgradeResponse("main, fallback"); + ch.writeInbound(res); + + HttpResponse res2 = ch.readInbound(); + List resExts = WebSocketExtensionUtil.extractExtensions( + res2.headers().get(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS)); + + // test + assertEquals(2, reqExts.size()); + assertEquals("main", reqExts.get(0).name()); + assertEquals("fallback", reqExts.get(1).name()); + + assertEquals(2, resExts.size()); + assertEquals("main", resExts.get(0).name()); + assertEquals("fallback", resExts.get(1).name()); + assertNotNull(ch.pipeline().context(mainEncoder)); + assertNotNull(ch.pipeline().context(mainDecoder)); + assertNotNull(ch.pipeline().context(fallbackEncoder)); + assertNotNull(ch.pipeline().context(fallbackDecoder)); + + verify(mainHandshakerMock).newRequestData(); + verify(mainHandshakerMock).handshakeExtension(webSocketExtensionDataMatcher("main")); + verify(mainHandshakerMock).handshakeExtension(webSocketExtensionDataMatcher("fallback")); + verify(fallbackHandshakerMock).newRequestData(); + verify(fallbackHandshakerMock).handshakeExtension(webSocketExtensionDataMatcher("fallback")); + verify(mainExtensionMock, atLeastOnce()).rsv(); + verify(mainExtensionMock).newExtensionEncoder(); + verify(mainExtensionMock).newExtensionDecoder(); + verify(fallbackExtensionMock, atLeastOnce()).rsv(); + verify(fallbackExtensionMock).newExtensionEncoder(); + verify(fallbackExtensionMock).newExtensionDecoder(); + } + + @Test + public void testIfMainAndFallbackUseRSV1WillFail() { + // initialize + when(mainHandshakerMock.newRequestData()). + thenReturn(new WebSocketExtensionData("main", Collections.emptyMap())); + when(mainHandshakerMock.handshakeExtension( + webSocketExtensionDataMatcher("main"))).thenReturn(mainExtensionMock); + when(mainHandshakerMock.handshakeExtension( + webSocketExtensionDataMatcher("fallback"))).thenReturn(null); + when(fallbackHandshakerMock.newRequestData()). + thenReturn(new WebSocketExtensionData("fallback", Collections.emptyMap())); + when(fallbackHandshakerMock.handshakeExtension( + webSocketExtensionDataMatcher("main"))).thenReturn(null); + when(fallbackHandshakerMock.handshakeExtension( + webSocketExtensionDataMatcher("fallback"))).thenReturn(fallbackExtensionMock); + when(mainExtensionMock.rsv()).thenReturn(WebSocketExtension.RSV1); + when(fallbackExtensionMock.rsv()).thenReturn(WebSocketExtension.RSV1); + + // execute + EmbeddedChannel ch = new EmbeddedChannel(new WebSocketClientExtensionHandler( + mainHandshakerMock, fallbackHandshakerMock)); + + HttpRequest req = newUpgradeRequest(null); + ch.writeOutbound(req); + + HttpRequest req2 = ch.readOutbound(); + List reqExts = WebSocketExtensionUtil.extractExtensions( + req2.headers().get(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS)); + + HttpResponse res = newUpgradeResponse("main, fallback"); + try { + ch.writeInbound(res); + } catch (CodecException e) { + return; + } + fail("Expected to encounter a CodecException"); + + // test + assertEquals(2, reqExts.size()); + assertEquals("main", reqExts.get(0).name()); + assertEquals("fallback", reqExts.get(1).name()); + + verify(mainHandshakerMock).newRequestData(); + verify(mainHandshakerMock, atLeastOnce()).handshakeExtension(webSocketExtensionDataMatcher("main")); + verify(mainHandshakerMock, atLeastOnce()).handshakeExtension(webSocketExtensionDataMatcher("fallback")); + + verify(fallbackHandshakerMock).newRequestData(); + verify(fallbackHandshakerMock, atLeastOnce()).handshakeExtension(webSocketExtensionDataMatcher("main")); + verify(fallbackHandshakerMock, atLeastOnce()).handshakeExtension(webSocketExtensionDataMatcher("fallback")); + + verify(mainExtensionMock, atLeastOnce()).rsv(); + verify(fallbackExtensionMock, atLeastOnce()).rsv(); + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketExtensionFilterProviderTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketExtensionFilterProviderTest.java new file mode 100644 index 0000000..ef80950 --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketExtensionFilterProviderTest.java @@ -0,0 +1,33 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx.extensions; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; + +public class WebSocketExtensionFilterProviderTest { + + @Test + public void testDefaultExtensionFilterProvider() { + WebSocketExtensionFilterProvider defaultProvider = WebSocketExtensionFilterProvider.DEFAULT; + assertNotNull(defaultProvider); + + assertEquals(WebSocketExtensionFilter.NEVER_SKIP, defaultProvider.decoderFilter()); + assertEquals(WebSocketExtensionFilter.NEVER_SKIP, defaultProvider.encoderFilter()); + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketExtensionFilterTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketExtensionFilterTest.java new file mode 100644 index 0000000..7eced82 --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketExtensionFilterTest.java @@ -0,0 +1,88 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx.extensions; + +import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; +import io.netty.handler.codec.http.websocketx.CloseWebSocketFrame; +import io.netty.handler.codec.http.websocketx.ContinuationWebSocketFrame; +import io.netty.handler.codec.http.websocketx.PingWebSocketFrame; +import io.netty.handler.codec.http.websocketx.PongWebSocketFrame; +import io.netty.handler.codec.http.websocketx.TextWebSocketFrame; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class WebSocketExtensionFilterTest { + + @Test + public void testNeverSkip() { + WebSocketExtensionFilter neverSkip = WebSocketExtensionFilter.NEVER_SKIP; + + BinaryWebSocketFrame binaryFrame = new BinaryWebSocketFrame(); + assertFalse(neverSkip.mustSkip(binaryFrame)); + assertTrue(binaryFrame.release()); + + TextWebSocketFrame textFrame = new TextWebSocketFrame(); + assertFalse(neverSkip.mustSkip(textFrame)); + assertTrue(textFrame.release()); + + PingWebSocketFrame pingFrame = new PingWebSocketFrame(); + assertFalse(neverSkip.mustSkip(pingFrame)); + assertTrue(pingFrame.release()); + + PongWebSocketFrame pongFrame = new PongWebSocketFrame(); + assertFalse(neverSkip.mustSkip(pongFrame)); + assertTrue(pongFrame.release()); + + CloseWebSocketFrame closeFrame = new CloseWebSocketFrame(); + assertFalse(neverSkip.mustSkip(closeFrame)); + assertTrue(closeFrame.release()); + + ContinuationWebSocketFrame continuationFrame = new ContinuationWebSocketFrame(); + assertFalse(neverSkip.mustSkip(continuationFrame)); + assertTrue(continuationFrame.release()); + } + + @Test + public void testAlwaysSkip() { + WebSocketExtensionFilter neverSkip = WebSocketExtensionFilter.ALWAYS_SKIP; + + BinaryWebSocketFrame binaryFrame = new BinaryWebSocketFrame(); + assertTrue(neverSkip.mustSkip(binaryFrame)); + assertTrue(binaryFrame.release()); + + TextWebSocketFrame textFrame = new TextWebSocketFrame(); + assertTrue(neverSkip.mustSkip(textFrame)); + assertTrue(textFrame.release()); + + PingWebSocketFrame pingFrame = new PingWebSocketFrame(); + assertTrue(neverSkip.mustSkip(pingFrame)); + assertTrue(pingFrame.release()); + + PongWebSocketFrame pongFrame = new PongWebSocketFrame(); + assertTrue(neverSkip.mustSkip(pongFrame)); + assertTrue(pongFrame.release()); + + CloseWebSocketFrame closeFrame = new CloseWebSocketFrame(); + assertTrue(neverSkip.mustSkip(closeFrame)); + assertTrue(closeFrame.release()); + + ContinuationWebSocketFrame continuationFrame = new ContinuationWebSocketFrame(); + assertTrue(neverSkip.mustSkip(continuationFrame)); + assertTrue(continuationFrame.release()); + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketExtensionTestUtil.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketExtensionTestUtil.java new file mode 100644 index 0000000..7102bde --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketExtensionTestUtil.java @@ -0,0 +1,121 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version + * 2.0 (the "License"); you may not use this file except in compliance with the + * License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http.websocketx.extensions; + +import java.util.List; + +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaderValues; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.http.DefaultHttpRequest; +import io.netty.handler.codec.http.DefaultHttpResponse; +import io.netty.handler.codec.http.websocketx.WebSocketFrame; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpRequest; +import io.netty.handler.codec.http.HttpResponse; +import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.handler.codec.http.HttpVersion; +import org.mockito.ArgumentMatcher; + +import static org.mockito.Mockito.argThat; + +public final class WebSocketExtensionTestUtil { + + public static HttpRequest newUpgradeRequest(String ext) { + HttpRequest req = new DefaultHttpRequest( + HttpVersion.HTTP_1_1, HttpMethod.GET, "/chat"); + + req.headers().set(HttpHeaderNames.HOST, "server.example.com"); + req.headers().set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET.toString().toLowerCase()); + req.headers().set(HttpHeaderNames.CONNECTION, "Upgrade"); + req.headers().set(HttpHeaderNames.ORIGIN, "http://example.com"); + if (ext != null) { + req.headers().set(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS, ext); + } + + return req; + } + + public static HttpResponse newUpgradeResponse(String ext) { + HttpResponse res = new DefaultHttpResponse( + HttpVersion.HTTP_1_1, HttpResponseStatus.SWITCHING_PROTOCOLS); + + res.headers().set(HttpHeaderNames.HOST, "server.example.com"); + res.headers().set(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET.toString().toLowerCase()); + res.headers().set(HttpHeaderNames.CONNECTION, "Upgrade"); + res.headers().set(HttpHeaderNames.ORIGIN, "http://example.com"); + if (ext != null) { + res.headers().set(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS, ext); + } + + return res; + } + + static final class WebSocketExtensionDataMatcher implements ArgumentMatcher { + + private final String name; + + WebSocketExtensionDataMatcher(String name) { + this.name = name; + } + + @Override + public boolean matches(WebSocketExtensionData data) { + return data != null && name.equals(data.name()); + } + } + + static WebSocketExtensionData webSocketExtensionDataMatcher(String text) { + return argThat(new WebSocketExtensionDataMatcher(text)); + } + + private WebSocketExtensionTestUtil() { + // unused + } + + static class DummyEncoder extends WebSocketExtensionEncoder { + @Override + protected void encode(ChannelHandlerContext ctx, WebSocketFrame msg, + List out) throws Exception { + // unused + } + } + + static class DummyDecoder extends WebSocketExtensionDecoder { + @Override + protected void decode(ChannelHandlerContext ctx, WebSocketFrame msg, + List out) throws Exception { + // unused + } + } + + static class Dummy2Encoder extends WebSocketExtensionEncoder { + @Override + protected void encode(ChannelHandlerContext ctx, WebSocketFrame msg, + List out) throws Exception { + // unused + } + } + + static class Dummy2Decoder extends WebSocketExtensionDecoder { + @Override + protected void decode(ChannelHandlerContext ctx, WebSocketFrame msg, + List out) throws Exception { + // unused + } + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketExtensionUtilTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketExtensionUtilTest.java new file mode 100644 index 0000000..4134566 --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketExtensionUtilTest.java @@ -0,0 +1,85 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx.extensions; + +import io.netty.handler.codec.http.DefaultHttpHeaders; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaderValues; +import io.netty.handler.codec.http.HttpHeaders; +import org.junit.jupiter.api.Test; + +import java.util.List; + +import static io.netty.handler.codec.http.websocketx.extensions.WebSocketExtensionUtil.*; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class WebSocketExtensionUtilTest { + + @Test + public void testIsWebsocketUpgrade() { + HttpHeaders headers = new DefaultHttpHeaders(); + assertFalse(isWebsocketUpgrade(headers)); + + headers.add(HttpHeaderNames.UPGRADE, HttpHeaderValues.WEBSOCKET); + assertFalse(isWebsocketUpgrade(headers)); + + headers.add(HttpHeaderNames.CONNECTION, "Keep-Alive, Upgrade"); + assertTrue(isWebsocketUpgrade(headers)); + } + + @Test + public void computeMergeExtensionsHeaderValueWhenNoUserDefinedHeader() { + List extras = extractExtensions("permessage-deflate; client_max_window_bits," + + "permessage-deflate; client_no_context_takeover; client_max_window_bits," + + "deflate-frame," + + "x-webkit-deflate-frame"); + String newHeaderValue = computeMergeExtensionsHeaderValue(null, extras); + assertEquals("permessage-deflate;client_max_window_bits," + + "permessage-deflate;client_no_context_takeover;client_max_window_bits," + + "deflate-frame," + + "x-webkit-deflate-frame", newHeaderValue); + } + + @Test + public void computeMergeExtensionsHeaderValueWhenNoConflictingUserDefinedHeader() { + List extras = extractExtensions("permessage-deflate; client_max_window_bits," + + "permessage-deflate; client_no_context_takeover; client_max_window_bits," + + "deflate-frame," + + "x-webkit-deflate-frame"); + String newHeaderValue = computeMergeExtensionsHeaderValue("foo, bar", extras); + assertEquals("permessage-deflate;client_max_window_bits," + + "permessage-deflate;client_no_context_takeover;client_max_window_bits," + + "deflate-frame," + + "x-webkit-deflate-frame," + + "foo," + + "bar", newHeaderValue); + } + + @Test + public void computeMergeExtensionsHeaderValueWhenConflictingUserDefinedHeader() { + List extras = extractExtensions("permessage-deflate; client_max_window_bits," + + "permessage-deflate; client_no_context_takeover; client_max_window_bits," + + "deflate-frame," + + "x-webkit-deflate-frame"); + String newHeaderValue = computeMergeExtensionsHeaderValue("permessage-deflate; client_max_window_bits", extras); + assertEquals("permessage-deflate;client_max_window_bits," + + "permessage-deflate;client_no_context_takeover;client_max_window_bits," + + "deflate-frame," + + "x-webkit-deflate-frame", newHeaderValue); + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketServerExtensionHandlerTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketServerExtensionHandlerTest.java new file mode 100644 index 0000000..96c2085 --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/WebSocketServerExtensionHandlerTest.java @@ -0,0 +1,287 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx.extensions; + +import io.netty.channel.ChannelPromise; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpRequest; +import io.netty.handler.codec.http.HttpResponse; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; + +import io.netty.handler.codec.http.LastHttpContent; +import org.junit.jupiter.api.Test; + +import static io.netty.handler.codec.http.websocketx.extensions.WebSocketExtensionTestUtil.*; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.*; + +public class WebSocketServerExtensionHandlerTest { + + WebSocketServerExtensionHandshaker mainHandshakerMock = + mock(WebSocketServerExtensionHandshaker.class, "mainHandshaker"); + WebSocketServerExtensionHandshaker fallbackHandshakerMock = + mock(WebSocketServerExtensionHandshaker.class, "fallbackHandshaker"); + + WebSocketServerExtensionHandshaker main2HandshakerMock = + mock(WebSocketServerExtensionHandshaker.class, "main2Handshaker"); + WebSocketServerExtension mainExtensionMock = + mock(WebSocketServerExtension.class, "mainExtension"); + + WebSocketServerExtension fallbackExtensionMock = + mock(WebSocketServerExtension.class, "fallbackExtension"); + + WebSocketServerExtension main2ExtensionMock = + mock(WebSocketServerExtension.class, "main2Extension"); + + @Test + public void testMainSuccess() { + // initialize + when(mainHandshakerMock.handshakeExtension(webSocketExtensionDataMatcher("main"))). + thenReturn(mainExtensionMock); + when(mainHandshakerMock.handshakeExtension(webSocketExtensionDataMatcher("fallback"))). + thenReturn(null); + + when(fallbackHandshakerMock.handshakeExtension(webSocketExtensionDataMatcher("fallback"))). + thenReturn(fallbackExtensionMock); + when(fallbackHandshakerMock.handshakeExtension(webSocketExtensionDataMatcher("main"))). + thenReturn(null); + + when(mainExtensionMock.rsv()).thenReturn(WebSocketExtension.RSV1); + when(mainExtensionMock.newReponseData()).thenReturn( + new WebSocketExtensionData("main", Collections.emptyMap())); + when(mainExtensionMock.newExtensionEncoder()).thenReturn(new DummyEncoder()); + when(mainExtensionMock.newExtensionDecoder()).thenReturn(new DummyDecoder()); + + when(fallbackExtensionMock.rsv()).thenReturn(WebSocketExtension.RSV1); + + // execute + WebSocketServerExtensionHandler extensionHandler = + new WebSocketServerExtensionHandler(mainHandshakerMock, fallbackHandshakerMock); + EmbeddedChannel ch = new EmbeddedChannel(extensionHandler); + + HttpRequest req = newUpgradeRequest("main, fallback"); + ch.writeInbound(req); + + HttpResponse res = newUpgradeResponse(null); + ch.writeOutbound(res); + + HttpResponse res2 = ch.readOutbound(); + List resExts = WebSocketExtensionUtil.extractExtensions( + res2.headers().get(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS)); + + // test + assertNull(ch.pipeline().context(extensionHandler)); + assertEquals(1, resExts.size()); + assertEquals("main", resExts.get(0).name()); + assertTrue(resExts.get(0).parameters().isEmpty()); + assertNotNull(ch.pipeline().get(DummyDecoder.class)); + assertNotNull(ch.pipeline().get(DummyEncoder.class)); + + verify(mainHandshakerMock, atLeastOnce()).handshakeExtension(webSocketExtensionDataMatcher("main")); + verify(mainHandshakerMock, atLeastOnce()).handshakeExtension(webSocketExtensionDataMatcher("fallback")); + verify(fallbackHandshakerMock, atLeastOnce()).handshakeExtension(webSocketExtensionDataMatcher("fallback")); + + verify(mainExtensionMock, atLeastOnce()).rsv(); + verify(mainExtensionMock).newReponseData(); + verify(mainExtensionMock).newExtensionEncoder(); + verify(mainExtensionMock).newExtensionDecoder(); + verify(fallbackExtensionMock, atLeastOnce()).rsv(); + } + + @Test + public void testCompatibleExtensionTogetherSuccess() { + // initialize + when(mainHandshakerMock.handshakeExtension(webSocketExtensionDataMatcher("main"))). + thenReturn(mainExtensionMock); + when(mainHandshakerMock.handshakeExtension(webSocketExtensionDataMatcher("fallback"))). + thenReturn(null); + + when(fallbackHandshakerMock.handshakeExtension(webSocketExtensionDataMatcher("fallback"))). + thenReturn(fallbackExtensionMock); + when(fallbackHandshakerMock.handshakeExtension(webSocketExtensionDataMatcher("main"))). + thenReturn(null); + + when(mainExtensionMock.rsv()).thenReturn(WebSocketExtension.RSV1); + when(mainExtensionMock.newReponseData()).thenReturn( + new WebSocketExtensionData("main", Collections.emptyMap())); + when(mainExtensionMock.newExtensionEncoder()).thenReturn(new DummyEncoder()); + when(mainExtensionMock.newExtensionDecoder()).thenReturn(new DummyDecoder()); + + when(fallbackExtensionMock.rsv()).thenReturn(WebSocketExtension.RSV2); + when(fallbackExtensionMock.newReponseData()).thenReturn( + new WebSocketExtensionData("fallback", Collections.emptyMap())); + when(fallbackExtensionMock.newExtensionEncoder()).thenReturn(new Dummy2Encoder()); + when(fallbackExtensionMock.newExtensionDecoder()).thenReturn(new Dummy2Decoder()); + + // execute + WebSocketServerExtensionHandler extensionHandler = + new WebSocketServerExtensionHandler(mainHandshakerMock, fallbackHandshakerMock); + EmbeddedChannel ch = new EmbeddedChannel(extensionHandler); + + HttpRequest req = newUpgradeRequest("main, fallback"); + ch.writeInbound(req); + + HttpResponse res = newUpgradeResponse(null); + ch.writeOutbound(res); + + HttpResponse res2 = ch.readOutbound(); + List resExts = WebSocketExtensionUtil.extractExtensions( + res2.headers().get(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS)); + + // test + assertNull(ch.pipeline().context(extensionHandler)); + assertEquals(2, resExts.size()); + assertEquals("main", resExts.get(0).name()); + assertEquals("fallback", resExts.get(1).name()); + assertNotNull(ch.pipeline().get(DummyDecoder.class)); + assertNotNull(ch.pipeline().get(DummyEncoder.class)); + assertNotNull(ch.pipeline().get(Dummy2Decoder.class)); + assertNotNull(ch.pipeline().get(Dummy2Encoder.class)); + + verify(mainHandshakerMock).handshakeExtension(webSocketExtensionDataMatcher("main")); + verify(mainHandshakerMock).handshakeExtension(webSocketExtensionDataMatcher("fallback")); + verify(fallbackHandshakerMock).handshakeExtension(webSocketExtensionDataMatcher("fallback")); + verify(mainExtensionMock, times(2)).rsv(); + verify(mainExtensionMock).newReponseData(); + verify(mainExtensionMock).newExtensionEncoder(); + verify(mainExtensionMock).newExtensionDecoder(); + + verify(fallbackExtensionMock, times(2)).rsv(); + + verify(fallbackExtensionMock).newReponseData(); + verify(fallbackExtensionMock).newExtensionEncoder(); + verify(fallbackExtensionMock).newExtensionDecoder(); + } + + @Test + public void testNoneExtensionMatchingSuccess() { + // initialize + when(mainHandshakerMock.handshakeExtension(webSocketExtensionDataMatcher("unknown"))). + thenReturn(null); + when(mainHandshakerMock.handshakeExtension(webSocketExtensionDataMatcher("unknown2"))). + thenReturn(null); + + when(fallbackHandshakerMock.handshakeExtension(webSocketExtensionDataMatcher("unknown"))). + thenReturn(null); + when(fallbackHandshakerMock.handshakeExtension(webSocketExtensionDataMatcher("unknown2"))). + thenReturn(null); + + // execute + WebSocketServerExtensionHandler extensionHandler = + new WebSocketServerExtensionHandler(mainHandshakerMock, fallbackHandshakerMock); + EmbeddedChannel ch = new EmbeddedChannel(extensionHandler); + + HttpRequest req = newUpgradeRequest("unknown, unknown2"); + ch.writeInbound(req); + + HttpResponse res = newUpgradeResponse(null); + ch.writeOutbound(res); + + HttpResponse res2 = ch.readOutbound(); + + // test + assertNull(ch.pipeline().context(extensionHandler)); + assertFalse(res2.headers().contains(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS)); + + verify(mainHandshakerMock).handshakeExtension(webSocketExtensionDataMatcher("unknown")); + verify(mainHandshakerMock).handshakeExtension(webSocketExtensionDataMatcher("unknown2")); + + verify(fallbackHandshakerMock).handshakeExtension(webSocketExtensionDataMatcher("unknown")); + verify(fallbackHandshakerMock).handshakeExtension(webSocketExtensionDataMatcher("unknown2")); + } + + @Test + public void testExtensionHandlerNotRemovedByFailureWritePromise() { + // initialize + when(mainHandshakerMock.handshakeExtension(webSocketExtensionDataMatcher("main"))) + .thenReturn(mainExtensionMock); + when(mainExtensionMock.newReponseData()).thenReturn( + new WebSocketExtensionData("main", Collections.emptyMap())); + + // execute + WebSocketServerExtensionHandler extensionHandler = + new WebSocketServerExtensionHandler(mainHandshakerMock); + EmbeddedChannel ch = new EmbeddedChannel(extensionHandler); + + HttpRequest req = newUpgradeRequest("main"); + ch.writeInbound(req); + + HttpResponse res = newUpgradeResponse(null); + ChannelPromise failurePromise = ch.newPromise(); + ch.writeOneOutbound(res, failurePromise); + failurePromise.setFailure(new IOException("Cannot write response")); + + // test + assertNull(ch.readOutbound()); + assertNotNull(ch.pipeline().context(extensionHandler)); + assertTrue(ch.finish()); + } + + @Test + public void testExtensionMultipleRequests() { + // initialize + when(mainHandshakerMock.handshakeExtension(webSocketExtensionDataMatcher("main"))) + .thenReturn(mainExtensionMock); + + when(mainExtensionMock.rsv()).thenReturn(WebSocketExtension.RSV1); + when(mainExtensionMock.newReponseData()).thenReturn( + new WebSocketExtensionData("main", Collections.emptyMap())); + when(mainExtensionMock.newExtensionEncoder()).thenReturn(new DummyEncoder()); + when(mainExtensionMock.newExtensionDecoder()).thenReturn(new DummyDecoder()); + + when(main2HandshakerMock.handshakeExtension(webSocketExtensionDataMatcher("main2"))) + .thenReturn(main2ExtensionMock); + + when(main2ExtensionMock.rsv()).thenReturn(WebSocketExtension.RSV1); + when(main2ExtensionMock.newReponseData()).thenReturn( + new WebSocketExtensionData("main2", Collections.emptyMap())); + when(main2ExtensionMock.newExtensionEncoder()).thenReturn(new DummyEncoder()); + when(main2ExtensionMock.newExtensionDecoder()).thenReturn(new DummyDecoder()); + + // execute + WebSocketServerExtensionHandler extensionHandler = + new WebSocketServerExtensionHandler(mainHandshakerMock, main2HandshakerMock); + EmbeddedChannel ch = new EmbeddedChannel(extensionHandler); + + HttpRequest req = newUpgradeRequest("main"); + assertTrue(ch.writeInbound(req)); + assertTrue(ch.writeInbound(LastHttpContent.EMPTY_LAST_CONTENT)); + + HttpRequest req2 = newUpgradeRequest("main2"); + assertTrue(ch.writeInbound(req2)); + assertTrue(ch.writeInbound(LastHttpContent.EMPTY_LAST_CONTENT)); + + HttpResponse res = newUpgradeResponse(null); + assertTrue(ch.writeOutbound(res)); + assertTrue(ch.writeOutbound(LastHttpContent.EMPTY_LAST_CONTENT)); + + res = ch.readOutbound(); + assertEquals("main", res.headers().get(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS)); + LastHttpContent content = ch.readOutbound(); + content.release(); + + assertNull(ch.pipeline().context(extensionHandler)); + assertTrue(ch.finishAndReleaseAll()); + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/compression/DeflateFrameClientExtensionHandshakerTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/compression/DeflateFrameClientExtensionHandshakerTest.java new file mode 100644 index 0000000..88c3ecb --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/compression/DeflateFrameClientExtensionHandshakerTest.java @@ -0,0 +1,88 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx.extensions.compression; + +import static io.netty.handler.codec.http.websocketx.extensions.compression. + DeflateFrameServerExtensionHandshaker.*; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import io.netty.handler.codec.http.websocketx.extensions.WebSocketClientExtension; +import io.netty.handler.codec.http.websocketx.extensions.WebSocketExtensionData; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import org.junit.jupiter.api.Test; + +public class DeflateFrameClientExtensionHandshakerTest { + + @Test + public void testWebkitDeflateFrameData() { + DeflateFrameClientExtensionHandshaker handshaker = + new DeflateFrameClientExtensionHandshaker(true); + + WebSocketExtensionData data = handshaker.newRequestData(); + + assertEquals(X_WEBKIT_DEFLATE_FRAME_EXTENSION, data.name()); + assertTrue(data.parameters().isEmpty()); + } + + @Test + public void testDeflateFrameData() { + DeflateFrameClientExtensionHandshaker handshaker = + new DeflateFrameClientExtensionHandshaker(false); + + WebSocketExtensionData data = handshaker.newRequestData(); + + assertEquals(DEFLATE_FRAME_EXTENSION, data.name()); + assertTrue(data.parameters().isEmpty()); + } + + @Test + public void testNormalHandshake() { + DeflateFrameClientExtensionHandshaker handshaker = + new DeflateFrameClientExtensionHandshaker(false); + + WebSocketClientExtension extension = handshaker.handshakeExtension( + new WebSocketExtensionData(DEFLATE_FRAME_EXTENSION, Collections.emptyMap())); + + assertNotNull(extension); + assertEquals(WebSocketClientExtension.RSV1, extension.rsv()); + assertTrue(extension.newExtensionDecoder() instanceof PerFrameDeflateDecoder); + assertTrue(extension.newExtensionEncoder() instanceof PerFrameDeflateEncoder); + } + + @Test + public void testFailedHandshake() { + // initialize + DeflateFrameClientExtensionHandshaker handshaker = + new DeflateFrameClientExtensionHandshaker(false); + + Map parameters = new HashMap(); + parameters.put("invalid", "12"); + + // execute + WebSocketClientExtension extension = handshaker.handshakeExtension( + new WebSocketExtensionData(DEFLATE_FRAME_EXTENSION, parameters)); + + // test + assertNull(extension); + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/compression/DeflateFrameServerExtensionHandshakerTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/compression/DeflateFrameServerExtensionHandshakerTest.java new file mode 100644 index 0000000..ecc6dcd --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/compression/DeflateFrameServerExtensionHandshakerTest.java @@ -0,0 +1,88 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx.extensions.compression; + +import static io.netty.handler.codec.http.websocketx.extensions.compression. + DeflateFrameServerExtensionHandshaker.*; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import io.netty.handler.codec.http.websocketx.extensions.WebSocketServerExtension; +import io.netty.handler.codec.http.websocketx.extensions.WebSocketExtensionData; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import org.junit.jupiter.api.Test; + +public class DeflateFrameServerExtensionHandshakerTest { + + @Test + public void testNormalHandshake() { + // initialize + DeflateFrameServerExtensionHandshaker handshaker = + new DeflateFrameServerExtensionHandshaker(); + + // execute + WebSocketServerExtension extension = handshaker.handshakeExtension( + new WebSocketExtensionData(DEFLATE_FRAME_EXTENSION, Collections.emptyMap())); + + // test + assertNotNull(extension); + assertEquals(WebSocketServerExtension.RSV1, extension.rsv()); + assertTrue(extension.newExtensionDecoder() instanceof PerFrameDeflateDecoder); + assertTrue(extension.newExtensionEncoder() instanceof PerFrameDeflateEncoder); + } + + @Test + public void testWebkitHandshake() { + // initialize + DeflateFrameServerExtensionHandshaker handshaker = + new DeflateFrameServerExtensionHandshaker(); + + // execute + WebSocketServerExtension extension = handshaker.handshakeExtension( + new WebSocketExtensionData(X_WEBKIT_DEFLATE_FRAME_EXTENSION, Collections.emptyMap())); + + // test + assertNotNull(extension); + assertEquals(WebSocketServerExtension.RSV1, extension.rsv()); + assertTrue(extension.newExtensionDecoder() instanceof PerFrameDeflateDecoder); + assertTrue(extension.newExtensionEncoder() instanceof PerFrameDeflateEncoder); + } + + @Test + public void testFailedHandshake() { + // initialize + DeflateFrameServerExtensionHandshaker handshaker = + new DeflateFrameServerExtensionHandshaker(); + + Map parameters; + parameters = new HashMap(); + parameters.put("unknown", "11"); + + // execute + WebSocketServerExtension extension = handshaker.handshakeExtension( + new WebSocketExtensionData(DEFLATE_FRAME_EXTENSION, parameters)); + + // test + assertNull(extension); + } + +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerFrameDeflateDecoderTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerFrameDeflateDecoderTest.java new file mode 100644 index 0000000..121fd20 --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerFrameDeflateDecoderTest.java @@ -0,0 +1,155 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx.extensions.compression; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.compression.ZlibCodecFactory; +import io.netty.handler.codec.compression.ZlibWrapper; +import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; +import io.netty.handler.codec.http.websocketx.extensions.WebSocketExtension; +import org.junit.jupiter.api.Test; + +import java.util.Random; + +import static io.netty.handler.codec.http.websocketx.extensions.WebSocketExtension.*; +import static io.netty.handler.codec.http.websocketx.extensions.WebSocketExtensionFilter.*; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class PerFrameDeflateDecoderTest { + + private static final Random random = new Random(); + + @Test + public void testCompressedFrame() { + EmbeddedChannel encoderChannel = new EmbeddedChannel( + ZlibCodecFactory.newZlibEncoder(ZlibWrapper.NONE, 9, 15, 8)); + EmbeddedChannel decoderChannel = new EmbeddedChannel(new PerFrameDeflateDecoder(false)); + + // initialize + byte[] payload = new byte[300]; + random.nextBytes(payload); + + assertTrue(encoderChannel.writeOutbound(Unpooled.wrappedBuffer(payload))); + ByteBuf compressedPayload = encoderChannel.readOutbound(); + + BinaryWebSocketFrame compressedFrame = new BinaryWebSocketFrame(true, + RSV1 | RSV3, + compressedPayload.slice(0, compressedPayload.readableBytes() - 4)); + + // execute + assertTrue(decoderChannel.writeInbound(compressedFrame)); + BinaryWebSocketFrame uncompressedFrame = decoderChannel.readInbound(); + + // test + assertNotNull(uncompressedFrame); + assertNotNull(uncompressedFrame.content()); + assertEquals(RSV3, uncompressedFrame.rsv()); + assertEquals(300, uncompressedFrame.content().readableBytes()); + + byte[] finalPayload = new byte[300]; + uncompressedFrame.content().readBytes(finalPayload); + assertArrayEquals(finalPayload, payload); + uncompressedFrame.release(); + } + + @Test + public void testNormalFrame() { + EmbeddedChannel decoderChannel = new EmbeddedChannel(new PerFrameDeflateDecoder(false)); + + // initialize + byte[] payload = new byte[300]; + random.nextBytes(payload); + + BinaryWebSocketFrame frame = new BinaryWebSocketFrame(true, + RSV3, Unpooled.wrappedBuffer(payload)); + + // execute + assertTrue(decoderChannel.writeInbound(frame)); + BinaryWebSocketFrame newFrame = decoderChannel.readInbound(); + + // test + assertNotNull(newFrame); + assertNotNull(newFrame.content()); + assertEquals(RSV3, newFrame.rsv()); + assertEquals(300, newFrame.content().readableBytes()); + + byte[] finalPayload = new byte[300]; + newFrame.content().readBytes(finalPayload); + assertArrayEquals(finalPayload, payload); + newFrame.release(); + } + + // See https://github.com/netty/netty/issues/4348 + @Test + public void testCompressedEmptyFrame() { + EmbeddedChannel encoderChannel = new EmbeddedChannel( + ZlibCodecFactory.newZlibEncoder(ZlibWrapper.NONE, 9, 15, 8)); + EmbeddedChannel decoderChannel = new EmbeddedChannel(new PerFrameDeflateDecoder(false)); + + assertTrue(encoderChannel.writeOutbound(Unpooled.EMPTY_BUFFER)); + ByteBuf compressedPayload = encoderChannel.readOutbound(); + BinaryWebSocketFrame compressedFrame = + new BinaryWebSocketFrame(true, RSV1 | RSV3, compressedPayload); + + // execute + assertTrue(decoderChannel.writeInbound(compressedFrame)); + BinaryWebSocketFrame uncompressedFrame = decoderChannel.readInbound(); + + // test + assertNotNull(uncompressedFrame); + assertNotNull(uncompressedFrame.content()); + assertEquals(RSV3, uncompressedFrame.rsv()); + assertEquals(0, uncompressedFrame.content().readableBytes()); + uncompressedFrame.release(); + } + + @Test + public void testDecompressionSkip() { + EmbeddedChannel encoderChannel = new EmbeddedChannel( + ZlibCodecFactory.newZlibEncoder(ZlibWrapper.NONE, 9, 15, 8)); + EmbeddedChannel decoderChannel = new EmbeddedChannel(new PerFrameDeflateDecoder(false, ALWAYS_SKIP)); + + byte[] payload = new byte[300]; + random.nextBytes(payload); + + assertTrue(encoderChannel.writeOutbound(Unpooled.wrappedBuffer(payload))); + ByteBuf compressedPayload = encoderChannel.readOutbound(); + + BinaryWebSocketFrame compressedBinaryFrame = new BinaryWebSocketFrame( + true, WebSocketExtension.RSV1 | WebSocketExtension.RSV3, compressedPayload); + + assertTrue(decoderChannel.writeInbound(compressedBinaryFrame)); + + BinaryWebSocketFrame inboundBinaryFrame = decoderChannel.readInbound(); + + assertNotNull(inboundBinaryFrame); + assertNotNull(inboundBinaryFrame.content()); + assertEquals(compressedPayload, inboundBinaryFrame.content()); + assertEquals(5, inboundBinaryFrame.rsv()); + + assertTrue(inboundBinaryFrame.release()); + + assertTrue(encoderChannel.finishAndReleaseAll()); + assertFalse(decoderChannel.finish()); + } + +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerFrameDeflateEncoderTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerFrameDeflateEncoderTest.java new file mode 100644 index 0000000..ceba96f --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerFrameDeflateEncoderTest.java @@ -0,0 +1,189 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx.extensions.compression; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.compression.ZlibCodecFactory; +import io.netty.handler.codec.compression.ZlibWrapper; +import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; +import io.netty.handler.codec.http.websocketx.ContinuationWebSocketFrame; +import io.netty.handler.codec.http.websocketx.extensions.WebSocketExtension; +import org.junit.jupiter.api.Test; + +import java.util.Random; + +import static io.netty.handler.codec.http.websocketx.extensions.WebSocketExtensionFilter.*; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class PerFrameDeflateEncoderTest { + + private static final Random random = new Random(); + + @Test + public void testCompressedFrame() { + EmbeddedChannel encoderChannel = new EmbeddedChannel(new PerFrameDeflateEncoder(9, 15, false)); + EmbeddedChannel decoderChannel = new EmbeddedChannel( + ZlibCodecFactory.newZlibDecoder(ZlibWrapper.NONE)); + + // initialize + byte[] payload = new byte[300]; + random.nextBytes(payload); + BinaryWebSocketFrame frame = new BinaryWebSocketFrame(true, + WebSocketExtension.RSV3, Unpooled.wrappedBuffer(payload)); + + // execute + assertTrue(encoderChannel.writeOutbound(frame)); + BinaryWebSocketFrame compressedFrame = encoderChannel.readOutbound(); + + // test + assertNotNull(compressedFrame); + assertNotNull(compressedFrame.content()); + assertEquals(WebSocketExtension.RSV1 | WebSocketExtension.RSV3, compressedFrame.rsv()); + + assertTrue(decoderChannel.writeInbound(compressedFrame.content())); + assertTrue(decoderChannel.writeInbound(DeflateDecoder.FRAME_TAIL.duplicate())); + ByteBuf uncompressedPayload = decoderChannel.readInbound(); + assertEquals(300, uncompressedPayload.readableBytes()); + + byte[] finalPayload = new byte[300]; + uncompressedPayload.readBytes(finalPayload); + assertArrayEquals(finalPayload, payload); + uncompressedPayload.release(); + } + + @Test + public void testAlreadyCompressedFrame() { + EmbeddedChannel encoderChannel = new EmbeddedChannel(new PerFrameDeflateEncoder(9, 15, false)); + + // initialize + byte[] payload = new byte[300]; + random.nextBytes(payload); + + BinaryWebSocketFrame frame = new BinaryWebSocketFrame(true, + WebSocketExtension.RSV3 | WebSocketExtension.RSV1, Unpooled.wrappedBuffer(payload)); + + // execute + assertTrue(encoderChannel.writeOutbound(frame)); + BinaryWebSocketFrame newFrame = encoderChannel.readOutbound(); + + // test + assertNotNull(newFrame); + assertNotNull(newFrame.content()); + assertEquals(WebSocketExtension.RSV3 | WebSocketExtension.RSV1, newFrame.rsv()); + assertEquals(300, newFrame.content().readableBytes()); + + byte[] finalPayload = new byte[300]; + newFrame.content().readBytes(finalPayload); + assertArrayEquals(finalPayload, payload); + newFrame.release(); + } + + @Test + public void testFramementedFrame() { + EmbeddedChannel encoderChannel = new EmbeddedChannel(new PerFrameDeflateEncoder(9, 15, false)); + EmbeddedChannel decoderChannel = new EmbeddedChannel( + ZlibCodecFactory.newZlibDecoder(ZlibWrapper.NONE)); + + // initialize + byte[] payload1 = new byte[100]; + random.nextBytes(payload1); + byte[] payload2 = new byte[100]; + random.nextBytes(payload2); + byte[] payload3 = new byte[100]; + random.nextBytes(payload3); + + BinaryWebSocketFrame frame1 = new BinaryWebSocketFrame(false, + WebSocketExtension.RSV3, Unpooled.wrappedBuffer(payload1)); + ContinuationWebSocketFrame frame2 = new ContinuationWebSocketFrame(false, + WebSocketExtension.RSV3, Unpooled.wrappedBuffer(payload2)); + ContinuationWebSocketFrame frame3 = new ContinuationWebSocketFrame(true, + WebSocketExtension.RSV3, Unpooled.wrappedBuffer(payload3)); + + // execute + assertTrue(encoderChannel.writeOutbound(frame1)); + assertTrue(encoderChannel.writeOutbound(frame2)); + assertTrue(encoderChannel.writeOutbound(frame3)); + BinaryWebSocketFrame compressedFrame1 = encoderChannel.readOutbound(); + ContinuationWebSocketFrame compressedFrame2 = encoderChannel.readOutbound(); + ContinuationWebSocketFrame compressedFrame3 = encoderChannel.readOutbound(); + + // test + assertNotNull(compressedFrame1); + assertNotNull(compressedFrame2); + assertNotNull(compressedFrame3); + assertEquals(WebSocketExtension.RSV1 | WebSocketExtension.RSV3, compressedFrame1.rsv()); + assertEquals(WebSocketExtension.RSV1 | WebSocketExtension.RSV3, compressedFrame2.rsv()); + assertEquals(WebSocketExtension.RSV1 | WebSocketExtension.RSV3, compressedFrame3.rsv()); + assertFalse(compressedFrame1.isFinalFragment()); + assertFalse(compressedFrame2.isFinalFragment()); + assertTrue(compressedFrame3.isFinalFragment()); + + assertTrue(decoderChannel.writeInbound(compressedFrame1.content())); + assertTrue(decoderChannel.writeInbound(DeflateDecoder.FRAME_TAIL.duplicate())); + ByteBuf uncompressedPayload1 = decoderChannel.readInbound(); + byte[] finalPayload1 = new byte[100]; + uncompressedPayload1.readBytes(finalPayload1); + assertArrayEquals(finalPayload1, payload1); + uncompressedPayload1.release(); + + assertTrue(decoderChannel.writeInbound(compressedFrame2.content())); + assertTrue(decoderChannel.writeInbound(DeflateDecoder.FRAME_TAIL.duplicate())); + ByteBuf uncompressedPayload2 = decoderChannel.readInbound(); + byte[] finalPayload2 = new byte[100]; + uncompressedPayload2.readBytes(finalPayload2); + assertArrayEquals(finalPayload2, payload2); + uncompressedPayload2.release(); + + assertTrue(decoderChannel.writeInbound(compressedFrame3.content())); + assertTrue(decoderChannel.writeInbound(DeflateDecoder.FRAME_TAIL.duplicate())); + ByteBuf uncompressedPayload3 = decoderChannel.readInbound(); + byte[] finalPayload3 = new byte[100]; + uncompressedPayload3.readBytes(finalPayload3); + assertArrayEquals(finalPayload3, payload3); + uncompressedPayload3.release(); + } + + @Test + public void testCompressionSkip() { + EmbeddedChannel encoderChannel = new EmbeddedChannel( + new PerFrameDeflateEncoder(9, 15, false, ALWAYS_SKIP)); + byte[] payload = new byte[300]; + random.nextBytes(payload); + BinaryWebSocketFrame binaryFrame = new BinaryWebSocketFrame(true, + 0, Unpooled.wrappedBuffer(payload)); + + // execute + assertTrue(encoderChannel.writeOutbound(binaryFrame.copy())); + BinaryWebSocketFrame outboundFrame = encoderChannel.readOutbound(); + + // test + assertNotNull(outboundFrame); + assertNotNull(outboundFrame.content()); + assertArrayEquals(payload, ByteBufUtil.getBytes(outboundFrame.content())); + assertEquals(0, outboundFrame.rsv()); + assertTrue(outboundFrame.release()); + + assertFalse(encoderChannel.finish()); + } + +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerMessageDeflateClientExtensionHandshakerTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerMessageDeflateClientExtensionHandshakerTest.java new file mode 100644 index 0000000..d649d21 --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerMessageDeflateClientExtensionHandshakerTest.java @@ -0,0 +1,248 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx.extensions.compression; + +import static io.netty.handler.codec.http.websocketx.extensions.WebSocketExtension.RSV1; +import static io.netty.handler.codec.http.websocketx.extensions.compression. + PerMessageDeflateServerExtensionHandshaker.*; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import io.netty.buffer.Unpooled; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.compression.ZlibCodecFactory; +import io.netty.handler.codec.http.websocketx.TextWebSocketFrame; +import io.netty.handler.codec.http.websocketx.extensions.WebSocketClientExtension; +import io.netty.handler.codec.http.websocketx.extensions.WebSocketExtensionData; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import org.junit.jupiter.api.Test; + +public class PerMessageDeflateClientExtensionHandshakerTest { + + @Test + public void testNormalData() { + PerMessageDeflateClientExtensionHandshaker handshaker = + new PerMessageDeflateClientExtensionHandshaker(); + + WebSocketExtensionData data = handshaker.newRequestData(); + + assertEquals(PERMESSAGE_DEFLATE_EXTENSION, data.name()); + assertEquals(ZlibCodecFactory.isSupportingWindowSizeAndMemLevel() ? 1 : 0, data.parameters().size()); + // TODO why is this 0? JP, 5.1.2023 + assertEquals(0, data.parameters().size()); + } + + @Test + public void testCustomData() { + PerMessageDeflateClientExtensionHandshaker handshaker = + new PerMessageDeflateClientExtensionHandshaker(6, true, 10, true, true); + + WebSocketExtensionData data = handshaker.newRequestData(); + + assertEquals(PERMESSAGE_DEFLATE_EXTENSION, data.name()); + assertTrue(data.parameters().containsKey(CLIENT_MAX_WINDOW)); + assertTrue(data.parameters().containsKey(SERVER_MAX_WINDOW)); + assertEquals("10", data.parameters().get(SERVER_MAX_WINDOW)); + assertTrue(data.parameters().containsKey(CLIENT_MAX_WINDOW)); + assertTrue(data.parameters().containsKey(SERVER_MAX_WINDOW)); + } + + @Test + public void testNormalHandshake() { + PerMessageDeflateClientExtensionHandshaker handshaker = + new PerMessageDeflateClientExtensionHandshaker(); + + WebSocketClientExtension extension = handshaker.handshakeExtension( + new WebSocketExtensionData(PERMESSAGE_DEFLATE_EXTENSION, Collections.emptyMap())); + + assertNotNull(extension); + assertEquals(RSV1, extension.rsv()); + assertTrue(extension.newExtensionDecoder() instanceof PerMessageDeflateDecoder); + assertTrue(extension.newExtensionEncoder() instanceof PerMessageDeflateEncoder); + } + + @Test + public void testCustomHandshake() { + WebSocketClientExtension extension; + Map parameters; + + // initialize + PerMessageDeflateClientExtensionHandshaker handshaker = + new PerMessageDeflateClientExtensionHandshaker(6, true, 10, true, true); + + parameters = new HashMap(); + parameters.put(CLIENT_MAX_WINDOW, "12"); + parameters.put(SERVER_MAX_WINDOW, "8"); + parameters.put(CLIENT_NO_CONTEXT, null); + parameters.put(SERVER_NO_CONTEXT, null); + + // execute + extension = handshaker.handshakeExtension( + new WebSocketExtensionData(PERMESSAGE_DEFLATE_EXTENSION, parameters)); + + // test + assertNotNull(extension); + assertEquals(RSV1, extension.rsv()); + assertTrue(extension.newExtensionDecoder() instanceof PerMessageDeflateDecoder); + assertTrue(extension.newExtensionEncoder() instanceof PerMessageDeflateEncoder); + + // initialize + parameters = new HashMap(); + parameters.put(SERVER_MAX_WINDOW, "10"); + parameters.put(SERVER_NO_CONTEXT, null); + + // execute + extension = handshaker.handshakeExtension( + new WebSocketExtensionData(PERMESSAGE_DEFLATE_EXTENSION, parameters)); + + // test + assertNotNull(extension); + assertEquals(RSV1, extension.rsv()); + assertTrue(extension.newExtensionDecoder() instanceof PerMessageDeflateDecoder); + assertTrue(extension.newExtensionEncoder() instanceof PerMessageDeflateEncoder); + + // initialize + parameters = new HashMap(); + + // execute + extension = handshaker.handshakeExtension( + new WebSocketExtensionData(PERMESSAGE_DEFLATE_EXTENSION, parameters)); + + // test + assertNull(extension); + } + + @Test + public void testParameterValidation() { + WebSocketClientExtension extension; + Map parameters; + + PerMessageDeflateClientExtensionHandshaker handshaker = + new PerMessageDeflateClientExtensionHandshaker(6, true, 15, true, false); + + parameters = new HashMap(); + parameters.put(CLIENT_MAX_WINDOW, "15"); + parameters.put(SERVER_MAX_WINDOW, "8"); + extension = handshaker.handshakeExtension(new WebSocketExtensionData(PERMESSAGE_DEFLATE_EXTENSION, parameters)); + + // Test that handshake succeeds when parameters are valid + assertNotNull(extension); + assertEquals(RSV1, extension.rsv()); + assertTrue(extension.newExtensionDecoder() instanceof PerMessageDeflateDecoder); + assertTrue(extension.newExtensionEncoder() instanceof PerMessageDeflateEncoder); + + parameters = new HashMap(); + parameters.put(CLIENT_MAX_WINDOW, "15"); + parameters.put(SERVER_MAX_WINDOW, "7"); + + extension = handshaker.handshakeExtension(new WebSocketExtensionData(PERMESSAGE_DEFLATE_EXTENSION, parameters)); + + // Test that handshake fails when parameters are invalid + assertNull(extension); + } + + @Test + public void testServerNoContextTakeover() { + WebSocketClientExtension extension; + Map parameters; + + PerMessageDeflateClientExtensionHandshaker handshaker = + new PerMessageDeflateClientExtensionHandshaker(6, true, 15, true, false); + + parameters = new HashMap(); + parameters.put(SERVER_NO_CONTEXT, null); + extension = handshaker.handshakeExtension(new WebSocketExtensionData(PERMESSAGE_DEFLATE_EXTENSION, parameters)); + + // Test that handshake succeeds when server responds with `server_no_context_takeover` that we didn't offer + assertNotNull(extension); + assertEquals(RSV1, extension.rsv()); + assertTrue(extension.newExtensionDecoder() instanceof PerMessageDeflateDecoder); + assertTrue(extension.newExtensionEncoder() instanceof PerMessageDeflateEncoder); + + // initialize + handshaker = new PerMessageDeflateClientExtensionHandshaker(6, true, 15, true, true); + + parameters = new HashMap(); + extension = handshaker.handshakeExtension(new WebSocketExtensionData(PERMESSAGE_DEFLATE_EXTENSION, parameters)); + + // Test that handshake fails when client offers `server_no_context_takeover` but server doesn't support it + assertNull(extension); + } + + @Test + public void testDecoderNoClientContext() { + PerMessageDeflateClientExtensionHandshaker handshaker = + new PerMessageDeflateClientExtensionHandshaker(6, true, MAX_WINDOW_SIZE, true, false); + + byte[] firstPayload = new byte[] { + 76, -50, -53, 10, -62, 48, 20, 4, -48, 95, 41, 89, -37, 36, 77, 90, 31, -39, 41, -72, 112, 33, -120, 20, + 20, 119, -79, 70, 123, -95, 121, -48, 92, -116, 80, -6, -17, -58, -99, -37, -31, 12, 51, 19, 1, -9, -12, + 68, -111, -117, 25, 58, 111, 77, -127, -66, -64, -34, 20, 59, -64, -29, -2, 90, -100, -115, 30, 16, 114, + -68, 61, 29, 40, 89, -112, -73, 25, 35, 120, -105, -67, -32, -43, -70, -84, 120, -55, 69, 43, -124, 106, + -92, 18, -110, 114, -50, 111, 25, -3, 10, 17, -75, 13, 127, -84, 106, 90, -66, 84, -75, 84, 53, -89, + -75, 92, -3, -40, -61, 119, 49, -117, 30, 49, 68, -59, 88, 74, -119, -34, 1, -83, -7, -48, 124, -124, + -23, 16, 88, -118, 121, 54, -53, 1, 44, 32, 81, 19, 25, -115, -43, -32, -64, -67, -120, -110, -101, 121, + -2, 2 + }; + + byte[] secondPayload = new byte[] { + -86, 86, 42, 46, 77, 78, 78, 45, 6, 26, 83, 82, 84, -102, -86, 3, -28, 38, 21, 39, 23, 101, 38, -91, 2, + -51, -51, 47, 74, 73, 45, 114, -54, -49, -49, -10, 49, -78, -118, 112, 10, 9, 13, 118, 1, -102, 84, + -108, 90, 88, 10, 116, 27, -56, -84, 124, -112, -13, 16, 26, 116, -108, 18, -117, -46, -127, 6, 69, 99, + -45, 24, 91, 91, 11, 0 + }; + + Map parameters = Collections.singletonMap(CLIENT_NO_CONTEXT, null); + + WebSocketClientExtension extension = handshaker.handshakeExtension( + new WebSocketExtensionData(PERMESSAGE_DEFLATE_EXTENSION, parameters)); + assertNotNull(extension); + + EmbeddedChannel decoderChannel = new EmbeddedChannel(extension.newExtensionDecoder()); + assertTrue( + decoderChannel.writeInbound(new TextWebSocketFrame(true, RSV1, Unpooled.copiedBuffer(firstPayload)))); + TextWebSocketFrame firstFrameDecompressed = decoderChannel.readInbound(); + assertTrue( + decoderChannel.writeInbound(new TextWebSocketFrame(true, RSV1, Unpooled.copiedBuffer(secondPayload)))); + TextWebSocketFrame secondFrameDecompressed = decoderChannel.readInbound(); + + assertNotNull(firstFrameDecompressed); + assertNotNull(firstFrameDecompressed.content()); + assertTrue(firstFrameDecompressed instanceof TextWebSocketFrame); + assertEquals(firstFrameDecompressed.text(), + "{\"info\":\"Welcome to the BitMEX Realtime API.\",\"version\"" + + ":\"2018-10-02T22:53:23.000Z\",\"timestamp\":\"2018-10-15T06:43:40.437Z\"," + + "\"docs\":\"https://www.bitmex.com/app/wsAPI\",\"limit\":{\"remaining\":39}}"); + assertTrue(firstFrameDecompressed.release()); + + assertNotNull(secondFrameDecompressed); + assertNotNull(secondFrameDecompressed.content()); + assertTrue(secondFrameDecompressed instanceof TextWebSocketFrame); + assertEquals(secondFrameDecompressed.text(), + "{\"success\":true,\"subscribe\":\"orderBookL2:XBTUSD\"," + + "\"request\":{\"op\":\"subscribe\",\"args\":[\"orderBookL2:XBTUSD\"]}}"); + assertTrue(secondFrameDecompressed.release()); + + assertFalse(decoderChannel.finish()); + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerMessageDeflateDecoderTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerMessageDeflateDecoderTest.java new file mode 100644 index 0000000..667a580 --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerMessageDeflateDecoderTest.java @@ -0,0 +1,400 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx.extensions.compression; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.DecoderException; +import io.netty.handler.codec.compression.ZlibCodecFactory; +import io.netty.handler.codec.compression.ZlibWrapper; +import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; +import io.netty.handler.codec.http.websocketx.ContinuationWebSocketFrame; +import io.netty.handler.codec.http.websocketx.TextWebSocketFrame; +import io.netty.handler.codec.http.websocketx.WebSocketFrame; +import io.netty.handler.codec.http.websocketx.extensions.WebSocketExtension; +import io.netty.handler.codec.http.websocketx.extensions.WebSocketExtensionFilter; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import java.util.Random; + +import static io.netty.handler.codec.http.websocketx.extensions.WebSocketExtensionFilter.*; +import static io.netty.handler.codec.http.websocketx.extensions.compression.DeflateDecoder.*; +import static io.netty.util.CharsetUtil.*; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class PerMessageDeflateDecoderTest { + + private static final Random random = new Random(); + + @Test + public void testCompressedFrame() { + EmbeddedChannel encoderChannel = new EmbeddedChannel( + ZlibCodecFactory.newZlibEncoder(ZlibWrapper.NONE, 9, 15, 8)); + EmbeddedChannel decoderChannel = new EmbeddedChannel(new PerMessageDeflateDecoder(false)); + + // initialize + byte[] payload = new byte[300]; + random.nextBytes(payload); + + assertTrue(encoderChannel.writeOutbound(Unpooled.wrappedBuffer(payload))); + ByteBuf compressedPayload = encoderChannel.readOutbound(); + + BinaryWebSocketFrame compressedFrame = new BinaryWebSocketFrame(true, + WebSocketExtension.RSV1 | WebSocketExtension.RSV3, + compressedPayload.slice(0, compressedPayload.readableBytes() - 4)); + + // execute + assertTrue(decoderChannel.writeInbound(compressedFrame)); + BinaryWebSocketFrame uncompressedFrame = decoderChannel.readInbound(); + + // test + assertNotNull(uncompressedFrame); + assertNotNull(uncompressedFrame.content()); + assertEquals(WebSocketExtension.RSV3, uncompressedFrame.rsv()); + assertEquals(300, uncompressedFrame.content().readableBytes()); + + byte[] finalPayload = new byte[300]; + uncompressedFrame.content().readBytes(finalPayload); + assertArrayEquals(finalPayload, payload); + uncompressedFrame.release(); + } + + @Test + public void testNormalFrame() { + EmbeddedChannel decoderChannel = new EmbeddedChannel(new PerMessageDeflateDecoder(false)); + + // initialize + byte[] payload = new byte[300]; + random.nextBytes(payload); + + BinaryWebSocketFrame frame = new BinaryWebSocketFrame(true, + WebSocketExtension.RSV3, Unpooled.wrappedBuffer(payload)); + + // execute + assertTrue(decoderChannel.writeInbound(frame)); + BinaryWebSocketFrame newFrame = decoderChannel.readInbound(); + + // test + assertNotNull(newFrame); + assertNotNull(newFrame.content()); + assertEquals(WebSocketExtension.RSV3, newFrame.rsv()); + assertEquals(300, newFrame.content().readableBytes()); + + byte[] finalPayload = new byte[300]; + newFrame.content().readBytes(finalPayload); + assertArrayEquals(finalPayload, payload); + newFrame.release(); + } + + @Test + public void testFragmentedFrame() { + EmbeddedChannel encoderChannel = new EmbeddedChannel( + ZlibCodecFactory.newZlibEncoder(ZlibWrapper.NONE, 9, 15, 8)); + EmbeddedChannel decoderChannel = new EmbeddedChannel(new PerMessageDeflateDecoder(false)); + + // initialize + byte[] payload = new byte[300]; + random.nextBytes(payload); + + assertTrue(encoderChannel.writeOutbound(Unpooled.wrappedBuffer(payload))); + ByteBuf compressedPayload = encoderChannel.readOutbound(); + compressedPayload = compressedPayload.slice(0, compressedPayload.readableBytes() - 4); + + int oneThird = compressedPayload.readableBytes() / 3; + BinaryWebSocketFrame compressedFrame1 = new BinaryWebSocketFrame(false, + WebSocketExtension.RSV1 | WebSocketExtension.RSV3, + compressedPayload.slice(0, oneThird)); + ContinuationWebSocketFrame compressedFrame2 = new ContinuationWebSocketFrame(false, + WebSocketExtension.RSV3, compressedPayload.slice(oneThird, oneThird)); + ContinuationWebSocketFrame compressedFrame3 = new ContinuationWebSocketFrame(true, + WebSocketExtension.RSV3, compressedPayload.slice(oneThird * 2, + compressedPayload.readableBytes() - oneThird * 2)); + + // execute + assertTrue(decoderChannel.writeInbound(compressedFrame1.retain())); + assertTrue(decoderChannel.writeInbound(compressedFrame2.retain())); + assertTrue(decoderChannel.writeInbound(compressedFrame3)); + BinaryWebSocketFrame uncompressedFrame1 = decoderChannel.readInbound(); + ContinuationWebSocketFrame uncompressedFrame2 = decoderChannel.readInbound(); + ContinuationWebSocketFrame uncompressedFrame3 = decoderChannel.readInbound(); + + // test + assertNotNull(uncompressedFrame1); + assertNotNull(uncompressedFrame2); + assertNotNull(uncompressedFrame3); + assertEquals(WebSocketExtension.RSV3, uncompressedFrame1.rsv()); + assertEquals(WebSocketExtension.RSV3, uncompressedFrame2.rsv()); + assertEquals(WebSocketExtension.RSV3, uncompressedFrame3.rsv()); + + ByteBuf finalPayloadWrapped = Unpooled.wrappedBuffer(uncompressedFrame1.content(), + uncompressedFrame2.content(), uncompressedFrame3.content()); + assertEquals(300, finalPayloadWrapped.readableBytes()); + + byte[] finalPayload = new byte[300]; + finalPayloadWrapped.readBytes(finalPayload); + assertArrayEquals(finalPayload, payload); + finalPayloadWrapped.release(); + } + + @Test + public void testMultiCompressedPayloadWithinFrame() { + EmbeddedChannel encoderChannel = new EmbeddedChannel( + ZlibCodecFactory.newZlibEncoder(ZlibWrapper.NONE, 9, 15, 8)); + EmbeddedChannel decoderChannel = new EmbeddedChannel(new PerMessageDeflateDecoder(false)); + + // initialize + byte[] payload1 = new byte[100]; + random.nextBytes(payload1); + byte[] payload2 = new byte[100]; + random.nextBytes(payload2); + + assertTrue(encoderChannel.writeOutbound(Unpooled.wrappedBuffer(payload1))); + ByteBuf compressedPayload1 = encoderChannel.readOutbound(); + assertTrue(encoderChannel.writeOutbound(Unpooled.wrappedBuffer(payload2))); + ByteBuf compressedPayload2 = encoderChannel.readOutbound(); + + BinaryWebSocketFrame compressedFrame = new BinaryWebSocketFrame(true, + WebSocketExtension.RSV1 | WebSocketExtension.RSV3, + Unpooled.wrappedBuffer( + compressedPayload1, + compressedPayload2.slice(0, compressedPayload2.readableBytes() - 4))); + + // execute + assertTrue(decoderChannel.writeInbound(compressedFrame)); + BinaryWebSocketFrame uncompressedFrame = decoderChannel.readInbound(); + + // test + assertNotNull(uncompressedFrame); + assertNotNull(uncompressedFrame.content()); + assertEquals(WebSocketExtension.RSV3, uncompressedFrame.rsv()); + assertEquals(200, uncompressedFrame.content().readableBytes()); + + byte[] finalPayload1 = new byte[100]; + uncompressedFrame.content().readBytes(finalPayload1); + assertArrayEquals(finalPayload1, payload1); + byte[] finalPayload2 = new byte[100]; + uncompressedFrame.content().readBytes(finalPayload2); + assertArrayEquals(finalPayload2, payload2); + uncompressedFrame.release(); + } + + @Test + public void testDecompressionSkipForBinaryFrame() { + EmbeddedChannel encoderChannel = new EmbeddedChannel( + ZlibCodecFactory.newZlibEncoder(ZlibWrapper.NONE, 9, 15, 8)); + EmbeddedChannel decoderChannel = new EmbeddedChannel(new PerMessageDeflateDecoder(false, ALWAYS_SKIP)); + + byte[] payload = new byte[300]; + random.nextBytes(payload); + + assertTrue(encoderChannel.writeOutbound(Unpooled.wrappedBuffer(payload))); + ByteBuf compressedPayload = encoderChannel.readOutbound(); + + BinaryWebSocketFrame compressedBinaryFrame = new BinaryWebSocketFrame(true, WebSocketExtension.RSV1, + compressedPayload); + assertTrue(decoderChannel.writeInbound(compressedBinaryFrame)); + + WebSocketFrame inboundFrame = decoderChannel.readInbound(); + + assertEquals(WebSocketExtension.RSV1, inboundFrame.rsv()); + assertEquals(compressedPayload, inboundFrame.content()); + assertTrue(inboundFrame.release()); + + assertTrue(encoderChannel.finishAndReleaseAll()); + assertFalse(decoderChannel.finish()); + } + + @Test + public void testSelectivityDecompressionSkip() { + WebSocketExtensionFilter selectivityDecompressionFilter = new WebSocketExtensionFilter() { + @Override + public boolean mustSkip(WebSocketFrame frame) { + return frame instanceof TextWebSocketFrame && frame.content().readableBytes() < 100; + } + }; + EmbeddedChannel encoderChannel = new EmbeddedChannel( + ZlibCodecFactory.newZlibEncoder(ZlibWrapper.NONE, 9, 15, 8)); + EmbeddedChannel decoderChannel = new EmbeddedChannel( + new PerMessageDeflateDecoder(false, selectivityDecompressionFilter)); + + String textPayload = "compressed payload"; + byte[] binaryPayload = new byte[300]; + random.nextBytes(binaryPayload); + + assertTrue(encoderChannel.writeOutbound(Unpooled.wrappedBuffer(textPayload.getBytes(UTF_8)))); + assertTrue(encoderChannel.writeOutbound(Unpooled.wrappedBuffer(binaryPayload))); + ByteBuf compressedTextPayload = encoderChannel.readOutbound(); + ByteBuf compressedBinaryPayload = encoderChannel.readOutbound(); + + TextWebSocketFrame compressedTextFrame = new TextWebSocketFrame(true, WebSocketExtension.RSV1, + compressedTextPayload); + BinaryWebSocketFrame compressedBinaryFrame = new BinaryWebSocketFrame(true, WebSocketExtension.RSV1, + compressedBinaryPayload); + + assertTrue(decoderChannel.writeInbound(compressedTextFrame)); + assertTrue(decoderChannel.writeInbound(compressedBinaryFrame)); + + TextWebSocketFrame inboundTextFrame = decoderChannel.readInbound(); + BinaryWebSocketFrame inboundBinaryFrame = decoderChannel.readInbound(); + + assertEquals(WebSocketExtension.RSV1, inboundTextFrame.rsv()); + assertEquals(compressedTextPayload, inboundTextFrame.content()); + assertTrue(inboundTextFrame.release()); + + assertEquals(0, inboundBinaryFrame.rsv()); + assertArrayEquals(binaryPayload, ByteBufUtil.getBytes(inboundBinaryFrame.content())); + assertTrue(inboundBinaryFrame.release()); + + assertTrue(encoderChannel.finishAndReleaseAll()); + assertFalse(decoderChannel.finish()); + } + + @Test + public void testIllegalStateWhenDecompressionInProgress() { + WebSocketExtensionFilter selectivityDecompressionFilter = new WebSocketExtensionFilter() { + @Override + public boolean mustSkip(WebSocketFrame frame) { + return frame.content().readableBytes() < 100; + } + }; + + EmbeddedChannel encoderChannel = new EmbeddedChannel( + ZlibCodecFactory.newZlibEncoder(ZlibWrapper.NONE, 9, 15, 8)); + final EmbeddedChannel decoderChannel = new EmbeddedChannel( + new PerMessageDeflateDecoder(false, selectivityDecompressionFilter)); + + byte[] firstPayload = new byte[200]; + random.nextBytes(firstPayload); + + byte[] finalPayload = new byte[50]; + random.nextBytes(finalPayload); + + assertTrue(encoderChannel.writeOutbound(Unpooled.wrappedBuffer(firstPayload))); + assertTrue(encoderChannel.writeOutbound(Unpooled.wrappedBuffer(finalPayload))); + ByteBuf compressedFirstPayload = encoderChannel.readOutbound(); + ByteBuf compressedFinalPayload = encoderChannel.readOutbound(); + assertTrue(encoderChannel.finishAndReleaseAll()); + + BinaryWebSocketFrame firstPart = new BinaryWebSocketFrame(false, WebSocketExtension.RSV1, + compressedFirstPayload); + final ContinuationWebSocketFrame finalPart = new ContinuationWebSocketFrame(true, WebSocketExtension.RSV1, + compressedFinalPayload); + assertTrue(decoderChannel.writeInbound(firstPart)); + + BinaryWebSocketFrame outboundFirstPart = decoderChannel.readInbound(); + //first part is decompressed + assertEquals(0, outboundFirstPart.rsv()); + assertArrayEquals(firstPayload, ByteBufUtil.getBytes(outboundFirstPart.content())); + assertTrue(outboundFirstPart.release()); + + //final part throwing exception + try { + assertThrows(DecoderException.class, new Executable() { + @Override + public void execute() { + decoderChannel.writeInbound(finalPart); + } + }); + } finally { + assertTrue(finalPart.release()); + assertFalse(encoderChannel.finishAndReleaseAll()); + } + } + + @Test + public void testEmptyFrameDecompression() { + EmbeddedChannel decoderChannel = new EmbeddedChannel(new PerMessageDeflateDecoder(false)); + + TextWebSocketFrame emptyDeflateBlockFrame = new TextWebSocketFrame(true, WebSocketExtension.RSV1, + EMPTY_DEFLATE_BLOCK); + + assertTrue(decoderChannel.writeInbound(emptyDeflateBlockFrame)); + TextWebSocketFrame emptyBufferFrame = decoderChannel.readInbound(); + + assertFalse(emptyBufferFrame.content().isReadable()); + + // Composite empty buffer + assertTrue(emptyBufferFrame.release()); + assertFalse(decoderChannel.finish()); + } + + @Test + public void testFragmentedFrameWithLeftOverInLastFragment() { + String hexDump = "677170647a777a737574656b707a787a6f6a7561756578756f6b7868616371716c657a6d64697479766d726f6" + + "269746c6376777464776f6f72767a726f64667278676764687775786f6762766d776d706b76697773777a7072" + + "6a6a737279707a7078697a6c69616d7461656d646278626d786f66666e686e776a7a7461746d7a776668776b6" + + "f6f736e73746575637a6d727a7175707a6e74627578687871767771697a71766c64626d78726d6d7675756877" + + "62667963626b687a726d676e646263776e67797264706d6c6863626577616967706a78636a72697464756e627" + + "977616f79736475676f76736f7178746a7a7479626c64636b6b6778637768746c62"; + EmbeddedChannel encoderChannel = new EmbeddedChannel( + ZlibCodecFactory.newZlibEncoder(ZlibWrapper.NONE, 9, 15, 8)); + EmbeddedChannel decoderChannel = new EmbeddedChannel(new PerMessageDeflateDecoder(false)); + + ByteBuf originPayload = Unpooled.wrappedBuffer(ByteBufUtil.decodeHexDump(hexDump)); + assertTrue(encoderChannel.writeOutbound(originPayload.duplicate().retain())); + + ByteBuf compressedPayload = encoderChannel.readOutbound(); + compressedPayload = compressedPayload.slice(0, compressedPayload.readableBytes() - 4); + + int oneThird = compressedPayload.readableBytes() / 3; + + TextWebSocketFrame compressedFrame1 = new TextWebSocketFrame( + false, WebSocketExtension.RSV1, compressedPayload.slice(0, oneThird)); + ContinuationWebSocketFrame compressedFrame2 = new ContinuationWebSocketFrame( + false, WebSocketExtension.RSV3, compressedPayload.slice(oneThird, oneThird)); + ContinuationWebSocketFrame compressedFrame3 = new ContinuationWebSocketFrame( + false, WebSocketExtension.RSV3, compressedPayload.slice(oneThird * 2, oneThird)); + int offset = oneThird * 3; + ContinuationWebSocketFrame compressedFrameWithExtraData = new ContinuationWebSocketFrame( + true, WebSocketExtension.RSV3, compressedPayload.slice(offset, + compressedPayload.readableBytes() - offset)); + + // check that last fragment contains only one extra byte + assertEquals(1, compressedFrameWithExtraData.content().readableBytes()); + assertEquals(1, compressedFrameWithExtraData.content().getByte(0)); + + // write compressed frames + assertTrue(decoderChannel.writeInbound(compressedFrame1.retain())); + assertTrue(decoderChannel.writeInbound(compressedFrame2.retain())); + assertTrue(decoderChannel.writeInbound(compressedFrame3.retain())); + assertTrue(decoderChannel.writeInbound(compressedFrameWithExtraData)); + + // read uncompressed frames + TextWebSocketFrame uncompressedFrame1 = decoderChannel.readInbound(); + ContinuationWebSocketFrame uncompressedFrame2 = decoderChannel.readInbound(); + ContinuationWebSocketFrame uncompressedFrame3 = decoderChannel.readInbound(); + ContinuationWebSocketFrame uncompressedExtraData = decoderChannel.readInbound(); + assertFalse(uncompressedExtraData.content().isReadable()); + + ByteBuf uncompressedPayload = Unpooled.wrappedBuffer(uncompressedFrame1.content(), uncompressedFrame2.content(), + uncompressedFrame3.content(), uncompressedExtraData.content()); + assertEquals(originPayload, uncompressedPayload); + + assertTrue(originPayload.release()); + assertTrue(uncompressedPayload.release()); + + assertTrue(encoderChannel.finishAndReleaseAll()); + assertFalse(decoderChannel.finish()); + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerMessageDeflateEncoderTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerMessageDeflateEncoderTest.java new file mode 100644 index 0000000..1cd8c94 --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerMessageDeflateEncoderTest.java @@ -0,0 +1,324 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx.extensions.compression; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.EncoderException; +import io.netty.handler.codec.compression.ZlibCodecFactory; +import io.netty.handler.codec.compression.ZlibWrapper; +import io.netty.handler.codec.http.websocketx.BinaryWebSocketFrame; +import io.netty.handler.codec.http.websocketx.ContinuationWebSocketFrame; +import io.netty.handler.codec.http.websocketx.TextWebSocketFrame; +import io.netty.handler.codec.http.websocketx.WebSocketFrame; +import io.netty.handler.codec.http.websocketx.extensions.WebSocketExtension; +import io.netty.handler.codec.http.websocketx.extensions.WebSocketExtensionFilter; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import java.util.Arrays; +import java.util.Random; + +import static io.netty.handler.codec.http.websocketx.extensions.WebSocketExtensionFilter.*; +import static io.netty.handler.codec.http.websocketx.extensions.compression.DeflateDecoder.*; +import static io.netty.util.CharsetUtil.*; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class PerMessageDeflateEncoderTest { + + private static final Random random = new Random(); + + @Test + public void testCompressedFrame() { + EmbeddedChannel encoderChannel = new EmbeddedChannel(new PerMessageDeflateEncoder(9, 15, false)); + EmbeddedChannel decoderChannel = new EmbeddedChannel( + ZlibCodecFactory.newZlibDecoder(ZlibWrapper.NONE)); + + // initialize + byte[] payload = new byte[300]; + random.nextBytes(payload); + BinaryWebSocketFrame frame = new BinaryWebSocketFrame(true, + WebSocketExtension.RSV3, Unpooled.wrappedBuffer(payload)); + + // execute + assertTrue(encoderChannel.writeOutbound(frame)); + BinaryWebSocketFrame compressedFrame = encoderChannel.readOutbound(); + + // test + assertNotNull(compressedFrame); + assertNotNull(compressedFrame.content()); + assertEquals(WebSocketExtension.RSV1 | WebSocketExtension.RSV3, compressedFrame.rsv()); + + assertTrue(decoderChannel.writeInbound(compressedFrame.content())); + assertTrue(decoderChannel.writeInbound(DeflateDecoder.FRAME_TAIL.duplicate())); + ByteBuf uncompressedPayload = decoderChannel.readInbound(); + assertEquals(300, uncompressedPayload.readableBytes()); + + byte[] finalPayload = new byte[300]; + uncompressedPayload.readBytes(finalPayload); + assertArrayEquals(finalPayload, payload); + uncompressedPayload.release(); + } + + @Test + public void testAlreadyCompressedFrame() { + EmbeddedChannel encoderChannel = new EmbeddedChannel(new PerMessageDeflateEncoder(9, 15, false)); + + // initialize + byte[] payload = new byte[300]; + random.nextBytes(payload); + + BinaryWebSocketFrame frame = new BinaryWebSocketFrame(true, + WebSocketExtension.RSV3 | WebSocketExtension.RSV1, + Unpooled.wrappedBuffer(payload)); + + // execute + assertTrue(encoderChannel.writeOutbound(frame)); + BinaryWebSocketFrame newFrame = encoderChannel.readOutbound(); + + // test + assertNotNull(newFrame); + assertNotNull(newFrame.content()); + assertEquals(WebSocketExtension.RSV3 | WebSocketExtension.RSV1, newFrame.rsv()); + assertEquals(300, newFrame.content().readableBytes()); + + byte[] finalPayload = new byte[300]; + newFrame.content().readBytes(finalPayload); + assertArrayEquals(finalPayload, payload); + newFrame.release(); + } + + @Test + public void testFragmentedFrame() { + EmbeddedChannel encoderChannel = new EmbeddedChannel(new PerMessageDeflateEncoder(9, 15, false, + NEVER_SKIP)); + EmbeddedChannel decoderChannel = new EmbeddedChannel( + ZlibCodecFactory.newZlibDecoder(ZlibWrapper.NONE)); + + // initialize + byte[] payload1 = new byte[100]; + random.nextBytes(payload1); + byte[] payload2 = new byte[100]; + random.nextBytes(payload2); + byte[] payload3 = new byte[100]; + random.nextBytes(payload3); + + BinaryWebSocketFrame frame1 = new BinaryWebSocketFrame(false, + WebSocketExtension.RSV3, + Unpooled.wrappedBuffer(payload1)); + ContinuationWebSocketFrame frame2 = new ContinuationWebSocketFrame(false, + WebSocketExtension.RSV3, + Unpooled.wrappedBuffer(payload2)); + ContinuationWebSocketFrame frame3 = new ContinuationWebSocketFrame(true, + WebSocketExtension.RSV3, + Unpooled.wrappedBuffer(payload3)); + + // execute + assertTrue(encoderChannel.writeOutbound(frame1)); + assertTrue(encoderChannel.writeOutbound(frame2)); + assertTrue(encoderChannel.writeOutbound(frame3)); + BinaryWebSocketFrame compressedFrame1 = encoderChannel.readOutbound(); + ContinuationWebSocketFrame compressedFrame2 = encoderChannel.readOutbound(); + ContinuationWebSocketFrame compressedFrame3 = encoderChannel.readOutbound(); + + // test + assertNotNull(compressedFrame1); + assertNotNull(compressedFrame2); + assertNotNull(compressedFrame3); + assertEquals(WebSocketExtension.RSV1 | WebSocketExtension.RSV3, compressedFrame1.rsv()); + assertEquals(WebSocketExtension.RSV3, compressedFrame2.rsv()); + assertEquals(WebSocketExtension.RSV3, compressedFrame3.rsv()); + assertFalse(compressedFrame1.isFinalFragment()); + assertFalse(compressedFrame2.isFinalFragment()); + assertTrue(compressedFrame3.isFinalFragment()); + + assertTrue(decoderChannel.writeInbound(compressedFrame1.content())); + ByteBuf uncompressedPayload1 = decoderChannel.readInbound(); + byte[] finalPayload1 = new byte[100]; + uncompressedPayload1.readBytes(finalPayload1); + assertArrayEquals(finalPayload1, payload1); + uncompressedPayload1.release(); + + assertTrue(decoderChannel.writeInbound(compressedFrame2.content())); + ByteBuf uncompressedPayload2 = decoderChannel.readInbound(); + byte[] finalPayload2 = new byte[100]; + uncompressedPayload2.readBytes(finalPayload2); + assertArrayEquals(finalPayload2, payload2); + uncompressedPayload2.release(); + + assertTrue(decoderChannel.writeInbound(compressedFrame3.content())); + assertTrue(decoderChannel.writeInbound(DeflateDecoder.FRAME_TAIL.duplicate())); + ByteBuf uncompressedPayload3 = decoderChannel.readInbound(); + byte[] finalPayload3 = new byte[100]; + uncompressedPayload3.readBytes(finalPayload3); + assertArrayEquals(finalPayload3, payload3); + uncompressedPayload3.release(); + } + + @Test + public void testCompressionSkipForBinaryFrame() { + EmbeddedChannel encoderChannel = new EmbeddedChannel(new PerMessageDeflateEncoder(9, 15, false, + ALWAYS_SKIP)); + byte[] payload = new byte[300]; + random.nextBytes(payload); + + WebSocketFrame binaryFrame = new BinaryWebSocketFrame(Unpooled.wrappedBuffer(payload)); + + assertTrue(encoderChannel.writeOutbound(binaryFrame.copy())); + WebSocketFrame outboundFrame = encoderChannel.readOutbound(); + + assertEquals(0, outboundFrame.rsv()); + assertArrayEquals(payload, ByteBufUtil.getBytes(outboundFrame.content())); + assertTrue(outboundFrame.release()); + + assertFalse(encoderChannel.finish()); + } + + @Test + public void testSelectivityCompressionSkip() { + WebSocketExtensionFilter selectivityCompressionFilter = new WebSocketExtensionFilter() { + @Override + public boolean mustSkip(WebSocketFrame frame) { + return (frame instanceof TextWebSocketFrame || frame instanceof BinaryWebSocketFrame) + && frame.content().readableBytes() < 100; + } + }; + EmbeddedChannel encoderChannel = new EmbeddedChannel( + new PerMessageDeflateEncoder(9, 15, false, selectivityCompressionFilter)); + EmbeddedChannel decoderChannel = new EmbeddedChannel( + ZlibCodecFactory.newZlibDecoder(ZlibWrapper.NONE)); + + String textPayload = "not compressed payload"; + byte[] binaryPayload = new byte[101]; + random.nextBytes(binaryPayload); + + WebSocketFrame textFrame = new TextWebSocketFrame(textPayload); + BinaryWebSocketFrame binaryFrame = new BinaryWebSocketFrame(Unpooled.wrappedBuffer(binaryPayload)); + + assertTrue(encoderChannel.writeOutbound(textFrame)); + assertTrue(encoderChannel.writeOutbound(binaryFrame)); + + WebSocketFrame outboundTextFrame = encoderChannel.readOutbound(); + + //compression skipped for textFrame + assertEquals(0, outboundTextFrame.rsv()); + assertEquals(textPayload, outboundTextFrame.content().toString(UTF_8)); + assertTrue(outboundTextFrame.release()); + + WebSocketFrame outboundBinaryFrame = encoderChannel.readOutbound(); + + //compression not skipped for binaryFrame + assertEquals(WebSocketExtension.RSV1, outboundBinaryFrame.rsv()); + + assertTrue(decoderChannel.writeInbound(outboundBinaryFrame.content().retain())); + ByteBuf uncompressedBinaryPayload = decoderChannel.readInbound(); + + assertArrayEquals(binaryPayload, ByteBufUtil.getBytes(uncompressedBinaryPayload)); + + assertTrue(outboundBinaryFrame.release()); + assertTrue(uncompressedBinaryPayload.release()); + + assertFalse(encoderChannel.finish()); + assertFalse(decoderChannel.finish()); + } + + @Test + public void testIllegalStateWhenCompressionInProgress() { + WebSocketExtensionFilter selectivityCompressionFilter = new WebSocketExtensionFilter() { + @Override + public boolean mustSkip(WebSocketFrame frame) { + return frame.content().readableBytes() < 100; + } + }; + final EmbeddedChannel encoderChannel = new EmbeddedChannel( + new PerMessageDeflateEncoder(9, 15, false, selectivityCompressionFilter)); + + byte[] firstPayload = new byte[200]; + random.nextBytes(firstPayload); + + byte[] finalPayload = new byte[90]; + random.nextBytes(finalPayload); + + BinaryWebSocketFrame firstPart = new BinaryWebSocketFrame(false, 0, Unpooled.wrappedBuffer(firstPayload)); + final ContinuationWebSocketFrame finalPart = new ContinuationWebSocketFrame(true, 0, + Unpooled.wrappedBuffer(finalPayload)); + assertTrue(encoderChannel.writeOutbound(firstPart)); + + BinaryWebSocketFrame outboundFirstPart = encoderChannel.readOutbound(); + //first part is compressed + assertEquals(WebSocketExtension.RSV1, outboundFirstPart.rsv()); + assertFalse(Arrays.equals(firstPayload, ByteBufUtil.getBytes(outboundFirstPart.content()))); + assertTrue(outboundFirstPart.release()); + + //final part throwing exception + try { + assertThrows(EncoderException.class, new Executable() { + @Override + public void execute() throws Throwable { + encoderChannel.writeOutbound(finalPart); + } + }); + } finally { + assertTrue(finalPart.release()); + assertFalse(encoderChannel.finishAndReleaseAll()); + } + } + + @Test + public void testEmptyFrameCompression() { + EmbeddedChannel encoderChannel = new EmbeddedChannel(new PerMessageDeflateEncoder(9, 15, false)); + + TextWebSocketFrame emptyFrame = new TextWebSocketFrame(""); + + assertTrue(encoderChannel.writeOutbound(emptyFrame)); + TextWebSocketFrame emptyDeflateFrame = encoderChannel.readOutbound(); + + assertEquals(WebSocketExtension.RSV1, emptyDeflateFrame.rsv()); + assertTrue(ByteBufUtil.equals(EMPTY_DEFLATE_BLOCK, emptyDeflateFrame.content())); + // Unreleasable buffer + assertFalse(emptyDeflateFrame.release()); + + assertFalse(encoderChannel.finish()); + } + + @Test + public void testCodecExceptionForNotFinEmptyFrame() { + final EmbeddedChannel encoderChannel = new EmbeddedChannel(new PerMessageDeflateEncoder(9, 15, false)); + + final TextWebSocketFrame emptyNotFinFrame = new TextWebSocketFrame(false, 0, ""); + + try { + assertThrows(EncoderException.class, new Executable() { + @Override + public void execute() { + encoderChannel.writeOutbound(emptyNotFinFrame); + } + }); + } finally { + // EmptyByteBuf buffer + assertFalse(emptyNotFinFrame.release()); + assertFalse(encoderChannel.finish()); + } + } + +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerMessageDeflateServerExtensionHandshakerTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerMessageDeflateServerExtensionHandshakerTest.java new file mode 100644 index 0000000..7f6a87a --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/compression/PerMessageDeflateServerExtensionHandshakerTest.java @@ -0,0 +1,176 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx.extensions.compression; + +import static io.netty.handler.codec.http.websocketx.extensions.compression. + PerMessageDeflateServerExtensionHandshaker.*; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import io.netty.handler.codec.http.websocketx.extensions.WebSocketServerExtension; +import io.netty.handler.codec.http.websocketx.extensions.WebSocketExtensionData; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +import org.junit.jupiter.api.Test; + +public class PerMessageDeflateServerExtensionHandshakerTest { + + @Test + public void testNormalHandshake() { + WebSocketServerExtension extension; + WebSocketExtensionData data; + Map parameters; + + // initialize + PerMessageDeflateServerExtensionHandshaker handshaker = + new PerMessageDeflateServerExtensionHandshaker(); + + // execute + extension = handshaker.handshakeExtension( + new WebSocketExtensionData(PERMESSAGE_DEFLATE_EXTENSION, Collections.emptyMap())); + + // test + assertNotNull(extension); + assertEquals(WebSocketServerExtension.RSV1, extension.rsv()); + assertTrue(extension.newExtensionDecoder() instanceof PerMessageDeflateDecoder); + assertTrue(extension.newExtensionEncoder() instanceof PerMessageDeflateEncoder); + + // execute + data = extension.newReponseData(); + + assertEquals(PERMESSAGE_DEFLATE_EXTENSION, data.name()); + assertTrue(data.parameters().isEmpty()); + + // initialize + parameters = new HashMap(); + parameters.put(CLIENT_MAX_WINDOW, null); + parameters.put(CLIENT_NO_CONTEXT, null); + + // execute + extension = handshaker.handshakeExtension( + new WebSocketExtensionData(PERMESSAGE_DEFLATE_EXTENSION, Collections.emptyMap())); + + // test + assertNotNull(extension); + assertEquals(WebSocketServerExtension.RSV1, extension.rsv()); + assertTrue(extension.newExtensionDecoder() instanceof PerMessageDeflateDecoder); + assertTrue(extension.newExtensionEncoder() instanceof PerMessageDeflateEncoder); + + // execute + data = extension.newReponseData(); + + // test + assertEquals(PERMESSAGE_DEFLATE_EXTENSION, data.name()); + assertTrue(data.parameters().isEmpty()); + + // initialize + parameters = new HashMap(); + parameters.put(SERVER_MAX_WINDOW, "12"); + parameters.put(SERVER_NO_CONTEXT, null); + + // execute + extension = handshaker.handshakeExtension( + new WebSocketExtensionData(PERMESSAGE_DEFLATE_EXTENSION, parameters)); + + // test + assertNull(extension); + } + + @Test + public void testCustomHandshake() { + WebSocketServerExtension extension; + Map parameters; + WebSocketExtensionData data; + + // initialize + PerMessageDeflateServerExtensionHandshaker handshaker = + new PerMessageDeflateServerExtensionHandshaker(6, true, 10, true, true); + + parameters = new HashMap(); + parameters.put(CLIENT_MAX_WINDOW, null); + parameters.put(SERVER_MAX_WINDOW, "12"); + parameters.put(CLIENT_NO_CONTEXT, null); + parameters.put(SERVER_NO_CONTEXT, null); + + // execute + extension = handshaker.handshakeExtension( + new WebSocketExtensionData(PERMESSAGE_DEFLATE_EXTENSION, parameters)); + + // test + assertNotNull(extension); + assertEquals(WebSocketServerExtension.RSV1, extension.rsv()); + assertTrue(extension.newExtensionDecoder() instanceof PerMessageDeflateDecoder); + assertTrue(extension.newExtensionEncoder() instanceof PerMessageDeflateEncoder); + + // execute + data = extension.newReponseData(); + + // test + assertEquals(PERMESSAGE_DEFLATE_EXTENSION, data.name()); + assertTrue(data.parameters().containsKey(CLIENT_MAX_WINDOW)); + assertEquals("10", data.parameters().get(CLIENT_MAX_WINDOW)); + assertTrue(data.parameters().containsKey(SERVER_MAX_WINDOW)); + assertEquals("12", data.parameters().get(SERVER_MAX_WINDOW)); + assertTrue(data.parameters().containsKey(CLIENT_MAX_WINDOW)); + assertTrue(data.parameters().containsKey(SERVER_MAX_WINDOW)); + + // initialize + parameters = new HashMap(); + parameters.put(SERVER_MAX_WINDOW, "12"); + parameters.put(SERVER_NO_CONTEXT, null); + + // execute + extension = handshaker.handshakeExtension( + new WebSocketExtensionData(PERMESSAGE_DEFLATE_EXTENSION, parameters)); + + // test + assertNotNull(extension); + assertEquals(WebSocketServerExtension.RSV1, extension.rsv()); + assertTrue(extension.newExtensionDecoder() instanceof PerMessageDeflateDecoder); + assertTrue(extension.newExtensionEncoder() instanceof PerMessageDeflateEncoder); + + // execute + data = extension.newReponseData(); + + // test + assertEquals(PERMESSAGE_DEFLATE_EXTENSION, data.name()); + assertEquals(2, data.parameters().size()); + assertTrue(data.parameters().containsKey(SERVER_MAX_WINDOW)); + assertEquals("12", data.parameters().get(SERVER_MAX_WINDOW)); + assertTrue(data.parameters().containsKey(SERVER_NO_CONTEXT)); + + // initialize + parameters = new HashMap(); + + // execute + extension = handshaker.handshakeExtension( + new WebSocketExtensionData(PERMESSAGE_DEFLATE_EXTENSION, parameters)); + // test + assertNotNull(extension); + + // execute + data = extension.newReponseData(); + + // test + assertEquals(PERMESSAGE_DEFLATE_EXTENSION, data.name()); + assertTrue(data.parameters().isEmpty()); + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/compression/WebSocketServerCompressionHandlerTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/compression/WebSocketServerCompressionHandlerTest.java new file mode 100644 index 0000000..ad3005f --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/http/websocketx/extensions/compression/WebSocketServerCompressionHandlerTest.java @@ -0,0 +1,201 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http.websocketx.extensions.compression; + +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpRequest; +import io.netty.handler.codec.http.HttpResponse; +import io.netty.handler.codec.http.websocketx.extensions.WebSocketExtensionData; +import io.netty.handler.codec.http.websocketx.extensions.WebSocketExtensionUtil; +import io.netty.handler.codec.http.websocketx.extensions.WebSocketServerExtensionHandler; + +import java.util.List; + +import org.junit.jupiter.api.Test; + +import static io.netty.handler.codec.http.websocketx.extensions.compression. + PerMessageDeflateServerExtensionHandshaker.*; +import static io.netty.handler.codec.http.websocketx.extensions.WebSocketExtensionTestUtil.*; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class WebSocketServerCompressionHandlerTest { + + @Test + public void testNormalSuccess() { + EmbeddedChannel ch = new EmbeddedChannel(new WebSocketServerCompressionHandler()); + + HttpRequest req = newUpgradeRequest(PERMESSAGE_DEFLATE_EXTENSION); + ch.writeInbound(req); + + HttpResponse res = newUpgradeResponse(null); + ch.writeOutbound(res); + + HttpResponse res2 = ch.readOutbound(); + List exts = WebSocketExtensionUtil.extractExtensions( + res2.headers().get(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS)); + + assertEquals(PERMESSAGE_DEFLATE_EXTENSION, exts.get(0).name()); + assertTrue(exts.get(0).parameters().isEmpty()); + assertNotNull(ch.pipeline().get(PerMessageDeflateDecoder.class)); + assertNotNull(ch.pipeline().get(PerMessageDeflateEncoder.class)); + } + + @Test + public void testClientWindowSizeSuccess() { + EmbeddedChannel ch = new EmbeddedChannel(new WebSocketServerExtensionHandler( + new PerMessageDeflateServerExtensionHandshaker(6, false, 10, false, false))); + + HttpRequest req = newUpgradeRequest(PERMESSAGE_DEFLATE_EXTENSION + "; " + CLIENT_MAX_WINDOW); + ch.writeInbound(req); + + HttpResponse res = newUpgradeResponse(null); + ch.writeOutbound(res); + + HttpResponse res2 = ch.readOutbound(); + List exts = WebSocketExtensionUtil.extractExtensions( + res2.headers().get(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS)); + + assertEquals(PERMESSAGE_DEFLATE_EXTENSION, exts.get(0).name()); + assertEquals("10", exts.get(0).parameters().get(CLIENT_MAX_WINDOW)); + assertNotNull(ch.pipeline().get(PerMessageDeflateDecoder.class)); + assertNotNull(ch.pipeline().get(PerMessageDeflateEncoder.class)); + } + + @Test + public void testClientWindowSizeUnavailable() { + EmbeddedChannel ch = new EmbeddedChannel(new WebSocketServerExtensionHandler( + new PerMessageDeflateServerExtensionHandshaker(6, false, 10, false, false))); + + HttpRequest req = newUpgradeRequest(PERMESSAGE_DEFLATE_EXTENSION); + ch.writeInbound(req); + + HttpResponse res = newUpgradeResponse(null); + ch.writeOutbound(res); + + HttpResponse res2 = ch.readOutbound(); + List exts = WebSocketExtensionUtil.extractExtensions( + res2.headers().get(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS)); + + assertEquals(PERMESSAGE_DEFLATE_EXTENSION, exts.get(0).name()); + assertTrue(exts.get(0).parameters().isEmpty()); + assertNotNull(ch.pipeline().get(PerMessageDeflateDecoder.class)); + assertNotNull(ch.pipeline().get(PerMessageDeflateEncoder.class)); + } + + @Test + public void testServerWindowSizeSuccess() { + EmbeddedChannel ch = new EmbeddedChannel(new WebSocketServerExtensionHandler( + new PerMessageDeflateServerExtensionHandshaker(6, true, 15, false, false))); + + HttpRequest req = newUpgradeRequest(PERMESSAGE_DEFLATE_EXTENSION + "; " + SERVER_MAX_WINDOW + "=10"); + ch.writeInbound(req); + + HttpResponse res = newUpgradeResponse(null); + ch.writeOutbound(res); + + HttpResponse res2 = ch.readOutbound(); + List exts = WebSocketExtensionUtil.extractExtensions( + res2.headers().get(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS)); + + assertEquals(PERMESSAGE_DEFLATE_EXTENSION, exts.get(0).name()); + assertEquals("10", exts.get(0).parameters().get(SERVER_MAX_WINDOW)); + assertNotNull(ch.pipeline().get(PerMessageDeflateDecoder.class)); + assertNotNull(ch.pipeline().get(PerMessageDeflateEncoder.class)); + } + + @Test + public void testServerWindowSizeDisable() { + EmbeddedChannel ch = new EmbeddedChannel(new WebSocketServerExtensionHandler( + new PerMessageDeflateServerExtensionHandshaker(6, false, 15, false, false))); + + HttpRequest req = newUpgradeRequest(PERMESSAGE_DEFLATE_EXTENSION + "; " + SERVER_MAX_WINDOW + "=10"); + ch.writeInbound(req); + + HttpResponse res = newUpgradeResponse(null); + ch.writeOutbound(res); + + HttpResponse res2 = ch.readOutbound(); + + assertFalse(res2.headers().contains(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS)); + assertNull(ch.pipeline().get(PerMessageDeflateDecoder.class)); + assertNull(ch.pipeline().get(PerMessageDeflateEncoder.class)); + } + + @Test + public void testServerNoContext() { + EmbeddedChannel ch = new EmbeddedChannel(new WebSocketServerCompressionHandler()); + + HttpRequest req = newUpgradeRequest(PERMESSAGE_DEFLATE_EXTENSION + "; " + SERVER_NO_CONTEXT); + ch.writeInbound(req); + + HttpResponse res = newUpgradeResponse(null); + ch.writeOutbound(res); + + HttpResponse res2 = ch.readOutbound(); + + assertFalse(res2.headers().contains(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS)); + assertNull(ch.pipeline().get(PerMessageDeflateDecoder.class)); + assertNull(ch.pipeline().get(PerMessageDeflateEncoder.class)); + } + + @Test + public void testClientNoContext() { + EmbeddedChannel ch = new EmbeddedChannel(new WebSocketServerCompressionHandler()); + + HttpRequest req = newUpgradeRequest(PERMESSAGE_DEFLATE_EXTENSION + "; " + CLIENT_NO_CONTEXT); + ch.writeInbound(req); + + HttpResponse res = newUpgradeResponse(null); + ch.writeOutbound(res); + + HttpResponse res2 = ch.readOutbound(); + List exts = WebSocketExtensionUtil.extractExtensions( + res2.headers().get(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS)); + + assertEquals(PERMESSAGE_DEFLATE_EXTENSION, exts.get(0).name()); + assertTrue(exts.get(0).parameters().isEmpty()); + assertNotNull(ch.pipeline().get(PerMessageDeflateDecoder.class)); + assertNotNull(ch.pipeline().get(PerMessageDeflateEncoder.class)); + } + + @Test + public void testServerWindowSizeDisableThenFallback() { + EmbeddedChannel ch = new EmbeddedChannel(new WebSocketServerExtensionHandler( + new PerMessageDeflateServerExtensionHandshaker(6, false, 15, false, false))); + + HttpRequest req = newUpgradeRequest(PERMESSAGE_DEFLATE_EXTENSION + "; " + SERVER_MAX_WINDOW + "=10, " + + PERMESSAGE_DEFLATE_EXTENSION); + ch.writeInbound(req); + + HttpResponse res = newUpgradeResponse(null); + ch.writeOutbound(res); + + HttpResponse res2 = ch.readOutbound(); + List exts = WebSocketExtensionUtil.extractExtensions( + res2.headers().get(HttpHeaderNames.SEC_WEBSOCKET_EXTENSIONS)); + + assertEquals(PERMESSAGE_DEFLATE_EXTENSION, exts.get(0).name()); + assertTrue(exts.get(0).parameters().isEmpty()); + assertNotNull(ch.pipeline().get(PerMessageDeflateDecoder.class)); + assertNotNull(ch.pipeline().get(PerMessageDeflateEncoder.class)); + } + +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/rtsp/RtspDecoderTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/rtsp/RtspDecoderTest.java new file mode 100644 index 0000000..4894912 --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/rtsp/RtspDecoderTest.java @@ -0,0 +1,72 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.rtsp; + +import io.netty.buffer.Unpooled; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpObject; +import io.netty.handler.codec.http.HttpObjectAggregator; +import org.junit.jupiter.api.Test; + + +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * Test cases for RTSP decoder. + */ +public class RtspDecoderTest { + + /** + * There was a problem when an ANNOUNCE request was issued by the server, + * i.e. entered through the response decoder. First the decoder failed to + * parse the ANNOUNCE request, then it stopped receiving any more + * responses. This test verifies that the issue is solved. + */ + @Test + public void testReceiveAnnounce() { + byte[] data1 = ("ANNOUNCE rtsp://172.20.184.218:554/d3abaaa7-65f2-" + + "42b4-8d6b-379f492fcf0f RTSP/1.0\r\n" + + "CSeq: 2\r\n" + + "Session: 2777476816092819869\r\n" + + "x-notice: 5402 \"Session Terminated by Server\" " + + "event-date=20150514T075303Z\r\n" + + "Range: npt=0\r\n\r\n").getBytes(); + + byte[] data2 = ("RTSP/1.0 200 OK\r\n" + + "Server: Orbit2x\r\n" + + "CSeq: 172\r\n" + + "Session: 2547019973447939919\r\n" + + "\r\n").getBytes(); + + EmbeddedChannel ch = new EmbeddedChannel(new RtspDecoder(), + new HttpObjectAggregator(1048576)); + ch.writeInbound(Unpooled.wrappedBuffer(data1), + Unpooled.wrappedBuffer(data2)); + + HttpObject res1 = ch.readInbound(); + assertNotNull(res1); + assertTrue(res1 instanceof FullHttpRequest); + ((FullHttpRequest) res1).release(); + + HttpObject res2 = ch.readInbound(); + assertNotNull(res2); + assertTrue(res2 instanceof FullHttpResponse); + ((FullHttpResponse) res2).release(); + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/rtsp/RtspEncoderTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/rtsp/RtspEncoderTest.java new file mode 100644 index 0000000..4af61a1 --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/rtsp/RtspEncoderTest.java @@ -0,0 +1,173 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.rtsp; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.http.DefaultFullHttpRequest; +import io.netty.handler.codec.http.DefaultFullHttpResponse; +import io.netty.handler.codec.http.DefaultHttpRequest; +import io.netty.handler.codec.http.DefaultHttpResponse; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpRequest; +import io.netty.handler.codec.http.HttpResponse; +import io.netty.util.CharsetUtil; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +/** + * Test cases for RTSP encoder. + */ +public class RtspEncoderTest { + + /** + * Test of a SETUP request, with no body. + */ + @Test + public void testSendSetupRequest() { + String expected = "SETUP rtsp://172.10.20.30:554/d3abaaa7-65f2-42b4-" + + "8d6b-379f492fcf0f RTSP/1.0\r\n" + + "transport: MP2T/DVBC/UDP;unicast;client=01234567;" + + "source=172.10.20.30;" + + "destination=1.1.1.1;client_port=6922\r\n" + + "cseq: 1\r\n" + + "\r\n"; + + HttpRequest request = new DefaultHttpRequest(RtspVersions.RTSP_1_0, + RtspMethods.SETUP, + "rtsp://172.10.20.30:554/d3abaaa7-65f2-42b4-8d6b-379f492fcf0f"); + request.headers().add(RtspHeaderNames.TRANSPORT, + "MP2T/DVBC/UDP;unicast;client=01234567;source=172.10.20.30;" + + "destination=1.1.1.1;client_port=6922"); + request.headers().add(RtspHeaderNames.CSEQ, "1"); + + EmbeddedChannel ch = new EmbeddedChannel(new RtspEncoder()); + ch.writeOutbound(request); + + ByteBuf buf = ch.readOutbound(); + String actual = buf.toString(CharsetUtil.UTF_8); + buf.release(); + assertEquals(expected, actual); + } + + /** + * Test of a GET_PARAMETER request, with body. + */ + @Test + public void testSendGetParameterRequest() { + String expected = "GET_PARAMETER rtsp://172.10.20.30:554 RTSP/1.0\r\n" + + "session: 2547019973447939919\r\n" + + "cseq: 3\r\n" + + "content-length: 31\r\n" + + "content-type: text/parameters\r\n" + + "\r\n" + + "stream_state\r\n" + + "position\r\n" + + "scale\r\n"; + + byte[] content = ("stream_state\r\n" + + "position\r\n" + + "scale\r\n").getBytes(CharsetUtil.UTF_8); + + FullHttpRequest request = new DefaultFullHttpRequest( + RtspVersions.RTSP_1_0, + RtspMethods.GET_PARAMETER, + "rtsp://172.10.20.30:554"); + request.headers().add(RtspHeaderNames.SESSION, "2547019973447939919"); + request.headers().add(RtspHeaderNames.CSEQ, "3"); + request.headers().add(RtspHeaderNames.CONTENT_LENGTH, + "" + content.length); + request.headers().add(RtspHeaderNames.CONTENT_TYPE, "text/parameters"); + request.content().writeBytes(content); + + EmbeddedChannel ch = new EmbeddedChannel(new RtspEncoder()); + ch.writeOutbound(request); + + ByteBuf buf = ch.readOutbound(); + String actual = buf.toString(CharsetUtil.UTF_8); + buf.release(); + assertEquals(expected, actual); + } + + /** + * Test of a 200 OK response, without body. + */ + @Test + public void testSend200OkResponseWithoutBody() { + String expected = "RTSP/1.0 200 OK\r\n" + + "server: Testserver\r\n" + + "cseq: 1\r\n" + + "session: 2547019973447939919\r\n" + + "\r\n"; + + HttpResponse response = new DefaultHttpResponse(RtspVersions.RTSP_1_0, + RtspResponseStatuses.OK); + response.headers().add(RtspHeaderNames.SERVER, "Testserver"); + response.headers().add(RtspHeaderNames.CSEQ, "1"); + response.headers().add(RtspHeaderNames.SESSION, "2547019973447939919"); + + EmbeddedChannel ch = new EmbeddedChannel(new RtspEncoder()); + ch.writeOutbound(response); + + ByteBuf buf = ch.readOutbound(); + String actual = buf.toString(CharsetUtil.UTF_8); + buf.release(); + assertEquals(expected, actual); + } + + /** + * Test of a 200 OK response, with body. + */ + @Test + public void testSend200OkResponseWithBody() { + String expected = "RTSP/1.0 200 OK\r\n" + + "server: Testserver\r\n" + + "session: 2547019973447939919\r\n" + + "content-type: text/parameters\r\n" + + "content-length: 50\r\n" + + "cseq: 3\r\n" + + "\r\n" + + "position: 24\r\n" + + "stream_state: playing\r\n" + + "scale: 1.00\r\n"; + + byte[] content = ("position: 24\r\n" + + "stream_state: playing\r\n" + + "scale: 1.00\r\n").getBytes(CharsetUtil.UTF_8); + + FullHttpResponse response = + new DefaultFullHttpResponse(RtspVersions.RTSP_1_0, + RtspResponseStatuses.OK); + response.headers().add(RtspHeaderNames.SERVER, "Testserver"); + response.headers().add(RtspHeaderNames.SESSION, "2547019973447939919"); + response.headers().add(RtspHeaderNames.CONTENT_TYPE, + "text/parameters"); + response.headers().add(RtspHeaderNames.CONTENT_LENGTH, + "" + content.length); + response.headers().add(RtspHeaderNames.CSEQ, "3"); + response.content().writeBytes(content); + + EmbeddedChannel ch = new EmbeddedChannel(new RtspEncoder()); + ch.writeOutbound(response); + + ByteBuf buf = ch.readOutbound(); + String actual = buf.toString(CharsetUtil.UTF_8); + buf.release(); + assertEquals(expected, actual); + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/spdy/DefaultSpdyHeadersTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/spdy/DefaultSpdyHeadersTest.java new file mode 100644 index 0000000..58f39bd --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/spdy/DefaultSpdyHeadersTest.java @@ -0,0 +1,58 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.spdy; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; + +public class DefaultSpdyHeadersTest { + + @Test + public void testStringKeyRetrievedAsAsciiString() { + final SpdyHeaders headers = new DefaultSpdyHeaders(); + + // Test adding String key and retrieving it using a AsciiString key + final String method = "GET"; + headers.add(":method", method); + + final String value = headers.getAsString(SpdyHeaders.HttpNames.METHOD.toString()); + assertNotNull(value); + assertEquals(method, value); + + final String value2 = headers.getAsString(SpdyHeaders.HttpNames.METHOD); + assertNotNull(value2); + assertEquals(method, value2); + } + + @Test + public void testAsciiStringKeyRetrievedAsString() { + final SpdyHeaders headers = new DefaultSpdyHeaders(); + + // Test adding AsciiString key and retrieving it using a String key + final String path = "/"; + headers.add(SpdyHeaders.HttpNames.PATH, path); + + final String value = headers.getAsString(SpdyHeaders.HttpNames.PATH); + assertNotNull(value); + assertEquals(path, value); + + final String value2 = headers.getAsString(SpdyHeaders.HttpNames.PATH.toString()); + assertNotNull(value2); + assertEquals(path, value2); + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/spdy/SpdyFrameDecoderTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/spdy/SpdyFrameDecoderTest.java new file mode 100644 index 0000000..79970b4 --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/spdy/SpdyFrameDecoderTest.java @@ -0,0 +1,1330 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.spdy; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + + +import java.util.ArrayDeque; +import java.util.Queue; +import java.util.Random; + +import static io.netty.handler.codec.spdy.SpdyCodecUtil.SPDY_HEADER_SIZE; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyString; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyZeroInteractions; + +public class SpdyFrameDecoderTest { + + private static final Random RANDOM = new Random(); + + private final SpdyFrameDecoderDelegate delegate = mock(SpdyFrameDecoderDelegate.class); + private final TestSpdyFrameDecoderDelegate testDelegate = new TestSpdyFrameDecoderDelegate(); + private SpdyFrameDecoder decoder; + + @BeforeEach + public void createDecoder() { + decoder = new SpdyFrameDecoder(SpdyVersion.SPDY_3_1, testDelegate); + } + + @AfterEach + public void releaseBuffers() { + testDelegate.releaseAll(); + } + + private final class TestSpdyFrameDecoderDelegate implements SpdyFrameDecoderDelegate { + private final Queue buffers = new ArrayDeque(); + + @Override + public void readDataFrame(int streamId, boolean last, ByteBuf data) { + delegate.readDataFrame(streamId, last, data); + buffers.add(data); + } + + @Override + public void readSynStreamFrame(int streamId, int associatedToStreamId, + byte priority, boolean last, boolean unidirectional) { + delegate.readSynStreamFrame(streamId, associatedToStreamId, priority, last, unidirectional); + } + + @Override + public void readSynReplyFrame(int streamId, boolean last) { + delegate.readSynReplyFrame(streamId, last); + } + + @Override + public void readRstStreamFrame(int streamId, int statusCode) { + delegate.readRstStreamFrame(streamId, statusCode); + } + + @Override + public void readSettingsFrame(boolean clearPersisted) { + delegate.readSettingsFrame(clearPersisted); + } + + @Override + public void readSetting(int id, int value, boolean persistValue, boolean persisted) { + delegate.readSetting(id, value, persistValue, persisted); + } + + @Override + public void readSettingsEnd() { + delegate.readSettingsEnd(); + } + + @Override + public void readPingFrame(int id) { + delegate.readPingFrame(id); + } + + @Override + public void readGoAwayFrame(int lastGoodStreamId, int statusCode) { + delegate.readGoAwayFrame(lastGoodStreamId, statusCode); + } + + @Override + public void readHeadersFrame(int streamId, boolean last) { + delegate.readHeadersFrame(streamId, last); + } + + @Override + public void readWindowUpdateFrame(int streamId, int deltaWindowSize) { + delegate.readWindowUpdateFrame(streamId, deltaWindowSize); + } + + @Override + public void readHeaderBlock(ByteBuf headerBlock) { + delegate.readHeaderBlock(headerBlock); + buffers.add(headerBlock); + } + + @Override + public void readHeaderBlockEnd() { + delegate.readHeaderBlockEnd(); + } + + @Override + public void readFrameError(String message) { + delegate.readFrameError(message); + } + + void releaseAll() { + for (;;) { + ByteBuf buf = buffers.poll(); + if (buf == null) { + return; + } + buf.release(); + } + } + } + + private static void encodeDataFrameHeader(ByteBuf buffer, int streamId, byte flags, int length) { + buffer.writeInt(streamId & 0x7FFFFFFF); + buffer.writeByte(flags); + buffer.writeMedium(length); + } + + private static void encodeControlFrameHeader(ByteBuf buffer, short type, byte flags, int length) { + buffer.writeShort(0x8000 | SpdyVersion.SPDY_3_1.getVersion()); + buffer.writeShort(type); + buffer.writeByte(flags); + buffer.writeMedium(length); + } + + @Test + public void testSpdyDataFrame() throws Exception { + int streamId = RANDOM.nextInt() & 0x7FFFFFFF | 0x01; + byte flags = 0; + int length = 1024; + + ByteBuf buf = Unpooled.buffer(SPDY_HEADER_SIZE + length); + encodeDataFrameHeader(buf, streamId, flags, length); + for (int i = 0; i < 256; i ++) { + buf.writeInt(RANDOM.nextInt()); + } + decoder.decode(buf); + verify(delegate).readDataFrame(streamId, false, buf.slice(SPDY_HEADER_SIZE, length)); + assertFalse(buf.isReadable()); + buf.release(); + } + + @Test + public void testEmptySpdyDataFrame() throws Exception { + int streamId = RANDOM.nextInt() & 0x7FFFFFFF | 0x01; + byte flags = 0; + int length = 0; + + ByteBuf buf = Unpooled.buffer(SPDY_HEADER_SIZE + length); + encodeDataFrameHeader(buf, streamId, flags, length); + + decoder.decode(buf); + verify(delegate).readDataFrame(streamId, false, Unpooled.EMPTY_BUFFER); + assertFalse(buf.isReadable()); + buf.release(); + } + + @Test + public void testLastSpdyDataFrame() throws Exception { + int streamId = RANDOM.nextInt() & 0x7FFFFFFF | 0x01; + byte flags = 0x01; // FLAG_FIN + int length = 0; + + ByteBuf buf = Unpooled.buffer(SPDY_HEADER_SIZE + length); + encodeDataFrameHeader(buf, streamId, flags, length); + + decoder.decode(buf); + verify(delegate).readDataFrame(streamId, true, Unpooled.EMPTY_BUFFER); + assertFalse(buf.isReadable()); + buf.release(); + } + + @Test + public void testUnknownSpdyDataFrameFlags() throws Exception { + int streamId = RANDOM.nextInt() & 0x7FFFFFFF | 0x01; + byte flags = (byte) 0xFE; // should ignore any unknown flags + int length = 0; + + ByteBuf buf = Unpooled.buffer(SPDY_HEADER_SIZE + length); + encodeDataFrameHeader(buf, streamId, flags, length); + + decoder.decode(buf); + verify(delegate).readDataFrame(streamId, false, Unpooled.EMPTY_BUFFER); + assertFalse(buf.isReadable()); + buf.release(); + } + + @Test + public void testIllegalSpdyDataFrameStreamId() throws Exception { + int streamId = 0; // illegal stream identifier + byte flags = 0; + int length = 0; + + ByteBuf buf = Unpooled.buffer(SPDY_HEADER_SIZE + length); + encodeDataFrameHeader(buf, streamId, flags, length); + + decoder.decode(buf); + verify(delegate).readFrameError((String) any()); + assertFalse(buf.isReadable()); + buf.release(); + } + + @Test + public void testPipelinedSpdyDataFrames() throws Exception { + int streamId1 = RANDOM.nextInt() & 0x7FFFFFFF | 0x01; + int streamId2 = RANDOM.nextInt() & 0x7FFFFFFF | 0x01; + byte flags = 0; + int length = 0; + + ByteBuf buf = Unpooled.buffer(2 * (SPDY_HEADER_SIZE + length)); + encodeDataFrameHeader(buf, streamId1, flags, length); + encodeDataFrameHeader(buf, streamId2, flags, length); + + decoder.decode(buf); + verify(delegate).readDataFrame(streamId1, false, Unpooled.EMPTY_BUFFER); + verify(delegate).readDataFrame(streamId2, false, Unpooled.EMPTY_BUFFER); + assertFalse(buf.isReadable()); + buf.release(); + } + + @Test + public void testSpdySynStreamFrame() throws Exception { + short type = 1; + byte flags = 0; + int length = 10; + int streamId = RANDOM.nextInt() & 0x7FFFFFFF | 0x01; + int associatedToStreamId = RANDOM.nextInt() & 0x7FFFFFFF; + byte priority = (byte) (RANDOM.nextInt() & 0x07); + + ByteBuf buf = Unpooled.buffer(SPDY_HEADER_SIZE + length); + encodeControlFrameHeader(buf, type, flags, length); + buf.writeInt(streamId); + buf.writeInt(associatedToStreamId); + buf.writeByte(priority << 5); + buf.writeByte(0); + + decoder.decode(buf); + verify(delegate).readSynStreamFrame(streamId, associatedToStreamId, priority, false, false); + verify(delegate).readHeaderBlockEnd(); + assertFalse(buf.isReadable()); + buf.release(); + } + + @Test + public void testLastSpdySynStreamFrame() throws Exception { + short type = 1; + byte flags = 0x01; // FLAG_FIN + int length = 10; + int streamId = RANDOM.nextInt() & 0x7FFFFFFF | 0x01; + int associatedToStreamId = RANDOM.nextInt() & 0x7FFFFFFF; + byte priority = (byte) (RANDOM.nextInt() & 0x07); + + ByteBuf buf = Unpooled.buffer(SPDY_HEADER_SIZE + length); + encodeControlFrameHeader(buf, type, flags, length); + buf.writeInt(streamId); + buf.writeInt(associatedToStreamId); + buf.writeByte(priority << 5); + buf.writeByte(0); + + decoder.decode(buf); + verify(delegate).readSynStreamFrame(streamId, associatedToStreamId, priority, true, false); + verify(delegate).readHeaderBlockEnd(); + assertFalse(buf.isReadable()); + buf.release(); + } + + @Test + public void testUnidirectionalSpdySynStreamFrame() throws Exception { + short type = 1; + byte flags = 0x02; // FLAG_UNIDIRECTIONAL + int length = 10; + int streamId = RANDOM.nextInt() & 0x7FFFFFFF | 0x01; + int associatedToStreamId = RANDOM.nextInt() & 0x7FFFFFFF; + byte priority = (byte) (RANDOM.nextInt() & 0x07); + + ByteBuf buf = Unpooled.buffer(SPDY_HEADER_SIZE + length); + encodeControlFrameHeader(buf, type, flags, length); + buf.writeInt(streamId); + buf.writeInt(associatedToStreamId); + buf.writeByte(priority << 5); + buf.writeByte(0); + + decoder.decode(buf); + verify(delegate).readSynStreamFrame(streamId, associatedToStreamId, priority, false, true); + verify(delegate).readHeaderBlockEnd(); + assertFalse(buf.isReadable()); + buf.release(); + } + + @Test + public void testIndependentSpdySynStreamFrame() throws Exception { + short type = 1; + byte flags = 0; + int length = 10; + int streamId = RANDOM.nextInt() & 0x7FFFFFFF | 0x01; + int associatedToStreamId = 0; // independent of all other streams + byte priority = (byte) (RANDOM.nextInt() & 0x07); + + ByteBuf buf = Unpooled.buffer(SPDY_HEADER_SIZE + length); + encodeControlFrameHeader(buf, type, flags, length); + buf.writeInt(streamId); + buf.writeInt(associatedToStreamId); + buf.writeByte(priority << 5); + buf.writeByte(0); + + decoder.decode(buf); + verify(delegate).readSynStreamFrame(streamId, associatedToStreamId, priority, false, false); + verify(delegate).readHeaderBlockEnd(); + assertFalse(buf.isReadable()); + buf.release(); + } + + @Test + public void testUnknownSpdySynStreamFrameFlags() throws Exception { + short type = 1; + byte flags = (byte) 0xFC; // undefined flags + int length = 10; + int streamId = RANDOM.nextInt() & 0x7FFFFFFF | 0x01; + int associatedToStreamId = RANDOM.nextInt() & 0x7FFFFFFF; + byte priority = (byte) (RANDOM.nextInt() & 0x07); + + ByteBuf buf = Unpooled.buffer(SPDY_HEADER_SIZE + length); + encodeControlFrameHeader(buf, type, flags, length); + buf.writeInt(streamId); + buf.writeInt(associatedToStreamId); + buf.writeByte(priority << 5); + buf.writeByte(0); + + decoder.decode(buf); + verify(delegate).readSynStreamFrame(streamId, associatedToStreamId, priority, false, false); + verify(delegate).readHeaderBlockEnd(); + assertFalse(buf.isReadable()); + buf.release(); + } + + @Test + public void testReservedSpdySynStreamFrameBits() throws Exception { + short type = 1; + byte flags = 0; + int length = 10; + int streamId = RANDOM.nextInt() & 0x7FFFFFFF | 0x01; + int associatedToStreamId = RANDOM.nextInt() & 0x7FFFFFFF; + byte priority = (byte) (RANDOM.nextInt() & 0x07); + + ByteBuf buf = Unpooled.buffer(SPDY_HEADER_SIZE + length); + encodeControlFrameHeader(buf, type, flags, length); + buf.writeInt(streamId | 0x80000000); // should ignore reserved bit + buf.writeInt(associatedToStreamId | 0x80000000); // should ignore reserved bit + buf.writeByte(priority << 5 | 0x1F); // should ignore reserved bits + buf.writeByte(0xFF); // should ignore reserved bits + + decoder.decode(buf); + verify(delegate).readSynStreamFrame(streamId, associatedToStreamId, priority, false, false); + verify(delegate).readHeaderBlockEnd(); + assertFalse(buf.isReadable()); + buf.release(); + } + + @Test + public void testInvalidSpdySynStreamFrameLength() throws Exception { + short type = 1; + byte flags = 0; + int length = 8; // invalid length + int streamId = RANDOM.nextInt() & 0x7FFFFFFF | 0x01; + int associatedToStreamId = RANDOM.nextInt() & 0x7FFFFFFF; + + ByteBuf buf = Unpooled.buffer(SPDY_HEADER_SIZE + length); + encodeControlFrameHeader(buf, type, flags, length); + buf.writeInt(streamId); + buf.writeInt(associatedToStreamId); + + decoder.decode(buf); + verify(delegate).readFrameError(anyString()); + assertFalse(buf.isReadable()); + buf.release(); + } + + @Test + public void testIllegalSpdySynStreamFrameStreamId() throws Exception { + short type = 1; + byte flags = 0; + int length = 10; + int streamId = 0; // invalid stream identifier + int associatedToStreamId = RANDOM.nextInt() & 0x7FFFFFFF; + byte priority = (byte) (RANDOM.nextInt() & 0x07); + + ByteBuf buf = Unpooled.buffer(SPDY_HEADER_SIZE + length); + encodeControlFrameHeader(buf, type, flags, length); + buf.writeInt(streamId); + buf.writeInt(associatedToStreamId); + buf.writeByte(priority << 5); + buf.writeByte(0); + + decoder.decode(buf); + verify(delegate).readFrameError(anyString()); + assertFalse(buf.isReadable()); + buf.release(); + } + + @Test + public void testSpdySynStreamFrameHeaderBlock() throws Exception { + short type = 1; + byte flags = 0; + int length = 10; + int headerBlockLength = 1024; + int streamId = RANDOM.nextInt() & 0x7FFFFFFF | 0x01; + int associatedToStreamId = RANDOM.nextInt() & 0x7FFFFFFF; + byte priority = (byte) (RANDOM.nextInt() & 0x07); + + ByteBuf buf = Unpooled.buffer(SPDY_HEADER_SIZE + length + headerBlockLength); + encodeControlFrameHeader(buf, type, flags, length + headerBlockLength); + buf.writeInt(streamId); + buf.writeInt(associatedToStreamId); + buf.writeByte(priority << 5); + buf.writeByte(0); + + ByteBuf headerBlock = Unpooled.buffer(headerBlockLength); + for (int i = 0; i < 256; i ++) { + headerBlock.writeInt(RANDOM.nextInt()); + } + + decoder.decode(buf); + decoder.decode(headerBlock); + verify(delegate).readSynStreamFrame(streamId, associatedToStreamId, priority, false, false); + verify(delegate).readHeaderBlock(headerBlock.slice(0, headerBlock.writerIndex())); + verify(delegate).readHeaderBlockEnd(); + assertFalse(buf.isReadable()); + assertFalse(headerBlock.isReadable()); + buf.release(); + headerBlock.release(); + } + + @Test + public void testSpdySynReplyFrame() throws Exception { + short type = 2; + byte flags = 0; + int length = 4; + int streamId = RANDOM.nextInt() & 0x7FFFFFFF | 0x01; + + ByteBuf buf = Unpooled.buffer(SPDY_HEADER_SIZE + length); + encodeControlFrameHeader(buf, type, flags, length); + buf.writeInt(streamId); + + decoder.decode(buf); + verify(delegate).readSynReplyFrame(streamId, false); + verify(delegate).readHeaderBlockEnd(); + assertFalse(buf.isReadable()); + buf.release(); + } + + @Test + public void testLastSpdySynReplyFrame() throws Exception { + short type = 2; + byte flags = 0x01; // FLAG_FIN + int length = 4; + int streamId = RANDOM.nextInt() & 0x7FFFFFFF | 0x01; + + ByteBuf buf = Unpooled.buffer(SPDY_HEADER_SIZE + length); + encodeControlFrameHeader(buf, type, flags, length); + buf.writeInt(streamId); + + decoder.decode(buf); + verify(delegate).readSynReplyFrame(streamId, true); + verify(delegate).readHeaderBlockEnd(); + assertFalse(buf.isReadable()); + buf.release(); + } + + @Test + public void testUnknownSpdySynReplyFrameFlags() throws Exception { + short type = 2; + byte flags = (byte) 0xFE; // undefined flags + int length = 4; + int streamId = RANDOM.nextInt() & 0x7FFFFFFF | 0x01; + + ByteBuf buf = Unpooled.buffer(SPDY_HEADER_SIZE + length); + encodeControlFrameHeader(buf, type, flags, length); + buf.writeInt(streamId); + + decoder.decode(buf); + verify(delegate).readSynReplyFrame(streamId, false); + verify(delegate).readHeaderBlockEnd(); + assertFalse(buf.isReadable()); + buf.release(); + } + + @Test + public void testReservedSpdySynReplyFrameBits() throws Exception { + short type = 2; + byte flags = 0; + int length = 4; + int streamId = RANDOM.nextInt() & 0x7FFFFFFF | 0x01; + + ByteBuf buf = Unpooled.buffer(SPDY_HEADER_SIZE + length); + encodeControlFrameHeader(buf, type, flags, length); + buf.writeInt(streamId | 0x80000000); // should ignore reserved bit + + decoder.decode(buf); + verify(delegate).readSynReplyFrame(streamId, false); + verify(delegate).readHeaderBlockEnd(); + + assertFalse(buf.isReadable()); + buf.release(); + } + + @Test + public void testInvalidSpdySynReplyFrameLength() throws Exception { + short type = 2; + byte flags = 0; + int length = 0; // invalid length + + ByteBuf buf = Unpooled.buffer(SPDY_HEADER_SIZE + length); + encodeControlFrameHeader(buf, type, flags, length); + + decoder.decode(buf); + verify(delegate).readFrameError(anyString()); + assertFalse(buf.isReadable()); + buf.release(); + } + + @Test + public void testIllegalSpdySynReplyFrameStreamId() throws Exception { + short type = 2; + byte flags = 0; + int length = 4; + int streamId = 0; // invalid stream identifier + + ByteBuf buf = Unpooled.buffer(SPDY_HEADER_SIZE + length); + encodeControlFrameHeader(buf, type, flags, length); + buf.writeInt(streamId); + + decoder.decode(buf); + verify(delegate).readFrameError(anyString()); + assertFalse(buf.isReadable()); + buf.release(); + } + + @Test + public void testSpdySynReplyFrameHeaderBlock() throws Exception { + short type = 2; + byte flags = 0; + int length = 4; + int headerBlockLength = 1024; + int streamId = RANDOM.nextInt() & 0x7FFFFFFF | 0x01; + + ByteBuf buf = Unpooled.buffer(SPDY_HEADER_SIZE + length + headerBlockLength); + encodeControlFrameHeader(buf, type, flags, length + headerBlockLength); + buf.writeInt(streamId); + + ByteBuf headerBlock = Unpooled.buffer(headerBlockLength); + for (int i = 0; i < 256; i ++) { + headerBlock.writeInt(RANDOM.nextInt()); + } + + decoder.decode(buf); + decoder.decode(headerBlock); + verify(delegate).readSynReplyFrame(streamId, false); + verify(delegate).readHeaderBlock(headerBlock.slice(0, headerBlock.writerIndex())); + verify(delegate).readHeaderBlockEnd(); + assertFalse(buf.isReadable()); + assertFalse(headerBlock.isReadable()); + buf.release(); + headerBlock.release(); + } + + @Test + public void testSpdyRstStreamFrame() throws Exception { + short type = 3; + byte flags = 0; + int length = 8; + int streamId = RANDOM.nextInt() & 0x7FFFFFFF | 0x01; + int statusCode = RANDOM.nextInt() | 0x01; + + ByteBuf buf = Unpooled.buffer(SPDY_HEADER_SIZE + length); + encodeControlFrameHeader(buf, type, flags, length); + buf.writeInt(streamId); + buf.writeInt(statusCode); + + decoder.decode(buf); + verify(delegate).readRstStreamFrame(streamId, statusCode); + assertFalse(buf.isReadable()); + buf.release(); + } + + @Test + public void testReservedSpdyRstStreamFrameBits() throws Exception { + short type = 3; + byte flags = 0; + int length = 8; + int streamId = RANDOM.nextInt() & 0x7FFFFFFF | 0x01; + int statusCode = RANDOM.nextInt() | 0x01; + + ByteBuf buf = Unpooled.buffer(SPDY_HEADER_SIZE + length); + encodeControlFrameHeader(buf, type, flags, length); + buf.writeInt(streamId | 0x80000000); // should ignore reserved bit + buf.writeInt(statusCode); + + decoder.decode(buf); + verify(delegate).readRstStreamFrame(streamId, statusCode); + assertFalse(buf.isReadable()); + buf.release(); + } + + @Test + public void testInvalidSpdyRstStreamFrameFlags() throws Exception { + short type = 3; + byte flags = (byte) 0xFF; // invalid flags + int length = 8; + int streamId = RANDOM.nextInt() & 0x7FFFFFFF | 0x01; + int statusCode = RANDOM.nextInt() | 0x01; + + ByteBuf buf = Unpooled.buffer(SPDY_HEADER_SIZE + length); + encodeControlFrameHeader(buf, type, flags, length); + buf.writeInt(streamId); + buf.writeInt(statusCode); + + decoder.decode(buf); + verify(delegate).readFrameError(anyString()); + assertFalse(buf.isReadable()); + buf.release(); + } + + @Test + public void testInvalidSpdyRstStreamFrameLength() throws Exception { + short type = 3; + byte flags = 0; + int length = 12; // invalid length + int streamId = RANDOM.nextInt() & 0x7FFFFFFF | 0x01; + int statusCode = RANDOM.nextInt() | 0x01; + + ByteBuf buf = Unpooled.buffer(SPDY_HEADER_SIZE + length); + encodeControlFrameHeader(buf, type, flags, length); + buf.writeInt(streamId); + buf.writeInt(statusCode); + + decoder.decode(buf); + verify(delegate).readFrameError(anyString()); + assertFalse(buf.isReadable()); + buf.release(); + } + + @Test + public void testIllegalSpdyRstStreamFrameStreamId() throws Exception { + short type = 3; + byte flags = 0; + int length = 8; + int streamId = 0; // invalid stream identifier + int statusCode = RANDOM.nextInt() | 0x01; + + ByteBuf buf = Unpooled.buffer(SPDY_HEADER_SIZE + length); + encodeControlFrameHeader(buf, type, flags, length); + buf.writeInt(streamId); + buf.writeInt(statusCode); + + decoder.decode(buf); + verify(delegate).readFrameError(anyString()); + assertFalse(buf.isReadable()); + buf.release(); + } + + @Test + public void testIllegalSpdyRstStreamFrameStatusCode() throws Exception { + short type = 3; + byte flags = 0; + int length = 8; + int streamId = RANDOM.nextInt() & 0x7FFFFFFF | 0x01; + int statusCode = 0; // invalid status code + + ByteBuf buf = Unpooled.buffer(SPDY_HEADER_SIZE + length); + encodeControlFrameHeader(buf, type, flags, length); + buf.writeInt(streamId); + buf.writeInt(statusCode); + + decoder.decode(buf); + verify(delegate).readFrameError(anyString()); + assertFalse(buf.isReadable()); + buf.release(); + } + + @Test + public void testSpdySettingsFrame() throws Exception { + short type = 4; + byte flags = 0; + int numSettings = 2; + int length = 8 * numSettings + 4; + byte idFlags = 0; + int id = RANDOM.nextInt() & 0x00FFFFFF; + int value = RANDOM.nextInt(); + + ByteBuf buf = Unpooled.buffer(SPDY_HEADER_SIZE + length); + encodeControlFrameHeader(buf, type, flags, length); + buf.writeInt(numSettings); + for (int i = 0; i < numSettings; i++) { + buf.writeByte(idFlags); + buf.writeMedium(id); + buf.writeInt(value); + } + + delegate.readSettingsEnd(); + decoder.decode(buf); + verify(delegate).readSettingsFrame(false); + verify(delegate, times(numSettings)).readSetting(id, value, false, false); + assertFalse(buf.isReadable()); + buf.release(); + } + + @Test + public void testEmptySpdySettingsFrame() throws Exception { + short type = 4; + byte flags = 0; + int numSettings = 0; + int length = 8 * numSettings + 4; + + ByteBuf buf = Unpooled.buffer(SPDY_HEADER_SIZE + length); + encodeControlFrameHeader(buf, type, flags, length); + buf.writeInt(numSettings); + + decoder.decode(buf); + verify(delegate).readSettingsFrame(false); + verify(delegate).readSettingsEnd(); + assertFalse(buf.isReadable()); + buf.release(); + } + + @Test + public void testSpdySettingsFrameClearFlag() throws Exception { + short type = 4; + byte flags = 0x01; // FLAG_SETTINGS_CLEAR_SETTINGS + int numSettings = 0; + int length = 8 * numSettings + 4; + + ByteBuf buf = Unpooled.buffer(SPDY_HEADER_SIZE + length); + encodeControlFrameHeader(buf, type, flags, length); + buf.writeInt(numSettings); + + decoder.decode(buf); + verify(delegate).readSettingsFrame(true); + verify(delegate).readSettingsEnd(); + assertFalse(buf.isReadable()); + buf.release(); + } + + @Test + public void testSpdySettingsPersistValues() throws Exception { + short type = 4; + byte flags = 0; + int numSettings = 1; + int length = 8 * numSettings + 4; + byte idFlags = 0x01; // FLAG_SETTINGS_PERSIST_VALUE + int id = RANDOM.nextInt() & 0x00FFFFFF; + int value = RANDOM.nextInt(); + + ByteBuf buf = Unpooled.buffer(SPDY_HEADER_SIZE + length); + encodeControlFrameHeader(buf, type, flags, length); + buf.writeInt(numSettings); + for (int i = 0; i < numSettings; i++) { + buf.writeByte(idFlags); + buf.writeMedium(id); + buf.writeInt(value); + } + + delegate.readSettingsEnd(); + decoder.decode(buf); + verify(delegate).readSettingsFrame(false); + verify(delegate, times(numSettings)).readSetting(id, value, true, false); + assertFalse(buf.isReadable()); + buf.release(); + } + + @Test + public void testSpdySettingsPersistedValues() throws Exception { + short type = 4; + byte flags = 0; + int numSettings = 1; + int length = 8 * numSettings + 4; + byte idFlags = 0x02; // FLAG_SETTINGS_PERSISTED + int id = RANDOM.nextInt() & 0x00FFFFFF; + int value = RANDOM.nextInt(); + + ByteBuf buf = Unpooled.buffer(SPDY_HEADER_SIZE + length); + encodeControlFrameHeader(buf, type, flags, length); + buf.writeInt(numSettings); + for (int i = 0; i < numSettings; i++) { + buf.writeByte(idFlags); + buf.writeMedium(id); + buf.writeInt(value); + } + + delegate.readSettingsEnd(); + decoder.decode(buf); + verify(delegate).readSettingsFrame(false); + verify(delegate, times(numSettings)).readSetting(id, value, false, true); + assertFalse(buf.isReadable()); + buf.release(); + } + + @Test + public void testUnknownSpdySettingsFrameFlags() throws Exception { + short type = 4; + byte flags = (byte) 0xFE; // undefined flags + int numSettings = 0; + int length = 8 * numSettings + 4; + + ByteBuf buf = Unpooled.buffer(SPDY_HEADER_SIZE + length); + encodeControlFrameHeader(buf, type, flags, length); + buf.writeInt(numSettings); + + decoder.decode(buf); + verify(delegate).readSettingsFrame(false); + verify(delegate).readSettingsEnd(); + assertFalse(buf.isReadable()); + buf.release(); + } + + @Test + public void testUnknownSpdySettingsFlags() throws Exception { + short type = 4; + byte flags = 0; + int numSettings = 1; + int length = 8 * numSettings + 4; + byte idFlags = (byte) 0xFC; // undefined flags + int id = RANDOM.nextInt() & 0x00FFFFFF; + int value = RANDOM.nextInt(); + + ByteBuf buf = Unpooled.buffer(SPDY_HEADER_SIZE + length); + encodeControlFrameHeader(buf, type, flags, length); + buf.writeInt(numSettings); + for (int i = 0; i < numSettings; i++) { + buf.writeByte(idFlags); + buf.writeMedium(id); + buf.writeInt(value); + } + + delegate.readSettingsEnd(); + decoder.decode(buf); + verify(delegate).readSettingsFrame(false); + verify(delegate, times(numSettings)).readSetting(id, value, false, false); + assertFalse(buf.isReadable()); + buf.release(); + } + + @Test + public void testInvalidSpdySettingsFrameLength() throws Exception { + short type = 4; + byte flags = 0; + int numSettings = 2; + int length = 8 * numSettings + 8; // invalid length + byte idFlags = 0; + int id = RANDOM.nextInt() & 0x00FFFFFF; + int value = RANDOM.nextInt(); + + ByteBuf buf = Unpooled.buffer(SPDY_HEADER_SIZE + length); + encodeControlFrameHeader(buf, type, flags, length); + buf.writeInt(numSettings); + for (int i = 0; i < numSettings; i++) { + buf.writeByte(idFlags); + buf.writeMedium(id); + buf.writeInt(value); + } + + decoder.decode(buf); + verify(delegate).readFrameError(anyString()); + assertFalse(buf.isReadable()); + buf.release(); + } + + @Test + public void testInvalidSpdySettingsFrameNumSettings() throws Exception { + short type = 4; + byte flags = 0; + int numSettings = 2; + int length = 8 * numSettings + 4; + byte idFlags = 0; + int id = RANDOM.nextInt() & 0x00FFFFFF; + int value = RANDOM.nextInt(); + + ByteBuf buf = Unpooled.buffer(SPDY_HEADER_SIZE + length); + encodeControlFrameHeader(buf, type, flags, length); + buf.writeInt(0); // invalid num_settings + for (int i = 0; i < numSettings; i++) { + buf.writeByte(idFlags); + buf.writeMedium(id); + buf.writeInt(value); + } + + decoder.decode(buf); + verify(delegate).readFrameError(anyString()); + assertFalse(buf.isReadable()); + buf.release(); + } + + @Test + public void testDiscardUnknownFrame() throws Exception { + short type = 5; + byte flags = (byte) 0xFF; + int length = 8; + + ByteBuf buf = Unpooled.buffer(SPDY_HEADER_SIZE + length); + encodeControlFrameHeader(buf, type, flags, length); + buf.writeLong(RANDOM.nextLong()); + + decoder.decode(buf); + verifyZeroInteractions(delegate); + assertFalse(buf.isReadable()); + buf.release(); + } + + @Test + public void testDiscardUnknownEmptyFrame() throws Exception { + short type = 5; + byte flags = (byte) 0xFF; + int length = 0; + + ByteBuf buf = Unpooled.buffer(SPDY_HEADER_SIZE + length); + encodeControlFrameHeader(buf, type, flags, length); + + decoder.decode(buf); + verifyZeroInteractions(delegate); + assertFalse(buf.isReadable()); + buf.release(); + } + + @Test + public void testProgressivelyDiscardUnknownEmptyFrame() throws Exception { + short type = 5; + byte flags = (byte) 0xFF; + int segment = 4; + int length = 2 * segment; + + ByteBuf header = Unpooled.buffer(SPDY_HEADER_SIZE); + ByteBuf segment1 = Unpooled.buffer(segment); + ByteBuf segment2 = Unpooled.buffer(segment); + encodeControlFrameHeader(header, type, flags, length); + segment1.writeInt(RANDOM.nextInt()); + segment2.writeInt(RANDOM.nextInt()); + + decoder.decode(header); + decoder.decode(segment1); + decoder.decode(segment2); + verifyZeroInteractions(delegate); + assertFalse(header.isReadable()); + assertFalse(segment1.isReadable()); + assertFalse(segment2.isReadable()); + header.release(); + segment1.release(); + segment2.release(); + } + + @Test + public void testSpdyPingFrame() throws Exception { + short type = 6; + byte flags = 0; + int length = 4; + int id = RANDOM.nextInt(); + + ByteBuf buf = Unpooled.buffer(SPDY_HEADER_SIZE + length); + encodeControlFrameHeader(buf, type, flags, length); + buf.writeInt(id); + + decoder.decode(buf); + verify(delegate).readPingFrame(id); + assertFalse(buf.isReadable()); + buf.release(); + } + + @Test + public void testUnknownSpdyPingFrameFlags() throws Exception { + short type = 6; + byte flags = (byte) 0xFF; // undefined flags + int length = 4; + int id = RANDOM.nextInt(); + + ByteBuf buf = Unpooled.buffer(SPDY_HEADER_SIZE + length); + encodeControlFrameHeader(buf, type, flags, length); + buf.writeInt(id); + + decoder.decode(buf); + verify(delegate).readPingFrame(id); + assertFalse(buf.isReadable()); + buf.release(); + } + + @Test + public void testInvalidSpdyPingFrameLength() throws Exception { + short type = 6; + byte flags = 0; + int length = 8; // invalid length + int id = RANDOM.nextInt(); + + ByteBuf buf = Unpooled.buffer(SPDY_HEADER_SIZE + length); + encodeControlFrameHeader(buf, type, flags, length); + buf.writeInt(id); + + decoder.decode(buf); + verify(delegate).readFrameError(anyString()); + assertFalse(buf.isReadable()); + buf.release(); + } + + @Test + public void testSpdyGoAwayFrame() throws Exception { + short type = 7; + byte flags = 0; + int length = 8; + int lastGoodStreamId = RANDOM.nextInt() & 0x7FFFFFFF; + int statusCode = RANDOM.nextInt() | 0x01; + + ByteBuf buf = Unpooled.buffer(SPDY_HEADER_SIZE + length); + encodeControlFrameHeader(buf, type, flags, length); + buf.writeInt(lastGoodStreamId); + buf.writeInt(statusCode); + + decoder.decode(buf); + verify(delegate).readGoAwayFrame(lastGoodStreamId, statusCode); + assertFalse(buf.isReadable()); + buf.release(); + } + + @Test + public void testUnknownSpdyGoAwayFrameFlags() throws Exception { + short type = 7; + byte flags = (byte) 0xFF; // undefined flags + int length = 8; + int lastGoodStreamId = RANDOM.nextInt() & 0x7FFFFFFF; + int statusCode = RANDOM.nextInt() | 0x01; + + ByteBuf buf = Unpooled.buffer(SPDY_HEADER_SIZE + length); + encodeControlFrameHeader(buf, type, flags, length); + buf.writeInt(lastGoodStreamId); + buf.writeInt(statusCode); + + decoder.decode(buf); + verify(delegate).readGoAwayFrame(lastGoodStreamId, statusCode); + assertFalse(buf.isReadable()); + buf.release(); + } + + @Test + public void testReservedSpdyGoAwayFrameBits() throws Exception { + short type = 7; + byte flags = 0; + int length = 8; + int lastGoodStreamId = RANDOM.nextInt() & 0x7FFFFFFF; + int statusCode = RANDOM.nextInt() | 0x01; + + ByteBuf buf = Unpooled.buffer(SPDY_HEADER_SIZE + length); + encodeControlFrameHeader(buf, type, flags, length); + buf.writeInt(lastGoodStreamId | 0x80000000); // should ignore reserved bit + buf.writeInt(statusCode); + + decoder.decode(buf); + verify(delegate).readGoAwayFrame(lastGoodStreamId, statusCode); + assertFalse(buf.isReadable()); + buf.release(); + } + + @Test + public void testInvalidSpdyGoAwayFrameLength() throws Exception { + short type = 7; + byte flags = 0; + int length = 12; // invalid length + int lastGoodStreamId = RANDOM.nextInt() & 0x7FFFFFFF; + int statusCode = RANDOM.nextInt() | 0x01; + + ByteBuf buf = Unpooled.buffer(SPDY_HEADER_SIZE + length); + encodeControlFrameHeader(buf, type, flags, length); + buf.writeInt(lastGoodStreamId); + buf.writeInt(statusCode); + + decoder.decode(buf); + verify(delegate).readFrameError(anyString()); + assertFalse(buf.isReadable()); + buf.release(); + } + + @Test + public void testSpdyHeadersFrame() throws Exception { + short type = 8; + byte flags = 0; + int length = 4; + int streamId = RANDOM.nextInt() & 0x7FFFFFFF | 0x01; + + ByteBuf buf = Unpooled.buffer(SPDY_HEADER_SIZE + length); + encodeControlFrameHeader(buf, type, flags, length); + buf.writeInt(streamId); + + decoder.decode(buf); + verify(delegate).readHeadersFrame(streamId, false); + verify(delegate).readHeaderBlockEnd(); + assertFalse(buf.isReadable()); + buf.release(); + } + + @Test + public void testLastSpdyHeadersFrame() throws Exception { + short type = 8; + byte flags = 0x01; // FLAG_FIN + int length = 4; + int streamId = RANDOM.nextInt() & 0x7FFFFFFF | 0x01; + + ByteBuf buf = Unpooled.buffer(SPDY_HEADER_SIZE + length); + encodeControlFrameHeader(buf, type, flags, length); + buf.writeInt(streamId); + + decoder.decode(buf); + verify(delegate).readHeadersFrame(streamId, true); + verify(delegate).readHeaderBlockEnd(); + assertFalse(buf.isReadable()); + buf.release(); + } + + @Test + public void testUnknownSpdyHeadersFrameFlags() throws Exception { + short type = 8; + byte flags = (byte) 0xFE; // undefined flags + int length = 4; + int streamId = RANDOM.nextInt() & 0x7FFFFFFF | 0x01; + + ByteBuf buf = Unpooled.buffer(SPDY_HEADER_SIZE + length); + encodeControlFrameHeader(buf, type, flags, length); + buf.writeInt(streamId); + + decoder.decode(buf); + verify(delegate).readHeadersFrame(streamId, false); + verify(delegate).readHeaderBlockEnd(); + assertFalse(buf.isReadable()); + buf.release(); + } + + @Test + public void testReservedSpdyHeadersFrameBits() throws Exception { + short type = 8; + byte flags = 0; + int length = 4; + int streamId = RANDOM.nextInt() & 0x7FFFFFFF | 0x01; + + ByteBuf buf = Unpooled.buffer(SPDY_HEADER_SIZE + length); + encodeControlFrameHeader(buf, type, flags, length); + buf.writeInt(streamId | 0x80000000); // should ignore reserved bit + + decoder.decode(buf); + verify(delegate).readHeadersFrame(streamId, false); + verify(delegate).readHeaderBlockEnd(); + assertFalse(buf.isReadable()); + buf.release(); + } + + @Test + public void testInvalidSpdyHeadersFrameLength() throws Exception { + short type = 8; + byte flags = 0; + int length = 0; // invalid length + + ByteBuf buf = Unpooled.buffer(SPDY_HEADER_SIZE + length); + encodeControlFrameHeader(buf, type, flags, length); + + decoder.decode(buf); + verify(delegate).readFrameError(anyString()); + assertFalse(buf.isReadable()); + buf.release(); + } + + @Test + public void testInvalidSpdyHeadersFrameStreamId() throws Exception { + short type = 8; + byte flags = 0; + int length = 4; + int streamId = 0; // invalid stream identifier + + ByteBuf buf = Unpooled.buffer(SPDY_HEADER_SIZE + length); + encodeControlFrameHeader(buf, type, flags, length); + buf.writeInt(streamId); + + decoder.decode(buf); + verify(delegate).readFrameError(anyString()); + assertFalse(buf.isReadable()); + buf.release(); + } + + @Test + public void testSpdyHeadersFrameHeaderBlock() throws Exception { + short type = 8; + byte flags = 0; + int length = 4; + int headerBlockLength = 1024; + int streamId = RANDOM.nextInt() & 0x7FFFFFFF | 0x01; + + ByteBuf buf = Unpooled.buffer(SPDY_HEADER_SIZE + length); + encodeControlFrameHeader(buf, type, flags, length + headerBlockLength); + buf.writeInt(streamId); + + ByteBuf headerBlock = Unpooled.buffer(headerBlockLength); + for (int i = 0; i < 256; i ++) { + headerBlock.writeInt(RANDOM.nextInt()); + } + decoder.decode(buf); + decoder.decode(headerBlock); + verify(delegate).readHeadersFrame(streamId, false); + verify(delegate).readHeaderBlock(headerBlock.slice(0, headerBlock.writerIndex())); + verify(delegate).readHeaderBlockEnd(); + assertFalse(buf.isReadable()); + assertFalse(headerBlock.isReadable()); + buf.release(); + headerBlock.release(); + } + + @Test + public void testSpdyWindowUpdateFrame() throws Exception { + short type = 9; + byte flags = 0; + int length = 8; + int streamId = RANDOM.nextInt() & 0x7FFFFFFF; + int deltaWindowSize = RANDOM.nextInt() & 0x7FFFFFFF | 0x01; + + ByteBuf buf = Unpooled.buffer(SPDY_HEADER_SIZE + length); + encodeControlFrameHeader(buf, type, flags, length); + buf.writeInt(streamId); + buf.writeInt(deltaWindowSize); + + decoder.decode(buf); + verify(delegate).readWindowUpdateFrame(streamId, deltaWindowSize); + assertFalse(buf.isReadable()); + } + + @Test + public void testUnknownSpdyWindowUpdateFrameFlags() throws Exception { + short type = 9; + byte flags = (byte) 0xFF; // undefined flags + int length = 8; + int streamId = RANDOM.nextInt() & 0x7FFFFFFF; + int deltaWindowSize = RANDOM.nextInt() & 0x7FFFFFFF | 0x01; + + ByteBuf buf = Unpooled.buffer(SPDY_HEADER_SIZE + length); + encodeControlFrameHeader(buf, type, flags, length); + buf.writeInt(streamId); + buf.writeInt(deltaWindowSize); + + decoder.decode(buf); + verify(delegate).readWindowUpdateFrame(streamId, deltaWindowSize); + assertFalse(buf.isReadable()); + buf.release(); + } + + @Test + public void testReservedSpdyWindowUpdateFrameBits() throws Exception { + short type = 9; + byte flags = 0; + int length = 8; + int streamId = RANDOM.nextInt() & 0x7FFFFFFF; + int deltaWindowSize = RANDOM.nextInt() & 0x7FFFFFFF | 0x01; + + ByteBuf buf = Unpooled.buffer(SPDY_HEADER_SIZE + length); + encodeControlFrameHeader(buf, type, flags, length); + buf.writeInt(streamId | 0x80000000); // should ignore reserved bit + buf.writeInt(deltaWindowSize | 0x80000000); // should ignore reserved bit + + decoder.decode(buf); + verify(delegate).readWindowUpdateFrame(streamId, deltaWindowSize); + assertFalse(buf.isReadable()); + buf.release(); + } + + @Test + public void testInvalidSpdyWindowUpdateFrameLength() throws Exception { + short type = 9; + byte flags = 0; + int length = 12; // invalid length + int streamId = RANDOM.nextInt() & 0x7FFFFFFF; + int deltaWindowSize = RANDOM.nextInt() & 0x7FFFFFFF | 0x01; + + ByteBuf buf = Unpooled.buffer(SPDY_HEADER_SIZE + length); + encodeControlFrameHeader(buf, type, flags, length); + buf.writeInt(streamId); + buf.writeInt(deltaWindowSize); + + decoder.decode(buf); + verify(delegate).readFrameError(anyString()); + assertFalse(buf.isReadable()); + buf.release(); + } + + @Test + public void testIllegalSpdyWindowUpdateFrameDeltaWindowSize() throws Exception { + short type = 9; + byte flags = 0; + int length = 8; + int streamId = RANDOM.nextInt() & 0x7FFFFFFF; + int deltaWindowSize = 0; // invalid delta window size + + ByteBuf buf = Unpooled.buffer(SPDY_HEADER_SIZE + length); + encodeControlFrameHeader(buf, type, flags, length); + buf.writeInt(streamId); + buf.writeInt(deltaWindowSize); + + decoder.decode(buf); + verify(delegate).readFrameError(anyString()); + assertFalse(buf.isReadable()); + buf.release(); + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/spdy/SpdyHeaderBlockRawDecoderTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/spdy/SpdyHeaderBlockRawDecoderTest.java new file mode 100644 index 0000000..737bf2f --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/spdy/SpdyHeaderBlockRawDecoderTest.java @@ -0,0 +1,516 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.spdy; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class SpdyHeaderBlockRawDecoderTest { + + private static final int maxHeaderSize = 16; + + private static final String name = "name"; + private static final String value = "value"; + private static final byte[] nameBytes = name.getBytes(); + private static final byte[] valueBytes = value.getBytes(); + + private SpdyHeaderBlockRawDecoder decoder; + private SpdyHeadersFrame frame; + + @BeforeEach + public void setUp() { + decoder = new SpdyHeaderBlockRawDecoder(SpdyVersion.SPDY_3_1, maxHeaderSize); + frame = new DefaultSpdyHeadersFrame(1); + } + + @AfterEach + public void tearDown() { + decoder.end(); + } + + @Test + public void testEmptyHeaderBlock() throws Exception { + ByteBuf headerBlock = Unpooled.EMPTY_BUFFER; + decoder.decode(ByteBufAllocator.DEFAULT, headerBlock, frame); + decoder.endHeaderBlock(frame); + + assertFalse(headerBlock.isReadable()); + assertTrue(frame.isInvalid()); + assertEquals(0, frame.headers().names().size()); + headerBlock.release(); + } + + @Test + public void testZeroNameValuePairs() throws Exception { + ByteBuf headerBlock = Unpooled.buffer(4); + headerBlock.writeInt(0); + decoder.decode(ByteBufAllocator.DEFAULT, headerBlock, frame); + decoder.endHeaderBlock(frame); + + assertFalse(headerBlock.isReadable()); + assertFalse(frame.isInvalid()); + assertEquals(0, frame.headers().names().size()); + headerBlock.release(); + } + + @Test + public void testNegativeNameValuePairs() throws Exception { + ByteBuf headerBlock = Unpooled.buffer(4); + headerBlock.writeInt(-1); + decoder.decode(ByteBufAllocator.DEFAULT, headerBlock, frame); + + assertFalse(headerBlock.isReadable()); + assertTrue(frame.isInvalid()); + assertEquals(0, frame.headers().names().size()); + headerBlock.release(); + } + + @Test + public void testOneNameValuePair() throws Exception { + ByteBuf headerBlock = Unpooled.buffer(21); + headerBlock.writeInt(1); + headerBlock.writeInt(4); + headerBlock.writeBytes(nameBytes); + headerBlock.writeInt(5); + headerBlock.writeBytes(valueBytes); + decoder.decode(ByteBufAllocator.DEFAULT, headerBlock, frame); + decoder.endHeaderBlock(frame); + + assertFalse(headerBlock.isReadable()); + assertFalse(frame.isInvalid()); + assertEquals(1, frame.headers().names().size()); + assertTrue(frame.headers().contains(name)); + assertEquals(1, frame.headers().getAll(name).size()); + assertEquals(value, frame.headers().get(name)); + headerBlock.release(); + } + + @Test + public void testMissingNameLength() throws Exception { + ByteBuf headerBlock = Unpooled.buffer(4); + headerBlock.writeInt(1); + decoder.decode(ByteBufAllocator.DEFAULT, headerBlock, frame); + decoder.endHeaderBlock(frame); + + assertFalse(headerBlock.isReadable()); + assertTrue(frame.isInvalid()); + assertEquals(0, frame.headers().names().size()); + headerBlock.release(); + } + + @Test + public void testZeroNameLength() throws Exception { + ByteBuf headerBlock = Unpooled.buffer(8); + headerBlock.writeInt(1); + headerBlock.writeInt(0); + decoder.decode(ByteBufAllocator.DEFAULT, headerBlock, frame); + + assertFalse(headerBlock.isReadable()); + assertTrue(frame.isInvalid()); + assertEquals(0, frame.headers().names().size()); + headerBlock.release(); + } + + @Test + public void testNegativeNameLength() throws Exception { + ByteBuf headerBlock = Unpooled.buffer(8); + headerBlock.writeInt(1); + headerBlock.writeInt(-1); + decoder.decode(ByteBufAllocator.DEFAULT, headerBlock, frame); + + assertFalse(headerBlock.isReadable()); + assertTrue(frame.isInvalid()); + assertEquals(0, frame.headers().names().size()); + headerBlock.release(); + } + + @Test + public void testMissingName() throws Exception { + ByteBuf headerBlock = Unpooled.buffer(8); + headerBlock.writeInt(1); + headerBlock.writeInt(4); + decoder.decode(ByteBufAllocator.DEFAULT, headerBlock, frame); + decoder.endHeaderBlock(frame); + + assertFalse(headerBlock.isReadable()); + assertTrue(frame.isInvalid()); + assertEquals(0, frame.headers().names().size()); + headerBlock.release(); + } + + @Test + public void testIllegalNameOnlyNull() throws Exception { + ByteBuf headerBlock = Unpooled.buffer(18); + headerBlock.writeInt(1); + headerBlock.writeInt(1); + headerBlock.writeByte(0); + headerBlock.writeInt(5); + headerBlock.writeBytes(valueBytes); + decoder.decode(ByteBufAllocator.DEFAULT, headerBlock, frame); + + assertFalse(headerBlock.isReadable()); + assertTrue(frame.isInvalid()); + assertEquals(0, frame.headers().names().size()); + headerBlock.release(); + } + + @Test + public void testMissingValueLength() throws Exception { + ByteBuf headerBlock = Unpooled.buffer(12); + headerBlock.writeInt(1); + headerBlock.writeInt(4); + headerBlock.writeBytes(nameBytes); + decoder.decode(ByteBufAllocator.DEFAULT, headerBlock, frame); + decoder.endHeaderBlock(frame); + + assertFalse(headerBlock.isReadable()); + assertTrue(frame.isInvalid()); + assertEquals(0, frame.headers().names().size()); + headerBlock.release(); + } + + @Test + public void testZeroValueLength() throws Exception { + ByteBuf headerBlock = Unpooled.buffer(16); + headerBlock.writeInt(1); + headerBlock.writeInt(4); + headerBlock.writeBytes(nameBytes); + headerBlock.writeInt(0); + decoder.decode(ByteBufAllocator.DEFAULT, headerBlock, frame); + decoder.endHeaderBlock(frame); + + assertFalse(headerBlock.isReadable()); + assertFalse(frame.isInvalid()); + assertEquals(1, frame.headers().names().size()); + assertTrue(frame.headers().contains(name)); + assertEquals(1, frame.headers().getAll(name).size()); + assertEquals("", frame.headers().get(name)); + headerBlock.release(); + } + + @Test + public void testNegativeValueLength() throws Exception { + ByteBuf headerBlock = Unpooled.buffer(16); + headerBlock.writeInt(1); + headerBlock.writeInt(4); + headerBlock.writeBytes(nameBytes); + headerBlock.writeInt(-1); + decoder.decode(ByteBufAllocator.DEFAULT, headerBlock, frame); + + assertFalse(headerBlock.isReadable()); + assertTrue(frame.isInvalid()); + assertEquals(0, frame.headers().names().size()); + headerBlock.release(); + } + + @Test + public void testMissingValue() throws Exception { + ByteBuf headerBlock = Unpooled.buffer(16); + headerBlock.writeInt(1); + headerBlock.writeInt(4); + headerBlock.writeBytes(nameBytes); + headerBlock.writeInt(5); + decoder.decode(ByteBufAllocator.DEFAULT, headerBlock, frame); + decoder.endHeaderBlock(frame); + + assertFalse(headerBlock.isReadable()); + assertTrue(frame.isInvalid()); + assertEquals(0, frame.headers().names().size()); + headerBlock.release(); + } + + @Test + public void testIllegalValueOnlyNull() throws Exception { + ByteBuf headerBlock = Unpooled.buffer(17); + headerBlock.writeInt(1); + headerBlock.writeInt(4); + headerBlock.writeBytes(nameBytes); + headerBlock.writeInt(1); + headerBlock.writeByte(0); + decoder.decode(ByteBufAllocator.DEFAULT, headerBlock, frame); + + assertFalse(headerBlock.isReadable()); + assertTrue(frame.isInvalid()); + assertEquals(0, frame.headers().names().size()); + headerBlock.release(); + } + + @Test + public void testIllegalValueStartsWithNull() throws Exception { + ByteBuf headerBlock = Unpooled.buffer(22); + headerBlock.writeInt(1); + headerBlock.writeInt(4); + headerBlock.writeBytes(nameBytes); + headerBlock.writeInt(6); + headerBlock.writeByte(0); + headerBlock.writeBytes(valueBytes); + decoder.decode(ByteBufAllocator.DEFAULT, headerBlock, frame); + + assertFalse(headerBlock.isReadable()); + assertTrue(frame.isInvalid()); + assertEquals(0, frame.headers().names().size()); + headerBlock.release(); + } + + @Test + public void testIllegalValueEndsWithNull() throws Exception { + ByteBuf headerBlock = Unpooled.buffer(22); + headerBlock.writeInt(1); + headerBlock.writeInt(4); + headerBlock.writeBytes(nameBytes); + headerBlock.writeInt(6); + headerBlock.writeBytes(valueBytes); + headerBlock.writeByte(0); + decoder.decode(ByteBufAllocator.DEFAULT, headerBlock, frame); + + assertFalse(headerBlock.isReadable()); + assertTrue(frame.isInvalid()); + assertEquals(0, frame.headers().names().size()); + headerBlock.release(); + } + + @Test + public void testMultipleValues() throws Exception { + ByteBuf headerBlock = Unpooled.buffer(27); + headerBlock.writeInt(1); + headerBlock.writeInt(4); + headerBlock.writeBytes(nameBytes); + headerBlock.writeInt(11); + headerBlock.writeBytes(valueBytes); + headerBlock.writeByte(0); + headerBlock.writeBytes(valueBytes); + decoder.decode(ByteBufAllocator.DEFAULT, headerBlock, frame); + decoder.endHeaderBlock(frame); + + assertFalse(headerBlock.isReadable()); + assertFalse(frame.isInvalid()); + assertEquals(1, frame.headers().names().size()); + assertTrue(frame.headers().contains(name)); + assertEquals(2, frame.headers().getAll(name).size()); + assertEquals(value, frame.headers().getAll(name).get(0)); + assertEquals(value, frame.headers().getAll(name).get(1)); + headerBlock.release(); + } + + @Test + public void testMultipleValuesEndsWithNull() throws Exception { + ByteBuf headerBlock = Unpooled.buffer(28); + headerBlock.writeInt(1); + headerBlock.writeInt(4); + headerBlock.writeBytes(nameBytes); + headerBlock.writeInt(12); + headerBlock.writeBytes(valueBytes); + headerBlock.writeByte(0); + headerBlock.writeBytes(valueBytes); + headerBlock.writeByte(0); + decoder.decode(ByteBufAllocator.DEFAULT, headerBlock, frame); + + assertFalse(headerBlock.isReadable()); + assertTrue(frame.isInvalid()); + assertEquals(1, frame.headers().names().size()); + assertTrue(frame.headers().contains(name)); + assertEquals(1, frame.headers().getAll(name).size()); + assertEquals(value, frame.headers().get(name)); + headerBlock.release(); + } + + @Test + public void testIllegalValueMultipleNulls() throws Exception { + ByteBuf headerBlock = Unpooled.buffer(28); + headerBlock.writeInt(1); + headerBlock.writeInt(4); + headerBlock.writeBytes(nameBytes); + headerBlock.writeInt(12); + headerBlock.writeBytes(valueBytes); + headerBlock.writeByte(0); + headerBlock.writeByte(0); + headerBlock.writeBytes(valueBytes); + decoder.decode(ByteBufAllocator.DEFAULT, headerBlock, frame); + decoder.endHeaderBlock(frame); + + assertFalse(headerBlock.isReadable()); + assertTrue(frame.isInvalid()); + assertEquals(0, frame.headers().names().size()); + headerBlock.release(); + } + + @Test + public void testMissingNextNameValuePair() throws Exception { + ByteBuf headerBlock = Unpooled.buffer(21); + headerBlock.writeInt(2); + headerBlock.writeInt(4); + headerBlock.writeBytes(nameBytes); + headerBlock.writeInt(5); + headerBlock.writeBytes(valueBytes); + decoder.decode(ByteBufAllocator.DEFAULT, headerBlock, frame); + decoder.endHeaderBlock(frame); + + assertFalse(headerBlock.isReadable()); + assertTrue(frame.isInvalid()); + assertEquals(1, frame.headers().names().size()); + assertTrue(frame.headers().contains(name)); + assertEquals(1, frame.headers().getAll(name).size()); + assertEquals(value, frame.headers().get(name)); + headerBlock.release(); + } + + @Test + public void testMultipleNames() throws Exception { + ByteBuf headerBlock = Unpooled.buffer(38); + headerBlock.writeInt(2); + headerBlock.writeInt(4); + headerBlock.writeBytes(nameBytes); + headerBlock.writeInt(5); + headerBlock.writeBytes(valueBytes); + headerBlock.writeInt(4); + headerBlock.writeBytes(nameBytes); + headerBlock.writeInt(5); + headerBlock.writeBytes(valueBytes); + decoder.decode(ByteBufAllocator.DEFAULT, headerBlock, frame); + + assertFalse(headerBlock.isReadable()); + assertTrue(frame.isInvalid()); + assertEquals(1, frame.headers().names().size()); + assertTrue(frame.headers().contains(name)); + assertEquals(1, frame.headers().getAll(name).size()); + assertEquals(value, frame.headers().get(name)); + headerBlock.release(); + } + + @Test + public void testExtraData() throws Exception { + ByteBuf headerBlock = Unpooled.buffer(22); + headerBlock.writeInt(1); + headerBlock.writeInt(4); + headerBlock.writeBytes(nameBytes); + headerBlock.writeInt(5); + headerBlock.writeBytes(valueBytes); + headerBlock.writeByte(0); + decoder.decode(ByteBufAllocator.DEFAULT, headerBlock, frame); + + assertFalse(headerBlock.isReadable()); + assertTrue(frame.isInvalid()); + assertEquals(1, frame.headers().names().size()); + assertTrue(frame.headers().contains(name)); + assertEquals(1, frame.headers().getAll(name).size()); + assertEquals(value, frame.headers().get(name)); + headerBlock.release(); + } + + @Test + public void testMultipleDecodes() throws Exception { + ByteBuf headerBlock = Unpooled.buffer(21); + headerBlock.writeInt(1); + headerBlock.writeInt(4); + headerBlock.writeBytes(nameBytes); + headerBlock.writeInt(5); + headerBlock.writeBytes(valueBytes); + + int readableBytes = headerBlock.readableBytes(); + for (int i = 0; i < readableBytes; i++) { + ByteBuf headerBlockSegment = headerBlock.slice(i, 1); + decoder.decode(ByteBufAllocator.DEFAULT, headerBlockSegment, frame); + assertFalse(headerBlockSegment.isReadable()); + } + decoder.endHeaderBlock(frame); + + assertFalse(frame.isInvalid()); + assertEquals(1, frame.headers().names().size()); + assertTrue(frame.headers().contains(name)); + assertEquals(1, frame.headers().getAll(name).size()); + assertEquals(value, frame.headers().get(name)); + headerBlock.release(); + } + + @Test + public void testContinueAfterInvalidHeaders() throws Exception { + ByteBuf numHeaders = Unpooled.buffer(4); + numHeaders.writeInt(1); + + ByteBuf nameBlock = Unpooled.buffer(8); + nameBlock.writeInt(4); + nameBlock.writeBytes(nameBytes); + + ByteBuf valueBlock = Unpooled.buffer(9); + valueBlock.writeInt(5); + valueBlock.writeBytes(valueBytes); + + decoder.decode(ByteBufAllocator.DEFAULT, numHeaders, frame); + decoder.decode(ByteBufAllocator.DEFAULT, nameBlock, frame); + frame.setInvalid(); + decoder.decode(ByteBufAllocator.DEFAULT, valueBlock, frame); + decoder.endHeaderBlock(frame); + + assertFalse(numHeaders.isReadable()); + assertFalse(nameBlock.isReadable()); + assertFalse(valueBlock.isReadable()); + assertEquals(1, frame.headers().names().size()); + assertTrue(frame.headers().contains(name)); + assertEquals(1, frame.headers().getAll(name).size()); + assertEquals(value, frame.headers().get(name)); + numHeaders.release(); + nameBlock.release(); + valueBlock.release(); + } + + @Test + public void testTruncatedHeaderName() throws Exception { + ByteBuf headerBlock = Unpooled.buffer(maxHeaderSize + 18); + headerBlock.writeInt(1); + headerBlock.writeInt(maxHeaderSize + 1); + for (int i = 0; i < maxHeaderSize + 1; i++) { + headerBlock.writeByte('a'); + } + headerBlock.writeInt(5); + headerBlock.writeBytes(valueBytes); + decoder.decode(ByteBufAllocator.DEFAULT, headerBlock, frame); + decoder.endHeaderBlock(frame); + + assertFalse(headerBlock.isReadable()); + assertTrue(frame.isTruncated()); + assertFalse(frame.isInvalid()); + assertEquals(0, frame.headers().names().size()); + headerBlock.release(); + } + + @Test + public void testTruncatedHeaderValue() throws Exception { + ByteBuf headerBlock = Unpooled.buffer(maxHeaderSize + 13); + headerBlock.writeInt(1); + headerBlock.writeInt(4); + headerBlock.writeBytes(nameBytes); + headerBlock.writeInt(13); + for (int i = 0; i < maxHeaderSize - 3; i++) { + headerBlock.writeByte('a'); + } + decoder.decode(ByteBufAllocator.DEFAULT, headerBlock, frame); + decoder.endHeaderBlock(frame); + + assertFalse(headerBlock.isReadable()); + assertTrue(frame.isTruncated()); + assertFalse(frame.isInvalid()); + assertEquals(0, frame.headers().names().size()); + headerBlock.release(); + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/spdy/SpdyHeaderBlockZlibDecoderTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/spdy/SpdyHeaderBlockZlibDecoderTest.java new file mode 100644 index 0000000..379ff7a --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/spdy/SpdyHeaderBlockZlibDecoderTest.java @@ -0,0 +1,245 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.spdy; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class SpdyHeaderBlockZlibDecoderTest { + + // zlib header indicating 32K window size fastest deflate algorithm with SPDY dictionary + private static final byte[] zlibHeader = {0x78, 0x3f, (byte) 0xe3, (byte) 0xc6, (byte) 0xa7, (byte) 0xc2}; + private static final byte[] zlibSyncFlush = {0x00, 0x00, 0x00, (byte) 0xff, (byte) 0xff}; + + private static final int maxHeaderSize = 8192; + + private static final String name = "name"; + private static final String value = "value"; + private static final byte[] nameBytes = name.getBytes(); + private static final byte[] valueBytes = value.getBytes(); + + private SpdyHeaderBlockZlibDecoder decoder; + private SpdyHeadersFrame frame; + + @BeforeEach + public void setUp() { + decoder = new SpdyHeaderBlockZlibDecoder(SpdyVersion.SPDY_3_1, maxHeaderSize); + frame = new DefaultSpdyHeadersFrame(1); + } + + @AfterEach + public void tearDown() { + decoder.end(); + } + + @Test + public void testHeaderBlock() throws Exception { + ByteBuf headerBlock = Unpooled.buffer(37); + headerBlock.writeBytes(zlibHeader); + headerBlock.writeByte(0); // Non-compressed block + headerBlock.writeByte(0x15); // little-endian length (21) + headerBlock.writeByte(0x00); // little-endian length (21) + headerBlock.writeByte(0xea); // one's compliment of length + headerBlock.writeByte(0xff); // one's compliment of length + headerBlock.writeInt(1); // number of Name/Value pairs + headerBlock.writeInt(4); // length of name + headerBlock.writeBytes(nameBytes); + headerBlock.writeInt(5); // length of value + headerBlock.writeBytes(valueBytes); + headerBlock.writeBytes(zlibSyncFlush); + decoder.decode(ByteBufAllocator.DEFAULT, headerBlock, frame); + decoder.endHeaderBlock(frame); + + assertFalse(headerBlock.isReadable()); + assertFalse(frame.isInvalid()); + assertEquals(1, frame.headers().names().size()); + assertTrue(frame.headers().contains(name)); + assertEquals(1, frame.headers().getAll(name).size()); + assertEquals(value, frame.headers().get(name)); + + headerBlock.release(); + } + + @Test + public void testHeaderBlockMultipleDecodes() throws Exception { + ByteBuf headerBlock = Unpooled.buffer(37); + headerBlock.writeBytes(zlibHeader); + headerBlock.writeByte(0); // Non-compressed block + headerBlock.writeByte(0x15); // little-endian length (21) + headerBlock.writeByte(0x00); // little-endian length (21) + headerBlock.writeByte(0xea); // one's compliment of length + headerBlock.writeByte(0xff); // one's compliment of length + headerBlock.writeInt(1); // number of Name/Value pairs + headerBlock.writeInt(4); // length of name + headerBlock.writeBytes(nameBytes); + headerBlock.writeInt(5); // length of value + headerBlock.writeBytes(valueBytes); + headerBlock.writeBytes(zlibSyncFlush); + + int readableBytes = headerBlock.readableBytes(); + for (int i = 0; i < readableBytes; i++) { + ByteBuf headerBlockSegment = headerBlock.slice(i, 1); + decoder.decode(ByteBufAllocator.DEFAULT, headerBlockSegment, frame); + assertFalse(headerBlockSegment.isReadable()); + } + decoder.endHeaderBlock(frame); + + assertFalse(frame.isInvalid()); + assertEquals(1, frame.headers().names().size()); + assertTrue(frame.headers().contains(name)); + assertEquals(1, frame.headers().getAll(name).size()); + assertEquals(value, frame.headers().get(name)); + + headerBlock.release(); + } + + @Test + public void testLargeHeaderName() throws Exception { + ByteBuf headerBlock = Unpooled.buffer(8220); + headerBlock.writeBytes(zlibHeader); + headerBlock.writeByte(0); // Non-compressed block + headerBlock.writeByte(0x0c); // little-endian length (8204) + headerBlock.writeByte(0x20); // little-endian length (8204) + headerBlock.writeByte(0xf3); // one's compliment of length + headerBlock.writeByte(0xdf); // one's compliment of length + headerBlock.writeInt(1); // number of Name/Value pairs + headerBlock.writeInt(8192); // length of name + for (int i = 0; i < 8192; i++) { + headerBlock.writeByte('n'); + } + headerBlock.writeInt(0); // length of value + headerBlock.writeBytes(zlibSyncFlush); + decoder.decode(ByteBufAllocator.DEFAULT, headerBlock, frame); + decoder.endHeaderBlock(frame); + + assertFalse(headerBlock.isReadable()); + assertFalse(frame.isInvalid()); + assertFalse(frame.isTruncated()); + assertEquals(1, frame.headers().names().size()); + + headerBlock.release(); + } + + @Test + public void testLargeHeaderValue() throws Exception { + ByteBuf headerBlock = Unpooled.buffer(8220); + headerBlock.writeBytes(zlibHeader); + headerBlock.writeByte(0); // Non-compressed block + headerBlock.writeByte(0x0c); // little-endian length (8204) + headerBlock.writeByte(0x20); // little-endian length (8204) + headerBlock.writeByte(0xf3); // one's compliment of length + headerBlock.writeByte(0xdf); // one's compliment of length + headerBlock.writeInt(1); // number of Name/Value pairs + headerBlock.writeInt(1); // length of name + headerBlock.writeByte('n'); + headerBlock.writeInt(8191); // length of value + for (int i = 0; i < 8191; i++) { + headerBlock.writeByte('v'); + } + headerBlock.writeBytes(zlibSyncFlush); + decoder.decode(ByteBufAllocator.DEFAULT, headerBlock, frame); + decoder.endHeaderBlock(frame); + + assertFalse(headerBlock.isReadable()); + assertFalse(frame.isInvalid()); + assertFalse(frame.isTruncated()); + assertEquals(1, frame.headers().names().size()); + assertEquals(8191, frame.headers().get("n").length()); + + headerBlock.release(); + } + + @Test + public void testHeaderBlockExtraData() throws Exception { + final ByteBuf headerBlock = Unpooled.buffer(37); + headerBlock.writeBytes(zlibHeader); + headerBlock.writeByte(0); // Non-compressed block + headerBlock.writeByte(0x15); // little-endian length (21) + headerBlock.writeByte(0x00); // little-endian length (21) + headerBlock.writeByte(0xea); // one's compliment of length + headerBlock.writeByte(0xff); // one's compliment of length + headerBlock.writeInt(1); // number of Name/Value pairs + headerBlock.writeInt(4); // length of name + headerBlock.writeBytes(nameBytes); + headerBlock.writeInt(5); // length of value + headerBlock.writeBytes(valueBytes); + headerBlock.writeByte(0x19); // adler-32 checksum + headerBlock.writeByte(0xa5); // adler-32 checksum + headerBlock.writeByte(0x03); // adler-32 checksum + headerBlock.writeByte(0xc9); // adler-32 checksum + headerBlock.writeByte(0); // Data following zlib stream + + assertThrows(SpdyProtocolException.class, new Executable() { + @Override + public void execute() throws Throwable { + decoder.decode(ByteBufAllocator.DEFAULT, headerBlock, frame); + } + }); + + headerBlock.release(); + } + + @Test + public void testHeaderBlockInvalidDictionary() throws Exception { + final ByteBuf headerBlock = Unpooled.buffer(7); + headerBlock.writeByte(0x78); + headerBlock.writeByte(0x3f); + headerBlock.writeByte(0x01); // Unknown dictionary + headerBlock.writeByte(0x02); // Unknown dictionary + headerBlock.writeByte(0x03); // Unknown dictionary + headerBlock.writeByte(0x04); // Unknown dictionary + headerBlock.writeByte(0); // Non-compressed block + + assertThrows(SpdyProtocolException.class, new Executable() { + @Override + public void execute() throws Throwable { + decoder.decode(ByteBufAllocator.DEFAULT, headerBlock, frame); + } + }); + + headerBlock.release(); + } + + @Test + public void testHeaderBlockInvalidDeflateBlock() throws Exception { + final ByteBuf headerBlock = Unpooled.buffer(11); + headerBlock.writeBytes(zlibHeader); + headerBlock.writeByte(0); // Non-compressed block + headerBlock.writeByte(0x00); // little-endian length (0) + headerBlock.writeByte(0x00); // little-endian length (0) + headerBlock.writeByte(0x00); // invalid one's compliment + headerBlock.writeByte(0x00); // invalid one's compliment + + assertThrows(SpdyProtocolException.class, new Executable() { + @Override + public void execute() throws Throwable { + decoder.decode(ByteBufAllocator.DEFAULT, headerBlock, frame); + } + }); + + headerBlock.release(); + } +} diff --git a/netty-handler-codec-http/src/test/java/io/netty/handler/codec/spdy/SpdySessionHandlerTest.java b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/spdy/SpdySessionHandlerTest.java new file mode 100644 index 0000000..45e5271 --- /dev/null +++ b/netty-handler-codec-http/src/test/java/io/netty/handler/codec/spdy/SpdySessionHandlerTest.java @@ -0,0 +1,392 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.spdy; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class SpdySessionHandlerTest { + + private static final InternalLogger logger = + InternalLoggerFactory.getInstance(SpdySessionHandlerTest.class); + + private static final int closeSignal = SpdyCodecUtil.SPDY_SETTINGS_MAX_ID; + private static final SpdySettingsFrame closeMessage = new DefaultSpdySettingsFrame(); + + static { + closeMessage.setValue(closeSignal, 0); + } + + private static void assertDataFrame(Object msg, int streamId, boolean last) { + assertNotNull(msg); + assertTrue(msg instanceof SpdyDataFrame); + SpdyDataFrame spdyDataFrame = (SpdyDataFrame) msg; + assertEquals(streamId, spdyDataFrame.streamId()); + assertEquals(last, spdyDataFrame.isLast()); + } + + private static void assertSynReply(Object msg, int streamId, boolean last, SpdyHeaders headers) { + assertNotNull(msg); + assertTrue(msg instanceof SpdySynReplyFrame); + assertHeaders(msg, streamId, last, headers); + } + + private static void assertRstStream(Object msg, int streamId, SpdyStreamStatus status) { + assertNotNull(msg); + assertTrue(msg instanceof SpdyRstStreamFrame); + SpdyRstStreamFrame spdyRstStreamFrame = (SpdyRstStreamFrame) msg; + assertEquals(streamId, spdyRstStreamFrame.streamId()); + assertEquals(status, spdyRstStreamFrame.status()); + } + + private static void assertPing(Object msg, int id) { + assertNotNull(msg); + assertTrue(msg instanceof SpdyPingFrame); + SpdyPingFrame spdyPingFrame = (SpdyPingFrame) msg; + assertEquals(id, spdyPingFrame.id()); + } + + private static void assertGoAway(Object msg, int lastGoodStreamId) { + assertNotNull(msg); + assertTrue(msg instanceof SpdyGoAwayFrame); + SpdyGoAwayFrame spdyGoAwayFrame = (SpdyGoAwayFrame) msg; + assertEquals(lastGoodStreamId, spdyGoAwayFrame.lastGoodStreamId()); + } + + private static void assertHeaders(Object msg, int streamId, boolean last, SpdyHeaders headers) { + assertNotNull(msg); + assertTrue(msg instanceof SpdyHeadersFrame); + SpdyHeadersFrame spdyHeadersFrame = (SpdyHeadersFrame) msg; + assertEquals(streamId, spdyHeadersFrame.streamId()); + assertEquals(last, spdyHeadersFrame.isLast()); + for (CharSequence name: headers.names()) { + List expectedValues = headers.getAll(name); + List receivedValues = spdyHeadersFrame.headers().getAll(name); + assertTrue(receivedValues.containsAll(expectedValues)); + receivedValues.removeAll(expectedValues); + assertTrue(receivedValues.isEmpty()); + spdyHeadersFrame.headers().remove(name); + } + assertTrue(spdyHeadersFrame.headers().isEmpty()); + } + + private static void testSpdySessionHandler(SpdyVersion version, boolean server) { + EmbeddedChannel sessionHandler = new EmbeddedChannel( + new SpdySessionHandler(version, server), new EchoHandler(closeSignal, server)); + + while (sessionHandler.readOutbound() != null) { + continue; + } + + int localStreamId = server ? 1 : 2; + int remoteStreamId = server ? 2 : 1; + + SpdySynStreamFrame spdySynStreamFrame = + new DefaultSpdySynStreamFrame(localStreamId, 0, (byte) 0); + spdySynStreamFrame.headers().set("compression", "test"); + + SpdyDataFrame spdyDataFrame = new DefaultSpdyDataFrame(localStreamId); + spdyDataFrame.setLast(true); + + // Check if session handler returns INVALID_STREAM if it receives + // a data frame for a Stream-ID that is not open + sessionHandler.writeInbound(new DefaultSpdyDataFrame(localStreamId)); + assertRstStream(sessionHandler.readOutbound(), localStreamId, SpdyStreamStatus.INVALID_STREAM); + assertNull(sessionHandler.readOutbound()); + + // Check if session handler returns PROTOCOL_ERROR if it receives + // a data frame for a Stream-ID before receiving a SYN_REPLY frame + sessionHandler.writeInbound(new DefaultSpdyDataFrame(remoteStreamId)); + assertRstStream(sessionHandler.readOutbound(), remoteStreamId, SpdyStreamStatus.PROTOCOL_ERROR); + assertNull(sessionHandler.readOutbound()); + remoteStreamId += 2; + + // Check if session handler returns PROTOCOL_ERROR if it receives + // multiple SYN_REPLY frames for the same active Stream-ID + sessionHandler.writeInbound(new DefaultSpdySynReplyFrame(remoteStreamId)); + assertNull(sessionHandler.readOutbound()); + sessionHandler.writeInbound(new DefaultSpdySynReplyFrame(remoteStreamId)); + assertRstStream(sessionHandler.readOutbound(), remoteStreamId, SpdyStreamStatus.STREAM_IN_USE); + assertNull(sessionHandler.readOutbound()); + remoteStreamId += 2; + + // Check if frame codec correctly compresses/uncompresses headers + sessionHandler.writeInbound(spdySynStreamFrame); + assertSynReply(sessionHandler.readOutbound(), localStreamId, false, spdySynStreamFrame.headers()); + assertNull(sessionHandler.readOutbound()); + SpdyHeadersFrame spdyHeadersFrame = new DefaultSpdyHeadersFrame(localStreamId); + + spdyHeadersFrame.headers().add("header", "test1"); + spdyHeadersFrame.headers().add("header", "test2"); + + sessionHandler.writeInbound(spdyHeadersFrame); + assertHeaders(sessionHandler.readOutbound(), localStreamId, false, spdyHeadersFrame.headers()); + assertNull(sessionHandler.readOutbound()); + localStreamId += 2; + + // Check if session handler closed the streams using the number + // of concurrent streams and that it returns REFUSED_STREAM + // if it receives a SYN_STREAM frame it does not wish to accept + spdySynStreamFrame.setStreamId(localStreamId); + spdySynStreamFrame.setLast(true); + spdySynStreamFrame.setUnidirectional(true); + + sessionHandler.writeInbound(spdySynStreamFrame); + assertRstStream(sessionHandler.readOutbound(), localStreamId, SpdyStreamStatus.REFUSED_STREAM); + assertNull(sessionHandler.readOutbound()); + + // Check if session handler rejects HEADERS for closed streams + int testStreamId = spdyDataFrame.streamId(); + sessionHandler.writeInbound(spdyDataFrame); + assertDataFrame(sessionHandler.readOutbound(), testStreamId, spdyDataFrame.isLast()); + assertNull(sessionHandler.readOutbound()); + spdyHeadersFrame.setStreamId(testStreamId); + + sessionHandler.writeInbound(spdyHeadersFrame); + assertRstStream(sessionHandler.readOutbound(), testStreamId, SpdyStreamStatus.INVALID_STREAM); + assertNull(sessionHandler.readOutbound()); + + // Check if session handler drops active streams if it receives + // a RST_STREAM frame for that Stream-ID + sessionHandler.writeInbound(new DefaultSpdyRstStreamFrame(remoteStreamId, 3)); + assertNull(sessionHandler.readOutbound()); + //remoteStreamId += 2; + + // Check if session handler honors UNIDIRECTIONAL streams + spdySynStreamFrame.setLast(false); + sessionHandler.writeInbound(spdySynStreamFrame); + assertNull(sessionHandler.readOutbound()); + spdySynStreamFrame.setUnidirectional(false); + + // Check if session handler returns PROTOCOL_ERROR if it receives + // multiple SYN_STREAM frames for the same active Stream-ID + sessionHandler.writeInbound(spdySynStreamFrame); + assertRstStream(sessionHandler.readOutbound(), localStreamId, SpdyStreamStatus.PROTOCOL_ERROR); + assertNull(sessionHandler.readOutbound()); + localStreamId += 2; + + // Check if session handler returns PROTOCOL_ERROR if it receives + // a SYN_STREAM frame with an invalid Stream-ID + spdySynStreamFrame.setStreamId(localStreamId - 1); + sessionHandler.writeInbound(spdySynStreamFrame); + assertRstStream(sessionHandler.readOutbound(), localStreamId - 1, SpdyStreamStatus.PROTOCOL_ERROR); + assertNull(sessionHandler.readOutbound()); + spdySynStreamFrame.setStreamId(localStreamId); + + // Check if session handler returns PROTOCOL_ERROR if it receives + // an invalid HEADERS frame + spdyHeadersFrame.setStreamId(localStreamId); + + spdyHeadersFrame.setInvalid(); + sessionHandler.writeInbound(spdyHeadersFrame); + assertRstStream(sessionHandler.readOutbound(), localStreamId, SpdyStreamStatus.PROTOCOL_ERROR); + assertNull(sessionHandler.readOutbound()); + + sessionHandler.finish(); + } + + private static void testSpdySessionHandlerPing(SpdyVersion version, boolean server) { + EmbeddedChannel sessionHandler = new EmbeddedChannel( + new SpdySessionHandler(version, server), new EchoHandler(closeSignal, server)); + + while (sessionHandler.readOutbound() != null) { + continue; + } + + int localStreamId = server ? 1 : 2; + int remoteStreamId = server ? 2 : 1; + + SpdyPingFrame localPingFrame = new DefaultSpdyPingFrame(localStreamId); + SpdyPingFrame remotePingFrame = new DefaultSpdyPingFrame(remoteStreamId); + + // Check if session handler returns identical local PINGs + sessionHandler.writeInbound(localPingFrame); + assertPing(sessionHandler.readOutbound(), localPingFrame.id()); + assertNull(sessionHandler.readOutbound()); + + // Check if session handler ignores un-initiated remote PINGs + sessionHandler.writeInbound(remotePingFrame); + assertNull(sessionHandler.readOutbound()); + + sessionHandler.finish(); + } + + private static void testSpdySessionHandlerGoAway(SpdyVersion version, boolean server) { + EmbeddedChannel sessionHandler = new EmbeddedChannel( + new SpdySessionHandler(version, server), new EchoHandler(closeSignal, server)); + + while (sessionHandler.readOutbound() != null) { + continue; + } + + int localStreamId = server ? 1 : 2; + + SpdySynStreamFrame spdySynStreamFrame = + new DefaultSpdySynStreamFrame(localStreamId, 0, (byte) 0); + spdySynStreamFrame.headers().set("compression", "test"); + + SpdyDataFrame spdyDataFrame = new DefaultSpdyDataFrame(localStreamId); + spdyDataFrame.setLast(true); + + // Send an initial request + sessionHandler.writeInbound(spdySynStreamFrame); + assertSynReply(sessionHandler.readOutbound(), localStreamId, false, spdySynStreamFrame.headers()); + assertNull(sessionHandler.readOutbound()); + sessionHandler.writeInbound(spdyDataFrame); + assertDataFrame(sessionHandler.readOutbound(), localStreamId, true); + assertNull(sessionHandler.readOutbound()); + + // Check if session handler sends a GOAWAY frame when closing + sessionHandler.writeInbound(closeMessage); + assertGoAway(sessionHandler.readOutbound(), localStreamId); + assertNull(sessionHandler.readOutbound()); + localStreamId += 2; + + // Check if session handler returns REFUSED_STREAM if it receives + // SYN_STREAM frames after sending a GOAWAY frame + spdySynStreamFrame.setStreamId(localStreamId); + sessionHandler.writeInbound(spdySynStreamFrame); + assertRstStream(sessionHandler.readOutbound(), localStreamId, SpdyStreamStatus.REFUSED_STREAM); + assertNull(sessionHandler.readOutbound()); + + // Check if session handler ignores Data frames after sending + // a GOAWAY frame + spdyDataFrame.setStreamId(localStreamId); + sessionHandler.writeInbound(spdyDataFrame); + assertNull(sessionHandler.readOutbound()); + + sessionHandler.finish(); + } + + @Test + public void testSpdyClientSessionHandler() { + logger.info("Running: testSpdyClientSessionHandler v3.1"); + testSpdySessionHandler(SpdyVersion.SPDY_3_1, false); + } + + @Test + public void testSpdyClientSessionHandlerPing() { + logger.info("Running: testSpdyClientSessionHandlerPing v3.1"); + testSpdySessionHandlerPing(SpdyVersion.SPDY_3_1, false); + } + + @Test + public void testSpdyClientSessionHandlerGoAway() { + logger.info("Running: testSpdyClientSessionHandlerGoAway v3.1"); + testSpdySessionHandlerGoAway(SpdyVersion.SPDY_3_1, false); + } + + @Test + public void testSpdyServerSessionHandler() { + logger.info("Running: testSpdyServerSessionHandler v3.1"); + testSpdySessionHandler(SpdyVersion.SPDY_3_1, true); + } + + @Test + public void testSpdyServerSessionHandlerPing() { + logger.info("Running: testSpdyServerSessionHandlerPing v3.1"); + testSpdySessionHandlerPing(SpdyVersion.SPDY_3_1, true); + } + + @Test + public void testSpdyServerSessionHandlerGoAway() { + logger.info("Running: testSpdyServerSessionHandlerGoAway v3.1"); + testSpdySessionHandlerGoAway(SpdyVersion.SPDY_3_1, true); + } + + // Echo Handler opens 4 half-closed streams on session connection + // and then sets the number of concurrent streams to 1 + private static class EchoHandler extends ChannelInboundHandlerAdapter { + private final int closeSignal; + private final boolean server; + + EchoHandler(int closeSignal, boolean server) { + this.closeSignal = closeSignal; + this.server = server; + } + + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + // Initiate 4 new streams + int streamId = server ? 2 : 1; + SpdySynStreamFrame spdySynStreamFrame = + new DefaultSpdySynStreamFrame(streamId, 0, (byte) 0); + spdySynStreamFrame.setLast(true); + ctx.writeAndFlush(spdySynStreamFrame); + spdySynStreamFrame.setStreamId(spdySynStreamFrame.streamId() + 2); + ctx.writeAndFlush(spdySynStreamFrame); + spdySynStreamFrame.setStreamId(spdySynStreamFrame.streamId() + 2); + ctx.writeAndFlush(spdySynStreamFrame); + spdySynStreamFrame.setStreamId(spdySynStreamFrame.streamId() + 2); + ctx.writeAndFlush(spdySynStreamFrame); + + // Limit the number of concurrent streams to 1 + SpdySettingsFrame spdySettingsFrame = new DefaultSpdySettingsFrame(); + spdySettingsFrame.setValue(SpdySettingsFrame.SETTINGS_MAX_CONCURRENT_STREAMS, 1); + ctx.writeAndFlush(spdySettingsFrame); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if (msg instanceof SpdySynStreamFrame) { + + SpdySynStreamFrame spdySynStreamFrame = (SpdySynStreamFrame) msg; + if (!spdySynStreamFrame.isUnidirectional()) { + int streamId = spdySynStreamFrame.streamId(); + SpdySynReplyFrame spdySynReplyFrame = new DefaultSpdySynReplyFrame(streamId); + spdySynReplyFrame.setLast(spdySynStreamFrame.isLast()); + for (Map.Entry entry: spdySynStreamFrame.headers()) { + spdySynReplyFrame.headers().add(entry.getKey(), entry.getValue()); + } + + ctx.writeAndFlush(spdySynReplyFrame); + } + return; + } + + if (msg instanceof SpdySynReplyFrame) { + return; + } + + if (msg instanceof SpdyDataFrame || + msg instanceof SpdyPingFrame || + msg instanceof SpdyHeadersFrame) { + + ctx.writeAndFlush(msg); + return; + } + + if (msg instanceof SpdySettingsFrame) { + SpdySettingsFrame spdySettingsFrame = (SpdySettingsFrame) msg; + if (spdySettingsFrame.isSet(closeSignal)) { + ctx.close(); + } + } + } + } +} diff --git a/netty-handler-codec-http/src/test/resources/file-01.txt b/netty-handler-codec-http/src/test/resources/file-01.txt new file mode 100644 index 0000000..a94c45f --- /dev/null +++ b/netty-handler-codec-http/src/test/resources/file-01.txt @@ -0,0 +1 @@ +File 01 diff --git a/netty-handler-codec-http/src/test/resources/file-02.txt b/netty-handler-codec-http/src/test/resources/file-02.txt new file mode 100644 index 0000000..e2e0c12 --- /dev/null +++ b/netty-handler-codec-http/src/test/resources/file-02.txt @@ -0,0 +1 @@ +File 02 diff --git a/netty-handler-codec-http/src/test/resources/file-03.txt b/netty-handler-codec-http/src/test/resources/file-03.txt new file mode 100644 index 0000000..b545f1b --- /dev/null +++ b/netty-handler-codec-http/src/test/resources/file-03.txt @@ -0,0 +1 @@ +File 03 diff --git a/netty-handler-codec-http/src/test/resources/junit-platform.properties b/netty-handler-codec-http/src/test/resources/junit-platform.properties new file mode 100644 index 0000000..4bcf35e --- /dev/null +++ b/netty-handler-codec-http/src/test/resources/junit-platform.properties @@ -0,0 +1,16 @@ +# Copyright 2022 The Netty Project +# +# The Netty Project licenses this file to you under the Apache License, +# version 2.0 (the "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +junit.jupiter.execution.parallel.enabled = true +junit.jupiter.execution.parallel.mode.default = concurrent diff --git a/netty-handler-codec-http/src/test/resources/logging.properties b/netty-handler-codec-http/src/test/resources/logging.properties new file mode 100644 index 0000000..3cd7309 --- /dev/null +++ b/netty-handler-codec-http/src/test/resources/logging.properties @@ -0,0 +1,7 @@ +handlers=java.util.logging.ConsoleHandler +.level=ALL +java.util.logging.SimpleFormatter.format=%1$tY-%1$tm-%1$td %1$tH:%1$tM:%1$tS.%1$tL %4$-7s [%3$s] %5$s %6$s%n +java.util.logging.ConsoleHandler.level=ALL +java.util.logging.ConsoleHandler.formatter=java.util.logging.SimpleFormatter +jdk.event.security.level=INFO +org.junit.jupiter.engine.execution.ConditionEvaluator.level=OFF diff --git a/netty-handler-codec-http2/build.gradle b/netty-handler-codec-http2/build.gradle new file mode 100644 index 0000000..c09f846 --- /dev/null +++ b/netty-handler-codec-http2/build.gradle @@ -0,0 +1,15 @@ +dependencies { + api project(':netty-handler-codec-http') + implementation libs.brotli4j // accessing com.aayushatharva.brotli4j.encoder.Encoder + testImplementation testLibs.gson + testImplementation testLibs.assertj + testImplementation testLibs.mockito.core + testRuntimeOnly(variantOf(testLibs.netty.tcnative.boringssl.static) { + classifier('linux-x86_64') + }) + testRuntimeOnly testLibs.brotli4j.native.linux.x8664 + testRuntimeOnly testLibs.brotli4j.native.linux.aarch64 + testRuntimeOnly testLibs.brotli4j.native.osx.x8664 + testRuntimeOnly testLibs.brotli4j.native.osx.aarch64 + testRuntimeOnly testLibs.brotli4j.native.windows.x8664 +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/AbstractHttp2ConnectionHandlerBuilder.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/AbstractHttp2ConnectionHandlerBuilder.java new file mode 100644 index 0000000..c135a74 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/AbstractHttp2ConnectionHandlerBuilder.java @@ -0,0 +1,660 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.codec.http2; + +import io.netty.channel.Channel; +import io.netty.handler.codec.http2.Http2HeadersEncoder.SensitivityDetector; +import io.netty.util.internal.UnstableApi; + +import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_HEADER_LIST_SIZE; +import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_MAX_RESERVED_STREAMS; +import static io.netty.handler.codec.http2.Http2PromisedRequestVerifier.ALWAYS_VERIFY; +import static io.netty.util.internal.ObjectUtil.checkNotNull; +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; + +/** + * Abstract base class which defines commonly used features required to build {@link Http2ConnectionHandler} instances. + * + *

Three ways to build a {@link Http2ConnectionHandler}

+ *

Let the builder create a {@link Http2ConnectionHandler}

+ * Simply call all the necessary setter methods, and then use {@link #build()} to build a new + * {@link Http2ConnectionHandler}. Setting the following properties are prohibited because they are used for + * other ways of building a {@link Http2ConnectionHandler}. + * conflicts with this option: + *
    + *
  • {@link #connection(Http2Connection)}
  • + *
  • {@link #codec(Http2ConnectionDecoder, Http2ConnectionEncoder)}
  • + *
+ * + * + *

Let the builder use the {@link Http2ConnectionHandler} you specified

+ * Call {@link #connection(Http2Connection)} to tell the builder that you want to build the handler from the + * {@link Http2Connection} you specified. Setting the following properties are prohibited and thus will trigger + * an {@link IllegalStateException} because they conflict with this option. + *
    + *
  • {@link #server(boolean)}
  • + *
  • {@link #codec(Http2ConnectionDecoder, Http2ConnectionEncoder)}
  • + *
+ * + *

Let the builder use the {@link Http2ConnectionDecoder} and {@link Http2ConnectionEncoder} you specified

+ * Call {@link #codec(Http2ConnectionDecoder, Http2ConnectionEncoder)} to tell the builder that you want to built the + * handler from the {@link Http2ConnectionDecoder} and {@link Http2ConnectionEncoder} you specified. Setting the + * following properties are prohibited and thus will trigger an {@link IllegalStateException} because they conflict + * with this option: + *
    + *
  • {@link #server(boolean)}
  • + *
  • {@link #connection(Http2Connection)}
  • + *
  • {@link #frameLogger(Http2FrameLogger)}
  • + *
  • {@link #headerSensitivityDetector(SensitivityDetector)}
  • + *
  • {@link #encoderEnforceMaxConcurrentStreams(boolean)}
  • + *
  • {@link #encoderIgnoreMaxHeaderListSize(boolean)}
  • + *
+ * + *

Exposing necessary methods in a subclass

+ * {@link #build()} method and all property access methods are {@code protected}. Choose the methods to expose to the + * users of your builder implementation and make them {@code public}. + * + * @param The type of handler created by this builder. + * @param The concrete type of this builder. + */ +@UnstableApi +public abstract class AbstractHttp2ConnectionHandlerBuilder> { + + private static final SensitivityDetector DEFAULT_HEADER_SENSITIVITY_DETECTOR = Http2HeadersEncoder.NEVER_SENSITIVE; + + private static final int DEFAULT_MAX_RST_FRAMES_PER_CONNECTION_FOR_SERVER = 200; + + // The properties that can always be set. + private Http2Settings initialSettings = Http2Settings.defaultSettings(); + private Http2FrameListener frameListener; + private long gracefulShutdownTimeoutMillis = Http2CodecUtil.DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT_MILLIS; + private boolean decoupleCloseAndGoAway; + private boolean flushPreface = true; + + // The property that will prohibit connection() and codec() if set by server(), + // because this property is used only when this builder creates an Http2Connection. + private Boolean isServer; + private Integer maxReservedStreams; + + // The property that will prohibit server() and codec() if set by connection(). + private Http2Connection connection; + + // The properties that will prohibit server() and connection() if set by codec(). + private Http2ConnectionDecoder decoder; + private Http2ConnectionEncoder encoder; + + // The properties that are: + // * mutually exclusive against codec() and + // * OK to use with server() and connection() + private Boolean validateHeaders; + private Http2FrameLogger frameLogger; + private SensitivityDetector headerSensitivityDetector; + private Boolean encoderEnforceMaxConcurrentStreams; + private Boolean encoderIgnoreMaxHeaderListSize; + private Http2PromisedRequestVerifier promisedRequestVerifier = ALWAYS_VERIFY; + private boolean autoAckSettingsFrame = true; + private boolean autoAckPingFrame = true; + private int maxQueuedControlFrames = Http2CodecUtil.DEFAULT_MAX_QUEUED_CONTROL_FRAMES; + private int maxConsecutiveEmptyFrames = 2; + private Integer maxRstFramesPerWindow; + private int secondsPerWindow = 30; + + /** + * Sets the {@link Http2Settings} to use for the initial connection settings exchange. + */ + protected Http2Settings initialSettings() { + return initialSettings; + } + + /** + * Sets the {@link Http2Settings} to use for the initial connection settings exchange. + */ + protected B initialSettings(Http2Settings settings) { + initialSettings = checkNotNull(settings, "settings"); + return self(); + } + + /** + * Returns the listener of inbound frames. + * + * @return {@link Http2FrameListener} if set, or {@code null} if not set. + */ + protected Http2FrameListener frameListener() { + return frameListener; + } + + /** + * Sets the listener of inbound frames. + * This listener will only be set if the decoder's listener is {@code null}. + */ + protected B frameListener(Http2FrameListener frameListener) { + this.frameListener = checkNotNull(frameListener, "frameListener"); + return self(); + } + + /** + * Returns the graceful shutdown timeout of the {@link Http2Connection} in milliseconds. Returns -1 if the + * timeout is indefinite. + */ + protected long gracefulShutdownTimeoutMillis() { + return gracefulShutdownTimeoutMillis; + } + + /** + * Sets the graceful shutdown timeout of the {@link Http2Connection} in milliseconds. + */ + protected B gracefulShutdownTimeoutMillis(long gracefulShutdownTimeoutMillis) { + if (gracefulShutdownTimeoutMillis < -1) { + throw new IllegalArgumentException("gracefulShutdownTimeoutMillis: " + gracefulShutdownTimeoutMillis + + " (expected: -1 for indefinite or >= 0)"); + } + this.gracefulShutdownTimeoutMillis = gracefulShutdownTimeoutMillis; + return self(); + } + + /** + * Returns if {@link #build()} will to create a {@link Http2Connection} in server mode ({@code true}) + * or client mode ({@code false}). + */ + protected boolean isServer() { + return isServer != null ? isServer : true; + } + + /** + * Sets if {@link #build()} will to create a {@link Http2Connection} in server mode ({@code true}) + * or client mode ({@code false}). + */ + protected B server(boolean isServer) { + enforceConstraint("server", "connection", connection); + enforceConstraint("server", "codec", decoder); + enforceConstraint("server", "codec", encoder); + + this.isServer = isServer; + return self(); + } + + /** + * Get the maximum number of streams which can be in the reserved state at any given time. + *

+ * By default this value will be ignored on the server for local endpoint. This is because the RFC provides + * no way to explicitly communicate a limit to how many states can be in the reserved state, and instead relies + * on the peer to send RST_STREAM frames when they will be rejected. + */ + protected int maxReservedStreams() { + return maxReservedStreams != null ? maxReservedStreams : DEFAULT_MAX_RESERVED_STREAMS; + } + + /** + * Set the maximum number of streams which can be in the reserved state at any given time. + */ + protected B maxReservedStreams(int maxReservedStreams) { + enforceConstraint("server", "connection", connection); + enforceConstraint("server", "codec", decoder); + enforceConstraint("server", "codec", encoder); + + this.maxReservedStreams = checkPositiveOrZero(maxReservedStreams, "maxReservedStreams"); + return self(); + } + + /** + * Returns the {@link Http2Connection} to use. + * + * @return {@link Http2Connection} if set, or {@code null} if not set. + */ + protected Http2Connection connection() { + return connection; + } + + /** + * Sets the {@link Http2Connection} to use. + */ + protected B connection(Http2Connection connection) { + enforceConstraint("connection", "maxReservedStreams", maxReservedStreams); + enforceConstraint("connection", "server", isServer); + enforceConstraint("connection", "codec", decoder); + enforceConstraint("connection", "codec", encoder); + + this.connection = checkNotNull(connection, "connection"); + + return self(); + } + + /** + * Returns the {@link Http2ConnectionDecoder} to use. + * + * @return {@link Http2ConnectionDecoder} if set, or {@code null} if not set. + */ + protected Http2ConnectionDecoder decoder() { + return decoder; + } + + /** + * Returns the {@link Http2ConnectionEncoder} to use. + * + * @return {@link Http2ConnectionEncoder} if set, or {@code null} if not set. + */ + protected Http2ConnectionEncoder encoder() { + return encoder; + } + + /** + * Sets the {@link Http2ConnectionDecoder} and {@link Http2ConnectionEncoder} to use. + */ + protected B codec(Http2ConnectionDecoder decoder, Http2ConnectionEncoder encoder) { + enforceConstraint("codec", "server", isServer); + enforceConstraint("codec", "maxReservedStreams", maxReservedStreams); + enforceConstraint("codec", "connection", connection); + enforceConstraint("codec", "frameLogger", frameLogger); + enforceConstraint("codec", "validateHeaders", validateHeaders); + enforceConstraint("codec", "headerSensitivityDetector", headerSensitivityDetector); + enforceConstraint("codec", "encoderEnforceMaxConcurrentStreams", encoderEnforceMaxConcurrentStreams); + + checkNotNull(decoder, "decoder"); + checkNotNull(encoder, "encoder"); + + if (decoder.connection() != encoder.connection()) { + throw new IllegalArgumentException("The specified encoder and decoder have different connections."); + } + + this.decoder = decoder; + this.encoder = encoder; + + return self(); + } + + /** + * Returns if HTTP headers should be validated according to + * RFC 7540, 8.1.2.6. + */ + protected boolean isValidateHeaders() { + return validateHeaders != null ? validateHeaders : true; + } + + /** + * Sets if HTTP headers should be validated according to + * RFC 7540, 8.1.2.6. + */ + protected B validateHeaders(boolean validateHeaders) { + enforceNonCodecConstraints("validateHeaders"); + this.validateHeaders = validateHeaders; + return self(); + } + + /** + * Returns the logger that is used for the encoder and decoder. + * + * @return {@link Http2FrameLogger} if set, or {@code null} if not set. + */ + protected Http2FrameLogger frameLogger() { + return frameLogger; + } + + /** + * Sets the logger that is used for the encoder and decoder. + */ + protected B frameLogger(Http2FrameLogger frameLogger) { + enforceNonCodecConstraints("frameLogger"); + this.frameLogger = checkNotNull(frameLogger, "frameLogger"); + return self(); + } + + /** + * Returns if the encoder should queue frames if the maximum number of concurrent streams + * would otherwise be exceeded. + */ + protected boolean encoderEnforceMaxConcurrentStreams() { + return encoderEnforceMaxConcurrentStreams != null ? encoderEnforceMaxConcurrentStreams : false; + } + + /** + * Sets if the encoder should queue frames if the maximum number of concurrent streams + * would otherwise be exceeded. + */ + protected B encoderEnforceMaxConcurrentStreams(boolean encoderEnforceMaxConcurrentStreams) { + enforceNonCodecConstraints("encoderEnforceMaxConcurrentStreams"); + this.encoderEnforceMaxConcurrentStreams = encoderEnforceMaxConcurrentStreams; + return self(); + } + + /** + * Returns the maximum number of queued control frames that are allowed before the connection is closed. + * This allows to protected against various attacks that can lead to high CPU / memory usage if the remote-peer + * floods us with frames that would have us produce control frames, but stops to read from the underlying socket. + * + * {@code 0} means no protection is in place. + */ + protected int encoderEnforceMaxQueuedControlFrames() { + return maxQueuedControlFrames; + } + + /** + * Sets the maximum number of queued control frames that are allowed before the connection is closed. + * This allows to protected against various attacks that can lead to high CPU / memory usage if the remote-peer + * floods us with frames that would have us produce control frames, but stops to read from the underlying socket. + * + * {@code 0} means no protection should be applied. + */ + protected B encoderEnforceMaxQueuedControlFrames(int maxQueuedControlFrames) { + enforceNonCodecConstraints("encoderEnforceMaxQueuedControlFrames"); + this.maxQueuedControlFrames = checkPositiveOrZero(maxQueuedControlFrames, "maxQueuedControlFrames"); + return self(); + } + + /** + * Returns the {@link SensitivityDetector} to use. + */ + protected SensitivityDetector headerSensitivityDetector() { + return headerSensitivityDetector != null ? headerSensitivityDetector : DEFAULT_HEADER_SENSITIVITY_DETECTOR; + } + + /** + * Sets the {@link SensitivityDetector} to use. + */ + protected B headerSensitivityDetector(SensitivityDetector headerSensitivityDetector) { + enforceNonCodecConstraints("headerSensitivityDetector"); + this.headerSensitivityDetector = checkNotNull(headerSensitivityDetector, "headerSensitivityDetector"); + return self(); + } + + /** + * Sets if the SETTINGS_MAX_HEADER_LIST_SIZE + * should be ignored when encoding headers. + * @param ignoreMaxHeaderListSize {@code true} to ignore + * SETTINGS_MAX_HEADER_LIST_SIZE. + * @return this. + */ + protected B encoderIgnoreMaxHeaderListSize(boolean ignoreMaxHeaderListSize) { + enforceNonCodecConstraints("encoderIgnoreMaxHeaderListSize"); + encoderIgnoreMaxHeaderListSize = ignoreMaxHeaderListSize; + return self(); + } + + /** + * Does nothing, do not call. + * + * @deprecated Huffman decoding no longer depends on having a decode capacity. + */ + @Deprecated + protected B initialHuffmanDecodeCapacity(int initialHuffmanDecodeCapacity) { + return self(); + } + + /** + * Set the {@link Http2PromisedRequestVerifier} to use. + * @return this. + */ + protected B promisedRequestVerifier(Http2PromisedRequestVerifier promisedRequestVerifier) { + enforceNonCodecConstraints("promisedRequestVerifier"); + this.promisedRequestVerifier = checkNotNull(promisedRequestVerifier, "promisedRequestVerifier"); + return self(); + } + + /** + * Get the {@link Http2PromisedRequestVerifier} to use. + * @return the {@link Http2PromisedRequestVerifier} to use. + */ + protected Http2PromisedRequestVerifier promisedRequestVerifier() { + return promisedRequestVerifier; + } + + /** + * Returns the maximum number of consecutive empty DATA frames (without end_of_stream flag) that are allowed before + * the connection is closed. This allows to protect against the remote peer flooding us with such frames and + * so use up a lot of CPU. There is no valid use-case for empty DATA frames without end_of_stream flag. + * + * {@code 0} means no protection is in place. + */ + protected int decoderEnforceMaxConsecutiveEmptyDataFrames() { + return maxConsecutiveEmptyFrames; + } + + /** + * Sets the maximum number of consecutive empty DATA frames (without end_of_stream flag) that are allowed before + * the connection is closed. This allows to protect against the remote peer flooding us with such frames and + * so use up a lot of CPU. There is no valid use-case for empty DATA frames without end_of_stream flag. + * + * {@code 0} means no protection should be applied. + */ + protected B decoderEnforceMaxConsecutiveEmptyDataFrames(int maxConsecutiveEmptyFrames) { + enforceNonCodecConstraints("maxConsecutiveEmptyFrames"); + this.maxConsecutiveEmptyFrames = checkPositiveOrZero( + maxConsecutiveEmptyFrames, "maxConsecutiveEmptyFrames"); + return self(); + } + + /** + * Sets the maximum number RST frames that are allowed per window before + * the connection is closed. This allows to protect against the remote peer flooding us with such frames and + * so use up a lot of CPU. + * + * {@code 0} for any of the parameters means no protection should be applied. + */ + protected B decoderEnforceMaxRstFramesPerWindow(int maxRstFramesPerWindow, int secondsPerWindow) { + enforceNonCodecConstraints("decoderEnforceMaxRstFramesPerWindow"); + this.maxRstFramesPerWindow = checkPositiveOrZero( + maxRstFramesPerWindow, "maxRstFramesPerWindow"); + this.secondsPerWindow = checkPositiveOrZero(secondsPerWindow, "secondsPerWindow"); + return self(); + } + + /** + * Determine if settings frame should automatically be acknowledged and applied. + * @return this. + */ + protected B autoAckSettingsFrame(boolean autoAckSettings) { + enforceNonCodecConstraints("autoAckSettingsFrame"); + autoAckSettingsFrame = autoAckSettings; + return self(); + } + + /** + * Determine if the SETTINGS frames should be automatically acknowledged and applied. + * @return {@code true} if the SETTINGS frames should be automatically acknowledged and applied. + */ + protected boolean isAutoAckSettingsFrame() { + return autoAckSettingsFrame; + } + + /** + * Determine if PING frame should automatically be acknowledged or not. + * @return this. + */ + protected B autoAckPingFrame(boolean autoAckPingFrame) { + enforceNonCodecConstraints("autoAckPingFrame"); + this.autoAckPingFrame = autoAckPingFrame; + return self(); + } + + /** + * Determine if the PING frames should be automatically acknowledged or not. + * @return {@code true} if the PING frames should be automatically acknowledged. + */ + protected boolean isAutoAckPingFrame() { + return autoAckPingFrame; + } + + /** + * Determine if the {@link Channel#close()} should be coupled with goaway and graceful close. + * @param decoupleCloseAndGoAway {@code true} to make {@link Channel#close()} directly close the underlying + * transport, and not attempt graceful closure via GOAWAY. + * @return {@code this}. + */ + protected B decoupleCloseAndGoAway(boolean decoupleCloseAndGoAway) { + this.decoupleCloseAndGoAway = decoupleCloseAndGoAway; + return self(); + } + + /** + * Determine if the {@link Channel#close()} should be coupled with goaway and graceful close. + */ + protected boolean decoupleCloseAndGoAway() { + return decoupleCloseAndGoAway; + } + + /** + * Determine if the Preface + * should be automatically flushed when the {@link Channel} becomes active or not. + *

+ * Client may choose to opt-out from this automatic behavior and manage flush manually if it's ready to send + * request frames immediately after the preface. It may help to avoid unnecessary latency. + * + * @param flushPreface {@code true} to automatically flush, {@code false otherwise}. + * @return {@code this}. + * @see HTTP/2 Connection Preface + */ + protected B flushPreface(boolean flushPreface) { + this.flushPreface = flushPreface; + return self(); + } + + /** + * Determine if the Preface + * should be automatically flushed when the {@link Channel} becomes active or not. + *

+ * Client may choose to opt-out from this automatic behavior and manage flush manually if it's ready to send + * request frames immediately after the preface. It may help to avoid unnecessary latency. + * + * @return {@code true} if automatically flushed. + * @see HTTP/2 Connection Preface + */ + protected boolean flushPreface() { + return flushPreface; + } + + /** + * Create a new {@link Http2ConnectionHandler}. + */ + protected T build() { + if (encoder != null) { + assert decoder != null; + return buildFromCodec(decoder, encoder); + } + + Http2Connection connection = this.connection; + if (connection == null) { + connection = new DefaultHttp2Connection(isServer(), maxReservedStreams()); + } + + return buildFromConnection(connection); + } + + private T buildFromConnection(Http2Connection connection) { + Long maxHeaderListSize = initialSettings.maxHeaderListSize(); + Http2FrameReader reader = new DefaultHttp2FrameReader(new DefaultHttp2HeadersDecoder(isValidateHeaders(), + maxHeaderListSize == null ? DEFAULT_HEADER_LIST_SIZE : maxHeaderListSize, + /* initialHuffmanDecodeCapacity= */ -1)); + Http2FrameWriter writer = encoderIgnoreMaxHeaderListSize == null ? + new DefaultHttp2FrameWriter(headerSensitivityDetector()) : + new DefaultHttp2FrameWriter(headerSensitivityDetector(), encoderIgnoreMaxHeaderListSize); + + if (frameLogger != null) { + reader = new Http2InboundFrameLogger(reader, frameLogger); + writer = new Http2OutboundFrameLogger(writer, frameLogger); + } + + Http2ConnectionEncoder encoder = new DefaultHttp2ConnectionEncoder(connection, writer); + boolean encoderEnforceMaxConcurrentStreams = encoderEnforceMaxConcurrentStreams(); + + if (maxQueuedControlFrames != 0) { + encoder = new Http2ControlFrameLimitEncoder(encoder, maxQueuedControlFrames); + } + if (encoderEnforceMaxConcurrentStreams) { + if (connection.isServer()) { + encoder.close(); + reader.close(); + throw new IllegalArgumentException( + "encoderEnforceMaxConcurrentStreams: " + encoderEnforceMaxConcurrentStreams + + " not supported for server"); + } + encoder = new StreamBufferingEncoder(encoder); + } + + DefaultHttp2ConnectionDecoder decoder = new DefaultHttp2ConnectionDecoder(connection, encoder, reader, + promisedRequestVerifier(), isAutoAckSettingsFrame(), isAutoAckPingFrame(), isValidateHeaders()); + return buildFromCodec(decoder, encoder); + } + + private T buildFromCodec(Http2ConnectionDecoder decoder, Http2ConnectionEncoder encoder) { + int maxConsecutiveEmptyDataFrames = decoderEnforceMaxConsecutiveEmptyDataFrames(); + if (maxConsecutiveEmptyDataFrames > 0) { + decoder = new Http2EmptyDataFrameConnectionDecoder(decoder, maxConsecutiveEmptyDataFrames); + } + final int maxRstFrames; + if (maxRstFramesPerWindow == null) { + // Only enable by default on the server. + if (isServer()) { + maxRstFrames = DEFAULT_MAX_RST_FRAMES_PER_CONNECTION_FOR_SERVER; + } else { + maxRstFrames = 0; + } + } else { + maxRstFrames = maxRstFramesPerWindow; + } + if (maxRstFrames > 0 && secondsPerWindow > 0) { + decoder = new Http2MaxRstFrameDecoder(decoder, maxRstFrames, secondsPerWindow); + } + final T handler; + try { + // Call the abstract build method + handler = build(decoder, encoder, initialSettings); + } catch (Throwable t) { + encoder.close(); + decoder.close(); + throw new IllegalStateException("failed to build an Http2ConnectionHandler", t); + } + + // Setup post build options + handler.gracefulShutdownTimeoutMillis(gracefulShutdownTimeoutMillis); + if (handler.decoder().frameListener() == null) { + handler.decoder().frameListener(frameListener); + } + return handler; + } + + /** + * Implement this method to create a new {@link Http2ConnectionHandler} or its subtype instance. + *

+ * The return of this method will be subject to the following: + *

    + *
  • {@link #frameListener(Http2FrameListener)} will be set if not already set in the decoder
  • + *
  • {@link #gracefulShutdownTimeoutMillis(long)} will always be set
  • + *
+ */ + protected abstract T build(Http2ConnectionDecoder decoder, Http2ConnectionEncoder encoder, + Http2Settings initialSettings) throws Exception; + + /** + * Returns {@code this}. + */ + @SuppressWarnings("unchecked") + protected final B self() { + return (B) this; + } + + private void enforceNonCodecConstraints(String rejected) { + enforceConstraint(rejected, "server/connection", decoder); + enforceConstraint(rejected, "server/connection", encoder); + } + + private static void enforceConstraint(String methodName, String rejectorName, Object value) { + if (value != null) { + throw new IllegalStateException( + methodName + "() cannot be called because " + rejectorName + "() has been called already."); + } + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/AbstractHttp2StreamChannel.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/AbstractHttp2StreamChannel.java new file mode 100644 index 0000000..49882ac --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/AbstractHttp2StreamChannel.java @@ -0,0 +1,1160 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBufAllocator; +import io.netty.channel.Channel; +import io.netty.channel.ChannelConfig; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelId; +import io.netty.channel.ChannelMetadata; +import io.netty.channel.ChannelOutboundBuffer; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.ChannelProgressivePromise; +import io.netty.channel.ChannelPromise; +import io.netty.channel.DefaultChannelConfig; +import io.netty.channel.DefaultChannelPipeline; +import io.netty.channel.EventLoop; +import io.netty.channel.MessageSizeEstimator; +import io.netty.channel.RecvByteBufAllocator; +import io.netty.channel.VoidChannelPromise; +import io.netty.channel.WriteBufferWaterMark; +import io.netty.channel.socket.ChannelInputShutdownReadComplete; +import io.netty.channel.socket.ChannelOutputShutdownEvent; +import io.netty.handler.codec.http2.Http2FrameCodec.DefaultHttp2FrameStream; +import io.netty.handler.ssl.SslCloseCompletionEvent; +import io.netty.util.DefaultAttributeMap; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.internal.StringUtil; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.io.IOException; +import java.net.SocketAddress; +import java.nio.channels.ClosedChannelException; +import java.util.ArrayDeque; +import java.util.Queue; +import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; + +import static io.netty.handler.codec.http2.Http2CodecUtil.isStreamIdValid; +import static io.netty.util.internal.ObjectUtil.checkNotNull; +import static java.lang.Math.min; + +abstract class AbstractHttp2StreamChannel extends DefaultAttributeMap implements Http2StreamChannel { + + static final Http2FrameStreamVisitor WRITABLE_VISITOR = new Http2FrameStreamVisitor() { + @Override + public boolean visit(Http2FrameStream stream) { + final AbstractHttp2StreamChannel childChannel = (AbstractHttp2StreamChannel) + ((DefaultHttp2FrameStream) stream).attachment; + childChannel.trySetWritable(); + return true; + } + }; + + static final Http2FrameStreamVisitor CHANNEL_INPUT_SHUTDOWN_READ_COMPLETE_VISITOR = + new UserEventStreamVisitor(ChannelInputShutdownReadComplete.INSTANCE); + + static final Http2FrameStreamVisitor CHANNEL_OUTPUT_SHUTDOWN_EVENT_VISITOR = + new UserEventStreamVisitor(ChannelOutputShutdownEvent.INSTANCE); + + static final Http2FrameStreamVisitor SSL_CLOSE_COMPLETION_EVENT_VISITOR = + new UserEventStreamVisitor(SslCloseCompletionEvent.SUCCESS); + + private static final InternalLogger logger = InternalLoggerFactory.getInstance(AbstractHttp2StreamChannel.class); + + private static final ChannelMetadata METADATA = new ChannelMetadata(false, 16); + + /** + * Number of bytes to consider non-payload messages. 9 is arbitrary, but also the minimum size of an HTTP/2 frame. + * Primarily is non-zero. + */ + private static final int MIN_HTTP2_FRAME_SIZE = 9; + + /** + * {@link Http2FrameStreamVisitor} that fires the user event for every active stream pipeline. + */ + private static final class UserEventStreamVisitor implements Http2FrameStreamVisitor { + + private final Object event; + + UserEventStreamVisitor(Object event) { + this.event = checkNotNull(event, "event"); + } + + @Override + public boolean visit(Http2FrameStream stream) { + final AbstractHttp2StreamChannel childChannel = (AbstractHttp2StreamChannel) + ((DefaultHttp2FrameStream) stream).attachment; + childChannel.pipeline().fireUserEventTriggered(event); + return true; + } + } + + /** + * Returns the flow-control size for DATA frames, and {@value MIN_HTTP2_FRAME_SIZE} for all other frames. + */ + private static final class FlowControlledFrameSizeEstimator implements MessageSizeEstimator { + + static final FlowControlledFrameSizeEstimator INSTANCE = new FlowControlledFrameSizeEstimator(); + + private static final Handle HANDLE_INSTANCE = new Handle() { + @Override + public int size(Object msg) { + return msg instanceof Http2DataFrame ? + // Guard against overflow. + (int) min(Integer.MAX_VALUE, ((Http2DataFrame) msg).initialFlowControlledBytes() + + (long) MIN_HTTP2_FRAME_SIZE) : MIN_HTTP2_FRAME_SIZE; + } + }; + + @Override + public Handle newHandle() { + return HANDLE_INSTANCE; + } + } + + private static final AtomicLongFieldUpdater TOTAL_PENDING_SIZE_UPDATER = + AtomicLongFieldUpdater.newUpdater(AbstractHttp2StreamChannel.class, "totalPendingSize"); + + private static final AtomicIntegerFieldUpdater UNWRITABLE_UPDATER = + AtomicIntegerFieldUpdater.newUpdater(AbstractHttp2StreamChannel.class, "unwritable"); + + private static void windowUpdateFrameWriteComplete(ChannelFuture future, Channel streamChannel) { + Throwable cause = future.cause(); + if (cause != null) { + Throwable unwrappedCause; + // Unwrap if needed + if (cause instanceof Http2FrameStreamException && (unwrappedCause = cause.getCause()) != null) { + cause = unwrappedCause; + } + + // Notify the child-channel and close it. + streamChannel.pipeline().fireExceptionCaught(cause); + streamChannel.unsafe().close(streamChannel.unsafe().voidPromise()); + } + } + + private final ChannelFutureListener windowUpdateFrameWriteListener = new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + windowUpdateFrameWriteComplete(future, AbstractHttp2StreamChannel.this); + } + }; + + /** + * The current status of the read-processing for a {@link AbstractHttp2StreamChannel}. + */ + private enum ReadStatus { + /** + * No read in progress and no read was requested (yet) + */ + IDLE, + + /** + * Reading in progress + */ + IN_PROGRESS, + + /** + * A read operation was requested. + */ + REQUESTED + } + + private final Http2StreamChannelConfig config = new Http2StreamChannelConfig(this); + private final Http2ChannelUnsafe unsafe = new Http2ChannelUnsafe(); + private final ChannelId channelId; + private final ChannelPipeline pipeline; + private final DefaultHttp2FrameStream stream; + private final ChannelPromise closePromise; + + private volatile boolean registered; + + private volatile long totalPendingSize; + private volatile int unwritable; + + // Cached to reduce GC + private Runnable fireChannelWritabilityChangedTask; + + private boolean outboundClosed; + private int flowControlledBytes; + + /** + * This variable represents if a read is in progress for the current channel or was requested. + * Note that depending upon the {@link RecvByteBufAllocator} behavior a read may extend beyond the + * {@link Http2ChannelUnsafe#beginRead()} method scope. The {@link Http2ChannelUnsafe#beginRead()} loop may + * drain all pending data, and then if the parent channel is reading this channel may still accept frames. + */ + private ReadStatus readStatus = ReadStatus.IDLE; + + private Queue inboundBuffer; + + /** {@code true} after the first HEADERS frame has been written **/ + private boolean firstFrameWritten; + private boolean readCompletePending; + + AbstractHttp2StreamChannel(DefaultHttp2FrameStream stream, int id, ChannelHandler inboundHandler) { + this.stream = stream; + stream.attachment = this; + pipeline = new DefaultChannelPipeline(this) { + @Override + protected void incrementPendingOutboundBytes(long size) { + AbstractHttp2StreamChannel.this.incrementPendingOutboundBytes(size, true); + } + + @Override + protected void decrementPendingOutboundBytes(long size) { + AbstractHttp2StreamChannel.this.decrementPendingOutboundBytes(size, true); + } + + @Override + protected void onUnhandledInboundException(Throwable cause) { + // Ensure we use the correct Http2Error to close the channel. + if (cause instanceof Http2FrameStreamException) { + closeWithError(((Http2FrameStreamException) cause).error()); + return; + } else { + Http2Exception exception = Http2CodecUtil.getEmbeddedHttp2Exception(cause); + if (exception != null) { + closeWithError(exception.error()); + return; + } + } + super.onUnhandledInboundException(cause); + } + }; + + closePromise = pipeline.newPromise(); + channelId = new Http2StreamChannelId(parent().id(), id); + + if (inboundHandler != null) { + // Add the handler to the pipeline now that we are registered. + pipeline.addLast(inboundHandler); + } + } + + private void incrementPendingOutboundBytes(long size, boolean invokeLater) { + if (size == 0) { + return; + } + + long newWriteBufferSize = TOTAL_PENDING_SIZE_UPDATER.addAndGet(this, size); + if (newWriteBufferSize > config().getWriteBufferHighWaterMark()) { + setUnwritable(invokeLater); + } + } + + private void decrementPendingOutboundBytes(long size, boolean invokeLater) { + if (size == 0) { + return; + } + + long newWriteBufferSize = TOTAL_PENDING_SIZE_UPDATER.addAndGet(this, -size); + // Once the totalPendingSize dropped below the low water-mark we can mark the child channel + // as writable again. Before doing so we also need to ensure the parent channel is writable to + // prevent excessive buffering in the parent outbound buffer. If the parent is not writable + // we will mark the child channel as writable once the parent becomes writable by calling + // trySetWritable() later. + if (newWriteBufferSize < config().getWriteBufferLowWaterMark() && parent().isWritable()) { + setWritable(invokeLater); + } + } + + final void trySetWritable() { + // The parent is writable again but the child channel itself may still not be writable. + // Lets try to set the child channel writable to match the state of the parent channel + // if (and only if) the totalPendingSize is smaller then the low water-mark. + // If this is not the case we will try again later once we drop under it. + if (totalPendingSize < config().getWriteBufferLowWaterMark()) { + setWritable(false); + } + } + + private void setWritable(boolean invokeLater) { + for (;;) { + final int oldValue = unwritable; + final int newValue = oldValue & ~1; + if (UNWRITABLE_UPDATER.compareAndSet(this, oldValue, newValue)) { + if (oldValue != 0 && newValue == 0) { + fireChannelWritabilityChanged(invokeLater); + } + break; + } + } + } + + private void setUnwritable(boolean invokeLater) { + for (;;) { + final int oldValue = unwritable; + final int newValue = oldValue | 1; + if (UNWRITABLE_UPDATER.compareAndSet(this, oldValue, newValue)) { + if (oldValue == 0) { + fireChannelWritabilityChanged(invokeLater); + } + break; + } + } + } + + private void fireChannelWritabilityChanged(boolean invokeLater) { + final ChannelPipeline pipeline = pipeline(); + if (invokeLater) { + Runnable task = fireChannelWritabilityChangedTask; + if (task == null) { + fireChannelWritabilityChangedTask = task = new Runnable() { + @Override + public void run() { + pipeline.fireChannelWritabilityChanged(); + } + }; + } + eventLoop().execute(task); + } else { + pipeline.fireChannelWritabilityChanged(); + } + } + @Override + public Http2FrameStream stream() { + return stream; + } + + void closeOutbound() { + outboundClosed = true; + } + + void streamClosed() { + unsafe.readEOS(); + // Attempt to drain any queued data from the queue and deliver it to the application before closing this + // channel. + unsafe.doBeginRead(); + } + + @Override + public ChannelMetadata metadata() { + return METADATA; + } + + @Override + public ChannelConfig config() { + return config; + } + + @Override + public boolean isOpen() { + return !closePromise.isDone(); + } + + @Override + public boolean isActive() { + return isOpen(); + } + + @Override + public boolean isWritable() { + return unwritable == 0; + } + + @Override + public ChannelId id() { + return channelId; + } + + @Override + public EventLoop eventLoop() { + return parent().eventLoop(); + } + + @Override + public Channel parent() { + return parentContext().channel(); + } + + @Override + public boolean isRegistered() { + return registered; + } + + @Override + public SocketAddress localAddress() { + return parent().localAddress(); + } + + @Override + public SocketAddress remoteAddress() { + return parent().remoteAddress(); + } + + @Override + public ChannelFuture closeFuture() { + return closePromise; + } + + @Override + public long bytesBeforeUnwritable() { + // +1 because writability doesn't change until the threshold is crossed (not equal to). + long bytes = config().getWriteBufferHighWaterMark() - totalPendingSize + 1; + // If bytes is negative we know we are not writable, but if bytes is non-negative we have to check + // writability. Note that totalPendingSize and isWritable() use different volatile variables that are not + // synchronized together. totalPendingSize will be updated before isWritable(). + return bytes > 0 && isWritable() ? bytes : 0; + } + + @Override + public long bytesBeforeWritable() { + // +1 because writability doesn't change until the threshold is crossed (not equal to). + long bytes = totalPendingSize - config().getWriteBufferLowWaterMark() + 1; + // If bytes is negative we know we are writable, but if bytes is non-negative we have to check writability. + // Note that totalPendingSize and isWritable() use different volatile variables that are not synchronized + // together. totalPendingSize will be updated before isWritable(). + return bytes <= 0 || isWritable() ? 0 : bytes; + } + + @Override + public Unsafe unsafe() { + return unsafe; + } + + @Override + public ChannelPipeline pipeline() { + return pipeline; + } + + @Override + public ByteBufAllocator alloc() { + return config().getAllocator(); + } + + @Override + public Channel read() { + pipeline().read(); + return this; + } + + @Override + public Channel flush() { + pipeline().flush(); + return this; + } + + @Override + public ChannelFuture bind(SocketAddress localAddress) { + return pipeline().bind(localAddress); + } + + @Override + public ChannelFuture connect(SocketAddress remoteAddress) { + return pipeline().connect(remoteAddress); + } + + @Override + public ChannelFuture connect(SocketAddress remoteAddress, SocketAddress localAddress) { + return pipeline().connect(remoteAddress, localAddress); + } + + @Override + public ChannelFuture disconnect() { + return pipeline().disconnect(); + } + + @Override + public ChannelFuture close() { + return pipeline().close(); + } + + @Override + public ChannelFuture deregister() { + return pipeline().deregister(); + } + + @Override + public ChannelFuture bind(SocketAddress localAddress, ChannelPromise promise) { + return pipeline().bind(localAddress, promise); + } + + @Override + public ChannelFuture connect(SocketAddress remoteAddress, ChannelPromise promise) { + return pipeline().connect(remoteAddress, promise); + } + + @Override + public ChannelFuture connect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) { + return pipeline().connect(remoteAddress, localAddress, promise); + } + + @Override + public ChannelFuture disconnect(ChannelPromise promise) { + return pipeline().disconnect(promise); + } + + @Override + public ChannelFuture close(ChannelPromise promise) { + return pipeline().close(promise); + } + + @Override + public ChannelFuture deregister(ChannelPromise promise) { + return pipeline().deregister(promise); + } + + @Override + public ChannelFuture write(Object msg) { + return pipeline().write(msg); + } + + @Override + public ChannelFuture write(Object msg, ChannelPromise promise) { + return pipeline().write(msg, promise); + } + + @Override + public ChannelFuture writeAndFlush(Object msg, ChannelPromise promise) { + return pipeline().writeAndFlush(msg, promise); + } + + @Override + public ChannelFuture writeAndFlush(Object msg) { + return pipeline().writeAndFlush(msg); + } + + @Override + public ChannelPromise newPromise() { + return pipeline().newPromise(); + } + + @Override + public ChannelProgressivePromise newProgressivePromise() { + return pipeline().newProgressivePromise(); + } + + @Override + public ChannelFuture newSucceededFuture() { + return pipeline().newSucceededFuture(); + } + + @Override + public ChannelFuture newFailedFuture(Throwable cause) { + return pipeline().newFailedFuture(cause); + } + + @Override + public ChannelPromise voidPromise() { + return pipeline().voidPromise(); + } + + @Override + public int hashCode() { + return id().hashCode(); + } + + @Override + public boolean equals(Object o) { + return this == o; + } + + @Override + public int compareTo(Channel o) { + if (this == o) { + return 0; + } + + return id().compareTo(o.id()); + } + + @Override + public String toString() { + return parent().toString() + "(H2 - " + stream + ')'; + } + + /** + * Receive a read message. This does not notify handlers unless a read is in progress on the + * channel. + */ + void fireChildRead(Http2Frame frame) { + assert eventLoop().inEventLoop(); + if (!isActive()) { + ReferenceCountUtil.release(frame); + } else if (readStatus != ReadStatus.IDLE) { + // If a read is in progress or has been requested, there cannot be anything in the queue, + // otherwise we would have drained it from the queue and processed it during the read cycle. + assert inboundBuffer == null || inboundBuffer.isEmpty(); + final RecvByteBufAllocator.Handle allocHandle = unsafe.recvBufAllocHandle(); + unsafe.doRead0(frame, allocHandle); + // We currently don't need to check for readEOS because the parent channel and child channel are limited + // to the same EventLoop thread. There are a limited number of frame types that may come after EOS is + // read (unknown, reset) and the trade off is less conditionals for the hot path (headers/data) at the + // cost of additional readComplete notifications on the rare path. + if (allocHandle.continueReading()) { + maybeAddChannelToReadCompletePendingQueue(); + } else { + unsafe.notifyReadComplete(allocHandle, true); + } + } else { + if (inboundBuffer == null) { + inboundBuffer = new ArrayDeque(4); + } + inboundBuffer.add(frame); + } + } + + void fireChildReadComplete() { + assert eventLoop().inEventLoop(); + assert readStatus != ReadStatus.IDLE || !readCompletePending; + unsafe.notifyReadComplete(unsafe.recvBufAllocHandle(), false); + } + + final void closeWithError(Http2Error error) { + assert eventLoop().inEventLoop(); + unsafe.close(unsafe.voidPromise(), error); + } + + private final class Http2ChannelUnsafe implements Unsafe { + private final VoidChannelPromise unsafeVoidPromise = + new VoidChannelPromise(AbstractHttp2StreamChannel.this, false); + @SuppressWarnings("deprecation") + private RecvByteBufAllocator.Handle recvHandle; + private boolean writeDoneAndNoFlush; + private boolean closeInitiated; + private boolean readEOS; + + @Override + public void connect(final SocketAddress remoteAddress, + SocketAddress localAddress, final ChannelPromise promise) { + if (!promise.setUncancellable()) { + return; + } + promise.setFailure(new UnsupportedOperationException()); + } + + @Override + public RecvByteBufAllocator.Handle recvBufAllocHandle() { + if (recvHandle == null) { + recvHandle = config().getRecvByteBufAllocator().newHandle(); + recvHandle.reset(config()); + } + return recvHandle; + } + + @Override + public SocketAddress localAddress() { + return parent().unsafe().localAddress(); + } + + @Override + public SocketAddress remoteAddress() { + return parent().unsafe().remoteAddress(); + } + + @Override + public void register(EventLoop eventLoop, ChannelPromise promise) { + if (!promise.setUncancellable()) { + return; + } + if (registered) { + promise.setFailure(new UnsupportedOperationException("Re-register is not supported")); + return; + } + + registered = true; + + promise.setSuccess(); + + pipeline().fireChannelRegistered(); + if (isActive()) { + pipeline().fireChannelActive(); + } + } + + @Override + public void bind(SocketAddress localAddress, ChannelPromise promise) { + if (!promise.setUncancellable()) { + return; + } + promise.setFailure(new UnsupportedOperationException()); + } + + @Override + public void disconnect(ChannelPromise promise) { + close(promise); + } + + @Override + public void close(final ChannelPromise promise) { + close(promise, Http2Error.CANCEL); + } + + void close(final ChannelPromise promise, Http2Error error) { + if (!promise.setUncancellable()) { + return; + } + if (closeInitiated) { + if (closePromise.isDone()) { + // Closed already. + promise.setSuccess(); + } else if (!(promise instanceof VoidChannelPromise)) { // Only needed if no VoidChannelPromise. + // This means close() was called before so we just register a listener and return + closePromise.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + promise.setSuccess(); + } + }); + } + return; + } + closeInitiated = true; + // Just set to false as removing from an underlying queue would even be more expensive. + readCompletePending = false; + + final boolean wasActive = isActive(); + + // There is no need to update the local window as once the stream is closed all the pending bytes will be + // given back to the connection window by the controller itself. + + // Only ever send a reset frame if the connection is still alive and if the stream was created before + // as otherwise we may send a RST on a stream in an invalid state and cause a connection error. + if (parent().isActive() && !readEOS && isStreamIdValid(stream.id())) { + Http2StreamFrame resetFrame = new DefaultHttp2ResetFrame(error).stream(stream()); + write(resetFrame, unsafe().voidPromise()); + flush(); + } + + if (inboundBuffer != null) { + for (;;) { + Object msg = inboundBuffer.poll(); + if (msg == null) { + break; + } + ReferenceCountUtil.release(msg); + } + inboundBuffer = null; + } + + // The promise should be notified before we call fireChannelInactive(). + outboundClosed = true; + closePromise.setSuccess(); + promise.setSuccess(); + + fireChannelInactiveAndDeregister(voidPromise(), wasActive); + } + + @Override + public void closeForcibly() { + close(unsafe().voidPromise()); + } + + @Override + public void deregister(ChannelPromise promise) { + fireChannelInactiveAndDeregister(promise, false); + } + + private void fireChannelInactiveAndDeregister(final ChannelPromise promise, + final boolean fireChannelInactive) { + if (!promise.setUncancellable()) { + return; + } + + if (!registered) { + promise.setSuccess(); + return; + } + + // As a user may call deregister() from within any method while doing processing in the ChannelPipeline, + // we need to ensure we do the actual deregister operation later. This is necessary to preserve the + // behavior of the AbstractChannel, which always invokes channelUnregistered and channelInactive + // events 'later' to ensure the current events in the handler are completed before these events. + // + // See: + // https://github.com/netty/netty/issues/4435 + invokeLater(new Runnable() { + @Override + public void run() { + if (fireChannelInactive) { + pipeline.fireChannelInactive(); + } + // The user can fire `deregister` events multiple times but we only want to fire the pipeline + // event if the channel was actually registered. + if (registered) { + registered = false; + pipeline.fireChannelUnregistered(); + } + safeSetSuccess(promise); + } + }); + } + + private void safeSetSuccess(ChannelPromise promise) { + if (!(promise instanceof VoidChannelPromise) && !promise.trySuccess()) { + logger.warn("Failed to mark a promise as success because it is done already: {}", promise); + } + } + + private void invokeLater(Runnable task) { + try { + // This method is used by outbound operation implementations to trigger an inbound event later. + // They do not trigger an inbound event immediately because an outbound operation might have been + // triggered by another inbound event handler method. If fired immediately, the call stack + // will look like this for example: + // + // handlerA.inboundBufferUpdated() - (1) an inbound handler method closes a connection. + // -> handlerA.ctx.close() + // -> channel.unsafe.close() + // -> handlerA.channelInactive() - (2) another inbound handler method called while in (1) yet + // + // which means the execution of two inbound handler methods of the same handler overlap undesirably. + eventLoop().execute(task); + } catch (RejectedExecutionException e) { + logger.warn("Can't invoke task later as EventLoop rejected it", e); + } + } + + @Override + public void beginRead() { + if (!isActive()) { + return; + } + updateLocalWindowIfNeeded(); + + switch (readStatus) { + case IDLE: + readStatus = ReadStatus.IN_PROGRESS; + doBeginRead(); + break; + case IN_PROGRESS: + readStatus = ReadStatus.REQUESTED; + break; + default: + break; + } + } + + private Object pollQueuedMessage() { + return inboundBuffer == null ? null : inboundBuffer.poll(); + } + + void doBeginRead() { + // Process messages until there are none left (or the user stopped requesting) and also handle EOS. + while (readStatus != ReadStatus.IDLE) { + Object message = pollQueuedMessage(); + if (message == null) { + if (readEOS) { + unsafe.closeForcibly(); + } + // We need to double check that there is nothing left to flush such as a + // window update frame. + flush(); + break; + } + final RecvByteBufAllocator.Handle allocHandle = recvBufAllocHandle(); + allocHandle.reset(config()); + boolean continueReading = false; + do { + doRead0((Http2Frame) message, allocHandle); + } while ((readEOS || (continueReading = allocHandle.continueReading())) + && (message = pollQueuedMessage()) != null); + + if (continueReading && isParentReadInProgress() && !readEOS) { + // Currently the parent and child channel are on the same EventLoop thread. If the parent is + // currently reading it is possible that more frames will be delivered to this child channel. In + // the case that this child channel still wants to read we delay the channelReadComplete on this + // child channel until the parent is done reading. + maybeAddChannelToReadCompletePendingQueue(); + } else { + notifyReadComplete(allocHandle, true); + } + } + } + + void readEOS() { + readEOS = true; + } + + private void updateLocalWindowIfNeeded() { + if (flowControlledBytes != 0) { + int bytes = flowControlledBytes; + flowControlledBytes = 0; + ChannelFuture future = write0(parentContext(), new DefaultHttp2WindowUpdateFrame(bytes).stream(stream)); + // window update frames are commonly swallowed by the Http2FrameCodec and the promise is synchronously + // completed but the flow controller _may_ have generated a wire level WINDOW_UPDATE. Therefore we need, + // to assume there was a write done that needs to be flushed or we risk flow control starvation. + writeDoneAndNoFlush = true; + // Add a listener which will notify and teardown the stream + // when a window update fails if needed or check the result of the future directly if it was completed + // already. + // See https://github.com/netty/netty/issues/9663 + if (future.isDone()) { + windowUpdateFrameWriteComplete(future, AbstractHttp2StreamChannel.this); + } else { + future.addListener(windowUpdateFrameWriteListener); + } + } + } + + void notifyReadComplete(RecvByteBufAllocator.Handle allocHandle, boolean forceReadComplete) { + if (!readCompletePending && !forceReadComplete) { + return; + } + // Set to false just in case we added the channel multiple times before. + readCompletePending = false; + + if (readStatus == ReadStatus.REQUESTED) { + readStatus = ReadStatus.IN_PROGRESS; + } else { + readStatus = ReadStatus.IDLE; + } + + allocHandle.readComplete(); + pipeline().fireChannelReadComplete(); + // Reading data may result in frames being written (e.g. WINDOW_UPDATE, RST, etc..). If the parent + // channel is not currently reading we need to force a flush at the child channel, because we cannot + // rely upon flush occurring in channelReadComplete on the parent channel. + flush(); + if (readEOS) { + unsafe.closeForcibly(); + } + } + + @SuppressWarnings("deprecation") + void doRead0(Http2Frame frame, RecvByteBufAllocator.Handle allocHandle) { + final int bytes; + if (frame instanceof Http2DataFrame) { + bytes = ((Http2DataFrame) frame).initialFlowControlledBytes(); + + // It is important that we increment the flowControlledBytes before we call fireChannelRead(...) + // as it may cause a read() that will call updateLocalWindowIfNeeded() and we need to ensure + // in this case that we accounted for it. + // + // See https://github.com/netty/netty/issues/9663 + flowControlledBytes += bytes; + } else { + bytes = MIN_HTTP2_FRAME_SIZE; + } + // Update before firing event through the pipeline to be consistent with other Channel implementation. + allocHandle.attemptedBytesRead(bytes); + allocHandle.lastBytesRead(bytes); + allocHandle.incMessagesRead(1); + + pipeline().fireChannelRead(frame); + } + + @Override + public void write(Object msg, final ChannelPromise promise) { + // After this point its not possible to cancel a write anymore. + if (!promise.setUncancellable()) { + ReferenceCountUtil.release(msg); + return; + } + + if (!isActive() || + // Once the outbound side was closed we should not allow header / data frames + outboundClosed && (msg instanceof Http2HeadersFrame || msg instanceof Http2DataFrame)) { + ReferenceCountUtil.release(msg); + promise.setFailure(new ClosedChannelException()); + return; + } + + try { + if (msg instanceof Http2StreamFrame) { + Http2StreamFrame frame = validateStreamFrame((Http2StreamFrame) msg).stream(stream()); + writeHttp2StreamFrame(frame, promise); + } else { + String msgStr = msg.toString(); + ReferenceCountUtil.release(msg); + promise.setFailure(new IllegalArgumentException( + "Message must be an " + StringUtil.simpleClassName(Http2StreamFrame.class) + + ": " + msgStr)); + } + } catch (Throwable t) { + promise.tryFailure(t); + } + } + + private void writeHttp2StreamFrame(Http2StreamFrame frame, final ChannelPromise promise) { + if (!firstFrameWritten && !isStreamIdValid(stream().id()) && !(frame instanceof Http2HeadersFrame)) { + ReferenceCountUtil.release(frame); + promise.setFailure( + new IllegalArgumentException("The first frame must be a headers frame. Was: " + + frame.name())); + return; + } + + final boolean firstWrite; + if (firstFrameWritten) { + firstWrite = false; + } else { + firstWrite = firstFrameWritten = true; + } + + ChannelFuture f = write0(parentContext(), frame); + if (f.isDone()) { + if (firstWrite) { + firstWriteComplete(f, promise); + } else { + writeComplete(f, promise); + } + } else { + final long bytes = FlowControlledFrameSizeEstimator.HANDLE_INSTANCE.size(frame); + incrementPendingOutboundBytes(bytes, false); + f.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + if (firstWrite) { + firstWriteComplete(future, promise); + } else { + writeComplete(future, promise); + } + decrementPendingOutboundBytes(bytes, false); + } + }); + writeDoneAndNoFlush = true; + } + } + + private void firstWriteComplete(ChannelFuture future, ChannelPromise promise) { + Throwable cause = future.cause(); + if (cause == null) { + promise.setSuccess(); + } else { + // If the first write fails there is not much we can do, just close + closeForcibly(); + promise.setFailure(wrapStreamClosedError(cause)); + } + } + + private void writeComplete(ChannelFuture future, ChannelPromise promise) { + Throwable cause = future.cause(); + if (cause == null) { + promise.setSuccess(); + } else { + Throwable error = wrapStreamClosedError(cause); + // To make it more consistent with AbstractChannel we handle all IOExceptions here. + if (error instanceof IOException) { + if (config.isAutoClose()) { + // Close channel if needed. + closeForcibly(); + } else { + // TODO: Once Http2StreamChannel extends DuplexChannel we should call shutdownOutput(...) + outboundClosed = true; + } + } + promise.setFailure(error); + } + } + + private Throwable wrapStreamClosedError(Throwable cause) { + // If the error was caused by STREAM_CLOSED we should use a ClosedChannelException to better + // mimic other transports and make it easier to reason about what exceptions to expect. + if (cause instanceof Http2Exception && ((Http2Exception) cause).error() == Http2Error.STREAM_CLOSED) { + return new ClosedChannelException().initCause(cause); + } + return cause; + } + + private Http2StreamFrame validateStreamFrame(Http2StreamFrame frame) { + if (frame.stream() != null && frame.stream() != stream) { + String msgString = frame.toString(); + ReferenceCountUtil.release(frame); + throw new IllegalArgumentException( + "Stream " + frame.stream() + " must not be set on the frame: " + msgString); + } + return frame; + } + + @Override + public void flush() { + // If we are currently in the parent channel's read loop we should just ignore the flush. + // We will ensure we trigger ctx.flush() after we processed all Channels later on and + // so aggregate the flushes. This is done as ctx.flush() is expensive when as it may trigger an + // write(...) or writev(...) operation on the socket. + if (!writeDoneAndNoFlush || isParentReadInProgress()) { + // There is nothing to flush so this is a NOOP. + return; + } + // We need to set this to false before we call flush0(...) as ChannelFutureListener may produce more data + // that are explicit flushed. + writeDoneAndNoFlush = false; + flush0(parentContext()); + } + + @Override + public ChannelPromise voidPromise() { + return unsafeVoidPromise; + } + + @Override + public ChannelOutboundBuffer outboundBuffer() { + // Always return null as we not use the ChannelOutboundBuffer and not even support it. + return null; + } + } + + /** + * {@link ChannelConfig} so that the high and low writebuffer watermarks can reflect the outbound flow control + * window, without having to create a new {@link WriteBufferWaterMark} object whenever the flow control window + * changes. + */ + private static final class Http2StreamChannelConfig extends DefaultChannelConfig { + Http2StreamChannelConfig(Channel channel) { + super(channel); + } + + @Override + public MessageSizeEstimator getMessageSizeEstimator() { + return FlowControlledFrameSizeEstimator.INSTANCE; + } + + @Override + public ChannelConfig setMessageSizeEstimator(MessageSizeEstimator estimator) { + throw new UnsupportedOperationException(); + } + + @Override + public ChannelConfig setRecvByteBufAllocator(RecvByteBufAllocator allocator) { + if (!(allocator.newHandle() instanceof RecvByteBufAllocator.ExtendedHandle)) { + throw new IllegalArgumentException("allocator.newHandle() must return an object of type: " + + RecvByteBufAllocator.ExtendedHandle.class); + } + super.setRecvByteBufAllocator(allocator); + return this; + } + } + + private void maybeAddChannelToReadCompletePendingQueue() { + if (!readCompletePending) { + readCompletePending = true; + addChannelToReadCompletePendingQueue(); + } + } + + protected void flush0(ChannelHandlerContext ctx) { + ctx.flush(); + } + + protected ChannelFuture write0(ChannelHandlerContext ctx, Object msg) { + ChannelPromise promise = ctx.newPromise(); + ctx.write(msg, promise); + return promise; + } + + protected abstract boolean isParentReadInProgress(); + protected abstract void addChannelToReadCompletePendingQueue(); + protected abstract ChannelHandlerContext parentContext(); +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/AbstractHttp2StreamFrame.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/AbstractHttp2StreamFrame.java new file mode 100644 index 0000000..1d5c1d0 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/AbstractHttp2StreamFrame.java @@ -0,0 +1,59 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.util.internal.UnstableApi; + +/** + * Abstract implementation of {@link Http2StreamFrame}. + */ +@UnstableApi +public abstract class AbstractHttp2StreamFrame implements Http2StreamFrame { + + private Http2FrameStream stream; + + @Override + public AbstractHttp2StreamFrame stream(Http2FrameStream stream) { + this.stream = stream; + return this; + } + + @Override + public Http2FrameStream stream() { + return stream; + } + + /** + * Returns {@code true} if {@code o} has equal {@code stream} to this object. + */ + @Override + public boolean equals(Object o) { + if (!(o instanceof Http2StreamFrame)) { + return false; + } + Http2StreamFrame other = (Http2StreamFrame) o; + return stream == other.stream() || stream != null && stream.equals(other.stream()); + } + + @Override + public int hashCode() { + Http2FrameStream stream = this.stream; + if (stream == null) { + return super.hashCode(); + } + return stream.hashCode(); + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/AbstractInboundHttp2ToHttpAdapterBuilder.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/AbstractInboundHttp2ToHttpAdapterBuilder.java new file mode 100644 index 0000000..ecf29b3 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/AbstractInboundHttp2ToHttpAdapterBuilder.java @@ -0,0 +1,136 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.handler.codec.TooLongFrameException; +import io.netty.util.internal.UnstableApi; + +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +/** + * A skeletal builder implementation of {@link InboundHttp2ToHttpAdapter} and its subtypes. + */ +@UnstableApi +public abstract class AbstractInboundHttp2ToHttpAdapterBuilder< + T extends InboundHttp2ToHttpAdapter, B extends AbstractInboundHttp2ToHttpAdapterBuilder> { + + private final Http2Connection connection; + private int maxContentLength; + private boolean validateHttpHeaders; + private boolean propagateSettings; + + /** + * Creates a new {@link InboundHttp2ToHttpAdapter} builder for the specified {@link Http2Connection}. + * + * @param connection the object which will provide connection notification events + * for the current connection + */ + protected AbstractInboundHttp2ToHttpAdapterBuilder(Http2Connection connection) { + this.connection = checkNotNull(connection, "connection"); + } + + @SuppressWarnings("unchecked") + protected final B self() { + return (B) this; + } + + /** + * Returns the {@link Http2Connection}. + */ + protected Http2Connection connection() { + return connection; + } + + /** + * Returns the maximum length of the message content. + */ + protected int maxContentLength() { + return maxContentLength; + } + + /** + * Specifies the maximum length of the message content. + * + * @param maxContentLength the maximum length of the message content. If the length of the message content + * exceeds this value, a {@link TooLongFrameException} will be raised + * @return {@link AbstractInboundHttp2ToHttpAdapterBuilder} the builder for the {@link InboundHttp2ToHttpAdapter} + */ + protected B maxContentLength(int maxContentLength) { + this.maxContentLength = maxContentLength; + return self(); + } + + /** + * Return {@code true} if HTTP header validation should be performed. + */ + protected boolean isValidateHttpHeaders() { + return validateHttpHeaders; + } + + /** + * Specifies whether validation of HTTP headers should be performed. + * + * @param validate + *
    + *
  • {@code true} to validate HTTP headers in the http-codec
  • + *
  • {@code false} not to validate HTTP headers in the http-codec
  • + *
+ * @return {@link AbstractInboundHttp2ToHttpAdapterBuilder} the builder for the {@link InboundHttp2ToHttpAdapter} + */ + protected B validateHttpHeaders(boolean validate) { + validateHttpHeaders = validate; + return self(); + } + + /** + * Returns {@code true} if a read settings frame should be propagated along the channel pipeline. + */ + protected boolean isPropagateSettings() { + return propagateSettings; + } + + /** + * Specifies whether a read settings frame should be propagated along the channel pipeline. + * + * @param propagate if {@code true} read settings will be passed along the pipeline. This can be useful + * to clients that need hold off sending data until they have received the settings. + * @return {@link AbstractInboundHttp2ToHttpAdapterBuilder} the builder for the {@link InboundHttp2ToHttpAdapter} + */ + protected B propagateSettings(boolean propagate) { + propagateSettings = propagate; + return self(); + } + + /** + * Builds/creates a new {@link InboundHttp2ToHttpAdapter} instance using this builder's current settings. + */ + protected T build() { + final T instance; + try { + instance = build(connection(), maxContentLength(), + isValidateHttpHeaders(), isPropagateSettings()); + } catch (Throwable t) { + throw new IllegalStateException("failed to create a new InboundHttp2ToHttpAdapter", t); + } + connection.addListener(instance); + return instance; + } + + /** + * Creates a new {@link InboundHttp2ToHttpAdapter} with the specified properties. + */ + protected abstract T build(Http2Connection connection, int maxContentLength, + boolean validateHttpHeaders, boolean propagateSettings) throws Exception; +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/CharSequenceMap.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/CharSequenceMap.java new file mode 100644 index 0000000..b9fbc44 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/CharSequenceMap.java @@ -0,0 +1,48 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.handler.codec.DefaultHeaders; +import io.netty.handler.codec.UnsupportedValueConverter; +import io.netty.handler.codec.ValueConverter; +import io.netty.util.internal.UnstableApi; + +import static io.netty.util.AsciiString.CASE_INSENSITIVE_HASHER; +import static io.netty.util.AsciiString.CASE_SENSITIVE_HASHER; + +/** + * Internal use only! + */ +@UnstableApi +public final class CharSequenceMap extends DefaultHeaders> { + public CharSequenceMap() { + this(true); + } + + public CharSequenceMap(boolean caseSensitive) { + this(caseSensitive, UnsupportedValueConverter.instance()); + } + + public CharSequenceMap(boolean caseSensitive, ValueConverter valueConverter) { + super(caseSensitive ? CASE_SENSITIVE_HASHER : CASE_INSENSITIVE_HASHER, valueConverter); + } + + @SuppressWarnings("unchecked") + public CharSequenceMap(boolean caseSensitive, ValueConverter valueConverter, int arraySizeHint) { + super(caseSensitive ? CASE_SENSITIVE_HASHER : CASE_INSENSITIVE_HASHER, valueConverter, + NameValidator.NOT_NULL, arraySizeHint); + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/CleartextHttp2ServerUpgradeHandler.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/CleartextHttp2ServerUpgradeHandler.java new file mode 100644 index 0000000..95ed241 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/CleartextHttp2ServerUpgradeHandler.java @@ -0,0 +1,107 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.ByteToMessageDecoder; +import io.netty.handler.codec.http.HttpServerCodec; +import io.netty.handler.codec.http.HttpServerUpgradeHandler; +import io.netty.util.internal.UnstableApi; + +import java.util.List; + +import static io.netty.buffer.Unpooled.unreleasableBuffer; +import static io.netty.handler.codec.http2.Http2CodecUtil.connectionPrefaceBuf; + +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +/** + * Performing cleartext upgrade, by h2c HTTP upgrade or Prior Knowledge. + * This handler config pipeline for h2c upgrade when handler added. + * And will update pipeline once it detect the connection is starting HTTP/2 by + * prior knowledge or not. + */ +@UnstableApi +public final class CleartextHttp2ServerUpgradeHandler extends ByteToMessageDecoder { + private static final ByteBuf CONNECTION_PREFACE = unreleasableBuffer(connectionPrefaceBuf()).asReadOnly(); + + private final HttpServerCodec httpServerCodec; + private final HttpServerUpgradeHandler httpServerUpgradeHandler; + private final ChannelHandler http2ServerHandler; + + /** + * Creates the channel handler provide cleartext HTTP/2 upgrade from HTTP + * upgrade or prior knowledge + * + * @param httpServerCodec the http server codec + * @param httpServerUpgradeHandler the http server upgrade handler for HTTP/2 + * @param http2ServerHandler the http2 server handler, will be added into pipeline + * when starting HTTP/2 by prior knowledge + */ + public CleartextHttp2ServerUpgradeHandler(HttpServerCodec httpServerCodec, + HttpServerUpgradeHandler httpServerUpgradeHandler, + ChannelHandler http2ServerHandler) { + this.httpServerCodec = checkNotNull(httpServerCodec, "httpServerCodec"); + this.httpServerUpgradeHandler = checkNotNull(httpServerUpgradeHandler, "httpServerUpgradeHandler"); + this.http2ServerHandler = checkNotNull(http2ServerHandler, "http2ServerHandler"); + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + ctx.pipeline() + .addAfter(ctx.name(), null, httpServerUpgradeHandler) + .addAfter(ctx.name(), null, httpServerCodec); + } + + /** + * Peek inbound message to determine current connection wants to start HTTP/2 + * by HTTP upgrade or prior knowledge + */ + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { + int prefaceLength = CONNECTION_PREFACE.readableBytes(); + int bytesRead = Math.min(in.readableBytes(), prefaceLength); + + if (!ByteBufUtil.equals(CONNECTION_PREFACE, CONNECTION_PREFACE.readerIndex(), + in, in.readerIndex(), bytesRead)) { + ctx.pipeline().remove(this); + } else if (bytesRead == prefaceLength) { + // Full h2 preface match, removed source codec, using http2 codec to handle + // following network traffic + ctx.pipeline() + .remove(httpServerCodec) + .remove(httpServerUpgradeHandler); + + ctx.pipeline().addAfter(ctx.name(), null, http2ServerHandler); + ctx.pipeline().remove(this); + + ctx.fireUserEventTriggered(PriorKnowledgeUpgradeEvent.INSTANCE); + } + } + + /** + * User event that is fired to notify about HTTP/2 protocol is started. + */ + public static final class PriorKnowledgeUpgradeEvent { + private static final PriorKnowledgeUpgradeEvent INSTANCE = new PriorKnowledgeUpgradeEvent(); + + private PriorKnowledgeUpgradeEvent() { + } + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/CompressorHttp2ConnectionEncoder.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/CompressorHttp2ConnectionEncoder.java new file mode 100644 index 0000000..262c8ad --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/CompressorHttp2ConnectionEncoder.java @@ -0,0 +1,423 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.ByteToMessageDecoder; +import io.netty.handler.codec.compression.BrotliEncoder; +import io.netty.handler.codec.compression.ZlibCodecFactory; +import io.netty.handler.codec.compression.ZlibWrapper; +import io.netty.handler.codec.compression.Brotli; +import io.netty.handler.codec.compression.BrotliOptions; +import io.netty.handler.codec.compression.CompressionOptions; +import io.netty.handler.codec.compression.DeflateOptions; +import io.netty.handler.codec.compression.GzipOptions; +import io.netty.handler.codec.compression.StandardCompressionOptions; +import io.netty.handler.codec.compression.ZstdEncoder; +import io.netty.handler.codec.compression.ZstdOptions; +import io.netty.handler.codec.compression.SnappyFrameEncoder; +import io.netty.handler.codec.compression.SnappyOptions; +import io.netty.util.concurrent.PromiseCombiner; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.UnstableApi; + +import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_ENCODING; +import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_LENGTH; +import static io.netty.handler.codec.http.HttpHeaderValues.BR; +import static io.netty.handler.codec.http.HttpHeaderValues.DEFLATE; +import static io.netty.handler.codec.http.HttpHeaderValues.GZIP; +import static io.netty.handler.codec.http.HttpHeaderValues.IDENTITY; +import static io.netty.handler.codec.http.HttpHeaderValues.X_DEFLATE; +import static io.netty.handler.codec.http.HttpHeaderValues.X_GZIP; +import static io.netty.handler.codec.http.HttpHeaderValues.ZSTD; +import static io.netty.handler.codec.http.HttpHeaderValues.SNAPPY; + +/** + * A decorating HTTP2 encoder that will compress data frames according to the {@code content-encoding} header for each + * stream. The compression provided by this class will be applied to the data for the entire stream. + */ +@UnstableApi +public class CompressorHttp2ConnectionEncoder extends DecoratingHttp2ConnectionEncoder { + // We cannot remove this because it'll be breaking change + public static final int DEFAULT_COMPRESSION_LEVEL = 6; + public static final int DEFAULT_WINDOW_BITS = 15; + public static final int DEFAULT_MEM_LEVEL = 8; + + private int compressionLevel; + private int windowBits; + private int memLevel; + private final Http2Connection.PropertyKey propertyKey; + + private final boolean supportsCompressionOptions; + + private BrotliOptions brotliOptions; + private GzipOptions gzipCompressionOptions; + private DeflateOptions deflateOptions; + private ZstdOptions zstdOptions; + private SnappyOptions snappyOptions; + + /** + * Create a new {@link CompressorHttp2ConnectionEncoder} instance + * with default implementation of {@link StandardCompressionOptions} + */ + public CompressorHttp2ConnectionEncoder(Http2ConnectionEncoder delegate) { + this(delegate, defaultCompressionOptions()); + } + + private static CompressionOptions[] defaultCompressionOptions() { + if (Brotli.isAvailable()) { + return new CompressionOptions[] { + StandardCompressionOptions.brotli(), + StandardCompressionOptions.snappy(), + StandardCompressionOptions.gzip(), + StandardCompressionOptions.deflate() }; + } + return new CompressionOptions[] { StandardCompressionOptions.snappy(), + StandardCompressionOptions.gzip(), StandardCompressionOptions.deflate() }; + } + + /** + * Create a new {@link CompressorHttp2ConnectionEncoder} instance + */ + @Deprecated + public CompressorHttp2ConnectionEncoder(Http2ConnectionEncoder delegate, int compressionLevel, int windowBits, + int memLevel) { + super(delegate); + this.compressionLevel = ObjectUtil.checkInRange(compressionLevel, 0, 9, "compressionLevel"); + this.windowBits = ObjectUtil.checkInRange(windowBits, 9, 15, "windowBits"); + this.memLevel = ObjectUtil.checkInRange(memLevel, 1, 9, "memLevel"); + + propertyKey = connection().newKey(); + connection().addListener(new Http2ConnectionAdapter() { + @Override + public void onStreamRemoved(Http2Stream stream) { + final EmbeddedChannel compressor = stream.getProperty(propertyKey); + if (compressor != null) { + cleanup(stream, compressor); + } + } + }); + + supportsCompressionOptions = false; + } + + /** + * Create a new {@link CompressorHttp2ConnectionEncoder} with + * specified {@link StandardCompressionOptions} + */ + public CompressorHttp2ConnectionEncoder(Http2ConnectionEncoder delegate, + CompressionOptions... compressionOptionsArgs) { + super(delegate); + ObjectUtil.checkNotNull(compressionOptionsArgs, "CompressionOptions"); + ObjectUtil.deepCheckNotNull("CompressionOptions", compressionOptionsArgs); + + for (CompressionOptions compressionOptions : compressionOptionsArgs) { + // BrotliOptions' class initialization depends on Brotli classes being on the classpath. + // The Brotli.isAvailable check ensures that BrotliOptions will only get instantiated if Brotli is on + // the classpath. + // This results in the static analysis of native-image identifying the instanceof BrotliOptions check + // and thus BrotliOptions itself as unreachable, enabling native-image to link all classes at build time + // and not complain about the missing Brotli classes. + if (Brotli.isAvailable() && compressionOptions instanceof BrotliOptions) { + brotliOptions = (BrotliOptions) compressionOptions; + } else if (compressionOptions instanceof GzipOptions) { + gzipCompressionOptions = (GzipOptions) compressionOptions; + } else if (compressionOptions instanceof DeflateOptions) { + deflateOptions = (DeflateOptions) compressionOptions; + } else if (compressionOptions instanceof ZstdOptions) { + zstdOptions = (ZstdOptions) compressionOptions; + } else if (compressionOptions instanceof SnappyOptions) { + snappyOptions = (SnappyOptions) compressionOptions; + } else { + throw new IllegalArgumentException("Unsupported " + CompressionOptions.class.getSimpleName() + + ": " + compressionOptions); + } + } + + supportsCompressionOptions = true; + + propertyKey = connection().newKey(); + connection().addListener(new Http2ConnectionAdapter() { + @Override + public void onStreamRemoved(Http2Stream stream) { + final EmbeddedChannel compressor = stream.getProperty(propertyKey); + if (compressor != null) { + cleanup(stream, compressor); + } + } + }); + } + + @Override + public ChannelFuture writeData(final ChannelHandlerContext ctx, final int streamId, ByteBuf data, int padding, + final boolean endOfStream, ChannelPromise promise) { + final Http2Stream stream = connection().stream(streamId); + final EmbeddedChannel channel = stream == null ? null : (EmbeddedChannel) stream.getProperty(propertyKey); + if (channel == null) { + // The compressor may be null if no compatible encoding type was found in this stream's headers + return super.writeData(ctx, streamId, data, padding, endOfStream, promise); + } + + try { + // The channel will release the buffer after being written + channel.writeOutbound(data); + ByteBuf buf = nextReadableBuf(channel); + if (buf == null) { + if (endOfStream) { + if (channel.finish()) { + buf = nextReadableBuf(channel); + } + return super.writeData(ctx, streamId, buf == null ? Unpooled.EMPTY_BUFFER : buf, padding, + true, promise); + } + // END_STREAM is not set and the assumption is data is still forthcoming. + promise.setSuccess(); + return promise; + } + + PromiseCombiner combiner = new PromiseCombiner(ctx.executor()); + for (;;) { + ByteBuf nextBuf = nextReadableBuf(channel); + boolean compressedEndOfStream = nextBuf == null && endOfStream; + if (compressedEndOfStream && channel.finish()) { + nextBuf = nextReadableBuf(channel); + compressedEndOfStream = nextBuf == null; + } + + ChannelPromise bufPromise = ctx.newPromise(); + combiner.add(bufPromise); + super.writeData(ctx, streamId, buf, padding, compressedEndOfStream, bufPromise); + if (nextBuf == null) { + break; + } + + padding = 0; // Padding is only communicated once on the first iteration + buf = nextBuf; + } + combiner.finish(promise); + } catch (Throwable cause) { + promise.tryFailure(cause); + } finally { + if (endOfStream) { + cleanup(stream, channel); + } + } + return promise; + } + + @Override + public ChannelFuture writeHeaders(ChannelHandlerContext ctx, int streamId, Http2Headers headers, int padding, + boolean endStream, ChannelPromise promise) { + try { + // Determine if compression is required and sanitize the headers. + EmbeddedChannel compressor = newCompressor(ctx, headers, endStream); + + // Write the headers and create the stream object. + ChannelFuture future = super.writeHeaders(ctx, streamId, headers, padding, endStream, promise); + + // After the stream object has been created, then attach the compressor as a property for data compression. + bindCompressorToStream(compressor, streamId); + + return future; + } catch (Throwable e) { + promise.tryFailure(e); + } + return promise; + } + + @Override + public ChannelFuture writeHeaders(final ChannelHandlerContext ctx, final int streamId, final Http2Headers headers, + final int streamDependency, final short weight, final boolean exclusive, final int padding, + final boolean endOfStream, final ChannelPromise promise) { + try { + // Determine if compression is required and sanitize the headers. + EmbeddedChannel compressor = newCompressor(ctx, headers, endOfStream); + + // Write the headers and create the stream object. + ChannelFuture future = super.writeHeaders(ctx, streamId, headers, streamDependency, weight, exclusive, + padding, endOfStream, promise); + + // After the stream object has been created, then attach the compressor as a property for data compression. + bindCompressorToStream(compressor, streamId); + + return future; + } catch (Throwable e) { + promise.tryFailure(e); + } + return promise; + } + + /** + * Returns a new {@link EmbeddedChannel} that encodes the HTTP2 message content encoded in the specified + * {@code contentEncoding}. + * + * @param ctx the context. + * @param contentEncoding the value of the {@code content-encoding} header + * @return a new {@link ByteToMessageDecoder} if the specified encoding is supported. {@code null} otherwise + * (alternatively, you can throw a {@link Http2Exception} to block unknown encoding). + * @throws Http2Exception If the specified encoding is not supported and warrants an exception + */ + protected EmbeddedChannel newContentCompressor(ChannelHandlerContext ctx, CharSequence contentEncoding) + throws Http2Exception { + if (GZIP.contentEqualsIgnoreCase(contentEncoding) || X_GZIP.contentEqualsIgnoreCase(contentEncoding)) { + return newCompressionChannel(ctx, ZlibWrapper.GZIP); + } + if (DEFLATE.contentEqualsIgnoreCase(contentEncoding) || X_DEFLATE.contentEqualsIgnoreCase(contentEncoding)) { + return newCompressionChannel(ctx, ZlibWrapper.ZLIB); + } + if (Brotli.isAvailable() && brotliOptions != null && BR.contentEqualsIgnoreCase(contentEncoding)) { + return new EmbeddedChannel(ctx.channel().id(), ctx.channel().metadata().hasDisconnect(), + ctx.channel().config(), new BrotliEncoder(brotliOptions.parameters())); + } + if (zstdOptions != null && ZSTD.contentEqualsIgnoreCase(contentEncoding)) { + return new EmbeddedChannel(ctx.channel().id(), ctx.channel().metadata().hasDisconnect(), + ctx.channel().config(), new ZstdEncoder(zstdOptions.compressionLevel(), + zstdOptions.blockSize(), zstdOptions.maxEncodeSize())); + } + if (snappyOptions != null && SNAPPY.contentEqualsIgnoreCase(contentEncoding)) { + return new EmbeddedChannel(ctx.channel().id(), ctx.channel().metadata().hasDisconnect(), + ctx.channel().config(), new SnappyFrameEncoder()); + } + // 'identity' or unsupported + return null; + } + + /** + * Returns the expected content encoding of the decoded content. Returning {@code contentEncoding} is the default + * behavior, which is the case for most compressors. + * + * @param contentEncoding the value of the {@code content-encoding} header + * @return the expected content encoding of the new content. + * @throws Http2Exception if the {@code contentEncoding} is not supported and warrants an exception + */ + protected CharSequence getTargetContentEncoding(CharSequence contentEncoding) throws Http2Exception { + return contentEncoding; + } + + /** + * Generate a new instance of an {@link EmbeddedChannel} capable of compressing data + * @param ctx the context. + * @param wrapper Defines what type of encoder should be used + */ + private EmbeddedChannel newCompressionChannel(final ChannelHandlerContext ctx, ZlibWrapper wrapper) { + if (supportsCompressionOptions) { + if (wrapper == ZlibWrapper.GZIP && gzipCompressionOptions != null) { + return new EmbeddedChannel(ctx.channel().id(), ctx.channel().metadata().hasDisconnect(), + ctx.channel().config(), ZlibCodecFactory.newZlibEncoder(wrapper, + gzipCompressionOptions.compressionLevel(), gzipCompressionOptions.windowBits(), + gzipCompressionOptions.memLevel())); + } else if (wrapper == ZlibWrapper.ZLIB && deflateOptions != null) { + return new EmbeddedChannel(ctx.channel().id(), ctx.channel().metadata().hasDisconnect(), + ctx.channel().config(), ZlibCodecFactory.newZlibEncoder(wrapper, + deflateOptions.compressionLevel(), deflateOptions.windowBits(), + deflateOptions.memLevel())); + } else { + throw new IllegalArgumentException("Unsupported ZlibWrapper: " + wrapper); + } + } else { + return new EmbeddedChannel(ctx.channel().id(), ctx.channel().metadata().hasDisconnect(), + ctx.channel().config(), ZlibCodecFactory.newZlibEncoder(wrapper, compressionLevel, windowBits, + memLevel)); + } + } + + /** + * Checks if a new compressor object is needed for the stream identified by {@code streamId}. This method will + * modify the {@code content-encoding} header contained in {@code headers}. + * + * @param ctx the context. + * @param headers Object representing headers which are to be written + * @param endOfStream Indicates if the stream has ended + * @return The channel used to compress data. + * @throws Http2Exception if any problems occur during initialization. + */ + private EmbeddedChannel newCompressor(ChannelHandlerContext ctx, Http2Headers headers, boolean endOfStream) + throws Http2Exception { + if (endOfStream) { + return null; + } + + CharSequence encoding = headers.get(CONTENT_ENCODING); + if (encoding == null) { + encoding = IDENTITY; + } + final EmbeddedChannel compressor = newContentCompressor(ctx, encoding); + if (compressor != null) { + CharSequence targetContentEncoding = getTargetContentEncoding(encoding); + if (IDENTITY.contentEqualsIgnoreCase(targetContentEncoding)) { + headers.remove(CONTENT_ENCODING); + } else { + headers.set(CONTENT_ENCODING, targetContentEncoding); + } + + // The content length will be for the decompressed data. Since we will compress the data + // this content-length will not be correct. Instead of queuing messages or delaying sending + // header frames...just remove the content-length header + headers.remove(CONTENT_LENGTH); + } + + return compressor; + } + + /** + * Called after the super class has written the headers and created any associated stream objects. + * @param compressor The compressor associated with the stream identified by {@code streamId}. + * @param streamId The stream id for which the headers were written. + */ + private void bindCompressorToStream(EmbeddedChannel compressor, int streamId) { + if (compressor != null) { + Http2Stream stream = connection().stream(streamId); + if (stream != null) { + stream.setProperty(propertyKey, compressor); + } + } + } + + /** + * Release remaining content from {@link EmbeddedChannel} and remove the compressor from the {@link Http2Stream}. + * + * @param stream The stream for which {@code compressor} is the compressor for + * @param compressor The compressor for {@code stream} + */ + void cleanup(Http2Stream stream, EmbeddedChannel compressor) { + compressor.finishAndReleaseAll(); + stream.removeProperty(propertyKey); + } + + /** + * Read the next compressed {@link ByteBuf} from the {@link EmbeddedChannel} or {@code null} if one does not exist. + * + * @param compressor The channel to read from + * @return The next decoded {@link ByteBuf} from the {@link EmbeddedChannel} or {@code null} if one does not exist + */ + private static ByteBuf nextReadableBuf(EmbeddedChannel compressor) { + for (;;) { + final ByteBuf buf = compressor.readOutbound(); + if (buf == null) { + return null; + } + if (!buf.isReadable()) { + buf.release(); + continue; + } + return buf; + } + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DecoratingHttp2ConnectionDecoder.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DecoratingHttp2ConnectionDecoder.java new file mode 100644 index 0000000..148fc80 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DecoratingHttp2ConnectionDecoder.java @@ -0,0 +1,80 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.util.internal.UnstableApi; + +import java.util.List; + +/** + * Decorator around another {@link Http2ConnectionDecoder} instance. + */ +@UnstableApi +public class DecoratingHttp2ConnectionDecoder implements Http2ConnectionDecoder { + private final Http2ConnectionDecoder delegate; + + public DecoratingHttp2ConnectionDecoder(Http2ConnectionDecoder delegate) { + this.delegate = checkNotNull(delegate, "delegate"); + } + + @Override + public void lifecycleManager(Http2LifecycleManager lifecycleManager) { + delegate.lifecycleManager(lifecycleManager); + } + + @Override + public Http2Connection connection() { + return delegate.connection(); + } + + @Override + public Http2LocalFlowController flowController() { + return delegate.flowController(); + } + + @Override + public void frameListener(Http2FrameListener listener) { + delegate.frameListener(listener); + } + + @Override + public Http2FrameListener frameListener() { + return delegate.frameListener(); + } + + @Override + public void decodeFrame(ChannelHandlerContext ctx, ByteBuf in, List out) throws Http2Exception { + delegate.decodeFrame(ctx, in, out); + } + + @Override + public Http2Settings localSettings() { + return delegate.localSettings(); + } + + @Override + public boolean prefaceReceived() { + return delegate.prefaceReceived(); + } + + @Override + public void close() { + delegate.close(); + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DecoratingHttp2ConnectionEncoder.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DecoratingHttp2ConnectionEncoder.java new file mode 100644 index 0000000..f47c05f --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DecoratingHttp2ConnectionEncoder.java @@ -0,0 +1,73 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.util.internal.UnstableApi; + +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +/** + * A decorator around another {@link Http2ConnectionEncoder} instance. + */ +@UnstableApi +public class DecoratingHttp2ConnectionEncoder extends DecoratingHttp2FrameWriter implements Http2ConnectionEncoder, + Http2SettingsReceivedConsumer { + private final Http2ConnectionEncoder delegate; + + public DecoratingHttp2ConnectionEncoder(Http2ConnectionEncoder delegate) { + super(delegate); + this.delegate = checkNotNull(delegate, "delegate"); + } + + @Override + public void lifecycleManager(Http2LifecycleManager lifecycleManager) { + delegate.lifecycleManager(lifecycleManager); + } + + @Override + public Http2Connection connection() { + return delegate.connection(); + } + + @Override + public Http2RemoteFlowController flowController() { + return delegate.flowController(); + } + + @Override + public Http2FrameWriter frameWriter() { + return delegate.frameWriter(); + } + + @Override + public Http2Settings pollSentSettings() { + return delegate.pollSentSettings(); + } + + @Override + public void remoteSettings(Http2Settings settings) throws Http2Exception { + delegate.remoteSettings(settings); + } + + @Override + public void consumeReceivedSettings(Http2Settings settings) { + if (delegate instanceof Http2SettingsReceivedConsumer) { + ((Http2SettingsReceivedConsumer) delegate).consumeReceivedSettings(settings); + } else { + throw new IllegalStateException("delegate " + delegate + " is not an instance of " + + Http2SettingsReceivedConsumer.class); + } + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DecoratingHttp2FrameWriter.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DecoratingHttp2FrameWriter.java new file mode 100644 index 0000000..034ffaf --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DecoratingHttp2FrameWriter.java @@ -0,0 +1,116 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.util.internal.UnstableApi; + +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +/** + * Decorator around another {@link Http2FrameWriter} instance. + */ +@UnstableApi +public class DecoratingHttp2FrameWriter implements Http2FrameWriter { + private final Http2FrameWriter delegate; + + public DecoratingHttp2FrameWriter(Http2FrameWriter delegate) { + this.delegate = checkNotNull(delegate, "delegate"); + } + + @Override + public ChannelFuture writeData(ChannelHandlerContext ctx, int streamId, ByteBuf data, int padding, + boolean endStream, ChannelPromise promise) { + return delegate.writeData(ctx, streamId, data, padding, endStream, promise); + } + + @Override + public ChannelFuture writeHeaders(ChannelHandlerContext ctx, int streamId, Http2Headers headers, int padding, + boolean endStream, ChannelPromise promise) { + return delegate.writeHeaders(ctx, streamId, headers, padding, endStream, promise); + } + + @Override + public ChannelFuture writeHeaders(ChannelHandlerContext ctx, int streamId, Http2Headers headers, + int streamDependency, short weight, boolean exclusive, int padding, + boolean endStream, ChannelPromise promise) { + return delegate + .writeHeaders(ctx, streamId, headers, streamDependency, weight, exclusive, padding, endStream, promise); + } + + @Override + public ChannelFuture writePriority(ChannelHandlerContext ctx, int streamId, int streamDependency, short weight, + boolean exclusive, ChannelPromise promise) { + return delegate.writePriority(ctx, streamId, streamDependency, weight, exclusive, promise); + } + + @Override + public ChannelFuture writeRstStream(ChannelHandlerContext ctx, int streamId, long errorCode, + ChannelPromise promise) { + return delegate.writeRstStream(ctx, streamId, errorCode, promise); + } + + @Override + public ChannelFuture writeSettings(ChannelHandlerContext ctx, Http2Settings settings, ChannelPromise promise) { + return delegate.writeSettings(ctx, settings, promise); + } + + @Override + public ChannelFuture writeSettingsAck(ChannelHandlerContext ctx, ChannelPromise promise) { + return delegate.writeSettingsAck(ctx, promise); + } + + @Override + public ChannelFuture writePing(ChannelHandlerContext ctx, boolean ack, long data, ChannelPromise promise) { + return delegate.writePing(ctx, ack, data, promise); + } + + @Override + public ChannelFuture writePushPromise(ChannelHandlerContext ctx, int streamId, int promisedStreamId, + Http2Headers headers, int padding, ChannelPromise promise) { + return delegate.writePushPromise(ctx, streamId, promisedStreamId, headers, padding, promise); + } + + @Override + public ChannelFuture writeGoAway(ChannelHandlerContext ctx, int lastStreamId, long errorCode, ByteBuf debugData, + ChannelPromise promise) { + return delegate.writeGoAway(ctx, lastStreamId, errorCode, debugData, promise); + } + + @Override + public ChannelFuture writeWindowUpdate(ChannelHandlerContext ctx, int streamId, int windowSizeIncrement, + ChannelPromise promise) { + return delegate.writeWindowUpdate(ctx, streamId, windowSizeIncrement, promise); + } + + @Override + public ChannelFuture writeFrame(ChannelHandlerContext ctx, byte frameType, int streamId, Http2Flags flags, + ByteBuf payload, ChannelPromise promise) { + return delegate.writeFrame(ctx, frameType, streamId, flags, payload, promise); + } + + @Override + public Configuration configuration() { + return delegate.configuration(); + } + + @Override + public void close() { + delegate.close(); + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2Connection.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2Connection.java new file mode 100644 index 0000000..90b23ae --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2Connection.java @@ -0,0 +1,1080 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelPromise; +import io.netty.handler.codec.http2.Http2Stream.State; +import io.netty.util.collection.IntObjectHashMap; +import io.netty.util.collection.IntObjectMap; +import io.netty.util.collection.IntObjectMap.PrimitiveEntry; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.Promise; +import io.netty.util.concurrent.PromiseNotifier; +import io.netty.util.internal.EmptyArrays; +import io.netty.util.internal.UnstableApi; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Queue; +import java.util.Set; + +import static io.netty.handler.codec.http2.Http2CodecUtil.CONNECTION_STREAM_ID; +import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_MAX_RESERVED_STREAMS; +import static io.netty.handler.codec.http2.Http2Error.INTERNAL_ERROR; +import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR; +import static io.netty.handler.codec.http2.Http2Error.REFUSED_STREAM; +import static io.netty.handler.codec.http2.Http2Exception.closedStreamError; +import static io.netty.handler.codec.http2.Http2Exception.connectionError; +import static io.netty.handler.codec.http2.Http2Exception.streamError; +import static io.netty.handler.codec.http2.Http2Stream.State.CLOSED; +import static io.netty.handler.codec.http2.Http2Stream.State.HALF_CLOSED_LOCAL; +import static io.netty.handler.codec.http2.Http2Stream.State.HALF_CLOSED_REMOTE; +import static io.netty.handler.codec.http2.Http2Stream.State.IDLE; +import static io.netty.handler.codec.http2.Http2Stream.State.OPEN; +import static io.netty.handler.codec.http2.Http2Stream.State.RESERVED_LOCAL; +import static io.netty.handler.codec.http2.Http2Stream.State.RESERVED_REMOTE; +import static io.netty.util.internal.ObjectUtil.checkNotNull; +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; +import static java.lang.Integer.MAX_VALUE; + +/** + * Simple implementation of {@link Http2Connection}. + */ +@UnstableApi +public class DefaultHttp2Connection implements Http2Connection { + private static final InternalLogger logger = InternalLoggerFactory.getInstance(DefaultHttp2Connection.class); + // Fields accessed by inner classes + final IntObjectMap streamMap = new IntObjectHashMap(); + final PropertyKeyRegistry propertyKeyRegistry = new PropertyKeyRegistry(); + final ConnectionStream connectionStream = new ConnectionStream(); + final DefaultEndpoint localEndpoint; + final DefaultEndpoint remoteEndpoint; + + /** + * We chose a {@link List} over a {@link Set} to avoid allocating an {@link Iterator} objects when iterating over + * the listeners. + *

+ * Initial size of 4 because the default configuration currently has 3 listeners + * (local/remote flow controller and {@link StreamByteDistributor}) and we leave room for 1 extra. + * We could be more aggressive but the ArrayList resize will double the size if we are too small. + */ + final List listeners = new ArrayList(4); + final ActiveStreams activeStreams; + Promise closePromise; + + /** + * Creates a new connection with the given settings. + * @param server whether or not this end-point is the server-side of the HTTP/2 connection. + */ + public DefaultHttp2Connection(boolean server) { + this(server, DEFAULT_MAX_RESERVED_STREAMS); + } + + /** + * Creates a new connection with the given settings. + * @param server whether or not this end-point is the server-side of the HTTP/2 connection. + * @param maxReservedStreams The maximum amount of streams which can exist in the reserved state for each endpoint. + */ + public DefaultHttp2Connection(boolean server, int maxReservedStreams) { + activeStreams = new ActiveStreams(listeners); + // Reserved streams are excluded from the SETTINGS_MAX_CONCURRENT_STREAMS limit according to [1] and the RFC + // doesn't define a way to communicate the limit on reserved streams. We rely upon the peer to send RST_STREAM + // in response to any locally enforced limits being exceeded [2]. + // [1] https://tools.ietf.org/html/rfc7540#section-5.1.2 + // [2] https://tools.ietf.org/html/rfc7540#section-8.2.2 + localEndpoint = new DefaultEndpoint(server, server ? MAX_VALUE : maxReservedStreams); + remoteEndpoint = new DefaultEndpoint(!server, maxReservedStreams); + + // Add the connection stream to the map. + streamMap.put(connectionStream.id(), connectionStream); + } + + /** + * Determine if {@link #close(Promise)} has been called and no more streams are allowed to be created. + */ + final boolean isClosed() { + return closePromise != null; + } + + @Override + public Future close(final Promise promise) { + checkNotNull(promise, "promise"); + // Since we allow this method to be called multiple times, we must make sure that all the promises are notified + // when all streams are removed and the close operation completes. + if (closePromise != null) { + if (closePromise == promise) { + // Do nothing + } else if (promise instanceof ChannelPromise && ((ChannelFuture) closePromise).isVoid()) { + closePromise = promise; + } else { + PromiseNotifier.cascade(closePromise, promise); + } + } else { + closePromise = promise; + } + if (isStreamMapEmpty()) { + promise.trySuccess(null); + return promise; + } + + Iterator> itr = streamMap.entries().iterator(); + // We must take care while iterating the streamMap as to not modify while iterating in case there are other code + // paths iterating over the active streams. + if (activeStreams.allowModifications()) { + activeStreams.incrementPendingIterations(); + try { + while (itr.hasNext()) { + DefaultStream stream = (DefaultStream) itr.next().value(); + if (stream.id() != CONNECTION_STREAM_ID) { + // If modifications of the activeStream map is allowed, then a stream close operation will also + // modify the streamMap. Pass the iterator in so that remove will be called to prevent + // concurrent modification exceptions. + stream.close(itr); + } + } + } finally { + activeStreams.decrementPendingIterations(); + } + } else { + while (itr.hasNext()) { + Http2Stream stream = itr.next().value(); + if (stream.id() != CONNECTION_STREAM_ID) { + // We are not allowed to make modifications, so the close calls will be executed after this + // iteration completes. + stream.close(); + } + } + } + return closePromise; + } + + @Override + public void addListener(Listener listener) { + listeners.add(listener); + } + + @Override + public void removeListener(Listener listener) { + listeners.remove(listener); + } + + @Override + public boolean isServer() { + return localEndpoint.isServer(); + } + + @Override + public Http2Stream connectionStream() { + return connectionStream; + } + + @Override + public Http2Stream stream(int streamId) { + return streamMap.get(streamId); + } + + @Override + public boolean streamMayHaveExisted(int streamId) { + return remoteEndpoint.mayHaveCreatedStream(streamId) || localEndpoint.mayHaveCreatedStream(streamId); + } + + @Override + public int numActiveStreams() { + return activeStreams.size(); + } + + @Override + public Http2Stream forEachActiveStream(Http2StreamVisitor visitor) throws Http2Exception { + return activeStreams.forEachActiveStream(visitor); + } + + @Override + public Endpoint local() { + return localEndpoint; + } + + @Override + public Endpoint remote() { + return remoteEndpoint; + } + + @Override + public boolean goAwayReceived() { + return localEndpoint.lastStreamKnownByPeer >= 0; + } + + @Override + public void goAwayReceived(final int lastKnownStream, long errorCode, ByteBuf debugData) throws Http2Exception { + if (localEndpoint.lastStreamKnownByPeer() >= 0 && localEndpoint.lastStreamKnownByPeer() < lastKnownStream) { + throw connectionError(PROTOCOL_ERROR, "lastStreamId MUST NOT increase. Current value: %d new value: %d", + localEndpoint.lastStreamKnownByPeer(), lastKnownStream); + } + + localEndpoint.lastStreamKnownByPeer(lastKnownStream); + for (int i = 0; i < listeners.size(); ++i) { + try { + listeners.get(i).onGoAwayReceived(lastKnownStream, errorCode, debugData); + } catch (Throwable cause) { + logger.error("Caught Throwable from listener onGoAwayReceived.", cause); + } + } + + closeStreamsGreaterThanLastKnownStreamId(lastKnownStream, localEndpoint); + } + + @Override + public boolean goAwaySent() { + return remoteEndpoint.lastStreamKnownByPeer >= 0; + } + + @Override + public boolean goAwaySent(final int lastKnownStream, long errorCode, ByteBuf debugData) throws Http2Exception { + if (remoteEndpoint.lastStreamKnownByPeer() >= 0) { + // Protect against re-entrancy. Could happen if writing the frame fails, and error handling + // treating this is a connection handler and doing a graceful shutdown... + if (lastKnownStream == remoteEndpoint.lastStreamKnownByPeer()) { + return false; + } + if (lastKnownStream > remoteEndpoint.lastStreamKnownByPeer()) { + throw connectionError(PROTOCOL_ERROR, "Last stream identifier must not increase between " + + "sending multiple GOAWAY frames (was '%d', is '%d').", + remoteEndpoint.lastStreamKnownByPeer(), lastKnownStream); + } + } + + remoteEndpoint.lastStreamKnownByPeer(lastKnownStream); + for (int i = 0; i < listeners.size(); ++i) { + try { + listeners.get(i).onGoAwaySent(lastKnownStream, errorCode, debugData); + } catch (Throwable cause) { + logger.error("Caught Throwable from listener onGoAwaySent.", cause); + } + } + + closeStreamsGreaterThanLastKnownStreamId(lastKnownStream, remoteEndpoint); + return true; + } + + private void closeStreamsGreaterThanLastKnownStreamId(final int lastKnownStream, + final DefaultEndpoint endpoint) throws Http2Exception { + forEachActiveStream(new Http2StreamVisitor() { + @Override + public boolean visit(Http2Stream stream) { + if (stream.id() > lastKnownStream && endpoint.isValidStreamId(stream.id())) { + stream.close(); + } + return true; + } + }); + } + + /** + * Determine if {@link #streamMap} only contains the connection stream. + */ + private boolean isStreamMapEmpty() { + return streamMap.size() == 1; + } + + /** + * Remove a stream from the {@link #streamMap}. + * @param stream the stream to remove. + * @param itr an iterator that may be pointing to the stream during iteration and {@link Iterator#remove()} will be + * used if non-{@code null}. + */ + void removeStream(DefaultStream stream, Iterator itr) { + final boolean removed; + if (itr == null) { + removed = streamMap.remove(stream.id()) != null; + } else { + itr.remove(); + removed = true; + } + + if (removed) { + for (int i = 0; i < listeners.size(); i++) { + try { + listeners.get(i).onStreamRemoved(stream); + } catch (Throwable cause) { + logger.error("Caught Throwable from listener onStreamRemoved.", cause); + } + } + + if (closePromise != null && isStreamMapEmpty()) { + closePromise.trySuccess(null); + } + } + } + + static State activeState(int streamId, State initialState, boolean isLocal, boolean halfClosed) + throws Http2Exception { + switch (initialState) { + case IDLE: + return halfClosed ? isLocal ? HALF_CLOSED_LOCAL : HALF_CLOSED_REMOTE : OPEN; + case RESERVED_LOCAL: + return HALF_CLOSED_REMOTE; + case RESERVED_REMOTE: + return HALF_CLOSED_LOCAL; + default: + throw streamError(streamId, PROTOCOL_ERROR, "Attempting to open a stream in an invalid state: " + + initialState); + } + } + + void notifyHalfClosed(Http2Stream stream) { + for (int i = 0; i < listeners.size(); i++) { + try { + listeners.get(i).onStreamHalfClosed(stream); + } catch (Throwable cause) { + logger.error("Caught Throwable from listener onStreamHalfClosed.", cause); + } + } + } + + void notifyClosed(Http2Stream stream) { + for (int i = 0; i < listeners.size(); i++) { + try { + listeners.get(i).onStreamClosed(stream); + } catch (Throwable cause) { + logger.error("Caught Throwable from listener onStreamClosed.", cause); + } + } + } + + @Override + public PropertyKey newKey() { + return propertyKeyRegistry.newKey(); + } + + /** + * Verifies that the key is valid and returns it as the internal {@link DefaultPropertyKey} type. + * + * @throws NullPointerException if the key is {@code null}. + * @throws ClassCastException if the key is not of type {@link DefaultPropertyKey}. + * @throws IllegalArgumentException if the key was not created by this connection. + */ + final DefaultPropertyKey verifyKey(PropertyKey key) { + return checkNotNull((DefaultPropertyKey) key, "key").verifyConnection(this); + } + + /** + * Simple stream implementation. Streams can be compared to each other by priority. + */ + private class DefaultStream implements Http2Stream { + private static final byte META_STATE_SENT_RST = 1; + private static final byte META_STATE_SENT_HEADERS = 1 << 1; + private static final byte META_STATE_SENT_TRAILERS = 1 << 2; + private static final byte META_STATE_SENT_PUSHPROMISE = 1 << 3; + private static final byte META_STATE_RECV_HEADERS = 1 << 4; + private static final byte META_STATE_RECV_TRAILERS = 1 << 5; + private final int id; + private final PropertyMap properties = new PropertyMap(); + private State state; + private byte metaState; + + DefaultStream(int id, State state) { + this.id = id; + this.state = state; + } + + @Override + public final int id() { + return id; + } + + @Override + public final State state() { + return state; + } + + @Override + public boolean isResetSent() { + return (metaState & META_STATE_SENT_RST) != 0; + } + + @Override + public Http2Stream resetSent() { + metaState |= META_STATE_SENT_RST; + return this; + } + + @Override + public Http2Stream headersSent(boolean isInformational) { + if (!isInformational) { + metaState |= isHeadersSent() ? META_STATE_SENT_TRAILERS : META_STATE_SENT_HEADERS; + } + return this; + } + + @Override + public boolean isHeadersSent() { + return (metaState & META_STATE_SENT_HEADERS) != 0; + } + + @Override + public boolean isTrailersSent() { + return (metaState & META_STATE_SENT_TRAILERS) != 0; + } + + @Override + public Http2Stream headersReceived(boolean isInformational) { + if (!isInformational) { + metaState |= isHeadersReceived() ? META_STATE_RECV_TRAILERS : META_STATE_RECV_HEADERS; + } + return this; + } + + @Override + public boolean isHeadersReceived() { + return (metaState & META_STATE_RECV_HEADERS) != 0; + } + + @Override + public boolean isTrailersReceived() { + return (metaState & META_STATE_RECV_TRAILERS) != 0; + } + + @Override + public Http2Stream pushPromiseSent() { + metaState |= META_STATE_SENT_PUSHPROMISE; + return this; + } + + @Override + public boolean isPushPromiseSent() { + return (metaState & META_STATE_SENT_PUSHPROMISE) != 0; + } + + @Override + public final V setProperty(PropertyKey key, V value) { + return properties.add(verifyKey(key), value); + } + + @Override + public final V getProperty(PropertyKey key) { + return properties.get(verifyKey(key)); + } + + @Override + public final V removeProperty(PropertyKey key) { + return properties.remove(verifyKey(key)); + } + + @Override + public Http2Stream open(boolean halfClosed) throws Http2Exception { + state = activeState(id, state, isLocal(), halfClosed); + final DefaultEndpoint endpoint = createdBy(); + if (!endpoint.canOpenStream()) { + throw connectionError(PROTOCOL_ERROR, "Maximum active streams violated for this endpoint: " + + endpoint.maxActiveStreams()); + } + + activate(); + return this; + } + + void activate() { + // If the stream is opened in a half-closed state, the headers must have either + // been sent if this is a local stream, or received if it is a remote stream. + if (state == HALF_CLOSED_LOCAL) { + headersSent(/*isInformational*/ false); + } else if (state == HALF_CLOSED_REMOTE) { + headersReceived(/*isInformational*/ false); + } + activeStreams.activate(this); + } + + Http2Stream close(Iterator itr) { + if (state == CLOSED) { + return this; + } + + state = CLOSED; + + --createdBy().numStreams; + activeStreams.deactivate(this, itr); + return this; + } + + @Override + public Http2Stream close() { + return close(null); + } + + @Override + public Http2Stream closeLocalSide() { + switch (state) { + case OPEN: + state = HALF_CLOSED_LOCAL; + notifyHalfClosed(this); + break; + case HALF_CLOSED_LOCAL: + break; + default: + close(); + break; + } + return this; + } + + @Override + public Http2Stream closeRemoteSide() { + switch (state) { + case OPEN: + state = HALF_CLOSED_REMOTE; + notifyHalfClosed(this); + break; + case HALF_CLOSED_REMOTE: + break; + default: + close(); + break; + } + return this; + } + + DefaultEndpoint createdBy() { + return localEndpoint.isValidStreamId(id) ? localEndpoint : remoteEndpoint; + } + + final boolean isLocal() { + return localEndpoint.isValidStreamId(id); + } + + /** + * Provides the lazy initialization for the {@link DefaultStream} data map. + */ + private class PropertyMap { + Object[] values = EmptyArrays.EMPTY_OBJECTS; + + V add(DefaultPropertyKey key, V value) { + resizeIfNecessary(key.index); + @SuppressWarnings("unchecked") + V prevValue = (V) values[key.index]; + values[key.index] = value; + return prevValue; + } + + @SuppressWarnings("unchecked") + V get(DefaultPropertyKey key) { + if (key.index >= values.length) { + return null; + } + return (V) values[key.index]; + } + + @SuppressWarnings("unchecked") + V remove(DefaultPropertyKey key) { + V prevValue = null; + if (key.index < values.length) { + prevValue = (V) values[key.index]; + values[key.index] = null; + } + return prevValue; + } + + void resizeIfNecessary(int index) { + if (index >= values.length) { + values = Arrays.copyOf(values, propertyKeyRegistry.size()); + } + } + } + } + + /** + * Stream class representing the connection, itself. + */ + private final class ConnectionStream extends DefaultStream { + ConnectionStream() { + super(CONNECTION_STREAM_ID, IDLE); + } + + @Override + public boolean isResetSent() { + return false; + } + + @Override + DefaultEndpoint createdBy() { + return null; + } + + @Override + public Http2Stream resetSent() { + throw new UnsupportedOperationException(); + } + + @Override + public Http2Stream open(boolean halfClosed) { + throw new UnsupportedOperationException(); + } + + @Override + public Http2Stream close() { + throw new UnsupportedOperationException(); + } + + @Override + public Http2Stream closeLocalSide() { + throw new UnsupportedOperationException(); + } + + @Override + public Http2Stream closeRemoteSide() { + throw new UnsupportedOperationException(); + } + + @Override + public Http2Stream headersSent(boolean isInformational) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isHeadersSent() { + throw new UnsupportedOperationException(); + } + + @Override + public Http2Stream pushPromiseSent() { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isPushPromiseSent() { + throw new UnsupportedOperationException(); + } + } + + /** + * Simple endpoint implementation. + */ + private final class DefaultEndpoint implements Endpoint { + private final boolean server; + /** + * The minimum stream ID allowed when creating the next stream. This only applies at the time the stream is + * created. If the ID of the stream being created is less than this value, stream creation will fail. Upon + * successful creation of a stream, this value is incremented to the next valid stream ID. + */ + private int nextStreamIdToCreate; + /** + * Used for reservation of stream IDs. Stream IDs can be reserved in advance by applications before the streams + * are actually created. For example, applications may choose to buffer stream creation attempts as a way of + * working around {@code SETTINGS_MAX_CONCURRENT_STREAMS}, in which case they will reserve stream IDs for each + * buffered stream. + */ + private int nextReservationStreamId; + private int lastStreamKnownByPeer = -1; + private boolean pushToAllowed; + private F flowController; + private int maxStreams; + private int maxActiveStreams; + private final int maxReservedStreams; + // Fields accessed by inner classes + int numActiveStreams; + int numStreams; + + DefaultEndpoint(boolean server, int maxReservedStreams) { + this.server = server; + + // Determine the starting stream ID for this endpoint. Client-initiated streams + // are odd and server-initiated streams are even. Zero is reserved for the + // connection. Stream 1 is reserved client-initiated stream for responding to an + // upgrade from HTTP 1.1. + if (server) { + nextStreamIdToCreate = 2; + nextReservationStreamId = 0; + } else { + nextStreamIdToCreate = 1; + // For manually created client-side streams, 1 is reserved for HTTP upgrade, so start at 3. + nextReservationStreamId = 1; + } + + // Push is disallowed by default for servers and allowed for clients. + pushToAllowed = !server; + maxActiveStreams = MAX_VALUE; + this.maxReservedStreams = checkPositiveOrZero(maxReservedStreams, "maxReservedStreams"); + updateMaxStreams(); + } + + @Override + public int incrementAndGetNextStreamId() { + return nextReservationStreamId >= 0 ? nextReservationStreamId += 2 : nextReservationStreamId; + } + + private void incrementExpectedStreamId(int streamId) { + if (streamId > nextReservationStreamId && nextReservationStreamId >= 0) { + nextReservationStreamId = streamId; + } + nextStreamIdToCreate = streamId + 2; + ++numStreams; + } + + @Override + public boolean isValidStreamId(int streamId) { + return streamId > 0 && server == ((streamId & 1) == 0); + } + + @Override + public boolean mayHaveCreatedStream(int streamId) { + return isValidStreamId(streamId) && streamId <= lastStreamCreated(); + } + + @Override + public boolean canOpenStream() { + return numActiveStreams < maxActiveStreams; + } + + @Override + public DefaultStream createStream(int streamId, boolean halfClosed) throws Http2Exception { + State state = activeState(streamId, IDLE, isLocal(), halfClosed); + + checkNewStreamAllowed(streamId, state); + + // Create and initialize the stream. + DefaultStream stream = new DefaultStream(streamId, state); + + incrementExpectedStreamId(streamId); + + addStream(stream); + + stream.activate(); + return stream; + } + + @Override + public boolean created(Http2Stream stream) { + return stream instanceof DefaultStream && ((DefaultStream) stream).createdBy() == this; + } + + @Override + public boolean isServer() { + return server; + } + + @Override + public DefaultStream reservePushStream(int streamId, Http2Stream parent) throws Http2Exception { + if (parent == null) { + throw connectionError(PROTOCOL_ERROR, "Parent stream missing"); + } + if (isLocal() ? !parent.state().localSideOpen() : !parent.state().remoteSideOpen()) { + throw connectionError(PROTOCOL_ERROR, "Stream %d is not open for sending push promise", parent.id()); + } + if (!opposite().allowPushTo()) { + throw connectionError(PROTOCOL_ERROR, "Server push not allowed to opposite endpoint"); + } + State state = isLocal() ? RESERVED_LOCAL : RESERVED_REMOTE; + checkNewStreamAllowed(streamId, state); + + // Create and initialize the stream. + DefaultStream stream = new DefaultStream(streamId, state); + + incrementExpectedStreamId(streamId); + + // Register the stream. + addStream(stream); + return stream; + } + + private void addStream(DefaultStream stream) { + // Add the stream to the map and priority tree. + streamMap.put(stream.id(), stream); + + // Notify the listeners of the event. + for (int i = 0; i < listeners.size(); i++) { + try { + listeners.get(i).onStreamAdded(stream); + } catch (Throwable cause) { + logger.error("Caught Throwable from listener onStreamAdded.", cause); + } + } + } + + @Override + public void allowPushTo(boolean allow) { + if (allow && server) { + throw new IllegalArgumentException("Servers do not allow push"); + } + pushToAllowed = allow; + } + + @Override + public boolean allowPushTo() { + return pushToAllowed; + } + + @Override + public int numActiveStreams() { + return numActiveStreams; + } + + @Override + public int maxActiveStreams() { + return maxActiveStreams; + } + + @Override + public void maxActiveStreams(int maxActiveStreams) { + this.maxActiveStreams = maxActiveStreams; + updateMaxStreams(); + } + + @Override + public int lastStreamCreated() { + return nextStreamIdToCreate > 1 ? nextStreamIdToCreate - 2 : 0; + } + + @Override + public int lastStreamKnownByPeer() { + return lastStreamKnownByPeer; + } + + private void lastStreamKnownByPeer(int lastKnownStream) { + lastStreamKnownByPeer = lastKnownStream; + } + + @Override + public F flowController() { + return flowController; + } + + @Override + public void flowController(F flowController) { + this.flowController = checkNotNull(flowController, "flowController"); + } + + @Override + public Endpoint opposite() { + return isLocal() ? remoteEndpoint : localEndpoint; + } + + private void updateMaxStreams() { + maxStreams = (int) Math.min(MAX_VALUE, (long) maxActiveStreams + maxReservedStreams); + } + + private void checkNewStreamAllowed(int streamId, State state) throws Http2Exception { + assert state != IDLE; + if (lastStreamKnownByPeer >= 0 && streamId > lastStreamKnownByPeer) { + throw streamError(streamId, REFUSED_STREAM, + "Cannot create stream %d greater than Last-Stream-ID %d from GOAWAY.", + streamId, lastStreamKnownByPeer); + } + if (!isValidStreamId(streamId)) { + if (streamId < 0) { + throw new Http2NoMoreStreamIdsException(); + } + throw connectionError(PROTOCOL_ERROR, "Request stream %d is not correct for %s connection", streamId, + server ? "server" : "client"); + } + // This check must be after all id validated checks, but before the max streams check because it may be + // recoverable to some degree for handling frames which can be sent on closed streams. + if (streamId < nextStreamIdToCreate) { + throw closedStreamError(PROTOCOL_ERROR, "Request stream %d is behind the next expected stream %d", + streamId, nextStreamIdToCreate); + } + if (nextStreamIdToCreate <= 0) { + // We exhausted the stream id space that we can use. Let's signal this back but also signal that + // we still may want to process active streams. + throw new Http2Exception(REFUSED_STREAM, "Stream IDs are exhausted for this endpoint.", + Http2Exception.ShutdownHint.GRACEFUL_SHUTDOWN); + } + boolean isReserved = state == RESERVED_LOCAL || state == RESERVED_REMOTE; + if (!isReserved && !canOpenStream() || isReserved && numStreams >= maxStreams) { + throw streamError(streamId, REFUSED_STREAM, "Maximum active streams violated for this endpoint: " + + (isReserved ? maxStreams : maxActiveStreams)); + } + if (isClosed()) { + throw connectionError(INTERNAL_ERROR, "Attempted to create stream id %d after connection was closed", + streamId); + } + } + + private boolean isLocal() { + return this == localEndpoint; + } + } + + /** + * Allows events which would modify the collection of active streams to be queued while iterating via {@link + * #forEachActiveStream(Http2StreamVisitor)}. + */ + interface Event { + /** + * Trigger the original intention of this event. Expect to modify the active streams list. + *

+ * If a {@link RuntimeException} object is thrown it will be logged and not propagated. + * Throwing from this method is not supported and is considered a programming error. + */ + void process(); + } + + /** + * Manages the list of currently active streams. Queues any {@link Event}s that would modify the list of + * active streams in order to prevent modification while iterating. + */ + private final class ActiveStreams { + private final List listeners; + private final Queue pendingEvents = new ArrayDeque(4); + private final Set streams = new LinkedHashSet(); + private int pendingIterations; + + ActiveStreams(List listeners) { + this.listeners = listeners; + } + + public int size() { + return streams.size(); + } + + public void activate(final DefaultStream stream) { + if (allowModifications()) { + addToActiveStreams(stream); + } else { + pendingEvents.add(new Event() { + @Override + public void process() { + addToActiveStreams(stream); + } + }); + } + } + + public void deactivate(final DefaultStream stream, final Iterator itr) { + if (allowModifications() || itr != null) { + removeFromActiveStreams(stream, itr); + } else { + pendingEvents.add(new Event() { + @Override + public void process() { + removeFromActiveStreams(stream, itr); + } + }); + } + } + + public Http2Stream forEachActiveStream(Http2StreamVisitor visitor) throws Http2Exception { + incrementPendingIterations(); + try { + for (Http2Stream stream : streams) { + if (!visitor.visit(stream)) { + return stream; + } + } + return null; + } finally { + decrementPendingIterations(); + } + } + + void addToActiveStreams(DefaultStream stream) { + if (streams.add(stream)) { + // Update the number of active streams initiated by the endpoint. + stream.createdBy().numActiveStreams++; + + for (int i = 0; i < listeners.size(); i++) { + try { + listeners.get(i).onStreamActive(stream); + } catch (Throwable cause) { + logger.error("Caught Throwable from listener onStreamActive.", cause); + } + } + } + } + + void removeFromActiveStreams(DefaultStream stream, Iterator itr) { + if (streams.remove(stream)) { + // Update the number of active streams initiated by the endpoint. + stream.createdBy().numActiveStreams--; + notifyClosed(stream); + } + removeStream(stream, itr); + } + + boolean allowModifications() { + return pendingIterations == 0; + } + + void incrementPendingIterations() { + ++pendingIterations; + } + + void decrementPendingIterations() { + --pendingIterations; + if (allowModifications()) { + for (;;) { + Event event = pendingEvents.poll(); + if (event == null) { + break; + } + try { + event.process(); + } catch (Throwable cause) { + logger.error("Caught Throwable while processing pending ActiveStreams$Event.", cause); + } + } + } + } + } + + /** + * Implementation of {@link PropertyKey} that specifies the index position of the property. + */ + final class DefaultPropertyKey implements PropertyKey { + final int index; + + DefaultPropertyKey(int index) { + this.index = index; + } + + DefaultPropertyKey verifyConnection(Http2Connection connection) { + if (connection != DefaultHttp2Connection.this) { + throw new IllegalArgumentException("Using a key that was not created by this connection"); + } + return this; + } + } + + /** + * A registry of all stream property keys known by this connection. + */ + private final class PropertyKeyRegistry { + /** + * Initial size of 4 because the default configuration currently has 3 listeners + * (local/remote flow controller and {@link StreamByteDistributor}) and we leave room for 1 extra. + * We could be more aggressive but the ArrayList resize will double the size if we are too small. + */ + final List keys = new ArrayList(4); + + /** + * Registers a new property key. + */ + DefaultPropertyKey newKey() { + DefaultPropertyKey key = new DefaultPropertyKey(keys.size()); + keys.add(key); + return key; + } + + int size() { + return keys.size(); + } + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionDecoder.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionDecoder.java new file mode 100644 index 0000000..933e56f --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionDecoder.java @@ -0,0 +1,843 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpStatusClass; +import io.netty.handler.codec.http.HttpUtil; +import io.netty.handler.codec.http2.Http2Connection.Endpoint; +import io.netty.util.internal.UnstableApi; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.util.Iterator; +import java.util.List; +import java.util.Map.Entry; + +import static io.netty.handler.codec.http.HttpStatusClass.INFORMATIONAL; +import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_PRIORITY_WEIGHT; +import static io.netty.handler.codec.http2.Http2Error.INTERNAL_ERROR; +import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR; +import static io.netty.handler.codec.http2.Http2Error.STREAM_CLOSED; +import static io.netty.handler.codec.http2.Http2Exception.connectionError; +import static io.netty.handler.codec.http2.Http2Exception.streamError; +import static io.netty.handler.codec.http2.Http2PromisedRequestVerifier.ALWAYS_VERIFY; +import static io.netty.handler.codec.http2.Http2Stream.State.CLOSED; +import static io.netty.handler.codec.http2.Http2Stream.State.HALF_CLOSED_REMOTE; +import static io.netty.util.internal.ObjectUtil.checkNotNull; +import static java.lang.Integer.MAX_VALUE; +import static java.lang.Math.min; + +/** + * Provides the default implementation for processing inbound frame events and delegates to a + * {@link Http2FrameListener} + *

+ * This class will read HTTP/2 frames and delegate the events to a {@link Http2FrameListener} + *

+ * This interface enforces inbound flow control functionality through + * {@link Http2LocalFlowController} + */ +@UnstableApi +public class DefaultHttp2ConnectionDecoder implements Http2ConnectionDecoder { + private static final InternalLogger logger = InternalLoggerFactory.getInstance(DefaultHttp2ConnectionDecoder.class); + private Http2FrameListener internalFrameListener = new PrefaceFrameListener(); + private final Http2Connection connection; + private Http2LifecycleManager lifecycleManager; + private final Http2ConnectionEncoder encoder; + private final Http2FrameReader frameReader; + private Http2FrameListener listener; + private final Http2PromisedRequestVerifier requestVerifier; + private final Http2SettingsReceivedConsumer settingsReceivedConsumer; + private final boolean autoAckPing; + private final Http2Connection.PropertyKey contentLengthKey; + private final boolean validateHeaders; + + public DefaultHttp2ConnectionDecoder(Http2Connection connection, + Http2ConnectionEncoder encoder, + Http2FrameReader frameReader) { + this(connection, encoder, frameReader, ALWAYS_VERIFY); + } + + public DefaultHttp2ConnectionDecoder(Http2Connection connection, + Http2ConnectionEncoder encoder, + Http2FrameReader frameReader, + Http2PromisedRequestVerifier requestVerifier) { + this(connection, encoder, frameReader, requestVerifier, true); + } + + /** + * Create a new instance. + * @param connection The {@link Http2Connection} associated with this decoder. + * @param encoder The {@link Http2ConnectionEncoder} associated with this decoder. + * @param frameReader Responsible for reading/parsing the raw frames. As opposed to this object which applies + * h2 semantics on top of the frames. + * @param requestVerifier Determines if push promised streams are valid. + * @param autoAckSettings {@code false} to disable automatically applying and sending settings acknowledge frame. + * The {@code Http2ConnectionEncoder} is expected to be an instance of {@link Http2SettingsReceivedConsumer} and + * will apply the earliest received but not yet ACKed SETTINGS when writing the SETTINGS ACKs. + * {@code true} to enable automatically applying and sending settings acknowledge frame. + */ + public DefaultHttp2ConnectionDecoder(Http2Connection connection, + Http2ConnectionEncoder encoder, + Http2FrameReader frameReader, + Http2PromisedRequestVerifier requestVerifier, + boolean autoAckSettings) { + this(connection, encoder, frameReader, requestVerifier, autoAckSettings, true); + } + + @Deprecated + public DefaultHttp2ConnectionDecoder(Http2Connection connection, + Http2ConnectionEncoder encoder, + Http2FrameReader frameReader, + Http2PromisedRequestVerifier requestVerifier, + boolean autoAckSettings, + boolean autoAckPing) { + this(connection, encoder, frameReader, requestVerifier, autoAckSettings, true, true); + } + + /** + * Create a new instance. + * @param connection The {@link Http2Connection} associated with this decoder. + * @param encoder The {@link Http2ConnectionEncoder} associated with this decoder. + * @param frameReader Responsible for reading/parsing the raw frames. As opposed to this object which applies + * h2 semantics on top of the frames. + * @param requestVerifier Determines if push promised streams are valid. + * @param autoAckSettings {@code false} to disable automatically applying and sending settings acknowledge frame. + * The {@code Http2ConnectionEncoder} is expected to be an instance of + * {@link Http2SettingsReceivedConsumer} and will apply the earliest received but not yet + * ACKed SETTINGS when writing the SETTINGS ACKs. {@code true} to enable automatically + * applying and sending settings acknowledge frame. + * @param autoAckPing {@code false} to disable automatically sending ping acknowledge frame. {@code true} to enable + * automatically sending ping ack frame. + */ + public DefaultHttp2ConnectionDecoder(Http2Connection connection, + Http2ConnectionEncoder encoder, + Http2FrameReader frameReader, + Http2PromisedRequestVerifier requestVerifier, + boolean autoAckSettings, + boolean autoAckPing, + boolean validateHeaders) { + this.validateHeaders = validateHeaders; + this.autoAckPing = autoAckPing; + if (autoAckSettings) { + settingsReceivedConsumer = null; + } else { + if (!(encoder instanceof Http2SettingsReceivedConsumer)) { + throw new IllegalArgumentException("disabling autoAckSettings requires the encoder to be a " + + Http2SettingsReceivedConsumer.class); + } + settingsReceivedConsumer = (Http2SettingsReceivedConsumer) encoder; + } + this.connection = checkNotNull(connection, "connection"); + contentLengthKey = this.connection.newKey(); + this.frameReader = checkNotNull(frameReader, "frameReader"); + this.encoder = checkNotNull(encoder, "encoder"); + this.requestVerifier = checkNotNull(requestVerifier, "requestVerifier"); + if (connection.local().flowController() == null) { + connection.local().flowController(new DefaultHttp2LocalFlowController(connection)); + } + connection.local().flowController().frameWriter(encoder.frameWriter()); + } + + @Override + public void lifecycleManager(Http2LifecycleManager lifecycleManager) { + this.lifecycleManager = checkNotNull(lifecycleManager, "lifecycleManager"); + } + + @Override + public Http2Connection connection() { + return connection; + } + + @Override + public final Http2LocalFlowController flowController() { + return connection.local().flowController(); + } + + @Override + public void frameListener(Http2FrameListener listener) { + this.listener = checkNotNull(listener, "listener"); + } + + @Override + public Http2FrameListener frameListener() { + return listener; + } + + @Override + public boolean prefaceReceived() { + return FrameReadListener.class == internalFrameListener.getClass(); + } + + @Override + public void decodeFrame(ChannelHandlerContext ctx, ByteBuf in, List out) throws Http2Exception { + frameReader.readFrame(ctx, in, internalFrameListener); + } + + @Override + public Http2Settings localSettings() { + Http2Settings settings = new Http2Settings(); + Http2FrameReader.Configuration config = frameReader.configuration(); + Http2HeadersDecoder.Configuration headersConfig = config.headersConfiguration(); + Http2FrameSizePolicy frameSizePolicy = config.frameSizePolicy(); + settings.initialWindowSize(flowController().initialWindowSize()); + settings.maxConcurrentStreams(connection.remote().maxActiveStreams()); + settings.headerTableSize(headersConfig.maxHeaderTableSize()); + settings.maxFrameSize(frameSizePolicy.maxFrameSize()); + settings.maxHeaderListSize(headersConfig.maxHeaderListSize()); + if (!connection.isServer()) { + // Only set the pushEnabled flag if this is a client endpoint. + settings.pushEnabled(connection.local().allowPushTo()); + } + return settings; + } + + @Override + public void close() { + frameReader.close(); + } + + /** + * Calculate the threshold in bytes which should trigger a {@code GO_AWAY} if a set of headers exceeds this amount. + * @param maxHeaderListSize + * SETTINGS_MAX_HEADER_LIST_SIZE for the local + * endpoint. + * @return the threshold in bytes which should trigger a {@code GO_AWAY} if a set of headers exceeds this amount. + */ + protected long calculateMaxHeaderListSizeGoAway(long maxHeaderListSize) { + return Http2CodecUtil.calculateMaxHeaderListSizeGoAway(maxHeaderListSize); + } + + private int unconsumedBytes(Http2Stream stream) { + return flowController().unconsumedBytes(stream); + } + + void onGoAwayRead0(ChannelHandlerContext ctx, int lastStreamId, long errorCode, ByteBuf debugData) + throws Http2Exception { + listener.onGoAwayRead(ctx, lastStreamId, errorCode, debugData); + connection.goAwayReceived(lastStreamId, errorCode, debugData); + } + + void onUnknownFrame0(ChannelHandlerContext ctx, byte frameType, int streamId, Http2Flags flags, + ByteBuf payload) throws Http2Exception { + listener.onUnknownFrame(ctx, frameType, streamId, flags, payload); + } + + // See https://tools.ietf.org/html/rfc7540#section-8.1.2.6 + private void verifyContentLength(Http2Stream stream, int data, boolean isEnd) throws Http2Exception { + ContentLength contentLength = stream.getProperty(contentLengthKey); + if (contentLength != null) { + try { + contentLength.increaseReceivedBytes(connection.isServer(), stream.id(), data, isEnd); + } finally { + if (isEnd) { + stream.removeProperty(contentLengthKey); + } + } + } + } + + /** + * Handles all inbound frames from the network. + */ + private final class FrameReadListener implements Http2FrameListener { + @Override + public int onDataRead(final ChannelHandlerContext ctx, int streamId, ByteBuf data, int padding, + boolean endOfStream) throws Http2Exception { + Http2Stream stream = connection.stream(streamId); + Http2LocalFlowController flowController = flowController(); + int readable = data.readableBytes(); + int bytesToReturn = readable + padding; + + final boolean shouldIgnore; + try { + shouldIgnore = shouldIgnoreHeadersOrDataFrame(ctx, streamId, stream, endOfStream, "DATA"); + } catch (Http2Exception e) { + // Ignoring this frame. We still need to count the frame towards the connection flow control + // window, but we immediately mark all bytes as consumed. + flowController.receiveFlowControlledFrame(stream, data, padding, endOfStream); + flowController.consumeBytes(stream, bytesToReturn); + throw e; + } catch (Throwable t) { + throw connectionError(INTERNAL_ERROR, t, "Unhandled error on data stream id %d", streamId); + } + + if (shouldIgnore) { + // Ignoring this frame. We still need to count the frame towards the connection flow control + // window, but we immediately mark all bytes as consumed. + flowController.receiveFlowControlledFrame(stream, data, padding, endOfStream); + flowController.consumeBytes(stream, bytesToReturn); + + // Verify that the stream may have existed after we apply flow control. + verifyStreamMayHaveExisted(streamId, endOfStream, "DATA"); + + // All bytes have been consumed. + return bytesToReturn; + } + Http2Exception error = null; + switch (stream.state()) { + case OPEN: + case HALF_CLOSED_LOCAL: + break; + case HALF_CLOSED_REMOTE: + case CLOSED: + error = streamError(stream.id(), STREAM_CLOSED, "Stream %d in unexpected state: %s", + stream.id(), stream.state()); + break; + default: + error = streamError(stream.id(), PROTOCOL_ERROR, + "Stream %d in unexpected state: %s", stream.id(), stream.state()); + break; + } + + int unconsumedBytes = unconsumedBytes(stream); + try { + flowController.receiveFlowControlledFrame(stream, data, padding, endOfStream); + // Update the unconsumed bytes after flow control is applied. + unconsumedBytes = unconsumedBytes(stream); + + // If the stream is in an invalid state to receive the frame, throw the error. + if (error != null) { + throw error; + } + + verifyContentLength(stream, readable, endOfStream); + + // Call back the application and retrieve the number of bytes that have been + // immediately processed. + bytesToReturn = listener.onDataRead(ctx, streamId, data, padding, endOfStream); + + if (endOfStream) { + lifecycleManager.closeStreamRemote(stream, ctx.newSucceededFuture()); + } + + return bytesToReturn; + } catch (Http2Exception e) { + // If an exception happened during delivery, the listener may have returned part + // of the bytes before the error occurred. If that's the case, subtract that from + // the total processed bytes so that we don't return too many bytes. + int delta = unconsumedBytes - unconsumedBytes(stream); + bytesToReturn -= delta; + throw e; + } catch (RuntimeException e) { + // If an exception happened during delivery, the listener may have returned part + // of the bytes before the error occurred. If that's the case, subtract that from + // the total processed bytes so that we don't return too many bytes. + int delta = unconsumedBytes - unconsumedBytes(stream); + bytesToReturn -= delta; + throw e; + } finally { + // If appropriate, return the processed bytes to the flow controller. + flowController.consumeBytes(stream, bytesToReturn); + } + } + + @Override + public void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers headers, int padding, + boolean endOfStream) throws Http2Exception { + onHeadersRead(ctx, streamId, headers, 0, DEFAULT_PRIORITY_WEIGHT, false, padding, endOfStream); + } + + @Override + public void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers headers, int streamDependency, + short weight, boolean exclusive, int padding, boolean endOfStream) throws Http2Exception { + Http2Stream stream = connection.stream(streamId); + boolean allowHalfClosedRemote = false; + boolean isTrailers = false; + if (stream == null && !connection.streamMayHaveExisted(streamId)) { + stream = connection.remote().createStream(streamId, endOfStream); + // Allow the state to be HALF_CLOSE_REMOTE if we're creating it in that state. + allowHalfClosedRemote = stream.state() == HALF_CLOSED_REMOTE; + } else if (stream != null) { + isTrailers = stream.isHeadersReceived(); + } + + if (shouldIgnoreHeadersOrDataFrame(ctx, streamId, stream, endOfStream, "HEADERS")) { + return; + } + + boolean isInformational = !connection.isServer() && + HttpStatusClass.valueOf(headers.status()) == INFORMATIONAL; + if ((isInformational || !endOfStream) && stream.isHeadersReceived() || stream.isTrailersReceived()) { + throw streamError(streamId, PROTOCOL_ERROR, + "Stream %d received too many headers EOS: %s state: %s", + streamId, endOfStream, stream.state()); + } + + switch (stream.state()) { + case RESERVED_REMOTE: + stream.open(endOfStream); + break; + case OPEN: + case HALF_CLOSED_LOCAL: + // Allowed to receive headers in these states. + break; + case HALF_CLOSED_REMOTE: + if (!allowHalfClosedRemote) { + throw streamError(stream.id(), STREAM_CLOSED, "Stream %d in unexpected state: %s", + stream.id(), stream.state()); + } + break; + case CLOSED: + throw streamError(stream.id(), STREAM_CLOSED, "Stream %d in unexpected state: %s", + stream.id(), stream.state()); + default: + // Connection error. + throw connectionError(PROTOCOL_ERROR, "Stream %d in unexpected state: %s", stream.id(), + stream.state()); + } + + if (!isTrailers) { + // extract the content-length header + List contentLength = headers.getAll(HttpHeaderNames.CONTENT_LENGTH); + if (contentLength != null && !contentLength.isEmpty()) { + try { + long cLength = HttpUtil.normalizeAndGetContentLength(contentLength, false, true); + if (cLength != -1) { + headers.setLong(HttpHeaderNames.CONTENT_LENGTH, cLength); + stream.setProperty(contentLengthKey, new ContentLength(cLength)); + } + } catch (IllegalArgumentException e) { + throw streamError(stream.id(), PROTOCOL_ERROR, e, + "Multiple content-length headers received"); + } + } + // Use size() instead of isEmpty() for backward compatibility with grpc-java prior to 1.59.1, + // see https://github.com/grpc/grpc-java/issues/10665 + } else if (validateHeaders && headers.size() > 0) { + // Need to check trailers don't contain pseudo headers. According to RFC 9113 + // Trailers MUST NOT include pseudo-header fields (Section 8.3). + for (Iterator> iterator = + headers.iterator(); iterator.hasNext();) { + CharSequence name = iterator.next().getKey(); + if (Http2Headers.PseudoHeaderName.hasPseudoHeaderFormat(name)) { + throw streamError(stream.id(), PROTOCOL_ERROR, + "Found invalid Pseudo-Header in trailers: %s", name); + } + } + } + + stream.headersReceived(isInformational); + verifyContentLength(stream, 0, endOfStream); + encoder.flowController().updateDependencyTree(streamId, streamDependency, weight, exclusive); + listener.onHeadersRead(ctx, streamId, headers, streamDependency, + weight, exclusive, padding, endOfStream); + // If the headers completes this stream, close it. + if (endOfStream) { + lifecycleManager.closeStreamRemote(stream, ctx.newSucceededFuture()); + } + } + + @Override + public void onPriorityRead(ChannelHandlerContext ctx, int streamId, int streamDependency, short weight, + boolean exclusive) throws Http2Exception { + encoder.flowController().updateDependencyTree(streamId, streamDependency, weight, exclusive); + + listener.onPriorityRead(ctx, streamId, streamDependency, weight, exclusive); + } + + @Override + public void onRstStreamRead(ChannelHandlerContext ctx, int streamId, long errorCode) throws Http2Exception { + Http2Stream stream = connection.stream(streamId); + if (stream == null) { + verifyStreamMayHaveExisted(streamId, false, "RST_STREAM"); + return; + } + + switch(stream.state()) { + case IDLE: + throw connectionError(PROTOCOL_ERROR, "RST_STREAM received for IDLE stream %d", streamId); + case CLOSED: + return; // RST_STREAM frames must be ignored for closed streams. + default: + break; + } + + listener.onRstStreamRead(ctx, streamId, errorCode); + + lifecycleManager.closeStream(stream, ctx.newSucceededFuture()); + } + + @Override + public void onSettingsAckRead(ChannelHandlerContext ctx) throws Http2Exception { + // Apply oldest outstanding local settings here. This is a synchronization point between endpoints. + Http2Settings settings = encoder.pollSentSettings(); + + if (settings != null) { + applyLocalSettings(settings); + } + + listener.onSettingsAckRead(ctx); + } + + /** + * Applies settings sent from the local endpoint. + *

+ * This method is only called after the local settings have been acknowledged from the remote endpoint. + */ + private void applyLocalSettings(Http2Settings settings) throws Http2Exception { + Boolean pushEnabled = settings.pushEnabled(); + final Http2FrameReader.Configuration config = frameReader.configuration(); + final Http2HeadersDecoder.Configuration headerConfig = config.headersConfiguration(); + final Http2FrameSizePolicy frameSizePolicy = config.frameSizePolicy(); + if (pushEnabled != null) { + if (connection.isServer()) { + throw connectionError(PROTOCOL_ERROR, "Server sending SETTINGS frame with ENABLE_PUSH specified"); + } + connection.local().allowPushTo(pushEnabled); + } + + Long maxConcurrentStreams = settings.maxConcurrentStreams(); + if (maxConcurrentStreams != null) { + connection.remote().maxActiveStreams((int) min(maxConcurrentStreams, MAX_VALUE)); + } + + Long headerTableSize = settings.headerTableSize(); + if (headerTableSize != null) { + headerConfig.maxHeaderTableSize(headerTableSize); + } + + Long maxHeaderListSize = settings.maxHeaderListSize(); + if (maxHeaderListSize != null) { + headerConfig.maxHeaderListSize(maxHeaderListSize, calculateMaxHeaderListSizeGoAway(maxHeaderListSize)); + } + + Integer maxFrameSize = settings.maxFrameSize(); + if (maxFrameSize != null) { + frameSizePolicy.maxFrameSize(maxFrameSize); + } + + Integer initialWindowSize = settings.initialWindowSize(); + if (initialWindowSize != null) { + flowController().initialWindowSize(initialWindowSize); + } + } + + @Override + public void onSettingsRead(final ChannelHandlerContext ctx, Http2Settings settings) throws Http2Exception { + if (settingsReceivedConsumer == null) { + // Acknowledge receipt of the settings. We should do this before we process the settings to ensure our + // remote peer applies these settings before any subsequent frames that we may send which depend upon + // these new settings. See https://github.com/netty/netty/issues/6520. + encoder.writeSettingsAck(ctx, ctx.newPromise()); + + encoder.remoteSettings(settings); + } else { + settingsReceivedConsumer.consumeReceivedSettings(settings); + } + + listener.onSettingsRead(ctx, settings); + } + + @Override + public void onPingRead(ChannelHandlerContext ctx, long data) throws Http2Exception { + if (autoAckPing) { + // Send an ack back to the remote client. + encoder.writePing(ctx, true, data, ctx.newPromise()); + } + listener.onPingRead(ctx, data); + } + + @Override + public void onPingAckRead(ChannelHandlerContext ctx, long data) throws Http2Exception { + listener.onPingAckRead(ctx, data); + } + + @Override + public void onPushPromiseRead(ChannelHandlerContext ctx, int streamId, int promisedStreamId, + Http2Headers headers, int padding) throws Http2Exception { + // A client cannot push. + if (connection().isServer()) { + throw connectionError(PROTOCOL_ERROR, "A client cannot push."); + } + + Http2Stream parentStream = connection.stream(streamId); + + if (shouldIgnoreHeadersOrDataFrame(ctx, streamId, parentStream, false, "PUSH_PROMISE")) { + return; + } + + switch (parentStream.state()) { + case OPEN: + case HALF_CLOSED_LOCAL: + // Allowed to receive push promise in these states. + break; + default: + // Connection error. + throw connectionError(PROTOCOL_ERROR, + "Stream %d in unexpected state for receiving push promise: %s", + parentStream.id(), parentStream.state()); + } + + if (!requestVerifier.isAuthoritative(ctx, headers)) { + throw streamError(promisedStreamId, PROTOCOL_ERROR, + "Promised request on stream %d for promised stream %d is not authoritative", + streamId, promisedStreamId); + } + if (!requestVerifier.isCacheable(headers)) { + throw streamError(promisedStreamId, PROTOCOL_ERROR, + "Promised request on stream %d for promised stream %d is not known to be cacheable", + streamId, promisedStreamId); + } + if (!requestVerifier.isSafe(headers)) { + throw streamError(promisedStreamId, PROTOCOL_ERROR, + "Promised request on stream %d for promised stream %d is not known to be safe", + streamId, promisedStreamId); + } + + // Reserve the push stream based with a priority based on the current stream's priority. + connection.remote().reservePushStream(promisedStreamId, parentStream); + + listener.onPushPromiseRead(ctx, streamId, promisedStreamId, headers, padding); + } + + @Override + public void onGoAwayRead(ChannelHandlerContext ctx, int lastStreamId, long errorCode, ByteBuf debugData) + throws Http2Exception { + onGoAwayRead0(ctx, lastStreamId, errorCode, debugData); + } + + @Override + public void onWindowUpdateRead(ChannelHandlerContext ctx, int streamId, int windowSizeIncrement) + throws Http2Exception { + Http2Stream stream = connection.stream(streamId); + if (stream == null || stream.state() == CLOSED || streamCreatedAfterGoAwaySent(streamId)) { + // Ignore this frame. + verifyStreamMayHaveExisted(streamId, false, "WINDOW_UPDATE"); + return; + } + + // Update the outbound flow control window. + encoder.flowController().incrementWindowSize(stream, windowSizeIncrement); + + listener.onWindowUpdateRead(ctx, streamId, windowSizeIncrement); + } + + @Override + public void onUnknownFrame(ChannelHandlerContext ctx, byte frameType, int streamId, Http2Flags flags, + ByteBuf payload) throws Http2Exception { + onUnknownFrame0(ctx, frameType, streamId, flags, payload); + } + + /** + * Helper method to determine if a frame that has the semantics of headers or data should be ignored for the + * {@code stream} (which may be {@code null}) associated with {@code streamId}. + */ + private boolean shouldIgnoreHeadersOrDataFrame(ChannelHandlerContext ctx, int streamId, Http2Stream stream, + boolean endOfStream, String frameName) throws Http2Exception { + if (stream == null) { + if (streamCreatedAfterGoAwaySent(streamId)) { + logger.info("{} ignoring {} frame for stream {}. Stream sent after GOAWAY sent", + ctx.channel(), frameName, streamId); + return true; + } + + // Make sure it's not an out-of-order frame, like a rogue DATA frame, for a stream that could + // never have existed. + verifyStreamMayHaveExisted(streamId, endOfStream, frameName); + + // Its possible that this frame would result in stream ID out of order creation (PROTOCOL ERROR) and its + // also possible that this frame is received on a CLOSED stream (STREAM_CLOSED after a RST_STREAM is + // sent). We don't have enough information to know for sure, so we choose the lesser of the two errors. + throw streamError(streamId, STREAM_CLOSED, "Received %s frame for an unknown stream %d", + frameName, streamId); + } + if (stream.isResetSent() || streamCreatedAfterGoAwaySent(streamId)) { + // If we have sent a reset stream it is assumed the stream will be closed after the write completes. + // If we have not sent a reset, but the stream was created after a GoAway this is not supported by + // DefaultHttp2Connection and if a custom Http2Connection is used it is assumed the lifetime is managed + // elsewhere so we don't close the stream or otherwise modify the stream's state. + + if (logger.isInfoEnabled()) { + logger.info("{} ignoring {} frame for stream {}", ctx.channel(), frameName, + stream.isResetSent() ? "RST_STREAM sent." : + "Stream created after GOAWAY sent. Last known stream by peer " + + connection.remote().lastStreamKnownByPeer()); + } + + return true; + } + return false; + } + + /** + * Helper method for determining whether or not to ignore inbound frames. A stream is considered to be created + * after a {@code GOAWAY} is sent if the following conditions hold: + *

+ *

    + *
  • A {@code GOAWAY} must have been sent by the local endpoint
  • + *
  • The {@code streamId} must identify a legitimate stream id for the remote endpoint to be creating
  • + *
  • {@code streamId} is greater than the Last Known Stream ID which was sent by the local endpoint + * in the last {@code GOAWAY} frame
  • + *
+ *

+ */ + private boolean streamCreatedAfterGoAwaySent(int streamId) { + Endpoint remote = connection.remote(); + return connection.goAwaySent() && remote.isValidStreamId(streamId) && + streamId > remote.lastStreamKnownByPeer(); + } + + private void verifyStreamMayHaveExisted(int streamId, boolean endOfStream, String frameName) + throws Http2Exception { + if (!connection.streamMayHaveExisted(streamId)) { + throw connectionError(PROTOCOL_ERROR, + "Stream %d does not exist for inbound frame %s, endOfStream = %b", + streamId, frameName, endOfStream); + } + } + } + + private final class PrefaceFrameListener implements Http2FrameListener { + /** + * Verifies that the HTTP/2 connection preface has been received from the remote endpoint. + * It is possible that the current call to + * {@link Http2FrameReader#readFrame(ChannelHandlerContext, ByteBuf, Http2FrameListener)} will have multiple + * frames to dispatch. So it may be OK for this class to get legitimate frames for the first readFrame. + */ + private void verifyPrefaceReceived() throws Http2Exception { + if (!prefaceReceived()) { + throw connectionError(PROTOCOL_ERROR, "Received non-SETTINGS as first frame."); + } + } + + @Override + public int onDataRead(ChannelHandlerContext ctx, int streamId, ByteBuf data, int padding, boolean endOfStream) + throws Http2Exception { + verifyPrefaceReceived(); + return internalFrameListener.onDataRead(ctx, streamId, data, padding, endOfStream); + } + + @Override + public void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers headers, int padding, + boolean endOfStream) throws Http2Exception { + verifyPrefaceReceived(); + internalFrameListener.onHeadersRead(ctx, streamId, headers, padding, endOfStream); + } + + @Override + public void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers headers, int streamDependency, + short weight, boolean exclusive, int padding, boolean endOfStream) throws Http2Exception { + verifyPrefaceReceived(); + internalFrameListener.onHeadersRead(ctx, streamId, headers, streamDependency, weight, + exclusive, padding, endOfStream); + } + + @Override + public void onPriorityRead(ChannelHandlerContext ctx, int streamId, int streamDependency, short weight, + boolean exclusive) throws Http2Exception { + verifyPrefaceReceived(); + internalFrameListener.onPriorityRead(ctx, streamId, streamDependency, weight, exclusive); + } + + @Override + public void onRstStreamRead(ChannelHandlerContext ctx, int streamId, long errorCode) throws Http2Exception { + verifyPrefaceReceived(); + internalFrameListener.onRstStreamRead(ctx, streamId, errorCode); + } + + @Override + public void onSettingsAckRead(ChannelHandlerContext ctx) throws Http2Exception { + verifyPrefaceReceived(); + internalFrameListener.onSettingsAckRead(ctx); + } + + @Override + public void onSettingsRead(ChannelHandlerContext ctx, Http2Settings settings) throws Http2Exception { + // The first settings should change the internalFrameListener to the "real" listener + // that expects the preface to be verified. + if (!prefaceReceived()) { + internalFrameListener = new FrameReadListener(); + } + internalFrameListener.onSettingsRead(ctx, settings); + } + + @Override + public void onPingRead(ChannelHandlerContext ctx, long data) throws Http2Exception { + verifyPrefaceReceived(); + internalFrameListener.onPingRead(ctx, data); + } + + @Override + public void onPingAckRead(ChannelHandlerContext ctx, long data) throws Http2Exception { + verifyPrefaceReceived(); + internalFrameListener.onPingAckRead(ctx, data); + } + + @Override + public void onPushPromiseRead(ChannelHandlerContext ctx, int streamId, int promisedStreamId, + Http2Headers headers, int padding) throws Http2Exception { + verifyPrefaceReceived(); + internalFrameListener.onPushPromiseRead(ctx, streamId, promisedStreamId, headers, padding); + } + + @Override + public void onGoAwayRead(ChannelHandlerContext ctx, int lastStreamId, long errorCode, ByteBuf debugData) + throws Http2Exception { + onGoAwayRead0(ctx, lastStreamId, errorCode, debugData); + } + + @Override + public void onWindowUpdateRead(ChannelHandlerContext ctx, int streamId, int windowSizeIncrement) + throws Http2Exception { + verifyPrefaceReceived(); + internalFrameListener.onWindowUpdateRead(ctx, streamId, windowSizeIncrement); + } + + @Override + public void onUnknownFrame(ChannelHandlerContext ctx, byte frameType, int streamId, Http2Flags flags, + ByteBuf payload) throws Http2Exception { + onUnknownFrame0(ctx, frameType, streamId, flags, payload); + } + } + + private static final class ContentLength { + private final long expected; + private long seen; + + ContentLength(long expected) { + this.expected = expected; + } + + void increaseReceivedBytes(boolean server, int streamId, int bytes, boolean isEnd) throws Http2Exception { + seen += bytes; + // Check for overflow + if (seen < 0) { + throw streamError(streamId, PROTOCOL_ERROR, + "Received amount of data did overflow and so not match content-length header %d", expected); + } + // Check if we received more data then what was advertised via the content-length header. + if (seen > expected) { + throw streamError(streamId, PROTOCOL_ERROR, + "Received amount of data %d does not match content-length header %d", seen, expected); + } + + if (isEnd) { + if (seen == 0 && !server) { + // This may be a response to a HEAD request, let's just allow it. + return; + } + + // Check that we really saw what was told via the content-length header. + if (expected > seen) { + throw streamError(streamId, PROTOCOL_ERROR, + "Received amount of data %d does not match content-length header %d", seen, expected); + } + } + } + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionEncoder.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionEncoder.java new file mode 100644 index 0000000..08993d9 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionEncoder.java @@ -0,0 +1,634 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.channel.CoalescingBufferQueue; +import io.netty.handler.codec.http.HttpStatusClass; +import io.netty.handler.codec.http2.Http2CodecUtil.SimpleChannelPromiseAggregator; +import io.netty.util.internal.UnstableApi; + +import java.util.ArrayDeque; +import java.util.Queue; + +import static io.netty.handler.codec.http.HttpStatusClass.INFORMATIONAL; +import static io.netty.handler.codec.http2.Http2Error.INTERNAL_ERROR; +import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR; +import static io.netty.handler.codec.http2.Http2Exception.connectionError; +import static io.netty.util.internal.ObjectUtil.checkNotNull; +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; +import static java.lang.Integer.MAX_VALUE; +import static java.lang.Math.min; + +/** + * Default implementation of {@link Http2ConnectionEncoder}. + */ +@UnstableApi +public class DefaultHttp2ConnectionEncoder implements Http2ConnectionEncoder, Http2SettingsReceivedConsumer { + private final Http2FrameWriter frameWriter; + private final Http2Connection connection; + private Http2LifecycleManager lifecycleManager; + // We prefer ArrayDeque to LinkedList because later will produce more GC. + // This initial capacity is plenty for SETTINGS traffic. + private final Queue outstandingLocalSettingsQueue = new ArrayDeque(4); + private Queue outstandingRemoteSettingsQueue; + + public DefaultHttp2ConnectionEncoder(Http2Connection connection, Http2FrameWriter frameWriter) { + this.connection = checkNotNull(connection, "connection"); + this.frameWriter = checkNotNull(frameWriter, "frameWriter"); + if (connection.remote().flowController() == null) { + connection.remote().flowController(new DefaultHttp2RemoteFlowController(connection)); + } + } + + @Override + public void lifecycleManager(Http2LifecycleManager lifecycleManager) { + this.lifecycleManager = checkNotNull(lifecycleManager, "lifecycleManager"); + } + + @Override + public Http2FrameWriter frameWriter() { + return frameWriter; + } + + @Override + public Http2Connection connection() { + return connection; + } + + @Override + public final Http2RemoteFlowController flowController() { + return connection().remote().flowController(); + } + + @Override + public void remoteSettings(Http2Settings settings) throws Http2Exception { + Boolean pushEnabled = settings.pushEnabled(); + Http2FrameWriter.Configuration config = configuration(); + Http2HeadersEncoder.Configuration outboundHeaderConfig = config.headersConfiguration(); + Http2FrameSizePolicy outboundFrameSizePolicy = config.frameSizePolicy(); + if (pushEnabled != null) { + if (!connection.isServer() && pushEnabled) { + throw connectionError(PROTOCOL_ERROR, + "Client received a value of ENABLE_PUSH specified to other than 0"); + } + connection.remote().allowPushTo(pushEnabled); + } + + Long maxConcurrentStreams = settings.maxConcurrentStreams(); + if (maxConcurrentStreams != null) { + connection.local().maxActiveStreams((int) min(maxConcurrentStreams, MAX_VALUE)); + } + + Long headerTableSize = settings.headerTableSize(); + if (headerTableSize != null) { + outboundHeaderConfig.maxHeaderTableSize(headerTableSize); + } + + Long maxHeaderListSize = settings.maxHeaderListSize(); + if (maxHeaderListSize != null) { + outboundHeaderConfig.maxHeaderListSize(maxHeaderListSize); + } + + Integer maxFrameSize = settings.maxFrameSize(); + if (maxFrameSize != null) { + outboundFrameSizePolicy.maxFrameSize(maxFrameSize); + } + + Integer initialWindowSize = settings.initialWindowSize(); + if (initialWindowSize != null) { + flowController().initialWindowSize(initialWindowSize); + } + } + + @Override + public ChannelFuture writeData(final ChannelHandlerContext ctx, final int streamId, ByteBuf data, int padding, + final boolean endOfStream, ChannelPromise promise) { + promise = promise.unvoid(); + final Http2Stream stream; + try { + stream = requireStream(streamId); + + // Verify that the stream is in the appropriate state for sending DATA frames. + switch (stream.state()) { + case OPEN: + case HALF_CLOSED_REMOTE: + // Allowed sending DATA frames in these states. + break; + default: + throw new IllegalStateException("Stream " + stream.id() + " in unexpected state " + stream.state()); + } + } catch (Throwable e) { + data.release(); + return promise.setFailure(e); + } + + // Hand control of the frame to the flow controller. + flowController().addFlowControlled(stream, + new FlowControlledData(stream, data, padding, endOfStream, promise)); + return promise; + } + + @Override + public ChannelFuture writeHeaders(ChannelHandlerContext ctx, int streamId, Http2Headers headers, int padding, + boolean endStream, ChannelPromise promise) { + return writeHeaders0(ctx, streamId, headers, false, 0, (short) 0, false, padding, endStream, promise); + } + + private static boolean validateHeadersSentState(Http2Stream stream, Http2Headers headers, boolean isServer, + boolean endOfStream) { + boolean isInformational = isServer && HttpStatusClass.valueOf(headers.status()) == INFORMATIONAL; + if ((isInformational || !endOfStream) && stream.isHeadersSent() || stream.isTrailersSent()) { + throw new IllegalStateException("Stream " + stream.id() + " sent too many headers EOS: " + endOfStream); + } + return isInformational; + } + + @Override + public ChannelFuture writeHeaders(final ChannelHandlerContext ctx, final int streamId, + final Http2Headers headers, final int streamDependency, final short weight, + final boolean exclusive, final int padding, final boolean endOfStream, ChannelPromise promise) { + return writeHeaders0(ctx, streamId, headers, true, streamDependency, + weight, exclusive, padding, endOfStream, promise); + } + + /** + * Write headers via {@link Http2FrameWriter}. If {@code hasPriority} is {@code false} it will ignore the + * {@code streamDependency}, {@code weight} and {@code exclusive} parameters. + */ + private static ChannelFuture sendHeaders(Http2FrameWriter frameWriter, ChannelHandlerContext ctx, int streamId, + Http2Headers headers, final boolean hasPriority, + int streamDependency, final short weight, + boolean exclusive, final int padding, + boolean endOfStream, ChannelPromise promise) { + if (hasPriority) { + return frameWriter.writeHeaders(ctx, streamId, headers, streamDependency, + weight, exclusive, padding, endOfStream, promise); + } + return frameWriter.writeHeaders(ctx, streamId, headers, padding, endOfStream, promise); + } + + private ChannelFuture writeHeaders0(final ChannelHandlerContext ctx, final int streamId, + final Http2Headers headers, final boolean hasPriority, + final int streamDependency, final short weight, + final boolean exclusive, final int padding, + final boolean endOfStream, ChannelPromise promise) { + try { + Http2Stream stream = connection.stream(streamId); + if (stream == null) { + try { + // We don't create the stream in a `halfClosed` state because if this is an initial + // HEADERS frame we don't want the connection state to signify that the HEADERS have + // been sent until after they have been encoded and placed in the outbound buffer. + // Therefore, we let the `LifeCycleManager` will take care of transitioning the state + // as appropriate. + stream = connection.local().createStream(streamId, /*endOfStream*/ false); + } catch (Http2Exception cause) { + if (connection.remote().mayHaveCreatedStream(streamId)) { + promise.tryFailure(new IllegalStateException("Stream no longer exists: " + streamId, cause)); + return promise; + } + throw cause; + } + } else { + switch (stream.state()) { + case RESERVED_LOCAL: + stream.open(endOfStream); + break; + case OPEN: + case HALF_CLOSED_REMOTE: + // Allowed sending headers in these states. + break; + default: + throw new IllegalStateException("Stream " + stream.id() + " in unexpected state " + + stream.state()); + } + } + + // Trailing headers must go through flow control if there are other frames queued in flow control + // for this stream. + Http2RemoteFlowController flowController = flowController(); + if (!endOfStream || !flowController.hasFlowControlled(stream)) { + // The behavior here should mirror that in FlowControlledHeaders + + promise = promise.unvoid(); + boolean isInformational = validateHeadersSentState(stream, headers, connection.isServer(), endOfStream); + + ChannelFuture future = sendHeaders(frameWriter, ctx, streamId, headers, hasPriority, streamDependency, + weight, exclusive, padding, endOfStream, promise); + + // Writing headers may fail during the encode state if they violate HPACK limits. + Throwable failureCause = future.cause(); + if (failureCause == null) { + // Synchronously set the headersSent flag to ensure that we do not subsequently write + // other headers containing pseudo-header fields. + // + // This just sets internal stream state which is used elsewhere in the codec and doesn't + // necessarily mean the write will complete successfully. + stream.headersSent(isInformational); + + if (!future.isSuccess()) { + // Either the future is not done or failed in the meantime. + notifyLifecycleManagerOnError(future, ctx); + } + } else { + lifecycleManager.onError(ctx, true, failureCause); + } + + if (endOfStream) { + // Must handle calling onError before calling closeStreamLocal, otherwise the error handler will + // incorrectly think the stream no longer exists and so may not send RST_STREAM or perform similar + // appropriate action. + lifecycleManager.closeStreamLocal(stream, future); + } + + return future; + } else { + // Pass headers to the flow-controller so it can maintain their sequence relative to DATA frames. + flowController.addFlowControlled(stream, + new FlowControlledHeaders(stream, headers, hasPriority, streamDependency, + weight, exclusive, padding, true, promise)); + return promise; + } + } catch (Throwable t) { + lifecycleManager.onError(ctx, true, t); + promise.tryFailure(t); + return promise; + } + } + + @Override + public ChannelFuture writePriority(ChannelHandlerContext ctx, int streamId, int streamDependency, short weight, + boolean exclusive, ChannelPromise promise) { + return frameWriter.writePriority(ctx, streamId, streamDependency, weight, exclusive, promise); + } + + @Override + public ChannelFuture writeRstStream(ChannelHandlerContext ctx, int streamId, long errorCode, + ChannelPromise promise) { + // Delegate to the lifecycle manager for proper updating of connection state. + return lifecycleManager.resetStream(ctx, streamId, errorCode, promise); + } + + @Override + public ChannelFuture writeSettings(ChannelHandlerContext ctx, Http2Settings settings, + ChannelPromise promise) { + outstandingLocalSettingsQueue.add(settings); + try { + Boolean pushEnabled = settings.pushEnabled(); + if (pushEnabled != null && connection.isServer()) { + throw connectionError(PROTOCOL_ERROR, "Server sending SETTINGS frame with ENABLE_PUSH specified"); + } + } catch (Throwable e) { + return promise.setFailure(e); + } + + return frameWriter.writeSettings(ctx, settings, promise); + } + + @Override + public ChannelFuture writeSettingsAck(ChannelHandlerContext ctx, ChannelPromise promise) { + if (outstandingRemoteSettingsQueue == null) { + return frameWriter.writeSettingsAck(ctx, promise); + } + Http2Settings settings = outstandingRemoteSettingsQueue.poll(); + if (settings == null) { + return promise.setFailure(new Http2Exception(INTERNAL_ERROR, "attempted to write a SETTINGS ACK with no " + + " pending SETTINGS")); + } + SimpleChannelPromiseAggregator aggregator = new SimpleChannelPromiseAggregator(promise, ctx.channel(), + ctx.executor()); + // Acknowledge receipt of the settings. We should do this before we process the settings to ensure our + // remote peer applies these settings before any subsequent frames that we may send which depend upon + // these new settings. See https://github.com/netty/netty/issues/6520. + frameWriter.writeSettingsAck(ctx, aggregator.newPromise()); + + // We create a "new promise" to make sure that status from both the write and the application are taken into + // account independently. + ChannelPromise applySettingsPromise = aggregator.newPromise(); + try { + remoteSettings(settings); + applySettingsPromise.setSuccess(); + } catch (Throwable e) { + applySettingsPromise.setFailure(e); + lifecycleManager.onError(ctx, true, e); + } + return aggregator.doneAllocatingPromises(); + } + + @Override + public ChannelFuture writePing(ChannelHandlerContext ctx, boolean ack, long data, ChannelPromise promise) { + return frameWriter.writePing(ctx, ack, data, promise); + } + + @Override + public ChannelFuture writePushPromise(ChannelHandlerContext ctx, int streamId, int promisedStreamId, + Http2Headers headers, int padding, ChannelPromise promise) { + try { + if (connection.goAwayReceived()) { + throw connectionError(PROTOCOL_ERROR, "Sending PUSH_PROMISE after GO_AWAY received."); + } + + Http2Stream stream = requireStream(streamId); + // Reserve the promised stream. + connection.local().reservePushStream(promisedStreamId, stream); + + promise = promise.unvoid(); + ChannelFuture future = frameWriter.writePushPromise(ctx, streamId, promisedStreamId, headers, padding, + promise); + // Writing headers may fail during the encode state if they violate HPACK limits. + Throwable failureCause = future.cause(); + if (failureCause == null) { + // This just sets internal stream state which is used elsewhere in the codec and doesn't + // necessarily mean the write will complete successfully. + stream.pushPromiseSent(); + + if (!future.isSuccess()) { + // Either the future is not done or failed in the meantime. + notifyLifecycleManagerOnError(future, ctx); + } + } else { + lifecycleManager.onError(ctx, true, failureCause); + } + return future; + } catch (Throwable t) { + lifecycleManager.onError(ctx, true, t); + promise.tryFailure(t); + return promise; + } + } + + @Override + public ChannelFuture writeGoAway(ChannelHandlerContext ctx, int lastStreamId, long errorCode, ByteBuf debugData, + ChannelPromise promise) { + return lifecycleManager.goAway(ctx, lastStreamId, errorCode, debugData, promise); + } + + @Override + public ChannelFuture writeWindowUpdate(ChannelHandlerContext ctx, int streamId, int windowSizeIncrement, + ChannelPromise promise) { + return promise.setFailure(new UnsupportedOperationException("Use the Http2[Inbound|Outbound]FlowController" + + " objects to control window sizes")); + } + + @Override + public ChannelFuture writeFrame(ChannelHandlerContext ctx, byte frameType, int streamId, Http2Flags flags, + ByteBuf payload, ChannelPromise promise) { + return frameWriter.writeFrame(ctx, frameType, streamId, flags, payload, promise); + } + + @Override + public void close() { + frameWriter.close(); + } + + @Override + public Http2Settings pollSentSettings() { + return outstandingLocalSettingsQueue.poll(); + } + + @Override + public Configuration configuration() { + return frameWriter.configuration(); + } + + private Http2Stream requireStream(int streamId) { + Http2Stream stream = connection.stream(streamId); + if (stream == null) { + final String message; + if (connection.streamMayHaveExisted(streamId)) { + message = "Stream no longer exists: " + streamId; + } else { + message = "Stream does not exist: " + streamId; + } + throw new IllegalArgumentException(message); + } + return stream; + } + + @Override + public void consumeReceivedSettings(Http2Settings settings) { + if (outstandingRemoteSettingsQueue == null) { + outstandingRemoteSettingsQueue = new ArrayDeque(2); + } + outstandingRemoteSettingsQueue.add(settings); + } + + /** + * Wrap a DATA frame so it can be written subject to flow-control. Note that this implementation assumes it + * only writes padding once for the entire payload as opposed to writing it once per-frame. This makes the + * {@link #size} calculation deterministic thereby greatly simplifying the implementation. + *

+ * If frame-splitting is required to fit within max-frame-size and flow-control constraints we ensure that + * the passed promise is not completed until last frame write. + *

+ */ + private final class FlowControlledData extends FlowControlledBase { + private final CoalescingBufferQueue queue; + private int dataSize; + + FlowControlledData(Http2Stream stream, ByteBuf buf, int padding, boolean endOfStream, + ChannelPromise promise) { + super(stream, padding, endOfStream, promise); + queue = new CoalescingBufferQueue(promise.channel()); + queue.add(buf, promise); + dataSize = queue.readableBytes(); + } + + @Override + public int size() { + return dataSize + padding; + } + + @Override + public void error(ChannelHandlerContext ctx, Throwable cause) { + queue.releaseAndFailAll(cause); + // Don't update dataSize because we need to ensure the size() method returns a consistent size even after + // error so we don't invalidate flow control when returning bytes to flow control. + // + // That said we will set dataSize and padding to 0 in the write(...) method if we cleared the queue + // because of an error. + lifecycleManager.onError(ctx, true, cause); + } + + @Override + public void write(ChannelHandlerContext ctx, int allowedBytes) { + int queuedData = queue.readableBytes(); + if (!endOfStream) { + if (queuedData == 0) { + if (queue.isEmpty()) { + // When the queue is empty it means we did clear it because of an error(...) call + // (as otherwise we will have at least 1 entry in there), which will happen either when called + // explicit or when the write itself fails. In this case just set dataSize and padding to 0 + // which will signal back that the whole frame was consumed. + // + // See https://github.com/netty/netty/issues/8707. + padding = dataSize = 0; + } else { + // There's no need to write any data frames because there are only empty data frames in the + // queue and it is not end of stream yet. Just complete their promises by getting the buffer + // corresponding to 0 bytes and writing it to the channel (to preserve notification order). + ChannelPromise writePromise = ctx.newPromise().addListener(this); + ctx.write(queue.remove(0, writePromise), writePromise); + } + return; + } + + if (allowedBytes == 0) { + return; + } + } + + // Determine how much data to write. + int writableData = min(queuedData, allowedBytes); + ChannelPromise writePromise = ctx.newPromise().addListener(this); + ByteBuf toWrite = queue.remove(writableData, writePromise); + dataSize = queue.readableBytes(); + + // Determine how much padding to write. + int writablePadding = min(allowedBytes - writableData, padding); + padding -= writablePadding; + + // Write the frame(s). + frameWriter().writeData(ctx, stream.id(), toWrite, writablePadding, + endOfStream && size() == 0, writePromise); + } + + @Override + public boolean merge(ChannelHandlerContext ctx, Http2RemoteFlowController.FlowControlled next) { + FlowControlledData nextData; + if (FlowControlledData.class != next.getClass() || + MAX_VALUE - (nextData = (FlowControlledData) next).size() < size()) { + return false; + } + nextData.queue.copyTo(queue); + dataSize = queue.readableBytes(); + // Given that we're merging data into a frame it doesn't really make sense to accumulate padding. + padding = Math.max(padding, nextData.padding); + endOfStream = nextData.endOfStream; + return true; + } + } + + private void notifyLifecycleManagerOnError(ChannelFuture future, final ChannelHandlerContext ctx) { + future.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + Throwable cause = future.cause(); + if (cause != null) { + lifecycleManager.onError(ctx, true, cause); + } + } + }); + } + + /** + * Wrap headers so they can be written subject to flow-control. While headers do not have cost against the + * flow-control window their order with respect to other frames must be maintained, hence if a DATA frame is + * blocked on flow-control a HEADER frame must wait until this frame has been written. + */ + private final class FlowControlledHeaders extends FlowControlledBase { + private final Http2Headers headers; + private final boolean hasPriority; + private final int streamDependency; + private final short weight; + private final boolean exclusive; + + FlowControlledHeaders(Http2Stream stream, Http2Headers headers, boolean hasPriority, + int streamDependency, short weight, boolean exclusive, + int padding, boolean endOfStream, ChannelPromise promise) { + super(stream, padding, endOfStream, promise.unvoid()); + this.headers = headers; + this.hasPriority = hasPriority; + this.streamDependency = streamDependency; + this.weight = weight; + this.exclusive = exclusive; + } + + @Override + public int size() { + return 0; + } + + @Override + public void error(ChannelHandlerContext ctx, Throwable cause) { + if (ctx != null) { + lifecycleManager.onError(ctx, true, cause); + } + promise.tryFailure(cause); + } + + @Override + public void write(ChannelHandlerContext ctx, int allowedBytes) { + boolean isInformational = validateHeadersSentState(stream, headers, connection.isServer(), endOfStream); + // The code is currently requiring adding this listener before writing, in order to call onError() before + // closeStreamLocal(). + promise.addListener(this); + + ChannelFuture f = sendHeaders(frameWriter, ctx, stream.id(), headers, hasPriority, streamDependency, + weight, exclusive, padding, endOfStream, promise); + // Writing headers may fail during the encode state if they violate HPACK limits. + Throwable failureCause = f.cause(); + if (failureCause == null) { + // This just sets internal stream state which is used elsewhere in the codec and doesn't + // necessarily mean the write will complete successfully. + stream.headersSent(isInformational); + } + } + + @Override + public boolean merge(ChannelHandlerContext ctx, Http2RemoteFlowController.FlowControlled next) { + return false; + } + } + + /** + * Common base type for payloads to deliver via flow-control. + */ + public abstract class FlowControlledBase implements Http2RemoteFlowController.FlowControlled, + ChannelFutureListener { + protected final Http2Stream stream; + protected ChannelPromise promise; + protected boolean endOfStream; + protected int padding; + + FlowControlledBase(final Http2Stream stream, int padding, boolean endOfStream, + final ChannelPromise promise) { + checkPositiveOrZero(padding, "padding"); + this.padding = padding; + this.endOfStream = endOfStream; + this.stream = stream; + this.promise = promise; + } + + @Override + public void writeComplete() { + if (endOfStream) { + lifecycleManager.closeStreamLocal(stream, promise); + } + } + + @Override + public void operationComplete(ChannelFuture future) throws Exception { + if (!future.isSuccess()) { + error(flowController().channelHandlerContext(), future.cause()); + } + } + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2DataFrame.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2DataFrame.java new file mode 100644 index 0000000..742f3f1 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2DataFrame.java @@ -0,0 +1,198 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import io.netty.util.internal.StringUtil; +import io.netty.util.internal.UnstableApi; + +import static io.netty.handler.codec.http2.Http2CodecUtil.verifyPadding; +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +/** + * The default {@link Http2DataFrame} implementation. + */ +@UnstableApi +public final class DefaultHttp2DataFrame extends AbstractHttp2StreamFrame implements Http2DataFrame { + private final ByteBuf content; + private final boolean endStream; + private final int padding; + private final int initialFlowControlledBytes; + + /** + * Equivalent to {@code new DefaultHttp2DataFrame(content, false)}. + * + * @param content non-{@code null} payload + */ + public DefaultHttp2DataFrame(ByteBuf content) { + this(content, false); + } + + /** + * Equivalent to {@code new DefaultHttp2DataFrame(Unpooled.EMPTY_BUFFER, endStream)}. + * + * @param endStream whether this data should terminate the stream + */ + public DefaultHttp2DataFrame(boolean endStream) { + this(Unpooled.EMPTY_BUFFER, endStream); + } + + /** + * Equivalent to {@code new DefaultHttp2DataFrame(content, endStream, 0)}. + * + * @param content non-{@code null} payload + * @param endStream whether this data should terminate the stream + */ + public DefaultHttp2DataFrame(ByteBuf content, boolean endStream) { + this(content, endStream, 0); + } + + /** + * Construct a new data message. + * + * @param content non-{@code null} payload + * @param endStream whether this data should terminate the stream + * @param padding additional bytes that should be added to obscure the true content size. Must be between 0 and + * 256 (inclusive). + */ + public DefaultHttp2DataFrame(ByteBuf content, boolean endStream, int padding) { + this.content = checkNotNull(content, "content"); + this.endStream = endStream; + verifyPadding(padding); + this.padding = padding; + if (content().readableBytes() + (long) padding > Integer.MAX_VALUE) { + throw new IllegalArgumentException("content + padding must be <= Integer.MAX_VALUE"); + } + initialFlowControlledBytes = content().readableBytes() + padding; + } + + @Override + public DefaultHttp2DataFrame stream(Http2FrameStream stream) { + super.stream(stream); + return this; + } + + @Override + public String name() { + return "DATA"; + } + + @Override + public boolean isEndStream() { + return endStream; + } + + @Override + public int padding() { + return padding; + } + + @Override + public ByteBuf content() { + return ByteBufUtil.ensureAccessible(content); + } + + @Override + public int initialFlowControlledBytes() { + return initialFlowControlledBytes; + } + + @Override + public DefaultHttp2DataFrame copy() { + return replace(content().copy()); + } + + @Override + public DefaultHttp2DataFrame duplicate() { + return replace(content().duplicate()); + } + + @Override + public DefaultHttp2DataFrame retainedDuplicate() { + return replace(content().retainedDuplicate()); + } + + @Override + public DefaultHttp2DataFrame replace(ByteBuf content) { + return new DefaultHttp2DataFrame(content, endStream, padding); + } + + @Override + public int refCnt() { + return content.refCnt(); + } + + @Override + public boolean release() { + return content.release(); + } + + @Override + public boolean release(int decrement) { + return content.release(decrement); + } + + @Override + public DefaultHttp2DataFrame retain() { + content.retain(); + return this; + } + + @Override + public DefaultHttp2DataFrame retain(int increment) { + content.retain(increment); + return this; + } + + @Override + public String toString() { + return StringUtil.simpleClassName(this) + "(stream=" + stream() + ", content=" + content + + ", endStream=" + endStream + ", padding=" + padding + ')'; + } + + @Override + public DefaultHttp2DataFrame touch() { + content.touch(); + return this; + } + + @Override + public DefaultHttp2DataFrame touch(Object hint) { + content.touch(hint); + return this; + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof DefaultHttp2DataFrame)) { + return false; + } + DefaultHttp2DataFrame other = (DefaultHttp2DataFrame) o; + return super.equals(other) && content.equals(other.content()) + && endStream == other.endStream && padding == other.padding; + } + + @Override + public int hashCode() { + int hash = super.hashCode(); + hash = hash * 31 + content.hashCode(); + hash = hash * 31 + (endStream ? 0 : 1); + hash = hash * 31 + padding; + return hash; + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2FrameReader.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2FrameReader.java new file mode 100644 index 0000000..a6d6afb --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2FrameReader.java @@ -0,0 +1,775 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.http2.Http2FrameReader.Configuration; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.UnstableApi; + +import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_MAX_FRAME_SIZE; +import static io.netty.handler.codec.http2.Http2CodecUtil.FRAME_HEADER_LENGTH; +import static io.netty.handler.codec.http2.Http2CodecUtil.INT_FIELD_LENGTH; +import static io.netty.handler.codec.http2.Http2CodecUtil.PING_FRAME_PAYLOAD_LENGTH; +import static io.netty.handler.codec.http2.Http2CodecUtil.PRIORITY_ENTRY_LENGTH; +import static io.netty.handler.codec.http2.Http2CodecUtil.SETTINGS_INITIAL_WINDOW_SIZE; +import static io.netty.handler.codec.http2.Http2CodecUtil.SETTING_ENTRY_LENGTH; +import static io.netty.handler.codec.http2.Http2CodecUtil.headerListSizeExceeded; +import static io.netty.handler.codec.http2.Http2CodecUtil.isMaxFrameSizeValid; +import static io.netty.handler.codec.http2.Http2CodecUtil.readUnsignedInt; +import static io.netty.handler.codec.http2.Http2Error.FLOW_CONTROL_ERROR; +import static io.netty.handler.codec.http2.Http2Error.FRAME_SIZE_ERROR; +import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR; +import static io.netty.handler.codec.http2.Http2Exception.connectionError; +import static io.netty.handler.codec.http2.Http2Exception.streamError; +import static io.netty.handler.codec.http2.Http2FrameTypes.CONTINUATION; +import static io.netty.handler.codec.http2.Http2FrameTypes.DATA; +import static io.netty.handler.codec.http2.Http2FrameTypes.GO_AWAY; +import static io.netty.handler.codec.http2.Http2FrameTypes.HEADERS; +import static io.netty.handler.codec.http2.Http2FrameTypes.PING; +import static io.netty.handler.codec.http2.Http2FrameTypes.PRIORITY; +import static io.netty.handler.codec.http2.Http2FrameTypes.PUSH_PROMISE; +import static io.netty.handler.codec.http2.Http2FrameTypes.RST_STREAM; +import static io.netty.handler.codec.http2.Http2FrameTypes.SETTINGS; +import static io.netty.handler.codec.http2.Http2FrameTypes.WINDOW_UPDATE; + +/** + * A {@link Http2FrameReader} that supports all frame types defined by the HTTP/2 specification. + */ +@UnstableApi +public class DefaultHttp2FrameReader implements Http2FrameReader, Http2FrameSizePolicy, Configuration { + private final Http2HeadersDecoder headersDecoder; + + /** + * {@code true} = reading headers, {@code false} = reading payload. + */ + private boolean readingHeaders = true; + /** + * Once set to {@code true} the value will never change. This is set to {@code true} if an unrecoverable error which + * renders the connection unusable. + */ + private boolean readError; + private byte frameType; + private int streamId; + private Http2Flags flags; + private int payloadLength; + private HeadersContinuation headersContinuation; + private int maxFrameSize; + + /** + * Create a new instance. + *

+ * Header names will be validated. + */ + public DefaultHttp2FrameReader() { + this(true); + } + + /** + * Create a new instance. + * @param validateHeaders {@code true} to validate headers. {@code false} to not validate headers. + */ + public DefaultHttp2FrameReader(boolean validateHeaders) { + this(new DefaultHttp2HeadersDecoder(validateHeaders)); + } + + public DefaultHttp2FrameReader(Http2HeadersDecoder headersDecoder) { + this.headersDecoder = headersDecoder; + maxFrameSize = DEFAULT_MAX_FRAME_SIZE; + } + + @Override + public Http2HeadersDecoder.Configuration headersConfiguration() { + return headersDecoder.configuration(); + } + + @Override + public Configuration configuration() { + return this; + } + + @Override + public Http2FrameSizePolicy frameSizePolicy() { + return this; + } + + @Override + public void maxFrameSize(int max) throws Http2Exception { + if (!isMaxFrameSizeValid(max)) { + throw streamError(streamId, FRAME_SIZE_ERROR, + "Invalid MAX_FRAME_SIZE specified in sent settings: %d", max); + } + maxFrameSize = max; + } + + @Override + public int maxFrameSize() { + return maxFrameSize; + } + + @Override + public void close() { + closeHeadersContinuation(); + } + + private void closeHeadersContinuation() { + if (headersContinuation != null) { + headersContinuation.close(); + headersContinuation = null; + } + } + + @Override + public void readFrame(ChannelHandlerContext ctx, ByteBuf input, Http2FrameListener listener) + throws Http2Exception { + if (readError) { + input.skipBytes(input.readableBytes()); + return; + } + try { + do { + if (readingHeaders) { + processHeaderState(input); + if (readingHeaders) { + // Wait until the entire header has arrived. + return; + } + } + + // The header is complete, fall into the next case to process the payload. + // This is to ensure the proper handling of zero-length payloads. In this + // case, we don't want to loop around because there may be no more data + // available, causing us to exit the loop. Instead, we just want to perform + // the first pass at payload processing now. + processPayloadState(ctx, input, listener); + if (!readingHeaders) { + // Wait until the entire payload has arrived. + return; + } + } while (input.isReadable()); + } catch (Http2Exception e) { + readError = !Http2Exception.isStreamError(e); + throw e; + } catch (RuntimeException e) { + readError = true; + throw e; + } catch (Throwable cause) { + readError = true; + PlatformDependent.throwException(cause); + } + } + + private void processHeaderState(ByteBuf in) throws Http2Exception { + if (in.readableBytes() < FRAME_HEADER_LENGTH) { + // Wait until the entire frame header has been read. + return; + } + + // Read the header and prepare the unmarshaller to read the frame. + payloadLength = in.readUnsignedMedium(); + if (payloadLength > maxFrameSize) { + throw connectionError(FRAME_SIZE_ERROR, "Frame length: %d exceeds maximum: %d", payloadLength, + maxFrameSize); + } + frameType = in.readByte(); + flags = new Http2Flags(in.readUnsignedByte()); + streamId = readUnsignedInt(in); + + // We have consumed the data, next time we read we will be expecting to read the frame payload. + readingHeaders = false; + + switch (frameType) { + case DATA: + verifyDataFrame(); + break; + case HEADERS: + verifyHeadersFrame(); + break; + case PRIORITY: + verifyPriorityFrame(); + break; + case RST_STREAM: + verifyRstStreamFrame(); + break; + case SETTINGS: + verifySettingsFrame(); + break; + case PUSH_PROMISE: + verifyPushPromiseFrame(); + break; + case PING: + verifyPingFrame(); + break; + case GO_AWAY: + verifyGoAwayFrame(); + break; + case WINDOW_UPDATE: + verifyWindowUpdateFrame(); + break; + case CONTINUATION: + verifyContinuationFrame(); + break; + default: + // Unknown frame type, could be an extension. + verifyUnknownFrame(); + break; + } + } + + private void processPayloadState(ChannelHandlerContext ctx, ByteBuf in, Http2FrameListener listener) + throws Http2Exception { + if (in.readableBytes() < payloadLength) { + // Wait until the entire payload has been read. + return; + } + + // Only process up to payloadLength bytes. + int payloadEndIndex = in.readerIndex() + payloadLength; + + // We have consumed the data, next time we read we will be expecting to read a frame header. + readingHeaders = true; + + // Read the payload and fire the frame event to the listener. + switch (frameType) { + case DATA: + readDataFrame(ctx, in, payloadEndIndex, listener); + break; + case HEADERS: + readHeadersFrame(ctx, in, payloadEndIndex, listener); + break; + case PRIORITY: + readPriorityFrame(ctx, in, listener); + break; + case RST_STREAM: + readRstStreamFrame(ctx, in, listener); + break; + case SETTINGS: + readSettingsFrame(ctx, in, listener); + break; + case PUSH_PROMISE: + readPushPromiseFrame(ctx, in, payloadEndIndex, listener); + break; + case PING: + readPingFrame(ctx, in.readLong(), listener); + break; + case GO_AWAY: + readGoAwayFrame(ctx, in, payloadEndIndex, listener); + break; + case WINDOW_UPDATE: + readWindowUpdateFrame(ctx, in, listener); + break; + case CONTINUATION: + readContinuationFrame(in, payloadEndIndex, listener); + break; + default: + readUnknownFrame(ctx, in, payloadEndIndex, listener); + break; + } + in.readerIndex(payloadEndIndex); + } + + private void verifyDataFrame() throws Http2Exception { + verifyAssociatedWithAStream(); + verifyNotProcessingHeaders(); + + if (payloadLength < flags.getPaddingPresenceFieldLength()) { + throw streamError(streamId, FRAME_SIZE_ERROR, + "Frame length %d too small.", payloadLength); + } + } + + private void verifyHeadersFrame() throws Http2Exception { + verifyAssociatedWithAStream(); + verifyNotProcessingHeaders(); + + int requiredLength = flags.getPaddingPresenceFieldLength() + flags.getNumPriorityBytes(); + if (payloadLength < requiredLength) { + throw streamError(streamId, FRAME_SIZE_ERROR, + "Frame length too small." + payloadLength); + } + } + + private void verifyPriorityFrame() throws Http2Exception { + verifyAssociatedWithAStream(); + verifyNotProcessingHeaders(); + + if (payloadLength != PRIORITY_ENTRY_LENGTH) { + throw streamError(streamId, FRAME_SIZE_ERROR, + "Invalid frame length %d.", payloadLength); + } + } + + private void verifyRstStreamFrame() throws Http2Exception { + verifyAssociatedWithAStream(); + verifyNotProcessingHeaders(); + + if (payloadLength != INT_FIELD_LENGTH) { + throw connectionError(FRAME_SIZE_ERROR, "Invalid frame length %d.", payloadLength); + } + } + + private void verifySettingsFrame() throws Http2Exception { + verifyNotProcessingHeaders(); + if (streamId != 0) { + throw connectionError(PROTOCOL_ERROR, "A stream ID must be zero."); + } + if (flags.ack() && payloadLength > 0) { + throw connectionError(FRAME_SIZE_ERROR, "Ack settings frame must have an empty payload."); + } + if (payloadLength % SETTING_ENTRY_LENGTH > 0) { + throw connectionError(FRAME_SIZE_ERROR, "Frame length %d invalid.", payloadLength); + } + } + + private void verifyPushPromiseFrame() throws Http2Exception { + verifyNotProcessingHeaders(); + + // Subtract the length of the promised stream ID field, to determine the length of the + // rest of the payload (header block fragment + payload). + int minLength = flags.getPaddingPresenceFieldLength() + INT_FIELD_LENGTH; + if (payloadLength < minLength) { + throw streamError(streamId, FRAME_SIZE_ERROR, + "Frame length %d too small.", payloadLength); + } + } + + private void verifyPingFrame() throws Http2Exception { + verifyNotProcessingHeaders(); + if (streamId != 0) { + throw connectionError(PROTOCOL_ERROR, "A stream ID must be zero."); + } + if (payloadLength != PING_FRAME_PAYLOAD_LENGTH) { + throw connectionError(FRAME_SIZE_ERROR, + "Frame length %d incorrect size for ping.", payloadLength); + } + } + + private void verifyGoAwayFrame() throws Http2Exception { + verifyNotProcessingHeaders(); + + if (streamId != 0) { + throw connectionError(PROTOCOL_ERROR, "A stream ID must be zero."); + } + if (payloadLength < 8) { + throw connectionError(FRAME_SIZE_ERROR, "Frame length %d too small.", payloadLength); + } + } + + private void verifyWindowUpdateFrame() throws Http2Exception { + verifyNotProcessingHeaders(); + verifyStreamOrConnectionId(streamId, "Stream ID"); + + if (payloadLength != INT_FIELD_LENGTH) { + throw connectionError(FRAME_SIZE_ERROR, "Invalid frame length %d.", payloadLength); + } + } + + private void verifyContinuationFrame() throws Http2Exception { + verifyAssociatedWithAStream(); + + if (headersContinuation == null) { + throw connectionError(PROTOCOL_ERROR, "Received %s frame but not currently processing headers.", + frameType); + } + + if (streamId != headersContinuation.getStreamId()) { + throw connectionError(PROTOCOL_ERROR, "Continuation stream ID does not match pending headers. " + + "Expected %d, but received %d.", headersContinuation.getStreamId(), streamId); + } + + if (payloadLength < flags.getPaddingPresenceFieldLength()) { + throw streamError(streamId, FRAME_SIZE_ERROR, + "Frame length %d too small for padding.", payloadLength); + } + } + + private void verifyUnknownFrame() throws Http2Exception { + verifyNotProcessingHeaders(); + } + + private void readDataFrame(ChannelHandlerContext ctx, ByteBuf payload, int payloadEndIndex, + Http2FrameListener listener) throws Http2Exception { + int padding = readPadding(payload); + verifyPadding(padding); + + // Determine how much data there is to read by removing the trailing + // padding. + int dataLength = lengthWithoutTrailingPadding(payloadEndIndex - payload.readerIndex(), padding); + + ByteBuf data = payload.readSlice(dataLength); + listener.onDataRead(ctx, streamId, data, padding, flags.endOfStream()); + } + + private void readHeadersFrame(final ChannelHandlerContext ctx, ByteBuf payload, int payloadEndIndex, + Http2FrameListener listener) throws Http2Exception { + final int headersStreamId = streamId; + final Http2Flags headersFlags = flags; + final int padding = readPadding(payload); + verifyPadding(padding); + + // The callback that is invoked is different depending on whether priority information + // is present in the headers frame. + if (flags.priorityPresent()) { + long word1 = payload.readUnsignedInt(); + final boolean exclusive = (word1 & 0x80000000L) != 0; + final int streamDependency = (int) (word1 & 0x7FFFFFFFL); + if (streamDependency == streamId) { + throw streamError(streamId, PROTOCOL_ERROR, "A stream cannot depend on itself."); + } + final short weight = (short) (payload.readUnsignedByte() + 1); + final int lenToRead = lengthWithoutTrailingPadding(payloadEndIndex - payload.readerIndex(), padding); + + // Create a handler that invokes the listener when the header block is complete. + headersContinuation = new HeadersContinuation() { + @Override + public int getStreamId() { + return headersStreamId; + } + + @Override + public void processFragment(boolean endOfHeaders, ByteBuf fragment, int len, + Http2FrameListener listener) throws Http2Exception { + final HeadersBlockBuilder hdrBlockBuilder = headersBlockBuilder(); + hdrBlockBuilder.addFragment(fragment, len, ctx.alloc(), endOfHeaders); + if (endOfHeaders) { + listener.onHeadersRead(ctx, headersStreamId, hdrBlockBuilder.headers(), streamDependency, + weight, exclusive, padding, headersFlags.endOfStream()); + } + } + }; + + // Process the initial fragment, invoking the listener's callback if end of headers. + headersContinuation.processFragment(flags.endOfHeaders(), payload, lenToRead, listener); + resetHeadersContinuationIfEnd(flags.endOfHeaders()); + return; + } + + // The priority fields are not present in the frame. Prepare a continuation that invokes + // the listener callback without priority information. + headersContinuation = new HeadersContinuation() { + @Override + public int getStreamId() { + return headersStreamId; + } + + @Override + public void processFragment(boolean endOfHeaders, ByteBuf fragment, int len, + Http2FrameListener listener) throws Http2Exception { + final HeadersBlockBuilder hdrBlockBuilder = headersBlockBuilder(); + hdrBlockBuilder.addFragment(fragment, len, ctx.alloc(), endOfHeaders); + if (endOfHeaders) { + listener.onHeadersRead(ctx, headersStreamId, hdrBlockBuilder.headers(), padding, + headersFlags.endOfStream()); + } + } + }; + + // Process the initial fragment, invoking the listener's callback if end of headers. + int len = lengthWithoutTrailingPadding(payloadEndIndex - payload.readerIndex(), padding); + headersContinuation.processFragment(flags.endOfHeaders(), payload, len, listener); + resetHeadersContinuationIfEnd(flags.endOfHeaders()); + } + + private void resetHeadersContinuationIfEnd(boolean endOfHeaders) { + if (endOfHeaders) { + closeHeadersContinuation(); + } + } + + private void readPriorityFrame(ChannelHandlerContext ctx, ByteBuf payload, + Http2FrameListener listener) throws Http2Exception { + long word1 = payload.readUnsignedInt(); + boolean exclusive = (word1 & 0x80000000L) != 0; + int streamDependency = (int) (word1 & 0x7FFFFFFFL); + if (streamDependency == streamId) { + throw streamError(streamId, PROTOCOL_ERROR, "A stream cannot depend on itself."); + } + short weight = (short) (payload.readUnsignedByte() + 1); + listener.onPriorityRead(ctx, streamId, streamDependency, weight, exclusive); + } + + private void readRstStreamFrame(ChannelHandlerContext ctx, ByteBuf payload, + Http2FrameListener listener) throws Http2Exception { + long errorCode = payload.readUnsignedInt(); + listener.onRstStreamRead(ctx, streamId, errorCode); + } + + private void readSettingsFrame(ChannelHandlerContext ctx, ByteBuf payload, + Http2FrameListener listener) throws Http2Exception { + if (flags.ack()) { + listener.onSettingsAckRead(ctx); + } else { + int numSettings = payloadLength / SETTING_ENTRY_LENGTH; + Http2Settings settings = new Http2Settings(); + for (int index = 0; index < numSettings; ++index) { + char id = (char) payload.readUnsignedShort(); + long value = payload.readUnsignedInt(); + try { + settings.put(id, Long.valueOf(value)); + } catch (IllegalArgumentException e) { + if (id == SETTINGS_INITIAL_WINDOW_SIZE) { + throw connectionError(FLOW_CONTROL_ERROR, e, + "Failed setting initial window size: %s", e.getMessage()); + } + throw connectionError(PROTOCOL_ERROR, e, "Protocol error: %s", e.getMessage()); + } + } + listener.onSettingsRead(ctx, settings); + } + } + + private void readPushPromiseFrame(final ChannelHandlerContext ctx, ByteBuf payload, int payloadEndIndex, + Http2FrameListener listener) throws Http2Exception { + final int pushPromiseStreamId = streamId; + final int padding = readPadding(payload); + verifyPadding(padding); + final int promisedStreamId = readUnsignedInt(payload); + + // Create a handler that invokes the listener when the header block is complete. + headersContinuation = new HeadersContinuation() { + @Override + public int getStreamId() { + return pushPromiseStreamId; + } + + @Override + public void processFragment(boolean endOfHeaders, ByteBuf fragment, int len, + Http2FrameListener listener) throws Http2Exception { + headersBlockBuilder().addFragment(fragment, len, ctx.alloc(), endOfHeaders); + if (endOfHeaders) { + listener.onPushPromiseRead(ctx, pushPromiseStreamId, promisedStreamId, + headersBlockBuilder().headers(), padding); + } + } + }; + + // Process the initial fragment, invoking the listener's callback if end of headers. + int len = lengthWithoutTrailingPadding(payloadEndIndex - payload.readerIndex(), padding); + headersContinuation.processFragment(flags.endOfHeaders(), payload, len, listener); + resetHeadersContinuationIfEnd(flags.endOfHeaders()); + } + + private void readPingFrame(ChannelHandlerContext ctx, long data, + Http2FrameListener listener) throws Http2Exception { + if (flags.ack()) { + listener.onPingAckRead(ctx, data); + } else { + listener.onPingRead(ctx, data); + } + } + + private static void readGoAwayFrame(ChannelHandlerContext ctx, ByteBuf payload, int payloadEndIndex, + Http2FrameListener listener) throws Http2Exception { + int lastStreamId = readUnsignedInt(payload); + long errorCode = payload.readUnsignedInt(); + ByteBuf debugData = payload.readSlice(payloadEndIndex - payload.readerIndex()); + listener.onGoAwayRead(ctx, lastStreamId, errorCode, debugData); + } + + private void readWindowUpdateFrame(ChannelHandlerContext ctx, ByteBuf payload, + Http2FrameListener listener) throws Http2Exception { + int windowSizeIncrement = readUnsignedInt(payload); + if (windowSizeIncrement == 0) { + throw streamError(streamId, PROTOCOL_ERROR, + "Received WINDOW_UPDATE with delta 0 for stream: %d", streamId); + } + listener.onWindowUpdateRead(ctx, streamId, windowSizeIncrement); + } + + private void readContinuationFrame(ByteBuf payload, int payloadEndIndex, Http2FrameListener listener) + throws Http2Exception { + // Process the initial fragment, invoking the listener's callback if end of headers. + headersContinuation.processFragment(flags.endOfHeaders(), payload, + payloadEndIndex - payload.readerIndex(), listener); + resetHeadersContinuationIfEnd(flags.endOfHeaders()); + } + + private void readUnknownFrame(ChannelHandlerContext ctx, ByteBuf payload, + int payloadEndIndex, Http2FrameListener listener) throws Http2Exception { + payload = payload.readSlice(payloadEndIndex - payload.readerIndex()); + listener.onUnknownFrame(ctx, frameType, streamId, flags, payload); + } + + /** + * If padding is present in the payload, reads the next byte as padding. The padding also includes the one byte + * width of the pad length field. Otherwise, returns zero. + */ + private int readPadding(ByteBuf payload) { + if (!flags.paddingPresent()) { + return 0; + } + return payload.readUnsignedByte() + 1; + } + + private void verifyPadding(int padding) throws Http2Exception { + int len = lengthWithoutTrailingPadding(payloadLength, padding); + if (len < 0) { + throw connectionError(PROTOCOL_ERROR, "Frame payload too small for padding."); + } + } + + /** + * The padding parameter consists of the 1 byte pad length field and the trailing padding bytes. This method + * returns the number of readable bytes without the trailing padding. + */ + private static int lengthWithoutTrailingPadding(int readableBytes, int padding) { + return padding == 0 + ? readableBytes + : readableBytes - (padding - 1); + } + + /** + * Base class for processing of HEADERS and PUSH_PROMISE header blocks that potentially span + * multiple frames. The implementation of this interface will perform the final callback to the + * {@link Http2FrameListener} once the end of headers is reached. + */ + private abstract class HeadersContinuation { + private final HeadersBlockBuilder builder = new HeadersBlockBuilder(); + + /** + * Returns the stream for which headers are currently being processed. + */ + abstract int getStreamId(); + + /** + * Processes the next fragment for the current header block. + * + * @param endOfHeaders whether the fragment is the last in the header block. + * @param fragment the fragment of the header block to be added. + * @param listener the listener to be notified if the header block is completed. + */ + abstract void processFragment(boolean endOfHeaders, ByteBuf fragment, int len, + Http2FrameListener listener) throws Http2Exception; + + final HeadersBlockBuilder headersBlockBuilder() { + return builder; + } + + /** + * Free any allocated resources. + */ + final void close() { + builder.close(); + } + } + + /** + * Utility class to help with construction of the headers block that may potentially span + * multiple frames. + */ + protected class HeadersBlockBuilder { + private ByteBuf headerBlock; + + /** + * The local header size maximum has been exceeded while accumulating bytes. + * @throws Http2Exception A connection error indicating too much data has been received. + */ + private void headerSizeExceeded() throws Http2Exception { + close(); + headerListSizeExceeded(headersDecoder.configuration().maxHeaderListSizeGoAway()); + } + + /** + * Adds a fragment to the block. + * + * @param fragment the fragment of the headers block to be added. + * @param alloc allocator for new blocks if needed. + * @param endOfHeaders flag indicating whether the current frame is the end of the headers. + * This is used for an optimization for when the first fragment is the full + * block. In that case, the buffer is used directly without copying. + */ + final void addFragment(ByteBuf fragment, int len, ByteBufAllocator alloc, + boolean endOfHeaders) throws Http2Exception { + if (headerBlock == null) { + if (len > headersDecoder.configuration().maxHeaderListSizeGoAway()) { + headerSizeExceeded(); + } + if (endOfHeaders) { + // Optimization - don't bother copying, just use the buffer as-is. Need + // to retain since we release when the header block is built. + headerBlock = fragment.readRetainedSlice(len); + } else { + headerBlock = alloc.buffer(len).writeBytes(fragment, len); + } + return; + } + if (headersDecoder.configuration().maxHeaderListSizeGoAway() - len < + headerBlock.readableBytes()) { + headerSizeExceeded(); + } + if (headerBlock.isWritable(len)) { + // The buffer can hold the requested bytes, just write it directly. + headerBlock.writeBytes(fragment, len); + } else { + // Allocate a new buffer that is big enough to hold the entire header block so far. + ByteBuf buf = alloc.buffer(headerBlock.readableBytes() + len); + buf.writeBytes(headerBlock).writeBytes(fragment, len); + headerBlock.release(); + headerBlock = buf; + } + } + + /** + * Builds the headers from the completed headers block. After this is called, this builder + * should not be called again. + */ + Http2Headers headers() throws Http2Exception { + try { + return headersDecoder.decodeHeaders(streamId, headerBlock); + } finally { + close(); + } + } + + /** + * Closes this builder and frees any resources. + */ + void close() { + if (headerBlock != null) { + headerBlock.release(); + headerBlock = null; + } + + // Clear the member variable pointing at this instance. + headersContinuation = null; + } + } + + /** + * Verify that current state is not processing on header block + * @throws Http2Exception thrown if {@link #headersContinuation} is not null + */ + private void verifyNotProcessingHeaders() throws Http2Exception { + if (headersContinuation != null) { + throw connectionError(PROTOCOL_ERROR, "Received frame of type %s while processing headers on stream %d.", + frameType, headersContinuation.getStreamId()); + } + } + + private void verifyAssociatedWithAStream() throws Http2Exception { + if (streamId == 0) { + throw connectionError(PROTOCOL_ERROR, "Frame of type %s must be associated with a stream.", frameType); + } + } + + private static void verifyStreamOrConnectionId(int streamId, String argumentName) + throws Http2Exception { + if (streamId < 0) { + throw connectionError(PROTOCOL_ERROR, "%s must be >= 0", argumentName); + } + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2FrameWriter.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2FrameWriter.java new file mode 100644 index 0000000..9b60892 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2FrameWriter.java @@ -0,0 +1,627 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.handler.codec.http2.Http2CodecUtil.SimpleChannelPromiseAggregator; +import io.netty.handler.codec.http2.Http2FrameWriter.Configuration; +import io.netty.handler.codec.http2.Http2HeadersEncoder.SensitivityDetector; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.UnstableApi; + +import static io.netty.buffer.Unpooled.directBuffer; +import static io.netty.buffer.Unpooled.unreleasableBuffer; +import static io.netty.handler.codec.http2.Http2CodecUtil.CONTINUATION_FRAME_HEADER_LENGTH; +import static io.netty.handler.codec.http2.Http2CodecUtil.DATA_FRAME_HEADER_LENGTH; +import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_MAX_FRAME_SIZE; +import static io.netty.handler.codec.http2.Http2CodecUtil.FRAME_HEADER_LENGTH; +import static io.netty.handler.codec.http2.Http2CodecUtil.GO_AWAY_FRAME_HEADER_LENGTH; +import static io.netty.handler.codec.http2.Http2CodecUtil.HEADERS_FRAME_HEADER_LENGTH; +import static io.netty.handler.codec.http2.Http2CodecUtil.INT_FIELD_LENGTH; +import static io.netty.handler.codec.http2.Http2CodecUtil.MAX_UNSIGNED_BYTE; +import static io.netty.handler.codec.http2.Http2CodecUtil.MAX_UNSIGNED_INT; +import static io.netty.handler.codec.http2.Http2CodecUtil.MAX_WEIGHT; +import static io.netty.handler.codec.http2.Http2CodecUtil.MIN_WEIGHT; +import static io.netty.handler.codec.http2.Http2CodecUtil.PING_FRAME_PAYLOAD_LENGTH; +import static io.netty.handler.codec.http2.Http2CodecUtil.PRIORITY_ENTRY_LENGTH; +import static io.netty.handler.codec.http2.Http2CodecUtil.PRIORITY_FRAME_LENGTH; +import static io.netty.handler.codec.http2.Http2CodecUtil.PUSH_PROMISE_FRAME_HEADER_LENGTH; +import static io.netty.handler.codec.http2.Http2CodecUtil.RST_STREAM_FRAME_LENGTH; +import static io.netty.handler.codec.http2.Http2CodecUtil.SETTING_ENTRY_LENGTH; +import static io.netty.handler.codec.http2.Http2CodecUtil.WINDOW_UPDATE_FRAME_LENGTH; +import static io.netty.handler.codec.http2.Http2CodecUtil.isMaxFrameSizeValid; +import static io.netty.handler.codec.http2.Http2CodecUtil.verifyPadding; +import static io.netty.handler.codec.http2.Http2CodecUtil.writeFrameHeaderInternal; +import static io.netty.handler.codec.http2.Http2Error.FRAME_SIZE_ERROR; +import static io.netty.handler.codec.http2.Http2Exception.connectionError; +import static io.netty.handler.codec.http2.Http2FrameTypes.CONTINUATION; +import static io.netty.handler.codec.http2.Http2FrameTypes.DATA; +import static io.netty.handler.codec.http2.Http2FrameTypes.GO_AWAY; +import static io.netty.handler.codec.http2.Http2FrameTypes.HEADERS; +import static io.netty.handler.codec.http2.Http2FrameTypes.PING; +import static io.netty.handler.codec.http2.Http2FrameTypes.PRIORITY; +import static io.netty.handler.codec.http2.Http2FrameTypes.PUSH_PROMISE; +import static io.netty.handler.codec.http2.Http2FrameTypes.RST_STREAM; +import static io.netty.handler.codec.http2.Http2FrameTypes.SETTINGS; +import static io.netty.handler.codec.http2.Http2FrameTypes.WINDOW_UPDATE; +import static io.netty.util.internal.ObjectUtil.checkNotNull; +import static io.netty.util.internal.ObjectUtil.checkPositive; +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; +import static java.lang.Math.max; +import static java.lang.Math.min; + +/** + * A {@link Http2FrameWriter} that supports all frame types defined by the HTTP/2 specification. + */ +@UnstableApi +public class DefaultHttp2FrameWriter implements Http2FrameWriter, Http2FrameSizePolicy, Configuration { + private static final String STREAM_ID = "Stream ID"; + private static final String STREAM_DEPENDENCY = "Stream Dependency"; + /** + * This buffer is allocated to the maximum size of the padding field, and filled with zeros. + * When padding is needed it can be taken as a slice of this buffer. Users should call {@link ByteBuf#retain()} + * before using their slice. + */ + private static final ByteBuf ZERO_BUFFER = + unreleasableBuffer(directBuffer(MAX_UNSIGNED_BYTE).writeZero(MAX_UNSIGNED_BYTE)).asReadOnly(); + + private final Http2HeadersEncoder headersEncoder; + private int maxFrameSize; + + public DefaultHttp2FrameWriter() { + this(new DefaultHttp2HeadersEncoder()); + } + + public DefaultHttp2FrameWriter(SensitivityDetector headersSensitivityDetector) { + this(new DefaultHttp2HeadersEncoder(headersSensitivityDetector)); + } + + public DefaultHttp2FrameWriter(SensitivityDetector headersSensitivityDetector, boolean ignoreMaxHeaderListSize) { + this(new DefaultHttp2HeadersEncoder(headersSensitivityDetector, ignoreMaxHeaderListSize)); + } + + public DefaultHttp2FrameWriter(Http2HeadersEncoder headersEncoder) { + this.headersEncoder = headersEncoder; + maxFrameSize = DEFAULT_MAX_FRAME_SIZE; + } + + @Override + public Configuration configuration() { + return this; + } + + @Override + public Http2HeadersEncoder.Configuration headersConfiguration() { + return headersEncoder.configuration(); + } + + @Override + public Http2FrameSizePolicy frameSizePolicy() { + return this; + } + + @Override + public void maxFrameSize(int max) throws Http2Exception { + if (!isMaxFrameSizeValid(max)) { + throw connectionError(FRAME_SIZE_ERROR, "Invalid MAX_FRAME_SIZE specified in sent settings: %d", max); + } + maxFrameSize = max; + } + + @Override + public int maxFrameSize() { + return maxFrameSize; + } + + @Override + public void close() { } + + @Override + public ChannelFuture writeData(ChannelHandlerContext ctx, int streamId, ByteBuf data, + int padding, boolean endStream, ChannelPromise promise) { + final SimpleChannelPromiseAggregator promiseAggregator = + new SimpleChannelPromiseAggregator(promise, ctx.channel(), ctx.executor()); + ByteBuf frameHeader = null; + try { + verifyStreamId(streamId, STREAM_ID); + verifyPadding(padding); + + int remainingData = data.readableBytes(); + Http2Flags flags = new Http2Flags(); + flags.endOfStream(false); + flags.paddingPresent(false); + // Fast path to write frames of payload size maxFrameSize first. + if (remainingData > maxFrameSize) { + frameHeader = ctx.alloc().buffer(FRAME_HEADER_LENGTH); + writeFrameHeaderInternal(frameHeader, maxFrameSize, DATA, flags, streamId); + do { + // Write the header. + ctx.write(frameHeader.retainedSlice(), promiseAggregator.newPromise()); + + // Write the payload. + ctx.write(data.readRetainedSlice(maxFrameSize), promiseAggregator.newPromise()); + + remainingData -= maxFrameSize; + // Stop iterating if remainingData == maxFrameSize so we can take care of reference counts below. + } while (remainingData > maxFrameSize); + } + + if (padding == 0) { + // Write the header. + if (frameHeader != null) { + frameHeader.release(); + frameHeader = null; + } + ByteBuf frameHeader2 = ctx.alloc().buffer(FRAME_HEADER_LENGTH); + flags.endOfStream(endStream); + writeFrameHeaderInternal(frameHeader2, remainingData, DATA, flags, streamId); + ctx.write(frameHeader2, promiseAggregator.newPromise()); + + // Write the payload. + ByteBuf lastFrame = data.readSlice(remainingData); + data = null; + ctx.write(lastFrame, promiseAggregator.newPromise()); + } else { + if (remainingData != maxFrameSize) { + if (frameHeader != null) { + frameHeader.release(); + frameHeader = null; + } + } else { + remainingData -= maxFrameSize; + // Write the header. + ByteBuf lastFrame; + if (frameHeader == null) { + lastFrame = ctx.alloc().buffer(FRAME_HEADER_LENGTH); + writeFrameHeaderInternal(lastFrame, maxFrameSize, DATA, flags, streamId); + } else { + lastFrame = frameHeader.slice(); + frameHeader = null; + } + ctx.write(lastFrame, promiseAggregator.newPromise()); + + // Write the payload. + lastFrame = data.readableBytes() != maxFrameSize ? data.readSlice(maxFrameSize) : data; + data = null; + ctx.write(lastFrame, promiseAggregator.newPromise()); + } + + do { + int frameDataBytes = min(remainingData, maxFrameSize); + int framePaddingBytes = min(padding, max(0, maxFrameSize - 1 - frameDataBytes)); + + // Decrement the remaining counters. + padding -= framePaddingBytes; + remainingData -= frameDataBytes; + + // Write the header. + ByteBuf frameHeader2 = ctx.alloc().buffer(DATA_FRAME_HEADER_LENGTH); + flags.endOfStream(endStream && remainingData == 0 && padding == 0); + flags.paddingPresent(framePaddingBytes > 0); + writeFrameHeaderInternal(frameHeader2, framePaddingBytes + frameDataBytes, DATA, flags, streamId); + writePaddingLength(frameHeader2, framePaddingBytes); + ctx.write(frameHeader2, promiseAggregator.newPromise()); + + // Write the payload. + if (data != null) { // Make sure Data is not null + if (remainingData == 0) { + ByteBuf lastFrame = data.readSlice(frameDataBytes); + data = null; + ctx.write(lastFrame, promiseAggregator.newPromise()); + } else { + ctx.write(data.readRetainedSlice(frameDataBytes), promiseAggregator.newPromise()); + } + } + // Write the frame padding. + if (paddingBytes(framePaddingBytes) > 0) { + ctx.write(ZERO_BUFFER.slice(0, paddingBytes(framePaddingBytes)), + promiseAggregator.newPromise()); + } + } while (remainingData != 0 || padding != 0); + } + } catch (Throwable cause) { + if (frameHeader != null) { + frameHeader.release(); + } + // Use a try/finally here in case the data has been released before calling this method. This is not + // necessary above because we internally allocate frameHeader. + try { + if (data != null) { + data.release(); + } + } finally { + promiseAggregator.setFailure(cause); + promiseAggregator.doneAllocatingPromises(); + } + return promiseAggregator; + } + return promiseAggregator.doneAllocatingPromises(); + } + + @Override + public ChannelFuture writeHeaders(ChannelHandlerContext ctx, int streamId, + Http2Headers headers, int padding, boolean endStream, ChannelPromise promise) { + return writeHeadersInternal(ctx, streamId, headers, padding, endStream, + false, 0, (short) 0, false, promise); + } + + @Override + public ChannelFuture writeHeaders(ChannelHandlerContext ctx, int streamId, + Http2Headers headers, int streamDependency, short weight, boolean exclusive, + int padding, boolean endStream, ChannelPromise promise) { + return writeHeadersInternal(ctx, streamId, headers, padding, endStream, + true, streamDependency, weight, exclusive, promise); + } + + @Override + public ChannelFuture writePriority(ChannelHandlerContext ctx, int streamId, + int streamDependency, short weight, boolean exclusive, ChannelPromise promise) { + try { + verifyStreamId(streamId, STREAM_ID); + verifyStreamOrConnectionId(streamDependency, STREAM_DEPENDENCY); + verifyWeight(weight); + + ByteBuf buf = ctx.alloc().buffer(PRIORITY_FRAME_LENGTH); + writeFrameHeaderInternal(buf, PRIORITY_ENTRY_LENGTH, PRIORITY, new Http2Flags(), streamId); + buf.writeInt(exclusive ? (int) (0x80000000L | streamDependency) : streamDependency); + // Adjust the weight so that it fits into a single byte on the wire. + buf.writeByte(weight - 1); + return ctx.write(buf, promise); + } catch (Throwable t) { + return promise.setFailure(t); + } + } + + @Override + public ChannelFuture writeRstStream(ChannelHandlerContext ctx, int streamId, long errorCode, + ChannelPromise promise) { + try { + verifyStreamId(streamId, STREAM_ID); + verifyErrorCode(errorCode); + + ByteBuf buf = ctx.alloc().buffer(RST_STREAM_FRAME_LENGTH); + writeFrameHeaderInternal(buf, INT_FIELD_LENGTH, RST_STREAM, new Http2Flags(), streamId); + buf.writeInt((int) errorCode); + return ctx.write(buf, promise); + } catch (Throwable t) { + return promise.setFailure(t); + } + } + + @Override + public ChannelFuture writeSettings(ChannelHandlerContext ctx, Http2Settings settings, + ChannelPromise promise) { + try { + checkNotNull(settings, "settings"); + int payloadLength = SETTING_ENTRY_LENGTH * settings.size(); + ByteBuf buf = ctx.alloc().buffer(FRAME_HEADER_LENGTH + payloadLength); + writeFrameHeaderInternal(buf, payloadLength, SETTINGS, new Http2Flags(), 0); + for (Http2Settings.PrimitiveEntry entry : settings.entries()) { + buf.writeChar(entry.key()); + buf.writeInt(entry.value().intValue()); + } + return ctx.write(buf, promise); + } catch (Throwable t) { + return promise.setFailure(t); + } + } + + @Override + public ChannelFuture writeSettingsAck(ChannelHandlerContext ctx, ChannelPromise promise) { + try { + ByteBuf buf = ctx.alloc().buffer(FRAME_HEADER_LENGTH); + writeFrameHeaderInternal(buf, 0, SETTINGS, new Http2Flags().ack(true), 0); + return ctx.write(buf, promise); + } catch (Throwable t) { + return promise.setFailure(t); + } + } + + @Override + public ChannelFuture writePing(ChannelHandlerContext ctx, boolean ack, long data, ChannelPromise promise) { + Http2Flags flags = ack ? new Http2Flags().ack(true) : new Http2Flags(); + ByteBuf buf = ctx.alloc().buffer(FRAME_HEADER_LENGTH + PING_FRAME_PAYLOAD_LENGTH); + // Assume nothing below will throw until buf is written. That way we don't have to take care of ownership + // in the catch block. + writeFrameHeaderInternal(buf, PING_FRAME_PAYLOAD_LENGTH, PING, flags, 0); + buf.writeLong(data); + return ctx.write(buf, promise); + } + + @Override + public ChannelFuture writePushPromise(ChannelHandlerContext ctx, int streamId, + int promisedStreamId, Http2Headers headers, int padding, ChannelPromise promise) { + ByteBuf headerBlock = null; + SimpleChannelPromiseAggregator promiseAggregator = + new SimpleChannelPromiseAggregator(promise, ctx.channel(), ctx.executor()); + try { + verifyStreamId(streamId, STREAM_ID); + verifyStreamId(promisedStreamId, "Promised Stream ID"); + verifyPadding(padding); + + // Encode the entire header block into an intermediate buffer. + headerBlock = ctx.alloc().buffer(); + headersEncoder.encodeHeaders(streamId, headers, headerBlock); + + // Read the first fragment (possibly everything). + Http2Flags flags = new Http2Flags().paddingPresent(padding > 0); + // INT_FIELD_LENGTH is for the length of the promisedStreamId + int nonFragmentLength = INT_FIELD_LENGTH + padding; + int maxFragmentLength = maxFrameSize - nonFragmentLength; + ByteBuf fragment = headerBlock.readRetainedSlice(min(headerBlock.readableBytes(), maxFragmentLength)); + + flags.endOfHeaders(!headerBlock.isReadable()); + + int payloadLength = fragment.readableBytes() + nonFragmentLength; + ByteBuf buf = ctx.alloc().buffer(PUSH_PROMISE_FRAME_HEADER_LENGTH); + writeFrameHeaderInternal(buf, payloadLength, PUSH_PROMISE, flags, streamId); + writePaddingLength(buf, padding); + + // Write out the promised stream ID. + buf.writeInt(promisedStreamId); + ctx.write(buf, promiseAggregator.newPromise()); + + // Write the first fragment. + ctx.write(fragment, promiseAggregator.newPromise()); + + // Write out the padding, if any. + if (paddingBytes(padding) > 0) { + ctx.write(ZERO_BUFFER.slice(0, paddingBytes(padding)), promiseAggregator.newPromise()); + } + + if (!flags.endOfHeaders()) { + writeContinuationFrames(ctx, streamId, headerBlock, promiseAggregator); + } + } catch (Http2Exception e) { + promiseAggregator.setFailure(e); + } catch (Throwable t) { + promiseAggregator.setFailure(t); + promiseAggregator.doneAllocatingPromises(); + PlatformDependent.throwException(t); + } finally { + if (headerBlock != null) { + headerBlock.release(); + } + } + return promiseAggregator.doneAllocatingPromises(); + } + + @Override + public ChannelFuture writeGoAway(ChannelHandlerContext ctx, int lastStreamId, long errorCode, + ByteBuf debugData, ChannelPromise promise) { + SimpleChannelPromiseAggregator promiseAggregator = + new SimpleChannelPromiseAggregator(promise, ctx.channel(), ctx.executor()); + try { + verifyStreamOrConnectionId(lastStreamId, "Last Stream ID"); + verifyErrorCode(errorCode); + + int payloadLength = 8 + debugData.readableBytes(); + ByteBuf buf = ctx.alloc().buffer(GO_AWAY_FRAME_HEADER_LENGTH); + // Assume nothing below will throw until buf is written. That way we don't have to take care of ownership + // in the catch block. + writeFrameHeaderInternal(buf, payloadLength, GO_AWAY, new Http2Flags(), 0); + buf.writeInt(lastStreamId); + buf.writeInt((int) errorCode); + ctx.write(buf, promiseAggregator.newPromise()); + } catch (Throwable t) { + try { + debugData.release(); + } finally { + promiseAggregator.setFailure(t); + promiseAggregator.doneAllocatingPromises(); + } + return promiseAggregator; + } + + try { + ctx.write(debugData, promiseAggregator.newPromise()); + } catch (Throwable t) { + promiseAggregator.setFailure(t); + } + return promiseAggregator.doneAllocatingPromises(); + } + + @Override + public ChannelFuture writeWindowUpdate(ChannelHandlerContext ctx, int streamId, + int windowSizeIncrement, ChannelPromise promise) { + try { + verifyStreamOrConnectionId(streamId, STREAM_ID); + verifyWindowSizeIncrement(windowSizeIncrement); + + ByteBuf buf = ctx.alloc().buffer(WINDOW_UPDATE_FRAME_LENGTH); + writeFrameHeaderInternal(buf, INT_FIELD_LENGTH, WINDOW_UPDATE, new Http2Flags(), streamId); + buf.writeInt(windowSizeIncrement); + return ctx.write(buf, promise); + } catch (Throwable t) { + return promise.setFailure(t); + } + } + + @Override + public ChannelFuture writeFrame(ChannelHandlerContext ctx, byte frameType, int streamId, + Http2Flags flags, ByteBuf payload, ChannelPromise promise) { + SimpleChannelPromiseAggregator promiseAggregator = + new SimpleChannelPromiseAggregator(promise, ctx.channel(), ctx.executor()); + try { + verifyStreamOrConnectionId(streamId, STREAM_ID); + ByteBuf buf = ctx.alloc().buffer(FRAME_HEADER_LENGTH); + // Assume nothing below will throw until buf is written. That way we don't have to take care of ownership + // in the catch block. + writeFrameHeaderInternal(buf, payload.readableBytes(), frameType, flags, streamId); + ctx.write(buf, promiseAggregator.newPromise()); + } catch (Throwable t) { + try { + payload.release(); + } finally { + promiseAggregator.setFailure(t); + promiseAggregator.doneAllocatingPromises(); + } + return promiseAggregator; + } + try { + ctx.write(payload, promiseAggregator.newPromise()); + } catch (Throwable t) { + promiseAggregator.setFailure(t); + } + return promiseAggregator.doneAllocatingPromises(); + } + + private ChannelFuture writeHeadersInternal(ChannelHandlerContext ctx, + int streamId, Http2Headers headers, int padding, boolean endStream, + boolean hasPriority, int streamDependency, short weight, boolean exclusive, ChannelPromise promise) { + ByteBuf headerBlock = null; + SimpleChannelPromiseAggregator promiseAggregator = + new SimpleChannelPromiseAggregator(promise, ctx.channel(), ctx.executor()); + try { + verifyStreamId(streamId, STREAM_ID); + if (hasPriority) { + verifyStreamOrConnectionId(streamDependency, STREAM_DEPENDENCY); + verifyPadding(padding); + verifyWeight(weight); + } + + // Encode the entire header block. + headerBlock = ctx.alloc().buffer(); + headersEncoder.encodeHeaders(streamId, headers, headerBlock); + + Http2Flags flags = + new Http2Flags().endOfStream(endStream).priorityPresent(hasPriority).paddingPresent(padding > 0); + + // Read the first fragment (possibly everything). + int nonFragmentBytes = padding + flags.getNumPriorityBytes(); + int maxFragmentLength = maxFrameSize - nonFragmentBytes; + ByteBuf fragment = headerBlock.readRetainedSlice(min(headerBlock.readableBytes(), maxFragmentLength)); + + // Set the end of headers flag for the first frame. + flags.endOfHeaders(!headerBlock.isReadable()); + + int payloadLength = fragment.readableBytes() + nonFragmentBytes; + ByteBuf buf = ctx.alloc().buffer(HEADERS_FRAME_HEADER_LENGTH); + writeFrameHeaderInternal(buf, payloadLength, HEADERS, flags, streamId); + writePaddingLength(buf, padding); + + if (hasPriority) { + buf.writeInt(exclusive ? (int) (0x80000000L | streamDependency) : streamDependency); + + // Adjust the weight so that it fits into a single byte on the wire. + buf.writeByte(weight - 1); + } + ctx.write(buf, promiseAggregator.newPromise()); + + // Write the first fragment. + ctx.write(fragment, promiseAggregator.newPromise()); + + // Write out the padding, if any. + if (paddingBytes(padding) > 0) { + ctx.write(ZERO_BUFFER.slice(0, paddingBytes(padding)), promiseAggregator.newPromise()); + } + + if (!flags.endOfHeaders()) { + writeContinuationFrames(ctx, streamId, headerBlock, promiseAggregator); + } + } catch (Http2Exception e) { + promiseAggregator.setFailure(e); + } catch (Throwable t) { + promiseAggregator.setFailure(t); + promiseAggregator.doneAllocatingPromises(); + PlatformDependent.throwException(t); + } finally { + if (headerBlock != null) { + headerBlock.release(); + } + } + return promiseAggregator.doneAllocatingPromises(); + } + + /** + * Writes as many continuation frames as needed until {@code padding} and {@code headerBlock} are consumed. + */ + private ChannelFuture writeContinuationFrames(ChannelHandlerContext ctx, int streamId, + ByteBuf headerBlock, SimpleChannelPromiseAggregator promiseAggregator) { + Http2Flags flags = new Http2Flags(); + + if (headerBlock.isReadable()) { + // The frame header (and padding) only changes on the last frame, so allocate it once and re-use + int fragmentReadableBytes = min(headerBlock.readableBytes(), maxFrameSize); + ByteBuf buf = ctx.alloc().buffer(CONTINUATION_FRAME_HEADER_LENGTH); + writeFrameHeaderInternal(buf, fragmentReadableBytes, CONTINUATION, flags, streamId); + + do { + fragmentReadableBytes = min(headerBlock.readableBytes(), maxFrameSize); + ByteBuf fragment = headerBlock.readRetainedSlice(fragmentReadableBytes); + + if (headerBlock.isReadable()) { + ctx.write(buf.retain(), promiseAggregator.newPromise()); + } else { + // The frame header is different for the last frame, so re-allocate and release the old buffer + flags = flags.endOfHeaders(true); + buf.release(); + buf = ctx.alloc().buffer(CONTINUATION_FRAME_HEADER_LENGTH); + writeFrameHeaderInternal(buf, fragmentReadableBytes, CONTINUATION, flags, streamId); + ctx.write(buf, promiseAggregator.newPromise()); + } + + ctx.write(fragment, promiseAggregator.newPromise()); + + } while (headerBlock.isReadable()); + } + return promiseAggregator; + } + + /** + * Returns the number of padding bytes that should be appended to the end of a frame. + */ + private static int paddingBytes(int padding) { + // The padding parameter contains the 1 byte pad length field as well as the trailing padding bytes. + // Subtract 1, so to only get the number of padding bytes that need to be appended to the end of a frame. + return padding - 1; + } + + private static void writePaddingLength(ByteBuf buf, int padding) { + if (padding > 0) { + // It is assumed that the padding length has been bounds checked before this + // Minus 1, as the pad length field is included in the padding parameter and is 1 byte wide. + buf.writeByte(padding - 1); + } + } + + private static void verifyStreamId(int streamId, String argumentName) { + checkPositive(streamId, argumentName); + } + + private static void verifyStreamOrConnectionId(int streamId, String argumentName) { + checkPositiveOrZero(streamId, argumentName); + } + + private static void verifyWeight(short weight) { + if (weight < MIN_WEIGHT || weight > MAX_WEIGHT) { + throw new IllegalArgumentException("Invalid weight: " + weight); + } + } + + private static void verifyErrorCode(long errorCode) { + if (errorCode < 0 || errorCode > MAX_UNSIGNED_INT) { + throw new IllegalArgumentException("Invalid errorCode: " + errorCode); + } + } + + private static void verifyWindowSizeIncrement(int windowSizeIncrement) { + checkPositiveOrZero(windowSizeIncrement, "windowSizeIncrement"); + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2GoAwayFrame.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2GoAwayFrame.java new file mode 100644 index 0000000..8940d06 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2GoAwayFrame.java @@ -0,0 +1,179 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http2; + +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.DefaultByteBufHolder; +import io.netty.buffer.Unpooled; +import io.netty.util.internal.StringUtil; +import io.netty.util.internal.UnstableApi; + +/** + * The default {@link Http2GoAwayFrame} implementation. + */ +@UnstableApi +public final class DefaultHttp2GoAwayFrame extends DefaultByteBufHolder implements Http2GoAwayFrame { + + private final long errorCode; + private final int lastStreamId; + private int extraStreamIds; + + /** + * Equivalent to {@code new DefaultHttp2GoAwayFrame(error.code())}. + * + * @param error non-{@code null} reason for the go away + */ + public DefaultHttp2GoAwayFrame(Http2Error error) { + this(error.code()); + } + + /** + * Equivalent to {@code new DefaultHttp2GoAwayFrame(content, Unpooled.EMPTY_BUFFER)}. + * + * @param errorCode reason for the go away + */ + public DefaultHttp2GoAwayFrame(long errorCode) { + this(errorCode, Unpooled.EMPTY_BUFFER); + } + + /** + * + * + * @param error non-{@code null} reason for the go away + * @param content non-{@code null} debug data + */ + public DefaultHttp2GoAwayFrame(Http2Error error, ByteBuf content) { + this(error.code(), content); + } + + /** + * Construct a new GOAWAY message. + * + * @param errorCode reason for the go away + * @param content non-{@code null} debug data + */ + public DefaultHttp2GoAwayFrame(long errorCode, ByteBuf content) { + this(-1, errorCode, content); + } + + /** + * Construct a new GOAWAY message. + * + * This constructor is for internal use only. A user should not have to specify a specific last stream identifier, + * but use {@link #setExtraStreamIds(int)} instead. + */ + DefaultHttp2GoAwayFrame(int lastStreamId, long errorCode, ByteBuf content) { + super(content); + this.errorCode = errorCode; + this.lastStreamId = lastStreamId; + } + + @Override + public String name() { + return "GOAWAY"; + } + + @Override + public long errorCode() { + return errorCode; + } + + @Override + public int extraStreamIds() { + return extraStreamIds; + } + + @Override + public Http2GoAwayFrame setExtraStreamIds(int extraStreamIds) { + checkPositiveOrZero(extraStreamIds, "extraStreamIds"); + this.extraStreamIds = extraStreamIds; + return this; + } + + @Override + public int lastStreamId() { + return lastStreamId; + } + + @Override + public Http2GoAwayFrame copy() { + return new DefaultHttp2GoAwayFrame(lastStreamId, errorCode, content().copy()); + } + + @Override + public Http2GoAwayFrame duplicate() { + return (Http2GoAwayFrame) super.duplicate(); + } + + @Override + public Http2GoAwayFrame retainedDuplicate() { + return (Http2GoAwayFrame) super.retainedDuplicate(); + } + + @Override + public Http2GoAwayFrame replace(ByteBuf content) { + return new DefaultHttp2GoAwayFrame(errorCode, content).setExtraStreamIds(extraStreamIds); + } + + @Override + public Http2GoAwayFrame retain() { + super.retain(); + return this; + } + + @Override + public Http2GoAwayFrame retain(int increment) { + super.retain(increment); + return this; + } + + @Override + public Http2GoAwayFrame touch() { + super.touch(); + return this; + } + + @Override + public Http2GoAwayFrame touch(Object hint) { + super.touch(hint); + return this; + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof DefaultHttp2GoAwayFrame)) { + return false; + } + DefaultHttp2GoAwayFrame other = (DefaultHttp2GoAwayFrame) o; + return errorCode == other.errorCode && extraStreamIds == other.extraStreamIds && super.equals(other); + } + + @Override + public int hashCode() { + int hash = super.hashCode(); + hash = hash * 31 + (int) (errorCode ^ errorCode >>> 32); + hash = hash * 31 + extraStreamIds; + return hash; + } + + @Override + public String toString() { + return StringUtil.simpleClassName(this) + "(errorCode=" + errorCode + ", content=" + content() + + ", extraStreamIds=" + extraStreamIds + ", lastStreamId=" + lastStreamId + ')'; + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2Headers.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2Headers.java new file mode 100644 index 0000000..60a2925 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2Headers.java @@ -0,0 +1,303 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.handler.codec.CharSequenceValueConverter; +import io.netty.handler.codec.DefaultHeaders; +import io.netty.handler.codec.http.HttpHeaderValidationUtil; +import io.netty.util.AsciiString; +import io.netty.util.ByteProcessor; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.UnstableApi; + +import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR; +import static io.netty.handler.codec.http2.Http2Exception.connectionError; +import static io.netty.handler.codec.http2.Http2Headers.PseudoHeaderName.getPseudoHeader; +import static io.netty.handler.codec.http2.Http2Headers.PseudoHeaderName.hasPseudoHeaderFormat; +import static io.netty.util.AsciiString.CASE_INSENSITIVE_HASHER; +import static io.netty.util.AsciiString.CASE_SENSITIVE_HASHER; +import static io.netty.util.AsciiString.isUpperCase; + +@UnstableApi +public class DefaultHttp2Headers + extends DefaultHeaders implements Http2Headers { + private static final ByteProcessor HTTP2_NAME_VALIDATOR_PROCESSOR = new ByteProcessor() { + @Override + public boolean process(byte value) { + return !isUpperCase(value); + } + }; + static final NameValidator HTTP2_NAME_VALIDATOR = new NameValidator() { + @Override + public void validateName(CharSequence name) { + if (name == null || name.length() == 0) { + PlatformDependent.throwException(connectionError(PROTOCOL_ERROR, + "empty headers are not allowed [%s]", name)); + } + + if (name instanceof AsciiString) { + final int index; + try { + index = ((AsciiString) name).forEachByte(HTTP2_NAME_VALIDATOR_PROCESSOR); + } catch (Http2Exception e) { + PlatformDependent.throwException(e); + return; + } catch (Throwable t) { + PlatformDependent.throwException(connectionError(PROTOCOL_ERROR, t, + "unexpected error. invalid header name [%s]", name)); + return; + } + + if (index != -1) { + PlatformDependent.throwException(connectionError(PROTOCOL_ERROR, + "invalid header name [%s]", name)); + } + } else { + for (int i = 0; i < name.length(); ++i) { + if (isUpperCase(name.charAt(i))) { + PlatformDependent.throwException(connectionError(PROTOCOL_ERROR, + "invalid header name [%s]", name)); + } + } + } + + if (hasPseudoHeaderFormat(name)) { + final Http2Headers.PseudoHeaderName pseudoHeader = getPseudoHeader(name); + if (pseudoHeader == null) { + PlatformDependent.throwException(connectionError( + PROTOCOL_ERROR, "Invalid HTTP/2 pseudo-header '%s' encountered.", name)); + } + } + } + }; + + private static final ValueValidator VALUE_VALIDATOR = new ValueValidator() { + @Override + public void validate(CharSequence value) { + int index = HttpHeaderValidationUtil.validateValidHeaderValue(value); + if (index != -1) { + throw new IllegalArgumentException("a header value contains prohibited character 0x" + + Integer.toHexString(value.charAt(index)) + " at index " + index + '.'); + } + } + }; + + private HeaderEntry firstNonPseudo = head; + + /** + * Create a new instance. + *

+ * Header names will be validated according to + * rfc7540. + */ + public DefaultHttp2Headers() { + this(true); + } + + /** + * Create a new instance. + * @param validate {@code true} to validate header names according to + * rfc7540. {@code false} to not validate header names. + */ + @SuppressWarnings("unchecked") + public DefaultHttp2Headers(boolean validate) { + // Case sensitive compare is used because it is cheaper, and header validation can be used to catch invalid + // headers. + super(CASE_SENSITIVE_HASHER, + CharSequenceValueConverter.INSTANCE, + validate ? HTTP2_NAME_VALIDATOR : NameValidator.NOT_NULL); + } + + /** + * Create a new instance. + * @param validate {@code true} to validate header names according to + * rfc7540. {@code false} to not validate header names. + * @param arraySizeHint A hint as to how large the hash data structure should be. + * The next positive power of two will be used. An upper bound may be enforced. + * @see DefaultHttp2Headers#DefaultHttp2Headers(boolean, boolean, int) + */ + @SuppressWarnings("unchecked") + public DefaultHttp2Headers(boolean validate, int arraySizeHint) { + // Case sensitive compare is used because it is cheaper, and header validation can be used to catch invalid + // headers. + super(CASE_SENSITIVE_HASHER, + CharSequenceValueConverter.INSTANCE, + validate ? HTTP2_NAME_VALIDATOR : NameValidator.NOT_NULL, + arraySizeHint); + } + + /** + * Create a new instance. + * @param validate {@code true} to validate header names according to + * rfc7540. {@code false} to not validate header names. + * @param validateValues {@code true} to validate header values according to + * rfc7230 and + * rfc5234. Otherwise, {@code false} + * (the default) to not validate values. + * @param arraySizeHint A hint as to how large the hash data structure should be. + * The next positive power of two will be used. An upper bound may be enforced. + */ + @SuppressWarnings("unchecked") + public DefaultHttp2Headers(boolean validate, boolean validateValues, int arraySizeHint) { + // Case sensitive compare is used because it is cheaper, and header validation can be used to catch invalid + // headers. + super(CASE_SENSITIVE_HASHER, + CharSequenceValueConverter.INSTANCE, + validate ? HTTP2_NAME_VALIDATOR : NameValidator.NOT_NULL, + arraySizeHint, + validateValues ? VALUE_VALIDATOR : (ValueValidator) ValueValidator.NO_VALIDATION); + } + + @Override + protected void validateName(NameValidator validator, boolean forAdd, CharSequence name) { + super.validateName(validator, forAdd, name); + if (nameValidator() == HTTP2_NAME_VALIDATOR && forAdd && hasPseudoHeaderFormat(name)) { + if (contains(name)) { + PlatformDependent.throwException(connectionError( + PROTOCOL_ERROR, "Duplicate HTTP/2 pseudo-header '%s' encountered.", name)); + } + } + } + + @Override + protected void validateValue(ValueValidator validator, CharSequence name, CharSequence value) { + // This method has a noop override for backward compatibility, see https://github.com/netty/netty/pull/12975 + super.validateValue(validator, name, value); + // https://datatracker.ietf.org/doc/html/rfc9113#section-8.3.1 + // pseudo headers must not be empty + if (nameValidator() == HTTP2_NAME_VALIDATOR && (value == null || value.length() == 0) && + hasPseudoHeaderFormat(name)) { + PlatformDependent.throwException(connectionError( + PROTOCOL_ERROR, "HTTP/2 pseudo-header '%s' must not be empty.", name)); + } + } + + @Override + public Http2Headers clear() { + firstNonPseudo = head; + return super.clear(); + } + + @Override + public boolean equals(Object o) { + return o instanceof Http2Headers && equals((Http2Headers) o, CASE_SENSITIVE_HASHER); + } + + @Override + public int hashCode() { + return hashCode(CASE_SENSITIVE_HASHER); + } + + @Override + public Http2Headers method(CharSequence value) { + set(PseudoHeaderName.METHOD.value(), value); + return this; + } + + @Override + public Http2Headers scheme(CharSequence value) { + set(PseudoHeaderName.SCHEME.value(), value); + return this; + } + + @Override + public Http2Headers authority(CharSequence value) { + set(PseudoHeaderName.AUTHORITY.value(), value); + return this; + } + + @Override + public Http2Headers path(CharSequence value) { + set(PseudoHeaderName.PATH.value(), value); + return this; + } + + @Override + public Http2Headers status(CharSequence value) { + set(PseudoHeaderName.STATUS.value(), value); + return this; + } + + @Override + public CharSequence method() { + return get(PseudoHeaderName.METHOD.value()); + } + + @Override + public CharSequence scheme() { + return get(PseudoHeaderName.SCHEME.value()); + } + + @Override + public CharSequence authority() { + return get(PseudoHeaderName.AUTHORITY.value()); + } + + @Override + public CharSequence path() { + return get(PseudoHeaderName.PATH.value()); + } + + @Override + public CharSequence status() { + return get(PseudoHeaderName.STATUS.value()); + } + + @Override + public boolean contains(CharSequence name, CharSequence value) { + return contains(name, value, false); + } + + @Override + public boolean contains(CharSequence name, CharSequence value, boolean caseInsensitive) { + return contains(name, value, caseInsensitive ? CASE_INSENSITIVE_HASHER : CASE_SENSITIVE_HASHER); + } + + @Override + protected final HeaderEntry newHeaderEntry(int h, CharSequence name, CharSequence value, + HeaderEntry next) { + return new Http2HeaderEntry(h, name, value, next); + } + + private final class Http2HeaderEntry extends HeaderEntry { + Http2HeaderEntry(int hash, CharSequence key, CharSequence value, + HeaderEntry next) { + super(hash, key); + this.value = value; + this.next = next; + + // Make sure the pseudo headers fields are first in iteration order + if (hasPseudoHeaderFormat(key)) { + after = firstNonPseudo; + before = firstNonPseudo.before(); + } else { + after = head; + before = head.before(); + if (firstNonPseudo == head) { + firstNonPseudo = this; + } + } + pointNeighborsToThis(); + } + + @Override + protected void remove() { + if (this == firstNonPseudo) { + firstNonPseudo = firstNonPseudo.after(); + } + super.remove(); + } + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2HeadersDecoder.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2HeadersDecoder.java new file mode 100644 index 0000000..9647927 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2HeadersDecoder.java @@ -0,0 +1,213 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.UnstableApi; + +import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_HEADER_LIST_SIZE; +import static io.netty.handler.codec.http2.Http2Error.COMPRESSION_ERROR; +import static io.netty.handler.codec.http2.Http2Error.INTERNAL_ERROR; +import static io.netty.handler.codec.http2.Http2Exception.connectionError; + +@UnstableApi +public class DefaultHttp2HeadersDecoder implements Http2HeadersDecoder, Http2HeadersDecoder.Configuration { + private static final float HEADERS_COUNT_WEIGHT_NEW = 1 / 5f; + private static final float HEADERS_COUNT_WEIGHT_HISTORICAL = 1 - HEADERS_COUNT_WEIGHT_NEW; + + private final HpackDecoder hpackDecoder; + private final boolean validateHeaders; + private final boolean validateHeaderValues; + private long maxHeaderListSizeGoAway; + + /** + * Used to calculate an exponential moving average of header sizes to get an estimate of how large the data + * structure for storing headers should be. + */ + private float headerArraySizeAccumulator = 8; + + public DefaultHttp2HeadersDecoder() { + this(true); + } + + /** + * Create a new instance. + * @param validateHeaders {@code true} to validate headers are valid according to the RFC. + */ + public DefaultHttp2HeadersDecoder(boolean validateHeaders) { + this(validateHeaders, DEFAULT_HEADER_LIST_SIZE); + } + + /** + * Create a new instance. + * + * @param validateHeaders {@code true} to validate headers are valid according to the RFC. + * This validates everything except header values. + * @param validateHeaderValues {@code true} to validate that header values are valid according to the RFC. + * Since this is potentially expensive, it can be enabled separately from {@code validateHeaders}. + */ + public DefaultHttp2HeadersDecoder(boolean validateHeaders, boolean validateHeaderValues) { + this(validateHeaders, validateHeaderValues, DEFAULT_HEADER_LIST_SIZE); + } + + /** + * Create a new instance. + * @param validateHeaders {@code true} to validate headers are valid according to the RFC. + * @param maxHeaderListSize This is the only setting that can be configured before notifying the peer. + * This is because SETTINGS_MAX_HEADER_LIST_SIZE + * allows a lower than advertised limit from being enforced, and the default limit is unlimited + * (which is dangerous). + */ + public DefaultHttp2HeadersDecoder(boolean validateHeaders, long maxHeaderListSize) { + this(validateHeaders, false, new HpackDecoder(maxHeaderListSize)); + } + + /** + * Create a new instance. + * @param validateHeaders {@code true} to validate headers are valid according to the RFC. + * This validates everything except header values. + * @param validateHeaderValues {@code true} to validate that header values are valid according to the RFC. + * Since this is potentially expensive, it can be enabled separately from {@code validateHeaders}. + * @param maxHeaderListSize This is the only setting that can be configured before notifying the peer. + * This is because SETTINGS_MAX_HEADER_LIST_SIZE + * allows a lower than advertised limit from being enforced, and the default limit is unlimited + * (which is dangerous). + */ + public DefaultHttp2HeadersDecoder(boolean validateHeaders, boolean validateHeaderValues, long maxHeaderListSize) { + this(validateHeaders, validateHeaderValues, new HpackDecoder(maxHeaderListSize)); + } + + /** + * Create a new instance. + * @param validateHeaders {@code true} to validate headers are valid according to the RFC. + * This validates everything except header values. + * @param maxHeaderListSize This is the only setting that can be configured before notifying the peer. + * This is because SETTINGS_MAX_HEADER_LIST_SIZE + * allows a lower than advertised limit from being enforced, and the default limit is unlimited + * (which is dangerous). + * @param initialHuffmanDecodeCapacity Does nothing, do not use. + */ + public DefaultHttp2HeadersDecoder(boolean validateHeaders, long maxHeaderListSize, + @Deprecated int initialHuffmanDecodeCapacity) { + this(validateHeaders, false, new HpackDecoder(maxHeaderListSize)); + } + + /** + * Exposed for testing only! Default values used in the initial settings frame are overridden intentionally + * for testing but violate the RFC if used outside the scope of testing. + */ + DefaultHttp2HeadersDecoder(boolean validateHeaders, boolean validateHeaderValues, HpackDecoder hpackDecoder) { + this.hpackDecoder = ObjectUtil.checkNotNull(hpackDecoder, "hpackDecoder"); + this.validateHeaders = validateHeaders; + this.validateHeaderValues = validateHeaderValues; + maxHeaderListSizeGoAway = + Http2CodecUtil.calculateMaxHeaderListSizeGoAway(hpackDecoder.getMaxHeaderListSize()); + } + + @Override + public void maxHeaderTableSize(long max) throws Http2Exception { + hpackDecoder.setMaxHeaderTableSize(max); + } + + @Override + public long maxHeaderTableSize() { + return hpackDecoder.getMaxHeaderTableSize(); + } + + @Override + public void maxHeaderListSize(long max, long goAwayMax) throws Http2Exception { + if (goAwayMax < max || goAwayMax < 0) { + throw connectionError(INTERNAL_ERROR, "Header List Size GO_AWAY %d must be non-negative and >= %d", + goAwayMax, max); + } + hpackDecoder.setMaxHeaderListSize(max); + maxHeaderListSizeGoAway = goAwayMax; + } + + @Override + public long maxHeaderListSize() { + return hpackDecoder.getMaxHeaderListSize(); + } + + @Override + public long maxHeaderListSizeGoAway() { + return maxHeaderListSizeGoAway; + } + + @Override + public Configuration configuration() { + return this; + } + + @Override + public Http2Headers decodeHeaders(int streamId, ByteBuf headerBlock) throws Http2Exception { + try { + final Http2Headers headers = newHeaders(); + hpackDecoder.decode(streamId, headerBlock, headers, validateHeaders); + headerArraySizeAccumulator = HEADERS_COUNT_WEIGHT_NEW * headers.size() + + HEADERS_COUNT_WEIGHT_HISTORICAL * headerArraySizeAccumulator; + return headers; + } catch (Http2Exception e) { + throw e; + } catch (Throwable e) { + // Default handler for any other types of errors that may have occurred. For example, + // the Header builder throws IllegalArgumentException if the key or value was invalid + // for any reason (e.g. the key was an invalid pseudo-header). + throw connectionError(COMPRESSION_ERROR, e, "Error decoding headers: %s", e.getMessage()); + } + } + + /** + * A weighted moving average estimating how many headers are expected during the decode process. + * @return an estimate of how many headers are expected during the decode process. + */ + protected final int numberOfHeadersGuess() { + return (int) headerArraySizeAccumulator; + } + + /** + * Determines if the headers should be validated as a result of the decode operation. + *

+ * Note: This does not include validation of header values, since that is potentially + * expensive to do. Value validation is instead {@linkplain #validateHeaderValues() enabled separately}. + * + * @return {@code true} if the headers should be validated as a result of the decode operation. + */ + protected final boolean validateHeaders() { + return validateHeaders; + } + + /** + * Determines if the header values should be validated as a result of the decode operation. + *

+ * Note: This only validates the values of headers. All other header validations are + * instead {@linkplain #validateHeaders() enabled separately}. + * + * @return {@code true} if the header values should be validated as a result of the decode operation. + */ + protected boolean validateHeaderValues() { // Not 'final' due to backwards compatibility. + return validateHeaderValues; + } + + /** + * Create a new {@link Http2Headers} object which will store the results of the decode operation. + * @return a new {@link Http2Headers} object which will store the results of the decode operation. + */ + protected Http2Headers newHeaders() { + return new DefaultHttp2Headers(validateHeaders, validateHeaderValues, (int) headerArraySizeAccumulator); + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2HeadersEncoder.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2HeadersEncoder.java new file mode 100644 index 0000000..2338bb3 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2HeadersEncoder.java @@ -0,0 +1,109 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.util.internal.UnstableApi; + +import static io.netty.handler.codec.http2.Http2Error.COMPRESSION_ERROR; +import static io.netty.handler.codec.http2.Http2Exception.connectionError; +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +@UnstableApi +public class DefaultHttp2HeadersEncoder implements Http2HeadersEncoder, Http2HeadersEncoder.Configuration { + private final HpackEncoder hpackEncoder; + private final SensitivityDetector sensitivityDetector; + private ByteBuf tableSizeChangeOutput; + + public DefaultHttp2HeadersEncoder() { + this(NEVER_SENSITIVE); + } + + public DefaultHttp2HeadersEncoder(SensitivityDetector sensitivityDetector) { + this(sensitivityDetector, new HpackEncoder()); + } + + public DefaultHttp2HeadersEncoder(SensitivityDetector sensitivityDetector, boolean ignoreMaxHeaderListSize) { + this(sensitivityDetector, new HpackEncoder(ignoreMaxHeaderListSize)); + } + + public DefaultHttp2HeadersEncoder(SensitivityDetector sensitivityDetector, boolean ignoreMaxHeaderListSize, + int dynamicTableArraySizeHint) { + this(sensitivityDetector, ignoreMaxHeaderListSize, dynamicTableArraySizeHint, HpackEncoder.HUFF_CODE_THRESHOLD); + } + + public DefaultHttp2HeadersEncoder(SensitivityDetector sensitivityDetector, boolean ignoreMaxHeaderListSize, + int dynamicTableArraySizeHint, int huffCodeThreshold) { + this(sensitivityDetector, + new HpackEncoder(ignoreMaxHeaderListSize, dynamicTableArraySizeHint, huffCodeThreshold)); + } + + /** + * Exposed Used for testing only! Default values used in the initial settings frame are overridden intentionally + * for testing but violate the RFC if used outside the scope of testing. + */ + DefaultHttp2HeadersEncoder(SensitivityDetector sensitivityDetector, HpackEncoder hpackEncoder) { + this.sensitivityDetector = checkNotNull(sensitivityDetector, "sensitiveDetector"); + this.hpackEncoder = checkNotNull(hpackEncoder, "hpackEncoder"); + } + + @Override + public void encodeHeaders(int streamId, Http2Headers headers, ByteBuf buffer) throws Http2Exception { + try { + // If there was a change in the table size, serialize the output from the hpackEncoder + // resulting from that change. + if (tableSizeChangeOutput != null && tableSizeChangeOutput.isReadable()) { + buffer.writeBytes(tableSizeChangeOutput); + tableSizeChangeOutput.clear(); + } + + hpackEncoder.encodeHeaders(streamId, buffer, headers, sensitivityDetector); + } catch (Http2Exception e) { + throw e; + } catch (Throwable t) { + throw connectionError(COMPRESSION_ERROR, t, "Failed encoding headers block: %s", t.getMessage()); + } + } + + @Override + public void maxHeaderTableSize(long max) throws Http2Exception { + if (tableSizeChangeOutput == null) { + tableSizeChangeOutput = Unpooled.buffer(); + } + hpackEncoder.setMaxHeaderTableSize(tableSizeChangeOutput, max); + } + + @Override + public long maxHeaderTableSize() { + return hpackEncoder.getMaxHeaderTableSize(); + } + + @Override + public void maxHeaderListSize(long max) throws Http2Exception { + hpackEncoder.setMaxHeaderListSize(max); + } + + @Override + public long maxHeaderListSize() { + return hpackEncoder.getMaxHeaderListSize(); + } + + @Override + public Configuration configuration() { + return this; + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2HeadersFrame.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2HeadersFrame.java new file mode 100644 index 0000000..c06d8ad --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2HeadersFrame.java @@ -0,0 +1,116 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.util.internal.StringUtil; +import io.netty.util.internal.UnstableApi; + +import static io.netty.handler.codec.http2.Http2CodecUtil.verifyPadding; +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +/** + * The default {@link Http2HeadersFrame} implementation. + */ +@UnstableApi +public final class DefaultHttp2HeadersFrame extends AbstractHttp2StreamFrame implements Http2HeadersFrame { + private final Http2Headers headers; + private final boolean endStream; + private final int padding; + + /** + * Equivalent to {@code new DefaultHttp2HeadersFrame(headers, false)}. + * + * @param headers the non-{@code null} headers to send + */ + public DefaultHttp2HeadersFrame(Http2Headers headers) { + this(headers, false); + } + + /** + * Equivalent to {@code new DefaultHttp2HeadersFrame(headers, endStream, 0)}. + * + * @param headers the non-{@code null} headers to send + */ + public DefaultHttp2HeadersFrame(Http2Headers headers, boolean endStream) { + this(headers, endStream, 0); + } + + /** + * Construct a new headers message. + * + * @param headers the non-{@code null} headers to send + * @param endStream whether these headers should terminate the stream + * @param padding additional bytes that should be added to obscure the true content size. Must be between 0 and + * 256 (inclusive). + */ + public DefaultHttp2HeadersFrame(Http2Headers headers, boolean endStream, int padding) { + this.headers = checkNotNull(headers, "headers"); + this.endStream = endStream; + verifyPadding(padding); + this.padding = padding; + } + + @Override + public DefaultHttp2HeadersFrame stream(Http2FrameStream stream) { + super.stream(stream); + return this; + } + + @Override + public String name() { + return "HEADERS"; + } + + @Override + public Http2Headers headers() { + return headers; + } + + @Override + public boolean isEndStream() { + return endStream; + } + + @Override + public int padding() { + return padding; + } + + @Override + public String toString() { + return StringUtil.simpleClassName(this) + "(stream=" + stream() + ", headers=" + headers + + ", endStream=" + endStream + ", padding=" + padding + ')'; + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof DefaultHttp2HeadersFrame)) { + return false; + } + DefaultHttp2HeadersFrame other = (DefaultHttp2HeadersFrame) o; + return super.equals(other) && headers.equals(other.headers) + && endStream == other.endStream && padding == other.padding; + } + + @Override + public int hashCode() { + int hash = super.hashCode(); + hash = hash * 31 + headers.hashCode(); + hash = hash * 31 + (endStream ? 0 : 1); + hash = hash * 31 + padding; + return hash; + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2LocalFlowController.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2LocalFlowController.java new file mode 100644 index 0000000..b1fb150 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2LocalFlowController.java @@ -0,0 +1,648 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package io.netty.handler.codec.http2; + +import static io.netty.handler.codec.http2.Http2CodecUtil.CONNECTION_STREAM_ID; +import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_WINDOW_SIZE; +import static io.netty.handler.codec.http2.Http2CodecUtil.MAX_INITIAL_WINDOW_SIZE; +import static io.netty.handler.codec.http2.Http2CodecUtil.MIN_INITIAL_WINDOW_SIZE; +import static io.netty.handler.codec.http2.Http2Error.FLOW_CONTROL_ERROR; +import static io.netty.handler.codec.http2.Http2Error.INTERNAL_ERROR; +import static io.netty.handler.codec.http2.Http2Exception.connectionError; +import static io.netty.handler.codec.http2.Http2Exception.streamError; +import static io.netty.util.internal.ObjectUtil.checkNotNull; +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; +import static java.lang.Math.max; +import static java.lang.Math.min; +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.http2.Http2Exception.CompositeStreamException; +import io.netty.handler.codec.http2.Http2Exception.StreamException; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.UnstableApi; + +/** + * Basic implementation of {@link Http2LocalFlowController}. + *

+ * This class is NOT thread safe. The assumption is all methods must be invoked from a single thread. + * Typically this thread is the event loop thread for the {@link ChannelHandlerContext} managed by this class. + */ +@UnstableApi +public class DefaultHttp2LocalFlowController implements Http2LocalFlowController { + /** + * The default ratio of window size to initial window size below which a {@code WINDOW_UPDATE} + * is sent to expand the window. + */ + public static final float DEFAULT_WINDOW_UPDATE_RATIO = 0.5f; + + private final Http2Connection connection; + private final Http2Connection.PropertyKey stateKey; + private Http2FrameWriter frameWriter; + private ChannelHandlerContext ctx; + private float windowUpdateRatio; + private int initialWindowSize = DEFAULT_WINDOW_SIZE; + + public DefaultHttp2LocalFlowController(Http2Connection connection) { + this(connection, DEFAULT_WINDOW_UPDATE_RATIO, false); + } + + /** + * Constructs a controller with the given settings. + * + * @param connection the connection state. + * @param windowUpdateRatio the window percentage below which to send a {@code WINDOW_UPDATE}. + * @param autoRefillConnectionWindow if {@code true}, effectively disables the connection window + * in the flow control algorithm as they will always refill automatically without requiring the + * application to consume the bytes. When enabled, the maximum bytes you must be prepared to + * queue is proportional to {@code maximum number of concurrent streams * the initial window + * size per stream} + * (SETTINGS_MAX_CONCURRENT_STREAMS + * SETTINGS_INITIAL_WINDOW_SIZE). + */ + public DefaultHttp2LocalFlowController(Http2Connection connection, + float windowUpdateRatio, + boolean autoRefillConnectionWindow) { + this.connection = checkNotNull(connection, "connection"); + windowUpdateRatio(windowUpdateRatio); + + // Add a flow state for the connection. + stateKey = connection.newKey(); + FlowState connectionState = autoRefillConnectionWindow ? + new AutoRefillState(connection.connectionStream(), initialWindowSize) : + new DefaultState(connection.connectionStream(), initialWindowSize); + connection.connectionStream().setProperty(stateKey, connectionState); + + // Register for notification of new streams. + connection.addListener(new Http2ConnectionAdapter() { + @Override + public void onStreamAdded(Http2Stream stream) { + // Unconditionally used the reduced flow control state because it requires no object allocation + // and the DefaultFlowState will be allocated in onStreamActive. + stream.setProperty(stateKey, REDUCED_FLOW_STATE); + } + + @Override + public void onStreamActive(Http2Stream stream) { + // Need to be sure the stream's initial window is adjusted for SETTINGS + // frames which may have been exchanged while it was in IDLE + stream.setProperty(stateKey, new DefaultState(stream, initialWindowSize)); + } + + @Override + public void onStreamClosed(Http2Stream stream) { + try { + // When a stream is closed, consume any remaining bytes so that they + // are restored to the connection window. + FlowState state = state(stream); + int unconsumedBytes = state.unconsumedBytes(); + if (ctx != null && unconsumedBytes > 0) { + if (consumeAllBytes(state, unconsumedBytes)) { + // As the user has no real control on when this callback is used we should better + // call flush() if we produced any window update to ensure we not stale. + ctx.flush(); + } + } + } catch (Http2Exception e) { + PlatformDependent.throwException(e); + } finally { + // Unconditionally reduce the amount of memory required for flow control because there is no + // object allocation costs associated with doing so and the stream will not have any more + // local flow control state to keep track of anymore. + stream.setProperty(stateKey, REDUCED_FLOW_STATE); + } + } + }); + } + + @Override + public DefaultHttp2LocalFlowController frameWriter(Http2FrameWriter frameWriter) { + this.frameWriter = checkNotNull(frameWriter, "frameWriter"); + return this; + } + + @Override + public void channelHandlerContext(ChannelHandlerContext ctx) { + this.ctx = checkNotNull(ctx, "ctx"); + } + + @Override + public void initialWindowSize(int newWindowSize) throws Http2Exception { + assert ctx == null || ctx.executor().inEventLoop(); + int delta = newWindowSize - initialWindowSize; + initialWindowSize = newWindowSize; + + WindowUpdateVisitor visitor = new WindowUpdateVisitor(delta); + connection.forEachActiveStream(visitor); + visitor.throwIfError(); + } + + @Override + public int initialWindowSize() { + return initialWindowSize; + } + + @Override + public int windowSize(Http2Stream stream) { + return state(stream).windowSize(); + } + + @Override + public int initialWindowSize(Http2Stream stream) { + return state(stream).initialWindowSize(); + } + + @Override + public void incrementWindowSize(Http2Stream stream, int delta) throws Http2Exception { + assert ctx != null && ctx.executor().inEventLoop(); + FlowState state = state(stream); + // Just add the delta to the stream-specific initial window size so that the next time the window + // expands it will grow to the new initial size. + state.incrementInitialStreamWindow(delta); + state.writeWindowUpdateIfNeeded(); + } + + @Override + public boolean consumeBytes(Http2Stream stream, int numBytes) throws Http2Exception { + assert ctx != null && ctx.executor().inEventLoop(); + checkPositiveOrZero(numBytes, "numBytes"); + if (numBytes == 0) { + return false; + } + + // Streams automatically consume all remaining bytes when they are closed, so just ignore + // if already closed. + if (stream != null && !isClosed(stream)) { + if (stream.id() == CONNECTION_STREAM_ID) { + throw new UnsupportedOperationException("Returning bytes for the connection window is not supported"); + } + + return consumeAllBytes(state(stream), numBytes); + } + return false; + } + + private boolean consumeAllBytes(FlowState state, int numBytes) throws Http2Exception { + return connectionState().consumeBytes(numBytes) | state.consumeBytes(numBytes); + } + + @Override + public int unconsumedBytes(Http2Stream stream) { + return state(stream).unconsumedBytes(); + } + + private static void checkValidRatio(float ratio) { + if (Double.compare(ratio, 0.0) <= 0 || Double.compare(ratio, 1.0) >= 0) { + throw new IllegalArgumentException("Invalid ratio: " + ratio); + } + } + + /** + * The window update ratio is used to determine when a window update must be sent. If the ratio + * of bytes processed since the last update has meet or exceeded this ratio then a window update will + * be sent. This is the global window update ratio that will be used for new streams. + * @param ratio the ratio to use when checking if a {@code WINDOW_UPDATE} is determined necessary for new streams. + * @throws IllegalArgumentException If the ratio is out of bounds (0, 1). + */ + public void windowUpdateRatio(float ratio) { + assert ctx == null || ctx.executor().inEventLoop(); + checkValidRatio(ratio); + windowUpdateRatio = ratio; + } + + /** + * The window update ratio is used to determine when a window update must be sent. If the ratio + * of bytes processed since the last update has meet or exceeded this ratio then a window update will + * be sent. This is the global window update ratio that will be used for new streams. + */ + public float windowUpdateRatio() { + return windowUpdateRatio; + } + + /** + * The window update ratio is used to determine when a window update must be sent. If the ratio + * of bytes processed since the last update has meet or exceeded this ratio then a window update will + * be sent. This window update ratio will only be applied to {@code streamId}. + *

+ * Note it is the responsibly of the caller to ensure that the + * initial {@code SETTINGS} frame is sent before this is called. It would + * be considered a {@link Http2Error#PROTOCOL_ERROR} if a {@code WINDOW_UPDATE} + * was generated by this method before the initial {@code SETTINGS} frame is sent. + * @param stream the stream for which {@code ratio} applies to. + * @param ratio the ratio to use when checking if a {@code WINDOW_UPDATE} is determined necessary. + * @throws Http2Exception If a protocol-error occurs while generating {@code WINDOW_UPDATE} frames + */ + public void windowUpdateRatio(Http2Stream stream, float ratio) throws Http2Exception { + assert ctx != null && ctx.executor().inEventLoop(); + checkValidRatio(ratio); + FlowState state = state(stream); + state.windowUpdateRatio(ratio); + state.writeWindowUpdateIfNeeded(); + } + + /** + * The window update ratio is used to determine when a window update must be sent. If the ratio + * of bytes processed since the last update has meet or exceeded this ratio then a window update will + * be sent. This window update ratio will only be applied to {@code streamId}. + * @throws Http2Exception If no stream corresponding to {@code stream} could be found. + */ + public float windowUpdateRatio(Http2Stream stream) throws Http2Exception { + return state(stream).windowUpdateRatio(); + } + + @Override + public void receiveFlowControlledFrame(Http2Stream stream, ByteBuf data, int padding, + boolean endOfStream) throws Http2Exception { + assert ctx != null && ctx.executor().inEventLoop(); + int dataLength = data.readableBytes() + padding; + + // Apply the connection-level flow control + FlowState connectionState = connectionState(); + connectionState.receiveFlowControlledFrame(dataLength); + + if (stream != null && !isClosed(stream)) { + // Apply the stream-level flow control + FlowState state = state(stream); + state.endOfStream(endOfStream); + state.receiveFlowControlledFrame(dataLength); + } else if (dataLength > 0) { + // Immediately consume the bytes for the connection window. + connectionState.consumeBytes(dataLength); + } + } + + private FlowState connectionState() { + return connection.connectionStream().getProperty(stateKey); + } + + private FlowState state(Http2Stream stream) { + return stream.getProperty(stateKey); + } + + private static boolean isClosed(Http2Stream stream) { + return stream.state() == Http2Stream.State.CLOSED; + } + + /** + * Flow control state that does autorefill of the flow control window when the data is + * received. + */ + private final class AutoRefillState extends DefaultState { + AutoRefillState(Http2Stream stream, int initialWindowSize) { + super(stream, initialWindowSize); + } + + @Override + public void receiveFlowControlledFrame(int dataLength) throws Http2Exception { + super.receiveFlowControlledFrame(dataLength); + // Need to call the super to consume the bytes, since this.consumeBytes does nothing. + super.consumeBytes(dataLength); + } + + @Override + public boolean consumeBytes(int numBytes) throws Http2Exception { + // Do nothing, since the bytes are already consumed upon receiving the data. + return false; + } + } + + /** + * Flow control window state for an individual stream. + */ + private class DefaultState implements FlowState { + private final Http2Stream stream; + + /** + * The actual flow control window that is decremented as soon as {@code DATA} arrives. + */ + private int window; + + /** + * A view of {@link #window} that is used to determine when to send {@code WINDOW_UPDATE} + * frames. Decrementing this window for received {@code DATA} frames is delayed until the + * application has indicated that the data has been fully processed. This prevents sending + * a {@code WINDOW_UPDATE} until the number of processed bytes drops below the threshold. + */ + private int processedWindow; + + /** + * This is what is used to determine how many bytes need to be returned relative to {@link #processedWindow}. + * Each stream has their own initial window size. + */ + private int initialStreamWindowSize; + + /** + * This is used to determine when {@link #processedWindow} is sufficiently far away from + * {@link #initialStreamWindowSize} such that a {@code WINDOW_UPDATE} should be sent. + * Each stream has their own window update ratio. + */ + private float streamWindowUpdateRatio; + + private int lowerBound; + private boolean endOfStream; + + DefaultState(Http2Stream stream, int initialWindowSize) { + this.stream = stream; + window(initialWindowSize); + streamWindowUpdateRatio = windowUpdateRatio; + } + + @Override + public void window(int initialWindowSize) { + assert ctx == null || ctx.executor().inEventLoop(); + window = processedWindow = initialStreamWindowSize = initialWindowSize; + } + + @Override + public int windowSize() { + return window; + } + + @Override + public int initialWindowSize() { + return initialStreamWindowSize; + } + + @Override + public void endOfStream(boolean endOfStream) { + this.endOfStream = endOfStream; + } + + @Override + public float windowUpdateRatio() { + return streamWindowUpdateRatio; + } + + @Override + public void windowUpdateRatio(float ratio) { + assert ctx == null || ctx.executor().inEventLoop(); + streamWindowUpdateRatio = ratio; + } + + @Override + public void incrementInitialStreamWindow(int delta) { + // Clip the delta so that the resulting initialStreamWindowSize falls within the allowed range. + int newValue = (int) min(MAX_INITIAL_WINDOW_SIZE, + max(MIN_INITIAL_WINDOW_SIZE, initialStreamWindowSize + (long) delta)); + delta = newValue - initialStreamWindowSize; + + initialStreamWindowSize += delta; + } + + @Override + public void incrementFlowControlWindows(int delta) throws Http2Exception { + if (delta > 0 && window > MAX_INITIAL_WINDOW_SIZE - delta) { + throw streamError(stream.id(), FLOW_CONTROL_ERROR, + "Flow control window overflowed for stream: %d", stream.id()); + } + + window += delta; + processedWindow += delta; + lowerBound = min(delta, 0); + } + + @Override + public void receiveFlowControlledFrame(int dataLength) throws Http2Exception { + assert dataLength >= 0; + + // Apply the delta. Even if we throw an exception we want to have taken this delta into account. + window -= dataLength; + + // Window size can become negative if we sent a SETTINGS frame that reduces the + // size of the transfer window after the peer has written data frames. + // The value is bounded by the length that SETTINGS frame decrease the window. + // This difference is stored for the connection when writing the SETTINGS frame + // and is cleared once we send a WINDOW_UPDATE frame. + if (window < lowerBound) { + throw streamError(stream.id(), FLOW_CONTROL_ERROR, + "Flow control window exceeded for stream: %d", stream.id()); + } + } + + private void returnProcessedBytes(int delta) throws Http2Exception { + if (processedWindow - delta < window) { + throw streamError(stream.id(), INTERNAL_ERROR, + "Attempting to return too many bytes for stream %d", stream.id()); + } + processedWindow -= delta; + } + + @Override + public boolean consumeBytes(int numBytes) throws Http2Exception { + // Return the bytes processed and update the window. + returnProcessedBytes(numBytes); + return writeWindowUpdateIfNeeded(); + } + + @Override + public int unconsumedBytes() { + return processedWindow - window; + } + + @Override + public boolean writeWindowUpdateIfNeeded() throws Http2Exception { + if (endOfStream || initialStreamWindowSize <= 0 || + // If the stream is already closed there is no need to try to write a window update for it. + isClosed(stream)) { + return false; + } + + int threshold = (int) (initialStreamWindowSize * streamWindowUpdateRatio); + if (processedWindow <= threshold) { + writeWindowUpdate(); + return true; + } + return false; + } + + /** + * Called to perform a window update for this stream (or connection). Updates the window size back + * to the size of the initial window and sends a window update frame to the remote endpoint. + */ + private void writeWindowUpdate() throws Http2Exception { + // Expand the window for this stream back to the size of the initial window. + int deltaWindowSize = initialStreamWindowSize - processedWindow; + try { + incrementFlowControlWindows(deltaWindowSize); + } catch (Throwable t) { + throw connectionError(INTERNAL_ERROR, t, + "Attempting to return too many bytes for stream %d", stream.id()); + } + + // Send a window update for the stream/connection. + frameWriter.writeWindowUpdate(ctx, stream.id(), deltaWindowSize, ctx.newPromise()); + } + } + + /** + * The local flow control state for a single stream that is not in a state where flow controlled frames cannot + * be exchanged. + */ + private static final FlowState REDUCED_FLOW_STATE = new FlowState() { + + @Override + public int windowSize() { + return 0; + } + + @Override + public int initialWindowSize() { + return 0; + } + + @Override + public void window(int initialWindowSize) { + throw new UnsupportedOperationException(); + } + + @Override + public void incrementInitialStreamWindow(int delta) { + // This operation needs to be supported during the initial settings exchange when + // the peer has not yet acknowledged this peer being activated. + } + + @Override + public boolean writeWindowUpdateIfNeeded() throws Http2Exception { + throw new UnsupportedOperationException(); + } + + @Override + public boolean consumeBytes(int numBytes) throws Http2Exception { + return false; + } + + @Override + public int unconsumedBytes() { + return 0; + } + + @Override + public float windowUpdateRatio() { + throw new UnsupportedOperationException(); + } + + @Override + public void windowUpdateRatio(float ratio) { + throw new UnsupportedOperationException(); + } + + @Override + public void receiveFlowControlledFrame(int dataLength) throws Http2Exception { + throw new UnsupportedOperationException(); + } + + @Override + public void incrementFlowControlWindows(int delta) throws Http2Exception { + // This operation needs to be supported during the initial settings exchange when + // the peer has not yet acknowledged this peer being activated. + } + + @Override + public void endOfStream(boolean endOfStream) { + throw new UnsupportedOperationException(); + } + }; + + /** + * An abstraction which provides specific extensions used by local flow control. + */ + private interface FlowState { + + int windowSize(); + + int initialWindowSize(); + + void window(int initialWindowSize); + + /** + * Increment the initial window size for this stream. + * @param delta The amount to increase the initial window size by. + */ + void incrementInitialStreamWindow(int delta); + + /** + * Updates the flow control window for this stream if it is appropriate. + * + * @return true if {@code WINDOW_UPDATE} was written, false otherwise. + */ + boolean writeWindowUpdateIfNeeded() throws Http2Exception; + + /** + * Indicates that the application has consumed {@code numBytes} from the connection or stream and is + * ready to receive more data. + * + * @param numBytes the number of bytes to be returned to the flow control window. + * @return true if {@code WINDOW_UPDATE} was written, false otherwise. + * @throws Http2Exception + */ + boolean consumeBytes(int numBytes) throws Http2Exception; + + int unconsumedBytes(); + + float windowUpdateRatio(); + + void windowUpdateRatio(float ratio); + + /** + * A flow control event has occurred and we should decrement the amount of available bytes for this stream. + * @param dataLength The amount of data to for which this stream is no longer eligible to use for flow control. + * @throws Http2Exception If too much data is used relative to how much is available. + */ + void receiveFlowControlledFrame(int dataLength) throws Http2Exception; + + /** + * Increment the windows which are used to determine many bytes have been processed. + * @param delta The amount to increment the window by. + * @throws Http2Exception if integer overflow occurs on the window. + */ + void incrementFlowControlWindows(int delta) throws Http2Exception; + + void endOfStream(boolean endOfStream); + } + + /** + * Provides a means to iterate over all active streams and increment the flow control windows. + */ + private final class WindowUpdateVisitor implements Http2StreamVisitor { + private CompositeStreamException compositeException; + private final int delta; + + WindowUpdateVisitor(int delta) { + this.delta = delta; + } + + @Override + public boolean visit(Http2Stream stream) throws Http2Exception { + try { + // Increment flow control window first so state will be consistent if overflow is detected. + FlowState state = state(stream); + state.incrementFlowControlWindows(delta); + state.incrementInitialStreamWindow(delta); + } catch (StreamException e) { + if (compositeException == null) { + compositeException = new CompositeStreamException(e.error(), 4); + } + compositeException.add(e); + } + return true; + } + + public void throwIfError() throws CompositeStreamException { + if (compositeException != null) { + throw compositeException; + } + } + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2PingFrame.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2PingFrame.java new file mode 100644 index 0000000..f102bd1 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2PingFrame.java @@ -0,0 +1,75 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.codec.http2; + +import io.netty.util.internal.StringUtil; +import io.netty.util.internal.UnstableApi; + +/** + * The default {@link Http2PingFrame} implementation. + */ +@UnstableApi +public class DefaultHttp2PingFrame implements Http2PingFrame { + + private final long content; + private final boolean ack; + + public DefaultHttp2PingFrame(long content) { + this(content, false); + } + + public DefaultHttp2PingFrame(long content, boolean ack) { + this.content = content; + this.ack = ack; + } + + @Override + public boolean ack() { + return ack; + } + + @Override + public String name() { + return "PING"; + } + + @Override + public long content() { + return content; + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof Http2PingFrame)) { + return false; + } + Http2PingFrame other = (Http2PingFrame) o; + return ack == other.ack() && content == other.content(); + } + + @Override + public int hashCode() { + int hash = super.hashCode(); + hash = hash * 31 + (ack ? 1 : 0); + return hash; + } + + @Override + public String toString() { + return StringUtil.simpleClassName(this) + "(content=" + content + ", ack=" + ack + ')'; + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2PriorityFrame.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2PriorityFrame.java new file mode 100644 index 0000000..5e073d5 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2PriorityFrame.java @@ -0,0 +1,91 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.util.internal.UnstableApi; + +/** + * Default implementation of {@linkplain Http2PriorityFrame} + */ +@UnstableApi +public final class DefaultHttp2PriorityFrame extends AbstractHttp2StreamFrame implements Http2PriorityFrame { + + private final int streamDependency; + private final short weight; + private final boolean exclusive; + + public DefaultHttp2PriorityFrame(int streamDependency, short weight, boolean exclusive) { + this.streamDependency = streamDependency; + this.weight = weight; + this.exclusive = exclusive; + } + + @Override + public int streamDependency() { + return streamDependency; + } + + @Override + public short weight() { + return weight; + } + + @Override + public boolean exclusive() { + return exclusive; + } + + @Override + public DefaultHttp2PriorityFrame stream(Http2FrameStream stream) { + super.stream(stream); + return this; + } + + @Override + public String name() { + return "PRIORITY_FRAME"; + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof DefaultHttp2PriorityFrame)) { + return false; + } + DefaultHttp2PriorityFrame other = (DefaultHttp2PriorityFrame) o; + boolean same = super.equals(other); + return same && streamDependency == other.streamDependency + && weight == other.weight && exclusive == other.exclusive; + } + + @Override + public int hashCode() { + int hash = super.hashCode(); + hash = hash * 31 + streamDependency; + hash = hash * 31 + weight; + hash = hash * 31 + (exclusive ? 1 : 0); + return hash; + } + + @Override + public String toString() { + return "DefaultHttp2PriorityFrame(" + + "stream=" + stream() + + ", streamDependency=" + streamDependency + + ", weight=" + weight + + ", exclusive=" + exclusive + + ')'; + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2PushPromiseFrame.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2PushPromiseFrame.java new file mode 100644 index 0000000..f9fd987 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2PushPromiseFrame.java @@ -0,0 +1,101 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.util.internal.UnstableApi; + +/** + * Default implementation of {@link Http2PushPromiseFrame} + */ +@UnstableApi +public final class DefaultHttp2PushPromiseFrame implements Http2PushPromiseFrame { + + private Http2FrameStream pushStreamFrame; + private final Http2Headers http2Headers; + private Http2FrameStream streamFrame; + private final int padding; + private final int promisedStreamId; + + public DefaultHttp2PushPromiseFrame(Http2Headers http2Headers) { + this(http2Headers, 0); + } + + public DefaultHttp2PushPromiseFrame(Http2Headers http2Headers, int padding) { + this(http2Headers, padding, -1); + } + + DefaultHttp2PushPromiseFrame(Http2Headers http2Headers, int padding, int promisedStreamId) { + this.http2Headers = http2Headers; + this.padding = padding; + this.promisedStreamId = promisedStreamId; + } + + @Override + public Http2StreamFrame pushStream(Http2FrameStream stream) { + pushStreamFrame = stream; + return this; + } + + @Override + public Http2FrameStream pushStream() { + return pushStreamFrame; + } + + @Override + public Http2Headers http2Headers() { + return http2Headers; + } + + @Override + public int padding() { + return padding; + } + + @Override + public int promisedStreamId() { + if (pushStreamFrame != null) { + return pushStreamFrame.id(); + } else { + return promisedStreamId; + } + } + + @Override + public Http2PushPromiseFrame stream(Http2FrameStream stream) { + streamFrame = stream; + return this; + } + + @Override + public Http2FrameStream stream() { + return streamFrame; + } + + @Override + public String name() { + return "PUSH_PROMISE_FRAME"; + } + + @Override + public String toString() { + return "DefaultHttp2PushPromiseFrame{" + + "pushStreamFrame=" + pushStreamFrame + + ", http2Headers=" + http2Headers + + ", streamFrame=" + streamFrame + + ", padding=" + padding + + '}'; + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2RemoteFlowController.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2RemoteFlowController.java new file mode 100644 index 0000000..eba1710 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2RemoteFlowController.java @@ -0,0 +1,768 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.util.internal.UnstableApi; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.util.ArrayDeque; +import java.util.Deque; + +import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_WINDOW_SIZE; +import static io.netty.handler.codec.http2.Http2CodecUtil.MAX_WEIGHT; +import static io.netty.handler.codec.http2.Http2CodecUtil.MIN_WEIGHT; +import static io.netty.handler.codec.http2.Http2Error.FLOW_CONTROL_ERROR; +import static io.netty.handler.codec.http2.Http2Error.INTERNAL_ERROR; +import static io.netty.handler.codec.http2.Http2Error.STREAM_CLOSED; +import static io.netty.handler.codec.http2.Http2Exception.streamError; +import static io.netty.handler.codec.http2.Http2Stream.State.HALF_CLOSED_LOCAL; +import static io.netty.util.internal.ObjectUtil.checkNotNull; +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; +import static java.lang.Math.max; +import static java.lang.Math.min; + +/** + * Basic implementation of {@link Http2RemoteFlowController}. + *

+ * This class is NOT thread safe. The assumption is all methods must be invoked from a single thread. + * Typically this thread is the event loop thread for the {@link ChannelHandlerContext} managed by this class. + */ +@UnstableApi +public class DefaultHttp2RemoteFlowController implements Http2RemoteFlowController { + private static final InternalLogger logger = + InternalLoggerFactory.getInstance(DefaultHttp2RemoteFlowController.class); + private static final int MIN_WRITABLE_CHUNK = 32 * 1024; + private final Http2Connection connection; + private final Http2Connection.PropertyKey stateKey; + private final StreamByteDistributor streamByteDistributor; + private final FlowState connectionState; + private int initialWindowSize = DEFAULT_WINDOW_SIZE; + private WritabilityMonitor monitor; + private ChannelHandlerContext ctx; + + public DefaultHttp2RemoteFlowController(Http2Connection connection) { + this(connection, (Listener) null); + } + + public DefaultHttp2RemoteFlowController(Http2Connection connection, + StreamByteDistributor streamByteDistributor) { + this(connection, streamByteDistributor, null); + } + + public DefaultHttp2RemoteFlowController(Http2Connection connection, final Listener listener) { + this(connection, new WeightedFairQueueByteDistributor(connection), listener); + } + + public DefaultHttp2RemoteFlowController(Http2Connection connection, + StreamByteDistributor streamByteDistributor, + final Listener listener) { + this.connection = checkNotNull(connection, "connection"); + this.streamByteDistributor = checkNotNull(streamByteDistributor, "streamWriteDistributor"); + + // Add a flow state for the connection. + stateKey = connection.newKey(); + connectionState = new FlowState(connection.connectionStream()); + connection.connectionStream().setProperty(stateKey, connectionState); + + // Monitor may depend upon connectionState, and so initialize after connectionState + listener(listener); + monitor.windowSize(connectionState, initialWindowSize); + + // Register for notification of new streams. + connection.addListener(new Http2ConnectionAdapter() { + @Override + public void onStreamAdded(Http2Stream stream) { + // If the stream state is not open then the stream is not yet eligible for flow controlled frames and + // only requires the ReducedFlowState. Otherwise the full amount of memory is required. + stream.setProperty(stateKey, new FlowState(stream)); + } + + @Override + public void onStreamActive(Http2Stream stream) { + // If the object was previously created, but later activated then we have to ensure the proper + // initialWindowSize is used. + monitor.windowSize(state(stream), initialWindowSize); + } + + @Override + public void onStreamClosed(Http2Stream stream) { + // Any pending frames can never be written, cancel and + // write errors for any pending frames. + state(stream).cancel(STREAM_CLOSED, null); + } + + @Override + public void onStreamHalfClosed(Http2Stream stream) { + if (HALF_CLOSED_LOCAL == stream.state()) { + /* + * When this method is called there should not be any + * pending frames left if the API is used correctly. However, + * it is possible that a erroneous application can sneak + * in a frame even after having already written a frame with the + * END_STREAM flag set, as the stream state might not transition + * immediately to HALF_CLOSED_LOCAL / CLOSED due to flow control + * delaying the write. + * + * This is to cancel any such illegal writes. + */ + state(stream).cancel(STREAM_CLOSED, null); + } + } + }); + } + + /** + * {@inheritDoc} + *

+ * Any queued {@link FlowControlled} objects will be sent. + */ + @Override + public void channelHandlerContext(ChannelHandlerContext ctx) throws Http2Exception { + this.ctx = checkNotNull(ctx, "ctx"); + + // Writing the pending bytes will not check writability change and instead a writability change notification + // to be provided by an explicit call. + channelWritabilityChanged(); + + // Don't worry about cleaning up queued frames here if ctx is null. It is expected that all streams will be + // closed and the queue cleanup will occur when the stream state transitions occur. + + // If any frames have been queued up, we should send them now that we have a channel context. + if (isChannelWritable()) { + writePendingBytes(); + } + } + + @Override + public ChannelHandlerContext channelHandlerContext() { + return ctx; + } + + @Override + public void initialWindowSize(int newWindowSize) throws Http2Exception { + assert ctx == null || ctx.executor().inEventLoop(); + monitor.initialWindowSize(newWindowSize); + } + + @Override + public int initialWindowSize() { + return initialWindowSize; + } + + @Override + public int windowSize(Http2Stream stream) { + return state(stream).windowSize(); + } + + @Override + public boolean isWritable(Http2Stream stream) { + return monitor.isWritable(state(stream)); + } + + @Override + public void channelWritabilityChanged() throws Http2Exception { + monitor.channelWritabilityChange(); + } + + @Override + public void updateDependencyTree(int childStreamId, int parentStreamId, short weight, boolean exclusive) { + // It is assumed there are all validated at a higher level. For example in the Http2FrameReader. + assert weight >= MIN_WEIGHT && weight <= MAX_WEIGHT : "Invalid weight"; + assert childStreamId != parentStreamId : "A stream cannot depend on itself"; + assert childStreamId > 0 && parentStreamId >= 0 : "childStreamId must be > 0. parentStreamId must be >= 0."; + + streamByteDistributor.updateDependencyTree(childStreamId, parentStreamId, weight, exclusive); + } + + private boolean isChannelWritable() { + return ctx != null && isChannelWritable0(); + } + + private boolean isChannelWritable0() { + return ctx.channel().isWritable(); + } + + @Override + public void listener(Listener listener) { + monitor = listener == null ? new WritabilityMonitor() : new ListenerWritabilityMonitor(listener); + } + + @Override + public void incrementWindowSize(Http2Stream stream, int delta) throws Http2Exception { + assert ctx == null || ctx.executor().inEventLoop(); + monitor.incrementWindowSize(state(stream), delta); + } + + @Override + public void addFlowControlled(Http2Stream stream, FlowControlled frame) { + // The context can be null assuming the frame will be queued and send later when the context is set. + assert ctx == null || ctx.executor().inEventLoop(); + checkNotNull(frame, "frame"); + try { + monitor.enqueueFrame(state(stream), frame); + } catch (Throwable t) { + frame.error(ctx, t); + } + } + + @Override + public boolean hasFlowControlled(Http2Stream stream) { + return state(stream).hasFrame(); + } + + private FlowState state(Http2Stream stream) { + return (FlowState) stream.getProperty(stateKey); + } + + /** + * Returns the flow control window for the entire connection. + */ + private int connectionWindowSize() { + return connectionState.windowSize(); + } + + private int minUsableChannelBytes() { + // The current allocation algorithm values "fairness" and doesn't give any consideration to "goodput". It + // is possible that 1 byte will be allocated to many streams. In an effort to try to make "goodput" + // reasonable with the current allocation algorithm we have this "cheap" check up front to ensure there is + // an "adequate" amount of connection window before allocation is attempted. This is not foolproof as if the + // number of streams is >= this minimal number then we may still have the issue, but the idea is to narrow the + // circumstances in which this can happen without rewriting the allocation algorithm. + return max(ctx.channel().config().getWriteBufferLowWaterMark(), MIN_WRITABLE_CHUNK); + } + + private int maxUsableChannelBytes() { + // If the channel isWritable, allow at least minUsableChannelBytes. + int channelWritableBytes = (int) min(Integer.MAX_VALUE, ctx.channel().bytesBeforeUnwritable()); + int usableBytes = channelWritableBytes > 0 ? max(channelWritableBytes, minUsableChannelBytes()) : 0; + + // Clip the usable bytes by the connection window. + return min(connectionState.windowSize(), usableBytes); + } + + /** + * The amount of bytes that can be supported by underlying {@link io.netty.channel.Channel} without + * queuing "too-much". + */ + private int writableBytes() { + return min(connectionWindowSize(), maxUsableChannelBytes()); + } + + @Override + public void writePendingBytes() throws Http2Exception { + monitor.writePendingBytes(); + } + + /** + * The remote flow control state for a single stream. + */ + private final class FlowState implements StreamByteDistributor.StreamState { + private final Http2Stream stream; + private final Deque pendingWriteQueue; + private int window; + private long pendingBytes; + private boolean markedWritable; + + /** + * Set to true while a frame is being written, false otherwise. + */ + private boolean writing; + /** + * Set to true if cancel() was called. + */ + private boolean cancelled; + + FlowState(Http2Stream stream) { + this.stream = stream; + pendingWriteQueue = new ArrayDeque(2); + } + + /** + * Determine if the stream associated with this object is writable. + * @return {@code true} if the stream associated with this object is writable. + */ + boolean isWritable() { + return windowSize() > pendingBytes() && !cancelled; + } + + /** + * The stream this state is associated with. + */ + @Override + public Http2Stream stream() { + return stream; + } + + /** + * Returns the parameter from the last call to {@link #markedWritability(boolean)}. + */ + boolean markedWritability() { + return markedWritable; + } + + /** + * Save the state of writability. + */ + void markedWritability(boolean isWritable) { + this.markedWritable = isWritable; + } + + @Override + public int windowSize() { + return window; + } + + /** + * Reset the window size for this stream. + */ + void windowSize(int initialWindowSize) { + window = initialWindowSize; + } + + /** + * Write the allocated bytes for this stream. + * @return the number of bytes written for a stream or {@code -1} if no write occurred. + */ + int writeAllocatedBytes(int allocated) { + final int initialAllocated = allocated; + int writtenBytes; + // In case an exception is thrown we want to remember it and pass it to cancel(Throwable). + Throwable cause = null; + FlowControlled frame; + try { + assert !writing; + writing = true; + + // Write the remainder of frames that we are allowed to + boolean writeOccurred = false; + while (!cancelled && (frame = peek()) != null) { + int maxBytes = min(allocated, writableWindow()); + if (maxBytes <= 0 && frame.size() > 0) { + // The frame still has data, but the amount of allocated bytes has been exhausted. + // Don't write needless empty frames. + break; + } + writeOccurred = true; + int initialFrameSize = frame.size(); + try { + frame.write(ctx, max(0, maxBytes)); + if (frame.size() == 0) { + // This frame has been fully written, remove this frame and notify it. + // Since we remove this frame first, we're guaranteed that its error + // method will not be called when we call cancel. + pendingWriteQueue.remove(); + frame.writeComplete(); + } + } finally { + // Decrement allocated by how much was actually written. + allocated -= initialFrameSize - frame.size(); + } + } + + if (!writeOccurred) { + // Either there was no frame, or the amount of allocated bytes has been exhausted. + return -1; + } + + } catch (Throwable t) { + // Mark the state as cancelled, we'll clear the pending queue via cancel() below. + cancelled = true; + cause = t; + } finally { + writing = false; + // Make sure we always decrement the flow control windows + // by the bytes written. + writtenBytes = initialAllocated - allocated; + + decrementPendingBytes(writtenBytes, false); + decrementFlowControlWindow(writtenBytes); + + // If a cancellation occurred while writing, call cancel again to + // clear and error all of the pending writes. + if (cancelled) { + cancel(INTERNAL_ERROR, cause); + } + } + return writtenBytes; + } + + /** + * Increments the flow control window for this stream by the given delta and returns the new value. + */ + int incrementStreamWindow(int delta) throws Http2Exception { + if (delta > 0 && Integer.MAX_VALUE - delta < window) { + throw streamError(stream.id(), FLOW_CONTROL_ERROR, + "Window size overflow for stream: %d", stream.id()); + } + window += delta; + + streamByteDistributor.updateStreamableBytes(this); + return window; + } + + /** + * Returns the maximum writable window (minimum of the stream and connection windows). + */ + private int writableWindow() { + return min(window, connectionWindowSize()); + } + + @Override + public long pendingBytes() { + return pendingBytes; + } + + /** + * Adds the {@code frame} to the pending queue and increments the pending byte count. + */ + void enqueueFrame(FlowControlled frame) { + FlowControlled last = pendingWriteQueue.peekLast(); + if (last == null) { + enqueueFrameWithoutMerge(frame); + return; + } + + int lastSize = last.size(); + if (last.merge(ctx, frame)) { + incrementPendingBytes(last.size() - lastSize, true); + return; + } + enqueueFrameWithoutMerge(frame); + } + + private void enqueueFrameWithoutMerge(FlowControlled frame) { + pendingWriteQueue.offer(frame); + // This must be called after adding to the queue in order so that hasFrame() is + // updated before updating the stream state. + incrementPendingBytes(frame.size(), true); + } + + @Override + public boolean hasFrame() { + return !pendingWriteQueue.isEmpty(); + } + + /** + * Returns the head of the pending queue, or {@code null} if empty. + */ + private FlowControlled peek() { + return pendingWriteQueue.peek(); + } + + /** + * Clears the pending queue and writes errors for each remaining frame. + * @param error the {@link Http2Error} to use. + * @param cause the {@link Throwable} that caused this method to be invoked. + */ + void cancel(Http2Error error, Throwable cause) { + cancelled = true; + // Ensure that the queue can't be modified while we are writing. + if (writing) { + return; + } + + FlowControlled frame = pendingWriteQueue.poll(); + if (frame != null) { + // Only create exception once and reuse to reduce overhead of filling in the stacktrace. + final Http2Exception exception = streamError(stream.id(), error, cause, + "Stream closed before write could take place"); + do { + writeError(frame, exception); + frame = pendingWriteQueue.poll(); + } while (frame != null); + } + + streamByteDistributor.updateStreamableBytes(this); + + monitor.stateCancelled(this); + } + + /** + * Increments the number of pending bytes for this node and optionally updates the + * {@link StreamByteDistributor}. + */ + private void incrementPendingBytes(int numBytes, boolean updateStreamableBytes) { + pendingBytes += numBytes; + monitor.incrementPendingBytes(numBytes); + if (updateStreamableBytes) { + streamByteDistributor.updateStreamableBytes(this); + } + } + + /** + * If this frame is in the pending queue, decrements the number of pending bytes for the stream. + */ + private void decrementPendingBytes(int bytes, boolean updateStreamableBytes) { + incrementPendingBytes(-bytes, updateStreamableBytes); + } + + /** + * Decrement the per stream and connection flow control window by {@code bytes}. + */ + private void decrementFlowControlWindow(int bytes) { + try { + int negativeBytes = -bytes; + connectionState.incrementStreamWindow(negativeBytes); + incrementStreamWindow(negativeBytes); + } catch (Http2Exception e) { + // Should never get here since we're decrementing. + throw new IllegalStateException("Invalid window state when writing frame: " + e.getMessage(), e); + } + } + + /** + * Discards this {@link FlowControlled}, writing an error. If this frame is in the pending queue, + * the unwritten bytes are removed from this branch of the priority tree. + */ + private void writeError(FlowControlled frame, Http2Exception cause) { + assert ctx != null; + decrementPendingBytes(frame.size(), true); + frame.error(ctx, cause); + } + } + + /** + * Abstract class which provides common functionality for writability monitor implementations. + */ + private class WritabilityMonitor implements StreamByteDistributor.Writer { + private boolean inWritePendingBytes; + private long totalPendingBytes; + + @Override + public final void write(Http2Stream stream, int numBytes) { + state(stream).writeAllocatedBytes(numBytes); + } + + /** + * Called when the writability of the underlying channel changes. + * @throws Http2Exception If a write occurs and an exception happens in the write operation. + */ + void channelWritabilityChange() throws Http2Exception { } + + /** + * Called when the state is cancelled. + * @param state the state that was cancelled. + */ + void stateCancelled(FlowState state) { } + + /** + * Set the initial window size for {@code state}. + * @param state the state to change the initial window size for. + * @param initialWindowSize the size of the window in bytes. + */ + void windowSize(FlowState state, int initialWindowSize) { + state.windowSize(initialWindowSize); + } + + /** + * Increment the window size for a particular stream. + * @param state the state associated with the stream whose window is being incremented. + * @param delta The amount to increment by. + * @throws Http2Exception If this operation overflows the window for {@code state}. + */ + void incrementWindowSize(FlowState state, int delta) throws Http2Exception { + state.incrementStreamWindow(delta); + } + + /** + * Add a frame to be sent via flow control. + * @param state The state associated with the stream which the {@code frame} is associated with. + * @param frame the frame to enqueue. + * @throws Http2Exception If a writability error occurs. + */ + void enqueueFrame(FlowState state, FlowControlled frame) throws Http2Exception { + state.enqueueFrame(frame); + } + + /** + * Increment the total amount of pending bytes for all streams. When any stream's pending bytes changes + * method should be called. + * @param delta The amount to increment by. + */ + final void incrementPendingBytes(int delta) { + totalPendingBytes += delta; + + // Notification of writibilty change should be delayed until the end of the top level event. + // This is to ensure the flow controller is more consistent state before calling external listener methods. + } + + /** + * Determine if the stream associated with {@code state} is writable. + * @param state The state which is associated with the stream to test writability for. + * @return {@code true} if {@link FlowState#stream()} is writable. {@code false} otherwise. + */ + final boolean isWritable(FlowState state) { + return isWritableConnection() && state.isWritable(); + } + + final void writePendingBytes() throws Http2Exception { + // Reentry is not permitted during the byte distribution process. It may lead to undesirable distribution of + // bytes and even infinite loops. We protect against reentry and make sure each call has an opportunity to + // cause a distribution to occur. This may be useful for example if the channel's writability changes from + // Writable -> Not Writable (because we are writing) -> Writable (because the user flushed to make more room + // in the channel outbound buffer). + if (inWritePendingBytes) { + return; + } + inWritePendingBytes = true; + try { + int bytesToWrite = writableBytes(); + // Make sure we always write at least once, regardless if we have bytesToWrite or not. + // This ensures that zero-length frames will always be written. + for (;;) { + if (!streamByteDistributor.distribute(bytesToWrite, this) || + (bytesToWrite = writableBytes()) <= 0 || + !isChannelWritable0()) { + break; + } + } + } finally { + inWritePendingBytes = false; + } + } + + void initialWindowSize(int newWindowSize) throws Http2Exception { + checkPositiveOrZero(newWindowSize, "newWindowSize"); + + final int delta = newWindowSize - initialWindowSize; + initialWindowSize = newWindowSize; + connection.forEachActiveStream(new Http2StreamVisitor() { + @Override + public boolean visit(Http2Stream stream) throws Http2Exception { + state(stream).incrementStreamWindow(delta); + return true; + } + }); + + if (delta > 0 && isChannelWritable()) { + // The window size increased, send any pending frames for all streams. + writePendingBytes(); + } + } + + final boolean isWritableConnection() { + return connectionState.windowSize() - totalPendingBytes > 0 && isChannelWritable(); + } + } + + /** + * Writability of a {@code stream} is calculated using the following: + *

+     * Connection Window - Total Queued Bytes > 0 &&
+     * Stream Window - Bytes Queued for Stream > 0 &&
+     * isChannelWritable()
+     * 
+ */ + private final class ListenerWritabilityMonitor extends WritabilityMonitor implements Http2StreamVisitor { + private final Listener listener; + + ListenerWritabilityMonitor(Listener listener) { + this.listener = listener; + } + + @Override + public boolean visit(Http2Stream stream) throws Http2Exception { + FlowState state = state(stream); + if (isWritable(state) != state.markedWritability()) { + notifyWritabilityChanged(state); + } + return true; + } + + @Override + void windowSize(FlowState state, int initialWindowSize) { + super.windowSize(state, initialWindowSize); + try { + checkStateWritability(state); + } catch (Http2Exception e) { + throw new RuntimeException("Caught unexpected exception from window", e); + } + } + + @Override + void incrementWindowSize(FlowState state, int delta) throws Http2Exception { + super.incrementWindowSize(state, delta); + checkStateWritability(state); + } + + @Override + void initialWindowSize(int newWindowSize) throws Http2Exception { + super.initialWindowSize(newWindowSize); + if (isWritableConnection()) { + // If the write operation does not occur we still need to check all streams because they + // may have transitioned from writable to not writable. + checkAllWritabilityChanged(); + } + } + + @Override + void enqueueFrame(FlowState state, FlowControlled frame) throws Http2Exception { + super.enqueueFrame(state, frame); + checkConnectionThenStreamWritabilityChanged(state); + } + + @Override + void stateCancelled(FlowState state) { + try { + checkConnectionThenStreamWritabilityChanged(state); + } catch (Http2Exception e) { + throw new RuntimeException("Caught unexpected exception from checkAllWritabilityChanged", e); + } + } + + @Override + void channelWritabilityChange() throws Http2Exception { + if (connectionState.markedWritability() != isChannelWritable()) { + checkAllWritabilityChanged(); + } + } + + private void checkStateWritability(FlowState state) throws Http2Exception { + if (isWritable(state) != state.markedWritability()) { + if (state == connectionState) { + checkAllWritabilityChanged(); + } else { + notifyWritabilityChanged(state); + } + } + } + + private void notifyWritabilityChanged(FlowState state) { + state.markedWritability(!state.markedWritability()); + try { + listener.writabilityChanged(state.stream); + } catch (Throwable cause) { + logger.error("Caught Throwable from listener.writabilityChanged", cause); + } + } + + private void checkConnectionThenStreamWritabilityChanged(FlowState state) throws Http2Exception { + // It is possible that the connection window and/or the individual stream writability could change. + if (isWritableConnection() != connectionState.markedWritability()) { + checkAllWritabilityChanged(); + } else if (isWritable(state) != state.markedWritability()) { + notifyWritabilityChanged(state); + } + } + + private void checkAllWritabilityChanged() throws Http2Exception { + // Make sure we mark that we have notified as a result of this change. + connectionState.markedWritability(isWritableConnection()); + connection.forEachActiveStream(this); + } + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2ResetFrame.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2ResetFrame.java new file mode 100644 index 0000000..ddc1b22 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2ResetFrame.java @@ -0,0 +1,85 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.util.internal.StringUtil; +import io.netty.util.internal.UnstableApi; + +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +/** + * The default {@link Http2ResetFrame} implementation. + */ +@UnstableApi +public final class DefaultHttp2ResetFrame extends AbstractHttp2StreamFrame implements Http2ResetFrame { + + private final long errorCode; + + /** + * Construct a reset message. + * + * @param error the non-{@code null} reason for reset + */ + public DefaultHttp2ResetFrame(Http2Error error) { + errorCode = checkNotNull(error, "error").code(); + } + + /** + * Construct a reset message. + * + * @param errorCode the reason for reset + */ + public DefaultHttp2ResetFrame(long errorCode) { + this.errorCode = errorCode; + } + + @Override + public DefaultHttp2ResetFrame stream(Http2FrameStream stream) { + super.stream(stream); + return this; + } + + @Override + public String name() { + return "RST_STREAM"; + } + + @Override + public long errorCode() { + return errorCode; + } + + @Override + public String toString() { + return StringUtil.simpleClassName(this) + "(stream=" + stream() + ", errorCode=" + errorCode + ')'; + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof DefaultHttp2ResetFrame)) { + return false; + } + DefaultHttp2ResetFrame other = (DefaultHttp2ResetFrame) o; + return super.equals(o) && errorCode == other.errorCode; + } + + @Override + public int hashCode() { + int hash = super.hashCode(); + hash = hash * 31 + (int) (errorCode ^ errorCode >>> 32); + return hash; + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2SettingsAckFrame.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2SettingsAckFrame.java new file mode 100644 index 0000000..ea1c9ba --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2SettingsAckFrame.java @@ -0,0 +1,33 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.util.internal.StringUtil; + +/** + * The default {@link Http2SettingsAckFrame} implementation. + */ +final class DefaultHttp2SettingsAckFrame implements Http2SettingsAckFrame { + @Override + public String name() { + return "SETTINGS(ACK)"; + } + + @Override + public String toString() { + return StringUtil.simpleClassName(this); + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2SettingsFrame.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2SettingsFrame.java new file mode 100644 index 0000000..565c28e --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2SettingsFrame.java @@ -0,0 +1,63 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.codec.http2; + +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.StringUtil; +import io.netty.util.internal.UnstableApi; + +/** + * The default {@link Http2SettingsFrame} implementation. + */ +@UnstableApi +public class DefaultHttp2SettingsFrame implements Http2SettingsFrame { + + private final Http2Settings settings; + + public DefaultHttp2SettingsFrame(Http2Settings settings) { + this.settings = ObjectUtil.checkNotNull(settings, "settings"); + } + + @Override + public Http2Settings settings() { + return settings; + } + + @Override + public String name() { + return "SETTINGS"; + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof Http2SettingsFrame)) { + return false; + } + Http2SettingsFrame other = (Http2SettingsFrame) o; + return settings.equals(other.settings()); + } + + @Override + public int hashCode() { + return settings.hashCode(); + } + + @Override + public String toString() { + return StringUtil.simpleClassName(this) + "(settings=" + settings + ')'; + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2UnknownFrame.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2UnknownFrame.java new file mode 100644 index 0000000..66a3c6d --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2UnknownFrame.java @@ -0,0 +1,140 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.DefaultByteBufHolder; +import io.netty.buffer.Unpooled; +import io.netty.util.internal.StringUtil; +import io.netty.util.internal.UnstableApi; + +@UnstableApi +public final class DefaultHttp2UnknownFrame extends DefaultByteBufHolder implements Http2UnknownFrame { + private final byte frameType; + private final Http2Flags flags; + private Http2FrameStream stream; + + public DefaultHttp2UnknownFrame(byte frameType, Http2Flags flags) { + this(frameType, flags, Unpooled.EMPTY_BUFFER); + } + + public DefaultHttp2UnknownFrame(byte frameType, Http2Flags flags, ByteBuf data) { + super(data); + this.frameType = frameType; + this.flags = flags; + } + + @Override + public Http2FrameStream stream() { + return stream; + } + + @Override + public DefaultHttp2UnknownFrame stream(Http2FrameStream stream) { + this.stream = stream; + return this; + } + + @Override + public byte frameType() { + return frameType; + } + + @Override + public Http2Flags flags() { + return flags; + } + + @Override + public String name() { + return "UNKNOWN"; + } + + @Override + public DefaultHttp2UnknownFrame copy() { + return replace(content().copy()); + } + + @Override + public DefaultHttp2UnknownFrame duplicate() { + return replace(content().duplicate()); + } + + @Override + public DefaultHttp2UnknownFrame retainedDuplicate() { + return replace(content().retainedDuplicate()); + } + + @Override + public DefaultHttp2UnknownFrame replace(ByteBuf content) { + return new DefaultHttp2UnknownFrame(frameType, flags, content).stream(stream); + } + + @Override + public DefaultHttp2UnknownFrame retain() { + super.retain(); + return this; + } + + @Override + public DefaultHttp2UnknownFrame retain(int increment) { + super.retain(increment); + return this; + } + + @Override + public String toString() { + return StringUtil.simpleClassName(this) + "(frameType=" + frameType + ", stream=" + stream + + ", flags=" + flags + ", content=" + contentToString() + ')'; + } + + @Override + public DefaultHttp2UnknownFrame touch() { + super.touch(); + return this; + } + + @Override + public DefaultHttp2UnknownFrame touch(Object hint) { + super.touch(hint); + return this; + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof DefaultHttp2UnknownFrame)) { + return false; + } + DefaultHttp2UnknownFrame other = (DefaultHttp2UnknownFrame) o; + Http2FrameStream otherStream = other.stream(); + return (stream == otherStream || otherStream != null && otherStream.equals(stream)) + && flags.equals(other.flags()) + && frameType == other.frameType() + && super.equals(other); + } + + @Override + public int hashCode() { + int hash = super.hashCode(); + hash = hash * 31 + frameType; + hash = hash * 31 + flags.hashCode(); + if (stream != null) { + hash = hash * 31 + stream.hashCode(); + } + + return hash; + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2WindowUpdateFrame.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2WindowUpdateFrame.java new file mode 100644 index 0000000..1c98dd4 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DefaultHttp2WindowUpdateFrame.java @@ -0,0 +1,54 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.util.internal.StringUtil; +import io.netty.util.internal.UnstableApi; + +/** + * The default {@link Http2WindowUpdateFrame} implementation. + */ +@UnstableApi +public class DefaultHttp2WindowUpdateFrame extends AbstractHttp2StreamFrame implements Http2WindowUpdateFrame { + + private final int windowUpdateIncrement; + + public DefaultHttp2WindowUpdateFrame(int windowUpdateIncrement) { + this.windowUpdateIncrement = windowUpdateIncrement; + } + + @Override + public DefaultHttp2WindowUpdateFrame stream(Http2FrameStream stream) { + super.stream(stream); + return this; + } + + @Override + public String name() { + return "WINDOW_UPDATE"; + } + + @Override + public int windowSizeIncrement() { + return windowUpdateIncrement; + } + + @Override + public String toString() { + return StringUtil.simpleClassName(this) + + "(stream=" + stream() + ", windowUpdateIncrement=" + windowUpdateIncrement + ')'; + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DelegatingDecompressorFrameListener.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DelegatingDecompressorFrameListener.java new file mode 100644 index 0000000..6c59189 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/DelegatingDecompressorFrameListener.java @@ -0,0 +1,435 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.ByteToMessageDecoder; +import io.netty.handler.codec.compression.Brotli; +import io.netty.handler.codec.compression.BrotliDecoder; +import io.netty.handler.codec.compression.ZlibCodecFactory; +import io.netty.handler.codec.compression.ZlibWrapper; +import io.netty.handler.codec.compression.SnappyFrameDecoder; +import io.netty.util.internal.UnstableApi; + +import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_ENCODING; +import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_LENGTH; +import static io.netty.handler.codec.http.HttpHeaderValues.BR; +import static io.netty.handler.codec.http.HttpHeaderValues.DEFLATE; +import static io.netty.handler.codec.http.HttpHeaderValues.GZIP; +import static io.netty.handler.codec.http.HttpHeaderValues.IDENTITY; +import static io.netty.handler.codec.http.HttpHeaderValues.X_DEFLATE; +import static io.netty.handler.codec.http.HttpHeaderValues.X_GZIP; +import static io.netty.handler.codec.http.HttpHeaderValues.SNAPPY; +import static io.netty.handler.codec.http2.Http2Error.INTERNAL_ERROR; +import static io.netty.handler.codec.http2.Http2Exception.streamError; +import static io.netty.util.internal.ObjectUtil.checkNotNull; +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; + +/** + * An HTTP2 frame listener that will decompress data frames according to the {@code content-encoding} header for each + * stream. The decompression provided by this class will be applied to the data for the entire stream. + */ +@UnstableApi +public class DelegatingDecompressorFrameListener extends Http2FrameListenerDecorator { + + private final Http2Connection connection; + private final boolean strict; + private boolean flowControllerInitialized; + private final Http2Connection.PropertyKey propertyKey; + + public DelegatingDecompressorFrameListener(Http2Connection connection, Http2FrameListener listener) { + this(connection, listener, true); + } + + public DelegatingDecompressorFrameListener(Http2Connection connection, Http2FrameListener listener, + boolean strict) { + super(listener); + this.connection = connection; + this.strict = strict; + + propertyKey = connection.newKey(); + connection.addListener(new Http2ConnectionAdapter() { + @Override + public void onStreamRemoved(Http2Stream stream) { + final Http2Decompressor decompressor = decompressor(stream); + if (decompressor != null) { + cleanup(decompressor); + } + } + }); + } + + @Override + public int onDataRead(ChannelHandlerContext ctx, int streamId, ByteBuf data, int padding, boolean endOfStream) + throws Http2Exception { + final Http2Stream stream = connection.stream(streamId); + final Http2Decompressor decompressor = decompressor(stream); + if (decompressor == null) { + // The decompressor may be null if no compatible encoding type was found in this stream's headers + return listener.onDataRead(ctx, streamId, data, padding, endOfStream); + } + + final EmbeddedChannel channel = decompressor.decompressor(); + final int compressedBytes = data.readableBytes() + padding; + decompressor.incrementCompressedBytes(compressedBytes); + try { + // call retain here as it will call release after its written to the channel + channel.writeInbound(data.retain()); + ByteBuf buf = nextReadableBuf(channel); + if (buf == null && endOfStream && channel.finish()) { + buf = nextReadableBuf(channel); + } + if (buf == null) { + if (endOfStream) { + listener.onDataRead(ctx, streamId, Unpooled.EMPTY_BUFFER, padding, true); + } + // No new decompressed data was extracted from the compressed data. This means the application could + // not be provided with data and thus could not return how many bytes were processed. We will assume + // there is more data coming which will complete the decompression block. To allow for more data we + // return all bytes to the flow control window (so the peer can send more data). + decompressor.incrementDecompressedBytes(compressedBytes); + return compressedBytes; + } + try { + Http2LocalFlowController flowController = connection.local().flowController(); + decompressor.incrementDecompressedBytes(padding); + for (;;) { + ByteBuf nextBuf = nextReadableBuf(channel); + boolean decompressedEndOfStream = nextBuf == null && endOfStream; + if (decompressedEndOfStream && channel.finish()) { + nextBuf = nextReadableBuf(channel); + decompressedEndOfStream = nextBuf == null; + } + + decompressor.incrementDecompressedBytes(buf.readableBytes()); + // Immediately return the bytes back to the flow controller. ConsumedBytesConverter will convert + // from the decompressed amount which the user knows about to the compressed amount which flow + // control knows about. + flowController.consumeBytes(stream, + listener.onDataRead(ctx, streamId, buf, padding, decompressedEndOfStream)); + if (nextBuf == null) { + break; + } + + padding = 0; // Padding is only communicated once on the first iteration. + buf.release(); + buf = nextBuf; + } + // We consume bytes each time we call the listener to ensure if multiple frames are decompressed + // that the bytes are accounted for immediately. Otherwise the user may see an inconsistent state of + // flow control. + return 0; + } finally { + buf.release(); + } + } catch (Http2Exception e) { + throw e; + } catch (Throwable t) { + throw streamError(stream.id(), INTERNAL_ERROR, t, + "Decompressor error detected while delegating data read on streamId %d", stream.id()); + } + } + + @Override + public void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers headers, int padding, + boolean endStream) throws Http2Exception { + initDecompressor(ctx, streamId, headers, endStream); + listener.onHeadersRead(ctx, streamId, headers, padding, endStream); + } + + @Override + public void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers headers, int streamDependency, + short weight, boolean exclusive, int padding, boolean endStream) throws Http2Exception { + initDecompressor(ctx, streamId, headers, endStream); + listener.onHeadersRead(ctx, streamId, headers, streamDependency, weight, exclusive, padding, endStream); + } + + /** + * Returns a new {@link EmbeddedChannel} that decodes the HTTP2 message content encoded in the specified + * {@code contentEncoding}. + * + * @param contentEncoding the value of the {@code content-encoding} header + * @return a new {@link ByteToMessageDecoder} if the specified encoding is supported. {@code null} otherwise + * (alternatively, you can throw a {@link Http2Exception} to block unknown encoding). + * @throws Http2Exception If the specified encoding is not supported and warrants an exception + */ + protected EmbeddedChannel newContentDecompressor(final ChannelHandlerContext ctx, CharSequence contentEncoding) + throws Http2Exception { + if (GZIP.contentEqualsIgnoreCase(contentEncoding) || X_GZIP.contentEqualsIgnoreCase(contentEncoding)) { + return new EmbeddedChannel(ctx.channel().id(), ctx.channel().metadata().hasDisconnect(), + ctx.channel().config(), ZlibCodecFactory.newZlibDecoder(ZlibWrapper.GZIP)); + } + if (DEFLATE.contentEqualsIgnoreCase(contentEncoding) || X_DEFLATE.contentEqualsIgnoreCase(contentEncoding)) { + final ZlibWrapper wrapper = strict ? ZlibWrapper.ZLIB : ZlibWrapper.ZLIB_OR_NONE; + // To be strict, 'deflate' means ZLIB, but some servers were not implemented correctly. + return new EmbeddedChannel(ctx.channel().id(), ctx.channel().metadata().hasDisconnect(), + ctx.channel().config(), ZlibCodecFactory.newZlibDecoder(wrapper)); + } + if (Brotli.isAvailable() && BR.contentEqualsIgnoreCase(contentEncoding)) { + return new EmbeddedChannel(ctx.channel().id(), ctx.channel().metadata().hasDisconnect(), + ctx.channel().config(), new BrotliDecoder()); + } + if (SNAPPY.contentEqualsIgnoreCase(contentEncoding)) { + return new EmbeddedChannel(ctx.channel().id(), ctx.channel().metadata().hasDisconnect(), + ctx.channel().config(), new SnappyFrameDecoder()); + } + // 'identity' or unsupported + return null; + } + + /** + * Returns the expected content encoding of the decoded content. This getMethod returns {@code "identity"} by + * default, which is the case for most decompressors. + * + * @param contentEncoding the value of the {@code content-encoding} header + * @return the expected content encoding of the new content. + * @throws Http2Exception if the {@code contentEncoding} is not supported and warrants an exception + */ + protected CharSequence getTargetContentEncoding(@SuppressWarnings("UnusedParameters") CharSequence contentEncoding) + throws Http2Exception { + return IDENTITY; + } + + /** + * Checks if a new decompressor object is needed for the stream identified by {@code streamId}. + * This method will modify the {@code content-encoding} header contained in {@code headers}. + * + * @param ctx The context + * @param streamId The identifier for the headers inside {@code headers} + * @param headers Object representing headers which have been read + * @param endOfStream Indicates if the stream has ended + * @throws Http2Exception If the {@code content-encoding} is not supported + */ + private void initDecompressor(ChannelHandlerContext ctx, int streamId, Http2Headers headers, boolean endOfStream) + throws Http2Exception { + final Http2Stream stream = connection.stream(streamId); + if (stream == null) { + return; + } + + Http2Decompressor decompressor = decompressor(stream); + if (decompressor == null && !endOfStream) { + // Determine the content encoding. + CharSequence contentEncoding = headers.get(CONTENT_ENCODING); + if (contentEncoding == null) { + contentEncoding = IDENTITY; + } + final EmbeddedChannel channel = newContentDecompressor(ctx, contentEncoding); + if (channel != null) { + decompressor = new Http2Decompressor(channel); + stream.setProperty(propertyKey, decompressor); + // Decode the content and remove or replace the existing headers + // so that the message looks like a decoded message. + CharSequence targetContentEncoding = getTargetContentEncoding(contentEncoding); + if (IDENTITY.contentEqualsIgnoreCase(targetContentEncoding)) { + headers.remove(CONTENT_ENCODING); + } else { + headers.set(CONTENT_ENCODING, targetContentEncoding); + } + } + } + + if (decompressor != null) { + // The content length will be for the compressed data. Since we will decompress the data + // this content-length will not be correct. Instead of queuing messages or delaying sending + // header frames just remove the content-length header. + headers.remove(CONTENT_LENGTH); + + // The first time that we initialize a decompressor, decorate the local flow controller to + // properly convert consumed bytes. + if (!flowControllerInitialized) { + flowControllerInitialized = true; + connection.local().flowController(new ConsumedBytesConverter(connection.local().flowController())); + } + } + } + + Http2Decompressor decompressor(Http2Stream stream) { + return stream == null ? null : (Http2Decompressor) stream.getProperty(propertyKey); + } + + /** + * Release remaining content from the {@link EmbeddedChannel}. + * + * @param decompressor The decompressor for {@code stream} + */ + private static void cleanup(Http2Decompressor decompressor) { + decompressor.decompressor().finishAndReleaseAll(); + } + + /** + * Read the next decompressed {@link ByteBuf} from the {@link EmbeddedChannel} + * or {@code null} if one does not exist. + * + * @param decompressor The channel to read from + * @return The next decoded {@link ByteBuf} from the {@link EmbeddedChannel} or {@code null} if one does not exist + */ + private static ByteBuf nextReadableBuf(EmbeddedChannel decompressor) { + for (;;) { + final ByteBuf buf = decompressor.readInbound(); + if (buf == null) { + return null; + } + if (!buf.isReadable()) { + buf.release(); + continue; + } + return buf; + } + } + + /** + * A decorator around the local flow controller that converts consumed bytes from uncompressed to compressed. + */ + private final class ConsumedBytesConverter implements Http2LocalFlowController { + private final Http2LocalFlowController flowController; + + ConsumedBytesConverter(Http2LocalFlowController flowController) { + this.flowController = checkNotNull(flowController, "flowController"); + } + + @Override + public Http2LocalFlowController frameWriter(Http2FrameWriter frameWriter) { + return flowController.frameWriter(frameWriter); + } + + @Override + public void channelHandlerContext(ChannelHandlerContext ctx) throws Http2Exception { + flowController.channelHandlerContext(ctx); + } + + @Override + public void initialWindowSize(int newWindowSize) throws Http2Exception { + flowController.initialWindowSize(newWindowSize); + } + + @Override + public int initialWindowSize() { + return flowController.initialWindowSize(); + } + + @Override + public int windowSize(Http2Stream stream) { + return flowController.windowSize(stream); + } + + @Override + public void incrementWindowSize(Http2Stream stream, int delta) throws Http2Exception { + flowController.incrementWindowSize(stream, delta); + } + + @Override + public void receiveFlowControlledFrame(Http2Stream stream, ByteBuf data, int padding, + boolean endOfStream) throws Http2Exception { + flowController.receiveFlowControlledFrame(stream, data, padding, endOfStream); + } + + @Override + public boolean consumeBytes(Http2Stream stream, int numBytes) throws Http2Exception { + Http2Decompressor decompressor = decompressor(stream); + if (decompressor != null) { + // Convert the decompressed bytes to compressed (on the wire) bytes. + numBytes = decompressor.consumeBytes(stream.id(), numBytes); + } + try { + return flowController.consumeBytes(stream, numBytes); + } catch (Http2Exception e) { + throw e; + } catch (Throwable t) { + // The stream should be closed at this point. We have already changed our state tracking the compressed + // bytes, and there is no guarantee we can recover if the underlying flow controller throws. + throw streamError(stream.id(), INTERNAL_ERROR, t, "Error while returning bytes to flow control window"); + } + } + + @Override + public int unconsumedBytes(Http2Stream stream) { + return flowController.unconsumedBytes(stream); + } + + @Override + public int initialWindowSize(Http2Stream stream) { + return flowController.initialWindowSize(stream); + } + } + + /** + * Provides the state for stream {@code DATA} frame decompression. + */ + private static final class Http2Decompressor { + private final EmbeddedChannel decompressor; + private int compressed; + private int decompressed; + + Http2Decompressor(EmbeddedChannel decompressor) { + this.decompressor = decompressor; + } + + /** + * Responsible for taking compressed bytes in and producing decompressed bytes. + */ + EmbeddedChannel decompressor() { + return decompressor; + } + + /** + * Increment the number of bytes received prior to doing any decompression. + */ + void incrementCompressedBytes(int delta) { + assert delta >= 0; + compressed += delta; + } + + /** + * Increment the number of bytes after the decompression process. + */ + void incrementDecompressedBytes(int delta) { + assert delta >= 0; + decompressed += delta; + } + + /** + * Determines the ratio between {@code numBytes} and {@link Http2Decompressor#decompressed}. + * This ratio is used to decrement {@link Http2Decompressor#decompressed} and + * {@link Http2Decompressor#compressed}. + * @param streamId the stream ID + * @param decompressedBytes The number of post-decompressed bytes to return to flow control + * @return The number of pre-decompressed bytes that have been consumed. + */ + int consumeBytes(int streamId, int decompressedBytes) throws Http2Exception { + checkPositiveOrZero(decompressedBytes, "decompressedBytes"); + if (decompressed - decompressedBytes < 0) { + throw streamError(streamId, INTERNAL_ERROR, + "Attempting to return too many bytes for stream %d. decompressed: %d " + + "decompressedBytes: %d", streamId, decompressed, decompressedBytes); + } + double consumedRatio = decompressedBytes / (double) decompressed; + int consumedCompressed = Math.min(compressed, (int) Math.ceil(compressed * consumedRatio)); + if (compressed - consumedCompressed < 0) { + throw streamError(streamId, INTERNAL_ERROR, + "overflow when converting decompressed bytes to compressed bytes for stream %d." + + "decompressedBytes: %d decompressed: %d compressed: %d consumedCompressed: %d", + streamId, decompressedBytes, decompressed, compressed, consumedCompressed); + } + decompressed -= decompressedBytes; + compressed -= consumedCompressed; + + return consumedCompressed; + } + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/EmptyHttp2Headers.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/EmptyHttp2Headers.java new file mode 100644 index 0000000..f096337 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/EmptyHttp2Headers.java @@ -0,0 +1,83 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package io.netty.handler.codec.http2; + +import io.netty.handler.codec.EmptyHeaders; +import io.netty.util.internal.UnstableApi; + +@UnstableApi +public final class EmptyHttp2Headers + extends EmptyHeaders implements Http2Headers { + public static final EmptyHttp2Headers INSTANCE = new EmptyHttp2Headers(); + + private EmptyHttp2Headers() { + } + + @Override + public EmptyHttp2Headers method(CharSequence method) { + throw new UnsupportedOperationException(); + } + + @Override + public EmptyHttp2Headers scheme(CharSequence status) { + throw new UnsupportedOperationException(); + } + + @Override + public EmptyHttp2Headers authority(CharSequence authority) { + throw new UnsupportedOperationException(); + } + + @Override + public EmptyHttp2Headers path(CharSequence path) { + throw new UnsupportedOperationException(); + } + + @Override + public EmptyHttp2Headers status(CharSequence status) { + throw new UnsupportedOperationException(); + } + + @Override + public CharSequence method() { + return get(PseudoHeaderName.METHOD.value()); + } + + @Override + public CharSequence scheme() { + return get(PseudoHeaderName.SCHEME.value()); + } + + @Override + public CharSequence authority() { + return get(PseudoHeaderName.AUTHORITY.value()); + } + + @Override + public CharSequence path() { + return get(PseudoHeaderName.PATH.value()); + } + + @Override + public CharSequence status() { + return get(PseudoHeaderName.STATUS.value()); + } + + @Override + public boolean contains(CharSequence name, CharSequence value, boolean caseInsensitive) { + return false; + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/HpackDecoder.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/HpackDecoder.java new file mode 100644 index 0000000..8fb6a26 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/HpackDecoder.java @@ -0,0 +1,571 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/* + * Copyright 2014 Twitter, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.handler.codec.http.HttpHeaderValidationUtil; +import io.netty.handler.codec.http2.HpackUtil.IndexType; +import io.netty.util.AsciiString; + +import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_HEADER_TABLE_SIZE; +import static io.netty.handler.codec.http2.Http2CodecUtil.MAX_HEADER_LIST_SIZE; +import static io.netty.handler.codec.http2.Http2CodecUtil.MAX_HEADER_TABLE_SIZE; +import static io.netty.handler.codec.http2.Http2CodecUtil.MIN_HEADER_LIST_SIZE; +import static io.netty.handler.codec.http2.Http2CodecUtil.MIN_HEADER_TABLE_SIZE; +import static io.netty.handler.codec.http2.Http2CodecUtil.headerListSizeExceeded; +import static io.netty.handler.codec.http2.Http2Error.COMPRESSION_ERROR; +import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR; +import static io.netty.handler.codec.http2.Http2Exception.connectionError; +import static io.netty.handler.codec.http2.Http2Exception.streamError; +import static io.netty.handler.codec.http2.Http2Headers.PseudoHeaderName.getPseudoHeader; +import static io.netty.handler.codec.http2.Http2Headers.PseudoHeaderName.hasPseudoHeaderFormat; +import static io.netty.util.AsciiString.EMPTY_STRING; +import static io.netty.util.internal.ObjectUtil.checkPositive; + +final class HpackDecoder { + private static final Http2Exception DECODE_ULE_128_DECOMPRESSION_EXCEPTION = + Http2Exception.newStatic(COMPRESSION_ERROR, "HPACK - decompression failure", + Http2Exception.ShutdownHint.HARD_SHUTDOWN, HpackDecoder.class, + "decodeULE128(..)"); + private static final Http2Exception DECODE_ULE_128_TO_LONG_DECOMPRESSION_EXCEPTION = + Http2Exception.newStatic(COMPRESSION_ERROR, "HPACK - long overflow", + Http2Exception.ShutdownHint.HARD_SHUTDOWN, HpackDecoder.class, "decodeULE128(..)"); + private static final Http2Exception DECODE_ULE_128_TO_INT_DECOMPRESSION_EXCEPTION = + Http2Exception.newStatic(COMPRESSION_ERROR, "HPACK - int overflow", + Http2Exception.ShutdownHint.HARD_SHUTDOWN, HpackDecoder.class, "decodeULE128ToInt(..)"); + private static final Http2Exception DECODE_ILLEGAL_INDEX_VALUE = + Http2Exception.newStatic(COMPRESSION_ERROR, "HPACK - illegal index value", + Http2Exception.ShutdownHint.HARD_SHUTDOWN, HpackDecoder.class, "decode(..)"); + private static final Http2Exception INDEX_HEADER_ILLEGAL_INDEX_VALUE = + Http2Exception.newStatic(COMPRESSION_ERROR, "HPACK - illegal index value", + Http2Exception.ShutdownHint.HARD_SHUTDOWN, HpackDecoder.class, "indexHeader(..)"); + private static final Http2Exception READ_NAME_ILLEGAL_INDEX_VALUE = + Http2Exception.newStatic(COMPRESSION_ERROR, "HPACK - illegal index value", + Http2Exception.ShutdownHint.HARD_SHUTDOWN, HpackDecoder.class, "readName(..)"); + private static final Http2Exception INVALID_MAX_DYNAMIC_TABLE_SIZE = + Http2Exception.newStatic(COMPRESSION_ERROR, "HPACK - invalid max dynamic table size", + Http2Exception.ShutdownHint.HARD_SHUTDOWN, HpackDecoder.class, + "setDynamicTableSize(..)"); + private static final Http2Exception MAX_DYNAMIC_TABLE_SIZE_CHANGE_REQUIRED = + Http2Exception.newStatic(COMPRESSION_ERROR, "HPACK - max dynamic table size change required", + Http2Exception.ShutdownHint.HARD_SHUTDOWN, HpackDecoder.class, "decode(..)"); + private static final byte READ_HEADER_REPRESENTATION = 0; + private static final byte READ_INDEXED_HEADER = 1; + private static final byte READ_INDEXED_HEADER_NAME = 2; + private static final byte READ_LITERAL_HEADER_NAME_LENGTH_PREFIX = 3; + private static final byte READ_LITERAL_HEADER_NAME_LENGTH = 4; + private static final byte READ_LITERAL_HEADER_NAME = 5; + private static final byte READ_LITERAL_HEADER_VALUE_LENGTH_PREFIX = 6; + private static final byte READ_LITERAL_HEADER_VALUE_LENGTH = 7; + private static final byte READ_LITERAL_HEADER_VALUE = 8; + + private final HpackHuffmanDecoder huffmanDecoder = new HpackHuffmanDecoder(); + private final HpackDynamicTable hpackDynamicTable; + private long maxHeaderListSize; + private long maxDynamicTableSize; + private long encoderMaxDynamicTableSize; + private boolean maxDynamicTableSizeChangeRequired; + + /** + * Create a new instance. + * @param maxHeaderListSize This is the only setting that can be configured before notifying the peer. + * This is because SETTINGS_MAX_HEADER_LIST_SIZE + * allows a lower than advertised limit from being enforced, and the default limit is unlimited + * (which is dangerous). + */ + HpackDecoder(long maxHeaderListSize) { + this(maxHeaderListSize, DEFAULT_HEADER_TABLE_SIZE); + } + + /** + * Exposed Used for testing only! Default values used in the initial settings frame are overridden intentionally + * for testing but violate the RFC if used outside the scope of testing. + */ + HpackDecoder(long maxHeaderListSize, int maxHeaderTableSize) { + this.maxHeaderListSize = checkPositive(maxHeaderListSize, "maxHeaderListSize"); + + maxDynamicTableSize = encoderMaxDynamicTableSize = maxHeaderTableSize; + maxDynamicTableSizeChangeRequired = false; + hpackDynamicTable = new HpackDynamicTable(maxHeaderTableSize); + } + + /** + * Decode the header block into header fields. + *

+ * This method assumes the entire header block is contained in {@code in}. + */ + void decode(int streamId, ByteBuf in, Http2Headers headers, boolean validateHeaders) throws Http2Exception { + Http2HeadersSink sink = new Http2HeadersSink( + streamId, headers, maxHeaderListSize, validateHeaders); + // Check for dynamic table size updates, which must occur at the beginning: + // https://www.rfc-editor.org/rfc/rfc7541.html#section-4.2 + decodeDynamicTableSizeUpdates(in); + decode(in, sink); + + // Now that we've read all of our headers we can perform the validation steps. We must + // delay throwing until this point to prevent dynamic table corruption. + sink.finish(); + } + + private void decodeDynamicTableSizeUpdates(ByteBuf in) throws Http2Exception { + byte b; + while (in.isReadable() && ((b = in.getByte(in.readerIndex())) & 0x20) == 0x20 && ((b & 0xC0) == 0x00)) { + in.readByte(); + int index = b & 0x1F; + if (index == 0x1F) { + setDynamicTableSize(decodeULE128(in, (long) index)); + } else { + setDynamicTableSize(index); + } + } + } + + private void decode(ByteBuf in, Http2HeadersSink sink) throws Http2Exception { + int index = 0; + int nameLength = 0; + int valueLength = 0; + byte state = READ_HEADER_REPRESENTATION; + boolean huffmanEncoded = false; + AsciiString name = null; + IndexType indexType = IndexType.NONE; + while (in.isReadable()) { + switch (state) { + case READ_HEADER_REPRESENTATION: + byte b = in.readByte(); + if (maxDynamicTableSizeChangeRequired && (b & 0xE0) != 0x20) { + // HpackEncoder MUST signal maximum dynamic table size change + throw MAX_DYNAMIC_TABLE_SIZE_CHANGE_REQUIRED; + } + if (b < 0) { + // Indexed Header Field + index = b & 0x7F; + switch (index) { + case 0: + throw DECODE_ILLEGAL_INDEX_VALUE; + case 0x7F: + state = READ_INDEXED_HEADER; + break; + default: + HpackHeaderField indexedHeader = getIndexedHeader(index); + sink.appendToHeaderList( + (AsciiString) indexedHeader.name, + (AsciiString) indexedHeader.value); + } + } else if ((b & 0x40) == 0x40) { + // Literal Header Field with Incremental Indexing + indexType = IndexType.INCREMENTAL; + index = b & 0x3F; + switch (index) { + case 0: + state = READ_LITERAL_HEADER_NAME_LENGTH_PREFIX; + break; + case 0x3F: + state = READ_INDEXED_HEADER_NAME; + break; + default: + // Index was stored as the prefix + name = readName(index); + nameLength = name.length(); + state = READ_LITERAL_HEADER_VALUE_LENGTH_PREFIX; + } + } else if ((b & 0x20) == 0x20) { + // Dynamic Table Size Update + // See https://www.rfc-editor.org/rfc/rfc7541.html#section-4.2 + throw connectionError(COMPRESSION_ERROR, "Dynamic table size update must happen " + + "at the beginning of the header block"); + } else { + // Literal Header Field without Indexing / never Indexed + indexType = (b & 0x10) == 0x10 ? IndexType.NEVER : IndexType.NONE; + index = b & 0x0F; + switch (index) { + case 0: + state = READ_LITERAL_HEADER_NAME_LENGTH_PREFIX; + break; + case 0x0F: + state = READ_INDEXED_HEADER_NAME; + break; + default: + // Index was stored as the prefix + name = readName(index); + nameLength = name.length(); + state = READ_LITERAL_HEADER_VALUE_LENGTH_PREFIX; + } + } + break; + + case READ_INDEXED_HEADER: + HpackHeaderField indexedHeader = getIndexedHeader(decodeULE128(in, index)); + sink.appendToHeaderList( + (AsciiString) indexedHeader.name, + (AsciiString) indexedHeader.value); + state = READ_HEADER_REPRESENTATION; + break; + + case READ_INDEXED_HEADER_NAME: + // Header Name matches an entry in the Header Table + name = readName(decodeULE128(in, index)); + nameLength = name.length(); + state = READ_LITERAL_HEADER_VALUE_LENGTH_PREFIX; + break; + + case READ_LITERAL_HEADER_NAME_LENGTH_PREFIX: + b = in.readByte(); + huffmanEncoded = (b & 0x80) == 0x80; + index = b & 0x7F; + if (index == 0x7f) { + state = READ_LITERAL_HEADER_NAME_LENGTH; + } else { + nameLength = index; + state = READ_LITERAL_HEADER_NAME; + } + break; + + case READ_LITERAL_HEADER_NAME_LENGTH: + // Header Name is a Literal String + nameLength = decodeULE128(in, index); + + state = READ_LITERAL_HEADER_NAME; + break; + + case READ_LITERAL_HEADER_NAME: + // Wait until entire name is readable + if (in.readableBytes() < nameLength) { + throw notEnoughDataException(in); + } + + name = readStringLiteral(in, nameLength, huffmanEncoded); + + state = READ_LITERAL_HEADER_VALUE_LENGTH_PREFIX; + break; + + case READ_LITERAL_HEADER_VALUE_LENGTH_PREFIX: + b = in.readByte(); + huffmanEncoded = (b & 0x80) == 0x80; + index = b & 0x7F; + switch (index) { + case 0x7f: + state = READ_LITERAL_HEADER_VALUE_LENGTH; + break; + case 0: + insertHeader(sink, name, EMPTY_STRING, indexType); + state = READ_HEADER_REPRESENTATION; + break; + default: + valueLength = index; + state = READ_LITERAL_HEADER_VALUE; + } + + break; + + case READ_LITERAL_HEADER_VALUE_LENGTH: + // Header Value is a Literal String + valueLength = decodeULE128(in, index); + + state = READ_LITERAL_HEADER_VALUE; + break; + + case READ_LITERAL_HEADER_VALUE: + // Wait until entire value is readable + if (in.readableBytes() < valueLength) { + throw notEnoughDataException(in); + } + + AsciiString value = readStringLiteral(in, valueLength, huffmanEncoded); + insertHeader(sink, name, value, indexType); + state = READ_HEADER_REPRESENTATION; + break; + + default: + throw new Error("should not reach here state: " + state); + } + } + + if (state != READ_HEADER_REPRESENTATION) { + throw connectionError(COMPRESSION_ERROR, "Incomplete header block fragment."); + } + } + + /** + * Set the maximum table size. If this is below the maximum size of the dynamic table used by + * the encoder, the beginning of the next header block MUST signal this change. + */ + void setMaxHeaderTableSize(long maxHeaderTableSize) throws Http2Exception { + if (maxHeaderTableSize < MIN_HEADER_TABLE_SIZE || maxHeaderTableSize > MAX_HEADER_TABLE_SIZE) { + throw connectionError(PROTOCOL_ERROR, "Header Table Size must be >= %d and <= %d but was %d", + MIN_HEADER_TABLE_SIZE, MAX_HEADER_TABLE_SIZE, maxHeaderTableSize); + } + maxDynamicTableSize = maxHeaderTableSize; + if (maxDynamicTableSize < encoderMaxDynamicTableSize) { + // decoder requires less space than encoder + // encoder MUST signal this change + maxDynamicTableSizeChangeRequired = true; + hpackDynamicTable.setCapacity(maxDynamicTableSize); + } + } + + void setMaxHeaderListSize(long maxHeaderListSize) throws Http2Exception { + if (maxHeaderListSize < MIN_HEADER_LIST_SIZE || maxHeaderListSize > MAX_HEADER_LIST_SIZE) { + throw connectionError(PROTOCOL_ERROR, "Header List Size must be >= %d and <= %d but was %d", + MIN_HEADER_TABLE_SIZE, MAX_HEADER_TABLE_SIZE, maxHeaderListSize); + } + this.maxHeaderListSize = maxHeaderListSize; + } + + long getMaxHeaderListSize() { + return maxHeaderListSize; + } + + /** + * Return the maximum table size. This is the maximum size allowed by both the encoder and the + * decoder. + */ + long getMaxHeaderTableSize() { + return hpackDynamicTable.capacity(); + } + + /** + * Return the number of header fields in the dynamic table. Exposed for testing. + */ + int length() { + return hpackDynamicTable.length(); + } + + /** + * Return the size of the dynamic table. Exposed for testing. + */ + long size() { + return hpackDynamicTable.size(); + } + + /** + * Return the header field at the given index. Exposed for testing. + */ + HpackHeaderField getHeaderField(int index) { + return hpackDynamicTable.getEntry(index + 1); + } + + private void setDynamicTableSize(long dynamicTableSize) throws Http2Exception { + if (dynamicTableSize > maxDynamicTableSize) { + throw INVALID_MAX_DYNAMIC_TABLE_SIZE; + } + encoderMaxDynamicTableSize = dynamicTableSize; + maxDynamicTableSizeChangeRequired = false; + hpackDynamicTable.setCapacity(dynamicTableSize); + } + + private static HeaderType validateHeader(int streamId, AsciiString name, CharSequence value, + HeaderType previousHeaderType) throws Http2Exception { + if (hasPseudoHeaderFormat(name)) { + if (previousHeaderType == HeaderType.REGULAR_HEADER) { + throw streamError(streamId, PROTOCOL_ERROR, + "Pseudo-header field '%s' found after regular header.", name); + } + final Http2Headers.PseudoHeaderName pseudoHeader = getPseudoHeader(name); + final HeaderType currentHeaderType = pseudoHeader.isRequestOnly() ? + HeaderType.REQUEST_PSEUDO_HEADER : HeaderType.RESPONSE_PSEUDO_HEADER; + if (previousHeaderType != null && currentHeaderType != previousHeaderType) { + throw streamError(streamId, PROTOCOL_ERROR, "Mix of request and response pseudo-headers."); + } + return currentHeaderType; + } + if (HttpHeaderValidationUtil.isConnectionHeader(name, true)) { + throw streamError(streamId, PROTOCOL_ERROR, "Illegal connection-specific header '%s' encountered.", name); + } + if (HttpHeaderValidationUtil.isTeNotTrailers(name, value)) { + throw streamError(streamId, PROTOCOL_ERROR, + "Illegal value specified for the 'TE' header (only 'trailers' is allowed)."); + } + + return HeaderType.REGULAR_HEADER; + } + + private AsciiString readName(int index) throws Http2Exception { + if (index <= HpackStaticTable.length) { + HpackHeaderField hpackHeaderField = HpackStaticTable.getEntry(index); + return (AsciiString) hpackHeaderField.name; + } + if (index - HpackStaticTable.length <= hpackDynamicTable.length()) { + HpackHeaderField hpackHeaderField = hpackDynamicTable.getEntry(index - HpackStaticTable.length); + return (AsciiString) hpackHeaderField.name; + } + throw READ_NAME_ILLEGAL_INDEX_VALUE; + } + + private HpackHeaderField getIndexedHeader(int index) throws Http2Exception { + if (index <= HpackStaticTable.length) { + return HpackStaticTable.getEntry(index); + } + if (index - HpackStaticTable.length <= hpackDynamicTable.length()) { + return hpackDynamicTable.getEntry(index - HpackStaticTable.length); + } + throw INDEX_HEADER_ILLEGAL_INDEX_VALUE; + } + + private void insertHeader(Http2HeadersSink sink, AsciiString name, AsciiString value, IndexType indexType) { + sink.appendToHeaderList(name, value); + + switch (indexType) { + case NONE: + case NEVER: + break; + + case INCREMENTAL: + hpackDynamicTable.add(new HpackHeaderField(name, value)); + break; + + default: + throw new Error("should not reach here"); + } + } + + private AsciiString readStringLiteral(ByteBuf in, int length, boolean huffmanEncoded) throws Http2Exception { + if (huffmanEncoded) { + return huffmanDecoder.decode(in, length); + } + byte[] buf = new byte[length]; + in.readBytes(buf); + return new AsciiString(buf, false); + } + + private static IllegalArgumentException notEnoughDataException(ByteBuf in) { + return new IllegalArgumentException("decode only works with an entire header block! " + in); + } + + /** + * Unsigned Little Endian Base 128 Variable-Length Integer Encoding + *

+ * Visible for testing only! + */ + static int decodeULE128(ByteBuf in, int result) throws Http2Exception { + final int readerIndex = in.readerIndex(); + final long v = decodeULE128(in, (long) result); + if (v > Integer.MAX_VALUE) { + // the maximum value that can be represented by a signed 32 bit number is: + // [0x1,0x7f] + 0x7f + (0x7f << 7) + (0x7f << 14) + (0x7f << 21) + (0x6 << 28) + // OR + // 0x0 + 0x7f + (0x7f << 7) + (0x7f << 14) + (0x7f << 21) + (0x7 << 28) + // we should reset the readerIndex if we overflowed the int type. + in.readerIndex(readerIndex); + throw DECODE_ULE_128_TO_INT_DECOMPRESSION_EXCEPTION; + } + return (int) v; + } + + /** + * Unsigned Little Endian Base 128 Variable-Length Integer Encoding + *

+ * Visible for testing only! + */ + static long decodeULE128(ByteBuf in, long result) throws Http2Exception { + assert result <= 0x7f && result >= 0; + final boolean resultStartedAtZero = result == 0; + final int writerIndex = in.writerIndex(); + for (int readerIndex = in.readerIndex(), shift = 0; readerIndex < writerIndex; ++readerIndex, shift += 7) { + byte b = in.getByte(readerIndex); + if (shift == 56 && ((b & 0x80) != 0 || b == 0x7F && !resultStartedAtZero)) { + // the maximum value that can be represented by a signed 64 bit number is: + // [0x01L, 0x7fL] + 0x7fL + (0x7fL << 7) + (0x7fL << 14) + (0x7fL << 21) + (0x7fL << 28) + (0x7fL << 35) + // + (0x7fL << 42) + (0x7fL << 49) + (0x7eL << 56) + // OR + // 0x0L + 0x7fL + (0x7fL << 7) + (0x7fL << 14) + (0x7fL << 21) + (0x7fL << 28) + (0x7fL << 35) + + // (0x7fL << 42) + (0x7fL << 49) + (0x7fL << 56) + // this means any more shifts will result in overflow so we should break out and throw an error. + throw DECODE_ULE_128_TO_LONG_DECOMPRESSION_EXCEPTION; + } + + if ((b & 0x80) == 0) { + in.readerIndex(readerIndex + 1); + return result + ((b & 0x7FL) << shift); + } + result += (b & 0x7FL) << shift; + } + + throw DECODE_ULE_128_DECOMPRESSION_EXCEPTION; + } + + /** + * HTTP/2 header types. + */ + private enum HeaderType { + REGULAR_HEADER, + REQUEST_PSEUDO_HEADER, + RESPONSE_PSEUDO_HEADER + } + + private static final class Http2HeadersSink { + private final Http2Headers headers; + private final long maxHeaderListSize; + private final int streamId; + private final boolean validateHeaders; + private long headersLength; + private boolean exceededMaxLength; + private HeaderType previousType; + private Http2Exception validationException; + + Http2HeadersSink(int streamId, Http2Headers headers, long maxHeaderListSize, boolean validateHeaders) { + this.headers = headers; + this.maxHeaderListSize = maxHeaderListSize; + this.streamId = streamId; + this.validateHeaders = validateHeaders; + } + + void finish() throws Http2Exception { + if (exceededMaxLength) { + headerListSizeExceeded(streamId, maxHeaderListSize, true); + } else if (validationException != null) { + throw validationException; + } + } + + void appendToHeaderList(AsciiString name, AsciiString value) { + headersLength += HpackHeaderField.sizeOf(name, value); + exceededMaxLength |= headersLength > maxHeaderListSize; + + if (exceededMaxLength || validationException != null) { + // We don't store the header since we've already failed validation requirements. + return; + } + + try { + headers.add(name, value); + if (validateHeaders) { + previousType = validateHeader(streamId, name, value, previousType); + } + } catch (IllegalArgumentException ex) { + validationException = streamError(streamId, PROTOCOL_ERROR, ex, + "Validation failed for header '%s': %s", name, ex.getMessage()); + } catch (Http2Exception ex) { + validationException = streamError(streamId, PROTOCOL_ERROR, ex, ex.getMessage()); + } + } + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/HpackDynamicTable.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/HpackDynamicTable.java new file mode 100644 index 0000000..4a712ee --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/HpackDynamicTable.java @@ -0,0 +1,201 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/* + * Copyright 2014 Twitter, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.netty.handler.codec.http2; + +import static io.netty.handler.codec.http2.Http2CodecUtil.MAX_HEADER_TABLE_SIZE; +import static io.netty.handler.codec.http2.Http2CodecUtil.MIN_HEADER_TABLE_SIZE; + +final class HpackDynamicTable { + + // a circular queue of header fields + HpackHeaderField[] hpackHeaderFields; + int head; + int tail; + private long size; + private long capacity = -1; // ensure setCapacity creates the array + + /** + * Creates a new dynamic table with the specified initial capacity. + */ + HpackDynamicTable(long initialCapacity) { + setCapacity(initialCapacity); + } + + /** + * Return the number of header fields in the dynamic table. + */ + public int length() { + int length; + if (head < tail) { + length = hpackHeaderFields.length - tail + head; + } else { + length = head - tail; + } + return length; + } + + /** + * Return the current size of the dynamic table. This is the sum of the size of the entries. + */ + public long size() { + return size; + } + + /** + * Return the maximum allowable size of the dynamic table. + */ + public long capacity() { + return capacity; + } + + /** + * Return the header field at the given index. The first and newest entry is always at index 1, + * and the oldest entry is at the index length(). + */ + public HpackHeaderField getEntry(int index) { + if (index <= 0 || index > length()) { + throw new IndexOutOfBoundsException("Index " + index + " out of bounds for length " + length()); + } + int i = head - index; + if (i < 0) { + return hpackHeaderFields[i + hpackHeaderFields.length]; + } else { + return hpackHeaderFields[i]; + } + } + + /** + * Add the header field to the dynamic table. Entries are evicted from the dynamic table until + * the size of the table and the new header field is less than or equal to the table's capacity. + * If the size of the new entry is larger than the table's capacity, the dynamic table will be + * cleared. + */ + public void add(HpackHeaderField header) { + int headerSize = header.size(); + if (headerSize > capacity) { + clear(); + return; + } + while (capacity - size < headerSize) { + remove(); + } + hpackHeaderFields[head++] = header; + size += headerSize; + if (head == hpackHeaderFields.length) { + head = 0; + } + } + + /** + * Remove and return the oldest header field from the dynamic table. + */ + public HpackHeaderField remove() { + HpackHeaderField removed = hpackHeaderFields[tail]; + if (removed == null) { + return null; + } + size -= removed.size(); + hpackHeaderFields[tail++] = null; + if (tail == hpackHeaderFields.length) { + tail = 0; + } + return removed; + } + + /** + * Remove all entries from the dynamic table. + */ + public void clear() { + while (tail != head) { + hpackHeaderFields[tail++] = null; + if (tail == hpackHeaderFields.length) { + tail = 0; + } + } + head = 0; + tail = 0; + size = 0; + } + + /** + * Set the maximum size of the dynamic table. Entries are evicted from the dynamic table until + * the size of the table is less than or equal to the maximum size. + */ + public void setCapacity(long capacity) { + if (capacity < MIN_HEADER_TABLE_SIZE || capacity > MAX_HEADER_TABLE_SIZE) { + throw new IllegalArgumentException("capacity is invalid: " + capacity); + } + // initially capacity will be -1 so init won't return here + if (this.capacity == capacity) { + return; + } + this.capacity = capacity; + + if (capacity == 0) { + clear(); + } else { + // initially size will be 0 so remove won't be called + while (size > capacity) { + remove(); + } + } + + int maxEntries = (int) (capacity / HpackHeaderField.HEADER_ENTRY_OVERHEAD); + if (capacity % HpackHeaderField.HEADER_ENTRY_OVERHEAD != 0) { + maxEntries++; + } + + // check if capacity change requires us to reallocate the array + if (hpackHeaderFields != null && hpackHeaderFields.length == maxEntries) { + return; + } + + HpackHeaderField[] tmp = new HpackHeaderField[maxEntries]; + + // initially length will be 0 so there will be no copy + int len = length(); + if (hpackHeaderFields != null) { + int cursor = tail; + for (int i = 0; i < len; i++) { + HpackHeaderField entry = hpackHeaderFields[cursor++]; + tmp[i] = entry; + if (cursor == hpackHeaderFields.length) { + cursor = 0; + } + } + } + + tail = 0; + head = tail + len; + hpackHeaderFields = tmp; + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/HpackEncoder.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/HpackEncoder.java new file mode 100644 index 0000000..8f2d5a4 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/HpackEncoder.java @@ -0,0 +1,555 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/* + * Copyright 2014 Twitter, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.handler.codec.http2.HpackUtil.IndexType; +import io.netty.handler.codec.http2.Http2HeadersEncoder.SensitivityDetector; +import io.netty.util.AsciiString; +import io.netty.util.CharsetUtil; + +import java.util.Map; + +import static io.netty.handler.codec.http2.HpackUtil.equalsConstantTime; +import static io.netty.handler.codec.http2.HpackUtil.equalsVariableTime; +import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_HEADER_TABLE_SIZE; +import static io.netty.handler.codec.http2.Http2CodecUtil.MAX_HEADER_LIST_SIZE; +import static io.netty.handler.codec.http2.Http2CodecUtil.MAX_HEADER_TABLE_SIZE; +import static io.netty.handler.codec.http2.Http2CodecUtil.MIN_HEADER_LIST_SIZE; +import static io.netty.handler.codec.http2.Http2CodecUtil.MIN_HEADER_TABLE_SIZE; +import static io.netty.handler.codec.http2.Http2CodecUtil.headerListSizeExceeded; +import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR; +import static io.netty.handler.codec.http2.Http2Exception.connectionError; +import static io.netty.util.internal.MathUtil.findNextPositivePowerOfTwo; +import static java.lang.Math.max; +import static java.lang.Math.min; + +/** + * An HPACK encoder. + * + *

Implementation note: This class is security sensitive, and depends on users correctly identifying their headers + * as security sensitive or not. If a header is considered not sensitive, methods names "insensitive" are used which + * are fast, but don't provide any security guarantees. + */ +final class HpackEncoder { + static final int NOT_FOUND = -1; + static final int HUFF_CODE_THRESHOLD = 512; + // a hash map of header fields keyed by header name + private final NameEntry[] nameEntries; + + // a hash map of header fields keyed by header name and value + private final NameValueEntry[] nameValueEntries; + + private final NameValueEntry head = new NameValueEntry(-1, AsciiString.EMPTY_STRING, + AsciiString.EMPTY_STRING, Integer.MAX_VALUE, null); + + private NameValueEntry latest = head; + + private final HpackHuffmanEncoder hpackHuffmanEncoder = new HpackHuffmanEncoder(); + private final byte hashMask; + private final boolean ignoreMaxHeaderListSize; + private final int huffCodeThreshold; + private long size; + private long maxHeaderTableSize; + private long maxHeaderListSize; + + /** + * Creates a new encoder. + */ + HpackEncoder() { + this(false); + } + + /** + * Creates a new encoder. + */ + HpackEncoder(boolean ignoreMaxHeaderListSize) { + this(ignoreMaxHeaderListSize, 64, HUFF_CODE_THRESHOLD); + } + + /** + * Creates a new encoder. + */ + HpackEncoder(boolean ignoreMaxHeaderListSize, int arraySizeHint, int huffCodeThreshold) { + this.ignoreMaxHeaderListSize = ignoreMaxHeaderListSize; + maxHeaderTableSize = DEFAULT_HEADER_TABLE_SIZE; + maxHeaderListSize = MAX_HEADER_LIST_SIZE; + // Enforce a bound of [2, 128] because hashMask is a byte. The max possible value of hashMask is one less + // than the length of this array, and we want the mask to be > 0. + nameEntries = new NameEntry[findNextPositivePowerOfTwo(max(2, min(arraySizeHint, 128)))]; + nameValueEntries = new NameValueEntry[nameEntries.length]; + hashMask = (byte) (nameEntries.length - 1); + this.huffCodeThreshold = huffCodeThreshold; + } + + /** + * Encode the header field into the header block. + *

+ * The given {@link CharSequence}s must be immutable! + */ + public void encodeHeaders(int streamId, ByteBuf out, Http2Headers headers, SensitivityDetector sensitivityDetector) + throws Http2Exception { + if (ignoreMaxHeaderListSize) { + encodeHeadersIgnoreMaxHeaderListSize(out, headers, sensitivityDetector); + } else { + encodeHeadersEnforceMaxHeaderListSize(streamId, out, headers, sensitivityDetector); + } + } + + private void encodeHeadersEnforceMaxHeaderListSize(int streamId, ByteBuf out, Http2Headers headers, + SensitivityDetector sensitivityDetector) + throws Http2Exception { + long headerSize = 0; + // To ensure we stay consistent with our peer check the size is valid before we potentially modify HPACK state. + for (Map.Entry header : headers) { + CharSequence name = header.getKey(); + CharSequence value = header.getValue(); + // OK to increment now and check for bounds after because this value is limited to unsigned int and will not + // overflow. + headerSize += HpackHeaderField.sizeOf(name, value); + if (headerSize > maxHeaderListSize) { + headerListSizeExceeded(streamId, maxHeaderListSize, false); + } + } + encodeHeadersIgnoreMaxHeaderListSize(out, headers, sensitivityDetector); + } + + private void encodeHeadersIgnoreMaxHeaderListSize(ByteBuf out, Http2Headers headers, + SensitivityDetector sensitivityDetector) { + for (Map.Entry header : headers) { + CharSequence name = header.getKey(); + CharSequence value = header.getValue(); + encodeHeader(out, name, value, sensitivityDetector.isSensitive(name, value), + HpackHeaderField.sizeOf(name, value)); + } + } + + /** + * Encode the header field into the header block. + *

+ * The given {@link CharSequence}s must be immutable! + */ + private void encodeHeader(ByteBuf out, CharSequence name, CharSequence value, boolean sensitive, long headerSize) { + // If the header value is sensitive then it must never be indexed + if (sensitive) { + int nameIndex = getNameIndex(name); + encodeLiteral(out, name, value, IndexType.NEVER, nameIndex); + return; + } + + // If the peer will only use the static table + if (maxHeaderTableSize == 0) { + int staticTableIndex = HpackStaticTable.getIndexInsensitive(name, value); + if (staticTableIndex == HpackStaticTable.NOT_FOUND) { + int nameIndex = HpackStaticTable.getIndex(name); + encodeLiteral(out, name, value, IndexType.NONE, nameIndex); + } else { + encodeInteger(out, 0x80, 7, staticTableIndex); + } + return; + } + + // If the headerSize is greater than the max table size then it must be encoded literally + if (headerSize > maxHeaderTableSize) { + int nameIndex = getNameIndex(name); + encodeLiteral(out, name, value, IndexType.NONE, nameIndex); + return; + } + + int nameHash = AsciiString.hashCode(name); + int valueHash = AsciiString.hashCode(value); + NameValueEntry headerField = getEntryInsensitive(name, nameHash, value, valueHash); + if (headerField != null) { + // Section 6.1. Indexed Header Field Representation + encodeInteger(out, 0x80, 7, getIndexPlusOffset(headerField.counter)); + } else { + int staticTableIndex = HpackStaticTable.getIndexInsensitive(name, value); + if (staticTableIndex != HpackStaticTable.NOT_FOUND) { + // Section 6.1. Indexed Header Field Representation + encodeInteger(out, 0x80, 7, staticTableIndex); + } else { + ensureCapacity(headerSize); + encodeAndAddEntries(out, name, nameHash, value, valueHash); + size += headerSize; + } + } + } + + private void encodeAndAddEntries(ByteBuf out, CharSequence name, int nameHash, CharSequence value, int valueHash) { + int staticTableIndex = HpackStaticTable.getIndex(name); + int nextCounter = latestCounter() - 1; + if (staticTableIndex == HpackStaticTable.NOT_FOUND) { + NameEntry e = getEntry(name, nameHash); + if (e == null) { + encodeLiteral(out, name, value, IndexType.INCREMENTAL, NOT_FOUND); + addNameEntry(name, nameHash, nextCounter); + addNameValueEntry(name, value, nameHash, valueHash, nextCounter); + } else { + encodeLiteral(out, name, value, IndexType.INCREMENTAL, getIndexPlusOffset(e.counter)); + addNameValueEntry(e.name, value, nameHash, valueHash, nextCounter); + + // The name entry should always point to the latest counter. + e.counter = nextCounter; + } + } else { + encodeLiteral(out, name, value, IndexType.INCREMENTAL, staticTableIndex); + // use the name from the static table to optimize memory usage. + addNameValueEntry( + HpackStaticTable.getEntry(staticTableIndex).name, value, nameHash, valueHash, nextCounter); + } + } + + /** + * Set the maximum table size. + */ + public void setMaxHeaderTableSize(ByteBuf out, long maxHeaderTableSize) throws Http2Exception { + if (maxHeaderTableSize < MIN_HEADER_TABLE_SIZE || maxHeaderTableSize > MAX_HEADER_TABLE_SIZE) { + throw connectionError(PROTOCOL_ERROR, "Header Table Size must be >= %d and <= %d but was %d", + MIN_HEADER_TABLE_SIZE, MAX_HEADER_TABLE_SIZE, maxHeaderTableSize); + } + if (this.maxHeaderTableSize == maxHeaderTableSize) { + return; + } + this.maxHeaderTableSize = maxHeaderTableSize; + ensureCapacity(0); + // Casting to integer is safe as we verified the maxHeaderTableSize is a valid unsigned int. + encodeInteger(out, 0x20, 5, maxHeaderTableSize); + } + + /** + * Return the maximum table size. + */ + public long getMaxHeaderTableSize() { + return maxHeaderTableSize; + } + + public void setMaxHeaderListSize(long maxHeaderListSize) throws Http2Exception { + if (maxHeaderListSize < MIN_HEADER_LIST_SIZE || maxHeaderListSize > MAX_HEADER_LIST_SIZE) { + throw connectionError(PROTOCOL_ERROR, "Header List Size must be >= %d and <= %d but was %d", + MIN_HEADER_LIST_SIZE, MAX_HEADER_LIST_SIZE, maxHeaderListSize); + } + this.maxHeaderListSize = maxHeaderListSize; + } + + public long getMaxHeaderListSize() { + return maxHeaderListSize; + } + + /** + * Encode integer according to Section 5.1. + */ + private static void encodeInteger(ByteBuf out, int mask, int n, int i) { + encodeInteger(out, mask, n, (long) i); + } + + /** + * Encode integer according to Section 5.1. + */ + private static void encodeInteger(ByteBuf out, int mask, int n, long i) { + assert n >= 0 && n <= 8 : "N: " + n; + int nbits = 0xFF >>> 8 - n; + if (i < nbits) { + out.writeByte((int) (mask | i)); + } else { + out.writeByte(mask | nbits); + long length = i - nbits; + for (; (length & ~0x7F) != 0; length >>>= 7) { + out.writeByte((int) (length & 0x7F | 0x80)); + } + out.writeByte((int) length); + } + } + + /** + * Encode string literal according to Section 5.2. + */ + private void encodeStringLiteral(ByteBuf out, CharSequence string) { + int huffmanLength; + if (string.length() >= huffCodeThreshold + && (huffmanLength = hpackHuffmanEncoder.getEncodedLength(string)) < string.length()) { + encodeInteger(out, 0x80, 7, huffmanLength); + hpackHuffmanEncoder.encode(out, string); + } else { + encodeInteger(out, 0x00, 7, string.length()); + if (string instanceof AsciiString) { + // Fast-path + AsciiString asciiString = (AsciiString) string; + out.writeBytes(asciiString.array(), asciiString.arrayOffset(), asciiString.length()); + } else { + // Only ASCII is allowed in http2 headers, so it is fine to use this. + // https://tools.ietf.org/html/rfc7540#section-8.1.2 + out.writeCharSequence(string, CharsetUtil.ISO_8859_1); + } + } + } + + /** + * Encode literal header field according to Section 6.2. + */ + private void encodeLiteral(ByteBuf out, CharSequence name, CharSequence value, IndexType indexType, + int nameIndex) { + boolean nameIndexValid = nameIndex != NOT_FOUND; + switch (indexType) { + case INCREMENTAL: + encodeInteger(out, 0x40, 6, nameIndexValid ? nameIndex : 0); + break; + case NONE: + encodeInteger(out, 0x00, 4, nameIndexValid ? nameIndex : 0); + break; + case NEVER: + encodeInteger(out, 0x10, 4, nameIndexValid ? nameIndex : 0); + break; + default: + throw new Error("should not reach here"); + } + if (!nameIndexValid) { + encodeStringLiteral(out, name); + } + encodeStringLiteral(out, value); + } + + private int getNameIndex(CharSequence name) { + int index = HpackStaticTable.getIndex(name); + if (index != HpackStaticTable.NOT_FOUND) { + return index; + } + NameEntry e = getEntry(name, AsciiString.hashCode(name)); + return e == null ? NOT_FOUND : getIndexPlusOffset(e.counter); + } + + /** + * Ensure that the dynamic table has enough room to hold 'headerSize' more bytes. Removes the + * oldest entry from the dynamic table until sufficient space is available. + */ + private void ensureCapacity(long headerSize) { + while (maxHeaderTableSize - size < headerSize) { + remove(); + } + } + + /** + * Return the number of header fields in the dynamic table. Exposed for testing. + */ + int length() { + return isEmpty() ? 0 : getIndex(head.after.counter); + } + + /** + * Return the size of the dynamic table. Exposed for testing. + */ + long size() { + return size; + } + + /** + * Return the header field at the given index. Exposed for testing. + */ + HpackHeaderField getHeaderField(int index) { + NameValueEntry entry = head; + while (index++ < length()) { + entry = entry.after; + } + return entry; + } + + /** + * Returns the header entry with the lowest index value for the header field. Returns null if + * header field is not in the dynamic table. + */ + private NameValueEntry getEntryInsensitive(CharSequence name, int nameHash, CharSequence value, int valueHash) { + int h = hash(nameHash, valueHash); + for (NameValueEntry e = nameValueEntries[bucket(h)]; e != null; e = e.next) { + if (e.hash == h && equalsVariableTime(value, e.value) && equalsVariableTime(name, e.name)) { + return e; + } + } + return null; + } + + /** + * Returns the lowest index value for the header field name in the dynamic table. Returns -1 if + * the header field name is not in the dynamic table. + */ + private NameEntry getEntry(CharSequence name, int nameHash) { + for (NameEntry e = nameEntries[bucket(nameHash)]; e != null; e = e.next) { + if (e.hash == nameHash && equalsConstantTime(name, e.name) != 0) { + return e; + } + } + return null; + } + + private int getIndexPlusOffset(int counter) { + return getIndex(counter) + HpackStaticTable.length; + } + + /** + * Compute the index into the dynamic table given the counter in the header entry. + */ + private int getIndex(int counter) { + return counter - latestCounter() + 1; + } + + private int latestCounter() { + return latest.counter; + } + + private void addNameEntry(CharSequence name, int nameHash, int nextCounter) { + int bucket = bucket(nameHash); + nameEntries[bucket] = new NameEntry(nameHash, name, nextCounter, nameEntries[bucket]); + } + + private void addNameValueEntry(CharSequence name, CharSequence value, + int nameHash, int valueHash, int nextCounter) { + int hash = hash(nameHash, valueHash); + int bucket = bucket(hash); + NameValueEntry e = new NameValueEntry(hash, name, value, nextCounter, nameValueEntries[bucket]); + nameValueEntries[bucket] = e; + latest.after = e; + latest = e; + } + + /** + * Remove the oldest header field from the dynamic table. + */ + private void remove() { + NameValueEntry eldest = head.after; + removeNameValueEntry(eldest); + removeNameEntryMatchingCounter(eldest.name, eldest.counter); + head.after = eldest.after; + eldest.unlink(); + size -= eldest.size(); + if (isEmpty()) { + latest = head; + } + } + + private boolean isEmpty() { + return size == 0; + } + + private void removeNameValueEntry(NameValueEntry eldest) { + int bucket = bucket(eldest.hash); + NameValueEntry e = nameValueEntries[bucket]; + if (e == eldest) { + nameValueEntries[bucket] = eldest.next; + } else { + while (e.next != eldest) { + e = e.next; + } + e.next = eldest.next; + } + } + + private void removeNameEntryMatchingCounter(CharSequence name, int counter) { + int hash = AsciiString.hashCode(name); + int bucket = bucket(hash); + NameEntry e = nameEntries[bucket]; + if (e == null) { + return; + } + if (counter == e.counter) { + nameEntries[bucket] = e.next; + e.unlink(); + } else { + NameEntry prev = e; + e = e.next; + while (e != null) { + if (counter == e.counter) { + prev.next = e.next; + e.unlink(); + break; + } + prev = e; + e = e.next; + } + } + } + + /** + * Returns the bucket of the hash table for the hash code h. + */ + private int bucket(int h) { + return h & hashMask; + } + + private static int hash(int nameHash, int valueHash) { + return 31 * nameHash + valueHash; + } + + private static final class NameEntry { + NameEntry next; + + final CharSequence name; + + final int hash; + + // This is used to compute the index in the dynamic table. + int counter; + + NameEntry(int hash, CharSequence name, int counter, NameEntry next) { + this.hash = hash; + this.name = name; + this.counter = counter; + this.next = next; + } + + void unlink() { + next = null; // null references to prevent nepotism in generational GC. + } + } + + private static final class NameValueEntry extends HpackHeaderField { + // This field comprises the linked list used for implementing the eviction policy. + NameValueEntry after; + + NameValueEntry next; + + // hash of both name and value + final int hash; + + // This is used to compute the index in the dynamic table. + final int counter; + + NameValueEntry(int hash, CharSequence name, CharSequence value, int counter, NameValueEntry next) { + super(name, value); + this.next = next; + this.hash = hash; + this.counter = counter; + } + + void unlink() { + after = null; // null references to prevent nepotism in generational GC. + next = null; + } + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/HpackHeaderField.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/HpackHeaderField.java new file mode 100644 index 0000000..5bc7d03 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/HpackHeaderField.java @@ -0,0 +1,69 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/* + * Copyright 2014 Twitter, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.netty.handler.codec.http2; + +import static io.netty.handler.codec.http2.HpackUtil.equalsVariableTime; +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +class HpackHeaderField { + + // Section 4.1. Calculating Table Size + // The additional 32 octets account for an estimated + // overhead associated with the structure. + static final int HEADER_ENTRY_OVERHEAD = 32; + + static long sizeOf(CharSequence name, CharSequence value) { + return name.length() + value.length() + HEADER_ENTRY_OVERHEAD; + } + + final CharSequence name; + final CharSequence value; + + // This constructor can only be used if name and value are ISO-8859-1 encoded. + HpackHeaderField(CharSequence name, CharSequence value) { + this.name = checkNotNull(name, "name"); + this.value = checkNotNull(value, "value"); + } + + final int size() { + return name.length() + value.length() + HEADER_ENTRY_OVERHEAD; + } + + public final boolean equalsForTest(HpackHeaderField other) { + return equalsVariableTime(name, other.name) && equalsVariableTime(value, other.value); + } + + @Override + public String toString() { + return name + ": " + value; + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/HpackHuffmanDecoder.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/HpackHuffmanDecoder.java new file mode 100644 index 0000000..eac86e2 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/HpackHuffmanDecoder.java @@ -0,0 +1,4736 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/* + * Copyright 2014 Twitter, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.util.AsciiString; +import io.netty.util.ByteProcessor; + +import static io.netty.handler.codec.http2.Http2Error.COMPRESSION_ERROR; + +final class HpackHuffmanDecoder implements ByteProcessor { + + /* Scroll to the bottom! */ + + private static final byte HUFFMAN_COMPLETE = 1; + private static final byte HUFFMAN_EMIT_SYMBOL = 1 << 1; + private static final byte HUFFMAN_FAIL = 1 << 2; + + private static final int HUFFMAN_COMPLETE_SHIFT = HUFFMAN_COMPLETE << 8; + private static final int HUFFMAN_EMIT_SYMBOL_SHIFT = HUFFMAN_EMIT_SYMBOL << 8; + private static final int HUFFMAN_FAIL_SHIFT = HUFFMAN_FAIL << 8; + + /** + * A table of byte tuples (state, flags, output). They are packed together as: + *

+ * state<<16 + flags<<8 + output + */ + private static final int[] HUFFS = new int[] { + // Node 0 (Root Node, never emits symbols.) + 4 << 16, + 5 << 16, + 7 << 16, + 8 << 16, + 11 << 16, + 12 << 16, + 16 << 16, + 19 << 16, + 25 << 16, + 28 << 16, + 32 << 16, + 35 << 16, + 42 << 16, + 49 << 16, + 57 << 16, + (64 << 16) + (HUFFMAN_COMPLETE << 8), + + // Node 1 + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 48, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 49, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 50, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 97, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 99, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 101, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 105, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 111, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 115, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 116, + 13 << 16, + 14 << 16, + 17 << 16, + 18 << 16, + 20 << 16, + 21 << 16, + + // Node 2 + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 48, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 48, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 49, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 49, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 50, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 50, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 97, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 97, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 99, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 99, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 101, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 101, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 105, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 105, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 111, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 111, + + // Node 3 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 48, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 48, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 48, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 48, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 49, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 49, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 49, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 49, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 50, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 50, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 50, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 50, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 97, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 97, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 97, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 97, + + // Node 4 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 48, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 48, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 48, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 48, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 48, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 48, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 48, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 48, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 49, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 49, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 49, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 49, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 49, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 49, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 49, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 49, + + // Node 5 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 50, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 50, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 50, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 50, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 50, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 50, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 50, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 50, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 97, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 97, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 97, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 97, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 97, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 97, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 97, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 97, + + // Node 6 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 99, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 99, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 99, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 99, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 101, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 101, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 101, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 101, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 105, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 105, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 105, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 105, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 111, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 111, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 111, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 111, + + // Node 7 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 99, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 99, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 99, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 99, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 99, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 99, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 99, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 99, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 101, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 101, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 101, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 101, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 101, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 101, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 101, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 101, + + // Node 8 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 105, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 105, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 105, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 105, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 105, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 105, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 105, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 105, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 111, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 111, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 111, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 111, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 111, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 111, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 111, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 111, + + // Node 9 + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 115, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 115, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 116, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 116, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 32, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 37, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 45, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 46, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 47, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 51, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 52, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 53, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 54, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 55, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 56, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 57, + + // Node 10 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 115, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 115, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 115, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 115, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 116, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 116, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 116, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 116, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 32, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 32, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 37, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 37, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 45, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 45, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 46, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 46, + + // Node 11 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 115, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 115, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 115, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 115, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 115, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 115, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 115, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 115, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 116, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 116, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 116, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 116, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 116, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 116, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 116, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 116, + + // Node 12 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 32, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 32, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 32, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 32, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 37, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 37, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 37, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 37, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 45, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 45, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 45, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 45, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 46, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 46, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 46, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 46, + + // Node 13 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 32, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 32, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 32, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 32, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 32, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 32, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 32, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 32, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 37, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 37, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 37, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 37, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 37, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 37, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 37, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 37, + + // Node 14 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 45, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 45, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 45, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 45, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 45, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 45, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 45, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 45, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 46, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 46, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 46, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 46, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 46, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 46, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 46, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 46, + + // Node 15 + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 47, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 47, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 51, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 51, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 52, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 52, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 53, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 53, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 54, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 54, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 55, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 55, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 56, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 56, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 57, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 57, + + // Node 16 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 47, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 47, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 47, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 47, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 51, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 51, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 51, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 51, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 52, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 52, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 52, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 52, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 53, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 53, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 53, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 53, + + // Node 17 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 47, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 47, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 47, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 47, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 47, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 47, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 47, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 47, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 51, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 51, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 51, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 51, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 51, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 51, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 51, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 51, + + // Node 18 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 52, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 52, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 52, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 52, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 52, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 52, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 52, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 52, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 53, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 53, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 53, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 53, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 53, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 53, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 53, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 53, + + // Node 19 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 54, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 54, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 54, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 54, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 55, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 55, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 55, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 55, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 56, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 56, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 56, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 56, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 57, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 57, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 57, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 57, + + // Node 20 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 54, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 54, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 54, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 54, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 54, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 54, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 54, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 54, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 55, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 55, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 55, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 55, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 55, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 55, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 55, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 55, + + // Node 21 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 56, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 56, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 56, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 56, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 56, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 56, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 56, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 56, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 57, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 57, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 57, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 57, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 57, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 57, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 57, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 57, + + // Node 22 + 26 << 16, + 27 << 16, + 29 << 16, + 30 << 16, + 33 << 16, + 34 << 16, + 36 << 16, + 37 << 16, + 43 << 16, + 46 << 16, + 50 << 16, + 53 << 16, + 58 << 16, + 61 << 16, + 65 << 16, + (68 << 16) + (HUFFMAN_COMPLETE << 8), + + // Node 23 + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 61, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 65, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 95, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 98, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 100, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 102, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 103, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 104, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 108, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 109, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 110, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 112, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 114, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 117, + 38 << 16, + 39 << 16, + + // Node 24 + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 61, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 61, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 65, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 65, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 95, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 95, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 98, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 98, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 100, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 100, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 102, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 102, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 103, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 103, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 104, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 104, + + // Node 25 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 61, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 61, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 61, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 61, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 65, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 65, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 65, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 65, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 95, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 95, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 95, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 95, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 98, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 98, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 98, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 98, + + // Node 26 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 61, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 61, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 61, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 61, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 61, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 61, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 61, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 61, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 65, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 65, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 65, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 65, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 65, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 65, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 65, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 65, + + // Node 27 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 95, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 95, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 95, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 95, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 95, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 95, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 95, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 95, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 98, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 98, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 98, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 98, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 98, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 98, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 98, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 98, + + // Node 28 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 100, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 100, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 100, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 100, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 102, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 102, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 102, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 102, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 103, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 103, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 103, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 103, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 104, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 104, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 104, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 104, + + // Node 29 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 100, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 100, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 100, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 100, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 100, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 100, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 100, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 100, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 102, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 102, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 102, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 102, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 102, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 102, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 102, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 102, + + // Node 30 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 103, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 103, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 103, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 103, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 103, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 103, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 103, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 103, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 104, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 104, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 104, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 104, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 104, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 104, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 104, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 104, + + // Node 31 + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 108, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 108, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 109, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 109, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 110, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 110, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 112, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 112, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 114, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 114, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 117, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 117, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 58, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 66, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 67, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 68, + + // Node 32 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 108, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 108, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 108, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 108, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 109, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 109, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 109, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 109, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 110, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 110, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 110, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 110, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 112, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 112, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 112, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 112, + + // Node 33 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 108, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 108, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 108, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 108, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 108, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 108, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 108, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 108, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 109, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 109, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 109, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 109, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 109, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 109, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 109, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 109, + + // Node 34 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 110, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 110, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 110, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 110, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 110, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 110, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 110, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 110, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 112, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 112, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 112, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 112, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 112, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 112, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 112, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 112, + + // Node 35 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 114, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 114, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 114, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 114, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 117, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 117, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 117, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 117, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 58, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 58, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 66, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 66, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 67, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 67, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 68, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 68, + + // Node 36 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 114, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 114, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 114, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 114, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 114, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 114, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 114, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 114, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 117, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 117, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 117, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 117, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 117, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 117, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 117, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 117, + + // Node 37 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 58, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 58, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 58, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 58, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 66, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 66, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 66, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 66, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 67, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 67, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 67, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 67, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 68, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 68, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 68, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 68, + + // Node 38 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 58, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 58, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 58, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 58, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 58, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 58, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 58, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 58, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 66, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 66, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 66, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 66, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 66, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 66, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 66, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 66, + + // Node 39 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 67, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 67, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 67, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 67, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 67, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 67, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 67, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 67, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 68, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 68, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 68, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 68, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 68, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 68, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 68, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 68, + + // Node 40 + 44 << 16, + 45 << 16, + 47 << 16, + 48 << 16, + 51 << 16, + 52 << 16, + 54 << 16, + 55 << 16, + 59 << 16, + 60 << 16, + 62 << 16, + 63 << 16, + 66 << 16, + 67 << 16, + 69 << 16, + (72 << 16) + (HUFFMAN_COMPLETE << 8), + + // Node 41 + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 69, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 70, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 71, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 72, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 73, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 74, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 75, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 76, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 77, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 78, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 79, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 80, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 81, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 82, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 83, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 84, + + // Node 42 + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 69, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 69, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 70, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 70, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 71, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 71, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 72, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 72, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 73, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 73, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 74, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 74, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 75, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 75, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 76, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 76, + + // Node 43 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 69, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 69, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 69, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 69, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 70, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 70, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 70, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 70, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 71, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 71, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 71, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 71, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 72, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 72, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 72, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 72, + + // Node 44 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 69, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 69, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 69, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 69, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 69, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 69, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 69, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 69, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 70, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 70, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 70, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 70, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 70, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 70, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 70, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 70, + + // Node 45 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 71, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 71, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 71, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 71, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 71, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 71, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 71, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 71, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 72, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 72, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 72, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 72, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 72, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 72, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 72, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 72, + + // Node 46 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 73, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 73, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 73, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 73, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 74, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 74, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 74, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 74, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 75, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 75, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 75, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 75, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 76, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 76, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 76, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 76, + + // Node 47 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 73, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 73, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 73, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 73, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 73, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 73, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 73, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 73, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 74, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 74, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 74, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 74, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 74, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 74, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 74, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 74, + + // Node 48 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 75, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 75, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 75, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 75, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 75, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 75, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 75, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 75, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 76, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 76, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 76, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 76, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 76, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 76, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 76, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 76, + + // Node 49 + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 77, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 77, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 78, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 78, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 79, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 79, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 80, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 80, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 81, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 81, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 82, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 82, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 83, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 83, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 84, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 84, + + // Node 50 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 77, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 77, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 77, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 77, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 78, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 78, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 78, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 78, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 79, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 79, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 79, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 79, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 80, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 80, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 80, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 80, + + // Node 51 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 77, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 77, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 77, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 77, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 77, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 77, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 77, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 77, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 78, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 78, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 78, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 78, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 78, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 78, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 78, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 78, + + // Node 52 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 79, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 79, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 79, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 79, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 79, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 79, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 79, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 79, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 80, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 80, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 80, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 80, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 80, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 80, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 80, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 80, + + // Node 53 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 81, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 81, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 81, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 81, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 82, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 82, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 82, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 82, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 83, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 83, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 83, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 83, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 84, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 84, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 84, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 84, + + // Node 54 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 81, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 81, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 81, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 81, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 81, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 81, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 81, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 81, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 82, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 82, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 82, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 82, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 82, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 82, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 82, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 82, + + // Node 55 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 83, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 83, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 83, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 83, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 83, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 83, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 83, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 83, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 84, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 84, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 84, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 84, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 84, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 84, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 84, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 84, + + // Node 56 + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 85, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 86, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 87, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 89, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 106, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 107, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 113, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 118, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 119, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 120, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 121, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 122, + 70 << 16, + 71 << 16, + 73 << 16, + (74 << 16) + (HUFFMAN_COMPLETE << 8), + + // Node 57 + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 85, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 85, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 86, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 86, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 87, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 87, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 89, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 89, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 106, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 106, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 107, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 107, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 113, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 113, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 118, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 118, + + // Node 58 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 85, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 85, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 85, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 85, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 86, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 86, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 86, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 86, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 87, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 87, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 87, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 87, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 89, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 89, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 89, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 89, + + // Node 59 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 85, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 85, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 85, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 85, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 85, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 85, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 85, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 85, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 86, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 86, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 86, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 86, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 86, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 86, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 86, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 86, + + // Node 60 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 87, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 87, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 87, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 87, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 87, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 87, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 87, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 87, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 89, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 89, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 89, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 89, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 89, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 89, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 89, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 89, + + // Node 61 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 106, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 106, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 106, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 106, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 107, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 107, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 107, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 107, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 113, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 113, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 113, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 113, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 118, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 118, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 118, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 118, + + // Node 62 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 106, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 106, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 106, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 106, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 106, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 106, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 106, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 106, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 107, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 107, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 107, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 107, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 107, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 107, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 107, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 107, + + // Node 63 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 113, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 113, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 113, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 113, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 113, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 113, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 113, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 113, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 118, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 118, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 118, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 118, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 118, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 118, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 118, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 118, + + // Node 64 + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 119, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 119, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 120, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 120, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 121, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 121, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 122, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 122, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 38, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 42, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 44, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 59, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 88, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 90, + 75 << 16, + 78 << 16, + + // Node 65 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 119, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 119, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 119, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 119, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 120, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 120, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 120, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 120, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 121, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 121, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 121, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 121, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 122, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 122, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 122, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 122, + + // Node 66 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 119, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 119, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 119, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 119, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 119, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 119, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 119, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 119, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 120, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 120, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 120, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 120, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 120, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 120, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 120, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 120, + + // Node 67 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 121, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 121, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 121, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 121, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 121, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 121, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 121, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 121, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 122, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 122, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 122, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 122, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 122, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 122, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 122, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 122, + + // Node 68 + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 38, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 38, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 42, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 42, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 44, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 44, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 59, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 59, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 88, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 88, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 90, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 90, + 76 << 16, + 77 << 16, + 79 << 16, + 81 << 16, + + // Node 69 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 38, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 38, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 38, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 38, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 42, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 42, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 42, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 42, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 44, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 44, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 44, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 44, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 59, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 59, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 59, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 59, + + // Node 70 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 38, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 38, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 38, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 38, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 38, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 38, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 38, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 38, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 42, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 42, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 42, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 42, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 42, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 42, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 42, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 42, + + // Node 71 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 44, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 44, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 44, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 44, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 44, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 44, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 44, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 44, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 59, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 59, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 59, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 59, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 59, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 59, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 59, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 59, + + // Node 72 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 88, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 88, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 88, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 88, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 90, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 90, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 90, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 90, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 33, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 34, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 40, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 41, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 63, + 80 << 16, + 82 << 16, + 84 << 16, + + // Node 73 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 88, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 88, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 88, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 88, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 88, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 88, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 88, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 88, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 90, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 90, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 90, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 90, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 90, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 90, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 90, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 90, + + // Node 74 + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 33, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 33, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 34, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 34, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 40, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 40, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 41, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 41, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 63, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 63, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 39, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 43, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 124, + 83 << 16, + 85 << 16, + 88 << 16, + + // Node 75 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 33, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 33, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 33, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 33, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 34, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 34, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 34, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 34, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 40, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 40, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 40, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 40, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 41, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 41, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 41, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 41, + + // Node 76 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 33, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 33, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 33, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 33, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 33, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 33, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 33, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 33, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 34, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 34, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 34, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 34, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 34, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 34, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 34, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 34, + + // Node 77 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 40, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 40, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 40, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 40, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 40, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 40, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 40, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 40, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 41, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 41, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 41, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 41, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 41, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 41, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 41, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 41, + + // Node 78 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 63, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 63, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 63, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 63, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 39, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 39, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 43, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 43, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 124, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 124, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 35, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 62, + 86 << 16, + 87 << 16, + 89 << 16, + 90 << 16, + + // Node 79 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 63, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 63, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 63, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 63, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 63, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 63, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 63, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 63, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 39, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 39, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 39, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 39, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 43, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 43, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 43, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 43, + + // Node 80 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 39, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 39, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 39, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 39, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 39, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 39, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 39, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 39, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 43, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 43, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 43, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 43, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 43, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 43, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 43, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 43, + + // Node 81 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 124, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 124, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 124, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 124, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 35, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 35, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 62, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 62, + (HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 36, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 64, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 91, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 93, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 126, + 91 << 16, + 92 << 16, + + // Node 82 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 124, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 124, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 124, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 124, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 124, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 124, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 124, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 124, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 35, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 35, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 35, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 35, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 62, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 62, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 62, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 62, + + // Node 83 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 35, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 35, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 35, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 35, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 35, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 35, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 35, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 35, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 62, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 62, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 62, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 62, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 62, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 62, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 62, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 62, + + // Node 84 + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8), + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8), + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 36, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 36, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 64, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 64, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 91, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 91, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 93, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 93, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 126, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 126, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 94, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 125, + 93 << 16, + 94 << 16, + + // Node 85 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8), + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8), + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8), + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8), + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 36, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 36, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 36, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 36, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 64, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 64, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 64, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 64, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 91, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 91, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 91, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 91, + + // Node 86 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8), + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8), + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8), + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8), + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8), + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8), + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8), + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8), + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 36, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 36, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 36, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 36, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 36, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 36, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 36, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 36, + + // Node 87 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 64, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 64, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 64, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 64, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 64, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 64, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 64, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 64, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 91, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 91, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 91, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 91, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 91, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 91, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 91, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 91, + + // Node 88 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 93, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 93, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 93, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 93, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 126, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 126, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 126, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 126, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 94, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 94, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 125, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 125, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 60, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 96, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 123, + 95 << 16, + + // Node 89 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 93, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 93, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 93, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 93, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 93, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 93, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 93, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 93, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 126, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 126, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 126, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 126, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 126, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 126, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 126, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 126, + + // Node 90 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 94, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 94, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 94, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 94, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 125, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 125, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 125, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 125, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 60, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 60, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 96, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 96, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 123, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 123, + 96 << 16, + 110 << 16, + + // Node 91 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 94, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 94, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 94, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 94, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 94, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 94, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 94, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 94, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 125, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 125, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 125, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 125, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 125, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 125, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 125, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 125, + + // Node 92 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 60, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 60, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 60, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 60, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 96, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 96, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 96, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 96, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 123, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 123, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 123, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 123, + 97 << 16, + 101 << 16, + 111 << 16, + 133 << 16, + + // Node 93 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 60, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 60, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 60, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 60, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 60, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 60, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 60, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 60, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 96, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 96, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 96, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 96, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 96, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 96, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 96, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 96, + + // Node 94 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 123, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 123, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 123, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 123, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 123, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 123, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 123, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 123, + 98 << 16, + 99 << 16, + 102 << 16, + 105 << 16, + 112 << 16, + 119 << 16, + 134 << 16, + 153 << 16, + + // Node 95 + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 92, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 195, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 208, + 100 << 16, + 103 << 16, + 104 << 16, + 106 << 16, + 107 << 16, + 113 << 16, + 116 << 16, + 120 << 16, + 126 << 16, + 135 << 16, + 142 << 16, + 154 << 16, + 169 << 16, + + // Node 96 + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 92, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 92, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 195, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 195, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 208, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 208, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 128, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 130, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 131, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 162, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 184, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 194, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 224, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 226, + 108 << 16, + 109 << 16, + + // Node 97 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 92, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 92, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 92, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 92, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 195, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 195, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 195, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 195, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 208, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 208, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 208, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 208, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 128, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 128, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 130, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 130, + + // Node 98 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 92, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 92, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 92, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 92, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 92, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 92, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 92, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 92, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 195, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 195, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 195, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 195, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 195, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 195, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 195, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 195, + + // Node 99 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 208, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 208, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 208, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 208, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 208, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 208, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 208, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 208, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 128, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 128, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 128, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 128, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 130, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 130, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 130, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 130, + + // Node 100 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 128, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 128, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 128, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 128, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 128, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 128, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 128, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 128, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 130, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 130, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 130, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 130, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 130, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 130, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 130, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 130, + + // Node 101 + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 131, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 131, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 162, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 162, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 184, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 184, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 194, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 194, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 224, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 224, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 226, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 226, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 153, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 161, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 167, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 172, + + // Node 102 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 131, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 131, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 131, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 131, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 162, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 162, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 162, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 162, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 184, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 184, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 184, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 184, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 194, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 194, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 194, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 194, + + // Node 103 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 131, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 131, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 131, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 131, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 131, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 131, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 131, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 131, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 162, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 162, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 162, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 162, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 162, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 162, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 162, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 162, + + // Node 104 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 184, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 184, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 184, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 184, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 184, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 184, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 184, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 184, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 194, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 194, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 194, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 194, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 194, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 194, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 194, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 194, + + // Node 105 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 224, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 224, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 224, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 224, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 226, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 226, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 226, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 226, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 153, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 153, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 161, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 161, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 167, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 167, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 172, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 172, + + // Node 106 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 224, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 224, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 224, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 224, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 224, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 224, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 224, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 224, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 226, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 226, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 226, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 226, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 226, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 226, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 226, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 226, + + // Node 107 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 153, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 153, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 153, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 153, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 161, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 161, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 161, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 161, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 167, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 167, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 167, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 167, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 172, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 172, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 172, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 172, + + // Node 108 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 153, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 153, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 153, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 153, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 153, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 153, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 153, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 153, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 161, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 161, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 161, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 161, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 161, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 161, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 161, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 161, + + // Node 109 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 167, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 167, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 167, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 167, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 167, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 167, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 167, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 167, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 172, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 172, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 172, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 172, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 172, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 172, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 172, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 172, + + // Node 110 + 114 << 16, + 115 << 16, + 117 << 16, + 118 << 16, + 121 << 16, + 123 << 16, + 127 << 16, + 130 << 16, + 136 << 16, + 139 << 16, + 143 << 16, + 146 << 16, + 155 << 16, + 162 << 16, + 170 << 16, + 180 << 16, + + // Node 111 + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 176, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 177, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 179, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 209, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 216, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 217, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 227, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 229, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 230, + 122 << 16, + 124 << 16, + 125 << 16, + 128 << 16, + 129 << 16, + 131 << 16, + 132 << 16, + + // Node 112 + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 176, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 176, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 177, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 177, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 179, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 179, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 209, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 209, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 216, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 216, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 217, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 217, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 227, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 227, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 229, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 229, + + // Node 113 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 176, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 176, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 176, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 176, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 177, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 177, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 177, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 177, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 179, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 179, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 179, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 179, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 209, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 209, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 209, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 209, + + // Node 114 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 176, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 176, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 176, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 176, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 176, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 176, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 176, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 176, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 177, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 177, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 177, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 177, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 177, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 177, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 177, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 177, + + // Node 115 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 179, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 179, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 179, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 179, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 179, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 179, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 179, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 179, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 209, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 209, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 209, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 209, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 209, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 209, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 209, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 209, + + // Node 116 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 216, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 216, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 216, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 216, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 217, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 217, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 217, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 217, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 227, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 227, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 227, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 227, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 229, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 229, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 229, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 229, + + // Node 117 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 216, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 216, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 216, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 216, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 216, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 216, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 216, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 216, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 217, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 217, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 217, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 217, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 217, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 217, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 217, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 217, + + // Node 118 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 227, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 227, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 227, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 227, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 227, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 227, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 227, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 227, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 229, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 229, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 229, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 229, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 229, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 229, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 229, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 229, + + // Node 119 + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 230, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 230, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 129, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 132, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 133, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 134, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 136, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 146, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 154, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 156, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 160, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 163, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 164, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 169, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 170, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 173, + + // Node 120 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 230, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 230, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 230, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 230, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 129, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 129, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 132, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 132, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 133, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 133, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 134, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 134, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 136, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 136, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 146, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 146, + + // Node 121 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 230, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 230, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 230, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 230, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 230, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 230, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 230, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 230, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 129, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 129, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 129, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 129, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 132, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 132, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 132, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 132, + + // Node 122 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 129, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 129, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 129, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 129, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 129, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 129, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 129, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 129, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 132, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 132, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 132, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 132, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 132, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 132, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 132, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 132, + + // Node 123 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 133, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 133, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 133, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 133, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 134, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 134, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 134, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 134, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 136, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 136, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 136, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 136, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 146, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 146, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 146, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 146, + + // Node 124 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 133, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 133, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 133, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 133, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 133, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 133, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 133, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 133, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 134, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 134, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 134, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 134, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 134, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 134, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 134, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 134, + + // Node 125 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 136, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 136, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 136, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 136, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 136, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 136, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 136, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 136, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 146, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 146, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 146, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 146, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 146, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 146, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 146, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 146, + + // Node 126 + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 154, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 154, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 156, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 156, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 160, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 160, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 163, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 163, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 164, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 164, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 169, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 169, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 170, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 170, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 173, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 173, + + // Node 127 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 154, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 154, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 154, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 154, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 156, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 156, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 156, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 156, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 160, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 160, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 160, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 160, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 163, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 163, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 163, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 163, + + // Node 128 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 154, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 154, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 154, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 154, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 154, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 154, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 154, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 154, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 156, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 156, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 156, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 156, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 156, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 156, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 156, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 156, + + // Node 129 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 160, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 160, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 160, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 160, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 160, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 160, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 160, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 160, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 163, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 163, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 163, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 163, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 163, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 163, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 163, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 163, + + // Node 130 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 164, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 164, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 164, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 164, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 169, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 169, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 169, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 169, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 170, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 170, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 170, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 170, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 173, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 173, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 173, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 173, + + // Node 131 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 164, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 164, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 164, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 164, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 164, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 164, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 164, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 164, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 169, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 169, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 169, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 169, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 169, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 169, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 169, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 169, + + // Node 132 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 170, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 170, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 170, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 170, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 170, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 170, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 170, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 170, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 173, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 173, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 173, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 173, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 173, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 173, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 173, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 173, + + // Node 133 + 137 << 16, + 138 << 16, + 140 << 16, + 141 << 16, + 144 << 16, + 145 << 16, + 147 << 16, + 150 << 16, + 156 << 16, + 159 << 16, + 163 << 16, + 166 << 16, + 171 << 16, + 174 << 16, + 181 << 16, + 190 << 16, + + // Node 134 + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 178, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 181, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 185, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 186, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 187, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 189, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 190, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 196, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 198, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 228, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 232, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 233, + 148 << 16, + 149 << 16, + 151 << 16, + 152 << 16, + + // Node 135 + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 178, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 178, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 181, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 181, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 185, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 185, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 186, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 186, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 187, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 187, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 189, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 189, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 190, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 190, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 196, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 196, + + // Node 136 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 178, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 178, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 178, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 178, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 181, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 181, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 181, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 181, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 185, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 185, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 185, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 185, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 186, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 186, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 186, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 186, + + // Node 137 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 178, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 178, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 178, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 178, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 178, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 178, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 178, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 178, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 181, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 181, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 181, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 181, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 181, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 181, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 181, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 181, + + // Node 138 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 185, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 185, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 185, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 185, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 185, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 185, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 185, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 185, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 186, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 186, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 186, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 186, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 186, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 186, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 186, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 186, + + // Node 139 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 187, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 187, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 187, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 187, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 189, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 189, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 189, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 189, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 190, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 190, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 190, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 190, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 196, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 196, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 196, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 196, + + // Node 140 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 187, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 187, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 187, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 187, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 187, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 187, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 187, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 187, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 189, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 189, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 189, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 189, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 189, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 189, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 189, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 189, + + // Node 141 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 190, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 190, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 190, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 190, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 190, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 190, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 190, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 190, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 196, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 196, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 196, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 196, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 196, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 196, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 196, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 196, + + // Node 142 + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 198, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 198, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 228, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 228, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 232, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 232, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 233, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 233, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 1, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 135, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 137, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 138, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 139, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 140, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 141, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 143, + + // Node 143 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 198, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 198, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 198, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 198, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 228, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 228, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 228, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 228, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 232, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 232, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 232, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 232, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 233, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 233, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 233, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 233, + + // Node 144 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 198, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 198, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 198, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 198, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 198, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 198, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 198, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 198, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 228, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 228, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 228, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 228, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 228, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 228, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 228, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 228, + + // Node 145 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 232, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 232, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 232, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 232, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 232, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 232, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 232, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 232, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 233, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 233, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 233, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 233, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 233, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 233, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 233, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 233, + + // Node 146 + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 1, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 1, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 135, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 135, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 137, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 137, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 138, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 138, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 139, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 139, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 140, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 140, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 141, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 141, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 143, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 143, + + // Node 147 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 1, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 1, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 1, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 1, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 135, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 135, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 135, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 135, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 137, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 137, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 137, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 137, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 138, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 138, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 138, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 138, + + // Node 148 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 1, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 1, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 1, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 1, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 1, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 1, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 1, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 1, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 135, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 135, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 135, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 135, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 135, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 135, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 135, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 135, + + // Node 149 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 137, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 137, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 137, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 137, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 137, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 137, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 137, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 137, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 138, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 138, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 138, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 138, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 138, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 138, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 138, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 138, + + // Node 150 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 139, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 139, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 139, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 139, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 140, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 140, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 140, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 140, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 141, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 141, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 141, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 141, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 143, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 143, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 143, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 143, + + // Node 151 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 139, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 139, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 139, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 139, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 139, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 139, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 139, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 139, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 140, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 140, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 140, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 140, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 140, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 140, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 140, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 140, + + // Node 152 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 141, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 141, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 141, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 141, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 141, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 141, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 141, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 141, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 143, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 143, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 143, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 143, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 143, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 143, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 143, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 143, + + // Node 153 + 157 << 16, + 158 << 16, + 160 << 16, + 161 << 16, + 164 << 16, + 165 << 16, + 167 << 16, + 168 << 16, + 172 << 16, + 173 << 16, + 175 << 16, + 177 << 16, + 182 << 16, + 185 << 16, + 191 << 16, + 207 << 16, + + // Node 154 + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 147, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 149, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 150, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 151, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 152, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 155, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 157, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 158, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 165, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 166, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 168, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 174, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 175, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 180, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 182, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 183, + + // Node 155 + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 147, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 147, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 149, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 149, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 150, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 150, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 151, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 151, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 152, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 152, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 155, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 155, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 157, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 157, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 158, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 158, + + // Node 156 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 147, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 147, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 147, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 147, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 149, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 149, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 149, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 149, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 150, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 150, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 150, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 150, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 151, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 151, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 151, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 151, + + // Node 157 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 147, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 147, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 147, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 147, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 147, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 147, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 147, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 147, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 149, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 149, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 149, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 149, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 149, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 149, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 149, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 149, + + // Node 158 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 150, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 150, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 150, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 150, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 150, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 150, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 150, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 150, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 151, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 151, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 151, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 151, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 151, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 151, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 151, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 151, + + // Node 159 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 152, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 152, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 152, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 152, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 155, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 155, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 155, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 155, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 157, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 157, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 157, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 157, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 158, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 158, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 158, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 158, + + // Node 160 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 152, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 152, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 152, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 152, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 152, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 152, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 152, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 152, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 155, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 155, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 155, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 155, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 155, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 155, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 155, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 155, + + // Node 161 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 157, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 157, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 157, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 157, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 157, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 157, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 157, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 157, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 158, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 158, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 158, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 158, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 158, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 158, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 158, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 158, + + // Node 162 + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 165, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 165, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 166, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 166, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 168, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 168, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 174, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 174, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 175, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 175, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 180, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 180, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 182, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 182, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 183, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 183, + + // Node 163 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 165, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 165, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 165, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 165, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 166, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 166, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 166, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 166, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 168, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 168, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 168, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 168, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 174, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 174, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 174, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 174, + + // Node 164 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 165, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 165, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 165, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 165, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 165, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 165, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 165, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 165, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 166, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 166, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 166, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 166, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 166, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 166, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 166, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 166, + + // Node 165 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 168, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 168, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 168, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 168, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 168, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 168, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 168, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 168, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 174, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 174, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 174, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 174, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 174, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 174, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 174, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 174, + + // Node 166 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 175, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 175, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 175, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 175, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 180, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 180, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 180, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 180, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 182, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 182, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 182, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 182, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 183, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 183, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 183, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 183, + + // Node 167 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 175, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 175, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 175, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 175, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 175, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 175, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 175, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 175, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 180, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 180, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 180, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 180, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 180, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 180, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 180, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 180, + + // Node 168 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 182, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 182, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 182, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 182, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 182, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 182, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 182, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 182, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 183, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 183, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 183, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 183, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 183, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 183, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 183, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 183, + + // Node 169 + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 188, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 191, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 197, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 231, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 239, + 176 << 16, + 178 << 16, + 179 << 16, + 183 << 16, + 184 << 16, + 186 << 16, + 187 << 16, + 192 << 16, + 199 << 16, + 208 << 16, + 223 << 16, + + // Node 170 + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 188, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 188, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 191, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 191, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 197, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 197, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 231, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 231, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 239, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 239, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 9, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 142, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 144, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 145, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 148, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 159, + + // Node 171 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 188, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 188, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 188, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 188, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 191, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 191, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 191, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 191, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 197, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 197, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 197, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 197, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 231, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 231, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 231, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 231, + + // Node 172 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 188, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 188, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 188, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 188, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 188, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 188, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 188, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 188, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 191, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 191, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 191, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 191, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 191, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 191, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 191, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 191, + + // Node 173 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 197, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 197, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 197, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 197, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 197, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 197, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 197, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 197, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 231, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 231, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 231, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 231, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 231, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 231, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 231, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 231, + + // Node 174 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 239, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 239, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 239, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 239, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 9, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 9, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 142, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 142, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 144, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 144, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 145, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 145, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 148, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 148, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 159, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 159, + + // Node 175 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 239, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 239, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 239, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 239, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 239, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 239, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 239, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 239, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 9, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 9, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 9, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 9, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 142, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 142, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 142, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 142, + + // Node 176 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 9, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 9, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 9, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 9, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 9, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 9, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 9, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 9, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 142, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 142, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 142, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 142, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 142, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 142, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 142, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 142, + + // Node 177 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 144, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 144, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 144, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 144, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 145, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 145, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 145, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 145, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 148, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 148, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 148, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 148, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 159, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 159, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 159, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 159, + + // Node 178 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 144, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 144, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 144, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 144, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 144, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 144, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 144, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 144, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 145, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 145, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 145, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 145, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 145, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 145, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 145, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 145, + + // Node 179 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 148, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 148, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 148, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 148, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 148, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 148, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 148, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 148, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 159, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 159, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 159, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 159, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 159, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 159, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 159, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 159, + + // Node 180 + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 171, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 206, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 215, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 225, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 236, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 237, + 188 << 16, + 189 << 16, + 193 << 16, + 196 << 16, + 200 << 16, + 203 << 16, + 209 << 16, + 216 << 16, + 224 << 16, + 238 << 16, + + // Node 181 + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 171, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 171, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 206, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 206, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 215, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 215, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 225, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 225, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 236, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 236, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 237, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 237, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 199, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 207, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 234, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 235, + + // Node 182 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 171, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 171, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 171, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 171, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 206, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 206, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 206, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 206, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 215, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 215, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 215, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 215, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 225, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 225, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 225, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 225, + + // Node 183 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 171, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 171, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 171, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 171, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 171, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 171, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 171, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 171, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 206, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 206, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 206, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 206, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 206, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 206, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 206, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 206, + + // Node 184 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 215, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 215, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 215, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 215, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 215, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 215, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 215, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 215, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 225, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 225, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 225, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 225, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 225, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 225, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 225, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 225, + + // Node 185 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 236, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 236, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 236, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 236, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 237, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 237, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 237, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 237, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 199, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 199, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 207, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 207, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 234, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 234, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 235, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 235, + + // Node 186 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 236, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 236, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 236, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 236, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 236, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 236, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 236, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 236, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 237, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 237, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 237, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 237, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 237, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 237, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 237, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 237, + + // Node 187 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 199, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 199, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 199, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 199, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 207, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 207, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 207, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 207, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 234, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 234, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 234, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 234, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 235, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 235, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 235, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 235, + + // Node 188 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 199, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 199, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 199, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 199, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 199, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 199, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 199, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 199, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 207, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 207, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 207, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 207, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 207, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 207, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 207, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 207, + + // Node 189 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 234, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 234, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 234, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 234, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 234, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 234, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 234, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 234, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 235, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 235, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 235, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 235, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 235, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 235, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 235, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 235, + + // Node 190 + 194 << 16, + 195 << 16, + 197 << 16, + 198 << 16, + 201 << 16, + 202 << 16, + 204 << 16, + 205 << 16, + 210 << 16, + 213 << 16, + 217 << 16, + 220 << 16, + 225 << 16, + 231 << 16, + 239 << 16, + 246 << 16, + + // Node 191 + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 192, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 193, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 200, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 201, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 202, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 205, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 210, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 213, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 218, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 219, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 238, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 240, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 242, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 243, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 255, + 206 << 16, + + // Node 192 + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 192, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 192, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 193, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 193, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 200, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 200, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 201, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 201, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 202, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 202, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 205, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 205, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 210, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 210, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 213, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 213, + + // Node 193 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 192, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 192, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 192, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 192, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 193, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 193, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 193, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 193, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 200, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 200, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 200, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 200, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 201, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 201, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 201, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 201, + + // Node 194 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 192, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 192, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 192, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 192, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 192, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 192, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 192, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 192, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 193, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 193, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 193, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 193, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 193, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 193, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 193, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 193, + + // Node 195 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 200, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 200, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 200, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 200, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 200, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 200, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 200, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 200, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 201, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 201, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 201, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 201, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 201, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 201, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 201, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 201, + + // Node 196 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 202, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 202, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 202, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 202, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 205, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 205, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 205, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 205, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 210, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 210, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 210, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 210, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 213, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 213, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 213, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 213, + + // Node 197 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 202, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 202, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 202, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 202, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 202, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 202, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 202, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 202, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 205, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 205, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 205, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 205, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 205, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 205, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 205, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 205, + + // Node 198 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 210, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 210, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 210, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 210, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 210, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 210, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 210, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 210, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 213, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 213, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 213, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 213, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 213, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 213, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 213, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 213, + + // Node 199 + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 218, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 218, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 219, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 219, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 238, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 238, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 240, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 240, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 242, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 242, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 243, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 243, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 255, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 255, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 203, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 204, + + // Node 200 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 218, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 218, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 218, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 218, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 219, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 219, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 219, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 219, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 238, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 238, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 238, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 238, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 240, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 240, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 240, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 240, + + // Node 201 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 218, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 218, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 218, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 218, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 218, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 218, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 218, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 218, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 219, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 219, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 219, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 219, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 219, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 219, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 219, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 219, + + // Node 202 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 238, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 238, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 238, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 238, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 238, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 238, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 238, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 238, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 240, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 240, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 240, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 240, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 240, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 240, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 240, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 240, + + // Node 203 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 242, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 242, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 242, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 242, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 243, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 243, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 243, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 243, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 255, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 255, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 255, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 255, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 203, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 203, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 204, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 204, + + // Node 204 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 242, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 242, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 242, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 242, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 242, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 242, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 242, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 242, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 243, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 243, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 243, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 243, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 243, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 243, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 243, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 243, + + // Node 205 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 255, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 255, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 255, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 255, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 255, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 255, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 255, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 255, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 203, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 203, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 203, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 203, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 204, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 204, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 204, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 204, + + // Node 206 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 203, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 203, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 203, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 203, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 203, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 203, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 203, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 203, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 204, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 204, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 204, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 204, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 204, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 204, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 204, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 204, + + // Node 207 + 211 << 16, + 212 << 16, + 214 << 16, + 215 << 16, + 218 << 16, + 219 << 16, + 221 << 16, + 222 << 16, + 226 << 16, + 228 << 16, + 232 << 16, + 235 << 16, + 240 << 16, + 243 << 16, + 247 << 16, + 250 << 16, + + // Node 208 + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 211, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 212, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 214, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 221, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 222, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 223, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 241, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 244, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 245, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 246, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 247, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 248, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 250, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 251, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 252, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 253, + + // Node 209 + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 211, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 211, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 212, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 212, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 214, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 214, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 221, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 221, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 222, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 222, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 223, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 223, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 241, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 241, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 244, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 244, + + // Node 210 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 211, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 211, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 211, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 211, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 212, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 212, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 212, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 212, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 214, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 214, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 214, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 214, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 221, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 221, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 221, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 221, + + // Node 211 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 211, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 211, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 211, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 211, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 211, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 211, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 211, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 211, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 212, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 212, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 212, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 212, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 212, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 212, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 212, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 212, + + // Node 212 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 214, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 214, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 214, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 214, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 214, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 214, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 214, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 214, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 221, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 221, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 221, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 221, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 221, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 221, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 221, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 221, + + // Node 213 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 222, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 222, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 222, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 222, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 223, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 223, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 223, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 223, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 241, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 241, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 241, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 241, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 244, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 244, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 244, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 244, + + // Node 214 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 222, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 222, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 222, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 222, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 222, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 222, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 222, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 222, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 223, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 223, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 223, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 223, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 223, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 223, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 223, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 223, + + // Node 215 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 241, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 241, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 241, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 241, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 241, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 241, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 241, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 241, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 244, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 244, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 244, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 244, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 244, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 244, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 244, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 244, + + // Node 216 + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 245, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 245, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 246, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 246, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 247, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 247, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 248, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 248, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 250, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 250, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 251, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 251, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 252, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 252, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 253, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 253, + + // Node 217 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 245, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 245, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 245, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 245, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 246, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 246, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 246, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 246, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 247, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 247, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 247, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 247, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 248, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 248, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 248, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 248, + + // Node 218 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 245, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 245, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 245, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 245, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 245, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 245, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 245, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 245, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 246, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 246, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 246, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 246, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 246, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 246, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 246, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 246, + + // Node 219 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 247, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 247, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 247, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 247, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 247, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 247, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 247, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 247, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 248, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 248, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 248, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 248, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 248, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 248, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 248, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 248, + + // Node 220 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 250, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 250, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 250, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 250, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 251, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 251, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 251, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 251, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 252, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 252, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 252, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 252, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 253, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 253, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 253, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 253, + + // Node 221 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 250, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 250, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 250, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 250, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 250, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 250, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 250, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 250, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 251, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 251, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 251, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 251, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 251, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 251, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 251, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 251, + + // Node 222 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 252, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 252, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 252, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 252, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 252, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 252, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 252, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 252, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 253, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 253, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 253, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 253, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 253, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 253, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 253, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 253, + + // Node 223 + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 254, + 227 << 16, + 229 << 16, + 230 << 16, + 233 << 16, + 234 << 16, + 236 << 16, + 237 << 16, + 241 << 16, + 242 << 16, + 244 << 16, + 245 << 16, + 248 << 16, + 249 << 16, + 251 << 16, + 252 << 16, + + // Node 224 + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 254, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 254, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 2, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 3, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 4, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 5, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 6, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 7, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 8, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 11, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 12, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 14, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 15, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 16, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 17, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 18, + + // Node 225 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 254, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 254, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 254, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 254, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 2, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 2, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 3, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 3, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 4, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 4, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 5, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 5, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 6, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 6, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 7, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 7, + + // Node 226 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 254, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 254, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 254, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 254, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 254, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 254, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 254, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 254, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 2, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 2, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 2, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 2, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 3, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 3, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 3, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 3, + + // Node 227 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 2, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 2, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 2, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 2, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 2, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 2, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 2, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 2, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 3, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 3, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 3, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 3, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 3, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 3, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 3, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 3, + + // Node 228 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 4, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 4, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 4, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 4, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 5, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 5, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 5, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 5, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 6, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 6, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 6, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 6, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 7, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 7, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 7, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 7, + + // Node 229 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 4, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 4, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 4, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 4, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 4, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 4, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 4, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 4, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 5, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 5, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 5, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 5, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 5, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 5, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 5, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 5, + + // Node 230 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 6, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 6, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 6, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 6, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 6, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 6, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 6, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 6, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 7, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 7, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 7, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 7, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 7, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 7, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 7, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 7, + + // Node 231 + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 8, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 8, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 11, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 11, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 12, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 12, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 14, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 14, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 15, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 15, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 16, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 16, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 17, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 17, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 18, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 18, + + // Node 232 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 8, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 8, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 8, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 8, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 11, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 11, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 11, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 11, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 12, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 12, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 12, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 12, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 14, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 14, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 14, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 14, + + // Node 233 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 8, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 8, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 8, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 8, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 8, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 8, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 8, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 8, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 11, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 11, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 11, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 11, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 11, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 11, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 11, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 11, + + // Node 234 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 12, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 12, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 12, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 12, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 12, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 12, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 12, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 12, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 14, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 14, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 14, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 14, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 14, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 14, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 14, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 14, + + // Node 235 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 15, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 15, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 15, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 15, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 16, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 16, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 16, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 16, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 17, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 17, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 17, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 17, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 18, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 18, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 18, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 18, + + // Node 236 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 15, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 15, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 15, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 15, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 15, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 15, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 15, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 15, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 16, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 16, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 16, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 16, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 16, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 16, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 16, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 16, + + // Node 237 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 17, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 17, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 17, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 17, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 17, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 17, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 17, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 17, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 18, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 18, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 18, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 18, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 18, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 18, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 18, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 18, + + // Node 238 + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 19, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 20, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 21, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 23, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 24, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 25, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 26, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 27, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 28, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 29, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 30, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 31, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 127, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 220, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 249, + 253 << 16, + + // Node 239 + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 19, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 19, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 20, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 20, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 21, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 21, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 23, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 23, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 24, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 24, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 25, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 25, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 26, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 26, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 27, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 27, + + // Node 240 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 19, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 19, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 19, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 19, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 20, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 20, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 20, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 20, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 21, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 21, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 21, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 21, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 23, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 23, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 23, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 23, + + // Node 241 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 19, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 19, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 19, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 19, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 19, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 19, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 19, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 19, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 20, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 20, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 20, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 20, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 20, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 20, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 20, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 20, + + // Node 242 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 21, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 21, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 21, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 21, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 21, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 21, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 21, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 21, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 23, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 23, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 23, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 23, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 23, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 23, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 23, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 23, + + // Node 243 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 24, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 24, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 24, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 24, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 25, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 25, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 25, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 25, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 26, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 26, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 26, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 26, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 27, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 27, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 27, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 27, + + // Node 244 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 24, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 24, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 24, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 24, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 24, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 24, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 24, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 24, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 25, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 25, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 25, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 25, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 25, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 25, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 25, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 25, + + // Node 245 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 26, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 26, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 26, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 26, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 26, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 26, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 26, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 26, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 27, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 27, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 27, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 27, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 27, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 27, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 27, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 27, + + // Node 246 + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 28, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 28, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 29, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 29, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 30, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 30, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 31, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 31, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 127, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 127, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 220, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 220, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 249, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 249, + 254 << 16, + 255 << 16, + + // Node 247 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 28, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 28, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 28, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 28, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 29, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 29, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 29, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 29, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 30, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 30, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 30, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 30, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 31, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 31, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 31, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 31, + + // Node 248 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 28, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 28, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 28, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 28, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 28, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 28, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 28, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 28, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 29, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 29, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 29, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 29, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 29, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 29, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 29, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 29, + + // Node 249 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 30, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 30, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 30, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 30, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 30, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 30, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 30, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 30, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 31, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 31, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 31, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 31, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 31, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 31, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 31, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 31, + + // Node 250 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 127, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 127, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 127, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 127, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 220, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 220, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 220, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 220, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 249, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 249, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 249, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 249, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 10, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 13, + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 22, + HUFFMAN_FAIL << 8, + + // Node 251 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 127, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 127, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 127, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 127, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 127, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 127, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 127, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 127, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 220, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 220, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 220, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 220, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 220, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 220, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 220, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 220, + + // Node 252 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 249, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 249, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 249, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 249, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 249, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 249, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 249, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 249, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 10, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 10, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 13, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 13, + (1 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 22, + (22 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 22, + HUFFMAN_FAIL << 8, + HUFFMAN_FAIL << 8, + + // Node 253 + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 10, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 10, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 10, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 10, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 13, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 13, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 13, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 13, + (2 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 22, + (9 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 22, + (23 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 22, + (40 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 22, + HUFFMAN_FAIL << 8, + HUFFMAN_FAIL << 8, + HUFFMAN_FAIL << 8, + HUFFMAN_FAIL << 8, + + // Node 254 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 10, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 10, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 10, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 10, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 10, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 10, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 10, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 10, + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 13, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 13, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 13, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 13, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 13, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 13, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 13, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 13, + + // Node 255 + (3 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 22, + (6 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 22, + (10 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 22, + (15 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 22, + (24 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 22, + (31 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 22, + (41 << 16) + (HUFFMAN_EMIT_SYMBOL << 8) + 22, + (56 << 16) + ((HUFFMAN_COMPLETE | HUFFMAN_EMIT_SYMBOL) << 8) + 22, + HUFFMAN_FAIL << 8, + HUFFMAN_FAIL << 8, + HUFFMAN_FAIL << 8, + HUFFMAN_FAIL << 8, + HUFFMAN_FAIL << 8, + HUFFMAN_FAIL << 8, + HUFFMAN_FAIL << 8, + HUFFMAN_FAIL << 8, + }; + + private static final Http2Exception BAD_ENCODING = + Http2Exception.newStatic(COMPRESSION_ERROR, "HPACK - Bad Encoding", + Http2Exception.ShutdownHint.HARD_SHUTDOWN, HpackHuffmanDecoder.class, "decode(..)"); + + private byte[] dest; + private int k; + private int state; + + HpackHuffmanDecoder() { } + + /** + * Decompresses the given Huffman coded string literal. + * + * @param buf the string literal to be decoded + * @return the output stream for the compressed data + * @throws Http2Exception EOS Decoded + */ + public AsciiString decode(ByteBuf buf, int length) throws Http2Exception { + if (length == 0) { + return AsciiString.EMPTY_STRING; + } + dest = new byte[length * 8 / 5]; + try { + int readerIndex = buf.readerIndex(); + // Using ByteProcessor to reduce bounds-checking and reference-count checking during byte-by-byte + // processing of the ByteBuf. + int endIndex = buf.forEachByte(readerIndex, length, this); + if (endIndex == -1) { + // We did consume the requested length + buf.readerIndex(readerIndex + length); + if ((state & HUFFMAN_COMPLETE_SHIFT) != HUFFMAN_COMPLETE_SHIFT) { + throw BAD_ENCODING; + } + return new AsciiString(dest, 0, k, false); + } + + // The process(...) method returned before the requested length was requested. This means there + // was a bad encoding detected. + buf.readerIndex(endIndex); + throw BAD_ENCODING; + } finally { + dest = null; + k = 0; + state = 0; + } + } + + /** + * This should never be called from anything but this class itself! + */ + @Override + public boolean process(byte input) { + return processNibble(input >> 4) && processNibble(input); + } + + private boolean processNibble(int input) { + // The high nibble of the flags byte of each row is always zero + // (low nibble after shifting row by 12), since there are only 3 flag bits + int index = state >> 12 | (input & 0x0F); + state = HUFFS[index]; + if ((state & HUFFMAN_FAIL_SHIFT) != 0) { + return false; + } + if ((state & HUFFMAN_EMIT_SYMBOL_SHIFT) != 0) { + // state is always positive so can cast without mask here + dest[k++] = (byte) state; + } + return true; + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/HpackHuffmanEncoder.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/HpackHuffmanEncoder.java new file mode 100644 index 0000000..d3641ce --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/HpackHuffmanEncoder.java @@ -0,0 +1,194 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/* + * Copyright 2014 Twitter, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.util.AsciiString; +import io.netty.util.ByteProcessor; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.PlatformDependent; + +final class HpackHuffmanEncoder { + + private final int[] codes; + private final byte[] lengths; + private final EncodedLengthProcessor encodedLengthProcessor = new EncodedLengthProcessor(); + private final EncodeProcessor encodeProcessor = new EncodeProcessor(); + + HpackHuffmanEncoder() { + this(HpackUtil.HUFFMAN_CODES, HpackUtil.HUFFMAN_CODE_LENGTHS); + } + + /** + * Creates a new Huffman encoder with the specified Huffman coding. + * + * @param codes the Huffman codes indexed by symbol + * @param lengths the length of each Huffman code + */ + private HpackHuffmanEncoder(int[] codes, byte[] lengths) { + this.codes = codes; + this.lengths = lengths; + } + + /** + * Compresses the input string literal using the Huffman coding. + * + * @param out the output stream for the compressed data + * @param data the string literal to be Huffman encoded + */ + public void encode(ByteBuf out, CharSequence data) { + ObjectUtil.checkNotNull(out, "out"); + if (data instanceof AsciiString) { + AsciiString string = (AsciiString) data; + try { + encodeProcessor.out = out; + string.forEachByte(encodeProcessor); + } catch (Exception e) { + PlatformDependent.throwException(e); + } finally { + encodeProcessor.end(); + } + } else { + encodeSlowPath(out, data); + } + } + + private void encodeSlowPath(ByteBuf out, CharSequence data) { + long current = 0; + int n = 0; + + for (int i = 0; i < data.length(); i++) { + int b = AsciiString.c2b(data.charAt(i)) & 0xFF; + int code = codes[b]; + int nbits = lengths[b]; + + current <<= nbits; + current |= code; + n += nbits; + + while (n >= 8) { + n -= 8; + out.writeByte((int) (current >> n)); + } + } + + if (n > 0) { + current <<= 8 - n; + current |= 0xFF >>> n; // this should be EOS symbol + out.writeByte((int) current); + } + } + + /** + * Returns the number of bytes required to Huffman encode the input string literal. + * + * @param data the string literal to be Huffman encoded + * @return the number of bytes required to Huffman encode {@code data} + */ + int getEncodedLength(CharSequence data) { + if (data instanceof AsciiString) { + AsciiString string = (AsciiString) data; + try { + encodedLengthProcessor.reset(); + string.forEachByte(encodedLengthProcessor); + return encodedLengthProcessor.length(); + } catch (Exception e) { + PlatformDependent.throwException(e); + return -1; + } + } else { + return getEncodedLengthSlowPath(data); + } + } + + private int getEncodedLengthSlowPath(CharSequence data) { + long len = 0; + for (int i = 0; i < data.length(); i++) { + len += lengths[AsciiString.c2b(data.charAt(i)) & 0xFF]; + } + return (int) (len + 7 >> 3); + } + + private final class EncodeProcessor implements ByteProcessor { + ByteBuf out; + private long current; + private int n; + + @Override + public boolean process(byte value) { + int b = value & 0xFF; + int nbits = lengths[b]; + + current <<= nbits; + current |= codes[b]; + n += nbits; + + while (n >= 8) { + n -= 8; + out.writeByte((int) (current >> n)); + } + return true; + } + + void end() { + try { + if (n > 0) { + current <<= 8 - n; + current |= 0xFF >>> n; // this should be EOS symbol + out.writeByte((int) current); + } + } finally { + out = null; + current = 0; + n = 0; + } + } + } + + private final class EncodedLengthProcessor implements ByteProcessor { + private long len; + + @Override + public boolean process(byte value) { + len += lengths[value & 0xFF]; + return true; + } + + void reset() { + len = 0; + } + + int length() { + return (int) ((len + 7) >> 3); + } + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/HpackStaticTable.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/HpackStaticTable.java new file mode 100644 index 0000000..666c14c --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/HpackStaticTable.java @@ -0,0 +1,257 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/* + * Copyright 2014 Twitter, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.util.AsciiString; +import io.netty.util.internal.PlatformDependent; + +import java.util.Arrays; +import java.util.List; + +import static io.netty.handler.codec.http2.HpackUtil.equalsVariableTime; + +final class HpackStaticTable { + + static final int NOT_FOUND = -1; + + // Appendix A: Static Table + // https://tools.ietf.org/html/rfc7541#appendix-A + private static final List STATIC_TABLE = Arrays.asList( + /* 1 */ newEmptyHeaderField(":authority"), + /* 2 */ newHeaderField(":method", "GET"), + /* 3 */ newHeaderField(":method", "POST"), + /* 4 */ newHeaderField(":path", "/"), + /* 5 */ newHeaderField(":path", "/index.html"), + /* 6 */ newHeaderField(":scheme", "http"), + /* 7 */ newHeaderField(":scheme", "https"), + /* 8 */ newHeaderField(":status", "200"), + /* 9 */ newHeaderField(":status", "204"), + /* 10 */ newHeaderField(":status", "206"), + /* 11 */ newHeaderField(":status", "304"), + /* 12 */ newHeaderField(":status", "400"), + /* 13 */ newHeaderField(":status", "404"), + /* 14 */ newHeaderField(":status", "500"), + /* 15 */ newEmptyHeaderField("accept-charset"), + /* 16 */ newHeaderField("accept-encoding", "gzip, deflate"), + /* 17 */ newEmptyHeaderField("accept-language"), + /* 18 */ newEmptyHeaderField("accept-ranges"), + /* 19 */ newEmptyHeaderField("accept"), + /* 20 */ newEmptyHeaderField("access-control-allow-origin"), + /* 21 */ newEmptyHeaderField("age"), + /* 22 */ newEmptyHeaderField("allow"), + /* 23 */ newEmptyHeaderField("authorization"), + /* 24 */ newEmptyHeaderField("cache-control"), + /* 25 */ newEmptyHeaderField("content-disposition"), + /* 26 */ newEmptyHeaderField("content-encoding"), + /* 27 */ newEmptyHeaderField("content-language"), + /* 28 */ newEmptyHeaderField("content-length"), + /* 29 */ newEmptyHeaderField("content-location"), + /* 30 */ newEmptyHeaderField("content-range"), + /* 31 */ newEmptyHeaderField("content-type"), + /* 32 */ newEmptyHeaderField("cookie"), + /* 33 */ newEmptyHeaderField("date"), + /* 34 */ newEmptyHeaderField("etag"), + /* 35 */ newEmptyHeaderField("expect"), + /* 36 */ newEmptyHeaderField("expires"), + /* 37 */ newEmptyHeaderField("from"), + /* 38 */ newEmptyHeaderField("host"), + /* 39 */ newEmptyHeaderField("if-match"), + /* 40 */ newEmptyHeaderField("if-modified-since"), + /* 41 */ newEmptyHeaderField("if-none-match"), + /* 42 */ newEmptyHeaderField("if-range"), + /* 43 */ newEmptyHeaderField("if-unmodified-since"), + /* 44 */ newEmptyHeaderField("last-modified"), + /* 45 */ newEmptyHeaderField("link"), + /* 46 */ newEmptyHeaderField("location"), + /* 47 */ newEmptyHeaderField("max-forwards"), + /* 48 */ newEmptyHeaderField("proxy-authenticate"), + /* 49 */ newEmptyHeaderField("proxy-authorization"), + /* 50 */ newEmptyHeaderField("range"), + /* 51 */ newEmptyHeaderField("referer"), + /* 52 */ newEmptyHeaderField("refresh"), + /* 53 */ newEmptyHeaderField("retry-after"), + /* 54 */ newEmptyHeaderField("server"), + /* 55 */ newEmptyHeaderField("set-cookie"), + /* 56 */ newEmptyHeaderField("strict-transport-security"), + /* 57 */ newEmptyHeaderField("transfer-encoding"), + /* 58 */ newEmptyHeaderField("user-agent"), + /* 59 */ newEmptyHeaderField("vary"), + /* 60 */ newEmptyHeaderField("via"), + /* 61 */ newEmptyHeaderField("www-authenticate") + ); + + private static HpackHeaderField newEmptyHeaderField(String name) { + return new HpackHeaderField(AsciiString.cached(name), AsciiString.EMPTY_STRING); + } + + private static HpackHeaderField newHeaderField(String name, String value) { + return new HpackHeaderField(AsciiString.cached(name), AsciiString.cached(value)); + } + + // The table size and bit shift are chosen so that each hash bucket contains a single header name. + private static final int HEADER_NAMES_TABLE_SIZE = 1 << 9; + + private static final int HEADER_NAMES_TABLE_SHIFT = PlatformDependent.BIG_ENDIAN_NATIVE_ORDER ? 22 : 18; + + // A table mapping header names to their associated indexes. + private static final HeaderNameIndex[] HEADER_NAMES = new HeaderNameIndex[HEADER_NAMES_TABLE_SIZE]; + static { + // Iterate through the static table in reverse order to + // save the smallest index for a given name in the table. + for (int index = STATIC_TABLE.size(); index > 0; index--) { + HpackHeaderField entry = getEntry(index); + int bucket = headerNameBucket(entry.name); + HeaderNameIndex tableEntry = HEADER_NAMES[bucket]; + if (tableEntry != null && !equalsVariableTime(tableEntry.name, entry.name)) { + // Can happen if AsciiString.hashCode changes + throw new IllegalStateException("Hash bucket collision between " + + tableEntry.name + " and " + entry.name); + } + HEADER_NAMES[bucket] = new HeaderNameIndex(entry.name, index, entry.value.length() == 0); + } + } + + // The table size and bit shift are chosen so that each hash bucket contains a single header. + private static final int HEADERS_WITH_NON_EMPTY_VALUES_TABLE_SIZE = 1 << 6; + + private static final int HEADERS_WITH_NON_EMPTY_VALUES_TABLE_SHIFT = + PlatformDependent.BIG_ENDIAN_NATIVE_ORDER ? 0 : 6; + + // A table mapping headers with non-empty values to their associated indexes. + private static final HeaderIndex[] HEADERS_WITH_NON_EMPTY_VALUES = + new HeaderIndex[HEADERS_WITH_NON_EMPTY_VALUES_TABLE_SIZE]; + static { + for (int index = STATIC_TABLE.size(); index > 0; index--) { + HpackHeaderField entry = getEntry(index); + if (entry.value.length() > 0) { + int bucket = headerBucket(entry.value); + HeaderIndex tableEntry = HEADERS_WITH_NON_EMPTY_VALUES[bucket]; + if (tableEntry != null) { + // Can happen if AsciiString.hashCode changes + throw new IllegalStateException("Hash bucket collision between " + + tableEntry.value + " and " + entry.value); + } + HEADERS_WITH_NON_EMPTY_VALUES[bucket] = new HeaderIndex(entry.name, entry.value, index); + } + } + } + + /** + * The number of header fields in the static table. + */ + static final int length = STATIC_TABLE.size(); + + /** + * Return the header field at the given index value. + */ + static HpackHeaderField getEntry(int index) { + return STATIC_TABLE.get(index - 1); + } + + /** + * Returns the lowest index value for the given header field name in the static table. Returns + * -1 if the header field name is not in the static table. + */ + static int getIndex(CharSequence name) { + HeaderNameIndex entry = getEntry(name); + return entry == null ? NOT_FOUND : entry.index; + } + + /** + * Returns the index value for the given header field in the static table. Returns -1 if the + * header field is not in the static table. + */ + static int getIndexInsensitive(CharSequence name, CharSequence value) { + if (value.length() == 0) { + HeaderNameIndex entry = getEntry(name); + return entry == null || !entry.emptyValue ? NOT_FOUND : entry.index; + } + int bucket = headerBucket(value); + HeaderIndex header = HEADERS_WITH_NON_EMPTY_VALUES[bucket]; + if (header == null) { + return NOT_FOUND; + } + if (equalsVariableTime(header.name, name) && equalsVariableTime(header.value, value)) { + return header.index; + } + return NOT_FOUND; + } + + private static HeaderNameIndex getEntry(CharSequence name) { + int bucket = headerNameBucket(name); + HeaderNameIndex entry = HEADER_NAMES[bucket]; + if (entry == null) { + return null; + } + return equalsVariableTime(entry.name, name) ? entry : null; + } + + private static int headerNameBucket(CharSequence name) { + return bucket(name, HEADER_NAMES_TABLE_SHIFT, HEADER_NAMES_TABLE_SIZE - 1); + } + + private static int headerBucket(CharSequence value) { + return bucket(value, HEADERS_WITH_NON_EMPTY_VALUES_TABLE_SHIFT, HEADERS_WITH_NON_EMPTY_VALUES_TABLE_SIZE - 1); + } + + private static int bucket(CharSequence s, int shift, int mask) { + return (AsciiString.hashCode(s) >> shift) & mask; + } + + private static final class HeaderNameIndex { + final CharSequence name; + final int index; + final boolean emptyValue; + + HeaderNameIndex(CharSequence name, int index, boolean emptyValue) { + this.name = name; + this.index = index; + this.emptyValue = emptyValue; + } + } + + private static final class HeaderIndex { + final CharSequence name; + final CharSequence value; + final int index; + + HeaderIndex(CharSequence name, CharSequence value, int index) { + this.name = name; + this.value = value; + this.index = index; + } + } + + // singleton + private HpackStaticTable() { + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/HpackUtil.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/HpackUtil.java new file mode 100644 index 0000000..5c51c52 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/HpackUtil.java @@ -0,0 +1,372 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/* + * Copyright 2014 Twitter, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.util.AsciiString; +import io.netty.util.internal.ConstantTimeUtils; +import io.netty.util.internal.PlatformDependent; + +final class HpackUtil { + /** + * Compare two {@link CharSequence} objects without leaking timing information. + *

+ * The {@code int} return type is intentional and is designed to allow cascading of constant time operations: + *

+     *     String s1 = "foo";
+     *     String s2 = "foo";
+     *     String s3 = "foo";
+     *     String s4 = "goo";
+     *     boolean equals = (equalsConstantTime(s1, s2) & equalsConstantTime(s3, s4)) != 0;
+     * 
+ * @param s1 the first value. + * @param s2 the second value. + * @return {@code 0} if not equal. {@code 1} if equal. + */ + static int equalsConstantTime(CharSequence s1, CharSequence s2) { + if (s1 instanceof AsciiString && s2 instanceof AsciiString) { + if (s1.length() != s2.length()) { + return 0; + } + AsciiString s1Ascii = (AsciiString) s1; + AsciiString s2Ascii = (AsciiString) s2; + return PlatformDependent.equalsConstantTime(s1Ascii.array(), s1Ascii.arrayOffset(), + s2Ascii.array(), s2Ascii.arrayOffset(), s1.length()); + } + + return ConstantTimeUtils.equalsConstantTime(s1, s2); + } + + /** + * Compare two {@link CharSequence}s. + * @param s1 the first value. + * @param s2 the second value. + * @return {@code false} if not equal. {@code true} if equal. + */ + static boolean equalsVariableTime(CharSequence s1, CharSequence s2) { + return AsciiString.contentEquals(s1, s2); + } + + // Section 6.2. Literal Header Field Representation + enum IndexType { + INCREMENTAL, // Section 6.2.1. Literal Header Field with Incremental Indexing + NONE, // Section 6.2.2. Literal Header Field without Indexing + NEVER // Section 6.2.3. Literal Header Field never Indexed + } + + // Appendix B: Huffman Codes + // https://tools.ietf.org/html/rfc7541#appendix-B + static final int[] HUFFMAN_CODES = { + 0x1ff8, + 0x7fffd8, + 0xfffffe2, + 0xfffffe3, + 0xfffffe4, + 0xfffffe5, + 0xfffffe6, + 0xfffffe7, + 0xfffffe8, + 0xffffea, + 0x3ffffffc, + 0xfffffe9, + 0xfffffea, + 0x3ffffffd, + 0xfffffeb, + 0xfffffec, + 0xfffffed, + 0xfffffee, + 0xfffffef, + 0xffffff0, + 0xffffff1, + 0xffffff2, + 0x3ffffffe, + 0xffffff3, + 0xffffff4, + 0xffffff5, + 0xffffff6, + 0xffffff7, + 0xffffff8, + 0xffffff9, + 0xffffffa, + 0xffffffb, + 0x14, + 0x3f8, + 0x3f9, + 0xffa, + 0x1ff9, + 0x15, + 0xf8, + 0x7fa, + 0x3fa, + 0x3fb, + 0xf9, + 0x7fb, + 0xfa, + 0x16, + 0x17, + 0x18, + 0x0, + 0x1, + 0x2, + 0x19, + 0x1a, + 0x1b, + 0x1c, + 0x1d, + 0x1e, + 0x1f, + 0x5c, + 0xfb, + 0x7ffc, + 0x20, + 0xffb, + 0x3fc, + 0x1ffa, + 0x21, + 0x5d, + 0x5e, + 0x5f, + 0x60, + 0x61, + 0x62, + 0x63, + 0x64, + 0x65, + 0x66, + 0x67, + 0x68, + 0x69, + 0x6a, + 0x6b, + 0x6c, + 0x6d, + 0x6e, + 0x6f, + 0x70, + 0x71, + 0x72, + 0xfc, + 0x73, + 0xfd, + 0x1ffb, + 0x7fff0, + 0x1ffc, + 0x3ffc, + 0x22, + 0x7ffd, + 0x3, + 0x23, + 0x4, + 0x24, + 0x5, + 0x25, + 0x26, + 0x27, + 0x6, + 0x74, + 0x75, + 0x28, + 0x29, + 0x2a, + 0x7, + 0x2b, + 0x76, + 0x2c, + 0x8, + 0x9, + 0x2d, + 0x77, + 0x78, + 0x79, + 0x7a, + 0x7b, + 0x7ffe, + 0x7fc, + 0x3ffd, + 0x1ffd, + 0xffffffc, + 0xfffe6, + 0x3fffd2, + 0xfffe7, + 0xfffe8, + 0x3fffd3, + 0x3fffd4, + 0x3fffd5, + 0x7fffd9, + 0x3fffd6, + 0x7fffda, + 0x7fffdb, + 0x7fffdc, + 0x7fffdd, + 0x7fffde, + 0xffffeb, + 0x7fffdf, + 0xffffec, + 0xffffed, + 0x3fffd7, + 0x7fffe0, + 0xffffee, + 0x7fffe1, + 0x7fffe2, + 0x7fffe3, + 0x7fffe4, + 0x1fffdc, + 0x3fffd8, + 0x7fffe5, + 0x3fffd9, + 0x7fffe6, + 0x7fffe7, + 0xffffef, + 0x3fffda, + 0x1fffdd, + 0xfffe9, + 0x3fffdb, + 0x3fffdc, + 0x7fffe8, + 0x7fffe9, + 0x1fffde, + 0x7fffea, + 0x3fffdd, + 0x3fffde, + 0xfffff0, + 0x1fffdf, + 0x3fffdf, + 0x7fffeb, + 0x7fffec, + 0x1fffe0, + 0x1fffe1, + 0x3fffe0, + 0x1fffe2, + 0x7fffed, + 0x3fffe1, + 0x7fffee, + 0x7fffef, + 0xfffea, + 0x3fffe2, + 0x3fffe3, + 0x3fffe4, + 0x7ffff0, + 0x3fffe5, + 0x3fffe6, + 0x7ffff1, + 0x3ffffe0, + 0x3ffffe1, + 0xfffeb, + 0x7fff1, + 0x3fffe7, + 0x7ffff2, + 0x3fffe8, + 0x1ffffec, + 0x3ffffe2, + 0x3ffffe3, + 0x3ffffe4, + 0x7ffffde, + 0x7ffffdf, + 0x3ffffe5, + 0xfffff1, + 0x1ffffed, + 0x7fff2, + 0x1fffe3, + 0x3ffffe6, + 0x7ffffe0, + 0x7ffffe1, + 0x3ffffe7, + 0x7ffffe2, + 0xfffff2, + 0x1fffe4, + 0x1fffe5, + 0x3ffffe8, + 0x3ffffe9, + 0xffffffd, + 0x7ffffe3, + 0x7ffffe4, + 0x7ffffe5, + 0xfffec, + 0xfffff3, + 0xfffed, + 0x1fffe6, + 0x3fffe9, + 0x1fffe7, + 0x1fffe8, + 0x7ffff3, + 0x3fffea, + 0x3fffeb, + 0x1ffffee, + 0x1ffffef, + 0xfffff4, + 0xfffff5, + 0x3ffffea, + 0x7ffff4, + 0x3ffffeb, + 0x7ffffe6, + 0x3ffffec, + 0x3ffffed, + 0x7ffffe7, + 0x7ffffe8, + 0x7ffffe9, + 0x7ffffea, + 0x7ffffeb, + 0xffffffe, + 0x7ffffec, + 0x7ffffed, + 0x7ffffee, + 0x7ffffef, + 0x7fffff0, + 0x3ffffee, + 0x3fffffff // EOS + }; + + static final byte[] HUFFMAN_CODE_LENGTHS = { + 13, 23, 28, 28, 28, 28, 28, 28, 28, 24, 30, 28, 28, 30, 28, 28, + 28, 28, 28, 28, 28, 28, 30, 28, 28, 28, 28, 28, 28, 28, 28, 28, + 6, 10, 10, 12, 13, 6, 8, 11, 10, 10, 8, 11, 8, 6, 6, 6, + 5, 5, 5, 6, 6, 6, 6, 6, 6, 6, 7, 8, 15, 6, 12, 10, + 13, 6, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, 7, + 7, 7, 7, 7, 7, 7, 7, 7, 8, 7, 8, 13, 19, 13, 14, 6, + 15, 5, 6, 5, 6, 5, 6, 6, 6, 5, 7, 7, 6, 6, 6, 5, + 6, 7, 6, 5, 5, 6, 7, 7, 7, 7, 7, 15, 11, 14, 13, 28, + 20, 22, 20, 20, 22, 22, 22, 23, 22, 23, 23, 23, 23, 23, 24, 23, + 24, 24, 22, 23, 24, 23, 23, 23, 23, 21, 22, 23, 22, 23, 23, 24, + 22, 21, 20, 22, 22, 23, 23, 21, 23, 22, 22, 24, 21, 22, 23, 23, + 21, 21, 22, 21, 23, 22, 23, 23, 20, 22, 22, 22, 23, 22, 22, 23, + 26, 26, 20, 19, 22, 23, 22, 25, 26, 26, 26, 27, 27, 26, 24, 25, + 19, 21, 26, 27, 27, 26, 27, 24, 21, 21, 26, 26, 28, 27, 27, 27, + 20, 24, 20, 21, 22, 21, 21, 23, 22, 22, 25, 25, 24, 24, 26, 23, + 26, 27, 26, 26, 27, 27, 27, 27, 27, 28, 27, 27, 27, 27, 27, 26, + 30 // EOS + }; + + static final int HUFFMAN_EOS = 256; + + private HpackUtil() { + // utility class + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ChannelDuplexHandler.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ChannelDuplexHandler.java new file mode 100644 index 0000000..f999e20 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ChannelDuplexHandler.java @@ -0,0 +1,94 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.codec.http2; + +import io.netty.channel.ChannelDuplexHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPipeline; +import io.netty.util.internal.StringUtil; +import io.netty.util.internal.UnstableApi; + +/** + * A {@link ChannelDuplexHandler} providing additional functionality for HTTP/2. Specifically it allows to: + *
    + *
  • Create new outbound streams using {@link #newStream()}.
  • + *
  • Iterate over all active streams using {@link #forEachActiveStream(Http2FrameStreamVisitor)}.
  • + *
+ * + *

The {@link Http2FrameCodec} is required to be part of the {@link ChannelPipeline} before this handler is added, + * or else an {@link IllegalStateException} will be thrown. + */ +@UnstableApi +public abstract class Http2ChannelDuplexHandler extends ChannelDuplexHandler { + + private volatile Http2FrameCodec frameCodec; + + @Override + public final void handlerAdded(ChannelHandlerContext ctx) throws Exception { + frameCodec = requireHttp2FrameCodec(ctx); + handlerAdded0(ctx); + } + + protected void handlerAdded0(@SuppressWarnings("unused") ChannelHandlerContext ctx) throws Exception { + // NOOP + } + + @Override + public final void handlerRemoved(ChannelHandlerContext ctx) throws Exception { + try { + handlerRemoved0(ctx); + } finally { + frameCodec = null; + } + } + + protected void handlerRemoved0(@SuppressWarnings("unused") ChannelHandlerContext ctx) throws Exception { + // NOOP + } + + /** + * Creates a new {@link Http2FrameStream} object. + * + *

This method is thread-safe. + */ + public final Http2FrameStream newStream() { + Http2FrameCodec codec = frameCodec; + if (codec == null) { + throw new IllegalStateException(StringUtil.simpleClassName(Http2FrameCodec.class) + " not found." + + " Has the handler been added to a pipeline?"); + } + return codec.newStream(); + } + + /** + * Allows to iterate over all currently active streams. + * + *

This method may only be called from the eventloop thread. + */ + protected final void forEachActiveStream(Http2FrameStreamVisitor streamVisitor) throws Http2Exception { + frameCodec.forEachActiveStream(streamVisitor); + } + + private static Http2FrameCodec requireHttp2FrameCodec(ChannelHandlerContext ctx) { + ChannelHandlerContext frameCodecCtx = ctx.pipeline().context(Http2FrameCodec.class); + if (frameCodecCtx == null) { + throw new IllegalArgumentException(Http2FrameCodec.class.getSimpleName() + + " was not found in the channel pipeline."); + } + return (Http2FrameCodec) frameCodecCtx.handler(); + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ClientUpgradeCodec.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ClientUpgradeCodec.java new file mode 100644 index 0000000..a6b6fb2 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ClientUpgradeCodec.java @@ -0,0 +1,175 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.base64.Base64; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpClientUpgradeHandler; +import io.netty.handler.codec.http.HttpRequest; +import io.netty.util.collection.CharObjectMap; +import io.netty.util.internal.UnstableApi; + +import java.util.Collection; +import java.util.Collections; +import java.util.List; + +import static io.netty.handler.codec.base64.Base64Dialect.URL_SAFE; +import static io.netty.handler.codec.http2.Http2CodecUtil.HTTP_UPGRADE_PROTOCOL_NAME; +import static io.netty.handler.codec.http2.Http2CodecUtil.HTTP_UPGRADE_SETTINGS_HEADER; +import static io.netty.handler.codec.http2.Http2CodecUtil.SETTING_ENTRY_LENGTH; +import static io.netty.util.CharsetUtil.UTF_8; +import static io.netty.util.ReferenceCountUtil.release; +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +/** + * Client-side cleartext upgrade codec from HTTP to HTTP/2. + */ +@UnstableApi +public class Http2ClientUpgradeCodec implements HttpClientUpgradeHandler.UpgradeCodec { + + private static final List UPGRADE_HEADERS = Collections.singletonList(HTTP_UPGRADE_SETTINGS_HEADER); + + private final String handlerName; + private final Http2ConnectionHandler connectionHandler; + private final ChannelHandler upgradeToHandler; + private final ChannelHandler http2MultiplexHandler; + + public Http2ClientUpgradeCodec(Http2FrameCodec frameCodec, ChannelHandler upgradeToHandler) { + this(null, frameCodec, upgradeToHandler); + } + + public Http2ClientUpgradeCodec(String handlerName, Http2FrameCodec frameCodec, ChannelHandler upgradeToHandler) { + this(handlerName, (Http2ConnectionHandler) frameCodec, upgradeToHandler, null); + } + + /** + * Creates the codec using a default name for the connection handler when adding to the + * pipeline. + * + * @param connectionHandler the HTTP/2 connection handler + */ + public Http2ClientUpgradeCodec(Http2ConnectionHandler connectionHandler) { + this((String) null, connectionHandler); + } + + /** + * Creates the codec using a default name for the connection handler when adding to the + * pipeline. + * + * @param connectionHandler the HTTP/2 connection handler + * @param http2MultiplexHandler the Http2 Multiplexer handler to work with Http2FrameCodec + */ + public Http2ClientUpgradeCodec(Http2ConnectionHandler connectionHandler, + Http2MultiplexHandler http2MultiplexHandler) { + this((String) null, connectionHandler, http2MultiplexHandler); + } + + /** + * Creates the codec providing an upgrade to the given handler for HTTP/2. + * + * @param handlerName the name of the HTTP/2 connection handler to be used in the pipeline, + * or {@code null} to auto-generate the name + * @param connectionHandler the HTTP/2 connection handler + */ + public Http2ClientUpgradeCodec(String handlerName, Http2ConnectionHandler connectionHandler) { + this(handlerName, connectionHandler, connectionHandler, null); + } + + /** + * Creates the codec providing an upgrade to the given handler for HTTP/2. + * + * @param handlerName the name of the HTTP/2 connection handler to be used in the pipeline, + * or {@code null} to auto-generate the name + * @param connectionHandler the HTTP/2 connection handler + */ + public Http2ClientUpgradeCodec(String handlerName, Http2ConnectionHandler connectionHandler, + Http2MultiplexHandler http2MultiplexHandler) { + this(handlerName, connectionHandler, connectionHandler, http2MultiplexHandler); + } + + private Http2ClientUpgradeCodec(String handlerName, Http2ConnectionHandler connectionHandler, ChannelHandler + upgradeToHandler, Http2MultiplexHandler http2MultiplexHandler) { + this.handlerName = handlerName; + this.connectionHandler = checkNotNull(connectionHandler, "connectionHandler"); + this.upgradeToHandler = checkNotNull(upgradeToHandler, "upgradeToHandler"); + this.http2MultiplexHandler = http2MultiplexHandler; + } + + @Override + public CharSequence protocol() { + return HTTP_UPGRADE_PROTOCOL_NAME; + } + + @Override + public Collection setUpgradeHeaders(ChannelHandlerContext ctx, + HttpRequest upgradeRequest) { + CharSequence settingsValue = getSettingsHeaderValue(ctx); + upgradeRequest.headers().set(HTTP_UPGRADE_SETTINGS_HEADER, settingsValue); + return UPGRADE_HEADERS; + } + + @Override + public void upgradeTo(ChannelHandlerContext ctx, FullHttpResponse upgradeResponse) + throws Exception { + try { + // Add the handler to the pipeline. + ctx.pipeline().addAfter(ctx.name(), handlerName, upgradeToHandler); + + // Add the Http2 Multiplex handler as this handler handle events produced by the connectionHandler. + // See https://github.com/netty/netty/issues/9495 + if (http2MultiplexHandler != null) { + final String name = ctx.pipeline().context(connectionHandler).name(); + ctx.pipeline().addAfter(name, null, http2MultiplexHandler); + } + + // Reserve local stream 1 for the response. + connectionHandler.onHttpClientUpgrade(); + } catch (Http2Exception e) { + ctx.fireExceptionCaught(e); + ctx.close(); + } + } + + /** + * Converts the current settings for the handler to the Base64-encoded representation used in + * the HTTP2-Settings upgrade header. + */ + private CharSequence getSettingsHeaderValue(ChannelHandlerContext ctx) { + ByteBuf buf = null; + ByteBuf encodedBuf = null; + try { + // Get the local settings for the handler. + Http2Settings settings = connectionHandler.decoder().localSettings(); + + // Serialize the payload of the SETTINGS frame. + int payloadLength = SETTING_ENTRY_LENGTH * settings.size(); + buf = ctx.alloc().buffer(payloadLength); + for (CharObjectMap.PrimitiveEntry entry : settings.entries()) { + buf.writeChar(entry.key()); + buf.writeInt(entry.value().intValue()); + } + + // Base64 encode the payload and then convert to a string for the header. + encodedBuf = Base64.encode(buf, URL_SAFE); + return encodedBuf.toString(UTF_8); + } finally { + release(buf); + release(encodedBuf); + } + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2CodecUtil.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2CodecUtil.java new file mode 100644 index 0000000..0dead35 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2CodecUtil.java @@ -0,0 +1,404 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.channel.DefaultChannelPromise; +import io.netty.handler.ssl.ApplicationProtocolNames; +import io.netty.util.AsciiString; +import io.netty.util.concurrent.EventExecutor; +import io.netty.util.internal.UnstableApi; + +import static io.netty.buffer.Unpooled.directBuffer; +import static io.netty.buffer.Unpooled.unreleasableBuffer; +import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR; +import static io.netty.handler.codec.http2.Http2Exception.connectionError; +import static io.netty.handler.codec.http2.Http2Exception.headerListSizeError; +import static io.netty.util.CharsetUtil.UTF_8; +import static java.lang.Math.max; +import static java.lang.Math.min; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.concurrent.TimeUnit.SECONDS; + +/** + * Constants and utility method used for encoding/decoding HTTP2 frames. + */ +@UnstableApi +public final class Http2CodecUtil { + public static final int CONNECTION_STREAM_ID = 0; + public static final int HTTP_UPGRADE_STREAM_ID = 1; + public static final CharSequence HTTP_UPGRADE_SETTINGS_HEADER = AsciiString.cached("HTTP2-Settings"); + public static final CharSequence HTTP_UPGRADE_PROTOCOL_NAME = "h2c"; + public static final CharSequence TLS_UPGRADE_PROTOCOL_NAME = ApplicationProtocolNames.HTTP_2; + + public static final int PING_FRAME_PAYLOAD_LENGTH = 8; + public static final short MAX_UNSIGNED_BYTE = 0xff; + /** + * The maximum number of padding bytes. That is the 255 padding bytes appended to the end of a frame and the 1 byte + * pad length field. + */ + public static final int MAX_PADDING = 256; + public static final long MAX_UNSIGNED_INT = 0xffffffffL; + public static final int FRAME_HEADER_LENGTH = 9; + public static final int SETTING_ENTRY_LENGTH = 6; + public static final int PRIORITY_ENTRY_LENGTH = 5; + public static final int INT_FIELD_LENGTH = 4; + public static final short MAX_WEIGHT = 256; + public static final short MIN_WEIGHT = 1; + + private static final ByteBuf CONNECTION_PREFACE = + unreleasableBuffer(directBuffer(24).writeBytes("PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n".getBytes(UTF_8))) + .asReadOnly(); + + private static final int MAX_PADDING_LENGTH_LENGTH = 1; + public static final int DATA_FRAME_HEADER_LENGTH = FRAME_HEADER_LENGTH + MAX_PADDING_LENGTH_LENGTH; + public static final int HEADERS_FRAME_HEADER_LENGTH = + FRAME_HEADER_LENGTH + MAX_PADDING_LENGTH_LENGTH + INT_FIELD_LENGTH + 1; + public static final int PRIORITY_FRAME_LENGTH = FRAME_HEADER_LENGTH + PRIORITY_ENTRY_LENGTH; + public static final int RST_STREAM_FRAME_LENGTH = FRAME_HEADER_LENGTH + INT_FIELD_LENGTH; + public static final int PUSH_PROMISE_FRAME_HEADER_LENGTH = + FRAME_HEADER_LENGTH + MAX_PADDING_LENGTH_LENGTH + INT_FIELD_LENGTH; + public static final int GO_AWAY_FRAME_HEADER_LENGTH = FRAME_HEADER_LENGTH + 2 * INT_FIELD_LENGTH; + public static final int WINDOW_UPDATE_FRAME_LENGTH = FRAME_HEADER_LENGTH + INT_FIELD_LENGTH; + public static final int CONTINUATION_FRAME_HEADER_LENGTH = FRAME_HEADER_LENGTH + MAX_PADDING_LENGTH_LENGTH; + + public static final char SETTINGS_HEADER_TABLE_SIZE = 1; + public static final char SETTINGS_ENABLE_PUSH = 2; + public static final char SETTINGS_MAX_CONCURRENT_STREAMS = 3; + public static final char SETTINGS_INITIAL_WINDOW_SIZE = 4; + public static final char SETTINGS_MAX_FRAME_SIZE = 5; + public static final char SETTINGS_MAX_HEADER_LIST_SIZE = 6; + public static final int NUM_STANDARD_SETTINGS = 6; + + public static final long MAX_HEADER_TABLE_SIZE = MAX_UNSIGNED_INT; + public static final long MAX_CONCURRENT_STREAMS = MAX_UNSIGNED_INT; + public static final int MAX_INITIAL_WINDOW_SIZE = Integer.MAX_VALUE; + public static final int MAX_FRAME_SIZE_LOWER_BOUND = 0x4000; + public static final int MAX_FRAME_SIZE_UPPER_BOUND = 0xffffff; + public static final long MAX_HEADER_LIST_SIZE = MAX_UNSIGNED_INT; + + public static final long MIN_HEADER_TABLE_SIZE = 0; + public static final long MIN_CONCURRENT_STREAMS = 0; + public static final int MIN_INITIAL_WINDOW_SIZE = 0; + public static final long MIN_HEADER_LIST_SIZE = 0; + + public static final int DEFAULT_WINDOW_SIZE = 65535; + public static final short DEFAULT_PRIORITY_WEIGHT = 16; + public static final int DEFAULT_HEADER_TABLE_SIZE = 4096; + /** + * The initial value of this setting is unlimited. + * However in practice we don't want to allow our peers to use unlimited memory by default. So we take advantage + * of the For any given request, a lower limit than what is advertised MAY be enforced. loophole. + */ + public static final long DEFAULT_HEADER_LIST_SIZE = 8192; + public static final int DEFAULT_MAX_FRAME_SIZE = MAX_FRAME_SIZE_LOWER_BOUND; + /** + * The assumed minimum value for {@code SETTINGS_MAX_CONCURRENT_STREAMS} as + * recommended by the HTTP/2 spec. + */ + public static final int SMALLEST_MAX_CONCURRENT_STREAMS = 100; + static final int DEFAULT_MAX_RESERVED_STREAMS = SMALLEST_MAX_CONCURRENT_STREAMS; + static final int DEFAULT_MIN_ALLOCATION_CHUNK = 1024; + + /** + * Calculate the threshold in bytes which should trigger a {@code GO_AWAY} if a set of headers exceeds this amount. + * @param maxHeaderListSize + * SETTINGS_MAX_HEADER_LIST_SIZE for the local + * endpoint. + * @return the threshold in bytes which should trigger a {@code GO_AWAY} if a set of headers exceeds this amount. + */ + public static long calculateMaxHeaderListSizeGoAway(long maxHeaderListSize) { + // This is equivalent to `maxHeaderListSize * 1.25` but we avoid floating point multiplication. + return maxHeaderListSize + (maxHeaderListSize >>> 2); + } + + public static final long DEFAULT_GRACEFUL_SHUTDOWN_TIMEOUT_MILLIS = MILLISECONDS.convert(30, SECONDS); + + public static final int DEFAULT_MAX_QUEUED_CONTROL_FRAMES = 10000; + + /** + * Returns {@code true} if the stream is an outbound stream. + * + * @param server {@code true} if the endpoint is a server, {@code false} otherwise. + * @param streamId the stream identifier + */ + public static boolean isOutboundStream(boolean server, int streamId) { + boolean even = (streamId & 1) == 0; + return streamId > 0 && server == even; + } + + /** + * Returns true if the {@code streamId} is a valid HTTP/2 stream identifier. + */ + public static boolean isStreamIdValid(int streamId) { + return streamId >= 0; + } + + static boolean isStreamIdValid(int streamId, boolean server) { + return isStreamIdValid(streamId) && server == ((streamId & 1) == 0); + } + + /** + * Indicates whether or not the given value for max frame size falls within the valid range. + */ + public static boolean isMaxFrameSizeValid(int maxFrameSize) { + return maxFrameSize >= MAX_FRAME_SIZE_LOWER_BOUND && maxFrameSize <= MAX_FRAME_SIZE_UPPER_BOUND; + } + + /** + * Returns a buffer containing the {@link #CONNECTION_PREFACE}. + */ + public static ByteBuf connectionPrefaceBuf() { + // Return a duplicate so that modifications to the reader index will not affect the original buffer. + return CONNECTION_PREFACE.retainedDuplicate(); + } + + /** + * Iteratively looks through the causality chain for the given exception and returns the first + * {@link Http2Exception} or {@code null} if none. + */ + public static Http2Exception getEmbeddedHttp2Exception(Throwable cause) { + while (cause != null) { + if (cause instanceof Http2Exception) { + return (Http2Exception) cause; + } + cause = cause.getCause(); + } + return null; + } + + /** + * Creates a buffer containing the error message from the given exception. If the cause is + * {@code null} returns an empty buffer. + */ + public static ByteBuf toByteBuf(ChannelHandlerContext ctx, Throwable cause) { + if (cause == null || cause.getMessage() == null) { + return Unpooled.EMPTY_BUFFER; + } + + return ByteBufUtil.writeUtf8(ctx.alloc(), cause.getMessage()); + } + + /** + * Reads a big-endian (31-bit) integer from the buffer. + */ + public static int readUnsignedInt(ByteBuf buf) { + return buf.readInt() & 0x7fffffff; + } + + /** + * Writes an HTTP/2 frame header to the output buffer. + */ + public static void writeFrameHeader(ByteBuf out, int payloadLength, byte type, + Http2Flags flags, int streamId) { + out.ensureWritable(FRAME_HEADER_LENGTH + payloadLength); + writeFrameHeaderInternal(out, payloadLength, type, flags, streamId); + } + + /** + * Calculate the amount of bytes that can be sent by {@code state}. The lower bound is {@code 0}. + */ + public static int streamableBytes(StreamByteDistributor.StreamState state) { + return max(0, (int) min(state.pendingBytes(), state.windowSize())); + } + + /** + * Results in a RST_STREAM being sent for {@code streamId} due to violating + * SETTINGS_MAX_HEADER_LIST_SIZE. + * @param streamId The stream ID that was being processed when the exceptional condition occurred. + * @param maxHeaderListSize The max allowed size for a list of headers in bytes which was exceeded. + * @param onDecode {@code true} if the exception was encountered during decoder. {@code false} for encode. + * @throws Http2Exception a stream error. + */ + public static void headerListSizeExceeded(int streamId, long maxHeaderListSize, + boolean onDecode) throws Http2Exception { + throw headerListSizeError(streamId, PROTOCOL_ERROR, onDecode, "Header size exceeded max " + + "allowed size (%d)", maxHeaderListSize); + } + + /** + * Results in a GO_AWAY being sent due to violating + * SETTINGS_MAX_HEADER_LIST_SIZE in an unrecoverable + * manner. + * @param maxHeaderListSize The max allowed size for a list of headers in bytes which was exceeded. + * @throws Http2Exception a connection error. + */ + public static void headerListSizeExceeded(long maxHeaderListSize) throws Http2Exception { + throw connectionError(PROTOCOL_ERROR, "Header size exceeded max " + + "allowed size (%d)", maxHeaderListSize); + } + + static void writeFrameHeaderInternal(ByteBuf out, int payloadLength, byte type, + Http2Flags flags, int streamId) { + out.writeMedium(payloadLength); + out.writeByte(type); + out.writeByte(flags.value()); + out.writeInt(streamId); + } + + /** + * Provides the ability to associate the outcome of multiple {@link ChannelPromise} + * objects into a single {@link ChannelPromise} object. + */ + static final class SimpleChannelPromiseAggregator extends DefaultChannelPromise { + private final ChannelPromise promise; + private int expectedCount; + private int doneCount; + private Throwable aggregateFailure; + private boolean doneAllocating; + + SimpleChannelPromiseAggregator(ChannelPromise promise, Channel c, EventExecutor e) { + super(c, e); + assert promise != null && !promise.isDone(); + this.promise = promise; + } + + /** + * Allocate a new promise which will be used to aggregate the overall success of this promise aggregator. + * @return A new promise which will be aggregated. + * {@code null} if {@link #doneAllocatingPromises()} was previously called. + */ + public ChannelPromise newPromise() { + assert !doneAllocating : "Done allocating. No more promises can be allocated."; + ++expectedCount; + return this; + } + + /** + * Signify that no more {@link #newPromise()} allocations will be made. + * The aggregation can not be successful until this method is called. + * @return The promise that is the aggregation of all promises allocated with {@link #newPromise()}. + */ + public ChannelPromise doneAllocatingPromises() { + if (!doneAllocating) { + doneAllocating = true; + if (doneCount == expectedCount || expectedCount == 0) { + return setPromise(); + } + } + return this; + } + + @Override + public boolean tryFailure(Throwable cause) { + if (allowFailure()) { + ++doneCount; + setAggregateFailure(cause); + if (allPromisesDone()) { + return tryPromise(); + } + // TODO: We break the interface a bit here. + // Multiple failure events can be processed without issue because this is an aggregation. + return true; + } + return false; + } + + /** + * Fail this object if it has not already been failed. + *

+ * This method will NOT throw an {@link IllegalStateException} if called multiple times + * because that may be expected. + */ + @Override + public ChannelPromise setFailure(Throwable cause) { + if (allowFailure()) { + ++doneCount; + setAggregateFailure(cause); + if (allPromisesDone()) { + return setPromise(); + } + } + return this; + } + + @Override + public ChannelPromise setSuccess(Void result) { + if (awaitingPromises()) { + ++doneCount; + if (allPromisesDone()) { + setPromise(); + } + } + return this; + } + + @Override + public boolean trySuccess(Void result) { + if (awaitingPromises()) { + ++doneCount; + if (allPromisesDone()) { + return tryPromise(); + } + // TODO: We break the interface a bit here. + // Multiple success events can be processed without issue because this is an aggregation. + return true; + } + return false; + } + + private boolean allowFailure() { + return awaitingPromises() || expectedCount == 0; + } + + private boolean awaitingPromises() { + return doneCount < expectedCount; + } + + private boolean allPromisesDone() { + return doneCount == expectedCount && doneAllocating; + } + + private ChannelPromise setPromise() { + if (aggregateFailure == null) { + promise.setSuccess(); + return super.setSuccess(null); + } else { + promise.setFailure(aggregateFailure); + return super.setFailure(aggregateFailure); + } + } + + private boolean tryPromise() { + if (aggregateFailure == null) { + promise.trySuccess(); + return super.trySuccess(null); + } else { + promise.tryFailure(aggregateFailure); + return super.tryFailure(aggregateFailure); + } + } + + private void setAggregateFailure(Throwable cause) { + if (aggregateFailure == null) { + aggregateFailure = cause; + } + } + } + + public static void verifyPadding(int padding) { + if (padding < 0 || padding > MAX_PADDING) { + throw new IllegalArgumentException(String.format("Invalid padding '%d'. Padding must be between 0 and " + + "%d (inclusive).", padding, MAX_PADDING)); + } + } + private Http2CodecUtil() { } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2Connection.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2Connection.java new file mode 100644 index 0000000..96f3013 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2Connection.java @@ -0,0 +1,356 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.Promise; +import io.netty.util.internal.UnstableApi; + +/** + * Manager for the state of an HTTP/2 connection with the remote end-point. + */ +@UnstableApi +public interface Http2Connection { + /** + * Listener for life-cycle events for streams in this connection. + */ + interface Listener { + /** + * Notifies the listener that the given stream was added to the connection. This stream may + * not yet be active (i.e. {@code OPEN} or {@code HALF CLOSED}). + *

+ * If a {@link RuntimeException} is thrown it will be logged and not propagated. + * Throwing from this method is not supported and is considered a programming error. + */ + void onStreamAdded(Http2Stream stream); + + /** + * Notifies the listener that the given stream was made active (i.e. {@code OPEN} or {@code HALF CLOSED}). + *

+ * If a {@link RuntimeException} is thrown it will be logged and not propagated. + * Throwing from this method is not supported and is considered a programming error. + */ + void onStreamActive(Http2Stream stream); + + /** + * Notifies the listener that the given stream has transitioned from {@code OPEN} to {@code HALF CLOSED}. + * This method will not be called until a state transition occurs from when + * {@link #onStreamActive(Http2Stream)} was called. + * The stream can be inspected to determine which side is {@code HALF CLOSED}. + *

+ * If a {@link RuntimeException} is thrown it will be logged and not propagated. + * Throwing from this method is not supported and is considered a programming error. + */ + void onStreamHalfClosed(Http2Stream stream); + + /** + * Notifies the listener that the given stream is now {@code CLOSED} in both directions and will no longer + * be accessible via {@link #forEachActiveStream(Http2StreamVisitor)}. + *

+ * If a {@link RuntimeException} is thrown it will be logged and not propagated. + * Throwing from this method is not supported and is considered a programming error. + */ + void onStreamClosed(Http2Stream stream); + + /** + * Notifies the listener that the given stream has now been removed from the connection and + * will no longer be returned via {@link Http2Connection#stream(int)}. The connection may + * maintain inactive streams for some time before removing them. + *

+ * If a {@link RuntimeException} is thrown it will be logged and not propagated. + * Throwing from this method is not supported and is considered a programming error. + */ + void onStreamRemoved(Http2Stream stream); + + /** + * Called when a {@code GOAWAY} frame was sent for the connection. + *

+ * If a {@link RuntimeException} is thrown it will be logged and not propagated. + * Throwing from this method is not supported and is considered a programming error. + * @param lastStreamId the last known stream of the remote endpoint. + * @param errorCode the error code, if abnormal closure. + * @param debugData application-defined debug data. + */ + void onGoAwaySent(int lastStreamId, long errorCode, ByteBuf debugData); + + /** + * Called when a {@code GOAWAY} was received from the remote endpoint. This event handler duplicates {@link + * Http2FrameListener#onGoAwayRead(io.netty.channel.ChannelHandlerContext, int, long, ByteBuf)} + * but is added here in order to simplify application logic for handling {@code GOAWAY} in a uniform way. An + * application should generally not handle both events, but if it does this method is called second, after + * notifying the {@link Http2FrameListener}. + *

+ * If a {@link RuntimeException} is thrown it will be logged and not propagated. + * Throwing from this method is not supported and is considered a programming error. + * @param lastStreamId the last known stream of the remote endpoint. + * @param errorCode the error code, if abnormal closure. + * @param debugData application-defined debug data. + */ + void onGoAwayReceived(int lastStreamId, long errorCode, ByteBuf debugData); + } + + /** + * A view of the connection from one endpoint (local or remote). + */ + interface Endpoint { + /** + * Increment and get the next generated stream id this endpoint. If negative, the stream IDs are + * exhausted for this endpoint an no further streams may be created. + */ + int incrementAndGetNextStreamId(); + + /** + * Indicates whether the given streamId is from the set of IDs used by this endpoint to + * create new streams. + */ + boolean isValidStreamId(int streamId); + + /** + * Indicates whether or not this endpoint may have created the given stream. This is {@code true} if + * {@link #isValidStreamId(int)} and {@code streamId} <= {@link #lastStreamCreated()}. + */ + boolean mayHaveCreatedStream(int streamId); + + /** + * Indicates whether or not this endpoint created the given stream. + */ + boolean created(Http2Stream stream); + + /** + * Indicates whether or a stream created by this endpoint can be opened without violating + * {@link #maxActiveStreams()}. + */ + boolean canOpenStream(); + + /** + * Creates a stream initiated by this endpoint. This could fail for the following reasons: + *

    + *
  • The requested stream ID is not the next sequential ID for this endpoint.
  • + *
  • The stream already exists.
  • + *
  • {@link #canOpenStream()} is {@code false}.
  • + *
  • The connection is marked as going away.
  • + *
+ *

+ * The initial state of the stream will be immediately set before notifying {@link Listener}s. The state + * transition is sensitive to {@code halfClosed} and is defined by {@link Http2Stream#open(boolean)}. + * @param streamId The ID of the stream + * @param halfClosed see {@link Http2Stream#open(boolean)}. + * @see Http2Stream#open(boolean) + */ + Http2Stream createStream(int streamId, boolean halfClosed) throws Http2Exception; + + /** + * Creates a push stream in the reserved state for this endpoint and notifies all listeners. + * This could fail for the following reasons: + *

    + *
  • Server push is not allowed to the opposite endpoint.
  • + *
  • The requested stream ID is not the next sequential stream ID for this endpoint.
  • + *
  • The number of concurrent streams is above the allowed threshold for this endpoint.
  • + *
  • The connection is marked as going away.
  • + *
  • The parent stream ID does not exist or is not {@code OPEN} from the side sending the push + * promise.
  • + *
  • Could not set a valid priority for the new stream.
  • + *
+ * + * @param streamId the ID of the push stream + * @param parent the parent stream used to initiate the push stream. + */ + Http2Stream reservePushStream(int streamId, Http2Stream parent) throws Http2Exception; + + /** + * Indicates whether or not this endpoint is the server-side of the connection. + */ + boolean isServer(); + + /** + * This is the SETTINGS_ENABLE_PUSH value sent + * from the opposite endpoint. This method should only be called by Netty (not users) as a result of a + * receiving a {@code SETTINGS} frame. + */ + void allowPushTo(boolean allow); + + /** + * This is the SETTINGS_ENABLE_PUSH value sent + * from the opposite endpoint. The initial value must be {@code true} for the client endpoint and always false + * for a server endpoint. + */ + boolean allowPushTo(); + + /** + * Gets the number of active streams (i.e. {@code OPEN} or {@code HALF CLOSED}) that were created by this + * endpoint. + */ + int numActiveStreams(); + + /** + * Gets the maximum number of streams (created by this endpoint) that are allowed to be active at + * the same time. This is the + * SETTINGS_MAX_CONCURRENT_STREAMS + * value sent from the opposite endpoint to restrict stream creation by this endpoint. + *

+ * The default value returned by this method must be "unlimited". + */ + int maxActiveStreams(); + + /** + * Sets the limit for {@code SETTINGS_MAX_CONCURRENT_STREAMS}. + * @param maxActiveStreams The maximum number of streams (created by this endpoint) that are allowed to be + * active at once. This is the + * SETTINGS_MAX_CONCURRENT_STREAMS value sent + * from the opposite endpoint to restrict stream creation by this endpoint. + */ + void maxActiveStreams(int maxActiveStreams); + + /** + * Gets the ID of the stream last successfully created by this endpoint. + */ + int lastStreamCreated(); + + /** + * If a GOAWAY was received for this endpoint, this will be the last stream ID from the + * GOAWAY frame. Otherwise, this will be {@code -1}. + */ + int lastStreamKnownByPeer(); + + /** + * Gets the flow controller for this endpoint. + */ + F flowController(); + + /** + * Sets the flow controller for this endpoint. + */ + void flowController(F flowController); + + /** + * Gets the {@link Endpoint} opposite this one. + */ + Endpoint opposite(); + } + + /** + * A key to be used for associating application-defined properties with streams within this connection. + */ + interface PropertyKey { + } + + /** + * Close this connection. No more new streams can be created after this point and + * all streams that exists (active or otherwise) will be closed and removed. + *

Note if iterating active streams via {@link #forEachActiveStream(Http2StreamVisitor)} and an exception is + * thrown it is necessary to call this method again to ensure the close completes. + * @param promise Will be completed when all streams have been removed, and listeners have been notified. + * @return A future that will be completed when all streams have been removed, and listeners have been notified. + */ + Future close(Promise promise); + + /** + * Creates a new key that is unique within this {@link Http2Connection}. + */ + PropertyKey newKey(); + + /** + * Adds a listener of stream life-cycle events. + */ + void addListener(Listener listener); + + /** + * Removes a listener of stream life-cycle events. If the same listener was added multiple times + * then only the first occurrence gets removed. + */ + void removeListener(Listener listener); + + /** + * Gets the stream if it exists. If not, returns {@code null}. + */ + Http2Stream stream(int streamId); + + /** + * Indicates whether or not the given stream may have existed within this connection. This is a short form + * for calling {@link Endpoint#mayHaveCreatedStream(int)} on both endpoints. + */ + boolean streamMayHaveExisted(int streamId); + + /** + * Gets the stream object representing the connection, itself (i.e. stream zero). This object + * always exists. + */ + Http2Stream connectionStream(); + + /** + * Gets the number of streams that are actively in use (i.e. {@code OPEN} or {@code HALF CLOSED}). + */ + int numActiveStreams(); + + /** + * Provide a means of iterating over the collection of active streams. + * + * @param visitor The visitor which will visit each active stream. + * @return The stream before iteration stopped or {@code null} if iteration went past the end. + */ + Http2Stream forEachActiveStream(Http2StreamVisitor visitor) throws Http2Exception; + + /** + * Indicates whether or not the local endpoint for this connection is the server. + */ + boolean isServer(); + + /** + * Gets a view of this connection from the local {@link Endpoint}. + */ + Endpoint local(); + + /** + * Gets a view of this connection from the remote {@link Endpoint}. + */ + Endpoint remote(); + + /** + * Indicates whether or not a {@code GOAWAY} was received from the remote endpoint. + */ + boolean goAwayReceived(); + + /** + * Indicates that a {@code GOAWAY} was received from the remote endpoint and sets the last known stream. + * @param lastKnownStream The Last-Stream-ID in the + * GOAWAY frame. + * @param errorCode the Error Code in the + * GOAWAY frame. + * @param message The Additional Debug Data in the + * GOAWAY frame. Note that reference count ownership + * belongs to the caller (ownership is not transferred to this method). + */ + void goAwayReceived(int lastKnownStream, long errorCode, ByteBuf message) throws Http2Exception; + + /** + * Indicates whether or not a {@code GOAWAY} was sent to the remote endpoint. + */ + boolean goAwaySent(); + + /** + * Updates the local state of this {@link Http2Connection} as a result of a {@code GOAWAY} to send to the remote + * endpoint. + * @param lastKnownStream The Last-Stream-ID in the + * GOAWAY frame. + * @param errorCode the Error Code in the + * GOAWAY frame. + * GOAWAY frame. Note that reference count ownership + * belongs to the caller (ownership is not transferred to this method). + * @return {@code true} if the corresponding {@code GOAWAY} frame should be sent to the remote endpoint. + */ + boolean goAwaySent(int lastKnownStream, long errorCode, ByteBuf message) throws Http2Exception; +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ConnectionAdapter.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ConnectionAdapter.java new file mode 100644 index 0000000..0ba4d94 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ConnectionAdapter.java @@ -0,0 +1,52 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.util.internal.UnstableApi; + +/** + * Provides empty implementations of all {@link Http2Connection.Listener} methods. + */ +@UnstableApi +public class Http2ConnectionAdapter implements Http2Connection.Listener { + @Override + public void onStreamAdded(Http2Stream stream) { + } + + @Override + public void onStreamActive(Http2Stream stream) { + } + + @Override + public void onStreamHalfClosed(Http2Stream stream) { + } + + @Override + public void onStreamClosed(Http2Stream stream) { + } + + @Override + public void onStreamRemoved(Http2Stream stream) { + } + + @Override + public void onGoAwaySent(int lastStreamId, long errorCode, ByteBuf debugData) { + } + + @Override + public void onGoAwayReceived(int lastStreamId, long errorCode, ByteBuf debugData) { + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ConnectionDecoder.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ConnectionDecoder.java new file mode 100644 index 0000000..71df7f9 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ConnectionDecoder.java @@ -0,0 +1,77 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.util.internal.UnstableApi; + +import java.io.Closeable; +import java.util.List; + +/** + * Handler for inbound traffic on behalf of {@link Http2ConnectionHandler}. Performs basic protocol + * conformance on inbound frames before calling the delegate {@link Http2FrameListener} for + * application-specific processing. Note that frames of an unknown type (i.e. HTTP/2 extensions) + * will skip all protocol checks and be given directly to the listener for processing. + */ +@UnstableApi +public interface Http2ConnectionDecoder extends Closeable { + + /** + * Sets the lifecycle manager. Must be called as part of initialization before the decoder is used. + */ + void lifecycleManager(Http2LifecycleManager lifecycleManager); + + /** + * Provides direct access to the underlying connection. + */ + Http2Connection connection(); + + /** + * Provides the local flow controller for managing inbound traffic. + */ + Http2LocalFlowController flowController(); + + /** + * Set the {@link Http2FrameListener} which will be notified when frames are decoded. + *

+ * This must be set before frames are decoded. + */ + void frameListener(Http2FrameListener listener); + + /** + * Get the {@link Http2FrameListener} which will be notified when frames are decoded. + */ + Http2FrameListener frameListener(); + + /** + * Called by the {@link Http2ConnectionHandler} to decode the next frame from the input buffer. + */ + void decodeFrame(ChannelHandlerContext ctx, ByteBuf in, List out) throws Http2Exception; + + /** + * Gets the local settings for this endpoint of the HTTP/2 connection. + */ + Http2Settings localSettings(); + + /** + * Indicates whether or not the first initial {@code SETTINGS} frame was received from the remote endpoint. + */ + boolean prefaceReceived(); + + @Override + void close(); +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ConnectionEncoder.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ConnectionEncoder.java new file mode 100644 index 0000000..b521e9f --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ConnectionEncoder.java @@ -0,0 +1,68 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.util.internal.UnstableApi; + + +/** + * Handler for outbound HTTP/2 traffic. + */ +@UnstableApi +public interface Http2ConnectionEncoder extends Http2FrameWriter { + + /** + * Sets the lifecycle manager. Must be called as part of initialization before the encoder is used. + */ + void lifecycleManager(Http2LifecycleManager lifecycleManager); + + /** + * Provides direct access to the underlying connection. + */ + Http2Connection connection(); + + /** + * Provides the remote flow controller for managing outbound traffic. + */ + Http2RemoteFlowController flowController(); + + /** + * Provides direct access to the underlying frame writer object. + */ + Http2FrameWriter frameWriter(); + + /** + * Gets the local settings on the top of the queue that has been sent but not ACKed. This may + * return {@code null}. + */ + Http2Settings pollSentSettings(); + + /** + * Sets the settings for the remote endpoint of the HTTP/2 connection. + */ + void remoteSettings(Http2Settings settings) throws Http2Exception; + + /** + * Writes the given data to the internal {@link Http2FrameWriter} without performing any + * state checks on the connection/stream. + */ + @Override + ChannelFuture writeFrame(ChannelHandlerContext ctx, byte frameType, int streamId, + Http2Flags flags, ByteBuf payload, ChannelPromise promise); +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ConnectionHandler.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ConnectionHandler.java new file mode 100644 index 0000000..badd159 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ConnectionHandler.java @@ -0,0 +1,1009 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelOutboundHandler; +import io.netty.channel.ChannelPromise; +import io.netty.handler.codec.ByteToMessageDecoder; +import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.handler.codec.http2.Http2Exception.CompositeStreamException; +import io.netty.handler.codec.http2.Http2Exception.StreamException; +import io.netty.util.CharsetUtil; +import io.netty.util.concurrent.Future; +import io.netty.util.internal.UnstableApi; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.net.SocketAddress; +import java.util.List; +import java.util.concurrent.TimeUnit; + +import static io.netty.buffer.ByteBufUtil.hexDump; +import static io.netty.buffer.Unpooled.EMPTY_BUFFER; +import static io.netty.handler.codec.http2.Http2CodecUtil.HTTP_UPGRADE_STREAM_ID; +import static io.netty.handler.codec.http2.Http2CodecUtil.connectionPrefaceBuf; +import static io.netty.handler.codec.http2.Http2CodecUtil.getEmbeddedHttp2Exception; +import static io.netty.handler.codec.http2.Http2Error.INTERNAL_ERROR; +import static io.netty.handler.codec.http2.Http2Error.NO_ERROR; +import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR; +import static io.netty.handler.codec.http2.Http2Exception.connectionError; +import static io.netty.handler.codec.http2.Http2Exception.isStreamError; +import static io.netty.handler.codec.http2.Http2FrameTypes.SETTINGS; +import static io.netty.handler.codec.http2.Http2Stream.State.IDLE; +import static io.netty.util.CharsetUtil.UTF_8; +import static io.netty.util.internal.ObjectUtil.checkNotNull; +import static java.lang.Math.min; +import static java.util.concurrent.TimeUnit.MILLISECONDS; + +/** + * Provides the default implementation for processing inbound frame events and delegates to a + * {@link Http2FrameListener} + *

+ * This class will read HTTP/2 frames and delegate the events to a {@link Http2FrameListener} + *

+ * This interface enforces inbound flow control functionality through + * {@link Http2LocalFlowController} + */ +@UnstableApi +public class Http2ConnectionHandler extends ByteToMessageDecoder implements Http2LifecycleManager, + ChannelOutboundHandler { + + private static final InternalLogger logger = InternalLoggerFactory.getInstance(Http2ConnectionHandler.class); + + private static final Http2Headers HEADERS_TOO_LARGE_HEADERS = ReadOnlyHttp2Headers.serverHeaders(false, + HttpResponseStatus.REQUEST_HEADER_FIELDS_TOO_LARGE.codeAsText()); + private static final ByteBuf HTTP_1_X_BUF = Unpooled.unreleasableBuffer( + Unpooled.wrappedBuffer(new byte[] {'H', 'T', 'T', 'P', '/', '1', '.'})).asReadOnly(); + + private final Http2ConnectionDecoder decoder; + private final Http2ConnectionEncoder encoder; + private final Http2Settings initialSettings; + private final boolean decoupleCloseAndGoAway; + private final boolean flushPreface; + private ChannelFutureListener closeListener; + private BaseDecoder byteDecoder; + private long gracefulShutdownTimeoutMillis; + + protected Http2ConnectionHandler(Http2ConnectionDecoder decoder, Http2ConnectionEncoder encoder, + Http2Settings initialSettings) { + this(decoder, encoder, initialSettings, false); + } + + protected Http2ConnectionHandler(Http2ConnectionDecoder decoder, Http2ConnectionEncoder encoder, + Http2Settings initialSettings, boolean decoupleCloseAndGoAway) { + this(decoder, encoder, initialSettings, decoupleCloseAndGoAway, true); + } + + protected Http2ConnectionHandler(Http2ConnectionDecoder decoder, Http2ConnectionEncoder encoder, + Http2Settings initialSettings, boolean decoupleCloseAndGoAway, + boolean flushPreface) { + this.initialSettings = checkNotNull(initialSettings, "initialSettings"); + this.decoder = checkNotNull(decoder, "decoder"); + this.encoder = checkNotNull(encoder, "encoder"); + this.decoupleCloseAndGoAway = decoupleCloseAndGoAway; + this.flushPreface = flushPreface; + if (encoder.connection() != decoder.connection()) { + throw new IllegalArgumentException("Encoder and Decoder do not share the same connection object"); + } + } + + /** + * Get the amount of time (in milliseconds) this endpoint will wait for all streams to be closed before closing + * the connection during the graceful shutdown process. Returns -1 if this connection is configured to wait + * indefinitely for all streams to close. + */ + public long gracefulShutdownTimeoutMillis() { + return gracefulShutdownTimeoutMillis; + } + + /** + * Set the amount of time (in milliseconds) this endpoint will wait for all streams to be closed before closing + * the connection during the graceful shutdown process. + * @param gracefulShutdownTimeoutMillis the amount of time (in milliseconds) this endpoint will wait for all + * streams to be closed before closing the connection during the graceful shutdown process. + */ + public void gracefulShutdownTimeoutMillis(long gracefulShutdownTimeoutMillis) { + if (gracefulShutdownTimeoutMillis < -1) { + throw new IllegalArgumentException("gracefulShutdownTimeoutMillis: " + gracefulShutdownTimeoutMillis + + " (expected: -1 for indefinite or >= 0)"); + } + this.gracefulShutdownTimeoutMillis = gracefulShutdownTimeoutMillis; + } + + public Http2Connection connection() { + return encoder.connection(); + } + + public Http2ConnectionDecoder decoder() { + return decoder; + } + + public Http2ConnectionEncoder encoder() { + return encoder; + } + + private boolean prefaceSent() { + return byteDecoder != null && byteDecoder.prefaceSent(); + } + + /** + * Handles the client-side (cleartext) upgrade from HTTP to HTTP/2. + * Reserves local stream 1 for the HTTP/2 response. + */ + public void onHttpClientUpgrade() throws Http2Exception { + if (connection().isServer()) { + throw connectionError(PROTOCOL_ERROR, "Client-side HTTP upgrade requested for a server"); + } + if (!prefaceSent()) { + // If the preface was not sent yet it most likely means the handler was not added to the pipeline before + // calling this method. + throw connectionError(INTERNAL_ERROR, "HTTP upgrade must occur after preface was sent"); + } + if (decoder.prefaceReceived()) { + throw connectionError(PROTOCOL_ERROR, "HTTP upgrade must occur before HTTP/2 preface is received"); + } + + // Create a local stream used for the HTTP cleartext upgrade. + connection().local().createStream(HTTP_UPGRADE_STREAM_ID, true); + } + + /** + * Handles the server-side (cleartext) upgrade from HTTP to HTTP/2. + * @param settings the settings for the remote endpoint. + */ + public void onHttpServerUpgrade(Http2Settings settings) throws Http2Exception { + if (!connection().isServer()) { + throw connectionError(PROTOCOL_ERROR, "Server-side HTTP upgrade requested for a client"); + } + if (!prefaceSent()) { + // If the preface was not sent yet it most likely means the handler was not added to the pipeline before + // calling this method. + throw connectionError(INTERNAL_ERROR, "HTTP upgrade must occur after preface was sent"); + } + if (decoder.prefaceReceived()) { + throw connectionError(PROTOCOL_ERROR, "HTTP upgrade must occur before HTTP/2 preface is received"); + } + + // Apply the settings but no ACK is necessary. + encoder.remoteSettings(settings); + + // Create a stream in the half-closed state. + connection().remote().createStream(HTTP_UPGRADE_STREAM_ID, true); + } + + @Override + public void flush(ChannelHandlerContext ctx) { + try { + // Trigger pending writes in the remote flow controller. + encoder.flowController().writePendingBytes(); + ctx.flush(); + } catch (Http2Exception e) { + onError(ctx, true, e); + } catch (Throwable cause) { + onError(ctx, true, connectionError(INTERNAL_ERROR, cause, "Error flushing")); + } + } + + private abstract class BaseDecoder { + public abstract void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception; + public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { } + public void channelActive(ChannelHandlerContext ctx) throws Exception { } + + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + // Connection has terminated, close the encoder and decoder. + encoder().close(); + decoder().close(); + + // We need to remove all streams (not just the active ones). + // See https://github.com/netty/netty/issues/4838. + connection().close(ctx.voidPromise()); + } + + /** + * Determine if the HTTP/2 connection preface been sent. + */ + public boolean prefaceSent() { + return true; + } + } + + private final class PrefaceDecoder extends BaseDecoder { + private ByteBuf clientPrefaceString; + private boolean prefaceSent; + + PrefaceDecoder(ChannelHandlerContext ctx) throws Exception { + clientPrefaceString = clientPrefaceString(encoder.connection()); + // This handler was just added to the context. In case it was handled after + // the connection became active, send the connection preface now. + sendPreface(ctx); + } + + @Override + public boolean prefaceSent() { + return prefaceSent; + } + + @Override + public void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { + try { + if (ctx.channel().isActive() && readClientPrefaceString(in) && verifyFirstFrameIsSettings(in)) { + // After the preface is read, it is time to hand over control to the post initialized decoder. + byteDecoder = new FrameDecoder(); + byteDecoder.decode(ctx, in, out); + } + } catch (Throwable e) { + onError(ctx, false, e); + } + } + + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + // The channel just became active - send the connection preface to the remote endpoint. + sendPreface(ctx); + + if (flushPreface) { + // As we don't know if any channelReadComplete() events will be triggered at all we need to ensure we + // also flush. Otherwise the remote peer might never see the preface / settings frame. + // See https://github.com/netty/netty/issues/12089 + ctx.flush(); + } + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + cleanup(); + super.channelInactive(ctx); + } + + /** + * Releases the {@code clientPrefaceString}. Any active streams will be left in the open. + */ + @Override + public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { + cleanup(); + } + + /** + * Releases the {@code clientPrefaceString}. Any active streams will be left in the open. + */ + private void cleanup() { + if (clientPrefaceString != null) { + clientPrefaceString.release(); + clientPrefaceString = null; + } + } + + /** + * Decodes the client connection preface string from the input buffer. + * + * @return {@code true} if processing of the client preface string is complete. Since client preface strings can + * only be received by servers, returns true immediately for client endpoints. + */ + private boolean readClientPrefaceString(ByteBuf in) throws Http2Exception { + if (clientPrefaceString == null) { + return true; + } + + int prefaceRemaining = clientPrefaceString.readableBytes(); + int bytesRead = min(in.readableBytes(), prefaceRemaining); + + // If the input so far doesn't match the preface, break the connection. + if (bytesRead == 0 || !ByteBufUtil.equals(in, in.readerIndex(), + clientPrefaceString, clientPrefaceString.readerIndex(), + bytesRead)) { + int maxSearch = 1024; // picked because 512 is too little, and 2048 too much + int http1Index = + ByteBufUtil.indexOf(HTTP_1_X_BUF, in.slice(in.readerIndex(), min(in.readableBytes(), maxSearch))); + if (http1Index != -1) { + String chunk = in.toString(in.readerIndex(), http1Index - in.readerIndex(), CharsetUtil.US_ASCII); + throw connectionError(PROTOCOL_ERROR, "Unexpected HTTP/1.x request: %s", chunk); + } + String receivedBytes = hexDump(in, in.readerIndex(), + min(in.readableBytes(), clientPrefaceString.readableBytes())); + throw connectionError(PROTOCOL_ERROR, "HTTP/2 client preface string missing or corrupt. " + + "Hex dump for received bytes: %s", receivedBytes); + } + in.skipBytes(bytesRead); + clientPrefaceString.skipBytes(bytesRead); + + if (!clientPrefaceString.isReadable()) { + // Entire preface has been read. + clientPrefaceString.release(); + clientPrefaceString = null; + return true; + } + return false; + } + + /** + * Peeks at that the next frame in the buffer and verifies that it is a non-ack {@code SETTINGS} frame. + * + * @param in the inbound buffer. + * @return {@code true} if the next frame is a non-ack {@code SETTINGS} frame, {@code false} if more + * data is required before we can determine the next frame type. + * @throws Http2Exception thrown if the next frame is NOT a non-ack {@code SETTINGS} frame. + */ + private boolean verifyFirstFrameIsSettings(ByteBuf in) throws Http2Exception { + if (in.readableBytes() < 5) { + // Need more data before we can see the frame type for the first frame. + return false; + } + + short frameType = in.getUnsignedByte(in.readerIndex() + 3); + short flags = in.getUnsignedByte(in.readerIndex() + 4); + if (frameType != SETTINGS || (flags & Http2Flags.ACK) != 0) { + throw connectionError(PROTOCOL_ERROR, "First received frame was not SETTINGS. " + + "Hex dump for first 5 bytes: %s", + hexDump(in, in.readerIndex(), 5)); + } + return true; + } + + /** + * Sends the HTTP/2 connection preface upon establishment of the connection, if not already sent. + */ + private void sendPreface(ChannelHandlerContext ctx) throws Exception { + if (prefaceSent || !ctx.channel().isActive()) { + return; + } + + prefaceSent = true; + + final boolean isClient = !connection().isServer(); + if (isClient) { + // Clients must send the preface string as the first bytes on the connection. + ctx.write(connectionPrefaceBuf()).addListener(ChannelFutureListener.CLOSE_ON_FAILURE); + } + + // Both client and server must send their initial settings. + encoder.writeSettings(ctx, initialSettings, ctx.newPromise()).addListener( + ChannelFutureListener.CLOSE_ON_FAILURE); + + if (isClient) { + // If this handler is extended by the user and we directly fire the userEvent from this context then + // the user will not see the event. We should fire the event starting with this handler so this class + // (and extending classes) have a chance to process the event. + userEventTriggered(ctx, Http2ConnectionPrefaceAndSettingsFrameWrittenEvent.INSTANCE); + } + } + } + + private final class FrameDecoder extends BaseDecoder { + @Override + public void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { + try { + decoder.decodeFrame(ctx, in, out); + } catch (Throwable e) { + onError(ctx, false, e); + } + } + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + // Initialize the encoder, decoder, flow controllers, and internal state. + encoder.lifecycleManager(this); + decoder.lifecycleManager(this); + encoder.flowController().channelHandlerContext(ctx); + decoder.flowController().channelHandlerContext(ctx); + byteDecoder = new PrefaceDecoder(ctx); + } + + @Override + protected void handlerRemoved0(ChannelHandlerContext ctx) throws Exception { + if (byteDecoder != null) { + byteDecoder.handlerRemoved(ctx); + byteDecoder = null; + } + } + + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + if (byteDecoder == null) { + byteDecoder = new PrefaceDecoder(ctx); + } + byteDecoder.channelActive(ctx); + super.channelActive(ctx); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + // Call super class first, as this may result in decode being called. + super.channelInactive(ctx); + if (byteDecoder != null) { + byteDecoder.channelInactive(ctx); + byteDecoder = null; + } + } + + @Override + public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception { + // Writability is expected to change while we are writing. We cannot allow this event to trigger reentering + // the allocation and write loop. Reentering the event loop will lead to over or illegal allocation. + try { + if (ctx.channel().isWritable()) { + flush(ctx); + } + encoder.flowController().channelWritabilityChanged(); + } finally { + super.channelWritabilityChanged(ctx); + } + } + + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { + byteDecoder.decode(ctx, in, out); + } + + @Override + public void bind(ChannelHandlerContext ctx, SocketAddress localAddress, ChannelPromise promise) throws Exception { + ctx.bind(localAddress, promise); + } + + @Override + public void connect(ChannelHandlerContext ctx, SocketAddress remoteAddress, SocketAddress localAddress, + ChannelPromise promise) throws Exception { + ctx.connect(remoteAddress, localAddress, promise); + } + + @Override + public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + ctx.disconnect(promise); + } + + @Override + public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + if (decoupleCloseAndGoAway) { + ctx.close(promise); + return; + } + promise = promise.unvoid(); + // Avoid NotYetConnectedException and avoid sending before connection preface + if (!ctx.channel().isActive() || !prefaceSent()) { + ctx.close(promise); + return; + } + + // If the user has already sent a GO_AWAY frame they may be attempting to do a graceful shutdown which requires + // sending multiple GO_AWAY frames. We should only send a GO_AWAY here if one has not already been sent. If + // a GO_AWAY has been sent we send a empty buffer just so we can wait to close until all other data has been + // flushed to the OS. + // https://github.com/netty/netty/issues/5307 + ChannelFuture f = connection().goAwaySent() ? ctx.write(EMPTY_BUFFER) : goAway(ctx, null, ctx.newPromise()); + ctx.flush(); + doGracefulShutdown(ctx, f, promise); + } + + private ChannelFutureListener newClosingChannelFutureListener( + ChannelHandlerContext ctx, ChannelPromise promise) { + long gracefulShutdownTimeoutMillis = this.gracefulShutdownTimeoutMillis; + return gracefulShutdownTimeoutMillis < 0 ? + new ClosingChannelFutureListener(ctx, promise) : + new ClosingChannelFutureListener(ctx, promise, gracefulShutdownTimeoutMillis, MILLISECONDS); + } + + private void doGracefulShutdown(ChannelHandlerContext ctx, ChannelFuture future, final ChannelPromise promise) { + final ChannelFutureListener listener = newClosingChannelFutureListener(ctx, promise); + if (isGracefulShutdownComplete()) { + // If there are no active streams, close immediately after the GO_AWAY write completes or the timeout + // elapsed. + future.addListener(listener); + } else { + // If there are active streams we should wait until they are all closed before closing the connection. + + // The ClosingChannelFutureListener will cascade promise completion. We need to always notify the + // new ClosingChannelFutureListener when the graceful close completes if the promise is not null. + if (closeListener == null) { + closeListener = listener; + } else if (promise != null) { + final ChannelFutureListener oldCloseListener = closeListener; + closeListener = new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + try { + oldCloseListener.operationComplete(future); + } finally { + listener.operationComplete(future); + } + } + }; + } + } + } + + @Override + public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + ctx.deregister(promise); + } + + @Override + public void read(ChannelHandlerContext ctx) throws Exception { + ctx.read(); + } + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + ctx.write(msg, promise); + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + // Trigger flush after read on the assumption that flush is cheap if there is nothing to write and that + // for flow-control the read may release window that causes data to be written that can now be flushed. + try { + // First call channelReadComplete0(...) as this may produce more data that we want to flush + channelReadComplete0(ctx); + } finally { + flush(ctx); + } + } + + final void channelReadComplete0(ChannelHandlerContext ctx) { + // Discard bytes of the cumulation buffer if needed. + discardSomeReadBytes(); + + // Ensure we never stale the HTTP/2 Channel. Flow-control is enforced by HTTP/2. + // + // See https://tools.ietf.org/html/rfc7540#section-5.2.2 + if (!ctx.channel().config().isAutoRead()) { + ctx.read(); + } + + ctx.fireChannelReadComplete(); + } + + /** + * Handles {@link Http2Exception} objects that were thrown from other handlers. Ignores all other exceptions. + */ + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + if (getEmbeddedHttp2Exception(cause) != null) { + // Some exception in the causality chain is an Http2Exception - handle it. + onError(ctx, false, cause); + } else { + super.exceptionCaught(ctx, cause); + } + } + + /** + * Closes the local side of the given stream. If this causes the stream to be closed, adds a + * hook to close the channel after the given future completes. + * + * @param stream the stream to be half closed. + * @param future If closing, the future after which to close the channel. + */ + @Override + public void closeStreamLocal(Http2Stream stream, ChannelFuture future) { + switch (stream.state()) { + case HALF_CLOSED_LOCAL: + case OPEN: + stream.closeLocalSide(); + break; + default: + closeStream(stream, future); + break; + } + } + + /** + * Closes the remote side of the given stream. If this causes the stream to be closed, adds a + * hook to close the channel after the given future completes. + * + * @param stream the stream to be half closed. + * @param future If closing, the future after which to close the channel. + */ + @Override + public void closeStreamRemote(Http2Stream stream, ChannelFuture future) { + switch (stream.state()) { + case HALF_CLOSED_REMOTE: + case OPEN: + stream.closeRemoteSide(); + break; + default: + closeStream(stream, future); + break; + } + } + + @Override + public void closeStream(final Http2Stream stream, ChannelFuture future) { + if (future.isDone()) { + doCloseStream(stream, future); + } else { + future.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + doCloseStream(stream, future); + } + }); + } + } + + /** + * Central handler for all exceptions caught during HTTP/2 processing. + */ + @Override + public void onError(ChannelHandlerContext ctx, boolean outbound, Throwable cause) { + Http2Exception embedded = getEmbeddedHttp2Exception(cause); + if (isStreamError(embedded)) { + onStreamError(ctx, outbound, cause, (StreamException) embedded); + } else if (embedded instanceof CompositeStreamException) { + CompositeStreamException compositException = (CompositeStreamException) embedded; + for (StreamException streamException : compositException) { + onStreamError(ctx, outbound, cause, streamException); + } + } else { + onConnectionError(ctx, outbound, cause, embedded); + } + ctx.flush(); + } + + /** + * Called by the graceful shutdown logic to determine when it is safe to close the connection. Returns {@code true} + * if the graceful shutdown has completed and the connection can be safely closed. This implementation just + * guarantees that there are no active streams. Subclasses may override to provide additional checks. + */ + protected boolean isGracefulShutdownComplete() { + return connection().numActiveStreams() == 0; + } + + /** + * Handler for a connection error. Sends a GO_AWAY frame to the remote endpoint. Once all + * streams are closed, the connection is shut down. + * + * @param ctx the channel context + * @param outbound {@code true} if the error was caused by an outbound operation. + * @param cause the exception that was caught + * @param http2Ex the {@link Http2Exception} that is embedded in the causality chain. This may + * be {@code null} if it's an unknown exception. + */ + protected void onConnectionError(ChannelHandlerContext ctx, boolean outbound, + Throwable cause, Http2Exception http2Ex) { + if (http2Ex == null) { + http2Ex = new Http2Exception(INTERNAL_ERROR, cause.getMessage(), cause); + } + + ChannelPromise promise = ctx.newPromise(); + ChannelFuture future = goAway(ctx, http2Ex, ctx.newPromise()); + if (http2Ex.shutdownHint() == Http2Exception.ShutdownHint.GRACEFUL_SHUTDOWN) { + doGracefulShutdown(ctx, future, promise); + } else { + future.addListener(newClosingChannelFutureListener(ctx, promise)); + } + } + + /** + * Handler for a stream error. Sends a {@code RST_STREAM} frame to the remote endpoint and closes the + * stream. + * + * @param ctx the channel context + * @param outbound {@code true} if the error was caused by an outbound operation. + * @param cause the exception that was caught + * @param http2Ex the {@link StreamException} that is embedded in the causality chain. + */ + protected void onStreamError(ChannelHandlerContext ctx, boolean outbound, + @SuppressWarnings("unused") Throwable cause, StreamException http2Ex) { + final int streamId = http2Ex.streamId(); + Http2Stream stream = connection().stream(streamId); + + //if this is caused by reading headers that are too large, send a header with status 431 + if (http2Ex instanceof Http2Exception.HeaderListSizeException && + ((Http2Exception.HeaderListSizeException) http2Ex).duringDecode() && + connection().isServer()) { + + // NOTE We have to check to make sure that a stream exists before we send our reply. + // We likely always create the stream below as the stream isn't created until the + // header block is completely processed. + + // The case of a streamId referring to a stream which was already closed is handled + // by createStream and will land us in the catch block below + if (stream == null) { + try { + stream = encoder.connection().remote().createStream(streamId, true); + } catch (Http2Exception e) { + resetUnknownStream(ctx, streamId, http2Ex.error().code(), ctx.newPromise()); + return; + } + } + + // ensure that we have not already sent headers on this stream + if (stream != null && !stream.isHeadersSent()) { + try { + handleServerHeaderDecodeSizeError(ctx, stream); + } catch (Throwable cause2) { + onError(ctx, outbound, connectionError(INTERNAL_ERROR, cause2, "Error DecodeSizeError")); + } + } + } + + if (stream == null) { + if (!outbound || connection().local().mayHaveCreatedStream(streamId)) { + resetUnknownStream(ctx, streamId, http2Ex.error().code(), ctx.newPromise()); + } + } else { + resetStream(ctx, stream, http2Ex.error().code(), ctx.newPromise()); + } + } + + /** + * Notifies client that this server has received headers that are larger than what it is + * willing to accept. Override to change behavior. + * + * @param ctx the channel context + * @param stream the Http2Stream on which the header was received + */ + protected void handleServerHeaderDecodeSizeError(ChannelHandlerContext ctx, Http2Stream stream) { + encoder().writeHeaders(ctx, stream.id(), HEADERS_TOO_LARGE_HEADERS, 0, true, ctx.newPromise()); + } + + protected Http2FrameWriter frameWriter() { + return encoder().frameWriter(); + } + + /** + * Sends a {@code RST_STREAM} frame even if we don't know about the stream. This error condition is most likely + * triggered by the first frame of a stream being invalid. That is, there was an error reading the frame before + * we could create a new stream. + */ + private ChannelFuture resetUnknownStream(final ChannelHandlerContext ctx, int streamId, long errorCode, + ChannelPromise promise) { + ChannelFuture future = frameWriter().writeRstStream(ctx, streamId, errorCode, promise); + if (future.isDone()) { + closeConnectionOnError(ctx, future); + } else { + future.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + closeConnectionOnError(ctx, future); + } + }); + } + return future; + } + + @Override + public ChannelFuture resetStream(final ChannelHandlerContext ctx, int streamId, long errorCode, + ChannelPromise promise) { + final Http2Stream stream = connection().stream(streamId); + if (stream == null) { + return resetUnknownStream(ctx, streamId, errorCode, promise.unvoid()); + } + + return resetStream(ctx, stream, errorCode, promise); + } + + private ChannelFuture resetStream(final ChannelHandlerContext ctx, final Http2Stream stream, + long errorCode, ChannelPromise promise) { + promise = promise.unvoid(); + if (stream.isResetSent()) { + // Don't write a RST_STREAM frame if we have already written one. + return promise.setSuccess(); + } + // Synchronously set the resetSent flag to prevent any subsequent calls + // from resulting in multiple reset frames being sent. + // + // This needs to be done before we notify the promise as the promise may have a listener attached that + // call resetStream(...) again. + stream.resetSent(); + + final ChannelFuture future; + // If the remote peer is not aware of the steam, then we are not allowed to send a RST_STREAM + // https://tools.ietf.org/html/rfc7540#section-6.4. + if (stream.state() == IDLE || + connection().local().created(stream) && !stream.isHeadersSent() && !stream.isPushPromiseSent()) { + future = promise.setSuccess(); + } else { + future = frameWriter().writeRstStream(ctx, stream.id(), errorCode, promise); + } + if (future.isDone()) { + processRstStreamWriteResult(ctx, stream, future); + } else { + future.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + processRstStreamWriteResult(ctx, stream, future); + } + }); + } + + return future; + } + + @Override + public ChannelFuture goAway(final ChannelHandlerContext ctx, final int lastStreamId, final long errorCode, + final ByteBuf debugData, ChannelPromise promise) { + promise = promise.unvoid(); + final Http2Connection connection = connection(); + try { + if (!connection.goAwaySent(lastStreamId, errorCode, debugData)) { + debugData.release(); + promise.trySuccess(); + return promise; + } + } catch (Throwable cause) { + debugData.release(); + promise.tryFailure(cause); + return promise; + } + + // Need to retain before we write the buffer because if we do it after the refCnt could already be 0 and + // result in an IllegalRefCountException. + debugData.retain(); + ChannelFuture future = frameWriter().writeGoAway(ctx, lastStreamId, errorCode, debugData, promise); + + if (future.isDone()) { + processGoAwayWriteResult(ctx, lastStreamId, errorCode, debugData, future); + } else { + future.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + processGoAwayWriteResult(ctx, lastStreamId, errorCode, debugData, future); + } + }); + } + + return future; + } + + /** + * Closes the connection if the graceful shutdown process has completed. + * @param future Represents the status that will be passed to the {@link #closeListener}. + */ + private void checkCloseConnection(ChannelFuture future) { + // If this connection is closing and the graceful shutdown has completed, close the connection + // once this operation completes. + if (closeListener != null && isGracefulShutdownComplete()) { + ChannelFutureListener closeListener = this.closeListener; + // This method could be called multiple times + // and we don't want to notify the closeListener multiple times. + this.closeListener = null; + try { + closeListener.operationComplete(future); + } catch (Exception e) { + throw new IllegalStateException("Close listener threw an unexpected exception", e); + } + } + } + + /** + * Close the remote endpoint with a {@code GO_AWAY} frame. Does not flush + * immediately, this is the responsibility of the caller. + */ + private ChannelFuture goAway(ChannelHandlerContext ctx, Http2Exception cause, ChannelPromise promise) { + long errorCode = cause != null ? cause.error().code() : NO_ERROR.code(); + int lastKnownStream; + if (cause != null && cause.shutdownHint() == Http2Exception.ShutdownHint.HARD_SHUTDOWN) { + // The hard shutdown could have been triggered during header processing, before updating + // lastStreamCreated(). Specifically, any connection errors encountered by Http2FrameReader or HPACK + // decoding will fail to update the last known stream. So we must be pessimistic. + // https://github.com/netty/netty/issues/10670 + lastKnownStream = Integer.MAX_VALUE; + } else { + lastKnownStream = connection().remote().lastStreamCreated(); + } + return goAway(ctx, lastKnownStream, errorCode, Http2CodecUtil.toByteBuf(ctx, cause), promise); + } + + private void processRstStreamWriteResult(ChannelHandlerContext ctx, Http2Stream stream, ChannelFuture future) { + if (future.isSuccess()) { + closeStream(stream, future); + } else { + // The connection will be closed and so no need to change the resetSent flag to false. + onConnectionError(ctx, true, future.cause(), null); + } + } + + private void closeConnectionOnError(ChannelHandlerContext ctx, ChannelFuture future) { + if (!future.isSuccess()) { + onConnectionError(ctx, true, future.cause(), null); + } + } + + private void doCloseStream(final Http2Stream stream, ChannelFuture future) { + stream.close(); + checkCloseConnection(future); + } + + /** + * Returns the client preface string if this is a client connection, otherwise returns {@code null}. + */ + private static ByteBuf clientPrefaceString(Http2Connection connection) { + return connection.isServer() ? connectionPrefaceBuf() : null; + } + + private static void processGoAwayWriteResult(final ChannelHandlerContext ctx, final int lastStreamId, + final long errorCode, final ByteBuf debugData, ChannelFuture future) { + try { + if (future.isSuccess()) { + if (errorCode != NO_ERROR.code()) { + if (logger.isDebugEnabled()) { + logger.debug("{} Sent GOAWAY: lastStreamId '{}', errorCode '{}', " + + "debugData '{}'. Forcing shutdown of the connection.", + ctx.channel(), lastStreamId, errorCode, debugData.toString(UTF_8), future.cause()); + } + ctx.close(); + } + } else { + if (logger.isDebugEnabled()) { + logger.debug("{} Sending GOAWAY failed: lastStreamId '{}', errorCode '{}', " + + "debugData '{}'. Forcing shutdown of the connection.", + ctx.channel(), lastStreamId, errorCode, debugData.toString(UTF_8), future.cause()); + } + ctx.close(); + } + } finally { + // We're done with the debug data now. + debugData.release(); + } + } + + /** + * Closes the channel when the future completes. + */ + private static final class ClosingChannelFutureListener implements ChannelFutureListener { + private final ChannelHandlerContext ctx; + private final ChannelPromise promise; + private final Future timeoutTask; + private boolean closed; + + ClosingChannelFutureListener(ChannelHandlerContext ctx, ChannelPromise promise) { + this.ctx = ctx; + this.promise = promise; + timeoutTask = null; + } + + ClosingChannelFutureListener(final ChannelHandlerContext ctx, final ChannelPromise promise, + long timeout, TimeUnit unit) { + this.ctx = ctx; + this.promise = promise; + timeoutTask = ctx.executor().schedule(new Runnable() { + @Override + public void run() { + doClose(); + } + }, timeout, unit); + } + + @Override + public void operationComplete(ChannelFuture sentGoAwayFuture) { + if (timeoutTask != null) { + timeoutTask.cancel(false); + } + doClose(); + } + + private void doClose() { + // We need to guard against multiple calls as the timeout may trigger close() first and then it will be + // triggered again because of operationComplete(...) is called. + if (closed) { + // This only happens if we also scheduled a timeout task. + assert timeoutTask != null; + return; + } + closed = true; + if (promise == null) { + ctx.close(); + } else { + ctx.close(promise); + } + } + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ConnectionHandlerBuilder.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ConnectionHandlerBuilder.java new file mode 100644 index 0000000..447a6fb --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ConnectionHandlerBuilder.java @@ -0,0 +1,121 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.codec.http2; + +import io.netty.handler.codec.http2.Http2HeadersEncoder.SensitivityDetector; +import io.netty.util.internal.UnstableApi; + +/** + * Builder which builds {@link Http2ConnectionHandler} objects. + */ +@UnstableApi +public final class Http2ConnectionHandlerBuilder + extends AbstractHttp2ConnectionHandlerBuilder { + + @Override + public Http2ConnectionHandlerBuilder validateHeaders(boolean validateHeaders) { + return super.validateHeaders(validateHeaders); + } + + @Override + public Http2ConnectionHandlerBuilder initialSettings(Http2Settings settings) { + return super.initialSettings(settings); + } + + @Override + public Http2Settings initialSettings() { + return super.initialSettings(); + } + + @Override + public Http2ConnectionHandlerBuilder frameListener(Http2FrameListener frameListener) { + return super.frameListener(frameListener); + } + + @Override + public Http2ConnectionHandlerBuilder gracefulShutdownTimeoutMillis(long gracefulShutdownTimeoutMillis) { + return super.gracefulShutdownTimeoutMillis(gracefulShutdownTimeoutMillis); + } + + @Override + public Http2ConnectionHandlerBuilder server(boolean isServer) { + return super.server(isServer); + } + + @Override + public Http2ConnectionHandlerBuilder connection(Http2Connection connection) { + return super.connection(connection); + } + + @Override + public Http2ConnectionHandlerBuilder maxReservedStreams(int maxReservedStreams) { + return super.maxReservedStreams(maxReservedStreams); + } + + @Override + public Http2ConnectionHandlerBuilder codec(Http2ConnectionDecoder decoder, Http2ConnectionEncoder encoder) { + return super.codec(decoder, encoder); + } + + @Override + public Http2ConnectionHandlerBuilder frameLogger(Http2FrameLogger frameLogger) { + return super.frameLogger(frameLogger); + } + + @Override + public Http2ConnectionHandlerBuilder encoderEnforceMaxConcurrentStreams( + boolean encoderEnforceMaxConcurrentStreams) { + return super.encoderEnforceMaxConcurrentStreams(encoderEnforceMaxConcurrentStreams); + } + + @Override + public Http2ConnectionHandlerBuilder encoderIgnoreMaxHeaderListSize(boolean encoderIgnoreMaxHeaderListSize) { + return super.encoderIgnoreMaxHeaderListSize(encoderIgnoreMaxHeaderListSize); + } + + @Override + public Http2ConnectionHandlerBuilder headerSensitivityDetector(SensitivityDetector headerSensitivityDetector) { + return super.headerSensitivityDetector(headerSensitivityDetector); + } + + @Override + @Deprecated + public Http2ConnectionHandlerBuilder initialHuffmanDecodeCapacity(int initialHuffmanDecodeCapacity) { + return super.initialHuffmanDecodeCapacity(initialHuffmanDecodeCapacity); + } + + @Override + public Http2ConnectionHandlerBuilder decoupleCloseAndGoAway(boolean decoupleCloseAndGoAway) { + return super.decoupleCloseAndGoAway(decoupleCloseAndGoAway); + } + + @Override + public Http2ConnectionHandlerBuilder flushPreface(boolean flushPreface) { + return super.flushPreface(flushPreface); + } + + @Override + public Http2ConnectionHandler build() { + return super.build(); + } + + @Override + protected Http2ConnectionHandler build(Http2ConnectionDecoder decoder, Http2ConnectionEncoder encoder, + Http2Settings initialSettings) { + return new Http2ConnectionHandler(decoder, encoder, initialSettings, decoupleCloseAndGoAway(), flushPreface()); + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ConnectionPrefaceAndSettingsFrameWrittenEvent.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ConnectionPrefaceAndSettingsFrameWrittenEvent.java new file mode 100644 index 0000000..6e975a0 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ConnectionPrefaceAndSettingsFrameWrittenEvent.java @@ -0,0 +1,31 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.util.internal.UnstableApi; + +/** + * Signifies that the connection preface and + * the initial SETTINGS frame have been sent. The client sends the preface, and the server receives the preface. + * The client shouldn't write any data until this event has been processed. + */ +@UnstableApi +public final class Http2ConnectionPrefaceAndSettingsFrameWrittenEvent { + static final Http2ConnectionPrefaceAndSettingsFrameWrittenEvent INSTANCE = + new Http2ConnectionPrefaceAndSettingsFrameWrittenEvent(); + + private Http2ConnectionPrefaceAndSettingsFrameWrittenEvent() { + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ControlFrameLimitEncoder.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ControlFrameLimitEncoder.java new file mode 100644 index 0000000..d5e7d66 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ControlFrameLimitEncoder.java @@ -0,0 +1,113 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +/** + * {@link DecoratingHttp2ConnectionEncoder} which guards against a remote peer that will trigger a massive amount + * of control frames but will not consume our responses to these. + * This encoder will tear-down the connection once we reached the configured limit to reduce the risk of DDOS. + */ +final class Http2ControlFrameLimitEncoder extends DecoratingHttp2ConnectionEncoder { + private static final InternalLogger logger = InternalLoggerFactory.getInstance(Http2ControlFrameLimitEncoder.class); + + private final int maxOutstandingControlFrames; + private final ChannelFutureListener outstandingControlFramesListener = new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + outstandingControlFrames--; + } + }; + private Http2LifecycleManager lifecycleManager; + private int outstandingControlFrames; + private boolean limitReached; + + Http2ControlFrameLimitEncoder(Http2ConnectionEncoder delegate, int maxOutstandingControlFrames) { + super(delegate); + this.maxOutstandingControlFrames = ObjectUtil.checkPositive(maxOutstandingControlFrames, + "maxOutstandingControlFrames"); + } + + @Override + public void lifecycleManager(Http2LifecycleManager lifecycleManager) { + this.lifecycleManager = lifecycleManager; + super.lifecycleManager(lifecycleManager); + } + + @Override + public ChannelFuture writeSettingsAck(ChannelHandlerContext ctx, ChannelPromise promise) { + ChannelPromise newPromise = handleOutstandingControlFrames(ctx, promise); + if (newPromise == null) { + return promise; + } + return super.writeSettingsAck(ctx, newPromise); + } + + @Override + public ChannelFuture writePing(ChannelHandlerContext ctx, boolean ack, long data, ChannelPromise promise) { + // Only apply the limit to ping acks. + if (ack) { + ChannelPromise newPromise = handleOutstandingControlFrames(ctx, promise); + if (newPromise == null) { + return promise; + } + return super.writePing(ctx, ack, data, newPromise); + } + return super.writePing(ctx, ack, data, promise); + } + + @Override + public ChannelFuture writeRstStream( + ChannelHandlerContext ctx, int streamId, long errorCode, ChannelPromise promise) { + ChannelPromise newPromise = handleOutstandingControlFrames(ctx, promise); + if (newPromise == null) { + return promise; + } + return super.writeRstStream(ctx, streamId, errorCode, newPromise); + } + + private ChannelPromise handleOutstandingControlFrames(ChannelHandlerContext ctx, ChannelPromise promise) { + if (!limitReached) { + if (outstandingControlFrames == maxOutstandingControlFrames) { + // Let's try to flush once as we may be able to flush some of the control frames. + ctx.flush(); + } + if (outstandingControlFrames == maxOutstandingControlFrames) { + limitReached = true; + Http2Exception exception = Http2Exception.connectionError(Http2Error.ENHANCE_YOUR_CALM, + "Maximum number %d of outstanding control frames reached", maxOutstandingControlFrames); + logger.info("Maximum number {} of outstanding control frames reached. Closing channel {}", + maxOutstandingControlFrames, ctx.channel(), exception); + + // First notify the Http2LifecycleManager and then close the connection. + lifecycleManager.onError(ctx, true, exception); + ctx.close(); + } + outstandingControlFrames++; + + // We did not reach the limit yet, add the listener to decrement the number of outstanding control frames + // once the promise was completed + return promise.unvoid().addListener(outstandingControlFramesListener); + } + return promise; + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2DataChunkedInput.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2DataChunkedInput.java new file mode 100644 index 0000000..ac382ff --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2DataChunkedInput.java @@ -0,0 +1,116 @@ +/* + * Copyright 2022 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.stream.ChunkedInput; +import io.netty.util.internal.ObjectUtil; + +/** + * A {@link ChunkedInput} that fetches data chunk by chunk for use with HTTP/2 Data Frames. + *

+ * Each chunk from the input data will be wrapped within a {@link Http2DataFrame}. At the end of the input data, + * {@link Http2DataFrame#isEndStream()} will be set to true and will be written. + *

+ *

+ *

+ *
+ *     public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
+ *         if (msg instanceof Http2HeadersFrame) {
+ *             Http2HeadersFrame http2HeadersFrame = (Http2HeadersFrame) msg;
+ *
+ *             Http2HeadersFrame response = new DefaultHttp2HeadersFrame(new DefaultHttp2Headers().status("200"));
+ *             response.stream(http2HeadersFrame.stream());
+ *             ctx.write(response);
+ *
+ *             ChannelFuture sendFileFuture = ctx.writeAndFlush(new Http2DataChunkedInput(
+ *                     new ChunkedFile(new File(("/home/meow/cats.mp4"))), http2HeadersFrame.stream()));
+ *         }
+ *     }
+ * 
+ */ +public final class Http2DataChunkedInput implements ChunkedInput { + + private final ChunkedInput input; + private final Http2FrameStream stream; + private boolean endStreamSent; + + /** + * Creates a new instance using the specified input. + * + * @param input {@link ChunkedInput} containing data to write + * @param stream {@link Http2FrameStream} holding stream info + */ + public Http2DataChunkedInput(ChunkedInput input, Http2FrameStream stream) { + this.input = ObjectUtil.checkNotNull(input, "input"); + this.stream = ObjectUtil.checkNotNull(stream, "stream"); + } + + @Override + public boolean isEndOfInput() throws Exception { + if (input.isEndOfInput()) { + // Only end of input after last HTTP chunk has been sent + return endStreamSent; + } + return false; + } + + @Override + public void close() throws Exception { + input.close(); + } + + @Deprecated + @Override + public Http2DataFrame readChunk(ChannelHandlerContext ctx) throws Exception { + return readChunk(ctx.alloc()); + } + + @Override + public Http2DataFrame readChunk(ByteBufAllocator allocator) throws Exception { + if (endStreamSent) { + return null; + } + + if (input.isEndOfInput()) { + endStreamSent = true; + return new DefaultHttp2DataFrame(true).stream(stream); + } + + ByteBuf buf = input.readChunk(allocator); + if (buf == null) { + return null; + } + + final Http2DataFrame dataFrame = new DefaultHttp2DataFrame(buf, input.isEndOfInput()).stream(stream); + if (dataFrame.isEndStream()) { + endStreamSent = true; + } + + return dataFrame; + } + + @Override + public long length() { + return input.length(); + } + + @Override + public long progress() { + return input.progress(); + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2DataFrame.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2DataFrame.java new file mode 100644 index 0000000..15097b0 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2DataFrame.java @@ -0,0 +1,73 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufHolder; +import io.netty.util.internal.UnstableApi; + +/** + * HTTP/2 DATA frame. + */ +@UnstableApi +public interface Http2DataFrame extends Http2StreamFrame, ByteBufHolder { + + /** + * Frame padding to use. Will be non-negative and less than 256. + */ + int padding(); + + /** + * Payload of DATA frame. Will not be {@code null}. + */ + @Override + ByteBuf content(); + + /** + * Returns the number of bytes that are flow-controlled initially, so even if the {@link #content()} is consumed + * this will not change. + */ + int initialFlowControlledBytes(); + + /** + * Returns {@code true} if the END_STREAM flag is set. + */ + boolean isEndStream(); + + @Override + Http2DataFrame copy(); + + @Override + Http2DataFrame duplicate(); + + @Override + Http2DataFrame retainedDuplicate(); + + @Override + Http2DataFrame replace(ByteBuf content); + + @Override + Http2DataFrame retain(); + + @Override + Http2DataFrame retain(int increment); + + @Override + Http2DataFrame touch(); + + @Override + Http2DataFrame touch(Object hint); +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2DataWriter.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2DataWriter.java new file mode 100644 index 0000000..ad46b85 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2DataWriter.java @@ -0,0 +1,45 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.util.internal.UnstableApi; + +/** + * Interface that defines an object capable of producing HTTP/2 data frames. + */ +@UnstableApi +public interface Http2DataWriter { + /** + * Writes a {@code DATA} frame to the remote endpoint. This will result in one or more + * frames being written to the context. + * + * @param ctx the context to use for writing. + * @param streamId the stream for which to send the frame. + * @param data the payload of the frame. This will be released by this method. + * @param padding additional bytes that should be added to obscure the true content size. Must be between 0 and + * 256 (inclusive). A 1 byte padding is encoded as just the pad length field with value 0. + * A 256 byte padding is encoded as the pad length field with value 255 and 255 padding bytes + * appended to the end of the frame. + * @param endStream indicates if this is the last frame to be sent for the stream. + * @param promise the promise for the write. + * @return the future for the write. + */ + ChannelFuture writeData(ChannelHandlerContext ctx, int streamId, + ByteBuf data, int padding, boolean endStream, ChannelPromise promise); +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2EmptyDataFrameConnectionDecoder.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2EmptyDataFrameConnectionDecoder.java new file mode 100644 index 0000000..69f2f2f --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2EmptyDataFrameConnectionDecoder.java @@ -0,0 +1,56 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.util.internal.ObjectUtil; + +/** + * Enforce a limit on the maximum number of consecutive empty DATA frames (without end_of_stream flag) that are allowed + * before the connection will be closed. + */ +final class Http2EmptyDataFrameConnectionDecoder extends DecoratingHttp2ConnectionDecoder { + + private final int maxConsecutiveEmptyFrames; + + Http2EmptyDataFrameConnectionDecoder(Http2ConnectionDecoder delegate, int maxConsecutiveEmptyFrames) { + super(delegate); + this.maxConsecutiveEmptyFrames = ObjectUtil.checkPositive( + maxConsecutiveEmptyFrames, "maxConsecutiveEmptyFrames"); + } + + @Override + public void frameListener(Http2FrameListener listener) { + if (listener != null) { + super.frameListener(new Http2EmptyDataFrameListener(listener, maxConsecutiveEmptyFrames)); + } else { + super.frameListener(null); + } + } + + @Override + public Http2FrameListener frameListener() { + Http2FrameListener frameListener = frameListener0(); + // Unwrap the original Http2FrameListener as we add this decoder under the hood. + if (frameListener instanceof Http2EmptyDataFrameListener) { + return ((Http2EmptyDataFrameListener) frameListener).listener; + } + return frameListener; + } + + // Package-private for testing + Http2FrameListener frameListener0() { + return super.frameListener(); + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2EmptyDataFrameListener.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2EmptyDataFrameListener.java new file mode 100644 index 0000000..7f194bf --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2EmptyDataFrameListener.java @@ -0,0 +1,65 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.util.internal.ObjectUtil; + +/** + * Enforce a limit on the maximum number of consecutive empty DATA frames (without end_of_stream flag) that are allowed + * before the connection will be closed. + */ +final class Http2EmptyDataFrameListener extends Http2FrameListenerDecorator { + private final int maxConsecutiveEmptyFrames; + + private boolean violationDetected; + private int emptyDataFrames; + + Http2EmptyDataFrameListener(Http2FrameListener listener, int maxConsecutiveEmptyFrames) { + super(listener); + this.maxConsecutiveEmptyFrames = ObjectUtil.checkPositive( + maxConsecutiveEmptyFrames, "maxConsecutiveEmptyFrames"); + } + + @Override + public int onDataRead(ChannelHandlerContext ctx, int streamId, ByteBuf data, int padding, boolean endOfStream) + throws Http2Exception { + if (endOfStream || data.isReadable()) { + emptyDataFrames = 0; + } else if (emptyDataFrames++ == maxConsecutiveEmptyFrames && !violationDetected) { + violationDetected = true; + throw Http2Exception.connectionError(Http2Error.ENHANCE_YOUR_CALM, + "Maximum number %d of empty data frames without end_of_stream flag received", + maxConsecutiveEmptyFrames); + } + + return super.onDataRead(ctx, streamId, data, padding, endOfStream); + } + + @Override + public void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers headers, + int padding, boolean endStream) throws Http2Exception { + emptyDataFrames = 0; + super.onHeadersRead(ctx, streamId, headers, padding, endStream); + } + + @Override + public void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers headers, int streamDependency, + short weight, boolean exclusive, int padding, boolean endStream) throws Http2Exception { + emptyDataFrames = 0; + super.onHeadersRead(ctx, streamId, headers, streamDependency, weight, exclusive, padding, endStream); + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2Error.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2Error.java new file mode 100644 index 0000000..96187b7 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2Error.java @@ -0,0 +1,65 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package io.netty.handler.codec.http2; + +import io.netty.util.internal.UnstableApi; + +/** + * All error codes identified by the HTTP/2 spec. + */ +@UnstableApi +public enum Http2Error { + NO_ERROR(0x0), + PROTOCOL_ERROR(0x1), + INTERNAL_ERROR(0x2), + FLOW_CONTROL_ERROR(0x3), + SETTINGS_TIMEOUT(0x4), + STREAM_CLOSED(0x5), + FRAME_SIZE_ERROR(0x6), + REFUSED_STREAM(0x7), + CANCEL(0x8), + COMPRESSION_ERROR(0x9), + CONNECT_ERROR(0xA), + ENHANCE_YOUR_CALM(0xB), + INADEQUATE_SECURITY(0xC), + HTTP_1_1_REQUIRED(0xD); + + private final long code; + private static final Http2Error[] INT_TO_ENUM_MAP; + static { + Http2Error[] errors = values(); + Http2Error[] map = new Http2Error[errors.length]; + for (Http2Error error : errors) { + map[(int) error.code()] = error; + } + INT_TO_ENUM_MAP = map; + } + + Http2Error(long code) { + this.code = code; + } + + /** + * Gets the code for this error used on the wire. + */ + public long code() { + return code; + } + + public static Http2Error valueOf(long value) { + return value >= INT_TO_ENUM_MAP.length || value < 0 ? null : INT_TO_ENUM_MAP[(int) value]; + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2EventAdapter.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2EventAdapter.java new file mode 100644 index 0000000..328b823 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2EventAdapter.java @@ -0,0 +1,115 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.util.internal.UnstableApi; + +/** + * This class brings {@link Http2Connection.Listener} and {@link Http2FrameListener} together to provide + * NOOP implementation so inheriting classes can selectively choose which methods to override. + */ +@UnstableApi +public class Http2EventAdapter implements Http2Connection.Listener, Http2FrameListener { + @Override + public int onDataRead(ChannelHandlerContext ctx, int streamId, ByteBuf data, int padding, boolean endOfStream) + throws Http2Exception { + return data.readableBytes() + padding; + } + + @Override + public void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers headers, int padding, + boolean endStream) throws Http2Exception { + } + + @Override + public void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers headers, int streamDependency, + short weight, boolean exclusive, int padding, boolean endStream) throws Http2Exception { + } + + @Override + public void onPriorityRead(ChannelHandlerContext ctx, int streamId, int streamDependency, short weight, + boolean exclusive) throws Http2Exception { + } + + @Override + public void onRstStreamRead(ChannelHandlerContext ctx, int streamId, long errorCode) throws Http2Exception { + } + + @Override + public void onSettingsAckRead(ChannelHandlerContext ctx) throws Http2Exception { + } + + @Override + public void onSettingsRead(ChannelHandlerContext ctx, Http2Settings settings) throws Http2Exception { + } + + @Override + public void onPingRead(ChannelHandlerContext ctx, long data) throws Http2Exception { + } + + @Override + public void onPingAckRead(ChannelHandlerContext ctx, long data) throws Http2Exception { + } + + @Override + public void onPushPromiseRead(ChannelHandlerContext ctx, int streamId, int promisedStreamId, + Http2Headers headers, int padding) throws Http2Exception { + } + + @Override + public void onGoAwayRead(ChannelHandlerContext ctx, int lastStreamId, long errorCode, ByteBuf debugData) + throws Http2Exception { + } + + @Override + public void onWindowUpdateRead(ChannelHandlerContext ctx, int streamId, int windowSizeIncrement) + throws Http2Exception { + } + + @Override + public void onUnknownFrame(ChannelHandlerContext ctx, byte frameType, int streamId, Http2Flags flags, + ByteBuf payload) throws Http2Exception { + } + + @Override + public void onStreamAdded(Http2Stream stream) { + } + + @Override + public void onStreamActive(Http2Stream stream) { + } + + @Override + public void onStreamHalfClosed(Http2Stream stream) { + } + + @Override + public void onStreamClosed(Http2Stream stream) { + } + + @Override + public void onStreamRemoved(Http2Stream stream) { + } + + @Override + public void onGoAwaySent(int lastStreamId, long errorCode, ByteBuf debugData) { + } + + @Override + public void onGoAwayReceived(int lastStreamId, long errorCode, ByteBuf debugData) { + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2Exception.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2Exception.java new file mode 100644 index 0000000..9b3e61f --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2Exception.java @@ -0,0 +1,344 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package io.netty.handler.codec.http2; + +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.SuppressJava6Requirement; +import io.netty.util.internal.ThrowableUtil; +import io.netty.util.internal.UnstableApi; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; + +import static io.netty.handler.codec.http2.Http2CodecUtil.CONNECTION_STREAM_ID; +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +/** + * Exception thrown when an HTTP/2 error was encountered. + */ +@UnstableApi +public class Http2Exception extends Exception { + private static final long serialVersionUID = -6941186345430164209L; + private final Http2Error error; + private final ShutdownHint shutdownHint; + + public Http2Exception(Http2Error error) { + this(error, ShutdownHint.HARD_SHUTDOWN); + } + + public Http2Exception(Http2Error error, ShutdownHint shutdownHint) { + this.error = checkNotNull(error, "error"); + this.shutdownHint = checkNotNull(shutdownHint, "shutdownHint"); + } + + public Http2Exception(Http2Error error, String message) { + this(error, message, ShutdownHint.HARD_SHUTDOWN); + } + + public Http2Exception(Http2Error error, String message, ShutdownHint shutdownHint) { + super(message); + this.error = checkNotNull(error, "error"); + this.shutdownHint = checkNotNull(shutdownHint, "shutdownHint"); + } + + public Http2Exception(Http2Error error, String message, Throwable cause) { + this(error, message, cause, ShutdownHint.HARD_SHUTDOWN); + } + + public Http2Exception(Http2Error error, String message, Throwable cause, ShutdownHint shutdownHint) { + super(message, cause); + this.error = checkNotNull(error, "error"); + this.shutdownHint = checkNotNull(shutdownHint, "shutdownHint"); + } + + static Http2Exception newStatic(Http2Error error, String message, ShutdownHint shutdownHint, + Class clazz, String method) { + final Http2Exception exception; + if (PlatformDependent.javaVersion() >= 7) { + exception = new StacklessHttp2Exception(error, message, shutdownHint, true); + } else { + exception = new StacklessHttp2Exception(error, message, shutdownHint); + } + return ThrowableUtil.unknownStackTrace(exception, clazz, method); + } + + @SuppressJava6Requirement(reason = "uses Java 7+ Exception.(String, Throwable, boolean, boolean)" + + " but is guarded by version checks") + private Http2Exception(Http2Error error, String message, ShutdownHint shutdownHint, boolean shared) { + super(message, null, false, true); + assert shared; + this.error = checkNotNull(error, "error"); + this.shutdownHint = checkNotNull(shutdownHint, "shutdownHint"); + } + + public Http2Error error() { + return error; + } + + /** + * Provide a hint as to what type of shutdown should be executed. Note this hint may be ignored. + */ + public ShutdownHint shutdownHint() { + return shutdownHint; + } + + /** + * Use if an error has occurred which can not be isolated to a single stream, but instead applies + * to the entire connection. + * @param error The type of error as defined by the HTTP/2 specification. + * @param fmt String with the content and format for the additional debug data. + * @param args Objects which fit into the format defined by {@code fmt}. + * @return An exception which can be translated into an HTTP/2 error. + */ + public static Http2Exception connectionError(Http2Error error, String fmt, Object... args) { + return new Http2Exception(error, formatErrorMessage(fmt, args)); + } + + /** + * Use if an error has occurred which can not be isolated to a single stream, but instead applies + * to the entire connection. + * @param error The type of error as defined by the HTTP/2 specification. + * @param cause The object which caused the error. + * @param fmt String with the content and format for the additional debug data. + * @param args Objects which fit into the format defined by {@code fmt}. + * @return An exception which can be translated into an HTTP/2 error. + */ + public static Http2Exception connectionError(Http2Error error, Throwable cause, + String fmt, Object... args) { + return new Http2Exception(error, formatErrorMessage(fmt, args), cause); + } + + /** + * Use if an error has occurred which can not be isolated to a single stream, but instead applies + * to the entire connection. + * @param error The type of error as defined by the HTTP/2 specification. + * @param fmt String with the content and format for the additional debug data. + * @param args Objects which fit into the format defined by {@code fmt}. + * @return An exception which can be translated into an HTTP/2 error. + */ + public static Http2Exception closedStreamError(Http2Error error, String fmt, Object... args) { + return new ClosedStreamCreationException(error, formatErrorMessage(fmt, args)); + } + + /** + * Use if an error which can be isolated to a single stream has occurred. If the {@code id} is not + * {@link Http2CodecUtil#CONNECTION_STREAM_ID} then a {@link StreamException} will be returned. + * Otherwise the error is considered a connection error and a {@link Http2Exception} is returned. + * @param id The stream id for which the error is isolated to. + * @param error The type of error as defined by the HTTP/2 specification. + * @param fmt String with the content and format for the additional debug data. + * @param args Objects which fit into the format defined by {@code fmt}. + * @return If the {@code id} is not + * {@link Http2CodecUtil#CONNECTION_STREAM_ID} then a {@link StreamException} will be returned. + * Otherwise the error is considered a connection error and a {@link Http2Exception} is returned. + */ + public static Http2Exception streamError(int id, Http2Error error, String fmt, Object... args) { + return CONNECTION_STREAM_ID == id ? + connectionError(error, fmt, args) : + new StreamException(id, error, formatErrorMessage(fmt, args)); + } + + /** + * Use if an error which can be isolated to a single stream has occurred. If the {@code id} is not + * {@link Http2CodecUtil#CONNECTION_STREAM_ID} then a {@link StreamException} will be returned. + * Otherwise the error is considered a connection error and a {@link Http2Exception} is returned. + * @param id The stream id for which the error is isolated to. + * @param error The type of error as defined by the HTTP/2 specification. + * @param cause The object which caused the error. + * @param fmt String with the content and format for the additional debug data. + * @param args Objects which fit into the format defined by {@code fmt}. + * @return If the {@code id} is not + * {@link Http2CodecUtil#CONNECTION_STREAM_ID} then a {@link StreamException} will be returned. + * Otherwise the error is considered a connection error and a {@link Http2Exception} is returned. + */ + public static Http2Exception streamError(int id, Http2Error error, Throwable cause, + String fmt, Object... args) { + return CONNECTION_STREAM_ID == id ? + connectionError(error, cause, fmt, args) : + new StreamException(id, error, formatErrorMessage(fmt, args), cause); + } + + /** + * A specific stream error resulting from failing to decode headers that exceeds the max header size list. + * If the {@code id} is not {@link Http2CodecUtil#CONNECTION_STREAM_ID} then a + * {@link StreamException} will be returned. Otherwise the error is considered a + * connection error and a {@link Http2Exception} is returned. + * @param id The stream id for which the error is isolated to. + * @param error The type of error as defined by the HTTP/2 specification. + * @param onDecode Whether this error was caught while decoding headers + * @param fmt String with the content and format for the additional debug data. + * @param args Objects which fit into the format defined by {@code fmt}. + * @return If the {@code id} is not + * {@link Http2CodecUtil#CONNECTION_STREAM_ID} then a {@link HeaderListSizeException} + * will be returned. Otherwise the error is considered a connection error and a {@link Http2Exception} is + * returned. + */ + public static Http2Exception headerListSizeError(int id, Http2Error error, boolean onDecode, + String fmt, Object... args) { + return CONNECTION_STREAM_ID == id ? + connectionError(error, fmt, args) : + new HeaderListSizeException(id, error, formatErrorMessage(fmt, args), onDecode); + } + + private static String formatErrorMessage(String fmt, Object[] args) { + if (fmt == null) { + if (args == null || args.length == 0) { + return "Unexpected error"; + } + return "Unexpected error: " + Arrays.toString(args); + } + return String.format(fmt, args); + } + + /** + * Check if an exception is isolated to a single stream or the entire connection. + * @param e The exception to check. + * @return {@code true} if {@code e} is an instance of {@link StreamException}. + * {@code false} otherwise. + */ + public static boolean isStreamError(Http2Exception e) { + return e instanceof StreamException; + } + + /** + * Get the stream id associated with an exception. + * @param e The exception to get the stream id for. + * @return {@link Http2CodecUtil#CONNECTION_STREAM_ID} if {@code e} is a connection error. + * Otherwise the stream id associated with the stream error. + */ + public static int streamId(Http2Exception e) { + return isStreamError(e) ? ((StreamException) e).streamId() : CONNECTION_STREAM_ID; + } + + /** + * Provides a hint as to if shutdown is justified, what type of shutdown should be executed. + */ + public enum ShutdownHint { + /** + * Do not shutdown the underlying channel. + */ + NO_SHUTDOWN, + /** + * Attempt to execute a "graceful" shutdown. The definition of "graceful" is left to the implementation. + * An example of "graceful" would be wait for some amount of time until all active streams are closed. + */ + GRACEFUL_SHUTDOWN, + /** + * Close the channel immediately after a {@code GOAWAY} is sent. + */ + HARD_SHUTDOWN + } + + /** + * Used when a stream creation attempt fails but may be because the stream was previously closed. + */ + public static final class ClosedStreamCreationException extends Http2Exception { + private static final long serialVersionUID = -6746542974372246206L; + + public ClosedStreamCreationException(Http2Error error) { + super(error); + } + + public ClosedStreamCreationException(Http2Error error, String message) { + super(error, message); + } + + public ClosedStreamCreationException(Http2Error error, String message, Throwable cause) { + super(error, message, cause); + } + } + + /** + * Represents an exception that can be isolated to a single stream (as opposed to the entire connection). + */ + public static class StreamException extends Http2Exception { + private static final long serialVersionUID = 602472544416984384L; + private final int streamId; + + StreamException(int streamId, Http2Error error, String message) { + super(error, message, ShutdownHint.NO_SHUTDOWN); + this.streamId = streamId; + } + + StreamException(int streamId, Http2Error error, String message, Throwable cause) { + super(error, message, cause, ShutdownHint.NO_SHUTDOWN); + this.streamId = streamId; + } + + public int streamId() { + return streamId; + } + } + + public static final class HeaderListSizeException extends StreamException { + private static final long serialVersionUID = -8807603212183882637L; + + private final boolean decode; + + HeaderListSizeException(int streamId, Http2Error error, String message, boolean decode) { + super(streamId, error, message); + this.decode = decode; + } + + public boolean duringDecode() { + return decode; + } + } + + /** + * Provides the ability to handle multiple stream exceptions with one throw statement. + */ + public static final class CompositeStreamException extends Http2Exception implements Iterable { + private static final long serialVersionUID = 7091134858213711015L; + private final List exceptions; + + public CompositeStreamException(Http2Error error, int initialCapacity) { + super(error, ShutdownHint.NO_SHUTDOWN); + exceptions = new ArrayList(initialCapacity); + } + + public void add(StreamException e) { + exceptions.add(e); + } + + @Override + public Iterator iterator() { + return exceptions.iterator(); + } + } + + private static final class StacklessHttp2Exception extends Http2Exception { + + private static final long serialVersionUID = 1077888485687219443L; + + StacklessHttp2Exception(Http2Error error, String message, ShutdownHint shutdownHint) { + super(error, message, shutdownHint); + } + + StacklessHttp2Exception(Http2Error error, String message, ShutdownHint shutdownHint, boolean shared) { + super(error, message, shutdownHint, shared); + } + + // Override fillInStackTrace() so we not populate the backtrace via a native call and so leak the + // Classloader. + @Override + public Throwable fillInStackTrace() { + return this; + } + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2Flags.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2Flags.java new file mode 100644 index 0000000..7885d20 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2Flags.java @@ -0,0 +1,207 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package io.netty.handler.codec.http2; + +import io.netty.util.internal.UnstableApi; + +/** + * Provides utility methods for accessing specific flags as defined by the HTTP/2 spec. + */ +@UnstableApi +public final class Http2Flags { + public static final short END_STREAM = 0x1; + public static final short END_HEADERS = 0x4; + public static final short ACK = 0x1; + public static final short PADDED = 0x8; + public static final short PRIORITY = 0x20; + + private short value; + + public Http2Flags() { + } + + public Http2Flags(short value) { + this.value = value; + } + + /** + * Gets the underlying flags value. + */ + public short value() { + return value; + } + + /** + * Determines whether the {@link #END_STREAM} flag is set. Only applies to DATA and HEADERS + * frames. + */ + public boolean endOfStream() { + return isFlagSet(END_STREAM); + } + + /** + * Determines whether the {@link #END_HEADERS} flag is set. Only applies for HEADERS, + * PUSH_PROMISE, and CONTINUATION frames. + */ + public boolean endOfHeaders() { + return isFlagSet(END_HEADERS); + } + + /** + * Determines whether the flag is set indicating the presence of the exclusive, stream + * dependency, and weight fields in a HEADERS frame. + */ + public boolean priorityPresent() { + return isFlagSet(PRIORITY); + } + + /** + * Determines whether the flag is set indicating that this frame is an ACK. Only applies for + * SETTINGS and PING frames. + */ + public boolean ack() { + return isFlagSet(ACK); + } + + /** + * For frames that include padding, indicates if the {@link #PADDED} field is present. Only + * applies to DATA, HEADERS, PUSH_PROMISE and CONTINUATION frames. + */ + public boolean paddingPresent() { + return isFlagSet(PADDED); + } + + /** + * Gets the number of bytes expected for the priority fields of the payload. This is determined + * by the {@link #priorityPresent()} flag. + */ + public int getNumPriorityBytes() { + return priorityPresent() ? 5 : 0; + } + + /** + * Gets the length in bytes of the padding presence field expected in the payload. This is + * determined by the {@link #paddingPresent()} flag. + */ + public int getPaddingPresenceFieldLength() { + return paddingPresent() ? 1 : 0; + } + + /** + * Sets the {@link #END_STREAM} flag. + */ + public Http2Flags endOfStream(boolean endOfStream) { + return setFlag(endOfStream, END_STREAM); + } + + /** + * Sets the {@link #END_HEADERS} flag. + */ + public Http2Flags endOfHeaders(boolean endOfHeaders) { + return setFlag(endOfHeaders, END_HEADERS); + } + + /** + * Sets the {@link #PRIORITY} flag. + */ + public Http2Flags priorityPresent(boolean priorityPresent) { + return setFlag(priorityPresent, PRIORITY); + } + + /** + * Sets the {@link #PADDED} flag. + */ + public Http2Flags paddingPresent(boolean paddingPresent) { + return setFlag(paddingPresent, PADDED); + } + + /** + * Sets the {@link #ACK} flag. + */ + public Http2Flags ack(boolean ack) { + return setFlag(ack, ACK); + } + + /** + * Generic method to set any flag. + * @param on if the flag should be enabled or disabled. + * @param mask the mask that identifies the bit for the flag. + * @return this instance. + */ + public Http2Flags setFlag(boolean on, short mask) { + if (on) { + value |= mask; + } else { + value &= ~mask; + } + return this; + } + + /** + * Indicates whether or not a particular flag is set. + * @param mask the mask identifying the bit for the particular flag being tested + * @return {@code true} if the flag is set + */ + public boolean isFlagSet(short mask) { + return (value & mask) != 0; + } + + @Override + public int hashCode() { + final int prime = 31; + int result = 1; + result = prime * result + value; + return result; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (obj == null) { + return false; + } + if (getClass() != obj.getClass()) { + return false; + } + + return value == ((Http2Flags) obj).value; + } + + @Override + public String toString() { + StringBuilder builder = new StringBuilder(); + builder.append("value = ").append(value).append(" ("); + if (ack()) { + builder.append("ACK,"); + } + if (endOfHeaders()) { + builder.append("END_OF_HEADERS,"); + } + if (endOfStream()) { + builder.append("END_OF_STREAM,"); + } + if (priorityPresent()) { + builder.append("PRIORITY_PRESENT,"); + } + if (paddingPresent()) { + builder.append("PADDING_PRESENT,"); + } + builder.append(')'); + return builder.toString(); + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FlowController.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FlowController.java new file mode 100644 index 0000000..2846b3d --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FlowController.java @@ -0,0 +1,81 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.util.internal.UnstableApi; + +/** + * Base interface for all HTTP/2 flow controllers. + */ +@UnstableApi +public interface Http2FlowController { + /** + * Set the {@link ChannelHandlerContext} for which to apply flow control on. + *

+ * This must be called to properly initialize the {@link Http2FlowController}. + * Not calling this is considered a programming error. + * @param ctx The {@link ChannelHandlerContext} for which to apply flow control on. + * @throws Http2Exception if any protocol-related error occurred. + */ + void channelHandlerContext(ChannelHandlerContext ctx) throws Http2Exception; + + /** + * Sets the connection-wide initial flow control window and updates all stream windows (but not the connection + * stream window) by the delta. + *

+ * Represents the value for + * SETTINGS_INITIAL_WINDOW_SIZE. This method should + * only be called by Netty (not users) as a result of a receiving a {@code SETTINGS} frame. + * + * @param newWindowSize the new initial window size. + * @throws Http2Exception thrown if any protocol-related error occurred. + */ + void initialWindowSize(int newWindowSize) throws Http2Exception; + + /** + * Gets the connection-wide initial flow control window size that is used as the basis for new stream flow + * control windows. + *

+ * Represents the value for + * SETTINGS_INITIAL_WINDOW_SIZE. The initial value + * returned by this method must be {@link Http2CodecUtil#DEFAULT_WINDOW_SIZE}. + */ + int initialWindowSize(); + + /** + * Get the portion of the flow control window for the given stream that is currently available for sending/receiving + * frames which are subject to flow control. This quantity is measured in number of bytes. + */ + int windowSize(Http2Stream stream); + + /** + * Increments the size of the stream's flow control window by the given delta. + *

+ * In the case of a {@link Http2RemoteFlowController} this is called upon receipt of a + * {@code WINDOW_UPDATE} frame from the remote endpoint to mirror the changes to the window + * size. + *

+ * For a {@link Http2LocalFlowController} this can be called to request the expansion of the + * window size published by this endpoint. It is up to the implementation, however, as to when a + * {@code WINDOW_UPDATE} is actually sent. + * + * @param stream The subject stream. Use {@link Http2Connection#connectionStream()} for + * requesting the size of the connection window. + * @param delta the change in size of the flow control window. + * @throws Http2Exception thrown if a protocol-related error occurred. + */ + void incrementWindowSize(Http2Stream stream, int delta) throws Http2Exception; +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2Frame.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2Frame.java new file mode 100644 index 0000000..799edaf --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2Frame.java @@ -0,0 +1,28 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.util.internal.UnstableApi; + +/** An HTTP/2 frame. */ +@UnstableApi +public interface Http2Frame { + + /** + * Returns the name of the HTTP/2 frame e.g. DATA, GOAWAY, etc. + */ + String name(); +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameAdapter.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameAdapter.java new file mode 100644 index 0000000..e491e6c --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameAdapter.java @@ -0,0 +1,90 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.util.internal.UnstableApi; + +/** + * Convenience class that provides no-op implementations for all methods of {@link Http2FrameListener}. + */ +@UnstableApi +public class Http2FrameAdapter implements Http2FrameListener { + + @Override + public int onDataRead(ChannelHandlerContext ctx, int streamId, ByteBuf data, int padding, + boolean endOfStream) throws Http2Exception { + return data.readableBytes() + padding; + } + + @Override + public void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers headers, + int padding, boolean endStream) throws Http2Exception { + } + + @Override + public void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers headers, + int streamDependency, short weight, boolean exclusive, int padding, boolean endStream) + throws Http2Exception { + } + + @Override + public void onPriorityRead(ChannelHandlerContext ctx, int streamId, int streamDependency, + short weight, boolean exclusive) throws Http2Exception { + } + + @Override + public void onRstStreamRead(ChannelHandlerContext ctx, int streamId, long errorCode) + throws Http2Exception { + } + + @Override + public void onSettingsAckRead(ChannelHandlerContext ctx) throws Http2Exception { + } + + @Override + public void onSettingsRead(ChannelHandlerContext ctx, Http2Settings settings) + throws Http2Exception { + } + + @Override + public void onPingRead(ChannelHandlerContext ctx, long data) throws Http2Exception { + } + + @Override + public void onPingAckRead(ChannelHandlerContext ctx, long data) throws Http2Exception { + } + + @Override + public void onPushPromiseRead(ChannelHandlerContext ctx, int streamId, int promisedStreamId, + Http2Headers headers, int padding) throws Http2Exception { + } + + @Override + public void onGoAwayRead(ChannelHandlerContext ctx, int lastStreamId, long errorCode, + ByteBuf debugData) throws Http2Exception { + } + + @Override + public void onWindowUpdateRead(ChannelHandlerContext ctx, int streamId, int windowSizeIncrement) + throws Http2Exception { + } + + @Override + public void onUnknownFrame(ChannelHandlerContext ctx, byte frameType, int streamId, Http2Flags flags, + ByteBuf payload) { + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameCodec.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameCodec.java new file mode 100644 index 0000000..017a3c5 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameCodec.java @@ -0,0 +1,769 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandler; +import io.netty.channel.ChannelPromise; +import io.netty.handler.codec.UnsupportedMessageTypeException; +import io.netty.handler.codec.http.HttpServerUpgradeHandler.UpgradeEvent; +import io.netty.handler.codec.http2.Http2Connection.PropertyKey; +import io.netty.handler.codec.http2.Http2Stream.State; +import io.netty.handler.codec.http2.StreamBufferingEncoder.Http2ChannelClosedException; +import io.netty.handler.codec.http2.StreamBufferingEncoder.Http2GoAwayException; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.ReferenceCounted; +import io.netty.util.collection.IntObjectHashMap; +import io.netty.util.collection.IntObjectMap; +import io.netty.util.internal.UnstableApi; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import static io.netty.buffer.ByteBufUtil.writeAscii; +import static io.netty.handler.codec.http2.Http2CodecUtil.HTTP_UPGRADE_STREAM_ID; +import static io.netty.handler.codec.http2.Http2CodecUtil.isStreamIdValid; +import static io.netty.handler.codec.http2.Http2Error.NO_ERROR; +import static io.netty.util.internal.logging.InternalLogLevel.DEBUG; + +/** + *

This API is very immature. The Http2Connection-based API is currently preferred over this API. + * This API is targeted to eventually replace or reduce the need for the {@link Http2ConnectionHandler} API. + * + *

An HTTP/2 handler that maps HTTP/2 frames to {@link Http2Frame} objects and vice versa. For every incoming HTTP/2 + * frame, an {@link Http2Frame} object is created and propagated via {@link #channelRead}. Outbound {@link Http2Frame} + * objects received via {@link #write} are converted to the HTTP/2 wire format. HTTP/2 frames specific to a stream + * implement the {@link Http2StreamFrame} interface. The {@link Http2FrameCodec} is instantiated using the + * {@link Http2FrameCodecBuilder}. It's recommended for channel handlers to inherit from the + * {@link Http2ChannelDuplexHandler}, as it provides additional functionality like iterating over all active streams or + * creating outbound streams. + * + *

Stream Lifecycle

+ *

+ * The frame codec delivers and writes frames for active streams. An active stream is closed when either side sends a + * {@code RST_STREAM} frame or both sides send a frame with the {@code END_STREAM} flag set. Each + * {@link Http2StreamFrame} has a {@link Http2FrameStream} object attached that uniquely identifies a particular stream. + * + *

{@link Http2StreamFrame}s read from the channel always a {@link Http2FrameStream} object set, while when writing a + * {@link Http2StreamFrame} the application code needs to set a {@link Http2FrameStream} object using + * {@link Http2StreamFrame#stream(Http2FrameStream)}. + * + *

Flow control

+ *

+ * The frame codec automatically increments stream and connection flow control windows. + * + *

Incoming flow controlled frames need to be consumed by writing a {@link Http2WindowUpdateFrame} with the consumed + * number of bytes and the corresponding stream identifier set to the frame codec. + * + *

The local stream-level flow control window can be changed by writing a {@link Http2SettingsFrame} with the + * {@link Http2Settings#initialWindowSize()} set to the targeted value. + * + *

The connection-level flow control window can be changed by writing a {@link Http2WindowUpdateFrame} with the + * desired window size increment in bytes and the stream identifier set to {@code 0}. By default the initial + * connection-level flow control window is the same as initial stream-level flow control window. + * + *

New inbound Streams

+ *

+ * The first frame of an HTTP/2 stream must be an {@link Http2HeadersFrame}, which will have an {@link Http2FrameStream} + * object attached. + * + *

New outbound Streams

+ *

+ * A outbound HTTP/2 stream can be created by first instantiating a new {@link Http2FrameStream} object via + * {@link Http2ChannelDuplexHandler#newStream()}, and then writing a {@link Http2HeadersFrame} object with the stream + * attached. + * + *

+ *     final Http2Stream2 stream = handler.newStream();
+ *     ctx.write(headersFrame.stream(stream)).addListener(new ChannelFutureListener() {
+ *
+ *         @Override
+ *         public void operationComplete(ChannelFuture f) {
+ *             if (f.isSuccess()) {
+ *                 // Stream is active and stream.id() returns a valid stream identifier.
+ *                 System.out.println("New stream with id " + stream.id() + " created.");
+ *             } else {
+ *                 // Stream failed to become active. Handle error.
+ *                 if (f.cause() instanceof Http2NoMoreStreamIdsException) {
+ *
+ *                 } else if (f.cause() instanceof Http2GoAwayException) {
+ *
+ *                 } else {
+ *
+ *                 }
+ *             }
+ *         }
+ *     }
+ * 
+ * + *

If a new stream cannot be created due to stream id exhaustion of the endpoint, the {@link ChannelPromise} of the + * HEADERS frame will fail with a {@link Http2NoMoreStreamIdsException}. + * + *

The HTTP/2 standard allows for an endpoint to limit the maximum number of concurrently active streams via the + * {@code SETTINGS_MAX_CONCURRENT_STREAMS} setting. When this limit is reached, no new streams can be created. However, + * the {@link Http2FrameCodec} can be build with + * {@link Http2FrameCodecBuilder#encoderEnforceMaxConcurrentStreams(boolean)} enabled, in which case a new stream and + * its associated frames will be buffered until either the limit is increased or an active stream is closed. It's, + * however, possible that a buffered stream will never become active. That is, the channel might + * get closed or a GO_AWAY frame might be received. In the first case, all writes of buffered streams will fail with a + * {@link Http2ChannelClosedException}. In the second case, all writes of buffered streams with an identifier less than + * the last stream identifier of the GO_AWAY frame will fail with a {@link Http2GoAwayException}. + * + *

Error Handling

+ *

+ * Exceptions and errors are propagated via {@link ChannelInboundHandler#exceptionCaught}. Exceptions that apply to + * a specific HTTP/2 stream are wrapped in a {@link Http2FrameStreamException} and have the corresponding + * {@link Http2FrameStream} object attached. + * + *

Reference Counting

+ *

+ * Some {@link Http2StreamFrame}s implement the {@link ReferenceCounted} interface, as they carry + * reference counted objects (e.g. {@link ByteBuf}s). The frame codec will call {@link ReferenceCounted#retain()} before + * propagating a reference counted object through the pipeline, and thus an application handler needs to release such + * an object after having consumed it. For more information on reference counting take a look at + * https://netty.io/wiki/reference-counted-objects.html + * + *

HTTP Upgrade

+ *

+ * Server-side HTTP to HTTP/2 upgrade is supported in conjunction with {@link Http2ServerUpgradeCodec}; the necessary + * HTTP-to-HTTP/2 conversion is performed automatically. + */ +@UnstableApi +public class Http2FrameCodec extends Http2ConnectionHandler { + + private static final InternalLogger LOG = InternalLoggerFactory.getInstance(Http2FrameCodec.class); + + protected final PropertyKey streamKey; + private final PropertyKey upgradeKey; + + private final Integer initialFlowControlWindowSize; + + ChannelHandlerContext ctx; + + /** + * Number of buffered streams if the {@link StreamBufferingEncoder} is used. + **/ + private int numBufferedStreams; + private final IntObjectMap frameStreamToInitializeMap = + new IntObjectHashMap(8); + + Http2FrameCodec(Http2ConnectionEncoder encoder, Http2ConnectionDecoder decoder, Http2Settings initialSettings, + boolean decoupleCloseAndGoAway, boolean flushPreface) { + super(decoder, encoder, initialSettings, decoupleCloseAndGoAway, flushPreface); + + decoder.frameListener(new FrameListener()); + connection().addListener(new ConnectionListener()); + connection().remote().flowController().listener(new Http2RemoteFlowControllerListener()); + streamKey = connection().newKey(); + upgradeKey = connection().newKey(); + initialFlowControlWindowSize = initialSettings.initialWindowSize(); + } + + /** + * Creates a new outbound/local stream. + */ + DefaultHttp2FrameStream newStream() { + return new DefaultHttp2FrameStream(); + } + + /** + * Iterates over all active HTTP/2 streams. + * + *

This method must not be called outside of the event loop. + */ + final void forEachActiveStream(final Http2FrameStreamVisitor streamVisitor) throws Http2Exception { + assert ctx.executor().inEventLoop(); + if (connection().numActiveStreams() > 0) { + connection().forEachActiveStream(new Http2StreamVisitor() { + @Override + public boolean visit(Http2Stream stream) { + try { + return streamVisitor.visit((Http2FrameStream) stream.getProperty(streamKey)); + } catch (Throwable cause) { + onError(ctx, false, cause); + return false; + } + } + }); + } + } + + /** + * Retrieve the number of streams currently in the process of being initialized. + *

+ * This is package-private for testing only. + */ + int numInitializingStreams() { + return frameStreamToInitializeMap.size(); + } + + @Override + public final void handlerAdded(ChannelHandlerContext ctx) throws Exception { + this.ctx = ctx; + super.handlerAdded(ctx); + handlerAdded0(ctx); + // Must be after Http2ConnectionHandler does its initialization in handlerAdded above. + // The server will not send a connection preface so we are good to send a window update. + Http2Connection connection = connection(); + if (connection.isServer()) { + tryExpandConnectionFlowControlWindow(connection); + } + } + + private void tryExpandConnectionFlowControlWindow(Http2Connection connection) throws Http2Exception { + if (initialFlowControlWindowSize != null) { + // The window size in the settings explicitly excludes the connection window. So we manually manipulate the + // connection window to accommodate more concurrent data per connection. + Http2Stream connectionStream = connection.connectionStream(); + Http2LocalFlowController localFlowController = connection.local().flowController(); + final int delta = initialFlowControlWindowSize - localFlowController.initialWindowSize(connectionStream); + // Only increase the connection window, don't decrease it. + if (delta > 0) { + // Double the delta just so a single stream can't exhaust the connection window. + localFlowController.incrementWindowSize(connectionStream, Math.max(delta << 1, delta)); + flush(ctx); + } + } + } + + void handlerAdded0(@SuppressWarnings("unsed") ChannelHandlerContext ctx) throws Exception { + // sub-class can override this for extra steps that needs to be done when the handler is added. + } + + /** + * Handles the cleartext HTTP upgrade event. If an upgrade occurred, sends a simple response via + * HTTP/2 on stream 1 (the stream specifically reserved for cleartext HTTP upgrade). + */ + @Override + public final void userEventTriggered(final ChannelHandlerContext ctx, final Object evt) throws Exception { + if (evt == Http2ConnectionPrefaceAndSettingsFrameWrittenEvent.INSTANCE) { + // The user event implies that we are on the client. + tryExpandConnectionFlowControlWindow(connection()); + + // We schedule this on the EventExecutor to allow to have any extra handlers added to the pipeline + // before we pass the event to the next handler. This is needed as the event may be called from within + // handlerAdded(...) which will be run before other handlers will be added to the pipeline. + ctx.executor().execute(new Runnable() { + @Override + public void run() { + ctx.fireUserEventTriggered(evt); + } + }); + } else if (evt instanceof UpgradeEvent) { + UpgradeEvent upgrade = (UpgradeEvent) evt; + try { + onUpgradeEvent(ctx, upgrade.retain()); + Http2Stream stream = connection().stream(HTTP_UPGRADE_STREAM_ID); + if (stream.getProperty(streamKey) == null) { + // TODO: improve handler/stream lifecycle so that stream isn't active before handler added. + // The stream was already made active, but ctx may have been null so it wasn't initialized. + // https://github.com/netty/netty/issues/4942 + onStreamActive0(stream); + } + upgrade.upgradeRequest().headers().setInt( + HttpConversionUtil.ExtensionHeaderNames.STREAM_ID.text(), HTTP_UPGRADE_STREAM_ID); + stream.setProperty(upgradeKey, true); + InboundHttpToHttp2Adapter.handle( + ctx, connection(), decoder().frameListener(), upgrade.upgradeRequest().retain()); + } finally { + upgrade.release(); + } + } else { + onUserEventTriggered(ctx, evt); + ctx.fireUserEventTriggered(evt); + } + } + + void onUserEventTriggered(final ChannelHandlerContext ctx, final Object evt) throws Exception { + // noop + } + + /** + * Processes all {@link Http2Frame}s. {@link Http2StreamFrame}s may only originate in child + * streams. + */ + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) { + if (msg instanceof Http2DataFrame) { + Http2DataFrame dataFrame = (Http2DataFrame) msg; + encoder().writeData(ctx, dataFrame.stream().id(), dataFrame.content(), + dataFrame.padding(), dataFrame.isEndStream(), promise); + } else if (msg instanceof Http2HeadersFrame) { + writeHeadersFrame(ctx, (Http2HeadersFrame) msg, promise); + } else if (msg instanceof Http2WindowUpdateFrame) { + Http2WindowUpdateFrame frame = (Http2WindowUpdateFrame) msg; + Http2FrameStream frameStream = frame.stream(); + // It is legit to send a WINDOW_UPDATE frame for the connection stream. The parent channel doesn't attempt + // to set the Http2FrameStream so we assume if it is null the WINDOW_UPDATE is for the connection stream. + try { + if (frameStream == null) { + increaseInitialConnectionWindow(frame.windowSizeIncrement()); + } else { + consumeBytes(frameStream.id(), frame.windowSizeIncrement()); + } + promise.setSuccess(); + } catch (Throwable t) { + promise.setFailure(t); + } + } else if (msg instanceof Http2ResetFrame) { + Http2ResetFrame rstFrame = (Http2ResetFrame) msg; + int id = rstFrame.stream().id(); + // Only ever send a reset frame if stream may have existed before as otherwise we may send a RST on a + // stream in an invalid state and cause a connection error. + if (connection().streamMayHaveExisted(id)) { + encoder().writeRstStream(ctx, rstFrame.stream().id(), rstFrame.errorCode(), promise); + } else { + ReferenceCountUtil.release(rstFrame); + promise.setFailure(Http2Exception.streamError( + rstFrame.stream().id(), Http2Error.PROTOCOL_ERROR, "Stream never existed")); + } + } else if (msg instanceof Http2PingFrame) { + Http2PingFrame frame = (Http2PingFrame) msg; + encoder().writePing(ctx, frame.ack(), frame.content(), promise); + } else if (msg instanceof Http2SettingsFrame) { + encoder().writeSettings(ctx, ((Http2SettingsFrame) msg).settings(), promise); + } else if (msg instanceof Http2SettingsAckFrame) { + // In the event of manual SETTINGS ACK, it is assumed the encoder will apply the earliest received but not + // yet ACKed settings. + encoder().writeSettingsAck(ctx, promise); + } else if (msg instanceof Http2GoAwayFrame) { + writeGoAwayFrame(ctx, (Http2GoAwayFrame) msg, promise); + } else if (msg instanceof Http2PushPromiseFrame) { + Http2PushPromiseFrame pushPromiseFrame = (Http2PushPromiseFrame) msg; + writePushPromise(ctx, pushPromiseFrame, promise); + } else if (msg instanceof Http2PriorityFrame) { + Http2PriorityFrame priorityFrame = (Http2PriorityFrame) msg; + encoder().writePriority(ctx, priorityFrame.stream().id(), priorityFrame.streamDependency(), + priorityFrame.weight(), priorityFrame.exclusive(), promise); + } else if (msg instanceof Http2UnknownFrame) { + Http2UnknownFrame unknownFrame = (Http2UnknownFrame) msg; + encoder().writeFrame(ctx, unknownFrame.frameType(), unknownFrame.stream().id(), + unknownFrame.flags(), unknownFrame.content(), promise); + } else if (!(msg instanceof Http2Frame)) { + ctx.write(msg, promise); + } else { + ReferenceCountUtil.release(msg); + throw new UnsupportedMessageTypeException(msg); + } + } + + private void increaseInitialConnectionWindow(int deltaBytes) throws Http2Exception { + // The LocalFlowController is responsible for detecting over/under flow. + connection().local().flowController().incrementWindowSize(connection().connectionStream(), deltaBytes); + } + + final boolean consumeBytes(int streamId, int bytes) throws Http2Exception { + Http2Stream stream = connection().stream(streamId); + // Upgraded requests are ineligible for stream control. We add the null check + // in case the stream has been deregistered. + if (stream != null && streamId == Http2CodecUtil.HTTP_UPGRADE_STREAM_ID) { + Boolean upgraded = stream.getProperty(upgradeKey); + if (Boolean.TRUE.equals(upgraded)) { + return false; + } + } + + return connection().local().flowController().consumeBytes(stream, bytes); + } + + private void writeGoAwayFrame(ChannelHandlerContext ctx, Http2GoAwayFrame frame, ChannelPromise promise) { + if (frame.lastStreamId() > -1) { + frame.release(); + throw new IllegalArgumentException("Last stream id must not be set on GOAWAY frame"); + } + + int lastStreamCreated = connection().remote().lastStreamCreated(); + long lastStreamId = lastStreamCreated + ((long) frame.extraStreamIds()) * 2; + // Check if the computation overflowed. + if (lastStreamId > Integer.MAX_VALUE) { + lastStreamId = Integer.MAX_VALUE; + } + goAway(ctx, (int) lastStreamId, frame.errorCode(), frame.content(), promise); + } + + private void writeHeadersFrame(final ChannelHandlerContext ctx, Http2HeadersFrame headersFrame, + final ChannelPromise promise) { + + if (isStreamIdValid(headersFrame.stream().id())) { + encoder().writeHeaders(ctx, headersFrame.stream().id(), headersFrame.headers(), headersFrame.padding(), + headersFrame.isEndStream(), promise); + } else if (initializeNewStream(ctx, (DefaultHttp2FrameStream) headersFrame.stream(), promise)) { + final int streamId = headersFrame.stream().id(); + + encoder().writeHeaders(ctx, streamId, headersFrame.headers(), headersFrame.padding(), + headersFrame.isEndStream(), promise); + + if (!promise.isDone()) { + numBufferedStreams++; + // Clean up the stream being initialized if writing the headers fails and also + // decrement the number of buffered streams. + promise.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture channelFuture) { + numBufferedStreams--; + handleHeaderFuture(channelFuture, streamId); + } + }); + } else { + handleHeaderFuture(promise, streamId); + } + } + } + + private void writePushPromise(final ChannelHandlerContext ctx, Http2PushPromiseFrame pushPromiseFrame, + final ChannelPromise promise) { + if (isStreamIdValid(pushPromiseFrame.pushStream().id())) { + encoder().writePushPromise(ctx, pushPromiseFrame.stream().id(), pushPromiseFrame.pushStream().id(), + pushPromiseFrame.http2Headers(), pushPromiseFrame.padding(), promise); + } else if (initializeNewStream(ctx, (DefaultHttp2FrameStream) pushPromiseFrame.pushStream(), promise)) { + final int streamId = pushPromiseFrame.stream().id(); + encoder().writePushPromise(ctx, streamId, pushPromiseFrame.pushStream().id(), + pushPromiseFrame.http2Headers(), pushPromiseFrame.padding(), promise); + + if (promise.isDone()) { + handleHeaderFuture(promise, streamId); + } else { + numBufferedStreams++; + // Clean up the stream being initialized if writing the headers fails and also + // decrement the number of buffered streams. + promise.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture channelFuture) { + numBufferedStreams--; + handleHeaderFuture(channelFuture, streamId); + } + }); + } + } + } + + private boolean initializeNewStream(ChannelHandlerContext ctx, DefaultHttp2FrameStream http2FrameStream, + ChannelPromise promise) { + final Http2Connection connection = connection(); + final int streamId = connection.local().incrementAndGetNextStreamId(); + if (streamId < 0) { + promise.setFailure(new Http2NoMoreStreamIdsException()); + + // Simulate a GOAWAY being received due to stream exhaustion on this connection. We use the maximum + // valid stream ID for the current peer. + onHttp2Frame(ctx, new DefaultHttp2GoAwayFrame(connection.isServer() ? Integer.MAX_VALUE : + Integer.MAX_VALUE - 1, NO_ERROR.code(), + writeAscii(ctx.alloc(), "Stream IDs exhausted on local stream creation"))); + + return false; + } + http2FrameStream.id = streamId; + + // Use a Map to store all pending streams as we may have multiple. This is needed as if we would store the + // stream in a field directly we may override the stored field before onStreamAdded(...) was called + // and so not correctly set the property for the buffered stream. + // + // See https://github.com/netty/netty/issues/8692 + Object old = frameStreamToInitializeMap.put(streamId, http2FrameStream); + + // We should not re-use ids. + assert old == null; + return true; + } + + private void handleHeaderFuture(ChannelFuture channelFuture, int streamId) { + if (!channelFuture.isSuccess()) { + frameStreamToInitializeMap.remove(streamId); + } + } + + private void onStreamActive0(Http2Stream stream) { + if (stream.id() != Http2CodecUtil.HTTP_UPGRADE_STREAM_ID && + connection().local().isValidStreamId(stream.id())) { + return; + } + + DefaultHttp2FrameStream stream2 = newStream().setStreamAndProperty(streamKey, stream); + onHttp2StreamStateChanged(ctx, stream2); + } + + private final class ConnectionListener extends Http2ConnectionAdapter { + @Override + public void onStreamAdded(Http2Stream stream) { + DefaultHttp2FrameStream frameStream = frameStreamToInitializeMap.remove(stream.id()); + + if (frameStream != null) { + frameStream.setStreamAndProperty(streamKey, stream); + } + } + + @Override + public void onStreamActive(Http2Stream stream) { + onStreamActive0(stream); + } + + @Override + public void onStreamClosed(Http2Stream stream) { + onHttp2StreamStateChanged0(stream); + } + + @Override + public void onStreamHalfClosed(Http2Stream stream) { + onHttp2StreamStateChanged0(stream); + } + + private void onHttp2StreamStateChanged0(Http2Stream stream) { + DefaultHttp2FrameStream stream2 = stream.getProperty(streamKey); + if (stream2 != null) { + onHttp2StreamStateChanged(ctx, stream2); + } + } + } + + @Override + protected void onConnectionError( + ChannelHandlerContext ctx, boolean outbound, Throwable cause, Http2Exception http2Ex) { + if (!outbound) { + // allow the user to handle it first in the pipeline, and then automatically clean up. + // If this is not desired behavior the user can override this method. + // + // We only forward non outbound errors as outbound errors will already be reflected by failing the promise. + ctx.fireExceptionCaught(cause); + } + super.onConnectionError(ctx, outbound, cause, http2Ex); + } + + /** + * Exceptions for unknown streams, that is streams that have no {@link Http2FrameStream} object attached + * are simply logged and replied to by sending a RST_STREAM frame. + */ + @Override + protected final void onStreamError(ChannelHandlerContext ctx, boolean outbound, Throwable cause, + Http2Exception.StreamException streamException) { + int streamId = streamException.streamId(); + Http2Stream connectionStream = connection().stream(streamId); + if (connectionStream == null) { + onHttp2UnknownStreamError(ctx, cause, streamException); + // Write a RST_STREAM + super.onStreamError(ctx, outbound, cause, streamException); + return; + } + + Http2FrameStream stream = connectionStream.getProperty(streamKey); + if (stream == null) { + LOG.warn("Stream exception thrown without stream object attached.", cause); + // Write a RST_STREAM + super.onStreamError(ctx, outbound, cause, streamException); + return; + } + + if (!outbound) { + // We only forward non outbound errors as outbound errors will already be reflected by failing the promise. + onHttp2FrameStreamException(ctx, new Http2FrameStreamException(stream, streamException.error(), cause)); + } + } + + private static void onHttp2UnknownStreamError(@SuppressWarnings("unused") ChannelHandlerContext ctx, + Throwable cause, Http2Exception.StreamException streamException) { + // We log here for debugging purposes. This exception will be propagated to the upper layers through other ways: + // - fireExceptionCaught + // - fireUserEventTriggered(Http2ResetFrame), see Http2MultiplexHandler#channelRead(...) + // - by failing write promise + // Receiver of the error is responsible for correct handling of this exception. + LOG.log(DEBUG, "Stream exception thrown for unknown stream {}.", streamException.streamId(), cause); + } + + @Override + protected final boolean isGracefulShutdownComplete() { + return super.isGracefulShutdownComplete() && numBufferedStreams == 0; + } + + private final class FrameListener implements Http2FrameListener { + + @Override + public void onUnknownFrame( + ChannelHandlerContext ctx, byte frameType, int streamId, Http2Flags flags, ByteBuf payload) { + if (streamId == 0) { + // Ignore unknown frames on connection stream, for example: HTTP/2 GREASE testing + return; + } + onHttp2Frame(ctx, new DefaultHttp2UnknownFrame(frameType, flags, payload) + .stream(requireStream(streamId)).retain()); + } + + @Override + public void onSettingsRead(ChannelHandlerContext ctx, Http2Settings settings) { + onHttp2Frame(ctx, new DefaultHttp2SettingsFrame(settings)); + } + + @Override + public void onPingRead(ChannelHandlerContext ctx, long data) { + onHttp2Frame(ctx, new DefaultHttp2PingFrame(data, false)); + } + + @Override + public void onPingAckRead(ChannelHandlerContext ctx, long data) { + onHttp2Frame(ctx, new DefaultHttp2PingFrame(data, true)); + } + + @Override + public void onRstStreamRead(ChannelHandlerContext ctx, int streamId, long errorCode) { + onHttp2Frame(ctx, new DefaultHttp2ResetFrame(errorCode).stream(requireStream(streamId))); + } + + @Override + public void onWindowUpdateRead(ChannelHandlerContext ctx, int streamId, int windowSizeIncrement) { + if (streamId == 0) { + // Ignore connection window updates. + return; + } + onHttp2Frame(ctx, new DefaultHttp2WindowUpdateFrame(windowSizeIncrement).stream(requireStream(streamId))); + } + + @Override + public void onHeadersRead(ChannelHandlerContext ctx, int streamId, + Http2Headers headers, int streamDependency, short weight, boolean + exclusive, int padding, boolean endStream) { + onHeadersRead(ctx, streamId, headers, padding, endStream); + } + + @Override + public void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers headers, + int padding, boolean endOfStream) { + onHttp2Frame(ctx, new DefaultHttp2HeadersFrame(headers, endOfStream, padding) + .stream(requireStream(streamId))); + } + + @Override + public int onDataRead(ChannelHandlerContext ctx, int streamId, ByteBuf data, int padding, + boolean endOfStream) { + onHttp2Frame(ctx, new DefaultHttp2DataFrame(data, endOfStream, padding) + .stream(requireStream(streamId)).retain()); + // We return the bytes in consumeBytes() once the stream channel consumed the bytes. + return 0; + } + + @Override + public void onGoAwayRead(ChannelHandlerContext ctx, int lastStreamId, long errorCode, ByteBuf debugData) { + onHttp2Frame(ctx, new DefaultHttp2GoAwayFrame(lastStreamId, errorCode, debugData).retain()); + } + + @Override + public void onPriorityRead(ChannelHandlerContext ctx, int streamId, int streamDependency, + short weight, boolean exclusive) { + + Http2Stream stream = connection().stream(streamId); + if (stream == null) { + // The stream was not opened yet, let's just ignore this for now. + return; + } + onHttp2Frame(ctx, new DefaultHttp2PriorityFrame(streamDependency, weight, exclusive) + .stream(requireStream(streamId))); + } + + @Override + public void onSettingsAckRead(ChannelHandlerContext ctx) { + onHttp2Frame(ctx, Http2SettingsAckFrame.INSTANCE); + } + + @Override + public void onPushPromiseRead(ChannelHandlerContext ctx, int streamId, int promisedStreamId, + Http2Headers headers, int padding) { + onHttp2Frame(ctx, new DefaultHttp2PushPromiseFrame(headers, padding, promisedStreamId) + .pushStream(new DefaultHttp2FrameStream() + .setStreamAndProperty(streamKey, connection().stream(promisedStreamId))) + .stream(requireStream(streamId))); + } + + private Http2FrameStream requireStream(int streamId) { + Http2FrameStream stream = connection().stream(streamId).getProperty(streamKey); + if (stream == null) { + throw new IllegalStateException("Stream object required for identifier: " + streamId); + } + return stream; + } + } + + private void onUpgradeEvent(ChannelHandlerContext ctx, UpgradeEvent evt) { + ctx.fireUserEventTriggered(evt); + } + + private void onHttp2StreamWritabilityChanged(ChannelHandlerContext ctx, DefaultHttp2FrameStream stream, + @SuppressWarnings("unused") boolean writable) { + ctx.fireUserEventTriggered(stream.writabilityChanged); + } + + void onHttp2StreamStateChanged(ChannelHandlerContext ctx, DefaultHttp2FrameStream stream) { + ctx.fireUserEventTriggered(stream.stateChanged); + } + + void onHttp2Frame(ChannelHandlerContext ctx, Http2Frame frame) { + ctx.fireChannelRead(frame); + } + + void onHttp2FrameStreamException(ChannelHandlerContext ctx, Http2FrameStreamException cause) { + ctx.fireExceptionCaught(cause); + } + + private final class Http2RemoteFlowControllerListener implements Http2RemoteFlowController.Listener { + @Override + public void writabilityChanged(Http2Stream stream) { + DefaultHttp2FrameStream frameStream = stream.getProperty(streamKey); + if (frameStream == null) { + return; + } + onHttp2StreamWritabilityChanged( + ctx, frameStream, connection().remote().flowController().isWritable(stream)); + } + } + + /** + * {@link Http2FrameStream} implementation. + */ + // TODO(buchgr): Merge Http2FrameStream and Http2Stream. + static class DefaultHttp2FrameStream implements Http2FrameStream { + + private volatile int id = -1; + private volatile Http2Stream stream; + + final Http2FrameStreamEvent stateChanged = Http2FrameStreamEvent.stateChanged(this); + final Http2FrameStreamEvent writabilityChanged = Http2FrameStreamEvent.writabilityChanged(this); + + Channel attachment; + + DefaultHttp2FrameStream setStreamAndProperty(PropertyKey streamKey, Http2Stream stream) { + assert id == -1 || stream.id() == id; + this.stream = stream; + stream.setProperty(streamKey, this); + return this; + } + + @Override + public int id() { + Http2Stream stream = this.stream; + return stream == null ? id : stream.id(); + } + + @Override + public State state() { + Http2Stream stream = this.stream; + return stream == null ? State.IDLE : stream.state(); + } + + @Override + public String toString() { + return String.valueOf(id()); + } + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameCodecBuilder.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameCodecBuilder.java new file mode 100644 index 0000000..4ea57c9 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameCodecBuilder.java @@ -0,0 +1,245 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.codec.http2; + +import io.netty.util.internal.UnstableApi; + +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +/** + * Builder for the {@link Http2FrameCodec}. + */ +@UnstableApi +public class Http2FrameCodecBuilder extends + AbstractHttp2ConnectionHandlerBuilder { + + private Http2FrameWriter frameWriter; + + /** + * Allows overriding behavior of existing builder. + *

+ * Users of this constructor are responsible for invoking {@link #server(boolean)} method or overriding + * {@link #isServer()} method to give the builder information if the {@link Http2Connection}(s) it creates are in + * server or client mode. + * + * @see AbstractHttp2ConnectionHandlerBuilder + */ + protected Http2FrameCodecBuilder() { + } + + Http2FrameCodecBuilder(boolean server) { + server(server); + // For backwards compatibility we should disable to timeout by default at this layer. + gracefulShutdownTimeoutMillis(0); + } + + /** + * Creates a builder for an HTTP/2 client. + */ + public static Http2FrameCodecBuilder forClient() { + return new Http2FrameCodecBuilder(false); + } + + /** + * Creates a builder for an HTTP/2 server. + */ + public static Http2FrameCodecBuilder forServer() { + return new Http2FrameCodecBuilder(true); + } + + // For testing only. + Http2FrameCodecBuilder frameWriter(Http2FrameWriter frameWriter) { + this.frameWriter = checkNotNull(frameWriter, "frameWriter"); + return this; + } + + @Override + public Http2Settings initialSettings() { + return super.initialSettings(); + } + + @Override + public Http2FrameCodecBuilder initialSettings(Http2Settings settings) { + return super.initialSettings(settings); + } + + @Override + public long gracefulShutdownTimeoutMillis() { + return super.gracefulShutdownTimeoutMillis(); + } + + @Override + public Http2FrameCodecBuilder gracefulShutdownTimeoutMillis(long gracefulShutdownTimeoutMillis) { + return super.gracefulShutdownTimeoutMillis(gracefulShutdownTimeoutMillis); + } + + @Override + public boolean isServer() { + return super.isServer(); + } + + @Override + public int maxReservedStreams() { + return super.maxReservedStreams(); + } + + @Override + public Http2FrameCodecBuilder maxReservedStreams(int maxReservedStreams) { + return super.maxReservedStreams(maxReservedStreams); + } + + @Override + public boolean isValidateHeaders() { + return super.isValidateHeaders(); + } + + @Override + public Http2FrameCodecBuilder validateHeaders(boolean validateHeaders) { + return super.validateHeaders(validateHeaders); + } + + @Override + public Http2FrameLogger frameLogger() { + return super.frameLogger(); + } + + @Override + public Http2FrameCodecBuilder frameLogger(Http2FrameLogger frameLogger) { + return super.frameLogger(frameLogger); + } + + @Override + public boolean encoderEnforceMaxConcurrentStreams() { + return super.encoderEnforceMaxConcurrentStreams(); + } + + @Override + public Http2FrameCodecBuilder encoderEnforceMaxConcurrentStreams(boolean encoderEnforceMaxConcurrentStreams) { + return super.encoderEnforceMaxConcurrentStreams(encoderEnforceMaxConcurrentStreams); + } + + @Override + public int encoderEnforceMaxQueuedControlFrames() { + return super.encoderEnforceMaxQueuedControlFrames(); + } + + @Override + public Http2FrameCodecBuilder encoderEnforceMaxQueuedControlFrames(int maxQueuedControlFrames) { + return super.encoderEnforceMaxQueuedControlFrames(maxQueuedControlFrames); + } + + @Override + public Http2HeadersEncoder.SensitivityDetector headerSensitivityDetector() { + return super.headerSensitivityDetector(); + } + + @Override + public Http2FrameCodecBuilder headerSensitivityDetector( + Http2HeadersEncoder.SensitivityDetector headerSensitivityDetector) { + return super.headerSensitivityDetector(headerSensitivityDetector); + } + + @Override + public Http2FrameCodecBuilder encoderIgnoreMaxHeaderListSize(boolean ignoreMaxHeaderListSize) { + return super.encoderIgnoreMaxHeaderListSize(ignoreMaxHeaderListSize); + } + + @Override + @Deprecated + public Http2FrameCodecBuilder initialHuffmanDecodeCapacity(int initialHuffmanDecodeCapacity) { + return super.initialHuffmanDecodeCapacity(initialHuffmanDecodeCapacity); + } + + @Override + public Http2FrameCodecBuilder autoAckSettingsFrame(boolean autoAckSettings) { + return super.autoAckSettingsFrame(autoAckSettings); + } + + @Override + public Http2FrameCodecBuilder autoAckPingFrame(boolean autoAckPingFrame) { + return super.autoAckPingFrame(autoAckPingFrame); + } + + @Override + public Http2FrameCodecBuilder decoupleCloseAndGoAway(boolean decoupleCloseAndGoAway) { + return super.decoupleCloseAndGoAway(decoupleCloseAndGoAway); + } + + @Override + public Http2FrameCodecBuilder flushPreface(boolean flushPreface) { + return super.flushPreface(flushPreface); + } + + @Override + public int decoderEnforceMaxConsecutiveEmptyDataFrames() { + return super.decoderEnforceMaxConsecutiveEmptyDataFrames(); + } + + @Override + public Http2FrameCodecBuilder decoderEnforceMaxConsecutiveEmptyDataFrames(int maxConsecutiveEmptyFrames) { + return super.decoderEnforceMaxConsecutiveEmptyDataFrames(maxConsecutiveEmptyFrames); + } + + @Override + public Http2FrameCodecBuilder decoderEnforceMaxRstFramesPerWindow( + int maxRstFramesPerWindow, int secondsPerWindow) { + return super.decoderEnforceMaxRstFramesPerWindow(maxRstFramesPerWindow, secondsPerWindow); + } + + /** + * Build a {@link Http2FrameCodec} object. + */ + @Override + public Http2FrameCodec build() { + Http2FrameWriter frameWriter = this.frameWriter; + if (frameWriter != null) { + // This is to support our tests and will never be executed by the user as frameWriter(...) + // is package-private. + DefaultHttp2Connection connection = new DefaultHttp2Connection(isServer(), maxReservedStreams()); + Long maxHeaderListSize = initialSettings().maxHeaderListSize(); + Http2FrameReader frameReader = new DefaultHttp2FrameReader(maxHeaderListSize == null ? + new DefaultHttp2HeadersDecoder(isValidateHeaders()) : + new DefaultHttp2HeadersDecoder(isValidateHeaders(), maxHeaderListSize)); + + if (frameLogger() != null) { + frameWriter = new Http2OutboundFrameLogger(frameWriter, frameLogger()); + frameReader = new Http2InboundFrameLogger(frameReader, frameLogger()); + } + Http2ConnectionEncoder encoder = new DefaultHttp2ConnectionEncoder(connection, frameWriter); + if (encoderEnforceMaxConcurrentStreams()) { + encoder = new StreamBufferingEncoder(encoder); + } + Http2ConnectionDecoder decoder = new DefaultHttp2ConnectionDecoder(connection, encoder, frameReader, + promisedRequestVerifier(), isAutoAckSettingsFrame(), isAutoAckPingFrame(), isValidateHeaders()); + int maxConsecutiveEmptyDataFrames = decoderEnforceMaxConsecutiveEmptyDataFrames(); + if (maxConsecutiveEmptyDataFrames > 0) { + decoder = new Http2EmptyDataFrameConnectionDecoder(decoder, maxConsecutiveEmptyDataFrames); + } + return build(decoder, encoder, initialSettings()); + } + return super.build(); + } + + @Override + protected Http2FrameCodec build( + Http2ConnectionDecoder decoder, Http2ConnectionEncoder encoder, Http2Settings initialSettings) { + Http2FrameCodec codec = new Http2FrameCodec(encoder, decoder, initialSettings, + decoupleCloseAndGoAway(), flushPreface()); + codec.gracefulShutdownTimeoutMillis(gracefulShutdownTimeoutMillis()); + return codec; + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameListener.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameListener.java new file mode 100644 index 0000000..355e4c6 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameListener.java @@ -0,0 +1,220 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.util.internal.UnstableApi; + +/** + * An listener of HTTP/2 frames. + */ +@UnstableApi +public interface Http2FrameListener { + /** + * Handles an inbound {@code DATA} frame. + * + * @param ctx the context from the handler where the frame was read. + * @param streamId the subject stream for the frame. + * @param data payload buffer for the frame. This buffer will be released by the codec. + * @param padding additional bytes that should be added to obscure the true content size. Must be between 0 and + * 256 (inclusive). + * @param endOfStream Indicates whether this is the last frame to be sent from the remote endpoint for this stream. + * @return the number of bytes that have been processed by the application. The returned bytes are used by the + * inbound flow controller to determine the appropriate time to expand the inbound flow control window (i.e. send + * {@code WINDOW_UPDATE}). Returning a value equal to the length of {@code data} + {@code padding} will effectively + * opt-out of application-level flow control for this frame. Returning a value less than the length of {@code data} + * + {@code padding} will defer the returning of the processed bytes, which the application must later return via + * {@link Http2LocalFlowController#consumeBytes(Http2Stream, int)}. The returned value must + * be >= {@code 0} and <= {@code data.readableBytes()} + {@code padding}. + */ + int onDataRead(ChannelHandlerContext ctx, int streamId, ByteBuf data, int padding, + boolean endOfStream) throws Http2Exception; + + /** + * Handles an inbound {@code HEADERS} frame. + *

+ * Only one of the following methods will be called for each {@code HEADERS} frame sequence. + * One will be called when the {@code END_HEADERS} flag has been received. + *

    + *
  • {@link #onHeadersRead(ChannelHandlerContext, int, Http2Headers, int, boolean)}
  • + *
  • {@link #onHeadersRead(ChannelHandlerContext, int, Http2Headers, int, short, boolean, int, boolean)}
  • + *
  • {@link #onPushPromiseRead(ChannelHandlerContext, int, int, Http2Headers, int)}
  • + *
+ *

+ * To say it another way; the {@link Http2Headers} will contain all of the headers + * for the current message exchange step (additional queuing is not necessary). + * + * @param ctx the context from the handler where the frame was read. + * @param streamId the subject stream for the frame. + * @param headers the received headers. + * @param padding additional bytes that should be added to obscure the true content size. Must be between 0 and + * 256 (inclusive). + * @param endOfStream Indicates whether this is the last frame to be sent from the remote endpoint + * for this stream. + */ + void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers headers, int padding, + boolean endOfStream) throws Http2Exception; + + /** + * Handles an inbound {@code HEADERS} frame with priority information specified. + * Only called if {@code END_HEADERS} encountered. + *

+ * Only one of the following methods will be called for each {@code HEADERS} frame sequence. + * One will be called when the {@code END_HEADERS} flag has been received. + *

    + *
  • {@link #onHeadersRead(ChannelHandlerContext, int, Http2Headers, int, boolean)}
  • + *
  • {@link #onHeadersRead(ChannelHandlerContext, int, Http2Headers, int, short, boolean, int, boolean)}
  • + *
  • {@link #onPushPromiseRead(ChannelHandlerContext, int, int, Http2Headers, int)}
  • + *
+ *

+ * To say it another way; the {@link Http2Headers} will contain all of the headers + * for the current message exchange step (additional queuing is not necessary). + * + * @param ctx the context from the handler where the frame was read. + * @param streamId the subject stream for the frame. + * @param headers the received headers. + * @param streamDependency the stream on which this stream depends, or 0 if dependent on the + * connection. + * @param weight the new weight for the stream. + * @param exclusive whether or not the stream should be the exclusive dependent of its parent. + * @param padding additional bytes that should be added to obscure the true content size. Must be between 0 and + * 256 (inclusive). + * @param endOfStream Indicates whether this is the last frame to be sent from the remote endpoint + * for this stream. + */ + void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers headers, + int streamDependency, short weight, boolean exclusive, int padding, boolean endOfStream) + throws Http2Exception; + + /** + * Handles an inbound {@code PRIORITY} frame. + *

+ * Note that is it possible to have this method called and no stream object exist for either + * {@code streamId}, {@code streamDependency}, or both. This is because the {@code PRIORITY} frame can be + * sent/received when streams are in the {@code CLOSED} state. + * + * @param ctx the context from the handler where the frame was read. + * @param streamId the subject stream for the frame. + * @param streamDependency the stream on which this stream depends, or 0 if dependent on the + * connection. + * @param weight the new weight for the stream. + * @param exclusive whether or not the stream should be the exclusive dependent of its parent. + */ + void onPriorityRead(ChannelHandlerContext ctx, int streamId, int streamDependency, + short weight, boolean exclusive) throws Http2Exception; + + /** + * Handles an inbound {@code RST_STREAM} frame. + * + * @param ctx the context from the handler where the frame was read. + * @param streamId the stream that is terminating. + * @param errorCode the error code identifying the type of failure. + */ + void onRstStreamRead(ChannelHandlerContext ctx, int streamId, long errorCode) throws Http2Exception; + + /** + * Handles an inbound {@code SETTINGS} acknowledgment frame. + * @param ctx the context from the handler where the frame was read. + */ + void onSettingsAckRead(ChannelHandlerContext ctx) throws Http2Exception; + + /** + * Handles an inbound {@code SETTINGS} frame. + * + * @param ctx the context from the handler where the frame was read. + * @param settings the settings received from the remote endpoint. + */ + void onSettingsRead(ChannelHandlerContext ctx, Http2Settings settings) throws Http2Exception; + + /** + * Handles an inbound {@code PING} frame. + * + * @param ctx the context from the handler where the frame was read. + * @param data the payload of the frame. + */ + void onPingRead(ChannelHandlerContext ctx, long data) throws Http2Exception; + + /** + * Handles an inbound {@code PING} acknowledgment. + * + * @param ctx the context from the handler where the frame was read. + * @param data the payload of the frame. + */ + void onPingAckRead(ChannelHandlerContext ctx, long data) throws Http2Exception; + + /** + * Handles an inbound {@code PUSH_PROMISE} frame. Only called if {@code END_HEADERS} encountered. + *

+ * Promised requests MUST be authoritative, cacheable, and safe. + * See [RFC 7540], Section 8.2. + *

+ * Only one of the following methods will be called for each {@code HEADERS} frame sequence. + * One will be called when the {@code END_HEADERS} flag has been received. + *

    + *
  • {@link #onHeadersRead(ChannelHandlerContext, int, Http2Headers, int, boolean)}
  • + *
  • {@link #onHeadersRead(ChannelHandlerContext, int, Http2Headers, int, short, boolean, int, boolean)}
  • + *
  • {@link #onPushPromiseRead(ChannelHandlerContext, int, int, Http2Headers, int)}
  • + *
+ *

+ * To say it another way; the {@link Http2Headers} will contain all of the headers + * for the current message exchange step (additional queuing is not necessary). + * + * @param ctx the context from the handler where the frame was read. + * @param streamId the stream the frame was sent on. + * @param promisedStreamId the ID of the promised stream. + * @param headers the received headers. + * @param padding additional bytes that should be added to obscure the true content size. Must be between 0 and + * 256 (inclusive). + */ + void onPushPromiseRead(ChannelHandlerContext ctx, int streamId, int promisedStreamId, + Http2Headers headers, int padding) throws Http2Exception; + + /** + * Handles an inbound {@code GO_AWAY} frame. + * + * @param ctx the context from the handler where the frame was read. + * @param lastStreamId the last known stream of the remote endpoint. + * @param errorCode the error code, if abnormal closure. + * @param debugData application-defined debug data. If this buffer needs to be retained by the + * listener they must make a copy. + */ + void onGoAwayRead(ChannelHandlerContext ctx, int lastStreamId, long errorCode, ByteBuf debugData) + throws Http2Exception; + + /** + * Handles an inbound {@code WINDOW_UPDATE} frame. + * + * @param ctx the context from the handler where the frame was read. + * @param streamId the stream the frame was sent on. + * @param windowSizeIncrement the increased number of bytes of the remote endpoint's flow + * control window. + */ + void onWindowUpdateRead(ChannelHandlerContext ctx, int streamId, int windowSizeIncrement) + throws Http2Exception; + + /** + * Handler for a frame not defined by the HTTP/2 spec. + * + * @param ctx the context from the handler where the frame was read. + * @param frameType the frame type from the HTTP/2 header. + * @param streamId the stream the frame was sent on. + * @param flags the flags in the frame header. + * @param payload the payload of the frame. + */ + void onUnknownFrame(ChannelHandlerContext ctx, byte frameType, int streamId, Http2Flags flags, ByteBuf payload) + throws Http2Exception; +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameListenerDecorator.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameListenerDecorator.java new file mode 100644 index 0000000..f028537 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameListenerDecorator.java @@ -0,0 +1,105 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import static io.netty.util.internal.ObjectUtil.checkNotNull; +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.util.internal.UnstableApi; + +/** + * Provides a decorator around a {@link Http2FrameListener} and delegates all method calls + */ +@UnstableApi +public class Http2FrameListenerDecorator implements Http2FrameListener { + protected final Http2FrameListener listener; + + public Http2FrameListenerDecorator(Http2FrameListener listener) { + this.listener = checkNotNull(listener, "listener"); + } + + @Override + public int onDataRead(ChannelHandlerContext ctx, int streamId, ByteBuf data, int padding, boolean endOfStream) + throws Http2Exception { + return listener.onDataRead(ctx, streamId, data, padding, endOfStream); + } + + @Override + public void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers headers, int padding, + boolean endStream) throws Http2Exception { + listener.onHeadersRead(ctx, streamId, headers, padding, endStream); + } + + @Override + public void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers headers, int streamDependency, + short weight, boolean exclusive, int padding, boolean endStream) throws Http2Exception { + listener.onHeadersRead(ctx, streamId, headers, streamDependency, weight, exclusive, padding, endStream); + } + + @Override + public void onPriorityRead(ChannelHandlerContext ctx, int streamId, int streamDependency, short weight, + boolean exclusive) throws Http2Exception { + listener.onPriorityRead(ctx, streamId, streamDependency, weight, exclusive); + } + + @Override + public void onRstStreamRead(ChannelHandlerContext ctx, int streamId, long errorCode) throws Http2Exception { + listener.onRstStreamRead(ctx, streamId, errorCode); + } + + @Override + public void onSettingsAckRead(ChannelHandlerContext ctx) throws Http2Exception { + listener.onSettingsAckRead(ctx); + } + + @Override + public void onSettingsRead(ChannelHandlerContext ctx, Http2Settings settings) throws Http2Exception { + listener.onSettingsRead(ctx, settings); + } + + @Override + public void onPingRead(ChannelHandlerContext ctx, long data) throws Http2Exception { + listener.onPingRead(ctx, data); + } + + @Override + public void onPingAckRead(ChannelHandlerContext ctx, long data) throws Http2Exception { + listener.onPingAckRead(ctx, data); + } + + @Override + public void onPushPromiseRead(ChannelHandlerContext ctx, int streamId, int promisedStreamId, Http2Headers headers, + int padding) throws Http2Exception { + listener.onPushPromiseRead(ctx, streamId, promisedStreamId, headers, padding); + } + + @Override + public void onGoAwayRead(ChannelHandlerContext ctx, int lastStreamId, long errorCode, ByteBuf debugData) + throws Http2Exception { + listener.onGoAwayRead(ctx, lastStreamId, errorCode, debugData); + } + + @Override + public void onWindowUpdateRead(ChannelHandlerContext ctx, int streamId, int windowSizeIncrement) + throws Http2Exception { + listener.onWindowUpdateRead(ctx, streamId, windowSizeIncrement); + } + + @Override + public void onUnknownFrame(ChannelHandlerContext ctx, byte frameType, int streamId, Http2Flags flags, + ByteBuf payload) throws Http2Exception { + listener.onUnknownFrame(ctx, frameType, streamId, flags, payload); + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameLogger.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameLogger.java new file mode 100644 index 0000000..9e69530 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameLogger.java @@ -0,0 +1,176 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.channel.ChannelHandlerAdapter; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.logging.LogLevel; +import io.netty.util.internal.UnstableApi; +import io.netty.util.internal.logging.InternalLogLevel; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +/** + * Logs HTTP2 frames for debugging purposes. + */ +@UnstableApi +public class Http2FrameLogger extends ChannelHandlerAdapter { + + public enum Direction { + INBOUND, + OUTBOUND + } + + private static final int BUFFER_LENGTH_THRESHOLD = 64; + private final InternalLogger logger; + private final InternalLogLevel level; + + public Http2FrameLogger(LogLevel level) { + this(checkAndConvertLevel(level), InternalLoggerFactory.getInstance(Http2FrameLogger.class)); + } + + public Http2FrameLogger(LogLevel level, String name) { + this(checkAndConvertLevel(level), InternalLoggerFactory.getInstance(checkNotNull(name, "name"))); + } + + public Http2FrameLogger(LogLevel level, Class clazz) { + this(checkAndConvertLevel(level), InternalLoggerFactory.getInstance(checkNotNull(clazz, "clazz"))); + } + + private Http2FrameLogger(InternalLogLevel level, InternalLogger logger) { + this.level = level; + this.logger = logger; + } + + private static InternalLogLevel checkAndConvertLevel(LogLevel level) { + return checkNotNull(level, "level").toInternalLevel(); + } + + public boolean isEnabled() { + return logger.isEnabled(level); + } + + public void logData(Direction direction, ChannelHandlerContext ctx, int streamId, ByteBuf data, int padding, + boolean endStream) { + if (isEnabled()) { + logger.log(level, "{} {} DATA: streamId={} padding={} endStream={} length={} bytes={}", ctx.channel(), + direction.name(), streamId, padding, endStream, data.readableBytes(), toString(data)); + } + } + + public void logHeaders(Direction direction, ChannelHandlerContext ctx, int streamId, Http2Headers headers, + int padding, boolean endStream) { + if (isEnabled()) { + logger.log(level, "{} {} HEADERS: streamId={} headers={} padding={} endStream={}", ctx.channel(), + direction.name(), streamId, headers, padding, endStream); + } + } + + public void logHeaders(Direction direction, ChannelHandlerContext ctx, int streamId, Http2Headers headers, + int streamDependency, short weight, boolean exclusive, int padding, boolean endStream) { + if (isEnabled()) { + logger.log(level, "{} {} HEADERS: streamId={} headers={} streamDependency={} weight={} exclusive={} " + + "padding={} endStream={}", ctx.channel(), + direction.name(), streamId, headers, streamDependency, weight, exclusive, padding, endStream); + } + } + + public void logPriority(Direction direction, ChannelHandlerContext ctx, int streamId, int streamDependency, + short weight, boolean exclusive) { + if (isEnabled()) { + logger.log(level, "{} {} PRIORITY: streamId={} streamDependency={} weight={} exclusive={}", ctx.channel(), + direction.name(), streamId, streamDependency, weight, exclusive); + } + } + + public void logRstStream(Direction direction, ChannelHandlerContext ctx, int streamId, long errorCode) { + if (isEnabled()) { + logger.log(level, "{} {} RST_STREAM: streamId={} errorCode={}", ctx.channel(), + direction.name(), streamId, errorCode); + } + } + + public void logSettingsAck(Direction direction, ChannelHandlerContext ctx) { + logger.log(level, "{} {} SETTINGS: ack=true", ctx.channel(), direction.name()); + } + + public void logSettings(Direction direction, ChannelHandlerContext ctx, Http2Settings settings) { + if (isEnabled()) { + logger.log(level, "{} {} SETTINGS: ack=false settings={}", ctx.channel(), direction.name(), settings); + } + } + + public void logPing(Direction direction, ChannelHandlerContext ctx, long data) { + if (isEnabled()) { + logger.log(level, "{} {} PING: ack=false bytes={}", ctx.channel(), + direction.name(), data); + } + } + + public void logPingAck(Direction direction, ChannelHandlerContext ctx, long data) { + if (isEnabled()) { + logger.log(level, "{} {} PING: ack=true bytes={}", ctx.channel(), + direction.name(), data); + } + } + + public void logPushPromise(Direction direction, ChannelHandlerContext ctx, int streamId, int promisedStreamId, + Http2Headers headers, int padding) { + if (isEnabled()) { + logger.log(level, "{} {} PUSH_PROMISE: streamId={} promisedStreamId={} headers={} padding={}", + ctx.channel(), direction.name(), streamId, promisedStreamId, headers, padding); + } + } + + public void logGoAway(Direction direction, ChannelHandlerContext ctx, int lastStreamId, long errorCode, + ByteBuf debugData) { + if (isEnabled()) { + logger.log(level, "{} {} GO_AWAY: lastStreamId={} errorCode={} length={} bytes={}", ctx.channel(), + direction.name(), lastStreamId, errorCode, debugData.readableBytes(), toString(debugData)); + } + } + + public void logWindowsUpdate(Direction direction, ChannelHandlerContext ctx, int streamId, + int windowSizeIncrement) { + if (isEnabled()) { + logger.log(level, "{} {} WINDOW_UPDATE: streamId={} windowSizeIncrement={}", ctx.channel(), + direction.name(), streamId, windowSizeIncrement); + } + } + + public void logUnknownFrame(Direction direction, ChannelHandlerContext ctx, byte frameType, int streamId, + Http2Flags flags, ByteBuf data) { + if (isEnabled()) { + logger.log(level, "{} {} UNKNOWN: frameType={} streamId={} flags={} length={} bytes={}", ctx.channel(), + direction.name(), frameType & 0xFF, streamId, flags.value(), data.readableBytes(), toString(data)); + } + } + + private String toString(ByteBuf buf) { + if (level == InternalLogLevel.TRACE || buf.readableBytes() <= BUFFER_LENGTH_THRESHOLD) { + // Log the entire buffer. + return ByteBufUtil.hexDump(buf); + } + + // Otherwise just log the first 64 bytes. + int length = Math.min(buf.readableBytes(), BUFFER_LENGTH_THRESHOLD); + return ByteBufUtil.hexDump(buf, buf.readerIndex(), length) + "..."; + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameReader.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameReader.java new file mode 100644 index 0000000..7d5ac3f --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameReader.java @@ -0,0 +1,62 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.util.internal.UnstableApi; + +import java.io.Closeable; + +/** + * Reads HTTP/2 frames from an input {@link ByteBuf} and notifies the specified + * {@link Http2FrameListener} when frames are complete. + */ +@UnstableApi +public interface Http2FrameReader extends Closeable { + /** + * Configuration specific to {@link Http2FrameReader} + */ + interface Configuration { + /** + * Get the {@link Http2HeadersDecoder.Configuration} for this {@link Http2FrameReader} + */ + Http2HeadersDecoder.Configuration headersConfiguration(); + + /** + * Get the {@link Http2FrameSizePolicy} for this {@link Http2FrameReader} + */ + Http2FrameSizePolicy frameSizePolicy(); + } + + /** + * Attempts to read the next frame from the input buffer. If enough data is available to fully + * read the frame, notifies the listener of the read frame. + */ + void readFrame(ChannelHandlerContext ctx, ByteBuf input, Http2FrameListener listener) + throws Http2Exception; + + /** + * Get the configuration related elements for this {@link Http2FrameReader} + */ + Configuration configuration(); + + /** + * Closes this reader and frees any allocated resources. + */ + @Override + void close(); +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameSizePolicy.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameSizePolicy.java new file mode 100644 index 0000000..23abd3a --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameSizePolicy.java @@ -0,0 +1,39 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.util.internal.UnstableApi; + +@UnstableApi +public interface Http2FrameSizePolicy { + /** + * Sets the maximum allowed frame size. Attempts to write frames longer than this maximum will fail. + *

+ * This value is used to represent + * SETTINGS_MAX_FRAME_SIZE. This method should + * only be called by Netty (not users) as a result of a receiving a {@code SETTINGS} frame. + */ + void maxFrameSize(int max) throws Http2Exception; + + /** + * Gets the maximum allowed frame size. + *

+ * This value is used to represent + * SETTINGS_MAX_FRAME_SIZE. The initial value + * defined by the RFC is unlimited but enforcing a lower limit is generally permitted. + * {@link Http2CodecUtil#DEFAULT_MAX_FRAME_SIZE} can be used as a more conservative default. + */ + int maxFrameSize(); +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameStream.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameStream.java new file mode 100644 index 0000000..5c3e980 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameStream.java @@ -0,0 +1,39 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.codec.http2; + +import io.netty.handler.codec.http2.Http2Stream.State; +import io.netty.util.internal.UnstableApi; + +/** + * A single stream within an HTTP/2 connection. To be used with the {@link Http2FrameCodec}. + */ +@UnstableApi +public interface Http2FrameStream { + /** + * Returns the stream identifier. + * + *

Use {@link Http2CodecUtil#isStreamIdValid(int)} to check if the stream has already been assigned an + * identifier. + */ + int id(); + + /** + * Returns the state of this stream. + */ + State state(); +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameStreamEvent.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameStreamEvent.java new file mode 100644 index 0000000..f9e8269 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameStreamEvent.java @@ -0,0 +1,52 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.util.internal.UnstableApi; + +@UnstableApi +public final class Http2FrameStreamEvent { + + private final Http2FrameStream stream; + private final Type type; + + @UnstableApi + public enum Type { + State, + Writability + } + + private Http2FrameStreamEvent(Http2FrameStream stream, Type type) { + this.stream = stream; + this.type = type; + } + + public Http2FrameStream stream() { + return stream; + } + + public Type type() { + return type; + } + + static Http2FrameStreamEvent stateChanged(Http2FrameStream stream) { + return new Http2FrameStreamEvent(stream, Type.State); + } + + static Http2FrameStreamEvent writabilityChanged(Http2FrameStream stream) { + return new Http2FrameStreamEvent(stream, Type.Writability); + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameStreamException.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameStreamException.java new file mode 100644 index 0000000..fbafbcd --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameStreamException.java @@ -0,0 +1,47 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.codec.http2; + +import io.netty.util.internal.UnstableApi; + +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +/** + * An HTTP/2 exception for a specific {@link Http2FrameStream}. + */ +@UnstableApi +public final class Http2FrameStreamException extends Exception { + + private static final long serialVersionUID = -4407186173493887044L; + + private final Http2Error error; + private final Http2FrameStream stream; + + public Http2FrameStreamException(Http2FrameStream stream, Http2Error error, Throwable cause) { + super(cause.getMessage(), cause); + this.stream = checkNotNull(stream, "stream"); + this.error = checkNotNull(error, "error"); + } + + public Http2Error error() { + return error; + } + + public Http2FrameStream stream() { + return stream; + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameStreamVisitor.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameStreamVisitor.java new file mode 100644 index 0000000..8c75d9e --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameStreamVisitor.java @@ -0,0 +1,38 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.codec.http2; + +import io.netty.util.internal.UnstableApi; + +/** + * A visitor that allows to iterate over a collection of {@link Http2FrameStream}s. + */ +@UnstableApi +public interface Http2FrameStreamVisitor { + + /** + * This method is called once for each stream of the collection. + * + *

If an {@link Exception} is thrown, the loop is stopped. + * + * @return

    + *
  • {@code true} if the visitor wants to continue the loop and handle the stream.
  • + *
  • {@code false} if the visitor wants to stop handling the stream and abort the loop.
  • + *
+ */ + boolean visit(Http2FrameStream stream); +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameTypes.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameTypes.java new file mode 100644 index 0000000..54bcc3b --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameTypes.java @@ -0,0 +1,38 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package io.netty.handler.codec.http2; + +import io.netty.util.internal.UnstableApi; + +/** + * Registry of all standard frame types defined by the HTTP/2 specification. + */ +@UnstableApi +public final class Http2FrameTypes { + public static final byte DATA = 0x0; + public static final byte HEADERS = 0x1; + public static final byte PRIORITY = 0x2; + public static final byte RST_STREAM = 0x3; + public static final byte SETTINGS = 0x4; + public static final byte PUSH_PROMISE = 0x5; + public static final byte PING = 0x6; + public static final byte GO_AWAY = 0x7; + public static final byte WINDOW_UPDATE = 0x8; + public static final byte CONTINUATION = 0x9; + + private Http2FrameTypes() { + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameWriter.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameWriter.java new file mode 100644 index 0000000..c312cdb --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2FrameWriter.java @@ -0,0 +1,229 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.util.internal.UnstableApi; + +import java.io.Closeable; + +/** + * A writer responsible for marshaling HTTP/2 frames to the channel. All of the write methods in + * this interface write to the context, but DO NOT FLUSH. To perform a flush, you must separately + * call {@link ChannelHandlerContext#flush()}. + */ +@UnstableApi +public interface Http2FrameWriter extends Http2DataWriter, Closeable { + /** + * Configuration specific to {@link Http2FrameWriter} + */ + interface Configuration { + /** + * Get the {@link Http2HeadersEncoder.Configuration} for this {@link Http2FrameWriter} + */ + Http2HeadersEncoder.Configuration headersConfiguration(); + + /** + * Get the {@link Http2FrameSizePolicy} for this {@link Http2FrameWriter} + */ + Http2FrameSizePolicy frameSizePolicy(); + } + + /** + * Writes a HEADERS frame to the remote endpoint. + * + * @param ctx the context to use for writing. + * @param streamId the stream for which to send the frame. + * @param headers the headers to be sent. + * @param padding additional bytes that should be added to obscure the true content size. Must be between 0 and + * 256 (inclusive). + * @param endStream indicates if this is the last frame to be sent for the stream. + * @param promise the promise for the write. + * @return the future for the write. + * Section 10.5.1 states the following: + *
+     * The header block MUST be processed to ensure a consistent connection state, unless the connection is closed.
+     * 
+ * If this call has modified the HPACK header state you MUST throw a connection error. + *

+ * If this call has NOT modified the HPACK header state you are free to throw a stream error. + */ + ChannelFuture writeHeaders(ChannelHandlerContext ctx, int streamId, Http2Headers headers, + int padding, boolean endStream, ChannelPromise promise); + + /** + * Writes a HEADERS frame with priority specified to the remote endpoint. + * + * @param ctx the context to use for writing. + * @param streamId the stream for which to send the frame. + * @param headers the headers to be sent. + * @param streamDependency the stream on which this stream should depend, or 0 if it should + * depend on the connection. + * @param weight the weight for this stream. + * @param exclusive whether this stream should be the exclusive dependant of its parent. + * @param padding additional bytes that should be added to obscure the true content size. Must be between 0 and + * 256 (inclusive). + * @param endStream indicates if this is the last frame to be sent for the stream. + * @param promise the promise for the write. + * @return the future for the write. + * Section 10.5.1 states the following: + *

+     * The header block MUST be processed to ensure a consistent connection state, unless the connection is closed.
+     * 
+ * If this call has modified the HPACK header state you MUST throw a connection error. + *

+ * If this call has NOT modified the HPACK header state you are free to throw a stream error. + */ + ChannelFuture writeHeaders(ChannelHandlerContext ctx, int streamId, Http2Headers headers, + int streamDependency, short weight, boolean exclusive, int padding, boolean endStream, + ChannelPromise promise); + + /** + * Writes a PRIORITY frame to the remote endpoint. + * + * @param ctx the context to use for writing. + * @param streamId the stream for which to send the frame. + * @param streamDependency the stream on which this stream should depend, or 0 if it should + * depend on the connection. + * @param weight the weight for this stream. + * @param exclusive whether this stream should be the exclusive dependant of its parent. + * @param promise the promise for the write. + * @return the future for the write. + */ + ChannelFuture writePriority(ChannelHandlerContext ctx, int streamId, int streamDependency, + short weight, boolean exclusive, ChannelPromise promise); + + /** + * Writes a RST_STREAM frame to the remote endpoint. + * + * @param ctx the context to use for writing. + * @param streamId the stream for which to send the frame. + * @param errorCode the error code indicating the nature of the failure. + * @param promise the promise for the write. + * @return the future for the write. + */ + ChannelFuture writeRstStream(ChannelHandlerContext ctx, int streamId, long errorCode, + ChannelPromise promise); + + /** + * Writes a SETTINGS frame to the remote endpoint. + * + * @param ctx the context to use for writing. + * @param settings the settings to be sent. + * @param promise the promise for the write. + * @return the future for the write. + */ + ChannelFuture writeSettings(ChannelHandlerContext ctx, Http2Settings settings, + ChannelPromise promise); + + /** + * Writes a SETTINGS acknowledgment to the remote endpoint. + * + * @param ctx the context to use for writing. + * @param promise the promise for the write. + * @return the future for the write. + */ + ChannelFuture writeSettingsAck(ChannelHandlerContext ctx, ChannelPromise promise); + + /** + * Writes a PING frame to the remote endpoint. + * + * @param ctx the context to use for writing. + * @param ack indicates whether this is an ack of a PING frame previously received from the + * remote endpoint. + * @param data the payload of the frame. + * @param promise the promise for the write. + * @return the future for the write. + */ + ChannelFuture writePing(ChannelHandlerContext ctx, boolean ack, long data, + ChannelPromise promise); + + /** + * Writes a PUSH_PROMISE frame to the remote endpoint. + * + * @param ctx the context to use for writing. + * @param streamId the stream for which to send the frame. + * @param promisedStreamId the ID of the promised stream. + * @param headers the headers to be sent. + * @param padding additional bytes that should be added to obscure the true content size. Must be between 0 and + * 256 (inclusive). + * @param promise the promise for the write. + * @return the future for the write. + * Section 10.5.1 states the following: + *

+     * The header block MUST be processed to ensure a consistent connection state, unless the connection is closed.
+     * 
+ * If this call has modified the HPACK header state you MUST throw a connection error. + *

+ * If this call has NOT modified the HPACK header state you are free to throw a stream error. + */ + ChannelFuture writePushPromise(ChannelHandlerContext ctx, int streamId, int promisedStreamId, + Http2Headers headers, int padding, ChannelPromise promise); + + /** + * Writes a GO_AWAY frame to the remote endpoint. + * + * @param ctx the context to use for writing. + * @param lastStreamId the last known stream of this endpoint. + * @param errorCode the error code, if the connection was abnormally terminated. + * @param debugData application-defined debug data. This will be released by this method. + * @param promise the promise for the write. + * @return the future for the write. + */ + ChannelFuture writeGoAway(ChannelHandlerContext ctx, int lastStreamId, long errorCode, + ByteBuf debugData, ChannelPromise promise); + + /** + * Writes a WINDOW_UPDATE frame to the remote endpoint. + * + * @param ctx the context to use for writing. + * @param streamId the stream for which to send the frame. + * @param windowSizeIncrement the number of bytes by which the local inbound flow control window + * is increasing. + * @param promise the promise for the write. + * @return the future for the write. + */ + ChannelFuture writeWindowUpdate(ChannelHandlerContext ctx, int streamId, + int windowSizeIncrement, ChannelPromise promise); + + /** + * Generic write method for any HTTP/2 frame. This allows writing of non-standard frames. + * + * @param ctx the context to use for writing. + * @param frameType the frame type identifier. + * @param streamId the stream for which to send the frame. + * @param flags the flags to write for this frame. + * @param payload the payload to write for this frame. This will be released by this method. + * @param promise the promise for the write. + * @return the future for the write. + */ + ChannelFuture writeFrame(ChannelHandlerContext ctx, byte frameType, int streamId, + Http2Flags flags, ByteBuf payload, ChannelPromise promise); + + /** + * Get the configuration related elements for this {@link Http2FrameWriter} + */ + Configuration configuration(); + + /** + * Closes this writer and frees any allocated resources. + */ + @Override + void close(); +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2GoAwayFrame.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2GoAwayFrame.java new file mode 100644 index 0000000..70f22d3 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2GoAwayFrame.java @@ -0,0 +1,89 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufHolder; +import io.netty.util.internal.UnstableApi; + +/** + * HTTP/2 GOAWAY frame. + * + *

The last stream identifier must not be set by the application, but instead the + * relative {@link #extraStreamIds()} should be used. The {@link #lastStreamId()} will only be + * set for incoming GOAWAY frames by the HTTP/2 codec. + * + *

Graceful shutdown as described in the HTTP/2 spec can be accomplished by calling + * {@code #setExtraStreamIds(Integer.MAX_VALUE)}. + */ +@UnstableApi +public interface Http2GoAwayFrame extends Http2Frame, ByteBufHolder { + /** + * The reason for beginning closure of the connection. Represented as an HTTP/2 error code. + */ + long errorCode(); + + /** + * The number of IDs to reserve for the receiver to use while GOAWAY is in transit. This allows + * for new streams currently en route to still be created, up to a point, which allows for very + * graceful shutdown of both sides. + */ + int extraStreamIds(); + + /** + * Sets the number of IDs to reserve for the receiver to use while GOAWAY is in transit. + * + * @see #extraStreamIds + * @return {@code this} + */ + Http2GoAwayFrame setExtraStreamIds(int extraStreamIds); + + /** + * Returns the last stream identifier if set, or {@code -1} else. + */ + int lastStreamId(); + + /** + * Optional debugging information describing cause the GOAWAY. Will not be {@code null}, but may + * be empty. + */ + @Override + ByteBuf content(); + + @Override + Http2GoAwayFrame copy(); + + @Override + Http2GoAwayFrame duplicate(); + + @Override + Http2GoAwayFrame retainedDuplicate(); + + @Override + Http2GoAwayFrame replace(ByteBuf content); + + @Override + Http2GoAwayFrame retain(); + + @Override + Http2GoAwayFrame retain(int increment); + + @Override + Http2GoAwayFrame touch(); + + @Override + Http2GoAwayFrame touch(Object hint); +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2Headers.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2Headers.java new file mode 100644 index 0000000..1d21994 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2Headers.java @@ -0,0 +1,205 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package io.netty.handler.codec.http2; + +import io.netty.handler.codec.Headers; +import io.netty.util.AsciiString; +import io.netty.util.internal.UnstableApi; + +import java.util.Iterator; +import java.util.Map.Entry; + +/** + * A collection of headers sent or received via HTTP/2. + */ +@UnstableApi +public interface Http2Headers extends Headers { + + /** + * HTTP/2 pseudo-headers names. + */ + enum PseudoHeaderName { + /** + * {@code :method}. + */ + METHOD(":method", true), + + /** + * {@code :scheme}. + */ + SCHEME(":scheme", true), + + /** + * {@code :authority}. + */ + AUTHORITY(":authority", true), + + /** + * {@code :path}. + */ + PATH(":path", true), + + /** + * {@code :status}. + */ + STATUS(":status", false), + + /** + * {@code :protocol}, as defined in RFC 8441, + * Bootstrapping WebSockets with HTTP/2. + */ + PROTOCOL(":protocol", true); + + private static final char PSEUDO_HEADER_PREFIX = ':'; + private static final byte PSEUDO_HEADER_PREFIX_BYTE = (byte) PSEUDO_HEADER_PREFIX; + + private final AsciiString value; + private final boolean requestOnly; + private static final CharSequenceMap PSEUDO_HEADERS = new CharSequenceMap(); + + static { + for (PseudoHeaderName pseudoHeader : values()) { + PSEUDO_HEADERS.add(pseudoHeader.value(), pseudoHeader); + } + } + + PseudoHeaderName(String value, boolean requestOnly) { + this.value = AsciiString.cached(value); + this.requestOnly = requestOnly; + } + + public AsciiString value() { + // Return a slice so that the buffer gets its own reader index. + return value; + } + + /** + * Indicates whether the specified header follows the pseudo-header format (begins with ':' character) + * + * @return {@code true} if the header follow the pseudo-header format + */ + public static boolean hasPseudoHeaderFormat(CharSequence headerName) { + if (headerName instanceof AsciiString) { + final AsciiString asciiHeaderName = (AsciiString) headerName; + return asciiHeaderName.length() > 0 && asciiHeaderName.byteAt(0) == PSEUDO_HEADER_PREFIX_BYTE; + } else { + return headerName.length() > 0 && headerName.charAt(0) == PSEUDO_HEADER_PREFIX; + } + } + + /** + * Indicates whether the given header name is a valid HTTP/2 pseudo header. + */ + public static boolean isPseudoHeader(CharSequence header) { + return PSEUDO_HEADERS.contains(header); + } + + /** + * Returns the {@link PseudoHeaderName} corresponding to the specified header name. + * + * @return corresponding {@link PseudoHeaderName} if any, {@code null} otherwise. + */ + public static PseudoHeaderName getPseudoHeader(CharSequence header) { + return PSEUDO_HEADERS.get(header); + } + + /** + * Indicates whether the pseudo-header is to be used in a request context. + * + * @return {@code true} if the pseudo-header is to be used in a request context + */ + public boolean isRequestOnly() { + return requestOnly; + } + } + + /** + * Returns an iterator over all HTTP/2 headers. The iteration order is as follows: + * 1. All pseudo headers (order not specified). + * 2. All non-pseudo headers (in insertion order). + */ + @Override + Iterator> iterator(); + + /** + * Equivalent to {@link #getAll(Object)} but no intermediate list is generated. + * @param name the name of the header to retrieve + * @return an {@link Iterator} of header values corresponding to {@code name}. + */ + Iterator valueIterator(CharSequence name); + + /** + * Sets the {@link PseudoHeaderName#METHOD} header + */ + Http2Headers method(CharSequence value); + + /** + * Sets the {@link PseudoHeaderName#SCHEME} header + */ + Http2Headers scheme(CharSequence value); + + /** + * Sets the {@link PseudoHeaderName#AUTHORITY} header + */ + Http2Headers authority(CharSequence value); + + /** + * Sets the {@link PseudoHeaderName#PATH} header + */ + Http2Headers path(CharSequence value); + + /** + * Sets the {@link PseudoHeaderName#STATUS} header + */ + Http2Headers status(CharSequence value); + + /** + * Gets the {@link PseudoHeaderName#METHOD} header or {@code null} if there is no such header + */ + CharSequence method(); + + /** + * Gets the {@link PseudoHeaderName#SCHEME} header or {@code null} if there is no such header + */ + CharSequence scheme(); + + /** + * Gets the {@link PseudoHeaderName#AUTHORITY} header or {@code null} if there is no such header + */ + CharSequence authority(); + + /** + * Gets the {@link PseudoHeaderName#PATH} header or {@code null} if there is no such header + */ + CharSequence path(); + + /** + * Gets the {@link PseudoHeaderName#STATUS} header or {@code null} if there is no such header + */ + CharSequence status(); + + /** + * Returns {@code true} if a header with the {@code name} and {@code value} exists, {@code false} otherwise. + *

+ * If {@code caseInsensitive} is {@code true} then a case insensitive compare is done on the value. + * + * @param name the name of the header to find + * @param value the value of the header to find + * @param caseInsensitive {@code true} then a case insensitive compare is run to compare values. + * otherwise a case sensitive compare is run to compare values. + */ + boolean contains(CharSequence name, CharSequence value, boolean caseInsensitive); +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2HeadersDecoder.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2HeadersDecoder.java new file mode 100644 index 0000000..b519c33 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2HeadersDecoder.java @@ -0,0 +1,80 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.util.internal.UnstableApi; + +/** + * Decodes HPACK-encoded headers blocks into {@link Http2Headers}. + */ +@UnstableApi +public interface Http2HeadersDecoder { + /** + * Configuration related elements for the {@link Http2HeadersDecoder} interface + */ + interface Configuration { + /** + * Represents the value for + * SETTINGS_HEADER_TABLE_SIZE. + * This method should only be called by Netty (not users) as a result of a receiving a {@code SETTINGS} frame. + */ + void maxHeaderTableSize(long max) throws Http2Exception; + + /** + * Represents the value for + * SETTINGS_HEADER_TABLE_SIZE. The initial value + * returned by this method must be {@link Http2CodecUtil#DEFAULT_HEADER_TABLE_SIZE}. + */ + long maxHeaderTableSize(); + + /** + * Configure the maximum allowed size in bytes of each set of headers. + *

+ * This method should only be called by Netty (not users) as a result of a receiving a {@code SETTINGS} frame. + * @param max SETTINGS_MAX_HEADER_LIST_SIZE. + * If this limit is exceeded the implementation should attempt to keep the HPACK header tables up to date + * by processing data from the peer, but a {@code RST_STREAM} frame will be sent for the offending stream. + * @param goAwayMax Must be {@code >= max}. A {@code GO_AWAY} frame will be generated if this limit is exceeded + * for any particular stream. + * @throws Http2Exception if limits exceed the RFC's boundaries or {@code max > goAwayMax}. + */ + void maxHeaderListSize(long max, long goAwayMax) throws Http2Exception; + + /** + * Represents the value for + * SETTINGS_MAX_HEADER_LIST_SIZE. + */ + long maxHeaderListSize(); + + /** + * Represents the upper bound in bytes for a set of headers before a {@code GO_AWAY} should be sent. + * This will be {@code <=} + * SETTINGS_MAX_HEADER_LIST_SIZE. + */ + long maxHeaderListSizeGoAway(); + } + + /** + * Decodes the given headers block and returns the headers. + */ + Http2Headers decodeHeaders(int streamId, ByteBuf headerBlock) throws Http2Exception; + + /** + * Get the {@link Configuration} for this {@link Http2HeadersDecoder} + */ + Configuration configuration(); +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2HeadersEncoder.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2HeadersEncoder.java new file mode 100644 index 0000000..aa4edc5 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2HeadersEncoder.java @@ -0,0 +1,110 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.util.internal.UnstableApi; + +/** + * Encodes {@link Http2Headers} into HPACK-encoded headers blocks. + */ +@UnstableApi +public interface Http2HeadersEncoder { + /** + * Configuration related elements for the {@link Http2HeadersEncoder} interface + */ + interface Configuration { + /** + * Represents the value for + * SETTINGS_HEADER_TABLE_SIZE. + * This method should only be called by Netty (not users) as a result of a receiving a {@code SETTINGS} frame. + */ + void maxHeaderTableSize(long max) throws Http2Exception; + + /** + * Represents the value for + * SETTINGS_HEADER_TABLE_SIZE. + * The initial value returned by this method must be {@link Http2CodecUtil#DEFAULT_HEADER_TABLE_SIZE}. + */ + long maxHeaderTableSize(); + + /** + * Represents the value for + * SETTINGS_MAX_HEADER_LIST_SIZE. + * This method should only be called by Netty (not users) as a result of a receiving a {@code SETTINGS} frame. + */ + void maxHeaderListSize(long max) throws Http2Exception; + + /** + * Represents the value for + * SETTINGS_MAX_HEADER_LIST_SIZE. + */ + long maxHeaderListSize(); + } + + /** + * Determine if a header name/value pair is treated as + * sensitive. + * If the object can be dynamically modified and shared across multiple connections it may need to be thread safe. + */ + interface SensitivityDetector { + /** + * Determine if a header {@code name}/{@code value} pair should be treated as + * sensitive. + * + * @param name The name for the header. + * @param value The value of the header. + * @return {@code true} if a header {@code name}/{@code value} pair should be treated as + * sensitive. + * {@code false} otherwise. + */ + boolean isSensitive(CharSequence name, CharSequence value); + } + + /** + * Encodes the given headers and writes the output headers block to the given output buffer. + * + * @param streamId the identifier of the stream for which the headers are encoded. + * @param headers the headers to be encoded. + * @param buffer the buffer to receive the encoded headers. + */ + void encodeHeaders(int streamId, Http2Headers headers, ByteBuf buffer) throws Http2Exception; + + /** + * Get the {@link Configuration} for this {@link Http2HeadersEncoder} + */ + Configuration configuration(); + + /** + * Always return {@code false} for {@link SensitivityDetector#isSensitive(CharSequence, CharSequence)}. + */ + SensitivityDetector NEVER_SENSITIVE = new SensitivityDetector() { + @Override + public boolean isSensitive(CharSequence name, CharSequence value) { + return false; + } + }; + + /** + * Always return {@code true} for {@link SensitivityDetector#isSensitive(CharSequence, CharSequence)}. + */ + SensitivityDetector ALWAYS_SENSITIVE = new SensitivityDetector() { + @Override + public boolean isSensitive(CharSequence name, CharSequence value) { + return true; + } + }; +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2HeadersFrame.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2HeadersFrame.java new file mode 100644 index 0000000..2a24f63 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2HeadersFrame.java @@ -0,0 +1,40 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.util.internal.UnstableApi; + +/** + * HTTP/2 HEADERS frame. + */ +@UnstableApi +public interface Http2HeadersFrame extends Http2StreamFrame { + + /** + * A complete header list. CONTINUATION frames are automatically handled. + */ + Http2Headers headers(); + + /** + * Frame padding to use. Must be non-negative and less than 256. + */ + int padding(); + + /** + * Returns {@code true} if the END_STREAM flag is set. + */ + boolean isEndStream(); +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2InboundFrameLogger.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2InboundFrameLogger.java new file mode 100644 index 0000000..342af5d --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2InboundFrameLogger.java @@ -0,0 +1,147 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http2; + +import static io.netty.handler.codec.http2.Http2FrameLogger.Direction.INBOUND; +import static io.netty.util.internal.ObjectUtil.checkNotNull; +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.util.internal.UnstableApi; + +/** + * Decorator around a {@link Http2FrameReader} that logs all inbound frames before calling + * back the listener. + */ +@UnstableApi +public class Http2InboundFrameLogger implements Http2FrameReader { + private final Http2FrameReader reader; + private final Http2FrameLogger logger; + + public Http2InboundFrameLogger(Http2FrameReader reader, Http2FrameLogger logger) { + this.reader = checkNotNull(reader, "reader"); + this.logger = checkNotNull(logger, "logger"); + } + + @Override + public void readFrame(ChannelHandlerContext ctx, ByteBuf input, final Http2FrameListener listener) + throws Http2Exception { + reader.readFrame(ctx, input, new Http2FrameListener() { + + @Override + public int onDataRead(ChannelHandlerContext ctx, int streamId, ByteBuf data, + int padding, boolean endOfStream) + throws Http2Exception { + logger.logData(INBOUND, ctx, streamId, data, padding, endOfStream); + return listener.onDataRead(ctx, streamId, data, padding, endOfStream); + } + + @Override + public void onHeadersRead(ChannelHandlerContext ctx, int streamId, + Http2Headers headers, int padding, boolean endStream) + throws Http2Exception { + logger.logHeaders(INBOUND, ctx, streamId, headers, padding, endStream); + listener.onHeadersRead(ctx, streamId, headers, padding, endStream); + } + + @Override + public void onHeadersRead(ChannelHandlerContext ctx, int streamId, + Http2Headers headers, int streamDependency, short weight, boolean exclusive, + int padding, boolean endStream) throws Http2Exception { + logger.logHeaders(INBOUND, ctx, streamId, headers, streamDependency, weight, exclusive, + padding, endStream); + listener.onHeadersRead(ctx, streamId, headers, streamDependency, weight, exclusive, + padding, endStream); + } + + @Override + public void onPriorityRead(ChannelHandlerContext ctx, int streamId, + int streamDependency, short weight, boolean exclusive) throws Http2Exception { + logger.logPriority(INBOUND, ctx, streamId, streamDependency, weight, exclusive); + listener.onPriorityRead(ctx, streamId, streamDependency, weight, exclusive); + } + + @Override + public void onRstStreamRead(ChannelHandlerContext ctx, int streamId, long errorCode) + throws Http2Exception { + logger.logRstStream(INBOUND, ctx, streamId, errorCode); + listener.onRstStreamRead(ctx, streamId, errorCode); + } + + @Override + public void onSettingsAckRead(ChannelHandlerContext ctx) throws Http2Exception { + logger.logSettingsAck(INBOUND, ctx); + listener.onSettingsAckRead(ctx); + } + + @Override + public void onSettingsRead(ChannelHandlerContext ctx, Http2Settings settings) + throws Http2Exception { + logger.logSettings(INBOUND, ctx, settings); + listener.onSettingsRead(ctx, settings); + } + + @Override + public void onPingRead(ChannelHandlerContext ctx, long data) throws Http2Exception { + logger.logPing(INBOUND, ctx, data); + listener.onPingRead(ctx, data); + } + + @Override + public void onPingAckRead(ChannelHandlerContext ctx, long data) throws Http2Exception { + logger.logPingAck(INBOUND, ctx, data); + listener.onPingAckRead(ctx, data); + } + + @Override + public void onPushPromiseRead(ChannelHandlerContext ctx, int streamId, + int promisedStreamId, Http2Headers headers, int padding) throws Http2Exception { + logger.logPushPromise(INBOUND, ctx, streamId, promisedStreamId, headers, padding); + listener.onPushPromiseRead(ctx, streamId, promisedStreamId, headers, padding); + } + + @Override + public void onGoAwayRead(ChannelHandlerContext ctx, int lastStreamId, long errorCode, + ByteBuf debugData) throws Http2Exception { + logger.logGoAway(INBOUND, ctx, lastStreamId, errorCode, debugData); + listener.onGoAwayRead(ctx, lastStreamId, errorCode, debugData); + } + + @Override + public void onWindowUpdateRead(ChannelHandlerContext ctx, int streamId, int windowSizeIncrement) + throws Http2Exception { + logger.logWindowsUpdate(INBOUND, ctx, streamId, windowSizeIncrement); + listener.onWindowUpdateRead(ctx, streamId, windowSizeIncrement); + } + + @Override + public void onUnknownFrame(ChannelHandlerContext ctx, byte frameType, int streamId, + Http2Flags flags, ByteBuf payload) throws Http2Exception { + logger.logUnknownFrame(INBOUND, ctx, frameType, streamId, flags, payload); + listener.onUnknownFrame(ctx, frameType, streamId, flags, payload); + } + }); + } + + @Override + public void close() { + reader.close(); + } + + @Override + public Configuration configuration() { + return reader.configuration(); + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2LifecycleManager.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2LifecycleManager.java new file mode 100644 index 0000000..2322708 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2LifecycleManager.java @@ -0,0 +1,98 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.util.internal.UnstableApi; + +/** + * Manager for the life cycle of the HTTP/2 connection. Handles graceful shutdown of the channel, + * closing only after all of the streams have closed. + */ +@UnstableApi +public interface Http2LifecycleManager { + + /** + * Closes the local side of the {@code stream}. Depending on the {@code stream} state this may result in + * {@code stream} being closed. See {@link #closeStream(Http2Stream, ChannelFuture)}. + * @param stream the stream to be half closed. + * @param future See {@link #closeStream(Http2Stream, ChannelFuture)}. + */ + void closeStreamLocal(Http2Stream stream, ChannelFuture future); + + /** + * Closes the remote side of the {@code stream}. Depending on the {@code stream} state this may result in + * {@code stream} being closed. See {@link #closeStream(Http2Stream, ChannelFuture)}. + * @param stream the stream to be half closed. + * @param future See {@link #closeStream(Http2Stream, ChannelFuture)}. + */ + void closeStreamRemote(Http2Stream stream, ChannelFuture future); + + /** + * Closes and deactivates the given {@code stream}. A listener is also attached to {@code future} and upon + * completion the underlying channel will be closed if {@link Http2Connection#numActiveStreams()} is 0. + * @param stream the stream to be closed and deactivated. + * @param future when completed if {@link Http2Connection#numActiveStreams()} is 0 then the underlying channel + * will be closed. + */ + void closeStream(Http2Stream stream, ChannelFuture future); + + /** + * Ensure the stream identified by {@code streamId} is reset. If our local state does not indicate the stream has + * been reset yet then a {@code RST_STREAM} will be sent to the peer. If our local state indicates the stream + * has already been reset then the return status will indicate success without sending anything to the peer. + * @param ctx The context used for communication and buffer allocation if necessary. + * @param streamId The identifier of the stream to reset. + * @param errorCode Justification as to why this stream is being reset. See {@link Http2Error}. + * @param promise Used to indicate the return status of this operation. + * @return Will be considered successful when the connection and stream state has been updated, and a + * {@code RST_STREAM} frame has been sent to the peer. If the stream state has already been updated and a + * {@code RST_STREAM} frame has been sent then the return status may indicate success immediately. + */ + ChannelFuture resetStream(ChannelHandlerContext ctx, int streamId, long errorCode, + ChannelPromise promise); + + /** + * Prevents the peer from creating streams and close the connection if {@code errorCode} is not + * {@link Http2Error#NO_ERROR}. After this call the peer is not allowed to create any new streams and the local + * endpoint will be limited to creating streams with {@code stream identifier <= lastStreamId}. This may result in + * sending a {@code GO_AWAY} frame (assuming we have not already sent one with + * {@code Last-Stream-ID <= lastStreamId}), or may just return success if a {@code GO_AWAY} has previously been + * sent. + * @param ctx The context used for communication and buffer allocation if necessary. + * @param lastStreamId The last stream that the local endpoint is claiming it will accept. + * @param errorCode The rational as to why the connection is being closed. See {@link Http2Error}. + * @param debugData For diagnostic purposes (carries no semantic value). + * @param promise Used to indicate the return status of this operation. + * @return Will be considered successful when the connection and stream state has been updated, and a + * {@code GO_AWAY} frame has been sent to the peer. If the stream state has already been updated and a + * {@code GO_AWAY} frame has been sent then the return status may indicate success immediately. + */ + ChannelFuture goAway(ChannelHandlerContext ctx, int lastStreamId, long errorCode, + ByteBuf debugData, ChannelPromise promise); + + /** + * Processes the given error. + * + * @param ctx The context used for communication and buffer allocation if necessary. + * @param outbound {@code true} if the error was caused by an outbound operation and so the corresponding + * {@link ChannelPromise} was failed as well. + * @param cause the error. + */ + void onError(ChannelHandlerContext ctx, boolean outbound, Throwable cause); +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2LocalFlowController.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2LocalFlowController.java new file mode 100644 index 0000000..544b2e3 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2LocalFlowController.java @@ -0,0 +1,87 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.util.internal.UnstableApi; + +/** + * A {@link Http2FlowController} for controlling the inbound flow of {@code DATA} frames from the remote endpoint. + */ +@UnstableApi +public interface Http2LocalFlowController extends Http2FlowController { + /** + * Sets the writer to be use for sending {@code WINDOW_UPDATE} frames. This must be called before any flow + * controlled data is received. + * + * @param frameWriter the HTTP/2 frame writer. + */ + Http2LocalFlowController frameWriter(Http2FrameWriter frameWriter); + + /** + * Receives an inbound {@code DATA} frame from the remote endpoint and applies flow control policies to it for both + * the {@code stream} as well as the connection. If any flow control policies have been violated, an exception is + * raised immediately, otherwise the frame is considered to have "passed" flow control. + *

+ * If {@code stream} is {@code null} or closed, flow control should only be applied to the connection window and the + * bytes are immediately consumed. + * + * @param stream the subject stream for the received frame. The connection stream object must not be used. If {@code + * stream} is {@code null} or closed, flow control should only be applied to the connection window and the bytes are + * immediately consumed. + * @param data payload buffer for the frame. + * @param padding additional bytes that should be added to obscure the true content size. Must be between 0 and + * 256 (inclusive). + * @param endOfStream Indicates whether this is the last frame to be sent from the remote endpoint for this stream. + * @throws Http2Exception if any flow control errors are encountered. + */ + void receiveFlowControlledFrame(Http2Stream stream, ByteBuf data, int padding, + boolean endOfStream) throws Http2Exception; + + /** + * Indicates that the application has consumed a number of bytes for the given stream and is therefore ready to + * receive more data from the remote endpoint. The application must consume any bytes that it receives or the flow + * control window will collapse. Consuming bytes enables the flow controller to send {@code WINDOW_UPDATE} to + * restore a portion of the flow control window for the stream. + *

+ * If {@code stream} is {@code null} or closed (i.e. {@link Http2Stream#state()} method returns {@link + * Http2Stream.State#CLOSED}), calling this method has no effect. + * + * @param stream the stream for which window space should be freed. The connection stream object must not be used. + * If {@code stream} is {@code null} or closed (i.e. {@link Http2Stream#state()} method returns {@link + * Http2Stream.State#CLOSED}), calling this method has no effect. + * @param numBytes the number of bytes to be returned to the flow control window. + * @return true if a {@code WINDOW_UPDATE} was sent, false otherwise. + * @throws Http2Exception if the number of bytes returned exceeds the {@link #unconsumedBytes(Http2Stream)} for the + * stream. + */ + boolean consumeBytes(Http2Stream stream, int numBytes) throws Http2Exception; + + /** + * The number of bytes for the given stream that have been received but not yet consumed by the + * application. + * + * @param stream the stream for which window space should be freed. + * @return the number of unconsumed bytes for the stream. + */ + int unconsumedBytes(Http2Stream stream); + + /** + * Get the initial flow control window size for the given stream. This quantity is measured in number of bytes. Note + * the unavailable window portion can be calculated by {@link #initialWindowSize()} - {@link + * #windowSize(Http2Stream)}. + */ + int initialWindowSize(Http2Stream stream); +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2MaxRstFrameDecoder.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2MaxRstFrameDecoder.java new file mode 100644 index 0000000..6ac6660 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2MaxRstFrameDecoder.java @@ -0,0 +1,58 @@ +/* + * Copyright 2023 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http2; + +import static io.netty.util.internal.ObjectUtil.checkPositive; + + +/** + * Enforce a limit on the maximum number of RST frames that are allowed per a window + * before the connection will be closed with a GO_AWAY frame. + */ +final class Http2MaxRstFrameDecoder extends DecoratingHttp2ConnectionDecoder { + private final int maxRstFramesPerWindow; + private final int secondsPerWindow; + + Http2MaxRstFrameDecoder(Http2ConnectionDecoder delegate, int maxRstFramesPerWindow, int secondsPerWindow) { + super(delegate); + this.maxRstFramesPerWindow = checkPositive(maxRstFramesPerWindow, "maxRstFramesPerWindow"); + this.secondsPerWindow = checkPositive(secondsPerWindow, "secondsPerWindow"); + } + + @Override + public void frameListener(Http2FrameListener listener) { + if (listener != null) { + super.frameListener(new Http2MaxRstFrameListener(listener, maxRstFramesPerWindow, secondsPerWindow)); + } else { + super.frameListener(null); + } + } + + @Override + public Http2FrameListener frameListener() { + Http2FrameListener frameListener = frameListener0(); + // Unwrap the original Http2FrameListener as we add this decoder under the hood. + if (frameListener instanceof Http2MaxRstFrameListener) { + return ((Http2MaxRstFrameListener) frameListener).listener; + } + return frameListener; + } + + // Package-private for testing + Http2FrameListener frameListener0() { + return super.frameListener(); + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2MaxRstFrameListener.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2MaxRstFrameListener.java new file mode 100644 index 0000000..ce8c2a8 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2MaxRstFrameListener.java @@ -0,0 +1,60 @@ +/* + * Copyright 2023 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.util.concurrent.TimeUnit; + + +final class Http2MaxRstFrameListener extends Http2FrameListenerDecorator { + private static final InternalLogger logger = InternalLoggerFactory.getInstance(Http2MaxRstFrameListener.class); + private static final Http2Exception RST_FRAME_RATE_EXCEEDED = Http2Exception.newStatic(Http2Error.ENHANCE_YOUR_CALM, + "Maximum number of RST frames reached", + Http2Exception.ShutdownHint.HARD_SHUTDOWN, Http2MaxRstFrameListener.class, "onRstStreamRead(..)"); + + private final long nanosPerWindow; + private final int maxRstFramesPerWindow; + private long lastRstFrameNano = System.nanoTime(); + private int receivedRstInWindow; + + Http2MaxRstFrameListener(Http2FrameListener listener, int maxRstFramesPerWindow, int secondsPerWindow) { + super(listener); + this.maxRstFramesPerWindow = maxRstFramesPerWindow; + this.nanosPerWindow = TimeUnit.SECONDS.toNanos(secondsPerWindow); + } + + @Override + public void onRstStreamRead(ChannelHandlerContext ctx, int streamId, long errorCode) throws Http2Exception { + long currentNano = System.nanoTime(); + if (currentNano - lastRstFrameNano >= nanosPerWindow) { + lastRstFrameNano = currentNano; + receivedRstInWindow = 1; + } else { + receivedRstInWindow++; + if (receivedRstInWindow > maxRstFramesPerWindow) { + logger.debug("{} Maximum number {} of RST frames reached within {} seconds, " + + "closing connection with {} error", ctx.channel(), maxRstFramesPerWindow, + TimeUnit.NANOSECONDS.toSeconds(nanosPerWindow), RST_FRAME_RATE_EXCEEDED.error(), + RST_FRAME_RATE_EXCEEDED); + throw RST_FRAME_RATE_EXCEEDED; + } + } + super.onRstStreamRead(ctx, streamId, errorCode); + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2MultiplexActiveStreamsException.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2MultiplexActiveStreamsException.java new file mode 100644 index 0000000..8a06f38 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2MultiplexActiveStreamsException.java @@ -0,0 +1,33 @@ +/* + * Copyright 2023 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http2; + +/** + * {@link Exception} that can be used to wrap some {@link Throwable} and fire it through the pipeline. + * The {@link Http2MultiplexHandler} will unwrap the original {@link Throwable} and fire it to all its + * active {@link Http2StreamChannel}. + */ +public final class Http2MultiplexActiveStreamsException extends Exception { + + public Http2MultiplexActiveStreamsException(Throwable cause) { + super(cause); + } + + @Override + public Throwable fillInStackTrace() { + return this; + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2MultiplexCodec.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2MultiplexCodec.java new file mode 100644 index 0000000..b2abea9 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2MultiplexCodec.java @@ -0,0 +1,346 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.Channel; +import io.netty.channel.ChannelConfig; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.channel.EventLoop; +import io.netty.channel.socket.ChannelInputShutdownReadComplete; +import io.netty.channel.socket.ChannelOutputShutdownEvent; +import io.netty.handler.ssl.SslCloseCompletionEvent; +import io.netty.util.ReferenceCounted; +import io.netty.util.internal.UnstableApi; + +import java.util.ArrayDeque; +import java.util.Queue; + +import static io.netty.handler.codec.http2.AbstractHttp2StreamChannel.CHANNEL_INPUT_SHUTDOWN_READ_COMPLETE_VISITOR; +import static io.netty.handler.codec.http2.AbstractHttp2StreamChannel.CHANNEL_OUTPUT_SHUTDOWN_EVENT_VISITOR; +import static io.netty.handler.codec.http2.AbstractHttp2StreamChannel.SSL_CLOSE_COMPLETION_EVENT_VISITOR; +import static io.netty.handler.codec.http2.Http2CodecUtil.HTTP_UPGRADE_STREAM_ID; +import static io.netty.handler.codec.http2.Http2Error.INTERNAL_ERROR; +import static io.netty.handler.codec.http2.Http2Exception.connectionError; + +/** + * An HTTP/2 handler that creates child channels for each stream. + * + *

When a new stream is created, a new {@link Channel} is created for it. Applications send and + * receive {@link Http2StreamFrame}s on the created channel. {@link ByteBuf}s cannot be processed by the channel; + * all writes that reach the head of the pipeline must be an instance of {@link Http2StreamFrame}. Writes that reach + * the head of the pipeline are processed directly by this handler and cannot be intercepted. + * + *

The child channel will be notified of user events that impact the stream, such as {@link + * Http2GoAwayFrame} and {@link Http2ResetFrame}, as soon as they occur. Although {@code + * Http2GoAwayFrame} and {@code Http2ResetFrame} signify that the remote is ignoring further + * communication, closing of the channel is delayed until any inbound queue is drained with {@link + * Channel#read()}, which follows the default behavior of channels in Netty. Applications are + * free to close the channel in response to such events if they don't have use for any queued + * messages. Any connection level events like {@link Http2SettingsFrame} and {@link Http2GoAwayFrame} + * will be processed internally and also propagated down the pipeline for other handlers to act on. + * + *

Outbound streams are supported via the {@link Http2StreamChannelBootstrap}. + * + *

{@link ChannelConfig#setMaxMessagesPerRead(int)} and {@link ChannelConfig#setAutoRead(boolean)} are supported. + * + *

Reference Counting

+ * + * Some {@link Http2StreamFrame}s implement the {@link ReferenceCounted} interface, as they carry + * reference counted objects (e.g. {@link ByteBuf}s). The multiplex codec will call {@link ReferenceCounted#retain()} + * before propagating a reference counted object through the pipeline, and thus an application handler needs to release + * such an object after having consumed it. For more information on reference counting take a look at + * https://netty.io/wiki/reference-counted-objects.html + * + *

Channel Events

+ * + * A child channel becomes active as soon as it is registered to an {@link EventLoop}. Therefore, an active channel + * does not map to an active HTTP/2 stream immediately. Only once a {@link Http2HeadersFrame} has been successfully sent + * or received, does the channel map to an active HTTP/2 stream. In case it is not possible to open a new HTTP/2 stream + * (i.e. due to the maximum number of active streams being exceeded), the child channel receives an exception + * indicating the cause and is closed immediately thereafter. + * + *

Writability and Flow Control

+ * + * A child channel observes outbound/remote flow control via the channel's writability. A channel only becomes writable + * when it maps to an active HTTP/2 stream and the stream's flow control window is greater than zero. A child channel + * does not know about the connection-level flow control window. {@link ChannelHandler}s are free to ignore the + * channel's writability, in which case the excessive writes will be buffered by the parent channel. It's important to + * note that only {@link Http2DataFrame}s are subject to HTTP/2 flow control. + * + * @deprecated use {@link Http2FrameCodecBuilder} together with {@link Http2MultiplexHandler}. + */ +@Deprecated +@UnstableApi +public class Http2MultiplexCodec extends Http2FrameCodec { + + private final ChannelHandler inboundStreamHandler; + private final ChannelHandler upgradeStreamHandler; + private final Queue readCompletePendingQueue = + new MaxCapacityQueue(new ArrayDeque(8), + // Choose 100 which is what is used most of the times as default. + Http2CodecUtil.SMALLEST_MAX_CONCURRENT_STREAMS); + + private boolean parentReadInProgress; + private int idCount; + + // Need to be volatile as accessed from within the Http2MultiplexCodecStreamChannel in a multi-threaded fashion. + volatile ChannelHandlerContext ctx; + + Http2MultiplexCodec(Http2ConnectionEncoder encoder, + Http2ConnectionDecoder decoder, + Http2Settings initialSettings, + ChannelHandler inboundStreamHandler, + ChannelHandler upgradeStreamHandler, boolean decoupleCloseAndGoAway, boolean flushPreface) { + super(encoder, decoder, initialSettings, decoupleCloseAndGoAway, flushPreface); + this.inboundStreamHandler = inboundStreamHandler; + this.upgradeStreamHandler = upgradeStreamHandler; + } + + @Override + public void onHttpClientUpgrade() throws Http2Exception { + // We must have an upgrade handler or else we can't handle the stream + if (upgradeStreamHandler == null) { + throw connectionError(INTERNAL_ERROR, "Client is misconfigured for upgrade requests"); + } + // Creates the Http2Stream in the Connection. + super.onHttpClientUpgrade(); + } + + @Override + public final void handlerAdded0(ChannelHandlerContext ctx) throws Exception { + if (ctx.executor() != ctx.channel().eventLoop()) { + throw new IllegalStateException("EventExecutor must be EventLoop of Channel"); + } + this.ctx = ctx; + } + + @Override + public final void handlerRemoved0(ChannelHandlerContext ctx) throws Exception { + super.handlerRemoved0(ctx); + + readCompletePendingQueue.clear(); + } + + @Override + final void onHttp2Frame(ChannelHandlerContext ctx, Http2Frame frame) { + if (frame instanceof Http2StreamFrame) { + Http2StreamFrame streamFrame = (Http2StreamFrame) frame; + AbstractHttp2StreamChannel channel = (AbstractHttp2StreamChannel) + ((DefaultHttp2FrameStream) streamFrame.stream()).attachment; + channel.fireChildRead(streamFrame); + return; + } + if (frame instanceof Http2GoAwayFrame) { + onHttp2GoAwayFrame(ctx, (Http2GoAwayFrame) frame); + } + // Send frames down the pipeline + ctx.fireChannelRead(frame); + } + + @Override + final void onHttp2StreamStateChanged(ChannelHandlerContext ctx, DefaultHttp2FrameStream stream) { + switch (stream.state()) { + case HALF_CLOSED_LOCAL: + if (stream.id() != HTTP_UPGRADE_STREAM_ID) { + // Ignore everything which was not caused by an upgrade + break; + } + // fall-through + case HALF_CLOSED_REMOTE: + // fall-through + case OPEN: + if (stream.attachment != null) { + // ignore if child channel was already created. + break; + } + final Http2MultiplexCodecStreamChannel streamChannel; + // We need to handle upgrades special when on the client side. + if (stream.id() == HTTP_UPGRADE_STREAM_ID && !connection().isServer()) { + // Add our upgrade handler to the channel and then register the channel. + // The register call fires the channelActive, etc. + assert upgradeStreamHandler != null; + streamChannel = new Http2MultiplexCodecStreamChannel(stream, upgradeStreamHandler); + streamChannel.closeOutbound(); + } else { + streamChannel = new Http2MultiplexCodecStreamChannel(stream, inboundStreamHandler); + } + ChannelFuture future = ctx.channel().eventLoop().register(streamChannel); + if (future.isDone()) { + Http2MultiplexHandler.registerDone(future); + } else { + future.addListener(Http2MultiplexHandler.CHILD_CHANNEL_REGISTRATION_LISTENER); + } + break; + case CLOSED: + AbstractHttp2StreamChannel channel = (AbstractHttp2StreamChannel) stream.attachment; + if (channel != null) { + channel.streamClosed(); + } + break; + default: + // ignore for now + break; + } + } + + // TODO: This is most likely not the best way to expose this, need to think more about it. + final Http2StreamChannel newOutboundStream() { + return new Http2MultiplexCodecStreamChannel(newStream(), null); + } + + @Override + final void onHttp2FrameStreamException(ChannelHandlerContext ctx, Http2FrameStreamException cause) { + Http2FrameStream stream = cause.stream(); + AbstractHttp2StreamChannel channel = (AbstractHttp2StreamChannel) ((DefaultHttp2FrameStream) stream).attachment; + + try { + channel.pipeline().fireExceptionCaught(cause.getCause()); + } finally { + // Close with the correct error that causes this stream exception. + // See https://github.com/netty/netty/issues/13235#issuecomment-1441994672 + channel.closeWithError(cause.error()); + } + } + + private void onHttp2GoAwayFrame(ChannelHandlerContext ctx, final Http2GoAwayFrame goAwayFrame) { + if (goAwayFrame.lastStreamId() == Integer.MAX_VALUE) { + // None of the streams can have an id greater than Integer.MAX_VALUE + return; + } + // Notify which streams were not processed by the remote peer and are safe to retry on another connection: + try { + forEachActiveStream(new Http2FrameStreamVisitor() { + @Override + public boolean visit(Http2FrameStream stream) { + final int streamId = stream.id(); + AbstractHttp2StreamChannel channel = (AbstractHttp2StreamChannel) + ((DefaultHttp2FrameStream) stream).attachment; + if (streamId > goAwayFrame.lastStreamId() && connection().local().isValidStreamId(streamId)) { + channel.pipeline().fireUserEventTriggered(goAwayFrame.retainedDuplicate()); + } + return true; + } + }); + } catch (Http2Exception e) { + ctx.fireExceptionCaught(e); + ctx.close(); + } + } + + /** + * Notifies any child streams of the read completion. + */ + @Override + public final void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + processPendingReadCompleteQueue(); + channelReadComplete0(ctx); + } + + private void processPendingReadCompleteQueue() { + parentReadInProgress = true; + try { + // If we have many child channel we can optimize for the case when multiple call flush() in + // channelReadComplete(...) callbacks and only do it once as otherwise we will end-up with multiple + // write calls on the socket which is expensive. + for (;;) { + AbstractHttp2StreamChannel childChannel = readCompletePendingQueue.poll(); + if (childChannel == null) { + break; + } + childChannel.fireChildReadComplete(); + } + } finally { + parentReadInProgress = false; + readCompletePendingQueue.clear(); + // We always flush as this is what Http2ConnectionHandler does for now. + flush0(ctx); + } + } + @Override + public final void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + parentReadInProgress = true; + super.channelRead(ctx, msg); + } + + @Override + public final void channelWritabilityChanged(final ChannelHandlerContext ctx) throws Exception { + if (ctx.channel().isWritable()) { + // While the writability state may change during iterating of the streams we just set all of the streams + // to writable to not affect fairness. These will be "limited" by their own watermarks in any case. + forEachActiveStream(AbstractHttp2StreamChannel.WRITABLE_VISITOR); + } + + super.channelWritabilityChanged(ctx); + } + + @Override + final void onUserEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + if (evt == ChannelInputShutdownReadComplete.INSTANCE) { + forEachActiveStream(CHANNEL_INPUT_SHUTDOWN_READ_COMPLETE_VISITOR); + } else if (evt == ChannelOutputShutdownEvent.INSTANCE) { + forEachActiveStream(CHANNEL_OUTPUT_SHUTDOWN_EVENT_VISITOR); + } else if (evt == SslCloseCompletionEvent.SUCCESS) { + forEachActiveStream(SSL_CLOSE_COMPLETION_EVENT_VISITOR); + } + super.onUserEventTriggered(ctx, evt); + } + + final void flush0(ChannelHandlerContext ctx) { + flush(ctx); + } + + private final class Http2MultiplexCodecStreamChannel extends AbstractHttp2StreamChannel { + + Http2MultiplexCodecStreamChannel(DefaultHttp2FrameStream stream, ChannelHandler inboundHandler) { + super(stream, ++idCount, inboundHandler); + } + + @Override + protected boolean isParentReadInProgress() { + return parentReadInProgress; + } + + @Override + protected void addChannelToReadCompletePendingQueue() { + // If there is no space left in the queue, just keep on processing everything that is already + // stored there and try again. + while (!readCompletePendingQueue.offer(this)) { + processPendingReadCompleteQueue(); + } + } + + @Override + protected ChannelHandlerContext parentContext() { + return ctx; + } + + @Override + protected ChannelFuture write0(ChannelHandlerContext ctx, Object msg) { + ChannelPromise promise = ctx.newPromise(); + Http2MultiplexCodec.this.write(ctx, msg, promise); + return promise; + } + + @Override + protected void flush0(ChannelHandlerContext ctx) { + Http2MultiplexCodec.this.flush0(ctx); + } + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2MultiplexCodecBuilder.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2MultiplexCodecBuilder.java new file mode 100644 index 0000000..927c577 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2MultiplexCodecBuilder.java @@ -0,0 +1,260 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerAdapter; +import io.netty.util.internal.UnstableApi; + +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +/** + * A builder for {@link Http2MultiplexCodec}. + * + * @deprecated use {@link Http2FrameCodecBuilder} together with {@link Http2MultiplexHandler}. + */ +@Deprecated +@UnstableApi +public class Http2MultiplexCodecBuilder + extends AbstractHttp2ConnectionHandlerBuilder { + private Http2FrameWriter frameWriter; + + final ChannelHandler childHandler; + private ChannelHandler upgradeStreamHandler; + + Http2MultiplexCodecBuilder(boolean server, ChannelHandler childHandler) { + server(server); + this.childHandler = checkSharable(checkNotNull(childHandler, "childHandler")); + // For backwards compatibility we should disable to timeout by default at this layer. + gracefulShutdownTimeoutMillis(0); + } + + private static ChannelHandler checkSharable(ChannelHandler handler) { + if (handler instanceof ChannelHandlerAdapter && !((ChannelHandlerAdapter) handler).isSharable() && + !handler.getClass().isAnnotationPresent(ChannelHandler.Sharable.class)) { + throw new IllegalArgumentException("The handler must be Sharable"); + } + return handler; + } + + // For testing only. + Http2MultiplexCodecBuilder frameWriter(Http2FrameWriter frameWriter) { + this.frameWriter = checkNotNull(frameWriter, "frameWriter"); + return this; + } + + /** + * Creates a builder for an HTTP/2 client. + * + * @param childHandler the handler added to channels for remotely-created streams. It must be + * {@link ChannelHandler.Sharable}. + */ + public static Http2MultiplexCodecBuilder forClient(ChannelHandler childHandler) { + return new Http2MultiplexCodecBuilder(false, childHandler); + } + + /** + * Creates a builder for an HTTP/2 server. + * + * @param childHandler the handler added to channels for remotely-created streams. It must be + * {@link ChannelHandler.Sharable}. + */ + public static Http2MultiplexCodecBuilder forServer(ChannelHandler childHandler) { + return new Http2MultiplexCodecBuilder(true, childHandler); + } + + public Http2MultiplexCodecBuilder withUpgradeStreamHandler(ChannelHandler upgradeStreamHandler) { + if (isServer()) { + throw new IllegalArgumentException("Server codecs don't use an extra handler for the upgrade stream"); + } + this.upgradeStreamHandler = upgradeStreamHandler; + return this; + } + + @Override + public Http2Settings initialSettings() { + return super.initialSettings(); + } + + @Override + public Http2MultiplexCodecBuilder initialSettings(Http2Settings settings) { + return super.initialSettings(settings); + } + + @Override + public long gracefulShutdownTimeoutMillis() { + return super.gracefulShutdownTimeoutMillis(); + } + + @Override + public Http2MultiplexCodecBuilder gracefulShutdownTimeoutMillis(long gracefulShutdownTimeoutMillis) { + return super.gracefulShutdownTimeoutMillis(gracefulShutdownTimeoutMillis); + } + + @Override + public boolean isServer() { + return super.isServer(); + } + + @Override + public int maxReservedStreams() { + return super.maxReservedStreams(); + } + + @Override + public Http2MultiplexCodecBuilder maxReservedStreams(int maxReservedStreams) { + return super.maxReservedStreams(maxReservedStreams); + } + + @Override + public boolean isValidateHeaders() { + return super.isValidateHeaders(); + } + + @Override + public Http2MultiplexCodecBuilder validateHeaders(boolean validateHeaders) { + return super.validateHeaders(validateHeaders); + } + + @Override + public Http2FrameLogger frameLogger() { + return super.frameLogger(); + } + + @Override + public Http2MultiplexCodecBuilder frameLogger(Http2FrameLogger frameLogger) { + return super.frameLogger(frameLogger); + } + + @Override + public boolean encoderEnforceMaxConcurrentStreams() { + return super.encoderEnforceMaxConcurrentStreams(); + } + + @Override + public Http2MultiplexCodecBuilder encoderEnforceMaxConcurrentStreams(boolean encoderEnforceMaxConcurrentStreams) { + return super.encoderEnforceMaxConcurrentStreams(encoderEnforceMaxConcurrentStreams); + } + + @Override + public int encoderEnforceMaxQueuedControlFrames() { + return super.encoderEnforceMaxQueuedControlFrames(); + } + + @Override + public Http2MultiplexCodecBuilder encoderEnforceMaxQueuedControlFrames(int maxQueuedControlFrames) { + return super.encoderEnforceMaxQueuedControlFrames(maxQueuedControlFrames); + } + + @Override + public Http2HeadersEncoder.SensitivityDetector headerSensitivityDetector() { + return super.headerSensitivityDetector(); + } + + @Override + public Http2MultiplexCodecBuilder headerSensitivityDetector( + Http2HeadersEncoder.SensitivityDetector headerSensitivityDetector) { + return super.headerSensitivityDetector(headerSensitivityDetector); + } + + @Override + public Http2MultiplexCodecBuilder encoderIgnoreMaxHeaderListSize(boolean ignoreMaxHeaderListSize) { + return super.encoderIgnoreMaxHeaderListSize(ignoreMaxHeaderListSize); + } + + @Override + @Deprecated + public Http2MultiplexCodecBuilder initialHuffmanDecodeCapacity(int initialHuffmanDecodeCapacity) { + return super.initialHuffmanDecodeCapacity(initialHuffmanDecodeCapacity); + } + + @Override + public Http2MultiplexCodecBuilder autoAckSettingsFrame(boolean autoAckSettings) { + return super.autoAckSettingsFrame(autoAckSettings); + } + + @Override + public Http2MultiplexCodecBuilder autoAckPingFrame(boolean autoAckPingFrame) { + return super.autoAckPingFrame(autoAckPingFrame); + } + + @Override + public Http2MultiplexCodecBuilder decoupleCloseAndGoAway(boolean decoupleCloseAndGoAway) { + return super.decoupleCloseAndGoAway(decoupleCloseAndGoAway); + } + + @Override + public Http2MultiplexCodecBuilder flushPreface(boolean flushPreface) { + return super.flushPreface(flushPreface); + } + + @Override + public int decoderEnforceMaxConsecutiveEmptyDataFrames() { + return super.decoderEnforceMaxConsecutiveEmptyDataFrames(); + } + + @Override + public Http2MultiplexCodecBuilder decoderEnforceMaxConsecutiveEmptyDataFrames(int maxConsecutiveEmptyFrames) { + return super.decoderEnforceMaxConsecutiveEmptyDataFrames(maxConsecutiveEmptyFrames); + } + + @Override + public Http2MultiplexCodecBuilder decoderEnforceMaxRstFramesPerWindow( + int maxRstFramesPerWindow, int secondsPerWindow) { + return super.decoderEnforceMaxRstFramesPerWindow(maxRstFramesPerWindow, secondsPerWindow); + } + + @Override + public Http2MultiplexCodec build() { + Http2FrameWriter frameWriter = this.frameWriter; + if (frameWriter != null) { + // This is to support our tests and will never be executed by the user as frameWriter(...) + // is package-private. + DefaultHttp2Connection connection = new DefaultHttp2Connection(isServer(), maxReservedStreams()); + Long maxHeaderListSize = initialSettings().maxHeaderListSize(); + Http2FrameReader frameReader = new DefaultHttp2FrameReader(maxHeaderListSize == null ? + new DefaultHttp2HeadersDecoder(isValidateHeaders()) : + new DefaultHttp2HeadersDecoder(isValidateHeaders(), maxHeaderListSize)); + + if (frameLogger() != null) { + frameWriter = new Http2OutboundFrameLogger(frameWriter, frameLogger()); + frameReader = new Http2InboundFrameLogger(frameReader, frameLogger()); + } + Http2ConnectionEncoder encoder = new DefaultHttp2ConnectionEncoder(connection, frameWriter); + if (encoderEnforceMaxConcurrentStreams()) { + encoder = new StreamBufferingEncoder(encoder); + } + Http2ConnectionDecoder decoder = new DefaultHttp2ConnectionDecoder(connection, encoder, frameReader, + promisedRequestVerifier(), isAutoAckSettingsFrame(), isAutoAckPingFrame(), isValidateHeaders()); + int maxConsecutiveEmptyDataFrames = decoderEnforceMaxConsecutiveEmptyDataFrames(); + if (maxConsecutiveEmptyDataFrames > 0) { + decoder = new Http2EmptyDataFrameConnectionDecoder(decoder, maxConsecutiveEmptyDataFrames); + } + + return build(decoder, encoder, initialSettings()); + } + return super.build(); + } + + @Override + protected Http2MultiplexCodec build( + Http2ConnectionDecoder decoder, Http2ConnectionEncoder encoder, Http2Settings initialSettings) { + Http2MultiplexCodec codec = new Http2MultiplexCodec(encoder, decoder, initialSettings, childHandler, + upgradeStreamHandler, decoupleCloseAndGoAway(), flushPreface()); + codec.gracefulShutdownTimeoutMillis(gracefulShutdownTimeoutMillis()); + return codec; + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2MultiplexHandler.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2MultiplexHandler.java new file mode 100644 index 0000000..5a6af8a --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2MultiplexHandler.java @@ -0,0 +1,415 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.Channel; +import io.netty.channel.ChannelConfig; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.EventLoop; +import io.netty.channel.ServerChannel; +import io.netty.channel.socket.ChannelInputShutdownReadComplete; +import io.netty.channel.socket.ChannelOutputShutdownEvent; +import io.netty.handler.codec.http2.Http2FrameCodec.DefaultHttp2FrameStream; +import io.netty.handler.ssl.SslCloseCompletionEvent; +import io.netty.util.ReferenceCounted; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.UnstableApi; + +import java.util.ArrayDeque; +import java.util.Queue; +import javax.net.ssl.SSLException; + +import static io.netty.handler.codec.http2.AbstractHttp2StreamChannel.CHANNEL_INPUT_SHUTDOWN_READ_COMPLETE_VISITOR; +import static io.netty.handler.codec.http2.AbstractHttp2StreamChannel.CHANNEL_OUTPUT_SHUTDOWN_EVENT_VISITOR; +import static io.netty.handler.codec.http2.AbstractHttp2StreamChannel.SSL_CLOSE_COMPLETION_EVENT_VISITOR; +import static io.netty.handler.codec.http2.Http2Error.INTERNAL_ERROR; +import static io.netty.handler.codec.http2.Http2Exception.connectionError; + +/** + * An HTTP/2 handler that creates child channels for each stream. This handler must be used in combination + * with {@link Http2FrameCodec}. + * + *

When a new stream is created, a new {@link Http2StreamChannel} is created for it. Applications send and + * receive {@link Http2StreamFrame}s on the created channel. {@link ByteBuf}s cannot be processed by the channel; + * all writes that reach the head of the pipeline must be an instance of {@link Http2StreamFrame}. Writes that reach + * the head of the pipeline are processed directly by this handler and cannot be intercepted. + * + *

The child channel will be notified of user events that impact the stream, such as {@link + * Http2GoAwayFrame} and {@link Http2ResetFrame}, as soon as they occur. Although {@code + * Http2GoAwayFrame} and {@code Http2ResetFrame} signify that the remote is ignoring further + * communication, closing of the channel is delayed until any inbound queue is drained with {@link + * Channel#read()}, which follows the default behavior of channels in Netty. Applications are + * free to close the channel in response to such events if they don't have use for any queued + * messages. Any connection level events like {@link Http2SettingsFrame} and {@link Http2GoAwayFrame} + * will be processed internally and also propagated down the pipeline for other handlers to act on. + * + *

Outbound streams are supported via the {@link Http2StreamChannelBootstrap}. + * + *

{@link ChannelConfig#setMaxMessagesPerRead(int)} and {@link ChannelConfig#setAutoRead(boolean)} are supported. + * + *

Reference Counting

+ * + * Some {@link Http2StreamFrame}s implement the {@link ReferenceCounted} interface, as they carry + * reference counted objects (e.g. {@link ByteBuf}s). The multiplex codec will call {@link ReferenceCounted#retain()} + * before propagating a reference counted object through the pipeline, and thus an application handler needs to release + * such an object after having consumed it. For more information on reference counting take a look at + * the reference counted docs. + * + *

Channel Events

+ * + * A child channel becomes active as soon as it is registered to an {@link EventLoop}. Therefore, an active channel + * does not map to an active HTTP/2 stream immediately. Only once a {@link Http2HeadersFrame} has been successfully sent + * or received, does the channel map to an active HTTP/2 stream. In case it is not possible to open a new HTTP/2 stream + * (i.e. due to the maximum number of active streams being exceeded), the child channel receives an exception + * indicating the cause and is closed immediately thereafter. + * + *

Writability and Flow Control

+ * + * A child channel observes outbound/remote flow control via the channel's writability. A channel only becomes writable + * when it maps to an active HTTP/2 stream . A child channel does not know about the connection-level flow control + * window. {@link ChannelHandler}s are free to ignore the channel's writability, in which case the excessive writes will + * be buffered by the parent channel. It's important to note that only {@link Http2DataFrame}s are subject to + * HTTP/2 flow control. + * + *

Closing a {@link Http2StreamChannel}

+ * + * Once you close a {@link Http2StreamChannel} a {@link Http2ResetFrame} will be sent to the remote peer with + * {@link Http2Error#CANCEL} if needed. If you want to close the stream with another {@link Http2Error} (due + * errors / limits) you should propagate a {@link Http2FrameStreamException} through the {@link ChannelPipeline}. + * Once it reaches the end of the {@link ChannelPipeline} it will automatically close the {@link Http2StreamChannel} + * and send a {@link Http2ResetFrame} with the unwrapped {@link Http2Error} set. Another possibility is to just + * directly write a {@link Http2ResetFrame} to the {@link Http2StreamChannel}l. + */ +@UnstableApi +public final class Http2MultiplexHandler extends Http2ChannelDuplexHandler { + + static final ChannelFutureListener CHILD_CHANNEL_REGISTRATION_LISTENER = new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + registerDone(future); + } + }; + + private final ChannelHandler inboundStreamHandler; + private final ChannelHandler upgradeStreamHandler; + private final Queue readCompletePendingQueue = + new MaxCapacityQueue(new ArrayDeque(8), + // Choose 100 which is what is used most of the times as default. + Http2CodecUtil.SMALLEST_MAX_CONCURRENT_STREAMS); + + private boolean parentReadInProgress; + private int idCount; + + // Need to be volatile as accessed from within the Http2MultiplexHandlerStreamChannel in a multi-threaded fashion. + private volatile ChannelHandlerContext ctx; + + /** + * Creates a new instance + * + * @param inboundStreamHandler the {@link ChannelHandler} that will be added to the {@link ChannelPipeline} of + * the {@link Channel}s created for new inbound streams. + */ + public Http2MultiplexHandler(ChannelHandler inboundStreamHandler) { + this(inboundStreamHandler, null); + } + + /** + * Creates a new instance + * + * @param inboundStreamHandler the {@link ChannelHandler} that will be added to the {@link ChannelPipeline} of + * the {@link Channel}s created for new inbound streams. + * @param upgradeStreamHandler the {@link ChannelHandler} that will be added to the {@link ChannelPipeline} of the + * upgraded {@link Channel}. + */ + public Http2MultiplexHandler(ChannelHandler inboundStreamHandler, ChannelHandler upgradeStreamHandler) { + this.inboundStreamHandler = ObjectUtil.checkNotNull(inboundStreamHandler, "inboundStreamHandler"); + this.upgradeStreamHandler = upgradeStreamHandler; + } + + static void registerDone(ChannelFuture future) { + // Handle any errors that occurred on the local thread while registering. Even though + // failures can happen after this point, they will be handled by the channel by closing the + // childChannel. + if (!future.isSuccess()) { + Channel childChannel = future.channel(); + if (childChannel.isRegistered()) { + childChannel.close(); + } else { + childChannel.unsafe().closeForcibly(); + } + } + } + + @Override + protected void handlerAdded0(ChannelHandlerContext ctx) { + if (ctx.executor() != ctx.channel().eventLoop()) { + throw new IllegalStateException("EventExecutor must be EventLoop of Channel"); + } + this.ctx = ctx; + } + + @Override + protected void handlerRemoved0(ChannelHandlerContext ctx) { + readCompletePendingQueue.clear(); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + parentReadInProgress = true; + if (msg instanceof Http2StreamFrame) { + if (msg instanceof Http2WindowUpdateFrame) { + // We dont want to propagate update frames to the user + return; + } + Http2StreamFrame streamFrame = (Http2StreamFrame) msg; + DefaultHttp2FrameStream s = + (DefaultHttp2FrameStream) streamFrame.stream(); + + AbstractHttp2StreamChannel channel = (AbstractHttp2StreamChannel) s.attachment; + if (msg instanceof Http2ResetFrame) { + // Reset frames needs to be propagated via user events as these are not flow-controlled and so + // must not be controlled by suppressing channel.read() on the child channel. + channel.pipeline().fireUserEventTriggered(msg); + + // RST frames will also trigger closing of the streams which then will call + // AbstractHttp2StreamChannel.streamClosed() + } else { + channel.fireChildRead(streamFrame); + } + return; + } + + if (msg instanceof Http2GoAwayFrame) { + // goaway frames will also trigger closing of the streams which then will call + // AbstractHttp2StreamChannel.streamClosed() + onHttp2GoAwayFrame(ctx, (Http2GoAwayFrame) msg); + } + + // Send everything down the pipeline + ctx.fireChannelRead(msg); + } + + @Override + public void channelWritabilityChanged(final ChannelHandlerContext ctx) throws Exception { + if (ctx.channel().isWritable()) { + // While the writability state may change during iterating of the streams we just set all of the streams + // to writable to not affect fairness. These will be "limited" by their own watermarks in any case. + forEachActiveStream(AbstractHttp2StreamChannel.WRITABLE_VISITOR); + } + + ctx.fireChannelWritabilityChanged(); + } + + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + if (evt instanceof Http2FrameStreamEvent) { + Http2FrameStreamEvent event = (Http2FrameStreamEvent) evt; + DefaultHttp2FrameStream stream = (DefaultHttp2FrameStream) event.stream(); + if (event.type() == Http2FrameStreamEvent.Type.State) { + switch (stream.state()) { + case HALF_CLOSED_LOCAL: + if (stream.id() != Http2CodecUtil.HTTP_UPGRADE_STREAM_ID) { + // Ignore everything which was not caused by an upgrade + break; + } + // fall-through + case HALF_CLOSED_REMOTE: + // fall-through + case OPEN: + if (stream.attachment != null) { + // ignore if child channel was already created. + break; + } + final AbstractHttp2StreamChannel ch; + // We need to handle upgrades special when on the client side. + if (stream.id() == Http2CodecUtil.HTTP_UPGRADE_STREAM_ID && !isServer(ctx)) { + // We must have an upgrade handler or else we can't handle the stream + if (upgradeStreamHandler == null) { + throw connectionError(INTERNAL_ERROR, + "Client is misconfigured for upgrade requests"); + } + ch = new Http2MultiplexHandlerStreamChannel(stream, upgradeStreamHandler); + ch.closeOutbound(); + } else { + ch = new Http2MultiplexHandlerStreamChannel(stream, inboundStreamHandler); + } + ChannelFuture future = ctx.channel().eventLoop().register(ch); + if (future.isDone()) { + registerDone(future); + } else { + future.addListener(CHILD_CHANNEL_REGISTRATION_LISTENER); + } + break; + case CLOSED: + AbstractHttp2StreamChannel channel = (AbstractHttp2StreamChannel) stream.attachment; + if (channel != null) { + channel.streamClosed(); + } + break; + default: + // ignore for now + break; + } + } + return; + } + if (evt == ChannelInputShutdownReadComplete.INSTANCE) { + forEachActiveStream(CHANNEL_INPUT_SHUTDOWN_READ_COMPLETE_VISITOR); + } else if (evt == ChannelOutputShutdownEvent.INSTANCE) { + forEachActiveStream(CHANNEL_OUTPUT_SHUTDOWN_EVENT_VISITOR); + } else if (evt == SslCloseCompletionEvent.SUCCESS) { + forEachActiveStream(SSL_CLOSE_COMPLETION_EVENT_VISITOR); + } + ctx.fireUserEventTriggered(evt); + } + + // TODO: This is most likely not the best way to expose this, need to think more about it. + Http2StreamChannel newOutboundStream() { + return new Http2MultiplexHandlerStreamChannel((DefaultHttp2FrameStream) newStream(), null); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, final Throwable cause) throws Exception { + if (cause instanceof Http2FrameStreamException) { + Http2FrameStreamException exception = (Http2FrameStreamException) cause; + Http2FrameStream stream = exception.stream(); + AbstractHttp2StreamChannel childChannel = (AbstractHttp2StreamChannel) + ((DefaultHttp2FrameStream) stream).attachment; + try { + childChannel.pipeline().fireExceptionCaught(cause.getCause()); + } finally { + // Close with the correct error that causes this stream exception. + // See https://github.com/netty/netty/issues/13235#issuecomment-1441994672 + childChannel.closeWithError(exception.error()); + } + return; + } + if (cause instanceof Http2MultiplexActiveStreamsException) { + // Unwrap the cause that was used to create it and fire it for all the active streams. + fireExceptionCaughtForActiveStream(cause.getCause()); + return; + } + + if (cause.getCause() instanceof SSLException) { + fireExceptionCaughtForActiveStream(cause); + } + ctx.fireExceptionCaught(cause); + } + + private void fireExceptionCaughtForActiveStream(final Throwable cause) throws Http2Exception { + forEachActiveStream(new Http2FrameStreamVisitor() { + @Override + public boolean visit(Http2FrameStream stream) { + AbstractHttp2StreamChannel childChannel = (AbstractHttp2StreamChannel) + ((DefaultHttp2FrameStream) stream).attachment; + childChannel.pipeline().fireExceptionCaught(cause); + return true; + } + }); + } + + private static boolean isServer(ChannelHandlerContext ctx) { + return ctx.channel().parent() instanceof ServerChannel; + } + + private void onHttp2GoAwayFrame(ChannelHandlerContext ctx, final Http2GoAwayFrame goAwayFrame) { + if (goAwayFrame.lastStreamId() == Integer.MAX_VALUE) { + // None of the streams can have an id greater than Integer.MAX_VALUE + return; + } + // Notify which streams were not processed by the remote peer and are safe to retry on another connection: + try { + final boolean server = isServer(ctx); + forEachActiveStream(new Http2FrameStreamVisitor() { + @Override + public boolean visit(Http2FrameStream stream) { + final int streamId = stream.id(); + if (streamId > goAwayFrame.lastStreamId() && Http2CodecUtil.isStreamIdValid(streamId, server)) { + final AbstractHttp2StreamChannel childChannel = (AbstractHttp2StreamChannel) + ((DefaultHttp2FrameStream) stream).attachment; + childChannel.pipeline().fireUserEventTriggered(goAwayFrame.retainedDuplicate()); + } + return true; + } + }); + } catch (Http2Exception e) { + ctx.fireExceptionCaught(e); + ctx.close(); + } + } + + /** + * Notifies any child streams of the read completion. + */ + @Override + public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + processPendingReadCompleteQueue(); + ctx.fireChannelReadComplete(); + } + + private void processPendingReadCompleteQueue() { + parentReadInProgress = true; + // If we have many child channel we can optimize for the case when multiple call flush() in + // channelReadComplete(...) callbacks and only do it once as otherwise we will end-up with multiple + // write calls on the socket which is expensive. + AbstractHttp2StreamChannel childChannel = readCompletePendingQueue.poll(); + if (childChannel != null) { + try { + do { + childChannel.fireChildReadComplete(); + childChannel = readCompletePendingQueue.poll(); + } while (childChannel != null); + } finally { + parentReadInProgress = false; + readCompletePendingQueue.clear(); + ctx.flush(); + } + } else { + parentReadInProgress = false; + } + } + + private final class Http2MultiplexHandlerStreamChannel extends AbstractHttp2StreamChannel { + + Http2MultiplexHandlerStreamChannel(DefaultHttp2FrameStream stream, ChannelHandler inboundHandler) { + super(stream, ++idCount, inboundHandler); + } + + @Override + protected boolean isParentReadInProgress() { + return parentReadInProgress; + } + + @Override + protected void addChannelToReadCompletePendingQueue() { + // If there is no space left in the queue, just keep on processing everything that is already + // stored there and try again. + while (!readCompletePendingQueue.offer(this)) { + processPendingReadCompleteQueue(); + } + } + + @Override + protected ChannelHandlerContext parentContext() { + return ctx; + } + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2NoMoreStreamIdsException.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2NoMoreStreamIdsException.java new file mode 100644 index 0000000..121ae96 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2NoMoreStreamIdsException.java @@ -0,0 +1,36 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.util.internal.UnstableApi; + +import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR; + +/** + * This exception is thrown when there are no more stream IDs available for the current connection + */ +@UnstableApi +public class Http2NoMoreStreamIdsException extends Http2Exception { + private static final long serialVersionUID = -7756236161274851110L; + private static final String ERROR_MESSAGE = "No more streams can be created on this connection"; + + public Http2NoMoreStreamIdsException() { + super(PROTOCOL_ERROR, ERROR_MESSAGE, ShutdownHint.GRACEFUL_SHUTDOWN); + } + + public Http2NoMoreStreamIdsException(Throwable cause) { + super(PROTOCOL_ERROR, ERROR_MESSAGE, cause, ShutdownHint.GRACEFUL_SHUTDOWN); + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2OutboundFrameLogger.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2OutboundFrameLogger.java new file mode 100644 index 0000000..076985f --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2OutboundFrameLogger.java @@ -0,0 +1,139 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http2; + +import static io.netty.handler.codec.http2.Http2FrameLogger.Direction.OUTBOUND; +import static io.netty.util.internal.ObjectUtil.checkNotNull; +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.util.internal.UnstableApi; + +/** + * Decorator around a {@link Http2FrameWriter} that logs all outbound frames before calling the + * writer. + */ +@UnstableApi +public class Http2OutboundFrameLogger implements Http2FrameWriter { + private final Http2FrameWriter writer; + private final Http2FrameLogger logger; + + public Http2OutboundFrameLogger(Http2FrameWriter writer, Http2FrameLogger logger) { + this.writer = checkNotNull(writer, "writer"); + this.logger = checkNotNull(logger, "logger"); + } + + @Override + public ChannelFuture writeData(ChannelHandlerContext ctx, int streamId, ByteBuf data, + int padding, boolean endStream, ChannelPromise promise) { + logger.logData(OUTBOUND, ctx, streamId, data, padding, endStream); + return writer.writeData(ctx, streamId, data, padding, endStream, promise); + } + + @Override + public ChannelFuture writeHeaders(ChannelHandlerContext ctx, int streamId, + Http2Headers headers, int padding, boolean endStream, ChannelPromise promise) { + logger.logHeaders(OUTBOUND, ctx, streamId, headers, padding, endStream); + return writer.writeHeaders(ctx, streamId, headers, padding, endStream, promise); + } + + @Override + public ChannelFuture writeHeaders(ChannelHandlerContext ctx, int streamId, + Http2Headers headers, int streamDependency, short weight, boolean exclusive, + int padding, boolean endStream, ChannelPromise promise) { + logger.logHeaders(OUTBOUND, ctx, streamId, headers, streamDependency, weight, exclusive, + padding, endStream); + return writer.writeHeaders(ctx, streamId, headers, streamDependency, weight, + exclusive, padding, endStream, promise); + } + + @Override + public ChannelFuture writePriority(ChannelHandlerContext ctx, int streamId, + int streamDependency, short weight, boolean exclusive, ChannelPromise promise) { + logger.logPriority(OUTBOUND, ctx, streamId, streamDependency, weight, exclusive); + return writer.writePriority(ctx, streamId, streamDependency, weight, exclusive, promise); + } + + @Override + public ChannelFuture writeRstStream(ChannelHandlerContext ctx, + int streamId, long errorCode, ChannelPromise promise) { + logger.logRstStream(OUTBOUND, ctx, streamId, errorCode); + return writer.writeRstStream(ctx, streamId, errorCode, promise); + } + + @Override + public ChannelFuture writeSettings(ChannelHandlerContext ctx, + Http2Settings settings, ChannelPromise promise) { + logger.logSettings(OUTBOUND, ctx, settings); + return writer.writeSettings(ctx, settings, promise); + } + + @Override + public ChannelFuture writeSettingsAck(ChannelHandlerContext ctx, ChannelPromise promise) { + logger.logSettingsAck(OUTBOUND, ctx); + return writer.writeSettingsAck(ctx, promise); + } + + @Override + public ChannelFuture writePing(ChannelHandlerContext ctx, boolean ack, + long data, ChannelPromise promise) { + if (ack) { + logger.logPingAck(OUTBOUND, ctx, data); + } else { + logger.logPing(OUTBOUND, ctx, data); + } + return writer.writePing(ctx, ack, data, promise); + } + + @Override + public ChannelFuture writePushPromise(ChannelHandlerContext ctx, int streamId, + int promisedStreamId, Http2Headers headers, int padding, ChannelPromise promise) { + logger.logPushPromise(OUTBOUND, ctx, streamId, promisedStreamId, headers, padding); + return writer.writePushPromise(ctx, streamId, promisedStreamId, headers, padding, promise); + } + + @Override + public ChannelFuture writeGoAway(ChannelHandlerContext ctx, int lastStreamId, long errorCode, + ByteBuf debugData, ChannelPromise promise) { + logger.logGoAway(OUTBOUND, ctx, lastStreamId, errorCode, debugData); + return writer.writeGoAway(ctx, lastStreamId, errorCode, debugData, promise); + } + + @Override + public ChannelFuture writeWindowUpdate(ChannelHandlerContext ctx, + int streamId, int windowSizeIncrement, ChannelPromise promise) { + logger.logWindowsUpdate(OUTBOUND, ctx, streamId, windowSizeIncrement); + return writer.writeWindowUpdate(ctx, streamId, windowSizeIncrement, promise); + } + + @Override + public ChannelFuture writeFrame(ChannelHandlerContext ctx, byte frameType, int streamId, + Http2Flags flags, ByteBuf payload, ChannelPromise promise) { + logger.logUnknownFrame(OUTBOUND, ctx, frameType, streamId, flags, payload); + return writer.writeFrame(ctx, frameType, streamId, flags, payload, promise); + } + + @Override + public void close() { + writer.close(); + } + + @Override + public Configuration configuration() { + return writer.configuration(); + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2PingFrame.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2PingFrame.java new file mode 100644 index 0000000..16b1807 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2PingFrame.java @@ -0,0 +1,36 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.codec.http2; + +import io.netty.util.internal.UnstableApi; + +/** + * HTTP/2 PING Frame. + */ +@UnstableApi +public interface Http2PingFrame extends Http2Frame { + + /** + * When {@code true}, indicates that this ping is a ping response. + */ + boolean ack(); + + /** + * Returns the eight byte opaque data. + */ + long content(); +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2PriorityFrame.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2PriorityFrame.java new file mode 100644 index 0000000..403028f --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2PriorityFrame.java @@ -0,0 +1,44 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.util.internal.UnstableApi; + +/** + * HTTP/2 Priority Frame + */ +@UnstableApi +public interface Http2PriorityFrame extends Http2StreamFrame { + + /** + * Parent Stream Id of this Priority request + */ + int streamDependency(); + + /** + * Stream weight + */ + short weight(); + + /** + * Set to {@code true} if this stream is exclusive else set to {@code false} + */ + boolean exclusive(); + + @Override + Http2PriorityFrame stream(Http2FrameStream stream); + +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2PromisedRequestVerifier.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2PromisedRequestVerifier.java new file mode 100644 index 0000000..0642350 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2PromisedRequestVerifier.java @@ -0,0 +1,74 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.util.internal.UnstableApi; + +/** + * Provides an extensibility point for users to define the validity of push requests. + * @see [RFC 7540], Section 8.2. + */ +@UnstableApi +public interface Http2PromisedRequestVerifier { + /** + * Determine if a {@link Http2Headers} are authoritative for a particular {@link ChannelHandlerContext}. + * @param ctx The context on which the {@code headers} where received on. + * @param headers The headers to be verified. + * @return {@code true} if the {@code ctx} is authoritative for the {@code headers}, {@code false} otherwise. + * @see + * [RFC 7540], Section 10.1. + */ + boolean isAuthoritative(ChannelHandlerContext ctx, Http2Headers headers); + + /** + * Determine if a request is cacheable. + * @param headers The headers for a push request. + * @return {@code true} if the request associated with {@code headers} is known to be cacheable, + * {@code false} otherwise. + * @see [RFC 7231], Section 4.2.3. + */ + boolean isCacheable(Http2Headers headers); + + /** + * Determine if a request is safe. + * @param headers The headers for a push request. + * @return {@code true} if the request associated with {@code headers} is known to be safe, + * {@code false} otherwise. + * @see [RFC 7231], Section 4.2.1. + */ + boolean isSafe(Http2Headers headers); + + /** + * A default implementation of {@link Http2PromisedRequestVerifier} which always returns positive responses for + * all verification challenges. + */ + Http2PromisedRequestVerifier ALWAYS_VERIFY = new Http2PromisedRequestVerifier() { + @Override + public boolean isAuthoritative(ChannelHandlerContext ctx, Http2Headers headers) { + return true; + } + + @Override + public boolean isCacheable(Http2Headers headers) { + return true; + } + + @Override + public boolean isSafe(Http2Headers headers) { + return true; + } + }; +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2PushPromiseFrame.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2PushPromiseFrame.java new file mode 100644 index 0000000..dc5d7cb --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2PushPromiseFrame.java @@ -0,0 +1,55 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.util.internal.UnstableApi; + +/** + * HTTP/2 Push Promise Frame + */ +@UnstableApi +public interface Http2PushPromiseFrame extends Http2StreamFrame { + + /** + * Set the Promise {@link Http2FrameStream} object for this frame. + */ + Http2StreamFrame pushStream(Http2FrameStream stream); + + /** + * Returns the Promise {@link Http2FrameStream} object for this frame, or {@code null} if the + * frame has yet to be associated with a stream. + */ + Http2FrameStream pushStream(); + + /** + * {@link Http2Headers} sent in Push Promise + */ + Http2Headers http2Headers(); + + /** + * Frame padding to use. Will be non-negative and less than 256. + */ + int padding(); + + /** + * Promised Stream ID + */ + int promisedStreamId(); + + @Override + Http2PushPromiseFrame stream(Http2FrameStream stream); + +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2RemoteFlowController.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2RemoteFlowController.java new file mode 100644 index 0000000..acec99c --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2RemoteFlowController.java @@ -0,0 +1,170 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.util.internal.UnstableApi; + +/** + * A {@link Http2FlowController} for controlling the flow of outbound {@code DATA} frames to the remote + * endpoint. + */ +@UnstableApi +public interface Http2RemoteFlowController extends Http2FlowController { + /** + * Get the {@link ChannelHandlerContext} for which to apply flow control on. + *

+ * This is intended for us by {@link FlowControlled} implementations only. Use with caution. + * @return The {@link ChannelHandlerContext} for which to apply flow control on. + */ + ChannelHandlerContext channelHandlerContext(); + + /** + * Queues a payload for transmission to the remote endpoint. There is no guarantee as to when the data + * will be written or how it will be assigned to frames. + * before sending. + *

+ * Writes do not actually occur until {@link #writePendingBytes()} is called. + * + * @param stream the subject stream. Must not be the connection stream object. + * @param payload payload to write subject to flow-control accounting and ordering rules. + */ + void addFlowControlled(Http2Stream stream, FlowControlled payload); + + /** + * Determine if {@code stream} has any {@link FlowControlled} frames currently queued. + * @param stream the stream to check if it has flow controlled frames. + * @return {@code true} if {@code stream} has any {@link FlowControlled} frames currently queued. + */ + boolean hasFlowControlled(Http2Stream stream); + + /** + * Write all data pending in the flow controller up to the flow-control limits. + * + * @throws Http2Exception throws if a protocol-related error occurred. + */ + void writePendingBytes() throws Http2Exception; + + /** + * Set the active listener on the flow-controller. + * + * @param listener to notify when the a write occurs, can be {@code null}. + */ + void listener(Listener listener); + + /** + * Determine if the {@code stream} has bytes remaining for use in the flow control window. + *

+ * Note that this method respects channel writability. The channel must be writable for this method to + * return {@code true}. + * + * @param stream The stream to test. + * @return {@code true} if the {@code stream} has bytes remaining for use in the flow control window and the + * channel is writable, {@code false} otherwise. + */ + boolean isWritable(Http2Stream stream); + + /** + * Notification that the writability of {@link #channelHandlerContext()} has changed. + * @throws Http2Exception If any writes occur as a result of this call and encounter errors. + */ + void channelWritabilityChanged() throws Http2Exception; + + /** + * Explicitly update the dependency tree. This method is called independently of stream state changes. + * @param childStreamId The stream identifier associated with the child stream. + * @param parentStreamId The stream identifier associated with the parent stream. May be {@code 0}, + * to make {@code childStreamId} and immediate child of the connection. + * @param weight The weight which is used relative to other child streams for {@code parentStreamId}. This value + * must be between 1 and 256 (inclusive). + * @param exclusive If {@code childStreamId} should be the exclusive dependency of {@code parentStreamId}. + */ + void updateDependencyTree(int childStreamId, int parentStreamId, short weight, boolean exclusive); + + /** + * Implementations of this interface are used to progressively write chunks of the underlying + * payload to the stream. A payload is considered to be fully written if {@link #write} has + * been called at least once and it's {@link #size} is now zero. + */ + interface FlowControlled { + /** + * The size of the payload in terms of bytes applied to the flow-control window. + * Some payloads like {@code HEADER} frames have no cost against flow control and would + * return 0 for this value even though they produce a non-zero number of bytes on + * the wire. Other frames like {@code DATA} frames have both their payload and padding count + * against flow-control. + */ + int size(); + + /** + * Called to indicate that an error occurred before this object could be completely written. + *

+ * The {@link Http2RemoteFlowController} will make exactly one call to either + * this method or {@link #writeComplete()}. + *

+ * + * @param ctx The context to use if any communication needs to occur as a result of the error. + * This may be {@code null} if an exception occurs when the connection has not been established yet. + * @param cause of the error. + */ + void error(ChannelHandlerContext ctx, Throwable cause); + + /** + * Called after this object has been successfully written. + *

+ * The {@link Http2RemoteFlowController} will make exactly one call to either + * this method or {@link #error(ChannelHandlerContext, Throwable)}. + *

+ */ + void writeComplete(); + + /** + * Writes up to {@code allowedBytes} of the encapsulated payload to the stream. Note that + * a value of 0 may be passed which will allow payloads with flow-control size == 0 to be + * written. The flow-controller may call this method multiple times with different values until + * the payload is fully written, i.e it's size after the write is 0. + *

+ * When an exception is thrown the {@link Http2RemoteFlowController} will make a call to + * {@link #error(ChannelHandlerContext, Throwable)}. + *

+ * + * @param ctx The context to use for writing. + * @param allowedBytes an upper bound on the number of bytes the payload can write at this time. + */ + void write(ChannelHandlerContext ctx, int allowedBytes); + + /** + * Merge the contents of the {@code next} message into this message so they can be written out as one unit. + * This allows many small messages to be written as a single DATA frame. + * + * @return {@code true} if {@code next} was successfully merged and does not need to be enqueued, + * {@code false} otherwise. + */ + boolean merge(ChannelHandlerContext ctx, FlowControlled next); + } + + /** + * Listener to the number of flow-controlled bytes written per stream. + */ + interface Listener { + /** + * Notification that {@link Http2RemoteFlowController#isWritable(Http2Stream)} has changed for {@code stream}. + *

+ * This method should not throw. Any thrown exceptions are considered a programming error and are ignored. + * @param stream The stream which writability has changed for. + */ + void writabilityChanged(Http2Stream stream); + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ResetFrame.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ResetFrame.java new file mode 100644 index 0000000..431a572 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ResetFrame.java @@ -0,0 +1,28 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.util.internal.UnstableApi; + +/** HTTP/2 RST_STREAM frame. */ +@UnstableApi +public interface Http2ResetFrame extends Http2StreamFrame { + + /** + * The reason for resetting the stream. Represented as an HTTP/2 error code. + */ + long errorCode(); +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2SecurityUtil.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2SecurityUtil.java new file mode 100644 index 0000000..5e3ef1f --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2SecurityUtil.java @@ -0,0 +1,80 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.util.internal.UnstableApi; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +/** + * Provides utilities related to security requirements specific to HTTP/2. + */ +@UnstableApi +public final class Http2SecurityUtil { + /** + * The following list is derived from SunJSSE Supported + * Ciphers and Mozilla Modern Cipher + * Suites in accordance with the HTTP/2 Specification. + * + * According to the + * JSSE documentation "the names mentioned in the TLS RFCs prefixed with TLS_ are functionally equivalent + * to the JSSE cipher suites prefixed with SSL_". + * Both variants are used to support JVMs supporting the one or the other. + */ + public static final List CIPHERS; + + /** + * Mozilla Modern Cipher Suites Intermediate compatibility minus the following cipher suites that are black + * listed by the HTTP/2 RFC. + */ + private static final List CIPHERS_JAVA_MOZILLA_MODERN_SECURITY = Collections.unmodifiableList(Arrays + .asList( + /* openssl = ECDHE-ECDSA-AES128-GCM-SHA256 */ + "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256", + + /* REQUIRED BY HTTP/2 SPEC */ + /* openssl = ECDHE-RSA-AES128-GCM-SHA256 */ + "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", + /* REQUIRED BY HTTP/2 SPEC */ + + /* openssl = ECDHE-ECDSA-AES256-GCM-SHA384 */ + "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384", + /* openssl = ECDHE-RSA-AES256-GCM-SHA384 */ + "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384", + /* openssl = ECDHE-ECDSA-CHACHA20-POLY1305 */ + "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256", + /* openssl = ECDHE-RSA-CHACHA20-POLY1305 */ + "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256", + + /* TLS 1.3 ciphers */ + "TLS_AES_128_GCM_SHA256", + "TLS_AES_256_GCM_SHA384", + "TLS_CHACHA20_POLY1305_SHA256" + )); + + static { + CIPHERS = Collections.unmodifiableList(new ArrayList(CIPHERS_JAVA_MOZILLA_MODERN_SECURITY)); + } + + private Http2SecurityUtil() { } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ServerUpgradeCodec.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ServerUpgradeCodec.java new file mode 100644 index 0000000..e5689ef --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2ServerUpgradeCodec.java @@ -0,0 +1,213 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.base64.Base64; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.HttpHeaders; +import io.netty.handler.codec.http.HttpServerUpgradeHandler; +import io.netty.util.CharsetUtil; +import io.netty.util.internal.UnstableApi; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.nio.CharBuffer; +import java.util.Collection; +import java.util.Collections; +import java.util.List; + +import static io.netty.handler.codec.base64.Base64Dialect.URL_SAFE; +import static io.netty.handler.codec.http2.Http2CodecUtil.FRAME_HEADER_LENGTH; +import static io.netty.handler.codec.http2.Http2CodecUtil.HTTP_UPGRADE_SETTINGS_HEADER; +import static io.netty.handler.codec.http2.Http2CodecUtil.writeFrameHeader; +import static io.netty.handler.codec.http2.Http2FrameTypes.SETTINGS; + +/** + * Server-side codec for performing a cleartext upgrade from HTTP/1.x to HTTP/2. + */ +@UnstableApi +public class Http2ServerUpgradeCodec implements HttpServerUpgradeHandler.UpgradeCodec { + + private static final InternalLogger logger = InternalLoggerFactory.getInstance(Http2ServerUpgradeCodec.class); + private static final List REQUIRED_UPGRADE_HEADERS = + Collections.singletonList(HTTP_UPGRADE_SETTINGS_HEADER); + private static final ChannelHandler[] EMPTY_HANDLERS = new ChannelHandler[0]; + + private final String handlerName; + private final Http2ConnectionHandler connectionHandler; + private final ChannelHandler[] handlers; + private final Http2FrameReader frameReader; + + private Http2Settings settings; + + /** + * Creates the codec using a default name for the connection handler when adding to the + * pipeline. + * + * @param connectionHandler the HTTP/2 connection handler + */ + public Http2ServerUpgradeCodec(Http2ConnectionHandler connectionHandler) { + this(null, connectionHandler, EMPTY_HANDLERS); + } + + /** + * Creates the codec using a default name for the connection handler when adding to the + * pipeline. + * + * @param http2Codec the HTTP/2 multiplexing handler. + */ + public Http2ServerUpgradeCodec(Http2MultiplexCodec http2Codec) { + this(null, http2Codec, EMPTY_HANDLERS); + } + + /** + * Creates the codec providing an upgrade to the given handler for HTTP/2. + * + * @param handlerName the name of the HTTP/2 connection handler to be used in the pipeline, + * or {@code null} to auto-generate the name + * @param connectionHandler the HTTP/2 connection handler + */ + public Http2ServerUpgradeCodec(String handlerName, Http2ConnectionHandler connectionHandler) { + this(handlerName, connectionHandler, EMPTY_HANDLERS); + } + + /** + * Creates the codec providing an upgrade to the given handler for HTTP/2. + * + * @param handlerName the name of the HTTP/2 connection handler to be used in the pipeline. + * @param http2Codec the HTTP/2 multiplexing handler. + */ + public Http2ServerUpgradeCodec(String handlerName, Http2MultiplexCodec http2Codec) { + this(handlerName, http2Codec, EMPTY_HANDLERS); + } + + /** + * Creates the codec using a default name for the connection handler when adding to the + * pipeline. + * + * @param http2Codec the HTTP/2 frame handler. + * @param handlers the handlers that will handle the {@link Http2Frame}s. + */ + public Http2ServerUpgradeCodec(Http2FrameCodec http2Codec, ChannelHandler... handlers) { + this(null, http2Codec, handlers); + } + + private Http2ServerUpgradeCodec(String handlerName, Http2ConnectionHandler connectionHandler, + ChannelHandler... handlers) { + this.handlerName = handlerName; + this.connectionHandler = connectionHandler; + this.handlers = handlers; + frameReader = new DefaultHttp2FrameReader(); + } + + @Override + public Collection requiredUpgradeHeaders() { + return REQUIRED_UPGRADE_HEADERS; + } + + @Override + public boolean prepareUpgradeResponse(ChannelHandlerContext ctx, FullHttpRequest upgradeRequest, + HttpHeaders headers) { + try { + // Decode the HTTP2-Settings header and set the settings on the handler to make + // sure everything is fine with the request. + List upgradeHeaders = upgradeRequest.headers().getAll(HTTP_UPGRADE_SETTINGS_HEADER); + if (upgradeHeaders.size() != 1) { + throw new IllegalArgumentException("There must be 1 and only 1 " + + HTTP_UPGRADE_SETTINGS_HEADER + " header."); + } + settings = decodeSettingsHeader(ctx, upgradeHeaders.get(0)); + // Everything looks good. + return true; + } catch (Throwable cause) { + logger.info("Error during upgrade to HTTP/2", cause); + return false; + } + } + + @Override + public void upgradeTo(final ChannelHandlerContext ctx, FullHttpRequest upgradeRequest) { + try { + // Add the HTTP/2 connection handler to the pipeline immediately following the current handler. + ctx.pipeline().addAfter(ctx.name(), handlerName, connectionHandler); + + // Add also all extra handlers as these may handle events / messages produced by the connectionHandler. + // See https://github.com/netty/netty/issues/9314 + if (handlers != null) { + final String name = ctx.pipeline().context(connectionHandler).name(); + for (int i = handlers.length - 1; i >= 0; i--) { + ctx.pipeline().addAfter(name, null, handlers[i]); + } + } + connectionHandler.onHttpServerUpgrade(settings); + } catch (Http2Exception e) { + ctx.fireExceptionCaught(e); + ctx.close(); + } + } + + /** + * Decodes the settings header and returns a {@link Http2Settings} object. + */ + private Http2Settings decodeSettingsHeader(ChannelHandlerContext ctx, CharSequence settingsHeader) + throws Http2Exception { + ByteBuf header = ByteBufUtil.encodeString(ctx.alloc(), CharBuffer.wrap(settingsHeader), CharsetUtil.UTF_8); + try { + // Decode the SETTINGS payload. + ByteBuf payload = Base64.decode(header, URL_SAFE); + + // Create an HTTP/2 frame for the settings. + ByteBuf frame = createSettingsFrame(ctx, payload); + + // Decode the SETTINGS frame and return the settings object. + return decodeSettings(ctx, frame); + } finally { + header.release(); + } + } + + /** + * Decodes the settings frame and returns the settings. + */ + private Http2Settings decodeSettings(ChannelHandlerContext ctx, ByteBuf frame) throws Http2Exception { + try { + final Http2Settings decodedSettings = new Http2Settings(); + frameReader.readFrame(ctx, frame, new Http2FrameAdapter() { + @Override + public void onSettingsRead(ChannelHandlerContext ctx, Http2Settings settings) { + decodedSettings.copyFrom(settings); + } + }); + return decodedSettings; + } finally { + frame.release(); + } + } + + /** + * Creates an HTTP2-Settings header with the given payload. The payload buffer is released. + */ + private static ByteBuf createSettingsFrame(ChannelHandlerContext ctx, ByteBuf payload) { + ByteBuf frame = ctx.alloc().buffer(FRAME_HEADER_LENGTH + payload.readableBytes()); + writeFrameHeader(frame, payload.readableBytes(), SETTINGS, new Http2Flags(), 0); + frame.writeBytes(payload); + payload.release(); + return frame; + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2Settings.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2Settings.java new file mode 100644 index 0000000..1d16145 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2Settings.java @@ -0,0 +1,282 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package io.netty.handler.codec.http2; + +import io.netty.util.collection.CharObjectHashMap; +import io.netty.util.internal.UnstableApi; + +import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_HEADER_LIST_SIZE; +import static io.netty.handler.codec.http2.Http2CodecUtil.MAX_CONCURRENT_STREAMS; +import static io.netty.handler.codec.http2.Http2CodecUtil.MAX_FRAME_SIZE_LOWER_BOUND; +import static io.netty.handler.codec.http2.Http2CodecUtil.MAX_FRAME_SIZE_UPPER_BOUND; +import static io.netty.handler.codec.http2.Http2CodecUtil.MAX_HEADER_LIST_SIZE; +import static io.netty.handler.codec.http2.Http2CodecUtil.MAX_HEADER_TABLE_SIZE; +import static io.netty.handler.codec.http2.Http2CodecUtil.MAX_INITIAL_WINDOW_SIZE; +import static io.netty.handler.codec.http2.Http2CodecUtil.MAX_UNSIGNED_INT; +import static io.netty.handler.codec.http2.Http2CodecUtil.MIN_CONCURRENT_STREAMS; +import static io.netty.handler.codec.http2.Http2CodecUtil.MIN_HEADER_LIST_SIZE; +import static io.netty.handler.codec.http2.Http2CodecUtil.MIN_HEADER_TABLE_SIZE; +import static io.netty.handler.codec.http2.Http2CodecUtil.MIN_INITIAL_WINDOW_SIZE; +import static io.netty.handler.codec.http2.Http2CodecUtil.NUM_STANDARD_SETTINGS; +import static io.netty.handler.codec.http2.Http2CodecUtil.SETTINGS_ENABLE_PUSH; +import static io.netty.handler.codec.http2.Http2CodecUtil.SETTINGS_HEADER_TABLE_SIZE; +import static io.netty.handler.codec.http2.Http2CodecUtil.SETTINGS_INITIAL_WINDOW_SIZE; +import static io.netty.handler.codec.http2.Http2CodecUtil.SETTINGS_MAX_CONCURRENT_STREAMS; +import static io.netty.handler.codec.http2.Http2CodecUtil.SETTINGS_MAX_FRAME_SIZE; +import static io.netty.handler.codec.http2.Http2CodecUtil.SETTINGS_MAX_HEADER_LIST_SIZE; +import static io.netty.handler.codec.http2.Http2CodecUtil.isMaxFrameSizeValid; +import static io.netty.util.internal.ObjectUtil.checkNotNull; +import static java.lang.Integer.toHexString; + +/** + * Settings for one endpoint in an HTTP/2 connection. Each of the values are optional as defined in + * the spec for the SETTINGS frame. Permits storage of arbitrary key/value pairs but provides helper + * methods for standard settings. + */ +@UnstableApi +public final class Http2Settings extends CharObjectHashMap { + /** + * Default capacity based on the number of standard settings from the HTTP/2 spec, adjusted so that adding all of + * the standard settings will not cause the map capacity to change. + */ + private static final int DEFAULT_CAPACITY = (int) (NUM_STANDARD_SETTINGS / DEFAULT_LOAD_FACTOR) + 1; + private static final Long FALSE = 0L; + private static final Long TRUE = 1L; + + public Http2Settings() { + this(DEFAULT_CAPACITY); + } + + public Http2Settings(int initialCapacity, float loadFactor) { + super(initialCapacity, loadFactor); + } + + public Http2Settings(int initialCapacity) { + super(initialCapacity); + } + + /** + * Adds the given setting key/value pair. For standard settings defined by the HTTP/2 spec, performs + * validation on the values. + * + * @throws IllegalArgumentException if verification for a standard HTTP/2 setting fails. + */ + @Override + public Long put(char key, Long value) { + verifyStandardSetting(key, value); + return super.put(key, value); + } + + /** + * Gets the {@code SETTINGS_HEADER_TABLE_SIZE} value. If unavailable, returns {@code null}. + */ + public Long headerTableSize() { + return get(SETTINGS_HEADER_TABLE_SIZE); + } + + /** + * Sets the {@code SETTINGS_HEADER_TABLE_SIZE} value. + * + * @throws IllegalArgumentException if verification of the setting fails. + */ + public Http2Settings headerTableSize(long value) { + put(SETTINGS_HEADER_TABLE_SIZE, Long.valueOf(value)); + return this; + } + + /** + * Gets the {@code SETTINGS_ENABLE_PUSH} value. If unavailable, returns {@code null}. + */ + public Boolean pushEnabled() { + Long value = get(SETTINGS_ENABLE_PUSH); + if (value == null) { + return null; + } + return TRUE.equals(value); + } + + /** + * Sets the {@code SETTINGS_ENABLE_PUSH} value. + */ + public Http2Settings pushEnabled(boolean enabled) { + put(SETTINGS_ENABLE_PUSH, enabled ? TRUE : FALSE); + return this; + } + + /** + * Gets the {@code SETTINGS_MAX_CONCURRENT_STREAMS} value. If unavailable, returns {@code null}. + */ + public Long maxConcurrentStreams() { + return get(SETTINGS_MAX_CONCURRENT_STREAMS); + } + + /** + * Sets the {@code SETTINGS_MAX_CONCURRENT_STREAMS} value. + * + * @throws IllegalArgumentException if verification of the setting fails. + */ + public Http2Settings maxConcurrentStreams(long value) { + put(SETTINGS_MAX_CONCURRENT_STREAMS, Long.valueOf(value)); + return this; + } + + /** + * Gets the {@code SETTINGS_INITIAL_WINDOW_SIZE} value. If unavailable, returns {@code null}. + */ + public Integer initialWindowSize() { + return getIntValue(SETTINGS_INITIAL_WINDOW_SIZE); + } + + /** + * Sets the {@code SETTINGS_INITIAL_WINDOW_SIZE} value. + * + * @throws IllegalArgumentException if verification of the setting fails. + */ + public Http2Settings initialWindowSize(int value) { + put(SETTINGS_INITIAL_WINDOW_SIZE, Long.valueOf(value)); + return this; + } + + /** + * Gets the {@code SETTINGS_MAX_FRAME_SIZE} value. If unavailable, returns {@code null}. + */ + public Integer maxFrameSize() { + return getIntValue(SETTINGS_MAX_FRAME_SIZE); + } + + /** + * Sets the {@code SETTINGS_MAX_FRAME_SIZE} value. + * + * @throws IllegalArgumentException if verification of the setting fails. + */ + public Http2Settings maxFrameSize(int value) { + put(SETTINGS_MAX_FRAME_SIZE, Long.valueOf(value)); + return this; + } + + /** + * Gets the {@code SETTINGS_MAX_HEADER_LIST_SIZE} value. If unavailable, returns {@code null}. + */ + public Long maxHeaderListSize() { + return get(SETTINGS_MAX_HEADER_LIST_SIZE); + } + + /** + * Sets the {@code SETTINGS_MAX_HEADER_LIST_SIZE} value. + * + * @throws IllegalArgumentException if verification of the setting fails. + */ + public Http2Settings maxHeaderListSize(long value) { + put(SETTINGS_MAX_HEADER_LIST_SIZE, Long.valueOf(value)); + return this; + } + + /** + * Clears and then copies the given settings into this object. + */ + public Http2Settings copyFrom(Http2Settings settings) { + clear(); + putAll(settings); + return this; + } + + /** + * A helper method that returns {@link Long#intValue()} on the return of {@link #get(char)}, if present. Note that + * if the range of the value exceeds {@link Integer#MAX_VALUE}, the {@link #get(char)} method should + * be used instead to avoid truncation of the value. + */ + public Integer getIntValue(char key) { + Long value = get(key); + if (value == null) { + return null; + } + return value.intValue(); + } + + private static void verifyStandardSetting(int key, Long value) { + checkNotNull(value, "value"); + switch (key) { + case SETTINGS_HEADER_TABLE_SIZE: + if (value < MIN_HEADER_TABLE_SIZE || value > MAX_HEADER_TABLE_SIZE) { + throw new IllegalArgumentException("Setting HEADER_TABLE_SIZE is invalid: " + value + + ", expected [" + MIN_HEADER_TABLE_SIZE + ", " + MAX_HEADER_TABLE_SIZE + ']'); + } + break; + case SETTINGS_ENABLE_PUSH: + if (value != 0L && value != 1L) { + throw new IllegalArgumentException("Setting ENABLE_PUSH is invalid: " + value + + ", expected [0, 1]"); + } + break; + case SETTINGS_MAX_CONCURRENT_STREAMS: + if (value < MIN_CONCURRENT_STREAMS || value > MAX_CONCURRENT_STREAMS) { + throw new IllegalArgumentException("Setting MAX_CONCURRENT_STREAMS is invalid: " + value + + ", expected [" + MIN_CONCURRENT_STREAMS + ", " + MAX_CONCURRENT_STREAMS + ']'); + } + break; + case SETTINGS_INITIAL_WINDOW_SIZE: + if (value < MIN_INITIAL_WINDOW_SIZE || value > MAX_INITIAL_WINDOW_SIZE) { + throw new IllegalArgumentException("Setting INITIAL_WINDOW_SIZE is invalid: " + value + + ", expected [" + MIN_INITIAL_WINDOW_SIZE + ", " + MAX_INITIAL_WINDOW_SIZE + ']'); + } + break; + case SETTINGS_MAX_FRAME_SIZE: + if (!isMaxFrameSizeValid(value.intValue())) { + throw new IllegalArgumentException("Setting MAX_FRAME_SIZE is invalid: " + value + + ", expected [" + MAX_FRAME_SIZE_LOWER_BOUND + ", " + MAX_FRAME_SIZE_UPPER_BOUND + ']'); + } + break; + case SETTINGS_MAX_HEADER_LIST_SIZE: + if (value < MIN_HEADER_LIST_SIZE || value > MAX_HEADER_LIST_SIZE) { + throw new IllegalArgumentException("Setting MAX_HEADER_LIST_SIZE is invalid: " + value + + ", expected [" + MIN_HEADER_LIST_SIZE + ", " + MAX_HEADER_LIST_SIZE + ']'); + } + break; + default: + // Non-standard HTTP/2 setting + if (value < 0 || value > MAX_UNSIGNED_INT) { + throw new IllegalArgumentException("Non-standard setting 0x" + toHexString(key) + " is invalid: " + + value + ", expected unsigned 32-bit value"); + } + break; + } + } + + @Override + protected String keyToString(char key) { + switch (key) { + case SETTINGS_HEADER_TABLE_SIZE: + return "HEADER_TABLE_SIZE"; + case SETTINGS_ENABLE_PUSH: + return "ENABLE_PUSH"; + case SETTINGS_MAX_CONCURRENT_STREAMS: + return "MAX_CONCURRENT_STREAMS"; + case SETTINGS_INITIAL_WINDOW_SIZE: + return "INITIAL_WINDOW_SIZE"; + case SETTINGS_MAX_FRAME_SIZE: + return "MAX_FRAME_SIZE"; + case SETTINGS_MAX_HEADER_LIST_SIZE: + return "MAX_HEADER_LIST_SIZE"; + default: + // Unknown keys. + return "0x" + toHexString(key); + } + } + + public static Http2Settings defaultSettings() { + return new Http2Settings().maxHeaderListSize(DEFAULT_HEADER_LIST_SIZE); + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2SettingsAckFrame.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2SettingsAckFrame.java new file mode 100644 index 0000000..a497e87 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2SettingsAckFrame.java @@ -0,0 +1,29 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http2; + +/** + * An ack for a previously received {@link Http2SettingsFrame}. + *

+ * The HTTP/2 protocol enforces that ACKs are applied in + * order, so this ACK will apply to the earliest received and not yet ACKed {@link Http2SettingsFrame} frame. + */ +public interface Http2SettingsAckFrame extends Http2Frame { + Http2SettingsAckFrame INSTANCE = new DefaultHttp2SettingsAckFrame(); + + @Override + String name(); +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2SettingsFrame.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2SettingsFrame.java new file mode 100644 index 0000000..f809062 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2SettingsFrame.java @@ -0,0 +1,28 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.codec.http2; + +/** + * HTTP/2 SETTINGS frame. + */ +public interface Http2SettingsFrame extends Http2Frame { + + Http2Settings settings(); + + @Override + String name(); +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2SettingsReceivedConsumer.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2SettingsReceivedConsumer.java new file mode 100644 index 0000000..69ba718 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2SettingsReceivedConsumer.java @@ -0,0 +1,25 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +/** + * Provides a Consumer like interface to consume remote settings received but not yet ACKed. + */ +public interface Http2SettingsReceivedConsumer { + /** + * Consume the most recently received but not yet ACKed settings. + */ + void consumeReceivedSettings(Http2Settings settings); +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2Stream.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2Stream.java new file mode 100644 index 0000000..202f6ba --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2Stream.java @@ -0,0 +1,177 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package io.netty.handler.codec.http2; + +import io.netty.util.internal.UnstableApi; + +/** + * A single stream within an HTTP2 connection. Streams are compared to each other by priority. + */ +@UnstableApi +public interface Http2Stream { + + /** + * The allowed states of an HTTP2 stream. + */ + enum State { + IDLE(false, false), + RESERVED_LOCAL(false, false), + RESERVED_REMOTE(false, false), + OPEN(true, true), + HALF_CLOSED_LOCAL(false, true), + HALF_CLOSED_REMOTE(true, false), + CLOSED(false, false); + + private final boolean localSideOpen; + private final boolean remoteSideOpen; + + State(boolean localSideOpen, boolean remoteSideOpen) { + this.localSideOpen = localSideOpen; + this.remoteSideOpen = remoteSideOpen; + } + + /** + * Indicates whether the local side of this stream is open (i.e. the state is either + * {@link State#OPEN} or {@link State#HALF_CLOSED_REMOTE}). + */ + public boolean localSideOpen() { + return localSideOpen; + } + + /** + * Indicates whether the remote side of this stream is open (i.e. the state is either + * {@link State#OPEN} or {@link State#HALF_CLOSED_LOCAL}). + */ + public boolean remoteSideOpen() { + return remoteSideOpen; + } + } + + /** + * Gets the unique identifier for this stream within the connection. + */ + int id(); + + /** + * Gets the state of this stream. + */ + State state(); + + /** + * Opens this stream, making it available via {@link Http2Connection#forEachActiveStream(Http2StreamVisitor)} and + * transition state to: + *

    + *
  • {@link State#OPEN} if {@link #state()} is {@link State#IDLE} and {@code halfClosed} is {@code false}.
  • + *
  • {@link State#HALF_CLOSED_LOCAL} if {@link #state()} is {@link State#IDLE} and {@code halfClosed} + * is {@code true} and the stream is local. In this state, {@link #isHeadersSent()} is {@code true}
  • + *
  • {@link State#HALF_CLOSED_REMOTE} if {@link #state()} is {@link State#IDLE} and {@code halfClosed} + * is {@code true} and the stream is remote. In this state, {@link #isHeadersReceived()} is {@code true}
  • + *
  • {@link State#RESERVED_LOCAL} if {@link #state()} is {@link State#HALF_CLOSED_REMOTE}.
  • + *
  • {@link State#RESERVED_REMOTE} if {@link #state()} is {@link State#HALF_CLOSED_LOCAL}.
  • + *
+ */ + Http2Stream open(boolean halfClosed) throws Http2Exception; + + /** + * Closes the stream. + */ + Http2Stream close(); + + /** + * Closes the local side of this stream. If this makes the stream closed, the child is closed as + * well. + */ + Http2Stream closeLocalSide(); + + /** + * Closes the remote side of this stream. If this makes the stream closed, the child is closed + * as well. + */ + Http2Stream closeRemoteSide(); + + /** + * Indicates whether a {@code RST_STREAM} frame has been sent from the local endpoint for this stream. + */ + boolean isResetSent(); + + /** + * Sets the flag indicating that a {@code RST_STREAM} frame has been sent from the local endpoint + * for this stream. This does not affect the stream state. + */ + Http2Stream resetSent(); + + /** + * Associates the application-defined data with this stream. + * @return The value that was previously associated with {@code key}, or {@code null} if there was none. + */ + V setProperty(Http2Connection.PropertyKey key, V value); + + /** + * Returns application-defined data if any was associated with this stream. + */ + V getProperty(Http2Connection.PropertyKey key); + + /** + * Returns and removes application-defined data if any was associated with this stream. + */ + V removeProperty(Http2Connection.PropertyKey key); + + /** + * Indicates that headers have been sent to the remote endpoint on this stream. The first call to this method would + * be for the initial headers (see {@link #isHeadersSent()}} and the second call would indicate the trailers + * (see {@link #isTrailersReceived()}). + * @param isInformational {@code true} if the headers contain an informational status code (for responses only). + */ + Http2Stream headersSent(boolean isInformational); + + /** + * Indicates whether or not headers were sent to the remote endpoint. + */ + boolean isHeadersSent(); + + /** + * Indicates whether or not trailers were sent to the remote endpoint. + */ + boolean isTrailersSent(); + + /** + * Indicates that headers have been received. The first call to this method would be for the initial headers + * (see {@link #isHeadersReceived()}} and the second call would indicate the trailers + * (see {@link #isTrailersReceived()}). + * @param isInformational {@code true} if the headers contain an informational status code (for responses only). + */ + Http2Stream headersReceived(boolean isInformational); + + /** + * Indicates whether or not the initial headers have been received. + */ + boolean isHeadersReceived(); + + /** + * Indicates whether or not the trailers have been received. + */ + boolean isTrailersReceived(); + + /** + * Indicates that a push promise was sent to the remote endpoint. + */ + Http2Stream pushPromiseSent(); + + /** + * Indicates whether or not a push promise was sent to the remote endpoint. + */ + boolean isPushPromiseSent(); +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2StreamChannel.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2StreamChannel.java new file mode 100644 index 0000000..819acac --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2StreamChannel.java @@ -0,0 +1,33 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.codec.http2; + +import io.netty.channel.Channel; +import io.netty.util.internal.UnstableApi; + +// TODO: Should we have an extra method to "open" the stream and so Channel and take care of sending the +// Http2HeadersFrame under the hood ? +// TODO: Should we extend SocketChannel and map input and output state to the stream state ? +// +@UnstableApi +public interface Http2StreamChannel extends Channel { + + /** + * Returns the {@link Http2FrameStream} that belongs to this channel. + */ + Http2FrameStream stream(); +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2StreamChannelBootstrap.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2StreamChannelBootstrap.java new file mode 100644 index 0000000..057dfcf --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2StreamChannelBootstrap.java @@ -0,0 +1,256 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelOption; +import io.netty.channel.ChannelPipeline; +import io.netty.util.AttributeKey; +import io.netty.util.concurrent.EventExecutor; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.Promise; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.StringUtil; +import io.netty.util.internal.UnstableApi; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.nio.channels.ClosedChannelException; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +@UnstableApi +public final class Http2StreamChannelBootstrap { + private static final InternalLogger logger = InternalLoggerFactory.getInstance(Http2StreamChannelBootstrap.class); + @SuppressWarnings("unchecked") + private static final Map.Entry, Object>[] EMPTY_OPTION_ARRAY = new Map.Entry[0]; + @SuppressWarnings("unchecked") + private static final Map.Entry, Object>[] EMPTY_ATTRIBUTE_ARRAY = new Map.Entry[0]; + + // The order in which ChannelOptions are applied is important they may depend on each other for validation + // purposes. + private final Map, Object> options = new LinkedHashMap, Object>(); + private final Map, Object> attrs = new ConcurrentHashMap, Object>(); + private final Channel channel; + private volatile ChannelHandler handler; + + // Cache the ChannelHandlerContext to speed up open(...) operations. + private volatile ChannelHandlerContext multiplexCtx; + + public Http2StreamChannelBootstrap(Channel channel) { + this.channel = ObjectUtil.checkNotNull(channel, "channel"); + } + + /** + * Allow to specify a {@link ChannelOption} which is used for the {@link Http2StreamChannel} instances once they got + * created. Use a value of {@code null} to remove a previous set {@link ChannelOption}. + */ + public Http2StreamChannelBootstrap option(ChannelOption option, T value) { + ObjectUtil.checkNotNull(option, "option"); + + synchronized (options) { + if (value == null) { + options.remove(option); + } else { + options.put(option, value); + } + } + return this; + } + + /** + * Allow to specify an initial attribute of the newly created {@link Http2StreamChannel}. If the {@code value} is + * {@code null}, the attribute of the specified {@code key} is removed. + */ + public Http2StreamChannelBootstrap attr(AttributeKey key, T value) { + ObjectUtil.checkNotNull(key, "key"); + if (value == null) { + attrs.remove(key); + } else { + attrs.put(key, value); + } + return this; + } + + /** + * the {@link ChannelHandler} to use for serving the requests. + */ + public Http2StreamChannelBootstrap handler(ChannelHandler handler) { + this.handler = ObjectUtil.checkNotNull(handler, "handler"); + return this; + } + + /** + * Open a new {@link Http2StreamChannel} to use. + * @return the {@link Future} that will be notified once the channel was opened successfully or it failed. + */ + public Future open() { + return open(channel.eventLoop().newPromise()); + } + + /** + * Open a new {@link Http2StreamChannel} to use and notifies the given {@link Promise}. + * @return the {@link Future} that will be notified once the channel was opened successfully or it failed. + */ + public Future open(final Promise promise) { + try { + ChannelHandlerContext ctx = findCtx(); + EventExecutor executor = ctx.executor(); + if (executor.inEventLoop()) { + open0(ctx, promise); + } else { + final ChannelHandlerContext finalCtx = ctx; + executor.execute(new Runnable() { + @Override + public void run() { + if (channel.isActive()) { + open0(finalCtx, promise); + } else { + promise.setFailure(new ClosedChannelException()); + } + } + }); + } + } catch (Throwable cause) { + promise.setFailure(cause); + } + return promise; + } + + @SuppressWarnings("deprecation") + private ChannelHandlerContext findCtx() throws ClosedChannelException { + // First try to use cached context and if this not work lets try to lookup the context. + ChannelHandlerContext ctx = multiplexCtx; + if (ctx != null && !ctx.isRemoved()) { + return ctx; + } + ChannelPipeline pipeline = channel.pipeline(); + ctx = pipeline.context(Http2MultiplexCodec.class); + if (ctx == null) { + ctx = pipeline.context(Http2MultiplexHandler.class); + } + if (ctx == null) { + if (channel.isActive()) { + throw new IllegalStateException(StringUtil.simpleClassName(Http2MultiplexCodec.class) + " or " + + StringUtil.simpleClassName(Http2MultiplexHandler.class) + + " must be in the ChannelPipeline of Channel " + channel); + } else { + throw new ClosedChannelException(); + } + } + multiplexCtx = ctx; + return ctx; + } + + /** + * @deprecated should not be used directly. Use {@link #open()} or {@link #open(Promise)} + */ + @Deprecated + public void open0(ChannelHandlerContext ctx, final Promise promise) { + assert ctx.executor().inEventLoop(); + if (!promise.setUncancellable()) { + return; + } + final Http2StreamChannel streamChannel; + try { + if (ctx.handler() instanceof Http2MultiplexCodec) { + streamChannel = ((Http2MultiplexCodec) ctx.handler()).newOutboundStream(); + } else { + streamChannel = ((Http2MultiplexHandler) ctx.handler()).newOutboundStream(); + } + } catch (Exception e) { + promise.setFailure(e); + return; + } + try { + init(streamChannel); + } catch (Exception e) { + streamChannel.unsafe().closeForcibly(); + promise.setFailure(e); + return; + } + + ChannelFuture future = ctx.channel().eventLoop().register(streamChannel); + future.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + if (future.isSuccess()) { + promise.setSuccess(streamChannel); + } else if (future.isCancelled()) { + promise.cancel(false); + } else { + if (streamChannel.isRegistered()) { + streamChannel.close(); + } else { + streamChannel.unsafe().closeForcibly(); + } + + promise.setFailure(future.cause()); + } + } + }); + } + + private void init(Channel channel) { + ChannelPipeline p = channel.pipeline(); + ChannelHandler handler = this.handler; + if (handler != null) { + p.addLast(handler); + } + final Map.Entry, Object> [] optionArray; + synchronized (options) { + optionArray = options.entrySet().toArray(EMPTY_OPTION_ARRAY); + } + + setChannelOptions(channel, optionArray); + setAttributes(channel, attrs.entrySet().toArray(EMPTY_ATTRIBUTE_ARRAY)); + } + + private static void setChannelOptions( + Channel channel, Map.Entry, Object>[] options) { + for (Map.Entry, Object> e: options) { + setChannelOption(channel, e.getKey(), e.getValue()); + } + } + + private static void setChannelOption( + Channel channel, ChannelOption option, Object value) { + try { + @SuppressWarnings("unchecked") + ChannelOption opt = (ChannelOption) option; + if (!channel.config().setOption(opt, value)) { + logger.warn("Unknown channel option '{}' for channel '{}'", option, channel); + } + } catch (Throwable t) { + logger.warn( + "Failed to set channel option '{}' with value '{}' for channel '{}'", option, value, channel, t); + } + } + + private static void setAttributes( + Channel channel, Map.Entry, Object>[] options) { + for (Map.Entry, Object> e: options) { + @SuppressWarnings("unchecked") + AttributeKey key = (AttributeKey) e.getKey(); + channel.attr(key).set(e.getValue()); + } + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2StreamChannelId.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2StreamChannelId.java new file mode 100644 index 0000000..e50038a --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2StreamChannelId.java @@ -0,0 +1,76 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.channel.ChannelId; + +/** + * ChannelId implementation which is used by our {@link Http2StreamChannel} implementation. + */ +final class Http2StreamChannelId implements ChannelId { + private static final long serialVersionUID = -6642338822166867585L; + + private final int id; + private final ChannelId parentId; + + Http2StreamChannelId(ChannelId parentId, int id) { + this.parentId = parentId; + this.id = id; + } + + @Override + public String asShortText() { + return parentId.asShortText() + '/' + id; + } + + @Override + public String asLongText() { + return parentId.asLongText() + '/' + id; + } + + @Override + public int compareTo(ChannelId o) { + if (o instanceof Http2StreamChannelId) { + Http2StreamChannelId otherId = (Http2StreamChannelId) o; + int res = parentId.compareTo(otherId.parentId); + if (res == 0) { + return id - otherId.id; + } else { + return res; + } + } + return parentId.compareTo(o); + } + + @Override + public int hashCode() { + return id * 31 + parentId.hashCode(); + } + + @Override + public boolean equals(Object obj) { + if (!(obj instanceof Http2StreamChannelId)) { + return false; + } + Http2StreamChannelId otherId = (Http2StreamChannelId) obj; + return id == otherId.id && parentId.equals(otherId.parentId); + } + + @Override + public String toString() { + return asShortText(); + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2StreamFrame.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2StreamFrame.java new file mode 100644 index 0000000..9ad448b --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2StreamFrame.java @@ -0,0 +1,38 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.util.internal.UnstableApi; + +/** + * A frame whose meaning may apply to a particular stream, instead of the entire connection. It is still + * possible for this frame type to apply to the entire connection. In such cases, the {@link #stream()} must return + * {@code null}. If the frame applies to a stream, the {@link Http2FrameStream#id()} must be greater than zero. + */ +@UnstableApi +public interface Http2StreamFrame extends Http2Frame { + + /** + * Set the {@link Http2FrameStream} object for this frame. + */ + Http2StreamFrame stream(Http2FrameStream stream); + + /** + * Returns the {@link Http2FrameStream} object for this frame, or {@code null} if the frame has yet to be associated + * with a stream. + */ + Http2FrameStream stream(); +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2StreamFrameToHttpObjectCodec.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2StreamFrameToHttpObjectCodec.java new file mode 100644 index 0000000..96b4f8d --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2StreamFrameToHttpObjectCodec.java @@ -0,0 +1,287 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandler.Sharable; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.EncoderException; +import io.netty.handler.codec.MessageToMessageCodec; +import io.netty.handler.codec.http.DefaultHttpContent; +import io.netty.handler.codec.http.DefaultLastHttpContent; +import io.netty.handler.codec.http.FullHttpMessage; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpContent; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaderValues; +import io.netty.handler.codec.http.HttpMessage; +import io.netty.handler.codec.http.HttpObject; +import io.netty.handler.codec.http.HttpRequest; +import io.netty.handler.codec.http.HttpResponse; +import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.handler.codec.http.HttpScheme; +import io.netty.handler.codec.http.HttpStatusClass; +import io.netty.handler.codec.http.HttpUtil; +import io.netty.handler.codec.http.HttpVersion; +import io.netty.handler.codec.http.LastHttpContent; +import io.netty.handler.ssl.SslHandler; +import io.netty.util.Attribute; +import io.netty.util.AttributeKey; +import io.netty.util.internal.UnstableApi; + +import java.util.List; + +/** + * This handler converts from {@link Http2StreamFrame} to {@link HttpObject}, + * and back. It can be used as an adapter in conjunction with {@link + * Http2MultiplexCodec} to make http/2 connections backward-compatible with + * {@link ChannelHandler}s expecting {@link HttpObject} + * + * For simplicity, it converts to chunked encoding unless the entire stream + * is a single header. + */ +@UnstableApi +@Sharable +public class Http2StreamFrameToHttpObjectCodec extends MessageToMessageCodec { + + private static final AttributeKey SCHEME_ATTR_KEY = + AttributeKey.valueOf(HttpScheme.class, "STREAMFRAMECODEC_SCHEME"); + + private final boolean isServer; + private final boolean validateHeaders; + + public Http2StreamFrameToHttpObjectCodec(final boolean isServer, + final boolean validateHeaders) { + this.isServer = isServer; + this.validateHeaders = validateHeaders; + } + + public Http2StreamFrameToHttpObjectCodec(final boolean isServer) { + this(isServer, true); + } + + @Override + public boolean acceptInboundMessage(Object msg) throws Exception { + return msg instanceof Http2HeadersFrame || msg instanceof Http2DataFrame; + } + + @Override + protected void decode(ChannelHandlerContext ctx, Http2StreamFrame frame, List out) throws Exception { + if (frame instanceof Http2HeadersFrame) { + Http2HeadersFrame headersFrame = (Http2HeadersFrame) frame; + Http2Headers headers = headersFrame.headers(); + Http2FrameStream stream = headersFrame.stream(); + int id = stream == null ? 0 : stream.id(); + + final CharSequence status = headers.status(); + + // 1xx response (excluding 101) is a special case where Http2HeadersFrame#isEndStream=false + // but we need to decode it as a FullHttpResponse to play nice with HttpObjectAggregator. + if (null != status && isInformationalResponseHeaderFrame(status)) { + final FullHttpMessage fullMsg = newFullMessage(id, headers, ctx.alloc()); + out.add(fullMsg); + return; + } + + if (headersFrame.isEndStream()) { + if (headers.method() == null && status == null) { + LastHttpContent last = new DefaultLastHttpContent(Unpooled.EMPTY_BUFFER, validateHeaders); + HttpConversionUtil.addHttp2ToHttpHeaders(id, headers, last.trailingHeaders(), + HttpVersion.HTTP_1_1, true, true); + out.add(last); + } else { + FullHttpMessage full = newFullMessage(id, headers, ctx.alloc()); + out.add(full); + } + } else { + HttpMessage req = newMessage(id, headers); + if ((status == null || !isContentAlwaysEmpty(status)) && !HttpUtil.isContentLengthSet(req)) { + req.headers().add(HttpHeaderNames.TRANSFER_ENCODING, HttpHeaderValues.CHUNKED); + } + out.add(req); + } + } else if (frame instanceof Http2DataFrame) { + Http2DataFrame dataFrame = (Http2DataFrame) frame; + if (dataFrame.isEndStream()) { + out.add(new DefaultLastHttpContent(dataFrame.content().retain(), validateHeaders)); + } else { + out.add(new DefaultHttpContent(dataFrame.content().retain())); + } + } + } + + private void encodeLastContent(LastHttpContent last, List out) { + boolean needFiller = !(last instanceof FullHttpMessage) && last.trailingHeaders().isEmpty(); + if (last.content().isReadable() || needFiller) { + out.add(new DefaultHttp2DataFrame(last.content().retain(), last.trailingHeaders().isEmpty())); + } + if (!last.trailingHeaders().isEmpty()) { + Http2Headers headers = HttpConversionUtil.toHttp2Headers(last.trailingHeaders(), validateHeaders); + out.add(new DefaultHttp2HeadersFrame(headers, true)); + } + } + + /** + * Encode from an {@link HttpObject} to an {@link Http2StreamFrame}. This method will + * be called for each written message that can be handled by this encoder. + * + * NOTE: 100-Continue responses that are NOT {@link FullHttpResponse} will be rejected. + * + * @param ctx the {@link ChannelHandlerContext} which this handler belongs to + * @param obj the {@link HttpObject} message to encode + * @param out the {@link List} into which the encoded msg should be added + * needs to do some kind of aggregation + * @throws Exception is thrown if an error occurs + */ + @Override + protected void encode(ChannelHandlerContext ctx, HttpObject obj, List out) throws Exception { + // 1xx (excluding 101) is typically a FullHttpResponse, but the decoded + // Http2HeadersFrame should not be marked as endStream=true + if (obj instanceof HttpResponse) { + final HttpResponse res = (HttpResponse) obj; + final HttpResponseStatus status = res.status(); + final int code = status.code(); + final HttpStatusClass statusClass = status.codeClass(); + // An informational response using a 1xx status code other than 101 is + // transmitted as a HEADERS frame + if (statusClass == HttpStatusClass.INFORMATIONAL && code != 101) { + if (res instanceof FullHttpResponse) { + final Http2Headers headers = toHttp2Headers(ctx, res); + out.add(new DefaultHttp2HeadersFrame(headers, false)); + return; + } else { + throw new EncoderException(status + " must be a FullHttpResponse"); + } + } + } + + if (obj instanceof HttpMessage) { + Http2Headers headers = toHttp2Headers(ctx, (HttpMessage) obj); + boolean noMoreFrames = false; + if (obj instanceof FullHttpMessage) { + FullHttpMessage full = (FullHttpMessage) obj; + noMoreFrames = !full.content().isReadable() && full.trailingHeaders().isEmpty(); + } + + out.add(new DefaultHttp2HeadersFrame(headers, noMoreFrames)); + } + + if (obj instanceof LastHttpContent) { + LastHttpContent last = (LastHttpContent) obj; + encodeLastContent(last, out); + } else if (obj instanceof HttpContent) { + HttpContent cont = (HttpContent) obj; + out.add(new DefaultHttp2DataFrame(cont.content().retain(), false)); + } + } + + private Http2Headers toHttp2Headers(final ChannelHandlerContext ctx, final HttpMessage msg) { + if (msg instanceof HttpRequest) { + msg.headers().set( + HttpConversionUtil.ExtensionHeaderNames.SCHEME.text(), + connectionScheme(ctx)); + } + + return HttpConversionUtil.toHttp2Headers(msg, validateHeaders); + } + + private HttpMessage newMessage(final int id, + final Http2Headers headers) throws Http2Exception { + return isServer ? + HttpConversionUtil.toHttpRequest(id, headers, validateHeaders) : + HttpConversionUtil.toHttpResponse(id, headers, validateHeaders); + } + + private FullHttpMessage newFullMessage(final int id, + final Http2Headers headers, + final ByteBufAllocator alloc) throws Http2Exception { + return isServer ? + HttpConversionUtil.toFullHttpRequest(id, headers, alloc, validateHeaders) : + HttpConversionUtil.toFullHttpResponse(id, headers, alloc, validateHeaders); + } + + @Override + public void handlerAdded(final ChannelHandlerContext ctx) throws Exception { + super.handlerAdded(ctx); + + // this handler is typically used on an Http2StreamChannel. At this + // stage, ssl handshake should've been established. checking for the + // presence of SslHandler in the parent's channel pipeline to + // determine the HTTP scheme should suffice, even for the case where + // SniHandler is used. + final Attribute schemeAttribute = connectionSchemeAttribute(ctx); + if (schemeAttribute.get() == null) { + final HttpScheme scheme = isSsl(ctx) ? HttpScheme.HTTPS : HttpScheme.HTTP; + schemeAttribute.set(scheme); + } + } + + protected boolean isSsl(final ChannelHandlerContext ctx) { + final Channel connChannel = connectionChannel(ctx); + return null != connChannel.pipeline().get(SslHandler.class); + } + + private static HttpScheme connectionScheme(ChannelHandlerContext ctx) { + final HttpScheme scheme = connectionSchemeAttribute(ctx).get(); + return scheme == null ? HttpScheme.HTTP : scheme; + } + + private static Attribute connectionSchemeAttribute(ChannelHandlerContext ctx) { + final Channel ch = connectionChannel(ctx); + return ch.attr(SCHEME_ATTR_KEY); + } + + private static Channel connectionChannel(ChannelHandlerContext ctx) { + final Channel ch = ctx.channel(); + return ch instanceof Http2StreamChannel ? ch.parent() : ch; + } + + /** + * An informational response using a 1xx status code other than 101 is + * transmitted as a HEADERS frame + */ + private static boolean isInformationalResponseHeaderFrame(CharSequence status) { + if (status.length() == 3) { + char char0 = status.charAt(0); + char char1 = status.charAt(1); + char char2 = status.charAt(2); + return char0 == '1' + && char1 >= '0' && char1 <= '9' + && char2 >= '0' && char2 <= '9' && char2 != '1'; + } + return false; + } + + /* + * https://datatracker.ietf.org/doc/html/rfc9113#section-8.1.1 + * '204' or '304' responses contain no content + */ + private static boolean isContentAlwaysEmpty(CharSequence status) { + if (status.length() == 3) { + char char0 = status.charAt(0); + char char1 = status.charAt(1); + char char2 = status.charAt(2); + return (char0 == '2' || char0 == '3') + && char1 == '0' + && char2 == '4'; + } + return false; + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2StreamVisitor.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2StreamVisitor.java new file mode 100644 index 0000000..229cd1b --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2StreamVisitor.java @@ -0,0 +1,31 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.util.internal.UnstableApi; + +/** + * A visitor that allows iteration over a collection of streams. + */ +@UnstableApi +public interface Http2StreamVisitor { + /** + * @return
    + *
  • {@code true} if the visitor wants to continue the loop and handle the entry.
  • + *
  • {@code false} if the visitor wants to stop handling headers and abort the loop.
  • + *
+ */ + boolean visit(Http2Stream stream) throws Http2Exception; +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2UnknownFrame.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2UnknownFrame.java new file mode 100644 index 0000000..9ce1ecd --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2UnknownFrame.java @@ -0,0 +1,58 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufHolder; +import io.netty.util.internal.UnstableApi; + +@UnstableApi +public interface Http2UnknownFrame extends Http2StreamFrame, ByteBufHolder { + + @Override + Http2FrameStream stream(); + + @Override + Http2UnknownFrame stream(Http2FrameStream stream); + + byte frameType(); + + Http2Flags flags(); + + @Override + Http2UnknownFrame copy(); + + @Override + Http2UnknownFrame duplicate(); + + @Override + Http2UnknownFrame retainedDuplicate(); + + @Override + Http2UnknownFrame replace(ByteBuf content); + + @Override + Http2UnknownFrame retain(); + + @Override + Http2UnknownFrame retain(int increment); + + @Override + Http2UnknownFrame touch(); + + @Override + Http2UnknownFrame touch(Object hint); +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2WindowUpdateFrame.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2WindowUpdateFrame.java new file mode 100644 index 0000000..cdd015e --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/Http2WindowUpdateFrame.java @@ -0,0 +1,30 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.util.internal.UnstableApi; + +/** + * HTTP/2 WINDOW_UPDATE frame. + */ +@UnstableApi +public interface Http2WindowUpdateFrame extends Http2StreamFrame { + + /** + * Number of bytes to increment the HTTP/2 stream's or connection's flow control window. + */ + int windowSizeIncrement(); +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/HttpConversionUtil.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/HttpConversionUtil.java new file mode 100644 index 0000000..d17a17e --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/HttpConversionUtil.java @@ -0,0 +1,710 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.handler.codec.UnsupportedValueConverter; +import io.netty.handler.codec.http.DefaultFullHttpRequest; +import io.netty.handler.codec.http.DefaultFullHttpResponse; +import io.netty.handler.codec.http.DefaultHttpRequest; +import io.netty.handler.codec.http.DefaultHttpResponse; +import io.netty.handler.codec.http.FullHttpMessage; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaders; +import io.netty.handler.codec.http.HttpMessage; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpRequest; +import io.netty.handler.codec.http.HttpResponse; +import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.handler.codec.http.HttpUtil; +import io.netty.handler.codec.http.HttpVersion; +import io.netty.util.AsciiString; +import io.netty.util.internal.InternalThreadLocalMap; +import io.netty.util.internal.UnstableApi; + +import java.net.URI; +import java.util.Iterator; +import java.util.List; +import java.util.Map.Entry; + +import static io.netty.handler.codec.http.HttpHeaderNames.CONNECTION; +import static io.netty.handler.codec.http.HttpHeaderNames.COOKIE; +import static io.netty.handler.codec.http.HttpHeaderNames.TE; +import static io.netty.handler.codec.http.HttpHeaderValues.TRAILERS; +import static io.netty.handler.codec.http.HttpResponseStatus.parseLine; +import static io.netty.handler.codec.http.HttpScheme.HTTP; +import static io.netty.handler.codec.http.HttpScheme.HTTPS; +import static io.netty.handler.codec.http.HttpUtil.isAsteriskForm; +import static io.netty.handler.codec.http.HttpUtil.isOriginForm; +import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR; +import static io.netty.handler.codec.http2.Http2Exception.connectionError; +import static io.netty.handler.codec.http2.Http2Exception.streamError; +import static io.netty.util.AsciiString.EMPTY_STRING; +import static io.netty.util.AsciiString.contentEqualsIgnoreCase; +import static io.netty.util.AsciiString.indexOf; +import static io.netty.util.AsciiString.trim; +import static io.netty.util.ByteProcessor.FIND_COMMA; +import static io.netty.util.ByteProcessor.FIND_SEMI_COLON; +import static io.netty.util.internal.ObjectUtil.checkNotNull; +import static io.netty.util.internal.StringUtil.isNullOrEmpty; +import static io.netty.util.internal.StringUtil.length; +import static io.netty.util.internal.StringUtil.unescapeCsvFields; + +/** + * Provides utility methods and constants for the HTTP/2 to HTTP conversion + */ +@UnstableApi +public final class HttpConversionUtil { + /** + * The set of headers that should not be directly copied when converting headers from HTTP to HTTP/2. + */ + private static final CharSequenceMap HTTP_TO_HTTP2_HEADER_BLACKLIST = + new CharSequenceMap(); + static { + HTTP_TO_HTTP2_HEADER_BLACKLIST.add(CONNECTION, EMPTY_STRING); + @SuppressWarnings("deprecation") + AsciiString keepAlive = HttpHeaderNames.KEEP_ALIVE; + HTTP_TO_HTTP2_HEADER_BLACKLIST.add(keepAlive, EMPTY_STRING); + @SuppressWarnings("deprecation") + AsciiString proxyConnection = HttpHeaderNames.PROXY_CONNECTION; + HTTP_TO_HTTP2_HEADER_BLACKLIST.add(proxyConnection, EMPTY_STRING); + HTTP_TO_HTTP2_HEADER_BLACKLIST.add(HttpHeaderNames.TRANSFER_ENCODING, EMPTY_STRING); + HTTP_TO_HTTP2_HEADER_BLACKLIST.add(HttpHeaderNames.HOST, EMPTY_STRING); + HTTP_TO_HTTP2_HEADER_BLACKLIST.add(HttpHeaderNames.UPGRADE, EMPTY_STRING); + HTTP_TO_HTTP2_HEADER_BLACKLIST.add(ExtensionHeaderNames.STREAM_ID.text(), EMPTY_STRING); + HTTP_TO_HTTP2_HEADER_BLACKLIST.add(ExtensionHeaderNames.SCHEME.text(), EMPTY_STRING); + HTTP_TO_HTTP2_HEADER_BLACKLIST.add(ExtensionHeaderNames.PATH.text(), EMPTY_STRING); + } + + /** + * This will be the method used for {@link HttpRequest} objects generated out of the HTTP message flow defined in [RFC 7540], Section 8.1 + */ + public static final HttpMethod OUT_OF_MESSAGE_SEQUENCE_METHOD = HttpMethod.OPTIONS; + + /** + * This will be the path used for {@link HttpRequest} objects generated out of the HTTP message flow defined in [RFC 7540], Section 8.1 + */ + public static final String OUT_OF_MESSAGE_SEQUENCE_PATH = ""; + + /** + * This will be the status code used for {@link HttpResponse} objects generated out of the HTTP message flow defined + * in [RFC 7540], Section 8.1 + */ + public static final HttpResponseStatus OUT_OF_MESSAGE_SEQUENCE_RETURN_CODE = HttpResponseStatus.OK; + + /** + * [RFC 7540], 8.1.2.3 states the path must not + * be empty, and instead should be {@code /}. + */ + private static final AsciiString EMPTY_REQUEST_PATH = AsciiString.cached("/"); + + private HttpConversionUtil() { + } + + /** + * Provides the HTTP header extensions used to carry HTTP/2 information in HTTP objects + */ + public enum ExtensionHeaderNames { + /** + * HTTP extension header which will identify the stream id from the HTTP/2 event(s) responsible for + * generating an {@code HttpObject} + *

+ * {@code "x-http2-stream-id"} + */ + STREAM_ID("x-http2-stream-id"), + /** + * HTTP extension header which will identify the scheme pseudo header from the HTTP/2 event(s) responsible for + * generating an {@code HttpObject} + *

+ * {@code "x-http2-scheme"} + */ + SCHEME("x-http2-scheme"), + /** + * HTTP extension header which will identify the path pseudo header from the HTTP/2 event(s) responsible for + * generating an {@code HttpObject} + *

+ * {@code "x-http2-path"} + */ + PATH("x-http2-path"), + /** + * HTTP extension header which will identify the stream id used to create this stream in an HTTP/2 push promise + * frame + *

+ * {@code "x-http2-stream-promise-id"} + */ + STREAM_PROMISE_ID("x-http2-stream-promise-id"), + /** + * HTTP extension header which will identify the stream id which this stream is dependent on. This stream will + * be a child node of the stream id associated with this header value. + *

+ * {@code "x-http2-stream-dependency-id"} + */ + STREAM_DEPENDENCY_ID("x-http2-stream-dependency-id"), + /** + * HTTP extension header which will identify the weight (if non-default and the priority is not on the default + * stream) of the associated HTTP/2 stream responsible responsible for generating an {@code HttpObject} + *

+ * {@code "x-http2-stream-weight"} + */ + STREAM_WEIGHT("x-http2-stream-weight"); + + private final AsciiString text; + + ExtensionHeaderNames(String text) { + this.text = AsciiString.cached(text); + } + + public AsciiString text() { + return text; + } + } + + /** + * Apply HTTP/2 rules while translating status code to {@link HttpResponseStatus} + * + * @param status The status from an HTTP/2 frame + * @return The HTTP/1.x status + * @throws Http2Exception If there is a problem translating from HTTP/2 to HTTP/1.x + */ + public static HttpResponseStatus parseStatus(CharSequence status) throws Http2Exception { + HttpResponseStatus result; + try { + result = parseLine(status); + if (result == HttpResponseStatus.SWITCHING_PROTOCOLS) { + throw connectionError(PROTOCOL_ERROR, "Invalid HTTP/2 status code '%d'", result.code()); + } + } catch (Http2Exception e) { + throw e; + } catch (Throwable t) { + throw connectionError(PROTOCOL_ERROR, t, + "Unrecognized HTTP status code '%s' encountered in translation to HTTP/1.x", status); + } + return result; + } + + /** + * Create a new object to contain the response data + * + * @param streamId The stream associated with the response + * @param http2Headers The initial set of HTTP/2 headers to create the response with + * @param alloc The {@link ByteBufAllocator} to use to generate the content of the message + * @param validateHttpHeaders

    + *
  • {@code true} to validate HTTP headers in the http-codec
  • + *
  • {@code false} not to validate HTTP headers in the http-codec
  • + *
+ * @return A new response object which represents headers/data + * @throws Http2Exception see {@link #addHttp2ToHttpHeaders(int, Http2Headers, FullHttpMessage, boolean)} + */ + public static FullHttpResponse toFullHttpResponse(int streamId, Http2Headers http2Headers, ByteBufAllocator alloc, + boolean validateHttpHeaders) throws Http2Exception { + return toFullHttpResponse(streamId, http2Headers, alloc.buffer(), validateHttpHeaders); + } + + /** + * Create a new object to contain the response data + * + * @param streamId The stream associated with the response + * @param http2Headers The initial set of HTTP/2 headers to create the response with + * @param content {@link ByteBuf} content to put in {@link FullHttpResponse} + * @param validateHttpHeaders
    + *
  • {@code true} to validate HTTP headers in the http-codec
  • + *
  • {@code false} not to validate HTTP headers in the http-codec
  • + *
+ * @return A new response object which represents headers/data + * @throws Http2Exception see {@link #addHttp2ToHttpHeaders(int, Http2Headers, FullHttpMessage, boolean)} + */ + public static FullHttpResponse toFullHttpResponse(int streamId, Http2Headers http2Headers, ByteBuf content, + boolean validateHttpHeaders) + throws Http2Exception { + HttpResponseStatus status = parseStatus(http2Headers.status()); + // HTTP/2 does not define a way to carry the version or reason phrase that is included in an + // HTTP/1.1 status line. + FullHttpResponse msg = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, status, content, + validateHttpHeaders); + try { + addHttp2ToHttpHeaders(streamId, http2Headers, msg, false); + } catch (Http2Exception e) { + msg.release(); + throw e; + } catch (Throwable t) { + msg.release(); + throw streamError(streamId, PROTOCOL_ERROR, t, "HTTP/2 to HTTP/1.x headers conversion error"); + } + return msg; + } + + /** + * Create a new object to contain the request data + * + * @param streamId The stream associated with the request + * @param http2Headers The initial set of HTTP/2 headers to create the request with + * @param alloc The {@link ByteBufAllocator} to use to generate the content of the message + * @param validateHttpHeaders
    + *
  • {@code true} to validate HTTP headers in the http-codec
  • + *
  • {@code false} not to validate HTTP headers in the http-codec
  • + *
+ * @return A new request object which represents headers/data + * @throws Http2Exception see {@link #addHttp2ToHttpHeaders(int, Http2Headers, FullHttpMessage, boolean)} + */ + public static FullHttpRequest toFullHttpRequest(int streamId, Http2Headers http2Headers, ByteBufAllocator alloc, + boolean validateHttpHeaders) throws Http2Exception { + return toFullHttpRequest(streamId, http2Headers, alloc.buffer(), validateHttpHeaders); + } + + private static String extractPath(CharSequence method, Http2Headers headers) { + if (HttpMethod.CONNECT.asciiName().contentEqualsIgnoreCase(method)) { + // See https://tools.ietf.org/html/rfc7231#section-4.3.6 + return checkNotNull(headers.authority(), + "authority header cannot be null in the conversion to HTTP/1.x").toString(); + } else { + return checkNotNull(headers.path(), + "path header cannot be null in conversion to HTTP/1.x").toString(); + } + } + + /** + * Create a new object to contain the request data + * + * @param streamId The stream associated with the request + * @param http2Headers The initial set of HTTP/2 headers to create the request with + * @param content {@link ByteBuf} content to put in {@link FullHttpRequest} + * @param validateHttpHeaders
    + *
  • {@code true} to validate HTTP headers in the http-codec
  • + *
  • {@code false} not to validate HTTP headers in the http-codec
  • + *
+ * @return A new request object which represents headers/data + * @throws Http2Exception see {@link #addHttp2ToHttpHeaders(int, Http2Headers, FullHttpMessage, boolean)} + */ + public static FullHttpRequest toFullHttpRequest(int streamId, Http2Headers http2Headers, ByteBuf content, + boolean validateHttpHeaders) throws Http2Exception { + // HTTP/2 does not define a way to carry the version identifier that is included in the HTTP/1.1 request line. + final CharSequence method = checkNotNull(http2Headers.method(), + "method header cannot be null in conversion to HTTP/1.x"); + final CharSequence path = extractPath(method, http2Headers); + FullHttpRequest msg = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.valueOf(method + .toString()), path.toString(), content, validateHttpHeaders); + try { + addHttp2ToHttpHeaders(streamId, http2Headers, msg, false); + } catch (Http2Exception e) { + msg.release(); + throw e; + } catch (Throwable t) { + msg.release(); + throw streamError(streamId, PROTOCOL_ERROR, t, "HTTP/2 to HTTP/1.x headers conversion error"); + } + return msg; + } + + /** + * Create a new object to contain the request data. + * + * @param streamId The stream associated with the request + * @param http2Headers The initial set of HTTP/2 headers to create the request with + * @param validateHttpHeaders
    + *
  • {@code true} to validate HTTP headers in the http-codec
  • + *
  • {@code false} not to validate HTTP headers in the http-codec
  • + *
+ * @return A new request object which represents headers for a chunked request + * @throws Http2Exception see {@link #addHttp2ToHttpHeaders(int, Http2Headers, FullHttpMessage, boolean)} + */ + public static HttpRequest toHttpRequest(int streamId, Http2Headers http2Headers, boolean validateHttpHeaders) + throws Http2Exception { + // HTTP/2 does not define a way to carry the version identifier that is included in the HTTP/1.1 request line. + final CharSequence method = checkNotNull(http2Headers.method(), + "method header cannot be null in conversion to HTTP/1.x"); + final CharSequence path = extractPath(method, http2Headers); + HttpRequest msg = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.valueOf(method.toString()), + path.toString(), validateHttpHeaders); + try { + addHttp2ToHttpHeaders(streamId, http2Headers, msg.headers(), msg.protocolVersion(), false, true); + } catch (Http2Exception e) { + throw e; + } catch (Throwable t) { + throw streamError(streamId, PROTOCOL_ERROR, t, "HTTP/2 to HTTP/1.x headers conversion error"); + } + return msg; + } + + /** + * Create a new object to contain the response data. + * + * @param streamId The stream associated with the response + * @param http2Headers The initial set of HTTP/2 headers to create the response with + * @param validateHttpHeaders
    + *
  • {@code true} to validate HTTP headers in the http-codec
  • + *
  • {@code false} not to validate HTTP headers in the http-codec
  • + *
+ * @return A new response object which represents headers for a chunked response + * @throws Http2Exception see {@link #addHttp2ToHttpHeaders(int, Http2Headers, + * HttpHeaders, HttpVersion, boolean, boolean)} + */ + public static HttpResponse toHttpResponse(final int streamId, + final Http2Headers http2Headers, + final boolean validateHttpHeaders) throws Http2Exception { + final HttpResponseStatus status = parseStatus(http2Headers.status()); + // HTTP/2 does not define a way to carry the version or reason phrase that is included in an + // HTTP/1.1 status line. + final HttpResponse msg = new DefaultHttpResponse(HttpVersion.HTTP_1_1, status, validateHttpHeaders); + try { + addHttp2ToHttpHeaders(streamId, http2Headers, msg.headers(), msg.protocolVersion(), false, false); + } catch (final Http2Exception e) { + throw e; + } catch (final Throwable t) { + throw streamError(streamId, PROTOCOL_ERROR, t, "HTTP/2 to HTTP/1.x headers conversion error"); + } + return msg; + } + + /** + * Translate and add HTTP/2 headers to HTTP/1.x headers. + * + * @param streamId The stream associated with {@code sourceHeaders}. + * @param inputHeaders The HTTP/2 headers to convert. + * @param destinationMessage The object which will contain the resulting HTTP/1.x headers. + * @param addToTrailer {@code true} to add to trailing headers. {@code false} to add to initial headers. + * @throws Http2Exception If not all HTTP/2 headers can be translated to HTTP/1.x. + * @see #addHttp2ToHttpHeaders(int, Http2Headers, HttpHeaders, HttpVersion, boolean, boolean) + */ + public static void addHttp2ToHttpHeaders(int streamId, Http2Headers inputHeaders, + FullHttpMessage destinationMessage, boolean addToTrailer) throws Http2Exception { + addHttp2ToHttpHeaders(streamId, inputHeaders, + addToTrailer ? destinationMessage.trailingHeaders() : destinationMessage.headers(), + destinationMessage.protocolVersion(), addToTrailer, destinationMessage instanceof HttpRequest); + } + + /** + * Translate and add HTTP/2 headers to HTTP/1.x headers. + * + * @param streamId The stream associated with {@code sourceHeaders}. + * @param inputHeaders The HTTP/2 headers to convert. + * @param outputHeaders The object which will contain the resulting HTTP/1.x headers.. + * @param httpVersion What HTTP/1.x version {@code outputHeaders} should be treated as when doing the conversion. + * @param isTrailer {@code true} if {@code outputHeaders} should be treated as trailing headers. + * {@code false} otherwise. + * @param isRequest {@code true} if the {@code outputHeaders} will be used in a request message. + * {@code false} for response message. + * @throws Http2Exception If not all HTTP/2 headers can be translated to HTTP/1.x. + */ + public static void addHttp2ToHttpHeaders(int streamId, Http2Headers inputHeaders, HttpHeaders outputHeaders, + HttpVersion httpVersion, boolean isTrailer, boolean isRequest) throws Http2Exception { + Http2ToHttpHeaderTranslator translator = new Http2ToHttpHeaderTranslator(streamId, outputHeaders, isRequest); + try { + translator.translateHeaders(inputHeaders); + } catch (Http2Exception ex) { + throw ex; + } catch (Throwable t) { + throw streamError(streamId, PROTOCOL_ERROR, t, "HTTP/2 to HTTP/1.x headers conversion error"); + } + + outputHeaders.remove(HttpHeaderNames.TRANSFER_ENCODING); + outputHeaders.remove(HttpHeaderNames.TRAILER); + if (!isTrailer) { + outputHeaders.setInt(ExtensionHeaderNames.STREAM_ID.text(), streamId); + HttpUtil.setKeepAlive(outputHeaders, httpVersion, true); + } + } + + /** + * Converts the given HTTP/1.x headers into HTTP/2 headers. + * The following headers are only used if they can not be found in from the {@code HOST} header or the + * {@code Request-Line} as defined by rfc7230 + *
    + *
  • {@link ExtensionHeaderNames#SCHEME}
  • + *
+ * {@link ExtensionHeaderNames#PATH} is ignored and instead extracted from the {@code Request-Line}. + */ + public static Http2Headers toHttp2Headers(HttpMessage in, boolean validateHeaders) { + HttpHeaders inHeaders = in.headers(); + final Http2Headers out = new DefaultHttp2Headers(validateHeaders, inHeaders.size()); + if (in instanceof HttpRequest) { + HttpRequest request = (HttpRequest) in; + String host = inHeaders.getAsString(HttpHeaderNames.HOST); + if (isOriginForm(request.uri()) || isAsteriskForm(request.uri())) { + out.path(new AsciiString(request.uri())); + setHttp2Scheme(inHeaders, out); + } else { + URI requestTargetUri = URI.create(request.uri()); + out.path(toHttp2Path(requestTargetUri)); + // Take from the request-line if HOST header was empty + host = isNullOrEmpty(host) ? requestTargetUri.getAuthority() : host; + setHttp2Scheme(inHeaders, requestTargetUri, out); + } + setHttp2Authority(host, out); + out.method(request.method().asciiName()); + } else if (in instanceof HttpResponse) { + HttpResponse response = (HttpResponse) in; + out.status(response.status().codeAsText()); + } + + // Add the HTTP headers which have not been consumed above + toHttp2Headers(inHeaders, out); + return out; + } + + public static Http2Headers toHttp2Headers(HttpHeaders inHeaders, boolean validateHeaders) { + if (inHeaders.isEmpty()) { + return EmptyHttp2Headers.INSTANCE; + } + + final Http2Headers out = new DefaultHttp2Headers(validateHeaders, inHeaders.size()); + toHttp2Headers(inHeaders, out); + return out; + } + + private static CharSequenceMap toLowercaseMap(Iterator valuesIter, + int arraySizeHint) { + UnsupportedValueConverter valueConverter = UnsupportedValueConverter.instance(); + CharSequenceMap result = new CharSequenceMap(true, valueConverter, arraySizeHint); + + while (valuesIter.hasNext()) { + AsciiString lowerCased = AsciiString.of(valuesIter.next()).toLowerCase(); + try { + int index = lowerCased.forEachByte(FIND_COMMA); + if (index != -1) { + int start = 0; + do { + result.add(lowerCased.subSequence(start, index, false).trim(), EMPTY_STRING); + start = index + 1; + } while (start < lowerCased.length() && + (index = lowerCased.forEachByte(start, lowerCased.length() - start, FIND_COMMA)) != -1); + result.add(lowerCased.subSequence(start, lowerCased.length(), false).trim(), EMPTY_STRING); + } else { + result.add(lowerCased.trim(), EMPTY_STRING); + } + } catch (Exception e) { + // This is not expect to happen because FIND_COMMA never throws but must be caught + // because of the ByteProcessor interface. + throw new IllegalStateException(e); + } + } + return result; + } + + /** + * Filter the {@link HttpHeaderNames#TE} header according to the + * special rules in the HTTP/2 RFC. + * @param entry An entry whose name is {@link HttpHeaderNames#TE}. + * @param out the resulting HTTP/2 headers. + */ + private static void toHttp2HeadersFilterTE(Entry entry, + Http2Headers out) { + if (indexOf(entry.getValue(), ',', 0) == -1) { + if (contentEqualsIgnoreCase(trim(entry.getValue()), TRAILERS)) { + out.add(TE, TRAILERS); + } + } else { + List teValues = unescapeCsvFields(entry.getValue()); + for (CharSequence teValue : teValues) { + if (contentEqualsIgnoreCase(trim(teValue), TRAILERS)) { + out.add(TE, TRAILERS); + break; + } + } + } + } + + public static void toHttp2Headers(HttpHeaders inHeaders, Http2Headers out) { + Iterator> iter = inHeaders.iteratorCharSequence(); + // Choose 8 as a default size because it is unlikely we will see more than 4 Connection headers values, but + // still allowing for "enough" space in the map to reduce the chance of hash code collision. + CharSequenceMap connectionBlacklist = + toLowercaseMap(inHeaders.valueCharSequenceIterator(CONNECTION), 8); + while (iter.hasNext()) { + Entry entry = iter.next(); + final AsciiString aName = AsciiString.of(entry.getKey()).toLowerCase(); + if (!HTTP_TO_HTTP2_HEADER_BLACKLIST.contains(aName) && !connectionBlacklist.contains(aName)) { + // https://tools.ietf.org/html/rfc7540#section-8.1.2.2 makes a special exception for TE + if (aName.contentEqualsIgnoreCase(TE)) { + toHttp2HeadersFilterTE(entry, out); + } else if (aName.contentEqualsIgnoreCase(COOKIE)) { + AsciiString value = AsciiString.of(entry.getValue()); + // split up cookies to allow for better compression + // https://tools.ietf.org/html/rfc7540#section-8.1.2.5 + try { + int index = value.forEachByte(FIND_SEMI_COLON); + if (index != -1) { + int start = 0; + do { + out.add(COOKIE, value.subSequence(start, index, false)); + // skip 2 characters "; " (see https://tools.ietf.org/html/rfc6265#section-4.2.1) + start = index + 2; + } while (start < value.length() && + (index = value.forEachByte(start, value.length() - start, FIND_SEMI_COLON)) != -1); + if (start >= value.length()) { + throw new IllegalArgumentException("cookie value is of unexpected format: " + value); + } + out.add(COOKIE, value.subSequence(start, value.length(), false)); + } else { + out.add(COOKIE, value); + } + } catch (Exception e) { + // This is not expect to happen because FIND_SEMI_COLON never throws but must be caught + // because of the ByteProcessor interface. + throw new IllegalStateException(e); + } + } else { + out.add(aName, entry.getValue()); + } + } + } + } + + /** + * Generate an HTTP/2 {code :path} from a URI in accordance with + * rfc7230, 5.3. + */ + private static AsciiString toHttp2Path(URI uri) { + StringBuilder pathBuilder = new StringBuilder(length(uri.getRawPath()) + + length(uri.getRawQuery()) + length(uri.getRawFragment()) + 2); + if (!isNullOrEmpty(uri.getRawPath())) { + pathBuilder.append(uri.getRawPath()); + } + if (!isNullOrEmpty(uri.getRawQuery())) { + pathBuilder.append('?'); + pathBuilder.append(uri.getRawQuery()); + } + if (!isNullOrEmpty(uri.getRawFragment())) { + pathBuilder.append('#'); + pathBuilder.append(uri.getRawFragment()); + } + String path = pathBuilder.toString(); + return path.isEmpty() ? EMPTY_REQUEST_PATH : new AsciiString(path); + } + + // package-private for testing only + static void setHttp2Authority(String authority, Http2Headers out) { + // The authority MUST NOT include the deprecated "userinfo" subcomponent + if (authority != null) { + if (authority.isEmpty()) { + out.authority(EMPTY_STRING); + } else { + int start = authority.indexOf('@') + 1; + int length = authority.length() - start; + if (length == 0) { + throw new IllegalArgumentException("authority: " + authority); + } + out.authority(new AsciiString(authority, start, length)); + } + } + } + + private static void setHttp2Scheme(HttpHeaders in, Http2Headers out) { + setHttp2Scheme(in, URI.create(""), out); + } + + private static void setHttp2Scheme(HttpHeaders in, URI uri, Http2Headers out) { + String value = uri.getScheme(); + if (!isNullOrEmpty(value)) { + out.scheme(new AsciiString(value)); + return; + } + + // Consume the Scheme extension header if present + CharSequence cValue = in.get(ExtensionHeaderNames.SCHEME.text()); + if (cValue != null) { + out.scheme(AsciiString.of(cValue)); + return; + } + + if (uri.getPort() == HTTPS.port()) { + out.scheme(HTTPS.name()); + } else if (uri.getPort() == HTTP.port()) { + out.scheme(HTTP.name()); + } else { + throw new IllegalArgumentException(":scheme must be specified. " + + "see https://tools.ietf.org/html/rfc7540#section-8.1.2.3"); + } + } + + /** + * Utility which translates HTTP/2 headers to HTTP/1 headers. + */ + private static final class Http2ToHttpHeaderTranslator { + /** + * Translations from HTTP/2 header name to the HTTP/1.x equivalent. + */ + private static final CharSequenceMap + REQUEST_HEADER_TRANSLATIONS = new CharSequenceMap(); + private static final CharSequenceMap + RESPONSE_HEADER_TRANSLATIONS = new CharSequenceMap(); + static { + RESPONSE_HEADER_TRANSLATIONS.add(Http2Headers.PseudoHeaderName.AUTHORITY.value(), + HttpHeaderNames.HOST); + RESPONSE_HEADER_TRANSLATIONS.add(Http2Headers.PseudoHeaderName.SCHEME.value(), + ExtensionHeaderNames.SCHEME.text()); + REQUEST_HEADER_TRANSLATIONS.add(RESPONSE_HEADER_TRANSLATIONS); + RESPONSE_HEADER_TRANSLATIONS.add(Http2Headers.PseudoHeaderName.PATH.value(), + ExtensionHeaderNames.PATH.text()); + } + + private final int streamId; + private final HttpHeaders output; + private final CharSequenceMap translations; + + /** + * Create a new instance + * + * @param output The HTTP/1.x headers object to store the results of the translation + * @param request if {@code true}, translates headers using the request translation map. Otherwise uses the + * response translation map. + */ + Http2ToHttpHeaderTranslator(int streamId, HttpHeaders output, boolean request) { + this.streamId = streamId; + this.output = output; + translations = request ? REQUEST_HEADER_TRANSLATIONS : RESPONSE_HEADER_TRANSLATIONS; + } + + void translateHeaders(Iterable> inputHeaders) throws Http2Exception { + // lazily created as needed + StringBuilder cookies = null; + + for (Entry entry : inputHeaders) { + final CharSequence name = entry.getKey(); + final CharSequence value = entry.getValue(); + AsciiString translatedName = translations.get(name); + if (translatedName != null) { + output.add(translatedName, AsciiString.of(value)); + } else if (!Http2Headers.PseudoHeaderName.isPseudoHeader(name)) { + // https://tools.ietf.org/html/rfc7540#section-8.1.2.3 + // All headers that start with ':' are only valid in HTTP/2 context + if (name.length() == 0 || name.charAt(0) == ':') { + throw streamError(streamId, PROTOCOL_ERROR, + "Invalid HTTP/2 header '%s' encountered in translation to HTTP/1.x", name); + } + if (COOKIE.equals(name)) { + // combine the cookie values into 1 header entry. + // https://tools.ietf.org/html/rfc7540#section-8.1.2.5 + if (cookies == null) { + cookies = InternalThreadLocalMap.get().stringBuilder(); + } else if (cookies.length() > 0) { + cookies.append("; "); + } + cookies.append(value); + } else { + output.add(name, value); + } + } + } + if (cookies != null) { + output.add(COOKIE, cookies.toString()); + } + } + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/HttpToHttp2ConnectionHandler.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/HttpToHttp2ConnectionHandler.java new file mode 100644 index 0000000..9a5c321 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/HttpToHttp2ConnectionHandler.java @@ -0,0 +1,166 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.handler.codec.http.EmptyHttpHeaders; +import io.netty.handler.codec.http.FullHttpMessage; +import io.netty.handler.codec.http.HttpContent; +import io.netty.handler.codec.http.HttpHeaders; +import io.netty.handler.codec.http.HttpMessage; +import io.netty.handler.codec.http.HttpScheme; +import io.netty.handler.codec.http.LastHttpContent; +import io.netty.handler.codec.http2.Http2CodecUtil.SimpleChannelPromiseAggregator; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.internal.UnstableApi; + +/** + * Translates HTTP/1.x object writes into HTTP/2 frames. + *

+ * See {@link InboundHttp2ToHttpAdapter} to get translation from HTTP/2 frames to HTTP/1.x objects. + */ +@UnstableApi +public class HttpToHttp2ConnectionHandler extends Http2ConnectionHandler { + + private final boolean validateHeaders; + private int currentStreamId; + private HttpScheme httpScheme; + + protected HttpToHttp2ConnectionHandler(Http2ConnectionDecoder decoder, Http2ConnectionEncoder encoder, + Http2Settings initialSettings, boolean validateHeaders) { + super(decoder, encoder, initialSettings); + this.validateHeaders = validateHeaders; + } + + protected HttpToHttp2ConnectionHandler(Http2ConnectionDecoder decoder, Http2ConnectionEncoder encoder, + Http2Settings initialSettings, boolean validateHeaders, + boolean decoupleCloseAndGoAway) { + this(decoder, encoder, initialSettings, validateHeaders, decoupleCloseAndGoAway, null); + } + + protected HttpToHttp2ConnectionHandler(Http2ConnectionDecoder decoder, Http2ConnectionEncoder encoder, + Http2Settings initialSettings, boolean validateHeaders, + boolean decoupleCloseAndGoAway, HttpScheme httpScheme) { + super(decoder, encoder, initialSettings, decoupleCloseAndGoAway); + this.validateHeaders = validateHeaders; + this.httpScheme = httpScheme; + } + + protected HttpToHttp2ConnectionHandler(Http2ConnectionDecoder decoder, Http2ConnectionEncoder encoder, + Http2Settings initialSettings, boolean validateHeaders, + boolean decoupleCloseAndGoAway, boolean flushPreface, + HttpScheme httpScheme) { + super(decoder, encoder, initialSettings, decoupleCloseAndGoAway, flushPreface); + this.validateHeaders = validateHeaders; + this.httpScheme = httpScheme; + } + + /** + * Get the next stream id either from the {@link HttpHeaders} object or HTTP/2 codec + * + * @param httpHeaders The HTTP/1.x headers object to look for the stream id + * @return The stream id to use with this {@link HttpHeaders} object + * @throws Exception If the {@code httpHeaders} object specifies an invalid stream id + */ + private int getStreamId(HttpHeaders httpHeaders) throws Exception { + return httpHeaders.getInt(HttpConversionUtil.ExtensionHeaderNames.STREAM_ID.text(), + connection().local().incrementAndGetNextStreamId()); + } + + /** + * Handles conversion of {@link HttpMessage} and {@link HttpContent} to HTTP/2 frames. + */ + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) { + + if (!(msg instanceof HttpMessage || msg instanceof HttpContent)) { + ctx.write(msg, promise); + return; + } + + boolean release = true; + SimpleChannelPromiseAggregator promiseAggregator = + new SimpleChannelPromiseAggregator(promise, ctx.channel(), ctx.executor()); + try { + Http2ConnectionEncoder encoder = encoder(); + boolean endStream = false; + if (msg instanceof HttpMessage) { + final HttpMessage httpMsg = (HttpMessage) msg; + + // Provide the user the opportunity to specify the streamId + currentStreamId = getStreamId(httpMsg.headers()); + + // Add HttpScheme if it's defined in constructor and header does not contain it. + if (httpScheme != null && + !httpMsg.headers().contains(HttpConversionUtil.ExtensionHeaderNames.SCHEME.text())) { + httpMsg.headers().set(HttpConversionUtil.ExtensionHeaderNames.SCHEME.text(), httpScheme.name()); + } + + // Convert and write the headers. + Http2Headers http2Headers = HttpConversionUtil.toHttp2Headers(httpMsg, validateHeaders); + endStream = msg instanceof FullHttpMessage && !((FullHttpMessage) msg).content().isReadable(); + writeHeaders(ctx, encoder, currentStreamId, httpMsg.headers(), http2Headers, + endStream, promiseAggregator); + } + + if (!endStream && msg instanceof HttpContent) { + boolean isLastContent = false; + HttpHeaders trailers = EmptyHttpHeaders.INSTANCE; + Http2Headers http2Trailers = EmptyHttp2Headers.INSTANCE; + if (msg instanceof LastHttpContent) { + isLastContent = true; + + // Convert any trailing headers. + final LastHttpContent lastContent = (LastHttpContent) msg; + trailers = lastContent.trailingHeaders(); + http2Trailers = HttpConversionUtil.toHttp2Headers(trailers, validateHeaders); + } + + // Write the data + final ByteBuf content = ((HttpContent) msg).content(); + endStream = isLastContent && trailers.isEmpty(); + encoder.writeData(ctx, currentStreamId, content, 0, endStream, promiseAggregator.newPromise()); + release = false; + + if (!trailers.isEmpty()) { + // Write trailing headers. + writeHeaders(ctx, encoder, currentStreamId, trailers, http2Trailers, true, promiseAggregator); + } + } + } catch (Throwable t) { + onError(ctx, true, t); + promiseAggregator.setFailure(t); + } finally { + if (release) { + ReferenceCountUtil.release(msg); + } + promiseAggregator.doneAllocatingPromises(); + } + } + + private static void writeHeaders(ChannelHandlerContext ctx, Http2ConnectionEncoder encoder, int streamId, + HttpHeaders headers, Http2Headers http2Headers, boolean endStream, + SimpleChannelPromiseAggregator promiseAggregator) { + int dependencyId = headers.getInt( + HttpConversionUtil.ExtensionHeaderNames.STREAM_DEPENDENCY_ID.text(), 0); + short weight = headers.getShort( + HttpConversionUtil.ExtensionHeaderNames.STREAM_WEIGHT.text(), Http2CodecUtil.DEFAULT_PRIORITY_WEIGHT); + encoder.writeHeaders(ctx, streamId, http2Headers, dependencyId, weight, false, + 0, endStream, promiseAggregator.newPromise()); + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/HttpToHttp2ConnectionHandlerBuilder.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/HttpToHttp2ConnectionHandlerBuilder.java new file mode 100644 index 0000000..9b23a36 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/HttpToHttp2ConnectionHandlerBuilder.java @@ -0,0 +1,123 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.codec.http2; + +import io.netty.handler.codec.http.HttpScheme; +import io.netty.handler.codec.http2.Http2HeadersEncoder.SensitivityDetector; +import io.netty.util.internal.UnstableApi; + +/** + * Builder which builds {@link HttpToHttp2ConnectionHandler} objects. + */ +@UnstableApi +public final class HttpToHttp2ConnectionHandlerBuilder extends + AbstractHttp2ConnectionHandlerBuilder { + + private HttpScheme httpScheme; + + @Override + public HttpToHttp2ConnectionHandlerBuilder validateHeaders(boolean validateHeaders) { + return super.validateHeaders(validateHeaders); + } + + @Override + public HttpToHttp2ConnectionHandlerBuilder initialSettings(Http2Settings settings) { + return super.initialSettings(settings); + } + + @Override + public HttpToHttp2ConnectionHandlerBuilder frameListener(Http2FrameListener frameListener) { + return super.frameListener(frameListener); + } + + @Override + public HttpToHttp2ConnectionHandlerBuilder gracefulShutdownTimeoutMillis(long gracefulShutdownTimeoutMillis) { + return super.gracefulShutdownTimeoutMillis(gracefulShutdownTimeoutMillis); + } + + @Override + public HttpToHttp2ConnectionHandlerBuilder server(boolean isServer) { + return super.server(isServer); + } + + @Override + public HttpToHttp2ConnectionHandlerBuilder connection(Http2Connection connection) { + return super.connection(connection); + } + + @Override + public HttpToHttp2ConnectionHandlerBuilder codec(Http2ConnectionDecoder decoder, + Http2ConnectionEncoder encoder) { + return super.codec(decoder, encoder); + } + + @Override + public HttpToHttp2ConnectionHandlerBuilder frameLogger(Http2FrameLogger frameLogger) { + return super.frameLogger(frameLogger); + } + + @Override + public HttpToHttp2ConnectionHandlerBuilder encoderEnforceMaxConcurrentStreams( + boolean encoderEnforceMaxConcurrentStreams) { + return super.encoderEnforceMaxConcurrentStreams(encoderEnforceMaxConcurrentStreams); + } + + @Override + public HttpToHttp2ConnectionHandlerBuilder headerSensitivityDetector( + SensitivityDetector headerSensitivityDetector) { + return super.headerSensitivityDetector(headerSensitivityDetector); + } + + @Override + @Deprecated + public HttpToHttp2ConnectionHandlerBuilder initialHuffmanDecodeCapacity(int initialHuffmanDecodeCapacity) { + return super.initialHuffmanDecodeCapacity(initialHuffmanDecodeCapacity); + } + + @Override + public HttpToHttp2ConnectionHandlerBuilder decoupleCloseAndGoAway(boolean decoupleCloseAndGoAway) { + return super.decoupleCloseAndGoAway(decoupleCloseAndGoAway); + } + + @Override + public HttpToHttp2ConnectionHandlerBuilder flushPreface(boolean flushPreface) { + return super.flushPreface(flushPreface); + } + + /** + * Add {@code scheme} in {@link Http2Headers} if not already present. + * + * @param httpScheme {@link HttpScheme} type + * @return {@code this}. + */ + public HttpToHttp2ConnectionHandlerBuilder httpScheme(HttpScheme httpScheme) { + this.httpScheme = httpScheme; + return self(); + } + + @Override + public HttpToHttp2ConnectionHandler build() { + return super.build(); + } + + @Override + protected HttpToHttp2ConnectionHandler build(Http2ConnectionDecoder decoder, Http2ConnectionEncoder encoder, + Http2Settings initialSettings) { + return new HttpToHttp2ConnectionHandler(decoder, encoder, initialSettings, isValidateHeaders(), + decoupleCloseAndGoAway(), flushPreface(), httpScheme); + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/InboundHttp2ToHttpAdapter.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/InboundHttp2ToHttpAdapter.java new file mode 100644 index 0000000..ca4cecf --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/InboundHttp2ToHttpAdapter.java @@ -0,0 +1,360 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.http.FullHttpMessage; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpStatusClass; +import io.netty.handler.codec.http.HttpUtil; +import io.netty.util.internal.UnstableApi; + +import static io.netty.handler.codec.http2.Http2Error.INTERNAL_ERROR; +import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR; +import static io.netty.handler.codec.http2.Http2Exception.connectionError; +import static io.netty.handler.codec.http.HttpResponseStatus.OK; +import static io.netty.util.internal.ObjectUtil.checkNotNull; +import static io.netty.util.internal.ObjectUtil.checkPositive; + +/** + * This adapter provides just header/data events from the HTTP message flow defined + * in [RFC 7540], Section 8.1. + *

+ * See {@link HttpToHttp2ConnectionHandler} to get translation from HTTP/1.x objects to HTTP/2 frames for writes. + */ +@UnstableApi +public class InboundHttp2ToHttpAdapter extends Http2EventAdapter { + private static final ImmediateSendDetector DEFAULT_SEND_DETECTOR = new ImmediateSendDetector() { + @Override + public boolean mustSendImmediately(FullHttpMessage msg) { + if (msg instanceof FullHttpResponse) { + return ((FullHttpResponse) msg).status().codeClass() == HttpStatusClass.INFORMATIONAL; + } + if (msg instanceof FullHttpRequest) { + return msg.headers().contains(HttpHeaderNames.EXPECT); + } + return false; + } + + @Override + public FullHttpMessage copyIfNeeded(ByteBufAllocator allocator, FullHttpMessage msg) { + if (msg instanceof FullHttpRequest) { + FullHttpRequest copy = ((FullHttpRequest) msg).replace(allocator.buffer(0)); + copy.headers().remove(HttpHeaderNames.EXPECT); + return copy; + } + return null; + } + }; + + private final int maxContentLength; + private final ImmediateSendDetector sendDetector; + private final Http2Connection.PropertyKey messageKey; + private final boolean propagateSettings; + protected final Http2Connection connection; + protected final boolean validateHttpHeaders; + + protected InboundHttp2ToHttpAdapter(Http2Connection connection, int maxContentLength, + boolean validateHttpHeaders, boolean propagateSettings) { + this.connection = checkNotNull(connection, "connection"); + this.maxContentLength = checkPositive(maxContentLength, "maxContentLength"); + this.validateHttpHeaders = validateHttpHeaders; + this.propagateSettings = propagateSettings; + sendDetector = DEFAULT_SEND_DETECTOR; + messageKey = connection.newKey(); + } + + /** + * The stream is out of scope for the HTTP message flow and will no longer be tracked + * @param stream The stream to remove associated state with + * @param release {@code true} to call release on the value if it is present. {@code false} to not call release. + */ + protected final void removeMessage(Http2Stream stream, boolean release) { + FullHttpMessage msg = stream.removeProperty(messageKey); + if (release && msg != null) { + msg.release(); + } + } + + /** + * Get the {@link FullHttpMessage} associated with {@code stream}. + * @param stream The stream to get the associated state from + * @return The {@link FullHttpMessage} associated with {@code stream}. + */ + protected final FullHttpMessage getMessage(Http2Stream stream) { + return (FullHttpMessage) stream.getProperty(messageKey); + } + + /** + * Make {@code message} be the state associated with {@code stream}. + * @param stream The stream which {@code message} is associated with. + * @param message The message which contains the HTTP semantics. + */ + protected final void putMessage(Http2Stream stream, FullHttpMessage message) { + FullHttpMessage previous = stream.setProperty(messageKey, message); + if (previous != message && previous != null) { + previous.release(); + } + } + + @Override + public void onStreamRemoved(Http2Stream stream) { + removeMessage(stream, true); + } + + /** + * Set final headers and fire a channel read event + * + * @param ctx The context to fire the event on + * @param msg The message to send + * @param release {@code true} to call release on the value if it is present. {@code false} to not call release. + * @param stream the stream of the message which is being fired + */ + protected void fireChannelRead(ChannelHandlerContext ctx, FullHttpMessage msg, boolean release, + Http2Stream stream) { + removeMessage(stream, release); + HttpUtil.setContentLength(msg, msg.content().readableBytes()); + ctx.fireChannelRead(msg); + } + + /** + * Create a new {@link FullHttpMessage} based upon the current connection parameters + * + * @param stream The stream to create a message for + * @param headers The headers associated with {@code stream} + * @param validateHttpHeaders + *

    + *
  • {@code true} to validate HTTP headers in the http-codec
  • + *
  • {@code false} not to validate HTTP headers in the http-codec
  • + *
+ * @param alloc The {@link ByteBufAllocator} to use to generate the content of the message + * @throws Http2Exception If there is an error when creating {@link FullHttpMessage} from + * {@link Http2Stream} and {@link Http2Headers} + */ + protected FullHttpMessage newMessage(Http2Stream stream, Http2Headers headers, boolean validateHttpHeaders, + ByteBufAllocator alloc) throws Http2Exception { + return connection.isServer() ? HttpConversionUtil.toFullHttpRequest(stream.id(), headers, alloc, + validateHttpHeaders) : HttpConversionUtil.toFullHttpResponse(stream.id(), headers, alloc, + validateHttpHeaders); + } + + /** + * Provides translation between HTTP/2 and HTTP header objects while ensuring the stream + * is in a valid state for additional headers. + * + * @param ctx The context for which this message has been received. + * Used to send informational header if detected. + * @param stream The stream the {@code headers} apply to + * @param headers The headers to process + * @param endOfStream {@code true} if the {@code stream} has received the end of stream flag + * @param allowAppend + *
    + *
  • {@code true} if headers will be appended if the stream already exists.
  • + *
  • if {@code false} and the stream already exists this method returns {@code null}.
  • + *
+ * @param appendToTrailer + *
    + *
  • {@code true} if a message {@code stream} already exists then the headers + * should be added to the trailing headers.
  • + *
  • {@code false} then appends will be done to the initial headers.
  • + *
+ * @return The object used to track the stream corresponding to {@code stream}. {@code null} if + * {@code allowAppend} is {@code false} and the stream already exists. + * @throws Http2Exception If the stream id is not in the correct state to process the headers request + */ + protected FullHttpMessage processHeadersBegin(ChannelHandlerContext ctx, Http2Stream stream, Http2Headers headers, + boolean endOfStream, boolean allowAppend, boolean appendToTrailer) + throws Http2Exception { + FullHttpMessage msg = getMessage(stream); + boolean release = true; + if (msg == null) { + msg = newMessage(stream, headers, validateHttpHeaders, ctx.alloc()); + } else if (allowAppend) { + release = false; + HttpConversionUtil.addHttp2ToHttpHeaders(stream.id(), headers, msg, appendToTrailer); + } else { + release = false; + msg = null; + } + + if (sendDetector.mustSendImmediately(msg)) { + // Copy the message (if necessary) before sending. The content is not expected to be copied (or used) in + // this operation but just in case it is used do the copy before sending and the resource may be released + final FullHttpMessage copy = endOfStream ? null : sendDetector.copyIfNeeded(ctx.alloc(), msg); + fireChannelRead(ctx, msg, release, stream); + return copy; + } + + return msg; + } + + /** + * After HTTP/2 headers have been processed by {@link #processHeadersBegin} this method either + * sends the result up the pipeline or retains the message for future processing. + * + * @param ctx The context for which this message has been received + * @param stream The stream the {@code objAccumulator} corresponds to + * @param msg The object which represents all headers/data for corresponding to {@code stream} + * @param endOfStream {@code true} if this is the last event for the stream + */ + private void processHeadersEnd(ChannelHandlerContext ctx, Http2Stream stream, FullHttpMessage msg, + boolean endOfStream) { + if (endOfStream) { + // Release if the msg from the map is different from the object being forwarded up the pipeline. + fireChannelRead(ctx, msg, getMessage(stream) != msg, stream); + } else { + putMessage(stream, msg); + } + } + + @Override + public int onDataRead(ChannelHandlerContext ctx, int streamId, ByteBuf data, int padding, boolean endOfStream) + throws Http2Exception { + Http2Stream stream = connection.stream(streamId); + FullHttpMessage msg = getMessage(stream); + if (msg == null) { + throw connectionError(PROTOCOL_ERROR, "Data Frame received for unknown stream id %d", streamId); + } + + ByteBuf content = msg.content(); + final int dataReadableBytes = data.readableBytes(); + if (content.readableBytes() > maxContentLength - dataReadableBytes) { + throw connectionError(INTERNAL_ERROR, + "Content length exceeded max of %d for stream id %d", maxContentLength, streamId); + } + + content.writeBytes(data, data.readerIndex(), dataReadableBytes); + + if (endOfStream) { + fireChannelRead(ctx, msg, false, stream); + } + + // All bytes have been processed. + return dataReadableBytes + padding; + } + + @Override + public void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers headers, int padding, + boolean endOfStream) throws Http2Exception { + Http2Stream stream = connection.stream(streamId); + FullHttpMessage msg = processHeadersBegin(ctx, stream, headers, endOfStream, true, true); + if (msg != null) { + processHeadersEnd(ctx, stream, msg, endOfStream); + } + } + + @Override + public void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers headers, int streamDependency, + short weight, boolean exclusive, int padding, boolean endOfStream) + throws Http2Exception { + Http2Stream stream = connection.stream(streamId); + FullHttpMessage msg = processHeadersBegin(ctx, stream, headers, endOfStream, true, true); + if (msg != null) { + // Add headers for dependency and weight. + // See https://github.com/netty/netty/issues/5866 + if (streamDependency != Http2CodecUtil.CONNECTION_STREAM_ID) { + msg.headers().setInt(HttpConversionUtil.ExtensionHeaderNames.STREAM_DEPENDENCY_ID.text(), + streamDependency); + } + msg.headers().setShort(HttpConversionUtil.ExtensionHeaderNames.STREAM_WEIGHT.text(), weight); + + processHeadersEnd(ctx, stream, msg, endOfStream); + } + } + + @Override + public void onRstStreamRead(ChannelHandlerContext ctx, int streamId, long errorCode) throws Http2Exception { + Http2Stream stream = connection.stream(streamId); + FullHttpMessage msg = getMessage(stream); + if (msg != null) { + onRstStreamRead(stream, msg); + } + ctx.fireExceptionCaught(Http2Exception.streamError(streamId, Http2Error.valueOf(errorCode), + "HTTP/2 to HTTP layer caught stream reset")); + } + + @Override + public void onPushPromiseRead(ChannelHandlerContext ctx, int streamId, int promisedStreamId, + Http2Headers headers, int padding) throws Http2Exception { + // A push promise should not be allowed to add headers to an existing stream + Http2Stream promisedStream = connection.stream(promisedStreamId); + if (headers.status() == null) { + // A PUSH_PROMISE frame has no Http response status. + // https://tools.ietf.org/html/rfc7540#section-8.2.1 + // Server push is semantically equivalent to a server responding to a + // request; however, in this case, that request is also sent by the + // server, as a PUSH_PROMISE frame. + headers.status(OK.codeAsText()); + } + FullHttpMessage msg = processHeadersBegin(ctx, promisedStream, headers, false, false, false); + if (msg == null) { + throw connectionError(PROTOCOL_ERROR, "Push Promise Frame received for pre-existing stream id %d", + promisedStreamId); + } + + msg.headers().setInt(HttpConversionUtil.ExtensionHeaderNames.STREAM_PROMISE_ID.text(), streamId); + msg.headers().setShort(HttpConversionUtil.ExtensionHeaderNames.STREAM_WEIGHT.text(), + Http2CodecUtil.DEFAULT_PRIORITY_WEIGHT); + + processHeadersEnd(ctx, promisedStream, msg, false); + } + + @Override + public void onSettingsRead(ChannelHandlerContext ctx, Http2Settings settings) throws Http2Exception { + if (propagateSettings) { + // Provide an interface for non-listeners to capture settings + ctx.fireChannelRead(settings); + } + } + + /** + * Called if a {@code RST_STREAM} is received but we have some data for that stream. + */ + protected void onRstStreamRead(Http2Stream stream, FullHttpMessage msg) { + removeMessage(stream, true); + } + + /** + * Allows messages to be sent up the pipeline before the next phase in the + * HTTP message flow is detected. + */ + private interface ImmediateSendDetector { + /** + * Determine if the response should be sent immediately, or wait for the end of the stream + * + * @param msg The response to test + * @return {@code true} if the message should be sent immediately + * {@code false) if we should wait for the end of the stream + */ + boolean mustSendImmediately(FullHttpMessage msg); + + /** + * Determine if a copy must be made after an immediate send happens. + *

+ * An example of this use case is if a request is received + * with a 'Expect: 100-continue' header. The message will be sent immediately, + * and the data will be queued and sent at the end of the stream. + * + * @param allocator The {@link ByteBufAllocator} that can be used to allocate + * @param msg The message which has just been sent due to {@link #mustSendImmediately(FullHttpMessage)} + * @return A modified copy of the {@code msg} or {@code null} if a copy is not needed. + */ + FullHttpMessage copyIfNeeded(ByteBufAllocator allocator, FullHttpMessage msg); + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/InboundHttp2ToHttpAdapterBuilder.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/InboundHttp2ToHttpAdapterBuilder.java new file mode 100644 index 0000000..4e3691b --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/InboundHttp2ToHttpAdapterBuilder.java @@ -0,0 +1,65 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.util.internal.UnstableApi; + +/** + * Builds an {@link InboundHttp2ToHttpAdapter}. + */ +@UnstableApi +public final class InboundHttp2ToHttpAdapterBuilder + extends AbstractInboundHttp2ToHttpAdapterBuilder { + + /** + * Creates a new {@link InboundHttp2ToHttpAdapter} builder for the specified {@link Http2Connection}. + * + * @param connection the object which will provide connection notification events + * for the current connection + */ + public InboundHttp2ToHttpAdapterBuilder(Http2Connection connection) { + super(connection); + } + + @Override + public InboundHttp2ToHttpAdapterBuilder maxContentLength(int maxContentLength) { + return super.maxContentLength(maxContentLength); + } + + @Override + public InboundHttp2ToHttpAdapterBuilder validateHttpHeaders(boolean validate) { + return super.validateHttpHeaders(validate); + } + + @Override + public InboundHttp2ToHttpAdapterBuilder propagateSettings(boolean propagate) { + return super.propagateSettings(propagate); + } + + @Override + public InboundHttp2ToHttpAdapter build() { + return super.build(); + } + + @Override + protected InboundHttp2ToHttpAdapter build(Http2Connection connection, + int maxContentLength, + boolean validateHttpHeaders, + boolean propagateSettings) throws Exception { + + return new InboundHttp2ToHttpAdapter(connection, maxContentLength, + validateHttpHeaders, propagateSettings); + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/InboundHttpToHttp2Adapter.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/InboundHttpToHttp2Adapter.java new file mode 100644 index 0000000..f1ff07f --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/InboundHttpToHttp2Adapter.java @@ -0,0 +1,81 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.handler.codec.http.FullHttpMessage; +import io.netty.handler.codec.http.HttpHeaders; +import io.netty.handler.codec.http.HttpScheme; +import io.netty.util.internal.UnstableApi; + +/** + * Translates HTTP/1.x object reads into HTTP/2 frames. + */ +@UnstableApi +public class InboundHttpToHttp2Adapter extends ChannelInboundHandlerAdapter { + private final Http2Connection connection; + private final Http2FrameListener listener; + + public InboundHttpToHttp2Adapter(Http2Connection connection, Http2FrameListener listener) { + this.connection = connection; + this.listener = listener; + } + + private static int getStreamId(Http2Connection connection, HttpHeaders httpHeaders) { + return httpHeaders.getInt(HttpConversionUtil.ExtensionHeaderNames.STREAM_ID.text(), + connection.remote().incrementAndGetNextStreamId()); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if (msg instanceof FullHttpMessage) { + handle(ctx, connection, listener, (FullHttpMessage) msg); + } else { + super.channelRead(ctx, msg); + } + } + + // note that this may behave strangely when used for the initial upgrade + // message when using h2c, since that message is ineligible for flow + // control, but there is not yet an API for signaling that. + static void handle(ChannelHandlerContext ctx, Http2Connection connection, + Http2FrameListener listener, FullHttpMessage message) throws Http2Exception { + try { + int streamId = getStreamId(connection, message.headers()); + Http2Stream stream = connection.stream(streamId); + if (stream == null) { + stream = connection.remote().createStream(streamId, false); + } + message.headers().set(HttpConversionUtil.ExtensionHeaderNames.SCHEME.text(), HttpScheme.HTTP.name()); + Http2Headers messageHeaders = HttpConversionUtil.toHttp2Headers(message, true); + boolean hasContent = message.content().isReadable(); + boolean hasTrailers = !message.trailingHeaders().isEmpty(); + listener.onHeadersRead( + ctx, streamId, messageHeaders, 0, !(hasContent || hasTrailers)); + if (hasContent) { + listener.onDataRead(ctx, streamId, message.content(), 0, !hasTrailers); + } + if (hasTrailers) { + Http2Headers headers = HttpConversionUtil.toHttp2Headers(message.trailingHeaders(), true); + listener.onHeadersRead(ctx, streamId, headers, 0, true); + } + stream.closeRemoteSide(); + } finally { + message.release(); + } + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/MaxCapacityQueue.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/MaxCapacityQueue.java new file mode 100644 index 0000000..90b1d0a --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/MaxCapacityQueue.java @@ -0,0 +1,129 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http2; + +import java.util.Collection; +import java.util.Iterator; +import java.util.Queue; + +final class MaxCapacityQueue implements Queue { + private final Queue queue; + private final int maxCapacity; + + MaxCapacityQueue(Queue queue, int maxCapacity) { + this.queue = queue; + this.maxCapacity = maxCapacity; + } + + @Override + public boolean add(E element) { + if (offer(element)) { + return true; + } + throw new IllegalStateException(); + } + + @Override + public boolean offer(E element) { + if (maxCapacity <= queue.size()) { + return false; + } + return queue.offer(element); + } + + @Override + public E remove() { + return queue.remove(); + } + + @Override + public E poll() { + return queue.poll(); + } + + @Override + public E element() { + return queue.element(); + } + + @Override + public E peek() { + return queue.peek(); + } + + @Override + public int size() { + return queue.size(); + } + + @Override + public boolean isEmpty() { + return queue.isEmpty(); + } + + @Override + public boolean contains(Object o) { + return queue.contains(o); + } + + @Override + public Iterator iterator() { + return queue.iterator(); + } + + @Override + public Object[] toArray() { + return queue.toArray(); + } + + @Override + public T[] toArray(T[] a) { + return queue.toArray(a); + } + + @Override + public boolean remove(Object o) { + return queue.remove(o); + } + + @Override + public boolean containsAll(Collection c) { + return queue.containsAll(c); + } + + @Override + public boolean addAll(Collection c) { + if (maxCapacity >= size() + c.size()) { + return queue.addAll(c); + } + throw new IllegalStateException(); + } + + @Override + public boolean removeAll(Collection c) { + return queue.removeAll(c); + } + + @Override + public boolean retainAll(Collection c) { + return queue.retainAll(c); + } + + @Override + public void clear() { + queue.clear(); + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/ReadOnlyHttp2Headers.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/ReadOnlyHttp2Headers.java new file mode 100644 index 0000000..ffb0f83 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/ReadOnlyHttp2Headers.java @@ -0,0 +1,892 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.handler.codec.Headers; +import io.netty.util.AsciiString; +import io.netty.util.HashingStrategy; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.Iterator; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Set; + +import static io.netty.handler.codec.CharSequenceValueConverter.*; +import static io.netty.handler.codec.http2.DefaultHttp2Headers.*; +import static io.netty.util.AsciiString.*; +import static io.netty.util.internal.EmptyArrays.*; +import static io.netty.util.internal.ObjectUtil.checkNotNullArrayParam; + +/** + * A variant of {@link Http2Headers} which only supports read-only methods. + *

+ * Any array passed to this class may be used directly in the underlying data structures of this class. If these + * arrays may be modified it is the caller's responsibility to supply this class with a copy of the array. + *

+ * This may be a good alternative to {@link DefaultHttp2Headers} if your have a fixed set of headers which will not + * change. + */ +public final class ReadOnlyHttp2Headers implements Http2Headers { + private static final byte PSEUDO_HEADER_TOKEN = (byte) ':'; + private final AsciiString[] pseudoHeaders; + private final AsciiString[] otherHeaders; + + /** + * Used to create read only object designed to represent trailers. + *

+ * If this is used for a purpose other than trailers you may violate the header serialization ordering defined by + * RFC 7540, 8.1.2.1. + * @param validateHeaders {@code true} will run validation on each header name/value pair to ensure protocol + * compliance. + * @param otherHeaders An array of key:value pairs. Must not contain any + * pseudo headers + * or {@code null} names/values. + * A copy will NOT be made of this array. If the contents of this array + * may be modified externally you are responsible for passing in a copy. + * @return A read only representation of the headers. + */ + public static ReadOnlyHttp2Headers trailers(boolean validateHeaders, AsciiString... otherHeaders) { + return new ReadOnlyHttp2Headers(validateHeaders, EMPTY_ASCII_STRINGS, otherHeaders); + } + + /** + * Create a new read only representation of headers used by clients. + * @param validateHeaders {@code true} will run validation on each header name/value pair to ensure protocol + * compliance. + * @param method The value for {@link PseudoHeaderName#METHOD}. + * @param path The value for {@link PseudoHeaderName#PATH}. + * @param scheme The value for {@link PseudoHeaderName#SCHEME}. + * @param authority The value for {@link PseudoHeaderName#AUTHORITY}. + * @param otherHeaders An array of key:value pairs. Must not contain any + * pseudo headers + * or {@code null} names/values. + * A copy will NOT be made of this array. If the contents of this array + * may be modified externally you are responsible for passing in a copy. + * @return a new read only representation of headers used by clients. + */ + public static ReadOnlyHttp2Headers clientHeaders(boolean validateHeaders, + AsciiString method, AsciiString path, + AsciiString scheme, AsciiString authority, + AsciiString... otherHeaders) { + return new ReadOnlyHttp2Headers(validateHeaders, + new AsciiString[] { + PseudoHeaderName.METHOD.value(), method, PseudoHeaderName.PATH.value(), path, + PseudoHeaderName.SCHEME.value(), scheme, PseudoHeaderName.AUTHORITY.value(), authority + }, + otherHeaders); + } + + /** + * Create a new read only representation of headers used by servers. + * @param validateHeaders {@code true} will run validation on each header name/value pair to ensure protocol + * compliance. + * @param status The value for {@link PseudoHeaderName#STATUS}. + * @param otherHeaders An array of key:value pairs. Must not contain any + * pseudo headers + * or {@code null} names/values. + * A copy will NOT be made of this array. If the contents of this array + * may be modified externally you are responsible for passing in a copy. + * @return a new read only representation of headers used by servers. + */ + public static ReadOnlyHttp2Headers serverHeaders(boolean validateHeaders, + AsciiString status, + AsciiString... otherHeaders) { + return new ReadOnlyHttp2Headers(validateHeaders, + new AsciiString[] { PseudoHeaderName.STATUS.value(), status }, + otherHeaders); + } + + private ReadOnlyHttp2Headers(boolean validateHeaders, AsciiString[] pseudoHeaders, AsciiString... otherHeaders) { + assert (pseudoHeaders.length & 1) == 0; // pseudoHeaders are only set internally so assert should be enough. + if ((otherHeaders.length & 1) != 0) { + throw newInvalidArraySizeException(); + } + if (validateHeaders) { + validateHeaders(pseudoHeaders, otherHeaders); + } + this.pseudoHeaders = pseudoHeaders; + this.otherHeaders = otherHeaders; + } + + private static IllegalArgumentException newInvalidArraySizeException() { + return new IllegalArgumentException("pseudoHeaders and otherHeaders must be arrays of [name, value] pairs"); + } + + private static void validateHeaders(AsciiString[] pseudoHeaders, AsciiString... otherHeaders) { + // We are only validating values... so start at 1 and go until end. + for (int i = 1; i < pseudoHeaders.length; i += 2) { + // pseudoHeaders names are only set internally so they are assumed to be valid. + checkNotNullArrayParam(pseudoHeaders[i], i, "pseudoHeaders"); + } + + boolean seenNonPseudoHeader = false; + final int otherHeadersEnd = otherHeaders.length - 1; + for (int i = 0; i < otherHeadersEnd; i += 2) { + AsciiString name = otherHeaders[i]; + HTTP2_NAME_VALIDATOR.validateName(name); + if (!seenNonPseudoHeader && !name.isEmpty() && name.byteAt(0) != PSEUDO_HEADER_TOKEN) { + seenNonPseudoHeader = true; + } else if (seenNonPseudoHeader && !name.isEmpty() && name.byteAt(0) == PSEUDO_HEADER_TOKEN) { + throw new IllegalArgumentException( + "otherHeaders name at index " + i + " is a pseudo header that appears after non-pseudo headers."); + } + checkNotNullArrayParam(otherHeaders[i + 1], i + 1, "otherHeaders"); + } + } + + private AsciiString get0(CharSequence name) { + final int nameHash = AsciiString.hashCode(name); + + final int pseudoHeadersEnd = pseudoHeaders.length - 1; + for (int i = 0; i < pseudoHeadersEnd; i += 2) { + AsciiString roName = pseudoHeaders[i]; + if (roName.hashCode() == nameHash && roName.contentEqualsIgnoreCase(name)) { + return pseudoHeaders[i + 1]; + } + } + + final int otherHeadersEnd = otherHeaders.length - 1; + for (int i = 0; i < otherHeadersEnd; i += 2) { + AsciiString roName = otherHeaders[i]; + if (roName.hashCode() == nameHash && roName.contentEqualsIgnoreCase(name)) { + return otherHeaders[i + 1]; + } + } + return null; + } + + @Override + public CharSequence get(CharSequence name) { + return get0(name); + } + + @Override + public CharSequence get(CharSequence name, CharSequence defaultValue) { + CharSequence value = get(name); + return value != null ? value : defaultValue; + } + + @Override + public CharSequence getAndRemove(CharSequence name) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public CharSequence getAndRemove(CharSequence name, CharSequence defaultValue) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public List getAll(CharSequence name) { + final int nameHash = AsciiString.hashCode(name); + List values = new ArrayList(); + + final int pseudoHeadersEnd = pseudoHeaders.length - 1; + for (int i = 0; i < pseudoHeadersEnd; i += 2) { + AsciiString roName = pseudoHeaders[i]; + if (roName.hashCode() == nameHash && roName.contentEqualsIgnoreCase(name)) { + values.add(pseudoHeaders[i + 1]); + } + } + + final int otherHeadersEnd = otherHeaders.length - 1; + for (int i = 0; i < otherHeadersEnd; i += 2) { + AsciiString roName = otherHeaders[i]; + if (roName.hashCode() == nameHash && roName.contentEqualsIgnoreCase(name)) { + values.add(otherHeaders[i + 1]); + } + } + + return values; + } + + @Override + public List getAllAndRemove(CharSequence name) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public Boolean getBoolean(CharSequence name) { + AsciiString value = get0(name); + return value != null ? INSTANCE.convertToBoolean(value) : null; + } + + @Override + public boolean getBoolean(CharSequence name, boolean defaultValue) { + Boolean value = getBoolean(name); + return value != null ? value : defaultValue; + } + + @Override + public Byte getByte(CharSequence name) { + AsciiString value = get0(name); + return value != null ? INSTANCE.convertToByte(value) : null; + } + + @Override + public byte getByte(CharSequence name, byte defaultValue) { + Byte value = getByte(name); + return value != null ? value : defaultValue; + } + + @Override + public Character getChar(CharSequence name) { + AsciiString value = get0(name); + return value != null ? INSTANCE.convertToChar(value) : null; + } + + @Override + public char getChar(CharSequence name, char defaultValue) { + Character value = getChar(name); + return value != null ? value : defaultValue; + } + + @Override + public Short getShort(CharSequence name) { + AsciiString value = get0(name); + return value != null ? INSTANCE.convertToShort(value) : null; + } + + @Override + public short getShort(CharSequence name, short defaultValue) { + Short value = getShort(name); + return value != null ? value : defaultValue; + } + + @Override + public Integer getInt(CharSequence name) { + AsciiString value = get0(name); + return value != null ? INSTANCE.convertToInt(value) : null; + } + + @Override + public int getInt(CharSequence name, int defaultValue) { + Integer value = getInt(name); + return value != null ? value : defaultValue; + } + + @Override + public Long getLong(CharSequence name) { + AsciiString value = get0(name); + return value != null ? INSTANCE.convertToLong(value) : null; + } + + @Override + public long getLong(CharSequence name, long defaultValue) { + Long value = getLong(name); + return value != null ? value : defaultValue; + } + + @Override + public Float getFloat(CharSequence name) { + AsciiString value = get0(name); + return value != null ? INSTANCE.convertToFloat(value) : null; + } + + @Override + public float getFloat(CharSequence name, float defaultValue) { + Float value = getFloat(name); + return value != null ? value : defaultValue; + } + + @Override + public Double getDouble(CharSequence name) { + AsciiString value = get0(name); + return value != null ? INSTANCE.convertToDouble(value) : null; + } + + @Override + public double getDouble(CharSequence name, double defaultValue) { + Double value = getDouble(name); + return value != null ? value : defaultValue; + } + + @Override + public Long getTimeMillis(CharSequence name) { + AsciiString value = get0(name); + return value != null ? INSTANCE.convertToTimeMillis(value) : null; + } + + @Override + public long getTimeMillis(CharSequence name, long defaultValue) { + Long value = getTimeMillis(name); + return value != null ? value : defaultValue; + } + + @Override + public Boolean getBooleanAndRemove(CharSequence name) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public boolean getBooleanAndRemove(CharSequence name, boolean defaultValue) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public Byte getByteAndRemove(CharSequence name) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public byte getByteAndRemove(CharSequence name, byte defaultValue) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public Character getCharAndRemove(CharSequence name) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public char getCharAndRemove(CharSequence name, char defaultValue) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public Short getShortAndRemove(CharSequence name) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public short getShortAndRemove(CharSequence name, short defaultValue) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public Integer getIntAndRemove(CharSequence name) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public int getIntAndRemove(CharSequence name, int defaultValue) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public Long getLongAndRemove(CharSequence name) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public long getLongAndRemove(CharSequence name, long defaultValue) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public Float getFloatAndRemove(CharSequence name) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public float getFloatAndRemove(CharSequence name, float defaultValue) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public Double getDoubleAndRemove(CharSequence name) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public double getDoubleAndRemove(CharSequence name, double defaultValue) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public Long getTimeMillisAndRemove(CharSequence name) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public long getTimeMillisAndRemove(CharSequence name, long defaultValue) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public boolean contains(CharSequence name) { + return get(name) != null; + } + + @Override + public boolean contains(CharSequence name, CharSequence value) { + return contains(name, value, false); + } + + @Override + public boolean containsObject(CharSequence name, Object value) { + if (value instanceof CharSequence) { + return contains(name, (CharSequence) value); + } + return contains(name, value.toString()); + } + + @Override + public boolean containsBoolean(CharSequence name, boolean value) { + return contains(name, String.valueOf(value)); + } + + @Override + public boolean containsByte(CharSequence name, byte value) { + return contains(name, String.valueOf(value)); + } + + @Override + public boolean containsChar(CharSequence name, char value) { + return contains(name, String.valueOf(value)); + } + + @Override + public boolean containsShort(CharSequence name, short value) { + return contains(name, String.valueOf(value)); + } + + @Override + public boolean containsInt(CharSequence name, int value) { + return contains(name, String.valueOf(value)); + } + + @Override + public boolean containsLong(CharSequence name, long value) { + return contains(name, String.valueOf(value)); + } + + @Override + public boolean containsFloat(CharSequence name, float value) { + return false; + } + + @Override + public boolean containsDouble(CharSequence name, double value) { + return contains(name, String.valueOf(value)); + } + + @Override + public boolean containsTimeMillis(CharSequence name, long value) { + return contains(name, String.valueOf(value)); + } + + @Override + public int size() { + return pseudoHeaders.length + otherHeaders.length >>> 1; + } + + @Override + public boolean isEmpty() { + return pseudoHeaders.length == 0 && otherHeaders.length == 0; + } + + @Override + public Set names() { + if (isEmpty()) { + return Collections.emptySet(); + } + Set names = new LinkedHashSet(size()); + final int pseudoHeadersEnd = pseudoHeaders.length - 1; + for (int i = 0; i < pseudoHeadersEnd; i += 2) { + names.add(pseudoHeaders[i]); + } + + final int otherHeadersEnd = otherHeaders.length - 1; + for (int i = 0; i < otherHeadersEnd; i += 2) { + names.add(otherHeaders[i]); + } + return names; + } + + @Override + public Http2Headers add(CharSequence name, CharSequence value) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public Http2Headers add(CharSequence name, Iterable values) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public Http2Headers add(CharSequence name, CharSequence... values) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public Http2Headers addObject(CharSequence name, Object value) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public Http2Headers addObject(CharSequence name, Iterable values) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public Http2Headers addObject(CharSequence name, Object... values) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public Http2Headers addBoolean(CharSequence name, boolean value) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public Http2Headers addByte(CharSequence name, byte value) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public Http2Headers addChar(CharSequence name, char value) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public Http2Headers addShort(CharSequence name, short value) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public Http2Headers addInt(CharSequence name, int value) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public Http2Headers addLong(CharSequence name, long value) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public Http2Headers addFloat(CharSequence name, float value) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public Http2Headers addDouble(CharSequence name, double value) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public Http2Headers addTimeMillis(CharSequence name, long value) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public Http2Headers add(Headers headers) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public Http2Headers set(CharSequence name, CharSequence value) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public Http2Headers set(CharSequence name, Iterable values) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public Http2Headers set(CharSequence name, CharSequence... values) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public Http2Headers setObject(CharSequence name, Object value) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public Http2Headers setObject(CharSequence name, Iterable values) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public Http2Headers setObject(CharSequence name, Object... values) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public Http2Headers setBoolean(CharSequence name, boolean value) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public Http2Headers setByte(CharSequence name, byte value) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public Http2Headers setChar(CharSequence name, char value) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public Http2Headers setShort(CharSequence name, short value) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public Http2Headers setInt(CharSequence name, int value) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public Http2Headers setLong(CharSequence name, long value) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public Http2Headers setFloat(CharSequence name, float value) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public Http2Headers setDouble(CharSequence name, double value) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public Http2Headers setTimeMillis(CharSequence name, long value) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public Http2Headers set(Headers headers) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public Http2Headers setAll(Headers headers) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public boolean remove(CharSequence name) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public Http2Headers clear() { + throw new UnsupportedOperationException("read only"); + } + + @Override + public Iterator> iterator() { + return new ReadOnlyIterator(); + } + + @Override + public Iterator valueIterator(CharSequence name) { + return new ReadOnlyValueIterator(name); + } + + @Override + public Http2Headers method(CharSequence value) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public Http2Headers scheme(CharSequence value) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public Http2Headers authority(CharSequence value) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public Http2Headers path(CharSequence value) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public Http2Headers status(CharSequence value) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public CharSequence method() { + return get(PseudoHeaderName.METHOD.value()); + } + + @Override + public CharSequence scheme() { + return get(PseudoHeaderName.SCHEME.value()); + } + + @Override + public CharSequence authority() { + return get(PseudoHeaderName.AUTHORITY.value()); + } + + @Override + public CharSequence path() { + return get(PseudoHeaderName.PATH.value()); + } + + @Override + public CharSequence status() { + return get(PseudoHeaderName.STATUS.value()); + } + + @Override + public boolean contains(CharSequence name, CharSequence value, boolean caseInsensitive) { + final int nameHash = AsciiString.hashCode(name); + final HashingStrategy strategy = + caseInsensitive ? CASE_INSENSITIVE_HASHER : CASE_SENSITIVE_HASHER; + final int valueHash = strategy.hashCode(value); + + return contains(name, nameHash, value, valueHash, strategy, otherHeaders) + || contains(name, nameHash, value, valueHash, strategy, pseudoHeaders); + } + + private static boolean contains(CharSequence name, int nameHash, CharSequence value, int valueHash, + HashingStrategy hashingStrategy, AsciiString[] headers) { + final int headersEnd = headers.length - 1; + for (int i = 0; i < headersEnd; i += 2) { + AsciiString roName = headers[i]; + AsciiString roValue = headers[i + 1]; + if (roName.hashCode() == nameHash && roValue.hashCode() == valueHash && + roName.contentEqualsIgnoreCase(name) && hashingStrategy.equals(roValue, value)) { + return true; + } + } + return false; + } + + @Override + public String toString() { + StringBuilder builder = new StringBuilder(getClass().getSimpleName()).append('['); + String separator = ""; + for (Map.Entry entry : this) { + builder.append(separator); + builder.append(entry.getKey()).append(": ").append(entry.getValue()); + separator = ", "; + } + return builder.append(']').toString(); + } + + private final class ReadOnlyValueIterator implements Iterator { + private int i; + private final int nameHash; + private final CharSequence name; + private AsciiString[] current = pseudoHeaders.length != 0 ? pseudoHeaders : otherHeaders; + private AsciiString next; + + ReadOnlyValueIterator(CharSequence name) { + nameHash = AsciiString.hashCode(name); + this.name = name; + calculateNext(); + } + + @Override + public boolean hasNext() { + return next != null; + } + + @Override + public CharSequence next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + CharSequence current = next; + calculateNext(); + return current; + } + + @Override + public void remove() { + throw new UnsupportedOperationException("read only"); + } + + private void calculateNext() { + for (; i < current.length; i += 2) { + AsciiString roName = current[i]; + if (roName.hashCode() == nameHash && roName.contentEqualsIgnoreCase(name)) { + if (i + 1 < current.length) { + next = current[i + 1]; + i += 2; + } + return; + } + } + if (current == pseudoHeaders) { + i = 0; + current = otherHeaders; + calculateNext(); + } else { + next = null; + } + } + } + + private final class ReadOnlyIterator implements Map.Entry, + Iterator> { + private int i; + private AsciiString[] current = pseudoHeaders.length != 0 ? pseudoHeaders : otherHeaders; + private AsciiString key; + private AsciiString value; + + @Override + public boolean hasNext() { + return i != current.length; + } + + @Override + public Map.Entry next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + key = current[i]; + value = current[i + 1]; + i += 2; + if (i == current.length && current == pseudoHeaders) { + current = otherHeaders; + i = 0; + } + return this; + } + + @Override + public CharSequence getKey() { + return key; + } + + @Override + public CharSequence getValue() { + return value; + } + + @Override + public CharSequence setValue(CharSequence value) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public void remove() { + throw new UnsupportedOperationException("read only"); + } + + @Override + public String toString() { + return key.toString() + '=' + value.toString(); + } + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/StreamBufferingEncoder.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/StreamBufferingEncoder.java new file mode 100644 index 0000000..56cb9c0 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/StreamBufferingEncoder.java @@ -0,0 +1,382 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.internal.UnstableApi; + +import java.util.ArrayDeque; +import java.util.Iterator; +import java.util.Map; +import java.util.Queue; +import java.util.TreeMap; + +import static io.netty.handler.codec.http2.Http2CodecUtil.SMALLEST_MAX_CONCURRENT_STREAMS; +import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR; +import static io.netty.handler.codec.http2.Http2Exception.connectionError; + +/** + * Implementation of a {@link Http2ConnectionEncoder} that dispatches all method call to another + * {@link Http2ConnectionEncoder}, until {@code SETTINGS_MAX_CONCURRENT_STREAMS} is reached. + *

+ *

When this limit is hit, instead of rejecting any new streams this implementation buffers newly + * created streams and their corresponding frames. Once an active stream gets closed or the maximum + * number of concurrent streams is increased, this encoder will automatically try to empty its + * buffer and create as many new streams as possible. + *

+ *

+ * If a {@code GOAWAY} frame is received from the remote endpoint, all buffered writes for streams + * with an ID less than the specified {@code lastStreamId} will immediately fail with a + * {@link Http2GoAwayException}. + *

+ *

+ * If the channel/encoder gets closed, all new and buffered writes will immediately fail with a + * {@link Http2ChannelClosedException}. + *

+ *

This implementation makes the buffering mostly transparent and is expected to be used as a + * drop-in decorator of {@link DefaultHttp2ConnectionEncoder}. + *

+ */ +@UnstableApi +public class StreamBufferingEncoder extends DecoratingHttp2ConnectionEncoder { + + /** + * Thrown if buffered streams are terminated due to this encoder being closed. + */ + public static final class Http2ChannelClosedException extends Http2Exception { + private static final long serialVersionUID = 4768543442094476971L; + + public Http2ChannelClosedException() { + super(Http2Error.REFUSED_STREAM, "Connection closed"); + } + } + + private static final class GoAwayDetail { + private final int lastStreamId; + private final long errorCode; + private final byte[] debugData; + + GoAwayDetail(int lastStreamId, long errorCode, byte[] debugData) { + this.lastStreamId = lastStreamId; + this.errorCode = errorCode; + this.debugData = debugData.clone(); + } + } + + /** + * Thrown by {@link StreamBufferingEncoder} if buffered streams are terminated due to + * receipt of a {@code GOAWAY}. + */ + public static final class Http2GoAwayException extends Http2Exception { + private static final long serialVersionUID = 1326785622777291198L; + private final GoAwayDetail goAwayDetail; + + public Http2GoAwayException(int lastStreamId, long errorCode, byte[] debugData) { + this(new GoAwayDetail(lastStreamId, errorCode, debugData)); + } + + Http2GoAwayException(GoAwayDetail goAwayDetail) { + super(Http2Error.STREAM_CLOSED); + this.goAwayDetail = goAwayDetail; + } + + public int lastStreamId() { + return goAwayDetail.lastStreamId; + } + + public long errorCode() { + return goAwayDetail.errorCode; + } + + public byte[] debugData() { + return goAwayDetail.debugData.clone(); + } + } + + /** + * Buffer for any streams and corresponding frames that could not be created due to the maximum + * concurrent stream limit being hit. + */ + private final TreeMap pendingStreams = new TreeMap(); + private int maxConcurrentStreams; + private boolean closed; + private GoAwayDetail goAwayDetail; + + public StreamBufferingEncoder(Http2ConnectionEncoder delegate) { + this(delegate, SMALLEST_MAX_CONCURRENT_STREAMS); + } + + public StreamBufferingEncoder(Http2ConnectionEncoder delegate, int initialMaxConcurrentStreams) { + super(delegate); + maxConcurrentStreams = initialMaxConcurrentStreams; + connection().addListener(new Http2ConnectionAdapter() { + + @Override + public void onGoAwayReceived(int lastStreamId, long errorCode, ByteBuf debugData) { + goAwayDetail = new GoAwayDetail( + // Using getBytes(..., false) is safe here as GoAwayDetail(...) will clone the byte[]. + lastStreamId, errorCode, + ByteBufUtil.getBytes(debugData, debugData.readerIndex(), debugData.readableBytes(), false)); + cancelGoAwayStreams(goAwayDetail); + } + + @Override + public void onStreamClosed(Http2Stream stream) { + tryCreatePendingStreams(); + } + }); + } + + /** + * Indicates the number of streams that are currently buffered, awaiting creation. + */ + public int numBufferedStreams() { + return pendingStreams.size(); + } + + @Override + public ChannelFuture writeHeaders(ChannelHandlerContext ctx, int streamId, Http2Headers headers, + int padding, boolean endStream, ChannelPromise promise) { + return writeHeaders(ctx, streamId, headers, 0, Http2CodecUtil.DEFAULT_PRIORITY_WEIGHT, + false, padding, endStream, promise); + } + + @Override + public ChannelFuture writeHeaders(ChannelHandlerContext ctx, int streamId, Http2Headers headers, + int streamDependency, short weight, boolean exclusive, + int padding, boolean endOfStream, ChannelPromise promise) { + if (closed) { + return promise.setFailure(new Http2ChannelClosedException()); + } + if (isExistingStream(streamId) || canCreateStream()) { + return super.writeHeaders(ctx, streamId, headers, streamDependency, weight, + exclusive, padding, endOfStream, promise); + } + if (goAwayDetail != null) { + return promise.setFailure(new Http2GoAwayException(goAwayDetail)); + } + PendingStream pendingStream = pendingStreams.get(streamId); + if (pendingStream == null) { + pendingStream = new PendingStream(ctx, streamId); + pendingStreams.put(streamId, pendingStream); + } + pendingStream.frames.add(new HeadersFrame(headers, streamDependency, weight, exclusive, + padding, endOfStream, promise)); + return promise; + } + + @Override + public ChannelFuture writeRstStream(ChannelHandlerContext ctx, int streamId, long errorCode, + ChannelPromise promise) { + if (isExistingStream(streamId)) { + return super.writeRstStream(ctx, streamId, errorCode, promise); + } + // Since the delegate doesn't know about any buffered streams we have to handle cancellation + // of the promises and releasing of the ByteBufs here. + PendingStream stream = pendingStreams.remove(streamId); + if (stream != null) { + // Sending a RST_STREAM to a buffered stream will succeed the promise of all frames + // associated with the stream, as sending a RST_STREAM means that someone "doesn't care" + // about the stream anymore and thus there is not point in failing the promises and invoking + // error handling routines. + stream.close(null); + promise.setSuccess(); + } else { + promise.setFailure(connectionError(PROTOCOL_ERROR, "Stream does not exist %d", streamId)); + } + return promise; + } + + @Override + public ChannelFuture writeData(ChannelHandlerContext ctx, int streamId, ByteBuf data, + int padding, boolean endOfStream, ChannelPromise promise) { + if (isExistingStream(streamId)) { + return super.writeData(ctx, streamId, data, padding, endOfStream, promise); + } + PendingStream pendingStream = pendingStreams.get(streamId); + if (pendingStream != null) { + pendingStream.frames.add(new DataFrame(data, padding, endOfStream, promise)); + } else { + ReferenceCountUtil.safeRelease(data); + promise.setFailure(connectionError(PROTOCOL_ERROR, "Stream does not exist %d", streamId)); + } + return promise; + } + + @Override + public void remoteSettings(Http2Settings settings) throws Http2Exception { + // Need to let the delegate decoder handle the settings first, so that it sees the + // new setting before we attempt to create any new streams. + super.remoteSettings(settings); + + // Get the updated value for SETTINGS_MAX_CONCURRENT_STREAMS. + maxConcurrentStreams = connection().local().maxActiveStreams(); + + // Try to create new streams up to the new threshold. + tryCreatePendingStreams(); + } + + @Override + public void close() { + try { + if (!closed) { + closed = true; + + // Fail all buffered streams. + Http2ChannelClosedException e = new Http2ChannelClosedException(); + while (!pendingStreams.isEmpty()) { + PendingStream stream = pendingStreams.pollFirstEntry().getValue(); + stream.close(e); + } + } + } finally { + super.close(); + } + } + + private void tryCreatePendingStreams() { + while (!pendingStreams.isEmpty() && canCreateStream()) { + Map.Entry entry = pendingStreams.pollFirstEntry(); + PendingStream pendingStream = entry.getValue(); + try { + pendingStream.sendFrames(); + } catch (Throwable t) { + pendingStream.close(t); + } + } + } + + private void cancelGoAwayStreams(GoAwayDetail goAwayDetail) { + Iterator iter = pendingStreams.values().iterator(); + Exception e = new Http2GoAwayException(goAwayDetail); + while (iter.hasNext()) { + PendingStream stream = iter.next(); + if (stream.streamId > goAwayDetail.lastStreamId) { + iter.remove(); + stream.close(e); + } + } + } + + /** + * Determines whether or not we're allowed to create a new stream right now. + */ + private boolean canCreateStream() { + return connection().local().numActiveStreams() < maxConcurrentStreams; + } + + private boolean isExistingStream(int streamId) { + return streamId <= connection().local().lastStreamCreated(); + } + + private static final class PendingStream { + final ChannelHandlerContext ctx; + final int streamId; + final Queue frames = new ArrayDeque(2); + + PendingStream(ChannelHandlerContext ctx, int streamId) { + this.ctx = ctx; + this.streamId = streamId; + } + + void sendFrames() { + for (Frame frame : frames) { + frame.send(ctx, streamId); + } + } + + void close(Throwable t) { + for (Frame frame : frames) { + frame.release(t); + } + } + } + + private abstract static class Frame { + final ChannelPromise promise; + + Frame(ChannelPromise promise) { + this.promise = promise; + } + + /** + * Release any resources (features, buffers, ...) associated with the frame. + */ + void release(Throwable t) { + if (t == null) { + promise.setSuccess(); + } else { + promise.setFailure(t); + } + } + + abstract void send(ChannelHandlerContext ctx, int streamId); + } + + private final class HeadersFrame extends Frame { + final Http2Headers headers; + final int streamDependency; + final short weight; + final boolean exclusive; + final int padding; + final boolean endOfStream; + + HeadersFrame(Http2Headers headers, int streamDependency, short weight, boolean exclusive, + int padding, boolean endOfStream, ChannelPromise promise) { + super(promise); + this.headers = headers; + this.streamDependency = streamDependency; + this.weight = weight; + this.exclusive = exclusive; + this.padding = padding; + this.endOfStream = endOfStream; + } + + @Override + void send(ChannelHandlerContext ctx, int streamId) { + writeHeaders(ctx, streamId, headers, streamDependency, weight, exclusive, padding, endOfStream, promise); + } + } + + private final class DataFrame extends Frame { + final ByteBuf data; + final int padding; + final boolean endOfStream; + + DataFrame(ByteBuf data, int padding, boolean endOfStream, ChannelPromise promise) { + super(promise); + this.data = data; + this.padding = padding; + this.endOfStream = endOfStream; + } + + @Override + void release(Throwable t) { + super.release(t); + ReferenceCountUtil.safeRelease(data); + } + + @Override + void send(ChannelHandlerContext ctx, int streamId) { + writeData(ctx, streamId, data, padding, endOfStream, promise); + } + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/StreamByteDistributor.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/StreamByteDistributor.java new file mode 100644 index 0000000..2fe840d --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/StreamByteDistributor.java @@ -0,0 +1,112 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package io.netty.handler.codec.http2; + +import io.netty.util.internal.UnstableApi; + +/** + * An object (used by remote flow control) that is responsible for distributing the bytes to be + * written across the streams in the connection. + */ +@UnstableApi +public interface StreamByteDistributor { + + /** + * State information for the stream, indicating the number of bytes that are currently + * streamable. This is provided to the {@link #updateStreamableBytes(StreamState)} method. + */ + interface StreamState { + /** + * Gets the stream this state is associated with. + */ + Http2Stream stream(); + + /** + * Get the amount of bytes this stream has pending to send. The actual amount written must not exceed + * {@link #windowSize()}! + * @return The amount of bytes this stream has pending to send. + * @see Http2CodecUtil#streamableBytes(StreamState) + */ + long pendingBytes(); + + /** + * Indicates whether or not there are frames pending for this stream. + */ + boolean hasFrame(); + + /** + * The size (in bytes) of the stream's flow control window. The amount written must not exceed this amount! + *

A {@link StreamByteDistributor} needs to know the stream's window size in order to avoid allocating bytes + * if the window size is negative. The window size being {@code 0} may also be significant to determine when if + * an stream has been given a chance to write an empty frame, and also enables optimizations like not writing + * empty frames in some situations (don't write headers until data can also be written). + * @return the size of the stream's flow control window. + * @see Http2CodecUtil#streamableBytes(StreamState) + */ + int windowSize(); + } + + /** + * Object that performs the writing of the bytes that have been allocated for a stream. + */ + interface Writer { + /** + * Writes the allocated bytes for this stream. + *

+ * Any {@link Throwable} thrown from this method is considered a programming error. + * A {@code GOAWAY} frame will be sent and the will be connection closed. + * @param stream the stream for which to perform the write. + * @param numBytes the number of bytes to write. + */ + void write(Http2Stream stream, int numBytes); + } + + /** + * Called when the streamable bytes for a stream has changed. Until this + * method is called for the first time for a give stream, the stream is assumed to have no + * streamable bytes. + */ + void updateStreamableBytes(StreamState state); + + /** + * Explicitly update the dependency tree. This method is called independently of stream state changes. + * @param childStreamId The stream identifier associated with the child stream. + * @param parentStreamId The stream identifier associated with the parent stream. May be {@code 0}, + * to make {@code childStreamId} and immediate child of the connection. + * @param weight The weight which is used relative to other child streams for {@code parentStreamId}. This value + * must be between 1 and 256 (inclusive). + * @param exclusive If {@code childStreamId} should be the exclusive dependency of {@code parentStreamId}. + */ + void updateDependencyTree(int childStreamId, int parentStreamId, short weight, boolean exclusive); + + /** + * Distributes up to {@code maxBytes} to those streams containing streamable bytes and + * iterates across those streams to write the appropriate bytes. Criteria for + * traversing streams is undefined and it is up to the implementation to determine when to stop + * at a given stream. + * + *

The streamable bytes are not automatically updated by calling this method. It is up to the + * caller to indicate the number of bytes streamable after the write by calling + * {@link #updateStreamableBytes(StreamState)}. + * + * @param maxBytes the maximum number of bytes to write. + * @return {@code true} if there are still streamable bytes that have not yet been written, + * otherwise {@code false}. + * @throws Http2Exception If an internal exception occurs and internal connection state would otherwise be + * corrupted. + */ + boolean distribute(int maxBytes, Writer writer) throws Http2Exception; +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/UniformStreamByteDistributor.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/UniformStreamByteDistributor.java new file mode 100644 index 0000000..0cf0cdf --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/UniformStreamByteDistributor.java @@ -0,0 +1,205 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.util.internal.UnstableApi; + +import java.util.ArrayDeque; +import java.util.Deque; + +import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_MIN_ALLOCATION_CHUNK; +import static io.netty.handler.codec.http2.Http2CodecUtil.streamableBytes; +import static io.netty.handler.codec.http2.Http2Error.INTERNAL_ERROR; +import static io.netty.handler.codec.http2.Http2Exception.connectionError; +import static io.netty.util.internal.ObjectUtil.checkNotNull; +import static io.netty.util.internal.ObjectUtil.checkPositive; +import static java.lang.Math.max; +import static java.lang.Math.min; + +/** + * A {@link StreamByteDistributor} that ignores stream priority and uniformly allocates bytes to all + * streams. This class uses a minimum chunk size that will be allocated to each stream. While + * fewer streams may be written to in each call to {@link #distribute(int, Writer)}, doing this + * should improve the goodput on each written stream. + */ +@UnstableApi +public final class UniformStreamByteDistributor implements StreamByteDistributor { + private final Http2Connection.PropertyKey stateKey; + private final Deque queue = new ArrayDeque(4); + + /** + * The minimum number of bytes that we will attempt to allocate to a stream. This is to + * help improve goodput on a per-stream basis. + */ + private int minAllocationChunk = DEFAULT_MIN_ALLOCATION_CHUNK; + private long totalStreamableBytes; + + public UniformStreamByteDistributor(Http2Connection connection) { + // Add a state for the connection. + stateKey = connection.newKey(); + Http2Stream connectionStream = connection.connectionStream(); + connectionStream.setProperty(stateKey, new State(connectionStream)); + + // Register for notification of new streams. + connection.addListener(new Http2ConnectionAdapter() { + @Override + public void onStreamAdded(Http2Stream stream) { + stream.setProperty(stateKey, new State(stream)); + } + + @Override + public void onStreamClosed(Http2Stream stream) { + state(stream).close(); + } + }); + } + + /** + * Sets the minimum allocation chunk that will be allocated to each stream. Defaults to 1KiB. + * + * @param minAllocationChunk the minimum number of bytes that will be allocated to each stream. + * Must be > 0. + */ + public void minAllocationChunk(int minAllocationChunk) { + checkPositive(minAllocationChunk, "minAllocationChunk"); + this.minAllocationChunk = minAllocationChunk; + } + + @Override + public void updateStreamableBytes(StreamState streamState) { + state(streamState.stream()).updateStreamableBytes(streamableBytes(streamState), + streamState.hasFrame(), + streamState.windowSize()); + } + + @Override + public void updateDependencyTree(int childStreamId, int parentStreamId, short weight, boolean exclusive) { + // This class ignores priority and dependency! + } + + @Override + public boolean distribute(int maxBytes, Writer writer) throws Http2Exception { + final int size = queue.size(); + if (size == 0) { + return totalStreamableBytes > 0; + } + + final int chunkSize = max(minAllocationChunk, maxBytes / size); + + State state = queue.pollFirst(); + do { + state.enqueued = false; + if (state.windowNegative) { + continue; + } + if (maxBytes == 0 && state.streamableBytes > 0) { + // Stop at the first state that can't send. Add this state back to the head of the queue. Note + // that empty frames at the head of the queue will always be written, assuming the stream window + // is not negative. + queue.addFirst(state); + state.enqueued = true; + break; + } + + // Allocate as much data as we can for this stream. + int chunk = min(chunkSize, min(maxBytes, state.streamableBytes)); + maxBytes -= chunk; + + // Write the allocated bytes and enqueue as necessary. + state.write(chunk, writer); + } while ((state = queue.pollFirst()) != null); + + return totalStreamableBytes > 0; + } + + private State state(Http2Stream stream) { + return checkNotNull(stream, "stream").getProperty(stateKey); + } + + /** + * The remote flow control state for a single stream. + */ + private final class State { + final Http2Stream stream; + int streamableBytes; + boolean windowNegative; + boolean enqueued; + boolean writing; + + State(Http2Stream stream) { + this.stream = stream; + } + + void updateStreamableBytes(int newStreamableBytes, boolean hasFrame, int windowSize) { + assert hasFrame || newStreamableBytes == 0 : + "hasFrame: " + hasFrame + " newStreamableBytes: " + newStreamableBytes; + + int delta = newStreamableBytes - streamableBytes; + if (delta != 0) { + streamableBytes = newStreamableBytes; + totalStreamableBytes += delta; + } + // In addition to only enqueuing state when they have frames we enforce the following restrictions: + // 1. If the window has gone negative. We never want to queue a state. However we also don't want to + // Immediately remove the item if it is already queued because removal from deque is O(n). So + // we allow it to stay queued and rely on the distribution loop to remove this state. + // 2. If the window is zero we only want to queue if we are not writing. If we are writing that means + // we gave the state a chance to write zero length frames. We wait until updateStreamableBytes is + // called again before this state is allowed to write. + windowNegative = windowSize < 0; + if (hasFrame && (windowSize > 0 || windowSize == 0 && !writing)) { + addToQueue(); + } + } + + /** + * Write any allocated bytes for the given stream and updates the streamable bytes, + * assuming all of the bytes will be written. + */ + void write(int numBytes, Writer writer) throws Http2Exception { + writing = true; + try { + // Write the allocated bytes. + writer.write(stream, numBytes); + } catch (Throwable t) { + throw connectionError(INTERNAL_ERROR, t, "byte distribution write error"); + } finally { + writing = false; + } + } + + void addToQueue() { + if (!enqueued) { + enqueued = true; + queue.addLast(this); + } + } + + void removeFromQueue() { + if (enqueued) { + enqueued = false; + queue.remove(this); + } + } + + void close() { + // Remove this state from the queue. + removeFromQueue(); + + // Clear the streamable bytes. + updateStreamableBytes(0, false, 0); + } + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/WeightedFairQueueByteDistributor.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/WeightedFairQueueByteDistributor.java new file mode 100644 index 0000000..3672de8 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/WeightedFairQueueByteDistributor.java @@ -0,0 +1,803 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.util.collection.IntCollections; +import io.netty.util.collection.IntObjectHashMap; +import io.netty.util.collection.IntObjectMap; +import io.netty.util.internal.DefaultPriorityQueue; +import io.netty.util.internal.EmptyPriorityQueue; +import io.netty.util.internal.MathUtil; +import io.netty.util.internal.PriorityQueue; +import io.netty.util.internal.PriorityQueueNode; +import io.netty.util.internal.SystemPropertyUtil; +import io.netty.util.internal.UnstableApi; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Comparator; +import java.util.Iterator; +import java.util.List; + +import static io.netty.handler.codec.http2.Http2CodecUtil.CONNECTION_STREAM_ID; +import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_MIN_ALLOCATION_CHUNK; +import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_PRIORITY_WEIGHT; +import static io.netty.handler.codec.http2.Http2CodecUtil.streamableBytes; +import static io.netty.handler.codec.http2.Http2Error.INTERNAL_ERROR; +import static io.netty.handler.codec.http2.Http2Exception.connectionError; +import static io.netty.util.internal.ObjectUtil.checkPositive; +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; +import static java.lang.Integer.MAX_VALUE; +import static java.lang.Math.max; +import static java.lang.Math.min; + +/** + * A {@link StreamByteDistributor} that is sensitive to stream priority and uses + * Weighted Fair Queueing approach for distributing + * bytes. + *

+ * Inspiration for this distributor was taken from Linux's + * Completely Fair Scheduler + * to model the distribution of bytes to simulate an "ideal multi-tasking CPU", but in this case we are simulating + * an "ideal multi-tasking NIC". + *

+ * Each write operation will use the {@link #allocationQuantum(int)} to know how many more bytes should be allocated + * relative to the next stream which wants to write. This is to balance fairness while also considering goodput. + */ +@UnstableApi +public final class WeightedFairQueueByteDistributor implements StreamByteDistributor { + /** + * The initial size of the children map is chosen to be conservative on initial memory allocations under + * the assumption that most streams will have a small number of children. This choice may be + * sub-optimal if when children are present there are many children (i.e. a web page which has many + * dependencies to load). + * + * Visible only for testing! + */ + static final int INITIAL_CHILDREN_MAP_SIZE = + max(1, SystemPropertyUtil.getInt("io.netty.http2.childrenMapSize", 2)); + /** + * FireFox currently uses 5 streams to establish QoS classes. + */ + private static final int DEFAULT_MAX_STATE_ONLY_SIZE = 5; + + private final Http2Connection.PropertyKey stateKey; + /** + * If there is no Http2Stream object, but we still persist priority information then this is where the state will + * reside. + */ + private final IntObjectMap stateOnlyMap; + /** + * This queue will hold streams that are not active and provides the capability to retain priority for streams which + * have no {@link Http2Stream} object. See {@link StateOnlyComparator} for the priority comparator. + */ + private final PriorityQueue stateOnlyRemovalQueue; + private final Http2Connection connection; + private final State connectionState; + /** + * The minimum number of bytes that we will attempt to allocate to a stream. This is to + * help improve goodput on a per-stream basis. + */ + private int allocationQuantum = DEFAULT_MIN_ALLOCATION_CHUNK; + private final int maxStateOnlySize; + + public WeightedFairQueueByteDistributor(Http2Connection connection) { + this(connection, DEFAULT_MAX_STATE_ONLY_SIZE); + } + + public WeightedFairQueueByteDistributor(Http2Connection connection, int maxStateOnlySize) { + checkPositiveOrZero(maxStateOnlySize, "maxStateOnlySize"); + if (maxStateOnlySize == 0) { + stateOnlyMap = IntCollections.emptyMap(); + stateOnlyRemovalQueue = EmptyPriorityQueue.instance(); + } else { + stateOnlyMap = new IntObjectHashMap(maxStateOnlySize); + // +2 because we may exceed the limit by 2 if a new dependency has no associated Http2Stream object. We need + // to create the State objects to put them into the dependency tree, which then impacts priority. + stateOnlyRemovalQueue = new DefaultPriorityQueue(StateOnlyComparator.INSTANCE, maxStateOnlySize + 2); + } + this.maxStateOnlySize = maxStateOnlySize; + + this.connection = connection; + stateKey = connection.newKey(); + final Http2Stream connectionStream = connection.connectionStream(); + connectionStream.setProperty(stateKey, connectionState = new State(connectionStream, 16)); + + // Register for notification of new streams. + connection.addListener(new Http2ConnectionAdapter() { + @Override + public void onStreamAdded(Http2Stream stream) { + State state = stateOnlyMap.remove(stream.id()); + if (state == null) { + state = new State(stream); + // Only the stream which was just added will change parents. So we only need an array of size 1. + List events = new ArrayList(1); + connectionState.takeChild(state, false, events); + notifyParentChanged(events); + } else { + stateOnlyRemovalQueue.removeTyped(state); + state.stream = stream; + } + switch (stream.state()) { + case RESERVED_REMOTE: + case RESERVED_LOCAL: + state.setStreamReservedOrActivated(); + // wasStreamReservedOrActivated is part of the comparator for stateOnlyRemovalQueue there is no + // need to reprioritize here because it will not be in stateOnlyRemovalQueue. + break; + default: + break; + } + stream.setProperty(stateKey, state); + } + + @Override + public void onStreamActive(Http2Stream stream) { + state(stream).setStreamReservedOrActivated(); + // wasStreamReservedOrActivated is part of the comparator for stateOnlyRemovalQueue there is no need to + // reprioritize here because it will not be in stateOnlyRemovalQueue. + } + + @Override + public void onStreamClosed(Http2Stream stream) { + state(stream).close(); + } + + @Override + public void onStreamRemoved(Http2Stream stream) { + // The stream has been removed from the connection. We can no longer rely on the stream's property + // storage to track the State. If we have room, and the precedence of the stream is sufficient, we + // should retain the State in the stateOnlyMap. + State state = state(stream); + + // Typically the stream is set to null when the stream is closed because it is no longer needed to write + // data. However if the stream was not activated it may not be closed (reserved streams) so we ensure + // the stream reference is set to null to avoid retaining a reference longer than necessary. + state.stream = null; + + if (WeightedFairQueueByteDistributor.this.maxStateOnlySize == 0) { + state.parent.removeChild(state); + return; + } + if (stateOnlyRemovalQueue.size() == WeightedFairQueueByteDistributor.this.maxStateOnlySize) { + State stateToRemove = stateOnlyRemovalQueue.peek(); + if (StateOnlyComparator.INSTANCE.compare(stateToRemove, state) >= 0) { + // The "lowest priority" stream is a "higher priority" than the stream being removed, so we + // just discard the state. + state.parent.removeChild(state); + return; + } + stateOnlyRemovalQueue.poll(); + stateToRemove.parent.removeChild(stateToRemove); + stateOnlyMap.remove(stateToRemove.streamId); + } + stateOnlyRemovalQueue.add(state); + stateOnlyMap.put(state.streamId, state); + } + }); + } + + @Override + public void updateStreamableBytes(StreamState state) { + state(state.stream()).updateStreamableBytes(streamableBytes(state), + state.hasFrame() && state.windowSize() >= 0); + } + + @Override + public void updateDependencyTree(int childStreamId, int parentStreamId, short weight, boolean exclusive) { + State state = state(childStreamId); + if (state == null) { + // If there is no State object that means there is no Http2Stream object and we would have to keep the + // State object in the stateOnlyMap and stateOnlyRemovalQueue. However if maxStateOnlySize is 0 this means + // stateOnlyMap and stateOnlyRemovalQueue are empty collections and cannot be modified so we drop the State. + if (maxStateOnlySize == 0) { + return; + } + state = new State(childStreamId); + stateOnlyRemovalQueue.add(state); + stateOnlyMap.put(childStreamId, state); + } + + State newParent = state(parentStreamId); + if (newParent == null) { + // If there is no State object that means there is no Http2Stream object and we would have to keep the + // State object in the stateOnlyMap and stateOnlyRemovalQueue. However if maxStateOnlySize is 0 this means + // stateOnlyMap and stateOnlyRemovalQueue are empty collections and cannot be modified so we drop the State. + if (maxStateOnlySize == 0) { + return; + } + newParent = new State(parentStreamId); + stateOnlyRemovalQueue.add(newParent); + stateOnlyMap.put(parentStreamId, newParent); + // Only the stream which was just added will change parents. So we only need an array of size 1. + List events = new ArrayList(1); + connectionState.takeChild(newParent, false, events); + notifyParentChanged(events); + } + + // if activeCountForTree == 0 then it will not be in its parent's pseudoTimeQueue and thus should not be counted + // toward parent.totalQueuedWeights. + if (state.activeCountForTree != 0 && state.parent != null) { + state.parent.totalQueuedWeights += weight - state.weight; + } + state.weight = weight; + + if (newParent != state.parent || exclusive && newParent.children.size() != 1) { + final List events; + if (newParent.isDescendantOf(state)) { + events = new ArrayList(2 + (exclusive ? newParent.children.size() : 0)); + state.parent.takeChild(newParent, false, events); + } else { + events = new ArrayList(1 + (exclusive ? newParent.children.size() : 0)); + } + newParent.takeChild(state, exclusive, events); + notifyParentChanged(events); + } + + // The location in the dependency tree impacts the priority in the stateOnlyRemovalQueue map. If we created new + // State objects we must check if we exceeded the limit after we insert into the dependency tree to ensure the + // stateOnlyRemovalQueue has been updated. + while (stateOnlyRemovalQueue.size() > maxStateOnlySize) { + State stateToRemove = stateOnlyRemovalQueue.poll(); + stateToRemove.parent.removeChild(stateToRemove); + stateOnlyMap.remove(stateToRemove.streamId); + } + } + + @Override + public boolean distribute(int maxBytes, Writer writer) throws Http2Exception { + // As long as there is some active frame we should write at least 1 time. + if (connectionState.activeCountForTree == 0) { + return false; + } + + // The goal is to write until we write all the allocated bytes or are no longer making progress. + // We still attempt to write even after the number of allocated bytes has been exhausted to allow empty frames + // to be sent. Making progress means the active streams rooted at the connection stream has changed. + int oldIsActiveCountForTree; + do { + oldIsActiveCountForTree = connectionState.activeCountForTree; + // connectionState will never be active, so go right to its children. + maxBytes -= distributeToChildren(maxBytes, writer, connectionState); + } while (connectionState.activeCountForTree != 0 && + (maxBytes > 0 || oldIsActiveCountForTree != connectionState.activeCountForTree)); + + return connectionState.activeCountForTree != 0; + } + + /** + * Sets the amount of bytes that will be allocated to each stream. Defaults to 1KiB. + * @param allocationQuantum the amount of bytes that will be allocated to each stream. Must be > 0. + */ + public void allocationQuantum(int allocationQuantum) { + checkPositive(allocationQuantum, "allocationQuantum"); + this.allocationQuantum = allocationQuantum; + } + + private int distribute(int maxBytes, Writer writer, State state) throws Http2Exception { + if (state.isActive()) { + int nsent = min(maxBytes, state.streamableBytes); + state.write(nsent, writer); + if (nsent == 0 && maxBytes != 0) { + // If a stream sends zero bytes, then we gave it a chance to write empty frames and it is now + // considered inactive until the next call to updateStreamableBytes. This allows descendant streams to + // be allocated bytes when the parent stream can't utilize them. This may be as a result of the + // stream's flow control window being 0. + state.updateStreamableBytes(state.streamableBytes, false); + } + return nsent; + } + + return distributeToChildren(maxBytes, writer, state); + } + + /** + * It is a pre-condition that {@code state.poll()} returns a non-{@code null} value. This is a result of the way + * the allocation algorithm is structured and can be explained in the following cases: + *

For the recursive case

+ * If a stream has no children (in the allocation tree) than that node must be active or it will not be in the + * allocation tree. If a node is active then it will not delegate to children and recursion ends. + *

For the initial case

+ * We check connectionState.activeCountForTree == 0 before any allocation is done. So if the connection stream + * has no active children we don't get into this method. + */ + private int distributeToChildren(int maxBytes, Writer writer, State state) throws Http2Exception { + long oldTotalQueuedWeights = state.totalQueuedWeights; + State childState = state.pollPseudoTimeQueue(); + State nextChildState = state.peekPseudoTimeQueue(); + childState.setDistributing(); + try { + assert nextChildState == null || nextChildState.pseudoTimeToWrite >= childState.pseudoTimeToWrite : + "nextChildState[" + nextChildState.streamId + "].pseudoTime(" + nextChildState.pseudoTimeToWrite + + ") < " + " childState[" + childState.streamId + "].pseudoTime(" + childState.pseudoTimeToWrite + ')'; + int nsent = distribute(nextChildState == null ? maxBytes : + min(maxBytes, (int) min((nextChildState.pseudoTimeToWrite - childState.pseudoTimeToWrite) * + childState.weight / oldTotalQueuedWeights + allocationQuantum, MAX_VALUE) + ), + writer, + childState); + state.pseudoTime += nsent; + childState.updatePseudoTime(state, nsent, oldTotalQueuedWeights); + return nsent; + } finally { + childState.unsetDistributing(); + // Do in finally to ensure the internal flags is not corrupted if an exception is thrown. + // The offer operation is delayed until we unroll up the recursive stack, so we don't have to remove from + // the priority pseudoTimeQueue due to a write operation. + if (childState.activeCountForTree != 0) { + state.offerPseudoTimeQueue(childState); + } + } + } + + private State state(Http2Stream stream) { + return stream.getProperty(stateKey); + } + + private State state(int streamId) { + Http2Stream stream = connection.stream(streamId); + return stream != null ? state(stream) : stateOnlyMap.get(streamId); + } + + /** + * For testing only! + */ + boolean isChild(int childId, int parentId, short weight) { + State parent = state(parentId); + State child; + return parent.children.containsKey(childId) && + (child = state(childId)).parent == parent && child.weight == weight; + } + + /** + * For testing only! + */ + int numChildren(int streamId) { + State state = state(streamId); + return state == null ? 0 : state.children.size(); + } + + /** + * Notify all listeners of the priority tree change events (in ascending order) + * @param events The events (top down order) which have changed + */ + void notifyParentChanged(List events) { + for (int i = 0; i < events.size(); ++i) { + ParentChangedEvent event = events.get(i); + stateOnlyRemovalQueue.priorityChanged(event.state); + if (event.state.parent != null && event.state.activeCountForTree != 0) { + event.state.parent.offerAndInitializePseudoTime(event.state); + event.state.parent.activeCountChangeForTree(event.state.activeCountForTree); + } + } + } + + /** + * A comparator for {@link State} which has no associated {@link Http2Stream} object. The general precedence is: + *
    + *
  • Was a stream activated or reserved (streams only used for priority are higher priority)
  • + *
  • Depth in the priority tree (closer to root is higher priority>
  • + *
  • Stream ID (higher stream ID is higher priority - used for tie breaker)
  • + *
+ */ + private static final class StateOnlyComparator implements Comparator, Serializable { + private static final long serialVersionUID = -4806936913002105966L; + + static final StateOnlyComparator INSTANCE = new StateOnlyComparator(); + + @Override + public int compare(State o1, State o2) { + // "priority only streams" (which have not been activated) are higher priority than streams used for data. + boolean o1Actived = o1.wasStreamReservedOrActivated(); + if (o1Actived != o2.wasStreamReservedOrActivated()) { + return o1Actived ? -1 : 1; + } + // Numerically greater depth is higher priority. + int x = o2.dependencyTreeDepth - o1.dependencyTreeDepth; + + // I also considered tracking the number of streams which are "activated" (eligible transfer data) at each + // subtree. This would require a traversal from each node to the root on dependency tree structural changes, + // and then it would require a re-prioritization at each of these nodes (instead of just the nodes where the + // direct parent changed). The costs of this are judged to be relatively high compared to the nominal + // benefit it provides to the heuristic. Instead folks should just increase maxStateOnlySize. + + // Last resort is to give larger stream ids more priority. + return x != 0 ? x : o1.streamId - o2.streamId; + } + } + + private static final class StatePseudoTimeComparator implements Comparator, Serializable { + private static final long serialVersionUID = -1437548640227161828L; + + static final StatePseudoTimeComparator INSTANCE = new StatePseudoTimeComparator(); + + @Override + public int compare(State o1, State o2) { + return MathUtil.compare(o1.pseudoTimeToWrite, o2.pseudoTimeToWrite); + } + } + + /** + * The remote flow control state for a single stream. + */ + private final class State implements PriorityQueueNode { + private static final byte STATE_IS_ACTIVE = 0x1; + private static final byte STATE_IS_DISTRIBUTING = 0x2; + private static final byte STATE_STREAM_ACTIVATED = 0x4; + + /** + * Maybe {@code null} if the stream if the stream is not active. + */ + Http2Stream stream; + State parent; + IntObjectMap children = IntCollections.emptyMap(); + private final PriorityQueue pseudoTimeQueue; + final int streamId; + int streamableBytes; + int dependencyTreeDepth; + /** + * Count of nodes rooted at this sub tree with {@link #isActive()} equal to {@code true}. + */ + int activeCountForTree; + private int pseudoTimeQueueIndex = INDEX_NOT_IN_QUEUE; + private int stateOnlyQueueIndex = INDEX_NOT_IN_QUEUE; + /** + * An estimate of when this node should be given the opportunity to write data. + */ + long pseudoTimeToWrite; + /** + * A pseudo time maintained for immediate children to base their {@link #pseudoTimeToWrite} off of. + */ + long pseudoTime; + long totalQueuedWeights; + private byte flags; + short weight = DEFAULT_PRIORITY_WEIGHT; + + State(int streamId) { + this(streamId, null, 0); + } + + State(Http2Stream stream) { + this(stream, 0); + } + + State(Http2Stream stream, int initialSize) { + this(stream.id(), stream, initialSize); + } + + State(int streamId, Http2Stream stream, int initialSize) { + this.stream = stream; + this.streamId = streamId; + pseudoTimeQueue = new DefaultPriorityQueue(StatePseudoTimeComparator.INSTANCE, initialSize); + } + + boolean isDescendantOf(State state) { + State next = parent; + while (next != null) { + if (next == state) { + return true; + } + next = next.parent; + } + return false; + } + + void takeChild(State child, boolean exclusive, List events) { + takeChild(null, child, exclusive, events); + } + + /** + * Adds a child to this priority. If exclusive is set, any children of this node are moved to being dependent on + * the child. + */ + void takeChild(Iterator> childItr, State child, boolean exclusive, + List events) { + State oldParent = child.parent; + + if (oldParent != this) { + events.add(new ParentChangedEvent(child, oldParent)); + child.setParent(this); + // If the childItr is not null we are iterating over the oldParent.children collection and should + // use the iterator to remove from the collection to avoid concurrent modification. Otherwise it is + // assumed we are not iterating over this collection and it is safe to call remove directly. + if (childItr != null) { + childItr.remove(); + } else if (oldParent != null) { + oldParent.children.remove(child.streamId); + } + + // Lazily initialize the children to save object allocations. + initChildrenIfEmpty(); + + final State oldChild = children.put(child.streamId, child); + assert oldChild == null : "A stream with the same stream ID was already in the child map."; + } + + if (exclusive && !children.isEmpty()) { + // If it was requested that this child be the exclusive dependency of this node, + // move any previous children to the child node, becoming grand children of this node. + Iterator> itr = removeAllChildrenExcept(child).entries().iterator(); + while (itr.hasNext()) { + child.takeChild(itr, itr.next().value(), false, events); + } + } + } + + /** + * Removes the child priority and moves any of its dependencies to being direct dependencies on this node. + */ + void removeChild(State child) { + if (children.remove(child.streamId) != null) { + List events = new ArrayList(1 + child.children.size()); + events.add(new ParentChangedEvent(child, child.parent)); + child.setParent(null); + + if (!child.children.isEmpty()) { + // Move up any grand children to be directly dependent on this node. + Iterator> itr = child.children.entries().iterator(); + long totalWeight = child.getTotalWeight(); + do { + // Redistribute the weight of child to its dependency proportionally. + State dependency = itr.next().value(); + dependency.weight = (short) max(1, dependency.weight * child.weight / totalWeight); + takeChild(itr, dependency, false, events); + } while (itr.hasNext()); + } + + notifyParentChanged(events); + } + } + + private long getTotalWeight() { + long totalWeight = 0L; + for (State state : children.values()) { + totalWeight += state.weight; + } + return totalWeight; + } + + /** + * Remove all children with the exception of {@code streamToRetain}. + * This method is intended to be used to support an exclusive priority dependency operation. + * @return The map of children prior to this operation, excluding {@code streamToRetain} if present. + */ + private IntObjectMap removeAllChildrenExcept(State stateToRetain) { + stateToRetain = children.remove(stateToRetain.streamId); + IntObjectMap prevChildren = children; + // This map should be re-initialized in anticipation for the 1 exclusive child which will be added. + // It will either be added directly in this method, or after this method is called...but it will be added. + initChildren(); + if (stateToRetain != null) { + children.put(stateToRetain.streamId, stateToRetain); + } + return prevChildren; + } + + private void setParent(State newParent) { + // if activeCountForTree == 0 then it will not be in its parent's pseudoTimeQueue. + if (activeCountForTree != 0 && parent != null) { + parent.removePseudoTimeQueue(this); + parent.activeCountChangeForTree(-activeCountForTree); + } + parent = newParent; + // Use MAX_VALUE if no parent because lower depth is considered higher priority by StateOnlyComparator. + dependencyTreeDepth = newParent == null ? MAX_VALUE : newParent.dependencyTreeDepth + 1; + } + + private void initChildrenIfEmpty() { + if (children == IntCollections.emptyMap()) { + initChildren(); + } + } + + private void initChildren() { + children = new IntObjectHashMap(INITIAL_CHILDREN_MAP_SIZE); + } + + void write(int numBytes, Writer writer) throws Http2Exception { + assert stream != null; + try { + writer.write(stream, numBytes); + } catch (Throwable t) { + throw connectionError(INTERNAL_ERROR, t, "byte distribution write error"); + } + } + + void activeCountChangeForTree(int increment) { + assert activeCountForTree + increment >= 0; + activeCountForTree += increment; + if (parent != null) { + assert activeCountForTree != increment || + pseudoTimeQueueIndex == INDEX_NOT_IN_QUEUE || + parent.pseudoTimeQueue.containsTyped(this) : + "State[" + streamId + "].activeCountForTree changed from 0 to " + increment + " is in a " + + "pseudoTimeQueue, but not in parent[ " + parent.streamId + "]'s pseudoTimeQueue"; + if (activeCountForTree == 0) { + parent.removePseudoTimeQueue(this); + } else if (activeCountForTree == increment && !isDistributing()) { + // If frame count was 0 but is now not, and this node is not already in a pseudoTimeQueue (assumed + // to be pState's pseudoTimeQueue) then enqueue it. If this State object is being processed the + // pseudoTime for this node should not be adjusted, and the node will be added back to the + // pseudoTimeQueue/tree structure after it is done being processed. This may happen if the + // activeCountForTree == 0 (a node which can't stream anything and is blocked) is at/near root of + // the tree, and is popped off the pseudoTimeQueue during processing, and then put back on the + // pseudoTimeQueue because a child changes position in the priority tree (or is closed because it is + // not blocked and finished writing all data). + parent.offerAndInitializePseudoTime(this); + } + parent.activeCountChangeForTree(increment); + } + } + + void updateStreamableBytes(int newStreamableBytes, boolean isActive) { + if (isActive() != isActive) { + if (isActive) { + activeCountChangeForTree(1); + setActive(); + } else { + activeCountChangeForTree(-1); + unsetActive(); + } + } + + streamableBytes = newStreamableBytes; + } + + /** + * Assumes the parents {@link #totalQueuedWeights} includes this node's weight. + */ + void updatePseudoTime(State parentState, int nsent, long totalQueuedWeights) { + assert streamId != CONNECTION_STREAM_ID && nsent >= 0; + // If the current pseudoTimeToSend is greater than parentState.pseudoTime then we previously over accounted + // and should use parentState.pseudoTime. + pseudoTimeToWrite = min(pseudoTimeToWrite, parentState.pseudoTime) + nsent * totalQueuedWeights / weight; + } + + /** + * The concept of pseudoTime can be influenced by priority tree manipulations or if a stream goes from "active" + * to "non-active". This method accounts for that by initializing the {@link #pseudoTimeToWrite} for + * {@code state} to {@link #pseudoTime} of this node and then calls {@link #offerPseudoTimeQueue(State)}. + */ + void offerAndInitializePseudoTime(State state) { + state.pseudoTimeToWrite = pseudoTime; + offerPseudoTimeQueue(state); + } + + void offerPseudoTimeQueue(State state) { + pseudoTimeQueue.offer(state); + totalQueuedWeights += state.weight; + } + + /** + * Must only be called if the pseudoTimeQueue is non-empty! + */ + State pollPseudoTimeQueue() { + State state = pseudoTimeQueue.poll(); + // This method is only ever called if the pseudoTimeQueue is non-empty. + totalQueuedWeights -= state.weight; + return state; + } + + void removePseudoTimeQueue(State state) { + if (pseudoTimeQueue.removeTyped(state)) { + totalQueuedWeights -= state.weight; + } + } + + State peekPseudoTimeQueue() { + return pseudoTimeQueue.peek(); + } + + void close() { + updateStreamableBytes(0, false); + stream = null; + } + + boolean wasStreamReservedOrActivated() { + return (flags & STATE_STREAM_ACTIVATED) != 0; + } + + void setStreamReservedOrActivated() { + flags |= STATE_STREAM_ACTIVATED; + } + + boolean isActive() { + return (flags & STATE_IS_ACTIVE) != 0; + } + + private void setActive() { + flags |= STATE_IS_ACTIVE; + } + + private void unsetActive() { + flags &= ~STATE_IS_ACTIVE; + } + + boolean isDistributing() { + return (flags & STATE_IS_DISTRIBUTING) != 0; + } + + void setDistributing() { + flags |= STATE_IS_DISTRIBUTING; + } + + void unsetDistributing() { + flags &= ~STATE_IS_DISTRIBUTING; + } + + @Override + public int priorityQueueIndex(DefaultPriorityQueue queue) { + return queue == stateOnlyRemovalQueue ? stateOnlyQueueIndex : pseudoTimeQueueIndex; + } + + @Override + public void priorityQueueIndex(DefaultPriorityQueue queue, int i) { + if (queue == stateOnlyRemovalQueue) { + stateOnlyQueueIndex = i; + } else { + pseudoTimeQueueIndex = i; + } + } + + @Override + public String toString() { + // Use activeCountForTree as a rough estimate for how many nodes are in this subtree. + StringBuilder sb = new StringBuilder(256 * (activeCountForTree > 0 ? activeCountForTree : 1)); + toString(sb); + return sb.toString(); + } + + private void toString(StringBuilder sb) { + sb.append("{streamId ").append(streamId) + .append(" streamableBytes ").append(streamableBytes) + .append(" activeCountForTree ").append(activeCountForTree) + .append(" pseudoTimeQueueIndex ").append(pseudoTimeQueueIndex) + .append(" pseudoTimeToWrite ").append(pseudoTimeToWrite) + .append(" pseudoTime ").append(pseudoTime) + .append(" flags ").append(flags) + .append(" pseudoTimeQueue.size() ").append(pseudoTimeQueue.size()) + .append(" stateOnlyQueueIndex ").append(stateOnlyQueueIndex) + .append(" parent.streamId ").append(parent == null ? -1 : parent.streamId).append("} ["); + + if (!pseudoTimeQueue.isEmpty()) { + for (State s : pseudoTimeQueue) { + s.toString(sb); + sb.append(", "); + } + // Remove the last ", " + sb.setLength(sb.length() - 2); + } + sb.append(']'); + } + } + + /** + * Allows a correlation to be made between a stream and its old parent before a parent change occurs. + */ + private static final class ParentChangedEvent { + final State state; + final State oldParent; + + /** + * Create a new instance. + * @param state The state who has had a parent change. + * @param oldParent The previous parent. + */ + ParentChangedEvent(State state, State oldParent) { + this.state = state; + this.oldParent = oldParent; + } + } +} diff --git a/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/package-info.java b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/package-info.java new file mode 100644 index 0000000..c8c7cbc --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/io/netty/handler/codec/http2/package-info.java @@ -0,0 +1,22 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +/** + * Handlers for sending and receiving HTTP/2 frames. + */ +@UnstableApi +package io.netty.handler.codec.http2; + +import io.netty.util.internal.UnstableApi; diff --git a/netty-handler-codec-http2/src/main/java/module-info.java b/netty-handler-codec-http2/src/main/java/module-info.java new file mode 100644 index 0000000..9a9a809 --- /dev/null +++ b/netty-handler-codec-http2/src/main/java/module-info.java @@ -0,0 +1,12 @@ +module org.xbib.io.netty.handler.codec.httptwo { + exports io.netty.handler.codec.http2; + requires org.xbib.io.netty.buffer; + requires org.xbib.io.netty.channel; + requires org.xbib.io.netty.handler; + requires org.xbib.io.netty.handler.codec; + requires org.xbib.io.netty.handler.codec.compression; + requires org.xbib.io.netty.handler.codec.http; + requires org.xbib.io.netty.handler.ssl; + requires org.xbib.io.netty.util; + requires com.aayushatharva.brotli4j; +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/AbstractDecoratingHttp2ConnectionDecoderTest.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/AbstractDecoratingHttp2ConnectionDecoderTest.java new file mode 100644 index 0000000..cd5e43b --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/AbstractDecoratingHttp2ConnectionDecoderTest.java @@ -0,0 +1,63 @@ +/* + * Copyright 2023 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import org.hamcrest.CoreMatchers; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public abstract class AbstractDecoratingHttp2ConnectionDecoderTest { + + protected abstract DecoratingHttp2ConnectionDecoder newDecoder(Http2ConnectionDecoder decoder); + + protected abstract Class delegatingFrameListenerType(); + + @Test + public void testDecoration() { + Http2ConnectionDecoder delegate = mock(Http2ConnectionDecoder.class); + final ArgumentCaptor listenerArgumentCaptor = + ArgumentCaptor.forClass(Http2FrameListener.class); + when(delegate.frameListener()).then(new Answer() { + @Override + public Http2FrameListener answer(InvocationOnMock invocationOnMock) { + return listenerArgumentCaptor.getValue(); + } + }); + Http2FrameListener listener = mock(Http2FrameListener.class); + DecoratingHttp2ConnectionDecoder decoder = newDecoder(delegate); + decoder.frameListener(listener); + verify(delegate).frameListener(listenerArgumentCaptor.capture()); + + assertThat(decoder.frameListener(), + CoreMatchers.not(CoreMatchers.instanceOf(delegatingFrameListenerType()))); + } + + @Test + public void testDecorationWithNull() { + Http2ConnectionDecoder delegate = mock(Http2ConnectionDecoder.class); + + DecoratingHttp2ConnectionDecoder decoder = newDecoder(delegate); + decoder.frameListener(null); + assertNull(decoder.frameListener()); + } +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/AbstractWeightedFairQueueByteDistributorDependencyTest.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/AbstractWeightedFairQueueByteDistributorDependencyTest.java new file mode 100644 index 0000000..724a653 --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/AbstractWeightedFairQueueByteDistributorDependencyTest.java @@ -0,0 +1,72 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.handler.codec.http2.Http2TestUtil.TestStreamByteDistributorStreamState; +import io.netty.util.collection.IntObjectHashMap; +import io.netty.util.collection.IntObjectMap; +import org.mockito.Mock; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +abstract class AbstractWeightedFairQueueByteDistributorDependencyTest { + Http2Connection connection; + WeightedFairQueueByteDistributor distributor; + private final IntObjectMap stateMap = + new IntObjectHashMap(); + + @Mock + StreamByteDistributor.Writer writer; + + Http2Stream stream(int streamId) { + return connection.stream(streamId); + } + + Answer writeAnswer(final boolean closeIfNoFrame) { + return new Answer() { + @Override + public Void answer(InvocationOnMock in) throws Throwable { + Http2Stream stream = in.getArgument(0); + int numBytes = in.getArgument(1); + TestStreamByteDistributorStreamState state = stateMap.get(stream.id()); + state.pendingBytes -= numBytes; + state.hasFrame = state.pendingBytes > 0; + state.isWriteAllowed = state.hasFrame; + if (closeIfNoFrame && !state.hasFrame) { + stream.close(); + } + distributor.updateStreamableBytes(state); + return null; + } + }; + } + + void initState(final int streamId, final long streamableBytes, final boolean hasFrame) { + initState(streamId, streamableBytes, hasFrame, hasFrame); + } + + void initState(final int streamId, final long pendingBytes, final boolean hasFrame, + final boolean isWriteAllowed) { + final Http2Stream stream = stream(streamId); + TestStreamByteDistributorStreamState state = new TestStreamByteDistributorStreamState(stream, pendingBytes, + hasFrame, isWriteAllowed); + stateMap.put(streamId, state); + distributor.updateStreamableBytes(state); + } + + void setPriority(int streamId, int parent, int weight, boolean exclusive) throws Http2Exception { + distributor.updateDependencyTree(streamId, parent, (short) weight, exclusive); + } +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/CleartextHttp2ServerUpgradeHandlerTest.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/CleartextHttp2ServerUpgradeHandlerTest.java new file mode 100644 index 0000000..696650c --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/CleartextHttp2ServerUpgradeHandlerTest.java @@ -0,0 +1,291 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.http.DefaultHttpHeaders; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpRequest; +import io.netty.handler.codec.http.HttpServerCodec; +import io.netty.handler.codec.http.HttpServerUpgradeHandler; +import io.netty.handler.codec.http.HttpServerUpgradeHandler.UpgradeCodec; +import io.netty.handler.codec.http.HttpServerUpgradeHandler.UpgradeCodecFactory; +import io.netty.handler.codec.http.HttpServerUpgradeHandler.UpgradeEvent; +import io.netty.handler.codec.http.HttpVersion; +import io.netty.handler.codec.http.LastHttpContent; +import io.netty.handler.codec.http2.CleartextHttp2ServerUpgradeHandler.PriorKnowledgeUpgradeEvent; +import io.netty.handler.codec.http2.Http2Stream.State; +import io.netty.util.CharsetUtil; +import io.netty.util.ReferenceCountUtil; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; + +/** + * Tests for {@link CleartextHttp2ServerUpgradeHandler} + */ +public class CleartextHttp2ServerUpgradeHandlerTest { + private EmbeddedChannel channel; + + private Http2FrameListener frameListener; + + private Http2ConnectionHandler http2ConnectionHandler; + + private List userEvents; + + private void setUpServerChannel() { + frameListener = mock(Http2FrameListener.class); + + http2ConnectionHandler = new Http2ConnectionHandlerBuilder() + .frameListener(frameListener).build(); + + UpgradeCodecFactory upgradeCodecFactory = new UpgradeCodecFactory() { + @Override + public UpgradeCodec newUpgradeCodec(CharSequence protocol) { + return new Http2ServerUpgradeCodec(http2ConnectionHandler); + } + }; + + userEvents = new ArrayList(); + + HttpServerCodec httpServerCodec = new HttpServerCodec(); + HttpServerUpgradeHandler upgradeHandler = new HttpServerUpgradeHandler(httpServerCodec, upgradeCodecFactory); + + CleartextHttp2ServerUpgradeHandler handler = new CleartextHttp2ServerUpgradeHandler( + httpServerCodec, upgradeHandler, http2ConnectionHandler); + channel = new EmbeddedChannel(handler, new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + userEvents.add(evt); + } + }); + } + + @AfterEach + public void tearDown() throws Exception { + channel.finishAndReleaseAll(); + } + + @Test + public void priorKnowledge() throws Exception { + setUpServerChannel(); + + channel.writeInbound(Http2CodecUtil.connectionPrefaceBuf()); + + ByteBuf settingsFrame = settingsFrameBuf(); + + assertFalse(channel.writeInbound(settingsFrame)); + + assertEquals(1, userEvents.size()); + assertTrue(userEvents.get(0) instanceof PriorKnowledgeUpgradeEvent); + + assertEquals(100, http2ConnectionHandler.connection().local().maxActiveStreams()); + assertEquals(65535, http2ConnectionHandler.connection().local().flowController().initialWindowSize()); + + verify(frameListener).onSettingsRead( + any(ChannelHandlerContext.class), eq(expectedSettings())); + } + + @Test + public void upgrade() throws Exception { + String upgradeString = "GET / HTTP/1.1\r\n" + + "Host: example.com\r\n" + + "Connection: Upgrade, HTTP2-Settings\r\n" + + "Upgrade: h2c\r\n" + + "HTTP2-Settings: AAMAAABkAAQAAP__\r\n\r\n"; + validateClearTextUpgrade(upgradeString); + } + + @Test + public void upgradeWithMultipleConnectionHeaders() { + String upgradeString = "GET / HTTP/1.1\r\n" + + "Host: example.com\r\n" + + "Connection: keep-alive\r\n" + + "Connection: Upgrade, HTTP2-Settings\r\n" + + "Upgrade: h2c\r\n" + + "HTTP2-Settings: AAMAAABkAAQAAP__\r\n\r\n"; + validateClearTextUpgrade(upgradeString); + } + + @Test + public void requiredHeadersInSeparateConnectionHeaders() { + String upgradeString = "GET / HTTP/1.1\r\n" + + "Host: example.com\r\n" + + "Connection: keep-alive\r\n" + + "Connection: HTTP2-Settings\r\n" + + "Connection: Upgrade\r\n" + + "Upgrade: h2c\r\n" + + "HTTP2-Settings: AAMAAABkAAQAAP__\r\n\r\n"; + validateClearTextUpgrade(upgradeString); + } + + @Test + public void priorKnowledgeInFragments() throws Exception { + setUpServerChannel(); + + ByteBuf connectionPreface = Http2CodecUtil.connectionPrefaceBuf(); + assertFalse(channel.writeInbound(connectionPreface.readBytes(5), connectionPreface)); + + ByteBuf settingsFrame = settingsFrameBuf(); + assertFalse(channel.writeInbound(settingsFrame)); + + assertEquals(1, userEvents.size()); + assertTrue(userEvents.get(0) instanceof PriorKnowledgeUpgradeEvent); + + assertEquals(100, http2ConnectionHandler.connection().local().maxActiveStreams()); + assertEquals(65535, http2ConnectionHandler.connection().local().flowController().initialWindowSize()); + + verify(frameListener).onSettingsRead( + any(ChannelHandlerContext.class), eq(expectedSettings())); + } + + @Test + public void downgrade() throws Exception { + setUpServerChannel(); + + String requestString = "GET / HTTP/1.1\r\n" + + "Host: example.com\r\n\r\n"; + ByteBuf inbound = Unpooled.buffer().writeBytes(requestString.getBytes(CharsetUtil.US_ASCII)); + + assertTrue(channel.writeInbound(inbound)); + + Object firstInbound = channel.readInbound(); + assertTrue(firstInbound instanceof HttpRequest); + HttpRequest request = (HttpRequest) firstInbound; + assertEquals(HttpMethod.GET, request.method()); + assertEquals("/", request.uri()); + assertEquals(HttpVersion.HTTP_1_1, request.protocolVersion()); + assertEquals(new DefaultHttpHeaders().add("Host", "example.com"), request.headers()); + + ((LastHttpContent) channel.readInbound()).release(); + + assertNull(channel.readInbound()); + } + + @Test + public void usedHttp2MultiplexCodec() throws Exception { + final Http2MultiplexCodec http2Codec = new Http2MultiplexCodecBuilder(true, new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + } + }).build(); + UpgradeCodecFactory upgradeCodecFactory = new UpgradeCodecFactory() { + @Override + public UpgradeCodec newUpgradeCodec(CharSequence protocol) { + return new Http2ServerUpgradeCodec(http2Codec); + } + }; + http2ConnectionHandler = http2Codec; + + userEvents = new ArrayList(); + + HttpServerCodec httpServerCodec = new HttpServerCodec(); + HttpServerUpgradeHandler upgradeHandler = new HttpServerUpgradeHandler(httpServerCodec, upgradeCodecFactory); + + CleartextHttp2ServerUpgradeHandler handler = new CleartextHttp2ServerUpgradeHandler( + httpServerCodec, upgradeHandler, http2Codec); + channel = new EmbeddedChannel(handler, new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + userEvents.add(evt); + } + }); + + assertFalse(channel.writeInbound(Http2CodecUtil.connectionPrefaceBuf())); + + ByteBuf settingsFrame = settingsFrameBuf(); + + assertTrue(channel.writeInbound(settingsFrame)); + + assertEquals(1, userEvents.size()); + assertTrue(userEvents.get(0) instanceof PriorKnowledgeUpgradeEvent); + } + + private static ByteBuf settingsFrameBuf() { + ByteBuf settingsFrame = Unpooled.buffer(); + settingsFrame.writeMedium(12); // Payload length + settingsFrame.writeByte(0x4); // Frame type + settingsFrame.writeByte(0x0); // Flags + settingsFrame.writeInt(0x0); // StreamId + settingsFrame.writeShort(0x3); + settingsFrame.writeInt(100); + settingsFrame.writeShort(0x4); + settingsFrame.writeInt(65535); + + return settingsFrame; + } + + private static Http2Settings expectedSettings() { + return new Http2Settings().maxConcurrentStreams(100).initialWindowSize(65535); + } + + private void validateClearTextUpgrade(String upgradeString) { + setUpServerChannel(); + + ByteBuf upgrade = Unpooled.copiedBuffer(upgradeString, CharsetUtil.US_ASCII); + + assertFalse(channel.writeInbound(upgrade)); + + assertEquals(1, userEvents.size()); + + Object userEvent = userEvents.get(0); + assertTrue(userEvent instanceof UpgradeEvent); + assertEquals("h2c", ((UpgradeEvent) userEvent).protocol()); + ReferenceCountUtil.release(userEvent); + + assertEquals(100, http2ConnectionHandler.connection().local().maxActiveStreams()); + assertEquals(65535, http2ConnectionHandler.connection().local().flowController().initialWindowSize()); + + assertEquals(1, http2ConnectionHandler.connection().numActiveStreams()); + assertNotNull(http2ConnectionHandler.connection().stream(1)); + + Http2Stream stream = http2ConnectionHandler.connection().stream(1); + assertEquals(State.HALF_CLOSED_REMOTE, stream.state()); + assertFalse(stream.isHeadersSent()); + + String expectedHttpResponse = "HTTP/1.1 101 Switching Protocols\r\n" + + "connection: upgrade\r\n" + + "upgrade: h2c\r\n\r\n"; + ByteBuf responseBuffer = channel.readOutbound(); + assertEquals(expectedHttpResponse, responseBuffer.toString(CharsetUtil.UTF_8)); + responseBuffer.release(); + + // Check that the preface was send (a.k.a the settings frame) + ByteBuf settingsBuffer = channel.readOutbound(); + assertNotNull(settingsBuffer); + settingsBuffer.release(); + + assertNull(channel.readOutbound()); + } +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/DataCompressionHttp2Test.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/DataCompressionHttp2Test.java new file mode 100644 index 0000000..0401449 --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/DataCompressionHttp2Test.java @@ -0,0 +1,533 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.ChannelPromise; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.nio.NioServerSocketChannel; +import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.handler.codec.compression.Brotli; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaderValues; +import io.netty.handler.codec.http2.Http2TestUtil.Http2Runnable; +import io.netty.util.AsciiString; +import io.netty.util.CharsetUtil; +import io.netty.util.NetUtil; +import io.netty.util.concurrent.Future; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.net.InetSocketAddress; +import java.util.Random; +import java.util.concurrent.CountDownLatch; + +import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_PRIORITY_WEIGHT; +import static io.netty.handler.codec.http2.Http2TestUtil.runInChannel; +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static java.util.concurrent.TimeUnit.SECONDS; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.anyBoolean; +import static org.mockito.Mockito.anyInt; +import static org.mockito.Mockito.anyShort; +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.verify; + +/** + * Test for data decompression in the HTTP/2 codec. + */ +public class DataCompressionHttp2Test { + private static final AsciiString GET = new AsciiString("GET"); + private static final AsciiString POST = new AsciiString("POST"); + private static final AsciiString PATH = new AsciiString("/some/path"); + + @Mock + private Http2FrameListener serverListener; + @Mock + private Http2FrameListener clientListener; + + private Http2ConnectionEncoder clientEncoder; + private ServerBootstrap sb; + private Bootstrap cb; + private Channel serverChannel; + private Channel clientChannel; + private volatile Channel serverConnectedChannel; + private CountDownLatch serverLatch; + private Http2Connection serverConnection; + private Http2Connection clientConnection; + private Http2ConnectionHandler clientHandler; + private ByteArrayOutputStream serverOut; + + @BeforeAll + public static void beforeAllTests() throws Throwable { + Brotli.ensureAvailability(); + } + + @BeforeEach + public void setup() throws InterruptedException, Http2Exception { + MockitoAnnotations.initMocks(this); + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocation) throws Throwable { + if (invocation.getArgument(4)) { + serverConnection.stream((Integer) invocation.getArgument(1)).close(); + } + return null; + } + }).when(serverListener).onHeadersRead(any(ChannelHandlerContext.class), anyInt(), any(Http2Headers.class), + anyInt(), anyBoolean()); + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocation) throws Throwable { + if (invocation.getArgument(7)) { + serverConnection.stream((Integer) invocation.getArgument(1)).close(); + } + return null; + } + }).when(serverListener).onHeadersRead(any(ChannelHandlerContext.class), anyInt(), any(Http2Headers.class), + anyInt(), anyShort(), anyBoolean(), anyInt(), anyBoolean()); + } + + @AfterEach + public void cleanup() throws IOException { + serverOut.close(); + } + + @AfterEach + public void teardown() throws InterruptedException { + if (clientChannel != null) { + clientChannel.close().sync(); + clientChannel = null; + } + if (serverChannel != null) { + serverChannel.close().sync(); + serverChannel = null; + } + final Channel serverConnectedChannel = this.serverConnectedChannel; + if (serverConnectedChannel != null) { + serverConnectedChannel.close().sync(); + this.serverConnectedChannel = null; + } + Future serverGroup = sb.config().group().shutdownGracefully(0, 0, MILLISECONDS); + Future serverChildGroup = sb.config().childGroup().shutdownGracefully(0, 0, MILLISECONDS); + Future clientGroup = cb.config().group().shutdownGracefully(0, 0, MILLISECONDS); + serverGroup.sync(); + serverChildGroup.sync(); + clientGroup.sync(); + } + + @Test + public void justHeadersNoData() throws Exception { + bootstrapEnv(0); + final Http2Headers headers = new DefaultHttp2Headers().method(GET).path(PATH) + .set(HttpHeaderNames.CONTENT_ENCODING, HttpHeaderValues.GZIP); + + runInChannel(clientChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + clientEncoder.writeHeaders(ctxClient(), 3, headers, 0, true, newPromiseClient()); + clientHandler.flush(ctxClient()); + } + }); + awaitServer(); + verify(serverListener).onHeadersRead(any(ChannelHandlerContext.class), eq(3), eq(headers), eq(0), + eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), eq(true)); + } + + @Test + public void gzipEncodingSingleEmptyMessage() throws Exception { + final String text = ""; + final ByteBuf data = Unpooled.copiedBuffer(text.getBytes()); + bootstrapEnv(data.readableBytes()); + try { + final Http2Headers headers = new DefaultHttp2Headers().method(POST).path(PATH) + .set(HttpHeaderNames.CONTENT_ENCODING, HttpHeaderValues.GZIP); + + runInChannel(clientChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + clientEncoder.writeHeaders(ctxClient(), 3, headers, 0, false, newPromiseClient()); + clientEncoder.writeData(ctxClient(), 3, data.retain(), 0, true, newPromiseClient()); + clientHandler.flush(ctxClient()); + } + }); + awaitServer(); + assertEquals(text, serverOut.toString(CharsetUtil.UTF_8.name())); + } finally { + data.release(); + } + } + + @Test + public void gzipEncodingSingleMessage() throws Exception { + final String text = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaabbbbbbbbbbbbbbbbbbbbbbbbbbbbbccccccccccccccccccccccc"; + final ByteBuf data = Unpooled.copiedBuffer(text.getBytes()); + bootstrapEnv(data.readableBytes()); + try { + final Http2Headers headers = new DefaultHttp2Headers().method(POST).path(PATH) + .set(HttpHeaderNames.CONTENT_ENCODING, HttpHeaderValues.GZIP); + + runInChannel(clientChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + clientEncoder.writeHeaders(ctxClient(), 3, headers, 0, false, newPromiseClient()); + clientEncoder.writeData(ctxClient(), 3, data.retain(), 0, true, newPromiseClient()); + clientHandler.flush(ctxClient()); + } + }); + awaitServer(); + assertEquals(text, serverOut.toString(CharsetUtil.UTF_8.name())); + } finally { + data.release(); + } + } + + @Test + public void gzipEncodingMultipleMessages() throws Exception { + final String text1 = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaabbbbbbbbbbbbbbbbbbbbbbbbbbbbbccccccccccccccccccccccc"; + final String text2 = "dddddddddddddddddddeeeeeeeeeeeeeeeeeeeffffffffffffffffffff"; + final ByteBuf data1 = Unpooled.copiedBuffer(text1.getBytes()); + final ByteBuf data2 = Unpooled.copiedBuffer(text2.getBytes()); + bootstrapEnv(data1.readableBytes() + data2.readableBytes()); + try { + final Http2Headers headers = new DefaultHttp2Headers().method(POST).path(PATH) + .set(HttpHeaderNames.CONTENT_ENCODING, HttpHeaderValues.GZIP); + + runInChannel(clientChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + clientEncoder.writeHeaders(ctxClient(), 3, headers, 0, false, newPromiseClient()); + clientEncoder.writeData(ctxClient(), 3, data1.retain(), 0, false, newPromiseClient()); + clientEncoder.writeData(ctxClient(), 3, data2.retain(), 0, true, newPromiseClient()); + clientHandler.flush(ctxClient()); + } + }); + awaitServer(); + assertEquals(text1 + text2, serverOut.toString(CharsetUtil.UTF_8.name())); + } finally { + data1.release(); + data2.release(); + } + } + + @Test + public void brotliEncodingSingleEmptyMessage() throws Exception { + final String text = ""; + final ByteBuf data = Unpooled.copiedBuffer(text.getBytes()); + bootstrapEnv(data.readableBytes()); + try { + final Http2Headers headers = new DefaultHttp2Headers().method(POST).path(PATH) + .set(HttpHeaderNames.CONTENT_ENCODING, HttpHeaderValues.BR); + + runInChannel(clientChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + clientEncoder.writeHeaders(ctxClient(), 3, headers, 0, false, newPromiseClient()); + clientEncoder.writeData(ctxClient(), 3, data.retain(), 0, true, newPromiseClient()); + clientHandler.flush(ctxClient()); + } + }); + awaitServer(); + assertEquals(text, serverOut.toString(CharsetUtil.UTF_8.name())); + } finally { + data.release(); + } + } + + @Test + public void brotliEncodingSingleMessage() throws Exception { + final String text = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaabbbbbbbbbbbbbbbbbbbbbbbbbbbbbccccccccccccccccccccccc"; + final ByteBuf data = Unpooled.copiedBuffer(text.getBytes(CharsetUtil.UTF_8.name())); + bootstrapEnv(data.readableBytes()); + try { + final Http2Headers headers = new DefaultHttp2Headers().method(POST).path(PATH) + .set(HttpHeaderNames.CONTENT_ENCODING, HttpHeaderValues.BR); + + runInChannel(clientChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + clientEncoder.writeHeaders(ctxClient(), 3, headers, 0, false, newPromiseClient()); + clientEncoder.writeData(ctxClient(), 3, data.retain(), 0, true, newPromiseClient()); + clientHandler.flush(ctxClient()); + } + }); + awaitServer(); + assertEquals(text, serverOut.toString(CharsetUtil.UTF_8.name())); + } finally { + data.release(); + } + } + + @Test + public void zstdEncodingSingleEmptyMessage() throws Exception { + final String text = ""; + final ByteBuf data = Unpooled.copiedBuffer(text.getBytes()); + bootstrapEnv(data.readableBytes()); + try { + final Http2Headers headers = new DefaultHttp2Headers().method(POST).path(PATH) + .set(HttpHeaderNames.CONTENT_ENCODING, HttpHeaderValues.ZSTD); + + runInChannel(clientChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + clientEncoder.writeHeaders(ctxClient(), 3, headers, 0, false, newPromiseClient()); + clientEncoder.writeData(ctxClient(), 3, data.retain(), 0, true, newPromiseClient()); + clientHandler.flush(ctxClient()); + } + }); + awaitServer(); + assertEquals(text, serverOut.toString(CharsetUtil.UTF_8.name())); + } finally { + data.release(); + } + } + + @Test + public void zstdEncodingSingleMessage() throws Exception { + final String text = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaabbbbbbbbbbbbbbbbbbbbbbbbbbbbbccccccccccccccccccccccc"; + final ByteBuf data = Unpooled.copiedBuffer(text.getBytes(CharsetUtil.UTF_8.name())); + bootstrapEnv(data.readableBytes()); + try { + final Http2Headers headers = new DefaultHttp2Headers().method(POST).path(PATH) + .set(HttpHeaderNames.CONTENT_ENCODING, HttpHeaderValues.ZSTD); + + runInChannel(clientChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + clientEncoder.writeHeaders(ctxClient(), 3, headers, 0, false, newPromiseClient()); + clientEncoder.writeData(ctxClient(), 3, data.retain(), 0, true, newPromiseClient()); + clientHandler.flush(ctxClient()); + } + }); + awaitServer(); + assertEquals(text, serverOut.toString(CharsetUtil.UTF_8.name())); + } finally { + data.release(); + } + } + + @Test + public void snappyEncodingSingleEmptyMessage() throws Exception { + final String text = ""; + final ByteBuf data = Unpooled.copiedBuffer(text.getBytes(CharsetUtil.US_ASCII)); + bootstrapEnv(data.readableBytes()); + try { + final Http2Headers headers = new DefaultHttp2Headers().method(POST).path(PATH) + .set(HttpHeaderNames.CONTENT_ENCODING, HttpHeaderValues.SNAPPY); + + runInChannel(clientChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + clientEncoder.writeHeaders(ctxClient(), 3, headers, 0, false, newPromiseClient()); + clientEncoder.writeData(ctxClient(), 3, data.retain(), 0, true, newPromiseClient()); + clientHandler.flush(ctxClient()); + } + }); + awaitServer(); + assertEquals(text, serverOut.toString(CharsetUtil.UTF_8.name())); + } finally { + data.release(); + } + } + + @Test + public void snappyEncodingSingleMessage() throws Exception { + final String text = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaabbbbbbbbbbbbbbbbbbbbbbbbbbbbbccccccccccccccccccccccc"; + final ByteBuf data = Unpooled.copiedBuffer(text.getBytes(CharsetUtil.US_ASCII)); + bootstrapEnv(data.readableBytes()); + try { + final Http2Headers headers = new DefaultHttp2Headers().method(POST).path(PATH) + .set(HttpHeaderNames.CONTENT_ENCODING, HttpHeaderValues.SNAPPY); + + runInChannel(clientChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + clientEncoder.writeHeaders(ctxClient(), 3, headers, 0, false, newPromiseClient()); + clientEncoder.writeData(ctxClient(), 3, data.retain(), 0, true, newPromiseClient()); + clientHandler.flush(ctxClient()); + } + }); + awaitServer(); + assertEquals(text, serverOut.toString(CharsetUtil.UTF_8.name())); + } finally { + data.release(); + } + } + + @Test + public void deflateEncodingWriteLargeMessage() throws Exception { + final int BUFFER_SIZE = 1 << 12; + final byte[] bytes = new byte[BUFFER_SIZE]; + new Random().nextBytes(bytes); + bootstrapEnv(BUFFER_SIZE); + final ByteBuf data = Unpooled.wrappedBuffer(bytes); + try { + final Http2Headers headers = new DefaultHttp2Headers().method(POST).path(PATH) + .set(HttpHeaderNames.CONTENT_ENCODING, HttpHeaderValues.DEFLATE); + + runInChannel(clientChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + clientEncoder.writeHeaders(ctxClient(), 3, headers, 0, false, newPromiseClient()); + clientEncoder.writeData(ctxClient(), 3, data.retain(), 0, true, newPromiseClient()); + clientHandler.flush(ctxClient()); + } + }); + awaitServer(); + assertEquals(data.resetReaderIndex().toString(CharsetUtil.UTF_8), + serverOut.toString(CharsetUtil.UTF_8.name())); + } finally { + data.release(); + } + } + + private void bootstrapEnv(int serverOutSize) throws Exception { + final CountDownLatch prefaceWrittenLatch = new CountDownLatch(1); + serverOut = new ByteArrayOutputStream(serverOutSize); + serverLatch = new CountDownLatch(1); + sb = new ServerBootstrap(); + cb = new Bootstrap(); + + // Streams are created before the normal flow for this test, so these connection must be initialized up front. + serverConnection = new DefaultHttp2Connection(true); + clientConnection = new DefaultHttp2Connection(false); + + serverConnection.addListener(new Http2ConnectionAdapter() { + @Override + public void onStreamClosed(Http2Stream stream) { + serverLatch.countDown(); + } + }); + + doAnswer(new Answer() { + @Override + public Integer answer(InvocationOnMock in) throws Throwable { + ByteBuf buf = (ByteBuf) in.getArguments()[2]; + int padding = (Integer) in.getArguments()[3]; + int processedBytes = buf.readableBytes() + padding; + + buf.readBytes(serverOut, buf.readableBytes()); + + if (in.getArgument(4)) { + serverConnection.stream((Integer) in.getArgument(1)).close(); + } + return processedBytes; + } + }).when(serverListener).onDataRead(any(ChannelHandlerContext.class), anyInt(), + any(ByteBuf.class), anyInt(), anyBoolean()); + + final CountDownLatch serverChannelLatch = new CountDownLatch(1); + sb.group(new NioEventLoopGroup(), new NioEventLoopGroup()); + sb.channel(NioServerSocketChannel.class); + sb.childHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + serverConnectedChannel = ch; + ChannelPipeline p = ch.pipeline(); + Http2FrameWriter frameWriter = new DefaultHttp2FrameWriter(); + serverConnection.remote().flowController( + new DefaultHttp2RemoteFlowController(serverConnection)); + serverConnection.local().flowController( + new DefaultHttp2LocalFlowController(serverConnection).frameWriter(frameWriter)); + Http2ConnectionEncoder encoder = new CompressorHttp2ConnectionEncoder( + new DefaultHttp2ConnectionEncoder(serverConnection, frameWriter)); + Http2ConnectionDecoder decoder = + new DefaultHttp2ConnectionDecoder(serverConnection, encoder, new DefaultHttp2FrameReader()); + Http2ConnectionHandler connectionHandler = new Http2ConnectionHandlerBuilder() + .frameListener(new DelegatingDecompressorFrameListener(serverConnection, serverListener)) + .codec(decoder, encoder).build(); + p.addLast(connectionHandler); + serverChannelLatch.countDown(); + } + }); + + cb.group(new NioEventLoopGroup()); + cb.channel(NioSocketChannel.class); + cb.handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ChannelPipeline p = ch.pipeline(); + Http2FrameWriter frameWriter = new DefaultHttp2FrameWriter(); + clientConnection.remote().flowController( + new DefaultHttp2RemoteFlowController(clientConnection)); + clientConnection.local().flowController( + new DefaultHttp2LocalFlowController(clientConnection).frameWriter(frameWriter)); + clientEncoder = new CompressorHttp2ConnectionEncoder( + new DefaultHttp2ConnectionEncoder(clientConnection, frameWriter)); + + Http2ConnectionDecoder decoder = + new DefaultHttp2ConnectionDecoder(clientConnection, clientEncoder, + new DefaultHttp2FrameReader()); + clientHandler = new Http2ConnectionHandlerBuilder() + .frameListener(new DelegatingDecompressorFrameListener(clientConnection, clientListener)) + // By default tests don't wait for server to gracefully shutdown streams + .gracefulShutdownTimeoutMillis(0) + .codec(decoder, clientEncoder).build(); + p.addLast(clientHandler); + p.addLast(new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + if (evt == Http2ConnectionPrefaceAndSettingsFrameWrittenEvent.INSTANCE) { + prefaceWrittenLatch.countDown(); + ctx.pipeline().remove(this); + } + } + }); + } + }); + + serverChannel = sb.bind(new InetSocketAddress(0)).sync().channel(); + int port = ((InetSocketAddress) serverChannel.localAddress()).getPort(); + + ChannelFuture ccf = cb.connect(new InetSocketAddress(NetUtil.LOCALHOST, port)); + assertTrue(ccf.awaitUninterruptibly().isSuccess()); + clientChannel = ccf.channel(); + assertTrue(prefaceWrittenLatch.await(5, SECONDS)); + assertTrue(serverChannelLatch.await(5, SECONDS)); + } + + private void awaitServer() throws Exception { + assertTrue(serverLatch.await(5, SECONDS)); + serverOut.flush(); + } + + private ChannelHandlerContext ctxClient() { + return clientChannel.pipeline().firstContext(); + } + + private ChannelPromise newPromiseClient() { + return ctxClient().newPromise(); + } +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/DecoratingHttp2ConnectionEncoderTest.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/DecoratingHttp2ConnectionEncoderTest.java new file mode 100644 index 0000000..fd72b36 --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/DecoratingHttp2ConnectionEncoderTest.java @@ -0,0 +1,54 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.times; + +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class DecoratingHttp2ConnectionEncoderTest { + + @Test + public void testConsumeReceivedSettingsThrows() { + Http2ConnectionEncoder encoder = mock(Http2ConnectionEncoder.class); + final DecoratingHttp2ConnectionEncoder decoratingHttp2ConnectionEncoder = + new DecoratingHttp2ConnectionEncoder(encoder); + assertThrows(IllegalStateException.class, new Executable() { + @Override + public void execute() { + decoratingHttp2ConnectionEncoder.consumeReceivedSettings(Http2Settings.defaultSettings()); + } + }); + } + + @Test + public void testConsumeReceivedSettingsDelegate() { + TestHttp2ConnectionEncoder encoder = mock(TestHttp2ConnectionEncoder.class); + DecoratingHttp2ConnectionEncoder decoratingHttp2ConnectionEncoder = + new DecoratingHttp2ConnectionEncoder(encoder); + + Http2Settings settings = Http2Settings.defaultSettings(); + decoratingHttp2ConnectionEncoder.consumeReceivedSettings(Http2Settings.defaultSettings()); + verify(encoder, times(1)).consumeReceivedSettings(eq(settings)); + } + + private interface TestHttp2ConnectionEncoder extends Http2ConnectionEncoder, Http2SettingsReceivedConsumer { } +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionDecoderTest.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionDecoderTest.java new file mode 100644 index 0000000..de60402 --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionDecoderTest.java @@ -0,0 +1,1055 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.UnpooledByteBufAllocator; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.channel.DefaultChannelPromise; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpResponseStatus; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.mockito.ArgumentCaptor; +import org.mockito.ArgumentMatchers; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +import java.util.Collections; +import java.util.IdentityHashMap; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; + +import static io.netty.buffer.Unpooled.EMPTY_BUFFER; +import static io.netty.buffer.Unpooled.wrappedBuffer; +import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_PRIORITY_WEIGHT; +import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR; +import static io.netty.handler.codec.http2.Http2Stream.State.IDLE; +import static io.netty.handler.codec.http2.Http2Stream.State.OPEN; +import static io.netty.handler.codec.http2.Http2Stream.State.RESERVED_REMOTE; +import static io.netty.util.CharsetUtil.UTF_8; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.not; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.anyBoolean; +import static org.mockito.Mockito.anyInt; +import static org.mockito.Mockito.anyLong; +import static org.mockito.Mockito.anyShort; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.isNull; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +/** + * Tests for {@link DefaultHttp2ConnectionDecoder}. + */ +public class DefaultHttp2ConnectionDecoderTest { + private static final int STREAM_ID = 3; + private static final int PUSH_STREAM_ID = 2; + private static final int STREAM_DEPENDENCY_ID = 5; + private static final int STATE_RECV_HEADERS = 1; + private static final int STATE_RECV_TRAILERS = 1 << 1; + + private Http2ConnectionDecoder decoder; + private ChannelPromise promise; + + @Mock + private Http2Connection connection; + + @Mock + private Http2Connection.Endpoint remote; + + @Mock + private Http2Connection.Endpoint local; + + @Mock + private Http2LocalFlowController localFlow; + + @Mock + private Http2RemoteFlowController remoteFlow; + + @Mock + private ChannelHandlerContext ctx; + + @Mock + private Channel channel; + + @Mock + private ChannelFuture future; + + @Mock + private Http2Stream stream; + + @Mock + private Http2Stream pushStream; + + @Mock + private Http2FrameListener listener; + + @Mock + private Http2FrameReader reader; + + @Mock + private Http2FrameWriter writer; + + @Mock + private Http2ConnectionEncoder encoder; + + @Mock + private Http2LifecycleManager lifecycleManager; + + @BeforeEach + public void setup() throws Exception { + MockitoAnnotations.initMocks(this); + + promise = new DefaultChannelPromise(channel); + + final AtomicInteger headersReceivedState = new AtomicInteger(); + when(channel.isActive()).thenReturn(true); + when(stream.id()).thenReturn(STREAM_ID); + when(stream.state()).thenReturn(OPEN); + when(stream.open(anyBoolean())).thenReturn(stream); + + final Map properties = new IdentityHashMap(); + when(stream.getProperty(ArgumentMatchers.any())).thenAnswer(new Answer() { + @Override + public Object answer(InvocationOnMock invocationOnMock) { + return properties.get(invocationOnMock.getArgument(0)); + } + }); + when(stream.setProperty(ArgumentMatchers.any(), any())).then(new Answer() { + @Override + public Object answer(InvocationOnMock invocationOnMock) { + return properties.put(invocationOnMock.getArgument(0), invocationOnMock.getArgument(1)); + } + }); + + when(pushStream.id()).thenReturn(PUSH_STREAM_ID); + doAnswer(new Answer() { + @Override + public Boolean answer(InvocationOnMock in) throws Throwable { + return (headersReceivedState.get() & STATE_RECV_HEADERS) != 0; + } + }).when(stream).isHeadersReceived(); + doAnswer(new Answer() { + @Override + public Boolean answer(InvocationOnMock in) throws Throwable { + return (headersReceivedState.get() & STATE_RECV_TRAILERS) != 0; + } + }).when(stream).isTrailersReceived(); + doAnswer(new Answer() { + @Override + public Http2Stream answer(InvocationOnMock in) throws Throwable { + boolean isInformational = in.getArgument(0); + if (isInformational) { + return stream; + } + for (;;) { + int current = headersReceivedState.get(); + int next = current; + if ((current & STATE_RECV_HEADERS) != 0) { + if ((current & STATE_RECV_TRAILERS) != 0) { + throw new IllegalStateException("already sent headers!"); + } + next |= STATE_RECV_TRAILERS; + } else { + next |= STATE_RECV_HEADERS; + } + if (headersReceivedState.compareAndSet(current, next)) { + break; + } + } + return stream; + } + }).when(stream).headersReceived(anyBoolean()); + doAnswer(new Answer() { + @Override + public Http2Stream answer(InvocationOnMock in) throws Throwable { + Http2StreamVisitor visitor = in.getArgument(0); + if (!visitor.visit(stream)) { + return stream; + } + return null; + } + }).when(connection).forEachActiveStream(any(Http2StreamVisitor.class)); + when(connection.stream(STREAM_ID)).thenReturn(stream); + when(connection.streamMayHaveExisted(STREAM_ID)).thenReturn(true); + when(connection.local()).thenReturn(local); + when(local.flowController()).thenReturn(localFlow); + when(encoder.flowController()).thenReturn(remoteFlow); + when(encoder.frameWriter()).thenReturn(writer); + when(connection.remote()).thenReturn(remote); + when(local.reservePushStream(eq(PUSH_STREAM_ID), eq(stream))).thenReturn(pushStream); + when(remote.reservePushStream(eq(PUSH_STREAM_ID), eq(stream))).thenReturn(pushStream); + when(ctx.alloc()).thenReturn(UnpooledByteBufAllocator.DEFAULT); + when(ctx.channel()).thenReturn(channel); + when(ctx.newSucceededFuture()).thenReturn(future); + when(ctx.newPromise()).thenReturn(promise); + when(ctx.write(any())).thenReturn(future); + + decoder = new DefaultHttp2ConnectionDecoder(connection, encoder, reader); + decoder.lifecycleManager(lifecycleManager); + decoder.frameListener(listener); + + // Simulate receiving the initial settings from the remote endpoint. + decode().onSettingsRead(ctx, new Http2Settings()); + verify(listener).onSettingsRead(eq(ctx), eq(new Http2Settings())); + assertTrue(decoder.prefaceReceived()); + verify(encoder).writeSettingsAck(eq(ctx), eq(promise)); + + // Simulate receiving the SETTINGS ACK for the initial settings. + decode().onSettingsAckRead(ctx); + + // Disallow any further flushes now that settings ACK has been sent + when(ctx.flush()).then(new Answer() { + @Override + public ChannelHandlerContext answer(InvocationOnMock invocationOnMock) { + fail(); + return null; + } + }); + } + + @Test + public void dataReadAfterGoAwaySentShouldApplyFlowControl() throws Exception { + mockGoAwaySent(); + + final ByteBuf data = dummyData(); + int padding = 10; + int processedBytes = data.readableBytes() + padding; + mockFlowControl(processedBytes); + try { + decode().onDataRead(ctx, STREAM_ID, data, padding, true); + verify(localFlow).receiveFlowControlledFrame(eq(stream), eq(data), eq(padding), eq(true)); + verify(localFlow).consumeBytes(eq(stream), eq(processedBytes)); + + // Verify that the event was absorbed and not propagated to the observer. + verify(listener, never()).onDataRead(eq(ctx), anyInt(), any(ByteBuf.class), anyInt(), anyBoolean()); + } finally { + data.release(); + } + } + + @Test + public void dataReadAfterGoAwaySentShouldAllowFramesForStreamCreatedByLocalEndpoint() throws Exception { + mockGoAwaySentShouldAllowFramesForStreamCreatedByLocalEndpoint(); + + final ByteBuf data = dummyData(); + int padding = 10; + int processedBytes = data.readableBytes() + padding; + mockFlowControl(processedBytes); + try { + decode().onDataRead(ctx, STREAM_ID, data, padding, true); + verify(localFlow).receiveFlowControlledFrame(eq(stream), eq(data), eq(padding), eq(true)); + verify(localFlow).consumeBytes(eq(stream), eq(processedBytes)); + + // Verify that the event was absorbed and not propagated to the observer. + verify(listener).onDataRead(eq(ctx), anyInt(), any(ByteBuf.class), anyInt(), anyBoolean()); + } finally { + data.release(); + } + } + + @Test + public void dataReadForUnknownStreamShouldApplyFlowControlAndFail() throws Exception { + when(connection.streamMayHaveExisted(STREAM_ID)).thenReturn(true); + when(connection.stream(STREAM_ID)).thenReturn(null); + final ByteBuf data = dummyData(); + final int padding = 10; + int processedBytes = data.readableBytes() + padding; + assertThrows(Http2Exception.StreamException.class, new Executable() { + @Override + public void execute() throws Throwable { + decode().onDataRead(ctx, STREAM_ID, data, padding, true); + } + }); + try { + verify(localFlow) + .receiveFlowControlledFrame(eq((Http2Stream) null), eq(data), eq(padding), eq(true)); + verify(localFlow).consumeBytes(eq((Http2Stream) null), eq(processedBytes)); + verify(localFlow).frameWriter(any(Http2FrameWriter.class)); + verifyNoMoreInteractions(localFlow); + verify(listener, never()).onDataRead(eq(ctx), anyInt(), any(ByteBuf.class), anyInt(), anyBoolean()); + } finally { + data.release(); + } + } + + @Test + public void dataReadForUnknownStreamThatCouldntExistFail() throws Exception { + when(connection.streamMayHaveExisted(STREAM_ID)).thenReturn(false); + when(connection.stream(STREAM_ID)).thenReturn(null); + final ByteBuf data = dummyData(); + final int padding = 10; + int processedBytes = data.readableBytes() + padding; + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + try { + decode().onDataRead(ctx, STREAM_ID, data, padding, true); + } catch (Http2Exception ex) { + assertThat(ex, not(instanceOf(Http2Exception.StreamException.class))); + throw ex; + } + } + }); + try { + verify(localFlow) + .receiveFlowControlledFrame(eq((Http2Stream) null), eq(data), eq(padding), eq(true)); + verify(localFlow).consumeBytes(eq((Http2Stream) null), eq(processedBytes)); + verify(localFlow).frameWriter(any(Http2FrameWriter.class)); + verifyNoMoreInteractions(localFlow); + verify(listener, never()).onDataRead(eq(ctx), anyInt(), any(ByteBuf.class), anyInt(), anyBoolean()); + } finally { + data.release(); + } + } + + @Test + public void dataReadForUnknownStreamShouldApplyFlowControl() throws Exception { + when(connection.stream(STREAM_ID)).thenReturn(null); + final ByteBuf data = dummyData(); + final int padding = 10; + int processedBytes = data.readableBytes() + padding; + try { + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + decode().onDataRead(ctx, STREAM_ID, data, padding, true); + } + }); + verify(localFlow) + .receiveFlowControlledFrame(eq((Http2Stream) null), eq(data), eq(padding), eq(true)); + verify(localFlow).consumeBytes(eq((Http2Stream) null), eq(processedBytes)); + verify(localFlow).frameWriter(any(Http2FrameWriter.class)); + verifyNoMoreInteractions(localFlow); + + // Verify that the event was absorbed and not propagated to the observer. + verify(listener, never()).onDataRead(eq(ctx), anyInt(), any(ByteBuf.class), anyInt(), anyBoolean()); + } finally { + data.release(); + } + } + + @Test + public void emptyDataFrameShouldApplyFlowControl() throws Exception { + final ByteBuf data = EMPTY_BUFFER; + int padding = 0; + mockFlowControl(0); + try { + decode().onDataRead(ctx, STREAM_ID, data, padding, true); + verify(localFlow).receiveFlowControlledFrame(eq(stream), eq(data), eq(padding), eq(true)); + + // Now we ignore the empty bytes inside consumeBytes method, so it will be called once. + verify(localFlow).consumeBytes(eq(stream), eq(0)); + + // Verify that the empty data event was propagated to the observer. + verify(listener).onDataRead(eq(ctx), eq(STREAM_ID), eq(data), eq(padding), eq(true)); + } finally { + data.release(); + } + } + + @Test + public void dataReadForStreamInInvalidStateShouldThrow() throws Exception { + // Throw an exception when checking stream state. + when(stream.state()).thenReturn(Http2Stream.State.CLOSED); + final ByteBuf data = dummyData(); + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + decode().onDataRead(ctx, STREAM_ID, data, 10, true); + } + }); + data.release(); + } + + @Test + public void dataReadAfterGoAwaySentForStreamInInvalidStateShouldIgnore() throws Exception { + // Throw an exception when checking stream state. + when(stream.state()).thenReturn(Http2Stream.State.CLOSED); + mockGoAwaySent(); + final ByteBuf data = dummyData(); + try { + decode().onDataRead(ctx, STREAM_ID, data, 10, true); + verify(localFlow).receiveFlowControlledFrame(eq(stream), eq(data), eq(10), eq(true)); + verify(listener, never()).onDataRead(eq(ctx), anyInt(), any(ByteBuf.class), anyInt(), anyBoolean()); + } finally { + data.release(); + } + } + + @Test + public void dataReadAfterGoAwaySentOnUnknownStreamShouldIgnore() throws Exception { + // Throw an exception when checking stream state. + when(connection.stream(STREAM_ID)).thenReturn(null); + mockGoAwaySent(); + final ByteBuf data = dummyData(); + try { + decode().onDataRead(ctx, STREAM_ID, data, 10, true); + verify(localFlow).receiveFlowControlledFrame((Http2Stream) isNull(), eq(data), eq(10), eq(true)); + verify(listener, never()).onDataRead(eq(ctx), anyInt(), any(ByteBuf.class), anyInt(), anyBoolean()); + } finally { + data.release(); + } + } + + @Test + public void dataReadAfterRstStreamForStreamInInvalidStateShouldIgnore() throws Exception { + // Throw an exception when checking stream state. + when(stream.state()).thenReturn(Http2Stream.State.CLOSED); + when(stream.isResetSent()).thenReturn(true); + final ByteBuf data = dummyData(); + try { + decode().onDataRead(ctx, STREAM_ID, data, 10, true); + verify(localFlow).receiveFlowControlledFrame(eq(stream), eq(data), eq(10), eq(true)); + verify(listener, never()).onDataRead(eq(ctx), anyInt(), any(ByteBuf.class), anyInt(), anyBoolean()); + } finally { + data.release(); + } + } + + @Test + public void dataReadWithEndOfStreamShouldcloseStreamRemote() throws Exception { + final ByteBuf data = dummyData(); + try { + decode().onDataRead(ctx, STREAM_ID, data, 10, true); + verify(localFlow).receiveFlowControlledFrame(eq(stream), eq(data), eq(10), eq(true)); + verify(lifecycleManager).closeStreamRemote(eq(stream), eq(future)); + verify(listener).onDataRead(eq(ctx), eq(STREAM_ID), eq(data), eq(10), eq(true)); + } finally { + data.release(); + } + } + + @Test + public void errorDuringDeliveryShouldReturnCorrectNumberOfBytes() throws Exception { + final ByteBuf data = dummyData(); + final int padding = 10; + final AtomicInteger unprocessed = new AtomicInteger(data.readableBytes() + padding); + doAnswer(new Answer() { + @Override + public Integer answer(InvocationOnMock in) throws Throwable { + return unprocessed.get(); + } + }).when(localFlow).unconsumedBytes(eq(stream)); + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock in) throws Throwable { + int delta = (Integer) in.getArguments()[1]; + int newValue = unprocessed.addAndGet(-delta); + if (newValue < 0) { + throw new RuntimeException("Returned too many bytes"); + } + return null; + } + }).when(localFlow).consumeBytes(eq(stream), anyInt()); + // When the listener callback is called, process a few bytes and then throw. + doAnswer(new Answer() { + @Override + public Integer answer(InvocationOnMock in) throws Throwable { + localFlow.consumeBytes(stream, 4); + throw new RuntimeException("Fake Exception"); + } + }).when(listener).onDataRead(eq(ctx), eq(STREAM_ID), any(ByteBuf.class), eq(10), eq(true)); + try { + assertThrows(RuntimeException.class, new Executable() { + @Override + public void execute() throws Throwable { + decode().onDataRead(ctx, STREAM_ID, data, padding, true); + } + }); + verify(localFlow) + .receiveFlowControlledFrame(eq(stream), eq(data), eq(padding), eq(true)); + verify(listener).onDataRead(eq(ctx), eq(STREAM_ID), eq(data), eq(padding), eq(true)); + assertEquals(0, localFlow.unconsumedBytes(stream)); + } finally { + data.release(); + } + } + + @Test + public void headersReadForUnknownStreamShouldThrow() throws Exception { + when(connection.stream(STREAM_ID)).thenReturn(null); + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + decode().onHeadersRead(ctx, STREAM_ID, EmptyHttp2Headers.INSTANCE, 0, false); + } + }); + } + + @Test + public void headersReadForStreamThatAlreadySentResetShouldBeIgnored() throws Exception { + when(stream.isResetSent()).thenReturn(true); + decode().onHeadersRead(ctx, STREAM_ID, EmptyHttp2Headers.INSTANCE, 0, false); + verify(remote, never()).createStream(anyInt(), anyBoolean()); + verify(stream, never()).open(anyBoolean()); + + // Verify that the event was absorbed and not propagated to the observer. + verify(listener, never()).onHeadersRead(eq(ctx), anyInt(), any(Http2Headers.class), anyInt(), anyBoolean()); + verify(remote, never()).createStream(anyInt(), anyBoolean()); + verify(stream, never()).open(anyBoolean()); + } + + @Test + public void headersReadForUnknownStreamAfterGoAwayShouldBeIgnored() throws Exception { + mockGoAwaySent(); + when(connection.stream(STREAM_ID)).thenReturn(null); + decode().onHeadersRead(ctx, STREAM_ID, EmptyHttp2Headers.INSTANCE, 0, false); + verify(remote, never()).createStream(anyInt(), anyBoolean()); + verify(stream, never()).open(anyBoolean()); + + // Verify that the event was absorbed and not propagated to the observer. + verify(listener, never()).onHeadersRead(eq(ctx), anyInt(), any(Http2Headers.class), anyInt(), anyBoolean()); + verify(remote, never()).createStream(anyInt(), anyBoolean()); + verify(stream, never()).open(anyBoolean()); + } + + @Test + public void headersReadForUnknownStreamShouldCreateStream() throws Exception { + final int streamId = 5; + when(remote.createStream(eq(streamId), anyBoolean())).thenReturn(stream); + decode().onHeadersRead(ctx, streamId, EmptyHttp2Headers.INSTANCE, 0, false); + verify(remote).createStream(eq(streamId), eq(false)); + verify(listener).onHeadersRead(eq(ctx), eq(streamId), eq(EmptyHttp2Headers.INSTANCE), eq(0), + eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), eq(false)); + } + + @Test + public void headersReadForUnknownStreamShouldCreateHalfClosedStream() throws Exception { + final int streamId = 5; + when(remote.createStream(eq(streamId), anyBoolean())).thenReturn(stream); + decode().onHeadersRead(ctx, streamId, EmptyHttp2Headers.INSTANCE, 0, true); + verify(remote).createStream(eq(streamId), eq(true)); + verify(listener).onHeadersRead(eq(ctx), eq(streamId), eq(EmptyHttp2Headers.INSTANCE), eq(0), + eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), eq(true)); + } + + @Test + public void headersReadForPromisedStreamShouldHalfOpenStream() throws Exception { + when(stream.state()).thenReturn(RESERVED_REMOTE); + decode().onHeadersRead(ctx, STREAM_ID, EmptyHttp2Headers.INSTANCE, 0, false); + verify(stream).open(false); + verify(listener).onHeadersRead(eq(ctx), eq(STREAM_ID), eq(EmptyHttp2Headers.INSTANCE), eq(0), + eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), eq(false)); + } + + @Test + public void trailersDoNotEndStreamThrows() throws Exception { + decode().onHeadersRead(ctx, STREAM_ID, EmptyHttp2Headers.INSTANCE, 0, false); + // Trailers must end the stream! + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + decode().onHeadersRead(ctx, STREAM_ID, EmptyHttp2Headers.INSTANCE, 0, false); + } + }); + } + + @ParameterizedTest + @ValueSource(strings = {":scheme", ":custom-pseudo-header"}) + public void trailersWithPseudoHeadersThrows(String pseudoHeader) throws Exception { + decode().onHeadersRead(ctx, STREAM_ID, EmptyHttp2Headers.INSTANCE, 0, false); + + final Http2Headers trailers = new DefaultHttp2Headers(false); + trailers.add(pseudoHeader, "something"); + Http2Exception ex = assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + decode().onHeadersRead(ctx, STREAM_ID, trailers, 0, true); + } + }); + assertEquals(PROTOCOL_ERROR, ex.error()); + assertThat(ex.getMessage(), containsString(pseudoHeader)); + } + + @Test + public void tooManyHeadersEOSThrows() throws Exception { + tooManyHeaderThrows(true); + } + + @Test + public void tooManyHeadersNoEOSThrows() throws Exception { + tooManyHeaderThrows(false); + } + + private void tooManyHeaderThrows(final boolean eos) throws Exception { + decode().onHeadersRead(ctx, STREAM_ID, EmptyHttp2Headers.INSTANCE, 0, false); + decode().onHeadersRead(ctx, STREAM_ID, EmptyHttp2Headers.INSTANCE, 0, true); + // We already received the trailers! + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + decode().onHeadersRead(ctx, STREAM_ID, EmptyHttp2Headers.INSTANCE, 0, eos); + } + }); + } + + private static Http2Headers informationalHeaders() { + Http2Headers headers = new DefaultHttp2Headers(); + headers.status(HttpResponseStatus.CONTINUE.codeAsText()); + return headers; + } + + @Test + public void infoHeadersAndTrailersAllowed() throws Exception { + infoHeadersAndTrailersAllowed(true, 1); + } + + @Test + public void multipleInfoHeadersAndTrailersAllowed() throws Exception { + infoHeadersAndTrailersAllowed(true, 10); + } + + @Test + public void infoHeadersAndTrailersNoEOSThrows() throws Exception { + infoHeadersAndTrailersAllowed(false, 1); + } + + @Test + public void multipleInfoHeadersAndTrailersNoEOSThrows() throws Exception { + infoHeadersAndTrailersAllowed(false, 10); + } + + private void infoHeadersAndTrailersAllowed(final boolean eos, int infoHeaderCount) + throws Exception { + for (int i = 0; i < infoHeaderCount; ++i) { + decode().onHeadersRead(ctx, STREAM_ID, informationalHeaders(), 0, false); + } + decode().onHeadersRead(ctx, STREAM_ID, EmptyHttp2Headers.INSTANCE, 0, false); + if (eos) { + decode().onHeadersRead(ctx, STREAM_ID, EmptyHttp2Headers.INSTANCE, 0, eos); + } else { + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + decode().onHeadersRead(ctx, STREAM_ID, EmptyHttp2Headers.INSTANCE, 0, eos); + } + }); + } + } + + @Test + public void headersReadForPromisedStreamShouldCloseStream() throws Exception { + when(stream.state()).thenReturn(RESERVED_REMOTE); + decode().onHeadersRead(ctx, STREAM_ID, EmptyHttp2Headers.INSTANCE, 0, true); + verify(stream).open(true); + verify(lifecycleManager).closeStreamRemote(eq(stream), eq(future)); + verify(listener).onHeadersRead(eq(ctx), eq(STREAM_ID), eq(EmptyHttp2Headers.INSTANCE), eq(0), + eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), eq(true)); + } + + @Test + public void headersDependencyNotCreatedShouldCreateAndSucceed() throws Exception { + final short weight = 1; + decode().onHeadersRead(ctx, STREAM_ID, EmptyHttp2Headers.INSTANCE, STREAM_DEPENDENCY_ID, + weight, true, 0, true); + verify(listener).onHeadersRead(eq(ctx), eq(STREAM_ID), eq(EmptyHttp2Headers.INSTANCE), eq(STREAM_DEPENDENCY_ID), + eq(weight), eq(true), eq(0), eq(true)); + verify(remoteFlow).updateDependencyTree(eq(STREAM_ID), eq(STREAM_DEPENDENCY_ID), eq(weight), eq(true)); + verify(lifecycleManager).closeStreamRemote(eq(stream), any(ChannelFuture.class)); + } + + @Test + public void pushPromiseReadAfterGoAwaySentShouldBeIgnored() throws Exception { + mockGoAwaySent(); + decode().onPushPromiseRead(ctx, STREAM_ID, PUSH_STREAM_ID, EmptyHttp2Headers.INSTANCE, 0); + verify(remote, never()).reservePushStream(anyInt(), any(Http2Stream.class)); + verify(listener, never()).onPushPromiseRead(eq(ctx), anyInt(), anyInt(), any(Http2Headers.class), anyInt()); + } + + @Test + public void pushPromiseReadAfterGoAwayShouldAllowFramesForStreamCreatedByLocalEndpoint() throws Exception { + mockGoAwaySentShouldAllowFramesForStreamCreatedByLocalEndpoint(); + decode().onPushPromiseRead(ctx, STREAM_ID, PUSH_STREAM_ID, EmptyHttp2Headers.INSTANCE, 0); + verify(remote).reservePushStream(anyInt(), any(Http2Stream.class)); + verify(listener).onPushPromiseRead(eq(ctx), anyInt(), anyInt(), any(Http2Headers.class), anyInt()); + } + + @Test + public void pushPromiseReadForUnknownStreamShouldThrow() throws Exception { + when(connection.stream(STREAM_ID)).thenReturn(null); + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + decode().onPushPromiseRead(ctx, STREAM_ID, PUSH_STREAM_ID, EmptyHttp2Headers.INSTANCE, 0); + } + }); + } + + @Test + public void pushPromiseReadShouldSucceed() throws Exception { + decode().onPushPromiseRead(ctx, STREAM_ID, PUSH_STREAM_ID, EmptyHttp2Headers.INSTANCE, 0); + verify(remote).reservePushStream(eq(PUSH_STREAM_ID), eq(stream)); + verify(listener).onPushPromiseRead(eq(ctx), eq(STREAM_ID), eq(PUSH_STREAM_ID), + eq(EmptyHttp2Headers.INSTANCE), eq(0)); + } + + @Test + public void priorityReadAfterGoAwaySentShouldAllowFramesForStreamCreatedByLocalEndpoint() throws Exception { + mockGoAwaySentShouldAllowFramesForStreamCreatedByLocalEndpoint(); + decode().onPriorityRead(ctx, STREAM_ID, 0, (short) 255, true); + verify(remoteFlow).updateDependencyTree(eq(STREAM_ID), eq(0), eq((short) 255), eq(true)); + verify(listener).onPriorityRead(eq(ctx), anyInt(), anyInt(), anyShort(), anyBoolean()); + } + + @Test + public void priorityReadForUnknownStreamShouldNotBeIgnored() throws Exception { + when(connection.stream(STREAM_ID)).thenReturn(null); + decode().onPriorityRead(ctx, STREAM_ID, 0, (short) 255, true); + verify(remoteFlow).updateDependencyTree(eq(STREAM_ID), eq(0), eq((short) 255), eq(true)); + verify(listener).onPriorityRead(eq(ctx), eq(STREAM_ID), eq(0), eq((short) 255), eq(true)); + } + + @Test + public void priorityReadShouldNotCreateNewStream() throws Exception { + when(connection.streamMayHaveExisted(STREAM_ID)).thenReturn(false); + when(connection.stream(STREAM_ID)).thenReturn(null); + decode().onPriorityRead(ctx, STREAM_ID, STREAM_DEPENDENCY_ID, (short) 255, true); + verify(remoteFlow).updateDependencyTree(eq(STREAM_ID), eq(STREAM_DEPENDENCY_ID), eq((short) 255), eq(true)); + verify(listener).onPriorityRead(eq(ctx), eq(STREAM_ID), eq(STREAM_DEPENDENCY_ID), eq((short) 255), eq(true)); + verify(remote, never()).createStream(eq(STREAM_ID), anyBoolean()); + verify(stream, never()).open(anyBoolean()); + } + + @Test + public void windowUpdateReadAfterGoAwaySentShouldBeIgnored() throws Exception { + mockGoAwaySent(); + decode().onWindowUpdateRead(ctx, STREAM_ID, 10); + verify(remoteFlow, never()).incrementWindowSize(any(Http2Stream.class), anyInt()); + verify(listener, never()).onWindowUpdateRead(eq(ctx), anyInt(), anyInt()); + } + + @Test + public void windowUpdateReadAfterGoAwaySentShouldAllowFramesForStreamCreatedByLocalEndpoint() throws Exception { + mockGoAwaySentShouldAllowFramesForStreamCreatedByLocalEndpoint(); + decode().onWindowUpdateRead(ctx, STREAM_ID, 10); + verify(remoteFlow).incrementWindowSize(any(Http2Stream.class), anyInt()); + verify(listener).onWindowUpdateRead(eq(ctx), anyInt(), anyInt()); + } + + @Test + public void windowUpdateReadForUnknownStreamShouldThrow() throws Exception { + when(connection.streamMayHaveExisted(STREAM_ID)).thenReturn(false); + when(connection.stream(STREAM_ID)).thenReturn(null); + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + decode().onWindowUpdateRead(ctx, STREAM_ID, 10); + } + }); + } + + @Test + public void windowUpdateReadForUnknownStreamShouldBeIgnored() throws Exception { + when(connection.stream(STREAM_ID)).thenReturn(null); + decode().onWindowUpdateRead(ctx, STREAM_ID, 10); + verify(remoteFlow, never()).incrementWindowSize(any(Http2Stream.class), anyInt()); + verify(listener, never()).onWindowUpdateRead(eq(ctx), anyInt(), anyInt()); + } + + @Test + public void windowUpdateReadShouldSucceed() throws Exception { + decode().onWindowUpdateRead(ctx, STREAM_ID, 10); + verify(remoteFlow).incrementWindowSize(eq(stream), eq(10)); + verify(listener).onWindowUpdateRead(eq(ctx), eq(STREAM_ID), eq(10)); + } + + @Test + public void rstStreamReadAfterGoAwayShouldSucceed() throws Exception { + when(connection.goAwaySent()).thenReturn(true); + decode().onRstStreamRead(ctx, STREAM_ID, PROTOCOL_ERROR.code()); + verify(lifecycleManager).closeStream(eq(stream), eq(future)); + verify(listener).onRstStreamRead(eq(ctx), anyInt(), anyLong()); + } + + @Test + public void rstStreamReadForUnknownStreamShouldThrow() throws Exception { + when(connection.streamMayHaveExisted(STREAM_ID)).thenReturn(false); + when(connection.stream(STREAM_ID)).thenReturn(null); + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + decode().onRstStreamRead(ctx, STREAM_ID, PROTOCOL_ERROR.code()); + } + }); + } + + @Test + public void rstStreamReadForUnknownStreamShouldBeIgnored() throws Exception { + when(connection.stream(STREAM_ID)).thenReturn(null); + decode().onRstStreamRead(ctx, STREAM_ID, PROTOCOL_ERROR.code()); + verify(lifecycleManager, never()).closeStream(eq(stream), eq(future)); + verify(listener, never()).onRstStreamRead(eq(ctx), anyInt(), anyLong()); + } + + @Test + public void rstStreamReadShouldCloseStream() throws Exception { + decode().onRstStreamRead(ctx, STREAM_ID, PROTOCOL_ERROR.code()); + verify(lifecycleManager).closeStream(eq(stream), eq(future)); + verify(listener).onRstStreamRead(eq(ctx), eq(STREAM_ID), eq(PROTOCOL_ERROR.code())); + } + + @Test + public void rstStreamOnIdleStreamShouldThrow() throws Exception { + when(stream.state()).thenReturn(IDLE); + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + decode().onRstStreamRead(ctx, STREAM_ID, PROTOCOL_ERROR.code()); + } + }); + verify(listener, never()).onRstStreamRead(any(ChannelHandlerContext.class), anyInt(), anyLong()); + } + + @Test + public void pingReadWithAckShouldNotifyListener() throws Exception { + decode().onPingAckRead(ctx, 0L); + verify(listener).onPingAckRead(eq(ctx), eq(0L)); + } + + @Test + public void pingReadShouldReplyWithAck() throws Exception { + decode().onPingRead(ctx, 0L); + verify(encoder).writePing(eq(ctx), eq(true), eq(0L), eq(promise)); + verify(listener, never()).onPingAckRead(eq(ctx), any(long.class)); + } + + @Test + public void settingsReadWithAckShouldNotifyListener() throws Exception { + decode().onSettingsAckRead(ctx); + // Take into account the time this was called during setup(). + verify(listener, times(2)).onSettingsAckRead(eq(ctx)); + } + + @Test + public void settingsReadShouldSetValues() throws Exception { + Http2Settings settings = new Http2Settings(); + settings.pushEnabled(true); + settings.initialWindowSize(123); + settings.maxConcurrentStreams(456); + settings.headerTableSize(789); + decode().onSettingsRead(ctx, settings); + verify(encoder).remoteSettings(settings); + verify(listener).onSettingsRead(eq(ctx), eq(settings)); + } + + @Test + public void goAwayShouldReadShouldUpdateConnectionState() throws Exception { + decode().onGoAwayRead(ctx, 1, 2L, EMPTY_BUFFER); + verify(connection).goAwayReceived(eq(1), eq(2L), eq(EMPTY_BUFFER)); + verify(listener).onGoAwayRead(eq(ctx), eq(1), eq(2L), eq(EMPTY_BUFFER)); + } + + @Test + public void dataContentLengthMissmatch() throws Exception { + dataContentLengthInvalid(false); + } + + @Test + public void dataContentLengthInvalid() throws Exception { + dataContentLengthInvalid(true); + } + + private void dataContentLengthInvalid(boolean negative) throws Exception { + final ByteBuf data = dummyData(); + final int padding = 10; + int processedBytes = data.readableBytes() + padding; + mockFlowControl(processedBytes); + try { + if (negative) { + assertThrows(Http2Exception.StreamException.class, new Executable() { + @Override + public void execute() throws Throwable { + decode().onHeadersRead(ctx, STREAM_ID, new DefaultHttp2Headers() + .setLong(HttpHeaderNames.CONTENT_LENGTH, -1L), padding, false); + } + }); + } else { + decode().onHeadersRead(ctx, STREAM_ID, new DefaultHttp2Headers() + .setLong(HttpHeaderNames.CONTENT_LENGTH, 1L), padding, false); + assertThrows(Http2Exception.StreamException.class, new Executable() { + @Override + public void execute() throws Throwable { + decode().onDataRead(ctx, STREAM_ID, data, padding, true); + } + }); + verify(localFlow).receiveFlowControlledFrame(eq(stream), eq(data), eq(padding), eq(true)); + verify(localFlow).consumeBytes(eq(stream), eq(processedBytes)); + + verify(listener, times(1)).onHeadersRead(eq(ctx), anyInt(), + any(Http2Headers.class), eq(0), eq(DEFAULT_PRIORITY_WEIGHT), eq(false), + eq(padding), eq(false)); + } + // Verify that the event was absorbed and not propagated to the observer. + verify(listener, never()).onDataRead(eq(ctx), anyInt(), any(ByteBuf.class), anyInt(), anyBoolean()); + } finally { + data.release(); + } + } + + @Test + public void headersContentLengthPositiveSign() throws Exception { + headersContentLengthSign("+1"); + } + + @Test + public void headersContentLengthNegativeSign() throws Exception { + headersContentLengthSign("-1"); + } + + private void headersContentLengthSign(final String length) throws Exception { + final int padding = 10; + when(connection.isServer()).thenReturn(true); + + assertThrows(Http2Exception.StreamException.class, new Executable() { + @Override + public void execute() throws Throwable { + decode().onHeadersRead(ctx, STREAM_ID, new DefaultHttp2Headers() + .set(HttpHeaderNames.CONTENT_LENGTH, length), padding, false); + } + }); + + // Verify that the event was absorbed and not propagated to the observer. + verify(listener, never()).onHeadersRead(eq(ctx), anyInt(), + any(Http2Headers.class), anyInt(), anyShort(), anyBoolean(), anyInt(), anyBoolean()); + } + + @Test + public void headersContentLengthMissmatch() throws Exception { + headersContentLength(false); + } + + @Test + public void headersContentLengthInvalid() throws Exception { + headersContentLength(true); + } + + private void headersContentLength(final boolean negative) throws Exception { + final int padding = 10; + when(connection.isServer()).thenReturn(true); + assertThrows(Http2Exception.StreamException.class, new Executable() { + @Override + public void execute() throws Throwable { + decode().onHeadersRead(ctx, STREAM_ID, new DefaultHttp2Headers() + .setLong(HttpHeaderNames.CONTENT_LENGTH, negative ? -1L : 1L), padding, true); + } + }); + + // Verify that the event was absorbed and not propagated to the observer. + verify(listener, never()).onHeadersRead(eq(ctx), anyInt(), + any(Http2Headers.class), anyInt(), anyShort(), anyBoolean(), anyInt(), anyBoolean()); + } + + @Test + public void multipleHeadersContentLengthSame() throws Exception { + multipleHeadersContentLength(true); + } + + @Test + public void multipleHeadersContentLengthDifferent() throws Exception { + multipleHeadersContentLength(false); + } + + private void multipleHeadersContentLength(boolean same) throws Exception { + final int padding = 10; + when(connection.isServer()).thenReturn(true); + final Http2Headers headers = new DefaultHttp2Headers(); + if (same) { + headers.addLong(HttpHeaderNames.CONTENT_LENGTH, 0); + headers.addLong(HttpHeaderNames.CONTENT_LENGTH, 0); + } else { + headers.addLong(HttpHeaderNames.CONTENT_LENGTH, 0); + headers.addLong(HttpHeaderNames.CONTENT_LENGTH, 1); + } + + if (same) { + decode().onHeadersRead(ctx, STREAM_ID, headers, padding, true); + verify(listener, times(1)).onHeadersRead(eq(ctx), anyInt(), + any(Http2Headers.class), anyInt(), anyShort(), anyBoolean(), anyInt(), anyBoolean()); + assertEquals(1, headers.getAll(HttpHeaderNames.CONTENT_LENGTH).size()); + } else { + assertThrows(Http2Exception.StreamException.class, new Executable() { + @Override + public void execute() throws Throwable { + decode().onHeadersRead(ctx, STREAM_ID, headers, padding, true); + } + }); + + // Verify that the event was absorbed and not propagated to the observer. + verify(listener, never()).onHeadersRead(eq(ctx), anyInt(), + any(Http2Headers.class), anyInt(), anyShort(), anyBoolean(), anyInt(), anyBoolean()); + } + } + + private static ByteBuf dummyData() { + // The buffer is purposely 8 bytes so it will even work for a ping frame. + return wrappedBuffer("abcdefgh".getBytes(UTF_8)); + } + + /** + * Calls the decode method on the handler and gets back the captured internal listener + */ + private Http2FrameListener decode() throws Exception { + ArgumentCaptor internalListener = ArgumentCaptor.forClass(Http2FrameListener.class); + doNothing().when(reader).readFrame(eq(ctx), any(ByteBuf.class), internalListener.capture()); + decoder.decodeFrame(ctx, EMPTY_BUFFER, Collections.emptyList()); + return internalListener.getValue(); + } + + private void mockFlowControl(final int processedBytes) throws Http2Exception { + doAnswer(new Answer() { + @Override + public Integer answer(InvocationOnMock invocation) throws Throwable { + return processedBytes; + } + }).when(listener).onDataRead(any(ChannelHandlerContext.class), anyInt(), + any(ByteBuf.class), anyInt(), anyBoolean()); + } + + private void mockGoAwaySent() { + when(connection.goAwaySent()).thenReturn(true); + when(remote.isValidStreamId(STREAM_ID)).thenReturn(true); + when(remote.lastStreamKnownByPeer()).thenReturn(0); + } + + private void mockGoAwaySentShouldAllowFramesForStreamCreatedByLocalEndpoint() { + when(connection.goAwaySent()).thenReturn(true); + when(remote.isValidStreamId(STREAM_ID)).thenReturn(false); + when(remote.lastStreamKnownByPeer()).thenReturn(0); + } +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionEncoderTest.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionEncoderTest.java new file mode 100644 index 0000000..4eee7ce --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionEncoderTest.java @@ -0,0 +1,957 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.PooledByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.netty.buffer.UnpooledByteBufAllocator; +import io.netty.channel.Channel; +import io.netty.channel.ChannelConfig; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelMetadata; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.ChannelPromise; +import io.netty.channel.DefaultChannelConfig; +import io.netty.channel.DefaultChannelPromise; +import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.handler.codec.http2.Http2RemoteFlowController.FlowControlled; +import io.netty.util.concurrent.ImmediateEventExecutor; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.InOrder; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +import java.util.ArrayList; +import java.util.List; + +import static io.netty.buffer.Unpooled.EMPTY_BUFFER; +import static io.netty.buffer.Unpooled.wrappedBuffer; +import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_PRIORITY_WEIGHT; +import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR; +import static io.netty.handler.codec.http2.Http2Stream.State.HALF_CLOSED_REMOTE; +import static io.netty.handler.codec.http2.Http2Stream.State.RESERVED_LOCAL; +import static io.netty.handler.codec.http2.Http2TestUtil.newVoidPromise; +import static io.netty.util.CharsetUtil.UTF_8; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.CoreMatchers.instanceOf; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.anyBoolean; +import static org.mockito.Mockito.anyInt; +import static org.mockito.Mockito.anyLong; +import static org.mockito.Mockito.anyShort; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.inOrder; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +/** + * Tests for {@link DefaultHttp2ConnectionEncoder} + */ +public class DefaultHttp2ConnectionEncoderTest { + private static final int STREAM_ID = 2; + private static final int PUSH_STREAM_ID = 4; + + @Mock + private Http2RemoteFlowController remoteFlow; + + @Mock + private ChannelHandlerContext ctx; + + @Mock + private Channel channel; + + @Mock + private Channel.Unsafe unsafe; + + @Mock + private ChannelPipeline pipeline; + + @Mock + private Http2FrameWriter writer; + + @Mock + private Http2FrameWriter.Configuration writerConfig; + + @Mock + private Http2FrameSizePolicy frameSizePolicy; + + @Mock + private Http2LifecycleManager lifecycleManager; + + private DefaultHttp2ConnectionEncoder encoder; + private Http2Connection connection; + private ArgumentCaptor payloadCaptor; + private List writtenData; + private List writtenPadding; + private boolean streamClosed; + + @BeforeEach + public void setup() throws Exception { + MockitoAnnotations.initMocks(this); + + ChannelMetadata metadata = new ChannelMetadata(false, 16); + when(channel.isActive()).thenReturn(true); + when(channel.pipeline()).thenReturn(pipeline); + when(channel.metadata()).thenReturn(metadata); + when(channel.unsafe()).thenReturn(unsafe); + ChannelConfig config = new DefaultChannelConfig(channel); + when(channel.config()).thenReturn(config); + doAnswer(new Answer() { + @Override + public ChannelFuture answer(InvocationOnMock in) { + return newPromise().setFailure((Throwable) in.getArgument(0)); + } + }).when(channel).newFailedFuture(any(Throwable.class)); + + when(writer.configuration()).thenReturn(writerConfig); + when(writerConfig.frameSizePolicy()).thenReturn(frameSizePolicy); + when(frameSizePolicy.maxFrameSize()).thenReturn(64); + doAnswer(new Answer() { + @Override + public ChannelFuture answer(InvocationOnMock in) throws Throwable { + return ((ChannelPromise) in.getArguments()[2]).setSuccess(); + } + }).when(writer).writeSettings(eq(ctx), any(Http2Settings.class), any(ChannelPromise.class)); + doAnswer(new Answer() { + @Override + public ChannelFuture answer(InvocationOnMock in) throws Throwable { + ((ByteBuf) in.getArguments()[3]).release(); + return ((ChannelPromise) in.getArguments()[4]).setSuccess(); + } + }).when(writer).writeGoAway(eq(ctx), anyInt(), anyInt(), any(ByteBuf.class), any(ChannelPromise.class)); + writtenData = new ArrayList(); + writtenPadding = new ArrayList(); + when(writer.writeData(eq(ctx), anyInt(), any(ByteBuf.class), anyInt(), anyBoolean(), + any(ChannelPromise.class))).then(new Answer() { + @Override + public ChannelFuture answer(InvocationOnMock in) throws Throwable { + // Make sure we only receive stream closure on the last frame and that void promises + // are used for all writes except the last one. + ChannelPromise promise = (ChannelPromise) in.getArguments()[5]; + if (streamClosed) { + fail("Stream already closed"); + } else { + streamClosed = (Boolean) in.getArguments()[4]; + } + writtenPadding.add((Integer) in.getArguments()[3]); + ByteBuf data = (ByteBuf) in.getArguments()[2]; + writtenData.add(data.toString(UTF_8)); + // Release the buffer just as DefaultHttp2FrameWriter does + data.release(); + // Let the promise succeed to trigger listeners. + return promise.setSuccess(); + } + }); + when(writer.writeHeaders(eq(ctx), anyInt(), any(Http2Headers.class), anyInt(), anyShort(), anyBoolean(), + anyInt(), anyBoolean(), any(ChannelPromise.class))) + .then(new Answer() { + @Override + public ChannelFuture answer(InvocationOnMock invocationOnMock) { + ChannelPromise promise = invocationOnMock.getArgument(8); + if (streamClosed) { + fail("Stream already closed"); + } else { + streamClosed = invocationOnMock.getArgument(5); + } + return promise.setSuccess(); + } + }); + when(writer.writeHeaders(eq(ctx), anyInt(), any(Http2Headers.class), + anyInt(), anyBoolean(), any(ChannelPromise.class))) + .then(new Answer() { + @Override + public ChannelFuture answer(InvocationOnMock invocationOnMock) { + ChannelPromise promise = invocationOnMock.getArgument(5); + if (streamClosed) { + fail("Stream already closed"); + } else { + streamClosed = invocationOnMock.getArgument(4); + } + return promise.setSuccess(); + } + }); + payloadCaptor = ArgumentCaptor.forClass(Http2RemoteFlowController.FlowControlled.class); + doNothing().when(remoteFlow).addFlowControlled(any(Http2Stream.class), payloadCaptor.capture()); + when(ctx.alloc()).thenReturn(UnpooledByteBufAllocator.DEFAULT); + when(ctx.channel()).thenReturn(channel); + doAnswer(new Answer() { + @Override + public ChannelPromise answer(InvocationOnMock in) throws Throwable { + return newPromise(); + } + }).when(ctx).newPromise(); + doAnswer(new Answer() { + @Override + public ChannelFuture answer(InvocationOnMock in) throws Throwable { + return newSucceededFuture(); + } + }).when(ctx).newSucceededFuture(); + when(ctx.flush()).thenAnswer(new Answer() { + @Override + public ChannelHandlerContext answer(InvocationOnMock invocationOnMock) { + fail("forbidden"); + return null; + } + }); + when(channel.alloc()).thenReturn(PooledByteBufAllocator.DEFAULT); + + // Use a server-side connection so we can test server push. + connection = new DefaultHttp2Connection(true); + connection.remote().flowController(remoteFlow); + + encoder = new DefaultHttp2ConnectionEncoder(connection, writer); + encoder.lifecycleManager(lifecycleManager); + } + + @Test + public void dataWithEndOfStreamWriteShouldSignalThatFrameWasConsumedOnError() throws Exception { + dataWriteShouldSignalThatFrameWasConsumedOnError0(true); + } + + @Test + public void dataWriteShouldSignalThatFrameWasConsumedOnError() throws Exception { + dataWriteShouldSignalThatFrameWasConsumedOnError0(false); + } + + private void dataWriteShouldSignalThatFrameWasConsumedOnError0(boolean endOfStream) throws Exception { + createStream(STREAM_ID, false); + final ByteBuf data = dummyData(); + ChannelPromise p = newPromise(); + encoder.writeData(ctx, STREAM_ID, data, 0, endOfStream, p); + + FlowControlled controlled = payloadCaptor.getValue(); + assertEquals(8, controlled.size()); + payloadCaptor.getValue().write(ctx, 4); + assertEquals(4, controlled.size()); + + Throwable error = new IllegalStateException(); + payloadCaptor.getValue().error(ctx, error); + payloadCaptor.getValue().write(ctx, 8); + assertEquals(0, controlled.size()); + assertEquals("abcd", writtenData.get(0)); + assertEquals(0, data.refCnt()); + assertSame(error, p.cause()); + } + + @Test + public void dataWriteShouldSucceed() throws Exception { + createStream(STREAM_ID, false); + final ByteBuf data = dummyData(); + ChannelPromise p = newPromise(); + encoder.writeData(ctx, STREAM_ID, data, 0, true, p); + assertEquals(8, payloadCaptor.getValue().size()); + payloadCaptor.getValue().write(ctx, 8); + assertEquals(0, payloadCaptor.getValue().size()); + assertEquals("abcdefgh", writtenData.get(0)); + assertEquals(0, data.refCnt()); + assertTrue(p.isSuccess()); + } + + @Test + public void dataFramesShouldMerge() throws Exception { + createStream(STREAM_ID, false); + final ByteBuf data = dummyData().retain(); + + ChannelPromise promise1 = newPromise(); + encoder.writeData(ctx, STREAM_ID, data, 0, true, promise1); + ChannelPromise promise2 = newPromise(); + encoder.writeData(ctx, STREAM_ID, data, 0, true, promise2); + + // Now merge the two payloads. + List capturedWrites = payloadCaptor.getAllValues(); + FlowControlled mergedPayload = capturedWrites.get(0); + mergedPayload.merge(ctx, capturedWrites.get(1)); + assertEquals(16, mergedPayload.size()); + assertFalse(promise1.isDone()); + assertFalse(promise2.isDone()); + + // Write the merged payloads and verify it was written correctly. + mergedPayload.write(ctx, 16); + assertEquals(0, mergedPayload.size()); + assertEquals("abcdefghabcdefgh", writtenData.get(0)); + assertEquals(0, data.refCnt()); + assertTrue(promise1.isSuccess()); + assertTrue(promise2.isSuccess()); + } + + @Test + public void dataFramesShouldMergeUseVoidPromise() throws Exception { + createStream(STREAM_ID, false); + final ByteBuf data = dummyData().retain(); + + ChannelPromise promise1 = newVoidPromise(channel); + encoder.writeData(ctx, STREAM_ID, data, 0, true, promise1); + ChannelPromise promise2 = newVoidPromise(channel); + encoder.writeData(ctx, STREAM_ID, data, 0, true, promise2); + + // Now merge the two payloads. + List capturedWrites = payloadCaptor.getAllValues(); + FlowControlled mergedPayload = capturedWrites.get(0); + mergedPayload.merge(ctx, capturedWrites.get(1)); + assertEquals(16, mergedPayload.size()); + assertFalse(promise1.isSuccess()); + assertFalse(promise2.isSuccess()); + + // Write the merged payloads and verify it was written correctly. + mergedPayload.write(ctx, 16); + assertEquals(0, mergedPayload.size()); + assertEquals("abcdefghabcdefgh", writtenData.get(0)); + assertEquals(0, data.refCnt()); + + // The promises won't be set since there are no listeners. + assertFalse(promise1.isSuccess()); + assertFalse(promise2.isSuccess()); + } + + @Test + public void dataFramesDontMergeWithHeaders() throws Exception { + createStream(STREAM_ID, false); + final ByteBuf data = dummyData().retain(); + encoder.writeData(ctx, STREAM_ID, data, 0, false, newPromise()); + when(remoteFlow.hasFlowControlled(any(Http2Stream.class))).thenReturn(true); + encoder.writeHeaders(ctx, STREAM_ID, EmptyHttp2Headers.INSTANCE, 0, true, newPromise()); + List capturedWrites = payloadCaptor.getAllValues(); + assertFalse(capturedWrites.get(0).merge(ctx, capturedWrites.get(1))); + } + + @Test + public void emptyFrameShouldSplitPadding() throws Exception { + ByteBuf data = Unpooled.buffer(0); + assertSplitPaddingOnEmptyBuffer(data); + assertEquals(0, data.refCnt()); + } + + @Test + public void writeHeadersUsingVoidPromise() throws Exception { + final Throwable cause = new RuntimeException("fake exception"); + when(writer.writeHeaders(eq(ctx), eq(STREAM_ID), any(Http2Headers.class), + anyInt(), anyBoolean(), any(ChannelPromise.class))) + .then(new Answer() { + @Override + public ChannelFuture answer(InvocationOnMock invocationOnMock) throws Throwable { + ChannelPromise promise = invocationOnMock.getArgument(5); + assertFalse(promise.isVoid()); + return promise.setFailure(cause); + } + }); + createStream(STREAM_ID, false); + // END_STREAM flag, so that a listener is added to the future. + encoder.writeHeaders(ctx, STREAM_ID, EmptyHttp2Headers.INSTANCE, 0, true, newVoidPromise(channel)); + + verify(writer).writeHeaders(eq(ctx), eq(STREAM_ID), any(Http2Headers.class), + anyInt(), anyBoolean(), any(ChannelPromise.class)); + // When using a void promise, the error should be propagated via the channel pipeline. + verify(pipeline).fireExceptionCaught(cause); + } + + private void assertSplitPaddingOnEmptyBuffer(ByteBuf data) throws Exception { + createStream(STREAM_ID, false); + when(frameSizePolicy.maxFrameSize()).thenReturn(5); + ChannelPromise p = newPromise(); + encoder.writeData(ctx, STREAM_ID, data, 10, true, p); + assertEquals(10, payloadCaptor.getValue().size()); + payloadCaptor.getValue().write(ctx, 10); + // writer was called 2 times + assertEquals(1, writtenData.size()); + assertEquals("", writtenData.get(0)); + assertEquals(10, (int) writtenPadding.get(0)); + assertEquals(0, data.refCnt()); + assertTrue(p.isSuccess()); + } + + @Test + public void headersWriteForUnknownStreamShouldCreateStream() throws Exception { + writeAllFlowControlledFrames(); + final int streamId = 6; + ChannelPromise promise = newPromise(); + encoder.writeHeaders(ctx, streamId, EmptyHttp2Headers.INSTANCE, 0, false, promise); + verify(writer).writeHeaders(eq(ctx), eq(streamId), eq(EmptyHttp2Headers.INSTANCE), eq(0), + eq(false), eq(promise)); + assertTrue(promise.isSuccess()); + } + + @Test + public void headersWriteShouldOpenStreamForPush() throws Exception { + writeAllFlowControlledFrames(); + Http2Stream parent = createStream(STREAM_ID, false); + reservePushStream(PUSH_STREAM_ID, parent); + + ChannelPromise promise = newPromise(); + encoder.writeHeaders(ctx, PUSH_STREAM_ID, EmptyHttp2Headers.INSTANCE, 0, false, promise); + assertEquals(HALF_CLOSED_REMOTE, stream(PUSH_STREAM_ID).state()); + verify(writer).writeHeaders(eq(ctx), eq(PUSH_STREAM_ID), eq(EmptyHttp2Headers.INSTANCE), + eq(0), eq(false), eq(promise)); + } + + @Test + public void trailersDoNotEndStreamThrows() { + writeAllFlowControlledFrames(); + final int streamId = 6; + ChannelPromise promise = newPromise(); + encoder.writeHeaders(ctx, streamId, EmptyHttp2Headers.INSTANCE, 0, false, promise); + + ChannelPromise promise2 = newPromise(); + ChannelFuture future = encoder.writeHeaders(ctx, streamId, EmptyHttp2Headers.INSTANCE, 0, false, promise2); + assertTrue(future.isDone()); + assertFalse(future.isSuccess()); + + verify(writer, times(1)).writeHeaders(eq(ctx), eq(streamId), eq(EmptyHttp2Headers.INSTANCE), + eq(0), eq(false), eq(promise)); + } + + @Test + public void trailersDoNotEndStreamWithDataThrows() { + writeAllFlowControlledFrames(); + final int streamId = 6; + ChannelPromise promise = newPromise(); + encoder.writeHeaders(ctx, streamId, EmptyHttp2Headers.INSTANCE, 0, false, promise); + + Http2Stream stream = connection.stream(streamId); + when(remoteFlow.hasFlowControlled(eq(stream))).thenReturn(true); + + ChannelPromise promise2 = newPromise(); + ChannelFuture future = encoder.writeHeaders(ctx, streamId, EmptyHttp2Headers.INSTANCE, 0, false, promise2); + assertTrue(future.isDone()); + assertFalse(future.isSuccess()); + + verify(writer, times(1)).writeHeaders(eq(ctx), eq(streamId), eq(EmptyHttp2Headers.INSTANCE), + eq(0), eq(false), eq(promise)); + } + + @Test + public void tooManyHeadersNoEOSThrows() { + tooManyHeadersThrows(false); + } + + @Test + public void tooManyHeadersEOSThrows() { + tooManyHeadersThrows(true); + } + + private void tooManyHeadersThrows(boolean eos) { + writeAllFlowControlledFrames(); + final int streamId = 6; + ChannelPromise promise = newPromise(); + encoder.writeHeaders(ctx, streamId, EmptyHttp2Headers.INSTANCE, 0, false, promise); + ChannelPromise promise2 = newPromise(); + encoder.writeHeaders(ctx, streamId, EmptyHttp2Headers.INSTANCE, 0, true, promise2); + + ChannelPromise promise3 = newPromise(); + ChannelFuture future = encoder.writeHeaders(ctx, streamId, EmptyHttp2Headers.INSTANCE, 0, eos, promise3); + assertTrue(future.isDone()); + assertFalse(future.isSuccess()); + + verify(writer, times(1)).writeHeaders(eq(ctx), eq(streamId), eq(EmptyHttp2Headers.INSTANCE), + eq(0), eq(false), eq(promise)); + verify(writer, times(1)).writeHeaders(eq(ctx), eq(streamId), eq(EmptyHttp2Headers.INSTANCE), + eq(0), eq(true), eq(promise2)); + } + + @Test + public void infoHeadersAndTrailersAllowed() throws Exception { + infoHeadersAndTrailers(true, 1); + } + + @Test + public void multipleInfoHeadersAndTrailersAllowed() throws Exception { + infoHeadersAndTrailers(true, 10); + } + + @Test + public void infoHeadersAndTrailersNoEOSThrows() throws Exception { + infoHeadersAndTrailers(false, 1); + } + + @Test + public void multipleInfoHeadersAndTrailersNoEOSThrows() throws Exception { + infoHeadersAndTrailers(false, 10); + } + + private void infoHeadersAndTrailers(boolean eos, int infoHeaderCount) { + writeAllFlowControlledFrames(); + final int streamId = 6; + Http2Headers infoHeaders = informationalHeaders(); + for (int i = 0; i < infoHeaderCount; ++i) { + encoder.writeHeaders(ctx, streamId, infoHeaders, 0, false, newPromise()); + } + ChannelPromise promise2 = newPromise(); + encoder.writeHeaders(ctx, streamId, EmptyHttp2Headers.INSTANCE, 0, false, promise2); + + ChannelPromise promise3 = newPromise(); + ChannelFuture future = encoder.writeHeaders(ctx, streamId, EmptyHttp2Headers.INSTANCE, 0, eos, promise3); + assertTrue(future.isDone()); + assertEquals(eos, future.isSuccess()); + + verify(writer, times(infoHeaderCount)).writeHeaders(eq(ctx), eq(streamId), eq(infoHeaders), + eq(0), eq(false), any(ChannelPromise.class)); + verify(writer, times(1)).writeHeaders(eq(ctx), eq(streamId), eq(EmptyHttp2Headers.INSTANCE), + eq(0), eq(false), eq(promise2)); + if (eos) { + verify(writer, times(1)).writeHeaders(eq(ctx), eq(streamId), eq(EmptyHttp2Headers.INSTANCE), + eq(0), eq(true), eq(promise3)); + } + } + + private static Http2Headers informationalHeaders() { + Http2Headers headers = new DefaultHttp2Headers(); + headers.status(HttpResponseStatus.CONTINUE.codeAsText()); + return headers; + } + + @Test + public void tooManyHeadersWithDataNoEOSThrows() { + tooManyHeadersWithDataThrows(false); + } + + @Test + public void tooManyHeadersWithDataEOSThrows() { + tooManyHeadersWithDataThrows(true); + } + + private void tooManyHeadersWithDataThrows(boolean eos) { + writeAllFlowControlledFrames(); + final int streamId = 6; + ChannelPromise promise = newPromise(); + encoder.writeHeaders(ctx, streamId, EmptyHttp2Headers.INSTANCE, 0, false, promise); + + Http2Stream stream = connection.stream(streamId); + when(remoteFlow.hasFlowControlled(eq(stream))).thenReturn(true); + + ChannelPromise promise2 = newPromise(); + encoder.writeHeaders(ctx, streamId, EmptyHttp2Headers.INSTANCE, 0, true, promise2); + + ChannelPromise promise3 = newPromise(); + ChannelFuture future = encoder.writeHeaders(ctx, streamId, EmptyHttp2Headers.INSTANCE, 0, eos, promise3); + assertTrue(future.isDone()); + assertFalse(future.isSuccess()); + + verify(writer, times(1)).writeHeaders(eq(ctx), eq(streamId), eq(EmptyHttp2Headers.INSTANCE), + eq(0), eq(false), eq(promise)); + verify(writer, times(1)).writeHeaders(eq(ctx), eq(streamId), eq(EmptyHttp2Headers.INSTANCE), + eq(0), eq(true), eq(promise2)); + } + + @Test + public void infoHeadersAndTrailersWithDataAllowed() { + infoHeadersAndTrailersWithData(true, 1); + } + + @Test + public void multipleInfoHeadersAndTrailersWithDataAllowed() { + infoHeadersAndTrailersWithData(true, 10); + } + + @Test + public void infoHeadersAndTrailersWithDataNoEOSThrows() { + infoHeadersAndTrailersWithData(false, 1); + } + + @Test + public void multipleInfoHeadersAndTrailersWithDataNoEOSThrows() { + infoHeadersAndTrailersWithData(false, 10); + } + + private void infoHeadersAndTrailersWithData(boolean eos, int infoHeaderCount) { + writeAllFlowControlledFrames(); + final int streamId = 6; + Http2Headers infoHeaders = informationalHeaders(); + for (int i = 0; i < infoHeaderCount; ++i) { + encoder.writeHeaders(ctx, streamId, infoHeaders, 0, false, newPromise()); + } + + Http2Stream stream = connection.stream(streamId); + when(remoteFlow.hasFlowControlled(eq(stream))).thenReturn(true); + + ChannelPromise promise2 = newPromise(); + encoder.writeHeaders(ctx, streamId, EmptyHttp2Headers.INSTANCE, 0, false, promise2); + + ChannelPromise promise3 = newPromise(); + ChannelFuture future = encoder.writeHeaders(ctx, streamId, EmptyHttp2Headers.INSTANCE, 0, eos, promise3); + assertTrue(future.isDone()); + assertEquals(eos, future.isSuccess()); + + verify(writer, times(infoHeaderCount)).writeHeaders(eq(ctx), eq(streamId), eq(infoHeaders), + eq(0), eq(false), any(ChannelPromise.class)); + verify(writer, times(1)).writeHeaders(eq(ctx), eq(streamId), eq(EmptyHttp2Headers.INSTANCE), + eq(0), eq(false), eq(promise2)); + if (eos) { + verify(writer, times(1)).writeHeaders(eq(ctx), eq(streamId), eq(EmptyHttp2Headers.INSTANCE), + eq(0), eq(true), eq(promise3)); + } + } + + @Test + public void pushPromiseWriteAfterGoAwayReceivedShouldFail() throws Exception { + createStream(STREAM_ID, false); + goAwayReceived(0); + ChannelFuture future = encoder.writePushPromise(ctx, STREAM_ID, PUSH_STREAM_ID, EmptyHttp2Headers.INSTANCE, 0, + newPromise()); + assertTrue(future.isDone()); + assertFalse(future.isSuccess()); + } + + @Test + public void pushPromiseWriteShouldReserveStream() throws Exception { + createStream(STREAM_ID, false); + ChannelPromise promise = newPromise(); + encoder.writePushPromise(ctx, STREAM_ID, PUSH_STREAM_ID, EmptyHttp2Headers.INSTANCE, 0, promise); + assertEquals(RESERVED_LOCAL, stream(PUSH_STREAM_ID).state()); + verify(writer).writePushPromise(eq(ctx), eq(STREAM_ID), eq(PUSH_STREAM_ID), + eq(EmptyHttp2Headers.INSTANCE), eq(0), eq(promise)); + } + + @Test + public void priorityWriteAfterGoAwayShouldSucceed() throws Exception { + createStream(STREAM_ID, false); + goAwayReceived(Integer.MAX_VALUE); + ChannelPromise promise = newPromise(); + encoder.writePriority(ctx, STREAM_ID, 0, (short) 255, true, promise); + verify(writer).writePriority(eq(ctx), eq(STREAM_ID), eq(0), eq((short) 255), eq(true), eq(promise)); + } + + @Test + public void priorityWriteShouldSetPriorityForStream() throws Exception { + ChannelPromise promise = newPromise(); + short weight = 255; + encoder.writePriority(ctx, STREAM_ID, 0, weight, true, promise); + + // Verify that this did NOT create a stream object. + Http2Stream stream = stream(STREAM_ID); + assertNull(stream); + + verify(writer).writePriority(eq(ctx), eq(STREAM_ID), eq(0), eq((short) 255), eq(true), eq(promise)); + } + + @Test + public void priorityWriteOnPreviouslyExistingStreamShouldSucceed() throws Exception { + createStream(STREAM_ID, false).close(); + ChannelPromise promise = newPromise(); + short weight = 255; + encoder.writePriority(ctx, STREAM_ID, 0, weight, true, promise); + verify(writer).writePriority(eq(ctx), eq(STREAM_ID), eq(0), eq(weight), eq(true), eq(promise)); + } + + @Test + public void priorityWriteOnPreviouslyExistingParentStreamShouldSucceed() throws Exception { + final int parentStreamId = STREAM_ID + 2; + createStream(STREAM_ID, false); + createStream(parentStreamId, false).close(); + + ChannelPromise promise = newPromise(); + short weight = 255; + encoder.writePriority(ctx, STREAM_ID, parentStreamId, weight, true, promise); + verify(writer).writePriority(eq(ctx), eq(STREAM_ID), eq(parentStreamId), eq(weight), eq(true), eq(promise)); + } + + @Test + public void rstStreamWriteForUnknownStreamShouldIgnore() throws Exception { + ChannelPromise promise = newPromise(); + encoder.writeRstStream(ctx, 5, PROTOCOL_ERROR.code(), promise); + verify(writer, never()).writeRstStream(eq(ctx), anyInt(), anyLong(), eq(promise)); + } + + @Test + public void rstStreamShouldCloseStream() throws Exception { + // Create the stream and send headers. + writeAllFlowControlledFrames(); + encoder.writeHeaders(ctx, STREAM_ID, EmptyHttp2Headers.INSTANCE, 0, true, newPromise()); + + // Now verify that a stream reset is performed. + stream(STREAM_ID); + ChannelPromise promise = newPromise(); + encoder.writeRstStream(ctx, STREAM_ID, PROTOCOL_ERROR.code(), promise); + verify(lifecycleManager).resetStream(eq(ctx), eq(STREAM_ID), anyLong(), eq(promise)); + } + + @Test + public void pingWriteAfterGoAwayShouldSucceed() throws Exception { + ChannelPromise promise = newPromise(); + goAwayReceived(0); + encoder.writePing(ctx, false, 0L, promise); + verify(writer).writePing(eq(ctx), eq(false), eq(0L), eq(promise)); + } + + @Test + public void pingWriteShouldSucceed() throws Exception { + ChannelPromise promise = newPromise(); + encoder.writePing(ctx, false, 0L, promise); + verify(writer).writePing(eq(ctx), eq(false), eq(0L), eq(promise)); + } + + @Test + public void settingsWriteAfterGoAwayShouldSucceed() throws Exception { + goAwayReceived(0); + ChannelPromise promise = newPromise(); + encoder.writeSettings(ctx, new Http2Settings(), promise); + verify(writer).writeSettings(eq(ctx), any(Http2Settings.class), eq(promise)); + } + + @Test + public void settingsWriteShouldNotUpdateSettings() throws Exception { + Http2Settings settings = new Http2Settings(); + settings.initialWindowSize(100); + settings.maxConcurrentStreams(1000); + settings.headerTableSize(2000); + + ChannelPromise promise = newPromise(); + encoder.writeSettings(ctx, settings, promise); + verify(writer).writeSettings(eq(ctx), eq(settings), eq(promise)); + } + + @Test + public void dataWriteShouldCreateHalfClosedStream() throws Exception { + writeAllFlowControlledFrames(); + + Http2Stream stream = createStream(STREAM_ID, false); + ByteBuf data = dummyData(); + ChannelPromise promise = newPromise(); + encoder.writeData(ctx, STREAM_ID, data.retain(), 0, true, promise); + assertTrue(promise.isSuccess()); + verify(remoteFlow).addFlowControlled(eq(stream), any(FlowControlled.class)); + verify(lifecycleManager).closeStreamLocal(stream, promise); + assertEquals(data.toString(UTF_8), writtenData.get(0)); + data.release(); + } + + @Test + public void headersWriteShouldHalfCloseStream() throws Exception { + writeAllFlowControlledFrames(); + createStream(STREAM_ID, false); + ChannelPromise promise = newPromise(); + encoder.writeHeaders(ctx, STREAM_ID, EmptyHttp2Headers.INSTANCE, 0, true, promise); + + assertTrue(promise.isSuccess()); + verify(lifecycleManager).closeStreamLocal(eq(stream(STREAM_ID)), eq(promise)); + } + + @Test + public void headersWriteShouldHalfClosePushStream() throws Exception { + writeAllFlowControlledFrames(); + Http2Stream parent = createStream(STREAM_ID, false); + Http2Stream stream = reservePushStream(PUSH_STREAM_ID, parent); + ChannelPromise promise = newPromise(); + encoder.writeHeaders(ctx, PUSH_STREAM_ID, EmptyHttp2Headers.INSTANCE, 0, true, promise); + assertEquals(HALF_CLOSED_REMOTE, stream.state()); + assertTrue(promise.isSuccess()); + verify(lifecycleManager).closeStreamLocal(eq(stream), eq(promise)); + } + + @Test + public void headersWriteShouldHalfCloseAfterOnErrorForPreCreatedStream() throws Exception { + final ChannelPromise promise = newPromise(); + final Throwable ex = new RuntimeException(); + // Fake an encoding error, like HPACK's HeaderListSizeException + when(writer.writeHeaders(eq(ctx), eq(STREAM_ID), eq(EmptyHttp2Headers.INSTANCE), eq(0), eq(true), eq(promise))) + .thenAnswer(new Answer() { + @Override + public ChannelFuture answer(InvocationOnMock invocation) { + promise.setFailure(ex); + return promise; + } + }); + + writeAllFlowControlledFrames(); + Http2Stream stream = createStream(STREAM_ID, false); + encoder.writeHeaders(ctx, STREAM_ID, EmptyHttp2Headers.INSTANCE, 0, true, promise); + + assertTrue(promise.isDone()); + assertFalse(promise.isSuccess()); + assertFalse(stream.isHeadersSent()); + InOrder inOrder = inOrder(lifecycleManager); + inOrder.verify(lifecycleManager).onError(eq(ctx), eq(true), eq(ex)); + inOrder.verify(lifecycleManager).closeStreamLocal(eq(stream(STREAM_ID)), eq(promise)); + } + + @Test + public void headersWriteShouldHalfCloseAfterOnErrorForImplicitlyCreatedStream() throws Exception { + final ChannelPromise promise = newPromise(); + final Throwable ex = new RuntimeException(); + // Fake an encoding error, like HPACK's HeaderListSizeException + when(writer.writeHeaders(eq(ctx), eq(STREAM_ID), eq(EmptyHttp2Headers.INSTANCE), eq(0), eq(true), eq(promise))) + .thenAnswer(new Answer() { + @Override + public ChannelFuture answer(InvocationOnMock invocation) { + promise.setFailure(ex); + return promise; + } + }); + + writeAllFlowControlledFrames(); + encoder.writeHeaders(ctx, STREAM_ID, EmptyHttp2Headers.INSTANCE, 0, true, promise); + + assertTrue(promise.isDone()); + assertFalse(promise.isSuccess()); + assertFalse(stream(STREAM_ID).isHeadersSent()); + InOrder inOrder = inOrder(lifecycleManager); + inOrder.verify(lifecycleManager).onError(eq(ctx), eq(true), eq(ex)); + inOrder.verify(lifecycleManager).closeStreamLocal(eq(stream(STREAM_ID)), eq(promise)); + } + + @Test + public void encoderDelegatesGoAwayToLifeCycleManager() { + ChannelPromise promise = newPromise(); + encoder.writeGoAway(ctx, STREAM_ID, Http2Error.INTERNAL_ERROR.code(), null, promise); + verify(lifecycleManager).goAway(eq(ctx), eq(STREAM_ID), eq(Http2Error.INTERNAL_ERROR.code()), + eq((ByteBuf) null), eq(promise)); + verifyNoMoreInteractions(writer); + } + + @Test + public void dataWriteToClosedStreamShouldFail() throws Exception { + createStream(STREAM_ID, false).close(); + ByteBuf data = mock(ByteBuf.class); + ChannelPromise promise = newPromise(); + encoder.writeData(ctx, STREAM_ID, data, 0, false, promise); + assertTrue(promise.isDone()); + assertFalse(promise.isSuccess()); + assertThat(promise.cause(), instanceOf(IllegalArgumentException.class)); + verify(data).release(); + } + + @Test + public void dataWriteToHalfClosedLocalStreamShouldFail() throws Exception { + createStream(STREAM_ID, true); + ByteBuf data = mock(ByteBuf.class); + ChannelPromise promise = newPromise(); + encoder.writeData(ctx, STREAM_ID, data, 0, false, promise); + assertTrue(promise.isDone()); + assertFalse(promise.isSuccess()); + assertThat(promise.cause(), instanceOf(IllegalStateException.class)); + verify(data).release(); + } + + @Test + public void canWriteDataFrameAfterGoAwaySent() throws Exception { + Http2Stream stream = createStream(STREAM_ID, false); + connection.goAwaySent(0, 0, EMPTY_BUFFER); + ByteBuf data = mock(ByteBuf.class); + encoder.writeData(ctx, STREAM_ID, data, 0, false, newPromise()); + verify(remoteFlow).addFlowControlled(eq(stream), any(FlowControlled.class)); + } + + @Test + public void canWriteHeaderFrameAfterGoAwaySent() throws Exception { + writeAllFlowControlledFrames(); + createStream(STREAM_ID, false); + goAwaySent(0); + ChannelPromise promise = newPromise(); + encoder.writeHeaders(ctx, STREAM_ID, EmptyHttp2Headers.INSTANCE, 0, false, promise); + verify(writer).writeHeaders(eq(ctx), eq(STREAM_ID), eq(EmptyHttp2Headers.INSTANCE), + eq(0), eq(false), eq(promise)); + } + + @Test + public void canWriteDataFrameAfterGoAwayReceived() throws Exception { + Http2Stream stream = createStream(STREAM_ID, false); + goAwayReceived(STREAM_ID); + ByteBuf data = mock(ByteBuf.class); + encoder.writeData(ctx, STREAM_ID, data, 0, false, newPromise()); + verify(remoteFlow).addFlowControlled(eq(stream), any(FlowControlled.class)); + } + + @Test + public void canWriteHeaderFrameAfterGoAwayReceived() throws Http2Exception { + writeAllFlowControlledFrames(); + goAwayReceived(STREAM_ID); + ChannelPromise promise = newPromise(); + encoder.writeHeaders(ctx, STREAM_ID, EmptyHttp2Headers.INSTANCE, 0, false, promise); + verify(writer).writeHeaders(eq(ctx), eq(STREAM_ID), eq(EmptyHttp2Headers.INSTANCE), + eq(0), eq(false), eq(promise)); + } + + @Test + public void headersWithNoPriority() { + writeAllFlowControlledFrames(); + final int streamId = 6; + ChannelPromise promise = newPromise(); + encoder.writeHeaders(ctx, streamId, EmptyHttp2Headers.INSTANCE, 0, false, promise); + verify(writer).writeHeaders(eq(ctx), eq(streamId), eq(EmptyHttp2Headers.INSTANCE), + eq(0), eq(false), eq(promise)); + } + + @Test + public void headersWithPriority() { + writeAllFlowControlledFrames(); + final int streamId = 6; + ChannelPromise promise = newPromise(); + encoder.writeHeaders(ctx, streamId, EmptyHttp2Headers.INSTANCE, 10, DEFAULT_PRIORITY_WEIGHT, + true, 1, false, promise); + verify(writer).writeHeaders(eq(ctx), eq(streamId), eq(EmptyHttp2Headers.INSTANCE), eq(10), + eq(DEFAULT_PRIORITY_WEIGHT), eq(true), eq(1), eq(false), eq(promise)); + } + + private void writeAllFlowControlledFrames() { + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocationOnMock) throws Throwable { + FlowControlled flowControlled = (FlowControlled) invocationOnMock.getArguments()[1]; + flowControlled.write(ctx, Integer.MAX_VALUE); + flowControlled.writeComplete(); + return null; + } + }).when(remoteFlow).addFlowControlled(any(Http2Stream.class), payloadCaptor.capture()); + } + + private Http2Stream createStream(int streamId, boolean halfClosed) throws Http2Exception { + return connection.local().createStream(streamId, halfClosed); + } + + private Http2Stream reservePushStream(int pushStreamId, Http2Stream parent) throws Http2Exception { + return connection.local().reservePushStream(pushStreamId, parent); + } + + private Http2Stream stream(int streamId) { + return connection.stream(streamId); + } + + private void goAwayReceived(int lastStreamId) throws Http2Exception { + connection.goAwayReceived(lastStreamId, 0, EMPTY_BUFFER); + } + + private void goAwaySent(int lastStreamId) throws Http2Exception { + connection.goAwaySent(lastStreamId, 0, EMPTY_BUFFER); + } + + private ChannelPromise newPromise() { + return new DefaultChannelPromise(channel, ImmediateEventExecutor.INSTANCE); + } + + private ChannelFuture newSucceededFuture() { + return newPromise().setSuccess(); + } + + private static ByteBuf dummyData() { + // The buffer is purposely 8 bytes so it will even work for a ping frame. + return wrappedBuffer("abcdefgh".getBytes(UTF_8)); + } +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionTest.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionTest.java new file mode 100644 index 0000000..37d401c --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2ConnectionTest.java @@ -0,0 +1,731 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.DefaultEventLoopGroup; +import io.netty.handler.codec.http2.Http2Connection.Endpoint; +import io.netty.handler.codec.http2.Http2Stream.State; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.FutureListener; +import io.netty.util.concurrent.Promise; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; +import org.mockito.ArgumentMatchers; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import static java.lang.Integer.MAX_VALUE; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.anyInt; +import static org.mockito.Mockito.anyLong; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; + +/** + * Tests for {@link DefaultHttp2Connection}. + */ +public class DefaultHttp2ConnectionTest { + + private DefaultHttp2Connection server; + private DefaultHttp2Connection client; + private static DefaultEventLoopGroup group; + + @Mock + private Http2Connection.Listener clientListener; + + @Mock + private Http2Connection.Listener clientListener2; + + @BeforeAll + public static void beforeClass() { + group = new DefaultEventLoopGroup(2); + } + + @AfterAll + public static void afterClass() { + group.shutdownGracefully(); + } + + @BeforeEach + public void setup() { + MockitoAnnotations.initMocks(this); + + server = new DefaultHttp2Connection(true); + client = new DefaultHttp2Connection(false); + client.addListener(clientListener); + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocation) throws Throwable { + assertNotNull(client.stream(((Http2Stream) invocation.getArgument(0)).id())); + return null; + } + }).when(clientListener).onStreamClosed(any(Http2Stream.class)); + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocation) throws Throwable { + assertNull(client.stream(((Http2Stream) invocation.getArgument(0)).id())); + return null; + } + }).when(clientListener).onStreamRemoved(any(Http2Stream.class)); + } + + @Test + public void getStreamWithoutStreamShouldReturnNull() { + assertNull(server.stream(100)); + } + + @Test + public void removeAllStreamsWithEmptyStreams() throws InterruptedException { + testRemoveAllStreams(); + } + + @Test + public void removeAllStreamsWithJustOneLocalStream() throws Exception { + client.local().createStream(3, false); + testRemoveAllStreams(); + } + + @Test + public void removeAllStreamsWithJustOneRemoveStream() throws Exception { + client.remote().createStream(2, false); + testRemoveAllStreams(); + } + + @Test + public void removeAllStreamsWithManyActiveStreams() throws Exception { + Endpoint remote = client.remote(); + Endpoint local = client.local(); + for (int c = 3, s = 2; c < 5000; c += 2, s += 2) { + local.createStream(c, false); + remote.createStream(s, false); + } + testRemoveAllStreams(); + } + + @Test + public void removeIndividualStreamsWhileCloseDoesNotNPE() throws Exception { + final Http2Stream streamA = client.local().createStream(3, false); + final Http2Stream streamB = client.remote().createStream(2, false); + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocation) throws Throwable { + streamA.close(); + streamB.close(); + return null; + } + }).when(clientListener2).onStreamClosed(any(Http2Stream.class)); + try { + client.addListener(clientListener2); + testRemoveAllStreams(); + } finally { + client.removeListener(clientListener2); + } + } + + @Test + public void removeAllStreamsWhileIteratingActiveStreams() throws Exception { + final Endpoint remote = client.remote(); + final Endpoint local = client.local(); + for (int c = 3, s = 2; c < 5000; c += 2, s += 2) { + local.createStream(c, false); + remote.createStream(s, false); + } + final Promise promise = group.next().newPromise(); + final CountDownLatch latch = new CountDownLatch(client.numActiveStreams()); + client.forEachActiveStream(new Http2StreamVisitor() { + @Override + public boolean visit(Http2Stream stream) { + client.close(promise).addListener(new FutureListener() { + @Override + public void operationComplete(Future future) throws Exception { + assertTrue(promise.isDone()); + latch.countDown(); + } + }); + return true; + } + }); + assertTrue(latch.await(5, TimeUnit.SECONDS)); + } + + @Test + public void removeAllStreamsWhileIteratingActiveStreamsAndExceptionOccurs() + throws Exception { + final Endpoint remote = client.remote(); + final Endpoint local = client.local(); + for (int c = 3, s = 2; c < 5000; c += 2, s += 2) { + local.createStream(c, false); + remote.createStream(s, false); + } + final Promise promise = group.next().newPromise(); + final CountDownLatch latch = new CountDownLatch(1); + try { + client.forEachActiveStream(new Http2StreamVisitor() { + @Override + public boolean visit(Http2Stream stream) throws Http2Exception { + // This close call is basically a noop, because the following statement will throw an exception. + client.close(promise); + // Do an invalid operation while iterating. + remote.createStream(3, false); + return true; + } + }); + } catch (Http2Exception ignored) { + client.close(promise).addListener(new FutureListener() { + @Override + public void operationComplete(Future future) throws Exception { + assertTrue(promise.isDone()); + latch.countDown(); + } + }); + } + assertTrue(latch.await(5, TimeUnit.SECONDS)); + } + + @Test + public void goAwayReceivedShouldCloseStreamsGreaterThanLastStream() throws Exception { + Http2Stream stream1 = client.local().createStream(3, false); + Http2Stream stream2 = client.local().createStream(5, false); + Http2Stream remoteStream = client.remote().createStream(4, false); + + assertEquals(State.OPEN, stream1.state()); + assertEquals(State.OPEN, stream2.state()); + + client.goAwayReceived(3, 8, null); + + assertEquals(State.OPEN, stream1.state()); + assertEquals(State.CLOSED, stream2.state()); + assertEquals(State.OPEN, remoteStream.state()); + assertEquals(3, client.local().lastStreamKnownByPeer()); + assertEquals(5, client.local().lastStreamCreated()); + // The remote endpoint must not be affected by a received GOAWAY frame. + assertEquals(-1, client.remote().lastStreamKnownByPeer()); + assertEquals(State.OPEN, remoteStream.state()); + } + + @Test + public void goAwaySentShouldCloseStreamsGreaterThanLastStream() throws Exception { + Http2Stream stream1 = server.remote().createStream(3, false); + Http2Stream stream2 = server.remote().createStream(5, false); + Http2Stream localStream = server.local().createStream(4, false); + + server.goAwaySent(3, 8, null); + + assertEquals(State.OPEN, stream1.state()); + assertEquals(State.CLOSED, stream2.state()); + + assertEquals(3, server.remote().lastStreamKnownByPeer()); + assertEquals(5, server.remote().lastStreamCreated()); + // The local endpoint must not be affected by a sent GOAWAY frame. + assertEquals(-1, server.local().lastStreamKnownByPeer()); + assertEquals(State.OPEN, localStream.state()); + } + + @Test + public void serverCreateStreamShouldSucceed() throws Http2Exception { + Http2Stream stream = server.local().createStream(2, false); + assertEquals(2, stream.id()); + assertEquals(State.OPEN, stream.state()); + assertEquals(1, server.numActiveStreams()); + assertEquals(2, server.local().lastStreamCreated()); + + stream = server.local().createStream(4, true); + assertEquals(4, stream.id()); + assertEquals(State.HALF_CLOSED_LOCAL, stream.state()); + assertEquals(2, server.numActiveStreams()); + assertEquals(4, server.local().lastStreamCreated()); + + stream = server.remote().createStream(3, true); + assertEquals(3, stream.id()); + assertEquals(State.HALF_CLOSED_REMOTE, stream.state()); + assertEquals(3, server.numActiveStreams()); + assertEquals(3, server.remote().lastStreamCreated()); + + stream = server.remote().createStream(5, false); + assertEquals(5, stream.id()); + assertEquals(State.OPEN, stream.state()); + assertEquals(4, server.numActiveStreams()); + assertEquals(5, server.remote().lastStreamCreated()); + } + + @Test + public void clientCreateStreamShouldSucceed() throws Http2Exception { + Http2Stream stream = client.remote().createStream(2, false); + assertEquals(2, stream.id()); + assertEquals(State.OPEN, stream.state()); + assertEquals(1, client.numActiveStreams()); + assertEquals(2, client.remote().lastStreamCreated()); + + stream = client.remote().createStream(4, true); + assertEquals(4, stream.id()); + assertEquals(State.HALF_CLOSED_REMOTE, stream.state()); + assertEquals(2, client.numActiveStreams()); + assertEquals(4, client.remote().lastStreamCreated()); + assertTrue(stream.isHeadersReceived()); + + stream = client.local().createStream(3, true); + assertEquals(3, stream.id()); + assertEquals(State.HALF_CLOSED_LOCAL, stream.state()); + assertEquals(3, client.numActiveStreams()); + assertEquals(3, client.local().lastStreamCreated()); + assertTrue(stream.isHeadersSent()); + + stream = client.local().createStream(5, false); + assertEquals(5, stream.id()); + assertEquals(State.OPEN, stream.state()); + assertEquals(4, client.numActiveStreams()); + assertEquals(5, client.local().lastStreamCreated()); + } + + @Test + public void serverReservePushStreamShouldSucceed() throws Http2Exception { + Http2Stream stream = server.remote().createStream(3, true); + Http2Stream pushStream = server.local().reservePushStream(2, stream); + assertEquals(2, pushStream.id()); + assertEquals(State.RESERVED_LOCAL, pushStream.state()); + assertEquals(1, server.numActiveStreams()); + assertEquals(2, server.local().lastStreamCreated()); + } + + @Test + public void clientReservePushStreamShouldSucceed() throws Http2Exception { + Http2Stream stream = server.remote().createStream(3, true); + Http2Stream pushStream = server.local().reservePushStream(4, stream); + assertEquals(4, pushStream.id()); + assertEquals(State.RESERVED_LOCAL, pushStream.state()); + assertEquals(1, server.numActiveStreams()); + assertEquals(4, server.local().lastStreamCreated()); + } + + @Test + public void serverRemoteIncrementAndGetStreamShouldSucceed() throws Http2Exception { + incrementAndGetStreamShouldSucceed(server.remote()); + } + + @Test + public void serverLocalIncrementAndGetStreamShouldSucceed() throws Http2Exception { + incrementAndGetStreamShouldSucceed(server.local()); + } + + @Test + public void clientRemoteIncrementAndGetStreamShouldSucceed() throws Http2Exception { + incrementAndGetStreamShouldSucceed(client.remote()); + } + + @Test + public void clientLocalIncrementAndGetStreamShouldSucceed() throws Http2Exception { + incrementAndGetStreamShouldSucceed(client.local()); + } + + @Test + public void serverRemoteIncrementAndGetStreamShouldRespectOverflow() throws Http2Exception { + incrementAndGetStreamShouldRespectOverflow(server.remote(), MAX_VALUE); + } + + @Test + public void serverLocalIncrementAndGetStreamShouldRespectOverflow() throws Http2Exception { + incrementAndGetStreamShouldRespectOverflow(server.local(), MAX_VALUE - 1); + } + + @Test + public void clientRemoteIncrementAndGetStreamShouldRespectOverflow() throws Http2Exception { + incrementAndGetStreamShouldRespectOverflow(client.remote(), MAX_VALUE - 1); + } + + @Test + public void clientLocalIncrementAndGetStreamShouldRespectOverflow() throws Http2Exception { + incrementAndGetStreamShouldRespectOverflow(client.local(), MAX_VALUE); + } + + @Test + public void clientLocalCreateStreamExhaustedSpace() throws Http2Exception { + client.local().createStream(MAX_VALUE, true); + Http2Exception expected = assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + client.local().createStream(MAX_VALUE, true); + } + }); + assertEquals(Http2Error.REFUSED_STREAM, expected.error()); + assertEquals(Http2Exception.ShutdownHint.GRACEFUL_SHUTDOWN, expected.shutdownHint()); + } + + @Test + public void newStreamBehindExpectedShouldThrow() throws Http2Exception { + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + server.local().createStream(0, true); + } + }); + } + + @Test + public void newStreamNotForServerShouldThrow() throws Http2Exception { + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + server.local().createStream(11, true); + } + }); + } + + @Test + public void newStreamNotForClientShouldThrow() throws Http2Exception { + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + client.local().createStream(10, true); + } + }); + } + + @Test + public void createShouldThrowWhenMaxAllowedStreamsOpenExceeded() throws Http2Exception { + server.local().maxActiveStreams(0); + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + server.local().createStream(2, true); + } + }); + } + + @Test + public void serverCreatePushShouldFailOnRemoteEndpointWhenMaxAllowedStreamsExceeded() throws Http2Exception { + server = new DefaultHttp2Connection(true, 0); + server.remote().maxActiveStreams(1); + final Http2Stream requestStream = server.remote().createStream(3, false); + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + server.remote().reservePushStream(2, requestStream); + } + }); + } + + @Test + public void clientCreatePushShouldFailOnRemoteEndpointWhenMaxAllowedStreamsExceeded() throws Http2Exception { + client = new DefaultHttp2Connection(false, 0); + client.remote().maxActiveStreams(1); + final Http2Stream requestStream = client.remote().createStream(2, false); + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + client.remote().reservePushStream(4, requestStream); + } + }); + } + + @Test + public void serverCreatePushShouldSucceedOnLocalEndpointWhenMaxAllowedStreamsExceeded() throws Http2Exception { + server = new DefaultHttp2Connection(true, 0); + server.local().maxActiveStreams(1); + Http2Stream requestStream = server.remote().createStream(3, false); + assertNotNull(server.local().reservePushStream(2, requestStream)); + } + + @Test + public void reserveWithPushDisallowedShouldThrow() throws Http2Exception { + final Http2Stream stream = server.remote().createStream(3, true); + server.remote().allowPushTo(false); + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + server.local().reservePushStream(2, stream); + } + }); + } + + @Test + public void goAwayReceivedShouldDisallowLocalCreation() throws Http2Exception { + server.goAwayReceived(0, 1L, Unpooled.EMPTY_BUFFER); + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + server.local().createStream(3, true); + } + }); + } + + @Test + public void goAwayReceivedShouldAllowRemoteCreation() throws Http2Exception { + server.goAwayReceived(0, 1L, Unpooled.EMPTY_BUFFER); + server.remote().createStream(3, true); + } + + @Test + public void goAwaySentShouldDisallowRemoteCreation() throws Http2Exception { + server.goAwaySent(0, 1L, Unpooled.EMPTY_BUFFER); + + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + server.remote().createStream(2, true); + } + }); + } + + @Test + public void goAwaySentShouldAllowLocalCreation() throws Http2Exception { + server.goAwaySent(0, 1L, Unpooled.EMPTY_BUFFER); + server.local().createStream(2, true); + } + + @Test + public void closeShouldSucceed() throws Http2Exception { + Http2Stream stream = server.remote().createStream(3, true); + stream.close(); + assertEquals(State.CLOSED, stream.state()); + assertEquals(0, server.numActiveStreams()); + } + + @Test + public void closeLocalWhenOpenShouldSucceed() throws Http2Exception { + Http2Stream stream = server.remote().createStream(3, false); + stream.closeLocalSide(); + assertEquals(State.HALF_CLOSED_LOCAL, stream.state()); + assertEquals(1, server.numActiveStreams()); + } + + @Test + public void closeRemoteWhenOpenShouldSucceed() throws Http2Exception { + Http2Stream stream = server.remote().createStream(3, false); + stream.closeRemoteSide(); + assertEquals(State.HALF_CLOSED_REMOTE, stream.state()); + assertEquals(1, server.numActiveStreams()); + } + + @Test + public void closeOnlyOpenSideShouldClose() throws Http2Exception { + Http2Stream stream = server.remote().createStream(3, true); + stream.closeLocalSide(); + assertEquals(State.CLOSED, stream.state()); + assertEquals(0, server.numActiveStreams()); + } + + @SuppressWarnings("NumericOverflow") + @Test + public void localStreamInvalidStreamIdShouldThrow() { + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + client.local().createStream(MAX_VALUE + 2, false); + } + }); + } + + @SuppressWarnings("NumericOverflow") + @Test + public void remoteStreamInvalidStreamIdShouldThrow() { + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + client.remote().createStream(MAX_VALUE + 1, false); + } + }); + } + + /** + * We force {@link #clientListener} methods to all throw a {@link RuntimeException} and verify the following: + *
    + *
  1. all listener methods are called for both {@link #clientListener} and {@link #clientListener2}
  2. + *
  3. {@link #clientListener2} is notified after {@link #clientListener}
  4. + *
  5. {@link #clientListener2} methods are all still called despite {@link #clientListener}'s + * method throwing a {@link RuntimeException}
  6. + *
+ */ + @Test + public void listenerThrowShouldNotPreventOtherListenersFromBeingNotified() throws Http2Exception { + final boolean[] calledArray = new boolean[128]; + // The following setup will ensure that clientListener throws exceptions, and marks a value in an array + // such that clientListener2 will verify that is is set or fail the test. + int methodIndex = 0; + doAnswer(new ListenerExceptionThrower(calledArray, methodIndex)) + .when(clientListener).onStreamAdded(any(Http2Stream.class)); + doAnswer(new ListenerVerifyCallAnswer(calledArray, methodIndex++)) + .when(clientListener2).onStreamAdded(any(Http2Stream.class)); + + doAnswer(new ListenerExceptionThrower(calledArray, methodIndex)) + .when(clientListener).onStreamActive(any(Http2Stream.class)); + doAnswer(new ListenerVerifyCallAnswer(calledArray, methodIndex++)) + .when(clientListener2).onStreamActive(any(Http2Stream.class)); + + doAnswer(new ListenerExceptionThrower(calledArray, methodIndex)) + .when(clientListener).onStreamHalfClosed(any(Http2Stream.class)); + doAnswer(new ListenerVerifyCallAnswer(calledArray, methodIndex++)) + .when(clientListener2).onStreamHalfClosed(any(Http2Stream.class)); + + doAnswer(new ListenerExceptionThrower(calledArray, methodIndex)) + .when(clientListener).onStreamClosed(any(Http2Stream.class)); + doAnswer(new ListenerVerifyCallAnswer(calledArray, methodIndex++)) + .when(clientListener2).onStreamClosed(any(Http2Stream.class)); + + doAnswer(new ListenerExceptionThrower(calledArray, methodIndex)) + .when(clientListener).onStreamRemoved(any(Http2Stream.class)); + doAnswer(new ListenerVerifyCallAnswer(calledArray, methodIndex++)) + .when(clientListener2).onStreamRemoved(any(Http2Stream.class)); + + doAnswer(new ListenerExceptionThrower(calledArray, methodIndex)) + .when(clientListener).onGoAwaySent(anyInt(), anyLong(), any(ByteBuf.class)); + doAnswer(new ListenerVerifyCallAnswer(calledArray, methodIndex++)) + .when(clientListener2).onGoAwaySent(anyInt(), anyLong(), any(ByteBuf.class)); + + doAnswer(new ListenerExceptionThrower(calledArray, methodIndex)) + .when(clientListener).onGoAwayReceived(anyInt(), anyLong(), any(ByteBuf.class)); + doAnswer(new ListenerVerifyCallAnswer(calledArray, methodIndex++)) + .when(clientListener2).onGoAwayReceived(anyInt(), anyLong(), any(ByteBuf.class)); + + doAnswer(new ListenerExceptionThrower(calledArray, methodIndex)) + .when(clientListener).onStreamAdded(any(Http2Stream.class)); + doAnswer(new ListenerVerifyCallAnswer(calledArray, methodIndex++)) + .when(clientListener2).onStreamAdded(any(Http2Stream.class)); + + // Now we add clientListener2 and exercise all listener functionality + try { + client.addListener(clientListener2); + Http2Stream stream = client.local().createStream(3, false); + verify(clientListener).onStreamAdded(any(Http2Stream.class)); + verify(clientListener2).onStreamAdded(any(Http2Stream.class)); + verify(clientListener).onStreamActive(any(Http2Stream.class)); + verify(clientListener2).onStreamActive(any(Http2Stream.class)); + + Http2Stream reservedStream = client.remote().reservePushStream(2, stream); + verify(clientListener, never()).onStreamActive(streamEq(reservedStream)); + verify(clientListener2, never()).onStreamActive(streamEq(reservedStream)); + + reservedStream.open(false); + verify(clientListener).onStreamActive(streamEq(reservedStream)); + verify(clientListener2).onStreamActive(streamEq(reservedStream)); + + stream.closeLocalSide(); + verify(clientListener).onStreamHalfClosed(any(Http2Stream.class)); + verify(clientListener2).onStreamHalfClosed(any(Http2Stream.class)); + + stream.close(); + verify(clientListener).onStreamClosed(any(Http2Stream.class)); + verify(clientListener2).onStreamClosed(any(Http2Stream.class)); + verify(clientListener).onStreamRemoved(any(Http2Stream.class)); + verify(clientListener2).onStreamRemoved(any(Http2Stream.class)); + + client.goAwaySent(client.connectionStream().id(), Http2Error.INTERNAL_ERROR.code(), Unpooled.EMPTY_BUFFER); + verify(clientListener).onGoAwaySent(anyInt(), anyLong(), any(ByteBuf.class)); + verify(clientListener2).onGoAwaySent(anyInt(), anyLong(), any(ByteBuf.class)); + + client.goAwayReceived(client.connectionStream().id(), + Http2Error.INTERNAL_ERROR.code(), Unpooled.EMPTY_BUFFER); + verify(clientListener).onGoAwayReceived(anyInt(), anyLong(), any(ByteBuf.class)); + verify(clientListener2).onGoAwayReceived(anyInt(), anyLong(), any(ByteBuf.class)); + } finally { + client.removeListener(clientListener2); + } + } + + private void testRemoveAllStreams() throws InterruptedException { + final CountDownLatch latch = new CountDownLatch(1); + final Promise promise = group.next().newPromise(); + client.close(promise).addListener(new FutureListener() { + @Override + public void operationComplete(Future future) throws Exception { + assertTrue(promise.isDone()); + latch.countDown(); + } + }); + assertTrue(latch.await(5, TimeUnit.SECONDS)); + } + + private static void incrementAndGetStreamShouldRespectOverflow(final Endpoint endpoint, int streamId) { + assertTrue(streamId > 0); + try { + endpoint.createStream(streamId, true); + streamId = endpoint.incrementAndGetNextStreamId(); + } catch (Throwable t) { + fail(t); + } + assertTrue(streamId < 0); + final int finalStreamId = streamId; + assertThrows(Http2NoMoreStreamIdsException.class, new Executable() { + @Override + public void execute() throws Throwable { + endpoint.createStream(finalStreamId, true); + } + }); + } + + private static void incrementAndGetStreamShouldSucceed(Endpoint endpoint) throws Http2Exception { + Http2Stream streamA = endpoint.createStream(endpoint.incrementAndGetNextStreamId(), true); + Http2Stream streamB = endpoint.createStream(streamA.id() + 2, true); + Http2Stream streamC = endpoint.createStream(endpoint.incrementAndGetNextStreamId(), true); + assertEquals(streamB.id() + 2, streamC.id()); + endpoint.createStream(streamC.id() + 2, true); + } + + private static final class ListenerExceptionThrower implements Answer { + private static final RuntimeException FAKE_EXCEPTION = new RuntimeException("Fake Exception"); + private final boolean[] array; + private final int index; + + ListenerExceptionThrower(boolean[] array, int index) { + this.array = array; + this.index = index; + } + + @Override + public Void answer(InvocationOnMock invocation) throws Throwable { + array[index] = true; + throw FAKE_EXCEPTION; + } + } + + private static final class ListenerVerifyCallAnswer implements Answer { + private final boolean[] array; + private final int index; + + ListenerVerifyCallAnswer(boolean[] array, int index) { + this.array = array; + this.index = index; + } + + @Override + public Void answer(InvocationOnMock invocation) throws Throwable { + assertTrue(array[index]); + return null; + } + } + + @SuppressWarnings("unchecked") + private static T streamEq(T stream) { + return (T) (stream == null ? ArgumentMatchers.isNull() : eq(stream)); + } +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2FrameReaderTest.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2FrameReaderTest.java new file mode 100644 index 0000000..bfedf0a --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2FrameReaderTest.java @@ -0,0 +1,453 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.buffer.UnpooledByteBufAllocator; +import io.netty.channel.ChannelHandlerContext; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import static io.netty.handler.codec.http2.Http2CodecUtil.*; +import static io.netty.handler.codec.http2.Http2FrameTypes.*; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.Mockito.*; + + +/** + * Tests for {@link DefaultHttp2FrameReader}. + */ +public class DefaultHttp2FrameReaderTest { + @Mock + private Http2FrameListener listener; + + @Mock + private ChannelHandlerContext ctx; + + private DefaultHttp2FrameReader frameReader; + + // Used to generate frame + private HpackEncoder hpackEncoder; + + @BeforeEach + public void setUp() throws Exception { + MockitoAnnotations.initMocks(this); + + when(ctx.alloc()).thenReturn(UnpooledByteBufAllocator.DEFAULT); + + frameReader = new DefaultHttp2FrameReader(); + hpackEncoder = new HpackEncoder(); + } + + @AfterEach + public void tearDown() { + frameReader.close(); + } + + @Test + public void readHeaderFrame() throws Http2Exception { + final int streamId = 1; + + ByteBuf input = Unpooled.buffer(); + try { + Http2Headers headers = new DefaultHttp2Headers() + .authority("foo") + .method("get") + .path("/") + .scheme("https"); + Http2Flags flags = new Http2Flags().endOfHeaders(true).endOfStream(true); + writeHeaderFrame(input, streamId, headers, flags); + frameReader.readFrame(ctx, input, listener); + + verify(listener).onHeadersRead(ctx, 1, headers, 0, true); + } finally { + input.release(); + } + } + + @Test + public void readHeaderFrameAndContinuationFrame() throws Http2Exception { + final int streamId = 1; + + ByteBuf input = Unpooled.buffer(); + try { + Http2Headers headers = new DefaultHttp2Headers() + .authority("foo") + .method("get") + .path("/") + .scheme("https"); + writeHeaderFrame(input, streamId, headers, + new Http2Flags().endOfHeaders(false).endOfStream(true)); + writeContinuationFrame(input, streamId, new DefaultHttp2Headers().add("foo", "bar"), + new Http2Flags().endOfHeaders(true)); + + frameReader.readFrame(ctx, input, listener); + + verify(listener).onHeadersRead(ctx, 1, headers.add("foo", "bar"), 0, true); + } finally { + input.release(); + } + } + + @Test + public void readUnknownFrame() throws Http2Exception { + ByteBuf input = Unpooled.buffer(); + ByteBuf payload = Unpooled.buffer(); + try { + payload.writeByte(1); + + writeFrameHeader(input, payload.readableBytes(), (byte) 0xff, new Http2Flags(), 0); + input.writeBytes(payload); + frameReader.readFrame(ctx, input, listener); + + verify(listener).onUnknownFrame( + ctx, (byte) 0xff, 0, new Http2Flags(), payload.slice(0, 1)); + } finally { + payload.release(); + input.release(); + } + } + + @Test + public void failedWhenUnknownFrameInMiddleOfHeaderBlock() throws Http2Exception { + final int streamId = 1; + + final ByteBuf input = Unpooled.buffer(); + try { + Http2Headers headers = new DefaultHttp2Headers() + .authority("foo") + .method("get") + .path("/") + .scheme("https"); + Http2Flags flags = new Http2Flags().endOfHeaders(false).endOfStream(true); + writeHeaderFrame(input, streamId, headers, flags); + writeFrameHeader(input, 0, (byte) 0xff, new Http2Flags(), streamId); + + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + frameReader.readFrame(ctx, input, listener); + } + }); + } finally { + input.release(); + } + } + + @Test + public void failedWhenContinuationFrameStreamIdMismatch() throws Http2Exception { + final ByteBuf input = Unpooled.buffer(); + try { + Http2Headers headers = new DefaultHttp2Headers() + .authority("foo") + .method("get") + .path("/") + .scheme("https"); + writeHeaderFrame(input, 1, headers, + new Http2Flags().endOfHeaders(false).endOfStream(true)); + writeContinuationFrame(input, 3, new DefaultHttp2Headers().add("foo", "bar"), + new Http2Flags().endOfHeaders(true)); + + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + frameReader.readFrame(ctx, input, listener); + } + }); + } finally { + input.release(); + } + } + + @Test + public void failedWhenContinuationFrameNotFollowHeaderFrame() throws Http2Exception { + final ByteBuf input = Unpooled.buffer(); + try { + writeContinuationFrame(input, 1, new DefaultHttp2Headers().add("foo", "bar"), + new Http2Flags().endOfHeaders(true)); + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + frameReader.readFrame(ctx, input, listener); + } + }); + } finally { + input.release(); + } + } + + @Test + public void failedWhenHeaderFrameDependsOnItself() throws Http2Exception { + final ByteBuf input = Unpooled.buffer(); + try { + Http2Headers headers = new DefaultHttp2Headers() + .authority("foo") + .method("get") + .path("/") + .scheme("https"); + writeHeaderFramePriorityPresent( + input, 1, headers, + new Http2Flags().endOfHeaders(true).endOfStream(true).priorityPresent(true), + 1, 10); + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + frameReader.readFrame(ctx, input, listener); + } + }); + } finally { + input.release(); + } + } + + @Test + public void readHeaderAndData() throws Http2Exception { + ByteBuf input = Unpooled.buffer(); + ByteBuf dataPayload = Unpooled.buffer(); + try { + Http2Headers headers = new DefaultHttp2Headers() + .authority("foo") + .method("get") + .path("/") + .scheme("https"); + dataPayload.writeByte(1); + writeHeaderFrameWithData(input, 1, headers, dataPayload); + + frameReader.readFrame(ctx, input, listener); + + verify(listener).onHeadersRead(ctx, 1, headers, 0, false); + verify(listener).onDataRead(ctx, 1, dataPayload.slice(0, 1), 0, true); + } finally { + input.release(); + dataPayload.release(); + } + } + + @Test + public void failedWhenDataFrameNotAssociateWithStream() throws Http2Exception { + final ByteBuf input = Unpooled.buffer(); + ByteBuf payload = Unpooled.buffer(); + try { + payload.writeByte(1); + + writeFrameHeader(input, payload.readableBytes(), DATA, new Http2Flags().endOfStream(true), 0); + input.writeBytes(payload); + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + frameReader.readFrame(ctx, input, listener); + } + }); + } finally { + payload.release(); + input.release(); + } + } + + @Test + public void readPriorityFrame() throws Http2Exception { + ByteBuf input = Unpooled.buffer(); + try { + writePriorityFrame(input, 1, 0, 10); + frameReader.readFrame(ctx, input, listener); + } finally { + input.release(); + } + } + + @Test + public void failedWhenPriorityFrameDependsOnItself() throws Http2Exception { + final ByteBuf input = Unpooled.buffer(); + try { + writePriorityFrame(input, 1, 1, 10); + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + frameReader.readFrame(ctx, input, listener); + } + }); + } finally { + input.release(); + } + } + + @Test + public void failedWhenWindowUpdateFrameWithZeroDelta() throws Http2Exception { + final ByteBuf input = Unpooled.buffer(); + try { + writeFrameHeader(input, 4, WINDOW_UPDATE, new Http2Flags(), 0); + input.writeInt(0); + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + frameReader.readFrame(ctx, input, listener); + } + }); + } finally { + input.release(); + } + } + + @Test + public void readSettingsFrame() throws Http2Exception { + ByteBuf input = Unpooled.buffer(); + try { + writeFrameHeader(input, 6, SETTINGS, new Http2Flags(), 0); + input.writeShort(SETTINGS_MAX_HEADER_LIST_SIZE); + input.writeInt(1024); + frameReader.readFrame(ctx, input, listener); + + listener.onSettingsRead(ctx, new Http2Settings().maxHeaderListSize(1024)); + } finally { + input.release(); + } + } + + @Test + public void readAckSettingsFrame() throws Http2Exception { + ByteBuf input = Unpooled.buffer(); + try { + writeFrameHeader(input, 0, SETTINGS, new Http2Flags().ack(true), 0); + frameReader.readFrame(ctx, input, listener); + + listener.onSettingsAckRead(ctx); + } finally { + input.release(); + } + } + + @Test + public void failedWhenSettingsFrameOnNonZeroStream() throws Http2Exception { + final ByteBuf input = Unpooled.buffer(); + try { + writeFrameHeader(input, 6, SETTINGS, new Http2Flags(), 1); + input.writeShort(SETTINGS_MAX_HEADER_LIST_SIZE); + input.writeInt(1024); + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + frameReader.readFrame(ctx, input, listener); + } + }); + } finally { + input.release(); + } + } + + @Test + public void failedWhenAckSettingsFrameWithPayload() throws Http2Exception { + final ByteBuf input = Unpooled.buffer(); + try { + writeFrameHeader(input, 1, SETTINGS, new Http2Flags().ack(true), 0); + input.writeByte(1); + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + frameReader.readFrame(ctx, input, listener); + } + }); + } finally { + input.release(); + } + } + + @Test + public void failedWhenSettingsFrameWithWrongPayloadLength() throws Http2Exception { + final ByteBuf input = Unpooled.buffer(); + try { + writeFrameHeader(input, 8, SETTINGS, new Http2Flags(), 0); + input.writeInt(SETTINGS_MAX_HEADER_LIST_SIZE); + input.writeInt(1024); + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + frameReader.readFrame(ctx, input, listener); + } + }); + } finally { + input.release(); + } + } + + private void writeHeaderFrame( + ByteBuf output, int streamId, Http2Headers headers, + Http2Flags flags) throws Http2Exception { + ByteBuf headerBlock = Unpooled.buffer(); + try { + hpackEncoder.encodeHeaders(streamId, headerBlock, headers, Http2HeadersEncoder.NEVER_SENSITIVE); + writeFrameHeader(output, headerBlock.readableBytes(), HEADERS, flags, streamId); + output.writeBytes(headerBlock, headerBlock.readableBytes()); + } finally { + headerBlock.release(); + } + } + + private void writeHeaderFrameWithData( + ByteBuf output, int streamId, Http2Headers headers, + ByteBuf dataPayload) throws Http2Exception { + ByteBuf headerBlock = Unpooled.buffer(); + try { + hpackEncoder.encodeHeaders(streamId, headerBlock, headers, Http2HeadersEncoder.NEVER_SENSITIVE); + writeFrameHeader(output, headerBlock.readableBytes(), HEADERS, + new Http2Flags().endOfHeaders(true), streamId); + output.writeBytes(headerBlock, headerBlock.readableBytes()); + + writeFrameHeader(output, dataPayload.readableBytes(), DATA, new Http2Flags().endOfStream(true), streamId); + output.writeBytes(dataPayload); + } finally { + headerBlock.release(); + } + } + + private void writeHeaderFramePriorityPresent( + ByteBuf output, int streamId, Http2Headers headers, + Http2Flags flags, int streamDependency, int weight) throws Http2Exception { + ByteBuf headerBlock = Unpooled.buffer(); + try { + headerBlock.writeInt(streamDependency); + headerBlock.writeByte(weight - 1); + hpackEncoder.encodeHeaders(streamId, headerBlock, headers, Http2HeadersEncoder.NEVER_SENSITIVE); + writeFrameHeader(output, headerBlock.readableBytes(), HEADERS, flags, streamId); + output.writeBytes(headerBlock, headerBlock.readableBytes()); + } finally { + headerBlock.release(); + } + } + + private void writeContinuationFrame( + ByteBuf output, int streamId, Http2Headers headers, + Http2Flags flags) throws Http2Exception { + ByteBuf headerBlock = Unpooled.buffer(); + try { + hpackEncoder.encodeHeaders(streamId, headerBlock, headers, Http2HeadersEncoder.NEVER_SENSITIVE); + writeFrameHeader(output, headerBlock.readableBytes(), CONTINUATION, flags, streamId); + output.writeBytes(headerBlock, headerBlock.readableBytes()); + } finally { + headerBlock.release(); + } + } + + private static void writePriorityFrame( + ByteBuf output, int streamId, int streamDependency, int weight) { + writeFrameHeader(output, 5, PRIORITY, new Http2Flags(), streamId); + output.writeInt(streamDependency); + output.writeByte(weight - 1); + } +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2FrameWriterTest.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2FrameWriterTest.java new file mode 100644 index 0000000..e134fb8 --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2FrameWriterTest.java @@ -0,0 +1,390 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.buffer.UnpooledByteBufAllocator; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.channel.DefaultChannelPromise; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.concurrent.ImmediateEventExecutor; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.util.Arrays; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.mockito.Mockito.*; + +/** + * Tests for {@link DefaultHttp2FrameWriter}. + */ +public class DefaultHttp2FrameWriterTest { + private DefaultHttp2FrameWriter frameWriter; + + private ByteBuf outbound; + + private ByteBuf expectedOutbound; + + private ChannelPromise promise; + + private Http2HeadersEncoder http2HeadersEncoder; + + @Mock + private Channel channel; + + @Mock + private ChannelFuture future; + + @Mock + private ChannelHandlerContext ctx; + + @BeforeEach + public void setUp() throws Exception { + MockitoAnnotations.initMocks(this); + http2HeadersEncoder = new DefaultHttp2HeadersEncoder( + Http2HeadersEncoder.NEVER_SENSITIVE, new HpackEncoder(false, 16, 0)); + + frameWriter = new DefaultHttp2FrameWriter(new DefaultHttp2HeadersEncoder( + Http2HeadersEncoder.NEVER_SENSITIVE, new HpackEncoder(false, 16, 0))); + + outbound = Unpooled.buffer(); + + expectedOutbound = Unpooled.EMPTY_BUFFER; + + promise = new DefaultChannelPromise(channel, ImmediateEventExecutor.INSTANCE); + + Answer answer = new Answer() { + @Override + public Object answer(InvocationOnMock var1) throws Throwable { + Object msg = var1.getArgument(0); + if (msg instanceof ByteBuf) { + outbound.writeBytes((ByteBuf) msg); + } + ReferenceCountUtil.release(msg); + return future; + } + }; + when(ctx.write(any())).then(answer); + when(ctx.write(any(), any(ChannelPromise.class))).then(answer); + when(ctx.alloc()).thenReturn(UnpooledByteBufAllocator.DEFAULT); + when(ctx.channel()).thenReturn(channel); + when(ctx.executor()).thenReturn(ImmediateEventExecutor.INSTANCE); + } + + @AfterEach + public void tearDown() throws Exception { + outbound.release(); + expectedOutbound.release(); + frameWriter.close(); + } + + @Test + public void writeHeaders() throws Exception { + int streamId = 1; + Http2Headers headers = new DefaultHttp2Headers() + .method("GET").path("/").authority("foo.com").scheme("https"); + + frameWriter.writeHeaders(ctx, streamId, headers, 0, true, promise); + + byte[] expectedPayload = headerPayload(streamId, headers); + byte[] expectedFrameBytes = { + (byte) 0x00, (byte) 0x00, (byte) 0x0a, // payload length = 10 + (byte) 0x01, // payload type = 1 + (byte) 0x05, // flags = (0x01 | 0x04) + (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x01 // stream id = 1 + }; + expectedOutbound = Unpooled.copiedBuffer(expectedFrameBytes, expectedPayload); + assertEquals(expectedOutbound, outbound); + } + + @Test + public void writeHeadersWithPadding() throws Exception { + int streamId = 1; + Http2Headers headers = new DefaultHttp2Headers() + .method("GET").path("/").authority("foo.com").scheme("https"); + + frameWriter.writeHeaders(ctx, streamId, headers, 5, true, promise); + + byte[] expectedPayload = headerPayload(streamId, headers, (byte) 4); + byte[] expectedFrameBytes = { + (byte) 0x00, (byte) 0x00, (byte) 0x0f, // payload length = 16 + (byte) 0x01, // payload type = 1 + (byte) 0x0d, // flags = (0x01 | 0x04 | 0x08) + (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x01 // stream id = 1 + }; + expectedOutbound = Unpooled.copiedBuffer(expectedFrameBytes, expectedPayload); + assertEquals(expectedOutbound, outbound); + } + + @Test + public void writeHeadersNotEndStream() throws Exception { + int streamId = 1; + Http2Headers headers = new DefaultHttp2Headers() + .method("GET").path("/").authority("foo.com").scheme("https"); + + frameWriter.writeHeaders(ctx, streamId, headers, 0, false, promise); + + byte[] expectedPayload = headerPayload(streamId, headers); + byte[] expectedFrameBytes = { + (byte) 0x00, (byte) 0x00, (byte) 0x0a, // payload length = 10 + (byte) 0x01, // payload type = 1 + (byte) 0x04, // flags = 0x04 + (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x01 // stream id = 1 + }; + ByteBuf expectedOutbound = Unpooled.copiedBuffer(expectedFrameBytes, expectedPayload); + assertEquals(expectedOutbound, outbound); + } + + @Test + public void writeEmptyDataWithPadding() { + int streamId = 1; + + ByteBuf payloadByteBuf = Unpooled.buffer(); + frameWriter.writeData(ctx, streamId, payloadByteBuf, 2, true, promise); + + assertEquals(0, payloadByteBuf.refCnt()); + + byte[] expectedFrameBytes = { + (byte) 0x00, (byte) 0x00, (byte) 0x02, // payload length + (byte) 0x00, // payload type + (byte) 0x09, // flags + (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x01, // stream id + (byte) 0x01, (byte) 0x00, // padding + }; + expectedOutbound = Unpooled.copiedBuffer(expectedFrameBytes); + assertEquals(expectedOutbound, outbound); + } + + /** + * Test large headers that exceed {@link DefaultHttp2FrameWriter#maxFrameSize()} + * the remaining headers will be sent in a CONTINUATION frame + */ + @Test + public void writeLargeHeaders() throws Exception { + int streamId = 1; + Http2Headers headers = new DefaultHttp2Headers() + .method("GET").path("/").authority("foo.com").scheme("https"); + headers = dummyHeaders(headers, 20); + + http2HeadersEncoder.configuration().maxHeaderListSize(Integer.MAX_VALUE); + frameWriter.headersConfiguration().maxHeaderListSize(Integer.MAX_VALUE); + frameWriter.maxFrameSize(Http2CodecUtil.MAX_FRAME_SIZE_LOWER_BOUND); + frameWriter.writeHeaders(ctx, streamId, headers, 0, true, promise); + + byte[] expectedPayload = headerPayload(streamId, headers); + + // First frame: HEADER(length=0x4000, flags=0x01) + assertEquals(Http2CodecUtil.MAX_FRAME_SIZE_LOWER_BOUND, + outbound.readUnsignedMedium()); + assertEquals(0x01, outbound.readByte()); + assertEquals(0x01, outbound.readByte()); + assertEquals(streamId, outbound.readInt()); + + byte[] firstPayload = new byte[Http2CodecUtil.MAX_FRAME_SIZE_LOWER_BOUND]; + outbound.readBytes(firstPayload); + + int remainPayloadLength = expectedPayload.length - Http2CodecUtil.MAX_FRAME_SIZE_LOWER_BOUND; + // Second frame: CONTINUATION(length=remainPayloadLength, flags=0x04) + assertEquals(remainPayloadLength, outbound.readUnsignedMedium()); + assertEquals(0x09, outbound.readByte()); + assertEquals(0x04, outbound.readByte()); + assertEquals(streamId, outbound.readInt()); + + byte[] secondPayload = new byte[remainPayloadLength]; + outbound.readBytes(secondPayload); + + assertArrayEquals(Arrays.copyOfRange(expectedPayload, 0, firstPayload.length), + firstPayload); + assertArrayEquals(Arrays.copyOfRange(expectedPayload, firstPayload.length, + expectedPayload.length), + secondPayload); + } + + @Test + public void writeLargeHeaderWithPadding() throws Exception { + int streamId = 1; + Http2Headers headers = new DefaultHttp2Headers() + .method("GET").path("/").authority("foo.com").scheme("https"); + headers = dummyHeaders(headers, 20); + + http2HeadersEncoder.configuration().maxHeaderListSize(Integer.MAX_VALUE); + frameWriter.headersConfiguration().maxHeaderListSize(Integer.MAX_VALUE); + frameWriter.maxFrameSize(Http2CodecUtil.MAX_FRAME_SIZE_LOWER_BOUND); + frameWriter.writeHeaders(ctx, streamId, headers, 5, true, promise); + + byte[] expectedPayload = buildLargeHeaderPayload(streamId, headers, (byte) 4, + Http2CodecUtil.MAX_FRAME_SIZE_LOWER_BOUND); + + // First frame: HEADER(length=0x4000, flags=0x09) + assertEquals(Http2CodecUtil.MAX_FRAME_SIZE_LOWER_BOUND, + outbound.readUnsignedMedium()); + assertEquals(0x01, outbound.readByte()); + assertEquals(0x09, outbound.readByte()); // 0x01 + 0x08 + assertEquals(streamId, outbound.readInt()); + + byte[] firstPayload = new byte[Http2CodecUtil.MAX_FRAME_SIZE_LOWER_BOUND]; + outbound.readBytes(firstPayload); + + int remainPayloadLength = expectedPayload.length - Http2CodecUtil.MAX_FRAME_SIZE_LOWER_BOUND; + // Second frame: CONTINUATION(length=remainPayloadLength, flags=0x04) + assertEquals(remainPayloadLength, outbound.readUnsignedMedium()); + assertEquals(0x09, outbound.readByte()); + assertEquals(0x04, outbound.readByte()); + assertEquals(streamId, outbound.readInt()); + + byte[] secondPayload = new byte[remainPayloadLength]; + outbound.readBytes(secondPayload); + + assertArrayEquals(Arrays.copyOfRange(expectedPayload, 0, firstPayload.length), + firstPayload); + assertArrayEquals(Arrays.copyOfRange(expectedPayload, firstPayload.length, + expectedPayload.length), + secondPayload); + } + + @Test + public void writeFrameZeroPayload() throws Exception { + frameWriter.writeFrame(ctx, (byte) 0xf, 0, new Http2Flags(), Unpooled.EMPTY_BUFFER, promise); + + byte[] expectedFrameBytes = { + (byte) 0x00, (byte) 0x00, (byte) 0x00, // payload length + (byte) 0x0f, // payload type + (byte) 0x00, // flags + (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00 // stream id + }; + + expectedOutbound = Unpooled.wrappedBuffer(expectedFrameBytes); + assertEquals(expectedOutbound, outbound); + } + + @Test + public void writeFrameHasPayload() throws Exception { + byte[] payload = {(byte) 0x01, (byte) 0x03, (byte) 0x05, (byte) 0x07, (byte) 0x09}; + + // will auto release after frameWriter.writeFrame succeed + ByteBuf payloadByteBuf = Unpooled.wrappedBuffer(payload); + frameWriter.writeFrame(ctx, (byte) 0xf, 0, new Http2Flags(), payloadByteBuf, promise); + + byte[] expectedFrameHeaderBytes = { + (byte) 0x00, (byte) 0x00, (byte) 0x05, // payload length + (byte) 0x0f, // payload type + (byte) 0x00, // flags + (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00 // stream id + }; + expectedOutbound = Unpooled.copiedBuffer(expectedFrameHeaderBytes, payload); + assertEquals(expectedOutbound, outbound); + } + + @Test + public void writePriority() { + frameWriter.writePriority( + ctx, /* streamId= */ 1, /* dependencyId= */ 2, /* weight= */ (short) 256, /* exclusive= */ true, promise); + + expectedOutbound = Unpooled.copiedBuffer(new byte[] { + (byte) 0x00, (byte) 0x00, (byte) 0x05, // payload length = 5 + (byte) 0x02, // payload type = 2 + (byte) 0x00, // flags = 0x00 + (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x01, // stream id = 1 + (byte) 0x80, (byte) 0x00, (byte) 0x00, (byte) 0x02, // dependency id = 2 | exclusive = 1 << 63 + (byte) 0xFF, // weight = 255 (implicit +1) + }); + assertEquals(expectedOutbound, outbound); + } + + @Test + public void writePriorityDefaults() { + frameWriter.writePriority( + ctx, /* streamId= */ 1, /* dependencyId= */ 0, /* weight= */ (short) 16, /* exclusive= */ false, promise); + + expectedOutbound = Unpooled.copiedBuffer(new byte[] { + (byte) 0x00, (byte) 0x00, (byte) 0x05, // payload length = 5 + (byte) 0x02, // payload type = 2 + (byte) 0x00, // flags = 0x00 + (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x01, // stream id = 1 + (byte) 0x00, (byte) 0x00, (byte) 0x00, (byte) 0x00, // dependency id = 0 | exclusive = 0 << 63 + (byte) 0x0F, // weight = 15 (implicit +1) + }); + assertEquals(expectedOutbound, outbound); + } + + private byte[] headerPayload(int streamId, Http2Headers headers, byte padding) throws Http2Exception, IOException { + if (padding == 0) { + return headerPayload(streamId, headers); + } + + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + try { + outputStream.write(padding); + outputStream.write(headerPayload(streamId, headers)); + outputStream.write(new byte[padding]); + return outputStream.toByteArray(); + } finally { + outputStream.close(); + } + } + + private byte[] headerPayload(int streamId, Http2Headers headers) throws Http2Exception { + ByteBuf byteBuf = Unpooled.buffer(); + try { + http2HeadersEncoder.encodeHeaders(streamId, headers, byteBuf); + byte[] bytes = new byte[byteBuf.readableBytes()]; + byteBuf.readBytes(bytes); + return bytes; + } finally { + byteBuf.release(); + } + } + + private byte[] buildLargeHeaderPayload(int streamId, Http2Headers headers, byte padding, int maxFrameSize) + throws Http2Exception, IOException { + ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); + try { + outputStream.write(padding); + byte[] payload = headerPayload(streamId, headers); + int firstPayloadSize = maxFrameSize - (padding + 1); //1 for padding length + outputStream.write(payload, 0, firstPayloadSize); + outputStream.write(new byte[padding]); + outputStream.write(payload, firstPayloadSize, payload.length - firstPayloadSize); + return outputStream.toByteArray(); + } finally { + outputStream.close(); + } + } + + private static Http2Headers dummyHeaders(Http2Headers headers, int times) { + final String largeValue = repeat("dummy-value", 100); + for (int i = 0; i < times; i++) { + headers.add(String.format("dummy-%d", i), largeValue); + } + return headers; + } + + private static String repeat(String str, int count) { + return String.format(String.format("%%%ds", count), " ").replace(" ", str); + } +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2HeadersDecoderTest.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2HeadersDecoderTest.java new file mode 100644 index 0000000..1ab87ae --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2HeadersDecoderTest.java @@ -0,0 +1,308 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaderValues; +import io.netty.util.AsciiString; +import io.netty.util.ReferenceCountUtil; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import java.util.ArrayList; +import java.util.List; + +import static io.netty.handler.codec.http2.Http2CodecUtil.MAX_HEADER_LIST_SIZE; +import static io.netty.handler.codec.http2.Http2CodecUtil.MIN_HEADER_LIST_SIZE; +import static io.netty.handler.codec.http2.Http2HeadersEncoder.NEVER_SENSITIVE; +import static io.netty.handler.codec.http2.Http2TestUtil.newTestEncoder; +import static io.netty.handler.codec.http2.Http2TestUtil.randomBytes; +import static io.netty.util.CharsetUtil.UTF_8; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +/** + * Tests for {@link DefaultHttp2HeadersDecoder}. + */ +public class DefaultHttp2HeadersDecoderTest { + + private DefaultHttp2HeadersDecoder decoder; + + @BeforeEach + public void setup() { + decoder = new DefaultHttp2HeadersDecoder(false); + } + + @Test + public void decodeShouldSucceed() throws Exception { + ByteBuf buf = encode(b(":method"), b("GET"), b("akey"), b("avalue"), randomBytes(), randomBytes()); + try { + Http2Headers headers = decoder.decodeHeaders(0, buf); + assertEquals(3, headers.size()); + assertEquals("GET", headers.method().toString()); + assertEquals("avalue", headers.get(new AsciiString("akey")).toString()); + } finally { + buf.release(); + } + } + + @Test + public void testExceedHeaderSize() throws Exception { + final int maxListSize = 100; + decoder.configuration().maxHeaderListSize(maxListSize, maxListSize); + final ByteBuf buf = encode(randomBytes(maxListSize), randomBytes(1)); + + try { + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + decoder.decodeHeaders(0, buf); + } + }); + } finally { + buf.release(); + } + } + + @Test + public void decodeLargerThanHeaderListSizeButLessThanGoAway() throws Exception { + decoder.maxHeaderListSize(MIN_HEADER_LIST_SIZE, MAX_HEADER_LIST_SIZE); + final ByteBuf buf = encode(b(":method"), b("GET")); + final int streamId = 1; + Http2Exception.HeaderListSizeException e = + assertThrows(Http2Exception.HeaderListSizeException.class, new Executable() { + @Override + public void execute() throws Throwable { + decoder.decodeHeaders(streamId, buf); + } + }); + assertEquals(streamId, e.streamId()); + buf.release(); + } + + @Test + public void decodeLargerThanHeaderListSizeButLessThanGoAwayWithInitialDecoderSettings() throws Exception { + final ByteBuf buf = encode(b(":method"), b("GET"), b("test_header"), + b(String.format("%09000d", 0).replace('0', 'A'))); + final int streamId = 1; + try { + Http2Exception.HeaderListSizeException e = assertThrows(Http2Exception.HeaderListSizeException.class, + new Executable() { + @Override + public void execute() throws Throwable { + decoder.decodeHeaders(streamId, buf); + } + }); + assertEquals(streamId, e.streamId()); + } finally { + buf.release(); + } + } + + @Test + public void decodeLargerThanHeaderListSizeGoAway() throws Exception { + decoder.maxHeaderListSize(MIN_HEADER_LIST_SIZE, MIN_HEADER_LIST_SIZE); + final ByteBuf buf = encode(b(":method"), b("GET")); + final int streamId = 1; + try { + Http2Exception e = assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + decoder.decodeHeaders(streamId, buf); + } + }); + assertEquals(Http2Error.PROTOCOL_ERROR, e.error()); + } finally { + buf.release(); + } + } + + @Test + public void duplicatePseudoHeadersMustFailValidation() throws Exception { + final DefaultHttp2HeadersDecoder decoder = new DefaultHttp2HeadersDecoder(true); + verifyValidationFails(decoder, encode(b(":authority"), b("abc"), b(":authority"), b("def"))); + } + + @Test + public void decodingTrailersTeHeaderMustNotFailValidation() throws Exception { + // The TE header is expressly allowed to have the value "trailers". + ByteBuf buf = null; + try { + buf = encode(b(":method"), b("GET"), b("te"), b("trailers")); + Http2Headers headers = decoder.decodeHeaders(1, buf); // This must not throw. + assertThat(headers.get(HttpHeaderNames.TE)).isEqualToIgnoringCase(HttpHeaderValues.TRAILERS); + } finally { + ReferenceCountUtil.release(buf); + } + } + + @Test + public void decodingConnectionRelatedHeadersMustFailValidation() throws Exception { + final DefaultHttp2HeadersDecoder decoder = new DefaultHttp2HeadersDecoder(true, true); + // Standard connection related headers + verifyValidationFails(decoder, encode(b(":method"), b("GET"), b("keep-alive"), b("timeout=5"))); + verifyValidationFails(decoder, encode(b(":method"), b("GET"), + b("connection"), b("keep-alive"), b("keep-alive"), b("timeout=5"))); + verifyValidationFails(decoder, encode(b(":method"), b("GET"), b("transfer-encoding"), b("chunked"))); + verifyValidationFails(decoder, encode(b(":method"), b("GET"), + b("connection"), b("transfer-encoding"), b("transfer-encoding"), b("chunked"))); + verifyValidationFails(decoder, encode(b(":method"), b("GET"), b("upgrade"), b("foo/2"))); + verifyValidationFails(decoder, encode(b(":method"), b("GET"), + b("connection"), b("upgrade"), b("upgrade"), b("foo/2"))); + verifyValidationFails(decoder, encode(b(":method"), b("GET"), b("connection"), b("close"))); + + // Non-standard connection related headers: + verifyValidationFails(decoder, encode(b(":method"), b("GET"), b("proxy-connection"), b("keep-alive"))); + + // Only "trailers" is allowed for the TE header: + verifyValidationFails(decoder, encode(b(":method"), b("GET"), b("te"), b("compress"))); + } + + public static List illegalFirstChar() { + ArrayList list = new ArrayList(); + for (int i = 0; i < 0x21; i++) { + list.add(i); + } + list.add(0x7F); + return list; + } + + @ParameterizedTest + @MethodSource("illegalFirstChar") + void decodingInvalidHeaderValueMustFailValidationIfFirstCharIsIllegal(int illegalFirstChar)throws Exception { + final DefaultHttp2HeadersDecoder decoder = new DefaultHttp2HeadersDecoder(true, true); + verifyValidationFails(decoder, encode(b(":method"), b("GET"), + b("test_header"), new byte[]{ (byte) illegalFirstChar, (byte) 'a' })); + } + + public static List illegalNotFirstChar() { + ArrayList list = new ArrayList(); + for (int i = 0; i < 0x21; i++) { + if (i == ' ' || i == '\t') { + continue; // Space and horizontal tab are only illegal as first chars. + } + list.add(i); + } + list.add(0x7F); + return list; + } + + @ParameterizedTest + @MethodSource("illegalNotFirstChar") + void decodingInvalidHeaderValueMustFailValidationIfANotFirstCharIsIllegal(int illegalSecondChar) throws Exception { + final DefaultHttp2HeadersDecoder decoder = new DefaultHttp2HeadersDecoder(true, true); + verifyValidationFails(decoder, encode(b(":method"), b("GET"), + b("test_header"), new byte[]{ (byte) 'a', (byte) illegalSecondChar })); + } + + @Test + public void headerValuesAllowSpaceAfterFirstCharacter() throws Exception { + final DefaultHttp2HeadersDecoder decoder = new DefaultHttp2HeadersDecoder(true); + ByteBuf buf = null; + try { + buf = encode(b(":method"), b("GET"), b("test_header"), b("a b")); + Http2Headers headers = decoder.decodeHeaders(1, buf); // This must not throw. + assertThat(headers.get("test_header")).isEqualToIgnoringCase("a b"); + } finally { + ReferenceCountUtil.release(buf); + } + } + + @Test + public void headerValuesAllowHorzontalTabAfterFirstCharacter() throws Exception { + final DefaultHttp2HeadersDecoder decoder = new DefaultHttp2HeadersDecoder(true); + ByteBuf buf = null; + try { + buf = encode(b(":method"), b("GET"), b("test_header"), b("a\tb")); + Http2Headers headers = decoder.decodeHeaders(1, buf); // This must not throw. + assertThat(headers.get("test_header")).isEqualToIgnoringCase("a\tb"); + } finally { + ReferenceCountUtil.release(buf); + } + } + + public static List validObsText() { + ArrayList list = new ArrayList(); + for (int i = 0x80; i <= 0xFF; i++) { + list.add(i); + } + return list; + } + + @ParameterizedTest + @MethodSource("validObsText") + void headerValuesAllowObsTextInFirstChar(int i) throws Exception { + final DefaultHttp2HeadersDecoder decoder = new DefaultHttp2HeadersDecoder(true); + ByteBuf buf = null; + try { + byte[] bytes = {(byte) i, 'a'}; + buf = encode(b(":method"), b("GET"), b("test_header"), bytes); + Http2Headers headers = decoder.decodeHeaders(1, buf); // This must not throw. + assertThat(headers.get("test_header")).isEqualTo(new AsciiString(bytes)); + } finally { + ReferenceCountUtil.release(buf); + } + } + + @ParameterizedTest + @MethodSource("validObsText") + void headerValuesAllowObsTextInNonFirstChar(int i) throws Exception { + final DefaultHttp2HeadersDecoder decoder = new DefaultHttp2HeadersDecoder(true); + ByteBuf buf = null; + try { + byte[] bytes = {(byte) 'a', (byte) i}; + buf = encode(b(":method"), b("GET"), b("test_header"), bytes); + Http2Headers headers = decoder.decodeHeaders(1, buf); // This must not throw. + assertThat(headers.get("test_header")).isEqualTo(new AsciiString(bytes)); + } finally { + ReferenceCountUtil.release(buf); + } + } + + private static void verifyValidationFails(final DefaultHttp2HeadersDecoder decoder, final ByteBuf buf) { + try { + Http2Exception e = assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + decoder.decodeHeaders(1, buf); + } + }); + assertEquals(Http2Error.PROTOCOL_ERROR, e.error()); + } finally { + buf.release(); + } + } + + private static byte[] b(String string) { + return string.getBytes(UTF_8); + } + + private static ByteBuf encode(byte[]... entries) throws Exception { + HpackEncoder hpackEncoder = newTestEncoder(); + ByteBuf out = Unpooled.buffer(); + Http2Headers http2Headers = new DefaultHttp2Headers(false); + for (int ix = 0; ix < entries.length;) { + http2Headers.add(new AsciiString(entries[ix++], false), new AsciiString(entries[ix++], false)); + } + hpackEncoder.encodeHeaders(3 /* randomly chosen */, out, http2Headers, NEVER_SENSITIVE); + return out; + } +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2HeadersEncoderTest.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2HeadersEncoderTest.java new file mode 100644 index 0000000..3091c86 --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2HeadersEncoderTest.java @@ -0,0 +1,70 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.handler.codec.http2.Http2Exception.StreamException; +import io.netty.util.AsciiString; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import static io.netty.handler.codec.http2.Http2TestUtil.newTestEncoder; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * Tests for {@link DefaultHttp2HeadersEncoder}. + */ +public class DefaultHttp2HeadersEncoderTest { + + private DefaultHttp2HeadersEncoder encoder; + + @BeforeEach + public void setup() { + encoder = new DefaultHttp2HeadersEncoder(Http2HeadersEncoder.NEVER_SENSITIVE, newTestEncoder()); + } + + @Test + public void encodeShouldSucceed() throws Http2Exception { + Http2Headers headers = headers(); + ByteBuf buf = Unpooled.buffer(); + try { + encoder.encodeHeaders(3 /* randomly chosen */, headers, buf); + assertTrue(buf.writerIndex() > 0); + } finally { + buf.release(); + } + } + + @Test + public void headersExceedMaxSetSizeShouldFail() throws Http2Exception { + final Http2Headers headers = headers(); + encoder.maxHeaderListSize(2); + assertThrows(StreamException.class, new Executable() { + @Override + public void execute() throws Throwable { + encoder.encodeHeaders(3 /* randomly chosen */, headers, Unpooled.buffer()); + } + }); + } + + private static Http2Headers headers() { + return new DefaultHttp2Headers().method(new AsciiString("GET")).add(new AsciiString("a"), new AsciiString("1")) + .add(new AsciiString("a"), new AsciiString("2")); + } +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2HeadersTest.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2HeadersTest.java new file mode 100644 index 0000000..612f1d8 --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2HeadersTest.java @@ -0,0 +1,253 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.codec.http2; + +import io.netty.handler.codec.http2.Http2Headers.PseudoHeaderName; +import io.netty.util.AsciiString; +import io.netty.util.internal.StringUtil; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; + +import java.util.Map.Entry; + +import static io.netty.handler.codec.http.HttpHeaderNames.CONTENT_LENGTH; +import static io.netty.util.AsciiString.*; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class DefaultHttp2HeadersTest { + + @Test + public void nullHeaderNameNotAllowed() { + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + new DefaultHttp2Headers().add(null, "foo"); + } + }); + } + + @Test + public void emptyHeaderNameNotAllowed() { + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + new DefaultHttp2Headers().add(StringUtil.EMPTY_STRING, "foo"); + } + }); + } + + @Test + public void testPseudoHeadersMustComeFirstWhenIterating() { + Http2Headers headers = newHeaders(); + + verifyPseudoHeadersFirst(headers); + verifyAllPseudoHeadersPresent(headers); + } + + @Test + public void testPseudoHeadersWithRemovePreservesPseudoIterationOrder() { + Http2Headers headers = newHeaders(); + + Http2Headers nonPseudoHeaders = new DefaultHttp2Headers(); + for (Entry entry : headers) { + if (entry.getKey().length() == 0 || entry.getKey().charAt(0) != ':' && + !nonPseudoHeaders.contains(entry.getKey())) { + nonPseudoHeaders.add(entry.getKey(), entry.getValue()); + } + } + + assertFalse(nonPseudoHeaders.isEmpty()); + + // Remove all the non-pseudo headers and verify + for (Entry nonPseudoHeaderEntry : nonPseudoHeaders) { + assertTrue(headers.remove(nonPseudoHeaderEntry.getKey())); + verifyPseudoHeadersFirst(headers); + verifyAllPseudoHeadersPresent(headers); + } + + // Add back all non-pseudo headers + for (Entry nonPseudoHeaderEntry : nonPseudoHeaders) { + headers.add(nonPseudoHeaderEntry.getKey(), of("goo")); + verifyPseudoHeadersFirst(headers); + verifyAllPseudoHeadersPresent(headers); + } + } + + @Test + public void testPseudoHeadersWithClearDoesNotLeak() { + Http2Headers headers = newHeaders(); + + assertFalse(headers.isEmpty()); + headers.clear(); + assertTrue(headers.isEmpty()); + + // Combine 2 headers together, make sure pseudo headers stay up front. + headers.add("name1", "value1").scheme("nothing"); + verifyPseudoHeadersFirst(headers); + + Http2Headers other = new DefaultHttp2Headers().add("name2", "value2").authority("foo"); + verifyPseudoHeadersFirst(other); + + headers.add(other); + verifyPseudoHeadersFirst(headers); + + // Make sure the headers are what we expect them to be, and no leaking behind the scenes. + assertEquals(4, headers.size()); + assertEquals("value1", headers.get("name1")); + assertEquals("value2", headers.get("name2")); + assertEquals("nothing", headers.scheme()); + assertEquals("foo", headers.authority()); + } + + @Test + public void testSetHeadersOrdersPseudoHeadersCorrectly() { + Http2Headers headers = newHeaders(); + Http2Headers other = new DefaultHttp2Headers().add("name2", "value2").authority("foo"); + + headers.set(other); + verifyPseudoHeadersFirst(headers); + assertEquals(other.size(), headers.size()); + assertEquals("foo", headers.authority()); + assertEquals("value2", headers.get("name2")); + } + + @Test + public void testSetAllOrdersPseudoHeadersCorrectly() { + Http2Headers headers = newHeaders(); + Http2Headers other = new DefaultHttp2Headers().add("name2", "value2").authority("foo"); + + int headersSizeBefore = headers.size(); + headers.setAll(other); + verifyPseudoHeadersFirst(headers); + verifyAllPseudoHeadersPresent(headers); + assertEquals(headersSizeBefore + 1, headers.size()); + assertEquals("foo", headers.authority()); + assertEquals("value2", headers.get("name2")); + } + + @Test + public void testHeaderNameValidation() { + final Http2Headers headers = newHeaders(); + + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + headers.add(of("Foo"), of("foo")); + } + }); + } + + @Test + public void testClearResetsPseudoHeaderDivision() { + DefaultHttp2Headers http2Headers = new DefaultHttp2Headers(); + http2Headers.method("POST"); + http2Headers.set("some", "value"); + http2Headers.clear(); + http2Headers.method("GET"); + assertEquals(1, http2Headers.names().size()); + } + + @Test + public void testContainsNameAndValue() { + Http2Headers headers = newHeaders(); + assertTrue(headers.contains("name1", "value2")); + assertFalse(headers.contains("name1", "Value2")); + assertTrue(headers.contains("2name", "Value3", true)); + assertFalse(headers.contains("2name", "Value3", false)); + } + + @Test + public void testContainsName() { + Http2Headers headers = new DefaultHttp2Headers(); + headers.add(CONTENT_LENGTH, "36"); + assertFalse(headers.contains("Content-Length")); + assertTrue(headers.contains("content-length")); + assertTrue(headers.contains(CONTENT_LENGTH)); + headers.remove(CONTENT_LENGTH); + assertFalse(headers.contains("Content-Length")); + assertFalse(headers.contains("content-length")); + assertFalse(headers.contains(CONTENT_LENGTH)); + + assertFalse(headers.contains("non-existent-name")); + assertFalse(headers.contains(new AsciiString("non-existent-name"))); + } + + @Test + void setMustOverwritePseudoHeaders() { + Http2Headers headers = newHeaders(); + // The headers are already populated with pseudo headers. + headers.method(of("GET")); + headers.path(of("/index2.html")); + headers.status(of("101")); + headers.authority(of("github.com")); + headers.scheme(of("http")); + headers.set(of(":protocol"), of("http")); + assertEquals(of("GET"), headers.method()); + assertEquals(of("/index2.html"), headers.path()); + assertEquals(of("101"), headers.status()); + assertEquals(of("github.com"), headers.authority()); + assertEquals(of("http"), headers.scheme()); + } + + @ParameterizedTest(name = "{displayName} [{index}] name={0} value={1}") + @CsvSource(value = {"upgrade,protocol1", "connection,close", "keep-alive,timeout=5", "proxy-connection,close", + "transfer-encoding,chunked", "te,something-else"}) + void possibleToAddConnectionHeaders(String name, String value) { + Http2Headers headers = newHeaders(); + headers.add(name, value); + assertTrue(headers.contains(name, value)); + } + + private static void verifyAllPseudoHeadersPresent(Http2Headers headers) { + for (PseudoHeaderName pseudoName : PseudoHeaderName.values()) { + assertNotNull(headers.get(pseudoName.value())); + } + } + + static void verifyPseudoHeadersFirst(Http2Headers headers) { + CharSequence lastNonPseudoName = null; + for (Entry entry: headers) { + if (entry.getKey().length() == 0 || entry.getKey().charAt(0) != ':') { + lastNonPseudoName = entry.getKey(); + } else if (lastNonPseudoName != null) { + fail("All pseudo headers must be first in iteration. Pseudo header " + entry.getKey() + + " is after a non pseudo header " + lastNonPseudoName); + } + } + } + + private static Http2Headers newHeaders() { + Http2Headers headers = new DefaultHttp2Headers(); + headers.add(of("name1"), of("value1"), of("value2")); + headers.method(of("POST")); + headers.add(of("2name"), of("value3")); + headers.path(of("/index.html")); + headers.status(of("200")); + headers.authority(of("netty.io")); + headers.add(of("name3"), of("value4")); + headers.scheme(of("https")); + headers.add(of(":protocol"), of("websocket")); + return headers; + } +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2LocalFlowControllerTest.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2LocalFlowControllerTest.java new file mode 100644 index 0000000..2117c85 --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2LocalFlowControllerTest.java @@ -0,0 +1,460 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package io.netty.handler.codec.http2; + +import static io.netty.handler.codec.http2.DefaultHttp2LocalFlowController.DEFAULT_WINDOW_UPDATE_RATIO; +import static io.netty.handler.codec.http2.Http2CodecUtil.CONNECTION_STREAM_ID; +import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_WINDOW_SIZE; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.anyInt; +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.when; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.handler.codec.http2.Http2Stream.State; +import io.netty.util.concurrent.EventExecutor; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +/** + * Tests for {@link DefaultHttp2LocalFlowController}. + */ +public class DefaultHttp2LocalFlowControllerTest { + private static final int STREAM_ID = 1; + + private DefaultHttp2LocalFlowController controller; + + @Mock + private Http2FrameWriter frameWriter; + + @Mock + private ChannelHandlerContext ctx; + + @Mock + private EventExecutor executor; + + @Mock + private ChannelPromise promise; + + private DefaultHttp2Connection connection; + + @BeforeEach + public void setup() throws Http2Exception { + MockitoAnnotations.initMocks(this); + setupChannelHandlerContext(false); + when(executor.inEventLoop()).thenReturn(true); + + initController(false); + } + + private void setupChannelHandlerContext(boolean allowFlush) { + reset(ctx); + when(ctx.newPromise()).thenReturn(promise); + if (allowFlush) { + when(ctx.flush()).then(new Answer() { + @Override + public ChannelHandlerContext answer(InvocationOnMock invocationOnMock) { + return ctx; + } + }); + } else { + when(ctx.flush()).then(new Answer() { + @Override + public ChannelHandlerContext answer(InvocationOnMock invocationOnMock) { + fail("forbidden"); + return null; + } + }); + } + when(ctx.executor()).thenReturn(executor); + } + + @Test + public void dataFrameShouldBeAccepted() throws Http2Exception { + receiveFlowControlledFrame(STREAM_ID, 10, 0, false); + verifyWindowUpdateNotSent(); + } + + @Test + public void windowUpdateShouldSendOnceBytesReturned() throws Http2Exception { + int dataSize = (int) (DEFAULT_WINDOW_SIZE * DEFAULT_WINDOW_UPDATE_RATIO) + 1; + receiveFlowControlledFrame(STREAM_ID, dataSize, 0, false); + + // Return only a few bytes and verify that the WINDOW_UPDATE hasn't been sent. + assertFalse(consumeBytes(STREAM_ID, 10)); + verifyWindowUpdateNotSent(STREAM_ID); + verifyWindowUpdateNotSent(CONNECTION_STREAM_ID); + + // Return the rest and verify the WINDOW_UPDATE is sent. + assertTrue(consumeBytes(STREAM_ID, dataSize - 10)); + verifyWindowUpdateSent(STREAM_ID, dataSize); + verifyWindowUpdateSent(CONNECTION_STREAM_ID, dataSize); + verifyNoMoreInteractions(frameWriter); + } + + @Test + public void connectionWindowShouldAutoRefillWhenDataReceived() throws Http2Exception { + // Reconfigure controller to auto-refill the connection window. + initController(true); + + int dataSize = (int) (DEFAULT_WINDOW_SIZE * DEFAULT_WINDOW_UPDATE_RATIO) + 1; + receiveFlowControlledFrame(STREAM_ID, dataSize, 0, false); + // Verify that we immediately refill the connection window. + verifyWindowUpdateSent(CONNECTION_STREAM_ID, dataSize); + + // Return only a few bytes and verify that the WINDOW_UPDATE hasn't been sent for the stream. + assertFalse(consumeBytes(STREAM_ID, 10)); + verifyWindowUpdateNotSent(STREAM_ID); + + // Return the rest and verify the WINDOW_UPDATE is sent for the stream. + assertTrue(consumeBytes(STREAM_ID, dataSize - 10)); + verifyWindowUpdateSent(STREAM_ID, dataSize); + verifyNoMoreInteractions(frameWriter); + } + + @Test + public void connectionFlowControlExceededShouldThrow() throws Http2Exception { + // Window exceeded because of the padding. + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + receiveFlowControlledFrame(STREAM_ID, DEFAULT_WINDOW_SIZE, 1, true); + } + }); + } + + @Test + public void windowUpdateShouldNotBeSentAfterEndOfStream() throws Http2Exception { + int dataSize = (int) (DEFAULT_WINDOW_SIZE * DEFAULT_WINDOW_UPDATE_RATIO) + 1; + + // Set end-of-stream on the frame, so no window update will be sent for the stream. + receiveFlowControlledFrame(STREAM_ID, dataSize, 0, true); + verifyWindowUpdateNotSent(CONNECTION_STREAM_ID); + verifyWindowUpdateNotSent(STREAM_ID); + + assertTrue(consumeBytes(STREAM_ID, dataSize)); + verifyWindowUpdateSent(CONNECTION_STREAM_ID, dataSize); + verifyWindowUpdateNotSent(STREAM_ID); + } + + @Test + public void windowUpdateShouldNotBeSentAfterStreamIsClosedForUnconsumedBytes() throws Http2Exception { + int dataSize = (int) (DEFAULT_WINDOW_SIZE * DEFAULT_WINDOW_UPDATE_RATIO) + 1; + + // Don't set end-of-stream on the frame as we want to verify that we not return the unconsumed bytes in this + // case once the stream was closed, + receiveFlowControlledFrame(STREAM_ID, dataSize, 0, false); + verifyWindowUpdateNotSent(CONNECTION_STREAM_ID); + verifyWindowUpdateNotSent(STREAM_ID); + + // Close the stream + Http2Stream stream = connection.stream(STREAM_ID); + stream.close(); + assertEquals(State.CLOSED, stream.state()); + assertNull(connection.stream(STREAM_ID)); + + // The window update for the connection should made it through but not the update for the already closed + // stream + verifyWindowUpdateSent(CONNECTION_STREAM_ID, dataSize); + verifyWindowUpdateNotSent(STREAM_ID); + } + + @Test + public void windowUpdateShouldBeWrittenWhenStreamIsClosedAndFlushed() throws Http2Exception { + int dataSize = (int) (DEFAULT_WINDOW_SIZE * DEFAULT_WINDOW_UPDATE_RATIO) + 1; + + setupChannelHandlerContext(true); + + receiveFlowControlledFrame(STREAM_ID, dataSize, 0, false); + verifyWindowUpdateNotSent(CONNECTION_STREAM_ID); + verifyWindowUpdateNotSent(STREAM_ID); + + connection.stream(STREAM_ID).close(); + + verifyWindowUpdateSent(CONNECTION_STREAM_ID, dataSize); + + // Verify we saw one flush. + verify(ctx).flush(); + } + + @Test + public void halfWindowRemainingShouldUpdateAllWindows() throws Http2Exception { + int dataSize = (int) (DEFAULT_WINDOW_SIZE * DEFAULT_WINDOW_UPDATE_RATIO) + 1; + int initialWindowSize = DEFAULT_WINDOW_SIZE; + int windowDelta = getWindowDelta(initialWindowSize, initialWindowSize, dataSize); + + // Don't set end-of-stream so we'll get a window update for the stream as well. + receiveFlowControlledFrame(STREAM_ID, dataSize, 0, false); + assertTrue(consumeBytes(STREAM_ID, dataSize)); + verifyWindowUpdateSent(CONNECTION_STREAM_ID, windowDelta); + verifyWindowUpdateSent(STREAM_ID, windowDelta); + } + + @Test + public void initialWindowUpdateShouldAllowMoreFrames() throws Http2Exception { + // Send a frame that takes up the entire window. + int initialWindowSize = DEFAULT_WINDOW_SIZE; + receiveFlowControlledFrame(STREAM_ID, initialWindowSize, 0, false); + assertEquals(0, window(STREAM_ID)); + assertEquals(0, window(CONNECTION_STREAM_ID)); + consumeBytes(STREAM_ID, initialWindowSize); + assertEquals(initialWindowSize, window(STREAM_ID)); + assertEquals(DEFAULT_WINDOW_SIZE, window(CONNECTION_STREAM_ID)); + + // Update the initial window size to allow another frame. + int newInitialWindowSize = 2 * initialWindowSize; + controller.initialWindowSize(newInitialWindowSize); + assertEquals(newInitialWindowSize, window(STREAM_ID)); + assertEquals(DEFAULT_WINDOW_SIZE, window(CONNECTION_STREAM_ID)); + + // Clear any previous calls to the writer. + reset(frameWriter); + + // Send the next frame and verify that the expected window updates were sent. + receiveFlowControlledFrame(STREAM_ID, initialWindowSize, 0, false); + assertTrue(consumeBytes(STREAM_ID, initialWindowSize)); + int delta = newInitialWindowSize - initialWindowSize; + verifyWindowUpdateSent(STREAM_ID, delta); + verifyWindowUpdateSent(CONNECTION_STREAM_ID, delta); + } + + @Test + public void connectionWindowShouldAdjustWithMultipleStreams() throws Http2Exception { + int newStreamId = 3; + connection.local().createStream(newStreamId, false); + + try { + assertEquals(DEFAULT_WINDOW_SIZE, window(STREAM_ID)); + assertEquals(DEFAULT_WINDOW_SIZE, window(CONNECTION_STREAM_ID)); + + // Test that both stream and connection window are updated (or not updated) together + int data1 = (int) (DEFAULT_WINDOW_SIZE * DEFAULT_WINDOW_UPDATE_RATIO) + 1; + receiveFlowControlledFrame(STREAM_ID, data1, 0, false); + verifyWindowUpdateNotSent(STREAM_ID); + verifyWindowUpdateNotSent(CONNECTION_STREAM_ID); + assertEquals(DEFAULT_WINDOW_SIZE - data1, window(STREAM_ID)); + assertEquals(DEFAULT_WINDOW_SIZE - data1, window(CONNECTION_STREAM_ID)); + assertTrue(consumeBytes(STREAM_ID, data1)); + verifyWindowUpdateSent(STREAM_ID, data1); + verifyWindowUpdateSent(CONNECTION_STREAM_ID, data1); + + reset(frameWriter); + + // Create a scenario where data is depleted from multiple streams, but not enough data + // to generate a window update on those streams. The amount will be enough to generate + // a window update for the connection stream. + --data1; + int data2 = data1 >> 1; + receiveFlowControlledFrame(STREAM_ID, data1, 0, false); + receiveFlowControlledFrame(newStreamId, data1, 0, false); + verifyWindowUpdateNotSent(STREAM_ID); + verifyWindowUpdateNotSent(newStreamId); + verifyWindowUpdateNotSent(CONNECTION_STREAM_ID); + assertEquals(DEFAULT_WINDOW_SIZE - data1, window(STREAM_ID)); + assertEquals(DEFAULT_WINDOW_SIZE - data1, window(newStreamId)); + assertEquals(DEFAULT_WINDOW_SIZE - (data1 << 1), window(CONNECTION_STREAM_ID)); + assertFalse(consumeBytes(STREAM_ID, data1)); + assertTrue(consumeBytes(newStreamId, data2)); + verifyWindowUpdateNotSent(STREAM_ID); + verifyWindowUpdateNotSent(newStreamId); + verifyWindowUpdateSent(CONNECTION_STREAM_ID, data1 + data2); + assertEquals(DEFAULT_WINDOW_SIZE - data1, window(STREAM_ID)); + assertEquals(DEFAULT_WINDOW_SIZE - data1, window(newStreamId)); + assertEquals(DEFAULT_WINDOW_SIZE - (data1 - data2), window(CONNECTION_STREAM_ID)); + } finally { + connection.stream(newStreamId).close(); + } + } + + @Test + public void closeShouldConsumeBytes() throws Http2Exception { + receiveFlowControlledFrame(STREAM_ID, 10, 0, false); + assertEquals(10, controller.unconsumedBytes(connection.connectionStream())); + stream(STREAM_ID).close(); + assertEquals(0, controller.unconsumedBytes(connection.connectionStream())); + } + + @Test + public void closeShouldNotConsumeConnectionWindowWhenAutoRefilled() throws Http2Exception { + // Reconfigure controller to auto-refill the connection window. + initController(true); + + receiveFlowControlledFrame(STREAM_ID, 10, 0, false); + assertEquals(0, controller.unconsumedBytes(connection.connectionStream())); + stream(STREAM_ID).close(); + assertEquals(0, controller.unconsumedBytes(connection.connectionStream())); + } + + @Test + public void dataReceivedForClosedStreamShouldImmediatelyConsumeBytes() throws Http2Exception { + Http2Stream stream = stream(STREAM_ID); + stream.close(); + receiveFlowControlledFrame(stream, 10, 0, false); + assertEquals(0, controller.unconsumedBytes(connection.connectionStream())); + } + + @Test + public void dataReceivedForNullStreamShouldImmediatelyConsumeBytes() throws Http2Exception { + receiveFlowControlledFrame(null, 10, 0, false); + assertEquals(0, controller.unconsumedBytes(connection.connectionStream())); + } + + @Test + public void consumeBytesForNullStreamShouldIgnore() throws Http2Exception { + controller.consumeBytes(null, 10); + assertEquals(0, controller.unconsumedBytes(connection.connectionStream())); + } + + @Test + public void globalRatioShouldImpactStreams() throws Http2Exception { + float ratio = 0.6f; + controller.windowUpdateRatio(ratio); + testRatio(ratio, DEFAULT_WINDOW_SIZE << 1, 3, false); + } + + @Test + public void streamlRatioShouldImpactStreams() throws Http2Exception { + float ratio = 0.6f; + testRatio(ratio, DEFAULT_WINDOW_SIZE << 1, 3, true); + } + + @Test + public void consumeBytesForZeroNumBytesShouldIgnore() throws Http2Exception { + assertFalse(controller.consumeBytes(connection.stream(STREAM_ID), 0)); + } + + @Test + public void consumeBytesForNegativeNumBytesShouldFail() throws Http2Exception { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() throws Throwable { + controller.consumeBytes(connection.stream(STREAM_ID), -1); + } + }); + } + + private void testRatio(float ratio, int newDefaultWindowSize, int newStreamId, boolean setStreamRatio) + throws Http2Exception { + int delta = newDefaultWindowSize - DEFAULT_WINDOW_SIZE; + controller.incrementWindowSize(stream(0), delta); + Http2Stream stream = connection.local().createStream(newStreamId, false); + if (setStreamRatio) { + controller.windowUpdateRatio(stream, ratio); + } + controller.incrementWindowSize(stream, delta); + reset(frameWriter); + try { + int data1 = (int) (newDefaultWindowSize * ratio) + 1; + int data2 = (int) (DEFAULT_WINDOW_SIZE * DEFAULT_WINDOW_UPDATE_RATIO) >> 1; + receiveFlowControlledFrame(STREAM_ID, data2, 0, false); + receiveFlowControlledFrame(newStreamId, data1, 0, false); + verifyWindowUpdateNotSent(STREAM_ID); + verifyWindowUpdateNotSent(newStreamId); + verifyWindowUpdateNotSent(CONNECTION_STREAM_ID); + assertEquals(DEFAULT_WINDOW_SIZE - data2, window(STREAM_ID)); + assertEquals(newDefaultWindowSize - data1, window(newStreamId)); + assertEquals(newDefaultWindowSize - data2 - data1, window(CONNECTION_STREAM_ID)); + assertFalse(consumeBytes(STREAM_ID, data2)); + assertTrue(consumeBytes(newStreamId, data1)); + verifyWindowUpdateNotSent(STREAM_ID); + verifyWindowUpdateSent(newStreamId, data1); + verifyWindowUpdateSent(CONNECTION_STREAM_ID, data1 + data2); + assertEquals(DEFAULT_WINDOW_SIZE - data2, window(STREAM_ID)); + assertEquals(newDefaultWindowSize, window(newStreamId)); + assertEquals(newDefaultWindowSize, window(CONNECTION_STREAM_ID)); + } finally { + connection.stream(newStreamId).close(); + } + } + + private static int getWindowDelta(int initialSize, int windowSize, int dataSize) { + int newWindowSize = windowSize - dataSize; + return initialSize - newWindowSize; + } + + private void receiveFlowControlledFrame(int streamId, int dataSize, int padding, + boolean endOfStream) throws Http2Exception { + receiveFlowControlledFrame(stream(streamId), dataSize, padding, endOfStream); + } + + private void receiveFlowControlledFrame(Http2Stream stream, int dataSize, int padding, + boolean endOfStream) throws Http2Exception { + final ByteBuf buf = dummyData(dataSize); + try { + controller.receiveFlowControlledFrame(stream, buf, padding, endOfStream); + } finally { + buf.release(); + } + } + + private static ByteBuf dummyData(int size) { + final ByteBuf buffer = Unpooled.buffer(size); + buffer.writerIndex(size); + return buffer; + } + + private boolean consumeBytes(int streamId, int numBytes) throws Http2Exception { + return controller.consumeBytes(stream(streamId), numBytes); + } + + private void verifyWindowUpdateSent(int streamId, int windowSizeIncrement) { + verify(frameWriter).writeWindowUpdate(eq(ctx), eq(streamId), eq(windowSizeIncrement), eq(promise)); + } + + private void verifyWindowUpdateNotSent(int streamId) { + verify(frameWriter, never()).writeWindowUpdate(eq(ctx), eq(streamId), anyInt(), eq(promise)); + } + + private void verifyWindowUpdateNotSent() { + verify(frameWriter, never()).writeWindowUpdate(any(ChannelHandlerContext.class), anyInt(), anyInt(), + any(ChannelPromise.class)); + } + + private int window(int streamId) { + return controller.windowSize(stream(streamId)); + } + + private Http2Stream stream(int streamId) { + return connection.stream(streamId); + } + + private void initController(boolean autoRefillConnectionWindow) throws Http2Exception { + connection = new DefaultHttp2Connection(false); + controller = new DefaultHttp2LocalFlowController(connection, + DEFAULT_WINDOW_UPDATE_RATIO, autoRefillConnectionWindow).frameWriter(frameWriter); + connection.local().flowController(controller); + connection.local().createStream(STREAM_ID, false); + controller.channelHandlerContext(ctx); + } +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2PushPromiseFrameTest.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2PushPromiseFrameTest.java new file mode 100644 index 0000000..04acf60 --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2PushPromiseFrameTest.java @@ -0,0 +1,239 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.SocketChannel; +import io.netty.channel.socket.nio.NioServerSocketChannel; +import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.util.CharsetUtil; +import io.netty.util.ReferenceCountUtil; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CountDownLatch; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class DefaultHttp2PushPromiseFrameTest { + + private final EventLoopGroup eventLoopGroup = new NioEventLoopGroup(2); + private final ClientHandler clientHandler = new ClientHandler(); + private final Map contentMap = new ConcurrentHashMap(); + + @BeforeEach + public void setup() throws InterruptedException { + ServerBootstrap serverBootstrap = new ServerBootstrap() + .group(eventLoopGroup) + .channel(NioServerSocketChannel.class) + .childHandler(new ChannelInitializer() { + @Override + protected void initChannel(SocketChannel ch) { + ChannelPipeline pipeline = ch.pipeline(); + + Http2FrameCodec frameCodec = Http2FrameCodecBuilder.forServer() + .autoAckSettingsFrame(true) + .autoAckPingFrame(true) + .build(); + + pipeline.addLast(frameCodec); + pipeline.addLast(new ServerHandler()); + } + }); + + ChannelFuture channelFuture = serverBootstrap.bind(0).sync(); + + final Bootstrap bootstrap = new Bootstrap() + .group(eventLoopGroup) + .channel(NioSocketChannel.class) + .handler(new ChannelInitializer() { + @Override + protected void initChannel(SocketChannel ch) { + ChannelPipeline pipeline = ch.pipeline(); + + Http2FrameCodec frameCodec = Http2FrameCodecBuilder.forClient() + .autoAckSettingsFrame(true) + .autoAckPingFrame(true) + .initialSettings(Http2Settings.defaultSettings().pushEnabled(true)) + .build(); + + pipeline.addLast(frameCodec); + pipeline.addLast(clientHandler); + } + }); + + bootstrap.connect(channelFuture.channel().localAddress()).sync(); + } + + @Test + public void send() throws Exception { + clientHandler.write(); + } + + @AfterEach + public void shutdown() { + eventLoopGroup.shutdownGracefully(); + } + + private final class ServerHandler extends Http2ChannelDuplexHandler { + + @Override + public void channelRead(final ChannelHandlerContext ctx, Object msg) throws Exception { + + if (msg instanceof Http2HeadersFrame) { + final Http2HeadersFrame receivedFrame = (Http2HeadersFrame) msg; + + Http2Headers pushRequestHeaders = new DefaultHttp2Headers(); + pushRequestHeaders.path("/meow") + .method("GET") + .scheme("https") + .authority("localhost:5555"); + + // Write PUSH_PROMISE request headers + final Http2FrameStream newPushFrameStream = newStream(); + Http2PushPromiseFrame pushPromiseFrame = new DefaultHttp2PushPromiseFrame(pushRequestHeaders); + pushPromiseFrame.stream(receivedFrame.stream()); + pushPromiseFrame.pushStream(newPushFrameStream); + ctx.writeAndFlush(pushPromiseFrame).addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + contentMap.put(newPushFrameStream.id(), "Meow, I am Pushed via HTTP/2"); + + // Write headers for actual request + Http2Headers http2Headers = new DefaultHttp2Headers(); + http2Headers.status("200"); + http2Headers.add("push", "false"); + Http2HeadersFrame headersFrame = new DefaultHttp2HeadersFrame(http2Headers, false); + headersFrame.stream(receivedFrame.stream()); + ChannelFuture channelFuture = ctx.writeAndFlush(headersFrame); + + // Write Data of actual request + channelFuture.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + Http2DataFrame dataFrame = new DefaultHttp2DataFrame( + Unpooled.wrappedBuffer("Meow".getBytes()), true); + dataFrame.stream(receivedFrame.stream()); + ctx.writeAndFlush(dataFrame); + } + }); + } + }); + } else if (msg instanceof Http2PriorityFrame) { + Http2PriorityFrame priorityFrame = (Http2PriorityFrame) msg; + String content = contentMap.get(priorityFrame.stream().id()); + if (content == null) { + ctx.writeAndFlush(new DefaultHttp2GoAwayFrame(Http2Error.REFUSED_STREAM)); + return; + } + + // Write headers for Priority request + Http2Headers http2Headers = new DefaultHttp2Headers(); + http2Headers.status("200"); + http2Headers.add("push", "true"); + Http2HeadersFrame headersFrame = new DefaultHttp2HeadersFrame(http2Headers, false); + headersFrame.stream(priorityFrame.stream()); + ctx.writeAndFlush(headersFrame); + + // Write Data of Priority request + Http2DataFrame dataFrame = new DefaultHttp2DataFrame(Unpooled.wrappedBuffer(content.getBytes()), true); + dataFrame.stream(priorityFrame.stream()); + ctx.writeAndFlush(dataFrame); + } + } + } + + private static final class ClientHandler extends Http2ChannelDuplexHandler { + + private final CountDownLatch latch = new CountDownLatch(1); + private volatile ChannelHandlerContext ctx; + + @Override + public void channelActive(ChannelHandlerContext ctx) throws InterruptedException { + this.ctx = ctx; + latch.countDown(); + } + + void write() throws InterruptedException { + latch.await(); + Http2Headers http2Headers = new DefaultHttp2Headers(); + http2Headers.path("/") + .authority("localhost") + .method("GET") + .scheme("https"); + + Http2HeadersFrame headersFrame = new DefaultHttp2HeadersFrame(http2Headers, true); + headersFrame.stream(newStream()); + ctx.writeAndFlush(headersFrame); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + + if (msg instanceof Http2PushPromiseFrame) { + Http2PushPromiseFrame pushPromiseFrame = (Http2PushPromiseFrame) msg; + + assertEquals("/meow", pushPromiseFrame.http2Headers().path().toString()); + assertEquals("GET", pushPromiseFrame.http2Headers().method().toString()); + assertEquals("https", pushPromiseFrame.http2Headers().scheme().toString()); + assertEquals("localhost:5555", pushPromiseFrame.http2Headers().authority().toString()); + + Http2PriorityFrame priorityFrame = new DefaultHttp2PriorityFrame(pushPromiseFrame.stream().id(), + Http2CodecUtil.DEFAULT_PRIORITY_WEIGHT, true); + priorityFrame.stream(pushPromiseFrame.pushStream()); + ctx.writeAndFlush(priorityFrame); + } else if (msg instanceof Http2HeadersFrame) { + Http2HeadersFrame headersFrame = (Http2HeadersFrame) msg; + + if (headersFrame.stream().id() == 3) { + assertEquals("200", headersFrame.headers().status().toString()); + assertEquals("false", headersFrame.headers().get("push").toString()); + } else if (headersFrame.stream().id() == 2) { + assertEquals("200", headersFrame.headers().status().toString()); + assertEquals("true", headersFrame.headers().get("push").toString()); + } else { + ctx.writeAndFlush(new DefaultHttp2GoAwayFrame(Http2Error.REFUSED_STREAM)); + } + } else if (msg instanceof Http2DataFrame) { + Http2DataFrame dataFrame = (Http2DataFrame) msg; + + try { + if (dataFrame.stream().id() == 3) { + assertEquals("Meow", dataFrame.content().toString(CharsetUtil.UTF_8)); + } else if (dataFrame.stream().id() == 2) { + assertEquals("Meow, I am Pushed via HTTP/2", dataFrame.content().toString(CharsetUtil.UTF_8)); + } else { + ctx.writeAndFlush(new DefaultHttp2GoAwayFrame(Http2Error.REFUSED_STREAM)); + } + } finally { + ReferenceCountUtil.release(dataFrame); + } + } + } + } +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2RemoteFlowControllerTest.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2RemoteFlowControllerTest.java new file mode 100644 index 0000000..0ded0e1 --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/DefaultHttp2RemoteFlowControllerTest.java @@ -0,0 +1,1146 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package io.netty.handler.codec.http2; + +import io.netty.channel.Channel; +import io.netty.channel.ChannelConfig; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.util.concurrent.EventExecutor; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; +import org.opentest4j.AssertionFailedError; + +import java.util.concurrent.atomic.AtomicInteger; + +import static io.netty.handler.codec.http2.Http2CodecUtil.CONNECTION_STREAM_ID; +import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_PRIORITY_WEIGHT; +import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_WINDOW_SIZE; +import static io.netty.handler.codec.http2.Http2CodecUtil.MAX_WEIGHT; +import static io.netty.handler.codec.http2.Http2CodecUtil.MIN_WEIGHT; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.anyInt; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.verifyZeroInteractions; +import static org.mockito.Mockito.when; + +/** + * Tests for {@link DefaultHttp2RemoteFlowController}. + */ +public abstract class DefaultHttp2RemoteFlowControllerTest { + private static final int STREAM_A = 1; + private static final int STREAM_B = 3; + private static final int STREAM_C = 5; + private static final int STREAM_D = 7; + + private DefaultHttp2RemoteFlowController controller; + + @Mock + private ChannelHandlerContext ctx; + + @Mock + private Channel channel; + + @Mock + private ChannelConfig config; + + @Mock + private EventExecutor executor; + + @Mock + private ChannelPromise promise; + + @Mock + private Http2RemoteFlowController.Listener listener; + + private DefaultHttp2Connection connection; + + @BeforeEach + public void setup() throws Http2Exception { + MockitoAnnotations.initMocks(this); + + when(ctx.newPromise()).thenReturn(promise); + when(ctx.flush()).thenThrow(new AssertionFailedError("forbidden")); + setChannelWritability(true); + when(channel.config()).thenReturn(config); + when(executor.inEventLoop()).thenReturn(true); + + initConnectionAndController(); + + resetCtx(); + // This is intentionally left out of initConnectionAndController so it can be tested below. + controller.channelHandlerContext(ctx); + assertWritabilityChanged(1, true); + reset(listener); + } + + protected abstract StreamByteDistributor newDistributor(Http2Connection connection); + + private void initConnectionAndController() throws Http2Exception { + connection = new DefaultHttp2Connection(false); + controller = new DefaultHttp2RemoteFlowController(connection, newDistributor(connection), listener); + connection.remote().flowController(controller); + + connection.local().createStream(STREAM_A, false); + connection.local().createStream(STREAM_B, false); + Http2Stream streamC = connection.local().createStream(STREAM_C, false); + Http2Stream streamD = connection.local().createStream(STREAM_D, false); + controller.updateDependencyTree(streamC.id(), STREAM_A, DEFAULT_PRIORITY_WEIGHT, false); + controller.updateDependencyTree(streamD.id(), STREAM_A, DEFAULT_PRIORITY_WEIGHT, false); + } + + @Test + public void initialWindowSizeShouldOnlyChangeStreams() throws Http2Exception { + controller.initialWindowSize(0); + assertEquals(DEFAULT_WINDOW_SIZE, window(CONNECTION_STREAM_ID)); + assertEquals(0, window(STREAM_A)); + assertEquals(0, window(STREAM_B)); + assertEquals(0, window(STREAM_C)); + assertEquals(0, window(STREAM_D)); + assertWritabilityChanged(1, false); + } + + @Test + public void windowUpdateShouldChangeConnectionWindow() throws Http2Exception { + incrementWindowSize(CONNECTION_STREAM_ID, 100); + assertEquals(DEFAULT_WINDOW_SIZE + 100, window(CONNECTION_STREAM_ID)); + assertEquals(DEFAULT_WINDOW_SIZE, window(STREAM_A)); + assertEquals(DEFAULT_WINDOW_SIZE, window(STREAM_B)); + assertEquals(DEFAULT_WINDOW_SIZE, window(STREAM_C)); + assertEquals(DEFAULT_WINDOW_SIZE, window(STREAM_D)); + verifyZeroInteractions(listener); + } + + @Test + public void windowUpdateShouldChangeStreamWindow() throws Http2Exception { + incrementWindowSize(STREAM_A, 100); + assertEquals(DEFAULT_WINDOW_SIZE, window(CONNECTION_STREAM_ID)); + assertEquals(DEFAULT_WINDOW_SIZE + 100, window(STREAM_A)); + assertEquals(DEFAULT_WINDOW_SIZE, window(STREAM_B)); + assertEquals(DEFAULT_WINDOW_SIZE, window(STREAM_C)); + assertEquals(DEFAULT_WINDOW_SIZE, window(STREAM_D)); + verifyZeroInteractions(listener); + } + + @Test + public void payloadSmallerThanWindowShouldBeWrittenImmediately() throws Http2Exception { + FakeFlowControlled data = new FakeFlowControlled(5); + sendData(STREAM_A, data); + data.assertNotWritten(); + verifyZeroInteractions(listener); + controller.writePendingBytes(); + data.assertFullyWritten(); + verifyZeroInteractions(listener); + } + + @Test + public void emptyPayloadShouldBeWrittenImmediately() throws Http2Exception { + FakeFlowControlled data = new FakeFlowControlled(0); + sendData(STREAM_A, data); + data.assertNotWritten(); + controller.writePendingBytes(); + data.assertFullyWritten(); + verifyZeroInteractions(listener); + } + + @Test + public void unflushedPayloadsShouldBeDroppedOnCancel() throws Http2Exception { + FakeFlowControlled data = new FakeFlowControlled(5); + Http2Stream streamA = stream(STREAM_A); + sendData(STREAM_A, data); + streamA.close(); + controller.writePendingBytes(); + data.assertNotWritten(); + controller.writePendingBytes(); + data.assertNotWritten(); + verify(listener, times(1)).writabilityChanged(streamA); + assertFalse(controller.isWritable(streamA)); + } + + @Test + public void payloadsShouldMerge() throws Http2Exception { + controller.initialWindowSize(15); + FakeFlowControlled data1 = new FakeFlowControlled(5, true); + FakeFlowControlled data2 = new FakeFlowControlled(10, true); + sendData(STREAM_A, data1); + sendData(STREAM_A, data2); + data1.assertNotWritten(); + data1.assertNotWritten(); + data2.assertMerged(); + controller.writePendingBytes(); + data1.assertFullyWritten(); + data2.assertNotWritten(); + verify(listener, times(1)).writabilityChanged(stream(STREAM_A)); + assertFalse(controller.isWritable(stream(STREAM_A))); + } + + @Test + public void flowControllerCorrectlyAccountsForBytesWithMerge() throws Http2Exception { + controller.initialWindowSize(112); // This must be more than the total merged frame size 110 + FakeFlowControlled data1 = new FakeFlowControlled(5, 2, true); + FakeFlowControlled data2 = new FakeFlowControlled(5, 100, true); + sendData(STREAM_A, data1); + sendData(STREAM_A, data2); + data1.assertNotWritten(); + data1.assertNotWritten(); + data2.assertMerged(); + controller.writePendingBytes(); + data1.assertFullyWritten(); + data2.assertNotWritten(); + verify(listener, never()).writabilityChanged(stream(STREAM_A)); + assertTrue(controller.isWritable(stream(STREAM_A))); + } + + @Test + public void stalledStreamShouldQueuePayloads() throws Http2Exception { + controller.initialWindowSize(0); + verify(listener, times(1)).writabilityChanged(stream(STREAM_A)); + assertFalse(controller.isWritable(stream(STREAM_A))); + reset(listener); + + FakeFlowControlled data = new FakeFlowControlled(15); + FakeFlowControlled moreData = new FakeFlowControlled(0); + sendData(STREAM_A, data); + controller.writePendingBytes(); + data.assertNotWritten(); + sendData(STREAM_A, moreData); + controller.writePendingBytes(); + moreData.assertNotWritten(); + verifyZeroInteractions(listener); + } + + @Test + public void queuedPayloadsReceiveErrorOnStreamClose() throws Http2Exception { + controller.initialWindowSize(0); + verify(listener, times(1)).writabilityChanged(stream(STREAM_A)); + assertFalse(controller.isWritable(stream(STREAM_A))); + reset(listener); + + FakeFlowControlled data = new FakeFlowControlled(15); + FakeFlowControlled moreData = new FakeFlowControlled(0); + sendData(STREAM_A, data); + controller.writePendingBytes(); + data.assertNotWritten(); + sendData(STREAM_A, moreData); + controller.writePendingBytes(); + moreData.assertNotWritten(); + + connection.stream(STREAM_A).close(); + data.assertError(Http2Error.STREAM_CLOSED); + moreData.assertError(Http2Error.STREAM_CLOSED); + verifyZeroInteractions(listener); + } + + @Test + public void payloadLargerThanWindowShouldWritePartial() throws Http2Exception { + controller.initialWindowSize(5); + verify(listener, never()).writabilityChanged(stream(STREAM_A)); + assertTrue(controller.isWritable(stream(STREAM_A))); + reset(listener); + + final FakeFlowControlled data = new FakeFlowControlled(10); + sendData(STREAM_A, data); + controller.writePendingBytes(); + // Verify that a partial frame of 5 remains to be sent + data.assertPartiallyWritten(5); + verify(listener, times(1)).writabilityChanged(stream(STREAM_A)); + assertFalse(controller.isWritable(stream(STREAM_A))); + verifyNoMoreInteractions(listener); + } + + @Test + public void windowUpdateAndFlushShouldTriggerWrite() throws Http2Exception { + controller.initialWindowSize(10); + verify(listener, never()).writabilityChanged(stream(STREAM_A)); + assertTrue(controller.isWritable(stream(STREAM_A))); + + FakeFlowControlled data = new FakeFlowControlled(20); + FakeFlowControlled moreData = new FakeFlowControlled(10); + sendData(STREAM_A, data); + sendData(STREAM_A, moreData); + controller.writePendingBytes(); + data.assertPartiallyWritten(10); + moreData.assertNotWritten(); + verify(listener, times(1)).writabilityChanged(stream(STREAM_A)); + assertFalse(controller.isWritable(stream(STREAM_A))); + reset(listener); + resetCtx(); + + // Update the window and verify that the rest of data and some of moreData are written + incrementWindowSize(STREAM_A, 15); + verify(listener, never()).writabilityChanged(stream(STREAM_A)); + assertFalse(controller.isWritable(stream(STREAM_A))); + reset(listener); + + controller.writePendingBytes(); + + data.assertFullyWritten(); + moreData.assertPartiallyWritten(5); + verify(listener, never()).writabilityChanged(stream(STREAM_A)); + assertFalse(controller.isWritable(stream(STREAM_A))); + + assertEquals(DEFAULT_WINDOW_SIZE - 25, window(CONNECTION_STREAM_ID)); + assertEquals(0, window(STREAM_A)); + assertEquals(10, window(STREAM_B)); + assertEquals(10, window(STREAM_C)); + assertEquals(10, window(STREAM_D)); + } + + @Test + public void initialWindowUpdateShouldSendPayload() throws Http2Exception { + incrementWindowSize(CONNECTION_STREAM_ID, -window(CONNECTION_STREAM_ID) + 10); + assertWritabilityChanged(0, true); + reset(listener); + + controller.initialWindowSize(0); + assertWritabilityChanged(1, false); + reset(listener); + + FakeFlowControlled data = new FakeFlowControlled(10); + sendData(STREAM_A, data); + controller.writePendingBytes(); + data.assertNotWritten(); + + // Verify that the entire frame was sent. + controller.initialWindowSize(10); + data.assertFullyWritten(); + assertWritabilityChanged(0, false); + } + + @Test + public void successiveSendsShouldNotInteract() throws Http2Exception { + // Collapse the connection window to force queueing. + incrementWindowSize(CONNECTION_STREAM_ID, -window(CONNECTION_STREAM_ID)); + assertEquals(0, window(CONNECTION_STREAM_ID)); + assertWritabilityChanged(1, false); + reset(listener); + + FakeFlowControlled dataA = new FakeFlowControlled(10); + // Queue data for stream A and allow most of it to be written. + sendData(STREAM_A, dataA); + controller.writePendingBytes(); + dataA.assertNotWritten(); + incrementWindowSize(CONNECTION_STREAM_ID, 8); + assertWritabilityChanged(0, false); + reset(listener); + + controller.writePendingBytes(); + dataA.assertPartiallyWritten(8); + assertEquals(65527, window(STREAM_A)); + assertEquals(0, window(CONNECTION_STREAM_ID)); + assertWritabilityChanged(0, false); + reset(listener); + + // Queue data for stream B and allow the rest of A and all of B to be written. + FakeFlowControlled dataB = new FakeFlowControlled(10); + sendData(STREAM_B, dataB); + controller.writePendingBytes(); + dataB.assertNotWritten(); + incrementWindowSize(CONNECTION_STREAM_ID, 12); + assertWritabilityChanged(0, false); + reset(listener); + + controller.writePendingBytes(); + assertEquals(0, window(CONNECTION_STREAM_ID)); + assertWritabilityChanged(0, false); + + // Verify the rest of A is written. + dataA.assertFullyWritten(); + assertEquals(65525, window(STREAM_A)); + + dataB.assertFullyWritten(); + assertEquals(65525, window(STREAM_B)); + verifyNoMoreInteractions(listener); + } + + @Test + public void negativeWindowShouldNotThrowException() throws Http2Exception { + final int initWindow = 20; + final int secondWindowSize = 10; + controller.initialWindowSize(initWindow); + assertWritabilityChanged(0, true); + reset(listener); + + FakeFlowControlled data1 = new FakeFlowControlled(initWindow); + FakeFlowControlled data2 = new FakeFlowControlled(5); + + // Deplete the stream A window to 0 + sendData(STREAM_A, data1); + controller.writePendingBytes(); + data1.assertFullyWritten(); + assertTrue(window(CONNECTION_STREAM_ID) > 0); + verify(listener, times(1)).writabilityChanged(stream(STREAM_A)); + verify(listener, never()).writabilityChanged(stream(STREAM_B)); + verify(listener, never()).writabilityChanged(stream(STREAM_C)); + verify(listener, never()).writabilityChanged(stream(STREAM_D)); + assertFalse(controller.isWritable(stream(STREAM_A))); + assertTrue(controller.isWritable(stream(STREAM_B))); + assertTrue(controller.isWritable(stream(STREAM_C))); + assertTrue(controller.isWritable(stream(STREAM_D))); + reset(listener); + + // Make the window size for stream A negative + controller.initialWindowSize(initWindow - secondWindowSize); + assertEquals(-secondWindowSize, window(STREAM_A)); + verify(listener, never()).writabilityChanged(stream(STREAM_A)); + verify(listener, never()).writabilityChanged(stream(STREAM_B)); + verify(listener, never()).writabilityChanged(stream(STREAM_C)); + verify(listener, never()).writabilityChanged(stream(STREAM_D)); + assertFalse(controller.isWritable(stream(STREAM_A))); + assertTrue(controller.isWritable(stream(STREAM_B))); + assertTrue(controller.isWritable(stream(STREAM_C))); + assertTrue(controller.isWritable(stream(STREAM_D))); + reset(listener); + + // Queue up a write. It should not be written now because the window is negative + sendData(STREAM_A, data2); + controller.writePendingBytes(); + data2.assertNotWritten(); + verify(listener, never()).writabilityChanged(stream(STREAM_A)); + verify(listener, never()).writabilityChanged(stream(STREAM_B)); + verify(listener, never()).writabilityChanged(stream(STREAM_C)); + verify(listener, never()).writabilityChanged(stream(STREAM_D)); + assertFalse(controller.isWritable(stream(STREAM_A))); + assertTrue(controller.isWritable(stream(STREAM_B))); + assertTrue(controller.isWritable(stream(STREAM_C))); + assertTrue(controller.isWritable(stream(STREAM_D))); + reset(listener); + + // Open the window size back up a bit (no send should happen) + incrementWindowSize(STREAM_A, 5); + controller.writePendingBytes(); + assertEquals(-5, window(STREAM_A)); + data2.assertNotWritten(); + verify(listener, never()).writabilityChanged(stream(STREAM_A)); + verify(listener, never()).writabilityChanged(stream(STREAM_B)); + verify(listener, never()).writabilityChanged(stream(STREAM_C)); + verify(listener, never()).writabilityChanged(stream(STREAM_D)); + assertFalse(controller.isWritable(stream(STREAM_A))); + assertTrue(controller.isWritable(stream(STREAM_B))); + assertTrue(controller.isWritable(stream(STREAM_C))); + assertTrue(controller.isWritable(stream(STREAM_D))); + reset(listener); + + // Open the window size back up a bit (no send should happen) + incrementWindowSize(STREAM_A, 5); + controller.writePendingBytes(); + assertEquals(0, window(STREAM_A)); + data2.assertNotWritten(); + verify(listener, never()).writabilityChanged(stream(STREAM_A)); + verify(listener, never()).writabilityChanged(stream(STREAM_B)); + verify(listener, never()).writabilityChanged(stream(STREAM_C)); + verify(listener, never()).writabilityChanged(stream(STREAM_D)); + assertFalse(controller.isWritable(stream(STREAM_A))); + assertTrue(controller.isWritable(stream(STREAM_B))); + assertTrue(controller.isWritable(stream(STREAM_C))); + assertTrue(controller.isWritable(stream(STREAM_D))); + reset(listener); + + // Open the window size back up and allow the write to happen + incrementWindowSize(STREAM_A, 5); + controller.writePendingBytes(); + data2.assertFullyWritten(); + verify(listener, never()).writabilityChanged(stream(STREAM_A)); + verify(listener, never()).writabilityChanged(stream(STREAM_B)); + verify(listener, never()).writabilityChanged(stream(STREAM_C)); + verify(listener, never()).writabilityChanged(stream(STREAM_D)); + assertFalse(controller.isWritable(stream(STREAM_A))); + assertTrue(controller.isWritable(stream(STREAM_B))); + assertTrue(controller.isWritable(stream(STREAM_C))); + assertTrue(controller.isWritable(stream(STREAM_D))); + } + + @Test + public void initialWindowUpdateShouldSendEmptyFrame() throws Http2Exception { + controller.initialWindowSize(0); + assertWritabilityChanged(1, false); + reset(listener); + + // First send a frame that will get buffered. + FakeFlowControlled data = new FakeFlowControlled(10, false); + sendData(STREAM_A, data); + controller.writePendingBytes(); + data.assertNotWritten(); + + // Now send an empty frame on the same stream and verify that it's also buffered. + FakeFlowControlled data2 = new FakeFlowControlled(0, false); + sendData(STREAM_A, data2); + controller.writePendingBytes(); + data2.assertNotWritten(); + + // Re-expand the window and verify that both frames were sent. + controller.initialWindowSize(10); + verify(listener, never()).writabilityChanged(stream(STREAM_A)); + verify(listener, times(1)).writabilityChanged(stream(STREAM_B)); + verify(listener, times(1)).writabilityChanged(stream(STREAM_C)); + verify(listener, times(1)).writabilityChanged(stream(STREAM_D)); + assertFalse(controller.isWritable(stream(STREAM_A))); + assertTrue(controller.isWritable(stream(STREAM_B))); + assertTrue(controller.isWritable(stream(STREAM_C))); + assertTrue(controller.isWritable(stream(STREAM_D))); + + data.assertFullyWritten(); + data2.assertFullyWritten(); + } + + @Test + public void initialWindowUpdateShouldSendPartialFrame() throws Http2Exception { + controller.initialWindowSize(0); + assertWritabilityChanged(1, false); + reset(listener); + + FakeFlowControlled data = new FakeFlowControlled(10); + sendData(STREAM_A, data); + controller.writePendingBytes(); + data.assertNotWritten(); + + // Verify that a partial frame of 5 was sent. + controller.initialWindowSize(5); + verify(listener, never()).writabilityChanged(stream(STREAM_A)); + verify(listener, times(1)).writabilityChanged(stream(STREAM_B)); + verify(listener, times(1)).writabilityChanged(stream(STREAM_C)); + verify(listener, times(1)).writabilityChanged(stream(STREAM_D)); + assertFalse(controller.isWritable(stream(STREAM_A))); + assertTrue(controller.isWritable(stream(STREAM_B))); + assertTrue(controller.isWritable(stream(STREAM_C))); + assertTrue(controller.isWritable(stream(STREAM_D))); + + data.assertPartiallyWritten(5); + } + + @Test + public void connectionWindowUpdateShouldSendFrame() throws Http2Exception { + // Set the connection window size to zero. + exhaustStreamWindow(CONNECTION_STREAM_ID); + assertWritabilityChanged(1, false); + reset(listener); + + FakeFlowControlled data = new FakeFlowControlled(10); + sendData(STREAM_A, data); + controller.writePendingBytes(); + data.assertNotWritten(); + assertWritabilityChanged(0, false); + reset(listener); + + // Verify that the entire frame was sent. + incrementWindowSize(CONNECTION_STREAM_ID, 10); + assertWritabilityChanged(0, false); + reset(listener); + data.assertNotWritten(); + + controller.writePendingBytes(); + data.assertFullyWritten(); + assertWritabilityChanged(0, false); + assertEquals(0, window(CONNECTION_STREAM_ID)); + assertEquals(DEFAULT_WINDOW_SIZE - 10, window(STREAM_A)); + assertEquals(DEFAULT_WINDOW_SIZE, window(STREAM_B)); + assertEquals(DEFAULT_WINDOW_SIZE, window(STREAM_C)); + assertEquals(DEFAULT_WINDOW_SIZE, window(STREAM_D)); + } + + @Test + public void connectionWindowUpdateShouldSendPartialFrame() throws Http2Exception { + // Set the connection window size to zero. + exhaustStreamWindow(CONNECTION_STREAM_ID); + assertWritabilityChanged(1, false); + reset(listener); + + FakeFlowControlled data = new FakeFlowControlled(10); + sendData(STREAM_A, data); + controller.writePendingBytes(); + data.assertNotWritten(); + + // Verify that a partial frame of 5 was sent. + incrementWindowSize(CONNECTION_STREAM_ID, 5); + data.assertNotWritten(); + assertWritabilityChanged(0, false); + reset(listener); + + controller.writePendingBytes(); + data.assertPartiallyWritten(5); + assertWritabilityChanged(0, false); + assertEquals(0, window(CONNECTION_STREAM_ID)); + assertEquals(DEFAULT_WINDOW_SIZE - 5, window(STREAM_A)); + assertEquals(DEFAULT_WINDOW_SIZE, window(STREAM_B)); + assertEquals(DEFAULT_WINDOW_SIZE, window(STREAM_C)); + assertEquals(DEFAULT_WINDOW_SIZE, window(STREAM_D)); + } + + @Test + public void streamWindowUpdateShouldSendFrame() throws Http2Exception { + // Set the stream window size to zero. + exhaustStreamWindow(STREAM_A); + verify(listener, times(1)).writabilityChanged(stream(STREAM_A)); + verify(listener, never()).writabilityChanged(stream(STREAM_B)); + verify(listener, never()).writabilityChanged(stream(STREAM_C)); + verify(listener, never()).writabilityChanged(stream(STREAM_D)); + assertFalse(controller.isWritable(stream(STREAM_A))); + assertTrue(controller.isWritable(stream(STREAM_B))); + assertTrue(controller.isWritable(stream(STREAM_C))); + assertTrue(controller.isWritable(stream(STREAM_D))); + reset(listener); + + FakeFlowControlled data = new FakeFlowControlled(10); + sendData(STREAM_A, data); + controller.writePendingBytes(); + data.assertNotWritten(); + + // Verify that the entire frame was sent. + incrementWindowSize(STREAM_A, 10); + verify(listener, never()).writabilityChanged(stream(STREAM_A)); + verify(listener, never()).writabilityChanged(stream(STREAM_B)); + verify(listener, never()).writabilityChanged(stream(STREAM_C)); + verify(listener, never()).writabilityChanged(stream(STREAM_D)); + assertFalse(controller.isWritable(stream(STREAM_A))); + assertTrue(controller.isWritable(stream(STREAM_B))); + assertTrue(controller.isWritable(stream(STREAM_C))); + assertTrue(controller.isWritable(stream(STREAM_D))); + reset(listener); + + data.assertNotWritten(); + controller.writePendingBytes(); + data.assertFullyWritten(); + verify(listener, never()).writabilityChanged(stream(STREAM_A)); + verify(listener, never()).writabilityChanged(stream(STREAM_B)); + verify(listener, never()).writabilityChanged(stream(STREAM_C)); + verify(listener, never()).writabilityChanged(stream(STREAM_D)); + assertFalse(controller.isWritable(stream(STREAM_A))); + assertTrue(controller.isWritable(stream(STREAM_B))); + assertTrue(controller.isWritable(stream(STREAM_C))); + assertTrue(controller.isWritable(stream(STREAM_D))); + assertEquals(DEFAULT_WINDOW_SIZE - 10, window(CONNECTION_STREAM_ID)); + assertEquals(0, window(STREAM_A)); + assertEquals(DEFAULT_WINDOW_SIZE, window(STREAM_B)); + assertEquals(DEFAULT_WINDOW_SIZE, window(STREAM_C)); + assertEquals(DEFAULT_WINDOW_SIZE, window(STREAM_D)); + } + + @Test + public void streamWindowUpdateShouldSendPartialFrame() throws Http2Exception { + // Set the stream window size to zero. + exhaustStreamWindow(STREAM_A); + verify(listener, times(1)).writabilityChanged(stream(STREAM_A)); + verify(listener, never()).writabilityChanged(stream(STREAM_B)); + verify(listener, never()).writabilityChanged(stream(STREAM_C)); + verify(listener, never()).writabilityChanged(stream(STREAM_D)); + assertFalse(controller.isWritable(stream(STREAM_A))); + assertTrue(controller.isWritable(stream(STREAM_B))); + assertTrue(controller.isWritable(stream(STREAM_C))); + assertTrue(controller.isWritable(stream(STREAM_D))); + reset(listener); + + FakeFlowControlled data = new FakeFlowControlled(10); + sendData(STREAM_A, data); + controller.writePendingBytes(); + data.assertNotWritten(); + + // Verify that a partial frame of 5 was sent. + incrementWindowSize(STREAM_A, 5); + verify(listener, never()).writabilityChanged(stream(STREAM_A)); + verify(listener, never()).writabilityChanged(stream(STREAM_B)); + verify(listener, never()).writabilityChanged(stream(STREAM_C)); + verify(listener, never()).writabilityChanged(stream(STREAM_D)); + assertFalse(controller.isWritable(stream(STREAM_A))); + assertTrue(controller.isWritable(stream(STREAM_B))); + assertTrue(controller.isWritable(stream(STREAM_C))); + assertTrue(controller.isWritable(stream(STREAM_D))); + reset(listener); + + data.assertNotWritten(); + controller.writePendingBytes(); + data.assertPartiallyWritten(5); + assertEquals(DEFAULT_WINDOW_SIZE - 5, window(CONNECTION_STREAM_ID)); + assertEquals(0, window(STREAM_A)); + assertEquals(DEFAULT_WINDOW_SIZE, window(STREAM_B)); + assertEquals(DEFAULT_WINDOW_SIZE, window(STREAM_C)); + assertEquals(DEFAULT_WINDOW_SIZE, window(STREAM_D)); + } + + @Test + public void flowControlledWriteThrowsAnException() throws Exception { + final Http2RemoteFlowController.FlowControlled flowControlled = mockedFlowControlledThatThrowsOnWrite(); + final Http2Stream stream = stream(STREAM_A); + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocationOnMock) { + stream.closeLocalSide(); + return null; + } + }).when(flowControlled).error(any(ChannelHandlerContext.class), any(Throwable.class)); + + int windowBefore = window(STREAM_A); + + controller.addFlowControlled(stream, flowControlled); + controller.writePendingBytes(); + + verify(flowControlled, atLeastOnce()).write(any(ChannelHandlerContext.class), anyInt()); + verify(flowControlled).error(any(ChannelHandlerContext.class), any(Throwable.class)); + verify(flowControlled, never()).writeComplete(); + + assertEquals(90, windowBefore - window(STREAM_A)); + verify(listener, times(1)).writabilityChanged(stream(STREAM_A)); + verify(listener, never()).writabilityChanged(stream(STREAM_B)); + verify(listener, never()).writabilityChanged(stream(STREAM_C)); + verify(listener, never()).writabilityChanged(stream(STREAM_D)); + assertFalse(controller.isWritable(stream(STREAM_A))); + assertTrue(controller.isWritable(stream(STREAM_B))); + assertTrue(controller.isWritable(stream(STREAM_C))); + assertTrue(controller.isWritable(stream(STREAM_D))); + } + + @Test + public void flowControlledWriteAndErrorThrowAnException() throws Exception { + final Http2RemoteFlowController.FlowControlled flowControlled = mockedFlowControlledThatThrowsOnWrite(); + final Http2Stream stream = stream(STREAM_A); + final RuntimeException fakeException = new RuntimeException("error failed"); + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocationOnMock) { + throw fakeException; + } + }).when(flowControlled).error(any(ChannelHandlerContext.class), any(Throwable.class)); + + int windowBefore = window(STREAM_A); + + Http2Exception e = assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + controller.addFlowControlled(stream, flowControlled); + controller.writePendingBytes(); + } + }); + assertSame(fakeException, e.getCause()); + + verify(flowControlled, atLeastOnce()).write(any(ChannelHandlerContext.class), anyInt()); + verify(flowControlled).error(any(ChannelHandlerContext.class), any(Throwable.class)); + verify(flowControlled, never()).writeComplete(); + + assertEquals(90, windowBefore - window(STREAM_A)); + verifyZeroInteractions(listener); + } + + @Test + public void flowControlledWriteCompleteThrowsAnException() throws Exception { + final Http2RemoteFlowController.FlowControlled flowControlled = + mock(Http2RemoteFlowController.FlowControlled.class); + Http2Stream streamA = stream(STREAM_A); + final AtomicInteger size = new AtomicInteger(150); + doAnswer(new Answer() { + @Override + public Integer answer(InvocationOnMock invocationOnMock) throws Throwable { + return size.get(); + } + }).when(flowControlled).size(); + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocationOnMock) throws Throwable { + size.addAndGet(-50); + return null; + } + }).when(flowControlled).write(any(ChannelHandlerContext.class), anyInt()); + + final Http2Stream stream = stream(STREAM_A); + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocationOnMock) { + throw new RuntimeException("writeComplete failed"); + } + }).when(flowControlled).writeComplete(); + + int windowBefore = window(STREAM_A); + + controller.addFlowControlled(stream, flowControlled); + controller.writePendingBytes(); + + verify(flowControlled, times(3)).write(any(ChannelHandlerContext.class), anyInt()); + verify(flowControlled, never()).error(any(ChannelHandlerContext.class), any(Throwable.class)); + verify(flowControlled).writeComplete(); + + assertEquals(150, windowBefore - window(STREAM_A)); + verify(listener, times(1)).writabilityChanged(streamA); + verify(listener, never()).writabilityChanged(stream(STREAM_B)); + verify(listener, never()).writabilityChanged(stream(STREAM_C)); + verify(listener, never()).writabilityChanged(stream(STREAM_D)); + assertFalse(controller.isWritable(streamA)); + assertTrue(controller.isWritable(stream(STREAM_B))); + assertTrue(controller.isWritable(stream(STREAM_C))); + assertTrue(controller.isWritable(stream(STREAM_D))); + } + + @Test + public void closeStreamInFlowControlledError() throws Exception { + final Http2RemoteFlowController.FlowControlled flowControlled = + mock(Http2RemoteFlowController.FlowControlled.class); + final Http2Stream stream = stream(STREAM_A); + when(flowControlled.size()).thenReturn(100); + doThrow(new RuntimeException("write failed")) + .when(flowControlled).write(any(ChannelHandlerContext.class), anyInt()); + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocationOnMock) { + stream.close(); + return null; + } + }).when(flowControlled).error(any(ChannelHandlerContext.class), any(Throwable.class)); + + controller.addFlowControlled(stream, flowControlled); + controller.writePendingBytes(); + + verify(flowControlled).write(any(ChannelHandlerContext.class), anyInt()); + verify(flowControlled).error(any(ChannelHandlerContext.class), any(Throwable.class)); + verify(flowControlled, never()).writeComplete(); + verify(listener, times(1)).writabilityChanged(stream); + verify(listener, never()).writabilityChanged(stream(STREAM_B)); + verify(listener, never()).writabilityChanged(stream(STREAM_C)); + verify(listener, never()).writabilityChanged(stream(STREAM_D)); + assertFalse(controller.isWritable(stream)); + assertTrue(controller.isWritable(stream(STREAM_B))); + assertTrue(controller.isWritable(stream(STREAM_C))); + assertTrue(controller.isWritable(stream(STREAM_D))); + } + + @Test + public void nonWritableChannelDoesNotAttemptToWrite() throws Exception { + // Start the channel as not writable and exercise the public methods of the flow controller + // making sure no frames are written. + setChannelWritability(false); + assertWritabilityChanged(1, false); + reset(listener); + FakeFlowControlled dataA = new FakeFlowControlled(1); + FakeFlowControlled dataB = new FakeFlowControlled(1); + final Http2Stream stream = stream(STREAM_A); + + controller.addFlowControlled(stream, dataA); + controller.writePendingBytes(); + dataA.assertNotWritten(); + + controller.incrementWindowSize(stream, 100); + controller.writePendingBytes(); + dataA.assertNotWritten(); + + controller.addFlowControlled(stream, dataB); + controller.writePendingBytes(); + dataA.assertNotWritten(); + dataB.assertNotWritten(); + assertWritabilityChanged(0, false); + + // Now change the channel to writable and make sure frames are written. + setChannelWritability(true); + assertWritabilityChanged(1, true); + controller.writePendingBytes(); + dataA.assertFullyWritten(); + dataB.assertFullyWritten(); + } + + @Test + public void contextShouldSendQueuedFramesWhenSet() throws Exception { + // Re-initialize the controller so we can ensure the context hasn't been set yet. + initConnectionAndController(); + + FakeFlowControlled dataA = new FakeFlowControlled(1); + final Http2Stream stream = stream(STREAM_A); + + // Queue some frames + controller.addFlowControlled(stream, dataA); + dataA.assertNotWritten(); + + controller.incrementWindowSize(stream, 100); + dataA.assertNotWritten(); + + assertWritabilityChanged(0, false); + + // Set the controller + controller.channelHandlerContext(ctx); + dataA.assertFullyWritten(); + + assertWritabilityChanged(1, true); + } + + @Test + public void initialWindowSizeWithNoContextShouldNotThrow() throws Exception { + // Re-initialize the controller so we can ensure the context hasn't been set yet. + initConnectionAndController(); + + // This should not throw. + controller.initialWindowSize(1024 * 100); + + FakeFlowControlled dataA = new FakeFlowControlled(1); + final Http2Stream stream = stream(STREAM_A); + + // Queue some frames + controller.addFlowControlled(stream, dataA); + dataA.assertNotWritten(); + + // Set the controller + controller.channelHandlerContext(ctx); + dataA.assertFullyWritten(); + } + + @Test + public void invalidParentStreamIdThrows() { + assertThrows(AssertionError.class, new Executable() { + @Override + public void execute() throws Throwable { + controller.updateDependencyTree(STREAM_D, -1, DEFAULT_PRIORITY_WEIGHT, true); + } + }); + } + + @Test + public void invalidChildStreamIdThrows() { + assertThrows(AssertionError.class, new Executable() { + @Override + public void execute() throws Throwable { + controller.updateDependencyTree(-1, STREAM_D, DEFAULT_PRIORITY_WEIGHT, true); + } + }); + } + + @Test + public void connectionChildStreamIdThrows() { + assertThrows(AssertionError.class, new Executable() { + @Override + public void execute() throws Throwable { + controller.updateDependencyTree(0, STREAM_D, DEFAULT_PRIORITY_WEIGHT, true); + } + }); + } + + @Test + public void invalidWeightTooSmallThrows() { + assertThrows(AssertionError.class, new Executable() { + @Override + public void execute() throws Throwable { + controller.updateDependencyTree(STREAM_A, STREAM_D, (short) (MIN_WEIGHT - 1), true); + } + }); + } + + @Test + public void invalidWeightTooBigThrows() { + assertThrows(AssertionError.class, new Executable() { + @Override + public void execute() throws Throwable { + controller.updateDependencyTree(STREAM_A, STREAM_D, (short) (MAX_WEIGHT + 1), true); + } + }); + } + + @Test + public void dependencyOnSelfThrows() { + assertThrows(AssertionError.class, new Executable() { + @Override + public void execute() throws Throwable { + controller.updateDependencyTree(STREAM_A, STREAM_A, DEFAULT_PRIORITY_WEIGHT, true); + } + }); + } + + private void assertWritabilityChanged(int amt, boolean writable) { + verify(listener, times(amt)).writabilityChanged(stream(STREAM_A)); + verify(listener, times(amt)).writabilityChanged(stream(STREAM_B)); + verify(listener, times(amt)).writabilityChanged(stream(STREAM_C)); + verify(listener, times(amt)).writabilityChanged(stream(STREAM_D)); + if (writable) { + assertTrue(controller.isWritable(stream(STREAM_A))); + assertTrue(controller.isWritable(stream(STREAM_B))); + assertTrue(controller.isWritable(stream(STREAM_C))); + assertTrue(controller.isWritable(stream(STREAM_D))); + } else { + assertFalse(controller.isWritable(stream(STREAM_A))); + assertFalse(controller.isWritable(stream(STREAM_B))); + assertFalse(controller.isWritable(stream(STREAM_C))); + assertFalse(controller.isWritable(stream(STREAM_D))); + } + } + + private static Http2RemoteFlowController.FlowControlled mockedFlowControlledThatThrowsOnWrite() throws Exception { + final Http2RemoteFlowController.FlowControlled flowControlled = + mock(Http2RemoteFlowController.FlowControlled.class); + when(flowControlled.size()).thenReturn(100); + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock in) throws Throwable { + // Write most of the bytes and then fail + when(flowControlled.size()).thenReturn(10); + throw new RuntimeException("Write failed"); + } + }).when(flowControlled).write(any(ChannelHandlerContext.class), anyInt()); + return flowControlled; + } + + private void sendData(int streamId, FakeFlowControlled data) { + Http2Stream stream = stream(streamId); + controller.addFlowControlled(stream, data); + } + + private void exhaustStreamWindow(int streamId) throws Http2Exception { + incrementWindowSize(streamId, -window(streamId)); + } + + private int window(int streamId) { + return controller.windowSize(stream(streamId)); + } + + private void incrementWindowSize(int streamId, int delta) throws Http2Exception { + controller.incrementWindowSize(stream(streamId), delta); + } + + private Http2Stream stream(int streamId) { + return connection.stream(streamId); + } + + private void resetCtx() { + reset(ctx); + when(ctx.channel()).thenReturn(channel); + when(ctx.executor()).thenReturn(executor); + } + + private void setChannelWritability(boolean isWritable) throws Http2Exception { + when(channel.bytesBeforeUnwritable()).thenReturn(isWritable ? Long.MAX_VALUE : 0); + when(channel.isWritable()).thenReturn(isWritable); + if (controller != null) { + controller.channelWritabilityChanged(); + } + } + + private static final class FakeFlowControlled implements Http2RemoteFlowController.FlowControlled { + private int currentPadding; + private int currentPayloadSize; + private int originalPayloadSize; + private int originalPadding; + private boolean writeCalled; + private final boolean mergeable; + private boolean merged; + + private Throwable t; + + private FakeFlowControlled(int size) { + this(size, false); + } + + private FakeFlowControlled(int size, boolean mergeable) { + this(size, 0, mergeable); + } + + private FakeFlowControlled(int payloadSize, int padding, boolean mergeable) { + currentPayloadSize = originalPayloadSize = payloadSize; + currentPadding = originalPadding = padding; + this.mergeable = mergeable; + } + + @Override + public int size() { + return currentPayloadSize + currentPadding; + } + + private int originalSize() { + return originalPayloadSize + originalPadding; + } + + @Override + public void error(ChannelHandlerContext ctx, Throwable t) { + this.t = t; + } + + @Override + public void writeComplete() { + } + + @Override + public void write(ChannelHandlerContext ctx, int allowedBytes) { + if (allowedBytes <= 0 && size() != 0) { + // Write has been called but no data can be written + return; + } + writeCalled = true; + int written = Math.min(size(), allowedBytes); + if (written > currentPayloadSize) { + written -= currentPayloadSize; + currentPayloadSize = 0; + currentPadding -= written; + } else { + currentPayloadSize -= written; + } + } + + @Override + public boolean merge(ChannelHandlerContext ctx, Http2RemoteFlowController.FlowControlled next) { + if (mergeable && next instanceof FakeFlowControlled) { + FakeFlowControlled ffcNext = (FakeFlowControlled) next; + originalPayloadSize += ffcNext.originalPayloadSize; + currentPayloadSize += ffcNext.originalPayloadSize; + currentPadding = originalPadding = Math.max(originalPadding, ffcNext.originalPadding); + ffcNext.merged = true; + return true; + } + return false; + } + + public int written() { + return originalSize() - size(); + } + + public void assertNotWritten() { + assertFalse(writeCalled); + } + + public void assertPartiallyWritten(int expectedWritten) { + assertPartiallyWritten(expectedWritten, 0); + } + + public void assertPartiallyWritten(int expectedWritten, int delta) { + assertTrue(writeCalled); + assertEquals(expectedWritten, written(), delta); + } + + public void assertFullyWritten() { + assertTrue(writeCalled); + assertEquals(0, currentPayloadSize); + assertEquals(0, currentPadding); + } + + public boolean assertMerged() { + return merged; + } + + public void assertError(Http2Error error) { + assertNotNull(t); + if (error != null) { + assertSame(error, ((Http2Exception) t).error()); + } + } + } +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/HashCollisionTest.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/HashCollisionTest.java new file mode 100644 index 0000000..f71df3b --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/HashCollisionTest.java @@ -0,0 +1,177 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaderValues; +import io.netty.util.AsciiString; +import io.netty.util.internal.PlatformDependent; +import org.junit.jupiter.api.Disabled; + +import java.io.BufferedReader; +import java.io.File; +import java.io.FileReader; +import java.io.IOException; +import java.io.PrintStream; +import java.lang.reflect.Field; +import java.net.URISyntaxException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.Set; + +@Disabled +public final class HashCollisionTest { + private HashCollisionTest() { } + + public static void main(String[] args) throws IllegalAccessException, IOException, URISyntaxException { + // Big initial size for when all name sources are pulled in. + List strings = new ArrayList(350000); + addHttpHeaderNames(strings); + addHttpHeaderValues(strings); + addHttp2HeaderNames(strings); + addWordsFromFile(new File("/usr/share/dict/words"), strings); + // More "english words" can be found here: + // https://gist.github.com/Scottmitch/de2f03912778016ecee3c140478f07e0#file-englishwords-txt + + Map> dups = calculateDuplicates(strings, new Function() { + @Override + public Integer apply(CharSequence string) { + int h = 0; + for (int i = 0; i < string.length(); ++i) { + // masking with 0x1F reduces the number of overall bits that impact the hash code but makes the hash + // code the same regardless of character case (upper case or lower case hash is the same). + h = h * 31 + (string.charAt(i) & 0x1F); + } + return h; + } + }); + PrintStream writer = System.out; + writer.println("==Old Duplicates=="); + printResults(writer, dups); + + dups = calculateDuplicates(strings, new Function() { + @Override + public Integer apply(CharSequence string) { + return PlatformDependent.hashCodeAscii(string); + } + }); + writer.println(); + writer.println("==New Duplicates=="); + printResults(writer, dups); + } + + private static void addHttpHeaderNames(List values) throws IllegalAccessException { + for (Field f : HttpHeaderNames.class.getFields()) { + if (f.getType() == AsciiString.class) { + values.add((AsciiString) f.get(null)); + } + } + } + + private static void addHttpHeaderValues(List values) throws IllegalAccessException { + for (Field f : HttpHeaderValues.class.getFields()) { + if (f.getType() == AsciiString.class) { + values.add((AsciiString) f.get(null)); + } + } + } + + private static void addHttp2HeaderNames(List values) throws IllegalAccessException { + for (Http2Headers.PseudoHeaderName name : Http2Headers.PseudoHeaderName.values()) { + values.add(name.value()); + } + } + + private static void addWordsFromFile(File file, List values) + throws IllegalAccessException, IOException { + BufferedReader br = new BufferedReader(new FileReader(file)); + try { + String line; + while ((line = br.readLine()) != null) { + // Make a "best effort" to prune input which contains characters that are not valid in HTTP header names + if (line.indexOf('\'') < 0) { + values.add(line); + } + } + } finally { + br.close(); + } + } + + private static Map> calculateDuplicates(List strings, + Function hasher) { + Map> hashResults = new HashMap>(); + Set duplicateHashCodes = new HashSet(); + + for (CharSequence str : strings) { + Integer hash = hasher.apply(str); + List results = hashResults.get(hash); + if (results == null) { + results = new ArrayList(1); + hashResults.put(hash, results); + } else { + duplicateHashCodes.add(hash); + } + results.add(str); + } + + if (duplicateHashCodes.isEmpty()) { + return Collections.emptyMap(); + } + Map> duplicates = + new HashMap>(duplicateHashCodes.size()); + for (Integer duplicateHashCode : duplicateHashCodes) { + List realDups = new ArrayList(2); + Iterator itr = hashResults.get(duplicateHashCode).iterator(); + // there should be at least 2 elements in the list ... bcz there may be duplicates + realDups.add(itr.next()); + checknext: do { + CharSequence next = itr.next(); + for (CharSequence potentialDup : realDups) { + if (!AsciiString.contentEqualsIgnoreCase(next, potentialDup)) { + realDups.add(next); + break checknext; + } + } + } while (itr.hasNext()); + + if (realDups.size() > 1) { + duplicates.put(duplicateHashCode, realDups); + } + } + return duplicates; + } + + private static void printResults(PrintStream stream, Map> dups) { + stream.println("Number duplicates: " + dups.size()); + for (Entry> entry : dups.entrySet()) { + stream.print(entry.getValue().size() + " duplicates for hash: " + entry.getKey() + " values: "); + for (CharSequence str : entry.getValue()) { + stream.print("[" + str + "] "); + } + stream.println(); + } + } + + private interface Function { + R apply(P param); + } +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/HpackDecoderTest.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/HpackDecoderTest.java new file mode 100644 index 0000000..f70f528 --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/HpackDecoderTest.java @@ -0,0 +1,914 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/* + * Copyright 2014 Twitter, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.util.AsciiString; +import io.netty.util.internal.StringUtil; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.CsvSource; +import org.mockito.MockingDetails; +import org.mockito.invocation.Invocation; + +import java.lang.reflect.Method; + +import static io.netty.handler.codec.http2.HpackDecoder.decodeULE128; +import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR; +import static io.netty.handler.codec.http2.Http2HeadersEncoder.NEVER_SENSITIVE; +import static io.netty.util.AsciiString.EMPTY_STRING; +import static io.netty.util.AsciiString.of; +import static java.lang.Integer.MAX_VALUE; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.mockingDetails; +import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; + +public class HpackDecoderTest { + + private HpackDecoder hpackDecoder; + private Http2Headers mockHeaders; + + private static String hex(String s) { + return StringUtil.toHexString(s.getBytes()); + } + + private void decode(String encoded) throws Http2Exception { + byte[] b = StringUtil.decodeHexDump(encoded); + ByteBuf in = Unpooled.wrappedBuffer(b); + try { + hpackDecoder.decode(0, in, mockHeaders, true); + } finally { + in.release(); + } + } + + @BeforeEach + public void setUp() { + hpackDecoder = new HpackDecoder(8192); + mockHeaders = mock(Http2Headers.class); + } + + @Test + public void testDecodeULE128IntMax() throws Http2Exception { + byte[] input = {(byte) 0xFF, (byte) 0xFF, (byte) 0xFF, (byte) 0xFF, (byte) 0x07}; + ByteBuf in = Unpooled.wrappedBuffer(input); + try { + assertEquals(MAX_VALUE, decodeULE128(in, 0)); + } finally { + in.release(); + } + } + + @Test + public void testDecodeULE128IntOverflow1() throws Http2Exception { + byte[] input = {(byte) 0xFF, (byte) 0xFF, (byte) 0xFF, (byte) 0xFF, (byte) 0x07}; + final ByteBuf in = Unpooled.wrappedBuffer(input); + final int readerIndex = in.readerIndex(); + try { + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + decodeULE128(in, 1); + } + }); + } finally { + assertEquals(readerIndex, in.readerIndex()); + in.release(); + } + } + + @Test + public void testDecodeULE128IntOverflow2() throws Http2Exception { + byte[] input = {(byte) 0xFF, (byte) 0xFF, (byte) 0xFF, (byte) 0xFF, (byte) 0x08}; + final ByteBuf in = Unpooled.wrappedBuffer(input); + final int readerIndex = in.readerIndex(); + try { + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + decodeULE128(in, 0); + } + }); + } finally { + assertEquals(readerIndex, in.readerIndex()); + in.release(); + } + } + + @Test + public void testDecodeULE128LongMax() throws Http2Exception { + byte[] input = {(byte) 0xFF, (byte) 0xFF, (byte) 0xFF, (byte) 0xFF, (byte) 0xFF, (byte) 0xFF, (byte) 0xFF, + (byte) 0xFF, (byte) 0x7F}; + ByteBuf in = Unpooled.wrappedBuffer(input); + try { + assertEquals(Long.MAX_VALUE, decodeULE128(in, 0L)); + } finally { + in.release(); + } + } + + @Test + public void testDecodeULE128LongOverflow1() throws Http2Exception { + byte[] input = {(byte) 0xFF, (byte) 0xFF, (byte) 0xFF, (byte) 0xFF, (byte) 0xFF, (byte) 0xFF, (byte) 0xFF, + (byte) 0xFF, (byte) 0xFF}; + final ByteBuf in = Unpooled.wrappedBuffer(input); + final int readerIndex = in.readerIndex(); + try { + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + decodeULE128(in, 0L); + } + }); + } finally { + assertEquals(readerIndex, in.readerIndex()); + in.release(); + } + } + + @Test + public void testDecodeULE128LongOverflow2() throws Http2Exception { + byte[] input = {(byte) 0xFF, (byte) 0xFF, (byte) 0xFF, (byte) 0xFF, (byte) 0xFF, (byte) 0xFF, (byte) 0xFF, + (byte) 0xFF, (byte) 0x7F}; + final ByteBuf in = Unpooled.wrappedBuffer(input); + final int readerIndex = in.readerIndex(); + try { + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + decodeULE128(in, 1L); + } + }); + } finally { + assertEquals(readerIndex, in.readerIndex()); + in.release(); + } + } + + @Test + public void testSetTableSizeWithMaxUnsigned32BitValueSucceeds() throws Http2Exception { + byte[] input = {(byte) 0x3F, (byte) 0xFF, (byte) 0xFF, (byte) 0xFF, (byte) 0xFF, (byte) 0x0E}; + ByteBuf in = Unpooled.wrappedBuffer(input); + try { + final long expectedHeaderSize = 4026531870L; // based on the input above + hpackDecoder.setMaxHeaderTableSize(expectedHeaderSize); + hpackDecoder.decode(0, in, mockHeaders, true); + assertEquals(expectedHeaderSize, hpackDecoder.getMaxHeaderTableSize()); + } finally { + in.release(); + } + } + + @Test + public void testSetTableSizeOverLimitFails() throws Http2Exception { + byte[] input = {(byte) 0x3F, (byte) 0xFF, (byte) 0xFF, (byte) 0xFF, (byte) 0xFF, (byte) 0x0E}; + final ByteBuf in = Unpooled.wrappedBuffer(input); + try { + hpackDecoder.setMaxHeaderTableSize(4026531870L - 1); // based on the input above ... 1 less than is above. + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + hpackDecoder.decode(0, in, mockHeaders, true); + } + }); + } finally { + in.release(); + } + } + + @Test + public void testLiteralHuffmanEncodedWithEmptyNameAndValue() throws Http2Exception { + byte[] input = {0, (byte) 0x80, 0}; + ByteBuf in = Unpooled.wrappedBuffer(input); + try { + hpackDecoder.decode(0, in, mockHeaders, true); + verify(mockHeaders, times(1)).add(EMPTY_STRING, EMPTY_STRING); + } finally { + in.release(); + } + } + + @Test + public void testLiteralHuffmanEncodedWithPaddingGreaterThan7Throws() throws Http2Exception { + byte[] input = {0, (byte) 0x81, -1}; + final ByteBuf in = Unpooled.wrappedBuffer(input); + try { + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + hpackDecoder.decode(0, in, mockHeaders, true); + } + }); + } finally { + in.release(); + } + } + + @Test + public void testLiteralHuffmanEncodedWithDecodingEOSThrows() throws Http2Exception { + byte[] input = {0, (byte) 0x84, (byte) 0xFF, (byte) 0xFF, (byte) 0xFF, (byte) 0xFF}; + final ByteBuf in = Unpooled.wrappedBuffer(input); + try { + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + hpackDecoder.decode(0, in, mockHeaders, true); + } + }); + } finally { + in.release(); + } + } + + @Test + public void testLiteralHuffmanEncodedWithPaddingNotCorrespondingToMSBThrows() throws Http2Exception { + byte[] input = {0, (byte) 0x81, 0}; + final ByteBuf in = Unpooled.wrappedBuffer(input); + try { + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + hpackDecoder.decode(0, in, mockHeaders, true); + } + }); + } finally { + in.release(); + } + } + + @Test + public void testIncompleteIndex() throws Http2Exception { + byte[] compressed = StringUtil.decodeHexDump("FFF0"); + final ByteBuf in = Unpooled.wrappedBuffer(compressed); + try { + assertEquals(2, in.readableBytes()); + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + hpackDecoder.decode(0, in, mockHeaders, true); + } + }); + } finally { + in.release(); + } + } + + @Test + public void testUnusedIndex() throws Http2Exception { + // Index 0 is not used + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + decode("80"); + } + }); + } + + @Test + public void testIllegalIndex() throws Http2Exception { + // Index larger than the header table + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + decode("FF00"); + } + }); + } + + @Test + public void testInsidiousIndex() throws Http2Exception { + // Insidious index so the last shift causes sign overflow + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + decode("FF8080808007"); + } + }); + } + + @Test + public void testDynamicTableSizeUpdate() throws Http2Exception { + decode("20"); + assertEquals(0, hpackDecoder.getMaxHeaderTableSize()); + decode("3FE11F"); + assertEquals(4096, hpackDecoder.getMaxHeaderTableSize()); + } + + @Test + public void testDynamicTableSizeUpdateRequired() throws Http2Exception { + hpackDecoder.setMaxHeaderTableSize(32); + decode("3F00"); + assertEquals(31, hpackDecoder.getMaxHeaderTableSize()); + } + + @Test + public void testIllegalDynamicTableSizeUpdate() throws Http2Exception { + // max header table size = MAX_HEADER_TABLE_SIZE + 1 + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + decode("3FE21F"); + } + }); + } + + @Test + public void testInsidiousMaxDynamicTableSize() throws Http2Exception { + hpackDecoder.setMaxHeaderTableSize(MAX_VALUE); + // max header table size sign overflow + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + decode("3FE1FFFFFF07"); + } + }); + } + + @Test + public void testMaxValidDynamicTableSize() throws Http2Exception { + hpackDecoder.setMaxHeaderTableSize(MAX_VALUE); + String baseValue = "3FE1FFFFFF0"; + for (int i = 0; i < 7; ++i) { + decode(baseValue + i); + } + } + + @Test + public void testReduceMaxDynamicTableSize() throws Http2Exception { + hpackDecoder.setMaxHeaderTableSize(0); + assertEquals(0, hpackDecoder.getMaxHeaderTableSize()); + decode("2081"); + } + + @Test + public void testTooLargeDynamicTableSizeUpdate() throws Http2Exception { + hpackDecoder.setMaxHeaderTableSize(0); + assertEquals(0, hpackDecoder.getMaxHeaderTableSize()); + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + decode("21"); // encoder max header table size not small enough + } + }); + } + + @Test + public void testMissingDynamicTableSizeUpdate() throws Http2Exception { + hpackDecoder.setMaxHeaderTableSize(0); + assertEquals(0, hpackDecoder.getMaxHeaderTableSize()); + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + decode("81"); + } + }); + } + + @Test + public void testDynamicTableSizeUpdateAfterTheBeginingOfTheBlock() throws Http2Exception { + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + decode("8120"); + } + }); + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + decode("813FE11F"); + } + }); + } + + @Test + public void testLiteralWithIncrementalIndexingWithEmptyName() throws Http2Exception { + decode("400005" + hex("value")); + verify(mockHeaders, times(1)).add(EMPTY_STRING, of("value")); + } + + @Test + public void testLiteralWithIncrementalIndexingCompleteEviction() throws Http2Exception { + // Verify indexed host header + decode("4004" + hex("name") + "05" + hex("value")); + verify(mockHeaders).add(of("name"), of("value")); + verifyNoMoreInteractions(mockHeaders); + + reset(mockHeaders); + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < 4096; i++) { + sb.append('a'); + } + String value = sb.toString(); + sb = new StringBuilder(); + sb.append("417F811F"); + for (int i = 0; i < 4096; i++) { + sb.append("61"); // 'a' + } + decode(sb.toString()); + verify(mockHeaders).add(of(":authority"), of(value)); + MockingDetails details = mockingDetails(mockHeaders); + for (Invocation invocation : details.getInvocations()) { + Method method = invocation.getMethod(); + if ("authority".equals(method.getName()) + && invocation.getArguments().length == 0) { + invocation.markVerified(); + } else if ("contains".equals(method.getName()) + && invocation.getArguments().length == 1 + && invocation.getArgument(0).equals(of(":authority"))) { + invocation.markVerified(); + } + } + verifyNoMoreInteractions(mockHeaders); + reset(mockHeaders); + + // Verify next header is inserted at index 62 + decode("4004" + hex("name") + "05" + hex("value") + "BE"); + verify(mockHeaders, times(2)).add(of("name"), of("value")); + verifyNoMoreInteractions(mockHeaders); + } + + @Test + public void testLiteralWithIncrementalIndexingWithLargeValue() throws Http2Exception { + // Ignore header that exceeds max header size + final StringBuilder sb = new StringBuilder(); + sb.append("4004"); + sb.append(hex("name")); + sb.append("7F813F"); + for (int i = 0; i < 8192; i++) { + sb.append("61"); // 'a' + } + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + decode(sb.toString()); + } + }); + } + + @Test + public void testLiteralWithoutIndexingWithEmptyName() throws Http2Exception { + decode("000005" + hex("value")); + verify(mockHeaders, times(1)).add(EMPTY_STRING, of("value")); + } + + @Test + public void testLiteralWithoutIndexingWithLargeName() throws Http2Exception { + // Ignore header name that exceeds max header size + final StringBuilder sb = new StringBuilder(); + sb.append("007F817F"); + for (int i = 0; i < 16384; i++) { + sb.append("61"); // 'a' + } + sb.append("00"); + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + decode(sb.toString()); + } + }); + } + + @Test + public void testLiteralWithoutIndexingWithLargeValue() throws Http2Exception { + // Ignore header that exceeds max header size + final StringBuilder sb = new StringBuilder(); + sb.append("0004"); + sb.append(hex("name")); + sb.append("7F813F"); + for (int i = 0; i < 8192; i++) { + sb.append("61"); // 'a' + } + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + decode(sb.toString()); + } + }); + } + + @Test + public void testLiteralNeverIndexedWithEmptyName() throws Http2Exception { + decode("100005" + hex("value")); + verify(mockHeaders, times(1)).add(EMPTY_STRING, of("value")); + } + + @Test + public void testLiteralNeverIndexedWithLargeName() throws Http2Exception { + // Ignore header name that exceeds max header size + final StringBuilder sb = new StringBuilder(); + sb.append("107F817F"); + for (int i = 0; i < 16384; i++) { + sb.append("61"); // 'a' + } + sb.append("00"); + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + decode(sb.toString()); + } + }); + } + + @Test + public void testLiteralNeverIndexedWithLargeValue() throws Http2Exception { + // Ignore header that exceeds max header size + final StringBuilder sb = new StringBuilder(); + sb.append("1004"); + sb.append(hex("name")); + sb.append("7F813F"); + for (int i = 0; i < 8192; i++) { + sb.append("61"); // 'a' + } + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + decode(sb.toString()); + } + }); + } + + @Test + public void testDecodeLargerThanMaxHeaderListSizeUpdatesDynamicTable() throws Http2Exception { + final ByteBuf in = Unpooled.buffer(300); + try { + hpackDecoder.setMaxHeaderListSize(200); + HpackEncoder hpackEncoder = new HpackEncoder(true); + + // encode headers that are slightly larger than maxHeaderListSize + Http2Headers toEncode = new DefaultHttp2Headers(); + toEncode.add("test_1", "1"); + toEncode.add("test_2", "2"); + toEncode.add("long", String.format("%0100d", 0).replace('0', 'A')); + toEncode.add("test_3", "3"); + hpackEncoder.encodeHeaders(1, in, toEncode, NEVER_SENSITIVE); + + // decode the headers, we should get an exception + final Http2Headers decoded = new DefaultHttp2Headers(); + assertThrows(Http2Exception.HeaderListSizeException.class, new Executable() { + @Override + public void execute() throws Throwable { + hpackDecoder.decode(1, in, decoded, true); + } + }); + + // but the dynamic table should have been updated, so that later blocks + // can refer to earlier headers + in.clear(); + // 0x80, "indexed header field representation" + // index 62, the first (most recent) dynamic table entry + in.writeByte(0x80 | 62); + Http2Headers decoded2 = new DefaultHttp2Headers(); + hpackDecoder.decode(1, in, decoded2, true); + + Http2Headers golden = new DefaultHttp2Headers(); + golden.add("test_3", "3"); + assertEquals(golden, decoded2); + } finally { + in.release(); + } + } + + @Test + public void testDecodeCountsNamesOnlyOnce() throws Http2Exception { + ByteBuf in = Unpooled.buffer(200); + try { + hpackDecoder.setMaxHeaderListSize(3500); + HpackEncoder hpackEncoder = new HpackEncoder(true); + + // encode headers that are slightly larger than maxHeaderListSize + Http2Headers toEncode = new DefaultHttp2Headers(); + toEncode.add(String.format("%03000d", 0).replace('0', 'f'), "value"); + toEncode.add("accept", "value"); + hpackEncoder.encodeHeaders(1, in, toEncode, NEVER_SENSITIVE); + + Http2Headers decoded = new DefaultHttp2Headers(); + hpackDecoder.decode(1, in, decoded, true); + assertEquals(2, decoded.size()); + } finally { + in.release(); + } + } + + @Test + public void testAccountForHeaderOverhead() throws Exception { + final ByteBuf in = Unpooled.buffer(100); + try { + String headerName = "12345"; + String headerValue = "56789"; + long headerSize = headerName.length() + headerValue.length(); + hpackDecoder.setMaxHeaderListSize(headerSize); + HpackEncoder hpackEncoder = new HpackEncoder(true); + + Http2Headers toEncode = new DefaultHttp2Headers(); + toEncode.add(headerName, headerValue); + hpackEncoder.encodeHeaders(1, in, toEncode, NEVER_SENSITIVE); + + final Http2Headers decoded = new DefaultHttp2Headers(); + + // SETTINGS_MAX_HEADER_LIST_SIZE is big enough for the header to fit... + assertThat(hpackDecoder.getMaxHeaderListSize(), is(greaterThanOrEqualTo(headerSize))); + + // ... but decode should fail because we add some overhead for each header entry + assertThrows(Http2Exception.HeaderListSizeException.class, new Executable() { + @Override + public void execute() throws Throwable { + hpackDecoder.decode(1, in, decoded, true); + } + }); + } finally { + in.release(); + } + } + + @Test + public void testIncompleteHeaderFieldRepresentation() throws Http2Exception { + // Incomplete Literal Header Field with Incremental Indexing + byte[] input = {(byte) 0x40}; + final ByteBuf in = Unpooled.wrappedBuffer(input); + try { + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + hpackDecoder.decode(0, in, mockHeaders, true); + } + }); + } finally { + in.release(); + } + } + + @Test + public void unknownPseudoHeader() throws Exception { + final ByteBuf in = Unpooled.buffer(200); + try { + HpackEncoder hpackEncoder = new HpackEncoder(true); + + Http2Headers toEncode = new DefaultHttp2Headers(false); + toEncode.add(":test", "1"); + hpackEncoder.encodeHeaders(1, in, toEncode, NEVER_SENSITIVE); + + final Http2Headers decoded = new DefaultHttp2Headers(true); + + assertThrows(Http2Exception.StreamException.class, new Executable() { + @Override + public void execute() throws Throwable { + hpackDecoder.decode(1, in, decoded, true); + } + }); + } finally { + in.release(); + } + } + + @Test + public void disableHeaderValidation() throws Exception { + ByteBuf in = Unpooled.buffer(200); + try { + HpackEncoder hpackEncoder = new HpackEncoder(true); + + Http2Headers toEncode = new DefaultHttp2Headers(false); + toEncode.add(":test", "1"); + toEncode.add(":status", "200"); + toEncode.add(":method", "GET"); + hpackEncoder.encodeHeaders(1, in, toEncode, NEVER_SENSITIVE); + + Http2Headers decoded = new DefaultHttp2Headers(false); + + hpackDecoder.decode(1, in, decoded, false); + + assertThat(decoded.valueIterator(":test").next().toString(), is("1")); + assertThat(decoded.status().toString(), is("200")); + assertThat(decoded.method().toString(), is("GET")); + } finally { + in.release(); + } + } + + @Test + public void requestPseudoHeaderInResponse() throws Exception { + final ByteBuf in = Unpooled.buffer(200); + try { + HpackEncoder hpackEncoder = new HpackEncoder(true); + + Http2Headers toEncode = new DefaultHttp2Headers(); + toEncode.add(":status", "200"); + toEncode.add(":method", "GET"); + hpackEncoder.encodeHeaders(1, in, toEncode, NEVER_SENSITIVE); + + final Http2Headers decoded = new DefaultHttp2Headers(); + + assertThrows(Http2Exception.StreamException.class, new Executable() { + @Override + public void execute() throws Throwable { + hpackDecoder.decode(1, in, decoded, true); + } + }); + } finally { + in.release(); + } + } + + @Test + public void responsePseudoHeaderInRequest() throws Exception { + final ByteBuf in = Unpooled.buffer(200); + try { + HpackEncoder hpackEncoder = new HpackEncoder(true); + + Http2Headers toEncode = new DefaultHttp2Headers(); + toEncode.add(":method", "GET"); + toEncode.add(":status", "200"); + hpackEncoder.encodeHeaders(1, in, toEncode, NEVER_SENSITIVE); + + final Http2Headers decoded = new DefaultHttp2Headers(); + + assertThrows(Http2Exception.StreamException.class, new Executable() { + @Override + public void execute() throws Throwable { + hpackDecoder.decode(1, in, decoded, true); + } + }); + } finally { + in.release(); + } + } + + @Test + public void pseudoHeaderAfterRegularHeader() throws Exception { + final ByteBuf in = Unpooled.buffer(200); + try { + HpackEncoder hpackEncoder = new HpackEncoder(true); + + Http2Headers toEncode = new InOrderHttp2Headers(); + toEncode.add("test", "1"); + toEncode.add(":method", "GET"); + hpackEncoder.encodeHeaders(1, in, toEncode, NEVER_SENSITIVE); + + final Http2Headers decoded = new DefaultHttp2Headers(); + + Http2Exception.StreamException e = assertThrows(Http2Exception.StreamException.class, new Executable() { + @Override + public void execute() throws Throwable { + hpackDecoder.decode(3, in, decoded, true); + } + }); + assertThat(e.streamId(), is(3)); + assertThat(e.error(), is(PROTOCOL_ERROR)); + } finally { + in.release(); + } + } + + @ParameterizedTest(name = "{displayName} [{index}] name={0} value={1}") + @CsvSource(value = {"upgrade,protocol1", "connection,close", "keep-alive,timeout=5", "proxy-connection,close", + "transfer-encoding,chunked", "te,something-else"}) + public void receivedConnectionHeader(String name, String value) throws Exception { + final ByteBuf in = Unpooled.buffer(200); + try { + HpackEncoder hpackEncoder = new HpackEncoder(true); + + Http2Headers toEncode = new InOrderHttp2Headers(); + toEncode.add(":method", "GET"); + toEncode.add(name, value); + hpackEncoder.encodeHeaders(1, in, toEncode, NEVER_SENSITIVE); + + final Http2Headers decoded = new DefaultHttp2Headers(); + + Http2Exception.StreamException e = assertThrows(Http2Exception.StreamException.class, new Executable() { + @Override + public void execute() throws Throwable { + hpackDecoder.decode(3, in, decoded, true); + } + }); + assertThat(e.streamId(), is(3)); + assertThat(e.error(), is(PROTOCOL_ERROR)); + } finally { + in.release(); + } + } + + @Test + public void failedValidationDoesntCorruptHpack() throws Exception { + final ByteBuf in1 = Unpooled.buffer(200); + ByteBuf in2 = Unpooled.buffer(200); + try { + HpackEncoder hpackEncoder = new HpackEncoder(true); + + Http2Headers toEncode = new DefaultHttp2Headers(); + toEncode.add(":method", "GET"); + toEncode.add(":status", "200"); + toEncode.add("foo", "bar"); + hpackEncoder.encodeHeaders(1, in1, toEncode, NEVER_SENSITIVE); + + final Http2Headers decoded = new DefaultHttp2Headers(); + + Http2Exception.StreamException expected = + assertThrows(Http2Exception.StreamException.class, new Executable() { + @Override + public void execute() throws Throwable { + hpackDecoder.decode(1, in1, decoded, true); + } + }); + assertEquals(1, expected.streamId()); + + // Do it again, this time without validation, to make sure the HPACK state is still sane. + decoded.clear(); + hpackEncoder.encodeHeaders(1, in2, toEncode, NEVER_SENSITIVE); + hpackDecoder.decode(1, in2, decoded, false); + + assertEquals(3, decoded.size()); + assertEquals("GET", decoded.method().toString()); + assertEquals("200", decoded.status().toString()); + assertEquals("bar", decoded.get("foo").toString()); + } finally { + in1.release(); + in2.release(); + } + } + + @ParameterizedTest + @CsvSource(value = {":method,''", ":scheme,''", ":authority,''", ":path,''"}) + public void testPseudoHeaderEmptyValidationEnabled(String name, String value) throws Exception { + final ByteBuf in = Unpooled.buffer(200); + try { + HpackEncoder hpackEncoder = new HpackEncoder(true); + + Http2Headers toEncode = new InOrderHttp2Headers(); + toEncode.add(name, value); + hpackEncoder.encodeHeaders(1, in, toEncode, NEVER_SENSITIVE); + + final Http2Headers decoded = new DefaultHttp2Headers(); + + Http2Exception.StreamException e = assertThrows(Http2Exception.StreamException.class, new Executable() { + @Override + public void execute() throws Throwable { + hpackDecoder.decode(3, in, decoded, true); + } + }); + assertThat(e.streamId(), is(3)); + assertThat(e.error(), is(PROTOCOL_ERROR)); + } finally { + in.release(); + } + } + + @ParameterizedTest + @CsvSource(value = {":method,''", ":scheme,''", ":authority,''", ":path,''"}) + public void testPseudoHeaderEmptyValidationDisabled(String name, String value) throws Exception { + final ByteBuf in = Unpooled.buffer(200); + try { + HpackEncoder hpackEncoder = new HpackEncoder(true); + + Http2Headers toEncode = new InOrderHttp2Headers(); + toEncode.add(name, value); + hpackEncoder.encodeHeaders(1, in, toEncode, NEVER_SENSITIVE); + + final Http2Headers decoded = new DefaultHttp2Headers(false); + hpackDecoder.decode(3, in, decoded, true); + + assertSame(AsciiString.EMPTY_STRING, decoded.get(name)); + } finally { + in.release(); + } + } +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/HpackDynamicTableTest.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/HpackDynamicTableTest.java new file mode 100644 index 0000000..56b0501 --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/HpackDynamicTableTest.java @@ -0,0 +1,142 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package io.netty.handler.codec.http2; + +import io.netty.util.AsciiString; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class HpackDynamicTableTest { + private static final AsciiString FOO = AsciiString.cached("foo"); + private static final AsciiString BAR = AsciiString.cached("bar"); + private static final AsciiString HELLO = AsciiString.cached("hello"); + private static final AsciiString WORLD = AsciiString.cached("world"); + + @Test + public void testLength() { + HpackDynamicTable table = new HpackDynamicTable(100); + assertEquals(0, table.length()); + HpackHeaderField entry = new HpackHeaderField(FOO, BAR); + table.add(entry); + assertEquals(1, table.length()); + table.clear(); + assertEquals(0, table.length()); + } + + @Test + public void testSize() { + HpackDynamicTable table = new HpackDynamicTable(100); + assertEquals(0, table.size()); + HpackHeaderField entry = new HpackHeaderField(FOO, BAR); + table.add(entry); + assertEquals(entry.size(), table.size()); + table.clear(); + assertEquals(0, table.size()); + } + + @Test + public void testGetEntry() { + final HpackDynamicTable table = new HpackDynamicTable(100); + HpackHeaderField entry = new HpackHeaderField(FOO, BAR); + table.add(entry); + assertEquals(entry, table.getEntry(1)); + table.clear(); + + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() throws Throwable { + table.getEntry(1); + } + }); + } + + @Test + public void testGetEntryExceptionally() { + final HpackDynamicTable table = new HpackDynamicTable(1); + assertThrows(IndexOutOfBoundsException.class, new Executable() { + @Override + public void execute() throws Throwable { + table.getEntry(1); + } + }); + } + + @Test + public void testRemove() { + HpackDynamicTable table = new HpackDynamicTable(100); + assertNull(table.remove()); + HpackHeaderField entry1 = new HpackHeaderField(FOO, BAR); + HpackHeaderField entry2 = new HpackHeaderField(HELLO, WORLD); + table.add(entry1); + table.add(entry2); + assertEquals(entry1, table.remove()); + assertEquals(entry2, table.getEntry(1)); + assertEquals(1, table.length()); + assertEquals(entry2.size(), table.size()); + } + + @Test + public void testSetCapacity() { + HpackHeaderField entry1 = new HpackHeaderField(FOO, BAR); + HpackHeaderField entry2 = new HpackHeaderField(HELLO, WORLD); + final int size1 = entry1.size(); + final int size2 = entry2.size(); + HpackDynamicTable table = new HpackDynamicTable(size1 + size2); + table.add(entry1); + table.add(entry2); + assertEquals(2, table.length()); + assertEquals(size1 + size2, table.size()); + table.setCapacity(((long) size1 + size2) * 2); //larger capacity + assertEquals(2, table.length()); + assertEquals(size1 + size2, table.size()); + table.setCapacity(size2); //smaller capacity + //entry1 will be removed + assertEquals(1, table.length()); + assertEquals(size2, table.size()); + assertEquals(entry2, table.getEntry(1)); + table.setCapacity(0); //clear all + assertEquals(0, table.length()); + assertEquals(0, table.size()); + } + + @Test + public void testAdd() { + HpackDynamicTable table = new HpackDynamicTable(100); + assertEquals(0, table.size()); + HpackHeaderField entry1 = new HpackHeaderField(FOO, BAR); //size:3+3+32=38 + HpackHeaderField entry2 = new HpackHeaderField(HELLO, WORLD); + table.add(entry1); //success + assertEquals(entry1.size(), table.size()); + table.setCapacity(32); //entry1 is removed from table + assertEquals(0, table.size()); + assertEquals(0, table.length()); + table.add(entry1); //fail quietly + assertEquals(0, table.size()); + assertEquals(0, table.length()); + table.setCapacity(64); + table.add(entry1); //success + assertEquals(entry1.size(), table.size()); + assertEquals(1, table.length()); + table.add(entry2); //entry2 is added, but entry1 is removed from table + assertEquals(entry2.size(), table.size()); + assertEquals(1, table.length()); + assertEquals(entry2, table.getEntry(1)); + } +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/HpackEncoderTest.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/HpackEncoderTest.java new file mode 100644 index 0000000..2a54fbf --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/HpackEncoderTest.java @@ -0,0 +1,281 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import java.util.Random; + +import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_HEADER_LIST_SIZE; +import static io.netty.handler.codec.http2.Http2CodecUtil.MAX_HEADER_TABLE_SIZE; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.Mockito.mock; + +public class HpackEncoderTest { + private HpackDecoder hpackDecoder; + private HpackEncoder hpackEncoder; + private Http2Headers mockHeaders; + private ByteBuf buf; + + @BeforeEach + public void setUp() { + hpackEncoder = new HpackEncoder(); + hpackDecoder = new HpackDecoder(DEFAULT_HEADER_LIST_SIZE); + mockHeaders = mock(Http2Headers.class); + buf = Unpooled.buffer(); + } + + @AfterEach + public void teardown() { + buf.release(); + } + + @Test + public void testSetMaxHeaderTableSizeToMaxValue() throws Http2Exception { + hpackEncoder.setMaxHeaderTableSize(buf, MAX_HEADER_TABLE_SIZE); + hpackDecoder.setMaxHeaderTableSize(MAX_HEADER_TABLE_SIZE); + hpackDecoder.decode(0, buf, mockHeaders, true); + assertEquals(MAX_HEADER_TABLE_SIZE, hpackDecoder.getMaxHeaderTableSize()); + } + + @Test + public void testSetMaxHeaderTableSizeOverflow() { + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + hpackEncoder.setMaxHeaderTableSize(buf, MAX_HEADER_TABLE_SIZE + 1); + } + }); + } + + /** + * The encoder should not impose an arbitrary limit on the header size if + * the server has not specified any limit. + * @throws Http2Exception + */ + @Test + public void testWillEncode16MBHeaderByDefault() throws Http2Exception { + String bigHeaderName = "x-big-header"; + int bigHeaderSize = 1024 * 1024 * 16; + String bigHeaderVal = new String(new char[bigHeaderSize]).replace('\0', 'X'); + Http2Headers headersIn = new DefaultHttp2Headers().add( + "x-big-header", bigHeaderVal); + Http2Headers headersOut = new DefaultHttp2Headers(); + + hpackEncoder.encodeHeaders(0, buf, headersIn, Http2HeadersEncoder.NEVER_SENSITIVE); + hpackDecoder.setMaxHeaderListSize(bigHeaderSize + 1024); + hpackDecoder.decode(0, buf, headersOut, false); + assertEquals(headersOut.get(bigHeaderName).toString(), bigHeaderVal); + } + + @Test + public void testSetMaxHeaderListSizeEnforcedAfterSet() throws Http2Exception { + final Http2Headers headers = new DefaultHttp2Headers().add( + "x-big-header", + new String(new char[1024 * 16]).replace('\0', 'X') + ); + + hpackEncoder.setMaxHeaderListSize(1000); + + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + hpackEncoder.encodeHeaders(0, buf, headers, Http2HeadersEncoder.NEVER_SENSITIVE); + } + }); + } + + @Test + public void testEncodeUsingBothStaticAndDynamicTable() throws Http2Exception { + final Http2Headers headers = new DefaultHttp2Headers() + // :method -> POST is found in the static table. + .add(":method", "POST") + + // ":path" is found in the static table but only matches "/" and "/index.html". + .add(":path", "/dev/null") + + // "accept-language" is found in the static table, but with no matching value. + .add("accept-language", "fr") + + // k -> x is not in the static table. + .add("k", "x"); + + // :method -> POST gets encoded by reference. + // :path -> /dev/null + // :path gets encoded by reference, /dev/null literally. + // accept-language -> fr + // accept-language gets encoded by reference, fr literally. + // k -> x + // both k and x get encoded literally. + verifyEncoding(headers, + -125, 68, 9, 47, 100, 101, 118, 47, 110, 117, 108, 108, 81, 2, 102, 114, 64, 1, 107, 1, 120); + + // encoded using references to previous headers. + verifyEncoding(headers, -125, -64, -65, -66); + } + + @Test + public void testSameHeaderNameMultipleValues() throws Http2Exception { + final Http2Headers headers = new DefaultHttp2Headers() + .add("k", "x") + .add("k", "y"); + + // k -> x encoded literally, k -> y encoded by referencing k of the k -> x header, + // y gets encoded literally. + verifyEncoding(headers, 64, 1, 107, 1, 120, 126, 1, 121); + + // both k -> x and k -> y encoded by reference. + verifyEncoding(headers, -65, -66); + } + + @Test + public void testEviction() throws Http2Exception { + setMaxTableSize(2 * HpackHeaderField.HEADER_ENTRY_OVERHEAD + 3); + + // k -> x encoded literally + verifyEncoding(new DefaultHttp2Headers().add("k", "x"), 63, 36, 64, 1, 107, 1, 120); + + // k -> x encoded by referencing the previously encoded k -> x. + verifyEncoding(new DefaultHttp2Headers().add("k", "x"), -66); + + // k -> x gets evicted + verifyEncoding(new DefaultHttp2Headers().add("k", "y"), 64, 1, 107, 1, 121); + + // k -> x was evicted, so we are back to literal encoding. + verifyEncoding(new DefaultHttp2Headers().add("k", "x"), 64, 1, 107, 1, 120); + } + + @Test + public void testTableResize() throws Http2Exception { + verifyEncoding(new DefaultHttp2Headers().add("k", "x").add("k", "y"), 64, 1, 107, 1, 120, 126, 1, 121); + + // k -> x gets encoded by referencing the previously encoded k -> x. + verifyEncoding(new DefaultHttp2Headers().add("k", "x"), -65); + + // k -> x gets evicted + setMaxTableSize(2 * HpackHeaderField.HEADER_ENTRY_OVERHEAD + 3); + + // k -> x header was evicted, so we are back to literal encoding. + verifyEncoding(new DefaultHttp2Headers().add("k", "x"), 63, 36, 64, 1, 107, 1, 120); + + // make room for k -> y + setMaxTableSize(1000); + + verifyEncoding(new DefaultHttp2Headers().add("k", "y"), 63, -55, 7, 126, 1, 121); + + // both k -> x and k -> y are encoded by reference. + verifyEncoding(new DefaultHttp2Headers().add("k", "x").add("k", "y"), -65, -66); + } + + @Test + public void testManyHeaderCombinations() throws Http2Exception { + final Random r = new Random(0); + for (int i = 0; i < 50000; i++) { + if (r.nextInt(10) == 0) { + setMaxTableSize(r.nextBoolean() ? 0 : r.nextInt(4096)); + } + verifyRoundTrip(new DefaultHttp2Headers() + .add("k" + r.nextInt(20), "x" + r.nextInt(500)) + .add(":method", r.nextBoolean() ? "GET" : "POST") + .add(":path", "/dev/null") + .add("accept-language", String.valueOf(r.nextBoolean())) + ); + buf.clear(); + } + } + + @Test + public void testSanitization() throws Http2Exception { + final int headerValueSize = 300; + StringBuilder actualHeaderValueBuilder = new StringBuilder(); + StringBuilder expectedHeaderValueBuilder = new StringBuilder(); + + for (int i = 0; i < headerValueSize; i++) { + actualHeaderValueBuilder.append((char) i); // Use the index as the code point value of the character. + if (i <= 255) { + expectedHeaderValueBuilder.append((char) i); + } else { + expectedHeaderValueBuilder.append('?'); // Expect this character to be sanitized. + } + } + String actualHeaderValue = actualHeaderValueBuilder.toString(); + String expectedHeaderValue = expectedHeaderValueBuilder.toString(); + HpackEncoder encoderWithHuffmanEncoding = + new HpackEncoder(false, 64, 0); // Low Huffman code threshold. + HpackEncoder encoderWithoutHuffmanEncoding = + new HpackEncoder(false, 64, Integer.MAX_VALUE); // High Huffman code threshold. + + // Expect the same decoded header value regardless of whether Huffman encoding is enabled or not. + verifyHeaderValueSanitization(encoderWithHuffmanEncoding, actualHeaderValue, expectedHeaderValue); + verifyHeaderValueSanitization(encoderWithoutHuffmanEncoding, actualHeaderValue, expectedHeaderValue); + } + + private void verifyHeaderValueSanitization( + HpackEncoder encoder, + String actualHeaderValue, + String expectedHeaderValue + ) throws Http2Exception { + + String headerKey = "some-key"; + Http2Headers toBeEncodedHeaders = new DefaultHttp2Headers().add(headerKey, actualHeaderValue); + encoder.encodeHeaders(0, buf, toBeEncodedHeaders, Http2HeadersEncoder.NEVER_SENSITIVE); + DefaultHttp2Headers decodedHeaders = new DefaultHttp2Headers(); + hpackDecoder.decode(0, buf, decodedHeaders, true); + buf.clear(); + String decodedHeaderValue = decodedHeaders.get(headerKey).toString(); + Assertions.assertEquals(expectedHeaderValue, decodedHeaderValue); + } + + private void setMaxTableSize(int maxHeaderTableSize) throws Http2Exception { + hpackEncoder.setMaxHeaderTableSize(buf, maxHeaderTableSize); + hpackDecoder.setMaxHeaderTableSize(maxHeaderTableSize); + } + + private void verifyEncoding(Http2Headers encodedHeaders, int... encoding) throws Http2Exception { + verifyRoundTrip(encodedHeaders); + verifyEncodedBytes(encoding); + buf.clear(); + } + + private void verifyRoundTrip(Http2Headers encodedHeaders) throws Http2Exception { + hpackEncoder.encodeHeaders(0, buf, encodedHeaders, Http2HeadersEncoder.NEVER_SENSITIVE); + DefaultHttp2Headers decodedHeaders = new DefaultHttp2Headers(); + hpackDecoder.decode(0, buf, decodedHeaders, true); + assertEquals(encodedHeaders, decodedHeaders); + } + + private void verifyEncodedBytes(int... expectedEncoding) { + // We want to copy everything that was written to the buffer. + byte[] actualEncoding = new byte[buf.writerIndex()]; + buf.getBytes(0, actualEncoding); + Assertions.assertArrayEquals(toByteArray(expectedEncoding), actualEncoding); + } + + private byte[] toByteArray(int[] encoding) { + byte[] expectedEncoding = new byte[encoding.length]; + for (int i = 0; i < encoding.length; i++) { + expectedEncoding[i] = (byte) encoding[i]; + } + return expectedEncoding; + } +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/HpackHuffmanTest.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/HpackHuffmanTest.java new file mode 100644 index 0000000..8b4bddc --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/HpackHuffmanTest.java @@ -0,0 +1,247 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/* + * Copyright 2014 Twitter, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.util.AsciiString; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import java.util.Arrays; +import java.util.Random; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class HpackHuffmanTest { + + @Test + public void testHuffman() throws Http2Exception { + String s = "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789"; + for (int i = 0; i < s.length(); i++) { + roundTrip(s.substring(0, i)); + } + + Random random = new Random(123456789L); + byte[] buf = new byte[4096]; + random.nextBytes(buf); + roundTrip(buf); + } + + @Test + public void testDecodeEOS() throws Http2Exception { + final byte[] buf = new byte[4]; + for (int i = 0; i < 4; i++) { + buf[i] = (byte) 0xFF; + } + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + decode(buf); + } + }); + } + + @Test + public void testDecodeIllegalPadding() throws Http2Exception { + final byte[] buf = new byte[1]; + buf[0] = 0x00; // '0', invalid padding + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + decode(buf); + } + }); + } + + @Test + public void testDecodeExtraPadding() throws Http2Exception { + final byte[] buf = makeBuf(0x0f, 0xFF); // '1', 'EOS' + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + decode(buf); + } + }); + } + + @Test + public void testDecodeExtraPadding1byte() throws Http2Exception { + final byte[] buf = makeBuf(0xFF); + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + decode(buf); + } + }); + } + + @Test + public void testDecodeExtraPadding2byte() throws Http2Exception { + final byte[] buf = makeBuf(0x1F, 0xFF); // 'a' + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + decode(buf); + } + }); + } + + @Test + public void testDecodeExtraPadding3byte() throws Http2Exception { + final byte[] buf = makeBuf(0x1F, 0xFF, 0xFF); // 'a' + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + decode(buf); + } + }); + } + + @Test + public void testDecodeExtraPadding4byte() throws Http2Exception { + final byte[] buf = makeBuf(0x1F, 0xFF, 0xFF, 0xFF); // 'a' + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + decode(buf); + } + }); + } + + @Test + public void testDecodeExtraPadding29bit() throws Http2Exception { + final byte[] buf = makeBuf(0xFF, 0x9F, 0xFF, 0xFF, 0xFF); // '|' + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + decode(buf); + } + }); + } + + @Test + public void testDecodePartialSymbol() throws Http2Exception { + final byte[] buf = + makeBuf(0x52, 0xBC, 0x30, 0xFF, 0xFF, 0xFF, 0xFF); // " pFA\x00", 31 bits of padding, a.k.a. EOS + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + decode(buf); + } + }); + } + + @Test + public void testEncoderSanitizingMultiByteCharacters() throws Http2Exception { + final int inputLen = 500; + StringBuilder sb = new StringBuilder(); + for (int i = 0; i < inputLen; i++) { + // Starts with 0x4E01 because certain suboptimal sanitization could cause some problem with this input. + // For example, if a multibyte character C is sanitized by doing (C & OxFF), if C == 0x4E01, then + // (0x4E01 & OxFF) is greater than zero which indicates insufficient sanitization. + sb.append((char) (0x4E01 + i)); + } + HpackHuffmanEncoder encoder = new HpackHuffmanEncoder(); + String toBeEncoded = sb.toString(); + ByteBuf buffer = Unpooled.buffer(); + byte[] bytes; + try { + encoder.encode(buffer, toBeEncoded); + bytes = new byte[buffer.readableBytes()]; + buffer.readBytes(bytes); + } finally { + buffer.release(); // Release as soon as possible. + } + byte[] actualBytes = decode(bytes); + String actualDecoded = new String(actualBytes); + char[] charArray = new char[inputLen]; + Arrays.fill(charArray, '?'); + String expectedDecoded = new String(charArray); + assertEquals( + expectedDecoded, + actualDecoded, + "Expect the decoded string to be sanitized and contains only '?' characters." + ); + } + + private static byte[] makeBuf(int ... bytes) { + byte[] buf = new byte[bytes.length]; + for (int i = 0; i < buf.length; i++) { + buf[i] = (byte) bytes[i]; + } + return buf; + } + + private static void roundTrip(String s) throws Http2Exception { + roundTrip(new HpackHuffmanEncoder(), s); + } + + private static void roundTrip(HpackHuffmanEncoder encoder, String s) + throws Http2Exception { + roundTrip(encoder, s.getBytes()); + } + + private static void roundTrip(byte[] buf) throws Http2Exception { + roundTrip(new HpackHuffmanEncoder(), buf); + } + + private static void roundTrip(HpackHuffmanEncoder encoder, byte[] buf) + throws Http2Exception { + ByteBuf buffer = Unpooled.buffer(); + try { + encoder.encode(buffer, new AsciiString(buf, false)); + byte[] bytes = new byte[buffer.readableBytes()]; + buffer.readBytes(bytes); + + byte[] actualBytes = decode(bytes); + + assertArrayEquals(buf, actualBytes); + } finally { + buffer.release(); + } + } + + private static byte[] decode(byte[] bytes) throws Http2Exception { + ByteBuf buffer = Unpooled.wrappedBuffer(bytes); + try { + AsciiString decoded = new HpackHuffmanDecoder().decode(buffer, buffer.readableBytes()); + assertFalse(buffer.isReadable()); + return decoded.toByteArray(); + } finally { + buffer.release(); + } + } +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/HpackStaticTableTest.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/HpackStaticTableTest.java new file mode 100644 index 0000000..6f238f3 --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/HpackStaticTableTest.java @@ -0,0 +1,76 @@ +/* + * Copyright 2022 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package io.netty.handler.codec.http2; + +import io.netty.util.AsciiString; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class HpackStaticTableTest { + + @Test + public void testEmptyHeaderName() { + assertEquals(-1, HpackStaticTable.getIndex("")); + } + + @Test + public void testMissingHeaderName() { + assertEquals(-1, HpackStaticTable.getIndex("missing")); + } + + @Test + public void testExistingHeaderName() { + assertEquals(6, HpackStaticTable.getIndex(":scheme")); + } + + @Test + public void testMissingHeaderNameAndValue() { + assertEquals(-1, HpackStaticTable.getIndexInsensitive("missing", "value")); + } + + @Test + public void testMissingHeaderNameButValueExists() { + assertEquals(-1, HpackStaticTable.getIndexInsensitive("missing", "https")); + } + + @Test + public void testExistingHeaderNameAndValueFirstMatch() { + assertEquals(6, HpackStaticTable.getIndexInsensitive(":scheme", "http")); + } + + @Test + public void testExistingHeaderNameAndValueSecondMatch() { + assertEquals(7, HpackStaticTable.getIndexInsensitive( + AsciiString.cached(":scheme"), AsciiString.cached("https"))); + } + + @Test + public void testExistingHeaderNameAndEmptyValueMismatch() { + assertEquals(-1, HpackStaticTable.getIndexInsensitive(":scheme", "")); + } + + @Test + public void testExistingHeaderNameAndEmptyValueMatch() { + assertEquals(27, HpackStaticTable.getIndexInsensitive("content-language", "")); + } + + @Test + public void testExistingHeaderNameButMissingValue() { + assertEquals(-1, HpackStaticTable.getIndexInsensitive(":scheme", "missing")); + } + +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/HpackTest.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/HpackTest.java new file mode 100644 index 0000000..342fbef --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/HpackTest.java @@ -0,0 +1,61 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/* + * Copyright 2014 Twitter, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.ResourcesUtil; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import java.io.File; +import java.io.InputStream; + +public class HpackTest { + + private static final String TEST_DIR = '/' + HpackTest.class.getPackage().getName().replaceAll("\\.", "/") + + "/testdata/"; + + public static File[] files() { + File[] files = ResourcesUtil.getFile(HpackTest.class, TEST_DIR).listFiles(); + ObjectUtil.checkNotNull(files, "files"); + return files; + } + + @ParameterizedTest(name = "file = {0}") + @MethodSource("files") + public void test(File file) throws Exception { + InputStream is = HpackTest.class.getResourceAsStream(TEST_DIR + file.getName()); + HpackTestCase hpackTestCase = HpackTestCase.load(is); + hpackTestCase.testCompress(); + hpackTestCase.testDecompress(); + } +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/HpackTestCase.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/HpackTestCase.java new file mode 100644 index 0000000..6d230ca --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/HpackTestCase.java @@ -0,0 +1,290 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/* + * Copyright 2014 Twitter, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.netty.handler.codec.http2; + +import com.google.gson.FieldNamingPolicy; +import com.google.gson.Gson; +import com.google.gson.GsonBuilder; +import com.google.gson.JsonDeserializationContext; +import com.google.gson.JsonDeserializer; +import com.google.gson.JsonElement; +import com.google.gson.JsonObject; +import com.google.gson.JsonParseException; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.util.internal.StringUtil; + +import java.io.InputStream; +import java.io.InputStreamReader; +import java.lang.reflect.Type; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; +import java.util.Set; + +import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_HEADER_LIST_SIZE; +import static io.netty.handler.codec.http2.Http2CodecUtil.MAX_HEADER_LIST_SIZE; +import static io.netty.handler.codec.http2.Http2TestUtil.newTestEncoder; + +final class HpackTestCase { + + private static final Gson GSON = new GsonBuilder() + .setFieldNamingPolicy(FieldNamingPolicy.LOWER_CASE_WITH_UNDERSCORES) + .registerTypeAdapter(HpackHeaderField.class, new HeaderFieldDeserializer()) + .create(); + + int maxHeaderTableSize = -1; + boolean sensitiveHeaders; + + List headerBlocks; + + private HpackTestCase() { + } + + static HpackTestCase load(InputStream is) { + InputStreamReader r = new InputStreamReader(is); + HpackTestCase hpackTestCase = GSON.fromJson(r, HpackTestCase.class); + for (HeaderBlock headerBlock : hpackTestCase.headerBlocks) { + headerBlock.encodedBytes = StringUtil.decodeHexDump(headerBlock.getEncodedStr()); + } + return hpackTestCase; + } + + void testCompress() throws Exception { + HpackEncoder hpackEncoder = createEncoder(); + + for (HeaderBlock headerBlock : headerBlocks) { + + byte[] actual = + encode(hpackEncoder, headerBlock.getHeaders(), headerBlock.getMaxHeaderTableSize(), + sensitiveHeaders); + + if (!Arrays.equals(actual, headerBlock.encodedBytes)) { + throw new AssertionError( + "\nEXPECTED:\n" + headerBlock.getEncodedStr() + + "\nACTUAL:\n" + StringUtil.toHexString(actual)); + } + + List actualDynamicTable = new ArrayList(); + for (int index = 0; index < hpackEncoder.length(); index++) { + actualDynamicTable.add(hpackEncoder.getHeaderField(index)); + } + + List expectedDynamicTable = headerBlock.getDynamicTable(); + + if (!headersEqual(expectedDynamicTable, actualDynamicTable)) { + throw new AssertionError( + "\nEXPECTED DYNAMIC TABLE:\n" + expectedDynamicTable + + "\nACTUAL DYNAMIC TABLE:\n" + actualDynamicTable); + } + + if (headerBlock.getTableSize() != hpackEncoder.size()) { + throw new AssertionError( + "\nEXPECTED TABLE SIZE: " + headerBlock.getTableSize() + + "\n ACTUAL TABLE SIZE : " + hpackEncoder.size()); + } + } + } + + void testDecompress() throws Exception { + HpackDecoder hpackDecoder = createDecoder(); + + for (HeaderBlock headerBlock : headerBlocks) { + + List actualHeaders = decode(hpackDecoder, headerBlock.encodedBytes); + + List expectedHeaders = new ArrayList(); + for (HpackHeaderField h : headerBlock.getHeaders()) { + expectedHeaders.add(new HpackHeaderField(h.name, h.value)); + } + + if (!headersEqual(expectedHeaders, actualHeaders)) { + throw new AssertionError( + "\nEXPECTED:\n" + expectedHeaders + + "\nACTUAL:\n" + actualHeaders); + } + + List actualDynamicTable = new ArrayList(); + for (int index = 0; index < hpackDecoder.length(); index++) { + actualDynamicTable.add(hpackDecoder.getHeaderField(index)); + } + + List expectedDynamicTable = headerBlock.getDynamicTable(); + + if (!headersEqual(expectedDynamicTable, actualDynamicTable)) { + throw new AssertionError( + "\nEXPECTED DYNAMIC TABLE:\n" + expectedDynamicTable + + "\nACTUAL DYNAMIC TABLE:\n" + actualDynamicTable); + } + + if (headerBlock.getTableSize() != hpackDecoder.size()) { + throw new AssertionError( + "\nEXPECTED TABLE SIZE: " + headerBlock.getTableSize() + + "\n ACTUAL TABLE SIZE : " + hpackDecoder.size()); + } + } + } + + private HpackEncoder createEncoder() { + int maxHeaderTableSize = this.maxHeaderTableSize; + if (maxHeaderTableSize == -1) { + maxHeaderTableSize = Integer.MAX_VALUE; + } + + try { + return newTestEncoder(true, MAX_HEADER_LIST_SIZE, maxHeaderTableSize); + } catch (Http2Exception e) { + throw new Error("invalid initial values!", e); + } + } + + private HpackDecoder createDecoder() { + int maxHeaderTableSize = this.maxHeaderTableSize; + if (maxHeaderTableSize == -1) { + maxHeaderTableSize = Integer.MAX_VALUE; + } + + return new HpackDecoder(DEFAULT_HEADER_LIST_SIZE, maxHeaderTableSize); + } + + private static byte[] encode(HpackEncoder hpackEncoder, List headers, int maxHeaderTableSize, + final boolean sensitive) throws Http2Exception { + Http2Headers http2Headers = toHttp2Headers(headers); + Http2HeadersEncoder.SensitivityDetector sensitivityDetector = new Http2HeadersEncoder.SensitivityDetector() { + @Override + public boolean isSensitive(CharSequence name, CharSequence value) { + return sensitive; + } + }; + ByteBuf buffer = Unpooled.buffer(); + try { + if (maxHeaderTableSize != -1) { + hpackEncoder.setMaxHeaderTableSize(buffer, maxHeaderTableSize); + } + + hpackEncoder.encodeHeaders(3 /* randomly chosen */, buffer, http2Headers, sensitivityDetector); + byte[] bytes = new byte[buffer.readableBytes()]; + buffer.readBytes(bytes); + return bytes; + } finally { + buffer.release(); + } + } + + private static Http2Headers toHttp2Headers(List inHeaders) { + Http2Headers headers = new DefaultHttp2Headers(false); + for (HpackHeaderField e : inHeaders) { + headers.add(e.name, e.value); + } + return headers; + } + + private static List decode(HpackDecoder hpackDecoder, byte[] expected) throws Exception { + ByteBuf in = Unpooled.wrappedBuffer(expected); + try { + List headers = new ArrayList(); + TestHeaderListener listener = new TestHeaderListener(headers); + hpackDecoder.decode(0, in, listener, true); + return headers; + } finally { + in.release(); + } + } + + private static String concat(List l) { + StringBuilder ret = new StringBuilder(); + for (String s : l) { + ret.append(s); + } + return ret.toString(); + } + + private static boolean headersEqual(List expected, List actual) { + if (expected.size() != actual.size()) { + return false; + } + for (int i = 0; i < expected.size(); i++) { + if (!expected.get(i).equalsForTest(actual.get(i))) { + return false; + } + } + return true; + } + + static class HeaderBlock { + @SuppressWarnings("FieldMayBeFinal") + private int maxHeaderTableSize = -1; + private byte[] encodedBytes; + private List encoded; + private List headers; + private List dynamicTable; + private int tableSize; + + private int getMaxHeaderTableSize() { + return maxHeaderTableSize; + } + + public String getEncodedStr() { + return concat(encoded).replaceAll(" ", ""); + } + + public List getHeaders() { + return headers; + } + + public List getDynamicTable() { + return dynamicTable; + } + + public int getTableSize() { + return tableSize; + } + } + + static class HeaderFieldDeserializer implements JsonDeserializer { + + @Override + public HpackHeaderField deserialize(JsonElement json, Type typeOfT, + JsonDeserializationContext context) { + JsonObject jsonObject = json.getAsJsonObject(); + Set> entrySet = jsonObject.entrySet(); + if (entrySet.size() != 1) { + throw new JsonParseException("JSON Object has multiple entries: " + entrySet); + } + Map.Entry entry = entrySet.iterator().next(); + String name = entry.getKey(); + String value = entry.getValue().getAsString(); + return new HpackHeaderField(name, value); + } + } +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2ClientUpgradeCodecTest.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2ClientUpgradeCodecTest.java new file mode 100644 index 0000000..9b9999c --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2ClientUpgradeCodecTest.java @@ -0,0 +1,86 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.http.DefaultFullHttpRequest; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpVersion; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class Http2ClientUpgradeCodecTest { + + @Test + public void testUpgradeToHttp2ConnectionHandler() throws Exception { + testUpgrade(new Http2ConnectionHandlerBuilder().server(false).frameListener( + new Http2FrameAdapter()).build(), null); + } + + @Test + public void testUpgradeToHttp2FrameCodec() throws Exception { + testUpgrade(Http2FrameCodecBuilder.forClient().build(), null); + } + + @Test + public void testUpgradeToHttp2MultiplexCodec() throws Exception { + testUpgrade(Http2MultiplexCodecBuilder.forClient(new HttpInboundHandler()) + .withUpgradeStreamHandler(new ChannelInboundHandlerAdapter()).build(), null); + } + + @Test + public void testUpgradeToHttp2FrameCodecWithMultiplexer() throws Exception { + testUpgrade(Http2FrameCodecBuilder.forClient().build(), + new Http2MultiplexHandler(new HttpInboundHandler(), new HttpInboundHandler())); + } + + private static void testUpgrade(Http2ConnectionHandler handler, Http2MultiplexHandler multiplexer) + throws Exception { + FullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.OPTIONS, "*"); + + EmbeddedChannel channel = new EmbeddedChannel(new ChannelInboundHandlerAdapter()); + ChannelHandlerContext ctx = channel.pipeline().firstContext(); + + Http2ClientUpgradeCodec codec; + + if (multiplexer == null) { + codec = new Http2ClientUpgradeCodec("connectionHandler", handler); + } else { + codec = new Http2ClientUpgradeCodec("connectionHandler", handler, multiplexer); + } + + codec.setUpgradeHeaders(ctx, request); + // Flush the channel to ensure we write out all buffered data + channel.flush(); + + codec.upgradeTo(ctx, null); + assertNotNull(channel.pipeline().get("connectionHandler")); + + if (multiplexer != null) { + assertNotNull(channel.pipeline().get(Http2MultiplexHandler.class)); + } + + assertTrue(channel.finishAndReleaseAll()); + } + + @ChannelHandler.Sharable + private static final class HttpInboundHandler extends ChannelInboundHandlerAdapter { } +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2ConnectionHandlerTest.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2ConnectionHandlerTest.java new file mode 100644 index 0000000..9d5a1c4 --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2ConnectionHandlerTest.java @@ -0,0 +1,884 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.buffer.UnpooledByteBufAllocator; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelMetadata; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.ChannelPromise; +import io.netty.channel.DefaultChannelConfig; +import io.netty.channel.DefaultChannelPromise; +import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.handler.codec.http2.Http2CodecUtil.SimpleChannelPromiseAggregator; +import io.netty.handler.codec.http2.Http2Exception.ShutdownHint; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.concurrent.EventExecutor; +import io.netty.util.concurrent.GenericFutureListener; +import io.netty.util.concurrent.ImmediateEventExecutor; +import io.netty.util.concurrent.Promise; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; +import org.mockito.ArgumentCaptor; +import org.mockito.ArgumentMatchers; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; + +import static io.netty.buffer.Unpooled.copiedBuffer; +import static io.netty.handler.codec.http2.Http2CodecUtil.connectionPrefaceBuf; +import static io.netty.handler.codec.http2.Http2Error.CANCEL; +import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR; +import static io.netty.handler.codec.http2.Http2Error.STREAM_CLOSED; +import static io.netty.handler.codec.http2.Http2Stream.State.CLOSED; +import static io.netty.handler.codec.http2.Http2Stream.State.IDLE; +import static io.netty.handler.codec.http2.Http2TestUtil.newVoidPromise; +import static io.netty.util.CharsetUtil.US_ASCII; +import static io.netty.util.CharsetUtil.UTF_8; +import static java.util.concurrent.TimeUnit.SECONDS; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.anyBoolean; +import static org.mockito.Mockito.anyInt; +import static org.mockito.Mockito.anyLong; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; +import static org.mockito.Mockito.verifyZeroInteractions; +import static org.mockito.Mockito.when; + +/** + * Tests for {@link Http2ConnectionHandler} + */ +public class Http2ConnectionHandlerTest { + private static final int STREAM_ID = 1; + private static final int NON_EXISTANT_STREAM_ID = 13; + + private Http2ConnectionHandler handler; + private ChannelPromise promise; + private ChannelPromise voidPromise; + + @Mock + private Http2Connection connection; + + @Mock + private Http2RemoteFlowController remoteFlow; + + @Mock + private Http2LocalFlowController localFlow; + + @Mock + private Http2Connection.Endpoint remote; + + @Mock + private Http2RemoteFlowController remoteFlowController; + + @Mock + private Http2Connection.Endpoint local; + + @Mock + private Http2LocalFlowController localFlowController; + + @Mock + private ChannelHandlerContext ctx; + + @Mock + private EventExecutor executor; + + @Mock + private Channel channel; + + @Mock + private ChannelPipeline pipeline; + + @Mock + private ChannelFuture future; + + @Mock + private Http2Stream stream; + + @Mock + private Http2ConnectionDecoder decoder; + + @Mock + private Http2ConnectionEncoder encoder; + + @Mock + private Http2FrameWriter frameWriter; + + private String goAwayDebugCap; + + @SuppressWarnings("unchecked") + @BeforeEach + public void setup() throws Exception { + MockitoAnnotations.initMocks(this); + + promise = new DefaultChannelPromise(channel, ImmediateEventExecutor.INSTANCE); + voidPromise = new DefaultChannelPromise(channel, ImmediateEventExecutor.INSTANCE); + + when(channel.metadata()).thenReturn(new ChannelMetadata(false)); + DefaultChannelConfig config = new DefaultChannelConfig(channel); + when(channel.config()).thenReturn(config); + + Throwable fakeException = new RuntimeException("Fake exception"); + when(encoder.connection()).thenReturn(connection); + when(decoder.connection()).thenReturn(connection); + when(encoder.frameWriter()).thenReturn(frameWriter); + when(encoder.flowController()).thenReturn(remoteFlow); + when(decoder.flowController()).thenReturn(localFlow); + doAnswer(new Answer() { + @Override + public ChannelFuture answer(InvocationOnMock invocation) throws Throwable { + ByteBuf buf = invocation.getArgument(3); + goAwayDebugCap = buf.toString(UTF_8); + buf.release(); + return future; + } + }).when(frameWriter).writeGoAway( + any(ChannelHandlerContext.class), anyInt(), anyLong(), any(ByteBuf.class), any(ChannelPromise.class)); + doAnswer(new Answer() { + @Override + public ChannelFuture answer(InvocationOnMock invocation) throws Throwable { + Object o = invocation.getArguments()[0]; + if (o instanceof ChannelFutureListener) { + ((ChannelFutureListener) o).operationComplete(future); + } + return future; + } + }).when(future).addListener(any(GenericFutureListener.class)); + when(future.cause()).thenReturn(fakeException); + when(future.channel()).thenReturn(channel); + when(channel.isActive()).thenReturn(true); + when(channel.pipeline()).thenReturn(pipeline); + when(connection.remote()).thenReturn(remote); + when(remote.flowController()).thenReturn(remoteFlowController); + when(connection.local()).thenReturn(local); + when(local.flowController()).thenReturn(localFlowController); + doAnswer(new Answer() { + @Override + public Http2Stream answer(InvocationOnMock in) throws Throwable { + Http2StreamVisitor visitor = in.getArgument(0); + if (!visitor.visit(stream)) { + return stream; + } + return null; + } + }).when(connection).forEachActiveStream(any(Http2StreamVisitor.class)); + when(connection.stream(NON_EXISTANT_STREAM_ID)).thenReturn(null); + when(connection.numActiveStreams()).thenReturn(1); + when(connection.stream(STREAM_ID)).thenReturn(stream); + when(connection.goAwaySent(anyInt(), anyLong(), any(ByteBuf.class))).thenReturn(true); + when(stream.open(anyBoolean())).thenReturn(stream); + when(encoder.writeSettings(eq(ctx), any(Http2Settings.class), eq(promise))).thenReturn(future); + when(ctx.alloc()).thenReturn(UnpooledByteBufAllocator.DEFAULT); + when(ctx.channel()).thenReturn(channel); + when(ctx.newSucceededFuture()).thenReturn(future); + when(ctx.newPromise()).thenReturn(promise); + when(ctx.voidPromise()).thenReturn(voidPromise); + when(ctx.write(any())).thenReturn(future); + when(ctx.executor()).thenReturn(executor); + doAnswer(new Answer() { + @Override + public Object answer(InvocationOnMock in) throws Throwable { + Object msg = in.getArgument(0); + ReferenceCountUtil.release(msg); + return null; + } + }).when(ctx).fireChannelRead(any()); + } + + private Http2ConnectionHandler newHandler(boolean flushPreface) throws Exception { + Http2ConnectionHandler handler = new Http2ConnectionHandlerBuilder().codec(decoder, encoder) + .flushPreface(flushPreface).build(); + handler.handlerAdded(ctx); + return handler; + } + + private Http2ConnectionHandler newHandler() throws Exception { + return newHandler(true); + } + + @AfterEach + public void tearDown() throws Exception { + if (handler != null) { + handler.handlerRemoved(ctx); + } + } + + @Test + public void onHttpServerUpgradeWithoutHandlerAdded() throws Exception { + handler = new Http2ConnectionHandlerBuilder().frameListener(new Http2FrameAdapter()).server(true).build(); + Http2Exception e = assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + handler.onHttpServerUpgrade(new Http2Settings()); + } + }); + assertEquals(Http2Error.INTERNAL_ERROR, e.error()); + } + + @Test + public void onHttpClientUpgradeWithoutHandlerAdded() throws Exception { + handler = new Http2ConnectionHandlerBuilder().frameListener(new Http2FrameAdapter()).server(false).build(); + Http2Exception e = assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + handler.onHttpClientUpgrade(); + } + }); + assertEquals(Http2Error.INTERNAL_ERROR, e.error()); + } + + @ParameterizedTest + @ValueSource(booleans = { true, false }) + public void clientShouldveSentPrefaceAndSettingsFrameWhenUserEventIsTriggered(boolean flushPreface) + throws Exception { + when(connection.isServer()).thenReturn(false); + when(channel.isActive()).thenReturn(false); + handler = newHandler(flushPreface); + when(channel.isActive()).thenReturn(true); + + final Http2ConnectionPrefaceAndSettingsFrameWrittenEvent evt = + Http2ConnectionPrefaceAndSettingsFrameWrittenEvent.INSTANCE; + + final AtomicBoolean verified = new AtomicBoolean(false); + final Answer verifier = new Answer() { + @Override + public Object answer(final InvocationOnMock in) throws Throwable { + assertEquals(in.getArgument(0), evt); // sanity check... + verify(ctx).write(eq(connectionPrefaceBuf())); + verify(encoder).writeSettings(eq(ctx), any(Http2Settings.class), any(ChannelPromise.class)); + verified.set(true); + return null; + } + }; + + doAnswer(verifier).when(ctx).fireUserEventTriggered(evt); + + handler.channelActive(ctx); + if (flushPreface) { + verify(ctx, times(1)).flush(); + } else { + verify(ctx, never()).flush(); + } + assertTrue(verified.get()); + } + + @Test + public void clientShouldSendClientPrefaceStringWhenActive() throws Exception { + when(connection.isServer()).thenReturn(false); + when(channel.isActive()).thenReturn(false); + handler = newHandler(); + when(channel.isActive()).thenReturn(true); + handler.channelActive(ctx); + verify(ctx).write(eq(connectionPrefaceBuf())); + } + + @Test + public void serverShouldNotSendClientPrefaceStringWhenActive() throws Exception { + when(connection.isServer()).thenReturn(true); + when(channel.isActive()).thenReturn(false); + handler = newHandler(); + when(channel.isActive()).thenReturn(true); + handler.channelActive(ctx); + verify(ctx, never()).write(eq(connectionPrefaceBuf())); + } + + @Test + public void serverReceivingInvalidClientPrefaceStringShouldHandleException() throws Exception { + when(connection.isServer()).thenReturn(true); + handler = newHandler(); + handler.channelRead(ctx, copiedBuffer("BAD_PREFACE", UTF_8)); + ArgumentCaptor captor = ArgumentCaptor.forClass(ByteBuf.class); + verify(frameWriter).writeGoAway(eq(ctx), eq(Integer.MAX_VALUE), eq(PROTOCOL_ERROR.code()), + captor.capture(), eq(promise)); + assertEquals(0, captor.getValue().refCnt()); + } + + @Test + public void serverReceivingHttp1ClientPrefaceStringShouldIncludePreface() throws Exception { + when(connection.isServer()).thenReturn(true); + handler = newHandler(); + handler.channelRead(ctx, copiedBuffer("GET /path HTTP/1.1", US_ASCII)); + ArgumentCaptor captor = ArgumentCaptor.forClass(ByteBuf.class); + verify(frameWriter).writeGoAway(eq(ctx), eq(Integer.MAX_VALUE), eq(PROTOCOL_ERROR.code()), + captor.capture(), eq(promise)); + assertEquals(0, captor.getValue().refCnt()); + assertTrue(goAwayDebugCap.contains("/path")); + } + + @Test + public void serverReceivingClientPrefaceStringFollowedByNonSettingsShouldHandleException() + throws Exception { + when(connection.isServer()).thenReturn(true); + handler = newHandler(); + + // Create a connection preface followed by a bunch of zeros (i.e. not a settings frame). + ByteBuf buf = Unpooled.buffer().writeBytes(connectionPrefaceBuf()).writeZero(10); + handler.channelRead(ctx, buf); + ArgumentCaptor captor = ArgumentCaptor.forClass(ByteBuf.class); + verify(frameWriter, atLeastOnce()).writeGoAway(eq(ctx), eq(Integer.MAX_VALUE), eq(PROTOCOL_ERROR.code()), + captor.capture(), eq(promise)); + assertEquals(0, captor.getValue().refCnt()); + } + + @Test + public void serverReceivingValidClientPrefaceStringShouldContinueReadingFrames() throws Exception { + when(connection.isServer()).thenReturn(true); + handler = newHandler(); + ByteBuf prefacePlusSome = addSettingsHeader(Unpooled.buffer().writeBytes(connectionPrefaceBuf())); + handler.channelRead(ctx, prefacePlusSome); + verify(decoder, atLeastOnce()).decodeFrame(any(ChannelHandlerContext.class), + any(ByteBuf.class), ArgumentMatchers.>any()); + } + + @Test + public void verifyChannelHandlerCanBeReusedInPipeline() throws Exception { + when(connection.isServer()).thenReturn(true); + handler = newHandler(); + // Only read the connection preface...after preface is read internal state of Http2ConnectionHandler + // is expected to change relative to the pipeline. + ByteBuf preface = connectionPrefaceBuf(); + handler.channelRead(ctx, preface); + verify(decoder, never()).decodeFrame(any(ChannelHandlerContext.class), + any(ByteBuf.class), ArgumentMatchers.>any()); + + // Now remove and add the handler...this is setting up the test condition. + handler.handlerRemoved(ctx); + handler.handlerAdded(ctx); + + // Now verify we can continue as normal, reading connection preface plus more. + ByteBuf prefacePlusSome = addSettingsHeader(Unpooled.buffer().writeBytes(connectionPrefaceBuf())); + handler.channelRead(ctx, prefacePlusSome); + verify(decoder, atLeastOnce()).decodeFrame(eq(ctx), any(ByteBuf.class), ArgumentMatchers.>any()); + } + + @SuppressWarnings("unchecked") + @Test + public void channelInactiveShouldCloseStreams() throws Exception { + handler = newHandler(); + handler.channelInactive(ctx); + verify(connection).close(any(Promise.class)); + } + + @Test + public void connectionErrorShouldStartShutdown() throws Exception { + handler = newHandler(); + Http2Exception e = new Http2Exception(PROTOCOL_ERROR); + // There's no guarantee that lastStreamCreated in correct, as the error could have occurred during header + // processing before it was updated. Thus, it should _not_ be used for the GOAWAY. + // https://github.com/netty/netty/issues/10670 + when(remote.lastStreamCreated()).thenReturn(STREAM_ID); + handler.exceptionCaught(ctx, e); + ArgumentCaptor captor = ArgumentCaptor.forClass(ByteBuf.class); + verify(frameWriter).writeGoAway(eq(ctx), eq(Integer.MAX_VALUE), eq(PROTOCOL_ERROR.code()), + captor.capture(), eq(promise)); + captor.getValue().release(); + } + + @Test + public void serverShouldSend431OnHeaderSizeErrorWhenDecodingInitialHeaders() throws Exception { + int padding = 0; + handler = newHandler(); + Http2Exception e = new Http2Exception.HeaderListSizeException(STREAM_ID, PROTOCOL_ERROR, + "Header size exceeded max allowed size 8196", true); + + when(stream.id()).thenReturn(STREAM_ID); + when(connection.isServer()).thenReturn(true); + when(stream.isHeadersSent()).thenReturn(false); + when(remote.lastStreamCreated()).thenReturn(STREAM_ID); + when(frameWriter.writeRstStream(eq(ctx), eq(STREAM_ID), + eq(PROTOCOL_ERROR.code()), eq(promise))).thenReturn(future); + + handler.exceptionCaught(ctx, e); + + ArgumentCaptor captor = ArgumentCaptor.forClass(Http2Headers.class); + verify(encoder).writeHeaders(eq(ctx), eq(STREAM_ID), + captor.capture(), eq(padding), eq(true), eq(promise)); + Http2Headers headers = captor.getValue(); + assertEquals(HttpResponseStatus.REQUEST_HEADER_FIELDS_TOO_LARGE.codeAsText(), headers.status()); + verify(frameWriter).writeRstStream(ctx, STREAM_ID, PROTOCOL_ERROR.code(), promise); + } + + @Test + public void serverShouldNeverSend431HeaderSizeErrorWhenEncoding() throws Exception { + int padding = 0; + handler = newHandler(); + Http2Exception e = new Http2Exception.HeaderListSizeException(STREAM_ID, PROTOCOL_ERROR, + "Header size exceeded max allowed size 8196", false); + + when(stream.id()).thenReturn(STREAM_ID); + when(connection.isServer()).thenReturn(true); + when(stream.isHeadersSent()).thenReturn(false); + when(remote.lastStreamCreated()).thenReturn(STREAM_ID); + when(frameWriter.writeRstStream(eq(ctx), eq(STREAM_ID), + eq(PROTOCOL_ERROR.code()), eq(promise))).thenReturn(future); + + handler.exceptionCaught(ctx, e); + + verify(encoder, never()).writeHeaders(eq(ctx), eq(STREAM_ID), + any(Http2Headers.class), eq(padding), eq(true), eq(promise)); + verify(frameWriter).writeRstStream(ctx, STREAM_ID, PROTOCOL_ERROR.code(), promise); + } + + @Test + public void clientShouldNeverSend431WhenHeadersAreTooLarge() throws Exception { + int padding = 0; + handler = newHandler(); + Http2Exception e = new Http2Exception.HeaderListSizeException(STREAM_ID, PROTOCOL_ERROR, + "Header size exceeded max allowed size 8196", true); + + when(stream.id()).thenReturn(STREAM_ID); + when(connection.isServer()).thenReturn(false); + when(stream.isHeadersSent()).thenReturn(false); + when(remote.lastStreamCreated()).thenReturn(STREAM_ID); + when(frameWriter.writeRstStream(eq(ctx), eq(STREAM_ID), + eq(PROTOCOL_ERROR.code()), eq(promise))).thenReturn(future); + + handler.exceptionCaught(ctx, e); + + verify(encoder, never()).writeHeaders(eq(ctx), eq(STREAM_ID), + any(Http2Headers.class), eq(padding), eq(true), eq(promise)); + verify(frameWriter).writeRstStream(ctx, STREAM_ID, PROTOCOL_ERROR.code(), promise); + } + + @Test + public void prefaceUserEventProcessed() throws Exception { + final CountDownLatch latch = new CountDownLatch(1); + handler = new Http2ConnectionHandler(decoder, encoder, new Http2Settings()) { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + if (evt == Http2ConnectionPrefaceAndSettingsFrameWrittenEvent.INSTANCE) { + latch.countDown(); + } + } + }; + handler.handlerAdded(ctx); + assertTrue(latch.await(5, SECONDS)); + } + + @Test + public void serverShouldNeverSend431IfHeadersAlreadySent() throws Exception { + int padding = 0; + handler = newHandler(); + Http2Exception e = new Http2Exception.HeaderListSizeException(STREAM_ID, PROTOCOL_ERROR, + "Header size exceeded max allowed size 8196", true); + + when(stream.id()).thenReturn(STREAM_ID); + when(connection.isServer()).thenReturn(true); + when(stream.isHeadersSent()).thenReturn(true); + when(remote.lastStreamCreated()).thenReturn(STREAM_ID); + when(frameWriter.writeRstStream(eq(ctx), eq(STREAM_ID), + eq(PROTOCOL_ERROR.code()), eq(promise))).thenReturn(future); + handler.exceptionCaught(ctx, e); + + verify(encoder, never()).writeHeaders(eq(ctx), eq(STREAM_ID), + any(Http2Headers.class), eq(padding), eq(true), eq(promise)); + + verify(frameWriter).writeRstStream(ctx, STREAM_ID, PROTOCOL_ERROR.code(), promise); + } + + @Test + public void serverShouldCreateStreamIfNeededBeforeSending431() throws Exception { + int padding = 0; + handler = newHandler(); + Http2Exception e = new Http2Exception.HeaderListSizeException(STREAM_ID, PROTOCOL_ERROR, + "Header size exceeded max allowed size 8196", true); + + when(connection.stream(STREAM_ID)).thenReturn(null); + when(remote.createStream(STREAM_ID, true)).thenReturn(stream); + when(stream.id()).thenReturn(STREAM_ID); + + when(connection.isServer()).thenReturn(true); + when(stream.isHeadersSent()).thenReturn(false); + when(remote.lastStreamCreated()).thenReturn(STREAM_ID); + when(frameWriter.writeRstStream(eq(ctx), eq(STREAM_ID), + eq(PROTOCOL_ERROR.code()), eq(promise))).thenReturn(future); + handler.exceptionCaught(ctx, e); + + verify(remote).createStream(STREAM_ID, true); + verify(encoder).writeHeaders(eq(ctx), eq(STREAM_ID), + any(Http2Headers.class), eq(padding), eq(true), eq(promise)); + + verify(frameWriter).writeRstStream(ctx, STREAM_ID, PROTOCOL_ERROR.code(), promise); + } + + @Test + public void encoderAndDecoderAreClosedOnChannelInactive() throws Exception { + handler = newHandler(); + handler.channelActive(ctx); + when(channel.isActive()).thenReturn(false); + handler.channelInactive(ctx); + verify(encoder).close(); + verify(decoder).close(); + } + + @Test + public void writeRstOnNonExistantStreamShouldSucceed() throws Exception { + handler = newHandler(); + when(frameWriter.writeRstStream(eq(ctx), eq(NON_EXISTANT_STREAM_ID), + eq(STREAM_CLOSED.code()), eq(promise))).thenReturn(future); + handler.resetStream(ctx, NON_EXISTANT_STREAM_ID, STREAM_CLOSED.code(), promise); + verify(frameWriter).writeRstStream(eq(ctx), eq(NON_EXISTANT_STREAM_ID), eq(STREAM_CLOSED.code()), eq(promise)); + } + + @Test + public void writeRstOnClosedStreamShouldSucceed() throws Exception { + handler = newHandler(); + when(stream.id()).thenReturn(STREAM_ID); + when(frameWriter.writeRstStream(eq(ctx), eq(STREAM_ID), + anyLong(), any(ChannelPromise.class))).thenReturn(future); + when(stream.state()).thenReturn(CLOSED); + when(stream.isHeadersSent()).thenReturn(true); + // The stream is "closed" but is still known about by the connection (connection().stream(..) + // will return the stream). We should still write a RST_STREAM frame in this scenario. + handler.resetStream(ctx, STREAM_ID, STREAM_CLOSED.code(), promise); + verify(frameWriter).writeRstStream(eq(ctx), eq(STREAM_ID), anyLong(), any(ChannelPromise.class)); + } + + @Test + public void writeRstOnIdleStreamShouldNotWriteButStillSucceed() throws Exception { + handler = newHandler(); + when(stream.state()).thenReturn(IDLE); + handler.resetStream(ctx, STREAM_ID, STREAM_CLOSED.code(), promise); + verify(frameWriter, never()).writeRstStream(eq(ctx), eq(STREAM_ID), anyLong(), any(ChannelPromise.class)); + verify(stream).close(); + } + + @SuppressWarnings("unchecked") + @Test + public void closeListenerShouldBeNotifiedOnlyOneTime() throws Exception { + handler = newHandler(); + when(future.isDone()).thenReturn(true); + when(future.isSuccess()).thenReturn(true); + doAnswer(new Answer() { + @Override + public ChannelFuture answer(InvocationOnMock invocation) throws Throwable { + Object[] args = invocation.getArguments(); + GenericFutureListener listener = (GenericFutureListener) args[0]; + // Simulate that all streams have become inactive by the time the future completes. + doAnswer(new Answer() { + @Override + public Http2Stream answer(InvocationOnMock in) throws Throwable { + return null; + } + }).when(connection).forEachActiveStream(any(Http2StreamVisitor.class)); + when(connection.numActiveStreams()).thenReturn(0); + // Simulate the future being completed. + listener.operationComplete(future); + return future; + } + }).when(future).addListener(any(GenericFutureListener.class)); + handler.close(ctx, promise); + if (future.isDone()) { + when(connection.numActiveStreams()).thenReturn(0); + } + handler.closeStream(stream, future); + // Simulate another stream close call being made after the context should already be closed. + handler.closeStream(stream, future); + verify(ctx, times(1)).close(any(ChannelPromise.class)); + } + + @SuppressWarnings("unchecked") + @Test + public void canSendGoAwayFrame() throws Exception { + ByteBuf data = dummyData(); + long errorCode = Http2Error.INTERNAL_ERROR.code(); + when(future.isDone()).thenReturn(true); + when(future.isSuccess()).thenReturn(true); + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocation) throws Throwable { + ((GenericFutureListener) invocation.getArgument(0)).operationComplete(future); + return null; + } + }).when(future).addListener(any(GenericFutureListener.class)); + handler = newHandler(); + handler.goAway(ctx, STREAM_ID, errorCode, data, promise); + + verify(connection).goAwaySent(eq(STREAM_ID), eq(errorCode), eq(data)); + verify(frameWriter).writeGoAway(eq(ctx), eq(STREAM_ID), eq(errorCode), eq(data), + eq(promise)); + verify(ctx).close(); + assertEquals(0, data.refCnt()); + } + + @Test + public void canSendGoAwayFramesWithDecreasingLastStreamIds() throws Exception { + handler = newHandler(); + ByteBuf data = dummyData(); + long errorCode = Http2Error.INTERNAL_ERROR.code(); + + handler.goAway(ctx, STREAM_ID + 2, errorCode, data.retain(), promise); + verify(frameWriter).writeGoAway(eq(ctx), eq(STREAM_ID + 2), eq(errorCode), eq(data), + eq(promise)); + verify(connection).goAwaySent(eq(STREAM_ID + 2), eq(errorCode), eq(data)); + promise = new DefaultChannelPromise(channel); + handler.goAway(ctx, STREAM_ID, errorCode, data, promise); + verify(frameWriter).writeGoAway(eq(ctx), eq(STREAM_ID), eq(errorCode), eq(data), eq(promise)); + verify(connection).goAwaySent(eq(STREAM_ID), eq(errorCode), eq(data)); + assertEquals(0, data.refCnt()); + } + + @Test + public void cannotSendGoAwayFrameWithIncreasingLastStreamIds() throws Exception { + handler = newHandler(); + ByteBuf data = dummyData(); + long errorCode = Http2Error.INTERNAL_ERROR.code(); + + handler.goAway(ctx, STREAM_ID, errorCode, data.retain(), promise); + verify(connection).goAwaySent(eq(STREAM_ID), eq(errorCode), eq(data)); + verify(frameWriter).writeGoAway(eq(ctx), eq(STREAM_ID), eq(errorCode), eq(data), eq(promise)); + // The frameWriter is only mocked, so it should not have interacted with the promise. + assertFalse(promise.isDone()); + + when(connection.goAwaySent()).thenReturn(true); + when(remote.lastStreamKnownByPeer()).thenReturn(STREAM_ID); + doAnswer(new Answer() { + @Override + public Boolean answer(InvocationOnMock invocationOnMock) { + throw new IllegalStateException(); + } + }).when(connection).goAwaySent(anyInt(), anyLong(), any(ByteBuf.class)); + handler.goAway(ctx, STREAM_ID + 2, errorCode, data, promise); + assertTrue(promise.isDone()); + assertFalse(promise.isSuccess()); + assertEquals(0, data.refCnt()); + verifyNoMoreInteractions(frameWriter); + } + + @Test + public void canSendGoAwayUsingVoidPromise() throws Exception { + handler = newHandler(); + ByteBuf data = dummyData(); + long errorCode = Http2Error.INTERNAL_ERROR.code(); + handler = newHandler(); + final Throwable cause = new RuntimeException("fake exception"); + doAnswer(new Answer() { + @Override + public ChannelFuture answer(InvocationOnMock invocation) throws Throwable { + ChannelPromise promise = invocation.getArgument(4); + assertFalse(promise.isVoid()); + // This is what DefaultHttp2FrameWriter does... I hate mocking :-(. + SimpleChannelPromiseAggregator aggregatedPromise = + new SimpleChannelPromiseAggregator(promise, channel, ImmediateEventExecutor.INSTANCE); + aggregatedPromise.newPromise(); + aggregatedPromise.doneAllocatingPromises(); + return aggregatedPromise.setFailure(cause); + } + }).when(frameWriter).writeGoAway( + any(ChannelHandlerContext.class), anyInt(), anyLong(), any(ByteBuf.class), any(ChannelPromise.class)); + handler.goAway(ctx, STREAM_ID, errorCode, data, newVoidPromise(channel)); + verify(pipeline).fireExceptionCaught(cause); + } + + @Test + public void canCloseStreamWithVoidPromise() throws Exception { + handler = newHandler(); + handler.closeStream(stream, ctx.voidPromise().setSuccess()); + verify(stream, times(1)).close(); + verifyNoMoreInteractions(stream); + } + + @Test + public void channelReadCompleteTriggersFlush() throws Exception { + handler = newHandler(); + handler.channelReadComplete(ctx); + verify(ctx, times(1)).flush(); + } + + @Test + public void channelReadCompleteCallsReadWhenAutoReadFalse() throws Exception { + channel.config().setAutoRead(false); + handler = newHandler(); + handler.channelReadComplete(ctx); + verify(ctx, times(1)).read(); + } + + @Test + public void channelClosedDoesNotThrowPrefaceException() throws Exception { + when(connection.isServer()).thenReturn(true); + handler = newHandler(); + when(channel.isActive()).thenReturn(false); + handler.channelInactive(ctx); + verify(frameWriter, never()).writeGoAway(any(ChannelHandlerContext.class), anyInt(), anyLong(), + any(ByteBuf.class), any(ChannelPromise.class)); + verify(frameWriter, never()).writeRstStream(any(ChannelHandlerContext.class), anyInt(), anyLong(), + any(ChannelPromise.class)); + } + + @Test + public void clientChannelClosedDoesNotSendGoAwayBeforePreface() throws Exception { + when(connection.isServer()).thenReturn(false); + when(channel.isActive()).thenReturn(false); + handler = newHandler(); + when(channel.isActive()).thenReturn(true); + handler.close(ctx, promise); + verifyZeroInteractions(frameWriter); + } + + @Test + public void writeRstStreamForUnknownStreamUsingVoidPromise() throws Exception { + writeRstStreamUsingVoidPromise(NON_EXISTANT_STREAM_ID); + } + + @Test + public void writeRstStreamForKnownStreamUsingVoidPromise() throws Exception { + writeRstStreamUsingVoidPromise(STREAM_ID); + } + + @Test + public void gracefulShutdownTimeoutWhenConnectionErrorHardShutdownTest() throws Exception { + gracefulShutdownTimeoutWhenConnectionErrorTest0(ShutdownHint.HARD_SHUTDOWN); + } + + @Test + public void gracefulShutdownTimeoutWhenConnectionErrorGracefulShutdownTest() throws Exception { + gracefulShutdownTimeoutWhenConnectionErrorTest0(ShutdownHint.GRACEFUL_SHUTDOWN); + } + + private void gracefulShutdownTimeoutWhenConnectionErrorTest0(ShutdownHint hint) throws Exception { + handler = newHandler(); + final long expectedMillis = 1234; + handler.gracefulShutdownTimeoutMillis(expectedMillis); + Http2Exception exception = new Http2Exception(PROTOCOL_ERROR, "Test error", hint); + handler.onConnectionError(ctx, false, exception, exception); + verify(executor, atLeastOnce()).schedule(any(Runnable.class), eq(expectedMillis), eq(TimeUnit.MILLISECONDS)); + } + + @Test + public void gracefulShutdownTimeoutTest() throws Exception { + handler = newHandler(); + final long expectedMillis = 1234; + handler.gracefulShutdownTimeoutMillis(expectedMillis); + handler.close(ctx, promise); + verify(executor, atLeastOnce()).schedule(any(Runnable.class), eq(expectedMillis), eq(TimeUnit.MILLISECONDS)); + } + + @Test + public void gracefulShutdownTimeoutNoActiveStreams() throws Exception { + handler = newHandler(); + when(connection.numActiveStreams()).thenReturn(0); + final long expectedMillis = 1234; + handler.gracefulShutdownTimeoutMillis(expectedMillis); + handler.close(ctx, promise); + verify(executor, atLeastOnce()).schedule(any(Runnable.class), eq(expectedMillis), eq(TimeUnit.MILLISECONDS)); + } + + @Test + public void gracefulShutdownIndefiniteTimeoutTest() throws Exception { + handler = newHandler(); + handler.gracefulShutdownTimeoutMillis(-1); + handler.close(ctx, promise); + verify(executor, never()).schedule(any(Runnable.class), anyLong(), any(TimeUnit.class)); + } + + @Test + public void writeMultipleRstFramesForSameStream() throws Exception { + handler = newHandler(); + when(stream.id()).thenReturn(STREAM_ID); + + final AtomicBoolean resetSent = new AtomicBoolean(); + when(stream.resetSent()).then(new Answer() { + @Override + public Http2Stream answer(InvocationOnMock invocationOnMock) { + resetSent.set(true); + return stream; + } + }); + when(stream.isResetSent()).then(new Answer() { + @Override + public Boolean answer(InvocationOnMock invocationOnMock) { + return resetSent.get(); + } + }); + when(frameWriter.writeRstStream(eq(ctx), eq(STREAM_ID), anyLong(), any(ChannelPromise.class))) + .then(new Answer() { + @Override + public ChannelFuture answer(InvocationOnMock invocationOnMock) throws Throwable { + ChannelPromise promise = invocationOnMock.getArgument(3); + return promise.setSuccess(); + } + }); + + ChannelPromise promise = + new DefaultChannelPromise(channel, ImmediateEventExecutor.INSTANCE); + final ChannelPromise promise2 = + new DefaultChannelPromise(channel, ImmediateEventExecutor.INSTANCE); + promise.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + handler.resetStream(ctx, STREAM_ID, STREAM_CLOSED.code(), promise2); + } + }); + + handler.resetStream(ctx, STREAM_ID, CANCEL.code(), promise); + verify(frameWriter).writeRstStream(eq(ctx), eq(STREAM_ID), anyLong(), any(ChannelPromise.class)); + assertTrue(promise.isSuccess()); + assertTrue(promise2.isSuccess()); + } + + private void writeRstStreamUsingVoidPromise(int streamId) throws Exception { + handler = newHandler(); + final Throwable cause = new RuntimeException("fake exception"); + when(stream.id()).thenReturn(STREAM_ID); + when(frameWriter.writeRstStream(eq(ctx), eq(streamId), anyLong(), any(ChannelPromise.class))) + .then(new Answer() { + @Override + public ChannelFuture answer(InvocationOnMock invocationOnMock) throws Throwable { + ChannelPromise promise = invocationOnMock.getArgument(3); + assertFalse(promise.isVoid()); + return promise.setFailure(cause); + } + }); + handler.resetStream(ctx, streamId, STREAM_CLOSED.code(), newVoidPromise(channel)); + verify(frameWriter).writeRstStream(eq(ctx), eq(streamId), anyLong(), any(ChannelPromise.class)); + verify(pipeline).fireExceptionCaught(cause); + } + + private static ByteBuf dummyData() { + return Unpooled.buffer().writeBytes("abcdefgh".getBytes(UTF_8)); + } + + private static ByteBuf addSettingsHeader(ByteBuf buf) { + buf.writeMedium(Http2CodecUtil.SETTING_ENTRY_LENGTH); + buf.writeByte(Http2FrameTypes.SETTINGS); + buf.writeByte(0); + buf.writeInt(0); + return buf; + } +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2ConnectionRoundtripTest.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2ConnectionRoundtripTest.java new file mode 100644 index 0000000..27cca9c --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2ConnectionRoundtripTest.java @@ -0,0 +1,1325 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerAdapter; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelOutboundHandlerAdapter; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.ChannelPromise; +import io.netty.channel.DefaultEventLoopGroup; +import io.netty.channel.local.LocalAddress; +import io.netty.channel.local.LocalChannel; +import io.netty.channel.local.LocalServerChannel; +import io.netty.handler.codec.http2.Http2TestUtil.FrameCountDown; +import io.netty.handler.codec.http2.Http2TestUtil.Http2Runnable; +import io.netty.util.AsciiString; +import io.netty.util.IllegalReferenceCountException; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.concurrent.Future; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +import java.io.ByteArrayOutputStream; +import java.util.Random; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; + +import static io.netty.buffer.Unpooled.EMPTY_BUFFER; +import static io.netty.handler.codec.http2.Http2CodecUtil.CONNECTION_STREAM_ID; +import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_PRIORITY_WEIGHT; +import static io.netty.handler.codec.http2.Http2Error.NO_ERROR; +import static io.netty.handler.codec.http2.Http2Error.PROTOCOL_ERROR; +import static io.netty.handler.codec.http2.Http2TestUtil.randomString; +import static io.netty.handler.codec.http2.Http2TestUtil.runInChannel; +import static java.lang.Integer.MAX_VALUE; +import static java.util.concurrent.TimeUnit.SECONDS; +import static org.hamcrest.CoreMatchers.instanceOf; +import static org.hamcrest.CoreMatchers.not; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.anyBoolean; +import static org.mockito.Mockito.anyInt; +import static org.mockito.Mockito.anyLong; +import static org.mockito.Mockito.anyShort; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +/** + * Tests the full HTTP/2 framing stack including the connection and preface handlers. + */ +public class Http2ConnectionRoundtripTest { + + private static final long DEFAULT_AWAIT_TIMEOUT_SECONDS = 15; + + @Mock + private Http2FrameListener clientListener; + + @Mock + private Http2FrameListener serverListener; + + private Http2ConnectionHandler http2Client; + private Http2ConnectionHandler http2Server; + private ServerBootstrap sb; + private Bootstrap cb; + private Channel serverChannel; + private volatile Channel serverConnectedChannel; + private Channel clientChannel; + private FrameCountDown serverFrameCountDown; + private CountDownLatch requestLatch; + private CountDownLatch serverSettingsAckLatch; + private CountDownLatch dataLatch; + private CountDownLatch trailersLatch; + private CountDownLatch goAwayLatch; + + @BeforeEach + public void setup() throws Exception { + MockitoAnnotations.initMocks(this); + mockFlowControl(clientListener); + mockFlowControl(serverListener); + } + + @AfterEach + public void teardown() throws Exception { + if (clientChannel != null) { + clientChannel.close().syncUninterruptibly(); + clientChannel = null; + } + if (serverChannel != null) { + serverChannel.close().syncUninterruptibly(); + serverChannel = null; + } + final Channel serverConnectedChannel = this.serverConnectedChannel; + if (serverConnectedChannel != null) { + serverConnectedChannel.close().syncUninterruptibly(); + this.serverConnectedChannel = null; + } + Future serverGroup = sb.config().group().shutdownGracefully(0, 5, SECONDS); + Future serverChildGroup = sb.config().childGroup().shutdownGracefully(0, 5, SECONDS); + Future clientGroup = cb.config().group().shutdownGracefully(0, 5, SECONDS); + serverGroup.syncUninterruptibly(); + serverChildGroup.syncUninterruptibly(); + clientGroup.syncUninterruptibly(); + } + + @Test + public void inflightFrameAfterStreamResetShouldNotMakeConnectionUnusable() throws Exception { + final CountDownLatch latch = new CountDownLatch(1); + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocationOnMock) throws Throwable { + ChannelHandlerContext ctx = invocationOnMock.getArgument(0); + http2Server.encoder().writeHeaders(ctx, + (Integer) invocationOnMock.getArgument(1), + (Http2Headers) invocationOnMock.getArgument(2), + 0, + false, + ctx.newPromise()); + http2Server.flush(ctx); + return null; + } + }).when(serverListener).onHeadersRead(any(ChannelHandlerContext.class), anyInt(), any(Http2Headers.class), + anyInt(), anyShort(), anyBoolean(), anyInt(), anyBoolean()); + + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocationOnMock) throws Throwable { + latch.countDown(); + return null; + } + }).when(clientListener).onHeadersRead(any(ChannelHandlerContext.class), eq(5), any(Http2Headers.class), + anyInt(), anyShort(), anyBoolean(), anyInt(), anyBoolean()); + + bootstrapEnv(1, 1, 2, 1); + + // Create a single stream by sending a HEADERS frame to the server. + final short weight = 16; + final Http2Headers headers = dummyHeaders(); + runInChannel(clientChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + http2Client.encoder().writeHeaders(ctx(), 3, headers, 0, weight, false, 0, false, newPromise()); + http2Client.flush(ctx()); + http2Client.encoder().writeRstStream(ctx(), 3, Http2Error.INTERNAL_ERROR.code(), newPromise()); + http2Client.flush(ctx()); + } + }); + + runInChannel(clientChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + http2Client.encoder().writeHeaders(ctx(), 5, headers, 0, weight, false, 0, false, newPromise()); + http2Client.flush(ctx()); + } + }); + + assertTrue(latch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS)); + } + + @Test + public void headersWithEndStreamShouldNotSendError() throws Exception { + bootstrapEnv(1, 1, 2, 1); + + // Create a single stream by sending a HEADERS frame to the server. + final short weight = 16; + final Http2Headers headers = dummyHeaders(); + runInChannel(clientChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + http2Client.encoder().writeHeaders(ctx(), 3, headers, 0, weight, false, 0, true, + newPromise()); + http2Client.flush(ctx()); + } + }); + + assertTrue(requestLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS)); + verify(serverListener).onHeadersRead(any(ChannelHandlerContext.class), eq(3), eq(headers), + eq(0), eq(weight), eq(false), eq(0), eq(true)); + // Wait for some time to see if a go_away or reset frame will be received. + Thread.sleep(1000); + + // Verify that no errors have been received. + verify(serverListener, never()).onGoAwayRead(any(ChannelHandlerContext.class), anyInt(), + anyLong(), any(ByteBuf.class)); + verify(serverListener, never()).onRstStreamRead(any(ChannelHandlerContext.class), anyInt(), + anyLong()); + + // The server will not respond, and so don't wait for graceful shutdown + setClientGracefulShutdownTime(0); + } + + @Test + public void encodeViolatesMaxHeaderListSizeCanStillUseConnection() throws Exception { + final CountDownLatch serverSettingsAckLatch1 = new CountDownLatch(2); + final CountDownLatch serverSettingsAckLatch2 = new CountDownLatch(3); + final CountDownLatch clientSettingsLatch1 = new CountDownLatch(3); + final CountDownLatch serverRevHeadersLatch = new CountDownLatch(1); + final CountDownLatch clientHeadersLatch = new CountDownLatch(1); + final CountDownLatch clientDataWrite = new CountDownLatch(1); + final AtomicReference clientHeadersWriteException = new AtomicReference(); + final AtomicReference clientHeadersWriteException2 = new AtomicReference(); + final AtomicReference clientDataWriteException = new AtomicReference(); + + final Http2Headers headers = dummyHeaders(); + + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocationOnMock) throws Throwable { + serverSettingsAckLatch1.countDown(); + serverSettingsAckLatch2.countDown(); + return null; + } + }).when(serverListener).onSettingsAckRead(any(ChannelHandlerContext.class)); + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocationOnMock) throws Throwable { + clientSettingsLatch1.countDown(); + return null; + } + }).when(clientListener).onSettingsRead(any(ChannelHandlerContext.class), any(Http2Settings.class)); + + // Manually add a listener for when we receive the expected headers on the server. + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocationOnMock) throws Throwable { + serverRevHeadersLatch.countDown(); + return null; + } + }).when(serverListener).onHeadersRead(any(ChannelHandlerContext.class), eq(5), eq(headers), + anyInt(), anyShort(), anyBoolean(), eq(0), eq(true)); + + bootstrapEnv(1, 2, 2, 0, 0); + + // Set the maxHeaderListSize to 100 so we may be able to write some headers, but not all. We want to verify + // that we don't corrupt state if some can be written but not all. + runInChannel(serverConnectedChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + http2Server.encoder().writeSettings(serverCtx(), + new Http2Settings().copyFrom(http2Server.decoder().localSettings()) + .maxHeaderListSize(100), + serverNewPromise()); + http2Server.flush(serverCtx()); + } + }); + + assertTrue(serverSettingsAckLatch1.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS)); + + runInChannel(clientChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + http2Client.encoder().writeHeaders(ctx(), 3, headers, 0, false, newPromise()) + .addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + clientHeadersWriteException.set(future.cause()); + } + }); + // It is expected that this write should fail locally and the remote peer will never see this. + http2Client.encoder().writeData(ctx(), 3, Unpooled.buffer(), 0, true, newPromise()) + .addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + clientDataWriteException.set(future.cause()); + clientDataWrite.countDown(); + } + }); + http2Client.flush(ctx()); + } + }); + + assertTrue(clientDataWrite.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS)); + assertNotNull(clientHeadersWriteException.get(), "Header encode should have exceeded maxHeaderListSize!"); + assertNotNull(clientDataWriteException.get(), "Data on closed stream should fail!"); + + // Set the maxHeaderListSize to the max value so we can send the headers. + runInChannel(serverConnectedChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + http2Server.encoder().writeSettings(serverCtx(), + new Http2Settings().copyFrom(http2Server.decoder().localSettings()) + .maxHeaderListSize(Http2CodecUtil.MAX_HEADER_LIST_SIZE), + serverNewPromise()); + http2Server.flush(serverCtx()); + } + }); + + assertTrue(clientSettingsLatch1.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS)); + assertTrue(serverSettingsAckLatch2.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS)); + + runInChannel(clientChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + http2Client.encoder().writeHeaders(ctx(), 5, headers, 0, true, + newPromise()).addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + clientHeadersWriteException2.set(future.cause()); + clientHeadersLatch.countDown(); + } + }); + http2Client.flush(ctx()); + } + }); + + assertTrue(clientHeadersLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS)); + assertNull(clientHeadersWriteException2.get(), + "Client write of headers should succeed with increased header list size!"); + assertTrue(serverRevHeadersLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS)); + + verify(serverListener, never()).onDataRead(any(ChannelHandlerContext.class), anyInt(), any(ByteBuf.class), + anyInt(), anyBoolean()); + + // Verify that no errors have been received. + verify(serverListener, never()).onGoAwayRead(any(ChannelHandlerContext.class), anyInt(), anyLong(), + any(ByteBuf.class)); + verify(serverListener, never()).onRstStreamRead(any(ChannelHandlerContext.class), anyInt(), anyLong()); + verify(clientListener, never()).onGoAwayRead(any(ChannelHandlerContext.class), anyInt(), anyLong(), + any(ByteBuf.class)); + verify(clientListener, never()).onRstStreamRead(any(ChannelHandlerContext.class), anyInt(), anyLong()); + } + + @Test + public void testSettingsAckIsSentBeforeUsingFlowControl() throws Exception { + final CountDownLatch serverSettingsAckLatch1 = new CountDownLatch(1); + final CountDownLatch serverSettingsAckLatch2 = new CountDownLatch(2); + final CountDownLatch serverDataLatch = new CountDownLatch(1); + final CountDownLatch clientWriteDataLatch = new CountDownLatch(1); + final byte[] data = new byte[] {1, 2, 3, 4, 5}; + final ByteArrayOutputStream out = new ByteArrayOutputStream(data.length); + + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocationOnMock) throws Throwable { + serverSettingsAckLatch1.countDown(); + serverSettingsAckLatch2.countDown(); + return null; + } + }).when(serverListener).onSettingsAckRead(any(ChannelHandlerContext.class)); + doAnswer(new Answer() { + @Override + public Integer answer(InvocationOnMock in) throws Throwable { + ByteBuf buf = (ByteBuf) in.getArguments()[2]; + int padding = (Integer) in.getArguments()[3]; + int processedBytes = buf.readableBytes() + padding; + + buf.readBytes(out, buf.readableBytes()); + serverDataLatch.countDown(); + return processedBytes; + } + }).when(serverListener).onDataRead(any(ChannelHandlerContext.class), eq(3), + any(ByteBuf.class), eq(0), anyBoolean()); + + bootstrapEnv(1, 1, 2, 1); + + final Http2Headers headers = dummyHeaders(); + + // The server initially reduces the connection flow control window to 0. + runInChannel(serverConnectedChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + http2Server.encoder().writeSettings(serverCtx(), + new Http2Settings().copyFrom(http2Server.decoder().localSettings()) + .initialWindowSize(0), + serverNewPromise()); + http2Server.flush(serverCtx()); + } + }); + + assertTrue(serverSettingsAckLatch1.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS)); + + // The client should now attempt to send data, but the window size is 0 so it will be queued in the flow + // controller. + runInChannel(clientChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + http2Client.encoder().writeHeaders(ctx(), 3, headers, 0, (short) 16, false, 0, false, + newPromise()); + http2Client.encoder().writeData(ctx(), 3, Unpooled.wrappedBuffer(data), 0, true, newPromise()); + http2Client.flush(ctx()); + clientWriteDataLatch.countDown(); + } + }); + + assertTrue(clientWriteDataLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS)); + + // Now the server opens up the connection window to allow the client to send the pending data. + runInChannel(serverConnectedChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + http2Server.encoder().writeSettings(serverCtx(), + new Http2Settings().copyFrom(http2Server.decoder().localSettings()) + .initialWindowSize(data.length), + serverNewPromise()); + http2Server.flush(serverCtx()); + } + }); + + assertTrue(serverSettingsAckLatch2.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS)); + assertTrue(serverDataLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS)); + assertArrayEquals(data, out.toByteArray()); + + // Verify that no errors have been received. + verify(serverListener, never()).onGoAwayRead(any(ChannelHandlerContext.class), anyInt(), anyLong(), + any(ByteBuf.class)); + verify(serverListener, never()).onRstStreamRead(any(ChannelHandlerContext.class), anyInt(), anyLong()); + verify(clientListener, never()).onGoAwayRead(any(ChannelHandlerContext.class), anyInt(), anyLong(), + any(ByteBuf.class)); + verify(clientListener, never()).onRstStreamRead(any(ChannelHandlerContext.class), anyInt(), anyLong()); + } + + @Test + public void priorityUsingHigherValuedStreamIdDoesNotPreventUsingLowerStreamId() throws Exception { + bootstrapEnv(1, 1, 3, 0); + + final Http2Headers headers = dummyHeaders(); + runInChannel(clientChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + http2Client.encoder().writePriority(ctx(), 5, 3, (short) 14, false, newPromise()); + http2Client.encoder().writeHeaders(ctx(), 3, headers, 0, (short) 16, false, 0, false, + newPromise()); + http2Client.flush(ctx()); + } + }); + + assertTrue(serverSettingsAckLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS)); + assertTrue(requestLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS)); + + verify(serverListener).onPriorityRead(any(ChannelHandlerContext.class), eq(5), eq(3), eq((short) 14), + eq(false)); + verify(serverListener).onHeadersRead(any(ChannelHandlerContext.class), eq(3), eq(headers), eq(0), + eq((short) 16), eq(false), eq(0), eq(false)); + + // Verify that no errors have been received. + verify(serverListener, never()).onGoAwayRead(any(ChannelHandlerContext.class), anyInt(), anyLong(), + any(ByteBuf.class)); + verify(serverListener, never()).onRstStreamRead(any(ChannelHandlerContext.class), anyInt(), anyLong()); + verify(clientListener, never()).onGoAwayRead(any(ChannelHandlerContext.class), anyInt(), anyLong(), + any(ByteBuf.class)); + verify(clientListener, never()).onRstStreamRead(any(ChannelHandlerContext.class), anyInt(), anyLong()); + } + + @Test + public void headersUsingHigherValuedStreamIdPreventsUsingLowerStreamId() throws Exception { + bootstrapEnv(1, 1, 2, 0); + + final Http2Headers headers = dummyHeaders(); + runInChannel(clientChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + http2Client.encoder().writeHeaders(ctx(), 5, headers, 0, (short) 16, false, 0, false, + newPromise()); + http2Client.encoder().frameWriter().writeHeaders(ctx(), 3, headers, 0, (short) 16, false, 0, false, + newPromise()); + http2Client.flush(ctx()); + } + }); + + assertTrue(serverSettingsAckLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS)); + assertTrue(requestLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS)); + + verify(serverListener).onHeadersRead(any(ChannelHandlerContext.class), eq(5), eq(headers), eq(0), + eq((short) 16), eq(false), eq(0), eq(false)); + verify(serverListener, never()).onHeadersRead(any(ChannelHandlerContext.class), eq(3), any(Http2Headers.class), + anyInt(), anyShort(), anyBoolean(), anyInt(), anyBoolean()); + + // Client should receive a RST_STREAM for stream 3, but there is not Http2Stream object so the listener is never + // notified. + verify(serverListener, never()).onGoAwayRead(any(ChannelHandlerContext.class), anyInt(), anyLong(), + any(ByteBuf.class)); + verify(serverListener, never()).onRstStreamRead(any(ChannelHandlerContext.class), anyInt(), anyLong()); + verify(clientListener, never()).onGoAwayRead(any(ChannelHandlerContext.class), anyInt(), anyLong(), + any(ByteBuf.class)); + verify(clientListener, never()).onRstStreamRead(any(ChannelHandlerContext.class), anyInt(), anyLong()); + } + + @Test + public void headersWriteForPeerStreamWhichWasResetShouldNotGoAway() throws Exception { + final CountDownLatch serverGotRstLatch = new CountDownLatch(1); + final CountDownLatch serverWriteHeadersLatch = new CountDownLatch(1); + final AtomicReference serverWriteHeadersCauseRef = new AtomicReference(); + + final int streamId = 3; + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocationOnMock) throws Throwable { + if (streamId == (Integer) invocationOnMock.getArgument(1)) { + serverGotRstLatch.countDown(); + } + return null; + } + }).when(serverListener).onRstStreamRead(any(ChannelHandlerContext.class), eq(streamId), anyLong()); + + bootstrapEnv(1, 1, 1, 0); + + final Http2Headers headers = dummyHeaders(); + runInChannel(clientChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + http2Client.encoder().writeHeaders(ctx(), streamId, headers, CONNECTION_STREAM_ID, + DEFAULT_PRIORITY_WEIGHT, false, 0, false, newPromise()); + http2Client.encoder().writeRstStream(ctx(), streamId, Http2Error.CANCEL.code(), newPromise()); + http2Client.flush(ctx()); + } + }); + + assertTrue(serverSettingsAckLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS)); + assertTrue(serverGotRstLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS)); + + verify(serverListener).onHeadersRead(any(ChannelHandlerContext.class), eq(streamId), eq(headers), anyInt(), + anyShort(), anyBoolean(), anyInt(), eq(false)); + + // Now have the server attempt to send a headers frame simulating some asynchronous work. + runInChannel(serverConnectedChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + http2Server.encoder().writeHeaders(serverCtx(), streamId, headers, 0, true, serverNewPromise()) + .addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + serverWriteHeadersCauseRef.set(future.cause()); + serverWriteHeadersLatch.countDown(); + } + }); + http2Server.flush(serverCtx()); + } + }); + + assertTrue(serverWriteHeadersLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS)); + Throwable serverWriteHeadersCause = serverWriteHeadersCauseRef.get(); + assertNotNull(serverWriteHeadersCause); + assertThat(serverWriteHeadersCauseRef.get(), not(instanceOf(Http2Exception.class))); + + // Server should receive a RST_STREAM for stream 3. + verify(serverListener, never()).onGoAwayRead(any(ChannelHandlerContext.class), anyInt(), anyLong(), + any(ByteBuf.class)); + verify(clientListener, never()).onGoAwayRead(any(ChannelHandlerContext.class), anyInt(), anyLong(), + any(ByteBuf.class)); + verify(clientListener, never()).onRstStreamRead(any(ChannelHandlerContext.class), anyInt(), anyLong()); + } + + @Test + public void http2ExceptionInPipelineShouldCloseConnection() throws Exception { + bootstrapEnv(1, 1, 2, 1); + + // Create a latch to track when the close occurs. + final CountDownLatch closeLatch = new CountDownLatch(1); + clientChannel.closeFuture().addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + closeLatch.countDown(); + } + }); + + // Create a single stream by sending a HEADERS frame to the server. + final Http2Headers headers = dummyHeaders(); + runInChannel(clientChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + http2Client.encoder().writeHeaders(ctx(), 3, headers, 0, (short) 16, false, 0, false, + newPromise()); + http2Client.flush(ctx()); + } + }); + + // Wait for the server to create the stream. + assertTrue(serverSettingsAckLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS)); + assertTrue(requestLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS)); + + // Add a handler that will immediately throw an exception. + clientChannel.pipeline().addFirst(new ChannelHandlerAdapter() { + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + throw Http2Exception.connectionError(PROTOCOL_ERROR, "Fake Exception"); + } + }); + + // Wait for the close to occur. + assertTrue(closeLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS)); + assertFalse(clientChannel.isOpen()); + } + + @Test + public void listenerExceptionShouldCloseConnection() throws Exception { + final Http2Headers headers = dummyHeaders(); + doThrow(new RuntimeException("Fake Exception")).when(serverListener).onHeadersRead( + any(ChannelHandlerContext.class), eq(3), eq(headers), eq(0), eq((short) 16), + eq(false), eq(0), eq(false)); + + bootstrapEnv(1, 0, 1, 1); + + // Create a latch to track when the close occurs. + final CountDownLatch closeLatch = new CountDownLatch(1); + clientChannel.closeFuture().addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + closeLatch.countDown(); + } + }); + + // Create a single stream by sending a HEADERS frame to the server. + runInChannel(clientChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + http2Client.encoder().writeHeaders(ctx(), 3, headers, 0, (short) 16, false, 0, false, + newPromise()); + http2Client.flush(ctx()); + } + }); + + // Wait for the server to create the stream. + assertTrue(serverSettingsAckLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS)); + assertTrue(requestLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS)); + + // Wait for the close to occur. + assertTrue(closeLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS)); + assertFalse(clientChannel.isOpen()); + } + + private enum WriteEmptyBufferMode { + SINGLE_END_OF_STREAM, + SECOND_END_OF_STREAM, + SINGLE_WITH_TRAILERS, + SECOND_WITH_TRAILERS + } + + @Test + public void writeOfEmptyReleasedBufferSingleBufferQueuedInFlowControllerShouldFail() throws Exception { + writeOfEmptyReleasedBufferQueuedInFlowControllerShouldFail(WriteEmptyBufferMode.SINGLE_END_OF_STREAM); + } + + @Test + public void writeOfEmptyReleasedBufferSingleBufferTrailersQueuedInFlowControllerShouldFail() throws Exception { + writeOfEmptyReleasedBufferQueuedInFlowControllerShouldFail(WriteEmptyBufferMode.SINGLE_WITH_TRAILERS); + } + + @Test + public void writeOfEmptyReleasedBufferMultipleBuffersQueuedInFlowControllerShouldFail() throws Exception { + writeOfEmptyReleasedBufferQueuedInFlowControllerShouldFail(WriteEmptyBufferMode.SECOND_END_OF_STREAM); + } + + @Test + public void writeOfEmptyReleasedBufferMultipleBuffersTrailersQueuedInFlowControllerShouldFail() throws Exception { + writeOfEmptyReleasedBufferQueuedInFlowControllerShouldFail(WriteEmptyBufferMode.SECOND_WITH_TRAILERS); + } + + private void writeOfEmptyReleasedBufferQueuedInFlowControllerShouldFail(final WriteEmptyBufferMode mode) + throws Exception { + bootstrapEnv(1, 1, 2, 1); + + final ChannelPromise emptyDataPromise = newPromise(); + runInChannel(clientChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + http2Client.encoder().writeHeaders(ctx(), 3, EmptyHttp2Headers.INSTANCE, 0, (short) 16, false, 0, false, + newPromise()); + ByteBuf emptyBuf = Unpooled.buffer(); + emptyBuf.release(); + switch (mode) { + case SINGLE_END_OF_STREAM: + http2Client.encoder().writeData(ctx(), 3, emptyBuf, 0, true, emptyDataPromise); + break; + case SECOND_END_OF_STREAM: + http2Client.encoder().writeData(ctx(), 3, emptyBuf, 0, false, emptyDataPromise); + http2Client.encoder().writeData(ctx(), 3, randomBytes(8), 0, true, newPromise()); + break; + case SINGLE_WITH_TRAILERS: + http2Client.encoder().writeData(ctx(), 3, emptyBuf, 0, false, emptyDataPromise); + http2Client.encoder().writeHeaders(ctx(), 3, EmptyHttp2Headers.INSTANCE, 0, + (short) 16, false, 0, true, newPromise()); + break; + case SECOND_WITH_TRAILERS: + http2Client.encoder().writeData(ctx(), 3, emptyBuf, 0, false, emptyDataPromise); + http2Client.encoder().writeData(ctx(), 3, randomBytes(8), 0, false, newPromise()); + http2Client.encoder().writeHeaders(ctx(), 3, EmptyHttp2Headers.INSTANCE, 0, + (short) 16, false, 0, true, newPromise()); + break; + default: + throw new Error(); + } + http2Client.flush(ctx()); + } + }); + + ExecutionException e = assertThrows(ExecutionException.class, new Executable() { + @Override + public void execute() throws Throwable { + emptyDataPromise.get(); + } + }); + assertThat(e.getCause(), is(instanceOf(IllegalReferenceCountException.class))); + } + + @Test + public void writeFailureFlowControllerRemoveFrame() + throws Exception { + bootstrapEnv(1, 1, 3, 1); + + final ChannelPromise dataPromise = newPromise(); + final ChannelPromise assertPromise = newPromise(); + + runInChannel(clientChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + http2Client.encoder().writeHeaders(ctx(), 3, EmptyHttp2Headers.INSTANCE, 0, (short) 16, false, 0, false, + newPromise()); + clientChannel.pipeline().addFirst(new ChannelOutboundHandlerAdapter() { + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + ReferenceCountUtil.release(msg); + + // Ensure we update the window size so we will try to write the rest of the frame while + // processing the flush. + http2Client.encoder().flowController().initialWindowSize(8); + promise.setFailure(new IllegalStateException()); + } + }); + + http2Client.encoder().flowController().initialWindowSize(4); + http2Client.encoder().writeData(ctx(), 3, randomBytes(8), 0, false, dataPromise); + assertTrue(http2Client.encoder().flowController() + .hasFlowControlled(http2Client.connection().stream(3))); + + http2Client.flush(ctx()); + + try { + // The Frame should have been removed after the write failed. + assertFalse(http2Client.encoder().flowController() + .hasFlowControlled(http2Client.connection().stream(3))); + assertPromise.setSuccess(); + } catch (Throwable error) { + assertPromise.setFailure(error); + } + } + }); + + ExecutionException e = assertThrows(ExecutionException.class, new Executable() { + @Override + public void execute() throws Throwable { + dataPromise.get(); + } + }); + assertThat(e.getCause(), is(instanceOf(IllegalStateException.class))); + assertPromise.sync(); + } + + @Test + public void nonHttp2ExceptionInPipelineShouldNotCloseConnection() throws Exception { + bootstrapEnv(1, 1, 2, 1); + + // Create a latch to track when the close occurs. + final CountDownLatch closeLatch = new CountDownLatch(1); + clientChannel.closeFuture().addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + closeLatch.countDown(); + } + }); + + // Create a single stream by sending a HEADERS frame to the server. + final Http2Headers headers = dummyHeaders(); + runInChannel(clientChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + http2Client.encoder().writeHeaders(ctx(), 3, headers, 0, (short) 16, false, 0, false, + newPromise()); + http2Client.flush(ctx()); + } + }); + + // Wait for the server to create the stream. + assertTrue(serverSettingsAckLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS)); + assertTrue(requestLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS)); + + // Add a handler that will immediately throw an exception. + clientChannel.pipeline().addFirst(new ChannelHandlerAdapter() { + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + throw new RuntimeException("Fake Exception"); + } + }); + + // The close should NOT occur. + assertFalse(closeLatch.await(2, SECONDS)); + assertTrue(clientChannel.isOpen()); + + // Set the timeout very low because we know graceful shutdown won't complete + setClientGracefulShutdownTime(0); + } + + @Test + public void noMoreStreamIdsShouldSendGoAway() throws Exception { + bootstrapEnv(1, 1, 4, 1, 1); + + // Don't wait for the server to close streams + setClientGracefulShutdownTime(0); + + // Create a single stream by sending a HEADERS frame to the server. + final Http2Headers headers = dummyHeaders(); + runInChannel(clientChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + http2Client.encoder().writeHeaders(ctx(), 3, headers, 0, (short) 16, false, 0, + true, newPromise()); + http2Client.flush(ctx()); + } + }); + + assertTrue(serverSettingsAckLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS)); + + runInChannel(clientChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + http2Client.encoder().writeHeaders(ctx(), MAX_VALUE + 1, headers, 0, (short) 16, false, 0, + true, newPromise()); + http2Client.flush(ctx()); + } + }); + + assertTrue(goAwayLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS)); + verify(serverListener).onGoAwayRead(any(ChannelHandlerContext.class), eq(0), + eq(PROTOCOL_ERROR.code()), any(ByteBuf.class)); + } + + @Test + public void createStreamAfterReceiveGoAwayShouldNotSendGoAway() throws Exception { + final CountDownLatch clientGoAwayLatch = new CountDownLatch(1); + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocationOnMock) throws Throwable { + clientGoAwayLatch.countDown(); + return null; + } + }).when(clientListener).onGoAwayRead(any(ChannelHandlerContext.class), anyInt(), anyLong(), any(ByteBuf.class)); + + bootstrapEnv(1, 1, 2, 1, 1); + + // We want both sides to do graceful shutdown during the test. + setClientGracefulShutdownTime(10000); + setServerGracefulShutdownTime(10000); + + // Create a single stream by sending a HEADERS frame to the server. + final Http2Headers headers = dummyHeaders(); + runInChannel(clientChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + http2Client.encoder().writeHeaders(ctx(), 3, headers, 0, (short) 16, false, 0, + false, newPromise()); + http2Client.flush(ctx()); + } + }); + + assertTrue(serverSettingsAckLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS)); + + // Server has received the headers, so the stream is open + assertTrue(requestLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS)); + + runInChannel(serverChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + http2Server.encoder().writeGoAway(serverCtx(), 3, NO_ERROR.code(), EMPTY_BUFFER, serverNewPromise()); + http2Server.flush(serverCtx()); + } + }); + + // wait for the client to receive the GO_AWAY. + assertTrue(clientGoAwayLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS)); + verify(clientListener).onGoAwayRead(any(ChannelHandlerContext.class), eq(3), eq(NO_ERROR.code()), + any(ByteBuf.class)); + + final AtomicReference clientWriteAfterGoAwayFutureRef = new AtomicReference(); + final CountDownLatch clientWriteAfterGoAwayLatch = new CountDownLatch(1); + runInChannel(clientChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + ChannelFuture f = http2Client.encoder().writeHeaders(ctx(), 5, headers, 0, (short) 16, false, 0, + true, newPromise()); + clientWriteAfterGoAwayFutureRef.set(f); + http2Client.flush(ctx()); + f.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + clientWriteAfterGoAwayLatch.countDown(); + } + }); + } + }); + + // Wait for the client's write operation to complete. + assertTrue(clientWriteAfterGoAwayLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS)); + + ChannelFuture clientWriteAfterGoAwayFuture = clientWriteAfterGoAwayFutureRef.get(); + assertNotNull(clientWriteAfterGoAwayFuture); + Throwable clientCause = clientWriteAfterGoAwayFuture.cause(); + assertThat(clientCause, is(instanceOf(Http2Exception.StreamException.class))); + assertEquals(Http2Error.REFUSED_STREAM.code(), ((Http2Exception.StreamException) clientCause).error().code()); + + // Wait for the server to receive a GO_AWAY, but this is expected to timeout! + assertFalse(goAwayLatch.await(1, SECONDS)); + verify(serverListener, never()).onGoAwayRead(any(ChannelHandlerContext.class), anyInt(), anyLong(), + any(ByteBuf.class)); + + // Shutdown shouldn't wait for the server to close streams + setClientGracefulShutdownTime(0); + setServerGracefulShutdownTime(0); + } + + @Test + public void listenerIsNotifiedOfGoawayBeforeStreamsAreRemovedFromTheConnection() throws Exception { + final AtomicReference clientStream3State = new AtomicReference(); + final CountDownLatch clientGoAwayLatch = new CountDownLatch(1); + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock invocationOnMock) throws Throwable { + clientStream3State.set(http2Client.connection().stream(3).state()); + clientGoAwayLatch.countDown(); + return null; + } + }).when(clientListener).onGoAwayRead(any(ChannelHandlerContext.class), anyInt(), anyLong(), any(ByteBuf.class)); + + bootstrapEnv(1, 1, 3, 1, 1); + + // We want both sides to do graceful shutdown during the test. + setClientGracefulShutdownTime(10000); + setServerGracefulShutdownTime(10000); + + // Create a single stream by sending a HEADERS frame to the server. + final Http2Headers headers = dummyHeaders(); + runInChannel(clientChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + http2Client.encoder().writeHeaders(ctx(), 1, headers, 0, (short) 16, false, 0, + false, newPromise()); + http2Client.encoder().writeHeaders(ctx(), 3, headers, 0, (short) 16, false, 0, + false, newPromise()); + http2Client.flush(ctx()); + } + }); + + assertTrue(serverSettingsAckLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS)); + + // Server has received the headers, so the stream is open + assertTrue(requestLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS)); + + runInChannel(serverChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + http2Server.encoder().writeGoAway(serverCtx(), 1, NO_ERROR.code(), EMPTY_BUFFER, serverNewPromise()); + http2Server.flush(serverCtx()); + } + }); + + // wait for the client to receive the GO_AWAY. + assertTrue(clientGoAwayLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS)); + verify(clientListener).onGoAwayRead(any(ChannelHandlerContext.class), eq(1), eq(NO_ERROR.code()), + any(ByteBuf.class)); + assertEquals(Http2Stream.State.OPEN, clientStream3State.get()); + + // Make sure that stream 3 has been closed which is true if it's gone. + final CountDownLatch probeStreamCount = new CountDownLatch(1); + final AtomicBoolean stream3Exists = new AtomicBoolean(); + final AtomicInteger streamCount = new AtomicInteger(); + runInChannel(this.clientChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + stream3Exists.set(http2Client.connection().stream(3) != null); + streamCount.set(http2Client.connection().numActiveStreams()); + probeStreamCount.countDown(); + } + }); + // The stream should be closed right after + assertTrue(probeStreamCount.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS)); + assertEquals(1, streamCount.get()); + assertFalse(stream3Exists.get()); + + // Wait for the server to receive a GO_AWAY, but this is expected to timeout! + assertFalse(goAwayLatch.await(1, SECONDS)); + verify(serverListener, never()).onGoAwayRead(any(ChannelHandlerContext.class), anyInt(), anyLong(), + any(ByteBuf.class)); + + // Shutdown shouldn't wait for the server to close streams + setClientGracefulShutdownTime(0); + setServerGracefulShutdownTime(0); + } + + @Test + public void flowControlProperlyChunksLargeMessage() throws Exception { + final Http2Headers headers = dummyHeaders(); + final Http2Headers trailers = dummyTrailers(); + + // Create a large message to send. + final int length = 10485760; // 10MB + + // Create a buffer filled with random bytes. + final ByteBuf data = randomBytes(length); + final ByteArrayOutputStream out = new ByteArrayOutputStream(length); + doAnswer(new Answer() { + @Override + public Integer answer(InvocationOnMock in) throws Throwable { + ByteBuf buf = (ByteBuf) in.getArguments()[2]; + int padding = (Integer) in.getArguments()[3]; + int processedBytes = buf.readableBytes() + padding; + + buf.readBytes(out, buf.readableBytes()); + return processedBytes; + } + }).when(serverListener).onDataRead(any(ChannelHandlerContext.class), eq(3), + any(ByteBuf.class), eq(0), anyBoolean()); + try { + // Initialize the data latch based on the number of bytes expected. + bootstrapEnv(length, 1, 3, 1); + + // Create the stream and send all of the data at once. + runInChannel(clientChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + http2Client.encoder().writeHeaders(ctx(), 3, headers, 0, (short) 16, false, 0, + false, newPromise()); + http2Client.encoder().writeData(ctx(), 3, data.retainedDuplicate(), 0, false, newPromise()); + + // Write trailers. + http2Client.encoder().writeHeaders(ctx(), 3, trailers, 0, (short) 16, false, 0, + true, newPromise()); + http2Client.flush(ctx()); + } + }); + + // Wait for the trailers to be received. + assertTrue(serverSettingsAckLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS)); + assertTrue(trailersLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS)); + + // Verify that headers and trailers were received. + verify(serverListener).onHeadersRead(any(ChannelHandlerContext.class), eq(3), eq(headers), eq(0), + eq((short) 16), eq(false), eq(0), eq(false)); + verify(serverListener).onHeadersRead(any(ChannelHandlerContext.class), eq(3), eq(trailers), eq(0), + eq((short) 16), eq(false), eq(0), eq(true)); + + // Verify we received all the bytes. + assertEquals(0, dataLatch.getCount()); + out.flush(); + byte[] received = out.toByteArray(); + assertArrayEquals(data.array(), received); + } finally { + // Don't wait for server to close streams + setClientGracefulShutdownTime(0); + data.release(); + out.close(); + } + } + + @Test + public void stressTest() throws Exception { + final Http2Headers headers = dummyHeaders(); + final Http2Headers trailers = dummyTrailers(); + int length = 10; + final ByteBuf data = randomBytes(length); + final String dataAsHex = ByteBufUtil.hexDump(data); + final long pingData = 8; + final int numStreams = 2000; + + // Collect all the ping buffers as we receive them at the server. + final long[] receivedPings = new long[numStreams]; + doAnswer(new Answer() { + int nextIndex; + + @Override + public Void answer(InvocationOnMock in) throws Throwable { + receivedPings[nextIndex++] = (Long) in.getArguments()[1]; + return null; + } + }).when(serverListener).onPingRead(any(ChannelHandlerContext.class), any(Long.class)); + + // Collect all the data buffers as we receive them at the server. + final StringBuilder[] receivedData = new StringBuilder[numStreams]; + doAnswer(new Answer() { + @Override + public Integer answer(InvocationOnMock in) throws Throwable { + int streamId = (Integer) in.getArguments()[1]; + ByteBuf buf = (ByteBuf) in.getArguments()[2]; + int padding = (Integer) in.getArguments()[3]; + int processedBytes = buf.readableBytes() + padding; + + int streamIndex = (streamId - 3) / 2; + StringBuilder builder = receivedData[streamIndex]; + if (builder == null) { + builder = new StringBuilder(dataAsHex.length()); + receivedData[streamIndex] = builder; + } + builder.append(ByteBufUtil.hexDump(buf)); + return processedBytes; + } + }).when(serverListener).onDataRead(any(ChannelHandlerContext.class), anyInt(), + any(ByteBuf.class), anyInt(), anyBoolean()); + try { + bootstrapEnv(numStreams * length, 1, numStreams * 4 + 1 , numStreams); + runInChannel(clientChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + int upperLimit = 3 + 2 * numStreams; + for (int streamId = 3; streamId < upperLimit; streamId += 2) { + // Send a bunch of data on each stream. + http2Client.encoder().writeHeaders(ctx(), streamId, headers, 0, (short) 16, + false, 0, false, newPromise()); + http2Client.encoder().writePing(ctx(), false, pingData, + newPromise()); + http2Client.encoder().writeData(ctx(), streamId, data.retainedSlice(), 0, + false, newPromise()); + // Write trailers. + http2Client.encoder().writeHeaders(ctx(), streamId, trailers, 0, (short) 16, + false, 0, true, newPromise()); + http2Client.flush(ctx()); + } + } + }); + // Wait for all frames to be received. + assertTrue(serverSettingsAckLatch.await(60, SECONDS)); + assertTrue(trailersLatch.await(60, SECONDS)); + verify(serverListener, times(numStreams)).onHeadersRead(any(ChannelHandlerContext.class), anyInt(), + eq(headers), eq(0), eq((short) 16), eq(false), eq(0), eq(false)); + verify(serverListener, times(numStreams)).onHeadersRead(any(ChannelHandlerContext.class), anyInt(), + eq(trailers), eq(0), eq((short) 16), eq(false), eq(0), eq(true)); + verify(serverListener, times(numStreams)).onPingRead(any(ChannelHandlerContext.class), + any(long.class)); + verify(serverListener, never()).onDataRead(any(ChannelHandlerContext.class), + anyInt(), any(ByteBuf.class), eq(0), eq(true)); + for (StringBuilder builder : receivedData) { + assertEquals(dataAsHex, builder.toString()); + } + for (long receivedPing : receivedPings) { + assertEquals(pingData, receivedPing); + } + } finally { + // Don't wait for server to close streams + setClientGracefulShutdownTime(0); + data.release(); + } + } + + private void bootstrapEnv(int dataCountDown, int settingsAckCount, + int requestCountDown, int trailersCountDown) throws Exception { + bootstrapEnv(dataCountDown, settingsAckCount, requestCountDown, trailersCountDown, -1); + } + + private void bootstrapEnv(int dataCountDown, int settingsAckCount, + int requestCountDown, int trailersCountDown, int goAwayCountDown) throws Exception { + final CountDownLatch prefaceWrittenLatch = new CountDownLatch(1); + requestLatch = new CountDownLatch(requestCountDown); + serverSettingsAckLatch = new CountDownLatch(settingsAckCount); + dataLatch = new CountDownLatch(dataCountDown); + trailersLatch = new CountDownLatch(trailersCountDown); + goAwayLatch = goAwayCountDown > 0 ? new CountDownLatch(goAwayCountDown) : requestLatch; + sb = new ServerBootstrap(); + cb = new Bootstrap(); + + final AtomicReference serverHandlerRef = new AtomicReference(); + final CountDownLatch serverInitLatch = new CountDownLatch(1); + sb.group(new DefaultEventLoopGroup()); + sb.channel(LocalServerChannel.class); + sb.childHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + serverConnectedChannel = ch; + ChannelPipeline p = ch.pipeline(); + serverFrameCountDown = + new FrameCountDown(serverListener, serverSettingsAckLatch, + requestLatch, dataLatch, trailersLatch, goAwayLatch); + serverHandlerRef.set(new Http2ConnectionHandlerBuilder() + .server(true) + .frameListener(serverFrameCountDown) + .validateHeaders(false) + .build()); + p.addLast(serverHandlerRef.get()); + serverInitLatch.countDown(); + } + }); + + cb.group(new DefaultEventLoopGroup()); + cb.channel(LocalChannel.class); + cb.handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ChannelPipeline p = ch.pipeline(); + p.addLast(new Http2ConnectionHandlerBuilder() + .server(false) + .frameListener(clientListener) + .validateHeaders(false) + .gracefulShutdownTimeoutMillis(0) + .build()); + p.addLast(new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + if (evt == Http2ConnectionPrefaceAndSettingsFrameWrittenEvent.INSTANCE) { + prefaceWrittenLatch.countDown(); + ctx.pipeline().remove(this); + } + } + }); + } + }); + + serverChannel = sb.bind(new LocalAddress(getClass())).sync().channel(); + + ChannelFuture ccf = cb.connect(serverChannel.localAddress()); + assertTrue(ccf.awaitUninterruptibly().isSuccess()); + clientChannel = ccf.channel(); + assertTrue(prefaceWrittenLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS)); + http2Client = clientChannel.pipeline().get(Http2ConnectionHandler.class); + assertTrue(serverInitLatch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS)); + http2Server = serverHandlerRef.get(); + } + + private ChannelHandlerContext ctx() { + return clientChannel.pipeline().firstContext(); + } + + private ChannelHandlerContext serverCtx() { + return serverConnectedChannel.pipeline().firstContext(); + } + + private ChannelPromise newPromise() { + return ctx().newPromise(); + } + + private ChannelPromise serverNewPromise() { + return serverCtx().newPromise(); + } + + private static Http2Headers dummyHeaders() { + return new DefaultHttp2Headers(false).method(new AsciiString("GET")).scheme(new AsciiString("https")) + .authority(new AsciiString("example.org")).path(new AsciiString("/some/path/resource2")) + .add(randomString(), randomString()); + } + + private static Http2Headers dummyTrailers() { + return new DefaultHttp2Headers(false) + .add("header-" + randomString(), randomString()); + } + + private static void mockFlowControl(Http2FrameListener listener) throws Http2Exception { + doAnswer(new Answer() { + @Override + public Integer answer(InvocationOnMock invocation) throws Throwable { + ByteBuf buf = (ByteBuf) invocation.getArguments()[2]; + int padding = (Integer) invocation.getArguments()[3]; + return buf.readableBytes() + padding; + } + + }).when(listener).onDataRead(any(ChannelHandlerContext.class), anyInt(), + any(ByteBuf.class), anyInt(), anyBoolean()); + } + + private void setClientGracefulShutdownTime(final long millis) throws InterruptedException { + setGracefulShutdownTime(clientChannel, http2Client, millis); + } + + private void setServerGracefulShutdownTime(final long millis) throws InterruptedException { + setGracefulShutdownTime(serverChannel, http2Server, millis); + } + + private static void setGracefulShutdownTime(Channel channel, final Http2ConnectionHandler handler, + final long millis) throws InterruptedException { + final CountDownLatch latch = new CountDownLatch(1); + runInChannel(channel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + handler.gracefulShutdownTimeoutMillis(millis); + latch.countDown(); + } + }); + + assertTrue(latch.await(DEFAULT_AWAIT_TIMEOUT_SECONDS, SECONDS)); + } + + /** + * Creates a {@link ByteBuf} of the given length, filled with random bytes. + */ + private static ByteBuf randomBytes(int length) { + final byte[] bytes = new byte[length]; + new Random().nextBytes(bytes); + return Unpooled.wrappedBuffer(bytes); + } +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2ControlFrameLimitEncoderTest.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2ControlFrameLimitEncoderTest.java new file mode 100644 index 0000000..db6f83d --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2ControlFrameLimitEncoderTest.java @@ -0,0 +1,277 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.UnpooledByteBufAllocator; +import io.netty.channel.Channel; +import io.netty.channel.ChannelConfig; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelMetadata; +import io.netty.channel.ChannelPromise; +import io.netty.channel.DefaultChannelPromise; +import io.netty.channel.DefaultMessageSizeEstimator; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.concurrent.EventExecutor; +import io.netty.util.concurrent.ImmediateEventExecutor; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + + +import java.util.ArrayDeque; +import java.util.Queue; + +import static io.netty.handler.codec.http2.Http2CodecUtil.*; +import static io.netty.handler.codec.http2.Http2Error.CANCEL; +import static io.netty.handler.codec.http2.Http2Error.ENHANCE_YOUR_CALM; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.*; + +/** + * Tests for {@link Http2ControlFrameLimitEncoder}. + */ +public class Http2ControlFrameLimitEncoderTest { + + private Http2ControlFrameLimitEncoder encoder; + + @Mock + private Http2FrameWriter writer; + + @Mock + private ChannelHandlerContext ctx; + + @Mock + private Channel channel; + + @Mock + private Channel.Unsafe unsafe; + + @Mock + private ChannelConfig config; + + @Mock + private EventExecutor executor; + + private int numWrites; + + private final Queue goAwayPromises = new ArrayDeque(); + + /** + * Init fields and do mocking. + */ + @BeforeEach + public void setup() throws Exception { + MockitoAnnotations.initMocks(this); + + numWrites = 0; + + Http2FrameWriter.Configuration configuration = mock(Http2FrameWriter.Configuration.class); + Http2FrameSizePolicy frameSizePolicy = mock(Http2FrameSizePolicy.class); + when(writer.configuration()).thenReturn(configuration); + when(configuration.frameSizePolicy()).thenReturn(frameSizePolicy); + when(frameSizePolicy.maxFrameSize()).thenReturn(DEFAULT_MAX_FRAME_SIZE); + + when(writer.writeRstStream(eq(ctx), anyInt(), anyLong(), any(ChannelPromise.class))) + .thenAnswer(new Answer() { + @Override + public ChannelFuture answer(InvocationOnMock invocationOnMock) { + return handlePromise(invocationOnMock, 3); + } + }); + when(writer.writeSettingsAck(any(ChannelHandlerContext.class), any(ChannelPromise.class))) + .thenAnswer(new Answer() { + @Override + public ChannelFuture answer(InvocationOnMock invocationOnMock) { + return handlePromise(invocationOnMock, 1); + } + }); + when(writer.writePing(any(ChannelHandlerContext.class), anyBoolean(), anyLong(), any(ChannelPromise.class))) + .thenAnswer(new Answer() { + @Override + public ChannelFuture answer(InvocationOnMock invocationOnMock) { + ChannelPromise promise = handlePromise(invocationOnMock, 3); + if (invocationOnMock.getArgument(1) == Boolean.FALSE) { + promise.trySuccess(); + } + return promise; + } + }); + when(writer.writeGoAway(any(ChannelHandlerContext.class), anyInt(), anyLong(), any(ByteBuf.class), + any(ChannelPromise.class))).thenAnswer(new Answer() { + @Override + public ChannelFuture answer(InvocationOnMock invocationOnMock) { + ReferenceCountUtil.release(invocationOnMock.getArgument(3)); + ChannelPromise promise = invocationOnMock.getArgument(4); + goAwayPromises.offer(promise); + return promise; + } + }); + Http2Connection connection = new DefaultHttp2Connection(false); + connection.remote().flowController(new DefaultHttp2RemoteFlowController(connection)); + connection.local().flowController(new DefaultHttp2LocalFlowController(connection).frameWriter(writer)); + + DefaultHttp2ConnectionEncoder defaultEncoder = + new DefaultHttp2ConnectionEncoder(connection, writer); + encoder = new Http2ControlFrameLimitEncoder(defaultEncoder, 2); + DefaultHttp2ConnectionDecoder decoder = + new DefaultHttp2ConnectionDecoder(connection, encoder, mock(Http2FrameReader.class)); + Http2ConnectionHandler handler = new Http2ConnectionHandlerBuilder() + .frameListener(mock(Http2FrameListener.class)) + .codec(decoder, encoder).build(); + + // Set LifeCycleManager on encoder and decoder + when(ctx.channel()).thenReturn(channel); + when(ctx.alloc()).thenReturn(UnpooledByteBufAllocator.DEFAULT); + when(channel.alloc()).thenReturn(UnpooledByteBufAllocator.DEFAULT); + when(executor.inEventLoop()).thenReturn(true); + doAnswer(new Answer() { + @Override + public ChannelPromise answer(InvocationOnMock invocation) throws Throwable { + return newPromise(); + } + }).when(ctx).newPromise(); + when(ctx.executor()).thenReturn(executor); + when(channel.isActive()).thenReturn(false); + when(channel.config()).thenReturn(config); + when(channel.isWritable()).thenReturn(true); + when(channel.bytesBeforeUnwritable()).thenReturn(Long.MAX_VALUE); + when(config.getWriteBufferHighWaterMark()).thenReturn(Integer.MAX_VALUE); + when(config.getMessageSizeEstimator()).thenReturn(DefaultMessageSizeEstimator.DEFAULT); + ChannelMetadata metadata = new ChannelMetadata(false, 16); + when(channel.metadata()).thenReturn(metadata); + when(channel.unsafe()).thenReturn(unsafe); + handler.handlerAdded(ctx); + } + + private ChannelPromise handlePromise(InvocationOnMock invocationOnMock, int promiseIdx) { + ChannelPromise promise = invocationOnMock.getArgument(promiseIdx); + if (++numWrites == 2) { + promise.setSuccess(); + } + return promise; + } + + @AfterEach + public void teardown() { + // Close and release any buffered frames. + encoder.close(); + + // Notify all goAway ChannelPromise instances now as these will also release the retained ByteBuf for the + // debugData. + for (;;) { + ChannelPromise promise = goAwayPromises.poll(); + if (promise == null) { + break; + } + promise.setSuccess(); + } + } + + @Test + public void testLimitSettingsAck() { + assertFalse(encoder.writeSettingsAck(ctx, newPromise()).isDone()); + // The second write is always marked as success by our mock, which means it will also not be queued and so + // not count to the number of queued frames. + assertTrue(encoder.writeSettingsAck(ctx, newPromise()).isSuccess()); + assertFalse(encoder.writeSettingsAck(ctx, newPromise()).isDone()); + + verifyFlushAndClose(0, false); + + assertFalse(encoder.writeSettingsAck(ctx, newPromise()).isDone()); + assertFalse(encoder.writeSettingsAck(ctx, newPromise()).isDone()); + + verifyFlushAndClose(1, true); + } + + @Test + public void testLimitPingAck() { + assertFalse(encoder.writePing(ctx, true, 8, newPromise()).isDone()); + // The second write is always marked as success by our mock, which means it will also not be queued and so + // not count to the number of queued frames. + assertTrue(encoder.writePing(ctx, true, 8, newPromise()).isSuccess()); + assertFalse(encoder.writePing(ctx, true, 8, newPromise()).isDone()); + + verifyFlushAndClose(0, false); + + assertFalse(encoder.writePing(ctx, true, 8, newPromise()).isDone()); + assertFalse(encoder.writePing(ctx, true, 8, newPromise()).isDone()); + + verifyFlushAndClose(1, true); + } + + @Test + public void testNotLimitPing() { + assertTrue(encoder.writePing(ctx, false, 8, newPromise()).isSuccess()); + assertTrue(encoder.writePing(ctx, false, 8, newPromise()).isSuccess()); + assertTrue(encoder.writePing(ctx, false, 8, newPromise()).isSuccess()); + assertTrue(encoder.writePing(ctx, false, 8, newPromise()).isSuccess()); + + verifyFlushAndClose(0, false); + } + + @Test + public void testLimitRst() { + assertFalse(encoder.writeRstStream(ctx, 1, CANCEL.code(), newPromise()).isDone()); + // The second write is always marked as success by our mock, which means it will also not be queued and so + // not count to the number of queued frames. + assertTrue(encoder.writeRstStream(ctx, 1, CANCEL.code(), newPromise()).isSuccess()); + assertFalse(encoder.writeRstStream(ctx, 1, CANCEL.code(), newPromise()).isDone()); + + verifyFlushAndClose(0, false); + + assertFalse(encoder.writeRstStream(ctx, 1, CANCEL.code(), newPromise()).isDone()); + assertFalse(encoder.writeRstStream(ctx, 1, CANCEL.code(), newPromise()).isDone()); + + verifyFlushAndClose(1, true); + } + + @Test + public void testLimit() { + assertFalse(encoder.writeRstStream(ctx, 1, CANCEL.code(), newPromise()).isDone()); + // The second write is always marked as success by our mock, which means it will also not be queued and so + // not count to the number of queued frames. + assertTrue(encoder.writePing(ctx, false, 8, newPromise()).isSuccess()); + assertFalse(encoder.writePing(ctx, true, 8, newPromise()).isSuccess()); + + verifyFlushAndClose(0, false); + + assertFalse(encoder.writeSettingsAck(ctx, newPromise()).isDone()); + assertFalse(encoder.writeRstStream(ctx, 1, CANCEL.code(), newPromise()).isDone()); + assertFalse(encoder.writePing(ctx, true, 8, newPromise()).isSuccess()); + + verifyFlushAndClose(1, true); + } + + private void verifyFlushAndClose(int invocations, boolean failed) { + verify(ctx, atLeast(invocations)).flush(); + verify(ctx, times(invocations)).close(); + if (failed) { + verify(writer, times(1)).writeGoAway(eq(ctx), eq(Integer.MAX_VALUE), eq(ENHANCE_YOUR_CALM.code()), + any(ByteBuf.class), any(ChannelPromise.class)); + } + } + + private ChannelPromise newPromise() { + return new DefaultChannelPromise(channel, ImmediateEventExecutor.INSTANCE); + } +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2DataChunkedInputTest.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2DataChunkedInputTest.java new file mode 100644 index 0000000..26ea3fd --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2DataChunkedInputTest.java @@ -0,0 +1,177 @@ +/* + * Copyright 2022 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.stream.ChunkedFile; +import io.netty.handler.stream.ChunkedInput; +import io.netty.handler.stream.ChunkedNioFile; +import io.netty.handler.stream.ChunkedNioStream; +import io.netty.handler.stream.ChunkedStream; +import io.netty.handler.stream.ChunkedWriteHandler; +import io.netty.util.internal.PlatformDependent; +import org.junit.jupiter.api.Test; + +import java.io.ByteArrayInputStream; +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.nio.channels.Channels; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class Http2DataChunkedInputTest { + private static final byte[] BYTES = new byte[1024 * 64]; + private static final File TMP; + + // Just a dummy interface implementation of stream + private static final Http2FrameStream STREAM = new Http2FrameStream() { + @Override + public int id() { + return 1; + } + + @Override + public Http2Stream.State state() { + return Http2Stream.State.OPEN; + } + }; + + static { + for (int i = 0; i < BYTES.length; i++) { + BYTES[i] = (byte) i; + } + + FileOutputStream out = null; + try { + TMP = PlatformDependent.createTempFile("netty-chunk-", ".tmp", null); + TMP.deleteOnExit(); + out = new FileOutputStream(TMP); + out.write(BYTES); + out.flush(); + } catch (IOException e) { + throw new RuntimeException(e); + } finally { + if (out != null) { + try { + out.close(); + } catch (IOException e) { + // ignore + } + } + } + } + + @Test + public void testChunkedStream() { + check(new Http2DataChunkedInput(new ChunkedStream(new ByteArrayInputStream(BYTES)), STREAM)); + } + + @Test + public void testChunkedNioStream() { + check(new Http2DataChunkedInput(new ChunkedNioStream(Channels.newChannel(new ByteArrayInputStream(BYTES))), + STREAM)); + } + + @Test + public void testChunkedFile() throws IOException { + check(new Http2DataChunkedInput(new ChunkedFile(TMP), STREAM)); + } + + @Test + public void testChunkedNioFile() throws IOException { + check(new Http2DataChunkedInput(new ChunkedNioFile(TMP), STREAM)); + } + + @Test + public void testWrappedReturnNull() throws Exception { + Http2DataChunkedInput input = new Http2DataChunkedInput(new ChunkedInput() { + + @Override + public boolean isEndOfInput() throws Exception { + return false; + } + + @Override + public void close() throws Exception { + // NOOP + } + + @Override + public ByteBuf readChunk(ChannelHandlerContext ctx) throws Exception { + return null; + } + + @Override + public ByteBuf readChunk(ByteBufAllocator allocator) throws Exception { + return null; + } + + @Override + public long length() { + return 0; + } + + @Override + public long progress() { + return 0; + } + }, STREAM); + assertNull(input.readChunk(ByteBufAllocator.DEFAULT)); + } + + private static void check(ChunkedInput... inputs) { + EmbeddedChannel ch = new EmbeddedChannel(new ChunkedWriteHandler()); + + for (ChunkedInput input : inputs) { + ch.writeOutbound(input); + } + + assertTrue(ch.finish()); + + int i = 0; + int read = 0; + Http2DataFrame http2DataFrame = null; + for (;;) { + Http2DataFrame dataFrame = ch.readOutbound(); + if (dataFrame == null) { + break; + } + + ByteBuf buffer = dataFrame.content(); + while (buffer.isReadable()) { + assertEquals(BYTES[i++], buffer.readByte()); + read++; + if (i == BYTES.length) { + i = 0; + } + } + buffer.release(); + + // Save last chunk + http2DataFrame = dataFrame; + } + + assertEquals(BYTES.length * inputs.length, read); + assertNotNull(http2DataFrame); + assertTrue(http2DataFrame.isEndStream(), "Last chunk must be Http2DataFrame#isEndStream() set to true"); + } +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2DefaultFramesTest.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2DefaultFramesTest.java new file mode 100644 index 0000000..aeb3c24 --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2DefaultFramesTest.java @@ -0,0 +1,44 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package io.netty.handler.codec.http2; + +import io.netty.buffer.DefaultByteBufHolder; +import io.netty.buffer.Unpooled; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertFalse; + +public class Http2DefaultFramesTest { + + @SuppressWarnings("SimplifiableJUnitAssertion") + @Test + public void testEqualOperation() { + // in this case, 'goAwayFrame' and 'unknownFrame' will also have an EMPTY_BUFFER data + // so we want to check that 'dflt' will not consider them equal. + DefaultHttp2GoAwayFrame goAwayFrame = new DefaultHttp2GoAwayFrame(1); + DefaultHttp2UnknownFrame unknownFrame = new DefaultHttp2UnknownFrame((byte) 1, new Http2Flags((short) 1)); + DefaultByteBufHolder dflt = new DefaultByteBufHolder(Unpooled.EMPTY_BUFFER); + try { + // not using 'assertNotEquals' to be explicit about which object we are calling .equals() on + assertFalse(dflt.equals(goAwayFrame)); + assertFalse(dflt.equals(unknownFrame)); + } finally { + goAwayFrame.release(); + unknownFrame.release(); + dflt.release(); + } + } +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2EmptyDataFrameConnectionDecoderTest.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2EmptyDataFrameConnectionDecoderTest.java new file mode 100644 index 0000000..901db88 --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2EmptyDataFrameConnectionDecoderTest.java @@ -0,0 +1,28 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +public class Http2EmptyDataFrameConnectionDecoderTest extends AbstractDecoratingHttp2ConnectionDecoderTest { + + @Override + protected DecoratingHttp2ConnectionDecoder newDecoder(Http2ConnectionDecoder decoder) { + return new Http2EmptyDataFrameConnectionDecoder(decoder, 2); + } + + @Override + protected Class delegatingFrameListenerType() { + return Http2EmptyDataFrameListener.class; + } +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2EmptyDataFrameListenerTest.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2EmptyDataFrameListenerTest.java new file mode 100644 index 0000000..9d1adf5 --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2EmptyDataFrameListenerTest.java @@ -0,0 +1,144 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandlerContext; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; +import org.mockito.Mock; + +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; +import static org.mockito.MockitoAnnotations.initMocks; + +public class Http2EmptyDataFrameListenerTest { + + @Mock + private Http2FrameListener frameListener; + @Mock + private ChannelHandlerContext ctx; + + @Mock + private ByteBuf nonEmpty; + + private Http2EmptyDataFrameListener listener; + + @BeforeEach + public void setUp() { + initMocks(this); + when(nonEmpty.isReadable()).thenReturn(true); + listener = new Http2EmptyDataFrameListener(frameListener, 2); + } + + @Test + public void testEmptyDataFrames() throws Http2Exception { + listener.onDataRead(ctx, 1, Unpooled.EMPTY_BUFFER, 0, false); + listener.onDataRead(ctx, 1, Unpooled.EMPTY_BUFFER, 0, false); + + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + listener.onDataRead(ctx, 1, Unpooled.EMPTY_BUFFER, 0, false); + } + }); + verify(frameListener, times(2)).onDataRead(eq(ctx), eq(1), any(ByteBuf.class), eq(0), eq(false)); + } + + @Test + public void testEmptyDataFramesWithNonEmptyInBetween() throws Http2Exception { + final Http2EmptyDataFrameListener listener = new Http2EmptyDataFrameListener(frameListener, 2); + listener.onDataRead(ctx, 1, Unpooled.EMPTY_BUFFER, 0, false); + listener.onDataRead(ctx, 1, nonEmpty, 0, false); + + listener.onDataRead(ctx, 1, Unpooled.EMPTY_BUFFER, 0, false); + listener.onDataRead(ctx, 1, Unpooled.EMPTY_BUFFER, 0, false); + + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + listener.onDataRead(ctx, 1, Unpooled.EMPTY_BUFFER, 0, false); + } + }); + verify(frameListener, times(4)).onDataRead(eq(ctx), eq(1), any(ByteBuf.class), eq(0), eq(false)); + } + + @Test + public void testEmptyDataFramesWithEndOfStreamInBetween() throws Http2Exception { + final Http2EmptyDataFrameListener listener = new Http2EmptyDataFrameListener(frameListener, 2); + listener.onDataRead(ctx, 1, Unpooled.EMPTY_BUFFER, 0, false); + listener.onDataRead(ctx, 1, Unpooled.EMPTY_BUFFER, 0, true); + + listener.onDataRead(ctx, 1, Unpooled.EMPTY_BUFFER, 0, false); + listener.onDataRead(ctx, 1, Unpooled.EMPTY_BUFFER, 0, false); + + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + listener.onDataRead(ctx, 1, Unpooled.EMPTY_BUFFER, 0, false); + } + }); + + verify(frameListener, times(1)).onDataRead(eq(ctx), eq(1), any(ByteBuf.class), eq(0), eq(true)); + verify(frameListener, times(3)).onDataRead(eq(ctx), eq(1), any(ByteBuf.class), eq(0), eq(false)); + } + + @Test + public void testEmptyDataFramesWithHeaderFrameInBetween() throws Http2Exception { + final Http2EmptyDataFrameListener listener = new Http2EmptyDataFrameListener(frameListener, 2); + listener.onDataRead(ctx, 1, Unpooled.EMPTY_BUFFER, 0, false); + listener.onHeadersRead(ctx, 1, EmptyHttp2Headers.INSTANCE, 0, true); + + listener.onDataRead(ctx, 1, Unpooled.EMPTY_BUFFER, 0, false); + listener.onDataRead(ctx, 1, Unpooled.EMPTY_BUFFER, 0, false); + + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + listener.onDataRead(ctx, 1, Unpooled.EMPTY_BUFFER, 0, false); + } + }); + + verify(frameListener, times(1)).onHeadersRead(eq(ctx), eq(1), eq(EmptyHttp2Headers.INSTANCE), eq(0), eq(true)); + verify(frameListener, times(3)).onDataRead(eq(ctx), eq(1), any(ByteBuf.class), eq(0), eq(false)); + } + + @Test + public void testEmptyDataFramesWithHeaderFrameInBetween2() throws Http2Exception { + final Http2EmptyDataFrameListener listener = new Http2EmptyDataFrameListener(frameListener, 2); + listener.onDataRead(ctx, 1, Unpooled.EMPTY_BUFFER, 0, false); + listener.onHeadersRead(ctx, 1, EmptyHttp2Headers.INSTANCE, 0, (short) 0, false, 0, true); + + listener.onDataRead(ctx, 1, Unpooled.EMPTY_BUFFER, 0, false); + listener.onDataRead(ctx, 1, Unpooled.EMPTY_BUFFER, 0, false); + + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + listener.onDataRead(ctx, 1, Unpooled.EMPTY_BUFFER, 0, false); + } + }); + + verify(frameListener, times(1)).onHeadersRead(eq(ctx), eq(1), + eq(EmptyHttp2Headers.INSTANCE), eq(0), eq((short) 0), eq(false), eq(0), eq(true)); + verify(frameListener, times(3)).onDataRead(eq(ctx), eq(1), any(ByteBuf.class), eq(0), eq(false)); + } +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2ExceptionTest.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2ExceptionTest.java new file mode 100644 index 0000000..782aa10 --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2ExceptionTest.java @@ -0,0 +1,50 @@ +/* + * Copyright 2022 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.handler.codec.DecoderException; +import io.netty.handler.codec.http.TooLongHttpLineException; +import org.junit.jupiter.api.Test; + +import static io.netty.handler.codec.http2.Http2Error.COMPRESSION_ERROR; +import static org.junit.jupiter.api.Assertions.assertEquals; + +class Http2ExceptionTest { + + @Test + public void connectionErrorHandlesMessage() { + DecoderException e = new TooLongHttpLineException("An HTTP line is larger than 1024 bytes."); + Http2Exception http2Exception = Http2Exception.connectionError(COMPRESSION_ERROR, e, e.getMessage()); + assertEquals(COMPRESSION_ERROR, http2Exception.error()); + assertEquals("An HTTP line is larger than 1024 bytes.", http2Exception.getMessage()); + } + + @Test + public void connectionErrorHandlesNullExceptionMessage() { + Exception e = new RuntimeException(); + Http2Exception http2Exception = Http2Exception.connectionError(COMPRESSION_ERROR, e, e.getMessage()); + assertEquals(COMPRESSION_ERROR, http2Exception.error()); + assertEquals("Unexpected error", http2Exception.getMessage()); + } + + @Test + public void connectionErrorHandlesMultipleMessages() { + Exception e = new RuntimeException(); + Http2Exception http2Exception = Http2Exception.connectionError(COMPRESSION_ERROR, e, e.getMessage(), "a", "b"); + assertEquals(COMPRESSION_ERROR, http2Exception.error()); + assertEquals("Unexpected error: [a, b]", http2Exception.getMessage()); + } +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2FrameCodecTest.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2FrameCodecTest.java new file mode 100644 index 0000000..69187a9 --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2FrameCodecTest.java @@ -0,0 +1,941 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelPromise; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.UnsupportedMessageTypeException; +import io.netty.handler.codec.http.DefaultFullHttpRequest; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.handler.codec.http.HttpScheme; +import io.netty.handler.codec.http.HttpServerUpgradeHandler; +import io.netty.handler.codec.http.HttpServerUpgradeHandler.UpgradeEvent; +import io.netty.handler.codec.http.HttpVersion; +import io.netty.handler.codec.http2.Http2Exception.StreamException; +import io.netty.handler.codec.http2.Http2Stream.State; +import io.netty.handler.logging.LogLevel; +import io.netty.util.AbstractReferenceCounted; +import io.netty.util.AsciiString; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.ReferenceCounted; +import io.netty.util.concurrent.DefaultPromise; +import io.netty.util.concurrent.GlobalEventExecutor; +import io.netty.util.concurrent.Promise; +import io.netty.util.internal.ReflectionUtil; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Assumptions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.api.function.Executable; +import org.mockito.ArgumentCaptor; + +import java.lang.reflect.Constructor; +import java.net.InetSocketAddress; +import java.util.HashSet; +import java.util.Set; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; + +import static io.netty.handler.codec.http2.Http2CodecUtil.isStreamIdValid; +import static io.netty.handler.codec.http2.Http2Error.NO_ERROR; +import static io.netty.handler.codec.http2.Http2TestUtil.anyChannelPromise; +import static io.netty.handler.codec.http2.Http2TestUtil.anyHttp2Settings; +import static io.netty.handler.codec.http2.Http2TestUtil.assertEqualsAndRelease; +import static io.netty.handler.codec.http2.Http2TestUtil.bb; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.instanceOf; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.anyInt; +import static org.mockito.Mockito.anyLong; +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.same; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; + +/** + * Unit tests for {@link Http2FrameCodec}. + */ +public class Http2FrameCodecTest { + + // For verifying outbound frames + private Http2FrameWriter frameWriter; + private Http2FrameCodec frameCodec; + private EmbeddedChannel channel; + + // For injecting inbound frames + private Http2FrameInboundWriter frameInboundWriter; + + private LastInboundHandler inboundHandler; + + private final Http2Headers request = new DefaultHttp2Headers() + .method(HttpMethod.GET.asciiName()).scheme(HttpScheme.HTTPS.name()) + .authority(new AsciiString("example.org")).path(new AsciiString("/foo")); + private final Http2Headers response = new DefaultHttp2Headers() + .status(HttpResponseStatus.OK.codeAsText()); + + @BeforeEach + public void setUp() throws Exception { + setUp(Http2FrameCodecBuilder.forServer(), new Http2Settings()); + } + + @AfterEach + public void tearDown() throws Exception { + if (inboundHandler != null) { + inboundHandler.finishAndReleaseAll(); + inboundHandler = null; + } + if (channel != null) { + channel.finishAndReleaseAll(); + channel.close(); + channel = null; + } + } + + private void setUp(Http2FrameCodecBuilder frameCodecBuilder, Http2Settings initialRemoteSettings) throws Exception { + /* + * Some tests call this method twice. Once with JUnit's @Before and once directly to pass special settings. + * This call ensures that in case of two consecutive calls to setUp(), the previous channel is shutdown and + * ByteBufs are released correctly. + */ + tearDown(); + + frameWriter = Http2TestUtil.mockedFrameWriter(); + + frameCodec = frameCodecBuilder.frameWriter(frameWriter).frameLogger(new Http2FrameLogger(LogLevel.TRACE)) + .initialSettings(initialRemoteSettings).build(); + inboundHandler = new LastInboundHandler(); + + channel = new EmbeddedChannel(); + frameInboundWriter = new Http2FrameInboundWriter(channel); + channel.connect(new InetSocketAddress(0)); + channel.pipeline().addLast(frameCodec); + channel.pipeline().addLast(inboundHandler); + channel.pipeline().fireChannelActive(); + + // Handshake + verify(frameWriter).writeSettings(eqFrameCodecCtx(), anyHttp2Settings(), anyChannelPromise()); + verifyNoMoreInteractions(frameWriter); + channel.writeInbound(Http2CodecUtil.connectionPrefaceBuf()); + + frameInboundWriter.writeInboundSettings(initialRemoteSettings); + + verify(frameWriter).writeSettingsAck(eqFrameCodecCtx(), anyChannelPromise()); + + frameInboundWriter.writeInboundSettingsAck(); + + Http2SettingsFrame settingsFrame = inboundHandler.readInbound(); + assertNotNull(settingsFrame); + Http2SettingsAckFrame settingsAckFrame = inboundHandler.readInbound(); + assertNotNull(settingsAckFrame); + } + + @Test + public void stateChanges() throws Exception { + frameInboundWriter.writeInboundHeaders(1, request, 31, true); + + Http2Stream stream = frameCodec.connection().stream(1); + assertNotNull(stream); + assertEquals(State.HALF_CLOSED_REMOTE, stream.state()); + + Http2FrameStreamEvent event = inboundHandler.readInboundMessageOrUserEvent(); + assertEquals(State.HALF_CLOSED_REMOTE, event.stream().state()); + + Http2StreamFrame inboundFrame = inboundHandler.readInbound(); + Http2FrameStream stream2 = inboundFrame.stream(); + assertNotNull(stream2); + assertEquals(1, stream2.id()); + assertEquals(inboundFrame, new DefaultHttp2HeadersFrame(request, true, 31).stream(stream2)); + assertNull(inboundHandler.readInbound()); + + channel.writeOutbound(new DefaultHttp2HeadersFrame(response, true, 27).stream(stream2)); + verify(frameWriter).writeHeaders( + eqFrameCodecCtx(), eq(1), eq(response), + eq(27), eq(true), anyChannelPromise()); + verify(frameWriter, never()).writeRstStream( + eqFrameCodecCtx(), anyInt(), anyLong(), anyChannelPromise()); + + assertEquals(State.CLOSED, stream.state()); + event = inboundHandler.readInboundMessageOrUserEvent(); + assertEquals(State.CLOSED, event.stream().state()); + + assertTrue(channel.isActive()); + } + + @Test + public void headerRequestHeaderResponse() throws Exception { + frameInboundWriter.writeInboundHeaders(1, request, 31, true); + + Http2Stream stream = frameCodec.connection().stream(1); + assertNotNull(stream); + assertEquals(State.HALF_CLOSED_REMOTE, stream.state()); + + Http2StreamFrame inboundFrame = inboundHandler.readInbound(); + Http2FrameStream stream2 = inboundFrame.stream(); + assertNotNull(stream2); + assertEquals(1, stream2.id()); + assertEquals(inboundFrame, new DefaultHttp2HeadersFrame(request, true, 31).stream(stream2)); + assertNull(inboundHandler.readInbound()); + + channel.writeOutbound(new DefaultHttp2HeadersFrame(response, true, 27).stream(stream2)); + verify(frameWriter).writeHeaders( + eqFrameCodecCtx(), eq(1), eq(response), + eq(27), eq(true), anyChannelPromise()); + verify(frameWriter, never()).writeRstStream( + eqFrameCodecCtx(), anyInt(), anyLong(), anyChannelPromise()); + + assertEquals(State.CLOSED, stream.state()); + assertTrue(channel.isActive()); + } + + @Test + public void flowControlShouldBeResilientToMissingStreams() throws Http2Exception { + Http2Connection conn = new DefaultHttp2Connection(true); + Http2ConnectionEncoder enc = new DefaultHttp2ConnectionEncoder(conn, new DefaultHttp2FrameWriter()); + Http2ConnectionDecoder dec = new DefaultHttp2ConnectionDecoder(conn, enc, new DefaultHttp2FrameReader()); + Http2FrameCodec codec = new Http2FrameCodec(enc, dec, new Http2Settings(), false, true); + EmbeddedChannel em = new EmbeddedChannel(codec); + + // We call #consumeBytes on a stream id which has not been seen yet to emulate the case + // where a stream is deregistered which in reality can happen in response to a RST. + assertFalse(codec.consumeBytes(1, 1)); + assertTrue(em.finishAndReleaseAll()); + } + + @Test + public void entityRequestEntityResponse() throws Exception { + frameInboundWriter.writeInboundHeaders(1, request, 0, false); + + Http2Stream stream = frameCodec.connection().stream(1); + assertNotNull(stream); + assertEquals(State.OPEN, stream.state()); + + Http2HeadersFrame inboundHeaders = inboundHandler.readInbound(); + Http2FrameStream stream2 = inboundHeaders.stream(); + assertNotNull(stream2); + assertEquals(1, stream2.id()); + assertEquals(new DefaultHttp2HeadersFrame(request, false).stream(stream2), inboundHeaders); + assertNull(inboundHandler.readInbound()); + + ByteBuf hello = bb("hello"); + frameInboundWriter.writeInboundData(1, hello, 31, true); + Http2DataFrame inboundData = inboundHandler.readInbound(); + Http2DataFrame expected = new DefaultHttp2DataFrame(bb("hello"), true, 31).stream(stream2); + assertEqualsAndRelease(expected, inboundData); + + assertNull(inboundHandler.readInbound()); + + channel.writeOutbound(new DefaultHttp2HeadersFrame(response, false).stream(stream2)); + verify(frameWriter).writeHeaders(eqFrameCodecCtx(), eq(1), eq(response), + eq(0), eq(false), anyChannelPromise()); + + channel.writeOutbound(new DefaultHttp2DataFrame(bb("world"), true, 27).stream(stream2)); + ArgumentCaptor outboundData = ArgumentCaptor.forClass(ByteBuf.class); + verify(frameWriter).writeData(eqFrameCodecCtx(), eq(1), outboundData.capture(), eq(27), + eq(true), anyChannelPromise()); + + ByteBuf bb = bb("world"); + assertEquals(bb, outboundData.getValue()); + assertEquals(1, outboundData.getValue().refCnt()); + bb.release(); + outboundData.getValue().release(); + + verify(frameWriter, never()).writeRstStream(eqFrameCodecCtx(), anyInt(), anyLong(), anyChannelPromise()); + assertTrue(channel.isActive()); + } + + @Test + public void sendRstStream() throws Exception { + frameInboundWriter.writeInboundHeaders(3, request, 31, true); + + Http2Stream stream = frameCodec.connection().stream(3); + assertNotNull(stream); + assertEquals(State.HALF_CLOSED_REMOTE, stream.state()); + + Http2HeadersFrame inboundHeaders = inboundHandler.readInbound(); + assertNotNull(inboundHeaders); + assertTrue(inboundHeaders.isEndStream()); + + Http2FrameStream stream2 = inboundHeaders.stream(); + assertNotNull(stream2); + assertEquals(3, stream2.id()); + + channel.writeOutbound(new DefaultHttp2ResetFrame(314 /* non-standard error */).stream(stream2)); + verify(frameWriter).writeRstStream(eqFrameCodecCtx(), eq(3), eq(314L), anyChannelPromise()); + assertEquals(State.CLOSED, stream.state()); + assertTrue(channel.isActive()); + } + + @Test + public void receiveRstStream() throws Exception { + frameInboundWriter.writeInboundHeaders(3, request, 31, false); + + Http2Stream stream = frameCodec.connection().stream(3); + assertNotNull(stream); + assertEquals(State.OPEN, stream.state()); + + Http2HeadersFrame expectedHeaders = new DefaultHttp2HeadersFrame(request, false, 31); + Http2HeadersFrame actualHeaders = inboundHandler.readInbound(); + assertEquals(expectedHeaders.stream(actualHeaders.stream()), actualHeaders); + + frameInboundWriter.writeInboundRstStream(3, NO_ERROR.code()); + + Http2ResetFrame expectedRst = new DefaultHttp2ResetFrame(NO_ERROR).stream(actualHeaders.stream()); + Http2ResetFrame actualRst = inboundHandler.readInbound(); + assertEquals(expectedRst, actualRst); + + assertNull(inboundHandler.readInbound()); + } + + @Test + public void sendGoAway() throws Exception { + frameInboundWriter.writeInboundHeaders(3, request, 31, false); + Http2Stream stream = frameCodec.connection().stream(3); + assertNotNull(stream); + assertEquals(State.OPEN, stream.state()); + + ByteBuf debugData = bb("debug"); + ByteBuf expected = debugData.copy(); + + Http2GoAwayFrame goAwayFrame = new DefaultHttp2GoAwayFrame(NO_ERROR.code(), + debugData.retainedDuplicate()); + goAwayFrame.setExtraStreamIds(2); + + channel.writeOutbound(goAwayFrame); + verify(frameWriter).writeGoAway(eqFrameCodecCtx(), eq(7), + eq(NO_ERROR.code()), eq(expected), anyChannelPromise()); + assertEquals(State.OPEN, stream.state()); + assertTrue(channel.isActive()); + expected.release(); + debugData.release(); + } + + @Test + public void receiveGoaway() throws Exception { + ByteBuf debugData = bb("foo"); + frameInboundWriter.writeInboundGoAway(2, NO_ERROR.code(), debugData); + Http2GoAwayFrame expectedFrame = new DefaultHttp2GoAwayFrame(2, NO_ERROR.code(), bb("foo")); + Http2GoAwayFrame actualFrame = inboundHandler.readInbound(); + + assertEqualsAndRelease(expectedFrame, actualFrame); + + assertNull(inboundHandler.readInbound()); + } + + @Test + public void unknownFrameTypeShouldThrowAndBeReleased() throws Exception { + class UnknownHttp2Frame extends AbstractReferenceCounted implements Http2Frame { + @Override + public String name() { + return "UNKNOWN"; + } + + @Override + protected void deallocate() { + } + + @Override + public ReferenceCounted touch(Object hint) { + return this; + } + } + + UnknownHttp2Frame frame = new UnknownHttp2Frame(); + assertEquals(1, frame.refCnt()); + + ChannelFuture f = channel.write(frame); + f.await(); + assertTrue(f.isDone()); + assertFalse(f.isSuccess()); + assertThat(f.cause(), instanceOf(UnsupportedMessageTypeException.class)); + assertEquals(0, frame.refCnt()); + } + + @Test + public void unknownFrameTypeOnConnectionStream() throws Exception { + // handle the case where unknown frames are sent before a stream is created, + // for example: HTTP/2 GREASE testing + ByteBuf debugData = bb("debug"); + frameInboundWriter.writeInboundFrame((byte) 0xb, 0, new Http2Flags(), debugData); + channel.flush(); + + assertEquals(0, debugData.refCnt()); + assertTrue(channel.isActive()); + } + + @Test + public void goAwayLastStreamIdOverflowed() throws Exception { + frameInboundWriter.writeInboundHeaders(5, request, 31, false); + + Http2Stream stream = frameCodec.connection().stream(5); + assertNotNull(stream); + assertEquals(State.OPEN, stream.state()); + + ByteBuf debugData = bb("debug"); + Http2GoAwayFrame goAwayFrame = new DefaultHttp2GoAwayFrame(NO_ERROR.code(), + debugData.retainedDuplicate()); + goAwayFrame.setExtraStreamIds(Integer.MAX_VALUE); + + channel.writeOutbound(goAwayFrame); + // When the last stream id computation overflows, the last stream id should just be set to 2^31 - 1. + verify(frameWriter).writeGoAway(eqFrameCodecCtx(), eq(Integer.MAX_VALUE), + eq(NO_ERROR.code()), eq(debugData), anyChannelPromise()); + debugData.release(); + assertEquals(State.OPEN, stream.state()); + assertTrue(channel.isActive()); + } + + @Test + public void streamErrorShouldFireExceptionForInbound() throws Exception { + frameInboundWriter.writeInboundHeaders(3, request, 31, false); + + Http2Stream stream = frameCodec.connection().stream(3); + assertNotNull(stream); + + StreamException streamEx = new StreamException(3, Http2Error.INTERNAL_ERROR, "foo"); + channel.pipeline().fireExceptionCaught(streamEx); + + Http2FrameStreamEvent event = inboundHandler.readInboundMessageOrUserEvent(); + assertEquals(Http2FrameStreamEvent.Type.State, event.type()); + assertEquals(State.OPEN, event.stream().state()); + Http2HeadersFrame headersFrame = inboundHandler.readInboundMessageOrUserEvent(); + assertNotNull(headersFrame); + + Http2FrameStreamException e = assertThrows(Http2FrameStreamException.class, new Executable() { + @Override + public void execute() throws Throwable { + inboundHandler.checkException(); + } + }); + assertEquals(streamEx, e.getCause()); + + assertNull(inboundHandler.readInboundMessageOrUserEvent()); + } + + @Test + public void streamErrorShouldNotFireExceptionForOutbound() throws Exception { + frameInboundWriter.writeInboundHeaders(3, request, 31, false); + + Http2Stream stream = frameCodec.connection().stream(3); + assertNotNull(stream); + + StreamException streamEx = new StreamException(3, Http2Error.INTERNAL_ERROR, "foo"); + frameCodec.onError(frameCodec.ctx, true, streamEx); + + Http2FrameStreamEvent event = inboundHandler.readInboundMessageOrUserEvent(); + assertEquals(Http2FrameStreamEvent.Type.State, event.type()); + assertEquals(State.OPEN, event.stream().state()); + Http2HeadersFrame headersFrame = inboundHandler.readInboundMessageOrUserEvent(); + assertNotNull(headersFrame); + + // No exception expected + inboundHandler.checkException(); + + assertNull(inboundHandler.readInboundMessageOrUserEvent()); + } + + @Test + public void windowUpdateFrameDecrementsConsumedBytes() throws Exception { + frameInboundWriter.writeInboundHeaders(3, request, 31, false); + + Http2Connection connection = frameCodec.connection(); + Http2Stream stream = connection.stream(3); + assertNotNull(stream); + + ByteBuf data = Unpooled.buffer(100).writeZero(100); + frameInboundWriter.writeInboundData(3, data, 0, false); + + Http2HeadersFrame inboundHeaders = inboundHandler.readInbound(); + assertNotNull(inboundHeaders); + assertNotNull(inboundHeaders.stream()); + + Http2FrameStream stream2 = inboundHeaders.stream(); + + int before = connection.local().flowController().unconsumedBytes(stream); + ChannelFuture f = channel.write(new DefaultHttp2WindowUpdateFrame(100).stream(stream2)); + int after = connection.local().flowController().unconsumedBytes(stream); + assertEquals(100, before - after); + assertTrue(f.isSuccess()); + } + + @Test + public void windowUpdateMayFail() throws Exception { + frameInboundWriter.writeInboundHeaders(3, request, 31, false); + Http2Connection connection = frameCodec.connection(); + Http2Stream stream = connection.stream(3); + assertNotNull(stream); + + Http2HeadersFrame inboundHeaders = inboundHandler.readInbound(); + assertNotNull(inboundHeaders); + + Http2FrameStream stream2 = inboundHeaders.stream(); + + // Fails, cause trying to return too many bytes to the flow controller + ChannelFuture f = channel.write(new DefaultHttp2WindowUpdateFrame(100).stream(stream2)); + assertTrue(f.isDone()); + assertFalse(f.isSuccess()); + assertThat(f.cause(), instanceOf(Http2Exception.class)); + } + + @Test + public void inboundWindowUpdateShouldBeForwarded() throws Exception { + frameInboundWriter.writeInboundHeaders(3, request, 31, false); + frameInboundWriter.writeInboundWindowUpdate(3, 100); + // Connection-level window update + frameInboundWriter.writeInboundWindowUpdate(0, 100); + + Http2HeadersFrame headersFrame = inboundHandler.readInbound(); + assertNotNull(headersFrame); + + Http2WindowUpdateFrame windowUpdateFrame = inboundHandler.readInbound(); + assertNotNull(windowUpdateFrame); + assertEquals(3, windowUpdateFrame.stream().id()); + assertEquals(100, windowUpdateFrame.windowSizeIncrement()); + + // Window update for the connection should not be forwarded. + assertNull(inboundHandler.readInbound()); + } + + @Test + public void streamZeroWindowUpdateIncrementsConnectionWindow() throws Http2Exception { + Http2Connection connection = frameCodec.connection(); + Http2LocalFlowController localFlow = connection.local().flowController(); + int initialWindowSizeBefore = localFlow.initialWindowSize(); + Http2Stream connectionStream = connection.connectionStream(); + int connectionWindowSizeBefore = localFlow.windowSize(connectionStream); + // We only replenish the flow control window after the amount consumed drops below the following threshold. + // We make the threshold very "high" so that window updates will be sent when the delta is relatively small. + ((DefaultHttp2LocalFlowController) localFlow).windowUpdateRatio(connectionStream, .999f); + + int windowUpdate = 1024; + + channel.write(new DefaultHttp2WindowUpdateFrame(windowUpdate)); + + // The initial window size is only changed by Http2Settings, so it shouldn't change. + assertEquals(initialWindowSizeBefore, localFlow.initialWindowSize()); + // The connection window should be increased by the delta amount. + assertEquals(connectionWindowSizeBefore + windowUpdate, localFlow.windowSize(connectionStream)); + } + + @Test + public void windowUpdateDoesNotOverflowConnectionWindow() { + Http2Connection connection = frameCodec.connection(); + Http2LocalFlowController localFlow = connection.local().flowController(); + int initialWindowSizeBefore = localFlow.initialWindowSize(); + + channel.write(new DefaultHttp2WindowUpdateFrame(Integer.MAX_VALUE)); + + // The initial window size is only changed by Http2Settings, so it shouldn't change. + assertEquals(initialWindowSizeBefore, localFlow.initialWindowSize()); + // The connection window should be increased by the delta amount. + assertEquals(Integer.MAX_VALUE, localFlow.windowSize(connection.connectionStream())); + } + + @Test + public void writeUnknownFrame() { + final Http2FrameStream stream = frameCodec.newStream(); + + ByteBuf buffer = Unpooled.buffer().writeByte(1); + DefaultHttp2UnknownFrame unknownFrame = new DefaultHttp2UnknownFrame( + (byte) 20, new Http2Flags().ack(true), buffer); + unknownFrame.stream(stream); + channel.write(unknownFrame); + + verify(frameWriter).writeFrame(eqFrameCodecCtx(), eq(unknownFrame.frameType()), + eq(unknownFrame.stream().id()), eq(unknownFrame.flags()), eq(buffer), any(ChannelPromise.class)); + } + + @Test + public void sendSettingsFrame() { + Http2Settings settings = new Http2Settings(); + channel.write(new DefaultHttp2SettingsFrame(settings)); + + verify(frameWriter).writeSettings(eqFrameCodecCtx(), same(settings), any(ChannelPromise.class)); + } + + @Test + @Timeout(value = 5000, unit = TimeUnit.MILLISECONDS) + public void newOutboundStream() { + final Http2FrameStream stream = frameCodec.newStream(); + + assertNotNull(stream); + assertFalse(isStreamIdValid(stream.id())); + + final Promise listenerExecuted = new DefaultPromise(GlobalEventExecutor.INSTANCE); + + channel.writeAndFlush(new DefaultHttp2HeadersFrame(new DefaultHttp2Headers(), false).stream(stream)) + .addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + assertTrue(future.isSuccess()); + assertTrue(isStreamIdValid(stream.id())); + listenerExecuted.setSuccess(null); + } + } + ); + ByteBuf data = Unpooled.buffer().writeZero(100); + ChannelFuture f = channel.writeAndFlush(new DefaultHttp2DataFrame(data).stream(stream)); + assertTrue(f.isSuccess()); + + listenerExecuted.syncUninterruptibly(); + assertTrue(listenerExecuted.isSuccess()); + } + + @Test + public void newOutboundStreamsShouldBeBuffered() throws Exception { + setUp(Http2FrameCodecBuilder.forServer().encoderEnforceMaxConcurrentStreams(true), + new Http2Settings().maxConcurrentStreams(1)); + + Http2FrameStream stream1 = frameCodec.newStream(); + Http2FrameStream stream2 = frameCodec.newStream(); + + ChannelPromise promise1 = channel.newPromise(); + ChannelPromise promise2 = channel.newPromise(); + + channel.writeAndFlush(new DefaultHttp2HeadersFrame(new DefaultHttp2Headers()).stream(stream1), promise1); + channel.writeAndFlush(new DefaultHttp2HeadersFrame(new DefaultHttp2Headers()).stream(stream2), promise2); + + assertTrue(isStreamIdValid(stream1.id())); + channel.runPendingTasks(); + assertTrue(isStreamIdValid(stream2.id())); + + assertTrue(promise1.syncUninterruptibly().isSuccess()); + assertFalse(promise2.isDone()); + + // Increase concurrent streams limit to 2 + frameInboundWriter.writeInboundSettings(new Http2Settings().maxConcurrentStreams(2)); + + channel.flush(); + + assertTrue(promise2.syncUninterruptibly().isSuccess()); + } + + @Test + public void multipleNewOutboundStreamsShouldBeBuffered() throws Exception { + // We use a limit of 1 and then increase it step by step. + setUp(Http2FrameCodecBuilder.forServer().encoderEnforceMaxConcurrentStreams(true), + new Http2Settings().maxConcurrentStreams(1)); + + Http2FrameStream stream1 = frameCodec.newStream(); + Http2FrameStream stream2 = frameCodec.newStream(); + Http2FrameStream stream3 = frameCodec.newStream(); + + ChannelPromise promise1 = channel.newPromise(); + ChannelPromise promise2 = channel.newPromise(); + ChannelPromise promise3 = channel.newPromise(); + + channel.writeAndFlush(new DefaultHttp2HeadersFrame(new DefaultHttp2Headers()).stream(stream1), promise1); + channel.writeAndFlush(new DefaultHttp2HeadersFrame(new DefaultHttp2Headers()).stream(stream2), promise2); + channel.writeAndFlush(new DefaultHttp2HeadersFrame(new DefaultHttp2Headers()).stream(stream3), promise3); + + assertTrue(isStreamIdValid(stream1.id())); + channel.runPendingTasks(); + assertTrue(isStreamIdValid(stream2.id())); + + assertTrue(promise1.syncUninterruptibly().isSuccess()); + assertFalse(promise2.isDone()); + assertFalse(promise3.isDone()); + + // Increase concurrent streams limit to 2 + frameInboundWriter.writeInboundSettings(new Http2Settings().maxConcurrentStreams(2)); + channel.flush(); + + // As we increased the limit to 2 we should have also succeed the second frame. + assertTrue(promise2.syncUninterruptibly().isSuccess()); + assertFalse(promise3.isDone()); + + frameInboundWriter.writeInboundSettings(new Http2Settings().maxConcurrentStreams(3)); + channel.flush(); + + // With the max streams of 3 all streams should be succeed now. + assertTrue(promise3.syncUninterruptibly().isSuccess()); + + assertFalse(channel.finishAndReleaseAll()); + } + + @Test + public void doNotLeakOnFailedInitializationForChannels() throws Exception { + setUp(Http2FrameCodecBuilder.forServer(), new Http2Settings().maxConcurrentStreams(2)); + + Http2FrameStream stream1 = frameCodec.newStream(); + Http2FrameStream stream2 = frameCodec.newStream(); + + ChannelPromise stream1HeaderPromise = channel.newPromise(); + ChannelPromise stream2HeaderPromise = channel.newPromise(); + + channel.writeAndFlush(new DefaultHttp2HeadersFrame(new DefaultHttp2Headers()).stream(stream1), + stream1HeaderPromise); + channel.runPendingTasks(); + + frameInboundWriter.writeInboundGoAway(stream1.id(), 0L, Unpooled.EMPTY_BUFFER); + + channel.writeAndFlush(new DefaultHttp2HeadersFrame(new DefaultHttp2Headers()).stream(stream2), + stream2HeaderPromise); + channel.runPendingTasks(); + + assertTrue(stream1HeaderPromise.syncUninterruptibly().isSuccess()); + assertTrue(stream2HeaderPromise.isDone()); + + assertEquals(0, frameCodec.numInitializingStreams()); + assertFalse(channel.finishAndReleaseAll()); + } + + @Test + public void streamIdentifiersExhausted() throws Http2Exception { + int maxServerStreamId = Integer.MAX_VALUE - 1; + + assertNotNull(frameCodec.connection().local().createStream(maxServerStreamId, false)); + + Http2FrameStream stream = frameCodec.newStream(); + assertNotNull(stream); + + ChannelPromise writePromise = channel.newPromise(); + channel.writeAndFlush(new DefaultHttp2HeadersFrame(new DefaultHttp2Headers()).stream(stream), writePromise); + + Http2GoAwayFrame goAwayFrame = inboundHandler.readInbound(); + assertNotNull(goAwayFrame); + assertEquals(NO_ERROR.code(), goAwayFrame.errorCode()); + assertEquals(Integer.MAX_VALUE, goAwayFrame.lastStreamId()); + goAwayFrame.release(); + assertThat(writePromise.cause(), instanceOf(Http2NoMoreStreamIdsException.class)); + } + + @Test + public void receivePing() throws Http2Exception { + frameInboundWriter.writeInboundPing(false, 12345L); + + Http2PingFrame pingFrame = inboundHandler.readInbound(); + assertNotNull(pingFrame); + + assertEquals(12345, pingFrame.content()); + assertFalse(pingFrame.ack()); + } + + @Test + public void sendPing() { + channel.writeAndFlush(new DefaultHttp2PingFrame(12345)); + + verify(frameWriter).writePing(eqFrameCodecCtx(), eq(false), eq(12345L), anyChannelPromise()); + } + + @Test + public void receiveSettings() throws Http2Exception { + Http2Settings settings = new Http2Settings().maxConcurrentStreams(1); + frameInboundWriter.writeInboundSettings(settings); + + Http2SettingsFrame settingsFrame = inboundHandler.readInbound(); + assertNotNull(settingsFrame); + assertEquals(settings, settingsFrame.settings()); + } + + @Test + public void sendSettings() { + Http2Settings settings = new Http2Settings().maxConcurrentStreams(1); + channel.writeAndFlush(new DefaultHttp2SettingsFrame(settings)); + + verify(frameWriter).writeSettings(eqFrameCodecCtx(), eq(settings), anyChannelPromise()); + } + + @Test + public void iterateActiveStreams() throws Exception { + setUp(Http2FrameCodecBuilder.forServer().encoderEnforceMaxConcurrentStreams(true), + new Http2Settings().maxConcurrentStreams(1)); + + frameInboundWriter.writeInboundHeaders(3, request, 0, false); + + Http2HeadersFrame headersFrame = inboundHandler.readInbound(); + assertNotNull(headersFrame); + + Http2FrameStream activeInbond = headersFrame.stream(); + + Http2FrameStream activeOutbound = frameCodec.newStream(); + channel.writeAndFlush(new DefaultHttp2HeadersFrame(new DefaultHttp2Headers()).stream(activeOutbound)); + + Http2FrameStream bufferedOutbound = frameCodec.newStream(); + channel.writeAndFlush(new DefaultHttp2HeadersFrame(new DefaultHttp2Headers()).stream(bufferedOutbound)); + + @SuppressWarnings("unused") + Http2FrameStream idleStream = frameCodec.newStream(); + + final Set activeStreams = new HashSet(); + frameCodec.forEachActiveStream(new Http2FrameStreamVisitor() { + @Override + public boolean visit(Http2FrameStream stream) { + activeStreams.add(stream); + return true; + } + }); + + assertEquals(2, activeStreams.size()); + + Set expectedStreams = new HashSet(); + expectedStreams.add(activeInbond); + expectedStreams.add(activeOutbound); + assertEquals(expectedStreams, activeStreams); + } + + @Test + public void autoAckPingTrue() throws Exception { + setUp(Http2FrameCodecBuilder.forServer().autoAckPingFrame(true), new Http2Settings()); + frameInboundWriter.writeInboundPing(false, 8); + Http2PingFrame frame = inboundHandler.readInbound(); + assertFalse(frame.ack()); + assertEquals(8, frame.content()); + verify(frameWriter).writePing(eqFrameCodecCtx(), eq(true), eq(8L), anyChannelPromise()); + } + + @Test + public void autoAckPingFalse() throws Exception { + setUp(Http2FrameCodecBuilder.forServer().autoAckPingFrame(false), new Http2Settings()); + frameInboundWriter.writeInboundPing(false, 8); + verify(frameWriter, never()).writePing(eqFrameCodecCtx(), eq(true), eq(8L), anyChannelPromise()); + Http2PingFrame frame = inboundHandler.readInbound(); + assertFalse(frame.ack()); + assertEquals(8, frame.content()); + + // Now ack the frame manually. + channel.writeAndFlush(new DefaultHttp2PingFrame(8, true)); + verify(frameWriter).writePing(eqFrameCodecCtx(), eq(true), eq(8L), anyChannelPromise()); + } + + @Test + public void streamShouldBeOpenInListener() { + final Http2FrameStream stream2 = frameCodec.newStream(); + assertEquals(State.IDLE, stream2.state()); + + final AtomicBoolean listenerExecuted = new AtomicBoolean(); + channel.writeAndFlush(new DefaultHttp2HeadersFrame(new DefaultHttp2Headers()).stream(stream2)) + .addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + assertTrue(future.isSuccess()); + assertEquals(State.OPEN, stream2.state()); + listenerExecuted.set(true); + } + }); + + assertTrue(listenerExecuted.get()); + } + + @Test + public void upgradeEventNoRefCntError() throws Exception { + frameInboundWriter.writeInboundHeaders(Http2CodecUtil.HTTP_UPGRADE_STREAM_ID, request, 31, false); + // Using reflect as the constructor is package-private and the class is final. + Constructor constructor = + UpgradeEvent.class.getDeclaredConstructor(CharSequence.class, FullHttpRequest.class); + + // Check if we could make it accessible which may fail on java9. + Assumptions.assumeTrue(ReflectionUtil.trySetAccessible(constructor, true) == null); + + HttpServerUpgradeHandler.UpgradeEvent upgradeEvent = constructor.newInstance( + "HTTP/2", new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/")); + channel.pipeline().fireUserEventTriggered(upgradeEvent); + assertEquals(1, upgradeEvent.refCnt()); + } + + @Test + public void upgradeWithoutFlowControlling() throws Exception { + channel.pipeline().addAfter(frameCodec.ctx.name(), null, new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(final ChannelHandlerContext ctx, Object msg) throws Exception { + if (msg instanceof Http2DataFrame) { + // Simulate consuming the frame and update the flow-controller. + Http2DataFrame data = (Http2DataFrame) msg; + ctx.writeAndFlush(new DefaultHttp2WindowUpdateFrame(data.initialFlowControlledBytes()) + .stream(data.stream())).addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + Throwable cause = future.cause(); + if (cause != null) { + ctx.fireExceptionCaught(cause); + } + } + }); + } + ReferenceCountUtil.release(msg); + } + }); + + frameInboundWriter.writeInboundHeaders(Http2CodecUtil.HTTP_UPGRADE_STREAM_ID, request, 31, false); + + // Using reflect as the constructor is package-private and the class is final. + Constructor constructor = + UpgradeEvent.class.getDeclaredConstructor(CharSequence.class, FullHttpRequest.class); + + // Check if we could make it accessible which may fail on java9. + Assumptions.assumeTrue(ReflectionUtil.trySetAccessible(constructor, true) == null); + + String longString = new String(new char[70000]).replace("\0", "*"); + DefaultFullHttpRequest request = + new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/", bb(longString)); + + HttpServerUpgradeHandler.UpgradeEvent upgradeEvent = constructor.newInstance( + "HTTP/2", request); + channel.pipeline().fireUserEventTriggered(upgradeEvent); + } + + @Test + public void priorityForNonExistingStream() { + writeHeaderAndAssert(1); + + frameInboundWriter.writeInboundPriority(3, 1, (short) 31, true); + } + + @Test + public void priorityForExistingStream() { + writeHeaderAndAssert(1); + writeHeaderAndAssert(3); + frameInboundWriter.writeInboundPriority(3, 1, (short) 31, true); + + assertInboundStreamFrame(3, new DefaultHttp2PriorityFrame(1, (short) 31, true)); + } + + private void writeHeaderAndAssert(int streamId) { + frameInboundWriter.writeInboundHeaders(streamId, request, 31, false); + + Http2Stream stream = frameCodec.connection().stream(streamId); + assertNotNull(stream); + assertEquals(State.OPEN, stream.state()); + + assertInboundStreamFrame(streamId, new DefaultHttp2HeadersFrame(request, false, 31)); + } + + private void assertInboundStreamFrame(int expectedId, Http2StreamFrame streamFrame) { + Http2StreamFrame inboundFrame = inboundHandler.readInbound(); + Http2FrameStream stream2 = inboundFrame.stream(); + assertNotNull(stream2); + assertEquals(expectedId, stream2.id()); + assertEquals(inboundFrame, streamFrame.stream(stream2)); + } + + private ChannelHandlerContext eqFrameCodecCtx() { + return eq(frameCodec.ctx); + } +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2FrameInboundWriter.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2FrameInboundWriter.java new file mode 100644 index 0000000..65ea48b --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2FrameInboundWriter.java @@ -0,0 +1,340 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelOutboundHandlerAdapter; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.ChannelProgressivePromise; +import io.netty.channel.ChannelPromise; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.util.Attribute; +import io.netty.util.AttributeKey; +import io.netty.util.concurrent.EventExecutor; + +import java.net.SocketAddress; + +/** + * Utility class which allows easy writing of HTTP2 frames via {@link EmbeddedChannel#writeInbound(Object...)}. + */ +final class Http2FrameInboundWriter { + + private final ChannelHandlerContext ctx; + private final Http2FrameWriter writer; + + Http2FrameInboundWriter(EmbeddedChannel channel) { + this(channel, new DefaultHttp2FrameWriter()); + } + + Http2FrameInboundWriter(EmbeddedChannel channel, Http2FrameWriter writer) { + ctx = new WriteInboundChannelHandlerContext(channel); + this.writer = writer; + } + + void writeInboundData(int streamId, ByteBuf data, int padding, boolean endStream) { + writer.writeData(ctx, streamId, data, padding, endStream, ctx.newPromise()).syncUninterruptibly(); + } + + void writeInboundHeaders(int streamId, Http2Headers headers, + int padding, boolean endStream) { + writer.writeHeaders(ctx, streamId, headers, padding, endStream, ctx.newPromise()).syncUninterruptibly(); + } + + void writeInboundHeaders(int streamId, Http2Headers headers, + int streamDependency, short weight, boolean exclusive, int padding, boolean endStream) { + writer.writeHeaders(ctx, streamId, headers, streamDependency, + weight, exclusive, padding, endStream, ctx.newPromise()).syncUninterruptibly(); + } + + void writeInboundPriority(int streamId, int streamDependency, + short weight, boolean exclusive) { + writer.writePriority(ctx, streamId, streamDependency, weight, + exclusive, ctx.newPromise()).syncUninterruptibly(); + } + + void writeInboundRstStream(int streamId, long errorCode) { + writer.writeRstStream(ctx, streamId, errorCode, ctx.newPromise()).syncUninterruptibly(); + } + + void writeInboundSettings(Http2Settings settings) { + writer.writeSettings(ctx, settings, ctx.newPromise()).syncUninterruptibly(); + } + + void writeInboundSettingsAck() { + writer.writeSettingsAck(ctx, ctx.newPromise()).syncUninterruptibly(); + } + + void writeInboundPing(boolean ack, long data) { + writer.writePing(ctx, ack, data, ctx.newPromise()).syncUninterruptibly(); + } + + void writePushPromise(int streamId, int promisedStreamId, + Http2Headers headers, int padding) { + writer.writePushPromise(ctx, streamId, promisedStreamId, + headers, padding, ctx.newPromise()).syncUninterruptibly(); + } + + void writeInboundGoAway(int lastStreamId, long errorCode, ByteBuf debugData) { + writer.writeGoAway(ctx, lastStreamId, errorCode, debugData, ctx.newPromise()).syncUninterruptibly(); + } + + void writeInboundWindowUpdate(int streamId, int windowSizeIncrement) { + writer.writeWindowUpdate(ctx, streamId, windowSizeIncrement, ctx.newPromise()).syncUninterruptibly(); + } + + void writeInboundFrame(byte frameType, int streamId, + Http2Flags flags, ByteBuf payload) { + writer.writeFrame(ctx, frameType, streamId, flags, payload, ctx.newPromise()).syncUninterruptibly(); + } + + private static final class WriteInboundChannelHandlerContext extends ChannelOutboundHandlerAdapter + implements ChannelHandlerContext { + private final EmbeddedChannel channel; + + WriteInboundChannelHandlerContext(EmbeddedChannel channel) { + this.channel = channel; + } + + @Override + public Channel channel() { + return channel; + } + + @Override + public EventExecutor executor() { + return channel.eventLoop(); + } + + @Override + public String name() { + return "WriteInbound"; + } + + @Override + public ChannelHandler handler() { + return this; + } + + @Override + public boolean isRemoved() { + return false; + } + + @Override + public ChannelHandlerContext fireChannelRegistered() { + channel.pipeline().fireChannelRegistered(); + return this; + } + + @Override + public ChannelHandlerContext fireChannelUnregistered() { + channel.pipeline().fireChannelUnregistered(); + return this; + } + + @Override + public ChannelHandlerContext fireChannelActive() { + channel.pipeline().fireChannelActive(); + return this; + } + + @Override + public ChannelHandlerContext fireChannelInactive() { + channel.pipeline().fireChannelInactive(); + return this; + } + + @Override + public ChannelHandlerContext fireExceptionCaught(Throwable cause) { + channel.pipeline().fireExceptionCaught(cause); + return this; + } + + @Override + public ChannelHandlerContext fireUserEventTriggered(Object evt) { + channel.pipeline().fireUserEventTriggered(evt); + return this; + } + + @Override + public ChannelHandlerContext fireChannelRead(Object msg) { + channel.pipeline().fireChannelRead(msg); + return this; + } + + @Override + public ChannelHandlerContext fireChannelReadComplete() { + channel.pipeline().fireChannelReadComplete(); + return this; + } + + @Override + public ChannelHandlerContext fireChannelWritabilityChanged() { + channel.pipeline().fireChannelWritabilityChanged(); + return this; + } + + @Override + public ChannelHandlerContext read() { + channel.read(); + return this; + } + + @Override + public ChannelHandlerContext flush() { + channel.pipeline().fireChannelReadComplete(); + return this; + } + + @Override + public ChannelPipeline pipeline() { + return channel.pipeline(); + } + + @Override + public ByteBufAllocator alloc() { + return channel.alloc(); + } + + @Override + public Attribute attr(AttributeKey key) { + return channel.attr(key); + } + + @Override + public boolean hasAttr(AttributeKey key) { + return channel.hasAttr(key); + } + + @Override + public ChannelFuture bind(SocketAddress localAddress) { + return channel.bind(localAddress); + } + + @Override + public ChannelFuture connect(SocketAddress remoteAddress) { + return channel.connect(remoteAddress); + } + + @Override + public ChannelFuture connect(SocketAddress remoteAddress, SocketAddress localAddress) { + return channel.connect(remoteAddress, localAddress); + } + + @Override + public ChannelFuture disconnect() { + return channel.disconnect(); + } + + @Override + public ChannelFuture close() { + return channel.close(); + } + + @Override + public ChannelFuture deregister() { + return channel.deregister(); + } + + @Override + public ChannelFuture bind(SocketAddress localAddress, ChannelPromise promise) { + return channel.bind(localAddress, promise); + } + + @Override + public ChannelFuture connect(SocketAddress remoteAddress, ChannelPromise promise) { + return channel.connect(remoteAddress, promise); + } + + @Override + public ChannelFuture connect(SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) { + return channel.connect(remoteAddress, localAddress, promise); + } + + @Override + public ChannelFuture disconnect(ChannelPromise promise) { + return channel.disconnect(promise); + } + + @Override + public ChannelFuture close(ChannelPromise promise) { + return channel.close(promise); + } + + @Override + public ChannelFuture deregister(ChannelPromise promise) { + return channel.deregister(promise); + } + + @Override + public ChannelFuture write(Object msg) { + return write(msg, newPromise()); + } + + @Override + public ChannelFuture write(Object msg, ChannelPromise promise) { + return writeAndFlush(msg, promise); + } + + @Override + public ChannelFuture writeAndFlush(Object msg, ChannelPromise promise) { + try { + channel.writeInbound(msg); + channel.runPendingTasks(); + promise.setSuccess(); + } catch (Throwable cause) { + promise.setFailure(cause); + } + return promise; + } + + @Override + public ChannelFuture writeAndFlush(Object msg) { + return writeAndFlush(msg, newPromise()); + } + + @Override + public ChannelPromise newPromise() { + return channel.newPromise(); + } + + @Override + public ChannelProgressivePromise newProgressivePromise() { + return channel.newProgressivePromise(); + } + + @Override + public ChannelFuture newSucceededFuture() { + return channel.newSucceededFuture(); + } + + @Override + public ChannelFuture newFailedFuture(Throwable cause) { + return channel.newFailedFuture(cause); + } + + @Override + public ChannelPromise voidPromise() { + return channel.voidPromise(); + } + } +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2FrameRoundtripTest.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2FrameRoundtripTest.java new file mode 100644 index 0000000..ff33281 --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2FrameRoundtripTest.java @@ -0,0 +1,491 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.buffer.EmptyByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.channel.DefaultChannelPromise; +import io.netty.util.AsciiString; +import io.netty.util.concurrent.EventExecutor; +import io.netty.util.concurrent.GlobalEventExecutor; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +import java.util.LinkedList; +import java.util.List; + +import static io.netty.buffer.Unpooled.EMPTY_BUFFER; +import static io.netty.handler.codec.http2.Http2CodecUtil.MAX_PADDING; +import static io.netty.handler.codec.http2.Http2HeadersEncoder.NEVER_SENSITIVE; +import static io.netty.handler.codec.http2.Http2TestUtil.newTestDecoder; +import static io.netty.handler.codec.http2.Http2TestUtil.newTestEncoder; +import static io.netty.handler.codec.http2.Http2TestUtil.randomString; +import static io.netty.util.CharsetUtil.UTF_8; +import static java.lang.Math.min; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.anyBoolean; +import static org.mockito.Mockito.anyInt; +import static org.mockito.Mockito.anyShort; +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.isA; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +/** + * Tests encoding/decoding each HTTP2 frame type. + */ +public class Http2FrameRoundtripTest { + private static final byte[] MESSAGE = "hello world".getBytes(UTF_8); + private static final int STREAM_ID = 0x7FFFFFFF; + private static final int WINDOW_UPDATE = 0x7FFFFFFF; + private static final long ERROR_CODE = 0xFFFFFFFFL; + + @Mock + private Http2FrameListener listener; + + @Mock + private ChannelHandlerContext ctx; + + @Mock + private EventExecutor executor; + + @Mock + private Channel channel; + + @Mock + private ByteBufAllocator alloc; + + private Http2FrameWriter writer; + private Http2FrameReader reader; + private final List needReleasing = new LinkedList(); + + @BeforeEach + public void setup() throws Exception { + MockitoAnnotations.initMocks(this); + + when(ctx.alloc()).thenReturn(alloc); + when(ctx.executor()).thenReturn(executor); + when(ctx.channel()).thenReturn(channel); + doAnswer(new Answer() { + @Override + public ByteBuf answer(InvocationOnMock in) throws Throwable { + return Unpooled.buffer(); + } + }).when(alloc).buffer(); + doAnswer(new Answer() { + @Override + public ByteBuf answer(InvocationOnMock in) throws Throwable { + return Unpooled.buffer((Integer) in.getArguments()[0]); + } + }).when(alloc).buffer(anyInt()); + doAnswer(new Answer() { + @Override + public ChannelPromise answer(InvocationOnMock invocation) throws Throwable { + return new DefaultChannelPromise(channel, GlobalEventExecutor.INSTANCE); + } + }).when(ctx).newPromise(); + + writer = new DefaultHttp2FrameWriter(new DefaultHttp2HeadersEncoder(NEVER_SENSITIVE, newTestEncoder())); + reader = new DefaultHttp2FrameReader(new DefaultHttp2HeadersDecoder(false, false, newTestDecoder())); + } + + @AfterEach + public void teardown() { + try { + // Release all of the buffers. + for (ByteBuf buf : needReleasing) { + buf.release(); + } + // Now verify that all of the reference counts are zero. + for (ByteBuf buf : needReleasing) { + int expectedFinalRefCount = 0; + if (buf.isReadOnly() || buf instanceof EmptyByteBuf) { + // Special case for when we're writing slices of the padding buffer. + expectedFinalRefCount = 1; + } + assertEquals(expectedFinalRefCount, buf.refCnt()); + } + } finally { + needReleasing.clear(); + } + } + + @Test + public void emptyDataShouldMatch() throws Exception { + final ByteBuf data = EMPTY_BUFFER; + writer.writeData(ctx, STREAM_ID, data.slice(), 0, false, ctx.newPromise()); + readFrames(); + verify(listener).onDataRead(eq(ctx), eq(STREAM_ID), eq(data), eq(0), eq(false)); + } + + @Test + public void dataShouldMatch() throws Exception { + final ByteBuf data = data(10); + writer.writeData(ctx, STREAM_ID, data.slice(), 1, false, ctx.newPromise()); + readFrames(); + verify(listener).onDataRead(eq(ctx), eq(STREAM_ID), eq(data), eq(1), eq(false)); + } + + @Test + public void dataWithPaddingShouldMatch() throws Exception { + final ByteBuf data = data(10); + writer.writeData(ctx, STREAM_ID, data.slice(), MAX_PADDING, true, ctx.newPromise()); + readFrames(); + verify(listener).onDataRead(eq(ctx), eq(STREAM_ID), eq(data), eq(MAX_PADDING), eq(true)); + } + + @Test + public void largeDataFrameShouldMatch() throws Exception { + // Create a large message to force chunking. + final ByteBuf originalData = data(1024 * 1024); + final int originalPadding = 100; + final boolean endOfStream = true; + + writer.writeData(ctx, STREAM_ID, originalData.slice(), originalPadding, + endOfStream, ctx.newPromise()); + readFrames(); + + // Verify that at least one frame was sent with eos=false and exactly one with eos=true. + verify(listener, atLeastOnce()).onDataRead(eq(ctx), eq(STREAM_ID), any(ByteBuf.class), + anyInt(), eq(false)); + verify(listener).onDataRead(eq(ctx), eq(STREAM_ID), any(ByteBuf.class), + anyInt(), eq(true)); + + // Capture the read data and padding. + ArgumentCaptor dataCaptor = ArgumentCaptor.forClass(ByteBuf.class); + ArgumentCaptor paddingCaptor = ArgumentCaptor.forClass(Integer.class); + verify(listener, atLeastOnce()).onDataRead(eq(ctx), eq(STREAM_ID), dataCaptor.capture(), + paddingCaptor.capture(), anyBoolean()); + + // Make sure the data matches the original. + for (ByteBuf chunk : dataCaptor.getAllValues()) { + ByteBuf originalChunk = originalData.readSlice(chunk.readableBytes()); + assertEquals(originalChunk, chunk); + } + assertFalse(originalData.isReadable()); + + // Make sure the padding matches the original. + int totalReadPadding = 0; + for (int framePadding : paddingCaptor.getAllValues()) { + totalReadPadding += framePadding; + } + assertEquals(originalPadding, totalReadPadding); + } + + @Test + public void emptyHeadersShouldMatch() throws Exception { + final Http2Headers headers = EmptyHttp2Headers.INSTANCE; + writer.writeHeaders(ctx, STREAM_ID, headers, 0, true, ctx.newPromise()); + readFrames(); + verify(listener).onHeadersRead(eq(ctx), eq(STREAM_ID), eq(headers), eq(0), eq(true)); + } + + @Test + public void emptyHeadersWithPaddingShouldMatch() throws Exception { + final Http2Headers headers = EmptyHttp2Headers.INSTANCE; + writer.writeHeaders(ctx, STREAM_ID, headers, MAX_PADDING, true, ctx.newPromise()); + readFrames(); + verify(listener).onHeadersRead(eq(ctx), eq(STREAM_ID), eq(headers), eq(MAX_PADDING), eq(true)); + } + + @Test + public void binaryHeadersWithoutPriorityShouldMatch() throws Exception { + final Http2Headers headers = binaryHeaders(); + writer.writeHeaders(ctx, STREAM_ID, headers, 0, true, ctx.newPromise()); + readFrames(); + verify(listener).onHeadersRead(eq(ctx), eq(STREAM_ID), eq(headers), eq(0), eq(true)); + } + + @Test + public void headersFrameWithoutPriorityShouldMatch() throws Exception { + final Http2Headers headers = headers(); + writer.writeHeaders(ctx, STREAM_ID, headers, 0, true, ctx.newPromise()); + readFrames(); + verify(listener).onHeadersRead(eq(ctx), eq(STREAM_ID), eq(headers), eq(0), eq(true)); + } + + @Test + public void headersFrameWithPriorityShouldMatch() throws Exception { + final Http2Headers headers = headers(); + writer.writeHeaders(ctx, STREAM_ID, headers, 4, (short) 255, true, 0, true, ctx.newPromise()); + readFrames(); + verify(listener).onHeadersRead(eq(ctx), eq(STREAM_ID), eq(headers), eq(4), eq((short) 255), + eq(true), eq(0), eq(true)); + } + + @Test + public void headersWithPaddingWithoutPriorityShouldMatch() throws Exception { + final Http2Headers headers = headers(); + writer.writeHeaders(ctx, STREAM_ID, headers, MAX_PADDING, true, ctx.newPromise()); + readFrames(); + verify(listener).onHeadersRead(eq(ctx), eq(STREAM_ID), eq(headers), eq(MAX_PADDING), eq(true)); + } + + @Test + public void headersWithPaddingWithPriorityShouldMatch() throws Exception { + final Http2Headers headers = headers(); + writer.writeHeaders(ctx, STREAM_ID, headers, 2, (short) 3, true, 1, true, ctx.newPromise()); + readFrames(); + verify(listener).onHeadersRead(eq(ctx), eq(STREAM_ID), eq(headers), eq(2), eq((short) 3), eq(true), + eq(1), eq(true)); + } + + @Test + public void continuedHeadersShouldMatch() throws Exception { + final Http2Headers headers = largeHeaders(); + writer.writeHeaders(ctx, STREAM_ID, headers, 2, (short) 3, true, 0, true, ctx.newPromise()); + readFrames(); + verify(listener) + .onHeadersRead(eq(ctx), eq(STREAM_ID), eq(headers), eq(2), eq((short) 3), eq(true), eq(0), eq(true)); + } + + @Test + public void continuedHeadersWithPaddingShouldMatch() throws Exception { + final Http2Headers headers = largeHeaders(); + writer.writeHeaders(ctx, STREAM_ID, headers, 2, (short) 3, true, MAX_PADDING, true, ctx.newPromise()); + readFrames(); + verify(listener).onHeadersRead(eq(ctx), eq(STREAM_ID), eq(headers), eq(2), eq((short) 3), eq(true), + eq(MAX_PADDING), eq(true)); + } + + @Test + public void headersThatAreTooBigShouldFail() throws Exception { + reader = new DefaultHttp2FrameReader(false); + final int maxListSize = 100; + reader.configuration().headersConfiguration().maxHeaderListSize(maxListSize, maxListSize); + final Http2Headers headers = headersOfSize(maxListSize + 1); + writer.writeHeaders(ctx, STREAM_ID, headers, 2, (short) 3, true, MAX_PADDING, true, ctx.newPromise()); + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + readFrames(); + } + }); + verify(listener, never()).onHeadersRead(any(ChannelHandlerContext.class), anyInt(), + any(Http2Headers.class), anyInt(), anyShort(), anyBoolean(), anyInt(), + anyBoolean()); + } + + @Test + public void emptyPushPromiseShouldMatch() throws Exception { + final Http2Headers headers = EmptyHttp2Headers.INSTANCE; + writer.writePushPromise(ctx, STREAM_ID, 2, headers, 0, ctx.newPromise()); + readFrames(); + verify(listener).onPushPromiseRead(eq(ctx), eq(STREAM_ID), eq(2), eq(headers), eq(0)); + } + + @Test + public void pushPromiseFrameShouldMatch() throws Exception { + final Http2Headers headers = headers(); + writer.writePushPromise(ctx, STREAM_ID, 1, headers, 5, ctx.newPromise()); + readFrames(); + verify(listener).onPushPromiseRead(eq(ctx), eq(STREAM_ID), eq(1), eq(headers), eq(5)); + } + + @Test + public void pushPromiseWithPaddingShouldMatch() throws Exception { + final Http2Headers headers = headers(); + writer.writePushPromise(ctx, STREAM_ID, 2, headers, MAX_PADDING, ctx.newPromise()); + readFrames(); + verify(listener).onPushPromiseRead(eq(ctx), eq(STREAM_ID), eq(2), eq(headers), eq(MAX_PADDING)); + } + + @Test + public void continuedPushPromiseShouldMatch() throws Exception { + final Http2Headers headers = largeHeaders(); + writer.writePushPromise(ctx, STREAM_ID, 2, headers, 0, ctx.newPromise()); + readFrames(); + verify(listener).onPushPromiseRead(eq(ctx), eq(STREAM_ID), eq(2), eq(headers), eq(0)); + } + + @Test + public void continuedPushPromiseWithPaddingShouldMatch() throws Exception { + final Http2Headers headers = largeHeaders(); + writer.writePushPromise(ctx, STREAM_ID, 2, headers, 0xFF, ctx.newPromise()); + readFrames(); + verify(listener).onPushPromiseRead(eq(ctx), eq(STREAM_ID), eq(2), eq(headers), eq(0xFF)); + } + + @Test + public void goAwayFrameShouldMatch() throws Exception { + final String text = "test"; + final ByteBuf data = buf(text.getBytes()); + + writer.writeGoAway(ctx, STREAM_ID, ERROR_CODE, data.slice(), ctx.newPromise()); + readFrames(); + + ArgumentCaptor captor = ArgumentCaptor.forClass(ByteBuf.class); + verify(listener).onGoAwayRead(eq(ctx), eq(STREAM_ID), eq(ERROR_CODE), captor.capture()); + assertEquals(data, captor.getValue()); + } + + @Test + public void pingFrameShouldMatch() throws Exception { + writer.writePing(ctx, false, 1234567, ctx.newPromise()); + readFrames(); + + ArgumentCaptor captor = ArgumentCaptor.forClass(long.class); + verify(listener).onPingRead(eq(ctx), captor.capture()); + assertEquals(1234567, (long) captor.getValue()); + } + + @Test + public void pingAckFrameShouldMatch() throws Exception { + writer.writePing(ctx, true, 1234567, ctx.newPromise()); + readFrames(); + + ArgumentCaptor captor = ArgumentCaptor.forClass(long.class); + verify(listener).onPingAckRead(eq(ctx), captor.capture()); + assertEquals(1234567, (long) captor.getValue()); + } + + @Test + public void priorityFrameShouldMatch() throws Exception { + writer.writePriority(ctx, STREAM_ID, 1, (short) 1, true, ctx.newPromise()); + readFrames(); + verify(listener).onPriorityRead(eq(ctx), eq(STREAM_ID), eq(1), eq((short) 1), eq(true)); + } + + @Test + public void rstStreamFrameShouldMatch() throws Exception { + writer.writeRstStream(ctx, STREAM_ID, ERROR_CODE, ctx.newPromise()); + readFrames(); + verify(listener).onRstStreamRead(eq(ctx), eq(STREAM_ID), eq(ERROR_CODE)); + } + + @Test + public void emptySettingsFrameShouldMatch() throws Exception { + final Http2Settings settings = new Http2Settings(); + writer.writeSettings(ctx, settings, ctx.newPromise()); + readFrames(); + verify(listener).onSettingsRead(eq(ctx), eq(settings)); + } + + @Test + public void settingsShouldStripShouldMatch() throws Exception { + final Http2Settings settings = new Http2Settings(); + settings.pushEnabled(true); + settings.headerTableSize(4096); + settings.initialWindowSize(123); + settings.maxConcurrentStreams(456); + + writer.writeSettings(ctx, settings, ctx.newPromise()); + readFrames(); + verify(listener).onSettingsRead(eq(ctx), eq(settings)); + } + + @Test + public void settingsAckShouldMatch() throws Exception { + writer.writeSettingsAck(ctx, ctx.newPromise()); + readFrames(); + verify(listener).onSettingsAckRead(eq(ctx)); + } + + @Test + public void windowUpdateFrameShouldMatch() throws Exception { + writer.writeWindowUpdate(ctx, STREAM_ID, WINDOW_UPDATE, ctx.newPromise()); + readFrames(); + verify(listener).onWindowUpdateRead(eq(ctx), eq(STREAM_ID), eq(WINDOW_UPDATE)); + } + + private void readFrames() throws Http2Exception { + // Now read all of the written frames. + ByteBuf write = captureWrites(); + reader.readFrame(ctx, write, listener); + } + + private static ByteBuf data(int size) { + byte[] data = new byte[size]; + for (int ix = 0; ix < data.length;) { + int length = min(MESSAGE.length, data.length - ix); + System.arraycopy(MESSAGE, 0, data, ix, length); + ix += length; + } + return buf(data); + } + + private static ByteBuf buf(byte[] bytes) { + return Unpooled.wrappedBuffer(bytes); + } + + private T releaseLater(T buf) { + needReleasing.add(buf); + return buf; + } + + private ByteBuf captureWrites() { + ArgumentCaptor captor = ArgumentCaptor.forClass(ByteBuf.class); + verify(ctx, atLeastOnce()).write(captor.capture(), isA(ChannelPromise.class)); + CompositeByteBuf composite = releaseLater(Unpooled.compositeBuffer()); + for (ByteBuf buf : captor.getAllValues()) { + buf = releaseLater(buf.retain()); + composite.addComponent(true, buf); + } + return composite; + } + + private static Http2Headers headers() { + return new DefaultHttp2Headers(false).method(AsciiString.of("GET")).scheme(AsciiString.of("https")) + .authority(AsciiString.of("example.org")).path(AsciiString.of("/some/path/resource2")) + .add(randomString(), randomString()); + } + + private static Http2Headers largeHeaders() { + DefaultHttp2Headers headers = new DefaultHttp2Headers(false); + for (int i = 0; i < 100; ++i) { + String key = "this-is-a-test-header-key-" + i; + String value = "this-is-a-test-header-value-" + i; + headers.add(AsciiString.of(key), AsciiString.of(value)); + } + return headers; + } + + private static Http2Headers headersOfSize(final int minSize) { + final AsciiString singleByte = new AsciiString(new byte[]{0}, false); + DefaultHttp2Headers headers = new DefaultHttp2Headers(false); + for (int size = 0; size < minSize; size += 2) { + headers.add(singleByte, singleByte); + } + return headers; + } + + private static Http2Headers binaryHeaders() { + DefaultHttp2Headers headers = new DefaultHttp2Headers(false); + for (int ix = 0; ix < 10; ++ix) { + headers.add(randomString(), randomString()); + } + return headers; + } +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2HeaderBlockIOTest.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2HeaderBlockIOTest.java new file mode 100644 index 0000000..927f4c1 --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2HeaderBlockIOTest.java @@ -0,0 +1,101 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.util.AsciiString; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import static io.netty.handler.codec.http2.Http2TestUtil.randomString; +import static org.junit.jupiter.api.Assertions.assertEquals; + +/** + * Tests for encoding/decoding HTTP2 header blocks. + */ +public class Http2HeaderBlockIOTest { + + private DefaultHttp2HeadersDecoder decoder; + private DefaultHttp2HeadersEncoder encoder; + private ByteBuf buffer; + + @BeforeEach + public void setup() { + encoder = new DefaultHttp2HeadersEncoder(); + decoder = new DefaultHttp2HeadersDecoder(false); + buffer = Unpooled.buffer(); + } + + @AfterEach + public void teardown() { + buffer.release(); + } + + @Test + public void roundtripShouldBeSuccessful() throws Http2Exception { + Http2Headers in = headers(); + assertRoundtripSuccessful(in); + } + + @Test + public void successiveCallsShouldSucceed() throws Http2Exception { + Http2Headers in = new DefaultHttp2Headers().method(new AsciiString("GET")).scheme(new AsciiString("https")) + .authority(new AsciiString("example.org")).path(new AsciiString("/some/path")) + .add(new AsciiString("accept"), new AsciiString("*/*")); + assertRoundtripSuccessful(in); + + in = new DefaultHttp2Headers().method(new AsciiString("GET")).scheme(new AsciiString("https")) + .authority(new AsciiString("example.org")).path(new AsciiString("/some/path/resource1")) + .add(new AsciiString("accept"), new AsciiString("image/jpeg")) + .add(new AsciiString("cache-control"), new AsciiString("no-cache")); + assertRoundtripSuccessful(in); + + in = new DefaultHttp2Headers().method(new AsciiString("GET")).scheme(new AsciiString("https")) + .authority(new AsciiString("example.org")).path(new AsciiString("/some/path/resource2")) + .add(new AsciiString("accept"), new AsciiString("image/png")) + .add(new AsciiString("cache-control"), new AsciiString("no-cache")); + assertRoundtripSuccessful(in); + } + + @Test + public void setMaxHeaderSizeShouldBeSuccessful() throws Http2Exception { + encoder.maxHeaderTableSize(10); + Http2Headers in = headers(); + assertRoundtripSuccessful(in); + assertEquals(10, decoder.maxHeaderTableSize()); + } + + private void assertRoundtripSuccessful(Http2Headers in) throws Http2Exception { + encoder.encodeHeaders(3 /* randomly chosen */, in, buffer); + + Http2Headers out = decoder.decodeHeaders(0, buffer); + assertEquals(in, out); + } + + private static Http2Headers headers() { + return new DefaultHttp2Headers(false).method(new AsciiString("GET")).scheme(new AsciiString("https")) + .authority(new AsciiString("example.org")).path(new AsciiString("/some/path/resource2")) + .add(new AsciiString("accept"), new AsciiString("image/png")) + .add(new AsciiString("cache-control"), new AsciiString("no-cache")) + .add(new AsciiString("custom"), new AsciiString("value1")) + .add(new AsciiString("custom"), new AsciiString("value2")) + .add(new AsciiString("custom"), new AsciiString("value3")) + .add(new AsciiString("custom"), new AsciiString("custom4")) + .add(randomString(), randomString()); + } +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2MaxRstFrameConnectionDecoderTest.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2MaxRstFrameConnectionDecoderTest.java new file mode 100644 index 0000000..5ec3538 --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2MaxRstFrameConnectionDecoderTest.java @@ -0,0 +1,28 @@ +/* + * Copyright 2023 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +public class Http2MaxRstFrameConnectionDecoderTest extends AbstractDecoratingHttp2ConnectionDecoderTest { + + @Override + protected DecoratingHttp2ConnectionDecoder newDecoder(Http2ConnectionDecoder decoder) { + return new Http2MaxRstFrameDecoder(decoder, 200, 30); + } + + @Override + protected Class delegatingFrameListenerType() { + return Http2MaxRstFrameListener.class; + } +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2MaxRstFrameListenerTest.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2MaxRstFrameListenerTest.java new file mode 100644 index 0000000..381834b --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2MaxRstFrameListenerTest.java @@ -0,0 +1,68 @@ +/* + * Copyright 2023 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.channel.ChannelHandlerContext; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; +import org.mockito.Mock; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.MockitoAnnotations.initMocks; + +public class Http2MaxRstFrameListenerTest { + + @Mock + private Http2FrameListener frameListener; + @Mock + private ChannelHandlerContext ctx; + + private Http2MaxRstFrameListener listener; + + @BeforeEach + public void setUp() { + initMocks(this); + } + + @Test + public void testMaxRstFramesReached() throws Http2Exception { + listener = new Http2MaxRstFrameListener(frameListener, 1, 10); + listener.onRstStreamRead(ctx, 1, Http2Error.STREAM_CLOSED.code()); + + Http2Exception ex = assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + listener.onRstStreamRead(ctx, 2, Http2Error.STREAM_CLOSED.code()); + } + }); + assertEquals(Http2Error.ENHANCE_YOUR_CALM, ex.error()); + verify(frameListener, times(1)).onRstStreamRead(eq(ctx), anyInt(), eq(Http2Error.STREAM_CLOSED.code())); + } + + @Test + public void testRstFrames() throws Exception { + listener = new Http2MaxRstFrameListener(frameListener, 1, 1); + listener.onRstStreamRead(ctx, 1, Http2Error.STREAM_CLOSED.code()); + Thread.sleep(1100); + listener.onRstStreamRead(ctx, 1, Http2Error.STREAM_CLOSED.code()); + verify(frameListener, times(2)).onRstStreamRead(eq(ctx), anyInt(), eq(Http2Error.STREAM_CLOSED.code())); + } +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2MultiplexClientUpgradeTest.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2MultiplexClientUpgradeTest.java new file mode 100644 index 0000000..d3a2f65 --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2MultiplexClientUpgradeTest.java @@ -0,0 +1,96 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.embedded.EmbeddedChannel; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public abstract class Http2MultiplexClientUpgradeTest { + + @ChannelHandler.Sharable + static final class NoopHandler extends ChannelInboundHandlerAdapter { + @Override + public void channelActive(ChannelHandlerContext ctx) { + ctx.channel().close(); + } + } + + private static final class UpgradeHandler extends ChannelInboundHandlerAdapter { + Http2Stream.State stateOnActive; + int streamId; + boolean channelInactiveCalled; + + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + Http2StreamChannel ch = (Http2StreamChannel) ctx.channel(); + stateOnActive = ch.stream().state(); + streamId = ch.stream().id(); + super.channelActive(ctx); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + channelInactiveCalled = true; + super.channelInactive(ctx); + } + } + + protected abstract C newCodec(ChannelHandler upgradeHandler); + + protected abstract ChannelHandler newMultiplexer(ChannelHandler upgradeHandler); + + @Test + public void upgradeHandlerGetsActivated() throws Exception { + UpgradeHandler upgradeHandler = new UpgradeHandler(); + C codec = newCodec(upgradeHandler); + EmbeddedChannel ch = new EmbeddedChannel(codec, newMultiplexer(upgradeHandler)); + + codec.onHttpClientUpgrade(); + + assertFalse(upgradeHandler.stateOnActive.localSideOpen()); + assertTrue(upgradeHandler.stateOnActive.remoteSideOpen()); + assertNotNull(codec.connection().stream(Http2CodecUtil.HTTP_UPGRADE_STREAM_ID).getProperty(codec.streamKey)); + assertEquals(Http2CodecUtil.HTTP_UPGRADE_STREAM_ID, upgradeHandler.streamId); + assertTrue(ch.finishAndReleaseAll()); + assertTrue(upgradeHandler.channelInactiveCalled); + } + + @Test + public void clientUpgradeWithoutUpgradeHandlerThrowsHttp2Exception() throws Http2Exception { + final C codec = newCodec(null); + final EmbeddedChannel ch = new EmbeddedChannel(codec, newMultiplexer(null)); + + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Http2Exception { + try { + codec.onHttpClientUpgrade(); + } finally { + ch.finishAndReleaseAll(); + } + } + }); + } +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2MultiplexCodecBuilderTest.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2MultiplexCodecBuilderTest.java new file mode 100644 index 0000000..4934515 --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2MultiplexCodecBuilderTest.java @@ -0,0 +1,268 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandler.Sharable; +import io.netty.channel.ChannelHandlerAdapter; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.DefaultEventLoop; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.local.LocalAddress; +import io.netty.channel.local.LocalChannel; +import io.netty.channel.local.LocalServerChannel; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import java.util.concurrent.CountDownLatch; + +import static io.netty.handler.codec.http2.Http2CodecUtil.isStreamIdValid; +import static java.util.concurrent.TimeUnit.SECONDS; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +/** + * Unit tests for {@link Http2MultiplexCodec}. + */ +public class Http2MultiplexCodecBuilderTest { + + private static EventLoopGroup group; + private Channel serverChannel; + private volatile Channel serverConnectedChannel; + private Channel clientChannel; + private LastInboundHandler serverLastInboundHandler; + + @BeforeAll + public static void init() { + group = new DefaultEventLoop(); + } + + @BeforeEach + public void setUp() throws InterruptedException { + final CountDownLatch serverChannelLatch = new CountDownLatch(1); + LocalAddress serverAddress = new LocalAddress(getClass()); + serverLastInboundHandler = new SharableLastInboundHandler(); + ServerBootstrap sb = new ServerBootstrap() + .channel(LocalServerChannel.class) + .group(group) + .childHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + serverConnectedChannel = ch; + ch.pipeline().addLast(new Http2MultiplexCodecBuilder(true, new ChannelInitializer() { + + @Override + protected void initChannel(Channel ch) throws Exception { + ch.pipeline().addLast(new ChannelInboundHandlerAdapter() { + private boolean writable; + + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + writable |= ctx.channel().isWritable(); + super.channelActive(ctx); + } + + @Override + public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception { + writable |= ctx.channel().isWritable(); + super.channelWritabilityChanged(ctx); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + assertTrue(writable); + super.channelInactive(ctx); + } + }); + ch.pipeline().addLast(serverLastInboundHandler); + } + }).build()); + serverChannelLatch.countDown(); + } + }); + serverChannel = sb.bind(serverAddress).sync().channel(); + + Bootstrap cb = new Bootstrap() + .channel(LocalChannel.class) + .group(group) + .handler(new Http2MultiplexCodecBuilder(false, new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + fail("Should not be called for outbound streams"); + } + }).build()); + clientChannel = cb.connect(serverAddress).sync().channel(); + assertTrue(serverChannelLatch.await(5, SECONDS)); + } + + @AfterAll + public static void shutdown() { + group.shutdownGracefully(0, 5, SECONDS); + } + + @AfterEach + public void tearDown() throws Exception { + if (clientChannel != null) { + clientChannel.close().syncUninterruptibly(); + clientChannel = null; + } + if (serverChannel != null) { + serverChannel.close().syncUninterruptibly(); + serverChannel = null; + } + final Channel serverConnectedChannel = this.serverConnectedChannel; + if (serverConnectedChannel != null) { + serverConnectedChannel.close().syncUninterruptibly(); + this.serverConnectedChannel = null; + } + } + + private Http2StreamChannel newOutboundStream(ChannelHandler handler) { + return new Http2StreamChannelBootstrap(clientChannel).handler(handler).open().syncUninterruptibly().getNow(); + } + + @Test + public void multipleOutboundStreams() throws Exception { + Http2StreamChannel childChannel1 = newOutboundStream(new TestChannelInitializer()); + assertTrue(childChannel1.isActive()); + assertFalse(isStreamIdValid(childChannel1.stream().id())); + Http2StreamChannel childChannel2 = newOutboundStream(new TestChannelInitializer()); + assertTrue(childChannel2.isActive()); + assertFalse(isStreamIdValid(childChannel2.stream().id())); + + Http2Headers headers1 = new DefaultHttp2Headers(); + Http2Headers headers2 = new DefaultHttp2Headers(); + // Test that streams can be made active (headers sent) in different order than the corresponding channels + // have been created. + childChannel2.writeAndFlush(new DefaultHttp2HeadersFrame(headers2)); + childChannel1.writeAndFlush(new DefaultHttp2HeadersFrame(headers1)); + + Http2HeadersFrame headersFrame2 = serverLastInboundHandler.blockingReadInbound(); + assertNotNull(headersFrame2); + assertEquals(3, headersFrame2.stream().id()); + + Http2HeadersFrame headersFrame1 = serverLastInboundHandler.blockingReadInbound(); + assertNotNull(headersFrame1); + assertEquals(5, headersFrame1.stream().id()); + + assertEquals(3, childChannel2.stream().id()); + assertEquals(5, childChannel1.stream().id()); + + childChannel1.close(); + childChannel2.close(); + + serverLastInboundHandler.checkException(); + } + + @Test + public void createOutboundStream() throws Exception { + Channel childChannel = newOutboundStream(new TestChannelInitializer()); + assertTrue(childChannel.isRegistered()); + assertTrue(childChannel.isActive()); + + Http2Headers headers = new DefaultHttp2Headers(); + childChannel.writeAndFlush(new DefaultHttp2HeadersFrame(headers)); + ByteBuf data = Unpooled.buffer(100).writeZero(100); + try { + childChannel.writeAndFlush(new DefaultHttp2DataFrame(data.retainedDuplicate(), true)); + + Http2HeadersFrame headersFrame = serverLastInboundHandler.blockingReadInbound(); + assertNotNull(headersFrame); + assertEquals(3, headersFrame.stream().id()); + assertEquals(headers, headersFrame.headers()); + + Http2DataFrame dataFrame = serverLastInboundHandler.blockingReadInbound(); + assertNotNull(dataFrame); + assertEquals(3, dataFrame.stream().id()); + assertEquals(data, dataFrame.content()); + assertTrue(dataFrame.isEndStream()); + dataFrame.release(); + + childChannel.close(); + + Http2ResetFrame rstFrame = serverLastInboundHandler.blockingReadInbound(); + assertNotNull(rstFrame); + assertEquals(3, rstFrame.stream().id()); + + serverLastInboundHandler.checkException(); + } finally { + data.release(); + } + } + + @Sharable + private static class SharableLastInboundHandler extends LastInboundHandler { + + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + ctx.fireChannelActive(); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + ctx.fireChannelInactive(); + } + } + + private static class SharableChannelHandler1 extends ChannelHandlerAdapter { + @Override + public boolean isSharable() { + return true; + } + } + + @Sharable + private static class SharableChannelHandler2 extends ChannelHandlerAdapter { + } + + private static class UnsharableChannelHandler extends ChannelHandlerAdapter { + @Override + public boolean isSharable() { + return false; + } + } + + @Test + public void testSharableCheck() { + assertNotNull(Http2MultiplexCodecBuilder.forServer(new SharableChannelHandler1())); + assertNotNull(Http2MultiplexCodecBuilder.forServer(new SharableChannelHandler2())); + } + + @Test + public void testUnsharableHandler() { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() throws Throwable { + Http2MultiplexCodecBuilder.forServer(new UnsharableChannelHandler()); + } + }); + } +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2MultiplexCodecClientUpgradeTest.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2MultiplexCodecClientUpgradeTest.java new file mode 100644 index 0000000..1446385 --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2MultiplexCodecClientUpgradeTest.java @@ -0,0 +1,34 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.channel.ChannelHandler; + +public class Http2MultiplexCodecClientUpgradeTest extends Http2MultiplexClientUpgradeTest { + + @Override + protected Http2MultiplexCodec newCodec(ChannelHandler upgradeHandler) { + Http2MultiplexCodecBuilder builder = Http2MultiplexCodecBuilder.forClient(new NoopHandler()); + if (upgradeHandler != null) { + builder.withUpgradeStreamHandler(upgradeHandler); + } + return builder.build(); + } + + @Override + protected ChannelHandler newMultiplexer(ChannelHandler upgradeHandler) { + return null; + } +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2MultiplexCodecTest.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2MultiplexCodecTest.java new file mode 100644 index 0000000..2abed4d --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2MultiplexCodecTest.java @@ -0,0 +1,40 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.channel.ChannelHandler; + +public class Http2MultiplexCodecTest extends Http2MultiplexTest { + + @Override + protected Http2FrameCodec newCodec(TestChannelInitializer childChannelInitializer, Http2FrameWriter frameWriter) { + return new Http2MultiplexCodecBuilder(true, childChannelInitializer).frameWriter(frameWriter).build(); + } + + @Override + protected ChannelHandler newMultiplexer(TestChannelInitializer childChannelInitializer) { + return null; + } + + @Override + protected boolean useUserEventForResetFrame() { + return false; + } + + @Override + protected boolean ignoreWindowUpdateFrames() { + return false; + } +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2MultiplexHandlerClientUpgradeTest.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2MultiplexHandlerClientUpgradeTest.java new file mode 100644 index 0000000..68e55a6 --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2MultiplexHandlerClientUpgradeTest.java @@ -0,0 +1,30 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.channel.ChannelHandler; + +public class Http2MultiplexHandlerClientUpgradeTest extends Http2MultiplexClientUpgradeTest { + + @Override + protected Http2FrameCodec newCodec(ChannelHandler upgradeHandler) { + return Http2FrameCodecBuilder.forClient().build(); + } + + @Override + protected ChannelHandler newMultiplexer(ChannelHandler upgradeHandler) { + return new Http2MultiplexHandler(new NoopHandler(), upgradeHandler); + } +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2MultiplexHandlerTest.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2MultiplexHandlerTest.java new file mode 100644 index 0000000..1ef84cf --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2MultiplexHandlerTest.java @@ -0,0 +1,107 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import javax.net.ssl.SSLException; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * Unit tests for {@link Http2MultiplexHandler}. + */ +public class Http2MultiplexHandlerTest extends Http2MultiplexTest { + + @Override + protected Http2FrameCodec newCodec(TestChannelInitializer childChannelInitializer, Http2FrameWriter frameWriter) { + return new Http2FrameCodecBuilder(true).frameWriter(frameWriter).build(); + } + + @Override + protected ChannelHandler newMultiplexer(TestChannelInitializer childChannelInitializer) { + return new Http2MultiplexHandler(childChannelInitializer, null); + } + + @Override + protected boolean useUserEventForResetFrame() { + return true; + } + + @Override + protected boolean ignoreWindowUpdateFrames() { + return true; + } + + @Test + public void sslExceptionTriggersChildChannelException() { + final LastInboundHandler inboundHandler = new LastInboundHandler(); + Http2StreamChannel channel = newInboundStream(3, false, inboundHandler); + assertTrue(channel.isActive()); + final RuntimeException testExc = new RuntimeException(new SSLException("foo")); + channel.parent().pipeline().addLast(new ChannelInboundHandlerAdapter() { + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + if (cause != testExc) { + super.exceptionCaught(ctx, cause); + } + } + }); + channel.parent().pipeline().fireExceptionCaught(testExc); + + assertTrue(channel.isActive()); + RuntimeException exc = assertThrows(RuntimeException.class, new Executable() { + @Override + public void execute() throws Throwable { + inboundHandler.checkException(); + } + }); + assertEquals(testExc, exc); + } + + @Test + public void customExceptionForwarding() { + final LastInboundHandler inboundHandler = new LastInboundHandler(); + Http2StreamChannel channel = newInboundStream(3, false, inboundHandler); + assertTrue(channel.isActive()); + final RuntimeException testExc = new RuntimeException("xyz"); + channel.parent().pipeline().addLast(new ChannelInboundHandlerAdapter() { + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + if (cause != testExc) { + super.exceptionCaught(ctx, cause); + } else { + ctx.pipeline().fireExceptionCaught(new Http2MultiplexActiveStreamsException(cause)); + } + } + }); + channel.parent().pipeline().fireExceptionCaught(testExc); + + assertTrue(channel.isActive()); + RuntimeException exc = assertThrows(RuntimeException.class, new Executable() { + @Override + public void execute() throws Throwable { + inboundHandler.checkException(); + } + }); + assertEquals(testExc, exc); + } +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2MultiplexTest.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2MultiplexTest.java new file mode 100644 index 0000000..f277dbb --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2MultiplexTest.java @@ -0,0 +1,1441 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelOutboundHandlerAdapter; +import io.netty.channel.ChannelPromise; +import io.netty.channel.WriteBufferWaterMark; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.channel.socket.ChannelInputShutdownReadComplete; +import io.netty.channel.socket.ChannelOutputShutdownEvent; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpScheme; +import io.netty.handler.codec.http2.Http2Exception.StreamException; +import io.netty.handler.codec.http2.LastInboundHandler.Consumer; +import io.netty.handler.ssl.SslCloseCompletionEvent; +import io.netty.util.AsciiString; +import io.netty.util.AttributeKey; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.ArgumentMatcher; +import org.mockito.Mockito; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +import java.net.InetSocketAddress; +import java.nio.channels.ClosedChannelException; +import java.util.ArrayDeque; +import java.util.Arrays; +import java.util.Collection; +import java.util.Queue; +import java.util.UUID; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; + +import static io.netty.handler.codec.http2.Http2TestUtil.anyChannelPromise; +import static io.netty.handler.codec.http2.Http2TestUtil.anyHttp2Settings; +import static io.netty.handler.codec.http2.Http2TestUtil.assertEqualsAndRelease; +import static io.netty.handler.codec.http2.Http2TestUtil.bb; +import static io.netty.util.ReferenceCountUtil.release; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.argThat; +import static org.mockito.ArgumentMatchers.eq; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public abstract class Http2MultiplexTest { + private final Http2Headers request = new DefaultHttp2Headers() + .method(HttpMethod.GET.asciiName()).scheme(HttpScheme.HTTPS.name()) + .authority(new AsciiString("example.org")).path(new AsciiString("/foo")); + + private EmbeddedChannel parentChannel; + private Http2FrameWriter frameWriter; + private Http2FrameInboundWriter frameInboundWriter; + private TestChannelInitializer childChannelInitializer; + private C codec; + + private static final int initialRemoteStreamWindow = 1024; + + protected abstract C newCodec(TestChannelInitializer childChannelInitializer, Http2FrameWriter frameWriter); + protected abstract ChannelHandler newMultiplexer(TestChannelInitializer childChannelInitializer); + + @BeforeEach + public void setUp() { + childChannelInitializer = new TestChannelInitializer(); + parentChannel = new EmbeddedChannel(); + frameInboundWriter = new Http2FrameInboundWriter(parentChannel); + parentChannel.connect(new InetSocketAddress(0)); + frameWriter = Http2TestUtil.mockedFrameWriter(); + codec = newCodec(childChannelInitializer, frameWriter); + parentChannel.pipeline().addLast(codec); + ChannelHandler multiplexer = newMultiplexer(childChannelInitializer); + if (multiplexer != null) { + parentChannel.pipeline().addLast(multiplexer); + } + + parentChannel.runPendingTasks(); + parentChannel.pipeline().fireChannelActive(); + + parentChannel.writeInbound(Http2CodecUtil.connectionPrefaceBuf()); + + Http2Settings settings = new Http2Settings().initialWindowSize(initialRemoteStreamWindow); + frameInboundWriter.writeInboundSettings(settings); + + verify(frameWriter).writeSettingsAck(eqCodecCtx(), anyChannelPromise()); + + frameInboundWriter.writeInboundSettingsAck(); + + Http2SettingsFrame settingsFrame = parentChannel.readInbound(); + assertNotNull(settingsFrame); + Http2SettingsAckFrame settingsAckFrame = parentChannel.readInbound(); + assertNotNull(settingsAckFrame); + + // Handshake + verify(frameWriter).writeSettings(eqCodecCtx(), + anyHttp2Settings(), anyChannelPromise()); + } + + private ChannelHandlerContext eqCodecCtx() { + return eq(codec.ctx); + } + + @AfterEach + public void tearDown() throws Exception { + if (childChannelInitializer.handler instanceof LastInboundHandler) { + ((LastInboundHandler) childChannelInitializer.handler).finishAndReleaseAll(); + } + parentChannel.finishAndReleaseAll(); + codec = null; + } + + // TODO(buchgr): Flush from child channel + // TODO(buchgr): ChildChannel.childReadComplete() + // TODO(buchgr): GOAWAY Logic + // TODO(buchgr): Test ChannelConfig.setMaxMessagesPerRead + + @Test + public void writeUnknownFrame() { + Http2StreamChannel childChannel = newOutboundStream(new ChannelInboundHandlerAdapter() { + @Override + public void channelActive(ChannelHandlerContext ctx) { + ctx.writeAndFlush(new DefaultHttp2HeadersFrame(new DefaultHttp2Headers())); + ctx.writeAndFlush(new DefaultHttp2UnknownFrame((byte) 99, new Http2Flags())); + ctx.fireChannelActive(); + } + }); + assertTrue(childChannel.isActive()); + + parentChannel.runPendingTasks(); + + verify(frameWriter).writeFrame(eq(codec.ctx), eq((byte) 99), eqStreamId(childChannel), any(Http2Flags.class), + any(ByteBuf.class), any(ChannelPromise.class)); + } + + Http2StreamChannel newInboundStream(int streamId, boolean endStream, final ChannelHandler childHandler) { + return newInboundStream(streamId, endStream, null, childHandler); + } + + private Http2StreamChannel newInboundStream(int streamId, boolean endStream, + AtomicInteger maxReads, final ChannelHandler childHandler) { + final AtomicReference streamChannelRef = new AtomicReference(); + childChannelInitializer.maxReads = maxReads; + childChannelInitializer.handler = new ChannelInboundHandlerAdapter() { + @Override + public void channelRegistered(ChannelHandlerContext ctx) { + assertNull(streamChannelRef.get()); + streamChannelRef.set((Http2StreamChannel) ctx.channel()); + ctx.pipeline().addLast(childHandler); + ctx.fireChannelRegistered(); + } + }; + + frameInboundWriter.writeInboundHeaders(streamId, request, 0, endStream); + parentChannel.runPendingTasks(); + Http2StreamChannel channel = streamChannelRef.get(); + assertEquals(streamId, channel.stream().id()); + return channel; + } + + @Test + public void readUnkownFrame() { + LastInboundHandler handler = new LastInboundHandler(); + + Http2StreamChannel channel = newInboundStream(3, true, handler); + frameInboundWriter.writeInboundFrame((byte) 99, channel.stream().id(), new Http2Flags(), Unpooled.EMPTY_BUFFER); + + // header frame and unknown frame + verifyFramesMultiplexedToCorrectChannel(channel, handler, 2); + + Channel childChannel = newOutboundStream(new ChannelInboundHandlerAdapter()); + assertTrue(childChannel.isActive()); + } + + @Test + public void headerAndDataFramesShouldBeDelivered() { + LastInboundHandler inboundHandler = new LastInboundHandler(); + + Http2StreamChannel channel = newInboundStream(3, false, inboundHandler); + Http2HeadersFrame headersFrame = new DefaultHttp2HeadersFrame(request).stream(channel.stream()); + Http2DataFrame dataFrame1 = new DefaultHttp2DataFrame(bb("hello")).stream(channel.stream()); + Http2DataFrame dataFrame2 = new DefaultHttp2DataFrame(bb("world")).stream(channel.stream()); + + assertTrue(inboundHandler.isChannelActive()); + frameInboundWriter.writeInboundData(channel.stream().id(), bb("hello"), 0, false); + frameInboundWriter.writeInboundData(channel.stream().id(), bb("world"), 0, false); + + assertEquals(headersFrame, inboundHandler.readInbound()); + + assertEqualsAndRelease(dataFrame1, inboundHandler.readInbound()); + assertEqualsAndRelease(dataFrame2, inboundHandler.readInbound()); + + assertNull(inboundHandler.readInbound()); + } + + @Test + public void headerMultipleContentLengthValidationShouldPropagate() { + headerMultipleContentLengthValidationShouldPropagate(false); + } + + @Test + public void headerMultipleContentLengthValidationShouldPropagateWithEndStream() { + headerMultipleContentLengthValidationShouldPropagate(true); + } + + private void headerMultipleContentLengthValidationShouldPropagate(boolean endStream) { + final LastInboundHandler inboundHandler = new LastInboundHandler(); + request.addLong(HttpHeaderNames.CONTENT_LENGTH, 0); + request.addLong(HttpHeaderNames.CONTENT_LENGTH, 1); + Http2StreamChannel channel = newInboundStream(3, endStream, inboundHandler); + + assertThrows(StreamException.class, new Executable() { + @Override + public void execute() throws Throwable { + inboundHandler.checkException(); + } + }); + assertNull(inboundHandler.readInbound()); + assertFalse(channel.isActive()); + } + + @Test + public void headerPlusSignContentLengthValidationShouldPropagate() { + headerSignContentLengthValidationShouldPropagateWithEndStream(false, false); + } + + @Test + public void headerPlusSignContentLengthValidationShouldPropagateWithEndStream() { + headerSignContentLengthValidationShouldPropagateWithEndStream(false, true); + } + + @Test + public void headerMinusSignContentLengthValidationShouldPropagate() { + headerSignContentLengthValidationShouldPropagateWithEndStream(true, false); + } + + @Test + public void headerMinusSignContentLengthValidationShouldPropagateWithEndStream() { + headerSignContentLengthValidationShouldPropagateWithEndStream(true, true); + } + + private void headerSignContentLengthValidationShouldPropagateWithEndStream(boolean minus, boolean endStream) { + final LastInboundHandler inboundHandler = new LastInboundHandler(); + request.add(HttpHeaderNames.CONTENT_LENGTH, (minus ? "-" : "+") + 1); + Http2StreamChannel channel = newInboundStream(3, endStream, inboundHandler); + assertThrows(StreamException.class, new Executable() { + @Override + public void execute() throws Throwable { + inboundHandler.checkException(); + } + }); + + assertNull(inboundHandler.readInbound()); + assertFalse(channel.isActive()); + } + + @Test + public void headerContentLengthNotMatchValidationShouldPropagate() { + headerContentLengthNotMatchValidationShouldPropagate(false, false, false); + } + + @Test + public void headerContentLengthNotMatchValidationShouldPropagateWithEndStream() { + headerContentLengthNotMatchValidationShouldPropagate(false, true, false); + } + + @Test + public void headerContentLengthNotMatchValidationShouldPropagateCloseLocal() { + headerContentLengthNotMatchValidationShouldPropagate(true, false, false); + } + + @Test + public void headerContentLengthNotMatchValidationShouldPropagateWithEndStreamCloseLocal() { + headerContentLengthNotMatchValidationShouldPropagate(true, true, false); + } + + @Test + public void headerContentLengthNotMatchValidationShouldPropagateTrailers() { + headerContentLengthNotMatchValidationShouldPropagate(false, false, true); + } + + @Test + public void headerContentLengthNotMatchValidationShouldPropagateWithEndStreamTrailers() { + headerContentLengthNotMatchValidationShouldPropagate(false, true, true); + } + + @Test + public void headerContentLengthNotMatchValidationShouldPropagateCloseLocalTrailers() { + headerContentLengthNotMatchValidationShouldPropagate(true, false, true); + } + + @Test + public void headerContentLengthNotMatchValidationShouldPropagateWithEndStreamCloseLocalTrailers() { + headerContentLengthNotMatchValidationShouldPropagate(true, true, true); + } + + private void headerContentLengthNotMatchValidationShouldPropagate( + boolean closeLocal, boolean endStream, boolean trailer) { + final LastInboundHandler inboundHandler = new LastInboundHandler(); + request.addLong(HttpHeaderNames.CONTENT_LENGTH, 1); + Http2StreamChannel channel = newInboundStream(3, false, inboundHandler); + assertTrue(channel.isActive()); + + if (closeLocal) { + channel.writeAndFlush(new DefaultHttp2HeadersFrame(new DefaultHttp2Headers(), true)) + .syncUninterruptibly(); + assertEquals(Http2Stream.State.HALF_CLOSED_LOCAL, channel.stream().state()); + } else { + assertEquals(Http2Stream.State.OPEN, channel.stream().state()); + } + + if (trailer) { + frameInboundWriter.writeInboundHeaders(channel.stream().id(), new DefaultHttp2Headers(), 0, endStream); + } else { + frameInboundWriter.writeInboundData(channel.stream().id(), bb("foo"), 0, endStream); + } + + assertThrows(StreamException.class, new Executable() { + @Override + public void execute() throws Throwable { + inboundHandler.checkException(); + } + }); + + Http2HeadersFrame headersFrame = new DefaultHttp2HeadersFrame(request).stream(channel.stream()); + assertEquals(headersFrame, inboundHandler.readInbound()); + assertNull(inboundHandler.readInbound()); + assertFalse(channel.isActive()); + } + + @Test + public void streamExceptionCauseRstStreamWithProtocolError() { + request.addLong(HttpHeaderNames.CONTENT_LENGTH, 10); + Http2StreamChannel channel = newInboundStream(3, false, new ChannelInboundHandlerAdapter()); + channel.pipeline().fireExceptionCaught(new Http2FrameStreamException(channel.stream(), + Http2Error.PROTOCOL_ERROR, new IllegalArgumentException())); + assertFalse(channel.isActive()); + verify(frameWriter).writeRstStream(eqCodecCtx(), eq(3), + eq(Http2Error.PROTOCOL_ERROR.code()), anyChannelPromise()); + } + + @Test + public void contentLengthNotMatchRstStreamWithProtocolError() { + final LastInboundHandler inboundHandler = new LastInboundHandler(); + request.addLong(HttpHeaderNames.CONTENT_LENGTH, 10); + Http2StreamChannel channel = newInboundStream(3, false, inboundHandler); + frameInboundWriter.writeInboundData(3, bb(8), 0, true); + assertThrows(StreamException.class, new Executable() { + @Override + public void execute() throws Throwable { + inboundHandler.checkException(); + } + }); + assertNotNull(inboundHandler.readInbound()); + assertFalse(channel.isActive()); + verify(frameWriter).writeRstStream(eqCodecCtx(), eq(3), + eq(Http2Error.PROTOCOL_ERROR.code()), anyChannelPromise()); + } + + @Test + public void framesShouldBeMultiplexed() { + LastInboundHandler handler1 = new LastInboundHandler(); + Http2StreamChannel channel1 = newInboundStream(3, false, handler1); + LastInboundHandler handler2 = new LastInboundHandler(); + Http2StreamChannel channel2 = newInboundStream(5, false, handler2); + LastInboundHandler handler3 = new LastInboundHandler(); + Http2StreamChannel channel3 = newInboundStream(11, false, handler3); + + verifyFramesMultiplexedToCorrectChannel(channel1, handler1, 1); + verifyFramesMultiplexedToCorrectChannel(channel2, handler2, 1); + verifyFramesMultiplexedToCorrectChannel(channel3, handler3, 1); + + frameInboundWriter.writeInboundData(channel2.stream().id(), bb("hello"), 0, false); + frameInboundWriter.writeInboundData(channel1.stream().id(), bb("foo"), 0, true); + frameInboundWriter.writeInboundData(channel2.stream().id(), bb("world"), 0, true); + frameInboundWriter.writeInboundData(channel3.stream().id(), bb("bar"), 0, true); + + verifyFramesMultiplexedToCorrectChannel(channel1, handler1, 1); + verifyFramesMultiplexedToCorrectChannel(channel2, handler2, 2); + verifyFramesMultiplexedToCorrectChannel(channel3, handler3, 1); + } + + @Test + public void inboundDataFrameShouldUpdateLocalFlowController() throws Http2Exception { + Http2LocalFlowController flowController = Mockito.mock(Http2LocalFlowController.class); + codec.connection().local().flowController(flowController); + + LastInboundHandler handler = new LastInboundHandler(); + final Http2StreamChannel channel = newInboundStream(3, false, handler); + + ByteBuf tenBytes = bb("0123456789"); + + frameInboundWriter.writeInboundData(channel.stream().id(), tenBytes, 0, true); + + // Verify we marked the bytes as consumed + verify(flowController).consumeBytes(argThat(new ArgumentMatcher() { + @Override + public boolean matches(Http2Stream http2Stream) { + return http2Stream.id() == channel.stream().id(); + } + }), eq(10)); + + // headers and data frame + verifyFramesMultiplexedToCorrectChannel(channel, handler, 2); + } + + @Test + public void unhandledHttp2FramesShouldBePropagated() { + Http2PingFrame pingFrame = new DefaultHttp2PingFrame(0); + frameInboundWriter.writeInboundPing(false, 0); + assertEquals(parentChannel.readInbound(), pingFrame); + + DefaultHttp2GoAwayFrame goAwayFrame = new DefaultHttp2GoAwayFrame(1, + parentChannel.alloc().buffer().writeLong(8)); + frameInboundWriter.writeInboundGoAway(0, goAwayFrame.errorCode(), goAwayFrame.content().retainedDuplicate()); + + Http2GoAwayFrame frame = parentChannel.readInbound(); + assertEqualsAndRelease(frame, goAwayFrame); + } + + @Test + public void channelReadShouldRespectAutoRead() { + LastInboundHandler inboundHandler = new LastInboundHandler(); + Http2StreamChannel childChannel = newInboundStream(3, false, inboundHandler); + assertTrue(childChannel.config().isAutoRead()); + Http2HeadersFrame headersFrame = inboundHandler.readInbound(); + assertNotNull(headersFrame); + + childChannel.config().setAutoRead(false); + + frameInboundWriter.writeInboundData(childChannel.stream().id(), bb("hello world"), 0, false); + Http2DataFrame dataFrame0 = inboundHandler.readInbound(); + assertNotNull(dataFrame0); + release(dataFrame0); + + frameInboundWriter.writeInboundData(childChannel.stream().id(), bb("foo"), 0, false); + frameInboundWriter.writeInboundData(childChannel.stream().id(), bb("bar"), 0, false); + + assertNull(inboundHandler.readInbound()); + + childChannel.config().setAutoRead(true); + verifyFramesMultiplexedToCorrectChannel(childChannel, inboundHandler, 2); + } + + @Test + public void channelReadShouldRespectAutoReadAndNotProduceNPE() throws Exception { + LastInboundHandler inboundHandler = new LastInboundHandler(); + Http2StreamChannel childChannel = newInboundStream(3, false, inboundHandler); + assertTrue(childChannel.config().isAutoRead()); + Http2HeadersFrame headersFrame = inboundHandler.readInbound(); + assertNotNull(headersFrame); + + childChannel.config().setAutoRead(false); + childChannel.pipeline().addFirst(new ChannelInboundHandlerAdapter() { + private int count; + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + ctx.fireChannelRead(msg); + // Close channel after 2 reads so there is still something in the inboundBuffer when the close happens. + if (++count == 2) { + ctx.close(); + } + } + }); + frameInboundWriter.writeInboundData(childChannel.stream().id(), bb("hello world"), 0, false); + Http2DataFrame dataFrame0 = inboundHandler.readInbound(); + assertNotNull(dataFrame0); + release(dataFrame0); + + frameInboundWriter.writeInboundData(childChannel.stream().id(), bb("foo"), 0, false); + frameInboundWriter.writeInboundData(childChannel.stream().id(), bb("bar"), 0, false); + frameInboundWriter.writeInboundData(childChannel.stream().id(), bb("bar"), 0, false); + + assertNull(inboundHandler.readInbound()); + + childChannel.config().setAutoRead(true); + verifyFramesMultiplexedToCorrectChannel(childChannel, inboundHandler, 3); + inboundHandler.checkException(); + } + + @Test + public void readInChannelReadWithoutAutoRead() { + useReadWithoutAutoRead(false); + } + + @Test + public void readInChannelReadCompleteWithoutAutoRead() { + useReadWithoutAutoRead(true); + } + + private void useReadWithoutAutoRead(final boolean readComplete) { + LastInboundHandler inboundHandler = new LastInboundHandler(); + Http2StreamChannel childChannel = newInboundStream(3, false, inboundHandler); + assertTrue(childChannel.config().isAutoRead()); + childChannel.config().setAutoRead(false); + assertFalse(childChannel.config().isAutoRead()); + + Http2HeadersFrame headersFrame = inboundHandler.readInbound(); + assertNotNull(headersFrame); + + // Add a handler which will request reads. + childChannel.pipeline().addFirst(new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + ctx.fireChannelRead(msg); + if (!readComplete) { + ctx.read(); + } + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) { + ctx.fireChannelReadComplete(); + if (readComplete) { + ctx.read(); + } + } + }); + + frameInboundWriter.writeInboundData(childChannel.stream().id(), bb("hello world"), 0, false); + frameInboundWriter.writeInboundData(childChannel.stream().id(), bb("foo"), 0, false); + frameInboundWriter.writeInboundData(childChannel.stream().id(), bb("bar"), 0, false); + frameInboundWriter.writeInboundData(childChannel.stream().id(), bb("hello world"), 0, false); + frameInboundWriter.writeInboundData(childChannel.stream().id(), bb("foo"), 0, false); + frameInboundWriter.writeInboundData(childChannel.stream().id(), bb("bar"), 0, true); + + verifyFramesMultiplexedToCorrectChannel(childChannel, inboundHandler, 6); + } + + private Http2StreamChannel newOutboundStream(ChannelHandler handler) { + return new Http2StreamChannelBootstrap(parentChannel).handler(handler) + .open().syncUninterruptibly().getNow(); + } + + /** + * A child channel for an HTTP/2 stream in IDLE state (that is no headers sent or received), + * should not emit a RST_STREAM frame on close, as this is a connection error of type protocol error. + */ + @Test + public void idleOutboundStreamShouldNotWriteResetFrameOnClose() { + LastInboundHandler handler = new LastInboundHandler(); + + Channel childChannel = newOutboundStream(handler); + assertTrue(childChannel.isActive()); + + childChannel.close(); + parentChannel.runPendingTasks(); + + assertFalse(childChannel.isOpen()); + assertFalse(childChannel.isActive()); + assertNull(parentChannel.readOutbound()); + } + + @Test + public void outboundStreamShouldWriteResetFrameOnClose_headersSent() { + ChannelHandler handler = new ChannelInboundHandlerAdapter() { + @Override + public void channelActive(ChannelHandlerContext ctx) { + ctx.writeAndFlush(new DefaultHttp2HeadersFrame(new DefaultHttp2Headers())); + ctx.fireChannelActive(); + } + }; + + Http2StreamChannel childChannel = newOutboundStream(handler); + assertTrue(childChannel.isActive()); + + childChannel.close(); + verify(frameWriter).writeRstStream(eqCodecCtx(), + eqStreamId(childChannel), eq(Http2Error.CANCEL.code()), anyChannelPromise()); + } + + @Test + public void outboundStreamShouldNotWriteResetFrameOnClose_IfStreamDidntExist() { + when(frameWriter.writeHeaders(eqCodecCtx(), anyInt(), + any(Http2Headers.class), anyInt(), anyBoolean(), + any(ChannelPromise.class))).thenAnswer(new Answer() { + + private boolean headersWritten; + @Override + public ChannelFuture answer(InvocationOnMock invocationOnMock) { + // We want to fail to write the first headers frame. This is what happens if the connection + // refuses to allocate a new stream due to having received a GOAWAY. + if (!headersWritten) { + headersWritten = true; + return ((ChannelPromise) invocationOnMock.getArgument(5)).setFailure(new Exception("boom")); + } + return ((ChannelPromise) invocationOnMock.getArgument(5)).setSuccess(); + } + }); + + Http2StreamChannel childChannel = newOutboundStream(new ChannelInboundHandlerAdapter() { + @Override + public void channelActive(ChannelHandlerContext ctx) { + ctx.writeAndFlush(new DefaultHttp2HeadersFrame(new DefaultHttp2Headers())); + ctx.fireChannelActive(); + } + }); + + assertFalse(childChannel.isActive()); + + childChannel.close(); + parentChannel.runPendingTasks(); + // The channel was never active so we should not generate a RST frame. + verify(frameWriter, never()).writeRstStream(eqCodecCtx(), eqStreamId(childChannel), anyLong(), + anyChannelPromise()); + + assertTrue(parentChannel.outboundMessages().isEmpty()); + } + + @Test + public void inboundRstStreamFireChannelInactive() { + LastInboundHandler inboundHandler = new LastInboundHandler(); + Http2StreamChannel channel = newInboundStream(3, false, inboundHandler); + assertTrue(inboundHandler.isChannelActive()); + frameInboundWriter.writeInboundRstStream(channel.stream().id(), Http2Error.INTERNAL_ERROR.code()); + + assertFalse(inboundHandler.isChannelActive()); + + // A RST_STREAM frame should NOT be emitted, as we received a RST_STREAM. + verify(frameWriter, never()).writeRstStream(eqCodecCtx(), eqStreamId(channel), + anyLong(), anyChannelPromise()); + } + + @Test + public void streamExceptionTriggersChildChannelExceptionAndClose() throws Exception { + final LastInboundHandler inboundHandler = new LastInboundHandler(); + Http2StreamChannel channel = newInboundStream(3, false, inboundHandler); + assertTrue(channel.isActive()); + StreamException cause = new StreamException(channel.stream().id(), Http2Error.PROTOCOL_ERROR, "baaam!"); + parentChannel.pipeline().fireExceptionCaught(cause); + + assertFalse(channel.isActive()); + + assertThrows(StreamException.class, new Executable() { + @Override + public void execute() throws Throwable { + inboundHandler.checkException(); + } + }); + } + + @Test + public void streamClosedErrorTranslatedToClosedChannelExceptionOnWrites() throws Exception { + LastInboundHandler inboundHandler = new LastInboundHandler(); + + final Http2StreamChannel childChannel = newOutboundStream(inboundHandler); + assertTrue(childChannel.isActive()); + + Http2Headers headers = new DefaultHttp2Headers(); + when(frameWriter.writeHeaders(eqCodecCtx(), anyInt(), + eq(headers), anyInt(), anyBoolean(), + any(ChannelPromise.class))).thenAnswer(new Answer() { + @Override + public ChannelFuture answer(InvocationOnMock invocationOnMock) { + return ((ChannelPromise) invocationOnMock.getArgument(5)).setFailure( + new StreamException(childChannel.stream().id(), Http2Error.STREAM_CLOSED, "Stream Closed")); + } + }); + final ChannelFuture future = childChannel.writeAndFlush( + new DefaultHttp2HeadersFrame(new DefaultHttp2Headers())); + + parentChannel.flush(); + + assertFalse(childChannel.isActive()); + assertFalse(childChannel.isOpen()); + + inboundHandler.checkException(); + + assertThrows(ClosedChannelException.class, new Executable() { + @Override + public void execute() { + future.syncUninterruptibly(); + } + }); + } + + @Test + public void creatingWritingReadingAndClosingOutboundStreamShouldWork() { + LastInboundHandler inboundHandler = new LastInboundHandler(); + Http2StreamChannel childChannel = newOutboundStream(inboundHandler); + assertTrue(childChannel.isActive()); + assertTrue(inboundHandler.isChannelActive()); + + // Write to the child channel + Http2Headers headers = new DefaultHttp2Headers().scheme("https").method("GET").path("/foo.txt"); + childChannel.writeAndFlush(new DefaultHttp2HeadersFrame(headers)); + + // Read from the child channel + frameInboundWriter.writeInboundHeaders(childChannel.stream().id(), headers, 0, false); + + Http2HeadersFrame headersFrame = inboundHandler.readInbound(); + assertNotNull(headersFrame); + assertEquals(headers, headersFrame.headers()); + + // Close the child channel. + childChannel.close(); + + parentChannel.runPendingTasks(); + // An active outbound stream should emit a RST_STREAM frame. + verify(frameWriter).writeRstStream(eqCodecCtx(), eqStreamId(childChannel), + anyLong(), anyChannelPromise()); + + assertFalse(childChannel.isOpen()); + assertFalse(childChannel.isActive()); + assertFalse(inboundHandler.isChannelActive()); + } + + // Test failing the promise of the first headers frame of an outbound stream. In practice this error case would most + // likely happen due to the max concurrent streams limit being hit or the channel running out of stream identifiers. + // + @Test + public void failedOutboundStreamCreationThrowsAndClosesChannel() throws Exception { + LastInboundHandler handler = new LastInboundHandler(); + Http2StreamChannel childChannel = newOutboundStream(handler); + assertTrue(childChannel.isActive()); + + Http2Headers headers = new DefaultHttp2Headers(); + when(frameWriter.writeHeaders(eqCodecCtx(), anyInt(), + eq(headers), anyInt(), anyBoolean(), + any(ChannelPromise.class))).thenAnswer(new Answer() { + @Override + public ChannelFuture answer(InvocationOnMock invocationOnMock) { + return ((ChannelPromise) invocationOnMock.getArgument(5)).setFailure( + new Http2NoMoreStreamIdsException()); + } + }); + + final ChannelFuture future = childChannel.writeAndFlush(new DefaultHttp2HeadersFrame(headers)); + parentChannel.flush(); + + assertFalse(childChannel.isActive()); + assertFalse(childChannel.isOpen()); + + handler.checkException(); + + assertThrows(Http2NoMoreStreamIdsException.class, new Executable() { + @Override + public void execute() { + future.syncUninterruptibly(); + } + }); + } + + @Test + public void channelClosedWhenCloseListenerCompletes() { + LastInboundHandler inboundHandler = new LastInboundHandler(); + Http2StreamChannel childChannel = newInboundStream(3, false, inboundHandler); + + assertTrue(childChannel.isOpen()); + assertTrue(childChannel.isActive()); + + final AtomicBoolean channelOpen = new AtomicBoolean(true); + final AtomicBoolean channelActive = new AtomicBoolean(true); + + // Create a promise before actually doing the close, because otherwise we would be adding a listener to a future + // that is already completed because we are using EmbeddedChannel which executes code in the JUnit thread. + ChannelPromise p = childChannel.newPromise(); + p.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + channelOpen.set(future.channel().isOpen()); + channelActive.set(future.channel().isActive()); + } + }); + childChannel.close(p).syncUninterruptibly(); + + assertFalse(channelOpen.get()); + assertFalse(channelActive.get()); + assertFalse(childChannel.isActive()); + } + + @Test + public void channelClosedWhenChannelClosePromiseCompletes() { + LastInboundHandler inboundHandler = new LastInboundHandler(); + Http2StreamChannel childChannel = newInboundStream(3, false, inboundHandler); + + assertTrue(childChannel.isOpen()); + assertTrue(childChannel.isActive()); + + final AtomicBoolean channelOpen = new AtomicBoolean(true); + final AtomicBoolean channelActive = new AtomicBoolean(true); + + childChannel.closeFuture().addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + channelOpen.set(future.channel().isOpen()); + channelActive.set(future.channel().isActive()); + } + }); + childChannel.close().syncUninterruptibly(); + + assertFalse(channelOpen.get()); + assertFalse(channelActive.get()); + assertFalse(childChannel.isActive()); + } + + @Test + public void channelClosedWhenWriteFutureFails() { + final Queue writePromises = new ArrayDeque(); + + LastInboundHandler inboundHandler = new LastInboundHandler(); + Http2StreamChannel childChannel = newInboundStream(3, false, inboundHandler); + + assertTrue(childChannel.isOpen()); + assertTrue(childChannel.isActive()); + + final AtomicBoolean channelOpen = new AtomicBoolean(true); + final AtomicBoolean channelActive = new AtomicBoolean(true); + + Http2Headers headers = new DefaultHttp2Headers(); + when(frameWriter.writeHeaders(eqCodecCtx(), anyInt(), + eq(headers), anyInt(), anyBoolean(), + any(ChannelPromise.class))).thenAnswer(new Answer() { + @Override + public ChannelFuture answer(InvocationOnMock invocationOnMock) { + ChannelPromise promise = invocationOnMock.getArgument(5); + writePromises.offer(promise); + return promise; + } + }); + + ChannelFuture f = childChannel.writeAndFlush(new DefaultHttp2HeadersFrame(headers)); + assertFalse(f.isDone()); + f.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + channelOpen.set(future.channel().isOpen()); + channelActive.set(future.channel().isActive()); + } + }); + + ChannelPromise first = writePromises.poll(); + first.setFailure(new ClosedChannelException()); + f.awaitUninterruptibly(); + + assertFalse(channelOpen.get()); + assertFalse(channelActive.get()); + assertFalse(childChannel.isActive()); + } + + @Test + public void channelClosedTwiceMarksPromiseAsSuccessful() { + LastInboundHandler inboundHandler = new LastInboundHandler(); + Http2StreamChannel childChannel = newInboundStream(3, false, inboundHandler); + + assertTrue(childChannel.isOpen()); + assertTrue(childChannel.isActive()); + childChannel.close().syncUninterruptibly(); + childChannel.close().syncUninterruptibly(); + + assertFalse(childChannel.isOpen()); + assertFalse(childChannel.isActive()); + } + + @Test + public void settingChannelOptsAndAttrs() { + AttributeKey key = AttributeKey.newInstance(UUID.randomUUID().toString()); + + Channel childChannel = newOutboundStream(new ChannelInboundHandlerAdapter()); + childChannel.config().setAutoRead(false).setWriteSpinCount(1000); + childChannel.attr(key).set("bar"); + assertFalse(childChannel.config().isAutoRead()); + assertEquals(1000, childChannel.config().getWriteSpinCount()); + assertEquals("bar", childChannel.attr(key).get()); + } + + @Test + public void outboundFlowControlWritability() { + Http2StreamChannel childChannel = newOutboundStream(new ChannelInboundHandlerAdapter()); + assertTrue(childChannel.isActive()); + + assertTrue(childChannel.isWritable()); + childChannel.writeAndFlush(new DefaultHttp2HeadersFrame(new DefaultHttp2Headers())); + parentChannel.flush(); + + // Test for initial window size + assertTrue(initialRemoteStreamWindow < childChannel.config().getWriteBufferHighWaterMark()); + + assertTrue(childChannel.isWritable()); + childChannel.write(new DefaultHttp2DataFrame(Unpooled.buffer().writeZero(16 * 1024 * 1024))); + assertEquals(0, childChannel.bytesBeforeUnwritable()); + assertFalse(childChannel.isWritable()); + } + + @Test + public void writabilityOfParentIsRespected() { + Http2StreamChannel childChannel = newOutboundStream(new ChannelInboundHandlerAdapter()); + childChannel.config().setWriteBufferWaterMark(new WriteBufferWaterMark(2048, 4096)); + parentChannel.config().setWriteBufferWaterMark(new WriteBufferWaterMark(256, 512)); + assertTrue(childChannel.isWritable()); + assertTrue(parentChannel.isActive()); + + childChannel.writeAndFlush(new DefaultHttp2HeadersFrame(new DefaultHttp2Headers())); + parentChannel.flush(); + + assertTrue(childChannel.isWritable()); + childChannel.write(new DefaultHttp2DataFrame(Unpooled.buffer().writeZero(256))); + assertTrue(childChannel.isWritable()); + childChannel.writeAndFlush(new DefaultHttp2DataFrame(Unpooled.buffer().writeZero(512))); + + long bytesBeforeUnwritable = childChannel.bytesBeforeUnwritable(); + assertNotEquals(0, bytesBeforeUnwritable); + // Add something to the ChannelOutboundBuffer of the parent to simulate queuing in the parents channel buffer + // and verify that this only affect the writability of the parent channel while the child stays writable + // until it used all of its credits. + parentChannel.unsafe().outboundBuffer().addMessage( + Unpooled.buffer().writeZero(800), 800, parentChannel.voidPromise()); + assertFalse(parentChannel.isWritable()); + + assertTrue(childChannel.isWritable()); + assertEquals(4097, childChannel.bytesBeforeUnwritable()); + + // Flush everything which simulate writing everything to the socket. + parentChannel.flush(); + assertTrue(parentChannel.isWritable()); + assertTrue(childChannel.isWritable()); + assertEquals(bytesBeforeUnwritable, childChannel.bytesBeforeUnwritable()); + + ChannelFuture future = childChannel.writeAndFlush(new DefaultHttp2DataFrame( + Unpooled.buffer().writeZero((int) bytesBeforeUnwritable))); + assertFalse(childChannel.isWritable()); + assertTrue(parentChannel.isWritable()); + + parentChannel.flush(); + assertFalse(future.isDone()); + assertTrue(parentChannel.isWritable()); + assertFalse(childChannel.isWritable()); + + // Now write an window update frame for the stream which then should ensure we will flush the bytes that were + // queued in the RemoteFlowController before for the stream. + frameInboundWriter.writeInboundWindowUpdate(childChannel.stream().id(), (int) bytesBeforeUnwritable); + assertTrue(childChannel.isWritable()); + assertTrue(future.isDone()); + } + + @Test + public void channelClosedWhenInactiveFired() { + LastInboundHandler inboundHandler = new LastInboundHandler(); + Http2StreamChannel childChannel = newInboundStream(3, false, inboundHandler); + + final AtomicBoolean channelOpen = new AtomicBoolean(false); + final AtomicBoolean channelActive = new AtomicBoolean(false); + assertTrue(childChannel.isOpen()); + assertTrue(childChannel.isActive()); + + childChannel.pipeline().addLast(new ChannelInboundHandlerAdapter() { + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + channelOpen.set(ctx.channel().isOpen()); + channelActive.set(ctx.channel().isActive()); + + super.channelInactive(ctx); + } + }); + + childChannel.close().syncUninterruptibly(); + assertFalse(channelOpen.get()); + assertFalse(channelActive.get()); + } + + @Test + public void channelInactiveHappensAfterExceptionCaughtEvents() throws Exception { + final AtomicInteger count = new AtomicInteger(0); + final AtomicInteger exceptionCaught = new AtomicInteger(-1); + final AtomicInteger channelInactive = new AtomicInteger(-1); + final AtomicInteger channelUnregistered = new AtomicInteger(-1); + Http2StreamChannel childChannel = newOutboundStream(new ChannelInboundHandlerAdapter() { + + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + ctx.close(); + throw new Exception("exception"); + } + }); + + childChannel.pipeline().addLast(new ChannelInboundHandlerAdapter() { + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + channelInactive.set(count.getAndIncrement()); + super.channelInactive(ctx); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + exceptionCaught.set(count.getAndIncrement()); + super.exceptionCaught(ctx, cause); + } + + @Override + public void channelUnregistered(ChannelHandlerContext ctx) throws Exception { + channelUnregistered.set(count.getAndIncrement()); + super.channelUnregistered(ctx); + } + }); + + childChannel.pipeline().fireUserEventTriggered(new Object()); + parentChannel.runPendingTasks(); + + // The events should have happened in this order because the inactive and deregistration events + // get deferred as they do in the AbstractChannel. + assertEquals(0, exceptionCaught.get()); + assertEquals(1, channelInactive.get()); + assertEquals(2, channelUnregistered.get()); + } + + @Test + public void callUnsafeCloseMultipleTimes() { + LastInboundHandler inboundHandler = new LastInboundHandler(); + Http2StreamChannel childChannel = newInboundStream(3, false, inboundHandler); + childChannel.unsafe().close(childChannel.voidPromise()); + + ChannelPromise promise = childChannel.newPromise(); + childChannel.unsafe().close(promise); + promise.syncUninterruptibly(); + childChannel.closeFuture().syncUninterruptibly(); + } + + @Test + public void endOfStreamDoesNotDiscardData() { + AtomicInteger numReads = new AtomicInteger(1); + final AtomicBoolean shouldDisableAutoRead = new AtomicBoolean(); + Consumer ctxConsumer = new Consumer() { + @Override + public void accept(ChannelHandlerContext obj) { + if (shouldDisableAutoRead.get()) { + obj.channel().config().setAutoRead(false); + } + } + }; + LastInboundHandler inboundHandler = new LastInboundHandler(ctxConsumer); + Http2StreamChannel childChannel = newInboundStream(3, false, numReads, inboundHandler); + childChannel.config().setAutoRead(false); + + Http2DataFrame dataFrame1 = new DefaultHttp2DataFrame(bb("1")).stream(childChannel.stream()); + Http2DataFrame dataFrame2 = new DefaultHttp2DataFrame(bb("2")).stream(childChannel.stream()); + Http2DataFrame dataFrame3 = new DefaultHttp2DataFrame(bb("3")).stream(childChannel.stream()); + Http2DataFrame dataFrame4 = new DefaultHttp2DataFrame(bb("4")).stream(childChannel.stream()); + + assertEquals(new DefaultHttp2HeadersFrame(request).stream(childChannel.stream()), inboundHandler.readInbound()); + + ChannelHandler readCompleteSupressHandler = new ChannelInboundHandlerAdapter() { + @Override + public void channelReadComplete(ChannelHandlerContext ctx) { + // We want to simulate the parent channel calling channelRead and delay calling channelReadComplete. + } + }; + + parentChannel.pipeline().addFirst(readCompleteSupressHandler); + frameInboundWriter.writeInboundData(childChannel.stream().id(), bb("1"), 0, false); + + assertEqualsAndRelease(dataFrame1, inboundHandler.readInbound()); + + // Deliver frames, and then a stream closed while read is inactive. + frameInboundWriter.writeInboundData(childChannel.stream().id(), bb("2"), 0, false); + frameInboundWriter.writeInboundData(childChannel.stream().id(), bb("3"), 0, false); + frameInboundWriter.writeInboundData(childChannel.stream().id(), bb("4"), 0, false); + + shouldDisableAutoRead.set(true); + childChannel.config().setAutoRead(true); + numReads.set(1); + + frameInboundWriter.writeInboundRstStream(childChannel.stream().id(), Http2Error.NO_ERROR.code()); + + // Detecting EOS should flush all pending data regardless of read calls. + assertEqualsAndRelease(dataFrame2, inboundHandler.readInbound()); + assertNull(inboundHandler.readInbound()); + + // As we limited the number to 1 we also need to call read() again. + childChannel.read(); + + assertEqualsAndRelease(dataFrame3, inboundHandler.readInbound()); + assertEqualsAndRelease(dataFrame4, inboundHandler.readInbound()); + + Http2ResetFrame resetFrame = useUserEventForResetFrame() ? inboundHandler.readUserEvent() : + inboundHandler.readInbound(); + + assertEquals(childChannel.stream(), resetFrame.stream()); + assertEquals(Http2Error.NO_ERROR.code(), resetFrame.errorCode()); + + assertNull(inboundHandler.readInbound()); + + // Now we want to call channelReadComplete and simulate the end of the read loop. + parentChannel.pipeline().remove(readCompleteSupressHandler); + parentChannel.flushInbound(); + + childChannel.closeFuture().syncUninterruptibly(); + } + + protected abstract boolean useUserEventForResetFrame(); + + protected abstract boolean ignoreWindowUpdateFrames(); + + @Test + public void windowUpdateFrames() { + AtomicInteger numReads = new AtomicInteger(1); + LastInboundHandler inboundHandler = new LastInboundHandler(); + Http2StreamChannel childChannel = newInboundStream(3, false, numReads, inboundHandler); + + assertEquals(new DefaultHttp2HeadersFrame(request).stream(childChannel.stream()), inboundHandler.readInbound()); + + frameInboundWriter.writeInboundWindowUpdate(childChannel.stream().id(), 4); + + Http2WindowUpdateFrame updateFrame = inboundHandler.readInbound(); + if (ignoreWindowUpdateFrames()) { + assertNull(updateFrame); + } else { + assertEquals(new DefaultHttp2WindowUpdateFrame(4).stream(childChannel.stream()), updateFrame); + } + + frameInboundWriter.writeInboundWindowUpdate(Http2CodecUtil.CONNECTION_STREAM_ID, 6); + + assertNull(parentChannel.readInbound()); + childChannel.close().syncUninterruptibly(); + } + + @Test + public void childQueueIsDrainedAndNewDataIsDispatchedInParentReadLoopAutoRead() { + AtomicInteger numReads = new AtomicInteger(1); + final AtomicInteger channelReadCompleteCount = new AtomicInteger(0); + final AtomicBoolean shouldDisableAutoRead = new AtomicBoolean(); + Consumer ctxConsumer = new Consumer() { + @Override + public void accept(ChannelHandlerContext obj) { + channelReadCompleteCount.incrementAndGet(); + if (shouldDisableAutoRead.get()) { + obj.channel().config().setAutoRead(false); + } + } + }; + LastInboundHandler inboundHandler = new LastInboundHandler(ctxConsumer); + Http2StreamChannel childChannel = newInboundStream(3, false, numReads, inboundHandler); + childChannel.config().setAutoRead(false); + + Http2DataFrame dataFrame1 = new DefaultHttp2DataFrame(bb("1")).stream(childChannel.stream()); + Http2DataFrame dataFrame2 = new DefaultHttp2DataFrame(bb("2")).stream(childChannel.stream()); + Http2DataFrame dataFrame3 = new DefaultHttp2DataFrame(bb("3")).stream(childChannel.stream()); + Http2DataFrame dataFrame4 = new DefaultHttp2DataFrame(bb("4")).stream(childChannel.stream()); + + assertEquals(new DefaultHttp2HeadersFrame(request).stream(childChannel.stream()), inboundHandler.readInbound()); + + ChannelHandler readCompleteSupressHandler = new ChannelInboundHandlerAdapter() { + @Override + public void channelReadComplete(ChannelHandlerContext ctx) { + // We want to simulate the parent channel calling channelRead and delay calling channelReadComplete. + } + }; + parentChannel.pipeline().addFirst(readCompleteSupressHandler); + + frameInboundWriter.writeInboundData(childChannel.stream().id(), bb("1"), 0, false); + + assertEqualsAndRelease(dataFrame1, inboundHandler.readInbound()); + + // We want one item to be in the queue, and allow the numReads to be larger than 1. This will ensure that + // when beginRead() is called the child channel is added to the readPending queue of the parent channel. + frameInboundWriter.writeInboundData(childChannel.stream().id(), bb("2"), 0, false); + + numReads.set(10); + shouldDisableAutoRead.set(true); + childChannel.config().setAutoRead(true); + + frameInboundWriter.writeInboundData(childChannel.stream().id(), bb("3"), 0, false); + frameInboundWriter.writeInboundData(childChannel.stream().id(), bb("4"), 0, false); + + // Detecting EOS should flush all pending data regardless of read calls. + assertEqualsAndRelease(dataFrame2, inboundHandler.readInbound()); + assertEqualsAndRelease(dataFrame3, inboundHandler.readInbound()); + assertEqualsAndRelease(dataFrame4, inboundHandler.readInbound()); + + assertNull(inboundHandler.readInbound()); + + // Now we want to call channelReadComplete and simulate the end of the read loop. + parentChannel.pipeline().remove(readCompleteSupressHandler); + parentChannel.flushInbound(); + + // 3 = 1 for initialization + 1 for read when auto read was off + 1 for when auto read was back on + assertEquals(3, channelReadCompleteCount.get()); + } + + @Test + public void childQueueIsDrainedAndNewDataIsDispatchedInParentReadLoopNoAutoRead() { + final AtomicInteger numReads = new AtomicInteger(1); + final AtomicInteger channelReadCompleteCount = new AtomicInteger(0); + final AtomicBoolean shouldDisableAutoRead = new AtomicBoolean(); + Consumer ctxConsumer = new Consumer() { + @Override + public void accept(ChannelHandlerContext obj) { + channelReadCompleteCount.incrementAndGet(); + if (shouldDisableAutoRead.get()) { + obj.channel().config().setAutoRead(false); + } + } + }; + final LastInboundHandler inboundHandler = new LastInboundHandler(ctxConsumer); + Http2StreamChannel childChannel = newInboundStream(3, false, numReads, inboundHandler); + childChannel.config().setAutoRead(false); + + Http2DataFrame dataFrame1 = new DefaultHttp2DataFrame(bb("1")).stream(childChannel.stream()); + Http2DataFrame dataFrame2 = new DefaultHttp2DataFrame(bb("2")).stream(childChannel.stream()); + Http2DataFrame dataFrame3 = new DefaultHttp2DataFrame(bb("3")).stream(childChannel.stream()); + Http2DataFrame dataFrame4 = new DefaultHttp2DataFrame(bb("4")).stream(childChannel.stream()); + + assertEquals(new DefaultHttp2HeadersFrame(request).stream(childChannel.stream()), inboundHandler.readInbound()); + + ChannelHandler readCompleteSupressHandler = new ChannelInboundHandlerAdapter() { + @Override + public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + // We want to simulate the parent channel calling channelRead and delay calling channelReadComplete. + } + }; + parentChannel.pipeline().addFirst(readCompleteSupressHandler); + + frameInboundWriter.writeInboundData(childChannel.stream().id(), bb("1"), 0, false); + + assertEqualsAndRelease(dataFrame1, inboundHandler.readInbound()); + + // We want one item to be in the queue, and allow the numReads to be larger than 1. This will ensure that + // when beginRead() is called the child channel is added to the readPending queue of the parent channel. + frameInboundWriter.writeInboundData(childChannel.stream().id(), bb("2"), 0, false); + + numReads.set(2); + childChannel.read(); + + assertEqualsAndRelease(dataFrame2, inboundHandler.readInbound()); + + assertNull(inboundHandler.readInbound()); + + // This is the second item that was read, this should be the last until we call read() again. This should also + // notify of readComplete(). + frameInboundWriter.writeInboundData(childChannel.stream().id(), bb("3"), 0, false); + + assertEqualsAndRelease(dataFrame3, inboundHandler.readInbound()); + + frameInboundWriter.writeInboundData(childChannel.stream().id(), bb("4"), 0, false); + assertNull(inboundHandler.readInbound()); + + childChannel.read(); + + assertEqualsAndRelease(dataFrame4, inboundHandler.readInbound()); + + assertNull(inboundHandler.readInbound()); + + // Now we want to call channelReadComplete and simulate the end of the read loop. + parentChannel.pipeline().remove(readCompleteSupressHandler); + parentChannel.flushInbound(); + + // 3 = 1 for initialization + 1 for first read of 2 items + 1 for second read of 2 items + + // 1 for parent channel readComplete + assertEquals(4, channelReadCompleteCount.get()); + } + + @Test + public void useReadWithoutAutoReadInRead() { + useReadWithoutAutoReadBuffered(false); + } + + @Test + public void useReadWithoutAutoReadInReadComplete() { + useReadWithoutAutoReadBuffered(true); + } + + private void useReadWithoutAutoReadBuffered(final boolean triggerOnReadComplete) { + LastInboundHandler inboundHandler = new LastInboundHandler(); + Http2StreamChannel childChannel = newInboundStream(3, false, inboundHandler); + assertTrue(childChannel.config().isAutoRead()); + childChannel.config().setAutoRead(false); + assertFalse(childChannel.config().isAutoRead()); + + Http2HeadersFrame headersFrame = inboundHandler.readInbound(); + assertNotNull(headersFrame); + + // Write some bytes to get the channel into the idle state with buffered data and also verify we + // do not dispatch it until we receive a read() call. + frameInboundWriter.writeInboundData(childChannel.stream().id(), bb("hello world"), 0, false); + frameInboundWriter.writeInboundData(childChannel.stream().id(), bb("foo"), 0, false); + frameInboundWriter.writeInboundData(childChannel.stream().id(), bb("bar"), 0, false); + + // Add a handler which will request reads. + childChannel.pipeline().addFirst(new ChannelInboundHandlerAdapter() { + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + super.channelReadComplete(ctx); + if (triggerOnReadComplete) { + ctx.read(); + ctx.read(); + } + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + ctx.fireChannelRead(msg); + if (!triggerOnReadComplete) { + ctx.read(); + ctx.read(); + } + } + }); + + inboundHandler.channel().read(); + + verifyFramesMultiplexedToCorrectChannel(childChannel, inboundHandler, 3); + + frameInboundWriter.writeInboundData(childChannel.stream().id(), bb("hello world2"), 0, false); + frameInboundWriter.writeInboundData(childChannel.stream().id(), bb("foo2"), 0, false); + frameInboundWriter.writeInboundData(childChannel.stream().id(), bb("bar2"), 0, true); + + verifyFramesMultiplexedToCorrectChannel(childChannel, inboundHandler, 3); + } + + private static final class FlushSniffer extends ChannelOutboundHandlerAdapter { + + private boolean didFlush; + + public boolean checkFlush() { + boolean r = didFlush; + didFlush = false; + return r; + } + + @Override + public void flush(ChannelHandlerContext ctx) throws Exception { + didFlush = true; + super.flush(ctx); + } + } + + @Test + public void windowUpdatesAreFlushed() { + LastInboundHandler inboundHandler = new LastInboundHandler(); + FlushSniffer flushSniffer = new FlushSniffer(); + parentChannel.pipeline().addFirst(flushSniffer); + + Http2StreamChannel childChannel = newInboundStream(3, false, inboundHandler); + assertTrue(childChannel.config().isAutoRead()); + childChannel.config().setAutoRead(false); + assertFalse(childChannel.config().isAutoRead()); + + Http2HeadersFrame headersFrame = inboundHandler.readInbound(); + assertNotNull(headersFrame); + + assertTrue(flushSniffer.checkFlush()); + + // Write some bytes to get the channel into the idle state with buffered data and also verify we + // do not dispatch it until we receive a read() call. + frameInboundWriter.writeInboundData(childChannel.stream().id(), bb(16 * 1024), 0, false); + frameInboundWriter.writeInboundData(childChannel.stream().id(), bb(16 * 1024), 0, false); + assertTrue(flushSniffer.checkFlush()); + + verify(frameWriter, never()).writeWindowUpdate(eqCodecCtx(), anyInt(), anyInt(), anyChannelPromise()); + // only the first one was read because it was legacy auto-read behavior. + verifyFramesMultiplexedToCorrectChannel(childChannel, inboundHandler, 1); + assertFalse(flushSniffer.checkFlush()); + + // Trigger a read of the second frame. + childChannel.read(); + verifyFramesMultiplexedToCorrectChannel(childChannel, inboundHandler, 1); + // We expect a flush here because the StreamChannel will flush the smaller increment but the + // connection will collect the bytes and decide not to send a wire level frame until more are consumed. + assertTrue(flushSniffer.checkFlush()); + verify(frameWriter, never()).writeWindowUpdate(eqCodecCtx(), anyInt(), anyInt(), anyChannelPromise()); + + // Call read one more time which should trigger the writing of the flow control update. + childChannel.read(); + verify(frameWriter).writeWindowUpdate(eqCodecCtx(), eq(0), eq(32 * 1024), anyChannelPromise()); + verify(frameWriter).writeWindowUpdate( + eqCodecCtx(), eq(childChannel.stream().id()), eq(32 * 1024), anyChannelPromise()); + assertTrue(flushSniffer.checkFlush()); + } + + @ParameterizedTest(name = "{displayName} [{index}] value={0}") + @MethodSource("userEvents") + public void userEventsThatPropagatedToChildChannels(Object userEvent) { + final LastInboundHandler inboundParentHandler = new LastInboundHandler(); + final LastInboundHandler inboundHandler = new LastInboundHandler(); + Http2StreamChannel channel = newInboundStream(3, false, inboundHandler); + assertTrue(channel.isActive()); + parentChannel.pipeline().addLast(inboundParentHandler); + parentChannel.pipeline().fireUserEventTriggered(userEvent); + assertEquals(userEvent, inboundHandler.readUserEvent()); + assertEquals(userEvent, inboundParentHandler.readUserEvent()); + assertNull(inboundHandler.readUserEvent()); + assertNull(inboundParentHandler.readUserEvent()); + } + + private static Collection userEvents() { + return Arrays.asList(ChannelInputShutdownReadComplete.INSTANCE, + ChannelOutputShutdownEvent.INSTANCE, SslCloseCompletionEvent.SUCCESS); + } + + private static void verifyFramesMultiplexedToCorrectChannel(Http2StreamChannel streamChannel, + LastInboundHandler inboundHandler, + int numFrames) { + for (int i = 0; i < numFrames; i++) { + Http2StreamFrame frame = inboundHandler.readInbound(); + assertNotNull(frame, i + " out of " + numFrames + " received"); + assertEquals(streamChannel.stream(), frame.stream()); + release(frame); + } + assertNull(inboundHandler.readInbound()); + } + + private static int eqStreamId(Http2StreamChannel channel) { + return eq(channel.stream().id()); + } +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2MultiplexTransportTest.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2MultiplexTransportTest.java new file mode 100644 index 0000000..fd201c6 --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2MultiplexTransportTest.java @@ -0,0 +1,746 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelOption; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.nio.NioServerSocketChannel; +import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.handler.ssl.ApplicationProtocolConfig; +import io.netty.handler.ssl.ApplicationProtocolNames; +import io.netty.handler.ssl.ApplicationProtocolNegotiationHandler; +import io.netty.handler.ssl.ClientAuth; +import io.netty.handler.ssl.OpenSsl; +import io.netty.handler.ssl.SslContext; +import io.netty.handler.ssl.SslContextBuilder; +import io.netty.handler.ssl.SslHandshakeCompletionEvent; +import io.netty.handler.ssl.SslProvider; +import io.netty.handler.ssl.SupportedCipherSuiteFilter; +import io.netty.handler.ssl.util.InsecureTrustManagerFactory; +import io.netty.handler.ssl.util.SelfSignedCertificate; +import io.netty.util.CharsetUtil; +import io.netty.util.NetUtil; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.FutureListener; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.api.condition.DisabledOnOs; +import org.junit.jupiter.api.condition.OS; + +import javax.net.ssl.SSLException; +import javax.net.ssl.X509TrustManager; +import java.net.InetSocketAddress; +import java.security.cert.CertificateException; +import java.security.cert.CertificateExpiredException; +import java.security.cert.X509Certificate; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; + +import static java.util.concurrent.TimeUnit.MILLISECONDS; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +public class Http2MultiplexTransportTest { + private static final ChannelHandler DISCARD_HANDLER = new ChannelInboundHandlerAdapter() { + + @Override + public boolean isSharable() { + return true; + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + ReferenceCountUtil.release(msg); + } + + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + ReferenceCountUtil.release(evt); + } + }; + + private EventLoopGroup eventLoopGroup; + private Channel clientChannel; + private Channel serverChannel; + private Channel serverConnectedChannel; + + private static final class MultiplexInboundStream extends ChannelInboundHandlerAdapter { + ChannelFuture responseFuture; + final AtomicInteger handlerInactivatedFlushed; + final AtomicInteger handleInactivatedNotFlushed; + final CountDownLatch latchHandlerInactive; + static final String LARGE_STRING = generateLargeString(10240); + + MultiplexInboundStream(AtomicInteger handleInactivatedFlushed, + AtomicInteger handleInactivatedNotFlushed, CountDownLatch latchHandlerInactive) { + this.handlerInactivatedFlushed = handleInactivatedFlushed; + this.handleInactivatedNotFlushed = handleInactivatedNotFlushed; + this.latchHandlerInactive = latchHandlerInactive; + } + + @Override + public void channelRead(final ChannelHandlerContext ctx, Object msg) { + if (msg instanceof Http2HeadersFrame && ((Http2HeadersFrame) msg).isEndStream()) { + ByteBuf response = Unpooled.copiedBuffer(LARGE_STRING, CharsetUtil.US_ASCII); + responseFuture = ctx.writeAndFlush(new DefaultHttp2DataFrame(response, true)); + } + ReferenceCountUtil.release(msg); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + if (responseFuture.isSuccess()) { + handlerInactivatedFlushed.incrementAndGet(); + } else { + handleInactivatedNotFlushed.incrementAndGet(); + } + latchHandlerInactive.countDown(); + ctx.fireChannelInactive(); + } + + private static String generateLargeString(int sizeInBytes) { + StringBuilder sb = new StringBuilder(sizeInBytes); + for (int i = 0; i < sizeInBytes; i++) { + sb.append('X'); + } + return sb.toString(); + } + } + + @BeforeEach + public void setup() { + eventLoopGroup = new NioEventLoopGroup(); + } + + @AfterEach + public void teardown() { + if (clientChannel != null) { + clientChannel.close(); + } + if (serverChannel != null) { + serverChannel.close(); + } + if (serverConnectedChannel != null) { + serverConnectedChannel.close(); + } + eventLoopGroup.shutdownGracefully(0, 0, MILLISECONDS); + } + + @Test + @Timeout(value = 10000, unit = MILLISECONDS) + public void asyncSettingsAckWithMultiplexCodec() throws InterruptedException { + asyncSettingsAck0(new Http2MultiplexCodecBuilder(true, DISCARD_HANDLER).build(), null); + } + + @Test + @Timeout(value = 10000, unit = MILLISECONDS) + public void asyncSettingsAckWithMultiplexHandler() throws InterruptedException { + asyncSettingsAck0(new Http2FrameCodecBuilder(true).build(), + new Http2MultiplexHandler(DISCARD_HANDLER)); + } + + private void asyncSettingsAck0(final Http2FrameCodec codec, final ChannelHandler multiplexer) + throws InterruptedException { + // The client expects 2 settings frames. One from the connection setup and one from this test. + final CountDownLatch serverAckOneLatch = new CountDownLatch(1); + final CountDownLatch serverAckAllLatch = new CountDownLatch(2); + final CountDownLatch clientSettingsLatch = new CountDownLatch(2); + final CountDownLatch serverConnectedChannelLatch = new CountDownLatch(1); + final AtomicReference serverConnectedChannelRef = new AtomicReference(); + ServerBootstrap sb = new ServerBootstrap(); + sb.group(eventLoopGroup); + sb.channel(NioServerSocketChannel.class); + sb.childHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + ch.pipeline().addLast(codec); + if (multiplexer != null) { + ch.pipeline().addLast(multiplexer); + } + ch.pipeline().addLast(new ChannelInboundHandlerAdapter() { + @Override + public void channelActive(ChannelHandlerContext ctx) { + serverConnectedChannelRef.set(ctx.channel()); + serverConnectedChannelLatch.countDown(); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + if (msg instanceof Http2SettingsAckFrame) { + serverAckOneLatch.countDown(); + serverAckAllLatch.countDown(); + } + ReferenceCountUtil.release(msg); + } + }); + } + }); + serverChannel = sb.bind(new InetSocketAddress(NetUtil.LOCALHOST, 0)).awaitUninterruptibly().channel(); + + Bootstrap bs = new Bootstrap(); + bs.group(eventLoopGroup); + bs.channel(NioSocketChannel.class); + bs.handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + ch.pipeline().addLast(Http2MultiplexCodecBuilder + .forClient(DISCARD_HANDLER).autoAckSettingsFrame(false).build()); + ch.pipeline().addLast(new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + if (msg instanceof Http2SettingsFrame) { + clientSettingsLatch.countDown(); + } + ReferenceCountUtil.release(msg); + } + }); + } + }); + clientChannel = bs.connect(serverChannel.localAddress()).awaitUninterruptibly().channel(); + serverConnectedChannelLatch.await(); + serverConnectedChannel = serverConnectedChannelRef.get(); + + serverConnectedChannel.writeAndFlush(new DefaultHttp2SettingsFrame(new Http2Settings() + .maxConcurrentStreams(10))).sync(); + + clientSettingsLatch.await(); + + // We expect a timeout here because we want to asynchronously generate the SETTINGS ACK below. + assertFalse(serverAckOneLatch.await(300, MILLISECONDS)); + + // We expect 2 settings frames, the initial settings frame during connection establishment and the setting frame + // written in this test. We should ack both of these settings frames. + clientChannel.writeAndFlush(Http2SettingsAckFrame.INSTANCE).sync(); + clientChannel.writeAndFlush(Http2SettingsAckFrame.INSTANCE).sync(); + + serverAckAllLatch.await(); + } + + @Test + @Timeout(value = 5000L, unit = MILLISECONDS) + public void testFlushNotDiscarded() + throws InterruptedException { + final ScheduledExecutorService executorService = Executors.newScheduledThreadPool(1); + + try { + ServerBootstrap sb = new ServerBootstrap(); + sb.group(eventLoopGroup); + sb.channel(NioServerSocketChannel.class); + sb.childHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + ch.pipeline().addLast(new Http2FrameCodecBuilder(true).build()); + ch.pipeline().addLast(new Http2MultiplexHandler(new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(final ChannelHandlerContext ctx, Object msg) { + if (msg instanceof Http2HeadersFrame && ((Http2HeadersFrame) msg).isEndStream()) { + executorService.schedule(new Runnable() { + @Override + public void run() { + ctx.writeAndFlush(new DefaultHttp2HeadersFrame( + new DefaultHttp2Headers(), false)).addListener( + new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + ctx.write(new DefaultHttp2DataFrame( + Unpooled.copiedBuffer("Hello World", CharsetUtil.US_ASCII), + true)); + ctx.channel().eventLoop().execute(new Runnable() { + @Override + public void run() { + ctx.flush(); + } + }); + } + }); + } + }, 500, MILLISECONDS); + } + ReferenceCountUtil.release(msg); + } + })); + } + }); + serverChannel = sb.bind(new InetSocketAddress(NetUtil.LOCALHOST, 0)).syncUninterruptibly().channel(); + + final CountDownLatch latch = new CountDownLatch(1); + Bootstrap bs = new Bootstrap(); + bs.group(eventLoopGroup); + bs.channel(NioSocketChannel.class); + bs.handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + ch.pipeline().addLast(new Http2FrameCodecBuilder(false).build()); + ch.pipeline().addLast(new Http2MultiplexHandler(DISCARD_HANDLER)); + } + }); + clientChannel = bs.connect(serverChannel.localAddress()).syncUninterruptibly().channel(); + Http2StreamChannelBootstrap h2Bootstrap = new Http2StreamChannelBootstrap(clientChannel); + h2Bootstrap.handler(new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + if (msg instanceof Http2DataFrame && ((Http2DataFrame) msg).isEndStream()) { + latch.countDown(); + } + ReferenceCountUtil.release(msg); + } + }); + Http2StreamChannel streamChannel = h2Bootstrap.open().syncUninterruptibly().getNow(); + streamChannel.writeAndFlush(new DefaultHttp2HeadersFrame(new DefaultHttp2Headers(), true)) + .syncUninterruptibly(); + + latch.await(); + } finally { + executorService.shutdown(); + } + } + + @Test + @Timeout(value = 5000L, unit = MILLISECONDS) + public void testSSLExceptionOpenSslTLSv12() throws Exception { + testSslException(SslProvider.OPENSSL, false); + } + + @Test + @Timeout(value = 5000L, unit = MILLISECONDS) + public void testSSLExceptionOpenSslTLSv13() throws Exception { + testSslException(SslProvider.OPENSSL, true); + } + + @Disabled("JDK SSLEngine does not produce an alert") + @Test + @Timeout(value = 5000L, unit = MILLISECONDS) + public void testSSLExceptionJDKTLSv12() throws Exception { + testSslException(SslProvider.JDK, false); + } + + @Disabled("JDK SSLEngine does not produce an alert") + @Test + @Timeout(value = 5000L, unit = MILLISECONDS) + public void testSSLExceptionJDKTLSv13() throws Exception { + testSslException(SslProvider.JDK, true); + } + + private void testSslException(SslProvider provider, final boolean tlsv13) throws Exception { + assumeTrue(SslProvider.isAlpnSupported(provider)); + if (tlsv13) { + assumeTrue(SslProvider.isTlsv13Supported(provider)); + } + final String protocol = tlsv13 ? "TLSv1.3" : "TLSv1.2"; + SelfSignedCertificate ssc = null; + try { + ssc = new SelfSignedCertificate(); + final SslContext sslCtx = SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) + .trustManager(new X509TrustManager() { + @Override + public void checkClientTrusted(X509Certificate[] chain, String authType) + throws CertificateException { + throw new CertificateExpiredException(); + } + + @Override + public void checkServerTrusted(X509Certificate[] chain, String authType) + throws CertificateException { + throw new CertificateExpiredException(); + } + + @Override + public X509Certificate[] getAcceptedIssuers() { + return new X509Certificate[0]; + } + }).sslProvider(provider) + .ciphers(Http2SecurityUtil.CIPHERS, SupportedCipherSuiteFilter.INSTANCE) + .protocols(protocol) + .applicationProtocolConfig(new ApplicationProtocolConfig( + ApplicationProtocolConfig.Protocol.ALPN, + // NO_ADVERTISE is currently the only mode supported by both OpenSsl and JDK providers. + ApplicationProtocolConfig.SelectorFailureBehavior.NO_ADVERTISE, + // ACCEPT is currently the only mode supported by both OpenSsl and JDK providers. + ApplicationProtocolConfig.SelectedListenerFailureBehavior.ACCEPT, + ApplicationProtocolNames.HTTP_2, + ApplicationProtocolNames.HTTP_1_1)).clientAuth(ClientAuth.REQUIRE) + .build(); + + ServerBootstrap sb = new ServerBootstrap(); + sb.group(eventLoopGroup); + sb.channel(NioServerSocketChannel.class); + sb.childHandler(new ChannelInitializer() { + + @Override + protected void initChannel(Channel ch) { + ch.pipeline().addLast(sslCtx.newHandler(ch.alloc())); + ch.pipeline().addLast(new Http2FrameCodecBuilder(true).build()); + ch.pipeline().addLast(new Http2MultiplexHandler(DISCARD_HANDLER)); + } + }); + serverChannel = sb.bind(new InetSocketAddress(NetUtil.LOCALHOST, 0)).syncUninterruptibly().channel(); + + final SslContext clientCtx = SslContextBuilder.forClient() + .keyManager(ssc.key(), ssc.cert()) + .sslProvider(provider) + /* NOTE: the cipher filter may not include all ciphers required by the HTTP/2 specification. + * Please refer to the HTTP/2 specification for cipher requirements. */ + .ciphers(Http2SecurityUtil.CIPHERS, SupportedCipherSuiteFilter.INSTANCE) + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .protocols(protocol) + .applicationProtocolConfig(new ApplicationProtocolConfig( + ApplicationProtocolConfig.Protocol.ALPN, + // NO_ADVERTISE is currently the only mode supported by both OpenSsl and JDK providers. + ApplicationProtocolConfig.SelectorFailureBehavior.NO_ADVERTISE, + // ACCEPT is currently the only mode supported by both OpenSsl and JDK providers. + ApplicationProtocolConfig.SelectedListenerFailureBehavior.ACCEPT, + ApplicationProtocolNames.HTTP_2, + ApplicationProtocolNames.HTTP_1_1)) + .build(); + + final CountDownLatch latch = new CountDownLatch(2); + final AtomicReference errorRef = new AtomicReference(); + Bootstrap bs = new Bootstrap(); + bs.group(eventLoopGroup); + bs.channel(NioSocketChannel.class); + bs.handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + ch.pipeline().addLast(clientCtx.newHandler(ch.alloc())); + ch.pipeline().addLast(new Http2FrameCodecBuilder(false).build()); + ch.pipeline().addLast(new Http2MultiplexHandler(DISCARD_HANDLER)); + ch.pipeline().addLast(new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + if (evt instanceof SslHandshakeCompletionEvent) { + SslHandshakeCompletionEvent handshakeCompletionEvent = + (SslHandshakeCompletionEvent) evt; + if (handshakeCompletionEvent.isSuccess()) { + // In case of TLSv1.3 we should succeed the handshake. The alert for + // the mTLS failure will be send in the next round-trip. + if (!tlsv13) { + errorRef.set(new AssertionError("TLSv1.3 expected")); + } + + Http2StreamChannelBootstrap h2Bootstrap = + new Http2StreamChannelBootstrap(ctx.channel()); + h2Bootstrap.handler(new ChannelInboundHandlerAdapter() { + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + if (cause.getCause() instanceof SSLException) { + latch.countDown(); + } + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) { + latch.countDown(); + } + }); + h2Bootstrap.open().addListener(new FutureListener() { + @Override + public void operationComplete(Future future) { + if (future.isSuccess()) { + future.getNow().writeAndFlush(new DefaultHttp2HeadersFrame( + new DefaultHttp2Headers(), false)); + } + } + }); + + } else if (handshakeCompletionEvent.cause() instanceof SSLException) { + // In case of TLSv1.2 we should never see the handshake succeed as the alert for + // the mTLS failure will be send in the same round-trip. + if (tlsv13) { + errorRef.set(new AssertionError("TLSv1.2 expected")); + } + latch.countDown(); + latch.countDown(); + } + } + } + }); + } + }); + clientChannel = bs.connect(serverChannel.localAddress()).syncUninterruptibly().channel(); + latch.await(); + AssertionError error = errorRef.get(); + if (error != null) { + throw error; + } + } finally { + if (ssc != null) { + ssc.delete(); + } + } + } + + @Test + @DisabledOnOs(value = OS.WINDOWS, disabledReason = "See: https://github.com/netty/netty/issues/11542") + @Timeout(value = 5000L, unit = MILLISECONDS) + public void testFireChannelReadAfterHandshakeSuccess_JDK() throws Exception { + assumeTrue(SslProvider.isAlpnSupported(SslProvider.JDK)); + testFireChannelReadAfterHandshakeSuccess(SslProvider.JDK); + } + + @Test + @DisabledOnOs(value = OS.WINDOWS, disabledReason = "See: https://github.com/netty/netty/issues/11542") + @Timeout(value = 5000L, unit = MILLISECONDS) + public void testFireChannelReadAfterHandshakeSuccess_OPENSSL() throws Exception { + assumeTrue(OpenSsl.isAvailable()); + assumeTrue(SslProvider.isAlpnSupported(SslProvider.OPENSSL)); + testFireChannelReadAfterHandshakeSuccess(SslProvider.OPENSSL); + } + + private void testFireChannelReadAfterHandshakeSuccess(SslProvider provider) throws Exception { + SelfSignedCertificate ssc = null; + try { + ssc = new SelfSignedCertificate(); + final SslContext serverCtx = SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) + .sslProvider(provider) + .ciphers(Http2SecurityUtil.CIPHERS, SupportedCipherSuiteFilter.INSTANCE) + .applicationProtocolConfig(new ApplicationProtocolConfig( + ApplicationProtocolConfig.Protocol.ALPN, + ApplicationProtocolConfig.SelectorFailureBehavior.NO_ADVERTISE, + ApplicationProtocolConfig.SelectedListenerFailureBehavior.ACCEPT, + ApplicationProtocolNames.HTTP_2, + ApplicationProtocolNames.HTTP_1_1)) + .build(); + + ServerBootstrap sb = new ServerBootstrap(); + sb.group(eventLoopGroup); + sb.channel(NioServerSocketChannel.class); + sb.childHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + ch.pipeline().addLast(serverCtx.newHandler(ch.alloc())); + ch.pipeline().addLast(new ApplicationProtocolNegotiationHandler(ApplicationProtocolNames.HTTP_1_1) { + @Override + protected void configurePipeline(ChannelHandlerContext ctx, String protocol) { + ctx.pipeline().addLast(new Http2FrameCodecBuilder(true).build()); + ctx.pipeline().addLast(new Http2MultiplexHandler(new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(final ChannelHandlerContext ctx, Object msg) { + if (msg instanceof Http2HeadersFrame && ((Http2HeadersFrame) msg).isEndStream()) { + ctx.writeAndFlush(new DefaultHttp2HeadersFrame( + new DefaultHttp2Headers(), false)) + .addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + ctx.writeAndFlush(new DefaultHttp2DataFrame( + Unpooled.copiedBuffer("Hello World", CharsetUtil.US_ASCII), + true)); + } + }); + } + ReferenceCountUtil.release(msg); + } + })); + } + }); + } + }); + serverChannel = sb.bind(new InetSocketAddress(NetUtil.LOCALHOST, 0)).sync().channel(); + + final SslContext clientCtx = SslContextBuilder.forClient() + .sslProvider(provider) + .ciphers(Http2SecurityUtil.CIPHERS, SupportedCipherSuiteFilter.INSTANCE) + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .applicationProtocolConfig(new ApplicationProtocolConfig( + ApplicationProtocolConfig.Protocol.ALPN, + ApplicationProtocolConfig.SelectorFailureBehavior.NO_ADVERTISE, + ApplicationProtocolConfig.SelectedListenerFailureBehavior.ACCEPT, + ApplicationProtocolNames.HTTP_2, + ApplicationProtocolNames.HTTP_1_1)) + .build(); + + final CountDownLatch latch = new CountDownLatch(1); + Bootstrap bs = new Bootstrap(); + bs.group(eventLoopGroup); + bs.channel(NioSocketChannel.class); + bs.handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + ch.pipeline().addLast(clientCtx.newHandler(ch.alloc())); + ch.pipeline().addLast(new Http2FrameCodecBuilder(false).build()); + ch.pipeline().addLast(new Http2MultiplexHandler(DISCARD_HANDLER)); + ch.pipeline().addLast(new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + if (evt instanceof SslHandshakeCompletionEvent) { + SslHandshakeCompletionEvent handshakeCompletionEvent = + (SslHandshakeCompletionEvent) evt; + if (handshakeCompletionEvent.isSuccess()) { + Http2StreamChannelBootstrap h2Bootstrap = + new Http2StreamChannelBootstrap(clientChannel); + h2Bootstrap.handler(new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + if (msg instanceof Http2DataFrame && ((Http2DataFrame) msg).isEndStream()) { + latch.countDown(); + } + ReferenceCountUtil.release(msg); + } + }); + h2Bootstrap.open().addListener(new FutureListener() { + @Override + public void operationComplete(Future future) { + if (future.isSuccess()) { + future.getNow().writeAndFlush(new DefaultHttp2HeadersFrame( + new DefaultHttp2Headers(), true)); + } + } + }); + } + } + } + }); + } + }); + clientChannel = bs.connect(serverChannel.localAddress()).sync().channel(); + + latch.await(); + } finally { + if (ssc != null) { + ssc.delete(); + } + } + } + + /** + * When an HTTP/2 server stream channel receives a frame with EOS flag, and when it responds with a EOS + * flag, then the server side stream will be closed, hence the stream handler will be inactivated. This test + * verifies that the ChannelFuture of the server response is successful at the time the server stream handler is + * inactivated. + */ + @Test + @Timeout(value = 120000L, unit = MILLISECONDS) + public void streamHandlerInactivatedResponseFlushed() throws InterruptedException { + EventLoopGroup serverEventLoopGroup = null; + EventLoopGroup clientEventLoopGroup = null; + + try { + serverEventLoopGroup = new NioEventLoopGroup(1, new ThreadFactory() { + @Override + public Thread newThread(Runnable r) { + return new Thread(r, "serverloop"); + } + }); + + clientEventLoopGroup = new NioEventLoopGroup(1, new ThreadFactory() { + @Override + public Thread newThread(Runnable r) { + return new Thread(r, "clientloop"); + } + }); + + final int streams = 10; + final CountDownLatch latchClientResponses = new CountDownLatch(streams); + final CountDownLatch latchHandlerInactive = new CountDownLatch(streams); + + final AtomicInteger handlerInactivatedFlushed = new AtomicInteger(); + final AtomicInteger handleInactivatedNotFlushed = new AtomicInteger(); + final ServerBootstrap sb = new ServerBootstrap(); + + sb.group(serverEventLoopGroup); + sb.channel(NioServerSocketChannel.class); + sb.childHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + // using a short sndbuf size will trigger writability events + ch.config().setOption(ChannelOption.SO_SNDBUF, 1); + ch.pipeline().addLast(new Http2FrameCodecBuilder(true).build()); + ch.pipeline().addLast(new Http2MultiplexHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + ch.pipeline().remove(this); + ch.pipeline().addLast(new MultiplexInboundStream(handlerInactivatedFlushed, + handleInactivatedNotFlushed, latchHandlerInactive)); + } + })); + } + }); + serverChannel = sb.bind(new InetSocketAddress(NetUtil.LOCALHOST, 0)).syncUninterruptibly().channel(); + + final Bootstrap bs = new Bootstrap(); + + bs.group(clientEventLoopGroup); + bs.channel(NioSocketChannel.class); + bs.handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + ch.pipeline().addLast(new Http2FrameCodecBuilder(false).build()); + ch.pipeline().addLast(new Http2MultiplexHandler(DISCARD_HANDLER)); + } + }); + + clientChannel = bs.connect(serverChannel.localAddress()).syncUninterruptibly().channel(); + final Http2StreamChannelBootstrap h2Bootstrap = new Http2StreamChannelBootstrap(clientChannel); + h2Bootstrap.handler(new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + if (msg instanceof Http2DataFrame && ((Http2DataFrame) msg).isEndStream()) { + latchClientResponses.countDown(); + } + ReferenceCountUtil.release(msg); + } + @Override + public boolean isSharable() { + return true; + } + }); + + List streamFutures = new ArrayList(); + for (int i = 0; i < streams; i ++) { + Http2StreamChannel stream = h2Bootstrap.open().syncUninterruptibly().getNow(); + streamFutures.add(stream.writeAndFlush(new DefaultHttp2HeadersFrame(new DefaultHttp2Headers(), true))); + } + for (int i = 0; i < streams; i ++) { + streamFutures.get(i).syncUninterruptibly(); + } + + assertTrue(latchHandlerInactive.await(120000, MILLISECONDS)); + assertTrue(latchClientResponses.await(120000, MILLISECONDS)); + assertEquals(0, handleInactivatedNotFlushed.get()); + assertEquals(streams, handlerInactivatedFlushed.get()); + } finally { + if (serverEventLoopGroup != null) { + serverEventLoopGroup.shutdownGracefully(0, 0, MILLISECONDS); + } + if (clientEventLoopGroup != null) { + clientEventLoopGroup.shutdownGracefully(0, 0, MILLISECONDS); + } + } + } +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2SecurityUtilTest.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2SecurityUtilTest.java new file mode 100644 index 0000000..81ec236 --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2SecurityUtilTest.java @@ -0,0 +1,49 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.buffer.UnpooledByteBufAllocator; +import io.netty.handler.ssl.SslContext; +import io.netty.handler.ssl.SslContextBuilder; +import io.netty.handler.ssl.SslProvider; +import io.netty.handler.ssl.SupportedCipherSuiteFilter; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Assumptions; +import org.junit.jupiter.api.Test; + +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLException; + +public class Http2SecurityUtilTest { + + @Test + public void testTLSv13CiphersIncluded() throws SSLException { + Assumptions.assumeTrue(SslProvider.isTlsv13Supported(SslProvider.JDK)); + testCiphersIncluded("TLSv1.3"); + } + + @Test + public void testTLSv12CiphersIncluded() throws SSLException { + testCiphersIncluded("TLSv1.2"); + } + + private static void testCiphersIncluded(String protocol) throws SSLException { + SslContext context = SslContextBuilder.forClient().sslProvider(SslProvider.JDK).protocols(protocol) + .ciphers(Http2SecurityUtil.CIPHERS, SupportedCipherSuiteFilter.INSTANCE).build(); + SSLEngine engine = context.newEngine(UnpooledByteBufAllocator.DEFAULT); + Assertions.assertTrue(engine.getEnabledCipherSuites().length > 0, "No " + protocol + " ciphers found"); + } +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2ServerUpgradeCodecTest.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2ServerUpgradeCodecTest.java new file mode 100644 index 0000000..88f548f --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2ServerUpgradeCodecTest.java @@ -0,0 +1,107 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.DefaultChannelId; +import io.netty.channel.ServerChannel; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.http.DefaultFullHttpRequest; +import io.netty.handler.codec.http.DefaultHttpHeaders; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpVersion; + +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import org.junit.jupiter.api.Test; +import org.mockito.Mockito; + +public class Http2ServerUpgradeCodecTest { + + @Test + public void testUpgradeToHttp2ConnectionHandler() { + testUpgrade(new Http2ConnectionHandlerBuilder().frameListener(new Http2FrameAdapter()).build(), null); + } + + @Test + public void testUpgradeToHttp2FrameCodec() { + testUpgrade(new Http2FrameCodecBuilder(true).build(), null); + } + + @Test + public void testUpgradeToHttp2MultiplexCodec() { + testUpgrade(new Http2MultiplexCodecBuilder(true, new HttpInboundHandler()).build(), null); + } + + @Test + public void testUpgradeToHttp2FrameCodecWithMultiplexer() { + testUpgrade(new Http2FrameCodecBuilder(true).build(), + new Http2MultiplexHandler(new HttpInboundHandler())); + } + + private static void testUpgrade(Http2ConnectionHandler handler, ChannelHandler multiplexer) { + FullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.OPTIONS, "*"); + request.headers().set(HttpHeaderNames.HOST, "netty.io"); + request.headers().set(HttpHeaderNames.CONNECTION, "Upgrade, HTTP2-Settings"); + request.headers().set(HttpHeaderNames.UPGRADE, "h2c"); + request.headers().set("HTTP2-Settings", "AAMAAABkAAQAAP__"); + + ServerChannel parent = Mockito.mock(ServerChannel.class); + EmbeddedChannel channel = new EmbeddedChannel(parent, DefaultChannelId.newInstance(), true, false, + new ChannelInboundHandlerAdapter()); + ChannelHandlerContext ctx = channel.pipeline().firstContext(); + Http2ServerUpgradeCodec codec; + if (multiplexer == null) { + codec = new Http2ServerUpgradeCodec(handler); + } else { + codec = new Http2ServerUpgradeCodec((Http2FrameCodec) handler, multiplexer); + } + assertTrue(codec.prepareUpgradeResponse(ctx, request, new DefaultHttpHeaders())); + codec.upgradeTo(ctx, request); + // Flush the channel to ensure we write out all buffered data + channel.flush(); + + channel.writeInbound(Http2CodecUtil.connectionPrefaceBuf()); + Http2FrameInboundWriter writer = new Http2FrameInboundWriter(channel); + writer.writeInboundSettings(new Http2Settings()); + writer.writeInboundRstStream(Http2CodecUtil.HTTP_UPGRADE_STREAM_ID, Http2Error.CANCEL.code()); + + assertSame(handler, channel.pipeline().remove(handler.getClass())); + assertNull(channel.pipeline().get(handler.getClass())); + assertTrue(channel.finish()); + + // Check that the preface was send (a.k.a the settings frame) + ByteBuf settingsBuffer = channel.readOutbound(); + assertNotNull(settingsBuffer); + settingsBuffer.release(); + + ByteBuf buf = channel.readOutbound(); + assertNotNull(buf); + buf.release(); + + assertNull(channel.readOutbound()); + } + + @ChannelHandler.Sharable + private static final class HttpInboundHandler extends ChannelInboundHandlerAdapter { } +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2SettingsTest.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2SettingsTest.java new file mode 100644 index 0000000..eb48bb4 --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2SettingsTest.java @@ -0,0 +1,253 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package io.netty.handler.codec.http2; + + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import static io.netty.handler.codec.http2.Http2CodecUtil.MAX_CONCURRENT_STREAMS; +import static io.netty.handler.codec.http2.Http2CodecUtil.MAX_FRAME_SIZE_LOWER_BOUND; +import static io.netty.handler.codec.http2.Http2CodecUtil.MAX_FRAME_SIZE_UPPER_BOUND; +import static io.netty.handler.codec.http2.Http2CodecUtil.MAX_HEADER_LIST_SIZE; +import static io.netty.handler.codec.http2.Http2CodecUtil.MAX_HEADER_TABLE_SIZE; +import static io.netty.handler.codec.http2.Http2CodecUtil.MAX_INITIAL_WINDOW_SIZE; +import static io.netty.handler.codec.http2.Http2CodecUtil.MAX_UNSIGNED_INT; +import static io.netty.handler.codec.http2.Http2CodecUtil.MIN_CONCURRENT_STREAMS; +import static io.netty.handler.codec.http2.Http2CodecUtil.MIN_HEADER_LIST_SIZE; +import static io.netty.handler.codec.http2.Http2CodecUtil.MIN_HEADER_TABLE_SIZE; +import static io.netty.handler.codec.http2.Http2CodecUtil.MIN_INITIAL_WINDOW_SIZE; +import static java.lang.Long.MAX_VALUE; +import static java.lang.Long.MIN_VALUE; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * Tests for {@link Http2Settings}. + */ +public class Http2SettingsTest { + + private Http2Settings settings; + + @BeforeEach + public void setup() { + settings = new Http2Settings(); + } + + @Test + public void standardSettingsShouldBeNotSet() { + assertEquals(0, settings.size()); + assertNull(settings.headerTableSize()); + assertNull(settings.initialWindowSize()); + assertNull(settings.maxConcurrentStreams()); + assertNull(settings.pushEnabled()); + assertNull(settings.maxFrameSize()); + assertNull(settings.maxHeaderListSize()); + } + + @Test + public void standardSettingsShouldBeSet() { + settings.initialWindowSize(1); + settings.maxConcurrentStreams(2); + settings.pushEnabled(true); + settings.headerTableSize(3); + settings.maxFrameSize(MAX_FRAME_SIZE_UPPER_BOUND); + settings.maxHeaderListSize(4); + assertEquals(1, (int) settings.initialWindowSize()); + assertEquals(2L, (long) settings.maxConcurrentStreams()); + assertTrue(settings.pushEnabled()); + assertEquals(3L, (long) settings.headerTableSize()); + assertEquals(MAX_FRAME_SIZE_UPPER_BOUND, (int) settings.maxFrameSize()); + assertEquals(4L, (long) settings.maxHeaderListSize()); + } + + @Test + public void settingsShouldSupportUnsignedShort() { + char key = (char) (Short.MAX_VALUE + 1); + settings.put(key, (Long) 123L); + assertEquals(123L, (long) settings.get(key)); + } + + @ParameterizedTest(name = "{displayName} [{index}] value={0}") + @ValueSource(longs = {MIN_HEADER_LIST_SIZE, MIN_HEADER_LIST_SIZE + 1L, + MAX_HEADER_LIST_SIZE - 1L, MAX_HEADER_LIST_SIZE}) + public void headerListSize(final long value) { + settings.maxHeaderListSize(value); + assertEquals(value, (long) settings.maxHeaderListSize()); + } + + @ParameterizedTest(name = "{displayName} [{index}] value={0}") + @ValueSource(longs = {MIN_VALUE, MIN_HEADER_LIST_SIZE - 1L, MAX_HEADER_LIST_SIZE + 1L, MAX_VALUE}) + public void headerListSizeBoundCheck(final long value) { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + settings.maxHeaderListSize(value); + } + }); + } + + @ParameterizedTest(name = "{displayName} [{index}] value={0}") + @ValueSource(longs = {MIN_HEADER_TABLE_SIZE, MIN_HEADER_TABLE_SIZE + 1L, + MAX_HEADER_TABLE_SIZE - 1L, MAX_HEADER_TABLE_SIZE}) + public void headerTableSize(final long value) { + settings.headerTableSize(value); + assertEquals(value, (long) settings.headerTableSize()); + } + + @ParameterizedTest(name = "{displayName} [{index}] value={0}") + @ValueSource(longs = {MIN_VALUE, MIN_HEADER_TABLE_SIZE - 1L, MAX_HEADER_TABLE_SIZE + 1L, MAX_VALUE}) + public void headerTableSizeBoundCheck(final long value) { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + settings.headerTableSize(value); + } + }); + } + + @ParameterizedTest(name = "{displayName} [{index}] value={0}") + @ValueSource(booleans = {false, true}) + public void pushEnabled(final boolean value) { + settings.pushEnabled(value); + assertEquals(value, settings.pushEnabled()); + } + + @ParameterizedTest(name = "{displayName} [{index}] value={0}") + @ValueSource(longs = {0, 1}) + public void enablePush(final long value) { + settings.put(Http2CodecUtil.SETTINGS_ENABLE_PUSH, (Long) value); + assertEquals(value, (long) settings.get(Http2CodecUtil.SETTINGS_ENABLE_PUSH)); + assertEquals(value == 1, settings.pushEnabled()); + } + + @ParameterizedTest(name = "{displayName} [{index}] value={0}") + @ValueSource(longs = {MIN_VALUE, -1, 2, MAX_VALUE}) + public void enablePushBoundCheck(final long value) { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + settings.put(Http2CodecUtil.SETTINGS_ENABLE_PUSH, (Long) value); + } + }); + } + + @ParameterizedTest(name = "{displayName} [{index}] value={0}") + @ValueSource(longs = {MIN_CONCURRENT_STREAMS, MIN_CONCURRENT_STREAMS + 1L, + MAX_CONCURRENT_STREAMS - 1L, MAX_CONCURRENT_STREAMS}) + public void maxConcurrentStreams(final long value) { + settings.maxConcurrentStreams(value); + assertEquals(value, (long) settings.maxConcurrentStreams()); + } + + @ParameterizedTest(name = "{displayName} [{index}] value={0}") + @ValueSource(longs = {MIN_VALUE, MIN_CONCURRENT_STREAMS - 1L, MAX_CONCURRENT_STREAMS + 1L, MAX_VALUE}) + public void maxConcurrentStreamsBoundCheck(final long value) { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + settings.maxConcurrentStreams(value); + } + }); + } + + @ParameterizedTest(name = "{displayName} [{index}] value={0}") + @ValueSource(ints = {MIN_INITIAL_WINDOW_SIZE, MIN_INITIAL_WINDOW_SIZE + 1, + MAX_INITIAL_WINDOW_SIZE - 1, MAX_INITIAL_WINDOW_SIZE}) + public void initialWindowSize(final int value) { + settings.initialWindowSize(value); + assertEquals(value, (int) settings.initialWindowSize()); + } + + @ParameterizedTest(name = "{displayName} [{index}] value={0}") + @ValueSource(ints = {Integer.MIN_VALUE, MIN_INITIAL_WINDOW_SIZE - 1}) + public void initialWindowSizeIntBoundCheck(final int value) { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + settings.initialWindowSize(value); + } + }); + } + + @ParameterizedTest(name = "{displayName} [{index}] value={0}") + @ValueSource(longs = {MIN_VALUE, MIN_INITIAL_WINDOW_SIZE - 1L, MAX_INITIAL_WINDOW_SIZE + 1L, MAX_VALUE}) + public void initialWindowSizeBoundCheck(final long value) { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + settings.put(Http2CodecUtil.SETTINGS_INITIAL_WINDOW_SIZE, (Long) value); + } + }); + } + + @ParameterizedTest(name = "{displayName} [{index}] value={0}") + @ValueSource(ints = {MAX_FRAME_SIZE_LOWER_BOUND, MAX_FRAME_SIZE_LOWER_BOUND + 1, + MAX_FRAME_SIZE_UPPER_BOUND - 1, MAX_FRAME_SIZE_UPPER_BOUND}) + public void maxFrameSize(final int value) { + settings.maxFrameSize(value); + assertEquals(value, (int) settings.maxFrameSize()); + } + + @ParameterizedTest(name = "{displayName} [{index}] value={0}") + @ValueSource(ints = {Integer.MIN_VALUE, 0, MAX_FRAME_SIZE_LOWER_BOUND - 1, + MAX_FRAME_SIZE_UPPER_BOUND + 1, Integer.MAX_VALUE}) + public void maxFrameSizeIntBoundCheck(final int value) { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + settings.maxFrameSize(value); + } + }); + } + + @ParameterizedTest(name = "{displayName} [{index}] value={0}") + @ValueSource(longs = {MIN_VALUE, 0L, MAX_FRAME_SIZE_LOWER_BOUND - 1L, MAX_FRAME_SIZE_UPPER_BOUND + 1L, + Integer.MAX_VALUE, MAX_VALUE}) + public void maxFrameSizeBoundCheck(final long value) { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + settings.put(Http2CodecUtil.SETTINGS_MAX_FRAME_SIZE, (Long) value); + } + }); + } + + @ParameterizedTest(name = "{displayName} [{index}] value={0}") + @ValueSource(longs = {0L, 1L, 123L, Integer.MAX_VALUE, MAX_UNSIGNED_INT - 1L, MAX_UNSIGNED_INT}) + public void nonStandardSetting(final long value) { + char key = 0; + settings.put(key, (Long) value); + assertEquals(value, (long) settings.get(key)); + } + + @ParameterizedTest(name = "{displayName} [{index}] value={0}") + @ValueSource(longs = {MIN_VALUE, Integer.MIN_VALUE, -1L, MAX_UNSIGNED_INT + 1L, MAX_VALUE}) + public void nonStandardSettingBoundCheck(final long value) { + final char key = 0; + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + settings.put(key, (Long) value); + } + }); + } +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2StreamChannelBootstrapTest.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2StreamChannelBootstrapTest.java new file mode 100644 index 0000000..b46a0cb --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2StreamChannelBootstrapTest.java @@ -0,0 +1,166 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.DefaultEventLoop; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.local.LocalAddress; +import io.netty.channel.local.LocalChannel; +import io.netty.channel.local.LocalServerChannel; +import io.netty.util.concurrent.DefaultPromise; +import io.netty.util.concurrent.EventExecutor; +import io.netty.util.concurrent.Promise; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; +import org.hamcrest.core.IsInstanceOf; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import java.nio.channels.ClosedChannelException; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; + +import static io.netty.handler.codec.http2.Http2FrameCodecBuilder.forClient; +import static io.netty.handler.codec.http2.Http2FrameCodecBuilder.forServer; +import static java.util.concurrent.TimeUnit.SECONDS; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.instanceOf; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + +public class Http2StreamChannelBootstrapTest { + + private static final InternalLogger logger = + InternalLoggerFactory.getInstance(Http2StreamChannelBootstrapTest.class); + + private volatile Channel serverConnectedChannel; + + @Test + public void testStreamIsNotCreatedIfParentConnectionIsClosedConcurrently() throws Exception { + EventLoopGroup group = null; + Channel serverChannel = null; + Channel clientChannel = null; + try { + final CountDownLatch serverChannelLatch = new CountDownLatch(1); + group = new DefaultEventLoop(); + LocalAddress serverAddress = new LocalAddress(getClass().getName()); + ServerBootstrap sb = new ServerBootstrap() + .channel(LocalServerChannel.class) + .group(group) + .childHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + serverConnectedChannel = ch; + ch.pipeline().addLast(forServer().build(), newMultiplexedHandler()); + serverChannelLatch.countDown(); + } + }); + serverChannel = sb.bind(serverAddress).sync().channel(); + + Bootstrap cb = new Bootstrap() + .channel(LocalChannel.class) + .group(group) + .handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + ch.pipeline().addLast(forClient().build(), newMultiplexedHandler()); + } + }); + clientChannel = cb.connect(serverAddress).sync().channel(); + assertTrue(serverChannelLatch.await(3, SECONDS)); + + final CountDownLatch closeLatch = new CountDownLatch(1); + final Channel clientChannelToClose = clientChannel; + group.execute(new Runnable() { + @Override + public void run() { + try { + closeLatch.await(); + clientChannelToClose.close().syncUninterruptibly(); + } catch (InterruptedException e) { + logger.error(e); + } + } + }); + + Http2StreamChannelBootstrap bootstrap = new Http2StreamChannelBootstrap(clientChannel); + final Promise promise = clientChannel.eventLoop().newPromise(); + bootstrap.open(promise); + assertThat(promise.isDone(), is(false)); + closeLatch.countDown(); + + ExecutionException exception = assertThrows(ExecutionException.class, new Executable() { + @Override + public void execute() throws Throwable { + promise.get(3, SECONDS); + } + }); + assertThat(exception.getCause(), IsInstanceOf.instanceOf(ClosedChannelException.class)); + } finally { + safeClose(clientChannel); + safeClose(serverConnectedChannel); + safeClose(serverChannel); + if (group != null) { + group.shutdownGracefully(0, 3, SECONDS); + } + } + } + + private static Http2MultiplexHandler newMultiplexedHandler() { + return new Http2MultiplexHandler(new ChannelInitializer() { + @Override + protected void initChannel(Http2StreamChannel ch) { + // noop + } + }); + } + + private static void safeClose(Channel channel) { + if (channel != null) { + try { + channel.close().syncUninterruptibly(); + } catch (Exception e) { + logger.error(e); + } + } + } + + @Test + public void open0FailsPromiseOnHttp2MultiplexHandlerError() { + Http2StreamChannelBootstrap bootstrap = new Http2StreamChannelBootstrap(mock(Channel.class)); + + Http2MultiplexHandler handler = new Http2MultiplexHandler(mock(ChannelHandler.class)); + EventExecutor executor = mock(EventExecutor.class); + when(executor.inEventLoop()).thenReturn(true); + ChannelHandlerContext ctx = mock(ChannelHandlerContext.class); + when(ctx.executor()).thenReturn(executor); + when(ctx.handler()).thenReturn(handler); + + Promise promise = new DefaultPromise(mock(EventExecutor.class)); + bootstrap.open0(ctx, promise); + assertThat(promise.isDone(), is(true)); + assertThat(promise.cause(), is(instanceOf(IllegalStateException.class))); + } +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2StreamChannelIdTest.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2StreamChannelIdTest.java new file mode 100644 index 0000000..7482230 --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2StreamChannelIdTest.java @@ -0,0 +1,58 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufInputStream; +import io.netty.buffer.ByteBufOutputStream; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelId; +import io.netty.channel.DefaultChannelId; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class Http2StreamChannelIdTest { + + @Test + public void testSerialization() throws Exception { + ChannelId normalInstance = new Http2StreamChannelId(DefaultChannelId.newInstance(), 0); + + ByteBuf buf = Unpooled.buffer(); + ObjectOutputStream outStream = new ObjectOutputStream(new ByteBufOutputStream(buf)); + try { + outStream.writeObject(normalInstance); + } finally { + outStream.close(); + } + + ObjectInputStream inStream = new ObjectInputStream(new ByteBufInputStream(buf, true)); + final ChannelId deserializedInstance; + try { + deserializedInstance = (ChannelId) inStream.readObject(); + } finally { + inStream.close(); + } + + assertEquals(normalInstance, deserializedInstance); + assertEquals(normalInstance.hashCode(), deserializedInstance.hashCode()); + assertEquals(0, normalInstance.compareTo(deserializedInstance)); + assertEquals(normalInstance.asLongText(), deserializedInstance.asLongText()); + assertEquals(normalInstance.asShortText(), deserializedInstance.asShortText()); + } +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2StreamFrameToHttpObjectCodecTest.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2StreamFrameToHttpObjectCodecTest.java new file mode 100644 index 0000000..75d98d8 --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2StreamFrameToHttpObjectCodecTest.java @@ -0,0 +1,995 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelOutboundHandlerAdapter; +import io.netty.channel.ChannelPromise; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.EncoderException; +import io.netty.handler.codec.http.DefaultFullHttpRequest; +import io.netty.handler.codec.http.DefaultFullHttpResponse; +import io.netty.handler.codec.http.DefaultHttpContent; +import io.netty.handler.codec.http.DefaultHttpRequest; +import io.netty.handler.codec.http.DefaultHttpResponse; +import io.netty.handler.codec.http.DefaultLastHttpContent; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.FullHttpResponse; +import io.netty.handler.codec.http.HttpContent; +import io.netty.handler.codec.http.HttpHeaders; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpRequest; +import io.netty.handler.codec.http.HttpResponse; +import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.handler.codec.http.HttpScheme; +import io.netty.handler.codec.http.HttpVersion; +import io.netty.handler.codec.http.HttpUtil; +import io.netty.handler.codec.http.LastHttpContent; +import io.netty.handler.ssl.SslContext; +import io.netty.handler.ssl.SslContextBuilder; +import io.netty.handler.ssl.SslProvider; +import io.netty.util.CharsetUtil; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import java.util.Queue; +import java.util.concurrent.ConcurrentLinkedQueue; + +import static org.hamcrest.CoreMatchers.instanceOf; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.CoreMatchers.not; +import static org.hamcrest.CoreMatchers.nullValue; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.assertFalse; + +public class Http2StreamFrameToHttpObjectCodecTest { + + @Test + public void testUpgradeEmptyFullResponse() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new Http2StreamFrameToHttpObjectCodec(true)); + assertTrue(ch.writeOutbound(new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK))); + + Http2HeadersFrame headersFrame = ch.readOutbound(); + assertThat(headersFrame.headers().status().toString(), is("200")); + assertTrue(headersFrame.isEndStream()); + + assertThat(ch.readOutbound(), is(nullValue())); + assertFalse(ch.finish()); + } + + @Test + public void encode100ContinueAsHttp2HeadersFrameThatIsNotEndStream() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new Http2StreamFrameToHttpObjectCodec(true)); + assertTrue(ch.writeOutbound(new DefaultFullHttpResponse( + HttpVersion.HTTP_1_1, HttpResponseStatus.CONTINUE))); + + Http2HeadersFrame headersFrame = ch.readOutbound(); + assertThat(headersFrame.headers().status().toString(), is("100")); + assertFalse(headersFrame.isEndStream()); + + assertThat(ch.readOutbound(), is(nullValue())); + assertFalse(ch.finish()); + } + + @Test + public void encodeNonFullHttpResponse100ContinueIsRejected() throws Exception { + final EmbeddedChannel ch = new EmbeddedChannel(new Http2StreamFrameToHttpObjectCodec(true)); + assertThrows(EncoderException.class, new Executable() { + @Override + public void execute() { + ch.writeOutbound(new DefaultHttpResponse( + HttpVersion.HTTP_1_1, HttpResponseStatus.CONTINUE)); + } + }); + ch.finishAndReleaseAll(); + } + + @Test + public void testUpgradeNonEmptyFullResponse() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new Http2StreamFrameToHttpObjectCodec(true)); + ByteBuf hello = Unpooled.copiedBuffer("hello world", CharsetUtil.UTF_8); + assertTrue(ch.writeOutbound(new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK, hello))); + + Http2HeadersFrame headersFrame = ch.readOutbound(); + assertThat(headersFrame.headers().status().toString(), is("200")); + assertFalse(headersFrame.isEndStream()); + + Http2DataFrame dataFrame = ch.readOutbound(); + try { + assertThat(dataFrame.content().toString(CharsetUtil.UTF_8), is("hello world")); + assertTrue(dataFrame.isEndStream()); + } finally { + dataFrame.release(); + } + + assertThat(ch.readOutbound(), is(nullValue())); + assertFalse(ch.finish()); + } + + @Test + public void testUpgradeEmptyFullResponseWithTrailers() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new Http2StreamFrameToHttpObjectCodec(true)); + FullHttpResponse response = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK); + HttpHeaders trailers = response.trailingHeaders(); + trailers.set("key", "value"); + assertTrue(ch.writeOutbound(response)); + + Http2HeadersFrame headersFrame = ch.readOutbound(); + assertThat(headersFrame.headers().status().toString(), is("200")); + assertFalse(headersFrame.isEndStream()); + + Http2HeadersFrame trailersFrame = ch.readOutbound(); + assertThat(trailersFrame.headers().get("key").toString(), is("value")); + assertTrue(trailersFrame.isEndStream()); + + assertThat(ch.readOutbound(), is(nullValue())); + assertFalse(ch.finish()); + } + + @Test + public void testUpgradeNonEmptyFullResponseWithTrailers() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new Http2StreamFrameToHttpObjectCodec(true)); + ByteBuf hello = Unpooled.copiedBuffer("hello world", CharsetUtil.UTF_8); + FullHttpResponse response = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK, hello); + HttpHeaders trailers = response.trailingHeaders(); + trailers.set("key", "value"); + assertTrue(ch.writeOutbound(response)); + + Http2HeadersFrame headersFrame = ch.readOutbound(); + assertThat(headersFrame.headers().status().toString(), is("200")); + assertFalse(headersFrame.isEndStream()); + + Http2DataFrame dataFrame = ch.readOutbound(); + try { + assertThat(dataFrame.content().toString(CharsetUtil.UTF_8), is("hello world")); + assertFalse(dataFrame.isEndStream()); + } finally { + dataFrame.release(); + } + + Http2HeadersFrame trailersFrame = ch.readOutbound(); + assertThat(trailersFrame.headers().get("key").toString(), is("value")); + assertTrue(trailersFrame.isEndStream()); + + assertThat(ch.readOutbound(), is(nullValue())); + assertFalse(ch.finish()); + } + + @Test + public void testUpgradeHeaders() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new Http2StreamFrameToHttpObjectCodec(true)); + HttpResponse response = new DefaultHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK); + assertTrue(ch.writeOutbound(response)); + + Http2HeadersFrame headersFrame = ch.readOutbound(); + assertThat(headersFrame.headers().status().toString(), is("200")); + assertFalse(headersFrame.isEndStream()); + + assertThat(ch.readOutbound(), is(nullValue())); + assertFalse(ch.finish()); + } + + @Test + public void testUpgradeChunk() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new Http2StreamFrameToHttpObjectCodec(true)); + ByteBuf hello = Unpooled.copiedBuffer("hello world", CharsetUtil.UTF_8); + HttpContent content = new DefaultHttpContent(hello); + assertTrue(ch.writeOutbound(content)); + + Http2DataFrame dataFrame = ch.readOutbound(); + try { + assertThat(dataFrame.content().toString(CharsetUtil.UTF_8), is("hello world")); + assertFalse(dataFrame.isEndStream()); + } finally { + dataFrame.release(); + } + + assertThat(ch.readOutbound(), is(nullValue())); + assertFalse(ch.finish()); + } + + @Test + public void testUpgradeEmptyEnd() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new Http2StreamFrameToHttpObjectCodec(true)); + LastHttpContent end = LastHttpContent.EMPTY_LAST_CONTENT; + assertTrue(ch.writeOutbound(end)); + + Http2DataFrame emptyFrame = ch.readOutbound(); + try { + assertThat(emptyFrame.content().readableBytes(), is(0)); + assertTrue(emptyFrame.isEndStream()); + } finally { + emptyFrame.release(); + } + + assertThat(ch.readOutbound(), is(nullValue())); + assertFalse(ch.finish()); + } + + @Test + public void testUpgradeDataEnd() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new Http2StreamFrameToHttpObjectCodec(true)); + ByteBuf hello = Unpooled.copiedBuffer("hello world", CharsetUtil.UTF_8); + LastHttpContent end = new DefaultLastHttpContent(hello, true); + assertTrue(ch.writeOutbound(end)); + + Http2DataFrame dataFrame = ch.readOutbound(); + try { + assertThat(dataFrame.content().toString(CharsetUtil.UTF_8), is("hello world")); + assertTrue(dataFrame.isEndStream()); + } finally { + dataFrame.release(); + } + + assertThat(ch.readOutbound(), is(nullValue())); + assertFalse(ch.finish()); + } + + @Test + public void testUpgradeTrailers() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new Http2StreamFrameToHttpObjectCodec(true)); + LastHttpContent trailers = new DefaultLastHttpContent(Unpooled.EMPTY_BUFFER, true); + HttpHeaders headers = trailers.trailingHeaders(); + headers.set("key", "value"); + assertTrue(ch.writeOutbound(trailers)); + + Http2HeadersFrame headerFrame = ch.readOutbound(); + assertThat(headerFrame.headers().get("key").toString(), is("value")); + assertTrue(headerFrame.isEndStream()); + + assertThat(ch.readOutbound(), is(nullValue())); + assertFalse(ch.finish()); + } + + @Test + public void testUpgradeDataEndWithTrailers() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new Http2StreamFrameToHttpObjectCodec(true)); + ByteBuf hello = Unpooled.copiedBuffer("hello world", CharsetUtil.UTF_8); + LastHttpContent trailers = new DefaultLastHttpContent(hello, true); + HttpHeaders headers = trailers.trailingHeaders(); + headers.set("key", "value"); + assertTrue(ch.writeOutbound(trailers)); + + Http2DataFrame dataFrame = ch.readOutbound(); + try { + assertThat(dataFrame.content().toString(CharsetUtil.UTF_8), is("hello world")); + assertFalse(dataFrame.isEndStream()); + } finally { + dataFrame.release(); + } + + Http2HeadersFrame headerFrame = ch.readOutbound(); + assertThat(headerFrame.headers().get("key").toString(), is("value")); + assertTrue(headerFrame.isEndStream()); + + assertThat(ch.readOutbound(), is(nullValue())); + assertFalse(ch.finish()); + } + + @Test + public void testDowngradeHeaders() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new Http2StreamFrameToHttpObjectCodec(true)); + Http2Headers headers = new DefaultHttp2Headers(); + headers.path("/"); + headers.method("GET"); + + assertTrue(ch.writeInbound(new DefaultHttp2HeadersFrame(headers))); + + HttpRequest request = ch.readInbound(); + assertThat(request.uri(), is("/")); + assertThat(request.method(), is(HttpMethod.GET)); + assertThat(request.protocolVersion(), is(HttpVersion.HTTP_1_1)); + assertFalse(request instanceof FullHttpRequest); + assertTrue(HttpUtil.isTransferEncodingChunked(request)); + + assertThat(ch.readInbound(), is(nullValue())); + assertFalse(ch.finish()); + } + + @Test + public void testDowngradeHeadersWithContentLength() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new Http2StreamFrameToHttpObjectCodec(true)); + Http2Headers headers = new DefaultHttp2Headers(); + headers.path("/"); + headers.method("GET"); + headers.setInt("content-length", 0); + + assertTrue(ch.writeInbound(new DefaultHttp2HeadersFrame(headers))); + + HttpRequest request = ch.readInbound(); + assertThat(request.uri(), is("/")); + assertThat(request.method(), is(HttpMethod.GET)); + assertThat(request.protocolVersion(), is(HttpVersion.HTTP_1_1)); + assertFalse(request instanceof FullHttpRequest); + assertFalse(HttpUtil.isTransferEncodingChunked(request)); + + assertThat(ch.readInbound(), is(nullValue())); + assertFalse(ch.finish()); + } + + @Test + public void testDowngradeFullHeaders() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new Http2StreamFrameToHttpObjectCodec(true)); + Http2Headers headers = new DefaultHttp2Headers(); + headers.path("/"); + headers.method("GET"); + + assertTrue(ch.writeInbound(new DefaultHttp2HeadersFrame(headers, true))); + + FullHttpRequest request = ch.readInbound(); + try { + assertThat(request.uri(), is("/")); + assertThat(request.method(), is(HttpMethod.GET)); + assertThat(request.protocolVersion(), is(HttpVersion.HTTP_1_1)); + assertThat(request.content().readableBytes(), is(0)); + assertTrue(request.trailingHeaders().isEmpty()); + assertFalse(HttpUtil.isTransferEncodingChunked(request)); + } finally { + request.release(); + } + + assertThat(ch.readInbound(), is(nullValue())); + assertFalse(ch.finish()); + } + + @Test + public void testDowngradeTrailers() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new Http2StreamFrameToHttpObjectCodec(true)); + Http2Headers headers = new DefaultHttp2Headers(); + headers.set("key", "value"); + assertTrue(ch.writeInbound(new DefaultHttp2HeadersFrame(headers, true))); + + LastHttpContent trailers = ch.readInbound(); + try { + assertThat(trailers.content().readableBytes(), is(0)); + assertThat(trailers.trailingHeaders().get("key"), is("value")); + assertFalse(trailers instanceof FullHttpRequest); + } finally { + trailers.release(); + } + + assertThat(ch.readInbound(), is(nullValue())); + assertFalse(ch.finish()); + } + + @Test + public void testDowngradeData() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new Http2StreamFrameToHttpObjectCodec(true)); + ByteBuf hello = Unpooled.copiedBuffer("hello world", CharsetUtil.UTF_8); + assertTrue(ch.writeInbound(new DefaultHttp2DataFrame(hello))); + + HttpContent content = ch.readInbound(); + try { + assertThat(content.content().toString(CharsetUtil.UTF_8), is("hello world")); + assertFalse(content instanceof LastHttpContent); + } finally { + content.release(); + } + + assertThat(ch.readInbound(), is(nullValue())); + assertFalse(ch.finish()); + } + + @Test + public void testDowngradeEndData() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new Http2StreamFrameToHttpObjectCodec(true)); + ByteBuf hello = Unpooled.copiedBuffer("hello world", CharsetUtil.UTF_8); + assertTrue(ch.writeInbound(new DefaultHttp2DataFrame(hello, true))); + + LastHttpContent content = ch.readInbound(); + try { + assertThat(content.content().toString(CharsetUtil.UTF_8), is("hello world")); + assertTrue(content.trailingHeaders().isEmpty()); + } finally { + content.release(); + } + + assertThat(ch.readInbound(), is(nullValue())); + assertFalse(ch.finish()); + } + + @Test + public void testPassThroughOther() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new Http2StreamFrameToHttpObjectCodec(true)); + Http2ResetFrame reset = new DefaultHttp2ResetFrame(0); + Http2GoAwayFrame goaway = new DefaultHttp2GoAwayFrame(0); + assertTrue(ch.writeInbound(reset)); + assertTrue(ch.writeInbound(goaway.retain())); + + assertEquals(reset, ch.readInbound()); + + Http2GoAwayFrame frame = ch.readInbound(); + try { + assertEquals(goaway, frame); + assertThat(ch.readInbound(), is(nullValue())); + assertFalse(ch.finish()); + } finally { + goaway.release(); + frame.release(); + } + } + + // client-specific tests + @Test + public void testEncodeEmptyFullRequest() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new Http2StreamFrameToHttpObjectCodec(false)); + assertTrue(ch.writeOutbound(new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/hello/world"))); + + Http2HeadersFrame headersFrame = ch.readOutbound(); + Http2Headers headers = headersFrame.headers(); + + assertThat(headers.scheme().toString(), is("http")); + assertThat(headers.method().toString(), is("GET")); + assertThat(headers.path().toString(), is("/hello/world")); + assertTrue(headersFrame.isEndStream()); + + assertThat(ch.readOutbound(), is(nullValue())); + assertFalse(ch.finish()); + } + + @Test + public void testEncodeHttpsSchemeWhenSslHandlerExists() throws Exception { + final Queue frames = new ConcurrentLinkedQueue(); + + final SslContext ctx = SslContextBuilder.forClient().sslProvider(SslProvider.JDK).build(); + EmbeddedChannel ch = new EmbeddedChannel(ctx.newHandler(ByteBufAllocator.DEFAULT), + new ChannelOutboundHandlerAdapter() { + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + if (msg instanceof Http2StreamFrame) { + frames.add((Http2StreamFrame) msg); + ctx.write(Unpooled.EMPTY_BUFFER, promise); + } else { + ctx.write(msg, promise); + } + } + }, new Http2StreamFrameToHttpObjectCodec(false)); + + try { + FullHttpRequest req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/hello/world"); + assertTrue(ch.writeOutbound(req)); + + ch.finishAndReleaseAll(); + + Http2HeadersFrame headersFrame = (Http2HeadersFrame) frames.poll(); + Http2Headers headers = headersFrame.headers(); + + assertThat(headers.scheme().toString(), is("https")); + assertThat(headers.method().toString(), is("GET")); + assertThat(headers.path().toString(), is("/hello/world")); + assertTrue(headersFrame.isEndStream()); + assertNull(frames.poll()); + } finally { + ch.finishAndReleaseAll(); + } + } + + @Test + public void testEncodeNonEmptyFullRequest() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new Http2StreamFrameToHttpObjectCodec(false)); + ByteBuf hello = Unpooled.copiedBuffer("hello world", CharsetUtil.UTF_8); + assertTrue(ch.writeOutbound(new DefaultFullHttpRequest( + HttpVersion.HTTP_1_1, HttpMethod.PUT, "/hello/world", hello))); + + Http2HeadersFrame headersFrame = ch.readOutbound(); + Http2Headers headers = headersFrame.headers(); + + assertThat(headers.scheme().toString(), is("http")); + assertThat(headers.method().toString(), is("PUT")); + assertThat(headers.path().toString(), is("/hello/world")); + assertFalse(headersFrame.isEndStream()); + + Http2DataFrame dataFrame = ch.readOutbound(); + try { + assertThat(dataFrame.content().toString(CharsetUtil.UTF_8), is("hello world")); + assertTrue(dataFrame.isEndStream()); + } finally { + dataFrame.release(); + } + + assertThat(ch.readOutbound(), is(nullValue())); + assertFalse(ch.finish()); + } + + @Test + public void testEncodeEmptyFullRequestWithTrailers() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new Http2StreamFrameToHttpObjectCodec(false)); + FullHttpRequest request = new DefaultFullHttpRequest( + HttpVersion.HTTP_1_1, HttpMethod.PUT, "/hello/world"); + + HttpHeaders trailers = request.trailingHeaders(); + trailers.set("key", "value"); + assertTrue(ch.writeOutbound(request)); + + Http2HeadersFrame headersFrame = ch.readOutbound(); + Http2Headers headers = headersFrame.headers(); + + assertThat(headers.scheme().toString(), is("http")); + assertThat(headers.method().toString(), is("PUT")); + assertThat(headers.path().toString(), is("/hello/world")); + assertFalse(headersFrame.isEndStream()); + + Http2HeadersFrame trailersFrame = ch.readOutbound(); + assertThat(trailersFrame.headers().get("key").toString(), is("value")); + assertTrue(trailersFrame.isEndStream()); + + assertThat(ch.readOutbound(), is(nullValue())); + assertFalse(ch.finish()); + } + + @Test + public void testEncodeNonEmptyFullRequestWithTrailers() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new Http2StreamFrameToHttpObjectCodec(false)); + ByteBuf hello = Unpooled.copiedBuffer("hello world", CharsetUtil.UTF_8); + FullHttpRequest request = new DefaultFullHttpRequest( + HttpVersion.HTTP_1_1, HttpMethod.PUT, "/hello/world", hello); + + HttpHeaders trailers = request.trailingHeaders(); + trailers.set("key", "value"); + assertTrue(ch.writeOutbound(request)); + + Http2HeadersFrame headersFrame = ch.readOutbound(); + Http2Headers headers = headersFrame.headers(); + + assertThat(headers.scheme().toString(), is("http")); + assertThat(headers.method().toString(), is("PUT")); + assertThat(headers.path().toString(), is("/hello/world")); + assertFalse(headersFrame.isEndStream()); + + Http2DataFrame dataFrame = ch.readOutbound(); + try { + assertThat(dataFrame.content().toString(CharsetUtil.UTF_8), is("hello world")); + assertFalse(dataFrame.isEndStream()); + } finally { + dataFrame.release(); + } + + Http2HeadersFrame trailersFrame = ch.readOutbound(); + assertThat(trailersFrame.headers().get("key").toString(), is("value")); + assertTrue(trailersFrame.isEndStream()); + + assertThat(ch.readOutbound(), is(nullValue())); + assertFalse(ch.finish()); + } + + @Test + public void testEncodeRequestHeaders() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new Http2StreamFrameToHttpObjectCodec(false)); + HttpRequest request = new DefaultHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/hello/world"); + assertTrue(ch.writeOutbound(request)); + + Http2HeadersFrame headersFrame = ch.readOutbound(); + Http2Headers headers = headersFrame.headers(); + + assertThat(headers.scheme().toString(), is("http")); + assertThat(headers.method().toString(), is("GET")); + assertThat(headers.path().toString(), is("/hello/world")); + assertFalse(headersFrame.isEndStream()); + + assertThat(ch.readOutbound(), is(nullValue())); + assertFalse(ch.finish()); + } + + @Test + public void testEncodeChunkAsClient() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new Http2StreamFrameToHttpObjectCodec(false)); + ByteBuf hello = Unpooled.copiedBuffer("hello world", CharsetUtil.UTF_8); + HttpContent content = new DefaultHttpContent(hello); + assertTrue(ch.writeOutbound(content)); + + Http2DataFrame dataFrame = ch.readOutbound(); + try { + assertThat(dataFrame.content().toString(CharsetUtil.UTF_8), is("hello world")); + assertFalse(dataFrame.isEndStream()); + } finally { + dataFrame.release(); + } + + assertThat(ch.readOutbound(), is(nullValue())); + assertFalse(ch.finish()); + } + + @Test + public void testEncodeEmptyEndAsClient() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new Http2StreamFrameToHttpObjectCodec(false)); + LastHttpContent end = LastHttpContent.EMPTY_LAST_CONTENT; + assertTrue(ch.writeOutbound(end)); + + Http2DataFrame emptyFrame = ch.readOutbound(); + try { + assertThat(emptyFrame.content().readableBytes(), is(0)); + assertTrue(emptyFrame.isEndStream()); + } finally { + emptyFrame.release(); + } + + assertThat(ch.readOutbound(), is(nullValue())); + assertFalse(ch.finish()); + } + + @Test + public void testEncodeDataEndAsClient() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new Http2StreamFrameToHttpObjectCodec(false)); + ByteBuf hello = Unpooled.copiedBuffer("hello world", CharsetUtil.UTF_8); + LastHttpContent end = new DefaultLastHttpContent(hello, true); + assertTrue(ch.writeOutbound(end)); + + Http2DataFrame dataFrame = ch.readOutbound(); + try { + assertThat(dataFrame.content().toString(CharsetUtil.UTF_8), is("hello world")); + assertTrue(dataFrame.isEndStream()); + } finally { + dataFrame.release(); + } + + assertThat(ch.readOutbound(), is(nullValue())); + assertFalse(ch.finish()); + } + + @Test + public void testEncodeTrailersAsClient() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new Http2StreamFrameToHttpObjectCodec(false)); + LastHttpContent trailers = new DefaultLastHttpContent(Unpooled.EMPTY_BUFFER, true); + HttpHeaders headers = trailers.trailingHeaders(); + headers.set("key", "value"); + assertTrue(ch.writeOutbound(trailers)); + + Http2HeadersFrame headerFrame = ch.readOutbound(); + assertThat(headerFrame.headers().get("key").toString(), is("value")); + assertTrue(headerFrame.isEndStream()); + + assertThat(ch.readOutbound(), is(nullValue())); + assertFalse(ch.finish()); + } + + @Test + public void testEncodeDataEndWithTrailersAsClient() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new Http2StreamFrameToHttpObjectCodec(false)); + ByteBuf hello = Unpooled.copiedBuffer("hello world", CharsetUtil.UTF_8); + LastHttpContent trailers = new DefaultLastHttpContent(hello, true); + HttpHeaders headers = trailers.trailingHeaders(); + headers.set("key", "value"); + assertTrue(ch.writeOutbound(trailers)); + + Http2DataFrame dataFrame = ch.readOutbound(); + try { + assertThat(dataFrame.content().toString(CharsetUtil.UTF_8), is("hello world")); + assertFalse(dataFrame.isEndStream()); + } finally { + dataFrame.release(); + } + + Http2HeadersFrame headerFrame = ch.readOutbound(); + assertThat(headerFrame.headers().get("key").toString(), is("value")); + assertTrue(headerFrame.isEndStream()); + + assertThat(ch.readOutbound(), is(nullValue())); + assertFalse(ch.finish()); + } + + @Test + public void decode100ContinueHttp2HeadersAsFullHttpResponse() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new Http2StreamFrameToHttpObjectCodec(false)); + Http2Headers headers = new DefaultHttp2Headers(); + headers.scheme(HttpScheme.HTTP.name()); + headers.status(HttpResponseStatus.CONTINUE.codeAsText()); + + assertTrue(ch.writeInbound(new DefaultHttp2HeadersFrame(headers, false))); + + final FullHttpResponse response = ch.readInbound(); + try { + assertThat(response.status(), is(HttpResponseStatus.CONTINUE)); + assertThat(response.protocolVersion(), is(HttpVersion.HTTP_1_1)); + } finally { + response.release(); + } + + assertThat(ch.readInbound(), is(nullValue())); + assertFalse(ch.finish()); + } + + /** + * An informational response using a 1xx status code other than 101 is + * transmitted as a HEADERS frame, followed by zero or more CONTINUATION + * frames. + * Trailing header fields are sent as a header block after both the + * request or response header block and all the DATA frames have been + * sent. The HEADERS frame starting the trailers header block has the + * END_STREAM flag set. + */ + @Test + public void decode103EarlyHintsHttp2HeadersAsFullHttpResponse() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new Http2StreamFrameToHttpObjectCodec(false)); + Http2Headers headers = new DefaultHttp2Headers(); + headers.scheme(HttpScheme.HTTP.name()); + headers.status(HttpResponseStatus.EARLY_HINTS.codeAsText()); + headers.set("key", "value"); + + assertTrue(ch.writeInbound(new DefaultHttp2HeadersFrame(headers, false))); + + final FullHttpResponse response = ch.readInbound(); + try { + assertThat(response.status(), is(HttpResponseStatus.EARLY_HINTS)); + assertThat(response.protocolVersion(), is(HttpVersion.HTTP_1_1)); + assertThat(response.headers().get("key"), is("value")); + } finally { + response.release(); + } + + assertThat(ch.readInbound(), is(nullValue())); + assertFalse(ch.finish()); + } + + @Test + public void testDecodeResponseHeaders() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new Http2StreamFrameToHttpObjectCodec(false)); + Http2Headers headers = new DefaultHttp2Headers(); + headers.scheme(HttpScheme.HTTP.name()); + headers.status(HttpResponseStatus.OK.codeAsText()); + + assertTrue(ch.writeInbound(new DefaultHttp2HeadersFrame(headers))); + + HttpResponse response = ch.readInbound(); + assertThat(response.status(), is(HttpResponseStatus.OK)); + assertThat(response.protocolVersion(), is(HttpVersion.HTTP_1_1)); + assertFalse(response instanceof FullHttpResponse); + assertTrue(HttpUtil.isTransferEncodingChunked(response)); + + assertThat(ch.readInbound(), is(nullValue())); + assertFalse(ch.finish()); + } + + @Test + public void testDecodeResponseHeadersWithContentLength() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new Http2StreamFrameToHttpObjectCodec(false)); + Http2Headers headers = new DefaultHttp2Headers(); + headers.scheme(HttpScheme.HTTP.name()); + headers.status(HttpResponseStatus.OK.codeAsText()); + headers.setInt("content-length", 0); + + assertTrue(ch.writeInbound(new DefaultHttp2HeadersFrame(headers))); + + HttpResponse response = ch.readInbound(); + assertThat(response.status(), is(HttpResponseStatus.OK)); + assertThat(response.protocolVersion(), is(HttpVersion.HTTP_1_1)); + assertFalse(response instanceof FullHttpResponse); + assertFalse(HttpUtil.isTransferEncodingChunked(response)); + + assertThat(ch.readInbound(), is(nullValue())); + assertFalse(ch.finish()); + } + + @ParameterizedTest() + @ValueSource(strings = {"204", "304"}) + public void testDecodeResponseHeadersContentAlwaysEmpty(String statusCode) { + EmbeddedChannel ch = new EmbeddedChannel(new Http2StreamFrameToHttpObjectCodec(false)); + Http2Headers headers = new DefaultHttp2Headers(); + headers.scheme(HttpScheme.HTTP.name()); + headers.status(statusCode); + + assertTrue(ch.writeInbound(new DefaultHttp2HeadersFrame(headers))); + + HttpResponse request = ch.readInbound(); + assertThat(request.status().codeAsText().toString(), is(statusCode)); + assertThat(request.protocolVersion(), is(HttpVersion.HTTP_1_1)); + assertThat(request, is(not(instanceOf(FullHttpResponse.class)))); + assertFalse(HttpUtil.isTransferEncodingChunked(request)); + + assertThat(ch.readInbound(), is(nullValue())); + assertFalse(ch.finish()); + } + + @Test + public void testDecodeFullResponseHeaders() throws Exception { + testDecodeFullResponseHeaders(false); + } + + @Test + public void testDecodeFullResponseHeadersWithStreamID() throws Exception { + testDecodeFullResponseHeaders(true); + } + + private void testDecodeFullResponseHeaders(boolean withStreamId) throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new Http2StreamFrameToHttpObjectCodec(false)); + Http2Headers headers = new DefaultHttp2Headers(); + headers.scheme(HttpScheme.HTTP.name()); + headers.status(HttpResponseStatus.OK.codeAsText()); + + Http2HeadersFrame frame = new DefaultHttp2HeadersFrame(headers, true); + if (withStreamId) { + frame.stream(new Http2FrameStream() { + @Override + public int id() { + return 1; + } + + @Override + public Http2Stream.State state() { + return Http2Stream.State.OPEN; + } + }); + } + + assertTrue(ch.writeInbound(frame)); + + FullHttpResponse response = ch.readInbound(); + try { + assertThat(response.status(), is(HttpResponseStatus.OK)); + assertThat(response.protocolVersion(), is(HttpVersion.HTTP_1_1)); + assertThat(response.content().readableBytes(), is(0)); + assertTrue(response.trailingHeaders().isEmpty()); + assertFalse(HttpUtil.isTransferEncodingChunked(response)); + if (withStreamId) { + assertEquals(1, + (int) response.headers().getInt(HttpConversionUtil.ExtensionHeaderNames.STREAM_ID.text())); + } + } finally { + response.release(); + } + + assertThat(ch.readInbound(), is(nullValue())); + assertFalse(ch.finish()); + } + + @Test + public void testDecodeResponseTrailersAsClient() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new Http2StreamFrameToHttpObjectCodec(false)); + Http2Headers headers = new DefaultHttp2Headers(); + headers.set("key", "value"); + assertTrue(ch.writeInbound(new DefaultHttp2HeadersFrame(headers, true))); + + LastHttpContent trailers = ch.readInbound(); + try { + assertThat(trailers.content().readableBytes(), is(0)); + assertThat(trailers.trailingHeaders().get("key"), is("value")); + assertFalse(trailers instanceof FullHttpRequest); + } finally { + trailers.release(); + } + + assertThat(ch.readInbound(), is(nullValue())); + assertFalse(ch.finish()); + } + + @Test + public void testDecodeDataAsClient() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new Http2StreamFrameToHttpObjectCodec(false)); + ByteBuf hello = Unpooled.copiedBuffer("hello world", CharsetUtil.UTF_8); + assertTrue(ch.writeInbound(new DefaultHttp2DataFrame(hello))); + + HttpContent content = ch.readInbound(); + try { + assertThat(content.content().toString(CharsetUtil.UTF_8), is("hello world")); + assertFalse(content instanceof LastHttpContent); + } finally { + content.release(); + } + + assertThat(ch.readInbound(), is(nullValue())); + assertFalse(ch.finish()); + } + + @Test + public void testDecodeEndDataAsClient() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new Http2StreamFrameToHttpObjectCodec(false)); + ByteBuf hello = Unpooled.copiedBuffer("hello world", CharsetUtil.UTF_8); + assertTrue(ch.writeInbound(new DefaultHttp2DataFrame(hello, true))); + + LastHttpContent content = ch.readInbound(); + try { + assertThat(content.content().toString(CharsetUtil.UTF_8), is("hello world")); + assertTrue(content.trailingHeaders().isEmpty()); + } finally { + content.release(); + } + + assertThat(ch.readInbound(), is(nullValue())); + assertFalse(ch.finish()); + } + + @Test + public void testPassThroughOtherAsClient() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new Http2StreamFrameToHttpObjectCodec(false)); + Http2ResetFrame reset = new DefaultHttp2ResetFrame(0); + Http2GoAwayFrame goaway = new DefaultHttp2GoAwayFrame(0); + assertTrue(ch.writeInbound(reset)); + assertTrue(ch.writeInbound(goaway.retain())); + + assertEquals(reset, ch.readInbound()); + + Http2GoAwayFrame frame = ch.readInbound(); + try { + assertEquals(goaway, frame); + assertThat(ch.readInbound(), is(nullValue())); + assertFalse(ch.finish()); + } finally { + goaway.release(); + frame.release(); + } + } + + @Test + public void testIsSharableBetweenChannels() throws Exception { + final Queue frames = new ConcurrentLinkedQueue(); + final ChannelHandler sharedHandler = new Http2StreamFrameToHttpObjectCodec(false); + + final SslContext ctx = SslContextBuilder.forClient().sslProvider(SslProvider.JDK).build(); + EmbeddedChannel tlsCh = new EmbeddedChannel(ctx.newHandler(ByteBufAllocator.DEFAULT), + new ChannelOutboundHandlerAdapter() { + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) { + if (msg instanceof Http2StreamFrame) { + frames.add((Http2StreamFrame) msg); + promise.setSuccess(); + } else { + ctx.write(msg, promise); + } + } + }, sharedHandler); + + EmbeddedChannel plaintextCh = new EmbeddedChannel( + new ChannelOutboundHandlerAdapter() { + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) { + if (msg instanceof Http2StreamFrame) { + frames.add((Http2StreamFrame) msg); + promise.setSuccess(); + } else { + ctx.write(msg, promise); + } + } + }, sharedHandler); + + FullHttpRequest req = new DefaultFullHttpRequest( + HttpVersion.HTTP_1_1, HttpMethod.GET, "/hello/world"); + assertTrue(tlsCh.writeOutbound(req)); + assertTrue(tlsCh.finishAndReleaseAll()); + + Http2HeadersFrame headersFrame = (Http2HeadersFrame) frames.poll(); + Http2Headers headers = headersFrame.headers(); + + assertThat(headers.scheme().toString(), is("https")); + assertThat(headers.method().toString(), is("GET")); + assertThat(headers.path().toString(), is("/hello/world")); + assertTrue(headersFrame.isEndStream()); + assertNull(frames.poll()); + + // Run the plaintext channel + req = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/hello/world"); + assertFalse(plaintextCh.writeOutbound(req)); + assertFalse(plaintextCh.finishAndReleaseAll()); + + headersFrame = (Http2HeadersFrame) frames.poll(); + headers = headersFrame.headers(); + + assertThat(headers.scheme().toString(), is("http")); + assertThat(headers.method().toString(), is("GET")); + assertThat(headers.path().toString(), is("/hello/world")); + assertTrue(headersFrame.isEndStream()); + assertNull(frames.poll()); + } +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2TestUtil.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2TestUtil.java new file mode 100644 index 0000000..2356ca3 --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/Http2TestUtil.java @@ -0,0 +1,538 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import io.netty.buffer.UnpooledByteBufAllocator; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.channel.DefaultChannelPromise; +import io.netty.util.AsciiString; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.GenericFutureListener; +import io.netty.util.concurrent.ImmediateEventExecutor; +import org.mockito.Mockito; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +import java.util.Random; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.CountDownLatch; + +import static io.netty.handler.codec.http2.Http2CodecUtil.MAX_HEADER_LIST_SIZE; +import static io.netty.handler.codec.http2.Http2CodecUtil.MAX_HEADER_TABLE_SIZE; +import static io.netty.util.ReferenceCountUtil.release; +import static java.lang.Math.min; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; +import static org.mockito.ArgumentMatchers.any; +import static org.mockito.ArgumentMatchers.anyBoolean; +import static org.mockito.ArgumentMatchers.anyByte; +import static org.mockito.ArgumentMatchers.anyInt; +import static org.mockito.ArgumentMatchers.anyLong; +import static org.mockito.ArgumentMatchers.anyShort; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.when; + +/** + * Utilities for the integration tests. + */ +public final class Http2TestUtil { + /** + * Interface that allows for running a operation that throws a {@link Http2Exception}. + */ + interface Http2Runnable { + void run() throws Http2Exception; + } + + /** + * Runs the given operation within the event loop thread of the given {@link Channel}. + */ + static void runInChannel(Channel channel, final Http2Runnable runnable) { + channel.eventLoop().execute(new Runnable() { + @Override + public void run() { + try { + runnable.run(); + } catch (Http2Exception e) { + throw new RuntimeException(e); + } + } + }); + } + + /** + * Returns a byte array filled with random data. + */ + public static byte[] randomBytes() { + return randomBytes(100); + } + + /** + * Returns a byte array filled with random data. + */ + public static byte[] randomBytes(int size) { + byte[] data = new byte[size]; + new Random().nextBytes(data); + return data; + } + + /** + * Returns an {@link AsciiString} that wraps a randomly-filled byte array. + */ + public static AsciiString randomString() { + return new AsciiString(randomBytes()); + } + + public static CharSequence of(String s) { + return s; + } + + public static HpackEncoder newTestEncoder() { + try { + return newTestEncoder(true, MAX_HEADER_LIST_SIZE, MAX_HEADER_TABLE_SIZE); + } catch (Http2Exception e) { + throw new Error("max size not allowed?", e); + } + } + + public static HpackEncoder newTestEncoder(boolean ignoreMaxHeaderListSize, + long maxHeaderListSize, long maxHeaderTableSize) throws Http2Exception { + HpackEncoder hpackEncoder = new HpackEncoder(false, 16, 0); + ByteBuf buf = Unpooled.buffer(); + try { + hpackEncoder.setMaxHeaderTableSize(buf, maxHeaderTableSize); + hpackEncoder.setMaxHeaderListSize(maxHeaderListSize); + } finally { + buf.release(); + } + return hpackEncoder; + } + + public static HpackDecoder newTestDecoder() { + try { + return newTestDecoder(MAX_HEADER_LIST_SIZE, MAX_HEADER_TABLE_SIZE); + } catch (Http2Exception e) { + throw new Error("max size not allowed?", e); + } + } + + public static HpackDecoder newTestDecoder(long maxHeaderListSize, long maxHeaderTableSize) throws Http2Exception { + HpackDecoder hpackDecoder = new HpackDecoder(maxHeaderListSize); + hpackDecoder.setMaxHeaderTableSize(maxHeaderTableSize); + return hpackDecoder; + } + + private Http2TestUtil() { + } + + /** + * A decorator around a {@link Http2FrameListener} that counts down the latch so that we can await the completion of + * the request. + */ + static class FrameCountDown implements Http2FrameListener { + private final Http2FrameListener listener; + private final CountDownLatch messageLatch; + private final CountDownLatch settingsAckLatch; + private final CountDownLatch dataLatch; + private final CountDownLatch trailersLatch; + private final CountDownLatch goAwayLatch; + + FrameCountDown(Http2FrameListener listener, CountDownLatch settingsAckLatch, CountDownLatch messageLatch, + CountDownLatch dataLatch, CountDownLatch trailersLatch) { + this(listener, settingsAckLatch, messageLatch, dataLatch, trailersLatch, messageLatch); + } + + FrameCountDown(Http2FrameListener listener, CountDownLatch settingsAckLatch, CountDownLatch messageLatch, + CountDownLatch dataLatch, CountDownLatch trailersLatch, CountDownLatch goAwayLatch) { + this.listener = listener; + this.messageLatch = messageLatch; + this.settingsAckLatch = settingsAckLatch; + this.dataLatch = dataLatch; + this.trailersLatch = trailersLatch; + this.goAwayLatch = goAwayLatch; + } + + @Override + public int onDataRead(ChannelHandlerContext ctx, int streamId, ByteBuf data, int padding, boolean endOfStream) + throws Http2Exception { + int numBytes = data.readableBytes(); + int processed = listener.onDataRead(ctx, streamId, data, padding, endOfStream); + messageLatch.countDown(); + if (dataLatch != null) { + for (int i = 0; i < numBytes; ++i) { + dataLatch.countDown(); + } + } + return processed; + } + + @Override + public void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers headers, int padding, + boolean endStream) throws Http2Exception { + listener.onHeadersRead(ctx, streamId, headers, padding, endStream); + messageLatch.countDown(); + if (trailersLatch != null && endStream) { + trailersLatch.countDown(); + } + } + + @Override + public void onHeadersRead(ChannelHandlerContext ctx, int streamId, Http2Headers headers, int streamDependency, + short weight, boolean exclusive, int padding, boolean endStream) throws Http2Exception { + listener.onHeadersRead(ctx, streamId, headers, streamDependency, weight, exclusive, padding, endStream); + messageLatch.countDown(); + if (trailersLatch != null && endStream) { + trailersLatch.countDown(); + } + } + + @Override + public void onPriorityRead(ChannelHandlerContext ctx, int streamId, int streamDependency, short weight, + boolean exclusive) throws Http2Exception { + listener.onPriorityRead(ctx, streamId, streamDependency, weight, exclusive); + messageLatch.countDown(); + } + + @Override + public void onRstStreamRead(ChannelHandlerContext ctx, int streamId, long errorCode) throws Http2Exception { + listener.onRstStreamRead(ctx, streamId, errorCode); + messageLatch.countDown(); + } + + @Override + public void onSettingsAckRead(ChannelHandlerContext ctx) throws Http2Exception { + listener.onSettingsAckRead(ctx); + settingsAckLatch.countDown(); + } + + @Override + public void onSettingsRead(ChannelHandlerContext ctx, Http2Settings settings) throws Http2Exception { + listener.onSettingsRead(ctx, settings); + messageLatch.countDown(); + } + + @Override + public void onPingRead(ChannelHandlerContext ctx, long data) throws Http2Exception { + listener.onPingRead(ctx, data); + messageLatch.countDown(); + } + + @Override + public void onPingAckRead(ChannelHandlerContext ctx, long data) throws Http2Exception { + listener.onPingAckRead(ctx, data); + messageLatch.countDown(); + } + + @Override + public void onPushPromiseRead(ChannelHandlerContext ctx, int streamId, int promisedStreamId, + Http2Headers headers, int padding) throws Http2Exception { + listener.onPushPromiseRead(ctx, streamId, promisedStreamId, headers, padding); + messageLatch.countDown(); + } + + @Override + public void onGoAwayRead(ChannelHandlerContext ctx, int lastStreamId, long errorCode, ByteBuf debugData) + throws Http2Exception { + listener.onGoAwayRead(ctx, lastStreamId, errorCode, debugData); + goAwayLatch.countDown(); + } + + @Override + public void onWindowUpdateRead(ChannelHandlerContext ctx, int streamId, int windowSizeIncrement) + throws Http2Exception { + listener.onWindowUpdateRead(ctx, streamId, windowSizeIncrement); + messageLatch.countDown(); + } + + @Override + public void onUnknownFrame(ChannelHandlerContext ctx, byte frameType, int streamId, Http2Flags flags, + ByteBuf payload) throws Http2Exception { + listener.onUnknownFrame(ctx, frameType, streamId, flags, payload); + messageLatch.countDown(); + } + } + + static ChannelPromise newVoidPromise(final Channel channel) { + return new DefaultChannelPromise(channel, ImmediateEventExecutor.INSTANCE) { + @Override + public ChannelPromise addListener( + GenericFutureListener> listener) { + fail(); + return null; + } + + @Override + public ChannelPromise addListeners( + GenericFutureListener>... listeners) { + fail(); + return null; + } + + @Override + public boolean isVoid() { + return true; + } + + @Override + public boolean tryFailure(Throwable cause) { + channel().pipeline().fireExceptionCaught(cause); + return true; + } + + @Override + public ChannelPromise setFailure(Throwable cause) { + tryFailure(cause); + return this; + } + + @Override + public ChannelPromise unvoid() { + ChannelPromise promise = + new DefaultChannelPromise(channel, ImmediateEventExecutor.INSTANCE); + promise.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + if (!future.isSuccess()) { + channel().pipeline().fireExceptionCaught(future.cause()); + } + } + }); + return promise; + } + }; + } + + static final class TestStreamByteDistributorStreamState implements StreamByteDistributor.StreamState { + private final Http2Stream stream; + boolean isWriteAllowed; + long pendingBytes; + boolean hasFrame; + + TestStreamByteDistributorStreamState(Http2Stream stream, long pendingBytes, boolean hasFrame, + boolean isWriteAllowed) { + this.stream = stream; + this.isWriteAllowed = isWriteAllowed; + this.pendingBytes = pendingBytes; + this.hasFrame = hasFrame; + } + + @Override + public Http2Stream stream() { + return stream; + } + + @Override + public long pendingBytes() { + return pendingBytes; + } + + @Override + public boolean hasFrame() { + return hasFrame; + } + + @Override + public int windowSize() { + return isWriteAllowed ? (int) min(pendingBytes, Integer.MAX_VALUE) : -1; + } + } + + static Http2FrameWriter mockedFrameWriter() { + Http2FrameWriter.Configuration configuration = new Http2FrameWriter.Configuration() { + private final Http2HeadersEncoder.Configuration headerConfiguration = + new Http2HeadersEncoder.Configuration() { + @Override + public void maxHeaderTableSize(long max) { + // NOOP + } + + @Override + public long maxHeaderTableSize() { + return 0; + } + + @Override + public void maxHeaderListSize(long max) { + // NOOP + } + + @Override + public long maxHeaderListSize() { + return 0; + } + }; + + private final Http2FrameSizePolicy policy = new Http2FrameSizePolicy() { + @Override + public void maxFrameSize(int max) { + // NOOP + } + + @Override + public int maxFrameSize() { + return 0; + } + }; + @Override + public Http2HeadersEncoder.Configuration headersConfiguration() { + return headerConfiguration; + } + + @Override + public Http2FrameSizePolicy frameSizePolicy() { + return policy; + } + }; + + final ConcurrentLinkedQueue buffers = new ConcurrentLinkedQueue(); + + Http2FrameWriter frameWriter = Mockito.mock(Http2FrameWriter.class); + doAnswer(new Answer() { + @Override + public Object answer(InvocationOnMock invocationOnMock) { + for (;;) { + ByteBuf buf = buffers.poll(); + if (buf == null) { + break; + } + buf.release(); + } + return null; + } + }).when(frameWriter).close(); + + when(frameWriter.configuration()).thenReturn(configuration); + when(frameWriter.writeSettings(any(ChannelHandlerContext.class), any(Http2Settings.class), + any(ChannelPromise.class))).thenAnswer(new Answer() { + @Override + public ChannelFuture answer(InvocationOnMock invocationOnMock) { + return ((ChannelPromise) invocationOnMock.getArgument(2)).setSuccess(); + } + }); + + when(frameWriter.writeSettingsAck(any(ChannelHandlerContext.class), any(ChannelPromise.class))) + .thenAnswer(new Answer() { + @Override + public ChannelFuture answer(InvocationOnMock invocationOnMock) { + return ((ChannelPromise) invocationOnMock.getArgument(1)).setSuccess(); + } + }); + + when(frameWriter.writeGoAway(any(ChannelHandlerContext.class), anyInt(), + anyLong(), any(ByteBuf.class), any(ChannelPromise.class))).thenAnswer(new Answer() { + @Override + public ChannelFuture answer(InvocationOnMock invocationOnMock) { + buffers.offer((ByteBuf) invocationOnMock.getArgument(3)); + return ((ChannelPromise) invocationOnMock.getArgument(4)).setSuccess(); + } + }); + when(frameWriter.writeHeaders(any(ChannelHandlerContext.class), anyInt(), any(Http2Headers.class), anyInt(), + anyBoolean(), any(ChannelPromise.class))).thenAnswer(new Answer() { + @Override + public ChannelFuture answer(InvocationOnMock invocationOnMock) { + return ((ChannelPromise) invocationOnMock.getArgument(5)).setSuccess(); + } + }); + + when(frameWriter.writeHeaders(any(ChannelHandlerContext.class), anyInt(), + any(Http2Headers.class), anyInt(), anyShort(), anyBoolean(), anyInt(), anyBoolean(), + any(ChannelPromise.class))).thenAnswer(new Answer() { + @Override + public ChannelFuture answer(InvocationOnMock invocationOnMock) { + return ((ChannelPromise) invocationOnMock.getArgument(8)).setSuccess(); + } + }); + + when(frameWriter.writeData(any(ChannelHandlerContext.class), anyInt(), any(ByteBuf.class), anyInt(), + anyBoolean(), any(ChannelPromise.class))).thenAnswer(new Answer() { + @Override + public ChannelFuture answer(InvocationOnMock invocationOnMock) { + buffers.offer((ByteBuf) invocationOnMock.getArgument(2)); + return ((ChannelPromise) invocationOnMock.getArgument(5)).setSuccess(); + } + }); + + when(frameWriter.writeRstStream(any(ChannelHandlerContext.class), anyInt(), + anyLong(), any(ChannelPromise.class))).thenAnswer(new Answer() { + @Override + public ChannelFuture answer(InvocationOnMock invocationOnMock) { + return ((ChannelPromise) invocationOnMock.getArgument(3)).setSuccess(); + } + }); + + when(frameWriter.writeWindowUpdate(any(ChannelHandlerContext.class), anyInt(), anyInt(), + any(ChannelPromise.class))).then(new Answer() { + @Override + public ChannelFuture answer(InvocationOnMock invocationOnMock) { + return ((ChannelPromise) invocationOnMock.getArgument(3)).setSuccess(); + } + }); + + when(frameWriter.writePushPromise(any(ChannelHandlerContext.class), anyInt(), anyInt(), any(Http2Headers.class), + anyInt(), anyChannelPromise())).thenAnswer(new Answer() { + @Override + public ChannelFuture answer(InvocationOnMock invocationOnMock) { + return ((ChannelPromise) invocationOnMock.getArgument(5)).setSuccess(); + } + }); + + when(frameWriter.writeFrame(any(ChannelHandlerContext.class), anyByte(), anyInt(), any(Http2Flags.class), + any(ByteBuf.class), anyChannelPromise())).thenAnswer(new Answer() { + @Override + public ChannelFuture answer(InvocationOnMock invocationOnMock) { + buffers.offer((ByteBuf) invocationOnMock.getArgument(4)); + return ((ChannelPromise) invocationOnMock.getArgument(5)).setSuccess(); + } + }); + return frameWriter; + } + + static ChannelPromise anyChannelPromise() { + return any(ChannelPromise.class); + } + + static Http2Settings anyHttp2Settings() { + return any(Http2Settings.class); + } + + static ByteBuf bb(String s) { + return ByteBufUtil.writeUtf8(UnpooledByteBufAllocator.DEFAULT, s); + } + + static ByteBuf bb(int size) { + return UnpooledByteBufAllocator.DEFAULT.buffer().writeZero(size); + } + + static void assertEqualsAndRelease(Http2Frame expected, Http2Frame actual) { + try { + assertEquals(expected, actual); + } finally { + release(expected); + release(actual); + // Will return -1 when not implements ReferenceCounted. + assertTrue(ReferenceCountUtil.refCnt(expected) <= 0); + assertTrue(ReferenceCountUtil.refCnt(actual) <= 0); + } + } + +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/HttpConversionUtilTest.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/HttpConversionUtilTest.java new file mode 100644 index 0000000..365dbda --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/HttpConversionUtilTest.java @@ -0,0 +1,279 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.handler.codec.http.DefaultHttpHeaders; +import io.netty.handler.codec.http.DefaultHttpRequest; +import io.netty.handler.codec.http.HttpHeaders; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpRequest; +import io.netty.handler.codec.http.HttpVersion; +import io.netty.util.AsciiString; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import static io.netty.handler.codec.http.HttpHeaderNames.CONNECTION; +import static io.netty.handler.codec.http.HttpHeaderNames.COOKIE; +import static io.netty.handler.codec.http.HttpHeaderNames.HOST; +import static io.netty.handler.codec.http.HttpHeaderNames.KEEP_ALIVE; +import static io.netty.handler.codec.http.HttpHeaderNames.PROXY_CONNECTION; +import static io.netty.handler.codec.http.HttpHeaderNames.TE; +import static io.netty.handler.codec.http.HttpHeaderNames.TRANSFER_ENCODING; +import static io.netty.handler.codec.http.HttpHeaderNames.UPGRADE; +import static io.netty.handler.codec.http.HttpHeaderValues.GZIP; +import static io.netty.handler.codec.http.HttpHeaderValues.TRAILERS; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class HttpConversionUtilTest { + + @Test + public void connectNoPath() throws Exception { + String authority = "netty.io:80"; + Http2Headers headers = new DefaultHttp2Headers(); + headers.authority(authority); + headers.method(HttpMethod.CONNECT.asciiName()); + HttpRequest request = HttpConversionUtil.toHttpRequest(0, headers, true); + assertNotNull(request); + assertEquals(authority, request.uri()); + assertEquals(authority, request.headers().get(HOST)); + } + + @Test + public void setHttp2AuthorityWithoutUserInfo() { + Http2Headers headers = new DefaultHttp2Headers(); + + HttpConversionUtil.setHttp2Authority("foo", headers); + assertEquals(new AsciiString("foo"), headers.authority()); + } + + @Test + public void setHttp2AuthorityWithUserInfo() { + Http2Headers headers = new DefaultHttp2Headers(); + + HttpConversionUtil.setHttp2Authority("info@foo", headers); + assertEquals(new AsciiString("foo"), headers.authority()); + + HttpConversionUtil.setHttp2Authority("@foo.bar", headers); + assertEquals(new AsciiString("foo.bar"), headers.authority()); + } + + @Test + public void setHttp2AuthorityNullOrEmpty() { + Http2Headers headers = new DefaultHttp2Headers(); + + HttpConversionUtil.setHttp2Authority(null, headers); + assertNull(headers.authority()); + + // https://datatracker.ietf.org/doc/html/rfc9113#section-8.3.1 + // Clients that generate HTTP/2 requests directly MUST use the ":authority" pseudo-header + // field to convey authority information, unless there is no authority information to convey + // (in which case it MUST NOT generate ":authority"). + // An intermediary that forwards a request over HTTP/2 MUST construct an ":authority" pseudo-header + // field using the authority information from the control data of the original request, unless the + // original request's target URI does not contain authority information + // (in which case it MUST NOT generate ":authority"). + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() { + HttpConversionUtil.setHttp2Authority("", new DefaultHttp2Headers()); + } + }); + } + + @Test + public void setHttp2AuthorityWithEmptyAuthority() { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + HttpConversionUtil.setHttp2Authority("info@", new DefaultHttp2Headers()); + } + }); + } + + @Test + public void stripTEHeaders() { + HttpHeaders inHeaders = new DefaultHttpHeaders(); + inHeaders.add(TE, GZIP); + Http2Headers out = new DefaultHttp2Headers(); + HttpConversionUtil.toHttp2Headers(inHeaders, out); + assertTrue(out.isEmpty()); + } + + @Test + public void stripTEHeadersExcludingTrailers() { + HttpHeaders inHeaders = new DefaultHttpHeaders(); + inHeaders.add(TE, GZIP); + inHeaders.add(TE, TRAILERS); + Http2Headers out = new DefaultHttp2Headers(); + HttpConversionUtil.toHttp2Headers(inHeaders, out); + assertSame(TRAILERS, out.get(TE)); + } + + @Test + public void stripTEHeadersCsvSeparatedExcludingTrailers() { + HttpHeaders inHeaders = new DefaultHttpHeaders(); + inHeaders.add(TE, GZIP + "," + TRAILERS); + Http2Headers out = new DefaultHttp2Headers(); + HttpConversionUtil.toHttp2Headers(inHeaders, out); + assertSame(TRAILERS, out.get(TE)); + } + + @Test + public void stripTEHeadersCsvSeparatedAccountsForValueSimilarToTrailers() { + HttpHeaders inHeaders = new DefaultHttpHeaders(); + inHeaders.add(TE, GZIP + "," + TRAILERS + "foo"); + Http2Headers out = new DefaultHttp2Headers(); + HttpConversionUtil.toHttp2Headers(inHeaders, out); + assertFalse(out.contains(TE)); + } + + @Test + public void stripTEHeadersAccountsForValueSimilarToTrailers() { + HttpHeaders inHeaders = new DefaultHttpHeaders(); + inHeaders.add(TE, TRAILERS + "foo"); + Http2Headers out = new DefaultHttp2Headers(); + HttpConversionUtil.toHttp2Headers(inHeaders, out); + assertFalse(out.contains(TE)); + } + + @Test + public void stripTEHeadersAccountsForOWS() { + HttpHeaders inHeaders = new DefaultHttpHeaders(false); + inHeaders.add(TE, " " + TRAILERS + ' '); + Http2Headers out = new DefaultHttp2Headers(); + HttpConversionUtil.toHttp2Headers(inHeaders, out); + assertSame(TRAILERS, out.get(TE)); + } + + @Test + public void stripConnectionHeadersAndNominees() { + HttpHeaders inHeaders = new DefaultHttpHeaders(); + inHeaders.add(CONNECTION, "foo"); + inHeaders.add("foo", "bar"); + Http2Headers out = new DefaultHttp2Headers(); + HttpConversionUtil.toHttp2Headers(inHeaders, out); + assertTrue(out.isEmpty()); + } + + @Test + public void stripConnectionNomineesWithCsv() { + HttpHeaders inHeaders = new DefaultHttpHeaders(); + inHeaders.add(CONNECTION, "foo, bar"); + inHeaders.add("foo", "baz"); + inHeaders.add("bar", "qux"); + inHeaders.add("hello", "world"); + Http2Headers out = new DefaultHttp2Headers(); + HttpConversionUtil.toHttp2Headers(inHeaders, out); + assertEquals(1, out.size()); + assertSame("world", out.get("hello")); + } + + @Test + public void handlesRequest() throws Exception { + boolean validateHeaders = true; + HttpRequest msg = new DefaultHttpRequest( + HttpVersion.HTTP_1_1, HttpMethod.GET, "http://example.com/path/to/something", validateHeaders); + HttpHeaders inHeaders = msg.headers(); + inHeaders.add(CONNECTION, "foo, bar"); + inHeaders.add("hello", "world"); + Http2Headers out = HttpConversionUtil.toHttp2Headers(msg, validateHeaders); + assertEquals(new AsciiString("/path/to/something"), out.path()); + assertEquals(new AsciiString("http"), out.scheme()); + assertEquals(new AsciiString("example.com"), out.authority()); + assertEquals(HttpMethod.GET.asciiName(), out.method()); + assertEquals("world", out.get("hello")); + } + + @Test + public void handlesRequestWithDoubleSlashPath() throws Exception { + boolean validateHeaders = true; + HttpRequest msg = new DefaultHttpRequest( + HttpVersion.HTTP_1_1, HttpMethod.GET, "//path/to/something", validateHeaders); + HttpHeaders inHeaders = msg.headers(); + inHeaders.add(CONNECTION, "foo, bar"); + inHeaders.add(HOST, "example.com"); + inHeaders.add(HttpConversionUtil.ExtensionHeaderNames.SCHEME.text(), "http"); + inHeaders.add("hello", "world"); + Http2Headers out = HttpConversionUtil.toHttp2Headers(msg, validateHeaders); + assertEquals(new AsciiString("//path/to/something"), out.path()); + assertEquals(new AsciiString("http"), out.scheme()); + assertEquals(new AsciiString("example.com"), out.authority()); + assertEquals(HttpMethod.GET.asciiName(), out.method()); + } + + @Test + public void addHttp2ToHttpHeadersCombinesCookies() throws Http2Exception { + Http2Headers inHeaders = new DefaultHttp2Headers(); + inHeaders.add("yes", "no"); + inHeaders.add(COOKIE, "foo=bar"); + inHeaders.add(COOKIE, "bax=baz"); + + HttpHeaders outHeaders = new DefaultHttpHeaders(); + + HttpConversionUtil.addHttp2ToHttpHeaders(5, inHeaders, outHeaders, HttpVersion.HTTP_1_1, false, false); + assertEquals("no", outHeaders.get("yes")); + assertEquals("foo=bar; bax=baz", outHeaders.get(COOKIE.toString())); + } + + @Test + public void connectionSpecificHeadersShouldBeRemoved() { + HttpHeaders inHeaders = new DefaultHttpHeaders(); + inHeaders.add(CONNECTION, "keep-alive"); + inHeaders.add(HOST, "example.com"); + @SuppressWarnings("deprecation") + AsciiString keepAlive = KEEP_ALIVE; + inHeaders.add(keepAlive, "timeout=5, max=1000"); + @SuppressWarnings("deprecation") + AsciiString proxyConnection = PROXY_CONNECTION; + inHeaders.add(proxyConnection, "timeout=5, max=1000"); + inHeaders.add(TRANSFER_ENCODING, "chunked"); + inHeaders.add(UPGRADE, "h2c"); + + Http2Headers outHeaders = new DefaultHttp2Headers(); + HttpConversionUtil.toHttp2Headers(inHeaders, outHeaders); + + assertFalse(outHeaders.contains(CONNECTION)); + assertFalse(outHeaders.contains(HOST)); + assertFalse(outHeaders.contains(keepAlive)); + assertFalse(outHeaders.contains(proxyConnection)); + assertFalse(outHeaders.contains(TRANSFER_ENCODING)); + assertFalse(outHeaders.contains(UPGRADE)); + } + + @Test + public void http2ToHttpHeaderTest() throws Exception { + Http2Headers http2Headers = new DefaultHttp2Headers(); + http2Headers.status("200"); + http2Headers.path("/meow"); // HTTP/2 Header response should not contain 'path' in response. + http2Headers.set("cat", "meow"); + + HttpHeaders httpHeaders = new DefaultHttpHeaders(); + HttpConversionUtil.addHttp2ToHttpHeaders(3, http2Headers, httpHeaders, HttpVersion.HTTP_1_1, false, true); + assertFalse(httpHeaders.contains(HttpConversionUtil.ExtensionHeaderNames.PATH.text())); + assertEquals("meow", httpHeaders.get("cat")); + + httpHeaders.clear(); + HttpConversionUtil.addHttp2ToHttpHeaders(3, http2Headers, httpHeaders, HttpVersion.HTTP_1_1, false, false); + assertTrue(httpHeaders.contains(HttpConversionUtil.ExtensionHeaderNames.PATH.text())); + assertEquals("meow", httpHeaders.get("cat")); + } +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/HttpToHttp2ConnectionHandlerTest.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/HttpToHttp2ConnectionHandlerTest.java new file mode 100644 index 0000000..465c173 --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/HttpToHttp2ConnectionHandlerTest.java @@ -0,0 +1,640 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.ChannelPromise; +import io.netty.channel.DefaultEventLoopGroup; +import io.netty.channel.local.LocalAddress; +import io.netty.channel.local.LocalChannel; +import io.netty.channel.local.LocalServerChannel; +import io.netty.handler.codec.http.DefaultFullHttpRequest; +import io.netty.handler.codec.http.DefaultHttpContent; +import io.netty.handler.codec.http.DefaultHttpRequest; +import io.netty.handler.codec.http.DefaultLastHttpContent; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaders; +import io.netty.handler.codec.http.HttpRequest; +import io.netty.handler.codec.http.HttpScheme; +import io.netty.handler.codec.http.LastHttpContent; +import io.netty.handler.codec.http2.Http2TestUtil.FrameCountDown; +import io.netty.util.AsciiString; +import io.netty.util.concurrent.Future; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.CountDownLatch; + +import static io.netty.handler.codec.http.HttpMethod.CONNECT; +import static io.netty.handler.codec.http.HttpMethod.GET; +import static io.netty.handler.codec.http.HttpMethod.OPTIONS; +import static io.netty.handler.codec.http.HttpMethod.POST; +import static io.netty.handler.codec.http.HttpVersion.HTTP_1_1; +import static io.netty.handler.codec.http2.Http2TestUtil.of; +import static io.netty.util.CharsetUtil.UTF_8; +import static java.util.concurrent.TimeUnit.SECONDS; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.instanceOf; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.anyBoolean; +import static org.mockito.Mockito.anyInt; +import static org.mockito.Mockito.anyShort; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; + +/** + * Testing the {@link HttpToHttp2ConnectionHandler} for {@link FullHttpRequest} objects into HTTP/2 frames + */ +public class HttpToHttp2ConnectionHandlerTest { + private static final int WAIT_TIME_SECONDS = 5; + + @Mock + private Http2FrameListener clientListener; + + @Mock + private Http2FrameListener serverListener; + + private ServerBootstrap sb; + private Bootstrap cb; + private Channel serverChannel; + private volatile Channel serverConnectedChannel; + private Channel clientChannel; + private CountDownLatch requestLatch; + private CountDownLatch serverSettingsAckLatch; + private CountDownLatch trailersLatch; + private FrameCountDown serverFrameCountDown; + + @BeforeEach + public void setup() throws Exception { + MockitoAnnotations.initMocks(this); + } + + @AfterEach + public void teardown() throws Exception { + if (clientChannel != null) { + clientChannel.close().syncUninterruptibly(); + clientChannel = null; + } + if (serverChannel != null) { + serverChannel.close().syncUninterruptibly(); + serverChannel = null; + } + final Channel serverConnectedChannel = this.serverConnectedChannel; + if (serverConnectedChannel != null) { + serverConnectedChannel.close().syncUninterruptibly(); + this.serverConnectedChannel = null; + } + Future serverGroup = sb.config().group().shutdownGracefully(0, 5, SECONDS); + Future serverChildGroup = sb.config().childGroup().shutdownGracefully(0, 5, SECONDS); + Future clientGroup = cb.config().group().shutdownGracefully(0, 5, SECONDS); + serverGroup.syncUninterruptibly(); + serverChildGroup.syncUninterruptibly(); + clientGroup.syncUninterruptibly(); + } + + @Test + public void testHeadersOnlyRequest() throws Exception { + bootstrapEnv(2, 1, 0); + final FullHttpRequest request = new DefaultFullHttpRequest(HTTP_1_1, GET, + "http://my-user_name@www.example.org:5555/example"); + final HttpHeaders httpHeaders = request.headers(); + httpHeaders.setInt(HttpConversionUtil.ExtensionHeaderNames.STREAM_ID.text(), 5); + httpHeaders.set(HttpHeaderNames.HOST, "my-user_name@www.example.org:5555"); + httpHeaders.set(HttpConversionUtil.ExtensionHeaderNames.SCHEME.text(), "http"); + httpHeaders.add(of("foo"), of("goo")); + httpHeaders.add(of("foo"), of("goo2")); + httpHeaders.add(of("foo2"), of("goo2")); + final Http2Headers http2Headers = + new DefaultHttp2Headers().method(new AsciiString("GET")).path(new AsciiString("/example")) + .authority(new AsciiString("www.example.org:5555")).scheme(new AsciiString("http")) + .add(new AsciiString("foo"), new AsciiString("goo")) + .add(new AsciiString("foo"), new AsciiString("goo2")) + .add(new AsciiString("foo2"), new AsciiString("goo2")); + + ChannelPromise writePromise = newPromise(); + verifyHeadersOnly(http2Headers, writePromise, clientChannel.writeAndFlush(request, writePromise)); + } + + @Test + public void testHttpScheme() throws Exception { + bootstrapEnv(2, 1, 0); + final FullHttpRequest request = new DefaultFullHttpRequest(HTTP_1_1, GET, + "http://my-user_name@www.example.org:5555/example"); + final HttpHeaders httpHeaders = request.headers(); + httpHeaders.setInt(HttpConversionUtil.ExtensionHeaderNames.STREAM_ID.text(), 5); + httpHeaders.set(HttpHeaderNames.HOST, "my-user_name@www.example.org:5555"); + httpHeaders.add(of("foo"), of("goo")); + httpHeaders.add(of("foo"), of("goo2")); + httpHeaders.add(of("foo2"), of("goo2")); + final Http2Headers http2Headers = + new DefaultHttp2Headers().method(new AsciiString("GET")).path(new AsciiString("/example")) + .authority(new AsciiString("www.example.org:5555")).scheme(new AsciiString("http")) + .scheme(new AsciiString("http")) + .add(new AsciiString("foo"), new AsciiString("goo")) + .add(new AsciiString("foo"), new AsciiString("goo2")) + .add(new AsciiString("foo2"), new AsciiString("goo2")); + + ChannelPromise writePromise = newPromise(); + verifyHeadersOnly(http2Headers, writePromise, clientChannel.writeAndFlush(request, writePromise)); + } + + @Test + public void testMultipleCookieEntriesAreCombined() throws Exception { + bootstrapEnv(2, 1, 0); + final FullHttpRequest request = new DefaultFullHttpRequest(HTTP_1_1, GET, + "http://my-user_name@www.example.org:5555/example"); + final HttpHeaders httpHeaders = request.headers(); + httpHeaders.setInt(HttpConversionUtil.ExtensionHeaderNames.STREAM_ID.text(), 5); + httpHeaders.set(HttpHeaderNames.HOST, "my-user_name@www.example.org:5555"); + httpHeaders.set(HttpConversionUtil.ExtensionHeaderNames.SCHEME.text(), "http"); + httpHeaders.set(HttpHeaderNames.COOKIE, "a=b; c=d; e=f"); + final Http2Headers http2Headers = + new DefaultHttp2Headers().method(new AsciiString("GET")).path(new AsciiString("/example")) + .authority(new AsciiString("www.example.org:5555")).scheme(new AsciiString("http")) + .add(HttpHeaderNames.COOKIE, "a=b") + .add(HttpHeaderNames.COOKIE, "c=d") + .add(HttpHeaderNames.COOKIE, "e=f"); + + ChannelPromise writePromise = newPromise(); + verifyHeadersOnly(http2Headers, writePromise, clientChannel.writeAndFlush(request, writePromise)); + } + + @Test + public void testOriginFormRequestTargetHandled() throws Exception { + bootstrapEnv(2, 1, 0); + final FullHttpRequest request = new DefaultFullHttpRequest(HTTP_1_1, GET, "/where?q=now&f=then#section1"); + final HttpHeaders httpHeaders = request.headers(); + httpHeaders.setInt(HttpConversionUtil.ExtensionHeaderNames.STREAM_ID.text(), 5); + httpHeaders.set(HttpConversionUtil.ExtensionHeaderNames.SCHEME.text(), "http"); + final Http2Headers http2Headers = + new DefaultHttp2Headers().method(new AsciiString("GET")) + .path(new AsciiString("/where?q=now&f=then#section1")) + .scheme(new AsciiString("http")); + + ChannelPromise writePromise = newPromise(); + verifyHeadersOnly(http2Headers, writePromise, clientChannel.writeAndFlush(request, writePromise)); + } + + @Test + public void testOriginFormRequestTargetHandledFromUrlencodedUri() throws Exception { + bootstrapEnv(2, 1, 0); + final FullHttpRequest request = new DefaultFullHttpRequest( + HTTP_1_1, GET, "/where%2B0?q=now%2B0&f=then%2B0#section1%2B0"); + final HttpHeaders httpHeaders = request.headers(); + httpHeaders.setInt(HttpConversionUtil.ExtensionHeaderNames.STREAM_ID.text(), 5); + httpHeaders.set(HttpConversionUtil.ExtensionHeaderNames.SCHEME.text(), "http"); + final Http2Headers http2Headers = + new DefaultHttp2Headers().method(new AsciiString("GET")) + .path(new AsciiString("/where%2B0?q=now%2B0&f=then%2B0#section1%2B0")) + .scheme(new AsciiString("http")); + + ChannelPromise writePromise = newPromise(); + verifyHeadersOnly(http2Headers, writePromise, clientChannel.writeAndFlush(request, writePromise)); + } + + @Test + public void testAbsoluteFormRequestTargetHandledFromHeaders() throws Exception { + bootstrapEnv(2, 1, 0); + final FullHttpRequest request = new DefaultFullHttpRequest(HTTP_1_1, GET, "/pub/WWW/TheProject.html"); + final HttpHeaders httpHeaders = request.headers(); + httpHeaders.setInt(HttpConversionUtil.ExtensionHeaderNames.STREAM_ID.text(), 5); + httpHeaders.set(HttpHeaderNames.HOST, "foouser@www.example.org:5555"); + httpHeaders.set(HttpConversionUtil.ExtensionHeaderNames.PATH.text(), "ignored_path"); + httpHeaders.set(HttpConversionUtil.ExtensionHeaderNames.SCHEME.text(), "https"); + final Http2Headers http2Headers = + new DefaultHttp2Headers().method(new AsciiString("GET")) + .path(new AsciiString("/pub/WWW/TheProject.html")) + .authority(new AsciiString("www.example.org:5555")).scheme(new AsciiString("https")); + + ChannelPromise writePromise = newPromise(); + verifyHeadersOnly(http2Headers, writePromise, clientChannel.writeAndFlush(request, writePromise)); + } + + @Test + public void testAbsoluteFormRequestTargetHandledFromRequestTargetUri() throws Exception { + bootstrapEnv(2, 1, 0); + final FullHttpRequest request = new DefaultFullHttpRequest(HTTP_1_1, GET, + "http://foouser@www.example.org:5555/pub/WWW/TheProject.html"); + final HttpHeaders httpHeaders = request.headers(); + httpHeaders.setInt(HttpConversionUtil.ExtensionHeaderNames.STREAM_ID.text(), 5); + final Http2Headers http2Headers = + new DefaultHttp2Headers().method(new AsciiString("GET")) + .path(new AsciiString("/pub/WWW/TheProject.html")) + .authority(new AsciiString("www.example.org:5555")).scheme(new AsciiString("http")); + + ChannelPromise writePromise = newPromise(); + verifyHeadersOnly(http2Headers, writePromise, clientChannel.writeAndFlush(request, writePromise)); + } + + @Test + public void testAuthorityFormRequestTargetHandled() throws Exception { + bootstrapEnv(2, 1, 0); + final FullHttpRequest request = new DefaultFullHttpRequest(HTTP_1_1, CONNECT, "http://www.example.com:80"); + final HttpHeaders httpHeaders = request.headers(); + httpHeaders.setInt(HttpConversionUtil.ExtensionHeaderNames.STREAM_ID.text(), 5); + final Http2Headers http2Headers = + new DefaultHttp2Headers().method(new AsciiString("CONNECT")).path(new AsciiString("/")) + .scheme(new AsciiString("http")).authority(new AsciiString("www.example.com:80")); + + ChannelPromise writePromise = newPromise(); + verifyHeadersOnly(http2Headers, writePromise, clientChannel.writeAndFlush(request, writePromise)); + } + + @Test + public void testAsterikFormRequestTargetHandled() throws Exception { + bootstrapEnv(2, 1, 0); + final FullHttpRequest request = new DefaultFullHttpRequest(HTTP_1_1, OPTIONS, "*"); + final HttpHeaders httpHeaders = request.headers(); + httpHeaders.setInt(HttpConversionUtil.ExtensionHeaderNames.STREAM_ID.text(), 5); + httpHeaders.set(HttpHeaderNames.HOST, "www.example.com:80"); + httpHeaders.set(HttpConversionUtil.ExtensionHeaderNames.SCHEME.text(), "http"); + final Http2Headers http2Headers = + new DefaultHttp2Headers().method(new AsciiString("OPTIONS")).path(new AsciiString("*")) + .scheme(new AsciiString("http")).authority(new AsciiString("www.example.com:80")); + + ChannelPromise writePromise = newPromise(); + verifyHeadersOnly(http2Headers, writePromise, clientChannel.writeAndFlush(request, writePromise)); + } + + @Test + public void testHostIPv6FormRequestTargetHandled() throws Exception { + // Valid according to + // https://tools.ietf.org/html/rfc7230#section-2.7.1 -> https://tools.ietf.org/html/rfc3986#section-3.2.2 + bootstrapEnv(2, 1, 0); + final FullHttpRequest request = new DefaultFullHttpRequest(HTTP_1_1, GET, "/"); + final HttpHeaders httpHeaders = request.headers(); + httpHeaders.setInt(HttpConversionUtil.ExtensionHeaderNames.STREAM_ID.text(), 5); + httpHeaders.set(HttpHeaderNames.HOST, "[::1]:80"); + httpHeaders.set(HttpConversionUtil.ExtensionHeaderNames.SCHEME.text(), "http"); + final Http2Headers http2Headers = + new DefaultHttp2Headers().method(new AsciiString("GET")).path(new AsciiString("/")) + .scheme(new AsciiString("http")).authority(new AsciiString("[::1]:80")); + + ChannelPromise writePromise = newPromise(); + verifyHeadersOnly(http2Headers, writePromise, clientChannel.writeAndFlush(request, writePromise)); + } + + @Test + public void testHostFormRequestTargetHandled() throws Exception { + bootstrapEnv(2, 1, 0); + final FullHttpRequest request = new DefaultFullHttpRequest(HTTP_1_1, GET, "/"); + final HttpHeaders httpHeaders = request.headers(); + httpHeaders.setInt(HttpConversionUtil.ExtensionHeaderNames.STREAM_ID.text(), 5); + httpHeaders.set(HttpHeaderNames.HOST, "localhost:80"); + httpHeaders.set(HttpConversionUtil.ExtensionHeaderNames.SCHEME.text(), "http"); + final Http2Headers http2Headers = + new DefaultHttp2Headers().method(new AsciiString("GET")).path(new AsciiString("/")) + .scheme(new AsciiString("http")).authority(new AsciiString("localhost:80")); + + ChannelPromise writePromise = newPromise(); + verifyHeadersOnly(http2Headers, writePromise, clientChannel.writeAndFlush(request, writePromise)); + } + + @Test + public void testHostIPv4FormRequestTargetHandled() throws Exception { + bootstrapEnv(2, 1, 0); + final FullHttpRequest request = new DefaultFullHttpRequest(HTTP_1_1, GET, "/"); + final HttpHeaders httpHeaders = request.headers(); + httpHeaders.setInt(HttpConversionUtil.ExtensionHeaderNames.STREAM_ID.text(), 5); + httpHeaders.set(HttpHeaderNames.HOST, "1.2.3.4:80"); + httpHeaders.set(HttpConversionUtil.ExtensionHeaderNames.SCHEME.text(), "http"); + final Http2Headers http2Headers = + new DefaultHttp2Headers().method(new AsciiString("GET")).path(new AsciiString("/")) + .scheme(new AsciiString("http")).authority(new AsciiString("1.2.3.4:80")); + + ChannelPromise writePromise = newPromise(); + verifyHeadersOnly(http2Headers, writePromise, clientChannel.writeAndFlush(request, writePromise)); + } + + @Test + public void testNoSchemeRequestTargetHandled() throws Exception { + bootstrapEnv(2, 1, 0); + final FullHttpRequest request = new DefaultFullHttpRequest(HTTP_1_1, GET, "/"); + final HttpHeaders httpHeaders = request.headers(); + httpHeaders.setInt(HttpConversionUtil.ExtensionHeaderNames.STREAM_ID.text(), 5); + httpHeaders.set(HttpHeaderNames.HOST, "localhost"); + ChannelPromise writePromise = newPromise(); + ChannelFuture writeFuture = clientChannel.writeAndFlush(request, writePromise); + + assertTrue(writePromise.awaitUninterruptibly(WAIT_TIME_SECONDS, SECONDS)); + assertTrue(writePromise.isDone()); + assertFalse(writePromise.isSuccess()); + assertTrue(writeFuture.isDone()); + assertFalse(writeFuture.isSuccess()); + } + + @Test + public void testInvalidStreamId() throws Exception { + bootstrapEnv(2, 1, 0); + final FullHttpRequest request = new DefaultFullHttpRequest(HTTP_1_1, POST, "/foo", + Unpooled.copiedBuffer("foobar", UTF_8)); + final HttpHeaders httpHeaders = request.headers(); + httpHeaders.setInt(HttpConversionUtil.ExtensionHeaderNames.STREAM_ID.text(), -1); + httpHeaders.set(HttpConversionUtil.ExtensionHeaderNames.SCHEME.text(), "http"); + httpHeaders.set(HttpHeaderNames.HOST, "localhost"); + ChannelPromise writePromise = newPromise(); + ChannelFuture writeFuture = clientChannel.writeAndFlush(request, writePromise); + + assertTrue(writePromise.awaitUninterruptibly(WAIT_TIME_SECONDS, SECONDS)); + assertTrue(writePromise.isDone()); + assertFalse(writePromise.isSuccess()); + Throwable cause = writePromise.cause(); + assertThat(cause, instanceOf(Http2NoMoreStreamIdsException.class)); + + assertTrue(writeFuture.isDone()); + assertFalse(writeFuture.isSuccess()); + cause = writeFuture.cause(); + assertThat(cause, instanceOf(Http2NoMoreStreamIdsException.class)); + } + + @Test + public void testRequestWithBody() throws Exception { + final String text = "foooooogoooo"; + final List receivedBuffers = Collections.synchronizedList(new ArrayList()); + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock in) throws Throwable { + receivedBuffers.add(((ByteBuf) in.getArguments()[2]).toString(UTF_8)); + return null; + } + }).when(serverListener).onDataRead(any(ChannelHandlerContext.class), eq(3), + any(ByteBuf.class), eq(0), eq(true)); + bootstrapEnv(3, 1, 0); + final FullHttpRequest request = new DefaultFullHttpRequest(HTTP_1_1, POST, + "http://your_user-name123@www.example.org:5555/example", + Unpooled.copiedBuffer(text, UTF_8)); + final HttpHeaders httpHeaders = request.headers(); + httpHeaders.set(HttpHeaderNames.HOST, "www.example-origin.org:5555"); + httpHeaders.add(of("foo"), of("goo")); + httpHeaders.add(of("foo"), of("goo2")); + httpHeaders.add(of("foo2"), of("goo2")); + final Http2Headers http2Headers = + new DefaultHttp2Headers().method(new AsciiString("POST")).path(new AsciiString("/example")) + .authority(new AsciiString("www.example-origin.org:5555")).scheme(new AsciiString("http")) + .add(new AsciiString("foo"), new AsciiString("goo")) + .add(new AsciiString("foo"), new AsciiString("goo2")) + .add(new AsciiString("foo2"), new AsciiString("goo2")); + ChannelPromise writePromise = newPromise(); + ChannelFuture writeFuture = clientChannel.writeAndFlush(request, writePromise); + + assertTrue(writePromise.awaitUninterruptibly(WAIT_TIME_SECONDS, SECONDS)); + assertTrue(writePromise.isSuccess()); + assertTrue(writeFuture.awaitUninterruptibly(WAIT_TIME_SECONDS, SECONDS)); + assertTrue(writeFuture.isSuccess()); + awaitRequests(); + verify(serverListener).onHeadersRead(any(ChannelHandlerContext.class), eq(3), eq(http2Headers), eq(0), + anyShort(), anyBoolean(), eq(0), eq(false)); + verify(serverListener).onDataRead(any(ChannelHandlerContext.class), eq(3), any(ByteBuf.class), eq(0), + eq(true)); + assertEquals(1, receivedBuffers.size()); + assertEquals(text, receivedBuffers.get(0)); + } + + @Test + public void testRequestWithBodyAndTrailingHeaders() throws Exception { + final String text = "foooooogoooo"; + final List receivedBuffers = Collections.synchronizedList(new ArrayList()); + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock in) throws Throwable { + receivedBuffers.add(((ByteBuf) in.getArguments()[2]).toString(UTF_8)); + return null; + } + }).when(serverListener).onDataRead(any(ChannelHandlerContext.class), eq(3), + any(ByteBuf.class), eq(0), eq(false)); + bootstrapEnv(4, 1, 1); + final FullHttpRequest request = new DefaultFullHttpRequest(HTTP_1_1, POST, + "http://your_user-name123@www.example.org:5555/example", + Unpooled.copiedBuffer(text, UTF_8)); + final HttpHeaders httpHeaders = request.headers(); + httpHeaders.set(HttpHeaderNames.HOST, "www.example.org:5555"); + httpHeaders.add(of("foo"), of("goo")); + httpHeaders.add(of("foo"), of("goo2")); + httpHeaders.add(of("foo2"), of("goo2")); + final Http2Headers http2Headers = + new DefaultHttp2Headers().method(new AsciiString("POST")).path(new AsciiString("/example")) + .authority(new AsciiString("www.example.org:5555")).scheme(new AsciiString("http")) + .add(new AsciiString("foo"), new AsciiString("goo")) + .add(new AsciiString("foo"), new AsciiString("goo2")) + .add(new AsciiString("foo2"), new AsciiString("goo2")); + + request.trailingHeaders().add(of("trailing"), of("bar")); + + final Http2Headers http2TrailingHeaders = new DefaultHttp2Headers() + .add(new AsciiString("trailing"), new AsciiString("bar")); + + ChannelPromise writePromise = newPromise(); + ChannelFuture writeFuture = clientChannel.writeAndFlush(request, writePromise); + + assertTrue(writePromise.awaitUninterruptibly(WAIT_TIME_SECONDS, SECONDS)); + assertTrue(writePromise.isSuccess()); + assertTrue(writeFuture.awaitUninterruptibly(WAIT_TIME_SECONDS, SECONDS)); + assertTrue(writeFuture.isSuccess()); + awaitRequests(); + verify(serverListener).onHeadersRead(any(ChannelHandlerContext.class), eq(3), eq(http2Headers), eq(0), + anyShort(), anyBoolean(), eq(0), eq(false)); + verify(serverListener).onDataRead(any(ChannelHandlerContext.class), eq(3), any(ByteBuf.class), eq(0), + eq(false)); + verify(serverListener).onHeadersRead(any(ChannelHandlerContext.class), eq(3), eq(http2TrailingHeaders), eq(0), + anyShort(), anyBoolean(), eq(0), eq(true)); + assertEquals(1, receivedBuffers.size()); + assertEquals(text, receivedBuffers.get(0)); + } + + @Test + public void testChunkedRequestWithBodyAndTrailingHeaders() throws Exception { + final String text = "foooooo"; + final String text2 = "goooo"; + final List receivedBuffers = Collections.synchronizedList(new ArrayList()); + doAnswer(new Answer() { + @Override + public Void answer(InvocationOnMock in) throws Throwable { + receivedBuffers.add(((ByteBuf) in.getArguments()[2]).toString(UTF_8)); + return null; + } + }).when(serverListener).onDataRead(any(ChannelHandlerContext.class), eq(3), + any(ByteBuf.class), eq(0), eq(false)); + bootstrapEnv(4, 1, 1); + final HttpRequest request = new DefaultHttpRequest(HTTP_1_1, POST, + "http://your_user-name123@www.example.org:5555/example"); + final HttpHeaders httpHeaders = request.headers(); + httpHeaders.set(HttpHeaderNames.HOST, "www.example.org:5555"); + httpHeaders.add(HttpHeaderNames.TRANSFER_ENCODING, "chunked"); + httpHeaders.add(of("foo"), of("goo")); + httpHeaders.add(of("foo"), of("goo2")); + httpHeaders.add(of("foo2"), of("goo2")); + final Http2Headers http2Headers = + new DefaultHttp2Headers().method(new AsciiString("POST")).path(new AsciiString("/example")) + .authority(new AsciiString("www.example.org:5555")).scheme(new AsciiString("http")) + .add(new AsciiString("foo"), new AsciiString("goo")) + .add(new AsciiString("foo"), new AsciiString("goo2")) + .add(new AsciiString("foo2"), new AsciiString("goo2")); + + final DefaultHttpContent httpContent = new DefaultHttpContent(Unpooled.copiedBuffer(text, UTF_8)); + final LastHttpContent lastHttpContent = new DefaultLastHttpContent(Unpooled.copiedBuffer(text2, UTF_8)); + + lastHttpContent.trailingHeaders().add(of("trailing"), of("bar")); + + final Http2Headers http2TrailingHeaders = new DefaultHttp2Headers() + .add(new AsciiString("trailing"), new AsciiString("bar")); + + ChannelPromise writePromise = newPromise(); + ChannelFuture writeFuture = clientChannel.write(request, writePromise); + ChannelPromise contentPromise = newPromise(); + ChannelFuture contentFuture = clientChannel.write(httpContent, contentPromise); + ChannelPromise lastContentPromise = newPromise(); + ChannelFuture lastContentFuture = clientChannel.write(lastHttpContent, lastContentPromise); + + clientChannel.flush(); + + assertTrue(writePromise.awaitUninterruptibly(WAIT_TIME_SECONDS, SECONDS)); + assertTrue(writePromise.isSuccess()); + assertTrue(writeFuture.awaitUninterruptibly(WAIT_TIME_SECONDS, SECONDS)); + assertTrue(writeFuture.isSuccess()); + + assertTrue(contentPromise.awaitUninterruptibly(WAIT_TIME_SECONDS, SECONDS)); + assertTrue(contentPromise.isSuccess()); + assertTrue(contentFuture.awaitUninterruptibly(WAIT_TIME_SECONDS, SECONDS)); + assertTrue(contentFuture.isSuccess()); + + assertTrue(lastContentPromise.awaitUninterruptibly(WAIT_TIME_SECONDS, SECONDS)); + assertTrue(lastContentPromise.isSuccess()); + assertTrue(lastContentFuture.awaitUninterruptibly(WAIT_TIME_SECONDS, SECONDS)); + assertTrue(lastContentFuture.isSuccess()); + + awaitRequests(); + verify(serverListener).onHeadersRead(any(ChannelHandlerContext.class), eq(3), eq(http2Headers), eq(0), + anyShort(), anyBoolean(), eq(0), eq(false)); + verify(serverListener).onDataRead(any(ChannelHandlerContext.class), eq(3), any(ByteBuf.class), eq(0), + eq(false)); + verify(serverListener).onHeadersRead(any(ChannelHandlerContext.class), eq(3), eq(http2TrailingHeaders), eq(0), + anyShort(), anyBoolean(), eq(0), eq(true)); + assertEquals(1, receivedBuffers.size()); + assertEquals(text + text2, receivedBuffers.get(0)); + } + + private void bootstrapEnv(int requestCountDown, int serverSettingsAckCount, int trailersCount) throws Exception { + final CountDownLatch prefaceWrittenLatch = new CountDownLatch(1); + final CountDownLatch serverChannelLatch = new CountDownLatch(1); + requestLatch = new CountDownLatch(requestCountDown); + serverSettingsAckLatch = new CountDownLatch(serverSettingsAckCount); + trailersLatch = trailersCount == 0 ? null : new CountDownLatch(trailersCount); + + sb = new ServerBootstrap(); + cb = new Bootstrap(); + + sb.group(new DefaultEventLoopGroup()); + sb.channel(LocalServerChannel.class); + sb.childHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + serverConnectedChannel = ch; + ChannelPipeline p = ch.pipeline(); + serverFrameCountDown = + new FrameCountDown(serverListener, serverSettingsAckLatch, requestLatch, null, trailersLatch); + p.addLast(new HttpToHttp2ConnectionHandlerBuilder() + .server(true) + .frameListener(serverFrameCountDown) + .httpScheme(HttpScheme.HTTP) + .build()); + serverChannelLatch.countDown(); + } + }); + + cb.group(new DefaultEventLoopGroup()); + cb.channel(LocalChannel.class); + cb.handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ChannelPipeline p = ch.pipeline(); + HttpToHttp2ConnectionHandler handler = new HttpToHttp2ConnectionHandlerBuilder() + .server(false) + .frameListener(clientListener) + .gracefulShutdownTimeoutMillis(0) + .build(); + p.addLast(handler); + p.addLast(new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + if (evt == Http2ConnectionPrefaceAndSettingsFrameWrittenEvent.INSTANCE) { + prefaceWrittenLatch.countDown(); + ctx.pipeline().remove(this); + } + } + }); + } + }); + + serverChannel = sb.bind(new LocalAddress(getClass())).sync().channel(); + + ChannelFuture ccf = cb.connect(serverChannel.localAddress()); + assertTrue(ccf.awaitUninterruptibly().isSuccess()); + clientChannel = ccf.channel(); + assertTrue(prefaceWrittenLatch.await(5, SECONDS)); + assertTrue(serverChannelLatch.await(WAIT_TIME_SECONDS, SECONDS)); + } + + private void verifyHeadersOnly(Http2Headers expected, ChannelPromise writePromise, ChannelFuture writeFuture) + throws Exception { + assertTrue(writePromise.awaitUninterruptibly(WAIT_TIME_SECONDS, SECONDS)); + assertTrue(writePromise.isSuccess()); + assertTrue(writeFuture.awaitUninterruptibly(WAIT_TIME_SECONDS, SECONDS)); + assertTrue(writeFuture.isSuccess()); + awaitRequests(); + verify(serverListener).onHeadersRead(any(ChannelHandlerContext.class), eq(5), + eq(expected), eq(0), anyShort(), anyBoolean(), eq(0), eq(true)); + verify(serverListener, never()).onDataRead(any(ChannelHandlerContext.class), anyInt(), + any(ByteBuf.class), anyInt(), anyBoolean()); + } + + private void awaitRequests() throws Exception { + assertTrue(requestLatch.await(WAIT_TIME_SECONDS, SECONDS)); + if (trailersLatch != null) { + assertTrue(trailersLatch.await(WAIT_TIME_SECONDS, SECONDS)); + } + assertTrue(serverSettingsAckLatch.await(WAIT_TIME_SECONDS, SECONDS)); + } + + private ChannelHandlerContext ctx() { + return clientChannel.pipeline().firstContext(); + } + + private ChannelPromise newPromise() { + return ctx().newPromise(); + } +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/InOrderHttp2Headers.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/InOrderHttp2Headers.java new file mode 100644 index 0000000..e566e94 --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/InOrderHttp2Headers.java @@ -0,0 +1,104 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.codec.http2; + +import io.netty.handler.codec.CharSequenceValueConverter; +import io.netty.handler.codec.DefaultHeaders; + +import static io.netty.util.AsciiString.CASE_INSENSITIVE_HASHER; +import static io.netty.util.AsciiString.CASE_SENSITIVE_HASHER; + +/** + * Http2Headers implementation that preserves headers insertion order. + */ +public class InOrderHttp2Headers + extends DefaultHeaders implements Http2Headers { + + InOrderHttp2Headers() { + super(CharSequenceValueConverter.INSTANCE); + } + + @Override + public boolean equals(Object o) { + return o instanceof Http2Headers && equals((Http2Headers) o, CASE_SENSITIVE_HASHER); + } + + @Override + public int hashCode() { + return hashCode(CASE_SENSITIVE_HASHER); + } + + @Override + public Http2Headers method(CharSequence value) { + set(PseudoHeaderName.METHOD.value(), value); + return this; + } + + @Override + public Http2Headers scheme(CharSequence value) { + set(PseudoHeaderName.SCHEME.value(), value); + return this; + } + + @Override + public Http2Headers authority(CharSequence value) { + set(PseudoHeaderName.AUTHORITY.value(), value); + return this; + } + + @Override + public Http2Headers path(CharSequence value) { + set(PseudoHeaderName.PATH.value(), value); + return this; + } + + @Override + public Http2Headers status(CharSequence value) { + set(PseudoHeaderName.STATUS.value(), value); + return this; + } + + @Override + public CharSequence method() { + return get(PseudoHeaderName.METHOD.value()); + } + + @Override + public CharSequence scheme() { + return get(PseudoHeaderName.SCHEME.value()); + } + + @Override + public CharSequence authority() { + return get(PseudoHeaderName.AUTHORITY.value()); + } + + @Override + public CharSequence path() { + return get(PseudoHeaderName.PATH.value()); + } + + @Override + public CharSequence status() { + return get(PseudoHeaderName.STATUS.value()); + } + + @Override + public boolean contains(CharSequence name, CharSequence value, boolean caseInsensitive) { + return contains(name, value, caseInsensitive ? CASE_INSENSITIVE_HASHER : CASE_SENSITIVE_HASHER); + } +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/InboundHttp2ToHttpAdapterTest.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/InboundHttp2ToHttpAdapterTest.java new file mode 100644 index 0000000..092f2c2 --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/InboundHttp2ToHttpAdapterTest.java @@ -0,0 +1,853 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerAdapter; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.ChannelPromise; +import io.netty.channel.DefaultEventLoopGroup; +import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.channel.local.LocalAddress; +import io.netty.channel.local.LocalChannel; +import io.netty.channel.local.LocalServerChannel; +import io.netty.handler.codec.http.DefaultFullHttpRequest; +import io.netty.handler.codec.http.DefaultFullHttpResponse; +import io.netty.handler.codec.http.FullHttpMessage; +import io.netty.handler.codec.http.FullHttpRequest; +import io.netty.handler.codec.http.HttpHeaderNames; +import io.netty.handler.codec.http.HttpHeaderValues; +import io.netty.handler.codec.http.HttpHeaders; +import io.netty.handler.codec.http.HttpMethod; +import io.netty.handler.codec.http.HttpObject; +import io.netty.handler.codec.http.HttpResponseStatus; +import io.netty.handler.codec.http.HttpVersion; +import io.netty.handler.codec.http2.Http2TestUtil.Http2Runnable; +import io.netty.util.AsciiString; +import io.netty.util.CharsetUtil; +import io.netty.util.concurrent.Future; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; + +import java.util.List; +import java.util.concurrent.CountDownLatch; + +import static io.netty.handler.codec.http2.Http2CodecUtil.getEmbeddedHttp2Exception; +import static io.netty.handler.codec.http2.Http2Exception.isStreamError; +import static io.netty.handler.codec.http2.Http2TestUtil.of; +import static io.netty.handler.codec.http2.Http2TestUtil.runInChannel; +import static java.util.concurrent.TimeUnit.SECONDS; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +/** + * Testing the {@link InboundHttp2ToHttpAdapter} and base class {@link InboundHttp2ToHttpAdapter} for HTTP/2 + * frames into {@link HttpObject}s + */ +public class InboundHttp2ToHttpAdapterTest { + private List capturedRequests; + private List capturedResponses; + + @Mock + private HttpResponseListener serverListener; + + @Mock + private HttpResponseListener clientListener; + + @Mock + private HttpSettingsListener settingsListener; + + private Http2ConnectionHandler serverHandler; + private Http2ConnectionHandler clientHandler; + private ServerBootstrap sb; + private Bootstrap cb; + private Channel serverChannel; + private volatile Channel serverConnectedChannel; + private Channel clientChannel; + private CountDownLatch serverLatch; + private CountDownLatch clientLatch; + private CountDownLatch serverLatch2; + private CountDownLatch clientLatch2; + private CountDownLatch settingsLatch; + private CountDownLatch clientHandlersAddedLatch; + private int maxContentLength; + private HttpResponseDelegator serverDelegator; + private HttpResponseDelegator clientDelegator; + private HttpSettingsDelegator settingsDelegator; + private Http2Exception clientException; + + @BeforeEach + public void setup() throws Exception { + MockitoAnnotations.initMocks(this); + } + + @AfterEach + public void teardown() throws Exception { + cleanupCapturedRequests(); + cleanupCapturedResponses(); + if (clientChannel != null) { + clientChannel.close().syncUninterruptibly(); + clientChannel = null; + } + if (serverChannel != null) { + serverChannel.close().syncUninterruptibly(); + serverChannel = null; + } + final Channel serverConnectedChannel = this.serverConnectedChannel; + if (serverConnectedChannel != null) { + serverConnectedChannel.close().syncUninterruptibly(); + this.serverConnectedChannel = null; + } + Future serverGroup = sb.config().group().shutdownGracefully(0, 5, SECONDS); + Future serverChildGroup = sb.config().childGroup().shutdownGracefully(0, 5, SECONDS); + Future clientGroup = cb.config().group().shutdownGracefully(0, 5, SECONDS); + serverGroup.syncUninterruptibly(); + serverChildGroup.syncUninterruptibly(); + clientGroup.syncUninterruptibly(); + } + + @Test + public void clientRequestSingleHeaderNoDataFrames() throws Exception { + boostrapEnv(1, 1, 1); + final FullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, + "/some/path/resource2", true); + try { + HttpHeaders httpHeaders = request.headers(); + httpHeaders.set(HttpConversionUtil.ExtensionHeaderNames.SCHEME.text(), "https"); + httpHeaders.set(HttpHeaderNames.HOST, "example.org"); + httpHeaders.setInt(HttpConversionUtil.ExtensionHeaderNames.STREAM_ID.text(), 3); + httpHeaders.setInt(HttpHeaderNames.CONTENT_LENGTH, 0); + httpHeaders.setShort(HttpConversionUtil.ExtensionHeaderNames.STREAM_WEIGHT.text(), (short) 16); + final Http2Headers http2Headers = new DefaultHttp2Headers().method(new AsciiString("GET")). + scheme(new AsciiString("https")).authority(new AsciiString("example.org")) + .path(new AsciiString("/some/path/resource2")); + runInChannel(clientChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + clientHandler.encoder().writeHeaders(ctxClient(), 3, http2Headers, 0, true, newPromiseClient()); + clientChannel.flush(); + } + }); + awaitRequests(); + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(FullHttpMessage.class); + verify(serverListener).messageReceived(requestCaptor.capture()); + capturedRequests = requestCaptor.getAllValues(); + assertEquals(request, capturedRequests.get(0)); + } finally { + request.release(); + } + } + + @Test + public void clientRequestSingleHeaderCookieSplitIntoMultipleEntries() throws Exception { + boostrapEnv(1, 1, 1); + final FullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, + "/some/path/resource2", true); + try { + HttpHeaders httpHeaders = request.headers(); + httpHeaders.set(HttpConversionUtil.ExtensionHeaderNames.SCHEME.text(), "https"); + httpHeaders.set(HttpHeaderNames.HOST, "example.org"); + httpHeaders.setInt(HttpConversionUtil.ExtensionHeaderNames.STREAM_ID.text(), 3); + httpHeaders.setInt(HttpHeaderNames.CONTENT_LENGTH, 0); + httpHeaders.set(HttpHeaderNames.COOKIE, "a=b; c=d; e=f"); + httpHeaders.setShort(HttpConversionUtil.ExtensionHeaderNames.STREAM_WEIGHT.text(), (short) 16); + final Http2Headers http2Headers = new DefaultHttp2Headers().method(new AsciiString("GET")). + scheme(new AsciiString("https")).authority(new AsciiString("example.org")) + .path(new AsciiString("/some/path/resource2")) + .add(HttpHeaderNames.COOKIE, "a=b") + .add(HttpHeaderNames.COOKIE, "c=d") + .add(HttpHeaderNames.COOKIE, "e=f"); + runInChannel(clientChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + clientHandler.encoder().writeHeaders(ctxClient(), 3, http2Headers, 0, true, newPromiseClient()); + clientChannel.flush(); + } + }); + awaitRequests(); + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(FullHttpMessage.class); + verify(serverListener).messageReceived(requestCaptor.capture()); + capturedRequests = requestCaptor.getAllValues(); + assertEquals(request, capturedRequests.get(0)); + } finally { + request.release(); + } + } + + @Test + public void clientRequestSingleHeaderCookieSplitIntoMultipleEntries2() throws Exception { + boostrapEnv(1, 1, 1); + final FullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, + "/some/path/resource2", true); + try { + HttpHeaders httpHeaders = request.headers(); + httpHeaders.set(HttpConversionUtil.ExtensionHeaderNames.SCHEME.text(), "https"); + httpHeaders.set(HttpHeaderNames.HOST, "example.org"); + httpHeaders.setInt(HttpConversionUtil.ExtensionHeaderNames.STREAM_ID.text(), 3); + httpHeaders.setInt(HttpHeaderNames.CONTENT_LENGTH, 0); + httpHeaders.set(HttpHeaderNames.COOKIE, "a=b; c=d; e=f"); + httpHeaders.setShort(HttpConversionUtil.ExtensionHeaderNames.STREAM_WEIGHT.text(), (short) 16); + final Http2Headers http2Headers = new DefaultHttp2Headers().method(new AsciiString("GET")). + scheme(new AsciiString("https")).authority(new AsciiString("example.org")) + .path(new AsciiString("/some/path/resource2")) + .add(HttpHeaderNames.COOKIE, "a=b; c=d") + .add(HttpHeaderNames.COOKIE, "e=f"); + runInChannel(clientChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + clientHandler.encoder().writeHeaders(ctxClient(), 3, http2Headers, 0, true, newPromiseClient()); + clientChannel.flush(); + } + }); + awaitRequests(); + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(FullHttpMessage.class); + verify(serverListener).messageReceived(requestCaptor.capture()); + capturedRequests = requestCaptor.getAllValues(); + assertEquals(request, capturedRequests.get(0)); + } finally { + request.release(); + } + } + + @Test + public void clientRequestSingleHeaderNonAsciiShouldThrow() throws Exception { + boostrapEnv(1, 1, 1); + final Http2Headers http2Headers = new DefaultHttp2Headers() + .method(new AsciiString("GET")) + .scheme(new AsciiString("https")) + .authority(new AsciiString("example.org")) + .path(new AsciiString("/some/path/resource2")) + .add(new AsciiString("çã".getBytes(CharsetUtil.UTF_8)), + new AsciiString("Ãã".getBytes(CharsetUtil.UTF_8))); + runInChannel(clientChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + clientHandler.encoder().writeHeaders(ctxClient(), 3, http2Headers, 0, true, newPromiseClient()); + clientChannel.flush(); + } + }); + awaitResponses(); + assertTrue(isStreamError(clientException)); + } + + @Test + public void clientRequestOneDataFrame() throws Exception { + boostrapEnv(1, 1, 1); + final String text = "hello world"; + final ByteBuf content = Unpooled.copiedBuffer(text.getBytes()); + final FullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, + "/some/path/resource2", content, true); + try { + HttpHeaders httpHeaders = request.headers(); + httpHeaders.setInt(HttpConversionUtil.ExtensionHeaderNames.STREAM_ID.text(), 3); + httpHeaders.setInt(HttpHeaderNames.CONTENT_LENGTH, text.length()); + httpHeaders.setShort(HttpConversionUtil.ExtensionHeaderNames.STREAM_WEIGHT.text(), (short) 16); + final Http2Headers http2Headers = new DefaultHttp2Headers().method(new AsciiString("GET")).path( + new AsciiString("/some/path/resource2")); + runInChannel(clientChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + clientHandler.encoder().writeHeaders(ctxClient(), 3, http2Headers, 0, false, newPromiseClient()); + clientHandler.encoder().writeData(ctxClient(), 3, content.retainedDuplicate(), 0, true, + newPromiseClient()); + clientChannel.flush(); + } + }); + awaitRequests(); + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(FullHttpMessage.class); + verify(serverListener).messageReceived(requestCaptor.capture()); + capturedRequests = requestCaptor.getAllValues(); + assertEquals(request, capturedRequests.get(0)); + } finally { + request.release(); + } + } + + @Test + public void clientRequestMultipleDataFrames() throws Exception { + boostrapEnv(1, 1, 1); + final String text = "hello world big time data!"; + final ByteBuf content = Unpooled.copiedBuffer(text.getBytes()); + final FullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, + "/some/path/resource2", content, true); + try { + HttpHeaders httpHeaders = request.headers(); + httpHeaders.setInt(HttpConversionUtil.ExtensionHeaderNames.STREAM_ID.text(), 3); + httpHeaders.setInt(HttpHeaderNames.CONTENT_LENGTH, text.length()); + httpHeaders.setShort(HttpConversionUtil.ExtensionHeaderNames.STREAM_WEIGHT.text(), (short) 16); + final Http2Headers http2Headers = new DefaultHttp2Headers().method(new AsciiString("GET")).path( + new AsciiString("/some/path/resource2")); + final int midPoint = text.length() / 2; + runInChannel(clientChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + clientHandler.encoder().writeHeaders(ctxClient(), 3, http2Headers, 0, false, newPromiseClient()); + clientHandler.encoder().writeData( + ctxClient(), 3, content.retainedSlice(0, midPoint), 0, false, newPromiseClient()); + clientHandler.encoder().writeData( + ctxClient(), 3, content.retainedSlice(midPoint, text.length() - midPoint), + 0, true, newPromiseClient()); + clientChannel.flush(); + } + }); + awaitRequests(); + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(FullHttpMessage.class); + verify(serverListener).messageReceived(requestCaptor.capture()); + capturedRequests = requestCaptor.getAllValues(); + assertEquals(request, capturedRequests.get(0)); + } finally { + request.release(); + } + } + + @Test + public void clientRequestMultipleEmptyDataFrames() throws Exception { + boostrapEnv(1, 1, 1); + final String text = ""; + final ByteBuf content = Unpooled.copiedBuffer(text.getBytes()); + final FullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, + "/some/path/resource2", content, true); + try { + HttpHeaders httpHeaders = request.headers(); + httpHeaders.setInt(HttpConversionUtil.ExtensionHeaderNames.STREAM_ID.text(), 3); + httpHeaders.setInt(HttpHeaderNames.CONTENT_LENGTH, text.length()); + httpHeaders.setShort(HttpConversionUtil.ExtensionHeaderNames.STREAM_WEIGHT.text(), (short) 16); + final Http2Headers http2Headers = new DefaultHttp2Headers().method(new AsciiString("GET")).path( + new AsciiString("/some/path/resource2")); + runInChannel(clientChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + clientHandler.encoder().writeHeaders(ctxClient(), 3, http2Headers, 0, false, newPromiseClient()); + clientHandler.encoder().writeData(ctxClient(), 3, content.retain(), 0, false, newPromiseClient()); + clientHandler.encoder().writeData(ctxClient(), 3, content.retain(), 0, false, newPromiseClient()); + clientHandler.encoder().writeData(ctxClient(), 3, content.retain(), 0, true, newPromiseClient()); + clientChannel.flush(); + } + }); + awaitRequests(); + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(FullHttpMessage.class); + verify(serverListener).messageReceived(requestCaptor.capture()); + capturedRequests = requestCaptor.getAllValues(); + assertEquals(request, capturedRequests.get(0)); + } finally { + request.release(); + } + } + + @Test + public void clientRequestTrailingHeaders() throws Exception { + boostrapEnv(1, 1, 1); + final String text = "some data"; + final ByteBuf content = Unpooled.copiedBuffer(text.getBytes()); + final FullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, + "/some/path/resource2", content, true); + try { + HttpHeaders httpHeaders = request.headers(); + httpHeaders.setInt(HttpConversionUtil.ExtensionHeaderNames.STREAM_ID.text(), 3); + httpHeaders.setInt(HttpHeaderNames.CONTENT_LENGTH, text.length()); + httpHeaders.setShort(HttpConversionUtil.ExtensionHeaderNames.STREAM_WEIGHT.text(), (short) 16); + HttpHeaders trailingHeaders = request.trailingHeaders(); + trailingHeaders.set(of("Foo"), of("goo")); + trailingHeaders.set(of("fOo2"), of("goo2")); + trailingHeaders.add(of("foO2"), of("goo3")); + final Http2Headers http2Headers = new DefaultHttp2Headers().method(new AsciiString("GET")).path( + new AsciiString("/some/path/resource2")); + final Http2Headers http2Headers2 = new DefaultHttp2Headers() + .set(new AsciiString("foo"), new AsciiString("goo")) + .set(new AsciiString("foo2"), new AsciiString("goo2")) + .add(new AsciiString("foo2"), new AsciiString("goo3")); + runInChannel(clientChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + clientHandler.encoder().writeHeaders(ctxClient(), 3, http2Headers, 0, false, newPromiseClient()); + clientHandler.encoder().writeData(ctxClient(), 3, content.retainedDuplicate(), 0, false, + newPromiseClient()); + clientHandler.encoder().writeHeaders(ctxClient(), 3, http2Headers2, 0, true, newPromiseClient()); + clientChannel.flush(); + } + }); + awaitRequests(); + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(FullHttpMessage.class); + verify(serverListener).messageReceived(requestCaptor.capture()); + capturedRequests = requestCaptor.getAllValues(); + assertEquals(request, capturedRequests.get(0)); + } finally { + request.release(); + } + } + + @Test + public void clientRequestStreamDependencyInHttpMessageFlow() throws Exception { + boostrapEnv(1, 2, 1); + final String text = "hello world big time data!"; + final ByteBuf content = Unpooled.copiedBuffer(text.getBytes()); + final String text2 = "hello world big time data...number 2!!"; + final ByteBuf content2 = Unpooled.copiedBuffer(text2.getBytes()); + final FullHttpRequest request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.PUT, + "/some/path/resource", content, true); + final FullHttpMessage request2 = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.PUT, + "/some/path/resource2", content2, true); + try { + HttpHeaders httpHeaders = request.headers(); + httpHeaders.setInt(HttpConversionUtil.ExtensionHeaderNames.STREAM_ID.text(), 3); + httpHeaders.setInt(HttpHeaderNames.CONTENT_LENGTH, text.length()); + httpHeaders.setShort(HttpConversionUtil.ExtensionHeaderNames.STREAM_WEIGHT.text(), (short) 16); + HttpHeaders httpHeaders2 = request2.headers(); + httpHeaders2.setInt(HttpConversionUtil.ExtensionHeaderNames.STREAM_ID.text(), 5); + httpHeaders2.setInt(HttpConversionUtil.ExtensionHeaderNames.STREAM_DEPENDENCY_ID.text(), 3); + httpHeaders2.setShort(HttpConversionUtil.ExtensionHeaderNames.STREAM_WEIGHT.text(), (short) 123); + httpHeaders2.setInt(HttpHeaderNames.CONTENT_LENGTH, text2.length()); + final Http2Headers http2Headers = new DefaultHttp2Headers().method(new AsciiString("PUT")).path( + new AsciiString("/some/path/resource")); + final Http2Headers http2Headers2 = new DefaultHttp2Headers().method(new AsciiString("PUT")).path( + new AsciiString("/some/path/resource2")); + runInChannel(clientChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + clientHandler.encoder().writeHeaders(ctxClient(), 3, http2Headers, 0, false, newPromiseClient()); + clientHandler.encoder().writeHeaders(ctxClient(), 5, http2Headers2, 3, (short) 123, true, 0, + false, newPromiseClient()); + clientChannel.flush(); // Headers are queued in the flow controller and so flush them. + clientHandler.encoder().writeData(ctxClient(), 3, content.retainedDuplicate(), 0, true, + newPromiseClient()); + clientHandler.encoder().writeData(ctxClient(), 5, content2.retainedDuplicate(), 0, true, + newPromiseClient()); + clientChannel.flush(); + } + }); + awaitRequests(); + ArgumentCaptor httpObjectCaptor = ArgumentCaptor.forClass(FullHttpMessage.class); + verify(serverListener, times(2)).messageReceived(httpObjectCaptor.capture()); + capturedRequests = httpObjectCaptor.getAllValues(); + assertEquals(request, capturedRequests.get(0)); + assertEquals(request2, capturedRequests.get(1)); + } finally { + request.release(); + request2.release(); + } + } + + @Test + public void serverRequestPushPromise() throws Exception { + boostrapEnv(1, 1, 1); + final String text = "hello world big time data!"; + final ByteBuf content = Unpooled.copiedBuffer(text.getBytes()); + final String text2 = "hello world smaller data?"; + final ByteBuf content2 = Unpooled.copiedBuffer(text2.getBytes()); + final FullHttpMessage response = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK, + content, true); + final FullHttpMessage response2 = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.CREATED, + content2, true); + final FullHttpMessage request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, "/push/test", + true); + try { + HttpHeaders httpHeaders = response.headers(); + httpHeaders.setInt(HttpConversionUtil.ExtensionHeaderNames.STREAM_ID.text(), 3); + httpHeaders.setInt(HttpHeaderNames.CONTENT_LENGTH, text.length()); + httpHeaders.setShort(HttpConversionUtil.ExtensionHeaderNames.STREAM_WEIGHT.text(), (short) 16); + HttpHeaders httpHeaders2 = response2.headers(); + httpHeaders2.set(HttpConversionUtil.ExtensionHeaderNames.SCHEME.text(), "https"); + httpHeaders2.set(HttpHeaderNames.HOST, "example.org"); + httpHeaders2.setInt(HttpConversionUtil.ExtensionHeaderNames.STREAM_ID.text(), 5); + httpHeaders2.setInt(HttpConversionUtil.ExtensionHeaderNames.STREAM_PROMISE_ID.text(), 3); + httpHeaders2.setInt(HttpHeaderNames.CONTENT_LENGTH, text2.length()); + + httpHeaders = request.headers(); + httpHeaders.setInt(HttpConversionUtil.ExtensionHeaderNames.STREAM_ID.text(), 3); + httpHeaders.setInt(HttpHeaderNames.CONTENT_LENGTH, 0); + httpHeaders.setShort(HttpConversionUtil.ExtensionHeaderNames.STREAM_WEIGHT.text(), (short) 16); + final Http2Headers http2Headers3 = new DefaultHttp2Headers().method(new AsciiString("GET")) + .path(new AsciiString("/push/test")); + runInChannel(clientChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + clientHandler.encoder().writeHeaders(ctxClient(), 3, http2Headers3, 0, true, newPromiseClient()); + clientChannel.flush(); + } + }); + awaitRequests(); + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(FullHttpMessage.class); + verify(serverListener).messageReceived(requestCaptor.capture()); + capturedRequests = requestCaptor.getAllValues(); + assertEquals(request, capturedRequests.get(0)); + + final Http2Headers http2Headers = new DefaultHttp2Headers().status(new AsciiString("200")); + // The PUSH_PROMISE frame includes a header block that contains a + // complete set of request header fields that the server attributes to + // the request. + // https://tools.ietf.org/html/rfc7540#section-8.2.1 + // Therefore, we should consider the case where there is no Http response status. + final Http2Headers http2Headers2 = new DefaultHttp2Headers() + .scheme(new AsciiString("https")) + .authority(new AsciiString("example.org")); + runInChannel(serverConnectedChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + serverHandler.encoder().writeHeaders(ctxServer(), 3, http2Headers, 0, false, newPromiseServer()); + serverHandler.encoder().writePushPromise(ctxServer(), 3, 2, http2Headers2, 0, newPromiseServer()); + serverHandler.encoder().writeData(ctxServer(), 3, content.retainedDuplicate(), 0, true, + newPromiseServer()); + serverHandler.encoder().writeData(ctxServer(), 5, content2.retainedDuplicate(), 0, true, + newPromiseServer()); + serverConnectedChannel.flush(); + } + }); + awaitResponses(); + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(FullHttpMessage.class); + verify(clientListener).messageReceived(responseCaptor.capture()); + capturedResponses = responseCaptor.getAllValues(); + assertEquals(response, capturedResponses.get(0)); + } finally { + request.release(); + response.release(); + response2.release(); + } + } + + @Test + public void serverResponseHeaderInformational() throws Exception { + boostrapEnv(1, 2, 1, 2, 1); + final FullHttpMessage request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.PUT, "/info/test", + true); + HttpHeaders httpHeaders = request.headers(); + httpHeaders.setInt(HttpConversionUtil.ExtensionHeaderNames.STREAM_ID.text(), 3); + httpHeaders.set(HttpHeaderNames.EXPECT, HttpHeaderValues.CONTINUE); + httpHeaders.setInt(HttpHeaderNames.CONTENT_LENGTH, 0); + httpHeaders.setShort(HttpConversionUtil.ExtensionHeaderNames.STREAM_WEIGHT.text(), (short) 16); + + final Http2Headers http2Headers = new DefaultHttp2Headers().method(new AsciiString("PUT")) + .path(new AsciiString("/info/test")) + .set(new AsciiString(HttpHeaderNames.EXPECT.toString()), + new AsciiString(HttpHeaderValues.CONTINUE.toString())); + final FullHttpMessage response = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.CONTINUE); + final String text = "a big payload"; + final ByteBuf payload = Unpooled.copiedBuffer(text.getBytes()); + final FullHttpMessage request2 = request.replace(payload); + final FullHttpMessage response2 = new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK); + + try { + runInChannel(clientChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + clientHandler.encoder().writeHeaders(ctxClient(), 3, http2Headers, 0, false, newPromiseClient()); + clientChannel.flush(); + } + }); + + awaitRequests(); + httpHeaders = response.headers(); + httpHeaders.setInt(HttpConversionUtil.ExtensionHeaderNames.STREAM_ID.text(), 3); + httpHeaders.setInt(HttpHeaderNames.CONTENT_LENGTH, 0); + final Http2Headers http2HeadersResponse = new DefaultHttp2Headers().status(new AsciiString("100")); + runInChannel(serverConnectedChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + serverHandler.encoder().writeHeaders(ctxServer(), 3, http2HeadersResponse, 0, false, + newPromiseServer()); + serverConnectedChannel.flush(); + } + }); + + awaitResponses(); + httpHeaders = request2.headers(); + httpHeaders.setInt(HttpHeaderNames.CONTENT_LENGTH, text.length()); + httpHeaders.remove(HttpHeaderNames.EXPECT); + runInChannel(clientChannel, new Http2Runnable() { + @Override + public void run() { + clientHandler.encoder().writeData(ctxClient(), 3, payload.retainedDuplicate(), 0, true, + newPromiseClient()); + clientChannel.flush(); + } + }); + + awaitRequests2(); + httpHeaders = response2.headers(); + httpHeaders.setInt(HttpConversionUtil.ExtensionHeaderNames.STREAM_ID.text(), 3); + httpHeaders.setInt(HttpHeaderNames.CONTENT_LENGTH, 0); + httpHeaders.setShort(HttpConversionUtil.ExtensionHeaderNames.STREAM_WEIGHT.text(), (short) 16); + + final Http2Headers http2HeadersResponse2 = new DefaultHttp2Headers().status(new AsciiString("200")); + runInChannel(serverConnectedChannel, new Http2Runnable() { + @Override + public void run() throws Http2Exception { + serverHandler.encoder().writeHeaders(ctxServer(), 3, http2HeadersResponse2, 0, true, + newPromiseServer()); + serverConnectedChannel.flush(); + } + }); + + awaitResponses2(); + ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(FullHttpMessage.class); + verify(serverListener, times(2)).messageReceived(requestCaptor.capture()); + capturedRequests = requestCaptor.getAllValues(); + assertEquals(2, capturedRequests.size()); + // We do not expect to have this header in the captured request so remove it now. + assertNotNull(request.headers().remove("x-http2-stream-weight")); + + assertEquals(request, capturedRequests.get(0)); + assertEquals(request2, capturedRequests.get(1)); + + ArgumentCaptor responseCaptor = ArgumentCaptor.forClass(FullHttpMessage.class); + verify(clientListener, times(2)).messageReceived(responseCaptor.capture()); + capturedResponses = responseCaptor.getAllValues(); + assertEquals(2, capturedResponses.size()); + assertEquals(response, capturedResponses.get(0)); + assertEquals(response2, capturedResponses.get(1)); + } finally { + request.release(); + request2.release(); + response.release(); + response2.release(); + } + } + + @Test + public void propagateSettings() throws Exception { + boostrapEnv(1, 1, 2); + final Http2Settings settings = new Http2Settings().pushEnabled(true); + runInChannel(clientChannel, new Http2Runnable() { + @Override + public void run() { + clientHandler.encoder().writeSettings(ctxClient(), settings, newPromiseClient()); + clientChannel.flush(); + } + }); + assertTrue(settingsLatch.await(5, SECONDS)); + ArgumentCaptor settingsCaptor = ArgumentCaptor.forClass(Http2Settings.class); + verify(settingsListener, times(2)).messageReceived(settingsCaptor.capture()); + assertEquals(settings, settingsCaptor.getValue()); + } + + private void boostrapEnv(int clientLatchCount, int serverLatchCount, int settingsLatchCount) + throws InterruptedException { + boostrapEnv(clientLatchCount, clientLatchCount, serverLatchCount, serverLatchCount, settingsLatchCount); + } + + private void boostrapEnv(int clientLatchCount, int clientLatchCount2, int serverLatchCount, int serverLatchCount2, + int settingsLatchCount) throws InterruptedException { + final CountDownLatch prefaceWrittenLatch = new CountDownLatch(1); + clientDelegator = null; + serverDelegator = null; + serverConnectedChannel = null; + maxContentLength = 1024; + final CountDownLatch serverChannelLatch = new CountDownLatch(1); + serverLatch = new CountDownLatch(serverLatchCount); + clientLatch = new CountDownLatch(clientLatchCount); + serverLatch2 = new CountDownLatch(serverLatchCount2); + clientLatch2 = new CountDownLatch(clientLatchCount2); + settingsLatch = new CountDownLatch(settingsLatchCount); + clientHandlersAddedLatch = new CountDownLatch(1); + + sb = new ServerBootstrap(); + cb = new Bootstrap(); + + sb.group(new DefaultEventLoopGroup()); + sb.channel(LocalServerChannel.class); + sb.childHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + serverConnectedChannel = ch; + ChannelPipeline p = ch.pipeline(); + Http2Connection connection = new DefaultHttp2Connection(true); + + serverHandler = new Http2ConnectionHandlerBuilder().frameListener( + new InboundHttp2ToHttpAdapterBuilder(connection) + .maxContentLength(maxContentLength) + .validateHttpHeaders(true) + .propagateSettings(true) + .build()) + .connection(connection) + .gracefulShutdownTimeoutMillis(0) + .build(); + p.addLast(serverHandler); + + serverDelegator = new HttpResponseDelegator(serverListener, serverLatch, serverLatch2); + p.addLast(serverDelegator); + settingsDelegator = new HttpSettingsDelegator(settingsListener, settingsLatch); + p.addLast(settingsDelegator); + serverChannelLatch.countDown(); + } + }); + + cb.group(new DefaultEventLoopGroup()); + cb.channel(LocalChannel.class); + cb.handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ChannelPipeline p = ch.pipeline(); + Http2Connection connection = new DefaultHttp2Connection(false); + + clientHandler = new Http2ConnectionHandlerBuilder().frameListener( + new InboundHttp2ToHttpAdapterBuilder(connection) + .maxContentLength(maxContentLength) + .build()) + .connection(connection) + .gracefulShutdownTimeoutMillis(0) + .build(); + p.addLast(clientHandler); + + clientDelegator = new HttpResponseDelegator(clientListener, clientLatch, clientLatch2); + p.addLast(clientDelegator); + p.addLast(new ChannelHandlerAdapter() { + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + Http2Exception e = getEmbeddedHttp2Exception(cause); + if (e != null) { + clientException = e; + clientLatch.countDown(); + } else { + super.exceptionCaught(ctx, cause); + } + } + }); + p.addLast(new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + if (evt == Http2ConnectionPrefaceAndSettingsFrameWrittenEvent.INSTANCE) { + prefaceWrittenLatch.countDown(); + ctx.pipeline().remove(this); + } + } + }); + p.addLast(new ChannelInboundHandlerAdapter() { + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + clientHandlersAddedLatch.countDown(); + } + }); + } + }); + + serverChannel = sb.bind(new LocalAddress(getClass())).sync().channel(); + + ChannelFuture ccf = cb.connect(serverChannel.localAddress()); + assertTrue(ccf.awaitUninterruptibly().isSuccess()); + clientChannel = ccf.channel(); + assertTrue(prefaceWrittenLatch.await(5, SECONDS)); + assertTrue(serverChannelLatch.await(5, SECONDS)); + assertTrue(clientHandlersAddedLatch.await(5, SECONDS)); + } + + private void cleanupCapturedRequests() { + if (capturedRequests != null) { + for (FullHttpMessage capturedRequest : capturedRequests) { + capturedRequest.release(); + } + capturedRequests = null; + } + } + + private void cleanupCapturedResponses() { + if (capturedResponses != null) { + for (FullHttpMessage capturedResponse : capturedResponses) { + capturedResponse.release(); + } + capturedResponses = null; + } + } + + private void awaitRequests() throws Exception { + assertTrue(serverLatch.await(5, SECONDS)); + } + + private void awaitResponses() throws Exception { + assertTrue(clientLatch.await(5, SECONDS)); + } + + private void awaitRequests2() throws Exception { + assertTrue(serverLatch2.await(5, SECONDS)); + } + + private void awaitResponses2() throws Exception { + assertTrue(clientLatch2.await(5, SECONDS)); + } + + private ChannelHandlerContext ctxClient() { + return clientChannel.pipeline().firstContext(); + } + + private ChannelPromise newPromiseClient() { + return ctxClient().newPromise(); + } + + private ChannelHandlerContext ctxServer() { + return serverConnectedChannel.pipeline().firstContext(); + } + + private ChannelPromise newPromiseServer() { + return ctxServer().newPromise(); + } + + private interface HttpResponseListener { + void messageReceived(HttpObject obj); + } + + private interface HttpSettingsListener { + void messageReceived(Http2Settings settings); + } + + private static final class HttpResponseDelegator extends SimpleChannelInboundHandler { + private final HttpResponseListener listener; + private final CountDownLatch latch; + private final CountDownLatch latch2; + + HttpResponseDelegator(HttpResponseListener listener, CountDownLatch latch, CountDownLatch latch2) { + super(false); + this.listener = listener; + this.latch = latch; + this.latch2 = latch2; + } + + @Override + protected void channelRead0(ChannelHandlerContext ctx, HttpObject msg) throws Exception { + listener.messageReceived(msg); + latch.countDown(); + latch2.countDown(); + } + } + + private static final class HttpSettingsDelegator extends SimpleChannelInboundHandler { + private final HttpSettingsListener listener; + private final CountDownLatch latch; + + HttpSettingsDelegator(HttpSettingsListener listener, CountDownLatch latch) { + super(false); + this.listener = listener; + this.latch = latch; + } + + @Override + protected void channelRead0(ChannelHandlerContext ctx, Http2Settings settings) throws Exception { + listener.messageReceived(settings); + latch.countDown(); + } + } +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/LastInboundHandler.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/LastInboundHandler.java new file mode 100644 index 0000000..9dec606 --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/LastInboundHandler.java @@ -0,0 +1,222 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.codec.http2; + +import io.netty.channel.Channel; +import io.netty.channel.ChannelDuplexHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.internal.PlatformDependent; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.locks.LockSupport; + +import static io.netty.util.internal.ObjectUtil.checkNotNull; +import static java.util.concurrent.TimeUnit.MILLISECONDS; + +/** + * Channel handler that allows to easily access inbound messages. + */ +public class LastInboundHandler extends ChannelDuplexHandler { + private final List queue = new ArrayList(); + private final Consumer channelReadCompleteConsumer; + private Throwable lastException; + private ChannelHandlerContext ctx; + private boolean channelActive; + private String writabilityStates = ""; + + // TODO(scott): use JDK 8's Consumer + public interface Consumer { + void accept(T obj); + } + + private static final Consumer NOOP_CONSUMER = new Consumer() { + @Override + public void accept(Object obj) { + } + }; + + @SuppressWarnings("unchecked") + public static Consumer noopConsumer() { + return (Consumer) NOOP_CONSUMER; + } + + public LastInboundHandler() { + this(LastInboundHandler.noopConsumer()); + } + + public LastInboundHandler(Consumer channelReadCompleteConsumer) { + this.channelReadCompleteConsumer = checkNotNull(channelReadCompleteConsumer, "channelReadCompleteConsumer"); + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + super.handlerAdded(ctx); + this.ctx = ctx; + } + + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + if (channelActive) { + throw new IllegalStateException("channelActive may only be fired once."); + } + channelActive = true; + super.channelActive(ctx); + } + + public boolean isChannelActive() { + return channelActive; + } + + public String writabilityStates() { + return writabilityStates; + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + if (!channelActive) { + throw new IllegalStateException("channelInactive may only be fired once after channelActive."); + } + channelActive = false; + super.channelInactive(ctx); + } + + @Override + public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception { + if ("".equals(writabilityStates)) { + writabilityStates = String.valueOf(ctx.channel().isWritable()); + } else { + writabilityStates += "," + ctx.channel().isWritable(); + } + super.channelWritabilityChanged(ctx); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + queue.add(msg); + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + channelReadCompleteConsumer.accept(ctx); + } + + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + queue.add(new UserEvent(evt)); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + if (lastException != null) { + cause.printStackTrace(); + } else { + lastException = cause; + } + } + + public void checkException() throws Exception { + if (lastException == null) { + return; + } + Throwable t = lastException; + lastException = null; + PlatformDependent.throwException(t); + } + + @SuppressWarnings("unchecked") + public T readInbound() { + for (int i = 0; i < queue.size(); i++) { + Object o = queue.get(i); + if (!(o instanceof UserEvent)) { + queue.remove(i); + return (T) o; + } + } + + return null; + } + + public T blockingReadInbound() { + T msg; + while ((msg = readInbound()) == null) { + LockSupport.parkNanos(MILLISECONDS.toNanos(10)); + } + return msg; + } + + @SuppressWarnings("unchecked") + public T readUserEvent() { + for (int i = 0; i < queue.size(); i++) { + Object o = queue.get(i); + if (o instanceof UserEvent) { + queue.remove(i); + return (T) ((UserEvent) o).evt; + } + } + + return null; + } + + /** + * Useful to test order of events and messages. + */ + @SuppressWarnings("unchecked") + public T readInboundMessageOrUserEvent() { + if (queue.isEmpty()) { + return null; + } + Object o = queue.remove(0); + if (o instanceof UserEvent) { + return (T) ((UserEvent) o).evt; + } + return (T) o; + } + + public void writeOutbound(Object... msgs) throws Exception { + for (Object msg : msgs) { + ctx.write(msg); + } + ctx.flush(); + EmbeddedChannel ch = (EmbeddedChannel) ctx.channel(); + ch.runPendingTasks(); + ch.checkException(); + checkException(); + } + + public void finishAndReleaseAll() throws Exception { + checkException(); + Object o; + while ((o = readInboundMessageOrUserEvent()) != null) { + ReferenceCountUtil.release(o); + } + } + + public Channel channel() { + return ctx.channel(); + } + + private static final class UserEvent { + private final Object evt; + + UserEvent(Object evt) { + this.evt = evt; + } + } +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/ReadOnlyHttp2HeadersTest.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/ReadOnlyHttp2HeadersTest.java new file mode 100644 index 0000000..66dcf16 --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/ReadOnlyHttp2HeadersTest.java @@ -0,0 +1,298 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.util.AsciiString; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import java.util.Iterator; +import java.util.Map; +import java.util.NoSuchElementException; + +import static io.netty.handler.codec.http2.DefaultHttp2HeadersTest.*; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class ReadOnlyHttp2HeadersTest { + @Test + public void notKeyValuePairThrows() { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + ReadOnlyHttp2Headers.trailers(false, new AsciiString[]{ null }); + } + }); + } + + @Test + public void nullTrailersNotAllowed() { + assertThrows(NullPointerException.class, new Executable() { + @Override + public void execute() { + ReadOnlyHttp2Headers.trailers(false, (AsciiString[]) null); + } + }); + } + + @Test + public void nullHeaderNameNotChecked() { + ReadOnlyHttp2Headers.trailers(false, null, null); + } + + @Test + public void nullHeaderNameValidated() { + assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() { + ReadOnlyHttp2Headers.trailers(true, null, new AsciiString("foo")); + } + }); + } + + @Test + public void pseudoHeaderNotAllowedAfterNonPseudoHeaders() { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + ReadOnlyHttp2Headers.trailers(true, new AsciiString(":scheme"), new AsciiString("foo"), + new AsciiString("othername"), new AsciiString("goo"), + new AsciiString(":path"), new AsciiString("val")); + } + }); + } + + @Test + public void nullValuesAreNotAllowed() { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + ReadOnlyHttp2Headers.trailers(true, new AsciiString("foo"), null); + } + }); + } + + @Test + public void emptyHeaderNameAllowed() { + ReadOnlyHttp2Headers.trailers(false, AsciiString.EMPTY_STRING, new AsciiString("foo")); + } + + @Test + public void testPseudoHeadersMustComeFirstWhenIteratingServer() { + Http2Headers headers = newServerHeaders(); + verifyPseudoHeadersFirst(headers); + } + + @Test + public void testPseudoHeadersMustComeFirstWhenIteratingClient() { + Http2Headers headers = newClientHeaders(); + verifyPseudoHeadersFirst(headers); + } + + @Test + public void testIteratorReadOnlyClient() { + assertThrows(UnsupportedOperationException.class, new Executable() { + @Override + public void execute() { + testIteratorReadOnly(newClientHeaders()); + } + }); + } + + @Test + public void testIteratorReadOnlyServer() { + assertThrows(UnsupportedOperationException.class, new Executable() { + @Override + public void execute() { + testIteratorReadOnly(newServerHeaders()); + } + }); + } + + @Test + public void testIteratorReadOnlyTrailers() { + assertThrows(UnsupportedOperationException.class, new Executable() { + @Override + public void execute() { + testIteratorReadOnly(newTrailers()); + } + }); + } + + @Test + public void testIteratorEntryReadOnlyClient() { + assertThrows(UnsupportedOperationException.class, new Executable() { + @Override + public void execute() { + testIteratorEntryReadOnly(newClientHeaders()); + } + }); + } + + @Test + public void testIteratorEntryReadOnlyServer() { + assertThrows(UnsupportedOperationException.class, new Executable() { + @Override + public void execute() { + testIteratorEntryReadOnly(newServerHeaders()); + } + }); + } + + @Test + public void testIteratorEntryReadOnlyTrailers() { + assertThrows(UnsupportedOperationException.class, new Executable() { + @Override + public void execute() { + testIteratorEntryReadOnly(newTrailers()); + } + }); + } + + @Test + public void testSize() { + Http2Headers headers = newTrailers(); + assertEquals(otherHeaders().length / 2, headers.size()); + } + + @Test + public void testIsNotEmpty() { + Http2Headers headers = newTrailers(); + assertFalse(headers.isEmpty()); + } + + @Test + public void testIsEmpty() { + Http2Headers headers = ReadOnlyHttp2Headers.trailers(false); + assertTrue(headers.isEmpty()); + } + + @Test + public void testContainsName() { + Http2Headers headers = newClientHeaders(); + assertTrue(headers.contains("Name1")); + assertTrue(headers.contains(Http2Headers.PseudoHeaderName.PATH.value())); + assertFalse(headers.contains(Http2Headers.PseudoHeaderName.STATUS.value())); + assertFalse(headers.contains("a missing header")); + } + + @Test + public void testContainsNameAndValue() { + Http2Headers headers = newClientHeaders(); + assertTrue(headers.contains("Name1", "value1")); + assertFalse(headers.contains("Name1", "Value1")); + assertTrue(headers.contains("name2", "Value2", true)); + assertFalse(headers.contains("name2", "Value2", false)); + assertTrue(headers.contains(Http2Headers.PseudoHeaderName.PATH.value(), "/foo")); + assertFalse(headers.contains(Http2Headers.PseudoHeaderName.STATUS.value(), "200")); + assertFalse(headers.contains("a missing header", "a missing value")); + } + + @Test + public void testGet() { + Http2Headers headers = newClientHeaders(); + assertTrue(AsciiString.contentEqualsIgnoreCase("value1", headers.get("Name1"))); + assertTrue(AsciiString.contentEqualsIgnoreCase("/foo", + headers.get(Http2Headers.PseudoHeaderName.PATH.value()))); + assertNull(headers.get(Http2Headers.PseudoHeaderName.STATUS.value())); + assertNull(headers.get("a missing header")); + } + + @Test + public void testClientOtherValueIterator() { + testValueIteratorSingleValue(newClientHeaders(), "name2", "value2"); + } + + @Test + public void testClientPsuedoValueIterator() { + testValueIteratorSingleValue(newClientHeaders(), ":path", "/foo"); + } + + @Test + public void testServerPsuedoValueIterator() { + testValueIteratorSingleValue(newServerHeaders(), ":status", "200"); + } + + @Test + public void testEmptyValueIterator() { + Http2Headers headers = newServerHeaders(); + final Iterator itr = headers.valueIterator("foo"); + assertFalse(itr.hasNext()); + assertThrows(NoSuchElementException.class, new Executable() { + @Override + public void execute() { + itr.next(); + } + }); + } + + @Test + public void testIteratorMultipleValues() { + Http2Headers headers = ReadOnlyHttp2Headers.serverHeaders(false, new AsciiString("200"), + new AsciiString("name2"), new AsciiString("value1"), + new AsciiString("name1"), new AsciiString("value2"), + new AsciiString("name2"), new AsciiString("value3")); + Iterator itr = headers.valueIterator("name2"); + assertTrue(itr.hasNext()); + assertTrue(AsciiString.contentEqualsIgnoreCase("value1", itr.next())); + assertTrue(itr.hasNext()); + assertTrue(AsciiString.contentEqualsIgnoreCase("value3", itr.next())); + assertFalse(itr.hasNext()); + } + + private static void testValueIteratorSingleValue(Http2Headers headers, CharSequence name, CharSequence value) { + Iterator itr = headers.valueIterator(name); + assertTrue(itr.hasNext()); + assertTrue(AsciiString.contentEqualsIgnoreCase(value, itr.next())); + assertFalse(itr.hasNext()); + } + + private static void testIteratorReadOnly(Http2Headers headers) { + Iterator> itr = headers.iterator(); + assertTrue(itr.hasNext()); + itr.remove(); + } + + private static void testIteratorEntryReadOnly(Http2Headers headers) { + Iterator> itr = headers.iterator(); + assertTrue(itr.hasNext()); + itr.next().setValue("foo"); + } + + private static ReadOnlyHttp2Headers newServerHeaders() { + return ReadOnlyHttp2Headers.serverHeaders(false, new AsciiString("200"), otherHeaders()); + } + + private static ReadOnlyHttp2Headers newClientHeaders() { + return ReadOnlyHttp2Headers.clientHeaders(false, new AsciiString("meth"), new AsciiString("/foo"), + new AsciiString("schemer"), new AsciiString("respect_my_authority"), otherHeaders()); + } + + private static ReadOnlyHttp2Headers newTrailers() { + return ReadOnlyHttp2Headers.trailers(false, otherHeaders()); + } + + private static AsciiString[] otherHeaders() { + return new AsciiString[] { + new AsciiString("name1"), new AsciiString("value1"), + new AsciiString("name2"), new AsciiString("value2"), + new AsciiString("name3"), new AsciiString("value3") + }; + } +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/StreamBufferingEncoderTest.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/StreamBufferingEncoderTest.java new file mode 100644 index 0000000..3089e6a --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/StreamBufferingEncoderTest.java @@ -0,0 +1,581 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package io.netty.handler.codec.http2; + +import static io.netty.buffer.Unpooled.EMPTY_BUFFER; +import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_MAX_FRAME_SIZE; +import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_PRIORITY_WEIGHT; +import static io.netty.handler.codec.http2.Http2CodecUtil.SMALLEST_MAX_CONCURRENT_STREAMS; +import static io.netty.handler.codec.http2.Http2Error.CANCEL; +import static io.netty.handler.codec.http2.Http2Stream.State.HALF_CLOSED_LOCAL; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.anyBoolean; +import static org.mockito.Mockito.anyInt; +import static org.mockito.Mockito.anyLong; +import static org.mockito.Mockito.anyShort; +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.buffer.UnpooledByteBufAllocator; +import io.netty.channel.Channel; +import io.netty.channel.ChannelConfig; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelMetadata; +import io.netty.channel.ChannelPromise; +import io.netty.channel.DefaultChannelPromise; +import io.netty.channel.DefaultMessageSizeEstimator; +import io.netty.handler.codec.http2.StreamBufferingEncoder.Http2GoAwayException; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.concurrent.EventExecutor; +import io.netty.util.concurrent.ImmediateEventExecutor; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; +import org.mockito.verification.VerificationMode; + +import java.util.ArrayList; +import java.util.List; + +/** + * Tests for {@link StreamBufferingEncoder}. + */ +public class StreamBufferingEncoderTest { + + private StreamBufferingEncoder encoder; + + private Http2Connection connection; + + @Mock + private Http2FrameWriter writer; + + @Mock + private ChannelHandlerContext ctx; + + @Mock + private Channel channel; + + @Mock + private Channel.Unsafe unsafe; + + @Mock + private ChannelConfig config; + + @Mock + private EventExecutor executor; + + /** + * Init fields and do mocking. + */ + @BeforeEach + public void setup() throws Exception { + MockitoAnnotations.initMocks(this); + + Http2FrameWriter.Configuration configuration = mock(Http2FrameWriter.Configuration.class); + Http2FrameSizePolicy frameSizePolicy = mock(Http2FrameSizePolicy.class); + when(writer.configuration()).thenReturn(configuration); + when(configuration.frameSizePolicy()).thenReturn(frameSizePolicy); + when(frameSizePolicy.maxFrameSize()).thenReturn(DEFAULT_MAX_FRAME_SIZE); + when(writer.writeData(any(ChannelHandlerContext.class), anyInt(), any(ByteBuf.class), anyInt(), anyBoolean(), + any(ChannelPromise.class))).thenAnswer(successAnswer()); + when(writer.writeRstStream(eq(ctx), anyInt(), anyLong(), any(ChannelPromise.class))).thenAnswer( + successAnswer()); + when(writer.writeGoAway(any(ChannelHandlerContext.class), anyInt(), anyLong(), any(ByteBuf.class), + any(ChannelPromise.class))) + .thenAnswer(successAnswer()); + when(writer.writeHeaders(any(ChannelHandlerContext.class), anyInt(), any(Http2Headers.class), + anyInt(), anyBoolean(), any(ChannelPromise.class))).thenAnswer(noopAnswer()); + when(writer.writeHeaders(any(ChannelHandlerContext.class), anyInt(), any(Http2Headers.class), + anyInt(), anyShort(), anyBoolean(), anyInt(), anyBoolean(), any(ChannelPromise.class))) + .thenAnswer(noopAnswer()); + + connection = new DefaultHttp2Connection(false); + connection.remote().flowController(new DefaultHttp2RemoteFlowController(connection)); + connection.local().flowController(new DefaultHttp2LocalFlowController(connection).frameWriter(writer)); + + DefaultHttp2ConnectionEncoder defaultEncoder = + new DefaultHttp2ConnectionEncoder(connection, writer); + encoder = new StreamBufferingEncoder(defaultEncoder); + DefaultHttp2ConnectionDecoder decoder = + new DefaultHttp2ConnectionDecoder(connection, encoder, mock(Http2FrameReader.class)); + Http2ConnectionHandler handler = new Http2ConnectionHandlerBuilder() + .frameListener(mock(Http2FrameListener.class)) + .codec(decoder, encoder).build(); + + // Set LifeCycleManager on encoder and decoder + when(ctx.channel()).thenReturn(channel); + when(ctx.alloc()).thenReturn(UnpooledByteBufAllocator.DEFAULT); + when(channel.alloc()).thenReturn(UnpooledByteBufAllocator.DEFAULT); + when(executor.inEventLoop()).thenReturn(true); + doAnswer(new Answer() { + @Override + public ChannelPromise answer(InvocationOnMock invocation) throws Throwable { + return newPromise(); + } + }).when(ctx).newPromise(); + when(ctx.executor()).thenReturn(executor); + when(channel.isActive()).thenReturn(false); + when(channel.config()).thenReturn(config); + when(channel.isWritable()).thenReturn(true); + when(channel.bytesBeforeUnwritable()).thenReturn(Long.MAX_VALUE); + when(config.getWriteBufferHighWaterMark()).thenReturn(Integer.MAX_VALUE); + when(config.getMessageSizeEstimator()).thenReturn(DefaultMessageSizeEstimator.DEFAULT); + ChannelMetadata metadata = new ChannelMetadata(false, 16); + when(channel.metadata()).thenReturn(metadata); + when(channel.unsafe()).thenReturn(unsafe); + handler.handlerAdded(ctx); + } + + @AfterEach + public void teardown() { + // Close and release any buffered frames. + encoder.close(); + } + + @Test + public void multipleWritesToActiveStream() { + encoder.writeSettingsAck(ctx, newPromise()); + encoderWriteHeaders(3, newPromise()); + assertEquals(0, encoder.numBufferedStreams()); + ByteBuf data = data(); + final int expectedBytes = data.readableBytes() * 3; + encoder.writeData(ctx, 3, data, 0, false, newPromise()); + encoder.writeData(ctx, 3, data(), 0, false, newPromise()); + encoder.writeData(ctx, 3, data(), 0, false, newPromise()); + encoderWriteHeaders(3, newPromise()); + + writeVerifyWriteHeaders(times(1), 3); + // Contiguous data writes are coalesced + ArgumentCaptor bufCaptor = ArgumentCaptor.forClass(ByteBuf.class); + verify(writer, times(1)) + .writeData(eq(ctx), eq(3), bufCaptor.capture(), eq(0), eq(false), any(ChannelPromise.class)); + assertEquals(expectedBytes, bufCaptor.getValue().readableBytes()); + } + + @Test + public void ensureCanCreateNextStreamWhenStreamCloses() { + encoder.writeSettingsAck(ctx, newPromise()); + setMaxConcurrentStreams(1); + + encoderWriteHeaders(3, newPromise()); + assertEquals(0, encoder.numBufferedStreams()); + + // This one gets buffered. + encoderWriteHeaders(5, newPromise()); + assertEquals(1, connection.numActiveStreams()); + assertEquals(1, encoder.numBufferedStreams()); + + // Now prevent us from creating another stream. + setMaxConcurrentStreams(0); + + // Close the previous stream. + connection.stream(3).close(); + + // Ensure that no streams are currently active and that only the HEADERS from the first + // stream were written. + writeVerifyWriteHeaders(times(1), 3); + writeVerifyWriteHeaders(never(), 5); + assertEquals(0, connection.numActiveStreams()); + assertEquals(1, encoder.numBufferedStreams()); + } + + @Test + public void alternatingWritesToActiveAndBufferedStreams() { + encoder.writeSettingsAck(ctx, newPromise()); + setMaxConcurrentStreams(1); + + encoderWriteHeaders(3, newPromise()); + assertEquals(0, encoder.numBufferedStreams()); + + encoderWriteHeaders(5, newPromise()); + assertEquals(1, connection.numActiveStreams()); + assertEquals(1, encoder.numBufferedStreams()); + + encoder.writeData(ctx, 3, EMPTY_BUFFER, 0, false, newPromise()); + writeVerifyWriteHeaders(times(1), 3); + encoder.writeData(ctx, 5, EMPTY_BUFFER, 0, false, newPromise()); + verify(writer, never()) + .writeData(eq(ctx), eq(5), any(ByteBuf.class), eq(0), eq(false), eq(newPromise())); + } + + @Test + public void bufferingNewStreamFailsAfterGoAwayReceived() throws Http2Exception { + encoder.writeSettingsAck(ctx, newPromise()); + setMaxConcurrentStreams(0); + connection.goAwayReceived(1, 8, EMPTY_BUFFER); + + ChannelPromise promise = newPromise(); + encoderWriteHeaders(3, promise); + assertEquals(0, encoder.numBufferedStreams()); + assertTrue(promise.isDone()); + assertFalse(promise.isSuccess()); + } + + @Test + public void receivingGoAwayFailsBufferedStreams() throws Http2Exception { + encoder.writeSettingsAck(ctx, newPromise()); + setMaxConcurrentStreams(5); + + int streamId = 3; + List futures = new ArrayList(); + for (int i = 0; i < 9; i++) { + futures.add(encoderWriteHeaders(streamId, newPromise())); + streamId += 2; + } + assertEquals(5, connection.numActiveStreams()); + assertEquals(4, encoder.numBufferedStreams()); + + connection.goAwayReceived(11, 8, EMPTY_BUFFER); + + assertEquals(5, connection.numActiveStreams()); + assertEquals(0, encoder.numBufferedStreams()); + int failCount = 0; + for (ChannelFuture f : futures) { + if (f.cause() != null) { + assertTrue(f.cause() instanceof Http2GoAwayException); + failCount++; + } + } + assertEquals(4, failCount); + } + + @Test + public void receivingGoAwayFailsNewStreamIfMaxConcurrentStreamsReached() throws Http2Exception { + encoder.writeSettingsAck(ctx, newPromise()); + setMaxConcurrentStreams(1); + encoderWriteHeaders(3, newPromise()); + connection.goAwayReceived(11, 8, EMPTY_BUFFER); + ChannelFuture f = encoderWriteHeaders(5, newPromise()); + + assertTrue(f.cause() instanceof Http2GoAwayException); + assertEquals(0, encoder.numBufferedStreams()); + } + + @Test + public void sendingGoAwayShouldNotFailStreams() { + encoder.writeSettingsAck(ctx, newPromise()); + setMaxConcurrentStreams(1); + + when(writer.writeHeaders(any(ChannelHandlerContext.class), anyInt(), any(Http2Headers.class), anyInt(), + anyBoolean(), any(ChannelPromise.class))).thenAnswer(successAnswer()); + when(writer.writeHeaders(any(ChannelHandlerContext.class), anyInt(), any(Http2Headers.class), anyInt(), + anyShort(), anyBoolean(), anyInt(), anyBoolean(), any(ChannelPromise.class))).thenAnswer(successAnswer()); + + ChannelFuture f1 = encoderWriteHeaders(3, newPromise()); + assertEquals(0, encoder.numBufferedStreams()); + ChannelFuture f2 = encoderWriteHeaders(5, newPromise()); + assertEquals(1, encoder.numBufferedStreams()); + ChannelFuture f3 = encoderWriteHeaders(7, newPromise()); + assertEquals(2, encoder.numBufferedStreams()); + + ByteBuf empty = Unpooled.buffer(0); + encoder.writeGoAway(ctx, 3, CANCEL.code(), empty, newPromise()); + + assertEquals(1, connection.numActiveStreams()); + assertEquals(2, encoder.numBufferedStreams()); + assertFalse(f1.isDone()); + assertFalse(f2.isDone()); + assertFalse(f3.isDone()); + } + + @Test + public void endStreamDoesNotFailBufferedStream() { + encoder.writeSettingsAck(ctx, newPromise()); + setMaxConcurrentStreams(0); + + encoderWriteHeaders(3, newPromise()); + assertEquals(1, encoder.numBufferedStreams()); + + encoder.writeData(ctx, 3, EMPTY_BUFFER, 0, true, newPromise()); + + assertEquals(0, connection.numActiveStreams()); + assertEquals(1, encoder.numBufferedStreams()); + + // Simulate that we received a SETTINGS frame which + // increased MAX_CONCURRENT_STREAMS to 1. + setMaxConcurrentStreams(1); + encoder.writeSettingsAck(ctx, newPromise()); + + assertEquals(1, connection.numActiveStreams()); + assertEquals(0, encoder.numBufferedStreams()); + assertEquals(HALF_CLOSED_LOCAL, connection.stream(3).state()); + } + + @Test + public void rstStreamClosesBufferedStream() { + encoder.writeSettingsAck(ctx, newPromise()); + setMaxConcurrentStreams(0); + + encoderWriteHeaders(3, newPromise()); + assertEquals(1, encoder.numBufferedStreams()); + + ChannelPromise rstStreamPromise = newPromise(); + encoder.writeRstStream(ctx, 3, CANCEL.code(), rstStreamPromise); + assertTrue(rstStreamPromise.isSuccess()); + assertEquals(0, encoder.numBufferedStreams()); + } + + @Test + public void bufferUntilActiveStreamsAreReset() throws Exception { + encoder.writeSettingsAck(ctx, newPromise()); + setMaxConcurrentStreams(1); + + encoderWriteHeaders(3, newPromise()); + assertEquals(0, encoder.numBufferedStreams()); + encoderWriteHeaders(5, newPromise()); + assertEquals(1, encoder.numBufferedStreams()); + encoderWriteHeaders(7, newPromise()); + assertEquals(2, encoder.numBufferedStreams()); + + writeVerifyWriteHeaders(times(1), 3); + writeVerifyWriteHeaders(never(), 5); + writeVerifyWriteHeaders(never(), 7); + + encoder.writeRstStream(ctx, 3, CANCEL.code(), newPromise()); + connection.remote().flowController().writePendingBytes(); + writeVerifyWriteHeaders(times(1), 5); + writeVerifyWriteHeaders(never(), 7); + assertEquals(1, connection.numActiveStreams()); + assertEquals(1, encoder.numBufferedStreams()); + + encoder.writeRstStream(ctx, 5, CANCEL.code(), newPromise()); + connection.remote().flowController().writePendingBytes(); + writeVerifyWriteHeaders(times(1), 7); + assertEquals(1, connection.numActiveStreams()); + assertEquals(0, encoder.numBufferedStreams()); + + encoder.writeRstStream(ctx, 7, CANCEL.code(), newPromise()); + assertEquals(0, connection.numActiveStreams()); + assertEquals(0, encoder.numBufferedStreams()); + } + + @Test + public void bufferUntilMaxStreamsIncreased() { + encoder.writeSettingsAck(ctx, newPromise()); + setMaxConcurrentStreams(2); + + encoderWriteHeaders(3, newPromise()); + encoderWriteHeaders(5, newPromise()); + encoderWriteHeaders(7, newPromise()); + encoderWriteHeaders(9, newPromise()); + assertEquals(2, encoder.numBufferedStreams()); + + writeVerifyWriteHeaders(times(1), 3); + writeVerifyWriteHeaders(times(1), 5); + writeVerifyWriteHeaders(never(), 7); + writeVerifyWriteHeaders(never(), 9); + + // Simulate that we received a SETTINGS frame which + // increased MAX_CONCURRENT_STREAMS to 5. + setMaxConcurrentStreams(5); + encoder.writeSettingsAck(ctx, newPromise()); + + assertEquals(0, encoder.numBufferedStreams()); + writeVerifyWriteHeaders(times(1), 7); + writeVerifyWriteHeaders(times(1), 9); + + encoderWriteHeaders(11, newPromise()); + + writeVerifyWriteHeaders(times(1), 11); + + assertEquals(5, connection.local().numActiveStreams()); + } + + @Test + public void bufferUntilSettingsReceived() throws Http2Exception { + int initialLimit = SMALLEST_MAX_CONCURRENT_STREAMS; + int numStreams = initialLimit * 2; + for (int ix = 0, nextStreamId = 3; ix < numStreams; ++ix, nextStreamId += 2) { + encoderWriteHeaders(nextStreamId, newPromise()); + if (ix < initialLimit) { + writeVerifyWriteHeaders(times(1), nextStreamId); + } else { + writeVerifyWriteHeaders(never(), nextStreamId); + } + } + assertEquals(numStreams / 2, encoder.numBufferedStreams()); + + // Simulate that we received a SETTINGS frame. + setMaxConcurrentStreams(initialLimit * 2); + + assertEquals(0, encoder.numBufferedStreams()); + assertEquals(numStreams, connection.local().numActiveStreams()); + } + + @Test + public void bufferUntilSettingsReceivedWithNoMaxConcurrentStreamValue() throws Http2Exception { + int initialLimit = SMALLEST_MAX_CONCURRENT_STREAMS; + int numStreams = initialLimit * 2; + for (int ix = 0, nextStreamId = 3; ix < numStreams; ++ix, nextStreamId += 2) { + encoderWriteHeaders(nextStreamId, newPromise()); + if (ix < initialLimit) { + writeVerifyWriteHeaders(times(1), nextStreamId); + } else { + writeVerifyWriteHeaders(never(), nextStreamId); + } + } + assertEquals(numStreams / 2, encoder.numBufferedStreams()); + + // Simulate that we received an empty SETTINGS frame. + encoder.remoteSettings(new Http2Settings()); + + assertEquals(0, encoder.numBufferedStreams()); + assertEquals(numStreams, connection.local().numActiveStreams()); + } + + @Test + public void exhaustedStreamsDoNotBuffer() throws Http2Exception { + // Write the highest possible stream ID for the client. + // This will cause the next stream ID to be negative. + encoderWriteHeaders(Integer.MAX_VALUE, newPromise()); + + // Disallow any further streams. + setMaxConcurrentStreams(0); + + // Simulate numeric overflow for the next stream ID. + ChannelFuture f = encoderWriteHeaders(-1, newPromise()); + + // Verify that the write fails. + assertNotNull(f.cause()); + } + + @Test + public void closedBufferedStreamReleasesByteBuf() { + encoder.writeSettingsAck(ctx, newPromise()); + setMaxConcurrentStreams(0); + ByteBuf data = mock(ByteBuf.class); + ChannelFuture f1 = encoderWriteHeaders(3, newPromise()); + assertEquals(1, encoder.numBufferedStreams()); + ChannelFuture f2 = encoder.writeData(ctx, 3, data, 0, false, newPromise()); + + ChannelPromise rstPromise = mock(ChannelPromise.class); + encoder.writeRstStream(ctx, 3, CANCEL.code(), rstPromise); + + assertEquals(0, encoder.numBufferedStreams()); + verify(rstPromise).setSuccess(); + assertTrue(f1.isSuccess()); + assertTrue(f2.isSuccess()); + verify(data).release(); + } + + @Test + public void closeShouldCancelAllBufferedStreams() throws Http2Exception { + encoder.writeSettingsAck(ctx, newPromise()); + connection.local().maxActiveStreams(0); + + ChannelFuture f1 = encoderWriteHeaders(3, newPromise()); + ChannelFuture f2 = encoderWriteHeaders(5, newPromise()); + ChannelFuture f3 = encoderWriteHeaders(7, newPromise()); + + encoder.close(); + assertNotNull(f1.cause()); + assertNotNull(f2.cause()); + assertNotNull(f3.cause()); + } + + @Test + public void headersAfterCloseShouldImmediatelyFail() { + encoder.writeSettingsAck(ctx, newPromise()); + encoder.close(); + + ChannelFuture f = encoderWriteHeaders(3, newPromise()); + assertNotNull(f.cause()); + } + + private void setMaxConcurrentStreams(int newValue) { + try { + encoder.remoteSettings(new Http2Settings().maxConcurrentStreams(newValue)); + // Flush the remote flow controller to write data + encoder.flowController().writePendingBytes(); + } catch (Http2Exception e) { + throw new RuntimeException(e); + } + } + + private ChannelFuture encoderWriteHeaders(int streamId, ChannelPromise promise) { + encoder.writeHeaders(ctx, streamId, new DefaultHttp2Headers(), 0, DEFAULT_PRIORITY_WEIGHT, + false, 0, false, promise); + try { + encoder.flowController().writePendingBytes(); + return promise; + } catch (Http2Exception e) { + throw new RuntimeException(e); + } + } + + private void writeVerifyWriteHeaders(VerificationMode mode, int streamId) { + verify(writer, mode).writeHeaders(eq(ctx), eq(streamId), any(Http2Headers.class), eq(0), + eq(DEFAULT_PRIORITY_WEIGHT), eq(false), eq(0), + eq(false), any(ChannelPromise.class)); + } + + private Answer successAnswer() { + return new Answer() { + @Override + public ChannelFuture answer(InvocationOnMock invocation) throws Throwable { + for (Object a : invocation.getArguments()) { + ReferenceCountUtil.safeRelease(a); + } + + ChannelPromise future = newPromise(); + future.setSuccess(); + return future; + } + }; + } + + private Answer noopAnswer() { + return new Answer() { + @Override + public ChannelFuture answer(InvocationOnMock invocation) throws Throwable { + for (Object a : invocation.getArguments()) { + if (a instanceof ChannelPromise) { + return (ChannelFuture) a; + } + } + return newPromise(); + } + }; + } + + private ChannelPromise newPromise() { + return new DefaultChannelPromise(channel, ImmediateEventExecutor.INSTANCE); + } + + private static ByteBuf data() { + ByteBuf buf = Unpooled.buffer(10); + for (int i = 0; i < buf.writableBytes(); i++) { + buf.writeByte(i); + } + return buf; + } +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/TestChannelInitializer.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/TestChannelInitializer.java new file mode 100644 index 0000000..c85d989 --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/TestChannelInitializer.java @@ -0,0 +1,122 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.codec.http2; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.channel.Channel; +import io.netty.channel.ChannelConfig; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandler.Sharable; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.RecvByteBufAllocator; +import io.netty.util.UncheckedBooleanSupplier; + +import java.util.concurrent.atomic.AtomicInteger; + +/** + * Channel initializer useful in tests. + */ +@Sharable +public class TestChannelInitializer extends ChannelInitializer { + ChannelHandler handler; + AtomicInteger maxReads; + + @Override + public void initChannel(Channel channel) { + if (handler != null) { + channel.pipeline().addLast(handler); + handler = null; + } + if (maxReads != null) { + channel.config().setRecvByteBufAllocator(new TestNumReadsRecvByteBufAllocator(maxReads)); + } + } + + /** + * Designed to read a single byte at a time to control the number of reads done at a fine granularity. + */ + static final class TestNumReadsRecvByteBufAllocator implements RecvByteBufAllocator { + private final AtomicInteger numReads; + private TestNumReadsRecvByteBufAllocator(AtomicInteger numReads) { + this.numReads = numReads; + } + + @Override + public ExtendedHandle newHandle() { + return new ExtendedHandle() { + private int attemptedBytesRead; + private int lastBytesRead; + private int numMessagesRead; + @Override + public ByteBuf allocate(ByteBufAllocator alloc) { + return alloc.ioBuffer(guess(), guess()); + } + + @Override + public int guess() { + return 1; // only ever allocate buffers of size 1 to ensure the number of reads is controlled. + } + + @Override + public void reset(ChannelConfig config) { + numMessagesRead = 0; + } + + @Override + public void incMessagesRead(int numMessages) { + numMessagesRead += numMessages; + } + + @Override + public void lastBytesRead(int bytes) { + lastBytesRead = bytes; + } + + @Override + public int lastBytesRead() { + return lastBytesRead; + } + + @Override + public void attemptedBytesRead(int bytes) { + attemptedBytesRead = bytes; + } + + @Override + public int attemptedBytesRead() { + return attemptedBytesRead; + } + + @Override + public boolean continueReading() { + return numMessagesRead < numReads.get(); + } + + @Override + public boolean continueReading(UncheckedBooleanSupplier maybeMoreDataSupplier) { + return continueReading(); + } + + @Override + public void readComplete() { + // Nothing needs to be done or adjusted after each read cycle is completed. + } + }; + } + } +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/TestHeaderListener.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/TestHeaderListener.java new file mode 100644 index 0000000..1240561 --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/TestHeaderListener.java @@ -0,0 +1,49 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/* + * Copyright 2014 Twitter, Inc. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.netty.handler.codec.http2; + +import java.util.List; + +final class TestHeaderListener extends DefaultHttp2Headers { + + private final List headers; + + TestHeaderListener(List headers) { + this.headers = headers; + } + + @Override + public TestHeaderListener add(CharSequence name, CharSequence value) { + headers.add(new HpackHeaderField(name, value)); + return this; + } +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/UniformStreamByteDistributorFlowControllerTest.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/UniformStreamByteDistributorFlowControllerTest.java new file mode 100644 index 0000000..d7b039d --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/UniformStreamByteDistributorFlowControllerTest.java @@ -0,0 +1,22 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +public class UniformStreamByteDistributorFlowControllerTest extends DefaultHttp2RemoteFlowControllerTest { + @Override + protected StreamByteDistributor newDistributor(Http2Connection connection) { + return new UniformStreamByteDistributor(connection); + } +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/UniformStreamByteDistributorTest.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/UniformStreamByteDistributorTest.java new file mode 100644 index 0000000..f1e2294 --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/UniformStreamByteDistributorTest.java @@ -0,0 +1,283 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import io.netty.handler.codec.http2.Http2TestUtil.TestStreamByteDistributorStreamState; +import io.netty.util.collection.IntObjectHashMap; +import io.netty.util.collection.IntObjectMap; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; +import org.mockito.verification.VerificationMode; + +import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_MIN_ALLOCATION_CHUNK; +import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_PRIORITY_WEIGHT; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.anyInt; +import static org.mockito.Mockito.atMost; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doNothing; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.same; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyNoMoreInteractions; + +/** + * Tests for {@link UniformStreamByteDistributor}. + */ +public class UniformStreamByteDistributorTest { + private static final int CHUNK_SIZE = DEFAULT_MIN_ALLOCATION_CHUNK; + + private static final int STREAM_A = 1; + private static final int STREAM_B = 3; + private static final int STREAM_C = 5; + private static final int STREAM_D = 7; + + private Http2Connection connection; + private UniformStreamByteDistributor distributor; + private IntObjectMap stateMap; + + @Mock + private StreamByteDistributor.Writer writer; + + @BeforeEach + public void setup() throws Http2Exception { + MockitoAnnotations.initMocks(this); + + stateMap = new IntObjectHashMap(); + connection = new DefaultHttp2Connection(false); + distributor = new UniformStreamByteDistributor(connection); + + // Assume we always write all the allocated bytes. + resetWriter(); + + connection.local().createStream(STREAM_A, false); + connection.local().createStream(STREAM_B, false); + Http2Stream streamC = connection.local().createStream(STREAM_C, false); + Http2Stream streamD = connection.local().createStream(STREAM_D, false); + setPriority(streamC.id(), STREAM_A, DEFAULT_PRIORITY_WEIGHT, false); + setPriority(streamD.id(), STREAM_A, DEFAULT_PRIORITY_WEIGHT, false); + } + + private Answer writeAnswer() { + return new Answer() { + @Override + public Void answer(InvocationOnMock in) throws Throwable { + Http2Stream stream = in.getArgument(0); + int numBytes = in.getArgument(1); + TestStreamByteDistributorStreamState state = stateMap.get(stream.id()); + state.pendingBytes -= numBytes; + state.hasFrame = state.pendingBytes > 0; + distributor.updateStreamableBytes(state); + return null; + } + }; + } + + private void resetWriter() { + reset(writer); + doAnswer(writeAnswer()).when(writer).write(any(Http2Stream.class), anyInt()); + } + + @Test + public void bytesUnassignedAfterProcessing() throws Http2Exception { + initState(STREAM_A, 1, true); + initState(STREAM_B, 2, true); + initState(STREAM_C, 3, true); + initState(STREAM_D, 4, true); + + assertFalse(write(10)); + verifyWrite(STREAM_A, 1); + verifyWrite(STREAM_B, 2); + verifyWrite(STREAM_C, 3); + verifyWrite(STREAM_D, 4); + verifyNoMoreInteractions(writer); + + assertFalse(write(10)); + verifyNoMoreInteractions(writer); + } + + @Test + public void connectionErrorForWriterException() throws Http2Exception { + initState(STREAM_A, 1, true); + initState(STREAM_B, 2, true); + initState(STREAM_C, 3, true); + initState(STREAM_D, 4, true); + + Exception fakeException = new RuntimeException("Fake exception"); + doThrow(fakeException).when(writer).write(same(stream(STREAM_C)), eq(3)); + + Http2Exception e = assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + write(10); + } + }); + assertFalse(Http2Exception.isStreamError(e)); + assertEquals(Http2Error.INTERNAL_ERROR, e.error()); + assertSame(fakeException, e.getCause()); + + verifyWrite(atMost(1), STREAM_A, 1); + verifyWrite(atMost(1), STREAM_B, 2); + verifyWrite(STREAM_C, 3); + verifyWrite(atMost(1), STREAM_D, 4); + + doNothing().when(writer).write(same(stream(STREAM_C)), eq(3)); + write(10); + verifyWrite(STREAM_A, 1); + verifyWrite(STREAM_B, 2); + verifyWrite(STREAM_C, 3); + verifyWrite(STREAM_D, 4); + } + + /** + * In this test, we verify that each stream is allocated a minimum chunk size. When bytes + * run out, the remaining streams will be next in line for the next iteration. + */ + @Test + public void minChunkShouldBeAllocatedPerStream() throws Http2Exception { + // Re-assign weights. + setPriority(STREAM_A, 0, (short) 50, false); + setPriority(STREAM_B, 0, (short) 200, false); + setPriority(STREAM_C, STREAM_A, (short) 100, false); + setPriority(STREAM_D, STREAM_A, (short) 100, false); + + // Update the streams. + initState(STREAM_A, CHUNK_SIZE, true); + initState(STREAM_B, CHUNK_SIZE, true); + initState(STREAM_C, CHUNK_SIZE, true); + initState(STREAM_D, CHUNK_SIZE, true); + + // Only write 3 * chunkSize, so that we'll only write to the first 3 streams. + int written = 3 * CHUNK_SIZE; + assertTrue(write(written)); + assertEquals(CHUNK_SIZE, captureWrite(STREAM_A)); + assertEquals(CHUNK_SIZE, captureWrite(STREAM_B)); + assertEquals(CHUNK_SIZE, captureWrite(STREAM_C)); + verifyNoMoreInteractions(writer); + + resetWriter(); + + // Now write again and verify that the last stream is written to. + assertFalse(write(CHUNK_SIZE)); + assertEquals(CHUNK_SIZE, captureWrite(STREAM_D)); + verifyNoMoreInteractions(writer); + } + + @Test + public void streamWithMoreDataShouldBeEnqueuedAfterWrite() throws Http2Exception { + // Give the stream a bunch of data. + initState(STREAM_A, 2 * CHUNK_SIZE, true); + + // Write only part of the data. + assertTrue(write(CHUNK_SIZE)); + assertEquals(CHUNK_SIZE, captureWrite(STREAM_A)); + verifyNoMoreInteractions(writer); + + resetWriter(); + + // Now write the rest of the data. + assertFalse(write(CHUNK_SIZE)); + assertEquals(CHUNK_SIZE, captureWrite(STREAM_A)); + verifyNoMoreInteractions(writer); + } + + @Test + public void emptyFrameAtHeadIsWritten() throws Http2Exception { + initState(STREAM_A, 10, true); + initState(STREAM_B, 0, true); + initState(STREAM_C, 0, true); + initState(STREAM_D, 10, true); + + assertTrue(write(10)); + verifyWrite(STREAM_A, 10); + verifyWrite(STREAM_B, 0); + verifyWrite(STREAM_C, 0); + verifyNoMoreInteractions(writer); + } + + @Test + public void streamWindowExhaustedDoesNotWrite() throws Http2Exception { + initState(STREAM_A, 0, true, false); + initState(STREAM_B, 0, true); + initState(STREAM_C, 0, true); + initState(STREAM_D, 0, true, false); + + assertFalse(write(10)); + verifyWrite(STREAM_B, 0); + verifyWrite(STREAM_C, 0); + verifyNoMoreInteractions(writer); + } + + @Test + public void streamWindowLargerThanIntDoesNotInfiniteLoop() throws Http2Exception { + initState(STREAM_A, Integer.MAX_VALUE + 1L, true, true); + assertTrue(write(Integer.MAX_VALUE)); + verifyWrite(STREAM_A, Integer.MAX_VALUE); + assertFalse(write(1)); + verifyWrite(STREAM_A, 1); + } + + private Http2Stream stream(int streamId) { + return connection.stream(streamId); + } + + private void initState(final int streamId, final long streamableBytes, final boolean hasFrame) { + initState(streamId, streamableBytes, hasFrame, hasFrame); + } + + private void initState(final int streamId, final long pendingBytes, final boolean hasFrame, + final boolean isWriteAllowed) { + final Http2Stream stream = stream(streamId); + TestStreamByteDistributorStreamState state = new TestStreamByteDistributorStreamState(stream, pendingBytes, + hasFrame, isWriteAllowed); + stateMap.put(streamId, state); + distributor.updateStreamableBytes(state); + } + + private void setPriority(int streamId, int parent, int weight, boolean exclusive) { + distributor.updateDependencyTree(streamId, parent, (short) weight, exclusive); + } + + private boolean write(int numBytes) throws Http2Exception { + return distributor.distribute(numBytes, writer); + } + + private void verifyWrite(int streamId, int numBytes) { + verify(writer).write(same(stream(streamId)), eq(numBytes)); + } + + private void verifyWrite(VerificationMode mode, int streamId, int numBytes) { + verify(writer, mode).write(same(stream(streamId)), eq(numBytes)); + } + + private int captureWrite(int streamId) { + ArgumentCaptor captor = ArgumentCaptor.forClass(Integer.class); + verify(writer).write(same(stream(streamId)), captor.capture()); + return captor.getValue(); + } +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/WeightedFairQueueByteDistributorDependencyTreeTest.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/WeightedFairQueueByteDistributorDependencyTreeTest.java new file mode 100644 index 0000000..349e876 --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/WeightedFairQueueByteDistributorDependencyTreeTest.java @@ -0,0 +1,980 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.MockitoAnnotations; + +import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_PRIORITY_WEIGHT; +import static io.netty.handler.codec.http2.Http2CodecUtil.MAX_WEIGHT; +import static io.netty.handler.codec.http2.Http2CodecUtil.MIN_WEIGHT; +import static io.netty.handler.codec.http2.WeightedFairQueueByteDistributor.INITIAL_CHILDREN_MAP_SIZE; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.anyInt; +import static org.mockito.Mockito.doAnswer; + +public class WeightedFairQueueByteDistributorDependencyTreeTest extends + AbstractWeightedFairQueueByteDistributorDependencyTest { + private static final int leadersId = 3; // js, css + private static final int unblockedId = 5; + private static final int backgroundId = 7; + private static final int speculativeId = 9; + private static final int followersId = 11; // images + private static final short leadersWeight = 201; + private static final short unblockedWeight = 101; + private static final short backgroundWeight = 1; + private static final short speculativeWeight = 1; + private static final short followersWeight = 1; + + @BeforeEach + public void setup() throws Http2Exception { + MockitoAnnotations.initMocks(this); + + setup(0); + } + + private void setup(int maxStateOnlySize) { + connection = new DefaultHttp2Connection(false); + distributor = new WeightedFairQueueByteDistributor(connection, maxStateOnlySize); + + // Assume we always write all the allocated bytes. + doAnswer(writeAnswer(false)).when(writer).write(any(Http2Stream.class), anyInt()); + } + + @Test + public void closingStreamWithChildrenDoesNotCauseConcurrentModification() throws Http2Exception { + // We create enough streams to wrap around the child array. We carefully craft the stream ids so that they hash + // codes overlap with respect to the child collection. If the implementation is not careful this may lead to a + // concurrent modification exception while promoting all children to the connection stream. + final Http2Stream streamA = connection.local().createStream(1, false); + final int numStreams = INITIAL_CHILDREN_MAP_SIZE - 1; + for (int i = 0, streamId = 3; i < numStreams; ++i, streamId += INITIAL_CHILDREN_MAP_SIZE) { + final Http2Stream stream = connection.local().createStream(streamId, false); + setPriority(stream.id(), streamA.id(), DEFAULT_PRIORITY_WEIGHT, false); + } + assertEquals(INITIAL_CHILDREN_MAP_SIZE, connection.numActiveStreams()); + streamA.close(); + assertEquals(numStreams, connection.numActiveStreams()); + } + + @Test + public void closeWhileIteratingDoesNotNPE() throws Http2Exception { + final Http2Stream streamA = connection.local().createStream(3, false); + final Http2Stream streamB = connection.local().createStream(5, false); + final Http2Stream streamC = connection.local().createStream(7, false); + setPriority(streamB.id(), streamA.id(), DEFAULT_PRIORITY_WEIGHT, false); + connection.forEachActiveStream(new Http2StreamVisitor() { + @Override + public boolean visit(Http2Stream stream) throws Http2Exception { + streamA.close(); + setPriority(streamB.id(), streamC.id(), DEFAULT_PRIORITY_WEIGHT, false); + return true; + } + }); + } + + @Test + public void localStreamCanDependUponIdleStream() throws Http2Exception { + setup(1); + + Http2Stream streamA = connection.local().createStream(1, false); + setPriority(3, streamA.id(), MIN_WEIGHT, true); + assertTrue(distributor.isChild(3, streamA.id(), MIN_WEIGHT)); + } + + @Test + public void remoteStreamCanDependUponIdleStream() throws Http2Exception { + setup(1); + + Http2Stream streamA = connection.remote().createStream(2, false); + setPriority(4, streamA.id(), MIN_WEIGHT, true); + assertTrue(distributor.isChild(4, streamA.id(), MIN_WEIGHT)); + } + + @Test + public void prioritizeShouldUseDefaults() throws Exception { + Http2Stream stream = connection.local().createStream(1, false); + assertTrue(distributor.isChild(stream.id(), connection.connectionStream().id(), DEFAULT_PRIORITY_WEIGHT)); + assertEquals(1, distributor.numChildren(connection.connectionStream().id())); + assertEquals(0, distributor.numChildren(stream.id())); + } + + @Test + public void reprioritizeWithNoChangeShouldDoNothing() throws Exception { + Http2Stream stream = connection.local().createStream(1, false); + setPriority(stream.id(), connection.connectionStream().id(), DEFAULT_PRIORITY_WEIGHT, false); + assertTrue(distributor.isChild(stream.id(), connection.connectionStream().id(), DEFAULT_PRIORITY_WEIGHT)); + assertEquals(1, distributor.numChildren(connection.connectionStream().id())); + assertEquals(0, distributor.numChildren(stream.id())); + } + + @Test + public void stateOnlyPriorityShouldBePreservedWhenStreamsAreCreatedAndClosed() throws Http2Exception { + setup(3); + + short weight3 = MIN_WEIGHT + 1; + short weight5 = (short) (weight3 + 1); + short weight7 = (short) (weight5 + 1); + setPriority(3, connection.connectionStream().id(), weight3, true); + setPriority(5, connection.connectionStream().id(), weight5, true); + setPriority(7, connection.connectionStream().id(), weight7, true); + + assertEquals(0, connection.numActiveStreams()); + verifyStateOnlyPriorityShouldBePreservedWhenStreamsAreCreated(weight3, weight5, weight7); + + // Now create stream objects and ensure the state and dependency tree is preserved. + Http2Stream streamA = connection.local().createStream(3, false); + Http2Stream streamB = connection.local().createStream(5, false); + Http2Stream streamC = connection.local().createStream(7, false); + + assertEquals(3, connection.numActiveStreams()); + verifyStateOnlyPriorityShouldBePreservedWhenStreamsAreCreated(weight3, weight5, weight7); + + // Close all the streams and ensure the state and dependency tree is preserved. + streamA.close(); + streamB.close(); + streamC.close(); + + assertEquals(0, connection.numActiveStreams()); + verifyStateOnlyPriorityShouldBePreservedWhenStreamsAreCreated(weight3, weight5, weight7); + } + + private void verifyStateOnlyPriorityShouldBePreservedWhenStreamsAreCreated(short weight3, short weight5, + short weight7) { + // Level 0 + assertEquals(1, distributor.numChildren(connection.connectionStream().id())); + + // Level 1 + assertTrue(distributor.isChild(7, connection.connectionStream().id(), weight7)); + assertEquals(1, distributor.numChildren(7)); + + // Level 2 + assertTrue(distributor.isChild(5, 7, weight5)); + assertEquals(1, distributor.numChildren(5)); + + // Level 3 + assertTrue(distributor.isChild(3, 5, weight3)); + assertEquals(0, distributor.numChildren(3)); + } + + @Test + public void fireFoxQoSStreamsRemainAfterDataStreamsAreClosed() throws Http2Exception { + // https://bitsup.blogspot.com/2015/01/http2-dependency-priorities-in-firefox.html + setup(5); + + setPriority(leadersId, connection.connectionStream().id(), leadersWeight, false); + setPriority(unblockedId, connection.connectionStream().id(), unblockedWeight, false); + setPriority(backgroundId, connection.connectionStream().id(), backgroundWeight, false); + setPriority(speculativeId, backgroundId, speculativeWeight, false); + setPriority(followersId, leadersId, followersWeight, false); + + verifyFireFoxQoSStreams(); + + // Simulate a HTML request + short htmlGetStreamWeight = 2; + Http2Stream htmlGetStream = connection.local().createStream(13, false); + setPriority(htmlGetStream.id(), followersId, htmlGetStreamWeight, false); + Http2Stream favIconStream = connection.local().createStream(15, false); + setPriority(favIconStream.id(), connection.connectionStream().id(), DEFAULT_PRIORITY_WEIGHT, false); + Http2Stream cssStream = connection.local().createStream(17, false); + setPriority(cssStream.id(), leadersId, DEFAULT_PRIORITY_WEIGHT, false); + Http2Stream jsStream = connection.local().createStream(19, false); + setPriority(jsStream.id(), leadersId, DEFAULT_PRIORITY_WEIGHT, false); + Http2Stream imageStream = connection.local().createStream(21, false); + setPriority(imageStream.id(), followersId, 1, false); + + // Level 0 + assertEquals(4, distributor.numChildren(connection.connectionStream().id())); + + // Level 1 + assertTrue(distributor.isChild(leadersId, connection.connectionStream().id(), leadersWeight)); + assertEquals(3, distributor.numChildren(leadersId)); + + assertTrue(distributor.isChild(unblockedId, connection.connectionStream().id(), unblockedWeight)); + assertEquals(0, distributor.numChildren(unblockedId)); + + assertTrue(distributor.isChild(backgroundId, connection.connectionStream().id(), backgroundWeight)); + assertEquals(1, distributor.numChildren(backgroundId)); + + assertTrue(distributor.isChild(favIconStream.id(), connection.connectionStream().id(), + DEFAULT_PRIORITY_WEIGHT)); + assertEquals(0, distributor.numChildren(favIconStream.id())); + + // Level 2 + assertTrue(distributor.isChild(followersId, leadersId, followersWeight)); + assertEquals(2, distributor.numChildren(followersId)); + + assertTrue(distributor.isChild(speculativeId, backgroundId, speculativeWeight)); + assertEquals(0, distributor.numChildren(speculativeId)); + + assertTrue(distributor.isChild(cssStream.id(), leadersId, DEFAULT_PRIORITY_WEIGHT)); + assertEquals(0, distributor.numChildren(cssStream.id())); + + assertTrue(distributor.isChild(jsStream.id(), leadersId, DEFAULT_PRIORITY_WEIGHT)); + assertEquals(0, distributor.numChildren(jsStream.id())); + + // Level 3 + assertTrue(distributor.isChild(htmlGetStream.id(), followersId, htmlGetStreamWeight)); + assertEquals(0, distributor.numChildren(htmlGetStream.id())); + + assertTrue(distributor.isChild(imageStream.id(), followersId, followersWeight)); + assertEquals(0, distributor.numChildren(imageStream.id())); + + // Close all the data streams and ensure the "priority only streams" are retained in the dependency tree. + htmlGetStream.close(); + favIconStream.close(); + cssStream.close(); + jsStream.close(); + imageStream.close(); + + verifyFireFoxQoSStreams(); + } + + private void verifyFireFoxQoSStreams() { + // Level 0 + assertEquals(3, distributor.numChildren(connection.connectionStream().id())); + + // Level 1 + assertTrue(distributor.isChild(leadersId, connection.connectionStream().id(), leadersWeight)); + assertEquals(1, distributor.numChildren(leadersId)); + + assertTrue(distributor.isChild(unblockedId, connection.connectionStream().id(), unblockedWeight)); + assertEquals(0, distributor.numChildren(unblockedId)); + + assertTrue(distributor.isChild(backgroundId, connection.connectionStream().id(), backgroundWeight)); + assertEquals(1, distributor.numChildren(backgroundId)); + + // Level 2 + assertTrue(distributor.isChild(followersId, leadersId, followersWeight)); + assertEquals(0, distributor.numChildren(followersId)); + + assertTrue(distributor.isChild(speculativeId, backgroundId, speculativeWeight)); + assertEquals(0, distributor.numChildren(speculativeId)); + } + + @Test + public void lowestPrecedenceStateShouldBeDropped() throws Http2Exception { + setup(3); + + short weight3 = MAX_WEIGHT; + short weight5 = (short) (weight3 - 1); + short weight7 = (short) (weight5 - 1); + short weight9 = (short) (weight7 - 1); + setPriority(3, connection.connectionStream().id(), weight3, true); + setPriority(5, connection.connectionStream().id(), weight5, true); + setPriority(7, connection.connectionStream().id(), weight7, false); + assertEquals(0, connection.numActiveStreams()); + verifyLowestPrecedenceStateShouldBeDropped1(weight3, weight5, weight7); + + // Attempt to create a new item in the dependency tree but the maximum amount of "state only" streams is meet + // so a stream will have to be dropped. Currently the new stream is the lowest "precedence" so it is dropped. + setPriority(9, 3, weight9, false); + assertEquals(0, connection.numActiveStreams()); + verifyLowestPrecedenceStateShouldBeDropped1(weight3, weight5, weight7); + + // Set the priority for stream 9 such that its depth in the dependency tree is numerically lower than stream 3, + // and therefore the dependency state associated with stream 3 will be dropped. + setPriority(9, 5, weight9, true); + verifyLowestPrecedenceStateShouldBeDropped2(weight9, weight5, weight7); + + // Test that stream which has been activated is lower priority than other streams that have not been activated. + Http2Stream streamA = connection.local().createStream(5, false); + streamA.close(); + verifyLowestPrecedenceStateShouldBeDropped2(weight9, weight5, weight7); + + // Stream 3 (hasn't been opened) should result in stream 5 being dropped. + // dropping stream 5 will distribute its weight to children (only 9) + setPriority(3, 9, weight3, false); + verifyLowestPrecedenceStateShouldBeDropped3(weight3, weight7, weight5); + + // Stream 5's state has been discarded so we should be able to re-insert this state. + setPriority(5, 0, weight5, false); + verifyLowestPrecedenceStateShouldBeDropped4(weight5, weight7, weight5); + + // All streams are at the same level, so stream ID should be used to drop the numeric lowest valued stream. + short weight11 = (short) (weight9 - 1); + setPriority(11, 0, weight11, false); + verifyLowestPrecedenceStateShouldBeDropped5(weight7, weight5, weight11); + } + + private void verifyLowestPrecedenceStateShouldBeDropped1(short weight3, short weight5, short weight7) { + // Level 0 + assertEquals(2, distributor.numChildren(connection.connectionStream().id())); + + // Level 1 + assertTrue(distributor.isChild(7, connection.connectionStream().id(), weight7)); + assertEquals(0, distributor.numChildren(7)); + + assertTrue(distributor.isChild(5, connection.connectionStream().id(), weight5)); + assertEquals(1, distributor.numChildren(5)); + + // Level 2 + assertTrue(distributor.isChild(3, 5, weight3)); + assertEquals(0, distributor.numChildren(3)); + } + + private void verifyLowestPrecedenceStateShouldBeDropped2(short weight9, short weight5, short weight7) { + // Level 0 + assertEquals(2, distributor.numChildren(connection.connectionStream().id())); + + // Level 1 + assertTrue(distributor.isChild(7, connection.connectionStream().id(), weight7)); + assertEquals(0, distributor.numChildren(7)); + + assertTrue(distributor.isChild(5, connection.connectionStream().id(), weight5)); + assertEquals(1, distributor.numChildren(5)); + + // Level 2 + assertTrue(distributor.isChild(9, 5, weight9)); + assertEquals(0, distributor.numChildren(9)); + } + + private void verifyLowestPrecedenceStateShouldBeDropped3(short weight3, short weight7, short weight9) { + // Level 0 + assertEquals(2, distributor.numChildren(connection.connectionStream().id())); + + // Level 1 + assertTrue(distributor.isChild(7, connection.connectionStream().id(), weight7)); + assertEquals(0, distributor.numChildren(7)); + + assertTrue(distributor.isChild(9, connection.connectionStream().id(), weight9)); + assertEquals(1, distributor.numChildren(9)); + + // Level 2 + assertTrue(distributor.isChild(3, 9, weight3)); + assertEquals(0, distributor.numChildren(3)); + } + + private void verifyLowestPrecedenceStateShouldBeDropped4(short weight5, short weight7, short weight9) { + // Level 0 + assertEquals(3, distributor.numChildren(connection.connectionStream().id())); + + // Level 1 + assertTrue(distributor.isChild(5, connection.connectionStream().id(), weight5)); + assertEquals(0, distributor.numChildren(5)); + + assertTrue(distributor.isChild(7, connection.connectionStream().id(), weight7)); + assertEquals(0, distributor.numChildren(7)); + + assertTrue(distributor.isChild(9, connection.connectionStream().id(), weight9)); + assertEquals(0, distributor.numChildren(9)); + } + + private void verifyLowestPrecedenceStateShouldBeDropped5(short weight7, short weight9, short weight11) { + // Level 0 + assertEquals(3, distributor.numChildren(connection.connectionStream().id())); + + // Level 1 + assertTrue(distributor.isChild(11, connection.connectionStream().id(), weight11)); + assertEquals(0, distributor.numChildren(11)); + + assertTrue(distributor.isChild(7, connection.connectionStream().id(), weight7)); + assertEquals(0, distributor.numChildren(7)); + + assertTrue(distributor.isChild(9, connection.connectionStream().id(), weight9)); + assertEquals(0, distributor.numChildren(9)); + } + + @Test + public void priorityOnlyStreamsArePreservedWhenReservedStreamsAreClosed() throws Http2Exception { + setup(1); + + short weight3 = MIN_WEIGHT; + setPriority(3, connection.connectionStream().id(), weight3, true); + + Http2Stream streamA = connection.local().createStream(5, false); + Http2Stream streamB = connection.remote().reservePushStream(4, streamA); + + // Level 0 + assertEquals(3, distributor.numChildren(connection.connectionStream().id())); + + // Level 1 + assertTrue(distributor.isChild(3, connection.connectionStream().id(), weight3)); + assertEquals(0, distributor.numChildren(3)); + + assertTrue(distributor.isChild(streamA.id(), connection.connectionStream().id(), DEFAULT_PRIORITY_WEIGHT)); + assertEquals(0, distributor.numChildren(streamA.id())); + + assertTrue(distributor.isChild(streamB.id(), connection.connectionStream().id(), DEFAULT_PRIORITY_WEIGHT)); + assertEquals(0, distributor.numChildren(streamB.id())); + + // Close both streams. + streamB.close(); + streamA.close(); + + // Level 0 + assertEquals(1, distributor.numChildren(connection.connectionStream().id())); + + // Level 1 + assertTrue(distributor.isChild(3, connection.connectionStream().id(), weight3)); + assertEquals(0, distributor.numChildren(3)); + } + + @Test + public void insertExclusiveShouldAddNewLevel() throws Exception { + Http2Stream streamA = connection.local().createStream(1, false); + Http2Stream streamB = connection.local().createStream(3, false); + Http2Stream streamC = connection.local().createStream(5, false); + Http2Stream streamD = connection.local().createStream(7, false); + + setPriority(streamB.id(), streamA.id(), DEFAULT_PRIORITY_WEIGHT, false); + setPriority(streamC.id(), streamA.id(), DEFAULT_PRIORITY_WEIGHT, false); + setPriority(streamD.id(), streamA.id(), DEFAULT_PRIORITY_WEIGHT, true); + + assertEquals(4, connection.numActiveStreams()); + + // Level 0 + assertEquals(1, distributor.numChildren(connection.connectionStream().id())); + + // Level 1 + assertTrue(distributor.isChild(streamA.id(), connection.connectionStream().id(), DEFAULT_PRIORITY_WEIGHT)); + assertEquals(1, distributor.numChildren(streamA.id())); + + // Level 2 + assertTrue(distributor.isChild(streamD.id(), streamA.id(), DEFAULT_PRIORITY_WEIGHT)); + assertEquals(2, distributor.numChildren(streamD.id())); + + // Level 3 + assertTrue(distributor.isChild(streamB.id(), streamD.id(), DEFAULT_PRIORITY_WEIGHT)); + assertEquals(0, distributor.numChildren(streamB.id())); + + assertTrue(distributor.isChild(streamC.id(), streamD.id(), DEFAULT_PRIORITY_WEIGHT)); + assertEquals(0, distributor.numChildren(streamC.id())); + } + + @Test + public void existingChildMadeExclusiveShouldNotCreateTreeCycle() throws Http2Exception { + Http2Stream streamA = connection.local().createStream(1, false); + Http2Stream streamB = connection.local().createStream(3, false); + Http2Stream streamC = connection.local().createStream(5, false); + Http2Stream streamD = connection.local().createStream(7, false); + + setPriority(streamB.id(), streamA.id(), DEFAULT_PRIORITY_WEIGHT, false); + setPriority(streamC.id(), streamA.id(), DEFAULT_PRIORITY_WEIGHT, false); + setPriority(streamD.id(), streamC.id(), DEFAULT_PRIORITY_WEIGHT, false); + + // Stream C is already dependent on Stream A, but now make that an exclusive dependency + setPriority(streamC.id(), streamA.id(), DEFAULT_PRIORITY_WEIGHT, true); + + assertEquals(4, connection.numActiveStreams()); + + // Level 0 + assertEquals(1, distributor.numChildren(connection.connectionStream().id())); + + // Level 1 + assertTrue(distributor.isChild(streamA.id(), connection.connectionStream().id(), DEFAULT_PRIORITY_WEIGHT)); + assertEquals(1, distributor.numChildren(streamA.id())); + + // Level 2 + assertTrue(distributor.isChild(streamC.id(), streamA.id(), DEFAULT_PRIORITY_WEIGHT)); + assertEquals(2, distributor.numChildren(streamC.id())); + + // Level 3 + assertTrue(distributor.isChild(streamB.id(), streamC.id(), DEFAULT_PRIORITY_WEIGHT)); + assertEquals(0, distributor.numChildren(streamB.id())); + + assertTrue(distributor.isChild(streamD.id(), streamC.id(), DEFAULT_PRIORITY_WEIGHT)); + assertEquals(0, distributor.numChildren(streamD.id())); + } + + @Test + public void newExclusiveChildShouldUpdateOldParentCorrectly() throws Http2Exception { + Http2Stream streamA = connection.local().createStream(1, false); + Http2Stream streamB = connection.local().createStream(3, false); + Http2Stream streamC = connection.local().createStream(5, false); + Http2Stream streamD = connection.local().createStream(7, false); + Http2Stream streamE = connection.local().createStream(9, false); + Http2Stream streamF = connection.local().createStream(11, false); + + setPriority(streamB.id(), streamA.id(), DEFAULT_PRIORITY_WEIGHT, false); + setPriority(streamC.id(), streamA.id(), DEFAULT_PRIORITY_WEIGHT, false); + setPriority(streamD.id(), streamC.id(), DEFAULT_PRIORITY_WEIGHT, false); + setPriority(streamF.id(), streamE.id(), DEFAULT_PRIORITY_WEIGHT, false); + + // F is now going to be exclusively dependent on A, after this we should check that stream E + // prioritizableForTree is not over decremented. + setPriority(streamF.id(), streamA.id(), DEFAULT_PRIORITY_WEIGHT, true); + + assertEquals(6, connection.numActiveStreams()); + + // Level 0 + assertEquals(2, distributor.numChildren(connection.connectionStream().id())); + + // Level 1 + assertTrue(distributor.isChild(streamE.id(), connection.connectionStream().id(), DEFAULT_PRIORITY_WEIGHT)); + assertEquals(0, distributor.numChildren(streamE.id())); + + assertTrue(distributor.isChild(streamA.id(), connection.connectionStream().id(), DEFAULT_PRIORITY_WEIGHT)); + assertEquals(1, distributor.numChildren(streamA.id())); + + // Level 2 + assertTrue(distributor.isChild(streamF.id(), streamA.id(), DEFAULT_PRIORITY_WEIGHT)); + assertEquals(2, distributor.numChildren(streamF.id())); + + // Level 3 + assertTrue(distributor.isChild(streamB.id(), streamF.id(), DEFAULT_PRIORITY_WEIGHT)); + assertEquals(0, distributor.numChildren(streamB.id())); + + assertTrue(distributor.isChild(streamC.id(), streamF.id(), DEFAULT_PRIORITY_WEIGHT)); + assertEquals(1, distributor.numChildren(streamC.id())); + + // Level 4 + assertTrue(distributor.isChild(streamD.id(), streamC.id(), DEFAULT_PRIORITY_WEIGHT)); + assertEquals(0, distributor.numChildren(streamD.id())); + } + + @Test + public void weightChangeWithNoTreeChangeShouldBeRespected() throws Http2Exception { + Http2Stream streamA = connection.local().createStream(1, false); + Http2Stream streamB = connection.local().createStream(3, false); + Http2Stream streamC = connection.local().createStream(5, false); + Http2Stream streamD = connection.local().createStream(7, false); + + setPriority(streamB.id(), streamA.id(), DEFAULT_PRIORITY_WEIGHT, false); + setPriority(streamC.id(), streamA.id(), DEFAULT_PRIORITY_WEIGHT, false); + setPriority(streamD.id(), streamA.id(), DEFAULT_PRIORITY_WEIGHT, true); + + assertEquals(4, connection.numActiveStreams()); + + short newWeight = (short) (DEFAULT_PRIORITY_WEIGHT + 1); + setPriority(streamD.id(), streamA.id(), newWeight, false); + + // Level 0 + assertEquals(1, distributor.numChildren(connection.connectionStream().id())); + + // Level 1 + assertTrue(distributor.isChild(streamA.id(), connection.connectionStream().id(), DEFAULT_PRIORITY_WEIGHT)); + assertEquals(1, distributor.numChildren(streamA.id())); + + // Level 2 + assertTrue(distributor.isChild(streamD.id(), streamA.id(), newWeight)); + assertEquals(2, distributor.numChildren(streamD.id())); + + // Level 3 + assertTrue(distributor.isChild(streamB.id(), streamD.id(), DEFAULT_PRIORITY_WEIGHT)); + assertEquals(0, distributor.numChildren(streamB.id())); + + assertTrue(distributor.isChild(streamC.id(), streamD.id(), DEFAULT_PRIORITY_WEIGHT)); + assertEquals(0, distributor.numChildren(streamC.id())); + } + + @Test + public void sameNodeDependentShouldNotStackOverflowNorChangePrioritizableForTree() throws Http2Exception { + Http2Stream streamA = connection.local().createStream(1, false); + Http2Stream streamB = connection.local().createStream(3, false); + Http2Stream streamC = connection.local().createStream(5, false); + Http2Stream streamD = connection.local().createStream(7, false); + + setPriority(streamB.id(), streamA.id(), DEFAULT_PRIORITY_WEIGHT, false); + setPriority(streamC.id(), streamA.id(), DEFAULT_PRIORITY_WEIGHT, false); + setPriority(streamD.id(), streamA.id(), DEFAULT_PRIORITY_WEIGHT, true); + + boolean[] exclusives = { true, false }; + short[] weights = { DEFAULT_PRIORITY_WEIGHT, 100, 200, DEFAULT_PRIORITY_WEIGHT }; + + assertEquals(4, connection.numActiveStreams()); + + // The goal is to call setPriority with the same parent and vary the parameters + // we were at one point adding a circular depends to the tree and then throwing + // a StackOverflow due to infinite recursive operation. + for (short weight : weights) { + for (boolean exclusive : exclusives) { + setPriority(streamD.id(), streamA.id(), weight, exclusive); + + assertEquals(0, distributor.numChildren(streamB.id())); + assertEquals(0, distributor.numChildren(streamC.id())); + assertEquals(1, distributor.numChildren(streamA.id())); + assertEquals(2, distributor.numChildren(streamD.id())); + assertFalse(distributor.isChild(streamB.id(), streamA.id(), DEFAULT_PRIORITY_WEIGHT)); + assertFalse(distributor.isChild(streamC.id(), streamA.id(), DEFAULT_PRIORITY_WEIGHT)); + assertTrue(distributor.isChild(streamB.id(), streamD.id(), DEFAULT_PRIORITY_WEIGHT)); + assertTrue(distributor.isChild(streamC.id(), streamD.id(), DEFAULT_PRIORITY_WEIGHT)); + assertTrue(distributor.isChild(streamD.id(), streamA.id(), weight)); + } + } + } + + @Test + public void multipleCircularDependencyShouldUpdatePrioritizable() throws Http2Exception { + Http2Stream streamA = connection.local().createStream(1, false); + Http2Stream streamB = connection.local().createStream(3, false); + Http2Stream streamC = connection.local().createStream(5, false); + Http2Stream streamD = connection.local().createStream(7, false); + + setPriority(streamB.id(), streamA.id(), DEFAULT_PRIORITY_WEIGHT, false); + setPriority(streamC.id(), streamA.id(), DEFAULT_PRIORITY_WEIGHT, false); + setPriority(streamD.id(), streamA.id(), DEFAULT_PRIORITY_WEIGHT, true); + + assertEquals(4, connection.numActiveStreams()); + + // Bring B to the root + setPriority(streamA.id(), streamB.id(), DEFAULT_PRIORITY_WEIGHT, true); + + // Move all streams to be children of B + setPriority(streamC.id(), streamB.id(), DEFAULT_PRIORITY_WEIGHT, false); + setPriority(streamD.id(), streamB.id(), DEFAULT_PRIORITY_WEIGHT, false); + + // Move A back to the root + setPriority(streamB.id(), streamA.id(), DEFAULT_PRIORITY_WEIGHT, true); + + // Move all streams to be children of A + setPriority(streamC.id(), streamA.id(), DEFAULT_PRIORITY_WEIGHT, false); + setPriority(streamD.id(), streamA.id(), DEFAULT_PRIORITY_WEIGHT, false); + + // Level 0 + assertEquals(1, distributor.numChildren(connection.connectionStream().id())); + + // Level 1 + assertTrue(distributor.isChild(streamA.id(), connection.connectionStream().id(), DEFAULT_PRIORITY_WEIGHT)); + assertEquals(3, distributor.numChildren(streamA.id())); + + // Level 2 + assertTrue(distributor.isChild(streamB.id(), streamA.id(), DEFAULT_PRIORITY_WEIGHT)); + assertEquals(0, distributor.numChildren(streamB.id())); + + assertTrue(distributor.isChild(streamC.id(), streamA.id(), DEFAULT_PRIORITY_WEIGHT)); + assertEquals(0, distributor.numChildren(streamC.id())); + + assertTrue(distributor.isChild(streamD.id(), streamA.id(), DEFAULT_PRIORITY_WEIGHT)); + assertEquals(0, distributor.numChildren(streamD.id())); + } + + @Test + public void removeWithPrioritizableDependentsShouldNotRestructureTree() throws Exception { + Http2Stream streamA = connection.local().createStream(1, false); + Http2Stream streamB = connection.local().createStream(3, false); + Http2Stream streamC = connection.local().createStream(5, false); + Http2Stream streamD = connection.local().createStream(7, false); + + setPriority(streamB.id(), streamA.id(), DEFAULT_PRIORITY_WEIGHT, false); + setPriority(streamC.id(), streamB.id(), DEFAULT_PRIORITY_WEIGHT, false); + setPriority(streamD.id(), streamB.id(), DEFAULT_PRIORITY_WEIGHT, false); + + // Default removal policy will cause it to be removed immediately. + // Closing streamB will distribute its weight to the children (C & D) equally. + streamB.close(); + + // Level 0 + assertEquals(1, distributor.numChildren(connection.connectionStream().id())); + + // Level 1 + assertTrue(distributor.isChild(streamA.id(), connection.connectionStream().id(), DEFAULT_PRIORITY_WEIGHT)); + assertEquals(2, distributor.numChildren(streamA.id())); + + // Level 2 + short halfWeight = DEFAULT_PRIORITY_WEIGHT / 2; + assertTrue(distributor.isChild(streamC.id(), streamA.id(), halfWeight)); + assertEquals(0, distributor.numChildren(streamC.id())); + + assertTrue(distributor.isChild(streamD.id(), streamA.id(), halfWeight)); + assertEquals(0, distributor.numChildren(streamD.id())); + } + + @Test + public void closeWithNoPrioritizableDependentsShouldRestructureTree() throws Exception { + Http2Stream streamA = connection.local().createStream(1, false); + Http2Stream streamB = connection.local().createStream(3, false); + Http2Stream streamC = connection.local().createStream(5, false); + Http2Stream streamD = connection.local().createStream(7, false); + Http2Stream streamE = connection.local().createStream(9, false); + Http2Stream streamF = connection.local().createStream(11, false); + + setPriority(streamB.id(), streamA.id(), DEFAULT_PRIORITY_WEIGHT, false); + setPriority(streamC.id(), streamB.id(), DEFAULT_PRIORITY_WEIGHT, false); + setPriority(streamD.id(), streamB.id(), DEFAULT_PRIORITY_WEIGHT, false); + setPriority(streamE.id(), streamC.id(), DEFAULT_PRIORITY_WEIGHT, false); + setPriority(streamF.id(), streamD.id(), DEFAULT_PRIORITY_WEIGHT, false); + + // Close internal nodes, leave 1 leaf node open, the only remaining stream is the one that is not closed (E). + streamA.close(); + // Closing streamB will distribute its weight to the children (C & D) equally. + streamB.close(); + streamC.close(); + streamD.close(); + streamF.close(); + + // Level 0 + assertEquals(1, distributor.numChildren(connection.connectionStream().id())); + + // Level 1 + short halfWeight = DEFAULT_PRIORITY_WEIGHT / 2; + assertTrue(distributor.isChild(streamE.id(), connection.connectionStream().id(), halfWeight)); + assertEquals(0, distributor.numChildren(streamE.id())); + } + + @Test + public void closeStreamWithChildrenShouldRedistributeWeightToChildren() throws Exception { + Http2Stream streamA = connection.local().createStream(1, false); + Http2Stream streamB = connection.local().createStream(3, false); + Http2Stream streamC = connection.local().createStream(5, false); + Http2Stream streamD = connection.local().createStream(7, false); + Http2Stream streamE = connection.local().createStream(9, false); + Http2Stream streamF = connection.local().createStream(11, false); + Http2Stream streamG = connection.local().createStream(13, false); + Http2Stream streamH = connection.local().createStream(15, false); + + setPriority(streamC.id(), streamA.id(), MAX_WEIGHT, false); + setPriority(streamD.id(), streamA.id(), MAX_WEIGHT, false); + setPriority(streamE.id(), streamA.id(), MAX_WEIGHT, false); + + setPriority(streamF.id(), streamB.id(), DEFAULT_PRIORITY_WEIGHT, false); + setPriority(streamG.id(), streamB.id(), DEFAULT_PRIORITY_WEIGHT, false); + setPriority(streamH.id(), streamB.id(), 2 * DEFAULT_PRIORITY_WEIGHT, false); + + streamE.close(); + // closing stream A will distribute its weight to the children (C & D) equally + streamA.close(); + // closing stream B will distribute its weight to the children (F & G & H) proportionally + streamB.close(); + // Level 0 + assertEquals(5, distributor.numChildren(connection.connectionStream().id())); + // Level 1 + short halfWeight = DEFAULT_PRIORITY_WEIGHT / 2; + assertTrue(distributor.isChild(streamC.id(), connection.connectionStream().id(), halfWeight)); + assertTrue(distributor.isChild(streamD.id(), connection.connectionStream().id(), halfWeight)); + + short quarterWeight = DEFAULT_PRIORITY_WEIGHT / 4; + assertTrue(distributor.isChild(streamF.id(), connection.connectionStream().id(), quarterWeight)); + assertTrue(distributor.isChild(streamG.id(), connection.connectionStream().id(), quarterWeight)); + assertTrue(distributor.isChild(streamH.id(), connection.connectionStream().id(), (short) (2 * quarterWeight))); + } + + @Test + public void priorityChangeWithNoPrioritizableDependentsShouldRestructureTree() throws Exception { + Http2Stream streamA = connection.local().createStream(1, false); + Http2Stream streamB = connection.local().createStream(3, false); + Http2Stream streamC = connection.local().createStream(5, false); + Http2Stream streamD = connection.local().createStream(7, false); + Http2Stream streamE = connection.local().createStream(9, false); + Http2Stream streamF = connection.local().createStream(11, false); + + setPriority(streamB.id(), streamA.id(), DEFAULT_PRIORITY_WEIGHT, false); + setPriority(streamC.id(), streamB.id(), DEFAULT_PRIORITY_WEIGHT, false); + setPriority(streamD.id(), streamB.id(), DEFAULT_PRIORITY_WEIGHT, false); + setPriority(streamF.id(), streamD.id(), DEFAULT_PRIORITY_WEIGHT, false); + setPriority(streamE.id(), streamC.id(), DEFAULT_PRIORITY_WEIGHT, false); + + // Leave leaf nodes open (E & F) + streamA.close(); + // Closing streamB will distribute its weight to the children (C & D) equally. + streamB.close(); + streamC.close(); + streamD.close(); + + // Move F to depend on C, even though C is closed. + setPriority(streamF.id(), streamC.id(), DEFAULT_PRIORITY_WEIGHT, false); + + // Level 0 + assertEquals(2, distributor.numChildren(connection.connectionStream().id())); + + // Level 1 + short halfWeight = DEFAULT_PRIORITY_WEIGHT / 2; + assertTrue(distributor.isChild(streamE.id(), connection.connectionStream().id(), halfWeight)); + assertEquals(0, distributor.numChildren(streamE.id())); + + assertTrue(distributor.isChild(streamF.id(), connection.connectionStream().id(), halfWeight)); + assertEquals(0, distributor.numChildren(streamF.id())); + } + + @Test + public void circularDependencyShouldRestructureTree() throws Exception { + // Using example from https://tools.ietf.org/html/rfc7540#section-5.3.3 + // Initialize all the nodes + Http2Stream streamA = connection.local().createStream(1, false); + Http2Stream streamB = connection.local().createStream(3, false); + Http2Stream streamC = connection.local().createStream(5, false); + Http2Stream streamD = connection.local().createStream(7, false); + Http2Stream streamE = connection.local().createStream(9, false); + Http2Stream streamF = connection.local().createStream(11, false); + + assertEquals(6, distributor.numChildren(connection.connectionStream().id())); + assertEquals(0, distributor.numChildren(streamA.id())); + assertEquals(0, distributor.numChildren(streamB.id())); + assertEquals(0, distributor.numChildren(streamC.id())); + assertEquals(0, distributor.numChildren(streamD.id())); + assertEquals(0, distributor.numChildren(streamE.id())); + assertEquals(0, distributor.numChildren(streamF.id())); + + // Build the tree + setPriority(streamB.id(), streamA.id(), DEFAULT_PRIORITY_WEIGHT, false); + assertEquals(5, distributor.numChildren(connection.connectionStream().id())); + assertTrue(distributor.isChild(streamB.id(), streamA.id(), DEFAULT_PRIORITY_WEIGHT)); + assertEquals(1, distributor.numChildren(streamA.id())); + + setPriority(streamC.id(), streamA.id(), DEFAULT_PRIORITY_WEIGHT, false); + assertEquals(4, distributor.numChildren(connection.connectionStream().id())); + assertTrue(distributor.isChild(streamC.id(), streamA.id(), DEFAULT_PRIORITY_WEIGHT)); + assertEquals(2, distributor.numChildren(streamA.id())); + + setPriority(streamD.id(), streamC.id(), DEFAULT_PRIORITY_WEIGHT, false); + assertEquals(3, distributor.numChildren(connection.connectionStream().id())); + assertTrue(distributor.isChild(streamD.id(), streamC.id(), DEFAULT_PRIORITY_WEIGHT)); + assertEquals(1, distributor.numChildren(streamC.id())); + + setPriority(streamE.id(), streamC.id(), DEFAULT_PRIORITY_WEIGHT, false); + assertEquals(2, distributor.numChildren(connection.connectionStream().id())); + assertTrue(distributor.isChild(streamE.id(), streamC.id(), DEFAULT_PRIORITY_WEIGHT)); + assertEquals(2, distributor.numChildren(streamC.id())); + + setPriority(streamF.id(), streamD.id(), DEFAULT_PRIORITY_WEIGHT, false); + assertEquals(1, distributor.numChildren(connection.connectionStream().id())); + assertTrue(distributor.isChild(streamF.id(), streamD.id(), DEFAULT_PRIORITY_WEIGHT)); + assertEquals(1, distributor.numChildren(streamD.id())); + + assertEquals(6, connection.numActiveStreams()); + + // Non-exclusive re-prioritization of a->d. + setPriority(streamA.id(), streamD.id(), DEFAULT_PRIORITY_WEIGHT, false); + + // Level 0 + assertEquals(1, distributor.numChildren(connection.connectionStream().id())); + + // Level 1 + assertTrue(distributor.isChild(streamD.id(), connection.connectionStream().id(), DEFAULT_PRIORITY_WEIGHT)); + assertEquals(2, distributor.numChildren(streamD.id())); + + // Level 2 + assertTrue(distributor.isChild(streamF.id(), streamD.id(), DEFAULT_PRIORITY_WEIGHT)); + assertEquals(0, distributor.numChildren(streamF.id())); + + assertTrue(distributor.isChild(streamA.id(), streamD.id(), DEFAULT_PRIORITY_WEIGHT)); + assertEquals(2, distributor.numChildren(streamA.id())); + + // Level 3 + assertTrue(distributor.isChild(streamB.id(), streamA.id(), DEFAULT_PRIORITY_WEIGHT)); + assertEquals(0, distributor.numChildren(streamB.id())); + + assertTrue(distributor.isChild(streamC.id(), streamA.id(), DEFAULT_PRIORITY_WEIGHT)); + assertEquals(1, distributor.numChildren(streamC.id())); + + // Level 4 + assertTrue(distributor.isChild(streamE.id(), streamC.id(), DEFAULT_PRIORITY_WEIGHT)); + assertEquals(0, distributor.numChildren(streamE.id())); + } + + @Test + public void circularDependencyWithExclusiveShouldRestructureTree() throws Exception { + // Using example from https://tools.ietf.org/html/rfc7540#section-5.3.3 + // Initialize all the nodes + Http2Stream streamA = connection.local().createStream(1, false); + Http2Stream streamB = connection.local().createStream(3, false); + Http2Stream streamC = connection.local().createStream(5, false); + Http2Stream streamD = connection.local().createStream(7, false); + Http2Stream streamE = connection.local().createStream(9, false); + Http2Stream streamF = connection.local().createStream(11, false); + + assertEquals(6, distributor.numChildren(connection.connectionStream().id())); + assertEquals(0, distributor.numChildren(streamA.id())); + assertEquals(0, distributor.numChildren(streamB.id())); + assertEquals(0, distributor.numChildren(streamC.id())); + assertEquals(0, distributor.numChildren(streamD.id())); + assertEquals(0, distributor.numChildren(streamE.id())); + assertEquals(0, distributor.numChildren(streamF.id())); + + // Build the tree + setPriority(streamB.id(), streamA.id(), DEFAULT_PRIORITY_WEIGHT, false); + assertEquals(5, distributor.numChildren(connection.connectionStream().id())); + assertTrue(distributor.isChild(streamB.id(), streamA.id(), DEFAULT_PRIORITY_WEIGHT)); + assertEquals(1, distributor.numChildren(streamA.id())); + + setPriority(streamC.id(), streamA.id(), DEFAULT_PRIORITY_WEIGHT, false); + assertEquals(4, distributor.numChildren(connection.connectionStream().id())); + assertTrue(distributor.isChild(streamC.id(), streamA.id(), DEFAULT_PRIORITY_WEIGHT)); + assertEquals(2, distributor.numChildren(streamA.id())); + + setPriority(streamD.id(), streamC.id(), DEFAULT_PRIORITY_WEIGHT, false); + assertEquals(3, distributor.numChildren(connection.connectionStream().id())); + assertTrue(distributor.isChild(streamD.id(), streamC.id(), DEFAULT_PRIORITY_WEIGHT)); + assertEquals(1, distributor.numChildren(streamC.id())); + + setPriority(streamE.id(), streamC.id(), DEFAULT_PRIORITY_WEIGHT, false); + assertEquals(2, distributor.numChildren(connection.connectionStream().id())); + assertTrue(distributor.isChild(streamE.id(), streamC.id(), DEFAULT_PRIORITY_WEIGHT)); + assertEquals(2, distributor.numChildren(streamC.id())); + + setPriority(streamF.id(), streamD.id(), DEFAULT_PRIORITY_WEIGHT, false); + assertEquals(1, distributor.numChildren(connection.connectionStream().id())); + assertTrue(distributor.isChild(streamF.id(), streamD.id(), DEFAULT_PRIORITY_WEIGHT)); + assertEquals(1, distributor.numChildren(streamD.id())); + + assertEquals(6, connection.numActiveStreams()); + + // Exclusive re-prioritization of a->d. + setPriority(streamA.id(), streamD.id(), DEFAULT_PRIORITY_WEIGHT, true); + + // Level 0 + assertEquals(1, distributor.numChildren(connection.connectionStream().id())); + + // Level 1 + assertTrue(distributor.isChild(streamD.id(), connection.connectionStream().id(), DEFAULT_PRIORITY_WEIGHT)); + assertEquals(1, distributor.numChildren(streamD.id())); + + // Level 2 + assertTrue(distributor.isChild(streamA.id(), streamD.id(), DEFAULT_PRIORITY_WEIGHT)); + assertEquals(3, distributor.numChildren(streamA.id())); + + // Level 3 + assertTrue(distributor.isChild(streamB.id(), streamA.id(), DEFAULT_PRIORITY_WEIGHT)); + assertEquals(0, distributor.numChildren(streamB.id())); + + assertTrue(distributor.isChild(streamF.id(), streamA.id(), DEFAULT_PRIORITY_WEIGHT)); + assertEquals(0, distributor.numChildren(streamF.id())); + + assertTrue(distributor.isChild(streamC.id(), streamA.id(), DEFAULT_PRIORITY_WEIGHT)); + assertEquals(1, distributor.numChildren(streamC.id())); + + // Level 4; + assertTrue(distributor.isChild(streamE.id(), streamC.id(), DEFAULT_PRIORITY_WEIGHT)); + assertEquals(0, distributor.numChildren(streamE.id())); + } + + // Unknown parent streams can come about in two ways: + // 1. Because the stream is old and its state was purged + // 2. This is the first reference to the stream, as implied at least by RFC7540§5.3.1: + // > A dependency on a stream that is not currently in the tree — such as a stream in the + // > "idle" state — results in that stream being given a default priority + @Test + public void unknownParentShouldBeCreatedUnderConnection() throws Exception { + setup(5); + + // Purposefully avoid creating streamA's Http2Stream so that is it completely unknown. + // It shouldn't matter whether the ID is before or after streamB.id() + int streamAId = 1; + Http2Stream streamB = connection.local().createStream(3, false); + + assertEquals(1, distributor.numChildren(connection.connectionStream().id())); + assertEquals(0, distributor.numChildren(streamB.id())); + + // Build the tree + setPriority(streamB.id(), streamAId, DEFAULT_PRIORITY_WEIGHT, false); + + assertEquals(1, connection.numActiveStreams()); + + // Level 0 + assertEquals(1, distributor.numChildren(connection.connectionStream().id())); + + // Level 1 + assertTrue(distributor.isChild(streamAId, connection.connectionStream().id(), DEFAULT_PRIORITY_WEIGHT)); + assertEquals(1, distributor.numChildren(streamAId)); + + // Level 2 + assertTrue(distributor.isChild(streamB.id(), streamAId, DEFAULT_PRIORITY_WEIGHT)); + assertEquals(0, distributor.numChildren(streamB.id())); + } +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/WeightedFairQueueByteDistributorTest.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/WeightedFairQueueByteDistributorTest.java new file mode 100644 index 0000000..714fb2c --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/WeightedFairQueueByteDistributorTest.java @@ -0,0 +1,964 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; +import org.mockito.ArgumentCaptor; +import org.mockito.MockitoAnnotations; +import org.mockito.verification.VerificationMode; + +import static io.netty.handler.codec.http2.Http2CodecUtil.DEFAULT_PRIORITY_WEIGHT; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.anyInt; +import static org.mockito.Mockito.atLeastOnce; +import static org.mockito.Mockito.atMost; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.reset; +import static org.mockito.Mockito.same; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +public class WeightedFairQueueByteDistributorTest extends AbstractWeightedFairQueueByteDistributorDependencyTest { + private static final int STREAM_A = 1; + private static final int STREAM_B = 3; + private static final int STREAM_C = 5; + private static final int STREAM_D = 7; + private static final int STREAM_E = 9; + private static final int ALLOCATION_QUANTUM = 100; + + @BeforeEach + public void setup() throws Http2Exception { + MockitoAnnotations.initMocks(this); + + // Assume we always write all the allocated bytes. + doAnswer(writeAnswer(false)).when(writer).write(any(Http2Stream.class), anyInt()); + + setup(-1); + } + + private void setup(int maxStateOnlySize) throws Http2Exception { + connection = new DefaultHttp2Connection(false); + distributor = maxStateOnlySize >= 0 ? new WeightedFairQueueByteDistributor(connection, maxStateOnlySize) + : new WeightedFairQueueByteDistributor(connection); + distributor.allocationQuantum(ALLOCATION_QUANTUM); + + connection.local().createStream(STREAM_A, false); + connection.local().createStream(STREAM_B, false); + Http2Stream streamC = connection.local().createStream(STREAM_C, false); + Http2Stream streamD = connection.local().createStream(STREAM_D, false); + setPriority(streamC.id(), STREAM_A, DEFAULT_PRIORITY_WEIGHT, false); + setPriority(streamD.id(), STREAM_A, DEFAULT_PRIORITY_WEIGHT, false); + } + + /** + * In this test, we block B such that it has no frames. We distribute enough bytes for all streams and stream B + * should be preserved in the priority queue structure until it has no "active" children, but it should not be + * doubly added to stream 0. + * + *
+     *         0
+     *         |
+     *         A
+     *         |
+     *        [B]
+     *         |
+     *         C
+     *         |
+     *         D
+     * 
+ * + * After the write: + *
+     *         0
+     * 
+ */ + @Test + public void writeWithNonActiveStreamShouldNotDobuleAddToPriorityQueue() throws Http2Exception { + initState(STREAM_A, 400, true); + initState(STREAM_B, 500, true); + initState(STREAM_C, 600, true); + initState(STREAM_D, 700, true); + + setPriority(STREAM_B, STREAM_A, DEFAULT_PRIORITY_WEIGHT, true); + setPriority(STREAM_D, STREAM_C, DEFAULT_PRIORITY_WEIGHT, true); + + // Block B, but it should still remain in the queue/tree structure. + initState(STREAM_B, 0, false); + + // Get the streams before the write, because they may be be closed. + Http2Stream streamA = stream(STREAM_A); + Http2Stream streamB = stream(STREAM_B); + Http2Stream streamC = stream(STREAM_C); + Http2Stream streamD = stream(STREAM_D); + + reset(writer); + doAnswer(writeAnswer(true)).when(writer).write(any(Http2Stream.class), anyInt()); + + assertFalse(write(400 + 600 + 700)); + assertEquals(400, captureWrites(streamA)); + verifyNeverWrite(streamB); + assertEquals(600, captureWrites(streamC)); + assertEquals(700, captureWrites(streamD)); + } + + @Test + public void bytesUnassignedAfterProcessing() throws Http2Exception { + initState(STREAM_A, 1, true); + initState(STREAM_B, 2, true); + initState(STREAM_C, 3, true); + initState(STREAM_D, 4, true); + + assertFalse(write(10)); + verifyWrite(STREAM_A, 1); + verifyWrite(STREAM_B, 2); + verifyWrite(STREAM_C, 3); + verifyWrite(STREAM_D, 4); + + assertFalse(write(10)); + verifyAnyWrite(STREAM_A, 1); + verifyAnyWrite(STREAM_B, 1); + verifyAnyWrite(STREAM_C, 1); + verifyAnyWrite(STREAM_D, 1); + } + + @Test + public void connectionErrorForWriterException() throws Http2Exception { + initState(STREAM_A, 1, true); + initState(STREAM_B, 2, true); + initState(STREAM_C, 3, true); + initState(STREAM_D, 4, true); + + Exception fakeException = new RuntimeException("Fake exception"); + doThrow(fakeException).when(writer).write(same(stream(STREAM_C)), eq(3)); + + Http2Exception e = assertThrows(Http2Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + write(10); + } + }); + assertFalse(Http2Exception.isStreamError(e)); + assertEquals(Http2Error.INTERNAL_ERROR, e.error()); + assertSame(fakeException, e.getCause()); + + verifyWrite(atMost(1), STREAM_A, 1); + verifyWrite(atMost(1), STREAM_B, 2); + verifyWrite(STREAM_C, 3); + verifyWrite(atMost(1), STREAM_D, 4); + + doAnswer(writeAnswer(false)).when(writer).write(same(stream(STREAM_C)), eq(3)); + assertFalse(write(10)); + verifyWrite(STREAM_A, 1); + verifyWrite(STREAM_B, 2); + verifyWrite(times(2), STREAM_C, 3); + verifyWrite(STREAM_D, 4); + } + + /** + * In this test, we verify that each stream is allocated a minimum chunk size. When bytes + * run out, the remaining streams will be next in line for the next iteration. + */ + @Test + public void minChunkShouldBeAllocatedPerStream() throws Http2Exception { + // Re-assign weights. + setPriority(STREAM_A, 0, (short) 50, false); + setPriority(STREAM_B, 0, (short) 200, false); + setPriority(STREAM_C, STREAM_A, (short) 100, false); + setPriority(STREAM_D, STREAM_A, (short) 100, false); + + // Update the streams. + initState(STREAM_A, ALLOCATION_QUANTUM, true); + initState(STREAM_B, ALLOCATION_QUANTUM, true); + initState(STREAM_C, ALLOCATION_QUANTUM, true); + initState(STREAM_D, ALLOCATION_QUANTUM, true); + + // Only write 3 * chunkSize, so that we'll only write to the first 3 streams. + int written = 3 * ALLOCATION_QUANTUM; + assertTrue(write(written)); + assertEquals(ALLOCATION_QUANTUM, captureWrites(STREAM_A)); + assertEquals(ALLOCATION_QUANTUM, captureWrites(STREAM_B)); + assertEquals(ALLOCATION_QUANTUM, captureWrites(STREAM_C)); + verifyWrite(atMost(1), STREAM_D, 0); + + // Now write again and verify that the last stream is written to. + assertFalse(write(ALLOCATION_QUANTUM)); + assertEquals(ALLOCATION_QUANTUM, captureWrites(STREAM_A)); + assertEquals(ALLOCATION_QUANTUM, captureWrites(STREAM_B)); + assertEquals(ALLOCATION_QUANTUM, captureWrites(STREAM_C)); + assertEquals(ALLOCATION_QUANTUM, captureWrites(STREAM_D)); + } + + /** + * In this test, we verify that the highest priority frame which has 0 bytes to send, but an empty frame is able + * to send that empty frame. + * + *
+     *         0
+     *        / \
+     *       A   B
+     *      / \
+     *     C   D
+     * 
+ * + * After the tree shift: + * + *
+     *         0
+     *         |
+     *         A
+     *         |
+     *         B
+     *        / \
+     *       C   D
+     * 
+ */ + @Test + public void emptyFrameAtHeadIsWritten() throws Http2Exception { + initState(STREAM_A, 0, true); + initState(STREAM_B, 0, true); + initState(STREAM_C, 0, true); + initState(STREAM_D, 10, true); + + setPriority(STREAM_B, STREAM_A, DEFAULT_PRIORITY_WEIGHT, true); + + assertFalse(write(10)); + verifyWrite(STREAM_A, 0); + verifyWrite(STREAM_B, 0); + verifyWrite(STREAM_C, 0); + verifyWrite(STREAM_D, 10); + } + + /** + * In this test, we block A which allows bytes to be written by C and D. Here's a view of the tree (stream A is + * blocked). + * + *
+     *         0
+     *        / \
+     *      [A]  B
+     *      / \
+     *     C   D
+     * 
+ */ + @Test + public void blockedStreamNoDataShouldSpreadDataToChildren() throws Http2Exception { + blockedStreamShouldSpreadDataToChildren(false); + } + + /** + * In this test, we block A and also give it an empty data frame to send. + * All bytes should be delegated to by C and D. Here's a view of the tree (stream A is blocked). + * + *
+     *           0
+     *         /   \
+     *      [A](0)  B
+     *      / \
+     *     C   D
+     * 
+ */ + @Test + public void blockedStreamWithDataAndNotAllowedToSendShouldSpreadDataToChildren() throws Http2Exception { + // A cannot stream. + initState(STREAM_A, 0, true, false); + blockedStreamShouldSpreadDataToChildren(false); + } + + /** + * In this test, we allow A to send, but expect the flow controller will only write to the stream 1 time. + * This is because we give the stream a chance to write its empty frame 1 time, and the stream will not + * be written to again until a update stream is called. + * + *
+     *         0
+     *        / \
+     *       A   B
+     *      / \
+     *     C   D
+     * 
+ */ + @Test + public void streamWithZeroFlowControlWindowAndDataShouldWriteOnlyOnce() throws Http2Exception { + initState(STREAM_A, 0, true, true); + blockedStreamShouldSpreadDataToChildren(true); + + // Make sure if we call update stream again, A should write 1 more time. + initState(STREAM_A, 0, true, true); + assertFalse(write(1)); + verifyWrite(times(2), STREAM_A, 0); + + // Try to write again, but since no initState A should not write again + assertFalse(write(1)); + verifyWrite(times(2), STREAM_A, 0); + } + + private void blockedStreamShouldSpreadDataToChildren(boolean streamAShouldWriteZero) throws Http2Exception { + initState(STREAM_B, 10, true); + initState(STREAM_C, 10, true); + initState(STREAM_D, 10, true); + + // Write up to 10 bytes. + assertTrue(write(10)); + + if (streamAShouldWriteZero) { + verifyWrite(STREAM_A, 0); + } else { + verifyNeverWrite(STREAM_A); + } + verifyWrite(atMost(1), STREAM_C, 0); + verifyWrite(atMost(1), STREAM_D, 0); + + // B is entirely written + verifyWrite(STREAM_B, 10); + + // Now test that writes get delegated from A (which is blocked) to its children + assertTrue(write(5)); + if (streamAShouldWriteZero) { + verifyWrite(times(1), STREAM_A, 0); + } else { + verifyNeverWrite(STREAM_A); + } + verifyWrite(STREAM_D, 5); + verifyWrite(atMost(1), STREAM_C, 0); + + assertTrue(write(5)); + if (streamAShouldWriteZero) { + verifyWrite(times(1), STREAM_A, 0); + } else { + verifyNeverWrite(STREAM_A); + } + assertEquals(10, captureWrites(STREAM_C) + captureWrites(STREAM_D)); + + assertTrue(write(5)); + assertFalse(write(5)); + if (streamAShouldWriteZero) { + verifyWrite(times(1), STREAM_A, 0); + } else { + verifyNeverWrite(STREAM_A); + } + verifyWrite(times(2), STREAM_C, 5); + verifyWrite(times(2), STREAM_D, 5); + } + + /** + * In this test, we block B which allows all bytes to be written by A. A should not share the data with its children + * since it's not blocked. + * + *
+     *         0
+     *        / \
+     *       A  [B]
+     *      / \
+     *     C   D
+     * 
+ */ + @Test + public void childrenShouldNotSendDataUntilParentBlocked() throws Http2Exception { + // B cannot stream. + initState(STREAM_A, 10, true); + initState(STREAM_C, 10, true); + initState(STREAM_D, 10, true); + + // Write up to 10 bytes. + assertTrue(write(10)); + + // A is assigned all of the bytes. + verifyWrite(STREAM_A, 10); + verifyNeverWrite(STREAM_B); + verifyWrite(atMost(1), STREAM_C, 0); + verifyWrite(atMost(1), STREAM_D, 0); + } + + /** + * In this test, we block B which allows all bytes to be written by A. Once A is complete, it will spill over the + * remaining of its portion to its children. + * + *
+     *         0
+     *        / \
+     *       A  [B]
+     *      / \
+     *     C   D
+     * 
+ */ + @Test + public void parentShouldWaterFallDataToChildren() throws Http2Exception { + // B cannot stream. + initState(STREAM_A, 5, true); + initState(STREAM_C, 10, true); + initState(STREAM_D, 10, true); + + // Write up to 10 bytes. + assertTrue(write(10)); + + verifyWrite(STREAM_A, 5); + verifyNeverWrite(STREAM_B); + verifyWrite(STREAM_C, 5); + verifyNeverWrite(STREAM_D); + + assertFalse(write(15)); + verifyAnyWrite(STREAM_A, 1); + verifyNeverWrite(STREAM_B); + verifyWrite(times(2), STREAM_C, 5); + verifyWrite(STREAM_D, 10); + } + + /** + * In this test, we verify re-prioritizing a stream. We start out with B blocked: + * + *
+     *         0
+     *        / \
+     *       A  [B]
+     *      / \
+     *     C   D
+     * 
+ * + * We then re-prioritize D so that it's directly off of the connection and verify that A and D split the written + * bytes between them. + * + *
+     *           0
+     *          /|\
+     *        /  |  \
+     *       A  [B]  D
+     *      /
+     *     C
+     * 
+ */ + @Test + public void reprioritizeShouldAdjustOutboundFlow() throws Http2Exception { + // B cannot stream. + initState(STREAM_A, 10, true); + initState(STREAM_C, 10, true); + initState(STREAM_D, 10, true); + + // Re-prioritize D as a direct child of the connection. + setPriority(STREAM_D, 0, DEFAULT_PRIORITY_WEIGHT, false); + + assertTrue(write(10)); + + verifyWrite(STREAM_A, 10); + verifyNeverWrite(STREAM_B); + verifyNeverWrite(STREAM_C); + verifyWrite(atMost(1), STREAM_D, 0); + + assertFalse(write(20)); + verifyAnyWrite(STREAM_A, 1); + verifyNeverWrite(STREAM_B); + verifyWrite(STREAM_C, 10); + verifyWrite(STREAM_D, 10); + } + + /** + * Test that the maximum allowed amount the flow controller allows to be sent is always fully allocated if + * the streams have at least this much data to send. See https://github.com/netty/netty/issues/4266. + *
+     *            0
+     *          / | \
+     *        /   |   \
+     *      A(0) B(0) C(0)
+     *     /
+     *    D(> allowed to send in 1 allocation attempt)
+     * 
+ */ + @Test + public void unstreamableParentsShouldFeedHungryChildren() throws Http2Exception { + // Setup the priority tree. + setPriority(STREAM_A, 0, (short) 32, false); + setPriority(STREAM_B, 0, (short) 16, false); + setPriority(STREAM_C, 0, (short) 16, false); + setPriority(STREAM_D, STREAM_A, (short) 16, false); + + final int writableBytes = 100; + + // Send enough so it can not be completely written out + final int expectedUnsentAmount = 1; + initState(STREAM_D, writableBytes + expectedUnsentAmount, true); + + assertTrue(write(writableBytes)); + verifyWrite(STREAM_D, writableBytes); + + assertFalse(write(expectedUnsentAmount)); + verifyWrite(STREAM_D, expectedUnsentAmount); + } + + /** + * In this test, we root all streams at the connection, and then verify that data is split appropriately based on + * weight (all available data is the same). + * + *
+     *           0
+     *        / / \ \
+     *       A B   C D
+     * 
+ */ + @Test + public void writeShouldPreferHighestWeight() throws Http2Exception { + // Root the streams at the connection and assign weights. + setPriority(STREAM_A, 0, (short) 50, false); + setPriority(STREAM_B, 0, (short) 200, false); + setPriority(STREAM_C, 0, (short) 100, false); + setPriority(STREAM_D, 0, (short) 100, false); + + initState(STREAM_A, 1000, true); + initState(STREAM_B, 1000, true); + initState(STREAM_C, 1000, true); + initState(STREAM_D, 1000, true); + + // Set allocation quantum to 1 so it is easier to see the ratio of total bytes written between each stream. + distributor.allocationQuantum(1); + assertTrue(write(1000)); + + assertEquals(100, captureWrites(STREAM_A)); + assertEquals(450, captureWrites(STREAM_B)); + assertEquals(225, captureWrites(STREAM_C)); + assertEquals(225, captureWrites(STREAM_D)); + } + + /** + * In this test, we root all streams at the connection, block streams C and D, and then verify that data is + * prioritized toward stream B which has a higher weight than stream A. + *

+ * We also verify that the amount that is written is not uniform, and not always the allocation quantum. + * + *

+     *            0
+     *        / /  \  \
+     *       A B   [C] [D]
+     * 
+ */ + @Test + public void writeShouldFavorPriority() throws Http2Exception { + // Root the streams at the connection and assign weights. + setPriority(STREAM_A, 0, (short) 50, false); + setPriority(STREAM_B, 0, (short) 200, false); + setPriority(STREAM_C, 0, (short) 100, false); + setPriority(STREAM_D, 0, (short) 100, false); + + initState(STREAM_A, 1000, true); + initState(STREAM_B, 1000, true); + initState(STREAM_C, 1000, false); + initState(STREAM_D, 1000, false); + + // Set allocation quantum to 1 so it is easier to see the ratio of total bytes written between each stream. + distributor.allocationQuantum(1); + + assertTrue(write(100)); + assertEquals(20, captureWrites(STREAM_A)); + verifyWrite(times(20), STREAM_A, 1); + assertEquals(80, captureWrites(STREAM_B)); + verifyWrite(times(0), STREAM_B, 1); + verifyNeverWrite(STREAM_C); + verifyNeverWrite(STREAM_D); + + assertTrue(write(100)); + assertEquals(40, captureWrites(STREAM_A)); + verifyWrite(times(40), STREAM_A, 1); + assertEquals(160, captureWrites(STREAM_B)); + verifyWrite(atMost(1), STREAM_B, 1); + verifyNeverWrite(STREAM_C); + verifyNeverWrite(STREAM_D); + + assertTrue(write(1050)); + assertEquals(250, captureWrites(STREAM_A)); + verifyWrite(times(250), STREAM_A, 1); + assertEquals(1000, captureWrites(STREAM_B)); + verifyWrite(atMost(2), STREAM_B, 1); + verifyNeverWrite(STREAM_C); + verifyNeverWrite(STREAM_D); + + assertFalse(write(750)); + assertEquals(1000, captureWrites(STREAM_A)); + verifyWrite(times(1), STREAM_A, 750); + assertEquals(1000, captureWrites(STREAM_B)); + verifyWrite(times(0), STREAM_B, 0); + verifyNeverWrite(STREAM_C); + verifyNeverWrite(STREAM_D); + } + + /** + * In this test, we root all streams at the connection, and then verify that data is split equally among the stream, + * since they all have the same weight. + * + *
+     *           0
+     *        / / \ \
+     *       A B   C D
+     * 
+ */ + @Test + public void samePriorityShouldDistributeBasedOnData() throws Http2Exception { + // Root the streams at the connection with the same weights. + setPriority(STREAM_A, 0, DEFAULT_PRIORITY_WEIGHT, false); + setPriority(STREAM_B, 0, DEFAULT_PRIORITY_WEIGHT, false); + setPriority(STREAM_C, 0, DEFAULT_PRIORITY_WEIGHT, false); + setPriority(STREAM_D, 0, DEFAULT_PRIORITY_WEIGHT, false); + + initState(STREAM_A, 400, true); + initState(STREAM_B, 500, true); + initState(STREAM_C, 0, true); + initState(STREAM_D, 700, true); + + // Set allocation quantum to 1 so it is easier to see the ratio of total bytes written between each stream. + distributor.allocationQuantum(1); + assertTrue(write(999)); + + assertEquals(333, captureWrites(STREAM_A)); + assertEquals(333, captureWrites(STREAM_B)); + verifyWrite(times(1), STREAM_C, 0); + assertEquals(333, captureWrites(STREAM_D)); + } + + /** + * In this test, we call distribute with 0 bytes and verify that all streams with 0 bytes are written. + * + *
+     *         0
+     *        / \
+     *       A   B
+     *      / \
+     *     C   D
+     * 
+ * + * After the tree shift: + * + *
+     *         0
+     *         |
+     *        [A]
+     *         |
+     *         B
+     *        / \
+     *       C   D
+     * 
+ */ + @Test + public void zeroDistributeShouldWriteAllZeroFrames() throws Http2Exception { + initState(STREAM_A, 400, false); + initState(STREAM_B, 0, true); + initState(STREAM_C, 0, true); + initState(STREAM_D, 0, true); + + setPriority(STREAM_B, STREAM_A, DEFAULT_PRIORITY_WEIGHT, true); + + assertFalse(write(0)); + verifyNeverWrite(STREAM_A); + verifyWrite(STREAM_B, 0); + verifyAnyWrite(STREAM_B, 1); + verifyWrite(STREAM_C, 0); + verifyAnyWrite(STREAM_C, 1); + verifyWrite(STREAM_D, 0); + verifyAnyWrite(STREAM_D, 1); + } + + /** + * In this test, we call distribute with 100 bytes which is the total amount eligible to be written, and also have + * streams with 0 bytes to write. All of these streams should be written with a single call to distribute. + * + *
+     *         0
+     *        / \
+     *       A   B
+     *      / \
+     *     C   D
+     * 
+ * + * After the tree shift: + * + *
+     *         0
+     *         |
+     *        [A]
+     *         |
+     *         B
+     *        / \
+     *       C   D
+     * 
+ */ + @Test + public void nonZeroDistributeShouldWriteAllZeroFramesIfAllEligibleDataIsWritten() throws Http2Exception { + initState(STREAM_A, 400, false); + initState(STREAM_B, 100, true); + initState(STREAM_C, 0, true); + initState(STREAM_D, 0, true); + + setPriority(STREAM_B, STREAM_A, DEFAULT_PRIORITY_WEIGHT, true); + + assertFalse(write(100)); + verifyNeverWrite(STREAM_A); + verifyWrite(STREAM_B, 100); + verifyAnyWrite(STREAM_B, 1); + verifyWrite(STREAM_C, 0); + verifyAnyWrite(STREAM_C, 1); + verifyWrite(STREAM_D, 0); + verifyAnyWrite(STREAM_D, 1); + } + + /** + * In this test, we shift the priority tree and verify priority bytes for each subtree are correct + * + *
+     *         0
+     *        / \
+     *       A   B
+     *      / \
+     *     C   D
+     * 
+ * + * After the tree shift: + * + *
+     *         0
+     *         |
+     *         A
+     *         |
+     *         B
+     *        / \
+     *       C   D
+     * 
+ */ + @Test + public void bytesDistributedWithRestructureShouldBeCorrect() throws Http2Exception { + initState(STREAM_A, 400, true); + initState(STREAM_B, 500, true); + initState(STREAM_C, 600, true); + initState(STREAM_D, 700, true); + + setPriority(STREAM_B, STREAM_A, DEFAULT_PRIORITY_WEIGHT, true); + + assertTrue(write(500)); + assertEquals(400, captureWrites(STREAM_A)); + verifyWrite(STREAM_B, 100); + verifyNeverWrite(STREAM_C); + verifyNeverWrite(STREAM_D); + + assertTrue(write(400)); + assertEquals(400, captureWrites(STREAM_A)); + assertEquals(500, captureWrites(STREAM_B)); + verifyWrite(atMost(1), STREAM_C, 0); + verifyWrite(atMost(1), STREAM_D, 0); + + assertFalse(write(1300)); + assertEquals(400, captureWrites(STREAM_A)); + assertEquals(500, captureWrites(STREAM_B)); + assertEquals(600, captureWrites(STREAM_C)); + assertEquals(700, captureWrites(STREAM_D)); + } + + /** + * In this test, we add a node to the priority tree and verify + * + *
+     *         0
+     *        / \
+     *       A   B
+     *      / \
+     *     C   D
+     * 
+ * + * After the tree shift: + * + *
+     *         0
+     *        / \
+     *       A   B
+     *       |
+     *       E
+     *      / \
+     *     C   D
+     * 
+ */ + @Test + public void bytesDistributedWithAdditionShouldBeCorrect() throws Http2Exception { + Http2Stream streamE = connection.local().createStream(STREAM_E, false); + setPriority(streamE.id(), STREAM_A, DEFAULT_PRIORITY_WEIGHT, true); + + // Send a bunch of data on each stream. + initState(STREAM_A, 400, true); + initState(STREAM_B, 500, true); + initState(STREAM_C, 600, true); + initState(STREAM_D, 700, true); + initState(STREAM_E, 900, true); + + assertTrue(write(900)); + assertEquals(400, captureWrites(STREAM_A)); + assertEquals(500, captureWrites(STREAM_B)); + verifyNeverWrite(STREAM_C); + verifyNeverWrite(STREAM_D); + verifyWrite(atMost(1), STREAM_E, 0); + + assertTrue(write(900)); + assertEquals(400, captureWrites(STREAM_A)); + assertEquals(500, captureWrites(STREAM_B)); + verifyWrite(atMost(1), STREAM_C, 0); + verifyWrite(atMost(1), STREAM_D, 0); + assertEquals(900, captureWrites(STREAM_E)); + + assertFalse(write(1301)); + assertEquals(400, captureWrites(STREAM_A)); + assertEquals(500, captureWrites(STREAM_B)); + assertEquals(600, captureWrites(STREAM_C)); + assertEquals(700, captureWrites(STREAM_D)); + assertEquals(900, captureWrites(STREAM_E)); + } + + /** + * In this test, we close an internal stream in the priority tree. + * + *
+     *         0
+     *        / \
+     *       A   B
+     *      / \
+     *     C   D
+     * 
+ * + * After the close: + *
+     *          0
+     *        / | \
+     *       C  D  B
+     * 
+ */ + @Test + public void bytesDistributedShouldBeCorrectWithInternalStreamClose() throws Http2Exception { + initState(STREAM_A, 400, true); + initState(STREAM_B, 500, true); + initState(STREAM_C, 600, true); + initState(STREAM_D, 700, true); + + stream(STREAM_A).close(); + + assertTrue(write(500)); + verifyNeverWrite(STREAM_A); + assertEquals(500, captureWrites(STREAM_B) + captureWrites(STREAM_C) + captureWrites(STREAM_D)); + + assertFalse(write(1300)); + verifyNeverWrite(STREAM_A); + assertEquals(500, captureWrites(STREAM_B)); + assertEquals(600, captureWrites(STREAM_C)); + assertEquals(700, captureWrites(STREAM_D)); + } + + /** + * In this test, we close a leaf stream in the priority tree and verify distribution. + * + *
+     *         0
+     *        / \
+     *       A   B
+     *      / \
+     *     C   D
+     * 
+ * + * After the close: + *
+     *         0
+     *        / \
+     *       A   B
+     *       |
+     *       D
+     * 
+ */ + @Test + public void bytesDistributedShouldBeCorrectWithLeafStreamClose() throws Http2Exception { + initState(STREAM_A, 400, true); + initState(STREAM_B, 500, true); + initState(STREAM_C, 600, true); + initState(STREAM_D, 700, true); + + stream(STREAM_C).close(); + + assertTrue(write(900)); + assertEquals(400, captureWrites(STREAM_A)); + assertEquals(500, captureWrites(STREAM_B)); + verifyNeverWrite(STREAM_C); + verifyWrite(atMost(1), STREAM_D, 0); + + assertFalse(write(700)); + assertEquals(400, captureWrites(STREAM_A)); + assertEquals(500, captureWrites(STREAM_B)); + verifyNeverWrite(STREAM_C); + assertEquals(700, captureWrites(STREAM_D)); + } + + @Test + public void activeStreamDependentOnNewNonActiveStreamGetsQuantum() throws Http2Exception { + setup(0); + initState(STREAM_D, 700, true); + setPriority(STREAM_D, STREAM_E, DEFAULT_PRIORITY_WEIGHT, true); + + assertFalse(write(700)); + assertEquals(700, captureWrites(STREAM_D)); + } + + @Test + public void streamWindowLargerThanIntDoesNotInfiniteLoop() throws Http2Exception { + initState(STREAM_A, Integer.MAX_VALUE + 1L, true, true); + assertTrue(write(Integer.MAX_VALUE)); + verifyWrite(STREAM_A, Integer.MAX_VALUE); + assertFalse(write(1)); + verifyWrite(STREAM_A, 1); + } + + private boolean write(int numBytes) throws Http2Exception { + return distributor.distribute(numBytes, writer); + } + + private void verifyWrite(int streamId, int numBytes) { + verify(writer).write(same(stream(streamId)), eq(numBytes)); + } + + private void verifyWrite(VerificationMode mode, int streamId, int numBytes) { + verify(writer, mode).write(same(stream(streamId)), eq(numBytes)); + } + + private void verifyAnyWrite(int streamId, int times) { + verify(writer, times(times)).write(same(stream(streamId)), anyInt()); + } + + private void verifyNeverWrite(int streamId) { + verifyNeverWrite(stream(streamId)); + } + + private void verifyNeverWrite(Http2Stream stream) { + verify(writer, never()).write(same(stream), anyInt()); + } + + private int captureWrites(int streamId) { + return captureWrites(stream(streamId)); + } + + private int captureWrites(Http2Stream stream) { + ArgumentCaptor captor = ArgumentCaptor.forClass(Integer.class); + verify(writer, atLeastOnce()).write(same(stream), captor.capture()); + int total = 0; + for (Integer x : captor.getAllValues()) { + total += x; + } + return total; + } +} diff --git a/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/WeightedFairQueueRemoteFlowControllerTest.java b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/WeightedFairQueueRemoteFlowControllerTest.java new file mode 100644 index 0000000..f65bd57 --- /dev/null +++ b/netty-handler-codec-http2/src/test/java/io/netty/handler/codec/http2/WeightedFairQueueRemoteFlowControllerTest.java @@ -0,0 +1,22 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec.http2; + +public class WeightedFairQueueRemoteFlowControllerTest extends DefaultHttp2RemoteFlowControllerTest { + @Override + protected StreamByteDistributor newDistributor(Http2Connection connection) { + return new WeightedFairQueueByteDistributor(connection); + } +} diff --git a/netty-handler-codec-http2/src/test/resources/io/netty/handler/codec/http2/testdata/testDuplicateHeaders.json b/netty-handler-codec-http2/src/test/resources/io/netty/handler/codec/http2/testdata/testDuplicateHeaders.json new file mode 100644 index 0000000..8e07eff --- /dev/null +++ b/netty-handler-codec-http2/src/test/resources/io/netty/handler/codec/http2/testdata/testDuplicateHeaders.json @@ -0,0 +1,66 @@ +{ + "header_blocks": + [ + { + "headers": [ + { ":path": "/somepath" }, + { "x-custom": "val" } + ], + "encoded": [ + "4487 6107 a4b5 8d33 ff40 86f2 b12d 424f", + "4f03 7661 6c" + ], + "dynamic_table": [ + { "x-custom": "val" }, + { ":path": "/somepath" } + ], + "table_size": 89 + }, + { + "headers": [ + { ":path": "/somepath" }, + { "x-custom": "val" }, + { "x-custom": "val" } + ], + "encoded": [ + "bfbe be" + ], + "dynamic_table": [ + { "x-custom": "val" }, + { ":path": "/somepath" } + ], + "table_size": 89 + }, + { + "headers": [ + { ":path": "/somepath" }, + { "x-custom": "val" }, + { "foo": "bar" }, + { "x-custom": "val" } + ], + "encoded": [ + "bfbe 4082 94e7 0362 6172 bf" + ], + "dynamic_table": [ + { "foo": "bar" }, + { "x-custom": "val" }, + { ":path": "/somepath" } + ], + "table_size": 127 + }, + { + "headers": [ + { ":path": "/somepath" } + ], + "encoded": [ + "c0" + ], + "dynamic_table": [ + { "foo": "bar" }, + { "x-custom": "val" }, + { ":path": "/somepath" } + ], + "table_size": 127 + } + ] +} diff --git a/netty-handler-codec-http2/src/test/resources/io/netty/handler/codec/http2/testdata/testEmpty.json b/netty-handler-codec-http2/src/test/resources/io/netty/handler/codec/http2/testdata/testEmpty.json new file mode 100644 index 0000000..19b46dd --- /dev/null +++ b/netty-handler-codec-http2/src/test/resources/io/netty/handler/codec/http2/testdata/testEmpty.json @@ -0,0 +1,14 @@ +{ + "header_blocks": + [ + { + "headers": [ + ], + "encoded": [ + ], + "dynamic_table": [ + ], + "table_size": 0 + } + ] +} diff --git a/netty-handler-codec-http2/src/test/resources/io/netty/handler/codec/http2/testdata/testEviction.json b/netty-handler-codec-http2/src/test/resources/io/netty/handler/codec/http2/testdata/testEviction.json new file mode 100644 index 0000000..fb5ddcf --- /dev/null +++ b/netty-handler-codec-http2/src/test/resources/io/netty/handler/codec/http2/testdata/testEviction.json @@ -0,0 +1,57 @@ +{ + "max_header_table_size": 128, + "header_blocks": + [ + { + "headers": [ + { ":path": "/somepath" }, + { "x-custom": "val1" }, + { "x-custom": "val2" }, + { "x-custom": "val3" } + ], + "encoded": [ + "4487 6107 a4b5 8d33 ff40 86f2 b12d 424f", + "4f83 ee3a 037e 83ee 3a05 7e83 ee3a 19" + ], + "dynamic_table": [ + { "x-custom": "val3" }, + { "x-custom": "val2" } + ], + "table_size": 88 + }, + { + "headers": [ + { ":path": "/somepath" }, + { "x-custom": "val4" }, + { "x-custom": "val5" }, + { "x-custom": "val6" } + ], + "encoded": [ + "4487 6107 a4b5 8d33 ff40 86f2 b12d 424f", + "4f83 ee3a 1a7e 83ee 3a1b 7e83 ee3a 1c" + ], + "dynamic_table": [ + { "x-custom": "val6" }, + { "x-custom": "val5" } + ], + "table_size": 88 + }, + { + "headers": [ + { ":path": "/somepath" }, + { "x-custom": "val1" }, + { "x-custom": "val2" }, + { "x-custom": "val3" } + ], + "encoded": [ + "4487 6107 a4b5 8d33 ff40 86f2 b12d 424f", + "4f83 ee3a 037e 83ee 3a05 7e83 ee3a 19" + ], + "dynamic_table": [ + { "x-custom": "val3" }, + { "x-custom": "val2" } + ], + "table_size": 88 + } + ] +} diff --git a/netty-handler-codec-http2/src/test/resources/io/netty/handler/codec/http2/testdata/testMaxHeaderTableSize.json b/netty-handler-codec-http2/src/test/resources/io/netty/handler/codec/http2/testdata/testMaxHeaderTableSize.json new file mode 100644 index 0000000..69d3b8c --- /dev/null +++ b/netty-handler-codec-http2/src/test/resources/io/netty/handler/codec/http2/testdata/testMaxHeaderTableSize.json @@ -0,0 +1,55 @@ +{ + "max_header_table_size": 128, + "header_blocks": + [ + { + "headers": [ + { "name1": "val1" }, + { "name2": "val2" }, + { "name3": "val3" } + ], + "encoded": [ + "4084 a874 943f 83ee 3a03 4084 a874 945f", + "83ee 3a05 4084 a874 959f 83ee 3a19" + ], + "dynamic_table": [ + { "name3": "val3" }, + { "name2": "val2" }, + { "name1": "val1" } + ], + "table_size": 123 + }, + { + "max_header_table_size": 81, + "headers": [ + { "name3": "val3" }, + { "name2": "val2" } + ], + "encoded": [ + "3f32 be40 84a8 7494 5f83 ee3a 05" + ], + "dynamic_table": [ + { "name2": "val2" } + ], + "table_size": 41 + }, + { + "max_header_table_size": 128, + "headers": [ + { "name1": "val1" }, + { "name2": "val2" }, + { "name3": "val3" } + ], + "encoded": [ + "3f61 4084 a874 943f 83ee 3a03 bf40 84a8", + "7495 9f83 ee3a 19" + ], + "dynamic_table": [ + { "name3": "val3" }, + { "name1": "val1" }, + { "name2": "val2" } + ], + "table_size": 123 + } + ] +} diff --git a/netty-handler-codec-http2/src/test/resources/io/netty/handler/codec/http2/testdata/testSpecExampleC2_1.json b/netty-handler-codec-http2/src/test/resources/io/netty/handler/codec/http2/testdata/testSpecExampleC2_1.json new file mode 100644 index 0000000..0838aba --- /dev/null +++ b/netty-handler-codec-http2/src/test/resources/io/netty/handler/codec/http2/testdata/testSpecExampleC2_1.json @@ -0,0 +1,18 @@ +{ + "header_blocks": + [ + { + "headers": [ + { "custom-key": "custom-header" } + ], + "encoded": [ + "4088 25a8 49e9 5ba9 7d7f 8925 a849 e95a", + "728e 42d9" + ], + "dynamic_table": [ + { "custom-key": "custom-header" } + ], + "table_size": 55 + } + ] +} diff --git a/netty-handler-codec-http2/src/test/resources/io/netty/handler/codec/http2/testdata/testSpecExampleC2_2.json b/netty-handler-codec-http2/src/test/resources/io/netty/handler/codec/http2/testdata/testSpecExampleC2_2.json new file mode 100644 index 0000000..01fb8fe --- /dev/null +++ b/netty-handler-codec-http2/src/test/resources/io/netty/handler/codec/http2/testdata/testSpecExampleC2_2.json @@ -0,0 +1,17 @@ +{ + "header_blocks": + [ + { + "headers": [ + { ":path": "/sample/path" } + ], + "encoded": [ + "4489 6103 a6ba 0ac5 634c ff" + ], + "dynamic_table": [ + { ":path": "/sample/path" } + ], + "table_size": 49 + } + ] +} diff --git a/netty-handler-codec-http2/src/test/resources/io/netty/handler/codec/http2/testdata/testSpecExampleC2_3.json b/netty-handler-codec-http2/src/test/resources/io/netty/handler/codec/http2/testdata/testSpecExampleC2_3.json new file mode 100644 index 0000000..50892d6 --- /dev/null +++ b/netty-handler-codec-http2/src/test/resources/io/netty/handler/codec/http2/testdata/testSpecExampleC2_3.json @@ -0,0 +1,17 @@ +{ + "sensitive_headers": true, + "header_blocks": + [ + { + "headers": [ + { "password": "secret" } + ], + "encoded": [ + "1086 ac68 4783 d927 8441 4961 53" + ], + "dynamic_table": [ + ], + "table_size": 0 + } + ] +} diff --git a/netty-handler-codec-http2/src/test/resources/io/netty/handler/codec/http2/testdata/testSpecExampleC2_4.json b/netty-handler-codec-http2/src/test/resources/io/netty/handler/codec/http2/testdata/testSpecExampleC2_4.json new file mode 100644 index 0000000..4e8c483 --- /dev/null +++ b/netty-handler-codec-http2/src/test/resources/io/netty/handler/codec/http2/testdata/testSpecExampleC2_4.json @@ -0,0 +1,17 @@ +{ + "force_huffman_off": true, + "header_blocks": + [ + { + "headers": [ + { ":method": "GET" } + ], + "encoded": [ + "82" + ], + "dynamic_table": [ + ], + "table_size": 0 + } + ] +} diff --git a/netty-handler-codec-http2/src/test/resources/io/netty/handler/codec/http2/testdata/testSpecExampleC3.json b/netty-handler-codec-http2/src/test/resources/io/netty/handler/codec/http2/testdata/testSpecExampleC3.json new file mode 100644 index 0000000..3e8f658 --- /dev/null +++ b/netty-handler-codec-http2/src/test/resources/io/netty/handler/codec/http2/testdata/testSpecExampleC3.json @@ -0,0 +1,57 @@ +{ + "header_blocks": + [ + { + "headers": [ + { ":method": "GET" }, + { ":scheme": "http" }, + { ":path": "/" }, + { ":authority": "www.example.com" } + ], + "encoded": [ + "8286 8441 8cf1 e3c2 e5f2 3a6b a0ab 90f4", + "ff" + ], + "dynamic_table": [ + { ":authority": "www.example.com" } + ], + "table_size": 57 + }, + { + "headers": [ + { ":method": "GET" }, + { ":scheme": "http" }, + { ":path": "/" }, + { ":authority": "www.example.com" }, + { "cache-control": "no-cache" } + ], + "encoded": [ + "8286 84be 5886 a8eb 1064 9cbf" + ], + "dynamic_table": [ + { "cache-control": "no-cache" }, + { ":authority": "www.example.com" } + ], + "table_size": 110 + }, + { + "headers": [ + { ":method": "GET" }, + { ":scheme": "https" }, + { ":path": "/index.html" }, + { ":authority": "www.example.com" }, + { "custom-key": "custom-value" } + ], + "encoded": [ + "8287 85bf 4088 25a8 49e9 5ba9 7d7f 8925", + "a849 e95b b8e8 b4bf" + ], + "dynamic_table": [ + { "custom-key": "custom-value" }, + { "cache-control": "no-cache" }, + { ":authority": "www.example.com" } + ], + "table_size": 164 + } + ] +} diff --git a/netty-handler-codec-http2/src/test/resources/io/netty/handler/codec/http2/testdata/testSpecExampleC4.json b/netty-handler-codec-http2/src/test/resources/io/netty/handler/codec/http2/testdata/testSpecExampleC4.json new file mode 100644 index 0000000..0543f53 --- /dev/null +++ b/netty-handler-codec-http2/src/test/resources/io/netty/handler/codec/http2/testdata/testSpecExampleC4.json @@ -0,0 +1,58 @@ +{ + "force_huffman_on": true, + "header_blocks": + [ + { + "headers": [ + { ":method": "GET" }, + { ":scheme": "http" }, + { ":path": "/" }, + { ":authority": "www.example.com" } + ], + "encoded": [ + "8286 8441 8cf1 e3c2 e5f2 3a6b a0ab 90f4", + "ff" + ], + "dynamic_table": [ + { ":authority": "www.example.com" } + ], + "table_size": 57 + }, + { + "headers": [ + { ":method": "GET" }, + { ":scheme": "http" }, + { ":path": "/" }, + { ":authority": "www.example.com" }, + { "cache-control": "no-cache" } + ], + "encoded": [ + "8286 84be 5886 a8eb 1064 9cbf" + ], + "dynamic_table": [ + { "cache-control": "no-cache" }, + { ":authority": "www.example.com" } + ], + "table_size": 110 + }, + { + "headers": [ + { ":method": "GET" }, + { ":scheme": "https" }, + { ":path": "/index.html" }, + { ":authority": "www.example.com" }, + { "custom-key": "custom-value" } + ], + "encoded": [ + "8287 85bf 4088 25a8 49e9 5ba9 7d7f 8925", + "a849 e95b b8e8 b4bf" + ], + "dynamic_table": [ + { "custom-key": "custom-value" }, + { "cache-control": "no-cache" }, + { ":authority": "www.example.com" } + ], + "table_size": 164 + } + ] +} diff --git a/netty-handler-codec-http2/src/test/resources/io/netty/handler/codec/http2/testdata/testSpecExampleC5.json b/netty-handler-codec-http2/src/test/resources/io/netty/handler/codec/http2/testdata/testSpecExampleC5.json new file mode 100644 index 0000000..28ec503 --- /dev/null +++ b/netty-handler-codec-http2/src/test/resources/io/netty/handler/codec/http2/testdata/testSpecExampleC5.json @@ -0,0 +1,68 @@ +{ + "max_header_table_size": 256, + "header_blocks": + [ + { + "headers": [ + { ":status": "302" }, + { "cache-control": "private" }, + { "date": "Mon, 21 Oct 2013 20:13:21 GMT" }, + { "location": "https://www.example.com" } + ], + "encoded": [ + "4882 6402 5885 aec3 771a 4b61 96d0 7abe", + "9410 54d4 44a8 2005 9504 0b81 66e0 82a6", + "2d1b ff6e 919d 29ad 1718 63c7 8f0b 97c8", + "e9ae 82ae 43d3" + ], + "dynamic_table": [ + { "location": "https://www.example.com" }, + { "date": "Mon, 21 Oct 2013 20:13:21 GMT" }, + { "cache-control": "private" }, + { ":status": "302" } + ], + "table_size": 222 + }, + { + "headers": [ + { ":status": "307" }, + { "cache-control": "private" }, + { "date": "Mon, 21 Oct 2013 20:13:21 GMT" }, + { "location": "https://www.example.com" } + ], + "encoded": [ + "4803 3330 37c1 c0bf" + ], + "dynamic_table": [ + { ":status": "307" }, + { "location": "https://www.example.com" }, + { "date": "Mon, 21 Oct 2013 20:13:21 GMT" }, + { "cache-control": "private" } + ], + "table_size": 222 + }, + { + "headers": [ + { ":status": "200" }, + { "cache-control": "private" }, + { "date": "Mon, 21 Oct 2013 20:13:22 GMT" }, + { "location": "https://www.example.com" }, + { "content-encoding": "gzip" }, + { "set-cookie": "foo=ASDJKHQKBZXOQWEOPIUAXQWEOIU; max-age=3600; version=1" } + ], + "encoded": [ + "88c1 6196 d07a be94 1054 d444 a820 0595", + "040b 8166 e084 a62d 1bff c05a 839b d9ab", + "77ad 94e7 821d d7f2 e6c7 b335 dfdf cd5b", + "3960 d5af 2708 7f36 72c1 ab27 0fb5 291f", + "9587 3160 65c0 03ed 4ee5 b106 3d50 07" + ], + "dynamic_table": [ + { "set-cookie": "foo=ASDJKHQKBZXOQWEOPIUAXQWEOIU; max-age=3600; version=1" }, + { "content-encoding": "gzip" }, + { "date": "Mon, 21 Oct 2013 20:13:22 GMT" } + ], + "table_size": 215 + } + ] +} diff --git a/netty-handler-codec-http2/src/test/resources/io/netty/handler/codec/http2/testdata/testSpecExampleC6.json b/netty-handler-codec-http2/src/test/resources/io/netty/handler/codec/http2/testdata/testSpecExampleC6.json new file mode 100644 index 0000000..28ec503 --- /dev/null +++ b/netty-handler-codec-http2/src/test/resources/io/netty/handler/codec/http2/testdata/testSpecExampleC6.json @@ -0,0 +1,68 @@ +{ + "max_header_table_size": 256, + "header_blocks": + [ + { + "headers": [ + { ":status": "302" }, + { "cache-control": "private" }, + { "date": "Mon, 21 Oct 2013 20:13:21 GMT" }, + { "location": "https://www.example.com" } + ], + "encoded": [ + "4882 6402 5885 aec3 771a 4b61 96d0 7abe", + "9410 54d4 44a8 2005 9504 0b81 66e0 82a6", + "2d1b ff6e 919d 29ad 1718 63c7 8f0b 97c8", + "e9ae 82ae 43d3" + ], + "dynamic_table": [ + { "location": "https://www.example.com" }, + { "date": "Mon, 21 Oct 2013 20:13:21 GMT" }, + { "cache-control": "private" }, + { ":status": "302" } + ], + "table_size": 222 + }, + { + "headers": [ + { ":status": "307" }, + { "cache-control": "private" }, + { "date": "Mon, 21 Oct 2013 20:13:21 GMT" }, + { "location": "https://www.example.com" } + ], + "encoded": [ + "4803 3330 37c1 c0bf" + ], + "dynamic_table": [ + { ":status": "307" }, + { "location": "https://www.example.com" }, + { "date": "Mon, 21 Oct 2013 20:13:21 GMT" }, + { "cache-control": "private" } + ], + "table_size": 222 + }, + { + "headers": [ + { ":status": "200" }, + { "cache-control": "private" }, + { "date": "Mon, 21 Oct 2013 20:13:22 GMT" }, + { "location": "https://www.example.com" }, + { "content-encoding": "gzip" }, + { "set-cookie": "foo=ASDJKHQKBZXOQWEOPIUAXQWEOIU; max-age=3600; version=1" } + ], + "encoded": [ + "88c1 6196 d07a be94 1054 d444 a820 0595", + "040b 8166 e084 a62d 1bff c05a 839b d9ab", + "77ad 94e7 821d d7f2 e6c7 b335 dfdf cd5b", + "3960 d5af 2708 7f36 72c1 ab27 0fb5 291f", + "9587 3160 65c0 03ed 4ee5 b106 3d50 07" + ], + "dynamic_table": [ + { "set-cookie": "foo=ASDJKHQKBZXOQWEOPIUAXQWEOIU; max-age=3600; version=1" }, + { "content-encoding": "gzip" }, + { "date": "Mon, 21 Oct 2013 20:13:22 GMT" } + ], + "table_size": 215 + } + ] +} diff --git a/netty-handler-codec-http2/src/test/resources/io/netty/handler/codec/http2/testdata/testStaticTableEntries.json b/netty-handler-codec-http2/src/test/resources/io/netty/handler/codec/http2/testdata/testStaticTableEntries.json new file mode 100644 index 0000000..9110306 --- /dev/null +++ b/netty-handler-codec-http2/src/test/resources/io/netty/handler/codec/http2/testdata/testStaticTableEntries.json @@ -0,0 +1,72 @@ +{ + "header_blocks": + [ + { + "headers": [ + { ":authority": "" }, + { ":method": "GET" }, + { ":method": "POST" }, + { ":path": "/" }, + { ":path": "/index.html" }, + { ":scheme": "http" }, + { ":scheme": "https" }, + { "accept-charset": "" }, + { "accept-encoding": "gzip, deflate" }, + { "accept-language": "" }, + { "accept-ranges": "" }, + { "accept": "" }, + { "access-control-allow-origin": "" }, + { "age": "" }, + { "allow": "" }, + { "authorization": "" }, + { "cache-control": "" }, + { "content-disposition": "" }, + { "content-encoding": "" }, + { "content-language": "" }, + { "content-length": "" }, + { "content-location": "" }, + { "content-range": "" }, + { "content-type": "" }, + { "cookie": "" }, + { "date": "" }, + { "etag": "" }, + { "expect": "" }, + { "expires": "" }, + { "from": "" }, + { "host": "" }, + { "if-match": "" }, + { "if-modified-since": "" }, + { "if-none-match": "" }, + { "if-range": "" }, + { "if-unmodified-since": "" }, + { "last-modified": "" }, + { "link": "" }, + { "location": "" }, + { "max-forwards": "" }, + { "proxy-authenticate": "" }, + { "proxy-authorization": "" }, + { "range": "" }, + { "referer": "" }, + { "refresh": "" }, + { "retry-after": "" }, + { "server": "" }, + { "set-cookie": "" }, + { "strict-transport-security": "" }, + { "user-agent": "" }, + { "vary": "" }, + { "via": "" }, + { "www-authenticate": "" } + ], + "encoded": [ + "8182 8384 8586 87 8f90", + "9192 9394 9596 9798 999a 9b9c 9d9e 9fa0", + "a1a2 a3a4 a5a6 a7a8 a9aa abac adae afb0", + "b1b2 b3b4 b5b6 b7b8 babb bcbd" + ], + "dynamic_table": [ + ], + "table_size": 0 + } + ] +} + diff --git a/netty-handler-codec-http2/src/test/resources/io/netty/handler/codec/http2/testdata/testStaticTableResponseEntries.json b/netty-handler-codec-http2/src/test/resources/io/netty/handler/codec/http2/testdata/testStaticTableResponseEntries.json new file mode 100644 index 0000000..91ec07b --- /dev/null +++ b/netty-handler-codec-http2/src/test/resources/io/netty/handler/codec/http2/testdata/testStaticTableResponseEntries.json @@ -0,0 +1,23 @@ +{ + "header_blocks": + [ + { + "headers": [ + { ":status": "200" }, + { ":status": "204" }, + { ":status": "206" }, + { ":status": "304" }, + { ":status": "400" }, + { ":status": "404" }, + { ":status": "500" } + ], + "encoded": [ + "8889 8a8b 8c8d 8e" + ], + "dynamic_table": [ + ], + "table_size": 0 + } + ] +} + diff --git a/netty-handler-codec-http2/src/test/resources/junit-platform.properties b/netty-handler-codec-http2/src/test/resources/junit-platform.properties new file mode 100644 index 0000000..4bcf35e --- /dev/null +++ b/netty-handler-codec-http2/src/test/resources/junit-platform.properties @@ -0,0 +1,16 @@ +# Copyright 2022 The Netty Project +# +# The Netty Project licenses this file to you under the Apache License, +# version 2.0 (the "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. + +junit.jupiter.execution.parallel.enabled = true +junit.jupiter.execution.parallel.mode.default = concurrent diff --git a/netty-handler-codec-http2/src/test/resources/logging.properties b/netty-handler-codec-http2/src/test/resources/logging.properties new file mode 100644 index 0000000..3cd7309 --- /dev/null +++ b/netty-handler-codec-http2/src/test/resources/logging.properties @@ -0,0 +1,7 @@ +handlers=java.util.logging.ConsoleHandler +.level=ALL +java.util.logging.SimpleFormatter.format=%1$tY-%1$tm-%1$td %1$tH:%1$tM:%1$tS.%1$tL %4$-7s [%3$s] %5$s %6$s%n +java.util.logging.ConsoleHandler.level=ALL +java.util.logging.ConsoleHandler.formatter=java.util.logging.SimpleFormatter +jdk.event.security.level=INFO +org.junit.jupiter.engine.execution.ConditionEvaluator.level=OFF diff --git a/netty-handler-codec-protobuf/build.gradle b/netty-handler-codec-protobuf/build.gradle new file mode 100644 index 0000000..71f6f7a --- /dev/null +++ b/netty-handler-codec-protobuf/build.gradle @@ -0,0 +1,4 @@ +dependencies { + api project(':netty-handler-codec') + implementation libs.protobuf +} diff --git a/netty-handler-codec-protobuf/src/main/java/io/netty/handler/codec/protobuf/ProtobufDecoder.java b/netty-handler-codec-protobuf/src/main/java/io/netty/handler/codec/protobuf/ProtobufDecoder.java new file mode 100644 index 0000000..31f0376 --- /dev/null +++ b/netty-handler-codec-protobuf/src/main/java/io/netty/handler/codec/protobuf/ProtobufDecoder.java @@ -0,0 +1,133 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.protobuf; + +import com.google.protobuf.ExtensionRegistry; +import com.google.protobuf.ExtensionRegistryLite; +import com.google.protobuf.Message; +import com.google.protobuf.MessageLite; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.channel.ChannelHandler.Sharable; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPipeline; +import io.netty.handler.codec.ByteToMessageDecoder; +import io.netty.handler.codec.LengthFieldBasedFrameDecoder; +import io.netty.handler.codec.LengthFieldPrepender; +import io.netty.handler.codec.MessageToMessageDecoder; +import io.netty.util.internal.ObjectUtil; + +import java.util.List; + +/** + * Decodes a received {@link ByteBuf} into a + * Google Protocol Buffers + * {@link Message} and {@link MessageLite}. Please note that this decoder must + * be used with a proper {@link ByteToMessageDecoder} such as {@link ProtobufVarint32FrameDecoder} + * or {@link LengthFieldBasedFrameDecoder} if you are using a stream-based + * transport such as TCP/IP. A typical setup for TCP/IP would be: + *
+ * {@link ChannelPipeline} pipeline = ...;
+ *
+ * // Decoders
+ * pipeline.addLast("frameDecoder",
+ *                  new {@link LengthFieldBasedFrameDecoder}(1048576, 0, 4, 0, 4));
+ * pipeline.addLast("protobufDecoder",
+ *                  new {@link ProtobufDecoder}(MyMessage.getDefaultInstance()));
+ *
+ * // Encoder
+ * pipeline.addLast("frameEncoder", new {@link LengthFieldPrepender}(4));
+ * pipeline.addLast("protobufEncoder", new {@link ProtobufEncoder}());
+ * 
+ * and then you can use a {@code MyMessage} instead of a {@link ByteBuf} + * as a message: + *
+ * void channelRead({@link ChannelHandlerContext} ctx, Object msg) {
+ *     MyMessage req = (MyMessage) msg;
+ *     MyMessage res = MyMessage.newBuilder().setText(
+ *                               "Did you say '" + req.getText() + "'?").build();
+ *     ch.write(res);
+ * }
+ * 
+ */ +@Sharable +public class ProtobufDecoder extends MessageToMessageDecoder { + + private static final boolean HAS_PARSER; + + static { + boolean hasParser = false; + try { + // MessageLite.getParserForType() is not available until protobuf 2.5.0. + MessageLite.class.getDeclaredMethod("getParserForType"); + hasParser = true; + } catch (Throwable t) { + // Ignore + } + + HAS_PARSER = hasParser; + } + + private final MessageLite prototype; + private final ExtensionRegistryLite extensionRegistry; + + /** + * Creates a new instance. + */ + public ProtobufDecoder(MessageLite prototype) { + this(prototype, null); + } + + public ProtobufDecoder(MessageLite prototype, ExtensionRegistry extensionRegistry) { + this(prototype, (ExtensionRegistryLite) extensionRegistry); + } + + public ProtobufDecoder(MessageLite prototype, ExtensionRegistryLite extensionRegistry) { + this.prototype = ObjectUtil.checkNotNull(prototype, "prototype").getDefaultInstanceForType(); + this.extensionRegistry = extensionRegistry; + } + + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf msg, List out) + throws Exception { + final byte[] array; + final int offset; + final int length = msg.readableBytes(); + if (msg.hasArray()) { + array = msg.array(); + offset = msg.arrayOffset() + msg.readerIndex(); + } else { + array = ByteBufUtil.getBytes(msg, msg.readerIndex(), length, false); + offset = 0; + } + + if (extensionRegistry == null) { + if (HAS_PARSER) { + out.add(prototype.getParserForType().parseFrom(array, offset, length)); + } else { + out.add(prototype.newBuilderForType().mergeFrom(array, offset, length).build()); + } + } else { + if (HAS_PARSER) { + out.add(prototype.getParserForType().parseFrom( + array, offset, length, extensionRegistry)); + } else { + out.add(prototype.newBuilderForType().mergeFrom( + array, offset, length, extensionRegistry).build()); + } + } + } +} diff --git a/netty-handler-codec-protobuf/src/main/java/io/netty/handler/codec/protobuf/ProtobufEncoder.java b/netty-handler-codec-protobuf/src/main/java/io/netty/handler/codec/protobuf/ProtobufEncoder.java new file mode 100644 index 0000000..776e425 --- /dev/null +++ b/netty-handler-codec-protobuf/src/main/java/io/netty/handler/codec/protobuf/ProtobufEncoder.java @@ -0,0 +1,74 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.protobuf; + +import com.google.protobuf.Message; +import com.google.protobuf.MessageLite; +import com.google.protobuf.MessageLiteOrBuilder; +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandler.Sharable; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPipeline; +import io.netty.handler.codec.LengthFieldBasedFrameDecoder; +import io.netty.handler.codec.LengthFieldPrepender; +import io.netty.handler.codec.MessageToMessageEncoder; + +import java.util.List; + +import static io.netty.buffer.Unpooled.*; + +/** + * Encodes the requested Google + * Protocol Buffers {@link Message} and {@link MessageLite} into a + * {@link ByteBuf}. A typical setup for TCP/IP would be: + *
+ * {@link ChannelPipeline} pipeline = ...;
+ *
+ * // Decoders
+ * pipeline.addLast("frameDecoder",
+ *                  new {@link LengthFieldBasedFrameDecoder}(1048576, 0, 4, 0, 4));
+ * pipeline.addLast("protobufDecoder",
+ *                  new {@link ProtobufDecoder}(MyMessage.getDefaultInstance()));
+ *
+ * // Encoder
+ * pipeline.addLast("frameEncoder", new {@link LengthFieldPrepender}(4));
+ * pipeline.addLast("protobufEncoder", new {@link ProtobufEncoder}());
+ * 
+ * and then you can use a {@code MyMessage} instead of a {@link ByteBuf} + * as a message: + *
+ * void channelRead({@link ChannelHandlerContext} ctx, Object msg) {
+ *     MyMessage req = (MyMessage) msg;
+ *     MyMessage res = MyMessage.newBuilder().setText(
+ *                               "Did you say '" + req.getText() + "'?").build();
+ *     ch.write(res);
+ * }
+ * 
+ */ +@Sharable +public class ProtobufEncoder extends MessageToMessageEncoder { + @Override + protected void encode(ChannelHandlerContext ctx, MessageLiteOrBuilder msg, List out) + throws Exception { + if (msg instanceof MessageLite) { + out.add(wrappedBuffer(((MessageLite) msg).toByteArray())); + return; + } + if (msg instanceof MessageLite.Builder) { + out.add(wrappedBuffer(((MessageLite.Builder) msg).build().toByteArray())); + } + } +} diff --git a/netty-handler-codec-protobuf/src/main/java/io/netty/handler/codec/protobuf/ProtobufVarint32FrameDecoder.java b/netty-handler-codec-protobuf/src/main/java/io/netty/handler/codec/protobuf/ProtobufVarint32FrameDecoder.java new file mode 100644 index 0000000..85c0d26 --- /dev/null +++ b/netty-handler-codec-protobuf/src/main/java/io/netty/handler/codec/protobuf/ProtobufVarint32FrameDecoder.java @@ -0,0 +1,119 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.protobuf; + +import com.google.protobuf.CodedInputStream; +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.ByteToMessageDecoder; +import io.netty.handler.codec.CorruptedFrameException; + +import java.util.List; + +/** + * A decoder that splits the received {@link ByteBuf}s dynamically by the + * value of the Google Protocol Buffers + * Base + * 128 Varints integer length field in the message. For example: + *
+ * BEFORE DECODE (302 bytes)       AFTER DECODE (300 bytes)
+ * +--------+---------------+      +---------------+
+ * | Length | Protobuf Data |----->| Protobuf Data |
+ * | 0xAC02 |  (300 bytes)  |      |  (300 bytes)  |
+ * +--------+---------------+      +---------------+
+ * 
+ * + * @see CodedInputStream + */ +public class ProtobufVarint32FrameDecoder extends ByteToMessageDecoder { + + // TODO maxFrameLength + safe skip + fail-fast option + // (just like LengthFieldBasedFrameDecoder) + + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) + throws Exception { + in.markReaderIndex(); + int preIndex = in.readerIndex(); + int length = readRawVarint32(in); + if (preIndex == in.readerIndex()) { + return; + } + if (length < 0) { + throw new CorruptedFrameException("negative length: " + length); + } + + if (in.readableBytes() < length) { + in.resetReaderIndex(); + } else { + out.add(in.readRetainedSlice(length)); + } + } + + /** + * Reads variable length 32bit int from buffer + * + * @return decoded int if buffers readerIndex has been forwarded else nonsense value + */ + private static int readRawVarint32(ByteBuf buffer) { + if (!buffer.isReadable()) { + return 0; + } + buffer.markReaderIndex(); + byte tmp = buffer.readByte(); + if (tmp >= 0) { + return tmp; + } else { + int result = tmp & 127; + if (!buffer.isReadable()) { + buffer.resetReaderIndex(); + return 0; + } + if ((tmp = buffer.readByte()) >= 0) { + result |= tmp << 7; + } else { + result |= (tmp & 127) << 7; + if (!buffer.isReadable()) { + buffer.resetReaderIndex(); + return 0; + } + if ((tmp = buffer.readByte()) >= 0) { + result |= tmp << 14; + } else { + result |= (tmp & 127) << 14; + if (!buffer.isReadable()) { + buffer.resetReaderIndex(); + return 0; + } + if ((tmp = buffer.readByte()) >= 0) { + result |= tmp << 21; + } else { + result |= (tmp & 127) << 21; + if (!buffer.isReadable()) { + buffer.resetReaderIndex(); + return 0; + } + result |= (tmp = buffer.readByte()) << 28; + if (tmp < 0) { + throw new CorruptedFrameException("malformed varint."); + } + } + } + } + return result; + } + } +} diff --git a/netty-handler-codec-protobuf/src/main/java/io/netty/handler/codec/protobuf/ProtobufVarint32LengthFieldPrepender.java b/netty-handler-codec-protobuf/src/main/java/io/netty/handler/codec/protobuf/ProtobufVarint32LengthFieldPrepender.java new file mode 100644 index 0000000..54d14ff --- /dev/null +++ b/netty-handler-codec-protobuf/src/main/java/io/netty/handler/codec/protobuf/ProtobufVarint32LengthFieldPrepender.java @@ -0,0 +1,88 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.protobuf; + +import com.google.protobuf.CodedOutputStream; +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandler.Sharable; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.MessageToByteEncoder; + +/** + * An encoder that prepends the Google Protocol Buffers + * Base + * 128 Varints integer length field. For example: + *
+ * BEFORE ENCODE (300 bytes)       AFTER ENCODE (302 bytes)
+ * +---------------+               +--------+---------------+
+ * | Protobuf Data |-------------->| Length | Protobuf Data |
+ * |  (300 bytes)  |               | 0xAC02 |  (300 bytes)  |
+ * +---------------+               +--------+---------------+
+ * 
* + * + * @see CodedOutputStream + */ +@Sharable +public class ProtobufVarint32LengthFieldPrepender extends MessageToByteEncoder { + + @Override + protected void encode( + ChannelHandlerContext ctx, ByteBuf msg, ByteBuf out) throws Exception { + int bodyLen = msg.readableBytes(); + int headerLen = computeRawVarint32Size(bodyLen); + out.ensureWritable(headerLen + bodyLen); + writeRawVarint32(out, bodyLen); + out.writeBytes(msg, msg.readerIndex(), bodyLen); + } + + /** + * Writes protobuf varint32 to (@link ByteBuf). + * @param out to be written to + * @param value to be written + */ + static void writeRawVarint32(ByteBuf out, int value) { + while (true) { + if ((value & ~0x7F) == 0) { + out.writeByte(value); + return; + } else { + out.writeByte((value & 0x7F) | 0x80); + value >>>= 7; + } + } + } + + /** + * Computes size of protobuf varint32 after encoding. + * @param value which is to be encoded. + * @return size of value encoded as protobuf varint32. + */ + static int computeRawVarint32Size(final int value) { + if ((value & (0xffffffff << 7)) == 0) { + return 1; + } + if ((value & (0xffffffff << 14)) == 0) { + return 2; + } + if ((value & (0xffffffff << 21)) == 0) { + return 3; + } + if ((value & (0xffffffff << 28)) == 0) { + return 4; + } + return 5; + } +} diff --git a/netty-handler-codec-protobuf/src/main/java/io/netty/handler/codec/protobuf/package-info.java b/netty-handler-codec-protobuf/src/main/java/io/netty/handler/codec/protobuf/package-info.java new file mode 100644 index 0000000..74730e8 --- /dev/null +++ b/netty-handler-codec-protobuf/src/main/java/io/netty/handler/codec/protobuf/package-info.java @@ -0,0 +1,23 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/** + * Encoder and decoder which transform a + * Google Protocol Buffers + * {@link com.google.protobuf.Message} into a + * {@link io.netty.buffer.ByteBuf} and vice versa. + */ +package io.netty.handler.codec.protobuf; diff --git a/netty-handler-codec-protobuf/src/main/java/module-info.java b/netty-handler-codec-protobuf/src/main/java/module-info.java new file mode 100644 index 0000000..38dbfac --- /dev/null +++ b/netty-handler-codec-protobuf/src/main/java/module-info.java @@ -0,0 +1,8 @@ +module org.xbib.io.netty.handler.codec.protobuf { + exports io.netty.handler.codec.protobuf; + requires org.xbib.io.netty.buffer; + requires org.xbib.io.netty.channel; + requires org.xbib.io.netty.util; + requires org.xbib.io.netty.handler.codec; + requires com.google.protobuf; +} diff --git a/netty-handler-codec/build.gradle b/netty-handler-codec/build.gradle new file mode 100644 index 0000000..aec8cc2 --- /dev/null +++ b/netty-handler-codec/build.gradle @@ -0,0 +1,4 @@ +dependencies { + api project(':netty-channel') + testImplementation testLibs.mockito.core +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/AsciiHeadersEncoder.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/AsciiHeadersEncoder.java new file mode 100644 index 0000000..aaa72bb --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/AsciiHeadersEncoder.java @@ -0,0 +1,121 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.codec; + + +import java.util.Map.Entry; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.util.AsciiString; +import io.netty.util.CharsetUtil; +import io.netty.util.internal.ObjectUtil; + +public final class AsciiHeadersEncoder { + + /** + * The separator characters to insert between a header name and a header value. + */ + public enum SeparatorType { + /** + * {@code ':'} + */ + COLON, + /** + * {@code ': '} + */ + COLON_SPACE, + } + + /** + * The newline characters to insert between header entries. + */ + public enum NewlineType { + /** + * {@code '\n'} + */ + LF, + /** + * {@code '\r\n'} + */ + CRLF + } + + private final ByteBuf buf; + private final SeparatorType separatorType; + private final NewlineType newlineType; + + public AsciiHeadersEncoder(ByteBuf buf) { + this(buf, SeparatorType.COLON_SPACE, NewlineType.CRLF); + } + + public AsciiHeadersEncoder(ByteBuf buf, SeparatorType separatorType, NewlineType newlineType) { + this.buf = ObjectUtil.checkNotNull(buf, "buf"); + this.separatorType = ObjectUtil.checkNotNull(separatorType, "separatorType"); + this.newlineType = ObjectUtil.checkNotNull(newlineType, "newlineType"); + } + + public void encode(Entry entry) { + final CharSequence name = entry.getKey(); + final CharSequence value = entry.getValue(); + final ByteBuf buf = this.buf; + final int nameLen = name.length(); + final int valueLen = value.length(); + final int entryLen = nameLen + valueLen + 4; + int offset = buf.writerIndex(); + buf.ensureWritable(entryLen); + writeAscii(buf, offset, name); + offset += nameLen; + + switch (separatorType) { + case COLON: + buf.setByte(offset ++, ':'); + break; + case COLON_SPACE: + buf.setByte(offset ++, ':'); + buf.setByte(offset ++, ' '); + break; + default: + throw new Error(); + } + + writeAscii(buf, offset, value); + offset += valueLen; + + switch (newlineType) { + case LF: + buf.setByte(offset ++, '\n'); + break; + case CRLF: + buf.setByte(offset ++, '\r'); + buf.setByte(offset ++, '\n'); + break; + default: + throw new Error(); + } + + buf.writerIndex(offset); + } + + private static void writeAscii(ByteBuf buf, int offset, CharSequence value) { + if (value instanceof AsciiString) { + ByteBufUtil.copy((AsciiString) value, 0, buf, offset, value.length()); + } else { + buf.setCharSequence(offset, value, CharsetUtil.US_ASCII); + } + } +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/ByteToMessageCodec.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/ByteToMessageCodec.java new file mode 100644 index 0000000..6f505d2 --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/ByteToMessageCodec.java @@ -0,0 +1,175 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelDuplexHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.util.internal.TypeParameterMatcher; + +import java.util.List; + +/** + * A Codec for on-the-fly encoding/decoding of bytes to messages and vise-versa. + * + * This can be thought of as a combination of {@link ByteToMessageDecoder} and {@link MessageToByteEncoder}. + * + * Be aware that sub-classes of {@link ByteToMessageCodec} MUST NOT + * annotated with @Sharable. + */ +public abstract class ByteToMessageCodec extends ChannelDuplexHandler { + + private final TypeParameterMatcher outboundMsgMatcher; + private final MessageToByteEncoder encoder; + + private final ByteToMessageDecoder decoder = new ByteToMessageDecoder() { + @Override + public void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { + ByteToMessageCodec.this.decode(ctx, in, out); + } + + @Override + protected void decodeLast(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { + ByteToMessageCodec.this.decodeLast(ctx, in, out); + } + }; + + /** + * see {@link #ByteToMessageCodec(boolean)} with {@code true} as boolean parameter. + */ + protected ByteToMessageCodec() { + this(true); + } + + /** + * see {@link #ByteToMessageCodec(Class, boolean)} with {@code true} as boolean value. + */ + protected ByteToMessageCodec(Class outboundMessageType) { + this(outboundMessageType, true); + } + + /** + * Create a new instance which will try to detect the types to match out of the type parameter of the class. + * + * @param preferDirect {@code true} if a direct {@link ByteBuf} should be tried to be used as target for + * the encoded messages. If {@code false} is used it will allocate a heap + * {@link ByteBuf}, which is backed by an byte array. + */ + protected ByteToMessageCodec(boolean preferDirect) { + ensureNotSharable(); + outboundMsgMatcher = TypeParameterMatcher.find(this, ByteToMessageCodec.class, "I"); + encoder = new Encoder(preferDirect); + } + + /** + * Create a new instance + * + * @param outboundMessageType The type of messages to match + * @param preferDirect {@code true} if a direct {@link ByteBuf} should be tried to be used as target for + * the encoded messages. If {@code false} is used it will allocate a heap + * {@link ByteBuf}, which is backed by an byte array. + */ + protected ByteToMessageCodec(Class outboundMessageType, boolean preferDirect) { + ensureNotSharable(); + outboundMsgMatcher = TypeParameterMatcher.get(outboundMessageType); + encoder = new Encoder(preferDirect); + } + + /** + * Returns {@code true} if and only if the specified message can be encoded by this codec. + * + * @param msg the message + */ + public boolean acceptOutboundMessage(Object msg) throws Exception { + return outboundMsgMatcher.match(msg); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + decoder.channelRead(ctx, msg); + } + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + encoder.write(ctx, msg, promise); + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + decoder.channelReadComplete(ctx); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + decoder.channelInactive(ctx); + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + try { + decoder.handlerAdded(ctx); + } finally { + encoder.handlerAdded(ctx); + } + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { + try { + decoder.handlerRemoved(ctx); + } finally { + encoder.handlerRemoved(ctx); + } + } + + /** + * @see MessageToByteEncoder#encode(ChannelHandlerContext, Object, ByteBuf) + */ + protected abstract void encode(ChannelHandlerContext ctx, I msg, ByteBuf out) throws Exception; + + /** + * @see ByteToMessageDecoder#decode(ChannelHandlerContext, ByteBuf, List) + */ + protected abstract void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception; + + /** + * @see ByteToMessageDecoder#decodeLast(ChannelHandlerContext, ByteBuf, List) + */ + protected void decodeLast(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { + if (in.isReadable()) { + // Only call decode() if there is something left in the buffer to decode. + // See https://github.com/netty/netty/issues/4386 + decode(ctx, in, out); + } + } + + private final class Encoder extends MessageToByteEncoder { + Encoder(boolean preferDirect) { + super(preferDirect); + } + + @Override + public boolean acceptOutboundMessage(Object msg) throws Exception { + return ByteToMessageCodec.this.acceptOutboundMessage(msg); + } + + @Override + protected void encode(ChannelHandlerContext ctx, I msg, ByteBuf out) throws Exception { + ByteToMessageCodec.this.encode(ctx, msg, out); + } + } +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/ByteToMessageDecoder.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/ByteToMessageDecoder.java new file mode 100644 index 0000000..0ec4c34 --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/ByteToMessageDecoder.java @@ -0,0 +1,586 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelConfig; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.socket.ChannelInputShutdownEvent; +import io.netty.util.IllegalReferenceCountException; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.StringUtil; + +import java.util.List; + +import static io.netty.util.internal.ObjectUtil.checkPositive; +import static java.lang.Integer.MAX_VALUE; + +/** + * {@link ChannelInboundHandlerAdapter} which decodes bytes in a stream-like fashion from one {@link ByteBuf} to an + * other Message type. + * + * For example here is an implementation which reads all readable bytes from + * the input {@link ByteBuf} and create a new {@link ByteBuf}. + * + *
+ *     public class SquareDecoder extends {@link ByteToMessageDecoder} {
+ *         {@code @Override}
+ *         public void decode({@link ChannelHandlerContext} ctx, {@link ByteBuf} in, List<Object> out)
+ *                 throws {@link Exception} {
+ *             out.add(in.readBytes(in.readableBytes()));
+ *         }
+ *     }
+ * 
+ * + *

Frame detection

+ *

+ * Generally frame detection should be handled earlier in the pipeline by adding a + * {@link DelimiterBasedFrameDecoder}, {@link FixedLengthFrameDecoder}, {@link LengthFieldBasedFrameDecoder}, + * or {@link LineBasedFrameDecoder}. + *

+ * If a custom frame decoder is required, then one needs to be careful when implementing + * one with {@link ByteToMessageDecoder}. Ensure there are enough bytes in the buffer for a + * complete frame by checking {@link ByteBuf#readableBytes()}. If there are not enough bytes + * for a complete frame, return without modifying the reader index to allow more bytes to arrive. + *

+ * To check for complete frames without modifying the reader index, use methods like {@link ByteBuf#getInt(int)}. + * One MUST use the reader index when using methods like {@link ByteBuf#getInt(int)}. + * For example calling in.getInt(0) is assuming the frame starts at the beginning of the buffer, which + * is not always the case. Use in.getInt(in.readerIndex()) instead. + *

Pitfalls

+ *

+ * Be aware that sub-classes of {@link ByteToMessageDecoder} MUST NOT + * annotated with {@link @Sharable}. + *

+ * Some methods such as {@link ByteBuf#readBytes(int)} will cause a memory leak if the returned buffer + * is not released or added to the out {@link List}. Use derived buffers like {@link ByteBuf#readSlice(int)} + * to avoid leaking memory. + */ +public abstract class ByteToMessageDecoder extends ChannelInboundHandlerAdapter { + + /** + * Cumulate {@link ByteBuf}s by merge them into one {@link ByteBuf}'s, using memory copies. + */ + public static final Cumulator MERGE_CUMULATOR = new Cumulator() { + @Override + public ByteBuf cumulate(ByteBufAllocator alloc, ByteBuf cumulation, ByteBuf in) { + if (cumulation == in) { + // when the in buffer is the same as the cumulation it is doubly retained, release it once + in.release(); + return cumulation; + } + if (!cumulation.isReadable() && in.isContiguous()) { + // If cumulation is empty and input buffer is contiguous, use it directly + cumulation.release(); + return in; + } + try { + final int required = in.readableBytes(); + if (required > cumulation.maxWritableBytes() || + required > cumulation.maxFastWritableBytes() && cumulation.refCnt() > 1 || + cumulation.isReadOnly()) { + // Expand cumulation (by replacing it) under the following conditions: + // - cumulation cannot be resized to accommodate the additional data + // - cumulation can be expanded with a reallocation operation to accommodate but the buffer is + // assumed to be shared (e.g. refCnt() > 1) and the reallocation may not be safe. + return expandCumulation(alloc, cumulation, in); + } + cumulation.writeBytes(in, in.readerIndex(), required); + in.readerIndex(in.writerIndex()); + return cumulation; + } finally { + // We must release in all cases as otherwise it may produce a leak if writeBytes(...) throw + // for whatever release (for example because of OutOfMemoryError) + in.release(); + } + } + }; + + /** + * Cumulate {@link ByteBuf}s by add them to a {@link CompositeByteBuf} and so do no memory copy whenever possible. + * Be aware that {@link CompositeByteBuf} use a more complex indexing implementation so depending on your use-case + * and the decoder implementation this may be slower than just use the {@link #MERGE_CUMULATOR}. + */ + public static final Cumulator COMPOSITE_CUMULATOR = new Cumulator() { + @Override + public ByteBuf cumulate(ByteBufAllocator alloc, ByteBuf cumulation, ByteBuf in) { + if (cumulation == in) { + // when the in buffer is the same as the cumulation it is doubly retained, release it once + in.release(); + return cumulation; + } + if (!cumulation.isReadable()) { + cumulation.release(); + return in; + } + CompositeByteBuf composite = null; + try { + if (cumulation instanceof CompositeByteBuf && cumulation.refCnt() == 1) { + composite = (CompositeByteBuf) cumulation; + // Writer index must equal capacity if we are going to "write" + // new components to the end + if (composite.writerIndex() != composite.capacity()) { + composite.capacity(composite.writerIndex()); + } + } else { + composite = alloc.compositeBuffer(Integer.MAX_VALUE).addFlattenedComponents(true, cumulation); + } + composite.addFlattenedComponents(true, in); + in = null; + return composite; + } finally { + if (in != null) { + // We must release if the ownership was not transferred as otherwise it may produce a leak + in.release(); + // Also release any new buffer allocated if we're not returning it + if (composite != null && composite != cumulation) { + composite.release(); + } + } + } + } + }; + + private static final byte STATE_INIT = 0; + private static final byte STATE_CALLING_CHILD_DECODE = 1; + private static final byte STATE_HANDLER_REMOVED_PENDING = 2; + + ByteBuf cumulation; + private Cumulator cumulator = MERGE_CUMULATOR; + private boolean singleDecode; + private boolean first; + + /** + * This flag is used to determine if we need to call {@link ChannelHandlerContext#read()} to consume more data + * when {@link ChannelConfig#isAutoRead()} is {@code false}. + */ + private boolean firedChannelRead; + + private boolean selfFiredChannelRead; + + /** + * A bitmask where the bits are defined as + *

    + *
  • {@link #STATE_INIT}
  • + *
  • {@link #STATE_CALLING_CHILD_DECODE}
  • + *
  • {@link #STATE_HANDLER_REMOVED_PENDING}
  • + *
+ */ + private byte decodeState = STATE_INIT; + private int discardAfterReads = 16; + private int numReads; + + protected ByteToMessageDecoder() { + ensureNotSharable(); + } + + /** + * If set then only one message is decoded on each {@link #channelRead(ChannelHandlerContext, Object)} + * call. This may be useful if you need to do some protocol upgrade and want to make sure nothing is mixed up. + * + * Default is {@code false} as this has performance impacts. + */ + public void setSingleDecode(boolean singleDecode) { + this.singleDecode = singleDecode; + } + + /** + * If {@code true} then only one message is decoded on each + * {@link #channelRead(ChannelHandlerContext, Object)} call. + * + * Default is {@code false} as this has performance impacts. + */ + public boolean isSingleDecode() { + return singleDecode; + } + + /** + * Set the {@link Cumulator} to use for cumulate the received {@link ByteBuf}s. + */ + public void setCumulator(Cumulator cumulator) { + this.cumulator = ObjectUtil.checkNotNull(cumulator, "cumulator"); + } + + /** + * Set the number of reads after which {@link ByteBuf#discardSomeReadBytes()} are called and so free up memory. + * The default is {@code 16}. + */ + public void setDiscardAfterReads(int discardAfterReads) { + checkPositive(discardAfterReads, "discardAfterReads"); + this.discardAfterReads = discardAfterReads; + } + + /** + * Returns the actual number of readable bytes in the internal cumulative + * buffer of this decoder. You usually do not need to rely on this value + * to write a decoder. Use it only when you must use it at your own risk. + * This method is a shortcut to {@link #internalBuffer() internalBuffer().readableBytes()}. + */ + protected int actualReadableBytes() { + return internalBuffer().readableBytes(); + } + + /** + * Returns the internal cumulative buffer of this decoder. You usually + * do not need to access the internal buffer directly to write a decoder. + * Use it only when you must use it at your own risk. + */ + protected ByteBuf internalBuffer() { + if (cumulation != null) { + return cumulation; + } else { + return Unpooled.EMPTY_BUFFER; + } + } + + @Override + public final void handlerRemoved(ChannelHandlerContext ctx) throws Exception { + if (decodeState == STATE_CALLING_CHILD_DECODE) { + decodeState = STATE_HANDLER_REMOVED_PENDING; + return; + } + ByteBuf buf = cumulation; + if (buf != null) { + // Directly set this to null, so we are sure we not access it in any other method here anymore. + cumulation = null; + numReads = 0; + int readable = buf.readableBytes(); + if (readable > 0) { + ctx.fireChannelRead(buf); + ctx.fireChannelReadComplete(); + } else { + buf.release(); + } + } + handlerRemoved0(ctx); + } + + /** + * Gets called after the {@link ByteToMessageDecoder} was removed from the actual context and it doesn't handle + * events anymore. + */ + protected void handlerRemoved0(ChannelHandlerContext ctx) throws Exception { } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if (msg instanceof ByteBuf) { + selfFiredChannelRead = true; + CodecOutputList out = CodecOutputList.newInstance(); + try { + first = cumulation == null; + cumulation = cumulator.cumulate(ctx.alloc(), + first ? Unpooled.EMPTY_BUFFER : cumulation, (ByteBuf) msg); + callDecode(ctx, cumulation, out); + } catch (DecoderException e) { + throw e; + } catch (Exception e) { + throw new DecoderException(e); + } finally { + try { + if (cumulation != null && !cumulation.isReadable()) { + numReads = 0; + try { + cumulation.release(); + } catch (IllegalReferenceCountException e) { + //noinspection ThrowFromFinallyBlock + throw new IllegalReferenceCountException( + getClass().getSimpleName() + "#decode() might have released its input buffer, " + + "or passed it down the pipeline without a retain() call, " + + "which is not allowed.", e); + } + cumulation = null; + } else if (++numReads >= discardAfterReads) { + // We did enough reads already try to discard some bytes, so we not risk to see a OOME. + // See https://github.com/netty/netty/issues/4275 + numReads = 0; + discardSomeReadBytes(); + } + + int size = out.size(); + firedChannelRead |= out.insertSinceRecycled(); + fireChannelRead(ctx, out, size); + } finally { + out.recycle(); + } + } + } else { + ctx.fireChannelRead(msg); + } + } + + /** + * Get {@code numElements} out of the {@link List} and forward these through the pipeline. + */ + static void fireChannelRead(ChannelHandlerContext ctx, List msgs, int numElements) { + if (msgs instanceof CodecOutputList) { + fireChannelRead(ctx, (CodecOutputList) msgs, numElements); + } else { + for (int i = 0; i < numElements; i++) { + ctx.fireChannelRead(msgs.get(i)); + } + } + } + + /** + * Get {@code numElements} out of the {@link CodecOutputList} and forward these through the pipeline. + */ + static void fireChannelRead(ChannelHandlerContext ctx, CodecOutputList msgs, int numElements) { + for (int i = 0; i < numElements; i ++) { + ctx.fireChannelRead(msgs.getUnsafe(i)); + } + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + numReads = 0; + discardSomeReadBytes(); + if (selfFiredChannelRead && !firedChannelRead && !ctx.channel().config().isAutoRead()) { + ctx.read(); + } + firedChannelRead = false; + ctx.fireChannelReadComplete(); + } + + protected final void discardSomeReadBytes() { + if (cumulation != null && !first && cumulation.refCnt() == 1) { + // discard some bytes if possible to make more room in the + // buffer but only if the refCnt == 1 as otherwise the user may have + // used slice().retain() or duplicate().retain(). + // + // See: + // - https://github.com/netty/netty/issues/2327 + // - https://github.com/netty/netty/issues/1764 + cumulation.discardSomeReadBytes(); + } + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + channelInputClosed(ctx, true); + } + + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + if (evt instanceof ChannelInputShutdownEvent) { + // The decodeLast method is invoked when a channelInactive event is encountered. + // This method is responsible for ending requests in some situations and must be called + // when the input has been shutdown. + channelInputClosed(ctx, false); + } + super.userEventTriggered(ctx, evt); + } + + private void channelInputClosed(ChannelHandlerContext ctx, boolean callChannelInactive) { + CodecOutputList out = CodecOutputList.newInstance(); + try { + channelInputClosed(ctx, out); + } catch (DecoderException e) { + throw e; + } catch (Exception e) { + throw new DecoderException(e); + } finally { + try { + if (cumulation != null) { + cumulation.release(); + cumulation = null; + } + int size = out.size(); + fireChannelRead(ctx, out, size); + if (size > 0) { + // Something was read, call fireChannelReadComplete() + ctx.fireChannelReadComplete(); + } + if (callChannelInactive) { + ctx.fireChannelInactive(); + } + } finally { + // Recycle in all cases + out.recycle(); + } + } + } + + /** + * Called when the input of the channel was closed which may be because it changed to inactive or because of + * {@link ChannelInputShutdownEvent}. + */ + void channelInputClosed(ChannelHandlerContext ctx, List out) throws Exception { + if (cumulation != null) { + callDecode(ctx, cumulation, out); + // If callDecode(...) removed the handle from the pipeline we should not call decodeLast(...) as this would + // be unexpected. + if (!ctx.isRemoved()) { + // Use Unpooled.EMPTY_BUFFER if cumulation become null after calling callDecode(...). + // See https://github.com/netty/netty/issues/10802. + ByteBuf buffer = cumulation == null ? Unpooled.EMPTY_BUFFER : cumulation; + decodeLast(ctx, buffer, out); + } + } else { + decodeLast(ctx, Unpooled.EMPTY_BUFFER, out); + } + } + + /** + * Called once data should be decoded from the given {@link ByteBuf}. This method will call + * {@link #decode(ChannelHandlerContext, ByteBuf, List)} as long as decoding should take place. + * + * @param ctx the {@link ChannelHandlerContext} which this {@link ByteToMessageDecoder} belongs to + * @param in the {@link ByteBuf} from which to read data + * @param out the {@link List} to which decoded messages should be added + */ + protected void callDecode(ChannelHandlerContext ctx, ByteBuf in, List out) { + try { + while (in.isReadable()) { + final int outSize = out.size(); + + if (outSize > 0) { + fireChannelRead(ctx, out, outSize); + out.clear(); + + // Check if this handler was removed before continuing with decoding. + // If it was removed, it is not safe to continue to operate on the buffer. + // + // See: + // - https://github.com/netty/netty/issues/4635 + if (ctx.isRemoved()) { + break; + } + } + + int oldInputLength = in.readableBytes(); + decodeRemovalReentryProtection(ctx, in, out); + + // Check if this handler was removed before continuing the loop. + // If it was removed, it is not safe to continue to operate on the buffer. + // + // See https://github.com/netty/netty/issues/1664 + if (ctx.isRemoved()) { + break; + } + + if (out.isEmpty()) { + if (oldInputLength == in.readableBytes()) { + break; + } else { + continue; + } + } + + if (oldInputLength == in.readableBytes()) { + throw new DecoderException( + StringUtil.simpleClassName(getClass()) + + ".decode() did not read anything but decoded a message."); + } + + if (isSingleDecode()) { + break; + } + } + } catch (DecoderException e) { + throw e; + } catch (Exception cause) { + throw new DecoderException(cause); + } + } + + /** + * Decode the from one {@link ByteBuf} to an other. This method will be called till either the input + * {@link ByteBuf} has nothing to read when return from this method or till nothing was read from the input + * {@link ByteBuf}. + * + * @param ctx the {@link ChannelHandlerContext} which this {@link ByteToMessageDecoder} belongs to + * @param in the {@link ByteBuf} from which to read data + * @param out the {@link List} to which decoded messages should be added + * @throws Exception is thrown if an error occurs + */ + protected abstract void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception; + + /** + * Decode the from one {@link ByteBuf} to an other. This method will be called till either the input + * {@link ByteBuf} has nothing to read when return from this method or till nothing was read from the input + * {@link ByteBuf}. + * + * @param ctx the {@link ChannelHandlerContext} which this {@link ByteToMessageDecoder} belongs to + * @param in the {@link ByteBuf} from which to read data + * @param out the {@link List} to which decoded messages should be added + * @throws Exception is thrown if an error occurs + */ + final void decodeRemovalReentryProtection(ChannelHandlerContext ctx, ByteBuf in, List out) + throws Exception { + decodeState = STATE_CALLING_CHILD_DECODE; + try { + decode(ctx, in, out); + } finally { + boolean removePending = decodeState == STATE_HANDLER_REMOVED_PENDING; + decodeState = STATE_INIT; + if (removePending) { + fireChannelRead(ctx, out, out.size()); + out.clear(); + handlerRemoved(ctx); + } + } + } + + /** + * Is called one last time when the {@link ChannelHandlerContext} goes in-active. Which means the + * {@link #channelInactive(ChannelHandlerContext)} was triggered. + * + * By default, this will just call {@link #decode(ChannelHandlerContext, ByteBuf, List)} but sub-classes may + * override this for some special cleanup operation. + */ + protected void decodeLast(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { + if (in.isReadable()) { + // Only call decode() if there is something left in the buffer to decode. + // See https://github.com/netty/netty/issues/4386 + decodeRemovalReentryProtection(ctx, in, out); + } + } + + static ByteBuf expandCumulation(ByteBufAllocator alloc, ByteBuf oldCumulation, ByteBuf in) { + int oldBytes = oldCumulation.readableBytes(); + int newBytes = in.readableBytes(); + int totalBytes = oldBytes + newBytes; + ByteBuf newCumulation = alloc.buffer(alloc.calculateNewCapacity(totalBytes, MAX_VALUE)); + ByteBuf toRelease = newCumulation; + try { + // This avoids redundant checks and stack depth compared to calling writeBytes(...) + newCumulation.setBytes(0, oldCumulation, oldCumulation.readerIndex(), oldBytes) + .setBytes(oldBytes, in, in.readerIndex(), newBytes) + .writerIndex(totalBytes); + in.readerIndex(in.writerIndex()); + toRelease = oldCumulation; + return newCumulation; + } finally { + toRelease.release(); + } + } + + /** + * Cumulate {@link ByteBuf}s. + */ + public interface Cumulator { + /** + * Cumulate the given {@link ByteBuf}s and return the {@link ByteBuf} that holds the cumulated bytes. + * The implementation is responsible to correctly handle the life-cycle of the given {@link ByteBuf}s and so + * call {@link ByteBuf#release()} if a {@link ByteBuf} is fully consumed. + */ + ByteBuf cumulate(ByteBufAllocator alloc, ByteBuf cumulation, ByteBuf in); + } +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/CharSequenceValueConverter.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/CharSequenceValueConverter.java new file mode 100644 index 0000000..cfcd60d --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/CharSequenceValueConverter.java @@ -0,0 +1,150 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec; + +import io.netty.util.AsciiString; +import io.netty.util.internal.PlatformDependent; + +import java.text.ParseException; +import java.util.Date; + +/** + * Converts to/from native types, general {@link Object}, and {@link CharSequence}s. + */ +public class CharSequenceValueConverter implements ValueConverter { + public static final CharSequenceValueConverter INSTANCE = new CharSequenceValueConverter(); + private static final AsciiString TRUE_ASCII = new AsciiString("true"); + + @Override + public CharSequence convertObject(Object value) { + if (value instanceof CharSequence) { + return (CharSequence) value; + } + return value.toString(); + } + + @Override + public CharSequence convertInt(int value) { + return String.valueOf(value); + } + + @Override + public CharSequence convertLong(long value) { + return String.valueOf(value); + } + + @Override + public CharSequence convertDouble(double value) { + return String.valueOf(value); + } + + @Override + public CharSequence convertChar(char value) { + return String.valueOf(value); + } + + @Override + public CharSequence convertBoolean(boolean value) { + return String.valueOf(value); + } + + @Override + public CharSequence convertFloat(float value) { + return String.valueOf(value); + } + + @Override + public boolean convertToBoolean(CharSequence value) { + return AsciiString.contentEqualsIgnoreCase(value, TRUE_ASCII); + } + + @Override + public CharSequence convertByte(byte value) { + return String.valueOf(value); + } + + @Override + public byte convertToByte(CharSequence value) { + if (value instanceof AsciiString && value.length() == 1) { + return ((AsciiString) value).byteAt(0); + } + return Byte.parseByte(value.toString()); + } + + @Override + public char convertToChar(CharSequence value) { + return value.charAt(0); + } + + @Override + public CharSequence convertShort(short value) { + return String.valueOf(value); + } + + @Override + public short convertToShort(CharSequence value) { + if (value instanceof AsciiString) { + return ((AsciiString) value).parseShort(); + } + return Short.parseShort(value.toString()); + } + + @Override + public int convertToInt(CharSequence value) { + if (value instanceof AsciiString) { + return ((AsciiString) value).parseInt(); + } + return Integer.parseInt(value.toString()); + } + + @Override + public long convertToLong(CharSequence value) { + if (value instanceof AsciiString) { + return ((AsciiString) value).parseLong(); + } + return Long.parseLong(value.toString()); + } + + @Override + public CharSequence convertTimeMillis(long value) { + return DateFormatter.format(new Date(value)); + } + + @Override + public long convertToTimeMillis(CharSequence value) { + Date date = DateFormatter.parseHttpDate(value); + if (date == null) { + PlatformDependent.throwException(new ParseException("header can't be parsed into a Date: " + value, 0)); + return 0; + } + return date.getTime(); + } + + @Override + public float convertToFloat(CharSequence value) { + if (value instanceof AsciiString) { + return ((AsciiString) value).parseFloat(); + } + return Float.parseFloat(value.toString()); + } + + @Override + public double convertToDouble(CharSequence value) { + if (value instanceof AsciiString) { + return ((AsciiString) value).parseDouble(); + } + return Double.parseDouble(value.toString()); + } +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/CodecException.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/CodecException.java new file mode 100644 index 0000000..d7e8ed0 --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/CodecException.java @@ -0,0 +1,51 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec; + +/** + * An {@link Exception} which is thrown by a codec. + */ +public class CodecException extends RuntimeException { + + private static final long serialVersionUID = -1464830400709348473L; + + /** + * Creates a new instance. + */ + public CodecException() { + } + + /** + * Creates a new instance. + */ + public CodecException(String message, Throwable cause) { + super(message, cause); + } + + /** + * Creates a new instance. + */ + public CodecException(String message) { + super(message); + } + + /** + * Creates a new instance. + */ + public CodecException(Throwable cause) { + super(cause); + } +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/CodecOutputList.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/CodecOutputList.java new file mode 100644 index 0000000..8f7b0fe --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/CodecOutputList.java @@ -0,0 +1,232 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec; + +import io.netty.util.concurrent.FastThreadLocal; +import io.netty.util.internal.MathUtil; + +import java.util.AbstractList; +import java.util.RandomAccess; + +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +/** + * Special {@link AbstractList} implementation which is used within our codec base classes. + */ +final class CodecOutputList extends AbstractList implements RandomAccess { + + private static final CodecOutputListRecycler NOOP_RECYCLER = new CodecOutputListRecycler() { + @Override + public void recycle(CodecOutputList object) { + // drop on the floor and let the GC handle it. + } + }; + + private static final FastThreadLocal CODEC_OUTPUT_LISTS_POOL = + new FastThreadLocal() { + @Override + protected CodecOutputLists initialValue() throws Exception { + // 16 CodecOutputList per Thread are cached. + return new CodecOutputLists(16); + } + }; + + private interface CodecOutputListRecycler { + void recycle(CodecOutputList codecOutputList); + } + + private static final class CodecOutputLists implements CodecOutputListRecycler { + private final CodecOutputList[] elements; + private final int mask; + + private int currentIdx; + private int count; + + CodecOutputLists(int numElements) { + elements = new CodecOutputList[MathUtil.safeFindNextPositivePowerOfTwo(numElements)]; + for (int i = 0; i < elements.length; ++i) { + // Size of 16 should be good enough for the majority of all users as an initial capacity. + elements[i] = new CodecOutputList(this, 16); + } + count = elements.length; + currentIdx = elements.length; + mask = elements.length - 1; + } + + public CodecOutputList getOrCreate() { + if (count == 0) { + // Return a new CodecOutputList which will not be cached. We use a size of 4 to keep the overhead + // low. + return new CodecOutputList(NOOP_RECYCLER, 4); + } + --count; + + int idx = (currentIdx - 1) & mask; + CodecOutputList list = elements[idx]; + currentIdx = idx; + return list; + } + + @Override + public void recycle(CodecOutputList codecOutputList) { + int idx = currentIdx; + elements[idx] = codecOutputList; + currentIdx = (idx + 1) & mask; + ++count; + assert count <= elements.length; + } + } + + static CodecOutputList newInstance() { + return CODEC_OUTPUT_LISTS_POOL.get().getOrCreate(); + } + + private final CodecOutputListRecycler recycler; + private int size; + private Object[] array; + private boolean insertSinceRecycled; + + private CodecOutputList(CodecOutputListRecycler recycler, int size) { + this.recycler = recycler; + array = new Object[size]; + } + + @Override + public Object get(int index) { + checkIndex(index); + return array[index]; + } + + @Override + public int size() { + return size; + } + + @Override + public boolean add(Object element) { + checkNotNull(element, "element"); + try { + insert(size, element); + } catch (IndexOutOfBoundsException ignore) { + // This should happen very infrequently so we just catch the exception and try again. + expandArray(); + insert(size, element); + } + ++ size; + return true; + } + + @Override + public Object set(int index, Object element) { + checkNotNull(element, "element"); + checkIndex(index); + + Object old = array[index]; + insert(index, element); + return old; + } + + @Override + public void add(int index, Object element) { + checkNotNull(element, "element"); + checkIndex(index); + + if (size == array.length) { + expandArray(); + } + + if (index != size) { + System.arraycopy(array, index, array, index + 1, size - index); + } + + insert(index, element); + ++ size; + } + + @Override + public Object remove(int index) { + checkIndex(index); + Object old = array[index]; + + int len = size - index - 1; + if (len > 0) { + System.arraycopy(array, index + 1, array, index, len); + } + array[-- size] = null; + + return old; + } + + @Override + public void clear() { + // We only set the size to 0 and not null out the array. Null out the array will explicit requested by + // calling recycle() + size = 0; + } + + /** + * Returns {@code true} if any elements where added or set. This will be reset once {@link #recycle()} was called. + */ + boolean insertSinceRecycled() { + return insertSinceRecycled; + } + + /** + * Recycle the array which will clear it and null out all entries in the internal storage. + */ + void recycle() { + for (int i = 0 ; i < size; i ++) { + array[i] = null; + } + size = 0; + insertSinceRecycled = false; + + recycler.recycle(this); + } + + /** + * Returns the element on the given index. This operation will not do any range-checks and so is considered unsafe. + */ + Object getUnsafe(int index) { + return array[index]; + } + + private void checkIndex(int index) { + if (index >= size) { + throw new IndexOutOfBoundsException("expected: index < (" + + size + "),but actual is (" + size + ")"); + } + } + + private void insert(int index, Object element) { + array[index] = element; + insertSinceRecycled = true; + } + + private void expandArray() { + // double capacity + int newCapacity = array.length << 1; + + if (newCapacity < 0) { + throw new OutOfMemoryError(); + } + + Object[] newArray = new Object[newCapacity]; + System.arraycopy(array, 0, newArray, 0, array.length); + + array = newArray; + } +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/CorruptedFrameException.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/CorruptedFrameException.java new file mode 100644 index 0000000..b98f35b --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/CorruptedFrameException.java @@ -0,0 +1,52 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec; + +/** + * An {@link DecoderException} which is thrown when the received frame data could not be decoded by + * an inbound handler. + */ +public class CorruptedFrameException extends DecoderException { + + private static final long serialVersionUID = 3918052232492988408L; + + /** + * Creates a new instance. + */ + public CorruptedFrameException() { + } + + /** + * Creates a new instance. + */ + public CorruptedFrameException(String message, Throwable cause) { + super(message, cause); + } + + /** + * Creates a new instance. + */ + public CorruptedFrameException(String message) { + super(message); + } + + /** + * Creates a new instance. + */ + public CorruptedFrameException(Throwable cause) { + super(cause); + } +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/DatagramPacketDecoder.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/DatagramPacketDecoder.java new file mode 100644 index 0000000..9f43cd8 --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/DatagramPacketDecoder.java @@ -0,0 +1,115 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.socket.DatagramPacket; +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +import java.util.List; + +/** + * A decoder that decodes the content of the received {@link DatagramPacket} using + * the specified {@link ByteBuf} decoder. E.g., + * + *

+ * {@link ChannelPipeline} pipeline = ...;
+ * pipeline.addLast("udpDecoder", new {@link DatagramPacketDecoder}(new ProtobufDecoder(...));
+ * 
+ */ +public class DatagramPacketDecoder extends MessageToMessageDecoder { + + private final MessageToMessageDecoder decoder; + + /** + * Create a {@link DatagramPacket} decoder using the specified {@link ByteBuf} decoder. + * + * @param decoder the specified {@link ByteBuf} decoder + */ + public DatagramPacketDecoder(MessageToMessageDecoder decoder) { + this.decoder = checkNotNull(decoder, "decoder"); + } + + @Override + public boolean acceptInboundMessage(Object msg) throws Exception { + if (msg instanceof DatagramPacket) { + return decoder.acceptInboundMessage(((DatagramPacket) msg).content()); + } + return false; + } + + @Override + protected void decode(ChannelHandlerContext ctx, DatagramPacket msg, List out) throws Exception { + decoder.decode(ctx, msg.content(), out); + } + + @Override + public void channelRegistered(ChannelHandlerContext ctx) throws Exception { + decoder.channelRegistered(ctx); + } + + @Override + public void channelUnregistered(ChannelHandlerContext ctx) throws Exception { + decoder.channelUnregistered(ctx); + } + + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + decoder.channelActive(ctx); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + decoder.channelInactive(ctx); + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + decoder.channelReadComplete(ctx); + } + + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + decoder.userEventTriggered(ctx, evt); + } + + @Override + public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception { + decoder.channelWritabilityChanged(ctx); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + decoder.exceptionCaught(ctx, cause); + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + decoder.handlerAdded(ctx); + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { + decoder.handlerRemoved(ctx); + } + + @Override + public boolean isSharable() { + return decoder.isSharable(); + } +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/DatagramPacketEncoder.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/DatagramPacketEncoder.java new file mode 100644 index 0000000..77c83cb --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/DatagramPacketEncoder.java @@ -0,0 +1,147 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.AddressedEnvelope; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.ChannelPromise; +import io.netty.channel.socket.DatagramPacket; +import io.netty.util.internal.StringUtil; +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.util.List; + +/** + * An encoder that encodes the content in {@link AddressedEnvelope} to {@link DatagramPacket} using + * the specified message encoder. E.g., + * + *

+ * {@link ChannelPipeline} pipeline = ...;
+ * pipeline.addLast("udpEncoder", new {@link DatagramPacketEncoder}(new ProtobufEncoder(...));
+ * 
+ * + * Note: As UDP packets are out-of-order, you should make sure the encoded message size are not greater than + * the max safe packet size in your particular network path which guarantees no packet fragmentation. + * + * @param the type of message to be encoded + */ +public class DatagramPacketEncoder extends MessageToMessageEncoder> { + + private final MessageToMessageEncoder encoder; + + /** + * Create an encoder that encodes the content in {@link AddressedEnvelope} to {@link DatagramPacket} using + * the specified message encoder. + * + * @param encoder the specified message encoder + */ + public DatagramPacketEncoder(MessageToMessageEncoder encoder) { + this.encoder = checkNotNull(encoder, "encoder"); + } + + @Override + public boolean acceptOutboundMessage(Object msg) throws Exception { + if (super.acceptOutboundMessage(msg)) { + @SuppressWarnings("rawtypes") + AddressedEnvelope envelope = (AddressedEnvelope) msg; + return encoder.acceptOutboundMessage(envelope.content()) + && (envelope.sender() instanceof InetSocketAddress || envelope.sender() == null) + && envelope.recipient() instanceof InetSocketAddress; + } + return false; + } + + @Override + protected void encode( + ChannelHandlerContext ctx, AddressedEnvelope msg, List out) throws Exception { + assert out.isEmpty(); + + encoder.encode(ctx, msg.content(), out); + if (out.size() != 1) { + throw new EncoderException( + StringUtil.simpleClassName(encoder) + " must produce only one message."); + } + Object content = out.get(0); + if (content instanceof ByteBuf) { + // Replace the ByteBuf with a DatagramPacket. + out.set(0, new DatagramPacket((ByteBuf) content, msg.recipient(), msg.sender())); + } else { + throw new EncoderException( + StringUtil.simpleClassName(encoder) + " must produce only ByteBuf."); + } + } + + @Override + public void bind(ChannelHandlerContext ctx, SocketAddress localAddress, ChannelPromise promise) throws Exception { + encoder.bind(ctx, localAddress, promise); + } + + @Override + public void connect( + ChannelHandlerContext ctx, SocketAddress remoteAddress, + SocketAddress localAddress, ChannelPromise promise) throws Exception { + encoder.connect(ctx, remoteAddress, localAddress, promise); + } + + @Override + public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + encoder.disconnect(ctx, promise); + } + + @Override + public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + encoder.close(ctx, promise); + } + + @Override + public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + encoder.deregister(ctx, promise); + } + + @Override + public void read(ChannelHandlerContext ctx) throws Exception { + encoder.read(ctx); + } + + @Override + public void flush(ChannelHandlerContext ctx) throws Exception { + encoder.flush(ctx); + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + encoder.handlerAdded(ctx); + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { + encoder.handlerRemoved(ctx); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + encoder.exceptionCaught(ctx, cause); + } + + @Override + public boolean isSharable() { + return encoder.isSharable(); + } +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/DateFormatter.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/DateFormatter.java new file mode 100644 index 0000000..8897cb5 --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/DateFormatter.java @@ -0,0 +1,448 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec; + +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +import io.netty.util.AsciiString; +import io.netty.util.concurrent.FastThreadLocal; + +import java.util.BitSet; +import java.util.Calendar; +import java.util.Date; +import java.util.GregorianCalendar; +import java.util.TimeZone; + +/** + * A formatter for HTTP header dates, such as "Expires" and "Date" headers, or "expires" field in "Set-Cookie". + * + * On the parsing side, it honors RFC6265 (so it supports RFC1123). + * Note that: + *
    + *
  • Day of week is ignored and not validated
  • + *
  • Timezone is ignored, as RFC6265 assumes UTC
  • + *
+ * If you're looking for a date format that validates day of week, or supports other timezones, consider using + * java.util.DateTimeFormatter.RFC_1123_DATE_TIME. + * + * On the formatting side, it uses a subset of RFC1123 (2 digit day-of-month and 4 digit year) as per RFC2616. + * This subset supports RFC6265. + * + * @see RFC6265 for the parsing side + * @see RFC1123 and + * RFC2616 for the encoding side. + */ +public final class DateFormatter { + + private static final BitSet DELIMITERS = new BitSet(); + static { + DELIMITERS.set(0x09); + for (char c = 0x20; c <= 0x2F; c++) { + DELIMITERS.set(c); + } + for (char c = 0x3B; c <= 0x40; c++) { + DELIMITERS.set(c); + } + for (char c = 0x5B; c <= 0x60; c++) { + DELIMITERS.set(c); + } + for (char c = 0x7B; c <= 0x7E; c++) { + DELIMITERS.set(c); + } + } + + private static final String[] DAY_OF_WEEK_TO_SHORT_NAME = + new String[]{"Sun", "Mon", "Tue", "Wed", "Thu", "Fri", "Sat"}; + + private static final String[] CALENDAR_MONTH_TO_SHORT_NAME = + new String[]{"Jan", "Feb", "Mar", "Apr", "May", "Jun", "Jul", "Aug", "Sep", "Oct", "Nov", "Dec"}; + + private static final FastThreadLocal INSTANCES = + new FastThreadLocal() { + @Override + protected DateFormatter initialValue() { + return new DateFormatter(); + } + }; + + /** + * Parse some text into a {@link Date}, according to RFC6265 + * @param txt text to parse + * @return a {@link Date}, or null if text couldn't be parsed + */ + public static Date parseHttpDate(CharSequence txt) { + return parseHttpDate(txt, 0, txt.length()); + } + + /** + * Parse some text into a {@link Date}, according to RFC6265 + * @param txt text to parse + * @param start the start index inside {@code txt} + * @param end the end index inside {@code txt} + * @return a {@link Date}, or null if text couldn't be parsed + */ + public static Date parseHttpDate(CharSequence txt, int start, int end) { + int length = end - start; + if (length == 0) { + return null; + } else if (length < 0) { + throw new IllegalArgumentException("Can't have end < start"); + } else if (length > 64) { + throw new IllegalArgumentException("Can't parse more than 64 chars, " + + "looks like a user error or a malformed header"); + } + return formatter().parse0(checkNotNull(txt, "txt"), start, end); + } + + /** + * Format a {@link Date} into RFC1123 format + * @param date the date to format + * @return a RFC1123 string + */ + public static String format(Date date) { + return formatter().format0(checkNotNull(date, "date")); + } + + /** + * Append a {@link Date} to a {@link StringBuilder} into RFC1123 format + * @param date the date to format + * @param sb the StringBuilder + * @return the same StringBuilder + */ + public static StringBuilder append(Date date, StringBuilder sb) { + return formatter().append0(checkNotNull(date, "date"), checkNotNull(sb, "sb")); + } + + private static DateFormatter formatter() { + DateFormatter formatter = INSTANCES.get(); + formatter.reset(); + return formatter; + } + + // delimiter = %x09 / %x20-2F / %x3B-40 / %x5B-60 / %x7B-7E + private static boolean isDelim(char c) { + return DELIMITERS.get(c); + } + + private static boolean isDigit(char c) { + return c >= 48 && c <= 57; + } + + private static int getNumericalValue(char c) { + return c - 48; + } + + private final GregorianCalendar cal = new GregorianCalendar(TimeZone.getTimeZone("UTC")); + private final StringBuilder sb = new StringBuilder(29); // Sun, 27 Nov 2016 19:37:15 GMT + private boolean timeFound; + private int hours; + private int minutes; + private int seconds; + private boolean dayOfMonthFound; + private int dayOfMonth; + private boolean monthFound; + private int month; + private boolean yearFound; + private int year; + + private DateFormatter() { + reset(); + } + + public void reset() { + timeFound = false; + hours = -1; + minutes = -1; + seconds = -1; + dayOfMonthFound = false; + dayOfMonth = -1; + monthFound = false; + month = -1; + yearFound = false; + year = -1; + cal.clear(); + sb.setLength(0); + } + + private boolean tryParseTime(CharSequence txt, int tokenStart, int tokenEnd) { + int len = tokenEnd - tokenStart; + + // h:m:s to hh:mm:ss + if (len < 5 || len > 8) { + return false; + } + + int localHours = -1; + int localMinutes = -1; + int localSeconds = -1; + int currentPartNumber = 0; + int currentPartValue = 0; + int numDigits = 0; + + for (int i = tokenStart; i < tokenEnd; i++) { + char c = txt.charAt(i); + if (isDigit(c)) { + currentPartValue = currentPartValue * 10 + getNumericalValue(c); + if (++numDigits > 2) { + return false; // too many digits in this part + } + } else if (c == ':') { + if (numDigits == 0) { + // no digits between separators + return false; + } + switch (currentPartNumber) { + case 0: + // flushing hours + localHours = currentPartValue; + break; + case 1: + // flushing minutes + localMinutes = currentPartValue; + break; + default: + // invalid, too many : + return false; + } + currentPartValue = 0; + currentPartNumber++; + numDigits = 0; + } else { + // invalid char + return false; + } + } + + if (numDigits > 0) { + // pending seconds + localSeconds = currentPartValue; + } + + if (localHours >= 0 && localMinutes >= 0 && localSeconds >= 0) { + hours = localHours; + minutes = localMinutes; + seconds = localSeconds; + return true; + } + + return false; + } + + private boolean tryParseDayOfMonth(CharSequence txt, int tokenStart, int tokenEnd) { + int len = tokenEnd - tokenStart; + + if (len == 1) { + char c0 = txt.charAt(tokenStart); + if (isDigit(c0)) { + dayOfMonth = getNumericalValue(c0); + return true; + } + + } else if (len == 2) { + char c0 = txt.charAt(tokenStart); + char c1 = txt.charAt(tokenStart + 1); + if (isDigit(c0) && isDigit(c1)) { + dayOfMonth = getNumericalValue(c0) * 10 + getNumericalValue(c1); + return true; + } + } + + return false; + } + + private boolean tryParseMonth(CharSequence txt, int tokenStart, int tokenEnd) { + int len = tokenEnd - tokenStart; + + if (len != 3) { + return false; + } + + char monthChar1 = AsciiString.toLowerCase(txt.charAt(tokenStart)); + char monthChar2 = AsciiString.toLowerCase(txt.charAt(tokenStart + 1)); + char monthChar3 = AsciiString.toLowerCase(txt.charAt(tokenStart + 2)); + + if (monthChar1 == 'j' && monthChar2 == 'a' && monthChar3 == 'n') { + month = Calendar.JANUARY; + } else if (monthChar1 == 'f' && monthChar2 == 'e' && monthChar3 == 'b') { + month = Calendar.FEBRUARY; + } else if (monthChar1 == 'm' && monthChar2 == 'a' && monthChar3 == 'r') { + month = Calendar.MARCH; + } else if (monthChar1 == 'a' && monthChar2 == 'p' && monthChar3 == 'r') { + month = Calendar.APRIL; + } else if (monthChar1 == 'm' && monthChar2 == 'a' && monthChar3 == 'y') { + month = Calendar.MAY; + } else if (monthChar1 == 'j' && monthChar2 == 'u' && monthChar3 == 'n') { + month = Calendar.JUNE; + } else if (monthChar1 == 'j' && monthChar2 == 'u' && monthChar3 == 'l') { + month = Calendar.JULY; + } else if (monthChar1 == 'a' && monthChar2 == 'u' && monthChar3 == 'g') { + month = Calendar.AUGUST; + } else if (monthChar1 == 's' && monthChar2 == 'e' && monthChar3 == 'p') { + month = Calendar.SEPTEMBER; + } else if (monthChar1 == 'o' && monthChar2 == 'c' && monthChar3 == 't') { + month = Calendar.OCTOBER; + } else if (monthChar1 == 'n' && monthChar2 == 'o' && monthChar3 == 'v') { + month = Calendar.NOVEMBER; + } else if (monthChar1 == 'd' && monthChar2 == 'e' && monthChar3 == 'c') { + month = Calendar.DECEMBER; + } else { + return false; + } + + return true; + } + + private boolean tryParseYear(CharSequence txt, int tokenStart, int tokenEnd) { + int len = tokenEnd - tokenStart; + + if (len == 2) { + char c0 = txt.charAt(tokenStart); + char c1 = txt.charAt(tokenStart + 1); + if (isDigit(c0) && isDigit(c1)) { + year = getNumericalValue(c0) * 10 + getNumericalValue(c1); + return true; + } + + } else if (len == 4) { + char c0 = txt.charAt(tokenStart); + char c1 = txt.charAt(tokenStart + 1); + char c2 = txt.charAt(tokenStart + 2); + char c3 = txt.charAt(tokenStart + 3); + if (isDigit(c0) && isDigit(c1) && isDigit(c2) && isDigit(c3)) { + year = getNumericalValue(c0) * 1000 + + getNumericalValue(c1) * 100 + + getNumericalValue(c2) * 10 + + getNumericalValue(c3); + return true; + } + } + + return false; + } + + private boolean parseToken(CharSequence txt, int tokenStart, int tokenEnd) { + // return true if all parts are found + if (!timeFound) { + timeFound = tryParseTime(txt, tokenStart, tokenEnd); + if (timeFound) { + return dayOfMonthFound && monthFound && yearFound; + } + } + + if (!dayOfMonthFound) { + dayOfMonthFound = tryParseDayOfMonth(txt, tokenStart, tokenEnd); + if (dayOfMonthFound) { + return timeFound && monthFound && yearFound; + } + } + + if (!monthFound) { + monthFound = tryParseMonth(txt, tokenStart, tokenEnd); + if (monthFound) { + return timeFound && dayOfMonthFound && yearFound; + } + } + + if (!yearFound) { + yearFound = tryParseYear(txt, tokenStart, tokenEnd); + } + return timeFound && dayOfMonthFound && monthFound && yearFound; + } + + private Date parse0(CharSequence txt, int start, int end) { + boolean allPartsFound = parse1(txt, start, end); + return allPartsFound && normalizeAndValidate() ? computeDate() : null; + } + + private boolean parse1(CharSequence txt, int start, int end) { + // return true if all parts are found + int tokenStart = -1; + + for (int i = start; i < end; i++) { + char c = txt.charAt(i); + + if (isDelim(c)) { + if (tokenStart != -1) { + // terminate token + if (parseToken(txt, tokenStart, i)) { + return true; + } + tokenStart = -1; + } + } else if (tokenStart == -1) { + // start new token + tokenStart = i; + } + } + + // terminate trailing token + return tokenStart != -1 && parseToken(txt, tokenStart, txt.length()); + } + + private boolean normalizeAndValidate() { + if (dayOfMonth < 1 + || dayOfMonth > 31 + || hours > 23 + || minutes > 59 + || seconds > 59) { + return false; + } + + if (year >= 70 && year <= 99) { + year += 1900; + } else if (year >= 0 && year < 70) { + year += 2000; + } else if (year < 1601) { + // invalid value + return false; + } + return true; + } + + private Date computeDate() { + cal.set(Calendar.DAY_OF_MONTH, dayOfMonth); + cal.set(Calendar.MONTH, month); + cal.set(Calendar.YEAR, year); + cal.set(Calendar.HOUR_OF_DAY, hours); + cal.set(Calendar.MINUTE, minutes); + cal.set(Calendar.SECOND, seconds); + return cal.getTime(); + } + + private String format0(Date date) { + append0(date, sb); + return sb.toString(); + } + + private StringBuilder append0(Date date, StringBuilder sb) { + cal.setTime(date); + + sb.append(DAY_OF_WEEK_TO_SHORT_NAME[cal.get(Calendar.DAY_OF_WEEK) - 1]).append(", "); + appendZeroLeftPadded(cal.get(Calendar.DAY_OF_MONTH), sb).append(' '); + sb.append(CALENDAR_MONTH_TO_SHORT_NAME[cal.get(Calendar.MONTH)]).append(' '); + sb.append(cal.get(Calendar.YEAR)).append(' '); + appendZeroLeftPadded(cal.get(Calendar.HOUR_OF_DAY), sb).append(':'); + appendZeroLeftPadded(cal.get(Calendar.MINUTE), sb).append(':'); + return appendZeroLeftPadded(cal.get(Calendar.SECOND), sb).append(" GMT"); + } + + private static StringBuilder appendZeroLeftPadded(int value, StringBuilder sb) { + if (value < 10) { + sb.append('0'); + } + return sb.append(value); + } +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/DecoderException.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/DecoderException.java new file mode 100644 index 0000000..0a1ff99 --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/DecoderException.java @@ -0,0 +1,51 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec; + +/** + * An {@link CodecException} which is thrown by a decoder. + */ +public class DecoderException extends CodecException { + + private static final long serialVersionUID = 6926716840699621852L; + + /** + * Creates a new instance. + */ + public DecoderException() { + } + + /** + * Creates a new instance. + */ + public DecoderException(String message, Throwable cause) { + super(message, cause); + } + + /** + * Creates a new instance. + */ + public DecoderException(String message) { + super(message); + } + + /** + * Creates a new instance. + */ + public DecoderException(Throwable cause) { + super(cause); + } +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/DecoderResult.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/DecoderResult.java new file mode 100644 index 0000000..58d91b2 --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/DecoderResult.java @@ -0,0 +1,76 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec; + +import io.netty.util.Signal; +import io.netty.util.internal.ObjectUtil; + +public class DecoderResult { + + protected static final Signal SIGNAL_UNFINISHED = Signal.valueOf(DecoderResult.class, "UNFINISHED"); + protected static final Signal SIGNAL_SUCCESS = Signal.valueOf(DecoderResult.class, "SUCCESS"); + + public static final DecoderResult UNFINISHED = new DecoderResult(SIGNAL_UNFINISHED); + public static final DecoderResult SUCCESS = new DecoderResult(SIGNAL_SUCCESS); + + public static DecoderResult failure(Throwable cause) { + return new DecoderResult(ObjectUtil.checkNotNull(cause, "cause")); + } + + private final Throwable cause; + + protected DecoderResult(Throwable cause) { + this.cause = ObjectUtil.checkNotNull(cause, "cause"); + } + + public boolean isFinished() { + return cause != SIGNAL_UNFINISHED; + } + + public boolean isSuccess() { + return cause == SIGNAL_SUCCESS; + } + + public boolean isFailure() { + return cause != SIGNAL_SUCCESS && cause != SIGNAL_UNFINISHED; + } + + public Throwable cause() { + if (isFailure()) { + return cause; + } else { + return null; + } + } + + @Override + public String toString() { + if (isFinished()) { + if (isSuccess()) { + return "success"; + } + + String cause = cause().toString(); + return new StringBuilder(cause.length() + 17) + .append("failure(") + .append(cause) + .append(')') + .toString(); + } else { + return "unfinished"; + } + } +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/DecoderResultProvider.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/DecoderResultProvider.java new file mode 100644 index 0000000..1418c90 --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/DecoderResultProvider.java @@ -0,0 +1,33 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.codec; + +/** + * Provides the accessor methods for the {@link DecoderResult} property of a decoded message. + */ +public interface DecoderResultProvider { + /** + * Returns the result of decoding this object. + */ + DecoderResult decoderResult(); + + /** + * Updates the result of decoding this object. This method is supposed to be invoked by a decoder. + * Do not call this method unless you know what you are doing. + */ + void setDecoderResult(DecoderResult result); +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/DefaultHeaders.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/DefaultHeaders.java new file mode 100644 index 0000000..6ff05fd --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/DefaultHeaders.java @@ -0,0 +1,1446 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec; + +import io.netty.util.HashingStrategy; + +import java.util.Arrays; +import java.util.Collections; +import java.util.Iterator; +import java.util.LinkedHashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.NoSuchElementException; +import java.util.Set; + +import static io.netty.util.HashingStrategy.JAVA_HASHER; +import static io.netty.util.internal.MathUtil.findNextPositivePowerOfTwo; +import static io.netty.util.internal.ObjectUtil.checkNotNull; +import static java.lang.Math.max; +import static java.lang.Math.min; + +/** + * Default implementation of {@link Headers}; + * + * @param the type of the header name. + * @param the type of the header value. + * @param the type to use for return values when the intention is to return {@code this} object. + */ +public class DefaultHeaders> implements Headers { + /** + * Constant used to seed the hash code generation. Could be anything but this was borrowed from murmur3. + */ + static final int HASH_CODE_SEED = 0xc2b2ae35; + + private final HeaderEntry[] entries; + protected final HeaderEntry head; + + private final byte hashMask; + private final ValueConverter valueConverter; + private final NameValidator nameValidator; + private final ValueValidator valueValidator; + private final HashingStrategy hashingStrategy; + int size; + + public interface NameValidator { + /** + * Verify that {@code name} is valid. + * @param name The name to validate. + * @throws RuntimeException if {@code name} is not valid. + */ + void validateName(K name); + + @SuppressWarnings("rawtypes") + NameValidator NOT_NULL = new NameValidator() { + @Override + public void validateName(Object name) { + checkNotNull(name, "name"); + } + }; + } + + public interface ValueValidator { + /** + * Validate the given value. If the validation fails, then an implementation specific runtime exception may be + * thrown. + * + * @param value The value to validate. + */ + void validate(V value); + + ValueValidator NO_VALIDATION = new ValueValidator() { + @Override + public void validate(Object value) { + } + }; + } + + @SuppressWarnings("unchecked") + public DefaultHeaders(ValueConverter valueConverter) { + this(JAVA_HASHER, valueConverter); + } + + @SuppressWarnings("unchecked") + public DefaultHeaders(ValueConverter valueConverter, NameValidator nameValidator) { + this(JAVA_HASHER, valueConverter, nameValidator); + } + + @SuppressWarnings("unchecked") + public DefaultHeaders(HashingStrategy nameHashingStrategy, ValueConverter valueConverter) { + this(nameHashingStrategy, valueConverter, NameValidator.NOT_NULL); + } + + public DefaultHeaders(HashingStrategy nameHashingStrategy, + ValueConverter valueConverter, NameValidator nameValidator) { + this(nameHashingStrategy, valueConverter, nameValidator, 16); + } + + /** + * Create a new instance. + * @param nameHashingStrategy Used to hash and equality compare names. + * @param valueConverter Used to convert values to/from native types. + * @param nameValidator Used to validate name elements. + * @param arraySizeHint A hint as to how large the hash data structure should be. + * The next positive power of two will be used. An upper bound may be enforced. + */ + @SuppressWarnings("unchecked") + public DefaultHeaders(HashingStrategy nameHashingStrategy, + ValueConverter valueConverter, NameValidator nameValidator, int arraySizeHint) { + this(nameHashingStrategy, valueConverter, nameValidator, arraySizeHint, + (ValueValidator) ValueValidator.NO_VALIDATION); + } + + /** + * Create a new instance. + * @param nameHashingStrategy Used to hash and equality compare names. + * @param valueConverter Used to convert values to/from native types. + * @param nameValidator Used to validate name elements. + * @param arraySizeHint A hint as to how large the hash data structure should be. + * The next positive power of two will be used. An upper bound may be enforced. + * @param valueValidator The validation strategy for entry values. + */ + @SuppressWarnings("unchecked") + public DefaultHeaders(HashingStrategy nameHashingStrategy, ValueConverter valueConverter, + NameValidator nameValidator, int arraySizeHint, ValueValidator valueValidator) { + this.valueConverter = checkNotNull(valueConverter, "valueConverter"); + this.nameValidator = checkNotNull(nameValidator, "nameValidator"); + hashingStrategy = checkNotNull(nameHashingStrategy, "nameHashingStrategy"); + this.valueValidator = checkNotNull(valueValidator, "valueValidator"); + // Enforce a bound of [2, 128] because hashMask is a byte. The max possible value of hashMask is one less + // than the length of this array, and we want the mask to be > 0. + entries = new HeaderEntry[findNextPositivePowerOfTwo(max(2, min(arraySizeHint, 128)))]; + hashMask = (byte) (entries.length - 1); + head = new HeaderEntry(); + } + + @Override + public V get(K name) { + checkNotNull(name, "name"); + + int h = hashingStrategy.hashCode(name); + int i = index(h); + HeaderEntry e = entries[i]; + V value = null; + // loop until the first header was found + while (e != null) { + if (e.hash == h && hashingStrategy.equals(name, e.key)) { + value = e.value; + } + + e = e.next; + } + return value; + } + + @Override + public V get(K name, V defaultValue) { + V value = get(name); + if (value == null) { + return defaultValue; + } + return value; + } + + @Override + public V getAndRemove(K name) { + int h = hashingStrategy.hashCode(name); + return remove0(h, index(h), checkNotNull(name, "name")); + } + + @Override + public V getAndRemove(K name, V defaultValue) { + V value = getAndRemove(name); + if (value == null) { + return defaultValue; + } + return value; + } + + @Override + public List getAll(K name) { + checkNotNull(name, "name"); + + LinkedList values = new LinkedList(); + + int h = hashingStrategy.hashCode(name); + int i = index(h); + HeaderEntry e = entries[i]; + while (e != null) { + if (e.hash == h && hashingStrategy.equals(name, e.key)) { + values.addFirst(e.getValue()); + } + e = e.next; + } + return values; + } + + /** + * Equivalent to {@link #getAll(Object)} but no intermediate list is generated. + * @param name the name of the header to retrieve + * @return an {@link Iterator} of header values corresponding to {@code name}. + */ + public Iterator valueIterator(K name) { + return new ValueIterator(name); + } + + @Override + public List getAllAndRemove(K name) { + List all = getAll(name); + remove(name); + return all; + } + + @Override + public boolean contains(K name) { + return get(name) != null; + } + + @Override + public boolean containsObject(K name, Object value) { + return contains(name, fromObject(name, value)); + } + + @Override + public boolean containsBoolean(K name, boolean value) { + return contains(name, fromBoolean(name, value)); + } + + @Override + public boolean containsByte(K name, byte value) { + return contains(name, fromByte(name, value)); + } + + @Override + public boolean containsChar(K name, char value) { + return contains(name, fromChar(name, value)); + } + + @Override + public boolean containsShort(K name, short value) { + return contains(name, fromShort(name, value)); + } + + @Override + public boolean containsInt(K name, int value) { + return contains(name, fromInt(name, value)); + } + + @Override + public boolean containsLong(K name, long value) { + return contains(name, fromLong(name, value)); + } + + @Override + public boolean containsFloat(K name, float value) { + return contains(name, fromFloat(name, value)); + } + + @Override + public boolean containsDouble(K name, double value) { + return contains(name, fromDouble(name, value)); + } + + @Override + public boolean containsTimeMillis(K name, long value) { + return contains(name, fromTimeMillis(name, value)); + } + + @SuppressWarnings("unchecked") + @Override + public boolean contains(K name, V value) { + return contains(name, value, JAVA_HASHER); + } + + public final boolean contains(K name, V value, HashingStrategy valueHashingStrategy) { + checkNotNull(name, "name"); + + int h = hashingStrategy.hashCode(name); + int i = index(h); + HeaderEntry e = entries[i]; + while (e != null) { + if (e.hash == h && hashingStrategy.equals(name, e.key) && valueHashingStrategy.equals(value, e.value)) { + return true; + } + e = e.next; + } + return false; + } + + @Override + public int size() { + return size; + } + + @Override + public boolean isEmpty() { + return head == head.after; + } + + @Override + public Set names() { + if (isEmpty()) { + return Collections.emptySet(); + } + Set names = new LinkedHashSet(size()); + HeaderEntry e = head.after; + while (e != head) { + names.add(e.getKey()); + e = e.after; + } + return names; + } + + @Override + public T add(K name, V value) { + validateName(nameValidator, true, name); + validateValue(valueValidator, name, value); + checkNotNull(value, "value"); + int h = hashingStrategy.hashCode(name); + int i = index(h); + add0(h, i, name, value); + return thisT(); + } + + @Override + public T add(K name, Iterable values) { + validateName(nameValidator, true, name); + int h = hashingStrategy.hashCode(name); + int i = index(h); + for (V v: values) { + validateValue(valueValidator, name, v); + add0(h, i, name, v); + } + return thisT(); + } + + @Override + public T add(K name, V... values) { + validateName(nameValidator, true, name); + int h = hashingStrategy.hashCode(name); + int i = index(h); + for (V v: values) { + validateValue(valueValidator, name, v); + add0(h, i, name, v); + } + return thisT(); + } + + @Override + public T addObject(K name, Object value) { + return add(name, fromObject(name, value)); + } + + @Override + public T addObject(K name, Iterable values) { + for (Object value : values) { + addObject(name, value); + } + return thisT(); + } + + @Override + public T addObject(K name, Object... values) { + for (Object value: values) { + addObject(name, value); + } + return thisT(); + } + + @Override + public T addInt(K name, int value) { + return add(name, fromInt(name, value)); + } + + @Override + public T addLong(K name, long value) { + return add(name, fromLong(name, value)); + } + + @Override + public T addDouble(K name, double value) { + return add(name, fromDouble(name, value)); + } + + @Override + public T addTimeMillis(K name, long value) { + return add(name, fromTimeMillis(name, value)); + } + + @Override + public T addChar(K name, char value) { + return add(name, fromChar(name, value)); + } + + @Override + public T addBoolean(K name, boolean value) { + return add(name, fromBoolean(name, value)); + } + + @Override + public T addFloat(K name, float value) { + return add(name, fromFloat(name, value)); + } + + @Override + public T addByte(K name, byte value) { + return add(name, fromByte(name, value)); + } + + @Override + public T addShort(K name, short value) { + return add(name, fromShort(name, value)); + } + + @Override + public T add(Headers headers) { + if (headers == this) { + throw new IllegalArgumentException("can't add to itself."); + } + addImpl(headers); + return thisT(); + } + + protected void addImpl(Headers headers) { + if (headers instanceof DefaultHeaders) { + @SuppressWarnings("unchecked") + final DefaultHeaders defaultHeaders = + (DefaultHeaders) headers; + HeaderEntry e = defaultHeaders.head.after; + if (defaultHeaders.hashingStrategy == hashingStrategy && + defaultHeaders.nameValidator == nameValidator) { + // Fastest copy + while (e != defaultHeaders.head) { + add0(e.hash, index(e.hash), e.key, e.value); + e = e.after; + } + } else { + // Fast copy + while (e != defaultHeaders.head) { + add(e.key, e.value); + e = e.after; + } + } + } else { + // Slow copy + for (Entry header : headers) { + add(header.getKey(), header.getValue()); + } + } + } + + @Override + public T set(K name, V value) { + validateName(nameValidator, false, name); + validateValue(valueValidator, name, value); + checkNotNull(value, "value"); + int h = hashingStrategy.hashCode(name); + int i = index(h); + remove0(h, i, name); + add0(h, i, name, value); + return thisT(); + } + + @Override + public T set(K name, Iterable values) { + validateName(nameValidator, false, name); + checkNotNull(values, "values"); + + int h = hashingStrategy.hashCode(name); + int i = index(h); + + remove0(h, i, name); + for (V v: values) { + if (v == null) { + break; + } + validateValue(valueValidator, name, v); + add0(h, i, name, v); + } + + return thisT(); + } + + @Override + public T set(K name, V... values) { + validateName(nameValidator, false, name); + checkNotNull(values, "values"); + + int h = hashingStrategy.hashCode(name); + int i = index(h); + + remove0(h, i, name); + for (V v: values) { + if (v == null) { + break; + } + validateValue(valueValidator, name, v); + add0(h, i, name, v); + } + + return thisT(); + } + + @Override + public T setObject(K name, Object value) { + V convertedValue = checkNotNull(fromObject(name, value), "convertedValue"); + return set(name, convertedValue); + } + + @Override + public T setObject(K name, Iterable values) { + validateName(nameValidator, false, name); + + int h = hashingStrategy.hashCode(name); + int i = index(h); + + remove0(h, i, name); + for (Object v: values) { + if (v == null) { + break; + } + V converted = fromObject(name, v); + validateValue(valueValidator, name, converted); + add0(h, i, name, converted); + } + + return thisT(); + } + + @Override + public T setObject(K name, Object... values) { + validateName(nameValidator, false, name); + + int h = hashingStrategy.hashCode(name); + int i = index(h); + + remove0(h, i, name); + for (Object v: values) { + if (v == null) { + break; + } + V converted = fromObject(name, v); + validateValue(valueValidator, name, converted); + add0(h, i, name, converted); + } + + return thisT(); + } + + @Override + public T setInt(K name, int value) { + return set(name, fromInt(name, value)); + } + + @Override + public T setLong(K name, long value) { + return set(name, fromLong(name, value)); + } + + @Override + public T setDouble(K name, double value) { + return set(name, fromDouble(name, value)); + } + + @Override + public T setTimeMillis(K name, long value) { + return set(name, fromTimeMillis(name, value)); + } + + @Override + public T setFloat(K name, float value) { + return set(name, fromFloat(name, value)); + } + + @Override + public T setChar(K name, char value) { + return set(name, fromChar(name, value)); + } + + @Override + public T setBoolean(K name, boolean value) { + return set(name, fromBoolean(name, value)); + } + + @Override + public T setByte(K name, byte value) { + return set(name, fromByte(name, value)); + } + + @Override + public T setShort(K name, short value) { + return set(name, fromShort(name, value)); + } + + @Override + public T set(Headers headers) { + if (headers != this) { + clear(); + addImpl(headers); + } + return thisT(); + } + + @Override + public T setAll(Headers headers) { + if (headers != this) { + for (K key : headers.names()) { + remove(key); + } + addImpl(headers); + } + return thisT(); + } + + @Override + public boolean remove(K name) { + return getAndRemove(name) != null; + } + + @Override + public T clear() { + Arrays.fill(entries, null); + head.before = head.after = head; + size = 0; + return thisT(); + } + + @Override + public Iterator> iterator() { + return new HeaderIterator(); + } + + @Override + public Boolean getBoolean(K name) { + V v = get(name); + try { + return v != null ? toBoolean(name, v) : null; + } catch (RuntimeException ignore) { + return null; + } + } + + @Override + public boolean getBoolean(K name, boolean defaultValue) { + Boolean v = getBoolean(name); + return v != null ? v : defaultValue; + } + + @Override + public Byte getByte(K name) { + V v = get(name); + try { + return v != null ? toByte(name, v) : null; + } catch (RuntimeException ignore) { + return null; + } + } + + @Override + public byte getByte(K name, byte defaultValue) { + Byte v = getByte(name); + return v != null ? v : defaultValue; + } + + @Override + public Character getChar(K name) { + V v = get(name); + try { + return v != null ? toChar(name, v) : null; + } catch (RuntimeException ignore) { + return null; + } + } + + @Override + public char getChar(K name, char defaultValue) { + Character v = getChar(name); + return v != null ? v : defaultValue; + } + + @Override + public Short getShort(K name) { + V v = get(name); + try { + return v != null ? toShort(name, v) : null; + } catch (RuntimeException ignore) { + return null; + } + } + + @Override + public short getShort(K name, short defaultValue) { + Short v = getShort(name); + return v != null ? v : defaultValue; + } + + @Override + public Integer getInt(K name) { + V v = get(name); + try { + return v != null ? toInt(name, v) : null; + } catch (RuntimeException ignore) { + return null; + } + } + + @Override + public int getInt(K name, int defaultValue) { + Integer v = getInt(name); + return v != null ? v : defaultValue; + } + + @Override + public Long getLong(K name) { + V v = get(name); + try { + return v != null ? toLong(name, v) : null; + } catch (RuntimeException ignore) { + return null; + } + } + + @Override + public long getLong(K name, long defaultValue) { + Long v = getLong(name); + return v != null ? v : defaultValue; + } + + @Override + public Float getFloat(K name) { + V v = get(name); + try { + return v != null ? toFloat(name, v) : null; + } catch (RuntimeException ignore) { + return null; + } + } + + @Override + public float getFloat(K name, float defaultValue) { + Float v = getFloat(name); + return v != null ? v : defaultValue; + } + + @Override + public Double getDouble(K name) { + V v = get(name); + try { + return v != null ? toDouble(name, v) : null; + } catch (RuntimeException ignore) { + return null; + } + } + + @Override + public double getDouble(K name, double defaultValue) { + Double v = getDouble(name); + return v != null ? v : defaultValue; + } + + @Override + public Long getTimeMillis(K name) { + V v = get(name); + try { + return v != null ? toTimeMillis(name, v) : null; + } catch (RuntimeException ignore) { + return null; + } + } + + @Override + public long getTimeMillis(K name, long defaultValue) { + Long v = getTimeMillis(name); + return v != null ? v : defaultValue; + } + + @Override + public Boolean getBooleanAndRemove(K name) { + V v = getAndRemove(name); + try { + return v != null ? toBoolean(name, v) : null; + } catch (RuntimeException ignore) { + return null; + } + } + + @Override + public boolean getBooleanAndRemove(K name, boolean defaultValue) { + Boolean v = getBooleanAndRemove(name); + return v != null ? v : defaultValue; + } + + @Override + public Byte getByteAndRemove(K name) { + V v = getAndRemove(name); + try { + return v != null ? toByte(name, v) : null; + } catch (RuntimeException ignore) { + return null; + } + } + + @Override + public byte getByteAndRemove(K name, byte defaultValue) { + Byte v = getByteAndRemove(name); + return v != null ? v : defaultValue; + } + + @Override + public Character getCharAndRemove(K name) { + V v = getAndRemove(name); + try { + return v != null ? toChar(name, v) : null; + } catch (RuntimeException ignore) { + return null; + } + } + + @Override + public char getCharAndRemove(K name, char defaultValue) { + Character v = getCharAndRemove(name); + return v != null ? v : defaultValue; + } + + @Override + public Short getShortAndRemove(K name) { + V v = getAndRemove(name); + try { + return v != null ? toShort(name, v) : null; + } catch (RuntimeException ignore) { + return null; + } + } + + @Override + public short getShortAndRemove(K name, short defaultValue) { + Short v = getShortAndRemove(name); + return v != null ? v : defaultValue; + } + + @Override + public Integer getIntAndRemove(K name) { + V v = getAndRemove(name); + try { + return v != null ? toInt(name, v) : null; + } catch (RuntimeException ignore) { + return null; + } + } + + @Override + public int getIntAndRemove(K name, int defaultValue) { + Integer v = getIntAndRemove(name); + return v != null ? v : defaultValue; + } + + @Override + public Long getLongAndRemove(K name) { + V v = getAndRemove(name); + try { + return v != null ? toLong(name, v) : null; + } catch (RuntimeException ignore) { + return null; + } + } + + @Override + public long getLongAndRemove(K name, long defaultValue) { + Long v = getLongAndRemove(name); + return v != null ? v : defaultValue; + } + + @Override + public Float getFloatAndRemove(K name) { + V v = getAndRemove(name); + try { + return v != null ? toFloat(name, v) : null; + } catch (RuntimeException ignore) { + return null; + } + } + + @Override + public float getFloatAndRemove(K name, float defaultValue) { + Float v = getFloatAndRemove(name); + return v != null ? v : defaultValue; + } + + @Override + public Double getDoubleAndRemove(K name) { + V v = getAndRemove(name); + try { + return v != null ? toDouble(name, v) : null; + } catch (RuntimeException ignore) { + return null; + } + } + + @Override + public double getDoubleAndRemove(K name, double defaultValue) { + Double v = getDoubleAndRemove(name); + return v != null ? v : defaultValue; + } + + @Override + public Long getTimeMillisAndRemove(K name) { + V v = getAndRemove(name); + try { + return v != null ? toTimeMillis(name, v) : null; + } catch (RuntimeException ignore) { + return null; + } + } + + @Override + public long getTimeMillisAndRemove(K name, long defaultValue) { + Long v = getTimeMillisAndRemove(name); + return v != null ? v : defaultValue; + } + + @SuppressWarnings("unchecked") + @Override + public boolean equals(Object o) { + if (!(o instanceof Headers)) { + return false; + } + + return equals((Headers) o, JAVA_HASHER); + } + + @SuppressWarnings("unchecked") + @Override + public int hashCode() { + return hashCode(JAVA_HASHER); + } + + /** + * Test this object for equality against {@code h2}. + * @param h2 The object to check equality for. + * @param valueHashingStrategy Defines how values will be compared for equality. + * @return {@code true} if this object equals {@code h2} given {@code valueHashingStrategy}. + * {@code false} otherwise. + */ + public final boolean equals(Headers h2, HashingStrategy valueHashingStrategy) { + if (h2.size() != size()) { + return false; + } + + if (this == h2) { + return true; + } + + for (K name : names()) { + List otherValues = h2.getAll(name); + List values = getAll(name); + if (otherValues.size() != values.size()) { + return false; + } + for (int i = 0; i < otherValues.size(); i++) { + if (!valueHashingStrategy.equals(otherValues.get(i), values.get(i))) { + return false; + } + } + } + return true; + } + + /** + * Generate a hash code for this object given a {@link HashingStrategy} to generate hash codes for + * individual values. + * @param valueHashingStrategy Defines how values will be hashed. + */ + public final int hashCode(HashingStrategy valueHashingStrategy) { + int result = HASH_CODE_SEED; + for (K name : names()) { + result = 31 * result + hashingStrategy.hashCode(name); + List values = getAll(name); + for (int i = 0; i < values.size(); ++i) { + result = 31 * result + valueHashingStrategy.hashCode(values.get(i)); + } + } + return result; + } + + @Override + public String toString() { + return HeadersUtils.toString(getClass(), iterator(), size()); + } + + /** + * Call out to the given {@link NameValidator} to validate the given name. + * + * @param validator the validator to use + * @param forAdd {@code true } if this validation is for adding to the headers, or {@code false} if this is for + * setting (overwriting) the given header. + * @param name the name to validate. + */ + protected void validateName(NameValidator validator, boolean forAdd, K name) { + validator.validateName(name); + } + + protected void validateValue(ValueValidator validator, K name, V value) { + try { + validator.validate(value); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException("Validation failed for header '" + name + "'", e); + } + } + + protected HeaderEntry newHeaderEntry(int h, K name, V value, HeaderEntry next) { + return new HeaderEntry(h, name, value, next, head); + } + + protected ValueConverter valueConverter() { + return valueConverter; + } + + protected NameValidator nameValidator() { + return nameValidator; + } + + protected ValueValidator valueValidator() { + return valueValidator; + } + + private int index(int hash) { + return hash & hashMask; + } + + private void add0(int h, int i, K name, V value) { + // Update the hash table. + entries[i] = newHeaderEntry(h, name, value, entries[i]); + ++size; + } + + /** + * @return the first value inserted whose hash code equals {@code h} and whose name is equal to {@code name}. + */ + private V remove0(int h, int i, K name) { + HeaderEntry e = entries[i]; + if (e == null) { + return null; + } + + V value = null; + HeaderEntry next = e.next; + while (next != null) { + if (next.hash == h && hashingStrategy.equals(name, next.key)) { + value = next.value; + e.next = next.next; + next.remove(); + --size; + } else { + e = next; + } + + next = e.next; + } + + e = entries[i]; + if (e.hash == h && hashingStrategy.equals(name, e.key)) { + if (value == null) { + value = e.value; + } + entries[i] = e.next; + e.remove(); + --size; + } + + return value; + } + + HeaderEntry remove0(HeaderEntry entry, HeaderEntry previous) { + int i = index(entry.hash); + HeaderEntry firstEntry = entries[i]; + if (firstEntry == entry) { + entries[i] = entry.next; + previous = entries[i]; + } else if (previous == null) { + // If we don't have any existing starting point, then start from the beginning. + previous = firstEntry; + HeaderEntry next = firstEntry.next; + while (next != null && next != entry) { + previous = next; + next = next.next; + } + assert next != null: "Entry not found in its hash bucket: " + entry; + previous.next = entry.next; + } else { + previous.next = entry.next; + } + entry.remove(); + --size; + return previous; + } + + @SuppressWarnings("unchecked") + private T thisT() { + return (T) this; + } + + private V fromObject(K name, Object value) { + try { + return valueConverter.convertObject(checkNotNull(value, "value")); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException("Failed to convert object value for header '" + name + '\'', e); + } + } + + private V fromBoolean(K name, boolean value) { + try { + return valueConverter.convertBoolean(value); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException("Failed to convert boolean value for header '" + name + '\'', e); + } + } + + private V fromByte(K name, byte value) { + try { + return valueConverter.convertByte(value); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException("Failed to convert byte value for header '" + name + '\'', e); + } + } + + private V fromChar(K name, char value) { + try { + return valueConverter.convertChar(value); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException("Failed to convert char value for header '" + name + '\'', e); + } + } + + private V fromShort(K name, short value) { + try { + return valueConverter.convertShort(value); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException("Failed to convert short value for header '" + name + '\'', e); + } + } + + private V fromInt(K name, int value) { + try { + return valueConverter.convertInt(value); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException("Failed to convert int value for header '" + name + '\'', e); + } + } + + private V fromLong(K name, long value) { + try { + return valueConverter.convertLong(value); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException("Failed to convert long value for header '" + name + '\'', e); + } + } + + private V fromFloat(K name, float value) { + try { + return valueConverter.convertFloat(value); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException("Failed to convert float value for header '" + name + '\'', e); + } + } + + private V fromDouble(K name, double value) { + try { + return valueConverter.convertDouble(value); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException("Failed to convert double value for header '" + name + '\'', e); + } + } + + private V fromTimeMillis(K name, long value) { + try { + return valueConverter.convertTimeMillis(value); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException("Failed to convert millsecond value for header '" + name + '\'', e); + } + } + + private boolean toBoolean(K name, V value) { + try { + return valueConverter.convertToBoolean(value); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException("Failed to convert header value to boolean for header '" + name + '\''); + } + } + + private byte toByte(K name, V value) { + try { + return valueConverter.convertToByte(value); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException("Failed to convert header value to byte for header '" + name + '\''); + } + } + + private char toChar(K name, V value) { + try { + return valueConverter.convertToChar(value); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException("Failed to convert header value to char for header '" + name + '\''); + } + } + + private short toShort(K name, V value) { + try { + return valueConverter.convertToShort(value); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException("Failed to convert header value to short for header '" + name + '\''); + } + } + + private int toInt(K name, V value) { + try { + return valueConverter.convertToInt(value); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException("Failed to convert header value to int for header '" + name + '\''); + } + } + + private long toLong(K name, V value) { + try { + return valueConverter.convertToLong(value); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException("Failed to convert header value to long for header '" + name + '\''); + } + } + + private float toFloat(K name, V value) { + try { + return valueConverter.convertToFloat(value); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException("Failed to convert header value to float for header '" + name + '\''); + } + } + + private double toDouble(K name, V value) { + try { + return valueConverter.convertToDouble(value); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException("Failed to convert header value to double for header '" + name + '\''); + } + } + + private long toTimeMillis(K name, V value) { + try { + return valueConverter.convertToTimeMillis(value); + } catch (IllegalArgumentException e) { + throw new IllegalArgumentException( + "Failed to convert header value to millsecond for header '" + name + '\''); + } + } + + /** + * Returns a deep copy of this instance. + */ + public DefaultHeaders copy() { + DefaultHeaders copy = new DefaultHeaders( + hashingStrategy, valueConverter, nameValidator, entries.length); + copy.addImpl(this); + return copy; + } + + private final class HeaderIterator implements Iterator> { + private HeaderEntry current = head; + + @Override + public boolean hasNext() { + return current.after != head; + } + + @Override + public Entry next() { + current = current.after; + + if (current == head) { + throw new NoSuchElementException(); + } + + return current; + } + + @Override + public void remove() { + throw new UnsupportedOperationException("read only"); + } + } + + private final class ValueIterator implements Iterator { + private final K name; + private final int hash; + private HeaderEntry removalPrevious; + private HeaderEntry previous; + private HeaderEntry next; + + ValueIterator(K name) { + this.name = checkNotNull(name, "name"); + hash = hashingStrategy.hashCode(name); + calculateNext(entries[index(hash)]); + } + + @Override + public boolean hasNext() { + return next != null; + } + + @Override + public V next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + if (previous != null) { + removalPrevious = previous; + } + previous = next; + calculateNext(next.next); + return previous.value; + } + + @Override + public void remove() { + if (previous == null) { + throw new IllegalStateException(); + } + removalPrevious = remove0(previous, removalPrevious); + previous = null; + } + + private void calculateNext(HeaderEntry entry) { + while (entry != null) { + if (entry.hash == hash && hashingStrategy.equals(name, entry.key)) { + next = entry; + return; + } + entry = entry.next; + } + next = null; + } + } + + protected static class HeaderEntry implements Entry { + protected final int hash; + protected final K key; + protected V value; + /** + * In bucket linked list + */ + protected HeaderEntry next; + /** + * Overall insertion order linked list + */ + protected HeaderEntry before, after; + + protected HeaderEntry(int hash, K key) { + this.hash = hash; + this.key = key; + } + + HeaderEntry(int hash, K key, V value, HeaderEntry next, HeaderEntry head) { + this.hash = hash; + this.key = key; + this.value = value; + this.next = next; + + after = head; + before = head.before; + pointNeighborsToThis(); + } + + HeaderEntry() { + hash = -1; + key = null; + before = after = this; + } + + protected final void pointNeighborsToThis() { + before.after = this; + after.before = this; + } + + public final HeaderEntry before() { + return before; + } + + public final HeaderEntry after() { + return after; + } + + protected void remove() { + before.after = after; + after.before = before; + } + + @Override + public final K getKey() { + return key; + } + + @Override + public final V getValue() { + return value; + } + + @Override + public final V setValue(V value) { + checkNotNull(value, "value"); + V oldValue = this.value; + this.value = value; + return oldValue; + } + + @Override + public final String toString() { + return key.toString() + '=' + value.toString(); + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof Map.Entry)) { + return false; + } + Entry other = (Entry) o; + return (getKey() == null ? other.getKey() == null : getKey().equals(other.getKey())) && + (getValue() == null ? other.getValue() == null : getValue().equals(other.getValue())); + } + + @Override + public int hashCode() { + return (key == null ? 0 : key.hashCode()) ^ (value == null ? 0 : value.hashCode()); + } + } +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/DefaultHeadersImpl.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/DefaultHeadersImpl.java new file mode 100644 index 0000000..f725c04 --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/DefaultHeadersImpl.java @@ -0,0 +1,34 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec; + +import io.netty.util.HashingStrategy; + +/** + * A concrete implementation of {@link DefaultHeaders} that allows for direct instantiation. + * @param the type of the header name. + * @param the type of the header value. + */ +public final class DefaultHeadersImpl extends DefaultHeaders> { + public DefaultHeadersImpl(HashingStrategy nameHashingStrategy, + ValueConverter valueConverter, NameValidator nameValidator) { + super(nameHashingStrategy, valueConverter, nameValidator); + } + + public DefaultHeadersImpl(HashingStrategy nameHashingStrategy, ValueConverter valueConverter, + NameValidator nameValidator, int arraySizeHint, ValueValidator valueValidator) { + super(nameHashingStrategy, valueConverter, nameValidator, arraySizeHint, valueValidator); + } +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/DelimiterBasedFrameDecoder.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/DelimiterBasedFrameDecoder.java new file mode 100644 index 0000000..6819736 --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/DelimiterBasedFrameDecoder.java @@ -0,0 +1,332 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec; + +import static io.netty.util.internal.ObjectUtil.checkPositive; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.channel.ChannelHandlerContext; +import io.netty.util.internal.ObjectUtil; + +import java.util.List; + +/** + * A decoder that splits the received {@link ByteBuf}s by one or more + * delimiters. It is particularly useful for decoding the frames which ends + * with a delimiter such as {@link Delimiters#nulDelimiter() NUL} or + * {@linkplain Delimiters#lineDelimiter() newline characters}. + * + *

Predefined delimiters

+ *

+ * {@link Delimiters} defines frequently used delimiters for convenience' sake. + * + *

Specifying more than one delimiter

+ *

+ * {@link DelimiterBasedFrameDecoder} allows you to specify more than one + * delimiter. If more than one delimiter is found in the buffer, it chooses + * the delimiter which produces the shortest frame. For example, if you have + * the following data in the buffer: + *

+ * +--------------+
+ * | ABC\nDEF\r\n |
+ * +--------------+
+ * 
+ * a {@link DelimiterBasedFrameDecoder}({@link Delimiters#lineDelimiter() Delimiters.lineDelimiter()}) + * will choose {@code '\n'} as the first delimiter and produce two frames: + *
+ * +-----+-----+
+ * | ABC | DEF |
+ * +-----+-----+
+ * 
+ * rather than incorrectly choosing {@code '\r\n'} as the first delimiter: + *
+ * +----------+
+ * | ABC\nDEF |
+ * +----------+
+ * 
+ */ +public class DelimiterBasedFrameDecoder extends ByteToMessageDecoder { + + private final ByteBuf[] delimiters; + private final int maxFrameLength; + private final boolean stripDelimiter; + private final boolean failFast; + private boolean discardingTooLongFrame; + private int tooLongFrameLength; + /** Set only when decoding with "\n" and "\r\n" as the delimiter. */ + private final LineBasedFrameDecoder lineBasedDecoder; + + /** + * Creates a new instance. + * + * @param maxFrameLength the maximum length of the decoded frame. + * A {@link TooLongFrameException} is thrown if + * the length of the frame exceeds this value. + * @param delimiter the delimiter + */ + public DelimiterBasedFrameDecoder(int maxFrameLength, ByteBuf delimiter) { + this(maxFrameLength, true, delimiter); + } + + /** + * Creates a new instance. + * + * @param maxFrameLength the maximum length of the decoded frame. + * A {@link TooLongFrameException} is thrown if + * the length of the frame exceeds this value. + * @param stripDelimiter whether the decoded frame should strip out the + * delimiter or not + * @param delimiter the delimiter + */ + public DelimiterBasedFrameDecoder( + int maxFrameLength, boolean stripDelimiter, ByteBuf delimiter) { + this(maxFrameLength, stripDelimiter, true, delimiter); + } + + /** + * Creates a new instance. + * + * @param maxFrameLength the maximum length of the decoded frame. + * A {@link TooLongFrameException} is thrown if + * the length of the frame exceeds this value. + * @param stripDelimiter whether the decoded frame should strip out the + * delimiter or not + * @param failFast If true, a {@link TooLongFrameException} is + * thrown as soon as the decoder notices the length of the + * frame will exceed maxFrameLength regardless of + * whether the entire frame has been read. + * If false, a {@link TooLongFrameException} is + * thrown after the entire frame that exceeds + * maxFrameLength has been read. + * @param delimiter the delimiter + */ + public DelimiterBasedFrameDecoder( + int maxFrameLength, boolean stripDelimiter, boolean failFast, + ByteBuf delimiter) { + this(maxFrameLength, stripDelimiter, failFast, new ByteBuf[] { + delimiter.slice(delimiter.readerIndex(), delimiter.readableBytes())}); + } + + /** + * Creates a new instance. + * + * @param maxFrameLength the maximum length of the decoded frame. + * A {@link TooLongFrameException} is thrown if + * the length of the frame exceeds this value. + * @param delimiters the delimiters + */ + public DelimiterBasedFrameDecoder(int maxFrameLength, ByteBuf... delimiters) { + this(maxFrameLength, true, delimiters); + } + + /** + * Creates a new instance. + * + * @param maxFrameLength the maximum length of the decoded frame. + * A {@link TooLongFrameException} is thrown if + * the length of the frame exceeds this value. + * @param stripDelimiter whether the decoded frame should strip out the + * delimiter or not + * @param delimiters the delimiters + */ + public DelimiterBasedFrameDecoder( + int maxFrameLength, boolean stripDelimiter, ByteBuf... delimiters) { + this(maxFrameLength, stripDelimiter, true, delimiters); + } + + /** + * Creates a new instance. + * + * @param maxFrameLength the maximum length of the decoded frame. + * A {@link TooLongFrameException} is thrown if + * the length of the frame exceeds this value. + * @param stripDelimiter whether the decoded frame should strip out the + * delimiter or not + * @param failFast If true, a {@link TooLongFrameException} is + * thrown as soon as the decoder notices the length of the + * frame will exceed maxFrameLength regardless of + * whether the entire frame has been read. + * If false, a {@link TooLongFrameException} is + * thrown after the entire frame that exceeds + * maxFrameLength has been read. + * @param delimiters the delimiters + */ + public DelimiterBasedFrameDecoder( + int maxFrameLength, boolean stripDelimiter, boolean failFast, ByteBuf... delimiters) { + validateMaxFrameLength(maxFrameLength); + ObjectUtil.checkNonEmpty(delimiters, "delimiters"); + + if (isLineBased(delimiters) && !isSubclass()) { + lineBasedDecoder = new LineBasedFrameDecoder(maxFrameLength, stripDelimiter, failFast); + this.delimiters = null; + } else { + this.delimiters = new ByteBuf[delimiters.length]; + for (int i = 0; i < delimiters.length; i ++) { + ByteBuf d = delimiters[i]; + validateDelimiter(d); + this.delimiters[i] = d.slice(d.readerIndex(), d.readableBytes()); + } + lineBasedDecoder = null; + } + this.maxFrameLength = maxFrameLength; + this.stripDelimiter = stripDelimiter; + this.failFast = failFast; + } + + /** Returns true if the delimiters are "\n" and "\r\n". */ + private static boolean isLineBased(final ByteBuf[] delimiters) { + if (delimiters.length != 2) { + return false; + } + ByteBuf a = delimiters[0]; + ByteBuf b = delimiters[1]; + if (a.capacity() < b.capacity()) { + a = delimiters[1]; + b = delimiters[0]; + } + return a.capacity() == 2 && b.capacity() == 1 + && a.getByte(0) == '\r' && a.getByte(1) == '\n' + && b.getByte(0) == '\n'; + } + + /** + * Return {@code true} if the current instance is a subclass of DelimiterBasedFrameDecoder + */ + private boolean isSubclass() { + return getClass() != DelimiterBasedFrameDecoder.class; + } + + @Override + protected final void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { + Object decoded = decode(ctx, in); + if (decoded != null) { + out.add(decoded); + } + } + + /** + * Create a frame out of the {@link ByteBuf} and return it. + * + * @param ctx the {@link ChannelHandlerContext} which this {@link ByteToMessageDecoder} belongs to + * @param buffer the {@link ByteBuf} from which to read data + * @return frame the {@link ByteBuf} which represent the frame or {@code null} if no frame could + * be created. + */ + protected Object decode(ChannelHandlerContext ctx, ByteBuf buffer) throws Exception { + if (lineBasedDecoder != null) { + return lineBasedDecoder.decode(ctx, buffer); + } + // Try all delimiters and choose the delimiter which yields the shortest frame. + int minFrameLength = Integer.MAX_VALUE; + ByteBuf minDelim = null; + for (ByteBuf delim: delimiters) { + int frameLength = indexOf(buffer, delim); + if (frameLength >= 0 && frameLength < minFrameLength) { + minFrameLength = frameLength; + minDelim = delim; + } + } + + if (minDelim != null) { + int minDelimLength = minDelim.capacity(); + ByteBuf frame; + + if (discardingTooLongFrame) { + // We've just finished discarding a very large frame. + // Go back to the initial state. + discardingTooLongFrame = false; + buffer.skipBytes(minFrameLength + minDelimLength); + + int tooLongFrameLength = this.tooLongFrameLength; + this.tooLongFrameLength = 0; + if (!failFast) { + fail(tooLongFrameLength); + } + return null; + } + + if (minFrameLength > maxFrameLength) { + // Discard read frame. + buffer.skipBytes(minFrameLength + minDelimLength); + fail(minFrameLength); + return null; + } + + if (stripDelimiter) { + frame = buffer.readRetainedSlice(minFrameLength); + buffer.skipBytes(minDelimLength); + } else { + frame = buffer.readRetainedSlice(minFrameLength + minDelimLength); + } + + return frame; + } else { + if (!discardingTooLongFrame) { + if (buffer.readableBytes() > maxFrameLength) { + // Discard the content of the buffer until a delimiter is found. + tooLongFrameLength = buffer.readableBytes(); + buffer.skipBytes(buffer.readableBytes()); + discardingTooLongFrame = true; + if (failFast) { + fail(tooLongFrameLength); + } + } + } else { + // Still discarding the buffer since a delimiter is not found. + tooLongFrameLength += buffer.readableBytes(); + buffer.skipBytes(buffer.readableBytes()); + } + return null; + } + } + + private void fail(long frameLength) { + if (frameLength > 0) { + throw new TooLongFrameException( + "frame length exceeds " + maxFrameLength + + ": " + frameLength + " - discarded"); + } else { + throw new TooLongFrameException( + "frame length exceeds " + maxFrameLength + + " - discarding"); + } + } + + /** + * Returns the number of bytes between the readerIndex of the haystack and + * the first needle found in the haystack. -1 is returned if no needle is + * found in the haystack. + */ + private static int indexOf(ByteBuf haystack, ByteBuf needle) { + int index = ByteBufUtil.indexOf(needle, haystack); + if (index == -1) { + return -1; + } + return index - haystack.readerIndex(); + } + + private static void validateDelimiter(ByteBuf delimiter) { + ObjectUtil.checkNotNull(delimiter, "delimiter"); + if (!delimiter.isReadable()) { + throw new IllegalArgumentException("empty delimiter"); + } + } + + private static void validateMaxFrameLength(int maxFrameLength) { + checkPositive(maxFrameLength, "maxFrameLength"); + } +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/Delimiters.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/Delimiters.java new file mode 100644 index 0000000..93f086b --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/Delimiters.java @@ -0,0 +1,49 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; + +/** + * A set of commonly used delimiters for {@link DelimiterBasedFrameDecoder}. + */ +public final class Delimiters { + + /** + * Returns a {@code NUL (0x00)} delimiter, which could be used for + * Flash XML socket or any similar protocols. + */ + public static ByteBuf[] nulDelimiter() { + return new ByteBuf[] { + Unpooled.wrappedBuffer(new byte[] { 0 }) }; + } + + /** + * Returns {@code CR ('\r')} and {@code LF ('\n')} delimiters, which could + * be used for text-based line protocols. + */ + public static ByteBuf[] lineDelimiter() { + return new ByteBuf[] { + Unpooled.wrappedBuffer(new byte[] { '\r', '\n' }), + Unpooled.wrappedBuffer(new byte[] { '\n' }), + }; + } + + private Delimiters() { + // Unused + } +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/EmptyHeaders.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/EmptyHeaders.java new file mode 100644 index 0000000..c102497 --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/EmptyHeaders.java @@ -0,0 +1,526 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec; + +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.Map.Entry; +import java.util.Set; + +import static io.netty.handler.codec.DefaultHeaders.HASH_CODE_SEED; + +public class EmptyHeaders> implements Headers { + @Override + public V get(K name) { + return null; + } + + @Override + public V get(K name, V defaultValue) { + return defaultValue; + } + + @Override + public V getAndRemove(K name) { + return null; + } + + @Override + public V getAndRemove(K name, V defaultValue) { + return defaultValue; + } + + @Override + public List getAll(K name) { + return Collections.emptyList(); + } + + @Override + public List getAllAndRemove(K name) { + return Collections.emptyList(); + } + + @Override + public Boolean getBoolean(K name) { + return null; + } + + @Override + public boolean getBoolean(K name, boolean defaultValue) { + return defaultValue; + } + + @Override + public Byte getByte(K name) { + return null; + } + + @Override + public byte getByte(K name, byte defaultValue) { + return defaultValue; + } + + @Override + public Character getChar(K name) { + return null; + } + + @Override + public char getChar(K name, char defaultValue) { + return defaultValue; + } + + @Override + public Short getShort(K name) { + return null; + } + + @Override + public short getShort(K name, short defaultValue) { + return defaultValue; + } + + @Override + public Integer getInt(K name) { + return null; + } + + @Override + public int getInt(K name, int defaultValue) { + return defaultValue; + } + + @Override + public Long getLong(K name) { + return null; + } + + @Override + public long getLong(K name, long defaultValue) { + return defaultValue; + } + + @Override + public Float getFloat(K name) { + return null; + } + + @Override + public float getFloat(K name, float defaultValue) { + return defaultValue; + } + + @Override + public Double getDouble(K name) { + return null; + } + + @Override + public double getDouble(K name, double defaultValue) { + return defaultValue; + } + + @Override + public Long getTimeMillis(K name) { + return null; + } + + @Override + public long getTimeMillis(K name, long defaultValue) { + return defaultValue; + } + + @Override + public Boolean getBooleanAndRemove(K name) { + return null; + } + + @Override + public boolean getBooleanAndRemove(K name, boolean defaultValue) { + return defaultValue; + } + + @Override + public Byte getByteAndRemove(K name) { + return null; + } + + @Override + public byte getByteAndRemove(K name, byte defaultValue) { + return defaultValue; + } + + @Override + public Character getCharAndRemove(K name) { + return null; + } + + @Override + public char getCharAndRemove(K name, char defaultValue) { + return defaultValue; + } + + @Override + public Short getShortAndRemove(K name) { + return null; + } + + @Override + public short getShortAndRemove(K name, short defaultValue) { + return defaultValue; + } + + @Override + public Integer getIntAndRemove(K name) { + return null; + } + + @Override + public int getIntAndRemove(K name, int defaultValue) { + return defaultValue; + } + + @Override + public Long getLongAndRemove(K name) { + return null; + } + + @Override + public long getLongAndRemove(K name, long defaultValue) { + return defaultValue; + } + + @Override + public Float getFloatAndRemove(K name) { + return null; + } + + @Override + public float getFloatAndRemove(K name, float defaultValue) { + return defaultValue; + } + + @Override + public Double getDoubleAndRemove(K name) { + return null; + } + + @Override + public double getDoubleAndRemove(K name, double defaultValue) { + return defaultValue; + } + + @Override + public Long getTimeMillisAndRemove(K name) { + return null; + } + + @Override + public long getTimeMillisAndRemove(K name, long defaultValue) { + return defaultValue; + } + + @Override + public boolean contains(K name) { + return false; + } + + @Override + public boolean contains(K name, V value) { + return false; + } + + @Override + public boolean containsObject(K name, Object value) { + return false; + } + + @Override + public boolean containsBoolean(K name, boolean value) { + return false; + } + + @Override + public boolean containsByte(K name, byte value) { + return false; + } + + @Override + public boolean containsChar(K name, char value) { + return false; + } + + @Override + public boolean containsShort(K name, short value) { + return false; + } + + @Override + public boolean containsInt(K name, int value) { + return false; + } + + @Override + public boolean containsLong(K name, long value) { + return false; + } + + @Override + public boolean containsFloat(K name, float value) { + return false; + } + + @Override + public boolean containsDouble(K name, double value) { + return false; + } + + @Override + public boolean containsTimeMillis(K name, long value) { + return false; + } + + @Override + public int size() { + return 0; + } + + @Override + public boolean isEmpty() { + return true; + } + + @Override + public Set names() { + return Collections.emptySet(); + } + + @Override + public T add(K name, V value) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public T add(K name, Iterable values) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public T add(K name, V... values) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public T addObject(K name, Object value) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public T addObject(K name, Iterable values) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public T addObject(K name, Object... values) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public T addBoolean(K name, boolean value) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public T addByte(K name, byte value) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public T addChar(K name, char value) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public T addShort(K name, short value) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public T addInt(K name, int value) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public T addLong(K name, long value) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public T addFloat(K name, float value) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public T addDouble(K name, double value) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public T addTimeMillis(K name, long value) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public T add(Headers headers) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public T set(K name, V value) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public T set(K name, Iterable values) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public T set(K name, V... values) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public T setObject(K name, Object value) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public T setObject(K name, Iterable values) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public T setObject(K name, Object... values) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public T setBoolean(K name, boolean value) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public T setByte(K name, byte value) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public T setChar(K name, char value) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public T setShort(K name, short value) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public T setInt(K name, int value) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public T setLong(K name, long value) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public T setFloat(K name, float value) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public T setDouble(K name, double value) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public T setTimeMillis(K name, long value) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public T set(Headers headers) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public T setAll(Headers headers) { + throw new UnsupportedOperationException("read only"); + } + + @Override + public boolean remove(K name) { + return false; + } + + @Override + public T clear() { + return thisT(); + } + + /** + * Equivalent to {@link #getAll(Object)} but no intermediate list is generated. + * @param name the name of the header to retrieve + * @return an {@link Iterator} of header values corresponding to {@code name}. + */ + public Iterator valueIterator(@SuppressWarnings("unused") K name) { + List empty = Collections.emptyList(); + return empty.iterator(); + } + + @Override + public Iterator> iterator() { + List> empty = Collections.emptyList(); + return empty.iterator(); + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof Headers)) { + return false; + } + + Headers rhs = (Headers) o; + return isEmpty() && rhs.isEmpty(); + } + + @Override + public int hashCode() { + return HASH_CODE_SEED; + } + + @Override + public String toString() { + return new StringBuilder(getClass().getSimpleName()).append('[').append(']').toString(); + } + + @SuppressWarnings("unchecked") + private T thisT() { + return (T) this; + } +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/EncoderException.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/EncoderException.java new file mode 100644 index 0000000..ebf6efb --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/EncoderException.java @@ -0,0 +1,51 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec; + +/** + * An {@link CodecException} which is thrown by an encoder. + */ +public class EncoderException extends CodecException { + + private static final long serialVersionUID = -5086121160476476774L; + + /** + * Creates a new instance. + */ + public EncoderException() { + } + + /** + * Creates a new instance. + */ + public EncoderException(String message, Throwable cause) { + super(message, cause); + } + + /** + * Creates a new instance. + */ + public EncoderException(String message) { + super(message); + } + + /** + * Creates a new instance. + */ + public EncoderException(Throwable cause) { + super(cause); + } +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/FixedLengthFrameDecoder.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/FixedLengthFrameDecoder.java new file mode 100644 index 0000000..cf41607 --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/FixedLengthFrameDecoder.java @@ -0,0 +1,79 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec; + +import static io.netty.util.internal.ObjectUtil.checkPositive; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; + +import java.util.List; + +/** + * A decoder that splits the received {@link ByteBuf}s by the fixed number + * of bytes. For example, if you received the following four fragmented packets: + *
+ * +---+----+------+----+
+ * | A | BC | DEFG | HI |
+ * +---+----+------+----+
+ * 
+ * A {@link FixedLengthFrameDecoder}{@code (3)} will decode them into the + * following three packets with the fixed length: + *
+ * +-----+-----+-----+
+ * | ABC | DEF | GHI |
+ * +-----+-----+-----+
+ * 
+ */ +public class FixedLengthFrameDecoder extends ByteToMessageDecoder { + + private final int frameLength; + + /** + * Creates a new instance. + * + * @param frameLength the length of the frame + */ + public FixedLengthFrameDecoder(int frameLength) { + checkPositive(frameLength, "frameLength"); + this.frameLength = frameLength; + } + + @Override + protected final void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { + Object decoded = decode(ctx, in); + if (decoded != null) { + out.add(decoded); + } + } + + /** + * Create a frame out of the {@link ByteBuf} and return it. + * + * @param ctx the {@link ChannelHandlerContext} which this {@link ByteToMessageDecoder} belongs to + * @param in the {@link ByteBuf} from which to read data + * @return frame the {@link ByteBuf} which represent the frame or {@code null} if no frame could + * be created. + */ + protected Object decode( + @SuppressWarnings("UnusedParameters") ChannelHandlerContext ctx, ByteBuf in) throws Exception { + if (in.readableBytes() < frameLength) { + return null; + } else { + return in.readRetainedSlice(frameLength); + } + } +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/Headers.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/Headers.java new file mode 100644 index 0000000..97301ba --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/Headers.java @@ -0,0 +1,998 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec; + +import java.util.Iterator; +import java.util.List; +import java.util.Map.Entry; +import java.util.Set; + +/** + * Common interface for {@link Headers} which represents a mapping of key to value. + * Duplicate keys may be allowed by implementations. + * + * @param the type of the header name. + * @param the type of the header value. + * @param the type to use for return values when the intention is to return {@code this} object. + */ +public interface Headers> extends Iterable> { + /** + * Returns the value of a header with the specified name. If there is more than one value for the specified name, + * the first value in insertion order is returned. + * + * @param name the name of the header to retrieve + * @return the first header value if the header is found. {@code null} if there's no such header + */ + V get(K name); + + /** + * Returns the value of a header with the specified name. If there is more than one value for the specified name, + * the first value in insertion order is returned. + * + * @param name the name of the header to retrieve + * @param defaultValue the default value + * @return the first header value or {@code defaultValue} if there is no such header + */ + V get(K name, V defaultValue); + + /** + * Returns the value of a header with the specified name and removes it from this object. If there is more than + * one value for the specified name, the first value in insertion order is returned. + * + * @param name the name of the header to retrieve + * @return the first header value or {@code null} if there is no such header + */ + V getAndRemove(K name); + + /** + * Returns the value of a header with the specified name and removes it from this object. If there is more than + * one value for the specified name, the first value in insertion order is returned. + * + * @param name the name of the header to retrieve + * @param defaultValue the default value + * @return the first header value or {@code defaultValue} if there is no such header + */ + V getAndRemove(K name, V defaultValue); + + /** + * Returns all values for the header with the specified name. The returned {@link List} can't be modified. + * + * @param name the name of the header to retrieve + * @return a {@link List} of header values or an empty {@link List} if no values are found. + */ + List getAll(K name); + + /** + * Returns all values for the header with the specified name and removes them from this object. + * The returned {@link List} can't be modified. + * + * @param name the name of the header to retrieve + * @return a {@link List} of header values or an empty {@link List} if no values are found. + */ + List getAllAndRemove(K name); + + /** + * Returns the {@code boolean} value of a header with the specified name. If there is more than one value for the + * specified name, the first value in insertion order is returned. + * + * @param name the name of the header to retrieve + * @return the {@code boolean} value of the first value in insertion order or {@code null} if there is no such + * value or it can't be converted to {@code boolean}. + */ + Boolean getBoolean(K name); + + /** + * Returns the {@code boolean} value of a header with the specified name. If there is more than one value for the + * specified name, the first value in insertion order is returned. + * + * @param name the name of the header to retrieve + * @param defaultValue the default value + * @return the {@code boolean} value of the first value in insertion order or {@code defaultValue} if there is no + * such value or it can't be converted to {@code boolean}. + */ + boolean getBoolean(K name, boolean defaultValue); + + /** + * Returns the {@code byte} value of a header with the specified name. If there is more than one value for the + * specified name, the first value in insertion order is returned. + * + * @param name the name of the header to retrieve + * @return the {@code byte} value of the first value in insertion order or {@code null} if there is no such + * value or it can't be converted to {@code byte}. + */ + Byte getByte(K name); + + /** + * Returns the {@code byte} value of a header with the specified name. If there is more than one value for the + * specified name, the first value in insertion order is returned. + * + * @param name the name of the header to retrieve + * @param defaultValue the default value + * @return the {@code byte} value of the first value in insertion order or {@code defaultValue} if there is no + * such value or it can't be converted to {@code byte}. + */ + byte getByte(K name, byte defaultValue); + + /** + * Returns the {@code char} value of a header with the specified name. If there is more than one value for the + * specified name, the first value in insertion order is returned. + * + * @param name the name of the header to retrieve + * @return the {@code char} value of the first value in insertion order or {@code null} if there is no such + * value or it can't be converted to {@code char}. + */ + Character getChar(K name); + + /** + * Returns the {@code char} value of a header with the specified name. If there is more than one value for the + * specified name, the first value in insertion order is returned. + * + * @param name the name of the header to retrieve + * @param defaultValue the default value + * @return the {@code char} value of the first value in insertion order or {@code defaultValue} if there is no + * such value or it can't be converted to {@code char}. + */ + char getChar(K name, char defaultValue); + + /** + * Returns the {@code short} value of a header with the specified name. If there is more than one value for the + * specified name, the first value in insertion order is returned. + * + * @param name the name of the header to retrieve + * @return the {@code short} value of the first value in insertion order or {@code null} if there is no such + * value or it can't be converted to {@code short}. + */ + Short getShort(K name); + + /** + * Returns the {@code short} value of a header with the specified name. If there is more than one value for the + * specified name, the first value in insertion order is returned. + * + * @param name the name of the header to retrieve + * @param defaultValue the default value + * @return the {@code short} value of the first value in insertion order or {@code defaultValue} if there is no + * such value or it can't be converted to {@code short}. + */ + short getShort(K name, short defaultValue); + + /** + * Returns the {@code int} value of a header with the specified name. If there is more than one value for the + * specified name, the first value in insertion order is returned. + * + * @param name the name of the header to retrieve + * @return the {@code int} value of the first value in insertion order or {@code null} if there is no such + * value or it can't be converted to {@code int}. + */ + Integer getInt(K name); + + /** + * Returns the {@code int} value of a header with the specified name. If there is more than one value for the + * specified name, the first value in insertion order is returned. + * + * @param name the name of the header to retrieve + * @param defaultValue the default value + * @return the {@code int} value of the first value in insertion order or {@code defaultValue} if there is no + * such value or it can't be converted to {@code int}. + */ + int getInt(K name, int defaultValue); + + /** + * Returns the {@code long} value of a header with the specified name. If there is more than one value for the + * specified name, the first value in insertion order is returned. + * + * @param name the name of the header to retrieve + * @return the {@code long} value of the first value in insertion order or {@code null} if there is no such + * value or it can't be converted to {@code long}. + */ + Long getLong(K name); + + /** + * Returns the {@code long} value of a header with the specified name. If there is more than one value for the + * specified name, the first value in insertion order is returned. + * + * @param name the name of the header to retrieve + * @param defaultValue the default value + * @return the {@code long} value of the first value in insertion order or {@code defaultValue} if there is no + * such value or it can't be converted to {@code long}. + */ + long getLong(K name, long defaultValue); + + /** + * Returns the {@code float} value of a header with the specified name. If there is more than one value for the + * specified name, the first value in insertion order is returned. + * + * @param name the name of the header to retrieve + * @return the {@code float} value of the first value in insertion order or {@code null} if there is no such + * value or it can't be converted to {@code float}. + */ + Float getFloat(K name); + + /** + * Returns the {@code float} value of a header with the specified name. If there is more than one value for the + * specified name, the first value in insertion order is returned. + * + * @param name the name of the header to retrieve + * @param defaultValue the default value + * @return the {@code float} value of the first value in insertion order or {@code defaultValue} if there is no + * such value or it can't be converted to {@code float}. + */ + float getFloat(K name, float defaultValue); + + /** + * Returns the {@code double} value of a header with the specified name. If there is more than one value for the + * specified name, the first value in insertion order is returned. + * + * @param name the name of the header to retrieve + * @return the {@code double} value of the first value in insertion order or {@code null} if there is no such + * value or it can't be converted to {@code double}. + */ + Double getDouble(K name); + + /** + * Returns the {@code double} value of a header with the specified name. If there is more than one value for the + * specified name, the first value in insertion order is returned. + * + * @param name the name of the header to retrieve + * @param defaultValue the default value + * @return the {@code double} value of the first value in insertion order or {@code defaultValue} if there is no + * such value or it can't be converted to {@code double}. + */ + double getDouble(K name, double defaultValue); + + /** + * Returns the value of a header with the specified name in milliseconds. If there is more than one value for the + * specified name, the first value in insertion order is returned. + * + * @param name the name of the header to retrieve + * @return the milliseconds value of the first value in insertion order or {@code null} if there is no such + * value or it can't be converted to milliseconds. + */ + Long getTimeMillis(K name); + + /** + * Returns the value of a header with the specified name in milliseconds. If there is more than one value for the + * specified name, the first value in insertion order is returned. + * + * @param name the name of the header to retrieve + * @param defaultValue the default value + * @return the milliseconds value of the first value in insertion order or {@code defaultValue} if there is no such + * value or it can't be converted to milliseconds. + */ + long getTimeMillis(K name, long defaultValue); + + /** + * Returns the {@code boolean} value of a header with the specified {@code name} and removes the header from this + * object. If there is more than one value for the specified name, the first value in insertion order is returned. + * In any case all values for {@code name} are removed. + *

+ * If an exception occurs during the translation from type {@code T} all entries with {@code name} may still + * be removed. + * @param name the name of the header to retrieve + * @return the {@code boolean} value of the first value in insertion order or {@code null} if there is no + * such value or it can't be converted to {@code boolean}. + */ + Boolean getBooleanAndRemove(K name); + + /** + * Returns the {@code boolean} value of a header with the specified {@code name} and removes the header from this + * object. If there is more than one value for the specified name, the first value in insertion order is returned. + * In any case all values for {@code name} are removed. + *

+ * If an exception occurs during the translation from type {@code T} all entries with {@code name} may still + * be removed. + * @param name the name of the header to search + * @param defaultValue the default value + * @return the {@code boolean} value of the first value in insertion order or {@code defaultValue} if there is no + * such value or it can't be converted to {@code boolean}. + */ + boolean getBooleanAndRemove(K name, boolean defaultValue); + + /** + * Returns the {@code byte} value of a header with the specified {@code name} and removes the header from this + * object. If there is more than one value for the specified name, the first value in insertion order is returned. + * In any case all values for {@code name} are removed. + *

+ * If an exception occurs during the translation from type {@code T} all entries with {@code name} may still + * be removed. + * @param name the name of the header to search + * @return the {@code byte} value of the first value in insertion order or {@code null} if there is no + * such value or it can't be converted to {@code byte}. + */ + Byte getByteAndRemove(K name); + + /** + * Returns the {@code byte} value of a header with the specified {@code name} and removes the header from this + * object. If there is more than one value for the specified name, the first value in insertion order is returned. + * In any case all values for {@code name} are removed. + *

+ * If an exception occurs during the translation from type {@code T} all entries with {@code name} may still + * be removed. + * @param name the name of the header to search + * @param defaultValue the default value + * @return the {@code byte} value of the first value in insertion order or {@code defaultValue} if there is no + * such value or it can't be converted to {@code byte}. + */ + byte getByteAndRemove(K name, byte defaultValue); + + /** + * Returns the {@code char} value of a header with the specified {@code name} and removes the header from this + * object. If there is more than one value for the specified name, the first value in insertion order is returned. + * In any case all values for {@code name} are removed. + *

+ * If an exception occurs during the translation from type {@code T} all entries with {@code name} may still + * be removed. + * @param name the name of the header to search + * @return the {@code char} value of the first value in insertion order or {@code null} if there is no + * such value or it can't be converted to {@code char}. + */ + Character getCharAndRemove(K name); + + /** + * Returns the {@code char} value of a header with the specified {@code name} and removes the header from this + * object. If there is more than one value for the specified name, the first value in insertion order is returned. + * In any case all values for {@code name} are removed. + *

+ * If an exception occurs during the translation from type {@code T} all entries with {@code name} may still + * be removed. + * @param name the name of the header to search + * @param defaultValue the default value + * @return the {@code char} value of the first value in insertion order or {@code defaultValue} if there is no + * such value or it can't be converted to {@code char}. + */ + char getCharAndRemove(K name, char defaultValue); + + /** + * Returns the {@code short} value of a header with the specified {@code name} and removes the header from this + * object. If there is more than one value for the specified name, the first value in insertion order is returned. + * In any case all values for {@code name} are removed. + *

+ * If an exception occurs during the translation from type {@code T} all entries with {@code name} may still + * be removed. + * @param name the name of the header to search + * @return the {@code short} value of the first value in insertion order or {@code null} if there is no + * such value or it can't be converted to {@code short}. + */ + Short getShortAndRemove(K name); + + /** + * Returns the {@code short} value of a header with the specified {@code name} and removes the header from this + * object. If there is more than one value for the specified name, the first value in insertion order is returned. + * In any case all values for {@code name} are removed. + *

+ * If an exception occurs during the translation from type {@code T} all entries with {@code name} may still + * be removed. + * @param name the name of the header to search + * @param defaultValue the default value + * @return the {@code short} value of the first value in insertion order or {@code defaultValue} if there is no + * such value or it can't be converted to {@code short}. + */ + short getShortAndRemove(K name, short defaultValue); + + /** + * Returns the {@code int} value of a header with the specified {@code name} and removes the header from this + * object. If there is more than one value for the specified name, the first value in insertion order is returned. + * In any case all values for {@code name} are removed. + *

+ * If an exception occurs during the translation from type {@code T} all entries with {@code name} may still + * be removed. + * @param name the name of the header to search + * @return the {@code int} value of the first value in insertion order or {@code null} if there is no + * such value or it can't be converted to {@code int}. + */ + Integer getIntAndRemove(K name); + + /** + * Returns the {@code int} value of a header with the specified {@code name} and removes the header from this + * object. If there is more than one value for the specified name, the first value in insertion order is returned. + * In any case all values for {@code name} are removed. + *

+ * If an exception occurs during the translation from type {@code T} all entries with {@code name} may still + * be removed. + * @param name the name of the header to search + * @param defaultValue the default value + * @return the {@code int} value of the first value in insertion order or {@code defaultValue} if there is no + * such value or it can't be converted to {@code int}. + */ + int getIntAndRemove(K name, int defaultValue); + + /** + * Returns the {@code long} value of a header with the specified {@code name} and removes the header from this + * object. If there is more than one value for the specified name, the first value in insertion order is returned. + * In any case all values for {@code name} are removed. + *

+ * If an exception occurs during the translation from type {@code T} all entries with {@code name} may still + * be removed. + * @param name the name of the header to search + * @return the {@code long} value of the first value in insertion order or {@code null} if there is no + * such value or it can't be converted to {@code long}. + */ + Long getLongAndRemove(K name); + + /** + * Returns the {@code long} value of a header with the specified {@code name} and removes the header from this + * object. If there is more than one value for the specified name, the first value in insertion order is returned. + * In any case all values for {@code name} are removed. + *

+ * If an exception occurs during the translation from type {@code T} all entries with {@code name} may still + * be removed. + * @param name the name of the header to search + * @param defaultValue the default value + * @return the {@code long} value of the first value in insertion order or {@code defaultValue} if there is no + * such value or it can't be converted to {@code long}. + */ + long getLongAndRemove(K name, long defaultValue); + + /** + * Returns the {@code float} value of a header with the specified {@code name} and removes the header from this + * object. If there is more than one value for the specified name, the first value in insertion order is returned. + * In any case all values for {@code name} are removed. + *

+ * If an exception occurs during the translation from type {@code T} all entries with {@code name} may still + * be removed. + * @param name the name of the header to search + * @return the {@code float} value of the first value in insertion order or {@code null} if there is no + * such value or it can't be converted to {@code float}. + */ + Float getFloatAndRemove(K name); + + /** + * Returns the {@code float} value of a header with the specified {@code name} and removes the header from this + * object. If there is more than one value for the specified name, the first value in insertion order is returned. + * In any case all values for {@code name} are removed. + *

+ * If an exception occurs during the translation from type {@code T} all entries with {@code name} may still + * be removed. + * @param name the name of the header to search + * @param defaultValue the default value + * @return the {@code float} value of the first value in insertion order or {@code defaultValue} if there is no + * such value or it can't be converted to {@code float}. + */ + float getFloatAndRemove(K name, float defaultValue); + + /** + * Returns the {@code double} value of a header with the specified {@code name} and removes the header from this + * object. If there is more than one value for the specified name, the first value in insertion order is returned. + * In any case all values for {@code name} are removed. + *

+ * If an exception occurs during the translation from type {@code T} all entries with {@code name} may still + * be removed. + * @param name the name of the header to search + * @return the {@code double} value of the first value in insertion order or {@code null} if there is no + * such value or it can't be converted to {@code double}. + */ + Double getDoubleAndRemove(K name); + + /** + * Returns the {@code double} value of a header with the specified {@code name} and removes the header from this + * object. If there is more than one value for the specified name, the first value in insertion order is returned. + * In any case all values for {@code name} are removed. + *

+ * If an exception occurs during the translation from type {@code T} all entries with {@code name} may still + * be removed. + * @param name the name of the header to search + * @param defaultValue the default value + * @return the {@code double} value of the first value in insertion order or {@code defaultValue} if there is no + * such value or it can't be converted to {@code double}. + */ + double getDoubleAndRemove(K name, double defaultValue); + + /** + * Returns the value of a header with the specified {@code name} in milliseconds and removes the header from this + * object. If there is more than one value for the specified {@code name}, the first value in insertion order is + * returned. In any case all values for {@code name} are removed. + *

+ * If an exception occurs during the translation from type {@code T} all entries with {@code name} may still + * be removed. + * @param name the name of the header to retrieve + * @return the milliseconds value of the first value in insertion order or {@code null} if there is no such + * value or it can't be converted to milliseconds. + */ + Long getTimeMillisAndRemove(K name); + + /** + * Returns the value of a header with the specified {@code name} in milliseconds and removes the header from this + * object. If there is more than one value for the specified {@code name}, the first value in insertion order is + * returned. In any case all values for {@code name} are removed. + *

+ * If an exception occurs during the translation from type {@code T} all entries with {@code name} may still + * be removed. + * @param name the name of the header to retrieve + * @param defaultValue the default value + * @return the milliseconds value of the first value in insertion order or {@code defaultValue} if there is no such + * value or it can't be converted to milliseconds. + */ + long getTimeMillisAndRemove(K name, long defaultValue); + + /** + * Returns {@code true} if a header with the {@code name} exists, {@code false} otherwise. + * + * @param name the header name + */ + boolean contains(K name); + + /** + * Returns {@code true} if a header with the {@code name} and {@code value} exists, {@code false} otherwise. + *

+ * The {@link Object#equals(Object)} method is used to test for equality of {@code value}. + *

+ * @param name the header name + * @param value the header value of the header to find + */ + boolean contains(K name, V value); + + /** + * Returns {@code true} if a header with the name and value exists. + * + * @param name the header name + * @param value the header value + * @return {@code true} if it contains it {@code false} otherwise + */ + boolean containsObject(K name, Object value); + + /** + * Returns {@code true} if a header with the name and value exists. + * + * @param name the header name + * @param value the header value + * @return {@code true} if it contains it {@code false} otherwise + */ + boolean containsBoolean(K name, boolean value); + + /** + * Returns {@code true} if a header with the name and value exists. + * + * @param name the header name + * @param value the header value + * @return {@code true} if it contains it {@code false} otherwise + */ + boolean containsByte(K name, byte value); + + /** + * Returns {@code true} if a header with the name and value exists. + * + * @param name the header name + * @param value the header value + * @return {@code true} if it contains it {@code false} otherwise + */ + boolean containsChar(K name, char value); + + /** + * Returns {@code true} if a header with the name and value exists. + * + * @param name the header name + * @param value the header value + * @return {@code true} if it contains it {@code false} otherwise + */ + boolean containsShort(K name, short value); + + /** + * Returns {@code true} if a header with the name and value exists. + * + * @param name the header name + * @param value the header value + * @return {@code true} if it contains it {@code false} otherwise + */ + boolean containsInt(K name, int value); + + /** + * Returns {@code true} if a header with the name and value exists. + * + * @param name the header name + * @param value the header value + * @return {@code true} if it contains it {@code false} otherwise + */ + boolean containsLong(K name, long value); + + /** + * Returns {@code true} if a header with the name and value exists. + * + * @param name the header name + * @param value the header value + * @return {@code true} if it contains it {@code false} otherwise + */ + boolean containsFloat(K name, float value); + + /** + * Returns {@code true} if a header with the name and value exists. + * + * @param name the header name + * @param value the header value + * @return {@code true} if it contains it {@code false} otherwise + */ + boolean containsDouble(K name, double value); + + /** + * Returns {@code true} if a header with the name and value exists. + * + * @param name the header name + * @param value the header value + * @return {@code true} if it contains it {@code false} otherwise + */ + boolean containsTimeMillis(K name, long value); + + /** + * Returns the number of headers in this object. + */ + int size(); + + /** + * Returns {@code true} if {@link #size()} equals {@code 0}. + */ + boolean isEmpty(); + + /** + * Returns a {@link Set} of all header names in this object. The returned {@link Set} cannot be modified. + */ + Set names(); + + /** + * Adds a new header with the specified {@code name} and {@code value}. + * + * @param name the name of the header + * @param value the value of the header + * @return {@code this} + */ + T add(K name, V value); + + /** + * Adds new headers with the specified {@code name} and {@code values}. This method is semantically equivalent to + * + *
+     * for (T value : values) {
+     *     headers.add(name, value);
+     * }
+     * 
+ * + * @param name the header name + * @param values the values of the header + * @return {@code this} + */ + T add(K name, Iterable values); + + /** + * Adds new headers with the specified {@code name} and {@code values}. This method is semantically equivalent to + * + *
+     * for (T value : values) {
+     *     headers.add(name, value);
+     * }
+     * 
+ * + * @param name the header name + * @param values the values of the header + * @return {@code this} + */ + T add(K name, V... values); + + /** + * Adds a new header. Before the {@code value} is added, it's converted to type {@code T}. + * + * @param name the header name + * @param value the value of the header + * @return {@code this} + */ + T addObject(K name, Object value); + + /** + * Adds a new header with the specified name and values. This method is equivalent to + * + *
+     * for (Object v : values) {
+     *     headers.addObject(name, v);
+     * }
+     * 
+ * + * @param name the header name + * @param values the value of the header + * @return {@code this} + */ + T addObject(K name, Iterable values); + + /** + * Adds a new header with the specified name and values. This method is equivalent to + * + *
+     * for (Object v : values) {
+     *     headers.addObject(name, v);
+     * }
+     * 
+ * + * @param name the header name + * @param values the value of the header + * @return {@code this} + */ + T addObject(K name, Object... values); + + /** + * Adds a new header. + * + * @param name the header name + * @param value the value of the header + * @return {@code this} + */ + T addBoolean(K name, boolean value); + + /** + * Adds a new header. + * + * @param name the header name + * @param value the value of the header + * @return {@code this} + */ + T addByte(K name, byte value); + + /** + * Adds a new header. + * + * @param name the header name + * @param value the value of the header + * @return {@code this} + */ + T addChar(K name, char value); + + /** + * Adds a new header. + * + * @param name the header name + * @param value the value of the header + * @return {@code this} + */ + T addShort(K name, short value); + + /** + * Adds a new header. + * + * @param name the header name + * @param value the value of the header + * @return {@code this} + */ + T addInt(K name, int value); + + /** + * Adds a new header. + * + * @param name the header name + * @param value the value of the header + * @return {@code this} + */ + T addLong(K name, long value); + + /** + * Adds a new header. + * + * @param name the header name + * @param value the value of the header + * @return {@code this} + */ + T addFloat(K name, float value); + + /** + * Adds a new header. + * + * @param name the header name + * @param value the value of the header + * @return {@code this} + */ + T addDouble(K name, double value); + + /** + * Adds a new header. + * + * @param name the header name + * @param value the value of the header + * @return {@code this} + */ + T addTimeMillis(K name, long value); + + /** + * Adds all header names and values of {@code headers} to this object. + * + * @throws IllegalArgumentException if {@code headers == this}. + * @return {@code this} + */ + T add(Headers headers); + + /** + * Sets a header with the specified name and value. Any existing headers with the same name are overwritten. + * + * @param name the header name + * @param value the value of the header + * @return {@code this} + */ + T set(K name, V value); + + /** + * Sets a new header with the specified name and values. This method is equivalent to + * + *
+     * for (T v : values) {
+     *     headers.addObject(name, v);
+     * }
+     * 
+ * + * @param name the header name + * @param values the value of the header + * @return {@code this} + */ + T set(K name, Iterable values); + + /** + * Sets a header with the specified name and values. Any existing headers with this name are removed. This method + * is equivalent to: + * + *
+     * headers.remove(name);
+     * for (T v : values) {
+     *     headers.add(name, v);
+     * }
+     * 
+ * + * @param name the header name + * @param values the value of the header + * @return {@code this} + */ + T set(K name, V... values); + + /** + * Sets a new header. Any existing headers with this name are removed. Before the {@code value} is add, it's + * converted to type {@code T}. + * + * @param name the header name + * @param value the value of the header + * @throws NullPointerException if either {@code name} or {@code value} before or after its conversion is + * {@code null}. + * @return {@code this} + */ + T setObject(K name, Object value); + + /** + * Sets a header with the specified name and values. Any existing headers with this name are removed. This method + * is equivalent to: + * + *
+     * headers.remove(name);
+     * for (Object v : values) {
+     *     headers.addObject(name, v);
+     * }
+     * 
+ * + * @param name the header name + * @param values the values of the header + * @return {@code this} + */ + T setObject(K name, Iterable values); + + /** + * Sets a header with the specified name and values. Any existing headers with this name are removed. This method + * is equivalent to: + * + *
+     * headers.remove(name);
+     * for (Object v : values) {
+     *     headers.addObject(name, v);
+     * }
+     * 
+ * + * @param name the header name + * @param values the values of the header + * @return {@code this} + */ + T setObject(K name, Object... values); + + /** + * Set the {@code name} to {@code value}. This will remove all previous values associated with {@code name}. + * @param name The name to modify + * @param value The value + * @return {@code this} + */ + T setBoolean(K name, boolean value); + + /** + * Set the {@code name} to {@code value}. This will remove all previous values associated with {@code name}. + * @param name The name to modify + * @param value The value + * @return {@code this} + */ + T setByte(K name, byte value); + + /** + * Set the {@code name} to {@code value}. This will remove all previous values associated with {@code name}. + * @param name The name to modify + * @param value The value + * @return {@code this} + */ + T setChar(K name, char value); + + /** + * Set the {@code name} to {@code value}. This will remove all previous values associated with {@code name}. + * @param name The name to modify + * @param value The value + * @return {@code this} + */ + T setShort(K name, short value); + + /** + * Set the {@code name} to {@code value}. This will remove all previous values associated with {@code name}. + * @param name The name to modify + * @param value The value + * @return {@code this} + */ + T setInt(K name, int value); + + /** + * Set the {@code name} to {@code value}. This will remove all previous values associated with {@code name}. + * @param name The name to modify + * @param value The value + * @return {@code this} + */ + T setLong(K name, long value); + + /** + * Set the {@code name} to {@code value}. This will remove all previous values associated with {@code name}. + * @param name The name to modify + * @param value The value + * @return {@code this} + */ + T setFloat(K name, float value); + + /** + * Set the {@code name} to {@code value}. This will remove all previous values associated with {@code name}. + * @param name The name to modify + * @param value The value + * @return {@code this} + */ + T setDouble(K name, double value); + + /** + * Set the {@code name} to {@code value}. This will remove all previous values associated with {@code name}. + * @param name The name to modify + * @param value The value + * @return {@code this} + */ + T setTimeMillis(K name, long value); + + /** + * Clears the current header entries and copies all header entries of the specified {@code headers}. + * + * @return {@code this} + */ + T set(Headers headers); + + /** + * Retains all current headers but calls {@link #set(K, V)} for each entry in {@code headers}. + * + * @param headers The headers used to {@link #set(K, V)} values in this instance + * @return {@code this} + */ + T setAll(Headers headers); + + /** + * Removes all headers with the specified {@code name}. + * + * @param name the header name + * @return {@code true} if at least one entry has been removed. + */ + boolean remove(K name); + + /** + * Removes all headers. After a call to this method {@link #size()} equals {@code 0}. + * + * @return {@code this} + */ + T clear(); + + @Override + Iterator> iterator(); +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/HeadersUtils.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/HeadersUtils.java new file mode 100644 index 0000000..e8ee5d5 --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/HeadersUtils.java @@ -0,0 +1,221 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec; + +import java.util.AbstractCollection; +import java.util.AbstractList; +import java.util.Iterator; +import java.util.List; +import java.util.Map.Entry; +import java.util.Set; + +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +/** + * Provides utility methods related to {@link Headers}. + */ +public final class HeadersUtils { + + private HeadersUtils() { + } + + /** + * {@link Headers#get(Object)} and convert each element of {@link List} to a {@link String}. + * @param name the name of the header to retrieve + * @return a {@link List} of header values or an empty {@link List} if no values are found. + */ + public static List getAllAsString(Headers headers, K name) { + final List allNames = headers.getAll(name); + return new AbstractList() { + @Override + public String get(int index) { + V value = allNames.get(index); + return value != null ? value.toString() : null; + } + + @Override + public int size() { + return allNames.size(); + } + }; + } + + /** + * {@link Headers#get(Object)} and convert the result to a {@link String}. + * @param headers the headers to get the {@code name} from + * @param name the name of the header to retrieve + * @return the first header value if the header is found. {@code null} if there's no such entry. + */ + public static String getAsString(Headers headers, K name) { + V orig = headers.get(name); + return orig != null ? orig.toString() : null; + } + + /** + * {@link Headers#iterator()} which converts each {@link Entry}'s key and value to a {@link String}. + */ + public static Iterator> iteratorAsString( + Iterable> headers) { + return new StringEntryIterator(headers.iterator()); + } + + /** + * Helper for implementing toString for {@link DefaultHeaders} and wrappers such as DefaultHttpHeaders. + * @param headersClass the class of headers + * @param headersIt the iterator on the actual headers + * @param size the size of the iterator + * @return a String representation of the headers + */ + public static String toString(Class headersClass, Iterator> headersIt, int size) { + String simpleName = headersClass.getSimpleName(); + if (size == 0) { + return simpleName + "[]"; + } else { + // original capacity assumes 20 chars per headers + StringBuilder sb = new StringBuilder(simpleName.length() + 2 + size * 20) + .append(simpleName) + .append('['); + while (headersIt.hasNext()) { + Entry header = headersIt.next(); + sb.append(header.getKey()).append(": ").append(header.getValue()).append(", "); + } + sb.setLength(sb.length() - 2); + return sb.append(']').toString(); + } + } + + /** + * {@link Headers#names()} and convert each element of {@link Set} to a {@link String}. + * @param headers the headers to get the names from + * @return a {@link Set} of header values or an empty {@link Set} if no values are found. + */ + public static Set namesAsString(Headers headers) { + return new DelegatingNameSet(headers); + } + + private static final class StringEntryIterator implements Iterator> { + private final Iterator> iter; + + StringEntryIterator(Iterator> iter) { + this.iter = iter; + } + + @Override + public boolean hasNext() { + return iter.hasNext(); + } + + @Override + public Entry next() { + return new StringEntry(iter.next()); + } + + @Override + public void remove() { + iter.remove(); + } + } + + private static final class StringEntry implements Entry { + private final Entry entry; + private String name; + private String value; + + StringEntry(Entry entry) { + this.entry = entry; + } + + @Override + public String getKey() { + if (name == null) { + name = entry.getKey().toString(); + } + return name; + } + + @Override + public String getValue() { + if (value == null && entry.getValue() != null) { + value = entry.getValue().toString(); + } + return value; + } + + @Override + public String setValue(String value) { + String old = getValue(); + entry.setValue(value); + return old; + } + + @Override + public String toString() { + return entry.toString(); + } + } + + private static final class StringIterator implements Iterator { + private final Iterator iter; + + StringIterator(Iterator iter) { + this.iter = iter; + } + + @Override + public boolean hasNext() { + return iter.hasNext(); + } + + @Override + public String next() { + T next = iter.next(); + return next != null ? next.toString() : null; + } + + @Override + public void remove() { + iter.remove(); + } + } + + private static final class DelegatingNameSet extends AbstractCollection implements Set { + private final Headers headers; + + DelegatingNameSet(Headers headers) { + this.headers = checkNotNull(headers, "headers"); + } + + @Override + public int size() { + return headers.names().size(); + } + + @Override + public boolean isEmpty() { + return headers.isEmpty(); + } + + @Override + public boolean contains(Object o) { + return headers.contains(o.toString()); + } + + @Override + public Iterator iterator() { + return new StringIterator(headers.names().iterator()); + } + } +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/LengthFieldBasedFrameDecoder.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/LengthFieldBasedFrameDecoder.java new file mode 100644 index 0000000..2ae4cae --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/LengthFieldBasedFrameDecoder.java @@ -0,0 +1,516 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec; + +import static io.netty.util.internal.ObjectUtil.checkNotNull; +import static io.netty.util.internal.ObjectUtil.checkPositive; +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; + +import java.nio.ByteOrder; +import java.util.List; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; + +/** + * A decoder that splits the received {@link ByteBuf}s dynamically by the + * value of the length field in the message. It is particularly useful when you + * decode a binary message which has an integer header field that represents the + * length of the message body or the whole message. + *

+ * {@link LengthFieldBasedFrameDecoder} has many configuration parameters so + * that it can decode any message with a length field, which is often seen in + * proprietary client-server protocols. Here are some example that will give + * you the basic idea on which option does what. + * + *

2 bytes length field at offset 0, do not strip header

+ * + * The value of the length field in this example is 12 (0x0C) which + * represents the length of "HELLO, WORLD". By default, the decoder assumes + * that the length field represents the number of the bytes that follows the + * length field. Therefore, it can be decoded with the simplistic parameter + * combination. + *
+ * lengthFieldOffset   = 0
+ * lengthFieldLength   = 2
+ * lengthAdjustment    = 0
+ * initialBytesToStrip = 0 (= do not strip header)
+ *
+ * BEFORE DECODE (14 bytes)         AFTER DECODE (14 bytes)
+ * +--------+----------------+      +--------+----------------+
+ * | Length | Actual Content |----->| Length | Actual Content |
+ * | 0x000C | "HELLO, WORLD" |      | 0x000C | "HELLO, WORLD" |
+ * +--------+----------------+      +--------+----------------+
+ * 
+ * + *

2 bytes length field at offset 0, strip header

+ * + * Because we can get the length of the content by calling + * {@link ByteBuf#readableBytes()}, you might want to strip the length + * field by specifying initialBytesToStrip. In this example, we + * specified 2, that is same with the length of the length field, to + * strip the first two bytes. + *
+ * lengthFieldOffset   = 0
+ * lengthFieldLength   = 2
+ * lengthAdjustment    = 0
+ * initialBytesToStrip = 2 (= the length of the Length field)
+ *
+ * BEFORE DECODE (14 bytes)         AFTER DECODE (12 bytes)
+ * +--------+----------------+      +----------------+
+ * | Length | Actual Content |----->| Actual Content |
+ * | 0x000C | "HELLO, WORLD" |      | "HELLO, WORLD" |
+ * +--------+----------------+      +----------------+
+ * 
+ * + *

2 bytes length field at offset 0, do not strip header, the length field + * represents the length of the whole message

+ * + * In most cases, the length field represents the length of the message body + * only, as shown in the previous examples. However, in some protocols, the + * length field represents the length of the whole message, including the + * message header. In such a case, we specify a non-zero + * lengthAdjustment. Because the length value in this example message + * is always greater than the body length by 2, we specify -2 + * as lengthAdjustment for compensation. + *
+ * lengthFieldOffset   =  0
+ * lengthFieldLength   =  2
+ * lengthAdjustment    = -2 (= the length of the Length field)
+ * initialBytesToStrip =  0
+ *
+ * BEFORE DECODE (14 bytes)         AFTER DECODE (14 bytes)
+ * +--------+----------------+      +--------+----------------+
+ * | Length | Actual Content |----->| Length | Actual Content |
+ * | 0x000E | "HELLO, WORLD" |      | 0x000E | "HELLO, WORLD" |
+ * +--------+----------------+      +--------+----------------+
+ * 
+ * + *

3 bytes length field at the end of 5 bytes header, do not strip header

+ * + * The following message is a simple variation of the first example. An extra + * header value is prepended to the message. lengthAdjustment is zero + * again because the decoder always takes the length of the prepended data into + * account during frame length calculation. + *
+ * lengthFieldOffset   = 2 (= the length of Header 1)
+ * lengthFieldLength   = 3
+ * lengthAdjustment    = 0
+ * initialBytesToStrip = 0
+ *
+ * BEFORE DECODE (17 bytes)                      AFTER DECODE (17 bytes)
+ * +----------+----------+----------------+      +----------+----------+----------------+
+ * | Header 1 |  Length  | Actual Content |----->| Header 1 |  Length  | Actual Content |
+ * |  0xCAFE  | 0x00000C | "HELLO, WORLD" |      |  0xCAFE  | 0x00000C | "HELLO, WORLD" |
+ * +----------+----------+----------------+      +----------+----------+----------------+
+ * 
+ * + *

3 bytes length field at the beginning of 5 bytes header, do not strip header

+ * + * This is an advanced example that shows the case where there is an extra + * header between the length field and the message body. You have to specify a + * positive lengthAdjustment so that the decoder counts the extra + * header into the frame length calculation. + *
+ * lengthFieldOffset   = 0
+ * lengthFieldLength   = 3
+ * lengthAdjustment    = 2 (= the length of Header 1)
+ * initialBytesToStrip = 0
+ *
+ * BEFORE DECODE (17 bytes)                      AFTER DECODE (17 bytes)
+ * +----------+----------+----------------+      +----------+----------+----------------+
+ * |  Length  | Header 1 | Actual Content |----->|  Length  | Header 1 | Actual Content |
+ * | 0x00000C |  0xCAFE  | "HELLO, WORLD" |      | 0x00000C |  0xCAFE  | "HELLO, WORLD" |
+ * +----------+----------+----------------+      +----------+----------+----------------+
+ * 
+ * + *

2 bytes length field at offset 1 in the middle of 4 bytes header, + * strip the first header field and the length field

+ * + * This is a combination of all the examples above. There are the prepended + * header before the length field and the extra header after the length field. + * The prepended header affects the lengthFieldOffset and the extra + * header affects the lengthAdjustment. We also specified a non-zero + * initialBytesToStrip to strip the length field and the prepended + * header from the frame. If you don't want to strip the prepended header, you + * could specify 0 for initialBytesToSkip. + *
+ * lengthFieldOffset   = 1 (= the length of HDR1)
+ * lengthFieldLength   = 2
+ * lengthAdjustment    = 1 (= the length of HDR2)
+ * initialBytesToStrip = 3 (= the length of HDR1 + LEN)
+ *
+ * BEFORE DECODE (16 bytes)                       AFTER DECODE (13 bytes)
+ * +------+--------+------+----------------+      +------+----------------+
+ * | HDR1 | Length | HDR2 | Actual Content |----->| HDR2 | Actual Content |
+ * | 0xCA | 0x000C | 0xFE | "HELLO, WORLD" |      | 0xFE | "HELLO, WORLD" |
+ * +------+--------+------+----------------+      +------+----------------+
+ * 
+ * + *

2 bytes length field at offset 1 in the middle of 4 bytes header, + * strip the first header field and the length field, the length field + * represents the length of the whole message

+ * + * Let's give another twist to the previous example. The only difference from + * the previous example is that the length field represents the length of the + * whole message instead of the message body, just like the third example. + * We have to count the length of HDR1 and Length into lengthAdjustment. + * Please note that we don't need to take the length of HDR2 into account + * because the length field already includes the whole header length. + *
+ * lengthFieldOffset   =  1
+ * lengthFieldLength   =  2
+ * lengthAdjustment    = -3 (= the length of HDR1 + LEN, negative)
+ * initialBytesToStrip =  3
+ *
+ * BEFORE DECODE (16 bytes)                       AFTER DECODE (13 bytes)
+ * +------+--------+------+----------------+      +------+----------------+
+ * | HDR1 | Length | HDR2 | Actual Content |----->| HDR2 | Actual Content |
+ * | 0xCA | 0x0010 | 0xFE | "HELLO, WORLD" |      | 0xFE | "HELLO, WORLD" |
+ * +------+--------+------+----------------+      +------+----------------+
+ * 
+ * @see LengthFieldPrepender + */ +public class LengthFieldBasedFrameDecoder extends ByteToMessageDecoder { + + private final ByteOrder byteOrder; + private final int maxFrameLength; + private final int lengthFieldOffset; + private final int lengthFieldLength; + private final int lengthFieldEndOffset; + private final int lengthAdjustment; + private final int initialBytesToStrip; + private final boolean failFast; + private boolean discardingTooLongFrame; + private long tooLongFrameLength; + private long bytesToDiscard; + private int frameLengthInt = -1; + + /** + * Creates a new instance. + * + * @param maxFrameLength + * the maximum length of the frame. If the length of the frame is + * greater than this value, {@link TooLongFrameException} will be + * thrown. + * @param lengthFieldOffset + * the offset of the length field + * @param lengthFieldLength + * the length of the length field + */ + public LengthFieldBasedFrameDecoder( + int maxFrameLength, + int lengthFieldOffset, int lengthFieldLength) { + this(maxFrameLength, lengthFieldOffset, lengthFieldLength, 0, 0); + } + + /** + * Creates a new instance. + * + * @param maxFrameLength + * the maximum length of the frame. If the length of the frame is + * greater than this value, {@link TooLongFrameException} will be + * thrown. + * @param lengthFieldOffset + * the offset of the length field + * @param lengthFieldLength + * the length of the length field + * @param lengthAdjustment + * the compensation value to add to the value of the length field + * @param initialBytesToStrip + * the number of first bytes to strip out from the decoded frame + */ + public LengthFieldBasedFrameDecoder( + int maxFrameLength, + int lengthFieldOffset, int lengthFieldLength, + int lengthAdjustment, int initialBytesToStrip) { + this( + maxFrameLength, + lengthFieldOffset, lengthFieldLength, lengthAdjustment, + initialBytesToStrip, true); + } + + /** + * Creates a new instance. + * + * @param maxFrameLength + * the maximum length of the frame. If the length of the frame is + * greater than this value, {@link TooLongFrameException} will be + * thrown. + * @param lengthFieldOffset + * the offset of the length field + * @param lengthFieldLength + * the length of the length field + * @param lengthAdjustment + * the compensation value to add to the value of the length field + * @param initialBytesToStrip + * the number of first bytes to strip out from the decoded frame + * @param failFast + * If true, a {@link TooLongFrameException} is thrown as + * soon as the decoder notices the length of the frame will exceed + * maxFrameLength regardless of whether the entire frame + * has been read. If false, a {@link TooLongFrameException} + * is thrown after the entire frame that exceeds maxFrameLength + * has been read. + */ + public LengthFieldBasedFrameDecoder( + int maxFrameLength, int lengthFieldOffset, int lengthFieldLength, + int lengthAdjustment, int initialBytesToStrip, boolean failFast) { + this( + ByteOrder.BIG_ENDIAN, maxFrameLength, lengthFieldOffset, lengthFieldLength, + lengthAdjustment, initialBytesToStrip, failFast); + } + + /** + * Creates a new instance. + * + * @param byteOrder + * the {@link ByteOrder} of the length field + * @param maxFrameLength + * the maximum length of the frame. If the length of the frame is + * greater than this value, {@link TooLongFrameException} will be + * thrown. + * @param lengthFieldOffset + * the offset of the length field + * @param lengthFieldLength + * the length of the length field + * @param lengthAdjustment + * the compensation value to add to the value of the length field + * @param initialBytesToStrip + * the number of first bytes to strip out from the decoded frame + * @param failFast + * If true, a {@link TooLongFrameException} is thrown as + * soon as the decoder notices the length of the frame will exceed + * maxFrameLength regardless of whether the entire frame + * has been read. If false, a {@link TooLongFrameException} + * is thrown after the entire frame that exceeds maxFrameLength + * has been read. + */ + public LengthFieldBasedFrameDecoder( + ByteOrder byteOrder, int maxFrameLength, int lengthFieldOffset, int lengthFieldLength, + int lengthAdjustment, int initialBytesToStrip, boolean failFast) { + + this.byteOrder = checkNotNull(byteOrder, "byteOrder"); + + checkPositive(maxFrameLength, "maxFrameLength"); + + checkPositiveOrZero(lengthFieldOffset, "lengthFieldOffset"); + + checkPositiveOrZero(initialBytesToStrip, "initialBytesToStrip"); + + if (lengthFieldOffset > maxFrameLength - lengthFieldLength) { + throw new IllegalArgumentException( + "maxFrameLength (" + maxFrameLength + ") " + + "must be equal to or greater than " + + "lengthFieldOffset (" + lengthFieldOffset + ") + " + + "lengthFieldLength (" + lengthFieldLength + ")."); + } + + this.maxFrameLength = maxFrameLength; + this.lengthFieldOffset = lengthFieldOffset; + this.lengthFieldLength = lengthFieldLength; + this.lengthAdjustment = lengthAdjustment; + this.lengthFieldEndOffset = lengthFieldOffset + lengthFieldLength; + this.initialBytesToStrip = initialBytesToStrip; + this.failFast = failFast; + } + + @Override + protected final void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { + Object decoded = decode(ctx, in); + if (decoded != null) { + out.add(decoded); + } + } + + private void discardingTooLongFrame(ByteBuf in) { + long bytesToDiscard = this.bytesToDiscard; + int localBytesToDiscard = (int) Math.min(bytesToDiscard, in.readableBytes()); + in.skipBytes(localBytesToDiscard); + bytesToDiscard -= localBytesToDiscard; + this.bytesToDiscard = bytesToDiscard; + + failIfNecessary(false); + } + + private static void failOnNegativeLengthField(ByteBuf in, long frameLength, int lengthFieldEndOffset) { + in.skipBytes(lengthFieldEndOffset); + throw new CorruptedFrameException( + "negative pre-adjustment length field: " + frameLength); + } + + private static void failOnFrameLengthLessThanLengthFieldEndOffset(ByteBuf in, + long frameLength, + int lengthFieldEndOffset) { + in.skipBytes(lengthFieldEndOffset); + throw new CorruptedFrameException( + "Adjusted frame length (" + frameLength + ") is less " + + "than lengthFieldEndOffset: " + lengthFieldEndOffset); + } + + private void exceededFrameLength(ByteBuf in, long frameLength) { + long discard = frameLength - in.readableBytes(); + tooLongFrameLength = frameLength; + + if (discard < 0) { + // buffer contains more bytes then the frameLength so we can discard all now + in.skipBytes((int) frameLength); + } else { + // Enter the discard mode and discard everything received so far. + discardingTooLongFrame = true; + bytesToDiscard = discard; + in.skipBytes(in.readableBytes()); + } + failIfNecessary(true); + } + + private static void failOnFrameLengthLessThanInitialBytesToStrip(ByteBuf in, + long frameLength, + int initialBytesToStrip) { + in.skipBytes((int) frameLength); + throw new CorruptedFrameException( + "Adjusted frame length (" + frameLength + ") is less " + + "than initialBytesToStrip: " + initialBytesToStrip); + } + + /** + * Create a frame out of the {@link ByteBuf} and return it. + * + * @param ctx the {@link ChannelHandlerContext} which this {@link ByteToMessageDecoder} belongs to + * @param in the {@link ByteBuf} from which to read data + * @return frame the {@link ByteBuf} which represent the frame or {@code null} if no frame could + * be created. + */ + protected Object decode(ChannelHandlerContext ctx, ByteBuf in) throws Exception { + long frameLength = 0; + if (frameLengthInt == -1) { // new frame + + if (discardingTooLongFrame) { + discardingTooLongFrame(in); + } + + if (in.readableBytes() < lengthFieldEndOffset) { + return null; + } + + int actualLengthFieldOffset = in.readerIndex() + lengthFieldOffset; + frameLength = getUnadjustedFrameLength(in, actualLengthFieldOffset, lengthFieldLength, byteOrder); + + if (frameLength < 0) { + failOnNegativeLengthField(in, frameLength, lengthFieldEndOffset); + } + + frameLength += lengthAdjustment + lengthFieldEndOffset; + + if (frameLength < lengthFieldEndOffset) { + failOnFrameLengthLessThanLengthFieldEndOffset(in, frameLength, lengthFieldEndOffset); + } + + if (frameLength > maxFrameLength) { + exceededFrameLength(in, frameLength); + return null; + } + // never overflows because it's less than maxFrameLength + frameLengthInt = (int) frameLength; + } + if (in.readableBytes() < frameLengthInt) { // frameLengthInt exist , just check buf + return null; + } + if (initialBytesToStrip > frameLengthInt) { + failOnFrameLengthLessThanInitialBytesToStrip(in, frameLength, initialBytesToStrip); + } + in.skipBytes(initialBytesToStrip); + + // extract frame + int readerIndex = in.readerIndex(); + int actualFrameLength = frameLengthInt - initialBytesToStrip; + ByteBuf frame = extractFrame(ctx, in, readerIndex, actualFrameLength); + in.readerIndex(readerIndex + actualFrameLength); + frameLengthInt = -1; // start processing the next frame + return frame; + } + + /** + * Decodes the specified region of the buffer into an unadjusted frame length. The default implementation is + * capable of decoding the specified region into an unsigned 8/16/24/32/64 bit integer. Override this method to + * decode the length field encoded differently. Note that this method must not modify the state of the specified + * buffer (e.g. {@code readerIndex}, {@code writerIndex}, and the content of the buffer.) + * + * @throws DecoderException if failed to decode the specified region + */ + protected long getUnadjustedFrameLength(ByteBuf buf, int offset, int length, ByteOrder order) { + buf = buf.order(order); + long frameLength; + switch (length) { + case 1: + frameLength = buf.getUnsignedByte(offset); + break; + case 2: + frameLength = buf.getUnsignedShort(offset); + break; + case 3: + frameLength = buf.getUnsignedMedium(offset); + break; + case 4: + frameLength = buf.getUnsignedInt(offset); + break; + case 8: + frameLength = buf.getLong(offset); + break; + default: + throw new DecoderException( + "unsupported lengthFieldLength: " + lengthFieldLength + " (expected: 1, 2, 3, 4, or 8)"); + } + return frameLength; + } + + private void failIfNecessary(boolean firstDetectionOfTooLongFrame) { + if (bytesToDiscard == 0) { + // Reset to the initial state and tell the handlers that + // the frame was too large. + long tooLongFrameLength = this.tooLongFrameLength; + this.tooLongFrameLength = 0; + discardingTooLongFrame = false; + if (!failFast || firstDetectionOfTooLongFrame) { + fail(tooLongFrameLength); + } + } else { + // Keep discarding and notify handlers if necessary. + if (failFast && firstDetectionOfTooLongFrame) { + fail(tooLongFrameLength); + } + } + } + + /** + * Extract the sub-region of the specified buffer. + */ + protected ByteBuf extractFrame(ChannelHandlerContext ctx, ByteBuf buffer, int index, int length) { + return buffer.retainedSlice(index, length); + } + + private void fail(long frameLength) { + if (frameLength > 0) { + throw new TooLongFrameException( + "Adjusted frame length exceeds " + maxFrameLength + + ": " + frameLength + " - discarded"); + } else { + throw new TooLongFrameException( + "Adjusted frame length exceeds " + maxFrameLength + + " - discarding"); + } + } +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/LengthFieldPrepender.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/LengthFieldPrepender.java new file mode 100644 index 0000000..ec3db47 --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/LengthFieldPrepender.java @@ -0,0 +1,201 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec; + +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandler.Sharable; +import io.netty.channel.ChannelHandlerContext; +import io.netty.util.internal.ObjectUtil; + +import java.nio.ByteOrder; +import java.util.List; + + +/** + * An encoder that prepends the length of the message. The length value is + * prepended as a binary form. + *

+ * For example, {@link LengthFieldPrepender}(2) will encode the + * following 12-bytes string: + *

+ * +----------------+
+ * | "HELLO, WORLD" |
+ * +----------------+
+ * 
+ * into the following: + *
+ * +--------+----------------+
+ * + 0x000C | "HELLO, WORLD" |
+ * +--------+----------------+
+ * 
+ * If you turned on the {@code lengthIncludesLengthFieldLength} flag in the + * constructor, the encoded data would look like the following + * (12 (original data) + 2 (prepended data) = 14 (0xE)): + *
+ * +--------+----------------+
+ * + 0x000E | "HELLO, WORLD" |
+ * +--------+----------------+
+ * 
+ */ +@Sharable +public class LengthFieldPrepender extends MessageToMessageEncoder { + + private final ByteOrder byteOrder; + private final int lengthFieldLength; + private final boolean lengthIncludesLengthFieldLength; + private final int lengthAdjustment; + + /** + * Creates a new instance. + * + * @param lengthFieldLength the length of the prepended length field. + * Only 1, 2, 3, 4, and 8 are allowed. + * + * @throws IllegalArgumentException + * if {@code lengthFieldLength} is not 1, 2, 3, 4, or 8 + */ + public LengthFieldPrepender(int lengthFieldLength) { + this(lengthFieldLength, false); + } + + /** + * Creates a new instance. + * + * @param lengthFieldLength the length of the prepended length field. + * Only 1, 2, 3, 4, and 8 are allowed. + * @param lengthIncludesLengthFieldLength + * if {@code true}, the length of the prepended + * length field is added to the value of the + * prepended length field. + * + * @throws IllegalArgumentException + * if {@code lengthFieldLength} is not 1, 2, 3, 4, or 8 + */ + public LengthFieldPrepender(int lengthFieldLength, boolean lengthIncludesLengthFieldLength) { + this(lengthFieldLength, 0, lengthIncludesLengthFieldLength); + } + + /** + * Creates a new instance. + * + * @param lengthFieldLength the length of the prepended length field. + * Only 1, 2, 3, 4, and 8 are allowed. + * @param lengthAdjustment the compensation value to add to the value + * of the length field + * + * @throws IllegalArgumentException + * if {@code lengthFieldLength} is not 1, 2, 3, 4, or 8 + */ + public LengthFieldPrepender(int lengthFieldLength, int lengthAdjustment) { + this(lengthFieldLength, lengthAdjustment, false); + } + + /** + * Creates a new instance. + * + * @param lengthFieldLength the length of the prepended length field. + * Only 1, 2, 3, 4, and 8 are allowed. + * @param lengthAdjustment the compensation value to add to the value + * of the length field + * @param lengthIncludesLengthFieldLength + * if {@code true}, the length of the prepended + * length field is added to the value of the + * prepended length field. + * + * @throws IllegalArgumentException + * if {@code lengthFieldLength} is not 1, 2, 3, 4, or 8 + */ + public LengthFieldPrepender(int lengthFieldLength, int lengthAdjustment, boolean lengthIncludesLengthFieldLength) { + this(ByteOrder.BIG_ENDIAN, lengthFieldLength, lengthAdjustment, lengthIncludesLengthFieldLength); + } + + /** + * Creates a new instance. + * + * @param byteOrder the {@link ByteOrder} of the length field + * @param lengthFieldLength the length of the prepended length field. + * Only 1, 2, 3, 4, and 8 are allowed. + * @param lengthAdjustment the compensation value to add to the value + * of the length field + * @param lengthIncludesLengthFieldLength + * if {@code true}, the length of the prepended + * length field is added to the value of the + * prepended length field. + * + * @throws IllegalArgumentException + * if {@code lengthFieldLength} is not 1, 2, 3, 4, or 8 + */ + public LengthFieldPrepender( + ByteOrder byteOrder, int lengthFieldLength, + int lengthAdjustment, boolean lengthIncludesLengthFieldLength) { + if (lengthFieldLength != 1 && lengthFieldLength != 2 && + lengthFieldLength != 3 && lengthFieldLength != 4 && + lengthFieldLength != 8) { + throw new IllegalArgumentException( + "lengthFieldLength must be either 1, 2, 3, 4, or 8: " + + lengthFieldLength); + } + this.byteOrder = ObjectUtil.checkNotNull(byteOrder, "byteOrder"); + this.lengthFieldLength = lengthFieldLength; + this.lengthIncludesLengthFieldLength = lengthIncludesLengthFieldLength; + this.lengthAdjustment = lengthAdjustment; + } + + @Override + protected void encode(ChannelHandlerContext ctx, ByteBuf msg, List out) throws Exception { + int length = msg.readableBytes() + lengthAdjustment; + if (lengthIncludesLengthFieldLength) { + length += lengthFieldLength; + } + + checkPositiveOrZero(length, "length"); + + switch (lengthFieldLength) { + case 1: + if (length >= 256) { + throw new IllegalArgumentException( + "length does not fit into a byte: " + length); + } + out.add(ctx.alloc().buffer(1).order(byteOrder).writeByte((byte) length)); + break; + case 2: + if (length >= 65536) { + throw new IllegalArgumentException( + "length does not fit into a short integer: " + length); + } + out.add(ctx.alloc().buffer(2).order(byteOrder).writeShort((short) length)); + break; + case 3: + if (length >= 16777216) { + throw new IllegalArgumentException( + "length does not fit into a medium integer: " + length); + } + out.add(ctx.alloc().buffer(3).order(byteOrder).writeMedium(length)); + break; + case 4: + out.add(ctx.alloc().buffer(4).order(byteOrder).writeInt(length)); + break; + case 8: + out.add(ctx.alloc().buffer(8).order(byteOrder).writeLong(length)); + break; + default: + throw new Error("should not reach here"); + } + out.add(msg.retain()); + } +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/LineBasedFrameDecoder.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/LineBasedFrameDecoder.java new file mode 100644 index 0000000..93289eb --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/LineBasedFrameDecoder.java @@ -0,0 +1,180 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.util.ByteProcessor; + +import java.util.List; + +/** + * A decoder that splits the received {@link ByteBuf}s on line endings. + *

+ * Both {@code "\n"} and {@code "\r\n"} are handled. + *

+ * The byte stream is expected to be in UTF-8 character encoding or ASCII. The current implementation + * uses direct {@code byte} to {@code char} cast and then compares that {@code char} to a few low range + * ASCII characters like {@code '\n'} or {@code '\r'}. UTF-8 is not using low range [0..0x7F] + * byte values for multibyte codepoint representations therefore fully supported by this implementation. + *

+ * For a more general delimiter-based decoder, see {@link DelimiterBasedFrameDecoder}. + */ +public class LineBasedFrameDecoder extends ByteToMessageDecoder { + + /** Maximum length of a frame we're willing to decode. */ + private final int maxLength; + /** Whether or not to throw an exception as soon as we exceed maxLength. */ + private final boolean failFast; + private final boolean stripDelimiter; + + /** True if we're discarding input because we're already over maxLength. */ + private boolean discarding; + private int discardedBytes; + + /** Last scan position. */ + private int offset; + + /** + * Creates a new decoder. + * @param maxLength the maximum length of the decoded frame. + * A {@link TooLongFrameException} is thrown if + * the length of the frame exceeds this value. + */ + public LineBasedFrameDecoder(final int maxLength) { + this(maxLength, true, false); + } + + /** + * Creates a new decoder. + * @param maxLength the maximum length of the decoded frame. + * A {@link TooLongFrameException} is thrown if + * the length of the frame exceeds this value. + * @param stripDelimiter whether the decoded frame should strip out the + * delimiter or not + * @param failFast If true, a {@link TooLongFrameException} is + * thrown as soon as the decoder notices the length of the + * frame will exceed maxFrameLength regardless of + * whether the entire frame has been read. + * If false, a {@link TooLongFrameException} is + * thrown after the entire frame that exceeds + * maxFrameLength has been read. + */ + public LineBasedFrameDecoder(final int maxLength, final boolean stripDelimiter, final boolean failFast) { + this.maxLength = maxLength; + this.failFast = failFast; + this.stripDelimiter = stripDelimiter; + } + + @Override + protected final void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { + Object decoded = decode(ctx, in); + if (decoded != null) { + out.add(decoded); + } + } + + /** + * Create a frame out of the {@link ByteBuf} and return it. + * + * @param ctx the {@link ChannelHandlerContext} which this {@link ByteToMessageDecoder} belongs to + * @param buffer the {@link ByteBuf} from which to read data + * @return frame the {@link ByteBuf} which represent the frame or {@code null} if no frame could + * be created. + */ + protected Object decode(ChannelHandlerContext ctx, ByteBuf buffer) throws Exception { + final int eol = findEndOfLine(buffer); + if (!discarding) { + if (eol >= 0) { + final ByteBuf frame; + final int length = eol - buffer.readerIndex(); + final int delimLength = buffer.getByte(eol) == '\r'? 2 : 1; + + if (length > maxLength) { + buffer.readerIndex(eol + delimLength); + fail(ctx, length); + return null; + } + + if (stripDelimiter) { + frame = buffer.readRetainedSlice(length); + buffer.skipBytes(delimLength); + } else { + frame = buffer.readRetainedSlice(length + delimLength); + } + + return frame; + } else { + final int length = buffer.readableBytes(); + if (length > maxLength) { + discardedBytes = length; + buffer.readerIndex(buffer.writerIndex()); + discarding = true; + offset = 0; + if (failFast) { + fail(ctx, "over " + discardedBytes); + } + } + return null; + } + } else { + if (eol >= 0) { + final int length = discardedBytes + eol - buffer.readerIndex(); + final int delimLength = buffer.getByte(eol) == '\r'? 2 : 1; + buffer.readerIndex(eol + delimLength); + discardedBytes = 0; + discarding = false; + if (!failFast) { + fail(ctx, length); + } + } else { + discardedBytes += buffer.readableBytes(); + buffer.readerIndex(buffer.writerIndex()); + // We skip everything in the buffer, we need to set the offset to 0 again. + offset = 0; + } + return null; + } + } + + private void fail(final ChannelHandlerContext ctx, int length) { + fail(ctx, String.valueOf(length)); + } + + private void fail(final ChannelHandlerContext ctx, String length) { + ctx.fireExceptionCaught( + new TooLongFrameException( + "frame length (" + length + ") exceeds the allowed maximum (" + maxLength + ')')); + } + + /** + * Returns the index in the buffer of the end of line found. + * Returns -1 if no end of line was found in the buffer. + */ + private int findEndOfLine(final ByteBuf buffer) { + int totalLength = buffer.readableBytes(); + int i = buffer.forEachByte(buffer.readerIndex() + offset, totalLength - offset, ByteProcessor.FIND_LF); + if (i >= 0) { + offset = 0; + if (i > 0 && buffer.getByte(i - 1) == '\r') { + i--; + } + } else { + offset = totalLength; + } + return i; + } +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/MessageAggregationException.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/MessageAggregationException.java new file mode 100644 index 0000000..cc96847 --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/MessageAggregationException.java @@ -0,0 +1,39 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.codec; + +/** + * Raised by {@link MessageAggregator} when aggregation fails due to an unexpected message sequence. + */ +public class MessageAggregationException extends IllegalStateException { + + private static final long serialVersionUID = -1995826182950310255L; + + public MessageAggregationException() { } + + public MessageAggregationException(String s) { + super(s); + } + + public MessageAggregationException(String message, Throwable cause) { + super(message, cause); + } + + public MessageAggregationException(Throwable cause) { + super(cause); + } +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/MessageAggregator.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/MessageAggregator.java new file mode 100644 index 0000000..bdef311 --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/MessageAggregator.java @@ -0,0 +1,471 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufHolder; +import io.netty.buffer.CompositeByteBuf; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPipeline; +import io.netty.util.ReferenceCountUtil; + +import java.util.List; + +import static io.netty.buffer.Unpooled.EMPTY_BUFFER; +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; + +/** + * An abstract {@link ChannelHandler} that aggregates a series of message objects into a single aggregated message. + *

+ * 'A series of messages' is composed of the following: + *

    + *
  • a single start message which optionally contains the first part of the content, and
  • + *
  • 1 or more content messages.
  • + *
+ * The content of the aggregated message will be the merged content of the start message and its following content + * messages. If this aggregator encounters a content message where {@link #isLastContentMessage(ByteBufHolder)} + * return {@code true} for, the aggregator will finish the aggregation and produce the aggregated message and expect + * another start message. + *

+ * + * @param the type that covers both start message and content message + * @param the type of the start message + * @param the type of the content message (must be a subtype of {@link ByteBufHolder}) + * @param the type of the aggregated message (must be a subtype of {@code S} and {@link ByteBufHolder}) + */ +public abstract class MessageAggregator + extends MessageToMessageDecoder { + + private static final int DEFAULT_MAX_COMPOSITEBUFFER_COMPONENTS = 1024; + + private final int maxContentLength; + private O currentMessage; + private boolean handlingOversizedMessage; + + private int maxCumulationBufferComponents = DEFAULT_MAX_COMPOSITEBUFFER_COMPONENTS; + private ChannelHandlerContext ctx; + private ChannelFutureListener continueResponseWriteListener; + + private boolean aggregating; + private boolean handleIncompleteAggregateDuringClose = true; + + /** + * Creates a new instance. + * + * @param maxContentLength + * the maximum length of the aggregated content. + * If the length of the aggregated content exceeds this value, + * {@link #handleOversizedMessage(ChannelHandlerContext, Object)} will be called. + */ + protected MessageAggregator(int maxContentLength) { + validateMaxContentLength(maxContentLength); + this.maxContentLength = maxContentLength; + } + + protected MessageAggregator(int maxContentLength, Class inboundMessageType) { + super(inboundMessageType); + validateMaxContentLength(maxContentLength); + this.maxContentLength = maxContentLength; + } + + private static void validateMaxContentLength(int maxContentLength) { + checkPositiveOrZero(maxContentLength, "maxContentLength"); + } + + @Override + public boolean acceptInboundMessage(Object msg) throws Exception { + // No need to match last and full types because they are subset of first and middle types. + if (!super.acceptInboundMessage(msg)) { + return false; + } + + @SuppressWarnings("unchecked") + I in = (I) msg; + + if (isAggregated(in)) { + return false; + } + + // NOTE: It's tempting to make this check only if aggregating is false. There are however + // side conditions in decode(...) in respect to large messages. + if (isStartMessage(in)) { + return true; + } else { + return aggregating && isContentMessage(in); + } + } + + /** + * Returns {@code true} if and only if the specified message is a start message. Typically, this method is + * implemented as a single {@code return} statement with {@code instanceof}: + *
+     * return msg instanceof MyStartMessage;
+     * 
+ */ + protected abstract boolean isStartMessage(I msg) throws Exception; + + /** + * Returns {@code true} if and only if the specified message is a content message. Typically, this method is + * implemented as a single {@code return} statement with {@code instanceof}: + *
+     * return msg instanceof MyContentMessage;
+     * 
+ */ + protected abstract boolean isContentMessage(I msg) throws Exception; + + /** + * Returns {@code true} if and only if the specified message is the last content message. Typically, this method is + * implemented as a single {@code return} statement with {@code instanceof}: + *
+     * return msg instanceof MyLastContentMessage;
+     * 
+ * or with {@code instanceof} and boolean field check: + *
+     * return msg instanceof MyContentMessage && msg.isLastFragment();
+     * 
+ */ + protected abstract boolean isLastContentMessage(C msg) throws Exception; + + /** + * Returns {@code true} if and only if the specified message is already aggregated. If this method returns + * {@code true}, this handler will simply forward the message to the next handler as-is. + */ + protected abstract boolean isAggregated(I msg) throws Exception; + + /** + * Returns the maximum allowed length of the aggregated message in bytes. + */ + public final int maxContentLength() { + return maxContentLength; + } + + /** + * Returns the maximum number of components in the cumulation buffer. If the number of + * the components in the cumulation buffer exceeds this value, the components of the + * cumulation buffer are consolidated into a single component, involving memory copies. + * The default value of this property is {@value #DEFAULT_MAX_COMPOSITEBUFFER_COMPONENTS}. + */ + public final int maxCumulationBufferComponents() { + return maxCumulationBufferComponents; + } + + /** + * Sets the maximum number of components in the cumulation buffer. If the number of + * the components in the cumulation buffer exceeds this value, the components of the + * cumulation buffer are consolidated into a single component, involving memory copies. + * The default value of this property is {@value #DEFAULT_MAX_COMPOSITEBUFFER_COMPONENTS} + * and its minimum allowed value is {@code 2}. + */ + public final void setMaxCumulationBufferComponents(int maxCumulationBufferComponents) { + if (maxCumulationBufferComponents < 2) { + throw new IllegalArgumentException( + "maxCumulationBufferComponents: " + maxCumulationBufferComponents + + " (expected: >= 2)"); + } + + if (ctx == null) { + this.maxCumulationBufferComponents = maxCumulationBufferComponents; + } else { + throw new IllegalStateException( + "decoder properties cannot be changed once the decoder is added to a pipeline."); + } + } + + /** + * @deprecated This method will be removed in future releases. + */ + @Deprecated + public final boolean isHandlingOversizedMessage() { + return handlingOversizedMessage; + } + + protected final ChannelHandlerContext ctx() { + if (ctx == null) { + throw new IllegalStateException("not added to a pipeline yet"); + } + return ctx; + } + + @Override + protected void decode(final ChannelHandlerContext ctx, I msg, List out) throws Exception { + if (isStartMessage(msg)) { + aggregating = true; + handlingOversizedMessage = false; + if (currentMessage != null) { + currentMessage.release(); + currentMessage = null; + throw new MessageAggregationException(); + } + + @SuppressWarnings("unchecked") + S m = (S) msg; + + // Send the continue response if necessary (e.g. 'Expect: 100-continue' header) + // Check before content length. Failing an expectation may result in a different response being sent. + Object continueResponse = newContinueResponse(m, maxContentLength, ctx.pipeline()); + if (continueResponse != null) { + // Cache the write listener for reuse. + ChannelFutureListener listener = continueResponseWriteListener; + if (listener == null) { + continueResponseWriteListener = listener = new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + if (!future.isSuccess()) { + ctx.fireExceptionCaught(future.cause()); + } + } + }; + } + + // Make sure to call this before writing, otherwise reference counts may be invalid. + boolean closeAfterWrite = closeAfterContinueResponse(continueResponse); + handlingOversizedMessage = ignoreContentAfterContinueResponse(continueResponse); + + final ChannelFuture future = ctx.writeAndFlush(continueResponse).addListener(listener); + + if (closeAfterWrite) { + handleIncompleteAggregateDuringClose = false; + future.addListener(ChannelFutureListener.CLOSE); + return; + } + if (handlingOversizedMessage) { + return; + } + } else if (isContentLengthInvalid(m, maxContentLength)) { + // if content length is set, preemptively close if it's too large + invokeHandleOversizedMessage(ctx, m); + return; + } + + if (m instanceof DecoderResultProvider && !((DecoderResultProvider) m).decoderResult().isSuccess()) { + O aggregated; + if (m instanceof ByteBufHolder) { + aggregated = beginAggregation(m, ((ByteBufHolder) m).content().retain()); + } else { + aggregated = beginAggregation(m, EMPTY_BUFFER); + } + finishAggregation0(aggregated); + out.add(aggregated); + return; + } + + // A streamed message - initialize the cumulative buffer, and wait for incoming chunks. + CompositeByteBuf content = ctx.alloc().compositeBuffer(maxCumulationBufferComponents); + if (m instanceof ByteBufHolder) { + appendPartialContent(content, ((ByteBufHolder) m).content()); + } + currentMessage = beginAggregation(m, content); + } else if (isContentMessage(msg)) { + if (currentMessage == null) { + // it is possible that a TooLongFrameException was already thrown but we can still discard data + // until the begging of the next request/response. + return; + } + + // Merge the received chunk into the content of the current message. + CompositeByteBuf content = (CompositeByteBuf) currentMessage.content(); + + @SuppressWarnings("unchecked") + final C m = (C) msg; + // Handle oversized message. + if (content.readableBytes() > maxContentLength - m.content().readableBytes()) { + // By convention, full message type extends first message type. + @SuppressWarnings("unchecked") + S s = (S) currentMessage; + invokeHandleOversizedMessage(ctx, s); + return; + } + + // Append the content of the chunk. + appendPartialContent(content, m.content()); + + // Give the subtypes a chance to merge additional information such as trailing headers. + aggregate(currentMessage, m); + + final boolean last; + if (m instanceof DecoderResultProvider) { + DecoderResult decoderResult = ((DecoderResultProvider) m).decoderResult(); + if (!decoderResult.isSuccess()) { + if (currentMessage instanceof DecoderResultProvider) { + ((DecoderResultProvider) currentMessage).setDecoderResult( + DecoderResult.failure(decoderResult.cause())); + } + last = true; + } else { + last = isLastContentMessage(m); + } + } else { + last = isLastContentMessage(m); + } + + if (last) { + finishAggregation0(currentMessage); + + // All done + out.add(currentMessage); + currentMessage = null; + } + } else { + throw new MessageAggregationException(); + } + } + + private static void appendPartialContent(CompositeByteBuf content, ByteBuf partialContent) { + if (partialContent.isReadable()) { + content.addComponent(true, partialContent.retain()); + } + } + + /** + * Determine if the message {@code start}'s content length is known, and if it greater than + * {@code maxContentLength}. + * @param start The message which may indicate the content length. + * @param maxContentLength The maximum allowed content length. + * @return {@code true} if the message {@code start}'s content length is known, and if it greater than + * {@code maxContentLength}. {@code false} otherwise. + */ + protected abstract boolean isContentLengthInvalid(S start, int maxContentLength) throws Exception; + + /** + * Returns the 'continue response' for the specified start message if necessary. For example, this method is + * useful to handle an HTTP 100-continue header. + * + * @return the 'continue response', or {@code null} if there's no message to send + */ + protected abstract Object newContinueResponse(S start, int maxContentLength, ChannelPipeline pipeline) + throws Exception; + + /** + * Determine if the channel should be closed after the result of + * {@link #newContinueResponse(Object, int, ChannelPipeline)} is written. + * @param msg The return value from {@link #newContinueResponse(Object, int, ChannelPipeline)}. + * @return {@code true} if the channel should be closed after the result of + * {@link #newContinueResponse(Object, int, ChannelPipeline)} is written. {@code false} otherwise. + */ + protected abstract boolean closeAfterContinueResponse(Object msg) throws Exception; + + /** + * Determine if all objects for the current request/response should be ignored or not. + * Messages will stop being ignored the next time {@link #isContentMessage(Object)} returns {@code true}. + * + * @param msg The return value from {@link #newContinueResponse(Object, int, ChannelPipeline)}. + * @return {@code true} if all objects for the current request/response should be ignored or not. + * {@code false} otherwise. + */ + protected abstract boolean ignoreContentAfterContinueResponse(Object msg) throws Exception; + + /** + * Creates a new aggregated message from the specified start message and the specified content. If the start + * message implements {@link ByteBufHolder}, its content is appended to the specified {@code content}. + * This aggregator will continue to append the received content to the specified {@code content}. + */ + protected abstract O beginAggregation(S start, ByteBuf content) throws Exception; + + /** + * Transfers the information provided by the specified content message to the specified aggregated message. + * Note that the content of the specified content message has been appended to the content of the specified + * aggregated message already, so that you don't need to. Use this method to transfer the additional information + * that the content message provides to {@code aggregated}. + */ + protected void aggregate(O aggregated, C content) throws Exception { } + + private void finishAggregation0(O aggregated) throws Exception { + aggregating = false; + finishAggregation(aggregated); + } + + /** + * Invoked when the specified {@code aggregated} message is about to be passed to the next handler in the pipeline. + */ + protected void finishAggregation(O aggregated) throws Exception { } + + private void invokeHandleOversizedMessage(ChannelHandlerContext ctx, S oversized) throws Exception { + handlingOversizedMessage = true; + currentMessage = null; + handleIncompleteAggregateDuringClose = false; + try { + handleOversizedMessage(ctx, oversized); + } finally { + // Release the message in case it is a full one. + ReferenceCountUtil.release(oversized); + } + } + + /** + * Invoked when an incoming request exceeds the maximum content length. The default behvaior is to trigger an + * {@code exceptionCaught()} event with a {@link TooLongFrameException}. + * + * @param ctx the {@link ChannelHandlerContext} + * @param oversized the accumulated message up to this point, whose type is {@code S} or {@code O} + */ + protected void handleOversizedMessage(ChannelHandlerContext ctx, S oversized) throws Exception { + ctx.fireExceptionCaught( + new TooLongFrameException("content length exceeded " + maxContentLength() + " bytes.")); + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + // We might need keep reading the channel until the full message is aggregated. + // + // See https://github.com/netty/netty/issues/6583 + if (currentMessage != null && !ctx.channel().config().isAutoRead()) { + ctx.read(); + } + ctx.fireChannelReadComplete(); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + if (aggregating && handleIncompleteAggregateDuringClose) { + ctx.fireExceptionCaught( + new PrematureChannelClosureException("Channel closed while still aggregating message")); + } + try { + // release current message if it is not null as it may be a left-over + super.channelInactive(ctx); + } finally { + releaseCurrentMessage(); + } + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + this.ctx = ctx; + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { + try { + super.handlerRemoved(ctx); + } finally { + // release current message if it is not null as it may be a left-over as there is not much more we can do in + // this case + releaseCurrentMessage(); + } + } + + private void releaseCurrentMessage() { + if (currentMessage != null) { + currentMessage.release(); + currentMessage = null; + handlingOversizedMessage = false; + aggregating = false; + } + } +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/MessageToByteEncoder.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/MessageToByteEncoder.java new file mode 100644 index 0000000..1342428 --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/MessageToByteEncoder.java @@ -0,0 +1,160 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelOutboundHandler; +import io.netty.channel.ChannelOutboundHandlerAdapter; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.ChannelPromise; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.internal.TypeParameterMatcher; + + +/** + * {@link ChannelOutboundHandlerAdapter} which encodes message in a stream-like fashion from one message to an + * {@link ByteBuf}. + * + * + * Example implementation which encodes {@link Integer}s to a {@link ByteBuf}. + * + *
+ *     public class IntegerEncoder extends {@link MessageToByteEncoder}<{@link Integer}> {
+ *         {@code @Override}
+ *         public void encode({@link ChannelHandlerContext} ctx, {@link Integer} msg, {@link ByteBuf} out)
+ *                 throws {@link Exception} {
+ *             out.writeInt(msg);
+ *         }
+ *     }
+ * 
+ */ +public abstract class MessageToByteEncoder extends ChannelOutboundHandlerAdapter { + + private final TypeParameterMatcher matcher; + private final boolean preferDirect; + + /** + * see {@link #MessageToByteEncoder(boolean)} with {@code true} as boolean parameter. + */ + protected MessageToByteEncoder() { + this(true); + } + + /** + * see {@link #MessageToByteEncoder(Class, boolean)} with {@code true} as boolean value. + */ + protected MessageToByteEncoder(Class outboundMessageType) { + this(outboundMessageType, true); + } + + /** + * Create a new instance which will try to detect the types to match out of the type parameter of the class. + * + * @param preferDirect {@code true} if a direct {@link ByteBuf} should be tried to be used as target for + * the encoded messages. If {@code false} is used it will allocate a heap + * {@link ByteBuf}, which is backed by an byte array. + */ + protected MessageToByteEncoder(boolean preferDirect) { + matcher = TypeParameterMatcher.find(this, MessageToByteEncoder.class, "I"); + this.preferDirect = preferDirect; + } + + /** + * Create a new instance + * + * @param outboundMessageType The type of messages to match + * @param preferDirect {@code true} if a direct {@link ByteBuf} should be tried to be used as target for + * the encoded messages. If {@code false} is used it will allocate a heap + * {@link ByteBuf}, which is backed by an byte array. + */ + protected MessageToByteEncoder(Class outboundMessageType, boolean preferDirect) { + matcher = TypeParameterMatcher.get(outboundMessageType); + this.preferDirect = preferDirect; + } + + /** + * Returns {@code true} if the given message should be handled. If {@code false} it will be passed to the next + * {@link ChannelOutboundHandler} in the {@link ChannelPipeline}. + */ + public boolean acceptOutboundMessage(Object msg) throws Exception { + return matcher.match(msg); + } + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + ByteBuf buf = null; + try { + if (acceptOutboundMessage(msg)) { + @SuppressWarnings("unchecked") + I cast = (I) msg; + buf = allocateBuffer(ctx, cast, preferDirect); + try { + encode(ctx, cast, buf); + } finally { + ReferenceCountUtil.release(cast); + } + + if (buf.isReadable()) { + ctx.write(buf, promise); + } else { + buf.release(); + ctx.write(Unpooled.EMPTY_BUFFER, promise); + } + buf = null; + } else { + ctx.write(msg, promise); + } + } catch (EncoderException e) { + throw e; + } catch (Throwable e) { + throw new EncoderException(e); + } finally { + if (buf != null) { + buf.release(); + } + } + } + + /** + * Allocate a {@link ByteBuf} which will be used as argument of {@link #encode(ChannelHandlerContext, I, ByteBuf)}. + * Sub-classes may override this method to return {@link ByteBuf} with a perfect matching {@code initialCapacity}. + */ + protected ByteBuf allocateBuffer(ChannelHandlerContext ctx, @SuppressWarnings("unused") I msg, + boolean preferDirect) throws Exception { + if (preferDirect) { + return ctx.alloc().ioBuffer(); + } else { + return ctx.alloc().heapBuffer(); + } + } + + /** + * Encode a message into a {@link ByteBuf}. This method will be called for each written message that can be handled + * by this encoder. + * + * @param ctx the {@link ChannelHandlerContext} which this {@link MessageToByteEncoder} belongs to + * @param msg the message to encode + * @param out the {@link ByteBuf} into which the encoded message will be written + * @throws Exception is thrown if an error occurs + */ + protected abstract void encode(ChannelHandlerContext ctx, I msg, ByteBuf out) throws Exception; + + protected boolean isPreferDirect() { + return preferDirect; + } +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/MessageToMessageCodec.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/MessageToMessageCodec.java new file mode 100644 index 0000000..9e99fa1 --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/MessageToMessageCodec.java @@ -0,0 +1,148 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec; + +import io.netty.channel.ChannelDuplexHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.util.ReferenceCounted; +import io.netty.util.internal.TypeParameterMatcher; + +import java.util.List; + +/** + * A Codec for on-the-fly encoding/decoding of message. + * + * This can be thought of as a combination of {@link MessageToMessageDecoder} and {@link MessageToMessageEncoder}. + * + * Here is an example of a {@link MessageToMessageCodec} which just decode from {@link Integer} to {@link Long} + * and encode from {@link Long} to {@link Integer}. + * + *
+ *     public class NumberCodec extends
+ *             {@link MessageToMessageCodec}<{@link Integer}, {@link Long}> {
+ *         {@code @Override}
+ *         public {@link Long} decode({@link ChannelHandlerContext} ctx, {@link Integer} msg, List<Object> out)
+ *                 throws {@link Exception} {
+ *             out.add(msg.longValue());
+ *         }
+ *
+ *         {@code @Override}
+ *         public {@link Integer} encode({@link ChannelHandlerContext} ctx, {@link Long} msg, List<Object> out)
+ *                 throws {@link Exception} {
+ *             out.add(msg.intValue());
+ *         }
+ *     }
+ * 
+ * + * Be aware that you need to call {@link ReferenceCounted#retain()} on messages that are just passed through if they + * are of type {@link ReferenceCounted}. This is needed as the {@link MessageToMessageCodec} will call + * {@link ReferenceCounted#release()} on encoded / decoded messages. + */ +public abstract class MessageToMessageCodec extends ChannelDuplexHandler { + + private final MessageToMessageEncoder encoder = new MessageToMessageEncoder() { + + @Override + public boolean acceptOutboundMessage(Object msg) throws Exception { + return MessageToMessageCodec.this.acceptOutboundMessage(msg); + } + + @Override + @SuppressWarnings("unchecked") + protected void encode(ChannelHandlerContext ctx, Object msg, List out) throws Exception { + MessageToMessageCodec.this.encode(ctx, (OUTBOUND_IN) msg, out); + } + }; + + private final MessageToMessageDecoder decoder = new MessageToMessageDecoder() { + + @Override + public boolean acceptInboundMessage(Object msg) throws Exception { + return MessageToMessageCodec.this.acceptInboundMessage(msg); + } + + @Override + @SuppressWarnings("unchecked") + protected void decode(ChannelHandlerContext ctx, Object msg, List out) throws Exception { + MessageToMessageCodec.this.decode(ctx, (INBOUND_IN) msg, out); + } + }; + + private final TypeParameterMatcher inboundMsgMatcher; + private final TypeParameterMatcher outboundMsgMatcher; + + /** + * Create a new instance which will try to detect the types to decode and encode out of the type parameter + * of the class. + */ + protected MessageToMessageCodec() { + inboundMsgMatcher = TypeParameterMatcher.find(this, MessageToMessageCodec.class, "INBOUND_IN"); + outboundMsgMatcher = TypeParameterMatcher.find(this, MessageToMessageCodec.class, "OUTBOUND_IN"); + } + + /** + * Create a new instance. + * + * @param inboundMessageType The type of messages to decode + * @param outboundMessageType The type of messages to encode + */ + protected MessageToMessageCodec( + Class inboundMessageType, Class outboundMessageType) { + inboundMsgMatcher = TypeParameterMatcher.get(inboundMessageType); + outboundMsgMatcher = TypeParameterMatcher.get(outboundMessageType); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + decoder.channelRead(ctx, msg); + } + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + encoder.write(ctx, msg, promise); + } + + /** + * Returns {@code true} if and only if the specified message can be decoded by this codec. + * + * @param msg the message + */ + public boolean acceptInboundMessage(Object msg) throws Exception { + return inboundMsgMatcher.match(msg); + } + + /** + * Returns {@code true} if and only if the specified message can be encoded by this codec. + * + * @param msg the message + */ + public boolean acceptOutboundMessage(Object msg) throws Exception { + return outboundMsgMatcher.match(msg); + } + + /** + * @see MessageToMessageEncoder#encode(ChannelHandlerContext, Object, List) + */ + protected abstract void encode(ChannelHandlerContext ctx, OUTBOUND_IN msg, List out) + throws Exception; + + /** + * @see MessageToMessageDecoder#decode(ChannelHandlerContext, Object, List) + */ + protected abstract void decode(ChannelHandlerContext ctx, INBOUND_IN msg, List out) + throws Exception; +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/MessageToMessageDecoder.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/MessageToMessageDecoder.java new file mode 100644 index 0000000..e088736 --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/MessageToMessageDecoder.java @@ -0,0 +1,121 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandler; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelPipeline; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.ReferenceCounted; +import io.netty.util.internal.TypeParameterMatcher; + +import java.util.List; + +/** + * {@link ChannelInboundHandlerAdapter} which decodes from one message to an other message. + * + * + * For example here is an implementation which decodes a {@link String} to an {@link Integer} which represent + * the length of the {@link String}. + * + *
+ *     public class StringToIntegerDecoder extends
+ *             {@link MessageToMessageDecoder}<{@link String}> {
+ *
+ *         {@code @Override}
+ *         public void decode({@link ChannelHandlerContext} ctx, {@link String} message,
+ *                            List<Object> out) throws {@link Exception} {
+ *             out.add(message.length());
+ *         }
+ *     }
+ * 
+ * + * Be aware that you need to call {@link ReferenceCounted#retain()} on messages that are just passed through if they + * are of type {@link ReferenceCounted}. This is needed as the {@link MessageToMessageDecoder} will call + * {@link ReferenceCounted#release()} on decoded messages. + * + */ +public abstract class MessageToMessageDecoder extends ChannelInboundHandlerAdapter { + + private final TypeParameterMatcher matcher; + + /** + * Create a new instance which will try to detect the types to match out of the type parameter of the class. + */ + protected MessageToMessageDecoder() { + matcher = TypeParameterMatcher.find(this, MessageToMessageDecoder.class, "I"); + } + + /** + * Create a new instance + * + * @param inboundMessageType The type of messages to match and so decode + */ + protected MessageToMessageDecoder(Class inboundMessageType) { + matcher = TypeParameterMatcher.get(inboundMessageType); + } + + /** + * Returns {@code true} if the given message should be handled. If {@code false} it will be passed to the next + * {@link ChannelInboundHandler} in the {@link ChannelPipeline}. + */ + public boolean acceptInboundMessage(Object msg) throws Exception { + return matcher.match(msg); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + CodecOutputList out = CodecOutputList.newInstance(); + try { + if (acceptInboundMessage(msg)) { + @SuppressWarnings("unchecked") + I cast = (I) msg; + try { + decode(ctx, cast, out); + } finally { + ReferenceCountUtil.release(cast); + } + } else { + out.add(msg); + } + } catch (DecoderException e) { + throw e; + } catch (Exception e) { + throw new DecoderException(e); + } finally { + try { + int size = out.size(); + for (int i = 0; i < size; i++) { + ctx.fireChannelRead(out.getUnsafe(i)); + } + } finally { + out.recycle(); + } + } + } + + /** + * Decode from one message to an other. This method will be called for each written message that can be handled + * by this decoder. + * + * @param ctx the {@link ChannelHandlerContext} which this {@link MessageToMessageDecoder} belongs to + * @param msg the message to decode to an other one + * @param out the {@link List} to which decoded messages should be added + * @throws Exception is thrown if an error occurs + */ + protected abstract void decode(ChannelHandlerContext ctx, I msg, List out) throws Exception; +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/MessageToMessageEncoder.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/MessageToMessageEncoder.java new file mode 100644 index 0000000..4973c9b --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/MessageToMessageEncoder.java @@ -0,0 +1,156 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelOutboundHandler; +import io.netty.channel.ChannelOutboundHandlerAdapter; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.ChannelPromise; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.ReferenceCounted; +import io.netty.util.concurrent.PromiseCombiner; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.StringUtil; +import io.netty.util.internal.TypeParameterMatcher; + +import java.util.List; + +/** + * {@link ChannelOutboundHandlerAdapter} which encodes from one message to an other message + * + * For example here is an implementation which decodes an {@link Integer} to an {@link String}. + * + *
+ *     public class IntegerToStringEncoder extends
+ *             {@link MessageToMessageEncoder}<{@link Integer}> {
+ *
+ *         {@code @Override}
+ *         public void encode({@link ChannelHandlerContext} ctx, {@link Integer} message, List<Object> out)
+ *                 throws {@link Exception} {
+ *             out.add(message.toString());
+ *         }
+ *     }
+ * 
+ * + * Be aware that you need to call {@link ReferenceCounted#retain()} on messages that are just passed through if they + * are of type {@link ReferenceCounted}. This is needed as the {@link MessageToMessageEncoder} will call + * {@link ReferenceCounted#release()} on encoded messages. + */ +public abstract class MessageToMessageEncoder extends ChannelOutboundHandlerAdapter { + + private final TypeParameterMatcher matcher; + + /** + * Create a new instance which will try to detect the types to match out of the type parameter of the class. + */ + protected MessageToMessageEncoder() { + matcher = TypeParameterMatcher.find(this, MessageToMessageEncoder.class, "I"); + } + + /** + * Create a new instance + * + * @param outboundMessageType The type of messages to match and so encode + */ + protected MessageToMessageEncoder(Class outboundMessageType) { + matcher = TypeParameterMatcher.get(outboundMessageType); + } + + /** + * Returns {@code true} if the given message should be handled. If {@code false} it will be passed to the next + * {@link ChannelOutboundHandler} in the {@link ChannelPipeline}. + */ + public boolean acceptOutboundMessage(Object msg) throws Exception { + return matcher.match(msg); + } + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + CodecOutputList out = null; + try { + if (acceptOutboundMessage(msg)) { + out = CodecOutputList.newInstance(); + @SuppressWarnings("unchecked") + I cast = (I) msg; + try { + encode(ctx, cast, out); + } catch (Throwable th) { + ReferenceCountUtil.safeRelease(cast); + PlatformDependent.throwException(th); + } + ReferenceCountUtil.release(cast); + + if (out.isEmpty()) { + throw new EncoderException( + StringUtil.simpleClassName(this) + " must produce at least one message."); + } + } else { + ctx.write(msg, promise); + } + } catch (EncoderException e) { + throw e; + } catch (Throwable t) { + throw new EncoderException(t); + } finally { + if (out != null) { + try { + final int sizeMinusOne = out.size() - 1; + if (sizeMinusOne == 0) { + ctx.write(out.getUnsafe(0), promise); + } else if (sizeMinusOne > 0) { + // Check if we can use a voidPromise for our extra writes to reduce GC-Pressure + // See https://github.com/netty/netty/issues/2525 + if (promise == ctx.voidPromise()) { + writeVoidPromise(ctx, out); + } else { + writePromiseCombiner(ctx, out, promise); + } + } + } finally { + out.recycle(); + } + } + } + } + + private static void writeVoidPromise(ChannelHandlerContext ctx, CodecOutputList out) { + final ChannelPromise voidPromise = ctx.voidPromise(); + for (int i = 0; i < out.size(); i++) { + ctx.write(out.getUnsafe(i), voidPromise); + } + } + + private static void writePromiseCombiner(ChannelHandlerContext ctx, CodecOutputList out, ChannelPromise promise) { + final PromiseCombiner combiner = new PromiseCombiner(ctx.executor()); + for (int i = 0; i < out.size(); i++) { + combiner.add(ctx.write(out.getUnsafe(i))); + } + combiner.finish(promise); + } + + /** + * Encode from one message to an other. This method will be called for each written message that can be handled + * by this encoder. + * + * @param ctx the {@link ChannelHandlerContext} which this {@link MessageToMessageEncoder} belongs to + * @param msg the message to encode to an other one + * @param out the {@link List} into which the encoded msg should be added + * needs to do some kind of aggregation + * @throws Exception is thrown if an error occurs + */ + protected abstract void encode(ChannelHandlerContext ctx, I msg, List out) throws Exception; +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/PrematureChannelClosureException.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/PrematureChannelClosureException.java new file mode 100644 index 0000000..680f2d0 --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/PrematureChannelClosureException.java @@ -0,0 +1,54 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec; + +import io.netty.channel.Channel; + +/** + * A {@link CodecException} which is thrown when a {@link Channel} is closed unexpectedly before + * the codec finishes handling the current message, such as missing response while waiting for a + * request. + */ +public class PrematureChannelClosureException extends CodecException { + + private static final long serialVersionUID = 4907642202594703094L; + + /** + * Creates a new instance. + */ + public PrematureChannelClosureException() { } + + /** + * Creates a new instance. + */ + public PrematureChannelClosureException(String message, Throwable cause) { + super(message, cause); + } + + /** + * Creates a new instance. + */ + public PrematureChannelClosureException(String message) { + super(message); + } + + /** + * Creates a new instance. + */ + public PrematureChannelClosureException(Throwable cause) { + super(cause); + } +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/ProtocolDetectionResult.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/ProtocolDetectionResult.java new file mode 100644 index 0000000..10da049 --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/ProtocolDetectionResult.java @@ -0,0 +1,80 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec; + +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +/** + * Result of detecting a protocol. + * + * @param the type of the protocol + */ +public final class ProtocolDetectionResult { + + @SuppressWarnings({ "rawtypes", "unchecked" }) + private static final ProtocolDetectionResult NEEDS_MORE_DATA = + new ProtocolDetectionResult(ProtocolDetectionState.NEEDS_MORE_DATA, null); + @SuppressWarnings({ "rawtypes", "unchecked" }) + private static final ProtocolDetectionResult INVALID = + new ProtocolDetectionResult(ProtocolDetectionState.INVALID, null); + + private final ProtocolDetectionState state; + private final T result; + + /** + * Returns a {@link ProtocolDetectionResult} that signals that more data is needed to detect the protocol. + */ + @SuppressWarnings("unchecked") + public static ProtocolDetectionResult needsMoreData() { + return NEEDS_MORE_DATA; + } + + /** + * Returns a {@link ProtocolDetectionResult} that signals the data was invalid for the protocol. + */ + @SuppressWarnings("unchecked") + public static ProtocolDetectionResult invalid() { + return INVALID; + } + + /** + * Returns a {@link ProtocolDetectionResult} which holds the detected protocol. + */ + @SuppressWarnings("unchecked") + public static ProtocolDetectionResult detected(T protocol) { + return new ProtocolDetectionResult(ProtocolDetectionState.DETECTED, checkNotNull(protocol, "protocol")); + } + + private ProtocolDetectionResult(ProtocolDetectionState state, T result) { + this.state = state; + this.result = result; + } + + /** + * Return the {@link ProtocolDetectionState}. If the state is {@link ProtocolDetectionState#DETECTED} you + * can retrieve the protocol via {@link #detectedProtocol()}. + */ + public ProtocolDetectionState state() { + return state; + } + + /** + * Returns the protocol if {@link #state()} returns {@link ProtocolDetectionState#DETECTED}, otherwise {@code null}. + */ + public T detectedProtocol() { + return result; + } +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/ProtocolDetectionState.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/ProtocolDetectionState.java new file mode 100644 index 0000000..0f072fc --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/ProtocolDetectionState.java @@ -0,0 +1,36 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec; + +/** + * The state of the current detection. + */ +public enum ProtocolDetectionState { + /** + * Need more data to detect the protocol. + */ + NEEDS_MORE_DATA, + + /** + * The data was invalid. + */ + INVALID, + + /** + * Protocol was detected, + */ + DETECTED +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/ReplayingDecoder.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/ReplayingDecoder.java new file mode 100644 index 0000000..c134c55 --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/ReplayingDecoder.java @@ -0,0 +1,424 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPipeline; +import io.netty.util.Signal; +import io.netty.util.internal.StringUtil; + +import java.util.List; + +/** + * A specialized variation of {@link ByteToMessageDecoder} which enables implementation + * of a non-blocking decoder in the blocking I/O paradigm. + *

+ * The biggest difference between {@link ReplayingDecoder} and + * {@link ByteToMessageDecoder} is that {@link ReplayingDecoder} allows you to + * implement the {@code decode()} and {@code decodeLast()} methods just like + * all required bytes were received already, rather than checking the + * availability of the required bytes. For example, the following + * {@link ByteToMessageDecoder} implementation: + *

+ * public class IntegerHeaderFrameDecoder extends {@link ByteToMessageDecoder} {
+ *
+ *   {@code @Override}
+ *   protected void decode({@link ChannelHandlerContext} ctx,
+ *                           {@link ByteBuf} buf, List<Object> out) throws Exception {
+ *
+ *     if (buf.readableBytes() < 4) {
+ *        return;
+ *     }
+ *
+ *     buf.markReaderIndex();
+ *     int length = buf.readInt();
+ *
+ *     if (buf.readableBytes() < length) {
+ *        buf.resetReaderIndex();
+ *        return;
+ *     }
+ *
+ *     out.add(buf.readBytes(length));
+ *   }
+ * }
+ * 
+ * is simplified like the following with {@link ReplayingDecoder}: + *
+ * public class IntegerHeaderFrameDecoder
+ *      extends {@link ReplayingDecoder}<{@link Void}> {
+ *
+ *   protected void decode({@link ChannelHandlerContext} ctx,
+ *                           {@link ByteBuf} buf, List<Object> out) throws Exception {
+ *
+ *     out.add(buf.readBytes(buf.readInt()));
+ *   }
+ * }
+ * 
+ * + *

How does this work?

+ *

+ * {@link ReplayingDecoder} passes a specialized {@link ByteBuf} + * implementation which throws an {@link Error} of certain type when there's not + * enough data in the buffer. In the {@code IntegerHeaderFrameDecoder} above, + * you just assumed that there will be 4 or more bytes in the buffer when + * you call {@code buf.readInt()}. If there's really 4 bytes in the buffer, + * it will return the integer header as you expected. Otherwise, the + * {@link Error} will be raised and the control will be returned to + * {@link ReplayingDecoder}. If {@link ReplayingDecoder} catches the + * {@link Error}, then it will rewind the {@code readerIndex} of the buffer + * back to the 'initial' position (i.e. the beginning of the buffer) and call + * the {@code decode(..)} method again when more data is received into the + * buffer. + *

+ * Please note that {@link ReplayingDecoder} always throws the same cached + * {@link Error} instance to avoid the overhead of creating a new {@link Error} + * and filling its stack trace for every throw. + * + *

Limitations

+ *

+ * At the cost of the simplicity, {@link ReplayingDecoder} enforces you a few + * limitations: + *

    + *
  • Some buffer operations are prohibited.
  • + *
  • Performance can be worse if the network is slow and the message + * format is complicated unlike the example above. In this case, your + * decoder might have to decode the same part of the message over and over + * again.
  • + *
  • You must keep in mind that {@code decode(..)} method can be called many + * times to decode a single message. For example, the following code will + * not work: + *
     public class MyDecoder extends {@link ReplayingDecoder}<{@link Void}> {
    + *
    + *   private final Queue<Integer> values = new LinkedList<Integer>();
    + *
    + *   {@code @Override}
    + *   public void decode(.., {@link ByteBuf} buf, List<Object> out) throws Exception {
    + *
    + *     // A message contains 2 integers.
    + *     values.offer(buf.readInt());
    + *     values.offer(buf.readInt());
    + *
    + *     // This assertion will fail intermittently since values.offer()
    + *     // can be called more than two times!
    + *     assert values.size() == 2;
    + *     out.add(values.poll() + values.poll());
    + *   }
    + * }
    + * The correct implementation looks like the following, and you can also + * utilize the 'checkpoint' feature which is explained in detail in the + * next section. + *
     public class MyDecoder extends {@link ReplayingDecoder}<{@link Void}> {
    + *
    + *   private final Queue<Integer> values = new LinkedList<Integer>();
    + *
    + *   {@code @Override}
    + *   public void decode(.., {@link ByteBuf} buf, List<Object> out) throws Exception {
    + *
    + *     // Revert the state of the variable that might have been changed
    + *     // since the last partial decode.
    + *     values.clear();
    + *
    + *     // A message contains 2 integers.
    + *     values.offer(buf.readInt());
    + *     values.offer(buf.readInt());
    + *
    + *     // Now we know this assertion will never fail.
    + *     assert values.size() == 2;
    + *     out.add(values.poll() + values.poll());
    + *   }
    + * }
    + *
  • + *
+ * + *

Improving the performance

+ *

+ * Fortunately, the performance of a complex decoder implementation can be + * improved significantly with the {@code checkpoint()} method. The + * {@code checkpoint()} method updates the 'initial' position of the buffer so + * that {@link ReplayingDecoder} rewinds the {@code readerIndex} of the buffer + * to the last position where you called the {@code checkpoint()} method. + * + *

Calling {@code checkpoint(T)} with an {@link Enum}

+ *

+ * Although you can just use {@code checkpoint()} method and manage the state + * of the decoder by yourself, the easiest way to manage the state of the + * decoder is to create an {@link Enum} type which represents the current state + * of the decoder and to call {@code checkpoint(T)} method whenever the state + * changes. You can have as many states as you want depending on the + * complexity of the message you want to decode: + * + *

+ * public enum MyDecoderState {
+ *   READ_LENGTH,
+ *   READ_CONTENT;
+ * }
+ *
+ * public class IntegerHeaderFrameDecoder
+ *      extends {@link ReplayingDecoder}<MyDecoderState> {
+ *
+ *   private int length;
+ *
+ *   public IntegerHeaderFrameDecoder() {
+ *     // Set the initial state.
+ *     super(MyDecoderState.READ_LENGTH);
+ *   }
+ *
+ *   {@code @Override}
+ *   protected void decode({@link ChannelHandlerContext} ctx,
+ *                           {@link ByteBuf} buf, List<Object> out) throws Exception {
+ *     switch (state()) {
+ *     case READ_LENGTH:
+ *       length = buf.readInt();
+ *       checkpoint(MyDecoderState.READ_CONTENT);
+ *     case READ_CONTENT:
+ *       ByteBuf frame = buf.readBytes(length);
+ *       checkpoint(MyDecoderState.READ_LENGTH);
+ *       out.add(frame);
+ *       break;
+ *     default:
+ *       throw new Error("Shouldn't reach here.");
+ *     }
+ *   }
+ * }
+ * 
+ * + *

Calling {@code checkpoint()} with no parameter

+ *

+ * An alternative way to manage the decoder state is to manage it by yourself. + *

+ * public class IntegerHeaderFrameDecoder
+ *      extends {@link ReplayingDecoder}<{@link Void}> {
+ *
+ *   private boolean readLength;
+ *   private int length;
+ *
+ *   {@code @Override}
+ *   protected void decode({@link ChannelHandlerContext} ctx,
+ *                           {@link ByteBuf} buf, List<Object> out) throws Exception {
+ *     if (!readLength) {
+ *       length = buf.readInt();
+ *       readLength = true;
+ *       checkpoint();
+ *     }
+ *
+ *     if (readLength) {
+ *       ByteBuf frame = buf.readBytes(length);
+ *       readLength = false;
+ *       checkpoint();
+ *       out.add(frame);
+ *     }
+ *   }
+ * }
+ * 
+ * + *

Replacing a decoder with another decoder in a pipeline

+ *

+ * If you are going to write a protocol multiplexer, you will probably want to + * replace a {@link ReplayingDecoder} (protocol detector) with another + * {@link ReplayingDecoder}, {@link ByteToMessageDecoder} or {@link MessageToMessageDecoder} + * (actual protocol decoder). + * It is not possible to achieve this simply by calling + * {@link ChannelPipeline#replace(ChannelHandler, String, ChannelHandler)}, but + * some additional steps are required: + *

+ * public class FirstDecoder extends {@link ReplayingDecoder}<{@link Void}> {
+ *
+ *     {@code @Override}
+ *     protected void decode({@link ChannelHandlerContext} ctx,
+ *                             {@link ByteBuf} buf, List<Object> out) {
+ *         ...
+ *         // Decode the first message
+ *         Object firstMessage = ...;
+ *
+ *         // Add the second decoder
+ *         ctx.pipeline().addLast("second", new SecondDecoder());
+ *
+ *         if (buf.isReadable()) {
+ *             // Hand off the remaining data to the second decoder
+ *             out.add(firstMessage);
+ *             out.add(buf.readBytes(super.actualReadableBytes()));
+ *         } else {
+ *             // Nothing to hand off
+ *             out.add(firstMessage);
+ *         }
+ *         // Remove the first decoder (me)
+ *         ctx.pipeline().remove(this);
+ *     }
+ * 
+ * @param + * the state type which is usually an {@link Enum}; use {@link Void} if state management is + * unused + */ +public abstract class ReplayingDecoder extends ByteToMessageDecoder { + + static final Signal REPLAY = Signal.valueOf(ReplayingDecoder.class, "REPLAY"); + + private final ReplayingDecoderByteBuf replayable = new ReplayingDecoderByteBuf(); + private S state; + private int checkpoint = -1; + + /** + * Creates a new instance with no initial state (i.e: {@code null}). + */ + protected ReplayingDecoder() { + this(null); + } + + /** + * Creates a new instance with the specified initial state. + */ + protected ReplayingDecoder(S initialState) { + state = initialState; + } + + /** + * Stores the internal cumulative buffer's reader position. + */ + protected void checkpoint() { + checkpoint = internalBuffer().readerIndex(); + } + + /** + * Stores the internal cumulative buffer's reader position and updates + * the current decoder state. + */ + protected void checkpoint(S state) { + checkpoint(); + state(state); + } + + /** + * Returns the current state of this decoder. + * @return the current state of this decoder + */ + protected S state() { + return state; + } + + /** + * Sets the current state of this decoder. + * @return the old state of this decoder + */ + protected S state(S newState) { + S oldState = state; + state = newState; + return oldState; + } + + @Override + final void channelInputClosed(ChannelHandlerContext ctx, List out) throws Exception { + try { + replayable.terminate(); + if (cumulation != null) { + callDecode(ctx, internalBuffer(), out); + } else { + replayable.setCumulation(Unpooled.EMPTY_BUFFER); + } + decodeLast(ctx, replayable, out); + } catch (Signal replay) { + // Ignore + replay.expect(REPLAY); + } + } + + @Override + protected void callDecode(ChannelHandlerContext ctx, ByteBuf in, List out) { + replayable.setCumulation(in); + try { + while (in.isReadable()) { + int oldReaderIndex = checkpoint = in.readerIndex(); + int outSize = out.size(); + + if (outSize > 0) { + fireChannelRead(ctx, out, outSize); + out.clear(); + + // Check if this handler was removed before continuing with decoding. + // If it was removed, it is not safe to continue to operate on the buffer. + // + // See: + // - https://github.com/netty/netty/issues/4635 + if (ctx.isRemoved()) { + break; + } + outSize = 0; + } + + S oldState = state; + int oldInputLength = in.readableBytes(); + try { + decodeRemovalReentryProtection(ctx, replayable, out); + + // Check if this handler was removed before continuing the loop. + // If it was removed, it is not safe to continue to operate on the buffer. + // + // See https://github.com/netty/netty/issues/1664 + if (ctx.isRemoved()) { + break; + } + + if (outSize == out.size()) { + if (oldInputLength == in.readableBytes() && oldState == state) { + throw new DecoderException( + StringUtil.simpleClassName(getClass()) + ".decode() must consume the inbound " + + "data or change its state if it did not decode anything."); + } else { + // Previous data has been discarded or caused state transition. + // Probably it is reading on. + continue; + } + } + } catch (Signal replay) { + replay.expect(REPLAY); + + // Check if this handler was removed before continuing the loop. + // If it was removed, it is not safe to continue to operate on the buffer. + // + // See https://github.com/netty/netty/issues/1664 + if (ctx.isRemoved()) { + break; + } + + // Return to the checkpoint (or oldPosition) and retry. + int checkpoint = this.checkpoint; + if (checkpoint >= 0) { + in.readerIndex(checkpoint); + } else { + // Called by cleanup() - no need to maintain the readerIndex + // anymore because the buffer has been released already. + } + break; + } + + if (oldReaderIndex == in.readerIndex() && oldState == state) { + throw new DecoderException( + StringUtil.simpleClassName(getClass()) + ".decode() method must consume the inbound data " + + "or change its state if it decoded something."); + } + if (isSingleDecode()) { + break; + } + } + } catch (DecoderException e) { + throw e; + } catch (Exception cause) { + throw new DecoderException(cause); + } + } +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/ReplayingDecoderByteBuf.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/ReplayingDecoderByteBuf.java new file mode 100644 index 0000000..b75a6c8 --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/ReplayingDecoderByteBuf.java @@ -0,0 +1,1147 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec; + +import java.io.InputStream; +import java.io.OutputStream; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.channels.FileChannel; +import java.nio.channels.GatheringByteChannel; +import java.nio.channels.ScatteringByteChannel; +import java.nio.charset.Charset; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.SwappedByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.util.ByteProcessor; +import io.netty.util.Signal; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.StringUtil; + +/** + * Special {@link ByteBuf} implementation which is used by the {@link ReplayingDecoder} + */ +final class ReplayingDecoderByteBuf extends ByteBuf { + + private static final Signal REPLAY = ReplayingDecoder.REPLAY; + + private ByteBuf buffer; + private boolean terminated; + private SwappedByteBuf swapped; + + @SuppressWarnings("checkstyle:StaticFinalBuffer") // Unpooled.EMPTY_BUFFER is not writeable or readable. + static final ReplayingDecoderByteBuf EMPTY_BUFFER = new ReplayingDecoderByteBuf(Unpooled.EMPTY_BUFFER); + + static { + EMPTY_BUFFER.terminate(); + } + + ReplayingDecoderByteBuf() { } + + ReplayingDecoderByteBuf(ByteBuf buffer) { + setCumulation(buffer); + } + + void setCumulation(ByteBuf buffer) { + this.buffer = buffer; + } + + void terminate() { + terminated = true; + } + + @Override + public int capacity() { + if (terminated) { + return buffer.capacity(); + } else { + return Integer.MAX_VALUE; + } + } + + @Override + public ByteBuf capacity(int newCapacity) { + throw reject(); + } + + @Override + public int maxCapacity() { + return capacity(); + } + + @Override + public ByteBufAllocator alloc() { + return buffer.alloc(); + } + + @Override + public boolean isReadOnly() { + return false; + } + + @SuppressWarnings("deprecation") + @Override + public ByteBuf asReadOnly() { + return Unpooled.unmodifiableBuffer(this); + } + + @Override + public boolean isDirect() { + return buffer.isDirect(); + } + + @Override + public boolean hasArray() { + return false; + } + + @Override + public byte[] array() { + throw new UnsupportedOperationException(); + } + + @Override + public int arrayOffset() { + throw new UnsupportedOperationException(); + } + + @Override + public boolean hasMemoryAddress() { + return false; + } + + @Override + public long memoryAddress() { + throw new UnsupportedOperationException(); + } + + @Override + public ByteBuf clear() { + throw reject(); + } + + @Override + public boolean equals(Object obj) { + return this == obj; + } + + @Override + public int compareTo(ByteBuf buffer) { + throw reject(); + } + + @Override + public ByteBuf copy() { + throw reject(); + } + + @Override + public ByteBuf copy(int index, int length) { + checkIndex(index, length); + return buffer.copy(index, length); + } + + @Override + public ByteBuf discardReadBytes() { + throw reject(); + } + + @Override + public ByteBuf ensureWritable(int writableBytes) { + throw reject(); + } + + @Override + public int ensureWritable(int minWritableBytes, boolean force) { + throw reject(); + } + + @Override + public ByteBuf duplicate() { + throw reject(); + } + + @Override + public ByteBuf retainedDuplicate() { + throw reject(); + } + + @Override + public boolean getBoolean(int index) { + checkIndex(index, 1); + return buffer.getBoolean(index); + } + + @Override + public byte getByte(int index) { + checkIndex(index, 1); + return buffer.getByte(index); + } + + @Override + public short getUnsignedByte(int index) { + checkIndex(index, 1); + return buffer.getUnsignedByte(index); + } + + @Override + public ByteBuf getBytes(int index, byte[] dst, int dstIndex, int length) { + checkIndex(index, length); + buffer.getBytes(index, dst, dstIndex, length); + return this; + } + + @Override + public ByteBuf getBytes(int index, byte[] dst) { + checkIndex(index, dst.length); + buffer.getBytes(index, dst); + return this; + } + + @Override + public ByteBuf getBytes(int index, ByteBuffer dst) { + throw reject(); + } + + @Override + public ByteBuf getBytes(int index, ByteBuf dst, int dstIndex, int length) { + checkIndex(index, length); + buffer.getBytes(index, dst, dstIndex, length); + return this; + } + + @Override + public ByteBuf getBytes(int index, ByteBuf dst, int length) { + throw reject(); + } + + @Override + public ByteBuf getBytes(int index, ByteBuf dst) { + throw reject(); + } + + @Override + public int getBytes(int index, GatheringByteChannel out, int length) { + throw reject(); + } + + @Override + public int getBytes(int index, FileChannel out, long position, int length) { + throw reject(); + } + + @Override + public ByteBuf getBytes(int index, OutputStream out, int length) { + throw reject(); + } + + @Override + public int getInt(int index) { + checkIndex(index, 4); + return buffer.getInt(index); + } + + @Override + public int getIntLE(int index) { + checkIndex(index, 4); + return buffer.getIntLE(index); + } + + @Override + public long getUnsignedInt(int index) { + checkIndex(index, 4); + return buffer.getUnsignedInt(index); + } + + @Override + public long getUnsignedIntLE(int index) { + checkIndex(index, 4); + return buffer.getUnsignedIntLE(index); + } + + @Override + public long getLong(int index) { + checkIndex(index, 8); + return buffer.getLong(index); + } + + @Override + public long getLongLE(int index) { + checkIndex(index, 8); + return buffer.getLongLE(index); + } + + @Override + public int getMedium(int index) { + checkIndex(index, 3); + return buffer.getMedium(index); + } + + @Override + public int getMediumLE(int index) { + checkIndex(index, 3); + return buffer.getMediumLE(index); + } + + @Override + public int getUnsignedMedium(int index) { + checkIndex(index, 3); + return buffer.getUnsignedMedium(index); + } + + @Override + public int getUnsignedMediumLE(int index) { + checkIndex(index, 3); + return buffer.getUnsignedMediumLE(index); + } + + @Override + public short getShort(int index) { + checkIndex(index, 2); + return buffer.getShort(index); + } + + @Override + public short getShortLE(int index) { + checkIndex(index, 2); + return buffer.getShortLE(index); + } + + @Override + public int getUnsignedShort(int index) { + checkIndex(index, 2); + return buffer.getUnsignedShort(index); + } + + @Override + public int getUnsignedShortLE(int index) { + checkIndex(index, 2); + return buffer.getUnsignedShortLE(index); + } + + @Override + public char getChar(int index) { + checkIndex(index, 2); + return buffer.getChar(index); + } + + @Override + public float getFloat(int index) { + checkIndex(index, 4); + return buffer.getFloat(index); + } + + @Override + public double getDouble(int index) { + checkIndex(index, 8); + return buffer.getDouble(index); + } + + @Override + public CharSequence getCharSequence(int index, int length, Charset charset) { + checkIndex(index, length); + return buffer.getCharSequence(index, length, charset); + } + + @Override + public int hashCode() { + throw reject(); + } + + @Override + public int indexOf(int fromIndex, int toIndex, byte value) { + if (fromIndex == toIndex) { + return -1; + } + + if (Math.max(fromIndex, toIndex) > buffer.writerIndex()) { + throw REPLAY; + } + + return buffer.indexOf(fromIndex, toIndex, value); + } + + @Override + public int bytesBefore(byte value) { + int bytes = buffer.bytesBefore(value); + if (bytes < 0) { + throw REPLAY; + } + return bytes; + } + + @Override + public int bytesBefore(int length, byte value) { + return bytesBefore(buffer.readerIndex(), length, value); + } + + @Override + public int bytesBefore(int index, int length, byte value) { + final int writerIndex = buffer.writerIndex(); + if (index >= writerIndex) { + throw REPLAY; + } + + if (index <= writerIndex - length) { + return buffer.bytesBefore(index, length, value); + } + + int res = buffer.bytesBefore(index, writerIndex - index, value); + if (res < 0) { + throw REPLAY; + } else { + return res; + } + } + + @Override + public int forEachByte(ByteProcessor processor) { + int ret = buffer.forEachByte(processor); + if (ret < 0) { + throw REPLAY; + } else { + return ret; + } + } + + @Override + public int forEachByte(int index, int length, ByteProcessor processor) { + final int writerIndex = buffer.writerIndex(); + if (index >= writerIndex) { + throw REPLAY; + } + + if (index <= writerIndex - length) { + return buffer.forEachByte(index, length, processor); + } + + int ret = buffer.forEachByte(index, writerIndex - index, processor); + if (ret < 0) { + throw REPLAY; + } else { + return ret; + } + } + + @Override + public int forEachByteDesc(ByteProcessor processor) { + if (terminated) { + return buffer.forEachByteDesc(processor); + } else { + throw reject(); + } + } + + @Override + public int forEachByteDesc(int index, int length, ByteProcessor processor) { + if (index + length > buffer.writerIndex()) { + throw REPLAY; + } + + return buffer.forEachByteDesc(index, length, processor); + } + + @Override + public ByteBuf markReaderIndex() { + buffer.markReaderIndex(); + return this; + } + + @Override + public ByteBuf markWriterIndex() { + throw reject(); + } + + @Override + public ByteOrder order() { + return buffer.order(); + } + + @Override + public ByteBuf order(ByteOrder endianness) { + if (ObjectUtil.checkNotNull(endianness, "endianness") == order()) { + return this; + } + + SwappedByteBuf swapped = this.swapped; + if (swapped == null) { + this.swapped = swapped = new SwappedByteBuf(this); + } + return swapped; + } + + @Override + public boolean isReadable() { + return !terminated || buffer.isReadable(); + } + + @Override + public boolean isReadable(int size) { + return !terminated || buffer.isReadable(size); + } + + @Override + public int readableBytes() { + if (terminated) { + return buffer.readableBytes(); + } else { + return Integer.MAX_VALUE - buffer.readerIndex(); + } + } + + @Override + public boolean readBoolean() { + checkReadableBytes(1); + return buffer.readBoolean(); + } + + @Override + public byte readByte() { + checkReadableBytes(1); + return buffer.readByte(); + } + + @Override + public short readUnsignedByte() { + checkReadableBytes(1); + return buffer.readUnsignedByte(); + } + + @Override + public ByteBuf readBytes(byte[] dst, int dstIndex, int length) { + checkReadableBytes(length); + buffer.readBytes(dst, dstIndex, length); + return this; + } + + @Override + public ByteBuf readBytes(byte[] dst) { + checkReadableBytes(dst.length); + buffer.readBytes(dst); + return this; + } + + @Override + public ByteBuf readBytes(ByteBuffer dst) { + throw reject(); + } + + @Override + public ByteBuf readBytes(ByteBuf dst, int dstIndex, int length) { + checkReadableBytes(length); + buffer.readBytes(dst, dstIndex, length); + return this; + } + + @Override + public ByteBuf readBytes(ByteBuf dst, int length) { + throw reject(); + } + + @Override + public ByteBuf readBytes(ByteBuf dst) { + checkReadableBytes(dst.writableBytes()); + buffer.readBytes(dst); + return this; + } + + @Override + public int readBytes(GatheringByteChannel out, int length) { + throw reject(); + } + + @Override + public int readBytes(FileChannel out, long position, int length) { + throw reject(); + } + + @Override + public ByteBuf readBytes(int length) { + checkReadableBytes(length); + return buffer.readBytes(length); + } + + @Override + public ByteBuf readSlice(int length) { + checkReadableBytes(length); + return buffer.readSlice(length); + } + + @Override + public ByteBuf readRetainedSlice(int length) { + checkReadableBytes(length); + return buffer.readRetainedSlice(length); + } + + @Override + public ByteBuf readBytes(OutputStream out, int length) { + throw reject(); + } + + @Override + public int readerIndex() { + return buffer.readerIndex(); + } + + @Override + public ByteBuf readerIndex(int readerIndex) { + buffer.readerIndex(readerIndex); + return this; + } + + @Override + public int readInt() { + checkReadableBytes(4); + return buffer.readInt(); + } + + @Override + public int readIntLE() { + checkReadableBytes(4); + return buffer.readIntLE(); + } + + @Override + public long readUnsignedInt() { + checkReadableBytes(4); + return buffer.readUnsignedInt(); + } + + @Override + public long readUnsignedIntLE() { + checkReadableBytes(4); + return buffer.readUnsignedIntLE(); + } + + @Override + public long readLong() { + checkReadableBytes(8); + return buffer.readLong(); + } + + @Override + public long readLongLE() { + checkReadableBytes(8); + return buffer.readLongLE(); + } + + @Override + public int readMedium() { + checkReadableBytes(3); + return buffer.readMedium(); + } + + @Override + public int readMediumLE() { + checkReadableBytes(3); + return buffer.readMediumLE(); + } + + @Override + public int readUnsignedMedium() { + checkReadableBytes(3); + return buffer.readUnsignedMedium(); + } + + @Override + public int readUnsignedMediumLE() { + checkReadableBytes(3); + return buffer.readUnsignedMediumLE(); + } + + @Override + public short readShort() { + checkReadableBytes(2); + return buffer.readShort(); + } + + @Override + public short readShortLE() { + checkReadableBytes(2); + return buffer.readShortLE(); + } + + @Override + public int readUnsignedShort() { + checkReadableBytes(2); + return buffer.readUnsignedShort(); + } + + @Override + public int readUnsignedShortLE() { + checkReadableBytes(2); + return buffer.readUnsignedShortLE(); + } + + @Override + public char readChar() { + checkReadableBytes(2); + return buffer.readChar(); + } + + @Override + public float readFloat() { + checkReadableBytes(4); + return buffer.readFloat(); + } + + @Override + public double readDouble() { + checkReadableBytes(8); + return buffer.readDouble(); + } + + @Override + public CharSequence readCharSequence(int length, Charset charset) { + checkReadableBytes(length); + return buffer.readCharSequence(length, charset); + } + + @Override + public ByteBuf resetReaderIndex() { + buffer.resetReaderIndex(); + return this; + } + + @Override + public ByteBuf resetWriterIndex() { + throw reject(); + } + + @Override + public ByteBuf setBoolean(int index, boolean value) { + throw reject(); + } + + @Override + public ByteBuf setByte(int index, int value) { + throw reject(); + } + + @Override + public ByteBuf setBytes(int index, byte[] src, int srcIndex, int length) { + throw reject(); + } + + @Override + public ByteBuf setBytes(int index, byte[] src) { + throw reject(); + } + + @Override + public ByteBuf setBytes(int index, ByteBuffer src) { + throw reject(); + } + + @Override + public ByteBuf setBytes(int index, ByteBuf src, int srcIndex, int length) { + throw reject(); + } + + @Override + public ByteBuf setBytes(int index, ByteBuf src, int length) { + throw reject(); + } + + @Override + public ByteBuf setBytes(int index, ByteBuf src) { + throw reject(); + } + + @Override + public int setBytes(int index, InputStream in, int length) { + throw reject(); + } + + @Override + public ByteBuf setZero(int index, int length) { + throw reject(); + } + + @Override + public int setBytes(int index, ScatteringByteChannel in, int length) { + throw reject(); + } + + @Override + public int setBytes(int index, FileChannel in, long position, int length) { + throw reject(); + } + + @Override + public ByteBuf setIndex(int readerIndex, int writerIndex) { + throw reject(); + } + + @Override + public ByteBuf setInt(int index, int value) { + throw reject(); + } + + @Override + public ByteBuf setIntLE(int index, int value) { + throw reject(); + } + + @Override + public ByteBuf setLong(int index, long value) { + throw reject(); + } + + @Override + public ByteBuf setLongLE(int index, long value) { + throw reject(); + } + + @Override + public ByteBuf setMedium(int index, int value) { + throw reject(); + } + + @Override + public ByteBuf setMediumLE(int index, int value) { + throw reject(); + } + + @Override + public ByteBuf setShort(int index, int value) { + throw reject(); + } + + @Override + public ByteBuf setShortLE(int index, int value) { + throw reject(); + } + + @Override + public ByteBuf setChar(int index, int value) { + throw reject(); + } + + @Override + public ByteBuf setFloat(int index, float value) { + throw reject(); + } + + @Override + public ByteBuf setDouble(int index, double value) { + throw reject(); + } + + @Override + public ByteBuf skipBytes(int length) { + checkReadableBytes(length); + buffer.skipBytes(length); + return this; + } + + @Override + public ByteBuf slice() { + throw reject(); + } + + @Override + public ByteBuf retainedSlice() { + throw reject(); + } + + @Override + public ByteBuf slice(int index, int length) { + checkIndex(index, length); + return buffer.slice(index, length); + } + + @Override + public ByteBuf retainedSlice(int index, int length) { + checkIndex(index, length); + return buffer.retainedSlice(index, length); + } + + @Override + public int nioBufferCount() { + return buffer.nioBufferCount(); + } + + @Override + public ByteBuffer nioBuffer() { + throw reject(); + } + + @Override + public ByteBuffer nioBuffer(int index, int length) { + checkIndex(index, length); + return buffer.nioBuffer(index, length); + } + + @Override + public ByteBuffer[] nioBuffers() { + throw reject(); + } + + @Override + public ByteBuffer[] nioBuffers(int index, int length) { + checkIndex(index, length); + return buffer.nioBuffers(index, length); + } + + @Override + public ByteBuffer internalNioBuffer(int index, int length) { + checkIndex(index, length); + return buffer.internalNioBuffer(index, length); + } + + @Override + public String toString(int index, int length, Charset charset) { + checkIndex(index, length); + return buffer.toString(index, length, charset); + } + + @Override + public String toString(Charset charsetName) { + throw reject(); + } + + @Override + public String toString() { + return StringUtil.simpleClassName(this) + '(' + + "ridx=" + + readerIndex() + + ", " + + "widx=" + + writerIndex() + + ')'; + } + + @Override + public boolean isWritable() { + return false; + } + + @Override + public boolean isWritable(int size) { + return false; + } + + @Override + public int writableBytes() { + return 0; + } + + @Override + public int maxWritableBytes() { + return 0; + } + + @Override + public ByteBuf writeBoolean(boolean value) { + throw reject(); + } + + @Override + public ByteBuf writeByte(int value) { + throw reject(); + } + + @Override + public ByteBuf writeBytes(byte[] src, int srcIndex, int length) { + throw reject(); + } + + @Override + public ByteBuf writeBytes(byte[] src) { + throw reject(); + } + + @Override + public ByteBuf writeBytes(ByteBuffer src) { + throw reject(); + } + + @Override + public ByteBuf writeBytes(ByteBuf src, int srcIndex, int length) { + throw reject(); + } + + @Override + public ByteBuf writeBytes(ByteBuf src, int length) { + throw reject(); + } + + @Override + public ByteBuf writeBytes(ByteBuf src) { + throw reject(); + } + + @Override + public int writeBytes(InputStream in, int length) { + throw reject(); + } + + @Override + public int writeBytes(ScatteringByteChannel in, int length) { + throw reject(); + } + + @Override + public int writeBytes(FileChannel in, long position, int length) { + throw reject(); + } + + @Override + public ByteBuf writeInt(int value) { + throw reject(); + } + + @Override + public ByteBuf writeIntLE(int value) { + throw reject(); + } + + @Override + public ByteBuf writeLong(long value) { + throw reject(); + } + + @Override + public ByteBuf writeLongLE(long value) { + throw reject(); + } + + @Override + public ByteBuf writeMedium(int value) { + throw reject(); + } + + @Override + public ByteBuf writeMediumLE(int value) { + throw reject(); + } + + @Override + public ByteBuf writeZero(int length) { + throw reject(); + } + + @Override + public int writerIndex() { + return buffer.writerIndex(); + } + + @Override + public ByteBuf writerIndex(int writerIndex) { + throw reject(); + } + + @Override + public ByteBuf writeShort(int value) { + throw reject(); + } + + @Override + public ByteBuf writeShortLE(int value) { + throw reject(); + } + + @Override + public ByteBuf writeChar(int value) { + throw reject(); + } + + @Override + public ByteBuf writeFloat(float value) { + throw reject(); + } + + @Override + public ByteBuf writeDouble(double value) { + throw reject(); + } + + @Override + public int setCharSequence(int index, CharSequence sequence, Charset charset) { + throw reject(); + } + + @Override + public int writeCharSequence(CharSequence sequence, Charset charset) { + throw reject(); + } + + private void checkIndex(int index, int length) { + if (index + length > buffer.writerIndex()) { + throw REPLAY; + } + } + + private void checkReadableBytes(int readableBytes) { + if (buffer.readableBytes() < readableBytes) { + throw REPLAY; + } + } + + @Override + public ByteBuf discardSomeReadBytes() { + throw reject(); + } + + @Override + public int refCnt() { + return buffer.refCnt(); + } + + @Override + public ByteBuf retain() { + throw reject(); + } + + @Override + public ByteBuf retain(int increment) { + throw reject(); + } + + @Override + public ByteBuf touch() { + buffer.touch(); + return this; + } + + @Override + public ByteBuf touch(Object hint) { + buffer.touch(hint); + return this; + } + + @Override + public boolean release() { + throw reject(); + } + + @Override + public boolean release(int decrement) { + throw reject(); + } + + @Override + public ByteBuf unwrap() { + throw reject(); + } + + private static UnsupportedOperationException reject() { + return new UnsupportedOperationException("not a replayable operation"); + } +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/TooLongFrameException.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/TooLongFrameException.java new file mode 100644 index 0000000..e905f52 --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/TooLongFrameException.java @@ -0,0 +1,52 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec; + +/** + * An {@link DecoderException} which is thrown when the length of the frame + * decoded is greater than the allowed maximum. + */ +public class TooLongFrameException extends DecoderException { + + private static final long serialVersionUID = -1995801950698951640L; + + /** + * Creates a new instance. + */ + public TooLongFrameException() { + } + + /** + * Creates a new instance. + */ + public TooLongFrameException(String message, Throwable cause) { + super(message, cause); + } + + /** + * Creates a new instance. + */ + public TooLongFrameException(String message) { + super(message); + } + + /** + * Creates a new instance. + */ + public TooLongFrameException(Throwable cause) { + super(cause); + } +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/UnsupportedMessageTypeException.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/UnsupportedMessageTypeException.java new file mode 100644 index 0000000..48a8f04 --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/UnsupportedMessageTypeException.java @@ -0,0 +1,63 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec; + +/** + * Thrown if an unsupported message is received by an codec. + */ +public class UnsupportedMessageTypeException extends CodecException { + + private static final long serialVersionUID = 2799598826487038726L; + + public UnsupportedMessageTypeException( + Object message, Class... expectedTypes) { + super(message( + message == null? "null" : message.getClass().getName(), expectedTypes)); + } + + public UnsupportedMessageTypeException() { } + + public UnsupportedMessageTypeException(String message, Throwable cause) { + super(message, cause); + } + + public UnsupportedMessageTypeException(String s) { + super(s); + } + + public UnsupportedMessageTypeException(Throwable cause) { + super(cause); + } + + private static String message( + String actualType, Class... expectedTypes) { + StringBuilder buf = new StringBuilder(actualType); + + if (expectedTypes != null && expectedTypes.length > 0) { + buf.append(" (expected: ").append(expectedTypes[0].getName()); + for (int i = 1; i < expectedTypes.length; i ++) { + Class t = expectedTypes[i]; + if (t == null) { + break; + } + buf.append(", ").append(t.getName()); + } + buf.append(')'); + } + + return buf.toString(); + } +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/UnsupportedValueConverter.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/UnsupportedValueConverter.java new file mode 100644 index 0000000..2567045 --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/UnsupportedValueConverter.java @@ -0,0 +1,125 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec; + +/** + * {@link UnsupportedOperationException} will be thrown from all {@link ValueConverter} methods. + */ +public final class UnsupportedValueConverter implements ValueConverter { + @SuppressWarnings("rawtypes") + private static final UnsupportedValueConverter INSTANCE = new UnsupportedValueConverter(); + private UnsupportedValueConverter() { } + + @SuppressWarnings("unchecked") + public static UnsupportedValueConverter instance() { + return (UnsupportedValueConverter) INSTANCE; + } + + @Override + public V convertObject(Object value) { + throw new UnsupportedOperationException(); + } + + @Override + public V convertBoolean(boolean value) { + throw new UnsupportedOperationException(); + } + + @Override + public boolean convertToBoolean(V value) { + throw new UnsupportedOperationException(); + } + + @Override + public V convertByte(byte value) { + throw new UnsupportedOperationException(); + } + + @Override + public byte convertToByte(V value) { + throw new UnsupportedOperationException(); + } + + @Override + public V convertChar(char value) { + throw new UnsupportedOperationException(); + } + + @Override + public char convertToChar(V value) { + throw new UnsupportedOperationException(); + } + + @Override + public V convertShort(short value) { + throw new UnsupportedOperationException(); + } + + @Override + public short convertToShort(V value) { + throw new UnsupportedOperationException(); + } + + @Override + public V convertInt(int value) { + throw new UnsupportedOperationException(); + } + + @Override + public int convertToInt(V value) { + throw new UnsupportedOperationException(); + } + + @Override + public V convertLong(long value) { + throw new UnsupportedOperationException(); + } + + @Override + public long convertToLong(V value) { + throw new UnsupportedOperationException(); + } + + @Override + public V convertTimeMillis(long value) { + throw new UnsupportedOperationException(); + } + + @Override + public long convertToTimeMillis(V value) { + throw new UnsupportedOperationException(); + } + + @Override + public V convertFloat(float value) { + throw new UnsupportedOperationException(); + } + + @Override + public float convertToFloat(V value) { + throw new UnsupportedOperationException(); + } + + @Override + public V convertDouble(double value) { + throw new UnsupportedOperationException(); + } + + @Override + public double convertToDouble(V value) { + throw new UnsupportedOperationException(); + } +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/ValueConverter.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/ValueConverter.java new file mode 100644 index 0000000..4480f11 --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/ValueConverter.java @@ -0,0 +1,57 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */package io.netty.handler.codec; + +/** + * Converts to/from a generic object to the type. + */ +public interface ValueConverter { + T convertObject(Object value); + + T convertBoolean(boolean value); + + boolean convertToBoolean(T value); + + T convertByte(byte value); + + byte convertToByte(T value); + + T convertChar(char value); + + char convertToChar(T value); + + T convertShort(short value); + + short convertToShort(T value); + + T convertInt(int value); + + int convertToInt(T value); + + T convertLong(long value); + + long convertToLong(T value); + + T convertTimeMillis(long value); + + long convertToTimeMillis(T value); + + T convertFloat(float value); + + float convertToFloat(T value); + + T convertDouble(double value); + + double convertToDouble(T value); +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/base64/Base64.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/base64/Base64.java new file mode 100644 index 0000000..2018a74 --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/base64/Base64.java @@ -0,0 +1,429 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +/* + * Written by Robert Harder and released to the public domain, as explained at + * https://creativecommons.org/licenses/publicdomain + */ +package io.netty.handler.codec.base64; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.util.ByteProcessor; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.PlatformDependent; + +import java.nio.ByteOrder; + +/** + * Utility class for {@link ByteBuf} that encodes and decodes to and from + * Base64 notation. + *

+ * The encoding and decoding algorithm in this class has been derived from + * Robert Harder's Public Domain + * Base64 Encoder/Decoder. + */ +public final class Base64 { + + /** Maximum line length (76) of Base64 output. */ + private static final int MAX_LINE_LENGTH = 76; + + /** The equals sign (=) as a byte. */ + private static final byte EQUALS_SIGN = (byte) '='; + + /** The new line character (\n) as a byte. */ + private static final byte NEW_LINE = (byte) '\n'; + + private static final byte WHITE_SPACE_ENC = -5; // Indicates white space in encoding + + private static final byte EQUALS_SIGN_ENC = -1; // Indicates equals sign in encoding + + private static byte[] alphabet(Base64Dialect dialect) { + return ObjectUtil.checkNotNull(dialect, "dialect").alphabet; + } + + private static byte[] decodabet(Base64Dialect dialect) { + return ObjectUtil.checkNotNull(dialect, "dialect").decodabet; + } + + private static boolean breakLines(Base64Dialect dialect) { + return ObjectUtil.checkNotNull(dialect, "dialect").breakLinesByDefault; + } + + public static ByteBuf encode(ByteBuf src) { + return encode(src, Base64Dialect.STANDARD); + } + + public static ByteBuf encode(ByteBuf src, Base64Dialect dialect) { + return encode(src, breakLines(dialect), dialect); + } + + public static ByteBuf encode(ByteBuf src, boolean breakLines) { + return encode(src, breakLines, Base64Dialect.STANDARD); + } + + public static ByteBuf encode(ByteBuf src, boolean breakLines, Base64Dialect dialect) { + ObjectUtil.checkNotNull(src, "src"); + + ByteBuf dest = encode(src, src.readerIndex(), src.readableBytes(), breakLines, dialect); + src.readerIndex(src.writerIndex()); + return dest; + } + + public static ByteBuf encode(ByteBuf src, int off, int len) { + return encode(src, off, len, Base64Dialect.STANDARD); + } + + public static ByteBuf encode(ByteBuf src, int off, int len, Base64Dialect dialect) { + return encode(src, off, len, breakLines(dialect), dialect); + } + + public static ByteBuf encode( + ByteBuf src, int off, int len, boolean breakLines) { + return encode(src, off, len, breakLines, Base64Dialect.STANDARD); + } + + public static ByteBuf encode( + ByteBuf src, int off, int len, boolean breakLines, Base64Dialect dialect) { + return encode(src, off, len, breakLines, dialect, src.alloc()); + } + + public static ByteBuf encode( + ByteBuf src, int off, int len, boolean breakLines, Base64Dialect dialect, ByteBufAllocator allocator) { + ObjectUtil.checkNotNull(src, "src"); + ObjectUtil.checkNotNull(dialect, "dialect"); + + ByteBuf dest = allocator.buffer(encodedBufferSize(len, breakLines)).order(src.order()); + byte[] alphabet = alphabet(dialect); + int d = 0; + int e = 0; + int len2 = len - 2; + int lineLength = 0; + for (; d < len2; d += 3, e += 4) { + encode3to4(src, d + off, 3, dest, e, alphabet); + + lineLength += 4; + + if (breakLines && lineLength == MAX_LINE_LENGTH) { + dest.setByte(e + 4, NEW_LINE); + e ++; + lineLength = 0; + } // end if: end of line + } // end for: each piece of array + + if (d < len) { + encode3to4(src, d + off, len - d, dest, e, alphabet); + e += 4; + } // end if: some padding needed + + // Remove last byte if it's a newline + if (e > 1 && dest.getByte(e - 1) == NEW_LINE) { + e--; + } + + return dest.slice(0, e); + } + + private static void encode3to4( + ByteBuf src, int srcOffset, int numSigBytes, ByteBuf dest, int destOffset, byte[] alphabet) { + // 1 2 3 + // 01234567890123456789012345678901 Bit position + // --------000000001111111122222222 Array position from threeBytes + // --------| || || || | Six bit groups to index ALPHABET + // >>18 >>12 >> 6 >> 0 Right shift necessary + // 0x3f 0x3f 0x3f Additional AND + + // Create buffer with zero-padding if there are only one or two + // significant bytes passed in the array. + // We have to shift left 24 in order to flush out the 1's that appear + // when Java treats a value as negative that is cast from a byte to an int. + if (src.order() == ByteOrder.BIG_ENDIAN) { + final int inBuff; + switch (numSigBytes) { + case 1: + inBuff = toInt(src.getByte(srcOffset)); + break; + case 2: + inBuff = toIntBE(src.getShort(srcOffset)); + break; + default: + inBuff = numSigBytes <= 0 ? 0 : toIntBE(src.getMedium(srcOffset)); + break; + } + encode3to4BigEndian(inBuff, numSigBytes, dest, destOffset, alphabet); + } else { + final int inBuff; + switch (numSigBytes) { + case 1: + inBuff = toInt(src.getByte(srcOffset)); + break; + case 2: + inBuff = toIntLE(src.getShort(srcOffset)); + break; + default: + inBuff = numSigBytes <= 0 ? 0 : toIntLE(src.getMedium(srcOffset)); + break; + } + encode3to4LittleEndian(inBuff, numSigBytes, dest, destOffset, alphabet); + } + } + + // package-private for testing + static int encodedBufferSize(int len, boolean breakLines) { + // Cast len to long to prevent overflow + long len43 = ((long) len << 2) / 3; + + // Account for padding + long ret = (len43 + 3) & ~3; + + if (breakLines) { + ret += len43 / MAX_LINE_LENGTH; + } + + return ret < Integer.MAX_VALUE ? (int) ret : Integer.MAX_VALUE; + } + + private static int toInt(byte value) { + return (value & 0xff) << 16; + } + + private static int toIntBE(short value) { + return (value & 0xff00) << 8 | (value & 0xff) << 8; + } + + private static int toIntLE(short value) { + return (value & 0xff) << 16 | (value & 0xff00); + } + + private static int toIntBE(int mediumValue) { + return (mediumValue & 0xff0000) | (mediumValue & 0xff00) | (mediumValue & 0xff); + } + + private static int toIntLE(int mediumValue) { + return (mediumValue & 0xff) << 16 | (mediumValue & 0xff00) | (mediumValue & 0xff0000) >>> 16; + } + + private static void encode3to4BigEndian( + int inBuff, int numSigBytes, ByteBuf dest, int destOffset, byte[] alphabet) { + // Packing bytes into an int to reduce bound and reference count checking. + switch (numSigBytes) { + case 3: + dest.setInt(destOffset, alphabet[inBuff >>> 18 ] << 24 | + alphabet[inBuff >>> 12 & 0x3f] << 16 | + alphabet[inBuff >>> 6 & 0x3f] << 8 | + alphabet[inBuff & 0x3f]); + break; + case 2: + dest.setInt(destOffset, alphabet[inBuff >>> 18 ] << 24 | + alphabet[inBuff >>> 12 & 0x3f] << 16 | + alphabet[inBuff >>> 6 & 0x3f] << 8 | + EQUALS_SIGN); + break; + case 1: + dest.setInt(destOffset, alphabet[inBuff >>> 18 ] << 24 | + alphabet[inBuff >>> 12 & 0x3f] << 16 | + EQUALS_SIGN << 8 | + EQUALS_SIGN); + break; + default: + // NOOP + break; + } + } + + private static void encode3to4LittleEndian( + int inBuff, int numSigBytes, ByteBuf dest, int destOffset, byte[] alphabet) { + // Packing bytes into an int to reduce bound and reference count checking. + switch (numSigBytes) { + case 3: + dest.setInt(destOffset, alphabet[inBuff >>> 18 ] | + alphabet[inBuff >>> 12 & 0x3f] << 8 | + alphabet[inBuff >>> 6 & 0x3f] << 16 | + alphabet[inBuff & 0x3f] << 24); + break; + case 2: + dest.setInt(destOffset, alphabet[inBuff >>> 18 ] | + alphabet[inBuff >>> 12 & 0x3f] << 8 | + alphabet[inBuff >>> 6 & 0x3f] << 16 | + EQUALS_SIGN << 24); + break; + case 1: + dest.setInt(destOffset, alphabet[inBuff >>> 18 ] | + alphabet[inBuff >>> 12 & 0x3f] << 8 | + EQUALS_SIGN << 16 | + EQUALS_SIGN << 24); + break; + default: + // NOOP + break; + } + } + + public static ByteBuf decode(ByteBuf src) { + return decode(src, Base64Dialect.STANDARD); + } + + public static ByteBuf decode(ByteBuf src, Base64Dialect dialect) { + ObjectUtil.checkNotNull(src, "src"); + + ByteBuf dest = decode(src, src.readerIndex(), src.readableBytes(), dialect); + src.readerIndex(src.writerIndex()); + return dest; + } + + public static ByteBuf decode( + ByteBuf src, int off, int len) { + return decode(src, off, len, Base64Dialect.STANDARD); + } + + public static ByteBuf decode( + ByteBuf src, int off, int len, Base64Dialect dialect) { + return decode(src, off, len, dialect, src.alloc()); + } + + public static ByteBuf decode( + ByteBuf src, int off, int len, Base64Dialect dialect, ByteBufAllocator allocator) { + ObjectUtil.checkNotNull(src, "src"); + ObjectUtil.checkNotNull(dialect, "dialect"); + + // Using a ByteProcessor to reduce bound and reference count checking. + return new Decoder().decode(src, off, len, allocator, dialect); + } + + // package-private for testing + static int decodedBufferSize(int len) { + return len - (len >>> 2); + } + + private static final class Decoder implements ByteProcessor { + private final byte[] b4 = new byte[4]; + private int b4Posn; + private byte[] decodabet; + private int outBuffPosn; + private ByteBuf dest; + + ByteBuf decode(ByteBuf src, int off, int len, ByteBufAllocator allocator, Base64Dialect dialect) { + dest = allocator.buffer(decodedBufferSize(len)).order(src.order()); // Upper limit on size of output + + decodabet = decodabet(dialect); + try { + src.forEachByte(off, len, this); + return dest.slice(0, outBuffPosn); + } catch (Throwable cause) { + dest.release(); + PlatformDependent.throwException(cause); + return null; + } + } + + @Override + public boolean process(byte value) throws Exception { + if (value > 0) { + byte sbiDecode = decodabet[value]; + if (sbiDecode >= WHITE_SPACE_ENC) { // White space, Equals sign or better + if (sbiDecode >= EQUALS_SIGN_ENC) { // Equals sign or better + b4[b4Posn ++] = value; + if (b4Posn > 3) { // Quartet built + outBuffPosn += decode4to3(b4, dest, outBuffPosn, decodabet); + b4Posn = 0; + + // If that was the equals sign, break out of 'for' loop + return value != EQUALS_SIGN; + } + } + return true; + } + } + throw new IllegalArgumentException( + "invalid Base64 input character: " + (short) (value & 0xFF) + " (decimal)"); + } + + private static int decode4to3(byte[] src, ByteBuf dest, int destOffset, byte[] decodabet) { + final byte src0 = src[0]; + final byte src1 = src[1]; + final byte src2 = src[2]; + final int decodedValue; + if (src2 == EQUALS_SIGN) { + // Example: Dk== + try { + decodedValue = (decodabet[src0] & 0xff) << 2 | (decodabet[src1] & 0xff) >>> 4; + } catch (IndexOutOfBoundsException ignored) { + throw new IllegalArgumentException("not encoded in Base64"); + } + dest.setByte(destOffset, decodedValue); + return 1; + } + + final byte src3 = src[3]; + if (src3 == EQUALS_SIGN) { + // Example: DkL= + final byte b1 = decodabet[src1]; + // Packing bytes into a short to reduce bound and reference count checking. + try { + if (dest.order() == ByteOrder.BIG_ENDIAN) { + // The decodabet bytes are meant to straddle byte boundaries and so we must carefully mask out + // the bits we care about. + decodedValue = ((decodabet[src0] & 0x3f) << 2 | (b1 & 0xf0) >> 4) << 8 | + (b1 & 0xf) << 4 | (decodabet[src2] & 0xfc) >>> 2; + } else { + // This is just a simple byte swap of the operation above. + decodedValue = (decodabet[src0] & 0x3f) << 2 | (b1 & 0xf0) >> 4 | + ((b1 & 0xf) << 4 | (decodabet[src2] & 0xfc) >>> 2) << 8; + } + } catch (IndexOutOfBoundsException ignored) { + throw new IllegalArgumentException("not encoded in Base64"); + } + dest.setShort(destOffset, decodedValue); + return 2; + } + + // Example: DkLE + try { + if (dest.order() == ByteOrder.BIG_ENDIAN) { + decodedValue = (decodabet[src0] & 0x3f) << 18 | + (decodabet[src1] & 0xff) << 12 | + (decodabet[src2] & 0xff) << 6 | + decodabet[src3] & 0xff; + } else { + final byte b1 = decodabet[src1]; + final byte b2 = decodabet[src2]; + // The goal is to byte swap the BIG_ENDIAN case above. There are 2 interesting things to consider: + // 1. We are byte swapping a 3 byte data type. The left and the right byte switch, but the middle + // remains the same. + // 2. The contents straddles byte boundaries. This means bytes will be pulled apart during the byte + // swapping process. + decodedValue = (decodabet[src0] & 0x3f) << 2 | + // The bottom half of b1 remains in the middle. + (b1 & 0xf) << 12 | + // The top half of b1 are the least significant bits after the swap. + (b1 & 0xf0) >>> 4 | + // The bottom 2 bits of b2 will be the most significant bits after the swap. + (b2 & 0x3) << 22 | + // The remaining 6 bits of b2 remain in the middle. + (b2 & 0xfc) << 6 | + (decodabet[src3] & 0xff) << 16; + } + } catch (IndexOutOfBoundsException ignored) { + throw new IllegalArgumentException("not encoded in Base64"); + } + dest.setMedium(destOffset, decodedValue); + return 3; + } + } + + private Base64() { + // Unused + } +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/base64/Base64Decoder.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/base64/Base64Decoder.java new file mode 100644 index 0000000..2842bf5 --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/base64/Base64Decoder.java @@ -0,0 +1,64 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.base64; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandler.Sharable; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPipeline; +import io.netty.handler.codec.ByteToMessageDecoder; +import io.netty.handler.codec.DelimiterBasedFrameDecoder; +import io.netty.handler.codec.Delimiters; +import io.netty.handler.codec.MessageToMessageDecoder; +import io.netty.util.internal.ObjectUtil; + +import java.util.List; + +/** + * Decodes a Base64-encoded {@link ByteBuf} or US-ASCII {@link String} + * into a {@link ByteBuf}. Please note that this decoder must be used + * with a proper {@link ByteToMessageDecoder} such as {@link DelimiterBasedFrameDecoder} + * if you are using a stream-based transport such as TCP/IP. A typical decoder + * setup for TCP/IP would be: + *

+ * {@link ChannelPipeline} pipeline = ...;
+ *
+ * // Decoders
+ * pipeline.addLast("frameDecoder", new {@link DelimiterBasedFrameDecoder}(80, {@link Delimiters#nulDelimiter()}));
+ * pipeline.addLast("base64Decoder", new {@link Base64Decoder}());
+ *
+ * // Encoder
+ * pipeline.addLast("base64Encoder", new {@link Base64Encoder}());
+ * 
+ */ +@Sharable +public class Base64Decoder extends MessageToMessageDecoder { + + private final Base64Dialect dialect; + + public Base64Decoder() { + this(Base64Dialect.STANDARD); + } + + public Base64Decoder(Base64Dialect dialect) { + this.dialect = ObjectUtil.checkNotNull(dialect, "dialect"); + } + + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf msg, List out) throws Exception { + out.add(Base64.decode(msg, msg.readerIndex(), msg.readableBytes(), dialect)); + } +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/base64/Base64Dialect.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/base64/Base64Dialect.java new file mode 100644 index 0000000..c40721f --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/base64/Base64Dialect.java @@ -0,0 +1,207 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +/* + * Written by Robert Harder and released to the public domain, as explained at + * https://creativecommons.org/licenses/publicdomain + */ +package io.netty.handler.codec.base64; + +/** + * Enumeration of supported Base64 dialects. + *

+ * The internal lookup tables in this class has been derived from + * Robert Harder's Public Domain + * Base64 Encoder/Decoder. + */ +public enum Base64Dialect { + /** + * Standard Base64 encoding as described in the Section 3 of + * RFC3548. + */ + STANDARD(new byte[] { + (byte) 'A', (byte) 'B', (byte) 'C', (byte) 'D', (byte) 'E', + (byte) 'F', (byte) 'G', (byte) 'H', (byte) 'I', (byte) 'J', + (byte) 'K', (byte) 'L', (byte) 'M', (byte) 'N', (byte) 'O', + (byte) 'P', (byte) 'Q', (byte) 'R', (byte) 'S', (byte) 'T', + (byte) 'U', (byte) 'V', (byte) 'W', (byte) 'X', (byte) 'Y', + (byte) 'Z', (byte) 'a', (byte) 'b', (byte) 'c', (byte) 'd', + (byte) 'e', (byte) 'f', (byte) 'g', (byte) 'h', (byte) 'i', + (byte) 'j', (byte) 'k', (byte) 'l', (byte) 'm', (byte) 'n', + (byte) 'o', (byte) 'p', (byte) 'q', (byte) 'r', (byte) 's', + (byte) 't', (byte) 'u', (byte) 'v', (byte) 'w', (byte) 'x', + (byte) 'y', (byte) 'z', (byte) '0', (byte) '1', (byte) '2', + (byte) '3', (byte) '4', (byte) '5', (byte) '6', (byte) '7', + (byte) '8', (byte) '9', (byte) '+', (byte) '/' }, + new byte[] { + -9, -9, -9, -9, -9, -9, + -9, -9, -9, // Decimal 0 - 8 + -5, -5, // Whitespace: Tab and Linefeed + -9, -9, // Decimal 11 - 12 + -5, // Whitespace: Carriage Return + -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, // Decimal 14 - 26 + -9, -9, -9, -9, -9, // Decimal 27 - 31 + -5, // Whitespace: Space + -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, // Decimal 33 - 42 + 62, // Plus sign at decimal 43 + -9, -9, -9, // Decimal 44 - 46 + 63, // Slash at decimal 47 + 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, // Numbers zero through nine + -9, -9, -9, // Decimal 58 - 60 + -1, // Equals sign at decimal 61 + -9, -9, -9, // Decimal 62 - 64 + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, // Letters 'A' through 'N' + 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, // Letters 'O' through 'Z' + -9, -9, -9, -9, -9, -9, // Decimal 91 - 96 + 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, // Letters 'a' through 'm' + 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, // Letters 'n' through 'z' + -9, -9, -9, -9, -9 // Decimal 123 - 127 + /* -9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9, // Decimal 128 - 140 + -9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9, // Decimal 141 - 153 + -9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9, // Decimal 154 - 166 + -9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9, // Decimal 167 - 179 + -9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9, // Decimal 180 - 192 + -9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9, // Decimal 193 - 205 + -9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9, // Decimal 206 - 218 + -9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9, // Decimal 219 - 231 + -9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9, // Decimal 232 - 244 + -9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9 // Decimal 245 - 255 */ + }, true), + /** + * Base64-like encoding that is URL-safe as described in the Section 4 of + * RFC3548. It is + * important to note that data encoded this way is not officially + * valid Base64, or at the very least should not be called Base64 without + * also specifying that is was encoded using the URL-safe dialect. + */ + URL_SAFE(new byte[] { + (byte) 'A', (byte) 'B', (byte) 'C', (byte) 'D', (byte) 'E', + (byte) 'F', (byte) 'G', (byte) 'H', (byte) 'I', (byte) 'J', + (byte) 'K', (byte) 'L', (byte) 'M', (byte) 'N', (byte) 'O', + (byte) 'P', (byte) 'Q', (byte) 'R', (byte) 'S', (byte) 'T', + (byte) 'U', (byte) 'V', (byte) 'W', (byte) 'X', (byte) 'Y', + (byte) 'Z', (byte) 'a', (byte) 'b', (byte) 'c', (byte) 'd', + (byte) 'e', (byte) 'f', (byte) 'g', (byte) 'h', (byte) 'i', + (byte) 'j', (byte) 'k', (byte) 'l', (byte) 'm', (byte) 'n', + (byte) 'o', (byte) 'p', (byte) 'q', (byte) 'r', (byte) 's', + (byte) 't', (byte) 'u', (byte) 'v', (byte) 'w', (byte) 'x', + (byte) 'y', (byte) 'z', (byte) '0', (byte) '1', (byte) '2', + (byte) '3', (byte) '4', (byte) '5', (byte) '6', (byte) '7', + (byte) '8', (byte) '9', (byte) '-', (byte) '_' }, + new byte[] { + -9, -9, -9, -9, -9, -9, + -9, -9, -9, // Decimal 0 - 8 + -5, -5, // Whitespace: Tab and Linefeed + -9, -9, // Decimal 11 - 12 + -5, // Whitespace: Carriage Return + -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, // Decimal 14 - 26 + -9, -9, -9, -9, -9, // Decimal 27 - 31 + -5, // Whitespace: Space + -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, // Decimal 33 - 42 + -9, // Plus sign at decimal 43 + -9, // Decimal 44 + 62, // Minus sign at decimal 45 + -9, // Decimal 46 + -9, // Slash at decimal 47 + 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, // Numbers zero through nine + -9, -9, -9, // Decimal 58 - 60 + -1, // Equals sign at decimal 61 + -9, -9, -9, // Decimal 62 - 64 + 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, // Letters 'A' through 'N' + 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, 24, 25, // Letters 'O' through 'Z' + -9, -9, -9, -9, // Decimal 91 - 94 + 63, // Underscore at decimal 95 + -9, // Decimal 96 + 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, 37, 38, // Letters 'a' through 'm' + 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, 51, // Letters 'n' through 'z' + -9, -9, -9, -9, -9, // Decimal 123 - 127 + /* -9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9, // Decimal 128 - 140 + -9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9, // Decimal 141 - 153 + -9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9, // Decimal 154 - 166 + -9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9, // Decimal 167 - 179 + -9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9, // Decimal 180 - 192 + -9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9, // Decimal 193 - 205 + -9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9, // Decimal 206 - 218 + -9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9, // Decimal 219 - 231 + -9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9, // Decimal 232 - 244 + -9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9 // Decimal 245 - 255 */ + }, false), + /** + * Special "ordered" dialect of Base64 described in + * RFC1940. + */ + ORDERED(new byte[] { + (byte) '-', (byte) '0', (byte) '1', (byte) '2', (byte) '3', + (byte) '4', (byte) '5', (byte) '6', (byte) '7', (byte) '8', + (byte) '9', (byte) 'A', (byte) 'B', (byte) 'C', (byte) 'D', + (byte) 'E', (byte) 'F', (byte) 'G', (byte) 'H', (byte) 'I', + (byte) 'J', (byte) 'K', (byte) 'L', (byte) 'M', (byte) 'N', + (byte) 'O', (byte) 'P', (byte) 'Q', (byte) 'R', (byte) 'S', + (byte) 'T', (byte) 'U', (byte) 'V', (byte) 'W', (byte) 'X', + (byte) 'Y', (byte) 'Z', (byte) '_', (byte) 'a', (byte) 'b', + (byte) 'c', (byte) 'd', (byte) 'e', (byte) 'f', (byte) 'g', + (byte) 'h', (byte) 'i', (byte) 'j', (byte) 'k', (byte) 'l', + (byte) 'm', (byte) 'n', (byte) 'o', (byte) 'p', (byte) 'q', + (byte) 'r', (byte) 's', (byte) 't', (byte) 'u', (byte) 'v', + (byte) 'w', (byte) 'x', (byte) 'y', (byte) 'z' }, + new byte[] { + -9, -9, -9, -9, -9, -9, + -9, -9, -9, // Decimal 0 - 8 + -5, -5, // Whitespace: Tab and Linefeed + -9, -9, // Decimal 11 - 12 + -5, // Whitespace: Carriage Return + -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, // Decimal 14 - 26 + -9, -9, -9, -9, -9, // Decimal 27 - 31 + -5, // Whitespace: Space + -9, -9, -9, -9, -9, -9, -9, -9, -9, -9, // Decimal 33 - 42 + -9, // Plus sign at decimal 43 + -9, // Decimal 44 + 0, // Minus sign at decimal 45 + -9, // Decimal 46 + -9, // Slash at decimal 47 + 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, // Numbers zero through nine + -9, -9, -9, // Decimal 58 - 60 + -1, // Equals sign at decimal 61 + -9, -9, -9, // Decimal 62 - 64 + 11, 12, 13, 14, 15, 16, 17, 18, 19, 20, 21, 22, 23, // Letters 'A' through 'M' + 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35, 36, // Letters 'N' through 'Z' + -9, -9, -9, -9, // Decimal 91 - 94 + 37, // Underscore at decimal 95 + -9, // Decimal 96 + 38, 39, 40, 41, 42, 43, 44, 45, 46, 47, 48, 49, 50, // Letters 'a' through 'm' + 51, 52, 53, 54, 55, 56, 57, 58, 59, 60, 61, 62, 63, // Letters 'n' through 'z' + -9, -9, -9, -9, -9 // Decimal 123 - 127 + /* -9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9, // Decimal 128 - 140 + -9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9, // Decimal 141 - 153 + -9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9, // Decimal 154 - 166 + -9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9, // Decimal 167 - 179 + -9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9, // Decimal 180 - 192 + -9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9, // Decimal 193 - 205 + -9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9, // Decimal 206 - 218 + -9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9, // Decimal 219 - 231 + -9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9, // Decimal 232 - 244 + -9,-9,-9,-9,-9,-9,-9,-9,-9,-9,-9 // Decimal 245 - 255 */ + }, true); + + final byte[] alphabet; + final byte[] decodabet; + final boolean breakLinesByDefault; + + Base64Dialect(byte[] alphabet, byte[] decodabet, boolean breakLinesByDefault) { + this.alphabet = alphabet; + this.decodabet = decodabet; + this.breakLinesByDefault = breakLinesByDefault; + } +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/base64/Base64Encoder.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/base64/Base64Encoder.java new file mode 100644 index 0000000..bfa700c --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/base64/Base64Encoder.java @@ -0,0 +1,66 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.base64; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandler.Sharable; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPipeline; +import io.netty.handler.codec.DelimiterBasedFrameDecoder; +import io.netty.handler.codec.Delimiters; +import io.netty.handler.codec.MessageToMessageEncoder; +import io.netty.util.internal.ObjectUtil; + +import java.util.List; + +/** + * Encodes a {@link ByteBuf} into a Base64-encoded {@link ByteBuf}. + * A typical setup for TCP/IP would be: + *

+ * {@link ChannelPipeline} pipeline = ...;
+ *
+ * // Decoders
+ * pipeline.addLast("frameDecoder", new {@link DelimiterBasedFrameDecoder}(80, {@link Delimiters#nulDelimiter()}));
+ * pipeline.addLast("base64Decoder", new {@link Base64Decoder}());
+ *
+ * // Encoder
+ * pipeline.addLast("base64Encoder", new {@link Base64Encoder}());
+ * 
+ */ +@Sharable +public class Base64Encoder extends MessageToMessageEncoder { + + private final boolean breakLines; + private final Base64Dialect dialect; + + public Base64Encoder() { + this(true); + } + + public Base64Encoder(boolean breakLines) { + this(breakLines, Base64Dialect.STANDARD); + } + + public Base64Encoder(boolean breakLines, Base64Dialect dialect) { + this.dialect = ObjectUtil.checkNotNull(dialect, "dialect"); + this.breakLines = breakLines; + } + + @Override + protected void encode(ChannelHandlerContext ctx, ByteBuf msg, List out) throws Exception { + out.add(Base64.encode(msg, msg.readerIndex(), msg.readableBytes(), breakLines, dialect)); + } +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/base64/package-info.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/base64/package-info.java new file mode 100644 index 0000000..96a8549 --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/base64/package-info.java @@ -0,0 +1,23 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/** + * Encoder and decoder which transform a + * Base64-encoded + * {@link java.lang.String} or {@link io.netty.buffer.ByteBuf} + * into a decoded {@link io.netty.buffer.ByteBuf} and vice versa. + */ +package io.netty.handler.codec.base64; diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/bytes/ByteArrayDecoder.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/bytes/ByteArrayDecoder.java new file mode 100644 index 0000000..d8ef90e --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/bytes/ByteArrayDecoder.java @@ -0,0 +1,58 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.bytes; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPipeline; +import io.netty.handler.codec.LengthFieldBasedFrameDecoder; +import io.netty.handler.codec.LengthFieldPrepender; +import io.netty.handler.codec.MessageToMessageDecoder; + +import java.util.List; + +/** + * Decodes a received {@link ByteBuf} into an array of bytes. + * A typical setup for TCP/IP would be: + *
+ * {@link ChannelPipeline} pipeline = ...;
+ *
+ * // Decoders
+ * pipeline.addLast("frameDecoder",
+ *                  new {@link LengthFieldBasedFrameDecoder}(1048576, 0, 4, 0, 4));
+ * pipeline.addLast("bytesDecoder",
+ *                  new {@link ByteArrayDecoder}());
+ *
+ * // Encoder
+ * pipeline.addLast("frameEncoder", new {@link LengthFieldPrepender}(4));
+ * pipeline.addLast("bytesEncoder", new {@link ByteArrayEncoder}());
+ * 
+ * and then you can use an array of bytes instead of a {@link ByteBuf} + * as a message: + *
+ * void channelRead({@link ChannelHandlerContext} ctx, byte[] bytes) {
+ *     ...
+ * }
+ * 
+ */ +public class ByteArrayDecoder extends MessageToMessageDecoder { + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf msg, List out) throws Exception { + // copy the ByteBuf content to a byte array + out.add(ByteBufUtil.getBytes(msg)); + } +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/bytes/ByteArrayEncoder.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/bytes/ByteArrayEncoder.java new file mode 100644 index 0000000..7b82c9b --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/bytes/ByteArrayEncoder.java @@ -0,0 +1,59 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.bytes; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandler.Sharable; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPipeline; +import io.netty.handler.codec.LengthFieldBasedFrameDecoder; +import io.netty.handler.codec.LengthFieldPrepender; +import io.netty.handler.codec.MessageToMessageEncoder; + +import java.util.List; + +/** + * Encodes the requested array of bytes into a {@link ByteBuf}. + * A typical setup for TCP/IP would be: + *
+ * {@link ChannelPipeline} pipeline = ...;
+ *
+ * // Decoders
+ * pipeline.addLast("frameDecoder",
+ *                  new {@link LengthFieldBasedFrameDecoder}(1048576, 0, 4, 0, 4));
+ * pipeline.addLast("bytesDecoder",
+ *                  new {@link ByteArrayDecoder}());
+ *
+ * // Encoder
+ * pipeline.addLast("frameEncoder", new {@link LengthFieldPrepender}(4));
+ * pipeline.addLast("bytesEncoder", new {@link ByteArrayEncoder}());
+ * 
+ * and then you can use an array of bytes instead of a {@link ByteBuf} + * as a message: + *
+ * void channelRead({@link ChannelHandlerContext} ctx, byte[] bytes) {
+ *     ...
+ * }
+ * 
+ */ +@Sharable +public class ByteArrayEncoder extends MessageToMessageEncoder { + @Override + protected void encode(ChannelHandlerContext ctx, byte[] msg, List out) throws Exception { + out.add(Unpooled.wrappedBuffer(msg)); + } +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/bytes/package-info.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/bytes/package-info.java new file mode 100644 index 0000000..1944465 --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/bytes/package-info.java @@ -0,0 +1,21 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/** + * Encoder and decoder which transform an array of bytes into a + * {@link io.netty.buffer.ByteBuf} and vice versa. + */ +package io.netty.handler.codec.bytes; diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/json/JsonObjectDecoder.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/json/JsonObjectDecoder.java new file mode 100644 index 0000000..a1b8c3e --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/json/JsonObjectDecoder.java @@ -0,0 +1,237 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.codec.json; + +import static io.netty.util.internal.ObjectUtil.checkPositive; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.ByteToMessageDecoder; +import io.netty.handler.codec.CorruptedFrameException; +import io.netty.handler.codec.TooLongFrameException; +import io.netty.channel.ChannelPipeline; + +import java.util.List; + +/** + * Splits a byte stream of JSON objects and arrays into individual objects/arrays and passes them up the + * {@link ChannelPipeline}. + *

+ * The byte stream is expected to be in UTF-8 character encoding or ASCII. The current implementation + * uses direct {@code byte} to {@code char} cast and then compares that {@code char} to a few low range + * ASCII characters like {@code '{'}, {@code '['} or {@code '"'}. UTF-8 is not using low range [0..0x7F] + * byte values for multibyte codepoint representations therefore fully supported by this implementation. + *

+ * This class does not do any real parsing or validation. A sequence of bytes is considered a JSON object/array + * if it contains a matching number of opening and closing braces/brackets. It's up to a subsequent + * {@link ChannelHandler} to parse the JSON text into a more usable form i.e. a POJO. + */ +public class JsonObjectDecoder extends ByteToMessageDecoder { + + private static final int ST_CORRUPTED = -1; + private static final int ST_INIT = 0; + private static final int ST_DECODING_NORMAL = 1; + private static final int ST_DECODING_ARRAY_STREAM = 2; + + private int openBraces; + private int idx; + + private int lastReaderIndex; + + private int state; + private boolean insideString; + + private final int maxObjectLength; + private final boolean streamArrayElements; + + public JsonObjectDecoder() { + // 1 MB + this(1024 * 1024); + } + + public JsonObjectDecoder(int maxObjectLength) { + this(maxObjectLength, false); + } + + public JsonObjectDecoder(boolean streamArrayElements) { + this(1024 * 1024, streamArrayElements); + } + + /** + * @param maxObjectLength maximum number of bytes a JSON object/array may use (including braces and all). + * Objects exceeding this length are dropped and an {@link TooLongFrameException} + * is thrown. + * @param streamArrayElements if set to true and the "top level" JSON object is an array, each of its entries + * is passed through the pipeline individually and immediately after it was fully + * received, allowing for arrays with "infinitely" many elements. + * + */ + public JsonObjectDecoder(int maxObjectLength, boolean streamArrayElements) { + this.maxObjectLength = checkPositive(maxObjectLength, "maxObjectLength"); + this.streamArrayElements = streamArrayElements; + } + + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { + if (state == ST_CORRUPTED) { + in.skipBytes(in.readableBytes()); + return; + } + + if (this.idx > in.readerIndex() && lastReaderIndex != in.readerIndex()) { + this.idx = in.readerIndex() + (idx - lastReaderIndex); + } + + // index of next byte to process. + int idx = this.idx; + int wrtIdx = in.writerIndex(); + + if (wrtIdx > maxObjectLength) { + // buffer size exceeded maxObjectLength; discarding the complete buffer. + in.skipBytes(in.readableBytes()); + reset(); + throw new TooLongFrameException( + "object length exceeds " + maxObjectLength + ": " + wrtIdx + " bytes discarded"); + } + + for (/* use current idx */; idx < wrtIdx; idx++) { + byte c = in.getByte(idx); + if (state == ST_DECODING_NORMAL) { + decodeByte(c, in, idx); + + // All opening braces/brackets have been closed. That's enough to conclude + // that the JSON object/array is complete. + if (openBraces == 0) { + ByteBuf json = extractObject(ctx, in, in.readerIndex(), idx + 1 - in.readerIndex()); + if (json != null) { + out.add(json); + } + + // The JSON object/array was extracted => discard the bytes from + // the input buffer. + in.readerIndex(idx + 1); + // Reset the object state to get ready for the next JSON object/text + // coming along the byte stream. + reset(); + } + } else if (state == ST_DECODING_ARRAY_STREAM) { + decodeByte(c, in, idx); + + if (!insideString && (openBraces == 1 && c == ',' || openBraces == 0 && c == ']')) { + // skip leading spaces. No range check is needed and the loop will terminate + // because the byte at position idx is not a whitespace. + for (int i = in.readerIndex(); Character.isWhitespace(in.getByte(i)); i++) { + in.skipBytes(1); + } + + // skip trailing spaces. + int idxNoSpaces = idx - 1; + while (idxNoSpaces >= in.readerIndex() && Character.isWhitespace(in.getByte(idxNoSpaces))) { + idxNoSpaces--; + } + + ByteBuf json = extractObject(ctx, in, in.readerIndex(), idxNoSpaces + 1 - in.readerIndex()); + if (json != null) { + out.add(json); + } + + in.readerIndex(idx + 1); + + if (c == ']') { + reset(); + } + } + // JSON object/array detected. Accumulate bytes until all braces/brackets are closed. + } else if (c == '{' || c == '[') { + initDecoding(c); + + if (state == ST_DECODING_ARRAY_STREAM) { + // Discard the array bracket + in.skipBytes(1); + } + // Discard leading spaces in front of a JSON object/array. + } else if (Character.isWhitespace(c)) { + in.skipBytes(1); + } else { + state = ST_CORRUPTED; + throw new CorruptedFrameException( + "invalid JSON received at byte position " + idx + ": " + ByteBufUtil.hexDump(in)); + } + } + + if (in.readableBytes() == 0) { + this.idx = 0; + } else { + this.idx = idx; + } + this.lastReaderIndex = in.readerIndex(); + } + + /** + * Override this method if you want to filter the json objects/arrays that get passed through the pipeline. + */ + @SuppressWarnings("UnusedParameters") + protected ByteBuf extractObject(ChannelHandlerContext ctx, ByteBuf buffer, int index, int length) { + return buffer.retainedSlice(index, length); + } + + private void decodeByte(byte c, ByteBuf in, int idx) { + if ((c == '{' || c == '[') && !insideString) { + openBraces++; + } else if ((c == '}' || c == ']') && !insideString) { + openBraces--; + } else if (c == '"') { + // start of a new JSON string. It's necessary to detect strings as they may + // also contain braces/brackets and that could lead to incorrect results. + if (!insideString) { + insideString = true; + } else { + int backslashCount = 0; + idx--; + while (idx >= 0) { + if (in.getByte(idx) == '\\') { + backslashCount++; + idx--; + } else { + break; + } + } + // The double quote isn't escaped only if there are even "\"s. + if (backslashCount % 2 == 0) { + // Since the double quote isn't escaped then this is the end of a string. + insideString = false; + } + } + } + } + + private void initDecoding(byte openingBrace) { + openBraces = 1; + if (openingBrace == '[' && streamArrayElements) { + state = ST_DECODING_ARRAY_STREAM; + } else { + state = ST_DECODING_NORMAL; + } + } + + private void reset() { + insideString = false; + state = ST_INIT; + openBraces = 0; + } +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/json/package-info.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/json/package-info.java new file mode 100644 index 0000000..6c2252a --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/json/package-info.java @@ -0,0 +1,20 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/** + * JSON specific codecs. + */ +package io.netty.handler.codec.json; diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/package-info.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/package-info.java new file mode 100644 index 0000000..b25306c --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/package-info.java @@ -0,0 +1,22 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/** + * Extensible decoder and its common implementations which deal with the + * packet fragmentation and reassembly issue found in a stream-based transport + * such as TCP/IP. + */ +package io.netty.handler.codec; diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/serialization/CachingClassResolver.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/serialization/CachingClassResolver.java new file mode 100644 index 0000000..825b028 --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/serialization/CachingClassResolver.java @@ -0,0 +1,46 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.serialization; + +import java.util.Map; + +class CachingClassResolver implements ClassResolver { + + private final Map> classCache; + private final ClassResolver delegate; + + CachingClassResolver(ClassResolver delegate, Map> classCache) { + this.delegate = delegate; + this.classCache = classCache; + } + + @Override + public Class resolve(String className) throws ClassNotFoundException { + // Query the cache first. + Class clazz; + clazz = classCache.get(className); + if (clazz != null) { + return clazz; + } + + // And then try to load. + clazz = delegate.resolve(className); + + classCache.put(className, clazz); + return clazz; + } + +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/serialization/ClassLoaderClassResolver.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/serialization/ClassLoaderClassResolver.java new file mode 100644 index 0000000..4e7fde6 --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/serialization/ClassLoaderClassResolver.java @@ -0,0 +1,35 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.serialization; + +class ClassLoaderClassResolver implements ClassResolver { + + private final ClassLoader classLoader; + + ClassLoaderClassResolver(ClassLoader classLoader) { + this.classLoader = classLoader; + } + + @Override + public Class resolve(String className) throws ClassNotFoundException { + try { + return classLoader.loadClass(className); + } catch (ClassNotFoundException ignored) { + return Class.forName(className, false, classLoader); + } + } + +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/serialization/ClassResolver.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/serialization/ClassResolver.java new file mode 100644 index 0000000..86dbc30 --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/serialization/ClassResolver.java @@ -0,0 +1,36 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.serialization; + +/** + * please use {@link ClassResolvers} as instance factory + *

+ * Security: serialization can be a security liability, + * and should not be used without defining a list of classes that are + * allowed to be desirialized. Such a list can be specified with the + * jdk.serialFilter system property, for instance. + * See the + * serialization filtering article for more information. + * + * @deprecated This class has been deprecated with no replacement, + * because serialization can be a security liability + */ +@Deprecated +public interface ClassResolver { + + Class resolve(String className) throws ClassNotFoundException; + +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/serialization/ClassResolvers.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/serialization/ClassResolvers.java new file mode 100644 index 0000000..2d1052c --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/serialization/ClassResolvers.java @@ -0,0 +1,118 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.serialization; + +import io.netty.util.internal.PlatformDependent; + +import java.lang.ref.Reference; +import java.util.HashMap; + +/** + * Factory methods for creating {@link ClassResolver} instances. + *

+ * Security: serialization can be a security liability, + * and should not be used without defining a list of classes that are + * allowed to be desirialized. Such a list can be specified with the + * jdk.serialFilter system property, for instance. + * See the + * serialization filtering article for more information. + * + * @deprecated This class has been deprecated with no replacement, + * because serialization can be a security liability + */ +@Deprecated +public final class ClassResolvers { + + /** + * cache disabled + * @param classLoader - specific classLoader to use, or null if you want to revert to default + * @return new instance of class resolver + */ + public static ClassResolver cacheDisabled(ClassLoader classLoader) { + return new ClassLoaderClassResolver(defaultClassLoader(classLoader)); + } + + /** + * non-aggressive non-concurrent cache + * good for non-shared default cache + * + * @param classLoader - specific classLoader to use, or null if you want to revert to default + * @return new instance of class resolver + */ + public static ClassResolver weakCachingResolver(ClassLoader classLoader) { + return new CachingClassResolver( + new ClassLoaderClassResolver(defaultClassLoader(classLoader)), + new WeakReferenceMap>(new HashMap>>())); + } + + /** + * aggressive non-concurrent cache + * good for non-shared cache, when we're not worried about class unloading + * + * @param classLoader - specific classLoader to use, or null if you want to revert to default + * @return new instance of class resolver + */ + public static ClassResolver softCachingResolver(ClassLoader classLoader) { + return new CachingClassResolver( + new ClassLoaderClassResolver(defaultClassLoader(classLoader)), + new SoftReferenceMap>(new HashMap>>())); + } + + /** + * non-aggressive concurrent cache + * good for shared cache, when we're worried about class unloading + * + * @param classLoader - specific classLoader to use, or null if you want to revert to default + * @return new instance of class resolver + */ + public static ClassResolver weakCachingConcurrentResolver(ClassLoader classLoader) { + return new CachingClassResolver( + new ClassLoaderClassResolver(defaultClassLoader(classLoader)), + new WeakReferenceMap>( + PlatformDependent.>>newConcurrentHashMap())); + } + + /** + * aggressive concurrent cache + * good for shared cache, when we're not worried about class unloading + * + * @param classLoader - specific classLoader to use, or null if you want to revert to default + * @return new instance of class resolver + */ + public static ClassResolver softCachingConcurrentResolver(ClassLoader classLoader) { + return new CachingClassResolver( + new ClassLoaderClassResolver(defaultClassLoader(classLoader)), + new SoftReferenceMap>( + PlatformDependent.>>newConcurrentHashMap())); + } + + static ClassLoader defaultClassLoader(ClassLoader classLoader) { + if (classLoader != null) { + return classLoader; + } + + final ClassLoader contextClassLoader = PlatformDependent.getContextClassLoader(); + if (contextClassLoader != null) { + return contextClassLoader; + } + + return PlatformDependent.getClassLoader(ClassResolvers.class); + } + + private ClassResolvers() { + // Unused + } +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/serialization/CompactObjectInputStream.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/serialization/CompactObjectInputStream.java new file mode 100644 index 0000000..8e9dbc0 --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/serialization/CompactObjectInputStream.java @@ -0,0 +1,75 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.serialization; + +import java.io.EOFException; +import java.io.IOException; +import java.io.InputStream; +import java.io.ObjectInputStream; +import java.io.ObjectStreamClass; +import java.io.StreamCorruptedException; + +class CompactObjectInputStream extends ObjectInputStream { + + private final ClassResolver classResolver; + + CompactObjectInputStream(InputStream in, ClassResolver classResolver) throws IOException { + super(in); + this.classResolver = classResolver; + } + + @Override + protected void readStreamHeader() throws IOException { + int version = readByte() & 0xFF; + if (version != STREAM_VERSION) { + throw new StreamCorruptedException( + "Unsupported version: " + version); + } + } + + @Override + protected ObjectStreamClass readClassDescriptor() + throws IOException, ClassNotFoundException { + int type = read(); + if (type < 0) { + throw new EOFException(); + } + switch (type) { + case CompactObjectOutputStream.TYPE_FAT_DESCRIPTOR: + return super.readClassDescriptor(); + case CompactObjectOutputStream.TYPE_THIN_DESCRIPTOR: + String className = readUTF(); + Class clazz = classResolver.resolve(className); + return ObjectStreamClass.lookupAny(clazz); + default: + throw new StreamCorruptedException( + "Unexpected class descriptor type: " + type); + } + } + + @Override + protected Class resolveClass(ObjectStreamClass desc) throws IOException, ClassNotFoundException { + Class clazz; + try { + clazz = classResolver.resolve(desc.getName()); + } catch (ClassNotFoundException ignored) { + clazz = super.resolveClass(desc); + } + + return clazz; + } + +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/serialization/CompactObjectOutputStream.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/serialization/CompactObjectOutputStream.java new file mode 100644 index 0000000..f8bdffb --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/serialization/CompactObjectOutputStream.java @@ -0,0 +1,49 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.serialization; + +import java.io.IOException; +import java.io.ObjectOutputStream; +import java.io.ObjectStreamClass; +import java.io.OutputStream; + +class CompactObjectOutputStream extends ObjectOutputStream { + + static final int TYPE_FAT_DESCRIPTOR = 0; + static final int TYPE_THIN_DESCRIPTOR = 1; + + CompactObjectOutputStream(OutputStream out) throws IOException { + super(out); + } + + @Override + protected void writeStreamHeader() throws IOException { + writeByte(STREAM_VERSION); + } + + @Override + protected void writeClassDescriptor(ObjectStreamClass desc) throws IOException { + Class clazz = desc.forClass(); + if (clazz.isPrimitive() || clazz.isArray() || clazz.isInterface() || + desc.getSerialVersionUID() == 0) { + write(TYPE_FAT_DESCRIPTOR); + super.writeClassDescriptor(desc); + } else { + write(TYPE_THIN_DESCRIPTOR); + writeUTF(desc.getName()); + } + } +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/serialization/CompatibleObjectEncoder.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/serialization/CompatibleObjectEncoder.java new file mode 100644 index 0000000..bcc1d63 --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/serialization/CompatibleObjectEncoder.java @@ -0,0 +1,101 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.serialization; + +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufOutputStream; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.MessageToByteEncoder; + +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.OutputStream; +import java.io.Serializable; + +/** + * An encoder which serializes a Java object into a {@link ByteBuf} + * (interoperability version). + *

+ * This encoder is interoperable with the standard Java object streams such as + * {@link ObjectInputStream} and {@link ObjectOutputStream}. + *

+ * Security: serialization can be a security liability, + * and should not be used without defining a list of classes that are + * allowed to be desirialized. Such a list can be specified with the + * jdk.serialFilter system property, for instance. + * See the + * serialization filtering article for more information. + * + * @deprecated This class has been deprecated with no replacement, + * because serialization can be a security liability + */ +@Deprecated +public class CompatibleObjectEncoder extends MessageToByteEncoder { + private final int resetInterval; + private int writtenObjects; + + /** + * Creates a new instance with the reset interval of {@code 16}. + */ + public CompatibleObjectEncoder() { + this(16); // Reset at every sixteen writes + } + + /** + * Creates a new instance. + * + * @param resetInterval + * the number of objects between {@link ObjectOutputStream#reset()}. + * {@code 0} will disable resetting the stream, but the remote + * peer will be at the risk of getting {@link OutOfMemoryError} in + * the long term. + */ + public CompatibleObjectEncoder(int resetInterval) { + this.resetInterval = checkPositiveOrZero(resetInterval, "resetInterval"); + } + + /** + * Creates a new {@link ObjectOutputStream} which wraps the specified + * {@link OutputStream}. Override this method to use a subclass of the + * {@link ObjectOutputStream}. + */ + protected ObjectOutputStream newObjectOutputStream(OutputStream out) throws Exception { + return new ObjectOutputStream(out); + } + + @Override + protected void encode(ChannelHandlerContext ctx, Serializable msg, ByteBuf out) throws Exception { + // Suppress a warning about resource leak since oss is closed below + ObjectOutputStream oos = newObjectOutputStream( + new ByteBufOutputStream(out)); + try { + if (resetInterval != 0) { + // Resetting will prevent OOM on the receiving side. + writtenObjects ++; + if (writtenObjects % resetInterval == 0) { + oos.reset(); + } + } + + oos.writeObject(msg); + oos.flush(); + } finally { + oos.close(); + } + } +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/serialization/ObjectDecoder.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/serialization/ObjectDecoder.java new file mode 100644 index 0000000..ba1c5af --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/serialization/ObjectDecoder.java @@ -0,0 +1,92 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.serialization; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufInputStream; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.LengthFieldBasedFrameDecoder; + +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.StreamCorruptedException; + +/** + * A decoder which deserializes the received {@link ByteBuf}s into Java + * objects. + *

+ * Please note that the serialized form this decoder expects is not + * compatible with the standard {@link ObjectOutputStream}. Please use + * {@link ObjectEncoder} or {@link ObjectEncoderOutputStream} to ensure the + * interoperability with this decoder. + *

+ * Security: serialization can be a security liability, + * and should not be used without defining a list of classes that are + * allowed to be desirialized. Such a list can be specified with the + * jdk.serialFilter system property, for instance. + * See the + * serialization filtering article for more information. + * + * @deprecated This class has been deprecated with no replacement, + * because serialization can be a security liability + */ +@Deprecated +public class ObjectDecoder extends LengthFieldBasedFrameDecoder { + + private final ClassResolver classResolver; + + /** + * Creates a new decoder whose maximum object size is {@code 1048576} + * bytes. If the size of the received object is greater than + * {@code 1048576} bytes, a {@link StreamCorruptedException} will be + * raised. + * + * @param classResolver the {@link ClassResolver} to use for this decoder + */ + public ObjectDecoder(ClassResolver classResolver) { + this(1048576, classResolver); + } + + /** + * Creates a new decoder with the specified maximum object size. + * + * @param maxObjectSize the maximum byte length of the serialized object. + * if the length of the received object is greater + * than this value, {@link StreamCorruptedException} + * will be raised. + * @param classResolver the {@link ClassResolver} which will load the class + * of the serialized object + */ + public ObjectDecoder(int maxObjectSize, ClassResolver classResolver) { + super(maxObjectSize, 0, 4, 0, 4); + this.classResolver = classResolver; + } + + @Override + protected Object decode(ChannelHandlerContext ctx, ByteBuf in) throws Exception { + ByteBuf frame = (ByteBuf) super.decode(ctx, in); + if (frame == null) { + return null; + } + + ObjectInputStream ois = new CompactObjectInputStream(new ByteBufInputStream(frame, true), classResolver); + try { + return ois.readObject(); + } finally { + ois.close(); + } + } +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/serialization/ObjectDecoderInputStream.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/serialization/ObjectDecoderInputStream.java new file mode 100644 index 0000000..dcb09da --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/serialization/ObjectDecoderInputStream.java @@ -0,0 +1,255 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.serialization; + +import io.netty.util.internal.ObjectUtil; + +import java.io.BufferedReader; +import java.io.DataInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.ObjectInput; +import java.io.StreamCorruptedException; + +/** + * An {@link ObjectInput} which is interoperable with {@link ObjectEncoder} + * and {@link ObjectEncoderOutputStream}. + *

+ * Security: serialization can be a security liability, + * and should not be used without defining a list of classes that are + * allowed to be desirialized. Such a list can be specified with the + * jdk.serialFilter system property, for instance. + * See the + * serialization filtering article for more information. + * + * @deprecated This class has been deprecated with no replacement, + * because serialization can be a security liability + */ +@Deprecated +public class ObjectDecoderInputStream extends InputStream implements + ObjectInput { + + private final DataInputStream in; + private final int maxObjectSize; + private final ClassResolver classResolver; + + /** + * Creates a new {@link ObjectInput}. + * + * @param in + * the {@link InputStream} where the serialized form will be + * read from + */ + public ObjectDecoderInputStream(InputStream in) { + this(in, null); + } + + /** + * Creates a new {@link ObjectInput}. + * + * @param in + * the {@link InputStream} where the serialized form will be + * read from + * @param classLoader + * the {@link ClassLoader} which will load the class of the + * serialized object + */ + public ObjectDecoderInputStream(InputStream in, ClassLoader classLoader) { + this(in, classLoader, 1048576); + } + + /** + * Creates a new {@link ObjectInput}. + * + * @param in + * the {@link InputStream} where the serialized form will be + * read from + * @param maxObjectSize + * the maximum byte length of the serialized object. if the length + * of the received object is greater than this value, + * a {@link StreamCorruptedException} will be raised. + */ + public ObjectDecoderInputStream(InputStream in, int maxObjectSize) { + this(in, null, maxObjectSize); + } + + /** + * Creates a new {@link ObjectInput}. + * + * @param in + * the {@link InputStream} where the serialized form will be + * read from + * @param classLoader + * the {@link ClassLoader} which will load the class of the + * serialized object + * @param maxObjectSize + * the maximum byte length of the serialized object. if the length + * of the received object is greater than this value, + * a {@link StreamCorruptedException} will be raised. + */ + public ObjectDecoderInputStream(InputStream in, ClassLoader classLoader, int maxObjectSize) { + ObjectUtil.checkNotNull(in, "in"); + ObjectUtil.checkPositive(maxObjectSize, "maxObjectSize"); + + if (in instanceof DataInputStream) { + this.in = (DataInputStream) in; + } else { + this.in = new DataInputStream(in); + } + classResolver = ClassResolvers.weakCachingResolver(classLoader); + this.maxObjectSize = maxObjectSize; + } + + @Override + public Object readObject() throws ClassNotFoundException, IOException { + int dataLen = readInt(); + if (dataLen <= 0) { + throw new StreamCorruptedException("invalid data length: " + dataLen); + } + if (dataLen > maxObjectSize) { + throw new StreamCorruptedException( + "data length too big: " + dataLen + " (max: " + maxObjectSize + ')'); + } + + return new CompactObjectInputStream(in, classResolver).readObject(); + } + + @Override + public int available() throws IOException { + return in.available(); + } + + @Override + public void close() throws IOException { + in.close(); + } + + // Suppress a warning since the class is not thread-safe + @Override + public void mark(int readlimit) { + in.mark(readlimit); + } + + @Override + public boolean markSupported() { + return in.markSupported(); + } + + // Suppress a warning since the class is not thread-safe + @Override + public int read() throws IOException { + return in.read(); + } + + @Override + public final int read(byte[] b, int off, int len) throws IOException { + return in.read(b, off, len); + } + + @Override + public final int read(byte[] b) throws IOException { + return in.read(b); + } + + @Override + public final boolean readBoolean() throws IOException { + return in.readBoolean(); + } + + @Override + public final byte readByte() throws IOException { + return in.readByte(); + } + + @Override + public final char readChar() throws IOException { + return in.readChar(); + } + + @Override + public final double readDouble() throws IOException { + return in.readDouble(); + } + + @Override + public final float readFloat() throws IOException { + return in.readFloat(); + } + + @Override + public final void readFully(byte[] b, int off, int len) throws IOException { + in.readFully(b, off, len); + } + + @Override + public final void readFully(byte[] b) throws IOException { + in.readFully(b); + } + + @Override + public final int readInt() throws IOException { + return in.readInt(); + } + + /** + * @deprecated Use {@link BufferedReader#readLine()} instead. + */ + @Override + @Deprecated + public final String readLine() throws IOException { + return in.readLine(); + } + + @Override + public final long readLong() throws IOException { + return in.readLong(); + } + + @Override + public final short readShort() throws IOException { + return in.readShort(); + } + + @Override + public final int readUnsignedByte() throws IOException { + return in.readUnsignedByte(); + } + + @Override + public final int readUnsignedShort() throws IOException { + return in.readUnsignedShort(); + } + + @Override + public final String readUTF() throws IOException { + return in.readUTF(); + } + + @Override + public void reset() throws IOException { + in.reset(); + } + + @Override + public long skip(long n) throws IOException { + return in.skip(n); + } + + @Override + public final int skipBytes(int n) throws IOException { + return in.skipBytes(n); + } +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/serialization/ObjectEncoder.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/serialization/ObjectEncoder.java new file mode 100644 index 0000000..c5ab4b3 --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/serialization/ObjectEncoder.java @@ -0,0 +1,74 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.serialization; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufOutputStream; +import io.netty.channel.ChannelHandler.Sharable; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.MessageToByteEncoder; + +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.Serializable; + +/** + * An encoder which serializes a Java object into a {@link ByteBuf}. + *

+ * Please note that the serialized form this encoder produces is not + * compatible with the standard {@link ObjectInputStream}. Please use + * {@link ObjectDecoder} or {@link ObjectDecoderInputStream} to ensure the + * interoperability with this encoder. + *

+ * Security: serialization can be a security liability, + * and should not be used without defining a list of classes that are + * allowed to be desirialized. Such a list can be specified with the + * jdk.serialFilter system property, for instance. + * See the + * serialization filtering article for more information. + * + * @deprecated This class has been deprecated with no replacement, + * because serialization can be a security liability + */ +@Deprecated +@Sharable +public class ObjectEncoder extends MessageToByteEncoder { + private static final byte[] LENGTH_PLACEHOLDER = new byte[4]; + + @Override + protected void encode(ChannelHandlerContext ctx, Serializable msg, ByteBuf out) throws Exception { + int startIdx = out.writerIndex(); + + ByteBufOutputStream bout = new ByteBufOutputStream(out); + ObjectOutputStream oout = null; + try { + bout.write(LENGTH_PLACEHOLDER); + oout = new CompactObjectOutputStream(bout); + oout.writeObject(msg); + oout.flush(); + } finally { + if (oout != null) { + oout.close(); + } else { + bout.close(); + } + } + + int endIdx = out.writerIndex(); + + out.setInt(startIdx, endIdx - startIdx - 4); + } +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/serialization/ObjectEncoderOutputStream.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/serialization/ObjectEncoderOutputStream.java new file mode 100644 index 0000000..3df74f1 --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/serialization/ObjectEncoderOutputStream.java @@ -0,0 +1,194 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.serialization; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufOutputStream; +import io.netty.buffer.Unpooled; +import io.netty.util.internal.ObjectUtil; + +import java.io.DataOutputStream; +import java.io.IOException; +import java.io.ObjectOutput; +import java.io.ObjectOutputStream; +import java.io.OutputStream; + +/** + * An {@link ObjectOutput} which is interoperable with {@link ObjectDecoder} + * and {@link ObjectDecoderInputStream}. + *

+ * Security: serialization can be a security liability, + * and should not be used without defining a list of classes that are + * allowed to be desirialized. Such a list can be specified with the + * jdk.serialFilter system property, for instance. + * See the + * serialization filtering article for more information. + * + * @deprecated This class has been deprecated with no replacement, + * because serialization can be a security liability + */ +@Deprecated +public class ObjectEncoderOutputStream extends OutputStream implements + ObjectOutput { + + private final DataOutputStream out; + private final int estimatedLength; + + /** + * Creates a new {@link ObjectOutput} with the estimated length of 512 + * bytes. + * + * @param out + * the {@link OutputStream} where the serialized form will be + * written out + */ + public ObjectEncoderOutputStream(OutputStream out) { + this(out, 512); + } + + /** + * Creates a new {@link ObjectOutput}. + * + * @param out + * the {@link OutputStream} where the serialized form will be + * written out + * + * @param estimatedLength + * the estimated byte length of the serialized form of an object. + * If the length of the serialized form exceeds this value, the + * internal buffer will be expanded automatically at the cost of + * memory bandwidth. If this value is too big, it will also waste + * memory bandwidth. To avoid unnecessary memory copy or allocation + * cost, please specify the properly estimated value. + */ + public ObjectEncoderOutputStream(OutputStream out, int estimatedLength) { + ObjectUtil.checkNotNull(out, "out"); + ObjectUtil.checkPositiveOrZero(estimatedLength, "estimatedLength"); + + if (out instanceof DataOutputStream) { + this.out = (DataOutputStream) out; + } else { + this.out = new DataOutputStream(out); + } + this.estimatedLength = estimatedLength; + } + + @Override + public void writeObject(Object obj) throws IOException { + ByteBuf buf = Unpooled.buffer(estimatedLength); + try { + // Suppress a warning about resource leak since oout is closed below + ObjectOutputStream oout = new CompactObjectOutputStream( + new ByteBufOutputStream(buf)); + try { + oout.writeObject(obj); + oout.flush(); + } finally { + oout.close(); + } + + int objectSize = buf.readableBytes(); + writeInt(objectSize); + buf.getBytes(0, this, objectSize); + } finally { + buf.release(); + } + } + + @Override + public void write(int b) throws IOException { + out.write(b); + } + + @Override + public void close() throws IOException { + out.close(); + } + + @Override + public void flush() throws IOException { + out.flush(); + } + + public final int size() { + return out.size(); + } + + @Override + public void write(byte[] b, int off, int len) throws IOException { + out.write(b, off, len); + } + + @Override + public void write(byte[] b) throws IOException { + out.write(b); + } + + @Override + public final void writeBoolean(boolean v) throws IOException { + out.writeBoolean(v); + } + + @Override + public final void writeByte(int v) throws IOException { + out.writeByte(v); + } + + @Override + public final void writeBytes(String s) throws IOException { + out.writeBytes(s); + } + + @Override + public final void writeChar(int v) throws IOException { + out.writeChar(v); + } + + @Override + public final void writeChars(String s) throws IOException { + out.writeChars(s); + } + + @Override + public final void writeDouble(double v) throws IOException { + out.writeDouble(v); + } + + @Override + public final void writeFloat(float v) throws IOException { + out.writeFloat(v); + } + + @Override + public final void writeInt(int v) throws IOException { + out.writeInt(v); + } + + @Override + public final void writeLong(long v) throws IOException { + out.writeLong(v); + } + + @Override + public final void writeShort(int v) throws IOException { + out.writeShort(v); + } + + @Override + public final void writeUTF(String str) throws IOException { + out.writeUTF(str); + } +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/serialization/ReferenceMap.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/serialization/ReferenceMap.java new file mode 100644 index 0000000..28b641e --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/serialization/ReferenceMap.java @@ -0,0 +1,102 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.serialization; + +import java.lang.ref.Reference; +import java.util.Collection; +import java.util.Map; +import java.util.Set; + +abstract class ReferenceMap implements Map { + + private final Map> delegate; + + protected ReferenceMap(Map> delegate) { + this.delegate = delegate; + } + + abstract Reference fold(V value); + + private V unfold(Reference ref) { + if (ref == null) { + return null; + } + + return ref.get(); + } + + @Override + public int size() { + return delegate.size(); + } + + @Override + public boolean isEmpty() { + return delegate.isEmpty(); + } + + @Override + public boolean containsKey(Object key) { + return delegate.containsKey(key); + } + + @Override + public boolean containsValue(Object value) { + throw new UnsupportedOperationException(); + } + + @Override + public V get(Object key) { + return unfold(delegate.get(key)); + } + + @Override + public V put(K key, V value) { + return unfold(delegate.put(key, fold(value))); + } + + @Override + public V remove(Object key) { + return unfold(delegate.remove(key)); + } + + @Override + public void putAll(Map m) { + for (Entry entry : m.entrySet()) { + delegate.put(entry.getKey(), fold(entry.getValue())); + } + } + + @Override + public void clear() { + delegate.clear(); + } + + @Override + public Set keySet() { + return delegate.keySet(); + } + + @Override + public Collection values() { + throw new UnsupportedOperationException(); + } + + @Override + public Set> entrySet() { + throw new UnsupportedOperationException(); + } +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/serialization/SoftReferenceMap.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/serialization/SoftReferenceMap.java new file mode 100644 index 0000000..afe2b97 --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/serialization/SoftReferenceMap.java @@ -0,0 +1,33 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.serialization; + +import java.lang.ref.Reference; +import java.lang.ref.SoftReference; +import java.util.Map; + +final class SoftReferenceMap extends ReferenceMap { + + SoftReferenceMap(Map> delegate) { + super(delegate); + } + + @Override + Reference fold(V value) { + return new SoftReference(value); + } + +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/serialization/WeakReferenceMap.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/serialization/WeakReferenceMap.java new file mode 100644 index 0000000..7fc2f45 --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/serialization/WeakReferenceMap.java @@ -0,0 +1,33 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.serialization; + +import java.lang.ref.Reference; +import java.lang.ref.WeakReference; +import java.util.Map; + +final class WeakReferenceMap extends ReferenceMap { + + WeakReferenceMap(Map> delegate) { + super(delegate); + } + + @Override + Reference fold(V value) { + return new WeakReference(value); + } + +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/serialization/package-info.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/serialization/package-info.java new file mode 100644 index 0000000..30473f1 --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/serialization/package-info.java @@ -0,0 +1,32 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/** + * Encoder, decoder and their compatibility stream implementations which + * transform a {@link java.io.Serializable} object into a byte buffer and + * vice versa. + *

+ * Security: serialization can be a security liability, + * and should not be used without defining a list of classes that are + * allowed to be desirialized. Such a list can be specified with the + * jdk.serialFilter system property, for instance. + * See the + * serialization filtering article for more information. + * + * @deprecated This package has been deprecated with no replacement, + * because serialization can be a security liability + */ +package io.netty.handler.codec.serialization; diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/string/LineEncoder.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/string/LineEncoder.java new file mode 100644 index 0000000..75d327b --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/string/LineEncoder.java @@ -0,0 +1,94 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.string; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.channel.ChannelHandler.Sharable; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPipeline; +import io.netty.handler.codec.LineBasedFrameDecoder; +import io.netty.handler.codec.MessageToMessageEncoder; +import io.netty.util.CharsetUtil; +import io.netty.util.internal.ObjectUtil; + +import java.nio.CharBuffer; +import java.nio.charset.Charset; +import java.util.List; + +/** + * Apply a line separator to the requested {@link String} and encode it into a {@link ByteBuf}. + * A typical setup for a text-based line protocol in a TCP/IP socket would be: + *

+ * {@link ChannelPipeline} pipeline = ...;
+ *
+ * // Decoders
+ * pipeline.addLast("frameDecoder", new {@link LineBasedFrameDecoder}(80));
+ * pipeline.addLast("stringDecoder", new {@link StringDecoder}(CharsetUtil.UTF_8));
+ *
+ * // Encoder
+ * pipeline.addLast("lineEncoder", new {@link LineEncoder}(LineSeparator.UNIX, CharsetUtil.UTF_8));
+ * 
+ * and then you can use a {@link String} instead of a {@link ByteBuf} + * as a message: + *
+ * void channelRead({@link ChannelHandlerContext} ctx, {@link String} msg) {
+ *     ch.write("Did you say '" + msg + "'?");
+ * }
+ * 
+ */ +@Sharable +public class LineEncoder extends MessageToMessageEncoder { + + private final Charset charset; + private final byte[] lineSeparator; + + /** + * Creates a new instance with the current system line separator and UTF-8 charset encoding. + */ + public LineEncoder() { + this(LineSeparator.DEFAULT, CharsetUtil.UTF_8); + } + + /** + * Creates a new instance with the specified line separator and UTF-8 charset encoding. + */ + public LineEncoder(LineSeparator lineSeparator) { + this(lineSeparator, CharsetUtil.UTF_8); + } + + /** + * Creates a new instance with the specified character set. + */ + public LineEncoder(Charset charset) { + this(LineSeparator.DEFAULT, charset); + } + + /** + * Creates a new instance with the specified line separator and character set. + */ + public LineEncoder(LineSeparator lineSeparator, Charset charset) { + this.charset = ObjectUtil.checkNotNull(charset, "charset"); + this.lineSeparator = ObjectUtil.checkNotNull(lineSeparator, "lineSeparator").value().getBytes(charset); + } + + @Override + protected void encode(ChannelHandlerContext ctx, CharSequence msg, List out) throws Exception { + ByteBuf buffer = ByteBufUtil.encodeString(ctx.alloc(), CharBuffer.wrap(msg), charset, lineSeparator.length); + buffer.writeBytes(lineSeparator); + out.add(buffer); + } +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/string/LineSeparator.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/string/LineSeparator.java new file mode 100644 index 0000000..35759d8 --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/string/LineSeparator.java @@ -0,0 +1,83 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.string; + +import io.netty.buffer.ByteBufUtil; +import io.netty.util.CharsetUtil; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.StringUtil; + +/** + * A class to represent line separators in different environments. + */ +public final class LineSeparator { + + /** + * The default line separator in the current system. + */ + public static final LineSeparator DEFAULT = new LineSeparator(StringUtil.NEWLINE); + + /** + * The Unix line separator(LF) + */ + public static final LineSeparator UNIX = new LineSeparator("\n"); + + /** + * The Windows line separator(CRLF) + */ + public static final LineSeparator WINDOWS = new LineSeparator("\r\n"); + + private final String value; + + /** + * Create {@link LineSeparator} with the specified {@code lineSeparator} string. + */ + public LineSeparator(String lineSeparator) { + this.value = ObjectUtil.checkNotNull(lineSeparator, "lineSeparator"); + } + + /** + * Return the string value of this line separator. + */ + public String value() { + return value; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof LineSeparator)) { + return false; + } + LineSeparator that = (LineSeparator) o; + return value != null ? value.equals(that.value) : that.value == null; + } + + @Override + public int hashCode() { + return value != null ? value.hashCode() : 0; + } + + /** + * Return a hex dump of the line separator in UTF-8 encoding. + */ + @Override + public String toString() { + return ByteBufUtil.hexDump(value.getBytes(CharsetUtil.UTF_8)); + } +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/string/StringDecoder.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/string/StringDecoder.java new file mode 100644 index 0000000..b043c6b --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/string/StringDecoder.java @@ -0,0 +1,79 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.string; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandler.Sharable; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPipeline; +import io.netty.handler.codec.ByteToMessageDecoder; +import io.netty.handler.codec.DelimiterBasedFrameDecoder; +import io.netty.handler.codec.LineBasedFrameDecoder; +import io.netty.handler.codec.MessageToMessageDecoder; +import io.netty.util.internal.ObjectUtil; + +import java.nio.charset.Charset; +import java.util.List; + +/** + * Decodes a received {@link ByteBuf} into a {@link String}. Please + * note that this decoder must be used with a proper {@link ByteToMessageDecoder} + * such as {@link DelimiterBasedFrameDecoder} or {@link LineBasedFrameDecoder} + * if you are using a stream-based transport such as TCP/IP. A typical setup for a + * text-based line protocol in a TCP/IP socket would be: + *
+ * {@link ChannelPipeline} pipeline = ...;
+ *
+ * // Decoders
+ * pipeline.addLast("frameDecoder", new {@link LineBasedFrameDecoder}(80));
+ * pipeline.addLast("stringDecoder", new {@link StringDecoder}(CharsetUtil.UTF_8));
+ *
+ * // Encoder
+ * pipeline.addLast("stringEncoder", new {@link StringEncoder}(CharsetUtil.UTF_8));
+ * 
+ * and then you can use a {@link String} instead of a {@link ByteBuf} + * as a message: + *
+ * void channelRead({@link ChannelHandlerContext} ctx, {@link String} msg) {
+ *     ch.write("Did you say '" + msg + "'?\n");
+ * }
+ * 
+ */ +@Sharable +public class StringDecoder extends MessageToMessageDecoder { + + // TODO Use CharsetDecoder instead. + private final Charset charset; + + /** + * Creates a new instance with the current system character set. + */ + public StringDecoder() { + this(Charset.defaultCharset()); + } + + /** + * Creates a new instance with the specified character set. + */ + public StringDecoder(Charset charset) { + this.charset = ObjectUtil.checkNotNull(charset, "charset"); + } + + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf msg, List out) throws Exception { + out.add(msg.toString(charset)); + } +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/string/StringEncoder.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/string/StringEncoder.java new file mode 100644 index 0000000..53d7968 --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/string/StringEncoder.java @@ -0,0 +1,79 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.string; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.channel.ChannelHandler.Sharable; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPipeline; +import io.netty.handler.codec.LineBasedFrameDecoder; +import io.netty.handler.codec.MessageToMessageEncoder; +import io.netty.util.internal.ObjectUtil; + +import java.nio.CharBuffer; +import java.nio.charset.Charset; +import java.util.List; + +/** + * Encodes the requested {@link String} into a {@link ByteBuf}. + * A typical setup for a text-based line protocol in a TCP/IP socket would be: + *
+ * {@link ChannelPipeline} pipeline = ...;
+ *
+ * // Decoders
+ * pipeline.addLast("frameDecoder", new {@link LineBasedFrameDecoder}(80));
+ * pipeline.addLast("stringDecoder", new {@link StringDecoder}(CharsetUtil.UTF_8));
+ *
+ * // Encoder
+ * pipeline.addLast("stringEncoder", new {@link StringEncoder}(CharsetUtil.UTF_8));
+ * 
+ * and then you can use a {@link String} instead of a {@link ByteBuf} + * as a message: + *
+ * void channelRead({@link ChannelHandlerContext} ctx, {@link String} msg) {
+ *     ch.write("Did you say '" + msg + "'?\n");
+ * }
+ * 
+ */ +@Sharable +public class StringEncoder extends MessageToMessageEncoder { + + private final Charset charset; + + /** + * Creates a new instance with the current system character set. + */ + public StringEncoder() { + this(Charset.defaultCharset()); + } + + /** + * Creates a new instance with the specified character set. + */ + public StringEncoder(Charset charset) { + this.charset = ObjectUtil.checkNotNull(charset, "charset"); + } + + @Override + protected void encode(ChannelHandlerContext ctx, CharSequence msg, List out) throws Exception { + if (msg.length() == 0) { + return; + } + + out.add(ByteBufUtil.encodeString(ctx.alloc(), CharBuffer.wrap(msg), charset)); + } +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/string/package-info.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/string/package-info.java new file mode 100644 index 0000000..ba2b09b --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/string/package-info.java @@ -0,0 +1,21 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/** + * Encoder and decoder which transform a {@link java.lang.String} into a + * {@link io.netty.buffer.ByteBuf} and vice versa. + */ +package io.netty.handler.codec.string; diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/xml/XmlFrameDecoder.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/xml/XmlFrameDecoder.java new file mode 100644 index 0000000..ed169a6 --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/xml/XmlFrameDecoder.java @@ -0,0 +1,245 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.xml; + +import static io.netty.util.internal.ObjectUtil.checkPositive; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.ByteToMessageDecoder; +import io.netty.handler.codec.CorruptedFrameException; +import io.netty.handler.codec.TooLongFrameException; + +import java.util.List; + +/** + * A frame decoder for single separate XML based message streams. + *

+ * A couple examples will better help illustrate + * what this decoder actually does. + *

+ * Given an input array of bytes split over 3 frames like this: + *

+ * +-----+-----+-----------+
+ * | <an | Xml | Element/> |
+ * +-----+-----+-----------+
+ * 
+ *

+ * this decoder would output a single frame: + *

+ *

+ * +-----------------+
+ * | <anXmlElement/> |
+ * +-----------------+
+ * 
+ * + * Given an input array of bytes split over 5 frames like this: + *
+ * +-----+-----+-----------+-----+----------------------------------+
+ * | <an | Xml | Element/> | <ro | ot><child>content</child></root> |
+ * +-----+-----+-----------+-----+----------------------------------+
+ * 
+ *

+ * this decoder would output two frames: + *

+ *

+ * +-----------------+-------------------------------------+
+ * | <anXmlElement/> | <root><child>content</child></root> |
+ * +-----------------+-------------------------------------+
+ * 
+ * + *

+ * The byte stream is expected to be in UTF-8 character encoding or ASCII. The current implementation + * uses direct {@code byte} to {@code char} cast and then compares that {@code char} to a few low range + * ASCII characters like {@code '<'}, {@code '>'} or {@code '/'}. UTF-8 is not using low range [0..0x7F] + * byte values for multibyte codepoint representations therefore fully supported by this implementation. + *

+ * Please note that this decoder is not suitable for + * xml streaming protocols such as + * XMPP, + * where an initial xml element opens the stream and only + * gets closed at the end of the session, although this class + * could probably allow for such type of message flow with + * minor modifications. + */ +public class XmlFrameDecoder extends ByteToMessageDecoder { + + private final int maxFrameLength; + + public XmlFrameDecoder(int maxFrameLength) { + this.maxFrameLength = checkPositive(maxFrameLength, "maxFrameLength"); + } + + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { + boolean openingBracketFound = false; + boolean atLeastOneXmlElementFound = false; + boolean inCDATASection = false; + long openBracketsCount = 0; + int length = 0; + int leadingWhiteSpaceCount = 0; + final int bufferLength = in.writerIndex(); + + if (bufferLength > maxFrameLength) { + // bufferLength exceeded maxFrameLength; dropping frame + in.skipBytes(in.readableBytes()); + fail(bufferLength); + return; + } + + for (int i = in.readerIndex(); i < bufferLength; i++) { + final byte readByte = in.getByte(i); + if (!openingBracketFound && Character.isWhitespace(readByte)) { + // xml has not started and whitespace char found + leadingWhiteSpaceCount++; + } else if (!openingBracketFound && readByte != '<') { + // garbage found before xml start + fail(ctx); + in.skipBytes(in.readableBytes()); + return; + } else if (!inCDATASection && readByte == '<') { + openingBracketFound = true; + + if (i < bufferLength - 1) { + final byte peekAheadByte = in.getByte(i + 1); + if (peekAheadByte == '/') { + // found we can decrement openBracketsCount + if (in.getByte(peekFurtherAheadIndex) == '>') { + openBracketsCount--; + break; + } + peekFurtherAheadIndex++; + } + } else if (isValidStartCharForXmlElement(peekAheadByte)) { + atLeastOneXmlElementFound = true; + // char after < is a valid xml element start char, + // incrementing openBracketsCount + openBracketsCount++; + } else if (peekAheadByte == '!') { + if (isCommentBlockStart(in, i)) { + // start found + openBracketsCount++; + } else if (isCDATABlockStart(in, i)) { + // start found + openBracketsCount++; + } + } + } else if (!inCDATASection && readByte == '/') { + if (i < bufferLength - 1 && in.getByte(i + 1) == '>') { + // found />, decrementing openBracketsCount + openBracketsCount--; + } + } else if (readByte == '>') { + length = i + 1; + + if (i - 1 > -1) { + final byte peekBehindByte = in.getByte(i - 1); + + if (!inCDATASection) { + if (peekBehindByte == '?') { + // an tag was closed + openBracketsCount--; + } else if (peekBehindByte == '-' && i - 2 > -1 && in.getByte(i - 2) == '-') { + // a was closed + openBracketsCount--; + } + } else if (peekBehindByte == ']' && i - 2 > -1 && in.getByte(i - 2) == ']') { + // a block was closed + openBracketsCount--; + inCDATASection = false; + } + } + + if (atLeastOneXmlElementFound && openBracketsCount == 0) { + // xml is balanced, bailing out + break; + } + } + } + + final int readerIndex = in.readerIndex(); + int xmlElementLength = length - readerIndex; + + if (openBracketsCount == 0 && xmlElementLength > 0) { + if (readerIndex + xmlElementLength >= bufferLength) { + xmlElementLength = in.readableBytes(); + } + final ByteBuf frame = + extractFrame(in, readerIndex + leadingWhiteSpaceCount, xmlElementLength - leadingWhiteSpaceCount); + in.skipBytes(xmlElementLength); + out.add(frame); + } + } + + private void fail(long frameLength) { + if (frameLength > 0) { + throw new TooLongFrameException( + "frame length exceeds " + maxFrameLength + ": " + frameLength + " - discarded"); + } else { + throw new TooLongFrameException( + "frame length exceeds " + maxFrameLength + " - discarding"); + } + } + + private static void fail(ChannelHandlerContext ctx) { + ctx.fireExceptionCaught(new CorruptedFrameException("frame contains content before the xml starts")); + } + + private static ByteBuf extractFrame(ByteBuf buffer, int index, int length) { + return buffer.copy(index, length); + } + + /** + * Asks whether the given byte is a valid + * start char for an xml element name. + *

+ * Please refer to the + * NameStartChar + * formal definition in the W3C XML spec for further info. + * + * @param b the input char + * @return true if the char is a valid start char + */ + private static boolean isValidStartCharForXmlElement(final byte b) { + return b >= 'a' && b <= 'z' || b >= 'A' && b <= 'Z' || b == ':' || b == '_'; + } + + private static boolean isCommentBlockStart(final ByteBuf in, final int i) { + return i < in.writerIndex() - 3 + && in.getByte(i + 2) == '-' + && in.getByte(i + 3) == '-'; + } + + private static boolean isCDATABlockStart(final ByteBuf in, final int i) { + return i < in.writerIndex() - 8 + && in.getByte(i + 2) == '[' + && in.getByte(i + 3) == 'C' + && in.getByte(i + 4) == 'D' + && in.getByte(i + 5) == 'A' + && in.getByte(i + 6) == 'T' + && in.getByte(i + 7) == 'A' + && in.getByte(i + 8) == '['; + } + +} diff --git a/netty-handler-codec/src/main/java/io/netty/handler/codec/xml/package-info.java b/netty-handler-codec/src/main/java/io/netty/handler/codec/xml/package-info.java new file mode 100644 index 0000000..91d77f8 --- /dev/null +++ b/netty-handler-codec/src/main/java/io/netty/handler/codec/xml/package-info.java @@ -0,0 +1,20 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/** + * Xml specific codecs. + */ +package io.netty.handler.codec.xml; diff --git a/netty-handler-codec/src/main/java/module-info.java b/netty-handler-codec/src/main/java/module-info.java new file mode 100644 index 0000000..6a1fac6 --- /dev/null +++ b/netty-handler-codec/src/main/java/module-info.java @@ -0,0 +1,11 @@ +module org.xbib.io.netty.handler.codec { + exports io.netty.handler.codec; + exports io.netty.handler.codec.base64; + exports io.netty.handler.codec.bytes; + exports io.netty.handler.codec.json; + exports io.netty.handler.codec.string; + exports io.netty.handler.codec.xml; + requires org.xbib.io.netty.buffer; + requires org.xbib.io.netty.channel; + requires org.xbib.io.netty.util; +} diff --git a/netty-handler-codec/src/test/java/io/netty/handler/codec/ByteToMessageCodecTest.java b/netty-handler-codec/src/test/java/io/netty/handler/codec/ByteToMessageCodecTest.java new file mode 100644 index 0000000..0e2d720 --- /dev/null +++ b/netty-handler-codec/src/test/java/io/netty/handler/codec/ByteToMessageCodecTest.java @@ -0,0 +1,113 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.embedded.EmbeddedChannel; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class ByteToMessageCodecTest { + + @Test + public void testSharable() { + assertThrows(IllegalStateException.class, new Executable() { + @Override + public void execute() { + new InvalidByteToMessageCodec(); + } + }); + } + + @Test + public void testSharable2() { + assertThrows(IllegalStateException.class, new Executable() { + @Override + public void execute() { + new InvalidByteToMessageCodec2(); + } + }); + } + + @Test + public void testForwardPendingData() { + ByteToMessageCodec codec = new ByteToMessageCodec() { + @Override + protected void encode(ChannelHandlerContext ctx, Integer msg, ByteBuf out) { + out.writeInt(msg); + } + + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) { + if (in.readableBytes() >= 4) { + out.add(in.readInt()); + } + } + }; + + ByteBuf buffer = Unpooled.buffer(); + buffer.writeInt(1); + buffer.writeByte('0'); + + EmbeddedChannel ch = new EmbeddedChannel(codec); + assertTrue(ch.writeInbound(buffer)); + ch.pipeline().remove(codec); + assertTrue(ch.finish()); + assertEquals(1, (Integer) ch.readInbound()); + + ByteBuf buf = ch.readInbound(); + assertEquals(Unpooled.wrappedBuffer(new byte[]{'0'}), buf); + buf.release(); + assertNull(ch.readInbound()); + assertNull(ch.readOutbound()); + } + + @ChannelHandler.Sharable + private static final class InvalidByteToMessageCodec extends ByteToMessageCodec { + InvalidByteToMessageCodec() { + super(true); + } + + @Override + protected void encode(ChannelHandlerContext ctx, Integer msg, ByteBuf out) throws Exception { } + + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { } + } + + @ChannelHandler.Sharable + private static final class InvalidByteToMessageCodec2 extends ByteToMessageCodec { + InvalidByteToMessageCodec2() { + super(Integer.class, true); + } + + @Override + protected void encode(ChannelHandlerContext ctx, Integer msg, ByteBuf out) throws Exception { } + + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { } + } +} diff --git a/netty-handler-codec/src/test/java/io/netty/handler/codec/ByteToMessageDecoderTest.java b/netty-handler-codec/src/test/java/io/netty/handler/codec/ByteToMessageDecoderTest.java new file mode 100644 index 0000000..8597bf5 --- /dev/null +++ b/netty-handler-codec/src/test/java/io/netty/handler/codec/ByteToMessageDecoderTest.java @@ -0,0 +1,656 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec; + +import io.netty.buffer.AbstractByteBufAllocator; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.buffer.UnpooledByteBufAllocator; +import io.netty.buffer.UnpooledHeapByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelOutboundHandlerAdapter; +import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.channel.socket.ChannelInputShutdownEvent; +import io.netty.util.CharsetUtil; +import io.netty.util.internal.PlatformDependent; +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingDeque; +import java.util.concurrent.atomic.AtomicBoolean; + +import static io.netty.buffer.Unpooled.wrappedBuffer; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class ByteToMessageDecoderTest { + + @Test + public void testRemoveItself() { + EmbeddedChannel channel = new EmbeddedChannel(new ByteToMessageDecoder() { + private boolean removed; + + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { + assertFalse(removed); + in.readByte(); + ctx.pipeline().remove(this); + removed = true; + } + }); + + ByteBuf buf = Unpooled.wrappedBuffer(new byte[] {'a', 'b', 'c'}); + channel.writeInbound(buf.copy()); + ByteBuf b = channel.readInbound(); + assertEquals(b, buf.skipBytes(1)); + b.release(); + buf.release(); + } + + @Test + public void testRemoveItselfWriteBuffer() { + final ByteBuf buf = Unpooled.buffer().writeBytes(new byte[] {'a', 'b', 'c'}); + EmbeddedChannel channel = new EmbeddedChannel(new ByteToMessageDecoder() { + private boolean removed; + + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { + assertFalse(removed); + in.readByte(); + ctx.pipeline().remove(this); + + // This should not let it keep call decode + buf.writeByte('d'); + removed = true; + } + }); + + channel.writeInbound(buf.copy()); + ByteBuf expected = Unpooled.wrappedBuffer(new byte[] {'b', 'c'}); + ByteBuf b = channel.readInbound(); + assertEquals(expected, b); + expected.release(); + buf.release(); + b.release(); + } + + /** + * Verifies that internal buffer of the ByteToMessageDecoder is released once decoder is removed from pipeline. In + * this case input is read fully. + */ + @Test + public void testInternalBufferClearReadAll() { + final ByteBuf buf = Unpooled.buffer().writeBytes(new byte[] {'a'}); + EmbeddedChannel channel = newInternalBufferTestChannel(); + assertFalse(channel.writeInbound(buf)); + assertFalse(channel.finish()); + } + + /** + * Verifies that internal buffer of the ByteToMessageDecoder is released once decoder is removed from pipeline. In + * this case input was not fully read. + */ + @Test + public void testInternalBufferClearReadPartly() { + final ByteBuf buf = Unpooled.buffer().writeBytes(new byte[] {'a', 'b'}); + EmbeddedChannel channel = newInternalBufferTestChannel(); + assertTrue(channel.writeInbound(buf)); + assertTrue(channel.finish()); + ByteBuf expected = Unpooled.wrappedBuffer(new byte[] {'b'}); + ByteBuf b = channel.readInbound(); + assertEquals(expected, b); + assertNull(channel.readInbound()); + expected.release(); + b.release(); + } + + private EmbeddedChannel newInternalBufferTestChannel() { + return new EmbeddedChannel(new ByteToMessageDecoder() { + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { + ByteBuf byteBuf = internalBuffer(); + assertEquals(1, byteBuf.refCnt()); + in.readByte(); + // Removal from pipeline should clear internal buffer + ctx.pipeline().remove(this); + } + + @Override + protected void handlerRemoved0(ChannelHandlerContext ctx) throws Exception { + assertCumulationReleased(internalBuffer()); + } + }); + } + + @Test + public void handlerRemovedWillNotReleaseBufferIfDecodeInProgress() { + EmbeddedChannel channel = new EmbeddedChannel(new ByteToMessageDecoder() { + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { + ctx.pipeline().remove(this); + assertTrue(in.refCnt() != 0); + } + + @Override + protected void handlerRemoved0(ChannelHandlerContext ctx) throws Exception { + assertCumulationReleased(internalBuffer()); + } + }); + byte[] bytes = new byte[1024]; + PlatformDependent.threadLocalRandom().nextBytes(bytes); + + assertTrue(channel.writeInbound(Unpooled.wrappedBuffer(bytes))); + assertTrue(channel.finishAndReleaseAll()); + } + + private static void assertCumulationReleased(ByteBuf byteBuf) { + assertTrue(byteBuf == null || byteBuf == Unpooled.EMPTY_BUFFER || byteBuf.refCnt() == 0, + "unexpected value: " + byteBuf); + } + + @Test + public void testFireChannelReadCompleteOnInactive() throws InterruptedException { + final BlockingQueue queue = new LinkedBlockingDeque(); + final ByteBuf buf = Unpooled.buffer().writeBytes(new byte[] {'a', 'b'}); + EmbeddedChannel channel = new EmbeddedChannel(new ByteToMessageDecoder() { + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { + int readable = in.readableBytes(); + assertTrue(readable > 0); + in.skipBytes(readable); + } + + @Override + protected void decodeLast(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { + assertFalse(in.isReadable()); + out.add("data"); + } + }, new ChannelInboundHandlerAdapter() { + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + queue.add(3); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + queue.add(1); + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + if (!ctx.channel().isActive()) { + queue.add(2); + } + } + }); + assertFalse(channel.writeInbound(buf)); + channel.finish(); + assertEquals(1, (int) queue.take()); + assertEquals(2, (int) queue.take()); + assertEquals(3, (int) queue.take()); + assertTrue(queue.isEmpty()); + } + + // See https://github.com/netty/netty/issues/4635 + @Test + public void testRemoveWhileInCallDecode() { + final Object upgradeMessage = new Object(); + final ByteToMessageDecoder decoder = new ByteToMessageDecoder() { + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { + assertEquals('a', in.readByte()); + out.add(upgradeMessage); + } + }; + + EmbeddedChannel channel = new EmbeddedChannel(decoder, new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if (msg == upgradeMessage) { + ctx.pipeline().remove(decoder); + return; + } + ctx.fireChannelRead(msg); + } + }); + + ByteBuf buf = Unpooled.wrappedBuffer(new byte[] { 'a', 'b', 'c' }); + assertTrue(channel.writeInbound(buf.copy())); + ByteBuf b = channel.readInbound(); + assertEquals(b, buf.skipBytes(1)); + assertFalse(channel.finish()); + buf.release(); + b.release(); + } + + @Test + public void testDecodeLastEmptyBuffer() { + EmbeddedChannel channel = new EmbeddedChannel(new ByteToMessageDecoder() { + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { + int readable = in.readableBytes(); + assertTrue(readable > 0); + out.add(in.readBytes(readable)); + } + }); + byte[] bytes = new byte[1024]; + PlatformDependent.threadLocalRandom().nextBytes(bytes); + + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(bytes))); + assertBuffer(Unpooled.wrappedBuffer(bytes), (ByteBuf) channel.readInbound()); + assertNull(channel.readInbound()); + assertFalse(channel.finish()); + assertNull(channel.readInbound()); + } + + @Test + public void testDecodeLastNonEmptyBuffer() { + EmbeddedChannel channel = new EmbeddedChannel(new ByteToMessageDecoder() { + private boolean decodeLast; + + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { + int readable = in.readableBytes(); + assertTrue(readable > 0); + if (!decodeLast && readable == 1) { + return; + } + out.add(in.readBytes(decodeLast ? readable : readable - 1)); + } + + @Override + protected void decodeLast(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { + assertFalse(decodeLast); + decodeLast = true; + super.decodeLast(ctx, in, out); + } + }); + byte[] bytes = new byte[1024]; + PlatformDependent.threadLocalRandom().nextBytes(bytes); + + assertTrue(channel.writeInbound(Unpooled.copiedBuffer(bytes))); + assertBuffer(Unpooled.wrappedBuffer(bytes, 0, bytes.length - 1), (ByteBuf) channel.readInbound()); + assertNull(channel.readInbound()); + assertTrue(channel.finish()); + assertBuffer(Unpooled.wrappedBuffer(bytes, bytes.length - 1, 1), (ByteBuf) channel.readInbound()); + assertNull(channel.readInbound()); + } + + private static void assertBuffer(ByteBuf expected, ByteBuf buffer) { + try { + assertEquals(expected, buffer); + } finally { + buffer.release(); + expected.release(); + } + } + + @Test + public void testReadOnlyBuffer() { + EmbeddedChannel channel = new EmbeddedChannel(new ByteToMessageDecoder() { + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { + } + }); + assertFalse(channel.writeInbound(Unpooled.buffer(8).writeByte(1).asReadOnly())); + assertFalse(channel.writeInbound(Unpooled.wrappedBuffer(new byte[] { (byte) 2 }))); + assertFalse(channel.finish()); + } + + static class WriteFailingByteBuf extends UnpooledHeapByteBuf { + private final Error error = new Error(); + private int untilFailure; + + WriteFailingByteBuf(int untilFailure, int capacity) { + super(UnpooledByteBufAllocator.DEFAULT, capacity, capacity); + this.untilFailure = untilFailure; + } + + @Override + public ByteBuf setBytes(int index, ByteBuf src, int srcIndex, int length) { + if (--untilFailure <= 0) { + throw error; + } + return super.setBytes(index, src, srcIndex, length); + } + + Error writeError() { + return error; + } + } + + @Test + public void releaseWhenMergeCumulateThrows() { + WriteFailingByteBuf oldCumulation = new WriteFailingByteBuf(1, 64); + oldCumulation.writeZero(1); + ByteBuf in = Unpooled.buffer().writeZero(12); + + Throwable thrown = null; + try { + ByteToMessageDecoder.MERGE_CUMULATOR.cumulate(UnpooledByteBufAllocator.DEFAULT, oldCumulation, in); + } catch (Throwable t) { + thrown = t; + } + + assertSame(oldCumulation.writeError(), thrown); + assertEquals(0, in.refCnt()); + assertEquals(1, oldCumulation.refCnt()); + oldCumulation.release(); + } + + @Test + public void releaseWhenMergeCumulateThrowsInExpand() { + releaseWhenMergeCumulateThrowsInExpand(1, true); + releaseWhenMergeCumulateThrowsInExpand(2, true); + releaseWhenMergeCumulateThrowsInExpand(3, false); // sentinel test case + } + + private void releaseWhenMergeCumulateThrowsInExpand(int untilFailure, boolean shouldFail) { + ByteBuf oldCumulation = UnpooledByteBufAllocator.DEFAULT.heapBuffer(8, 8).writeZero(1); + final WriteFailingByteBuf newCumulation = new WriteFailingByteBuf(untilFailure, 16); + + ByteBufAllocator allocator = new AbstractByteBufAllocator(false) { + @Override + public boolean isDirectBufferPooled() { + return false; + } + + @Override + protected ByteBuf newHeapBuffer(int initialCapacity, int maxCapacity) { + return newCumulation; + } + + @Override + protected ByteBuf newDirectBuffer(int initialCapacity, int maxCapacity) { + throw new UnsupportedOperationException(); + } + }; + + ByteBuf in = Unpooled.buffer().writeZero(12); + Throwable thrown = null; + try { + ByteToMessageDecoder.MERGE_CUMULATOR.cumulate(allocator, oldCumulation, in); + } catch (Throwable t) { + thrown = t; + } + + assertEquals(0, in.refCnt()); + + if (shouldFail) { + assertSame(newCumulation.writeError(), thrown); + assertEquals(1, oldCumulation.refCnt()); + oldCumulation.release(); + assertEquals(0, newCumulation.refCnt()); + } else { + assertNull(thrown); + assertEquals(0, oldCumulation.refCnt()); + assertEquals(1, newCumulation.refCnt()); + newCumulation.release(); + } + } + + @Test + public void releaseWhenCompositeCumulateThrows() { + final Error error = new Error(); + + ByteBuf cumulation = new CompositeByteBuf(UnpooledByteBufAllocator.DEFAULT, false, 64) { + @Override + public CompositeByteBuf addComponent(boolean increaseWriterIndex, ByteBuf buffer) { + throw error; + } + @Override + public CompositeByteBuf addFlattenedComponents(boolean increaseWriterIndex, ByteBuf buffer) { + throw error; + } + }.writeZero(1); + ByteBuf in = Unpooled.buffer().writeZero(12); + try { + ByteToMessageDecoder.COMPOSITE_CUMULATOR.cumulate(UnpooledByteBufAllocator.DEFAULT, cumulation, in); + fail(); + } catch (Error expected) { + assertSame(error, expected); + assertEquals(0, in.refCnt()); + cumulation.release(); + } + } + + private static final class ReadInterceptingHandler extends ChannelOutboundHandlerAdapter { + private int readsTriggered; + + @Override + public void read(ChannelHandlerContext ctx) throws Exception { + readsTriggered++; + super.read(ctx); + } + } + + @Test + public void testDoesNotOverRead() { + ReadInterceptingHandler interceptor = new ReadInterceptingHandler(); + + EmbeddedChannel channel = new EmbeddedChannel(); + channel.config().setAutoRead(false); + channel.pipeline().addLast(interceptor, new FixedLengthFrameDecoder(3)); + assertEquals(0, interceptor.readsTriggered); + + // 0 complete frames, 1 partial frame: SHOULD trigger a read + channel.writeInbound(wrappedBuffer(new byte[] { 0, 1 })); + assertEquals(1, interceptor.readsTriggered); + + // 2 complete frames, 0 partial frames: should NOT trigger a read + channel.writeInbound(wrappedBuffer(new byte[] { 2 }), wrappedBuffer(new byte[] { 3, 4, 5 })); + assertEquals(1, interceptor.readsTriggered); + + // 1 complete frame, 1 partial frame: should NOT trigger a read + channel.writeInbound(wrappedBuffer(new byte[] { 6, 7, 8 }), wrappedBuffer(new byte[] { 9 })); + assertEquals(1, interceptor.readsTriggered); + + // 1 complete frame, 1 partial frame: should NOT trigger a read + channel.writeInbound(wrappedBuffer(new byte[] { 10, 11 }), wrappedBuffer(new byte[] { 12 })); + assertEquals(1, interceptor.readsTriggered); + + // 0 complete frames, 1 partial frame: SHOULD trigger a read + channel.writeInbound(wrappedBuffer(new byte[] { 13 })); + assertEquals(2, interceptor.readsTriggered); + + // 1 complete frame, 0 partial frames: should NOT trigger a read + channel.writeInbound(wrappedBuffer(new byte[] { 14 })); + assertEquals(2, interceptor.readsTriggered); + + for (int i = 0; i < 5; i++) { + ByteBuf read = channel.readInbound(); + assertEquals(i * 3 + 0, read.getByte(0)); + assertEquals(i * 3 + 1, read.getByte(1)); + assertEquals(i * 3 + 2, read.getByte(2)); + read.release(); + } + assertFalse(channel.finish()); + } + + @Test + public void testDisorder() { + ByteToMessageDecoder decoder = new ByteToMessageDecoder() { + int count; + + //read 4 byte then remove this decoder + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) { + out.add(in.readByte()); + if (++count >= 4) { + ctx.pipeline().remove(this); + } + } + }; + EmbeddedChannel channel = new EmbeddedChannel(decoder); + assertTrue(channel.writeInbound(Unpooled.wrappedBuffer(new byte[]{1, 2, 3, 4, 5}))); + assertEquals((byte) 1, (Byte) channel.readInbound()); + assertEquals((byte) 2, (Byte) channel.readInbound()); + assertEquals((byte) 3, (Byte) channel.readInbound()); + assertEquals((byte) 4, (Byte) channel.readInbound()); + ByteBuf buffer5 = channel.readInbound(); + assertEquals((byte) 5, buffer5.readByte()); + assertFalse(buffer5.isReadable()); + assertTrue(buffer5.release()); + assertFalse(channel.finish()); + } + + @Test + public void testDecodeLast() { + final AtomicBoolean removeHandler = new AtomicBoolean(); + EmbeddedChannel channel = new EmbeddedChannel(new ByteToMessageDecoder() { + + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) { + if (removeHandler.get()) { + ctx.pipeline().remove(this); + } + } + }); + byte[] bytes = new byte[1024]; + PlatformDependent.threadLocalRandom().nextBytes(bytes); + + assertFalse(channel.writeInbound(Unpooled.copiedBuffer(bytes))); + assertNull(channel.readInbound()); + removeHandler.set(true); + // This should trigger channelInputClosed(...) + channel.pipeline().fireUserEventTriggered(ChannelInputShutdownEvent.INSTANCE); + + assertTrue(channel.finish()); + assertBuffer(Unpooled.wrappedBuffer(bytes), (ByteBuf) channel.readInbound()); + assertNull(channel.readInbound()); + } + + @Test + void testUnexpectRead() { + EmbeddedChannel channel = new EmbeddedChannel(); + channel.config().setAutoRead(false); + ReadInterceptingHandler interceptor = new ReadInterceptingHandler(); + channel.pipeline().addLast( + interceptor, + new SimpleChannelInboundHandler() { + @Override + protected void channelRead0(ChannelHandlerContext ctx, ByteBuf msg) throws Exception { + ctx.pipeline().replace(this, "fix", new FixedLengthFrameDecoder(3)); + } + } + ); + + assertFalse(channel.writeInbound(Unpooled.wrappedBuffer(new byte[]{1}))); + assertEquals(0, interceptor.readsTriggered); + assertNotNull(channel.pipeline().get(FixedLengthFrameDecoder.class)); + assertFalse(channel.finish()); + } + + @Test + public void testReuseInputBufferJustLargeEnoughToContainMessage_MergeCumulator() { + testReusedBuffer(Unpooled.buffer(16), false, ByteToMessageDecoder.MERGE_CUMULATOR); + } + + @Test + public void testReuseInputBufferJustLargeEnoughToContainMessagePartiallyReceived2x_MergeCumulator() { + testReusedBuffer(Unpooled.buffer(16), true, ByteToMessageDecoder.MERGE_CUMULATOR); + } + + @Test + public void testReuseInputBufferSufficientlyLargeToContainDuplicateMessage_MergeCumulator() { + testReusedBuffer(Unpooled.buffer(1024), false, ByteToMessageDecoder.MERGE_CUMULATOR); + } + + @Test + public void testReuseInputBufferSufficientlyLargeToContainDuplicateMessagePartiallyReceived2x_MergeCumulator() { + testReusedBuffer(Unpooled.buffer(1024), true, ByteToMessageDecoder.MERGE_CUMULATOR); + } + + @Test + public void testReuseInputBufferJustLargeEnoughToContainMessage_CompositeCumulator() { + testReusedBuffer(Unpooled.buffer(16), false, ByteToMessageDecoder.COMPOSITE_CUMULATOR); + } + + @Test + public void testReuseInputBufferJustLargeEnoughToContainMessagePartiallyReceived2x_CompositeCumulator() { + testReusedBuffer(Unpooled.buffer(16), true, ByteToMessageDecoder.COMPOSITE_CUMULATOR); + } + + @Test + public void testReuseInputBufferSufficientlyLargeToContainDuplicateMessage_CompositeCumulator() { + testReusedBuffer(Unpooled.buffer(1024), false, ByteToMessageDecoder.COMPOSITE_CUMULATOR); + } + + @Test + public void testReuseInputBufferSufficientlyLargeToContainDuplicateMessagePartiallyReceived2x_CompositeCumulator() { + testReusedBuffer(Unpooled.buffer(1024), true, ByteToMessageDecoder.COMPOSITE_CUMULATOR); + } + + static void testReusedBuffer(ByteBuf buffer, boolean secondPartial, ByteToMessageDecoder.Cumulator cumulator) { + ByteToMessageDecoder decoder = new ByteToMessageDecoder() { + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) { + while (in.readableBytes() >= 4) { + int index = in.readerIndex(); + int len = in.readInt(); + assert len < (1 << 30) : "In-plausibly long message: " + len; + if (in.readableBytes() >= len) { + byte[] bytes = new byte[len]; + in.readBytes(bytes); + String message = new String(bytes, CharsetUtil.UTF_8); + out.add(message); + } else { + in.readerIndex(index); + return; + } + } + } + }; + decoder.setCumulator(cumulator); + EmbeddedChannel channel = new EmbeddedChannel(decoder); + + buffer.retain(); // buffer is allocated from the pool, the pool would call retain() + buffer.writeInt(11); // total length of message + buffer.writeByte('h').writeByte('e').writeByte('l').writeByte('l'); + if (secondPartial) { + assertFalse(channel.writeInbound(buffer)); // try reading incomplete message + assertTrue(channel.inboundMessages().isEmpty()); + assertEquals(0, buffer.readerIndex(), "Incomplete message should still be readable in buffer"); + buffer.retain(); // buffer is allocated from the pool - reusing same buffer, the pool would call retain() + } + buffer.writeByte('o').writeByte(' '); + assertFalse(channel.writeInbound(buffer)); // try reading incomplete message + assertTrue(channel.inboundMessages().isEmpty()); + assertEquals(0, buffer.readerIndex(), "Incomplete message should still be readable in buffer"); + + buffer.retain(); // buffer is allocated from the pool - reusing same buffer, the pool would call retain() + buffer.writeByte('w').writeByte('o').writeByte('r').writeByte('l').writeByte('d'); + assertTrue(channel.writeInbound(buffer)); + assertFalse(channel.inboundMessages().isEmpty(), "Message should be received"); + assertEquals("hello world", channel.inboundMessages().poll(), "Message should be received correctly"); + assertTrue(channel.inboundMessages().isEmpty(), "Only a single message should be received"); + assertFalse(buffer.isReadable(), "Buffer should not have remaining data after reading complete message"); + + buffer.release(); // we are done with the buffer - release it from the pool + assertEquals(0, buffer.refCnt(), "Buffer should be released"); + assertFalse(channel.finish()); + } +} diff --git a/netty-handler-codec/src/test/java/io/netty/handler/codec/CharSequenceValueConverterTest.java b/netty-handler-codec/src/test/java/io/netty/handler/codec/CharSequenceValueConverterTest.java new file mode 100644 index 0000000..c92ec0a --- /dev/null +++ b/netty-handler-codec/src/test/java/io/netty/handler/codec/CharSequenceValueConverterTest.java @@ -0,0 +1,92 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec; + +import io.netty.util.AsciiString; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class CharSequenceValueConverterTest { + + private final CharSequenceValueConverter converter = CharSequenceValueConverter.INSTANCE; + + @Test + public void testBoolean() { + assertTrue(converter.convertToBoolean(converter.convertBoolean(true))); + assertFalse(converter.convertToBoolean(converter.convertBoolean(false))); + } + + @Test + public void testByteFromAsciiString() { + assertEquals(127, converter.convertToByte(AsciiString.of("127"))); + } + + @Test + public void testByteFromEmptyAsciiString() { + assertThrows(NumberFormatException.class, new Executable() { + @Override + public void execute() { + converter.convertToByte(AsciiString.EMPTY_STRING); + } + }); + } + + @Test + public void testByte() { + assertEquals(Byte.MAX_VALUE, converter.convertToByte(converter.convertByte(Byte.MAX_VALUE))); + } + + @Test + public void testChar() { + assertEquals(Character.MAX_VALUE, converter.convertToChar(converter.convertChar(Character.MAX_VALUE))); + } + + @Test + public void testDouble() { + assertEquals(Double.MAX_VALUE, converter.convertToDouble(converter.convertDouble(Double.MAX_VALUE)), 0); + } + + @Test + public void testFloat() { + assertEquals(Float.MAX_VALUE, converter.convertToFloat(converter.convertFloat(Float.MAX_VALUE)), 0); + } + + @Test + public void testInt() { + assertEquals(Integer.MAX_VALUE, converter.convertToInt(converter.convertInt(Integer.MAX_VALUE))); + } + + @Test + public void testShort() { + assertEquals(Short.MAX_VALUE, converter.convertToShort(converter.convertShort(Short.MAX_VALUE))); + } + + @Test + public void testLong() { + assertEquals(Long.MAX_VALUE, converter.convertToLong(converter.convertLong(Long.MAX_VALUE))); + } + + @Test + public void testTimeMillis() { + // Zero out the millis as this is what the convert is doing as well. + long millis = (System.currentTimeMillis() / 1000) * 1000; + assertEquals(millis, converter.convertToTimeMillis(converter.convertTimeMillis(millis))); + } +} diff --git a/netty-handler-codec/src/test/java/io/netty/handler/codec/CodecOutputListTest.java b/netty-handler-codec/src/test/java/io/netty/handler/codec/CodecOutputListTest.java new file mode 100644 index 0000000..5da3395 --- /dev/null +++ b/netty-handler-codec/src/test/java/io/netty/handler/codec/CodecOutputListTest.java @@ -0,0 +1,54 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class CodecOutputListTest { + + @Test + public void testCodecOutputListAdd() { + CodecOutputList codecOutputList = CodecOutputList.newInstance(); + try { + assertEquals(0, codecOutputList.size()); + assertTrue(codecOutputList.isEmpty()); + + codecOutputList.add(1); + assertEquals(1, codecOutputList.size()); + assertFalse(codecOutputList.isEmpty()); + assertEquals(1, codecOutputList.get(0)); + + codecOutputList.add(0, 0); + assertEquals(2, codecOutputList.size()); + assertFalse(codecOutputList.isEmpty()); + assertEquals(0, codecOutputList.get(0)); + assertEquals(1, codecOutputList.get(1)); + + codecOutputList.add(1, 2); + assertEquals(3, codecOutputList.size()); + assertFalse(codecOutputList.isEmpty()); + assertEquals(0, codecOutputList.get(0)); + assertEquals(2, codecOutputList.get(1)); + assertEquals(1, codecOutputList.get(2)); + } finally { + codecOutputList.recycle(); + } + } +} diff --git a/netty-handler-codec/src/test/java/io/netty/handler/codec/DatagramPacketDecoderTest.java b/netty-handler-codec/src/test/java/io/netty/handler/codec/DatagramPacketDecoderTest.java new file mode 100644 index 0000000..fb15b62 --- /dev/null +++ b/netty-handler-codec/src/test/java/io/netty/handler/codec/DatagramPacketDecoderTest.java @@ -0,0 +1,96 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.channel.socket.DatagramPacket; +import io.netty.handler.codec.string.StringDecoder; +import io.netty.util.CharsetUtil; +import io.netty.util.internal.SocketUtils; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.net.InetSocketAddress; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class DatagramPacketDecoderTest { + + private EmbeddedChannel channel; + + @BeforeEach + public void setUp() { + channel = new EmbeddedChannel( + new DatagramPacketDecoder( + new StringDecoder(CharsetUtil.UTF_8))); + } + + @AfterEach + public void tearDown() { + assertFalse(channel.finish()); + } + + @Test + public void testDecode() { + InetSocketAddress recipient = SocketUtils.socketAddress("127.0.0.1", 10000); + InetSocketAddress sender = SocketUtils.socketAddress("127.0.0.1", 20000); + ByteBuf content = Unpooled.wrappedBuffer("netty".getBytes(CharsetUtil.UTF_8)); + assertTrue(channel.writeInbound(new DatagramPacket(content, recipient, sender))); + assertEquals("netty", channel.readInbound()); + } + + @Test + public void testIsNotSharable() { + testIsSharable(false); + } + + @Test + public void testIsSharable() { + testIsSharable(true); + } + + private static void testIsSharable(boolean sharable) { + MessageToMessageDecoder wrapped = new TestMessageToMessageDecoder(sharable); + DatagramPacketDecoder decoder = new DatagramPacketDecoder(wrapped); + assertEquals(wrapped.isSharable(), decoder.isSharable()); + } + + private static final class TestMessageToMessageDecoder extends MessageToMessageDecoder { + + private final boolean sharable; + + TestMessageToMessageDecoder(boolean sharable) { + this.sharable = sharable; + } + + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf msg, List out) throws Exception { + // NOOP + } + + @Override + public boolean isSharable() { + return sharable; + } + } +} diff --git a/netty-handler-codec/src/test/java/io/netty/handler/codec/DatagramPacketEncoderTest.java b/netty-handler-codec/src/test/java/io/netty/handler/codec/DatagramPacketEncoderTest.java new file mode 100644 index 0000000..fee5159 --- /dev/null +++ b/netty-handler-codec/src/test/java/io/netty/handler/codec/DatagramPacketEncoderTest.java @@ -0,0 +1,139 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.AddressedEnvelope; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.DefaultAddressedEnvelope; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.channel.socket.DatagramPacket; +import io.netty.handler.codec.string.StringEncoder; +import io.netty.util.CharsetUtil; +import io.netty.util.internal.SocketUtils; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.net.InetSocketAddress; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.*; + +public class DatagramPacketEncoderTest { + + private EmbeddedChannel channel; + + @BeforeEach + public void setUp() { + channel = new EmbeddedChannel( + new DatagramPacketEncoder( + new StringEncoder(CharsetUtil.UTF_8))); + } + + @AfterEach + public void tearDown() { + assertFalse(channel.finish()); + } + + @Test + public void testEncode() { + testEncode(false); + } + + @Test + public void testEncodeWithSenderIsNull() { + testEncode(true); + } + + private void testEncode(boolean senderIsNull) { + InetSocketAddress recipient = SocketUtils.socketAddress("127.0.0.1", 10000); + InetSocketAddress sender = senderIsNull ? null : SocketUtils.socketAddress("127.0.0.1", 20000); + assertTrue(channel.writeOutbound( + new DefaultAddressedEnvelope("netty", recipient, sender))); + DatagramPacket packet = channel.readOutbound(); + try { + assertEquals("netty", packet.content().toString(CharsetUtil.UTF_8)); + assertEquals(recipient, packet.recipient()); + assertEquals(sender, packet.sender()); + } finally { + packet.release(); + } + } + + @Test + public void testUnmatchedMessageType() { + InetSocketAddress recipient = SocketUtils.socketAddress("127.0.0.1", 10000); + InetSocketAddress sender = SocketUtils.socketAddress("127.0.0.1", 20000); + DefaultAddressedEnvelope envelope = + new DefaultAddressedEnvelope(1L, recipient, sender); + assertTrue(channel.writeOutbound(envelope)); + DefaultAddressedEnvelope output = channel.readOutbound(); + try { + assertSame(envelope, output); + } finally { + output.release(); + } + } + + @Test + public void testUnmatchedType() { + String netty = "netty"; + assertTrue(channel.writeOutbound(netty)); + assertSame(netty, channel.readOutbound()); + } + + @Test + public void testIsNotSharable() { + testSharable(false); + } + + @Test + public void testIsSharable() { + testSharable(true); + } + + private static void testSharable(boolean sharable) { + MessageToMessageEncoder> wrapped = + new TestMessageToMessageEncoder(sharable); + + DatagramPacketEncoder> encoder = + new DatagramPacketEncoder>(wrapped); + assertEquals(wrapped.isSharable(), encoder.isSharable()); + } + + private static final class TestMessageToMessageEncoder + extends MessageToMessageEncoder> { + + private final boolean sharable; + + TestMessageToMessageEncoder(boolean sharable) { + this.sharable = sharable; + } + + @Override + protected void encode( + ChannelHandlerContext ctx, AddressedEnvelope msg, List out) { + // NOOP + } + + @Override + public boolean isSharable() { + return sharable; + } + } +} diff --git a/netty-handler-codec/src/test/java/io/netty/handler/codec/DateFormatterTest.java b/netty-handler-codec/src/test/java/io/netty/handler/codec/DateFormatterTest.java new file mode 100644 index 0000000..e7b3d1a --- /dev/null +++ b/netty-handler-codec/src/test/java/io/netty/handler/codec/DateFormatterTest.java @@ -0,0 +1,145 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec; + +import org.junit.jupiter.api.Test; + +import java.util.Calendar; +import java.util.Date; + +import static io.netty.handler.codec.DateFormatter.*; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; + +public class DateFormatterTest { + /** + * This date is set at "06 Nov 1994 08:49:37 GMT", from + * examples in RFC documentation + */ + private static final long TIMESTAMP = 784111777000L; + private static final Date DATE = new Date(TIMESTAMP); + + @Test + public void testParseWithSingleDigitDay() { + assertEquals(DATE, parseHttpDate("Sun, 6 Nov 1994 08:49:37 GMT")); + } + + @Test + public void testParseWithDoubleDigitDay() { + assertEquals(DATE, parseHttpDate("Sun, 06 Nov 1994 08:49:37 GMT")); + } + + @Test + public void testParseWithDashSeparatorSingleDigitDay() { + assertEquals(DATE, parseHttpDate("Sunday, 6-Nov-94 08:49:37 GMT")); + } + + @Test + public void testParseWithDashSeparatorDoubleDigitDay() { + assertEquals(DATE, parseHttpDate("Sunday, 06-Nov-94 08:49:37 GMT")); + } + + @Test + public void testParseWithoutGMT() { + assertEquals(DATE, parseHttpDate("Sun Nov 06 08:49:37 1994")); + } + + @Test + public void testParseWithFunkyTimezone() { + assertEquals(DATE, parseHttpDate("Sun Nov 06 08:49:37 1994 -0000")); + } + + @Test + public void testParseWithSingleDigitHourMinutesAndSecond() { + assertEquals(DATE, parseHttpDate("Sunday, 06-Nov-94 8:49:37 GMT")); + } + + @Test + public void testParseWithSingleDigitTime() { + assertEquals(DATE, parseHttpDate("Sunday, 06 Nov 1994 8:49:37 GMT")); + + Date _08_09_37 = new Date(TIMESTAMP - 40 * 60 * 1000); + assertEquals(_08_09_37, parseHttpDate("Sunday, 06 Nov 1994 8:9:37 GMT")); + assertEquals(_08_09_37, parseHttpDate("Sunday, 06 Nov 1994 8:09:37 GMT")); + + Date _08_09_07 = new Date(TIMESTAMP - (40 * 60 + 30) * 1000); + assertEquals(_08_09_07, parseHttpDate("Sunday, 06 Nov 1994 8:9:7 GMT")); + assertEquals(_08_09_07, parseHttpDate("Sunday, 06 Nov 1994 8:9:07 GMT")); + } + + @Test + public void testParseMidnight() { + assertEquals(new Date(784080000000L), parseHttpDate("Sunday, 06 Nov 1994 00:00:00 GMT")); + } + + @Test + public void testParseInvalidInput() { + // missing field + assertNull(parseHttpDate("Sun, Nov 1994 08:49:37 GMT")); + assertNull(parseHttpDate("Sun, 06 1994 08:49:37 GMT")); + assertNull(parseHttpDate("Sun, 06 Nov 08:49:37 GMT")); + assertNull(parseHttpDate("Sun, 06 Nov 1994 :49:37 GMT")); + assertNull(parseHttpDate("Sun, 06 Nov 1994 49:37 GMT")); + assertNull(parseHttpDate("Sun, 06 Nov 1994 08::37 GMT")); + assertNull(parseHttpDate("Sun, 06 Nov 1994 08:37 GMT")); + assertNull(parseHttpDate("Sun, 06 Nov 1994 08:49: GMT")); + assertNull(parseHttpDate("Sun, 06 Nov 1994 08:49 GMT")); + //invalid value + assertNull(parseHttpDate("Sun, 06 FOO 1994 08:49:37 GMT")); + assertNull(parseHttpDate("Sun, 36 Nov 1994 08:49:37 GMT")); + assertNull(parseHttpDate("Sun, 06 Nov 1994 28:49:37 GMT")); + assertNull(parseHttpDate("Sun, 06 Nov 1994 08:69:37 GMT")); + assertNull(parseHttpDate("Sun, 06 Nov 1994 08:49:67 GMT")); + //wrong number of digits in timestamp + assertNull(parseHttpDate("Sunday, 06 Nov 1994 0:0:000 GMT")); + assertNull(parseHttpDate("Sunday, 06 Nov 1994 0:000:0 GMT")); + assertNull(parseHttpDate("Sunday, 06 Nov 1994 000:0:0 GMT")); + } + + @Test + public void testFormat() { + assertEquals("Sun, 06 Nov 1994 08:49:37 GMT", format(DATE)); + } + + @Test + public void testAppend() { + StringBuilder sb = new StringBuilder(); + append(DATE, sb); + assertEquals("Sun, 06 Nov 1994 08:49:37 GMT", sb.toString()); + } + + @Test + public void testParseAllMonths() { + assertEquals(Calendar.JANUARY, getMonth(parseHttpDate("Sun, 06 Jan 1994 08:49:37 GMT"))); + assertEquals(Calendar.FEBRUARY, getMonth(parseHttpDate("Sun, 06 Feb 1994 08:49:37 GMT"))); + assertEquals(Calendar.MARCH, getMonth(parseHttpDate("Sun, 06 Mar 1994 08:49:37 GMT"))); + assertEquals(Calendar.APRIL, getMonth(parseHttpDate("Sun, 06 Apr 1994 08:49:37 GMT"))); + assertEquals(Calendar.MAY, getMonth(parseHttpDate("Sun, 06 May 1994 08:49:37 GMT"))); + assertEquals(Calendar.JUNE, getMonth(parseHttpDate("Sun, 06 Jun 1994 08:49:37 GMT"))); + assertEquals(Calendar.JULY, getMonth(parseHttpDate("Sun, 06 Jul 1994 08:49:37 GMT"))); + assertEquals(Calendar.AUGUST, getMonth(parseHttpDate("Sun, 06 Aug 1994 08:49:37 GMT"))); + assertEquals(Calendar.SEPTEMBER, getMonth(parseHttpDate("Sun, 06 Sep 1994 08:49:37 GMT"))); + assertEquals(Calendar.OCTOBER, getMonth(parseHttpDate("Sun Oct 06 08:49:37 1994"))); + assertEquals(Calendar.NOVEMBER, getMonth(parseHttpDate("Sun Nov 06 08:49:37 1994"))); + assertEquals(Calendar.DECEMBER, getMonth(parseHttpDate("Sun Dec 06 08:49:37 1994"))); + } + + private static int getMonth(Date referenceDate) { + Calendar cal = Calendar.getInstance(); + cal.setTime(referenceDate); + return cal.get(Calendar.MONTH); + } +} diff --git a/netty-handler-codec/src/test/java/io/netty/handler/codec/DefaultHeadersTest.java b/netty-handler-codec/src/test/java/io/netty/handler/codec/DefaultHeadersTest.java new file mode 100644 index 0000000..8d23613 --- /dev/null +++ b/netty-handler-codec/src/test/java/io/netty/handler/codec/DefaultHeadersTest.java @@ -0,0 +1,833 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec; + +import io.netty.util.AsciiString; +import io.netty.util.HashingStrategy; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Map; +import java.util.Map.Entry; +import java.util.NoSuchElementException; + +import static io.netty.util.AsciiString.of; +import static java.util.Arrays.asList; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.hasSize; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNotSame; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +/** + * Tests for {@link DefaultHeaders}. + */ +public class DefaultHeadersTest { + + private static final class TestDefaultHeaders extends + DefaultHeaders { + TestDefaultHeaders() { + this(CharSequenceValueConverter.INSTANCE); + } + + TestDefaultHeaders(ValueConverter converter) { + super(converter); + } + + TestDefaultHeaders(HashingStrategy nameHashingStrategy) { + super(nameHashingStrategy, CharSequenceValueConverter.INSTANCE); + } + } + + private static TestDefaultHeaders newInstance() { + return new TestDefaultHeaders(); + } + + @Test + public void addShouldIncreaseAndRemoveShouldDecreaseTheSize() { + TestDefaultHeaders headers = newInstance(); + assertEquals(0, headers.size()); + headers.add(of("name1"), of("value1"), of("value2")); + assertEquals(2, headers.size()); + headers.add(of("name2"), of("value3"), of("value4")); + assertEquals(4, headers.size()); + headers.add(of("name3"), of("value5")); + assertEquals(5, headers.size()); + + headers.remove(of("name3")); + assertEquals(4, headers.size()); + headers.remove(of("name1")); + assertEquals(2, headers.size()); + headers.remove(of("name2")); + assertEquals(0, headers.size()); + assertTrue(headers.isEmpty()); + } + + @Test + public void afterClearHeadersShouldBeEmpty() { + TestDefaultHeaders headers = newInstance(); + headers.add(of("name1"), of("value1")); + headers.add(of("name2"), of("value2")); + assertEquals(2, headers.size()); + headers.clear(); + assertEquals(0, headers.size()); + assertTrue(headers.isEmpty()); + assertFalse(headers.contains(of("name1"))); + assertFalse(headers.contains(of("name2"))); + } + + @Test + public void removingANameForASecondTimeShouldReturnFalse() { + TestDefaultHeaders headers = newInstance(); + headers.add(of("name1"), of("value1")); + headers.add(of("name2"), of("value2")); + assertTrue(headers.remove(of("name2"))); + assertFalse(headers.remove(of("name2"))); + } + + @Test + public void multipleValuesPerNameShouldBeAllowed() { + TestDefaultHeaders headers = newInstance(); + headers.add(of("name"), of("value1")); + headers.add(of("name"), of("value2")); + headers.add(of("name"), of("value3")); + assertEquals(3, headers.size()); + + List values = headers.getAll(of("name")); + assertEquals(3, values.size()); + assertTrue(values.containsAll(asList(of("value1"), of("value2"), of("value3")))); + } + + @Test + public void multipleValuesPerNameIteratorWithOtherNames() { + TestDefaultHeaders headers = newInstance(); + headers.add(of("name1"), of("value1")); + headers.add(of("name1"), of("value2")); + headers.add(of("name2"), of("value4")); + headers.add(of("name1"), of("value3")); + assertEquals(4, headers.size()); + + List values = new ArrayList(); + Iterator itr = headers.valueIterator(of("name1")); + while (itr.hasNext()) { + values.add(itr.next()); + itr.remove(); + } + assertEquals(3, values.size()); + assertEquals(1, headers.size()); + assertFalse(headers.isEmpty()); + assertTrue(values.containsAll(asList(of("value1"), of("value2"), of("value3")))); + itr = headers.valueIterator(of("name1")); + assertFalse(itr.hasNext()); + itr = headers.valueIterator(of("name2")); + assertTrue(itr.hasNext()); + assertEquals(of("value4"), itr.next()); + assertFalse(itr.hasNext()); + } + + @Test + public void multipleValuesPerNameIterator() { + TestDefaultHeaders headers = newInstance(); + headers.add(of("name1"), of("value1")); + headers.add(of("name1"), of("value2")); + assertEquals(2, headers.size()); + + List values = new ArrayList(); + Iterator itr = headers.valueIterator(of("name1")); + while (itr.hasNext()) { + values.add(itr.next()); + itr.remove(); + } + assertEquals(2, values.size()); + assertEquals(0, headers.size()); + assertTrue(headers.isEmpty()); + assertTrue(values.containsAll(asList(of("value1"), of("value2")))); + itr = headers.valueIterator(of("name1")); + assertFalse(itr.hasNext()); + } + + @Test + public void valuesItrRemoveThrowsWhenEmpty() { + TestDefaultHeaders headers = newInstance(); + assertEquals(0, headers.size()); + assertTrue(headers.isEmpty()); + final Iterator itr = headers.valueIterator(of("name")); + assertThrows(IllegalStateException.class, new Executable() { + @Override + public void execute() { + itr.remove(); + } + }); + } + + @Test + public void valuesItrRemoveThrowsAfterLastElement() { + TestDefaultHeaders headers = newInstance(); + headers.add(of("name"), of("value1")); + assertEquals(1, headers.size()); + + List values = new ArrayList(); + Iterator itr = headers.valueIterator(of("name")); + while (itr.hasNext()) { + values.add(itr.next()); + itr.remove(); + } + assertEquals(1, values.size()); + assertEquals(0, headers.size()); + assertTrue(headers.isEmpty()); + assertTrue(values.contains(of("value1"))); + try { + itr.remove(); + fail(); + } catch (IllegalStateException ignored) { + // ignored + } + } + + @Test + public void multipleValuesPerNameIteratorEmpty() { + TestDefaultHeaders headers = newInstance(); + + List values = new ArrayList(); + Iterator itr = headers.valueIterator(of("name")); + while (itr.hasNext()) { + values.add(itr.next()); + } + assertEquals(0, values.size()); + try { + itr.next(); + fail(); + } catch (NoSuchElementException ignored) { + // ignored + } + } + + @Test + public void testContains() { + TestDefaultHeaders headers = newInstance(); + headers.addBoolean(of("boolean"), true); + assertTrue(headers.containsBoolean(of("boolean"), true)); + assertFalse(headers.containsBoolean(of("boolean"), false)); + + headers.addLong(of("long"), Long.MAX_VALUE); + assertTrue(headers.containsLong(of("long"), Long.MAX_VALUE)); + assertFalse(headers.containsLong(of("long"), Long.MIN_VALUE)); + + headers.addInt(of("int"), Integer.MIN_VALUE); + assertTrue(headers.containsInt(of("int"), Integer.MIN_VALUE)); + assertFalse(headers.containsInt(of("int"), Integer.MAX_VALUE)); + + headers.addShort(of("short"), Short.MAX_VALUE); + assertTrue(headers.containsShort(of("short"), Short.MAX_VALUE)); + assertFalse(headers.containsShort(of("short"), Short.MIN_VALUE)); + + headers.addChar(of("char"), Character.MAX_VALUE); + assertTrue(headers.containsChar(of("char"), Character.MAX_VALUE)); + assertFalse(headers.containsChar(of("char"), Character.MIN_VALUE)); + + headers.addByte(of("byte"), Byte.MAX_VALUE); + assertTrue(headers.containsByte(of("byte"), Byte.MAX_VALUE)); + assertFalse(headers.containsLong(of("byte"), Byte.MIN_VALUE)); + + headers.addDouble(of("double"), Double.MAX_VALUE); + assertTrue(headers.containsDouble(of("double"), Double.MAX_VALUE)); + assertFalse(headers.containsDouble(of("double"), Double.MIN_VALUE)); + + headers.addFloat(of("float"), Float.MAX_VALUE); + assertTrue(headers.containsFloat(of("float"), Float.MAX_VALUE)); + assertFalse(headers.containsFloat(of("float"), Float.MIN_VALUE)); + + long millis = System.currentTimeMillis(); + headers.addTimeMillis(of("millis"), millis); + assertTrue(headers.containsTimeMillis(of("millis"), millis)); + // This test doesn't work on midnight, January 1, 1970 UTC + assertFalse(headers.containsTimeMillis(of("millis"), 0)); + + headers.addObject(of("object"), "Hello World"); + assertTrue(headers.containsObject(of("object"), "Hello World")); + assertFalse(headers.containsObject(of("object"), "")); + + headers.add(of("name"), of("value")); + assertTrue(headers.contains(of("name"), of("value"))); + assertFalse(headers.contains(of("name"), of("value1"))); + } + + @Test + public void testCopy() throws Exception { + TestDefaultHeaders headers = newInstance(); + headers.addBoolean(of("boolean"), true); + headers.addLong(of("long"), Long.MAX_VALUE); + headers.addInt(of("int"), Integer.MIN_VALUE); + headers.addShort(of("short"), Short.MAX_VALUE); + headers.addChar(of("char"), Character.MAX_VALUE); + headers.addByte(of("byte"), Byte.MAX_VALUE); + headers.addDouble(of("double"), Double.MAX_VALUE); + headers.addFloat(of("float"), Float.MAX_VALUE); + long millis = System.currentTimeMillis(); + headers.addTimeMillis(of("millis"), millis); + headers.addObject(of("object"), "Hello World"); + headers.add(of("name"), of("value")); + + headers = newInstance().add(headers); + + assertTrue(headers.containsBoolean(of("boolean"), true)); + assertFalse(headers.containsBoolean(of("boolean"), false)); + + assertTrue(headers.containsLong(of("long"), Long.MAX_VALUE)); + assertFalse(headers.containsLong(of("long"), Long.MIN_VALUE)); + + assertTrue(headers.containsInt(of("int"), Integer.MIN_VALUE)); + assertFalse(headers.containsInt(of("int"), Integer.MAX_VALUE)); + + assertTrue(headers.containsShort(of("short"), Short.MAX_VALUE)); + assertFalse(headers.containsShort(of("short"), Short.MIN_VALUE)); + + assertTrue(headers.containsChar(of("char"), Character.MAX_VALUE)); + assertFalse(headers.containsChar(of("char"), Character.MIN_VALUE)); + + assertTrue(headers.containsByte(of("byte"), Byte.MAX_VALUE)); + assertFalse(headers.containsLong(of("byte"), Byte.MIN_VALUE)); + + assertTrue(headers.containsDouble(of("double"), Double.MAX_VALUE)); + assertFalse(headers.containsDouble(of("double"), Double.MIN_VALUE)); + + assertTrue(headers.containsFloat(of("float"), Float.MAX_VALUE)); + assertFalse(headers.containsFloat(of("float"), Float.MIN_VALUE)); + + assertTrue(headers.containsTimeMillis(of("millis"), millis)); + // This test doesn't work on midnight, January 1, 1970 UTC + assertFalse(headers.containsTimeMillis(of("millis"), 0)); + + assertTrue(headers.containsObject(of("object"), "Hello World")); + assertFalse(headers.containsObject(of("object"), "")); + + assertTrue(headers.contains(of("name"), of("value"))); + assertFalse(headers.contains(of("name"), of("value1"))); + } + + @Test + public void canMixConvertedAndNormalValues() { + TestDefaultHeaders headers = newInstance(); + headers.add(of("name"), of("value")); + headers.addInt(of("name"), 100); + headers.addBoolean(of("name"), false); + + assertEquals(3, headers.size()); + assertTrue(headers.contains(of("name"))); + assertTrue(headers.contains(of("name"), of("value"))); + assertTrue(headers.containsInt(of("name"), 100)); + assertTrue(headers.containsBoolean(of("name"), false)); + } + + @Test + public void testGetAndRemove() { + TestDefaultHeaders headers = newInstance(); + headers.add(of("name1"), of("value1")); + headers.add(of("name2"), of("value2"), of("value3")); + headers.add(of("name3"), of("value4"), of("value5"), of("value6")); + + assertEquals(of("value1"), headers.getAndRemove(of("name1"), of("defaultvalue"))); + assertEquals(of("value2"), headers.getAndRemove(of("name2"))); + assertNull(headers.getAndRemove(of("name2"))); + assertEquals(asList(of("value4"), of("value5"), of("value6")), headers.getAllAndRemove(of("name3"))); + assertEquals(0, headers.size()); + assertNull(headers.getAndRemove(of("noname"))); + assertEquals(of("defaultvalue"), headers.getAndRemove(of("noname"), of("defaultvalue"))); + } + + @Test + public void whenNameContainsMultipleValuesGetShouldReturnTheFirst() { + TestDefaultHeaders headers = newInstance(); + headers.add(of("name1"), of("value1"), of("value2")); + assertEquals(of("value1"), headers.get(of("name1"))); + } + + @Test + public void getWithDefaultValueWorks() { + TestDefaultHeaders headers = newInstance(); + headers.add(of("name1"), of("value1")); + + assertEquals(of("value1"), headers.get(of("name1"), of("defaultvalue"))); + assertEquals(of("defaultvalue"), headers.get(of("noname"), of("defaultvalue"))); + } + + @Test + public void setShouldOverWritePreviousValue() { + TestDefaultHeaders headers = newInstance(); + headers.set(of("name"), of("value1")); + headers.set(of("name"), of("value2")); + assertEquals(1, headers.size()); + assertEquals(1, headers.getAll(of("name")).size()); + assertEquals(of("value2"), headers.getAll(of("name")).get(0)); + assertEquals(of("value2"), headers.get(of("name"))); + } + + @Test + public void setAllShouldOverwriteSomeAndLeaveOthersUntouched() { + TestDefaultHeaders h1 = newInstance(); + + h1.add(of("name1"), of("value1")); + h1.add(of("name2"), of("value2")); + h1.add(of("name2"), of("value3")); + h1.add(of("name3"), of("value4")); + + TestDefaultHeaders h2 = newInstance(); + h2.add(of("name1"), of("value5")); + h2.add(of("name2"), of("value6")); + h2.add(of("name1"), of("value7")); + + TestDefaultHeaders expected = newInstance(); + expected.add(of("name1"), of("value5")); + expected.add(of("name2"), of("value6")); + expected.add(of("name1"), of("value7")); + expected.add(of("name3"), of("value4")); + + h1.setAll(h2); + + assertEquals(expected, h1); + } + + @Test + public void headersWithSameNamesAndValuesShouldBeEquivalent() { + TestDefaultHeaders headers1 = newInstance(); + headers1.add(of("name1"), of("value1")); + headers1.add(of("name2"), of("value2")); + headers1.add(of("name2"), of("value3")); + + TestDefaultHeaders headers2 = newInstance(); + headers2.add(of("name1"), of("value1")); + headers2.add(of("name2"), of("value2")); + headers2.add(of("name2"), of("value3")); + + assertEquals(headers1, headers2); + assertEquals(headers2, headers1); + assertEquals(headers1, headers1); + assertEquals(headers2, headers2); + assertEquals(headers1.hashCode(), headers2.hashCode()); + assertEquals(headers1.hashCode(), headers1.hashCode()); + assertEquals(headers2.hashCode(), headers2.hashCode()); + } + + @Test + public void emptyHeadersShouldBeEqual() { + TestDefaultHeaders headers1 = newInstance(); + TestDefaultHeaders headers2 = newInstance(); + assertNotSame(headers1, headers2); + assertEquals(headers1, headers2); + assertEquals(headers1.hashCode(), headers2.hashCode()); + } + + @Test + public void headersWithSameNamesButDifferentValuesShouldNotBeEquivalent() { + TestDefaultHeaders headers1 = newInstance(); + headers1.add(of("name1"), of("value1")); + TestDefaultHeaders headers2 = newInstance(); + headers1.add(of("name1"), of("value2")); + assertNotEquals(headers1, headers2); + } + + @Test + public void subsetOfHeadersShouldNotBeEquivalent() { + TestDefaultHeaders headers1 = newInstance(); + headers1.add(of("name1"), of("value1")); + headers1.add(of("name2"), of("value2")); + TestDefaultHeaders headers2 = newInstance(); + headers1.add(of("name1"), of("value1")); + assertNotEquals(headers1, headers2); + } + + @Test + public void headersWithDifferentNamesAndValuesShouldNotBeEquivalent() { + TestDefaultHeaders h1 = newInstance(); + h1.set(of("name1"), of("value1")); + TestDefaultHeaders h2 = newInstance(); + h2.set(of("name2"), of("value2")); + assertNotEquals(h1, h2); + assertNotEquals(h2, h1); + assertEquals(h1, h1); + assertEquals(h2, h2); + } + + @Test + public void iterateEmptyHeadersShouldThrow() { + final Iterator> iterator = newInstance().iterator(); + assertFalse(iterator.hasNext()); + assertThrows(NoSuchElementException.class, new Executable() { + @Override + public void execute() { + iterator.next(); + } + }); + } + + @Test + public void iteratorShouldReturnAllNameValuePairs() { + TestDefaultHeaders headers1 = newInstance(); + headers1.add(of("name1"), of("value1"), of("value2")); + headers1.add(of("name2"), of("value3")); + headers1.add(of("name3"), of("value4"), of("value5"), of("value6")); + headers1.add(of("name1"), of("value7"), of("value8")); + assertEquals(8, headers1.size()); + + TestDefaultHeaders headers2 = newInstance(); + for (Entry entry : headers1) { + headers2.add(entry.getKey(), entry.getValue()); + } + + assertEquals(headers1, headers2); + } + + @Test + public void iteratorSetValueShouldChangeHeaderValue() { + TestDefaultHeaders headers = newInstance(); + headers.add(of("name1"), of("value1"), of("value2"), of("value3")); + headers.add(of("name2"), of("value4")); + assertEquals(4, headers.size()); + + Iterator> iter = headers.iterator(); + while (iter.hasNext()) { + Entry header = iter.next(); + if (of("name1").equals(header.getKey()) && of("value2").equals(header.getValue())) { + header.setValue(of("updatedvalue2")); + assertEquals(of("updatedvalue2"), header.getValue()); + } + if (of("name1").equals(header.getKey()) && of("value3").equals(header.getValue())) { + header.setValue(of("updatedvalue3")); + assertEquals(of("updatedvalue3"), header.getValue()); + } + } + + assertEquals(4, headers.size()); + assertTrue(headers.contains(of("name1"), of("updatedvalue2"))); + assertFalse(headers.contains(of("name1"), of("value2"))); + assertTrue(headers.contains(of("name1"), of("updatedvalue3"))); + assertFalse(headers.contains(of("name1"), of("value3"))); + } + + @Test + public void testEntryEquals() { + Map.Entry same1 = newInstance().add("name", "value").iterator().next(); + Map.Entry same2 = newInstance().add("name", "value").iterator().next(); + assertEquals(same1, same2); + assertEquals(same1.hashCode(), same2.hashCode()); + + Map.Entry nameDifferent1 = newInstance().add("name1", "value").iterator().next(); + Map.Entry nameDifferent2 = newInstance().add("name2", "value").iterator().next(); + assertNotEquals(nameDifferent1, nameDifferent2); + assertNotEquals(nameDifferent1.hashCode(), nameDifferent2.hashCode()); + + Map.Entry valueDifferent1 = newInstance().add("name", "value1").iterator().next(); + Map.Entry valueDifferent2 = newInstance().add("name", "value2").iterator().next(); + assertNotEquals(valueDifferent1, valueDifferent2); + assertNotEquals(valueDifferent1.hashCode(), valueDifferent2.hashCode()); + } + + @Test + public void getAllReturnsEmptyListForUnknownName() { + TestDefaultHeaders headers = newInstance(); + assertEquals(0, headers.getAll(of("noname")).size()); + } + + @Test + public void setHeadersShouldClearAndOverwrite() { + TestDefaultHeaders headers1 = newInstance(); + headers1.add(of("name"), of("value")); + + TestDefaultHeaders headers2 = newInstance(); + headers2.add(of("name"), of("newvalue")); + headers2.add(of("name1"), of("value1")); + + headers1.set(headers2); + assertEquals(headers1, headers2); + } + + @Test + public void setAllHeadersShouldOnlyOverwriteHeaders() { + TestDefaultHeaders headers1 = newInstance(); + headers1.add(of("name"), of("value")); + headers1.add(of("name1"), of("value1")); + + TestDefaultHeaders headers2 = newInstance(); + headers2.add(of("name"), of("newvalue")); + headers2.add(of("name2"), of("value2")); + + TestDefaultHeaders expected = newInstance(); + expected.add(of("name"), of("newvalue")); + expected.add(of("name1"), of("value1")); + expected.add(of("name2"), of("value2")); + + headers1.setAll(headers2); + assertEquals(headers1, expected); + } + + @Test + public void testAddSelf() { + final TestDefaultHeaders headers = newInstance(); + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + headers.add(headers); + } + }); + } + + @Test + public void testSetSelfIsNoOp() { + TestDefaultHeaders headers = newInstance(); + headers.add("name", "value"); + headers.set(headers); + assertEquals(1, headers.size()); + } + + @Test + public void testToString() { + TestDefaultHeaders headers = newInstance(); + headers.add(of("name1"), of("value1")); + headers.add(of("name1"), of("value2")); + headers.add(of("name2"), of("value3")); + assertEquals("TestDefaultHeaders[name1: value1, name1: value2, name2: value3]", headers.toString()); + + headers = newInstance(); + headers.add(of("name1"), of("value1")); + headers.add(of("name2"), of("value2")); + headers.add(of("name3"), of("value3")); + assertEquals("TestDefaultHeaders[name1: value1, name2: value2, name3: value3]", headers.toString()); + + headers = newInstance(); + headers.add(of("name1"), of("value1")); + assertEquals("TestDefaultHeaders[name1: value1]", headers.toString()); + + headers = newInstance(); + assertEquals("TestDefaultHeaders[]", headers.toString()); + } + + @Test + public void testNotThrowWhenConvertFails() { + TestDefaultHeaders headers = new TestDefaultHeaders(new ValueConverter() { + @Override + public CharSequence convertObject(Object value) { + throw new IllegalArgumentException(); + } + + @Override + public CharSequence convertBoolean(boolean value) { + throw new IllegalArgumentException(); + } + + @Override + public boolean convertToBoolean(CharSequence value) { + throw new IllegalArgumentException(); + } + + @Override + public CharSequence convertByte(byte value) { + throw new IllegalArgumentException(); + } + + @Override + public byte convertToByte(CharSequence value) { + throw new IllegalArgumentException(); + } + + @Override + public CharSequence convertChar(char value) { + throw new IllegalArgumentException(); + } + + @Override + public char convertToChar(CharSequence value) { + throw new IllegalArgumentException(); + } + + @Override + public CharSequence convertShort(short value) { + throw new IllegalArgumentException(); + } + + @Override + public short convertToShort(CharSequence value) { + throw new IllegalArgumentException(); + } + + @Override + public CharSequence convertInt(int value) { + throw new IllegalArgumentException(); + } + + @Override + public int convertToInt(CharSequence value) { + throw new IllegalArgumentException(); + } + + @Override + public CharSequence convertLong(long value) { + throw new IllegalArgumentException(); + } + + @Override + public long convertToLong(CharSequence value) { + throw new IllegalArgumentException(); + } + + @Override + public CharSequence convertTimeMillis(long value) { + throw new IllegalArgumentException(); + } + + @Override + public long convertToTimeMillis(CharSequence value) { + throw new IllegalArgumentException(); + } + + @Override + public CharSequence convertFloat(float value) { + throw new IllegalArgumentException(); + } + + @Override + public float convertToFloat(CharSequence value) { + throw new IllegalArgumentException(); + } + + @Override + public CharSequence convertDouble(double value) { + throw new IllegalArgumentException(); + } + + @Override + public double convertToDouble(CharSequence value) { + throw new IllegalArgumentException(); + } + }); + headers.set("name1", ""); + assertNull(headers.getInt("name1")); + assertEquals(1, headers.getInt("name1", 1)); + + assertNull(headers.getBoolean("")); + assertFalse(headers.getBoolean("name1", false)); + + assertNull(headers.getByte("name1")); + assertEquals(1, headers.getByte("name1", (byte) 1)); + + assertNull(headers.getChar("name")); + assertEquals('n', headers.getChar("name1", 'n')); + + assertNull(headers.getDouble("name")); + assertEquals(1, headers.getDouble("name1", 1), 0); + + assertNull(headers.getFloat("name")); + assertEquals(Float.MAX_VALUE, headers.getFloat("name1", Float.MAX_VALUE), 0); + + assertNull(headers.getLong("name")); + assertEquals(Long.MAX_VALUE, headers.getLong("name1", Long.MAX_VALUE)); + + assertNull(headers.getShort("name")); + assertEquals(Short.MAX_VALUE, headers.getShort("name1", Short.MAX_VALUE)); + + assertNull(headers.getTimeMillis("name")); + assertEquals(Long.MAX_VALUE, headers.getTimeMillis("name1", Long.MAX_VALUE)); + } + + @Test + public void testGetBooleanInvalidValue() { + TestDefaultHeaders headers = newInstance(); + headers.set("name1", "invalid"); + headers.set("name2", new AsciiString("invalid")); + headers.set("name3", new StringBuilder("invalid")); + + assertFalse(headers.getBoolean("name1", false)); + assertFalse(headers.getBoolean("name2", false)); + assertFalse(headers.getBoolean("name3", false)); + } + + @Test + public void testGetBooleanFalseValue() { + TestDefaultHeaders headers = newInstance(); + headers.set("name1", "false"); + headers.set("name2", new AsciiString("false")); + headers.set("name3", new StringBuilder("false")); + + assertFalse(headers.getBoolean("name1", true)); + assertFalse(headers.getBoolean("name2", true)); + assertFalse(headers.getBoolean("name3", true)); + } + + @Test + public void testGetBooleanTrueValue() { + TestDefaultHeaders headers = newInstance(); + headers.set("name1", "true"); + headers.set("name2", new AsciiString("true")); + headers.set("name3", new StringBuilder("true")); + + assertTrue(headers.getBoolean("name1", false)); + assertTrue(headers.getBoolean("name2", false)); + assertTrue(headers.getBoolean("name3", false)); + } + + @Test + public void handlingOfHeaderNameHashCollisions() { + TestDefaultHeaders headers = new TestDefaultHeaders(new HashingStrategy() { + @Override + public int hashCode(CharSequence obj) { + return 0; // Degenerate hashing strategy to enforce collisions. + } + + @Override + public boolean equals(CharSequence a, CharSequence b) { + return a.equals(b); + } + }); + + headers.add("Cookie", "a=b; c=d; e=f"); + headers.add("other", "text/plain"); // Add another header which will be saved in the same entries[index] + + simulateCookieSplitting(headers); + List cookies = headers.getAll("Cookie"); + + assertThat(cookies, hasSize(3)); + assertThat(cookies, containsInAnyOrder((CharSequence) "a=b", "c=d", "e=f")); + } + + /** + * Split up cookies into individual cookie crumb headers. + */ + static void simulateCookieSplitting(TestDefaultHeaders headers) { + Iterator cookieItr = headers.valueIterator("Cookie"); + if (!cookieItr.hasNext()) { + return; + } + // We want to avoid "concurrent modifications" of the headers while we are iterating. So we insert crumbs + // into an intermediate collection and insert them after the split process concludes. + List cookiesToAdd = new ArrayList(); + while (cookieItr.hasNext()) { + //noinspection DynamicRegexReplaceableByCompiledPattern + String[] cookies = cookieItr.next().toString().split("; "); + cookiesToAdd.addAll(asList(cookies)); + cookieItr.remove(); + } + for (CharSequence crumb : cookiesToAdd) { + headers.add("Cookie", crumb); + } + } +} diff --git a/netty-handler-codec/src/test/java/io/netty/handler/codec/DelimiterBasedFrameDecoderTest.java b/netty-handler-codec/src/test/java/io/netty/handler/codec/DelimiterBasedFrameDecoderTest.java new file mode 100644 index 0000000..d3dd830 --- /dev/null +++ b/netty-handler-codec/src/test/java/io/netty/handler/codec/DelimiterBasedFrameDecoderTest.java @@ -0,0 +1,128 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.util.CharsetUtil; +import io.netty.util.ReferenceCountUtil; +import org.junit.jupiter.api.Test; + +import java.nio.charset.Charset; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; + +public class DelimiterBasedFrameDecoderTest { + + @Test + public void testMultipleLinesStrippedDelimiters() { + EmbeddedChannel ch = new EmbeddedChannel(new DelimiterBasedFrameDecoder(8192, true, + Delimiters.lineDelimiter())); + ch.writeInbound(Unpooled.copiedBuffer("TestLine\r\ng\r\n", Charset.defaultCharset())); + + ByteBuf buf = ch.readInbound(); + assertEquals("TestLine", buf.toString(Charset.defaultCharset())); + + ByteBuf buf2 = ch.readInbound(); + assertEquals("g", buf2.toString(Charset.defaultCharset())); + assertNull(ch.readInbound()); + ch.finish(); + + buf.release(); + buf2.release(); + } + + @Test + public void testIncompleteLinesStrippedDelimiters() { + EmbeddedChannel ch = new EmbeddedChannel(new DelimiterBasedFrameDecoder(8192, true, + Delimiters.lineDelimiter())); + ch.writeInbound(Unpooled.copiedBuffer("Test", Charset.defaultCharset())); + assertNull(ch.readInbound()); + ch.writeInbound(Unpooled.copiedBuffer("Line\r\ng\r\n", Charset.defaultCharset())); + + ByteBuf buf = ch.readInbound(); + assertEquals("TestLine", buf.toString(Charset.defaultCharset())); + + ByteBuf buf2 = ch.readInbound(); + assertEquals("g", buf2.toString(Charset.defaultCharset())); + assertNull(ch.readInbound()); + ch.finish(); + + buf.release(); + buf2.release(); + } + + @Test + public void testMultipleLines() { + EmbeddedChannel ch = new EmbeddedChannel(new DelimiterBasedFrameDecoder(8192, false, + Delimiters.lineDelimiter())); + ch.writeInbound(Unpooled.copiedBuffer("TestLine\r\ng\r\n", Charset.defaultCharset())); + + ByteBuf buf = ch.readInbound(); + assertEquals("TestLine\r\n", buf.toString(Charset.defaultCharset())); + + ByteBuf buf2 = ch.readInbound(); + assertEquals("g\r\n", buf2.toString(Charset.defaultCharset())); + assertNull(ch.readInbound()); + ch.finish(); + + buf.release(); + buf2.release(); + } + + @Test + public void testIncompleteLines() { + EmbeddedChannel ch = new EmbeddedChannel(new DelimiterBasedFrameDecoder(8192, false, + Delimiters.lineDelimiter())); + ch.writeInbound(Unpooled.copiedBuffer("Test", Charset.defaultCharset())); + assertNull(ch.readInbound()); + ch.writeInbound(Unpooled.copiedBuffer("Line\r\ng\r\n", Charset.defaultCharset())); + + ByteBuf buf = ch.readInbound(); + assertEquals("TestLine\r\n", buf.toString(Charset.defaultCharset())); + + ByteBuf buf2 = ch.readInbound(); + assertEquals("g\r\n", buf2.toString(Charset.defaultCharset())); + assertNull(ch.readInbound()); + ch.finish(); + + buf.release(); + buf2.release(); + } + + @Test + public void testDecode() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel( + new DelimiterBasedFrameDecoder(8192, true, Delimiters.lineDelimiter())); + + ch.writeInbound(Unpooled.copiedBuffer("first\r\nsecond\nthird", CharsetUtil.US_ASCII)); + + ByteBuf buf = ch.readInbound(); + assertEquals("first", buf.toString(CharsetUtil.US_ASCII)); + + ByteBuf buf2 = ch.readInbound(); + assertEquals("second", buf2.toString(CharsetUtil.US_ASCII)); + assertNull(ch.readInbound()); + ch.finish(); + + ReferenceCountUtil.release(ch.readInbound()); + + buf.release(); + buf2.release(); + } +} diff --git a/netty-handler-codec/src/test/java/io/netty/handler/codec/EmptyHeadersTest.java b/netty-handler-codec/src/test/java/io/netty/handler/codec/EmptyHeadersTest.java new file mode 100644 index 0000000..c90f5c6 --- /dev/null +++ b/netty-handler-codec/src/test/java/io/netty/handler/codec/EmptyHeadersTest.java @@ -0,0 +1,585 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.codec; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import java.util.Arrays; +import java.util.Collections; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class EmptyHeadersTest { + + private static final TestEmptyHeaders HEADERS = new TestEmptyHeaders(); + + @Test + public void testAddStringValue() { + assertThrows(UnsupportedOperationException.class, new Executable() { + @Override + public void execute() { + HEADERS.add("name", "value"); + } + }); + } + + @Test + public void testAddStringValues() { + assertThrows(UnsupportedOperationException.class, new Executable() { + @Override + public void execute() { + HEADERS.add("name", "value1", "value2"); + } + }); + } + + @Test + public void testAddStringValuesIterable() { + assertThrows(UnsupportedOperationException.class, new Executable() { + @Override + public void execute() { + HEADERS.add("name", Arrays.asList("value1", "value2")); + } + }); + } + + @Test + public void testAddBoolean() { + assertThrows(UnsupportedOperationException.class, new Executable() { + @Override + public void execute() { + HEADERS.addBoolean("name", true); + } + }); + } + + @Test + public void testAddByte() { + assertThrows(UnsupportedOperationException.class, new Executable() { + @Override + public void execute() { + HEADERS.addByte("name", (byte) 1); + } + }); + } + + @Test + public void testAddChar() { + assertThrows(UnsupportedOperationException.class, new Executable() { + @Override + public void execute() { + HEADERS.addChar("name", 'a'); + } + }); + } + + @Test + public void testAddDouble() { + assertThrows(UnsupportedOperationException.class, new Executable() { + @Override + public void execute() { + HEADERS.addDouble("name", 0); + } + }); + } + + @Test + public void testAddFloat() { + assertThrows(UnsupportedOperationException.class, new Executable() { + @Override + public void execute() { + HEADERS.addFloat("name", 0); + } + }); + } + + @Test + public void testAddInt() { + assertThrows(UnsupportedOperationException.class, new Executable() { + @Override + public void execute() { + HEADERS.addInt("name", 0); + } + }); + } + + @Test + public void testAddLong() { + assertThrows(UnsupportedOperationException.class, new Executable() { + @Override + public void execute() { + HEADERS.addLong("name", 0); + } + }); + } + + @Test + public void testAddShort() { + assertThrows(UnsupportedOperationException.class, new Executable() { + @Override + public void execute() { + HEADERS.addShort("name", (short) 0); + } + }); + } + + @Test + public void testAddTimeMillis() { + assertThrows(UnsupportedOperationException.class, new Executable() { + @Override + public void execute() { + HEADERS.addTimeMillis("name", 0); + } + }); + } + + @Test + public void testSetStringValue() { + assertThrows(UnsupportedOperationException.class, new Executable() { + @Override + public void execute() { + HEADERS.set("name", "value"); + } + }); + } + + @Test + public void testSetStringValues() { + assertThrows(UnsupportedOperationException.class, new Executable() { + @Override + public void execute() { + HEADERS.set("name", "value1", "value2"); + } + }); + } + + @Test + public void testSetStringValuesIterable() { + assertThrows(UnsupportedOperationException.class, new Executable() { + @Override + public void execute() { + HEADERS.set("name", Arrays.asList("value1", "value2")); + } + }); + } + + @Test + public void testSetBoolean() { + assertThrows(UnsupportedOperationException.class, new Executable() { + @Override + public void execute() { + HEADERS.setBoolean("name", true); + } + }); + } + + @Test + public void testSetByte() { + assertThrows(UnsupportedOperationException.class, new Executable() { + @Override + public void execute() { + HEADERS.setByte("name", (byte) 1); + } + }); + } + + @Test + public void testSetChar() { + assertThrows(UnsupportedOperationException.class, new Executable() { + @Override + public void execute() { + HEADERS.setChar("name", 'a'); + } + }); + } + + @Test + public void testSetDouble() { + assertThrows(UnsupportedOperationException.class, new Executable() { + @Override + public void execute() { + HEADERS.setDouble("name", 0); + } + }); + } + + @Test + public void testSetFloat() { + assertThrows(UnsupportedOperationException.class, new Executable() { + @Override + public void execute() { + HEADERS.setFloat("name", 0); + } + }); + } + + @Test + public void testSetInt() { + assertThrows(UnsupportedOperationException.class, new Executable() { + @Override + public void execute() { + HEADERS.setInt("name", 0); + } + }); + } + + @Test + public void testSetLong() { + assertThrows(UnsupportedOperationException.class, new Executable() { + @Override + public void execute() { + HEADERS.setLong("name", 0); + } + }); + } + + @Test + public void testSetShort() { + assertThrows(UnsupportedOperationException.class, new Executable() { + @Override + public void execute() { + HEADERS.setShort("name", (short) 0); + } + }); + } + + @Test + public void testSetTimeMillis() { + assertThrows(UnsupportedOperationException.class, new Executable() { + @Override + public void execute() { + HEADERS.setTimeMillis("name", 0); + } + }); + } + + @Test + public void testSetAll() { + assertThrows(UnsupportedOperationException.class, new Executable() { + @Override + public void execute() { + HEADERS.setAll(new TestEmptyHeaders()); + } + }); + } + + @Test + public void testSet() { + assertThrows(UnsupportedOperationException.class, new Executable() { + @Override + public void execute() { + HEADERS.set(new TestEmptyHeaders()); + } + }); + } + + @Test + public void testGet() { + assertNull(HEADERS.get("name1")); + } + + @Test + public void testGetDefault() { + assertEquals("default", HEADERS.get("name1", "default")); + } + + @Test + public void testGetAndRemove() { + assertNull(HEADERS.getAndRemove("name1")); + } + + @Test + public void testGetAndRemoveDefault() { + assertEquals("default", HEADERS.getAndRemove("name1", "default")); + } + + @Test + public void testGetAll() { + assertEquals(Collections.emptyList(), HEADERS.getAll("name1")); + } + + @Test + public void testGetAllAndRemove() { + assertEquals(Collections.emptyList(), HEADERS.getAllAndRemove("name1")); + } + + @Test + public void testGetBoolean() { + assertNull(HEADERS.getBoolean("name1")); + } + + @Test + public void testGetBooleanDefault() { + assertTrue(HEADERS.getBoolean("name1", true)); + } + + @Test + public void testGetBooleanAndRemove() { + assertNull(HEADERS.getBooleanAndRemove("name1")); + } + + @Test + public void testGetBooleanAndRemoveDefault() { + assertTrue(HEADERS.getBooleanAndRemove("name1", true)); + } + + @Test + public void testGetByte() { + assertNull(HEADERS.getByte("name1")); + } + + @Test + public void testGetByteDefault() { + assertEquals((byte) 0, HEADERS.getByte("name1", (byte) 0)); + } + + @Test + public void testGetByteAndRemove() { + assertNull(HEADERS.getByteAndRemove("name1")); + } + + @Test + public void testGetByteAndRemoveDefault() { + assertEquals((byte) 0, HEADERS.getByteAndRemove("name1", (byte) 0)); + } + + @Test + public void testGetChar() { + assertNull(HEADERS.getChar("name1")); + } + + @Test + public void testGetCharDefault() { + assertEquals('x', HEADERS.getChar("name1", 'x')); + } + + @Test + public void testGetCharAndRemove() { + assertNull(HEADERS.getCharAndRemove("name1")); + } + + @Test + public void testGetCharAndRemoveDefault() { + assertEquals('x', HEADERS.getCharAndRemove("name1", 'x')); + } + + @Test + public void testGetDouble() { + assertNull(HEADERS.getDouble("name1")); + } + + @Test + public void testGetDoubleDefault() { + assertEquals(1, HEADERS.getDouble("name1", 1), 0); + } + + @Test + public void testGetDoubleAndRemove() { + assertNull(HEADERS.getDoubleAndRemove("name1")); + } + + @Test + public void testGetDoubleAndRemoveDefault() { + assertEquals(1, HEADERS.getDoubleAndRemove("name1", 1), 0); + } + + @Test + public void testGetFloat() { + assertNull(HEADERS.getFloat("name1")); + } + + @Test + public void testGetFloatDefault() { + assertEquals(1, HEADERS.getFloat("name1", 1), 0); + } + + @Test + public void testGetFloatAndRemove() { + assertNull(HEADERS.getFloatAndRemove("name1")); + } + + @Test + public void testGetFloatAndRemoveDefault() { + assertEquals(1, HEADERS.getFloatAndRemove("name1", 1), 0); + } + + @Test + public void testGetInt() { + assertNull(HEADERS.getInt("name1")); + } + + @Test + public void testGetIntDefault() { + assertEquals(1, HEADERS.getInt("name1", 1)); + } + + @Test + public void testGetIntAndRemove() { + assertNull(HEADERS.getIntAndRemove("name1")); + } + + @Test + public void testGetIntAndRemoveDefault() { + assertEquals(1, HEADERS.getIntAndRemove("name1", 1)); + } + + @Test + public void testGetLong() { + assertNull(HEADERS.getLong("name1")); + } + + @Test + public void testGetLongDefault() { + assertEquals(1, HEADERS.getLong("name1", 1)); + } + + @Test + public void testGetLongAndRemove() { + assertNull(HEADERS.getLongAndRemove("name1")); + } + + @Test + public void testGetLongAndRemoveDefault() { + assertEquals(1, HEADERS.getLongAndRemove("name1", 1)); + } + + @Test + public void testGetShort() { + assertNull(HEADERS.getShort("name1")); + } + + @Test + public void testGetShortDefault() { + assertEquals(1, HEADERS.getShort("name1", (short) 1)); + } + + @Test + public void testGetShortAndRemove() { + assertNull(HEADERS.getShortAndRemove("name1")); + } + + @Test + public void testGetShortAndRemoveDefault() { + assertEquals(1, HEADERS.getShortAndRemove("name1", (short) 1)); + } + + @Test + public void testGetTimeMillis() { + assertNull(HEADERS.getTimeMillis("name1")); + } + + @Test + public void testGetTimeMillisDefault() { + assertEquals(1, HEADERS.getTimeMillis("name1", 1)); + } + + @Test + public void testGetTimeMillisAndRemove() { + assertNull(HEADERS.getTimeMillisAndRemove("name1")); + } + + @Test + public void testGetTimeMillisAndRemoveDefault() { + assertEquals(1, HEADERS.getTimeMillisAndRemove("name1", 1)); + } + + @Test + public void testContains() { + assertFalse(HEADERS.contains("name1")); + } + + @Test + public void testContainsWithValue() { + assertFalse(HEADERS.contains("name1", "value1")); + } + + @Test + public void testContainsBoolean() { + assertFalse(HEADERS.containsBoolean("name1", false)); + } + + @Test + public void testContainsByte() { + assertFalse(HEADERS.containsByte("name1", (byte) 'x')); + } + + @Test + public void testContainsChar() { + assertFalse(HEADERS.containsChar("name1", 'x')); + } + + @Test + public void testContainsDouble() { + assertFalse(HEADERS.containsDouble("name1", 1)); + } + + @Test + public void testContainsFloat() { + assertFalse(HEADERS.containsFloat("name1", 1)); + } + + @Test + public void testContainsInt() { + assertFalse(HEADERS.containsInt("name1", 1)); + } + + @Test + public void testContainsLong() { + assertFalse(HEADERS.containsLong("name1", 1)); + } + + @Test + public void testContainsShort() { + assertFalse(HEADERS.containsShort("name1", (short) 1)); + } + + @Test + public void testContainsTimeMillis() { + assertFalse(HEADERS.containsTimeMillis("name1", 1)); + } + + @Test + public void testContainsObject() { + assertFalse(HEADERS.containsObject("name1", "")); + } + + @Test + public void testIsEmpty() { + assertTrue(HEADERS.isEmpty()); + } + + @Test + public void testClear() { + assertSame(HEADERS, HEADERS.clear()); + } + + @Test + public void testSize() { + assertEquals(0, HEADERS.size()); + } + + @Test + public void testValueIterator() { + assertFalse(HEADERS.valueIterator("name1").hasNext()); + } + + private static final class TestEmptyHeaders extends EmptyHeaders { } +} diff --git a/netty-handler-codec/src/test/java/io/netty/handler/codec/LengthFieldBasedFrameDecoderTest.java b/netty-handler-codec/src/test/java/io/netty/handler/codec/LengthFieldBasedFrameDecoderTest.java new file mode 100644 index 0000000..afc86b3 --- /dev/null +++ b/netty-handler-codec/src/test/java/io/netty/handler/codec/LengthFieldBasedFrameDecoderTest.java @@ -0,0 +1,89 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.embedded.EmbeddedChannel; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class LengthFieldBasedFrameDecoderTest { + + @Test + public void testDiscardTooLongFrame1() { + ByteBuf buf = Unpooled.buffer(); + buf.writeInt(32); + for (int i = 0; i < 32; i++) { + buf.writeByte(i); + } + buf.writeInt(1); + buf.writeByte('a'); + EmbeddedChannel channel = new EmbeddedChannel(new LengthFieldBasedFrameDecoder(16, 0, 4)); + try { + channel.writeInbound(buf); + fail(); + } catch (TooLongFrameException e) { + // expected + } + assertTrue(channel.finish()); + + ByteBuf b = channel.readInbound(); + assertEquals(5, b.readableBytes()); + assertEquals(1, b.readInt()); + assertEquals('a', b.readByte()); + b.release(); + + assertNull(channel.readInbound()); + channel.finish(); + } + + @Test + public void testDiscardTooLongFrame2() { + ByteBuf buf = Unpooled.buffer(); + buf.writeInt(32); + for (int i = 0; i < 32; i++) { + buf.writeByte(i); + } + buf.writeInt(1); + buf.writeByte('a'); + EmbeddedChannel channel = new EmbeddedChannel(new LengthFieldBasedFrameDecoder(16, 0, 4)); + try { + channel.writeInbound(buf.readRetainedSlice(14)); + fail(); + } catch (TooLongFrameException e) { + // expected + } + assertTrue(channel.writeInbound(buf.readRetainedSlice(buf.readableBytes()))); + + assertTrue(channel.finish()); + + ByteBuf b = channel.readInbound(); + assertEquals(5, b.readableBytes()); + assertEquals(1, b.readInt()); + assertEquals('a', b.readByte()); + b.release(); + + assertNull(channel.readInbound()); + channel.finish(); + + buf.release(); + } +} diff --git a/netty-handler-codec/src/test/java/io/netty/handler/codec/LineBasedFrameDecoderTest.java b/netty-handler-codec/src/test/java/io/netty/handler/codec/LineBasedFrameDecoderTest.java new file mode 100644 index 0000000..82ff201 --- /dev/null +++ b/netty-handler-codec/src/test/java/io/netty/handler/codec/LineBasedFrameDecoderTest.java @@ -0,0 +1,216 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.util.CharsetUtil; +import io.netty.util.ReferenceCountUtil; +import org.junit.jupiter.api.Test; + +import static io.netty.buffer.Unpooled.*; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.CoreMatchers.*; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class LineBasedFrameDecoderTest { + @Test + public void testDecodeWithStrip() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new LineBasedFrameDecoder(8192, true, false)); + + ch.writeInbound(copiedBuffer("first\r\nsecond\nthird", CharsetUtil.US_ASCII)); + + ByteBuf buf = ch.readInbound(); + assertEquals("first", buf.toString(CharsetUtil.US_ASCII)); + + ByteBuf buf2 = ch.readInbound(); + assertEquals("second", buf2.toString(CharsetUtil.US_ASCII)); + assertNull(ch.readInbound()); + ch.finish(); + + ReferenceCountUtil.release(ch.readInbound()); + + buf.release(); + buf2.release(); + } + + @Test + public void testDecodeWithoutStrip() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new LineBasedFrameDecoder(8192, false, false)); + + ch.writeInbound(copiedBuffer("first\r\nsecond\nthird", CharsetUtil.US_ASCII)); + + ByteBuf buf = ch.readInbound(); + assertEquals("first\r\n", buf.toString(CharsetUtil.US_ASCII)); + + ByteBuf buf2 = ch.readInbound(); + assertEquals("second\n", buf2.toString(CharsetUtil.US_ASCII)); + assertNull(ch.readInbound()); + ch.finish(); + ReferenceCountUtil.release(ch.readInbound()); + + buf.release(); + buf2.release(); + } + + @Test + public void testTooLongLine1() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new LineBasedFrameDecoder(16, false, false)); + + try { + ch.writeInbound(copiedBuffer("12345678901234567890\r\nfirst\nsecond", CharsetUtil.US_ASCII)); + fail(); + } catch (Exception e) { + assertThat(e, is(instanceOf(TooLongFrameException.class))); + } + + ByteBuf buf = ch.readInbound(); + ByteBuf buf2 = copiedBuffer("first\n", CharsetUtil.US_ASCII); + assertThat(buf, is(buf2)); + assertThat(ch.finish(), is(false)); + + buf.release(); + buf2.release(); + } + + @Test + public void testTooLongLine2() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new LineBasedFrameDecoder(16, false, false)); + + assertFalse(ch.writeInbound(copiedBuffer("12345678901234567", CharsetUtil.US_ASCII))); + try { + ch.writeInbound(copiedBuffer("890\r\nfirst\r\n", CharsetUtil.US_ASCII)); + fail(); + } catch (Exception e) { + assertThat(e, is(instanceOf(TooLongFrameException.class))); + } + + ByteBuf buf = ch.readInbound(); + ByteBuf buf2 = copiedBuffer("first\r\n", CharsetUtil.US_ASCII); + assertThat(buf, is(buf2)); + assertThat(ch.finish(), is(false)); + + buf.release(); + buf2.release(); + } + + @Test + public void testTooLongLineWithFailFast() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new LineBasedFrameDecoder(16, false, true)); + + try { + ch.writeInbound(copiedBuffer("12345678901234567", CharsetUtil.US_ASCII)); + fail(); + } catch (Exception e) { + assertThat(e, is(instanceOf(TooLongFrameException.class))); + } + + assertThat(ch.writeInbound(copiedBuffer("890", CharsetUtil.US_ASCII)), is(false)); + assertThat(ch.writeInbound(copiedBuffer("123\r\nfirst\r\n", CharsetUtil.US_ASCII)), is(true)); + + ByteBuf buf = ch.readInbound(); + ByteBuf buf2 = copiedBuffer("first\r\n", CharsetUtil.US_ASCII); + assertThat(buf, is(buf2)); + assertThat(ch.finish(), is(false)); + + buf.release(); + buf2.release(); + } + + @Test + public void testDecodeSplitsCorrectly() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new LineBasedFrameDecoder(8192, false, false)); + + assertTrue(ch.writeInbound(copiedBuffer("line\r\n.\r\n", CharsetUtil.US_ASCII))); + + ByteBuf buf = ch.readInbound(); + assertEquals("line\r\n", buf.toString(CharsetUtil.US_ASCII)); + + ByteBuf buf2 = ch.readInbound(); + assertEquals(".\r\n", buf2.toString(CharsetUtil.US_ASCII)); + assertFalse(ch.finishAndReleaseAll()); + + buf.release(); + buf2.release(); + } + + @Test + public void testFragmentedDecode() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new LineBasedFrameDecoder(8192, false, false)); + + assertFalse(ch.writeInbound(copiedBuffer("huu", CharsetUtil.US_ASCII))); + assertNull(ch.readInbound()); + + assertFalse(ch.writeInbound(copiedBuffer("haa\r", CharsetUtil.US_ASCII))); + assertNull(ch.readInbound()); + + assertTrue(ch.writeInbound(copiedBuffer("\nhuuhaa\r\n", CharsetUtil.US_ASCII))); + ByteBuf buf = ch.readInbound(); + assertEquals("huuhaa\r\n", buf.toString(CharsetUtil.US_ASCII)); + + ByteBuf buf2 = ch.readInbound(); + assertEquals("huuhaa\r\n", buf2.toString(CharsetUtil.US_ASCII)); + assertFalse(ch.finishAndReleaseAll()); + + buf.release(); + buf2.release(); + } + + @Test + public void testEmptyLine() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new LineBasedFrameDecoder(8192, true, false)); + + assertTrue(ch.writeInbound(copiedBuffer("\nabcna\r\n", CharsetUtil.US_ASCII))); + + ByteBuf buf = ch.readInbound(); + assertEquals("", buf.toString(CharsetUtil.US_ASCII)); + + ByteBuf buf2 = ch.readInbound(); + assertEquals("abcna", buf2.toString(CharsetUtil.US_ASCII)); + + assertFalse(ch.finishAndReleaseAll()); + + buf.release(); + buf2.release(); + } + + @Test + public void testNotFailFast() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new LineBasedFrameDecoder(2, false, false)); + assertFalse(ch.writeInbound(wrappedBuffer(new byte[] { 0, 1, 2 }))); + assertFalse(ch.writeInbound(wrappedBuffer(new byte[]{ 3, 4 }))); + try { + ch.writeInbound(wrappedBuffer(new byte[] { '\n' })); + fail(); + } catch (TooLongFrameException expected) { + // Expected once we received a full frame. + } + assertFalse(ch.writeInbound(wrappedBuffer(new byte[] { '5' }))); + assertTrue(ch.writeInbound(wrappedBuffer(new byte[] { '\n' }))); + + ByteBuf expected = wrappedBuffer(new byte[] { '5', '\n' }); + ByteBuf buffer = ch.readInbound(); + assertEquals(expected, buffer); + expected.release(); + buffer.release(); + + assertFalse(ch.finish()); + } +} diff --git a/netty-handler-codec/src/test/java/io/netty/handler/codec/MessageAggregatorTest.java b/netty-handler-codec/src/test/java/io/netty/handler/codec/MessageAggregatorTest.java new file mode 100644 index 0000000..69ef074 --- /dev/null +++ b/netty-handler-codec/src/test/java/io/netty/handler/codec/MessageAggregatorTest.java @@ -0,0 +1,137 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.codec; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.*; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufHolder; +import io.netty.buffer.DefaultByteBufHolder; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelOutboundHandlerAdapter; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.util.CharsetUtil; +import org.junit.jupiter.api.function.Executable; + +public class MessageAggregatorTest { + private static final class ReadCounter extends ChannelOutboundHandlerAdapter { + int value; + + @Override + public void read(ChannelHandlerContext ctx) throws Exception { + value++; + ctx.read(); + } + } + + abstract static class MockMessageAggregator + extends MessageAggregator { + + protected MockMessageAggregator() { + super(1024); + } + + @Override + protected ByteBufHolder beginAggregation(ByteBufHolder start, ByteBuf content) throws Exception { + return start.replace(content); + } + } + + private static ByteBufHolder message(String string) { + return new DefaultByteBufHolder( + Unpooled.copiedBuffer(string, CharsetUtil.US_ASCII)); + } + + @Test + public void testReadFlowManagement() throws Exception { + ReadCounter counter = new ReadCounter(); + ByteBufHolder first = message("first"); + ByteBufHolder chunk = message("chunk"); + ByteBufHolder last = message("last"); + + MockMessageAggregator agg = spy(MockMessageAggregator.class); + when(agg.isStartMessage(first)).thenReturn(true); + when(agg.isContentMessage(chunk)).thenReturn(true); + when(agg.isContentMessage(last)).thenReturn(true); + when(agg.isLastContentMessage(last)).thenReturn(true); + + EmbeddedChannel embedded = new EmbeddedChannel(counter, agg); + embedded.config().setAutoRead(false); + + assertFalse(embedded.writeInbound(first)); + assertFalse(embedded.writeInbound(chunk)); + assertTrue(embedded.writeInbound(last)); + + assertEquals(3, counter.value); // 2 reads issued from MockMessageAggregator + // 1 read issued from EmbeddedChannel constructor + + ByteBufHolder all = new DefaultByteBufHolder(Unpooled.wrappedBuffer( + first.content().retain(), chunk.content().retain(), last.content().retain())); + ByteBufHolder out = embedded.readInbound(); + + assertEquals(all, out); + assertTrue(all.release() && out.release()); + assertFalse(embedded.finish()); + } + + @Test + public void testCloseWhileAggregating() throws Exception { + ReadCounter counter = new ReadCounter(); + ByteBufHolder first = new TestMessage(Unpooled.copiedBuffer("first", CharsetUtil.US_ASCII)); + + MockMessageAggregator agg = spy(MockMessageAggregator.class); + when(agg.isStartMessage(first)).thenReturn(true); + when(agg.isLastContentMessage(first)).thenReturn(false); + + final EmbeddedChannel embedded = new EmbeddedChannel(counter, agg); + embedded.config().setAutoRead(false); + + assertFalse(embedded.writeInbound(first)); + + assertEquals(2, counter.value); + assertThrows(PrematureChannelClosureException.class, new Executable() { + @Override + public void execute() { + embedded.finish(); + } + }); + assertEquals(0, first.refCnt()); + } + + private static final class TestMessage extends DefaultByteBufHolder implements DecoderResultProvider { + TestMessage(ByteBuf data) { + super(data); + } + + @Override + public DecoderResult decoderResult() { + return DecoderResult.SUCCESS; + } + + @Override + public void setDecoderResult(DecoderResult result) { + // NOOP + } + } +} diff --git a/netty-handler-codec/src/test/java/io/netty/handler/codec/MessageToMessageEncoderTest.java b/netty-handler-codec/src/test/java/io/netty/handler/codec/MessageToMessageEncoderTest.java new file mode 100644 index 0000000..714600c --- /dev/null +++ b/netty-handler-codec/src/test/java/io/netty/handler/codec/MessageToMessageEncoderTest.java @@ -0,0 +1,86 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec; + +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelOutboundHandlerAdapter; +import io.netty.channel.ChannelPromise; +import io.netty.channel.embedded.EmbeddedChannel; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; + +import java.util.List; + +public class MessageToMessageEncoderTest { + + /** + * Test-case for https://github.com/netty/netty/issues/1656 + */ + @Test + public void testException() { + final EmbeddedChannel channel = new EmbeddedChannel(new MessageToMessageEncoder() { + @Override + protected void encode(ChannelHandlerContext ctx, Object msg, List out) throws Exception { + throw new Exception(); + } + }); + assertThrows(EncoderException.class, new Executable() { + @Override + public void execute() { + channel.writeOutbound(new Object()); + } + }); + } + + @Test + public void testIntermediateWriteFailures() { + ChannelHandler encoder = new MessageToMessageEncoder() { + @Override + protected void encode(ChannelHandlerContext ctx, Object msg, List out) { + out.add(new Object()); + out.add(msg); + } + }; + + final Exception firstWriteException = new Exception(); + + ChannelHandler writeThrower = new ChannelOutboundHandlerAdapter() { + private boolean firstWritten; + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) { + if (firstWritten) { + ctx.write(msg, promise); + } else { + firstWritten = true; + promise.setFailure(firstWriteException); + } + } + }; + + EmbeddedChannel channel = new EmbeddedChannel(writeThrower, encoder); + Object msg = new Object(); + ChannelFuture write = channel.writeAndFlush(msg); + assertSame(firstWriteException, write.cause()); + assertSame(msg, channel.readOutbound()); + assertFalse(channel.finish()); + } +} diff --git a/netty-handler-codec/src/test/java/io/netty/handler/codec/ReplayingDecoderByteBufTest.java b/netty-handler-codec/src/test/java/io/netty/handler/codec/ReplayingDecoderByteBufTest.java new file mode 100644 index 0000000..0c48f6e --- /dev/null +++ b/netty-handler-codec/src/test/java/io/netty/handler/codec/ReplayingDecoderByteBufTest.java @@ -0,0 +1,129 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.netty.util.Signal; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class ReplayingDecoderByteBufTest { + + /** + * See https://github.com/netty/netty/issues/445 + */ + @Test + public void testGetUnsignedByte() { + ByteBuf buf = Unpooled.copiedBuffer("TestBuffer", CharsetUtil.ISO_8859_1); + ReplayingDecoderByteBuf buffer = new ReplayingDecoderByteBuf(buf); + + boolean error; + int i = 0; + try { + for (;;) { + buffer.getUnsignedByte(i); + i++; + } + } catch (Signal e) { + error = true; + } + + assertTrue(error); + assertEquals(10, i); + + buf.release(); + } + + /** + * See https://github.com/netty/netty/issues/445 + */ + @Test + public void testGetByte() { + ByteBuf buf = Unpooled.copiedBuffer("TestBuffer", CharsetUtil.ISO_8859_1); + ReplayingDecoderByteBuf buffer = new ReplayingDecoderByteBuf(buf); + + boolean error; + int i = 0; + try { + for (;;) { + buffer.getByte(i); + i++; + } + } catch (Signal e) { + error = true; + } + + assertTrue(error); + assertEquals(10, i); + + buf.release(); + } + + /** + * See https://github.com/netty/netty/issues/445 + */ + @Test + public void testGetBoolean() { + ByteBuf buf = Unpooled.buffer(10); + while (buf.isWritable()) { + buf.writeBoolean(true); + } + ReplayingDecoderByteBuf buffer = new ReplayingDecoderByteBuf(buf); + + boolean error; + int i = 0; + try { + for (;;) { + buffer.getBoolean(i); + i++; + } + } catch (Signal e) { + error = true; + } + + assertTrue(error); + assertEquals(10, i); + + buf.release(); + } + + // See https://github.com/netty/netty/issues/13455 + @Test + void testRetainedSlice() { + ByteBuf buf = Unpooled.buffer(10); + int i = 0; + while (buf.isWritable()) { + buf.writeByte(i++); + } + ReplayingDecoderByteBuf buffer = new ReplayingDecoderByteBuf(buf); + ByteBuf slice = buffer.retainedSlice(0, 4); + assertEquals(2, slice.refCnt()); + + i = 0; + while (slice.isReadable()) { + assertEquals(i++, slice.readByte()); + } + slice.release(); + buf.release(); + assertEquals(0, slice.refCnt()); + assertEquals(0, buf.refCnt()); + assertEquals(0, buffer.refCnt()); + } +} diff --git a/netty-handler-codec/src/test/java/io/netty/handler/codec/ReplayingDecoderTest.java b/netty-handler-codec/src/test/java/io/netty/handler/codec/ReplayingDecoderTest.java new file mode 100644 index 0000000..18dafd9 --- /dev/null +++ b/netty-handler-codec/src/test/java/io/netty/handler/codec/ReplayingDecoderTest.java @@ -0,0 +1,319 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.channel.socket.ChannelInputShutdownEvent; +import io.netty.util.internal.PlatformDependent; +import org.junit.jupiter.api.Test; + +import java.util.List; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingDeque; +import java.util.concurrent.atomic.AtomicReference; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class ReplayingDecoderTest { + + @Test + public void testLineProtocol() { + EmbeddedChannel ch = new EmbeddedChannel(new LineDecoder()); + + // Ordinary input + ch.writeInbound(Unpooled.wrappedBuffer(new byte[] { 'A' })); + assertNull(ch.readInbound()); + ch.writeInbound(Unpooled.wrappedBuffer(new byte[] { 'B' })); + assertNull(ch.readInbound()); + ch.writeInbound(Unpooled.wrappedBuffer(new byte[] { 'C' })); + assertNull(ch.readInbound()); + ch.writeInbound(Unpooled.wrappedBuffer(new byte[] { '\n' })); + + ByteBuf buf = Unpooled.wrappedBuffer(new byte[] { 'A', 'B', 'C' }); + ByteBuf buf2 = ch.readInbound(); + assertEquals(buf, buf2); + + buf.release(); + buf2.release(); + + // Truncated input + ch.writeInbound(Unpooled.wrappedBuffer(new byte[] { 'A' })); + assertNull(ch.readInbound()); + + ch.finish(); + assertNull(ch.readInbound()); + } + + private static final class LineDecoder extends ReplayingDecoder { + + LineDecoder() { + } + + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) { + ByteBuf msg = in.readBytes(in.bytesBefore((byte) '\n')); + out.add(msg); + in.skipBytes(1); + } + } + + @Test + public void testReplacement() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel(new BloatedLineDecoder()); + + // "AB" should be forwarded to LineDecoder by BloatedLineDecoder. + ch.writeInbound(Unpooled.wrappedBuffer(new byte[]{'A', 'B'})); + assertNull(ch.readInbound()); + + // "C\n" should be appended to "AB" so that LineDecoder decodes it correctly. + ch.writeInbound(Unpooled.wrappedBuffer(new byte[]{'C', '\n'})); + + ByteBuf buf = Unpooled.wrappedBuffer(new byte[] { 'A', 'B', 'C' }); + ByteBuf buf2 = ch.readInbound(); + assertEquals(buf, buf2); + + buf.release(); + buf2.release(); + + ch.finish(); + assertNull(ch.readInbound()); + } + + private static final class BloatedLineDecoder extends ChannelInboundHandlerAdapter { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + ctx.pipeline().replace(this, "less-bloated", new LineDecoder()); + ctx.pipeline().fireChannelRead(msg); + } + } + + @Test + public void testSingleDecode() throws Exception { + LineDecoder decoder = new LineDecoder(); + decoder.setSingleDecode(true); + EmbeddedChannel ch = new EmbeddedChannel(decoder); + + // "C\n" should be appended to "AB" so that LineDecoder decodes it correctly. + ch.writeInbound(Unpooled.wrappedBuffer(new byte[]{'C', '\n' , 'B', '\n'})); + + ByteBuf buf = Unpooled.wrappedBuffer(new byte[] {'C'}); + ByteBuf buf2 = ch.readInbound(); + assertEquals(buf, buf2); + + buf.release(); + buf2.release(); + + assertNull(ch.readInbound(), "Must be null as it must only decode one frame"); + + ch.read(); + ch.finish(); + + buf = Unpooled.wrappedBuffer(new byte[] {'B'}); + buf2 = ch.readInbound(); + assertEquals(buf, buf2); + + buf.release(); + buf2.release(); + + assertNull(ch.readInbound()); + } + + @Test + public void testRemoveItself() { + EmbeddedChannel channel = new EmbeddedChannel(new ReplayingDecoder() { + private boolean removed; + + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { + assertFalse(removed); + in.readByte(); + ctx.pipeline().remove(this); + removed = true; + } + }); + + ByteBuf buf = Unpooled.wrappedBuffer(new byte[] {'a', 'b', 'c'}); + channel.writeInbound(buf.copy()); + ByteBuf b = channel.readInbound(); + assertEquals(b, buf.skipBytes(1)); + b.release(); + buf.release(); + } + + @Test + public void testRemoveItselfWithReplayError() { + EmbeddedChannel channel = new EmbeddedChannel(new ReplayingDecoder() { + private boolean removed; + + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { + assertFalse(removed); + ctx.pipeline().remove(this); + + in.readBytes(1000); + + removed = true; + } + }); + + ByteBuf buf = Unpooled.wrappedBuffer(new byte[] {'a', 'b', 'c'}); + channel.writeInbound(buf.copy()); + ByteBuf b = channel.readInbound(); + + assertEquals(b, buf, "Expect to have still all bytes in the buffer"); + b.release(); + buf.release(); + } + + @Test + public void testRemoveItselfWriteBuffer() { + final ByteBuf buf = Unpooled.buffer().writeBytes(new byte[] {'a', 'b', 'c'}); + EmbeddedChannel channel = new EmbeddedChannel(new ReplayingDecoder() { + private boolean removed; + + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { + assertFalse(removed); + in.readByte(); + ctx.pipeline().remove(this); + + // This should not let it keep call decode + buf.writeByte('d'); + removed = true; + } + }); + + channel.writeInbound(buf.copy()); + ByteBuf b = channel.readInbound(); + assertEquals(b, Unpooled.wrappedBuffer(new byte[] { 'b', 'c'})); + b.release(); + buf.release(); + } + + @Test + public void testFireChannelReadCompleteOnInactive() throws InterruptedException { + final BlockingQueue queue = new LinkedBlockingDeque(); + final ByteBuf buf = Unpooled.buffer().writeBytes(new byte[]{'a', 'b'}); + EmbeddedChannel channel = new EmbeddedChannel(new ReplayingDecoder() { + + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { + int readable = in.readableBytes(); + assertTrue(readable > 0); + in.skipBytes(readable); + out.add("data"); + } + + @Override + protected void decodeLast(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { + assertFalse(in.isReadable()); + out.add("data"); + } + }, new ChannelInboundHandlerAdapter() { + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + queue.add(3); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + queue.add(1); + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + if (!ctx.channel().isActive()) { + queue.add(2); + } + } + }); + assertFalse(channel.writeInbound(buf)); + channel.finish(); + assertEquals(1, (int) queue.take()); + assertEquals(1, (int) queue.take()); + assertEquals(2, (int) queue.take()); + assertEquals(3, (int) queue.take()); + assertTrue(queue.isEmpty()); + } + + @Test + public void testChannelInputShutdownEvent() { + final AtomicReference error = new AtomicReference(); + + EmbeddedChannel channel = new EmbeddedChannel(new ReplayingDecoder(0) { + private boolean decoded; + + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { + if (!(in instanceof ReplayingDecoderByteBuf)) { + error.set(new AssertionError("in must be of type " + ReplayingDecoderByteBuf.class + + " but was " + in.getClass())); + return; + } + if (!decoded) { + decoded = true; + in.readByte(); + state(1); + } else { + // This will throw an ReplayingError + in.skipBytes(Integer.MAX_VALUE); + } + } + }); + + assertFalse(channel.writeInbound(Unpooled.wrappedBuffer(new byte[] {0, 1}))); + channel.pipeline().fireUserEventTriggered(ChannelInputShutdownEvent.INSTANCE); + assertFalse(channel.finishAndReleaseAll()); + + Error err = error.get(); + if (err != null) { + throw err; + } + } + + @Test + public void handlerRemovedWillNotReleaseBufferIfDecodeInProgress() { + EmbeddedChannel channel = new EmbeddedChannel(new ReplayingDecoder() { + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { + ctx.pipeline().remove(this); + assertTrue(in.refCnt() != 0); + } + + @Override + protected void handlerRemoved0(ChannelHandlerContext ctx) throws Exception { + assertCumulationReleased(internalBuffer()); + } + }); + byte[] bytes = new byte[1024]; + PlatformDependent.threadLocalRandom().nextBytes(bytes); + + assertTrue(channel.writeInbound(Unpooled.wrappedBuffer(bytes))); + assertTrue(channel.finishAndReleaseAll()); + } + + private static void assertCumulationReleased(ByteBuf byteBuf) { + assertTrue(byteBuf == null || byteBuf == Unpooled.EMPTY_BUFFER || byteBuf.refCnt() == 0, + "unexpected value: " + byteBuf); + } +} diff --git a/netty-handler-codec/src/test/java/io/netty/handler/codec/base64/Base64Test.java b/netty-handler-codec/src/test/java/io/netty/handler/codec/base64/Base64Test.java new file mode 100644 index 0000000..90803c6 --- /dev/null +++ b/netty-handler-codec/src/test/java/io/netty/handler/codec/base64/Base64Test.java @@ -0,0 +1,191 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.base64; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.StringUtil; +import org.junit.jupiter.api.Test; + +import java.io.ByteArrayInputStream; +import java.nio.ByteOrder; +import java.security.cert.CertificateFactory; +import java.security.cert.X509Certificate; + +import static io.netty.buffer.Unpooled.copiedBuffer; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class Base64Test { + + @Test + public void testNotAddNewLineWhenEndOnLimit() { + ByteBuf src = copiedBuffer("abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyzabcde", + CharsetUtil.US_ASCII); + ByteBuf expectedEncoded = + copiedBuffer("YWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4eXphYmNkZWZnaGlqa2xtbm9wcXJzdHV2d3h5emFiY2Rl", + CharsetUtil.US_ASCII); + testEncode(src, expectedEncoded); + } + + @Test + public void testAddNewLine() { + ByteBuf src = copiedBuffer("abcdefghijklmnopqrstuvwxyzabcdefghijklmnopqrstuvwxyz12345678", + CharsetUtil.US_ASCII); + ByteBuf expectedEncoded = + copiedBuffer("YWJjZGVmZ2hpamtsbW5vcHFyc3R1dnd4eXphYmNkZWZnaGlqa2xtbm9wcXJzdHV2d3h5ejEyMzQ1\nNjc4", + CharsetUtil.US_ASCII); + testEncode(src, expectedEncoded); + } + + @Test + public void testEncodeEmpty() { + ByteBuf src = Unpooled.EMPTY_BUFFER; + ByteBuf expectedEncoded = Unpooled.EMPTY_BUFFER; + testEncode(src, expectedEncoded); + } + + @Test + public void testPaddingNewline() throws Exception { + String cert = "-----BEGIN CERTIFICATE-----\n" + + "MIICqjCCAjGgAwIBAgICI1YwCQYHKoZIzj0EATAmMSQwIgYDVQQDDBtUcnVzdGVk\n" + + "IFRoaW4gQ2xpZW50IFJvb3QgQ0EwIhcRMTYwMTI0MTU0OTQ1LTA2MDAXDTE2MDQy\n" + + "NTIyNDk0NVowYzEwMC4GA1UEAwwnREMgMGRlYzI0MGYtOTI2OS00MDY5LWE2MTYt\n" + + "YjJmNTI0ZjA2ZGE0MREwDwYDVQQLDAhEQyBJUFNFQzEcMBoGA1UECgwTVHJ1c3Rl\n" + + "ZCBUaGluIENsaWVudDB2MBAGByqGSM49AgEGBSuBBAAiA2IABOB7pZYC24sF5gJm\n" + + "OHXhasxmrNYebdtSAiQRgz0M0pIsogsFeTU/W0HTlTOqwDDckphHESAKHVxa6EBL\n" + + "d+/8HYZ1AaCmXtG73XpaOyaRr3TipJl2IaJzwuehgDHs0L+qcqOB8TCB7jAwBgYr\n" + + "BgEBEAQEJgwkMGRlYzI0MGYtOTI2OS00MDY5LWE2MTYtYjJmNTI0ZjA2ZGE0MCMG\n" + + "CisGAQQBjCHbZwEEFQwTNDkwNzUyMjc1NjM3MTE3Mjg5NjAUBgorBgEEAYwh22cC\n" + + "BAYMBDIwNTkwCwYDVR0PBAQDAgXgMAkGA1UdEwQCMAAwHQYDVR0OBBYEFGWljaKj\n" + + "wiGqW61PgLL/zLxj4iirMB8GA1UdIwQYMBaAFA2FRBtG/dGnl0iXP2uKFwJHmEQI\n" + + "MCcGA1UdJQQgMB4GCCsGAQUFBwMCBggrBgEFBQcDAQYIKwYBBQUHAwkwCQYHKoZI\n" + + "zj0EAQNoADBlAjAQFP8rMLUxl36u8610LsSCiRG8pP3gjuLaaJMm3tjbVue/TI4C\n" + + "z3iL8i96YWK0VxcCMQC7pf6Wk3RhUU2Sg6S9e6CiirFLDyzLkaWxuCnXcOwTvuXT\n" + + "HUQSeUCp2Q6ygS5qKyc=\n" + + "-----END CERTIFICATE-----"; + + String expected = "MIICqjCCAjGgAwIBAgICI1YwCQYHKoZIzj0EATAmMSQwIgYDVQQDDBtUcnVzdGVkIFRoaW4gQ2xp\n" + + "ZW50IFJvb3QgQ0EwIhcRMTYwMTI0MTU0OTQ1LTA2MDAXDTE2MDQyNTIyNDk0NVowYzEwMC4GA1UE\n" + + "AwwnREMgMGRlYzI0MGYtOTI2OS00MDY5LWE2MTYtYjJmNTI0ZjA2ZGE0MREwDwYDVQQLDAhEQyBJ\n" + + "UFNFQzEcMBoGA1UECgwTVHJ1c3RlZCBUaGluIENsaWVudDB2MBAGByqGSM49AgEGBSuBBAAiA2IA\n" + + "BOB7pZYC24sF5gJmOHXhasxmrNYebdtSAiQRgz0M0pIsogsFeTU/W0HTlTOqwDDckphHESAKHVxa\n" + + "6EBLd+/8HYZ1AaCmXtG73XpaOyaRr3TipJl2IaJzwuehgDHs0L+qcqOB8TCB7jAwBgYrBgEBEAQE\n" + + "JgwkMGRlYzI0MGYtOTI2OS00MDY5LWE2MTYtYjJmNTI0ZjA2ZGE0MCMGCisGAQQBjCHbZwEEFQwT\n" + + "NDkwNzUyMjc1NjM3MTE3Mjg5NjAUBgorBgEEAYwh22cCBAYMBDIwNTkwCwYDVR0PBAQDAgXgMAkG\n" + + "A1UdEwQCMAAwHQYDVR0OBBYEFGWljaKjwiGqW61PgLL/zLxj4iirMB8GA1UdIwQYMBaAFA2FRBtG\n" + + "/dGnl0iXP2uKFwJHmEQIMCcGA1UdJQQgMB4GCCsGAQUFBwMCBggrBgEFBQcDAQYIKwYBBQUHAwkw\n" + + "CQYHKoZIzj0EAQNoADBlAjAQFP8rMLUxl36u8610LsSCiRG8pP3gjuLaaJMm3tjbVue/TI4Cz3iL\n" + + "8i96YWK0VxcCMQC7pf6Wk3RhUU2Sg6S9e6CiirFLDyzLkaWxuCnXcOwTvuXTHUQSeUCp2Q6ygS5q\n" + + "Kyc="; + + ByteBuf src = Unpooled.wrappedBuffer(certFromString(cert).getEncoded()); + ByteBuf expectedEncoded = copiedBuffer(expected, CharsetUtil.US_ASCII); + testEncode(src, expectedEncoded); + } + + private static X509Certificate certFromString(String string) throws Exception { + CertificateFactory factory = CertificateFactory.getInstance("X.509"); + ByteArrayInputStream bin = new ByteArrayInputStream(string.getBytes(CharsetUtil.US_ASCII)); + try { + return (X509Certificate) factory.generateCertificate(bin); + } finally { + bin.close(); + } + } + + private static void testEncode(ByteBuf src, ByteBuf expectedEncoded) { + ByteBuf encoded = Base64.encode(src, true, Base64Dialect.STANDARD); + try { + assertEquals(expectedEncoded, encoded); + } finally { + src.release(); + expectedEncoded.release(); + encoded.release(); + } + } + + @Test + public void testEncodeDecodeBE() { + testEncodeDecode(ByteOrder.BIG_ENDIAN); + } + + @Test + public void testEncodeDecodeLE() { + testEncodeDecode(ByteOrder.LITTLE_ENDIAN); + } + + private static void testEncodeDecode(ByteOrder order) { + testEncodeDecode(64, order); + testEncodeDecode(128, order); + testEncodeDecode(512, order); + testEncodeDecode(1024, order); + testEncodeDecode(4096, order); + testEncodeDecode(8192, order); + testEncodeDecode(16384, order); + } + + private static void testEncodeDecode(int size, ByteOrder order) { + byte[] bytes = new byte[size]; + PlatformDependent.threadLocalRandom().nextBytes(bytes); + + ByteBuf src = Unpooled.wrappedBuffer(bytes).order(order); + ByteBuf encoded = Base64.encode(src); + ByteBuf decoded = Base64.decode(encoded); + ByteBuf expectedBuf = Unpooled.wrappedBuffer(bytes); + try { + assertEquals(expectedBuf, decoded, + StringUtil.NEWLINE + "expected: " + ByteBufUtil.hexDump(expectedBuf) + + StringUtil.NEWLINE + "actual--: " + ByteBufUtil.hexDump(decoded)); + } finally { + src.release(); + encoded.release(); + decoded.release(); + expectedBuf.release(); + } + } + + @Test + public void testOverflowEncodedBufferSize() { + assertEquals(Integer.MAX_VALUE, Base64.encodedBufferSize(Integer.MAX_VALUE, true)); + assertEquals(Integer.MAX_VALUE, Base64.encodedBufferSize(Integer.MAX_VALUE, false)); + } + + @Test + public void testOverflowDecodedBufferSize() { + assertEquals(1610612736, Base64.decodedBufferSize(Integer.MAX_VALUE)); + } + + @Test + public void decodingFailsOnInvalidInputByte() { + char[] invalidChars = {'\u007F', '\u0080', '\u00BD', '\u00FF'}; + for (char invalidChar : invalidChars) { + ByteBuf buf = copiedBuffer("eHh4" + invalidChar, CharsetUtil.ISO_8859_1); + try { + Base64.decode(buf); + fail("Invalid character in not detected: " + invalidChar); + } catch (IllegalArgumentException ignored) { + // as expected + } finally { + assertTrue(buf.release()); + } + } + } +} diff --git a/netty-handler-codec/src/test/java/io/netty/handler/codec/bytes/ByteArrayDecoderTest.java b/netty-handler-codec/src/test/java/io/netty/handler/codec/bytes/ByteArrayDecoderTest.java new file mode 100644 index 0000000..f7beef9 --- /dev/null +++ b/netty-handler-codec/src/test/java/io/netty/handler/codec/bytes/ByteArrayDecoderTest.java @@ -0,0 +1,58 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.bytes; + +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.util.internal.EmptyArrays; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.Random; + +import static io.netty.buffer.Unpooled.*; +import static org.hamcrest.core.Is.*; +import static org.hamcrest.MatcherAssert.assertThat; + +public class ByteArrayDecoderTest { + + private EmbeddedChannel ch; + + @BeforeEach + public void setUp() { + ch = new EmbeddedChannel(new ByteArrayDecoder()); + } + + @Test + public void testDecode() { + byte[] b = new byte[2048]; + new Random().nextBytes(b); + ch.writeInbound(wrappedBuffer(b)); + assertThat((byte[]) ch.readInbound(), is(b)); + } + + @Test + public void testDecodeEmpty() { + ch.writeInbound(EMPTY_BUFFER); + assertThat((byte[]) ch.readInbound(), is(EmptyArrays.EMPTY_BYTES)); + } + + @Test + public void testDecodeOtherType() { + String str = "Meep!"; + ch.writeInbound(str); + assertThat(ch.readInbound(), is((Object) str)); + } +} diff --git a/netty-handler-codec/src/test/java/io/netty/handler/codec/bytes/ByteArrayEncoderTest.java b/netty-handler-codec/src/test/java/io/netty/handler/codec/bytes/ByteArrayEncoderTest.java new file mode 100644 index 0000000..9e1df98 --- /dev/null +++ b/netty-handler-codec/src/test/java/io/netty/handler/codec/bytes/ByteArrayEncoderTest.java @@ -0,0 +1,67 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.bytes; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.util.internal.EmptyArrays; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import java.util.Random; + +import static io.netty.buffer.Unpooled.*; +import static org.hamcrest.CoreMatchers.*; +import static org.hamcrest.MatcherAssert.assertThat; + +public class ByteArrayEncoderTest { + + private EmbeddedChannel ch; + + @BeforeEach + public void setUp() { + ch = new EmbeddedChannel(new ByteArrayEncoder()); + } + + @AfterEach + public void tearDown() { + assertThat(ch.finish(), is(false)); + } + + @Test + public void testEncode() { + byte[] b = new byte[2048]; + new Random().nextBytes(b); + ch.writeOutbound(b); + ByteBuf encoded = ch.readOutbound(); + assertThat(encoded, is(wrappedBuffer(b))); + encoded.release(); + } + + @Test + public void testEncodeEmpty() { + ch.writeOutbound(EmptyArrays.EMPTY_BYTES); + assertThat((ByteBuf) ch.readOutbound(), is(sameInstance(EMPTY_BUFFER))); + } + + @Test + public void testEncodeOtherType() { + String str = "Meep!"; + ch.writeOutbound(str); + assertThat(ch.readOutbound(), is((Object) str)); + } +} diff --git a/netty-handler-codec/src/test/java/io/netty/handler/codec/frame/DelimiterBasedFrameDecoderTest.java b/netty-handler-codec/src/test/java/io/netty/handler/codec/frame/DelimiterBasedFrameDecoderTest.java new file mode 100644 index 0000000..6fa63eb --- /dev/null +++ b/netty-handler-codec/src/test/java/io/netty/handler/codec/frame/DelimiterBasedFrameDecoderTest.java @@ -0,0 +1,76 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.DecoderException; +import io.netty.handler.codec.DelimiterBasedFrameDecoder; +import io.netty.handler.codec.Delimiters; +import io.netty.handler.codec.TooLongFrameException; +import io.netty.util.CharsetUtil; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class DelimiterBasedFrameDecoderTest { + + @Test + public void testFailSlowTooLongFrameRecovery() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel( + new DelimiterBasedFrameDecoder(1, true, false, Delimiters.nulDelimiter())); + + for (int i = 0; i < 2; i ++) { + ch.writeInbound(Unpooled.wrappedBuffer(new byte[] { 1, 2 })); + try { + assertTrue(ch.writeInbound(Unpooled.wrappedBuffer(new byte[] { 0 }))); + fail(DecoderException.class.getSimpleName() + " must be raised."); + } catch (TooLongFrameException e) { + // Expected + } + + ch.writeInbound(Unpooled.wrappedBuffer(new byte[] { 'A', 0 })); + ByteBuf buf = ch.readInbound(); + assertEquals("A", buf.toString(CharsetUtil.ISO_8859_1)); + + buf.release(); + } + } + + @Test + public void testFailFastTooLongFrameRecovery() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel( + new DelimiterBasedFrameDecoder(1, Delimiters.nulDelimiter())); + + for (int i = 0; i < 2; i ++) { + try { + assertTrue(ch.writeInbound(Unpooled.wrappedBuffer(new byte[] { 1, 2 }))); + fail(DecoderException.class.getSimpleName() + " must be raised."); + } catch (TooLongFrameException e) { + // Expected + } + + ch.writeInbound(Unpooled.wrappedBuffer(new byte[] { 0, 'A', 0 })); + ByteBuf buf = ch.readInbound(); + assertEquals("A", buf.toString(CharsetUtil.ISO_8859_1)); + + buf.release(); + } + } +} diff --git a/netty-handler-codec/src/test/java/io/netty/handler/codec/frame/LengthFieldBasedFrameDecoderTest.java b/netty-handler-codec/src/test/java/io/netty/handler/codec/frame/LengthFieldBasedFrameDecoderTest.java new file mode 100644 index 0000000..ecdd212 --- /dev/null +++ b/netty-handler-codec/src/test/java/io/netty/handler/codec/frame/LengthFieldBasedFrameDecoderTest.java @@ -0,0 +1,73 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.DecoderException; +import io.netty.handler.codec.LengthFieldBasedFrameDecoder; +import io.netty.handler.codec.TooLongFrameException; +import io.netty.util.CharsetUtil; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class LengthFieldBasedFrameDecoderTest { + @Test + public void testFailSlowTooLongFrameRecovery() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel( + new LengthFieldBasedFrameDecoder(5, 0, 4, 0, 4, false)); + + for (int i = 0; i < 2; i ++) { + assertFalse(ch.writeInbound(Unpooled.wrappedBuffer(new byte[] { 0, 0, 0, 2 }))); + try { + assertTrue(ch.writeInbound(Unpooled.wrappedBuffer(new byte[] { 0, 0 }))); + fail(DecoderException.class.getSimpleName() + " must be raised."); + } catch (TooLongFrameException e) { + // Expected + } + + ch.writeInbound(Unpooled.wrappedBuffer(new byte[] { 0, 0, 0, 1, 'A' })); + ByteBuf buf = ch.readInbound(); + assertEquals("A", buf.toString(CharsetUtil.ISO_8859_1)); + buf.release(); + } + } + + @Test + public void testFailFastTooLongFrameRecovery() throws Exception { + EmbeddedChannel ch = new EmbeddedChannel( + new LengthFieldBasedFrameDecoder(5, 0, 4, 0, 4)); + + for (int i = 0; i < 2; i ++) { + try { + assertTrue(ch.writeInbound(Unpooled.wrappedBuffer(new byte[] { 0, 0, 0, 2 }))); + fail(DecoderException.class.getSimpleName() + " must be raised."); + } catch (TooLongFrameException e) { + // Expected + } + + ch.writeInbound(Unpooled.wrappedBuffer(new byte[] { 0, 0, 0, 0, 0, 1, 'A' })); + ByteBuf buf = ch.readInbound(); + assertEquals("A", buf.toString(CharsetUtil.ISO_8859_1)); + buf.release(); + } + } +} diff --git a/netty-handler-codec/src/test/java/io/netty/handler/codec/frame/LengthFieldPrependerTest.java b/netty-handler-codec/src/test/java/io/netty/handler/codec/frame/LengthFieldPrependerTest.java new file mode 100644 index 0000000..325f3a9 --- /dev/null +++ b/netty-handler-codec/src/test/java/io/netty/handler/codec/frame/LengthFieldPrependerTest.java @@ -0,0 +1,116 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.frame; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.EncoderException; +import io.netty.handler.codec.LengthFieldPrepender; +import io.netty.util.CharsetUtil; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import static io.netty.buffer.Unpooled.*; +import java.nio.ByteOrder; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.fail; + +public class LengthFieldPrependerTest { + + private ByteBuf msg; + + @BeforeEach + public void setUp() throws Exception { + msg = copiedBuffer("A", CharsetUtil.ISO_8859_1); + } + + @Test + public void testPrependLength() throws Exception { + final EmbeddedChannel ch = new EmbeddedChannel(new LengthFieldPrepender(4)); + ch.writeOutbound(msg); + ByteBuf buf = ch.readOutbound(); + assertEquals(4, buf.readableBytes()); + assertEquals(msg.readableBytes(), buf.readInt()); + buf.release(); + + buf = ch.readOutbound(); + assertSame(buf, msg); + buf.release(); + } + + @Test + public void testPrependLengthIncludesLengthFieldLength() throws Exception { + final EmbeddedChannel ch = new EmbeddedChannel(new LengthFieldPrepender(4, true)); + ch.writeOutbound(msg); + ByteBuf buf = ch.readOutbound(); + assertEquals(4, buf.readableBytes()); + assertEquals(5, buf.readInt()); + buf.release(); + + buf = ch.readOutbound(); + assertSame(buf, msg); + buf.release(); + } + + @Test + public void testPrependAdjustedLength() throws Exception { + final EmbeddedChannel ch = new EmbeddedChannel(new LengthFieldPrepender(4, -1)); + ch.writeOutbound(msg); + ByteBuf buf = ch.readOutbound(); + assertEquals(4, buf.readableBytes()); + assertEquals(msg.readableBytes() - 1, buf.readInt()); + buf.release(); + + buf = ch.readOutbound(); + assertSame(buf, msg); + buf.release(); + } + + @Test + public void testAdjustedLengthLessThanZero() throws Exception { + final EmbeddedChannel ch = new EmbeddedChannel(new LengthFieldPrepender(4, -2)); + try { + ch.writeOutbound(msg); + fail(EncoderException.class.getSimpleName() + " must be raised."); + } catch (EncoderException e) { + // Expected + } + } + + @Test + public void testPrependLengthInLittleEndian() throws Exception { + final EmbeddedChannel ch = new EmbeddedChannel(new LengthFieldPrepender(ByteOrder.LITTLE_ENDIAN, 4, 0, false)); + ch.writeOutbound(msg); + ByteBuf buf = ch.readOutbound(); + assertEquals(4, buf.readableBytes()); + byte[] writtenBytes = new byte[buf.readableBytes()]; + buf.getBytes(0, writtenBytes); + assertEquals(1, writtenBytes[0]); + assertEquals(0, writtenBytes[1]); + assertEquals(0, writtenBytes[2]); + assertEquals(0, writtenBytes[3]); + buf.release(); + + buf = ch.readOutbound(); + assertSame(buf, msg); + buf.release(); + assertFalse(ch.finish(), "The channel must have been completely read"); + } + +} diff --git a/netty-handler-codec/src/test/java/io/netty/handler/codec/frame/package-info.java b/netty-handler-codec/src/test/java/io/netty/handler/codec/frame/package-info.java new file mode 100644 index 0000000..d68433b --- /dev/null +++ b/netty-handler-codec/src/test/java/io/netty/handler/codec/frame/package-info.java @@ -0,0 +1,20 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/** + * Test classes for frame based decoders + */ +package io.netty.handler.codec.frame; diff --git a/netty-handler-codec/src/test/java/io/netty/handler/codec/json/JsonObjectDecoderTest.java b/netty-handler-codec/src/test/java/io/netty/handler/codec/json/JsonObjectDecoderTest.java new file mode 100644 index 0000000..b3d64f6 --- /dev/null +++ b/netty-handler-codec/src/test/java/io/netty/handler/codec/json/JsonObjectDecoderTest.java @@ -0,0 +1,418 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.codec.json; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.CorruptedFrameException; +import io.netty.handler.codec.TooLongFrameException; +import io.netty.util.CharsetUtil; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class JsonObjectDecoderTest { + @Test + public void testJsonObjectOverMultipleWrites() { + EmbeddedChannel ch = new EmbeddedChannel(new JsonObjectDecoder()); + + String objectPart1 = "{ \"firstname\": \"John"; + String objectPart2 = "\" ,\n \"surname\" :"; + String objectPart3 = "\"Doe\", age:22 \n}"; + + // Test object + ch.writeInbound(Unpooled.copiedBuffer(" \n\n " + objectPart1, CharsetUtil.UTF_8)); + ch.writeInbound(Unpooled.copiedBuffer(objectPart2, CharsetUtil.UTF_8)); + ch.writeInbound(Unpooled.copiedBuffer(objectPart3 + " \n\n \n", CharsetUtil.UTF_8)); + + ByteBuf res = ch.readInbound(); + assertEquals(objectPart1 + objectPart2 + objectPart3, res.toString(CharsetUtil.UTF_8)); + res.release(); + + assertFalse(ch.finish()); + } + + @Test + public void testMultipleJsonObjectsOverMultipleWrites() { + EmbeddedChannel ch = new EmbeddedChannel(new JsonObjectDecoder()); + + String objectPart1 = "{\"name\":\"Jo"; + String objectPart2 = "hn\"}{\"name\":\"John\"}{\"name\":\"Jo"; + String objectPart3 = "hn\"}"; + + ch.writeInbound(Unpooled.copiedBuffer(objectPart1, CharsetUtil.UTF_8)); + ch.writeInbound(Unpooled.copiedBuffer(objectPart2, CharsetUtil.UTF_8)); + ch.writeInbound(Unpooled.copiedBuffer(objectPart3, CharsetUtil.UTF_8)); + + for (int i = 0; i < 3; i++) { + ByteBuf res = ch.readInbound(); + assertEquals("{\"name\":\"John\"}", res.toString(CharsetUtil.UTF_8)); + res.release(); + } + + assertFalse(ch.finish()); + } + + @Test + public void testJsonArrayOverMultipleWrites() { + EmbeddedChannel ch = new EmbeddedChannel(new JsonObjectDecoder()); + + String arrayPart1 = "[{\"test"; + String arrayPart2 = "case\" : \"\\\"}]Escaped dou\\\"ble quotes \\\" in JSON str\\\"ing\""; + String arrayPart3 = " }\n\n , "; + String arrayPart4 = "{\"testcase\" : \"Streaming string me"; + String arrayPart5 = "ssage\"} ]"; + + // Test array + ch.writeInbound(Unpooled.copiedBuffer(" " + arrayPart1, CharsetUtil.UTF_8)); + ch.writeInbound(Unpooled.copiedBuffer(arrayPart2, CharsetUtil.UTF_8)); + ch.writeInbound(Unpooled.copiedBuffer(arrayPart3, CharsetUtil.UTF_8)); + ch.writeInbound(Unpooled.copiedBuffer(arrayPart4, CharsetUtil.UTF_8)); + ch.writeInbound(Unpooled.copiedBuffer(arrayPart5 + " ", CharsetUtil.UTF_8)); + + ByteBuf res = ch.readInbound(); + assertEquals(arrayPart1 + arrayPart2 + arrayPart3 + arrayPart4 + arrayPart5, res.toString(CharsetUtil.UTF_8)); + res.release(); + + assertFalse(ch.finish()); + } + + @Test + public void testStreamJsonArrayOverMultipleWrites1() { + String[] array = new String[] { + " [{\"test", + "case\" : \"\\\"}]Escaped dou\\\"ble quotes \\\" in JSON str\\\"ing\"", + " }\n\n , ", + "{\"testcase\" : \"Streaming string me", + "ssage\"} ] " + }; + String[] result = new String[] { + "{\"testcase\" : \"\\\"}]Escaped dou\\\"ble quotes \\\" in JSON str\\\"ing\" }", + "{\"testcase\" : \"Streaming string message\"}" + }; + doTestStreamJsonArrayOverMultipleWrites(2, array, result); + } + + @Test + public void testStreamJsonArrayOverMultipleWrites2() { + String[] array = new String[] { + " [{\"test", + "case\" : \"\\\"}]Escaped dou\\\"ble quotes \\\" in JSON str\\\"ing\"", + " }\n\n , {\"test", + "case\" : \"Streaming string me", + "ssage\"} ] " + }; + String[] result = new String[] { + "{\"testcase\" : \"\\\"}]Escaped dou\\\"ble quotes \\\" in JSON str\\\"ing\" }", + "{\"testcase\" : \"Streaming string message\"}" + }; + doTestStreamJsonArrayOverMultipleWrites(2, array, result); + } + + @Test + public void testStreamJsonArrayOverMultipleWrites3() { + String[] array = new String[] { + " [{\"test", + "case\" : \"\\\"}]Escaped dou\\\"ble quotes \\\" in JSON str\\\"ing\"", + " }\n\n , [{\"test", + "case\" : \"Streaming string me", + "ssage\"}] ] " + }; + String[] result = new String[] { + "{\"testcase\" : \"\\\"}]Escaped dou\\\"ble quotes \\\" in JSON str\\\"ing\" }", + "[{\"testcase\" : \"Streaming string message\"}]" + }; + doTestStreamJsonArrayOverMultipleWrites(2, array, result); + } + + private static void doTestStreamJsonArrayOverMultipleWrites(int indexDataAvailable, + String[] array, String[] result) { + EmbeddedChannel ch = new EmbeddedChannel(new JsonObjectDecoder(true)); + + boolean dataAvailable = false; + for (String part : array) { + dataAvailable = ch.writeInbound(Unpooled.copiedBuffer(part, CharsetUtil.UTF_8)); + if (indexDataAvailable > 0) { + assertFalse(dataAvailable); + } else { + assertTrue(dataAvailable); + } + indexDataAvailable--; + } + + for (String part : result) { + ByteBuf res = ch.readInbound(); + assertEquals(part, res.toString(CharsetUtil.UTF_8)); + res.release(); + } + + assertFalse(ch.finish()); + } + + @Test + public void testSingleByteStream() { + EmbeddedChannel ch = new EmbeddedChannel(new JsonObjectDecoder()); + + String json = "{\"foo\" : {\"bar\" : [{},{}]}}"; + for (byte c : json.getBytes(CharsetUtil.UTF_8)) { + ch.writeInbound(Unpooled.copiedBuffer(new byte[] {c})); + } + + ByteBuf res = ch.readInbound(); + assertEquals(json, res.toString(CharsetUtil.UTF_8)); + res.release(); + + assertFalse(ch.finish()); + } + + @Test + public void testBackslashInString1() { + EmbeddedChannel ch = new EmbeddedChannel(new JsonObjectDecoder()); + // {"foo" : "bar\""} + String json = "{\"foo\" : \"bar\\\"\"}"; + + ch.writeInbound(Unpooled.copiedBuffer(json, CharsetUtil.UTF_8)); + + ByteBuf res = ch.readInbound(); + assertEquals(json, res.toString(CharsetUtil.UTF_8)); + res.release(); + + assertFalse(ch.finish()); + } + + @Test + public void testBackslashInString2() { + EmbeddedChannel ch = new EmbeddedChannel(new JsonObjectDecoder()); + // {"foo" : "bar\\"} + String json = "{\"foo\" : \"bar\\\\\"}"; + + ch.writeInbound(Unpooled.copiedBuffer(json, CharsetUtil.UTF_8)); + + ByteBuf res = ch.readInbound(); + assertEquals(json, res.toString(CharsetUtil.UTF_8)); + res.release(); + + assertFalse(ch.finish()); + } + + @Test + public void testBackslashInString3() { + EmbeddedChannel ch = new EmbeddedChannel(new JsonObjectDecoder()); + // {"foo" : "bar\\\""} + String json = "{\"foo\" : \"bar\\\\\\\"\"}"; + + ch.writeInbound(Unpooled.copiedBuffer(json, CharsetUtil.UTF_8)); + + ByteBuf res = ch.readInbound(); + assertEquals(json, res.toString(CharsetUtil.UTF_8)); + res.release(); + + assertFalse(ch.finish()); + } + + @Test + public void testMultipleJsonObjectsInOneWrite() { + EmbeddedChannel ch = new EmbeddedChannel(new JsonObjectDecoder()); + + String object1 = "{\"key\" : \"value1\"}", + object2 = "{\"key\" : \"value2\"}", + object3 = "{\"key\" : \"value3\"}"; + + ch.writeInbound(Unpooled.copiedBuffer(object1 + object2 + object3, CharsetUtil.UTF_8)); + + ByteBuf res = ch.readInbound(); + assertEquals(object1, res.toString(CharsetUtil.UTF_8)); + res.release(); + res = ch.readInbound(); + assertEquals(object2, res.toString(CharsetUtil.UTF_8)); + res.release(); + res = ch.readInbound(); + assertEquals(object3, res.toString(CharsetUtil.UTF_8)); + res.release(); + + assertFalse(ch.finish()); + } + + @Test + public void testNonJsonContent1() { + final EmbeddedChannel ch = new EmbeddedChannel(new JsonObjectDecoder()); + try { + assertThrows(CorruptedFrameException.class, new Executable() { + @Override + public void execute() { + ch.writeInbound(Unpooled.copiedBuffer(" b [1,2,3]", CharsetUtil.UTF_8)); + } + }); + } finally { + assertFalse(ch.finish()); + } + } + + @Test + public void testNonJsonContent2() { + final EmbeddedChannel ch = new EmbeddedChannel(new JsonObjectDecoder()); + ch.writeInbound(Unpooled.copiedBuffer(" [1,2,3] ", CharsetUtil.UTF_8)); + + ByteBuf res = ch.readInbound(); + assertEquals("[1,2,3]", res.toString(CharsetUtil.UTF_8)); + res.release(); + + try { + assertThrows(CorruptedFrameException.class, new Executable() { + @Override + public void execute() { + ch.writeInbound(Unpooled.copiedBuffer(" a {\"key\" : 10}", CharsetUtil.UTF_8)); + } + }); + } finally { + assertFalse(ch.finish()); + } + } + + @Test + public void testMaxObjectLength() { + final EmbeddedChannel ch = new EmbeddedChannel(new JsonObjectDecoder(6)); + try { + assertThrows(TooLongFrameException.class, new Executable() { + @Override + public void execute() throws Throwable { + ch.writeInbound(Unpooled.copiedBuffer("[2,4,5]", CharsetUtil.UTF_8)); + } + }); + } finally { + assertFalse(ch.finish()); + } + } + + @Test + public void testOneJsonObjectPerWrite() { + EmbeddedChannel ch = new EmbeddedChannel(new JsonObjectDecoder()); + + String object1 = "{\"key\" : \"value1\"}", + object2 = "{\"key\" : \"value2\"}", + object3 = "{\"key\" : \"value3\"}"; + + ch.writeInbound(Unpooled.copiedBuffer(object1, CharsetUtil.UTF_8)); + ch.writeInbound(Unpooled.copiedBuffer(object2, CharsetUtil.UTF_8)); + ch.writeInbound(Unpooled.copiedBuffer(object3, CharsetUtil.UTF_8)); + + ByteBuf res = ch.readInbound(); + assertEquals(object1, res.toString(CharsetUtil.UTF_8)); + res.release(); + res = ch.readInbound(); + assertEquals(object2, res.toString(CharsetUtil.UTF_8)); + res.release(); + res = ch.readInbound(); + assertEquals(object3, res.toString(CharsetUtil.UTF_8)); + res.release(); + + assertFalse(ch.finish()); + } + + @Test + public void testSpecialJsonCharsInString() { + EmbeddedChannel ch = new EmbeddedChannel(new JsonObjectDecoder()); + + String object = "{ \"key\" : \"[]{}}\\\"}}'}\"}"; + ch.writeInbound(Unpooled.copiedBuffer(object, CharsetUtil.UTF_8)); + + ByteBuf res = ch.readInbound(); + assertEquals(object, res.toString(CharsetUtil.UTF_8)); + res.release(); + + assertFalse(ch.finish()); + } + + @Test + public void testStreamArrayElementsSimple() { + EmbeddedChannel ch = new EmbeddedChannel(new JsonObjectDecoder(Integer.MAX_VALUE, true)); + + String array = "[ 12, \"bla\" , 13.4 \t ,{\"key0\" : [1,2], \"key1\" : 12, \"key2\" : {}} , " + + "true, false, null, [\"bla\", {}, [1,2,3]] ]"; + String object = "{\"bla\" : \"blub\"}"; + ch.writeInbound(Unpooled.copiedBuffer(array, CharsetUtil.UTF_8)); + ch.writeInbound(Unpooled.copiedBuffer(object, CharsetUtil.UTF_8)); + + ByteBuf res = ch.readInbound(); + assertEquals("12", res.toString(CharsetUtil.UTF_8)); + res.release(); + res = ch.readInbound(); + assertEquals("\"bla\"", res.toString(CharsetUtil.UTF_8)); + res.release(); + res = ch.readInbound(); + assertEquals("13.4", res.toString(CharsetUtil.UTF_8)); + res.release(); + res = ch.readInbound(); + assertEquals("{\"key0\" : [1,2], \"key1\" : 12, \"key2\" : {}}", res.toString(CharsetUtil.UTF_8)); + res.release(); + res = ch.readInbound(); + assertEquals("true", res.toString(CharsetUtil.UTF_8)); + res.release(); + res = ch.readInbound(); + assertEquals("false", res.toString(CharsetUtil.UTF_8)); + res.release(); + res = ch.readInbound(); + assertEquals("null", res.toString(CharsetUtil.UTF_8)); + res.release(); + res = ch.readInbound(); + assertEquals("[\"bla\", {}, [1,2,3]]", res.toString(CharsetUtil.UTF_8)); + res.release(); + res = ch.readInbound(); + assertEquals(object, res.toString(CharsetUtil.UTF_8)); + res.release(); + + assertFalse(ch.finish()); + } + + @Test + public void testCorruptedFrameException() { + final String part1 = "{\"a\":{\"b\":{\"c\":{ \"d\":\"27301\", \"med\":\"d\", \"path\":\"27310\"} }," + + " \"status\":\"OK\" } }{\""; + final String part2 = "a\":{\"b\":{\"c\":{\"ory\":[{\"competi\":[{\"event\":[{" + "\"externalI\":{\"external\"" + + ":[{\"id\":\"O\"} ]"; + + EmbeddedChannel ch = new EmbeddedChannel(new JsonObjectDecoder()); + + ByteBuf res; + + ch.writeInbound(Unpooled.copiedBuffer(part1, CharsetUtil.UTF_8)); + res = ch.readInbound(); + assertEquals("{\"a\":{\"b\":{\"c\":{ \"d\":\"27301\", \"med\":\"d\", \"path\":\"27310\"} }, " + + "\"status\":\"OK\" } }", res.toString(CharsetUtil.UTF_8)); + res.release(); + + ch.writeInbound(Unpooled.copiedBuffer(part2, CharsetUtil.UTF_8)); + res = ch.readInbound(); + + assertNull(res); + + ch.writeInbound(Unpooled.copiedBuffer("}}]}]}]}}}}", CharsetUtil.UTF_8)); + res = ch.readInbound(); + + assertEquals("{\"a\":{\"b\":{\"c\":{\"ory\":[{\"competi\":[{\"event\":[{" + "\"externalI\":{" + + "\"external\":[{\"id\":\"O\"} ]}}]}]}]}}}}", res.toString(CharsetUtil.UTF_8)); + res.release(); + + assertFalse(ch.finish()); + } +} diff --git a/netty-handler-codec/src/test/java/io/netty/handler/codec/serialization/CompactObjectSerializationTest.java b/netty-handler-codec/src/test/java/io/netty/handler/codec/serialization/CompactObjectSerializationTest.java new file mode 100644 index 0000000..f39b7ea --- /dev/null +++ b/netty-handler-codec/src/test/java/io/netty/handler/codec/serialization/CompactObjectSerializationTest.java @@ -0,0 +1,36 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.serialization; + +import java.io.PipedInputStream; +import java.io.PipedOutputStream; +import java.util.List; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +public class CompactObjectSerializationTest { + + @Test + public void testInterfaceSerialization() throws Exception { + PipedOutputStream pipeOut = new PipedOutputStream(); + PipedInputStream pipeIn = new PipedInputStream(pipeOut); + CompactObjectOutputStream out = new CompactObjectOutputStream(pipeOut); + CompactObjectInputStream in = new CompactObjectInputStream(pipeIn, ClassResolvers.cacheDisabled(null)); + out.writeObject(List.class); + Assertions.assertSame(List.class, in.readObject()); + } +} diff --git a/netty-handler-codec/src/test/java/io/netty/handler/codec/serialization/CompatibleObjectEncoderTest.java b/netty-handler-codec/src/test/java/io/netty/handler/codec/serialization/CompatibleObjectEncoderTest.java new file mode 100644 index 0000000..3b48836 --- /dev/null +++ b/netty-handler-codec/src/test/java/io/netty/handler/codec/serialization/CompatibleObjectEncoderTest.java @@ -0,0 +1,79 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.serialization; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufInputStream; +import io.netty.channel.embedded.EmbeddedChannel; +import org.junit.jupiter.api.Test; + +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.Serializable; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; + +public class CompatibleObjectEncoderTest { + @Test + public void testMultipleEncodeReferenceCount() throws IOException, ClassNotFoundException { + EmbeddedChannel channel = new EmbeddedChannel(new CompatibleObjectEncoder()); + testEncode(channel, new TestSerializable(6, 8)); + testEncode(channel, new TestSerializable(10, 5)); + testEncode(channel, new TestSerializable(1, 5)); + assertFalse(channel.finishAndReleaseAll()); + } + + private static void testEncode(EmbeddedChannel channel, TestSerializable original) + throws IOException, ClassNotFoundException { + channel.writeOutbound(original); + Object o = channel.readOutbound(); + ByteBuf buf = (ByteBuf) o; + ObjectInputStream ois = new ObjectInputStream(new ByteBufInputStream(buf)); + try { + assertEquals(original, ois.readObject()); + } finally { + buf.release(); + ois.close(); + } + } + + private static final class TestSerializable implements Serializable { + private static final long serialVersionUID = 2235771472534930360L; + + public final int x; + public final int y; + + TestSerializable(int x, int y) { + this.x = x; + this.y = y; + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof TestSerializable)) { + return false; + } + TestSerializable rhs = (TestSerializable) o; + return x == rhs.x && y == rhs.y; + } + + @Override + public int hashCode() { + return 31 * (31 + x) + y; + } + } +} diff --git a/netty-handler-codec/src/test/java/io/netty/handler/codec/string/LineEncoderTest.java b/netty-handler-codec/src/test/java/io/netty/handler/codec/string/LineEncoderTest.java new file mode 100644 index 0000000..a8f2e3f --- /dev/null +++ b/netty-handler-codec/src/test/java/io/netty/handler/codec/string/LineEncoderTest.java @@ -0,0 +1,52 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.string; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.util.CharsetUtil; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class LineEncoderTest { + + @Test + public void testEncode() { + testLineEncode(LineSeparator.DEFAULT, "abc"); + testLineEncode(LineSeparator.WINDOWS, "abc"); + testLineEncode(LineSeparator.UNIX, "abc"); + } + + private static void testLineEncode(LineSeparator lineSeparator, String msg) { + EmbeddedChannel channel = new EmbeddedChannel(new LineEncoder(lineSeparator, CharsetUtil.UTF_8)); + assertTrue(channel.writeOutbound(msg)); + ByteBuf buf = channel.readOutbound(); + try { + byte[] data = new byte[buf.readableBytes()]; + buf.readBytes(data); + byte[] expected = (msg + lineSeparator.value()).getBytes(CharsetUtil.UTF_8); + assertArrayEquals(expected, data); + assertNull(channel.readOutbound()); + } finally { + buf.release(); + assertFalse(channel.finish()); + } + } +} diff --git a/netty-handler-codec/src/test/java/io/netty/handler/codec/string/StringDecoderTest.java b/netty-handler-codec/src/test/java/io/netty/handler/codec/string/StringDecoderTest.java new file mode 100644 index 0000000..f0da531 --- /dev/null +++ b/netty-handler-codec/src/test/java/io/netty/handler/codec/string/StringDecoderTest.java @@ -0,0 +1,42 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.string; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.util.CharsetUtil; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.assertNull; + +public class StringDecoderTest { + + @Test + public void testDecode() { + String msg = "abc123"; + ByteBuf byteBuf = Unpooled.copiedBuffer(msg, CharsetUtil.UTF_8); + EmbeddedChannel channel = new EmbeddedChannel(new StringDecoder()); + assertTrue(channel.writeInbound(byteBuf)); + String result = channel.readInbound(); + assertEquals(msg, result); + assertNull(channel.readInbound()); + assertFalse(channel.finish()); + } +} diff --git a/netty-handler-codec/src/test/java/io/netty/handler/codec/string/StringEncoderTest.java b/netty-handler-codec/src/test/java/io/netty/handler/codec/string/StringEncoderTest.java new file mode 100644 index 0000000..f0309b8 --- /dev/null +++ b/netty-handler-codec/src/test/java/io/netty/handler/codec/string/StringEncoderTest.java @@ -0,0 +1,42 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.codec.string; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.util.CharsetUtil; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertFalse; + +public class StringEncoderTest { + + @Test + public void testEncode() { + String msg = "Test"; + EmbeddedChannel channel = new EmbeddedChannel(new StringEncoder()); + Assertions.assertTrue(channel.writeOutbound(msg)); + Assertions.assertTrue(channel.finish()); + ByteBuf buf = channel.readOutbound(); + byte[] data = new byte[buf.readableBytes()]; + buf.readBytes(data); + Assertions.assertArrayEquals(msg.getBytes(CharsetUtil.UTF_8), data); + Assertions.assertNull(channel.readOutbound()); + buf.release(); + assertFalse(channel.finish()); + } +} diff --git a/netty-handler-codec/src/test/java/io/netty/handler/codec/xml/XmlFrameDecoderTest.java b/netty-handler-codec/src/test/java/io/netty/handler/codec/xml/XmlFrameDecoderTest.java new file mode 100644 index 0000000..a52835a --- /dev/null +++ b/netty-handler-codec/src/test/java/io/netty/handler/codec/xml/XmlFrameDecoderTest.java @@ -0,0 +1,230 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.codec.xml; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.codec.CorruptedFrameException; +import io.netty.handler.codec.TooLongFrameException; +import io.netty.util.CharsetUtil; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import java.io.IOException; +import java.net.URISyntaxException; +import java.net.URL; +import java.nio.ByteBuffer; +import java.nio.charset.StandardCharsets; +import java.nio.file.Files; +import java.nio.file.Paths; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.List; + +import static org.hamcrest.CoreMatchers.*; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class XmlFrameDecoderTest { + + private final List xmlSamples; + + public XmlFrameDecoderTest() throws IOException, URISyntaxException { + xmlSamples = Arrays.asList( + sample("01"), sample("02"), sample("03"), + sample("04"), sample("05"), sample("06") + ); + } + + @Test + public void testConstructorWithIllegalArgs01() { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + new XmlFrameDecoder(0); + } + }); + } + + @Test + public void testConstructorWithIllegalArgs02() { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + new XmlFrameDecoder(-23); + } + }); + } + + @Test + public void testDecodeWithFrameExceedingMaxLength() { + XmlFrameDecoder decoder = new XmlFrameDecoder(3); + final EmbeddedChannel ch = new EmbeddedChannel(decoder); + assertThrows(TooLongFrameException.class, new Executable() { + @Override + public void execute() { + ch.writeInbound(Unpooled.copiedBuffer("", CharsetUtil.UTF_8)); + } + }); + } + + @Test + public void testDecodeWithInvalidInput() { + XmlFrameDecoder decoder = new XmlFrameDecoder(1048576); + final EmbeddedChannel ch = new EmbeddedChannel(decoder); + assertThrows(CorruptedFrameException.class, new Executable() { + @Override + public void execute() { + ch.writeInbound(Unpooled.copiedBuffer("invalid XML", CharsetUtil.UTF_8)); + } + }); + } + + @Test + public void testDecodeWithInvalidContentBeforeXml() { + XmlFrameDecoder decoder = new XmlFrameDecoder(1048576); + final EmbeddedChannel ch = new EmbeddedChannel(decoder); + assertThrows(CorruptedFrameException.class, new Executable() { + @Override + public void execute() throws Throwable { + ch.writeInbound(Unpooled.copiedBuffer("invalid XML", CharsetUtil.UTF_8)); + } + }); + } + + @Test + public void testDecodeShortValidXml() { + testDecodeWithXml("", ""); + } + + @Test + public void testDecodeShortValidXmlWithLeadingWhitespace01() { + testDecodeWithXml(" ", ""); + } + + @Test + public void testDecodeShortValidXmlWithLeadingWhitespace02() { + testDecodeWithXml(" \n\r \t\t", ""); + } + + @Test + public void testDecodeShortValidXmlWithLeadingWhitespace02AndTrailingGarbage() { + testDecodeWithXml(" \n\r \t\ttrash", "", CorruptedFrameException.class); + } + + @Test + public void testDecodeInvalidXml() { + testDecodeWithXml("" + + ""; + testDecodeWithXml(xml, xml); + } + + @Test + public void testDecodeWithCDATABlockContainingNestedUnbalancedXml() { + //
isn't closed, also
should have been + final String xml = "" + + "ACME Inc.]]>" + + ""; + testDecodeWithXml(xml, xml); + } + + @Test + public void testDecodeWithMultipleMessages() { + final String input = "\n\n" + + "\n\n" + + ""; + final String frame1 = ""; + final String frame2 = "\n\n"; + final String frame3 = ""; + testDecodeWithXml(input, frame1, frame2, frame3); + } + + @Test + public void testFraming() { + testDecodeWithXml(Arrays.asList("123"), "123"); + } + + @Test + public void testDecodeWithSampleXml() { + for (final String xmlSample : xmlSamples) { + testDecodeWithXml(xmlSample, xmlSample); + } + } + + private static void testDecodeWithXml(List xmlFrames, Object... expected) { + EmbeddedChannel ch = new EmbeddedChannel(new XmlFrameDecoder(1048576)); + Exception cause = null; + try { + for (String xmlFrame : xmlFrames) { + ch.writeInbound(Unpooled.copiedBuffer(xmlFrame, CharsetUtil.UTF_8)); + } + } catch (Exception e) { + cause = e; + } + List actual = new ArrayList(); + for (;;) { + ByteBuf buf = ch.readInbound(); + if (buf == null) { + break; + } + actual.add(buf.toString(CharsetUtil.UTF_8)); + buf.release(); + } + + if (cause != null) { + actual.add(cause.getClass()); + } + + try { + List expectedList = new ArrayList(); + Collections.addAll(expectedList, expected); + assertThat(actual, is(expectedList)); + } finally { + ch.finish(); + } + } + + private static void testDecodeWithXml(String xml, Object... expected) { + testDecodeWithXml(Collections.singletonList(xml), expected); + } + + private String sample(String number) throws IOException, URISyntaxException { + String path = "io/netty/handler/codec/xml/sample-" + number + ".xml"; + URL url = getClass().getClassLoader().getResource(path); + if (url == null) { + throw new IllegalArgumentException("file not found: " + path); + } + byte[] buf = Files.readAllBytes(Paths.get(url.toURI())); + return StandardCharsets.UTF_8.decode(ByteBuffer.wrap(buf)).toString(); + } +} diff --git a/netty-handler-codec/src/test/resources/io/netty/handler/codec/xml/sample-01.xml b/netty-handler-codec/src/test/resources/io/netty/handler/codec/xml/sample-01.xml new file mode 100644 index 0000000..2408a6f --- /dev/null +++ b/netty-handler-codec/src/test/resources/io/netty/handler/codec/xml/sample-01.xml @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/netty-handler-codec/src/test/resources/io/netty/handler/codec/xml/sample-02.xml b/netty-handler-codec/src/test/resources/io/netty/handler/codec/xml/sample-02.xml new file mode 100644 index 0000000..ffd40cd --- /dev/null +++ b/netty-handler-codec/src/test/resources/io/netty/handler/codec/xml/sample-02.xml @@ -0,0 +1,3 @@ + + + \ No newline at end of file diff --git a/netty-handler-codec/src/test/resources/io/netty/handler/codec/xml/sample-03.xml b/netty-handler-codec/src/test/resources/io/netty/handler/codec/xml/sample-03.xml new file mode 100644 index 0000000..a928402 --- /dev/null +++ b/netty-handler-codec/src/test/resources/io/netty/handler/codec/xml/sample-03.xml @@ -0,0 +1,65 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + \ No newline at end of file diff --git a/netty-handler-codec/src/test/resources/io/netty/handler/codec/xml/sample-04.xml b/netty-handler-codec/src/test/resources/io/netty/handler/codec/xml/sample-04.xml new file mode 100644 index 0000000..22b9d70 --- /dev/null +++ b/netty-handler-codec/src/test/resources/io/netty/handler/codec/xml/sample-04.xml @@ -0,0 +1,752 @@ + + + + + 4.0.0 + + org.sonatype.oss + oss-parent + 7 + + + io.netty + netty-parent + pom + 4.0.14.Final-SNAPSHOT + + Netty + https://netty.io/ + + Netty is an asynchronous event-driven network application framework for + rapid development of maintainable high performance protocol servers and + clients. + + + + The Netty Project + https://netty.io/ + + + + + Apache License, Version 2.0 + https://www.apache.org/licenses/LICENSE-2.0 + + + 2008 + + + https://github.com/netty/netty + scm:git:git://github.com/netty/netty.git + scm:git:ssh://git@github.com/netty/netty.git + HEAD + + + + + netty.io + The Netty Project Contributors + netty@googlegroups.com + https://netty.io/ + The Netty Project + https://netty.io/ + + + + + UTF-8 + UTF-8 + 1.3.18.GA + + -server + -dsa -da -ea:io.netty... + -XX:+AggressiveOpts + -XX:+TieredCompilation + -XX:+UseBiasedLocking + -XX:+UseFastAccessorMethods + -XX:+UseStringCache + -XX:+OptimizeStringConcat + -XX:+HeapDumpOnOutOfMemoryError + + + + + common + buffer + codec + codec-http + codec-socks + transport + transport-rxtx + transport-sctp + transport-udt + handler + example + testsuite + microbench + all + tarball + + + + + + + org.jboss.marshalling + jboss-marshalling + ${jboss.marshalling.version} + compile + true + + + + com.google.protobuf + protobuf-java + 2.6.1 + + + com.jcraft + jzlib + 1.1.3 + + + + org.rxtx + rxtx + 2.1.7 + + + + com.barchart.udt + barchart-udt-bundle + 2.3.0 + + + + javax.servlet + servlet-api + 2.5 + + + + org.slf4j + slf4j-api + 1.7.21 + + + commons-logging + commons-logging + 1.2 + + + log4j + log4j + 1.2.17 + + + mail + javax.mail + + + jms + javax.jms + + + jmxtools + com.sun.jdmk + + + jmxri + com.sun.jmx + + + true + + + + + com.yammer.metrics + metrics-core + 2.2.0 + + + + + org.jboss.marshalling + jboss-marshalling-serial + ${jboss.marshalling.version} + test + + + org.jboss.marshalling + jboss-marshalling-river + ${jboss.marshalling.version} + test + + + + + com.google.caliper + caliper + 0.5-rc1 + test + + + + + + + + org.easymock + easymock + 3.4 + test + + + org.easymock + easymockclassextension + 3.2 + test + + + org.jmock + jmock-junit4 + 2.8.2 + test + + + ch.qos.logback + logback-classic + 1.1.7 + test + + + + + + + maven-enforcer-plugin + 1.4.1 + + + enforce-tools + + enforce + + + + + + + [1.7.0,) + + + [3.0.5,3.1) + + + + + + + + maven-compiler-plugin + 3.5.1 + + 1.7 + true + 1.6 + 1.6 + true + true + true + true + -Xlint:-options + + + 256m + 1024m + + + + + org.codehaus.mojo + animal-sniffer-maven-plugin + 1.15 + + + org.codehaus.mojo.signature + java16 + 1.1 + + + sun.misc.Unsafe + sun.misc.Cleaner + + java.util.zip.Deflater + + + java.nio.channels.DatagramChannel + java.nio.channels.MembershipKey + java.net.StandardProtocolFamily + + + java.nio.channels.AsynchronousChannel + java.nio.channels.AsynchronousSocketChannel + java.nio.channels.AsynchronousServerSocketChannel + java.nio.channels.AsynchronousChannelGroup + java.nio.channels.NetworkChannel + java.nio.channels.InterruptedByTimeoutException + java.net.StandardSocketOptions + java.net.SocketOption + + + + + process-classes + + check + + + + + + maven-checkstyle-plugin + 2.12.1 + + + check-style + + check + + validate + + true + true + true + true + io/netty/checkstyle.xml + true + + + + + + ${project.groupId} + netty-build + 19 + + + + + maven-surefire-plugin + + + **/*Test*.java + **/*Benchmark*.java + + + **/Abstract* + **/TestUtil* + + random + ${test.jvm.argLine} + + + + + org.apache.felix + maven-bundle-plugin + 2.5.4 + + + generate-manifest + process-classes + + manifest + + + + ${project.groupId}.* + + sun.misc.*;resolution:=optional,* + + !* + + + + + + + maven-source-plugin + 3.0.1 + + + + attach-sources + invalid + + jar + + + + attach-sources-no-fork + package + + jar-no-fork + + + + + + maven-javadoc-plugin + 2.10.4 + + false + true + false + false + true + + + + maven-deploy-plugin + 2.8.2 + + 10 + + + + maven-release-plugin + + 2.5.3 + + false + -P release,sonatype-oss-release,full,no-osgi + true + false + netty-@{project.version} + + + + + + maven-antrun-plugin + + + + write-version-properties + initialize + + run + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + Current commit: ${shortCommitHash} on ${commitDate} + + + + + + + + + + + + + + + + + + + + + + + org.apache.ant + ant + 1.9.7 + + + org.apache.ant + ant-launcher + 1.9.7 + + + ant-contrib + ant-contrib + 1.0b3 + + + ant + ant + + + + + + + + + + + + maven-surefire-plugin + 2.19.1 + + + + maven-failsafe-plugin + 2.19.1 + + + maven-clean-plugin + 3.0.0 + + + maven-resources-plugin + 3.0.1 + + + maven-jar-plugin + 3.0.2 + + + default-jar + + + + true + + true + ${project.build.outputDirectory}/META-INF/MANIFEST.MF + + + + + + + maven-dependency-plugin + 2.10 + + + maven-assembly-plugin + 2.6 + + + + maven-jxr-plugin + 2.2 + + + maven-antrun-plugin + 1.8 + + + ant-contrib + ant-contrib + 1.0b3 + + + ant + ant + + + + + + + org.codehaus.mojo + build-helper-maven-plugin + 1.10 + + + + + org.eclipse.m2e + lifecycle-mapping + 1.0.0 + + + + + + org.apache.maven.plugins + maven-antrun-plugin + [1.7,) + + run + + + + + + + + + org.apache.maven.plugins + maven-checkstyle-plugin + [1.0,) + + check + + + + + false + + + + + + org.apache.maven.plugins + maven-enforcer-plugin + [1.0,) + + enforce + + + + + false + + + + + + org.apache.maven.plugins + maven-clean-plugin + [1.0,) + + clean + + + + + false + + + + + + org.apache.felix + maven-bundle-plugin + [2.4,) + + manifest + + + + + + + + + + + + + + \ No newline at end of file diff --git a/netty-handler-codec/src/test/resources/io/netty/handler/codec/xml/sample-05.xml b/netty-handler-codec/src/test/resources/io/netty/handler/codec/xml/sample-05.xml new file mode 100644 index 0000000..427cbf8 --- /dev/null +++ b/netty-handler-codec/src/test/resources/io/netty/handler/codec/xml/sample-05.xml @@ -0,0 +1,81 @@ + + + + Rocket Launching + + + Configuration + + + Configuring rockets should look very familiar if you're used + to Jakarta Commons-Rocket or Commons-Space. You will first + create a normal + ContextSource + then wrap it in a + RocketContextSource + . + + + ... + + + + + + + + + + + ... + +]]> + + In a real world example you would probably configure the + rocket options and enable connection validation; the above + serves as an example to demonstrate the general idea. + + + + Validation Configuration + + + Adding validation and a few rocket configuration tweaks to + the above example is straight forward. Inject a + RocketContextValidator + and set when validation should occur and the rocket is + ready to go. + + + ... + + + + + + + + + + + + + + + + ... + +]]> + + The above example will test each + RocketContext + before it is passed to the client application and test + RocketContext + s that have been sitting idle in orbit. + + + + \ No newline at end of file diff --git a/netty-handler-codec/src/test/resources/io/netty/handler/codec/xml/sample-06.xml b/netty-handler-codec/src/test/resources/io/netty/handler/codec/xml/sample-06.xml new file mode 100644 index 0000000..ecf5f8d --- /dev/null +++ b/netty-handler-codec/src/test/resources/io/netty/handler/codec/xml/sample-06.xml @@ -0,0 +1,62 @@ + + + + Rocket Launching + + + Configuration + + + Configuring rockets should look very familiar if you're used + to Jakarta Commons-Rocket or Commons-Space. You will first + create a normal + ContextSource + then wrap it in a + RocketContextSource + . + + + ... + + + ... +]]> + + In a real world example you would probably configure the + rocket options and enable connection validation; the above + serves as an example to demonstrate the general idea. + + + + Validation Configuration + + + Adding validation and a few rocket configuration tweaks to + the above example is straight forward. Inject a + RocketContextValidator + and set when validation should occur and the rocket is + ready to go. + + + ... + + + ... + + + ... +]]> + + The above example will test each + RocketContext + before it is passed to the client application and test + RocketContext + s that have been sitting idle in orbit. + + + + \ No newline at end of file diff --git a/netty-handler-ssl/build.gradle b/netty-handler-ssl/build.gradle new file mode 100644 index 0000000..9f6a9ca --- /dev/null +++ b/netty-handler-ssl/build.gradle @@ -0,0 +1,18 @@ +dependencies { + api project(':netty-handler-codec') + api project(':netty-internal-tcnative') + api project(':netty-channel-unix') + implementation libs.bouncycastle + implementation libs.conscrypt + testImplementation testLibs.mockito.core + testImplementation testLibs.assertj + testImplementation project(':netty-handler') + testImplementation (testLibs.amazonCorrettoCrypt) { + artifact { + classifier = 'linux-x86_64' + } + } + testRuntimeOnly(variantOf(testLibs.netty.tcnative.boringssl.static) { + classifier('linux-x86_64') + }) +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/AbstractSniHandler.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/AbstractSniHandler.java new file mode 100644 index 0000000..8a6f967 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/AbstractSniHandler.java @@ -0,0 +1,222 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.util.CharsetUtil; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.ScheduledFuture; + +import java.util.Locale; +import java.util.concurrent.TimeUnit; + +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; + +/** + *

Enables SNI + * (Server Name Indication) extension for server side SSL. For clients + * support SNI, the server could have multiple host name bound on a single IP. + * The client will send host name in the handshake data so server could decide + * which certificate to choose for the host name.

+ * @param the parameter type + */ +public abstract class AbstractSniHandler extends SslClientHelloHandler { + + private static String extractSniHostname(ByteBuf in) { + // See https://tools.ietf.org/html/rfc5246#section-7.4.1.2 + // + // Decode the ssl client hello packet. + // + // struct { + // ProtocolVersion client_version; + // Random random; + // SessionID session_id; + // CipherSuite cipher_suites<2..2^16-2>; + // CompressionMethod compression_methods<1..2^8-1>; + // select (extensions_present) { + // case false: + // struct {}; + // case true: + // Extension extensions<0..2^16-1>; + // }; + // } ClientHello; + // + + // We have to skip bytes until SessionID (which sum to 34 bytes in this case). + int offset = in.readerIndex(); + int endOffset = in.writerIndex(); + offset += 34; + + if (endOffset - offset >= 6) { + final int sessionIdLength = in.getUnsignedByte(offset); + offset += sessionIdLength + 1; + + final int cipherSuitesLength = in.getUnsignedShort(offset); + offset += cipherSuitesLength + 2; + + final int compressionMethodLength = in.getUnsignedByte(offset); + offset += compressionMethodLength + 1; + + final int extensionsLength = in.getUnsignedShort(offset); + offset += 2; + final int extensionsLimit = offset + extensionsLength; + + // Extensions should never exceed the record boundary. + if (extensionsLimit <= endOffset) { + while (extensionsLimit - offset >= 4) { + final int extensionType = in.getUnsignedShort(offset); + offset += 2; + + final int extensionLength = in.getUnsignedShort(offset); + offset += 2; + + if (extensionsLimit - offset < extensionLength) { + break; + } + + // SNI + // See https://tools.ietf.org/html/rfc6066#page-6 + if (extensionType == 0) { + offset += 2; + if (extensionsLimit - offset < 3) { + break; + } + + final int serverNameType = in.getUnsignedByte(offset); + offset++; + + if (serverNameType == 0) { + final int serverNameLength = in.getUnsignedShort(offset); + offset += 2; + + if (extensionsLimit - offset < serverNameLength) { + break; + } + + final String hostname = in.toString(offset, serverNameLength, CharsetUtil.US_ASCII); + return hostname.toLowerCase(Locale.US); + } else { + // invalid enum value + break; + } + } + + offset += extensionLength; + } + } + } + return null; + } + + protected final long handshakeTimeoutMillis; + private ScheduledFuture timeoutFuture; + private String hostname; + + /** + * @param handshakeTimeoutMillis the handshake timeout in milliseconds + */ + protected AbstractSniHandler(long handshakeTimeoutMillis) { + this(0, handshakeTimeoutMillis); + } + + /** + * @paramm maxClientHelloLength the maximum length of the client hello message. + * @param handshakeTimeoutMillis the handshake timeout in milliseconds + */ + protected AbstractSniHandler(int maxClientHelloLength, long handshakeTimeoutMillis) { + super(maxClientHelloLength); + this.handshakeTimeoutMillis = checkPositiveOrZero(handshakeTimeoutMillis, "handshakeTimeoutMillis"); + } + + public AbstractSniHandler() { + this(0, 0L); + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + if (ctx.channel().isActive()) { + checkStartTimeout(ctx); + } + } + + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + ctx.fireChannelActive(); + checkStartTimeout(ctx); + } + + private void checkStartTimeout(final ChannelHandlerContext ctx) { + if (handshakeTimeoutMillis <= 0 || timeoutFuture != null) { + return; + } + timeoutFuture = ctx.executor().schedule(new Runnable() { + @Override + public void run() { + if (ctx.channel().isActive()) { + SslHandshakeTimeoutException exception = new SslHandshakeTimeoutException( + "handshake timed out after " + handshakeTimeoutMillis + "ms"); + ctx.fireUserEventTriggered(new SniCompletionEvent(exception)); + ctx.close(); + } + } + }, handshakeTimeoutMillis, TimeUnit.MILLISECONDS); + } + + @Override + protected Future lookup(ChannelHandlerContext ctx, ByteBuf clientHello) throws Exception { + hostname = clientHello == null ? null : extractSniHostname(clientHello); + + return lookup(ctx, hostname); + } + + @Override + protected void onLookupComplete(ChannelHandlerContext ctx, Future future) throws Exception { + if (timeoutFuture != null) { + timeoutFuture.cancel(false); + } + try { + onLookupComplete(ctx, hostname, future); + } finally { + fireSniCompletionEvent(ctx, hostname, future); + } + } + + /** + * Kicks off a lookup for the given SNI value and returns a {@link Future} which in turn will + * notify the {@link #onLookupComplete(ChannelHandlerContext, String, Future)} on completion. + * + * @see #onLookupComplete(ChannelHandlerContext, String, Future) + */ + protected abstract Future lookup(ChannelHandlerContext ctx, String hostname) throws Exception; + + /** + * Called upon completion of the {@link #lookup(ChannelHandlerContext, String)} {@link Future}. + * + * @see #lookup(ChannelHandlerContext, String) + */ + protected abstract void onLookupComplete(ChannelHandlerContext ctx, + String hostname, Future future) throws Exception; + + private static void fireSniCompletionEvent(ChannelHandlerContext ctx, String hostname, Future future) { + Throwable cause = future.cause(); + if (cause == null) { + ctx.fireUserEventTriggered(new SniCompletionEvent(hostname)); + } else { + ctx.fireUserEventTriggered(new SniCompletionEvent(hostname, cause)); + } + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/ApplicationProtocolAccessor.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/ApplicationProtocolAccessor.java new file mode 100644 index 0000000..489d59d --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/ApplicationProtocolAccessor.java @@ -0,0 +1,30 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.ssl; + +/** + * Provides a way to get the application-level protocol name from ALPN or NPN. + */ +interface ApplicationProtocolAccessor { + /** + * Returns the name of the negotiated application-level protocol. + * + * @return the application-level protocol name or + * {@code null} if the negotiation failed or the client does not have ALPN/NPN extension + */ + String getNegotiatedApplicationProtocol(); +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/ApplicationProtocolConfig.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/ApplicationProtocolConfig.java new file mode 100644 index 0000000..2392411 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/ApplicationProtocolConfig.java @@ -0,0 +1,184 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import java.util.Collections; +import java.util.List; + +import javax.net.ssl.SSLEngine; + +import static io.netty.handler.ssl.ApplicationProtocolUtil.toList; +import static io.netty.util.internal.ObjectUtil.checkNotNull; +import static io.netty.util.internal.ObjectUtil.checkNonEmpty; + +/** + * Provides an {@link SSLEngine} agnostic way to configure a {@link ApplicationProtocolNegotiator}. + */ +public final class ApplicationProtocolConfig { + + /** + * The configuration that disables application protocol negotiation. + */ + public static final ApplicationProtocolConfig DISABLED = new ApplicationProtocolConfig(); + + private final List supportedProtocols; + private final Protocol protocol; + private final SelectorFailureBehavior selectorBehavior; + private final SelectedListenerFailureBehavior selectedBehavior; + + /** + * Create a new instance. + * @param protocol The application protocol functionality to use. + * @param selectorBehavior How the peer selecting the protocol should behave. + * @param selectedBehavior How the peer being notified of the selected protocol should behave. + * @param supportedProtocols The order of iteration determines the preference of support for protocols. + */ + public ApplicationProtocolConfig(Protocol protocol, SelectorFailureBehavior selectorBehavior, + SelectedListenerFailureBehavior selectedBehavior, Iterable supportedProtocols) { + this(protocol, selectorBehavior, selectedBehavior, toList(supportedProtocols)); + } + + /** + * Create a new instance. + * @param protocol The application protocol functionality to use. + * @param selectorBehavior How the peer selecting the protocol should behave. + * @param selectedBehavior How the peer being notified of the selected protocol should behave. + * @param supportedProtocols The order of iteration determines the preference of support for protocols. + */ + public ApplicationProtocolConfig(Protocol protocol, SelectorFailureBehavior selectorBehavior, + SelectedListenerFailureBehavior selectedBehavior, String... supportedProtocols) { + this(protocol, selectorBehavior, selectedBehavior, toList(supportedProtocols)); + } + + /** + * Create a new instance. + * @param protocol The application protocol functionality to use. + * @param selectorBehavior How the peer selecting the protocol should behave. + * @param selectedBehavior How the peer being notified of the selected protocol should behave. + * @param supportedProtocols The order of iteration determines the preference of support for protocols. + */ + private ApplicationProtocolConfig( + Protocol protocol, SelectorFailureBehavior selectorBehavior, + SelectedListenerFailureBehavior selectedBehavior, List supportedProtocols) { + this.supportedProtocols = Collections.unmodifiableList(checkNotNull(supportedProtocols, "supportedProtocols")); + this.protocol = checkNotNull(protocol, "protocol"); + this.selectorBehavior = checkNotNull(selectorBehavior, "selectorBehavior"); + this.selectedBehavior = checkNotNull(selectedBehavior, "selectedBehavior"); + + if (protocol == Protocol.NONE) { + throw new IllegalArgumentException("protocol (" + Protocol.NONE + ") must not be " + Protocol.NONE + '.'); + } + checkNonEmpty(supportedProtocols, "supportedProtocols"); + } + + /** + * A special constructor that is used to instantiate {@link #DISABLED}. + */ + private ApplicationProtocolConfig() { + supportedProtocols = Collections.emptyList(); + protocol = Protocol.NONE; + selectorBehavior = SelectorFailureBehavior.CHOOSE_MY_LAST_PROTOCOL; + selectedBehavior = SelectedListenerFailureBehavior.ACCEPT; + } + + /** + * Defines which application level protocol negotiation to use. + */ + public enum Protocol { + NONE, NPN, ALPN, NPN_AND_ALPN + } + + /** + * Defines the most common behaviors for the peer that selects the application protocol. + */ + public enum SelectorFailureBehavior { + /** + * If the peer who selects the application protocol doesn't find a match this will result in the failing the + * handshake with a fatal alert. + *

+ * For example in the case of ALPN this will result in a + * no_application_protocol(120) alert. + */ + FATAL_ALERT, + /** + * If the peer who selects the application protocol doesn't find a match it will pretend no to support + * the TLS extension by not advertising support for the TLS extension in the handshake. This is used in cases + * where a "best effort" is desired to talk even if there is no matching protocol. + */ + NO_ADVERTISE, + /** + * If the peer who selects the application protocol doesn't find a match it will just select the last protocol + * it advertised support for. This is used in cases where a "best effort" is desired to talk even if there + * is no matching protocol, and the assumption is the "most general" fallback protocol is typically listed last. + *

+ * This may be illegal for some RFCs but was + * observed behavior by some SSL implementations, and is supported for flexibility/compatibility. + */ + CHOOSE_MY_LAST_PROTOCOL + } + + /** + * Defines the most common behaviors for the peer which is notified of the selected protocol. + */ + public enum SelectedListenerFailureBehavior { + /** + * If the peer who is notified what protocol was selected determines the selection was not matched, or the peer + * didn't advertise support for the TLS extension then the handshake will continue and the application protocol + * is assumed to be accepted. + */ + ACCEPT, + /** + * If the peer who is notified what protocol was selected determines the selection was not matched, or the peer + * didn't advertise support for the TLS extension then the handshake will be failed with a fatal alert. + */ + FATAL_ALERT, + /** + * If the peer who is notified what protocol was selected determines the selection was not matched, or the peer + * didn't advertise support for the TLS extension then the handshake will continue assuming the last protocol + * supported by this peer is used. This is used in cases where a "best effort" is desired to talk even if there + * is no matching protocol, and the assumption is the "most general" fallback protocol is typically listed last. + */ + CHOOSE_MY_LAST_PROTOCOL + } + + /** + * The application level protocols supported. + */ + public List supportedProtocols() { + return supportedProtocols; + } + + /** + * Get which application level protocol negotiation to use. + */ + public Protocol protocol() { + return protocol; + } + + /** + * Get the desired behavior for the peer who selects the application protocol. + */ + public SelectorFailureBehavior selectorFailureBehavior() { + return selectorBehavior; + } + + /** + * Get the desired behavior for the peer who is notified of the selected protocol. + */ + public SelectedListenerFailureBehavior selectedListenerFailureBehavior() { + return selectedBehavior; + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/ApplicationProtocolNames.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/ApplicationProtocolNames.java new file mode 100644 index 0000000..721d72a --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/ApplicationProtocolNames.java @@ -0,0 +1,59 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.ssl; + +/** + * Provides a set of protocol names used in ALPN and NPN. + * + * @see RFC7540 (HTTP/2) + * @see RFC7301 (TLS ALPN Extension) + * @see TLS NPN Extension Draft + */ +public final class ApplicationProtocolNames { + + /** + * {@code "h2"}: HTTP version 2 + */ + public static final String HTTP_2 = "h2"; + + /** + * {@code "http/1.1"}: HTTP version 1.1 + */ + public static final String HTTP_1_1 = "http/1.1"; + + /** + * {@code "spdy/3.1"}: SPDY version 3.1 + */ + public static final String SPDY_3_1 = "spdy/3.1"; + + /** + * {@code "spdy/3"}: SPDY version 3 + */ + public static final String SPDY_3 = "spdy/3"; + + /** + * {@code "spdy/2"}: SPDY version 2 + */ + public static final String SPDY_2 = "spdy/2"; + + /** + * {@code "spdy/1"}: SPDY version 1 + */ + public static final String SPDY_1 = "spdy/1"; + + private ApplicationProtocolNames() { } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/ApplicationProtocolNegotiationHandler.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/ApplicationProtocolNegotiationHandler.java new file mode 100644 index 0000000..49d6d5e --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/ApplicationProtocolNegotiationHandler.java @@ -0,0 +1,210 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.socket.ChannelInputShutdownEvent; +import io.netty.handler.codec.DecoderException; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.RecyclableArrayList; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import javax.net.ssl.SSLException; + +/** + * Configures a {@link ChannelPipeline} depending on the application-level protocol negotiation result of + * {@link SslHandler}. For example, you could configure your HTTP pipeline depending on the result of ALPN: + *

+ * public class MyInitializer extends {@link ChannelInitializer}<{@link Channel}> {
+ *     private final {@link SslContext} sslCtx;
+ *
+ *     public MyInitializer({@link SslContext} sslCtx) {
+ *         this.sslCtx = sslCtx;
+ *     }
+ *
+ *     protected void initChannel({@link Channel} ch) {
+ *         {@link ChannelPipeline} p = ch.pipeline();
+ *         p.addLast(sslCtx.newHandler(...)); // Adds {@link SslHandler}
+ *         p.addLast(new MyNegotiationHandler());
+ *     }
+ * }
+ *
+ * public class MyNegotiationHandler extends {@link ApplicationProtocolNegotiationHandler} {
+ *     public MyNegotiationHandler() {
+ *         super({@link ApplicationProtocolNames}.HTTP_1_1);
+ *     }
+ *
+ *     protected void configurePipeline({@link ChannelHandlerContext} ctx, String protocol) {
+ *         if ({@link ApplicationProtocolNames}.HTTP_2.equals(protocol) {
+ *             configureHttp2(ctx);
+ *         } else if ({@link ApplicationProtocolNames}.HTTP_1_1.equals(protocol)) {
+ *             configureHttp1(ctx);
+ *         } else {
+ *             throw new IllegalStateException("unknown protocol: " + protocol);
+ *         }
+ *     }
+ * }
+ * 
+ */ +public abstract class ApplicationProtocolNegotiationHandler extends ChannelInboundHandlerAdapter { + + private static final InternalLogger logger = + InternalLoggerFactory.getInstance(ApplicationProtocolNegotiationHandler.class); + + private final String fallbackProtocol; + private final RecyclableArrayList bufferedMessages = RecyclableArrayList.newInstance(); + private ChannelHandlerContext ctx; + private boolean sslHandlerChecked; + + /** + * Creates a new instance with the specified fallback protocol name. + * + * @param fallbackProtocol the name of the protocol to use when + * ALPN/NPN negotiation fails or the client does not support ALPN/NPN + */ + protected ApplicationProtocolNegotiationHandler(String fallbackProtocol) { + this.fallbackProtocol = ObjectUtil.checkNotNull(fallbackProtocol, "fallbackProtocol"); + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + this.ctx = ctx; + super.handlerAdded(ctx); + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { + fireBufferedMessages(); + bufferedMessages.recycle(); + super.handlerRemoved(ctx); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + // Let's buffer all data until this handler will be removed from the pipeline. + bufferedMessages.add(msg); + if (!sslHandlerChecked) { + sslHandlerChecked = true; + if (ctx.pipeline().get(SslHandler.class) == null) { + // Just remove ourself if there is no SslHandler in the pipeline and so we would otherwise + // buffer forever. + removeSelfIfPresent(ctx); + } + } + } + + /** + * Process all backlog into pipeline from List. + */ + private void fireBufferedMessages() { + if (!bufferedMessages.isEmpty()) { + for (int i = 0; i < bufferedMessages.size(); i++) { + ctx.fireChannelRead(bufferedMessages.get(i)); + } + ctx.fireChannelReadComplete(); + bufferedMessages.clear(); + } + } + + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + if (evt instanceof SslHandshakeCompletionEvent) { + SslHandshakeCompletionEvent handshakeEvent = (SslHandshakeCompletionEvent) evt; + try { + if (handshakeEvent.isSuccess()) { + SslHandler sslHandler = ctx.pipeline().get(SslHandler.class); + if (sslHandler == null) { + throw new IllegalStateException("cannot find an SslHandler in the pipeline (required for " + + "application-level protocol negotiation)"); + } + String protocol = sslHandler.applicationProtocol(); + configurePipeline(ctx, protocol != null ? protocol : fallbackProtocol); + } else { + // if the event is not produced because of an successful handshake we will receive the same + // exception in exceptionCaught(...) and handle it there. This will allow us more fine-grained + // control over which exception we propagate down the ChannelPipeline. + // + // See https://github.com/netty/netty/issues/10342 + } + } catch (Throwable cause) { + exceptionCaught(ctx, cause); + } finally { + // Handshake failures are handled in exceptionCaught(...). + if (handshakeEvent.isSuccess()) { + removeSelfIfPresent(ctx); + } + } + } + + if (evt instanceof ChannelInputShutdownEvent) { + fireBufferedMessages(); + } + + ctx.fireUserEventTriggered(evt); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + fireBufferedMessages(); + super.channelInactive(ctx); + } + + private void removeSelfIfPresent(ChannelHandlerContext ctx) { + ChannelPipeline pipeline = ctx.pipeline(); + if (!ctx.isRemoved()) { + pipeline.remove(this); + } + } + + /** + * Invoked on successful initial SSL/TLS handshake. Implement this method to configure your pipeline + * for the negotiated application-level protocol. + * + * @param protocol the name of the negotiated application-level protocol, or + * the fallback protocol name specified in the constructor call if negotiation failed or the client + * isn't aware of ALPN/NPN extension + */ + protected abstract void configurePipeline(ChannelHandlerContext ctx, String protocol) throws Exception; + + /** + * Invoked on failed initial SSL/TLS handshake. + */ + protected void handshakeFailure(ChannelHandlerContext ctx, Throwable cause) throws Exception { + logger.warn("{} TLS handshake failed:", ctx.channel(), cause); + ctx.close(); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + Throwable wrapped; + if (cause instanceof DecoderException && ((wrapped = cause.getCause()) instanceof SSLException)) { + try { + handshakeFailure(ctx, wrapped); + return; + } finally { + removeSelfIfPresent(ctx); + } + } + logger.warn("{} Failed to select the application-level protocol:", ctx.channel(), cause); + ctx.fireExceptionCaught(cause); + ctx.close(); + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/ApplicationProtocolNegotiator.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/ApplicationProtocolNegotiator.java new file mode 100644 index 0000000..2d5820d --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/ApplicationProtocolNegotiator.java @@ -0,0 +1,37 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import java.util.List; + +/** + * Interface to support Application Protocol Negotiation. + *

+ * Default implementations are provided for: + *

+ * + * @deprecated use {@link ApplicationProtocolConfig} + */ +@SuppressWarnings("deprecation") +public interface ApplicationProtocolNegotiator { + /** + * Get the collection of application protocols supported by this application (in preference order). + */ + List protocols(); +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/ApplicationProtocolUtil.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/ApplicationProtocolUtil.java new file mode 100644 index 0000000..e3bca70 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/ApplicationProtocolUtil.java @@ -0,0 +1,65 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import static io.netty.util.internal.ObjectUtil.checkNonEmpty; + +import java.util.ArrayList; +import java.util.List; + +/** + * Utility class for application protocol common operations. + */ +final class ApplicationProtocolUtil { + private static final int DEFAULT_LIST_SIZE = 2; + + private ApplicationProtocolUtil() { + } + + static List toList(Iterable protocols) { + return toList(DEFAULT_LIST_SIZE, protocols); + } + + static List toList(int initialListSize, Iterable protocols) { + if (protocols == null) { + return null; + } + + List result = new ArrayList(initialListSize); + for (String p : protocols) { + result.add(checkNonEmpty(p, "p")); + } + + return checkNonEmpty(result, "result"); + } + + static List toList(String... protocols) { + return toList(DEFAULT_LIST_SIZE, protocols); + } + + static List toList(int initialListSize, String... protocols) { + if (protocols == null) { + return null; + } + + List result = new ArrayList(initialListSize); + for (String p : protocols) { + result.add(checkNonEmpty(p, "p")); + } + + return checkNonEmpty(result, "result"); + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/AsyncRunnable.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/AsyncRunnable.java new file mode 100644 index 0000000..6fb620e --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/AsyncRunnable.java @@ -0,0 +1,20 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +interface AsyncRunnable extends Runnable { + void run(Runnable completionCallback); +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/BouncyCastle.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/BouncyCastle.java new file mode 100644 index 0000000..a90bdd1 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/BouncyCastle.java @@ -0,0 +1,54 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import javax.net.ssl.SSLEngine; + +/** + * Contains methods that can be used to detect if BouncyCastle is usable. + */ +final class BouncyCastle { + + private static final boolean BOUNCY_CASTLE_ON_CLASSPATH; + + static { + boolean bcOnClasspath = false; + try { + Class.forName("org.bouncycastle.jsse.provider.BouncyCastleJsseProvider"); + bcOnClasspath = true; + } catch (Throwable ignore) { + // ignore + } + BOUNCY_CASTLE_ON_CLASSPATH = bcOnClasspath; + } + + /** + * Indicates whether or not BouncyCastle is available on the current system. + */ + static boolean isAvailable() { + return BOUNCY_CASTLE_ON_CLASSPATH; + } + + /** + * Indicates whether or not BouncyCastle is the underlying SSLEngine. + */ + static boolean isInUse(SSLEngine engine) { + return engine.getClass().getPackage().getName().startsWith("org.bouncycastle.jsse.provider"); + } + + private BouncyCastle() { + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/BouncyCastleAlpnSslEngine.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/BouncyCastleAlpnSslEngine.java new file mode 100644 index 0000000..38d2577 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/BouncyCastleAlpnSslEngine.java @@ -0,0 +1,62 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.util.internal.SuppressJava6Requirement; + +import javax.net.ssl.SSLEngine; +import java.util.List; +import java.util.function.BiConsumer; +import java.util.function.BiFunction; + +@SuppressJava6Requirement(reason = "Usage guarded by java version check") +final class BouncyCastleAlpnSslEngine extends JdkAlpnSslEngine { + + BouncyCastleAlpnSslEngine(SSLEngine engine, + @SuppressWarnings("deprecation") JdkApplicationProtocolNegotiator applicationNegotiator, + boolean isServer) { + super(engine, applicationNegotiator, isServer, + new BiConsumer() { + @Override + public void accept(SSLEngine e, AlpnSelector s) { + BouncyCastleAlpnSslUtils.setHandshakeApplicationProtocolSelector(e, s); + } + }, + new BiConsumer>() { + @Override + public void accept(SSLEngine e, List p) { + BouncyCastleAlpnSslUtils.setApplicationProtocols(e, p); + } + }); + } + + public String getApplicationProtocol() { + return BouncyCastleAlpnSslUtils.getApplicationProtocol(getWrappedEngine()); + } + + public String getHandshakeApplicationProtocol() { + return BouncyCastleAlpnSslUtils.getHandshakeApplicationProtocol(getWrappedEngine()); + } + + public void setHandshakeApplicationProtocolSelector(BiFunction, String> selector) { + BouncyCastleAlpnSslUtils.setHandshakeApplicationProtocolSelector(getWrappedEngine(), selector); + } + + public BiFunction, String> getHandshakeApplicationProtocolSelector() { + return BouncyCastleAlpnSslUtils.getHandshakeApplicationProtocolSelector(getWrappedEngine()); + } + +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/BouncyCastleAlpnSslUtils.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/BouncyCastleAlpnSslUtils.java new file mode 100644 index 0000000..50d4ab6 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/BouncyCastleAlpnSslUtils.java @@ -0,0 +1,259 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + + +import io.netty.util.internal.EmptyArrays; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.SuppressJava6Requirement; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLEngine; +import java.lang.reflect.InvocationHandler; +import java.lang.reflect.Method; +import java.lang.reflect.Proxy; +import java.security.AccessController; +import java.security.PrivilegedExceptionAction; +import java.util.List; +import java.util.function.BiFunction; + +import static io.netty.handler.ssl.SslUtils.getSSLContext; + +@SuppressJava6Requirement(reason = "Usage guarded by java version check") +final class BouncyCastleAlpnSslUtils { + private static final InternalLogger logger = InternalLoggerFactory.getInstance(BouncyCastleAlpnSslUtils.class); + private static final Method SET_PARAMETERS; + private static final Method GET_PARAMETERS; + private static final Method SET_APPLICATION_PROTOCOLS; + private static final Method GET_APPLICATION_PROTOCOL; + private static final Method GET_HANDSHAKE_APPLICATION_PROTOCOL; + private static final Method SET_HANDSHAKE_APPLICATION_PROTOCOL_SELECTOR; + private static final Method GET_HANDSHAKE_APPLICATION_PROTOCOL_SELECTOR; + private static final Class BC_APPLICATION_PROTOCOL_SELECTOR; + private static final Method BC_APPLICATION_PROTOCOL_SELECTOR_SELECT; + + static { + Class bcSslEngine; + Method getParameters; + Method setParameters; + Method setApplicationProtocols; + Method getApplicationProtocol; + Method getHandshakeApplicationProtocol; + Method setHandshakeApplicationProtocolSelector; + Method getHandshakeApplicationProtocolSelector; + Method bcApplicationProtocolSelectorSelect; + Class bcApplicationProtocolSelector; + + try { + bcSslEngine = Class.forName("org.bouncycastle.jsse.BCSSLEngine"); + final Class testBCSslEngine = bcSslEngine; + + bcApplicationProtocolSelector = + Class.forName("org.bouncycastle.jsse.BCApplicationProtocolSelector"); + + final Class testBCApplicationProtocolSelector = bcApplicationProtocolSelector; + + bcApplicationProtocolSelectorSelect = AccessController.doPrivileged( + new PrivilegedExceptionAction() { + @Override + public Method run() throws Exception { + return testBCApplicationProtocolSelector.getMethod("select", Object.class, List.class); + } + }); + + SSLContext context = getSSLContext("BCJSSE"); + SSLEngine engine = context.createSSLEngine(); + + getParameters = AccessController.doPrivileged(new PrivilegedExceptionAction() { + @Override + public Method run() throws Exception { + return testBCSslEngine.getMethod("getParameters"); + } + }); + + final Object bcSslParameters = getParameters.invoke(engine); + final Class bCSslParametersClass = bcSslParameters.getClass(); + + setParameters = AccessController.doPrivileged(new PrivilegedExceptionAction() { + @Override + public Method run() throws Exception { + return testBCSslEngine.getMethod("setParameters", bCSslParametersClass); + } + }); + setParameters.invoke(engine, bcSslParameters); + + setApplicationProtocols = AccessController.doPrivileged(new PrivilegedExceptionAction() { + @Override + public Method run() throws Exception { + return bCSslParametersClass.getMethod("setApplicationProtocols", String[].class); + } + }); + setApplicationProtocols.invoke(bcSslParameters, new Object[]{EmptyArrays.EMPTY_STRINGS}); + + getApplicationProtocol = AccessController.doPrivileged(new PrivilegedExceptionAction() { + @Override + public Method run() throws Exception { + return testBCSslEngine.getMethod("getApplicationProtocol"); + } + }); + getApplicationProtocol.invoke(engine); + + getHandshakeApplicationProtocol = AccessController.doPrivileged(new PrivilegedExceptionAction() { + @Override + public Method run() throws Exception { + return testBCSslEngine.getMethod("getHandshakeApplicationProtocol"); + } + }); + getHandshakeApplicationProtocol.invoke(engine); + + setHandshakeApplicationProtocolSelector = + AccessController.doPrivileged(new PrivilegedExceptionAction() { + @Override + public Method run() throws Exception { + return testBCSslEngine.getMethod("setBCHandshakeApplicationProtocolSelector", + testBCApplicationProtocolSelector); + } + }); + + getHandshakeApplicationProtocolSelector = + AccessController.doPrivileged(new PrivilegedExceptionAction() { + @Override + public Method run() throws Exception { + return testBCSslEngine.getMethod("getBCHandshakeApplicationProtocolSelector"); + } + }); + getHandshakeApplicationProtocolSelector.invoke(engine); + + } catch (Throwable t) { + logger.error("Unable to initialize BouncyCastleAlpnSslUtils.", t); + setParameters = null; + getParameters = null; + setApplicationProtocols = null; + getApplicationProtocol = null; + getHandshakeApplicationProtocol = null; + setHandshakeApplicationProtocolSelector = null; + getHandshakeApplicationProtocolSelector = null; + bcApplicationProtocolSelectorSelect = null; + bcApplicationProtocolSelector = null; + } + SET_PARAMETERS = setParameters; + GET_PARAMETERS = getParameters; + SET_APPLICATION_PROTOCOLS = setApplicationProtocols; + GET_APPLICATION_PROTOCOL = getApplicationProtocol; + GET_HANDSHAKE_APPLICATION_PROTOCOL = getHandshakeApplicationProtocol; + SET_HANDSHAKE_APPLICATION_PROTOCOL_SELECTOR = setHandshakeApplicationProtocolSelector; + GET_HANDSHAKE_APPLICATION_PROTOCOL_SELECTOR = getHandshakeApplicationProtocolSelector; + BC_APPLICATION_PROTOCOL_SELECTOR_SELECT = bcApplicationProtocolSelectorSelect; + BC_APPLICATION_PROTOCOL_SELECTOR = bcApplicationProtocolSelector; + } + + private BouncyCastleAlpnSslUtils() { + } + + static String getApplicationProtocol(SSLEngine sslEngine) { + try { + return (String) GET_APPLICATION_PROTOCOL.invoke(sslEngine); + } catch (UnsupportedOperationException ex) { + throw ex; + } catch (Exception ex) { + throw new IllegalStateException(ex); + } + } + + static void setApplicationProtocols(SSLEngine engine, List supportedProtocols) { + String[] protocolArray = supportedProtocols.toArray(EmptyArrays.EMPTY_STRINGS); + try { + Object bcSslParameters = GET_PARAMETERS.invoke(engine); + SET_APPLICATION_PROTOCOLS.invoke(bcSslParameters, new Object[]{protocolArray}); + SET_PARAMETERS.invoke(engine, bcSslParameters); + } catch (UnsupportedOperationException ex) { + throw ex; + } catch (Exception ex) { + throw new IllegalStateException(ex); + } + if (PlatformDependent.javaVersion() >= 9) { + JdkAlpnSslUtils.setApplicationProtocols(engine, supportedProtocols); + } + } + + static String getHandshakeApplicationProtocol(SSLEngine sslEngine) { + try { + return (String) GET_HANDSHAKE_APPLICATION_PROTOCOL.invoke(sslEngine); + } catch (UnsupportedOperationException ex) { + throw ex; + } catch (Exception ex) { + throw new IllegalStateException(ex); + } + } + + static void setHandshakeApplicationProtocolSelector( + SSLEngine engine, final BiFunction, String> selector) { + try { + Object selectorProxyInstance = Proxy.newProxyInstance( + BouncyCastleAlpnSslUtils.class.getClassLoader(), + new Class[]{BC_APPLICATION_PROTOCOL_SELECTOR}, + new InvocationHandler() { + @Override + public Object invoke(Object proxy, Method method, Object[] args) throws Throwable { + if (method.getName().equals("select")) { + try { + return selector.apply((SSLEngine) args[0], (List) args[1]); + } catch (ClassCastException e) { + throw new RuntimeException("BCApplicationProtocolSelector select method " + + "parameter of invalid type.", e); + } + } else { + throw new UnsupportedOperationException(String.format("Method '%s' not supported.", + method.getName())); + } + } + }); + + SET_HANDSHAKE_APPLICATION_PROTOCOL_SELECTOR.invoke(engine, selectorProxyInstance); + } catch (UnsupportedOperationException ex) { + throw ex; + } catch (Exception ex) { + throw new IllegalStateException(ex); + } + } + + @SuppressWarnings("unchecked") + static BiFunction, String> getHandshakeApplicationProtocolSelector(SSLEngine engine) { + try { + final Object selector = GET_HANDSHAKE_APPLICATION_PROTOCOL_SELECTOR.invoke(engine); + return new BiFunction, String>() { + + @Override + public String apply(SSLEngine sslEngine, List strings) { + try { + return (String) BC_APPLICATION_PROTOCOL_SELECTOR_SELECT.invoke(selector, sslEngine, + strings); + } catch (Exception e) { + throw new RuntimeException("Could not call getHandshakeApplicationProtocolSelector", e); + } + } + }; + + } catch (UnsupportedOperationException ex) { + throw ex; + } catch (Exception ex) { + throw new IllegalStateException(ex); + } + } + +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/BouncyCastlePemReader.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/BouncyCastlePemReader.java new file mode 100644 index 0000000..026ef07 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/BouncyCastlePemReader.java @@ -0,0 +1,223 @@ +/* + * Copyright 2022 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.util.CharsetUtil; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; +import org.bouncycastle.asn1.pkcs.PrivateKeyInfo; +import org.bouncycastle.openssl.PEMDecryptorProvider; +import org.bouncycastle.openssl.PEMEncryptedKeyPair; +import org.bouncycastle.openssl.PEMKeyPair; +import org.bouncycastle.openssl.PEMParser; +import org.bouncycastle.openssl.jcajce.JcaPEMKeyConverter; +import org.bouncycastle.openssl.jcajce.JceOpenSSLPKCS8DecryptorProviderBuilder; +import org.bouncycastle.openssl.jcajce.JcePEMDecryptorProviderBuilder; +import org.bouncycastle.operator.InputDecryptorProvider; +import org.bouncycastle.operator.OperatorCreationException; +import org.bouncycastle.pkcs.PKCS8EncryptedPrivateKeyInfo; +import org.bouncycastle.pkcs.PKCSException; + +import java.io.File; +import java.io.FileNotFoundException; +import java.io.FileReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.security.AccessController; +import java.security.PrivateKey; +import java.security.PrivilegedAction; +import java.security.Provider; + +final class BouncyCastlePemReader { + private static final String BC_PROVIDER = "org.bouncycastle.jce.provider.BouncyCastleProvider"; + private static final String BC_PEMPARSER = "org.bouncycastle.openssl.PEMParser"; + private static final InternalLogger logger = InternalLoggerFactory.getInstance(BouncyCastlePemReader.class); + + private static volatile Throwable unavailabilityCause; + private static volatile Provider bcProvider; + private static volatile boolean attemptedLoading; + + public static boolean hasAttemptedLoading() { + return attemptedLoading; + } + + public static boolean isAvailable() { + if (!hasAttemptedLoading()) { + tryLoading(); + } + return unavailabilityCause == null; + } + + /** + * @return the cause if unavailable. {@code null} if available. + */ + public static Throwable unavailabilityCause() { + return unavailabilityCause; + } + + private static void tryLoading() { + AccessController.doPrivileged(new PrivilegedAction() { + @Override + public Void run() { + try { + ClassLoader classLoader = getClass().getClassLoader(); + // Check for bcprov-jdk15on: + Class bcProviderClass = + (Class) Class.forName(BC_PROVIDER, true, classLoader); + // Check for bcpkix-jdk15on: + Class.forName(BC_PEMPARSER, true, classLoader); + bcProvider = bcProviderClass.getConstructor().newInstance(); + logger.debug("Bouncy Castle provider available"); + attemptedLoading = true; + } catch (Throwable e) { + logger.debug("Cannot load Bouncy Castle provider", e); + unavailabilityCause = e; + attemptedLoading = true; + } + return null; + } + }); + } + + /** + * Generates a new {@link PrivateKey}. + * + * @param keyInputStream an input stream for a PKCS#1 or PKCS#8 private key in PEM format. + * @param keyPassword the password of the {@code keyFile}. + * {@code null} if it's not password-protected. + * @return generated {@link PrivateKey}. + */ + public static PrivateKey getPrivateKey(InputStream keyInputStream, String keyPassword) { + if (!isAvailable()) { + if (logger.isDebugEnabled()) { + logger.debug("Bouncy castle provider is unavailable.", unavailabilityCause()); + } + return null; + } + try { + PEMParser parser = newParser(keyInputStream); + return getPrivateKey(parser, keyPassword); + } catch (Exception e) { + logger.debug("Unable to extract private key", e); + return null; + } + } + + /** + * Generates a new {@link PrivateKey}. + * + * @param keyFile a PKCS#1 or PKCS#8 private key file in PEM format. + * @param keyPassword the password of the {@code keyFile}. + * {@code null} if it's not password-protected. + * @return generated {@link PrivateKey}. + */ + public static PrivateKey getPrivateKey(File keyFile, String keyPassword) { + if (!isAvailable()) { + if (logger.isDebugEnabled()) { + logger.debug("Bouncy castle provider is unavailable.", unavailabilityCause()); + } + return null; + } + try { + PEMParser parser = newParser(keyFile); + return getPrivateKey(parser, keyPassword); + } catch (Exception e) { + logger.debug("Unable to extract private key", e); + return null; + } + } + + private static JcaPEMKeyConverter newConverter() { + return new JcaPEMKeyConverter().setProvider(bcProvider); + } + + private static PrivateKey getPrivateKey(PEMParser pemParser, String keyPassword) throws IOException, + PKCSException, OperatorCreationException { + try { + JcaPEMKeyConverter converter = newConverter(); + PrivateKey pk = null; + + Object object = pemParser.readObject(); + while (object != null && pk == null) { + if (logger.isDebugEnabled()) { + logger.debug("Parsed PEM object of type {} and assume " + + "key is {}encrypted", object.getClass().getName(), keyPassword == null? "not " : ""); + } + + if (keyPassword == null) { + // assume private key is not encrypted + if (object instanceof PrivateKeyInfo) { + pk = converter.getPrivateKey((PrivateKeyInfo) object); + } else if (object instanceof PEMKeyPair) { + pk = converter.getKeyPair((PEMKeyPair) object).getPrivate(); + } else { + logger.debug("Unable to handle PEM object of type {} as a non encrypted key", + object.getClass()); + } + } else { + // assume private key is encrypted + if (object instanceof PEMEncryptedKeyPair) { + PEMDecryptorProvider decProv = new JcePEMDecryptorProviderBuilder() + .setProvider(bcProvider) + .build(keyPassword.toCharArray()); + pk = converter.getKeyPair(((PEMEncryptedKeyPair) object).decryptKeyPair(decProv)).getPrivate(); + } else if (object instanceof PKCS8EncryptedPrivateKeyInfo) { + InputDecryptorProvider pkcs8InputDecryptorProvider = + new JceOpenSSLPKCS8DecryptorProviderBuilder() + .setProvider(bcProvider) + .build(keyPassword.toCharArray()); + pk = converter.getPrivateKey(((PKCS8EncryptedPrivateKeyInfo) object) + .decryptPrivateKeyInfo(pkcs8InputDecryptorProvider)); + } else { + logger.debug("Unable to handle PEM object of type {} as a encrypted key", object.getClass()); + } + } + + // Try reading next entry in the pem file if private key is not yet found + if (pk == null) { + object = pemParser.readObject(); + } + } + + if (pk == null) { + if (logger.isDebugEnabled()) { + logger.debug("No key found"); + } + } + + return pk; + } finally { + if (pemParser != null) { + try { + pemParser.close(); + } catch (Exception exception) { + logger.debug("Failed closing pem parser", exception); + } + } + } + } + + private static PEMParser newParser(File keyFile) throws FileNotFoundException { + return new PEMParser(new FileReader(keyFile)); + } + + private static PEMParser newParser(InputStream keyInputStream) { + return new PEMParser(new InputStreamReader(keyInputStream, CharsetUtil.US_ASCII)); + } + + private BouncyCastlePemReader() { } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/CipherSuiteConverter.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/CipherSuiteConverter.java new file mode 100644 index 0000000..910b5e3 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/CipherSuiteConverter.java @@ -0,0 +1,516 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.ssl; + +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.UnstableApi; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.ConcurrentMap; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +import static java.util.Collections.singletonMap; + +/** + * Converts a Java cipher suite string to an OpenSSL cipher suite string and vice versa. + * + * @see Wikipedia page about cipher suite + */ +@UnstableApi +public final class CipherSuiteConverter { + + private static final InternalLogger logger = InternalLoggerFactory.getInstance(CipherSuiteConverter.class); + + /** + * A_B_WITH_C_D, where: + * + * A - TLS or SSL (protocol) + * B - handshake algorithm (key exchange and authentication algorithms to be precise) + * C - bulk cipher + * D - HMAC algorithm + * + * This regular expression assumes that: + * + * 1) A is always TLS or SSL, and + * 2) D is always a single word. + */ + private static final Pattern JAVA_CIPHERSUITE_PATTERN = + Pattern.compile("^(?:TLS|SSL)_((?:(?!_WITH_).)+)_WITH_(.*)_(.*)$"); + + /** + * A-B-C, where: + * + * A - handshake algorithm (key exchange and authentication algorithms to be precise) + * B - bulk cipher + * C - HMAC algorithm + * + * This regular expression assumes that: + * + * 1) A has some deterministic pattern as shown below, and + * 2) C is always a single word + */ + private static final Pattern OPENSSL_CIPHERSUITE_PATTERN = + // Be very careful not to break the indentation while editing. + Pattern.compile( + "^(?:(" + // BEGIN handshake algorithm + "(?:(?:EXP-)?" + + "(?:" + + "(?:DHE|EDH|ECDH|ECDHE|SRP|RSA)-(?:DSS|RSA|ECDSA|PSK)|" + + "(?:ADH|AECDH|KRB5|PSK|SRP)" + + ')' + + ")|" + + "EXP" + + ")-)?" + // END handshake algorithm + "(.*)-(.*)$"); + + private static final Pattern JAVA_AES_CBC_PATTERN = Pattern.compile("^(AES)_([0-9]+)_CBC$"); + private static final Pattern JAVA_AES_PATTERN = Pattern.compile("^(AES)_([0-9]+)_(.*)$"); + private static final Pattern OPENSSL_AES_CBC_PATTERN = Pattern.compile("^(AES)([0-9]+)$"); + private static final Pattern OPENSSL_AES_PATTERN = Pattern.compile("^(AES)([0-9]+)-(.*)$"); + + /** + * Used to store nullable values in a CHM + */ + private static final class CachedValue { + + private static final CachedValue NULL = new CachedValue(null); + + static CachedValue of(String value) { + return value != null ? new CachedValue(value) : NULL; + } + + final String value; + private CachedValue(String value) { + this.value = value; + } + } + + /** + * Java-to-OpenSSL cipher suite conversion map + * Note that the Java cipher suite has the protocol prefix (TLS_, SSL_) + */ + private static final ConcurrentMap j2o = PlatformDependent.newConcurrentHashMap(); + + /** + * OpenSSL-to-Java cipher suite conversion map. + * Note that one OpenSSL cipher suite can be converted to more than one Java cipher suites because + * a Java cipher suite has the protocol name prefix (TLS_, SSL_) + */ + private static final ConcurrentMap> o2j = PlatformDependent.newConcurrentHashMap(); + + private static final Map j2oTls13; + private static final Map> o2jTls13; + + static { + Map j2oTls13Map = new HashMap(); + j2oTls13Map.put("TLS_AES_128_GCM_SHA256", "AEAD-AES128-GCM-SHA256"); + j2oTls13Map.put("TLS_AES_256_GCM_SHA384", "AEAD-AES256-GCM-SHA384"); + j2oTls13Map.put("TLS_CHACHA20_POLY1305_SHA256", "AEAD-CHACHA20-POLY1305-SHA256"); + j2oTls13 = Collections.unmodifiableMap(j2oTls13Map); + + Map> o2jTls13Map = new HashMap>(); + o2jTls13Map.put("TLS_AES_128_GCM_SHA256", singletonMap("TLS", "TLS_AES_128_GCM_SHA256")); + o2jTls13Map.put("TLS_AES_256_GCM_SHA384", singletonMap("TLS", "TLS_AES_256_GCM_SHA384")); + o2jTls13Map.put("TLS_CHACHA20_POLY1305_SHA256", singletonMap("TLS", "TLS_CHACHA20_POLY1305_SHA256")); + o2jTls13Map.put("AEAD-AES128-GCM-SHA256", singletonMap("TLS", "TLS_AES_128_GCM_SHA256")); + o2jTls13Map.put("AEAD-AES256-GCM-SHA384", singletonMap("TLS", "TLS_AES_256_GCM_SHA384")); + o2jTls13Map.put("AEAD-CHACHA20-POLY1305-SHA256", singletonMap("TLS", "TLS_CHACHA20_POLY1305_SHA256")); + o2jTls13 = Collections.unmodifiableMap(o2jTls13Map); + } + + /** + * Clears the cache for testing purpose. + */ + static void clearCache() { + j2o.clear(); + o2j.clear(); + } + + /** + * Tests if the specified key-value pair has been cached in Java-to-OpenSSL cache. + */ + static boolean isJ2OCached(String key, String value) { + CachedValue cached = j2o.get(key); + return cached != null && value.equals(cached.value); + } + + /** + * Tests if the specified key-value pair has been cached in OpenSSL-to-Java cache. + */ + static boolean isO2JCached(String key, String protocol, String value) { + Map p2j = o2j.get(key); + if (p2j == null) { + return false; + } else { + return value.equals(p2j.get(protocol)); + } + } + + /** + * Converts the specified Java cipher suite to its corresponding OpenSSL cipher suite name. + * + * @return {@code null} if the conversion has failed + */ + public static String toOpenSsl(String javaCipherSuite, boolean boringSSL) { + CachedValue converted = j2o.get(javaCipherSuite); + if (converted != null) { + return converted.value; + } + return cacheFromJava(javaCipherSuite, boringSSL); + } + + private static String cacheFromJava(String javaCipherSuite, boolean boringSSL) { + String converted = j2oTls13.get(javaCipherSuite); + if (converted != null) { + return boringSSL ? converted : javaCipherSuite; + } + + String openSslCipherSuite = toOpenSslUncached(javaCipherSuite, boringSSL); + + // Cache the mapping. + j2o.putIfAbsent(javaCipherSuite, CachedValue.of(openSslCipherSuite)); + + if (openSslCipherSuite == null) { + return null; + } + + // Cache the reverse mapping after stripping the protocol prefix (TLS_ or SSL_) + final String javaCipherSuiteSuffix = javaCipherSuite.substring(4); + Map p2j = new HashMap(4); + p2j.put("", javaCipherSuiteSuffix); + p2j.put("SSL", "SSL_" + javaCipherSuiteSuffix); + p2j.put("TLS", "TLS_" + javaCipherSuiteSuffix); + o2j.put(openSslCipherSuite, p2j); + + logger.debug("Cipher suite mapping: {} => {}", javaCipherSuite, openSslCipherSuite); + + return openSslCipherSuite; + } + + static String toOpenSslUncached(String javaCipherSuite, boolean boringSSL) { + String converted = j2oTls13.get(javaCipherSuite); + if (converted != null) { + return boringSSL ? converted : javaCipherSuite; + } + + Matcher m = JAVA_CIPHERSUITE_PATTERN.matcher(javaCipherSuite); + if (!m.matches()) { + return null; + } + + String handshakeAlgo = toOpenSslHandshakeAlgo(m.group(1)); + String bulkCipher = toOpenSslBulkCipher(m.group(2)); + String hmacAlgo = toOpenSslHmacAlgo(m.group(3)); + if (handshakeAlgo.isEmpty()) { + return bulkCipher + '-' + hmacAlgo; + } else if (bulkCipher.contains("CHACHA20")) { + return handshakeAlgo + '-' + bulkCipher; + } else { + return handshakeAlgo + '-' + bulkCipher + '-' + hmacAlgo; + } + } + + private static String toOpenSslHandshakeAlgo(String handshakeAlgo) { + final boolean export = handshakeAlgo.endsWith("_EXPORT"); + if (export) { + handshakeAlgo = handshakeAlgo.substring(0, handshakeAlgo.length() - 7); + } + + if ("RSA".equals(handshakeAlgo)) { + handshakeAlgo = ""; + } else if (handshakeAlgo.endsWith("_anon")) { + handshakeAlgo = 'A' + handshakeAlgo.substring(0, handshakeAlgo.length() - 5); + } + + if (export) { + if (handshakeAlgo.isEmpty()) { + handshakeAlgo = "EXP"; + } else { + handshakeAlgo = "EXP-" + handshakeAlgo; + } + } + + return handshakeAlgo.replace('_', '-'); + } + + private static String toOpenSslBulkCipher(String bulkCipher) { + if (bulkCipher.startsWith("AES_")) { + Matcher m = JAVA_AES_CBC_PATTERN.matcher(bulkCipher); + if (m.matches()) { + return m.replaceFirst("$1$2"); + } + + m = JAVA_AES_PATTERN.matcher(bulkCipher); + if (m.matches()) { + return m.replaceFirst("$1$2-$3"); + } + } + + if ("3DES_EDE_CBC".equals(bulkCipher)) { + return "DES-CBC3"; + } + + if ("RC4_128".equals(bulkCipher) || "RC4_40".equals(bulkCipher)) { + return "RC4"; + } + + if ("DES40_CBC".equals(bulkCipher) || "DES_CBC_40".equals(bulkCipher)) { + return "DES-CBC"; + } + + if ("RC2_CBC_40".equals(bulkCipher)) { + return "RC2-CBC"; + } + + return bulkCipher.replace('_', '-'); + } + + private static String toOpenSslHmacAlgo(String hmacAlgo) { + // Java and OpenSSL use the same algorithm names for: + // + // * SHA + // * SHA256 + // * MD5 + // + return hmacAlgo; + } + + /** + * Convert from OpenSSL cipher suite name convention to java cipher suite name convention. + * @param openSslCipherSuite An OpenSSL cipher suite name. + * @param protocol The cryptographic protocol (i.e. SSL, TLS, ...). + * @return The translated cipher suite name according to java conventions. This will not be {@code null}. + */ + public static String toJava(String openSslCipherSuite, String protocol) { + Map p2j = o2j.get(openSslCipherSuite); + if (p2j == null) { + p2j = cacheFromOpenSsl(openSslCipherSuite); + // This may happen if this method is queried when OpenSSL doesn't yet have a cipher setup. It will return + // "(NONE)" in this case. + if (p2j == null) { + return null; + } + } + + String javaCipherSuite = p2j.get(protocol); + if (javaCipherSuite == null) { + String cipher = p2j.get(""); + if (cipher == null) { + return null; + } + javaCipherSuite = protocol + '_' + cipher; + } + + return javaCipherSuite; + } + + private static Map cacheFromOpenSsl(String openSslCipherSuite) { + Map converted = o2jTls13.get(openSslCipherSuite); + if (converted != null) { + return converted; + } + + String javaCipherSuiteSuffix = toJavaUncached0(openSslCipherSuite, false); + if (javaCipherSuiteSuffix == null) { + return null; + } + + final String javaCipherSuiteSsl = "SSL_" + javaCipherSuiteSuffix; + final String javaCipherSuiteTls = "TLS_" + javaCipherSuiteSuffix; + + // Cache the mapping. + final Map p2j = new HashMap(4); + p2j.put("", javaCipherSuiteSuffix); + p2j.put("SSL", javaCipherSuiteSsl); + p2j.put("TLS", javaCipherSuiteTls); + o2j.putIfAbsent(openSslCipherSuite, p2j); + + // Cache the reverse mapping after adding the protocol prefix (TLS_ or SSL_) + CachedValue cachedValue = CachedValue.of(openSslCipherSuite); + j2o.putIfAbsent(javaCipherSuiteTls, cachedValue); + j2o.putIfAbsent(javaCipherSuiteSsl, cachedValue); + + logger.debug("Cipher suite mapping: {} => {}", javaCipherSuiteTls, openSslCipherSuite); + logger.debug("Cipher suite mapping: {} => {}", javaCipherSuiteSsl, openSslCipherSuite); + + return p2j; + } + + static String toJavaUncached(String openSslCipherSuite) { + return toJavaUncached0(openSslCipherSuite, true); + } + + private static String toJavaUncached0(String openSslCipherSuite, boolean checkTls13) { + if (checkTls13) { + Map converted = o2jTls13.get(openSslCipherSuite); + if (converted != null) { + return converted.get("TLS"); + } + } + + Matcher m = OPENSSL_CIPHERSUITE_PATTERN.matcher(openSslCipherSuite); + if (!m.matches()) { + return null; + } + + String handshakeAlgo = m.group(1); + final boolean export; + if (handshakeAlgo == null) { + handshakeAlgo = ""; + export = false; + } else if (handshakeAlgo.startsWith("EXP-")) { + handshakeAlgo = handshakeAlgo.substring(4); + export = true; + } else if ("EXP".equals(handshakeAlgo)) { + handshakeAlgo = ""; + export = true; + } else { + export = false; + } + + handshakeAlgo = toJavaHandshakeAlgo(handshakeAlgo, export); + String bulkCipher = toJavaBulkCipher(m.group(2), export); + String hmacAlgo = toJavaHmacAlgo(m.group(3)); + + String javaCipherSuite = handshakeAlgo + "_WITH_" + bulkCipher + '_' + hmacAlgo; + // For historical reasons the CHACHA20 ciphers do not follow OpenSSL's custom naming convention and omits the + // HMAC algorithm portion of the name. There is currently no way to derive this information because it is + // omitted from the OpenSSL cipher name, but they currently all use SHA256 for HMAC [1]. + // [1] https://www.openssl.org/docs/man1.1.0/apps/ciphers.html + return bulkCipher.contains("CHACHA20") ? javaCipherSuite + "_SHA256" : javaCipherSuite; + } + + private static String toJavaHandshakeAlgo(String handshakeAlgo, boolean export) { + if (handshakeAlgo.isEmpty()) { + handshakeAlgo = "RSA"; + } else if ("ADH".equals(handshakeAlgo)) { + handshakeAlgo = "DH_anon"; + } else if ("AECDH".equals(handshakeAlgo)) { + handshakeAlgo = "ECDH_anon"; + } + + handshakeAlgo = handshakeAlgo.replace('-', '_'); + if (export) { + return handshakeAlgo + "_EXPORT"; + } else { + return handshakeAlgo; + } + } + + private static String toJavaBulkCipher(String bulkCipher, boolean export) { + if (bulkCipher.startsWith("AES")) { + Matcher m = OPENSSL_AES_CBC_PATTERN.matcher(bulkCipher); + if (m.matches()) { + return m.replaceFirst("$1_$2_CBC"); + } + + m = OPENSSL_AES_PATTERN.matcher(bulkCipher); + if (m.matches()) { + return m.replaceFirst("$1_$2_$3"); + } + } + + if ("DES-CBC3".equals(bulkCipher)) { + return "3DES_EDE_CBC"; + } + + if ("RC4".equals(bulkCipher)) { + if (export) { + return "RC4_40"; + } else { + return "RC4_128"; + } + } + + if ("DES-CBC".equals(bulkCipher)) { + if (export) { + return "DES_CBC_40"; + } else { + return "DES_CBC"; + } + } + + if ("RC2-CBC".equals(bulkCipher)) { + if (export) { + return "RC2_CBC_40"; + } else { + return "RC2_CBC"; + } + } + + return bulkCipher.replace('-', '_'); + } + + private static String toJavaHmacAlgo(String hmacAlgo) { + // Java and OpenSSL use the same algorithm names for: + // + // * SHA + // * SHA256 + // * MD5 + // + return hmacAlgo; + } + + /** + * Convert the given ciphers if needed to OpenSSL format and append them to the correct {@link StringBuilder} + * depending on if its a TLSv1.3 cipher or not. If this methods returns without throwing an exception its + * guaranteed that at least one of the {@link StringBuilder}s contain some ciphers that can be used to configure + * OpenSSL. + */ + static void convertToCipherStrings(Iterable cipherSuites, StringBuilder cipherBuilder, + StringBuilder cipherTLSv13Builder, boolean boringSSL) { + for (String c: cipherSuites) { + if (c == null) { + break; + } + + String converted = toOpenSsl(c, boringSSL); + if (converted == null) { + converted = c; + } + + if (!OpenSsl.isCipherSuiteAvailable(converted)) { + throw new IllegalArgumentException("unsupported cipher suite: " + c + '(' + converted + ')'); + } + + if (SslUtils.isTLSv13Cipher(converted) || SslUtils.isTLSv13Cipher(c)) { + cipherTLSv13Builder.append(converted); + cipherTLSv13Builder.append(':'); + } else { + cipherBuilder.append(converted); + cipherBuilder.append(':'); + } + } + + if (cipherBuilder.length() == 0 && cipherTLSv13Builder.length() == 0) { + throw new IllegalArgumentException("empty cipher suites"); + } + if (cipherBuilder.length() > 0) { + cipherBuilder.setLength(cipherBuilder.length() - 1); + } + if (cipherTLSv13Builder.length() > 0) { + cipherTLSv13Builder.setLength(cipherTLSv13Builder.length() - 1); + } + } + + private CipherSuiteConverter() { } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/CipherSuiteFilter.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/CipherSuiteFilter.java new file mode 100644 index 0000000..5e47182 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/CipherSuiteFilter.java @@ -0,0 +1,34 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import javax.net.ssl.SSLEngine; +import java.util.List; +import java.util.Set; + +/** + * Provides a means to filter the supplied cipher suite based upon the supported and default cipher suites. + */ +public interface CipherSuiteFilter { + /** + * Filter the requested {@code ciphers} based upon other cipher characteristics. + * @param ciphers The requested ciphers + * @param defaultCiphers The default recommended ciphers for the current {@link SSLEngine} as determined by Netty + * @param supportedCiphers The supported ciphers for the current {@link SSLEngine} + * @return The filter list of ciphers. Must not return {@code null}. + */ + String[] filterCipherSuites(Iterable ciphers, List defaultCiphers, Set supportedCiphers); +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/Ciphers.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/Ciphers.java new file mode 100644 index 0000000..a615007 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/Ciphers.java @@ -0,0 +1,754 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +/** + * Cipher suites + */ +public final class Ciphers { + + /** + * TLS_AES_256_GCM_SHA384 + */ + public static final String TLS_AES_256_GCM_SHA384 = "TLS_AES_256_GCM_SHA384"; + + /** + * TLS_CHACHA20_POLY1305_SHA256 + */ + public static final String TLS_CHACHA20_POLY1305_SHA256 = "TLS_CHACHA20_POLY1305_SHA256"; + + /** + * TLS_AES_128_GCM_SHA256 + */ + public static final String TLS_AES_128_GCM_SHA256 = "TLS_AES_128_GCM_SHA256"; + + /** + * TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 + */ + public static final String TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 = "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384"; + + /** + * TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 + */ + public static final String TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 = "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384"; + + /** + * TLS_DHE_DSS_WITH_AES_256_GCM_SHA384 + */ + public static final String TLS_DHE_DSS_WITH_AES_256_GCM_SHA384 = "TLS_DHE_DSS_WITH_AES_256_GCM_SHA384"; + + /** + * TLS_DHE_RSA_WITH_AES_256_GCM_SHA384 + */ + public static final String TLS_DHE_RSA_WITH_AES_256_GCM_SHA384 = "TLS_DHE_RSA_WITH_AES_256_GCM_SHA384"; + + /** + * TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 + */ + public static final String TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256 = + "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256"; + + /** + * TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 + */ + public static final String TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256 = + "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256"; + + /** + * TLS_DHE_RSA_WITH_CHACHA20_POLY1305_SHA256 + */ + public static final String TLS_DHE_RSA_WITH_CHACHA20_POLY1305_SHA256 = "TLS_DHE_RSA_WITH_CHACHA20_POLY1305_SHA256"; + + /** + * TLS_ECDHE_ECDSA_WITH_AES_256_CBC_CCM8 + */ + public static final String TLS_ECDHE_ECDSA_WITH_AES_256_CBC_CCM8 = "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_CCM8"; + + /** + * TLS_ECDHE_ECDSA_WITH_AES_256_CBC_CCM + */ + public static final String TLS_ECDHE_ECDSA_WITH_AES_256_CBC_CCM = "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_CCM"; + + /** + * TLS_DHE_RSA_WITH_AES_256_CBC_CCM8 + */ + public static final String TLS_DHE_RSA_WITH_AES_256_CBC_CCM8 = "TLS_DHE_RSA_WITH_AES_256_CBC_CCM8"; + + /** + * TLS_DHE_RSA_WITH_AES_256_CBC_CCM + */ + public static final String TLS_DHE_RSA_WITH_AES_256_CBC_CCM = "TLS_DHE_RSA_WITH_AES_256_CBC_CCM"; + + /** + * TLS_ECDHE_ECDSA_WITH_ARIA256_GCM_SHA384 + */ + public static final String TLS_ECDHE_ECDSA_WITH_ARIA256_GCM_SHA384 = "TLS_ECDHE_ECDSA_WITH_ARIA256_GCM_SHA384"; + + /** + * TLS_RSA_WITH_ECDHE_ARIA256_GCM_SHA384 + */ + public static final String TLS_RSA_WITH_ECDHE_ARIA256_GCM_SHA384 = "TLS_RSA_WITH_ECDHE_ARIA256_GCM_SHA384"; + + /** + * TLS_DHE_DSS_WITH_ARIA256_GCM_SHA384 + */ + public static final String TLS_DHE_DSS_WITH_ARIA256_GCM_SHA384 = "TLS_DHE_DSS_WITH_ARIA256_GCM_SHA384"; + + /** + * TLS_DHE_RSA_WITH_ARIA256_GCM_SHA384 + */ + public static final String TLS_DHE_RSA_WITH_ARIA256_GCM_SHA384 = "TLS_DHE_RSA_WITH_ARIA256_GCM_SHA384"; + + /** + * TLS_DH_anon_WITH_AES_256_GCM_SHA384 + */ + public static final String TLS_DH_anon_WITH_AES_256_GCM_SHA384 = "TLS_DH_anon_WITH_AES_256_GCM_SHA384"; + + /** + * TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 + */ + public static final String TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 = "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256"; + + /** + * TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 + */ + public static final String TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 = "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"; + + /** + * TLS_DHE_DSS_WITH_AES_128_GCM_SHA256 + */ + public static final String TLS_DHE_DSS_WITH_AES_128_GCM_SHA256 = "TLS_DHE_DSS_WITH_AES_128_GCM_SHA256"; + + /** + * TLS_DHE_RSA_WITH_AES_128_GCM_SHA256 + */ + public static final String TLS_DHE_RSA_WITH_AES_128_GCM_SHA256 = "TLS_DHE_RSA_WITH_AES_128_GCM_SHA256"; + + /** + * TLS_ECDHE_ECDSA_WITH_AES_128_CBC_CCM8 + */ + public static final String TLS_ECDHE_ECDSA_WITH_AES_128_CBC_CCM8 = "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_CCM8"; + + /** + * TLS_ECDHE_ECDSA_WITH_AES_128_CBC_CCM + */ + public static final String TLS_ECDHE_ECDSA_WITH_AES_128_CBC_CCM = "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_CCM"; + + /** + * TLS_DHE_RSA_WITH_AES_128_CBC_CCM8 + */ + public static final String TLS_DHE_RSA_WITH_AES_128_CBC_CCM8 = "TLS_DHE_RSA_WITH_AES_128_CBC_CCM8"; + + /** + * TLS_DHE_RSA_WITH_AES_128_CBC_CCM + */ + public static final String TLS_DHE_RSA_WITH_AES_128_CBC_CCM = "TLS_DHE_RSA_WITH_AES_128_CBC_CCM"; + + /** + * TLS_ECDHE_ECDSA_WITH_ARIA128_GCM_SHA256 + */ + public static final String TLS_ECDHE_ECDSA_WITH_ARIA128_GCM_SHA256 = "TLS_ECDHE_ECDSA_WITH_ARIA128_GCM_SHA256"; + + /** + * TLS_RSA_WITH_ECDHE_ARIA128_GCM_SHA256 + */ + public static final String TLS_RSA_WITH_ECDHE_ARIA128_GCM_SHA256 = "TLS_RSA_WITH_ECDHE_ARIA128_GCM_SHA256"; + + /** + * TLS_DHE_DSS_WITH_ARIA128_GCM_SHA256 + */ + public static final String TLS_DHE_DSS_WITH_ARIA128_GCM_SHA256 = "TLS_DHE_DSS_WITH_ARIA128_GCM_SHA256"; + + /** + * TLS_DHE_RSA_WITH_ARIA128_GCM_SHA256 + */ + public static final String TLS_DHE_RSA_WITH_ARIA128_GCM_SHA256 = "TLS_DHE_RSA_WITH_ARIA128_GCM_SHA256"; + + /** + * TLS_DH_anon_WITH_AES_128_GCM_SHA256 + */ + public static final String TLS_DH_anon_WITH_AES_128_GCM_SHA256 = "TLS_DH_anon_WITH_AES_128_GCM_SHA256"; + + /** + * TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384 + */ + public static final String TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384 = "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384"; + + /** + * TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384 + */ + public static final String TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384 = "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384"; + + /** + * TLS_DHE_RSA_WITH_AES_256_CBC_SHA256 + */ + public static final String TLS_DHE_RSA_WITH_AES_256_CBC_SHA256 = "TLS_DHE_RSA_WITH_AES_256_CBC_SHA256"; + + /** + * TLS_DHE_DSS_WITH_AES_256_CBC_SHA256 + */ + public static final String TLS_DHE_DSS_WITH_AES_256_CBC_SHA256 = "TLS_DHE_DSS_WITH_AES_256_CBC_SHA256"; + + /** + * TLS_ECDHE_ECDSA_WITH_CAMELLIA256_SHA384 + */ + public static final String TLS_ECDHE_ECDSA_WITH_CAMELLIA256_SHA384 = "TLS_ECDHE_ECDSA_WITH_CAMELLIA256_SHA384"; + + /** + * TLS_ECDHE_RSA_WITH_CAMELLIA256_SHA384 + */ + public static final String TLS_ECDHE_RSA_WITH_CAMELLIA256_SHA384 = "TLS_ECDHE_RSA_WITH_CAMELLIA256_SHA384"; + + /** + * TLS_DHE_RSA_WITH_CAMELLIA256_SHA256 + */ + public static final String TLS_DHE_RSA_WITH_CAMELLIA256_SHA256 = "TLS_DHE_RSA_WITH_CAMELLIA256_SHA256"; + + /** + * TLS_DHE_DSS_WITH_CAMELLIA256_SHA256 + */ + public static final String TLS_DHE_DSS_WITH_CAMELLIA256_SHA256 = "TLS_DHE_DSS_WITH_CAMELLIA256_SHA256"; + + /** + * TLS_DH_anon_WITH_AES_256_CBC_SHA256 + */ + public static final String TLS_DH_anon_WITH_AES_256_CBC_SHA256 = "TLS_DH_anon_WITH_AES_256_CBC_SHA256"; + + /** + * TLS_DH_anon_WITH_CAMELLIA256_SHA256 + */ + public static final String TLS_DH_anon_WITH_CAMELLIA256_SHA256 = "TLS_DH_anon_WITH_CAMELLIA256_SHA256"; + + /** + * TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256 + */ + public static final String TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256 = "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256"; + + /** + * TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256 + */ + public static final String TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256 = "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256"; + + /** + * TLS_DHE_RSA_WITH_AES_128_CBC_SHA256 + */ + public static final String TLS_DHE_RSA_WITH_AES_128_CBC_SHA256 = "TLS_DHE_RSA_WITH_AES_128_CBC_SHA256"; + + /** + * TLS_DHE_DSS_WITH_AES_128_CBC_SHA256 + */ + public static final String TLS_DHE_DSS_WITH_AES_128_CBC_SHA256 = "TLS_DHE_DSS_WITH_AES_128_CBC_SHA256"; + + /** + * TLS_ECDHE_ECDSA_WITH_CAMELLIA128_SHA256 + */ + public static final String TLS_ECDHE_ECDSA_WITH_CAMELLIA128_SHA256 = "TLS_ECDHE_ECDSA_WITH_CAMELLIA128_SHA256"; + + /** + * TLS_ECDHE_RSA_WITH_CAMELLIA128_SHA256 + */ + public static final String TLS_ECDHE_RSA_WITH_CAMELLIA128_SHA256 = "TLS_ECDHE_RSA_WITH_CAMELLIA128_SHA256"; + + /** + * TLS_DHE_RSA_WITH_CAMELLIA128_SHA256 + */ + public static final String TLS_DHE_RSA_WITH_CAMELLIA128_SHA256 = "TLS_DHE_RSA_WITH_CAMELLIA128_SHA256"; + + /** + * TLS_DHE_DSS_WITH_CAMELLIA128_SHA256 + */ + public static final String TLS_DHE_DSS_WITH_CAMELLIA128_SHA256 = "TLS_DHE_DSS_WITH_CAMELLIA128_SHA256"; + + /** + * TLS_DH_anon_WITH_AES_128_CBC_SHA256 + */ + public static final String TLS_DH_anon_WITH_AES_128_CBC_SHA256 = "TLS_DH_anon_WITH_AES_128_CBC_SHA256"; + + /** + * TLS_DH_anon_WITH_CAMELLIA128_SHA256 + */ + public static final String TLS_DH_anon_WITH_CAMELLIA128_SHA256 = "TLS_DH_anon_WITH_CAMELLIA128_SHA256"; + + /** + * TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA + */ + public static final String TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA = "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA"; + + /** + * TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA + */ + public static final String TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA = "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA"; + + /** + * TLS_DHE_RSA_WITH_AES_256_CBC_SHA + */ + public static final String TLS_DHE_RSA_WITH_AES_256_CBC_SHA = "TLS_DHE_RSA_WITH_AES_256_CBC_SHA"; + + /** + * TLS_DHE_DSS_WITH_AES_256_CBC_SHA + */ + public static final String TLS_DHE_DSS_WITH_AES_256_CBC_SHA = "TLS_DHE_DSS_WITH_AES_256_CBC_SHA"; + + /** + * TLS_DHE_RSA_WITH_CAMELLIA256_SHA + */ + public static final String TLS_DHE_RSA_WITH_CAMELLIA256_SHA = "TLS_DHE_RSA_WITH_CAMELLIA256_SHA"; + + /** + * TLS_DHE_DSS_WITH_CAMELLIA256_SHA + */ + public static final String TLS_DHE_DSS_WITH_CAMELLIA256_SHA = "TLS_DHE_DSS_WITH_CAMELLIA256_SHA"; + + /** + * TLS_ECDH_anon_WITH_AES_256_CBC_SHA + */ + public static final String TLS_ECDH_anon_WITH_AES_256_CBC_SHA = "TLS_ECDH_anon_WITH_AES_256_CBC_SHA"; + + /** + * TLS_DH_anon_WITH_AES_256_CBC_SHA + */ + public static final String TLS_DH_anon_WITH_AES_256_CBC_SHA = "TLS_DH_anon_WITH_AES_256_CBC_SHA"; + + /** + * TLS_DH_anon_WITH_CAMELLIA256_SHA + */ + public static final String TLS_DH_anon_WITH_CAMELLIA256_SHA = "TLS_DH_anon_WITH_CAMELLIA256_SHA"; + + /** + * TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA + */ + public static final String TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA = "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA"; + + /** + * TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA + */ + public static final String TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA = "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA"; + + /** + * TLS_DHE_RSA_WITH_AES_128_CBC_SHA + */ + public static final String TLS_DHE_RSA_WITH_AES_128_CBC_SHA = "TLS_DHE_RSA_WITH_AES_128_CBC_SHA"; + + /** + * TLS_DHE_DSS_WITH_AES_128_CBC_SHA + */ + public static final String TLS_DHE_DSS_WITH_AES_128_CBC_SHA = "TLS_DHE_DSS_WITH_AES_128_CBC_SHA"; + + /** + * TLS_DHE_RSA_WITH_SEED_SHA + */ + public static final String TLS_DHE_RSA_WITH_SEED_SHA = "TLS_DHE_RSA_WITH_SEED_SHA"; + + /** + * TLS_DHE_DSS_WITH_SEED_SHA + */ + public static final String TLS_DHE_DSS_WITH_SEED_SHA = "TLS_DHE_DSS_WITH_SEED_SHA"; + + /** + * TLS_DHE_RSA_WITH_CAMELLIA128_SHA + */ + public static final String TLS_DHE_RSA_WITH_CAMELLIA128_SHA = "TLS_DHE_RSA_WITH_CAMELLIA128_SHA"; + + /** + * TLS_DHE_DSS_WITH_CAMELLIA128_SHA + */ + public static final String TLS_DHE_DSS_WITH_CAMELLIA128_SHA = "TLS_DHE_DSS_WITH_CAMELLIA128_SHA"; + + /** + * TLS_ECDH_anon_WITH_AES_128_CBC_SHA + */ + public static final String TLS_ECDH_anon_WITH_AES_128_CBC_SHA = "TLS_ECDH_anon_WITH_AES_128_CBC_SHA"; + + /** + * TLS_DH_anon_WITH_AES_128_CBC_SHA + */ + public static final String TLS_DH_anon_WITH_AES_128_CBC_SHA = "TLS_DH_anon_WITH_AES_128_CBC_SHA"; + + /** + * TLS_DH_anon_WITH_SEED_SHA + */ + public static final String TLS_DH_anon_WITH_SEED_SHA = "TLS_DH_anon_WITH_SEED_SHA"; + + /** + * TLS_DH_anon_WITH_CAMELLIA128_SHA + */ + public static final String TLS_DH_anon_WITH_CAMELLIA128_SHA = "TLS_DH_anon_WITH_CAMELLIA128_SHA"; + + /** + * TLS_RSA_PSK_WITH_AES_256_GCM_SHA384 + */ + public static final String TLS_RSA_PSK_WITH_AES_256_GCM_SHA384 = "TLS_RSA_PSK_WITH_AES_256_GCM_SHA384"; + + /** + * TLS_DHE_PSK_WITH_AES_256_GCM_SHA384 + */ + public static final String TLS_DHE_PSK_WITH_AES_256_GCM_SHA384 = "TLS_DHE_PSK_WITH_AES_256_GCM_SHA384"; + + /** + * TLS_RSA_PSK_WITH_CHACHA20_POLY1305_SHA256 + */ + public static final String TLS_RSA_PSK_WITH_CHACHA20_POLY1305_SHA256 = "TLS_RSA_PSK_WITH_CHACHA20_POLY1305_SHA256"; + + /** + * TLS_DHE_PSK_WITH_CHACHA20_POLY1305_SHA256 + */ + public static final String TLS_DHE_PSK_WITH_CHACHA20_POLY1305_SHA256 = "TLS_DHE_PSK_WITH_CHACHA20_POLY1305_SHA256"; + + /** + * TLS_ECDHE_PSK_WITH_CHACHA20_POLY1305_SHA256 + */ + public static final String TLS_ECDHE_PSK_WITH_CHACHA20_POLY1305_SHA256 = + "TLS_ECDHE_PSK_WITH_CHACHA20_POLY1305_SHA256"; + + /** + * TLS_DHE_PSK_WITH_AES_256_CBC_CCM8 + */ + public static final String TLS_DHE_PSK_WITH_AES_256_CBC_CCM8 = "TLS_DHE_PSK_WITH_AES_256_CBC_CCM8"; + + /** + * TLS_DHE_PSK_WITH_AES_256_CBC_CCM + */ + public static final String TLS_DHE_PSK_WITH_AES_256_CBC_CCM = "TLS_DHE_PSK_WITH_AES_256_CBC_CCM"; + + /** + * TLS_RSA_PSK_WITH_ARIA256_GCM_SHA384 + */ + public static final String TLS_RSA_PSK_WITH_ARIA256_GCM_SHA384 = "TLS_RSA_PSK_WITH_ARIA256_GCM_SHA384"; + + /** + * TLS_DHE_PSK_WITH_ARIA256_GCM_SHA384 + */ + public static final String TLS_DHE_PSK_WITH_ARIA256_GCM_SHA384 = "TLS_DHE_PSK_WITH_ARIA256_GCM_SHA384"; + + /** + * TLS_RSA_WITH_AES_256_GCM_SHA384 + */ + public static final String TLS_RSA_WITH_AES_256_GCM_SHA384 = "TLS_RSA_WITH_AES_256_GCM_SHA384"; + + /** + * TLS_RSA_WITH_AES_256_CBC_CCM8 + */ + public static final String TLS_RSA_WITH_AES_256_CBC_CCM8 = "TLS_RSA_WITH_AES_256_CBC_CCM8"; + + /** + * TLS_RSA_WITH_AES_256_CBC_CCM + */ + public static final String TLS_RSA_WITH_AES_256_CBC_CCM = "TLS_RSA_WITH_AES_256_CBC_CCM"; + + /** + * TLS_RSA_WITH_ARIA256_GCM_SHA384 + */ + public static final String TLS_RSA_WITH_ARIA256_GCM_SHA384 = "TLS_RSA_WITH_ARIA256_GCM_SHA384"; + + /** + * TLS_PSK_WITH_AES_256_GCM_SHA384 + */ + public static final String TLS_PSK_WITH_AES_256_GCM_SHA384 = "TLS_PSK_WITH_AES_256_GCM_SHA384"; + + /** + * TLS_PSK_WITH_CHACHA20_POLY1305_SHA256 + */ + public static final String TLS_PSK_WITH_CHACHA20_POLY1305_SHA256 = "TLS_PSK_WITH_CHACHA20_POLY1305_SHA256"; + + /** + * TLS_PSK_WITH_AES_256_CBC_CCM8 + */ + public static final String TLS_PSK_WITH_AES_256_CBC_CCM8 = "TLS_PSK_WITH_AES_256_CBC_CCM8"; + + /** + * TLS_PSK_WITH_AES_256_CBC_CCM + */ + public static final String TLS_PSK_WITH_AES_256_CBC_CCM = "TLS_PSK_WITH_AES_256_CBC_CCM"; + + /** + * TLS_PSK_WITH_ARIA256_GCM_SHA384 + */ + public static final String TLS_PSK_WITH_ARIA256_GCM_SHA384 = "TLS_PSK_WITH_ARIA256_GCM_SHA384"; + + /** + * TLS_RSA_PSK_WITH_AES_128_GCM_SHA256 + */ + public static final String TLS_RSA_PSK_WITH_AES_128_GCM_SHA256 = "TLS_RSA_PSK_WITH_AES_128_GCM_SHA256"; + + /** + * TLS_DHE_PSK_WITH_AES_128_GCM_SHA256 + */ + public static final String TLS_DHE_PSK_WITH_AES_128_GCM_SHA256 = "TLS_DHE_PSK_WITH_AES_128_GCM_SHA256"; + + /** + * TLS_DHE_PSK_WITH_AES_128_CBC_CCM8 + */ + public static final String TLS_DHE_PSK_WITH_AES_128_CBC_CCM8 = "TLS_DHE_PSK_WITH_AES_128_CBC_CCM8"; + + /** + * TLS_DHE_PSK_WITH_AES_128_CBC_CCM + */ + public static final String TLS_DHE_PSK_WITH_AES_128_CBC_CCM = "TLS_DHE_PSK_WITH_AES_128_CBC_CCM"; + + /** + * TLS_RSA_PSK_WITH_ARIA128_GCM_SHA256 + */ + public static final String TLS_RSA_PSK_WITH_ARIA128_GCM_SHA256 = "TLS_RSA_PSK_WITH_ARIA128_GCM_SHA256"; + + /** + * TLS_DHE_PSK_WITH_ARIA128_GCM_SHA256 + */ + public static final String TLS_DHE_PSK_WITH_ARIA128_GCM_SHA256 = "TLS_DHE_PSK_WITH_ARIA128_GCM_SHA256"; + + /** + * TLS_RSA_WITH_AES_128_GCM_SHA256 + */ + public static final String TLS_RSA_WITH_AES_128_GCM_SHA256 = "TLS_RSA_WITH_AES_128_GCM_SHA256"; + + /** + * TLS_RSA_WITH_AES_128_CBC_CCM8 + */ + public static final String TLS_RSA_WITH_AES_128_CBC_CCM8 = "TLS_RSA_WITH_AES_128_CBC_CCM8"; + + /** + * TLS_RSA_WITH_AES_128_CBC_CCM + */ + public static final String TLS_RSA_WITH_AES_128_CBC_CCM = "TLS_RSA_WITH_AES_128_CBC_CCM"; + + /** + * TLS_RSA_WITH_ARIA128_GCM_SHA256 + */ + public static final String TLS_RSA_WITH_ARIA128_GCM_SHA256 = "TLS_RSA_WITH_ARIA128_GCM_SHA256"; + + /** + * TLS_PSK_WITH_AES_128_GCM_SHA256 + */ + public static final String TLS_PSK_WITH_AES_128_GCM_SHA256 = "TLS_PSK_WITH_AES_128_GCM_SHA256"; + + /** + * TLS_PSK_WITH_AES_128_CBC_CCM8 + */ + public static final String TLS_PSK_WITH_AES_128_CBC_CCM8 = "TLS_PSK_WITH_AES_128_CBC_CCM8"; + + /** + * TLS_PSK_WITH_AES_128_CBC_CCM + */ + public static final String TLS_PSK_WITH_AES_128_CBC_CCM = "TLS_PSK_WITH_AES_128_CBC_CCM"; + + /** + * TLS_PSK_WITH_ARIA128_GCM_SHA256 + */ + public static final String TLS_PSK_WITH_ARIA128_GCM_SHA256 = "TLS_PSK_WITH_ARIA128_GCM_SHA256"; + + /** + * TLS_RSA_WITH_AES_256_CBC_SHA256 + */ + public static final String TLS_RSA_WITH_AES_256_CBC_SHA256 = "TLS_RSA_WITH_AES_256_CBC_SHA256"; + + /** + * TLS_RSA_WITH_CAMELLIA256_SHA256 + */ + public static final String TLS_RSA_WITH_CAMELLIA256_SHA256 = "TLS_RSA_WITH_CAMELLIA256_SHA256"; + + /** + * TLS_RSA_WITH_AES_128_CBC_SHA256 + */ + public static final String TLS_RSA_WITH_AES_128_CBC_SHA256 = "TLS_RSA_WITH_AES_128_CBC_SHA256"; + + /** + * TLS_RSA_WITH_CAMELLIA128_SHA256 + */ + public static final String TLS_RSA_WITH_CAMELLIA128_SHA256 = "TLS_RSA_WITH_CAMELLIA128_SHA256"; + + /** + * TLS_ECDHE_PSK_WITH_AES_256_CBC_SHA384 + */ + public static final String TLS_ECDHE_PSK_WITH_AES_256_CBC_SHA384 = "TLS_ECDHE_PSK_WITH_AES_256_CBC_SHA384"; + + /** + * TLS_ECDHE_PSK_WITH_AES_256_CBC_SHA + */ + public static final String TLS_ECDHE_PSK_WITH_AES_256_CBC_SHA = "TLS_ECDHE_PSK_WITH_AES_256_CBC_SHA"; + + /** + * TLS_SRP_DSS_WITH_AES_256_CBC_SHA + */ + public static final String TLS_SRP_DSS_WITH_AES_256_CBC_SHA = "TLS_SRP_DSS_WITH_AES_256_CBC_SHA"; + + /** + * TLS_SRP_RSA_WITH_AES_256_CBC_SHA + */ + public static final String TLS_SRP_RSA_WITH_AES_256_CBC_SHA = "TLS_SRP_RSA_WITH_AES_256_CBC_SHA"; + + /** + * TLS_SRP_WITH_AES_256_CBC_SHA + */ + public static final String TLS_SRP_WITH_AES_256_CBC_SHA = "TLS_SRP_WITH_AES_256_CBC_SHA"; + + /** + * TLS_RSA_PSK_WITH_AES_256_CBC_SHA384 + */ + public static final String TLS_RSA_PSK_WITH_AES_256_CBC_SHA384 = "TLS_RSA_PSK_WITH_AES_256_CBC_SHA384"; + + /** + * TLS_DHE_PSK_WITH_AES_256_CBC_SHA384 + */ + public static final String TLS_DHE_PSK_WITH_AES_256_CBC_SHA384 = "TLS_DHE_PSK_WITH_AES_256_CBC_SHA384"; + + /** + * TLS_RSA_PSK_WITH_AES_256_CBC_SHA + */ + public static final String TLS_RSA_PSK_WITH_AES_256_CBC_SHA = "TLS_RSA_PSK_WITH_AES_256_CBC_SHA"; + + /** + * TLS_DHE_PSK_WITH_AES_256_CBC_SHA + */ + public static final String TLS_DHE_PSK_WITH_AES_256_CBC_SHA = "TLS_DHE_PSK_WITH_AES_256_CBC_SHA"; + + /** + * TLS_ECDHE_PSK_WITH_CAMELLIA256_SHA384 + */ + public static final String TLS_ECDHE_PSK_WITH_CAMELLIA256_SHA384 = "TLS_ECDHE_PSK_WITH_CAMELLIA256_SHA384"; + + /** + * TLS_RSA_PSK_WITH_CAMELLIA256_SHA384 + */ + public static final String TLS_RSA_PSK_WITH_CAMELLIA256_SHA384 = "TLS_RSA_PSK_WITH_CAMELLIA256_SHA384"; + + /** + * TLS_DHE_PSK_WITH_CAMELLIA256_SHA384 + */ + public static final String TLS_DHE_PSK_WITH_CAMELLIA256_SHA384 = "TLS_DHE_PSK_WITH_CAMELLIA256_SHA384"; + + /** + * TLS_RSA_WITH_AES_256_CBC_SHA + */ + public static final String TLS_RSA_WITH_AES_256_CBC_SHA = "TLS_RSA_WITH_AES_256_CBC_SHA"; + + /** + * TLS_RSA_WITH_CAMELLIA256_SHA + */ + public static final String TLS_RSA_WITH_CAMELLIA256_SHA = "TLS_RSA_WITH_CAMELLIA256_SHA"; + + /** + * TLS_PSK_WITH_AES_256_CBC_SHA384 + */ + public static final String TLS_PSK_WITH_AES_256_CBC_SHA384 = "TLS_PSK_WITH_AES_256_CBC_SHA384"; + + /** + * TLS_PSK_WITH_AES_256_CBC_SHA + */ + public static final String TLS_PSK_WITH_AES_256_CBC_SHA = "TLS_PSK_WITH_AES_256_CBC_SHA"; + + /** + * TLS_PSK_WITH_CAMELLIA256_SHA384 + */ + public static final String TLS_PSK_WITH_CAMELLIA256_SHA384 = "TLS_PSK_WITH_CAMELLIA256_SHA384"; + + /** + * TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256 + */ + public static final String TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256 = "TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA256"; + + /** + * TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA + */ + public static final String TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA = "TLS_ECDHE_PSK_WITH_AES_128_CBC_SHA"; + + /** + * TLS_SRP_DSS_WITH_AES_128_CBC_SHA + */ + public static final String TLS_SRP_DSS_WITH_AES_128_CBC_SHA = "TLS_SRP_DSS_WITH_AES_128_CBC_SHA"; + + /** + * TLS_SRP_RSA_WITH_AES_128_CBC_SHA + */ + public static final String TLS_SRP_RSA_WITH_AES_128_CBC_SHA = "TLS_SRP_RSA_WITH_AES_128_CBC_SHA"; + + /** + * TLS_SRP_WITH_AES_128_CBC_SHA + */ + public static final String TLS_SRP_WITH_AES_128_CBC_SHA = "TLS_SRP_WITH_AES_128_CBC_SHA"; + + /** + * TLS_RSA_PSK_WITH_AES_128_CBC_SHA256 + */ + public static final String TLS_RSA_PSK_WITH_AES_128_CBC_SHA256 = "TLS_RSA_PSK_WITH_AES_128_CBC_SHA256"; + + /** + * TLS_DHE_PSK_WITH_AES_128_CBC_SHA256 + */ + public static final String TLS_DHE_PSK_WITH_AES_128_CBC_SHA256 = "TLS_DHE_PSK_WITH_AES_128_CBC_SHA256"; + + /** + * TLS_RSA_PSK_WITH_AES_128_CBC_SHA + */ + public static final String TLS_RSA_PSK_WITH_AES_128_CBC_SHA = "TLS_RSA_PSK_WITH_AES_128_CBC_SHA"; + + /** + * TLS_DHE_PSK_WITH_AES_128_CBC_SHA + */ + public static final String TLS_DHE_PSK_WITH_AES_128_CBC_SHA = "TLS_DHE_PSK_WITH_AES_128_CBC_SHA"; + + /** + * TLS_ECDHE_PSK_WITH_CAMELLIA128_SHA256 + */ + public static final String TLS_ECDHE_PSK_WITH_CAMELLIA128_SHA256 = "TLS_ECDHE_PSK_WITH_CAMELLIA128_SHA256"; + + /** + * TLS_RSA_PSK_WITH_CAMELLIA128_SHA256 + */ + public static final String TLS_RSA_PSK_WITH_CAMELLIA128_SHA256 = "TLS_RSA_PSK_WITH_CAMELLIA128_SHA256"; + + /** + * TLS_DHE_PSK_WITH_CAMELLIA128_SHA256 + */ + public static final String TLS_DHE_PSK_WITH_CAMELLIA128_SHA256 = "TLS_DHE_PSK_WITH_CAMELLIA128_SHA256"; + + /** + * TLS_RSA_WITH_AES_128_CBC_SHA + */ + public static final String TLS_RSA_WITH_AES_128_CBC_SHA = "TLS_RSA_WITH_AES_128_CBC_SHA"; + + /** + * TLS_RSA_WITH_SEED_SHA + */ + public static final String TLS_RSA_WITH_SEED_SHA = "TLS_RSA_WITH_SEED_SHA"; + + /** + * TLS_RSA_WITH_CAMELLIA128_SHA + */ + public static final String TLS_RSA_WITH_CAMELLIA128_SHA = "TLS_RSA_WITH_CAMELLIA128_SHA"; + + /** + * TLS_RSA_WITH_IDEA_CBC_SHA + */ + public static final String TLS_RSA_WITH_IDEA_CBC_SHA = "TLS_RSA_WITH_IDEA_CBC_SHA"; + + /** + * TLS_PSK_WITH_AES_128_CBC_SHA256 + */ + public static final String TLS_PSK_WITH_AES_128_CBC_SHA256 = "TLS_PSK_WITH_AES_128_CBC_SHA256"; + + /** + * TLS_PSK_WITH_AES_128_CBC_SHA + */ + public static final String TLS_PSK_WITH_AES_128_CBC_SHA = "TLS_PSK_WITH_AES_128_CBC_SHA"; + + /** + * TLS_PSK_WITH_CAMELLIA128_SHA256 + */ + public static final String TLS_PSK_WITH_CAMELLIA128_SHA256 = "TLS_PSK_WITH_CAMELLIA128_SHA256"; + + private Ciphers() { + // Prevent outside initialization + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/ClientAuth.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/ClientAuth.java new file mode 100644 index 0000000..82696d9 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/ClientAuth.java @@ -0,0 +1,38 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.ssl; + +/** + * Indicates the state of the {@link javax.net.ssl.SSLEngine} with respect to client authentication. + * This configuration item really only applies when building the server-side {@link SslContext}. + */ +public enum ClientAuth { + /** + * Indicates that the {@link javax.net.ssl.SSLEngine} will not request client authentication. + */ + NONE, + + /** + * Indicates that the {@link javax.net.ssl.SSLEngine} will request client authentication. + */ + OPTIONAL, + + /** + * Indicates that the {@link javax.net.ssl.SSLEngine} will *require* client authentication. + */ + REQUIRE +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/Conscrypt.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/Conscrypt.java new file mode 100644 index 0000000..c5af3fd --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/Conscrypt.java @@ -0,0 +1,75 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.util.internal.PlatformDependent; + +import javax.net.ssl.SSLEngine; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; + +/** + * Contains methods that can be used to detect if conscrypt is usable. + */ +final class Conscrypt { + // This class exists to avoid loading other conscrypt related classes using features only available in JDK8+, + // because we need to maintain JDK6+ runtime compatibility. + private static final Method IS_CONSCRYPT_SSLENGINE; + + static { + Method isConscryptSSLEngine = null; + + if ((PlatformDependent.javaVersion() >= 8 && + // Only works on Java14 and earlier for now + // See https://github.com/google/conscrypt/issues/838 + PlatformDependent.javaVersion() < 15) || PlatformDependent.isAndroid()) { + try { + Class providerClass = Class.forName("org.conscrypt.OpenSSLProvider", true, + PlatformDependent.getClassLoader(ConscryptAlpnSslEngine.class)); + providerClass.newInstance(); + + Class conscryptClass = Class.forName("org.conscrypt.Conscrypt", true, + PlatformDependent.getClassLoader(ConscryptAlpnSslEngine.class)); + isConscryptSSLEngine = conscryptClass.getMethod("isConscrypt", SSLEngine.class); + } catch (Throwable ignore) { + // ignore + } + } + IS_CONSCRYPT_SSLENGINE = isConscryptSSLEngine; + } + + /** + * Indicates whether or not conscrypt is available on the current system. + */ + static boolean isAvailable() { + return IS_CONSCRYPT_SSLENGINE != null; + } + + /** + * Returns {@code true} if the passed in {@link SSLEngine} is handled by Conscrypt, {@code false} otherwise. + */ + static boolean isEngineSupported(SSLEngine engine) { + try { + return IS_CONSCRYPT_SSLENGINE != null && (Boolean) IS_CONSCRYPT_SSLENGINE.invoke(null, engine); + } catch (IllegalAccessException ignore) { + return false; + } catch (InvocationTargetException ex) { + throw new RuntimeException(ex); + } + } + + private Conscrypt() { } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/ConscryptAlpnSslEngine.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/ConscryptAlpnSslEngine.java new file mode 100644 index 0000000..917ebae --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/ConscryptAlpnSslEngine.java @@ -0,0 +1,212 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import static io.netty.handler.ssl.SslUtils.toSSLHandshakeException; +import static io.netty.util.internal.ObjectUtil.checkNotNull; +import static java.lang.Math.min; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.handler.ssl.JdkApplicationProtocolNegotiator.ProtocolSelectionListener; +import io.netty.handler.ssl.JdkApplicationProtocolNegotiator.ProtocolSelector; +import java.nio.ByteBuffer; +import java.util.Collections; +import java.util.LinkedHashSet; +import java.util.List; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLEngineResult; +import javax.net.ssl.SSLException; + +import io.netty.util.internal.EmptyArrays; +import io.netty.util.internal.SystemPropertyUtil; +import org.conscrypt.AllocatedBuffer; +import org.conscrypt.BufferAllocator; +import org.conscrypt.Conscrypt; +import org.conscrypt.HandshakeListener; + +/** + * A {@link JdkSslEngine} that uses the Conscrypt provider or SSL with ALPN. + */ +abstract class ConscryptAlpnSslEngine extends JdkSslEngine { + private static final boolean USE_BUFFER_ALLOCATOR = SystemPropertyUtil.getBoolean( + "io.netty.handler.ssl.conscrypt.useBufferAllocator", true); + + static ConscryptAlpnSslEngine newClientEngine(SSLEngine engine, ByteBufAllocator alloc, + JdkApplicationProtocolNegotiator applicationNegotiator) { + return new ClientEngine(engine, alloc, applicationNegotiator); + } + + static ConscryptAlpnSslEngine newServerEngine(SSLEngine engine, ByteBufAllocator alloc, + JdkApplicationProtocolNegotiator applicationNegotiator) { + return new ServerEngine(engine, alloc, applicationNegotiator); + } + + private ConscryptAlpnSslEngine(SSLEngine engine, ByteBufAllocator alloc, List protocols) { + super(engine); + + // Configure the Conscrypt engine to use Netty's buffer allocator. This is a trade-off of memory vs + // performance. + // + // If no allocator is provided, the engine will internally allocate a direct buffer of max packet size in + // order to optimize JNI calls (this happens the first time it is provided a non-direct buffer from the + // application). + // + // Alternatively, if an allocator is provided, no internal buffer will be created and direct buffers will be + // retrieved from the allocator on-demand. + if (USE_BUFFER_ALLOCATOR) { + Conscrypt.setBufferAllocator(engine, new BufferAllocatorAdapter(alloc)); + } + + // Set the list of supported ALPN protocols on the engine. + Conscrypt.setApplicationProtocols(engine, protocols.toArray(EmptyArrays.EMPTY_STRINGS)); + } + + /** + * Calculates the maximum size of the encrypted output buffer required to wrap the given plaintext bytes. Assumes + * as a worst case that there is one TLS record per buffer. + * + * @param plaintextBytes the number of plaintext bytes to be wrapped. + * @param numBuffers the number of buffers that the plaintext bytes are spread across. + * @return the maximum size of the encrypted output buffer required for the wrap operation. + */ + final int calculateOutNetBufSize(int plaintextBytes, int numBuffers) { + // Assuming a max of one frame per component in a composite buffer. + return calculateSpace(plaintextBytes, numBuffers, Integer.MAX_VALUE); + } + + /** + * Calculate the space necessary in an out buffer to hold the max size that the given + * plaintextBytes and numBuffers can produce when encrypted. Assumes as a worst case + * that there is one TLS record per buffer. + * @param plaintextBytes the number of plaintext bytes to be wrapped. + * @param numBuffers the number of buffers that the plaintext bytes are spread across. + * @return the maximum size of the encrypted output buffer required for the wrap operation. + */ + final int calculateRequiredOutBufSpace(int plaintextBytes, int numBuffers) { + return calculateSpace(plaintextBytes, numBuffers, Conscrypt.maxEncryptedPacketLength()); + } + + private int calculateSpace(int plaintextBytes, int numBuffers, long maxPacketLength) { + long maxOverhead = (long) Conscrypt.maxSealOverhead(getWrappedEngine()) * numBuffers; + return (int) min(maxPacketLength, plaintextBytes + maxOverhead); + } + + final SSLEngineResult unwrap(ByteBuffer[] srcs, ByteBuffer[] dests) throws SSLException { + return Conscrypt.unwrap(getWrappedEngine(), srcs, dests); + } + + private static final class ClientEngine extends ConscryptAlpnSslEngine { + private final ProtocolSelectionListener protocolListener; + + ClientEngine(SSLEngine engine, ByteBufAllocator alloc, + JdkApplicationProtocolNegotiator applicationNegotiator) { + super(engine, alloc, applicationNegotiator.protocols()); + // Register for completion of the handshake. + Conscrypt.setHandshakeListener(engine, new HandshakeListener() { + @Override + public void onHandshakeFinished() throws SSLException { + selectProtocol(); + } + }); + + protocolListener = checkNotNull(applicationNegotiator + .protocolListenerFactory().newListener(this, applicationNegotiator.protocols()), + "protocolListener"); + } + + private void selectProtocol() throws SSLException { + String protocol = Conscrypt.getApplicationProtocol(getWrappedEngine()); + try { + protocolListener.selected(protocol); + } catch (Throwable e) { + throw toSSLHandshakeException(e); + } + } + } + + private static final class ServerEngine extends ConscryptAlpnSslEngine { + private final ProtocolSelector protocolSelector; + + ServerEngine(SSLEngine engine, ByteBufAllocator alloc, + JdkApplicationProtocolNegotiator applicationNegotiator) { + super(engine, alloc, applicationNegotiator.protocols()); + + // Register for completion of the handshake. + Conscrypt.setHandshakeListener(engine, new HandshakeListener() { + @Override + public void onHandshakeFinished() throws SSLException { + selectProtocol(); + } + }); + + protocolSelector = checkNotNull(applicationNegotiator.protocolSelectorFactory() + .newSelector(this, + new LinkedHashSet(applicationNegotiator.protocols())), + "protocolSelector"); + } + + private void selectProtocol() throws SSLException { + try { + String protocol = Conscrypt.getApplicationProtocol(getWrappedEngine()); + protocolSelector.select(protocol != null ? Collections.singletonList(protocol) + : Collections.emptyList()); + } catch (Throwable e) { + throw toSSLHandshakeException(e); + } + } + } + + private static final class BufferAllocatorAdapter extends BufferAllocator { + private final ByteBufAllocator alloc; + + BufferAllocatorAdapter(ByteBufAllocator alloc) { + this.alloc = alloc; + } + + @Override + public AllocatedBuffer allocateDirectBuffer(int capacity) { + return new BufferAdapter(alloc.directBuffer(capacity)); + } + } + + private static final class BufferAdapter extends AllocatedBuffer { + private final ByteBuf nettyBuffer; + private final ByteBuffer buffer; + + BufferAdapter(ByteBuf nettyBuffer) { + this.nettyBuffer = nettyBuffer; + buffer = nettyBuffer.nioBuffer(0, nettyBuffer.capacity()); + } + + @Override + public ByteBuffer nioBuffer() { + return buffer; + } + + @Override + public AllocatedBuffer retain() { + nettyBuffer.retain(); + return this; + } + + @Override + public AllocatedBuffer release() { + nettyBuffer.release(); + return this; + } + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/DefaultOpenSslKeyMaterial.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/DefaultOpenSslKeyMaterial.java new file mode 100644 index 0000000..6673044 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/DefaultOpenSslKeyMaterial.java @@ -0,0 +1,126 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.internal.tcnative.SSL; +import io.netty.util.AbstractReferenceCounted; +import io.netty.util.IllegalReferenceCountException; +import io.netty.util.ResourceLeakDetector; +import io.netty.util.ResourceLeakDetectorFactory; +import io.netty.util.ResourceLeakTracker; + +import java.security.cert.X509Certificate; + +final class DefaultOpenSslKeyMaterial extends AbstractReferenceCounted implements OpenSslKeyMaterial { + + private static final ResourceLeakDetector leakDetector = + ResourceLeakDetectorFactory.instance().newResourceLeakDetector(DefaultOpenSslKeyMaterial.class); + private final ResourceLeakTracker leak; + private final X509Certificate[] x509CertificateChain; + private long chain; + private long privateKey; + + DefaultOpenSslKeyMaterial(long chain, long privateKey, X509Certificate[] x509CertificateChain) { + this.chain = chain; + this.privateKey = privateKey; + this.x509CertificateChain = x509CertificateChain; + leak = leakDetector.track(this); + } + + @Override + public X509Certificate[] certificateChain() { + return x509CertificateChain.clone(); + } + + @Override + public long certificateChainAddress() { + if (refCnt() <= 0) { + throw new IllegalReferenceCountException(); + } + return chain; + } + + @Override + public long privateKeyAddress() { + if (refCnt() <= 0) { + throw new IllegalReferenceCountException(); + } + return privateKey; + } + + @Override + protected void deallocate() { + SSL.freeX509Chain(chain); + chain = 0; + SSL.freePrivateKey(privateKey); + privateKey = 0; + if (leak != null) { + boolean closed = leak.close(this); + assert closed; + } + } + + @Override + public DefaultOpenSslKeyMaterial retain() { + if (leak != null) { + leak.record(); + } + super.retain(); + return this; + } + + @Override + public DefaultOpenSslKeyMaterial retain(int increment) { + if (leak != null) { + leak.record(); + } + super.retain(increment); + return this; + } + + @Override + public DefaultOpenSslKeyMaterial touch() { + if (leak != null) { + leak.record(); + } + super.touch(); + return this; + } + + @Override + public DefaultOpenSslKeyMaterial touch(Object hint) { + if (leak != null) { + leak.record(hint); + } + return this; + } + + @Override + public boolean release() { + if (leak != null) { + leak.record(); + } + return super.release(); + } + + @Override + public boolean release(int decrement) { + if (leak != null) { + leak.record(); + } + return super.release(decrement); + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/DelegatingSslContext.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/DelegatingSslContext.java new file mode 100644 index 0000000..aeaf8c6 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/DelegatingSslContext.java @@ -0,0 +1,122 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.buffer.ByteBufAllocator; +import io.netty.util.internal.ObjectUtil; + +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLSessionContext; +import java.util.List; +import java.util.concurrent.Executor; + +/** + * Adapter class which allows to wrap another {@link SslContext} and init {@link SSLEngine} instances. + */ +public abstract class DelegatingSslContext extends SslContext { + + private final SslContext ctx; + + protected DelegatingSslContext(SslContext ctx) { + this.ctx = ObjectUtil.checkNotNull(ctx, "ctx"); + } + + @Override + public final boolean isClient() { + return ctx.isClient(); + } + + @Override + public final List cipherSuites() { + return ctx.cipherSuites(); + } + + @Override + public final long sessionCacheSize() { + return ctx.sessionCacheSize(); + } + + @Override + public final long sessionTimeout() { + return ctx.sessionTimeout(); + } + + @Override + public final ApplicationProtocolNegotiator applicationProtocolNegotiator() { + return ctx.applicationProtocolNegotiator(); + } + + @Override + public final SSLEngine newEngine(ByteBufAllocator alloc) { + SSLEngine engine = ctx.newEngine(alloc); + initEngine(engine); + return engine; + } + + @Override + public final SSLEngine newEngine(ByteBufAllocator alloc, String peerHost, int peerPort) { + SSLEngine engine = ctx.newEngine(alloc, peerHost, peerPort); + initEngine(engine); + return engine; + } + + @Override + protected final SslHandler newHandler(ByteBufAllocator alloc, boolean startTls) { + SslHandler handler = ctx.newHandler(alloc, startTls); + initHandler(handler); + return handler; + } + + @Override + protected final SslHandler newHandler(ByteBufAllocator alloc, String peerHost, int peerPort, boolean startTls) { + SslHandler handler = ctx.newHandler(alloc, peerHost, peerPort, startTls); + initHandler(handler); + return handler; + } + + @Override + protected SslHandler newHandler(ByteBufAllocator alloc, boolean startTls, Executor executor) { + SslHandler handler = ctx.newHandler(alloc, startTls, executor); + initHandler(handler); + return handler; + } + + @Override + protected SslHandler newHandler(ByteBufAllocator alloc, String peerHost, int peerPort, + boolean startTls, Executor executor) { + SslHandler handler = ctx.newHandler(alloc, peerHost, peerPort, startTls, executor); + initHandler(handler); + return handler; + } + + @Override + public final SSLSessionContext sessionContext() { + return ctx.sessionContext(); + } + + /** + * Init the {@link SSLEngine}. + */ + protected abstract void initEngine(SSLEngine engine); + + /** + * Init the {@link SslHandler}. This will by default call {@link #initEngine(SSLEngine)}, sub-classes may override + * this. + */ + protected void initHandler(SslHandler handler) { + initEngine(handler.engine()); + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/EnhancingX509ExtendedTrustManager.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/EnhancingX509ExtendedTrustManager.java new file mode 100644 index 0000000..c2c3e90 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/EnhancingX509ExtendedTrustManager.java @@ -0,0 +1,124 @@ +/* + * Copyright 2023 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.ssl; + +import io.netty.util.internal.SuppressJava6Requirement; + +import javax.net.ssl.SSLEngine; +import javax.net.ssl.X509ExtendedTrustManager; +import javax.net.ssl.X509TrustManager; +import java.net.Socket; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; +import java.util.Collection; +import java.util.List; + + +/** + * Wraps an existing {@link X509ExtendedTrustManager} and enhances the {@link CertificateException} that is thrown + * because of hostname validation. + */ +@SuppressJava6Requirement(reason = "Usage guarded by java version check") +final class EnhancingX509ExtendedTrustManager extends X509ExtendedTrustManager { + private final X509ExtendedTrustManager wrapped; + + EnhancingX509ExtendedTrustManager(X509TrustManager wrapped) { + this.wrapped = (X509ExtendedTrustManager) wrapped; + } + + @Override + public void checkClientTrusted(X509Certificate[] chain, String authType, Socket socket) + throws CertificateException { + wrapped.checkClientTrusted(chain, authType, socket); + } + + @Override + public void checkServerTrusted(X509Certificate[] chain, String authType, Socket socket) + throws CertificateException { + try { + wrapped.checkServerTrusted(chain, authType, socket); + } catch (CertificateException e) { + throwEnhancedCertificateException(chain, e); + } + } + + @Override + public void checkClientTrusted(X509Certificate[] chain, String authType, SSLEngine engine) + throws CertificateException { + wrapped.checkClientTrusted(chain, authType, engine); + } + + @Override + public void checkServerTrusted(X509Certificate[] chain, String authType, SSLEngine engine) + throws CertificateException { + try { + wrapped.checkServerTrusted(chain, authType, engine); + } catch (CertificateException e) { + throwEnhancedCertificateException(chain, e); + } + } + + @Override + public void checkClientTrusted(X509Certificate[] chain, String authType) + throws CertificateException { + wrapped.checkClientTrusted(chain, authType); + } + + @Override + public void checkServerTrusted(X509Certificate[] chain, String authType) + throws CertificateException { + try { + wrapped.checkServerTrusted(chain, authType); + } catch (CertificateException e) { + throwEnhancedCertificateException(chain, e); + } + } + + @Override + public X509Certificate[] getAcceptedIssuers() { + return wrapped.getAcceptedIssuers(); + } + + private static void throwEnhancedCertificateException(X509Certificate[] chain, CertificateException e) + throws CertificateException { + // Matching the message is the best we can do sadly. + String message = e.getMessage(); + if (message != null && e.getMessage().startsWith("No subject alternative DNS name matching")) { + StringBuilder names = new StringBuilder(64); + for (int i = 0; i < chain.length; i++) { + X509Certificate cert = chain[i]; + Collection> collection = cert.getSubjectAlternativeNames(); + if (collection != null) { + for (List altNames : collection) { + // 2 is dNSName. See X509Certificate javadocs. + if (altNames.size() >= 2 && ((Integer) altNames.get(0)).intValue() == 2) { + names.append((String) altNames.get(1)).append(","); + } + } + } + } + if (names.length() != 0) { + // Strip of , + names.setLength(names.length() - 1); + throw new CertificateException(message + + " Subject alternative DNS names in the certificate chain of " + chain.length + + " certificate(s): " + names, e); + } + } + throw e; + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/ExtendedOpenSslSession.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/ExtendedOpenSslSession.java new file mode 100644 index 0000000..5924325 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/ExtendedOpenSslSession.java @@ -0,0 +1,241 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.util.internal.EmptyArrays; +import io.netty.util.internal.SuppressJava6Requirement; + +import javax.net.ssl.ExtendedSSLSession; +import javax.net.ssl.SSLException; +import javax.net.ssl.SSLPeerUnverifiedException; +import javax.net.ssl.SSLSessionBindingEvent; +import javax.net.ssl.SSLSessionBindingListener; +import javax.security.cert.X509Certificate; +import java.security.Principal; +import java.security.cert.Certificate; +import java.util.Collections; +import java.util.List; + +/** + * Delegates all operations to a wrapped {@link OpenSslSession} except the methods defined by {@link ExtendedSSLSession} + * itself. + */ +@SuppressJava6Requirement(reason = "Usage guarded by java version check") +abstract class ExtendedOpenSslSession extends ExtendedSSLSession implements OpenSslSession { + + // TODO: use OpenSSL API to actually fetch the real data but for now just do what Conscrypt does: + // https://github.com/google/conscrypt/blob/1.2.0/common/ + // src/main/java/org/conscrypt/Java7ExtendedSSLSession.java#L32 + private static final String[] LOCAL_SUPPORTED_SIGNATURE_ALGORITHMS = { + "SHA512withRSA", "SHA512withECDSA", "SHA384withRSA", "SHA384withECDSA", "SHA256withRSA", + "SHA256withECDSA", "SHA224withRSA", "SHA224withECDSA", "SHA1withRSA", "SHA1withECDSA", + "RSASSA-PSS", + }; + + private final OpenSslSession wrapped; + + ExtendedOpenSslSession(OpenSslSession wrapped) { + this.wrapped = wrapped; + } + + // Use rawtypes an unchecked override to be able to also work on java7. + @Override + @SuppressWarnings({ "unchecked", "rawtypes" }) + public abstract List getRequestedServerNames(); + + // Do not mark as override so we can compile on java8. + public List getStatusResponses() { + // Just return an empty list for now until we support it as otherwise we will fail in java9 + // because of their sun.security.ssl.X509TrustManagerImpl class. + return Collections.emptyList(); + } + + @Override + public OpenSslSessionId sessionId() { + return wrapped.sessionId(); + } + + @Override + public void setSessionId(OpenSslSessionId id) { + wrapped.setSessionId(id); + } + + @Override + public final void setLocalCertificate(Certificate[] localCertificate) { + wrapped.setLocalCertificate(localCertificate); + } + + @Override + public String[] getPeerSupportedSignatureAlgorithms() { + return EmptyArrays.EMPTY_STRINGS; + } + + @Override + public final void tryExpandApplicationBufferSize(int packetLengthDataOnly) { + wrapped.tryExpandApplicationBufferSize(packetLengthDataOnly); + } + + @Override + public final String[] getLocalSupportedSignatureAlgorithms() { + return LOCAL_SUPPORTED_SIGNATURE_ALGORITHMS.clone(); + } + + @Override + public final byte[] getId() { + return wrapped.getId(); + } + + @Override + public final OpenSslSessionContext getSessionContext() { + return wrapped.getSessionContext(); + } + + @Override + public final long getCreationTime() { + return wrapped.getCreationTime(); + } + + @Override + public final long getLastAccessedTime() { + return wrapped.getLastAccessedTime(); + } + + @Override + public final void invalidate() { + wrapped.invalidate(); + } + + @Override + public final boolean isValid() { + return wrapped.isValid(); + } + + @Override + public final void putValue(String name, Object value) { + if (value instanceof SSLSessionBindingListener) { + // Decorate the value if needed so we submit the correct SSLSession instance + value = new SSLSessionBindingListenerDecorator((SSLSessionBindingListener) value); + } + wrapped.putValue(name, value); + } + + @Override + public final Object getValue(String s) { + Object value = wrapped.getValue(s); + if (value instanceof SSLSessionBindingListenerDecorator) { + // Unwrap as needed so we return the original value + return ((SSLSessionBindingListenerDecorator) value).delegate; + } + return value; + } + + @Override + public final void removeValue(String s) { + wrapped.removeValue(s); + } + + @Override + public final String[] getValueNames() { + return wrapped.getValueNames(); + } + + @Override + public final Certificate[] getPeerCertificates() throws SSLPeerUnverifiedException { + return wrapped.getPeerCertificates(); + } + + @Override + public final Certificate[] getLocalCertificates() { + return wrapped.getLocalCertificates(); + } + + @Override + public final X509Certificate[] getPeerCertificateChain() throws SSLPeerUnverifiedException { + return wrapped.getPeerCertificateChain(); + } + + @Override + public final Principal getPeerPrincipal() throws SSLPeerUnverifiedException { + return wrapped.getPeerPrincipal(); + } + + @Override + public final Principal getLocalPrincipal() { + return wrapped.getLocalPrincipal(); + } + + @Override + public final String getCipherSuite() { + return wrapped.getCipherSuite(); + } + + @Override + public String getProtocol() { + return wrapped.getProtocol(); + } + + @Override + public final String getPeerHost() { + return wrapped.getPeerHost(); + } + + @Override + public final int getPeerPort() { + return wrapped.getPeerPort(); + } + + @Override + public final int getPacketBufferSize() { + return wrapped.getPacketBufferSize(); + } + + @Override + public final int getApplicationBufferSize() { + return wrapped.getApplicationBufferSize(); + } + + private final class SSLSessionBindingListenerDecorator implements SSLSessionBindingListener { + + final SSLSessionBindingListener delegate; + + SSLSessionBindingListenerDecorator(SSLSessionBindingListener delegate) { + this.delegate = delegate; + } + + @Override + public void valueBound(SSLSessionBindingEvent event) { + delegate.valueBound(new SSLSessionBindingEvent(ExtendedOpenSslSession.this, event.getName())); + } + + @Override + public void valueUnbound(SSLSessionBindingEvent event) { + delegate.valueUnbound(new SSLSessionBindingEvent(ExtendedOpenSslSession.this, event.getName())); + } + } + + @Override + public void handshakeFinished(byte[] id, String cipher, String protocol, byte[] peerCertificate, + byte[][] peerCertificateChain, long creationTime, long timeout) throws SSLException { + wrapped.handshakeFinished(id, cipher, protocol, peerCertificate, peerCertificateChain, creationTime, timeout); + } + + @Override + public String toString() { + return "ExtendedOpenSslSession{" + + "wrapped=" + wrapped + + '}'; + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/GroupsConverter.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/GroupsConverter.java new file mode 100644 index 0000000..28b3329 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/GroupsConverter.java @@ -0,0 +1,50 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +/** + * Convert java naming to OpenSSL naming if possible and if not return the original name. + */ +final class GroupsConverter { + + private static final Map mappings; + + static { + // See https://tools.ietf.org/search/rfc4492#appendix-A and https://www.java.com/en/configure_crypto.html + Map map = new HashMap(); + map.put("secp224r1", "P-224"); + map.put("prime256v1", "P-256"); + map.put("secp256r1", "P-256"); + map.put("secp384r1", "P-384"); + map.put("secp521r1", "P-521"); + map.put("x25519", "X25519"); + mappings = Collections.unmodifiableMap(map); + } + + static String toOpenSsl(String key) { + String mapping = mappings.get(key); + if (mapping == null) { + return key; + } + return mapping; + } + + private GroupsConverter() { } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/IdentityCipherSuiteFilter.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/IdentityCipherSuiteFilter.java new file mode 100644 index 0000000..0ceba28 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/IdentityCipherSuiteFilter.java @@ -0,0 +1,64 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.util.internal.EmptyArrays; + +import java.util.ArrayList; +import java.util.List; +import java.util.Set; + +/** + * This class will not do any filtering of ciphers suites. + */ +public final class IdentityCipherSuiteFilter implements CipherSuiteFilter { + + /** + * Defaults to default ciphers when provided ciphers are null + */ + public static final IdentityCipherSuiteFilter INSTANCE = new IdentityCipherSuiteFilter(true); + + /** + * Defaults to supported ciphers when provided ciphers are null + */ + public static final IdentityCipherSuiteFilter INSTANCE_DEFAULTING_TO_SUPPORTED_CIPHERS = + new IdentityCipherSuiteFilter(false); + + private final boolean defaultToDefaultCiphers; + + private IdentityCipherSuiteFilter(boolean defaultToDefaultCiphers) { + this.defaultToDefaultCiphers = defaultToDefaultCiphers; + } + + @Override + public String[] filterCipherSuites(Iterable ciphers, List defaultCiphers, + Set supportedCiphers) { + if (ciphers == null) { + return defaultToDefaultCiphers ? + defaultCiphers.toArray(EmptyArrays.EMPTY_STRINGS) : + supportedCiphers.toArray(EmptyArrays.EMPTY_STRINGS); + } else { + List newCiphers = new ArrayList(supportedCiphers.size()); + for (String c : ciphers) { + if (c == null) { + break; + } + newCiphers.add(c); + } + return newCiphers.toArray(EmptyArrays.EMPTY_STRINGS); + } + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/Java7SslParametersUtils.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/Java7SslParametersUtils.java new file mode 100644 index 0000000..5a18dc5 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/Java7SslParametersUtils.java @@ -0,0 +1,38 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.util.internal.SuppressJava6Requirement; + +import javax.net.ssl.SSLParameters; +import java.security.AlgorithmConstraints; + +final class Java7SslParametersUtils { + + private Java7SslParametersUtils() { + // Utility + } + + /** + * Utility method that is used by {@link OpenSslEngine} and so allow use not have any reference to + * {@link AlgorithmConstraints} in the code. This helps us to not get into trouble when using it in java + * version < 7 and especially when using on android. + */ + @SuppressJava6Requirement(reason = "Usage guarded by java version check") + static void setAlgorithmConstraints(SSLParameters sslParameters, Object algorithmConstraints) { + sslParameters.setAlgorithmConstraints((AlgorithmConstraints) algorithmConstraints); + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/Java8SslUtils.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/Java8SslUtils.java new file mode 100644 index 0000000..4396ef4 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/Java8SslUtils.java @@ -0,0 +1,114 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.util.internal.SuppressJava6Requirement; +import io.netty.util.CharsetUtil; + +import javax.net.ssl.SNIHostName; +import javax.net.ssl.SNIMatcher; +import javax.net.ssl.SNIServerName; +import javax.net.ssl.SSLParameters; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; + +@SuppressJava6Requirement(reason = "Usage guarded by java version check") +final class Java8SslUtils { + + private Java8SslUtils() { } + + static List getSniHostNames(SSLParameters sslParameters) { + List names = sslParameters.getServerNames(); + if (names == null || names.isEmpty()) { + return Collections.emptyList(); + } + List strings = new ArrayList(names.size()); + + for (SNIServerName serverName : names) { + if (serverName instanceof SNIHostName) { + strings.add(((SNIHostName) serverName).getAsciiName()); + } else { + throw new IllegalArgumentException("Only " + SNIHostName.class.getName() + + " instances are supported, but found: " + serverName); + } + } + return strings; + } + + static void setSniHostNames(SSLParameters sslParameters, List names) { + sslParameters.setServerNames(getSniHostNames(names)); + } + + static boolean isValidHostNameForSNI(String hostname) { + try { + new SNIHostName(hostname); + return true; + } catch (IllegalArgumentException illegal) { + return false; + } + } + + static List getSniHostNames(List names) { + if (names == null || names.isEmpty()) { + return Collections.emptyList(); + } + List sniServerNames = new ArrayList(names.size()); + for (String name: names) { + sniServerNames.add(new SNIHostName(name.getBytes(CharsetUtil.UTF_8))); + } + return sniServerNames; + } + + static List getSniHostName(byte[] hostname) { + if (hostname == null || hostname.length == 0) { + return Collections.emptyList(); + } + return Collections.singletonList(new SNIHostName(hostname)); + } + + static boolean getUseCipherSuitesOrder(SSLParameters sslParameters) { + return sslParameters.getUseCipherSuitesOrder(); + } + + static void setUseCipherSuitesOrder(SSLParameters sslParameters, boolean useOrder) { + sslParameters.setUseCipherSuitesOrder(useOrder); + } + + @SuppressWarnings("unchecked") + static void setSNIMatchers(SSLParameters sslParameters, Collection matchers) { + sslParameters.setSNIMatchers((Collection) matchers); + } + + @SuppressWarnings("unchecked") + static boolean checkSniHostnameMatch(Collection matchers, byte[] hostname) { + if (matchers != null && !matchers.isEmpty()) { + SNIHostName name = new SNIHostName(hostname); + Iterator matcherIt = (Iterator) matchers.iterator(); + while (matcherIt.hasNext()) { + SNIMatcher matcher = matcherIt.next(); + // type 0 is for hostname + if (matcher.getType() == 0 && matcher.matches(name)) { + return true; + } + } + return false; + } + return true; + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/JdkAlpnApplicationProtocolNegotiator.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/JdkAlpnApplicationProtocolNegotiator.java new file mode 100644 index 0000000..d50d336 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/JdkAlpnApplicationProtocolNegotiator.java @@ -0,0 +1,154 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.buffer.ByteBufAllocator; + +import javax.net.ssl.SSLEngine; + +/** + * The {@link JdkApplicationProtocolNegotiator} to use if you need ALPN and are using {@link SslProvider#JDK}. + * + * @deprecated use {@link ApplicationProtocolConfig}. + */ +@Deprecated +public final class JdkAlpnApplicationProtocolNegotiator extends JdkBaseApplicationProtocolNegotiator { + private static final boolean AVAILABLE = Conscrypt.isAvailable() || + JdkAlpnSslUtils.supportsAlpn() || + BouncyCastle.isAvailable(); + + private static final SslEngineWrapperFactory ALPN_WRAPPER = AVAILABLE ? new AlpnWrapper() : new FailureWrapper(); + + /** + * Create a new instance. + * @param protocols The order of iteration determines the preference of support for protocols. + */ + public JdkAlpnApplicationProtocolNegotiator(Iterable protocols) { + this(false, protocols); + } + + /** + * Create a new instance. + * @param protocols The order of iteration determines the preference of support for protocols. + */ + public JdkAlpnApplicationProtocolNegotiator(String... protocols) { + this(false, protocols); + } + + /** + * Create a new instance. + * @param failIfNoCommonProtocols Fail with a fatal alert if not common protocols are detected. + * @param protocols The order of iteration determines the preference of support for protocols. + */ + public JdkAlpnApplicationProtocolNegotiator(boolean failIfNoCommonProtocols, Iterable protocols) { + this(failIfNoCommonProtocols, failIfNoCommonProtocols, protocols); + } + + /** + * Create a new instance. + * @param failIfNoCommonProtocols Fail with a fatal alert if not common protocols are detected. + * @param protocols The order of iteration determines the preference of support for protocols. + */ + public JdkAlpnApplicationProtocolNegotiator(boolean failIfNoCommonProtocols, String... protocols) { + this(failIfNoCommonProtocols, failIfNoCommonProtocols, protocols); + } + + /** + * Create a new instance. + * @param clientFailIfNoCommonProtocols Client side fail with a fatal alert if not common protocols are detected. + * @param serverFailIfNoCommonProtocols Server side fail with a fatal alert if not common protocols are detected. + * @param protocols The order of iteration determines the preference of support for protocols. + */ + public JdkAlpnApplicationProtocolNegotiator(boolean clientFailIfNoCommonProtocols, + boolean serverFailIfNoCommonProtocols, Iterable protocols) { + this(serverFailIfNoCommonProtocols ? FAIL_SELECTOR_FACTORY : NO_FAIL_SELECTOR_FACTORY, + clientFailIfNoCommonProtocols ? FAIL_SELECTION_LISTENER_FACTORY : NO_FAIL_SELECTION_LISTENER_FACTORY, + protocols); + } + + /** + * Create a new instance. + * @param clientFailIfNoCommonProtocols Client side fail with a fatal alert if not common protocols are detected. + * @param serverFailIfNoCommonProtocols Server side fail with a fatal alert if not common protocols are detected. + * @param protocols The order of iteration determines the preference of support for protocols. + */ + public JdkAlpnApplicationProtocolNegotiator(boolean clientFailIfNoCommonProtocols, + boolean serverFailIfNoCommonProtocols, String... protocols) { + this(serverFailIfNoCommonProtocols ? FAIL_SELECTOR_FACTORY : NO_FAIL_SELECTOR_FACTORY, + clientFailIfNoCommonProtocols ? FAIL_SELECTION_LISTENER_FACTORY : NO_FAIL_SELECTION_LISTENER_FACTORY, + protocols); + } + + /** + * Create a new instance. + * @param selectorFactory The factory which provides classes responsible for selecting the protocol. + * @param listenerFactory The factory which provides to be notified of which protocol was selected. + * @param protocols The order of iteration determines the preference of support for protocols. + */ + public JdkAlpnApplicationProtocolNegotiator(ProtocolSelectorFactory selectorFactory, + ProtocolSelectionListenerFactory listenerFactory, Iterable protocols) { + super(ALPN_WRAPPER, selectorFactory, listenerFactory, protocols); + } + + /** + * Create a new instance. + * @param selectorFactory The factory which provides classes responsible for selecting the protocol. + * @param listenerFactory The factory which provides to be notified of which protocol was selected. + * @param protocols The order of iteration determines the preference of support for protocols. + */ + public JdkAlpnApplicationProtocolNegotiator(ProtocolSelectorFactory selectorFactory, + ProtocolSelectionListenerFactory listenerFactory, String... protocols) { + super(ALPN_WRAPPER, selectorFactory, listenerFactory, protocols); + } + + private static final class FailureWrapper extends AllocatorAwareSslEngineWrapperFactory { + @Override + public SSLEngine wrapSslEngine(SSLEngine engine, ByteBufAllocator alloc, + JdkApplicationProtocolNegotiator applicationNegotiator, boolean isServer) { + throw new RuntimeException("ALPN unsupported. Is your classpath configured correctly?" + + " For Conscrypt, add the appropriate Conscrypt JAR to classpath and set the security provider." + + " For Jetty-ALPN, see " + + "https://www.eclipse.org/jetty/documentation/current/alpn-chapter.html#alpn-starting"); + } + } + + private static final class AlpnWrapper extends AllocatorAwareSslEngineWrapperFactory { + @Override + public SSLEngine wrapSslEngine(SSLEngine engine, ByteBufAllocator alloc, + JdkApplicationProtocolNegotiator applicationNegotiator, boolean isServer) { + if (Conscrypt.isEngineSupported(engine)) { + return isServer ? ConscryptAlpnSslEngine.newServerEngine(engine, alloc, applicationNegotiator) + : ConscryptAlpnSslEngine.newClientEngine(engine, alloc, applicationNegotiator); + } + if (BouncyCastle.isInUse(engine)) { + return new BouncyCastleAlpnSslEngine(engine, applicationNegotiator, isServer); + } + // ALPN support was recently backported to Java8 as + // https://bugs.java.com/bugdatabase/view_bug.do?bug_id=8230977. + // Because of this lets not do a Java version runtime check but just depend on if the required methods are + // present + if (JdkAlpnSslUtils.supportsAlpn()) { + return new JdkAlpnSslEngine(engine, applicationNegotiator, isServer); + } + throw new UnsupportedOperationException("ALPN not supported. Unable to wrap SSLEngine of type '" + + engine.getClass().getName() + "')"); + } + } + + static boolean isAlpnSupported() { + return AVAILABLE; + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/JdkAlpnSslEngine.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/JdkAlpnSslEngine.java new file mode 100644 index 0000000..08fdc36 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/JdkAlpnSslEngine.java @@ -0,0 +1,207 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.util.internal.StringUtil; +import io.netty.util.internal.SuppressJava6Requirement; + +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLEngineResult; +import javax.net.ssl.SSLException; + +import java.nio.ByteBuffer; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.function.BiConsumer; +import java.util.function.BiFunction; + +import static io.netty.handler.ssl.SslUtils.toSSLHandshakeException; +import static io.netty.handler.ssl.JdkApplicationProtocolNegotiator.ProtocolSelectionListener; +import static io.netty.handler.ssl.JdkApplicationProtocolNegotiator.ProtocolSelector; + +@SuppressJava6Requirement(reason = "Usage guarded by java version check") +class JdkAlpnSslEngine extends JdkSslEngine { + private final ProtocolSelectionListener selectionListener; + private final AlpnSelector alpnSelector; + + final class AlpnSelector implements BiFunction, String> { + private final ProtocolSelector selector; + private boolean called; + + AlpnSelector(ProtocolSelector selector) { + this.selector = selector; + } + + @Override + public String apply(SSLEngine sslEngine, List strings) { + assert !called; + called = true; + + try { + String selected = selector.select(strings); + return selected == null ? StringUtil.EMPTY_STRING : selected; + } catch (Exception cause) { + // Returning null means we want to fail the handshake. + // + // See https://download.java.net/java/jdk9/docs/api/javax/net/ssl/ + // SSLEngine.html#setHandshakeApplicationProtocolSelector-java.util.function.BiFunction- + return null; + } + } + + void checkUnsupported() { + if (called) { + // ALPN message was received by peer and so apply(...) was called. + // See: + // https://hg.openjdk.java.net/jdk9/dev/jdk/file/65464a307408/src/ + // java.base/share/classes/sun/security/ssl/ServerHandshaker.java#l933 + return; + } + String protocol = getApplicationProtocol(); + assert protocol != null; + + if (protocol.isEmpty()) { + // ALPN is not supported + selector.unsupported(); + } + } + } + + JdkAlpnSslEngine(SSLEngine engine, + @SuppressWarnings("deprecation") JdkApplicationProtocolNegotiator applicationNegotiator, + boolean isServer, BiConsumer setHandshakeApplicationProtocolSelector, + BiConsumer> setApplicationProtocols) { + super(engine); + if (isServer) { + selectionListener = null; + alpnSelector = new AlpnSelector(applicationNegotiator.protocolSelectorFactory(). + newSelector(this, new LinkedHashSet(applicationNegotiator.protocols()))); + setHandshakeApplicationProtocolSelector.accept(engine, alpnSelector); + } else { + selectionListener = applicationNegotiator.protocolListenerFactory() + .newListener(this, applicationNegotiator.protocols()); + alpnSelector = null; + setApplicationProtocols.accept(engine, applicationNegotiator.protocols()); + } + } + + JdkAlpnSslEngine(SSLEngine engine, + @SuppressWarnings("deprecation") JdkApplicationProtocolNegotiator applicationNegotiator, + boolean isServer) { + this(engine, applicationNegotiator, isServer, + new BiConsumer() { + @Override + public void accept(SSLEngine e, AlpnSelector s) { + JdkAlpnSslUtils.setHandshakeApplicationProtocolSelector(e, s); + } + }, + new BiConsumer>() { + @Override + public void accept(SSLEngine e, List p) { + JdkAlpnSslUtils.setApplicationProtocols(e, p); + } + }); + } + + private SSLEngineResult verifyProtocolSelection(SSLEngineResult result) throws SSLException { + if (result.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.FINISHED) { + if (alpnSelector == null) { + // This means we are using client-side and + try { + String protocol = getApplicationProtocol(); + assert protocol != null; + if (protocol.isEmpty()) { + // If empty the server did not announce ALPN: + // See: + // https://hg.openjdk.java.net/jdk9/dev/jdk/file/65464a307408/src/java.base/ + // share/classes/sun/security/ssl/ClientHandshaker.java#l741 + selectionListener.unsupported(); + } else { + selectionListener.selected(protocol); + } + } catch (Throwable e) { + throw toSSLHandshakeException(e); + } + } else { + assert selectionListener == null; + alpnSelector.checkUnsupported(); + } + } + return result; + } + + @Override + public SSLEngineResult wrap(ByteBuffer src, ByteBuffer dst) throws SSLException { + return verifyProtocolSelection(super.wrap(src, dst)); + } + + @Override + public SSLEngineResult wrap(ByteBuffer[] srcs, ByteBuffer dst) throws SSLException { + return verifyProtocolSelection(super.wrap(srcs, dst)); + } + + @Override + public SSLEngineResult wrap(ByteBuffer[] srcs, int offset, int len, ByteBuffer dst) throws SSLException { + return verifyProtocolSelection(super.wrap(srcs, offset, len, dst)); + } + + @Override + public SSLEngineResult unwrap(ByteBuffer src, ByteBuffer dst) throws SSLException { + return verifyProtocolSelection(super.unwrap(src, dst)); + } + + @Override + public SSLEngineResult unwrap(ByteBuffer src, ByteBuffer[] dsts) throws SSLException { + return verifyProtocolSelection(super.unwrap(src, dsts)); + } + + @Override + public SSLEngineResult unwrap(ByteBuffer src, ByteBuffer[] dst, int offset, int len) throws SSLException { + return verifyProtocolSelection(super.unwrap(src, dst, offset, len)); + } + + @Override + void setNegotiatedApplicationProtocol(String applicationProtocol) { + // Do nothing as this is handled internally by the Java8u251+ implementation of SSLEngine. + } + + @Override + public String getNegotiatedApplicationProtocol() { + String protocol = getApplicationProtocol(); + if (protocol != null) { + return protocol.isEmpty() ? null : protocol; + } + return null; + } + + // These methods will override the methods defined by Java 8u251 and later. As we may compile with an earlier + // java8 version we don't use @Override annotations here. + public String getApplicationProtocol() { + return JdkAlpnSslUtils.getApplicationProtocol(getWrappedEngine()); + } + + public String getHandshakeApplicationProtocol() { + return JdkAlpnSslUtils.getHandshakeApplicationProtocol(getWrappedEngine()); + } + + public void setHandshakeApplicationProtocolSelector(BiFunction, String> selector) { + JdkAlpnSslUtils.setHandshakeApplicationProtocolSelector(getWrappedEngine(), selector); + } + + public BiFunction, String> getHandshakeApplicationProtocolSelector() { + return JdkAlpnSslUtils.getHandshakeApplicationProtocolSelector(getWrappedEngine()); + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/JdkAlpnSslUtils.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/JdkAlpnSslUtils.java new file mode 100644 index 0000000..cc230ef --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/JdkAlpnSslUtils.java @@ -0,0 +1,181 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + + +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLParameters; +import java.lang.reflect.Method; +import java.security.AccessController; +import java.security.PrivilegedExceptionAction; +import java.util.List; +import java.util.function.BiFunction; + +import io.netty.util.internal.EmptyArrays; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.SuppressJava6Requirement; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +@SuppressJava6Requirement(reason = "Usage guarded by java version check") +final class JdkAlpnSslUtils { + private static final InternalLogger logger = InternalLoggerFactory.getInstance(JdkAlpnSslUtils.class); + private static final Method SET_APPLICATION_PROTOCOLS; + private static final Method GET_APPLICATION_PROTOCOL; + private static final Method GET_HANDSHAKE_APPLICATION_PROTOCOL; + private static final Method SET_HANDSHAKE_APPLICATION_PROTOCOL_SELECTOR; + private static final Method GET_HANDSHAKE_APPLICATION_PROTOCOL_SELECTOR; + + static { + Method getHandshakeApplicationProtocol; + Method getApplicationProtocol; + Method setApplicationProtocols; + Method setHandshakeApplicationProtocolSelector; + Method getHandshakeApplicationProtocolSelector; + + try { + SSLContext context = SSLContext.getInstance(JdkSslContext.PROTOCOL); + context.init(null, null, null); + SSLEngine engine = context.createSSLEngine(); + getHandshakeApplicationProtocol = AccessController.doPrivileged(new PrivilegedExceptionAction() { + @Override + public Method run() throws Exception { + return SSLEngine.class.getMethod("getHandshakeApplicationProtocol"); + } + }); + getHandshakeApplicationProtocol.invoke(engine); + getApplicationProtocol = AccessController.doPrivileged(new PrivilegedExceptionAction() { + @Override + public Method run() throws Exception { + return SSLEngine.class.getMethod("getApplicationProtocol"); + } + }); + getApplicationProtocol.invoke(engine); + + setApplicationProtocols = AccessController.doPrivileged(new PrivilegedExceptionAction() { + @Override + public Method run() throws Exception { + return SSLParameters.class.getMethod("setApplicationProtocols", String[].class); + } + }); + setApplicationProtocols.invoke(engine.getSSLParameters(), new Object[]{EmptyArrays.EMPTY_STRINGS}); + + setHandshakeApplicationProtocolSelector = + AccessController.doPrivileged(new PrivilegedExceptionAction() { + @Override + public Method run() throws Exception { + return SSLEngine.class.getMethod("setHandshakeApplicationProtocolSelector", BiFunction.class); + } + }); + setHandshakeApplicationProtocolSelector.invoke(engine, new BiFunction, String>() { + @Override + public String apply(SSLEngine sslEngine, List strings) { + return null; + } + }); + + getHandshakeApplicationProtocolSelector = + AccessController.doPrivileged(new PrivilegedExceptionAction() { + @Override + public Method run() throws Exception { + return SSLEngine.class.getMethod("getHandshakeApplicationProtocolSelector"); + } + }); + getHandshakeApplicationProtocolSelector.invoke(engine); + } catch (Throwable t) { + int version = PlatformDependent.javaVersion(); + if (version >= 9) { + // We only log when run on java9+ as this is expected on some earlier java8 versions + logger.error("Unable to initialize JdkAlpnSslUtils, but the detected java version was: {}", version, t); + } + getHandshakeApplicationProtocol = null; + getApplicationProtocol = null; + setApplicationProtocols = null; + setHandshakeApplicationProtocolSelector = null; + getHandshakeApplicationProtocolSelector = null; + } + GET_HANDSHAKE_APPLICATION_PROTOCOL = getHandshakeApplicationProtocol; + GET_APPLICATION_PROTOCOL = getApplicationProtocol; + SET_APPLICATION_PROTOCOLS = setApplicationProtocols; + SET_HANDSHAKE_APPLICATION_PROTOCOL_SELECTOR = setHandshakeApplicationProtocolSelector; + GET_HANDSHAKE_APPLICATION_PROTOCOL_SELECTOR = getHandshakeApplicationProtocolSelector; + } + + private JdkAlpnSslUtils() { + } + + static boolean supportsAlpn() { + return GET_APPLICATION_PROTOCOL != null; + } + + static String getApplicationProtocol(SSLEngine sslEngine) { + try { + return (String) GET_APPLICATION_PROTOCOL.invoke(sslEngine); + } catch (UnsupportedOperationException ex) { + throw ex; + } catch (Exception ex) { + throw new IllegalStateException(ex); + } + } + + static String getHandshakeApplicationProtocol(SSLEngine sslEngine) { + try { + return (String) GET_HANDSHAKE_APPLICATION_PROTOCOL.invoke(sslEngine); + } catch (UnsupportedOperationException ex) { + throw ex; + } catch (Exception ex) { + throw new IllegalStateException(ex); + } + } + + static void setApplicationProtocols(SSLEngine engine, List supportedProtocols) { + SSLParameters parameters = engine.getSSLParameters(); + + String[] protocolArray = supportedProtocols.toArray(EmptyArrays.EMPTY_STRINGS); + try { + SET_APPLICATION_PROTOCOLS.invoke(parameters, new Object[]{protocolArray}); + } catch (UnsupportedOperationException ex) { + throw ex; + } catch (Exception ex) { + throw new IllegalStateException(ex); + } + engine.setSSLParameters(parameters); + } + + static void setHandshakeApplicationProtocolSelector( + SSLEngine engine, BiFunction, String> selector) { + try { + SET_HANDSHAKE_APPLICATION_PROTOCOL_SELECTOR.invoke(engine, selector); + } catch (UnsupportedOperationException ex) { + throw ex; + } catch (Exception ex) { + throw new IllegalStateException(ex); + } + } + + @SuppressWarnings("unchecked") + static BiFunction, String> getHandshakeApplicationProtocolSelector(SSLEngine engine) { + try { + return (BiFunction, String>) + GET_HANDSHAKE_APPLICATION_PROTOCOL_SELECTOR.invoke(engine); + } catch (UnsupportedOperationException ex) { + throw ex; + } catch (Exception ex) { + throw new IllegalStateException(ex); + } + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/JdkApplicationProtocolNegotiator.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/JdkApplicationProtocolNegotiator.java new file mode 100644 index 0000000..14dec62 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/JdkApplicationProtocolNegotiator.java @@ -0,0 +1,162 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.buffer.ByteBufAllocator; +import javax.net.ssl.SSLEngine; +import java.util.List; +import java.util.Set; + +/** + * JDK extension methods to support {@link ApplicationProtocolNegotiator} + * + * @deprecated use {@link ApplicationProtocolConfig} + */ +@Deprecated +public interface JdkApplicationProtocolNegotiator extends ApplicationProtocolNegotiator { + /** + * Abstract factory pattern for wrapping an {@link SSLEngine} object. This is useful for NPN/APLN JDK support. + */ + interface SslEngineWrapperFactory { + /** + * Abstract factory pattern for wrapping an {@link SSLEngine} object. This is useful for NPN/APLN support. + * + * @param engine The engine to wrap. + * @param applicationNegotiator The application level protocol negotiator + * @param isServer
    + *
  • {@code true} if the engine is for server side of connections
  • + *
  • {@code false} if the engine is for client side of connections
  • + *
+ * @return The resulting wrapped engine. This may just be {@code engine}. + */ + SSLEngine wrapSslEngine( + SSLEngine engine, JdkApplicationProtocolNegotiator applicationNegotiator, boolean isServer); + } + + abstract class AllocatorAwareSslEngineWrapperFactory implements SslEngineWrapperFactory { + + @Override + public final SSLEngine wrapSslEngine(SSLEngine engine, + JdkApplicationProtocolNegotiator applicationNegotiator, boolean isServer) { + return wrapSslEngine(engine, ByteBufAllocator.DEFAULT, applicationNegotiator, isServer); + } + + /** + * Abstract factory pattern for wrapping an {@link SSLEngine} object. This is useful for NPN/APLN support. + * + * @param engine The engine to wrap. + * @param alloc the buffer allocator. + * @param applicationNegotiator The application level protocol negotiator + * @param isServer
    + *
  • {@code true} if the engine is for server side of connections
  • + *
  • {@code false} if the engine is for client side of connections
  • + *
+ * @return The resulting wrapped engine. This may just be {@code engine}. + */ + abstract SSLEngine wrapSslEngine(SSLEngine engine, ByteBufAllocator alloc, + JdkApplicationProtocolNegotiator applicationNegotiator, boolean isServer); + } + + /** + * Interface to define the role of an application protocol selector in the SSL handshake process. Either + * {@link ProtocolSelector#unsupported()} OR {@link ProtocolSelector#select(List)} will be called for each SSL + * handshake. + */ + interface ProtocolSelector { + /** + * Callback invoked to let the application know that the peer does not support this + * {@link ApplicationProtocolNegotiator}. + */ + void unsupported(); + + /** + * Callback invoked to select the application level protocol from the {@code protocols} provided. + * + * @param protocols the protocols sent by the protocol advertiser + * @return the protocol selected by this {@link ProtocolSelector}. A {@code null} value will indicate the no + * protocols were selected but the handshake should not fail. The decision to fail the handshake is left to the + * other end negotiating the SSL handshake. + * @throws Exception If the {@code protocols} provide warrant failing the SSL handshake with a fatal alert. + */ + String select(List protocols) throws Exception; + } + + /** + * A listener to be notified by which protocol was select by its peer. Either the + * {@link ProtocolSelectionListener#unsupported()} OR the {@link ProtocolSelectionListener#selected(String)} method + * will be called for each SSL handshake. + */ + interface ProtocolSelectionListener { + /** + * Callback invoked to let the application know that the peer does not support this + * {@link ApplicationProtocolNegotiator}. + */ + void unsupported(); + + /** + * Callback invoked to let this application know the protocol chosen by the peer. + * + * @param protocol the protocol selected by the peer. May be {@code null} or empty as supported by the + * application negotiation protocol. + * @throws Exception This may be thrown if the selected protocol is not acceptable and the desired behavior is + * to fail the handshake with a fatal alert. + */ + void selected(String protocol) throws Exception; + } + + /** + * Factory interface for {@link ProtocolSelector} objects. + */ + interface ProtocolSelectorFactory { + /** + * Generate a new instance of {@link ProtocolSelector}. + * @param engine The {@link SSLEngine} that the returned {@link ProtocolSelector} will be used to create an + * instance for. + * @param supportedProtocols The protocols that are supported. + * @return A new instance of {@link ProtocolSelector}. + */ + ProtocolSelector newSelector(SSLEngine engine, Set supportedProtocols); + } + + /** + * Factory interface for {@link ProtocolSelectionListener} objects. + */ + interface ProtocolSelectionListenerFactory { + /** + * Generate a new instance of {@link ProtocolSelectionListener}. + * @param engine The {@link SSLEngine} that the returned {@link ProtocolSelectionListener} will be used to + * create an instance for. + * @param supportedProtocols The protocols that are supported in preference order. + * @return A new instance of {@link ProtocolSelectionListener}. + */ + ProtocolSelectionListener newListener(SSLEngine engine, List supportedProtocols); + } + + /** + * Get the {@link SslEngineWrapperFactory}. + */ + SslEngineWrapperFactory wrapperFactory(); + + /** + * Get the {@link ProtocolSelectorFactory}. + */ + ProtocolSelectorFactory protocolSelectorFactory(); + + /** + * Get the {@link ProtocolSelectionListenerFactory}. + */ + ProtocolSelectionListenerFactory protocolListenerFactory(); +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/JdkBaseApplicationProtocolNegotiator.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/JdkBaseApplicationProtocolNegotiator.java new file mode 100644 index 0000000..54bb216 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/JdkBaseApplicationProtocolNegotiator.java @@ -0,0 +1,209 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import static io.netty.handler.ssl.ApplicationProtocolUtil.toList; +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +import java.util.Collections; +import java.util.List; +import java.util.Set; + +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLHandshakeException; + +/** + * Common base class for {@link JdkApplicationProtocolNegotiator} classes to inherit from. + */ +class JdkBaseApplicationProtocolNegotiator implements JdkApplicationProtocolNegotiator { + private final List protocols; + private final ProtocolSelectorFactory selectorFactory; + private final ProtocolSelectionListenerFactory listenerFactory; + private final SslEngineWrapperFactory wrapperFactory; + + /** + * Create a new instance. + * @param wrapperFactory Determines which application protocol will be used by wrapping the SSLEngine in use. + * @param selectorFactory How the peer selecting the protocol should behave. + * @param listenerFactory How the peer being notified of the selected protocol should behave. + * @param protocols The order of iteration determines the preference of support for protocols. + */ + JdkBaseApplicationProtocolNegotiator(SslEngineWrapperFactory wrapperFactory, + ProtocolSelectorFactory selectorFactory, ProtocolSelectionListenerFactory listenerFactory, + Iterable protocols) { + this(wrapperFactory, selectorFactory, listenerFactory, toList(protocols)); + } + + /** + * Create a new instance. + * @param wrapperFactory Determines which application protocol will be used by wrapping the SSLEngine in use. + * @param selectorFactory How the peer selecting the protocol should behave. + * @param listenerFactory How the peer being notified of the selected protocol should behave. + * @param protocols The order of iteration determines the preference of support for protocols. + */ + JdkBaseApplicationProtocolNegotiator(SslEngineWrapperFactory wrapperFactory, + ProtocolSelectorFactory selectorFactory, ProtocolSelectionListenerFactory listenerFactory, + String... protocols) { + this(wrapperFactory, selectorFactory, listenerFactory, toList(protocols)); + } + + /** + * Create a new instance. + * @param wrapperFactory Determines which application protocol will be used by wrapping the SSLEngine in use. + * @param selectorFactory How the peer selecting the protocol should behave. + * @param listenerFactory How the peer being notified of the selected protocol should behave. + * @param protocols The order of iteration determines the preference of support for protocols. + */ + private JdkBaseApplicationProtocolNegotiator(SslEngineWrapperFactory wrapperFactory, + ProtocolSelectorFactory selectorFactory, ProtocolSelectionListenerFactory listenerFactory, + List protocols) { + this.wrapperFactory = checkNotNull(wrapperFactory, "wrapperFactory"); + this.selectorFactory = checkNotNull(selectorFactory, "selectorFactory"); + this.listenerFactory = checkNotNull(listenerFactory, "listenerFactory"); + this.protocols = Collections.unmodifiableList(checkNotNull(protocols, "protocols")); + } + + @Override + public List protocols() { + return protocols; + } + + @Override + public ProtocolSelectorFactory protocolSelectorFactory() { + return selectorFactory; + } + + @Override + public ProtocolSelectionListenerFactory protocolListenerFactory() { + return listenerFactory; + } + + @Override + public SslEngineWrapperFactory wrapperFactory() { + return wrapperFactory; + } + + static final ProtocolSelectorFactory FAIL_SELECTOR_FACTORY = new ProtocolSelectorFactory() { + @Override + public ProtocolSelector newSelector(SSLEngine engine, Set supportedProtocols) { + return new FailProtocolSelector((JdkSslEngine) engine, supportedProtocols); + } + }; + + static final ProtocolSelectorFactory NO_FAIL_SELECTOR_FACTORY = new ProtocolSelectorFactory() { + @Override + public ProtocolSelector newSelector(SSLEngine engine, Set supportedProtocols) { + return new NoFailProtocolSelector((JdkSslEngine) engine, supportedProtocols); + } + }; + + static final ProtocolSelectionListenerFactory FAIL_SELECTION_LISTENER_FACTORY = + new ProtocolSelectionListenerFactory() { + @Override + public ProtocolSelectionListener newListener(SSLEngine engine, List supportedProtocols) { + return new FailProtocolSelectionListener((JdkSslEngine) engine, supportedProtocols); + } + }; + + static final ProtocolSelectionListenerFactory NO_FAIL_SELECTION_LISTENER_FACTORY = + new ProtocolSelectionListenerFactory() { + @Override + public ProtocolSelectionListener newListener(SSLEngine engine, List supportedProtocols) { + return new NoFailProtocolSelectionListener((JdkSslEngine) engine, supportedProtocols); + } + }; + + static class NoFailProtocolSelector implements ProtocolSelector { + private final JdkSslEngine engineWrapper; + private final Set supportedProtocols; + + NoFailProtocolSelector(JdkSslEngine engineWrapper, Set supportedProtocols) { + this.engineWrapper = engineWrapper; + this.supportedProtocols = supportedProtocols; + } + + @Override + public void unsupported() { + engineWrapper.setNegotiatedApplicationProtocol(null); + } + + @Override + public String select(List protocols) throws Exception { + for (String p : supportedProtocols) { + if (protocols.contains(p)) { + engineWrapper.setNegotiatedApplicationProtocol(p); + return p; + } + } + return noSelectMatchFound(); + } + + public String noSelectMatchFound() throws Exception { + engineWrapper.setNegotiatedApplicationProtocol(null); + return null; + } + } + + private static final class FailProtocolSelector extends NoFailProtocolSelector { + FailProtocolSelector(JdkSslEngine engineWrapper, Set supportedProtocols) { + super(engineWrapper, supportedProtocols); + } + + @Override + public String noSelectMatchFound() throws Exception { + throw new SSLHandshakeException("Selected protocol is not supported"); + } + } + + private static class NoFailProtocolSelectionListener implements ProtocolSelectionListener { + private final JdkSslEngine engineWrapper; + private final List supportedProtocols; + + NoFailProtocolSelectionListener(JdkSslEngine engineWrapper, List supportedProtocols) { + this.engineWrapper = engineWrapper; + this.supportedProtocols = supportedProtocols; + } + + @Override + public void unsupported() { + engineWrapper.setNegotiatedApplicationProtocol(null); + } + + @Override + public void selected(String protocol) throws Exception { + if (supportedProtocols.contains(protocol)) { + engineWrapper.setNegotiatedApplicationProtocol(protocol); + } else { + noSelectedMatchFound(protocol); + } + } + + protected void noSelectedMatchFound(String protocol) throws Exception { + // Will never be called. + } + } + + private static final class FailProtocolSelectionListener extends NoFailProtocolSelectionListener { + FailProtocolSelectionListener(JdkSslEngine engineWrapper, List supportedProtocols) { + super(engineWrapper, supportedProtocols); + } + + @Override + protected void noSelectedMatchFound(String protocol) throws Exception { + throw new SSLHandshakeException("No compatible protocols found"); + } + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/JdkDefaultApplicationProtocolNegotiator.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/JdkDefaultApplicationProtocolNegotiator.java new file mode 100644 index 0000000..587c599 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/JdkDefaultApplicationProtocolNegotiator.java @@ -0,0 +1,60 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import java.util.Collections; +import java.util.List; + +import javax.net.ssl.SSLEngine; + +/** + * The {@link JdkApplicationProtocolNegotiator} to use if you do not care about NPN or ALPN and are using + * {@link SslProvider#JDK}. + */ +final class JdkDefaultApplicationProtocolNegotiator implements JdkApplicationProtocolNegotiator { + public static final JdkDefaultApplicationProtocolNegotiator INSTANCE = + new JdkDefaultApplicationProtocolNegotiator(); + private static final SslEngineWrapperFactory DEFAULT_SSL_ENGINE_WRAPPER_FACTORY = new SslEngineWrapperFactory() { + @Override + public SSLEngine wrapSslEngine(SSLEngine engine, + JdkApplicationProtocolNegotiator applicationNegotiator, boolean isServer) { + return engine; + } + }; + + private JdkDefaultApplicationProtocolNegotiator() { + } + + @Override + public SslEngineWrapperFactory wrapperFactory() { + return DEFAULT_SSL_ENGINE_WRAPPER_FACTORY; + } + + @Override + public ProtocolSelectorFactory protocolSelectorFactory() { + throw new UnsupportedOperationException("Application protocol negotiation unsupported"); + } + + @Override + public ProtocolSelectionListenerFactory protocolListenerFactory() { + throw new UnsupportedOperationException("Application protocol negotiation unsupported"); + } + + @Override + public List protocols() { + return Collections.emptyList(); + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/JdkSslClientContext.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/JdkSslClientContext.java new file mode 100644 index 0000000..2a018e3 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/JdkSslClientContext.java @@ -0,0 +1,313 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.ssl; + +import java.security.KeyStore; +import java.security.Provider; +import javax.net.ssl.KeyManager; +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLException; +import javax.net.ssl.SSLSessionContext; +import javax.net.ssl.TrustManager; +import javax.net.ssl.TrustManagerFactory; +import java.io.File; +import java.security.PrivateKey; +import java.security.cert.X509Certificate; + +/** + * A client-side {@link SslContext} which uses JDK's SSL/TLS implementation. + * + * @deprecated Use {@link SslContextBuilder} to create {@link JdkSslContext} instances and only + * use {@link JdkSslContext} in your code. + */ +@Deprecated +public final class JdkSslClientContext extends JdkSslContext { + + /** + * Creates a new instance. + * + * @deprecated use {@link SslContextBuilder} + */ + @Deprecated + public JdkSslClientContext() throws SSLException { + this(null, null); + } + + /** + * Creates a new instance. + * + * @param certChainFile an X.509 certificate chain file in PEM format. + * {@code null} to use the system default + * @deprecated use {@link SslContextBuilder} + */ + @Deprecated + public JdkSslClientContext(File certChainFile) throws SSLException { + this(certChainFile, null); + } + + /** + * Creates a new instance. + * + * @param trustManagerFactory the {@link TrustManagerFactory} that provides the {@link TrustManager}s + * that verifies the certificates sent from servers. + * {@code null} to use the default. + * @deprecated use {@link SslContextBuilder} + */ + @Deprecated + public JdkSslClientContext(TrustManagerFactory trustManagerFactory) throws SSLException { + this(null, trustManagerFactory); + } + + /** + * Creates a new instance. + * + * @param certChainFile an X.509 certificate chain file in PEM format. + * {@code null} to use the system default + * @param trustManagerFactory the {@link TrustManagerFactory} that provides the {@link TrustManager}s + * that verifies the certificates sent from servers. + * {@code null} to use the default. + * @deprecated use {@link SslContextBuilder} + */ + @Deprecated + public JdkSslClientContext(File certChainFile, TrustManagerFactory trustManagerFactory) throws SSLException { + this(certChainFile, trustManagerFactory, null, IdentityCipherSuiteFilter.INSTANCE, + JdkDefaultApplicationProtocolNegotiator.INSTANCE, 0, 0); + } + + /** + * Creates a new instance. + * + * @param certChainFile an X.509 certificate chain file in PEM format. + * {@code null} to use the system default + * @param trustManagerFactory the {@link TrustManagerFactory} that provides the {@link TrustManager}s + * that verifies the certificates sent from servers. + * {@code null} to use the default. + * @param ciphers the cipher suites to enable, in the order of preference. + * {@code null} to use the default cipher suites. + * @param nextProtocols the application layer protocols to accept, in the order of preference. + * {@code null} to disable TLS NPN/ALPN extension. + * @param sessionCacheSize the size of the cache used for storing SSL session objects. + * {@code 0} to use the default value. + * @param sessionTimeout the timeout for the cached SSL session objects, in seconds. + * {@code 0} to use the default value. + * @deprecated use {@link SslContextBuilder} + */ + @Deprecated + public JdkSslClientContext( + File certChainFile, TrustManagerFactory trustManagerFactory, + Iterable ciphers, Iterable nextProtocols, + long sessionCacheSize, long sessionTimeout) throws SSLException { + this(certChainFile, trustManagerFactory, ciphers, IdentityCipherSuiteFilter.INSTANCE, + toNegotiator(toApplicationProtocolConfig(nextProtocols), false), sessionCacheSize, sessionTimeout); + } + + /** + * Creates a new instance. + * + * @param certChainFile an X.509 certificate chain file in PEM format. + * {@code null} to use the system default + * @param trustManagerFactory the {@link TrustManagerFactory} that provides the {@link TrustManager}s + * that verifies the certificates sent from servers. + * {@code null} to use the default. + * @param ciphers the cipher suites to enable, in the order of preference. + * {@code null} to use the default cipher suites. + * @param cipherFilter a filter to apply over the supplied list of ciphers + * @param apn Provides a means to configure parameters related to application protocol negotiation. + * @param sessionCacheSize the size of the cache used for storing SSL session objects. + * {@code 0} to use the default value. + * @param sessionTimeout the timeout for the cached SSL session objects, in seconds. + * {@code 0} to use the default value. + * @deprecated use {@link SslContextBuilder} + */ + @Deprecated + public JdkSslClientContext( + File certChainFile, TrustManagerFactory trustManagerFactory, + Iterable ciphers, CipherSuiteFilter cipherFilter, ApplicationProtocolConfig apn, + long sessionCacheSize, long sessionTimeout) throws SSLException { + this(certChainFile, trustManagerFactory, ciphers, cipherFilter, + toNegotiator(apn, false), sessionCacheSize, sessionTimeout); + } + + /** + * Creates a new instance. + * + * @param certChainFile an X.509 certificate chain file in PEM format. + * {@code null} to use the system default + * @param trustManagerFactory the {@link TrustManagerFactory} that provides the {@link TrustManager}s + * that verifies the certificates sent from servers. + * {@code null} to use the default. + * @param ciphers the cipher suites to enable, in the order of preference. + * {@code null} to use the default cipher suites. + * @param cipherFilter a filter to apply over the supplied list of ciphers + * @param apn Application Protocol Negotiator object. + * @param sessionCacheSize the size of the cache used for storing SSL session objects. + * {@code 0} to use the default value. + * @param sessionTimeout the timeout for the cached SSL session objects, in seconds. + * {@code 0} to use the default value. + * @deprecated use {@link SslContextBuilder} + */ + @Deprecated + public JdkSslClientContext( + File certChainFile, TrustManagerFactory trustManagerFactory, + Iterable ciphers, CipherSuiteFilter cipherFilter, JdkApplicationProtocolNegotiator apn, + long sessionCacheSize, long sessionTimeout) throws SSLException { + this(null, certChainFile, trustManagerFactory, ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout); + } + + JdkSslClientContext(Provider provider, + File trustCertCollectionFile, TrustManagerFactory trustManagerFactory, + Iterable ciphers, CipherSuiteFilter cipherFilter, JdkApplicationProtocolNegotiator apn, + long sessionCacheSize, long sessionTimeout) throws SSLException { + super(newSSLContext(provider, toX509CertificatesInternal(trustCertCollectionFile), + trustManagerFactory, null, null, + null, null, sessionCacheSize, sessionTimeout, KeyStore.getDefaultType()), true, + ciphers, cipherFilter, apn, ClientAuth.NONE, null, false); + } + + /** + * Creates a new instance. + * @param trustCertCollectionFile an X.509 certificate collection file in PEM format. + * {@code null} to use the system default + * @param trustManagerFactory the {@link TrustManagerFactory} that provides the {@link TrustManager}s + * that verifies the certificates sent from servers. + * {@code null} to use the default or the results of parsing + * {@code trustCertCollectionFile} + * @param keyCertChainFile an X.509 certificate chain file in PEM format. + * This provides the public key for mutual authentication. + * {@code null} to use the system default + * @param keyFile a PKCS#8 private key file in PEM format. + * This provides the private key for mutual authentication. + * {@code null} for no mutual authentication. + * @param keyPassword the password of the {@code keyFile}. + * {@code null} if it's not password-protected. + * Ignored if {@code keyFile} is {@code null}. + * @param keyManagerFactory the {@link KeyManagerFactory} that provides the {@link KeyManager}s + * that is used to encrypt data being sent to servers. + * {@code null} to use the default or the results of parsing + * {@code keyCertChainFile} and {@code keyFile}. + * @param ciphers the cipher suites to enable, in the order of preference. + * {@code null} to use the default cipher suites. + * @param cipherFilter a filter to apply over the supplied list of ciphers + * @param apn Provides a means to configure parameters related to application protocol negotiation. + * @param sessionCacheSize the size of the cache used for storing SSL session objects. + * {@code 0} to use the default value. + * @param sessionTimeout the timeout for the cached SSL session objects, in seconds. + * {@code 0} to use the default value. + * @deprecated use {@link SslContextBuilder} + */ + @Deprecated + public JdkSslClientContext(File trustCertCollectionFile, TrustManagerFactory trustManagerFactory, + File keyCertChainFile, File keyFile, String keyPassword, KeyManagerFactory keyManagerFactory, + Iterable ciphers, CipherSuiteFilter cipherFilter, ApplicationProtocolConfig apn, + long sessionCacheSize, long sessionTimeout) throws SSLException { + this(trustCertCollectionFile, trustManagerFactory, keyCertChainFile, keyFile, keyPassword, keyManagerFactory, + ciphers, cipherFilter, toNegotiator(apn, false), sessionCacheSize, sessionTimeout); + } + + /** + * Creates a new instance. + * @param trustCertCollectionFile an X.509 certificate collection file in PEM format. + * {@code null} to use the system default + * @param trustManagerFactory the {@link TrustManagerFactory} that provides the {@link TrustManager}s + * that verifies the certificates sent from servers. + * {@code null} to use the default or the results of parsing + * {@code trustCertCollectionFile} + * @param keyCertChainFile an X.509 certificate chain file in PEM format. + * This provides the public key for mutual authentication. + * {@code null} to use the system default + * @param keyFile a PKCS#8 private key file in PEM format. + * This provides the private key for mutual authentication. + * {@code null} for no mutual authentication. + * @param keyPassword the password of the {@code keyFile}. + * {@code null} if it's not password-protected. + * Ignored if {@code keyFile} is {@code null}. + * @param keyManagerFactory the {@link KeyManagerFactory} that provides the {@link KeyManager}s + * that is used to encrypt data being sent to servers. + * {@code null} to use the default or the results of parsing + * {@code keyCertChainFile} and {@code keyFile}. + * @param ciphers the cipher suites to enable, in the order of preference. + * {@code null} to use the default cipher suites. + * @param cipherFilter a filter to apply over the supplied list of ciphers + * @param apn Application Protocol Negotiator object. + * @param sessionCacheSize the size of the cache used for storing SSL session objects. + * {@code 0} to use the default value. + * @param sessionTimeout the timeout for the cached SSL session objects, in seconds. + * {@code 0} to use the default value. + * @deprecated use {@link SslContextBuilder} + */ + @Deprecated + public JdkSslClientContext(File trustCertCollectionFile, TrustManagerFactory trustManagerFactory, + File keyCertChainFile, File keyFile, String keyPassword, KeyManagerFactory keyManagerFactory, + Iterable ciphers, CipherSuiteFilter cipherFilter, JdkApplicationProtocolNegotiator apn, + long sessionCacheSize, long sessionTimeout) throws SSLException { + super(newSSLContext(null, toX509CertificatesInternal( + trustCertCollectionFile), trustManagerFactory, + toX509CertificatesInternal(keyCertChainFile), toPrivateKeyInternal(keyFile, keyPassword), + keyPassword, keyManagerFactory, sessionCacheSize, sessionTimeout, KeyStore.getDefaultType()), true, + ciphers, cipherFilter, apn, ClientAuth.NONE, null, false); + } + + JdkSslClientContext(Provider sslContextProvider, + X509Certificate[] trustCertCollection, TrustManagerFactory trustManagerFactory, + X509Certificate[] keyCertChain, PrivateKey key, String keyPassword, + KeyManagerFactory keyManagerFactory, Iterable ciphers, CipherSuiteFilter cipherFilter, + ApplicationProtocolConfig apn, String[] protocols, long sessionCacheSize, long sessionTimeout, + String keyStoreType) + throws SSLException { + super(newSSLContext(sslContextProvider, trustCertCollection, trustManagerFactory, + keyCertChain, key, keyPassword, keyManagerFactory, sessionCacheSize, + sessionTimeout, keyStoreType), + true, ciphers, cipherFilter, toNegotiator(apn, false), ClientAuth.NONE, protocols, false); + } + + private static SSLContext newSSLContext(Provider sslContextProvider, + X509Certificate[] trustCertCollection, + TrustManagerFactory trustManagerFactory, X509Certificate[] keyCertChain, + PrivateKey key, String keyPassword, KeyManagerFactory keyManagerFactory, + long sessionCacheSize, long sessionTimeout, + String keyStore) throws SSLException { + try { + if (trustCertCollection != null) { + trustManagerFactory = buildTrustManagerFactory(trustCertCollection, trustManagerFactory, keyStore); + } + if (keyCertChain != null) { + keyManagerFactory = buildKeyManagerFactory(keyCertChain, null, + key, keyPassword, keyManagerFactory, keyStore); + } + SSLContext ctx = sslContextProvider == null ? SSLContext.getInstance(PROTOCOL) + : SSLContext.getInstance(PROTOCOL, sslContextProvider); + ctx.init(keyManagerFactory == null ? null : keyManagerFactory.getKeyManagers(), + trustManagerFactory == null ? null : trustManagerFactory.getTrustManagers(), + null); + + SSLSessionContext sessCtx = ctx.getClientSessionContext(); + if (sessionCacheSize > 0) { + sessCtx.setSessionCacheSize((int) Math.min(sessionCacheSize, Integer.MAX_VALUE)); + } + if (sessionTimeout > 0) { + sessCtx.setSessionTimeout((int) Math.min(sessionTimeout, Integer.MAX_VALUE)); + } + return ctx; + } catch (Exception e) { + if (e instanceof SSLException) { + throw (SSLException) e; + } + throw new SSLException("failed to initialize the client-side SSL context", e); + } + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/JdkSslContext.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/JdkSslContext.java new file mode 100644 index 0000000..a81d9c2 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/JdkSslContext.java @@ -0,0 +1,514 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.ssl; + +import io.netty.buffer.ByteBufAllocator; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.internal.EmptyArrays; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.io.File; +import java.io.IOException; +import java.security.InvalidAlgorithmParameterException; +import java.security.KeyException; +import java.security.KeyStore; +import java.security.KeyStoreException; +import java.security.NoSuchAlgorithmException; +import java.security.Provider; +import java.security.Security; +import java.security.UnrecoverableKeyException; +import java.security.cert.CertificateException; +import java.security.spec.InvalidKeySpecException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Set; + +import javax.crypto.NoSuchPaddingException; +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLSessionContext; + +import static io.netty.handler.ssl.SslUtils.DEFAULT_CIPHER_SUITES; +import static io.netty.handler.ssl.SslUtils.addIfSupported; +import static io.netty.handler.ssl.SslUtils.useFallbackCiphersIfDefaultIsEmpty; +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +/** + * An {@link SslContext} which uses JDK's SSL/TLS implementation. + */ +public class JdkSslContext extends SslContext { + + private static final InternalLogger logger = InternalLoggerFactory.getInstance(JdkSslContext.class); + + static final String PROTOCOL = "TLS"; + private static final String[] DEFAULT_PROTOCOLS; + private static final List DEFAULT_CIPHERS; + private static final List DEFAULT_CIPHERS_NON_TLSV13; + private static final Set SUPPORTED_CIPHERS; + private static final Set SUPPORTED_CIPHERS_NON_TLSV13; + private static final Provider DEFAULT_PROVIDER; + + static { + Defaults defaults = new Defaults(); + defaults.init(); + + DEFAULT_PROVIDER = defaults.defaultProvider; + DEFAULT_PROTOCOLS = defaults.defaultProtocols; + SUPPORTED_CIPHERS = defaults.supportedCiphers; + DEFAULT_CIPHERS = defaults.defaultCiphers; + DEFAULT_CIPHERS_NON_TLSV13 = defaults.defaultCiphersNonTLSv13; + SUPPORTED_CIPHERS_NON_TLSV13 = defaults.supportedCiphersNonTLSv13; + + if (logger.isDebugEnabled()) { + logger.debug("Default protocols (JDK): {} ", Arrays.asList(DEFAULT_PROTOCOLS)); + logger.debug("Default cipher suites (JDK): {}", DEFAULT_CIPHERS); + } + } + + private static final class Defaults { + String[] defaultProtocols; + List defaultCiphers; + List defaultCiphersNonTLSv13; + Set supportedCiphers; + Set supportedCiphersNonTLSv13; + Provider defaultProvider; + + void init() { + SSLContext context; + try { + context = SSLContext.getInstance(PROTOCOL); + context.init(null, null, null); + } catch (Exception e) { + throw new Error("failed to initialize the default SSL context", e); + } + + defaultProvider = context.getProvider(); + + SSLEngine engine = context.createSSLEngine(); + defaultProtocols = defaultProtocols(context, engine); + + supportedCiphers = Collections.unmodifiableSet(supportedCiphers(engine)); + defaultCiphers = Collections.unmodifiableList(defaultCiphers(engine, supportedCiphers)); + + List ciphersNonTLSv13 = new ArrayList(defaultCiphers); + ciphersNonTLSv13.removeAll(Arrays.asList(SslUtils.DEFAULT_TLSV13_CIPHER_SUITES)); + defaultCiphersNonTLSv13 = Collections.unmodifiableList(ciphersNonTLSv13); + + Set suppertedCiphersNonTLSv13 = new LinkedHashSet(supportedCiphers); + suppertedCiphersNonTLSv13.removeAll(Arrays.asList(SslUtils.DEFAULT_TLSV13_CIPHER_SUITES)); + supportedCiphersNonTLSv13 = Collections.unmodifiableSet(suppertedCiphersNonTLSv13); + } + } + + private static String[] defaultProtocols(SSLContext context, SSLEngine engine) { + // Choose the sensible default list of protocols that respects JDK flags, eg. jdk.tls.client.protocols + final String[] supportedProtocols = context.getDefaultSSLParameters().getProtocols(); + Set supportedProtocolsSet = new HashSet(supportedProtocols.length); + Collections.addAll(supportedProtocolsSet, supportedProtocols); + List protocols = new ArrayList(); + addIfSupported( + supportedProtocolsSet, protocols, + SslProtocols.TLS_v1_3, SslProtocols.TLS_v1_2, + SslProtocols.TLS_v1_1, SslProtocols.TLS_v1); + + if (!protocols.isEmpty()) { + return protocols.toArray(EmptyArrays.EMPTY_STRINGS); + } + return engine.getEnabledProtocols(); + } + + private static Set supportedCiphers(SSLEngine engine) { + // Choose the sensible default list of cipher suites. + final String[] supportedCiphers = engine.getSupportedCipherSuites(); + Set supportedCiphersSet = new LinkedHashSet(supportedCiphers.length); + for (int i = 0; i < supportedCiphers.length; ++i) { + String supportedCipher = supportedCiphers[i]; + supportedCiphersSet.add(supportedCipher); + // IBM's J9 JVM utilizes a custom naming scheme for ciphers and only returns ciphers with the "SSL_" + // prefix instead of the "TLS_" prefix (as defined in the JSSE cipher suite names [1]). According to IBM's + // documentation [2] the "SSL_" prefix is "interchangeable" with the "TLS_" prefix. + // See the IBM forum discussion [3] and issue on IBM's JVM [4] for more details. + //[1] https://docs.oracle.com/javase/8/docs/technotes/guides/security/StandardNames.html#ciphersuites + //[2] https://www.ibm.com/support/knowledgecenter/en/SSYKE2_8.0.0/com.ibm.java.security.component.80.doc/ + // security-component/jsse2Docs/ciphersuites.html + //[3] https://www.ibm.com/developerworks/community/forums/html/topic?id=9b5a56a9-fa46-4031-b33b-df91e28d77c2 + //[4] https://www.ibm.com/developerworks/rfe/execute?use_case=viewRfe&CR_ID=71770 + if (supportedCipher.startsWith("SSL_")) { + final String tlsPrefixedCipherName = "TLS_" + supportedCipher.substring("SSL_".length()); + try { + engine.setEnabledCipherSuites(new String[]{tlsPrefixedCipherName}); + supportedCiphersSet.add(tlsPrefixedCipherName); + } catch (IllegalArgumentException ignored) { + // The cipher is not supported ... move on to the next cipher. + } + } + } + return supportedCiphersSet; + } + + private static List defaultCiphers(SSLEngine engine, Set supportedCiphers) { + List ciphers = new ArrayList(); + addIfSupported(supportedCiphers, ciphers, DEFAULT_CIPHER_SUITES); + useFallbackCiphersIfDefaultIsEmpty(ciphers, engine.getEnabledCipherSuites()); + return ciphers; + } + + private static boolean isTlsV13Supported(String[] protocols) { + for (String protocol: protocols) { + if (SslProtocols.TLS_v1_3.equals(protocol)) { + return true; + } + } + return false; + } + + private final String[] protocols; + private final String[] cipherSuites; + private final List unmodifiableCipherSuites; + @SuppressWarnings("deprecation") + private final JdkApplicationProtocolNegotiator apn; + private final ClientAuth clientAuth; + private final SSLContext sslContext; + private final boolean isClient; + + /** + * Creates a new {@link JdkSslContext} from a pre-configured {@link SSLContext}. + * + * @param sslContext the {@link SSLContext} to use. + * @param isClient {@code true} if this context should create {@link SSLEngine}s for client-side usage. + * @param clientAuth the {@link ClientAuth} to use. This will only be used when {@param isClient} is {@code false}. + * @deprecated Use {@link #JdkSslContext(SSLContext, boolean, Iterable, CipherSuiteFilter, + * ApplicationProtocolConfig, ClientAuth, String[], boolean)} + */ + @Deprecated + public JdkSslContext(SSLContext sslContext, boolean isClient, + ClientAuth clientAuth) { + this(sslContext, isClient, null, IdentityCipherSuiteFilter.INSTANCE, + JdkDefaultApplicationProtocolNegotiator.INSTANCE, clientAuth, null, false); + } + + /** + * Creates a new {@link JdkSslContext} from a pre-configured {@link SSLContext}. + * + * @param sslContext the {@link SSLContext} to use. + * @param isClient {@code true} if this context should create {@link SSLEngine}s for client-side usage. + * @param ciphers the ciphers to use or {@code null} if the standard should be used. + * @param cipherFilter the filter to use. + * @param apn the {@link ApplicationProtocolConfig} to use. + * @param clientAuth the {@link ClientAuth} to use. This will only be used when {@param isClient} is {@code false}. + * @deprecated Use {@link #JdkSslContext(SSLContext, boolean, Iterable, CipherSuiteFilter, + * ApplicationProtocolConfig, ClientAuth, String[], boolean)} + */ + @Deprecated + public JdkSslContext(SSLContext sslContext, boolean isClient, Iterable ciphers, + CipherSuiteFilter cipherFilter, ApplicationProtocolConfig apn, + ClientAuth clientAuth) { + this(sslContext, isClient, ciphers, cipherFilter, apn, clientAuth, null, false); + } + + /** + * Creates a new {@link JdkSslContext} from a pre-configured {@link SSLContext}. + * + * @param sslContext the {@link SSLContext} to use. + * @param isClient {@code true} if this context should create {@link SSLEngine}s for client-side usage. + * @param ciphers the ciphers to use or {@code null} if the standard should be used. + * @param cipherFilter the filter to use. + * @param apn the {@link ApplicationProtocolConfig} to use. + * @param clientAuth the {@link ClientAuth} to use. This will only be used when {@param isClient} is {@code false}. + * @param protocols the protocols to enable, or {@code null} to enable the default protocols. + * @param startTls {@code true} if the first write request shouldn't be encrypted + */ + public JdkSslContext(SSLContext sslContext, + boolean isClient, + Iterable ciphers, + CipherSuiteFilter cipherFilter, + ApplicationProtocolConfig apn, + ClientAuth clientAuth, + String[] protocols, + boolean startTls) { + this(sslContext, + isClient, + ciphers, + cipherFilter, + toNegotiator(apn, !isClient), + clientAuth, + protocols == null ? null : protocols.clone(), + startTls); + } + + @SuppressWarnings("deprecation") + JdkSslContext(SSLContext sslContext, boolean isClient, Iterable ciphers, CipherSuiteFilter cipherFilter, + JdkApplicationProtocolNegotiator apn, ClientAuth clientAuth, String[] protocols, boolean startTls) { + super(startTls); + this.apn = checkNotNull(apn, "apn"); + this.clientAuth = checkNotNull(clientAuth, "clientAuth"); + this.sslContext = checkNotNull(sslContext, "sslContext"); + + final List defaultCiphers; + final Set supportedCiphers; + if (DEFAULT_PROVIDER.equals(sslContext.getProvider())) { + this.protocols = protocols == null? DEFAULT_PROTOCOLS : protocols; + if (isTlsV13Supported(this.protocols)) { + supportedCiphers = SUPPORTED_CIPHERS; + defaultCiphers = DEFAULT_CIPHERS; + } else { + // TLSv1.3 is not supported, ensure we do not include any TLSv1.3 ciphersuite. + supportedCiphers = SUPPORTED_CIPHERS_NON_TLSV13; + defaultCiphers = DEFAULT_CIPHERS_NON_TLSV13; + } + } else { + // This is a different Provider then the one used by the JDK by default so we can not just assume + // the same protocols and ciphers are supported. For example even if Java11+ is used Conscrypt will + // not support TLSv1.3 and the TLSv1.3 ciphersuites. + SSLEngine engine = sslContext.createSSLEngine(); + try { + if (protocols == null) { + this.protocols = defaultProtocols(sslContext, engine); + } else { + this.protocols = protocols; + } + supportedCiphers = supportedCiphers(engine); + defaultCiphers = defaultCiphers(engine, supportedCiphers); + if (!isTlsV13Supported(this.protocols)) { + // TLSv1.3 is not supported, ensure we do not include any TLSv1.3 ciphersuite. + for (String cipher: SslUtils.DEFAULT_TLSV13_CIPHER_SUITES) { + supportedCiphers.remove(cipher); + defaultCiphers.remove(cipher); + } + } + } finally { + ReferenceCountUtil.release(engine); + } + } + + cipherSuites = checkNotNull(cipherFilter, "cipherFilter").filterCipherSuites( + ciphers, defaultCiphers, supportedCiphers); + + unmodifiableCipherSuites = Collections.unmodifiableList(Arrays.asList(cipherSuites)); + this.isClient = isClient; + } + + /** + * Returns the JDK {@link SSLContext} object held by this context. + */ + public final SSLContext context() { + return sslContext; + } + + @Override + public final boolean isClient() { + return isClient; + } + + /** + * Returns the JDK {@link SSLSessionContext} object held by this context. + */ + @Override + public final SSLSessionContext sessionContext() { + if (isServer()) { + return context().getServerSessionContext(); + } else { + return context().getClientSessionContext(); + } + } + + @Override + public final List cipherSuites() { + return unmodifiableCipherSuites; + } + + @Override + public final SSLEngine newEngine(ByteBufAllocator alloc) { + return configureAndWrapEngine(context().createSSLEngine(), alloc); + } + + @Override + public final SSLEngine newEngine(ByteBufAllocator alloc, String peerHost, int peerPort) { + return configureAndWrapEngine(context().createSSLEngine(peerHost, peerPort), alloc); + } + + @SuppressWarnings("deprecation") + private SSLEngine configureAndWrapEngine(SSLEngine engine, ByteBufAllocator alloc) { + engine.setEnabledCipherSuites(cipherSuites); + engine.setEnabledProtocols(protocols); + engine.setUseClientMode(isClient()); + if (isServer()) { + switch (clientAuth) { + case OPTIONAL: + engine.setWantClientAuth(true); + break; + case REQUIRE: + engine.setNeedClientAuth(true); + break; + case NONE: + break; // exhaustive cases + default: + throw new Error("Unknown auth " + clientAuth); + } + } + JdkApplicationProtocolNegotiator.SslEngineWrapperFactory factory = apn.wrapperFactory(); + if (factory instanceof JdkApplicationProtocolNegotiator.AllocatorAwareSslEngineWrapperFactory) { + return ((JdkApplicationProtocolNegotiator.AllocatorAwareSslEngineWrapperFactory) factory) + .wrapSslEngine(engine, alloc, apn, isServer()); + } + return factory.wrapSslEngine(engine, apn, isServer()); + } + + @Override + public final JdkApplicationProtocolNegotiator applicationProtocolNegotiator() { + return apn; + } + + /** + * Translate a {@link ApplicationProtocolConfig} object to a {@link JdkApplicationProtocolNegotiator} object. + * @param config The configuration which defines the translation + * @param isServer {@code true} if a server {@code false} otherwise. + * @return The results of the translation + */ + @SuppressWarnings("deprecation") + static JdkApplicationProtocolNegotiator toNegotiator(ApplicationProtocolConfig config, boolean isServer) { + if (config == null) { + return JdkDefaultApplicationProtocolNegotiator.INSTANCE; + } + + switch(config.protocol()) { + case NONE: + return JdkDefaultApplicationProtocolNegotiator.INSTANCE; + case ALPN: + if (isServer) { + switch(config.selectorFailureBehavior()) { + case FATAL_ALERT: + return new JdkAlpnApplicationProtocolNegotiator(true, config.supportedProtocols()); + case NO_ADVERTISE: + return new JdkAlpnApplicationProtocolNegotiator(false, config.supportedProtocols()); + default: + throw new UnsupportedOperationException(new StringBuilder("JDK provider does not support ") + .append(config.selectorFailureBehavior()).append(" failure behavior").toString()); + } + } else { + switch(config.selectedListenerFailureBehavior()) { + case ACCEPT: + return new JdkAlpnApplicationProtocolNegotiator(false, config.supportedProtocols()); + case FATAL_ALERT: + return new JdkAlpnApplicationProtocolNegotiator(true, config.supportedProtocols()); + default: + throw new UnsupportedOperationException(new StringBuilder("JDK provider does not support ") + .append(config.selectedListenerFailureBehavior()).append(" failure behavior").toString()); + } + } + default: + throw new UnsupportedOperationException(new StringBuilder("JDK provider does not support ") + .append(config.protocol()).append(" protocol").toString()); + } + } + + /** + * Build a {@link KeyManagerFactory} based upon a key file, key file password, and a certificate chain. + * @param certChainFile an X.509 certificate chain file in PEM format + * @param keyFile a PKCS#8 private key file in PEM format + * @param keyPassword the password of the {@code keyFile}. + * {@code null} if it's not password-protected. + * @param kmf The existing {@link KeyManagerFactory} that will be used if not {@code null} + * @param keyStore the {@link KeyStore} that should be used in the {@link KeyManagerFactory} + * @return A {@link KeyManagerFactory} based upon a key file, key file password, and a certificate chain. + */ + static KeyManagerFactory buildKeyManagerFactory(File certChainFile, File keyFile, String keyPassword, + KeyManagerFactory kmf, String keyStore) + throws UnrecoverableKeyException, KeyStoreException, NoSuchAlgorithmException, + NoSuchPaddingException, InvalidKeySpecException, InvalidAlgorithmParameterException, + CertificateException, KeyException, IOException { + String algorithm = Security.getProperty("ssl.KeyManagerFactory.algorithm"); + if (algorithm == null) { + algorithm = "SunX509"; + } + return buildKeyManagerFactory(certChainFile, algorithm, keyFile, keyPassword, kmf, keyStore); + } + + /** + * Build a {@link KeyManagerFactory} based upon a key file, key file password, and a certificate chain. + * @param certChainFile an X.509 certificate chain file in PEM format + * @param keyFile a PKCS#8 private key file in PEM format + * @param keyPassword the password of the {@code keyFile}. + * {@code null} if it's not password-protected. + * @param kmf The existing {@link KeyManagerFactory} that will be used if not {@code null} + * @return A {@link KeyManagerFactory} based upon a key file, key file password, and a certificate chain. + * @deprecated will be removed. + */ + @Deprecated + protected static KeyManagerFactory buildKeyManagerFactory(File certChainFile, File keyFile, String keyPassword, + KeyManagerFactory kmf) + throws UnrecoverableKeyException, KeyStoreException, NoSuchAlgorithmException, + NoSuchPaddingException, InvalidKeySpecException, InvalidAlgorithmParameterException, + CertificateException, KeyException, IOException { + return buildKeyManagerFactory(certChainFile, keyFile, keyPassword, kmf, KeyStore.getDefaultType()); + } + + /** + * Build a {@link KeyManagerFactory} based upon a key algorithm, key file, key file password, + * and a certificate chain. + * @param certChainFile an X.509 certificate chain file in PEM format + * @param keyAlgorithm the standard name of the requested algorithm. See the Java Secure Socket Extension + * Reference Guide for information about standard algorithm names. + * @param keyFile a PKCS#8 private key file in PEM format + * @param keyPassword the password of the {@code keyFile}. + * {@code null} if it's not password-protected. + * @param kmf The existing {@link KeyManagerFactory} that will be used if not {@code null} + * @param keyStore the {@link KeyStore} that should be used in the {@link KeyManagerFactory} + * @return A {@link KeyManagerFactory} based upon a key algorithm, key file, key file password, + * and a certificate chain. + */ + static KeyManagerFactory buildKeyManagerFactory(File certChainFile, + String keyAlgorithm, File keyFile, String keyPassword, KeyManagerFactory kmf, + String keyStore) + throws KeyStoreException, NoSuchAlgorithmException, NoSuchPaddingException, + InvalidKeySpecException, InvalidAlgorithmParameterException, IOException, + CertificateException, KeyException, UnrecoverableKeyException { + return buildKeyManagerFactory(toX509Certificates(certChainFile), keyAlgorithm, + toPrivateKey(keyFile, keyPassword), keyPassword, kmf, keyStore); + } + + /** + * Build a {@link KeyManagerFactory} based upon a key algorithm, key file, key file password, + * and a certificate chain. + * @param certChainFile an buildKeyManagerFactory X.509 certificate chain file in PEM format + * @param keyAlgorithm the standard name of the requested algorithm. See the Java Secure Socket Extension + * Reference Guide for information about standard algorithm names. + * @param keyFile a PKCS#8 private key file in PEM format + * @param keyPassword the password of the {@code keyFile}. + * {@code null} if it's not password-protected. + * @param kmf The existing {@link KeyManagerFactory} that will be used if not {@code null} + * @return A {@link KeyManagerFactory} based upon a key algorithm, key file, key file password, + * and a certificate chain. + * @deprecated will be removed. + */ + @Deprecated + protected static KeyManagerFactory buildKeyManagerFactory(File certChainFile, + String keyAlgorithm, File keyFile, + String keyPassword, KeyManagerFactory kmf) + throws KeyStoreException, NoSuchAlgorithmException, NoSuchPaddingException, + InvalidKeySpecException, InvalidAlgorithmParameterException, IOException, + CertificateException, KeyException, UnrecoverableKeyException { + return buildKeyManagerFactory(toX509Certificates(certChainFile), keyAlgorithm, + toPrivateKey(keyFile, keyPassword), keyPassword, kmf, KeyStore.getDefaultType()); + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/JdkSslEngine.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/JdkSslEngine.java new file mode 100644 index 0000000..8f7fa66 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/JdkSslEngine.java @@ -0,0 +1,215 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.util.internal.SuppressJava6Requirement; + +import java.nio.ByteBuffer; + +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLEngineResult; +import javax.net.ssl.SSLEngineResult.HandshakeStatus; +import javax.net.ssl.SSLException; +import javax.net.ssl.SSLParameters; +import javax.net.ssl.SSLSession; + +class JdkSslEngine extends SSLEngine implements ApplicationProtocolAccessor { + private final SSLEngine engine; + private volatile String applicationProtocol; + + JdkSslEngine(SSLEngine engine) { + this.engine = engine; + } + + @Override + public String getNegotiatedApplicationProtocol() { + return applicationProtocol; + } + + void setNegotiatedApplicationProtocol(String applicationProtocol) { + this.applicationProtocol = applicationProtocol; + } + + @Override + public SSLSession getSession() { + return engine.getSession(); + } + + public SSLEngine getWrappedEngine() { + return engine; + } + + @Override + public void closeInbound() throws SSLException { + engine.closeInbound(); + } + + @Override + public void closeOutbound() { + engine.closeOutbound(); + } + + @Override + public String getPeerHost() { + return engine.getPeerHost(); + } + + @Override + public int getPeerPort() { + return engine.getPeerPort(); + } + + @Override + public SSLEngineResult wrap(ByteBuffer byteBuffer, ByteBuffer byteBuffer2) throws SSLException { + return engine.wrap(byteBuffer, byteBuffer2); + } + + @Override + public SSLEngineResult wrap(ByteBuffer[] byteBuffers, ByteBuffer byteBuffer) throws SSLException { + return engine.wrap(byteBuffers, byteBuffer); + } + + @Override + public SSLEngineResult wrap(ByteBuffer[] byteBuffers, int i, int i2, ByteBuffer byteBuffer) throws SSLException { + return engine.wrap(byteBuffers, i, i2, byteBuffer); + } + + @Override + public SSLEngineResult unwrap(ByteBuffer byteBuffer, ByteBuffer byteBuffer2) throws SSLException { + return engine.unwrap(byteBuffer, byteBuffer2); + } + + @Override + public SSLEngineResult unwrap(ByteBuffer byteBuffer, ByteBuffer[] byteBuffers) throws SSLException { + return engine.unwrap(byteBuffer, byteBuffers); + } + + @Override + public SSLEngineResult unwrap(ByteBuffer byteBuffer, ByteBuffer[] byteBuffers, int i, int i2) throws SSLException { + return engine.unwrap(byteBuffer, byteBuffers, i, i2); + } + + @Override + public Runnable getDelegatedTask() { + return engine.getDelegatedTask(); + } + + @Override + public boolean isInboundDone() { + return engine.isInboundDone(); + } + + @Override + public boolean isOutboundDone() { + return engine.isOutboundDone(); + } + + @Override + public String[] getSupportedCipherSuites() { + return engine.getSupportedCipherSuites(); + } + + @Override + public String[] getEnabledCipherSuites() { + return engine.getEnabledCipherSuites(); + } + + @Override + public void setEnabledCipherSuites(String[] strings) { + engine.setEnabledCipherSuites(strings); + } + + @Override + public String[] getSupportedProtocols() { + return engine.getSupportedProtocols(); + } + + @Override + public String[] getEnabledProtocols() { + return engine.getEnabledProtocols(); + } + + @Override + public void setEnabledProtocols(String[] strings) { + engine.setEnabledProtocols(strings); + } + + @SuppressJava6Requirement(reason = "Can only be called when running on JDK7+") + @Override + public SSLSession getHandshakeSession() { + return engine.getHandshakeSession(); + } + + @Override + public void beginHandshake() throws SSLException { + engine.beginHandshake(); + } + + @Override + public HandshakeStatus getHandshakeStatus() { + return engine.getHandshakeStatus(); + } + + @Override + public void setUseClientMode(boolean b) { + engine.setUseClientMode(b); + } + + @Override + public boolean getUseClientMode() { + return engine.getUseClientMode(); + } + + @Override + public void setNeedClientAuth(boolean b) { + engine.setNeedClientAuth(b); + } + + @Override + public boolean getNeedClientAuth() { + return engine.getNeedClientAuth(); + } + + @Override + public void setWantClientAuth(boolean b) { + engine.setWantClientAuth(b); + } + + @Override + public boolean getWantClientAuth() { + return engine.getWantClientAuth(); + } + + @Override + public void setEnableSessionCreation(boolean b) { + engine.setEnableSessionCreation(b); + } + + @Override + public boolean getEnableSessionCreation() { + return engine.getEnableSessionCreation(); + } + + @Override + public SSLParameters getSSLParameters() { + return engine.getSSLParameters(); + } + + @Override + public void setSSLParameters(SSLParameters sslParameters) { + engine.setSSLParameters(sslParameters); + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/JdkSslServerContext.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/JdkSslServerContext.java new file mode 100644 index 0000000..49a2330 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/JdkSslServerContext.java @@ -0,0 +1,317 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.ssl; + +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.SuppressJava6Requirement; + +import java.security.KeyStore; +import java.security.Provider; +import javax.net.ssl.KeyManager; + +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLException; +import javax.net.ssl.SSLSessionContext; +import javax.net.ssl.TrustManager; +import javax.net.ssl.TrustManagerFactory; +import javax.net.ssl.X509ExtendedTrustManager; +import java.io.File; +import java.security.PrivateKey; +import java.security.cert.X509Certificate; + +/** + * A server-side {@link SslContext} which uses JDK's SSL/TLS implementation. + * + * @deprecated Use {@link SslContextBuilder} to create {@link JdkSslContext} instances and only + * use {@link JdkSslContext} in your code. + */ +@Deprecated +public final class JdkSslServerContext extends JdkSslContext { + + /** + * Creates a new instance. + * + * @param certChainFile an X.509 certificate chain file in PEM format + * @param keyFile a PKCS#8 private key file in PEM format + * @deprecated use {@link SslContextBuilder} + */ + @Deprecated + public JdkSslServerContext(File certChainFile, File keyFile) throws SSLException { + this(null, certChainFile, keyFile, null, null, IdentityCipherSuiteFilter.INSTANCE, + JdkDefaultApplicationProtocolNegotiator.INSTANCE, 0, 0, null); + } + + /** + * Creates a new instance. + * + * @param certChainFile an X.509 certificate chain file in PEM format + * @param keyFile a PKCS#8 private key file in PEM format + * @param keyPassword the password of the {@code keyFile}. + * {@code null} if it's not password-protected. + * @deprecated use {@link SslContextBuilder} + */ + @Deprecated + public JdkSslServerContext(File certChainFile, File keyFile, String keyPassword) throws SSLException { + this(certChainFile, keyFile, keyPassword, null, IdentityCipherSuiteFilter.INSTANCE, + JdkDefaultApplicationProtocolNegotiator.INSTANCE, 0, 0); + } + + /** + * Creates a new instance. + * + * @param certChainFile an X.509 certificate chain file in PEM format + * @param keyFile a PKCS#8 private key file in PEM format + * @param keyPassword the password of the {@code keyFile}. + * {@code null} if it's not password-protected. + * @param ciphers the cipher suites to enable, in the order of preference. + * {@code null} to use the default cipher suites. + * @param nextProtocols the application layer protocols to accept, in the order of preference. + * {@code null} to disable TLS NPN/ALPN extension. + * @param sessionCacheSize the size of the cache used for storing SSL session objects. + * {@code 0} to use the default value. + * @param sessionTimeout the timeout for the cached SSL session objects, in seconds. + * {@code 0} to use the default value. + * @deprecated use {@link SslContextBuilder} + */ + @Deprecated + public JdkSslServerContext( + File certChainFile, File keyFile, String keyPassword, + Iterable ciphers, Iterable nextProtocols, + long sessionCacheSize, long sessionTimeout) throws SSLException { + this(null, certChainFile, keyFile, keyPassword, ciphers, IdentityCipherSuiteFilter.INSTANCE, + toNegotiator(toApplicationProtocolConfig(nextProtocols), true), sessionCacheSize, + sessionTimeout, KeyStore.getDefaultType()); + } + + /** + * Creates a new instance. + * + * @param certChainFile an X.509 certificate chain file in PEM format + * @param keyFile a PKCS#8 private key file in PEM format + * @param keyPassword the password of the {@code keyFile}. + * {@code null} if it's not password-protected. + * @param ciphers the cipher suites to enable, in the order of preference. + * {@code null} to use the default cipher suites. + * @param cipherFilter a filter to apply over the supplied list of ciphers + * @param apn Provides a means to configure parameters related to application protocol negotiation. + * @param sessionCacheSize the size of the cache used for storing SSL session objects. + * {@code 0} to use the default value. + * @param sessionTimeout the timeout for the cached SSL session objects, in seconds. + * {@code 0} to use the default value. + * @deprecated use {@link SslContextBuilder} + */ + @Deprecated + public JdkSslServerContext( + File certChainFile, File keyFile, String keyPassword, + Iterable ciphers, CipherSuiteFilter cipherFilter, ApplicationProtocolConfig apn, + long sessionCacheSize, long sessionTimeout) throws SSLException { + this(null, certChainFile, keyFile, keyPassword, ciphers, cipherFilter, + toNegotiator(apn, true), sessionCacheSize, sessionTimeout, KeyStore.getDefaultType()); + } + + /** + * Creates a new instance. + * + * @param certChainFile an X.509 certificate chain file in PEM format + * @param keyFile a PKCS#8 private key file in PEM format + * @param keyPassword the password of the {@code keyFile}. + * {@code null} if it's not password-protected. + * @param ciphers the cipher suites to enable, in the order of preference. + * {@code null} to use the default cipher suites. + * @param cipherFilter a filter to apply over the supplied list of ciphers + * @param apn Application Protocol Negotiator object. + * @param sessionCacheSize the size of the cache used for storing SSL session objects. + * {@code 0} to use the default value. + * @param sessionTimeout the timeout for the cached SSL session objects, in seconds. + * {@code 0} to use the default value. + * @deprecated use {@link SslContextBuilder} + */ + @Deprecated + public JdkSslServerContext( + File certChainFile, File keyFile, String keyPassword, + Iterable ciphers, CipherSuiteFilter cipherFilter, JdkApplicationProtocolNegotiator apn, + long sessionCacheSize, long sessionTimeout) throws SSLException { + this(null, certChainFile, keyFile, keyPassword, ciphers, cipherFilter, apn, + sessionCacheSize, sessionTimeout, KeyStore.getDefaultType()); + } + + JdkSslServerContext(Provider provider, + File certChainFile, File keyFile, String keyPassword, + Iterable ciphers, CipherSuiteFilter cipherFilter, JdkApplicationProtocolNegotiator apn, + long sessionCacheSize, long sessionTimeout, String keyStore) throws SSLException { + super(newSSLContext(provider, null, null, + toX509CertificatesInternal(certChainFile), toPrivateKeyInternal(keyFile, keyPassword), + keyPassword, null, sessionCacheSize, sessionTimeout, keyStore), false, + ciphers, cipherFilter, apn, ClientAuth.NONE, null, false); + } + + /** + * Creates a new instance. + * @param trustCertCollectionFile an X.509 certificate collection file in PEM format. + * This provides the certificate collection used for mutual authentication. + * {@code null} to use the system default + * @param trustManagerFactory the {@link TrustManagerFactory} that provides the {@link TrustManager}s + * that verifies the certificates sent from clients. + * {@code null} to use the default or the results of parsing + * {@code trustCertCollectionFile}. + * @param keyCertChainFile an X.509 certificate chain file in PEM format + * @param keyFile a PKCS#8 private key file in PEM format + * @param keyPassword the password of the {@code keyFile}. + * {@code null} if it's not password-protected. + * @param keyManagerFactory the {@link KeyManagerFactory} that provides the {@link KeyManager}s + * that is used to encrypt data being sent to clients. + * {@code null} to use the default or the results of parsing + * {@code keyCertChainFile} and {@code keyFile}. + * @param ciphers the cipher suites to enable, in the order of preference. + * {@code null} to use the default cipher suites. + * @param cipherFilter a filter to apply over the supplied list of ciphers + * Only required if {@code provider} is {@link SslProvider#JDK} + * @param apn Provides a means to configure parameters related to application protocol negotiation. + * @param sessionCacheSize the size of the cache used for storing SSL session objects. + * {@code 0} to use the default value. + * @param sessionTimeout the timeout for the cached SSL session objects, in seconds. + * {@code 0} to use the default value. + * @deprecated use {@link SslContextBuilder} + */ + @Deprecated + public JdkSslServerContext(File trustCertCollectionFile, TrustManagerFactory trustManagerFactory, + File keyCertChainFile, File keyFile, String keyPassword, + KeyManagerFactory keyManagerFactory, + Iterable ciphers, CipherSuiteFilter cipherFilter, ApplicationProtocolConfig apn, + long sessionCacheSize, long sessionTimeout) throws SSLException { + super(newSSLContext(null, toX509CertificatesInternal(trustCertCollectionFile), trustManagerFactory, + toX509CertificatesInternal(keyCertChainFile), toPrivateKeyInternal(keyFile, keyPassword), + keyPassword, keyManagerFactory, sessionCacheSize, sessionTimeout, null), false, + ciphers, cipherFilter, apn, ClientAuth.NONE, null, false); + } + + /** + * Creates a new instance. + * @param trustCertCollectionFile an X.509 certificate collection file in PEM format. + * This provides the certificate collection used for mutual authentication. + * {@code null} to use the system default + * @param trustManagerFactory the {@link TrustManagerFactory} that provides the {@link TrustManager}s + * that verifies the certificates sent from clients. + * {@code null} to use the default or the results of parsing + * {@code trustCertCollectionFile} + * @param keyCertChainFile an X.509 certificate chain file in PEM format + * @param keyFile a PKCS#8 private key file in PEM format + * @param keyPassword the password of the {@code keyFile}. + * {@code null} if it's not password-protected. + * @param keyManagerFactory the {@link KeyManagerFactory} that provides the {@link KeyManager}s + * that is used to encrypt data being sent to clients. + * {@code null} to use the default or the results of parsing + * {@code keyCertChainFile} and {@code keyFile}. + * @param ciphers the cipher suites to enable, in the order of preference. + * {@code null} to use the default cipher suites. + * @param cipherFilter a filter to apply over the supplied list of ciphers + * Only required if {@code provider} is {@link SslProvider#JDK} + * @param apn Application Protocol Negotiator object. + * @param sessionCacheSize the size of the cache used for storing SSL session objects. + * {@code 0} to use the default value. + * @param sessionTimeout the timeout for the cached SSL session objects, in seconds. + * {@code 0} to use the default value + * @deprecated use {@link SslContextBuilder} + */ + @Deprecated + public JdkSslServerContext(File trustCertCollectionFile, TrustManagerFactory trustManagerFactory, + File keyCertChainFile, File keyFile, String keyPassword, + KeyManagerFactory keyManagerFactory, + Iterable ciphers, CipherSuiteFilter cipherFilter, + JdkApplicationProtocolNegotiator apn, + long sessionCacheSize, long sessionTimeout) throws SSLException { + super(newSSLContext(null, toX509CertificatesInternal(trustCertCollectionFile), trustManagerFactory, + toX509CertificatesInternal(keyCertChainFile), toPrivateKeyInternal(keyFile, keyPassword), + keyPassword, keyManagerFactory, sessionCacheSize, sessionTimeout, KeyStore.getDefaultType()), false, + ciphers, cipherFilter, apn, ClientAuth.NONE, null, false); + } + + JdkSslServerContext(Provider provider, + X509Certificate[] trustCertCollection, TrustManagerFactory trustManagerFactory, + X509Certificate[] keyCertChain, PrivateKey key, String keyPassword, + KeyManagerFactory keyManagerFactory, Iterable ciphers, CipherSuiteFilter cipherFilter, + ApplicationProtocolConfig apn, long sessionCacheSize, long sessionTimeout, + ClientAuth clientAuth, String[] protocols, boolean startTls, + String keyStore) throws SSLException { + super(newSSLContext(provider, trustCertCollection, trustManagerFactory, keyCertChain, key, + keyPassword, keyManagerFactory, sessionCacheSize, sessionTimeout, keyStore), false, + ciphers, cipherFilter, toNegotiator(apn, true), clientAuth, protocols, startTls); + } + + private static SSLContext newSSLContext(Provider sslContextProvider, X509Certificate[] trustCertCollection, + TrustManagerFactory trustManagerFactory, X509Certificate[] keyCertChain, + PrivateKey key, String keyPassword, KeyManagerFactory keyManagerFactory, + long sessionCacheSize, long sessionTimeout, String keyStore) + throws SSLException { + if (key == null && keyManagerFactory == null) { + throw new NullPointerException("key, keyManagerFactory"); + } + + try { + if (trustCertCollection != null) { + trustManagerFactory = buildTrustManagerFactory(trustCertCollection, trustManagerFactory, keyStore); + } else if (trustManagerFactory == null) { + // Mimic the way SSLContext.getInstance(KeyManager[], null, null) works + trustManagerFactory = TrustManagerFactory.getInstance( + TrustManagerFactory.getDefaultAlgorithm()); + trustManagerFactory.init((KeyStore) null); + } + + if (key != null) { + keyManagerFactory = buildKeyManagerFactory(keyCertChain, null, + key, keyPassword, keyManagerFactory, null); + } + + // Initialize the SSLContext to work with our key managers. + SSLContext ctx = sslContextProvider == null ? SSLContext.getInstance(PROTOCOL) + : SSLContext.getInstance(PROTOCOL, sslContextProvider); + ctx.init(keyManagerFactory.getKeyManagers(), + wrapTrustManagerIfNeeded(trustManagerFactory.getTrustManagers()), + null); + + SSLSessionContext sessCtx = ctx.getServerSessionContext(); + if (sessionCacheSize > 0) { + sessCtx.setSessionCacheSize((int) Math.min(sessionCacheSize, Integer.MAX_VALUE)); + } + if (sessionTimeout > 0) { + sessCtx.setSessionTimeout((int) Math.min(sessionTimeout, Integer.MAX_VALUE)); + } + return ctx; + } catch (Exception e) { + if (e instanceof SSLException) { + throw (SSLException) e; + } + throw new SSLException("failed to initialize the server-side SSL context", e); + } + } + + @SuppressJava6Requirement(reason = "Guarded by java version check") + private static TrustManager[] wrapTrustManagerIfNeeded(TrustManager[] trustManagers) { + if (PlatformDependent.javaVersion() >= 7) { + for (int i = 0; i < trustManagers.length; i++) { + TrustManager tm = trustManagers[i]; + if (tm instanceof X509ExtendedTrustManager) { + // Wrap the TrustManager to provide a better exception message for users to debug hostname + // validation failures. + trustManagers[i] = new EnhancingX509ExtendedTrustManager((X509ExtendedTrustManager) tm); + } + } + } + return trustManagers; + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/NotSslRecordException.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/NotSslRecordException.java new file mode 100644 index 0000000..4e74a61 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/NotSslRecordException.java @@ -0,0 +1,48 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import javax.net.ssl.SSLException; + +/** + * Special {@link SSLException} which will get thrown if a packet is + * received that not looks like a TLS/SSL record. A user can check for + * this {@link NotSslRecordException} and so detect if one peer tries to + * use secure and the other plain connection. + * + * + */ +public class NotSslRecordException extends SSLException { + + private static final long serialVersionUID = -4316784434770656841L; + + public NotSslRecordException() { + super(""); + } + + public NotSslRecordException(String message) { + super(message); + } + + public NotSslRecordException(Throwable cause) { + super(cause); + } + + public NotSslRecordException(String message, Throwable cause) { + super(message, cause); + } + +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSsl.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSsl.java new file mode 100644 index 0000000..78f11ee --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSsl.java @@ -0,0 +1,790 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.ssl; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.UnpooledByteBufAllocator; +import io.netty.internal.tcnative.Buffer; +import io.netty.internal.tcnative.Library; +import io.netty.internal.tcnative.SSL; +import io.netty.internal.tcnative.SSLContext; +import io.netty.util.CharsetUtil; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.ReferenceCounted; +import io.netty.util.internal.EmptyArrays; +import io.netty.util.internal.NativeLibraryLoader; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.StringUtil; +import io.netty.util.internal.SystemPropertyUtil; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.io.ByteArrayInputStream; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashSet; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Set; + +import static io.netty.handler.ssl.SslUtils.*; + +/** + * Tells if {@code netty-tcnative} and its OpenSSL support + * are available. + */ +public final class OpenSsl { + + private static final InternalLogger logger = InternalLoggerFactory.getInstance(OpenSsl.class); + private static final Throwable UNAVAILABILITY_CAUSE; + + static final List DEFAULT_CIPHERS; + static final Set AVAILABLE_CIPHER_SUITES; + private static final Set AVAILABLE_OPENSSL_CIPHER_SUITES; + private static final Set AVAILABLE_JAVA_CIPHER_SUITES; + private static final boolean SUPPORTS_KEYMANAGER_FACTORY; + private static final boolean USE_KEYMANAGER_FACTORY; + private static final boolean SUPPORTS_OCSP; + private static final boolean TLSV13_SUPPORTED; + private static final boolean IS_BORINGSSL; + private static final Set CLIENT_DEFAULT_PROTOCOLS; + private static final Set SERVER_DEFAULT_PROTOCOLS; + static final Set SUPPORTED_PROTOCOLS_SET; + static final String[] EXTRA_SUPPORTED_TLS_1_3_CIPHERS; + static final String EXTRA_SUPPORTED_TLS_1_3_CIPHERS_STRING; + static final String[] NAMED_GROUPS; + + static final boolean JAVAX_CERTIFICATE_CREATION_SUPPORTED; + + // Use default that is supported in java 11 and earlier and also in OpenSSL / BoringSSL. + // See https://github.com/netty/netty-tcnative/issues/567 + // See https://www.java.com/en/configure_crypto.html for ordering + private static final String[] DEFAULT_NAMED_GROUPS = { "x25519", "secp256r1", "secp384r1", "secp521r1" }; + + // self-signed certificate for netty.io and the matching private-key + private static final String CERT = "-----BEGIN CERTIFICATE-----\n" + + "MIICrjCCAZagAwIBAgIIdSvQPv1QAZQwDQYJKoZIhvcNAQELBQAwFjEUMBIGA1UEAxMLZXhhbXBs\n" + + "ZS5jb20wIBcNMTgwNDA2MjIwNjU5WhgPOTk5OTEyMzEyMzU5NTlaMBYxFDASBgNVBAMTC2V4YW1w\n" + + "bGUuY29tMIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAggbWsmDQ6zNzRZ5AW8E3eoGl\n" + + "qWvOBDb5Fs1oBRrVQHuYmVAoaqwDzXYJ0LOwa293AgWEQ1jpcbZ2hpoYQzqEZBTLnFhMrhRFlH6K\n" + + "bJND8Y33kZ/iSVBBDuGbdSbJShlM+4WwQ9IAso4MZ4vW3S1iv5fGGpLgbtXRmBf/RU8omN0Gijlv\n" + + "WlLWHWijLN8xQtySFuBQ7ssW8RcKAary3pUm6UUQB+Co6lnfti0Tzag8PgjhAJq2Z3wbsGRnP2YS\n" + + "vYoaK6qzmHXRYlp/PxrjBAZAmkLJs4YTm/XFF+fkeYx4i9zqHbyone5yerRibsHaXZWLnUL+rFoe\n" + + "MdKvr0VS3sGmhQIDAQABMA0GCSqGSIb3DQEBCwUAA4IBAQADQi441pKmXf9FvUV5EHU4v8nJT9Iq\n" + + "yqwsKwXnr7AsUlDGHBD7jGrjAXnG5rGxuNKBQ35wRxJATKrUtyaquFUL6H8O6aGQehiFTk6zmPbe\n" + + "12Gu44vqqTgIUxnv3JQJiox8S2hMxsSddpeCmSdvmalvD6WG4NthH6B9ZaBEiep1+0s0RUaBYn73\n" + + "I7CCUaAtbjfR6pcJjrFk5ei7uwdQZFSJtkP2z8r7zfeANJddAKFlkaMWn7u+OIVuB4XPooWicObk\n" + + "NAHFtP65bocUYnDpTVdiyvn8DdqyZ/EO8n1bBKBzuSLplk2msW4pdgaFgY7Vw/0wzcFXfUXmL1uy\n" + + "G8sQD/wx\n" + + "-----END CERTIFICATE-----"; + + private static final String KEY = "-----BEGIN PRIVATE KEY-----\n" + + "MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQCCBtayYNDrM3NFnkBbwTd6gaWp\n" + + "a84ENvkWzWgFGtVAe5iZUChqrAPNdgnQs7Brb3cCBYRDWOlxtnaGmhhDOoRkFMucWEyuFEWUfops\n" + + "k0PxjfeRn+JJUEEO4Zt1JslKGUz7hbBD0gCyjgxni9bdLWK/l8YakuBu1dGYF/9FTyiY3QaKOW9a\n" + + "UtYdaKMs3zFC3JIW4FDuyxbxFwoBqvLelSbpRRAH4KjqWd+2LRPNqDw+COEAmrZnfBuwZGc/ZhK9\n" + + "ihorqrOYddFiWn8/GuMEBkCaQsmzhhOb9cUX5+R5jHiL3OodvKid7nJ6tGJuwdpdlYudQv6sWh4x\n" + + "0q+vRVLewaaFAgMBAAECggEAP8tPJvFtTxhNJAkCloHz0D0vpDHqQBMgntlkgayqmBqLwhyb18pR\n" + + "i0qwgh7HHc7wWqOOQuSqlEnrWRrdcI6TSe8R/sErzfTQNoznKWIPYcI/hskk4sdnQ//Yn9/Jvnsv\n" + + "U/BBjOTJxtD+sQbhAl80JcA3R+5sArURQkfzzHOL/YMqzAsn5hTzp7HZCxUqBk3KaHRxV7NefeOE\n" + + "xlZuWSmxYWfbFIs4kx19/1t7h8CHQWezw+G60G2VBtSBBxDnhBWvqG6R/wpzJ3nEhPLLY9T+XIHe\n" + + "ipzdMOOOUZorfIg7M+pyYPji+ZIZxIpY5OjrOzXHciAjRtr5Y7l99K1CG1LguQKBgQDrQfIMxxtZ\n" + + "vxU/1cRmUV9l7pt5bjV5R6byXq178LxPKVYNjdZ840Q0/OpZEVqaT1xKVi35ohP1QfNjxPLlHD+K\n" + + "iDAR9z6zkwjIrbwPCnb5kuXy4lpwPcmmmkva25fI7qlpHtbcuQdoBdCfr/KkKaUCMPyY89LCXgEw\n" + + "5KTDj64UywKBgQCNfbO+eZLGzhiHhtNJurresCsIGWlInv322gL8CSfBMYl6eNfUTZvUDdFhPISL\n" + + "UljKWzXDrjw0ujFSPR0XhUGtiq89H+HUTuPPYv25gVXO+HTgBFZEPl4PpA+BUsSVZy0NddneyqLk\n" + + "42Wey9omY9Q8WsdNQS5cbUvy0uG6WFoX7wKBgQDZ1jpW8pa0x2bZsQsm4vo+3G5CRnZlUp+XlWt2\n" + + "dDcp5dC0xD1zbs1dc0NcLeGDOTDv9FSl7hok42iHXXq8AygjEm/QcuwwQ1nC2HxmQP5holAiUs4D\n" + + "WHM8PWs3wFYPzE459EBoKTxeaeP/uWAn+he8q7d5uWvSZlEcANs/6e77eQKBgD21Ar0hfFfj7mK8\n" + + "9E0FeRZBsqK3omkfnhcYgZC11Xa2SgT1yvs2Va2n0RcdM5kncr3eBZav2GYOhhAdwyBM55XuE/sO\n" + + "eokDVutNeuZ6d5fqV96TRaRBpvgfTvvRwxZ9hvKF4Vz+9wfn/JvCwANaKmegF6ejs7pvmF3whq2k\n" + + "drZVAoGAX5YxQ5XMTD0QbMAl7/6qp6S58xNoVdfCkmkj1ZLKaHKIjS/benkKGlySVQVPexPfnkZx\n" + + "p/Vv9yyphBoudiTBS9Uog66ueLYZqpgxlM/6OhYg86Gm3U2ycvMxYjBM1NFiyze21AqAhI+HX+Ot\n" + + "mraV2/guSgDgZAhukRZzeQ2RucI=\n" + + "-----END PRIVATE KEY-----"; + + static { + Throwable cause = null; + + if (SystemPropertyUtil.getBoolean("io.netty.handler.ssl.noOpenSsl", false)) { + cause = new UnsupportedOperationException( + "OpenSSL was explicit disabled with -Dio.netty.handler.ssl.noOpenSsl=true"); + + logger.debug( + "netty-tcnative explicit disabled; " + + OpenSslEngine.class.getSimpleName() + " will be unavailable.", cause); + } else { + // Test if netty-tcnative is in the classpath first. + try { + Class.forName("io.netty.internal.tcnative.SSLContext", false, + PlatformDependent.getClassLoader(OpenSsl.class)); + } catch (ClassNotFoundException t) { + cause = t; + logger.debug( + "netty-tcnative not in the classpath; " + + OpenSslEngine.class.getSimpleName() + " will be unavailable."); + } + + // If in the classpath, try to load the native library and initialize netty-tcnative. + if (cause == null) { + try { + // The JNI library was not already loaded. Load it now. + loadTcNative(); + } catch (Throwable t) { + cause = t; + logger.debug( + "Failed to load netty-tcnative; " + + OpenSslEngine.class.getSimpleName() + " will be unavailable, unless the " + + "application has already loaded the symbols by some other means. " + + "See https://netty.io/wiki/forked-tomcat-native.html for more information.", t); + } + + try { + String engine = SystemPropertyUtil.get("io.netty.handler.ssl.openssl.engine", null); + if (engine == null) { + logger.debug("Initialize netty-tcnative using engine: 'default'"); + } else { + logger.debug("Initialize netty-tcnative using engine: '{}'", engine); + } + initializeTcNative(engine); + + // The library was initialized successfully. If loading the library failed above, + // reset the cause now since it appears that the library was loaded by some other + // means. + cause = null; + } catch (Throwable t) { + if (cause == null) { + cause = t; + } + logger.debug( + "Failed to initialize netty-tcnative; " + + OpenSslEngine.class.getSimpleName() + " will be unavailable. " + + "See https://netty.io/wiki/forked-tomcat-native.html for more information.", t); + } + } + } + + UNAVAILABILITY_CAUSE = cause; + CLIENT_DEFAULT_PROTOCOLS = protocols("jdk.tls.client.protocols"); + SERVER_DEFAULT_PROTOCOLS = protocols("jdk.tls.server.protocols"); + + if (cause == null) { + logger.debug("netty-tcnative using native library: {}", SSL.versionString()); + + final List defaultCiphers = new ArrayList(); + final Set availableOpenSslCipherSuites = new LinkedHashSet(128); + boolean supportsKeyManagerFactory = false; + boolean useKeyManagerFactory = false; + boolean tlsv13Supported = false; + String[] namedGroups = DEFAULT_NAMED_GROUPS; + String[] defaultConvertedNamedGroups = new String[namedGroups.length]; + for (int i = 0; i < namedGroups.length; i++) { + defaultConvertedNamedGroups[i] = GroupsConverter.toOpenSsl(namedGroups[i]); + } + + IS_BORINGSSL = "BoringSSL".equals(versionString()); + if (IS_BORINGSSL) { + EXTRA_SUPPORTED_TLS_1_3_CIPHERS = new String [] { "TLS_AES_128_GCM_SHA256", + "TLS_AES_256_GCM_SHA384" , + "TLS_CHACHA20_POLY1305_SHA256" }; + + StringBuilder ciphersBuilder = new StringBuilder(128); + for (String cipher: EXTRA_SUPPORTED_TLS_1_3_CIPHERS) { + ciphersBuilder.append(cipher).append(":"); + } + ciphersBuilder.setLength(ciphersBuilder.length() - 1); + EXTRA_SUPPORTED_TLS_1_3_CIPHERS_STRING = ciphersBuilder.toString(); + } else { + EXTRA_SUPPORTED_TLS_1_3_CIPHERS = EmptyArrays.EMPTY_STRINGS; + EXTRA_SUPPORTED_TLS_1_3_CIPHERS_STRING = StringUtil.EMPTY_STRING; + } + + try { + final long sslCtx = SSLContext.make(SSL.SSL_PROTOCOL_ALL, SSL.SSL_MODE_SERVER); + long certBio = 0; + long keyBio = 0; + long cert = 0; + long key = 0; + try { + // As we delegate to the KeyManager / TrustManager of the JDK we need to ensure it can actually + // handle TLSv13 as otherwise we may see runtime exceptions + if (SslProvider.isTlsv13Supported(SslProvider.JDK)) { + try { + StringBuilder tlsv13Ciphers = new StringBuilder(); + + for (String cipher : TLSV13_CIPHERS) { + String converted = CipherSuiteConverter.toOpenSsl(cipher, IS_BORINGSSL); + if (converted != null) { + tlsv13Ciphers.append(converted).append(':'); + } + } + if (tlsv13Ciphers.length() == 0) { + tlsv13Supported = false; + } else { + tlsv13Ciphers.setLength(tlsv13Ciphers.length() - 1); + SSLContext.setCipherSuite(sslCtx, tlsv13Ciphers.toString(), true); + tlsv13Supported = true; + } + + } catch (Exception ignore) { + tlsv13Supported = false; + } + } + + SSLContext.setCipherSuite(sslCtx, "ALL", false); + + final long ssl = SSL.newSSL(sslCtx, true); + try { + for (String c: SSL.getCiphers(ssl)) { + // Filter out bad input. + if (c == null || c.isEmpty() || availableOpenSslCipherSuites.contains(c) || + // Filter out TLSv1.3 ciphers if not supported. + !tlsv13Supported && isTLSv13Cipher(c)) { + continue; + } + availableOpenSslCipherSuites.add(c); + } + if (IS_BORINGSSL) { + // Currently BoringSSL does not include these when calling SSL.getCiphers() even when these + // are supported. + Collections.addAll(availableOpenSslCipherSuites, EXTRA_SUPPORTED_TLS_1_3_CIPHERS); + Collections.addAll(availableOpenSslCipherSuites, + "AEAD-AES128-GCM-SHA256", + "AEAD-AES256-GCM-SHA384", + "AEAD-CHACHA20-POLY1305-SHA256"); + } + + PemEncoded privateKey = PemPrivateKey.valueOf(KEY.getBytes(CharsetUtil.US_ASCII)); + try { + // Let's check if we can set a callback, which may not work if the used OpenSSL version + // is to old. + SSLContext.setCertificateCallback(sslCtx, null); + + X509Certificate certificate = selfSignedCertificate(); + certBio = ReferenceCountedOpenSslContext.toBIO(ByteBufAllocator.DEFAULT, certificate); + cert = SSL.parseX509Chain(certBio); + + keyBio = ReferenceCountedOpenSslContext.toBIO( + UnpooledByteBufAllocator.DEFAULT, privateKey.retain()); + key = SSL.parsePrivateKey(keyBio, null); + + SSL.setKeyMaterial(ssl, cert, key); + supportsKeyManagerFactory = true; + try { + boolean propertySet = SystemPropertyUtil.contains( + "io.netty.handler.ssl.openssl.useKeyManagerFactory"); + if (!IS_BORINGSSL) { + useKeyManagerFactory = SystemPropertyUtil.getBoolean( + "io.netty.handler.ssl.openssl.useKeyManagerFactory", true); + + if (propertySet) { + logger.info("System property " + + "'io.netty.handler.ssl.openssl.useKeyManagerFactory'" + + " is deprecated and so will be ignored in the future"); + } + } else { + useKeyManagerFactory = true; + if (propertySet) { + logger.info("System property " + + "'io.netty.handler.ssl.openssl.useKeyManagerFactory'" + + " is deprecated and will be ignored when using BoringSSL"); + } + } + } catch (Throwable ignore) { + logger.debug("Failed to get useKeyManagerFactory system property."); + } + } catch (Error ignore) { + logger.debug("KeyManagerFactory not supported."); + } finally { + privateKey.release(); + } + } finally { + SSL.freeSSL(ssl); + if (certBio != 0) { + SSL.freeBIO(certBio); + } + if (keyBio != 0) { + SSL.freeBIO(keyBio); + } + if (cert != 0) { + SSL.freeX509Chain(cert); + } + if (key != 0) { + SSL.freePrivateKey(key); + } + } + + String groups = SystemPropertyUtil.get("jdk.tls.namedGroups", null); + if (groups != null) { + String[] nGroups = groups.split(","); + Set supportedNamedGroups = new LinkedHashSet(nGroups.length); + Set supportedConvertedNamedGroups = new LinkedHashSet(nGroups.length); + + Set unsupportedNamedGroups = new LinkedHashSet(); + for (String namedGroup : nGroups) { + String converted = GroupsConverter.toOpenSsl(namedGroup); + if (SSLContext.setCurvesList(sslCtx, converted)) { + supportedConvertedNamedGroups.add(converted); + supportedNamedGroups.add(namedGroup); + } else { + unsupportedNamedGroups.add(namedGroup); + } + } + + if (supportedNamedGroups.isEmpty()) { + namedGroups = defaultConvertedNamedGroups; + logger.info("All configured namedGroups are not supported: {}. Use default: {}.", + Arrays.toString(unsupportedNamedGroups.toArray(EmptyArrays.EMPTY_STRINGS)), + Arrays.toString(DEFAULT_NAMED_GROUPS)); + } else { + String[] groupArray = supportedNamedGroups.toArray(EmptyArrays.EMPTY_STRINGS); + if (unsupportedNamedGroups.isEmpty()) { + logger.info("Using configured namedGroups -D 'jdk.tls.namedGroup': {} ", + Arrays.toString(groupArray)); + } else { + logger.info("Using supported configured namedGroups: {}. Unsupported namedGroups: {}. ", + Arrays.toString(groupArray), + Arrays.toString(unsupportedNamedGroups.toArray(EmptyArrays.EMPTY_STRINGS))); + } + namedGroups = supportedConvertedNamedGroups.toArray(EmptyArrays.EMPTY_STRINGS); + } + } else { + namedGroups = defaultConvertedNamedGroups; + } + } finally { + SSLContext.free(sslCtx); + } + } catch (Exception e) { + logger.warn("Failed to get the list of available OpenSSL cipher suites.", e); + } + NAMED_GROUPS = namedGroups; + AVAILABLE_OPENSSL_CIPHER_SUITES = Collections.unmodifiableSet(availableOpenSslCipherSuites); + final Set availableJavaCipherSuites = new LinkedHashSet( + AVAILABLE_OPENSSL_CIPHER_SUITES.size() * 2); + for (String cipher: AVAILABLE_OPENSSL_CIPHER_SUITES) { + // Included converted but also openssl cipher name + if (!isTLSv13Cipher(cipher)) { + availableJavaCipherSuites.add(CipherSuiteConverter.toJava(cipher, "TLS")); + availableJavaCipherSuites.add(CipherSuiteConverter.toJava(cipher, "SSL")); + } else { + // TLSv1.3 ciphers have the correct format. + availableJavaCipherSuites.add(cipher); + } + } + + addIfSupported(availableJavaCipherSuites, defaultCiphers, DEFAULT_CIPHER_SUITES); + addIfSupported(availableJavaCipherSuites, defaultCiphers, TLSV13_CIPHER_SUITES); + // Also handle the extra supported ciphers as these will contain some more stuff on BoringSSL. + addIfSupported(availableJavaCipherSuites, defaultCiphers, EXTRA_SUPPORTED_TLS_1_3_CIPHERS); + + useFallbackCiphersIfDefaultIsEmpty(defaultCiphers, availableJavaCipherSuites); + DEFAULT_CIPHERS = Collections.unmodifiableList(defaultCiphers); + + AVAILABLE_JAVA_CIPHER_SUITES = Collections.unmodifiableSet(availableJavaCipherSuites); + + final Set availableCipherSuites = new LinkedHashSet( + AVAILABLE_OPENSSL_CIPHER_SUITES.size() + AVAILABLE_JAVA_CIPHER_SUITES.size()); + availableCipherSuites.addAll(AVAILABLE_OPENSSL_CIPHER_SUITES); + availableCipherSuites.addAll(AVAILABLE_JAVA_CIPHER_SUITES); + + AVAILABLE_CIPHER_SUITES = availableCipherSuites; + SUPPORTS_KEYMANAGER_FACTORY = supportsKeyManagerFactory; + USE_KEYMANAGER_FACTORY = useKeyManagerFactory; + + Set protocols = new LinkedHashSet(6); + // Seems like there is no way to explicitly disable SSLv2Hello in openssl so it is always enabled + protocols.add(SslProtocols.SSL_v2_HELLO); + if (doesSupportProtocol(SSL.SSL_PROTOCOL_SSLV2, SSL.SSL_OP_NO_SSLv2)) { + protocols.add(SslProtocols.SSL_v2); + } + if (doesSupportProtocol(SSL.SSL_PROTOCOL_SSLV3, SSL.SSL_OP_NO_SSLv3)) { + protocols.add(SslProtocols.SSL_v3); + } + if (doesSupportProtocol(SSL.SSL_PROTOCOL_TLSV1, SSL.SSL_OP_NO_TLSv1)) { + protocols.add(SslProtocols.TLS_v1); + } + if (doesSupportProtocol(SSL.SSL_PROTOCOL_TLSV1_1, SSL.SSL_OP_NO_TLSv1_1)) { + protocols.add(SslProtocols.TLS_v1_1); + } + if (doesSupportProtocol(SSL.SSL_PROTOCOL_TLSV1_2, SSL.SSL_OP_NO_TLSv1_2)) { + protocols.add(SslProtocols.TLS_v1_2); + } + + // This is only supported by java8u272 and later. + if (tlsv13Supported && doesSupportProtocol(SSL.SSL_PROTOCOL_TLSV1_3, SSL.SSL_OP_NO_TLSv1_3)) { + protocols.add(SslProtocols.TLS_v1_3); + TLSV13_SUPPORTED = true; + } else { + TLSV13_SUPPORTED = false; + } + + SUPPORTED_PROTOCOLS_SET = Collections.unmodifiableSet(protocols); + SUPPORTS_OCSP = doesSupportOcsp(); + + if (logger.isDebugEnabled()) { + logger.debug("Supported protocols (OpenSSL): {} ", SUPPORTED_PROTOCOLS_SET); + logger.debug("Default cipher suites (OpenSSL): {}", DEFAULT_CIPHERS); + } + + // Check if we can create a javax.security.cert.X509Certificate from our cert. This might fail on + // JDK17 and above. In this case we will later throw an UnsupportedOperationException if someone + // tries to access these via SSLSession. See https://github.com/netty/netty/issues/13560. + boolean javaxCertificateCreationSupported; + try { + javax.security.cert.X509Certificate.getInstance(CERT.getBytes(CharsetUtil.US_ASCII)); + javaxCertificateCreationSupported = true; + } catch (javax.security.cert.CertificateException ex) { + javaxCertificateCreationSupported = false; + } + JAVAX_CERTIFICATE_CREATION_SUPPORTED = javaxCertificateCreationSupported; + } else { + DEFAULT_CIPHERS = Collections.emptyList(); + AVAILABLE_OPENSSL_CIPHER_SUITES = Collections.emptySet(); + AVAILABLE_JAVA_CIPHER_SUITES = Collections.emptySet(); + AVAILABLE_CIPHER_SUITES = Collections.emptySet(); + SUPPORTS_KEYMANAGER_FACTORY = false; + USE_KEYMANAGER_FACTORY = false; + SUPPORTED_PROTOCOLS_SET = Collections.emptySet(); + SUPPORTS_OCSP = false; + TLSV13_SUPPORTED = false; + IS_BORINGSSL = false; + EXTRA_SUPPORTED_TLS_1_3_CIPHERS = EmptyArrays.EMPTY_STRINGS; + EXTRA_SUPPORTED_TLS_1_3_CIPHERS_STRING = StringUtil.EMPTY_STRING; + NAMED_GROUPS = DEFAULT_NAMED_GROUPS; + JAVAX_CERTIFICATE_CREATION_SUPPORTED = false; + } + } + + static String checkTls13Ciphers(InternalLogger logger, String ciphers) { + if (IS_BORINGSSL && !ciphers.isEmpty()) { + assert EXTRA_SUPPORTED_TLS_1_3_CIPHERS.length > 0; + Set boringsslTlsv13Ciphers = new HashSet(EXTRA_SUPPORTED_TLS_1_3_CIPHERS.length); + Collections.addAll(boringsslTlsv13Ciphers, EXTRA_SUPPORTED_TLS_1_3_CIPHERS); + boolean ciphersNotMatch = false; + for (String cipher: ciphers.split(":")) { + if (boringsslTlsv13Ciphers.isEmpty()) { + ciphersNotMatch = true; + break; + } + if (!boringsslTlsv13Ciphers.remove(cipher) && + !boringsslTlsv13Ciphers.remove(CipherSuiteConverter.toJava(cipher, "TLS"))) { + ciphersNotMatch = true; + break; + } + } + + // Also check if there are ciphers left. + ciphersNotMatch |= !boringsslTlsv13Ciphers.isEmpty(); + + if (ciphersNotMatch) { + if (logger.isInfoEnabled()) { + StringBuilder javaCiphers = new StringBuilder(128); + for (String cipher : ciphers.split(":")) { + javaCiphers.append(CipherSuiteConverter.toJava(cipher, "TLS")).append(":"); + } + javaCiphers.setLength(javaCiphers.length() - 1); + logger.info( + "BoringSSL doesn't allow to enable or disable TLSv1.3 ciphers explicitly." + + " Provided TLSv1.3 ciphers: '{}', default TLSv1.3 ciphers that will be used: '{}'.", + javaCiphers, EXTRA_SUPPORTED_TLS_1_3_CIPHERS_STRING); + } + return EXTRA_SUPPORTED_TLS_1_3_CIPHERS_STRING; + } + } + return ciphers; + } + + static boolean isSessionCacheSupported() { + return version() >= 0x10100000L; + } + + /** + * Returns a self-signed {@link X509Certificate} for {@code netty.io}. + */ + static X509Certificate selfSignedCertificate() throws CertificateException { + return (X509Certificate) SslContext.X509_CERT_FACTORY.generateCertificate( + new ByteArrayInputStream(CERT.getBytes(CharsetUtil.US_ASCII)) + ); + } + + private static boolean doesSupportOcsp() { + boolean supportsOcsp = false; + if (version() >= 0x10002000L) { + long sslCtx = -1; + try { + sslCtx = SSLContext.make(SSL.SSL_PROTOCOL_TLSV1_2, SSL.SSL_MODE_SERVER); + SSLContext.enableOcsp(sslCtx, false); + supportsOcsp = true; + } catch (Exception ignore) { + // ignore + } finally { + if (sslCtx != -1) { + SSLContext.free(sslCtx); + } + } + } + return supportsOcsp; + } + private static boolean doesSupportProtocol(int protocol, int opt) { + if (opt == 0) { + // If the opt is 0 the protocol is not supported. This is for example the case with BoringSSL and SSLv2. + return false; + } + long sslCtx = -1; + try { + sslCtx = SSLContext.make(protocol, SSL.SSL_MODE_COMBINED); + return true; + } catch (Exception ignore) { + return false; + } finally { + if (sslCtx != -1) { + SSLContext.free(sslCtx); + } + } + } + + /** + * Returns {@code true} if and only if + * {@code netty-tcnative} and its OpenSSL support + * are available. + */ + public static boolean isAvailable() { + return UNAVAILABILITY_CAUSE == null; + } + + /** + * Returns {@code true} if the used version of openssl supports + * ALPN. + * + * @deprecated use {@link SslProvider#isAlpnSupported(SslProvider)} with {@link SslProvider#OPENSSL}. + */ + @Deprecated + public static boolean isAlpnSupported() { + return version() >= 0x10002000L; + } + + /** + * Returns {@code true} if the used version of OpenSSL supports OCSP stapling. + */ + public static boolean isOcspSupported() { + return SUPPORTS_OCSP; + } + + /** + * Returns the version of the used available OpenSSL library or {@code -1} if {@link #isAvailable()} + * returns {@code false}. + */ + public static int version() { + return isAvailable() ? SSL.version() : -1; + } + + /** + * Returns the version string of the used available OpenSSL library or {@code null} if {@link #isAvailable()} + * returns {@code false}. + */ + public static String versionString() { + return isAvailable() ? SSL.versionString() : null; + } + + /** + * Ensure that {@code netty-tcnative} and + * its OpenSSL support are available. + * + * @throws UnsatisfiedLinkError if unavailable + */ + public static void ensureAvailability() { + if (UNAVAILABILITY_CAUSE != null) { + throw (Error) new UnsatisfiedLinkError( + "failed to load the required native library").initCause(UNAVAILABILITY_CAUSE); + } + } + + /** + * Returns the cause of unavailability of + * {@code netty-tcnative} and its OpenSSL support. + * + * @return the cause if unavailable. {@code null} if available. + */ + public static Throwable unavailabilityCause() { + return UNAVAILABILITY_CAUSE; + } + + /** + * @deprecated use {@link #availableOpenSslCipherSuites()} + */ + @Deprecated + public static Set availableCipherSuites() { + return availableOpenSslCipherSuites(); + } + + /** + * Returns all the available OpenSSL cipher suites. + * Please note that the returned array may include the cipher suites that are insecure or non-functional. + */ + public static Set availableOpenSslCipherSuites() { + return AVAILABLE_OPENSSL_CIPHER_SUITES; + } + + /** + * Returns all the available cipher suites (Java-style). + * Please note that the returned array may include the cipher suites that are insecure or non-functional. + */ + public static Set availableJavaCipherSuites() { + return AVAILABLE_JAVA_CIPHER_SUITES; + } + + /** + * Returns {@code true} if and only if the specified cipher suite is available in OpenSSL. + * Both Java-style cipher suite and OpenSSL-style cipher suite are accepted. + */ + public static boolean isCipherSuiteAvailable(String cipherSuite) { + String converted = CipherSuiteConverter.toOpenSsl(cipherSuite, IS_BORINGSSL); + if (converted != null) { + cipherSuite = converted; + } + return AVAILABLE_OPENSSL_CIPHER_SUITES.contains(cipherSuite); + } + + /** + * Returns {@code true} if {@link javax.net.ssl.KeyManagerFactory} is supported when using OpenSSL. + */ + public static boolean supportsKeyManagerFactory() { + return SUPPORTS_KEYMANAGER_FACTORY; + } + + /** + * Always returns {@code true} if {@link #isAvailable()} returns {@code true}. + * + * @deprecated Will be removed because hostname validation is always done by a + * {@link javax.net.ssl.TrustManager} implementation. + */ + @Deprecated + public static boolean supportsHostnameValidation() { + return isAvailable(); + } + + static boolean useKeyManagerFactory() { + return USE_KEYMANAGER_FACTORY; + } + + static long memoryAddress(ByteBuf buf) { + assert buf.isDirect(); + return buf.hasMemoryAddress() ? buf.memoryAddress() : + // Use internalNioBuffer to reduce object creation. + Buffer.address(buf.internalNioBuffer(0, buf.readableBytes())); + } + + private OpenSsl() { } + + private static void loadTcNative() throws Exception { + String os = PlatformDependent.normalizedOs(); + String arch = PlatformDependent.normalizedArch(); + + Set libNames = new LinkedHashSet(5); + String staticLibName = "netty_tcnative"; + + // First, try loading the platform-specific library. Platform-specific + // libraries will be available if using a tcnative uber jar. + if ("linux".equals(os)) { + Set classifiers = PlatformDependent.normalizedLinuxClassifiers(); + for (String classifier : classifiers) { + libNames.add(staticLibName + "_" + os + '_' + arch + "_" + classifier); + } + // generic arch-dependent library + libNames.add(staticLibName + "_" + os + '_' + arch); + + // Fedora SSL lib so naming (libssl.so.10 vs libssl.so.1.0.0). + // note: should already be included from the classifiers but if not, we use this as an + // additional fallback option here + libNames.add(staticLibName + "_" + os + '_' + arch + "_fedora"); + } else { + libNames.add(staticLibName + "_" + os + '_' + arch); + } + libNames.add(staticLibName + "_" + arch); + libNames.add(staticLibName); + + NativeLibraryLoader.loadFirstAvailable(PlatformDependent.getClassLoader(SSLContext.class), + libNames.toArray(EmptyArrays.EMPTY_STRINGS)); + } + + private static boolean initializeTcNative(String engine) throws Exception { + return Library.initialize("provided", engine); + } + + static void releaseIfNeeded(ReferenceCounted counted) { + if (counted.refCnt() > 0) { + ReferenceCountUtil.safeRelease(counted); + } + } + + static boolean isTlsv13Supported() { + return TLSV13_SUPPORTED; + } + + static boolean isOptionSupported(SslContextOption option) { + if (isAvailable()) { + if (option == OpenSslContextOption.USE_TASKS) { + return true; + } + // Check for options that are only supported by BoringSSL atm. + if (isBoringSSL()) { + return option == OpenSslContextOption.ASYNC_PRIVATE_KEY_METHOD || + option == OpenSslContextOption.PRIVATE_KEY_METHOD || + option == OpenSslContextOption.CERTIFICATE_COMPRESSION_ALGORITHMS || + option == OpenSslContextOption.TLS_FALSE_START || + option == OpenSslContextOption.MAX_CERTIFICATE_LIST_BYTES; + } + } + return false; + } + + private static Set protocols(String property) { + String protocolsString = SystemPropertyUtil.get(property, null); + if (protocolsString != null) { + Set protocols = new HashSet(); + for (String proto : protocolsString.split(",")) { + String p = proto.trim(); + protocols.add(p); + } + return protocols; + } + return null; + } + + static String[] defaultProtocols(boolean isClient) { + final Collection defaultProtocols = isClient ? CLIENT_DEFAULT_PROTOCOLS : SERVER_DEFAULT_PROTOCOLS; + if (defaultProtocols == null) { + return null; + } + List protocols = new ArrayList(defaultProtocols.size()); + for (String proto : defaultProtocols) { + if (SUPPORTED_PROTOCOLS_SET.contains(proto)) { + protocols.add(proto); + } + } + return protocols.toArray(EmptyArrays.EMPTY_STRINGS); + } + + static boolean isBoringSSL() { + return IS_BORINGSSL; + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslApplicationProtocolNegotiator.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslApplicationProtocolNegotiator.java new file mode 100644 index 0000000..98349ff --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslApplicationProtocolNegotiator.java @@ -0,0 +1,40 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +/** + * OpenSSL version of {@link ApplicationProtocolNegotiator}. + * + * @deprecated use {@link ApplicationProtocolConfig} + */ +@Deprecated +public interface OpenSslApplicationProtocolNegotiator extends ApplicationProtocolNegotiator { + + /** + * Returns the {@link ApplicationProtocolConfig.Protocol} which should be used. + */ + ApplicationProtocolConfig.Protocol protocol(); + + /** + * Get the desired behavior for the peer who selects the application protocol. + */ + ApplicationProtocolConfig.SelectorFailureBehavior selectorFailureBehavior(); + + /** + * Get the desired behavior for the peer who is notified of the selected protocol. + */ + ApplicationProtocolConfig.SelectedListenerFailureBehavior selectedListenerFailureBehavior(); +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslAsyncPrivateKeyMethod.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslAsyncPrivateKeyMethod.java new file mode 100644 index 0000000..27edaa6 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslAsyncPrivateKeyMethod.java @@ -0,0 +1,58 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.internal.tcnative.SSLPrivateKeyMethod; +import io.netty.util.concurrent.Future; + +import javax.net.ssl.SSLEngine; + +public interface OpenSslAsyncPrivateKeyMethod { + int SSL_SIGN_RSA_PKCS1_SHA1 = SSLPrivateKeyMethod.SSL_SIGN_RSA_PKCS1_SHA1; + int SSL_SIGN_RSA_PKCS1_SHA256 = SSLPrivateKeyMethod.SSL_SIGN_RSA_PKCS1_SHA256; + int SSL_SIGN_RSA_PKCS1_SHA384 = SSLPrivateKeyMethod.SSL_SIGN_RSA_PKCS1_SHA384; + int SSL_SIGN_RSA_PKCS1_SHA512 = SSLPrivateKeyMethod.SSL_SIGN_RSA_PKCS1_SHA512; + int SSL_SIGN_ECDSA_SHA1 = SSLPrivateKeyMethod.SSL_SIGN_ECDSA_SHA1; + int SSL_SIGN_ECDSA_SECP256R1_SHA256 = SSLPrivateKeyMethod.SSL_SIGN_ECDSA_SECP256R1_SHA256; + int SSL_SIGN_ECDSA_SECP384R1_SHA384 = SSLPrivateKeyMethod.SSL_SIGN_ECDSA_SECP384R1_SHA384; + int SSL_SIGN_ECDSA_SECP521R1_SHA512 = SSLPrivateKeyMethod.SSL_SIGN_ECDSA_SECP521R1_SHA512; + int SSL_SIGN_RSA_PSS_RSAE_SHA256 = SSLPrivateKeyMethod.SSL_SIGN_RSA_PSS_RSAE_SHA256; + int SSL_SIGN_RSA_PSS_RSAE_SHA384 = SSLPrivateKeyMethod.SSL_SIGN_RSA_PSS_RSAE_SHA384; + int SSL_SIGN_RSA_PSS_RSAE_SHA512 = SSLPrivateKeyMethod.SSL_SIGN_RSA_PSS_RSAE_SHA512; + int SSL_SIGN_ED25519 = SSLPrivateKeyMethod.SSL_SIGN_ED25519; + int SSL_SIGN_RSA_PKCS1_MD5_SHA1 = SSLPrivateKeyMethod.SSL_SIGN_RSA_PKCS1_MD5_SHA1; + + /** + * Signs the input with the given key and notifies the returned {@link Future} with the signed bytes. + * + * @param engine the {@link SSLEngine} + * @param signatureAlgorithm the algorithm to use for signing + * @param input the digest itself + * @return the {@link Future} that will be notified with the signed data + * (must not be {@code null}) when the operation completes. + */ + Future sign(SSLEngine engine, int signatureAlgorithm, byte[] input); + + /** + * Decrypts the input with the given key and notifies the returned {@link Future} with the decrypted bytes. + * + * @param engine the {@link SSLEngine} + * @param input the input which should be decrypted + * @return the {@link Future} that will be notified with the decrypted data + * (must not be {@code null}) when the operation completes. + */ + Future decrypt(SSLEngine engine, byte[] input); +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslCachingKeyMaterialProvider.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslCachingKeyMaterialProvider.java new file mode 100644 index 0000000..a55007d --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslCachingKeyMaterialProvider.java @@ -0,0 +1,79 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.buffer.ByteBufAllocator; + +import javax.net.ssl.X509KeyManager; +import java.util.Iterator; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; + +/** + * {@link OpenSslKeyMaterialProvider} that will cache the {@link OpenSslKeyMaterial} to reduce the overhead + * of parsing the chain and the key for generation of the material. + */ +final class OpenSslCachingKeyMaterialProvider extends OpenSslKeyMaterialProvider { + + private final int maxCachedEntries; + private volatile boolean full; + private final ConcurrentMap cache = new ConcurrentHashMap(); + + OpenSslCachingKeyMaterialProvider(X509KeyManager keyManager, String password, int maxCachedEntries) { + super(keyManager, password); + this.maxCachedEntries = maxCachedEntries; + } + + @Override + OpenSslKeyMaterial chooseKeyMaterial(ByteBufAllocator allocator, String alias) throws Exception { + OpenSslKeyMaterial material = cache.get(alias); + if (material == null) { + material = super.chooseKeyMaterial(allocator, alias); + if (material == null) { + // No keymaterial should be used. + return null; + } + + if (full) { + return material; + } + if (cache.size() > maxCachedEntries) { + full = true; + // Do not cache... + return material; + } + OpenSslKeyMaterial old = cache.putIfAbsent(alias, material); + if (old != null) { + material.release(); + material = old; + } + } + // We need to call retain() as we want to always have at least a refCnt() of 1 before destroy() was called. + return material.retain(); + } + + @Override + void destroy() { + // Remove and release all entries. + do { + Iterator iterator = cache.values().iterator(); + while (iterator.hasNext()) { + iterator.next().release(); + iterator.remove(); + } + } while (!cache.isEmpty()); + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslCachingX509KeyManagerFactory.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslCachingX509KeyManagerFactory.java new file mode 100644 index 0000000..7f644e2 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslCachingX509KeyManagerFactory.java @@ -0,0 +1,80 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.util.internal.ObjectUtil; + +import javax.net.ssl.KeyManager; +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.KeyManagerFactorySpi; +import javax.net.ssl.ManagerFactoryParameters; +import javax.net.ssl.X509KeyManager; +import java.security.InvalidAlgorithmParameterException; +import java.security.KeyStore; +import java.security.KeyStoreException; +import java.security.NoSuchAlgorithmException; +import java.security.PrivateKey; +import java.security.UnrecoverableKeyException; +import java.security.cert.X509Certificate; + +/** + * Wraps another {@link KeyManagerFactory} and caches its chains / certs for an alias for better performance when using + * {@link SslProvider#OPENSSL} or {@link SslProvider#OPENSSL_REFCNT}. + * + * Because of the caching its important that the wrapped {@link KeyManagerFactory}s {@link X509KeyManager}s always + * return the same {@link X509Certificate} chain and {@link PrivateKey} for the same alias. + */ +public final class OpenSslCachingX509KeyManagerFactory extends KeyManagerFactory { + + private final int maxCachedEntries; + + public OpenSslCachingX509KeyManagerFactory(final KeyManagerFactory factory) { + this(factory, 1024); + } + + public OpenSslCachingX509KeyManagerFactory(final KeyManagerFactory factory, int maxCachedEntries) { + super(new KeyManagerFactorySpi() { + @Override + protected void engineInit(KeyStore keyStore, char[] chars) + throws KeyStoreException, NoSuchAlgorithmException, UnrecoverableKeyException { + factory.init(keyStore, chars); + } + + @Override + protected void engineInit(ManagerFactoryParameters managerFactoryParameters) + throws InvalidAlgorithmParameterException { + factory.init(managerFactoryParameters); + } + + @Override + protected KeyManager[] engineGetKeyManagers() { + return factory.getKeyManagers(); + } + }, factory.getProvider(), factory.getAlgorithm()); + this.maxCachedEntries = ObjectUtil.checkPositive(maxCachedEntries, "maxCachedEntries"); + } + + OpenSslKeyMaterialProvider newProvider(String password) { + X509KeyManager keyManager = ReferenceCountedOpenSslContext.chooseX509KeyManager(getKeyManagers()); + if ("sun.security.ssl.X509KeyManagerImpl".equals(keyManager.getClass().getName())) { + // Don't do caching if X509KeyManagerImpl is used as the returned aliases are not stable and will change + // between invocations. + return new OpenSslKeyMaterialProvider(keyManager, password); + } + return new OpenSslCachingKeyMaterialProvider( + ReferenceCountedOpenSslContext.chooseX509KeyManager(getKeyManagers()), password, maxCachedEntries); + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslCertificateCompressionAlgorithm.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslCertificateCompressionAlgorithm.java new file mode 100644 index 0000000..9b28f61 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslCertificateCompressionAlgorithm.java @@ -0,0 +1,64 @@ +/* + * Copyright 2022 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import javax.net.ssl.SSLEngine; + +/** + * Provides compression and decompression implementations for TLS Certificate Compression + * (RFC 8879). + */ +public interface OpenSslCertificateCompressionAlgorithm { + + /** + * Compress the given input with the specified algorithm and return the compressed bytes. + * + * @param engine the {@link SSLEngine} + * @param uncompressedCertificate the uncompressed certificate + * @return the compressed form of the certificate + * @throws Exception thrown if an error occurs while compressing + */ + byte[] compress(SSLEngine engine, byte[] uncompressedCertificate) throws Exception; + + /** + * Decompress the given input with the specified algorithm and return the decompressed bytes. + * + *

Implementation + * Security Considerations

+ *

Implementations SHOULD bound the memory usage when decompressing the CompressedCertificate message.

+ *

+ * Implementations MUST limit the size of the resulting decompressed chain to the specified {@code uncompressedLen}, + * and they MUST abort the connection (throw an exception) if the size of the output of the decompression + * function exceeds that limit. + *

+ * + * @param engine the {@link SSLEngine} + * @param uncompressedLen the expected length of the decompressed certificate that will be returned. + * @param compressedCertificate the compressed form of the certificate + * @return the decompressed form of the certificate + * @throws Exception thrown if an error occurs while decompressing or output size exceeds + * {@code uncompressedLen} + */ + byte[] decompress(SSLEngine engine, int uncompressedLen, byte[] compressedCertificate) throws Exception; + + /** + * Return the ID for the compression algorithm provided for by a given implementation. + * + * @return compression algorithm ID as specified by + * RFC8879. + */ + int algorithmId(); +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslCertificateCompressionConfig.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslCertificateCompressionConfig.java new file mode 100644 index 0000000..c4e6563 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslCertificateCompressionConfig.java @@ -0,0 +1,137 @@ +/* + * Copyright 2022 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version + * 2.0 (the "License"); you may not use this file except in compliance with the + * License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.ssl; + +import io.netty.util.internal.ObjectUtil; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; + +/** + * Configuration for TLS1.3 certificate compression extension. + */ +public final class OpenSslCertificateCompressionConfig implements + Iterable { + private final List pairList; + + private OpenSslCertificateCompressionConfig(AlgorithmConfig... pairs) { + pairList = Collections.unmodifiableList(Arrays.asList(pairs)); + } + + @Override + public Iterator iterator() { + return pairList.iterator(); + } + + /** + * Creates a new {@link Builder} for a config. + * + * @return a bulder + */ + public static Builder newBuilder() { + return new Builder(); + } + + /** + * Builder for an {@link OpenSslCertificateCompressionAlgorithm}. + */ + public static final class Builder { + private final List algorithmList = new ArrayList(); + + private Builder() { } + + /** + * Adds a certificate compression algorithm. + * For servers, algorithm preference order is dictated by the order of algorithm registration. + * Most preferred algorithm should be registered first. + * + * @param algorithm implementation of the compression and or decompression algorithm as a + * {@link OpenSslCertificateCompressionAlgorithm} + * @param mode indicates whether decompression support should be advertized, compression should be applied + * for peers which support it, or both. This allows the caller to support one way compression + * only. + * @return self. + */ + public Builder addAlgorithm(OpenSslCertificateCompressionAlgorithm algorithm, AlgorithmMode mode) { + algorithmList.add(new AlgorithmConfig(algorithm, mode)); + return this; + } + + /** + * Build a new {@link OpenSslCertificateCompressionConfig} based on the previous + * added {@link OpenSslCertificateCompressionAlgorithm}s. + * + * @return a new config. + */ + public OpenSslCertificateCompressionConfig build() { + return new OpenSslCertificateCompressionConfig(algorithmList.toArray(new AlgorithmConfig[0])); + } + } + + /** + * The configuration for algorithm. + */ + public static final class AlgorithmConfig { + private final OpenSslCertificateCompressionAlgorithm algorithm; + private final AlgorithmMode mode; + + private AlgorithmConfig(OpenSslCertificateCompressionAlgorithm algorithm, AlgorithmMode mode) { + this.algorithm = ObjectUtil.checkNotNull(algorithm, "algorithm"); + this.mode = ObjectUtil.checkNotNull(mode, "mode"); + } + + /** + * The {@link AlgorithmMode} + * + * @return the usage mode. + */ + public AlgorithmMode mode() { + return mode; + } + + /** + * The configured {@link OpenSslCertificateCompressionAlgorithm}. + * + * @return the algorithm + */ + public OpenSslCertificateCompressionAlgorithm algorithm() { + return algorithm; + } + } + + /** + * The usage mode of the {@link OpenSslCertificateCompressionAlgorithm}. + */ + public enum AlgorithmMode { + /** + * Compression supported and should be advertized. + */ + Compress, + + /** + * Decompression supported and should be advertized. + */ + Decompress, + + /** + * Compression and Decompression are supported and both should be advertized. + */ + Both + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslCertificateException.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslCertificateException.java new file mode 100644 index 0000000..39fddf2 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslCertificateException.java @@ -0,0 +1,81 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.internal.tcnative.CertificateVerifier; + +import java.security.cert.CertificateException; + +/** + * A special {@link CertificateException} which allows to specify which error code is included in the + * SSL Record. This only work when {@link SslProvider#OPENSSL} or {@link SslProvider#OPENSSL_REFCNT} is used. + */ +public final class OpenSslCertificateException extends CertificateException { + private static final long serialVersionUID = 5542675253797129798L; + + private final int errorCode; + + /** + * Construct a new exception with the + * error code. + */ + public OpenSslCertificateException(int errorCode) { + this((String) null, errorCode); + } + + /** + * Construct a new exception with the msg and + * error code . + */ + public OpenSslCertificateException(String msg, int errorCode) { + super(msg); + this.errorCode = checkErrorCode(errorCode); + } + + /** + * Construct a new exception with the msg, cause and + * error code . + */ + public OpenSslCertificateException(String message, Throwable cause, int errorCode) { + super(message, cause); + this.errorCode = checkErrorCode(errorCode); + } + + /** + * Construct a new exception with the cause and + * error code . + */ + public OpenSslCertificateException(Throwable cause, int errorCode) { + this(null, cause, errorCode); + } + + /** + * Return the error code to use. + */ + public int errorCode() { + return errorCode; + } + + private static int checkErrorCode(int errorCode) { + // Call OpenSsl.isAvailable() to ensure we try to load the native lib as CertificateVerifier.isValid(...) + // will depend on it. If loading fails we will just skip the validation. + if (OpenSsl.isAvailable() && !CertificateVerifier.isValid(errorCode)) { + throw new IllegalArgumentException("errorCode '" + errorCode + + "' invalid, see https://www.openssl.org/docs/man1.0.2/apps/verify.html."); + } + return errorCode; + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslClientContext.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslClientContext.java new file mode 100644 index 0000000..3648b71 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslClientContext.java @@ -0,0 +1,211 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.internal.tcnative.SSL; + +import java.io.File; +import java.security.KeyStore; +import java.security.PrivateKey; +import java.security.cert.X509Certificate; +import java.util.Map; + +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLException; +import javax.net.ssl.TrustManager; +import javax.net.ssl.TrustManagerFactory; + +import static io.netty.handler.ssl.ReferenceCountedOpenSslClientContext.newSessionContext; + +/** + * A client-side {@link SslContext} which uses OpenSSL's SSL/TLS implementation. + *

This class will use a finalizer to ensure native resources are automatically cleaned up. To avoid finalizers + * and manually release the native memory see {@link ReferenceCountedOpenSslClientContext}. + */ +public final class OpenSslClientContext extends OpenSslContext { + private final OpenSslSessionContext sessionContext; + + /** + * Creates a new instance. + * @deprecated use {@link SslContextBuilder} + */ + @Deprecated + public OpenSslClientContext() throws SSLException { + this(null, null, null, null, null, null, null, IdentityCipherSuiteFilter.INSTANCE, null, 0, 0); + } + + /** + * Creates a new instance. + * + * @param certChainFile an X.509 certificate chain file in PEM format. + * {@code null} to use the system default + * @deprecated use {@link SslContextBuilder} + */ + @Deprecated + public OpenSslClientContext(File certChainFile) throws SSLException { + this(certChainFile, null); + } + + /** + * Creates a new instance. + * + * @param trustManagerFactory the {@link TrustManagerFactory} that provides the {@link TrustManager}s + * that verifies the certificates sent from servers. + * {@code null} to use the default. + * @deprecated use {@link SslContextBuilder} + */ + @Deprecated + public OpenSslClientContext(TrustManagerFactory trustManagerFactory) throws SSLException { + this(null, trustManagerFactory); + } + + /** + * Creates a new instance. + * + * @param certChainFile an X.509 certificate chain file in PEM format. + * {@code null} to use the system default + * @param trustManagerFactory the {@link TrustManagerFactory} that provides the {@link TrustManager}s + * that verifies the certificates sent from servers. + * {@code null} to use the default. + * @deprecated use {@link SslContextBuilder} + */ + @Deprecated + public OpenSslClientContext(File certChainFile, TrustManagerFactory trustManagerFactory) throws SSLException { + this(certChainFile, trustManagerFactory, null, null, null, null, null, + IdentityCipherSuiteFilter.INSTANCE, null, 0, 0); + } + + /** + * Creates a new instance. + * + * @param certChainFile an X.509 certificate chain file in PEM format + * @param trustManagerFactory the {@link TrustManagerFactory} that provides the {@link TrustManager}s + * that verifies the certificates sent from servers. + * {@code null} to use the default.. + * @param ciphers the cipher suites to enable, in the order of preference. + * {@code null} to use the default cipher suites. + * @param apn Provides a means to configure parameters related to application protocol negotiation. + * @param sessionCacheSize the size of the cache used for storing SSL session objects. + * {@code 0} to use the default value. + * @param sessionTimeout the timeout for the cached SSL session objects, in seconds. + * {@code 0} to use the default value. + * @deprecated use {@link SslContextBuilder} + */ + @Deprecated + public OpenSslClientContext(File certChainFile, TrustManagerFactory trustManagerFactory, Iterable ciphers, + ApplicationProtocolConfig apn, long sessionCacheSize, long sessionTimeout) + throws SSLException { + this(certChainFile, trustManagerFactory, null, null, null, null, ciphers, IdentityCipherSuiteFilter.INSTANCE, + apn, sessionCacheSize, sessionTimeout); + } + + /** + * Creates a new instance. + * + * @param certChainFile an X.509 certificate chain file in PEM format + * @param trustManagerFactory the {@link TrustManagerFactory} that provides the {@link TrustManager}s + * that verifies the certificates sent from servers. + * {@code null} to use the default.. + * @param ciphers the cipher suites to enable, in the order of preference. + * {@code null} to use the default cipher suites. + * @param cipherFilter a filter to apply over the supplied list of ciphers + * @param apn Provides a means to configure parameters related to application protocol negotiation. + * @param sessionCacheSize the size of the cache used for storing SSL session objects. + * {@code 0} to use the default value. + * @param sessionTimeout the timeout for the cached SSL session objects, in seconds. + * {@code 0} to use the default value. + * @deprecated use {@link SslContextBuilder} + */ + @Deprecated + public OpenSslClientContext(File certChainFile, TrustManagerFactory trustManagerFactory, Iterable ciphers, + CipherSuiteFilter cipherFilter, ApplicationProtocolConfig apn, + long sessionCacheSize, long sessionTimeout) throws SSLException { + this(certChainFile, trustManagerFactory, null, null, null, null, + ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout); + } + + /** + * Creates a new instance. + * @param trustCertCollectionFile an X.509 certificate collection file in PEM format. + * {@code null} to use the system default + * @param trustManagerFactory the {@link TrustManagerFactory} that provides the {@link TrustManager}s + * that verifies the certificates sent from servers. + * {@code null} to use the default or the results of parsing + * {@code trustCertCollectionFile} + * @param keyCertChainFile an X.509 certificate chain file in PEM format. + * This provides the public key for mutual authentication. + * {@code null} to use the system default + * @param keyFile a PKCS#8 private key file in PEM format. + * This provides the private key for mutual authentication. + * {@code null} for no mutual authentication. + * @param keyPassword the password of the {@code keyFile}. + * {@code null} if it's not password-protected. + * Ignored if {@code keyFile} is {@code null}. + * @param keyManagerFactory the {@link KeyManagerFactory} that provides the {@link javax.net.ssl.KeyManager}s + * that is used to encrypt data being sent to servers. + * {@code null} to use the default or the results of parsing + * {@code keyCertChainFile} and {@code keyFile}. + * @param ciphers the cipher suites to enable, in the order of preference. + * {@code null} to use the default cipher suites. + * @param cipherFilter a filter to apply over the supplied list of ciphers + * @param apn Application Protocol Negotiator object. + * @param sessionCacheSize the size of the cache used for storing SSL session objects. + * {@code 0} to use the default value. + * @param sessionTimeout the timeout for the cached SSL session objects, in seconds. + * {@code 0} to use the default value. + * @deprecated use {@link SslContextBuilder} + */ + @Deprecated + public OpenSslClientContext(File trustCertCollectionFile, TrustManagerFactory trustManagerFactory, + File keyCertChainFile, File keyFile, String keyPassword, + KeyManagerFactory keyManagerFactory, Iterable ciphers, + CipherSuiteFilter cipherFilter, ApplicationProtocolConfig apn, + long sessionCacheSize, long sessionTimeout) + throws SSLException { + this(toX509CertificatesInternal(trustCertCollectionFile), trustManagerFactory, + toX509CertificatesInternal(keyCertChainFile), toPrivateKeyInternal(keyFile, keyPassword), + keyPassword, keyManagerFactory, ciphers, cipherFilter, apn, null, sessionCacheSize, + sessionTimeout, false, KeyStore.getDefaultType()); + } + + OpenSslClientContext(X509Certificate[] trustCertCollection, TrustManagerFactory trustManagerFactory, + X509Certificate[] keyCertChain, PrivateKey key, String keyPassword, + KeyManagerFactory keyManagerFactory, Iterable ciphers, + CipherSuiteFilter cipherFilter, ApplicationProtocolConfig apn, String[] protocols, + long sessionCacheSize, long sessionTimeout, boolean enableOcsp, String keyStore, + Map.Entry, Object>... options) + throws SSLException { + super(ciphers, cipherFilter, apn, SSL.SSL_MODE_CLIENT, keyCertChain, + ClientAuth.NONE, protocols, false, enableOcsp, options); + boolean success = false; + try { + OpenSslKeyMaterialProvider.validateKeyMaterialSupported(keyCertChain, key, keyPassword); + sessionContext = newSessionContext(this, ctx, engineMap, trustCertCollection, trustManagerFactory, + keyCertChain, key, keyPassword, keyManagerFactory, keyStore, + sessionCacheSize, sessionTimeout); + success = true; + } finally { + if (!success) { + release(); + } + } + } + + @Override + public OpenSslSessionContext sessionContext() { + return sessionContext; + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslClientSessionCache.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslClientSessionCache.java new file mode 100644 index 0000000..f0c6daf --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslClientSessionCache.java @@ -0,0 +1,138 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.internal.tcnative.SSL; +import io.netty.util.AsciiString; + +import java.util.HashMap; +import java.util.Map; + +/** + * {@link OpenSslSessionCache} that is used by the client-side. + */ +final class OpenSslClientSessionCache extends OpenSslSessionCache { + // TODO: Should we support to have a List of OpenSslSessions for a Host/Port key and so be able to + // support sessions for different protocols / ciphers to the same remote peer ? + private final Map sessions = new HashMap(); + + OpenSslClientSessionCache(OpenSslEngineMap engineMap) { + super(engineMap); + } + + @Override + protected boolean sessionCreated(NativeSslSession session) { + assert Thread.holdsLock(this); + HostPort hostPort = keyFor(session.getPeerHost(), session.getPeerPort()); + if (hostPort == null || sessions.containsKey(hostPort)) { + return false; + } + sessions.put(hostPort, session); + return true; + } + + @Override + protected void sessionRemoved(NativeSslSession session) { + assert Thread.holdsLock(this); + HostPort hostPort = keyFor(session.getPeerHost(), session.getPeerPort()); + if (hostPort == null) { + return; + } + sessions.remove(hostPort); + } + + @Override + void setSession(long ssl, String host, int port) { + HostPort hostPort = keyFor(host, port); + if (hostPort == null) { + return; + } + final NativeSslSession session; + final boolean reused; + synchronized (this) { + session = sessions.get(hostPort); + if (session == null) { + return; + } + if (!session.isValid()) { + removeSessionWithId(session.sessionId()); + return; + } + // Try to set the session, if true is returned OpenSSL incremented the reference count + // of the underlying SSL_SESSION*. + reused = SSL.setSession(ssl, session.session()); + } + + if (reused) { + if (session.shouldBeSingleUse()) { + // Should only be used once + session.invalidate(); + } + session.updateLastAccessedTime(); + } + } + + private static HostPort keyFor(String host, int port) { + if (host == null && port < 1) { + return null; + } + return new HostPort(host, port); + } + + @Override + synchronized void clear() { + super.clear(); + sessions.clear(); + } + + /** + * Host / Port tuple used to find a {@link OpenSslSession} in the cache. + */ + private static final class HostPort { + private final int hash; + private final String host; + private final int port; + + HostPort(String host, int port) { + this.host = host; + this.port = port; + // Calculate a hashCode that does ignore case. + this.hash = 31 * AsciiString.hashCode(host) + port; + } + + @Override + public int hashCode() { + return hash; + } + + @Override + public boolean equals(Object obj) { + if (!(obj instanceof HostPort)) { + return false; + } + HostPort other = (HostPort) obj; + return port == other.port && host.equalsIgnoreCase(other.host); + } + + @Override + public String toString() { + return "HostPort{" + + "host='" + host + '\'' + + ", port=" + port + + '}'; + } + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslContext.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslContext.java new file mode 100644 index 0000000..e50108c --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslContext.java @@ -0,0 +1,60 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.buffer.ByteBufAllocator; + +import java.security.cert.Certificate; +import java.util.Map; + +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLException; + +/** + * This class will use a finalizer to ensure native resources are automatically cleaned up. To avoid finalizers + * and manually release the native memory see {@link ReferenceCountedOpenSslContext}. + */ +public abstract class OpenSslContext extends ReferenceCountedOpenSslContext { + OpenSslContext(Iterable ciphers, CipherSuiteFilter cipherFilter, ApplicationProtocolConfig apnCfg, + int mode, Certificate[] keyCertChain, + ClientAuth clientAuth, String[] protocols, boolean startTls, boolean enableOcsp, + Map.Entry, Object>... options) + throws SSLException { + super(ciphers, cipherFilter, toNegotiator(apnCfg), mode, keyCertChain, + clientAuth, protocols, startTls, enableOcsp, false, options); + } + + OpenSslContext(Iterable ciphers, CipherSuiteFilter cipherFilter, OpenSslApplicationProtocolNegotiator apn, + int mode, Certificate[] keyCertChain, + ClientAuth clientAuth, String[] protocols, boolean startTls, boolean enableOcsp, + Map.Entry, Object>... options) + throws SSLException { + super(ciphers, cipherFilter, apn, mode, keyCertChain, + clientAuth, protocols, startTls, enableOcsp, false, options); + } + + @Override + final SSLEngine newEngine0(ByteBufAllocator alloc, String peerHost, int peerPort, boolean jdkCompatibilityMode) { + return new OpenSslEngine(this, alloc, peerHost, peerPort, jdkCompatibilityMode); + } + + @Override + @SuppressWarnings("FinalizeDeclaration") + protected final void finalize() throws Throwable { + super.finalize(); + OpenSsl.releaseIfNeeded(this); + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslContextOption.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslContextOption.java new file mode 100644 index 0000000..7e538dd --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslContextOption.java @@ -0,0 +1,77 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +/** + * {@link SslContextOption}s that are specific to the {@link SslProvider#OPENSSL} / {@link SslProvider#OPENSSL_REFCNT}. + * + * @param the type of the value. + */ +public final class OpenSslContextOption extends SslContextOption { + + private OpenSslContextOption(String name) { + super(name); + } + + /** + * If enabled heavy-operations may be offloaded from the {@link io.netty.channel.EventLoop} if possible. + */ + public static final OpenSslContextOption USE_TASKS = + new OpenSslContextOption("USE_TASKS"); + /** + * If enabled TLS false start will be enabled if supported. + * When TLS false start is enabled the flow of {@link SslHandshakeCompletionEvent}s may be different compared when, + * not enabled. + * + * This is currently only supported when {@code BoringSSL} and ALPN is used. + */ + public static final OpenSslContextOption TLS_FALSE_START = + new OpenSslContextOption("TLS_FALSE_START"); + + /** + * Set the {@link OpenSslPrivateKeyMethod} to use. This allows to offload private-key operations + * if needed. + * + * This is currently only supported when {@code BoringSSL} is used. + */ + public static final OpenSslContextOption PRIVATE_KEY_METHOD = + new OpenSslContextOption("PRIVATE_KEY_METHOD"); + + /** + * Set the {@link OpenSslAsyncPrivateKeyMethod} to use. This allows to offload private-key operations + * if needed. + * + * This is currently only supported when {@code BoringSSL} is used. + */ + public static final OpenSslContextOption ASYNC_PRIVATE_KEY_METHOD = + new OpenSslContextOption("ASYNC_PRIVATE_KEY_METHOD"); + + /** + * Set the {@link OpenSslCertificateCompressionConfig} to use. This allows for the configuration of certificate + * compression algorithms which should be used, the priority of those algorithms and the directions in which + * they should be used. + * + * This is currently only supported when {@code BoringSSL} is used. + */ + public static final OpenSslContextOption CERTIFICATE_COMPRESSION_ALGORITHMS = + new OpenSslContextOption("CERTIFICATE_COMPRESSION_ALGORITHMS"); + + /** + * Set the maximum number of bytes that is allowed during the handshake for certificate chain. + */ + public static final OpenSslContextOption MAX_CERTIFICATE_LIST_BYTES = + new OpenSslContextOption("MAX_CERTIFICATE_LIST_BYTES"); +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslDefaultApplicationProtocolNegotiator.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslDefaultApplicationProtocolNegotiator.java new file mode 100644 index 0000000..5efd282 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslDefaultApplicationProtocolNegotiator.java @@ -0,0 +1,53 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import java.util.List; + +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +/** + * OpenSSL {@link ApplicationProtocolNegotiator} for ALPN and NPN. + * + * @deprecated use {@link ApplicationProtocolConfig}. + */ +@Deprecated +public final class OpenSslDefaultApplicationProtocolNegotiator implements OpenSslApplicationProtocolNegotiator { + private final ApplicationProtocolConfig config; + public OpenSslDefaultApplicationProtocolNegotiator(ApplicationProtocolConfig config) { + this.config = checkNotNull(config, "config"); + } + + @Override + public List protocols() { + return config.supportedProtocols(); + } + + @Override + public ApplicationProtocolConfig.Protocol protocol() { + return config.protocol(); + } + + @Override + public ApplicationProtocolConfig.SelectorFailureBehavior selectorFailureBehavior() { + return config.selectorFailureBehavior(); + } + + @Override + public ApplicationProtocolConfig.SelectedListenerFailureBehavior selectedListenerFailureBehavior() { + return config.selectedListenerFailureBehavior(); + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslEngine.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslEngine.java new file mode 100644 index 0000000..a5464ad --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslEngine.java @@ -0,0 +1,41 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.buffer.ByteBufAllocator; + +import javax.net.ssl.SSLEngine; + +/** + * Implements a {@link SSLEngine} using + * OpenSSL BIO abstractions. + *

+ * This class will use a finalizer to ensure native resources are automatically cleaned up. To avoid finalizers + * and manually release the native memory see {@link ReferenceCountedOpenSslEngine}. + */ +public final class OpenSslEngine extends ReferenceCountedOpenSslEngine { + OpenSslEngine(OpenSslContext context, ByteBufAllocator alloc, String peerHost, int peerPort, + boolean jdkCompatibilityMode) { + super(context, alloc, peerHost, peerPort, jdkCompatibilityMode, false); + } + + @Override + @SuppressWarnings("FinalizeDeclaration") + protected void finalize() throws Throwable { + super.finalize(); + OpenSsl.releaseIfNeeded(this); + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslEngineMap.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslEngineMap.java new file mode 100644 index 0000000..68e2df5 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslEngineMap.java @@ -0,0 +1,35 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +interface OpenSslEngineMap { + + /** + * Remove the {@link OpenSslEngine} with the given {@code ssl} address and + * return it. + */ + ReferenceCountedOpenSslEngine remove(long ssl); + + /** + * Add a {@link OpenSslEngine} to this {@link OpenSslEngineMap}. + */ + void add(ReferenceCountedOpenSslEngine engine); + + /** + * Get the {@link OpenSslEngine} for the given {@code ssl} address. + */ + ReferenceCountedOpenSslEngine get(long ssl); +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslKeyMaterial.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslKeyMaterial.java new file mode 100644 index 0000000..88131e4 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslKeyMaterial.java @@ -0,0 +1,59 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.util.ReferenceCounted; + +import java.security.cert.X509Certificate; + +/** + * Holds references to the native key-material that is used by OpenSSL. + */ +interface OpenSslKeyMaterial extends ReferenceCounted { + + /** + * Returns the configured {@link X509Certificate}s. + */ + X509Certificate[] certificateChain(); + + /** + * Returns the pointer to the {@code STACK_OF(X509)} which holds the certificate chain. + */ + long certificateChainAddress(); + + /** + * Returns the pointer to the {@code EVP_PKEY}. + */ + long privateKeyAddress(); + + @Override + OpenSslKeyMaterial retain(); + + @Override + OpenSslKeyMaterial retain(int increment); + + @Override + OpenSslKeyMaterial touch(); + + @Override + OpenSslKeyMaterial touch(Object hint); + + @Override + boolean release(); + + @Override + boolean release(int decrement); +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslKeyMaterialManager.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslKeyMaterialManager.java new file mode 100644 index 0000000..e2e2069 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslKeyMaterialManager.java @@ -0,0 +1,138 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import javax.net.ssl.SSLException; +import javax.net.ssl.SSLHandshakeException; +import javax.net.ssl.X509ExtendedKeyManager; +import javax.net.ssl.X509KeyManager; +import javax.security.auth.x500.X500Principal; +import java.security.PrivateKey; +import java.security.cert.X509Certificate; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + + +/** + * Manages key material for {@link OpenSslEngine}s and so set the right {@link PrivateKey}s and + * {@link X509Certificate}s. + */ +final class OpenSslKeyMaterialManager { + + // Code in this class is inspired by code of conscrypts: + // - https://android.googlesource.com/platform/external/ + // conscrypt/+/master/src/main/java/org/conscrypt/OpenSSLEngineImpl.java + // - https://android.googlesource.com/platform/external/ + // conscrypt/+/master/src/main/java/org/conscrypt/SSLParametersImpl.java + // + static final String KEY_TYPE_RSA = "RSA"; + static final String KEY_TYPE_DH_RSA = "DH_RSA"; + static final String KEY_TYPE_EC = "EC"; + static final String KEY_TYPE_EC_EC = "EC_EC"; + static final String KEY_TYPE_EC_RSA = "EC_RSA"; + + // key type mappings for types. + private static final Map KEY_TYPES = new HashMap(); + static { + KEY_TYPES.put("RSA", KEY_TYPE_RSA); + KEY_TYPES.put("DHE_RSA", KEY_TYPE_RSA); + KEY_TYPES.put("ECDHE_RSA", KEY_TYPE_RSA); + KEY_TYPES.put("ECDHE_ECDSA", KEY_TYPE_EC); + KEY_TYPES.put("ECDH_RSA", KEY_TYPE_EC_RSA); + KEY_TYPES.put("ECDH_ECDSA", KEY_TYPE_EC_EC); + KEY_TYPES.put("DH_RSA", KEY_TYPE_DH_RSA); + } + + private final OpenSslKeyMaterialProvider provider; + + OpenSslKeyMaterialManager(OpenSslKeyMaterialProvider provider) { + this.provider = provider; + } + + void setKeyMaterialServerSide(ReferenceCountedOpenSslEngine engine) throws SSLException { + String[] authMethods = engine.authMethods(); + if (authMethods.length == 0) { + throw new SSLHandshakeException("Unable to find key material"); + } + + // authMethods may contain duplicates or may result in the same type + // but call chooseServerAlias(...) may be expensive. So let's ensure + // we filter out duplicates. + Set typeSet = new HashSet(KEY_TYPES.size()); + for (String authMethod : authMethods) { + String type = KEY_TYPES.get(authMethod); + if (type != null && typeSet.add(type)) { + String alias = chooseServerAlias(engine, type); + if (alias != null) { + // We found a match... let's set the key material and return. + setKeyMaterial(engine, alias); + return; + } + } + } + throw new SSLHandshakeException("Unable to find key material for auth method(s): " + + Arrays.toString(authMethods)); + } + + void setKeyMaterialClientSide(ReferenceCountedOpenSslEngine engine, String[] keyTypes, + X500Principal[] issuer) throws SSLException { + String alias = chooseClientAlias(engine, keyTypes, issuer); + // Only try to set the keymaterial if we have a match. This is also consistent with what OpenJDK does: + // https://hg.openjdk.java.net/jdk/jdk11/file/76072a077ee1/ + // src/java.base/share/classes/sun/security/ssl/CertificateRequest.java#l362 + if (alias != null) { + setKeyMaterial(engine, alias); + } + } + + private void setKeyMaterial(ReferenceCountedOpenSslEngine engine, String alias) throws SSLException { + OpenSslKeyMaterial keyMaterial = null; + try { + keyMaterial = provider.chooseKeyMaterial(engine.alloc, alias); + if (keyMaterial == null) { + return; + } + engine.setKeyMaterial(keyMaterial); + } catch (SSLException e) { + throw e; + } catch (Exception e) { + throw new SSLException(e); + } finally { + if (keyMaterial != null) { + keyMaterial.release(); + } + } + } + private String chooseClientAlias(ReferenceCountedOpenSslEngine engine, + String[] keyTypes, X500Principal[] issuer) { + X509KeyManager manager = provider.keyManager(); + if (manager instanceof X509ExtendedKeyManager) { + return ((X509ExtendedKeyManager) manager).chooseEngineClientAlias(keyTypes, issuer, engine); + } + return manager.chooseClientAlias(keyTypes, issuer, null); + } + + private String chooseServerAlias(ReferenceCountedOpenSslEngine engine, String type) { + X509KeyManager manager = provider.keyManager(); + if (manager instanceof X509ExtendedKeyManager) { + return ((X509ExtendedKeyManager) manager).chooseEngineServerAlias(type, null, engine); + } + return manager.chooseServerAlias(type, null, null); + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslKeyMaterialProvider.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslKeyMaterialProvider.java new file mode 100644 index 0000000..adf545f --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslKeyMaterialProvider.java @@ -0,0 +1,154 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.UnpooledByteBufAllocator; +import io.netty.internal.tcnative.SSL; + +import javax.net.ssl.SSLException; +import javax.net.ssl.X509KeyManager; +import java.security.PrivateKey; +import java.security.cert.X509Certificate; + +import static io.netty.handler.ssl.ReferenceCountedOpenSslContext.toBIO; + +/** + * Provides {@link OpenSslKeyMaterial} for a given alias. + */ +class OpenSslKeyMaterialProvider { + + private final X509KeyManager keyManager; + private final String password; + + OpenSslKeyMaterialProvider(X509KeyManager keyManager, String password) { + this.keyManager = keyManager; + this.password = password; + } + + static void validateKeyMaterialSupported(X509Certificate[] keyCertChain, PrivateKey key, String keyPassword) + throws SSLException { + validateSupported(keyCertChain); + validateSupported(key, keyPassword); + } + + private static void validateSupported(PrivateKey key, String password) throws SSLException { + if (key == null) { + return; + } + + long pkeyBio = 0; + long pkey = 0; + + try { + pkeyBio = toBIO(UnpooledByteBufAllocator.DEFAULT, key); + pkey = SSL.parsePrivateKey(pkeyBio, password); + } catch (Exception e) { + throw new SSLException("PrivateKey type not supported " + key.getFormat(), e); + } finally { + SSL.freeBIO(pkeyBio); + if (pkey != 0) { + SSL.freePrivateKey(pkey); + } + } + } + + private static void validateSupported(X509Certificate[] certificates) throws SSLException { + if (certificates == null || certificates.length == 0) { + return; + } + + long chainBio = 0; + long chain = 0; + PemEncoded encoded = null; + try { + encoded = PemX509Certificate.toPEM(UnpooledByteBufAllocator.DEFAULT, true, certificates); + chainBio = toBIO(UnpooledByteBufAllocator.DEFAULT, encoded.retain()); + chain = SSL.parseX509Chain(chainBio); + } catch (Exception e) { + throw new SSLException("Certificate type not supported", e); + } finally { + SSL.freeBIO(chainBio); + if (chain != 0) { + SSL.freeX509Chain(chain); + } + if (encoded != null) { + encoded.release(); + } + } + } + + /** + * Returns the underlying {@link X509KeyManager} that is used. + */ + X509KeyManager keyManager() { + return keyManager; + } + + /** + * Returns the {@link OpenSslKeyMaterial} or {@code null} (if none) that should be used during the handshake by + * OpenSSL. + */ + OpenSslKeyMaterial chooseKeyMaterial(ByteBufAllocator allocator, String alias) throws Exception { + X509Certificate[] certificates = keyManager.getCertificateChain(alias); + if (certificates == null || certificates.length == 0) { + return null; + } + + PrivateKey key = keyManager.getPrivateKey(alias); + PemEncoded encoded = PemX509Certificate.toPEM(allocator, true, certificates); + long chainBio = 0; + long pkeyBio = 0; + long chain = 0; + long pkey = 0; + try { + chainBio = toBIO(allocator, encoded.retain()); + chain = SSL.parseX509Chain(chainBio); + + OpenSslKeyMaterial keyMaterial; + if (key instanceof OpenSslPrivateKey) { + keyMaterial = ((OpenSslPrivateKey) key).newKeyMaterial(chain, certificates); + } else { + pkeyBio = toBIO(allocator, key); + pkey = key == null ? 0 : SSL.parsePrivateKey(pkeyBio, password); + keyMaterial = new DefaultOpenSslKeyMaterial(chain, pkey, certificates); + } + + // See the chain and pkey to 0 so we will not release it as the ownership was + // transferred to OpenSslKeyMaterial. + chain = 0; + pkey = 0; + return keyMaterial; + } finally { + SSL.freeBIO(chainBio); + SSL.freeBIO(pkeyBio); + if (chain != 0) { + SSL.freeX509Chain(chain); + } + if (pkey != 0) { + SSL.freePrivateKey(pkey); + } + encoded.release(); + } + } + + /** + * Will be invoked once the provider should be destroyed. + */ + void destroy() { + // NOOP. + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslNpnApplicationProtocolNegotiator.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslNpnApplicationProtocolNegotiator.java new file mode 100644 index 0000000..860e57b --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslNpnApplicationProtocolNegotiator.java @@ -0,0 +1,59 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import static io.netty.handler.ssl.ApplicationProtocolUtil.toList; +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +import java.util.List; + +/** + * OpenSSL {@link ApplicationProtocolNegotiator} for NPN. + * + * @deprecated use {@link ApplicationProtocolConfig} + */ +@Deprecated +public final class OpenSslNpnApplicationProtocolNegotiator implements OpenSslApplicationProtocolNegotiator { + private final List protocols; + + public OpenSslNpnApplicationProtocolNegotiator(Iterable protocols) { + this.protocols = checkNotNull(toList(protocols), "protocols"); + } + + public OpenSslNpnApplicationProtocolNegotiator(String... protocols) { + this.protocols = checkNotNull(toList(protocols), "protocols"); + } + + @Override + public ApplicationProtocolConfig.Protocol protocol() { + return ApplicationProtocolConfig.Protocol.NPN; + } + + @Override + public List protocols() { + return protocols; + } + + @Override + public ApplicationProtocolConfig.SelectorFailureBehavior selectorFailureBehavior() { + return ApplicationProtocolConfig.SelectorFailureBehavior.CHOOSE_MY_LAST_PROTOCOL; + } + + @Override + public ApplicationProtocolConfig.SelectedListenerFailureBehavior selectedListenerFailureBehavior() { + return ApplicationProtocolConfig.SelectedListenerFailureBehavior.ACCEPT; + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslPrivateKey.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslPrivateKey.java new file mode 100644 index 0000000..fb6caed --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslPrivateKey.java @@ -0,0 +1,191 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.internal.tcnative.SSL; +import io.netty.util.AbstractReferenceCounted; +import io.netty.util.IllegalReferenceCountException; +import io.netty.util.internal.EmptyArrays; + +import javax.security.auth.Destroyable; +import java.security.PrivateKey; +import java.security.cert.X509Certificate; + +final class OpenSslPrivateKey extends AbstractReferenceCounted implements PrivateKey { + + private long privateKeyAddress; + + OpenSslPrivateKey(long privateKeyAddress) { + this.privateKeyAddress = privateKeyAddress; + } + + @Override + public String getAlgorithm() { + return "unknown"; + } + + @Override + public String getFormat() { + // As we do not support encoding we should return null as stated in the javadocs of PrivateKey. + return null; + } + + @Override + public byte[] getEncoded() { + return null; + } + + private long privateKeyAddress() { + if (refCnt() <= 0) { + throw new IllegalReferenceCountException(); + } + return privateKeyAddress; + } + + @Override + protected void deallocate() { + SSL.freePrivateKey(privateKeyAddress); + privateKeyAddress = 0; + } + + @Override + public OpenSslPrivateKey retain() { + super.retain(); + return this; + } + + @Override + public OpenSslPrivateKey retain(int increment) { + super.retain(increment); + return this; + } + + @Override + public OpenSslPrivateKey touch() { + super.touch(); + return this; + } + + @Override + public OpenSslPrivateKey touch(Object hint) { + return this; + } + + /** + * NOTE: This is a JDK8 interface/method. Due to backwards compatibility + * reasons it's not possible to slap the {@code @Override} annotation onto + * this method. + * + * @see Destroyable#destroy() + */ + @Override + public void destroy() { + release(refCnt()); + } + + /** + * NOTE: This is a JDK8 interface/method. Due to backwards compatibility + * reasons it's not possible to slap the {@code @Override} annotation onto + * this method. + * + * @see Destroyable#isDestroyed() + */ + @Override + public boolean isDestroyed() { + return refCnt() == 0; + } + + /** + * Create a new {@link OpenSslKeyMaterial} which uses the private key that is held by {@link OpenSslPrivateKey}. + * + * When the material is created we increment the reference count of the enclosing {@link OpenSslPrivateKey} and + * decrement it again when the reference count of the {@link OpenSslKeyMaterial} reaches {@code 0}. + */ + OpenSslKeyMaterial newKeyMaterial(long certificateChain, X509Certificate[] chain) { + return new OpenSslPrivateKeyMaterial(certificateChain, chain); + } + + // Package-private for unit-test only + final class OpenSslPrivateKeyMaterial extends AbstractReferenceCounted implements OpenSslKeyMaterial { + + // Package-private for unit-test only + long certificateChain; + private final X509Certificate[] x509CertificateChain; + + OpenSslPrivateKeyMaterial(long certificateChain, X509Certificate[] x509CertificateChain) { + this.certificateChain = certificateChain; + this.x509CertificateChain = x509CertificateChain == null ? + EmptyArrays.EMPTY_X509_CERTIFICATES : x509CertificateChain; + OpenSslPrivateKey.this.retain(); + } + + @Override + public X509Certificate[] certificateChain() { + return x509CertificateChain.clone(); + } + + @Override + public long certificateChainAddress() { + if (refCnt() <= 0) { + throw new IllegalReferenceCountException(); + } + return certificateChain; + } + + @Override + public long privateKeyAddress() { + if (refCnt() <= 0) { + throw new IllegalReferenceCountException(); + } + return OpenSslPrivateKey.this.privateKeyAddress(); + } + + @Override + public OpenSslKeyMaterial touch(Object hint) { + OpenSslPrivateKey.this.touch(hint); + return this; + } + + @Override + public OpenSslKeyMaterial retain() { + super.retain(); + return this; + } + + @Override + public OpenSslKeyMaterial retain(int increment) { + super.retain(increment); + return this; + } + + @Override + public OpenSslKeyMaterial touch() { + OpenSslPrivateKey.this.touch(); + return this; + } + + @Override + protected void deallocate() { + releaseChain(); + OpenSslPrivateKey.this.release(); + } + + private void releaseChain() { + SSL.freeX509Chain(certificateChain); + certificateChain = 0; + } + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslPrivateKeyMethod.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslPrivateKeyMethod.java new file mode 100644 index 0000000..84c8229 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslPrivateKeyMethod.java @@ -0,0 +1,62 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.internal.tcnative.SSLPrivateKeyMethod; +import io.netty.util.internal.UnstableApi; + +import javax.net.ssl.SSLEngine; + +/** + * Allow to customize private key signing / decrypting (when using RSA). Only supported when using BoringSSL atm. + */ +@UnstableApi +public interface OpenSslPrivateKeyMethod { + int SSL_SIGN_RSA_PKCS1_SHA1 = SSLPrivateKeyMethod.SSL_SIGN_RSA_PKCS1_SHA1; + int SSL_SIGN_RSA_PKCS1_SHA256 = SSLPrivateKeyMethod.SSL_SIGN_RSA_PKCS1_SHA256; + int SSL_SIGN_RSA_PKCS1_SHA384 = SSLPrivateKeyMethod.SSL_SIGN_RSA_PKCS1_SHA384; + int SSL_SIGN_RSA_PKCS1_SHA512 = SSLPrivateKeyMethod.SSL_SIGN_RSA_PKCS1_SHA512; + int SSL_SIGN_ECDSA_SHA1 = SSLPrivateKeyMethod.SSL_SIGN_ECDSA_SHA1; + int SSL_SIGN_ECDSA_SECP256R1_SHA256 = SSLPrivateKeyMethod.SSL_SIGN_ECDSA_SECP256R1_SHA256; + int SSL_SIGN_ECDSA_SECP384R1_SHA384 = SSLPrivateKeyMethod.SSL_SIGN_ECDSA_SECP384R1_SHA384; + int SSL_SIGN_ECDSA_SECP521R1_SHA512 = SSLPrivateKeyMethod.SSL_SIGN_ECDSA_SECP521R1_SHA512; + int SSL_SIGN_RSA_PSS_RSAE_SHA256 = SSLPrivateKeyMethod.SSL_SIGN_RSA_PSS_RSAE_SHA256; + int SSL_SIGN_RSA_PSS_RSAE_SHA384 = SSLPrivateKeyMethod.SSL_SIGN_RSA_PSS_RSAE_SHA384; + int SSL_SIGN_RSA_PSS_RSAE_SHA512 = SSLPrivateKeyMethod.SSL_SIGN_RSA_PSS_RSAE_SHA512; + int SSL_SIGN_ED25519 = SSLPrivateKeyMethod.SSL_SIGN_ED25519; + int SSL_SIGN_RSA_PKCS1_MD5_SHA1 = SSLPrivateKeyMethod.SSL_SIGN_RSA_PKCS1_MD5_SHA1; + + /** + * Signs the input with the given key and returns the signed bytes. + * + * @param engine the {@link SSLEngine} + * @param signatureAlgorithm the algorithm to use for signing + * @param input the digest itself + * @return the signed data (must not be {@code null}) + * @throws Exception thrown if an error is encountered during the signing + */ + byte[] sign(SSLEngine engine, int signatureAlgorithm, byte[] input) throws Exception; + + /** + * Decrypts the input with the given key and returns the decrypted bytes. + * + * @param engine the {@link SSLEngine} + * @param input the input which should be decrypted + * @return the decrypted data (must not be {@code null}) + * @throws Exception thrown if an error is encountered during the decrypting + */ + byte[] decrypt(SSLEngine engine, byte[] input) throws Exception; +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslServerContext.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslServerContext.java new file mode 100644 index 0000000..381c846 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslServerContext.java @@ -0,0 +1,371 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.internal.tcnative.SSL; + +import java.io.File; +import java.security.KeyStore; +import java.security.PrivateKey; +import java.security.cert.X509Certificate; +import java.util.Map; + +import javax.net.ssl.KeyManager; +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLException; +import javax.net.ssl.TrustManager; +import javax.net.ssl.TrustManagerFactory; + +import static io.netty.handler.ssl.ReferenceCountedOpenSslServerContext.newSessionContext; + +/** + * A server-side {@link SslContext} which uses OpenSSL's SSL/TLS implementation. + *

This class will use a finalizer to ensure native resources are automatically cleaned up. To avoid finalizers + * and manually release the native memory see {@link ReferenceCountedOpenSslServerContext}. + */ +public final class OpenSslServerContext extends OpenSslContext { + private final OpenSslServerSessionContext sessionContext; + + /** + * Creates a new instance. + * + * @param certChainFile an X.509 certificate chain file in PEM format + * @param keyFile a PKCS#8 private key file in PEM format + * @deprecated use {@link SslContextBuilder} + */ + @Deprecated + public OpenSslServerContext(File certChainFile, File keyFile) throws SSLException { + this(certChainFile, keyFile, null); + } + + /** + * Creates a new instance. + * + * @param certChainFile an X.509 certificate chain file in PEM format + * @param keyFile a PKCS#8 private key file in PEM format + * @param keyPassword the password of the {@code keyFile}. + * {@code null} if it's not password-protected. + * @deprecated use {@link SslContextBuilder} + */ + @Deprecated + public OpenSslServerContext(File certChainFile, File keyFile, String keyPassword) throws SSLException { + this(certChainFile, keyFile, keyPassword, null, IdentityCipherSuiteFilter.INSTANCE, + ApplicationProtocolConfig.DISABLED, 0, 0); + } + + /** + * Creates a new instance. + * + * @param certChainFile an X.509 certificate chain file in PEM format + * @param keyFile a PKCS#8 private key file in PEM format + * @param keyPassword the password of the {@code keyFile}. + * {@code null} if it's not password-protected. + * @param ciphers the cipher suites to enable, in the order of preference. + * {@code null} to use the default cipher suites. + * @param apn Provides a means to configure parameters related to application protocol negotiation. + * @param sessionCacheSize the size of the cache used for storing SSL session objects. + * {@code 0} to use the default value. + * @param sessionTimeout the timeout for the cached SSL session objects, in seconds. + * {@code 0} to use the default value. + * @deprecated use {@link SslContextBuilder} + */ + @Deprecated + public OpenSslServerContext( + File certChainFile, File keyFile, String keyPassword, + Iterable ciphers, ApplicationProtocolConfig apn, + long sessionCacheSize, long sessionTimeout) throws SSLException { + this(certChainFile, keyFile, keyPassword, ciphers, IdentityCipherSuiteFilter.INSTANCE, + apn, sessionCacheSize, sessionTimeout); + } + + /** + * Creates a new instance. + * + * @param certChainFile an X.509 certificate chain file in PEM format + * @param keyFile a PKCS#8 private key file in PEM format + * @param keyPassword the password of the {@code keyFile}. + * {@code null} if it's not password-protected. + * @param ciphers the cipher suites to enable, in the order of preference. + * {@code null} to use the default cipher suites. + * @param nextProtocols the application layer protocols to accept, in the order of preference. + * {@code null} to disable TLS NPN/ALPN extension. + * @param sessionCacheSize the size of the cache used for storing SSL session objects. + * {@code 0} to use the default value. + * @param sessionTimeout the timeout for the cached SSL session objects, in seconds. + * {@code 0} to use the default value. + * @deprecated use {@link SslContextBuilder} + */ + @Deprecated + public OpenSslServerContext( + File certChainFile, File keyFile, String keyPassword, + Iterable ciphers, Iterable nextProtocols, + long sessionCacheSize, long sessionTimeout) throws SSLException { + this(certChainFile, keyFile, keyPassword, ciphers, + toApplicationProtocolConfig(nextProtocols), sessionCacheSize, sessionTimeout); + } + + /** + * Creates a new instance. + * + * @param certChainFile an X.509 certificate chain file in PEM format + * @param keyFile a PKCS#8 private key file in PEM format + * @param keyPassword the password of the {@code keyFile}. + * {@code null} if it's not password-protected. + * @param ciphers the cipher suites to enable, in the order of preference. + * {@code null} to use the default cipher suites. + * @param config Application protocol config. + * @param sessionCacheSize the size of the cache used for storing SSL session objects. + * {@code 0} to use the default value. + * @param sessionTimeout the timeout for the cached SSL session objects, in seconds. + * {@code 0} to use the default value. + * @deprecated use {@link SslContextBuilder} + */ + @Deprecated + public OpenSslServerContext( + File certChainFile, File keyFile, String keyPassword, TrustManagerFactory trustManagerFactory, + Iterable ciphers, ApplicationProtocolConfig config, + long sessionCacheSize, long sessionTimeout) throws SSLException { + this(certChainFile, keyFile, keyPassword, trustManagerFactory, ciphers, + toNegotiator(config), sessionCacheSize, sessionTimeout); + } + + /** + * Creates a new instance. + * + * @param certChainFile an X.509 certificate chain file in PEM format + * @param keyFile a PKCS#8 private key file in PEM format + * @param keyPassword the password of the {@code keyFile}. + * {@code null} if it's not password-protected. + * @param ciphers the cipher suites to enable, in the order of preference. + * {@code null} to use the default cipher suites. + * @param apn Application protocol negotiator. + * @param sessionCacheSize the size of the cache used for storing SSL session objects. + * {@code 0} to use the default value. + * @param sessionTimeout the timeout for the cached SSL session objects, in seconds. + * {@code 0} to use the default value. + * @deprecated use {@link SslContextBuilder} + */ + @Deprecated + public OpenSslServerContext( + File certChainFile, File keyFile, String keyPassword, TrustManagerFactory trustManagerFactory, + Iterable ciphers, OpenSslApplicationProtocolNegotiator apn, + long sessionCacheSize, long sessionTimeout) throws SSLException { + this(null, trustManagerFactory, certChainFile, keyFile, keyPassword, null, + ciphers, null, apn, sessionCacheSize, sessionTimeout); + } + + /** + * Creates a new instance. + * + * @param certChainFile an X.509 certificate chain file in PEM format + * @param keyFile a PKCS#8 private key file in PEM format + * @param keyPassword the password of the {@code keyFile}. + * {@code null} if it's not password-protected. + * @param ciphers the cipher suites to enable, in the order of preference. + * {@code null} to use the default cipher suites. + * @param cipherFilter a filter to apply over the supplied list of ciphers + * @param apn Provides a means to configure parameters related to application protocol negotiation. + * @param sessionCacheSize the size of the cache used for storing SSL session objects. + * {@code 0} to use the default value. + * @param sessionTimeout the timeout for the cached SSL session objects, in seconds. + * {@code 0} to use the default value. + * @deprecated use {@link SslContextBuilder} + */ + @Deprecated + public OpenSslServerContext( + File certChainFile, File keyFile, String keyPassword, + Iterable ciphers, CipherSuiteFilter cipherFilter, ApplicationProtocolConfig apn, + long sessionCacheSize, long sessionTimeout) throws SSLException { + this(null, null, certChainFile, keyFile, keyPassword, null, + ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout); + } + + /** + * Creates a new instance. + * + * @param trustCertCollectionFile an X.509 certificate collection file in PEM format. + * This provides the certificate collection used for mutual authentication. + * {@code null} to use the system default + * @param trustManagerFactory the {@link TrustManagerFactory} that provides the {@link TrustManager}s + * that verifies the certificates sent from clients. + * {@code null} to use the default or the results of parsing + * {@code trustCertCollectionFile}. + * @param keyCertChainFile an X.509 certificate chain file in PEM format + * @param keyFile a PKCS#8 private key file in PEM format + * @param keyPassword the password of the {@code keyFile}. + * {@code null} if it's not password-protected. + * @param keyManagerFactory the {@link KeyManagerFactory} that provides the {@link KeyManager}s + * that is used to encrypt data being sent to clients. + * {@code null} to use the default or the results of parsing + * {@code keyCertChainFile} and {@code keyFile}. + * @param ciphers the cipher suites to enable, in the order of preference. + * {@code null} to use the default cipher suites. + * @param cipherFilter a filter to apply over the supplied list of ciphers + * Only required if {@code provider} is {@link SslProvider#JDK} + * @param config Provides a means to configure parameters related to application protocol negotiation. + * @param sessionCacheSize the size of the cache used for storing SSL session objects. + * {@code 0} to use the default value. + * @param sessionTimeout the timeout for the cached SSL session objects, in seconds. + * {@code 0} to use the default value. + * @deprecated use {@link SslContextBuilder} + */ + @Deprecated + public OpenSslServerContext( + File trustCertCollectionFile, TrustManagerFactory trustManagerFactory, + File keyCertChainFile, File keyFile, String keyPassword, KeyManagerFactory keyManagerFactory, + Iterable ciphers, CipherSuiteFilter cipherFilter, ApplicationProtocolConfig config, + long sessionCacheSize, long sessionTimeout) throws SSLException { + this(trustCertCollectionFile, trustManagerFactory, keyCertChainFile, keyFile, keyPassword, keyManagerFactory, + ciphers, cipherFilter, toNegotiator(config), sessionCacheSize, sessionTimeout); + } + + /** + * Creates a new instance. + * + * @param certChainFile an X.509 certificate chain file in PEM format + * @param keyFile a PKCS#8 private key file in PEM format + * @param keyPassword the password of the {@code keyFile}. + * {@code null} if it's not password-protected. + * @param ciphers the cipher suites to enable, in the order of preference. + * {@code null} to use the default cipher suites. + * @param cipherFilter a filter to apply over the supplied list of ciphers + * @param config Application protocol config. + * @param sessionCacheSize the size of the cache used for storing SSL session objects. + * {@code 0} to use the default value. + * @param sessionTimeout the timeout for the cached SSL session objects, in seconds. + * {@code 0} to use the default value. + * @deprecated use {@link SslContextBuilder} + */ + @Deprecated + public OpenSslServerContext(File certChainFile, File keyFile, String keyPassword, + TrustManagerFactory trustManagerFactory, Iterable ciphers, + CipherSuiteFilter cipherFilter, ApplicationProtocolConfig config, + long sessionCacheSize, long sessionTimeout) throws SSLException { + this(null, trustManagerFactory, certChainFile, keyFile, keyPassword, null, ciphers, cipherFilter, + toNegotiator(config), sessionCacheSize, sessionTimeout); + } + + /** + * Creates a new instance. + * + * @param certChainFile an X.509 certificate chain file in PEM format + * @param keyFile a PKCS#8 private key file in PEM format + * @param keyPassword the password of the {@code keyFile}. + * {@code null} if it's not password-protected. + * @param ciphers the cipher suites to enable, in the order of preference. + * {@code null} to use the default cipher suites. + * @param cipherFilter a filter to apply over the supplied list of ciphers + * @param apn Application protocol negotiator. + * @param sessionCacheSize the size of the cache used for storing SSL session objects. + * {@code 0} to use the default value. + * @param sessionTimeout the timeout for the cached SSL session objects, in seconds. + * {@code 0} to use the default value. + * @deprecated use {@link SslContextBuilder}} + */ + @Deprecated + public OpenSslServerContext( + File certChainFile, File keyFile, String keyPassword, TrustManagerFactory trustManagerFactory, + Iterable ciphers, CipherSuiteFilter cipherFilter, OpenSslApplicationProtocolNegotiator apn, + long sessionCacheSize, long sessionTimeout) throws SSLException { + this(null, trustManagerFactory, certChainFile, keyFile, keyPassword, null, ciphers, cipherFilter, + apn, sessionCacheSize, sessionTimeout); + } + + /** + * Creates a new instance. + * + * + * @param trustCertCollectionFile an X.509 certificate collection file in PEM format. + * This provides the certificate collection used for mutual authentication. + * {@code null} to use the system default + * @param trustManagerFactory the {@link TrustManagerFactory} that provides the {@link TrustManager}s + * that verifies the certificates sent from clients. + * {@code null} to use the default or the results of parsing + * {@code trustCertCollectionFile}. + * @param keyCertChainFile an X.509 certificate chain file in PEM format + * @param keyFile a PKCS#8 private key file in PEM format + * @param keyPassword the password of the {@code keyFile}. + * {@code null} if it's not password-protected. + * @param keyManagerFactory the {@link KeyManagerFactory} that provides the {@link KeyManager}s + * that is used to encrypt data being sent to clients. + * {@code null} to use the default or the results of parsing + * {@code keyCertChainFile} and {@code keyFile}. + * @param ciphers the cipher suites to enable, in the order of preference. + * {@code null} to use the default cipher suites. + * @param cipherFilter a filter to apply over the supplied list of ciphers + * Only required if {@code provider} is {@link SslProvider#JDK} + * @param apn Application Protocol Negotiator object + * @param sessionCacheSize the size of the cache used for storing SSL session objects. + * {@code 0} to use the default value. + * @param sessionTimeout the timeout for the cached SSL session objects, in seconds. + * {@code 0} to use the default value. + * @deprecated use {@link SslContextBuilder} + */ + @Deprecated + public OpenSslServerContext( + File trustCertCollectionFile, TrustManagerFactory trustManagerFactory, + File keyCertChainFile, File keyFile, String keyPassword, KeyManagerFactory keyManagerFactory, + Iterable ciphers, CipherSuiteFilter cipherFilter, OpenSslApplicationProtocolNegotiator apn, + long sessionCacheSize, long sessionTimeout) throws SSLException { + this(toX509CertificatesInternal(trustCertCollectionFile), trustManagerFactory, + toX509CertificatesInternal(keyCertChainFile), toPrivateKeyInternal(keyFile, keyPassword), + keyPassword, keyManagerFactory, ciphers, cipherFilter, + apn, sessionCacheSize, sessionTimeout, ClientAuth.NONE, null, false, false, KeyStore.getDefaultType()); + } + + OpenSslServerContext( + X509Certificate[] trustCertCollection, TrustManagerFactory trustManagerFactory, + X509Certificate[] keyCertChain, PrivateKey key, String keyPassword, KeyManagerFactory keyManagerFactory, + Iterable ciphers, CipherSuiteFilter cipherFilter, ApplicationProtocolConfig apn, + long sessionCacheSize, long sessionTimeout, ClientAuth clientAuth, String[] protocols, boolean startTls, + boolean enableOcsp, String keyStore, Map.Entry, Object>... options) + throws SSLException { + this(trustCertCollection, trustManagerFactory, keyCertChain, key, keyPassword, keyManagerFactory, ciphers, + cipherFilter, toNegotiator(apn), sessionCacheSize, sessionTimeout, clientAuth, protocols, startTls, + enableOcsp, keyStore, options); + } + + @SuppressWarnings("deprecation") + private OpenSslServerContext( + X509Certificate[] trustCertCollection, TrustManagerFactory trustManagerFactory, + X509Certificate[] keyCertChain, PrivateKey key, String keyPassword, KeyManagerFactory keyManagerFactory, + Iterable ciphers, CipherSuiteFilter cipherFilter, OpenSslApplicationProtocolNegotiator apn, + long sessionCacheSize, long sessionTimeout, ClientAuth clientAuth, String[] protocols, boolean startTls, + boolean enableOcsp, String keyStore, Map.Entry, Object>... options) + throws SSLException { + super(ciphers, cipherFilter, apn, SSL.SSL_MODE_SERVER, keyCertChain, + clientAuth, protocols, startTls, enableOcsp, options); + + // Create a new SSL_CTX and configure it. + boolean success = false; + try { + OpenSslKeyMaterialProvider.validateKeyMaterialSupported(keyCertChain, key, keyPassword); + sessionContext = newSessionContext(this, ctx, engineMap, trustCertCollection, trustManagerFactory, + keyCertChain, key, keyPassword, keyManagerFactory, keyStore, + sessionCacheSize, sessionTimeout); + success = true; + } finally { + if (!success) { + release(); + } + } + } + + @Override + public OpenSslServerSessionContext sessionContext() { + return sessionContext; + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslServerSessionContext.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslServerSessionContext.java new file mode 100644 index 0000000..eba161f --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslServerSessionContext.java @@ -0,0 +1,50 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.internal.tcnative.SSL; +import io.netty.internal.tcnative.SSLContext; + +import java.util.concurrent.locks.Lock; + + +/** + * {@link OpenSslSessionContext} implementation which offers extra methods which are only useful for the server-side. + */ +public final class OpenSslServerSessionContext extends OpenSslSessionContext { + OpenSslServerSessionContext(ReferenceCountedOpenSslContext context, OpenSslKeyMaterialProvider provider) { + super(context, provider, SSL.SSL_SESS_CACHE_SERVER, new OpenSslSessionCache(context.engineMap)); + } + + /** + * Set the context within which session be reused (server side only) + * See + * man SSL_CTX_set_session_id_context + * + * @param sidCtx can be any kind of binary data, it is therefore possible to use e.g. the name + * of the application and/or the hostname and/or service name + * @return {@code true} if success, {@code false} otherwise. + */ + public boolean setSessionIdContext(byte[] sidCtx) { + Lock writerLock = context.ctxLock.writeLock(); + writerLock.lock(); + try { + return SSLContext.setSessionIdContext(context.ctx, sidCtx); + } finally { + writerLock.unlock(); + } + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslSession.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslSession.java new file mode 100644 index 0000000..4e6ef35 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslSession.java @@ -0,0 +1,62 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.util.ReferenceCounted; + +import javax.net.ssl.SSLException; +import javax.net.ssl.SSLSession; +import java.security.cert.Certificate; + +/** + * {@link SSLSession} that is specific to our native implementation and {@link ReferenceCounted} to track native + * resources. + */ +interface OpenSslSession extends SSLSession { + + /** + * Return the {@link OpenSslSessionId} that can be used to identify this session. + */ + OpenSslSessionId sessionId(); + + /** + * Set the local certificate chain that is used. It is not expected that this array will be changed at all + * and so its ok to not copy the array. + */ + void setLocalCertificate(Certificate[] localCertificate); + + /** + * Set the {@link OpenSslSessionId} for the {@link OpenSslSession}. + */ + void setSessionId(OpenSslSessionId id); + + @Override + OpenSslSessionContext getSessionContext(); + + /** + * Expand (or increase) the value returned by {@link #getApplicationBufferSize()} if necessary. + *

+ * This is only called in a synchronized block, so no need to use atomic operations. + * @param packetLengthDataOnly The packet size which exceeds the current {@link #getApplicationBufferSize()}. + */ + void tryExpandApplicationBufferSize(int packetLengthDataOnly); + + /** + * Called once the handshake has completed. + */ + void handshakeFinished(byte[] id, String cipher, String protocol, byte[] peerCertificate, + byte[][] peerCertificateChain, long creationTime, long timeout) throws SSLException; +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslSessionCache.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslSessionCache.java new file mode 100644 index 0000000..2881a45 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslSessionCache.java @@ -0,0 +1,492 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.internal.tcnative.SSLSession; +import io.netty.internal.tcnative.SSLSessionCache; +import io.netty.util.ResourceLeakDetector; +import io.netty.util.ResourceLeakDetectorFactory; +import io.netty.util.ResourceLeakTracker; +import io.netty.util.internal.EmptyArrays; +import io.netty.util.internal.SystemPropertyUtil; + +import javax.security.cert.X509Certificate; +import java.security.Principal; +import java.security.cert.Certificate; +import java.util.ArrayList; +import java.util.Iterator; +import java.util.LinkedHashMap; +import java.util.List; +import java.util.Map; +import java.util.concurrent.atomic.AtomicInteger; + +/** + * {@link SSLSessionCache} implementation for our native SSL implementation. + */ +class OpenSslSessionCache implements SSLSessionCache { + private static final OpenSslSession[] EMPTY_SESSIONS = new OpenSslSession[0]; + + private static final int DEFAULT_CACHE_SIZE; + static { + // Respect the same system property as the JDK implementation to make it easy to switch between implementations. + int cacheSize = SystemPropertyUtil.getInt("javax.net.ssl.sessionCacheSize", 20480); + if (cacheSize >= 0) { + DEFAULT_CACHE_SIZE = cacheSize; + } else { + DEFAULT_CACHE_SIZE = 20480; + } + } + private final OpenSslEngineMap engineMap; + + private final Map sessions = + new LinkedHashMap() { + + private static final long serialVersionUID = -7773696788135734448L; + + @Override + protected boolean removeEldestEntry(Map.Entry eldest) { + int maxSize = maximumCacheSize.get(); + if (maxSize >= 0 && size() > maxSize) { + removeSessionWithId(eldest.getKey()); + } + // We always need to return false as we modify the map directly. + return false; + } + }; + + private final AtomicInteger maximumCacheSize = new AtomicInteger(DEFAULT_CACHE_SIZE); + + // Let's use the same default value as OpenSSL does. + // See https://www.openssl.org/docs/man1.1.1/man3/SSL_get_default_timeout.html + private final AtomicInteger sessionTimeout = new AtomicInteger(300); + private int sessionCounter; + + OpenSslSessionCache(OpenSslEngineMap engineMap) { + this.engineMap = engineMap; + } + + final void setSessionTimeout(int seconds) { + int oldTimeout = sessionTimeout.getAndSet(seconds); + if (oldTimeout > seconds) { + // Drain the whole cache as this way we can use the ordering of the LinkedHashMap to detect early + // if there are any other sessions left that are invalid. + clear(); + } + } + + final int getSessionTimeout() { + return sessionTimeout.get(); + } + + /** + * Called once a new {@link OpenSslSession} was created. + * + * @param session the new session. + * @return {@code true} if the session should be cached, {@code false} otherwise. + */ + protected boolean sessionCreated(NativeSslSession session) { + return true; + } + + /** + * Called once an {@link OpenSslSession} was removed from the cache. + * + * @param session the session to remove. + */ + protected void sessionRemoved(NativeSslSession session) { } + + final void setSessionCacheSize(int size) { + long oldSize = maximumCacheSize.getAndSet(size); + if (oldSize > size || size == 0) { + // Just keep it simple for now and drain the whole cache. + clear(); + } + } + + final int getSessionCacheSize() { + return maximumCacheSize.get(); + } + + private void expungeInvalidSessions() { + if (sessions.isEmpty()) { + return; + } + long now = System.currentTimeMillis(); + Iterator> iterator = sessions.entrySet().iterator(); + while (iterator.hasNext()) { + NativeSslSession session = iterator.next().getValue(); + // As we use a LinkedHashMap we can break the while loop as soon as we find a valid session. + // This is true as we always drain the cache as soon as we change the timeout to a smaller value as + // it was set before. This way its true that the insertion order matches the timeout order. + if (session.isValid(now)) { + break; + } + iterator.remove(); + + notifyRemovalAndFree(session); + } + } + + @Override + public final boolean sessionCreated(long ssl, long sslSession) { + ReferenceCountedOpenSslEngine engine = engineMap.get(ssl); + if (engine == null) { + // We couldn't find the engine itself. + return false; + } + NativeSslSession session = new NativeSslSession(sslSession, engine.getPeerHost(), engine.getPeerPort(), + getSessionTimeout() * 1000L); + engine.setSessionId(session.sessionId()); + synchronized (this) { + // Mimic what OpenSSL is doing and expunge every 255 new sessions + // See https://www.openssl.org/docs/man1.0.2/man3/SSL_CTX_flush_sessions.html + if (++sessionCounter == 255) { + sessionCounter = 0; + expungeInvalidSessions(); + } + + if (!sessionCreated(session)) { + // Should not be cached, return false. In this case we also need to call close() to ensure we + // close the ResourceLeakTracker. + session.close(); + return false; + } + + final NativeSslSession old = sessions.put(session.sessionId(), session); + if (old != null) { + notifyRemovalAndFree(old); + } + } + return true; + } + + @Override + public final long getSession(long ssl, byte[] sessionId) { + OpenSslSessionId id = new OpenSslSessionId(sessionId); + final NativeSslSession session; + synchronized (this) { + session = sessions.get(id); + if (session == null) { + return -1; + } + + // If the session is not valid anymore we should remove it from the cache and just signal back + // that we couldn't find a session that is re-usable. + if (!session.isValid() || + // This needs to happen in the synchronized block so we ensure we never destroy it before we + // incremented the reference count. If we cant increment the reference count there is something + // wrong. In this case just remove the session from the cache and signal back that we couldn't + // find a session for re-use. + !session.upRef()) { + // Remove the session from the cache. This will also take care of calling SSL_SESSION_free(...) + removeSessionWithId(session.sessionId()); + return -1; + } + + // At this point we already incremented the reference count via SSL_SESSION_up_ref(...). + if (session.shouldBeSingleUse()) { + // Should only be used once. In this case invalidate the session which will also ensure we remove it + // from the cache and call SSL_SESSION_free(...). + removeSessionWithId(session.sessionId()); + } + } + session.updateLastAccessedTime(); + return session.session(); + } + + void setSession(long ssl, String host, int port) { + // Do nothing by default as this needs special handling for the client side. + } + + /** + * Remove the session with the given id from the cache + */ + final synchronized void removeSessionWithId(OpenSslSessionId id) { + NativeSslSession sslSession = sessions.remove(id); + if (sslSession != null) { + notifyRemovalAndFree(sslSession); + } + } + + /** + * Returns {@code true} if there is a session for the given id in the cache. + */ + final synchronized boolean containsSessionWithId(OpenSslSessionId id) { + return sessions.containsKey(id); + } + + private void notifyRemovalAndFree(NativeSslSession session) { + sessionRemoved(session); + session.free(); + } + + /** + * Return the {@link OpenSslSession} which is cached for the given id. + */ + final synchronized OpenSslSession getSession(OpenSslSessionId id) { + NativeSslSession session = sessions.get(id); + if (session != null && !session.isValid()) { + // The session is not valid anymore, let's remove it and just signal back that there is no session + // with the given ID in the cache anymore. This also takes care of calling SSL_SESSION_free(...) + removeSessionWithId(session.sessionId()); + return null; + } + return session; + } + + /** + * Returns a snapshot of the session ids of the current valid sessions. + */ + final List getIds() { + final OpenSslSession[] sessionsArray; + synchronized (this) { + sessionsArray = sessions.values().toArray(EMPTY_SESSIONS); + } + List ids = new ArrayList(sessionsArray.length); + for (OpenSslSession session: sessionsArray) { + if (session.isValid()) { + ids.add(session.sessionId()); + } + } + return ids; + } + + /** + * Clear the cache and free all cached SSL_SESSION*. + */ + synchronized void clear() { + Iterator> iterator = sessions.entrySet().iterator(); + while (iterator.hasNext()) { + NativeSslSession session = iterator.next().getValue(); + iterator.remove(); + + // Notify about removal. This also takes care of calling SSL_SESSION_free(...). + notifyRemovalAndFree(session); + } + } + + /** + * {@link OpenSslSession} implementation which wraps the native SSL_SESSION* while in cache. + */ + static final class NativeSslSession implements OpenSslSession { + static final ResourceLeakDetector LEAK_DETECTOR = ResourceLeakDetectorFactory.instance() + .newResourceLeakDetector(NativeSslSession.class); + private final ResourceLeakTracker leakTracker; + private final long session; + private final String peerHost; + private final int peerPort; + private final OpenSslSessionId id; + private final long timeout; + private final long creationTime = System.currentTimeMillis(); + private volatile long lastAccessedTime = creationTime; + private volatile boolean valid = true; + private boolean freed; + + NativeSslSession(long session, String peerHost, int peerPort, long timeout) { + this.session = session; + this.peerHost = peerHost; + this.peerPort = peerPort; + this.timeout = timeout; + this.id = new OpenSslSessionId(io.netty.internal.tcnative.SSLSession.getSessionId(session)); + leakTracker = LEAK_DETECTOR.track(this); + } + + @Override + public void setSessionId(OpenSslSessionId id) { + throw new UnsupportedOperationException(); + } + + boolean shouldBeSingleUse() { + assert !freed; + return SSLSession.shouldBeSingleUse(session); + } + + long session() { + assert !freed; + return session; + } + + boolean upRef() { + assert !freed; + return SSLSession.upRef(session); + } + + synchronized void free() { + close(); + SSLSession.free(session); + } + + void close() { + assert !freed; + freed = true; + invalidate(); + if (leakTracker != null) { + leakTracker.close(this); + } + } + + @Override + public OpenSslSessionId sessionId() { + return id; + } + + boolean isValid(long now) { + return creationTime + timeout >= now && valid; + } + + @Override + public void setLocalCertificate(Certificate[] localCertificate) { + throw new UnsupportedOperationException(); + } + + @Override + public OpenSslSessionContext getSessionContext() { + return null; + } + + @Override + public void tryExpandApplicationBufferSize(int packetLengthDataOnly) { + throw new UnsupportedOperationException(); + } + + @Override + public void handshakeFinished(byte[] id, String cipher, String protocol, byte[] peerCertificate, + byte[][] peerCertificateChain, long creationTime, long timeout) { + throw new UnsupportedOperationException(); + } + + @Override + public byte[] getId() { + return id.cloneBytes(); + } + + @Override + public long getCreationTime() { + return creationTime; + } + + void updateLastAccessedTime() { + lastAccessedTime = System.currentTimeMillis(); + } + + @Override + public long getLastAccessedTime() { + return lastAccessedTime; + } + + @Override + public void invalidate() { + valid = false; + } + + @Override + public boolean isValid() { + return isValid(System.currentTimeMillis()); + } + + @Override + public void putValue(String name, Object value) { + throw new UnsupportedOperationException(); + } + + @Override + public Object getValue(String name) { + return null; + } + + @Override + public void removeValue(String name) { + // NOOP + } + + @Override + public String[] getValueNames() { + return EmptyArrays.EMPTY_STRINGS; + } + + @Override + public Certificate[] getPeerCertificates() { + throw new UnsupportedOperationException(); + } + + @Override + public Certificate[] getLocalCertificates() { + throw new UnsupportedOperationException(); + } + + @Override + public X509Certificate[] getPeerCertificateChain() { + throw new UnsupportedOperationException(); + } + + @Override + public Principal getPeerPrincipal() { + throw new UnsupportedOperationException(); + } + + @Override + public Principal getLocalPrincipal() { + throw new UnsupportedOperationException(); + } + + @Override + public String getCipherSuite() { + return null; + } + + @Override + public String getProtocol() { + return null; + } + + @Override + public String getPeerHost() { + return peerHost; + } + + @Override + public int getPeerPort() { + return peerPort; + } + + @Override + public int getPacketBufferSize() { + return ReferenceCountedOpenSslEngine.MAX_RECORD_SIZE; + } + + @Override + public int getApplicationBufferSize() { + return ReferenceCountedOpenSslEngine.MAX_PLAINTEXT_LENGTH; + } + + @Override + public int hashCode() { + return id.hashCode(); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof OpenSslSession)) { + return false; + } + OpenSslSession session1 = (OpenSslSession) o; + return id.equals(session1.sessionId()); + } + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslSessionContext.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslSessionContext.java new file mode 100644 index 0000000..0da26c9 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslSessionContext.java @@ -0,0 +1,229 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.internal.tcnative.SSL; +import io.netty.internal.tcnative.SSLContext; +import io.netty.internal.tcnative.SessionTicketKey; +import io.netty.util.internal.ObjectUtil; + +import javax.net.ssl.SSLSession; +import javax.net.ssl.SSLSessionContext; +import java.util.Arrays; +import java.util.Enumeration; +import java.util.Iterator; +import java.util.concurrent.locks.Lock; + +/** + * OpenSSL specific {@link SSLSessionContext} implementation. + */ +public abstract class OpenSslSessionContext implements SSLSessionContext { + + private final OpenSslSessionStats stats; + + // The OpenSslKeyMaterialProvider is not really used by the OpenSslSessionContext but only be stored here + // to make it easier to destroy it later because the ReferenceCountedOpenSslContext will hold a reference + // to OpenSslSessionContext. + private final OpenSslKeyMaterialProvider provider; + + final ReferenceCountedOpenSslContext context; + + private final OpenSslSessionCache sessionCache; + private final long mask; + + // IMPORTANT: We take the OpenSslContext and not just the long (which points the native instance) to prevent + // the GC to collect OpenSslContext as this would also free the pointer and so could result in a + // segfault when the user calls any of the methods here that try to pass the pointer down to the native + // level. + OpenSslSessionContext(ReferenceCountedOpenSslContext context, OpenSslKeyMaterialProvider provider, long mask, + OpenSslSessionCache cache) { + this.context = context; + this.provider = provider; + this.mask = mask; + stats = new OpenSslSessionStats(context); + sessionCache = cache; + SSLContext.setSSLSessionCache(context.ctx, cache); + } + + final boolean useKeyManager() { + return provider != null; + } + + @Override + public void setSessionCacheSize(int size) { + ObjectUtil.checkPositiveOrZero(size, "size"); + sessionCache.setSessionCacheSize(size); + } + + @Override + public int getSessionCacheSize() { + return sessionCache.getSessionCacheSize(); + } + + @Override + public void setSessionTimeout(int seconds) { + ObjectUtil.checkPositiveOrZero(seconds, "seconds"); + + Lock writerLock = context.ctxLock.writeLock(); + writerLock.lock(); + try { + SSLContext.setSessionCacheTimeout(context.ctx, seconds); + sessionCache.setSessionTimeout(seconds); + } finally { + writerLock.unlock(); + } + } + + @Override + public int getSessionTimeout() { + return sessionCache.getSessionTimeout(); + } + + @Override + public SSLSession getSession(byte[] bytes) { + return sessionCache.getSession(new OpenSslSessionId(bytes)); + } + + @Override + public Enumeration getIds() { + return new Enumeration() { + private final Iterator ids = sessionCache.getIds().iterator(); + @Override + public boolean hasMoreElements() { + return ids.hasNext(); + } + + @Override + public byte[] nextElement() { + return ids.next().cloneBytes(); + } + }; + } + + /** + * Sets the SSL session ticket keys of this context. + * @deprecated use {@link #setTicketKeys(OpenSslSessionTicketKey...)}. + */ + @Deprecated + public void setTicketKeys(byte[] keys) { + if (keys.length % SessionTicketKey.TICKET_KEY_SIZE != 0) { + throw new IllegalArgumentException("keys.length % " + SessionTicketKey.TICKET_KEY_SIZE + " != 0"); + } + SessionTicketKey[] tickets = new SessionTicketKey[keys.length / SessionTicketKey.TICKET_KEY_SIZE]; + for (int i = 0, a = 0; i < tickets.length; i++) { + byte[] name = Arrays.copyOfRange(keys, a, SessionTicketKey.NAME_SIZE); + a += SessionTicketKey.NAME_SIZE; + byte[] hmacKey = Arrays.copyOfRange(keys, a, SessionTicketKey.HMAC_KEY_SIZE); + i += SessionTicketKey.HMAC_KEY_SIZE; + byte[] aesKey = Arrays.copyOfRange(keys, a, SessionTicketKey.AES_KEY_SIZE); + a += SessionTicketKey.AES_KEY_SIZE; + tickets[i] = new SessionTicketKey(name, hmacKey, aesKey); + } + Lock writerLock = context.ctxLock.writeLock(); + writerLock.lock(); + try { + SSLContext.clearOptions(context.ctx, SSL.SSL_OP_NO_TICKET); + SSLContext.setSessionTicketKeys(context.ctx, tickets); + } finally { + writerLock.unlock(); + } + } + + /** + * Sets the SSL session ticket keys of this context. Depending on the underlying native library you may omit the + * argument or pass an empty array and so let the native library handle the key generation and rotating for you. + * If this is supported by the underlying native library should be checked in this case. For example + * + * BoringSSL is known to support this. + */ + public void setTicketKeys(OpenSslSessionTicketKey... keys) { + ObjectUtil.checkNotNull(keys, "keys"); + SessionTicketKey[] ticketKeys = new SessionTicketKey[keys.length]; + for (int i = 0; i < ticketKeys.length; i++) { + ticketKeys[i] = keys[i].key; + } + Lock writerLock = context.ctxLock.writeLock(); + writerLock.lock(); + try { + SSLContext.clearOptions(context.ctx, SSL.SSL_OP_NO_TICKET); + if (ticketKeys.length > 0) { + SSLContext.setSessionTicketKeys(context.ctx, ticketKeys); + } + } finally { + writerLock.unlock(); + } + } + + /** + * Enable or disable caching of SSL sessions. + */ + public void setSessionCacheEnabled(boolean enabled) { + long mode = enabled ? mask | SSL.SSL_SESS_CACHE_NO_INTERNAL_LOOKUP | + SSL.SSL_SESS_CACHE_NO_INTERNAL_STORE : SSL.SSL_SESS_CACHE_OFF; + Lock writerLock = context.ctxLock.writeLock(); + writerLock.lock(); + try { + SSLContext.setSessionCacheMode(context.ctx, mode); + if (!enabled) { + sessionCache.clear(); + } + } finally { + writerLock.unlock(); + } + } + + /** + * Return {@code true} if caching of SSL sessions is enabled, {@code false} otherwise. + */ + public boolean isSessionCacheEnabled() { + Lock readerLock = context.ctxLock.readLock(); + readerLock.lock(); + try { + return (SSLContext.getSessionCacheMode(context.ctx) & mask) != 0; + } finally { + readerLock.unlock(); + } + } + + /** + * Returns the stats of this context. + */ + public OpenSslSessionStats stats() { + return stats; + } + + /** + * Remove the given {@link OpenSslSession} from the cache, and so not re-use it for new connections. + */ + final void removeFromCache(OpenSslSessionId id) { + sessionCache.removeSessionWithId(id); + } + + final boolean isInCache(OpenSslSessionId id) { + return sessionCache.containsSessionWithId(id); + } + + void setSessionFromCache(String host, int port, long ssl) { + sessionCache.setSession(ssl, host, port); + } + + final void destroy() { + if (provider != null) { + provider.destroy(); + } + sessionCache.clear(); + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslSessionId.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslSessionId.java new file mode 100644 index 0000000..76941f7 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslSessionId.java @@ -0,0 +1,66 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.util.internal.EmptyArrays; + +import java.util.Arrays; + +/** + * Represent the session ID used by an {@link OpenSslSession}. + */ +final class OpenSslSessionId { + + private final byte[] id; + private final int hashCode; + + static final OpenSslSessionId NULL_ID = new OpenSslSessionId(EmptyArrays.EMPTY_BYTES); + + OpenSslSessionId(byte[] id) { + // We take ownership if the byte[] and so there is no need to clone it. + this.id = id; + // cache the hashCode as the byte[] array will never change + this.hashCode = Arrays.hashCode(id); + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (!(o instanceof OpenSslSessionId)) { + return false; + } + + return Arrays.equals(id, ((OpenSslSessionId) o).id); + } + + @Override + public String toString() { + return "OpenSslSessionId{" + + "id=" + Arrays.toString(id) + + '}'; + } + + @Override + public int hashCode() { + return hashCode; + } + + byte[] cloneBytes() { + return id.clone(); + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslSessionStats.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslSessionStats.java new file mode 100644 index 0000000..85a0ee8 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslSessionStats.java @@ -0,0 +1,253 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.ssl; + +import io.netty.internal.tcnative.SSLContext; + +import java.util.concurrent.locks.Lock; + +/** + * Stats exposed by an OpenSSL session context. + * + * @see SSL_CTX_sess_number + */ +public final class OpenSslSessionStats { + + private final ReferenceCountedOpenSslContext context; + + // IMPORTANT: We take the OpenSslContext and not just the long (which points the native instance) to prevent + // the GC to collect OpenSslContext as this would also free the pointer and so could result in a + // segfault when the user calls any of the methods here that try to pass the pointer down to the native + // level. + OpenSslSessionStats(ReferenceCountedOpenSslContext context) { + this.context = context; + } + + /** + * Returns the current number of sessions in the internal session cache. + */ + public long number() { + Lock readerLock = context.ctxLock.readLock(); + readerLock.lock(); + try { + return SSLContext.sessionNumber(context.ctx); + } finally { + readerLock.unlock(); + } + } + + /** + * Returns the number of started SSL/TLS handshakes in client mode. + */ + public long connect() { + Lock readerLock = context.ctxLock.readLock(); + readerLock.lock(); + try { + return SSLContext.sessionConnect(context.ctx); + } finally { + readerLock.unlock(); + } + } + + /** + * Returns the number of successfully established SSL/TLS sessions in client mode. + */ + public long connectGood() { + Lock readerLock = context.ctxLock.readLock(); + readerLock.lock(); + try { + return SSLContext.sessionConnectGood(context.ctx); + } finally { + readerLock.unlock(); + } + } + + /** + * Returns the number of start renegotiations in client mode. + */ + public long connectRenegotiate() { + Lock readerLock = context.ctxLock.readLock(); + readerLock.lock(); + try { + return SSLContext.sessionConnectRenegotiate(context.ctx); + } finally { + readerLock.unlock(); + } + } + + /** + * Returns the number of started SSL/TLS handshakes in server mode. + */ + public long accept() { + Lock readerLock = context.ctxLock.readLock(); + readerLock.lock(); + try { + return SSLContext.sessionAccept(context.ctx); + } finally { + readerLock.unlock(); + } + } + + /** + * Returns the number of successfully established SSL/TLS sessions in server mode. + */ + public long acceptGood() { + Lock readerLock = context.ctxLock.readLock(); + readerLock.lock(); + try { + return SSLContext.sessionAcceptGood(context.ctx); + } finally { + readerLock.unlock(); + } + } + + /** + * Returns the number of start renegotiations in server mode. + */ + public long acceptRenegotiate() { + Lock readerLock = context.ctxLock.readLock(); + readerLock.lock(); + try { + return SSLContext.sessionAcceptRenegotiate(context.ctx); + } finally { + readerLock.unlock(); + } + } + + /** + * Returns the number of successfully reused sessions. In client mode, a session set with {@code SSL_set_session} + * successfully reused is counted as a hit. In server mode, a session successfully retrieved from internal or + * external cache is counted as a hit. + */ + public long hits() { + Lock readerLock = context.ctxLock.readLock(); + readerLock.lock(); + try { + return SSLContext.sessionHits(context.ctx); + } finally { + readerLock.unlock(); + } + } + + /** + * Returns the number of successfully retrieved sessions from the external session cache in server mode. + */ + public long cbHits() { + Lock readerLock = context.ctxLock.readLock(); + readerLock.lock(); + try { + return SSLContext.sessionCbHits(context.ctx); + } finally { + readerLock.unlock(); + } + } + + /** + * Returns the number of sessions proposed by clients that were not found in the internal session cache + * in server mode. + */ + public long misses() { + Lock readerLock = context.ctxLock.readLock(); + readerLock.lock(); + try { + return SSLContext.sessionMisses(context.ctx); + } finally { + readerLock.unlock(); + } + } + + /** + * Returns the number of sessions proposed by clients and either found in the internal or external session cache + * in server mode, but that were invalid due to timeout. These sessions are not included in the {@link #hits()} + * count. + */ + public long timeouts() { + Lock readerLock = context.ctxLock.readLock(); + readerLock.lock(); + try { + return SSLContext.sessionTimeouts(context.ctx); + } finally { + readerLock.unlock(); + } + } + + /** + * Returns the number of sessions that were removed because the maximum session cache size was exceeded. + */ + public long cacheFull() { + Lock readerLock = context.ctxLock.readLock(); + readerLock.lock(); + try { + return SSLContext.sessionCacheFull(context.ctx); + } finally { + readerLock.unlock(); + } + } + + /** + * Returns the number of times a client presented a ticket that did not match any key in the list. + */ + public long ticketKeyFail() { + Lock readerLock = context.ctxLock.readLock(); + readerLock.lock(); + try { + return SSLContext.sessionTicketKeyFail(context.ctx); + } finally { + readerLock.unlock(); + } + } + + /** + * Returns the number of times a client did not present a ticket and we issued a new one + */ + public long ticketKeyNew() { + Lock readerLock = context.ctxLock.readLock(); + readerLock.lock(); + try { + return SSLContext.sessionTicketKeyNew(context.ctx); + } finally { + readerLock.unlock(); + } + } + + /** + * Returns the number of times a client presented a ticket derived from an older key, + * and we upgraded to the primary key. + */ + public long ticketKeyRenew() { + Lock readerLock = context.ctxLock.readLock(); + readerLock.lock(); + try { + return SSLContext.sessionTicketKeyRenew(context.ctx); + } finally { + readerLock.unlock(); + } + } + + /** + * Returns the number of times a client presented a ticket derived from the primary key. + */ + public long ticketKeyResume() { + Lock readerLock = context.ctxLock.readLock(); + readerLock.lock(); + try { + return SSLContext.sessionTicketKeyResume(context.ctx); + } finally { + readerLock.unlock(); + } + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslSessionTicketKey.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslSessionTicketKey.java new file mode 100644 index 0000000..175a37e --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslSessionTicketKey.java @@ -0,0 +1,78 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.internal.tcnative.SessionTicketKey; + +/** + * Session Ticket Key + */ +public final class OpenSslSessionTicketKey { + + /** + * Size of session ticket key name + */ + public static final int NAME_SIZE = SessionTicketKey.NAME_SIZE; + /** + * Size of session ticket key HMAC key + */ + public static final int HMAC_KEY_SIZE = SessionTicketKey.HMAC_KEY_SIZE; + /** + * Size of session ticket key AES key + */ + public static final int AES_KEY_SIZE = SessionTicketKey.AES_KEY_SIZE; + /** + * Size of session ticker key + */ + public static final int TICKET_KEY_SIZE = SessionTicketKey.TICKET_KEY_SIZE; + + final SessionTicketKey key; + + /** + * Construct a OpenSslSessionTicketKey. + * + * @param name the name of the session ticket key + * @param hmacKey the HMAC key of the session ticket key + * @param aesKey the AES key of the session ticket key + */ + public OpenSslSessionTicketKey(byte[] name, byte[] hmacKey, byte[] aesKey) { + key = new SessionTicketKey(name.clone(), hmacKey.clone(), aesKey.clone()); + } + + /** + * Get name. + * @return the name of the session ticket key + */ + public byte[] name() { + return key.getName().clone(); + } + + /** + * Get HMAC key. + * @return the HMAC key of the session ticket key + */ + public byte[] hmacKey() { + return key.getHmacKey().clone(); + } + + /** + * Get AES Key. + * @return the AES key of the session ticket key + */ + public byte[] aesKey() { + return key.getAesKey().clone(); + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslX509KeyManagerFactory.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslX509KeyManagerFactory.java new file mode 100644 index 0000000..df711a0 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslX509KeyManagerFactory.java @@ -0,0 +1,416 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import static io.netty.util.internal.ObjectUtil.checkNonEmpty; +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.UnpooledByteBufAllocator; +import io.netty.internal.tcnative.SSL; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.internal.ObjectUtil; + +import javax.net.ssl.KeyManager; +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.KeyManagerFactorySpi; +import javax.net.ssl.ManagerFactoryParameters; +import javax.net.ssl.X509KeyManager; +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.security.InvalidAlgorithmParameterException; +import java.security.Key; +import java.security.KeyStore; +import java.security.KeyStoreException; +import java.security.KeyStoreSpi; +import java.security.NoSuchAlgorithmException; +import java.security.PrivateKey; +import java.security.Provider; +import java.security.UnrecoverableKeyException; +import java.security.cert.Certificate; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; +import java.util.Collections; +import java.util.Date; +import java.util.Enumeration; +import java.util.HashMap; +import java.util.Map; + +/** + * Special {@link KeyManagerFactory} that pre-compute the keymaterial used when {@link SslProvider#OPENSSL} or + * {@link SslProvider#OPENSSL_REFCNT} is used and so will improve handshake times and its performance. + * + * + * + * Because the keymaterial is pre-computed any modification to the {@link KeyStore} is ignored after + * {@link #init(KeyStore, char[])} is called. + * + * {@link #init(ManagerFactoryParameters)} is not supported by this implementation and so a call to it will always + * result in an {@link InvalidAlgorithmParameterException}. + */ +public final class OpenSslX509KeyManagerFactory extends KeyManagerFactory { + + private final OpenSslKeyManagerFactorySpi spi; + + public OpenSslX509KeyManagerFactory() { + this(newOpenSslKeyManagerFactorySpi(null)); + } + + public OpenSslX509KeyManagerFactory(Provider provider) { + this(newOpenSslKeyManagerFactorySpi(provider)); + } + + public OpenSslX509KeyManagerFactory(String algorithm, Provider provider) throws NoSuchAlgorithmException { + this(newOpenSslKeyManagerFactorySpi(algorithm, provider)); + } + + private OpenSslX509KeyManagerFactory(OpenSslKeyManagerFactorySpi spi) { + super(spi, spi.kmf.getProvider(), spi.kmf.getAlgorithm()); + this.spi = spi; + } + + private static OpenSslKeyManagerFactorySpi newOpenSslKeyManagerFactorySpi(Provider provider) { + try { + return newOpenSslKeyManagerFactorySpi(null, provider); + } catch (NoSuchAlgorithmException e) { + // This should never happen as we use the default algorithm. + throw new IllegalStateException(e); + } + } + + private static OpenSslKeyManagerFactorySpi newOpenSslKeyManagerFactorySpi(String algorithm, Provider provider) + throws NoSuchAlgorithmException { + if (algorithm == null) { + algorithm = KeyManagerFactory.getDefaultAlgorithm(); + } + return new OpenSslKeyManagerFactorySpi( + provider == null ? KeyManagerFactory.getInstance(algorithm) : + KeyManagerFactory.getInstance(algorithm, provider)); + } + + OpenSslKeyMaterialProvider newProvider() { + return spi.newProvider(); + } + + private static final class OpenSslKeyManagerFactorySpi extends KeyManagerFactorySpi { + final KeyManagerFactory kmf; + private volatile ProviderFactory providerFactory; + + OpenSslKeyManagerFactorySpi(KeyManagerFactory kmf) { + this.kmf = ObjectUtil.checkNotNull(kmf, "kmf"); + } + + @Override + protected synchronized void engineInit(KeyStore keyStore, char[] chars) + throws KeyStoreException, NoSuchAlgorithmException, UnrecoverableKeyException { + if (providerFactory != null) { + throw new KeyStoreException("Already initialized"); + } + if (!keyStore.aliases().hasMoreElements()) { + throw new KeyStoreException("No aliases found"); + } + + kmf.init(keyStore, chars); + providerFactory = new ProviderFactory(ReferenceCountedOpenSslContext.chooseX509KeyManager( + kmf.getKeyManagers()), password(chars), Collections.list(keyStore.aliases())); + } + + private static String password(char[] password) { + if (password == null || password.length == 0) { + return null; + } + return new String(password); + } + + @Override + protected void engineInit(ManagerFactoryParameters managerFactoryParameters) + throws InvalidAlgorithmParameterException { + throw new InvalidAlgorithmParameterException("Not supported"); + } + + @Override + protected KeyManager[] engineGetKeyManagers() { + ProviderFactory providerFactory = this.providerFactory; + if (providerFactory == null) { + throw new IllegalStateException("engineInit(...) not called yet"); + } + return new KeyManager[] { providerFactory.keyManager }; + } + + OpenSslKeyMaterialProvider newProvider() { + ProviderFactory providerFactory = this.providerFactory; + if (providerFactory == null) { + throw new IllegalStateException("engineInit(...) not called yet"); + } + return providerFactory.newProvider(); + } + + private static final class ProviderFactory { + private final X509KeyManager keyManager; + private final String password; + private final Iterable aliases; + + ProviderFactory(X509KeyManager keyManager, String password, Iterable aliases) { + this.keyManager = keyManager; + this.password = password; + this.aliases = aliases; + } + + OpenSslKeyMaterialProvider newProvider() { + return new OpenSslPopulatedKeyMaterialProvider(keyManager, + password, aliases); + } + + /** + * {@link OpenSslKeyMaterialProvider} implementation that pre-compute the {@link OpenSslKeyMaterial} for + * all aliases. + */ + private static final class OpenSslPopulatedKeyMaterialProvider extends OpenSslKeyMaterialProvider { + private final Map materialMap; + + OpenSslPopulatedKeyMaterialProvider( + X509KeyManager keyManager, String password, Iterable aliases) { + super(keyManager, password); + materialMap = new HashMap(); + boolean initComplete = false; + try { + for (String alias: aliases) { + if (alias != null && !materialMap.containsKey(alias)) { + try { + materialMap.put(alias, super.chooseKeyMaterial( + UnpooledByteBufAllocator.DEFAULT, alias)); + } catch (Exception e) { + // Just store the exception and rethrow it when we try to choose the keymaterial + // for this alias later on. + materialMap.put(alias, e); + } + } + } + initComplete = true; + } finally { + if (!initComplete) { + destroy(); + } + } + checkNonEmpty(materialMap, "materialMap"); + } + + @Override + OpenSslKeyMaterial chooseKeyMaterial(ByteBufAllocator allocator, String alias) throws Exception { + Object value = materialMap.get(alias); + if (value == null) { + // There is no keymaterial for the requested alias, return null + return null; + } + if (value instanceof OpenSslKeyMaterial) { + return ((OpenSslKeyMaterial) value).retain(); + } + throw (Exception) value; + } + + @Override + void destroy() { + for (Object material: materialMap.values()) { + ReferenceCountUtil.release(material); + } + materialMap.clear(); + } + } + } + } + + /** + * Create a new initialized {@link OpenSslX509KeyManagerFactory} which loads its {@link PrivateKey} directly from + * an {@code OpenSSL engine} via the + * ENGINE_load_private_key + * function. + */ + public static OpenSslX509KeyManagerFactory newEngineBased(File certificateChain, String password) + throws CertificateException, IOException, + KeyStoreException, NoSuchAlgorithmException, UnrecoverableKeyException { + return newEngineBased(SslContext.toX509Certificates(certificateChain), password); + } + + /** + * Create a new initialized {@link OpenSslX509KeyManagerFactory} which loads its {@link PrivateKey} directly from + * an {@code OpenSSL engine} via the + * ENGINE_load_private_key + * function. + */ + public static OpenSslX509KeyManagerFactory newEngineBased(X509Certificate[] certificateChain, String password) + throws CertificateException, IOException, + KeyStoreException, NoSuchAlgorithmException, UnrecoverableKeyException { + checkNotNull(certificateChain, "certificateChain"); + KeyStore store = new OpenSslKeyStore(certificateChain.clone(), false); + store.load(null, null); + OpenSslX509KeyManagerFactory factory = new OpenSslX509KeyManagerFactory(); + factory.init(store, password == null ? null : password.toCharArray()); + return factory; + } + + /** + * See {@link OpenSslX509KeyManagerFactory#newEngineBased(X509Certificate[], String)}. + */ + public static OpenSslX509KeyManagerFactory newKeyless(File chain) + throws CertificateException, IOException, + KeyStoreException, NoSuchAlgorithmException, UnrecoverableKeyException { + return newKeyless(SslContext.toX509Certificates(chain)); + } + + /** + * See {@link OpenSslX509KeyManagerFactory#newEngineBased(X509Certificate[], String)}. + */ + public static OpenSslX509KeyManagerFactory newKeyless(InputStream chain) + throws CertificateException, IOException, + KeyStoreException, NoSuchAlgorithmException, UnrecoverableKeyException { + return newKeyless(SslContext.toX509Certificates(chain)); + } + + /** + * Returns a new initialized {@link OpenSslX509KeyManagerFactory} which will provide its private key by using the + * {@link OpenSslPrivateKeyMethod}. + */ + public static OpenSslX509KeyManagerFactory newKeyless(X509Certificate... certificateChain) + throws CertificateException, IOException, + KeyStoreException, NoSuchAlgorithmException, UnrecoverableKeyException { + checkNotNull(certificateChain, "certificateChain"); + KeyStore store = new OpenSslKeyStore(certificateChain.clone(), true); + store.load(null, null); + OpenSslX509KeyManagerFactory factory = new OpenSslX509KeyManagerFactory(); + factory.init(store, null); + return factory; + } + + private static final class OpenSslKeyStore extends KeyStore { + private OpenSslKeyStore(final X509Certificate[] certificateChain, final boolean keyless) { + super(new KeyStoreSpi() { + + private final Date creationDate = new Date(); + + @Override + public Key engineGetKey(String alias, char[] password) throws UnrecoverableKeyException { + if (engineContainsAlias(alias)) { + final long privateKeyAddress; + if (keyless) { + privateKeyAddress = 0; + } else { + try { + privateKeyAddress = SSL.loadPrivateKeyFromEngine( + alias, password == null ? null : new String(password)); + } catch (Exception e) { + UnrecoverableKeyException keyException = + new UnrecoverableKeyException("Unable to load key from engine"); + keyException.initCause(e); + throw keyException; + } + } + return new OpenSslPrivateKey(privateKeyAddress); + } + return null; + } + + @Override + public Certificate[] engineGetCertificateChain(String alias) { + return engineContainsAlias(alias)? certificateChain.clone() : null; + } + + @Override + public Certificate engineGetCertificate(String alias) { + return engineContainsAlias(alias)? certificateChain[0] : null; + } + + @Override + public Date engineGetCreationDate(String alias) { + return engineContainsAlias(alias)? creationDate : null; + } + + @Override + public void engineSetKeyEntry(String alias, Key key, char[] password, Certificate[] chain) + throws KeyStoreException { + throw new KeyStoreException("Not supported"); + } + + @Override + public void engineSetKeyEntry(String alias, byte[] key, Certificate[] chain) throws KeyStoreException { + throw new KeyStoreException("Not supported"); + } + + @Override + public void engineSetCertificateEntry(String alias, Certificate cert) throws KeyStoreException { + throw new KeyStoreException("Not supported"); + } + + @Override + public void engineDeleteEntry(String alias) throws KeyStoreException { + throw new KeyStoreException("Not supported"); + } + + @Override + public Enumeration engineAliases() { + return Collections.enumeration(Collections.singleton(SslContext.ALIAS)); + } + + @Override + public boolean engineContainsAlias(String alias) { + return SslContext.ALIAS.equals(alias); + } + + @Override + public int engineSize() { + return 1; + } + + @Override + public boolean engineIsKeyEntry(String alias) { + return engineContainsAlias(alias); + } + + @Override + public boolean engineIsCertificateEntry(String alias) { + return engineContainsAlias(alias); + } + + @Override + public String engineGetCertificateAlias(Certificate cert) { + if (cert instanceof X509Certificate) { + for (X509Certificate x509Certificate : certificateChain) { + if (x509Certificate.equals(cert)) { + return SslContext.ALIAS; + } + } + } + return null; + } + + @Override + public void engineStore(OutputStream stream, char[] password) { + throw new UnsupportedOperationException(); + } + + @Override + public void engineLoad(InputStream stream, char[] password) { + if (stream != null && password != null) { + throw new UnsupportedOperationException(); + } + } + }, null, "native"); + + OpenSsl.ensureAvailability(); + } + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslX509TrustManagerWrapper.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslX509TrustManagerWrapper.java new file mode 100644 index 0000000..8851db1 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OpenSslX509TrustManagerWrapper.java @@ -0,0 +1,202 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.util.internal.EmptyArrays; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.SuppressJava6Requirement; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import javax.net.ssl.SSLContext; +import javax.net.ssl.TrustManager; +import javax.net.ssl.X509ExtendedTrustManager; +import javax.net.ssl.X509TrustManager; +import java.lang.reflect.Field; +import java.security.AccessController; +import java.security.KeyManagementException; +import java.security.NoSuchAlgorithmException; +import java.security.NoSuchProviderException; +import java.security.PrivilegedAction; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; + +/** + * Utility which allows to wrap {@link X509TrustManager} implementations with the internal implementation used by + * {@code SSLContextImpl} that provides extended verification. + * + * This is really a "hack" until there is an official API as requested on the in + * JDK-8210843. + */ +@SuppressJava6Requirement(reason = "Usage guarded by java version check") +final class OpenSslX509TrustManagerWrapper { + private static final InternalLogger LOGGER = InternalLoggerFactory + .getInstance(OpenSslX509TrustManagerWrapper.class); + private static final TrustManagerWrapper WRAPPER; + + static { + // By default we will not do any wrapping but just return the passed in manager. + TrustManagerWrapper wrapper = new TrustManagerWrapper() { + @Override + public X509TrustManager wrapIfNeeded(X509TrustManager manager) { + return manager; + } + }; + + Throwable cause = null; + Throwable unsafeCause = PlatformDependent.getUnsafeUnavailabilityCause(); + if (unsafeCause == null) { + SSLContext context; + try { + context = newSSLContext(); + // Now init with an array that only holds a X509TrustManager. This should be wrapped into an + // AbstractTrustManagerWrapper which will delegate the TrustManager itself but also do extra + // validations. + // + // See: + // - https://hg.openjdk.java.net/jdk8u/jdk8u/jdk/file/ + // cadea780bc76/src/share/classes/sun/security/ssl/SSLContextImpl.java#l127 + context.init(null, new TrustManager[] { + new X509TrustManager() { + @Override + public void checkClientTrusted(X509Certificate[] x509Certificates, String s) + throws CertificateException { + throw new CertificateException(); + } + + @Override + public void checkServerTrusted(X509Certificate[] x509Certificates, String s) + throws CertificateException { + throw new CertificateException(); + } + + @Override + public X509Certificate[] getAcceptedIssuers() { + return EmptyArrays.EMPTY_X509_CERTIFICATES; + } + } + }, null); + } catch (Throwable error) { + context = null; + cause = error; + } + if (cause != null) { + LOGGER.debug("Unable to access wrapped TrustManager", cause); + } else { + final SSLContext finalContext = context; + Object maybeWrapper = AccessController.doPrivileged(new PrivilegedAction() { + @Override + public Object run() { + try { + Field contextSpiField = SSLContext.class.getDeclaredField("contextSpi"); + final long spiOffset = PlatformDependent.objectFieldOffset(contextSpiField); + Object spi = PlatformDependent.getObject(finalContext, spiOffset); + if (spi != null) { + Class clazz = spi.getClass(); + + // Let's cycle through the whole hierarchy until we find what we are looking for or + // there is nothing left in which case we will not wrap at all. + do { + try { + Field trustManagerField = clazz.getDeclaredField("trustManager"); + final long tmOffset = PlatformDependent.objectFieldOffset(trustManagerField); + Object trustManager = PlatformDependent.getObject(spi, tmOffset); + if (trustManager instanceof X509ExtendedTrustManager) { + return new UnsafeTrustManagerWrapper(spiOffset, tmOffset); + } + } catch (NoSuchFieldException ignore) { + // try next + } + clazz = clazz.getSuperclass(); + } while (clazz != null); + } + throw new NoSuchFieldException(); + } catch (NoSuchFieldException e) { + return e; + } catch (SecurityException e) { + return e; + } + } + }); + if (maybeWrapper instanceof Throwable) { + LOGGER.debug("Unable to access wrapped TrustManager", (Throwable) maybeWrapper); + } else { + wrapper = (TrustManagerWrapper) maybeWrapper; + } + } + } else { + LOGGER.debug("Unable to access wrapped TrustManager", cause); + } + WRAPPER = wrapper; + } + + private OpenSslX509TrustManagerWrapper() { } + + static X509TrustManager wrapIfNeeded(X509TrustManager trustManager) { + return WRAPPER.wrapIfNeeded(trustManager); + } + + private interface TrustManagerWrapper { + X509TrustManager wrapIfNeeded(X509TrustManager manager); + } + + private static SSLContext newSSLContext() throws NoSuchAlgorithmException, NoSuchProviderException { + // As this depends on the implementation detail we should explicit select the correct provider. + // See https://github.com/netty/netty/issues/10374 + return SSLContext.getInstance("TLS", "SunJSSE"); + } + + private static final class UnsafeTrustManagerWrapper implements TrustManagerWrapper { + private final long spiOffset; + private final long tmOffset; + + UnsafeTrustManagerWrapper(long spiOffset, long tmOffset) { + this.spiOffset = spiOffset; + this.tmOffset = tmOffset; + } + + @SuppressJava6Requirement(reason = "Usage guarded by java version check") + @Override + public X509TrustManager wrapIfNeeded(X509TrustManager manager) { + if (!(manager instanceof X509ExtendedTrustManager)) { + try { + SSLContext ctx = newSSLContext(); + ctx.init(null, new TrustManager[] { manager }, null); + Object spi = PlatformDependent.getObject(ctx, spiOffset); + if (spi != null) { + Object tm = PlatformDependent.getObject(spi, tmOffset); + if (tm instanceof X509ExtendedTrustManager) { + return (X509TrustManager) tm; + } + } + } catch (NoSuchAlgorithmException e) { + // This should never happen as we did the same in the static block + // before. + PlatformDependent.throwException(e); + } catch (KeyManagementException e) { + // This should never happen as we did the same in the static block + // before. + PlatformDependent.throwException(e); + } catch (NoSuchProviderException e) { + // This should never happen as we did the same in the static block + // before. + PlatformDependent.throwException(e); + } + } + return manager; + } + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OptionalSslHandler.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OptionalSslHandler.java new file mode 100644 index 0000000..45afdbe --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/OptionalSslHandler.java @@ -0,0 +1,117 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.ByteToMessageDecoder; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.internal.ObjectUtil; + +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLParameters; +import java.util.List; + +/** + * {@link OptionalSslHandler} is a utility decoder to support both SSL and non-SSL handlers + * based on the first message received. + */ +public class OptionalSslHandler extends ByteToMessageDecoder { + + private final SslContext sslContext; + + public OptionalSslHandler(SslContext sslContext) { + this.sslContext = ObjectUtil.checkNotNull(sslContext, "sslContext"); + } + + @Override + protected void decode(ChannelHandlerContext context, ByteBuf in, List out) throws Exception { + if (in.readableBytes() < SslUtils.SSL_RECORD_HEADER_LENGTH) { + return; + } + if (SslHandler.isEncrypted(in)) { + handleSsl(context); + } else { + handleNonSsl(context); + } + } + + private void handleSsl(ChannelHandlerContext context) { + SslHandler sslHandler = null; + try { + sslHandler = newSslHandler(context, sslContext); + context.pipeline().replace(this, newSslHandlerName(), sslHandler); + sslHandler = null; + } finally { + // Since the SslHandler was not inserted into the pipeline the ownership of the SSLEngine was not + // transferred to the SslHandler. + if (sslHandler != null) { + ReferenceCountUtil.safeRelease(sslHandler.engine()); + } + } + } + + private void handleNonSsl(ChannelHandlerContext context) { + ChannelHandler handler = newNonSslHandler(context); + if (handler != null) { + context.pipeline().replace(this, newNonSslHandlerName(), handler); + } else { + context.pipeline().remove(this); + } + } + + /** + * Optionally specify the SSL handler name, this method may return {@code null}. + * @return the name of the SSL handler. + */ + protected String newSslHandlerName() { + return null; + } + + /** + * Override to configure the SslHandler eg. {@link SSLParameters#setEndpointIdentificationAlgorithm(String)}. + * The hostname and port is not known by this method so servers may want to override this method and use the + * {@link SslContext#newHandler(ByteBufAllocator, String, int)} variant. + * + * @param context the {@link ChannelHandlerContext} to use. + * @param sslContext the {@link SSLContext} to use. + * @return the {@link SslHandler} which will replace the {@link OptionalSslHandler} in the pipeline if the + * traffic is SSL. + */ + protected SslHandler newSslHandler(ChannelHandlerContext context, SslContext sslContext) { + return sslContext.newHandler(context.alloc()); + } + + /** + * Optionally specify the non-SSL handler name, this method may return {@code null}. + * @return the name of the non-SSL handler. + */ + protected String newNonSslHandlerName() { + return null; + } + + /** + * Override to configure the ChannelHandler. + * @param context the {@link ChannelHandlerContext} to use. + * @return the {@link ChannelHandler} which will replace the {@link OptionalSslHandler} in the pipeline + * or {@code null} to simply remove the {@link OptionalSslHandler} if the traffic is non-SSL. + */ + protected ChannelHandler newNonSslHandler(ChannelHandlerContext context) { + return null; + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/PemEncoded.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/PemEncoded.java new file mode 100644 index 0000000..634d9e2 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/PemEncoded.java @@ -0,0 +1,55 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufHolder; + +/** + * A marker interface for PEM encoded values. + */ +interface PemEncoded extends ByteBufHolder { + + /** + * Returns {@code true} if the PEM encoded value is considered + * sensitive information such as a private key. + */ + boolean isSensitive(); + + @Override + PemEncoded copy(); + + @Override + PemEncoded duplicate(); + + @Override + PemEncoded retainedDuplicate(); + + @Override + PemEncoded replace(ByteBuf content); + + @Override + PemEncoded retain(); + + @Override + PemEncoded retain(int increment); + + @Override + PemEncoded touch(); + + @Override + PemEncoded touch(Object hint); +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/PemPrivateKey.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/PemPrivateKey.java new file mode 100644 index 0000000..887d85b --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/PemPrivateKey.java @@ -0,0 +1,230 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import java.security.PrivateKey; + +import javax.security.auth.Destroyable; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.netty.util.AbstractReferenceCounted; +import io.netty.util.CharsetUtil; +import io.netty.util.IllegalReferenceCountException; +import io.netty.util.internal.ObjectUtil; + +/** + * This is a special purpose implementation of a {@link PrivateKey} which allows the + * user to pass PEM/PKCS#8 encoded key material straight into {@link OpenSslContext} + * without having to parse and re-encode bytes in Java land. + * + * All methods other than what's implemented in {@link PemEncoded} and {@link Destroyable} + * throw {@link UnsupportedOperationException}s. + * + * @see PemEncoded + * @see OpenSslContext + * @see #valueOf(byte[]) + * @see #valueOf(ByteBuf) + */ +public final class PemPrivateKey extends AbstractReferenceCounted implements PrivateKey, PemEncoded { + private static final long serialVersionUID = 7978017465645018936L; + + private static final byte[] BEGIN_PRIVATE_KEY = "-----BEGIN PRIVATE KEY-----\n".getBytes(CharsetUtil.US_ASCII); + private static final byte[] END_PRIVATE_KEY = "\n-----END PRIVATE KEY-----\n".getBytes(CharsetUtil.US_ASCII); + + private static final String PKCS8_FORMAT = "PKCS#8"; + + /** + * Creates a {@link PemEncoded} value from the {@link PrivateKey}. + */ + static PemEncoded toPEM(ByteBufAllocator allocator, boolean useDirect, PrivateKey key) { + // We can take a shortcut if the private key happens to be already + // PEM/PKCS#8 encoded. This is the ideal case and reason why all + // this exists. It allows the user to pass pre-encoded bytes straight + // into OpenSSL without having to do any of the extra work. + if (key instanceof PemEncoded) { + return ((PemEncoded) key).retain(); + } + + byte[] bytes = key.getEncoded(); + if (bytes == null) { + throw new IllegalArgumentException(key.getClass().getName() + " does not support encoding"); + } + + return toPEM(allocator, useDirect, bytes); + } + + static PemEncoded toPEM(ByteBufAllocator allocator, boolean useDirect, byte[] bytes) { + ByteBuf encoded = Unpooled.wrappedBuffer(bytes); + try { + ByteBuf base64 = SslUtils.toBase64(allocator, encoded); + try { + int size = BEGIN_PRIVATE_KEY.length + base64.readableBytes() + END_PRIVATE_KEY.length; + + boolean success = false; + final ByteBuf pem = useDirect ? allocator.directBuffer(size) : allocator.buffer(size); + try { + pem.writeBytes(BEGIN_PRIVATE_KEY); + pem.writeBytes(base64); + pem.writeBytes(END_PRIVATE_KEY); + + PemValue value = new PemValue(pem, true); + success = true; + return value; + } finally { + // Make sure we never leak that PEM ByteBuf if there's an Exception. + if (!success) { + SslUtils.zerooutAndRelease(pem); + } + } + } finally { + SslUtils.zerooutAndRelease(base64); + } + } finally { + SslUtils.zerooutAndRelease(encoded); + } + } + + /** + * Creates a {@link PemPrivateKey} from raw {@code byte[]}. + * + * ATTENTION: It's assumed that the given argument is a PEM/PKCS#8 encoded value. + * No input validation is performed to validate it. + */ + public static PemPrivateKey valueOf(byte[] key) { + return valueOf(Unpooled.wrappedBuffer(key)); + } + + /** + * Creates a {@link PemPrivateKey} from raw {@code ByteBuf}. + * + * ATTENTION: It's assumed that the given argument is a PEM/PKCS#8 encoded value. + * No input validation is performed to validate it. + */ + public static PemPrivateKey valueOf(ByteBuf key) { + return new PemPrivateKey(key); + } + + private final ByteBuf content; + + private PemPrivateKey(ByteBuf content) { + this.content = ObjectUtil.checkNotNull(content, "content"); + } + + @Override + public boolean isSensitive() { + return true; + } + + @Override + public ByteBuf content() { + int count = refCnt(); + if (count <= 0) { + throw new IllegalReferenceCountException(count); + } + + return content; + } + + @Override + public PemPrivateKey copy() { + return replace(content.copy()); + } + + @Override + public PemPrivateKey duplicate() { + return replace(content.duplicate()); + } + + @Override + public PemPrivateKey retainedDuplicate() { + return replace(content.retainedDuplicate()); + } + + @Override + public PemPrivateKey replace(ByteBuf content) { + return new PemPrivateKey(content); + } + + @Override + public PemPrivateKey touch() { + content.touch(); + return this; + } + + @Override + public PemPrivateKey touch(Object hint) { + content.touch(hint); + return this; + } + + @Override + public PemPrivateKey retain() { + return (PemPrivateKey) super.retain(); + } + + @Override + public PemPrivateKey retain(int increment) { + return (PemPrivateKey) super.retain(increment); + } + + @Override + protected void deallocate() { + // Private Keys are sensitive. We need to zero the bytes + // before we're releasing the underlying ByteBuf + SslUtils.zerooutAndRelease(content); + } + + @Override + public byte[] getEncoded() { + throw new UnsupportedOperationException(); + } + + @Override + public String getAlgorithm() { + throw new UnsupportedOperationException(); + } + + @Override + public String getFormat() { + return PKCS8_FORMAT; + } + + /** + * NOTE: This is a JDK8 interface/method. Due to backwards compatibility + * reasons it's not possible to slap the {@code @Override} annotation onto + * this method. + * + * @see Destroyable#destroy() + */ + @Override + public void destroy() { + release(refCnt()); + } + + /** + * NOTE: This is a JDK8 interface/method. Due to backwards compatibility + * reasons it's not possible to slap the {@code @Override} annotation onto + * this method. + * + * @see Destroyable#isDestroyed() + */ + @Override + public boolean isDestroyed() { + return refCnt() == 0; + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/PemReader.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/PemReader.java new file mode 100644 index 0000000..33328ea --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/PemReader.java @@ -0,0 +1,203 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.ssl; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.handler.codec.base64.Base64; +import io.netty.util.CharsetUtil; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.FileInputStream; +import java.io.FileNotFoundException; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.security.KeyException; +import java.security.KeyStore; +import java.security.cert.CertificateException; +import java.util.ArrayList; +import java.util.List; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * Reads a PEM file and converts it into a list of DERs so that they are imported into a {@link KeyStore} easily. + */ +final class PemReader { + + private static final InternalLogger logger = InternalLoggerFactory.getInstance(PemReader.class); + + private static final Pattern CERT_HEADER = Pattern.compile( + "-+BEGIN\\s[^-\\r\\n]*CERTIFICATE[^-\\r\\n]*-+(?:\\s|\\r|\\n)+"); + private static final Pattern CERT_FOOTER = Pattern.compile( + "-+END\\s[^-\\r\\n]*CERTIFICATE[^-\\r\\n]*-+(?:\\s|\\r|\\n)*"); + private static final Pattern KEY_HEADER = Pattern.compile( + "-+BEGIN\\s[^-\\r\\n]*PRIVATE\\s+KEY[^-\\r\\n]*-+(?:\\s|\\r|\\n)+"); + private static final Pattern KEY_FOOTER = Pattern.compile( + "-+END\\s[^-\\r\\n]*PRIVATE\\s+KEY[^-\\r\\n]*-+(?:\\s|\\r|\\n)*"); + private static final Pattern BODY = Pattern.compile("[a-z0-9+/=][a-z0-9+/=\\r\\n]*", Pattern.CASE_INSENSITIVE); + + static ByteBuf[] readCertificates(File file) throws CertificateException { + try { + InputStream in = new FileInputStream(file); + + try { + return readCertificates(in); + } finally { + safeClose(in); + } + } catch (FileNotFoundException e) { + throw new CertificateException("could not find certificate file: " + file); + } + } + + static ByteBuf[] readCertificates(InputStream in) throws CertificateException { + String content; + try { + content = readContent(in); + } catch (IOException e) { + throw new CertificateException("failed to read certificate input stream", e); + } + + List certs = new ArrayList(); + Matcher m = CERT_HEADER.matcher(content); + int start = 0; + for (;;) { + if (!m.find(start)) { + break; + } + + // Here and below it's necessary to save the position as it is reset + // after calling usePattern() on Android due to a bug. + // + // See https://issuetracker.google.com/issues/293206296 + start = m.end(); + m.usePattern(BODY); + if (!m.find(start)) { + break; + } + + ByteBuf base64 = Unpooled.copiedBuffer(m.group(0), CharsetUtil.US_ASCII); + start = m.end(); + m.usePattern(CERT_FOOTER); + if (!m.find(start)) { + // Certificate is incomplete. + break; + } + ByteBuf der = Base64.decode(base64); + base64.release(); + certs.add(der); + + start = m.end(); + m.usePattern(CERT_HEADER); + } + + if (certs.isEmpty()) { + throw new CertificateException("found no certificates in input stream"); + } + + return certs.toArray(new ByteBuf[0]); + } + + static ByteBuf readPrivateKey(File file) throws KeyException { + try { + InputStream in = new FileInputStream(file); + + try { + return readPrivateKey(in); + } finally { + safeClose(in); + } + } catch (FileNotFoundException e) { + throw new KeyException("could not find key file: " + file); + } + } + + static ByteBuf readPrivateKey(InputStream in) throws KeyException { + String content; + try { + content = readContent(in); + } catch (IOException e) { + throw new KeyException("failed to read key input stream", e); + } + int start = 0; + Matcher m = KEY_HEADER.matcher(content); + if (!m.find(start)) { + throw keyNotFoundException(); + } + start = m.end(); + m.usePattern(BODY); + if (!m.find(start)) { + throw keyNotFoundException(); + } + + ByteBuf base64 = Unpooled.copiedBuffer(m.group(0), CharsetUtil.US_ASCII); + start = m.end(); + m.usePattern(KEY_FOOTER); + if (!m.find(start)) { + // Key is incomplete. + throw keyNotFoundException(); + } + ByteBuf der = Base64.decode(base64); + base64.release(); + return der; + } + + private static KeyException keyNotFoundException() { + return new KeyException("could not find a PKCS #8 private key in input stream" + + " (see https://netty.io/wiki/sslcontextbuilder-and-private-key.html for more information)"); + } + + private static String readContent(InputStream in) throws IOException { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + try { + byte[] buf = new byte[8192]; + for (;;) { + int ret = in.read(buf); + if (ret < 0) { + break; + } + out.write(buf, 0, ret); + } + return out.toString(CharsetUtil.US_ASCII.name()); + } finally { + safeClose(out); + } + } + + private static void safeClose(InputStream in) { + try { + in.close(); + } catch (IOException e) { + logger.warn("Failed to close a stream.", e); + } + } + + private static void safeClose(OutputStream out) { + try { + out.close(); + } catch (IOException e) { + logger.warn("Failed to close a stream.", e); + } + } + + private PemReader() { } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/PemValue.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/PemValue.java new file mode 100644 index 0000000..634e408 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/PemValue.java @@ -0,0 +1,105 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.util.AbstractReferenceCounted; +import io.netty.util.IllegalReferenceCountException; +import io.netty.util.internal.ObjectUtil; + +/** + * A PEM encoded value. + * + * @see PemEncoded + * @see PemPrivateKey#toPEM(ByteBufAllocator, boolean, java.security.PrivateKey) + * @see PemX509Certificate#toPEM(ByteBufAllocator, boolean, java.security.cert.X509Certificate[]) + */ +class PemValue extends AbstractReferenceCounted implements PemEncoded { + + private final ByteBuf content; + + private final boolean sensitive; + + PemValue(ByteBuf content, boolean sensitive) { + this.content = ObjectUtil.checkNotNull(content, "content"); + this.sensitive = sensitive; + } + + @Override + public boolean isSensitive() { + return sensitive; + } + + @Override + public ByteBuf content() { + int count = refCnt(); + if (count <= 0) { + throw new IllegalReferenceCountException(count); + } + + return content; + } + + @Override + public PemValue copy() { + return replace(content.copy()); + } + + @Override + public PemValue duplicate() { + return replace(content.duplicate()); + } + + @Override + public PemValue retainedDuplicate() { + return replace(content.retainedDuplicate()); + } + + @Override + public PemValue replace(ByteBuf content) { + return new PemValue(content, sensitive); + } + + @Override + public PemValue touch() { + return (PemValue) super.touch(); + } + + @Override + public PemValue touch(Object hint) { + content.touch(hint); + return this; + } + + @Override + public PemValue retain() { + return (PemValue) super.retain(); + } + + @Override + public PemValue retain(int increment) { + return (PemValue) super.retain(increment); + } + + @Override + protected void deallocate() { + if (sensitive) { + SslUtils.zeroout(content); + } + content.release(); + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/PemX509Certificate.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/PemX509Certificate.java new file mode 100644 index 0000000..dcb1ffc --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/PemX509Certificate.java @@ -0,0 +1,403 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import static io.netty.util.internal.ObjectUtil.checkNonEmpty; + +import java.math.BigInteger; +import java.security.Principal; +import java.security.PublicKey; +import java.security.cert.CertificateEncodingException; +import java.security.cert.X509Certificate; +import java.util.Arrays; +import java.util.Date; +import java.util.Set; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.netty.util.CharsetUtil; +import io.netty.util.IllegalReferenceCountException; +import io.netty.util.internal.ObjectUtil; + +/** + * This is a special purpose implementation of a {@link X509Certificate} which allows + * the user to pass PEM/PKCS#8 encoded data straight into {@link OpenSslContext} without + * having to parse and re-encode bytes in Java land. + * + * All methods other than what's implemented in {@link PemEncoded}'s throw + * {@link UnsupportedOperationException}s. + * + * @see PemEncoded + * @see OpenSslContext + * @see #valueOf(byte[]) + * @see #valueOf(ByteBuf) + */ +public final class PemX509Certificate extends X509Certificate implements PemEncoded { + + private static final byte[] BEGIN_CERT = "-----BEGIN CERTIFICATE-----\n".getBytes(CharsetUtil.US_ASCII); + private static final byte[] END_CERT = "\n-----END CERTIFICATE-----\n".getBytes(CharsetUtil.US_ASCII); + + /** + * Creates a {@link PemEncoded} value from the {@link X509Certificate}s. + */ + static PemEncoded toPEM(ByteBufAllocator allocator, boolean useDirect, + X509Certificate... chain) throws CertificateEncodingException { + + checkNonEmpty(chain, "chain"); + + // We can take a shortcut if there is only one certificate and + // it already happens to be a PemEncoded instance. This is the + // ideal case and reason why all this exists. It allows the user + // to pass pre-encoded bytes straight into OpenSSL without having + // to do any of the extra work. + if (chain.length == 1) { + X509Certificate first = chain[0]; + if (first instanceof PemEncoded) { + return ((PemEncoded) first).retain(); + } + } + + boolean success = false; + ByteBuf pem = null; + try { + for (X509Certificate cert : chain) { + + if (cert == null) { + throw new IllegalArgumentException("Null element in chain: " + Arrays.toString(chain)); + } + + if (cert instanceof PemEncoded) { + pem = append(allocator, useDirect, (PemEncoded) cert, chain.length, pem); + } else { + pem = append(allocator, useDirect, cert, chain.length, pem); + } + } + + PemValue value = new PemValue(pem, false); + success = true; + return value; + } finally { + // Make sure we never leak the PEM's ByteBuf in the event of an Exception + if (!success && pem != null) { + pem.release(); + } + } + } + + /** + * Appends the {@link PemEncoded} value to the {@link ByteBuf} (last arg) and returns it. + * If the {@link ByteBuf} didn't exist yet it'll create it using the {@link ByteBufAllocator}. + */ + private static ByteBuf append(ByteBufAllocator allocator, boolean useDirect, + PemEncoded encoded, int count, ByteBuf pem) { + + ByteBuf content = encoded.content(); + + if (pem == null) { + // see the other append() method + pem = newBuffer(allocator, useDirect, content.readableBytes() * count); + } + + pem.writeBytes(content.slice()); + return pem; + } + + /** + * Appends the {@link X509Certificate} value to the {@link ByteBuf} (last arg) and returns it. + * If the {@link ByteBuf} didn't exist yet it'll create it using the {@link ByteBufAllocator}. + */ + private static ByteBuf append(ByteBufAllocator allocator, boolean useDirect, + X509Certificate cert, int count, ByteBuf pem) throws CertificateEncodingException { + + ByteBuf encoded = Unpooled.wrappedBuffer(cert.getEncoded()); + try { + ByteBuf base64 = SslUtils.toBase64(allocator, encoded); + try { + if (pem == null) { + // We try to approximate the buffer's initial size. The sizes of + // certificates can vary a lot so it'll be off a bit depending + // on the number of elements in the array (count argument). + pem = newBuffer(allocator, useDirect, + (BEGIN_CERT.length + base64.readableBytes() + END_CERT.length) * count); + } + + pem.writeBytes(BEGIN_CERT); + pem.writeBytes(base64); + pem.writeBytes(END_CERT); + } finally { + base64.release(); + } + } finally { + encoded.release(); + } + + return pem; + } + + private static ByteBuf newBuffer(ByteBufAllocator allocator, boolean useDirect, int initialCapacity) { + return useDirect ? allocator.directBuffer(initialCapacity) : allocator.buffer(initialCapacity); + } + + /** + * Creates a {@link PemX509Certificate} from raw {@code byte[]}. + * + * ATTENTION: It's assumed that the given argument is a PEM/PKCS#8 encoded value. + * No input validation is performed to validate it. + */ + public static PemX509Certificate valueOf(byte[] key) { + return valueOf(Unpooled.wrappedBuffer(key)); + } + + /** + * Creates a {@link PemX509Certificate} from raw {@code ByteBuf}. + * + * ATTENTION: It's assumed that the given argument is a PEM/PKCS#8 encoded value. + * No input validation is performed to validate it. + */ + public static PemX509Certificate valueOf(ByteBuf key) { + return new PemX509Certificate(key); + } + + private final ByteBuf content; + + private PemX509Certificate(ByteBuf content) { + this.content = ObjectUtil.checkNotNull(content, "content"); + } + + @Override + public boolean isSensitive() { + // There is no sensitive information in a X509 Certificate + return false; + } + + @Override + public int refCnt() { + return content.refCnt(); + } + + @Override + public ByteBuf content() { + int count = refCnt(); + if (count <= 0) { + throw new IllegalReferenceCountException(count); + } + + return content; + } + + @Override + public PemX509Certificate copy() { + return replace(content.copy()); + } + + @Override + public PemX509Certificate duplicate() { + return replace(content.duplicate()); + } + + @Override + public PemX509Certificate retainedDuplicate() { + return replace(content.retainedDuplicate()); + } + + @Override + public PemX509Certificate replace(ByteBuf content) { + return new PemX509Certificate(content); + } + + @Override + public PemX509Certificate retain() { + content.retain(); + return this; + } + + @Override + public PemX509Certificate retain(int increment) { + content.retain(increment); + return this; + } + + @Override + public PemX509Certificate touch() { + content.touch(); + return this; + } + + @Override + public PemX509Certificate touch(Object hint) { + content.touch(hint); + return this; + } + + @Override + public boolean release() { + return content.release(); + } + + @Override + public boolean release(int decrement) { + return content.release(decrement); + } + + @Override + public byte[] getEncoded() { + throw new UnsupportedOperationException(); + } + + @Override + public boolean hasUnsupportedCriticalExtension() { + throw new UnsupportedOperationException(); + } + + @Override + public Set getCriticalExtensionOIDs() { + throw new UnsupportedOperationException(); + } + + @Override + public Set getNonCriticalExtensionOIDs() { + throw new UnsupportedOperationException(); + } + + @Override + public byte[] getExtensionValue(String oid) { + throw new UnsupportedOperationException(); + } + + @Override + public void checkValidity() { + throw new UnsupportedOperationException(); + } + + @Override + public void checkValidity(Date date) { + throw new UnsupportedOperationException(); + } + + @Override + public int getVersion() { + throw new UnsupportedOperationException(); + } + + @Override + public BigInteger getSerialNumber() { + throw new UnsupportedOperationException(); + } + + @Override + public Principal getIssuerDN() { + throw new UnsupportedOperationException(); + } + + @Override + public Principal getSubjectDN() { + throw new UnsupportedOperationException(); + } + + @Override + public Date getNotBefore() { + throw new UnsupportedOperationException(); + } + + @Override + public Date getNotAfter() { + throw new UnsupportedOperationException(); + } + + @Override + public byte[] getTBSCertificate() { + throw new UnsupportedOperationException(); + } + + @Override + public byte[] getSignature() { + throw new UnsupportedOperationException(); + } + + @Override + public String getSigAlgName() { + throw new UnsupportedOperationException(); + } + + @Override + public String getSigAlgOID() { + throw new UnsupportedOperationException(); + } + + @Override + public byte[] getSigAlgParams() { + throw new UnsupportedOperationException(); + } + + @Override + public boolean[] getIssuerUniqueID() { + throw new UnsupportedOperationException(); + } + + @Override + public boolean[] getSubjectUniqueID() { + throw new UnsupportedOperationException(); + } + + @Override + public boolean[] getKeyUsage() { + throw new UnsupportedOperationException(); + } + + @Override + public int getBasicConstraints() { + throw new UnsupportedOperationException(); + } + + @Override + public void verify(PublicKey key) { + throw new UnsupportedOperationException(); + } + + @Override + public void verify(PublicKey key, String sigProvider) { + throw new UnsupportedOperationException(); + } + + @Override + public PublicKey getPublicKey() { + throw new UnsupportedOperationException(); + } + + @Override + public boolean equals(Object o) { + if (o == this) { + return true; + } + if (!(o instanceof PemX509Certificate)) { + return false; + } + + PemX509Certificate other = (PemX509Certificate) o; + return content.equals(other.content); + } + + @Override + public int hashCode() { + return content.hashCode(); + } + + @Override + public String toString() { + return content.toString(CharsetUtil.UTF_8); + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/PseudoRandomFunction.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/PseudoRandomFunction.java new file mode 100644 index 0000000..0471fbd --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/PseudoRandomFunction.java @@ -0,0 +1,94 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; + +import io.netty.util.internal.EmptyArrays; + +import javax.crypto.Mac; +import javax.crypto.spec.SecretKeySpec; +import java.security.GeneralSecurityException; +import java.util.Arrays; + +/** + * This pseudorandom function (PRF) takes as input a secret, a seed, and + * an identifying label and produces an output of arbitrary length. + * + * This is used by the TLS RFC to construct/deconstruct an array of bytes into + * composite secrets. + * + * {@link rfc5246} + */ +final class PseudoRandomFunction { + + /** + * Constructor never to be called. + */ + private PseudoRandomFunction() { + } + + /** + * Use a single hash function to expand a secret and seed into an + * arbitrary quantity of output. + * + * P_hash(secret, seed) = HMAC_hash(secret, A(1) + seed) + + * HMAC_hash(secret, A(2) + seed) + + * HMAC_hash(secret, A(3) + seed) + ... + * where + indicates concatenation. + * A() is defined as: + * A(0) = seed + * A(i) = HMAC_hash(secret, A(i-1)) + * @param secret The starting secret to use for expansion + * @param label An ascii string without a length byte or trailing null character. + * @param seed The seed of the hash + * @param length The number of bytes to return + * @param algo the hmac algorithm to use + * @return The expanded secrets + * @throws IllegalArgumentException if the algo could not be found. + */ + static byte[] hash(byte[] secret, byte[] label, byte[] seed, int length, String algo) { + checkPositiveOrZero(length, "length"); + try { + Mac hmac = Mac.getInstance(algo); + hmac.init(new SecretKeySpec(secret, algo)); + /* + * P_hash(secret, seed) = HMAC_hash(secret, A(1) + seed) + + * HMAC_hash(secret, A(2) + seed) + HMAC_hash(secret, A(3) + seed) + ... + * where + indicates concatenation. A() is defined as: A(0) = seed, A(i) + * = HMAC_hash(secret, A(i-1)) + */ + + int iterations = (int) Math.ceil(length / (double) hmac.getMacLength()); + byte[] expansion = EmptyArrays.EMPTY_BYTES; + byte[] data = concat(label, seed); + byte[] A = data; + for (int i = 0; i < iterations; i++) { + A = hmac.doFinal(A); + expansion = concat(expansion, hmac.doFinal(concat(A, data))); + } + return Arrays.copyOf(expansion, length); + } catch (GeneralSecurityException e) { + throw new IllegalArgumentException("Could not find algo: " + algo, e); + } + } + + private static byte[] concat(byte[] first, byte[] second) { + byte[] result = Arrays.copyOf(first, first.length + second.length); + System.arraycopy(second, 0, result, first.length, second.length); + return result; + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslClientContext.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslClientContext.java new file mode 100644 index 0000000..2609489 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslClientContext.java @@ -0,0 +1,320 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.internal.tcnative.CertificateCallback; +import io.netty.util.internal.EmptyArrays; +import io.netty.util.internal.SuppressJava6Requirement; +import io.netty.internal.tcnative.SSL; +import io.netty.internal.tcnative.SSLContext; + +import java.security.KeyStore; +import java.security.PrivateKey; +import java.security.cert.X509Certificate; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.LinkedHashSet; +import java.util.Map; +import java.util.Set; + +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLException; +import javax.net.ssl.TrustManagerFactory; +import javax.net.ssl.X509ExtendedTrustManager; +import javax.net.ssl.X509TrustManager; +import javax.security.auth.x500.X500Principal; + +/** + * A client-side {@link SslContext} which uses OpenSSL's SSL/TLS implementation. + *

Instances of this class must be {@link #release() released} or else native memory will leak! + * + *

Instances of this class must not be released before any {@link ReferenceCountedOpenSslEngine} + * which depends upon the instance of this class is released. Otherwise if any method of + * {@link ReferenceCountedOpenSslEngine} is called which uses this class's JNI resources the JVM may crash. + */ +public final class ReferenceCountedOpenSslClientContext extends ReferenceCountedOpenSslContext { + + private static final Set SUPPORTED_KEY_TYPES = Collections.unmodifiableSet(new LinkedHashSet( + Arrays.asList(OpenSslKeyMaterialManager.KEY_TYPE_RSA, + OpenSslKeyMaterialManager.KEY_TYPE_DH_RSA, + OpenSslKeyMaterialManager.KEY_TYPE_EC, + OpenSslKeyMaterialManager.KEY_TYPE_EC_RSA, + OpenSslKeyMaterialManager.KEY_TYPE_EC_EC))); + + private final OpenSslSessionContext sessionContext; + + ReferenceCountedOpenSslClientContext(X509Certificate[] trustCertCollection, TrustManagerFactory trustManagerFactory, + X509Certificate[] keyCertChain, PrivateKey key, String keyPassword, + KeyManagerFactory keyManagerFactory, Iterable ciphers, + CipherSuiteFilter cipherFilter, ApplicationProtocolConfig apn, + String[] protocols, long sessionCacheSize, long sessionTimeout, + boolean enableOcsp, String keyStore, + Map.Entry, Object>... options) throws SSLException { + super(ciphers, cipherFilter, toNegotiator(apn), SSL.SSL_MODE_CLIENT, keyCertChain, + ClientAuth.NONE, protocols, false, enableOcsp, true, options); + boolean success = false; + try { + sessionContext = newSessionContext(this, ctx, engineMap, trustCertCollection, trustManagerFactory, + keyCertChain, key, keyPassword, keyManagerFactory, keyStore, + sessionCacheSize, sessionTimeout); + success = true; + } finally { + if (!success) { + release(); + } + } + } + + @Override + public OpenSslSessionContext sessionContext() { + return sessionContext; + } + + static OpenSslSessionContext newSessionContext(ReferenceCountedOpenSslContext thiz, long ctx, + OpenSslEngineMap engineMap, + X509Certificate[] trustCertCollection, + TrustManagerFactory trustManagerFactory, + X509Certificate[] keyCertChain, PrivateKey key, + String keyPassword, KeyManagerFactory keyManagerFactory, + String keyStore, long sessionCacheSize, long sessionTimeout) + throws SSLException { + if (key == null && keyCertChain != null || key != null && keyCertChain == null) { + throw new IllegalArgumentException( + "Either both keyCertChain and key needs to be null or none of them"); + } + OpenSslKeyMaterialProvider keyMaterialProvider = null; + try { + try { + if (!OpenSsl.useKeyManagerFactory()) { + if (keyManagerFactory != null) { + throw new IllegalArgumentException( + "KeyManagerFactory not supported"); + } + if (keyCertChain != null/* && key != null*/) { + setKeyMaterial(ctx, keyCertChain, key, keyPassword); + } + } else { + // javadocs state that keyManagerFactory has precedent over keyCertChain + if (keyManagerFactory == null && keyCertChain != null) { + char[] keyPasswordChars = keyStorePassword(keyPassword); + KeyStore ks = buildKeyStore(keyCertChain, key, keyPasswordChars, keyStore); + if (ks.aliases().hasMoreElements()) { + keyManagerFactory = new OpenSslX509KeyManagerFactory(); + } else { + keyManagerFactory = new OpenSslCachingX509KeyManagerFactory( + KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())); + } + keyManagerFactory.init(ks, keyPasswordChars); + keyMaterialProvider = providerFor(keyManagerFactory, keyPassword); + } else if (keyManagerFactory != null) { + keyMaterialProvider = providerFor(keyManagerFactory, keyPassword); + } + + if (keyMaterialProvider != null) { + OpenSslKeyMaterialManager materialManager = new OpenSslKeyMaterialManager(keyMaterialProvider); + SSLContext.setCertificateCallback(ctx, new OpenSslClientCertificateCallback( + engineMap, materialManager)); + } + } + } catch (Exception e) { + throw new SSLException("failed to set certificate and key", e); + } + + // On the client side we always need to use SSL_CVERIFY_OPTIONAL (which will translate to SSL_VERIFY_PEER) + // to ensure that when the TrustManager throws we will produce the correct alert back to the server. + // + // See: + // - https://www.openssl.org/docs/man1.0.2/man3/SSL_CTX_set_verify.html + // - https://github.com/netty/netty/issues/8942 + SSLContext.setVerify(ctx, SSL.SSL_CVERIFY_OPTIONAL, VERIFY_DEPTH); + + try { + if (trustCertCollection != null) { + trustManagerFactory = buildTrustManagerFactory(trustCertCollection, trustManagerFactory, keyStore); + } else if (trustManagerFactory == null) { + trustManagerFactory = TrustManagerFactory.getInstance( + TrustManagerFactory.getDefaultAlgorithm()); + trustManagerFactory.init((KeyStore) null); + } + final X509TrustManager manager = chooseTrustManager(trustManagerFactory.getTrustManagers()); + + // IMPORTANT: The callbacks set for verification must be static to prevent memory leak as + // otherwise the context can never be collected. This is because the JNI code holds + // a global reference to the callbacks. + // + // See https://github.com/netty/netty/issues/5372 + + setVerifyCallback(ctx, engineMap, manager); + } catch (Exception e) { + if (keyMaterialProvider != null) { + keyMaterialProvider.destroy(); + } + throw new SSLException("unable to setup trustmanager", e); + } + OpenSslClientSessionContext context = new OpenSslClientSessionContext(thiz, keyMaterialProvider); + context.setSessionCacheEnabled(CLIENT_ENABLE_SESSION_CACHE); + if (sessionCacheSize > 0) { + context.setSessionCacheSize((int) Math.min(sessionCacheSize, Integer.MAX_VALUE)); + } + if (sessionTimeout > 0) { + context.setSessionTimeout((int) Math.min(sessionTimeout, Integer.MAX_VALUE)); + } + + if (CLIENT_ENABLE_SESSION_TICKET) { + context.setTicketKeys(); + } + + keyMaterialProvider = null; + return context; + } finally { + if (keyMaterialProvider != null) { + keyMaterialProvider.destroy(); + } + } + } + + @SuppressJava6Requirement(reason = "Guarded by java version check") + private static void setVerifyCallback(long ctx, OpenSslEngineMap engineMap, X509TrustManager manager) { + // Use this to prevent an error when running on java < 7 + if (useExtendedTrustManager(manager)) { + SSLContext.setCertVerifyCallback(ctx, + new ExtendedTrustManagerVerifyCallback(engineMap, (X509ExtendedTrustManager) manager)); + } else { + SSLContext.setCertVerifyCallback(ctx, new TrustManagerVerifyCallback(engineMap, manager)); + } + } + + static final class OpenSslClientSessionContext extends OpenSslSessionContext { + OpenSslClientSessionContext(ReferenceCountedOpenSslContext context, OpenSslKeyMaterialProvider provider) { + super(context, provider, SSL.SSL_SESS_CACHE_CLIENT, new OpenSslClientSessionCache(context.engineMap)); + } + } + + private static final class TrustManagerVerifyCallback extends AbstractCertificateVerifier { + private final X509TrustManager manager; + + TrustManagerVerifyCallback(OpenSslEngineMap engineMap, X509TrustManager manager) { + super(engineMap); + this.manager = manager; + } + + @Override + void verify(ReferenceCountedOpenSslEngine engine, X509Certificate[] peerCerts, String auth) + throws Exception { + manager.checkServerTrusted(peerCerts, auth); + } + } + + @SuppressJava6Requirement(reason = "Usage guarded by java version check") + private static final class ExtendedTrustManagerVerifyCallback extends AbstractCertificateVerifier { + private final X509ExtendedTrustManager manager; + + ExtendedTrustManagerVerifyCallback(OpenSslEngineMap engineMap, X509ExtendedTrustManager manager) { + super(engineMap); + this.manager = manager; + } + + @Override + void verify(ReferenceCountedOpenSslEngine engine, X509Certificate[] peerCerts, String auth) + throws Exception { + manager.checkServerTrusted(peerCerts, auth, engine); + } + } + + private static final class OpenSslClientCertificateCallback implements CertificateCallback { + private final OpenSslEngineMap engineMap; + private final OpenSslKeyMaterialManager keyManagerHolder; + + OpenSslClientCertificateCallback(OpenSslEngineMap engineMap, OpenSslKeyMaterialManager keyManagerHolder) { + this.engineMap = engineMap; + this.keyManagerHolder = keyManagerHolder; + } + + @Override + public void handle(long ssl, byte[] keyTypeBytes, byte[][] asn1DerEncodedPrincipals) throws Exception { + final ReferenceCountedOpenSslEngine engine = engineMap.get(ssl); + // May be null if it was destroyed in the meantime. + if (engine == null) { + return; + } + try { + final Set keyTypesSet = supportedClientKeyTypes(keyTypeBytes); + final String[] keyTypes = keyTypesSet.toArray(EmptyArrays.EMPTY_STRINGS); + final X500Principal[] issuers; + if (asn1DerEncodedPrincipals == null) { + issuers = null; + } else { + issuers = new X500Principal[asn1DerEncodedPrincipals.length]; + for (int i = 0; i < asn1DerEncodedPrincipals.length; i++) { + issuers[i] = new X500Principal(asn1DerEncodedPrincipals[i]); + } + } + keyManagerHolder.setKeyMaterialClientSide(engine, keyTypes, issuers); + } catch (Throwable cause) { + engine.initHandshakeException(cause); + if (cause instanceof Exception) { + throw (Exception) cause; + } + throw new SSLException(cause); + } + } + + /** + * Gets the supported key types for client certificates. + * + * @param clientCertificateTypes {@code ClientCertificateType} values provided by the server. + * See https://www.ietf.org/assignments/tls-parameters/tls-parameters.xml. + * @return supported key types that can be used in {@code X509KeyManager.chooseClientAlias} and + * {@code X509ExtendedKeyManager.chooseEngineClientAlias}. + */ + private static Set supportedClientKeyTypes(byte[] clientCertificateTypes) { + if (clientCertificateTypes == null) { + // Try all of the supported key types. + return SUPPORTED_KEY_TYPES; + } + Set result = new HashSet(clientCertificateTypes.length); + for (byte keyTypeCode : clientCertificateTypes) { + String keyType = clientKeyType(keyTypeCode); + if (keyType == null) { + // Unsupported client key type -- ignore + continue; + } + result.add(keyType); + } + return result; + } + + private static String clientKeyType(byte clientCertificateType) { + // See also https://www.ietf.org/assignments/tls-parameters/tls-parameters.xml + switch (clientCertificateType) { + case CertificateCallback.TLS_CT_RSA_SIGN: + return OpenSslKeyMaterialManager.KEY_TYPE_RSA; // RFC rsa_sign + case CertificateCallback.TLS_CT_RSA_FIXED_DH: + return OpenSslKeyMaterialManager.KEY_TYPE_DH_RSA; // RFC rsa_fixed_dh + case CertificateCallback.TLS_CT_ECDSA_SIGN: + return OpenSslKeyMaterialManager.KEY_TYPE_EC; // RFC ecdsa_sign + case CertificateCallback.TLS_CT_RSA_FIXED_ECDH: + return OpenSslKeyMaterialManager.KEY_TYPE_EC_RSA; // RFC rsa_fixed_ecdh + case CertificateCallback.TLS_CT_ECDSA_FIXED_ECDH: + return OpenSslKeyMaterialManager.KEY_TYPE_EC_EC; // RFC ecdsa_fixed_ecdh + default: + return null; + } + } + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslContext.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslContext.java new file mode 100644 index 0000000..40a5614 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslContext.java @@ -0,0 +1,1146 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.handler.ssl.util.LazyX509Certificate; +import io.netty.internal.tcnative.AsyncSSLPrivateKeyMethod; +import io.netty.internal.tcnative.CertificateCompressionAlgo; +import io.netty.internal.tcnative.CertificateVerifier; +import io.netty.internal.tcnative.ResultCallback; +import io.netty.internal.tcnative.SSL; +import io.netty.internal.tcnative.SSLContext; +import io.netty.internal.tcnative.SSLPrivateKeyMethod; +import io.netty.util.AbstractReferenceCounted; +import io.netty.util.ReferenceCounted; +import io.netty.util.ResourceLeakDetector; +import io.netty.util.ResourceLeakDetectorFactory; +import io.netty.util.ResourceLeakTracker; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.FutureListener; +import io.netty.util.internal.EmptyArrays; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.StringUtil; +import io.netty.util.internal.SuppressJava6Requirement; +import io.netty.util.internal.SystemPropertyUtil; +import io.netty.util.internal.UnstableApi; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.security.PrivateKey; +import java.security.SignatureException; +import java.security.cert.CertPathValidatorException; +import java.security.cert.Certificate; +import java.security.cert.CertificateExpiredException; +import java.security.cert.CertificateNotYetValidException; +import java.security.cert.CertificateRevokedException; +import java.security.cert.X509Certificate; +import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.concurrent.Executor; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReadWriteLock; +import java.util.concurrent.locks.ReentrantReadWriteLock; + +import javax.net.ssl.KeyManager; +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLException; +import javax.net.ssl.SSLHandshakeException; +import javax.net.ssl.TrustManager; +import javax.net.ssl.X509ExtendedTrustManager; +import javax.net.ssl.X509KeyManager; +import javax.net.ssl.X509TrustManager; + +import static io.netty.handler.ssl.OpenSsl.DEFAULT_CIPHERS; +import static io.netty.handler.ssl.OpenSsl.availableJavaCipherSuites; +import static io.netty.util.internal.ObjectUtil.checkNotNull; +import static io.netty.util.internal.ObjectUtil.checkNonEmpty; +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; + +/** + * An implementation of {@link SslContext} which works with libraries that support the + * OpenSsl C library API. + *

Instances of this class must be {@link #release() released} or else native memory will leak! + * + *

Instances of this class must not be released before any {@link ReferenceCountedOpenSslEngine} + * which depends upon the instance of this class is released. Otherwise if any method of + * {@link ReferenceCountedOpenSslEngine} is called which uses this class's JNI resources the JVM may crash. + */ +public abstract class ReferenceCountedOpenSslContext extends SslContext implements ReferenceCounted { + private static final InternalLogger logger = + InternalLoggerFactory.getInstance(ReferenceCountedOpenSslContext.class); + + private static final int DEFAULT_BIO_NON_APPLICATION_BUFFER_SIZE = Math.max(1, + SystemPropertyUtil.getInt("io.netty.handler.ssl.openssl.bioNonApplicationBufferSize", + 2048)); + // Let's use tasks by default but still allow the user to disable it via system property just in case. + static final boolean USE_TASKS = + SystemPropertyUtil.getBoolean("io.netty.handler.ssl.openssl.useTasks", true); + private static final Integer DH_KEY_LENGTH; + private static final ResourceLeakDetector leakDetector = + ResourceLeakDetectorFactory.instance().newResourceLeakDetector(ReferenceCountedOpenSslContext.class); + + // TODO: Maybe make configurable ? + protected static final int VERIFY_DEPTH = 10; + + static final boolean CLIENT_ENABLE_SESSION_TICKET = + SystemPropertyUtil.getBoolean("jdk.tls.client.enableSessionTicketExtension", false); + + static final boolean CLIENT_ENABLE_SESSION_TICKET_TLSV13 = + SystemPropertyUtil.getBoolean("jdk.tls.client.enableSessionTicketExtension", true); + + static final boolean SERVER_ENABLE_SESSION_TICKET = + SystemPropertyUtil.getBoolean("jdk.tls.server.enableSessionTicketExtension", false); + + static final boolean SERVER_ENABLE_SESSION_TICKET_TLSV13 = + SystemPropertyUtil.getBoolean("jdk.tls.server.enableSessionTicketExtension", true); + + static final boolean SERVER_ENABLE_SESSION_CACHE = + SystemPropertyUtil.getBoolean("io.netty.handler.ssl.openssl.sessionCacheServer", true); + static final boolean CLIENT_ENABLE_SESSION_CACHE = + SystemPropertyUtil.getBoolean("io.netty.handler.ssl.openssl.sessionCacheClient", true); + + /** + * The OpenSSL SSL_CTX object. + * + * {@link #ctxLock} must be hold while using ctx! + */ + protected long ctx; + private final List unmodifiableCiphers; + private final OpenSslApplicationProtocolNegotiator apn; + private final int mode; + + // Reference Counting + private final ResourceLeakTracker leak; + private final AbstractReferenceCounted refCnt = new AbstractReferenceCounted() { + @Override + public ReferenceCounted touch(Object hint) { + if (leak != null) { + leak.record(hint); + } + + return ReferenceCountedOpenSslContext.this; + } + + @Override + protected void deallocate() { + destroy(); + if (leak != null) { + boolean closed = leak.close(ReferenceCountedOpenSslContext.this); + assert closed; + } + } + }; + + final Certificate[] keyCertChain; + final ClientAuth clientAuth; + final String[] protocols; + final boolean enableOcsp; + final OpenSslEngineMap engineMap = new DefaultOpenSslEngineMap(); + final ReadWriteLock ctxLock = new ReentrantReadWriteLock(); + + private volatile int bioNonApplicationBufferSize = DEFAULT_BIO_NON_APPLICATION_BUFFER_SIZE; + + @SuppressWarnings("deprecation") + static final OpenSslApplicationProtocolNegotiator NONE_PROTOCOL_NEGOTIATOR = + new OpenSslApplicationProtocolNegotiator() { + @Override + public ApplicationProtocolConfig.Protocol protocol() { + return ApplicationProtocolConfig.Protocol.NONE; + } + + @Override + public List protocols() { + return Collections.emptyList(); + } + + @Override + public ApplicationProtocolConfig.SelectorFailureBehavior selectorFailureBehavior() { + return ApplicationProtocolConfig.SelectorFailureBehavior.CHOOSE_MY_LAST_PROTOCOL; + } + + @Override + public ApplicationProtocolConfig.SelectedListenerFailureBehavior selectedListenerFailureBehavior() { + return ApplicationProtocolConfig.SelectedListenerFailureBehavior.ACCEPT; + } + }; + + static { + Integer dhLen = null; + + try { + String dhKeySize = SystemPropertyUtil.get("jdk.tls.ephemeralDHKeySize"); + if (dhKeySize != null) { + try { + dhLen = Integer.valueOf(dhKeySize); + } catch (NumberFormatException e) { + logger.debug("ReferenceCountedOpenSslContext supports -Djdk.tls.ephemeralDHKeySize={int}, but got: " + + dhKeySize); + } + } + } catch (Throwable ignore) { + // ignore + } + DH_KEY_LENGTH = dhLen; + } + + final boolean tlsFalseStart; + + ReferenceCountedOpenSslContext(Iterable ciphers, CipherSuiteFilter cipherFilter, + OpenSslApplicationProtocolNegotiator apn, int mode, Certificate[] keyCertChain, + ClientAuth clientAuth, String[] protocols, boolean startTls, boolean enableOcsp, + boolean leakDetection, Map.Entry, Object>... ctxOptions) + throws SSLException { + super(startTls); + + OpenSsl.ensureAvailability(); + + if (enableOcsp && !OpenSsl.isOcspSupported()) { + throw new IllegalStateException("OCSP is not supported."); + } + + if (mode != SSL.SSL_MODE_SERVER && mode != SSL.SSL_MODE_CLIENT) { + throw new IllegalArgumentException("mode most be either SSL.SSL_MODE_SERVER or SSL.SSL_MODE_CLIENT"); + } + + boolean tlsFalseStart = false; + boolean useTasks = USE_TASKS; + OpenSslPrivateKeyMethod privateKeyMethod = null; + OpenSslAsyncPrivateKeyMethod asyncPrivateKeyMethod = null; + OpenSslCertificateCompressionConfig certCompressionConfig = null; + Integer maxCertificateList = null; + + if (ctxOptions != null) { + for (Map.Entry, Object> ctxOpt : ctxOptions) { + SslContextOption option = ctxOpt.getKey(); + + if (option == OpenSslContextOption.TLS_FALSE_START) { + tlsFalseStart = (Boolean) ctxOpt.getValue(); + } else if (option == OpenSslContextOption.USE_TASKS) { + useTasks = (Boolean) ctxOpt.getValue(); + } else if (option == OpenSslContextOption.PRIVATE_KEY_METHOD) { + privateKeyMethod = (OpenSslPrivateKeyMethod) ctxOpt.getValue(); + } else if (option == OpenSslContextOption.ASYNC_PRIVATE_KEY_METHOD) { + asyncPrivateKeyMethod = (OpenSslAsyncPrivateKeyMethod) ctxOpt.getValue(); + } else if (option == OpenSslContextOption.CERTIFICATE_COMPRESSION_ALGORITHMS) { + certCompressionConfig = (OpenSslCertificateCompressionConfig) ctxOpt.getValue(); + } else if (option == OpenSslContextOption.MAX_CERTIFICATE_LIST_BYTES) { + maxCertificateList = (Integer) ctxOpt.getValue(); + } else { + logger.debug("Skipping unsupported " + SslContextOption.class.getSimpleName() + + ": " + ctxOpt.getKey()); + } + } + } + if (privateKeyMethod != null && asyncPrivateKeyMethod != null) { + throw new IllegalArgumentException("You can either only use " + + OpenSslAsyncPrivateKeyMethod.class.getSimpleName() + " or " + + OpenSslPrivateKeyMethod.class.getSimpleName()); + } + + this.tlsFalseStart = tlsFalseStart; + + leak = leakDetection ? leakDetector.track(this) : null; + this.mode = mode; + this.clientAuth = isServer() ? checkNotNull(clientAuth, "clientAuth") : ClientAuth.NONE; + this.protocols = protocols == null ? OpenSsl.defaultProtocols(mode == SSL.SSL_MODE_CLIENT) : protocols; + this.enableOcsp = enableOcsp; + + this.keyCertChain = keyCertChain == null ? null : keyCertChain.clone(); + + String[] suites = checkNotNull(cipherFilter, "cipherFilter").filterCipherSuites( + ciphers, DEFAULT_CIPHERS, availableJavaCipherSuites()); + // Filter out duplicates. + LinkedHashSet suitesSet = new LinkedHashSet(suites.length); + Collections.addAll(suitesSet, suites); + unmodifiableCiphers = new ArrayList(suitesSet); + + this.apn = checkNotNull(apn, "apn"); + + // Create a new SSL_CTX and configure it. + boolean success = false; + try { + boolean tlsv13Supported = OpenSsl.isTlsv13Supported(); + + try { + int protocolOpts = SSL.SSL_PROTOCOL_SSLV3 | SSL.SSL_PROTOCOL_TLSV1 | + SSL.SSL_PROTOCOL_TLSV1_1 | SSL.SSL_PROTOCOL_TLSV1_2; + if (tlsv13Supported) { + protocolOpts |= SSL.SSL_PROTOCOL_TLSV1_3; + } + ctx = SSLContext.make(protocolOpts, mode); + } catch (Exception e) { + throw new SSLException("failed to create an SSL_CTX", e); + } + + StringBuilder cipherBuilder = new StringBuilder(); + StringBuilder cipherTLSv13Builder = new StringBuilder(); + + /* List the ciphers that are permitted to negotiate. */ + try { + if (unmodifiableCiphers.isEmpty()) { + // Set non TLSv1.3 ciphers. + SSLContext.setCipherSuite(ctx, StringUtil.EMPTY_STRING, false); + if (tlsv13Supported) { + // Set TLSv1.3 ciphers. + SSLContext.setCipherSuite(ctx, StringUtil.EMPTY_STRING, true); + } + } else { + CipherSuiteConverter.convertToCipherStrings( + unmodifiableCiphers, cipherBuilder, cipherTLSv13Builder, OpenSsl.isBoringSSL()); + + // Set non TLSv1.3 ciphers. + SSLContext.setCipherSuite(ctx, cipherBuilder.toString(), false); + if (tlsv13Supported) { + // Set TLSv1.3 ciphers. + SSLContext.setCipherSuite(ctx, + OpenSsl.checkTls13Ciphers(logger, cipherTLSv13Builder.toString()), true); + } + } + } catch (SSLException e) { + throw e; + } catch (Exception e) { + throw new SSLException("failed to set cipher suite: " + unmodifiableCiphers, e); + } + + int options = SSLContext.getOptions(ctx) | + SSL.SSL_OP_NO_SSLv2 | + SSL.SSL_OP_NO_SSLv3 | + // Disable TLSv1 and TLSv1.1 by default as these are not considered secure anymore + // and the JDK is doing the same: + // https://www.oracle.com/java/technologies/javase/8u291-relnotes.html + SSL.SSL_OP_NO_TLSv1 | + SSL.SSL_OP_NO_TLSv1_1 | + + SSL.SSL_OP_CIPHER_SERVER_PREFERENCE | + + // We do not support compression at the moment so we should explicitly disable it. + SSL.SSL_OP_NO_COMPRESSION | + + // Disable ticket support by default to be more inline with SSLEngineImpl of the JDK. + // This also let SSLSession.getId() work the same way for the JDK implementation and the + // OpenSSLEngine. If tickets are supported SSLSession.getId() will only return an ID on the + // server-side if it could make use of tickets. + SSL.SSL_OP_NO_TICKET; + + if (cipherBuilder.length() == 0) { + // No ciphers that are compatible with SSLv2 / SSLv3 / TLSv1 / TLSv1.1 / TLSv1.2 + options |= SSL.SSL_OP_NO_SSLv2 | SSL.SSL_OP_NO_SSLv3 | SSL.SSL_OP_NO_TLSv1 + | SSL.SSL_OP_NO_TLSv1_1 | SSL.SSL_OP_NO_TLSv1_2; + } + + if (!tlsv13Supported) { + // Explicit disable TLSv1.3 + // See https://github.com/netty/netty/issues/12968 + options |= SSL.SSL_OP_NO_TLSv1_3; + } + + SSLContext.setOptions(ctx, options); + + // We need to enable SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER as the memory address may change between + // calling OpenSSLEngine.wrap(...). + // See https://github.com/netty/netty-tcnative/issues/100 + SSLContext.setMode(ctx, SSLContext.getMode(ctx) | SSL.SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER); + + if (DH_KEY_LENGTH != null) { + SSLContext.setTmpDHLength(ctx, DH_KEY_LENGTH); + } + + List nextProtoList = apn.protocols(); + /* Set next protocols for next protocol negotiation extension, if specified */ + if (!nextProtoList.isEmpty()) { + String[] appProtocols = nextProtoList.toArray(EmptyArrays.EMPTY_STRINGS); + int selectorBehavior = opensslSelectorFailureBehavior(apn.selectorFailureBehavior()); + + switch (apn.protocol()) { + case NPN: + SSLContext.setNpnProtos(ctx, appProtocols, selectorBehavior); + break; + case ALPN: + SSLContext.setAlpnProtos(ctx, appProtocols, selectorBehavior); + break; + case NPN_AND_ALPN: + SSLContext.setNpnProtos(ctx, appProtocols, selectorBehavior); + SSLContext.setAlpnProtos(ctx, appProtocols, selectorBehavior); + break; + default: + throw new Error(); + } + } + + if (enableOcsp) { + SSLContext.enableOcsp(ctx, isClient()); + } + + SSLContext.setUseTasks(ctx, useTasks); + if (privateKeyMethod != null) { + SSLContext.setPrivateKeyMethod(ctx, new PrivateKeyMethod(engineMap, privateKeyMethod)); + } + if (asyncPrivateKeyMethod != null) { + SSLContext.setPrivateKeyMethod(ctx, new AsyncPrivateKeyMethod(engineMap, asyncPrivateKeyMethod)); + } + if (certCompressionConfig != null) { + for (OpenSslCertificateCompressionConfig.AlgorithmConfig configPair : certCompressionConfig) { + final CertificateCompressionAlgo algo = new CompressionAlgorithm(engineMap, configPair.algorithm()); + switch (configPair.mode()) { + case Decompress: + SSLContext.addCertificateCompressionAlgorithm( + ctx, SSL.SSL_CERT_COMPRESSION_DIRECTION_DECOMPRESS, algo); + break; + case Compress: + SSLContext.addCertificateCompressionAlgorithm( + ctx, SSL.SSL_CERT_COMPRESSION_DIRECTION_COMPRESS, algo); + break; + case Both: + SSLContext.addCertificateCompressionAlgorithm( + ctx, SSL.SSL_CERT_COMPRESSION_DIRECTION_BOTH, algo); + break; + default: + throw new IllegalStateException(); + } + } + } + if (maxCertificateList != null) { + SSLContext.setMaxCertList(ctx, maxCertificateList); + } + // Set the curves. + SSLContext.setCurvesList(ctx, OpenSsl.NAMED_GROUPS); + success = true; + } finally { + if (!success) { + release(); + } + } + } + + private static int opensslSelectorFailureBehavior(ApplicationProtocolConfig.SelectorFailureBehavior behavior) { + switch (behavior) { + case NO_ADVERTISE: + return SSL.SSL_SELECTOR_FAILURE_NO_ADVERTISE; + case CHOOSE_MY_LAST_PROTOCOL: + return SSL.SSL_SELECTOR_FAILURE_CHOOSE_MY_LAST_PROTOCOL; + default: + throw new Error(); + } + } + + @Override + public final List cipherSuites() { + return unmodifiableCiphers; + } + + @Override + public ApplicationProtocolNegotiator applicationProtocolNegotiator() { + return apn; + } + + @Override + public final boolean isClient() { + return mode == SSL.SSL_MODE_CLIENT; + } + + @Override + public final SSLEngine newEngine(ByteBufAllocator alloc, String peerHost, int peerPort) { + return newEngine0(alloc, peerHost, peerPort, true); + } + + @Override + protected final SslHandler newHandler(ByteBufAllocator alloc, boolean startTls) { + return new SslHandler(newEngine0(alloc, null, -1, false), startTls); + } + + @Override + protected final SslHandler newHandler(ByteBufAllocator alloc, String peerHost, int peerPort, boolean startTls) { + return new SslHandler(newEngine0(alloc, peerHost, peerPort, false), startTls); + } + + @Override + protected SslHandler newHandler(ByteBufAllocator alloc, boolean startTls, Executor executor) { + return new SslHandler(newEngine0(alloc, null, -1, false), startTls, executor); + } + + @Override + protected SslHandler newHandler(ByteBufAllocator alloc, String peerHost, int peerPort, + boolean startTls, Executor executor) { + return new SslHandler(newEngine0(alloc, peerHost, peerPort, false), executor); + } + + SSLEngine newEngine0(ByteBufAllocator alloc, String peerHost, int peerPort, boolean jdkCompatibilityMode) { + return new ReferenceCountedOpenSslEngine(this, alloc, peerHost, peerPort, jdkCompatibilityMode, true); + } + + /** + * Returns a new server-side {@link SSLEngine} with the current configuration. + */ + @Override + public final SSLEngine newEngine(ByteBufAllocator alloc) { + return newEngine(alloc, null, -1); + } + + /** + * Returns the pointer to the {@code SSL_CTX} object for this {@link ReferenceCountedOpenSslContext}. + * Be aware that it is freed as soon as the {@link #finalize()} method is called. + * At this point {@code 0} will be returned. + * + * @deprecated this method is considered unsafe as the returned pointer may be released later. Dont use it! + */ + @Deprecated + public final long context() { + return sslCtxPointer(); + } + + /** + * Returns the stats of this context. + * + * @deprecated use #sessionContext#stats() + */ + @Deprecated + public final OpenSslSessionStats stats() { + return sessionContext().stats(); + } + + /** + * {@deprecated Renegotiation is not supported} + * Specify if remote initiated renegotiation is supported or not. If not supported and the remote side tries + * to initiate a renegotiation a {@link SSLHandshakeException} will be thrown during decoding. + */ + @Deprecated + public void setRejectRemoteInitiatedRenegotiation(boolean rejectRemoteInitiatedRenegotiation) { + if (!rejectRemoteInitiatedRenegotiation) { + throw new UnsupportedOperationException("Renegotiation is not supported"); + } + } + + /** + * {@deprecated Renegotiation is not supported} + * @return {@code true} because renegotiation is not supported. + */ + @Deprecated + public boolean getRejectRemoteInitiatedRenegotiation() { + return true; + } + + /** + * Set the size of the buffer used by the BIO for non-application based writes + * (e.g. handshake, renegotiation, etc...). + */ + public void setBioNonApplicationBufferSize(int bioNonApplicationBufferSize) { + this.bioNonApplicationBufferSize = + checkPositiveOrZero(bioNonApplicationBufferSize, "bioNonApplicationBufferSize"); + } + + /** + * Returns the size of the buffer used by the BIO for non-application based writes + */ + public int getBioNonApplicationBufferSize() { + return bioNonApplicationBufferSize; + } + + /** + * Sets the SSL session ticket keys of this context. + * + * @deprecated use {@link OpenSslSessionContext#setTicketKeys(byte[])} + */ + @Deprecated + public final void setTicketKeys(byte[] keys) { + sessionContext().setTicketKeys(keys); + } + + @Override + public abstract OpenSslSessionContext sessionContext(); + + /** + * Returns the pointer to the {@code SSL_CTX} object for this {@link ReferenceCountedOpenSslContext}. + * Be aware that it is freed as soon as the {@link #release()} method is called. + * At this point {@code 0} will be returned. + * + * @deprecated this method is considered unsafe as the returned pointer may be released later. Dont use it! + */ + @Deprecated + public final long sslCtxPointer() { + Lock readerLock = ctxLock.readLock(); + readerLock.lock(); + try { + return SSLContext.getSslCtx(ctx); + } finally { + readerLock.unlock(); + } + } + + /** + * Set the {@link OpenSslPrivateKeyMethod} to use. This allows to offload private-key operations + * if needed. + * + * This method is currently only supported when {@code BoringSSL} is used. + * + * @param method method to use. + * @deprecated use {@link SslContextBuilder#option(SslContextOption, Object)} with + * {@link OpenSslContextOption#PRIVATE_KEY_METHOD}. + */ + @Deprecated + @UnstableApi + public final void setPrivateKeyMethod(OpenSslPrivateKeyMethod method) { + checkNotNull(method, "method"); + Lock writerLock = ctxLock.writeLock(); + writerLock.lock(); + try { + SSLContext.setPrivateKeyMethod(ctx, new PrivateKeyMethod(engineMap, method)); + } finally { + writerLock.unlock(); + } + } + + /** + * @deprecated use {@link SslContextBuilder#option(SslContextOption, Object)} with + * {@link OpenSslContextOption#USE_TASKS}. + */ + @Deprecated + public final void setUseTasks(boolean useTasks) { + Lock writerLock = ctxLock.writeLock(); + writerLock.lock(); + try { + SSLContext.setUseTasks(ctx, useTasks); + } finally { + writerLock.unlock(); + } + } + + // IMPORTANT: This method must only be called from either the constructor or the finalizer as a user MUST never + // get access to an OpenSslSessionContext after this method was called to prevent the user from + // producing a segfault. + private void destroy() { + Lock writerLock = ctxLock.writeLock(); + writerLock.lock(); + try { + if (ctx != 0) { + if (enableOcsp) { + SSLContext.disableOcsp(ctx); + } + + SSLContext.free(ctx); + ctx = 0; + + OpenSslSessionContext context = sessionContext(); + if (context != null) { + context.destroy(); + } + } + } finally { + writerLock.unlock(); + } + } + + protected static X509Certificate[] certificates(byte[][] chain) { + X509Certificate[] peerCerts = new X509Certificate[chain.length]; + for (int i = 0; i < peerCerts.length; i++) { + peerCerts[i] = new LazyX509Certificate(chain[i]); + } + return peerCerts; + } + + protected static X509TrustManager chooseTrustManager(TrustManager[] managers) { + for (TrustManager m : managers) { + if (m instanceof X509TrustManager) { + X509TrustManager tm = (X509TrustManager) m; + if (PlatformDependent.javaVersion() >= 7) { + tm = OpenSslX509TrustManagerWrapper.wrapIfNeeded((X509TrustManager) m); + if (useExtendedTrustManager(tm)) { + // Wrap the TrustManager to provide a better exception message for users to debug hostname + // validation failures. + tm = new EnhancingX509ExtendedTrustManager(tm); + } + } + return tm; + } + } + throw new IllegalStateException("no X509TrustManager found"); + } + + protected static X509KeyManager chooseX509KeyManager(KeyManager[] kms) { + for (KeyManager km : kms) { + if (km instanceof X509KeyManager) { + return (X509KeyManager) km; + } + } + throw new IllegalStateException("no X509KeyManager found"); + } + + /** + * Translate a {@link ApplicationProtocolConfig} object to a + * {@link OpenSslApplicationProtocolNegotiator} object. + * + * @param config The configuration which defines the translation + * @return The results of the translation + */ + @SuppressWarnings("deprecation") + static OpenSslApplicationProtocolNegotiator toNegotiator(ApplicationProtocolConfig config) { + if (config == null) { + return NONE_PROTOCOL_NEGOTIATOR; + } + + switch (config.protocol()) { + case NONE: + return NONE_PROTOCOL_NEGOTIATOR; + case ALPN: + case NPN: + case NPN_AND_ALPN: + switch (config.selectedListenerFailureBehavior()) { + case CHOOSE_MY_LAST_PROTOCOL: + case ACCEPT: + switch (config.selectorFailureBehavior()) { + case CHOOSE_MY_LAST_PROTOCOL: + case NO_ADVERTISE: + return new OpenSslDefaultApplicationProtocolNegotiator( + config); + default: + throw new UnsupportedOperationException( + new StringBuilder("OpenSSL provider does not support ") + .append(config.selectorFailureBehavior()) + .append(" behavior").toString()); + } + default: + throw new UnsupportedOperationException( + new StringBuilder("OpenSSL provider does not support ") + .append(config.selectedListenerFailureBehavior()) + .append(" behavior").toString()); + } + default: + throw new Error(); + } + } + + @SuppressJava6Requirement(reason = "Guarded by java version check") + static boolean useExtendedTrustManager(X509TrustManager trustManager) { + return PlatformDependent.javaVersion() >= 7 && trustManager instanceof X509ExtendedTrustManager; + } + + @Override + public final int refCnt() { + return refCnt.refCnt(); + } + + @Override + public final ReferenceCounted retain() { + refCnt.retain(); + return this; + } + + @Override + public final ReferenceCounted retain(int increment) { + refCnt.retain(increment); + return this; + } + + @Override + public final ReferenceCounted touch() { + refCnt.touch(); + return this; + } + + @Override + public final ReferenceCounted touch(Object hint) { + refCnt.touch(hint); + return this; + } + + @Override + public final boolean release() { + return refCnt.release(); + } + + @Override + public final boolean release(int decrement) { + return refCnt.release(decrement); + } + + abstract static class AbstractCertificateVerifier extends CertificateVerifier { + private final OpenSslEngineMap engineMap; + + AbstractCertificateVerifier(OpenSslEngineMap engineMap) { + this.engineMap = engineMap; + } + + @Override + public final int verify(long ssl, byte[][] chain, String auth) { + final ReferenceCountedOpenSslEngine engine = engineMap.get(ssl); + if (engine == null) { + // May be null if it was destroyed in the meantime. + return CertificateVerifier.X509_V_ERR_UNSPECIFIED; + } + X509Certificate[] peerCerts = certificates(chain); + try { + verify(engine, peerCerts, auth); + return CertificateVerifier.X509_V_OK; + } catch (Throwable cause) { + logger.debug("verification of certificate failed", cause); + engine.initHandshakeException(cause); + + // Try to extract the correct error code that should be used. + if (cause instanceof OpenSslCertificateException) { + // This will never return a negative error code as its validated when constructing the + // OpenSslCertificateException. + return ((OpenSslCertificateException) cause).errorCode(); + } + if (cause instanceof CertificateExpiredException) { + return CertificateVerifier.X509_V_ERR_CERT_HAS_EXPIRED; + } + if (cause instanceof CertificateNotYetValidException) { + return CertificateVerifier.X509_V_ERR_CERT_NOT_YET_VALID; + } + if (PlatformDependent.javaVersion() >= 7) { + return translateToError(cause); + } + + // Could not detect a specific error code to use, so fallback to a default code. + return CertificateVerifier.X509_V_ERR_UNSPECIFIED; + } + } + + @SuppressJava6Requirement(reason = "Usage guarded by java version check") + private static int translateToError(Throwable cause) { + if (cause instanceof CertificateRevokedException) { + return CertificateVerifier.X509_V_ERR_CERT_REVOKED; + } + + // The X509TrustManagerImpl uses a Validator which wraps a CertPathValidatorException into + // an CertificateException. So we need to handle the wrapped CertPathValidatorException to be + // able to send the correct alert. + Throwable wrapped = cause.getCause(); + while (wrapped != null) { + if (wrapped instanceof CertPathValidatorException) { + CertPathValidatorException ex = (CertPathValidatorException) wrapped; + CertPathValidatorException.Reason reason = ex.getReason(); + if (reason == CertPathValidatorException.BasicReason.EXPIRED) { + return CertificateVerifier.X509_V_ERR_CERT_HAS_EXPIRED; + } + if (reason == CertPathValidatorException.BasicReason.NOT_YET_VALID) { + return CertificateVerifier.X509_V_ERR_CERT_NOT_YET_VALID; + } + if (reason == CertPathValidatorException.BasicReason.REVOKED) { + return CertificateVerifier.X509_V_ERR_CERT_REVOKED; + } + } + wrapped = wrapped.getCause(); + } + return CertificateVerifier.X509_V_ERR_UNSPECIFIED; + } + + abstract void verify(ReferenceCountedOpenSslEngine engine, X509Certificate[] peerCerts, + String auth) throws Exception; + } + + private static final class DefaultOpenSslEngineMap implements OpenSslEngineMap { + private final Map engines = PlatformDependent.newConcurrentHashMap(); + + @Override + public ReferenceCountedOpenSslEngine remove(long ssl) { + return engines.remove(ssl); + } + + @Override + public void add(ReferenceCountedOpenSslEngine engine) { + engines.put(engine.sslPointer(), engine); + } + + @Override + public ReferenceCountedOpenSslEngine get(long ssl) { + return engines.get(ssl); + } + } + + static void setKeyMaterial(long ctx, X509Certificate[] keyCertChain, PrivateKey key, String keyPassword) + throws SSLException { + /* Load the certificate file and private key. */ + long keyBio = 0; + long keyCertChainBio = 0; + long keyCertChainBio2 = 0; + PemEncoded encoded = null; + try { + // Only encode one time + encoded = PemX509Certificate.toPEM(ByteBufAllocator.DEFAULT, true, keyCertChain); + keyCertChainBio = toBIO(ByteBufAllocator.DEFAULT, encoded.retain()); + keyCertChainBio2 = toBIO(ByteBufAllocator.DEFAULT, encoded.retain()); + + if (key != null) { + keyBio = toBIO(ByteBufAllocator.DEFAULT, key); + } + + SSLContext.setCertificateBio( + ctx, keyCertChainBio, keyBio, + keyPassword == null ? StringUtil.EMPTY_STRING : keyPassword); + // We may have more then one cert in the chain so add all of them now. + SSLContext.setCertificateChainBio(ctx, keyCertChainBio2, true); + } catch (SSLException e) { + throw e; + } catch (Exception e) { + throw new SSLException("failed to set certificate and key", e); + } finally { + freeBio(keyBio); + freeBio(keyCertChainBio); + freeBio(keyCertChainBio2); + if (encoded != null) { + encoded.release(); + } + } + } + + static void freeBio(long bio) { + if (bio != 0) { + SSL.freeBIO(bio); + } + } + + /** + * Return the pointer to a in-memory BIO + * or {@code 0} if the {@code key} is {@code null}. The BIO contains the content of the {@code key}. + */ + static long toBIO(ByteBufAllocator allocator, PrivateKey key) throws Exception { + if (key == null) { + return 0; + } + + PemEncoded pem = PemPrivateKey.toPEM(allocator, true, key); + try { + return toBIO(allocator, pem.retain()); + } finally { + pem.release(); + } + } + + /** + * Return the pointer to a in-memory BIO + * or {@code 0} if the {@code certChain} is {@code null}. The BIO contains the content of the {@code certChain}. + */ + static long toBIO(ByteBufAllocator allocator, X509Certificate... certChain) throws Exception { + if (certChain == null) { + return 0; + } + + checkNonEmpty(certChain, "certChain"); + + PemEncoded pem = PemX509Certificate.toPEM(allocator, true, certChain); + try { + return toBIO(allocator, pem.retain()); + } finally { + pem.release(); + } + } + + static long toBIO(ByteBufAllocator allocator, PemEncoded pem) throws Exception { + try { + // We can turn direct buffers straight into BIOs. No need to + // make a yet another copy. + ByteBuf content = pem.content(); + + if (content.isDirect()) { + return newBIO(content.retainedSlice()); + } + + ByteBuf buffer = allocator.directBuffer(content.readableBytes()); + try { + buffer.writeBytes(content, content.readerIndex(), content.readableBytes()); + return newBIO(buffer.retainedSlice()); + } finally { + try { + // If the contents of the ByteBuf is sensitive (e.g. a PrivateKey) we + // need to zero out the bytes of the copy before we're releasing it. + if (pem.isSensitive()) { + SslUtils.zeroout(buffer); + } + } finally { + buffer.release(); + } + } + } finally { + pem.release(); + } + } + + private static long newBIO(ByteBuf buffer) throws Exception { + try { + long bio = SSL.newMemBIO(); + int readable = buffer.readableBytes(); + if (SSL.bioWrite(bio, OpenSsl.memoryAddress(buffer) + buffer.readerIndex(), readable) != readable) { + SSL.freeBIO(bio); + throw new IllegalStateException("Could not write data to memory BIO"); + } + return bio; + } finally { + buffer.release(); + } + } + + /** + * Returns the {@link OpenSslKeyMaterialProvider} that should be used for OpenSSL. Depending on the given + * {@link KeyManagerFactory} this may cache the {@link OpenSslKeyMaterial} for better performance if it can + * ensure that the same material is always returned for the same alias. + */ + static OpenSslKeyMaterialProvider providerFor(KeyManagerFactory factory, String password) { + if (factory instanceof OpenSslX509KeyManagerFactory) { + return ((OpenSslX509KeyManagerFactory) factory).newProvider(); + } + + if (factory instanceof OpenSslCachingX509KeyManagerFactory) { + // The user explicit used OpenSslCachingX509KeyManagerFactory which signals us that its fine to cache. + return ((OpenSslCachingX509KeyManagerFactory) factory).newProvider(password); + } + // We can not be sure if the material may change at runtime so we will not cache it. + return new OpenSslKeyMaterialProvider(chooseX509KeyManager(factory.getKeyManagers()), password); + } + + private static ReferenceCountedOpenSslEngine retrieveEngine(OpenSslEngineMap engineMap, long ssl) + throws SSLException { + ReferenceCountedOpenSslEngine engine = engineMap.get(ssl); + if (engine == null) { + throw new SSLException("Could not find a " + + StringUtil.simpleClassName(ReferenceCountedOpenSslEngine.class) + " for sslPointer " + ssl); + } + return engine; + } + + private static final class PrivateKeyMethod implements SSLPrivateKeyMethod { + + private final OpenSslEngineMap engineMap; + private final OpenSslPrivateKeyMethod keyMethod; + PrivateKeyMethod(OpenSslEngineMap engineMap, OpenSslPrivateKeyMethod keyMethod) { + this.engineMap = engineMap; + this.keyMethod = keyMethod; + } + + @Override + public byte[] sign(long ssl, int signatureAlgorithm, byte[] digest) throws Exception { + ReferenceCountedOpenSslEngine engine = retrieveEngine(engineMap, ssl); + try { + return verifyResult(keyMethod.sign(engine, signatureAlgorithm, digest)); + } catch (Exception e) { + engine.initHandshakeException(e); + throw e; + } + } + + @Override + public byte[] decrypt(long ssl, byte[] input) throws Exception { + ReferenceCountedOpenSslEngine engine = retrieveEngine(engineMap, ssl); + try { + return verifyResult(keyMethod.decrypt(engine, input)); + } catch (Exception e) { + engine.initHandshakeException(e); + throw e; + } + } + } + + private static final class AsyncPrivateKeyMethod implements AsyncSSLPrivateKeyMethod { + + private final OpenSslEngineMap engineMap; + private final OpenSslAsyncPrivateKeyMethod keyMethod; + + AsyncPrivateKeyMethod(OpenSslEngineMap engineMap, OpenSslAsyncPrivateKeyMethod keyMethod) { + this.engineMap = engineMap; + this.keyMethod = keyMethod; + } + + @Override + public void sign(long ssl, int signatureAlgorithm, byte[] bytes, ResultCallback resultCallback) { + try { + ReferenceCountedOpenSslEngine engine = retrieveEngine(engineMap, ssl); + keyMethod.sign(engine, signatureAlgorithm, bytes) + .addListener(new ResultCallbackListener(engine, ssl, resultCallback)); + } catch (SSLException e) { + resultCallback.onError(ssl, e); + } + } + + @Override + public void decrypt(long ssl, byte[] bytes, ResultCallback resultCallback) { + try { + ReferenceCountedOpenSslEngine engine = retrieveEngine(engineMap, ssl); + keyMethod.decrypt(engine, bytes) + .addListener(new ResultCallbackListener(engine, ssl, resultCallback)); + } catch (SSLException e) { + resultCallback.onError(ssl, e); + } + } + + private static final class ResultCallbackListener implements FutureListener { + private final ReferenceCountedOpenSslEngine engine; + private final long ssl; + private final ResultCallback resultCallback; + + ResultCallbackListener(ReferenceCountedOpenSslEngine engine, long ssl, + ResultCallback resultCallback) { + this.engine = engine; + this.ssl = ssl; + this.resultCallback = resultCallback; + } + + @Override + public void operationComplete(Future future) { + Throwable cause = future.cause(); + if (cause == null) { + try { + byte[] result = verifyResult(future.getNow()); + resultCallback.onSuccess(ssl, result); + return; + } catch (SignatureException e) { + cause = e; + engine.initHandshakeException(e); + } + } + resultCallback.onError(ssl, cause); + } + } + } + + private static byte[] verifyResult(byte[] result) throws SignatureException { + if (result == null) { + throw new SignatureException(); + } + return result; + } + + private static final class CompressionAlgorithm implements CertificateCompressionAlgo { + private final OpenSslEngineMap engineMap; + private final OpenSslCertificateCompressionAlgorithm compressionAlgorithm; + + CompressionAlgorithm(OpenSslEngineMap engineMap, OpenSslCertificateCompressionAlgorithm compressionAlgorithm) { + this.engineMap = engineMap; + this.compressionAlgorithm = compressionAlgorithm; + } + + @Override + public byte[] compress(long ssl, byte[] bytes) throws Exception { + ReferenceCountedOpenSslEngine engine = retrieveEngine(engineMap, ssl); + return compressionAlgorithm.compress(engine, bytes); + } + + @Override + public byte[] decompress(long ssl, int len, byte[] bytes) throws Exception { + ReferenceCountedOpenSslEngine engine = retrieveEngine(engineMap, ssl); + return compressionAlgorithm.decompress(engine, len, bytes); + } + + @Override + public int algorithmId() { + return compressionAlgorithm.algorithmId(); + } + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngine.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngine.java new file mode 100644 index 0000000..5f81c4a --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngine.java @@ -0,0 +1,2761 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.handler.ssl.util.LazyJavaxX509Certificate; +import io.netty.handler.ssl.util.LazyX509Certificate; +import io.netty.internal.tcnative.AsyncTask; +import io.netty.internal.tcnative.Buffer; +import io.netty.internal.tcnative.SSL; +import io.netty.util.AbstractReferenceCounted; +import io.netty.util.CharsetUtil; +import io.netty.util.ReferenceCounted; +import io.netty.util.ResourceLeakDetector; +import io.netty.util.ResourceLeakDetectorFactory; +import io.netty.util.ResourceLeakTracker; +import io.netty.util.internal.EmptyArrays; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.StringUtil; +import io.netty.util.internal.SuppressJava6Requirement; +import io.netty.util.internal.ThrowableUtil; +import io.netty.util.internal.UnstableApi; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.nio.ByteBuffer; +import java.nio.ReadOnlyBufferException; +import java.security.Principal; +import java.security.cert.Certificate; +import java.security.cert.CertificateEncodingException; +import java.security.cert.X509Certificate; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.locks.Lock; + +import javax.crypto.spec.SecretKeySpec; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLEngineResult; +import javax.net.ssl.SSLException; +import javax.net.ssl.SSLHandshakeException; +import javax.net.ssl.SSLParameters; +import javax.net.ssl.SSLPeerUnverifiedException; +import javax.net.ssl.SSLSession; +import javax.net.ssl.SSLSessionBindingEvent; +import javax.net.ssl.SSLSessionBindingListener; + +import static io.netty.handler.ssl.OpenSsl.memoryAddress; +import static io.netty.handler.ssl.SslUtils.SSL_RECORD_HEADER_LENGTH; +import static io.netty.util.internal.EmptyArrays.EMPTY_STRINGS; +import static io.netty.util.internal.ObjectUtil.checkNotNull; +import static io.netty.util.internal.ObjectUtil.checkNotNullArrayParam; +import static io.netty.util.internal.ObjectUtil.checkNotNullWithIAE; +import static java.lang.Integer.MAX_VALUE; +import static java.lang.Math.min; +import static javax.net.ssl.SSLEngineResult.HandshakeStatus.FINISHED; +import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NEED_TASK; +import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NEED_UNWRAP; +import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NEED_WRAP; +import static javax.net.ssl.SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING; +import static javax.net.ssl.SSLEngineResult.Status.BUFFER_OVERFLOW; +import static javax.net.ssl.SSLEngineResult.Status.BUFFER_UNDERFLOW; +import static javax.net.ssl.SSLEngineResult.Status.CLOSED; +import static javax.net.ssl.SSLEngineResult.Status.OK; + +/** + * Implements a {@link SSLEngine} using + * OpenSSL BIO abstractions. + *

Instances of this class must be {@link #release() released} or else native memory will leak! + * + *

Instances of this class must be released before the {@link ReferenceCountedOpenSslContext} + * the instance depends upon are released. Otherwise if any method of this class is called which uses the + * the {@link ReferenceCountedOpenSslContext} JNI resources the JVM may crash. + */ +public class ReferenceCountedOpenSslEngine extends SSLEngine implements ReferenceCounted, ApplicationProtocolAccessor { + + private static final InternalLogger logger = InternalLoggerFactory.getInstance(ReferenceCountedOpenSslEngine.class); + + private static final ResourceLeakDetector leakDetector = + ResourceLeakDetectorFactory.instance().newResourceLeakDetector(ReferenceCountedOpenSslEngine.class); + private static final int OPENSSL_OP_NO_PROTOCOL_INDEX_SSLV2 = 0; + private static final int OPENSSL_OP_NO_PROTOCOL_INDEX_SSLV3 = 1; + private static final int OPENSSL_OP_NO_PROTOCOL_INDEX_TLSv1 = 2; + private static final int OPENSSL_OP_NO_PROTOCOL_INDEX_TLSv1_1 = 3; + private static final int OPENSSL_OP_NO_PROTOCOL_INDEX_TLSv1_2 = 4; + private static final int OPENSSL_OP_NO_PROTOCOL_INDEX_TLSv1_3 = 5; + private static final int[] OPENSSL_OP_NO_PROTOCOLS = { + SSL.SSL_OP_NO_SSLv2, + SSL.SSL_OP_NO_SSLv3, + SSL.SSL_OP_NO_TLSv1, + SSL.SSL_OP_NO_TLSv1_1, + SSL.SSL_OP_NO_TLSv1_2, + SSL.SSL_OP_NO_TLSv1_3 + }; + + /** + * Depends upon tcnative ... only use if tcnative is available! + */ + static final int MAX_PLAINTEXT_LENGTH = SSL.SSL_MAX_PLAINTEXT_LENGTH; + /** + * Depends upon tcnative ... only use if tcnative is available! + */ + static final int MAX_RECORD_SIZE = SSL.SSL_MAX_RECORD_LENGTH; + + private static final SSLEngineResult NEED_UNWRAP_OK = new SSLEngineResult(OK, NEED_UNWRAP, 0, 0); + private static final SSLEngineResult NEED_UNWRAP_CLOSED = new SSLEngineResult(CLOSED, NEED_UNWRAP, 0, 0); + private static final SSLEngineResult NEED_WRAP_OK = new SSLEngineResult(OK, NEED_WRAP, 0, 0); + private static final SSLEngineResult NEED_WRAP_CLOSED = new SSLEngineResult(CLOSED, NEED_WRAP, 0, 0); + private static final SSLEngineResult CLOSED_NOT_HANDSHAKING = new SSLEngineResult(CLOSED, NOT_HANDSHAKING, 0, 0); + + // OpenSSL state + private long ssl; + private long networkBIO; + + private enum HandshakeState { + /** + * Not started yet. + */ + NOT_STARTED, + /** + * Started via unwrap/wrap. + */ + STARTED_IMPLICITLY, + /** + * Started via {@link #beginHandshake()}. + */ + STARTED_EXPLICITLY, + /** + * Handshake is finished. + */ + FINISHED + } + + private HandshakeState handshakeState = HandshakeState.NOT_STARTED; + private boolean receivedShutdown; + private volatile boolean destroyed; + private volatile String applicationProtocol; + private volatile boolean needTask; + private String[] explicitlyEnabledProtocols; + private boolean sessionSet; + + // Reference Counting + private final ResourceLeakTracker leak; + private final AbstractReferenceCounted refCnt = new AbstractReferenceCounted() { + @Override + public ReferenceCounted touch(Object hint) { + if (leak != null) { + leak.record(hint); + } + + return ReferenceCountedOpenSslEngine.this; + } + + @Override + protected void deallocate() { + shutdown(); + if (leak != null) { + boolean closed = leak.close(ReferenceCountedOpenSslEngine.this); + assert closed; + } + parentContext.release(); + } + }; + + private volatile ClientAuth clientAuth = ClientAuth.NONE; + + // Updated once a new handshake is started and so the SSLSession reused. + private volatile long lastAccessed = -1; + + private String endPointIdentificationAlgorithm; + // Store as object as AlgorithmConstraints only exists since java 7. + private Object algorithmConstraints; + private List sniHostNames; + + // Mark as volatile as accessed by checkSniHostnameMatch(...) and also not specify the SNIMatcher type to allow us + // using it with java7. + private volatile Collection matchers; + + // SSL Engine status variables + private boolean isInboundDone; + private boolean outboundClosed; + + final boolean jdkCompatibilityMode; + private final boolean clientMode; + final ByteBufAllocator alloc; + private final OpenSslEngineMap engineMap; + private final OpenSslApplicationProtocolNegotiator apn; + private final ReferenceCountedOpenSslContext parentContext; + private final OpenSslSession session; + private final ByteBuffer[] singleSrcBuffer = new ByteBuffer[1]; + private final ByteBuffer[] singleDstBuffer = new ByteBuffer[1]; + private final boolean enableOcsp; + private int maxWrapOverhead; + private int maxWrapBufferSize; + private Throwable pendingException; + + /** + * Create a new instance. + * @param context Reference count release responsibility is not transferred! The callee still owns this object. + * @param alloc The allocator to use. + * @param peerHost The peer host name. + * @param peerPort The peer port. + * @param jdkCompatibilityMode {@code true} to behave like described in + * https://docs.oracle.com/javase/7/docs/api/javax/net/ssl/SSLEngine.html. + * {@code false} allows for partial and/or multiple packets to be process in a single + * wrap or unwrap call. + * @param leakDetection {@code true} to enable leak detection of this object. + */ + ReferenceCountedOpenSslEngine(ReferenceCountedOpenSslContext context, final ByteBufAllocator alloc, String peerHost, + int peerPort, boolean jdkCompatibilityMode, boolean leakDetection) { + super(peerHost, peerPort); + OpenSsl.ensureAvailability(); + engineMap = context.engineMap; + enableOcsp = context.enableOcsp; + this.jdkCompatibilityMode = jdkCompatibilityMode; + this.alloc = checkNotNull(alloc, "alloc"); + apn = (OpenSslApplicationProtocolNegotiator) context.applicationProtocolNegotiator(); + clientMode = context.isClient(); + + if (PlatformDependent.javaVersion() >= 7) { + session = new ExtendedOpenSslSession(new DefaultOpenSslSession(context.sessionContext())) { + private String[] peerSupportedSignatureAlgorithms; + private List requestedServerNames; + + @Override + public List getRequestedServerNames() { + if (clientMode) { + return Java8SslUtils.getSniHostNames(sniHostNames); + } else { + synchronized (ReferenceCountedOpenSslEngine.this) { + if (requestedServerNames == null) { + if (isDestroyed()) { + requestedServerNames = Collections.emptyList(); + } else { + String name = SSL.getSniHostname(ssl); + if (name == null) { + requestedServerNames = Collections.emptyList(); + } else { + // Convert to bytes as we do not want to do any strict validation of the + // SNIHostName while creating it. + requestedServerNames = + Java8SslUtils.getSniHostName( + SSL.getSniHostname(ssl).getBytes(CharsetUtil.UTF_8)); + } + } + } + return requestedServerNames; + } + } + } + + @Override + public String[] getPeerSupportedSignatureAlgorithms() { + synchronized (ReferenceCountedOpenSslEngine.this) { + if (peerSupportedSignatureAlgorithms == null) { + if (isDestroyed()) { + peerSupportedSignatureAlgorithms = EMPTY_STRINGS; + } else { + String[] algs = SSL.getSigAlgs(ssl); + if (algs == null) { + peerSupportedSignatureAlgorithms = EMPTY_STRINGS; + } else { + Set algorithmList = new LinkedHashSet(algs.length); + for (String alg: algs) { + String converted = SignatureAlgorithmConverter.toJavaName(alg); + + if (converted != null) { + algorithmList.add(converted); + } + } + peerSupportedSignatureAlgorithms = algorithmList.toArray(EMPTY_STRINGS); + } + } + } + return peerSupportedSignatureAlgorithms.clone(); + } + } + + @Override + public List getStatusResponses() { + byte[] ocspResponse = null; + if (enableOcsp && clientMode) { + synchronized (ReferenceCountedOpenSslEngine.this) { + if (!isDestroyed()) { + ocspResponse = SSL.getOcspResponse(ssl); + } + } + } + return ocspResponse == null ? + Collections.emptyList() : Collections.singletonList(ocspResponse); + } + }; + } else { + session = new DefaultOpenSslSession(context.sessionContext()); + } + + if (!context.sessionContext().useKeyManager()) { + session.setLocalCertificate(context.keyCertChain); + } + + Lock readerLock = context.ctxLock.readLock(); + readerLock.lock(); + final long finalSsl; + try { + finalSsl = SSL.newSSL(context.ctx, !context.isClient()); + } finally { + readerLock.unlock(); + } + synchronized (this) { + ssl = finalSsl; + try { + networkBIO = SSL.bioNewByteBuffer(ssl, context.getBioNonApplicationBufferSize()); + + // Set the client auth mode, this needs to be done via setClientAuth(...) method so we actually call the + // needed JNI methods. + setClientAuth(clientMode ? ClientAuth.NONE : context.clientAuth); + + if (context.protocols != null) { + setEnabledProtocols0(context.protocols, true); + } else { + this.explicitlyEnabledProtocols = getEnabledProtocols(); + } + + // Use SNI if peerHost was specified and a valid hostname + // See https://github.com/netty/netty/issues/4746 + if (clientMode && SslUtils.isValidHostNameForSNI(peerHost)) { + // If on java8 and later we should do some extra validation to ensure we can construct the + // SNIHostName later again. + if (PlatformDependent.javaVersion() >= 8) { + if (Java8SslUtils.isValidHostNameForSNI(peerHost)) { + SSL.setTlsExtHostName(ssl, peerHost); + sniHostNames = Collections.singletonList(peerHost); + } + } else { + SSL.setTlsExtHostName(ssl, peerHost); + sniHostNames = Collections.singletonList(peerHost); + } + } + + if (enableOcsp) { + SSL.enableOcsp(ssl); + } + + if (!jdkCompatibilityMode) { + SSL.setMode(ssl, SSL.getMode(ssl) | SSL.SSL_MODE_ENABLE_PARTIAL_WRITE); + } + + if (isProtocolEnabled(SSL.getOptions(ssl), SSL.SSL_OP_NO_TLSv1_3, SslProtocols.TLS_v1_3)) { + final boolean enableTickets = clientMode ? + ReferenceCountedOpenSslContext.CLIENT_ENABLE_SESSION_TICKET_TLSV13 : + ReferenceCountedOpenSslContext.SERVER_ENABLE_SESSION_TICKET_TLSV13; + if (enableTickets) { + // We should enable session tickets for stateless resumption when TLSv1.3 is enabled. This + // is also done by OpenJDK and without this session resumption does not work at all with + // BoringSSL when TLSv1.3 is used as BoringSSL only supports stateless resumption with TLSv1.3: + // + // See: + // - https://bugs.openjdk.java.net/browse/JDK-8223922 + // - https://boringssl.googlesource.com/boringssl/+/refs/heads/master/ssl/tls13_server.cc#104 + SSL.clearOptions(ssl, SSL.SSL_OP_NO_TICKET); + } + } + + if (OpenSsl.isBoringSSL() && clientMode) { + // If in client-mode and BoringSSL let's allow to renegotiate once as the server may use this + // for client auth. + // + // See https://github.com/netty/netty/issues/11529 + SSL.setRenegotiateMode(ssl, SSL.SSL_RENEGOTIATE_ONCE); + } + // setMode may impact the overhead. + calculateMaxWrapOverhead(); + } catch (Throwable cause) { + // Call shutdown so we are sure we correctly release all native memory and also guard against the + // case when shutdown() will be called by the finalizer again. + shutdown(); + + PlatformDependent.throwException(cause); + } + } + + // Now that everything looks good and we're going to successfully return the + // object so we need to retain a reference to the parent context. + parentContext = context; + parentContext.retain(); + + // Only create the leak after everything else was executed and so ensure we don't produce a false-positive for + // the ResourceLeakDetector. + leak = leakDetection ? leakDetector.track(this) : null; + } + + final synchronized String[] authMethods() { + if (isDestroyed()) { + return EMPTY_STRINGS; + } + return SSL.authenticationMethods(ssl); + } + + final boolean setKeyMaterial(OpenSslKeyMaterial keyMaterial) throws Exception { + synchronized (this) { + if (isDestroyed()) { + return false; + } + SSL.setKeyMaterial(ssl, keyMaterial.certificateChainAddress(), keyMaterial.privateKeyAddress()); + } + session.setLocalCertificate(keyMaterial.certificateChain()); + return true; + } + + final synchronized SecretKeySpec masterKey() { + if (isDestroyed()) { + return null; + } + return new SecretKeySpec(SSL.getMasterKey(ssl), "AES"); + } + + synchronized boolean isSessionReused() { + if (isDestroyed()) { + return false; + } + return SSL.isSessionReused(ssl); + } + + /** + * Sets the OCSP response. + */ + @UnstableApi + public void setOcspResponse(byte[] response) { + if (!enableOcsp) { + throw new IllegalStateException("OCSP stapling is not enabled"); + } + + if (clientMode) { + throw new IllegalStateException("Not a server SSLEngine"); + } + + synchronized (this) { + if (!isDestroyed()) { + SSL.setOcspResponse(ssl, response); + } + } + } + + /** + * Returns the OCSP response or {@code null} if the server didn't provide a stapled OCSP response. + */ + @UnstableApi + public byte[] getOcspResponse() { + if (!enableOcsp) { + throw new IllegalStateException("OCSP stapling is not enabled"); + } + + if (!clientMode) { + throw new IllegalStateException("Not a client SSLEngine"); + } + + synchronized (this) { + if (isDestroyed()) { + return EmptyArrays.EMPTY_BYTES; + } + return SSL.getOcspResponse(ssl); + } + } + + @Override + public final int refCnt() { + return refCnt.refCnt(); + } + + @Override + public final ReferenceCounted retain() { + refCnt.retain(); + return this; + } + + @Override + public final ReferenceCounted retain(int increment) { + refCnt.retain(increment); + return this; + } + + @Override + public final ReferenceCounted touch() { + refCnt.touch(); + return this; + } + + @Override + public final ReferenceCounted touch(Object hint) { + refCnt.touch(hint); + return this; + } + + @Override + public final boolean release() { + return refCnt.release(); + } + + @Override + public final boolean release(int decrement) { + return refCnt.release(decrement); + } + + // These method will override the method defined by Java 8u251 and later. As we may compile with an earlier + // java8 version we don't use @Override annotations here. + public String getApplicationProtocol() { + return applicationProtocol; + } + + // These method will override the method defined by Java 8u251 and later. As we may compile with an earlier + // java8 version we don't use @Override annotations here. + public String getHandshakeApplicationProtocol() { + return applicationProtocol; + } + + @Override + public final synchronized SSLSession getHandshakeSession() { + // Javadocs state return value should be: + // null if this instance is not currently handshaking, or if the current handshake has not + // progressed far enough to create a basic SSLSession. Otherwise, this method returns the + // SSLSession currently being negotiated. + switch(handshakeState) { + case NOT_STARTED: + case FINISHED: + return null; + default: + return session; + } + } + + /** + * Returns the pointer to the {@code SSL} object for this {@link ReferenceCountedOpenSslEngine}. + * Be aware that it is freed as soon as the {@link #release()} or {@link #shutdown()} methods are called. + * At this point {@code 0} will be returned. + */ + public final synchronized long sslPointer() { + return ssl; + } + + /** + * Destroys this engine. + */ + public final synchronized void shutdown() { + if (!destroyed) { + destroyed = true; + // Let's check if engineMap is null as it could be in theory if we throw an OOME during the construction of + // ReferenceCountedOpenSslEngine (before we assign the field). This is needed as shutdown() is called from + // the finalizer as well. + if (engineMap != null) { + engineMap.remove(ssl); + } + SSL.freeSSL(ssl); + ssl = networkBIO = 0; + + isInboundDone = outboundClosed = true; + } + + // On shutdown clear all errors + SSL.clearError(); + } + + /** + * Write plaintext data to the OpenSSL internal BIO + * + * Calling this function with src.remaining == 0 is undefined. + */ + private int writePlaintextData(final ByteBuffer src, int len) { + final int pos = src.position(); + final int limit = src.limit(); + final int sslWrote; + + if (src.isDirect()) { + sslWrote = SSL.writeToSSL(ssl, bufferAddress(src) + pos, len); + if (sslWrote > 0) { + src.position(pos + sslWrote); + } + } else { + ByteBuf buf = alloc.directBuffer(len); + try { + src.limit(pos + len); + + buf.setBytes(0, src); + src.limit(limit); + + sslWrote = SSL.writeToSSL(ssl, memoryAddress(buf), len); + if (sslWrote > 0) { + src.position(pos + sslWrote); + } else { + src.position(pos); + } + } finally { + buf.release(); + } + } + return sslWrote; + } + + synchronized void bioSetFd(int fd) { + if (!isDestroyed()) { + SSL.bioSetFd(this.ssl, fd); + } + } + + /** + * Write encrypted data to the OpenSSL network BIO. + */ + private ByteBuf writeEncryptedData(final ByteBuffer src, int len) throws SSLException { + final int pos = src.position(); + if (src.isDirect()) { + SSL.bioSetByteBuffer(networkBIO, bufferAddress(src) + pos, len, false); + } else { + final ByteBuf buf = alloc.directBuffer(len); + try { + final int limit = src.limit(); + src.limit(pos + len); + buf.writeBytes(src); + // Restore the original position and limit because we don't want to consume from `src`. + src.position(pos); + src.limit(limit); + + SSL.bioSetByteBuffer(networkBIO, memoryAddress(buf), len, false); + return buf; + } catch (Throwable cause) { + buf.release(); + PlatformDependent.throwException(cause); + } + } + return null; + } + + /** + * Read plaintext data from the OpenSSL internal BIO + */ + private int readPlaintextData(final ByteBuffer dst) throws SSLException { + final int sslRead; + final int pos = dst.position(); + if (dst.isDirect()) { + sslRead = SSL.readFromSSL(ssl, bufferAddress(dst) + pos, dst.limit() - pos); + if (sslRead > 0) { + dst.position(pos + sslRead); + } + } else { + final int limit = dst.limit(); + final int len = min(maxEncryptedPacketLength0(), limit - pos); + final ByteBuf buf = alloc.directBuffer(len); + try { + sslRead = SSL.readFromSSL(ssl, memoryAddress(buf), len); + if (sslRead > 0) { + dst.limit(pos + sslRead); + buf.getBytes(buf.readerIndex(), dst); + dst.limit(limit); + } + } finally { + buf.release(); + } + } + + return sslRead; + } + + /** + * Visible only for testing! + */ + final synchronized int maxWrapOverhead() { + return maxWrapOverhead; + } + + /** + * Visible only for testing! + */ + final synchronized int maxEncryptedPacketLength() { + return maxEncryptedPacketLength0(); + } + + /** + * This method is intentionally not synchronized, only use if you know you are in the EventLoop + * thread and visibility on {@link #maxWrapOverhead} is achieved via other synchronized blocks. + */ + final int maxEncryptedPacketLength0() { + return maxWrapOverhead + MAX_PLAINTEXT_LENGTH; + } + + /** + * This method is intentionally not synchronized, only use if you know you are in the EventLoop + * thread and visibility on {@link #maxWrapBufferSize} and {@link #maxWrapOverhead} is achieved + * via other synchronized blocks. + *
+ * Calculates the max size of a single wrap operation for the given plaintextLength and + * numComponents. + */ + final int calculateMaxLengthForWrap(int plaintextLength, int numComponents) { + return (int) min(maxWrapBufferSize, plaintextLength + (long) maxWrapOverhead * numComponents); + } + + /** + * This method is intentionally not synchronized, only use if you know you are in the EventLoop + * thread and visibility on {@link #maxWrapOverhead} is achieved via other synchronized blocks. + *
+ * Calculates the size of the out net buf to create for the given plaintextLength and numComponents. + * This is not related to the max size per wrap, as we can wrap chunks at a time into one out net buf. + */ + final int calculateOutNetBufSize(int plaintextLength, int numComponents) { + return (int) min(MAX_VALUE, plaintextLength + (long) maxWrapOverhead * numComponents); + } + + final synchronized int sslPending() { + return sslPending0(); + } + + /** + * It is assumed this method is called in a synchronized block (or the constructor)! + */ + private void calculateMaxWrapOverhead() { + maxWrapOverhead = SSL.getMaxWrapOverhead(ssl); + + // maxWrapBufferSize must be set after maxWrapOverhead because there is a dependency on this value. + // If jdkCompatibility mode is off we allow enough space to encrypt 16 buffers at a time. This could be + // configurable in the future if necessary. + maxWrapBufferSize = jdkCompatibilityMode ? maxEncryptedPacketLength0() : maxEncryptedPacketLength0() << 4; + } + + private int sslPending0() { + // OpenSSL has a limitation where if you call SSL_pending before the handshake is complete OpenSSL will throw a + // "called a function you should not call" error. Using the TLS_method instead of SSLv23_method may solve this + // issue but this API is only available in 1.1.0+ [1]. + // [1] https://www.openssl.org/docs/man1.1.0/ssl/SSL_CTX_new.html + return handshakeState != HandshakeState.FINISHED ? 0 : SSL.sslPending(ssl); + } + + private boolean isBytesAvailableEnoughForWrap(int bytesAvailable, int plaintextLength, int numComponents) { + return bytesAvailable - (long) maxWrapOverhead * numComponents >= plaintextLength; + } + + @Override + public final SSLEngineResult wrap( + final ByteBuffer[] srcs, int offset, final int length, final ByteBuffer dst) throws SSLException { + // Throw required runtime exceptions + checkNotNullWithIAE(srcs, "srcs"); + checkNotNullWithIAE(dst, "dst"); + + if (offset >= srcs.length || offset + length > srcs.length) { + throw new IndexOutOfBoundsException( + "offset: " + offset + ", length: " + length + + " (expected: offset <= offset + length <= srcs.length (" + srcs.length + "))"); + } + + if (dst.isReadOnly()) { + throw new ReadOnlyBufferException(); + } + + synchronized (this) { + if (isOutboundDone()) { + // All drained in the outbound buffer + return isInboundDone() || isDestroyed() ? CLOSED_NOT_HANDSHAKING : NEED_UNWRAP_CLOSED; + } + + int bytesProduced = 0; + ByteBuf bioReadCopyBuf = null; + try { + // Setup the BIO buffer so that we directly write the encryption results into dst. + if (dst.isDirect()) { + SSL.bioSetByteBuffer(networkBIO, bufferAddress(dst) + dst.position(), dst.remaining(), + true); + } else { + bioReadCopyBuf = alloc.directBuffer(dst.remaining()); + SSL.bioSetByteBuffer(networkBIO, memoryAddress(bioReadCopyBuf), bioReadCopyBuf.writableBytes(), + true); + } + + int bioLengthBefore = SSL.bioLengthByteBuffer(networkBIO); + + // Explicitly use outboundClosed as we want to drain any bytes that are still present. + if (outboundClosed) { + // If the outbound was closed we want to ensure we can produce the alert to the destination buffer. + // This is true even if we not using jdkCompatibilityMode. + // + // We use a plaintextLength of 2 as we at least want to have an alert fit into it. + // https://tools.ietf.org/html/rfc5246#section-7.2 + if (!isBytesAvailableEnoughForWrap(dst.remaining(), 2, 1)) { + return new SSLEngineResult(BUFFER_OVERFLOW, getHandshakeStatus(), 0, 0); + } + + // There is something left to drain. + // See https://github.com/netty/netty/issues/6260 + bytesProduced = SSL.bioFlushByteBuffer(networkBIO); + if (bytesProduced <= 0) { + return newResultMayFinishHandshake(NOT_HANDSHAKING, 0, 0); + } + // It is possible when the outbound was closed there was not enough room in the non-application + // buffers to hold the close_notify. We should keep trying to close until we consume all the data + // OpenSSL can give us. + if (!doSSLShutdown()) { + return newResultMayFinishHandshake(NOT_HANDSHAKING, 0, bytesProduced); + } + bytesProduced = bioLengthBefore - SSL.bioLengthByteBuffer(networkBIO); + return newResultMayFinishHandshake(NEED_WRAP, 0, bytesProduced); + } + + // Flush any data that may be implicitly generated by OpenSSL (handshake, close, etc..). + SSLEngineResult.HandshakeStatus status = NOT_HANDSHAKING; + HandshakeState oldHandshakeState = handshakeState; + + // Prepare OpenSSL to work in server mode and receive handshake + if (handshakeState != HandshakeState.FINISHED) { + if (handshakeState != HandshakeState.STARTED_EXPLICITLY) { + // Update accepted so we know we triggered the handshake via wrap + handshakeState = HandshakeState.STARTED_IMPLICITLY; + } + + // Flush any data that may have been written implicitly during the handshake by OpenSSL. + bytesProduced = SSL.bioFlushByteBuffer(networkBIO); + + if (pendingException != null) { + // TODO(scott): It is possible that when the handshake failed there was not enough room in the + // non-application buffers to hold the alert. We should get all the data before progressing on. + // However I'm not aware of a way to do this with the OpenSSL APIs. + // See https://github.com/netty/netty/issues/6385. + + // We produced / consumed some data during the handshake, signal back to the caller. + // If there is a handshake exception and we have produced data, we should send the data before + // we allow handshake() to throw the handshake exception. + // + // When the user calls wrap() again we will propagate the handshake error back to the user as + // soon as there is no more data to was produced (as part of an alert etc). + if (bytesProduced > 0) { + return newResult(NEED_WRAP, 0, bytesProduced); + } + // Nothing was produced see if there is a handshakeException that needs to be propagated + // to the caller by calling handshakeException() which will return the right HandshakeStatus + // if it can "recover" from the exception for now. + return newResult(handshakeException(), 0, 0); + } + + status = handshake(); + + // Handshake may have generated more data, for example if the internal SSL buffer is small + // we may have freed up space by flushing above. + bytesProduced = bioLengthBefore - SSL.bioLengthByteBuffer(networkBIO); + + if (status == NEED_TASK) { + return newResult(status, 0, bytesProduced); + } + + if (bytesProduced > 0) { + // If we have filled up the dst buffer and we have not finished the handshake we should try to + // wrap again. Otherwise we should only try to wrap again if there is still data pending in + // SSL buffers. + return newResult(mayFinishHandshake(status != FINISHED ? + bytesProduced == bioLengthBefore ? NEED_WRAP : + getHandshakeStatus(SSL.bioLengthNonApplication(networkBIO)) : FINISHED), + 0, bytesProduced); + } + + if (status == NEED_UNWRAP) { + // Signal if the outbound is done or not. + return isOutboundDone() ? NEED_UNWRAP_CLOSED : NEED_UNWRAP_OK; + } + + // Explicit use outboundClosed and not outboundClosed() as we want to drain any bytes that are + // still present. + if (outboundClosed) { + bytesProduced = SSL.bioFlushByteBuffer(networkBIO); + return newResultMayFinishHandshake(status, 0, bytesProduced); + } + } + + final int endOffset = offset + length; + if (jdkCompatibilityMode || + // If the handshake was not finished before we entered the method, we also ensure we only + // wrap one record. We do this to ensure we not produce any extra data before the caller + // of the method is able to observe handshake completion and react on it. + oldHandshakeState != HandshakeState.FINISHED) { + int srcsLen = 0; + for (int i = offset; i < endOffset; ++i) { + final ByteBuffer src = srcs[i]; + if (src == null) { + throw new IllegalArgumentException("srcs[" + i + "] is null"); + } + if (srcsLen == MAX_PLAINTEXT_LENGTH) { + continue; + } + + srcsLen += src.remaining(); + if (srcsLen > MAX_PLAINTEXT_LENGTH || srcsLen < 0) { + // If srcLen > MAX_PLAINTEXT_LENGTH or secLen < 0 just set it to MAX_PLAINTEXT_LENGTH. + // This also help us to guard against overflow. + // We not break out here as we still need to check for null entries in srcs[]. + srcsLen = MAX_PLAINTEXT_LENGTH; + } + } + + // jdkCompatibilityMode will only produce a single TLS packet, and we don't aggregate src buffers, + // so we always fix the number of buffers to 1 when checking if the dst buffer is large enough. + if (!isBytesAvailableEnoughForWrap(dst.remaining(), srcsLen, 1)) { + return new SSLEngineResult(BUFFER_OVERFLOW, getHandshakeStatus(), 0, 0); + } + } + + // There was no pending data in the network BIO -- encrypt any application data + int bytesConsumed = 0; + assert bytesProduced == 0; + + // Flush any data that may have been written implicitly by OpenSSL in case a shutdown/alert occurs. + bytesProduced = SSL.bioFlushByteBuffer(networkBIO); + + if (bytesProduced > 0) { + return newResultMayFinishHandshake(status, bytesConsumed, bytesProduced); + } + // There was a pending exception that we just delayed because there was something to produce left. + // Throw it now and shutdown the engine. + if (pendingException != null) { + Throwable error = pendingException; + pendingException = null; + shutdown(); + // Throw a new exception wrapping the pending exception, so the stacktrace is meaningful and + // contains all the details. + throw new SSLException(error); + } + + for (; offset < endOffset; ++offset) { + final ByteBuffer src = srcs[offset]; + final int remaining = src.remaining(); + if (remaining == 0) { + continue; + } + + final int bytesWritten; + if (jdkCompatibilityMode) { + // Write plaintext application data to the SSL engine. We don't have to worry about checking + // if there is enough space if jdkCompatibilityMode because we only wrap at most + // MAX_PLAINTEXT_LENGTH and we loop over the input before hand and check if there is space. + bytesWritten = writePlaintextData(src, min(remaining, MAX_PLAINTEXT_LENGTH - bytesConsumed)); + } else { + // OpenSSL's SSL_write keeps state between calls. We should make sure the amount we attempt to + // write is guaranteed to succeed so we don't have to worry about keeping state consistent + // between calls. + final int availableCapacityForWrap = dst.remaining() - bytesProduced - maxWrapOverhead; + if (availableCapacityForWrap <= 0) { + return new SSLEngineResult(BUFFER_OVERFLOW, getHandshakeStatus(), bytesConsumed, + bytesProduced); + } + bytesWritten = writePlaintextData(src, min(remaining, availableCapacityForWrap)); + } + + // Determine how much encrypted data was generated. + // + // Even if SSL_write doesn't consume any application data it is possible that OpenSSL will + // produce non-application data into the BIO. For example session tickets.... + // See https://github.com/netty/netty/issues/10041 + final int pendingNow = SSL.bioLengthByteBuffer(networkBIO); + bytesProduced += bioLengthBefore - pendingNow; + bioLengthBefore = pendingNow; + + if (bytesWritten > 0) { + bytesConsumed += bytesWritten; + + if (jdkCompatibilityMode || bytesProduced == dst.remaining()) { + return newResultMayFinishHandshake(status, bytesConsumed, bytesProduced); + } + } else { + int sslError = SSL.getError(ssl, bytesWritten); + if (sslError == SSL.SSL_ERROR_ZERO_RETURN) { + // This means the connection was shutdown correctly, close inbound and outbound + if (!receivedShutdown) { + closeAll(); + + bytesProduced += bioLengthBefore - SSL.bioLengthByteBuffer(networkBIO); + + // If we have filled up the dst buffer and we have not finished the handshake we should + // try to wrap again. Otherwise we should only try to wrap again if there is still data + // pending in SSL buffers. + SSLEngineResult.HandshakeStatus hs = mayFinishHandshake( + status != FINISHED ? bytesProduced == dst.remaining() ? NEED_WRAP + : getHandshakeStatus(SSL.bioLengthNonApplication(networkBIO)) + : FINISHED); + return newResult(hs, bytesConsumed, bytesProduced); + } + + return newResult(NOT_HANDSHAKING, bytesConsumed, bytesProduced); + } else if (sslError == SSL.SSL_ERROR_WANT_READ) { + // If there is no pending data to read from BIO we should go back to event loop and try + // to read more data [1]. It is also possible that event loop will detect the socket has + // been closed. [1] https://www.openssl.org/docs/manmaster/ssl/SSL_write.html + return newResult(NEED_UNWRAP, bytesConsumed, bytesProduced); + } else if (sslError == SSL.SSL_ERROR_WANT_WRITE) { + // SSL_ERROR_WANT_WRITE typically means that the underlying transport is not writable + // and we should set the "want write" flag on the selector and try again when the + // underlying transport is writable [1]. However we are not directly writing to the + // underlying transport and instead writing to a BIO buffer. The OpenSsl documentation + // says we should do the following [1]: + // + // "When using a buffering BIO, like a BIO pair, data must be written into or retrieved + // out of the BIO before being able to continue." + // + // In practice this means the destination buffer doesn't have enough space for OpenSSL + // to write encrypted data to. This is an OVERFLOW condition. + // [1] https://www.openssl.org/docs/manmaster/ssl/SSL_write.html + if (bytesProduced > 0) { + // If we produced something we should report this back and let the user call + // wrap again. + return newResult(NEED_WRAP, bytesConsumed, bytesProduced); + } + return newResult(BUFFER_OVERFLOW, status, bytesConsumed, bytesProduced); + } else if (sslError == SSL.SSL_ERROR_WANT_X509_LOOKUP || + sslError == SSL.SSL_ERROR_WANT_CERTIFICATE_VERIFY || + sslError == SSL.SSL_ERROR_WANT_PRIVATE_KEY_OPERATION) { + + return newResult(NEED_TASK, bytesConsumed, bytesProduced); + } else { + // Everything else is considered as error + throw shutdownWithError("SSL_write", sslError); + } + } + } + return newResultMayFinishHandshake(status, bytesConsumed, bytesProduced); + } finally { + SSL.bioClearByteBuffer(networkBIO); + if (bioReadCopyBuf == null) { + dst.position(dst.position() + bytesProduced); + } else { + assert bioReadCopyBuf.readableBytes() <= dst.remaining() : "The destination buffer " + dst + + " didn't have enough remaining space to hold the encrypted content in " + bioReadCopyBuf; + dst.put(bioReadCopyBuf.internalNioBuffer(bioReadCopyBuf.readerIndex(), bytesProduced)); + bioReadCopyBuf.release(); + } + } + } + } + + private SSLEngineResult newResult(SSLEngineResult.HandshakeStatus hs, int bytesConsumed, int bytesProduced) { + return newResult(OK, hs, bytesConsumed, bytesProduced); + } + + private SSLEngineResult newResult(SSLEngineResult.Status status, SSLEngineResult.HandshakeStatus hs, + int bytesConsumed, int bytesProduced) { + // If isOutboundDone, then the data from the network BIO + // was the close_notify message and all was consumed we are not required to wait + // for the receipt the peer's close_notify message -- shutdown. + if (isOutboundDone()) { + if (isInboundDone()) { + // If the inbound was done as well, we need to ensure we return NOT_HANDSHAKING to signal we are done. + hs = NOT_HANDSHAKING; + + // As the inbound and the outbound is done we can shutdown the engine now. + shutdown(); + } + return new SSLEngineResult(CLOSED, hs, bytesConsumed, bytesProduced); + } + if (hs == NEED_TASK) { + // Set needTask to true so getHandshakeStatus() will return the correct value. + needTask = true; + } + return new SSLEngineResult(status, hs, bytesConsumed, bytesProduced); + } + + private SSLEngineResult newResultMayFinishHandshake(SSLEngineResult.HandshakeStatus hs, + int bytesConsumed, int bytesProduced) throws SSLException { + return newResult(mayFinishHandshake(hs, bytesConsumed, bytesProduced), bytesConsumed, bytesProduced); + } + + private SSLEngineResult newResultMayFinishHandshake(SSLEngineResult.Status status, + SSLEngineResult.HandshakeStatus hs, + int bytesConsumed, int bytesProduced) throws SSLException { + return newResult(status, mayFinishHandshake(hs, bytesConsumed, bytesProduced), bytesConsumed, bytesProduced); + } + + /** + * Log the error, shutdown the engine and throw an exception. + */ + private SSLException shutdownWithError(String operations, int sslError) { + return shutdownWithError(operations, sslError, SSL.getLastErrorNumber()); + } + + private SSLException shutdownWithError(String operation, int sslError, int error) { + if (logger.isDebugEnabled()) { + String errorString = SSL.getErrorString(error); + logger.debug("{} failed with {}: OpenSSL error: {} {}", + operation, sslError, error, errorString); + } + + // There was an internal error -- shutdown + shutdown(); + + SSLException exception = newSSLExceptionForError(error); + // If we have a pendingException stored already we should include it as well to help the user debug things. + if (pendingException != null) { + exception.initCause(pendingException); + pendingException = null; + } + return exception; + } + + private SSLEngineResult handleUnwrapException(int bytesConsumed, int bytesProduced, SSLException e) + throws SSLException { + int lastError = SSL.getLastErrorNumber(); + if (lastError != 0) { + return sslReadErrorResult(SSL.SSL_ERROR_SSL, lastError, bytesConsumed, + bytesProduced); + } + throw e; + } + + public final SSLEngineResult unwrap( + final ByteBuffer[] srcs, int srcsOffset, final int srcsLength, + final ByteBuffer[] dsts, int dstsOffset, final int dstsLength) throws SSLException { + + // Throw required runtime exceptions + checkNotNullWithIAE(srcs, "srcs"); + if (srcsOffset >= srcs.length + || srcsOffset + srcsLength > srcs.length) { + throw new IndexOutOfBoundsException( + "offset: " + srcsOffset + ", length: " + srcsLength + + " (expected: offset <= offset + length <= srcs.length (" + srcs.length + "))"); + } + checkNotNullWithIAE(dsts, "dsts"); + if (dstsOffset >= dsts.length || dstsOffset + dstsLength > dsts.length) { + throw new IndexOutOfBoundsException( + "offset: " + dstsOffset + ", length: " + dstsLength + + " (expected: offset <= offset + length <= dsts.length (" + dsts.length + "))"); + } + long capacity = 0; + final int dstsEndOffset = dstsOffset + dstsLength; + for (int i = dstsOffset; i < dstsEndOffset; i ++) { + ByteBuffer dst = checkNotNullArrayParam(dsts[i], i, "dsts"); + if (dst.isReadOnly()) { + throw new ReadOnlyBufferException(); + } + capacity += dst.remaining(); + } + + final int srcsEndOffset = srcsOffset + srcsLength; + long len = 0; + for (int i = srcsOffset; i < srcsEndOffset; i++) { + ByteBuffer src = checkNotNullArrayParam(srcs[i], i, "srcs"); + len += src.remaining(); + } + + synchronized (this) { + if (isInboundDone()) { + return isOutboundDone() || isDestroyed() ? CLOSED_NOT_HANDSHAKING : NEED_WRAP_CLOSED; + } + + SSLEngineResult.HandshakeStatus status = NOT_HANDSHAKING; + HandshakeState oldHandshakeState = handshakeState; + // Prepare OpenSSL to work in server mode and receive handshake + if (handshakeState != HandshakeState.FINISHED) { + if (handshakeState != HandshakeState.STARTED_EXPLICITLY) { + // Update accepted so we know we triggered the handshake via wrap + handshakeState = HandshakeState.STARTED_IMPLICITLY; + } + + status = handshake(); + + if (status == NEED_TASK) { + return newResult(status, 0, 0); + } + + if (status == NEED_WRAP) { + return NEED_WRAP_OK; + } + // Check if the inbound is considered to be closed if so let us try to wrap again. + if (isInboundDone) { + return NEED_WRAP_CLOSED; + } + } + + int sslPending = sslPending0(); + int packetLength; + // The JDK implies that only a single SSL packet should be processed per unwrap call [1]. If we are in + // JDK compatibility mode then we should honor this, but if not we just wrap as much as possible. If there + // are multiple records or partial records this may reduce thrashing events through the pipeline. + // [1] https://docs.oracle.com/javase/7/docs/api/javax/net/ssl/SSLEngine.html + if (jdkCompatibilityMode || + // If the handshake was not finished before we entered the method, we also ensure we only + // unwrap one record. We do this to ensure we not produce any extra data before the caller + // of the method is able to observe handshake completion and react on it. + oldHandshakeState != HandshakeState.FINISHED) { + if (len < SSL_RECORD_HEADER_LENGTH) { + return newResultMayFinishHandshake(BUFFER_UNDERFLOW, status, 0, 0); + } + + packetLength = SslUtils.getEncryptedPacketLength(srcs, srcsOffset); + if (packetLength == SslUtils.NOT_ENCRYPTED) { + throw new NotSslRecordException("not an SSL/TLS record"); + } + + final int packetLengthDataOnly = packetLength - SSL_RECORD_HEADER_LENGTH; + if (packetLengthDataOnly > capacity) { + // Not enough space in the destination buffer so signal the caller that the buffer needs to be + // increased. + if (packetLengthDataOnly > MAX_RECORD_SIZE) { + // The packet length MUST NOT exceed 2^14 [1]. However we do accommodate more data to support + // legacy use cases which may violate this condition (e.g. OpenJDK's SslEngineImpl). If the max + // length is exceeded we fail fast here to avoid an infinite loop due to the fact that we + // won't allocate a buffer large enough. + // [1] https://tools.ietf.org/html/rfc5246#section-6.2.1 + throw new SSLException("Illegal packet length: " + packetLengthDataOnly + " > " + + session.getApplicationBufferSize()); + } else { + session.tryExpandApplicationBufferSize(packetLengthDataOnly); + } + return newResultMayFinishHandshake(BUFFER_OVERFLOW, status, 0, 0); + } + + if (len < packetLength) { + // We either don't have enough data to read the packet length or not enough for reading the whole + // packet. + return newResultMayFinishHandshake(BUFFER_UNDERFLOW, status, 0, 0); + } + } else if (len == 0 && sslPending <= 0) { + return newResultMayFinishHandshake(BUFFER_UNDERFLOW, status, 0, 0); + } else if (capacity == 0) { + return newResultMayFinishHandshake(BUFFER_OVERFLOW, status, 0, 0); + } else { + packetLength = (int) min(MAX_VALUE, len); + } + + // This must always be the case when we reached here as if not we returned BUFFER_UNDERFLOW. + assert srcsOffset < srcsEndOffset; + + // This must always be the case if we reached here. + assert capacity > 0; + + // Number of produced bytes + int bytesProduced = 0; + int bytesConsumed = 0; + try { + srcLoop: + for (;;) { + ByteBuffer src = srcs[srcsOffset]; + int remaining = src.remaining(); + final ByteBuf bioWriteCopyBuf; + int pendingEncryptedBytes; + if (remaining == 0) { + if (sslPending <= 0) { + // We must skip empty buffers as BIO_write will return 0 if asked to write something + // with length 0. + if (++srcsOffset >= srcsEndOffset) { + break; + } + continue; + } else { + bioWriteCopyBuf = null; + pendingEncryptedBytes = SSL.bioLengthByteBuffer(networkBIO); + } + } else { + // Write more encrypted data into the BIO. Ensure we only read one packet at a time as + // stated in the SSLEngine javadocs. + pendingEncryptedBytes = min(packetLength, remaining); + try { + bioWriteCopyBuf = writeEncryptedData(src, pendingEncryptedBytes); + } catch (SSLException e) { + // Ensure we correctly handle the error stack. + return handleUnwrapException(bytesConsumed, bytesProduced, e); + } + } + try { + for (;;) { + ByteBuffer dst = dsts[dstsOffset]; + if (!dst.hasRemaining()) { + // No space left in the destination buffer, skip it. + if (++dstsOffset >= dstsEndOffset) { + break srcLoop; + } + continue; + } + + int bytesRead; + try { + bytesRead = readPlaintextData(dst); + } catch (SSLException e) { + // Ensure we correctly handle the error stack. + return handleUnwrapException(bytesConsumed, bytesProduced, e); + } + // We are directly using the ByteBuffer memory for the write, and so we only know what has + // been consumed after we let SSL decrypt the data. At this point we should update the + // number of bytes consumed, update the ByteBuffer position, and release temp ByteBuf. + int localBytesConsumed = pendingEncryptedBytes - SSL.bioLengthByteBuffer(networkBIO); + bytesConsumed += localBytesConsumed; + packetLength -= localBytesConsumed; + pendingEncryptedBytes -= localBytesConsumed; + src.position(src.position() + localBytesConsumed); + + if (bytesRead > 0) { + bytesProduced += bytesRead; + + if (!dst.hasRemaining()) { + sslPending = sslPending0(); + // Move to the next dst buffer as this one is full. + if (++dstsOffset >= dstsEndOffset) { + return sslPending > 0 ? + newResult(BUFFER_OVERFLOW, status, bytesConsumed, bytesProduced) : + newResultMayFinishHandshake(isInboundDone() ? CLOSED : OK, status, + bytesConsumed, bytesProduced); + } + } else if (packetLength == 0 || jdkCompatibilityMode) { + // We either consumed all data or we are in jdkCompatibilityMode and have consumed + // a single TLS packet and should stop consuming until this method is called again. + break srcLoop; + } + } else { + int sslError = SSL.getError(ssl, bytesRead); + if (sslError == SSL.SSL_ERROR_WANT_READ || sslError == SSL.SSL_ERROR_WANT_WRITE) { + // break to the outer loop as we want to read more data which means we need to + // write more to the BIO. + break; + } else if (sslError == SSL.SSL_ERROR_ZERO_RETURN) { + // This means the connection was shutdown correctly, close inbound and outbound + if (!receivedShutdown) { + closeAll(); + } + return newResultMayFinishHandshake(isInboundDone() ? CLOSED : OK, status, + bytesConsumed, bytesProduced); + } else if (sslError == SSL.SSL_ERROR_WANT_X509_LOOKUP || + sslError == SSL.SSL_ERROR_WANT_CERTIFICATE_VERIFY || + sslError == SSL.SSL_ERROR_WANT_PRIVATE_KEY_OPERATION) { + return newResult(isInboundDone() ? CLOSED : OK, + NEED_TASK, bytesConsumed, bytesProduced); + } else { + return sslReadErrorResult(sslError, SSL.getLastErrorNumber(), bytesConsumed, + bytesProduced); + } + } + } + + if (++srcsOffset >= srcsEndOffset) { + break; + } + } finally { + if (bioWriteCopyBuf != null) { + bioWriteCopyBuf.release(); + } + } + } + } finally { + SSL.bioClearByteBuffer(networkBIO); + rejectRemoteInitiatedRenegotiation(); + } + + // Check to see if we received a close_notify message from the peer. + if (!receivedShutdown && (SSL.getShutdown(ssl) & SSL.SSL_RECEIVED_SHUTDOWN) == SSL.SSL_RECEIVED_SHUTDOWN) { + closeAll(); + } + + return newResultMayFinishHandshake(isInboundDone() ? CLOSED : OK, status, bytesConsumed, bytesProduced); + } + } + + private boolean needWrapAgain(int stackError) { + // Check if we have a pending handshakeException and if so see if we need to consume all pending data from the + // BIO first or can just shutdown and throw it now. + // This is needed so we ensure close_notify etc is correctly send to the remote peer. + // See https://github.com/netty/netty/issues/3900 + if (SSL.bioLengthNonApplication(networkBIO) > 0) { + // we seem to have data left that needs to be transferred and so the user needs + // call wrap(...). Store the error so we can pick it up later. + if (pendingException == null) { + pendingException = newSSLExceptionForError(stackError); + } else if (shouldAddSuppressed(pendingException, stackError)) { + ThrowableUtil.addSuppressed(pendingException, newSSLExceptionForError(stackError)); + } + // We need to clear all errors so we not pick up anything that was left on the stack on the next + // operation. Note that shutdownWithError(...) will cleanup the stack as well so its only needed here. + SSL.clearError(); + return true; + } + return false; + } + + private SSLException newSSLExceptionForError(int stackError) { + String message = SSL.getErrorString(stackError); + return handshakeState == HandshakeState.FINISHED ? + new OpenSslException(message, stackError) : new OpenSslHandshakeException(message, stackError); + } + + private static boolean shouldAddSuppressed(Throwable target, int errorCode) { + for (Throwable suppressed: ThrowableUtil.getSuppressed(target)) { + if (suppressed instanceof NativeSslException && + ((NativeSslException) suppressed).errorCode() == errorCode) { + /// An exception with this errorCode was already added before. + return false; + } + } + return true; + } + + private SSLEngineResult sslReadErrorResult(int error, int stackError, int bytesConsumed, int bytesProduced) + throws SSLException { + if (needWrapAgain(stackError)) { + // There is something that needs to be send to the remote peer before we can teardown. + // This is most likely some alert. + return new SSLEngineResult(OK, NEED_WRAP, bytesConsumed, bytesProduced); + } + throw shutdownWithError("SSL_read", error, stackError); + } + + private void closeAll() throws SSLException { + receivedShutdown = true; + closeOutbound(); + closeInbound(); + } + + private void rejectRemoteInitiatedRenegotiation() throws SSLHandshakeException { + // As rejectRemoteInitiatedRenegotiation() is called in a finally block we also need to check if we shutdown + // the engine before as otherwise SSL.getHandshakeCount(ssl) will throw an NPE if the passed in ssl is 0. + // See https://github.com/netty/netty/issues/7353 + if (!isDestroyed() && (!clientMode && SSL.getHandshakeCount(ssl) > 1 || + // Let's allow to renegotiate once for client auth. + clientMode && SSL.getHandshakeCount(ssl) > 2) && + // As we may count multiple handshakes when TLSv1.3 is used we should just ignore this here as + // renegotiation is not supported in TLSv1.3 as per spec. + !SslProtocols.TLS_v1_3.equals(session.getProtocol()) && handshakeState == HandshakeState.FINISHED) { + // TODO: In future versions me may also want to send a fatal_alert to the client and so notify it + // that the renegotiation failed. + shutdown(); + throw new SSLHandshakeException("remote-initiated renegotiation not allowed"); + } + } + + public final SSLEngineResult unwrap(final ByteBuffer[] srcs, final ByteBuffer[] dsts) throws SSLException { + return unwrap(srcs, 0, srcs.length, dsts, 0, dsts.length); + } + + private ByteBuffer[] singleSrcBuffer(ByteBuffer src) { + singleSrcBuffer[0] = src; + return singleSrcBuffer; + } + + private void resetSingleSrcBuffer() { + singleSrcBuffer[0] = null; + } + + private ByteBuffer[] singleDstBuffer(ByteBuffer src) { + singleDstBuffer[0] = src; + return singleDstBuffer; + } + + private void resetSingleDstBuffer() { + singleDstBuffer[0] = null; + } + + @Override + public final synchronized SSLEngineResult unwrap( + final ByteBuffer src, final ByteBuffer[] dsts, final int offset, final int length) throws SSLException { + try { + return unwrap(singleSrcBuffer(src), 0, 1, dsts, offset, length); + } finally { + resetSingleSrcBuffer(); + } + } + + @Override + public final synchronized SSLEngineResult wrap(ByteBuffer src, ByteBuffer dst) throws SSLException { + try { + return wrap(singleSrcBuffer(src), dst); + } finally { + resetSingleSrcBuffer(); + } + } + + @Override + public final synchronized SSLEngineResult unwrap(ByteBuffer src, ByteBuffer dst) throws SSLException { + try { + return unwrap(singleSrcBuffer(src), singleDstBuffer(dst)); + } finally { + resetSingleSrcBuffer(); + resetSingleDstBuffer(); + } + } + + @Override + public final synchronized SSLEngineResult unwrap(ByteBuffer src, ByteBuffer[] dsts) throws SSLException { + try { + return unwrap(singleSrcBuffer(src), dsts); + } finally { + resetSingleSrcBuffer(); + } + } + + private class TaskDecorator implements Runnable { + protected final R task; + TaskDecorator(R task) { + this.task = task; + } + + @Override + public void run() { + runAndResetNeedTask(task); + } + } + + private final class AsyncTaskDecorator extends TaskDecorator implements AsyncRunnable { + AsyncTaskDecorator(AsyncTask task) { + super(task); + } + + @Override + public void run(final Runnable runnable) { + if (isDestroyed()) { + // The engine was destroyed in the meantime, just return. + return; + } + task.runAsync(new TaskDecorator(runnable)); + } + } + + private synchronized void runAndResetNeedTask(Runnable task) { + try { + if (isDestroyed()) { + // The engine was destroyed in the meantime, just return. + return; + } + task.run(); + } finally { + // The task was run, reset needTask to false so getHandshakeStatus() returns the correct value. + needTask = false; + } + } + + @Override + public final synchronized Runnable getDelegatedTask() { + if (isDestroyed()) { + return null; + } + final Runnable task = SSL.getTask(ssl); + if (task == null) { + return null; + } + if (task instanceof AsyncTask) { + return new AsyncTaskDecorator((AsyncTask) task); + } + return new TaskDecorator(task); + } + + @Override + public final synchronized void closeInbound() throws SSLException { + if (isInboundDone) { + return; + } + + isInboundDone = true; + + if (isOutboundDone()) { + // Only call shutdown if there is no outbound data pending. + // See https://github.com/netty/netty/issues/6167 + shutdown(); + } + + if (handshakeState != HandshakeState.NOT_STARTED && !receivedShutdown) { + throw new SSLException( + "Inbound closed before receiving peer's close_notify: possible truncation attack?"); + } + } + + @Override + public final synchronized boolean isInboundDone() { + return isInboundDone; + } + + @Override + public final synchronized void closeOutbound() { + if (outboundClosed) { + return; + } + + outboundClosed = true; + + if (handshakeState != HandshakeState.NOT_STARTED && !isDestroyed()) { + int mode = SSL.getShutdown(ssl); + if ((mode & SSL.SSL_SENT_SHUTDOWN) != SSL.SSL_SENT_SHUTDOWN) { + doSSLShutdown(); + } + } else { + // engine closing before initial handshake + shutdown(); + } + } + + /** + * Attempt to call {@link SSL#shutdownSSL(long)}. + * @return {@code false} if the call to {@link SSL#shutdownSSL(long)} was not attempted or returned an error. + */ + private boolean doSSLShutdown() { + if (SSL.isInInit(ssl) != 0) { + // Only try to call SSL_shutdown if we are not in the init state anymore. + // Otherwise we will see 'error:140E0197:SSL routines:SSL_shutdown:shutdown while in init' in our logs. + // + // See also https://hg.nginx.org/nginx/rev/062c189fee20 + return false; + } + int err = SSL.shutdownSSL(ssl); + if (err < 0) { + int sslErr = SSL.getError(ssl, err); + if (sslErr == SSL.SSL_ERROR_SYSCALL || sslErr == SSL.SSL_ERROR_SSL) { + if (logger.isDebugEnabled()) { + int error = SSL.getLastErrorNumber(); + logger.debug("SSL_shutdown failed: OpenSSL error: {} {}", error, SSL.getErrorString(error)); + } + // There was an internal error -- shutdown + shutdown(); + return false; + } + SSL.clearError(); + } + return true; + } + + @Override + public final synchronized boolean isOutboundDone() { + // Check if there is anything left in the outbound buffer. + // We need to ensure we only call SSL.pendingWrittenBytesInBIO(...) if the engine was not destroyed yet. + return outboundClosed && (networkBIO == 0 || SSL.bioLengthNonApplication(networkBIO) == 0); + } + + @Override + public final String[] getSupportedCipherSuites() { + return OpenSsl.AVAILABLE_CIPHER_SUITES.toArray(EMPTY_STRINGS); + } + + @Override + public final String[] getEnabledCipherSuites() { + final String[] extraCiphers; + final String[] enabled; + final boolean tls13Enabled; + synchronized (this) { + if (!isDestroyed()) { + enabled = SSL.getCiphers(ssl); + int opts = SSL.getOptions(ssl); + if (isProtocolEnabled(opts, SSL.SSL_OP_NO_TLSv1_3, SslProtocols.TLS_v1_3)) { + extraCiphers = OpenSsl.EXTRA_SUPPORTED_TLS_1_3_CIPHERS; + tls13Enabled = true; + } else { + extraCiphers = EMPTY_STRINGS; + tls13Enabled = false; + } + } else { + return EMPTY_STRINGS; + } + } + if (enabled == null) { + return EMPTY_STRINGS; + } else { + Set enabledSet = new LinkedHashSet(enabled.length + extraCiphers.length); + synchronized (this) { + for (int i = 0; i < enabled.length; i++) { + String mapped = toJavaCipherSuite(enabled[i]); + final String cipher = mapped == null ? enabled[i] : mapped; + if ((!tls13Enabled || !OpenSsl.isTlsv13Supported()) && SslUtils.isTLSv13Cipher(cipher)) { + continue; + } + enabledSet.add(cipher); + } + Collections.addAll(enabledSet, extraCiphers); + } + return enabledSet.toArray(EMPTY_STRINGS); + } + } + + @Override + public final void setEnabledCipherSuites(String[] cipherSuites) { + checkNotNull(cipherSuites, "cipherSuites"); + + final StringBuilder buf = new StringBuilder(); + final StringBuilder bufTLSv13 = new StringBuilder(); + + CipherSuiteConverter.convertToCipherStrings(Arrays.asList(cipherSuites), buf, bufTLSv13, OpenSsl.isBoringSSL()); + final String cipherSuiteSpec = buf.toString(); + final String cipherSuiteSpecTLSv13 = bufTLSv13.toString(); + + if (!OpenSsl.isTlsv13Supported() && !cipherSuiteSpecTLSv13.isEmpty()) { + throw new IllegalArgumentException("TLSv1.3 is not supported by this java version."); + } + synchronized (this) { + if (!isDestroyed()) { + try { + // Set non TLSv1.3 ciphers. + SSL.setCipherSuites(ssl, cipherSuiteSpec, false); + if (OpenSsl.isTlsv13Supported()) { + // Set TLSv1.3 ciphers. + SSL.setCipherSuites(ssl, OpenSsl.checkTls13Ciphers(logger, cipherSuiteSpecTLSv13), true); + } + + // We also need to update the enabled protocols to ensure we disable the protocol if there are + // no compatible ciphers left. + Set protocols = new HashSet(explicitlyEnabledProtocols.length); + Collections.addAll(protocols, explicitlyEnabledProtocols); + + // We have no ciphers that are compatible with none-TLSv1.3, let us explicit disable all other + // protocols. + if (cipherSuiteSpec.isEmpty()) { + protocols.remove(SslProtocols.TLS_v1); + protocols.remove(SslProtocols.TLS_v1_1); + protocols.remove(SslProtocols.TLS_v1_2); + protocols.remove(SslProtocols.SSL_v3); + protocols.remove(SslProtocols.SSL_v2); + protocols.remove(SslProtocols.SSL_v2_HELLO); + } + // We have no ciphers that are compatible with TLSv1.3, let us explicit disable it. + if (cipherSuiteSpecTLSv13.isEmpty()) { + protocols.remove(SslProtocols.TLS_v1_3); + } + // Update the protocols but not cache the value. We only cache when we call it from the user + // code or when we construct the engine. + setEnabledProtocols0(protocols.toArray(EMPTY_STRINGS), false); + } catch (Exception e) { + throw new IllegalStateException("failed to enable cipher suites: " + cipherSuiteSpec, e); + } + } else { + throw new IllegalStateException("failed to enable cipher suites: " + cipherSuiteSpec); + } + } + } + + @Override + public final String[] getSupportedProtocols() { + return OpenSsl.SUPPORTED_PROTOCOLS_SET.toArray(EMPTY_STRINGS); + } + + @Override + public final String[] getEnabledProtocols() { + List enabled = new ArrayList(6); + // Seems like there is no way to explicit disable SSLv2Hello in openssl so it is always enabled + enabled.add(SslProtocols.SSL_v2_HELLO); + + int opts; + synchronized (this) { + if (!isDestroyed()) { + opts = SSL.getOptions(ssl); + } else { + return enabled.toArray(EMPTY_STRINGS); + } + } + if (isProtocolEnabled(opts, SSL.SSL_OP_NO_TLSv1, SslProtocols.TLS_v1)) { + enabled.add(SslProtocols.TLS_v1); + } + if (isProtocolEnabled(opts, SSL.SSL_OP_NO_TLSv1_1, SslProtocols.TLS_v1_1)) { + enabled.add(SslProtocols.TLS_v1_1); + } + if (isProtocolEnabled(opts, SSL.SSL_OP_NO_TLSv1_2, SslProtocols.TLS_v1_2)) { + enabled.add(SslProtocols.TLS_v1_2); + } + if (isProtocolEnabled(opts, SSL.SSL_OP_NO_TLSv1_3, SslProtocols.TLS_v1_3)) { + enabled.add(SslProtocols.TLS_v1_3); + } + if (isProtocolEnabled(opts, SSL.SSL_OP_NO_SSLv2, SslProtocols.SSL_v2)) { + enabled.add(SslProtocols.SSL_v2); + } + if (isProtocolEnabled(opts, SSL.SSL_OP_NO_SSLv3, SslProtocols.SSL_v3)) { + enabled.add(SslProtocols.SSL_v3); + } + return enabled.toArray(EMPTY_STRINGS); + } + + private static boolean isProtocolEnabled(int opts, int disableMask, String protocolString) { + // We also need to check if the actual protocolString is supported as depending on the openssl API + // implementations it may use a disableMask of 0 (BoringSSL is doing this for example). + return (opts & disableMask) == 0 && OpenSsl.SUPPORTED_PROTOCOLS_SET.contains(protocolString); + } + + /** + * {@inheritDoc} + * TLS doesn't support a way to advertise non-contiguous versions from the client's perspective, and the client + * just advertises the max supported version. The TLS protocol also doesn't support all different combinations of + * discrete protocols, and instead assumes contiguous ranges. OpenSSL has some unexpected behavior + * (e.g. handshake failures) if non-contiguous protocols are used even where there is a compatible set of protocols + * and ciphers. For these reasons this method will determine the minimum protocol and the maximum protocol and + * enabled a contiguous range from [min protocol, max protocol] in OpenSSL. + */ + @Override + public final void setEnabledProtocols(String[] protocols) { + setEnabledProtocols0(protocols, true); + } + + private void setEnabledProtocols0(String[] protocols, boolean cache) { + // This is correct from the API docs + checkNotNullWithIAE(protocols, "protocols"); + int minProtocolIndex = OPENSSL_OP_NO_PROTOCOLS.length; + int maxProtocolIndex = 0; + for (String p: protocols) { + if (!OpenSsl.SUPPORTED_PROTOCOLS_SET.contains(p)) { + throw new IllegalArgumentException("Protocol " + p + " is not supported."); + } + if (p.equals(SslProtocols.SSL_v2)) { + if (minProtocolIndex > OPENSSL_OP_NO_PROTOCOL_INDEX_SSLV2) { + minProtocolIndex = OPENSSL_OP_NO_PROTOCOL_INDEX_SSLV2; + } + if (maxProtocolIndex < OPENSSL_OP_NO_PROTOCOL_INDEX_SSLV2) { + maxProtocolIndex = OPENSSL_OP_NO_PROTOCOL_INDEX_SSLV2; + } + } else if (p.equals(SslProtocols.SSL_v3)) { + if (minProtocolIndex > OPENSSL_OP_NO_PROTOCOL_INDEX_SSLV3) { + minProtocolIndex = OPENSSL_OP_NO_PROTOCOL_INDEX_SSLV3; + } + if (maxProtocolIndex < OPENSSL_OP_NO_PROTOCOL_INDEX_SSLV3) { + maxProtocolIndex = OPENSSL_OP_NO_PROTOCOL_INDEX_SSLV3; + } + } else if (p.equals(SslProtocols.TLS_v1)) { + if (minProtocolIndex > OPENSSL_OP_NO_PROTOCOL_INDEX_TLSv1) { + minProtocolIndex = OPENSSL_OP_NO_PROTOCOL_INDEX_TLSv1; + } + if (maxProtocolIndex < OPENSSL_OP_NO_PROTOCOL_INDEX_TLSv1) { + maxProtocolIndex = OPENSSL_OP_NO_PROTOCOL_INDEX_TLSv1; + } + } else if (p.equals(SslProtocols.TLS_v1_1)) { + if (minProtocolIndex > OPENSSL_OP_NO_PROTOCOL_INDEX_TLSv1_1) { + minProtocolIndex = OPENSSL_OP_NO_PROTOCOL_INDEX_TLSv1_1; + } + if (maxProtocolIndex < OPENSSL_OP_NO_PROTOCOL_INDEX_TLSv1_1) { + maxProtocolIndex = OPENSSL_OP_NO_PROTOCOL_INDEX_TLSv1_1; + } + } else if (p.equals(SslProtocols.TLS_v1_2)) { + if (minProtocolIndex > OPENSSL_OP_NO_PROTOCOL_INDEX_TLSv1_2) { + minProtocolIndex = OPENSSL_OP_NO_PROTOCOL_INDEX_TLSv1_2; + } + if (maxProtocolIndex < OPENSSL_OP_NO_PROTOCOL_INDEX_TLSv1_2) { + maxProtocolIndex = OPENSSL_OP_NO_PROTOCOL_INDEX_TLSv1_2; + } + } else if (p.equals(SslProtocols.TLS_v1_3)) { + if (minProtocolIndex > OPENSSL_OP_NO_PROTOCOL_INDEX_TLSv1_3) { + minProtocolIndex = OPENSSL_OP_NO_PROTOCOL_INDEX_TLSv1_3; + } + if (maxProtocolIndex < OPENSSL_OP_NO_PROTOCOL_INDEX_TLSv1_3) { + maxProtocolIndex = OPENSSL_OP_NO_PROTOCOL_INDEX_TLSv1_3; + } + } + } + synchronized (this) { + if (cache) { + this.explicitlyEnabledProtocols = protocols; + } + if (!isDestroyed()) { + // Clear out options which disable protocols + SSL.clearOptions(ssl, SSL.SSL_OP_NO_SSLv2 | SSL.SSL_OP_NO_SSLv3 | SSL.SSL_OP_NO_TLSv1 | + SSL.SSL_OP_NO_TLSv1_1 | SSL.SSL_OP_NO_TLSv1_2 | SSL.SSL_OP_NO_TLSv1_3); + + int opts = 0; + for (int i = 0; i < minProtocolIndex; ++i) { + opts |= OPENSSL_OP_NO_PROTOCOLS[i]; + } + assert maxProtocolIndex != MAX_VALUE; + for (int i = maxProtocolIndex + 1; i < OPENSSL_OP_NO_PROTOCOLS.length; ++i) { + opts |= OPENSSL_OP_NO_PROTOCOLS[i]; + } + + // Disable protocols we do not want + SSL.setOptions(ssl, opts); + } else { + throw new IllegalStateException("failed to enable protocols: " + Arrays.asList(protocols)); + } + } + } + + @Override + public final SSLSession getSession() { + return session; + } + + @Override + public final synchronized void beginHandshake() throws SSLException { + switch (handshakeState) { + case STARTED_IMPLICITLY: + checkEngineClosed(); + + // A user did not start handshake by calling this method by him/herself, + // but handshake has been started already by wrap() or unwrap() implicitly. + // Because it's the user's first time to call this method, it is unfair to + // raise an exception. From the user's standpoint, he or she never asked + // for renegotiation. + + handshakeState = HandshakeState.STARTED_EXPLICITLY; // Next time this method is invoked by the user, + calculateMaxWrapOverhead(); + // we should raise an exception. + break; + case STARTED_EXPLICITLY: + // Nothing to do as the handshake is not done yet. + break; + case FINISHED: + throw new SSLException("renegotiation unsupported"); + case NOT_STARTED: + handshakeState = HandshakeState.STARTED_EXPLICITLY; + if (handshake() == NEED_TASK) { + // Set needTask to true so getHandshakeStatus() will return the correct value. + needTask = true; + } + calculateMaxWrapOverhead(); + break; + default: + throw new Error(); + } + } + + private void checkEngineClosed() throws SSLException { + if (isDestroyed()) { + throw new SSLException("engine closed"); + } + } + + private static SSLEngineResult.HandshakeStatus pendingStatus(int pendingStatus) { + // Depending on if there is something left in the BIO we need to WRAP or UNWRAP + return pendingStatus > 0 ? NEED_WRAP : NEED_UNWRAP; + } + + private static boolean isEmpty(Object[] arr) { + return arr == null || arr.length == 0; + } + + private static boolean isEmpty(byte[] cert) { + return cert == null || cert.length == 0; + } + + private SSLEngineResult.HandshakeStatus handshakeException() throws SSLException { + if (SSL.bioLengthNonApplication(networkBIO) > 0) { + // There is something pending, we need to consume it first via a WRAP so we don't loose anything. + return NEED_WRAP; + } + + Throwable exception = pendingException; + assert exception != null; + pendingException = null; + shutdown(); + if (exception instanceof SSLHandshakeException) { + throw (SSLHandshakeException) exception; + } + SSLHandshakeException e = new SSLHandshakeException("General OpenSslEngine problem"); + e.initCause(exception); + throw e; + } + + /** + * Should be called if the handshake will be failed due a callback that throws an exception. + * This cause will then be used to give more details as part of the {@link SSLHandshakeException}. + */ + final void initHandshakeException(Throwable cause) { + if (pendingException == null) { + pendingException = cause; + } else { + ThrowableUtil.addSuppressed(pendingException, cause); + } + } + + private SSLEngineResult.HandshakeStatus handshake() throws SSLException { + if (needTask) { + return NEED_TASK; + } + if (handshakeState == HandshakeState.FINISHED) { + return FINISHED; + } + + checkEngineClosed(); + + if (pendingException != null) { + // Let's call SSL.doHandshake(...) again in case there is some async operation pending that would fill the + // outbound buffer. + if (SSL.doHandshake(ssl) <= 0) { + // Clear any error that was put on the stack by the handshake + SSL.clearError(); + } + return handshakeException(); + } + + // Adding the OpenSslEngine to the OpenSslEngineMap so it can be used in the AbstractCertificateVerifier. + engineMap.add(this); + + if (!sessionSet) { + parentContext.sessionContext().setSessionFromCache(getPeerHost(), getPeerPort(), ssl); + sessionSet = true; + } + + if (lastAccessed == -1) { + lastAccessed = System.currentTimeMillis(); + } + + int code = SSL.doHandshake(ssl); + if (code <= 0) { + int sslError = SSL.getError(ssl, code); + if (sslError == SSL.SSL_ERROR_WANT_READ || sslError == SSL.SSL_ERROR_WANT_WRITE) { + return pendingStatus(SSL.bioLengthNonApplication(networkBIO)); + } + + if (sslError == SSL.SSL_ERROR_WANT_X509_LOOKUP || + sslError == SSL.SSL_ERROR_WANT_CERTIFICATE_VERIFY || + sslError == SSL.SSL_ERROR_WANT_PRIVATE_KEY_OPERATION) { + return NEED_TASK; + } + + if (needWrapAgain(SSL.getLastErrorNumber())) { + // There is something that needs to be send to the remote peer before we can teardown. + // This is most likely some alert. + return NEED_WRAP; + } + // Check if we have a pending exception that was created during the handshake and if so throw it after + // shutdown the connection. + if (pendingException != null) { + return handshakeException(); + } + + // Everything else is considered as error + throw shutdownWithError("SSL_do_handshake", sslError); + } + // We have produced more data as part of the handshake if this is the case the user should call wrap(...) + if (SSL.bioLengthNonApplication(networkBIO) > 0) { + return NEED_WRAP; + } + // if SSL_do_handshake returns > 0 or sslError == SSL.SSL_ERROR_NAME it means the handshake was finished. + session.handshakeFinished(SSL.getSessionId(ssl), SSL.getCipherForSSL(ssl), SSL.getVersion(ssl), + SSL.getPeerCertificate(ssl), SSL.getPeerCertChain(ssl), + SSL.getTime(ssl) * 1000L, parentContext.sessionTimeout() * 1000L); + selectApplicationProtocol(); + return FINISHED; + } + + private SSLEngineResult.HandshakeStatus mayFinishHandshake( + SSLEngineResult.HandshakeStatus hs, int bytesConsumed, int bytesProduced) throws SSLException { + return hs == NEED_UNWRAP && bytesProduced > 0 || hs == NEED_WRAP && bytesConsumed > 0 ? + handshake() : mayFinishHandshake(hs != FINISHED ? getHandshakeStatus() : FINISHED); + } + + private SSLEngineResult.HandshakeStatus mayFinishHandshake(SSLEngineResult.HandshakeStatus status) + throws SSLException { + if (status == NOT_HANDSHAKING) { + if (handshakeState != HandshakeState.FINISHED) { + // If the status was NOT_HANDSHAKING and we not finished the handshake we need to call + // SSL_do_handshake() again + return handshake(); + } + if (!isDestroyed() && SSL.bioLengthNonApplication(networkBIO) > 0) { + // We have something left that needs to be wrapped. + return NEED_WRAP; + } + } + return status; + } + + @Override + public final synchronized SSLEngineResult.HandshakeStatus getHandshakeStatus() { + // Check if we are in the initial handshake phase or shutdown phase + if (needPendingStatus()) { + if (needTask) { + // There is a task outstanding + return NEED_TASK; + } + return pendingStatus(SSL.bioLengthNonApplication(networkBIO)); + } + return NOT_HANDSHAKING; + } + + private SSLEngineResult.HandshakeStatus getHandshakeStatus(int pending) { + // Check if we are in the initial handshake phase or shutdown phase + if (needPendingStatus()) { + if (needTask) { + // There is a task outstanding + return NEED_TASK; + } + return pendingStatus(pending); + } + return NOT_HANDSHAKING; + } + + private boolean needPendingStatus() { + return handshakeState != HandshakeState.NOT_STARTED && !isDestroyed() + && (handshakeState != HandshakeState.FINISHED || isInboundDone() || isOutboundDone()); + } + + /** + * Converts the specified OpenSSL cipher suite to the Java cipher suite. + */ + private String toJavaCipherSuite(String openSslCipherSuite) { + if (openSslCipherSuite == null) { + return null; + } + + String version = SSL.getVersion(ssl); + String prefix = toJavaCipherSuitePrefix(version); + return CipherSuiteConverter.toJava(openSslCipherSuite, prefix); + } + + /** + * Converts the protocol version string returned by {@link SSL#getVersion(long)} to protocol family string. + */ + private static String toJavaCipherSuitePrefix(String protocolVersion) { + final char c; + if (protocolVersion == null || protocolVersion.isEmpty()) { + c = 0; + } else { + c = protocolVersion.charAt(0); + } + + switch (c) { + case 'T': + return "TLS"; + case 'S': + return "SSL"; + default: + return "UNKNOWN"; + } + } + + @Override + public final void setUseClientMode(boolean clientMode) { + if (clientMode != this.clientMode) { + throw new UnsupportedOperationException(); + } + } + + @Override + public final boolean getUseClientMode() { + return clientMode; + } + + @Override + public final void setNeedClientAuth(boolean b) { + setClientAuth(b ? ClientAuth.REQUIRE : ClientAuth.NONE); + } + + @Override + public final boolean getNeedClientAuth() { + return clientAuth == ClientAuth.REQUIRE; + } + + @Override + public final void setWantClientAuth(boolean b) { + setClientAuth(b ? ClientAuth.OPTIONAL : ClientAuth.NONE); + } + + @Override + public final boolean getWantClientAuth() { + return clientAuth == ClientAuth.OPTIONAL; + } + + /** + * See SSL_set_verify and + * {@link SSL#setVerify(long, int, int)}. + */ + @UnstableApi + public final synchronized void setVerify(int verifyMode, int depth) { + if (!isDestroyed()) { + SSL.setVerify(ssl, verifyMode, depth); + } + } + + private void setClientAuth(ClientAuth mode) { + if (clientMode) { + return; + } + synchronized (this) { + if (clientAuth == mode) { + // No need to issue any JNI calls if the mode is the same + return; + } + if (!isDestroyed()) { + switch (mode) { + case NONE: + SSL.setVerify(ssl, SSL.SSL_CVERIFY_NONE, ReferenceCountedOpenSslContext.VERIFY_DEPTH); + break; + case REQUIRE: + SSL.setVerify(ssl, SSL.SSL_CVERIFY_REQUIRED, ReferenceCountedOpenSslContext.VERIFY_DEPTH); + break; + case OPTIONAL: + SSL.setVerify(ssl, SSL.SSL_CVERIFY_OPTIONAL, ReferenceCountedOpenSslContext.VERIFY_DEPTH); + break; + default: + throw new Error(mode.toString()); + } + } + clientAuth = mode; + } + } + + @Override + public final void setEnableSessionCreation(boolean b) { + if (b) { + throw new UnsupportedOperationException(); + } + } + + @Override + public final boolean getEnableSessionCreation() { + return false; + } + + @SuppressJava6Requirement(reason = "Usage guarded by java version check") + @Override + public final synchronized SSLParameters getSSLParameters() { + SSLParameters sslParameters = super.getSSLParameters(); + + int version = PlatformDependent.javaVersion(); + if (version >= 7) { + sslParameters.setEndpointIdentificationAlgorithm(endPointIdentificationAlgorithm); + Java7SslParametersUtils.setAlgorithmConstraints(sslParameters, algorithmConstraints); + if (version >= 8) { + if (sniHostNames != null) { + Java8SslUtils.setSniHostNames(sslParameters, sniHostNames); + } + if (!isDestroyed()) { + Java8SslUtils.setUseCipherSuitesOrder( + sslParameters, (SSL.getOptions(ssl) & SSL.SSL_OP_CIPHER_SERVER_PREFERENCE) != 0); + } + + Java8SslUtils.setSNIMatchers(sslParameters, matchers); + } + } + return sslParameters; + } + + @SuppressJava6Requirement(reason = "Usage guarded by java version check") + @Override + public final synchronized void setSSLParameters(SSLParameters sslParameters) { + int version = PlatformDependent.javaVersion(); + if (version >= 7) { + if (sslParameters.getAlgorithmConstraints() != null) { + throw new IllegalArgumentException("AlgorithmConstraints are not supported."); + } + + boolean isDestroyed = isDestroyed(); + if (version >= 8) { + if (!isDestroyed) { + if (clientMode) { + final List sniHostNames = Java8SslUtils.getSniHostNames(sslParameters); + for (String name: sniHostNames) { + SSL.setTlsExtHostName(ssl, name); + } + this.sniHostNames = sniHostNames; + } + if (Java8SslUtils.getUseCipherSuitesOrder(sslParameters)) { + SSL.setOptions(ssl, SSL.SSL_OP_CIPHER_SERVER_PREFERENCE); + } else { + SSL.clearOptions(ssl, SSL.SSL_OP_CIPHER_SERVER_PREFERENCE); + } + } + matchers = sslParameters.getSNIMatchers(); + } + + final String endPointIdentificationAlgorithm = sslParameters.getEndpointIdentificationAlgorithm(); + if (!isDestroyed) { + // If the user asks for hostname verification we must ensure we verify the peer. + // If the user disables hostname verification we leave it up to the user to change the mode manually. + if (clientMode && isEndPointVerificationEnabled(endPointIdentificationAlgorithm)) { + SSL.setVerify(ssl, SSL.SSL_CVERIFY_REQUIRED, -1); + } + } + this.endPointIdentificationAlgorithm = endPointIdentificationAlgorithm; + algorithmConstraints = sslParameters.getAlgorithmConstraints(); + } + super.setSSLParameters(sslParameters); + } + + private static boolean isEndPointVerificationEnabled(String endPointIdentificationAlgorithm) { + return endPointIdentificationAlgorithm != null && !endPointIdentificationAlgorithm.isEmpty(); + } + + private boolean isDestroyed() { + return destroyed; + } + + final boolean checkSniHostnameMatch(byte[] hostname) { + return Java8SslUtils.checkSniHostnameMatch(matchers, hostname); + } + + @Override + public String getNegotiatedApplicationProtocol() { + return applicationProtocol; + } + + private static long bufferAddress(ByteBuffer b) { + assert b.isDirect(); + if (PlatformDependent.hasUnsafe()) { + return PlatformDependent.directBufferAddress(b); + } + return Buffer.address(b); + } + + /** + * Select the application protocol used. + */ + private void selectApplicationProtocol() throws SSLException { + ApplicationProtocolConfig.SelectedListenerFailureBehavior behavior = apn.selectedListenerFailureBehavior(); + List protocols = apn.protocols(); + String applicationProtocol; + switch (apn.protocol()) { + case NONE: + break; + // We always need to check for applicationProtocol == null as the remote peer may not support + // the TLS extension or may have returned an empty selection. + case ALPN: + applicationProtocol = SSL.getAlpnSelected(ssl); + if (applicationProtocol != null) { + ReferenceCountedOpenSslEngine.this.applicationProtocol = selectApplicationProtocol( + protocols, behavior, applicationProtocol); + } + break; + case NPN: + applicationProtocol = SSL.getNextProtoNegotiated(ssl); + if (applicationProtocol != null) { + ReferenceCountedOpenSslEngine.this.applicationProtocol = selectApplicationProtocol( + protocols, behavior, applicationProtocol); + } + break; + case NPN_AND_ALPN: + applicationProtocol = SSL.getAlpnSelected(ssl); + if (applicationProtocol == null) { + applicationProtocol = SSL.getNextProtoNegotiated(ssl); + } + if (applicationProtocol != null) { + ReferenceCountedOpenSslEngine.this.applicationProtocol = selectApplicationProtocol( + protocols, behavior, applicationProtocol); + } + break; + default: + throw new Error(); + } + } + + private String selectApplicationProtocol(List protocols, + ApplicationProtocolConfig.SelectedListenerFailureBehavior behavior, + String applicationProtocol) throws SSLException { + if (behavior == ApplicationProtocolConfig.SelectedListenerFailureBehavior.ACCEPT) { + return applicationProtocol; + } else { + int size = protocols.size(); + assert size > 0; + if (protocols.contains(applicationProtocol)) { + return applicationProtocol; + } else { + if (behavior == ApplicationProtocolConfig.SelectedListenerFailureBehavior.CHOOSE_MY_LAST_PROTOCOL) { + return protocols.get(size - 1); + } else { + throw new SSLException("unknown protocol " + applicationProtocol); + } + } + } + } + + final void setSessionId(OpenSslSessionId id) { + session.setSessionId(id); + } + + private static final Certificate[] JAVAX_CERTS_NOT_SUPPORTED = new X509Certificate[0]; + + private final class DefaultOpenSslSession implements OpenSslSession { + private final OpenSslSessionContext sessionContext; + + // These are guarded by synchronized(OpenSslEngine.this) as handshakeFinished() may be triggered by any + // thread. + private Certificate[] x509PeerCerts; + private Certificate[] peerCerts; + + private boolean valid = true; + private String protocol; + private String cipher; + private OpenSslSessionId id = OpenSslSessionId.NULL_ID; + private volatile long creationTime; + private volatile int applicationBufferSize = MAX_PLAINTEXT_LENGTH; + private volatile Certificate[] localCertificateChain; + // lazy init for memory reasons + private Map values; + + DefaultOpenSslSession(OpenSslSessionContext sessionContext) { + this.sessionContext = sessionContext; + } + + private SSLSessionBindingEvent newSSLSessionBindingEvent(String name) { + return new SSLSessionBindingEvent(session, name); + } + + @Override + public void setSessionId(OpenSslSessionId sessionId) { + synchronized (ReferenceCountedOpenSslEngine.this) { + if (this.id == OpenSslSessionId.NULL_ID) { + this.id = sessionId; + creationTime = System.currentTimeMillis(); + } + } + } + + @Override + public OpenSslSessionId sessionId() { + synchronized (ReferenceCountedOpenSslEngine.this) { + if (this.id == OpenSslSessionId.NULL_ID && !isDestroyed()) { + byte[] sessionId = SSL.getSessionId(ssl); + if (sessionId != null) { + id = new OpenSslSessionId(sessionId); + } + } + + return id; + } + } + + @Override + public void setLocalCertificate(Certificate[] localCertificate) { + this.localCertificateChain = localCertificate; + } + + @Override + public byte[] getId() { + return sessionId().cloneBytes(); + } + + @Override + public OpenSslSessionContext getSessionContext() { + return sessionContext; + } + + @Override + public long getCreationTime() { + synchronized (ReferenceCountedOpenSslEngine.this) { + return creationTime; + } + } + + @Override + public long getLastAccessedTime() { + long lastAccessed = ReferenceCountedOpenSslEngine.this.lastAccessed; + // if lastAccessed is -1 we will just return the creation time as the handshake was not started yet. + return lastAccessed == -1 ? getCreationTime() : lastAccessed; + } + + @Override + public void invalidate() { + synchronized (ReferenceCountedOpenSslEngine.this) { + valid = false; + sessionContext.removeFromCache(id); + } + } + + @Override + public boolean isValid() { + synchronized (ReferenceCountedOpenSslEngine.this) { + return valid || sessionContext.isInCache(id); + } + } + + @Override + public void putValue(String name, Object value) { + checkNotNull(name, "name"); + checkNotNull(value, "value"); + + final Object old; + synchronized (this) { + Map values = this.values; + if (values == null) { + // Use size of 2 to keep the memory overhead small + values = this.values = new HashMap(2); + } + old = values.put(name, value); + } + + if (value instanceof SSLSessionBindingListener) { + // Use newSSLSessionBindingEvent so we always use the wrapper if needed. + ((SSLSessionBindingListener) value).valueBound(newSSLSessionBindingEvent(name)); + } + notifyUnbound(old, name); + } + + @Override + public Object getValue(String name) { + checkNotNull(name, "name"); + synchronized (this) { + if (values == null) { + return null; + } + return values.get(name); + } + } + + @Override + public void removeValue(String name) { + checkNotNull(name, "name"); + + final Object old; + synchronized (this) { + Map values = this.values; + if (values == null) { + return; + } + old = values.remove(name); + } + + notifyUnbound(old, name); + } + + @Override + public String[] getValueNames() { + synchronized (this) { + Map values = this.values; + if (values == null || values.isEmpty()) { + return EMPTY_STRINGS; + } + return values.keySet().toArray(EMPTY_STRINGS); + } + } + + private void notifyUnbound(Object value, String name) { + if (value instanceof SSLSessionBindingListener) { + // Use newSSLSessionBindingEvent so we always use the wrapper if needed. + ((SSLSessionBindingListener) value).valueUnbound(newSSLSessionBindingEvent(name)); + } + } + + /** + * Finish the handshake and so init everything in the {@link OpenSslSession} that should be accessible by + * the user. + */ + @Override + public void handshakeFinished(byte[] id, String cipher, String protocol, byte[] peerCertificate, + byte[][] peerCertificateChain, long creationTime, long timeout) + throws SSLException { + synchronized (ReferenceCountedOpenSslEngine.this) { + if (!isDestroyed()) { + this.creationTime = creationTime; + if (this.id == OpenSslSessionId.NULL_ID) { + this.id = id == null ? OpenSslSessionId.NULL_ID : new OpenSslSessionId(id); + } + this.cipher = toJavaCipherSuite(cipher); + this.protocol = protocol; + + if (clientMode) { + if (isEmpty(peerCertificateChain)) { + peerCerts = EmptyArrays.EMPTY_CERTIFICATES; + if (OpenSsl.JAVAX_CERTIFICATE_CREATION_SUPPORTED) { + x509PeerCerts = EmptyArrays.EMPTY_JAVAX_X509_CERTIFICATES; + } else { + x509PeerCerts = JAVAX_CERTS_NOT_SUPPORTED; + } + } else { + peerCerts = new Certificate[peerCertificateChain.length]; + if (OpenSsl.JAVAX_CERTIFICATE_CREATION_SUPPORTED) { + x509PeerCerts = new X509Certificate[peerCertificateChain.length]; + } else { + x509PeerCerts = JAVAX_CERTS_NOT_SUPPORTED; + } + initCerts(peerCertificateChain, 0); + } + } else { + // if used on the server side SSL_get_peer_cert_chain(...) will not include the remote peer + // certificate. We use SSL_get_peer_certificate to get it in this case and add it to our + // array later. + // + // See https://www.openssl.org/docs/ssl/SSL_get_peer_cert_chain.html + if (isEmpty(peerCertificate)) { + peerCerts = EmptyArrays.EMPTY_CERTIFICATES; + x509PeerCerts = EmptyArrays.EMPTY_JAVAX_X509_CERTIFICATES; + } else { + if (isEmpty(peerCertificateChain)) { + peerCerts = new Certificate[] {new LazyX509Certificate(peerCertificate)}; + if (OpenSsl.JAVAX_CERTIFICATE_CREATION_SUPPORTED) { + x509PeerCerts = new X509Certificate[] { + new LazyJavaxX509Certificate(peerCertificate) + }; + } else { + x509PeerCerts = JAVAX_CERTS_NOT_SUPPORTED; + } + } else { + peerCerts = new Certificate[peerCertificateChain.length + 1]; + peerCerts[0] = new LazyX509Certificate(peerCertificate); + + if (OpenSsl.JAVAX_CERTIFICATE_CREATION_SUPPORTED) { + x509PeerCerts = new X509Certificate[peerCertificateChain.length + 1]; + x509PeerCerts[0] = new LazyJavaxX509Certificate(peerCertificate); + } else { + x509PeerCerts = JAVAX_CERTS_NOT_SUPPORTED; + } + + initCerts(peerCertificateChain, 1); + } + } + } + + calculateMaxWrapOverhead(); + + handshakeState = HandshakeState.FINISHED; + } else { + throw new SSLException("Already closed"); + } + } + } + + private void initCerts(byte[][] chain, int startPos) { + for (int i = 0; i < chain.length; i++) { + int certPos = startPos + i; + peerCerts[certPos] = new LazyX509Certificate(chain[i]); + if (x509PeerCerts != JAVAX_CERTS_NOT_SUPPORTED) { + x509PeerCerts[certPos] = new LazyJavaxX509Certificate(chain[i]); + } + } + } + + @Override + public Certificate[] getPeerCertificates() throws SSLPeerUnverifiedException { + synchronized (ReferenceCountedOpenSslEngine.this) { + if (isEmpty(peerCerts)) { + throw new SSLPeerUnverifiedException("peer not verified"); + } + return peerCerts.clone(); + } + } + + @Override + public Certificate[] getLocalCertificates() { + Certificate[] localCerts = this.localCertificateChain; + if (localCerts == null) { + return null; + } + return localCerts.clone(); + } + + @Override + public javax.security.cert.X509Certificate[] getPeerCertificateChain() throws SSLPeerUnverifiedException { + synchronized (ReferenceCountedOpenSslEngine.this) { + if (x509PeerCerts == JAVAX_CERTS_NOT_SUPPORTED) { + // Not supported by the underlying JDK, so just throw. This is fine in terms of the API + // contract. See SSLSession.html#getPeerCertificateChain(). + throw new UnsupportedOperationException(); + } + if (isEmpty(x509PeerCerts)) { + throw new SSLPeerUnverifiedException("peer not verified"); + } + javax.security.cert.X509Certificate[] certificates = new javax.security.cert.X509Certificate[x509PeerCerts.length]; + int i = 0; + for (Certificate certificate : x509PeerCerts) { + try { + certificates[i++] = javax.security.cert.X509Certificate.getInstance(certificate.getEncoded()); + } catch (javax.security.cert.CertificateException | CertificateEncodingException e) { + throw new RuntimeException(e); + } + } + return certificates; + } + } + + @Override + public Principal getPeerPrincipal() throws SSLPeerUnverifiedException { + Certificate[] peer = getPeerCertificates(); + // No need for null or length > 0 is needed as this is done in getPeerCertificates() + // already. + return ((java.security.cert.X509Certificate) peer[0]).getSubjectX500Principal(); + } + + @Override + public Principal getLocalPrincipal() { + Certificate[] local = this.localCertificateChain; + if (local == null || local.length == 0) { + return null; + } + return ((java.security.cert.X509Certificate) local[0]).getSubjectX500Principal(); + } + + @Override + public String getCipherSuite() { + synchronized (ReferenceCountedOpenSslEngine.this) { + if (cipher == null) { + return SslUtils.INVALID_CIPHER; + } + return cipher; + } + } + + @Override + public String getProtocol() { + String protocol = this.protocol; + if (protocol == null) { + synchronized (ReferenceCountedOpenSslEngine.this) { + if (!isDestroyed()) { + protocol = SSL.getVersion(ssl); + } else { + protocol = StringUtil.EMPTY_STRING; + } + } + } + return protocol; + } + + @Override + public String getPeerHost() { + return ReferenceCountedOpenSslEngine.this.getPeerHost(); + } + + @Override + public int getPeerPort() { + return ReferenceCountedOpenSslEngine.this.getPeerPort(); + } + + @Override + public int getPacketBufferSize() { + return SSL.SSL_MAX_ENCRYPTED_LENGTH; + } + + @Override + public int getApplicationBufferSize() { + return applicationBufferSize; + } + + @Override + public void tryExpandApplicationBufferSize(int packetLengthDataOnly) { + if (packetLengthDataOnly > MAX_PLAINTEXT_LENGTH && applicationBufferSize != MAX_RECORD_SIZE) { + applicationBufferSize = MAX_RECORD_SIZE; + } + } + + @Override + public String toString() { + return "DefaultOpenSslSession{" + + "sessionContext=" + sessionContext + + ", id=" + id + + '}'; + } + } + + private interface NativeSslException { + int errorCode(); + } + + private static final class OpenSslException extends SSLException implements NativeSslException { + private final int errorCode; + + OpenSslException(String reason, int errorCode) { + super(reason); + this.errorCode = errorCode; + } + + @Override + public int errorCode() { + return errorCode; + } + } + + private static final class OpenSslHandshakeException extends SSLHandshakeException implements NativeSslException { + private final int errorCode; + + OpenSslHandshakeException(String reason, int errorCode) { + super(reason); + this.errorCode = errorCode; + } + + @Override + public int errorCode() { + return errorCode; + } + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslServerContext.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslServerContext.java new file mode 100644 index 0000000..07e0a5d --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/ReferenceCountedOpenSslServerContext.java @@ -0,0 +1,298 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.buffer.ByteBufAllocator; +import io.netty.internal.tcnative.CertificateCallback; +import io.netty.internal.tcnative.SSL; +import io.netty.internal.tcnative.SSLContext; +import io.netty.internal.tcnative.SniHostNameMatcher; +import io.netty.util.CharsetUtil; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.SuppressJava6Requirement; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.security.KeyStore; +import java.security.PrivateKey; +import java.security.cert.X509Certificate; +import java.util.Map; +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLException; +import javax.net.ssl.TrustManagerFactory; +import javax.net.ssl.X509ExtendedTrustManager; +import javax.net.ssl.X509TrustManager; + +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +/** + * A server-side {@link SslContext} which uses OpenSSL's SSL/TLS implementation. + *

Instances of this class must be {@link #release() released} or else native memory will leak! + * + *

Instances of this class must not be released before any {@link ReferenceCountedOpenSslEngine} + * which depends upon the instance of this class is released. Otherwise if any method of + * {@link ReferenceCountedOpenSslEngine} is called which uses this class's JNI resources the JVM may crash. + */ +public final class ReferenceCountedOpenSslServerContext extends ReferenceCountedOpenSslContext { + private static final InternalLogger logger = + InternalLoggerFactory.getInstance(ReferenceCountedOpenSslServerContext.class); + private static final byte[] ID = {'n', 'e', 't', 't', 'y'}; + private final OpenSslServerSessionContext sessionContext; + + ReferenceCountedOpenSslServerContext( + X509Certificate[] trustCertCollection, TrustManagerFactory trustManagerFactory, + X509Certificate[] keyCertChain, PrivateKey key, String keyPassword, KeyManagerFactory keyManagerFactory, + Iterable ciphers, CipherSuiteFilter cipherFilter, ApplicationProtocolConfig apn, + long sessionCacheSize, long sessionTimeout, ClientAuth clientAuth, String[] protocols, boolean startTls, + boolean enableOcsp, String keyStore, Map.Entry, Object>... options) + throws SSLException { + this(trustCertCollection, trustManagerFactory, keyCertChain, key, keyPassword, keyManagerFactory, ciphers, + cipherFilter, toNegotiator(apn), sessionCacheSize, sessionTimeout, clientAuth, protocols, startTls, + enableOcsp, keyStore, options); + } + + ReferenceCountedOpenSslServerContext( + X509Certificate[] trustCertCollection, TrustManagerFactory trustManagerFactory, + X509Certificate[] keyCertChain, PrivateKey key, String keyPassword, KeyManagerFactory keyManagerFactory, + Iterable ciphers, CipherSuiteFilter cipherFilter, OpenSslApplicationProtocolNegotiator apn, + long sessionCacheSize, long sessionTimeout, ClientAuth clientAuth, String[] protocols, boolean startTls, + boolean enableOcsp, String keyStore, Map.Entry, Object>... options) + throws SSLException { + super(ciphers, cipherFilter, apn, SSL.SSL_MODE_SERVER, keyCertChain, + clientAuth, protocols, startTls, enableOcsp, true, options); + // Create a new SSL_CTX and configure it. + boolean success = false; + try { + sessionContext = newSessionContext(this, ctx, engineMap, trustCertCollection, trustManagerFactory, + keyCertChain, key, keyPassword, keyManagerFactory, keyStore, + sessionCacheSize, sessionTimeout); + if (SERVER_ENABLE_SESSION_TICKET) { + sessionContext.setTicketKeys(); + } + success = true; + } finally { + if (!success) { + release(); + } + } + } + + @Override + public OpenSslServerSessionContext sessionContext() { + return sessionContext; + } + + static OpenSslServerSessionContext newSessionContext(ReferenceCountedOpenSslContext thiz, long ctx, + OpenSslEngineMap engineMap, + X509Certificate[] trustCertCollection, + TrustManagerFactory trustManagerFactory, + X509Certificate[] keyCertChain, PrivateKey key, + String keyPassword, KeyManagerFactory keyManagerFactory, + String keyStore, long sessionCacheSize, long sessionTimeout) + throws SSLException { + OpenSslKeyMaterialProvider keyMaterialProvider = null; + try { + try { + SSLContext.setVerify(ctx, SSL.SSL_CVERIFY_NONE, VERIFY_DEPTH); + if (!OpenSsl.useKeyManagerFactory()) { + if (keyManagerFactory != null) { + throw new IllegalArgumentException( + "KeyManagerFactory not supported"); + } + checkNotNull(keyCertChain, "keyCertChain"); + + setKeyMaterial(ctx, keyCertChain, key, keyPassword); + } else { + // javadocs state that keyManagerFactory has precedent over keyCertChain, and we must have a + // keyManagerFactory for the server so build one if it is not specified. + if (keyManagerFactory == null) { + char[] keyPasswordChars = keyStorePassword(keyPassword); + KeyStore ks = buildKeyStore(keyCertChain, key, keyPasswordChars, keyStore); + if (ks.aliases().hasMoreElements()) { + keyManagerFactory = new OpenSslX509KeyManagerFactory(); + } else { + keyManagerFactory = new OpenSslCachingX509KeyManagerFactory( + KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm())); + } + keyManagerFactory.init(ks, keyPasswordChars); + } + keyMaterialProvider = providerFor(keyManagerFactory, keyPassword); + + SSLContext.setCertificateCallback(ctx, new OpenSslServerCertificateCallback( + engineMap, new OpenSslKeyMaterialManager(keyMaterialProvider))); + } + } catch (Exception e) { + throw new SSLException("failed to set certificate and key", e); + } + try { + if (trustCertCollection != null) { + trustManagerFactory = buildTrustManagerFactory(trustCertCollection, trustManagerFactory, keyStore); + } else if (trustManagerFactory == null) { + // Mimic the way SSLContext.getInstance(KeyManager[], null, null) works + trustManagerFactory = TrustManagerFactory.getInstance( + TrustManagerFactory.getDefaultAlgorithm()); + trustManagerFactory.init((KeyStore) null); + } + + final X509TrustManager manager = chooseTrustManager(trustManagerFactory.getTrustManagers()); + + // IMPORTANT: The callbacks set for verification must be static to prevent memory leak as + // otherwise the context can never be collected. This is because the JNI code holds + // a global reference to the callbacks. + // + // See https://github.com/netty/netty/issues/5372 + + setVerifyCallback(ctx, engineMap, manager); + + X509Certificate[] issuers = manager.getAcceptedIssuers(); + if (issuers != null && issuers.length > 0) { + long bio = 0; + try { + bio = toBIO(ByteBufAllocator.DEFAULT, issuers); + if (!SSLContext.setCACertificateBio(ctx, bio)) { + throw new SSLException("unable to setup accepted issuers for trustmanager " + manager); + } + } finally { + freeBio(bio); + } + } + + if (PlatformDependent.javaVersion() >= 8) { + // Only do on Java8+ as SNIMatcher is not supported in earlier releases. + // IMPORTANT: The callbacks set for hostname matching must be static to prevent memory leak as + // otherwise the context can never be collected. This is because the JNI code holds + // a global reference to the matcher. + SSLContext.setSniHostnameMatcher(ctx, new OpenSslSniHostnameMatcher(engineMap)); + } + } catch (SSLException e) { + throw e; + } catch (Exception e) { + throw new SSLException("unable to setup trustmanager", e); + } + + OpenSslServerSessionContext sessionContext = new OpenSslServerSessionContext(thiz, keyMaterialProvider); + sessionContext.setSessionIdContext(ID); + // Enable session caching by default + sessionContext.setSessionCacheEnabled(SERVER_ENABLE_SESSION_CACHE); + if (sessionCacheSize > 0) { + sessionContext.setSessionCacheSize((int) Math.min(sessionCacheSize, Integer.MAX_VALUE)); + } + if (sessionTimeout > 0) { + sessionContext.setSessionTimeout((int) Math.min(sessionTimeout, Integer.MAX_VALUE)); + } + + keyMaterialProvider = null; + + return sessionContext; + } finally { + if (keyMaterialProvider != null) { + keyMaterialProvider.destroy(); + } + } + } + + @SuppressJava6Requirement(reason = "Guarded by java version check") + private static void setVerifyCallback(long ctx, OpenSslEngineMap engineMap, X509TrustManager manager) { + // Use this to prevent an error when running on java < 7 + if (useExtendedTrustManager(manager)) { + SSLContext.setCertVerifyCallback(ctx, new ExtendedTrustManagerVerifyCallback( + engineMap, (X509ExtendedTrustManager) manager)); + } else { + SSLContext.setCertVerifyCallback(ctx, new TrustManagerVerifyCallback(engineMap, manager)); + } + } + + private static final class OpenSslServerCertificateCallback implements CertificateCallback { + private final OpenSslEngineMap engineMap; + private final OpenSslKeyMaterialManager keyManagerHolder; + + OpenSslServerCertificateCallback(OpenSslEngineMap engineMap, OpenSslKeyMaterialManager keyManagerHolder) { + this.engineMap = engineMap; + this.keyManagerHolder = keyManagerHolder; + } + + @Override + public void handle(long ssl, byte[] keyTypeBytes, byte[][] asn1DerEncodedPrincipals) throws Exception { + final ReferenceCountedOpenSslEngine engine = engineMap.get(ssl); + if (engine == null) { + // Maybe null if destroyed in the meantime. + return; + } + try { + // For now we just ignore the asn1DerEncodedPrincipals as this is kind of inline with what the + // OpenJDK SSLEngineImpl does. + keyManagerHolder.setKeyMaterialServerSide(engine); + } catch (Throwable cause) { + engine.initHandshakeException(cause); + + if (cause instanceof Exception) { + throw (Exception) cause; + } + throw new SSLException(cause); + } + } + } + + private static final class TrustManagerVerifyCallback extends AbstractCertificateVerifier { + private final X509TrustManager manager; + + TrustManagerVerifyCallback(OpenSslEngineMap engineMap, X509TrustManager manager) { + super(engineMap); + this.manager = manager; + } + + @Override + void verify(ReferenceCountedOpenSslEngine engine, X509Certificate[] peerCerts, String auth) + throws Exception { + manager.checkClientTrusted(peerCerts, auth); + } + } + + @SuppressJava6Requirement(reason = "Usage guarded by java version check") + private static final class ExtendedTrustManagerVerifyCallback extends AbstractCertificateVerifier { + private final X509ExtendedTrustManager manager; + + ExtendedTrustManagerVerifyCallback(OpenSslEngineMap engineMap, X509ExtendedTrustManager manager) { + super(engineMap); + this.manager = manager; + } + + @Override + void verify(ReferenceCountedOpenSslEngine engine, X509Certificate[] peerCerts, String auth) + throws Exception { + manager.checkClientTrusted(peerCerts, auth, engine); + } + } + + private static final class OpenSslSniHostnameMatcher implements SniHostNameMatcher { + private final OpenSslEngineMap engineMap; + + OpenSslSniHostnameMatcher(OpenSslEngineMap engineMap) { + this.engineMap = engineMap; + } + + @Override + public boolean match(long ssl, String hostname) { + ReferenceCountedOpenSslEngine engine = engineMap.get(ssl); + if (engine != null) { + // TODO: In the next release of tcnative we should pass the byte[] directly in and not use a String. + return engine.checkSniHostnameMatch(hostname.getBytes(CharsetUtil.UTF_8)); + } + logger.warn("No ReferenceCountedOpenSslEngine found for SSL pointer: {}", ssl); + return false; + } + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/SignatureAlgorithmConverter.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/SignatureAlgorithmConverter.java new file mode 100644 index 0000000..362a9de --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/SignatureAlgorithmConverter.java @@ -0,0 +1,74 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import java.util.Locale; +import java.util.regex.Matcher; +import java.util.regex.Pattern; + +/** + * Converts OpenSSL signature Algorithm names to + * + * Java signature Algorithm names. + */ +final class SignatureAlgorithmConverter { + + private SignatureAlgorithmConverter() { } + + // OpenSSL has 3 different formats it uses at the moment we will match against all of these. + // For example: + // ecdsa-with-SHA384 + // hmacWithSHA384 + // dsa_with_SHA224 + // + // For more details see https://github.com/openssl/openssl/blob/OpenSSL_1_0_2p/crypto/objects/obj_dat.h + // + // BoringSSL uses a different format: + // https://github.com/google/boringssl/blob/8525ff3/ssl/ssl_privkey.cc#L436 + // + private static final Pattern PATTERN = Pattern.compile( + // group 1 - 2 + "(?:(^[a-zA-Z].+)With(.+)Encryption$)|" + + // group 3 - 4 + "(?:(^[a-zA-Z].+)(?:_with_|-with-|_pkcs1_|_pss_rsae_)(.+$))|" + + // group 5 - 6 + "(?:(^[a-zA-Z].+)_(.+$))"); + + /** + * Converts an OpenSSL algorithm name to a Java algorithm name and return it, + * or return {@code null} if the conversation failed because the format is not known. + */ + static String toJavaName(String opensslName) { + if (opensslName == null) { + return null; + } + Matcher matcher = PATTERN.matcher(opensslName); + if (matcher.matches()) { + String group1 = matcher.group(1); + if (group1 != null) { + return group1.toUpperCase(Locale.ROOT) + "with" + matcher.group(2).toUpperCase(Locale.ROOT); + } + if (matcher.group(3) != null) { + return matcher.group(4).toUpperCase(Locale.ROOT) + "with" + matcher.group(3).toUpperCase(Locale.ROOT); + } + + if (matcher.group(5) != null) { + return matcher.group(6).toUpperCase(Locale.ROOT) + "with" + matcher.group(5).toUpperCase(Locale.ROOT); + } + } + return null; + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/SniCompletionEvent.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/SniCompletionEvent.java new file mode 100644 index 0000000..2c29258 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/SniCompletionEvent.java @@ -0,0 +1,54 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.util.internal.UnstableApi; + +/** + * Event that is fired once we did a selection of a {@link SslContext} based on the {@code SNI hostname}, + * which may be because it was successful or there was an error. + */ +@UnstableApi +public final class SniCompletionEvent extends SslCompletionEvent { + private final String hostname; + + public SniCompletionEvent(String hostname) { + this.hostname = hostname; + } + + public SniCompletionEvent(String hostname, Throwable cause) { + super(cause); + this.hostname = hostname; + } + + public SniCompletionEvent(Throwable cause) { + this(null, cause); + } + + /** + * Returns the SNI hostname send by the client if we were able to parse it, {@code null} otherwise. + */ + public String hostname() { + return hostname; + } + + @Override + public String toString() { + final Throwable cause = cause(); + return cause == null ? getClass().getSimpleName() + "(SUCCESS='" + hostname + "'\")": + getClass().getSimpleName() + '(' + cause + ')'; + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/SniHandler.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/SniHandler.java new file mode 100644 index 0000000..0f1d069 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/SniHandler.java @@ -0,0 +1,234 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.buffer.ByteBufAllocator; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.DecoderException; +import io.netty.util.AsyncMapping; +import io.netty.util.DomainNameMapping; +import io.netty.util.Mapping; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.Promise; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.PlatformDependent; + +/** + *

Enables SNI + * (Server Name Indication) extension for server side SSL. For clients + * support SNI, the server could have multiple host name bound on a single IP. + * The client will send host name in the handshake data so server could decide + * which certificate to choose for the host name.

+ */ +public class SniHandler extends AbstractSniHandler { + private static final Selection EMPTY_SELECTION = new Selection(null, null); + + protected final AsyncMapping mapping; + + private volatile Selection selection = EMPTY_SELECTION; + + /** + * Creates a SNI detection handler with configured {@link SslContext} + * maintained by {@link Mapping} + * + * @param mapping the mapping of domain name to {@link SslContext} + */ + public SniHandler(Mapping mapping) { + this(new AsyncMappingAdapter(mapping)); + } + + /** + * Creates a SNI detection handler with configured {@link SslContext} + * maintained by {@link Mapping} + * + * @param mapping the mapping of domain name to {@link SslContext} + * @param maxClientHelloLength the maximum length of the client hello message + * @param handshakeTimeoutMillis the handshake timeout in milliseconds + */ + public SniHandler(Mapping mapping, + int maxClientHelloLength, long handshakeTimeoutMillis) { + this(new AsyncMappingAdapter(mapping), maxClientHelloLength, handshakeTimeoutMillis); + } + + /** + * Creates a SNI detection handler with configured {@link SslContext} + * maintained by {@link DomainNameMapping} + * + * @param mapping the mapping of domain name to {@link SslContext} + */ + public SniHandler(DomainNameMapping mapping) { + this((Mapping) mapping); + } + + /** + * Creates a SNI detection handler with configured {@link SslContext} + * maintained by {@link AsyncMapping} + * + * @param mapping the mapping of domain name to {@link SslContext} + */ + @SuppressWarnings("unchecked") + public SniHandler(AsyncMapping mapping) { + this(mapping, 0, 0L); + } + + /** + * Creates a SNI detection handler with configured {@link SslContext} + * maintained by {@link AsyncMapping} + * + * @param mapping the mapping of domain name to {@link SslContext} + * @param maxClientHelloLength the maximum length of the client hello message + * @param handshakeTimeoutMillis the handshake timeout in milliseconds + */ + @SuppressWarnings("unchecked") + public SniHandler(AsyncMapping mapping, + int maxClientHelloLength, long handshakeTimeoutMillis) { + super(maxClientHelloLength, handshakeTimeoutMillis); + this.mapping = (AsyncMapping) ObjectUtil.checkNotNull(mapping, "mapping"); + } + + /** + * Creates a SNI detection handler with configured {@link SslContext} + * maintained by {@link Mapping} + * + * @param mapping the mapping of domain name to {@link SslContext} + * @param handshakeTimeoutMillis the handshake timeout in milliseconds + */ + public SniHandler(Mapping mapping, long handshakeTimeoutMillis) { + this(new AsyncMappingAdapter(mapping), handshakeTimeoutMillis); + } + + /** + * Creates a SNI detection handler with configured {@link SslContext} + * maintained by {@link AsyncMapping} + * + * @param mapping the mapping of domain name to {@link SslContext} + * @param handshakeTimeoutMillis the handshake timeout in milliseconds + */ + public SniHandler(AsyncMapping mapping, long handshakeTimeoutMillis) { + this(mapping, 0, handshakeTimeoutMillis); + } + + /** + * @return the selected hostname + */ + public String hostname() { + return selection.hostname; + } + + /** + * @return the selected {@link SslContext} + */ + public SslContext sslContext() { + return selection.context; + } + + /** + * The default implementation will simply call {@link AsyncMapping#map(Object, Promise)} but + * users can override this method to implement custom behavior. + * + * @see AsyncMapping#map(Object, Promise) + */ + @Override + protected Future lookup(ChannelHandlerContext ctx, String hostname) throws Exception { + return mapping.map(hostname, ctx.executor().newPromise()); + } + + @Override + protected final void onLookupComplete(ChannelHandlerContext ctx, + String hostname, Future future) throws Exception { + if (!future.isSuccess()) { + final Throwable cause = future.cause(); + if (cause instanceof Error) { + throw (Error) cause; + } + throw new DecoderException("failed to get the SslContext for " + hostname, cause); + } + + SslContext sslContext = future.getNow(); + selection = new Selection(sslContext, hostname); + try { + replaceHandler(ctx, hostname, sslContext); + } catch (Throwable cause) { + selection = EMPTY_SELECTION; + PlatformDependent.throwException(cause); + } + } + + /** + * The default implementation of this method will simply replace {@code this} {@link SniHandler} + * instance with a {@link SslHandler}. Users may override this method to implement custom behavior. + * + * Please be aware that this method may get called after a client has already disconnected and + * custom implementations must take it into consideration when overriding this method. + * + * It's also possible for the hostname argument to be {@code null}. + */ + protected void replaceHandler(ChannelHandlerContext ctx, String hostname, SslContext sslContext) throws Exception { + SslHandler sslHandler = null; + try { + sslHandler = newSslHandler(sslContext, ctx.alloc()); + ctx.pipeline().replace(this, SslHandler.class.getName(), sslHandler); + sslHandler = null; + } finally { + // Since the SslHandler was not inserted into the pipeline the ownership of the SSLEngine was not + // transferred to the SslHandler. + // See https://github.com/netty/netty/issues/5678 + if (sslHandler != null) { + ReferenceCountUtil.safeRelease(sslHandler.engine()); + } + } + } + + /** + * Returns a new {@link SslHandler} using the given {@link SslContext} and {@link ByteBufAllocator}. + * Users may override this method to implement custom behavior. + */ + protected SslHandler newSslHandler(SslContext context, ByteBufAllocator allocator) { + SslHandler sslHandler = context.newHandler(allocator); + sslHandler.setHandshakeTimeoutMillis(handshakeTimeoutMillis); + return sslHandler; + } + + private static final class AsyncMappingAdapter implements AsyncMapping { + private final Mapping mapping; + + private AsyncMappingAdapter(Mapping mapping) { + this.mapping = ObjectUtil.checkNotNull(mapping, "mapping"); + } + + @Override + public Future map(String input, Promise promise) { + final SslContext context; + try { + context = mapping.map(input); + } catch (Throwable cause) { + return promise.setFailure(cause); + } + return promise.setSuccess(context); + } + } + + private static final class Selection { + final SslContext context; + final String hostname; + + Selection(SslContext context, String hostname) { + this.context = context; + this.hostname = hostname; + } + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/SslClientHelloHandler.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/SslClientHelloHandler.java new file mode 100644 index 0000000..660f135 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/SslClientHelloHandler.java @@ -0,0 +1,349 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufUtil; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelOutboundHandler; +import io.netty.channel.ChannelPromise; +import io.netty.handler.codec.ByteToMessageDecoder; +import io.netty.handler.codec.DecoderException; +import io.netty.handler.codec.TooLongFrameException; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.FutureListener; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.net.SocketAddress; +import java.util.List; + +/** + * {@link ByteToMessageDecoder} which allows to be notified once a full {@code ClientHello} was received. + * + * @param the parameter type + */ +public abstract class SslClientHelloHandler extends ByteToMessageDecoder implements ChannelOutboundHandler { + + /** + * The maximum length of client hello message as defined by + * RFC5246. + */ + public static final int MAX_CLIENT_HELLO_LENGTH = 0xFFFFFF; + + private static final InternalLogger logger = + InternalLoggerFactory.getInstance(SslClientHelloHandler.class); + + private final int maxClientHelloLength; + private boolean handshakeFailed; + private boolean suppressRead; + private boolean readPending; + private ByteBuf handshakeBuffer; + + public SslClientHelloHandler() { + this(MAX_CLIENT_HELLO_LENGTH); + } + + protected SslClientHelloHandler(int maxClientHelloLength) { + // 16MB is the maximum as per RFC: + // See https://www.rfc-editor.org/rfc/rfc5246#section-6.2.1 + this.maxClientHelloLength = + ObjectUtil.checkInRange(maxClientHelloLength, 0, MAX_CLIENT_HELLO_LENGTH, "maxClientHelloLength"); + } + + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws Exception { + if (!suppressRead && !handshakeFailed) { + try { + int readerIndex = in.readerIndex(); + int readableBytes = in.readableBytes(); + int handshakeLength = -1; + + // Check if we have enough data to determine the record type and length. + while (readableBytes >= SslUtils.SSL_RECORD_HEADER_LENGTH) { + final int contentType = in.getUnsignedByte(readerIndex); + switch (contentType) { + case SslUtils.SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC: + // fall-through + case SslUtils.SSL_CONTENT_TYPE_ALERT: + final int len = SslUtils.getEncryptedPacketLength(in, readerIndex); + + // Not an SSL/TLS packet + if (len == SslUtils.NOT_ENCRYPTED) { + handshakeFailed = true; + NotSslRecordException e = new NotSslRecordException( + "not an SSL/TLS record: " + ByteBufUtil.hexDump(in)); + in.skipBytes(in.readableBytes()); + ctx.fireUserEventTriggered(new SniCompletionEvent(e)); + SslUtils.handleHandshakeFailure(ctx, e, true); + throw e; + } + if (len == SslUtils.NOT_ENOUGH_DATA) { + // Not enough data + return; + } + // No ClientHello + select(ctx, null); + return; + case SslUtils.SSL_CONTENT_TYPE_HANDSHAKE: + final int majorVersion = in.getUnsignedByte(readerIndex + 1); + // SSLv3 or TLS + if (majorVersion == 3) { + int packetLength = in.getUnsignedShort(readerIndex + 3) + + SslUtils.SSL_RECORD_HEADER_LENGTH; + + if (readableBytes < packetLength) { + // client hello incomplete; try again to decode once more data is ready. + return; + } else if (packetLength == SslUtils.SSL_RECORD_HEADER_LENGTH) { + select(ctx, null); + return; + } + + final int endOffset = readerIndex + packetLength; + + // Let's check if we already parsed the handshake length or not. + if (handshakeLength == -1) { + if (readerIndex + 4 > endOffset) { + // Need more data to read HandshakeType and handshakeLength (4 bytes) + return; + } + + final int handshakeType = in.getUnsignedByte(readerIndex + + SslUtils.SSL_RECORD_HEADER_LENGTH); + + // Check if this is a clientHello(1) + // See https://tools.ietf.org/html/rfc5246#section-7.4 + if (handshakeType != 1) { + select(ctx, null); + return; + } + + // Read the length of the handshake as it may arrive in fragments + // See https://tools.ietf.org/html/rfc5246#section-7.4 + handshakeLength = in.getUnsignedMedium(readerIndex + + SslUtils.SSL_RECORD_HEADER_LENGTH + 1); + + if (handshakeLength > maxClientHelloLength && maxClientHelloLength != 0) { + TooLongFrameException e = new TooLongFrameException( + "ClientHello length exceeds " + maxClientHelloLength + + ": " + handshakeLength); + in.skipBytes(in.readableBytes()); + ctx.fireUserEventTriggered(new SniCompletionEvent(e)); + SslUtils.handleHandshakeFailure(ctx, e, true); + throw e; + } + // Consume handshakeType and handshakeLength (this sums up as 4 bytes) + readerIndex += 4; + packetLength -= 4; + + if (handshakeLength + 4 + SslUtils.SSL_RECORD_HEADER_LENGTH <= packetLength) { + // We have everything we need in one packet. + // Skip the record header + readerIndex += SslUtils.SSL_RECORD_HEADER_LENGTH; + select(ctx, in.retainedSlice(readerIndex, handshakeLength)); + return; + } else { + if (handshakeBuffer == null) { + handshakeBuffer = ctx.alloc().buffer(handshakeLength); + } else { + // Clear the buffer so we can aggregate into it again. + handshakeBuffer.clear(); + } + } + } + + // Combine the encapsulated data in one buffer but not include the SSL_RECORD_HEADER + handshakeBuffer.writeBytes(in, readerIndex + SslUtils.SSL_RECORD_HEADER_LENGTH, + packetLength - SslUtils.SSL_RECORD_HEADER_LENGTH); + readerIndex += packetLength; + readableBytes -= packetLength; + if (handshakeLength <= handshakeBuffer.readableBytes()) { + ByteBuf clientHello = handshakeBuffer.setIndex(0, handshakeLength); + handshakeBuffer = null; + + select(ctx, clientHello); + return; + } + break; + } + // fall-through + default: + // not tls, ssl or application data + select(ctx, null); + return; + } + } + } catch (NotSslRecordException e) { + // Just rethrow as in this case we also closed the channel and this is consistent with SslHandler. + throw e; + } catch (TooLongFrameException e) { + // Just rethrow as in this case we also closed the channel + throw e; + } catch (Exception e) { + // unexpected encoding, ignore sni and use default + if (logger.isDebugEnabled()) { + logger.debug("Unexpected client hello packet: " + ByteBufUtil.hexDump(in), e); + } + select(ctx, null); + } + } + } + + private void releaseHandshakeBuffer() { + releaseIfNotNull(handshakeBuffer); + handshakeBuffer = null; + } + + private static void releaseIfNotNull(ByteBuf buffer) { + if (buffer != null) { + buffer.release(); + } + } + + private void select(final ChannelHandlerContext ctx, ByteBuf clientHello) throws Exception { + final Future future; + try { + future = lookup(ctx, clientHello); + if (future.isDone()) { + onLookupComplete(ctx, future); + } else { + suppressRead = true; + final ByteBuf finalClientHello = clientHello; + future.addListener(new FutureListener() { + @Override + public void operationComplete(Future future) { + releaseIfNotNull(finalClientHello); + try { + suppressRead = false; + try { + onLookupComplete(ctx, future); + } catch (DecoderException err) { + ctx.fireExceptionCaught(err); + } catch (Exception cause) { + ctx.fireExceptionCaught(new DecoderException(cause)); + } catch (Throwable cause) { + ctx.fireExceptionCaught(cause); + } + } finally { + if (readPending) { + readPending = false; + ctx.read(); + } + } + } + }); + + // Ownership was transferred to the FutureListener. + clientHello = null; + } + } catch (Throwable cause) { + PlatformDependent.throwException(cause); + } finally { + releaseIfNotNull(clientHello); + } + } + + @Override + protected void handlerRemoved0(ChannelHandlerContext ctx) throws Exception { + releaseHandshakeBuffer(); + + super.handlerRemoved0(ctx); + } + + /** + * Kicks off a lookup for the given {@code ClientHello} and returns a {@link Future} which in turn will + * notify the {@link #onLookupComplete(ChannelHandlerContext, Future)} on completion. + * + * See https://tools.ietf.org/html/rfc5246#section-7.4.1.2 + * + *
+     * struct {
+     *    ProtocolVersion client_version;
+     *    Random random;
+     *    SessionID session_id;
+     *    CipherSuite cipher_suites<2..2^16-2>;
+     *    CompressionMethod compression_methods<1..2^8-1>;
+     *    select (extensions_present) {
+     *        case false:
+     *            struct {};
+     *        case true:
+     *            Extension extensions<0..2^16-1>;
+     *    };
+     * } ClientHello;
+     * 
+ * + * @see #onLookupComplete(ChannelHandlerContext, Future) + * + * @param ctx ctx + * @param clientHello clientHello + */ + protected abstract Future lookup(ChannelHandlerContext ctx, ByteBuf clientHello) throws Exception; + + /** + * Called upon completion of the {@link #lookup(ChannelHandlerContext, ByteBuf)} {@link Future}. + * + * @see #lookup(ChannelHandlerContext, ByteBuf) + */ + protected abstract void onLookupComplete(ChannelHandlerContext ctx, Future future) throws Exception; + + @Override + public void read(ChannelHandlerContext ctx) throws Exception { + if (suppressRead) { + readPending = true; + } else { + ctx.read(); + } + } + + @Override + public void bind(ChannelHandlerContext ctx, SocketAddress localAddress, ChannelPromise promise) throws Exception { + ctx.bind(localAddress, promise); + } + + @Override + public void connect(ChannelHandlerContext ctx, SocketAddress remoteAddress, SocketAddress localAddress, + ChannelPromise promise) throws Exception { + ctx.connect(remoteAddress, localAddress, promise); + } + + @Override + public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + ctx.disconnect(promise); + } + + @Override + public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + ctx.close(promise); + } + + @Override + public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + ctx.deregister(promise); + } + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + ctx.write(msg, promise); + } + + @Override + public void flush(ChannelHandlerContext ctx) throws Exception { + ctx.flush(); + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/SslCloseCompletionEvent.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/SslCloseCompletionEvent.java new file mode 100644 index 0000000..679e346 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/SslCloseCompletionEvent.java @@ -0,0 +1,37 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +/** + * Event that is fired once the close_notify was received or if an failure happens before it was received. + */ +public final class SslCloseCompletionEvent extends SslCompletionEvent { + + public static final SslCloseCompletionEvent SUCCESS = new SslCloseCompletionEvent(); + + /** + * Creates a new event that indicates a successful receiving of close_notify. + */ + private SslCloseCompletionEvent() { } + + /** + * Creates a new event that indicates an close_notify was not received because of an previous error. + * Use {@link #SUCCESS} to indicate a success. + */ + public SslCloseCompletionEvent(Throwable cause) { + super(cause); + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/SslClosedEngineException.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/SslClosedEngineException.java new file mode 100644 index 0000000..0ef9c0f --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/SslClosedEngineException.java @@ -0,0 +1,31 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import javax.net.ssl.SSLException; + +/** + * {@link SSLException} which signals that the exception was caused by an {@link javax.net.ssl.SSLEngine} which was + * closed already. + */ +public final class SslClosedEngineException extends SSLException { + + private static final long serialVersionUID = -5204207600474401904L; + + public SslClosedEngineException(String reason) { + super(reason); + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/SslCompletionEvent.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/SslCompletionEvent.java new file mode 100644 index 0000000..d6d9772 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/SslCompletionEvent.java @@ -0,0 +1,53 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.util.internal.ObjectUtil; + +public abstract class SslCompletionEvent { + + private final Throwable cause; + + SslCompletionEvent() { + cause = null; + } + + SslCompletionEvent(Throwable cause) { + this.cause = ObjectUtil.checkNotNull(cause, "cause"); + } + + /** + * Return {@code true} if the completion was successful + */ + public final boolean isSuccess() { + return cause == null; + } + + /** + * Return the {@link Throwable} if {@link #isSuccess()} returns {@code false} + * and so the completion failed. + */ + public final Throwable cause() { + return cause; + } + + @Override + public String toString() { + final Throwable cause = cause(); + return cause == null? getClass().getSimpleName() + "(SUCCESS)" : + getClass().getSimpleName() + '(' + cause + ')'; + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/SslContext.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/SslContext.java new file mode 100644 index 0000000..40abef9 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/SslContext.java @@ -0,0 +1,1363 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.ssl; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufInputStream; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelPipeline; +import io.netty.handler.ssl.ApplicationProtocolConfig.Protocol; +import io.netty.handler.ssl.ApplicationProtocolConfig.SelectedListenerFailureBehavior; +import io.netty.handler.ssl.ApplicationProtocolConfig.SelectorFailureBehavior; +import io.netty.util.AttributeMap; +import io.netty.util.DefaultAttributeMap; +import io.netty.util.internal.EmptyArrays; +import io.netty.util.internal.PlatformDependent; + +import java.io.BufferedInputStream; +import java.security.Provider; +import javax.net.ssl.KeyManager; +import javax.net.ssl.KeyManagerFactory; +import javax.crypto.Cipher; +import javax.crypto.EncryptedPrivateKeyInfo; +import javax.crypto.NoSuchPaddingException; +import javax.crypto.SecretKey; +import javax.crypto.SecretKeyFactory; +import javax.crypto.spec.PBEKeySpec; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLException; +import javax.net.ssl.SSLSessionContext; +import javax.net.ssl.TrustManager; +import javax.net.ssl.TrustManagerFactory; + +import java.io.File; +import java.io.IOException; +import java.io.InputStream; +import java.security.AlgorithmParameters; +import java.security.InvalidAlgorithmParameterException; +import java.security.InvalidKeyException; +import java.security.KeyException; +import java.security.KeyFactory; +import java.security.KeyStore; +import java.security.KeyStoreException; +import java.security.NoSuchAlgorithmException; +import java.security.PrivateKey; +import java.security.UnrecoverableKeyException; +import java.security.cert.CertificateException; +import java.security.cert.CertificateFactory; +import java.security.cert.X509Certificate; +import java.security.spec.InvalidKeySpecException; +import java.security.spec.PKCS8EncodedKeySpec; +import java.util.List; +import java.util.Map; +import java.util.concurrent.Executor; + +/** + * A secure socket protocol implementation which acts as a factory for {@link SSLEngine} and {@link SslHandler}. + * Internally, it is implemented via JDK's {@link SSLContext} or OpenSSL's {@code SSL_CTX}. + * + * Making your server support SSL/TLS + *
+ * // In your {@link ChannelInitializer}:
+ * {@link ChannelPipeline} p = channel.pipeline();
+ * {@link SslContext} sslCtx = {@link SslContextBuilder#forServer(File, File) SslContextBuilder.forServer(...)}.build();
+ * p.addLast("ssl", {@link #newHandler(ByteBufAllocator) sslCtx.newHandler(channel.alloc())});
+ * ...
+ * 
+ * + * Making your client support SSL/TLS + *
+ * // In your {@link ChannelInitializer}:
+ * {@link ChannelPipeline} p = channel.pipeline();
+ * {@link SslContext} sslCtx = {@link SslContextBuilder#forClient() SslContextBuilder.forClient()}.build();
+ * p.addLast("ssl", {@link #newHandler(ByteBufAllocator, String, int) sslCtx.newHandler(channel.alloc(), host, port)});
+ * ...
+ * 
+ */ +public abstract class SslContext { + static final String ALIAS = "key"; + + static final CertificateFactory X509_CERT_FACTORY; + static { + try { + X509_CERT_FACTORY = CertificateFactory.getInstance("X.509"); + } catch (CertificateException e) { + throw new IllegalStateException("unable to instance X.509 CertificateFactory", e); + } + } + + private final boolean startTls; + private final AttributeMap attributes = new DefaultAttributeMap(); + private static final String OID_PKCS5_PBES2 = "1.2.840.113549.1.5.13"; + private static final String PBES2 = "PBES2"; + + /** + * Returns the default server-side implementation provider currently in use. + * + * @return {@link SslProvider#OPENSSL} if OpenSSL is available. {@link SslProvider#JDK} otherwise. + */ + public static SslProvider defaultServerProvider() { + return defaultProvider(); + } + + /** + * Returns the default client-side implementation provider currently in use. + * + * @return {@link SslProvider#OPENSSL} if OpenSSL is available. {@link SslProvider#JDK} otherwise. + */ + public static SslProvider defaultClientProvider() { + return defaultProvider(); + } + + private static SslProvider defaultProvider() { + if (OpenSsl.isAvailable()) { + return SslProvider.OPENSSL; + } else { + return SslProvider.JDK; + } + } + + /** + * Creates a new server-side {@link SslContext}. + * + * @param certChainFile an X.509 certificate chain file in PEM format + * @param keyFile a PKCS#8 private key file in PEM format + * @return a new server-side {@link SslContext} + * @deprecated Replaced by {@link SslContextBuilder} + */ + @Deprecated + public static SslContext newServerContext(File certChainFile, File keyFile) throws SSLException { + return newServerContext(certChainFile, keyFile, null); + } + + /** + * Creates a new server-side {@link SslContext}. + * + * @param certChainFile an X.509 certificate chain file in PEM format + * @param keyFile a PKCS#8 private key file in PEM format + * @param keyPassword the password of the {@code keyFile}. + * {@code null} if it's not password-protected. + * @return a new server-side {@link SslContext} + * @deprecated Replaced by {@link SslContextBuilder} + */ + @Deprecated + public static SslContext newServerContext( + File certChainFile, File keyFile, String keyPassword) throws SSLException { + return newServerContext(null, certChainFile, keyFile, keyPassword); + } + + /** + * Creates a new server-side {@link SslContext}. + * + * @param certChainFile an X.509 certificate chain file in PEM format + * @param keyFile a PKCS#8 private key file in PEM format + * @param keyPassword the password of the {@code keyFile}. + * {@code null} if it's not password-protected. + * @param ciphers the cipher suites to enable, in the order of preference. + * {@code null} to use the default cipher suites. + * @param nextProtocols the application layer protocols to accept, in the order of preference. + * {@code null} to disable TLS NPN/ALPN extension. + * @param sessionCacheSize the size of the cache used for storing SSL session objects. + * {@code 0} to use the default value. + * @param sessionTimeout the timeout for the cached SSL session objects, in seconds. + * {@code 0} to use the default value. + * @return a new server-side {@link SslContext} + * @deprecated Replaced by {@link SslContextBuilder} + */ + @Deprecated + public static SslContext newServerContext( + File certChainFile, File keyFile, String keyPassword, + Iterable ciphers, Iterable nextProtocols, + long sessionCacheSize, long sessionTimeout) throws SSLException { + + return newServerContext( + null, certChainFile, keyFile, keyPassword, + ciphers, nextProtocols, sessionCacheSize, sessionTimeout); + } + + /** + * Creates a new server-side {@link SslContext}. + * + * @param certChainFile an X.509 certificate chain file in PEM format + * @param keyFile a PKCS#8 private key file in PEM format + * @param keyPassword the password of the {@code keyFile}. + * {@code null} if it's not password-protected. + * @param ciphers the cipher suites to enable, in the order of preference. + * {@code null} to use the default cipher suites. + * @param cipherFilter a filter to apply over the supplied list of ciphers + * @param apn Provides a means to configure parameters related to application protocol negotiation. + * @param sessionCacheSize the size of the cache used for storing SSL session objects. + * {@code 0} to use the default value. + * @param sessionTimeout the timeout for the cached SSL session objects, in seconds. + * {@code 0} to use the default value. + * @return a new server-side {@link SslContext} + * @deprecated Replaced by {@link SslContextBuilder} + */ + @Deprecated + public static SslContext newServerContext( + File certChainFile, File keyFile, String keyPassword, + Iterable ciphers, CipherSuiteFilter cipherFilter, ApplicationProtocolConfig apn, + long sessionCacheSize, long sessionTimeout) throws SSLException { + return newServerContext( + null, certChainFile, keyFile, keyPassword, + ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout); + } + + /** + * Creates a new server-side {@link SslContext}. + * + * @param provider the {@link SslContext} implementation to use. + * {@code null} to use the current default one. + * @param certChainFile an X.509 certificate chain file in PEM format + * @param keyFile a PKCS#8 private key file in PEM format + * @return a new server-side {@link SslContext} + * @deprecated Replaced by {@link SslContextBuilder} + */ + @Deprecated + public static SslContext newServerContext( + SslProvider provider, File certChainFile, File keyFile) throws SSLException { + return newServerContext(provider, certChainFile, keyFile, null); + } + + /** + * Creates a new server-side {@link SslContext}. + * + * @param provider the {@link SslContext} implementation to use. + * {@code null} to use the current default one. + * @param certChainFile an X.509 certificate chain file in PEM format + * @param keyFile a PKCS#8 private key file in PEM format + * @param keyPassword the password of the {@code keyFile}. + * {@code null} if it's not password-protected. + * @return a new server-side {@link SslContext} + * @deprecated Replaced by {@link SslContextBuilder} + */ + @Deprecated + public static SslContext newServerContext( + SslProvider provider, File certChainFile, File keyFile, String keyPassword) throws SSLException { + return newServerContext(provider, certChainFile, keyFile, keyPassword, null, IdentityCipherSuiteFilter.INSTANCE, + null, 0, 0); + } + + /** + * Creates a new server-side {@link SslContext}. + * + * @param provider the {@link SslContext} implementation to use. + * {@code null} to use the current default one. + * @param certChainFile an X.509 certificate chain file in PEM format + * @param keyFile a PKCS#8 private key file in PEM format + * @param keyPassword the password of the {@code keyFile}. + * {@code null} if it's not password-protected. + * @param ciphers the cipher suites to enable, in the order of preference. + * {@code null} to use the default cipher suites. + * @param nextProtocols the application layer protocols to accept, in the order of preference. + * {@code null} to disable TLS NPN/ALPN extension. + * @param sessionCacheSize the size of the cache used for storing SSL session objects. + * {@code 0} to use the default value. + * @param sessionTimeout the timeout for the cached SSL session objects, in seconds. + * {@code 0} to use the default value. + * @return a new server-side {@link SslContext} + * @deprecated Replaced by {@link SslContextBuilder} + */ + @Deprecated + public static SslContext newServerContext( + SslProvider provider, + File certChainFile, File keyFile, String keyPassword, + Iterable ciphers, Iterable nextProtocols, + long sessionCacheSize, long sessionTimeout) throws SSLException { + return newServerContext(provider, certChainFile, keyFile, keyPassword, + ciphers, IdentityCipherSuiteFilter.INSTANCE, + toApplicationProtocolConfig(nextProtocols), sessionCacheSize, sessionTimeout); + } + + /** + * Creates a new server-side {@link SslContext}. + * + * @param provider the {@link SslContext} implementation to use. + * {@code null} to use the current default one. + * @param certChainFile an X.509 certificate chain file in PEM format + * @param keyFile a PKCS#8 private key file in PEM format + * @param keyPassword the password of the {@code keyFile}. + * {@code null} if it's not password-protected. + * @param trustManagerFactory the {@link TrustManagerFactory} that provides the {@link TrustManager}s + * that verifies the certificates sent from servers. + * {@code null} to use the default. + * @param ciphers the cipher suites to enable, in the order of preference. + * {@code null} to use the default cipher suites. + * @param nextProtocols the application layer protocols to accept, in the order of preference. + * {@code null} to disable TLS NPN/ALPN extension. + * @param sessionCacheSize the size of the cache used for storing SSL session objects. + * {@code 0} to use the default value. + * @param sessionTimeout the timeout for the cached SSL session objects, in seconds. + * {@code 0} to use the default value. + * @return a new server-side {@link SslContext} + * @deprecated Replaced by {@link SslContextBuilder} + */ + @Deprecated + public static SslContext newServerContext( + SslProvider provider, + File certChainFile, File keyFile, String keyPassword, TrustManagerFactory trustManagerFactory, + Iterable ciphers, Iterable nextProtocols, + long sessionCacheSize, long sessionTimeout) throws SSLException { + + return newServerContext( + provider, null, trustManagerFactory, certChainFile, keyFile, keyPassword, + null, ciphers, IdentityCipherSuiteFilter.INSTANCE, + toApplicationProtocolConfig(nextProtocols), sessionCacheSize, sessionTimeout); + } + + /** + * Creates a new server-side {@link SslContext}. + * + * @param provider the {@link SslContext} implementation to use. + * {@code null} to use the current default one. + * @param certChainFile an X.509 certificate chain file in PEM format + * @param keyFile a PKCS#8 private key file in PEM format + * @param keyPassword the password of the {@code keyFile}. + * {@code null} if it's not password-protected. + * @param ciphers the cipher suites to enable, in the order of preference. + * {@code null} to use the default cipher suites. + * @param cipherFilter a filter to apply over the supplied list of ciphers + * Only required if {@code provider} is {@link SslProvider#JDK} + * @param apn Provides a means to configure parameters related to application protocol negotiation. + * @param sessionCacheSize the size of the cache used for storing SSL session objects. + * {@code 0} to use the default value. + * @param sessionTimeout the timeout for the cached SSL session objects, in seconds. + * {@code 0} to use the default value. + * @return a new server-side {@link SslContext} + * @deprecated Replaced by {@link SslContextBuilder} + */ + @Deprecated + public static SslContext newServerContext(SslProvider provider, + File certChainFile, File keyFile, String keyPassword, + Iterable ciphers, CipherSuiteFilter cipherFilter, ApplicationProtocolConfig apn, + long sessionCacheSize, long sessionTimeout) throws SSLException { + return newServerContext(provider, null, null, certChainFile, keyFile, keyPassword, null, + ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout, KeyStore.getDefaultType()); + } + + /** + * Creates a new server-side {@link SslContext}. + * @param provider the {@link SslContext} implementation to use. + * {@code null} to use the current default one. + * @param trustCertCollectionFile an X.509 certificate collection file in PEM format. + * This provides the certificate collection used for mutual authentication. + * {@code null} to use the system default + * @param trustManagerFactory the {@link TrustManagerFactory} that provides the {@link TrustManager}s + * that verifies the certificates sent from clients. + * {@code null} to use the default or the results of parsing + * {@code trustCertCollectionFile}. + * This parameter is ignored if {@code provider} is not {@link SslProvider#JDK}. + * @param keyCertChainFile an X.509 certificate chain file in PEM format + * @param keyFile a PKCS#8 private key file in PEM format + * @param keyPassword the password of the {@code keyFile}. + * {@code null} if it's not password-protected. + * @param keyManagerFactory the {@link KeyManagerFactory} that provides the {@link KeyManager}s + * that is used to encrypt data being sent to clients. + * {@code null} to use the default or the results of parsing + * {@code keyCertChainFile} and {@code keyFile}. + * This parameter is ignored if {@code provider} is not {@link SslProvider#JDK}. + * @param ciphers the cipher suites to enable, in the order of preference. + * {@code null} to use the default cipher suites. + * @param cipherFilter a filter to apply over the supplied list of ciphers + * Only required if {@code provider} is {@link SslProvider#JDK} + * @param apn Provides a means to configure parameters related to application protocol negotiation. + * @param sessionCacheSize the size of the cache used for storing SSL session objects. + * {@code 0} to use the default value. + * @param sessionTimeout the timeout for the cached SSL session objects, in seconds. + * {@code 0} to use the default value. + * @return a new server-side {@link SslContext} + * @deprecated Replaced by {@link SslContextBuilder} + */ + @Deprecated + public static SslContext newServerContext( + SslProvider provider, + File trustCertCollectionFile, TrustManagerFactory trustManagerFactory, + File keyCertChainFile, File keyFile, String keyPassword, KeyManagerFactory keyManagerFactory, + Iterable ciphers, CipherSuiteFilter cipherFilter, ApplicationProtocolConfig apn, + long sessionCacheSize, long sessionTimeout) throws SSLException { + return newServerContext(provider, trustCertCollectionFile, trustManagerFactory, keyCertChainFile, + keyFile, keyPassword, keyManagerFactory, ciphers, cipherFilter, apn, + sessionCacheSize, sessionTimeout, KeyStore.getDefaultType()); + } + + /** + * Creates a new server-side {@link SslContext}. + * @param provider the {@link SslContext} implementation to use. + * {@code null} to use the current default one. + * @param trustCertCollectionFile an X.509 certificate collection file in PEM format. + * This provides the certificate collection used for mutual authentication. + * {@code null} to use the system default + * @param trustManagerFactory the {@link TrustManagerFactory} that provides the {@link TrustManager}s + * that verifies the certificates sent from clients. + * {@code null} to use the default or the results of parsing + * {@code trustCertCollectionFile}. + * This parameter is ignored if {@code provider} is not {@link SslProvider#JDK}. + * @param keyCertChainFile an X.509 certificate chain file in PEM format + * @param keyFile a PKCS#8 private key file in PEM format + * @param keyPassword the password of the {@code keyFile}. + * {@code null} if it's not password-protected. + * @param keyManagerFactory the {@link KeyManagerFactory} that provides the {@link KeyManager}s + * that is used to encrypt data being sent to clients. + * {@code null} to use the default or the results of parsing + * {@code keyCertChainFile} and {@code keyFile}. + * This parameter is ignored if {@code provider} is not {@link SslProvider#JDK}. + * @param ciphers the cipher suites to enable, in the order of preference. + * {@code null} to use the default cipher suites. + * @param cipherFilter a filter to apply over the supplied list of ciphers + * Only required if {@code provider} is {@link SslProvider#JDK} + * @param apn Provides a means to configure parameters related to application protocol negotiation. + * @param sessionCacheSize the size of the cache used for storing SSL session objects. + * {@code 0} to use the default value. + * @param sessionTimeout the timeout for the cached SSL session objects, in seconds. + * {@code 0} to use the default value. + * @param keyStore the keystore type that should be used + * @return a new server-side {@link SslContext} + */ + static SslContext newServerContext( + SslProvider provider, + File trustCertCollectionFile, TrustManagerFactory trustManagerFactory, + File keyCertChainFile, File keyFile, String keyPassword, KeyManagerFactory keyManagerFactory, + Iterable ciphers, CipherSuiteFilter cipherFilter, ApplicationProtocolConfig apn, + long sessionCacheSize, long sessionTimeout, String keyStore) throws SSLException { + try { + return newServerContextInternal(provider, null, toX509Certificates(trustCertCollectionFile), + trustManagerFactory, toX509Certificates(keyCertChainFile), + toPrivateKey(keyFile, keyPassword), + keyPassword, keyManagerFactory, ciphers, cipherFilter, apn, + sessionCacheSize, sessionTimeout, ClientAuth.NONE, null, + false, false, keyStore); + } catch (Exception e) { + if (e instanceof SSLException) { + throw (SSLException) e; + } + throw new SSLException("failed to initialize the server-side SSL context", e); + } + } + + static SslContext newServerContextInternal( + SslProvider provider, + Provider sslContextProvider, + X509Certificate[] trustCertCollection, TrustManagerFactory trustManagerFactory, + X509Certificate[] keyCertChain, PrivateKey key, String keyPassword, KeyManagerFactory keyManagerFactory, + Iterable ciphers, CipherSuiteFilter cipherFilter, ApplicationProtocolConfig apn, + long sessionCacheSize, long sessionTimeout, ClientAuth clientAuth, String[] protocols, boolean startTls, + boolean enableOcsp, String keyStoreType, Map.Entry, Object>... ctxOptions) + throws SSLException { + + if (provider == null) { + provider = defaultServerProvider(); + } + + switch (provider) { + case JDK: + if (enableOcsp) { + throw new IllegalArgumentException("OCSP is not supported with this SslProvider: " + provider); + } + return new JdkSslServerContext(sslContextProvider, + trustCertCollection, trustManagerFactory, keyCertChain, key, keyPassword, + keyManagerFactory, ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout, + clientAuth, protocols, startTls, keyStoreType); + case OPENSSL: + verifyNullSslContextProvider(provider, sslContextProvider); + return new OpenSslServerContext( + trustCertCollection, trustManagerFactory, keyCertChain, key, keyPassword, + keyManagerFactory, ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout, + clientAuth, protocols, startTls, enableOcsp, keyStoreType, ctxOptions); + case OPENSSL_REFCNT: + verifyNullSslContextProvider(provider, sslContextProvider); + return new ReferenceCountedOpenSslServerContext( + trustCertCollection, trustManagerFactory, keyCertChain, key, keyPassword, + keyManagerFactory, ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout, + clientAuth, protocols, startTls, enableOcsp, keyStoreType, ctxOptions); + default: + throw new Error(provider.toString()); + } + } + + private static void verifyNullSslContextProvider(SslProvider provider, Provider sslContextProvider) { + if (sslContextProvider != null) { + throw new IllegalArgumentException("Java Security Provider unsupported for SslProvider: " + provider); + } + } + + /** + * Creates a new client-side {@link SslContext}. + * + * @return a new client-side {@link SslContext} + * @deprecated Replaced by {@link SslContextBuilder} + */ + @Deprecated + public static SslContext newClientContext() throws SSLException { + return newClientContext(null, null, null); + } + + /** + * Creates a new client-side {@link SslContext}. + * + * @param certChainFile an X.509 certificate chain file in PEM format + * + * @return a new client-side {@link SslContext} + * @deprecated Replaced by {@link SslContextBuilder} + */ + @Deprecated + public static SslContext newClientContext(File certChainFile) throws SSLException { + return newClientContext(null, certChainFile); + } + + /** + * Creates a new client-side {@link SslContext}. + * + * @param trustManagerFactory the {@link TrustManagerFactory} that provides the {@link TrustManager}s + * that verifies the certificates sent from servers. + * {@code null} to use the default. + * + * @return a new client-side {@link SslContext} + * @deprecated Replaced by {@link SslContextBuilder} + */ + @Deprecated + public static SslContext newClientContext(TrustManagerFactory trustManagerFactory) throws SSLException { + return newClientContext(null, null, trustManagerFactory); + } + + /** + * Creates a new client-side {@link SslContext}. + * + * @param certChainFile an X.509 certificate chain file in PEM format. + * {@code null} to use the system default + * @param trustManagerFactory the {@link TrustManagerFactory} that provides the {@link TrustManager}s + * that verifies the certificates sent from servers. + * {@code null} to use the default. + * + * @return a new client-side {@link SslContext} + * @deprecated Replaced by {@link SslContextBuilder} + */ + @Deprecated + public static SslContext newClientContext( + File certChainFile, TrustManagerFactory trustManagerFactory) throws SSLException { + return newClientContext(null, certChainFile, trustManagerFactory); + } + + /** + * Creates a new client-side {@link SslContext}. + * + * @param certChainFile an X.509 certificate chain file in PEM format. + * {@code null} to use the system default + * @param trustManagerFactory the {@link TrustManagerFactory} that provides the {@link TrustManager}s + * that verifies the certificates sent from servers. + * {@code null} to use the default. + * @param ciphers the cipher suites to enable, in the order of preference. + * {@code null} to use the default cipher suites. + * @param nextProtocols the application layer protocols to accept, in the order of preference. + * {@code null} to disable TLS NPN/ALPN extension. + * @param sessionCacheSize the size of the cache used for storing SSL session objects. + * {@code 0} to use the default value. + * @param sessionTimeout the timeout for the cached SSL session objects, in seconds. + * {@code 0} to use the default value. + * + * @return a new client-side {@link SslContext} + * @deprecated Replaced by {@link SslContextBuilder} + */ + @Deprecated + public static SslContext newClientContext( + File certChainFile, TrustManagerFactory trustManagerFactory, + Iterable ciphers, Iterable nextProtocols, + long sessionCacheSize, long sessionTimeout) throws SSLException { + return newClientContext( + null, certChainFile, trustManagerFactory, + ciphers, nextProtocols, sessionCacheSize, sessionTimeout); + } + + /** + * Creates a new client-side {@link SslContext}. + * + * @param certChainFile an X.509 certificate chain file in PEM format. + * {@code null} to use the system default + * @param trustManagerFactory the {@link TrustManagerFactory} that provides the {@link TrustManager}s + * that verifies the certificates sent from servers. + * {@code null} to use the default. + * @param ciphers the cipher suites to enable, in the order of preference. + * {@code null} to use the default cipher suites. + * @param cipherFilter a filter to apply over the supplied list of ciphers + * @param apn Provides a means to configure parameters related to application protocol negotiation. + * @param sessionCacheSize the size of the cache used for storing SSL session objects. + * {@code 0} to use the default value. + * @param sessionTimeout the timeout for the cached SSL session objects, in seconds. + * {@code 0} to use the default value. + * + * @return a new client-side {@link SslContext} + * @deprecated Replaced by {@link SslContextBuilder} + */ + @Deprecated + public static SslContext newClientContext( + File certChainFile, TrustManagerFactory trustManagerFactory, + Iterable ciphers, CipherSuiteFilter cipherFilter, ApplicationProtocolConfig apn, + long sessionCacheSize, long sessionTimeout) throws SSLException { + return newClientContext( + null, certChainFile, trustManagerFactory, + ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout); + } + + /** + * Creates a new client-side {@link SslContext}. + * + * @param provider the {@link SslContext} implementation to use. + * {@code null} to use the current default one. + * + * @return a new client-side {@link SslContext} + * @deprecated Replaced by {@link SslContextBuilder} + */ + @Deprecated + public static SslContext newClientContext(SslProvider provider) throws SSLException { + return newClientContext(provider, null, null); + } + + /** + * Creates a new client-side {@link SslContext}. + * + * @param provider the {@link SslContext} implementation to use. + * {@code null} to use the current default one. + * @param certChainFile an X.509 certificate chain file in PEM format. + * {@code null} to use the system default + * + * @return a new client-side {@link SslContext} + * @deprecated Replaced by {@link SslContextBuilder} + */ + @Deprecated + public static SslContext newClientContext(SslProvider provider, File certChainFile) throws SSLException { + return newClientContext(provider, certChainFile, null); + } + + /** + * Creates a new client-side {@link SslContext}. + * + * @param provider the {@link SslContext} implementation to use. + * {@code null} to use the current default one. + * @param trustManagerFactory the {@link TrustManagerFactory} that provides the {@link TrustManager}s + * that verifies the certificates sent from servers. + * {@code null} to use the default. + * + * @return a new client-side {@link SslContext} + * @deprecated Replaced by {@link SslContextBuilder} + */ + @Deprecated + public static SslContext newClientContext( + SslProvider provider, TrustManagerFactory trustManagerFactory) throws SSLException { + return newClientContext(provider, null, trustManagerFactory); + } + + /** + * Creates a new client-side {@link SslContext}. + * + * @param provider the {@link SslContext} implementation to use. + * {@code null} to use the current default one. + * @param certChainFile an X.509 certificate chain file in PEM format. + * {@code null} to use the system default + * @param trustManagerFactory the {@link TrustManagerFactory} that provides the {@link TrustManager}s + * that verifies the certificates sent from servers. + * {@code null} to use the default. + * + * @return a new client-side {@link SslContext} + * @deprecated Replaced by {@link SslContextBuilder} + */ + @Deprecated + public static SslContext newClientContext( + SslProvider provider, File certChainFile, TrustManagerFactory trustManagerFactory) throws SSLException { + return newClientContext(provider, certChainFile, trustManagerFactory, null, IdentityCipherSuiteFilter.INSTANCE, + null, 0, 0); + } + + /** + * Creates a new client-side {@link SslContext}. + * + * @param provider the {@link SslContext} implementation to use. + * {@code null} to use the current default one. + * @param certChainFile an X.509 certificate chain file in PEM format. + * {@code null} to use the system default + * @param trustManagerFactory the {@link TrustManagerFactory} that provides the {@link TrustManager}s + * that verifies the certificates sent from servers. + * {@code null} to use the default. + * @param ciphers the cipher suites to enable, in the order of preference. + * {@code null} to use the default cipher suites. + * @param nextProtocols the application layer protocols to accept, in the order of preference. + * {@code null} to disable TLS NPN/ALPN extension. + * @param sessionCacheSize the size of the cache used for storing SSL session objects. + * {@code 0} to use the default value. + * @param sessionTimeout the timeout for the cached SSL session objects, in seconds. + * {@code 0} to use the default value. + * + * @return a new client-side {@link SslContext} + * @deprecated Replaced by {@link SslContextBuilder} + */ + @Deprecated + public static SslContext newClientContext( + SslProvider provider, + File certChainFile, TrustManagerFactory trustManagerFactory, + Iterable ciphers, Iterable nextProtocols, + long sessionCacheSize, long sessionTimeout) throws SSLException { + return newClientContext( + provider, certChainFile, trustManagerFactory, null, null, null, null, + ciphers, IdentityCipherSuiteFilter.INSTANCE, + toApplicationProtocolConfig(nextProtocols), sessionCacheSize, sessionTimeout); + } + + /** + * Creates a new client-side {@link SslContext}. + * + * @param provider the {@link SslContext} implementation to use. + * {@code null} to use the current default one. + * @param certChainFile an X.509 certificate chain file in PEM format. + * {@code null} to use the system default + * @param trustManagerFactory the {@link TrustManagerFactory} that provides the {@link TrustManager}s + * that verifies the certificates sent from servers. + * {@code null} to use the default. + * @param ciphers the cipher suites to enable, in the order of preference. + * {@code null} to use the default cipher suites. + * @param cipherFilter a filter to apply over the supplied list of ciphers + * @param apn Provides a means to configure parameters related to application protocol negotiation. + * @param sessionCacheSize the size of the cache used for storing SSL session objects. + * {@code 0} to use the default value. + * @param sessionTimeout the timeout for the cached SSL session objects, in seconds. + * {@code 0} to use the default value. + * + * @return a new client-side {@link SslContext} + * @deprecated Replaced by {@link SslContextBuilder} + */ + @Deprecated + public static SslContext newClientContext( + SslProvider provider, + File certChainFile, TrustManagerFactory trustManagerFactory, + Iterable ciphers, CipherSuiteFilter cipherFilter, ApplicationProtocolConfig apn, + long sessionCacheSize, long sessionTimeout) throws SSLException { + + return newClientContext( + provider, certChainFile, trustManagerFactory, null, null, null, null, + ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout); + } + + /** + * Creates a new client-side {@link SslContext}. + * @param provider the {@link SslContext} implementation to use. + * {@code null} to use the current default one. + * @param trustCertCollectionFile an X.509 certificate collection file in PEM format. + * {@code null} to use the system default + * @param trustManagerFactory the {@link TrustManagerFactory} that provides the {@link TrustManager}s + * that verifies the certificates sent from servers. + * {@code null} to use the default or the results of parsing + * {@code trustCertCollectionFile}. + * This parameter is ignored if {@code provider} is not {@link SslProvider#JDK}. + * @param keyCertChainFile an X.509 certificate chain file in PEM format. + * This provides the public key for mutual authentication. + * {@code null} to use the system default + * @param keyFile a PKCS#8 private key file in PEM format. + * This provides the private key for mutual authentication. + * {@code null} for no mutual authentication. + * @param keyPassword the password of the {@code keyFile}. + * {@code null} if it's not password-protected. + * Ignored if {@code keyFile} is {@code null}. + * @param keyManagerFactory the {@link KeyManagerFactory} that provides the {@link KeyManager}s + * that is used to encrypt data being sent to servers. + * {@code null} to use the default or the results of parsing + * {@code keyCertChainFile} and {@code keyFile}. + * This parameter is ignored if {@code provider} is not {@link SslProvider#JDK}. + * @param ciphers the cipher suites to enable, in the order of preference. + * {@code null} to use the default cipher suites. + * @param cipherFilter a filter to apply over the supplied list of ciphers + * @param apn Provides a means to configure parameters related to application protocol negotiation. + * @param sessionCacheSize the size of the cache used for storing SSL session objects. + * {@code 0} to use the default value. + * @param sessionTimeout the timeout for the cached SSL session objects, in seconds. + * {@code 0} to use the default value. + * + * @return a new client-side {@link SslContext} + * @deprecated Replaced by {@link SslContextBuilder} + */ + @Deprecated + public static SslContext newClientContext( + SslProvider provider, + File trustCertCollectionFile, TrustManagerFactory trustManagerFactory, + File keyCertChainFile, File keyFile, String keyPassword, + KeyManagerFactory keyManagerFactory, + Iterable ciphers, CipherSuiteFilter cipherFilter, ApplicationProtocolConfig apn, + long sessionCacheSize, long sessionTimeout) throws SSLException { + try { + return newClientContextInternal(provider, null, + toX509Certificates(trustCertCollectionFile), trustManagerFactory, + toX509Certificates(keyCertChainFile), toPrivateKey(keyFile, keyPassword), + keyPassword, keyManagerFactory, ciphers, cipherFilter, + apn, null, sessionCacheSize, sessionTimeout, false, + KeyStore.getDefaultType()); + } catch (Exception e) { + if (e instanceof SSLException) { + throw (SSLException) e; + } + throw new SSLException("failed to initialize the client-side SSL context", e); + } + } + + static SslContext newClientContextInternal( + SslProvider provider, + Provider sslContextProvider, + X509Certificate[] trustCert, TrustManagerFactory trustManagerFactory, + X509Certificate[] keyCertChain, PrivateKey key, String keyPassword, KeyManagerFactory keyManagerFactory, + Iterable ciphers, CipherSuiteFilter cipherFilter, ApplicationProtocolConfig apn, String[] protocols, + long sessionCacheSize, long sessionTimeout, boolean enableOcsp, String keyStoreType, + Map.Entry, Object>... options) throws SSLException { + if (provider == null) { + provider = defaultClientProvider(); + } + switch (provider) { + case JDK: + if (enableOcsp) { + throw new IllegalArgumentException("OCSP is not supported with this SslProvider: " + provider); + } + return new JdkSslClientContext(sslContextProvider, + trustCert, trustManagerFactory, keyCertChain, key, keyPassword, + keyManagerFactory, ciphers, cipherFilter, apn, protocols, sessionCacheSize, + sessionTimeout, keyStoreType); + case OPENSSL: + verifyNullSslContextProvider(provider, sslContextProvider); + OpenSsl.ensureAvailability(); + return new OpenSslClientContext( + trustCert, trustManagerFactory, keyCertChain, key, keyPassword, + keyManagerFactory, ciphers, cipherFilter, apn, protocols, sessionCacheSize, sessionTimeout, + enableOcsp, keyStoreType, options); + case OPENSSL_REFCNT: + verifyNullSslContextProvider(provider, sslContextProvider); + OpenSsl.ensureAvailability(); + return new ReferenceCountedOpenSslClientContext( + trustCert, trustManagerFactory, keyCertChain, key, keyPassword, + keyManagerFactory, ciphers, cipherFilter, apn, protocols, sessionCacheSize, sessionTimeout, + enableOcsp, keyStoreType, options); + default: + throw new Error(provider.toString()); + } + } + + static ApplicationProtocolConfig toApplicationProtocolConfig(Iterable nextProtocols) { + ApplicationProtocolConfig apn; + if (nextProtocols == null) { + apn = ApplicationProtocolConfig.DISABLED; + } else { + apn = new ApplicationProtocolConfig( + Protocol.NPN_AND_ALPN, SelectorFailureBehavior.CHOOSE_MY_LAST_PROTOCOL, + SelectedListenerFailureBehavior.ACCEPT, nextProtocols); + } + return apn; + } + + /** + * Creates a new instance (startTls set to {@code false}). + */ + protected SslContext() { + this(false); + } + + /** + * Creates a new instance. + */ + protected SslContext(boolean startTls) { + this.startTls = startTls; + } + + /** + * Returns the {@link AttributeMap} that belongs to this {@link SslContext} . + */ + public final AttributeMap attributes() { + return attributes; + } + + /** + * Returns {@code true} if and only if this context is for server-side. + */ + public final boolean isServer() { + return !isClient(); + } + + /** + * Returns the {@code true} if and only if this context is for client-side. + */ + public abstract boolean isClient(); + + /** + * Returns the list of enabled cipher suites, in the order of preference. + */ + public abstract List cipherSuites(); + + /** + * Returns the size of the cache used for storing SSL session objects. + */ + public long sessionCacheSize() { + return sessionContext().getSessionCacheSize(); + } + + /** + * Returns the timeout for the cached SSL session objects, in seconds. + */ + public long sessionTimeout() { + return sessionContext().getSessionTimeout(); + } + + /** + * @deprecated Use {@link #applicationProtocolNegotiator()} instead. + */ + @Deprecated + public final List nextProtocols() { + return applicationProtocolNegotiator().protocols(); + } + + /** + * Returns the object responsible for negotiating application layer protocols for the TLS NPN/ALPN extensions. + */ + public abstract ApplicationProtocolNegotiator applicationProtocolNegotiator(); + + /** + * Creates a new {@link SSLEngine}. + *

If {@link SslProvider#OPENSSL_REFCNT} is used then the object must be released. One way to do this is to + * wrap in a {@link SslHandler} and insert it into a pipeline. See {@link #newHandler(ByteBufAllocator)}. + * @return a new {@link SSLEngine} + */ + public abstract SSLEngine newEngine(ByteBufAllocator alloc); + + /** + * Creates a new {@link SSLEngine} using advisory peer information. + *

If {@link SslProvider#OPENSSL_REFCNT} is used then the object must be released. One way to do this is to + * wrap in a {@link SslHandler} and insert it into a pipeline. + * See {@link #newHandler(ByteBufAllocator, String, int)}. + * @param peerHost the non-authoritative name of the host + * @param peerPort the non-authoritative port + * + * @return a new {@link SSLEngine} + */ + public abstract SSLEngine newEngine(ByteBufAllocator alloc, String peerHost, int peerPort); + + /** + * Returns the {@link SSLSessionContext} object held by this context. + */ + public abstract SSLSessionContext sessionContext(); + + /** + * Create a new SslHandler. + * @see #newHandler(ByteBufAllocator, Executor) + */ + public final SslHandler newHandler(ByteBufAllocator alloc) { + return newHandler(alloc, startTls); + } + + /** + * Create a new SslHandler. + * @see #newHandler(ByteBufAllocator) + */ + protected SslHandler newHandler(ByteBufAllocator alloc, boolean startTls) { + return new SslHandler(newEngine(alloc), startTls); + } + + /** + * Creates a new {@link SslHandler}. + *

If {@link SslProvider#OPENSSL_REFCNT} is used then the returned {@link SslHandler} will release the engine + * that is wrapped. If the returned {@link SslHandler} is not inserted into a pipeline then you may leak native + * memory! + *

Beware: the underlying generated {@link SSLEngine} won't have + * hostname verification enabled by default. + * If you create {@link SslHandler} for the client side and want proper security, we advice that you configure + * the {@link SSLEngine} (see {@link javax.net.ssl.SSLParameters#setEndpointIdentificationAlgorithm(String)}): + *

+     * SSLEngine sslEngine = sslHandler.engine();
+     * SSLParameters sslParameters = sslEngine.getSSLParameters();
+     * // only available since Java 7
+     * sslParameters.setEndpointIdentificationAlgorithm("HTTPS");
+     * sslEngine.setSSLParameters(sslParameters);
+     * 
+ *

+ * The underlying {@link SSLEngine} may not follow the restrictions imposed by the + * SSLEngine javadocs which + * limits wrap/unwrap to operate on a single SSL/TLS packet. + * @param alloc If supported by the SSLEngine then the SSLEngine will use this to allocate ByteBuf objects. + * @param delegatedTaskExecutor the {@link Executor} that will be used to execute tasks that are returned by + * {@link SSLEngine#getDelegatedTask()}. + * @return a new {@link SslHandler} + */ + public SslHandler newHandler(ByteBufAllocator alloc, Executor delegatedTaskExecutor) { + return newHandler(alloc, startTls, delegatedTaskExecutor); + } + + /** + * Create a new SslHandler. + * @see #newHandler(ByteBufAllocator, String, int, boolean, Executor) + */ + protected SslHandler newHandler(ByteBufAllocator alloc, boolean startTls, Executor executor) { + return new SslHandler(newEngine(alloc), startTls, executor); + } + + /** + * Creates a new {@link SslHandler} + * + * @see #newHandler(ByteBufAllocator, String, int, Executor) + */ + public final SslHandler newHandler(ByteBufAllocator alloc, String peerHost, int peerPort) { + return newHandler(alloc, peerHost, peerPort, startTls); + } + + /** + * Create a new SslHandler. + * @see #newHandler(ByteBufAllocator, String, int, boolean, Executor) + */ + protected SslHandler newHandler(ByteBufAllocator alloc, String peerHost, int peerPort, boolean startTls) { + return new SslHandler(newEngine(alloc, peerHost, peerPort), startTls); + } + + /** + * Creates a new {@link SslHandler} with advisory peer information. + *

If {@link SslProvider#OPENSSL_REFCNT} is used then the returned {@link SslHandler} will release the engine + * that is wrapped. If the returned {@link SslHandler} is not inserted into a pipeline then you may leak native + * memory! + *

Beware: the underlying generated {@link SSLEngine} won't have + * hostname verification enabled by default. + * If you create {@link SslHandler} for the client side and want proper security, we advice that you configure + * the {@link SSLEngine} (see {@link javax.net.ssl.SSLParameters#setEndpointIdentificationAlgorithm(String)}): + *

+     * SSLEngine sslEngine = sslHandler.engine();
+     * SSLParameters sslParameters = sslEngine.getSSLParameters();
+     * // only available since Java 7
+     * sslParameters.setEndpointIdentificationAlgorithm("HTTPS");
+     * sslEngine.setSSLParameters(sslParameters);
+     * 
+ *

+ * The underlying {@link SSLEngine} may not follow the restrictions imposed by the + * SSLEngine javadocs which + * limits wrap/unwrap to operate on a single SSL/TLS packet. + * @param alloc If supported by the SSLEngine then the SSLEngine will use this to allocate ByteBuf objects. + * @param peerHost the non-authoritative name of the host + * @param peerPort the non-authoritative port + * @param delegatedTaskExecutor the {@link Executor} that will be used to execute tasks that are returned by + * {@link SSLEngine#getDelegatedTask()}. + * + * @return a new {@link SslHandler} + */ + public SslHandler newHandler(ByteBufAllocator alloc, String peerHost, int peerPort, + Executor delegatedTaskExecutor) { + return newHandler(alloc, peerHost, peerPort, startTls, delegatedTaskExecutor); + } + + protected SslHandler newHandler(ByteBufAllocator alloc, String peerHost, int peerPort, boolean startTls, + Executor delegatedTaskExecutor) { + return new SslHandler(newEngine(alloc, peerHost, peerPort), startTls, delegatedTaskExecutor); + } + + /** + * Generates a key specification for an (encrypted) private key. + * + * @param password characters, if {@code null} an unencrypted key is assumed + * @param key bytes of the DER encoded private key + * + * @return a key specification + * + * @throws IOException if parsing {@code key} fails + * @throws NoSuchAlgorithmException if the algorithm used to encrypt {@code key} is unknown + * @throws NoSuchPaddingException if the padding scheme specified in the decryption algorithm is unknown + * @throws InvalidKeySpecException if the decryption key based on {@code password} cannot be generated + * @throws InvalidKeyException if the decryption key based on {@code password} cannot be used to decrypt + * {@code key} + * @throws InvalidAlgorithmParameterException if decryption algorithm parameters are somehow faulty + */ + @Deprecated + protected static PKCS8EncodedKeySpec generateKeySpec(char[] password, byte[] key) + throws IOException, NoSuchAlgorithmException, NoSuchPaddingException, InvalidKeySpecException, + InvalidKeyException, InvalidAlgorithmParameterException { + + if (password == null) { + return new PKCS8EncodedKeySpec(key); + } + + EncryptedPrivateKeyInfo encryptedPrivateKeyInfo = new EncryptedPrivateKeyInfo(key); + String pbeAlgorithm = getPBEAlgorithm(encryptedPrivateKeyInfo); + SecretKeyFactory keyFactory = SecretKeyFactory.getInstance(pbeAlgorithm); + PBEKeySpec pbeKeySpec = new PBEKeySpec(password); + SecretKey pbeKey = keyFactory.generateSecret(pbeKeySpec); + + Cipher cipher = Cipher.getInstance(pbeAlgorithm); + cipher.init(Cipher.DECRYPT_MODE, pbeKey, encryptedPrivateKeyInfo.getAlgParameters()); + + return encryptedPrivateKeyInfo.getKeySpec(cipher); + } + + private static String getPBEAlgorithm(EncryptedPrivateKeyInfo encryptedPrivateKeyInfo) { + AlgorithmParameters parameters = encryptedPrivateKeyInfo.getAlgParameters(); + String algName = encryptedPrivateKeyInfo.getAlgName(); + // Java 8 ~ 16 returns OID_PKCS5_PBES2 + // Java 17+ returns PBES2 + if (PlatformDependent.javaVersion() >= 8 && parameters != null && + (OID_PKCS5_PBES2.equals(algName) || PBES2.equals(algName))) { + /* + * This should be "PBEWithAnd". + * Relying on the toString() implementation is potentially + * fragile but acceptable in this case since the JRE depends on + * the toString() implementation as well. + * In the future, if necessary, we can parse the value of + * parameters.getEncoded() but the associated complexity and + * unlikeliness of the JRE implementation changing means that + * Tomcat will use to toString() approach for now. + */ + return parameters.toString(); + } + return encryptedPrivateKeyInfo.getAlgName(); + } + + /** + * Generates a new {@link KeyStore}. + * + * @param certChain an X.509 certificate chain + * @param key a PKCS#8 private key + * @param keyPasswordChars the password of the {@code keyFile}. + * {@code null} if it's not password-protected. + * @param keyStoreType The KeyStore Type you want to use + * @return generated {@link KeyStore}. + */ + protected static KeyStore buildKeyStore(X509Certificate[] certChain, PrivateKey key, + char[] keyPasswordChars, String keyStoreType) + throws KeyStoreException, NoSuchAlgorithmException, + CertificateException, IOException { + if (keyStoreType == null) { + keyStoreType = KeyStore.getDefaultType(); + } + KeyStore ks = KeyStore.getInstance(keyStoreType); + ks.load(null, null); + ks.setKeyEntry(ALIAS, key, keyPasswordChars, certChain); + return ks; + } + + protected static PrivateKey toPrivateKey(File keyFile, String keyPassword) throws NoSuchAlgorithmException, + NoSuchPaddingException, InvalidKeySpecException, + InvalidAlgorithmParameterException, + KeyException, IOException { + return toPrivateKey(keyFile, keyPassword, true); + } + + static PrivateKey toPrivateKey(File keyFile, String keyPassword, boolean tryBouncyCastle) + throws NoSuchAlgorithmException, NoSuchPaddingException, InvalidKeySpecException, + InvalidAlgorithmParameterException, + KeyException, IOException { + if (keyFile == null) { + return null; + } + + // try BC first, if this fail fallback to original key extraction process + if (tryBouncyCastle && BouncyCastlePemReader.isAvailable()) { + PrivateKey pk = BouncyCastlePemReader.getPrivateKey(keyFile, keyPassword); + if (pk != null) { + return pk; + } + } + + return getPrivateKeyFromByteBuffer(PemReader.readPrivateKey(keyFile), keyPassword); + } + + protected static PrivateKey toPrivateKey(InputStream keyInputStream, String keyPassword) + throws NoSuchAlgorithmException, + NoSuchPaddingException, InvalidKeySpecException, + InvalidAlgorithmParameterException, + KeyException, IOException { + if (keyInputStream == null) { + return null; + } + + // try BC first, if this fail fallback to original key extraction process + if (BouncyCastlePemReader.isAvailable()) { + if (!keyInputStream.markSupported()) { + // We need an input stream that supports resetting, in case BouncyCastle fails to read. + keyInputStream = new BufferedInputStream(keyInputStream); + } + keyInputStream.mark(1048576); // Be able to reset up to 1 MiB of data. + PrivateKey pk = BouncyCastlePemReader.getPrivateKey(keyInputStream, keyPassword); + if (pk != null) { + return pk; + } + // BouncyCastle could not read the key. Reset the input stream in case the input position changed. + keyInputStream.reset(); + } + + return getPrivateKeyFromByteBuffer(PemReader.readPrivateKey(keyInputStream), keyPassword); + } + + private static PrivateKey getPrivateKeyFromByteBuffer(ByteBuf encodedKeyBuf, String keyPassword) + throws NoSuchAlgorithmException, NoSuchPaddingException, InvalidKeySpecException, + InvalidAlgorithmParameterException, KeyException, IOException { + + byte[] encodedKey = new byte[encodedKeyBuf.readableBytes()]; + encodedKeyBuf.readBytes(encodedKey).release(); + + PKCS8EncodedKeySpec encodedKeySpec = generateKeySpec( + keyPassword == null ? null : keyPassword.toCharArray(), encodedKey); + try { + return KeyFactory.getInstance("RSA").generatePrivate(encodedKeySpec); + } catch (InvalidKeySpecException ignore) { + try { + return KeyFactory.getInstance("DSA").generatePrivate(encodedKeySpec); + } catch (InvalidKeySpecException ignore2) { + try { + return KeyFactory.getInstance("EC").generatePrivate(encodedKeySpec); + } catch (InvalidKeySpecException e) { + throw new InvalidKeySpecException("Neither RSA, DSA nor EC worked", e); + } + } + } + } + + /** + * Build a {@link TrustManagerFactory} from a certificate chain file. + * @param certChainFile The certificate file to build from. + * @param trustManagerFactory The existing {@link TrustManagerFactory} that will be used if not {@code null}. + * @return A {@link TrustManagerFactory} which contains the certificates in {@code certChainFile} + */ + @Deprecated + protected static TrustManagerFactory buildTrustManagerFactory( + File certChainFile, TrustManagerFactory trustManagerFactory) + throws NoSuchAlgorithmException, CertificateException, KeyStoreException, IOException { + return buildTrustManagerFactory(certChainFile, trustManagerFactory, null); + } + + /** + * Build a {@link TrustManagerFactory} from a certificate chain file. + * @param certChainFile The certificate file to build from. + * @param trustManagerFactory The existing {@link TrustManagerFactory} that will be used if not {@code null}. + * @param keyType The KeyStore Type you want to use + * @return A {@link TrustManagerFactory} which contains the certificates in {@code certChainFile} + */ + protected static TrustManagerFactory buildTrustManagerFactory( + File certChainFile, TrustManagerFactory trustManagerFactory, String keyType) + throws NoSuchAlgorithmException, CertificateException, KeyStoreException, IOException { + X509Certificate[] x509Certs = toX509Certificates(certChainFile); + + return buildTrustManagerFactory(x509Certs, trustManagerFactory, keyType); + } + + protected static X509Certificate[] toX509Certificates(File file) throws CertificateException { + if (file == null) { + return null; + } + return getCertificatesFromBuffers(PemReader.readCertificates(file)); + } + + protected static X509Certificate[] toX509Certificates(InputStream in) throws CertificateException { + if (in == null) { + return null; + } + return getCertificatesFromBuffers(PemReader.readCertificates(in)); + } + + private static X509Certificate[] getCertificatesFromBuffers(ByteBuf[] certs) throws CertificateException { + CertificateFactory cf = CertificateFactory.getInstance("X.509"); + X509Certificate[] x509Certs = new X509Certificate[certs.length]; + + try { + for (int i = 0; i < certs.length; i++) { + InputStream is = new ByteBufInputStream(certs[i], false); + try { + x509Certs[i] = (X509Certificate) cf.generateCertificate(is); + } finally { + try { + is.close(); + } catch (IOException e) { + // This is not expected to happen, but re-throw in case it does. + throw new RuntimeException(e); + } + } + } + } finally { + for (ByteBuf buf: certs) { + buf.release(); + } + } + return x509Certs; + } + + protected static TrustManagerFactory buildTrustManagerFactory( + X509Certificate[] certCollection, TrustManagerFactory trustManagerFactory, String keyStoreType) + throws NoSuchAlgorithmException, CertificateException, KeyStoreException, IOException { + if (keyStoreType == null) { + keyStoreType = KeyStore.getDefaultType(); + } + final KeyStore ks = KeyStore.getInstance(keyStoreType); + ks.load(null, null); + + int i = 1; + for (X509Certificate cert: certCollection) { + String alias = Integer.toString(i); + ks.setCertificateEntry(alias, cert); + i++; + } + + // Set up trust manager factory to use our key store. + if (trustManagerFactory == null) { + trustManagerFactory = TrustManagerFactory.getInstance(TrustManagerFactory.getDefaultAlgorithm()); + } + trustManagerFactory.init(ks); + + return trustManagerFactory; + } + + static PrivateKey toPrivateKeyInternal(File keyFile, String keyPassword) throws SSLException { + try { + return toPrivateKey(keyFile, keyPassword); + } catch (Exception e) { + throw new SSLException(e); + } + } + + static X509Certificate[] toX509CertificatesInternal(File file) throws SSLException { + try { + return toX509Certificates(file); + } catch (CertificateException e) { + throw new SSLException(e); + } + } + + protected static KeyManagerFactory buildKeyManagerFactory(X509Certificate[] certChainFile, + String keyAlgorithm, PrivateKey key, + String keyPassword, KeyManagerFactory kmf, + String keyStore) + throws KeyStoreException, NoSuchAlgorithmException, IOException, + CertificateException, UnrecoverableKeyException { + if (keyAlgorithm == null) { + keyAlgorithm = KeyManagerFactory.getDefaultAlgorithm(); + } + char[] keyPasswordChars = keyStorePassword(keyPassword); + KeyStore ks = buildKeyStore(certChainFile, key, keyPasswordChars, keyStore); + return buildKeyManagerFactory(ks, keyAlgorithm, keyPasswordChars, kmf); + } + + static KeyManagerFactory buildKeyManagerFactory(KeyStore ks, + String keyAlgorithm, + char[] keyPasswordChars, KeyManagerFactory kmf) + throws KeyStoreException, NoSuchAlgorithmException, UnrecoverableKeyException { + // Set up key manager factory to use our key store + if (kmf == null) { + if (keyAlgorithm == null) { + keyAlgorithm = KeyManagerFactory.getDefaultAlgorithm(); + } + kmf = KeyManagerFactory.getInstance(keyAlgorithm); + } + kmf.init(ks, keyPasswordChars); + + return kmf; + } + + static char[] keyStorePassword(String keyPassword) { + return keyPassword == null ? EmptyArrays.EMPTY_CHARS : keyPassword.toCharArray(); + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/SslContextBuilder.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/SslContextBuilder.java new file mode 100644 index 0000000..3fc8f6b --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/SslContextBuilder.java @@ -0,0 +1,632 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.ssl; + +import io.netty.handler.ssl.util.KeyManagerFactoryWrapper; +import io.netty.handler.ssl.util.TrustManagerFactoryWrapper; +import io.netty.util.internal.UnstableApi; + +import javax.net.ssl.KeyManager; +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLException; +import javax.net.ssl.TrustManager; +import javax.net.ssl.TrustManagerFactory; +import java.io.File; +import java.io.InputStream; +import java.security.KeyStore; +import java.security.PrivateKey; +import java.security.Provider; +import java.security.cert.X509Certificate; +import java.util.ArrayList; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + +import static io.netty.util.internal.EmptyArrays.EMPTY_STRINGS; +import static io.netty.util.internal.EmptyArrays.EMPTY_X509_CERTIFICATES; +import static io.netty.util.internal.ObjectUtil.checkNotNull; +import static io.netty.util.internal.ObjectUtil.checkNotNullWithIAE; +import static io.netty.util.internal.ObjectUtil.checkNonEmpty; + +/** + * Builder for configuring a new SslContext for creation. + */ +public final class SslContextBuilder { + @SuppressWarnings("rawtypes") + private static final Map.Entry[] EMPTY_ENTRIES = new Map.Entry[0]; + + /** + * Creates a builder for new client-side {@link SslContext}. + */ + public static SslContextBuilder forClient() { + return new SslContextBuilder(false); + } + + /** + * Creates a builder for new server-side {@link SslContext}. + * + * @param keyCertChainFile an X.509 certificate chain file in PEM format + * @param keyFile a PKCS#8 private key file in PEM format + * @see #keyManager(File, File) + */ + public static SslContextBuilder forServer(File keyCertChainFile, File keyFile) { + return new SslContextBuilder(true).keyManager(keyCertChainFile, keyFile); + } + + /** + * Creates a builder for new server-side {@link SslContext}. + * + * @param keyCertChainInputStream an input stream for an X.509 certificate chain in PEM format. The caller is + * responsible for calling {@link InputStream#close()} after {@link #build()} + * has been called. + * @param keyInputStream an input stream for a PKCS#8 private key in PEM format. The caller is + * responsible for calling {@link InputStream#close()} after {@link #build()} + * has been called. + * + * @see #keyManager(InputStream, InputStream) + */ + public static SslContextBuilder forServer(InputStream keyCertChainInputStream, InputStream keyInputStream) { + return new SslContextBuilder(true).keyManager(keyCertChainInputStream, keyInputStream); + } + + /** + * Creates a builder for new server-side {@link SslContext}. + * + * @param key a PKCS#8 private key + * @param keyCertChain the X.509 certificate chain + * @see #keyManager(PrivateKey, X509Certificate[]) + */ + public static SslContextBuilder forServer(PrivateKey key, X509Certificate... keyCertChain) { + return new SslContextBuilder(true).keyManager(key, keyCertChain); + } + + /** + * Creates a builder for new server-side {@link SslContext}. + * + * @param key a PKCS#8 private key + * @param keyCertChain the X.509 certificate chain + * @see #keyManager(PrivateKey, X509Certificate[]) + */ + public static SslContextBuilder forServer(PrivateKey key, Iterable keyCertChain) { + return forServer(key, toArray(keyCertChain, EMPTY_X509_CERTIFICATES)); + } + + /** + * Creates a builder for new server-side {@link SslContext}. + * + * @param keyCertChainFile an X.509 certificate chain file in PEM format + * @param keyFile a PKCS#8 private key file in PEM format + * @param keyPassword the password of the {@code keyFile}, or {@code null} if it's not + * password-protected + * @see #keyManager(File, File, String) + */ + public static SslContextBuilder forServer( + File keyCertChainFile, File keyFile, String keyPassword) { + return new SslContextBuilder(true).keyManager(keyCertChainFile, keyFile, keyPassword); + } + + /** + * Creates a builder for new server-side {@link SslContext}. + * + * @param keyCertChainInputStream an input stream for an X.509 certificate chain in PEM format. The caller is + * responsible for calling {@link InputStream#close()} after {@link #build()} + * has been called. + * @param keyInputStream an input stream for a PKCS#8 private key in PEM format. The caller is + * responsible for calling {@link InputStream#close()} after {@link #build()} + * has been called. + * @param keyPassword the password of the {@code keyFile}, or {@code null} if it's not + * password-protected + * @see #keyManager(InputStream, InputStream, String) + */ + public static SslContextBuilder forServer( + InputStream keyCertChainInputStream, InputStream keyInputStream, String keyPassword) { + return new SslContextBuilder(true).keyManager(keyCertChainInputStream, keyInputStream, keyPassword); + } + + /** + * Creates a builder for new server-side {@link SslContext}. + * + * @param key a PKCS#8 private key + * @param keyCertChain the X.509 certificate chain + * @param keyPassword the password of the {@code keyFile}, or {@code null} if it's not + * password-protected + * @see #keyManager(File, File, String) + */ + public static SslContextBuilder forServer( + PrivateKey key, String keyPassword, X509Certificate... keyCertChain) { + return new SslContextBuilder(true).keyManager(key, keyPassword, keyCertChain); + } + + /** + * Creates a builder for new server-side {@link SslContext}. + * + * @param key a PKCS#8 private key + * @param keyCertChain the X.509 certificate chain + * @param keyPassword the password of the {@code keyFile}, or {@code null} if it's not + * password-protected + * @see #keyManager(File, File, String) + */ + public static SslContextBuilder forServer( + PrivateKey key, String keyPassword, Iterable keyCertChain) { + return forServer(key, keyPassword, toArray(keyCertChain, EMPTY_X509_CERTIFICATES)); + } + + /** + * Creates a builder for new server-side {@link SslContext}. + * + * If you use {@link SslProvider#OPENSSL} or {@link SslProvider#OPENSSL_REFCNT} consider using + * {@link OpenSslX509KeyManagerFactory} or {@link OpenSslCachingX509KeyManagerFactory}. + * + * @param keyManagerFactory non-{@code null} factory for server's private key + * @see #keyManager(KeyManagerFactory) + */ + public static SslContextBuilder forServer(KeyManagerFactory keyManagerFactory) { + return new SslContextBuilder(true).keyManager(keyManagerFactory); + } + + /** + * Creates a builder for new server-side {@link SslContext} with {@link KeyManager}. + * + * @param keyManager non-{@code null} KeyManager for server's private key + */ + public static SslContextBuilder forServer(KeyManager keyManager) { + return new SslContextBuilder(true).keyManager(keyManager); + } + + private final boolean forServer; + private SslProvider provider; + private Provider sslContextProvider; + private X509Certificate[] trustCertCollection; + private TrustManagerFactory trustManagerFactory; + private X509Certificate[] keyCertChain; + private PrivateKey key; + private String keyPassword; + private KeyManagerFactory keyManagerFactory; + private Iterable ciphers; + private CipherSuiteFilter cipherFilter = IdentityCipherSuiteFilter.INSTANCE; + private ApplicationProtocolConfig apn; + private long sessionCacheSize; + private long sessionTimeout; + private ClientAuth clientAuth = ClientAuth.NONE; + private String[] protocols; + private boolean startTls; + private boolean enableOcsp; + private String keyStoreType = KeyStore.getDefaultType(); + private final Map, Object> options = new HashMap, Object>(); + + private SslContextBuilder(boolean forServer) { + this.forServer = forServer; + } + + /** + * Configure a {@link SslContextOption}. + */ + public SslContextBuilder option(SslContextOption option, T value) { + if (value == null) { + options.remove(option); + } else { + options.put(option, value); + } + return this; + } + + /** + * The {@link SslContext} implementation to use. {@code null} uses the default one. + */ + public SslContextBuilder sslProvider(SslProvider provider) { + this.provider = provider; + return this; + } + + /** + * Sets the {@link KeyStore} type that should be used. {@code null} uses the default one. + */ + public SslContextBuilder keyStoreType(String keyStoreType) { + this.keyStoreType = keyStoreType; + return this; + } + + /** + * The SSLContext {@link Provider} to use. {@code null} uses the default one. This is only + * used with {@link SslProvider#JDK}. + */ + public SslContextBuilder sslContextProvider(Provider sslContextProvider) { + this.sslContextProvider = sslContextProvider; + return this; + } + + /** + * Trusted certificates for verifying the remote endpoint's certificate. The file should + * contain an X.509 certificate collection in PEM format. {@code null} uses the system default. + */ + public SslContextBuilder trustManager(File trustCertCollectionFile) { + try { + return trustManager(SslContext.toX509Certificates(trustCertCollectionFile)); + } catch (Exception e) { + throw new IllegalArgumentException("File does not contain valid certificates: " + + trustCertCollectionFile, e); + } + } + + /** + * Trusted certificates for verifying the remote endpoint's certificate. The input stream should + * contain an X.509 certificate collection in PEM format. {@code null} uses the system default. + * + * The caller is responsible for calling {@link InputStream#close()} after {@link #build()} has been called. + */ + public SslContextBuilder trustManager(InputStream trustCertCollectionInputStream) { + try { + return trustManager(SslContext.toX509Certificates(trustCertCollectionInputStream)); + } catch (Exception e) { + throw new IllegalArgumentException("Input stream does not contain valid certificates.", e); + } + } + + /** + * Trusted certificates for verifying the remote endpoint's certificate, {@code null} uses the system default. + */ + public SslContextBuilder trustManager(X509Certificate... trustCertCollection) { + this.trustCertCollection = trustCertCollection != null ? trustCertCollection.clone() : null; + trustManagerFactory = null; + return this; + } + + /** + * Trusted certificates for verifying the remote endpoint's certificate, {@code null} uses the system default. + */ + public SslContextBuilder trustManager(Iterable trustCertCollection) { + return trustManager(toArray(trustCertCollection, EMPTY_X509_CERTIFICATES)); + } + + /** + * Trusted manager for verifying the remote endpoint's certificate. {@code null} uses the system default. + */ + public SslContextBuilder trustManager(TrustManagerFactory trustManagerFactory) { + trustCertCollection = null; + this.trustManagerFactory = trustManagerFactory; + return this; + } + + /** + * A single trusted manager for verifying the remote endpoint's certificate. + * This is helpful when custom implementation of {@link TrustManager} is needed. + * Internally, a simple wrapper of {@link TrustManagerFactory} that only produces this + * specified {@link TrustManager} will be created, thus all the requirements specified in + * {@link #trustManager(TrustManagerFactory trustManagerFactory)} also apply here. + */ + public SslContextBuilder trustManager(TrustManager trustManager) { + if (trustManager != null) { + this.trustManagerFactory = new TrustManagerFactoryWrapper(trustManager); + } else { + this.trustManagerFactory = null; + } + trustCertCollection = null; + return this; + } + + /** + * Identifying certificate for this host. {@code keyCertChainFile} and {@code keyFile} may + * be {@code null} for client contexts, which disables mutual authentication. + * + * @param keyCertChainFile an X.509 certificate chain file in PEM format + * @param keyFile a PKCS#8 private key file in PEM format + */ + public SslContextBuilder keyManager(File keyCertChainFile, File keyFile) { + return keyManager(keyCertChainFile, keyFile, null); + } + + /** + * Identifying certificate for this host. {@code keyCertChainInputStream} and {@code keyInputStream} may + * be {@code null} for client contexts, which disables mutual authentication. + * + * @param keyCertChainInputStream an input stream for an X.509 certificate chain in PEM format. The caller is + * responsible for calling {@link InputStream#close()} after {@link #build()} + * has been called. + * @param keyInputStream an input stream for a PKCS#8 private key in PEM format. The caller is + * responsible for calling {@link InputStream#close()} after {@link #build()} + * has been called. + */ + public SslContextBuilder keyManager(InputStream keyCertChainInputStream, InputStream keyInputStream) { + return keyManager(keyCertChainInputStream, keyInputStream, null); + } + + /** + * Identifying certificate for this host. {@code keyCertChain} and {@code key} may + * be {@code null} for client contexts, which disables mutual authentication. + * + * @param key a PKCS#8 private key + * @param keyCertChain an X.509 certificate chain + */ + public SslContextBuilder keyManager(PrivateKey key, X509Certificate... keyCertChain) { + return keyManager(key, null, keyCertChain); + } + + /** + * Identifying certificate for this host. {@code keyCertChain} and {@code key} may + * be {@code null} for client contexts, which disables mutual authentication. + * + * @param key a PKCS#8 private key + * @param keyCertChain an X.509 certificate chain + */ + public SslContextBuilder keyManager(PrivateKey key, Iterable keyCertChain) { + return keyManager(key, toArray(keyCertChain, EMPTY_X509_CERTIFICATES)); + } + + /** + * Identifying certificate for this host. {@code keyCertChainFile} and {@code keyFile} may + * be {@code null} for client contexts, which disables mutual authentication. + * + * @param keyCertChainFile an X.509 certificate chain file in PEM format + * @param keyFile a PKCS#8 private key file in PEM format + * @param keyPassword the password of the {@code keyFile}, or {@code null} if it's not + * password-protected + */ + public SslContextBuilder keyManager(File keyCertChainFile, File keyFile, String keyPassword) { + X509Certificate[] keyCertChain; + PrivateKey key; + try { + keyCertChain = SslContext.toX509Certificates(keyCertChainFile); + } catch (Exception e) { + throw new IllegalArgumentException("File does not contain valid certificates: " + keyCertChainFile, e); + } + try { + key = SslContext.toPrivateKey(keyFile, keyPassword); + } catch (Exception e) { + throw new IllegalArgumentException("File does not contain valid private key: " + keyFile, e); + } + return keyManager(key, keyPassword, keyCertChain); + } + + /** + * Identifying certificate for this host. {@code keyCertChainInputStream} and {@code keyInputStream} may + * be {@code null} for client contexts, which disables mutual authentication. + * + * @param keyCertChainInputStream an input stream for an X.509 certificate chain in PEM format. The caller is + * responsible for calling {@link InputStream#close()} after {@link #build()} + * has been called. + * @param keyInputStream an input stream for a PKCS#8 private key in PEM format. The caller is + * responsible for calling {@link InputStream#close()} after {@link #build()} + * has been called. + * @param keyPassword the password of the {@code keyInputStream}, or {@code null} if it's not + * password-protected + */ + public SslContextBuilder keyManager(InputStream keyCertChainInputStream, InputStream keyInputStream, + String keyPassword) { + X509Certificate[] keyCertChain; + PrivateKey key; + try { + keyCertChain = SslContext.toX509Certificates(keyCertChainInputStream); + } catch (Exception e) { + throw new IllegalArgumentException("Input stream not contain valid certificates.", e); + } + try { + key = SslContext.toPrivateKey(keyInputStream, keyPassword); + } catch (Exception e) { + throw new IllegalArgumentException("Input stream does not contain valid private key.", e); + } + return keyManager(key, keyPassword, keyCertChain); + } + + /** + * Identifying certificate for this host. {@code keyCertChain} and {@code key} may + * be {@code null} for client contexts, which disables mutual authentication. + * + * @param key a PKCS#8 private key file + * @param keyPassword the password of the {@code key}, or {@code null} if it's not + * password-protected + * @param keyCertChain an X.509 certificate chain + */ + public SslContextBuilder keyManager(PrivateKey key, String keyPassword, X509Certificate... keyCertChain) { + if (forServer) { + checkNonEmpty(keyCertChain, "keyCertChain"); + checkNotNull(key, "key required for servers"); + } + if (keyCertChain == null || keyCertChain.length == 0) { + this.keyCertChain = null; + } else { + for (X509Certificate cert: keyCertChain) { + checkNotNullWithIAE(cert, "cert"); + } + this.keyCertChain = keyCertChain.clone(); + } + this.key = key; + this.keyPassword = keyPassword; + keyManagerFactory = null; + return this; + } + + /** + * Identifying certificate for this host. {@code keyCertChain} and {@code key} may + * be {@code null} for client contexts, which disables mutual authentication. + * + * @param key a PKCS#8 private key file + * @param keyPassword the password of the {@code key}, or {@code null} if it's not + * password-protected + * @param keyCertChain an X.509 certificate chain + */ + public SslContextBuilder keyManager(PrivateKey key, String keyPassword, + Iterable keyCertChain) { + return keyManager(key, keyPassword, toArray(keyCertChain, EMPTY_X509_CERTIFICATES)); + } + + /** + * Identifying manager for this host. {@code keyManagerFactory} may be {@code null} for + * client contexts, which disables mutual authentication. Using a {@link KeyManagerFactory} + * is only supported for {@link SslProvider#JDK} or {@link SslProvider#OPENSSL} / {@link SslProvider#OPENSSL_REFCNT} + * if the used openssl version is 1.0.1+. You can check if your openssl version supports using a + * {@link KeyManagerFactory} by calling {@link OpenSsl#supportsKeyManagerFactory()}. If this is not the case + * you must use {@link #keyManager(File, File)} or {@link #keyManager(File, File, String)}. + * + * If you use {@link SslProvider#OPENSSL} or {@link SslProvider#OPENSSL_REFCNT} consider using + * {@link OpenSslX509KeyManagerFactory} or {@link OpenSslCachingX509KeyManagerFactory}. + */ + public SslContextBuilder keyManager(KeyManagerFactory keyManagerFactory) { + if (forServer) { + checkNotNull(keyManagerFactory, "keyManagerFactory required for servers"); + } + keyCertChain = null; + key = null; + keyPassword = null; + this.keyManagerFactory = keyManagerFactory; + return this; + } + + /** + * A single key manager managing the identity information of this host. + * This is helpful when custom implementation of {@link KeyManager} is needed. + * Internally, a wrapper of {@link KeyManagerFactory} that only produces this specified + * {@link KeyManager} will be created, thus all the requirements specified in + * {@link #keyManager(KeyManagerFactory keyManagerFactory)} also apply here. + */ + public SslContextBuilder keyManager(KeyManager keyManager) { + if (forServer) { + checkNotNull(keyManager, "keyManager required for servers"); + } + if (keyManager != null) { + this.keyManagerFactory = new KeyManagerFactoryWrapper(keyManager); + } else { + this.keyManagerFactory = null; + } + keyCertChain = null; + key = null; + keyPassword = null; + return this; + } + + /** + * The cipher suites to enable, in the order of preference. {@code null} to use default + * cipher suites. + */ + public SslContextBuilder ciphers(Iterable ciphers) { + return ciphers(ciphers, IdentityCipherSuiteFilter.INSTANCE); + } + + /** + * The cipher suites to enable, in the order of preference. {@code cipherFilter} will be + * applied to the ciphers before use. If {@code ciphers} is {@code null}, then the default + * cipher suites will be used. + */ + public SslContextBuilder ciphers(Iterable ciphers, CipherSuiteFilter cipherFilter) { + this.cipherFilter = checkNotNull(cipherFilter, "cipherFilter"); + this.ciphers = ciphers; + return this; + } + + /** + * Application protocol negotiation configuration. {@code null} disables support. + */ + public SslContextBuilder applicationProtocolConfig(ApplicationProtocolConfig apn) { + this.apn = apn; + return this; + } + + /** + * Set the size of the cache used for storing SSL session objects. {@code 0} to use the + * default value. + */ + public SslContextBuilder sessionCacheSize(long sessionCacheSize) { + this.sessionCacheSize = sessionCacheSize; + return this; + } + + /** + * Set the timeout for the cached SSL session objects, in seconds. {@code 0} to use the + * default value. + */ + public SslContextBuilder sessionTimeout(long sessionTimeout) { + this.sessionTimeout = sessionTimeout; + return this; + } + + /** + * Sets the client authentication mode. + */ + public SslContextBuilder clientAuth(ClientAuth clientAuth) { + this.clientAuth = checkNotNull(clientAuth, "clientAuth"); + return this; + } + + /** + * The TLS protocol versions to enable. + * @param protocols The protocols to enable, or {@code null} to enable the default protocols. + * @see SSLEngine#setEnabledCipherSuites(String[]) + */ + public SslContextBuilder protocols(String... protocols) { + this.protocols = protocols == null ? null : protocols.clone(); + return this; + } + + /** + * The TLS protocol versions to enable. + * @param protocols The protocols to enable, or {@code null} to enable the default protocols. + * @see SSLEngine#setEnabledCipherSuites(String[]) + */ + public SslContextBuilder protocols(Iterable protocols) { + return protocols(toArray(protocols, EMPTY_STRINGS)); + } + + /** + * {@code true} if the first write request shouldn't be encrypted. + */ + public SslContextBuilder startTls(boolean startTls) { + this.startTls = startTls; + return this; + } + + /** + * Enables OCSP stapling. Please note that not all {@link SslProvider} implementations support OCSP + * stapling and an exception will be thrown upon {@link #build()}. + * + * @see OpenSsl#isOcspSupported() + */ + @UnstableApi + public SslContextBuilder enableOcsp(boolean enableOcsp) { + this.enableOcsp = enableOcsp; + return this; + } + + /** + * Create new {@code SslContext} instance with configured settings. + *

If {@link #sslProvider(SslProvider)} is set to {@link SslProvider#OPENSSL_REFCNT} then the caller is + * responsible for releasing this object, or else native memory may leak. + */ + public SslContext build() throws SSLException { + if (forServer) { + return SslContext.newServerContextInternal(provider, sslContextProvider, trustCertCollection, + trustManagerFactory, keyCertChain, key, keyPassword, keyManagerFactory, + ciphers, cipherFilter, apn, sessionCacheSize, sessionTimeout, clientAuth, protocols, startTls, + enableOcsp, keyStoreType, toArray(options.entrySet(), EMPTY_ENTRIES)); + } else { + return SslContext.newClientContextInternal(provider, sslContextProvider, trustCertCollection, + trustManagerFactory, keyCertChain, key, keyPassword, keyManagerFactory, + ciphers, cipherFilter, apn, protocols, sessionCacheSize, sessionTimeout, enableOcsp, keyStoreType, + toArray(options.entrySet(), EMPTY_ENTRIES)); + } + } + + private static T[] toArray(Iterable iterable, T[] prototype) { + if (iterable == null) { + return null; + } + final List list = new ArrayList(); + for (T element : iterable) { + list.add(element); + } + return list.toArray(prototype); + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/SslContextOption.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/SslContextOption.java new file mode 100644 index 0000000..36491d5 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/SslContextOption.java @@ -0,0 +1,86 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.util.AbstractConstant; +import io.netty.util.ConstantPool; +import io.netty.util.internal.ObjectUtil; + + +/** + * A {@link SslContextOption} allows to configure a {@link SslContext} in a type-safe + * way. Which {@link SslContextOption} is supported depends on the actual implementation + * of {@link SslContext} and may depend on the nature of the SSL implementation it belongs + * to. + * + * @param the type of the value which is valid for the {@link SslContextOption} + */ +public class SslContextOption extends AbstractConstant> { + + private static final ConstantPool> pool = new ConstantPool>() { + @Override + protected SslContextOption newConstant(int id, String name) { + return new SslContextOption(id, name); + } + }; + + /** + * Returns the {@link SslContextOption} of the specified name. + */ + @SuppressWarnings("unchecked") + public static SslContextOption valueOf(String name) { + return (SslContextOption) pool.valueOf(name); + } + + /** + * Shortcut of {@link #valueOf(String) valueOf(firstNameComponent.getName() + "#" + secondNameComponent)}. + */ + @SuppressWarnings("unchecked") + public static SslContextOption valueOf(Class firstNameComponent, String secondNameComponent) { + return (SslContextOption) pool.valueOf(firstNameComponent, secondNameComponent); + } + + /** + * Returns {@code true} if a {@link SslContextOption} exists for the given {@code name}. + */ + public static boolean exists(String name) { + return pool.exists(name); + } + + /** + * Creates a new {@link SslContextOption} with the specified unique {@code name}. + */ + private SslContextOption(int id, String name) { + super(id, name); + } + + /** + * Should be used by sub-classes. + * + * @param name the name of the option + */ + protected SslContextOption(String name) { + this(pool.nextId(), name); + } + + /** + * Validate the value which is set for the {@link SslContextOption}. Sub-classes + * may override this for special checks. + */ + public void validate(T value) { + ObjectUtil.checkNotNull(value, "value"); + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/SslHandler.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/SslHandler.java new file mode 100644 index 0000000..448e430 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/SslHandler.java @@ -0,0 +1,2478 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.CompositeByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.AbstractCoalescingBufferQueue; +import io.netty.channel.Channel; +import io.netty.channel.ChannelConfig; +import io.netty.channel.ChannelException; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandler; +import io.netty.channel.ChannelOption; +import io.netty.channel.ChannelOutboundBuffer; +import io.netty.channel.ChannelOutboundHandler; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.ChannelPromise; +import io.netty.channel.unix.UnixChannel; +import io.netty.handler.codec.ByteToMessageDecoder; +import io.netty.handler.codec.DecoderException; +import io.netty.handler.codec.UnsupportedMessageTypeException; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.concurrent.DefaultPromise; +import io.netty.util.concurrent.EventExecutor; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.FutureListener; +import io.netty.util.concurrent.ImmediateExecutor; +import io.netty.util.concurrent.Promise; +import io.netty.util.concurrent.PromiseNotifier; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.ThrowableUtil; +import io.netty.util.internal.UnstableApi; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.io.IOException; +import java.net.SocketAddress; +import java.nio.ByteBuffer; +import java.nio.channels.ClosedChannelException; +import java.nio.channels.DatagramChannel; +import java.nio.channels.SocketChannel; +import java.util.List; +import java.util.concurrent.Executor; +import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.regex.Pattern; + +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLEngineResult; +import javax.net.ssl.SSLEngineResult.HandshakeStatus; +import javax.net.ssl.SSLEngineResult.Status; +import javax.net.ssl.SSLException; +import javax.net.ssl.SSLHandshakeException; +import javax.net.ssl.SSLSession; + +import static io.netty.buffer.ByteBufUtil.ensureWritableSuccess; +import static io.netty.handler.ssl.SslUtils.NOT_ENOUGH_DATA; +import static io.netty.handler.ssl.SslUtils.getEncryptedPacketLength; +import static io.netty.util.internal.ObjectUtil.checkNotNull; +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; + +/** + * The SslHandler + * Adds SSL + * · TLS and StartTLS support to a {@link Channel}. Please refer + * to the "SecureChat" example in the distribution or the web + * site for the detailed usage. + * + * Beginning the handshake + *

+ * Beside using the handshake {@link ChannelFuture} to get notified about the completion of the handshake it's + * also possible to detect it by implement the + * {@link ChannelInboundHandler#userEventTriggered(ChannelHandlerContext, Object)} + * method and check for a {@link SslHandshakeCompletionEvent}. + * + * Handshake + *

+ * The handshake will be automatically issued for you once the {@link Channel} is active and + * {@link SSLEngine#getUseClientMode()} returns {@code true}. + * So no need to bother with it by your self. + * + * Closing the session + *

+ * To close the SSL session, the {@link #closeOutbound()} method should be + * called to send the {@code close_notify} message to the remote peer. One + * exception is when you close the {@link Channel} - {@link SslHandler} + * intercepts the close request and send the {@code close_notify} message + * before the channel closure automatically. Once the SSL session is closed, + * it is not reusable, and consequently you should create a new + * {@link SslHandler} with a new {@link SSLEngine} as explained in the + * following section. + * + * Restarting the session + *

+ * To restart the SSL session, you must remove the existing closed + * {@link SslHandler} from the {@link ChannelPipeline}, insert a new + * {@link SslHandler} with a new {@link SSLEngine} into the pipeline, + * and start the handshake process as described in the first section. + * + * Implementing StartTLS + *

+ * StartTLS is the + * communication pattern that secures the wire in the middle of the plaintext + * connection. Please note that it is different from SSL · TLS, that + * secures the wire from the beginning of the connection. Typically, StartTLS + * is composed of three steps: + *

    + *
  1. Client sends a StartTLS request to server.
  2. + *
  3. Server sends a StartTLS response to client.
  4. + *
  5. Client begins SSL handshake.
  6. + *
+ * If you implement a server, you need to: + *
    + *
  1. create a new {@link SslHandler} instance with {@code startTls} flag set + * to {@code true},
  2. + *
  3. insert the {@link SslHandler} to the {@link ChannelPipeline}, and
  4. + *
  5. write a StartTLS response.
  6. + *
+ * Please note that you must insert {@link SslHandler} before sending + * the StartTLS response. Otherwise the client can send begin SSL handshake + * before {@link SslHandler} is inserted to the {@link ChannelPipeline}, causing + * data corruption. + *

+ * The client-side implementation is much simpler. + *

    + *
  1. Write a StartTLS request,
  2. + *
  3. wait for the StartTLS response,
  4. + *
  5. create a new {@link SslHandler} instance with {@code startTls} flag set + * to {@code false},
  6. + *
  7. insert the {@link SslHandler} to the {@link ChannelPipeline}, and
  8. + *
  9. Initiate SSL handshake.
  10. + *
+ * + * Known issues + *

+ * Because of a known issue with the current implementation of the SslEngine that comes + * with Java it may be possible that you see blocked IO-Threads while a full GC is done. + *

+ * So if you are affected you can workaround this problem by adjust the cache settings + * like shown below: + * + *

+ *     SslContext context = ...;
+ *     context.getServerSessionContext().setSessionCacheSize(someSaneSize);
+ *     context.getServerSessionContext().setSessionTime(someSameTimeout);
+ * 
+ *

+ * What values to use here depends on the nature of your application and should be set + * based on monitoring and debugging of it. + * For more details see + * #832 in our issue tracker. + */ +public class SslHandler extends ByteToMessageDecoder implements ChannelOutboundHandler { + private static final InternalLogger logger = + InternalLoggerFactory.getInstance(SslHandler.class); + private static final Pattern IGNORABLE_CLASS_IN_STACK = Pattern.compile( + "^.*(?:Socket|Datagram|Sctp|Udt)Channel.*$"); + private static final Pattern IGNORABLE_ERROR_MESSAGE = Pattern.compile( + "^.*(?:connection.*(?:reset|closed|abort|broken)|broken.*pipe).*$", Pattern.CASE_INSENSITIVE); + private static final int STATE_SENT_FIRST_MESSAGE = 1; + private static final int STATE_FLUSHED_BEFORE_HANDSHAKE = 1 << 1; + private static final int STATE_READ_DURING_HANDSHAKE = 1 << 2; + private static final int STATE_HANDSHAKE_STARTED = 1 << 3; + /** + * Set by wrap*() methods when something is produced. + * {@link #channelReadComplete(ChannelHandlerContext)} will check this flag, clear it, and call ctx.flush(). + */ + private static final int STATE_NEEDS_FLUSH = 1 << 4; + private static final int STATE_OUTBOUND_CLOSED = 1 << 5; + private static final int STATE_CLOSE_NOTIFY = 1 << 6; + private static final int STATE_PROCESS_TASK = 1 << 7; + /** + * This flag is used to determine if we need to call {@link ChannelHandlerContext#read()} to consume more data + * when {@link ChannelConfig#isAutoRead()} is {@code false}. + */ + private static final int STATE_FIRE_CHANNEL_READ = 1 << 8; + private static final int STATE_UNWRAP_REENTRY = 1 << 9; + + /** + * 2^14 which is the maximum sized plaintext chunk + * allowed by the TLS RFC. + */ + private static final int MAX_PLAINTEXT_LENGTH = 16 * 1024; + + private enum SslEngineType { + TCNATIVE(true, COMPOSITE_CUMULATOR) { + @Override + SSLEngineResult unwrap(SslHandler handler, ByteBuf in, int len, ByteBuf out) throws SSLException { + int nioBufferCount = in.nioBufferCount(); + int writerIndex = out.writerIndex(); + final SSLEngineResult result; + if (nioBufferCount > 1) { + /* + * If {@link OpenSslEngine} is in use, + * we can use a special {@link OpenSslEngine#unwrap(ByteBuffer[], ByteBuffer[])} method + * that accepts multiple {@link ByteBuffer}s without additional memory copies. + */ + ReferenceCountedOpenSslEngine opensslEngine = (ReferenceCountedOpenSslEngine) handler.engine; + try { + handler.singleBuffer[0] = toByteBuffer(out, writerIndex, out.writableBytes()); + result = opensslEngine.unwrap(in.nioBuffers(in.readerIndex(), len), handler.singleBuffer); + } finally { + handler.singleBuffer[0] = null; + } + } else { + result = handler.engine.unwrap(toByteBuffer(in, in.readerIndex(), len), + toByteBuffer(out, writerIndex, out.writableBytes())); + } + out.writerIndex(writerIndex + result.bytesProduced()); + return result; + } + + @Override + ByteBuf allocateWrapBuffer(SslHandler handler, ByteBufAllocator allocator, + int pendingBytes, int numComponents) { + return allocator.directBuffer(((ReferenceCountedOpenSslEngine) handler.engine) + .calculateOutNetBufSize(pendingBytes, numComponents)); + } + + @Override + int calculateRequiredOutBufSpace(SslHandler handler, int pendingBytes, int numComponents) { + return ((ReferenceCountedOpenSslEngine) handler.engine) + .calculateMaxLengthForWrap(pendingBytes, numComponents); + } + + @Override + int calculatePendingData(SslHandler handler, int guess) { + int sslPending = ((ReferenceCountedOpenSslEngine) handler.engine).sslPending(); + return sslPending > 0 ? sslPending : guess; + } + + @Override + boolean jdkCompatibilityMode(SSLEngine engine) { + return ((ReferenceCountedOpenSslEngine) engine).jdkCompatibilityMode; + } + }, + CONSCRYPT(true, COMPOSITE_CUMULATOR) { + @Override + SSLEngineResult unwrap(SslHandler handler, ByteBuf in, int len, ByteBuf out) throws SSLException { + int nioBufferCount = in.nioBufferCount(); + int writerIndex = out.writerIndex(); + final SSLEngineResult result; + if (nioBufferCount > 1) { + /* + * Use a special unwrap method without additional memory copies. + */ + try { + handler.singleBuffer[0] = toByteBuffer(out, writerIndex, out.writableBytes()); + result = ((ConscryptAlpnSslEngine) handler.engine).unwrap( + in.nioBuffers(in.readerIndex(), len), + handler.singleBuffer); + } finally { + handler.singleBuffer[0] = null; + } + } else { + result = handler.engine.unwrap(toByteBuffer(in, in.readerIndex(), len), + toByteBuffer(out, writerIndex, out.writableBytes())); + } + out.writerIndex(writerIndex + result.bytesProduced()); + return result; + } + + @Override + ByteBuf allocateWrapBuffer(SslHandler handler, ByteBufAllocator allocator, + int pendingBytes, int numComponents) { + return allocator.directBuffer( + ((ConscryptAlpnSslEngine) handler.engine).calculateOutNetBufSize(pendingBytes, numComponents)); + } + + @Override + int calculateRequiredOutBufSpace(SslHandler handler, int pendingBytes, int numComponents) { + return ((ConscryptAlpnSslEngine) handler.engine) + .calculateRequiredOutBufSpace(pendingBytes, numComponents); + } + + @Override + int calculatePendingData(SslHandler handler, int guess) { + return guess; + } + + @Override + boolean jdkCompatibilityMode(SSLEngine engine) { + return true; + } + }, + JDK(false, MERGE_CUMULATOR) { + @Override + SSLEngineResult unwrap(SslHandler handler, ByteBuf in, int len, ByteBuf out) throws SSLException { + int writerIndex = out.writerIndex(); + ByteBuffer inNioBuffer = toByteBuffer(in, in.readerIndex(), len); + int position = inNioBuffer.position(); + final SSLEngineResult result = handler.engine.unwrap(inNioBuffer, + toByteBuffer(out, writerIndex, out.writableBytes())); + out.writerIndex(writerIndex + result.bytesProduced()); + + // This is a workaround for a bug in Android 5.0. Android 5.0 does not correctly update the + // SSLEngineResult.bytesConsumed() in some cases and just return 0. + // + // See: + // - https://android-review.googlesource.com/c/platform/external/conscrypt/+/122080 + // - https://github.com/netty/netty/issues/7758 + if (result.bytesConsumed() == 0) { + int consumed = inNioBuffer.position() - position; + if (consumed != result.bytesConsumed()) { + // Create a new SSLEngineResult with the correct bytesConsumed(). + return new SSLEngineResult( + result.getStatus(), result.getHandshakeStatus(), consumed, result.bytesProduced()); + } + } + return result; + } + + @Override + ByteBuf allocateWrapBuffer(SslHandler handler, ByteBufAllocator allocator, + int pendingBytes, int numComponents) { + // For JDK we don't have a good source for the max wrap overhead. We need at least one packet buffer + // size, but may be able to fit more in based on the total requested. + return allocator.heapBuffer(Math.max(pendingBytes, handler.engine.getSession().getPacketBufferSize())); + } + + @Override + int calculateRequiredOutBufSpace(SslHandler handler, int pendingBytes, int numComponents) { + // As for the JDK SSLEngine we always need to operate on buffer space required by the SSLEngine + // (normally ~16KB). This is required even if the amount of data to encrypt is very small. Use heap + // buffers to reduce the native memory usage. + // + // Beside this the JDK SSLEngine also (as of today) will do an extra heap to direct buffer copy + // if a direct buffer is used as its internals operate on byte[]. + return handler.engine.getSession().getPacketBufferSize(); + } + + @Override + int calculatePendingData(SslHandler handler, int guess) { + return guess; + } + + @Override + boolean jdkCompatibilityMode(SSLEngine engine) { + return true; + } + }; + + static SslEngineType forEngine(SSLEngine engine) { + return engine instanceof ReferenceCountedOpenSslEngine ? TCNATIVE : + engine instanceof ConscryptAlpnSslEngine ? CONSCRYPT : JDK; + } + + SslEngineType(boolean wantsDirectBuffer, Cumulator cumulator) { + this.wantsDirectBuffer = wantsDirectBuffer; + this.cumulator = cumulator; + } + + abstract SSLEngineResult unwrap(SslHandler handler, ByteBuf in, int len, ByteBuf out) throws SSLException; + + abstract int calculatePendingData(SslHandler handler, int guess); + + abstract boolean jdkCompatibilityMode(SSLEngine engine); + + abstract ByteBuf allocateWrapBuffer(SslHandler handler, ByteBufAllocator allocator, + int pendingBytes, int numComponents); + + abstract int calculateRequiredOutBufSpace(SslHandler handler, int pendingBytes, int numComponents); + + // BEGIN Platform-dependent flags + + /** + * {@code true} if and only if {@link SSLEngine} expects a direct buffer and so if a heap buffer + * is given will make an extra memory copy. + */ + final boolean wantsDirectBuffer; + + // END Platform-dependent flags + + /** + * When using JDK {@link SSLEngine}, we use {@link #MERGE_CUMULATOR} because it works only with + * one {@link ByteBuffer}. + * + * When using {@link OpenSslEngine}, we can use {@link #COMPOSITE_CUMULATOR} because it has + * {@link OpenSslEngine#unwrap(ByteBuffer[], ByteBuffer[])} which works with multiple {@link ByteBuffer}s + * and which does not need to do extra memory copies. + */ + final Cumulator cumulator; + } + + private volatile ChannelHandlerContext ctx; + private final SSLEngine engine; + private final SslEngineType engineType; + private final Executor delegatedTaskExecutor; + private final boolean jdkCompatibilityMode; + + /** + * Used if {@link SSLEngine#wrap(ByteBuffer[], ByteBuffer)} and {@link SSLEngine#unwrap(ByteBuffer, ByteBuffer[])} + * should be called with a {@link ByteBuf} that is only backed by one {@link ByteBuffer} to reduce the object + * creation. + */ + private final ByteBuffer[] singleBuffer = new ByteBuffer[1]; + + private final boolean startTls; + + private final SslTasksRunner sslTaskRunnerForUnwrap = new SslTasksRunner(true); + private final SslTasksRunner sslTaskRunner = new SslTasksRunner(false); + + private SslHandlerCoalescingBufferQueue pendingUnencryptedWrites; + private Promise handshakePromise = new LazyChannelPromise(); + private final LazyChannelPromise sslClosePromise = new LazyChannelPromise(); + + private int packetLength; + private short state; + + private volatile long handshakeTimeoutMillis = 10000; + private volatile long closeNotifyFlushTimeoutMillis = 3000; + private volatile long closeNotifyReadTimeoutMillis; + volatile int wrapDataSize = MAX_PLAINTEXT_LENGTH; + + /** + * Creates a new instance which runs all delegated tasks directly on the {@link EventExecutor}. + * + * @param engine the {@link SSLEngine} this handler will use + */ + public SslHandler(SSLEngine engine) { + this(engine, false); + } + + /** + * Creates a new instance which runs all delegated tasks directly on the {@link EventExecutor}. + * + * @param engine the {@link SSLEngine} this handler will use + * @param startTls {@code true} if the first write request shouldn't be + * encrypted by the {@link SSLEngine} + */ + public SslHandler(SSLEngine engine, boolean startTls) { + this(engine, startTls, ImmediateExecutor.INSTANCE); + } + + /** + * Creates a new instance. + * + * @param engine the {@link SSLEngine} this handler will use + * @param delegatedTaskExecutor the {@link Executor} that will be used to execute tasks that are returned by + * {@link SSLEngine#getDelegatedTask()}. + */ + public SslHandler(SSLEngine engine, Executor delegatedTaskExecutor) { + this(engine, false, delegatedTaskExecutor); + } + + /** + * Creates a new instance. + * + * @param engine the {@link SSLEngine} this handler will use + * @param startTls {@code true} if the first write request shouldn't be + * encrypted by the {@link SSLEngine} + * @param delegatedTaskExecutor the {@link Executor} that will be used to execute tasks that are returned by + * {@link SSLEngine#getDelegatedTask()}. + */ + public SslHandler(SSLEngine engine, boolean startTls, Executor delegatedTaskExecutor) { + this.engine = ObjectUtil.checkNotNull(engine, "engine"); + this.delegatedTaskExecutor = ObjectUtil.checkNotNull(delegatedTaskExecutor, "delegatedTaskExecutor"); + engineType = SslEngineType.forEngine(engine); + this.startTls = startTls; + this.jdkCompatibilityMode = engineType.jdkCompatibilityMode(engine); + setCumulator(engineType.cumulator); + } + + public long getHandshakeTimeoutMillis() { + return handshakeTimeoutMillis; + } + + public void setHandshakeTimeout(long handshakeTimeout, TimeUnit unit) { + checkNotNull(unit, "unit"); + setHandshakeTimeoutMillis(unit.toMillis(handshakeTimeout)); + } + + public void setHandshakeTimeoutMillis(long handshakeTimeoutMillis) { + this.handshakeTimeoutMillis = checkPositiveOrZero(handshakeTimeoutMillis, "handshakeTimeoutMillis"); + } + + /** + * Sets the number of bytes to pass to each {@link SSLEngine#wrap(ByteBuffer[], int, int, ByteBuffer)} call. + *

+ * This value will partition data which is passed to write + * {@link #write(ChannelHandlerContext, Object, ChannelPromise)}. The partitioning will work as follows: + *

    + *
  • If {@code wrapDataSize <= 0} then we will write each data chunk as is.
  • + *
  • If {@code wrapDataSize > data size} then we will attempt to aggregate multiple data chunks together.
  • + *
  • If {@code wrapDataSize > data size} Else if {@code wrapDataSize <= data size} then we will divide the data + * into chunks of {@code wrapDataSize} when writing.
  • + *
+ *

+ * If the {@link SSLEngine} doesn't support a gather wrap operation (e.g. {@link SslProvider#OPENSSL}) then + * aggregating data before wrapping can help reduce the ratio between TLS overhead vs data payload which will lead + * to better goodput. Writing fixed chunks of data can also help target the underlying transport's (e.g. TCP) + * frame size. Under lossy/congested network conditions this may help the peer get full TLS packets earlier and + * be able to do work sooner, as opposed to waiting for the all the pieces of the TLS packet to arrive. + * @param wrapDataSize the number of bytes which will be passed to each + * {@link SSLEngine#wrap(ByteBuffer[], int, int, ByteBuffer)} call. + */ + @UnstableApi + public final void setWrapDataSize(int wrapDataSize) { + this.wrapDataSize = wrapDataSize; + } + + /** + * @deprecated use {@link #getCloseNotifyFlushTimeoutMillis()} + */ + @Deprecated + public long getCloseNotifyTimeoutMillis() { + return getCloseNotifyFlushTimeoutMillis(); + } + + /** + * @deprecated use {@link #setCloseNotifyFlushTimeout(long, TimeUnit)} + */ + @Deprecated + public void setCloseNotifyTimeout(long closeNotifyTimeout, TimeUnit unit) { + setCloseNotifyFlushTimeout(closeNotifyTimeout, unit); + } + + /** + * @deprecated use {@link #setCloseNotifyFlushTimeoutMillis(long)} + */ + @Deprecated + public void setCloseNotifyTimeoutMillis(long closeNotifyFlushTimeoutMillis) { + setCloseNotifyFlushTimeoutMillis(closeNotifyFlushTimeoutMillis); + } + + /** + * Gets the timeout for flushing the close_notify that was triggered by closing the + * {@link Channel}. If the close_notify was not flushed in the given timeout the {@link Channel} will be closed + * forcibly. + */ + public final long getCloseNotifyFlushTimeoutMillis() { + return closeNotifyFlushTimeoutMillis; + } + + /** + * Sets the timeout for flushing the close_notify that was triggered by closing the + * {@link Channel}. If the close_notify was not flushed in the given timeout the {@link Channel} will be closed + * forcibly. + */ + public final void setCloseNotifyFlushTimeout(long closeNotifyFlushTimeout, TimeUnit unit) { + setCloseNotifyFlushTimeoutMillis(unit.toMillis(closeNotifyFlushTimeout)); + } + + /** + * See {@link #setCloseNotifyFlushTimeout(long, TimeUnit)}. + */ + public final void setCloseNotifyFlushTimeoutMillis(long closeNotifyFlushTimeoutMillis) { + this.closeNotifyFlushTimeoutMillis = checkPositiveOrZero(closeNotifyFlushTimeoutMillis, + "closeNotifyFlushTimeoutMillis"); + } + + /** + * Gets the timeout (in ms) for receiving the response for the close_notify that was triggered by closing the + * {@link Channel}. This timeout starts after the close_notify message was successfully written to the + * remote peer. Use {@code 0} to directly close the {@link Channel} and not wait for the response. + */ + public final long getCloseNotifyReadTimeoutMillis() { + return closeNotifyReadTimeoutMillis; + } + + /** + * Sets the timeout for receiving the response for the close_notify that was triggered by closing the + * {@link Channel}. This timeout starts after the close_notify message was successfully written to the + * remote peer. Use {@code 0} to directly close the {@link Channel} and not wait for the response. + */ + public final void setCloseNotifyReadTimeout(long closeNotifyReadTimeout, TimeUnit unit) { + setCloseNotifyReadTimeoutMillis(unit.toMillis(closeNotifyReadTimeout)); + } + + /** + * See {@link #setCloseNotifyReadTimeout(long, TimeUnit)}. + */ + public final void setCloseNotifyReadTimeoutMillis(long closeNotifyReadTimeoutMillis) { + this.closeNotifyReadTimeoutMillis = checkPositiveOrZero(closeNotifyReadTimeoutMillis, + "closeNotifyReadTimeoutMillis"); + } + + /** + * Returns the {@link SSLEngine} which is used by this handler. + */ + public SSLEngine engine() { + return engine; + } + + /** + * Returns the name of the current application-level protocol. + * + * @return the protocol name or {@code null} if application-level protocol has not been negotiated + */ + public String applicationProtocol() { + SSLEngine engine = engine(); + if (!(engine instanceof ApplicationProtocolAccessor)) { + return null; + } + + return ((ApplicationProtocolAccessor) engine).getNegotiatedApplicationProtocol(); + } + + /** + * Returns a {@link Future} that will get notified once the current TLS handshake completes. + * + * @return the {@link Future} for the initial TLS handshake if {@link #renegotiate()} was not invoked. + * The {@link Future} for the most recent {@linkplain #renegotiate() TLS renegotiation} otherwise. + */ + public Future handshakeFuture() { + return handshakePromise; + } + + /** + * Use {@link #closeOutbound()} + */ + @Deprecated + public ChannelFuture close() { + return closeOutbound(); + } + + /** + * Use {@link #closeOutbound(ChannelPromise)} + */ + @Deprecated + public ChannelFuture close(ChannelPromise promise) { + return closeOutbound(promise); + } + + /** + * Sends an SSL {@code close_notify} message to the specified channel and + * destroys the underlying {@link SSLEngine}. This will not close the underlying + * {@link Channel}. If you want to also close the {@link Channel} use {@link Channel#close()} or + * {@link ChannelHandlerContext#close()} + */ + public ChannelFuture closeOutbound() { + return closeOutbound(ctx.newPromise()); + } + + /** + * Sends an SSL {@code close_notify} message to the specified channel and + * destroys the underlying {@link SSLEngine}. This will not close the underlying + * {@link Channel}. If you want to also close the {@link Channel} use {@link Channel#close()} or + * {@link ChannelHandlerContext#close()} + */ + public ChannelFuture closeOutbound(final ChannelPromise promise) { + final ChannelHandlerContext ctx = this.ctx; + if (ctx.executor().inEventLoop()) { + closeOutbound0(promise); + } else { + ctx.executor().execute(new Runnable() { + @Override + public void run() { + closeOutbound0(promise); + } + }); + } + return promise; + } + + private void closeOutbound0(ChannelPromise promise) { + setState(STATE_OUTBOUND_CLOSED); + engine.closeOutbound(); + try { + flush(ctx, promise); + } catch (Exception e) { + if (!promise.tryFailure(e)) { + logger.warn("{} flush() raised a masked exception.", ctx.channel(), e); + } + } + } + + /** + * Return the {@link Future} that will get notified if the inbound of the {@link SSLEngine} is closed. + * + * This method will return the same {@link Future} all the time. + * + * @see SSLEngine + */ + public Future sslCloseFuture() { + return sslClosePromise; + } + + @Override + public void handlerRemoved0(ChannelHandlerContext ctx) throws Exception { + try { + if (pendingUnencryptedWrites != null && !pendingUnencryptedWrites.isEmpty()) { + // Check if queue is not empty first because create a new ChannelException is expensive + pendingUnencryptedWrites.releaseAndFailAll(ctx, + new ChannelException("Pending write on removal of SslHandler")); + } + pendingUnencryptedWrites = null; + + SSLException cause = null; + + // If the handshake or SSLEngine closure is not done yet we should fail corresponding promise and + // notify the rest of the + // pipeline. + if (!handshakePromise.isDone()) { + cause = new SSLHandshakeException("SslHandler removed before handshake completed"); + if (handshakePromise.tryFailure(cause)) { + ctx.fireUserEventTriggered(new SslHandshakeCompletionEvent(cause)); + } + } + if (!sslClosePromise.isDone()) { + if (cause == null) { + cause = new SSLException("SslHandler removed before SSLEngine was closed"); + } + notifyClosePromise(cause); + } + } finally { + ReferenceCountUtil.release(engine); + } + } + + @Override + public void bind(ChannelHandlerContext ctx, SocketAddress localAddress, ChannelPromise promise) throws Exception { + ctx.bind(localAddress, promise); + } + + @Override + public void connect(ChannelHandlerContext ctx, SocketAddress remoteAddress, SocketAddress localAddress, + ChannelPromise promise) throws Exception { + ctx.connect(remoteAddress, localAddress, promise); + } + + @Override + public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + ctx.deregister(promise); + } + + @Override + public void disconnect(final ChannelHandlerContext ctx, + final ChannelPromise promise) throws Exception { + closeOutboundAndChannel(ctx, promise, true); + } + + @Override + public void close(final ChannelHandlerContext ctx, + final ChannelPromise promise) throws Exception { + closeOutboundAndChannel(ctx, promise, false); + } + + @Override + public void read(ChannelHandlerContext ctx) throws Exception { + if (!handshakePromise.isDone()) { + setState(STATE_READ_DURING_HANDSHAKE); + } + + ctx.read(); + } + + private static IllegalStateException newPendingWritesNullException() { + return new IllegalStateException("pendingUnencryptedWrites is null, handlerRemoved0 called?"); + } + + @Override + public void write(final ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + if (!(msg instanceof ByteBuf)) { + UnsupportedMessageTypeException exception = new UnsupportedMessageTypeException(msg, ByteBuf.class); + ReferenceCountUtil.safeRelease(msg); + promise.setFailure(exception); + } else if (pendingUnencryptedWrites == null) { + ReferenceCountUtil.safeRelease(msg); + promise.setFailure(newPendingWritesNullException()); + } else { + pendingUnencryptedWrites.add((ByteBuf) msg, promise); + } + } + + @Override + public void flush(ChannelHandlerContext ctx) throws Exception { + // Do not encrypt the first write request if this handler is + // created with startTLS flag turned on. + if (startTls && !isStateSet(STATE_SENT_FIRST_MESSAGE)) { + setState(STATE_SENT_FIRST_MESSAGE); + pendingUnencryptedWrites.writeAndRemoveAll(ctx); + forceFlush(ctx); + // Explicit start handshake processing once we send the first message. This will also ensure + // we will schedule the timeout if needed. + startHandshakeProcessing(true); + return; + } + + if (isStateSet(STATE_PROCESS_TASK)) { + return; + } + + try { + wrapAndFlush(ctx); + } catch (Throwable cause) { + setHandshakeFailure(ctx, cause); + PlatformDependent.throwException(cause); + } + } + + private void wrapAndFlush(ChannelHandlerContext ctx) throws SSLException { + if (pendingUnencryptedWrites.isEmpty()) { + // It's important to NOT use a voidPromise here as the user + // may want to add a ChannelFutureListener to the ChannelPromise later. + // + // See https://github.com/netty/netty/issues/3364 + pendingUnencryptedWrites.add(Unpooled.EMPTY_BUFFER, ctx.newPromise()); + } + if (!handshakePromise.isDone()) { + setState(STATE_FLUSHED_BEFORE_HANDSHAKE); + } + try { + wrap(ctx, false); + } finally { + // We may have written some parts of data before an exception was thrown so ensure we always flush. + // See https://github.com/netty/netty/issues/3900#issuecomment-172481830 + forceFlush(ctx); + } + } + + // This method will not call setHandshakeFailure(...) ! + private void wrap(ChannelHandlerContext ctx, boolean inUnwrap) throws SSLException { + ByteBuf out = null; + ByteBufAllocator alloc = ctx.alloc(); + try { + final int wrapDataSize = this.wrapDataSize; + // Only continue to loop if the handler was not removed in the meantime. + // See https://github.com/netty/netty/issues/5860 + outer: while (!ctx.isRemoved()) { + ChannelPromise promise = ctx.newPromise(); + ByteBuf buf = wrapDataSize > 0 ? + pendingUnencryptedWrites.remove(alloc, wrapDataSize, promise) : + pendingUnencryptedWrites.removeFirst(promise); + if (buf == null) { + break; + } + + SSLEngineResult result; + + if (buf.readableBytes() > MAX_PLAINTEXT_LENGTH) { + // If we pulled a buffer larger than the supported packet size, we can slice it up and iteratively, + // encrypting multiple packets into a single larger buffer. This substantially saves on allocations + // for large responses. Here we estimate how large of a buffer we need. If we overestimate a bit, + // that's fine. If we underestimate, we'll simply re-enqueue the remaining buffer and get it on the + // next outer loop. + int readableBytes = buf.readableBytes(); + int numPackets = readableBytes / MAX_PLAINTEXT_LENGTH; + if (readableBytes % MAX_PLAINTEXT_LENGTH != 0) { + numPackets += 1; + } + + if (out == null) { + out = allocateOutNetBuf(ctx, readableBytes, buf.nioBufferCount() + numPackets); + } + result = wrapMultiple(alloc, engine, buf, out); + } else { + if (out == null) { + out = allocateOutNetBuf(ctx, buf.readableBytes(), buf.nioBufferCount()); + } + result = wrap(alloc, engine, buf, out); + } + + if (buf.isReadable()) { + pendingUnencryptedWrites.addFirst(buf, promise); + // When we add the buffer/promise pair back we need to be sure we don't complete the promise + // later. We only complete the promise if the buffer is completely consumed. + promise = null; + } else { + buf.release(); + } + + // We need to write any data before we invoke any methods which may trigger re-entry, otherwise + // writes may occur out of order and TLS sequencing may be off (e.g. SSLV3_ALERT_BAD_RECORD_MAC). + if (out.isReadable()) { + final ByteBuf b = out; + out = null; + if (promise != null) { + ctx.write(b, promise); + } else { + ctx.write(b); + } + } else if (promise != null) { + ctx.write(Unpooled.EMPTY_BUFFER, promise); + } + // else out is not readable we can re-use it and so save an extra allocation + + if (result.getStatus() == Status.CLOSED) { + // First check if there is any write left that needs to be failed, if there is none we don't need + // to create a new exception or obtain an existing one. + if (!pendingUnencryptedWrites.isEmpty()) { + // Make a best effort to preserve any exception that way previously encountered from the + // handshake or the transport, else fallback to a general error. + Throwable exception = handshakePromise.cause(); + if (exception == null) { + exception = sslClosePromise.cause(); + if (exception == null) { + exception = new SslClosedEngineException("SSLEngine closed already"); + } + } + pendingUnencryptedWrites.releaseAndFailAll(ctx, exception); + } + + return; + } else { + switch (result.getHandshakeStatus()) { + case NEED_TASK: + if (!runDelegatedTasks(inUnwrap)) { + // We scheduled a task on the delegatingTaskExecutor, so stop processing as we will + // resume once the task completes. + break outer; + } + break; + case FINISHED: + case NOT_HANDSHAKING: // work around for android bug that skips the FINISHED state. + setHandshakeSuccess(); + break; + case NEED_WRAP: + // If we are expected to wrap again and we produced some data we need to ensure there + // is something in the queue to process as otherwise we will not try again before there + // was more added. Failing to do so may fail to produce an alert that can be + // consumed by the remote peer. + if (result.bytesProduced() > 0 && pendingUnencryptedWrites.isEmpty()) { + pendingUnencryptedWrites.add(Unpooled.EMPTY_BUFFER); + } + break; + case NEED_UNWRAP: + // The underlying engine is starving so we need to feed it with more data. + // See https://github.com/netty/netty/pull/5039 + readIfNeeded(ctx); + return; + default: + throw new IllegalStateException( + "Unknown handshake status: " + result.getHandshakeStatus()); + } + } + } + } finally { + if (out != null) { + out.release(); + } + if (inUnwrap) { + setState(STATE_NEEDS_FLUSH); + } + } + } + + /** + * This method will not call + * {@link #setHandshakeFailure(ChannelHandlerContext, Throwable, boolean, boolean, boolean)} or + * {@link #setHandshakeFailure(ChannelHandlerContext, Throwable)}. + * @return {@code true} if this method ends on {@link SSLEngineResult.HandshakeStatus#NOT_HANDSHAKING}. + */ + private boolean wrapNonAppData(final ChannelHandlerContext ctx, boolean inUnwrap) throws SSLException { + ByteBuf out = null; + ByteBufAllocator alloc = ctx.alloc(); + try { + // Only continue to loop if the handler was not removed in the meantime. + // See https://github.com/netty/netty/issues/5860 + outer: while (!ctx.isRemoved()) { + if (out == null) { + // As this is called for the handshake we have no real idea how big the buffer needs to be. + // That said 2048 should give us enough room to include everything like ALPN / NPN data. + // If this is not enough we will increase the buffer in wrap(...). + out = allocateOutNetBuf(ctx, 2048, 1); + } + SSLEngineResult result = wrap(alloc, engine, Unpooled.EMPTY_BUFFER, out); + if (result.bytesProduced() > 0) { + ctx.write(out).addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + Throwable cause = future.cause(); + if (cause != null) { + setHandshakeFailureTransportFailure(ctx, cause); + } + } + }); + if (inUnwrap) { + setState(STATE_NEEDS_FLUSH); + } + out = null; + } + + HandshakeStatus status = result.getHandshakeStatus(); + switch (status) { + case FINISHED: + // We may be here because we read data and discovered the remote peer initiated a renegotiation + // and this write is to complete the new handshake. The user may have previously done a + // writeAndFlush which wasn't able to wrap data due to needing the pending handshake, so we + // attempt to wrap application data here if any is pending. + if (setHandshakeSuccess() && inUnwrap && !pendingUnencryptedWrites.isEmpty()) { + wrap(ctx, true); + } + return false; + case NEED_TASK: + if (!runDelegatedTasks(inUnwrap)) { + // We scheduled a task on the delegatingTaskExecutor, so stop processing as we will + // resume once the task completes. + break outer; + } + break; + case NEED_UNWRAP: + if (inUnwrap || unwrapNonAppData(ctx) <= 0) { + // If we asked for a wrap, the engine requested an unwrap, and we are in unwrap there is + // no use in trying to call wrap again because we have already attempted (or will after we + // return) to feed more data to the engine. + return false; + } + break; + case NEED_WRAP: + break; + case NOT_HANDSHAKING: + if (setHandshakeSuccess() && inUnwrap && !pendingUnencryptedWrites.isEmpty()) { + wrap(ctx, true); + } + // Workaround for TLS False Start problem reported at: + // https://github.com/netty/netty/issues/1108#issuecomment-14266970 + if (!inUnwrap) { + unwrapNonAppData(ctx); + } + return true; + default: + throw new IllegalStateException("Unknown handshake status: " + result.getHandshakeStatus()); + } + + // Check if did not produce any bytes and if so break out of the loop, but only if we did not process + // a task as last action. It's fine to not produce any data as part of executing a task. + if (result.bytesProduced() == 0 && status != HandshakeStatus.NEED_TASK) { + break; + } + + // It should not consume empty buffers when it is not handshaking + // Fix for Android, where it was encrypting empty buffers even when not handshaking + if (result.bytesConsumed() == 0 && result.getHandshakeStatus() == HandshakeStatus.NOT_HANDSHAKING) { + break; + } + } + } finally { + if (out != null) { + out.release(); + } + } + return false; + } + + private SSLEngineResult wrapMultiple(ByteBufAllocator alloc, SSLEngine engine, ByteBuf in, ByteBuf out) + throws SSLException { + SSLEngineResult result = null; + + do { + int nextSliceSize = Math.min(MAX_PLAINTEXT_LENGTH, in.readableBytes()); + // This call over-estimates, because we are slicing and not every nioBuffer will be part of + // every slice. We could improve the estimate by having an nioBufferCount(offset, length). + int nextOutSize = engineType.calculateRequiredOutBufSpace(this, nextSliceSize, in.nioBufferCount()); + + if (!out.isWritable(nextOutSize)) { + if (result != null) { + // We underestimated the space needed to encrypt the entire in buf. Break out, and + // upstream will re-enqueue the buffer for later. + break; + } + // This shouldn't happen, as the out buf was properly sized for at least packetLength + // prior to calling wrap. + out.ensureWritable(nextOutSize); + } + + ByteBuf wrapBuf = in.readSlice(nextSliceSize); + result = wrap(alloc, engine, wrapBuf, out); + + if (result.getStatus() == Status.CLOSED) { + // If the engine gets closed, we can exit out early. Otherwise, we'll do a full handling of + // possible results once finished. + break; + } + + if (wrapBuf.isReadable()) { + // There may be some left-over, in which case we can just pick it up next loop, so reset the original + // reader index so its included again in the next slice. + in.readerIndex(in.readerIndex() - wrapBuf.readableBytes()); + } + } while (in.readableBytes() > 0); + + return result; + } + + private SSLEngineResult wrap(ByteBufAllocator alloc, SSLEngine engine, ByteBuf in, ByteBuf out) + throws SSLException { + ByteBuf newDirectIn = null; + try { + int readerIndex = in.readerIndex(); + int readableBytes = in.readableBytes(); + + // We will call SslEngine.wrap(ByteBuffer[], ByteBuffer) to allow efficient handling of + // CompositeByteBuf without force an extra memory copy when CompositeByteBuffer.nioBuffer() is called. + final ByteBuffer[] in0; + if (in.isDirect() || !engineType.wantsDirectBuffer) { + // As CompositeByteBuf.nioBufferCount() can be expensive (as it needs to check all composed ByteBuf + // to calculate the count) we will just assume a CompositeByteBuf contains more then 1 ByteBuf. + // The worst that can happen is that we allocate an extra ByteBuffer[] in CompositeByteBuf.nioBuffers() + // which is better then walking the composed ByteBuf in most cases. + if (!(in instanceof CompositeByteBuf) && in.nioBufferCount() == 1) { + in0 = singleBuffer; + // We know its only backed by 1 ByteBuffer so use internalNioBuffer to keep object allocation + // to a minimum. + in0[0] = in.internalNioBuffer(readerIndex, readableBytes); + } else { + in0 = in.nioBuffers(); + } + } else { + // We could even go further here and check if its a CompositeByteBuf and if so try to decompose it and + // only replace the ByteBuffer that are not direct. At the moment we just will replace the whole + // CompositeByteBuf to keep the complexity to a minimum + newDirectIn = alloc.directBuffer(readableBytes); + newDirectIn.writeBytes(in, readerIndex, readableBytes); + in0 = singleBuffer; + in0[0] = newDirectIn.internalNioBuffer(newDirectIn.readerIndex(), readableBytes); + } + + for (;;) { + // Use toByteBuffer(...) which might be able to return the internal ByteBuffer and so reduce + // allocations. + ByteBuffer out0 = toByteBuffer(out, out.writerIndex(), out.writableBytes()); + SSLEngineResult result = engine.wrap(in0, out0); + in.skipBytes(result.bytesConsumed()); + out.writerIndex(out.writerIndex() + result.bytesProduced()); + + if (result.getStatus() == Status.BUFFER_OVERFLOW) { + out.ensureWritable(engine.getSession().getPacketBufferSize()); + } else { + return result; + } + } + } finally { + // Null out to allow GC of ByteBuffer + singleBuffer[0] = null; + + if (newDirectIn != null) { + newDirectIn.release(); + } + } + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + boolean handshakeFailed = handshakePromise.cause() != null; + + // Channel closed, we will generate 'ClosedChannelException' now. + ClosedChannelException exception = new ClosedChannelException(); + + // Add a supressed exception if the handshake was not completed yet. + if (isStateSet(STATE_HANDSHAKE_STARTED) && !handshakePromise.isDone()) { + ThrowableUtil.addSuppressed(exception, StacklessSSLHandshakeException.newInstance( + "Connection closed while SSL/TLS handshake was in progress", + SslHandler.class, "channelInactive")); + } + + // Make sure to release SSLEngine, + // and notify the handshake future if the connection has been closed during handshake. + setHandshakeFailure(ctx, exception, !isStateSet(STATE_OUTBOUND_CLOSED), isStateSet(STATE_HANDSHAKE_STARTED), + false); + + // Ensure we always notify the sslClosePromise as well + notifyClosePromise(exception); + + try { + super.channelInactive(ctx); + } catch (DecoderException e) { + if (!handshakeFailed || !(e.getCause() instanceof SSLException)) { + // We only rethrow the exception if the handshake did not fail before channelInactive(...) was called + // as otherwise this may produce duplicated failures as super.channelInactive(...) will also call + // channelRead(...). + // + // See https://github.com/netty/netty/issues/10119 + throw e; + } + } + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + if (ignoreException(cause)) { + // It is safe to ignore the 'connection reset by peer' or + // 'broken pipe' error after sending close_notify. + if (logger.isDebugEnabled()) { + logger.debug( + "{} Swallowing a harmless 'connection reset by peer / broken pipe' error that occurred " + + "while writing close_notify in response to the peer's close_notify", ctx.channel(), cause); + } + + // Close the connection explicitly just in case the transport + // did not close the connection automatically. + if (ctx.channel().isActive()) { + ctx.close(); + } + } else { + ctx.fireExceptionCaught(cause); + } + } + + /** + * Checks if the given {@link Throwable} can be ignore and just "swallowed" + * + * When an ssl connection is closed a close_notify message is sent. + * After that the peer also sends close_notify however, it's not mandatory to receive + * the close_notify. The party who sent the initial close_notify can close the connection immediately + * then the peer will get connection reset error. + * + */ + private boolean ignoreException(Throwable t) { + if (!(t instanceof SSLException) && t instanceof IOException && sslClosePromise.isDone()) { + String message = t.getMessage(); + + // first try to match connection reset / broke peer based on the regex. This is the fastest way + // but may fail on different jdk impls or OS's + if (message != null && IGNORABLE_ERROR_MESSAGE.matcher(message).matches()) { + return true; + } + + // Inspect the StackTraceElements to see if it was a connection reset / broken pipe or not + StackTraceElement[] elements = t.getStackTrace(); + for (StackTraceElement element: elements) { + String classname = element.getClassName(); + String methodname = element.getMethodName(); + + // skip all classes that belong to the io.netty package + if (classname.startsWith("io.netty.")) { + continue; + } + + // check if the method name is read if not skip it + if (!"read".equals(methodname)) { + continue; + } + + // This will also match against SocketInputStream which is used by openjdk 7 and maybe + // also others + if (IGNORABLE_CLASS_IN_STACK.matcher(classname).matches()) { + return true; + } + + try { + // No match by now.. Try to load the class via classloader and inspect it. + // This is mainly done as other JDK implementations may differ in name of + // the impl. + Class clazz = PlatformDependent.getClassLoader(getClass()).loadClass(classname); + + if (SocketChannel.class.isAssignableFrom(clazz) + || DatagramChannel.class.isAssignableFrom(clazz)) { + return true; + } + + // also match against SctpChannel via String matching as it may not present. + if (PlatformDependent.javaVersion() >= 7 + && "com.sun.nio.sctp.SctpChannel".equals(clazz.getSuperclass().getName())) { + return true; + } + } catch (Throwable cause) { + if (logger.isDebugEnabled()) { + logger.debug("Unexpected exception while loading class {} classname {}", + getClass(), classname, cause); + } + } + } + } + + return false; + } + + /** + * Returns {@code true} if the given {@link ByteBuf} is encrypted. Be aware that this method + * will not increase the readerIndex of the given {@link ByteBuf}. + * + * @param buffer + * The {@link ByteBuf} to read from. Be aware that it must have at least 5 bytes to read, + * otherwise it will throw an {@link IllegalArgumentException}. + * @return encrypted + * {@code true} if the {@link ByteBuf} is encrypted, {@code false} otherwise. + * @throws IllegalArgumentException + * Is thrown if the given {@link ByteBuf} has not at least 5 bytes to read. + */ + public static boolean isEncrypted(ByteBuf buffer) { + if (buffer.readableBytes() < SslUtils.SSL_RECORD_HEADER_LENGTH) { + throw new IllegalArgumentException( + "buffer must have at least " + SslUtils.SSL_RECORD_HEADER_LENGTH + " readable bytes"); + } + return getEncryptedPacketLength(buffer, buffer.readerIndex()) != SslUtils.NOT_ENCRYPTED; + } + + private void decodeJdkCompatible(ChannelHandlerContext ctx, ByteBuf in) throws NotSslRecordException { + int packetLength = this.packetLength; + // If we calculated the length of the current SSL record before, use that information. + if (packetLength > 0) { + if (in.readableBytes() < packetLength) { + return; + } + } else { + // Get the packet length and wait until we get a packets worth of data to unwrap. + final int readableBytes = in.readableBytes(); + if (readableBytes < SslUtils.SSL_RECORD_HEADER_LENGTH) { + return; + } + packetLength = getEncryptedPacketLength(in, in.readerIndex()); + if (packetLength == SslUtils.NOT_ENCRYPTED) { + // Not an SSL/TLS packet + NotSslRecordException e = new NotSslRecordException( + "not an SSL/TLS record: " + ByteBufUtil.hexDump(in)); + in.skipBytes(in.readableBytes()); + + // First fail the handshake promise as we may need to have access to the SSLEngine which may + // be released because the user will remove the SslHandler in an exceptionCaught(...) implementation. + setHandshakeFailure(ctx, e); + + throw e; + } + if (packetLength == NOT_ENOUGH_DATA) { + return; + } + assert packetLength > 0; + if (packetLength > readableBytes) { + // wait until the whole packet can be read + this.packetLength = packetLength; + return; + } + } + + // Reset the state of this class so we can get the length of the next packet. We assume the entire packet will + // be consumed by the SSLEngine. + this.packetLength = 0; + try { + final int bytesConsumed = unwrap(ctx, in, packetLength); + assert bytesConsumed == packetLength || engine.isInboundDone() : + "we feed the SSLEngine a packets worth of data: " + packetLength + " but it only consumed: " + + bytesConsumed; + } catch (Throwable cause) { + handleUnwrapThrowable(ctx, cause); + } + } + + private void decodeNonJdkCompatible(ChannelHandlerContext ctx, ByteBuf in) { + try { + unwrap(ctx, in, in.readableBytes()); + } catch (Throwable cause) { + handleUnwrapThrowable(ctx, cause); + } + } + + private void handleUnwrapThrowable(ChannelHandlerContext ctx, Throwable cause) { + try { + // We should attempt to notify the handshake failure before writing any pending data. If we are in unwrap + // and failed during the handshake process, and we attempt to wrap, then promises will fail, and if + // listeners immediately close the Channel then we may end up firing the handshake event after the Channel + // has been closed. + if (handshakePromise.tryFailure(cause)) { + ctx.fireUserEventTriggered(new SslHandshakeCompletionEvent(cause)); + } + + // Let's check if the handler was removed in the meantime and so pendingUnencryptedWrites is null. + if (pendingUnencryptedWrites != null) { + // We need to flush one time as there may be an alert that we should send to the remote peer because + // of the SSLException reported here. + wrapAndFlush(ctx); + } + } catch (SSLException ex) { + logger.debug("SSLException during trying to call SSLEngine.wrap(...)" + + " because of an previous SSLException, ignoring...", ex); + } finally { + // ensure we always flush and close the channel. + setHandshakeFailure(ctx, cause, true, false, true); + } + PlatformDependent.throwException(cause); + } + + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) throws SSLException { + if (isStateSet(STATE_PROCESS_TASK)) { + return; + } + if (jdkCompatibilityMode) { + decodeJdkCompatible(ctx, in); + } else { + decodeNonJdkCompatible(ctx, in); + } + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + channelReadComplete0(ctx); + } + + private void channelReadComplete0(ChannelHandlerContext ctx) { + // Discard bytes of the cumulation buffer if needed. + discardSomeReadBytes(); + + flushIfNeeded(ctx); + readIfNeeded(ctx); + + clearState(STATE_FIRE_CHANNEL_READ); + ctx.fireChannelReadComplete(); + } + + private void readIfNeeded(ChannelHandlerContext ctx) { + // If handshake is not finished yet, we need more data. + if (!ctx.channel().config().isAutoRead() && + (!isStateSet(STATE_FIRE_CHANNEL_READ) || !handshakePromise.isDone())) { + // No auto-read used and no message passed through the ChannelPipeline or the handshake was not complete + // yet, which means we need to trigger the read to ensure we not encounter any stalls. + ctx.read(); + } + } + + private void flushIfNeeded(ChannelHandlerContext ctx) { + if (isStateSet(STATE_NEEDS_FLUSH)) { + forceFlush(ctx); + } + } + + /** + * Calls {@link SSLEngine#unwrap(ByteBuffer, ByteBuffer)} with an empty buffer to handle handshakes, etc. + */ + private int unwrapNonAppData(ChannelHandlerContext ctx) throws SSLException { + return unwrap(ctx, Unpooled.EMPTY_BUFFER, 0); + } + + /** + * Unwraps inbound SSL records. + */ + private int unwrap(ChannelHandlerContext ctx, ByteBuf packet, int length) throws SSLException { + final int originalLength = length; + boolean wrapLater = false; + boolean notifyClosure = false; + boolean executedRead = false; + ByteBuf decodeOut = allocate(ctx, length); + try { + // Only continue to loop if the handler was not removed in the meantime. + // See https://github.com/netty/netty/issues/5860 + do { + final SSLEngineResult result = engineType.unwrap(this, packet, length, decodeOut); + final Status status = result.getStatus(); + final HandshakeStatus handshakeStatus = result.getHandshakeStatus(); + final int produced = result.bytesProduced(); + final int consumed = result.bytesConsumed(); + + // Skip bytes now in case unwrap is called in a re-entry scenario. For example LocalChannel.read() + // may entry this method in a re-entry fashion and if the peer is writing into a shared buffer we may + // unwrap the same data multiple times. + packet.skipBytes(consumed); + length -= consumed; + + // The expected sequence of events is: + // 1. Notify of handshake success + // 2. fireChannelRead for unwrapped data + if (handshakeStatus == HandshakeStatus.FINISHED || handshakeStatus == HandshakeStatus.NOT_HANDSHAKING) { + wrapLater |= (decodeOut.isReadable() ? + setHandshakeSuccessUnwrapMarkReentry() : setHandshakeSuccess()) || + handshakeStatus == HandshakeStatus.FINISHED; + } + + // Dispatch decoded data after we have notified of handshake success. If this method has been invoked + // in a re-entry fashion we execute a task on the executor queue to process after the stack unwinds + // to preserve order of events. + if (decodeOut.isReadable()) { + setState(STATE_FIRE_CHANNEL_READ); + if (isStateSet(STATE_UNWRAP_REENTRY)) { + executedRead = true; + executeChannelRead(ctx, decodeOut); + } else { + ctx.fireChannelRead(decodeOut); + } + decodeOut = null; + } + + if (status == Status.CLOSED) { + notifyClosure = true; // notify about the CLOSED state of the SSLEngine. See #137 + } else if (status == Status.BUFFER_OVERFLOW) { + if (decodeOut != null) { + decodeOut.release(); + } + final int applicationBufferSize = engine.getSession().getApplicationBufferSize(); + // Allocate a new buffer which can hold all the rest data and loop again. + // It may happen that applicationBufferSize < produced while there is still more to unwrap, in this + // case we will just allocate a new buffer with the capacity of applicationBufferSize and call + // unwrap again. + decodeOut = allocate(ctx, engineType.calculatePendingData(this, applicationBufferSize < produced ? + applicationBufferSize : applicationBufferSize - produced)); + continue; + } + + if (handshakeStatus == HandshakeStatus.NEED_TASK) { + boolean pending = runDelegatedTasks(true); + if (!pending) { + // We scheduled a task on the delegatingTaskExecutor, so stop processing as we will + // resume once the task completes. + // + // We break out of the loop only and do NOT return here as we still may need to notify + // about the closure of the SSLEngine. + wrapLater = false; + break; + } + } else if (handshakeStatus == HandshakeStatus.NEED_WRAP) { + // If the wrap operation transitions the status to NOT_HANDSHAKING and there is no more data to + // unwrap then the next call to unwrap will not produce any data. We can avoid the potentially + // costly unwrap operation and break out of the loop. + if (wrapNonAppData(ctx, true) && length == 0) { + break; + } + } + + if (status == Status.BUFFER_UNDERFLOW || + // If we processed NEED_TASK we should try again even we did not consume or produce anything. + handshakeStatus != HandshakeStatus.NEED_TASK && (consumed == 0 && produced == 0 || + (length == 0 && handshakeStatus == HandshakeStatus.NOT_HANDSHAKING))) { + if (handshakeStatus == HandshakeStatus.NEED_UNWRAP) { + // The underlying engine is starving so we need to feed it with more data. + // See https://github.com/netty/netty/pull/5039 + readIfNeeded(ctx); + } + + break; + } else if (decodeOut == null) { + decodeOut = allocate(ctx, length); + } + } while (!ctx.isRemoved()); + + if (isStateSet(STATE_FLUSHED_BEFORE_HANDSHAKE) && handshakePromise.isDone()) { + // We need to call wrap(...) in case there was a flush done before the handshake completed to ensure + // we do not stale. + // + // See https://github.com/netty/netty/pull/2437 + clearState(STATE_FLUSHED_BEFORE_HANDSHAKE); + wrapLater = true; + } + + if (wrapLater) { + wrap(ctx, true); + } + } finally { + if (decodeOut != null) { + decodeOut.release(); + } + + if (notifyClosure) { + if (executedRead) { + executeNotifyClosePromise(ctx); + } else { + notifyClosePromise(null); + } + } + } + return originalLength - length; + } + + private boolean setHandshakeSuccessUnwrapMarkReentry() { + // setHandshakeSuccess calls out to external methods which may trigger re-entry. We need to preserve ordering of + // fireChannelRead for decodeOut relative to re-entry data. + final boolean setReentryState = !isStateSet(STATE_UNWRAP_REENTRY); + if (setReentryState) { + setState(STATE_UNWRAP_REENTRY); + } + try { + return setHandshakeSuccess(); + } finally { + // It is unlikely this specific method will be re-entry because handshake completion is infrequent, but just + // in case we only clear the state if we set it in the first place. + if (setReentryState) { + clearState(STATE_UNWRAP_REENTRY); + } + } + } + + private void executeNotifyClosePromise(final ChannelHandlerContext ctx) { + try { + ctx.executor().execute(new Runnable() { + @Override + public void run() { + notifyClosePromise(null); + } + }); + } catch (RejectedExecutionException e) { + notifyClosePromise(e); + } + } + + private void executeChannelRead(final ChannelHandlerContext ctx, final ByteBuf decodedOut) { + try { + ctx.executor().execute(new Runnable() { + @Override + public void run() { + ctx.fireChannelRead(decodedOut); + } + }); + } catch (RejectedExecutionException e) { + decodedOut.release(); + throw e; + } + } + + private static ByteBuffer toByteBuffer(ByteBuf out, int index, int len) { + return out.nioBufferCount() == 1 ? out.internalNioBuffer(index, len) : + out.nioBuffer(index, len); + } + + private static boolean inEventLoop(Executor executor) { + return executor instanceof EventExecutor && ((EventExecutor) executor).inEventLoop(); + } + + /** + * Will either run the delegated task directly calling {@link Runnable#run()} and return {@code true} or will + * offload the delegated task using {@link Executor#execute(Runnable)} and return {@code false}. + * + * If the task is offloaded it will take care to resume its work on the {@link EventExecutor} once there are no + * more tasks to process. + */ + private boolean runDelegatedTasks(boolean inUnwrap) { + if (delegatedTaskExecutor == ImmediateExecutor.INSTANCE || inEventLoop(delegatedTaskExecutor)) { + // We should run the task directly in the EventExecutor thread and not offload at all. As we are on the + // EventLoop we can just run all tasks at once. + for (;;) { + Runnable task = engine.getDelegatedTask(); + if (task == null) { + return true; + } + setState(STATE_PROCESS_TASK); + if (task instanceof AsyncRunnable) { + // Let's set the task to processing task before we try to execute it. + boolean pending = false; + try { + AsyncRunnable asyncTask = (AsyncRunnable) task; + AsyncTaskCompletionHandler completionHandler = new AsyncTaskCompletionHandler(inUnwrap); + asyncTask.run(completionHandler); + pending = completionHandler.resumeLater(); + if (pending) { + return false; + } + } finally { + if (!pending) { + // The task has completed, lets clear the state. If it is not completed we will clear the + // state once it is. + clearState(STATE_PROCESS_TASK); + } + } + } else { + try { + task.run(); + } finally { + clearState(STATE_PROCESS_TASK); + } + } + } + } else { + executeDelegatedTask(inUnwrap); + return false; + } + } + + private SslTasksRunner getTaskRunner(boolean inUnwrap) { + return inUnwrap ? sslTaskRunnerForUnwrap : sslTaskRunner; + } + + private void executeDelegatedTask(boolean inUnwrap) { + executeDelegatedTask(getTaskRunner(inUnwrap)); + } + + private void executeDelegatedTask(SslTasksRunner task) { + setState(STATE_PROCESS_TASK); + try { + delegatedTaskExecutor.execute(task); + } catch (RejectedExecutionException e) { + clearState(STATE_PROCESS_TASK); + throw e; + } + } + + private final class AsyncTaskCompletionHandler implements Runnable { + private final boolean inUnwrap; + boolean didRun; + boolean resumeLater; + + AsyncTaskCompletionHandler(boolean inUnwrap) { + this.inUnwrap = inUnwrap; + } + + @Override + public void run() { + didRun = true; + if (resumeLater) { + getTaskRunner(inUnwrap).runComplete(); + } + } + + boolean resumeLater() { + if (!didRun) { + resumeLater = true; + return true; + } + return false; + } + } + + /** + * {@link Runnable} that will be scheduled on the {@code delegatedTaskExecutor} and will take care + * of resume work on the {@link EventExecutor} once the task was executed. + */ + private final class SslTasksRunner implements Runnable { + private final boolean inUnwrap; + private final Runnable runCompleteTask = new Runnable() { + @Override + public void run() { + runComplete(); + } + }; + + SslTasksRunner(boolean inUnwrap) { + this.inUnwrap = inUnwrap; + } + + // Handle errors which happened during task processing. + private void taskError(Throwable e) { + if (inUnwrap) { + // As the error happened while the task was scheduled as part of unwrap(...) we also need to ensure + // we fire it through the pipeline as inbound error to be consistent with what we do in decode(...). + // + // This will also ensure we fail the handshake future and flush all produced data. + try { + handleUnwrapThrowable(ctx, e); + } catch (Throwable cause) { + safeExceptionCaught(cause); + } + } else { + setHandshakeFailure(ctx, e); + forceFlush(ctx); + } + } + + // Try to call exceptionCaught(...) + private void safeExceptionCaught(Throwable cause) { + try { + exceptionCaught(ctx, wrapIfNeeded(cause)); + } catch (Throwable error) { + ctx.fireExceptionCaught(error); + } + } + + private Throwable wrapIfNeeded(Throwable cause) { + if (!inUnwrap) { + // If we are not in unwrap(...) we can just rethrow without wrapping at all. + return cause; + } + // As the exception would have been triggered by an inbound operation we will need to wrap it in a + // DecoderException to mimic what a decoder would do when decode(...) throws. + return cause instanceof DecoderException ? cause : new DecoderException(cause); + } + + private void tryDecodeAgain() { + try { + channelRead(ctx, Unpooled.EMPTY_BUFFER); + } catch (Throwable cause) { + safeExceptionCaught(cause); + } finally { + // As we called channelRead(...) we also need to call channelReadComplete(...) which + // will ensure we either call ctx.fireChannelReadComplete() or will trigger a ctx.read() if + // more data is needed. + channelReadComplete0(ctx); + } + } + + /** + * Executed after the wrapped {@code task} was executed via {@code delegatedTaskExecutor} to resume work + * on the {@link EventExecutor}. + */ + private void resumeOnEventExecutor() { + assert ctx.executor().inEventLoop(); + clearState(STATE_PROCESS_TASK); + try { + HandshakeStatus status = engine.getHandshakeStatus(); + switch (status) { + // There is another task that needs to be executed and offloaded to the delegatingTaskExecutor as + // a result of this. Let's reschedule.... + case NEED_TASK: + executeDelegatedTask(this); + + break; + + // The handshake finished, lets notify about the completion of it and resume processing. + case FINISHED: + // Not handshaking anymore, lets notify about the completion if not done yet and resume processing. + case NOT_HANDSHAKING: + setHandshakeSuccess(); // NOT_HANDSHAKING -> workaround for android skipping FINISHED state. + try { + // Lets call wrap to ensure we produce the alert if there is any pending and also to + // ensure we flush any queued data.. + wrap(ctx, inUnwrap); + } catch (Throwable e) { + taskError(e); + return; + } + if (inUnwrap) { + // If we were in the unwrap call when the task was processed we should also try to unwrap + // non app data first as there may not anything left in the inbound buffer to process. + unwrapNonAppData(ctx); + } + + // Flush now as we may have written some data as part of the wrap call. + forceFlush(ctx); + + tryDecodeAgain(); + break; + + // We need more data so lets try to unwrap first and then call decode again which will feed us + // with buffered data (if there is any). + case NEED_UNWRAP: + try { + unwrapNonAppData(ctx); + } catch (SSLException e) { + handleUnwrapThrowable(ctx, e); + return; + } + tryDecodeAgain(); + break; + + // To make progress we need to call SSLEngine.wrap(...) which may produce more output data + // that will be written to the Channel. + case NEED_WRAP: + try { + if (!wrapNonAppData(ctx, false) && inUnwrap) { + // The handshake finished in wrapNonAppData(...), we need to try call + // unwrapNonAppData(...) as we may have some alert that we should read. + // + // This mimics what we would do when we are calling this method while in unwrap(...). + unwrapNonAppData(ctx); + } + + // Flush now as we may have written some data as part of the wrap call. + forceFlush(ctx); + } catch (Throwable e) { + taskError(e); + return; + } + + // Now try to feed in more data that we have buffered. + tryDecodeAgain(); + break; + + default: + // Should never reach here as we handle all cases. + throw new AssertionError(); + } + } catch (Throwable cause) { + safeExceptionCaught(cause); + } + } + + void runComplete() { + EventExecutor executor = ctx.executor(); + // Jump back on the EventExecutor. We do this even if we are already on the EventLoop to guard against + // reentrancy issues. Failing to do so could lead to the situation of tryDecode(...) be called and so + // channelRead(...) while still in the decode loop. In this case channelRead(...) might release the input + // buffer if its empty which would then result in an IllegalReferenceCountException when we try to continue + // decoding. + // + // See https://github.com/netty/netty-tcnative/issues/680 + executor.execute(new Runnable() { + @Override + public void run() { + resumeOnEventExecutor(); + } + }); + } + + @Override + public void run() { + try { + Runnable task = engine.getDelegatedTask(); + if (task == null) { + // The task was processed in the meantime. Let's just return. + return; + } + if (task instanceof AsyncRunnable) { + AsyncRunnable asyncTask = (AsyncRunnable) task; + asyncTask.run(runCompleteTask); + } else { + task.run(); + runComplete(); + } + } catch (final Throwable cause) { + handleException(cause); + } + } + + private void handleException(final Throwable cause) { + EventExecutor executor = ctx.executor(); + if (executor.inEventLoop()) { + clearState(STATE_PROCESS_TASK); + safeExceptionCaught(cause); + } else { + try { + executor.execute(new Runnable() { + @Override + public void run() { + clearState(STATE_PROCESS_TASK); + safeExceptionCaught(cause); + } + }); + } catch (RejectedExecutionException ignore) { + clearState(STATE_PROCESS_TASK); + // the context itself will handle the rejected exception when try to schedule the operation so + // ignore the RejectedExecutionException + ctx.fireExceptionCaught(cause); + } + } + } + } + + /** + * Notify all the handshake futures about the successfully handshake + * @return {@code true} if {@link #handshakePromise} was set successfully and a {@link SslHandshakeCompletionEvent} + * was fired. {@code false} otherwise. + */ + private boolean setHandshakeSuccess() { + // Our control flow may invoke this method multiple times for a single FINISHED event. For example + // wrapNonAppData may drain pendingUnencryptedWrites in wrap which transitions to handshake from FINISHED to + // NOT_HANDSHAKING which invokes setHandshakeSuccess, and then wrapNonAppData also directly invokes this method. + final boolean notified; + if (notified = !handshakePromise.isDone() && handshakePromise.trySuccess(ctx.channel())) { + if (logger.isDebugEnabled()) { + SSLSession session = engine.getSession(); + logger.debug( + "{} HANDSHAKEN: protocol:{} cipher suite:{}", + ctx.channel(), + session.getProtocol(), + session.getCipherSuite()); + } + ctx.fireUserEventTriggered(SslHandshakeCompletionEvent.SUCCESS); + } + if (isStateSet(STATE_READ_DURING_HANDSHAKE)) { + clearState(STATE_READ_DURING_HANDSHAKE); + if (!ctx.channel().config().isAutoRead()) { + ctx.read(); + } + } + return notified; + } + + /** + * Notify all the handshake futures about the failure during the handshake. + */ + private void setHandshakeFailure(ChannelHandlerContext ctx, Throwable cause) { + setHandshakeFailure(ctx, cause, true, true, false); + } + + /** + * Notify all the handshake futures about the failure during the handshake. + */ + private void setHandshakeFailure(ChannelHandlerContext ctx, Throwable cause, boolean closeInbound, + boolean notify, boolean alwaysFlushAndClose) { + try { + // Release all resources such as internal buffers that SSLEngine is managing. + setState(STATE_OUTBOUND_CLOSED); + engine.closeOutbound(); + + if (closeInbound) { + try { + engine.closeInbound(); + } catch (SSLException e) { + if (logger.isDebugEnabled()) { + // only log in debug mode as it most likely harmless and latest chrome still trigger + // this all the time. + // + // See https://github.com/netty/netty/issues/1340 + String msg = e.getMessage(); + if (msg == null || !(msg.contains("possible truncation attack") || + msg.contains("closing inbound before receiving peer's close_notify"))) { + logger.debug("{} SSLEngine.closeInbound() raised an exception.", ctx.channel(), e); + } + } + } + } + if (handshakePromise.tryFailure(cause) || alwaysFlushAndClose) { + SslUtils.handleHandshakeFailure(ctx, cause, notify); + } + } finally { + // Ensure we remove and fail all pending writes in all cases and so release memory quickly. + releaseAndFailAll(ctx, cause); + } + } + + private void setHandshakeFailureTransportFailure(ChannelHandlerContext ctx, Throwable cause) { + // If TLS control frames fail to write we are in an unknown state and may become out of + // sync with our peer. We give up and close the channel. This will also take care of + // cleaning up any outstanding state (e.g. handshake promise, queued unencrypted data). + try { + SSLException transportFailure = new SSLException("failure when writing TLS control frames", cause); + releaseAndFailAll(ctx, transportFailure); + if (handshakePromise.tryFailure(transportFailure)) { + ctx.fireUserEventTriggered(new SslHandshakeCompletionEvent(transportFailure)); + } + } finally { + ctx.close(); + } + } + + private void releaseAndFailAll(ChannelHandlerContext ctx, Throwable cause) { + if (pendingUnencryptedWrites != null) { + pendingUnencryptedWrites.releaseAndFailAll(ctx, cause); + } + } + + private void notifyClosePromise(Throwable cause) { + if (cause == null) { + if (sslClosePromise.trySuccess(ctx.channel())) { + ctx.fireUserEventTriggered(SslCloseCompletionEvent.SUCCESS); + } + } else { + if (sslClosePromise.tryFailure(cause)) { + ctx.fireUserEventTriggered(new SslCloseCompletionEvent(cause)); + } + } + } + + private void closeOutboundAndChannel( + final ChannelHandlerContext ctx, final ChannelPromise promise, boolean disconnect) throws Exception { + setState(STATE_OUTBOUND_CLOSED); + engine.closeOutbound(); + + if (!ctx.channel().isActive()) { + if (disconnect) { + ctx.disconnect(promise); + } else { + ctx.close(promise); + } + return; + } + + ChannelPromise closeNotifyPromise = ctx.newPromise(); + try { + flush(ctx, closeNotifyPromise); + } finally { + if (!isStateSet(STATE_CLOSE_NOTIFY)) { + setState(STATE_CLOSE_NOTIFY); + // It's important that we do not pass the original ChannelPromise to safeClose(...) as when flush(....) + // throws an Exception it will be propagated to the AbstractChannelHandlerContext which will try + // to fail the promise because of this. This will then fail as it was already completed by + // safeClose(...). We create a new ChannelPromise and try to notify the original ChannelPromise + // once it is complete. If we fail to do so we just ignore it as in this case it was failed already + // because of a propagated Exception. + // + // See https://github.com/netty/netty/issues/5931 + safeClose(ctx, closeNotifyPromise, PromiseNotifier.cascade(false, ctx.newPromise(), promise)); + } else { + /// We already handling the close_notify so just attach the promise to the sslClosePromise. + sslClosePromise.addListener(new FutureListener() { + @Override + public void operationComplete(Future future) { + promise.setSuccess(); + } + }); + } + } + } + + private void flush(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + if (pendingUnencryptedWrites != null) { + pendingUnencryptedWrites.add(Unpooled.EMPTY_BUFFER, promise); + } else { + promise.setFailure(newPendingWritesNullException()); + } + flush(ctx); + } + + @Override + public void handlerAdded(final ChannelHandlerContext ctx) throws Exception { + this.ctx = ctx; + Channel channel = ctx.channel(); + pendingUnencryptedWrites = new SslHandlerCoalescingBufferQueue(channel, 16); + + setOpensslEngineSocketFd(channel); + boolean fastOpen = Boolean.TRUE.equals(channel.config().getOption(ChannelOption.TCP_FASTOPEN_CONNECT)); + boolean active = channel.isActive(); + if (active || fastOpen) { + // Explicitly flush the handshake only if the channel is already active. + // With TCP Fast Open, we write to the outbound buffer before the TCP connect is established. + // The buffer will then be flushed as part of establishing the connection, saving us a round-trip. + startHandshakeProcessing(active); + // If we weren't able to include client_hello in the TCP SYN (e.g. no token, disabled at the OS) we have to + // flush pending data in the outbound buffer later in channelActive(). + final ChannelOutboundBuffer outboundBuffer; + if (fastOpen && ((outboundBuffer = channel.unsafe().outboundBuffer()) == null || + outboundBuffer.totalPendingWriteBytes() > 0)) { + setState(STATE_NEEDS_FLUSH); + } + } + } + + private void startHandshakeProcessing(boolean flushAtEnd) { + if (!isStateSet(STATE_HANDSHAKE_STARTED)) { + setState(STATE_HANDSHAKE_STARTED); + if (engine.getUseClientMode()) { + // Begin the initial handshake. + // channelActive() event has been fired already, which means this.channelActive() will + // not be invoked. We have to initialize here instead. + handshake(flushAtEnd); + } + applyHandshakeTimeout(); + } else if (isStateSet(STATE_NEEDS_FLUSH)) { + forceFlush(ctx); + } + } + + /** + * Performs TLS renegotiation. + */ + public Future renegotiate() { + ChannelHandlerContext ctx = this.ctx; + if (ctx == null) { + throw new IllegalStateException(); + } + + return renegotiate(ctx.executor().newPromise()); + } + + /** + * Performs TLS renegotiation. + */ + public Future renegotiate(final Promise promise) { + ObjectUtil.checkNotNull(promise, "promise"); + + ChannelHandlerContext ctx = this.ctx; + if (ctx == null) { + throw new IllegalStateException(); + } + + EventExecutor executor = ctx.executor(); + if (!executor.inEventLoop()) { + executor.execute(new Runnable() { + @Override + public void run() { + renegotiateOnEventLoop(promise); + } + }); + return promise; + } + + renegotiateOnEventLoop(promise); + return promise; + } + + private void renegotiateOnEventLoop(final Promise newHandshakePromise) { + final Promise oldHandshakePromise = handshakePromise; + if (!oldHandshakePromise.isDone()) { + // There's no need to handshake because handshake is in progress already. + // Merge the new promise into the old one. + PromiseNotifier.cascade(oldHandshakePromise, newHandshakePromise); + } else { + handshakePromise = newHandshakePromise; + handshake(true); + applyHandshakeTimeout(); + } + } + + /** + * Performs TLS (re)negotiation. + * @param flushAtEnd Set to {@code true} if the outbound buffer should be flushed (written to the network) at the + * end. Set to {@code false} if the handshake will be flushed later, e.g. as part of TCP Fast Open + * connect. + */ + private void handshake(boolean flushAtEnd) { + if (engine.getHandshakeStatus() != HandshakeStatus.NOT_HANDSHAKING) { + // Not all SSLEngine implementations support calling beginHandshake multiple times while a handshake + // is in progress. See https://github.com/netty/netty/issues/4718. + return; + } + if (handshakePromise.isDone()) { + // If the handshake is done already lets just return directly as there is no need to trigger it again. + // This can happen if the handshake(...) was triggered before we called channelActive(...) by a + // flush() that was triggered by a ChannelFutureListener that was added to the ChannelFuture returned + // from the connect(...) method. In this case we will see the flush() happen before we had a chance to + // call fireChannelActive() on the pipeline. + return; + } + + // Begin handshake. + final ChannelHandlerContext ctx = this.ctx; + try { + engine.beginHandshake(); + wrapNonAppData(ctx, false); + } catch (Throwable e) { + setHandshakeFailure(ctx, e); + } finally { + if (flushAtEnd) { + forceFlush(ctx); + } + } + } + + private void applyHandshakeTimeout() { + final Promise localHandshakePromise = this.handshakePromise; + + // Set timeout if necessary. + final long handshakeTimeoutMillis = this.handshakeTimeoutMillis; + if (handshakeTimeoutMillis <= 0 || localHandshakePromise.isDone()) { + return; + } + + final Future timeoutFuture = ctx.executor().schedule(new Runnable() { + @Override + public void run() { + if (localHandshakePromise.isDone()) { + return; + } + SSLException exception = + new SslHandshakeTimeoutException("handshake timed out after " + handshakeTimeoutMillis + "ms"); + try { + if (localHandshakePromise.tryFailure(exception)) { + SslUtils.handleHandshakeFailure(ctx, exception, true); + } + } finally { + releaseAndFailAll(ctx, exception); + } + } + }, handshakeTimeoutMillis, TimeUnit.MILLISECONDS); + + // Cancel the handshake timeout when handshake is finished. + localHandshakePromise.addListener(new FutureListener() { + @Override + public void operationComplete(Future f) throws Exception { + timeoutFuture.cancel(false); + } + }); + } + + private void forceFlush(ChannelHandlerContext ctx) { + clearState(STATE_NEEDS_FLUSH); + ctx.flush(); + } + + private void setOpensslEngineSocketFd(Channel c) { + if (c instanceof UnixChannel && engine instanceof ReferenceCountedOpenSslEngine) { + ((ReferenceCountedOpenSslEngine) engine).bioSetFd(((UnixChannel) c).fd().intValue()); + } + } + + /** + * Issues an initial TLS handshake once connected when used in client-mode + */ + @Override + public void channelActive(final ChannelHandlerContext ctx) throws Exception { + setOpensslEngineSocketFd(ctx.channel()); + if (!startTls) { + startHandshakeProcessing(true); + } + ctx.fireChannelActive(); + } + + private void safeClose( + final ChannelHandlerContext ctx, final ChannelFuture flushFuture, + final ChannelPromise promise) { + if (!ctx.channel().isActive()) { + ctx.close(promise); + return; + } + + final Future timeoutFuture; + if (!flushFuture.isDone()) { + long closeNotifyTimeout = closeNotifyFlushTimeoutMillis; + if (closeNotifyTimeout > 0) { + // Force-close the connection if close_notify is not fully sent in time. + timeoutFuture = ctx.executor().schedule(new Runnable() { + @Override + public void run() { + // May be done in the meantime as cancel(...) is only best effort. + if (!flushFuture.isDone()) { + logger.warn("{} Last write attempt timed out; force-closing the connection.", + ctx.channel()); + addCloseListener(ctx.close(ctx.newPromise()), promise); + } + } + }, closeNotifyTimeout, TimeUnit.MILLISECONDS); + } else { + timeoutFuture = null; + } + } else { + timeoutFuture = null; + } + + // Close the connection if close_notify is sent in time. + flushFuture.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture f) { + if (timeoutFuture != null) { + timeoutFuture.cancel(false); + } + final long closeNotifyReadTimeout = closeNotifyReadTimeoutMillis; + if (closeNotifyReadTimeout <= 0) { + // Trigger the close in all cases to make sure the promise is notified + // See https://github.com/netty/netty/issues/2358 + addCloseListener(ctx.close(ctx.newPromise()), promise); + } else { + final Future closeNotifyReadTimeoutFuture; + + if (!sslClosePromise.isDone()) { + closeNotifyReadTimeoutFuture = ctx.executor().schedule(new Runnable() { + @Override + public void run() { + if (!sslClosePromise.isDone()) { + logger.debug( + "{} did not receive close_notify in {}ms; force-closing the connection.", + ctx.channel(), closeNotifyReadTimeout); + + // Do the close now... + addCloseListener(ctx.close(ctx.newPromise()), promise); + } + } + }, closeNotifyReadTimeout, TimeUnit.MILLISECONDS); + } else { + closeNotifyReadTimeoutFuture = null; + } + + // Do the close once the we received the close_notify. + sslClosePromise.addListener(new FutureListener() { + @Override + public void operationComplete(Future future) throws Exception { + if (closeNotifyReadTimeoutFuture != null) { + closeNotifyReadTimeoutFuture.cancel(false); + } + addCloseListener(ctx.close(ctx.newPromise()), promise); + } + }); + } + } + }); + } + + private static void addCloseListener(ChannelFuture future, ChannelPromise promise) { + // We notify the promise in the ChannelPromiseNotifier as there is a "race" where the close(...) call + // by the timeoutFuture and the close call in the flushFuture listener will be called. Because of + // this we need to use trySuccess() and tryFailure(...) as otherwise we can cause an + // IllegalStateException. + // Also we not want to log if the notification happens as this is expected in some cases. + // See https://github.com/netty/netty/issues/5598 + PromiseNotifier.cascade(false, future, promise); + } + + /** + * Always prefer a direct buffer when it's pooled, so that we reduce the number of memory copies + * in {@link OpenSslEngine}. + */ + private ByteBuf allocate(ChannelHandlerContext ctx, int capacity) { + ByteBufAllocator alloc = ctx.alloc(); + if (engineType.wantsDirectBuffer) { + return alloc.directBuffer(capacity); + } else { + return alloc.buffer(capacity); + } + } + + /** + * Allocates an outbound network buffer for {@link SSLEngine#wrap(ByteBuffer, ByteBuffer)} which can encrypt + * the specified amount of pending bytes. + */ + private ByteBuf allocateOutNetBuf(ChannelHandlerContext ctx, int pendingBytes, int numComponents) { + return engineType.allocateWrapBuffer(this, ctx.alloc(), pendingBytes, numComponents); + } + + private boolean isStateSet(int bit) { + return (state & bit) == bit; + } + + private void setState(int bit) { + state |= bit; + } + + private void clearState(int bit) { + state &= ~bit; + } + + /** + * Each call to SSL_write will introduce about ~100 bytes of overhead. This coalescing queue attempts to increase + * goodput by aggregating the plaintext in chunks of {@link #wrapDataSize}. If many small chunks are written + * this can increase goodput, decrease the amount of calls to SSL_write, and decrease overall encryption operations. + */ + private final class SslHandlerCoalescingBufferQueue extends AbstractCoalescingBufferQueue { + + SslHandlerCoalescingBufferQueue(Channel channel, int initSize) { + super(channel, initSize); + } + + @Override + protected ByteBuf compose(ByteBufAllocator alloc, ByteBuf cumulation, ByteBuf next) { + final int wrapDataSize = SslHandler.this.wrapDataSize; + if (cumulation instanceof CompositeByteBuf) { + CompositeByteBuf composite = (CompositeByteBuf) cumulation; + int numComponents = composite.numComponents(); + if (numComponents == 0 || + !attemptCopyToCumulation(composite.internalComponent(numComponents - 1), next, wrapDataSize)) { + composite.addComponent(true, next); + } + return composite; + } + return attemptCopyToCumulation(cumulation, next, wrapDataSize) ? cumulation : + copyAndCompose(alloc, cumulation, next); + } + + @Override + protected ByteBuf composeFirst(ByteBufAllocator allocator, ByteBuf first) { + if (first instanceof CompositeByteBuf) { + CompositeByteBuf composite = (CompositeByteBuf) first; + if (engineType.wantsDirectBuffer) { + first = allocator.directBuffer(composite.readableBytes()); + } else { + first = allocator.heapBuffer(composite.readableBytes()); + } + try { + first.writeBytes(composite); + } catch (Throwable cause) { + first.release(); + PlatformDependent.throwException(cause); + } + composite.release(); + } + return first; + } + + @Override + protected ByteBuf removeEmptyValue() { + return null; + } + } + + private static boolean attemptCopyToCumulation(ByteBuf cumulation, ByteBuf next, int wrapDataSize) { + final int inReadableBytes = next.readableBytes(); + final int cumulationCapacity = cumulation.capacity(); + if (wrapDataSize - cumulation.readableBytes() >= inReadableBytes && + // Avoid using the same buffer if next's data would make cumulation exceed the wrapDataSize. + // Only copy if there is enough space available and the capacity is large enough, and attempt to + // resize if the capacity is small. + (cumulation.isWritable(inReadableBytes) && cumulationCapacity >= wrapDataSize || + cumulationCapacity < wrapDataSize && + ensureWritableSuccess(cumulation.ensureWritable(inReadableBytes, false)))) { + cumulation.writeBytes(next); + next.release(); + return true; + } + return false; + } + + private final class LazyChannelPromise extends DefaultPromise { + + @Override + protected EventExecutor executor() { + if (ctx == null) { + throw new IllegalStateException(); + } + return ctx.executor(); + } + + @Override + protected void checkDeadLock() { + if (ctx == null) { + // If ctx is null the handlerAdded(...) callback was not called, in this case the checkDeadLock() + // method was called from another Thread then the one that is used by ctx.executor(). We need to + // guard against this as a user can see a race if handshakeFuture().sync() is called but the + // handlerAdded(..) method was not yet as it is called from the EventExecutor of the + // ChannelHandlerContext. If we not guard against this super.checkDeadLock() would cause an + // IllegalStateException when trying to call executor(). + return; + } + super.checkDeadLock(); + } + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/SslHandshakeCompletionEvent.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/SslHandshakeCompletionEvent.java new file mode 100644 index 0000000..004c7b0 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/SslHandshakeCompletionEvent.java @@ -0,0 +1,39 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + + +/** + * Event that is fired once the SSL handshake is complete, which may be because it was successful or there + * was an error. + */ +public final class SslHandshakeCompletionEvent extends SslCompletionEvent { + + public static final SslHandshakeCompletionEvent SUCCESS = new SslHandshakeCompletionEvent(); + + /** + * Creates a new event that indicates a successful handshake. + */ + private SslHandshakeCompletionEvent() { } + + /** + * Creates a new event that indicates an unsuccessful handshake. + * Use {@link #SUCCESS} to indicate a successful handshake. + */ + public SslHandshakeCompletionEvent(Throwable cause) { + super(cause); + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/SslHandshakeTimeoutException.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/SslHandshakeTimeoutException.java new file mode 100644 index 0000000..cbbbd75 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/SslHandshakeTimeoutException.java @@ -0,0 +1,28 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import javax.net.ssl.SSLHandshakeException; + +/** + * {@link SSLHandshakeException} that is used when a handshake failed due a configured timeout. + */ +public final class SslHandshakeTimeoutException extends SSLHandshakeException { + + SslHandshakeTimeoutException(String reason) { + super(reason); + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/SslMasterKeyHandler.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/SslMasterKeyHandler.java new file mode 100644 index 0000000..5077d49 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/SslMasterKeyHandler.java @@ -0,0 +1,199 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.buffer.ByteBufUtil; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.util.internal.ReflectionUtil; +import io.netty.util.internal.SystemPropertyUtil; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import javax.crypto.SecretKey; +import javax.crypto.spec.SecretKeySpec; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLSession; +import java.lang.reflect.Field; + +/** + * The {@link SslMasterKeyHandler} is a channel-handler you can include in your pipeline to consume the master key + * & session identifier for a TLS session. + * This can be very useful, for instance the {@link WiresharkSslMasterKeyHandler} implementation will + * log the secret & identifier in a format that is consumable by Wireshark -- allowing easy decryption of pcap/tcpdumps. + */ +public abstract class SslMasterKeyHandler extends ChannelInboundHandlerAdapter { + + private static final InternalLogger logger = InternalLoggerFactory.getInstance(SslMasterKeyHandler.class); + + /** + * The JRE SSLSessionImpl cannot be imported + */ + private static final Class SSL_SESSIONIMPL_CLASS; + + /** + * The master key field in the SSLSessionImpl + */ + private static final Field SSL_SESSIONIMPL_MASTER_SECRET_FIELD; + + /** + * A system property that can be used to turn on/off the {@link SslMasterKeyHandler} dynamically without having + * to edit your pipeline. + * -Dio.netty.ssl.masterKeyHandler=true + */ + public static final String SYSTEM_PROP_KEY = "io.netty.ssl.masterKeyHandler"; + + /** + * The unavailability cause of whether the private Sun implementation of SSLSessionImpl is available. + */ + private static final Throwable UNAVAILABILITY_CAUSE; + + static { + Throwable cause; + Class clazz = null; + Field field = null; + try { + clazz = Class.forName("sun.security.ssl.SSLSessionImpl"); + field = clazz.getDeclaredField("masterSecret"); + cause = ReflectionUtil.trySetAccessible(field, true); + } catch (Throwable e) { + cause = e; + if (logger.isTraceEnabled()) { + logger.debug("sun.security.ssl.SSLSessionImpl is unavailable.", e); + } else { + logger.debug("sun.security.ssl.SSLSessionImpl is unavailable: {}", e.getMessage()); + } + } + UNAVAILABILITY_CAUSE = cause; + SSL_SESSIONIMPL_CLASS = clazz; + SSL_SESSIONIMPL_MASTER_SECRET_FIELD = field; + } + + /** + * Constructor. + */ + protected SslMasterKeyHandler() { + } + + /** + * Ensure that SSLSessionImpl is available. + * @throws UnsatisfiedLinkError if unavailable + */ + public static void ensureSunSslEngineAvailability() { + if (UNAVAILABILITY_CAUSE != null) { + throw new IllegalStateException( + "Failed to find SSLSessionImpl on classpath", UNAVAILABILITY_CAUSE); + } + } + + /** + * Returns the cause of unavailability. + * + * @return the cause if unavailable. {@code null} if available. + */ + public static Throwable sunSslEngineUnavailabilityCause() { + return UNAVAILABILITY_CAUSE; + } + + /* Returns {@code true} if and only if sun.security.ssl.SSLSessionImpl exists in the runtime. + */ + public static boolean isSunSslEngineAvailable() { + return UNAVAILABILITY_CAUSE == null; + } + + /** + * Consume the master key for the session and the sessionId + * @param masterKey A 48-byte secret shared between the client and server. + * @param session The current TLS session + */ + protected abstract void accept(SecretKey masterKey, SSLSession session); + + @Override + public final void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + //only try to log the session info if the ssl handshake has successfully completed. + if (evt == SslHandshakeCompletionEvent.SUCCESS && masterKeyHandlerEnabled()) { + final SslHandler handler = ctx.pipeline().get(SslHandler.class); + final SSLEngine engine = handler.engine(); + final SSLSession sslSession = engine.getSession(); + + //the OpenJDK does not expose a way to get the master secret, so try to use reflection to get it. + if (isSunSslEngineAvailable() && sslSession.getClass().equals(SSL_SESSIONIMPL_CLASS)) { + final SecretKey secretKey; + try { + secretKey = (SecretKey) SSL_SESSIONIMPL_MASTER_SECRET_FIELD.get(sslSession); + } catch (IllegalAccessException e) { + throw new IllegalArgumentException("Failed to access the field 'masterSecret' " + + "via reflection.", e); + } + accept(secretKey, sslSession); + } else if (OpenSsl.isAvailable() && engine instanceof ReferenceCountedOpenSslEngine) { + SecretKeySpec secretKey = ((ReferenceCountedOpenSslEngine) engine).masterKey(); + accept(secretKey, sslSession); + } + } + + ctx.fireUserEventTriggered(evt); + } + + /** + * Checks if the handler is set up to actually handle/accept the event. + * By default the {@link #SYSTEM_PROP_KEY} property is checked, but any implementations of this class are + * free to override if they have different mechanisms of checking. + * + * @return true if it should handle, false otherwise. + */ + protected boolean masterKeyHandlerEnabled() { + return SystemPropertyUtil.getBoolean(SYSTEM_PROP_KEY, false); + } + + /** + * Create a {@link WiresharkSslMasterKeyHandler} instance. + * This TLS master key handler logs the master key and session-id in a format + * understood by Wireshark -- this can be especially useful if you need to ever + * decrypt a TLS session and are using perfect forward secrecy (i.e. Diffie-Hellman) + * The key and session identifier are forwarded to the log named 'io.netty.wireshark'. + */ + public static SslMasterKeyHandler newWireSharkSslMasterKeyHandler() { + return new WiresharkSslMasterKeyHandler(); + } + + /** + * Record the session identifier and master key to the {@link InternalLogger} named {@code io.netty.wireshark}. + * ex. {@code RSA Session-ID:XXX Master-Key:YYY} + * This format is understood by Wireshark 1.6.0. + * See: Wireshark + * The key and session identifier are forwarded to the log named 'io.netty.wireshark'. + */ + private static final class WiresharkSslMasterKeyHandler extends SslMasterKeyHandler { + + private static final InternalLogger wireshark_logger = + InternalLoggerFactory.getInstance("io.netty.wireshark"); + + @Override + protected void accept(SecretKey masterKey, SSLSession session) { + if (masterKey.getEncoded().length != 48) { + throw new IllegalArgumentException("An invalid length master key was provided."); + } + final byte[] sessionId = session.getId(); + wireshark_logger.warn("RSA Session-ID:{} Master-Key:{}", + ByteBufUtil.hexDump(sessionId).toLowerCase(), + ByteBufUtil.hexDump(masterKey.getEncoded()).toLowerCase()); + } + } + +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/SslProtocols.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/SslProtocols.java new file mode 100644 index 0000000..c38e1cf --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/SslProtocols.java @@ -0,0 +1,76 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +/** + * SSL/TLS protocols + */ +public final class SslProtocols { + + /** + * SSL v2 Hello + * + * @deprecated SSLv2Hello is no longer secure. Consider using {@link #TLS_v1_2} or {@link #TLS_v1_3} + */ + @Deprecated + public static final String SSL_v2_HELLO = "SSLv2Hello"; + + /** + * SSL v2 + * + * @deprecated SSLv2 is no longer secure. Consider using {@link #TLS_v1_2} or {@link #TLS_v1_3} + */ + @Deprecated + public static final String SSL_v2 = "SSLv2"; + + /** + * SSLv3 + * + * @deprecated SSLv3 is no longer secure. Consider using {@link #TLS_v1_2} or {@link #TLS_v1_3} + */ + @Deprecated + public static final String SSL_v3 = "SSLv3"; + + /** + * TLS v1 + * + * @deprecated TLSv1 is no longer secure. Consider using {@link #TLS_v1_2} or {@link #TLS_v1_3} + */ + @Deprecated + public static final String TLS_v1 = "TLSv1"; + + /** + * TLS v1.1 + * + * @deprecated TLSv1.1 is no longer secure. Consider using {@link #TLS_v1_2} or {@link #TLS_v1_3} + */ + @Deprecated + public static final String TLS_v1_1 = "TLSv1.1"; + + /** + * TLS v1.2 + */ + public static final String TLS_v1_2 = "TLSv1.2"; + + /** + * TLS v1.3 + */ + public static final String TLS_v1_3 = "TLSv1.3"; + + private SslProtocols() { + // Prevent outside initialization + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/SslProvider.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/SslProvider.java new file mode 100644 index 0000000..952b9e8 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/SslProvider.java @@ -0,0 +1,115 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.ssl; + +import io.netty.util.ReferenceCounted; +import io.netty.util.internal.UnstableApi; + +import java.security.Provider; + +/** + * An enumeration of SSL/TLS protocol providers. + */ +public enum SslProvider { + /** + * JDK's default implementation. + */ + JDK, + /** + * OpenSSL-based implementation. + */ + OPENSSL, + /** + * OpenSSL-based implementation which does not have finalizers and instead implements {@link ReferenceCounted}. + */ + @UnstableApi + OPENSSL_REFCNT; + + /** + * Returns {@code true} if the specified {@link SslProvider} supports + * TLS ALPN Extension, {@code false} otherwise. + */ + @SuppressWarnings("deprecation") + public static boolean isAlpnSupported(final SslProvider provider) { + switch (provider) { + case JDK: + return JdkAlpnApplicationProtocolNegotiator.isAlpnSupported(); + case OPENSSL: + case OPENSSL_REFCNT: + return OpenSsl.isAlpnSupported(); + default: + throw new Error("Unknown SslProvider: " + provider); + } + } + + /** + * Returns {@code true} if the specified {@link SslProvider} supports + * TLS 1.3, {@code false} otherwise. + */ + public static boolean isTlsv13Supported(final SslProvider sslProvider) { + return isTlsv13Supported(sslProvider, null); + } + + /** + * Returns {@code true} if the specified {@link SslProvider} supports + * TLS 1.3, {@code false} otherwise. + */ + public static boolean isTlsv13Supported(final SslProvider sslProvider, Provider provider) { + switch (sslProvider) { + case JDK: + return SslUtils.isTLSv13SupportedByJDK(provider); + case OPENSSL: + case OPENSSL_REFCNT: + return OpenSsl.isTlsv13Supported(); + default: + throw new Error("Unknown SslProvider: " + sslProvider); + } + } + + /** + * Returns {@code true} if the specified {@link SslProvider} supports the specified {@link SslContextOption}, + * {@code false} otherwise. + */ + public static boolean isOptionSupported(SslProvider sslProvider, SslContextOption option) { + switch (sslProvider) { + case JDK: + // We currently don't support any SslContextOptions when using the JDK implementation + return false; + case OPENSSL: + case OPENSSL_REFCNT: + return OpenSsl.isOptionSupported(option); + default: + throw new Error("Unknown SslProvider: " + sslProvider); + } + } + + /** + * Returns {@code true} if the specified {@link SslProvider} enables + * TLS 1.3 by default, {@code false} otherwise. + */ + static boolean isTlsv13EnabledByDefault(final SslProvider sslProvider, Provider provider) { + switch (sslProvider) { + case JDK: + return SslUtils.isTLSv13EnabledByJDK(provider); + case OPENSSL: + case OPENSSL_REFCNT: + return OpenSsl.isTlsv13Supported(); + default: + throw new Error("Unknown SslProvider: " + sslProvider); + } + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/SslUtils.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/SslUtils.java new file mode 100644 index 0000000..98aeaff --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/SslUtils.java @@ -0,0 +1,509 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.ByteBufUtil; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.base64.Base64; +import io.netty.handler.codec.base64.Base64Dialect; +import io.netty.util.NetUtil; +import io.netty.util.internal.EmptyArrays; +import io.netty.util.internal.StringUtil; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.security.KeyManagementException; +import java.security.NoSuchAlgorithmException; +import java.security.NoSuchProviderException; +import java.security.Provider; +import java.util.Collections; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Set; + +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLHandshakeException; +import javax.net.ssl.TrustManager; + +import static java.util.Arrays.asList; + +/** + * Constants for SSL packets. + */ +final class SslUtils { + private static final InternalLogger logger = InternalLoggerFactory.getInstance(SslUtils.class); + + // See https://tools.ietf.org/html/rfc8446#appendix-B.4 + static final Set TLSV13_CIPHERS = Collections.unmodifiableSet(new LinkedHashSet( + asList("TLS_AES_256_GCM_SHA384", "TLS_CHACHA20_POLY1305_SHA256", + "TLS_AES_128_GCM_SHA256", "TLS_AES_128_CCM_8_SHA256", + "TLS_AES_128_CCM_SHA256"))); + + static final short DTLS_1_0 = (short) 0xFEFF; + static final short DTLS_1_2 = (short) 0xFEFD; + static final short DTLS_1_3 = (short) 0xFEFC; + static final short DTLS_RECORD_HEADER_LENGTH = 13; + + /** + * GMSSL Protocol Version + */ + static final int GMSSL_PROTOCOL_VERSION = 0x101; + + static final String INVALID_CIPHER = "SSL_NULL_WITH_NULL_NULL"; + + /** + * change cipher spec + */ + static final int SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC = 20; + + /** + * alert + */ + static final int SSL_CONTENT_TYPE_ALERT = 21; + + /** + * handshake + */ + static final int SSL_CONTENT_TYPE_HANDSHAKE = 22; + + /** + * application data + */ + static final int SSL_CONTENT_TYPE_APPLICATION_DATA = 23; + + /** + * HeartBeat Extension + */ + static final int SSL_CONTENT_TYPE_EXTENSION_HEARTBEAT = 24; + + /** + * the length of the ssl record header (in bytes) + */ + static final int SSL_RECORD_HEADER_LENGTH = 5; + + /** + * Not enough data in buffer to parse the record length + */ + static final int NOT_ENOUGH_DATA = -1; + + /** + * data is not encrypted + */ + static final int NOT_ENCRYPTED = -2; + + static final String[] DEFAULT_CIPHER_SUITES; + static final String[] DEFAULT_TLSV13_CIPHER_SUITES; + static final String[] TLSV13_CIPHER_SUITES = { "TLS_AES_128_GCM_SHA256", "TLS_AES_256_GCM_SHA384" }; + + private static final boolean TLSV1_3_JDK_SUPPORTED; + private static final boolean TLSV1_3_JDK_DEFAULT_ENABLED; + + static { + TLSV1_3_JDK_SUPPORTED = isTLSv13SupportedByJDK0(null); + TLSV1_3_JDK_DEFAULT_ENABLED = isTLSv13EnabledByJDK0(null); + if (TLSV1_3_JDK_SUPPORTED) { + DEFAULT_TLSV13_CIPHER_SUITES = TLSV13_CIPHER_SUITES; + } else { + DEFAULT_TLSV13_CIPHER_SUITES = EmptyArrays.EMPTY_STRINGS; + } + + Set defaultCiphers = new LinkedHashSet(); + // GCM (Galois/Counter Mode) requires JDK 8. + defaultCiphers.add("TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384"); + defaultCiphers.add("TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256"); + defaultCiphers.add("TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"); + defaultCiphers.add("TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384"); + defaultCiphers.add("TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA"); + // AES256 requires JCE unlimited strength jurisdiction policy files. + defaultCiphers.add("TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA"); + // GCM (Galois/Counter Mode) requires JDK 8. + defaultCiphers.add("TLS_RSA_WITH_AES_128_GCM_SHA256"); + defaultCiphers.add("TLS_RSA_WITH_AES_128_CBC_SHA"); + // AES256 requires JCE unlimited strength jurisdiction policy files. + defaultCiphers.add("TLS_RSA_WITH_AES_256_CBC_SHA"); + + Collections.addAll(defaultCiphers, DEFAULT_TLSV13_CIPHER_SUITES); + + DEFAULT_CIPHER_SUITES = defaultCiphers.toArray(EmptyArrays.EMPTY_STRINGS); + } + + /** + * Returns {@code true} if the JDK itself supports TLSv1.3, {@code false} otherwise. + */ + static boolean isTLSv13SupportedByJDK(Provider provider) { + if (provider == null) { + return TLSV1_3_JDK_SUPPORTED; + } + return isTLSv13SupportedByJDK0(provider); + } + + private static boolean isTLSv13SupportedByJDK0(Provider provider) { + try { + return arrayContains(newInitContext(provider) + .getSupportedSSLParameters().getProtocols(), SslProtocols.TLS_v1_3); + } catch (Throwable cause) { + logger.debug("Unable to detect if JDK SSLEngine with provider {} supports TLSv1.3, assuming no", + provider, cause); + return false; + } + } + + /** + * Returns {@code true} if the JDK itself supports TLSv1.3 and enabled it by default, {@code false} otherwise. + */ + static boolean isTLSv13EnabledByJDK(Provider provider) { + if (provider == null) { + return TLSV1_3_JDK_DEFAULT_ENABLED; + } + return isTLSv13EnabledByJDK0(provider); + } + + private static boolean isTLSv13EnabledByJDK0(Provider provider) { + try { + return arrayContains(newInitContext(provider) + .getDefaultSSLParameters().getProtocols(), SslProtocols.TLS_v1_3); + } catch (Throwable cause) { + logger.debug("Unable to detect if JDK SSLEngine with provider {} enables TLSv1.3 by default," + + " assuming no", provider, cause); + return false; + } + } + + private static SSLContext newInitContext(Provider provider) + throws NoSuchAlgorithmException, KeyManagementException { + final SSLContext context; + if (provider == null) { + context = SSLContext.getInstance("TLS"); + } else { + context = SSLContext.getInstance("TLS", provider); + } + context.init(null, new TrustManager[0], null); + return context; + } + + static SSLContext getSSLContext(String provider) + throws NoSuchAlgorithmException, KeyManagementException, NoSuchProviderException { + final SSLContext context; + if (StringUtil.isNullOrEmpty(provider)) { + context = SSLContext.getInstance(getTlsVersion()); + } else { + context = SSLContext.getInstance(getTlsVersion(), provider); + } + context.init(null, new TrustManager[0], null); + return context; + } + + private static String getTlsVersion() { + return TLSV1_3_JDK_SUPPORTED ? SslProtocols.TLS_v1_3 : SslProtocols.TLS_v1_2; + } + + static boolean arrayContains(String[] array, String value) { + for (String v: array) { + if (value.equals(v)) { + return true; + } + } + return false; + } + + /** + * Add elements from {@code names} into {@code enabled} if they are in {@code supported}. + */ + static void addIfSupported(Set supported, List enabled, String... names) { + for (String n: names) { + if (supported.contains(n)) { + enabled.add(n); + } + } + } + + static void useFallbackCiphersIfDefaultIsEmpty(List defaultCiphers, Iterable fallbackCiphers) { + if (defaultCiphers.isEmpty()) { + for (String cipher : fallbackCiphers) { + if (cipher.startsWith("SSL_") || cipher.contains("_RC4_")) { + continue; + } + defaultCiphers.add(cipher); + } + } + } + + static void useFallbackCiphersIfDefaultIsEmpty(List defaultCiphers, String... fallbackCiphers) { + useFallbackCiphersIfDefaultIsEmpty(defaultCiphers, asList(fallbackCiphers)); + } + + /** + * Converts the given exception to a {@link SSLHandshakeException}, if it isn't already. + */ + static SSLHandshakeException toSSLHandshakeException(Throwable e) { + if (e instanceof SSLHandshakeException) { + return (SSLHandshakeException) e; + } + + return (SSLHandshakeException) new SSLHandshakeException(e.getMessage()).initCause(e); + } + + /** + * Return how much bytes can be read out of the encrypted data. Be aware that this method will not increase + * the readerIndex of the given {@link ByteBuf}. + * + * @param buffer + * The {@link ByteBuf} to read from. + * @return length + * The length of the encrypted packet that is included in the buffer or + * {@link #SslUtils#NOT_ENOUGH_DATA} if not enough data is present in the + * {@link ByteBuf}. This will return {@link SslUtils#NOT_ENCRYPTED} if + * the given {@link ByteBuf} is not encrypted at all. + */ + static int getEncryptedPacketLength(ByteBuf buffer, int offset) { + int packetLength = 0; + + // SSLv3 or TLS - Check ContentType + boolean tls; + switch (buffer.getUnsignedByte(offset)) { + case SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC: + case SSL_CONTENT_TYPE_ALERT: + case SSL_CONTENT_TYPE_HANDSHAKE: + case SSL_CONTENT_TYPE_APPLICATION_DATA: + case SSL_CONTENT_TYPE_EXTENSION_HEARTBEAT: + tls = true; + break; + default: + // SSLv2 or bad data + tls = false; + } + + if (tls) { + // SSLv3 or TLS or GMSSLv1.0 or GMSSLv1.1 - Check ProtocolVersion + int majorVersion = buffer.getUnsignedByte(offset + 1); + int version = buffer.getShort(offset + 1); + if (majorVersion == 3 || version == GMSSL_PROTOCOL_VERSION) { + // SSLv3 or TLS or GMSSLv1.0 or GMSSLv1.1 + packetLength = unsignedShortBE(buffer, offset + 3) + SSL_RECORD_HEADER_LENGTH; + if (packetLength <= SSL_RECORD_HEADER_LENGTH) { + // Neither SSLv3 or TLSv1 (i.e. SSLv2 or bad data) + tls = false; + } + } else if (version == DTLS_1_0 || version == DTLS_1_2 || version == DTLS_1_3) { + if (buffer.readableBytes() < offset + DTLS_RECORD_HEADER_LENGTH) { + return NOT_ENOUGH_DATA; + } + // length is the last 2 bytes in the 13 byte header. + packetLength = unsignedShortBE(buffer, offset + DTLS_RECORD_HEADER_LENGTH - 2) + + DTLS_RECORD_HEADER_LENGTH; + } else { + // Neither SSLv3 or TLSv1 (i.e. SSLv2 or bad data) + tls = false; + } + } + + if (!tls) { + // SSLv2 or bad data - Check the version + int headerLength = (buffer.getUnsignedByte(offset) & 0x80) != 0 ? 2 : 3; + int majorVersion = buffer.getUnsignedByte(offset + headerLength + 1); + if (majorVersion == 2 || majorVersion == 3) { + // SSLv2 + packetLength = headerLength == 2 ? + (shortBE(buffer, offset) & 0x7FFF) + 2 : (shortBE(buffer, offset) & 0x3FFF) + 3; + if (packetLength <= headerLength) { + return NOT_ENOUGH_DATA; + } + } else { + return NOT_ENCRYPTED; + } + } + return packetLength; + } + + // Reads a big-endian unsigned short integer from the buffer + @SuppressWarnings("deprecation") + private static int unsignedShortBE(ByteBuf buffer, int offset) { + int value = buffer.getUnsignedShort(offset); + if (buffer.order() == ByteOrder.LITTLE_ENDIAN) { + value = Integer.reverseBytes(value) >>> Short.SIZE; + } + return value; + } + + // Reads a big-endian short integer from the buffer + @SuppressWarnings("deprecation") + private static short shortBE(ByteBuf buffer, int offset) { + short value = buffer.getShort(offset); + if (buffer.order() == ByteOrder.LITTLE_ENDIAN) { + value = Short.reverseBytes(value); + } + return value; + } + + private static short unsignedByte(byte b) { + return (short) (b & 0xFF); + } + + // Reads a big-endian unsigned short integer from the buffer + private static int unsignedShortBE(ByteBuffer buffer, int offset) { + return shortBE(buffer, offset) & 0xFFFF; + } + + // Reads a big-endian short integer from the buffer + private static short shortBE(ByteBuffer buffer, int offset) { + return buffer.order() == ByteOrder.BIG_ENDIAN ? + buffer.getShort(offset) : ByteBufUtil.swapShort(buffer.getShort(offset)); + } + + static int getEncryptedPacketLength(ByteBuffer[] buffers, int offset) { + ByteBuffer buffer = buffers[offset]; + + // Check if everything we need is in one ByteBuffer. If so we can make use of the fast-path. + if (buffer.remaining() >= SSL_RECORD_HEADER_LENGTH) { + return getEncryptedPacketLength(buffer); + } + + // We need to copy 5 bytes into a temporary buffer so we can parse out the packet length easily. + ByteBuffer tmp = ByteBuffer.allocate(5); + + do { + buffer = buffers[offset++].duplicate(); + if (buffer.remaining() > tmp.remaining()) { + buffer.limit(buffer.position() + tmp.remaining()); + } + tmp.put(buffer); + } while (tmp.hasRemaining()); + + // Done, flip the buffer so we can read from it. + tmp.flip(); + return getEncryptedPacketLength(tmp); + } + + private static int getEncryptedPacketLength(ByteBuffer buffer) { + int packetLength = 0; + int pos = buffer.position(); + // SSLv3 or TLS - Check ContentType + boolean tls; + switch (unsignedByte(buffer.get(pos))) { + case SSL_CONTENT_TYPE_CHANGE_CIPHER_SPEC: + case SSL_CONTENT_TYPE_ALERT: + case SSL_CONTENT_TYPE_HANDSHAKE: + case SSL_CONTENT_TYPE_APPLICATION_DATA: + case SSL_CONTENT_TYPE_EXTENSION_HEARTBEAT: + tls = true; + break; + default: + // SSLv2 or bad data + tls = false; + } + + if (tls) { + // SSLv3 or TLS or GMSSLv1.0 or GMSSLv1.1 - Check ProtocolVersion + int majorVersion = unsignedByte(buffer.get(pos + 1)); + if (majorVersion == 3 || buffer.getShort(pos + 1) == GMSSL_PROTOCOL_VERSION) { + // SSLv3 or TLS or GMSSLv1.0 or GMSSLv1.1 + packetLength = unsignedShortBE(buffer, pos + 3) + SSL_RECORD_HEADER_LENGTH; + if (packetLength <= SSL_RECORD_HEADER_LENGTH) { + // Neither SSLv3 or TLSv1 (i.e. SSLv2 or bad data) + tls = false; + } + } else { + // Neither SSLv3 or TLSv1 (i.e. SSLv2 or bad data) + tls = false; + } + } + + if (!tls) { + // SSLv2 or bad data - Check the version + int headerLength = (unsignedByte(buffer.get(pos)) & 0x80) != 0 ? 2 : 3; + int majorVersion = unsignedByte(buffer.get(pos + headerLength + 1)); + if (majorVersion == 2 || majorVersion == 3) { + // SSLv2 + packetLength = headerLength == 2 ? + (shortBE(buffer, pos) & 0x7FFF) + 2 : (shortBE(buffer, pos) & 0x3FFF) + 3; + if (packetLength <= headerLength) { + return NOT_ENOUGH_DATA; + } + } else { + return NOT_ENCRYPTED; + } + } + return packetLength; + } + + static void handleHandshakeFailure(ChannelHandlerContext ctx, Throwable cause, boolean notify) { + // We have may haven written some parts of data before an exception was thrown so ensure we always flush. + // See https://github.com/netty/netty/issues/3900#issuecomment-172481830 + ctx.flush(); + if (notify) { + ctx.fireUserEventTriggered(new SslHandshakeCompletionEvent(cause)); + } + ctx.close(); + } + + /** + * Fills the {@link ByteBuf} with zero bytes. + */ + static void zeroout(ByteBuf buffer) { + if (!buffer.isReadOnly()) { + buffer.setZero(0, buffer.capacity()); + } + } + + /** + * Fills the {@link ByteBuf} with zero bytes and releases it. + */ + static void zerooutAndRelease(ByteBuf buffer) { + zeroout(buffer); + buffer.release(); + } + + /** + * Same as {@link Base64#encode(ByteBuf, boolean)} but allows the use of a custom {@link ByteBufAllocator}. + * + * @see Base64#encode(ByteBuf, boolean) + */ + static ByteBuf toBase64(ByteBufAllocator allocator, ByteBuf src) { + ByteBuf dst = Base64.encode(src, src.readerIndex(), + src.readableBytes(), true, Base64Dialect.STANDARD, allocator); + src.readerIndex(src.writerIndex()); + return dst; + } + + /** + * Validate that the given hostname can be used in SNI extension. + */ + static boolean isValidHostNameForSNI(String hostname) { + // See https://datatracker.ietf.org/doc/html/rfc6066#section-3 + return hostname != null && + // SNI HostName has to be a FQDN according to TLS SNI Extension spec (see [1]), + // which means that is has to have at least a host name and a domain part. + hostname.indexOf('.') > 0 && + !hostname.endsWith(".") && !hostname.startsWith("/") && + !NetUtil.isValidIpV4Address(hostname) && + !NetUtil.isValidIpV6Address(hostname); + } + + /** + * Returns {@code true} if the given cipher (in openssl format) is for TLSv1.3, {@code false} otherwise. + */ + static boolean isTLSv13Cipher(String cipher) { + // See https://tools.ietf.org/html/rfc8446#appendix-B.4 + return TLSV13_CIPHERS.contains(cipher); + } + + private SslUtils() { + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/StacklessSSLHandshakeException.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/StacklessSSLHandshakeException.java new file mode 100644 index 0000000..aef55fa --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/StacklessSSLHandshakeException.java @@ -0,0 +1,46 @@ +/* + * Copyright 2023 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.util.internal.ThrowableUtil; + +import javax.net.ssl.SSLHandshakeException; + +/** + * A {@link SSLHandshakeException} that does not fill in the stack trace. + */ +final class StacklessSSLHandshakeException extends SSLHandshakeException { + + private static final long serialVersionUID = -1244781947804415549L; + + private StacklessSSLHandshakeException(String reason) { + super(reason); + } + + @Override + public Throwable fillInStackTrace() { + // This is a performance optimization to not fill in the + // stack trace as this is a stackless exception. + return this; + } + + /** + * Creates a new {@link StacklessSSLHandshakeException} which has the origin of the given {@link Class} and method. + */ + static StacklessSSLHandshakeException newInstance(String reason, Class clazz, String method) { + return ThrowableUtil.unknownStackTrace(new StacklessSSLHandshakeException(reason), clazz, method); + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/SupportedCipherSuiteFilter.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/SupportedCipherSuiteFilter.java new file mode 100644 index 0000000..14cc08e --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/SupportedCipherSuiteFilter.java @@ -0,0 +1,58 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.util.internal.EmptyArrays; +import io.netty.util.internal.ObjectUtil; + +import javax.net.ssl.SSLEngine; +import java.util.ArrayList; +import java.util.List; +import java.util.Set; + +/** + * This class will filter all requested ciphers out that are not supported by the current {@link SSLEngine}. + */ +public final class SupportedCipherSuiteFilter implements CipherSuiteFilter { + public static final SupportedCipherSuiteFilter INSTANCE = new SupportedCipherSuiteFilter(); + + private SupportedCipherSuiteFilter() { } + + @Override + public String[] filterCipherSuites(Iterable ciphers, List defaultCiphers, + Set supportedCiphers) { + ObjectUtil.checkNotNull(defaultCiphers, "defaultCiphers"); + ObjectUtil.checkNotNull(supportedCiphers, "supportedCiphers"); + + final List newCiphers; + if (ciphers == null) { + newCiphers = new ArrayList(defaultCiphers.size()); + ciphers = defaultCiphers; + } else { + newCiphers = new ArrayList(supportedCiphers.size()); + } + for (String c : ciphers) { + if (c == null) { + break; + } + if (supportedCiphers.contains(c)) { + newCiphers.add(c); + } + } + return newCiphers.toArray(EmptyArrays.EMPTY_STRINGS); + } + +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/ocsp/OcspClientHandler.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/ocsp/OcspClientHandler.java new file mode 100644 index 0000000..296a32b --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/ocsp/OcspClientHandler.java @@ -0,0 +1,57 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl.ocsp; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.handler.ssl.ReferenceCountedOpenSslEngine; +import io.netty.handler.ssl.SslHandshakeCompletionEvent; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.UnstableApi; + +import javax.net.ssl.SSLHandshakeException; + +/** + * A handler for SSL clients to handle and act upon stapled OCSP responses. + */ +@UnstableApi +public abstract class OcspClientHandler extends ChannelInboundHandlerAdapter { + + private final ReferenceCountedOpenSslEngine engine; + + protected OcspClientHandler(ReferenceCountedOpenSslEngine engine) { + this.engine = ObjectUtil.checkNotNull(engine, "engine"); + } + + /** + * @see ReferenceCountedOpenSslEngine#getOcspResponse() + */ + protected abstract boolean verify(ChannelHandlerContext ctx, ReferenceCountedOpenSslEngine engine) throws Exception; + + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + if (evt instanceof SslHandshakeCompletionEvent) { + ctx.pipeline().remove(this); + + SslHandshakeCompletionEvent event = (SslHandshakeCompletionEvent) evt; + if (event.isSuccess() && !verify(ctx, engine)) { + throw new SSLHandshakeException("Bad OCSP response"); + } + } + + ctx.fireUserEventTriggered(evt); + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/ocsp/package-info.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/ocsp/package-info.java new file mode 100644 index 0000000..7e81ae6 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/ocsp/package-info.java @@ -0,0 +1,23 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/** + * OCSP stapling, + * formally known as the TLS Certificate Status Request extension, is an + * alternative approach to the Online Certificate Status Protocol (OCSP) + * for checking the revocation status of X.509 digital certificates. + */ +package io.netty.handler.ssl.ocsp; diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/package-info.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/package-info.java new file mode 100644 index 0000000..583c63b --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/package-info.java @@ -0,0 +1,21 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/** + * SSL · + * TLS implementation based on {@link javax.net.ssl.SSLEngine} + */ +package io.netty.handler.ssl; diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/util/BouncyCastleSelfSignedCertGenerator.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/util/BouncyCastleSelfSignedCertGenerator.java new file mode 100644 index 0000000..890ade3 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/util/BouncyCastleSelfSignedCertGenerator.java @@ -0,0 +1,64 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.ssl.util; + +import org.bouncycastle.asn1.x500.X500Name; +import org.bouncycastle.cert.X509CertificateHolder; +import org.bouncycastle.cert.X509v3CertificateBuilder; +import org.bouncycastle.cert.jcajce.JcaX509CertificateConverter; +import org.bouncycastle.cert.jcajce.JcaX509v3CertificateBuilder; +import org.bouncycastle.jce.provider.BouncyCastleProvider; +import org.bouncycastle.operator.ContentSigner; +import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder; + +import java.math.BigInteger; +import java.security.KeyPair; +import java.security.PrivateKey; +import java.security.Provider; +import java.security.SecureRandom; +import java.security.cert.X509Certificate; +import java.util.Date; + +import static io.netty.handler.ssl.util.SelfSignedCertificate.newSelfSignedCertificate; + +/** + * Generates a self-signed certificate using Bouncy Castle. + */ +final class BouncyCastleSelfSignedCertGenerator { + + private static final Provider PROVIDER = new BouncyCastleProvider(); + + static String[] generate(String fqdn, KeyPair keypair, SecureRandom random, Date notBefore, Date notAfter, + String algorithm) throws Exception { + PrivateKey key = keypair.getPrivate(); + + // Prepare the information required for generating an X.509 certificate. + X500Name owner = new X500Name("CN=" + fqdn); + X509v3CertificateBuilder builder = new JcaX509v3CertificateBuilder( + owner, new BigInteger(64, random), notBefore, notAfter, owner, keypair.getPublic()); + + ContentSigner signer = new JcaContentSignerBuilder( + algorithm.equalsIgnoreCase("EC") ? "SHA256withECDSA" : "SHA256WithRSAEncryption").build(key); + X509CertificateHolder certHolder = builder.build(signer); + X509Certificate cert = new JcaX509CertificateConverter().setProvider(PROVIDER).getCertificate(certHolder); + cert.verify(keypair.getPublic()); + + return newSelfSignedCertificate(fqdn, key, cert); + } + + private BouncyCastleSelfSignedCertGenerator() { } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/util/FingerprintTrustManagerFactory.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/util/FingerprintTrustManagerFactory.java new file mode 100644 index 0000000..cbe9b25 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/util/FingerprintTrustManagerFactory.java @@ -0,0 +1,266 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.ssl.util; + +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import io.netty.util.concurrent.FastThreadLocal; +import io.netty.util.internal.EmptyArrays; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.StringUtil; + +import javax.net.ssl.ManagerFactoryParameters; +import javax.net.ssl.TrustManager; +import javax.net.ssl.TrustManagerFactory; +import javax.net.ssl.X509TrustManager; +import java.security.KeyStore; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.security.cert.CertificateEncodingException; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.regex.Pattern; + +/** + * An {@link TrustManagerFactory} that trusts an X.509 certificate whose hash matches. + *

+ * NOTE: It is recommended to verify certificates and their chain to prevent + * Man-in-the-middle attacks. + * This {@link TrustManagerFactory} will only verify that the fingerprint of certificates match one + * of the given fingerprints. This procedure is called + * certificate pinning and + * is an effective protection. For maximum security one should verify that the whole certificate chain is as expected. + * It is worth mentioning that certain firewalls, proxies or other appliances found in corporate environments, + * actually perform Man-in-the-middle attacks and thus present a different certificate fingerprint. + *

+ *

+ * The hash of an X.509 certificate is calculated from its DER encoded format. You can get the fingerprint of + * an X.509 certificate using the {@code openssl} command. For example: + * + *

+ * $ openssl x509 -fingerprint -sha256 -in my_certificate.crt
+ * SHA256 Fingerprint=1C:53:0E:6B:FF:93:F0:DE:C2:E6:E7:9D:10:53:58:FF:DD:8E:68:CD:82:D9:C9:36:9B:43:EE:B3:DC:13:68:FB
+ * -----BEGIN CERTIFICATE-----
+ * MIIC/jCCAeagAwIBAgIIIMONxElm0AIwDQYJKoZIhvcNAQELBQAwPjE8MDoGA1UE
+ * AwwzZThhYzAyZmEwZDY1YTg0MjE5MDE2MDQ1ZGI4YjA1YzQ4NWI0ZWNkZi5uZXR0
+ * eS50ZXN0MCAXDTEzMDgwMjA3NTEzNloYDzk5OTkxMjMxMjM1OTU5WjA+MTwwOgYD
+ * VQQDDDNlOGFjMDJmYTBkNjVhODQyMTkwMTYwNDVkYjhiMDVjNDg1YjRlY2RmLm5l
+ * dHR5LnRlc3QwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQDb+HBO3C0U
+ * RBKvDUgJHbhIlBye8X/cbNH3lDq3XOOFBz7L4XZKLDIXS+FeQqSAUMo2otmU+Vkj
+ * 0KorshMjbUXfE1KkTijTMJlaga2M2xVVt21fRIkJNWbIL0dWFLWyRq7OXdygyFkI
+ * iW9b2/LYaePBgET22kbtHSCAEj+BlSf265+1rNxyAXBGGGccCKzEbcqASBKHOgVp
+ * 6pLqlQAfuSy6g/OzGzces3zXRrGu1N3pBIzAIwCW429n52ZlYfYR0nr+REKDnRrP
+ * IIDsWASmEHhBezTD+v0qCJRyLz2usFgWY+7agUJE2yHHI2mTu2RAFngBilJXlMCt
+ * VwT0xGuQxkbHAgMBAAEwDQYJKoZIhvcNAQELBQADggEBAEv8N7Xm8qaY2FgrOc6P
+ * a1GTgA+AOb3aU33TGwAR86f+nLf6BSPaohcQfOeJid7FkFuYInuXl+oqs+RqM/j8
+ * R0E5BuGYY2wOKpL/PbFi1yf/Kyvft7KVh8e1IUUec/i1DdYTDB0lNWvXXxjfMKGL
+ * ct3GMbEHKvLfHx42Iwz/+fva6LUrO4u2TDfv0ycHuR7UZEuC1DJ4xtFhbpq/QRAj
+ * CyfNx3cDc7L2EtJWnCmivTFA9l8MF1ZPMDSVd4ecQ7B0xZIFQ5cSSFt7WGaJCsGM
+ * zYkU4Fp4IykQcWxdlNX7wJZRwQ2TZJFFglpTiFZdeq6I6Ad9An1Encpz5W8UJ4tv
+ * hmw=
+ * -----END CERTIFICATE-----
+ * 
+ */ +public final class FingerprintTrustManagerFactory extends SimpleTrustManagerFactory { + + private static final Pattern FINGERPRINT_PATTERN = Pattern.compile("^[0-9a-fA-F:]+$"); + private static final Pattern FINGERPRINT_STRIP_PATTERN = Pattern.compile(":"); + + /** + * Creates a builder for {@link FingerprintTrustManagerFactory}. + * + * @param algorithm a hash algorithm + * @return a builder + */ + public static FingerprintTrustManagerFactoryBuilder builder(String algorithm) { + return new FingerprintTrustManagerFactoryBuilder(algorithm); + } + + private final FastThreadLocal tlmd; + + private final TrustManager tm = new X509TrustManager() { + + @Override + public void checkClientTrusted(X509Certificate[] chain, String s) throws CertificateException { + checkTrusted("client", chain); + } + + @Override + public void checkServerTrusted(X509Certificate[] chain, String s) throws CertificateException { + checkTrusted("server", chain); + } + + private void checkTrusted(String type, X509Certificate[] chain) throws CertificateException { + X509Certificate cert = chain[0]; + byte[] fingerprint = fingerprint(cert); + boolean found = false; + for (byte[] allowedFingerprint: fingerprints) { + if (Arrays.equals(fingerprint, allowedFingerprint)) { + found = true; + break; + } + } + + if (!found) { + throw new CertificateException( + type + " certificate with unknown fingerprint: " + cert.getSubjectDN()); + } + } + + private byte[] fingerprint(X509Certificate cert) throws CertificateEncodingException { + MessageDigest md = tlmd.get(); + md.reset(); + return md.digest(cert.getEncoded()); + } + + @Override + public X509Certificate[] getAcceptedIssuers() { + return EmptyArrays.EMPTY_X509_CERTIFICATES; + } + }; + + private final byte[][] fingerprints; + + /** + * Creates a new instance. + * + * @deprecated This deprecated constructor uses SHA-1 that is considered insecure. + * It is recommended to specify a stronger hash algorithm, such as SHA-256, + * by calling {@link FingerprintTrustManagerFactory#builder(String)} method. + * + * @param fingerprints a list of SHA1 fingerprints in hexadecimal form + */ + @Deprecated + public FingerprintTrustManagerFactory(Iterable fingerprints) { + this("SHA1", toFingerprintArray(fingerprints)); + } + + /** + * Creates a new instance. + * + * @deprecated This deprecated constructor uses SHA-1 that is considered insecure. + * It is recommended to specify a stronger hash algorithm, such as SHA-256, + * by calling {@link FingerprintTrustManagerFactory#builder(String)} method. + * + * @param fingerprints a list of SHA1 fingerprints in hexadecimal form + */ + @Deprecated + public FingerprintTrustManagerFactory(String... fingerprints) { + this("SHA1", toFingerprintArray(Arrays.asList(fingerprints))); + } + + /** + * Creates a new instance. + * + * @deprecated This deprecated constructor uses SHA-1 that is considered insecure. + * It is recommended to specify a stronger hash algorithm, such as SHA-256, + * by calling {@link FingerprintTrustManagerFactory#builder(String)} method. + * + * @param fingerprints a list of SHA1 fingerprints + */ + @Deprecated + public FingerprintTrustManagerFactory(byte[]... fingerprints) { + this("SHA1", fingerprints); + } + + /** + * Creates a new instance. + * + * @param algorithm a hash algorithm + * @param fingerprints a list of fingerprints + */ + FingerprintTrustManagerFactory(final String algorithm, byte[][] fingerprints) { + ObjectUtil.checkNotNull(algorithm, "algorithm"); + ObjectUtil.checkNotNull(fingerprints, "fingerprints"); + + if (fingerprints.length == 0) { + throw new IllegalArgumentException("No fingerprints provided"); + } + + // check early if the hash algorithm is available + final MessageDigest md; + try { + md = MessageDigest.getInstance(algorithm); + } catch (NoSuchAlgorithmException e) { + throw new IllegalArgumentException( + String.format("Unsupported hash algorithm: %s", algorithm), e); + } + + int hashLength = md.getDigestLength(); + List list = new ArrayList(fingerprints.length); + for (byte[] f: fingerprints) { + if (f == null) { + break; + } + if (f.length != hashLength) { + throw new IllegalArgumentException( + String.format("malformed fingerprint (length is %d but expected %d): %s", + f.length, hashLength, ByteBufUtil.hexDump(Unpooled.wrappedBuffer(f)))); + } + list.add(f.clone()); + } + + this.tlmd = new FastThreadLocal() { + + @Override + protected MessageDigest initialValue() { + try { + return MessageDigest.getInstance(algorithm); + } catch (NoSuchAlgorithmException e) { + throw new IllegalArgumentException( + String.format("Unsupported hash algorithm: %s", algorithm), e); + } + } + }; + + this.fingerprints = list.toArray(new byte[0][]); + } + + static byte[][] toFingerprintArray(Iterable fingerprints) { + ObjectUtil.checkNotNull(fingerprints, "fingerprints"); + + List list = new ArrayList(); + for (String f: fingerprints) { + if (f == null) { + break; + } + + if (!FINGERPRINT_PATTERN.matcher(f).matches()) { + throw new IllegalArgumentException("malformed fingerprint: " + f); + } + f = FINGERPRINT_STRIP_PATTERN.matcher(f).replaceAll(""); + + list.add(StringUtil.decodeHexDump(f)); + } + + return list.toArray(new byte[0][]); + } + + @Override + protected void engineInit(KeyStore keyStore) throws Exception { } + + @Override + protected void engineInit(ManagerFactoryParameters managerFactoryParameters) throws Exception { } + + @Override + protected TrustManager[] engineGetTrustManagers() { + return new TrustManager[] { tm }; + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/util/FingerprintTrustManagerFactoryBuilder.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/util/FingerprintTrustManagerFactoryBuilder.java new file mode 100644 index 0000000..a7e939b --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/util/FingerprintTrustManagerFactoryBuilder.java @@ -0,0 +1,87 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.ssl.util; + +import static io.netty.util.internal.ObjectUtil.checkNotNull; +import static io.netty.util.internal.ObjectUtil.checkNotNullWithIAE; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +/** + * A builder for creating {@link FingerprintTrustManagerFactory}. + */ +public final class FingerprintTrustManagerFactoryBuilder { + + /** + * A hash algorithm for fingerprints. + */ + private final String algorithm; + + /** + * A list of fingerprints. + */ + private final List fingerprints = new ArrayList(); + + /** + * Creates a builder. + * + * @param algorithm a hash algorithm + */ + FingerprintTrustManagerFactoryBuilder(String algorithm) { + this.algorithm = checkNotNull(algorithm, "algorithm"); + } + + /** + * Adds fingerprints. + * + * @param fingerprints a number of fingerprints + * @return the same builder + */ + public FingerprintTrustManagerFactoryBuilder fingerprints(CharSequence... fingerprints) { + return fingerprints(Arrays.asList(checkNotNull(fingerprints, "fingerprints"))); + } + + /** + * Adds fingerprints. + * + * @param fingerprints a number of fingerprints + * @return the same builder + */ + public FingerprintTrustManagerFactoryBuilder fingerprints(Iterable fingerprints) { + checkNotNull(fingerprints, "fingerprints"); + for (CharSequence fingerprint : fingerprints) { + checkNotNullWithIAE(fingerprint, "fingerprint"); + this.fingerprints.add(fingerprint.toString()); + } + return this; + } + + /** + * Creates a {@link FingerprintTrustManagerFactory}. + * + * @return a new {@link FingerprintTrustManagerFactory} + */ + public FingerprintTrustManagerFactory build() { + if (fingerprints.isEmpty()) { + throw new IllegalStateException("No fingerprints provided"); + } + return new FingerprintTrustManagerFactory(this.algorithm, + FingerprintTrustManagerFactory.toFingerprintArray(this.fingerprints)); + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/util/InsecureTrustManagerFactory.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/util/InsecureTrustManagerFactory.java new file mode 100644 index 0000000..6efe959 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/util/InsecureTrustManagerFactory.java @@ -0,0 +1,77 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.ssl.util; + +import io.netty.util.internal.EmptyArrays; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import javax.net.ssl.ManagerFactoryParameters; +import javax.net.ssl.TrustManager; +import javax.net.ssl.TrustManagerFactory; +import javax.net.ssl.X509TrustManager; +import java.security.KeyStore; +import java.security.cert.X509Certificate; + +/** + * An insecure {@link TrustManagerFactory} that trusts all X.509 certificates without any verification. + *

+ * NOTE: + * Never use this {@link TrustManagerFactory} in production. + * It is purely for testing purposes, and thus it is very insecure. + *

+ */ +public final class InsecureTrustManagerFactory extends SimpleTrustManagerFactory { + + private static final InternalLogger logger = InternalLoggerFactory.getInstance(InsecureTrustManagerFactory.class); + + public static final TrustManagerFactory INSTANCE = new InsecureTrustManagerFactory(); + + private static final TrustManager tm = new X509TrustManager() { + @Override + public void checkClientTrusted(X509Certificate[] chain, String s) { + if (logger.isDebugEnabled()) { + logger.debug("Accepting a client certificate: " + chain[0].getSubjectDN()); + } + } + + @Override + public void checkServerTrusted(X509Certificate[] chain, String s) { + if (logger.isDebugEnabled()) { + logger.debug("Accepting a server certificate: " + chain[0].getSubjectDN()); + } + } + + @Override + public X509Certificate[] getAcceptedIssuers() { + return EmptyArrays.EMPTY_X509_CERTIFICATES; + } + }; + + private InsecureTrustManagerFactory() { } + + @Override + protected void engineInit(KeyStore keyStore) throws Exception { } + + @Override + protected void engineInit(ManagerFactoryParameters managerFactoryParameters) throws Exception { } + + @Override + protected TrustManager[] engineGetTrustManagers() { + return new TrustManager[] { tm }; + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/util/KeyManagerFactoryWrapper.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/util/KeyManagerFactoryWrapper.java new file mode 100644 index 0000000..3ec2056 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/util/KeyManagerFactoryWrapper.java @@ -0,0 +1,43 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.ssl.util; + +import io.netty.util.internal.ObjectUtil; + +import java.security.KeyStore; +import javax.net.ssl.KeyManager; +import javax.net.ssl.ManagerFactoryParameters; + +public final class KeyManagerFactoryWrapper extends SimpleKeyManagerFactory { + private final KeyManager km; + + public KeyManagerFactoryWrapper(KeyManager km) { + this.km = ObjectUtil.checkNotNull(km, "km"); + } + + @Override + protected void engineInit(KeyStore keyStore, char[] var2) throws Exception { } + + @Override + protected void engineInit(ManagerFactoryParameters managerFactoryParameters) + throws Exception { } + + @Override + protected KeyManager[] engineGetKeyManagers() { + return new KeyManager[] {km}; + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/util/LazyJavaxX509Certificate.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/util/LazyJavaxX509Certificate.java new file mode 100644 index 0000000..40efcb9 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/util/LazyJavaxX509Certificate.java @@ -0,0 +1,202 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl.util; + +import io.netty.util.internal.ObjectUtil; + +import java.io.ByteArrayInputStream; +import java.security.cert.CertificateEncodingException; +import java.security.cert.CertificateException; +import java.security.cert.CertificateExpiredException; +import java.security.cert.CertificateFactory; +import java.security.cert.CertificateNotYetValidException; +import java.security.cert.X509Certificate; +import java.math.BigInteger; +import java.security.InvalidKeyException; +import java.security.NoSuchAlgorithmException; +import java.security.NoSuchProviderException; +import java.security.Principal; +import java.security.PublicKey; +import java.security.SignatureException; +import java.util.Date; +import java.util.Set; + +public final class LazyJavaxX509Certificate extends X509Certificate { + private final byte[] bytes; + private X509Certificate wrapped; + + /** + * Creates a new instance which will lazy parse the given bytes. Be aware that the bytes will not be cloned. + */ + public LazyJavaxX509Certificate(byte[] bytes) { + this.bytes = ObjectUtil.checkNotNull(bytes, "bytes"); + } + + @Override + public void checkValidity() throws CertificateExpiredException, CertificateNotYetValidException { + unwrap().checkValidity(); + } + + @Override + public void checkValidity(Date date) throws CertificateExpiredException, CertificateNotYetValidException { + unwrap().checkValidity(date); + } + + @Override + public int getVersion() { + return unwrap().getVersion(); + } + + @Override + public BigInteger getSerialNumber() { + return unwrap().getSerialNumber(); + } + + @Override + public Principal getIssuerDN() { + return unwrap().getIssuerDN(); + } + + @Override + public Principal getSubjectDN() { + return unwrap().getSubjectDN(); + } + + @Override + public Date getNotBefore() { + return unwrap().getNotBefore(); + } + + @Override + public Date getNotAfter() { + return unwrap().getNotAfter(); + } + + @Override + public byte[] getTBSCertificate() throws CertificateEncodingException { + return new byte[0]; // TODO + } + + @Override + public byte[] getSignature() { + return new byte[0]; + } + + @Override + public String getSigAlgName() { + return unwrap().getSigAlgName(); + } + + @Override + public String getSigAlgOID() { + return unwrap().getSigAlgOID(); + } + + @Override + public byte[] getSigAlgParams() { + return unwrap().getSigAlgParams(); + } + + @Override + public boolean[] getIssuerUniqueID() { + return new boolean[0]; + } + + @Override + public boolean[] getSubjectUniqueID() { + return new boolean[0]; + } + + @Override + public boolean[] getKeyUsage() { + return new boolean[0]; + } + + @Override + public int getBasicConstraints() { + return 0; + } + + @Override + public byte[] getEncoded() { + return bytes.clone(); + } + + /** + * Return the underyling {@code byte[]} without cloning it first. This {@code byte[]} must never + * be mutated. + */ + byte[] getBytes() { + return bytes; + } + + @Override + public void verify(PublicKey key) + throws CertificateException, NoSuchAlgorithmException, InvalidKeyException, NoSuchProviderException, + SignatureException { + unwrap().verify(key); + } + + @Override + public void verify(PublicKey key, String sigProvider) + throws CertificateException, NoSuchAlgorithmException, InvalidKeyException, NoSuchProviderException, + SignatureException { + unwrap().verify(key, sigProvider); + } + + @Override + public String toString() { + return unwrap().toString(); + } + + @Override + public PublicKey getPublicKey() { + return unwrap().getPublicKey(); + } + + private X509Certificate unwrap() { + X509Certificate wrapped = this.wrapped; + if (wrapped == null) { + try { + CertificateFactory cf = CertificateFactory.getInstance("X.509"); + wrapped = this.wrapped = (X509Certificate) cf.generateCertificate(new ByteArrayInputStream(bytes)); + } catch (CertificateException e) { + throw new IllegalStateException(e); + } + } + return wrapped; + } + + @Override + public boolean hasUnsupportedCriticalExtension() { + return false; + } + + @Override + public Set getCriticalExtensionOIDs() { + return null; + } + + @Override + public Set getNonCriticalExtensionOIDs() { + return null; + } + + @Override + public byte[] getExtensionValue(String s) { + return new byte[0]; + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/util/LazyX509Certificate.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/util/LazyX509Certificate.java new file mode 100644 index 0000000..8402577 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/util/LazyX509Certificate.java @@ -0,0 +1,242 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl.util; + +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.SuppressJava6Requirement; + +import javax.security.auth.x500.X500Principal; +import java.io.ByteArrayInputStream; +import java.math.BigInteger; +import java.security.InvalidKeyException; +import java.security.NoSuchAlgorithmException; +import java.security.NoSuchProviderException; +import java.security.Principal; +import java.security.Provider; +import java.security.PublicKey; +import java.security.SignatureException; +import java.security.cert.CertificateEncodingException; +import java.security.cert.CertificateException; +import java.security.cert.CertificateExpiredException; +import java.security.cert.CertificateFactory; +import java.security.cert.CertificateNotYetValidException; +import java.security.cert.CertificateParsingException; +import java.security.cert.X509Certificate; +import java.util.Collection; +import java.util.Date; +import java.util.List; +import java.util.Set; + +public final class LazyX509Certificate extends X509Certificate { + + static final CertificateFactory X509_CERT_FACTORY; + static { + try { + X509_CERT_FACTORY = CertificateFactory.getInstance("X.509"); + } catch (CertificateException e) { + throw new ExceptionInInitializerError(e); + } + } + + private final byte[] bytes; + private X509Certificate wrapped; + + /** + * Creates a new instance which will lazy parse the given bytes. Be aware that the bytes will not be cloned. + */ + public LazyX509Certificate(byte[] bytes) { + this.bytes = ObjectUtil.checkNotNull(bytes, "bytes"); + } + + @Override + public void checkValidity() throws CertificateExpiredException, CertificateNotYetValidException { + unwrap().checkValidity(); + } + + @Override + public void checkValidity(Date date) throws CertificateExpiredException, CertificateNotYetValidException { + unwrap().checkValidity(date); + } + + @Override + public X500Principal getIssuerX500Principal() { + return unwrap().getIssuerX500Principal(); + } + + @Override + public X500Principal getSubjectX500Principal() { + return unwrap().getSubjectX500Principal(); + } + + @Override + public List getExtendedKeyUsage() throws CertificateParsingException { + return unwrap().getExtendedKeyUsage(); + } + + @Override + public Collection> getSubjectAlternativeNames() throws CertificateParsingException { + return unwrap().getSubjectAlternativeNames(); + } + + @Override + public Collection> getIssuerAlternativeNames() throws CertificateParsingException { + return unwrap().getSubjectAlternativeNames(); + } + + // No @Override annotation as it was only introduced in Java8. + @SuppressJava6Requirement(reason = "Can only be called from Java8 as class is package-private") + public void verify(PublicKey key, Provider sigProvider) + throws CertificateException, NoSuchAlgorithmException, InvalidKeyException, SignatureException { + unwrap().verify(key, sigProvider); + } + + @Override + public int getVersion() { + return unwrap().getVersion(); + } + + @Override + public BigInteger getSerialNumber() { + return unwrap().getSerialNumber(); + } + + @Override + public Principal getIssuerDN() { + return unwrap().getIssuerDN(); + } + + @Override + public Principal getSubjectDN() { + return unwrap().getSubjectDN(); + } + + @Override + public Date getNotBefore() { + return unwrap().getNotBefore(); + } + + @Override + public Date getNotAfter() { + return unwrap().getNotAfter(); + } + + @Override + public byte[] getTBSCertificate() throws CertificateEncodingException { + return unwrap().getTBSCertificate(); + } + + @Override + public byte[] getSignature() { + return unwrap().getSignature(); + } + + @Override + public String getSigAlgName() { + return unwrap().getSigAlgName(); + } + + @Override + public String getSigAlgOID() { + return unwrap().getSigAlgOID(); + } + + @Override + public byte[] getSigAlgParams() { + return unwrap().getSigAlgParams(); + } + + @Override + public boolean[] getIssuerUniqueID() { + return unwrap().getIssuerUniqueID(); + } + + @Override + public boolean[] getSubjectUniqueID() { + return unwrap().getSubjectUniqueID(); + } + + @Override + public boolean[] getKeyUsage() { + return unwrap().getKeyUsage(); + } + + @Override + public int getBasicConstraints() { + return unwrap().getBasicConstraints(); + } + + @Override + public byte[] getEncoded() { + return bytes.clone(); + } + + @Override + public void verify(PublicKey key) + throws CertificateException, NoSuchAlgorithmException, + InvalidKeyException, NoSuchProviderException, SignatureException { + unwrap().verify(key); + } + + @Override + public void verify(PublicKey key, String sigProvider) + throws CertificateException, NoSuchAlgorithmException, InvalidKeyException, + NoSuchProviderException, SignatureException { + unwrap().verify(key, sigProvider); + } + + @Override + public String toString() { + return unwrap().toString(); + } + + @Override + public PublicKey getPublicKey() { + return unwrap().getPublicKey(); + } + + @Override + public boolean hasUnsupportedCriticalExtension() { + return unwrap().hasUnsupportedCriticalExtension(); + } + + @Override + public Set getCriticalExtensionOIDs() { + return unwrap().getCriticalExtensionOIDs(); + } + + @Override + public Set getNonCriticalExtensionOIDs() { + return unwrap().getNonCriticalExtensionOIDs(); + } + + @Override + public byte[] getExtensionValue(String oid) { + return unwrap().getExtensionValue(oid); + } + + private X509Certificate unwrap() { + X509Certificate wrapped = this.wrapped; + if (wrapped == null) { + try { + wrapped = this.wrapped = (X509Certificate) X509_CERT_FACTORY.generateCertificate( + new ByteArrayInputStream(bytes)); + } catch (CertificateException e) { + throw new IllegalStateException(e); + } + } + return wrapped; + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/util/SelfSignedCertificate.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/util/SelfSignedCertificate.java new file mode 100644 index 0000000..959c22f --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/util/SelfSignedCertificate.java @@ -0,0 +1,406 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.ssl.util; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.handler.codec.base64.Base64; +import io.netty.util.CharsetUtil; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.SystemPropertyUtil; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.io.File; +import java.io.FileInputStream; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.OutputStream; +import java.security.KeyPair; +import java.security.KeyPairGenerator; +import java.security.NoSuchAlgorithmException; +import java.security.PrivateKey; +import java.security.SecureRandom; +import java.security.cert.CertificateEncodingException; +import java.security.cert.CertificateException; +import java.security.cert.CertificateFactory; +import java.security.cert.X509Certificate; +import java.util.Date; + +/** + * Generates a temporary self-signed certificate for testing purposes. + *

+ * NOTE: + * Never use the certificate and private key generated by this class in production. + * It is purely for testing purposes, and thus it is very insecure. + * It even uses an insecure pseudo-random generator for faster generation internally. + *

+ * An X.509 certificate file and a EC/RSA private key file are generated in a system's temporary directory using + * {@link java.io.File#createTempFile(String, String)}, and they are deleted when the JVM exits using + * {@link java.io.File#deleteOnExit()}. + *

+ * At first, this method tries to use OpenJDK's X.509 implementation (the {@code sun.security.x509} package). + * If it fails, it tries to use Bouncy Castle as a fallback. + *

+ */ +public final class SelfSignedCertificate { + + private static final InternalLogger logger = InternalLoggerFactory.getInstance(SelfSignedCertificate.class); + + /** Current time minus 1 year, just in case software clock goes back due to time synchronization */ + private static final Date DEFAULT_NOT_BEFORE = new Date(SystemPropertyUtil.getLong( + "io.netty.selfSignedCertificate.defaultNotBefore", System.currentTimeMillis() - 86400000L * 365)); + /** The maximum possible value in X.509 specification: 9999-12-31 23:59:59 */ + private static final Date DEFAULT_NOT_AFTER = new Date(SystemPropertyUtil.getLong( + "io.netty.selfSignedCertificate.defaultNotAfter", 253402300799000L)); + + /** + * FIPS 140-2 encryption requires the RSA key length to be 2048 bits or greater. + * Let's use that as a sane default but allow the default to be set dynamically + * for those that need more stringent security requirements. + */ + private static final int DEFAULT_KEY_LENGTH_BITS = + SystemPropertyUtil.getInt("io.netty.handler.ssl.util.selfSignedKeyStrength", 2048); + + private final File certificate; + private final File privateKey; + private final X509Certificate cert; + private final PrivateKey key; + + /** + * Creates a new instance. + *

Algorithm: RSA

+ */ + public SelfSignedCertificate() throws CertificateException { + this(DEFAULT_NOT_BEFORE, DEFAULT_NOT_AFTER, "RSA", DEFAULT_KEY_LENGTH_BITS); + } + + /** + * Creates a new instance. + *

Algorithm: RSA

+ * + * @param notBefore Certificate is not valid before this time + * @param notAfter Certificate is not valid after this time + */ + public SelfSignedCertificate(Date notBefore, Date notAfter) + throws CertificateException { + this("localhost", notBefore, notAfter, "RSA", DEFAULT_KEY_LENGTH_BITS); + } + + /** + * Creates a new instance. + * + * @param notBefore Certificate is not valid before this time + * @param notAfter Certificate is not valid after this time + * @param algorithm Key pair algorithm + * @param bits the number of bits of the generated private key + */ + public SelfSignedCertificate(Date notBefore, Date notAfter, String algorithm, int bits) + throws CertificateException { + this("localhost", notBefore, notAfter, algorithm, bits); + } + + /** + * Creates a new instance. + *

Algorithm: RSA

+ * + * @param fqdn a fully qualified domain name + */ + public SelfSignedCertificate(String fqdn) throws CertificateException { + this(fqdn, DEFAULT_NOT_BEFORE, DEFAULT_NOT_AFTER, "RSA", DEFAULT_KEY_LENGTH_BITS); + } + + /** + * Creates a new instance. + * + * @param fqdn a fully qualified domain name + * @param algorithm Key pair algorithm + * @param bits the number of bits of the generated private key + */ + public SelfSignedCertificate(String fqdn, String algorithm, int bits) throws CertificateException { + this(fqdn, DEFAULT_NOT_BEFORE, DEFAULT_NOT_AFTER, algorithm, bits); + } + + /** + * Creates a new instance. + *

Algorithm: RSA

+ * + * @param fqdn a fully qualified domain name + * @param notBefore Certificate is not valid before this time + * @param notAfter Certificate is not valid after this time + */ + public SelfSignedCertificate(String fqdn, Date notBefore, Date notAfter) throws CertificateException { + // Bypass entropy collection by using insecure random generator. + // We just want to generate it without any delay because it's for testing purposes only. + this(fqdn, ThreadLocalInsecureRandom.current(), DEFAULT_KEY_LENGTH_BITS, notBefore, notAfter, "RSA"); + } + + /** + * Creates a new instance. + * + * @param fqdn a fully qualified domain name + * @param notBefore Certificate is not valid before this time + * @param notAfter Certificate is not valid after this time + * @param algorithm Key pair algorithm + * @param bits the number of bits of the generated private key + */ + public SelfSignedCertificate(String fqdn, Date notBefore, Date notAfter, String algorithm, int bits) + throws CertificateException { + // Bypass entropy collection by using insecure random generator. + // We just want to generate it without any delay because it's for testing purposes only. + this(fqdn, ThreadLocalInsecureRandom.current(), bits, notBefore, notAfter, algorithm); + } + + /** + * Creates a new instance. + *

Algorithm: RSA

+ * + * @param fqdn a fully qualified domain name + * @param random the {@link SecureRandom} to use + * @param bits the number of bits of the generated private key + */ + public SelfSignedCertificate(String fqdn, SecureRandom random, int bits) + throws CertificateException { + this(fqdn, random, bits, DEFAULT_NOT_BEFORE, DEFAULT_NOT_AFTER, "RSA"); + } + + /** + * Creates a new instance. + * + * @param fqdn a fully qualified domain name + * @param random the {@link SecureRandom} to use + * @param algorithm Key pair algorithm + * @param bits the number of bits of the generated private key + */ + public SelfSignedCertificate(String fqdn, SecureRandom random, String algorithm, int bits) + throws CertificateException { + this(fqdn, random, bits, DEFAULT_NOT_BEFORE, DEFAULT_NOT_AFTER, algorithm); + } + + /** + * Creates a new instance. + *

Algorithm: RSA

+ * + * @param fqdn a fully qualified domain name + * @param random the {@link SecureRandom} to use + * @param bits the number of bits of the generated private key + * @param notBefore Certificate is not valid before this time + * @param notAfter Certificate is not valid after this time + */ + public SelfSignedCertificate(String fqdn, SecureRandom random, int bits, Date notBefore, Date notAfter) + throws CertificateException { + this(fqdn, random, bits, notBefore, notAfter, "RSA"); + } + + /** + * Creates a new instance. + * + * @param fqdn a fully qualified domain name + * @param random the {@link SecureRandom} to use + * @param bits the number of bits of the generated private key + * @param notBefore Certificate is not valid before this time + * @param notAfter Certificate is not valid after this time + * @param algorithm Key pair algorithm + */ + public SelfSignedCertificate(String fqdn, SecureRandom random, int bits, Date notBefore, Date notAfter, + String algorithm) throws CertificateException { + + if (!"EC".equalsIgnoreCase(algorithm) && !"RSA".equalsIgnoreCase(algorithm)) { + throw new IllegalArgumentException("Algorithm not valid: " + algorithm); + } + + final KeyPair keypair; + try { + KeyPairGenerator keyGen = KeyPairGenerator.getInstance(algorithm); + keyGen.initialize(bits, random); + keypair = keyGen.generateKeyPair(); + } catch (NoSuchAlgorithmException e) { + // Should not reach here because every Java implementation must have RSA and EC key pair generator. + throw new Error(e); + } + + String[] paths = null; + try { + // Try Bouncy Castle first as otherwise we will see an IllegalAccessError on more recent JDKs. + paths = BouncyCastleSelfSignedCertGenerator.generate( + fqdn, keypair, random, notBefore, notAfter, algorithm); + } catch (Throwable t) { + if (!isBouncyCastleAvailable()) { + logger.debug("Failed to generate a self-signed X.509 certificate because " + + "BouncyCastle PKIX is not available in classpath"); + } else { + logger.debug("Failed to generate a self-signed X.509 certificate using Bouncy Castle:", t); + } + } + + certificate = new File(paths[0]); + privateKey = new File(paths[1]); + key = keypair.getPrivate(); + FileInputStream certificateInput = null; + try { + certificateInput = new FileInputStream(certificate); + cert = (X509Certificate) CertificateFactory.getInstance("X509").generateCertificate(certificateInput); + } catch (Exception e) { + throw new CertificateEncodingException(e); + } finally { + if (certificateInput != null) { + try { + certificateInput.close(); + } catch (IOException e) { + if (logger.isWarnEnabled()) { + logger.warn("Failed to close a file: " + certificate, e); + } + } + } + } + } + + /** + * Returns the generated X.509 certificate file in PEM format. + */ + public File certificate() { + return certificate; + } + + /** + * Returns the generated EC/RSA private key file in PEM format. + */ + public File privateKey() { + return privateKey; + } + + /** + * Returns the generated X.509 certificate. + */ + public X509Certificate cert() { + return cert; + } + + /** + * Returns the generated EC/RSA private key. + */ + public PrivateKey key() { + return key; + } + + /** + * Deletes the generated X.509 certificate file and EC/RSA private key file. + */ + public void delete() { + safeDelete(certificate); + safeDelete(privateKey); + } + + static String[] newSelfSignedCertificate( + String fqdn, PrivateKey key, X509Certificate cert) throws IOException, CertificateEncodingException { + // Encode the private key into a file. + ByteBuf wrappedBuf = Unpooled.wrappedBuffer(key.getEncoded()); + ByteBuf encodedBuf; + final String keyText; + try { + encodedBuf = Base64.encode(wrappedBuf, true); + try { + keyText = "-----BEGIN PRIVATE KEY-----\n" + + encodedBuf.toString(CharsetUtil.US_ASCII) + + "\n-----END PRIVATE KEY-----\n"; + } finally { + encodedBuf.release(); + } + } finally { + wrappedBuf.release(); + } + + // Change all asterisk to 'x' for file name safety. + fqdn = fqdn.replaceAll("[^\\w.-]", "x"); + + File keyFile = PlatformDependent.createTempFile("keyutil_" + fqdn + '_', ".key", null); + keyFile.deleteOnExit(); + + OutputStream keyOut = new FileOutputStream(keyFile); + try { + keyOut.write(keyText.getBytes(CharsetUtil.US_ASCII)); + keyOut.close(); + keyOut = null; + } finally { + if (keyOut != null) { + safeClose(keyFile, keyOut); + safeDelete(keyFile); + } + } + + wrappedBuf = Unpooled.wrappedBuffer(cert.getEncoded()); + final String certText; + try { + encodedBuf = Base64.encode(wrappedBuf, true); + try { + // Encode the certificate into a CRT file. + certText = "-----BEGIN CERTIFICATE-----\n" + + encodedBuf.toString(CharsetUtil.US_ASCII) + + "\n-----END CERTIFICATE-----\n"; + } finally { + encodedBuf.release(); + } + } finally { + wrappedBuf.release(); + } + + File certFile = PlatformDependent.createTempFile("keyutil_" + fqdn + '_', ".crt", null); + certFile.deleteOnExit(); + + OutputStream certOut = new FileOutputStream(certFile); + try { + certOut.write(certText.getBytes(CharsetUtil.US_ASCII)); + certOut.close(); + certOut = null; + } finally { + if (certOut != null) { + safeClose(certFile, certOut); + safeDelete(certFile); + safeDelete(keyFile); + } + } + + return new String[] { certFile.getPath(), keyFile.getPath() }; + } + + private static void safeDelete(File certFile) { + if (!certFile.delete()) { + if (logger.isWarnEnabled()) { + logger.warn("Failed to delete a file: " + certFile); + } + } + } + + private static void safeClose(File keyFile, OutputStream keyOut) { + try { + keyOut.close(); + } catch (IOException e) { + if (logger.isWarnEnabled()) { + logger.warn("Failed to close a file: " + keyFile, e); + } + } + } + + private static boolean isBouncyCastleAvailable() { + try { + Class.forName("org.bouncycastle.cert.X509v3CertificateBuilder"); + return true; + } catch (ClassNotFoundException e) { + return false; + } + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/util/SimpleKeyManagerFactory.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/util/SimpleKeyManagerFactory.java new file mode 100644 index 0000000..befc869 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/util/SimpleKeyManagerFactory.java @@ -0,0 +1,154 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.ssl.util; + +import io.netty.util.concurrent.FastThreadLocal; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.StringUtil; +import io.netty.util.internal.SuppressJava6Requirement; +import java.security.InvalidAlgorithmParameterException; +import java.security.KeyStore; +import java.security.KeyStoreException; +import java.security.Provider; +import javax.net.ssl.ManagerFactoryParameters; +import javax.net.ssl.KeyManager; +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.KeyManagerFactorySpi; +import javax.net.ssl.X509ExtendedKeyManager; +import javax.net.ssl.X509KeyManager; + +/** + * Helps to implement a custom {@link KeyManagerFactory}. + */ +public abstract class SimpleKeyManagerFactory extends KeyManagerFactory { + + private static final Provider PROVIDER = new Provider("", 0.0, "") { + private static final long serialVersionUID = -2680540247105807895L; + }; + + /** + * {@link SimpleKeyManagerFactorySpi} must have a reference to {@link SimpleKeyManagerFactory} + * to delegate its callbacks back to {@link SimpleKeyManagerFactory}. However, it is impossible to do so, + * because {@link KeyManagerFactory} requires {@link KeyManagerFactorySpi} at construction time and + * does not provide a way to access it later. + * + * To work around this issue, we use an ugly hack which uses a {@link FastThreadLocal }. + */ + private static final FastThreadLocal CURRENT_SPI = + new FastThreadLocal() { + @Override + protected SimpleKeyManagerFactorySpi initialValue() { + return new SimpleKeyManagerFactorySpi(); + } + }; + + /** + * Creates a new instance. + */ + protected SimpleKeyManagerFactory() { + this(StringUtil.EMPTY_STRING); + } + + /** + * Creates a new instance. + * + * @param name the name of this {@link KeyManagerFactory} + */ + protected SimpleKeyManagerFactory(String name) { + super(CURRENT_SPI.get(), PROVIDER, ObjectUtil.checkNotNull(name, "name")); + CURRENT_SPI.get().init(this); + CURRENT_SPI.remove(); + } + + /** + * Initializes this factory with a source of certificate authorities and related key material. + * + * @see KeyManagerFactorySpi#engineInit(KeyStore, char[]) + */ + protected abstract void engineInit(KeyStore keyStore, char[] var2) throws Exception; + + /** + * Initializes this factory with a source of provider-specific key material. + * + * @see KeyManagerFactorySpi#engineInit(ManagerFactoryParameters) + */ + protected abstract void engineInit(ManagerFactoryParameters managerFactoryParameters) throws Exception; + + /** + * Returns one key manager for each type of key material. + * + * @see KeyManagerFactorySpi#engineGetKeyManagers() + */ + protected abstract KeyManager[] engineGetKeyManagers(); + + private static final class SimpleKeyManagerFactorySpi extends KeyManagerFactorySpi { + + private SimpleKeyManagerFactory parent; + private volatile KeyManager[] keyManagers; + + void init(SimpleKeyManagerFactory parent) { + this.parent = parent; + } + + @Override + protected void engineInit(KeyStore keyStore, char[] pwd) throws KeyStoreException { + try { + parent.engineInit(keyStore, pwd); + } catch (KeyStoreException e) { + throw e; + } catch (Exception e) { + throw new KeyStoreException(e); + } + } + + @Override + protected void engineInit( + ManagerFactoryParameters managerFactoryParameters) throws InvalidAlgorithmParameterException { + try { + parent.engineInit(managerFactoryParameters); + } catch (InvalidAlgorithmParameterException e) { + throw e; + } catch (Exception e) { + throw new InvalidAlgorithmParameterException(e); + } + } + + @Override + protected KeyManager[] engineGetKeyManagers() { + KeyManager[] keyManagers = this.keyManagers; + if (keyManagers == null) { + keyManagers = parent.engineGetKeyManagers(); + if (PlatformDependent.javaVersion() >= 7) { + wrapIfNeeded(keyManagers); + } + this.keyManagers = keyManagers; + } + return keyManagers.clone(); + } + + @SuppressJava6Requirement(reason = "Usage guarded by java version check") + private static void wrapIfNeeded(KeyManager[] keyManagers) { + for (int i = 0; i < keyManagers.length; i++) { + final KeyManager tm = keyManagers[i]; + if (tm instanceof X509KeyManager && !(tm instanceof X509ExtendedKeyManager)) { + keyManagers[i] = new X509KeyManagerWrapper((X509KeyManager) tm); + } + } + } + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/util/SimpleTrustManagerFactory.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/util/SimpleTrustManagerFactory.java new file mode 100644 index 0000000..c6d4b8b --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/util/SimpleTrustManagerFactory.java @@ -0,0 +1,156 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.ssl.util; + +import io.netty.util.concurrent.FastThreadLocal; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.SuppressJava6Requirement; + +import javax.net.ssl.ManagerFactoryParameters; +import javax.net.ssl.TrustManager; +import javax.net.ssl.TrustManagerFactory; +import javax.net.ssl.TrustManagerFactorySpi; +import javax.net.ssl.X509ExtendedTrustManager; +import javax.net.ssl.X509TrustManager; +import java.security.InvalidAlgorithmParameterException; +import java.security.KeyStore; +import java.security.KeyStoreException; +import java.security.Provider; + +/** + * Helps to implement a custom {@link TrustManagerFactory}. + */ +public abstract class SimpleTrustManagerFactory extends TrustManagerFactory { + + private static final Provider PROVIDER = new Provider("", 0.0, "") { + private static final long serialVersionUID = -2680540247105807895L; + }; + + /** + * {@link SimpleTrustManagerFactorySpi} must have a reference to {@link SimpleTrustManagerFactory} + * to delegate its callbacks back to {@link SimpleTrustManagerFactory}. However, it is impossible to do so, + * because {@link TrustManagerFactory} requires {@link TrustManagerFactorySpi} at construction time and + * does not provide a way to access it later. + * + * To work around this issue, we use an ugly hack which uses a {@link ThreadLocal}. + */ + private static final FastThreadLocal CURRENT_SPI = + new FastThreadLocal() { + @Override + protected SimpleTrustManagerFactorySpi initialValue() { + return new SimpleTrustManagerFactorySpi(); + } + }; + + /** + * Creates a new instance. + */ + protected SimpleTrustManagerFactory() { + this(""); + } + + /** + * Creates a new instance. + * + * @param name the name of this {@link TrustManagerFactory} + */ + protected SimpleTrustManagerFactory(String name) { + super(CURRENT_SPI.get(), PROVIDER, name); + CURRENT_SPI.get().init(this); + CURRENT_SPI.remove(); + + ObjectUtil.checkNotNull(name, "name"); + } + + /** + * Initializes this factory with a source of certificate authorities and related trust material. + * + * @see TrustManagerFactorySpi#engineInit(KeyStore) + */ + protected abstract void engineInit(KeyStore keyStore) throws Exception; + + /** + * Initializes this factory with a source of provider-specific key material. + * + * @see TrustManagerFactorySpi#engineInit(ManagerFactoryParameters) + */ + protected abstract void engineInit(ManagerFactoryParameters managerFactoryParameters) throws Exception; + + /** + * Returns one trust manager for each type of trust material. + * + * @see TrustManagerFactorySpi#engineGetTrustManagers() + */ + protected abstract TrustManager[] engineGetTrustManagers(); + + static final class SimpleTrustManagerFactorySpi extends TrustManagerFactorySpi { + + private SimpleTrustManagerFactory parent; + private volatile TrustManager[] trustManagers; + + void init(SimpleTrustManagerFactory parent) { + this.parent = parent; + } + + @Override + protected void engineInit(KeyStore keyStore) throws KeyStoreException { + try { + parent.engineInit(keyStore); + } catch (KeyStoreException e) { + throw e; + } catch (Exception e) { + throw new KeyStoreException(e); + } + } + + @Override + protected void engineInit( + ManagerFactoryParameters managerFactoryParameters) throws InvalidAlgorithmParameterException { + try { + parent.engineInit(managerFactoryParameters); + } catch (InvalidAlgorithmParameterException e) { + throw e; + } catch (Exception e) { + throw new InvalidAlgorithmParameterException(e); + } + } + + @Override + protected TrustManager[] engineGetTrustManagers() { + TrustManager[] trustManagers = this.trustManagers; + if (trustManagers == null) { + trustManagers = parent.engineGetTrustManagers(); + if (PlatformDependent.javaVersion() >= 7) { + wrapIfNeeded(trustManagers); + } + this.trustManagers = trustManagers; + } + return trustManagers.clone(); + } + + @SuppressJava6Requirement(reason = "Usage guarded by java version check") + private static void wrapIfNeeded(TrustManager[] trustManagers) { + for (int i = 0; i < trustManagers.length; i++) { + final TrustManager tm = trustManagers[i]; + if (tm instanceof X509TrustManager && !(tm instanceof X509ExtendedTrustManager)) { + trustManagers[i] = new X509TrustManagerWrapper((X509TrustManager) tm); + } + } + } + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/util/ThreadLocalInsecureRandom.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/util/ThreadLocalInsecureRandom.java new file mode 100644 index 0000000..980938b --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/util/ThreadLocalInsecureRandom.java @@ -0,0 +1,101 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.ssl.util; + +import io.netty.util.internal.PlatformDependent; + +import java.security.SecureRandom; +import java.util.Random; + +/** + * Insecure {@link SecureRandom} which relies on {@link PlatformDependent#threadLocalRandom()} for random number + * generation. + */ +final class ThreadLocalInsecureRandom extends SecureRandom { + + private static final long serialVersionUID = -8209473337192526191L; + + private static final SecureRandom INSTANCE = new ThreadLocalInsecureRandom(); + + static SecureRandom current() { + return INSTANCE; + } + + private ThreadLocalInsecureRandom() { } + + @Override + public String getAlgorithm() { + return "insecure"; + } + + @Override + public void setSeed(byte[] seed) { } + + @Override + public void setSeed(long seed) { } + + @Override + public void nextBytes(byte[] bytes) { + random().nextBytes(bytes); + } + + @Override + public byte[] generateSeed(int numBytes) { + byte[] seed = new byte[numBytes]; + random().nextBytes(seed); + return seed; + } + + @Override + public int nextInt() { + return random().nextInt(); + } + + @Override + public int nextInt(int n) { + return random().nextInt(n); + } + + @Override + public boolean nextBoolean() { + return random().nextBoolean(); + } + + @Override + public long nextLong() { + return random().nextLong(); + } + + @Override + public float nextFloat() { + return random().nextFloat(); + } + + @Override + public double nextDouble() { + return random().nextDouble(); + } + + @Override + public double nextGaussian() { + return random().nextGaussian(); + } + + private static Random random() { + return PlatformDependent.threadLocalRandom(); + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/util/TrustManagerFactoryWrapper.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/util/TrustManagerFactoryWrapper.java new file mode 100644 index 0000000..e28df7f --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/util/TrustManagerFactoryWrapper.java @@ -0,0 +1,43 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.ssl.util; + +import io.netty.util.internal.ObjectUtil; + +import java.security.KeyStore; +import javax.net.ssl.ManagerFactoryParameters; +import javax.net.ssl.TrustManager; + +public final class TrustManagerFactoryWrapper extends SimpleTrustManagerFactory { + private final TrustManager tm; + + public TrustManagerFactoryWrapper(TrustManager tm) { + this.tm = ObjectUtil.checkNotNull(tm, "tm"); + } + + @Override + protected void engineInit(KeyStore keyStore) throws Exception { } + + @Override + protected void engineInit(ManagerFactoryParameters managerFactoryParameters) + throws Exception { } + + @Override + protected TrustManager[] engineGetTrustManagers() { + return new TrustManager[] {tm}; + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/util/X509KeyManagerWrapper.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/util/X509KeyManagerWrapper.java new file mode 100644 index 0000000..0c95ec6 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/util/X509KeyManagerWrapper.java @@ -0,0 +1,78 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.ssl.util; + +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +import io.netty.util.internal.SuppressJava6Requirement; +import java.net.Socket; +import java.security.Principal; +import java.security.PrivateKey; +import java.security.cert.X509Certificate; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.X509ExtendedKeyManager; +import javax.net.ssl.X509KeyManager; + +@SuppressJava6Requirement(reason = "Usage guarded by java version check") +final class X509KeyManagerWrapper extends X509ExtendedKeyManager { + + private final X509KeyManager delegate; + + X509KeyManagerWrapper(X509KeyManager delegate) { + this.delegate = checkNotNull(delegate, "delegate"); + } + + @Override + public String[] getClientAliases(String var1, Principal[] var2) { + return delegate.getClientAliases(var1, var2); + } + + @Override + public String chooseClientAlias(String[] var1, Principal[] var2, Socket var3) { + return delegate.chooseClientAlias(var1, var2, var3); + } + + @Override + public String[] getServerAliases(String var1, Principal[] var2) { + return delegate.getServerAliases(var1, var2); + } + + @Override + public String chooseServerAlias(String var1, Principal[] var2, Socket var3) { + return delegate.chooseServerAlias(var1, var2, var3); + } + + @Override + public X509Certificate[] getCertificateChain(String var1) { + return delegate.getCertificateChain(var1); + } + + @Override + public PrivateKey getPrivateKey(String var1) { + return delegate.getPrivateKey(var1); + } + + @Override + public String chooseEngineClientAlias(String[] keyType, Principal[] issuers, SSLEngine engine) { + return delegate.chooseClientAlias(keyType, issuers, null); + } + + @Override + public String chooseEngineServerAlias(String keyType, Principal[] issuers, SSLEngine engine) { + return delegate.chooseServerAlias(keyType, issuers, null); + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/util/X509TrustManagerWrapper.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/util/X509TrustManagerWrapper.java new file mode 100644 index 0000000..acdab47 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/util/X509TrustManagerWrapper.java @@ -0,0 +1,76 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl.util; + +import io.netty.util.internal.SuppressJava6Requirement; + +import javax.net.ssl.SSLEngine; +import javax.net.ssl.X509ExtendedTrustManager; +import javax.net.ssl.X509TrustManager; +import java.net.Socket; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; + +import static io.netty.util.internal.ObjectUtil.*; + +@SuppressJava6Requirement(reason = "Usage guarded by java version check") +final class X509TrustManagerWrapper extends X509ExtendedTrustManager { + + private final X509TrustManager delegate; + + X509TrustManagerWrapper(X509TrustManager delegate) { + this.delegate = checkNotNull(delegate, "delegate"); + } + + @Override + public void checkClientTrusted(X509Certificate[] chain, String s) throws CertificateException { + delegate.checkClientTrusted(chain, s); + } + + @Override + public void checkClientTrusted(X509Certificate[] chain, String s, Socket socket) + throws CertificateException { + delegate.checkClientTrusted(chain, s); + } + + @Override + public void checkClientTrusted(X509Certificate[] chain, String s, SSLEngine sslEngine) + throws CertificateException { + delegate.checkClientTrusted(chain, s); + } + + @Override + public void checkServerTrusted(X509Certificate[] chain, String s) throws CertificateException { + delegate.checkServerTrusted(chain, s); + } + + @Override + public void checkServerTrusted(X509Certificate[] chain, String s, Socket socket) + throws CertificateException { + delegate.checkServerTrusted(chain, s); + } + + @Override + public void checkServerTrusted(X509Certificate[] chain, String s, SSLEngine sslEngine) + throws CertificateException { + delegate.checkServerTrusted(chain, s); + } + + @Override + public X509Certificate[] getAcceptedIssuers() { + return delegate.getAcceptedIssuers(); + } +} diff --git a/netty-handler-ssl/src/main/java/io/netty/handler/ssl/util/package-info.java b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/util/package-info.java new file mode 100644 index 0000000..e2abab7 --- /dev/null +++ b/netty-handler-ssl/src/main/java/io/netty/handler/ssl/util/package-info.java @@ -0,0 +1,20 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/** + * Utility classes that helps easier development of TLS/SSL applications. + */ +package io.netty.handler.ssl.util; diff --git a/netty-handler-ssl/src/main/java/module-info.java b/netty-handler-ssl/src/main/java/module-info.java new file mode 100644 index 0000000..c460f53 --- /dev/null +++ b/netty-handler-ssl/src/main/java/module-info.java @@ -0,0 +1,14 @@ +module org.xbib.io.netty.handler.ssl { + exports io.netty.handler.ssl; + exports io.netty.handler.ssl.util; + exports io.netty.handler.ssl.ocsp; + requires org.xbib.io.netty.buffer; + requires org.xbib.io.netty.channel; + requires org.xbib.io.netty.channel.unix; + requires org.xbib.io.netty.handler.codec; + requires org.xbib.io.netty.internal.tcnative; + requires org.xbib.io.netty.util; + requires org.bouncycastle.pkix; + requires org.bouncycastle.provider; + requires org.conscrypt; +} diff --git a/netty-handler-ssl/src/test/java/io/netty/handler/ssl/AmazonCorrettoSslEngineTest.java b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/AmazonCorrettoSslEngineTest.java new file mode 100644 index 0000000..498e734 --- /dev/null +++ b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/AmazonCorrettoSslEngineTest.java @@ -0,0 +1,103 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import com.amazon.corretto.crypto.provider.AmazonCorrettoCryptoProvider; +import com.amazon.corretto.crypto.provider.SelfTestStatus; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.condition.DisabledIf; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import javax.crypto.Cipher; +import java.security.Security; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; + + +@DisabledIf("checkIfAccpIsDisabled") +public class AmazonCorrettoSslEngineTest extends SSLEngineTest { + + static boolean checkIfAccpIsDisabled() { + return AmazonCorrettoCryptoProvider.INSTANCE.getLoadingError() != null || + !AmazonCorrettoCryptoProvider.INSTANCE.runSelfTests().equals(SelfTestStatus.PASSED); + } + + public AmazonCorrettoSslEngineTest() { + super(SslProvider.isTlsv13Supported(SslProvider.JDK)); + } + + @Override + protected SslProvider sslClientProvider() { + return SslProvider.JDK; + } + + @Override + protected SslProvider sslServerProvider() { + return SslProvider.JDK; + } + + @BeforeEach + @Override + public void setup() { + // See https://github.com/corretto/amazon-corretto-crypto-provider/blob/develop/README.md#code + Security.insertProviderAt(AmazonCorrettoCryptoProvider.INSTANCE, 1); + + // See https://github.com/corretto/amazon-corretto-crypto-provider/blob/develop/README.md#verification-optional + try { + AmazonCorrettoCryptoProvider.INSTANCE.assertHealthy(); + String providerName = Cipher.getInstance("AES/GCM/NoPadding").getProvider().getName(); + assertEquals(AmazonCorrettoCryptoProvider.PROVIDER_NAME, providerName); + } catch (Throwable e) { + Security.removeProvider(AmazonCorrettoCryptoProvider.PROVIDER_NAME); + throw new AssertionError(e); + } + super.setup(); + } + + @AfterEach + @Override + public void tearDown() throws InterruptedException { + super.tearDown(); + + // Remove the provider again and verify that it was removed + Security.removeProvider(AmazonCorrettoCryptoProvider.PROVIDER_NAME); + assertNull(Security.getProvider(AmazonCorrettoCryptoProvider.PROVIDER_NAME)); + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Disabled /* Does the JDK support a "max certificate chain length"? */ + @Override + public void testMutualAuthValidClientCertChainTooLongFailOptionalClientAuth(SSLEngineTestParam param) { + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Disabled /* Does the JDK support a "max certificate chain length"? */ + @Override + public void testMutualAuthValidClientCertChainTooLongFailRequireClientAuth(SSLEngineTestParam param) { + } + + @Override + protected boolean mySetupMutualAuthServerIsValidException(Throwable cause) { + // TODO(scott): work around for a JDK issue. The exception should be SSLHandshakeException. + return super.mySetupMutualAuthServerIsValidException(cause) || causedBySSLException(cause); + } +} diff --git a/netty-handler-ssl/src/test/java/io/netty/handler/ssl/ApplicationProtocolNegotiationHandlerTest.java b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/ApplicationProtocolNegotiationHandlerTest.java new file mode 100644 index 0000000..66a20b3 --- /dev/null +++ b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/ApplicationProtocolNegotiationHandlerTest.java @@ -0,0 +1,232 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.channel.socket.ChannelInputShutdownEvent; +import io.netty.handler.codec.DecoderException; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLHandshakeException; +import java.security.NoSuchAlgorithmException; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; + +import static io.netty.handler.ssl.CloseNotifyTest.assertCloseNotify; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class ApplicationProtocolNegotiationHandlerTest { + + @Test + public void testRemoveItselfIfNoSslHandlerPresent() throws NoSuchAlgorithmException { + ChannelHandler alpnHandler = new ApplicationProtocolNegotiationHandler(ApplicationProtocolNames.HTTP_1_1) { + @Override + protected void configurePipeline(ChannelHandlerContext ctx, String protocol) { + fail(); + } + }; + + SSLEngine engine = SSLContext.getDefault().createSSLEngine(); + // This test is mocked/simulated and doesn't go through full TLS handshake. Currently only JDK SSLEngineImpl + // client mode will generate a close_notify. + engine.setUseClientMode(true); + + EmbeddedChannel channel = new EmbeddedChannel(alpnHandler); + String msg = "msg"; + String msg2 = "msg2"; + + assertTrue(channel.writeInbound(msg)); + assertTrue(channel.writeInbound(msg2)); + assertNull(channel.pipeline().context(alpnHandler)); + assertEquals(msg, channel.readInbound()); + assertEquals(msg2, channel.readInbound()); + + assertFalse(channel.finishAndReleaseAll()); + } + + @Test + public void testHandshakeFailure() { + ChannelHandler alpnHandler = new ApplicationProtocolNegotiationHandler(ApplicationProtocolNames.HTTP_1_1) { + @Override + protected void configurePipeline(ChannelHandlerContext ctx, String protocol) { + fail(); + } + }; + + EmbeddedChannel channel = new EmbeddedChannel(alpnHandler); + SSLHandshakeException exception = new SSLHandshakeException("error"); + SslHandshakeCompletionEvent completionEvent = new SslHandshakeCompletionEvent(exception); + channel.pipeline().fireUserEventTriggered(completionEvent); + channel.pipeline().fireExceptionCaught(new DecoderException(exception)); + assertNull(channel.pipeline().context(alpnHandler)); + assertFalse(channel.finishAndReleaseAll()); + } + + @Test + public void testHandshakeSuccess() throws NoSuchAlgorithmException { + testHandshakeSuccess0(false); + } + + @Test + public void testHandshakeSuccessWithSslHandlerAddedLater() throws NoSuchAlgorithmException { + testHandshakeSuccess0(true); + } + + private static void testHandshakeSuccess0(boolean addLater) throws NoSuchAlgorithmException { + final AtomicBoolean configureCalled = new AtomicBoolean(false); + ChannelHandler alpnHandler = new ApplicationProtocolNegotiationHandler(ApplicationProtocolNames.HTTP_1_1) { + @Override + protected void configurePipeline(ChannelHandlerContext ctx, String protocol) { + configureCalled.set(true); + assertEquals(ApplicationProtocolNames.HTTP_1_1, protocol); + } + }; + + SSLEngine engine = SSLContext.getDefault().createSSLEngine(); + // This test is mocked/simulated and doesn't go through full TLS handshake. Currently only JDK SSLEngineImpl + // client mode will generate a close_notify. + engine.setUseClientMode(true); + + EmbeddedChannel channel = new EmbeddedChannel(); + if (addLater) { + channel.pipeline().addLast(alpnHandler); + channel.pipeline().addFirst(new SslHandler(engine)); + } else { + channel.pipeline().addLast(new SslHandler(engine)); + channel.pipeline().addLast(alpnHandler); + } + channel.pipeline().fireUserEventTriggered(SslHandshakeCompletionEvent.SUCCESS); + assertNull(channel.pipeline().context(alpnHandler)); + // Should produce the close_notify messages + channel.releaseOutbound(); + channel.close(); + assertCloseNotify((ByteBuf) channel.readOutbound()); + channel.finishAndReleaseAll(); + assertTrue(configureCalled.get()); + } + + @Test + public void testHandshakeSuccessButNoSslHandler() { + ChannelHandler alpnHandler = new ApplicationProtocolNegotiationHandler(ApplicationProtocolNames.HTTP_1_1) { + @Override + protected void configurePipeline(ChannelHandlerContext ctx, String protocol) { + fail(); + } + }; + final EmbeddedChannel channel = new EmbeddedChannel(alpnHandler); + channel.pipeline().fireUserEventTriggered(SslHandshakeCompletionEvent.SUCCESS); + assertNull(channel.pipeline().context(alpnHandler)); + assertThrows(IllegalStateException.class, new Executable() { + @Override + public void execute() throws Throwable { + channel.finishAndReleaseAll(); + } + }); + } + + @Test + public void testBufferMessagesUntilHandshakeComplete() throws Exception { + testBufferMessagesUntilHandshakeComplete(null); + } + + @Test + public void testBufferMessagesUntilHandshakeCompleteWithClose() throws Exception { + testBufferMessagesUntilHandshakeComplete( + new ApplicationProtocolNegotiationHandlerTest.Consumer() { + @Override + public void consume(ChannelHandlerContext ctx) { + ctx.channel().close(); + } + }); + } + + @Test + public void testBufferMessagesUntilHandshakeCompleteWithInputShutdown() throws Exception { + testBufferMessagesUntilHandshakeComplete( + new ApplicationProtocolNegotiationHandlerTest.Consumer() { + @Override + public void consume(ChannelHandlerContext ctx) { + ctx.fireUserEventTriggered(ChannelInputShutdownEvent.INSTANCE); + } + }); + } + + private void testBufferMessagesUntilHandshakeComplete(final Consumer pipelineConfigurator) + throws Exception { + final AtomicReference channelReadData = new AtomicReference(); + final AtomicBoolean channelReadCompleteCalled = new AtomicBoolean(false); + ChannelHandler alpnHandler = new ApplicationProtocolNegotiationHandler(ApplicationProtocolNames.HTTP_1_1) { + @Override + protected void configurePipeline(ChannelHandlerContext ctx, String protocol) { + assertEquals(ApplicationProtocolNames.HTTP_1_1, protocol); + ctx.pipeline().addLast(new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + channelReadData.set((byte[]) msg); + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) { + channelReadCompleteCalled.set(true); + } + }); + if (pipelineConfigurator != null) { + pipelineConfigurator.consume(ctx); + } + } + }; + + SSLEngine engine = SSLContext.getDefault().createSSLEngine(); + // This test is mocked/simulated and doesn't go through full TLS handshake. Currently only JDK SSLEngineImpl + // client mode will generate a close_notify. + engine.setUseClientMode(true); + + final byte[] someBytes = new byte[1024]; + + EmbeddedChannel channel = new EmbeddedChannel(new SslHandler(engine), new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + if (evt == SslHandshakeCompletionEvent.SUCCESS) { + ctx.fireChannelRead(someBytes); + } + ctx.fireUserEventTriggered(evt); + } + }, alpnHandler); + channel.pipeline().fireUserEventTriggered(SslHandshakeCompletionEvent.SUCCESS); + assertNull(channel.pipeline().context(alpnHandler)); + assertArrayEquals(someBytes, channelReadData.get()); + assertTrue(channelReadCompleteCalled.get()); + assertNull(channel.readInbound()); + assertTrue(channel.finishAndReleaseAll()); + } + + private interface Consumer { + void consume(T t); + } +} diff --git a/netty-handler-ssl/src/test/java/io/netty/handler/ssl/CipherSuiteCanaryTest.java b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/CipherSuiteCanaryTest.java new file mode 100644 index 0000000..9b7398b --- /dev/null +++ b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/CipherSuiteCanaryTest.java @@ -0,0 +1,285 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version + * 2.0 (the "License"); you may not use this file except in compliance with the + * License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ + +package io.netty.handler.ssl; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.DefaultEventLoopGroup; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.channel.local.LocalAddress; +import io.netty.channel.local.LocalChannel; +import io.netty.channel.local.LocalServerChannel; +import io.netty.handler.ssl.util.InsecureTrustManagerFactory; +import io.netty.handler.ssl.util.SelfSignedCertificate; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.concurrent.Promise; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import java.net.SocketAddress; +import java.security.NoSuchAlgorithmException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.Executor; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; + +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLEngine; + +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +/** + * The purpose of this unit test is to act as a canary and catch changes in supported cipher suites. + */ +public class CipherSuiteCanaryTest { + + private static EventLoopGroup GROUP; + + private static SelfSignedCertificate CERT; + + static Collection parameters() { + List dst = new ArrayList(); + dst.addAll(expand("TLS_DHE_RSA_WITH_AES_128_GCM_SHA256")); // DHE-RSA-AES128-GCM-SHA256 + return dst; + } + + @BeforeAll + public static void init() throws Exception { + GROUP = new DefaultEventLoopGroup(); + CERT = new SelfSignedCertificate(); + } + + @AfterAll + public static void destroy() { + GROUP.shutdownGracefully(); + CERT.delete(); + } + + private static void assumeCipherAvailable(SslProvider provider, String cipher) throws NoSuchAlgorithmException { + boolean cipherSupported = false; + if (provider == SslProvider.JDK) { + SSLEngine engine = SSLContext.getDefault().createSSLEngine(); + for (String c: engine.getSupportedCipherSuites()) { + if (cipher.equals(c)) { + cipherSupported = true; + break; + } + } + } else { + cipherSupported = OpenSsl.isCipherSuiteAvailable(cipher); + } + assumeTrue(cipherSupported, "Unsupported cipher: " + cipher); + } + + private static SslHandler newSslHandler(SslContext sslCtx, ByteBufAllocator allocator, Executor executor) { + if (executor == null) { + return sslCtx.newHandler(allocator); + } else { + return sslCtx.newHandler(allocator, executor); + } + } + + @ParameterizedTest( + name = "{index}: serverSslProvider = {0}, clientSslProvider = {1}, rfcCipherName = {2}, delegate = {3}") + @MethodSource("parameters") + public void testHandshake(SslProvider serverSslProvider, SslProvider clientSslProvider, + String rfcCipherName, boolean delegate) throws Exception { + // Check if the cipher is supported at all which may not be the case for various JDK versions and OpenSSL API + // implementations. + assumeCipherAvailable(serverSslProvider, rfcCipherName); + assumeCipherAvailable(clientSslProvider, rfcCipherName); + + List ciphers = Collections.singletonList(rfcCipherName); + + final SslContext sslServerContext = SslContextBuilder.forServer(CERT.certificate(), CERT.privateKey()) + .sslProvider(serverSslProvider) + .ciphers(ciphers) + // As this is not a TLSv1.3 cipher we should ensure we talk something else. + .protocols(SslProtocols.TLS_v1_2) + .build(); + + final ExecutorService executorService = delegate ? Executors.newCachedThreadPool() : null; + + try { + final SslContext sslClientContext = SslContextBuilder.forClient() + .sslProvider(clientSslProvider) + .ciphers(ciphers) + // As this is not a TLSv1.3 cipher we should ensure we talk something else. + .protocols(SslProtocols.TLS_v1_2) + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .build(); + + try { + final Promise serverPromise = GROUP.next().newPromise(); + final Promise clientPromise = GROUP.next().newPromise(); + + ChannelHandler serverHandler = new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ChannelPipeline pipeline = ch.pipeline(); + pipeline.addLast(newSslHandler(sslServerContext, ch.alloc(), executorService)); + + pipeline.addLast(new SimpleChannelInboundHandler() { + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + serverPromise.cancel(true); + ctx.fireChannelInactive(); + } + + @Override + public void channelRead0(ChannelHandlerContext ctx, Object msg) throws Exception { + if (serverPromise.trySuccess(null)) { + ctx.writeAndFlush(Unpooled.wrappedBuffer(new byte[] {'P', 'O', 'N', 'G'})); + } + ctx.close(); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + if (!serverPromise.tryFailure(cause)) { + ctx.fireExceptionCaught(cause); + } + } + }); + } + }; + + LocalAddress address = new LocalAddress("test-" + serverSslProvider + + '-' + clientSslProvider + '-' + rfcCipherName); + + Channel server = server(address, serverHandler); + try { + ChannelHandler clientHandler = new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ChannelPipeline pipeline = ch.pipeline(); + pipeline.addLast(newSslHandler(sslClientContext, ch.alloc(), executorService)); + + pipeline.addLast(new SimpleChannelInboundHandler() { + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + clientPromise.cancel(true); + ctx.fireChannelInactive(); + } + + @Override + public void channelRead0(ChannelHandlerContext ctx, Object msg) throws Exception { + clientPromise.trySuccess(null); + ctx.close(); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) + throws Exception { + if (!clientPromise.tryFailure(cause)) { + ctx.fireExceptionCaught(cause); + } + } + }); + } + }; + + Channel client = client(server, clientHandler); + try { + client.writeAndFlush(Unpooled.wrappedBuffer(new byte[] {'P', 'I', 'N', 'G'})) + .syncUninterruptibly(); + + assertTrue(clientPromise.await(5L, TimeUnit.SECONDS), "client timeout"); + assertTrue(serverPromise.await(5L, TimeUnit.SECONDS), "server timeout"); + + clientPromise.sync(); + serverPromise.sync(); + } finally { + client.close().sync(); + } + } finally { + server.close().sync(); + } + } finally { + ReferenceCountUtil.release(sslClientContext); + } + } finally { + ReferenceCountUtil.release(sslServerContext); + + if (executorService != null) { + executorService.shutdown(); + } + } + } + + private static Channel server(LocalAddress address, ChannelHandler handler) throws Exception { + ServerBootstrap bootstrap = new ServerBootstrap() + .channel(LocalServerChannel.class) + .group(GROUP) + .childHandler(handler); + + return bootstrap.bind(address).sync().channel(); + } + + private static Channel client(Channel server, ChannelHandler handler) throws Exception { + SocketAddress remoteAddress = server.localAddress(); + + Bootstrap bootstrap = new Bootstrap() + .channel(LocalChannel.class) + .group(GROUP) + .handler(handler); + + return bootstrap.connect(remoteAddress).sync().channel(); + } + + private static List expand(String rfcCipherName) { + List dst = new ArrayList(); + SslProvider[] sslProviders = SslProvider.values(); + + for (int i = 0; i < sslProviders.length; i++) { + SslProvider serverSslProvider = sslProviders[i]; + + for (int j = 0; j < sslProviders.length; j++) { + SslProvider clientSslProvider = sslProviders[j]; + + if ((serverSslProvider != SslProvider.JDK || clientSslProvider != SslProvider.JDK) + && !OpenSsl.isAvailable()) { + continue; + } + + dst.add(new Object[]{serverSslProvider, clientSslProvider, rfcCipherName, true}); + dst.add(new Object[]{serverSslProvider, clientSslProvider, rfcCipherName, false}); + } + } + + if (dst.isEmpty()) { + throw new IllegalStateException(); + } + + return dst; + } +} diff --git a/netty-handler-ssl/src/test/java/io/netty/handler/ssl/CipherSuiteConverterTest.java b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/CipherSuiteConverterTest.java new file mode 100644 index 0000000..c7cc8e1 --- /dev/null +++ b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/CipherSuiteConverterTest.java @@ -0,0 +1,417 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.ssl; + +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.parallel.Isolated; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.sameInstance; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; + +@Isolated +public class CipherSuiteConverterTest { + + private static final InternalLogger logger = InternalLoggerFactory.getInstance(CipherSuiteConverterTest.class); + + @Test + public void testJ2OMappings() throws Exception { + testJ2OMapping("TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256", "ECDHE-ECDSA-AES128-SHA256"); + testJ2OMapping("TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256", "ECDHE-RSA-AES128-SHA256"); + testJ2OMapping("TLS_RSA_WITH_AES_128_CBC_SHA256", "AES128-SHA256"); + testJ2OMapping("TLS_ECDH_ECDSA_WITH_AES_128_CBC_SHA256", "ECDH-ECDSA-AES128-SHA256"); + testJ2OMapping("TLS_ECDH_RSA_WITH_AES_128_CBC_SHA256", "ECDH-RSA-AES128-SHA256"); + testJ2OMapping("TLS_DHE_RSA_WITH_AES_128_CBC_SHA256", "DHE-RSA-AES128-SHA256"); + testJ2OMapping("TLS_DHE_DSS_WITH_AES_128_CBC_SHA256", "DHE-DSS-AES128-SHA256"); + testJ2OMapping("TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA", "ECDHE-ECDSA-AES128-SHA"); + testJ2OMapping("TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA", "ECDHE-RSA-AES128-SHA"); + testJ2OMapping("TLS_RSA_WITH_AES_128_CBC_SHA", "AES128-SHA"); + testJ2OMapping("TLS_ECDH_ECDSA_WITH_AES_128_CBC_SHA", "ECDH-ECDSA-AES128-SHA"); + testJ2OMapping("TLS_ECDH_RSA_WITH_AES_128_CBC_SHA", "ECDH-RSA-AES128-SHA"); + testJ2OMapping("TLS_DHE_RSA_WITH_AES_128_CBC_SHA", "DHE-RSA-AES128-SHA"); + testJ2OMapping("TLS_DHE_DSS_WITH_AES_128_CBC_SHA", "DHE-DSS-AES128-SHA"); + testJ2OMapping("TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256", "ECDHE-ECDSA-AES128-GCM-SHA256"); + testJ2OMapping("TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256", "ECDHE-RSA-AES128-GCM-SHA256"); + testJ2OMapping("TLS_RSA_WITH_AES_128_GCM_SHA256", "AES128-GCM-SHA256"); + testJ2OMapping("TLS_ECDH_ECDSA_WITH_AES_128_GCM_SHA256", "ECDH-ECDSA-AES128-GCM-SHA256"); + testJ2OMapping("TLS_ECDH_RSA_WITH_AES_128_GCM_SHA256", "ECDH-RSA-AES128-GCM-SHA256"); + testJ2OMapping("TLS_DHE_RSA_WITH_AES_128_GCM_SHA256", "DHE-RSA-AES128-GCM-SHA256"); + testJ2OMapping("TLS_DHE_DSS_WITH_AES_128_GCM_SHA256", "DHE-DSS-AES128-GCM-SHA256"); + testJ2OMapping("TLS_ECDHE_ECDSA_WITH_3DES_EDE_CBC_SHA", "ECDHE-ECDSA-DES-CBC3-SHA"); + testJ2OMapping("TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA", "ECDHE-RSA-DES-CBC3-SHA"); + testJ2OMapping("SSL_RSA_WITH_3DES_EDE_CBC_SHA", "DES-CBC3-SHA"); + testJ2OMapping("TLS_ECDH_ECDSA_WITH_3DES_EDE_CBC_SHA", "ECDH-ECDSA-DES-CBC3-SHA"); + testJ2OMapping("TLS_ECDH_RSA_WITH_3DES_EDE_CBC_SHA", "ECDH-RSA-DES-CBC3-SHA"); + testJ2OMapping("SSL_DHE_RSA_WITH_3DES_EDE_CBC_SHA", "DHE-RSA-DES-CBC3-SHA"); + testJ2OMapping("SSL_DHE_DSS_WITH_3DES_EDE_CBC_SHA", "DHE-DSS-DES-CBC3-SHA"); + testJ2OMapping("TLS_ECDHE_ECDSA_WITH_RC4_128_SHA", "ECDHE-ECDSA-RC4-SHA"); + testJ2OMapping("TLS_ECDHE_RSA_WITH_RC4_128_SHA", "ECDHE-RSA-RC4-SHA"); + testJ2OMapping("SSL_RSA_WITH_RC4_128_SHA", "RC4-SHA"); + testJ2OMapping("TLS_ECDH_ECDSA_WITH_RC4_128_SHA", "ECDH-ECDSA-RC4-SHA"); + testJ2OMapping("TLS_ECDH_RSA_WITH_RC4_128_SHA", "ECDH-RSA-RC4-SHA"); + testJ2OMapping("SSL_RSA_WITH_RC4_128_MD5", "RC4-MD5"); + testJ2OMapping("TLS_DH_anon_WITH_AES_128_GCM_SHA256", "ADH-AES128-GCM-SHA256"); + testJ2OMapping("TLS_DH_anon_WITH_AES_128_CBC_SHA256", "ADH-AES128-SHA256"); + testJ2OMapping("TLS_ECDH_anon_WITH_AES_128_CBC_SHA", "AECDH-AES128-SHA"); + testJ2OMapping("TLS_DH_anon_WITH_AES_128_CBC_SHA", "ADH-AES128-SHA"); + testJ2OMapping("TLS_ECDH_anon_WITH_3DES_EDE_CBC_SHA", "AECDH-DES-CBC3-SHA"); + testJ2OMapping("SSL_DH_anon_WITH_3DES_EDE_CBC_SHA", "ADH-DES-CBC3-SHA"); + testJ2OMapping("TLS_ECDH_anon_WITH_RC4_128_SHA", "AECDH-RC4-SHA"); + testJ2OMapping("SSL_DH_anon_WITH_RC4_128_MD5", "ADH-RC4-MD5"); + testJ2OMapping("SSL_RSA_WITH_DES_CBC_SHA", "DES-CBC-SHA"); + testJ2OMapping("SSL_DHE_RSA_WITH_DES_CBC_SHA", "DHE-RSA-DES-CBC-SHA"); + testJ2OMapping("SSL_DHE_DSS_WITH_DES_CBC_SHA", "DHE-DSS-DES-CBC-SHA"); + testJ2OMapping("SSL_DH_anon_WITH_DES_CBC_SHA", "ADH-DES-CBC-SHA"); + testJ2OMapping("SSL_RSA_EXPORT_WITH_DES40_CBC_SHA", "EXP-DES-CBC-SHA"); + testJ2OMapping("SSL_DHE_RSA_EXPORT_WITH_DES40_CBC_SHA", "EXP-DHE-RSA-DES-CBC-SHA"); + testJ2OMapping("SSL_DHE_DSS_EXPORT_WITH_DES40_CBC_SHA", "EXP-DHE-DSS-DES-CBC-SHA"); + testJ2OMapping("SSL_DH_anon_EXPORT_WITH_DES40_CBC_SHA", "EXP-ADH-DES-CBC-SHA"); + testJ2OMapping("SSL_RSA_EXPORT_WITH_RC4_40_MD5", "EXP-RC4-MD5"); + testJ2OMapping("SSL_DH_anon_EXPORT_WITH_RC4_40_MD5", "EXP-ADH-RC4-MD5"); + testJ2OMapping("TLS_RSA_WITH_NULL_SHA256", "NULL-SHA256"); + testJ2OMapping("TLS_ECDHE_ECDSA_WITH_NULL_SHA", "ECDHE-ECDSA-NULL-SHA"); + testJ2OMapping("TLS_ECDHE_RSA_WITH_NULL_SHA", "ECDHE-RSA-NULL-SHA"); + testJ2OMapping("SSL_RSA_WITH_NULL_SHA", "NULL-SHA"); + testJ2OMapping("TLS_ECDH_ECDSA_WITH_NULL_SHA", "ECDH-ECDSA-NULL-SHA"); + testJ2OMapping("TLS_ECDH_RSA_WITH_NULL_SHA", "ECDH-RSA-NULL-SHA"); + testJ2OMapping("TLS_ECDH_anon_WITH_NULL_SHA", "AECDH-NULL-SHA"); + testJ2OMapping("SSL_RSA_WITH_NULL_MD5", "NULL-MD5"); + testJ2OMapping("TLS_KRB5_WITH_3DES_EDE_CBC_SHA", "KRB5-DES-CBC3-SHA"); + testJ2OMapping("TLS_KRB5_WITH_3DES_EDE_CBC_MD5", "KRB5-DES-CBC3-MD5"); + testJ2OMapping("TLS_KRB5_WITH_RC4_128_SHA", "KRB5-RC4-SHA"); + testJ2OMapping("TLS_KRB5_WITH_RC4_128_MD5", "KRB5-RC4-MD5"); + testJ2OMapping("TLS_KRB5_WITH_DES_CBC_SHA", "KRB5-DES-CBC-SHA"); + testJ2OMapping("TLS_KRB5_WITH_DES_CBC_MD5", "KRB5-DES-CBC-MD5"); + testJ2OMapping("TLS_KRB5_EXPORT_WITH_DES_CBC_40_SHA", "EXP-KRB5-DES-CBC-SHA"); + testJ2OMapping("TLS_KRB5_EXPORT_WITH_DES_CBC_40_MD5", "EXP-KRB5-DES-CBC-MD5"); + testJ2OMapping("TLS_KRB5_EXPORT_WITH_RC4_40_SHA", "EXP-KRB5-RC4-SHA"); + testJ2OMapping("TLS_KRB5_EXPORT_WITH_RC4_40_MD5", "EXP-KRB5-RC4-MD5"); + testJ2OMapping("SSL_RSA_EXPORT_WITH_RC2_CBC_40_MD5", "EXP-RC2-CBC-MD5"); + testJ2OMapping("TLS_DHE_DSS_WITH_AES_256_CBC_SHA", "DHE-DSS-AES256-SHA"); + testJ2OMapping("TLS_DHE_RSA_WITH_AES_256_CBC_SHA", "DHE-RSA-AES256-SHA"); + testJ2OMapping("TLS_DH_anon_WITH_AES_256_CBC_SHA", "ADH-AES256-SHA"); + testJ2OMapping("TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA", "ECDHE-ECDSA-AES256-SHA"); + testJ2OMapping("TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA", "ECDHE-RSA-AES256-SHA"); + testJ2OMapping("TLS_ECDH_ECDSA_WITH_AES_256_CBC_SHA", "ECDH-ECDSA-AES256-SHA"); + testJ2OMapping("TLS_ECDH_RSA_WITH_AES_256_CBC_SHA", "ECDH-RSA-AES256-SHA"); + testJ2OMapping("TLS_ECDH_anon_WITH_AES_256_CBC_SHA", "AECDH-AES256-SHA"); + testJ2OMapping("TLS_KRB5_EXPORT_WITH_RC2_CBC_40_MD5", "EXP-KRB5-RC2-CBC-MD5"); + testJ2OMapping("TLS_KRB5_EXPORT_WITH_RC2_CBC_40_SHA", "EXP-KRB5-RC2-CBC-SHA"); + testJ2OMapping("TLS_RSA_WITH_AES_256_CBC_SHA", "AES256-SHA"); + + // For historical reasons the CHACHA20 ciphers do not follow OpenSSL's custom naming + // convention and omits the HMAC algorithm portion of the name. + testJ2OMapping("TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256", "ECDHE-RSA-CHACHA20-POLY1305"); + testJ2OMapping("TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256", "ECDHE-ECDSA-CHACHA20-POLY1305"); + testJ2OMapping("TLS_DHE_RSA_WITH_CHACHA20_POLY1305_SHA256", "DHE-RSA-CHACHA20-POLY1305"); + testJ2OMapping("TLS_PSK_WITH_CHACHA20_POLY1305_SHA256", "PSK-CHACHA20-POLY1305"); + testJ2OMapping("TLS_ECDHE_PSK_WITH_CHACHA20_POLY1305_SHA256", "ECDHE-PSK-CHACHA20-POLY1305"); + testJ2OMapping("TLS_DHE_PSK_WITH_CHACHA20_POLY1305_SHA256", "DHE-PSK-CHACHA20-POLY1305"); + testJ2OMapping("TLS_RSA_PSK_WITH_CHACHA20_POLY1305_SHA256", "RSA-PSK-CHACHA20-POLY1305"); + + testJ2OMapping("TLS_AES_128_GCM_SHA256", "TLS_AES_128_GCM_SHA256"); + testJ2OMapping("TLS_AES_256_GCM_SHA384", "TLS_AES_256_GCM_SHA384"); + testJ2OMapping("TLS_CHACHA20_POLY1305_SHA256", "TLS_CHACHA20_POLY1305_SHA256"); + } + + private static void testJ2OMapping(String javaCipherSuite, String openSslCipherSuite) { + final String actual = CipherSuiteConverter.toOpenSslUncached(javaCipherSuite, false); + logger.info("{} => {}", javaCipherSuite, actual); + assertThat(actual, is(openSslCipherSuite)); + } + + @Test + public void testO2JMappings() throws Exception { + testO2JMapping("ECDHE_ECDSA_WITH_AES_128_CBC_SHA256", "ECDHE-ECDSA-AES128-SHA256"); + testO2JMapping("ECDHE_RSA_WITH_AES_128_CBC_SHA256", "ECDHE-RSA-AES128-SHA256"); + testO2JMapping("RSA_WITH_AES_128_CBC_SHA256", "AES128-SHA256"); + testO2JMapping("ECDH_ECDSA_WITH_AES_128_CBC_SHA256", "ECDH-ECDSA-AES128-SHA256"); + testO2JMapping("ECDH_RSA_WITH_AES_128_CBC_SHA256", "ECDH-RSA-AES128-SHA256"); + testO2JMapping("DHE_RSA_WITH_AES_128_CBC_SHA256", "DHE-RSA-AES128-SHA256"); + testO2JMapping("DHE_DSS_WITH_AES_128_CBC_SHA256", "DHE-DSS-AES128-SHA256"); + testO2JMapping("ECDHE_ECDSA_WITH_AES_128_CBC_SHA", "ECDHE-ECDSA-AES128-SHA"); + testO2JMapping("ECDHE_RSA_WITH_AES_128_CBC_SHA", "ECDHE-RSA-AES128-SHA"); + testO2JMapping("RSA_WITH_AES_128_CBC_SHA", "AES128-SHA"); + testO2JMapping("ECDH_ECDSA_WITH_AES_128_CBC_SHA", "ECDH-ECDSA-AES128-SHA"); + testO2JMapping("ECDH_RSA_WITH_AES_128_CBC_SHA", "ECDH-RSA-AES128-SHA"); + testO2JMapping("DHE_RSA_WITH_AES_128_CBC_SHA", "DHE-RSA-AES128-SHA"); + testO2JMapping("DHE_DSS_WITH_AES_128_CBC_SHA", "DHE-DSS-AES128-SHA"); + testO2JMapping("ECDHE_ECDSA_WITH_AES_128_GCM_SHA256", "ECDHE-ECDSA-AES128-GCM-SHA256"); + testO2JMapping("ECDHE_RSA_WITH_AES_128_GCM_SHA256", "ECDHE-RSA-AES128-GCM-SHA256"); + testO2JMapping("RSA_WITH_AES_128_GCM_SHA256", "AES128-GCM-SHA256"); + testO2JMapping("ECDH_ECDSA_WITH_AES_128_GCM_SHA256", "ECDH-ECDSA-AES128-GCM-SHA256"); + testO2JMapping("ECDH_RSA_WITH_AES_128_GCM_SHA256", "ECDH-RSA-AES128-GCM-SHA256"); + testO2JMapping("DHE_RSA_WITH_AES_128_GCM_SHA256", "DHE-RSA-AES128-GCM-SHA256"); + testO2JMapping("DHE_DSS_WITH_AES_128_GCM_SHA256", "DHE-DSS-AES128-GCM-SHA256"); + testO2JMapping("ECDHE_ECDSA_WITH_3DES_EDE_CBC_SHA", "ECDHE-ECDSA-DES-CBC3-SHA"); + testO2JMapping("ECDHE_RSA_WITH_3DES_EDE_CBC_SHA", "ECDHE-RSA-DES-CBC3-SHA"); + testO2JMapping("RSA_WITH_3DES_EDE_CBC_SHA", "DES-CBC3-SHA"); + testO2JMapping("ECDH_ECDSA_WITH_3DES_EDE_CBC_SHA", "ECDH-ECDSA-DES-CBC3-SHA"); + testO2JMapping("ECDH_RSA_WITH_3DES_EDE_CBC_SHA", "ECDH-RSA-DES-CBC3-SHA"); + testO2JMapping("DHE_RSA_WITH_3DES_EDE_CBC_SHA", "DHE-RSA-DES-CBC3-SHA"); + testO2JMapping("DHE_DSS_WITH_3DES_EDE_CBC_SHA", "DHE-DSS-DES-CBC3-SHA"); + testO2JMapping("ECDHE_ECDSA_WITH_RC4_128_SHA", "ECDHE-ECDSA-RC4-SHA"); + testO2JMapping("ECDHE_RSA_WITH_RC4_128_SHA", "ECDHE-RSA-RC4-SHA"); + testO2JMapping("RSA_WITH_RC4_128_SHA", "RC4-SHA"); + testO2JMapping("ECDH_ECDSA_WITH_RC4_128_SHA", "ECDH-ECDSA-RC4-SHA"); + testO2JMapping("ECDH_RSA_WITH_RC4_128_SHA", "ECDH-RSA-RC4-SHA"); + testO2JMapping("RSA_WITH_RC4_128_MD5", "RC4-MD5"); + testO2JMapping("DH_anon_WITH_AES_128_GCM_SHA256", "ADH-AES128-GCM-SHA256"); + testO2JMapping("DH_anon_WITH_AES_128_CBC_SHA256", "ADH-AES128-SHA256"); + testO2JMapping("ECDH_anon_WITH_AES_128_CBC_SHA", "AECDH-AES128-SHA"); + testO2JMapping("DH_anon_WITH_AES_128_CBC_SHA", "ADH-AES128-SHA"); + testO2JMapping("ECDH_anon_WITH_3DES_EDE_CBC_SHA", "AECDH-DES-CBC3-SHA"); + testO2JMapping("DH_anon_WITH_3DES_EDE_CBC_SHA", "ADH-DES-CBC3-SHA"); + testO2JMapping("ECDH_anon_WITH_RC4_128_SHA", "AECDH-RC4-SHA"); + testO2JMapping("DH_anon_WITH_RC4_128_MD5", "ADH-RC4-MD5"); + testO2JMapping("RSA_WITH_DES_CBC_SHA", "DES-CBC-SHA"); + testO2JMapping("DHE_RSA_WITH_DES_CBC_SHA", "DHE-RSA-DES-CBC-SHA"); + testO2JMapping("DHE_DSS_WITH_DES_CBC_SHA", "DHE-DSS-DES-CBC-SHA"); + testO2JMapping("DH_anon_WITH_DES_CBC_SHA", "ADH-DES-CBC-SHA"); + testO2JMapping("RSA_EXPORT_WITH_DES_CBC_40_SHA", "EXP-DES-CBC-SHA"); + testO2JMapping("DHE_RSA_EXPORT_WITH_DES_CBC_40_SHA", "EXP-DHE-RSA-DES-CBC-SHA"); + testO2JMapping("DHE_DSS_EXPORT_WITH_DES_CBC_40_SHA", "EXP-DHE-DSS-DES-CBC-SHA"); + testO2JMapping("DH_anon_EXPORT_WITH_DES_CBC_40_SHA", "EXP-ADH-DES-CBC-SHA"); + testO2JMapping("RSA_EXPORT_WITH_RC4_40_MD5", "EXP-RC4-MD5"); + testO2JMapping("DH_anon_EXPORT_WITH_RC4_40_MD5", "EXP-ADH-RC4-MD5"); + testO2JMapping("RSA_WITH_NULL_SHA256", "NULL-SHA256"); + testO2JMapping("ECDHE_ECDSA_WITH_NULL_SHA", "ECDHE-ECDSA-NULL-SHA"); + testO2JMapping("ECDHE_RSA_WITH_NULL_SHA", "ECDHE-RSA-NULL-SHA"); + testO2JMapping("RSA_WITH_NULL_SHA", "NULL-SHA"); + testO2JMapping("ECDH_ECDSA_WITH_NULL_SHA", "ECDH-ECDSA-NULL-SHA"); + testO2JMapping("ECDH_RSA_WITH_NULL_SHA", "ECDH-RSA-NULL-SHA"); + testO2JMapping("ECDH_anon_WITH_NULL_SHA", "AECDH-NULL-SHA"); + testO2JMapping("RSA_WITH_NULL_MD5", "NULL-MD5"); + testO2JMapping("KRB5_WITH_3DES_EDE_CBC_SHA", "KRB5-DES-CBC3-SHA"); + testO2JMapping("KRB5_WITH_3DES_EDE_CBC_MD5", "KRB5-DES-CBC3-MD5"); + testO2JMapping("KRB5_WITH_RC4_128_SHA", "KRB5-RC4-SHA"); + testO2JMapping("KRB5_WITH_RC4_128_MD5", "KRB5-RC4-MD5"); + testO2JMapping("KRB5_WITH_DES_CBC_SHA", "KRB5-DES-CBC-SHA"); + testO2JMapping("KRB5_WITH_DES_CBC_MD5", "KRB5-DES-CBC-MD5"); + testO2JMapping("KRB5_EXPORT_WITH_DES_CBC_40_SHA", "EXP-KRB5-DES-CBC-SHA"); + testO2JMapping("KRB5_EXPORT_WITH_DES_CBC_40_MD5", "EXP-KRB5-DES-CBC-MD5"); + testO2JMapping("KRB5_EXPORT_WITH_RC4_40_SHA", "EXP-KRB5-RC4-SHA"); + testO2JMapping("KRB5_EXPORT_WITH_RC4_40_MD5", "EXP-KRB5-RC4-MD5"); + testO2JMapping("RSA_EXPORT_WITH_RC2_CBC_40_MD5", "EXP-RC2-CBC-MD5"); + testO2JMapping("DHE_DSS_WITH_AES_256_CBC_SHA", "DHE-DSS-AES256-SHA"); + testO2JMapping("DHE_RSA_WITH_AES_256_CBC_SHA", "DHE-RSA-AES256-SHA"); + testO2JMapping("DH_anon_WITH_AES_256_CBC_SHA", "ADH-AES256-SHA"); + testO2JMapping("ECDHE_ECDSA_WITH_AES_256_CBC_SHA", "ECDHE-ECDSA-AES256-SHA"); + testO2JMapping("ECDHE_RSA_WITH_AES_256_CBC_SHA", "ECDHE-RSA-AES256-SHA"); + testO2JMapping("ECDH_ECDSA_WITH_AES_256_CBC_SHA", "ECDH-ECDSA-AES256-SHA"); + testO2JMapping("ECDH_RSA_WITH_AES_256_CBC_SHA", "ECDH-RSA-AES256-SHA"); + testO2JMapping("ECDH_anon_WITH_AES_256_CBC_SHA", "AECDH-AES256-SHA"); + testO2JMapping("KRB5_EXPORT_WITH_RC2_CBC_40_MD5", "EXP-KRB5-RC2-CBC-MD5"); + testO2JMapping("KRB5_EXPORT_WITH_RC2_CBC_40_SHA", "EXP-KRB5-RC2-CBC-SHA"); + testO2JMapping("RSA_WITH_AES_256_CBC_SHA", "AES256-SHA"); + + // Test the known mappings that actually do not exist in Java + testO2JMapping("EDH_DSS_WITH_3DES_EDE_CBC_SHA", "EDH-DSS-DES-CBC3-SHA"); + testO2JMapping("RSA_WITH_SEED_SHA", "SEED-SHA"); + testO2JMapping("RSA_WITH_CAMELLIA128_SHA", "CAMELLIA128-SHA"); + testO2JMapping("RSA_WITH_IDEA_CBC_SHA", "IDEA-CBC-SHA"); + testO2JMapping("PSK_WITH_AES_128_CBC_SHA", "PSK-AES128-CBC-SHA"); + testO2JMapping("PSK_WITH_3DES_EDE_CBC_SHA", "PSK-3DES-EDE-CBC-SHA"); + testO2JMapping("KRB5_WITH_IDEA_CBC_SHA", "KRB5-IDEA-CBC-SHA"); + testO2JMapping("KRB5_WITH_IDEA_CBC_MD5", "KRB5-IDEA-CBC-MD5"); + testO2JMapping("PSK_WITH_RC4_128_SHA", "PSK-RC4-SHA"); + testO2JMapping("ECDHE_RSA_WITH_AES_256_GCM_SHA384", "ECDHE-RSA-AES256-GCM-SHA384"); + testO2JMapping("ECDHE_ECDSA_WITH_AES_256_GCM_SHA384", "ECDHE-ECDSA-AES256-GCM-SHA384"); + testO2JMapping("ECDHE_RSA_WITH_AES_256_CBC_SHA384", "ECDHE-RSA-AES256-SHA384"); + testO2JMapping("ECDHE_ECDSA_WITH_AES_256_CBC_SHA384", "ECDHE-ECDSA-AES256-SHA384"); + testO2JMapping("DHE_DSS_WITH_AES_256_GCM_SHA384", "DHE-DSS-AES256-GCM-SHA384"); + testO2JMapping("DHE_RSA_WITH_AES_256_GCM_SHA384", "DHE-RSA-AES256-GCM-SHA384"); + testO2JMapping("DHE_RSA_WITH_AES_256_CBC_SHA256", "DHE-RSA-AES256-SHA256"); + testO2JMapping("DHE_DSS_WITH_AES_256_CBC_SHA256", "DHE-DSS-AES256-SHA256"); + testO2JMapping("DHE_RSA_WITH_CAMELLIA256_SHA", "DHE-RSA-CAMELLIA256-SHA"); + testO2JMapping("DHE_DSS_WITH_CAMELLIA256_SHA", "DHE-DSS-CAMELLIA256-SHA"); + testO2JMapping("ECDH_RSA_WITH_AES_256_GCM_SHA384", "ECDH-RSA-AES256-GCM-SHA384"); + testO2JMapping("ECDH_ECDSA_WITH_AES_256_GCM_SHA384", "ECDH-ECDSA-AES256-GCM-SHA384"); + testO2JMapping("ECDH_RSA_WITH_AES_256_CBC_SHA384", "ECDH-RSA-AES256-SHA384"); + testO2JMapping("ECDH_ECDSA_WITH_AES_256_CBC_SHA384", "ECDH-ECDSA-AES256-SHA384"); + testO2JMapping("RSA_WITH_AES_256_GCM_SHA384", "AES256-GCM-SHA384"); + testO2JMapping("RSA_WITH_AES_256_CBC_SHA256", "AES256-SHA256"); + testO2JMapping("RSA_WITH_CAMELLIA256_SHA", "CAMELLIA256-SHA"); + testO2JMapping("PSK_WITH_AES_256_CBC_SHA", "PSK-AES256-CBC-SHA"); + testO2JMapping("DHE_RSA_WITH_SEED_SHA", "DHE-RSA-SEED-SHA"); + testO2JMapping("DHE_DSS_WITH_SEED_SHA", "DHE-DSS-SEED-SHA"); + testO2JMapping("DHE_RSA_WITH_CAMELLIA128_SHA", "DHE-RSA-CAMELLIA128-SHA"); + testO2JMapping("DHE_DSS_WITH_CAMELLIA128_SHA", "DHE-DSS-CAMELLIA128-SHA"); + testO2JMapping("EDH_RSA_WITH_3DES_EDE_CBC_SHA", "EDH-RSA-DES-CBC3-SHA"); + testO2JMapping("SRP_DSS_WITH_AES_256_CBC_SHA", "SRP-DSS-AES-256-CBC-SHA"); + testO2JMapping("SRP_RSA_WITH_AES_256_CBC_SHA", "SRP-RSA-AES-256-CBC-SHA"); + testO2JMapping("SRP_WITH_AES_256_CBC_SHA", "SRP-AES-256-CBC-SHA"); + testO2JMapping("DH_anon_WITH_AES_256_GCM_SHA384", "ADH-AES256-GCM-SHA384"); + testO2JMapping("DH_anon_WITH_AES_256_CBC_SHA256", "ADH-AES256-SHA256"); + testO2JMapping("DH_anon_WITH_CAMELLIA256_SHA", "ADH-CAMELLIA256-SHA"); + testO2JMapping("SRP_DSS_WITH_AES_128_CBC_SHA", "SRP-DSS-AES-128-CBC-SHA"); + testO2JMapping("SRP_RSA_WITH_AES_128_CBC_SHA", "SRP-RSA-AES-128-CBC-SHA"); + testO2JMapping("SRP_WITH_AES_128_CBC_SHA", "SRP-AES-128-CBC-SHA"); + testO2JMapping("DH_anon_WITH_SEED_SHA", "ADH-SEED-SHA"); + testO2JMapping("DH_anon_WITH_CAMELLIA128_SHA", "ADH-CAMELLIA128-SHA"); + testO2JMapping("RSA_WITH_RC2_CBC_MD5", "RC2-CBC-MD5"); + testO2JMapping("SRP_DSS_WITH_3DES_EDE_CBC_SHA", "SRP-DSS-3DES-EDE-CBC-SHA"); + testO2JMapping("SRP_RSA_WITH_3DES_EDE_CBC_SHA", "SRP-RSA-3DES-EDE-CBC-SHA"); + testO2JMapping("SRP_WITH_3DES_EDE_CBC_SHA", "SRP-3DES-EDE-CBC-SHA"); + testO2JMapping("RSA_WITH_3DES_EDE_CBC_MD5", "DES-CBC3-MD5"); + testO2JMapping("EDH_RSA_WITH_DES_CBC_SHA", "EDH-RSA-DES-CBC-SHA"); + testO2JMapping("EDH_DSS_WITH_DES_CBC_SHA", "EDH-DSS-DES-CBC-SHA"); + testO2JMapping("RSA_WITH_DES_CBC_MD5", "DES-CBC-MD5"); + testO2JMapping("EDH_RSA_EXPORT_WITH_DES_CBC_40_SHA", "EXP-EDH-RSA-DES-CBC-SHA"); + testO2JMapping("EDH_DSS_EXPORT_WITH_DES_CBC_40_SHA", "EXP-EDH-DSS-DES-CBC-SHA"); + + // For historical reasons the CHACHA20 ciphers do not follow OpenSSL's custom naming + // convention and omits the HMAC algorithm portion of the name. + testO2JMapping("ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256", "ECDHE-RSA-CHACHA20-POLY1305"); + testO2JMapping("ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256", "ECDHE-ECDSA-CHACHA20-POLY1305"); + testO2JMapping("DHE_RSA_WITH_CHACHA20_POLY1305_SHA256", "DHE-RSA-CHACHA20-POLY1305"); + testO2JMapping("PSK_WITH_CHACHA20_POLY1305_SHA256", "PSK-CHACHA20-POLY1305"); + testO2JMapping("ECDHE_PSK_WITH_CHACHA20_POLY1305_SHA256", "ECDHE-PSK-CHACHA20-POLY1305"); + testO2JMapping("DHE_PSK_WITH_CHACHA20_POLY1305_SHA256", "DHE-PSK-CHACHA20-POLY1305"); + testO2JMapping("RSA_PSK_WITH_CHACHA20_POLY1305_SHA256", "RSA-PSK-CHACHA20-POLY1305"); + } + + private static void testO2JMapping(String javaCipherSuite, String openSslCipherSuite) { + final String actual = CipherSuiteConverter.toJavaUncached(openSslCipherSuite); + logger.info("{} => {}", openSslCipherSuite, actual); + assertThat(actual, is(javaCipherSuite)); + } + + @Test + public void testCachedJ2OMappings() { + testCachedJ2OMapping("TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256", "ECDHE-ECDSA-AES128-SHA256"); + } + + @Test + public void testUnknownOpenSSLCiphersToJava() { + testUnknownOpenSSLCiphersToJava("(NONE)"); + testUnknownOpenSSLCiphersToJava("unknown"); + testUnknownOpenSSLCiphersToJava(""); + } + + @Test + public void testUnknownJavaCiphersToOpenSSL() { + testUnknownJavaCiphersToOpenSSL("(NONE)"); + testUnknownJavaCiphersToOpenSSL("unknown"); + testUnknownJavaCiphersToOpenSSL(""); + } + + private static void testUnknownOpenSSLCiphersToJava(String openSslCipherSuite) { + CipherSuiteConverter.clearCache(); + + assertNull(CipherSuiteConverter.toJava(openSslCipherSuite, "TLS")); + assertNull(CipherSuiteConverter.toJava(openSslCipherSuite, "SSL")); + } + + private static void testUnknownJavaCiphersToOpenSSL(String javaCipherSuite) { + CipherSuiteConverter.clearCache(); + + assertNull(CipherSuiteConverter.toOpenSsl(javaCipherSuite, false)); + assertNull(CipherSuiteConverter.toOpenSsl(javaCipherSuite, true)); + } + + private static void testCachedJ2OMapping(String javaCipherSuite, String openSslCipherSuite) { + CipherSuiteConverter.clearCache(); + + // For TLSv1.3 this should make no diffierence if boringSSL is true or false + final String actual1 = CipherSuiteConverter.toOpenSsl(javaCipherSuite, false); + assertThat(actual1, is(openSslCipherSuite)); + final String actual2 = CipherSuiteConverter.toOpenSsl(javaCipherSuite, true); + assertEquals(actual1, actual2); + + // Ensure that the cache entries have been created. + assertThat(CipherSuiteConverter.isJ2OCached(javaCipherSuite, actual1), is(true)); + assertThat(CipherSuiteConverter.isO2JCached(actual1, "", javaCipherSuite.substring(4)), is(true)); + assertThat(CipherSuiteConverter.isO2JCached(actual1, "SSL", "SSL_" + javaCipherSuite.substring(4)), is(true)); + assertThat(CipherSuiteConverter.isO2JCached(actual1, "TLS", "TLS_" + javaCipherSuite.substring(4)), is(true)); + + final String actual3 = CipherSuiteConverter.toOpenSsl(javaCipherSuite, false); + assertThat(actual3, is(openSslCipherSuite)); + + // Test if the returned cipher strings are identical, + // so that the TLS sessions with the same cipher suite do not create many strings. + assertThat(actual1, is(sameInstance(actual3))); + } + + @Test + public void testCachedO2JMappings() { + testCachedO2JMapping("ECDHE_ECDSA_WITH_AES_128_CBC_SHA256", "ECDHE-ECDSA-AES128-SHA256"); + } + + private static void testCachedO2JMapping(String javaCipherSuite, String openSslCipherSuite) { + CipherSuiteConverter.clearCache(); + + final String tlsExpected = "TLS_" + javaCipherSuite; + final String sslExpected = "SSL_" + javaCipherSuite; + + final String tlsActual1 = CipherSuiteConverter.toJava(openSslCipherSuite, "TLS"); + final String sslActual1 = CipherSuiteConverter.toJava(openSslCipherSuite, "SSL"); + assertThat(tlsActual1, is(tlsExpected)); + assertThat(sslActual1, is(sslExpected)); + + // Ensure that the cache entries have been created. + assertThat(CipherSuiteConverter.isO2JCached(openSslCipherSuite, "", javaCipherSuite), is(true)); + assertThat(CipherSuiteConverter.isO2JCached(openSslCipherSuite, "SSL", sslExpected), is(true)); + assertThat(CipherSuiteConverter.isO2JCached(openSslCipherSuite, "TLS", tlsExpected), is(true)); + assertThat(CipherSuiteConverter.isJ2OCached(tlsExpected, openSslCipherSuite), is(true)); + assertThat(CipherSuiteConverter.isJ2OCached(sslExpected, openSslCipherSuite), is(true)); + + final String tlsActual2 = CipherSuiteConverter.toJava(openSslCipherSuite, "TLS"); + final String sslActual2 = CipherSuiteConverter.toJava(openSslCipherSuite, "SSL"); + assertThat(tlsActual2, is(tlsExpected)); + assertThat(sslActual2, is(sslExpected)); + + // Test if the returned cipher strings are identical, + // so that the TLS sessions with the same cipher suite do not create many strings. + assertThat(tlsActual1, is(sameInstance(tlsActual2))); + assertThat(sslActual1, is(sameInstance(sslActual2))); + } + + @Test + public void testTlsv13Mappings() { + CipherSuiteConverter.clearCache(); + + assertEquals("TLS_AES_128_GCM_SHA256", + CipherSuiteConverter.toJava("TLS_AES_128_GCM_SHA256", "TLS")); + assertNull(CipherSuiteConverter.toJava("TLS_AES_128_GCM_SHA256", "SSL")); + assertEquals("TLS_AES_256_GCM_SHA384", + CipherSuiteConverter.toJava("TLS_AES_256_GCM_SHA384", "TLS")); + assertNull(CipherSuiteConverter.toJava("TLS_AES_256_GCM_SHA384", "SSL")); + assertEquals("TLS_CHACHA20_POLY1305_SHA256", + CipherSuiteConverter.toJava("TLS_CHACHA20_POLY1305_SHA256", "TLS")); + assertNull(CipherSuiteConverter.toJava("TLS_CHACHA20_POLY1305_SHA256", "SSL")); + + // BoringSSL use different cipher naming then OpenSSL so we need to test for both + assertEquals("TLS_AES_128_GCM_SHA256", + CipherSuiteConverter.toOpenSsl("TLS_AES_128_GCM_SHA256", false)); + assertEquals("TLS_AES_256_GCM_SHA384", + CipherSuiteConverter.toOpenSsl("TLS_AES_256_GCM_SHA384", false)); + assertEquals("TLS_CHACHA20_POLY1305_SHA256", + CipherSuiteConverter.toOpenSsl("TLS_CHACHA20_POLY1305_SHA256", false)); + + assertEquals("AEAD-AES128-GCM-SHA256", + CipherSuiteConverter.toOpenSsl("TLS_AES_128_GCM_SHA256", true)); + assertEquals("AEAD-AES256-GCM-SHA384", + CipherSuiteConverter.toOpenSsl("TLS_AES_256_GCM_SHA384", true)); + assertEquals("AEAD-CHACHA20-POLY1305-SHA256", + CipherSuiteConverter.toOpenSsl("TLS_CHACHA20_POLY1305_SHA256", true)); + } +} diff --git a/netty-handler-ssl/src/test/java/io/netty/handler/ssl/CloseNotifyTest.java b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/CloseNotifyTest.java new file mode 100644 index 0000000..179e6b4 --- /dev/null +++ b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/CloseNotifyTest.java @@ -0,0 +1,230 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.UnpooledByteBufAllocator; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.handler.ssl.util.InsecureTrustManagerFactory; +import io.netty.handler.ssl.util.SelfSignedCertificate; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import javax.net.ssl.SSLSession; +import java.util.Collection; +import java.util.Queue; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingQueue; + +import static io.netty.buffer.ByteBufUtil.writeAscii; +import static io.netty.buffer.Unpooled.EMPTY_BUFFER; +import static io.netty.handler.codec.ByteToMessageDecoder.MERGE_CUMULATOR; +import static java.nio.charset.StandardCharsets.US_ASCII; +import static java.util.Arrays.asList; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.empty; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.hamcrest.Matchers.notNullValue; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +public class CloseNotifyTest { + + private static final UnpooledByteBufAllocator ALLOC = UnpooledByteBufAllocator.DEFAULT; + private static final Object INACTIVE = new Object() { + @Override + public String toString() { + return "INACTIVE"; + } + }; + + static Collection data() { + return asList(new Object[][] { + { SslProvider.JDK, SslProtocols.TLS_v1_2 }, + { SslProvider.JDK, SslProtocols.TLS_v1_3 }, + { SslProvider.OPENSSL, SslProtocols.TLS_v1_2 }, + { SslProvider.OPENSSL, SslProtocols.TLS_v1_3 }, + }); + } + + @ParameterizedTest(name = "{index}: provider={0}, protocol={1}") + @Timeout(30) + @MethodSource("data") + public void eventsOrder(SslProvider provider, String protocol) throws Exception { + assumeTrue(provider != SslProvider.OPENSSL || OpenSsl.isAvailable(), "OpenSSL is not available"); + + if (SslProtocols.TLS_v1_3.equals(protocol)) { + // Ensure we support TLSv1.3 + assumeTrue(SslProvider.isTlsv13Supported(provider)); + } + BlockingQueue clientEventQueue = new LinkedBlockingQueue(); + BlockingQueue serverEventQueue = new LinkedBlockingQueue(); + + EmbeddedChannel clientChannel = initChannel(provider, protocol, true, clientEventQueue); + EmbeddedChannel serverChannel = initChannel(provider, protocol, false, serverEventQueue); + + try { + // handshake: + forwardData(clientChannel, serverChannel); + forwardData(serverChannel, clientChannel); + forwardData(clientChannel, serverChannel); + forwardData(serverChannel, clientChannel); + assertThat(clientEventQueue.poll(), instanceOf(SslHandshakeCompletionEvent.class)); + assertThat(serverEventQueue.poll(), instanceOf(SslHandshakeCompletionEvent.class)); + assertThat(handshakenProtocol(clientChannel), equalTo(protocol)); + + // send data: + clientChannel.writeOutbound(writeAscii(ALLOC, "request_msg")); + forwardData(clientChannel, serverChannel); + assertThat(serverEventQueue.poll(), equalTo((Object) "request_msg")); + + // respond with data and close_notify: + serverChannel.writeOutbound(writeAscii(ALLOC, "response_msg")); + assertThat(serverChannel.finish(), is(true)); + assertThat(serverEventQueue.poll(), instanceOf(SslCloseCompletionEvent.class)); + assertThat(clientEventQueue, empty()); + + // consume server response with close_notify: + forwardAllWithCloseNotify(serverChannel, clientChannel); + assertThat(clientEventQueue.poll(), equalTo((Object) "response_msg")); + assertThat(clientEventQueue.poll(), instanceOf(SslCloseCompletionEvent.class)); + + // make sure client automatically responds with close_notify: + if (!jdkTls13(provider, protocol)) { + // JDK impl of TLSv1.3 does not automatically generate "close_notify" in response to the received + // "close_notify" alert. This is a legit behavior according to the spec: + // https://tools.ietf.org/html/rfc8446#section-6.1. Handle it differently: + assertCloseNotify((ByteBuf) clientChannel.readOutbound()); + } + } finally { + try { + clientChannel.finish(); + } finally { + serverChannel.finish(); + } + } + + if (jdkTls13(provider, protocol)) { + assertCloseNotify((ByteBuf) clientChannel.readOutbound()); + } else { + discardEmptyOutboundBuffers(clientChannel); + } + + assertThat(clientEventQueue.poll(), is(INACTIVE)); + assertThat(clientEventQueue, empty()); + assertThat(serverEventQueue.poll(), is(INACTIVE)); + assertThat(serverEventQueue, empty()); + + assertThat(clientChannel.releaseInbound(), is(false)); + assertThat(clientChannel.releaseOutbound(), is(false)); + assertThat(serverChannel.releaseInbound(), is(false)); + assertThat(serverChannel.releaseOutbound(), is(false)); + } + + private static boolean jdkTls13(SslProvider provider, String protocol) { + return provider == SslProvider.JDK && SslProtocols.TLS_v1_3.equals(protocol); + } + + private static EmbeddedChannel initChannel(SslProvider provider, String protocol, final boolean useClientMode, + final BlockingQueue eventQueue) throws Exception { + + SelfSignedCertificate ssc = new SelfSignedCertificate(); + final SslContext sslContext = (useClientMode + ? SslContextBuilder.forClient().trustManager(InsecureTrustManagerFactory.INSTANCE) + : SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey())) + .sslProvider(provider) + .protocols(protocol) + .build(); + return new EmbeddedChannel( + // use sslContext.newHandler(ALLOC) instead of new SslHandler(sslContext.newEngine(ALLOC)) to create + // non-JDK compatible OpenSSL engine that can process partial packets: + sslContext.newHandler(ALLOC), + new SimpleChannelInboundHandler() { + + @Override + protected void channelRead0(ChannelHandlerContext ctx, ByteBuf msg) { + eventQueue.add(msg.toString(US_ASCII)); + } + + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + eventQueue.add(evt); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + eventQueue.add(INACTIVE); + super.channelInactive(ctx); + } + } + ); + } + + private static void forwardData(EmbeddedChannel from, EmbeddedChannel to) { + ByteBuf in; + while ((in = from.readOutbound()) != null) { + to.writeInbound(in); + } + } + + private static void forwardAllWithCloseNotify(EmbeddedChannel from, EmbeddedChannel to) { + ByteBuf cumulation = EMPTY_BUFFER; + ByteBuf in, closeNotify = null; + while ((in = from.readOutbound()) != null) { + if (closeNotify != null) { + closeNotify.release(); + } + closeNotify = in.duplicate(); + cumulation = MERGE_CUMULATOR.cumulate(ALLOC, cumulation, in.retain()); + } + assertCloseNotify(closeNotify); + to.writeInbound(cumulation); + } + + private static String handshakenProtocol(EmbeddedChannel channel) { + SslHandler sslHandler = channel.pipeline().get(SslHandler.class); + SSLSession session = sslHandler.engine().getSession(); + return session.getProtocol(); + } + + private static void discardEmptyOutboundBuffers(EmbeddedChannel channel) { + Queue outbound = channel.outboundMessages(); + while (outbound.peek() instanceof ByteBuf) { + ByteBuf buf = (ByteBuf) outbound.peek(); + if (!buf.isReadable()) { + buf.release(); + outbound.poll(); + } else { + break; + } + } + } + + static void assertCloseNotify(ByteBuf closeNotify) { + assertThat(closeNotify, notNullValue()); + try { + assertThat("Doesn't match expected length of close_notify alert", + closeNotify.readableBytes(), greaterThanOrEqualTo(7)); + } finally { + closeNotify.release(); + } + } +} diff --git a/netty-handler-ssl/src/test/java/io/netty/handler/ssl/ConscryptJdkSslEngineInteropTest.java b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/ConscryptJdkSslEngineInteropTest.java new file mode 100644 index 0000000..6d8a9c7 --- /dev/null +++ b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/ConscryptJdkSslEngineInteropTest.java @@ -0,0 +1,90 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import java.security.Provider; + +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.condition.DisabledIf; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import javax.net.ssl.SSLSessionContext; + + +@DisabledIf("checkConscryptDisabled") +public class ConscryptJdkSslEngineInteropTest extends SSLEngineTest { + + public ConscryptJdkSslEngineInteropTest() { + super(false); + } + + static boolean checkConscryptDisabled() { + return !Conscrypt.isAvailable(); + } + + @Override + protected SslProvider sslClientProvider() { + return SslProvider.JDK; + } + + @Override + protected SslProvider sslServerProvider() { + return SslProvider.JDK; + } + + @Override + protected Provider clientSslContextProvider() { + return Java8SslTestUtils.conscryptProvider(); + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Disabled /* Does the JDK support a "max certificate chain length"? */ + @Override + public void testMutualAuthValidClientCertChainTooLongFailOptionalClientAuth(SSLEngineTestParam param) + throws Exception { + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Disabled /* Does the JDK support a "max certificate chain length"? */ + @Override + public void testMutualAuthValidClientCertChainTooLongFailRequireClientAuth(SSLEngineTestParam param) + throws Exception { + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Override + protected boolean mySetupMutualAuthServerIsValidServerException(Throwable cause) { + // TODO(scott): work around for a JDK issue. The exception should be SSLHandshakeException. + return super.mySetupMutualAuthServerIsValidServerException(cause) || causedBySSLException(cause); + } + + @Override + protected void invalidateSessionsAndAssert(SSLSessionContext context) { + // Not supported by conscrypt + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Disabled("Disabled due a conscrypt bug") + @Override + public void testInvalidSNIIsIgnoredAndNotThrow(SSLEngineTestParam param) throws Exception { + super.testInvalidSNIIsIgnoredAndNotThrow(param); + } +} diff --git a/netty-handler-ssl/src/test/java/io/netty/handler/ssl/ConscryptOpenSslEngineInteropTest.java b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/ConscryptOpenSslEngineInteropTest.java new file mode 100644 index 0000000..5d530b9 --- /dev/null +++ b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/ConscryptOpenSslEngineInteropTest.java @@ -0,0 +1,219 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.condition.DisabledIf; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLSessionContext; + +import java.security.Provider; +import java.util.ArrayList; +import java.util.List; + +import static io.netty.handler.ssl.OpenSslTestUtils.checkShouldUseKeyManagerFactory; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +@DisabledIf("checkConscryptDisabled") +public class ConscryptOpenSslEngineInteropTest extends ConscryptSslEngineTest { + + @Override + protected List newTestParams() { + List params = super.newTestParams(); + List testParams = new ArrayList(); + for (SSLEngineTestParam param: params) { + testParams.add(new OpenSslEngineTestParam(true, param)); + //testParams.add(new OpenSslEngineTestParam(false, param)); + } + return testParams; + } + + @BeforeAll + public static void checkOpenssl() { + OpenSsl.ensureAvailability(); + } + + @Override + protected SslProvider sslClientProvider() { + return SslProvider.JDK; + } + + @Override + protected SslProvider sslServerProvider() { + return SslProvider.OPENSSL; + } + + @Override + protected Provider serverSslContextProvider() { + return null; + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Disabled("TODO: Make this work with Conscrypt") + @Override + public void testMutualAuthValidClientCertChainTooLongFailOptionalClientAuth(SSLEngineTestParam param) { + super.testMutualAuthValidClientCertChainTooLongFailOptionalClientAuth(param); + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Disabled("TODO: Make this work with Conscrypt") + @Override + public void testMutualAuthValidClientCertChainTooLongFailRequireClientAuth(SSLEngineTestParam param) { + super.testMutualAuthValidClientCertChainTooLongFailRequireClientAuth(param); + } + + @Override + protected boolean mySetupMutualAuthServerIsValidClientException(Throwable cause) { + // TODO(scott): work around for a JDK issue. The exception should be SSLHandshakeException. + return super.mySetupMutualAuthServerIsValidClientException(cause) || causedBySSLException(cause); + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Override + public void testMutualAuthInvalidIntermediateCASucceedWithOptionalClientAuth(SSLEngineTestParam param) + throws Exception { + checkShouldUseKeyManagerFactory(); + super.testMutualAuthInvalidIntermediateCASucceedWithOptionalClientAuth(param); + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Override + public void testMutualAuthInvalidIntermediateCAFailWithOptionalClientAuth(SSLEngineTestParam param) + throws Exception { + checkShouldUseKeyManagerFactory(); + super.testMutualAuthInvalidIntermediateCAFailWithOptionalClientAuth(param); + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Override + public void testMutualAuthInvalidIntermediateCAFailWithRequiredClientAuth(SSLEngineTestParam param) + throws Exception { + checkShouldUseKeyManagerFactory(); + super.testMutualAuthInvalidIntermediateCAFailWithRequiredClientAuth(param); + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Override + public void testSessionAfterHandshakeKeyManagerFactory(SSLEngineTestParam param) throws Exception { + checkShouldUseKeyManagerFactory(); + super.testSessionAfterHandshakeKeyManagerFactory(param); + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Override + public void testSessionAfterHandshakeKeyManagerFactoryMutualAuth(SSLEngineTestParam param) throws Exception { + checkShouldUseKeyManagerFactory(); + super.testSessionAfterHandshakeKeyManagerFactoryMutualAuth(param); + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Override + public void testSupportedSignatureAlgorithms(SSLEngineTestParam param) throws Exception { + checkShouldUseKeyManagerFactory(); + super.testSupportedSignatureAlgorithms(param); + } + + @Override + protected boolean mySetupMutualAuthServerIsValidServerException(Throwable cause) { + // TODO(scott): work around for a JDK issue. The exception should be SSLHandshakeException. + return super.mySetupMutualAuthServerIsValidServerException(cause) || causedBySSLException(cause); + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Override + public void testSessionLocalWhenNonMutualWithKeyManager(SSLEngineTestParam param) throws Exception { + checkShouldUseKeyManagerFactory(); + super.testSessionLocalWhenNonMutualWithKeyManager(param); + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Override + public void testSessionLocalWhenNonMutualWithoutKeyManager(SSLEngineTestParam param) throws Exception { + // This only really works when the KeyManagerFactory is supported as otherwise we not really know when + // we need to provide a cert. + assumeTrue(OpenSsl.supportsKeyManagerFactory()); + super.testSessionLocalWhenNonMutualWithoutKeyManager(param); + } + + @Override + protected void invalidateSessionsAndAssert(SSLSessionContext context) { + // Not supported by conscrypt + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Override + public void testSessionCache(SSLEngineTestParam param) throws Exception { + assumeTrue(OpenSsl.isSessionCacheSupported()); + super.testSessionCache(param); + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Override + public void testSessionCacheTimeout(SSLEngineTestParam param) throws Exception { + assumeTrue(OpenSsl.isSessionCacheSupported()); + super.testSessionCacheTimeout(param); + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Override + public void testSessionCacheSize(SSLEngineTestParam param) throws Exception { + assumeTrue(OpenSsl.isSessionCacheSupported()); + super.testSessionCacheSize(param); + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Disabled("Disabled due a conscrypt bug") + @Override + public void testInvalidSNIIsIgnoredAndNotThrow(SSLEngineTestParam param) throws Exception { + super.testInvalidSNIIsIgnoredAndNotThrow(param); + } + + @Override + protected SSLEngine wrapEngine(SSLEngine engine) { + return Java8SslTestUtils.wrapSSLEngineForTesting(engine); + } + + @SuppressWarnings("deprecation") + @Override + protected SslContext wrapContext(SSLEngineTestParam param, SslContext context) { + if (context instanceof OpenSslContext) { + if (param instanceof OpenSslEngineTestParam) { + ((OpenSslContext) context).setUseTasks(((OpenSslEngineTestParam) param).useTasks); + } + // Explicit enable the session cache as its disabled by default on the client side. + ((OpenSslContext) context).sessionContext().setSessionCacheEnabled(true); + } + return context; + } +} diff --git a/netty-handler-ssl/src/test/java/io/netty/handler/ssl/ConscryptSslEngineTest.java b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/ConscryptSslEngineTest.java new file mode 100644 index 0000000..5822185 --- /dev/null +++ b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/ConscryptSslEngineTest.java @@ -0,0 +1,99 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + + +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.condition.DisabledIf; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import javax.net.ssl.SSLSessionContext; +import java.security.Provider; + +@DisabledIf("checkConscryptDisabled") +public class ConscryptSslEngineTest extends SSLEngineTest { + + static boolean checkConscryptDisabled() { + return !Conscrypt.isAvailable(); + } + + public ConscryptSslEngineTest() { + super(false); + } + + @Override + protected SslProvider sslClientProvider() { + return SslProvider.JDK; + } + + @Override + protected SslProvider sslServerProvider() { + return SslProvider.JDK; + } + + @Override + protected Provider clientSslContextProvider() { + return Java8SslTestUtils.conscryptProvider(); + } + + @Override + protected Provider serverSslContextProvider() { + return Java8SslTestUtils.conscryptProvider(); + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Disabled /* Does the JDK support a "max certificate chain length"? */ + @Override + public void testMutualAuthValidClientCertChainTooLongFailOptionalClientAuth(SSLEngineTestParam param) { + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Disabled /* Does the JDK support a "max certificate chain length"? */ + @Override + public void testMutualAuthValidClientCertChainTooLongFailRequireClientAuth(SSLEngineTestParam param) { + } + + @Override + protected void invalidateSessionsAndAssert(SSLSessionContext context) { + // Not supported by conscrypt + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Disabled("Possible Conscrypt bug") + @Override + public void testSessionCacheTimeout(SSLEngineTestParam param) throws Exception { + // Skip + // https://github.com/google/conscrypt/issues/851 + } + + @Disabled("Not supported") + @Override + public void testRSASSAPSS(SSLEngineTestParam param) { + // skip + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Disabled("Disabled due a conscrypt bug") + @Override + public void testInvalidSNIIsIgnoredAndNotThrow(SSLEngineTestParam param) throws Exception { + super.testInvalidSNIIsIgnoredAndNotThrow(param); + } +} diff --git a/netty-handler-ssl/src/test/java/io/netty/handler/ssl/DelegatingSslContextTest.java b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/DelegatingSslContextTest.java new file mode 100644 index 0000000..56b36de --- /dev/null +++ b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/DelegatingSslContextTest.java @@ -0,0 +1,60 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.ssl; + +import io.netty.buffer.UnpooledByteBufAllocator; +import org.junit.jupiter.api.Test; + +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLEngine; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; + +public class DelegatingSslContextTest { + private static final String[] EXPECTED_PROTOCOLS = { SslProtocols.TLS_v1_1 }; + + @Test + public void testInitEngineOnNewEngine() throws Exception { + SslContext delegating = newDelegatingSslContext(); + + SSLEngine engine = delegating.newEngine(UnpooledByteBufAllocator.DEFAULT); + assertArrayEquals(EXPECTED_PROTOCOLS, engine.getEnabledProtocols()); + + engine = delegating.newEngine(UnpooledByteBufAllocator.DEFAULT, "localhost", 9090); + assertArrayEquals(EXPECTED_PROTOCOLS, engine.getEnabledProtocols()); + } + + @Test + public void testInitEngineOnNewSslHandler() throws Exception { + SslContext delegating = newDelegatingSslContext(); + + SslHandler handler = delegating.newHandler(UnpooledByteBufAllocator.DEFAULT); + assertArrayEquals(EXPECTED_PROTOCOLS, handler.engine().getEnabledProtocols()); + + handler = delegating.newHandler(UnpooledByteBufAllocator.DEFAULT, "localhost", 9090); + assertArrayEquals(EXPECTED_PROTOCOLS, handler.engine().getEnabledProtocols()); + } + + private static SslContext newDelegatingSslContext() throws Exception { + return new DelegatingSslContext(new JdkSslContext(SSLContext.getDefault(), false, ClientAuth.NONE)) { + @Override + protected void initEngine(SSLEngine engine) { + engine.setEnabledProtocols(EXPECTED_PROTOCOLS); + } + }; + } +} diff --git a/netty-handler-ssl/src/test/java/io/netty/handler/ssl/EnhancedX509ExtendedTrustManagerTest.java b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/EnhancedX509ExtendedTrustManagerTest.java new file mode 100644 index 0000000..398fa98 --- /dev/null +++ b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/EnhancedX509ExtendedTrustManagerTest.java @@ -0,0 +1,326 @@ +/* + * Copyright 2023 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.ssl; + +import io.netty.util.internal.EmptyArrays; +import org.hamcrest.Matchers; +import org.junit.jupiter.api.function.Executable; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLSocket; +import javax.net.ssl.X509ExtendedTrustManager; +import java.math.BigInteger; +import java.net.Socket; +import java.security.Principal; +import java.security.PublicKey; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; +import java.util.Arrays; +import java.util.Collection; +import java.util.Date; +import java.util.List; +import java.util.Set; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertInstanceOf; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.fail; + +public class EnhancedX509ExtendedTrustManagerTest { + + private static final X509Certificate TEST_CERT = new X509Certificate() { + + @Override + public Collection> getSubjectAlternativeNames() { + return Arrays.asList(Arrays.asList(1, new Object()), Arrays.asList(2, "some.netty.io")); + } + + @Override + public void checkValidity() { + // NOOP + } + + @Override + public void checkValidity(Date date) { + // NOOP + } + + @Override + public int getVersion() { + return 0; + } + + @Override + public BigInteger getSerialNumber() { + return null; + } + + @Override + public Principal getIssuerDN() { + return null; + } + + @Override + public Principal getSubjectDN() { + return null; + } + + @Override + public Date getNotBefore() { + return null; + } + + @Override + public Date getNotAfter() { + return null; + } + + @Override + public byte[] getTBSCertificate() { + return EmptyArrays.EMPTY_BYTES; + } + + @Override + public byte[] getSignature() { + return EmptyArrays.EMPTY_BYTES; + } + + @Override + public String getSigAlgName() { + return null; + } + + @Override + public String getSigAlgOID() { + return null; + } + + @Override + public byte[] getSigAlgParams() { + return EmptyArrays.EMPTY_BYTES; + } + + @Override + public boolean[] getIssuerUniqueID() { + return new boolean[0]; + } + + @Override + public boolean[] getSubjectUniqueID() { + return new boolean[0]; + } + + @Override + public boolean[] getKeyUsage() { + return new boolean[0]; + } + + @Override + public int getBasicConstraints() { + return 0; + } + + @Override + public byte[] getEncoded() { + return EmptyArrays.EMPTY_BYTES; + } + + @Override + public void verify(PublicKey key) { + // NOOP + } + + @Override + public void verify(PublicKey key, String sigProvider) { + // NOOP + } + + @Override + public String toString() { + return null; + } + + @Override + public PublicKey getPublicKey() { + return null; + } + + @Override + public boolean hasUnsupportedCriticalExtension() { + return false; + } + + @Override + public Set getCriticalExtensionOIDs() { + return null; + } + + @Override + public Set getNonCriticalExtensionOIDs() { + return null; + } + + @Override + public byte[] getExtensionValue(String oid) { + return EmptyArrays.EMPTY_BYTES; + } + }; + + private static final EnhancingX509ExtendedTrustManager MATCHING_MANAGER = + new EnhancingX509ExtendedTrustManager(new X509ExtendedTrustManager() { + @Override + public void checkClientTrusted(X509Certificate[] chain, String authType, Socket socket) { + fail(); + } + + @Override + public void checkServerTrusted(X509Certificate[] chain, String authType, Socket socket) + throws CertificateException { + throw new CertificateException("No subject alternative DNS name matching netty.io."); + } + + @Override + public void checkClientTrusted(X509Certificate[] chain, String authType, SSLEngine engine) { + fail(); + } + + @Override + public void checkServerTrusted(X509Certificate[] chain, String authType, SSLEngine engine) + throws CertificateException { + throw new CertificateException("No subject alternative DNS name matching netty.io."); + } + + @Override + public void checkClientTrusted(X509Certificate[] chain, String authType) { + fail(); + } + + @Override + public void checkServerTrusted(X509Certificate[] chain, String authType) + throws CertificateException { + throw new CertificateException("No subject alternative DNS name matching netty.io."); + } + + @Override + public X509Certificate[] getAcceptedIssuers() { + return new X509Certificate[0]; + } + }); + + static List throwingMatchingExecutables() { + return Arrays.asList(new Executable() { + @Override + public void execute() throws Throwable { + MATCHING_MANAGER.checkServerTrusted(new X509Certificate[] { TEST_CERT }, null); + } + }, new Executable() { + @Override + public void execute() throws Throwable { + MATCHING_MANAGER.checkServerTrusted(new X509Certificate[] { TEST_CERT }, null, (SSLEngine) null); + } + }, new Executable() { + @Override + public void execute() throws Throwable { + MATCHING_MANAGER.checkServerTrusted(new X509Certificate[] { TEST_CERT }, null, (SSLSocket) null); + } + }); + } + + private static final EnhancingX509ExtendedTrustManager NON_MATCHING_MANAGER = + new EnhancingX509ExtendedTrustManager(new X509ExtendedTrustManager() { + @Override + public void checkClientTrusted(X509Certificate[] chain, String authType, Socket socket) + throws CertificateException { + fail(); + } + + @Override + public void checkServerTrusted(X509Certificate[] chain, String authType, Socket socket) + throws CertificateException { + throw new CertificateException(); + } + + @Override + public void checkClientTrusted(X509Certificate[] chain, String authType, SSLEngine engine) + throws CertificateException { + fail(); + } + + @Override + public void checkServerTrusted(X509Certificate[] chain, String authType, SSLEngine engine) + throws CertificateException { + throw new CertificateException(); + } + + @Override + public void checkClientTrusted(X509Certificate[] chain, String authType) + throws CertificateException { + fail(); + } + + @Override + public void checkServerTrusted(X509Certificate[] chain, String authType) + throws CertificateException { + throw new CertificateException(); + } + + @Override + public X509Certificate[] getAcceptedIssuers() { + return new X509Certificate[0]; + } + }); + + static List throwingNonMatchingExecutables() { + return Arrays.asList(new Executable() { + @Override + public void execute() throws Throwable { + NON_MATCHING_MANAGER.checkServerTrusted(new X509Certificate[] { TEST_CERT }, null); + } + }, new Executable() { + @Override + public void execute() throws Throwable { + NON_MATCHING_MANAGER.checkServerTrusted(new X509Certificate[] { TEST_CERT }, null, (SSLEngine) null); + } + }, new Executable() { + @Override + public void execute() throws Throwable { + NON_MATCHING_MANAGER.checkServerTrusted(new X509Certificate[] { TEST_CERT }, null, (SSLSocket) null); + } + }); + } + + @ParameterizedTest + @MethodSource("throwingMatchingExecutables") + void testEnhanceException(Executable executable) { + CertificateException exception = assertThrows(CertificateException.class, executable); + // We should wrap the original cause with our own. + assertInstanceOf(CertificateException.class, exception.getCause()); + assertThat(exception.getMessage(), Matchers.containsString("some.netty.io")); + } + + @ParameterizedTest + @MethodSource("throwingNonMatchingExecutables") + void testNotEnhanceException(Executable executable) { + CertificateException exception = assertThrows(CertificateException.class, executable); + // We should not wrap the original cause with our own. + assertNull(exception.getCause()); + assertThat(exception.getMessage(), Matchers.not(Matchers.containsString("some.netty.io"))); + } +} diff --git a/netty-handler-ssl/src/test/java/io/netty/handler/ssl/IdentityCipherSuiteFilterTest.java b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/IdentityCipherSuiteFilterTest.java new file mode 100644 index 0000000..b1ee341 --- /dev/null +++ b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/IdentityCipherSuiteFilterTest.java @@ -0,0 +1,46 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import org.junit.jupiter.api.Test; + +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; + +public class IdentityCipherSuiteFilterTest { + + @Test + public void regularInstanceDefaultsToDefaultCiphers() { + List defaultCiphers = Arrays.asList("FOO", "BAR"); + Set supportedCiphers = new HashSet(Arrays.asList("BAZ", "QIX")); + String[] filtered = IdentityCipherSuiteFilter.INSTANCE + .filterCipherSuites(null, defaultCiphers, supportedCiphers); + assertArrayEquals(defaultCiphers.toArray(), filtered); + } + + @Test + public void alternativeInstanceDefaultsToSupportedCiphers() { + List defaultCiphers = Arrays.asList("FOO", "BAR"); + Set supportedCiphers = new HashSet(Arrays.asList("BAZ", "QIX")); + String[] filtered = IdentityCipherSuiteFilter.INSTANCE_DEFAULTING_TO_SUPPORTED_CIPHERS + .filterCipherSuites(null, defaultCiphers, supportedCiphers); + assertArrayEquals(supportedCiphers.toArray(), filtered); + } +} diff --git a/netty-handler-ssl/src/test/java/io/netty/handler/ssl/Java8SslTestUtils.java b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/Java8SslTestUtils.java new file mode 100644 index 0000000..1021619 --- /dev/null +++ b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/Java8SslTestUtils.java @@ -0,0 +1,84 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.ssl; + +import org.conscrypt.OpenSSLProvider; + +import javax.net.ssl.SNIMatcher; +import javax.net.ssl.SNIServerName; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLParameters; +import java.io.InputStream; +import java.security.Provider; +import java.security.cert.CertificateFactory; +import java.security.cert.X509Certificate; +import java.util.Arrays; +import java.util.Collections; + +import static org.junit.jupiter.api.Assertions.assertNotNull; + +public final class Java8SslTestUtils { + + private Java8SslTestUtils() { } + + static void setSNIMatcher(SSLParameters parameters, final byte[] match) { + SNIMatcher matcher = new SNIMatcher(0) { + @Override + public boolean matches(SNIServerName sniServerName) { + return Arrays.equals(match, sniServerName.getEncoded()); + } + }; + parameters.setSNIMatchers(Collections.singleton(matcher)); + } + + static Provider conscryptProvider() { + return new OpenSSLProvider(); + } + + /** + * Wraps the given {@link SSLEngine} to add extra tests while executing methods if possible / needed. + */ + static SSLEngine wrapSSLEngineForTesting(SSLEngine engine) { + if (engine instanceof ReferenceCountedOpenSslEngine) { + return new OpenSslErrorStackAssertSSLEngine((ReferenceCountedOpenSslEngine) engine); + } + return engine; + } + + public static X509Certificate[] loadCertCollection(String... resourceNames) + throws Exception { + CertificateFactory certFactory = CertificateFactory + .getInstance("X.509"); + + X509Certificate[] certCollection = new X509Certificate[resourceNames.length]; + for (int i = 0; i < resourceNames.length; i++) { + String resourceName = resourceNames[i]; + InputStream is = null; + try { + is = SslContextTest.class.getResourceAsStream(resourceName); + assertNotNull(is, "Cannot find " + resourceName); + certCollection[i] = (X509Certificate) certFactory + .generateCertificate(is); + } finally { + if (is != null) { + is.close(); + } + } + } + return certCollection; + } +} diff --git a/netty-handler-ssl/src/test/java/io/netty/handler/ssl/JdkConscryptSslEngineInteropTest.java b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/JdkConscryptSslEngineInteropTest.java new file mode 100644 index 0000000..be453db --- /dev/null +++ b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/JdkConscryptSslEngineInteropTest.java @@ -0,0 +1,105 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import java.security.Provider; + +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.condition.DisabledIf; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import javax.net.ssl.SSLSessionContext; + +@DisabledIf("checkConscryptDisabled") +public class JdkConscryptSslEngineInteropTest extends SSLEngineTest { + + static boolean checkConscryptDisabled() { + return !Conscrypt.isAvailable(); + } + + public JdkConscryptSslEngineInteropTest() { + super(false); + } + + @Override + protected SslProvider sslClientProvider() { + return SslProvider.JDK; + } + + @Override + protected SslProvider sslServerProvider() { + return SslProvider.JDK; + } + + @Override + protected Provider serverSslContextProvider() { + return Java8SslTestUtils.conscryptProvider(); + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Disabled("TODO: Make this work with Conscrypt") + @Override + public void testMutualAuthValidClientCertChainTooLongFailOptionalClientAuth(SSLEngineTestParam param) + throws Exception { + super.testMutualAuthValidClientCertChainTooLongFailOptionalClientAuth(param); + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Disabled("TODO: Make this work with Conscrypt") + @Override + public void testMutualAuthValidClientCertChainTooLongFailRequireClientAuth(SSLEngineTestParam param) + throws Exception { + super.testMutualAuthValidClientCertChainTooLongFailRequireClientAuth(param); + } + + @Override + protected boolean mySetupMutualAuthServerIsValidClientException(Throwable cause) { + // TODO(scott): work around for a JDK issue. The exception should be SSLHandshakeException. + return super.mySetupMutualAuthServerIsValidClientException(cause) || causedBySSLException(cause); + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Disabled("Ignore due bug in Conscrypt") + @Override + public void testHandshakeSession(SSLEngineTestParam param) throws Exception { + // Ignore as Conscrypt does not correctly return the local certificates while the TrustManager is invoked. + // See https://github.com/google/conscrypt/issues/634 + } + + @Override + protected void invalidateSessionsAndAssert(SSLSessionContext context) { + // Not supported by conscrypt + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Disabled("Possible Conscrypt bug") + @Override + public void testSessionCacheTimeout(SSLEngineTestParam param) { + // Skip + // https://github.com/google/conscrypt/issues/851 + } + + @Disabled("Not supported") + @Override + public void testRSASSAPSS(SSLEngineTestParam param) { + // skip + } +} diff --git a/netty-handler-ssl/src/test/java/io/netty/handler/ssl/JdkOpenSslEngineInteroptTest.java b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/JdkOpenSslEngineInteroptTest.java new file mode 100644 index 0000000..6ee49fb --- /dev/null +++ b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/JdkOpenSslEngineInteroptTest.java @@ -0,0 +1,256 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import javax.net.ssl.SSLEngine; +import java.util.ArrayList; +import java.util.List; + +import static io.netty.handler.ssl.OpenSslTestUtils.checkShouldUseKeyManagerFactory; +import static io.netty.internal.tcnative.SSL.SSL_CVERIFY_IGNORED; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +public class JdkOpenSslEngineInteroptTest extends SSLEngineTest { + + public JdkOpenSslEngineInteroptTest() { + super(SslProvider.isTlsv13Supported(SslProvider.JDK) && + SslProvider.isTlsv13Supported(SslProvider.OPENSSL)); + } + + @Override + protected List newTestParams() { + List params = super.newTestParams(); + List testParams = new ArrayList(); + for (SSLEngineTestParam param: params) { + testParams.add(new OpenSslEngineTestParam(true, param)); + // TODO this hangs! JP, 5.1.2024 + //testParams.add(new OpenSslEngineTestParam(false, param)); + } + return testParams; + } + + @BeforeAll + public static void checkOpenSsl() { + OpenSsl.ensureAvailability(); + } + + @Override + protected SslProvider sslClientProvider() { + return SslProvider.JDK; + } + + @Override + protected SslProvider sslServerProvider() { + return SslProvider.OPENSSL; + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Disabled("Disable until figured out why this sometimes fail on the CI") + @Override + public void testMutualAuthSameCerts(SSLEngineTestParam param) throws Throwable { + super.testMutualAuthSameCerts(param); + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Disabled("Disable until figured out why this sometimes fail on the CI") + @Override + public void testMutualAuthDiffCerts(SSLEngineTestParam param) throws Exception { + super.testMutualAuthDiffCerts(param); + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Disabled("Disable until figured out why this sometimes fail on the CI") + @Override + public void testMutualAuthDiffCertsServerFailure(SSLEngineTestParam param) throws Exception { + super.testMutualAuthDiffCertsServerFailure(param); + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Disabled("Disable until figured out why this sometimes fail on the CI") + @Override + public void testMutualAuthDiffCertsClientFailure(SSLEngineTestParam param) throws Exception { + super.testMutualAuthDiffCertsClientFailure(param); + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Disabled("Disable until figured out why this sometimes fail on the CI") + @Override + public void testMutualAuthInvalidIntermediateCASucceedWithOptionalClientAuth(SSLEngineTestParam param) + throws Exception { + checkShouldUseKeyManagerFactory(); + super.testMutualAuthInvalidIntermediateCASucceedWithOptionalClientAuth(param); + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Disabled("Disable until figured out why this sometimes fail on the CI") + @Override + public void testMutualAuthInvalidIntermediateCAFailWithOptionalClientAuth(SSLEngineTestParam param) + throws Exception { + checkShouldUseKeyManagerFactory(); + super.testMutualAuthInvalidIntermediateCAFailWithOptionalClientAuth(param); + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Disabled("Disable until figured out why this sometimes fail on the CI") + @Override + public void testMutualAuthInvalidIntermediateCAFailWithRequiredClientAuth(SSLEngineTestParam param) + throws Exception { + checkShouldUseKeyManagerFactory(); + super.testMutualAuthInvalidIntermediateCAFailWithRequiredClientAuth(param); + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Disabled("Disable until figured out why this sometimes fail on the CI") + @Override + public void testMutualAuthValidClientCertChainTooLongFailOptionalClientAuth(SSLEngineTestParam param) + throws Exception { + checkShouldUseKeyManagerFactory(); + super.testMutualAuthValidClientCertChainTooLongFailOptionalClientAuth(param); + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Disabled("Disable until figured out why this sometimes fail on the CI") + @Override + public void testMutualAuthValidClientCertChainTooLongFailRequireClientAuth(SSLEngineTestParam param) + throws Exception { + checkShouldUseKeyManagerFactory(); + super.testMutualAuthValidClientCertChainTooLongFailRequireClientAuth(param); + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Override + public void testSessionAfterHandshakeKeyManagerFactoryMutualAuth(SSLEngineTestParam param) throws Exception { + checkShouldUseKeyManagerFactory(); + super.testSessionAfterHandshakeKeyManagerFactoryMutualAuth(param); + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Override + public void testSessionAfterHandshakeKeyManagerFactory(SSLEngineTestParam param) throws Exception { + checkShouldUseKeyManagerFactory(); + super.testSessionAfterHandshakeKeyManagerFactory(param); + } + + @Override + protected void mySetupMutualAuthServerInitSslHandler(SslHandler handler) { + ReferenceCountedOpenSslEngine engine = (ReferenceCountedOpenSslEngine) handler.engine(); + engine.setVerify(SSL_CVERIFY_IGNORED, 1); + } + + @Override + protected boolean mySetupMutualAuthServerIsValidClientException(Throwable cause) { + // TODO(scott): work around for a JDK issue. The exception should be SSLHandshakeException. + return super.mySetupMutualAuthServerIsValidClientException(cause) || causedBySSLException(cause); + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Override + public void testHandshakeSession(SSLEngineTestParam param) throws Exception { + checkShouldUseKeyManagerFactory(); + super.testHandshakeSession(param); + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Override + public void testSupportedSignatureAlgorithms(SSLEngineTestParam param) throws Exception { + checkShouldUseKeyManagerFactory(); + super.testSupportedSignatureAlgorithms(param); + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Override + public void testSessionLocalWhenNonMutualWithKeyManager(SSLEngineTestParam param) throws Exception { + checkShouldUseKeyManagerFactory(); + super.testSessionLocalWhenNonMutualWithKeyManager(param); + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Override + public void testSessionLocalWhenNonMutualWithoutKeyManager(SSLEngineTestParam param) throws Exception { + // This only really works when the KeyManagerFactory is supported as otherwise we not really know when + // we need to provide a cert. + assumeTrue(OpenSsl.supportsKeyManagerFactory()); + super.testSessionLocalWhenNonMutualWithoutKeyManager(param); + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Override + public void testSessionCache(SSLEngineTestParam param) throws Exception { + assumeTrue(OpenSsl.isSessionCacheSupported()); + super.testSessionCache(param); + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Override + public void testSessionCacheTimeout(SSLEngineTestParam param) throws Exception { + assumeTrue(OpenSsl.isSessionCacheSupported()); + super.testSessionCacheTimeout(param); + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Override + public void testSessionCacheSize(SSLEngineTestParam param) throws Exception { + assumeTrue(OpenSsl.isSessionCacheSupported()); + super.testSessionCacheSize(param); + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Override + public void testRSASSAPSS(SSLEngineTestParam param) throws Exception { + checkShouldUseKeyManagerFactory(); + super.testRSASSAPSS(param); + } + + @Override + protected SSLEngine wrapEngine(SSLEngine engine) { + return Java8SslTestUtils.wrapSSLEngineForTesting(engine); + } + + @SuppressWarnings("deprecation") + @Override + protected SslContext wrapContext(SSLEngineTestParam param, SslContext context) { + if (context instanceof OpenSslContext && param instanceof OpenSslEngineTestParam) { + ((OpenSslContext) context).setUseTasks(((OpenSslEngineTestParam) param).useTasks); + // Explicit enable the session cache as its disabled by default on the client side. + ((OpenSslContext) context).sessionContext().setSessionCacheEnabled(true); + } + return context; + } +} diff --git a/netty-handler-ssl/src/test/java/io/netty/handler/ssl/JdkSslClientContextTest.java b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/JdkSslClientContextTest.java new file mode 100644 index 0000000..e5e18c1 --- /dev/null +++ b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/JdkSslClientContextTest.java @@ -0,0 +1,29 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.handler.ssl.util.InsecureTrustManagerFactory; + +import javax.net.ssl.SSLException; +import java.io.File; + +public class JdkSslClientContextTest extends SslContextTest { + @Override + protected SslContext newSslContext(File crtFile, File keyFile, String pass) throws SSLException { + return new JdkSslClientContext(crtFile, InsecureTrustManagerFactory.INSTANCE, crtFile, keyFile, pass, + null, null, IdentityCipherSuiteFilter.INSTANCE, ApplicationProtocolConfig.DISABLED, 0, 0); + } +} diff --git a/netty-handler-ssl/src/test/java/io/netty/handler/ssl/JdkSslEngineTest.java b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/JdkSslEngineTest.java new file mode 100644 index 0000000..468f99c --- /dev/null +++ b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/JdkSslEngineTest.java @@ -0,0 +1,358 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.handler.ssl.ApplicationProtocolConfig.Protocol; +import io.netty.handler.ssl.ApplicationProtocolConfig.SelectedListenerFailureBehavior; +import io.netty.handler.ssl.ApplicationProtocolConfig.SelectorFailureBehavior; +import io.netty.handler.ssl.JdkApplicationProtocolNegotiator.ProtocolSelector; +import io.netty.handler.ssl.JdkApplicationProtocolNegotiator.ProtocolSelectorFactory; +import io.netty.handler.ssl.util.InsecureTrustManagerFactory; +import io.netty.handler.ssl.util.SelfSignedCertificate; +import java.security.Provider; + +import io.netty.util.internal.EmptyArrays; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.opentest4j.TestAbortedException; + +import java.util.ArrayList; +import java.util.List; +import java.util.Set; +import java.util.concurrent.TimeUnit; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLHandshakeException; + +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class JdkSslEngineTest extends SSLEngineTest { + public enum ProviderType { + ALPN_JAVA { + @Override + boolean isAvailable() { + return JdkAlpnSslUtils.supportsAlpn(); + } + + @Override + Protocol protocol() { + return Protocol.ALPN; + } + + @Override + Provider provider() { + // Use the default provider. + return null; + } + }, + ALPN_CONSCRYPT { + private Provider provider; + + @Override + boolean isAvailable() { + return Conscrypt.isAvailable(); + } + + @Override + Protocol protocol() { + return Protocol.ALPN; + } + + @Override + Provider provider() { + try { + if (provider == null) { + provider = (Provider) Class.forName("org.conscrypt.OpenSSLProvider") + .getConstructor().newInstance(); + } + return provider; + } catch (Exception e) { + throw new IllegalStateException(e); + } + } + }; + + abstract boolean isAvailable(); + abstract Protocol protocol(); + abstract Provider provider(); + + final void activate(JdkSslEngineTest instance) { + // Typical code will not have to check this, but will get a initialization error on class load. + // Check in this test just in case we have multiple tests that just the class and we already ignored the + // initialization error. + if (!isAvailable()) { + throw tlsExtensionNotFound(protocol()); + } + instance.provider = provider(); + } + } + + private static final String PREFERRED_APPLICATION_LEVEL_PROTOCOL = "my-protocol-http2"; + private static final String FALLBACK_APPLICATION_LEVEL_PROTOCOL = "my-protocol-http1_1"; + private static final String APPLICATION_LEVEL_PROTOCOL_NOT_COMPATIBLE = "my-protocol-FOO"; + + private Provider provider; + + public JdkSslEngineTest() { + super(SslProvider.isTlsv13Supported(SslProvider.JDK)); + } + + List newJdkParams() { + List params = newTestParams(); + + List jdkParams = new ArrayList(); + for (ProviderType providerType: ProviderType.values()) { + for (SSLEngineTestParam param: params) { + jdkParams.add(new JdkSSLEngineTestParam(providerType, param)); + } + } + return jdkParams; + } + + private static final class JdkSSLEngineTestParam extends SSLEngineTestParam { + final ProviderType providerType; + JdkSSLEngineTestParam(ProviderType providerType, SSLEngineTestParam param) { + super(param.type(), param.combo(), param.delegate()); + this.providerType = providerType; + } + + @Override + public String toString() { + return "JdkSSLEngineTestParam{" + + "type=" + type() + + ", protocolCipherCombo=" + combo() + + ", delegate=" + delegate() + + ", providerType=" + providerType + + '}'; + } + } + + @MethodSource("newJdkParams") + @ParameterizedTest + public void testTlsExtension(JdkSSLEngineTestParam param) throws Exception { + try { + param.providerType.activate(this); + ApplicationProtocolConfig apn = failingNegotiator(param.providerType.protocol(), + PREFERRED_APPLICATION_LEVEL_PROTOCOL); + setupHandlers(param, apn); + runTest(); + } catch (SkipTestException e) { + // ALPN availability is dependent on the java version. If ALPN is not available because of + // java version incompatibility don't fail the test, but instead just skip the test + throw new TestAbortedException("Not expected", e); + } + } + + @MethodSource("newJdkParams") + @ParameterizedTest + public void testTlsExtensionNoCompatibleProtocolsNoHandshakeFailure(JdkSSLEngineTestParam param) throws Exception { + try { + param.providerType.activate(this); + ApplicationProtocolConfig clientApn = acceptingNegotiator(param.providerType.protocol(), + PREFERRED_APPLICATION_LEVEL_PROTOCOL); + ApplicationProtocolConfig serverApn = acceptingNegotiator(param.providerType.protocol(), + APPLICATION_LEVEL_PROTOCOL_NOT_COMPATIBLE); + setupHandlers(param, serverApn, clientApn); + runTest(null); + } catch (SkipTestException e) { + // ALPN availability is dependent on the java version. If ALPN is not available because of + // java version incompatibility don't fail the test, but instead just skip the test + throw new TestAbortedException("Not expected", e); + } + } + + @MethodSource("newJdkParams") + @ParameterizedTest + public void testTlsExtensionNoCompatibleProtocolsClientHandshakeFailure(JdkSSLEngineTestParam param) + throws Exception { + try { + param.providerType.activate(this); + SelfSignedCertificate ssc = new SelfSignedCertificate(); + JdkApplicationProtocolNegotiator clientApn = new JdkAlpnApplicationProtocolNegotiator(true, true, + PREFERRED_APPLICATION_LEVEL_PROTOCOL); + JdkApplicationProtocolNegotiator serverApn = new JdkAlpnApplicationProtocolNegotiator( + new ProtocolSelectorFactory() { + @Override + public ProtocolSelector newSelector(SSLEngine engine, Set supportedProtocols) { + return new ProtocolSelector() { + @Override + public void unsupported() { + } + + @Override + public String select(List protocols) { + return APPLICATION_LEVEL_PROTOCOL_NOT_COMPATIBLE; + } + }; + } + }, JdkBaseApplicationProtocolNegotiator.FAIL_SELECTION_LISTENER_FACTORY, + APPLICATION_LEVEL_PROTOCOL_NOT_COMPATIBLE); + + SslContext serverSslCtx = new JdkSslServerContext(param.providerType.provider(), + ssc.certificate(), ssc.privateKey(), null, null, + IdentityCipherSuiteFilter.INSTANCE, serverApn, 0, 0, null); + SslContext clientSslCtx = new JdkSslClientContext(param.providerType.provider(), null, + InsecureTrustManagerFactory.INSTANCE, null, + IdentityCipherSuiteFilter.INSTANCE, clientApn, 0, 0); + + setupHandlers(param.type(), param.delegate(), new TestDelegatingSslContext(param, serverSslCtx), + new TestDelegatingSslContext(param, clientSslCtx)); + assertTrue(clientLatch.await(2, TimeUnit.SECONDS)); + // When using TLSv1.3 the handshake is NOT sent in an extra round trip which means there will be + // no exception reported in this case but just the channel will be closed. + assertTrue(clientException instanceof SSLHandshakeException || clientException == null); + } catch (SkipTestException e) { + // ALPN availability is dependent on the java version. If ALPN is not available because of + // java version incompatibility don't fail the test, but instead just skip the test + throw new TestAbortedException("Not expected", e); + } + } + + @MethodSource("newJdkParams") + @ParameterizedTest + public void testTlsExtensionNoCompatibleProtocolsServerHandshakeFailure(JdkSSLEngineTestParam param) + throws Exception { + try { + param.providerType.activate(this); + ApplicationProtocolConfig clientApn = acceptingNegotiator(param.providerType.protocol(), + PREFERRED_APPLICATION_LEVEL_PROTOCOL); + ApplicationProtocolConfig serverApn = failingNegotiator(param.providerType.protocol(), + APPLICATION_LEVEL_PROTOCOL_NOT_COMPATIBLE); + setupHandlers(param, serverApn, clientApn); + assertTrue(serverLatch.await(2, TimeUnit.SECONDS)); + assertTrue(serverException instanceof SSLHandshakeException); + } catch (SkipTestException e) { + // ALPN availability is dependent on the java version. If ALPN is not available because of + // java version incompatibility don't fail the test, but instead just skip the test + throw new TestAbortedException("Not expected", e); + } + } + + @MethodSource("newJdkParams") + @ParameterizedTest + public void testAlpnCompatibleProtocolsDifferentClientOrder(JdkSSLEngineTestParam param) throws Exception { + try { + param.providerType.activate(this); + // Even the preferred application protocol appears second in the client's list, it will be picked + // because it's the first one on server's list. + ApplicationProtocolConfig clientApn = acceptingNegotiator(Protocol.ALPN, + FALLBACK_APPLICATION_LEVEL_PROTOCOL, PREFERRED_APPLICATION_LEVEL_PROTOCOL); + ApplicationProtocolConfig serverApn = failingNegotiator(Protocol.ALPN, + PREFERRED_APPLICATION_LEVEL_PROTOCOL, FALLBACK_APPLICATION_LEVEL_PROTOCOL); + setupHandlers(param, serverApn, clientApn); + assertNull(serverException); + runTest(PREFERRED_APPLICATION_LEVEL_PROTOCOL); + } catch (SkipTestException e) { + // ALPN availability is dependent on the java version. If ALPN is not available because of + // java version incompatibility don't fail the test, but instead just skip the test + throw new TestAbortedException("Not expected", e); + } + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testEnablingAnAlreadyDisabledSslProtocol(SSLEngineTestParam param) throws Exception { + testEnablingAnAlreadyDisabledSslProtocol(param, new String[]{}, new String[]{ SslProtocols.TLS_v1_2 }); + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Disabled /* Does the JDK support a "max certificate chain length"? */ + @Override + public void testMutualAuthValidClientCertChainTooLongFailOptionalClientAuth(SSLEngineTestParam param) + throws Exception { + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Disabled /* Does the JDK support a "max certificate chain length"? */ + @Override + public void testMutualAuthValidClientCertChainTooLongFailRequireClientAuth(SSLEngineTestParam param) + throws Exception { + } + + @Override + protected boolean mySetupMutualAuthServerIsValidException(Throwable cause) { + // TODO(scott): work around for a JDK issue. The exception should be SSLHandshakeException. + return super.mySetupMutualAuthServerIsValidException(cause) || causedBySSLException(cause); + } + + private void runTest() throws Exception { + runTest(PREFERRED_APPLICATION_LEVEL_PROTOCOL); + } + + @Override + protected SslProvider sslClientProvider() { + return SslProvider.JDK; + } + + @Override + protected SslProvider sslServerProvider() { + return SslProvider.JDK; + } + + @Override + protected Provider clientSslContextProvider() { + return provider; + } + + @Override + protected Provider serverSslContextProvider() { + return provider; + } + + private static ApplicationProtocolConfig failingNegotiator(Protocol protocol, String... supportedProtocols) { + return new ApplicationProtocolConfig(protocol, + SelectorFailureBehavior.FATAL_ALERT, + SelectedListenerFailureBehavior.FATAL_ALERT, + supportedProtocols); + } + + private static ApplicationProtocolConfig acceptingNegotiator(Protocol protocol, String... supportedProtocols) { + return new ApplicationProtocolConfig(protocol, + SelectorFailureBehavior.NO_ADVERTISE, + SelectedListenerFailureBehavior.ACCEPT, + supportedProtocols); + } + + private static SkipTestException tlsExtensionNotFound(Protocol protocol) { + throw new SkipTestException(protocol + " not on classpath"); + } + + private static final class SkipTestException extends RuntimeException { + private static final long serialVersionUID = 9214869217774035223L; + + SkipTestException(String message) { + super(message); + } + } + + private static final class TestDelegatingSslContext extends DelegatingSslContext { + private final SSLEngineTestParam param; + + TestDelegatingSslContext(SSLEngineTestParam param, SslContext ctx) { + super(ctx); + this.param = param; + } + + @Override + protected void initEngine(SSLEngine engine) { + engine.setEnabledProtocols(param.protocols().toArray(EmptyArrays.EMPTY_STRINGS)); + engine.setEnabledCipherSuites(param.ciphers().toArray(EmptyArrays.EMPTY_STRINGS)); + } + } +} diff --git a/netty-handler-ssl/src/test/java/io/netty/handler/ssl/JdkSslRenegotiateTest.java b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/JdkSslRenegotiateTest.java new file mode 100644 index 0000000..984f9b5 --- /dev/null +++ b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/JdkSslRenegotiateTest.java @@ -0,0 +1,24 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +public class JdkSslRenegotiateTest extends RenegotiateTest { + + @Override + protected SslProvider serverSslProvider() { + return SslProvider.JDK; + } +} diff --git a/netty-handler-ssl/src/test/java/io/netty/handler/ssl/JdkSslServerContextTest.java b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/JdkSslServerContextTest.java new file mode 100644 index 0000000..06fa1ac --- /dev/null +++ b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/JdkSslServerContextTest.java @@ -0,0 +1,27 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import javax.net.ssl.SSLException; +import java.io.File; + +public class JdkSslServerContextTest extends SslContextTest { + + @Override + protected SslContext newSslContext(File crtFile, File keyFile, String pass) throws SSLException { + return new JdkSslServerContext(crtFile, keyFile, pass); + } +} diff --git a/netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslCachingKeyMaterialProviderTest.java b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslCachingKeyMaterialProviderTest.java new file mode 100644 index 0000000..969470e --- /dev/null +++ b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslCachingKeyMaterialProviderTest.java @@ -0,0 +1,92 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.buffer.UnpooledByteBufAllocator; +import org.hamcrest.CoreMatchers; +import org.junit.jupiter.api.Test; + +import javax.net.ssl.KeyManagerFactory; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; + +public class OpenSslCachingKeyMaterialProviderTest extends OpenSslKeyMaterialProviderTest { + + @Override + protected KeyManagerFactory newKeyManagerFactory() throws Exception { + return new OpenSslCachingX509KeyManagerFactory(super.newKeyManagerFactory()); + } + + @Override + protected OpenSslKeyMaterialProvider newMaterialProvider(KeyManagerFactory factory, String password) { + return new OpenSslCachingKeyMaterialProvider(ReferenceCountedOpenSslContext.chooseX509KeyManager( + factory.getKeyManagers()), password, Integer.MAX_VALUE); + } + + @Override + protected void assertRelease(OpenSslKeyMaterial material) { + assertFalse(material.release()); + } + + @Test + public void testMaterialCached() throws Exception { + OpenSslKeyMaterialProvider provider = newMaterialProvider(newKeyManagerFactory(), PASSWORD); + + OpenSslKeyMaterial material = provider.chooseKeyMaterial(UnpooledByteBufAllocator.DEFAULT, EXISTING_ALIAS); + assertNotNull(material); + assertNotEquals(0, material.certificateChainAddress()); + assertNotEquals(0, material.privateKeyAddress()); + assertEquals(2, material.refCnt()); + + OpenSslKeyMaterial material2 = provider.chooseKeyMaterial(UnpooledByteBufAllocator.DEFAULT, EXISTING_ALIAS); + assertNotNull(material2); + assertEquals(material.certificateChainAddress(), material2.certificateChainAddress()); + assertEquals(material.privateKeyAddress(), material2.privateKeyAddress()); + assertEquals(3, material.refCnt()); + assertEquals(3, material2.refCnt()); + + assertFalse(material.release()); + assertFalse(material2.release()); + + // After this the material should have been released. + provider.destroy(); + + assertEquals(0, material.refCnt()); + assertEquals(0, material2.refCnt()); + } + + @Test + public void testCacheForSunX509() throws Exception { + OpenSslCachingX509KeyManagerFactory factory = new OpenSslCachingX509KeyManagerFactory( + super.newKeyManagerFactory("SunX509")); + OpenSslKeyMaterialProvider provider = factory.newProvider(PASSWORD); + assertThat(provider, + CoreMatchers.instanceOf(OpenSslCachingKeyMaterialProvider.class)); + } + + @Test + public void testNotCacheForX509() throws Exception { + OpenSslCachingX509KeyManagerFactory factory = new OpenSslCachingX509KeyManagerFactory( + super.newKeyManagerFactory("PKIX")); + OpenSslKeyMaterialProvider provider = factory.newProvider(PASSWORD); + assertThat(provider, CoreMatchers.not( + CoreMatchers.instanceOf(OpenSslCachingKeyMaterialProvider.class))); + } +} diff --git a/netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslCertificateCompressionTest.java b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslCertificateCompressionTest.java new file mode 100644 index 0000000..9882c72 --- /dev/null +++ b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslCertificateCompressionTest.java @@ -0,0 +1,441 @@ +/* + * Copyright 2022 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version + * 2.0 (the "License"); you may not use this file except in compliance with the + * License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.ssl; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.channel.local.LocalAddress; +import io.netty.channel.local.LocalChannel; +import io.netty.channel.local.LocalEventLoopGroup; +import io.netty.channel.local.LocalServerChannel; +import io.netty.handler.ssl.util.InsecureTrustManagerFactory; +import io.netty.handler.ssl.util.SelfSignedCertificate; +import io.netty.internal.tcnative.CertificateCompressionAlgo; +import io.netty.util.concurrent.Promise; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLException; +import javax.net.ssl.SSLHandshakeException; +import java.util.concurrent.TimeUnit; + +import static org.junit.jupiter.api.Assumptions.assumeTrue; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class OpenSslCertificateCompressionTest { + + private static SelfSignedCertificate cert; + private TestCertCompressionAlgo testZLibAlgoServer; + private TestCertCompressionAlgo testBrotliAlgoServer; + private TestCertCompressionAlgo testZstdAlgoServer; + private TestCertCompressionAlgo testZlibAlgoClient; + private TestCertCompressionAlgo testBrotliAlgoClient; + + @BeforeAll + public static void init() throws Exception { + assumeTrue(OpenSsl.isTlsv13Supported()); + cert = new SelfSignedCertificate(); + } + + @BeforeEach + public void refreshAlgos() { + testZLibAlgoServer = new TestCertCompressionAlgo(CertificateCompressionAlgo.TLS_EXT_CERT_COMPRESSION_ZLIB); + testBrotliAlgoServer = new TestCertCompressionAlgo(CertificateCompressionAlgo.TLS_EXT_CERT_COMPRESSION_BROTLI); + testZstdAlgoServer = new TestCertCompressionAlgo(CertificateCompressionAlgo.TLS_EXT_CERT_COMPRESSION_ZSTD); + testZlibAlgoClient = new TestCertCompressionAlgo(CertificateCompressionAlgo.TLS_EXT_CERT_COMPRESSION_ZLIB); + testBrotliAlgoClient = new TestCertCompressionAlgo(CertificateCompressionAlgo.TLS_EXT_CERT_COMPRESSION_BROTLI); + } + + @Test + public void testSimple() throws Throwable { + assumeTrue(OpenSsl.isBoringSSL()); + final SslContext clientSslContext = buildClientContext( + OpenSslCertificateCompressionConfig.newBuilder() + .addAlgorithm(testBrotliAlgoClient, + OpenSslCertificateCompressionConfig.AlgorithmMode.Decompress) + .build() + ); + final SslContext serverSslContext = buildServerContext( + OpenSslCertificateCompressionConfig.newBuilder() + .addAlgorithm(testBrotliAlgoServer, OpenSslCertificateCompressionConfig.AlgorithmMode.Compress) + .build() + ); + + runCertCompressionTest(clientSslContext, serverSslContext); + + assertCompress(testBrotliAlgoServer); + assertDecompress(testBrotliAlgoClient); + } + + @Test + public void testServerPriority() throws Throwable { + assumeTrue(OpenSsl.isBoringSSL()); + final SslContext clientSslContext = buildClientContext( + OpenSslCertificateCompressionConfig.newBuilder() + .addAlgorithm(testBrotliAlgoClient, + OpenSslCertificateCompressionConfig.AlgorithmMode.Decompress) + .addAlgorithm(testZlibAlgoClient, OpenSslCertificateCompressionConfig.AlgorithmMode.Decompress) + .build() + ); + final SslContext serverSslContext = buildServerContext( + OpenSslCertificateCompressionConfig.newBuilder() + .addAlgorithm(testZLibAlgoServer, OpenSslCertificateCompressionConfig.AlgorithmMode.Compress) + .addAlgorithm(testBrotliAlgoServer, OpenSslCertificateCompressionConfig.AlgorithmMode.Compress) + .build() + ); + + runCertCompressionTest(clientSslContext, serverSslContext); + + assertCompress(testZLibAlgoServer); + assertDecompress(testZlibAlgoClient); + assertNone(testBrotliAlgoClient, testBrotliAlgoServer); + } + + @Test + public void testServerPriorityReverse() throws Throwable { + assumeTrue(OpenSsl.isBoringSSL()); + final SslContext clientSslContext = buildClientContext( + OpenSslCertificateCompressionConfig.newBuilder() + .addAlgorithm(testBrotliAlgoClient, + OpenSslCertificateCompressionConfig.AlgorithmMode.Decompress) + .addAlgorithm(testZlibAlgoClient, OpenSslCertificateCompressionConfig.AlgorithmMode.Decompress) + .build() + ); + final SslContext serverSslContext = buildServerContext( + OpenSslCertificateCompressionConfig.newBuilder() + .addAlgorithm(testBrotliAlgoServer, + OpenSslCertificateCompressionConfig.AlgorithmMode.Compress) + .addAlgorithm(testZLibAlgoServer, OpenSslCertificateCompressionConfig.AlgorithmMode.Compress) + .build() + ); + + runCertCompressionTest(clientSslContext, serverSslContext); + + assertCompress(testBrotliAlgoServer); + assertDecompress(testBrotliAlgoClient); + assertNone(testZLibAlgoServer, testZlibAlgoClient); + } + + @Test + public void testFailedNegotiation() throws Throwable { + assumeTrue(OpenSsl.isBoringSSL()); + final SslContext clientSslContext = buildClientContext( + OpenSslCertificateCompressionConfig.newBuilder() + .addAlgorithm(testBrotliAlgoClient, + OpenSslCertificateCompressionConfig.AlgorithmMode.Decompress) + .addAlgorithm(testZlibAlgoClient, OpenSslCertificateCompressionConfig.AlgorithmMode.Decompress) + .build() + ); + final SslContext serverSslContext = buildServerContext( + OpenSslCertificateCompressionConfig.newBuilder() + .addAlgorithm(testZstdAlgoServer, OpenSslCertificateCompressionConfig.AlgorithmMode.Compress) + .build() + ); + + runCertCompressionTest(clientSslContext, serverSslContext); + + assertNone(testBrotliAlgoClient, testZlibAlgoClient, testZstdAlgoServer); + } + + @Test + public void testAlgoFailure() throws Throwable { + assumeTrue(OpenSsl.isBoringSSL()); + TestCertCompressionAlgo badZlibAlgoClient = + new TestCertCompressionAlgo(CertificateCompressionAlgo.TLS_EXT_CERT_COMPRESSION_ZLIB) { + @Override + public byte[] decompress(SSLEngine engine, int uncompressed_len, byte[] input) { + return input; + } + }; + final SslContext clientSslContext = buildClientContext( + OpenSslCertificateCompressionConfig.newBuilder() + .addAlgorithm(badZlibAlgoClient, OpenSslCertificateCompressionConfig.AlgorithmMode.Decompress) + .build() + ); + final SslContext serverSslContext = buildServerContext( + OpenSslCertificateCompressionConfig.newBuilder() + .addAlgorithm(testZLibAlgoServer, OpenSslCertificateCompressionConfig.AlgorithmMode.Compress) + .build() + ); + + Assertions.assertThrows(SSLHandshakeException.class, new Executable() { + @Override + public void execute() throws Throwable { + runCertCompressionTest(clientSslContext, serverSslContext); + } + }); + } + + @Test + public void testAlgoException() throws Throwable { + assumeTrue(OpenSsl.isBoringSSL()); + TestCertCompressionAlgo badZlibAlgoClient = + new TestCertCompressionAlgo(CertificateCompressionAlgo.TLS_EXT_CERT_COMPRESSION_ZLIB) { + @Override + public byte[] decompress(SSLEngine engine, int uncompressed_len, byte[] input) { + throw new RuntimeException("broken"); + } + }; + final SslContext clientSslContext = buildClientContext( + OpenSslCertificateCompressionConfig.newBuilder() + .addAlgorithm(badZlibAlgoClient, OpenSslCertificateCompressionConfig.AlgorithmMode.Decompress) + .build() + ); + final SslContext serverSslContext = buildServerContext( + OpenSslCertificateCompressionConfig.newBuilder() + .addAlgorithm(testZLibAlgoServer, OpenSslCertificateCompressionConfig.AlgorithmMode.Compress) + .build() + ); + + Assertions.assertThrows(SSLHandshakeException.class, new Executable() { + @Override + public void execute() throws Throwable { + runCertCompressionTest(clientSslContext, serverSslContext); + } + }); + } + + @Test + public void testTlsLessThan13() throws Throwable { + assumeTrue(OpenSsl.isBoringSSL()); + final SslContext clientSslContext = SslContextBuilder.forClient() + .sslProvider(SslProvider.OPENSSL) + .protocols(SslProtocols.TLS_v1_2) + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .option(OpenSslContextOption.CERTIFICATE_COMPRESSION_ALGORITHMS, + OpenSslCertificateCompressionConfig.newBuilder() + .addAlgorithm(testBrotliAlgoClient, + OpenSslCertificateCompressionConfig.AlgorithmMode.Decompress) + .build()) + .build(); + final SslContext serverSslContext = SslContextBuilder.forServer(cert.key(), cert.cert()) + .sslProvider(SslProvider.OPENSSL) + .protocols(SslProtocols.TLS_v1_2) + .option(OpenSslContextOption.CERTIFICATE_COMPRESSION_ALGORITHMS, + OpenSslCertificateCompressionConfig.newBuilder() + .addAlgorithm(testBrotliAlgoServer, + OpenSslCertificateCompressionConfig.AlgorithmMode.Compress) + .build()) + .build(); + + runCertCompressionTest(clientSslContext, serverSslContext); + + // BoringSSL returns success when calling SSL_CTX_add_cert_compression_alg + // but only applies compression for TLSv1.3 + assertNone(testBrotliAlgoClient, testBrotliAlgoServer); + } + + @Test + public void testDuplicateAdd() throws Throwable { + // Fails with "Failed trying to add certificate compression algorithm" + assumeTrue(OpenSsl.isBoringSSL()); + Assertions.assertThrows(Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + buildClientContext( + OpenSslCertificateCompressionConfig.newBuilder() + .addAlgorithm(testBrotliAlgoClient, + OpenSslCertificateCompressionConfig.AlgorithmMode.Decompress) + .addAlgorithm(testBrotliAlgoClient, + OpenSslCertificateCompressionConfig.AlgorithmMode.Compress) + .build() + ); + } + }); + + Assertions.assertThrows(Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + buildServerContext( + OpenSslCertificateCompressionConfig.newBuilder() + .addAlgorithm(testBrotliAlgoServer, + OpenSslCertificateCompressionConfig.AlgorithmMode.Compress) + .addAlgorithm(testBrotliAlgoServer, + OpenSslCertificateCompressionConfig.AlgorithmMode.Both).build() + ); + } + }); + } + + @Test + public void testNotBoringAdd() throws Throwable { + // Fails with "TLS Cert Compression only supported by BoringSSL" + assumeTrue(!OpenSsl.isBoringSSL()); + Assertions.assertThrows(Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + buildClientContext( + OpenSslCertificateCompressionConfig.newBuilder() + .addAlgorithm(testBrotliAlgoClient, + OpenSslCertificateCompressionConfig.AlgorithmMode.Decompress) + .build() + ); + } + }); + + Assertions.assertThrows(Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + buildServerContext( + OpenSslCertificateCompressionConfig.newBuilder() + .addAlgorithm(testBrotliAlgoServer, + OpenSslCertificateCompressionConfig.AlgorithmMode.Compress) + .build() + ); + } + }); + } + + public void runCertCompressionTest(SslContext clientSslContext, SslContext serverSslContext) throws Throwable { + EventLoopGroup group = new LocalEventLoopGroup(); + Promise clientPromise = group.next().newPromise(); + Promise serverPromise = group.next().newPromise(); + try { + ServerBootstrap sb = new ServerBootstrap(); + sb.group(group).channel(LocalServerChannel.class) + .childHandler(new CertCompressionTestChannelInitializer(serverPromise, serverSslContext)); + Channel serverChannel = sb.bind(new LocalAddress("testCertificateCompression")) + .syncUninterruptibly().channel(); + + Bootstrap bootstrap = new Bootstrap(); + bootstrap.group(group).channel(LocalChannel.class) + .handler(new CertCompressionTestChannelInitializer(clientPromise, clientSslContext)); + + Channel clientChannel = bootstrap.connect(serverChannel.localAddress()).syncUninterruptibly().channel(); + + assertTrue(clientPromise.await(5L, TimeUnit.SECONDS), "client timeout"); + assertTrue(serverPromise.await(5L, TimeUnit.SECONDS), "server timeout"); + clientPromise.sync(); + serverPromise.sync(); + clientChannel.close().syncUninterruptibly(); + serverChannel.close().syncUninterruptibly(); + } finally { + group.shutdownGracefully(); + } + } + + private SslContext buildServerContext(OpenSslCertificateCompressionConfig compressionConfig) throws SSLException { + return SslContextBuilder.forServer(cert.key(), cert.cert()) + .sslProvider(SslProvider.OPENSSL) + .protocols(SslProtocols.TLS_v1_3) + .option(OpenSslContextOption.CERTIFICATE_COMPRESSION_ALGORITHMS, + compressionConfig) + .build(); + } + + private SslContext buildClientContext(OpenSslCertificateCompressionConfig compressionConfig) throws SSLException { + return SslContextBuilder.forClient() + .sslProvider(SslProvider.OPENSSL) + .protocols(SslProtocols.TLS_v1_3) + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .option(OpenSslContextOption.CERTIFICATE_COMPRESSION_ALGORITHMS, + compressionConfig) + .build(); + } + + private void assertCompress(TestCertCompressionAlgo algo) { + assertTrue(algo.compressCalled && !algo.decompressCalled); + } + + private void assertDecompress(TestCertCompressionAlgo algo) { + assertTrue(!algo.compressCalled && algo.decompressCalled); + } + + private void assertNone(TestCertCompressionAlgo... algos) { + for (TestCertCompressionAlgo algo : algos) { + assertTrue(!algo.compressCalled && !algo.decompressCalled); + } + } + + private static class CertCompressionTestChannelInitializer extends ChannelInitializer { + + private final Promise channelPromise; + private final SslContext sslContext; + + CertCompressionTestChannelInitializer(Promise channelPromise, SslContext sslContext) { + this.channelPromise = channelPromise; + this.sslContext = sslContext; + } + + @Override + protected void initChannel(Channel ch) { + ChannelPipeline pipeline = ch.pipeline(); + pipeline.addLast(sslContext.newHandler(ch.alloc())); + pipeline.addLast(new SimpleChannelInboundHandler() { + + @Override + public void channelRead0(ChannelHandlerContext ctx, Object msg) { + // Do nothing + } + + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + if (evt instanceof SslHandshakeCompletionEvent) { + if (((SslHandshakeCompletionEvent) evt).isSuccess()) { + channelPromise.trySuccess(evt); + } else { + channelPromise.tryFailure(((SslHandshakeCompletionEvent) evt).cause()); + } + } + ctx.fireUserEventTriggered(evt); + } + }); + } + } + + private static class TestCertCompressionAlgo implements OpenSslCertificateCompressionAlgorithm { + + private static final int BASE_PADDING_SIZE = 10; + public boolean compressCalled; + public boolean decompressCalled; + private final int algorithmId; + + TestCertCompressionAlgo(int algorithmId) { + this.algorithmId = algorithmId; + } + + @Override + public byte[] compress(SSLEngine engine, byte[] input) throws Exception { + compressCalled = true; + byte[] output = new byte[input.length + BASE_PADDING_SIZE + algorithmId]; + System.arraycopy(input, 0, output, BASE_PADDING_SIZE + algorithmId, input.length); + return output; + } + + @Override + public byte[] decompress(SSLEngine engine, int uncompressed_len, byte[] input) { + decompressCalled = true; + byte[] output = new byte[input.length - (BASE_PADDING_SIZE + algorithmId)]; + System.arraycopy(input, BASE_PADDING_SIZE + algorithmId, output, 0, output.length); + return output; + } + + @Override + public int algorithmId() { + return algorithmId; + } + } +} diff --git a/netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslCertificateExceptionTest.java b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslCertificateExceptionTest.java new file mode 100644 index 0000000..6380855 --- /dev/null +++ b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslCertificateExceptionTest.java @@ -0,0 +1,61 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.internal.tcnative.CertificateVerifier; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import java.lang.reflect.Field; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class OpenSslCertificateExceptionTest { + + @BeforeAll + public static void ensureOpenSsl() { + OpenSsl.ensureAvailability(); + } + + @Test + public void testValidErrorCode() throws Exception { + Field[] fields = CertificateVerifier.class.getFields(); + for (Field field : fields) { + if (field.isAccessible()) { + int errorCode = field.getInt(null); + OpenSslCertificateException exception = new OpenSslCertificateException(errorCode); + assertEquals(errorCode, exception.errorCode()); + } + } + } + + @Test + public void testNonValidErrorCode() { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() throws Throwable { + new OpenSslCertificateException(Integer.MIN_VALUE); + } + }); + } + + @Test + public void testCanBeInstancedWhenOpenSslIsNotAvailable() { + new OpenSslCertificateException(0); + } +} diff --git a/netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslClientContextTest.java b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslClientContextTest.java new file mode 100644 index 0000000..7c0e367 --- /dev/null +++ b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslClientContextTest.java @@ -0,0 +1,36 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.handler.ssl.util.InsecureTrustManagerFactory; +import org.junit.jupiter.api.BeforeAll; + +import javax.net.ssl.SSLException; +import java.io.File; + +public class OpenSslClientContextTest extends SslContextTest { + + @BeforeAll + public static void checkOpenSsl() { + OpenSsl.ensureAvailability(); + } + + @Override + protected SslContext newSslContext(File crtFile, File keyFile, String pass) throws SSLException { + return new OpenSslClientContext(crtFile, InsecureTrustManagerFactory.INSTANCE, crtFile, keyFile, pass, + null, null, IdentityCipherSuiteFilter.INSTANCE, ApplicationProtocolConfig.DISABLED, 0, 0); + } +} diff --git a/netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslConscryptSslEngineInteropTest.java b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslConscryptSslEngineInteropTest.java new file mode 100644 index 0000000..b5d4903 --- /dev/null +++ b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslConscryptSslEngineInteropTest.java @@ -0,0 +1,191 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.condition.DisabledIf; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLSessionContext; +import java.security.Provider; +import java.util.ArrayList; +import java.util.List; + +import static io.netty.handler.ssl.OpenSslTestUtils.checkShouldUseKeyManagerFactory; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +@DisabledIf("checkConscryptDisabled") +public class OpenSslConscryptSslEngineInteropTest extends ConscryptSslEngineTest { + @Override + protected List newTestParams() { + List params = super.newTestParams(); + List testParams = new ArrayList(); + for (SSLEngineTestParam param: params) { + testParams.add(new OpenSslEngineTestParam(true, param)); + testParams.add(new OpenSslEngineTestParam(false, param)); + } + return testParams; + } + + @BeforeAll + public static void checkOpenssl() { + OpenSsl.ensureAvailability(); + } + + @Override + protected SslProvider sslClientProvider() { + return SslProvider.OPENSSL; + } + + @Override + protected SslProvider sslServerProvider() { + return SslProvider.JDK; + } + + @Override + protected Provider clientSslContextProvider() { + return null; + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Disabled("TODO: Make this work with Conscrypt") + @Override + public void testMutualAuthValidClientCertChainTooLongFailOptionalClientAuth(SSLEngineTestParam param) { + super.testMutualAuthValidClientCertChainTooLongFailOptionalClientAuth(param); + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Disabled("TODO: Make this work with Conscrypt") + @Override + public void testMutualAuthValidClientCertChainTooLongFailRequireClientAuth(SSLEngineTestParam param) { + super.testMutualAuthValidClientCertChainTooLongFailRequireClientAuth(param); + } + + @Override + protected boolean mySetupMutualAuthServerIsValidClientException(Throwable cause) { + // TODO(scott): work around for a JDK issue. The exception should be SSLHandshakeException. + return super.mySetupMutualAuthServerIsValidClientException(cause) || causedBySSLException(cause); + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Override + public void testMutualAuthInvalidIntermediateCASucceedWithOptionalClientAuth(SSLEngineTestParam param) + throws Exception { + checkShouldUseKeyManagerFactory(); + super.testMutualAuthInvalidIntermediateCASucceedWithOptionalClientAuth(param); + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Override + public void testMutualAuthInvalidIntermediateCAFailWithOptionalClientAuth(SSLEngineTestParam param) + throws Exception { + checkShouldUseKeyManagerFactory(); + super.testMutualAuthInvalidIntermediateCAFailWithOptionalClientAuth(param); + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Override + public void testMutualAuthInvalidIntermediateCAFailWithRequiredClientAuth(SSLEngineTestParam param) + throws Exception { + checkShouldUseKeyManagerFactory(); + super.testMutualAuthInvalidIntermediateCAFailWithRequiredClientAuth(param); + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Override + public void testSessionAfterHandshakeKeyManagerFactoryMutualAuth(SSLEngineTestParam param) throws Exception { + checkShouldUseKeyManagerFactory(); + super.testSessionAfterHandshakeKeyManagerFactoryMutualAuth(param); + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Override + public void testSupportedSignatureAlgorithms(SSLEngineTestParam param) throws Exception { + checkShouldUseKeyManagerFactory(); + super.testSupportedSignatureAlgorithms(param); + } + + @Override + protected boolean mySetupMutualAuthServerIsValidServerException(Throwable cause) { + // TODO(scott): work around for a JDK issue. The exception should be SSLHandshakeException. + return super.mySetupMutualAuthServerIsValidServerException(cause) || causedBySSLException(cause); + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Override + public void testSessionLocalWhenNonMutualWithKeyManager(SSLEngineTestParam param) throws Exception { + checkShouldUseKeyManagerFactory(); + super.testSessionLocalWhenNonMutualWithKeyManager(param); + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Override + public void testSessionCache(SSLEngineTestParam param) throws Exception { + assumeTrue(OpenSsl.isSessionCacheSupported()); + super.testSessionCache(param); + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Override + public void testSessionCacheTimeout(SSLEngineTestParam param) throws Exception { + assumeTrue(OpenSsl.isSessionCacheSupported()); + super.testSessionCacheTimeout(param); + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Override + public void testSessionCacheSize(SSLEngineTestParam param) throws Exception { + assumeTrue(OpenSsl.isSessionCacheSupported()); + super.testSessionCacheSize(param); + } + + @Override + protected void invalidateSessionsAndAssert(SSLSessionContext context) { + // Not supported by conscrypt + } + + @Override + protected SSLEngine wrapEngine(SSLEngine engine) { + return Java8SslTestUtils.wrapSSLEngineForTesting(engine); + } + + @SuppressWarnings("deprecation") + @Override + protected SslContext wrapContext(SSLEngineTestParam param, SslContext context) { + if (context instanceof OpenSslContext) { + if (param instanceof OpenSslEngineTestParam) { + ((OpenSslContext) context).setUseTasks(((OpenSslEngineTestParam) param).useTasks); + } + // Explicit enable the session cache as its disabled by default on the client side. + ((OpenSslContext) context).sessionContext().setSessionCacheEnabled(true); + } + return context; + } +} diff --git a/netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslEngineTest.java b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslEngineTest.java new file mode 100644 index 0000000..ffbcd4a --- /dev/null +++ b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslEngineTest.java @@ -0,0 +1,1649 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.buffer.UnpooledByteBufAllocator; +import io.netty.handler.ssl.ApplicationProtocolConfig.Protocol; +import io.netty.handler.ssl.ApplicationProtocolConfig.SelectedListenerFailureBehavior; +import io.netty.handler.ssl.ApplicationProtocolConfig.SelectorFailureBehavior; +import io.netty.handler.ssl.util.InsecureTrustManagerFactory; +import io.netty.handler.ssl.util.SelfSignedCertificate; +import io.netty.internal.tcnative.SSL; +import io.netty.util.CharsetUtil; +import io.netty.util.internal.EmptyArrays; +import io.netty.util.internal.PlatformDependent; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.opentest4j.TestAbortedException; + +import java.net.Socket; +import java.nio.ByteBuffer; +import java.security.AlgorithmConstraints; +import java.security.AlgorithmParameters; +import java.security.CryptoPrimitive; +import java.security.Key; +import java.security.Principal; +import java.security.PrivateKey; +import java.security.cert.X509Certificate; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Set; +import javax.crypto.Cipher; +import javax.crypto.spec.IvParameterSpec; +import javax.crypto.spec.SecretKeySpec; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLEngineResult; +import javax.net.ssl.SSLEngineResult.HandshakeStatus; +import javax.net.ssl.SSLException; +import javax.net.ssl.SSLHandshakeException; +import javax.net.ssl.SSLParameters; +import javax.net.ssl.X509ExtendedKeyManager; + +import static io.netty.handler.ssl.OpenSslContextOption.MAX_CERTIFICATE_LIST_BYTES; +import static io.netty.handler.ssl.OpenSslTestUtils.checkShouldUseKeyManagerFactory; +import static io.netty.handler.ssl.ReferenceCountedOpenSslEngine.MAX_PLAINTEXT_LENGTH; +import static io.netty.handler.ssl.SslProvider.OPENSSL; +import static io.netty.handler.ssl.SslProvider.isOptionSupported; +import static io.netty.internal.tcnative.SSL.SSL_CVERIFY_IGNORED; +import static java.lang.Integer.MAX_VALUE; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +public class OpenSslEngineTest extends SSLEngineTest { + private static final String PREFERRED_APPLICATION_LEVEL_PROTOCOL = "my-protocol-http2"; + private static final String FALLBACK_APPLICATION_LEVEL_PROTOCOL = "my-protocol-http1_1"; + + public OpenSslEngineTest() { + super(SslProvider.isTlsv13Supported(OPENSSL)); + } + + @Override + protected List newTestParams() { + List params = super.newTestParams(); + List testParams = new ArrayList<>(); + for (SSLEngineTestParam param: params) { + testParams.add(new OpenSslEngineTestParam(true, param)); + // TODO:: this hangs! JP, 5.1.2024 + //testParams.add(new OpenSslEngineTestParam(false, param)); + } + return testParams; + } + + @BeforeAll + public static void checkOpenSsl() { + OpenSsl.ensureAvailability(); + } + + @AfterEach + @Override + public void tearDown() throws InterruptedException { + super.tearDown(); + assertEquals(0, SSL.getLastErrorNumber(), "SSL error stack not correctly consumed"); + } + + @Override + public void testSessionAfterHandshakeKeyManagerFactory(SSLEngineTestParam param) throws Exception { + checkShouldUseKeyManagerFactory(); + super.testSessionAfterHandshakeKeyManagerFactory(param); + } + + @Override + public void testSessionAfterHandshakeKeyManagerFactoryMutualAuth(SSLEngineTestParam param) throws Exception { + checkShouldUseKeyManagerFactory(); + super.testSessionAfterHandshakeKeyManagerFactoryMutualAuth(param); + } + + @Override + public void testMutualAuthInvalidIntermediateCASucceedWithOptionalClientAuth(SSLEngineTestParam param) + throws Exception { + checkShouldUseKeyManagerFactory(); + super.testMutualAuthInvalidIntermediateCASucceedWithOptionalClientAuth(param); + } + + @Override + public void testMutualAuthInvalidIntermediateCAFailWithOptionalClientAuth(SSLEngineTestParam param) + throws Exception { + checkShouldUseKeyManagerFactory(); + super.testMutualAuthInvalidIntermediateCAFailWithOptionalClientAuth(param); + } + + @Override + public void testMutualAuthInvalidIntermediateCAFailWithRequiredClientAuth(SSLEngineTestParam param) + throws Exception { + checkShouldUseKeyManagerFactory(); + super.testMutualAuthInvalidIntermediateCAFailWithRequiredClientAuth(param); + } + + @Override + public void testMutualAuthValidClientCertChainTooLongFailOptionalClientAuth(SSLEngineTestParam param) + throws Exception { + checkShouldUseKeyManagerFactory(); + super.testMutualAuthValidClientCertChainTooLongFailOptionalClientAuth(param); + } + + @Override + public void testMutualAuthValidClientCertChainTooLongFailRequireClientAuth(SSLEngineTestParam param) + throws Exception { + checkShouldUseKeyManagerFactory(); + super.testMutualAuthValidClientCertChainTooLongFailRequireClientAuth(param); + } + + @Override + public void testHandshakeSession(SSLEngineTestParam param) throws Exception { + checkShouldUseKeyManagerFactory(); + super.testHandshakeSession(param); + } + + @Override + public void testSupportedSignatureAlgorithms(SSLEngineTestParam param) throws Exception { + checkShouldUseKeyManagerFactory(); + super.testSupportedSignatureAlgorithms(param); + } + + private static boolean isNpnSupported(String versionString) { + String[] versionStringParts = versionString.split(" ", -1); + if (versionStringParts.length == 2 && "LibreSSL".equals(versionStringParts[0])) { + String[] versionParts = versionStringParts[1].split("\\.", -1); + if (versionParts.length == 3) { + int major = Integer.parseInt(versionParts[0]); + if (major < 2) { + return true; + } + if (major > 2) { + return false; + } + int minor = Integer.parseInt(versionParts[1]); + if (minor < 6) { + return true; + } + if (minor > 6) { + return false; + } + int bugfix = Integer.parseInt(versionParts[2]); + if (bugfix > 0) { + return false; + } + } + } + return true; + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testNpn(SSLEngineTestParam param) throws Exception { + String versionString = OpenSsl.versionString(); + assumeTrue(isNpnSupported(versionString), "LibreSSL 2.6.1 removed NPN support, detected " + versionString); + ApplicationProtocolConfig apn = acceptingNegotiator(Protocol.NPN, + PREFERRED_APPLICATION_LEVEL_PROTOCOL); + setupHandlers(param, apn); + runTest(PREFERRED_APPLICATION_LEVEL_PROTOCOL); + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testAlpn(SSLEngineTestParam param) throws Exception { + assumeTrue(OpenSsl.isAlpnSupported()); + ApplicationProtocolConfig apn = acceptingNegotiator(Protocol.ALPN, + PREFERRED_APPLICATION_LEVEL_PROTOCOL); + setupHandlers(param, apn); + runTest(PREFERRED_APPLICATION_LEVEL_PROTOCOL); + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testAlpnCompatibleProtocolsDifferentClientOrder(SSLEngineTestParam param) throws Exception { + assumeTrue(OpenSsl.isAlpnSupported()); + ApplicationProtocolConfig clientApn = acceptingNegotiator(Protocol.ALPN, + FALLBACK_APPLICATION_LEVEL_PROTOCOL, PREFERRED_APPLICATION_LEVEL_PROTOCOL); + ApplicationProtocolConfig serverApn = acceptingNegotiator(Protocol.ALPN, + PREFERRED_APPLICATION_LEVEL_PROTOCOL, FALLBACK_APPLICATION_LEVEL_PROTOCOL); + setupHandlers(param, serverApn, clientApn); + assertNull(serverException); + runTest(PREFERRED_APPLICATION_LEVEL_PROTOCOL); + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testEnablingAnAlreadyDisabledSslProtocol(SSLEngineTestParam param) throws Exception { + testEnablingAnAlreadyDisabledSslProtocol(param, new String[]{SslProtocols.SSL_v2_HELLO}, + new String[]{SslProtocols.SSL_v2_HELLO, SslProtocols.TLS_v1_2}); + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testWrapBuffersNoWritePendingError(SSLEngineTestParam param) throws Exception { + clientSslCtx = wrapContext(param, SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(sslClientProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + SelfSignedCertificate ssc = new SelfSignedCertificate(); + serverSslCtx = wrapContext(param, SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) + .sslProvider(sslServerProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + SSLEngine clientEngine = null; + SSLEngine serverEngine = null; + try { + clientEngine = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + serverEngine = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + handshake(param.type(), param.delegate(), clientEngine, serverEngine); + + ByteBuffer src = allocateBuffer(param.type(), 1024 * 10); + byte[] data = new byte[src.capacity()]; + PlatformDependent.threadLocalRandom().nextBytes(data); + src.put(data).flip(); + ByteBuffer dst = allocateBuffer(param.type(), 1); + // Try to wrap multiple times so we are more likely to hit the issue. + for (int i = 0; i < 100; i++) { + src.position(0); + dst.position(0); + assertSame(SSLEngineResult.Status.BUFFER_OVERFLOW, clientEngine.wrap(src, dst).getStatus()); + } + } finally { + cleanupClientSslEngine(clientEngine); + cleanupServerSslEngine(serverEngine); + } + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testOnlySmallBufferNeededForWrap(SSLEngineTestParam param) throws Exception { + clientSslCtx = wrapContext(param, SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(sslClientProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + SelfSignedCertificate ssc = new SelfSignedCertificate(); + serverSslCtx = wrapContext(param, SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) + .sslProvider(sslServerProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + SSLEngine clientEngine = null; + SSLEngine serverEngine = null; + try { + clientEngine = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + serverEngine = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + handshake(param.type(), param.delegate(), clientEngine, serverEngine); + + // Allocate a buffer which is small enough and set the limit to the capacity to mark its whole content + // as readable. + int srcLen = 1024; + ByteBuffer src = allocateBuffer(param.type(), srcLen); + + ByteBuffer dstTooSmall = allocateBuffer( + param.type(), src.capacity() + unwrapEngine(clientEngine).maxWrapOverhead() - 1); + ByteBuffer dst = allocateBuffer( + param.type(), src.capacity() + unwrapEngine(clientEngine).maxWrapOverhead()); + + // Check that we fail to wrap if the dst buffers capacity is not at least + // src.capacity() + ReferenceCountedOpenSslEngine.MAX_TLS_RECORD_OVERHEAD_LENGTH + SSLEngineResult result = clientEngine.wrap(src, dstTooSmall); + assertEquals(SSLEngineResult.Status.BUFFER_OVERFLOW, result.getStatus()); + assertEquals(0, result.bytesConsumed()); + assertEquals(0, result.bytesProduced()); + assertEquals(src.remaining(), src.capacity()); + assertEquals(dst.remaining(), dst.capacity()); + + // Check that we can wrap with a dst buffer that has the capacity of + // src.capacity() + ReferenceCountedOpenSslEngine.MAX_TLS_RECORD_OVERHEAD_LENGTH + result = clientEngine.wrap(src, dst); + assertEquals(SSLEngineResult.Status.OK, result.getStatus()); + assertEquals(srcLen, result.bytesConsumed()); + assertEquals(0, src.remaining()); + assertTrue(result.bytesProduced() > srcLen); + assertEquals(src.capacity() - result.bytesConsumed(), src.remaining()); + assertEquals(dst.capacity() - result.bytesProduced(), dst.remaining()); + } finally { + cleanupClientSslEngine(clientEngine); + cleanupServerSslEngine(serverEngine); + } + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testNeededDstCapacityIsCorrectlyCalculated(SSLEngineTestParam param) throws Exception { + clientSslCtx = wrapContext(param, SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(sslClientProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + SelfSignedCertificate ssc = new SelfSignedCertificate(); + serverSslCtx = wrapContext(param, SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) + .sslProvider(sslServerProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + SSLEngine clientEngine = null; + SSLEngine serverEngine = null; + try { + clientEngine = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + serverEngine = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + handshake(param.type(), param.delegate(), clientEngine, serverEngine); + + ByteBuffer src = allocateBuffer(param.type(), 1024); + ByteBuffer src2 = src.duplicate(); + + ByteBuffer dst = allocateBuffer(param.type(), src.capacity() + + unwrapEngine(clientEngine).maxWrapOverhead()); + + SSLEngineResult result = clientEngine.wrap(new ByteBuffer[] { src, src2 }, dst); + assertEquals(SSLEngineResult.Status.BUFFER_OVERFLOW, result.getStatus()); + assertEquals(0, src.position()); + assertEquals(0, src2.position()); + assertEquals(0, dst.position()); + assertEquals(0, result.bytesConsumed()); + assertEquals(0, result.bytesProduced()); + } finally { + cleanupClientSslEngine(clientEngine); + cleanupServerSslEngine(serverEngine); + } + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testSrcsLenOverFlowCorrectlyHandled(SSLEngineTestParam param) throws Exception { + clientSslCtx = wrapContext(param, SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(sslClientProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + SelfSignedCertificate ssc = new SelfSignedCertificate(); + serverSslCtx = wrapContext(param, SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) + .sslProvider(sslServerProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + SSLEngine clientEngine = null; + SSLEngine serverEngine = null; + try { + clientEngine = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + serverEngine = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + handshake(param.type(), param.delegate(), clientEngine, serverEngine); + + ByteBuffer src = allocateBuffer(param.type(), 1024); + List srcList = new ArrayList(); + long srcsLen = 0; + long maxLen = ((long) MAX_VALUE) * 2; + + while (srcsLen < maxLen) { + ByteBuffer dup = src.duplicate(); + srcList.add(dup); + srcsLen += dup.capacity(); + } + + ByteBuffer[] srcs = srcList.toArray(new ByteBuffer[0]); + ByteBuffer dst = allocateBuffer( + param.type(), unwrapEngine(clientEngine).maxEncryptedPacketLength() - 1); + + SSLEngineResult result = clientEngine.wrap(srcs, dst); + assertEquals(SSLEngineResult.Status.BUFFER_OVERFLOW, result.getStatus()); + + for (ByteBuffer buffer : srcs) { + assertEquals(0, buffer.position()); + } + assertEquals(0, dst.position()); + assertEquals(0, result.bytesConsumed()); + assertEquals(0, result.bytesProduced()); + } finally { + cleanupClientSslEngine(clientEngine); + cleanupServerSslEngine(serverEngine); + } + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testCalculateOutNetBufSizeOverflow(SSLEngineTestParam param) throws SSLException { + clientSslCtx = wrapContext(param, SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(sslClientProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + SSLEngine clientEngine = null; + try { + clientEngine = clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); + int value = ((ReferenceCountedOpenSslEngine) clientEngine).calculateOutNetBufSize(MAX_VALUE, 1); + assertTrue(value > 0); + } finally { + cleanupClientSslEngine(clientEngine); + } + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testCalculateOutNetBufSize0(SSLEngineTestParam param) throws SSLException { + clientSslCtx = wrapContext(param, SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(sslClientProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + SSLEngine clientEngine = null; + try { + clientEngine = clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); + assertTrue(((ReferenceCountedOpenSslEngine) clientEngine).calculateOutNetBufSize(0, 1) > 0); + } finally { + cleanupClientSslEngine(clientEngine); + } + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testCorrectlyCalculateSpaceForAlert(SSLEngineTestParam param) throws Exception { + testCorrectlyCalculateSpaceForAlert(param, true); + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testCorrectlyCalculateSpaceForAlertJDKCompatabilityModeOff(SSLEngineTestParam param) throws Exception { + testCorrectlyCalculateSpaceForAlert(param, false); + } + + private void testCorrectlyCalculateSpaceForAlert(SSLEngineTestParam param, boolean jdkCompatabilityMode) + throws Exception { + SelfSignedCertificate ssc = new SelfSignedCertificate(); + serverSslCtx = wrapContext(param, SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) + .sslProvider(sslServerProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + + clientSslCtx = wrapContext(param, SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(sslClientProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + SSLEngine clientEngine = null; + SSLEngine serverEngine = null; + try { + if (jdkCompatabilityMode) { + clientEngine = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + serverEngine = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + } else { + clientEngine = wrapEngine(clientSslCtx.newHandler(UnpooledByteBufAllocator.DEFAULT).engine()); + serverEngine = wrapEngine(serverSslCtx.newHandler(UnpooledByteBufAllocator.DEFAULT).engine()); + } + handshake(param.type(), param.delegate(), clientEngine, serverEngine); + + // This should produce an alert + clientEngine.closeOutbound(); + + ByteBuffer empty = allocateBuffer(param.type(), 0); + ByteBuffer dst = allocateBuffer(param.type(), clientEngine.getSession().getPacketBufferSize()); + // Limit to something that is guaranteed to be too small to hold an SSL Record. + dst.limit(1); + + // As we called closeOutbound() before this should produce a BUFFER_OVERFLOW. + SSLEngineResult result = clientEngine.wrap(empty, dst); + assertEquals(SSLEngineResult.Status.BUFFER_OVERFLOW, result.getStatus()); + + // This must calculate a length that can hold an alert at least (or more). + dst.limit(dst.capacity()); + + result = clientEngine.wrap(empty, dst); + assertEquals(SSLEngineResult.Status.CLOSED, result.getStatus()); + + // flip the buffer so we can verify we produced a full length buffer. + dst.flip(); + + int length = SslUtils.getEncryptedPacketLength(new ByteBuffer[] { dst }, 0); + assertEquals(length, dst.remaining()); + } finally { + cleanupClientSslEngine(clientEngine); + cleanupServerSslEngine(serverEngine); + ssc.delete(); + } + } + + @Override + protected void mySetupMutualAuthServerInitSslHandler(SslHandler handler) { + ReferenceCountedOpenSslEngine engine = (ReferenceCountedOpenSslEngine) handler.engine(); + engine.setVerify(SSL_CVERIFY_IGNORED, 1); + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testWrapWithDifferentSizesTLSv1(SSLEngineTestParam param) throws Exception { + clientSslCtx = wrapContext(param, SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(sslClientProvider()) + .build()); + SelfSignedCertificate ssc = new SelfSignedCertificate(); + serverSslCtx = wrapContext(param, SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) + .sslProvider(sslServerProvider()) + .build()); + + testWrapWithDifferentSizes(param, SslProtocols.TLS_v1, "AES128-SHA"); + testWrapWithDifferentSizes(param, SslProtocols.TLS_v1, "ECDHE-RSA-AES128-SHA"); + testWrapWithDifferentSizes(param, SslProtocols.TLS_v1, "DES-CBC3-SHA"); + testWrapWithDifferentSizes(param, SslProtocols.TLS_v1, "AECDH-DES-CBC3-SHA"); + testWrapWithDifferentSizes(param, SslProtocols.TLS_v1, "CAMELLIA128-SHA"); + testWrapWithDifferentSizes(param, SslProtocols.TLS_v1, "SEED-SHA"); + testWrapWithDifferentSizes(param, SslProtocols.TLS_v1, "RC4-MD5"); + testWrapWithDifferentSizes(param, SslProtocols.TLS_v1, "AES256-SHA"); + testWrapWithDifferentSizes(param, SslProtocols.TLS_v1, "ADH-DES-CBC3-SHA"); + testWrapWithDifferentSizes(param, SslProtocols.TLS_v1, "EDH-RSA-DES-CBC3-SHA"); + testWrapWithDifferentSizes(param, SslProtocols.TLS_v1, "ADH-RC4-MD5"); + testWrapWithDifferentSizes(param, SslProtocols.TLS_v1, "IDEA-CBC-SHA"); + testWrapWithDifferentSizes(param, SslProtocols.TLS_v1, "RC4-SHA"); + testWrapWithDifferentSizes(param, SslProtocols.TLS_v1, "CAMELLIA256-SHA"); + testWrapWithDifferentSizes(param, SslProtocols.TLS_v1, "AECDH-RC4-SHA"); + testWrapWithDifferentSizes(param, SslProtocols.TLS_v1, "ECDHE-RSA-DES-CBC3-SHA"); + testWrapWithDifferentSizes(param, SslProtocols.TLS_v1, "ECDHE-RSA-AES256-SHA"); + testWrapWithDifferentSizes(param, SslProtocols.TLS_v1, "ECDHE-RSA-RC4-SHA"); + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testWrapWithDifferentSizesTLSv1_1(SSLEngineTestParam param) throws Exception { + clientSslCtx = wrapContext(param, SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(sslClientProvider()) + .build()); + SelfSignedCertificate ssc = new SelfSignedCertificate(); + serverSslCtx = wrapContext(param, SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) + .sslProvider(sslServerProvider()) + .build()); + + testWrapWithDifferentSizes(param, SslProtocols.TLS_v1_1, "ECDHE-RSA-AES256-SHA"); + testWrapWithDifferentSizes(param, SslProtocols.TLS_v1_1, "AES256-SHA"); + testWrapWithDifferentSizes(param, SslProtocols.TLS_v1_1, "CAMELLIA256-SHA"); + testWrapWithDifferentSizes(param, SslProtocols.TLS_v1_1, "ECDHE-RSA-AES256-SHA"); + testWrapWithDifferentSizes(param, SslProtocols.TLS_v1_1, "SEED-SHA"); + testWrapWithDifferentSizes(param, SslProtocols.TLS_v1_1, "CAMELLIA128-SHA"); + testWrapWithDifferentSizes(param, SslProtocols.TLS_v1_1, "IDEA-CBC-SHA"); + testWrapWithDifferentSizes(param, SslProtocols.TLS_v1_1, "AECDH-RC4-SHA"); + testWrapWithDifferentSizes(param, SslProtocols.TLS_v1_1, "ADH-RC4-MD5"); + testWrapWithDifferentSizes(param, SslProtocols.TLS_v1_1, "RC4-SHA"); + testWrapWithDifferentSizes(param, SslProtocols.TLS_v1_1, "ECDHE-RSA-DES-CBC3-SHA"); + testWrapWithDifferentSizes(param, SslProtocols.TLS_v1_1, "EDH-RSA-DES-CBC3-SHA"); + testWrapWithDifferentSizes(param, SslProtocols.TLS_v1_1, "AECDH-DES-CBC3-SHA"); + testWrapWithDifferentSizes(param, SslProtocols.TLS_v1_1, "ADH-DES-CBC3-SHA"); + testWrapWithDifferentSizes(param, SslProtocols.TLS_v1_1, "DES-CBC3-SHA"); + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testWrapWithDifferentSizesTLSv1_2(SSLEngineTestParam param) throws Exception { + clientSslCtx = wrapContext(param, SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(sslClientProvider()) + .build()); + SelfSignedCertificate ssc = new SelfSignedCertificate(); + serverSslCtx = wrapContext(param, SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) + .sslProvider(sslServerProvider()) + .build()); + + testWrapWithDifferentSizes(param, SslProtocols.TLS_v1_2, "AES128-SHA"); + testWrapWithDifferentSizes(param, SslProtocols.TLS_v1_2, "ECDHE-RSA-AES128-SHA"); + testWrapWithDifferentSizes(param, SslProtocols.TLS_v1_2, "DES-CBC3-SHA"); + testWrapWithDifferentSizes(param, SslProtocols.TLS_v1_2, "AES128-GCM-SHA256"); + testWrapWithDifferentSizes(param, SslProtocols.TLS_v1_2, "ECDHE-RSA-AES256-SHA384"); + testWrapWithDifferentSizes(param, SslProtocols.TLS_v1_2, "AECDH-DES-CBC3-SHA"); + testWrapWithDifferentSizes(param, SslProtocols.TLS_v1_2, "AES256-GCM-SHA384"); + testWrapWithDifferentSizes(param, SslProtocols.TLS_v1_2, "AES256-SHA256"); + testWrapWithDifferentSizes(param, SslProtocols.TLS_v1_2, "ECDHE-RSA-AES128-GCM-SHA256"); + testWrapWithDifferentSizes(param, SslProtocols.TLS_v1_2, "ECDHE-RSA-AES128-SHA256"); + testWrapWithDifferentSizes(param, SslProtocols.TLS_v1_2, "CAMELLIA128-SHA"); + testWrapWithDifferentSizes(param, SslProtocols.TLS_v1_2, "SEED-SHA"); + testWrapWithDifferentSizes(param, SslProtocols.TLS_v1_2, "RC4-MD5"); + testWrapWithDifferentSizes(param, SslProtocols.TLS_v1_2, "AES256-SHA"); + testWrapWithDifferentSizes(param, SslProtocols.TLS_v1_2, "ADH-DES-CBC3-SHA"); + testWrapWithDifferentSizes(param, SslProtocols.TLS_v1_2, "EDH-RSA-DES-CBC3-SHA"); + testWrapWithDifferentSizes(param, SslProtocols.TLS_v1_2, "ADH-RC4-MD5"); + testWrapWithDifferentSizes(param, SslProtocols.TLS_v1_2, "RC4-SHA"); + testWrapWithDifferentSizes(param, SslProtocols.TLS_v1_2, "CAMELLIA256-SHA"); + testWrapWithDifferentSizes(param, SslProtocols.TLS_v1_2, "AES128-SHA256"); + testWrapWithDifferentSizes(param, SslProtocols.TLS_v1_2, "AECDH-RC4-SHA"); + testWrapWithDifferentSizes(param, SslProtocols.TLS_v1_2, "ECDHE-RSA-DES-CBC3-SHA"); + testWrapWithDifferentSizes(param, SslProtocols.TLS_v1_2, "ECDHE-RSA-AES256-GCM-SHA384"); + testWrapWithDifferentSizes(param, SslProtocols.TLS_v1_2, "ECDHE-RSA-AES256-SHA"); + testWrapWithDifferentSizes(param, SslProtocols.TLS_v1_2, "ECDHE-RSA-RC4-SHA"); + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testWrapWithDifferentSizesSSLv3(SSLEngineTestParam param) throws Exception { + clientSslCtx = wrapContext(param, SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(sslClientProvider()) + .build()); + SelfSignedCertificate ssc = new SelfSignedCertificate(); + serverSslCtx = wrapContext(param, SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) + .sslProvider(sslServerProvider()) + .build()); + + testWrapWithDifferentSizes(param, SslProtocols.SSL_v3, "ADH-AES128-SHA"); + testWrapWithDifferentSizes(param, SslProtocols.SSL_v3, "ADH-CAMELLIA128-SHA"); + testWrapWithDifferentSizes(param, SslProtocols.SSL_v3, "AECDH-AES128-SHA"); + testWrapWithDifferentSizes(param, SslProtocols.SSL_v3, "AECDH-DES-CBC3-SHA"); + testWrapWithDifferentSizes(param, SslProtocols.SSL_v3, "CAMELLIA128-SHA"); + testWrapWithDifferentSizes(param, SslProtocols.SSL_v3, "DHE-RSA-AES256-SHA"); + testWrapWithDifferentSizes(param, SslProtocols.SSL_v3, "SEED-SHA"); + testWrapWithDifferentSizes(param, SslProtocols.SSL_v3, "RC4-MD5"); + testWrapWithDifferentSizes(param, SslProtocols.SSL_v3, "ADH-AES256-SHA"); + testWrapWithDifferentSizes(param, SslProtocols.SSL_v3, "ADH-SEED-SHA"); + testWrapWithDifferentSizes(param, SslProtocols.SSL_v3, "ADH-DES-CBC3-SHA"); + testWrapWithDifferentSizes(param, SslProtocols.SSL_v3, "EDH-RSA-DES-CBC3-SHA"); + testWrapWithDifferentSizes(param, SslProtocols.SSL_v3, "ADH-RC4-MD5"); + testWrapWithDifferentSizes(param, SslProtocols.SSL_v3, "IDEA-CBC-SHA"); + testWrapWithDifferentSizes(param, SslProtocols.SSL_v3, "DHE-RSA-AES128-SHA"); + testWrapWithDifferentSizes(param, SslProtocols.SSL_v3, "RC4-SHA"); + testWrapWithDifferentSizes(param, SslProtocols.SSL_v3, "CAMELLIA256-SHA"); + testWrapWithDifferentSizes(param, SslProtocols.SSL_v3, "AECDH-RC4-SHA"); + testWrapWithDifferentSizes(param, SslProtocols.SSL_v3, "DHE-RSA-SEED-SHA"); + testWrapWithDifferentSizes(param, SslProtocols.SSL_v3, "AECDH-AES256-SHA"); + testWrapWithDifferentSizes(param, SslProtocols.SSL_v3, "ECDHE-RSA-DES-CBC3-SHA"); + testWrapWithDifferentSizes(param, SslProtocols.SSL_v3, "ADH-CAMELLIA256-SHA"); + testWrapWithDifferentSizes(param, SslProtocols.SSL_v3, "DHE-RSA-CAMELLIA256-SHA"); + testWrapWithDifferentSizes(param, SslProtocols.SSL_v3, "DHE-RSA-CAMELLIA128-SHA"); + testWrapWithDifferentSizes(param, SslProtocols.SSL_v3, "ECDHE-RSA-RC4-SHA"); + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testMultipleRecordsInOneBufferWithNonZeroPositionJDKCompatabilityModeOff(SSLEngineTestParam param) + throws Exception { + SelfSignedCertificate cert = new SelfSignedCertificate(); + + clientSslCtx = wrapContext(param, SslContextBuilder + .forClient() + .trustManager(cert.cert()) + .sslProvider(sslClientProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + SSLEngine client = wrapEngine(clientSslCtx.newHandler(UnpooledByteBufAllocator.DEFAULT).engine()); + + serverSslCtx = wrapContext(param, SslContextBuilder + .forServer(cert.certificate(), cert.privateKey()) + .sslProvider(sslServerProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + SSLEngine server = wrapEngine(serverSslCtx.newHandler(UnpooledByteBufAllocator.DEFAULT).engine()); + + try { + // Choose buffer size small enough that we can put multiple buffers into one buffer and pass it into the + // unwrap call without exceed MAX_ENCRYPTED_PACKET_LENGTH. + final int plainClientOutLen = 1024; + ByteBuffer plainClientOut = allocateBuffer(param.type(), plainClientOutLen); + ByteBuffer plainServerOut = allocateBuffer(param.type(), server.getSession().getApplicationBufferSize()); + + ByteBuffer encClientToServer = allocateBuffer(param.type(), client.getSession().getPacketBufferSize()); + + int positionOffset = 1; + // We need to be able to hold 2 records + positionOffset + ByteBuffer combinedEncClientToServer = allocateBuffer( + param.type(), encClientToServer.capacity() * 2 + positionOffset); + combinedEncClientToServer.position(positionOffset); + + handshake(param.type(), param.delegate(), client, server); + + plainClientOut.limit(plainClientOut.capacity()); + SSLEngineResult result = client.wrap(plainClientOut, encClientToServer); + assertEquals(plainClientOut.capacity(), result.bytesConsumed()); + assertTrue(result.bytesProduced() > 0); + + encClientToServer.flip(); + + // Copy the first record into the combined buffer + combinedEncClientToServer.put(encClientToServer); + + plainClientOut.clear(); + encClientToServer.clear(); + + result = client.wrap(plainClientOut, encClientToServer); + assertEquals(plainClientOut.capacity(), result.bytesConsumed()); + assertTrue(result.bytesProduced() > 0); + + encClientToServer.flip(); + + // Copy the first record into the combined buffer + combinedEncClientToServer.put(encClientToServer); + + encClientToServer.clear(); + + combinedEncClientToServer.flip(); + combinedEncClientToServer.position(positionOffset); + + // Make sure the limit takes positionOffset into account to the content we are looking at is correct. + combinedEncClientToServer.limit( + combinedEncClientToServer.limit() - positionOffset); + final int combinedEncClientToServerLen = combinedEncClientToServer.remaining(); + + result = server.unwrap(combinedEncClientToServer, plainServerOut); + assertEquals(0, combinedEncClientToServer.remaining()); + assertEquals(combinedEncClientToServerLen, result.bytesConsumed()); + assertEquals(plainClientOutLen, result.bytesProduced()); + } finally { + cert.delete(); + cleanupClientSslEngine(client); + cleanupServerSslEngine(server); + } + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testInputTooBigAndFillsUpBuffersJDKCompatabilityModeOff(SSLEngineTestParam param) throws Exception { + SelfSignedCertificate cert = new SelfSignedCertificate(); + + clientSslCtx = wrapContext(param, SslContextBuilder + .forClient() + .trustManager(cert.cert()) + .sslProvider(sslClientProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + SSLEngine client = wrapEngine(clientSslCtx.newHandler(UnpooledByteBufAllocator.DEFAULT).engine()); + + serverSslCtx = wrapContext(param, SslContextBuilder + .forServer(cert.certificate(), cert.privateKey()) + .sslProvider(sslServerProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + SSLEngine server = wrapEngine(serverSslCtx.newHandler(UnpooledByteBufAllocator.DEFAULT).engine()); + + try { + ByteBuffer plainClient = allocateBuffer(param.type(), MAX_PLAINTEXT_LENGTH + 100); + ByteBuffer plainClient2 = allocateBuffer(param.type(), 512); + ByteBuffer plainClientTotal = + allocateBuffer(param.type(), plainClient.capacity() + plainClient2.capacity()); + plainClientTotal.put(plainClient); + plainClientTotal.put(plainClient2); + plainClient.clear(); + plainClient2.clear(); + plainClientTotal.flip(); + + // The capacity is designed to trigger an overflow condition. + ByteBuffer encClientToServerTooSmall = allocateBuffer(param.type(), MAX_PLAINTEXT_LENGTH + 28); + ByteBuffer encClientToServer = allocateBuffer(param.type(), client.getSession().getApplicationBufferSize()); + ByteBuffer encClientToServerTotal = + allocateBuffer(param.type(), client.getSession().getApplicationBufferSize() << 1); + ByteBuffer plainServer = allocateBuffer(param.type(), server.getSession().getApplicationBufferSize() << 1); + + handshake(param.type(), param.delegate(), client, server); + + int plainClientRemaining = plainClient.remaining(); + int encClientToServerTooSmallRemaining = encClientToServerTooSmall.remaining(); + SSLEngineResult result = client.wrap(plainClient, encClientToServerTooSmall); + assertEquals(SSLEngineResult.Status.OK, result.getStatus()); + assertEquals(plainClientRemaining - plainClient.remaining(), result.bytesConsumed()); + assertEquals(encClientToServerTooSmallRemaining - encClientToServerTooSmall.remaining(), + result.bytesProduced()); + + result = client.wrap(plainClient, encClientToServerTooSmall); + assertEquals(SSLEngineResult.Status.BUFFER_OVERFLOW, result.getStatus()); + assertEquals(0, result.bytesConsumed()); + assertEquals(0, result.bytesProduced()); + + plainClientRemaining = plainClient.remaining(); + int encClientToServerRemaining = encClientToServer.remaining(); + result = client.wrap(plainClient, encClientToServer); + assertEquals(SSLEngineResult.Status.OK, result.getStatus()); + assertEquals(plainClientRemaining, result.bytesConsumed()); + assertEquals(encClientToServerRemaining - encClientToServer.remaining(), result.bytesProduced()); + assertEquals(0, plainClient.remaining()); + + final int plainClient2Remaining = plainClient2.remaining(); + encClientToServerRemaining = encClientToServer.remaining(); + result = client.wrap(plainClient2, encClientToServer); + assertEquals(SSLEngineResult.Status.OK, result.getStatus()); + assertEquals(plainClient2Remaining, result.bytesConsumed()); + assertEquals(encClientToServerRemaining - encClientToServer.remaining(), result.bytesProduced()); + + // Concatenate the too small buffer + encClientToServerTooSmall.flip(); + encClientToServer.flip(); + encClientToServerTotal.put(encClientToServerTooSmall); + encClientToServerTotal.put(encClientToServer); + encClientToServerTotal.flip(); + + // Unwrap in a single call. + final int encClientToServerTotalRemaining = encClientToServerTotal.remaining(); + result = server.unwrap(encClientToServerTotal, plainServer); + assertEquals(SSLEngineResult.Status.OK, result.getStatus()); + assertEquals(encClientToServerTotalRemaining, result.bytesConsumed()); + plainServer.flip(); + assertEquals(plainClientTotal, plainServer); + } finally { + cert.delete(); + cleanupClientSslEngine(client); + cleanupServerSslEngine(server); + } + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testPartialPacketUnwrapJDKCompatabilityModeOff(SSLEngineTestParam param) throws Exception { + SelfSignedCertificate cert = new SelfSignedCertificate(); + + clientSslCtx = wrapContext(param, SslContextBuilder + .forClient() + .trustManager(cert.cert()) + .sslProvider(sslClientProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + SSLEngine client = wrapEngine(clientSslCtx.newHandler(UnpooledByteBufAllocator.DEFAULT).engine()); + + serverSslCtx = wrapContext(param, SslContextBuilder + .forServer(cert.certificate(), cert.privateKey()) + .sslProvider(sslServerProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + SSLEngine server = wrapEngine(serverSslCtx.newHandler(UnpooledByteBufAllocator.DEFAULT).engine()); + + try { + ByteBuffer plainClient = allocateBuffer(param.type(), 1024); + ByteBuffer plainClient2 = allocateBuffer(param.type(), 512); + ByteBuffer plainClientTotal = + allocateBuffer(param.type(), plainClient.capacity() + plainClient2.capacity()); + plainClientTotal.put(plainClient); + plainClientTotal.put(plainClient2); + plainClient.clear(); + plainClient2.clear(); + plainClientTotal.flip(); + + ByteBuffer encClientToServer = allocateBuffer(param.type(), client.getSession().getPacketBufferSize()); + ByteBuffer plainServer = allocateBuffer(param.type(), server.getSession().getApplicationBufferSize()); + + handshake(param.type(), param.delegate(), client, server); + + SSLEngineResult result = client.wrap(plainClient, encClientToServer); + assertEquals(SSLEngineResult.Status.OK, result.getStatus()); + assertEquals(result.bytesConsumed(), plainClient.capacity()); + final int encClientLen = result.bytesProduced(); + + result = client.wrap(plainClient2, encClientToServer); + assertEquals(SSLEngineResult.Status.OK, result.getStatus()); + assertEquals(result.bytesConsumed(), plainClient2.capacity()); + final int encClientLen2 = result.bytesProduced(); + + // Flip so we can read it. + encClientToServer.flip(); + + // Consume a partial TLS packet. + ByteBuffer encClientFirstHalf = encClientToServer.duplicate(); + encClientFirstHalf.limit(encClientLen / 2); + result = server.unwrap(encClientFirstHalf, plainServer); + assertEquals(SSLEngineResult.Status.OK, result.getStatus()); + assertEquals(result.bytesConsumed(), encClientLen / 2); + encClientToServer.position(result.bytesConsumed()); + + // We now have half of the first packet and the whole second packet, so lets decode all but the last byte. + ByteBuffer encClientAllButLastByte = encClientToServer.duplicate(); + final int encClientAllButLastByteLen = encClientAllButLastByte.remaining() - 1; + encClientAllButLastByte.limit(encClientAllButLastByte.limit() - 1); + result = server.unwrap(encClientAllButLastByte, plainServer); + assertEquals(SSLEngineResult.Status.OK, result.getStatus()); + assertEquals(result.bytesConsumed(), encClientAllButLastByteLen); + encClientToServer.position(encClientToServer.position() + result.bytesConsumed()); + + // Read the last byte and verify the original content has been decrypted. + result = server.unwrap(encClientToServer, plainServer); + assertEquals(SSLEngineResult.Status.OK, result.getStatus()); + assertEquals(result.bytesConsumed(), 1); + plainServer.flip(); + assertEquals(plainClientTotal, plainServer); + } finally { + cert.delete(); + cleanupClientSslEngine(client); + cleanupServerSslEngine(server); + } + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testBufferUnderFlowAvoidedIfJDKCompatabilityModeOff(SSLEngineTestParam param) throws Exception { + SelfSignedCertificate cert = new SelfSignedCertificate(); + + clientSslCtx = wrapContext(param, SslContextBuilder + .forClient() + .trustManager(cert.cert()) + .sslProvider(sslClientProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + SSLEngine client = wrapEngine(clientSslCtx.newHandler(UnpooledByteBufAllocator.DEFAULT).engine()); + + serverSslCtx = wrapContext(param, SslContextBuilder + .forServer(cert.certificate(), cert.privateKey()) + .sslProvider(sslServerProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + SSLEngine server = wrapEngine(serverSslCtx.newHandler(UnpooledByteBufAllocator.DEFAULT).engine()); + + try { + ByteBuffer plainClient = allocateBuffer(param.type(), 1024); + plainClient.limit(plainClient.capacity()); + + ByteBuffer encClientToServer = allocateBuffer(param.type(), client.getSession().getPacketBufferSize()); + ByteBuffer plainServer = allocateBuffer(param.type(), server.getSession().getApplicationBufferSize()); + + handshake(param.type(), param.delegate(), client, server); + + SSLEngineResult result = client.wrap(plainClient, encClientToServer); + assertEquals(SSLEngineResult.Status.OK, result.getStatus()); + assertEquals(result.bytesConsumed(), plainClient.capacity()); + + // Flip so we can read it. + encClientToServer.flip(); + int remaining = encClientToServer.remaining(); + + // We limit the buffer so we have less then the header to read, this should result in an BUFFER_UNDERFLOW. + encClientToServer.limit(SslUtils.SSL_RECORD_HEADER_LENGTH - 1); + result = server.unwrap(encClientToServer, plainServer); + assertEquals(SSLEngineResult.Status.OK, result.getStatus()); + assertEquals(SslUtils.SSL_RECORD_HEADER_LENGTH - 1, result.bytesConsumed()); + assertEquals(0, result.bytesProduced()); + remaining -= result.bytesConsumed(); + + // We limit the buffer so we can read the header but not the rest, this should result in an + // BUFFER_UNDERFLOW. + encClientToServer.limit(SslUtils.SSL_RECORD_HEADER_LENGTH); + result = server.unwrap(encClientToServer, plainServer); + assertEquals(SSLEngineResult.Status.OK, result.getStatus()); + assertEquals(1, result.bytesConsumed()); + assertEquals(0, result.bytesProduced()); + remaining -= result.bytesConsumed(); + + // We limit the buffer so we can read the header and partly the rest, this should result in an + // BUFFER_UNDERFLOW. + encClientToServer.limit( + SslUtils.SSL_RECORD_HEADER_LENGTH + remaining - 1 - SslUtils.SSL_RECORD_HEADER_LENGTH); + result = server.unwrap(encClientToServer, plainServer); + assertEquals(SSLEngineResult.Status.OK, result.getStatus()); + assertEquals(encClientToServer.limit() - SslUtils.SSL_RECORD_HEADER_LENGTH, result.bytesConsumed()); + assertEquals(0, result.bytesProduced()); + remaining -= result.bytesConsumed(); + + // Reset limit so we can read the full record. + encClientToServer.limit(remaining); + assertEquals(0, encClientToServer.remaining()); + result = server.unwrap(encClientToServer, plainServer); + assertEquals(SSLEngineResult.Status.BUFFER_UNDERFLOW, result.getStatus()); + assertEquals(0, result.bytesConsumed()); + assertEquals(0, result.bytesProduced()); + + encClientToServer.position(0); + result = server.unwrap(encClientToServer, plainServer); + assertEquals(SSLEngineResult.Status.OK, result.getStatus()); + assertEquals(remaining, result.bytesConsumed()); + assertEquals(0, result.bytesProduced()); + } finally { + cert.delete(); + cleanupClientSslEngine(client); + cleanupServerSslEngine(server); + } + } + + private void testWrapWithDifferentSizes(SSLEngineTestParam param, String protocol, String cipher) throws Exception { + assumeTrue(OpenSsl.SUPPORTED_PROTOCOLS_SET.contains(protocol)); + if (!OpenSsl.isCipherSuiteAvailable(cipher)) { + return; + } + + SSLEngine clientEngine = null; + SSLEngine serverEngine = null; + try { + clientEngine = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + serverEngine = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + clientEngine.setEnabledCipherSuites(new String[] { cipher }); + clientEngine.setEnabledProtocols(new String[] { protocol }); + serverEngine.setEnabledCipherSuites(new String[] { cipher }); + serverEngine.setEnabledProtocols(new String[] { protocol }); + + try { + handshake(param.type(), param.delegate(), clientEngine, serverEngine); + } catch (SSLException e) { + if (e.getMessage().contains("unsupported protocol") || + e.getMessage().contains("no protocols available")) { + throw new TestAbortedException(protocol + " not supported with cipher " + cipher, e); + } + throw e; + } + + int srcLen = 64; + do { + testWrapDstBigEnough(param.type(), clientEngine, srcLen); + srcLen += 64; + } while (srcLen < MAX_PLAINTEXT_LENGTH); + + testWrapDstBigEnough(param.type(), clientEngine, MAX_PLAINTEXT_LENGTH); + } finally { + cleanupClientSslEngine(clientEngine); + cleanupServerSslEngine(serverEngine); + } + } + + private void testWrapDstBigEnough(BufferType type, SSLEngine engine, int srcLen) throws SSLException { + ByteBuffer src = allocateBuffer(type, srcLen); + ByteBuffer dst = allocateBuffer(type, srcLen + unwrapEngine(engine).maxWrapOverhead()); + + SSLEngineResult result = engine.wrap(src, dst); + assertEquals(SSLEngineResult.Status.OK, result.getStatus()); + int consumed = result.bytesConsumed(); + int produced = result.bytesProduced(); + assertEquals(srcLen, consumed); + assertTrue(produced > consumed); + + dst.flip(); + assertEquals(produced, dst.remaining()); + assertFalse(src.hasRemaining()); + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testSNIMatchersDoesNotThrow(SSLEngineTestParam param) throws Exception { + assumeTrue(PlatformDependent.javaVersion() >= 8); + SelfSignedCertificate ssc = new SelfSignedCertificate(); + serverSslCtx = wrapContext(param, SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) + .sslProvider(sslServerProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + + SSLEngine engine = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + try { + SSLParameters parameters = new SSLParameters(); + Java8SslTestUtils.setSNIMatcher(parameters, EmptyArrays.EMPTY_BYTES); + engine.setSSLParameters(parameters); + } finally { + cleanupServerSslEngine(engine); + ssc.delete(); + } + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testSNIMatchersWithSNINameWithUnderscore(SSLEngineTestParam param) throws Exception { + assumeTrue(PlatformDependent.javaVersion() >= 8); + byte[] name = "rb8hx3pww30y3tvw0mwy.v1_1".getBytes(CharsetUtil.UTF_8); + SelfSignedCertificate ssc = new SelfSignedCertificate(); + serverSslCtx = wrapContext(param, SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) + .sslProvider(sslServerProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + + SSLEngine engine = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + try { + SSLParameters parameters = new SSLParameters(); + Java8SslTestUtils.setSNIMatcher(parameters, name); + engine.setSSLParameters(parameters); + assertFalse(unwrapEngine(engine).checkSniHostnameMatch("other".getBytes(CharsetUtil.UTF_8))); + } finally { + cleanupServerSslEngine(engine); + ssc.delete(); + } + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testAlgorithmConstraintsThrows(SSLEngineTestParam param) throws Exception { + SelfSignedCertificate ssc = new SelfSignedCertificate(); + serverSslCtx = wrapContext(param, SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) + .sslProvider(sslServerProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + + final SSLEngine engine = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + final SSLParameters parameters = new SSLParameters(); + parameters.setAlgorithmConstraints(new AlgorithmConstraints() { + @Override + public boolean permits( + Set primitives, String algorithm, AlgorithmParameters parameters) { + return false; + } + + @Override + public boolean permits(Set primitives, Key key) { + return false; + } + + @Override + public boolean permits( + Set primitives, String algorithm, Key key, AlgorithmParameters parameters) { + return false; + } + }); + try { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() throws Throwable { + engine.setSSLParameters(parameters); + } + }); + } finally { + cleanupServerSslEngine(engine); + ssc.delete(); + } + } + + private static void runTasksIfNeeded(SSLEngine engine) { + if (engine.getHandshakeStatus() == HandshakeStatus.NEED_TASK) { + for (;;) { + Runnable task = engine.getDelegatedTask(); + if (task == null) { + assertNotEquals(HandshakeStatus.NEED_TASK, engine.getHandshakeStatus()); + break; + } + task.run(); + } + } + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testExtractMasterkeyWorksCorrectly(SSLEngineTestParam param) throws Exception { + if (param.combo() != ProtocolCipherCombo.tlsv12()) { + return; + } + SelfSignedCertificate cert = new SelfSignedCertificate(); + serverSslCtx = wrapContext(param, SslContextBuilder.forServer(cert.key(), cert.cert()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .sslProvider(OPENSSL).build()); + final SSLEngine serverEngine = + wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + clientSslCtx = wrapContext(param, SslContextBuilder.forClient() + .trustManager(cert.certificate()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .sslProvider(OPENSSL).build()); + final SSLEngine clientEngine = + wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + + final String enabledCipher = "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256"; + try { + //lets set the cipher suite to a specific one with DHE + assumeTrue(Arrays.asList(clientEngine.getSupportedCipherSuites()).contains(enabledCipher), + "The diffie hellman cipher is not supported on your runtime."); + + //https://www.ietf.org/rfc/rfc5289.txt + //For cipher suites ending with _SHA256, the PRF is the TLS PRF + //[RFC5246] with SHA-256 as the hash function. The MAC is HMAC + //[RFC2104] with SHA-256 as the hash function. + clientEngine.setEnabledCipherSuites(new String[] { enabledCipher }); + serverEngine.setEnabledCipherSuites(new String[] { enabledCipher }); + + int appBufferMax = clientEngine.getSession().getApplicationBufferSize(); + int netBufferMax = clientEngine.getSession().getPacketBufferSize(); + + /* + * We'll make the input buffers a bit bigger than the max needed + * size, so that unwrap()s following a successful data transfer + * won't generate BUFFER_OVERFLOWS. + */ + ByteBuffer clientIn = ByteBuffer.allocate(appBufferMax + 50); + ByteBuffer serverIn = ByteBuffer.allocate(appBufferMax + 50); + + ByteBuffer cTOs = ByteBuffer.allocate(netBufferMax); + ByteBuffer sTOc = ByteBuffer.allocate(netBufferMax); + + ByteBuffer clientOut = ByteBuffer.wrap("Hi Server, I'm Client".getBytes(CharsetUtil.US_ASCII)); + ByteBuffer serverOut = ByteBuffer.wrap("Hello Client, I'm Server".getBytes(CharsetUtil.US_ASCII)); + + // This implementation is largely imitated from + // https://docs.oracle.com/javase/8/docs/technotes/ + // guides/security/jsse/samples/sslengine/SSLEngineSimpleDemo.java + // It has been simplified however without the need for running delegation tasks + + // Do handshake for SSL + // A typical handshake will usually contain the following steps: + // 1. wrap: ClientHello + // 2. unwrap: ServerHello/Cert/ServerHelloDone + // 3. wrap: ClientKeyExchange + // 4. wrap: ChangeCipherSpec + // 5. wrap: Finished + // 6. unwrap: ChangeCipherSpec + // 7. unwrap: Finished + + //set a for loop; instead of a while loop to guarantee we quit out eventually + boolean asserted = false; + for (int i = 0; i < 1000; i++) { + + clientEngine.wrap(clientOut, cTOs); + serverEngine.wrap(serverOut, sTOc); + + cTOs.flip(); + sTOc.flip(); + + runTasksIfNeeded(clientEngine); + runTasksIfNeeded(serverEngine); + + clientEngine.unwrap(sTOc, clientIn); + serverEngine.unwrap(cTOs, serverIn); + + runTasksIfNeeded(clientEngine); + runTasksIfNeeded(serverEngine); + + // check when the application data has fully been consumed and sent + // for both the client and server + if ((clientOut.limit() == serverIn.position()) && + (serverOut.limit() == clientIn.position())) { + byte[] serverRandom = SSL.getServerRandom(unwrapEngine(serverEngine).sslPointer()); + byte[] clientRandom = SSL.getClientRandom(unwrapEngine(clientEngine).sslPointer()); + byte[] serverMasterKey = SSL.getMasterKey(unwrapEngine(serverEngine).sslPointer()); + byte[] clientMasterKey = SSL.getMasterKey(unwrapEngine(clientEngine).sslPointer()); + + asserted = true; + assertArrayEquals(serverMasterKey, clientMasterKey); + + // let us re-read the encrypted data and decrypt it ourselves! + cTOs.flip(); + sTOc.flip(); + + // See https://tools.ietf.org/html/rfc5246#section-6.3: + // key_block = PRF(SecurityParameters.master_secret, "key expansion", + // SecurityParameters.server_random + SecurityParameters.client_random); + // + // partitioned: + // client_write_MAC_secret[SecurityParameters.hash_size] + // server_write_MAC_secret[SecurityParameters.hash_size] + // client_write_key[SecurityParameters.key_material_length] + // server_write_key[SecurityParameters.key_material_length] + + int keySize = 16; // AES is 16 bytes or 128 bits + int macSize = 32; // SHA256 is 32 bytes or 256 bits + int keyBlockSize = (2 * keySize) + (2 * macSize); + + byte[] seed = new byte[serverRandom.length + clientRandom.length]; + System.arraycopy(serverRandom, 0, seed, 0, serverRandom.length); + System.arraycopy(clientRandom, 0, seed, serverRandom.length, clientRandom.length); + byte[] keyBlock = PseudoRandomFunction.hash(serverMasterKey, + "key expansion".getBytes(CharsetUtil.US_ASCII), seed, keyBlockSize, "HmacSha256"); + + int offset = 0; + byte[] clientWriteMac = Arrays.copyOfRange(keyBlock, offset, offset + macSize); + offset += macSize; + + byte[] serverWriteMac = Arrays.copyOfRange(keyBlock, offset, offset + macSize); + offset += macSize; + + byte[] clientWriteKey = Arrays.copyOfRange(keyBlock, offset, offset + keySize); + offset += keySize; + + byte[] serverWriteKey = Arrays.copyOfRange(keyBlock, offset, offset + keySize); + offset += keySize; + + //advance the cipher text by 5 + //to take into account the TLS Record Header + cTOs.position(cTOs.position() + 5); + + byte[] ciphertext = new byte[cTOs.remaining()]; + cTOs.get(ciphertext); + + //the initialization vector is the first 16 bytes (128 bits) of the payload + byte[] clientWriteIV = Arrays.copyOfRange(ciphertext, 0, 16); + ciphertext = Arrays.copyOfRange(ciphertext, 16, ciphertext.length); + + SecretKeySpec secretKey = new SecretKeySpec(clientWriteKey, "AES"); + final IvParameterSpec ivForCBC = new IvParameterSpec(clientWriteIV); + Cipher cipher = Cipher.getInstance("AES/CBC/NoPadding"); + cipher.init(Cipher.DECRYPT_MODE, secretKey, ivForCBC); + byte[] plaintext = cipher.doFinal(ciphertext); + assertTrue(new String(plaintext).startsWith("Hi Server, I'm Client")); + break; + } else { + cTOs.compact(); + sTOc.compact(); + } + } + + assertTrue(asserted, "The assertions were never executed."); + } finally { + cleanupClientSslEngine(clientEngine); + cleanupServerSslEngine(serverEngine); + cert.delete(); + } + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testNoKeyFound(final SSLEngineTestParam param) throws Exception { + checkShouldUseKeyManagerFactory(); + clientSslCtx = wrapContext(param, SslContextBuilder + .forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(sslClientProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + final SSLEngine client = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + + serverSslCtx = wrapContext(param, SslContextBuilder + .forServer(new X509ExtendedKeyManager() { + @Override + public String[] getClientAliases(String keyType, Principal[] issuers) { + return new String[0]; + } + + @Override + public String chooseClientAlias(String[] keyType, Principal[] issuers, Socket socket) { + return null; + } + + @Override + public String[] getServerAliases(String keyType, Principal[] issuers) { + return new String[0]; + } + + @Override + public String chooseServerAlias(String keyType, Principal[] issuers, Socket socket) { + return null; + } + + @Override + public X509Certificate[] getCertificateChain(String alias) { + return new X509Certificate[0]; + } + + @Override + public PrivateKey getPrivateKey(String alias) { + return null; + } + }) + .sslProvider(sslServerProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + final SSLEngine server = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + + try { + assertThrows(SSLException.class, new Executable() { + @Override + public void execute() throws Throwable { + handshake(param.type(), param.delegate(), client, server); + } + }); + } finally { + cleanupClientSslEngine(client); + cleanupServerSslEngine(server); + } + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Override + public void testSessionLocalWhenNonMutualWithKeyManager(SSLEngineTestParam param) throws Exception { + checkShouldUseKeyManagerFactory(); + super.testSessionLocalWhenNonMutualWithKeyManager(param); + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Override + public void testSessionLocalWhenNonMutualWithoutKeyManager(SSLEngineTestParam param) throws Exception { + // This only really works when the KeyManagerFactory is supported as otherwise we not really know when + // we need to provide a cert. + assumeTrue(OpenSsl.supportsKeyManagerFactory()); + super.testSessionLocalWhenNonMutualWithoutKeyManager(param); + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testDefaultTLS1NotAcceptedByDefaultServer(SSLEngineTestParam param) throws Exception { + testDefaultTLS1NotAcceptedByDefault(param, null, SslProtocols.TLS_v1); + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testDefaultTLS11NotAcceptedByDefaultServer(SSLEngineTestParam param) throws Exception { + testDefaultTLS1NotAcceptedByDefault(param, null, SslProtocols.TLS_v1_1); + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testDefaultTLS1NotAcceptedByDefaultClient(SSLEngineTestParam param) throws Exception { + testDefaultTLS1NotAcceptedByDefault(param, SslProtocols.TLS_v1, null); + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testDefaultTLS11NotAcceptedByDefaultClient(SSLEngineTestParam param) throws Exception { + testDefaultTLS1NotAcceptedByDefault(param, SslProtocols.TLS_v1_1, null); + } + + private void testDefaultTLS1NotAcceptedByDefault(final SSLEngineTestParam param, + String clientProtocol, String serverProtocol) throws Exception { + SslContextBuilder clientCtxBuilder = SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(sslClientProvider()) + .sslContextProvider(clientSslContextProvider()); + if (clientProtocol != null) { + clientCtxBuilder.protocols(clientProtocol); + } + clientSslCtx = wrapContext(param, clientCtxBuilder.build()); + SelfSignedCertificate ssc = new SelfSignedCertificate(); + + SslContextBuilder serverCtxBuilder = SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) + .sslProvider(sslServerProvider()) + .sslContextProvider(serverSslContextProvider()); + if (serverProtocol != null) { + serverCtxBuilder.protocols(serverProtocol); + } + serverSslCtx = wrapContext(param, serverCtxBuilder.build()); + final SSLEngine client = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + final SSLEngine server = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + + try { + assertThrows(SSLHandshakeException.class, new Executable() { + @Override + public void execute() throws Throwable { + handshake(param.type(), param.delegate(), client, server); + } + }); + } finally { + cleanupClientSslEngine(client); + cleanupServerSslEngine(server); + ssc.delete(); + } + } + + @Override + protected SslProvider sslClientProvider() { + return OPENSSL; + } + + @Override + protected SslProvider sslServerProvider() { + return OPENSSL; + } + + private static ApplicationProtocolConfig acceptingNegotiator(Protocol protocol, + String... supportedProtocols) { + return new ApplicationProtocolConfig(protocol, + SelectorFailureBehavior.NO_ADVERTISE, + SelectedListenerFailureBehavior.ACCEPT, + supportedProtocols); + } + + @Override + protected SSLEngine wrapEngine(SSLEngine engine) { + if (PlatformDependent.javaVersion() >= 8) { + return Java8SslTestUtils.wrapSSLEngineForTesting(engine); + } + return engine; + } + + ReferenceCountedOpenSslEngine unwrapEngine(SSLEngine engine) { + if (engine instanceof JdkSslEngine) { + return (ReferenceCountedOpenSslEngine) ((JdkSslEngine) engine).getWrappedEngine(); + } + return (ReferenceCountedOpenSslEngine) engine; + } + + @SuppressWarnings("deprecation") + @Override + protected SslContext wrapContext(SSLEngineTestParam param, SslContext context) { + if (context instanceof OpenSslContext) { + if (param instanceof OpenSslEngineTestParam) { + ((OpenSslContext) context).setUseTasks(((OpenSslEngineTestParam) param).useTasks); + } + // Explicit enable the session cache as its disabled by default on the client side. + ((OpenSslContext) context).sessionContext().setSessionCacheEnabled(true); + } + return context; + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Override + public void testSessionCache(SSLEngineTestParam param) throws Exception { + assumeTrue(OpenSsl.isSessionCacheSupported()); + super.testSessionCache(param); + assertSessionContext(clientSslCtx); + assertSessionContext(serverSslCtx); + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Override + public void testSessionCacheTimeout(SSLEngineTestParam param) throws Exception { + assumeTrue(OpenSsl.isSessionCacheSupported()); + super.testSessionCacheTimeout(param); + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Override + public void testSessionCacheSize(SSLEngineTestParam param) throws Exception { + assumeTrue(OpenSsl.isSessionCacheSupported()); + super.testSessionCacheSize(param); + } + + private static void assertSessionContext(SslContext context) { + if (context == null) { + return; + } + OpenSslSessionContext serverSessionCtx = (OpenSslSessionContext) context.sessionContext(); + assertTrue(serverSessionCtx.isSessionCacheEnabled()); + if (serverSessionCtx.getIds().hasMoreElements()) { + serverSessionCtx.setSessionCacheEnabled(false); + assertFalse(serverSessionCtx.getIds().hasMoreElements()); + assertFalse(serverSessionCtx.isSessionCacheEnabled()); + } + } + + @Override + protected void assertSessionReusedForEngine(SSLEngine clientEngine, SSLEngine serverEngine, boolean reuse) { + assertEquals(reuse, unwrapEngine(clientEngine).isSessionReused()); + assertEquals(reuse, unwrapEngine(serverEngine).isSessionReused()); + } + + @Override + protected boolean isSessionMaybeReused(SSLEngine engine) { + return unwrapEngine(engine).isSessionReused(); + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Override + public void testRSASSAPSS(SSLEngineTestParam param) throws Exception { + checkShouldUseKeyManagerFactory(); + super.testRSASSAPSS(param); + } + + @Test + public void testExtraDataInLastSrcBufferForClientUnwrapNonjdkCompatabilityMode() throws Exception { + SSLEngineTestParam param = new SSLEngineTestParam(BufferType.Direct, ProtocolCipherCombo.tlsv12(), false); + SelfSignedCertificate ssc = new SelfSignedCertificate(); + clientSslCtx = wrapContext(param, SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(sslClientProvider()) + .sslContextProvider(clientSslContextProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + serverSslCtx = wrapContext(param, SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) + .sslProvider(sslServerProvider()) + .sslContextProvider(serverSslContextProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .clientAuth(ClientAuth.NONE) + .build()); + testExtraDataInLastSrcBufferForClientUnwrap(param, + wrapEngine(clientSslCtx.newHandler(UnpooledByteBufAllocator.DEFAULT).engine()), + wrapEngine(serverSslCtx.newHandler(UnpooledByteBufAllocator.DEFAULT).engine())); + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testMaxCertificateList(final SSLEngineTestParam param) throws Exception { + assumeTrue(isOptionSupported(sslClientProvider(), MAX_CERTIFICATE_LIST_BYTES)); + assumeTrue(isOptionSupported(sslServerProvider(), MAX_CERTIFICATE_LIST_BYTES)); + SelfSignedCertificate ssc = new SelfSignedCertificate(); + clientSslCtx = wrapContext(param, SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .keyManager(ssc.certificate(), ssc.privateKey()) + .sslProvider(sslClientProvider()) + .sslContextProvider(clientSslContextProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .option(MAX_CERTIFICATE_LIST_BYTES, 10) + .build()); + serverSslCtx = wrapContext(param, SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) + .sslProvider(sslServerProvider()) + .sslContextProvider(serverSslContextProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .option(MAX_CERTIFICATE_LIST_BYTES, 10) + .clientAuth(ClientAuth.REQUIRE) + .build()); + + final SSLEngine client = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + final SSLEngine server = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + + try { + assertThrows(SSLHandshakeException.class, new Executable() { + @Override + public void execute() throws Throwable { + handshake(param.type(), param.delegate(), client, server); + } + }); + } finally { + cleanupClientSslEngine(client); + cleanupServerSslEngine(server); + ssc.delete(); + } + } +} diff --git a/netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslEngineTestParam.java b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslEngineTestParam.java new file mode 100644 index 0000000..a8f725f --- /dev/null +++ b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslEngineTestParam.java @@ -0,0 +1,34 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +final class OpenSslEngineTestParam extends SSLEngineTest.SSLEngineTestParam { + final boolean useTasks; + OpenSslEngineTestParam(boolean useTasks, SSLEngineTest.SSLEngineTestParam param) { + super(param.type(), param.combo(), param.delegate()); + this.useTasks = useTasks; + } + + @Override + public String toString() { + return "OpenSslEngineTestParam{" + + "type=" + type() + + ", protocolCipherCombo=" + combo() + + ", delegate=" + delegate() + + ", useTasks=" + useTasks + + '}'; + } +} diff --git a/netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslErrorStackAssertSSLEngine.java b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslErrorStackAssertSSLEngine.java new file mode 100644 index 0000000..9163950 --- /dev/null +++ b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslErrorStackAssertSSLEngine.java @@ -0,0 +1,442 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.internal.tcnative.SSL; +import io.netty.util.ReferenceCounted; +import io.netty.util.internal.PlatformDependent; + +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLEngineResult; +import javax.net.ssl.SSLException; +import javax.net.ssl.SSLParameters; +import javax.net.ssl.SSLSession; +import java.nio.ByteBuffer; +import java.util.List; +import java.util.function.BiFunction; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +/** + * Special {@link SSLEngine} which allows to wrap a {@link ReferenceCountedOpenSslEngine} and verify that that + * Error stack is empty after each method call. + */ +final class OpenSslErrorStackAssertSSLEngine extends JdkSslEngine implements ReferenceCounted { + + OpenSslErrorStackAssertSSLEngine(ReferenceCountedOpenSslEngine engine) { + super(engine); + } + + @Override + public String getPeerHost() { + try { + return getWrappedEngine().getPeerHost(); + } finally { + assertErrorStackEmpty(); + } + } + + @Override + public int getPeerPort() { + try { + return getWrappedEngine().getPeerPort(); + } finally { + assertErrorStackEmpty(); + } + } + + @Override + public SSLEngineResult wrap(ByteBuffer src, ByteBuffer dst) throws SSLException { + try { + return getWrappedEngine().wrap(src, dst); + } finally { + assertErrorStackEmpty(); + } + } + + @Override + public SSLEngineResult wrap(ByteBuffer[] srcs, ByteBuffer dst) throws SSLException { + try { + return getWrappedEngine().wrap(srcs, dst); + } finally { + assertErrorStackEmpty(); + } + } + + @Override + public SSLEngineResult wrap(ByteBuffer[] byteBuffers, int i, int i1, ByteBuffer byteBuffer) throws SSLException { + try { + return getWrappedEngine().wrap(byteBuffers, i, i1, byteBuffer); + } finally { + assertErrorStackEmpty(); + } + } + + @Override + public SSLEngineResult unwrap(ByteBuffer src, ByteBuffer dst) throws SSLException { + try { + return getWrappedEngine().unwrap(src, dst); + } finally { + assertErrorStackEmpty(); + } + } + + @Override + public SSLEngineResult unwrap(ByteBuffer src, ByteBuffer[] dsts) throws SSLException { + try { + return getWrappedEngine().unwrap(src, dsts); + } finally { + assertErrorStackEmpty(); + } + } + + @Override + public SSLEngineResult unwrap(ByteBuffer byteBuffer, ByteBuffer[] byteBuffers, int i, int i1) throws SSLException { + try { + return getWrappedEngine().unwrap(byteBuffer, byteBuffers, i, i1); + } finally { + assertErrorStackEmpty(); + } + } + + @Override + public Runnable getDelegatedTask() { + try { + return getWrappedEngine().getDelegatedTask(); + } finally { + assertErrorStackEmpty(); + } + } + + @Override + public void closeInbound() throws SSLException { + try { + getWrappedEngine().closeInbound(); + } finally { + assertErrorStackEmpty(); + } + } + + @Override + public boolean isInboundDone() { + try { + return getWrappedEngine().isInboundDone(); + } finally { + assertErrorStackEmpty(); + } + } + + @Override + public void closeOutbound() { + try { + getWrappedEngine().closeOutbound(); + } finally { + assertErrorStackEmpty(); + } + } + + @Override + public boolean isOutboundDone() { + try { + return getWrappedEngine().isOutboundDone(); + } finally { + assertErrorStackEmpty(); + } + } + + @Override + public String[] getSupportedCipherSuites() { + try { + return getWrappedEngine().getSupportedCipherSuites(); + } finally { + assertErrorStackEmpty(); + } + } + + @Override + public String[] getEnabledCipherSuites() { + try { + return getWrappedEngine().getEnabledCipherSuites(); + } finally { + assertErrorStackEmpty(); + } + } + + @Override + public void setEnabledCipherSuites(String[] strings) { + try { + getWrappedEngine().setEnabledCipherSuites(strings); + } finally { + assertErrorStackEmpty(); + } + } + + @Override + public String[] getSupportedProtocols() { + try { + return getWrappedEngine().getSupportedProtocols(); + } finally { + assertErrorStackEmpty(); + } + } + + @Override + public String[] getEnabledProtocols() { + try { + return getWrappedEngine().getEnabledProtocols(); + } finally { + assertErrorStackEmpty(); + } + } + + @Override + public void setEnabledProtocols(String[] strings) { + try { + getWrappedEngine().setEnabledProtocols(strings); + } finally { + assertErrorStackEmpty(); + } + } + + @Override + public SSLSession getSession() { + try { + return getWrappedEngine().getSession(); + } finally { + assertErrorStackEmpty(); + } + } + + @Override + public SSLSession getHandshakeSession() { + try { + return getWrappedEngine().getHandshakeSession(); + } finally { + assertErrorStackEmpty(); + } + } + + @Override + public void beginHandshake() throws SSLException { + try { + getWrappedEngine().beginHandshake(); + } finally { + assertErrorStackEmpty(); + } + } + + @Override + public SSLEngineResult.HandshakeStatus getHandshakeStatus() { + try { + return getWrappedEngine().getHandshakeStatus(); + } finally { + assertErrorStackEmpty(); + } + } + + @Override + public void setUseClientMode(boolean b) { + try { + getWrappedEngine().setUseClientMode(b); + } finally { + assertErrorStackEmpty(); + } + } + + @Override + public boolean getUseClientMode() { + try { + return getWrappedEngine().getUseClientMode(); + } finally { + assertErrorStackEmpty(); + } + } + + @Override + public void setNeedClientAuth(boolean b) { + try { + getWrappedEngine().setNeedClientAuth(b); + } finally { + assertErrorStackEmpty(); + } + } + + @Override + public boolean getNeedClientAuth() { + try { + return getWrappedEngine().getNeedClientAuth(); + } finally { + assertErrorStackEmpty(); + } + } + + @Override + public void setWantClientAuth(boolean b) { + try { + getWrappedEngine().setWantClientAuth(b); + } finally { + assertErrorStackEmpty(); + } + } + + @Override + public boolean getWantClientAuth() { + try { + return getWrappedEngine().getWantClientAuth(); + } finally { + assertErrorStackEmpty(); + } + } + + @Override + public void setEnableSessionCreation(boolean b) { + try { + getWrappedEngine().setEnableSessionCreation(b); + } finally { + assertErrorStackEmpty(); + } + } + + @Override + public boolean getEnableSessionCreation() { + try { + return getWrappedEngine().getEnableSessionCreation(); + } finally { + assertErrorStackEmpty(); + } + } + + @Override + public SSLParameters getSSLParameters() { + try { + return getWrappedEngine().getSSLParameters(); + } finally { + assertErrorStackEmpty(); + } + } + + @Override + public void setSSLParameters(SSLParameters params) { + try { + getWrappedEngine().setSSLParameters(params); + } finally { + assertErrorStackEmpty(); + } + } + + public String getApplicationProtocol() { + if (PlatformDependent.javaVersion() >= 9) { + try { + return JdkAlpnSslUtils.getApplicationProtocol(getWrappedEngine()); + } finally { + assertErrorStackEmpty(); + } + } + throw new UnsupportedOperationException(); + } + + public String getHandshakeApplicationProtocol() { + if (PlatformDependent.javaVersion() >= 9) { + try { + return JdkAlpnSslUtils.getHandshakeApplicationProtocol(getWrappedEngine()); + } finally { + assertErrorStackEmpty(); + } + } + throw new UnsupportedOperationException(); + } + + public void setHandshakeApplicationProtocolSelector(BiFunction, String> selector) { + if (PlatformDependent.javaVersion() >= 9) { + try { + JdkAlpnSslUtils.setHandshakeApplicationProtocolSelector(getWrappedEngine(), selector); + } finally { + assertErrorStackEmpty(); + } + } + throw new UnsupportedOperationException(); + } + + public BiFunction, String> getHandshakeApplicationProtocolSelector() { + if (PlatformDependent.javaVersion() >= 9) { + try { + return JdkAlpnSslUtils.getHandshakeApplicationProtocolSelector(getWrappedEngine()); + } finally { + assertErrorStackEmpty(); + } + } + throw new UnsupportedOperationException(); + } + + @Override + public int refCnt() { + return getWrappedEngine().refCnt(); + } + + @Override + public OpenSslErrorStackAssertSSLEngine retain() { + getWrappedEngine().retain(); + return this; + } + + @Override + public OpenSslErrorStackAssertSSLEngine retain(int increment) { + getWrappedEngine().retain(increment); + return this; + } + + @Override + public OpenSslErrorStackAssertSSLEngine touch() { + getWrappedEngine().touch(); + return this; + } + + @Override + public OpenSslErrorStackAssertSSLEngine touch(Object hint) { + getWrappedEngine().touch(hint); + return this; + } + + @Override + public boolean release() { + return getWrappedEngine().release(); + } + + @Override + public boolean release(int decrement) { + return getWrappedEngine().release(decrement); + } + + @Override + public String getNegotiatedApplicationProtocol() { + return getWrappedEngine().getNegotiatedApplicationProtocol(); + } + + @Override + void setNegotiatedApplicationProtocol(String applicationProtocol) { + throw new UnsupportedOperationException(); + } + + @Override + public ReferenceCountedOpenSslEngine getWrappedEngine() { + return (ReferenceCountedOpenSslEngine) super.getWrappedEngine(); + } + + private static void assertErrorStackEmpty() { + long error = SSL.getLastErrorNumber(); + assertEquals(0, error, "SSL error stack non-empty: " + SSL.getErrorString(error)); + } +} diff --git a/netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslJdkSslEngineInteroptTest.java b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslJdkSslEngineInteroptTest.java new file mode 100644 index 0000000..68e024f --- /dev/null +++ b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslJdkSslEngineInteroptTest.java @@ -0,0 +1,203 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import javax.net.ssl.SSLEngine; + +import java.util.ArrayList; +import java.util.List; + +import static io.netty.handler.ssl.OpenSslTestUtils.checkShouldUseKeyManagerFactory; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +public class OpenSslJdkSslEngineInteroptTest extends SSLEngineTest { + + public OpenSslJdkSslEngineInteroptTest() { + super(SslProvider.isTlsv13Supported(SslProvider.JDK) && + SslProvider.isTlsv13Supported(SslProvider.OPENSSL)); + } + + @Override + protected List newTestParams() { + List params = super.newTestParams(); + List testParams = new ArrayList(); + for (SSLEngineTestParam param: params) { + testParams.add(new OpenSslEngineTestParam(true, param)); + // TODO hangs! JP, 5.1.2024 + //testParams.add(new OpenSslEngineTestParam(false, param)); + } + return testParams; + } + + @BeforeAll + public static void checkOpenSsl() { + OpenSsl.ensureAvailability(); + } + + @Override + protected SslProvider sslClientProvider() { + return SslProvider.OPENSSL; + } + + @Override + protected SslProvider sslServerProvider() { + return SslProvider.JDK; + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Disabled /* Does the JDK support a "max certificate chain length"? */ + @Override + public void testMutualAuthValidClientCertChainTooLongFailOptionalClientAuth(SSLEngineTestParam param) + throws Exception { + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Disabled /* Does the JDK support a "max certificate chain length"? */ + @Override + public void testMutualAuthValidClientCertChainTooLongFailRequireClientAuth(SSLEngineTestParam param) + throws Exception { + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Override + public void testMutualAuthInvalidIntermediateCASucceedWithOptionalClientAuth(SSLEngineTestParam param) + throws Exception { + checkShouldUseKeyManagerFactory(); + super.testMutualAuthInvalidIntermediateCASucceedWithOptionalClientAuth(param); + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Override + public void testMutualAuthInvalidIntermediateCAFailWithOptionalClientAuth(SSLEngineTestParam param) + throws Exception { + checkShouldUseKeyManagerFactory(); + super.testMutualAuthInvalidIntermediateCAFailWithOptionalClientAuth(param); + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Override + public void testMutualAuthInvalidIntermediateCAFailWithRequiredClientAuth(SSLEngineTestParam param) + throws Exception { + checkShouldUseKeyManagerFactory(); + super.testMutualAuthInvalidIntermediateCAFailWithRequiredClientAuth(param); + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Override + public void testSessionAfterHandshakeKeyManagerFactoryMutualAuth(SSLEngineTestParam param) throws Exception { + checkShouldUseKeyManagerFactory(); + super.testSessionAfterHandshakeKeyManagerFactoryMutualAuth(param); + } + + @Override + protected boolean mySetupMutualAuthServerIsValidServerException(Throwable cause) { + // TODO(scott): work around for a JDK issue. The exception should be SSLHandshakeException. + return super.mySetupMutualAuthServerIsValidServerException(cause) || causedBySSLException(cause); + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Override + public void testHandshakeSession(SSLEngineTestParam param) throws Exception { + checkShouldUseKeyManagerFactory(); + super.testHandshakeSession(param); + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Override + public void testSupportedSignatureAlgorithms(SSLEngineTestParam param) throws Exception { + checkShouldUseKeyManagerFactory(); + super.testSupportedSignatureAlgorithms(param); + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Override + public void testSessionLocalWhenNonMutualWithKeyManager(SSLEngineTestParam param) throws Exception { + checkShouldUseKeyManagerFactory(); + super.testSessionLocalWhenNonMutualWithKeyManager(param); + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Override + public void testSessionLocalWhenNonMutualWithoutKeyManager(SSLEngineTestParam param) throws Exception { + // This only really works when the KeyManagerFactory is supported as otherwise we not really know when + // we need to provide a cert. + assumeTrue(OpenSsl.supportsKeyManagerFactory()); + super.testSessionLocalWhenNonMutualWithoutKeyManager(param); + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Override + public void testSessionCache(SSLEngineTestParam param) throws Exception { + assumeTrue(OpenSsl.isSessionCacheSupported()); + super.testSessionCache(param); + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Override + public void testSessionCacheTimeout(SSLEngineTestParam param) throws Exception { + assumeTrue(OpenSsl.isSessionCacheSupported()); + super.testSessionCacheTimeout(param); + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Override + public void testSessionCacheSize(SSLEngineTestParam param) throws Exception { + assumeTrue(OpenSsl.isSessionCacheSupported()); + super.testSessionCacheSize(param); + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Override + public void testRSASSAPSS(SSLEngineTestParam param) throws Exception { + checkShouldUseKeyManagerFactory(); + super.testRSASSAPSS(param); + } + + @Override + protected SSLEngine wrapEngine(SSLEngine engine) { + return Java8SslTestUtils.wrapSSLEngineForTesting(engine); + } + + @SuppressWarnings("deprecation") + @Override + protected SslContext wrapContext(SSLEngineTestParam param, SslContext context) { + if (context instanceof OpenSslContext && param instanceof OpenSslEngineTestParam) { + ((OpenSslContext) context).setUseTasks(((OpenSslEngineTestParam) param).useTasks); + // Explicit enable the session cache as its disabled by default on the client side. + ((OpenSslContext) context).sessionContext().setSessionCacheEnabled(true); + } + return context; + } +} diff --git a/netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslKeyMaterialManagerTest.java b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslKeyMaterialManagerTest.java new file mode 100644 index 0000000..c828dc9 --- /dev/null +++ b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslKeyMaterialManagerTest.java @@ -0,0 +1,83 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.UnpooledByteBufAllocator; +import io.netty.util.internal.EmptyArrays; +import org.junit.jupiter.api.Test; + +import javax.net.ssl.SSLException; +import javax.net.ssl.X509ExtendedKeyManager; +import java.net.Socket; +import java.security.Principal; +import java.security.PrivateKey; +import java.security.cert.X509Certificate; + +import static org.junit.jupiter.api.Assertions.fail; + +public class OpenSslKeyMaterialManagerTest { + + @Test + public void testChooseClientAliasReturnsNull() throws SSLException { + OpenSsl.ensureAvailability(); + + X509ExtendedKeyManager keyManager = new X509ExtendedKeyManager() { + @Override + public String[] getClientAliases(String s, Principal[] principals) { + return EmptyArrays.EMPTY_STRINGS; + } + + @Override + public String chooseClientAlias(String[] strings, Principal[] principals, Socket socket) { + return null; + } + + @Override + public String[] getServerAliases(String s, Principal[] principals) { + return EmptyArrays.EMPTY_STRINGS; + } + + @Override + public String chooseServerAlias(String s, Principal[] principals, Socket socket) { + return null; + } + + @Override + public X509Certificate[] getCertificateChain(String s) { + return EmptyArrays.EMPTY_X509_CERTIFICATES; + } + + @Override + public PrivateKey getPrivateKey(String s) { + return null; + } + }; + + OpenSslKeyMaterialManager manager = new OpenSslKeyMaterialManager( + new OpenSslKeyMaterialProvider(keyManager, null) { + @Override + OpenSslKeyMaterial chooseKeyMaterial(ByteBufAllocator allocator, String alias) throws Exception { + fail("Should not be called when alias is null"); + return null; + } + }); + SslContext context = SslContextBuilder.forClient().sslProvider(SslProvider.OPENSSL).build(); + OpenSslEngine engine = + (OpenSslEngine) context.newEngine(UnpooledByteBufAllocator.DEFAULT); + manager.setKeyMaterialClientSide(engine, EmptyArrays.EMPTY_STRINGS, null); + } +} diff --git a/netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslKeyMaterialProviderTest.java b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslKeyMaterialProviderTest.java new file mode 100644 index 0000000..09976ee --- /dev/null +++ b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslKeyMaterialProviderTest.java @@ -0,0 +1,180 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.UnpooledByteBufAllocator; +import io.netty.internal.tcnative.SSL; +import io.netty.util.ReferenceCountUtil; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.X509KeyManager; + +import java.net.Socket; +import java.security.KeyStore; +import java.security.Principal; +import java.security.PrivateKey; +import java.security.cert.X509Certificate; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class OpenSslKeyMaterialProviderTest { + + static final String PASSWORD = "example"; + static final String EXISTING_ALIAS = "1"; + private static final String NON_EXISTING_ALIAS = "nonexisting"; + + @BeforeAll + static void checkOpenSsl() { + OpenSsl.ensureAvailability(); + } + + protected KeyManagerFactory newKeyManagerFactory() throws Exception { + return newKeyManagerFactory(KeyManagerFactory.getDefaultAlgorithm()); + } + + protected KeyManagerFactory newKeyManagerFactory(String algorithm) throws Exception { + char[] password = PASSWORD.toCharArray(); + final KeyStore keystore = KeyStore.getInstance("PKCS12"); + keystore.load(getClass().getResourceAsStream("mutual_auth_server.p12"), password); + + KeyManagerFactory kmf = + KeyManagerFactory.getInstance(algorithm); + kmf.init(keystore, password); + return kmf; + } + + protected OpenSslKeyMaterialProvider newMaterialProvider(KeyManagerFactory factory, String password) { + return new OpenSslKeyMaterialProvider(ReferenceCountedOpenSslContext.chooseX509KeyManager( + factory.getKeyManagers()), password); + } + + protected void assertRelease(OpenSslKeyMaterial material) { + assertTrue(material.release()); + } + + @Test + public void testChooseKeyMaterial() throws Exception { + OpenSslKeyMaterialProvider provider = newMaterialProvider(newKeyManagerFactory(), PASSWORD); + OpenSslKeyMaterial nonExistingMaterial = provider.chooseKeyMaterial( + UnpooledByteBufAllocator.DEFAULT, NON_EXISTING_ALIAS); + assertNull(nonExistingMaterial); + + OpenSslKeyMaterial material = provider.chooseKeyMaterial(UnpooledByteBufAllocator.DEFAULT, EXISTING_ALIAS); + assertNotNull(material); + assertNotEquals(0, material.certificateChainAddress()); + assertNotEquals(0, material.privateKeyAddress()); + assertRelease(material); + + provider.destroy(); + } + + /** + * Test class used by testChooseOpenSslPrivateKeyMaterial(). + */ + private static final class SingleKeyManager implements X509KeyManager { + private final String keyAlias; + private final PrivateKey pk; + private final X509Certificate[] certChain; + + SingleKeyManager(String keyAlias, PrivateKey pk, X509Certificate[] certChain) { + this.keyAlias = keyAlias; + this.pk = pk; + this.certChain = certChain; + } + + @Override + public String[] getClientAliases(String keyType, Principal[] issuers) { + return new String[]{keyAlias}; + } + + @Override + public String chooseClientAlias(String[] keyType, Principal[] issuers, Socket socket) { + return keyAlias; + } + + @Override + public String[] getServerAliases(String keyType, Principal[] issuers) { + return new String[]{keyAlias}; + } + + @Override + public String chooseServerAlias(String keyType, Principal[] issuers, Socket socket) { + return keyAlias; + } + + @Override + public X509Certificate[] getCertificateChain(String alias) { + return certChain; + } + + @Override + public PrivateKey getPrivateKey(String alias) { + return pk; + } + } + + @Test + public void testChooseOpenSslPrivateKeyMaterial() throws Exception { + PrivateKey privateKey = SslContext.toPrivateKey( + getClass().getResourceAsStream("localhost_server.key"), + null); + assertNotNull(privateKey); + assertEquals("PKCS#8", privateKey.getFormat()); + final X509Certificate[] certChain = SslContext.toX509Certificates( + getClass().getResourceAsStream("localhost_server.pem")); + assertNotNull(certChain); + PemEncoded pemKey = null; + long pkeyBio = 0L; + OpenSslPrivateKey sslPrivateKey; + try { + pemKey = PemPrivateKey.toPEM(ByteBufAllocator.DEFAULT, true, privateKey); + pkeyBio = ReferenceCountedOpenSslContext.toBIO(ByteBufAllocator.DEFAULT, pemKey.retain()); + sslPrivateKey = new OpenSslPrivateKey(SSL.parsePrivateKey(pkeyBio, null)); + } finally { + ReferenceCountUtil.safeRelease(pemKey); + if (pkeyBio != 0L) { + SSL.freeBIO(pkeyBio); + } + } + final String keyAlias = "key"; + + OpenSslKeyMaterialProvider provider = new OpenSslKeyMaterialProvider( + new SingleKeyManager(keyAlias, sslPrivateKey, certChain), + null); + OpenSslKeyMaterial material = provider.chooseKeyMaterial(ByteBufAllocator.DEFAULT, keyAlias); + assertNotNull(material); + assertEquals(2, sslPrivateKey.refCnt()); + assertEquals(1, material.refCnt()); + assertTrue(material.release()); + assertEquals(1, sslPrivateKey.refCnt()); + // Can get material multiple times from the same key + material = provider.chooseKeyMaterial(ByteBufAllocator.DEFAULT, keyAlias); + assertNotNull(material); + assertEquals(2, sslPrivateKey.refCnt()); + assertTrue(material.release()); + assertTrue(sslPrivateKey.release()); + assertEquals(0, sslPrivateKey.refCnt()); + assertEquals(0, material.refCnt()); + assertEquals(0, ((OpenSslPrivateKey.OpenSslPrivateKeyMaterial) material).certificateChain); + } +} diff --git a/netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslPrivateKeyMethodTest.java b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslPrivateKeyMethodTest.java new file mode 100644 index 0000000..dc65305 --- /dev/null +++ b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslPrivateKeyMethodTest.java @@ -0,0 +1,481 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version + * 2.0 (the "License"); you may not use this file except in compliance with the + * License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.ssl; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.netty.buffer.UnpooledByteBufAllocator; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.DefaultEventLoopGroup; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.channel.local.LocalAddress; +import io.netty.channel.local.LocalChannel; +import io.netty.channel.local.LocalServerChannel; +import io.netty.handler.ssl.util.InsecureTrustManagerFactory; +import io.netty.handler.ssl.util.SelfSignedCertificate; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.ImmediateEventExecutor; +import io.netty.util.concurrent.Promise; +import io.netty.util.internal.ThreadLocalRandom; +import org.hamcrest.Matchers; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLHandshakeException; +import java.net.SocketAddress; +import java.security.NoSuchAlgorithmException; +import java.security.Signature; +import java.security.SignatureException; +import java.security.spec.MGF1ParameterSpec; +import java.security.spec.PSSParameterSpec; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.Executor; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; + +import static io.netty.handler.ssl.OpenSslTestUtils.checkShouldUseKeyManagerFactory; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +public class OpenSslPrivateKeyMethodTest { + private static final String RFC_CIPHER_NAME = "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"; + private static EventLoopGroup GROUP; + private static SelfSignedCertificate CERT; + private static ExecutorService EXECUTOR; + + static Collection parameters() { + List dst = new ArrayList(); + for (int a = 0; a < 2; a++) { + for (int b = 0; b < 2; b++) { + for (int c = 0; c < 2; c++) { + dst.add(new Object[] { a == 0, b == 0, c == 0 }); + } + } + } + return dst; + } + + @BeforeAll + public static void init() throws Exception { + checkShouldUseKeyManagerFactory(); + + assumeTrue(OpenSsl.isBoringSSL()); + // Check if the cipher is supported at all which may not be the case for various JDK versions and OpenSSL API + // implementations. + assumeCipherAvailable(SslProvider.OPENSSL); + assumeCipherAvailable(SslProvider.JDK); + + GROUP = new DefaultEventLoopGroup(); + CERT = new SelfSignedCertificate(); + EXECUTOR = Executors.newCachedThreadPool(new ThreadFactory() { + @Override + public Thread newThread(Runnable r) { + return new DelegateThread(r); + } + }); + } + + @AfterAll + public static void destroy() { + if (OpenSsl.isBoringSSL()) { + GROUP.shutdownGracefully(); + CERT.delete(); + EXECUTOR.shutdown(); + } + } + + private static void assumeCipherAvailable(SslProvider provider) throws NoSuchAlgorithmException { + boolean cipherSupported = false; + if (provider == SslProvider.JDK) { + SSLEngine engine = SSLContext.getDefault().createSSLEngine(); + for (String c: engine.getSupportedCipherSuites()) { + if (RFC_CIPHER_NAME.equals(c)) { + cipherSupported = true; + break; + } + } + } else { + cipherSupported = OpenSsl.isCipherSuiteAvailable(RFC_CIPHER_NAME); + } + assumeTrue(cipherSupported, "Unsupported cipher: " + RFC_CIPHER_NAME); + } + + private static SslHandler newSslHandler(SslContext sslCtx, ByteBufAllocator allocator, Executor executor) { + if (executor == null) { + return sslCtx.newHandler(allocator); + } else { + return sslCtx.newHandler(allocator, executor); + } + } + + private SslContext buildServerContext(OpenSslPrivateKeyMethod method) throws Exception { + List ciphers = Collections.singletonList(RFC_CIPHER_NAME); + + final KeyManagerFactory kmf = OpenSslX509KeyManagerFactory.newKeyless(CERT.cert()); + + return SslContextBuilder.forServer(kmf) + .sslProvider(SslProvider.OPENSSL) + .ciphers(ciphers) + // As this is not a TLSv1.3 cipher we should ensure we talk something else. + .protocols(SslProtocols.TLS_v1_2) + .option(OpenSslContextOption.PRIVATE_KEY_METHOD, method) + .build(); + } + + private SslContext buildClientContext() throws Exception { + return SslContextBuilder.forClient() + .sslProvider(SslProvider.JDK) + .ciphers(Collections.singletonList(RFC_CIPHER_NAME)) + // As this is not a TLSv1.3 cipher we should ensure we talk something else. + .protocols(SslProtocols.TLS_v1_2) + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .build(); + } + + private static Executor delegateExecutor(boolean delegate) { + return delegate ? EXECUTOR : null; + } + private SslContext buildServerContext(OpenSslAsyncPrivateKeyMethod method) throws Exception { + List ciphers = Collections.singletonList(RFC_CIPHER_NAME); + + final KeyManagerFactory kmf = OpenSslX509KeyManagerFactory.newKeyless(CERT.cert()); + + return SslContextBuilder.forServer(kmf) + .sslProvider(SslProvider.OPENSSL) + .ciphers(ciphers) + // As this is not a TLSv1.3 cipher we should ensure we talk something else. + .protocols(SslProtocols.TLS_v1_2) + .option(OpenSslContextOption.ASYNC_PRIVATE_KEY_METHOD, method) + .build(); + } + + private static void assertThread(boolean delegate) { + if (delegate && OpenSslContext.USE_TASKS) { + assertEquals(DelegateThread.class, Thread.currentThread().getClass()); + } + } + + @ParameterizedTest(name = "{index}: delegate = {0}, async = {1}, newThread={2}") + @MethodSource("parameters") + public void testPrivateKeyMethod(final boolean delegate, boolean async, boolean newThread) throws Exception { + final AtomicBoolean signCalled = new AtomicBoolean(); + OpenSslPrivateKeyMethod keyMethod = new OpenSslPrivateKeyMethod() { + @Override + public byte[] sign(SSLEngine engine, int signatureAlgorithm, byte[] input) throws Exception { + signCalled.set(true); + assertThread(delegate); + + assertEquals(CERT.cert().getPublicKey(), + engine.getSession().getLocalCertificates()[0].getPublicKey()); + + // Delegate signing to Java implementation. + final Signature signature; + // Depending on the Java version it will pick one or the other. + if (signatureAlgorithm == OpenSslPrivateKeyMethod.SSL_SIGN_RSA_PKCS1_SHA256) { + signature = Signature.getInstance("SHA256withRSA"); + } else if (signatureAlgorithm == OpenSslPrivateKeyMethod.SSL_SIGN_RSA_PSS_RSAE_SHA256) { + signature = Signature.getInstance("RSASSA-PSS"); + signature.setParameter(new PSSParameterSpec("SHA-256", "MGF1", MGF1ParameterSpec.SHA256, + 32, 1)); + } else { + throw new AssertionError("Unexpected signature algorithm " + signatureAlgorithm); + } + signature.initSign(CERT.key()); + signature.update(input); + return signature.sign(); + } + + @Override + public byte[] decrypt(SSLEngine engine, byte[] input) { + throw new UnsupportedOperationException(); + } + }; + + final SslContext sslServerContext = async ? buildServerContext( + new OpenSslPrivateKeyMethodAdapter(keyMethod, newThread)) : buildServerContext(keyMethod); + + final SslContext sslClientContext = buildClientContext(); + try { + try { + final Promise serverPromise = GROUP.next().newPromise(); + final Promise clientPromise = GROUP.next().newPromise(); + + ChannelHandler serverHandler = new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + ChannelPipeline pipeline = ch.pipeline(); + pipeline.addLast(newSslHandler(sslServerContext, ch.alloc(), delegateExecutor(delegate))); + + pipeline.addLast(new SimpleChannelInboundHandler() { + @Override + public void channelInactive(ChannelHandlerContext ctx) { + serverPromise.cancel(true); + ctx.fireChannelInactive(); + } + + @Override + public void channelRead0(ChannelHandlerContext ctx, Object msg) { + if (serverPromise.trySuccess(null)) { + ctx.writeAndFlush(Unpooled.wrappedBuffer(new byte[] {'P', 'O', 'N', 'G'})); + } + ctx.close(); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + if (!serverPromise.tryFailure(cause)) { + ctx.fireExceptionCaught(cause); + } + } + }); + } + }; + + LocalAddress address = new LocalAddress("test-" + SslProvider.OPENSSL + + '-' + SslProvider.JDK + '-' + RFC_CIPHER_NAME + '-' + delegate); + + Channel server = server(address, serverHandler); + try { + ChannelHandler clientHandler = new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + ChannelPipeline pipeline = ch.pipeline(); + pipeline.addLast(newSslHandler(sslClientContext, ch.alloc(), delegateExecutor(delegate))); + + pipeline.addLast(new SimpleChannelInboundHandler() { + @Override + public void channelInactive(ChannelHandlerContext ctx) { + clientPromise.cancel(true); + ctx.fireChannelInactive(); + } + + @Override + public void channelRead0(ChannelHandlerContext ctx, Object msg) { + clientPromise.trySuccess(null); + ctx.close(); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + if (!clientPromise.tryFailure(cause)) { + ctx.fireExceptionCaught(cause); + } + } + }); + } + }; + + Channel client = client(server, clientHandler); + try { + client.writeAndFlush(Unpooled.wrappedBuffer(new byte[] {'P', 'I', 'N', 'G'})) + .syncUninterruptibly(); + + assertTrue(clientPromise.await(5L, TimeUnit.SECONDS), "client timeout"); + assertTrue(serverPromise.await(5L, TimeUnit.SECONDS), "server timeout"); + + clientPromise.sync(); + serverPromise.sync(); + assertTrue(signCalled.get()); + } finally { + client.close().sync(); + } + } finally { + server.close().sync(); + } + } finally { + ReferenceCountUtil.release(sslClientContext); + } + } finally { + ReferenceCountUtil.release(sslServerContext); + } + } + + @ParameterizedTest(name = "{index}: delegate = {0}") + @MethodSource("parameters") + public void testPrivateKeyMethodFailsBecauseOfException(final boolean delegate) throws Exception { + testPrivateKeyMethodFails(delegate, false); + } + + @ParameterizedTest(name = "{index}: delegate = {0}") + @MethodSource("parameters") + public void testPrivateKeyMethodFailsBecauseOfNull(final boolean delegate) throws Exception { + testPrivateKeyMethodFails(delegate, true); + } + + private void testPrivateKeyMethodFails(final boolean delegate, final boolean returnNull) throws Exception { + final SslContext sslServerContext = buildServerContext(new OpenSslPrivateKeyMethod() { + @Override + public byte[] sign(SSLEngine engine, int signatureAlgorithm, byte[] input) throws Exception { + assertThread(delegate); + if (returnNull) { + return null; + } + throw new SignatureException(); + } + + @Override + public byte[] decrypt(SSLEngine engine, byte[] input) { + throw new UnsupportedOperationException(); + } + }); + final SslContext sslClientContext = buildClientContext(); + + SslHandler serverSslHandler = newSslHandler( + sslServerContext, UnpooledByteBufAllocator.DEFAULT, delegateExecutor(delegate)); + SslHandler clientSslHandler = newSslHandler( + sslClientContext, UnpooledByteBufAllocator.DEFAULT, delegateExecutor(delegate)); + + try { + try { + LocalAddress address = new LocalAddress("test-" + SslProvider.OPENSSL + + '-' + SslProvider.JDK + '-' + RFC_CIPHER_NAME + '-' + delegate); + + Channel server = server(address, serverSslHandler); + try { + Channel client = client(server, clientSslHandler); + try { + Throwable clientCause = clientSslHandler.handshakeFuture().await().cause(); + Throwable serverCause = serverSslHandler.handshakeFuture().await().cause(); + assertNotNull(clientCause); + assertThat(serverCause, Matchers.instanceOf(SSLHandshakeException.class)); + } finally { + client.close().sync(); + } + } finally { + server.close().sync(); + } + } finally { + ReferenceCountUtil.release(sslClientContext); + } + } finally { + ReferenceCountUtil.release(sslServerContext); + } + } + + private static Channel server(LocalAddress address, ChannelHandler handler) throws Exception { + ServerBootstrap bootstrap = new ServerBootstrap() + .channel(LocalServerChannel.class) + .group(GROUP) + .childHandler(handler); + + return bootstrap.bind(address).sync().channel(); + } + + private static Channel client(Channel server, ChannelHandler handler) throws Exception { + SocketAddress remoteAddress = server.localAddress(); + + Bootstrap bootstrap = new Bootstrap() + .channel(LocalChannel.class) + .group(GROUP) + .handler(handler); + + return bootstrap.connect(remoteAddress).sync().channel(); + } + + private static final class DelegateThread extends Thread { + DelegateThread(Runnable target) { + super(target); + } + } + + private static final class OpenSslPrivateKeyMethodAdapter implements OpenSslAsyncPrivateKeyMethod { + private final OpenSslPrivateKeyMethod keyMethod; + private final boolean newThread; + + OpenSslPrivateKeyMethodAdapter(OpenSslPrivateKeyMethod keyMethod, boolean newThread) { + this.keyMethod = keyMethod; + this.newThread = newThread; + } + + @Override + public Future sign(final SSLEngine engine, final int signatureAlgorithm, final byte[] input) { + final Promise promise = ImmediateEventExecutor.INSTANCE.newPromise(); + try { + if (newThread) { + // Let's run these in an extra thread to ensure that this would also work if the promise is + // notified later. + new DelegateThread(new Runnable() { + @Override + public void run() { + try { + // Let's sleep for some time to ensure we would notify in an async fashion + Thread.sleep(ThreadLocalRandom.current().nextLong(100, 500)); + promise.setSuccess(keyMethod.sign(engine, signatureAlgorithm, input)); + } catch (Throwable cause) { + promise.setFailure(cause); + } + } + }).start(); + } else { + promise.setSuccess(keyMethod.sign(engine, signatureAlgorithm, input)); + } + } catch (Throwable cause) { + promise.setFailure(cause); + } + return promise; + } + + @Override + public Future decrypt(final SSLEngine engine, final byte[] input) { + final Promise promise = ImmediateEventExecutor.INSTANCE.newPromise(); + try { + if (newThread) { + // Let's run these in an extra thread to ensure that this would also work if the promise is + // notified later. + new DelegateThread(new Runnable() { + @Override + public void run() { + try { + // Let's sleep for some time to ensure we would notify in an async fashion + Thread.sleep(ThreadLocalRandom.current().nextLong(100, 500)); + promise.setSuccess(keyMethod.decrypt(engine, input)); + } catch (Throwable cause) { + promise.setFailure(cause); + } + } + }).start(); + } else { + promise.setSuccess(keyMethod.decrypt(engine, input)); + } + } catch (Throwable cause) { + promise.setFailure(cause); + } + return promise; + } + } +} diff --git a/netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslRenegotiateTest.java b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslRenegotiateTest.java new file mode 100644 index 0000000..7f5dea8 --- /dev/null +++ b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslRenegotiateTest.java @@ -0,0 +1,45 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import org.junit.jupiter.api.BeforeAll; + +import java.util.concurrent.atomic.AtomicReference; + +import javax.net.ssl.SSLException; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; + +public class OpenSslRenegotiateTest extends RenegotiateTest { + + @BeforeAll + public static void checkOpenSsl() { + OpenSsl.ensureAvailability(); + } + + @Override + protected SslProvider serverSslProvider() { + return SslProvider.OPENSSL; + } + + protected void verifyResult(AtomicReference error) throws Throwable { + Throwable cause = error.get(); + // Renegotiation is not supported by the OpenSslEngine. + assertThat(cause, is(instanceOf(SSLException.class))); + } +} diff --git a/netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslServerContextTest.java b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslServerContextTest.java new file mode 100644 index 0000000..099e85d --- /dev/null +++ b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslServerContextTest.java @@ -0,0 +1,34 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import org.junit.jupiter.api.BeforeAll; + +import javax.net.ssl.SSLException; +import java.io.File; + +public class OpenSslServerContextTest extends SslContextTest { + + @BeforeAll + public static void checkOpenSsl() { + OpenSsl.ensureAvailability(); + } + + @Override + protected SslContext newSslContext(File crtFile, File keyFile, String pass) throws SSLException { + return new OpenSslServerContext(crtFile, keyFile, pass); + } +} diff --git a/netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslTest.java b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslTest.java new file mode 100644 index 0000000..18e8464 --- /dev/null +++ b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslTest.java @@ -0,0 +1,31 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class OpenSslTest { + + @Test + public void testDefaultCiphers() { + if (!OpenSsl.isTlsv13Supported()) { + assertTrue( + OpenSsl.DEFAULT_CIPHERS.size() <= SslUtils.DEFAULT_CIPHER_SUITES.length); + } + } +} diff --git a/netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslTestUtils.java b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslTestUtils.java new file mode 100644 index 0000000..a5c9ac7 --- /dev/null +++ b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslTestUtils.java @@ -0,0 +1,27 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +final class OpenSslTestUtils { + private OpenSslTestUtils() { + } + + static void checkShouldUseKeyManagerFactory() { + assumeTrue(OpenSsl.supportsKeyManagerFactory() && OpenSsl.useKeyManagerFactory()); + } +} diff --git a/netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslX509KeyManagerFactoryProviderTest.java b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslX509KeyManagerFactoryProviderTest.java new file mode 100644 index 0000000..c111ab8 --- /dev/null +++ b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/OpenSslX509KeyManagerFactoryProviderTest.java @@ -0,0 +1,38 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import javax.net.ssl.KeyManagerFactory; +import java.security.KeyStore; + +public class OpenSslX509KeyManagerFactoryProviderTest extends OpenSslCachingKeyMaterialProviderTest { + + @Override + protected KeyManagerFactory newKeyManagerFactory() throws Exception { + char[] password = PASSWORD.toCharArray(); + final KeyStore keystore = KeyStore.getInstance("PKCS12"); + keystore.load(getClass().getResourceAsStream("mutual_auth_server.p12"), password); + + OpenSslX509KeyManagerFactory kmf = new OpenSslX509KeyManagerFactory(); + kmf.init(keystore, password); + return kmf; + } + + @Override + protected OpenSslKeyMaterialProvider newMaterialProvider(KeyManagerFactory kmf, String password) { + return ((OpenSslX509KeyManagerFactory) kmf).newProvider(); + } +} diff --git a/netty-handler-ssl/src/test/java/io/netty/handler/ssl/OptionalSslHandlerTest.java b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/OptionalSslHandlerTest.java new file mode 100644 index 0000000..36ff3de --- /dev/null +++ b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/OptionalSslHandlerTest.java @@ -0,0 +1,123 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPipeline; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.mockito.Mock; +import org.mockito.Mockito; +import org.mockito.MockitoAnnotations; + +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.verifyZeroInteractions; +import static org.mockito.Mockito.when; + +public class OptionalSslHandlerTest { + + private static final String SSL_HANDLER_NAME = "sslhandler"; + private static final String HANDLER_NAME = "handler"; + + @Mock + private ChannelHandlerContext context; + + @Mock + private SslContext sslContext; + + @Mock + private ChannelPipeline pipeline; + + @BeforeEach + public void setUp() throws Exception { + MockitoAnnotations.initMocks(this); + when(context.pipeline()).thenReturn(pipeline); + } + + @Test + public void handlerRemoved() throws Exception { + OptionalSslHandler handler = new OptionalSslHandler(sslContext); + final ByteBuf payload = Unpooled.copiedBuffer("plaintext".getBytes()); + try { + handler.decode(context, payload, null); + verify(pipeline).remove(handler); + } finally { + payload.release(); + } + } + + @Test + public void handlerReplaced() throws Exception { + final ChannelHandler nonSslHandler = Mockito.mock(ChannelHandler.class); + OptionalSslHandler handler = new OptionalSslHandler(sslContext) { + @Override + protected ChannelHandler newNonSslHandler(ChannelHandlerContext context) { + return nonSslHandler; + } + + @Override + protected String newNonSslHandlerName() { + return HANDLER_NAME; + } + }; + final ByteBuf payload = Unpooled.copiedBuffer("plaintext".getBytes()); + try { + handler.decode(context, payload, null); + verify(pipeline).replace(handler, HANDLER_NAME, nonSslHandler); + } finally { + payload.release(); + } + } + + @Test + public void sslHandlerReplaced() throws Exception { + final SslHandler sslHandler = Mockito.mock(SslHandler.class); + OptionalSslHandler handler = new OptionalSslHandler(sslContext) { + @Override + protected SslHandler newSslHandler(ChannelHandlerContext context, SslContext sslContext) { + return sslHandler; + } + + @Override + protected String newSslHandlerName() { + return SSL_HANDLER_NAME; + } + }; + final ByteBuf payload = Unpooled.wrappedBuffer(new byte[] { 22, 3, 1, 0, 5 }); + try { + handler.decode(context, payload, null); + verify(pipeline).replace(handler, SSL_HANDLER_NAME, sslHandler); + } finally { + payload.release(); + } + } + + @Test + public void decodeBuffered() throws Exception { + OptionalSslHandler handler = new OptionalSslHandler(sslContext); + final ByteBuf payload = Unpooled.wrappedBuffer(new byte[] { 22, 3 }); + try { + handler.decode(context, payload, null); + verifyZeroInteractions(pipeline); + } finally { + payload.release(); + } + } +} diff --git a/netty-handler-ssl/src/test/java/io/netty/handler/ssl/ParameterizedSslHandlerTest.java b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/ParameterizedSslHandlerTest.java new file mode 100644 index 0000000..c3da0dc --- /dev/null +++ b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/ParameterizedSslHandlerTest.java @@ -0,0 +1,709 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.CompositeByteBuf; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelOption; +import io.netty.channel.DefaultEventLoopGroup; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.ServerChannel; +import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.channel.local.LocalAddress; +import io.netty.channel.local.LocalChannel; +import io.netty.channel.local.LocalServerChannel; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.nio.NioServerSocketChannel; +import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.handler.ssl.util.InsecureTrustManagerFactory; +import io.netty.handler.ssl.util.SelfSignedCertificate; +import io.netty.handler.ssl.util.SimpleTrustManagerFactory; +import io.netty.util.CharsetUtil; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.FutureListener; +import io.netty.util.concurrent.Promise; +import io.netty.util.concurrent.PromiseNotifier; +import io.netty.util.internal.EmptyArrays; +import io.netty.util.internal.ResourcesUtil; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import javax.net.ssl.ManagerFactoryParameters; +import javax.net.ssl.SSLException; +import javax.net.ssl.TrustManager; +import javax.net.ssl.X509TrustManager; +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.io.UnsupportedEncodingException; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.security.KeyStore; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; + +import static io.netty.buffer.ByteBufUtil.writeAscii; +import static io.netty.util.internal.ThreadLocalRandom.current; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class ParameterizedSslHandlerTest { + + private static final String PARAMETERIZED_NAME = "{index}: clientProvider={0}, {index}: serverProvider={1}"; + + static Collection data() { + List providers = new ArrayList(3); + if (OpenSsl.isAvailable()) { + providers.add(SslProvider.OPENSSL); + providers.add(SslProvider.OPENSSL_REFCNT); + } + providers.add(SslProvider.JDK); + + List params = new ArrayList(); + + for (SslProvider cp: providers) { + for (SslProvider sp: providers) { + params.add(new Object[] { cp, sp }); + } + } + return params; + } + + @ParameterizedTest(name = PARAMETERIZED_NAME) + @MethodSource("data") + @Timeout(value = 48000, unit = TimeUnit.MILLISECONDS) + public void testCompositeBufSizeEstimationGuaranteesSynchronousWrite( + SslProvider clientProvider, SslProvider serverProvider) + throws CertificateException, SSLException, ExecutionException, InterruptedException { + compositeBufSizeEstimationGuaranteesSynchronousWrite(serverProvider, clientProvider, + true, true, true); + compositeBufSizeEstimationGuaranteesSynchronousWrite(serverProvider, clientProvider, + true, true, false); + compositeBufSizeEstimationGuaranteesSynchronousWrite(serverProvider, clientProvider, + true, false, true); + compositeBufSizeEstimationGuaranteesSynchronousWrite(serverProvider, clientProvider, + true, false, false); + compositeBufSizeEstimationGuaranteesSynchronousWrite(serverProvider, clientProvider, + false, true, true); + compositeBufSizeEstimationGuaranteesSynchronousWrite(serverProvider, clientProvider, + false, true, false); + compositeBufSizeEstimationGuaranteesSynchronousWrite(serverProvider, clientProvider, + false, false, true); + compositeBufSizeEstimationGuaranteesSynchronousWrite(serverProvider, clientProvider, + false, false, false); + } + + private static void compositeBufSizeEstimationGuaranteesSynchronousWrite( + SslProvider serverProvider, SslProvider clientProvider, + final boolean serverDisableWrapSize, + final boolean letHandlerCreateServerEngine, final boolean letHandlerCreateClientEngine) + throws CertificateException, SSLException, ExecutionException, InterruptedException { + SelfSignedCertificate ssc = new SelfSignedCertificate(); + + final SslContext sslServerCtx = SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) + .sslProvider(serverProvider) + .build(); + + final SslContext sslClientCtx = SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(clientProvider).build(); + + EventLoopGroup group = new NioEventLoopGroup(); + Channel sc = null; + Channel cc = null; + try { + final Promise donePromise = group.next().newPromise(); + // The goal is to provide the SSLEngine with many ByteBuf components to ensure that the overhead for wrap + // is correctly accounted for on each component. + final int numComponents = 150; + // This is the TLS packet size. The goal is to divide the maximum amount of application data that can fit + // into a single TLS packet into many components to ensure the overhead is correctly taken into account. + final int desiredBytes = 16384; + final int singleComponentSize = desiredBytes / numComponents; + final int expectedBytes = numComponents * singleComponentSize; + + sc = new ServerBootstrap() + .group(group) + .channel(NioServerSocketChannel.class) + .childHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + final SslHandler handler = letHandlerCreateServerEngine + ? sslServerCtx.newHandler(ch.alloc()) + : new SslHandler(sslServerCtx.newEngine(ch.alloc())); + if (serverDisableWrapSize) { + handler.setWrapDataSize(-1); + } + ch.pipeline().addLast(handler); + ch.pipeline().addLast(new ChannelInboundHandlerAdapter() { + private boolean sentData; + private Throwable writeCause; + + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + if (evt instanceof SslHandshakeCompletionEvent) { + SslHandshakeCompletionEvent sslEvt = (SslHandshakeCompletionEvent) evt; + if (sslEvt.isSuccess()) { + CompositeByteBuf content = ctx.alloc().compositeDirectBuffer(numComponents); + for (int i = 0; i < numComponents; ++i) { + ByteBuf buf = ctx.alloc().directBuffer(singleComponentSize); + buf.writerIndex(buf.writerIndex() + singleComponentSize); + content.addComponent(true, buf); + } + ctx.writeAndFlush(content).addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + writeCause = future.cause(); + if (writeCause == null) { + sentData = true; + } + } + }); + } else { + donePromise.tryFailure(sslEvt.cause()); + } + } + ctx.fireUserEventTriggered(evt); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + donePromise.tryFailure(new IllegalStateException("server exception sentData: " + + sentData + " writeCause: " + writeCause, cause)); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) { + donePromise.tryFailure(new IllegalStateException("server closed sentData: " + + sentData + " writeCause: " + writeCause)); + } + }); + } + }).bind(new InetSocketAddress(0)).syncUninterruptibly().channel(); + + cc = new Bootstrap() + .group(group) + .channel(NioSocketChannel.class) + .handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + if (letHandlerCreateClientEngine) { + ch.pipeline().addLast(sslClientCtx.newHandler(ch.alloc())); + } else { + ch.pipeline().addLast(new SslHandler(sslClientCtx.newEngine(ch.alloc()))); + } + ch.pipeline().addLast(new ChannelInboundHandlerAdapter() { + private int bytesSeen; + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + if (msg instanceof ByteBuf) { + bytesSeen += ((ByteBuf) msg).readableBytes(); + if (bytesSeen == expectedBytes) { + donePromise.trySuccess(null); + } + } + ReferenceCountUtil.release(msg); + } + + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + if (evt instanceof SslHandshakeCompletionEvent) { + SslHandshakeCompletionEvent sslEvt = (SslHandshakeCompletionEvent) evt; + if (!sslEvt.isSuccess()) { + donePromise.tryFailure(sslEvt.cause()); + } + } + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + donePromise.tryFailure(new IllegalStateException("client exception. bytesSeen: " + + bytesSeen, cause)); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) { + donePromise.tryFailure(new IllegalStateException("client closed. bytesSeen: " + + bytesSeen)); + } + }); + } + }).connect(sc.localAddress()).syncUninterruptibly().channel(); + + donePromise.get(); + } finally { + if (cc != null) { + cc.close().syncUninterruptibly(); + } + if (sc != null) { + sc.close().syncUninterruptibly(); + } + group.shutdownGracefully(); + + ReferenceCountUtil.release(sslServerCtx); + ReferenceCountUtil.release(sslClientCtx); + ssc.delete(); + } + } + + @ParameterizedTest(name = PARAMETERIZED_NAME) + @MethodSource("data") + @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS) + public void testAlertProducedAndSend(SslProvider clientProvider, SslProvider serverProvider) throws Exception { + SelfSignedCertificate ssc = new SelfSignedCertificate(); + + final SslContext sslServerCtx = SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) + .sslProvider(serverProvider) + .trustManager(new SimpleTrustManagerFactory() { + @Override + protected void engineInit(KeyStore keyStore) { } + @Override + protected void engineInit(ManagerFactoryParameters managerFactoryParameters) { } + + @Override + protected TrustManager[] engineGetTrustManagers() { + return new TrustManager[] { new X509TrustManager() { + + @Override + public void checkClientTrusted(X509Certificate[] x509Certificates, String s) + throws CertificateException { + // Fail verification which should produce an alert that is send back to the client. + throw new CertificateException(); + } + + @Override + public void checkServerTrusted(X509Certificate[] x509Certificates, String s) { + // NOOP + } + + @Override + public X509Certificate[] getAcceptedIssuers() { + return EmptyArrays.EMPTY_X509_CERTIFICATES; + } + } }; + } + }).clientAuth(ClientAuth.REQUIRE).build(); + + final SslContext sslClientCtx = SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .keyManager(ResourcesUtil.getFile(getClass(), "test.crt"), + ResourcesUtil.getFile(getClass(), "test_unencrypted.pem")) + .sslProvider(clientProvider).build(); + + NioEventLoopGroup group = new NioEventLoopGroup(); + Channel sc = null; + Channel cc = null; + try { + final Promise promise = group.next().newPromise(); + sc = new ServerBootstrap() + .group(group) + .channel(NioServerSocketChannel.class) + .childHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ch.pipeline().addLast(sslServerCtx.newHandler(ch.alloc())); + ch.pipeline().addLast(new ChannelInboundHandlerAdapter() { + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + // Just trigger a close + ctx.close(); + } + }); + } + }).bind(new InetSocketAddress(0)).syncUninterruptibly().channel(); + + cc = new Bootstrap() + .group(group) + .channel(NioSocketChannel.class) + .handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ch.pipeline().addLast(sslClientCtx.newHandler(ch.alloc())); + ch.pipeline().addLast(new ChannelInboundHandlerAdapter() { + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + if (cause.getCause() instanceof SSLException) { + // We received the alert and so produce an SSLException. + promise.trySuccess(null); + } + } + }); + } + }).connect(sc.localAddress()).syncUninterruptibly().channel(); + + promise.syncUninterruptibly(); + } finally { + if (cc != null) { + cc.close().syncUninterruptibly(); + } + if (sc != null) { + sc.close().syncUninterruptibly(); + } + group.shutdownGracefully(); + + ReferenceCountUtil.release(sslServerCtx); + ReferenceCountUtil.release(sslClientCtx); + } + } + + @ParameterizedTest(name = PARAMETERIZED_NAME) + @MethodSource("data") + @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS) + public void testCloseNotify(SslProvider clientProvider, SslProvider serverProvider) throws Exception { + testCloseNotify(clientProvider, serverProvider, 5000, false); + } + + @ParameterizedTest(name = PARAMETERIZED_NAME) + @MethodSource("data") + @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS) + public void testCloseNotifyReceivedTimeout(SslProvider clientProvider, SslProvider serverProvider) + throws Exception { + testCloseNotify(clientProvider, serverProvider, 100, true); + } + + @ParameterizedTest(name = PARAMETERIZED_NAME) + @MethodSource("data") + @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS) + public void testCloseNotifyNotWaitForResponse(SslProvider clientProvider, SslProvider serverProvider) + throws Exception { + testCloseNotify(clientProvider, serverProvider, 0, false); + } + + private void testCloseNotify(SslProvider clientProvider, SslProvider serverProvider, + final long closeNotifyReadTimeout, final boolean timeout) throws Exception { + SelfSignedCertificate ssc = new SelfSignedCertificate(); + + final SslContext sslServerCtx = SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) + .sslProvider(serverProvider) + // Use TLSv1.2 as we depend on the fact that the handshake + // is done in an extra round trip in the test which + // is not true in TLSv1.3 + .protocols(SslProtocols.TLS_v1_2) + .build(); + + final SslContext sslClientCtx = SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(clientProvider) + // Use TLSv1.2 as we depend on the fact that the handshake + // is done in an extra round trip in the test which + // is not true in TLSv1.3 + .protocols(SslProtocols.TLS_v1_2) + .build(); + + EventLoopGroup group = new NioEventLoopGroup(); + Channel sc = null; + Channel cc = null; + try { + final Promise clientPromise = group.next().newPromise(); + final Promise serverPromise = group.next().newPromise(); + + sc = new ServerBootstrap() + .group(group) + .channel(NioServerSocketChannel.class) + .childHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + SslHandler handler = sslServerCtx.newHandler(ch.alloc()); + handler.setCloseNotifyReadTimeoutMillis(closeNotifyReadTimeout); + PromiseNotifier.cascade(handler.sslCloseFuture(), serverPromise); + handler.handshakeFuture().addListener(new FutureListener() { + @Override + public void operationComplete(Future future) { + if (!future.isSuccess()) { + // Something bad happened during handshake fail the promise! + serverPromise.tryFailure(future.cause()); + } + } + }); + ch.pipeline().addLast(handler); + } + }).bind(new InetSocketAddress(0)).syncUninterruptibly().channel(); + + cc = new Bootstrap() + .group(group) + .channel(NioSocketChannel.class) + .handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + final AtomicBoolean closeSent = new AtomicBoolean(); + if (timeout) { + ch.pipeline().addFirst(new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if (closeSent.get()) { + // Drop data on the floor so we will get a timeout while waiting for the + // close_notify. + ReferenceCountUtil.release(msg); + } else { + super.channelRead(ctx, msg); + } + } + }); + } + + SslHandler handler = sslClientCtx.newHandler(ch.alloc()); + handler.setCloseNotifyReadTimeoutMillis(closeNotifyReadTimeout); + PromiseNotifier.cascade(handler.sslCloseFuture(), clientPromise); + handler.handshakeFuture().addListener(new FutureListener() { + @Override + public void operationComplete(Future future) { + if (future.isSuccess()) { + closeSent.compareAndSet(false, true); + future.getNow().close(); + } else { + // Something bad happened during handshake fail the promise! + clientPromise.tryFailure(future.cause()); + } + } + }); + ch.pipeline().addLast(handler); + } + }).connect(sc.localAddress()).syncUninterruptibly().channel(); + + serverPromise.awaitUninterruptibly(); + clientPromise.awaitUninterruptibly(); + + // Server always received the close_notify as the client triggers the close sequence. + assertTrue(serverPromise.isSuccess()); + + // Depending on if we wait for the response or not the promise will be failed or not. + if (closeNotifyReadTimeout > 0 && !timeout) { + assertTrue(clientPromise.isSuccess()); + } else { + assertFalse(clientPromise.isSuccess()); + } + } finally { + if (cc != null) { + cc.close().syncUninterruptibly(); + } + if (sc != null) { + sc.close().syncUninterruptibly(); + } + group.shutdownGracefully(); + + ReferenceCountUtil.release(sslServerCtx); + ReferenceCountUtil.release(sslClientCtx); + } + } + + @ParameterizedTest(name = PARAMETERIZED_NAME) + @MethodSource("data") + @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS) + public void reentryOnHandshakeCompleteNioChannel(SslProvider clientProvider, SslProvider serverProvider) + throws Exception { + EventLoopGroup group = new NioEventLoopGroup(); + try { + Class serverClass = NioServerSocketChannel.class; + Class clientClass = NioSocketChannel.class; + SocketAddress bindAddress = new InetSocketAddress(0); + reentryOnHandshakeComplete(clientProvider, serverProvider, group, bindAddress, + serverClass, clientClass, false, false); + reentryOnHandshakeComplete(clientProvider, serverProvider, group, bindAddress, + serverClass, clientClass, false, true); + reentryOnHandshakeComplete(clientProvider, serverProvider, group, bindAddress, + serverClass, clientClass, true, false); + reentryOnHandshakeComplete(clientProvider, serverProvider, group, bindAddress, + serverClass, clientClass, true, true); + } finally { + group.shutdownGracefully(); + } + } + + @ParameterizedTest(name = PARAMETERIZED_NAME) + @MethodSource("data") + @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS) + public void reentryOnHandshakeCompleteLocalChannel(SslProvider clientProvider, SslProvider serverProvider) + throws Exception { + EventLoopGroup group = new DefaultEventLoopGroup(); + try { + Class serverClass = LocalServerChannel.class; + Class clientClass = LocalChannel.class; + SocketAddress bindAddress = new LocalAddress(String.valueOf(current().nextLong())); + reentryOnHandshakeComplete(clientProvider, serverProvider, group, bindAddress, + serverClass, clientClass, false, false); + reentryOnHandshakeComplete(clientProvider, serverProvider, group, bindAddress, + serverClass, clientClass, false, true); + reentryOnHandshakeComplete(clientProvider, serverProvider, group, bindAddress, + serverClass, clientClass, true, false); + reentryOnHandshakeComplete(clientProvider, serverProvider, group, bindAddress, + serverClass, clientClass, true, true); + } finally { + group.shutdownGracefully(); + } + } + + private void reentryOnHandshakeComplete(SslProvider clientProvider, SslProvider serverProvider, + EventLoopGroup group, SocketAddress bindAddress, + Class serverClass, + Class clientClass, boolean serverAutoRead, + boolean clientAutoRead) throws Exception { + SelfSignedCertificate ssc = new SelfSignedCertificate(); + final SslContext sslServerCtx = SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) + .sslProvider(serverProvider) + .build(); + + final SslContext sslClientCtx = SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(clientProvider) + .build(); + + Channel sc = null; + Channel cc = null; + try { + final String expectedContent = "HelloWorld"; + final CountDownLatch serverLatch = new CountDownLatch(1); + final CountDownLatch clientLatch = new CountDownLatch(1); + final StringBuilder serverQueue = new StringBuilder(expectedContent.length()); + final StringBuilder clientQueue = new StringBuilder(expectedContent.length()); + + sc = new ServerBootstrap() + .group(group) + .channel(serverClass) + .childOption(ChannelOption.AUTO_READ, serverAutoRead) + .childHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + ch.pipeline().addLast(disableHandshakeTimeout(sslServerCtx.newHandler(ch.alloc()))); + ch.pipeline().addLast(new ReentryWriteSslHandshakeHandler(expectedContent, serverQueue, + serverLatch)); + } + }).bind(bindAddress).syncUninterruptibly().channel(); + + cc = new Bootstrap() + .group(group) + .channel(clientClass) + .option(ChannelOption.AUTO_READ, clientAutoRead) + .handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + ch.pipeline().addLast(disableHandshakeTimeout(sslClientCtx.newHandler(ch.alloc()))); + ch.pipeline().addLast(new ReentryWriteSslHandshakeHandler(expectedContent, clientQueue, + clientLatch)); + } + }).connect(sc.localAddress()).syncUninterruptibly().channel(); + + serverLatch.await(); + assertEquals(expectedContent, serverQueue.toString()); + clientLatch.await(); + assertEquals(expectedContent, clientQueue.toString()); + } finally { + if (cc != null) { + cc.close().syncUninterruptibly(); + } + if (sc != null) { + sc.close().syncUninterruptibly(); + } + + ReferenceCountUtil.release(sslServerCtx); + ReferenceCountUtil.release(sslClientCtx); + } + } + + private static SslHandler disableHandshakeTimeout(SslHandler handler) { + handler.setHandshakeTimeoutMillis(0); + return handler; + } + + private static final class ReentryWriteSslHandshakeHandler extends SimpleChannelInboundHandler { + private static final InternalLogger LOGGER = + InternalLoggerFactory.getInstance(ReentryWriteSslHandshakeHandler.class); + private final String toWrite; + private final StringBuilder readQueue; + private final CountDownLatch doneLatch; + + ReentryWriteSslHandshakeHandler(String toWrite, StringBuilder readQueue, CountDownLatch doneLatch) { + this.toWrite = toWrite; + this.readQueue = readQueue; + this.doneLatch = doneLatch; + } + + @Override + public void channelActive(ChannelHandlerContext ctx) { + // Write toWrite in two chunks, first here then we get SslHandshakeCompletionEvent (which is re-entry). + ctx.writeAndFlush(writeAscii(ctx.alloc(), toWrite.substring(0, toWrite.length() / 2))); + } + + @Override + protected void channelRead0(ChannelHandlerContext ctx, ByteBuf msg) { + readQueue.append(msg.toString(CharsetUtil.US_ASCII)); + if (readQueue.length() >= toWrite.length()) { + doneLatch.countDown(); + } + } + + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + if (evt instanceof SslHandshakeCompletionEvent) { + SslHandshakeCompletionEvent sslEvt = (SslHandshakeCompletionEvent) evt; + if (sslEvt.isSuccess()) { + // this is the re-entry write, it should be ordered after the subsequent write. + ctx.writeAndFlush(writeAscii(ctx.alloc(), toWrite.substring(toWrite.length() / 2))); + } else { + appendError(sslEvt.cause()); + } + } + ctx.fireUserEventTriggered(evt); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + appendError(cause); + ctx.fireExceptionCaught(cause); + } + + private void appendError(Throwable cause) { + LOGGER.error(new Exception("Caught possible write failure in ParameterizedSslHandlerTest.", cause)); + readQueue.append("failed to write '").append(toWrite).append("': "); + + ByteArrayOutputStream out = new ByteArrayOutputStream(); + try { + cause.printStackTrace(new PrintStream(out)); + readQueue.append(out.toString(CharsetUtil.US_ASCII.name())); + } catch (UnsupportedEncodingException ignore) { + // Let's just fallback to using toString(). + readQueue.append(cause); + } finally { + doneLatch.countDown(); + try { + out.close(); + } catch (IOException ignore) { + // ignore + } + } + } + } +} + diff --git a/netty-handler-ssl/src/test/java/io/netty/handler/ssl/PemEncodedTest.java b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/PemEncodedTest.java new file mode 100644 index 0000000..94e7c0f --- /dev/null +++ b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/PemEncodedTest.java @@ -0,0 +1,123 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.ssl; + +import java.io.ByteArrayOutputStream; +import java.io.File; +import java.io.FileInputStream; +import java.security.PrivateKey; + +import io.netty.buffer.UnpooledByteBufAllocator; + +import io.netty.handler.ssl.util.SelfSignedCertificate; +import io.netty.util.ReferenceCountUtil; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assumptions.assumeFalse; + +public class PemEncodedTest { + + @Test + public void testPemEncodedOpenSsl() throws Exception { + testPemEncoded(SslProvider.OPENSSL); + } + + @Test + public void testPemEncodedOpenSslRef() throws Exception { + testPemEncoded(SslProvider.OPENSSL_REFCNT); + } + + private static void testPemEncoded(SslProvider provider) throws Exception { + OpenSsl.ensureAvailability(); + assumeFalse(OpenSsl.useKeyManagerFactory()); + PemPrivateKey pemKey; + PemX509Certificate pemCert; + SelfSignedCertificate ssc = new SelfSignedCertificate(); + try { + pemKey = PemPrivateKey.valueOf(toByteArray(ssc.privateKey())); + pemCert = PemX509Certificate.valueOf(toByteArray(ssc.certificate())); + } finally { + ssc.delete(); + } + + SslContext context = SslContextBuilder.forServer(pemKey, pemCert) + .sslProvider(provider) + .build(); + assertEquals(1, pemKey.refCnt()); + assertEquals(1, pemCert.refCnt()); + try { + assertTrue(context instanceof ReferenceCountedOpenSslContext); + } finally { + ReferenceCountUtil.release(context); + assertRelease(pemKey); + assertRelease(pemCert); + } + } + + @Test + public void testEncodedReturnsNull() throws Exception { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() throws Throwable { + PemPrivateKey.toPEM(UnpooledByteBufAllocator.DEFAULT, true, new PrivateKey() { + @Override + public String getAlgorithm() { + return null; + } + + @Override + public String getFormat() { + return null; + } + + @Override + public byte[] getEncoded() { + return null; + } + }); + } + }); + } + + private static void assertRelease(PemEncoded encoded) { + assertTrue(encoded.release()); + } + + private static byte[] toByteArray(File file) throws Exception { + FileInputStream in = new FileInputStream(file); + try { + ByteArrayOutputStream baos = new ByteArrayOutputStream(); + try { + byte[] buf = new byte[1024]; + int len; + while ((len = in.read(buf)) != -1) { + baos.write(buf, 0, len); + } + } finally { + baos.close(); + } + + return baos.toByteArray(); + } finally { + in.close(); + } + } +} diff --git a/netty-handler-ssl/src/test/java/io/netty/handler/ssl/PemReaderTest.java b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/PemReaderTest.java new file mode 100644 index 0000000..008f5ad --- /dev/null +++ b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/PemReaderTest.java @@ -0,0 +1,91 @@ +/* + * Copyright 2022 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.buffer.ByteBuf; +import io.netty.util.CharsetUtil; +import org.junit.jupiter.api.Test; + +import java.io.ByteArrayInputStream; + +import static org.assertj.core.api.Assertions.assertThat; + +class PemReaderTest { + @Test + public void mustBeAbleToReadMultipleCertificates() throws Exception { + byte[] certs = ("-----BEGIN CERTIFICATE-----\n" + + "MIICqjCCAZKgAwIBAgIIEaz8uuDHTcIwDQYJKoZIhvcNAQELBQAwFDESMBAGA1UEAwwJbG9jYWxo\n" + + "b3N0MCAXDTIxMDYxNjE3MjYyOFoYDzk5OTkxMjMxMjM1OTU5WjAUMRIwEAYDVQQDDAlsb2NhbGhv\n" + + "c3QwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQCVjENomtpMqHkg1yJ/uYZgSWmf/0Gb\n" + + "U4yMDf30muPvMYb3gO6peEnoXa2b0WDOjLbLrcltp1YdjTlLhRRTYgDo9TAvHoUdoMGlTnfQtQne\n" + + "2o+/92bnlZTroRIjUT0lqSxQ6UNXcOi9tNqVD4tML3vk20fudwBur8Plx+3hOhM/v64GbV46k06+\n" + + "AblrFwBt9u6V0uIVtvgraOd+NgL4yNf594uND30mbB7Q7xe/Y6DiPhI6cVI/CbLlXVwKLvC5OziS\n" + + "JKZ7svP0K3DBRxk+dOD9pg4SdaAEQVtR734ZlDh1XJ+mZssuDDda3NGZAjpCU4rkeV/J3Tr5KKMD\n" + + "g3NEOmifAgMBAAEwDQYJKoZIhvcNAQELBQADggEBABejZGeRNCyPdIqac6cyAf99JPp5OySEMWHU\n" + + "QXVCHEbQ8Eh6hsmrXSEaZS2zy/5dixPb5Rin0xaX5fqZdfLIUc0Mw/zXV/RiOIjrtUKez15Ko/4K\n" + + "ONyxELjUq+SaJx2C1UcKEMfQBeq7O6XO60CLQ2AtVozmvJU9pt4KYQv0Kr+br3iNFpRuR43IYHyx\n" + + "HP7QsD3L3LEqIqW/QtYEnAAngZofUiq0XELh4GB0L8DbcSJIxfZmYagFl7c2go9OZPD14mlaTnMV\n" + + "Pjd+OkwMif5T7v+r+KVSmDSMQwa+NfW+V6Xngg5/bN3kWHdw9qFQGANojl9wsRVN/B3pu3Cc2XFD\n" + + "MmQ=\n" + + "-----END CERTIFICATE-----\n" + + "-----BEGIN CERTIFICATE-----\n" + + "MIICqjCCAZKgAwIBAgIIIsUS6UkDau4wDQYJKoZIhvcNAQELBQAwFDESMBAGA1UEAwwJbG9jYWxo\n" + + "b3N0MCAXDTIxMDYxNjE3MjYyOFoYDzk5OTkxMjMxMjM1OTU5WjAUMRIwEAYDVQQDDAlsb2NhbGhv\n" + + "c3QwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQCTZmahFZXB0Dv3N8t6gfXTeqhTxRng\n" + + "mIBBPmrbZBODrZm06vrR5KNhxB2FhWIq1Yu8xXXv8sO+PaO2Sw/h6TeslRJ4EkrNd9zmYhT2cJvP\n" + + "d1CtkX5EHyMZRUKj7Eg4eUO1k/+JnhMmaY+nUAG7fCtvs8pS9SEXbEqYW7S4AQ1oopbCAMqQekly\n" + + "KCdnjGlVhXwL2Lj2rr/uw1Fc2+WvY/leQGo0rbIqoc7OSAktsP+MXI6iQ1RWJOec15V6iFRzcdE3\n" + + "Q4ODSMZ/R8wm9DH+4hkeQNPMbcc1wlvVZpDZ/FZegr1XimcYcJr2AoAQf3Xe1yFKAtBMXCjCIGm8\n" + + "veCQ+xeHAgMBAAEwDQYJKoZIhvcNAQELBQADggEBAGyV+dib2MdpenbntKk7PdZEx+/vNl9cEpwL\n" + + "BfWmQN/j2RmHUxrUM+PVkLTgyCq8okdCKqCvraKwBkF6vlzp4u5CL4L323z+/uxAi6pzbcWnG1EH\n" + + "JpSkf1OhTUFu6UhLfpg3XqeiIujYdVZTpHr7KHVLRYUSQPprt4HjLZeCIg4P2pZ0yQ3SEBhVed89\n" + + "GMj/+O4jjvuZv5NQc57NpMIrE9fNINczLG1CPTgnhvqMP42W6ahBuexQUe4gP+jmB/BZmBYKoauU\n" + + "mPBKruq3mNuoXtbHufv5I7CFVXNgJ0/aT+lvEkQ4IlCIcJyvTgyUTOQVbqDp+SswymAIRowaRdxa\n" + + "7Ss=\n" + + "-----END CERTIFICATE-----\n").getBytes(CharsetUtil.US_ASCII); + ByteArrayInputStream in = new ByteArrayInputStream(certs); + ByteBuf[] bufs = PemReader.readCertificates(in); + in.close(); + assertThat(bufs.length).isEqualTo(2); + for (ByteBuf buf : bufs) { + buf.release(); + } + } + + @Test + public void mustBeAbleToReadPrivateKey() throws Exception { + byte[] key = ("-----BEGIN PRIVATE KEY-----\n" + + "MIICqjCCAZKgAwIBAgIIEaz8uuDHTcIwDQYJKoZIhvcNAQELBQAwFDESMBAGA1UEAwwJbG9jYWxo\n" + + "b3N0MCAXDTIxMDYxNjE3MjYyOFoYDzk5OTkxMjMxMjM1OTU5WjAUMRIwEAYDVQQDDAlsb2NhbGhv\n" + + "c3QwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQCVjENomtpMqHkg1yJ/uYZgSWmf/0Gb\n" + + "U4yMDf30muPvMYb3gO6peEnoXa2b0WDOjLbLrcltp1YdjTlLhRRTYgDo9TAvHoUdoMGlTnfQtQne\n" + + "2o+/92bnlZTroRIjUT0lqSxQ6UNXcOi9tNqVD4tML3vk20fudwBur8Plx+3hOhM/v64GbV46k06+\n" + + "AblrFwBt9u6V0uIVtvgraOd+NgL4yNf594uND30mbB7Q7xe/Y6DiPhI6cVI/CbLlXVwKLvC5OziS\n" + + "JKZ7svP0K3DBRxk+dOD9pg4SdaAEQVtR734ZlDh1XJ+mZssuDDda3NGZAjpCU4rkeV/J3Tr5KKMD\n" + + "g3NEOmifAgMBAAEwDQYJKoZIhvcNAQELBQADggEBABejZGeRNCyPdIqac6cyAf99JPp5OySEMWHU\n" + + "QXVCHEbQ8Eh6hsmrXSEaZS2zy/5dixPb5Rin0xaX5fqZdfLIUc0Mw/zXV/RiOIjrtUKez15Ko/4K\n" + + "ONyxELjUq+SaJx2C1UcKEMfQBeq7O6XO60CLQ2AtVozmvJU9pt4KYQv0Kr+br3iNFpRuR43IYHyx\n" + + "HP7QsD3L3LEqIqW/QtYEnAAngZofUiq0XELh4GB0L8DbcSJIxfZmYagFl7c2go9OZPD14mlaTnMV\n" + + "Pjd+OkwMif5T7v+r+KVSmDSMQwa+NfW+V6Xngg5/bN3kWHdw9qFQGANojl9wsRVN/B3pu3Cc2XFD\n" + + "MmQ=\n" + + "-----END PRIVATE KEY-----\n").getBytes(CharsetUtil.US_ASCII); + ByteArrayInputStream in = new ByteArrayInputStream(key); + ByteBuf buf = PemReader.readPrivateKey(in); + in.close(); + assertThat(buf.readableBytes()).isEqualTo(686); + buf.release(); + } +} diff --git a/netty-handler-ssl/src/test/java/io/netty/handler/ssl/PseudoRandomFunctionTest.java b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/PseudoRandomFunctionTest.java new file mode 100644 index 0000000..a480fa0 --- /dev/null +++ b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/PseudoRandomFunctionTest.java @@ -0,0 +1,52 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.ssl; + +import io.netty.util.CharsetUtil; +import org.bouncycastle.util.encoders.Hex; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +/** + * The test vectors here were provided via: + * https://www.ietf.org/mail-archive/web/tls/current/msg03416.html + */ +public class PseudoRandomFunctionTest { + + @Test + public void testPrfSha256() { + byte[] secret = Hex.decode("9b be 43 6b a9 40 f0 17 b1 76 52 84 9a 71 db 35"); + byte[] seed = Hex.decode("a0 ba 9f 93 6c da 31 18 27 a6 f7 96 ff d5 19 8c"); + byte[] label = "test label".getBytes(CharsetUtil.US_ASCII); + byte[] expected = Hex.decode( + "e3 f2 29 ba 72 7b e1 7b" + + "8d 12 26 20 55 7c d4 53" + + "c2 aa b2 1d 07 c3 d4 95" + + "32 9b 52 d4 e6 1e db 5a" + + "6b 30 17 91 e9 0d 35 c9" + + "c9 a4 6b 4e 14 ba f9 af" + + "0f a0 22 f7 07 7d ef 17" + + "ab fd 37 97 c0 56 4b ab" + + "4f bc 91 66 6e 9d ef 9b" + + "97 fc e3 4f 79 67 89 ba" + + "a4 80 82 d1 22 ee 42 c5" + + "a7 2e 5a 51 10 ff f7 01" + + "87 34 7b 66"); + byte[] actual = PseudoRandomFunction.hash(secret, label, seed, expected.length, "HmacSha256"); + assertArrayEquals(expected, actual); + } +} diff --git a/netty-handler-ssl/src/test/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngineTest.java b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngineTest.java new file mode 100644 index 0000000..aa32b7d --- /dev/null +++ b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/ReferenceCountedOpenSslEngineTest.java @@ -0,0 +1,112 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.buffer.UnpooledByteBufAllocator; +import io.netty.handler.ssl.util.InsecureTrustManagerFactory; +import io.netty.util.ReferenceCountUtil; +import org.junit.jupiter.api.function.Executable; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import javax.net.ssl.SSLEngine; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class ReferenceCountedOpenSslEngineTest extends OpenSslEngineTest { + + @Override + protected SslProvider sslClientProvider() { + return SslProvider.OPENSSL_REFCNT; + } + + @Override + protected SslProvider sslServerProvider() { + return SslProvider.OPENSSL_REFCNT; + } + + @Override + protected void cleanupClientSslContext(SslContext ctx) { + ReferenceCountUtil.release(ctx); + } + + @Override + protected void cleanupClientSslEngine(SSLEngine engine) { + ReferenceCountUtil.release(unwrapEngine(engine)); + } + + @Override + protected void cleanupServerSslContext(SslContext ctx) { + ReferenceCountUtil.release(ctx); + } + + @Override + protected void cleanupServerSslEngine(SSLEngine engine) { + ReferenceCountUtil.release(unwrapEngine(engine)); + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testNotLeakOnException(SSLEngineTestParam param) throws Exception { + clientSslCtx = wrapContext(param, SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(sslClientProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + + assertThrows(NullPointerException.class, new Executable() { + @Override + public void execute() throws Throwable { + clientSslCtx.newEngine(null); + } + }); + } + + @SuppressWarnings("deprecation") + @Override + protected SslContext wrapContext(SSLEngineTestParam param, SslContext context) { + if (context instanceof ReferenceCountedOpenSslContext) { + if (param instanceof OpenSslEngineTestParam) { + ((ReferenceCountedOpenSslContext) context).setUseTasks(((OpenSslEngineTestParam) param).useTasks); + } + // Explicit enable the session cache as its disabled by default on the client side. + ((ReferenceCountedOpenSslContext) context).sessionContext().setSessionCacheEnabled(true); + } + return context; + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void parentContextIsRetainedByChildEngines(SSLEngineTestParam param) throws Exception { + SslContext clientSslCtx = wrapContext(param, SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(sslClientProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + + SSLEngine engine = clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT); + assertEquals(ReferenceCountUtil.refCnt(clientSslCtx), 2); + + cleanupClientSslContext(clientSslCtx); + assertEquals(ReferenceCountUtil.refCnt(clientSslCtx), 1); + + cleanupClientSslEngine(engine); + assertEquals(ReferenceCountUtil.refCnt(clientSslCtx), 0); + } +} diff --git a/netty-handler-ssl/src/test/java/io/netty/handler/ssl/RenegotiateTest.java b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/RenegotiateTest.java new file mode 100644 index 0000000..210ff4b --- /dev/null +++ b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/RenegotiateTest.java @@ -0,0 +1,154 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.local.LocalAddress; +import io.netty.channel.local.LocalChannel; +import io.netty.channel.local.LocalEventLoopGroup; +import io.netty.channel.local.LocalServerChannel; +import io.netty.handler.ssl.util.InsecureTrustManagerFactory; +import io.netty.handler.ssl.util.SelfSignedCertificate; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.FutureListener; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicReference; + +public abstract class RenegotiateTest { + + @Test + @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS) + public void testRenegotiateServer() throws Throwable { + final AtomicReference error = new AtomicReference(); + final CountDownLatch latch = new CountDownLatch(2); + SelfSignedCertificate cert = new SelfSignedCertificate(); + EventLoopGroup group = new LocalEventLoopGroup(); + try { + final SslContext context = SslContextBuilder.forServer(cert.key(), cert.cert()) + .sslProvider(serverSslProvider()) + .protocols(SslProtocols.TLS_v1_2) + .build(); + + ServerBootstrap sb = new ServerBootstrap(); + sb.group(group).channel(LocalServerChannel.class) + .childHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + SslHandler handler = context.newHandler(ch.alloc()); + handler.setHandshakeTimeoutMillis(0); + ch.pipeline().addLast(handler); + ch.pipeline().addLast(new ChannelInboundHandlerAdapter() { + private boolean renegotiate; + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + ReferenceCountUtil.release(msg); + } + + @Override + public void userEventTriggered( + final ChannelHandlerContext ctx, Object evt) throws Exception { + if (!renegotiate && evt instanceof SslHandshakeCompletionEvent) { + SslHandshakeCompletionEvent event = (SslHandshakeCompletionEvent) evt; + + if (event.isSuccess()) { + final SslHandler handler = ctx.pipeline().get(SslHandler.class); + + renegotiate = true; + handler.renegotiate().addListener(new FutureListener() { + @Override + public void operationComplete(Future future) throws Exception { + if (!future.isSuccess()) { + error.compareAndSet(null, future.cause()); + ctx.close(); + } + latch.countDown(); + } + }); + } else { + error.compareAndSet(null, event.cause()); + latch.countDown(); + + ctx.close(); + } + } + } + }); + } + }); + Channel channel = sb.bind(new LocalAddress("RenegotiateTest")).syncUninterruptibly().channel(); + + final SslContext clientContext = SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(SslProvider.JDK) + .protocols(SslProtocols.TLS_v1_2) + .build(); + + Bootstrap bootstrap = new Bootstrap(); + bootstrap.group(group).channel(LocalChannel.class) + .handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + SslHandler handler = clientContext.newHandler(ch.alloc()); + handler.setHandshakeTimeoutMillis(0); + ch.pipeline().addLast(handler); + ch.pipeline().addLast(new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered( + ChannelHandlerContext ctx, Object evt) throws Exception { + if (evt instanceof SslHandshakeCompletionEvent) { + SslHandshakeCompletionEvent event = (SslHandshakeCompletionEvent) evt; + if (!event.isSuccess()) { + error.compareAndSet(null, event.cause()); + ctx.close(); + } + latch.countDown(); + } + } + }); + } + }); + + Channel clientChannel = bootstrap.connect(channel.localAddress()).syncUninterruptibly().channel(); + latch.await(); + clientChannel.close().syncUninterruptibly(); + channel.close().syncUninterruptibly(); + verifyResult(error); + } finally { + group.shutdownGracefully(); + } + } + + protected abstract SslProvider serverSslProvider(); + + protected void verifyResult(AtomicReference error) throws Throwable { + Throwable cause = error.get(); + if (cause != null) { + throw cause; + } + } +} diff --git a/netty-handler-ssl/src/test/java/io/netty/handler/ssl/SSLEngineTest.java b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/SSLEngineTest.java new file mode 100644 index 0000000..f371d62 --- /dev/null +++ b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/SSLEngineTest.java @@ -0,0 +1,4489 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.CompositeByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.buffer.UnpooledByteBufAllocator; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.SocketChannel; +import io.netty.channel.socket.nio.NioServerSocketChannel; +import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.handler.ssl.ApplicationProtocolConfig.Protocol; +import io.netty.handler.ssl.util.InsecureTrustManagerFactory; +import io.netty.handler.ssl.util.SelfSignedCertificate; +import io.netty.handler.ssl.util.SimpleTrustManagerFactory; +import io.netty.util.CharsetUtil; +import io.netty.util.NetUtil; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.concurrent.ImmediateEventExecutor; +import io.netty.util.concurrent.PromiseNotifier; +import io.netty.util.internal.ResourcesUtil; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.Promise; +import io.netty.util.internal.EmptyArrays; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.StringUtil; +import io.netty.util.internal.SystemPropertyUtil; +import org.conscrypt.OpenSSLProvider; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.api.function.Executable; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; +import org.mockito.ArgumentCaptor; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.opentest4j.AssertionFailedError; + +import java.io.ByteArrayInputStream; +import java.io.ByteArrayOutputStream; +import java.io.Closeable; +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.net.InetSocketAddress; +import java.net.Socket; +import java.nio.ByteBuffer; +import java.nio.channels.ClosedChannelException; +import java.nio.file.Files; +import java.security.KeyStore; +import java.security.KeyStoreException; +import java.security.NoSuchAlgorithmException; +import java.security.Principal; +import java.security.PrivateKey; +import java.security.Provider; +import java.security.UnrecoverableKeyException; +import java.security.cert.Certificate; +import java.security.cert.CertificateException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Enumeration; +import java.util.HashSet; +import java.util.List; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; + +import javax.crypto.SecretKey; +import javax.net.ssl.ExtendedSSLSession; +import javax.net.ssl.KeyManager; +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.KeyManagerFactorySpi; +import javax.net.ssl.ManagerFactoryParameters; +import javax.net.ssl.SNIHostName; +import javax.net.ssl.SNIServerName; +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLEngineResult; +import javax.net.ssl.SSLEngineResult.Status; +import javax.net.ssl.SSLException; +import javax.net.ssl.SSLHandshakeException; +import javax.net.ssl.SSLParameters; +import javax.net.ssl.SSLPeerUnverifiedException; +import javax.net.ssl.SSLSession; +import javax.net.ssl.SSLSessionBindingEvent; +import javax.net.ssl.SSLSessionBindingListener; +import javax.net.ssl.SSLSessionContext; +import javax.net.ssl.SSLSocketFactory; +import javax.net.ssl.TrustManager; +import javax.net.ssl.TrustManagerFactory; +import javax.net.ssl.TrustManagerFactorySpi; +import javax.net.ssl.X509ExtendedKeyManager; +import javax.net.ssl.X509ExtendedTrustManager; +import javax.net.ssl.X509TrustManager; +import javax.security.cert.X509Certificate; + +import static io.netty.handler.ssl.SslUtils.*; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; +import static org.junit.jupiter.api.Assumptions.assumeFalse; +import static org.junit.jupiter.api.Assumptions.assumeTrue; +import static org.mockito.Mockito.verify; + +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +public abstract class SSLEngineTest { + + private static final String PRINCIPAL_NAME = "CN=e8ac02fa0d65a84219016045db8b05c485b4ecdf.netty.test"; + private final boolean tlsv13Supported; + + @Mock + protected MessageReceiver serverReceiver; + @Mock + protected MessageReceiver clientReceiver; + + protected Throwable serverException; + protected Throwable clientException; + protected SslContext serverSslCtx; + protected SslContext clientSslCtx; + protected ServerBootstrap sb; + protected Bootstrap cb; + protected Channel serverChannel; + protected Channel serverConnectedChannel; + protected Channel clientChannel; + protected CountDownLatch serverLatch; + protected CountDownLatch clientLatch; + + interface MessageReceiver { + void messageReceived(ByteBuf msg); + } + + protected static final class MessageDelegatorChannelHandler extends SimpleChannelInboundHandler { + private final MessageReceiver receiver; + private final CountDownLatch latch; + + public MessageDelegatorChannelHandler(MessageReceiver receiver, CountDownLatch latch) { + super(false); + this.receiver = receiver; + this.latch = latch; + } + + @Override + protected void channelRead0(ChannelHandlerContext ctx, ByteBuf msg) throws Exception { + receiver.messageReceived(msg); + latch.countDown(); + } + } + + enum BufferType { + Direct, + Heap, + Mixed + } + + static final class ProtocolCipherCombo { + private static final ProtocolCipherCombo TLSV12 = new ProtocolCipherCombo( + SslProtocols.TLS_v1_2, "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256"); + private static final ProtocolCipherCombo TLSV13 = new ProtocolCipherCombo( + SslProtocols.TLS_v1_3, "TLS_AES_128_GCM_SHA256"); + final String protocol; + final String cipher; + + private ProtocolCipherCombo(String protocol, String cipher) { + this.protocol = protocol; + this.cipher = cipher; + } + + static ProtocolCipherCombo tlsv12() { + return TLSV12; + } + + static ProtocolCipherCombo tlsv13() { + return TLSV13; + } + + @Override + public String toString() { + return "ProtocolCipherCombo{" + + "protocol='" + protocol + '\'' + + ", cipher='" + cipher + '\'' + + '}'; + } + } + + protected SSLEngineTest(boolean tlsv13Supported) { + this.tlsv13Supported = tlsv13Supported; + } + + protected static class SSLEngineTestParam { + private final BufferType type; + private final ProtocolCipherCombo protocolCipherCombo; + private final boolean delegate; + + SSLEngineTestParam(BufferType type, ProtocolCipherCombo protocolCipherCombo, boolean delegate) { + this.type = type; + this.protocolCipherCombo = protocolCipherCombo; + this.delegate = delegate; + } + + final BufferType type() { + return type; + } + + final ProtocolCipherCombo combo() { + return protocolCipherCombo; + } + + final boolean delegate() { + return delegate; + } + + final List protocols() { + return Collections.singletonList(protocolCipherCombo.protocol); + } + + final List ciphers() { + return Collections.singletonList(protocolCipherCombo.cipher); + } + } + + protected List newTestParams() { + List params = new ArrayList(); + for (BufferType type: BufferType.values()) { + params.add(new SSLEngineTestParam(type, ProtocolCipherCombo.tlsv12(), false)); + params.add(new SSLEngineTestParam(type, ProtocolCipherCombo.tlsv12(), true)); + + if (tlsv13Supported) { + params.add(new SSLEngineTestParam(type, ProtocolCipherCombo.tlsv13(), false)); + params.add(new SSLEngineTestParam(type, ProtocolCipherCombo.tlsv13(), true)); + } + } + return params; + } + + private ExecutorService delegatingExecutor; + + protected ByteBuffer allocateBuffer(BufferType type, int len) { + switch (type) { + case Direct: + return ByteBuffer.allocateDirect(len); + case Heap: + return ByteBuffer.allocate(len); + case Mixed: + return PlatformDependent.threadLocalRandom().nextBoolean() ? + ByteBuffer.allocateDirect(len) : ByteBuffer.allocate(len); + default: + throw new Error(); + } + } + + private static final class TestByteBufAllocator implements ByteBufAllocator { + + private final ByteBufAllocator allocator; + private final BufferType type; + + TestByteBufAllocator(ByteBufAllocator allocator, BufferType type) { + this.allocator = allocator; + this.type = type; + } + + @Override + public ByteBuf buffer() { + switch (type) { + case Direct: + return allocator.directBuffer(); + case Heap: + return allocator.heapBuffer(); + case Mixed: + return PlatformDependent.threadLocalRandom().nextBoolean() ? + allocator.directBuffer() : allocator.heapBuffer(); + default: + throw new Error(); + } + } + + @Override + public ByteBuf buffer(int initialCapacity) { + switch (type) { + case Direct: + return allocator.directBuffer(initialCapacity); + case Heap: + return allocator.heapBuffer(initialCapacity); + case Mixed: + return PlatformDependent.threadLocalRandom().nextBoolean() ? + allocator.directBuffer(initialCapacity) : allocator.heapBuffer(initialCapacity); + default: + throw new Error(); + } + } + + @Override + public ByteBuf buffer(int initialCapacity, int maxCapacity) { + switch (type) { + case Direct: + return allocator.directBuffer(initialCapacity, maxCapacity); + case Heap: + return allocator.heapBuffer(initialCapacity, maxCapacity); + case Mixed: + return PlatformDependent.threadLocalRandom().nextBoolean() ? + allocator.directBuffer(initialCapacity, maxCapacity) : + allocator.heapBuffer(initialCapacity, maxCapacity); + default: + throw new Error(); + } + } + + @Override + public ByteBuf ioBuffer() { + return allocator.ioBuffer(); + } + + @Override + public ByteBuf ioBuffer(int initialCapacity) { + return allocator.ioBuffer(initialCapacity); + } + + @Override + public ByteBuf ioBuffer(int initialCapacity, int maxCapacity) { + return allocator.ioBuffer(initialCapacity, maxCapacity); + } + + @Override + public ByteBuf heapBuffer() { + return allocator.heapBuffer(); + } + + @Override + public ByteBuf heapBuffer(int initialCapacity) { + return allocator.heapBuffer(initialCapacity); + } + + @Override + public ByteBuf heapBuffer(int initialCapacity, int maxCapacity) { + return allocator.heapBuffer(initialCapacity, maxCapacity); + } + + @Override + public ByteBuf directBuffer() { + return allocator.directBuffer(); + } + + @Override + public ByteBuf directBuffer(int initialCapacity) { + return allocator.directBuffer(initialCapacity); + } + + @Override + public ByteBuf directBuffer(int initialCapacity, int maxCapacity) { + return allocator.directBuffer(initialCapacity, maxCapacity); + } + + @Override + public CompositeByteBuf compositeBuffer() { + switch (type) { + case Direct: + return allocator.compositeDirectBuffer(); + case Heap: + return allocator.compositeHeapBuffer(); + case Mixed: + return PlatformDependent.threadLocalRandom().nextBoolean() ? + allocator.compositeDirectBuffer() : + allocator.compositeHeapBuffer(); + default: + throw new Error(); + } + } + + @Override + public CompositeByteBuf compositeBuffer(int maxNumComponents) { + switch (type) { + case Direct: + return allocator.compositeDirectBuffer(maxNumComponents); + case Heap: + return allocator.compositeHeapBuffer(maxNumComponents); + case Mixed: + return PlatformDependent.threadLocalRandom().nextBoolean() ? + allocator.compositeDirectBuffer(maxNumComponents) : + allocator.compositeHeapBuffer(maxNumComponents); + default: + throw new Error(); + } + } + + @Override + public CompositeByteBuf compositeHeapBuffer() { + return allocator.compositeHeapBuffer(); + } + + @Override + public CompositeByteBuf compositeHeapBuffer(int maxNumComponents) { + return allocator.compositeHeapBuffer(maxNumComponents); + } + + @Override + public CompositeByteBuf compositeDirectBuffer() { + return allocator.compositeDirectBuffer(); + } + + @Override + public CompositeByteBuf compositeDirectBuffer(int maxNumComponents) { + return allocator.compositeDirectBuffer(maxNumComponents); + } + + @Override + public boolean isDirectBufferPooled() { + return allocator.isDirectBufferPooled(); + } + + @Override + public int calculateNewCapacity(int minNewCapacity, int maxCapacity) { + return allocator.calculateNewCapacity(minNewCapacity, maxCapacity); + } + } + + @BeforeEach + public void setup() { + MockitoAnnotations.initMocks(this); + serverLatch = new CountDownLatch(1); + clientLatch = new CountDownLatch(1); + delegatingExecutor = Executors.newCachedThreadPool(); + } + + @AfterEach + public void tearDown() throws InterruptedException { + ChannelFuture clientCloseFuture = null; + ChannelFuture serverConnectedCloseFuture = null; + ChannelFuture serverCloseFuture = null; + if (clientChannel != null) { + clientCloseFuture = clientChannel.close(); + clientChannel = null; + } + if (serverConnectedChannel != null) { + serverConnectedCloseFuture = serverConnectedChannel.close(); + serverConnectedChannel = null; + } + if (serverChannel != null) { + serverCloseFuture = serverChannel.close(); + serverChannel = null; + } + // We must wait for the Channel cleanup to finish. In the case if the ReferenceCountedOpenSslEngineTest + // the ReferenceCountedOpenSslEngine depends upon the SslContext and so we must wait the cleanup the + // SslContext to avoid JVM core dumps! + // + // See https://github.com/netty/netty/issues/5692 + if (clientCloseFuture != null) { + clientCloseFuture.sync(); + } + if (serverConnectedCloseFuture != null) { + serverConnectedCloseFuture.sync(); + } + if (serverCloseFuture != null) { + serverCloseFuture.sync(); + } + if (serverSslCtx != null) { + cleanupServerSslContext(serverSslCtx); + serverSslCtx = null; + } + if (clientSslCtx != null) { + cleanupClientSslContext(clientSslCtx); + clientSslCtx = null; + } + Future serverGroupShutdownFuture = null; + Future serverChildGroupShutdownFuture = null; + Future clientGroupShutdownFuture = null; + if (sb != null) { + serverGroupShutdownFuture = sb.config().group().shutdownGracefully(0, 0, TimeUnit.MILLISECONDS); + serverChildGroupShutdownFuture = sb.config().childGroup().shutdownGracefully(0, 0, TimeUnit.MILLISECONDS); + } + if (cb != null) { + clientGroupShutdownFuture = cb.config().group().shutdownGracefully(0, 0, TimeUnit.MILLISECONDS); + } + if (serverGroupShutdownFuture != null) { + serverGroupShutdownFuture.sync(); + serverChildGroupShutdownFuture.sync(); + } + if (clientGroupShutdownFuture != null) { + clientGroupShutdownFuture.sync(); + } + delegatingExecutor.shutdown(); + serverException = null; + clientException = null; + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testMutualAuthSameCerts(SSLEngineTestParam param) throws Throwable { + mySetupMutualAuth(param, ResourcesUtil.getFile(getClass(), "test_unencrypted.pem"), + ResourcesUtil.getFile(getClass(), "test.crt"), + null); + runTest(null); + assertTrue(serverLatch.await(2, TimeUnit.SECONDS)); + Throwable cause = serverException; + if (cause != null) { + throw cause; + } + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testSetSupportedCiphers(SSLEngineTestParam param) throws Exception { + if (param.protocolCipherCombo != ProtocolCipherCombo.tlsv12()) { + return; + } + SelfSignedCertificate cert = new SelfSignedCertificate(); + serverSslCtx = wrapContext(param, SslContextBuilder.forServer(cert.key(), cert.cert()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .sslProvider(sslServerProvider()).build()); + final SSLEngine serverEngine = + wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + clientSslCtx = wrapContext(param, SslContextBuilder.forClient() + .trustManager(cert.certificate()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .sslProvider(sslClientProvider()).build()); + final SSLEngine clientEngine = + wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + + final String[] enabledCiphers = new String[]{ param.ciphers().get(0) }; + + try { + clientEngine.setEnabledCipherSuites(enabledCiphers); + serverEngine.setEnabledCipherSuites(enabledCiphers); + + assertArrayEquals(enabledCiphers, clientEngine.getEnabledCipherSuites()); + assertArrayEquals(enabledCiphers, serverEngine.getEnabledCipherSuites()); + } finally { + cleanupClientSslEngine(clientEngine); + cleanupServerSslEngine(serverEngine); + cert.delete(); + } + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testIncompatibleCiphers(final SSLEngineTestParam param) throws Exception { + assumeTrue(SslProvider.isTlsv13Supported(sslClientProvider())); + assumeTrue(SslProvider.isTlsv13Supported(sslServerProvider())); + + SelfSignedCertificate ssc = new SelfSignedCertificate(); + // Select a mandatory cipher from the TLSv1.2 RFC https://www.ietf.org/rfc/rfc5246.txt so handshakes won't fail + // due to no shared/supported cipher. + clientSslCtx = wrapContext(param, SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .protocols(SslProtocols.TLS_v1_3, SslProtocols.TLS_v1_2, SslProtocols.TLS_v1) + .sslContextProvider(clientSslContextProvider()) + .sslProvider(sslClientProvider()) + .build()); + + serverSslCtx = wrapContext(param, SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) + .protocols(SslProtocols.TLS_v1_3, SslProtocols.TLS_v1_2, SslProtocols.TLS_v1) + .sslContextProvider(serverSslContextProvider()) + .sslProvider(sslServerProvider()) + .build()); + SSLEngine clientEngine = null; + SSLEngine serverEngine = null; + try { + clientEngine = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + serverEngine = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + + // Set the server to only support a single TLSv1.2 cipher + final String serverCipher = "TLS_RSA_WITH_AES_128_CBC_SHA"; + serverEngine.setEnabledCipherSuites(new String[] { serverCipher }); + + // Set the client to only support a single TLSv1.3 cipher + final String clientCipher = "TLS_AES_256_GCM_SHA384"; + clientEngine.setEnabledCipherSuites(new String[] { clientCipher }); + + final SSLEngine client = clientEngine; + final SSLEngine server = serverEngine; + assertThrows(SSLHandshakeException.class, new Executable() { + @Override + public void execute() throws Throwable { + handshake(param.type(), param.delegate(), client, server); + } + }); + } finally { + cleanupClientSslEngine(clientEngine); + cleanupServerSslEngine(serverEngine); + ssc.delete(); + } + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testMutualAuthDiffCerts(SSLEngineTestParam param) throws Exception { + File serverKeyFile = ResourcesUtil.getFile(getClass(), "test_encrypted.pem"); + File serverCrtFile = ResourcesUtil.getFile(getClass(), "test.crt"); + String serverKeyPassword = "12345"; + File clientKeyFile = ResourcesUtil.getFile(getClass(), "test2_encrypted.pem"); + File clientCrtFile = ResourcesUtil.getFile(getClass(), "test2.crt"); + String clientKeyPassword = "12345"; + mySetupMutualAuth(param, clientCrtFile, serverKeyFile, serverCrtFile, serverKeyPassword, + serverCrtFile, clientKeyFile, clientCrtFile, clientKeyPassword); + runTest(null); + assertTrue(serverLatch.await(2, TimeUnit.SECONDS)); + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testMutualAuthDiffCertsServerFailure(SSLEngineTestParam param) throws Exception { + File serverKeyFile = ResourcesUtil.getFile(getClass(), "test_encrypted.pem"); + File serverCrtFile = ResourcesUtil.getFile(getClass(), "test.crt"); + String serverKeyPassword = "12345"; + File clientKeyFile = ResourcesUtil.getFile(getClass(), "test2_encrypted.pem"); + File clientCrtFile = ResourcesUtil.getFile(getClass(), "test2.crt"); + String clientKeyPassword = "12345"; + // Client trusts server but server only trusts itself + mySetupMutualAuth(param, serverCrtFile, serverKeyFile, serverCrtFile, serverKeyPassword, + serverCrtFile, clientKeyFile, clientCrtFile, clientKeyPassword); + assertTrue(serverLatch.await(10, TimeUnit.SECONDS)); + assertTrue(serverException instanceof SSLHandshakeException); + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testMutualAuthDiffCertsClientFailure(SSLEngineTestParam param) throws Exception { + File serverKeyFile = ResourcesUtil.getFile(getClass(), "test_unencrypted.pem"); + File serverCrtFile = ResourcesUtil.getFile(getClass(), "test.crt"); + String serverKeyPassword = null; + File clientKeyFile = ResourcesUtil.getFile(getClass(), "test2_unencrypted.pem"); + File clientCrtFile = ResourcesUtil.getFile(getClass(), "test2.crt"); + String clientKeyPassword = null; + // Server trusts client but client only trusts itself + mySetupMutualAuth(param, clientCrtFile, serverKeyFile, serverCrtFile, serverKeyPassword, + clientCrtFile, clientKeyFile, clientCrtFile, clientKeyPassword); + assertTrue(clientLatch.await(10, TimeUnit.SECONDS)); + assertTrue(clientException instanceof SSLHandshakeException); + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testMutualAuthInvalidIntermediateCASucceedWithOptionalClientAuth(SSLEngineTestParam param) + throws Exception { + testMutualAuthInvalidClientCertSucceed(param, ClientAuth.NONE); + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testMutualAuthInvalidIntermediateCAFailWithOptionalClientAuth(SSLEngineTestParam param) + throws Exception { + testMutualAuthClientCertFail(param, ClientAuth.OPTIONAL); + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testMutualAuthInvalidIntermediateCAFailWithRequiredClientAuth(SSLEngineTestParam param) + throws Exception { + testMutualAuthClientCertFail(param, ClientAuth.REQUIRE); + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testMutualAuthValidClientCertChainTooLongFailOptionalClientAuth(SSLEngineTestParam param) + throws Exception { + testMutualAuthClientCertFail(param, ClientAuth.OPTIONAL, "mutual_auth_client.p12", true); + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testMutualAuthValidClientCertChainTooLongFailRequireClientAuth(SSLEngineTestParam param) + throws Exception { + testMutualAuthClientCertFail(param, ClientAuth.REQUIRE, "mutual_auth_client.p12", true); + } + + private void testMutualAuthInvalidClientCertSucceed(SSLEngineTestParam param, ClientAuth auth) throws Exception { + char[] password = "example".toCharArray(); + final KeyStore serverKeyStore = KeyStore.getInstance("PKCS12"); + serverKeyStore.load(getClass().getResourceAsStream("mutual_auth_server.p12"), password); + final KeyStore clientKeyStore = KeyStore.getInstance("PKCS12"); + clientKeyStore.load(getClass().getResourceAsStream("mutual_auth_invalid_client.p12"), password); + final KeyManagerFactory serverKeyManagerFactory = + KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); + serverKeyManagerFactory.init(serverKeyStore, password); + final KeyManagerFactory clientKeyManagerFactory = + KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); + clientKeyManagerFactory.init(clientKeyStore, password); + File commonCertChain = ResourcesUtil.getFile(getClass(), "mutual_auth_ca.pem"); + + mySetupMutualAuth(param, serverKeyManagerFactory, commonCertChain, clientKeyManagerFactory, commonCertChain, + auth, false, false); + assertTrue(clientLatch.await(10, TimeUnit.SECONDS)); + rethrowIfNotNull(clientException); + assertTrue(serverLatch.await(5, TimeUnit.SECONDS)); + rethrowIfNotNull(serverException); + } + + private void testMutualAuthClientCertFail(SSLEngineTestParam param, ClientAuth auth) throws Exception { + testMutualAuthClientCertFail(param, auth, "mutual_auth_invalid_client.p12", false); + } + + private void testMutualAuthClientCertFail(SSLEngineTestParam param, ClientAuth auth, String clientCert, + boolean serverInitEngine) + throws Exception { + char[] password = "example".toCharArray(); + final KeyStore serverKeyStore = KeyStore.getInstance("PKCS12"); + serverKeyStore.load(getClass().getResourceAsStream("mutual_auth_server.p12"), password); + final KeyStore clientKeyStore = KeyStore.getInstance("PKCS12"); + clientKeyStore.load(getClass().getResourceAsStream(clientCert), password); + final KeyManagerFactory serverKeyManagerFactory = + KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); + serverKeyManagerFactory.init(serverKeyStore, password); + final KeyManagerFactory clientKeyManagerFactory = + KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); + clientKeyManagerFactory.init(clientKeyStore, password); + File commonCertChain = ResourcesUtil.getFile(getClass(), "mutual_auth_ca.pem"); + + mySetupMutualAuth(param, serverKeyManagerFactory, commonCertChain, clientKeyManagerFactory, commonCertChain, + auth, true, serverInitEngine); + assertTrue(clientLatch.await(10, TimeUnit.SECONDS)); + assertTrue(mySetupMutualAuthServerIsValidClientException(clientException), + "unexpected exception: " + clientException); + assertTrue(serverLatch.await(5, TimeUnit.SECONDS)); + assertTrue(mySetupMutualAuthServerIsValidServerException(serverException), + "unexpected exception: " + serverException); + } + + protected static boolean causedBySSLException(Throwable cause) { + Throwable next = cause; + do { + if (next instanceof SSLException) { + return true; + } + next = next.getCause(); + } while (next != null); + return false; + } + + protected boolean mySetupMutualAuthServerIsValidServerException(Throwable cause) { + return mySetupMutualAuthServerIsValidException(cause); + } + + protected boolean mySetupMutualAuthServerIsValidClientException(Throwable cause) { + return mySetupMutualAuthServerIsValidException(cause); + } + + protected boolean mySetupMutualAuthServerIsValidException(Throwable cause) { + // As in TLSv1.3 the handshake is sent without an extra roundtrip an SSLException is valid as well. + return cause instanceof SSLException || cause instanceof ClosedChannelException; + } + + protected void mySetupMutualAuthServerInitSslHandler(SslHandler handler) { + } + + protected void mySetupMutualAuth(final SSLEngineTestParam param, KeyManagerFactory serverKMF, + final File serverTrustManager, + KeyManagerFactory clientKMF, File clientTrustManager, + ClientAuth clientAuth, final boolean failureExpected, + final boolean serverInitEngine) + throws SSLException, InterruptedException { + serverSslCtx = + wrapContext(param, SslContextBuilder.forServer(serverKMF) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .sslProvider(sslServerProvider()) + .sslContextProvider(serverSslContextProvider()) + .trustManager(serverTrustManager) + .clientAuth(clientAuth) + .ciphers(null, IdentityCipherSuiteFilter.INSTANCE) + .sessionCacheSize(0) + .sessionTimeout(0).build()); + + clientSslCtx = + wrapContext(param, SslContextBuilder.forClient() + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .sslProvider(sslClientProvider()) + .sslContextProvider(clientSslContextProvider()) + .trustManager(clientTrustManager) + .keyManager(clientKMF) + .ciphers(null, IdentityCipherSuiteFilter.INSTANCE) + .sessionCacheSize(0) + .sessionTimeout(0).build()); + + serverConnectedChannel = null; + sb = new ServerBootstrap(); + cb = new Bootstrap(); + + sb.group(new NioEventLoopGroup(), new NioEventLoopGroup()); + sb.channel(NioServerSocketChannel.class); + sb.childHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ch.config().setAllocator(new TestByteBufAllocator(ch.config().getAllocator(), param.type())); + + ChannelPipeline p = ch.pipeline(); + SslHandler handler = !param.delegate ? serverSslCtx.newHandler(ch.alloc()) : + serverSslCtx.newHandler(ch.alloc(), delegatingExecutor); + if (serverInitEngine) { + mySetupMutualAuthServerInitSslHandler(handler); + } + p.addLast(handler); + p.addLast(new MessageDelegatorChannelHandler(serverReceiver, serverLatch)); + p.addLast(new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + if (evt == SslHandshakeCompletionEvent.SUCCESS) { + if (failureExpected) { + serverException = new IllegalStateException("handshake complete. expected failure"); + } + serverLatch.countDown(); + } else if (evt instanceof SslHandshakeCompletionEvent) { + serverException = ((SslHandshakeCompletionEvent) evt).cause(); + serverLatch.countDown(); + } + ctx.fireUserEventTriggered(evt); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + if (cause.getCause() instanceof SSLHandshakeException) { + serverException = cause.getCause(); + serverLatch.countDown(); + } else { + serverException = cause; + ctx.fireExceptionCaught(cause); + } + } + }); + serverConnectedChannel = ch; + } + }); + + cb.group(new NioEventLoopGroup()); + cb.channel(NioSocketChannel.class); + cb.handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ch.config().setAllocator(new TestByteBufAllocator(ch.config().getAllocator(), param.type)); + ChannelPipeline p = ch.pipeline(); + + SslHandler handler = !param.delegate ? clientSslCtx.newHandler(ch.alloc()) : + clientSslCtx.newHandler(ch.alloc(), delegatingExecutor); + p.addLast(handler); + p.addLast(new MessageDelegatorChannelHandler(clientReceiver, clientLatch)); + p.addLast(new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + if (evt == SslHandshakeCompletionEvent.SUCCESS) { + // With TLS1.3 a mutal auth error will not be propagated as a handshake error most of the + // time as the handshake needs NO extra roundtrip. + if (!failureExpected) { + clientLatch.countDown(); + } + } else if (evt instanceof SslHandshakeCompletionEvent) { + clientException = ((SslHandshakeCompletionEvent) evt).cause(); + clientLatch.countDown(); + } + ctx.fireUserEventTriggered(evt); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + if (cause.getCause() instanceof SSLException) { + clientException = cause.getCause(); + clientLatch.countDown(); + } else { + ctx.fireExceptionCaught(cause); + } + } + }); + } + }); + + serverChannel = sb.bind(new InetSocketAddress(0)).sync().channel(); + int port = ((InetSocketAddress) serverChannel.localAddress()).getPort(); + + ChannelFuture ccf = cb.connect(new InetSocketAddress(NetUtil.LOCALHOST, port)); + assertTrue(ccf.awaitUninterruptibly().isSuccess()); + clientChannel = ccf.channel(); + } + + protected static void rethrowIfNotNull(Throwable error) { + if (error != null) { + throw new AssertionFailedError("Expected no error", error); + } + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testClientHostnameValidationSuccess(SSLEngineTestParam param) throws Exception { + mySetupClientHostnameValidation(param, ResourcesUtil.getFile(getClass(), "localhost_server.pem"), + ResourcesUtil.getFile(getClass(), "localhost_server.key"), + ResourcesUtil.getFile(getClass(), "mutual_auth_ca.pem"), + false); + assertTrue(clientLatch.await(10, TimeUnit.SECONDS)); + + rethrowIfNotNull(clientException); + assertTrue(serverLatch.await(5, TimeUnit.SECONDS)); + rethrowIfNotNull(serverException); + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testClientHostnameValidationFail(SSLEngineTestParam param) throws Exception { + Future clientWriteFuture = + mySetupClientHostnameValidation(param, ResourcesUtil.getFile(getClass(), "notlocalhost_server.pem"), + ResourcesUtil.getFile(getClass(), "notlocalhost_server.key"), + ResourcesUtil.getFile(getClass(), "mutual_auth_ca.pem"), + true); + assertTrue(clientLatch.await(10, TimeUnit.SECONDS)); + assertTrue(mySetupMutualAuthServerIsValidClientException(clientException), + "unexpected exception: " + clientException); + assertTrue(serverLatch.await(5, TimeUnit.SECONDS)); + assertTrue(mySetupMutualAuthServerIsValidServerException(serverException), + "unexpected exception: " + serverException); + + // Verify that any pending writes are failed with the cached handshake exception and not a general SSLException. + clientWriteFuture.awaitUninterruptibly(); + Throwable actualCause = clientWriteFuture.cause(); + assertSame(clientException, actualCause); + } + + private Future mySetupClientHostnameValidation(final SSLEngineTestParam param, File serverCrtFile, + File serverKeyFile, + File clientTrustCrtFile, + final boolean failureExpected) + throws SSLException, InterruptedException { + final String expectedHost = "localhost"; + serverSslCtx = wrapContext(param, SslContextBuilder.forServer(serverCrtFile, serverKeyFile, null) + .sslProvider(sslServerProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .sslContextProvider(serverSslContextProvider()) + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .ciphers(null, IdentityCipherSuiteFilter.INSTANCE) + .sessionCacheSize(0) + .sessionTimeout(0) + .build()); + + clientSslCtx = wrapContext(param, SslContextBuilder.forClient() + .sslProvider(sslClientProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .sslContextProvider(clientSslContextProvider()) + .trustManager(clientTrustCrtFile) + .ciphers(null, IdentityCipherSuiteFilter.INSTANCE) + .sessionCacheSize(0) + .sessionTimeout(0) + .build()); + + serverConnectedChannel = null; + sb = new ServerBootstrap(); + cb = new Bootstrap(); + + sb.group(new NioEventLoopGroup(), new NioEventLoopGroup()); + sb.channel(NioServerSocketChannel.class); + sb.childHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ch.config().setAllocator(new TestByteBufAllocator(ch.config().getAllocator(), param.type)); + ChannelPipeline p = ch.pipeline(); + + SslHandler handler = !param.delegate ? serverSslCtx.newHandler(ch.alloc()) : + serverSslCtx.newHandler(ch.alloc(), delegatingExecutor); + p.addLast(handler); + p.addLast(new MessageDelegatorChannelHandler(serverReceiver, serverLatch)); + p.addLast(new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + if (evt == SslHandshakeCompletionEvent.SUCCESS) { + if (failureExpected) { + serverException = new IllegalStateException("handshake complete. expected failure"); + } + serverLatch.countDown(); + } else if (evt instanceof SslHandshakeCompletionEvent) { + serverException = ((SslHandshakeCompletionEvent) evt).cause(); + serverLatch.countDown(); + } + ctx.fireUserEventTriggered(evt); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + if (cause.getCause() instanceof SSLHandshakeException) { + serverException = cause.getCause(); + serverLatch.countDown(); + } else { + serverException = cause; + ctx.fireExceptionCaught(cause); + } + } + }); + serverConnectedChannel = ch; + } + }); + + final Promise clientWritePromise = ImmediateEventExecutor.INSTANCE.newPromise(); + cb.group(new NioEventLoopGroup()); + cb.channel(NioSocketChannel.class); + cb.handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ch.config().setAllocator(new TestByteBufAllocator(ch.config().getAllocator(), param.type)); + ChannelPipeline p = ch.pipeline(); + InetSocketAddress remoteAddress = (InetSocketAddress) serverChannel.localAddress(); + + SslHandler sslHandler = !param.delegate ? + clientSslCtx.newHandler(ch.alloc(), expectedHost, 0) : + clientSslCtx.newHandler(ch.alloc(), expectedHost, 0, delegatingExecutor); + + SSLParameters parameters = sslHandler.engine().getSSLParameters(); + if (SslUtils.isValidHostNameForSNI(expectedHost)) { + assertEquals(1, parameters.getServerNames().size()); + assertEquals(new SNIHostName(expectedHost), parameters.getServerNames().get(0)); + } + parameters.setEndpointIdentificationAlgorithm("HTTPS"); + sslHandler.engine().setSSLParameters(parameters); + p.addLast(sslHandler); + p.addLast(new MessageDelegatorChannelHandler(clientReceiver, clientLatch)); + p.addLast(new ChannelInboundHandlerAdapter() { + @Override + public void handlerAdded(ChannelHandlerContext ctx) { + // Only write if there is a failure expected. We don't actually care about the write going + // through we just want to verify the local failure condition. This way we don't have to worry + // about verifying the payload and releasing the content on the server side. + if (failureExpected) { + ChannelFuture f = ctx.write(ctx.alloc().buffer(1).writeByte(1)); + PromiseNotifier.cascade(f, clientWritePromise); + } + } + + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + if (evt == SslHandshakeCompletionEvent.SUCCESS) { + if (failureExpected) { + clientException = new IllegalStateException("handshake complete. expected failure"); + } + clientLatch.countDown(); + } else if (evt instanceof SslHandshakeCompletionEvent) { + clientException = ((SslHandshakeCompletionEvent) evt).cause(); + clientLatch.countDown(); + } + ctx.fireUserEventTriggered(evt); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + if (cause.getCause() instanceof SSLHandshakeException) { + clientException = cause.getCause(); + clientLatch.countDown(); + } else { + ctx.fireExceptionCaught(cause); + } + } + }); + } + }); + + serverChannel = sb.bind(new InetSocketAddress(expectedHost, 0)).sync().channel(); + final int port = ((InetSocketAddress) serverChannel.localAddress()).getPort(); + + ChannelFuture ccf = cb.connect(new InetSocketAddress(expectedHost, port)); + assertTrue(ccf.awaitUninterruptibly().isSuccess()); + clientChannel = ccf.channel(); + return clientWritePromise; + } + + private void mySetupMutualAuth(SSLEngineTestParam param, File keyFile, File crtFile, String keyPassword) + throws SSLException, InterruptedException { + mySetupMutualAuth(param, crtFile, keyFile, crtFile, keyPassword, crtFile, keyFile, crtFile, keyPassword); + } + + private void verifySSLSessionForMutualAuth( + SSLEngineTestParam param, SSLSession session, File certFile, String principalName) + throws Exception { + InputStream in = null; + try { + assertEquals(principalName, session.getLocalPrincipal().getName()); + assertEquals(principalName, session.getPeerPrincipal().getName()); + assertNotNull(session.getId()); + assertEquals(param.combo().cipher, session.getCipherSuite()); + assertEquals(param.combo().protocol, session.getProtocol()); + assertTrue(session.getApplicationBufferSize() > 0); + assertTrue(session.getCreationTime() > 0); + assertTrue(session.isValid()); + assertTrue(session.getLastAccessedTime() > 0); + + in = new FileInputStream(certFile); + final byte[] certBytes = SslContext.X509_CERT_FACTORY + .generateCertificate(in).getEncoded(); + + // Verify session + assertEquals(1, session.getPeerCertificates().length); + assertArrayEquals(certBytes, session.getPeerCertificates()[0].getEncoded()); + + try { + assertEquals(1, session.getPeerCertificateChain().length); + assertArrayEquals(certBytes, session.getPeerCertificateChain()[0].getEncoded()); + } catch (UnsupportedOperationException e) { + // See https://bugs.openjdk.java.net/browse/JDK-8241039 + assertTrue(PlatformDependent.javaVersion() >= 15); + } + + assertEquals(1, session.getLocalCertificates().length); + assertArrayEquals(certBytes, session.getLocalCertificates()[0].getEncoded()); + } finally { + if (in != null) { + in.close(); + } + } + } + + private void mySetupMutualAuth(final SSLEngineTestParam param, + File servertTrustCrtFile, File serverKeyFile, final File serverCrtFile, String serverKeyPassword, + File clientTrustCrtFile, File clientKeyFile, final File clientCrtFile, String clientKeyPassword) + throws InterruptedException, SSLException { + serverSslCtx = + wrapContext(param, SslContextBuilder.forServer(serverCrtFile, serverKeyFile, serverKeyPassword) + .sslProvider(sslServerProvider()) + .sslContextProvider(serverSslContextProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .trustManager(servertTrustCrtFile) + .ciphers(null, IdentityCipherSuiteFilter.INSTANCE) + .sessionCacheSize(0) + .sessionTimeout(0).build()); + clientSslCtx = + wrapContext(param, SslContextBuilder.forClient() + .sslProvider(sslClientProvider()) + .sslContextProvider(clientSslContextProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .trustManager(clientTrustCrtFile) + .keyManager(clientCrtFile, clientKeyFile, clientKeyPassword) + .ciphers(null, IdentityCipherSuiteFilter.INSTANCE) + .sessionCacheSize(0) + .sessionTimeout(0).build()); + + serverConnectedChannel = null; + sb = new ServerBootstrap(); + cb = new Bootstrap(); + + sb.group(new NioEventLoopGroup(), new NioEventLoopGroup()); + sb.channel(NioServerSocketChannel.class); + sb.childHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + ch.config().setAllocator(new TestByteBufAllocator(ch.config().getAllocator(), param.type)); + + ChannelPipeline p = ch.pipeline(); + final SSLEngine engine = wrapEngine(serverSslCtx.newEngine(ch.alloc())); + engine.setUseClientMode(false); + engine.setNeedClientAuth(true); + + p.addLast(new SslHandler(engine)); + p.addLast(new MessageDelegatorChannelHandler(serverReceiver, serverLatch)); + p.addLast(new ChannelInboundHandlerAdapter() { + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + if (cause.getCause() instanceof SSLHandshakeException) { + serverException = cause.getCause(); + serverLatch.countDown(); + } else { + serverException = cause; + ctx.fireExceptionCaught(cause); + } + } + + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + if (evt == SslHandshakeCompletionEvent.SUCCESS) { + try { + verifySSLSessionForMutualAuth( + param, engine.getSession(), serverCrtFile, PRINCIPAL_NAME); + } catch (Throwable cause) { + serverException = cause; + } + } + } + }); + serverConnectedChannel = ch; + } + }); + + cb.group(new NioEventLoopGroup()); + cb.channel(NioSocketChannel.class); + cb.handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ch.config().setAllocator(new TestByteBufAllocator(ch.config().getAllocator(), param.type)); + + final SslHandler handler = !param.delegate ? + clientSslCtx.newHandler(ch.alloc()) : + clientSslCtx.newHandler(ch.alloc(), delegatingExecutor); + + handler.engine().setNeedClientAuth(true); + ChannelPipeline p = ch.pipeline(); + p.addLast(handler); + p.addLast(new MessageDelegatorChannelHandler(clientReceiver, clientLatch)); + p.addLast(new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + if (evt == SslHandshakeCompletionEvent.SUCCESS) { + try { + verifySSLSessionForMutualAuth( + param, handler.engine().getSession(), clientCrtFile, PRINCIPAL_NAME); + } catch (Throwable cause) { + clientException = cause; + } + } + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + if (cause.getCause() instanceof SSLHandshakeException) { + clientException = cause.getCause(); + clientLatch.countDown(); + } else { + ctx.fireExceptionCaught(cause); + } + } + }); + } + }); + + serverChannel = sb.bind(new InetSocketAddress(0)).sync().channel(); + int port = ((InetSocketAddress) serverChannel.localAddress()).getPort(); + + ChannelFuture ccf = cb.connect(new InetSocketAddress(NetUtil.LOCALHOST, port)); + assertTrue(ccf.awaitUninterruptibly().isSuccess()); + clientChannel = ccf.channel(); + } + + protected void runTest(String expectedApplicationProtocol) throws Exception { + final ByteBuf clientMessage = Unpooled.copiedBuffer("I am a client".getBytes()); + final ByteBuf serverMessage = Unpooled.copiedBuffer("I am a server".getBytes()); + try { + writeAndVerifyReceived(clientMessage.retain(), clientChannel, serverLatch, serverReceiver); + writeAndVerifyReceived(serverMessage.retain(), serverConnectedChannel, clientLatch, clientReceiver); + verifyApplicationLevelProtocol(clientChannel, expectedApplicationProtocol); + verifyApplicationLevelProtocol(serverConnectedChannel, expectedApplicationProtocol); + } finally { + clientMessage.release(); + serverMessage.release(); + } + } + + private static void verifyApplicationLevelProtocol(Channel channel, String expectedApplicationProtocol) { + SslHandler handler = channel.pipeline().get(SslHandler.class); + assertNotNull(handler); + String appProto = handler.applicationProtocol(); + assertEquals(expectedApplicationProtocol, appProto); + + SSLEngine engine = handler.engine(); + if (engine instanceof JdkAlpnSslEngine) { + // Also verify the Java9 exposed method. + JdkAlpnSslEngine java9SslEngine = (JdkAlpnSslEngine) engine; + assertEquals(expectedApplicationProtocol == null ? StringUtil.EMPTY_STRING : expectedApplicationProtocol, + java9SslEngine.getApplicationProtocol()); + } + } + + private static void writeAndVerifyReceived(ByteBuf message, Channel sendChannel, CountDownLatch receiverLatch, + MessageReceiver receiver) throws Exception { + List dataCapture = null; + try { + assertTrue(sendChannel.writeAndFlush(message).await(10, TimeUnit.SECONDS)); + receiverLatch.await(5, TimeUnit.SECONDS); + message.resetReaderIndex(); + ArgumentCaptor captor = ArgumentCaptor.forClass(ByteBuf.class); + verify(receiver).messageReceived(captor.capture()); + dataCapture = captor.getAllValues(); + assertEquals(message, dataCapture.get(0)); + } finally { + if (dataCapture != null) { + for (ByteBuf data : dataCapture) { + data.release(); + } + } + } + } + + @Test + public void testGetCreationTime() throws Exception { + clientSslCtx = wrapContext(null, SslContextBuilder.forClient() + .sslProvider(sslClientProvider()) + .sslContextProvider(clientSslContextProvider()).build()); + SSLEngine engine = null; + try { + engine = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + assertTrue(engine.getSession().getCreationTime() <= System.currentTimeMillis()); + } finally { + cleanupClientSslEngine(engine); + } + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testSessionInvalidate(SSLEngineTestParam param) throws Exception { + clientSslCtx = wrapContext(param, SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(sslClientProvider()) + .sslContextProvider(clientSslContextProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + SelfSignedCertificate ssc = new SelfSignedCertificate(); + serverSslCtx = wrapContext(param, SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) + .sslProvider(sslServerProvider()) + .sslContextProvider(serverSslContextProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + SSLEngine clientEngine = null; + SSLEngine serverEngine = null; + try { + clientEngine = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + serverEngine = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + handshake(param.type(), param.delegate(), clientEngine, serverEngine); + + SSLSession session = serverEngine.getSession(); + assertTrue(session.isValid()); + session.invalidate(); + assertFalse(session.isValid()); + } finally { + cleanupClientSslEngine(clientEngine); + cleanupServerSslEngine(serverEngine); + ssc.delete(); + } + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testSSLSessionId(SSLEngineTestParam param) throws Exception { + clientSslCtx = wrapContext(param, SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(sslClientProvider()) + // This test only works for non TLSv1.3 for now + .protocols(param.protocols()) + .sslContextProvider(clientSslContextProvider()) + .build()); + SelfSignedCertificate ssc = new SelfSignedCertificate(); + serverSslCtx = wrapContext(param, SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) + .sslProvider(sslServerProvider()) + // This test only works for non TLSv1.3 for now + .protocols(param.protocols()) + .sslContextProvider(serverSslContextProvider()) + .build()); + SSLEngine clientEngine = null; + SSLEngine serverEngine = null; + try { + clientEngine = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + serverEngine = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + + // Before the handshake the id should have length == 0 + assertEquals(0, clientEngine.getSession().getId().length); + assertEquals(0, serverEngine.getSession().getId().length); + + handshake(param.type(), param.delegate(), clientEngine, serverEngine); + + if (param.protocolCipherCombo == ProtocolCipherCombo.TLSV13) { + // Allocate something which is big enough for sure + ByteBuffer packetBuffer = allocateBuffer(param.type(), 32 * 1024); + ByteBuffer appBuffer = allocateBuffer(param.type(), 32 * 1024); + + appBuffer.clear().position(4).flip(); + packetBuffer.clear(); + + do { + SSLEngineResult result; + + do { + result = serverEngine.wrap(appBuffer, packetBuffer); + } while (appBuffer.hasRemaining() || result.bytesProduced() > 0); + + appBuffer.clear(); + packetBuffer.flip(); + do { + result = clientEngine.unwrap(packetBuffer, appBuffer); + } while (packetBuffer.hasRemaining() || result.bytesProduced() > 0); + + packetBuffer.clear(); + appBuffer.clear().position(4).flip(); + + do { + result = clientEngine.wrap(appBuffer, packetBuffer); + } while (appBuffer.hasRemaining() || result.bytesProduced() > 0); + + appBuffer.clear(); + packetBuffer.flip(); + + do { + result = serverEngine.unwrap(packetBuffer, appBuffer); + } while (packetBuffer.hasRemaining() || result.bytesProduced() > 0); + + packetBuffer.clear(); + appBuffer.clear().position(4).flip(); + } while (clientEngine.getSession().getId().length == 0); + + // With TLS1.3 we should see pseudo IDs and so these should never match. + assertFalse(Arrays.equals(clientEngine.getSession().getId(), serverEngine.getSession().getId())); + } else { + // After the handshake the id should have length > 0 + assertNotEquals(0, clientEngine.getSession().getId().length); + assertNotEquals(0, serverEngine.getSession().getId().length); + + assertArrayEquals(clientEngine.getSession().getId(), serverEngine.getSession().getId()); + } + } finally { + cleanupClientSslEngine(clientEngine); + cleanupServerSslEngine(serverEngine); + ssc.delete(); + } + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Timeout(30) + public void clientInitiatedRenegotiationWithFatalAlertDoesNotInfiniteLoopServer(final SSLEngineTestParam param) + throws Exception { + assumeTrue(PlatformDependent.javaVersion() >= 11); + final SelfSignedCertificate ssc = new SelfSignedCertificate(); + serverSslCtx = wrapContext(param, SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) + .sslProvider(sslServerProvider()) + .sslContextProvider(serverSslContextProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + sb = new ServerBootstrap() + .group(new NioEventLoopGroup(1)) + .channel(NioServerSocketChannel.class) + .childHandler(new ChannelInitializer() { + @Override + public void initChannel(SocketChannel ch) { + ch.config().setAllocator(new TestByteBufAllocator(ch.config().getAllocator(), param.type)); + + ChannelPipeline p = ch.pipeline(); + + SslHandler handler = !param.delegate ? + serverSslCtx.newHandler(ch.alloc()) : + serverSslCtx.newHandler(ch.alloc(), delegatingExecutor); + + p.addLast(handler); + p.addLast(new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + if (evt instanceof SslHandshakeCompletionEvent && + ((SslHandshakeCompletionEvent) evt).isSuccess()) { + // This data will be sent to the client before any of the re-negotiation data can be + // sent. The client will read this, detect that it is not the response to + // renegotiation which was expected, and respond with a fatal alert. + ctx.writeAndFlush(ctx.alloc().buffer(1).writeByte(100)); + } + ctx.fireUserEventTriggered(evt); + } + + @Override + public void channelRead(final ChannelHandlerContext ctx, Object msg) { + ReferenceCountUtil.release(msg); + // The server then attempts to trigger a flush operation once the application data is + // received from the client. The flush will encrypt all data and should not result in + // deadlock. + ctx.channel().eventLoop().schedule(new Runnable() { + @Override + public void run() { + ctx.writeAndFlush(ctx.alloc().buffer(1).writeByte(101)); + } + }, 500, TimeUnit.MILLISECONDS); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) { + serverLatch.countDown(); + } + }); + serverConnectedChannel = ch; + } + }); + + serverChannel = sb.bind(new InetSocketAddress(0)).syncUninterruptibly().channel(); + + clientSslCtx = wrapContext(param, SslContextBuilder.forClient() + // OpenSslEngine doesn't support renegotiation on client side + .sslProvider(SslProvider.JDK) + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + + cb = new Bootstrap(); + cb.group(new NioEventLoopGroup(1)) + .channel(NioSocketChannel.class) + .handler(new ChannelInitializer() { + @Override + public void initChannel(SocketChannel ch) { + ch.config().setAllocator(new TestByteBufAllocator(ch.config().getAllocator(), param.type())); + + ChannelPipeline p = ch.pipeline(); + + SslHandler sslHandler = !param.delegate ? + clientSslCtx.newHandler(ch.alloc()) : + clientSslCtx.newHandler(ch.alloc(), delegatingExecutor); + + // The renegotiate is not expected to succeed, so we should stop trying in a timely manner so + // the unit test can terminate relativley quicly. + sslHandler.setHandshakeTimeout(1, TimeUnit.SECONDS); + p.addLast(sslHandler); + p.addLast(new ChannelInboundHandlerAdapter() { + private int handshakeCount; + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + // OpenSSL SSLEngine sends a fatal alert for the renegotiation handshake because the + // user data read as part of the handshake. The client receives this fatal alert and is + // expected to shutdown the connection. The "invalid data" during the renegotiation + // handshake is also delivered to channelRead(..) on the server. + // JDK SSLEngine completes the renegotiation handshake and delivers the "invalid data" + // is also delivered to channelRead(..) on the server. JDK SSLEngine does not send a + // fatal error and so for testing purposes we close the connection after we have + // completed the first renegotiation handshake (which is the second handshake). + if (evt instanceof SslHandshakeCompletionEvent && ++handshakeCount == 2) { + ctx.close(); + return; + } + ctx.fireUserEventTriggered(evt); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + ReferenceCountUtil.release(msg); + // Simulate a request that the server's application logic will think is invalid. + ctx.writeAndFlush(ctx.alloc().buffer(1).writeByte(102)); + ctx.pipeline().get(SslHandler.class).renegotiate(); + } + }); + } + }); + + ChannelFuture ccf = cb.connect(serverChannel.localAddress()); + assertTrue(ccf.syncUninterruptibly().isSuccess()); + clientChannel = ccf.channel(); + + serverLatch.await(); + ssc.delete(); + } + + protected void testEnablingAnAlreadyDisabledSslProtocol(SSLEngineTestParam param, + String[] protocols1, String[] protocols2) throws Exception { + SSLEngine sslEngine = null; + try { + File serverKeyFile = ResourcesUtil.getFile(getClass(), "test_unencrypted.pem"); + File serverCrtFile = ResourcesUtil.getFile(getClass(), "test.crt"); + serverSslCtx = wrapContext(param, SslContextBuilder.forServer(serverCrtFile, serverKeyFile) + .sslProvider(sslServerProvider()) + .sslContextProvider(serverSslContextProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + + sslEngine = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + + // Disable all protocols + sslEngine.setEnabledProtocols(EmptyArrays.EMPTY_STRINGS); + + // The only protocol that should be enabled is SSLv2Hello + String[] enabledProtocols = sslEngine.getEnabledProtocols(); + assertArrayEquals(protocols1, enabledProtocols); + + // Enable a protocol that is currently disabled + sslEngine.setEnabledProtocols(new String[]{ SslProtocols.TLS_v1_2 }); + + // The protocol that was just enabled should be returned + enabledProtocols = sslEngine.getEnabledProtocols(); + assertEquals(protocols2.length, enabledProtocols.length); + assertArrayEquals(protocols2, enabledProtocols); + } finally { + if (sslEngine != null) { + sslEngine.closeInbound(); + sslEngine.closeOutbound(); + cleanupServerSslEngine(sslEngine); + } + } + } + + protected void handshake(BufferType type, boolean delegate, SSLEngine clientEngine, SSLEngine serverEngine) + throws Exception { + ByteBuffer cTOs = allocateBuffer(type, clientEngine.getSession().getPacketBufferSize()); + ByteBuffer sTOc = allocateBuffer(type, serverEngine.getSession().getPacketBufferSize()); + + ByteBuffer serverAppReadBuffer = allocateBuffer(type, serverEngine.getSession().getApplicationBufferSize()); + ByteBuffer clientAppReadBuffer = allocateBuffer(type, clientEngine.getSession().getApplicationBufferSize()); + + clientEngine.beginHandshake(); + serverEngine.beginHandshake(); + + ByteBuffer empty = allocateBuffer(type, 0); + + SSLEngineResult clientResult; + SSLEngineResult serverResult; + + boolean clientHandshakeFinished = false; + boolean serverHandshakeFinished = false; + boolean cTOsHasRemaining; + boolean sTOcHasRemaining; + + do { + int cTOsPos = cTOs.position(); + int sTOcPos = sTOc.position(); + + if (!clientHandshakeFinished) { + clientResult = clientEngine.wrap(empty, cTOs); + runDelegatedTasks(delegate, clientResult, clientEngine); + assertEquals(empty.remaining(), clientResult.bytesConsumed()); + assertEquals(cTOs.position() - cTOsPos, clientResult.bytesProduced()); + + if (isHandshakeFinished(clientResult)) { + clientHandshakeFinished = true; + } + + if (clientResult.getStatus() == Status.BUFFER_OVERFLOW) { + cTOs = increaseDstBuffer(clientEngine.getSession().getPacketBufferSize(), type, cTOs); + } + } + + if (!serverHandshakeFinished) { + serverResult = serverEngine.wrap(empty, sTOc); + runDelegatedTasks(delegate, serverResult, serverEngine); + assertEquals(empty.remaining(), serverResult.bytesConsumed()); + assertEquals(sTOc.position() - sTOcPos, serverResult.bytesProduced()); + + if (isHandshakeFinished(serverResult)) { + serverHandshakeFinished = true; + } + + if (serverResult.getStatus() == Status.BUFFER_OVERFLOW) { + sTOc = increaseDstBuffer(serverEngine.getSession().getPacketBufferSize(), type, sTOc); + } + } + + cTOs.flip(); + sTOc.flip(); + + cTOsPos = cTOs.position(); + sTOcPos = sTOc.position(); + + if (!clientHandshakeFinished || + // After the handshake completes it is possible we have more data that was send by the server as + // the server will send session updates after the handshake. In this case continue to unwrap. + SslProtocols.TLS_v1_3.equals(clientEngine.getSession().getProtocol())) { + int clientAppReadBufferPos = clientAppReadBuffer.position(); + clientResult = clientEngine.unwrap(sTOc, clientAppReadBuffer); + + runDelegatedTasks(delegate, clientResult, clientEngine); + assertEquals(sTOc.position() - sTOcPos, clientResult.bytesConsumed()); + assertEquals(clientAppReadBuffer.position() - clientAppReadBufferPos, clientResult.bytesProduced()); + assertEquals(0, clientAppReadBuffer.position()); + + if (isHandshakeFinished(clientResult)) { + clientHandshakeFinished = true; + } + + if (clientResult.getStatus() == Status.BUFFER_OVERFLOW) { + clientAppReadBuffer = increaseDstBuffer( + clientEngine.getSession().getApplicationBufferSize(), type, clientAppReadBuffer); + } + } else { + assertEquals(0, sTOc.remaining()); + } + + if (!serverHandshakeFinished) { + int serverAppReadBufferPos = serverAppReadBuffer.position(); + serverResult = serverEngine.unwrap(cTOs, serverAppReadBuffer); + runDelegatedTasks(delegate, serverResult, serverEngine); + assertEquals(cTOs.position() - cTOsPos, serverResult.bytesConsumed()); + assertEquals(serverAppReadBuffer.position() - serverAppReadBufferPos, serverResult.bytesProduced()); + assertEquals(0, serverAppReadBuffer.position()); + + if (isHandshakeFinished(serverResult)) { + serverHandshakeFinished = true; + } + + if (serverResult.getStatus() == Status.BUFFER_OVERFLOW) { + serverAppReadBuffer = increaseDstBuffer( + serverEngine.getSession().getApplicationBufferSize(), type, serverAppReadBuffer); + } + } else { + assertFalse(cTOs.hasRemaining()); + } + + cTOsHasRemaining = compactOrClear(cTOs); + sTOcHasRemaining = compactOrClear(sTOc); + + serverAppReadBuffer.clear(); + clientAppReadBuffer.clear(); + + if (clientEngine.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING) { + clientHandshakeFinished = true; + } + + if (serverEngine.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING) { + serverHandshakeFinished = true; + } + } while (!clientHandshakeFinished || !serverHandshakeFinished || + // We need to ensure we feed all the data to the engine to not end up with a corrupted state. + // This is especially important with TLS1.3 which may produce sessions after the "main handshake" is + // done + cTOsHasRemaining || sTOcHasRemaining); + } + + private static boolean compactOrClear(ByteBuffer buffer) { + if (buffer.hasRemaining()) { + buffer.compact(); + return true; + } + buffer.clear(); + return false; + } + + private ByteBuffer increaseDstBuffer(int maxBufferSize, + BufferType type, ByteBuffer dstBuffer) { + assumeFalse(maxBufferSize == dstBuffer.remaining()); + // We need to increase the destination buffer + dstBuffer.flip(); + ByteBuffer tmpBuffer = allocateBuffer(type, maxBufferSize + dstBuffer.remaining()); + tmpBuffer.put(dstBuffer); + return tmpBuffer; + } + + private static boolean isHandshakeFinished(SSLEngineResult result) { + return result.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.FINISHED; + } + + private void runDelegatedTasks(boolean delegate, SSLEngineResult result, SSLEngine engine) { + if (result.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NEED_TASK) { + for (;;) { + Runnable task = engine.getDelegatedTask(); + if (task == null) { + break; + } + if (!delegate) { + task.run(); + } else { + delegatingExecutor.execute(task); + } + } + } + } + + protected abstract SslProvider sslClientProvider(); + + protected abstract SslProvider sslServerProvider(); + + protected Provider clientSslContextProvider() { + return null; + } + protected Provider serverSslContextProvider() { + return null; + } + + /** + * Called from the test cleanup code and can be used to release the {@code ctx} if it must be done manually. + */ + protected void cleanupClientSslContext(SslContext ctx) { + } + + /** + * Called from the test cleanup code and can be used to release the {@code ctx} if it must be done manually. + */ + protected void cleanupServerSslContext(SslContext ctx) { + } + + /** + * Called when ever an SSLEngine is not wrapped by a {@link SslHandler} and inserted into a pipeline. + */ + protected void cleanupClientSslEngine(SSLEngine engine) { + } + + /** + * Called when ever an SSLEngine is not wrapped by a {@link SslHandler} and inserted into a pipeline. + */ + protected void cleanupServerSslEngine(SSLEngine engine) { + } + + protected void setupHandlers(SSLEngineTestParam param, ApplicationProtocolConfig apn) + throws InterruptedException, SSLException, CertificateException { + setupHandlers(param, apn, apn); + } + + protected void setupHandlers(SSLEngineTestParam param, + ApplicationProtocolConfig serverApn, ApplicationProtocolConfig clientApn) + throws InterruptedException, SSLException, CertificateException { + SelfSignedCertificate ssc = new SelfSignedCertificate(); + + try { + SslContextBuilder serverCtxBuilder = SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey(), null) + .sslProvider(sslServerProvider()) + .sslContextProvider(serverSslContextProvider()) + .ciphers(null, IdentityCipherSuiteFilter.INSTANCE) + .applicationProtocolConfig(serverApn) + .sessionCacheSize(0) + .sessionTimeout(0); + if (serverApn.protocol() == Protocol.NPN || serverApn.protocol() == Protocol.NPN_AND_ALPN) { + // NPN is not really well supported with TLSv1.3 so force to use TLSv1.2 + // See https://github.com/openssl/openssl/issues/3665 + serverCtxBuilder.protocols(SslProtocols.TLS_v1_2); + } + + SslContextBuilder clientCtxBuilder = SslContextBuilder.forClient() + .sslProvider(sslClientProvider()) + .sslContextProvider(clientSslContextProvider()) + .applicationProtocolConfig(clientApn) + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .ciphers(null, IdentityCipherSuiteFilter.INSTANCE) + .sessionCacheSize(0) + .sessionTimeout(0); + + if (clientApn.protocol() == Protocol.NPN || clientApn.protocol() == Protocol.NPN_AND_ALPN) { + // NPN is not really well supported with TLSv1.3 so force to use TLSv1.2 + // See https://github.com/openssl/openssl/issues/3665 + clientCtxBuilder.protocols(SslProtocols.TLS_v1_2); + } + + setupHandlers(param.type(), param.delegate(), + wrapContext(param, serverCtxBuilder.build()), wrapContext(param, clientCtxBuilder.build())); + } finally { + ssc.delete(); + } + } + + protected void setupHandlers(final BufferType type, final boolean delegate, + SslContext serverCtx, SslContext clientCtx) + throws InterruptedException, SSLException, CertificateException { + serverSslCtx = serverCtx; + clientSslCtx = clientCtx; + + serverConnectedChannel = null; + sb = new ServerBootstrap(); + cb = new Bootstrap(); + + sb.group(new NioEventLoopGroup(), new NioEventLoopGroup()); + sb.channel(NioServerSocketChannel.class); + sb.childHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ch.config().setAllocator(new TestByteBufAllocator(ch.config().getAllocator(), type)); + + ChannelPipeline p = ch.pipeline(); + + SslHandler sslHandler = !delegate ? + serverSslCtx.newHandler(ch.alloc()) : + serverSslCtx.newHandler(ch.alloc(), delegatingExecutor); + + p.addLast(sslHandler); + p.addLast(new MessageDelegatorChannelHandler(serverReceiver, serverLatch)); + p.addLast(new ChannelInboundHandlerAdapter() { + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + if (cause.getCause() instanceof SSLHandshakeException) { + serverException = cause.getCause(); + serverLatch.countDown(); + } else { + ctx.fireExceptionCaught(cause); + } + } + }); + serverConnectedChannel = ch; + } + }); + + cb.group(new NioEventLoopGroup()); + cb.channel(NioSocketChannel.class); + cb.handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ch.config().setAllocator(new TestByteBufAllocator(ch.config().getAllocator(), type)); + + ChannelPipeline p = ch.pipeline(); + + SslHandler sslHandler = !delegate ? + clientSslCtx.newHandler(ch.alloc()) : + clientSslCtx.newHandler(ch.alloc(), delegatingExecutor); + + p.addLast(sslHandler); + p.addLast(new MessageDelegatorChannelHandler(clientReceiver, clientLatch)); + p.addLast(new ChannelInboundHandlerAdapter() { + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + if (cause.getCause() instanceof SSLHandshakeException) { + clientException = cause.getCause(); + clientLatch.countDown(); + } else { + ctx.fireExceptionCaught(cause); + } + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + clientLatch.countDown(); + } + }); + } + }); + + serverChannel = sb.bind(new InetSocketAddress(0)).syncUninterruptibly().channel(); + + ChannelFuture ccf = cb.connect(serverChannel.localAddress()); + assertTrue(ccf.syncUninterruptibly().isSuccess()); + clientChannel = ccf.channel(); + } + + @MethodSource("newTestParams") + @ParameterizedTest + @Timeout(30) + public void testMutualAuthSameCertChain(final SSLEngineTestParam param) throws Exception { + SelfSignedCertificate serverCert = new SelfSignedCertificate(); + SelfSignedCertificate clientCert = new SelfSignedCertificate(); + serverSslCtx = + wrapContext(param, SslContextBuilder.forServer(serverCert.certificate(), serverCert.privateKey()) + .trustManager(clientCert.cert()) + .clientAuth(ClientAuth.REQUIRE).sslProvider(sslServerProvider()) + .sslContextProvider(serverSslContextProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()).build()); + + sb = new ServerBootstrap(); + sb.group(new NioEventLoopGroup(), new NioEventLoopGroup()); + sb.channel(NioServerSocketChannel.class); + + final Promise promise = sb.config().group().next().newPromise(); + serverChannel = sb.childHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ch.config().setAllocator(new TestByteBufAllocator(ch.config().getAllocator(), param.type())); + + SslHandler sslHandler = !param.delegate ? + serverSslCtx.newHandler(ch.alloc()) : + serverSslCtx.newHandler(ch.alloc(), delegatingExecutor); + + ch.pipeline().addFirst(sslHandler); + ch.pipeline().addLast(new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + if (evt instanceof SslHandshakeCompletionEvent) { + Throwable cause = ((SslHandshakeCompletionEvent) evt).cause(); + if (cause == null) { + SSLSession session = ((SslHandler) ctx.pipeline().first()).engine().getSession(); + Certificate[] peerCertificates = session.getPeerCertificates(); + if (peerCertificates == null) { + promise.setFailure(new NullPointerException("peerCertificates")); + return; + } + try { + X509Certificate[] peerCertificateChain = session.getPeerCertificateChain(); + if (peerCertificateChain == null) { + promise.setFailure(new NullPointerException("peerCertificateChain")); + } else if (peerCertificateChain.length + peerCertificates.length != 4) { + String excTxtFmt = "peerCertificateChain.length:%s, peerCertificates.length:%s"; + promise.setFailure(new IllegalStateException(String.format(excTxtFmt, + peerCertificateChain.length, + peerCertificates.length))); + } else { + for (int i = 0; i < peerCertificateChain.length; i++) { + if (peerCertificateChain[i] == null || peerCertificates[i] == null) { + promise.setFailure( + new IllegalStateException("Certificate in chain is null")); + return; + } + } + promise.setSuccess(null); + } + } catch (UnsupportedOperationException e) { + // See https://bugs.openjdk.java.net/browse/JDK-8241039 + assertTrue(PlatformDependent.javaVersion() >= 15); + assertEquals(2, peerCertificates.length); + for (int i = 0; i < peerCertificates.length; i++) { + if (peerCertificates[i] == null) { + promise.setFailure( + new IllegalStateException("Certificate in chain is null")); + return; + } + } + promise.setSuccess(null); + } + } else { + promise.setFailure(cause); + } + } + } + }); + serverConnectedChannel = ch; + } + }).bind(new InetSocketAddress(0)).syncUninterruptibly().channel(); + + // We create a new chain for certificates which contains 2 certificates + ByteArrayOutputStream chainStream = new ByteArrayOutputStream(); + chainStream.write(Files.readAllBytes(clientCert.certificate().toPath())); + chainStream.write(Files.readAllBytes(serverCert.certificate().toPath())); + + clientSslCtx = + wrapContext(param, SslContextBuilder.forClient().keyManager( + new ByteArrayInputStream(chainStream.toByteArray()), + new FileInputStream(clientCert.privateKey())) + .trustManager(new FileInputStream(serverCert.certificate())) + .sslProvider(sslClientProvider()) + .sslContextProvider(clientSslContextProvider()) + .protocols(param.protocols()).ciphers(param.ciphers()).build()); + cb = new Bootstrap(); + cb.group(new NioEventLoopGroup()); + cb.channel(NioSocketChannel.class); + clientChannel = cb.handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ch.config().setAllocator(new TestByteBufAllocator(ch.config().getAllocator(), param.type())); + ch.pipeline().addLast(new SslHandler(wrapEngine(clientSslCtx.newEngine(ch.alloc())))); + } + + }).connect(serverChannel.localAddress()).syncUninterruptibly().channel(); + + promise.syncUninterruptibly(); + + serverCert.delete(); + clientCert.delete(); + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testUnwrapBehavior(SSLEngineTestParam param) throws Exception { + SelfSignedCertificate cert = new SelfSignedCertificate(); + + clientSslCtx = wrapContext(param, SslContextBuilder + .forClient() + .trustManager(cert.cert()) + .sslProvider(sslClientProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + SSLEngine client = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + + serverSslCtx = wrapContext(param, SslContextBuilder + .forServer(cert.certificate(), cert.privateKey()) + .sslProvider(sslServerProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + SSLEngine server = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + + byte[] bytes = "Hello World".getBytes(CharsetUtil.US_ASCII); + + try { + ByteBuffer plainClientOut = allocateBuffer(param.type, client.getSession().getApplicationBufferSize()); + ByteBuffer encryptedClientToServer = allocateBuffer( + param.type, server.getSession().getPacketBufferSize() * 2); + ByteBuffer plainServerIn = allocateBuffer(param.type, server.getSession().getApplicationBufferSize()); + + handshake(param.type(), param.delegate(), client, server); + + // create two TLS frames + + // first frame + plainClientOut.put(bytes, 0, 5); + plainClientOut.flip(); + + SSLEngineResult result = client.wrap(plainClientOut, encryptedClientToServer); + assertEquals(SSLEngineResult.Status.OK, result.getStatus()); + assertEquals(5, result.bytesConsumed()); + assertTrue(result.bytesProduced() > 0); + + assertFalse(plainClientOut.hasRemaining()); + + // second frame + plainClientOut.clear(); + plainClientOut.put(bytes, 5, 6); + plainClientOut.flip(); + + result = client.wrap(plainClientOut, encryptedClientToServer); + assertEquals(SSLEngineResult.Status.OK, result.getStatus()); + assertEquals(6, result.bytesConsumed()); + assertTrue(result.bytesProduced() > 0); + + // send over to server + encryptedClientToServer.flip(); + + // try with too small output buffer first (to check BUFFER_OVERFLOW case) + int remaining = encryptedClientToServer.remaining(); + ByteBuffer small = allocateBuffer(param.type, 3); + result = server.unwrap(encryptedClientToServer, small); + assertEquals(SSLEngineResult.Status.BUFFER_OVERFLOW, result.getStatus()); + assertEquals(remaining, encryptedClientToServer.remaining()); + + // now with big enough buffer + result = server.unwrap(encryptedClientToServer, plainServerIn); + assertEquals(SSLEngineResult.Status.OK, result.getStatus()); + + assertEquals(5, result.bytesProduced()); + assertTrue(encryptedClientToServer.hasRemaining()); + + result = server.unwrap(encryptedClientToServer, plainServerIn); + assertEquals(SSLEngineResult.Status.OK, result.getStatus()); + assertEquals(6, result.bytesProduced()); + assertFalse(encryptedClientToServer.hasRemaining()); + + plainServerIn.flip(); + + assertEquals(ByteBuffer.wrap(bytes), plainServerIn); + } finally { + cleanupClientSslEngine(client); + cleanupServerSslEngine(server); + cert.delete(); + } + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testProtocolMatch(SSLEngineTestParam param) throws Exception { + testProtocol(param, false, new String[] {"TLSv1.2"}, new String[] {"TLSv1", "TLSv1.1", "TLSv1.2"}); + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testProtocolNoMatch(SSLEngineTestParam param) throws Exception { + testProtocol(param, true, new String[] {"TLSv1.2"}, new String[] {"TLSv1", "TLSv1.1"}); + } + + private void testProtocol(final SSLEngineTestParam param, boolean handshakeFails, + String[] clientProtocols, String[] serverProtocols) + throws Exception { + SelfSignedCertificate cert = new SelfSignedCertificate(); + + clientSslCtx = wrapContext(param, SslContextBuilder + .forClient() + .trustManager(cert.cert()) + .sslProvider(sslClientProvider()) + .protocols(clientProtocols) + .build()); + SSLEngine client = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + + serverSslCtx = wrapContext(param, SslContextBuilder + .forServer(cert.certificate(), cert.privateKey()) + .sslProvider(sslServerProvider()) + .protocols(serverProtocols) + .build()); + SSLEngine server = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + + try { + if (handshakeFails) { + final SSLEngine clientEngine = client; + final SSLEngine serverEngine = server; + assertThrows(SSLHandshakeException.class, new Executable() { + @Override + public void execute() throws Throwable { + handshake(param.type(), param.delegate(), clientEngine, serverEngine); + } + }); + } else { + handshake(param.type(), param.delegate(), client, server); + } + } finally { + cleanupClientSslEngine(client); + cleanupServerSslEngine(server); + cert.delete(); + } + } + + private String[] nonContiguousProtocols(SslProvider provider) { + if (provider != null) { + // conscrypt not correctly filters out TLSv1 and TLSv1.1 which is required now by the JDK. + // https://github.com/google/conscrypt/issues/1013 + return new String[] { SslProtocols.TLS_v1_2 }; + } + return new String[] {SslProtocols.TLS_v1_2, SslProtocols.TLS_v1}; + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testHandshakeCompletesWithNonContiguousProtocolsTLSv1_2CipherOnly(SSLEngineTestParam param) + throws Exception { + SelfSignedCertificate ssc = new SelfSignedCertificate(); + // Select a mandatory cipher from the TLSv1.2 RFC https://www.ietf.org/rfc/rfc5246.txt so handshakes won't fail + // due to no shared/supported cipher. + final String sharedCipher = "TLS_RSA_WITH_AES_128_CBC_SHA"; + clientSslCtx = wrapContext(param, SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .ciphers(Collections.singletonList(sharedCipher)) + .protocols(nonContiguousProtocols(sslClientProvider())) + .sslContextProvider(clientSslContextProvider()) + .sslProvider(sslClientProvider()) + .build()); + + serverSslCtx = wrapContext(param, SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) + .ciphers(Collections.singletonList(sharedCipher)) + .protocols(nonContiguousProtocols(sslServerProvider())) + .sslContextProvider(serverSslContextProvider()) + .sslProvider(sslServerProvider()) + .build()); + SSLEngine clientEngine = null; + SSLEngine serverEngine = null; + try { + clientEngine = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + serverEngine = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + handshake(param.type(), param.delegate(), clientEngine, serverEngine); + } finally { + cleanupClientSslEngine(clientEngine); + cleanupServerSslEngine(serverEngine); + ssc.delete(); + } + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testHandshakeCompletesWithoutFilteringSupportedCipher(SSLEngineTestParam param) throws Exception { + SelfSignedCertificate ssc = new SelfSignedCertificate(); + // Select a mandatory cipher from the TLSv1.2 RFC https://www.ietf.org/rfc/rfc5246.txt so handshakes won't fail + // due to no shared/supported cipher. + final String sharedCipher = "TLS_RSA_WITH_AES_128_CBC_SHA"; + clientSslCtx = wrapContext(param, SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .ciphers(Collections.singletonList(sharedCipher), SupportedCipherSuiteFilter.INSTANCE) + .protocols(nonContiguousProtocols(sslClientProvider())) + .sslContextProvider(clientSslContextProvider()) + .sslProvider(sslClientProvider()) + .build()); + + serverSslCtx = wrapContext(param, SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) + .ciphers(Collections.singletonList(sharedCipher), SupportedCipherSuiteFilter.INSTANCE) + .protocols(nonContiguousProtocols(sslServerProvider())) + .sslContextProvider(serverSslContextProvider()) + .sslProvider(sslServerProvider()) + .build()); + SSLEngine clientEngine = null; + SSLEngine serverEngine = null; + try { + clientEngine = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + serverEngine = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + handshake(param.type(), param.delegate(), clientEngine, serverEngine); + } finally { + cleanupClientSslEngine(clientEngine); + cleanupServerSslEngine(serverEngine); + ssc.delete(); + } + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testPacketBufferSizeLimit(SSLEngineTestParam param) throws Exception { + SelfSignedCertificate cert = new SelfSignedCertificate(); + + clientSslCtx = wrapContext(param, SslContextBuilder + .forClient() + .trustManager(cert.cert()) + .sslContextProvider(clientSslContextProvider()) + .sslProvider(sslClientProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + SSLEngine client = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + + serverSslCtx = wrapContext(param, SslContextBuilder + .forServer(cert.certificate(), cert.privateKey()) + .sslContextProvider(serverSslContextProvider()) + .sslProvider(sslServerProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + SSLEngine server = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + + try { + // Allocate an buffer that is bigger then the max plain record size. + ByteBuffer plainServerOut = allocateBuffer( + param.type(), server.getSession().getApplicationBufferSize() * 2); + + handshake(param.type(), param.delegate(), client, server); + + // Fill the whole buffer and flip it. + plainServerOut.position(plainServerOut.capacity()); + plainServerOut.flip(); + + ByteBuffer encryptedServerToClient = allocateBuffer( + param.type(), server.getSession().getPacketBufferSize()); + + int encryptedServerToClientPos = encryptedServerToClient.position(); + int plainServerOutPos = plainServerOut.position(); + SSLEngineResult result = server.wrap(plainServerOut, encryptedServerToClient); + assertEquals(SSLEngineResult.Status.OK, result.getStatus()); + assertEquals(plainServerOut.position() - plainServerOutPos, result.bytesConsumed()); + assertEquals(encryptedServerToClient.position() - encryptedServerToClientPos, result.bytesProduced()); + } finally { + cleanupClientSslEngine(client); + cleanupServerSslEngine(server); + cert.delete(); + } + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testSSLEngineUnwrapNoSslRecord(SSLEngineTestParam param) throws Exception { + clientSslCtx = wrapContext(param, SslContextBuilder + .forClient() + .sslContextProvider(clientSslContextProvider()) + .sslProvider(sslClientProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + final SSLEngine client = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + + try { + final ByteBuffer src = allocateBuffer(param.type(), client.getSession().getApplicationBufferSize()); + final ByteBuffer dst = allocateBuffer(param.type(), client.getSession().getPacketBufferSize()); + ByteBuffer empty = allocateBuffer(param.type(), 0); + + SSLEngineResult clientResult = client.wrap(empty, dst); + assertEquals(SSLEngineResult.Status.OK, clientResult.getStatus()); + assertEquals(SSLEngineResult.HandshakeStatus.NEED_UNWRAP, clientResult.getHandshakeStatus()); + + assertThrows(SSLException.class, new Executable() { + @Override + public void execute() throws Throwable { + client.unwrap(src, dst); + } + }); + } finally { + cleanupClientSslEngine(client); + } + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testBeginHandshakeAfterEngineClosed(SSLEngineTestParam param) throws SSLException { + clientSslCtx = wrapContext(param, SslContextBuilder + .forClient() + .sslContextProvider(clientSslContextProvider()) + .sslProvider(sslClientProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + SSLEngine client = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + + try { + client.closeInbound(); + client.closeOutbound(); + try { + client.beginHandshake(); + fail(); + } catch (SSLException expected) { + // expected + } catch (IllegalStateException e) { + if (!Conscrypt.isEngineSupported(client)) { + throw e; + } + // Workaround for conscrypt bug + // See https://github.com/google/conscrypt/issues/840 + } + } finally { + cleanupClientSslEngine(client); + } + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testBeginHandshakeCloseOutbound(SSLEngineTestParam param) throws Exception { + SelfSignedCertificate cert = new SelfSignedCertificate(); + + clientSslCtx = wrapContext(param, SslContextBuilder + .forClient() + .sslContextProvider(clientSslContextProvider()) + .sslProvider(sslClientProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + SSLEngine client = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + + serverSslCtx = wrapContext(param, SslContextBuilder + .forServer(cert.certificate(), cert.privateKey()) + .sslContextProvider(serverSslContextProvider()) + .sslProvider(sslServerProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + SSLEngine server = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + + try { + testBeginHandshakeCloseOutbound(param, client); + testBeginHandshakeCloseOutbound(param, server); + } finally { + cleanupClientSslEngine(client); + cleanupServerSslEngine(server); + cert.delete(); + } + } + + private void testBeginHandshakeCloseOutbound(SSLEngineTestParam param, SSLEngine engine) throws SSLException { + ByteBuffer dst = allocateBuffer(param.type(), engine.getSession().getPacketBufferSize()); + ByteBuffer empty = allocateBuffer(param.type(), 0); + engine.beginHandshake(); + engine.closeOutbound(); + + SSLEngineResult result; + for (;;) { + result = engine.wrap(empty, dst); + dst.flip(); + + assertEquals(0, result.bytesConsumed()); + assertEquals(dst.remaining(), result.bytesProduced()); + if (result.getHandshakeStatus() != SSLEngineResult.HandshakeStatus.NEED_WRAP) { + break; + } + dst.clear(); + } + assertEquals(SSLEngineResult.Status.CLOSED, result.getStatus()); + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testCloseInboundAfterBeginHandshake(SSLEngineTestParam param) throws Exception { + SelfSignedCertificate cert = new SelfSignedCertificate(); + + clientSslCtx = wrapContext(param, SslContextBuilder + .forClient() + .sslContextProvider(clientSslContextProvider()) + .sslProvider(sslClientProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + SSLEngine client = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + + serverSslCtx = wrapContext(param, SslContextBuilder + .forServer(cert.certificate(), cert.privateKey()) + .sslContextProvider(serverSslContextProvider()) + .sslProvider(sslServerProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + SSLEngine server = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + + try { + testCloseInboundAfterBeginHandshake(client); + testCloseInboundAfterBeginHandshake(server); + } finally { + cleanupClientSslEngine(client); + cleanupServerSslEngine(server); + cert.delete(); + } + } + + private static void testCloseInboundAfterBeginHandshake(SSLEngine engine) throws SSLException { + engine.beginHandshake(); + try { + engine.closeInbound(); + // Workaround for conscrypt bug + // See https://github.com/google/conscrypt/issues/839 + if (!Conscrypt.isEngineSupported(engine)) { + fail(); + } + } catch (SSLException expected) { + // expected + } + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testCloseNotifySequence(SSLEngineTestParam param) throws Exception { + SelfSignedCertificate cert = new SelfSignedCertificate(); + + clientSslCtx = wrapContext(param, SslContextBuilder + .forClient() + .trustManager(cert.cert()) + .sslContextProvider(clientSslContextProvider()) + .sslProvider(sslClientProvider()) + // This test only works for non TLSv1.3 for now + .protocols(SslProtocols.TLS_v1_2) + .build()); + SSLEngine client = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + + serverSslCtx = wrapContext(param, SslContextBuilder + .forServer(cert.certificate(), cert.privateKey()) + .sslContextProvider(serverSslContextProvider()) + .sslProvider(sslServerProvider()) + // This test only works for non TLSv1.3 for now + .protocols(SslProtocols.TLS_v1_2) + .build()); + SSLEngine server = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + + try { + ByteBuffer plainClientOut = allocateBuffer(param.type(), client.getSession().getApplicationBufferSize()); + ByteBuffer plainServerOut = allocateBuffer(param.type(), server.getSession().getApplicationBufferSize()); + + ByteBuffer encryptedClientToServer = + allocateBuffer(param.type(), client.getSession().getPacketBufferSize()); + ByteBuffer encryptedServerToClient = + allocateBuffer(param.type(), server.getSession().getPacketBufferSize()); + ByteBuffer empty = allocateBuffer(param.type(), 0); + + handshake(param.type(), param.delegate(), client, server); + + // This will produce a close_notify + client.closeOutbound(); + + // Something still pending in the outbound buffer. + assertFalse(client.isOutboundDone()); + assertFalse(client.isInboundDone()); + + // Now wrap and so drain the outbound buffer. + SSLEngineResult result = client.wrap(empty, encryptedClientToServer); + encryptedClientToServer.flip(); + + assertEquals(SSLEngineResult.Status.CLOSED, result.getStatus()); + SSLEngineResult.HandshakeStatus hs = result.getHandshakeStatus(); + // Need an UNWRAP to read the response of the close_notify + if (sslClientProvider() == SslProvider.JDK || Conscrypt.isEngineSupported(client)) { + assertTrue(hs == SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING + || hs == SSLEngineResult.HandshakeStatus.NEED_UNWRAP); + } else { + assertEquals(SSLEngineResult.HandshakeStatus.NEED_UNWRAP, hs); + } + + int produced = result.bytesProduced(); + int consumed = result.bytesConsumed(); + int closeNotifyLen = produced; + + assertTrue(produced > 0); + assertEquals(0, consumed); + assertEquals(produced, encryptedClientToServer.remaining()); + // Outbound buffer should be drained now. + assertTrue(client.isOutboundDone()); + assertFalse(client.isInboundDone()); + + assertFalse(server.isOutboundDone()); + assertFalse(server.isInboundDone()); + result = server.unwrap(encryptedClientToServer, plainServerOut); + plainServerOut.flip(); + + assertEquals(SSLEngineResult.Status.CLOSED, result.getStatus()); + // Need a WRAP to respond to the close_notify + assertEquals(SSLEngineResult.HandshakeStatus.NEED_WRAP, result.getHandshakeStatus()); + + produced = result.bytesProduced(); + consumed = result.bytesConsumed(); + assertEquals(closeNotifyLen, consumed); + assertEquals(0, produced); + // Should have consumed the complete close_notify + assertEquals(0, encryptedClientToServer.remaining()); + assertEquals(0, plainServerOut.remaining()); + + assertFalse(server.isOutboundDone()); + assertTrue(server.isInboundDone()); + + result = server.wrap(empty, encryptedServerToClient); + encryptedServerToClient.flip(); + + assertEquals(SSLEngineResult.Status.CLOSED, result.getStatus()); + // UNWRAP/WRAP are not expected after this point + assertEquals(SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING, result.getHandshakeStatus()); + + produced = result.bytesProduced(); + consumed = result.bytesConsumed(); + assertEquals(closeNotifyLen, produced); + assertEquals(0, consumed); + + assertEquals(produced, encryptedServerToClient.remaining()); + assertTrue(server.isOutboundDone()); + assertTrue(server.isInboundDone()); + + result = client.unwrap(encryptedServerToClient, plainClientOut); + + plainClientOut.flip(); + assertEquals(SSLEngineResult.Status.CLOSED, result.getStatus()); + // UNWRAP/WRAP are not expected after this point + assertEquals(SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING, result.getHandshakeStatus()); + + produced = result.bytesProduced(); + consumed = result.bytesConsumed(); + assertEquals(closeNotifyLen, consumed); + assertEquals(0, produced); + assertEquals(0, encryptedServerToClient.remaining()); + + assertTrue(client.isOutboundDone()); + assertTrue(client.isInboundDone()); + + // Ensure that calling wrap or unwrap again will not produce an SSLException + encryptedServerToClient.clear(); + plainServerOut.clear(); + + result = server.wrap(plainServerOut, encryptedServerToClient); + assertEngineRemainsClosed(result); + + encryptedClientToServer.clear(); + plainServerOut.clear(); + + result = server.unwrap(encryptedClientToServer, plainServerOut); + assertEngineRemainsClosed(result); + + encryptedClientToServer.clear(); + plainClientOut.clear(); + + result = client.wrap(plainClientOut, encryptedClientToServer); + assertEngineRemainsClosed(result); + + encryptedServerToClient.clear(); + plainClientOut.clear(); + + result = client.unwrap(encryptedServerToClient, plainClientOut); + assertEngineRemainsClosed(result); + } finally { + cert.delete(); + cleanupClientSslEngine(client); + cleanupServerSslEngine(server); + } + } + + private static void assertEngineRemainsClosed(SSLEngineResult result) { + assertEquals(SSLEngineResult.Status.CLOSED, result.getStatus()); + assertEquals(SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING, result.getHandshakeStatus()); + assertEquals(0, result.bytesConsumed()); + assertEquals(0, result.bytesProduced()); + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testWrapAfterCloseOutbound(SSLEngineTestParam param) throws Exception { + SelfSignedCertificate cert = new SelfSignedCertificate(); + + clientSslCtx = wrapContext(param, SslContextBuilder + .forClient() + .trustManager(cert.cert()) + .sslProvider(sslClientProvider()) + .sslContextProvider(clientSslContextProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + SSLEngine client = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + + serverSslCtx = wrapContext(param, SslContextBuilder + .forServer(cert.certificate(), cert.privateKey()) + .sslProvider(sslServerProvider()) + .sslContextProvider(serverSslContextProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + SSLEngine server = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + + try { + ByteBuffer dst = allocateBuffer(param.type(), client.getSession().getPacketBufferSize()); + ByteBuffer src = allocateBuffer(param.type(), 1024); + + handshake(param.type(), param.delegate(), client, server); + + // This will produce a close_notify + client.closeOutbound(); + SSLEngineResult result = client.wrap(src, dst); + assertEquals(SSLEngineResult.Status.CLOSED, result.getStatus()); + assertEquals(0, result.bytesConsumed()); + assertTrue(result.bytesProduced() > 0); + + assertTrue(client.isOutboundDone()); + assertFalse(client.isInboundDone()); + } finally { + cert.delete(); + cleanupClientSslEngine(client); + cleanupServerSslEngine(server); + } + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testMultipleRecordsInOneBufferWithNonZeroPosition(SSLEngineTestParam param) throws Exception { + SelfSignedCertificate cert = new SelfSignedCertificate(); + + clientSslCtx = wrapContext(param, SslContextBuilder + .forClient() + .trustManager(cert.cert()) + .sslContextProvider(clientSslContextProvider()) + .sslProvider(sslClientProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + SSLEngine client = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + + serverSslCtx = wrapContext(param, SslContextBuilder + .forServer(cert.certificate(), cert.privateKey()) + .sslContextProvider(serverSslContextProvider()) + .sslProvider(sslServerProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + SSLEngine server = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + + try { + // Choose buffer size small enough that we can put multiple buffers into one buffer and pass it into the + // unwrap call without exceed MAX_ENCRYPTED_PACKET_LENGTH. + ByteBuffer plainClientOut = allocateBuffer(param.type(), 1024); + ByteBuffer plainServerOut = allocateBuffer(param.type(), server.getSession().getApplicationBufferSize()); + + ByteBuffer encClientToServer = allocateBuffer(param.type(), client.getSession().getPacketBufferSize()); + + int positionOffset = 1; + // We need to be able to hold 2 records + positionOffset + ByteBuffer combinedEncClientToServer = allocateBuffer( + param.type(), encClientToServer.capacity() * 2 + positionOffset); + combinedEncClientToServer.position(positionOffset); + + handshake(param.type(), param.delegate(), client, server); + + plainClientOut.limit(plainClientOut.capacity()); + SSLEngineResult result = client.wrap(plainClientOut, encClientToServer); + assertEquals(plainClientOut.capacity(), result.bytesConsumed()); + assertTrue(result.bytesProduced() > 0); + + encClientToServer.flip(); + + // Copy the first record into the combined buffer + combinedEncClientToServer.put(encClientToServer); + + plainClientOut.clear(); + encClientToServer.clear(); + + result = client.wrap(plainClientOut, encClientToServer); + assertEquals(plainClientOut.capacity(), result.bytesConsumed()); + assertTrue(result.bytesProduced() > 0); + + encClientToServer.flip(); + + int encClientToServerLen = encClientToServer.remaining(); + + // Copy the first record into the combined buffer + combinedEncClientToServer.put(encClientToServer); + + encClientToServer.clear(); + + combinedEncClientToServer.flip(); + combinedEncClientToServer.position(positionOffset); + + // Ensure we have the first record and a tiny amount of the second record in the buffer + combinedEncClientToServer.limit( + combinedEncClientToServer.limit() - (encClientToServerLen - positionOffset)); + result = server.unwrap(combinedEncClientToServer, plainServerOut); + assertEquals(encClientToServerLen, result.bytesConsumed()); + assertTrue(result.bytesProduced() > 0); + } finally { + cert.delete(); + cleanupClientSslEngine(client); + cleanupServerSslEngine(server); + } + } + + @Disabled("hangs") + @MethodSource("newTestParams") + @ParameterizedTest + public void testMultipleRecordsInOneBufferBiggerThenPacketBufferSize(SSLEngineTestParam param) throws Exception { + SelfSignedCertificate cert = new SelfSignedCertificate(); + + clientSslCtx = wrapContext(param, SslContextBuilder + .forClient() + .trustManager(cert.cert()) + .sslContextProvider(clientSslContextProvider()) + .sslProvider(sslClientProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + SSLEngine client = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + + serverSslCtx = wrapContext(param, SslContextBuilder + .forServer(cert.certificate(), cert.privateKey()) + .sslContextProvider(serverSslContextProvider()) + .sslProvider(sslServerProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + SSLEngine server = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + + try { + ByteBuffer plainClientOut = allocateBuffer(param.type(), 4096); + ByteBuffer plainServerOut = allocateBuffer(param.type(), server.getSession().getApplicationBufferSize()); + + ByteBuffer encClientToServer = allocateBuffer(param.type(), server.getSession().getPacketBufferSize() * 2); + + handshake(param.type(), param.delegate(), client, server); + + int srcLen = plainClientOut.remaining(); + SSLEngineResult result; + + int count = 0; + do { + int plainClientOutPosition = plainClientOut.position(); + int encClientToServerPosition = encClientToServer.position(); + result = client.wrap(plainClientOut, encClientToServer); + if (result.getStatus() == Status.BUFFER_OVERFLOW) { + // We did not have enough room to wrap + assertEquals(plainClientOutPosition, plainClientOut.position()); + assertEquals(encClientToServerPosition, encClientToServer.position()); + break; + } + assertEquals(SSLEngineResult.Status.OK, result.getStatus()); + assertEquals(srcLen, result.bytesConsumed()); + assertTrue(result.bytesProduced() > 0); + + plainClientOut.clear(); + + ++count; + } while (encClientToServer.position() < server.getSession().getPacketBufferSize()); + + // Check that we were able to wrap multiple times. + assertTrue(count >= 2); + encClientToServer.flip(); + + result = server.unwrap(encClientToServer, plainServerOut); + assertEquals(SSLEngineResult.Status.OK, result.getStatus()); + assertTrue(result.bytesConsumed() > 0); + assertTrue(result.bytesProduced() > 0); + assertTrue(encClientToServer.hasRemaining()); + } finally { + cert.delete(); + cleanupClientSslEngine(client); + cleanupServerSslEngine(server); + } + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testBufferUnderFlow(SSLEngineTestParam param) throws Exception { + SelfSignedCertificate cert = new SelfSignedCertificate(); + + clientSslCtx = wrapContext(param, SslContextBuilder + .forClient() + .trustManager(cert.cert()) + .sslContextProvider(clientSslContextProvider()) + .sslProvider(sslClientProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + SSLEngine client = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + + serverSslCtx = wrapContext(param, SslContextBuilder + .forServer(cert.certificate(), cert.privateKey()) + .sslContextProvider(serverSslContextProvider()) + .sslProvider(sslServerProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + SSLEngine server = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + + try { + ByteBuffer plainClient = allocateBuffer(param.type(), 1024); + plainClient.limit(plainClient.capacity()); + + ByteBuffer encClientToServer = allocateBuffer(param.type(), client.getSession().getPacketBufferSize()); + ByteBuffer plainServer = allocateBuffer(param.type(), server.getSession().getApplicationBufferSize()); + + handshake(param.type(), param.delegate(), client, server); + + SSLEngineResult result = client.wrap(plainClient, encClientToServer); + assertEquals(SSLEngineResult.Status.OK, result.getStatus()); + assertEquals(result.bytesConsumed(), plainClient.capacity()); + + // Flip so we can read it. + encClientToServer.flip(); + int remaining = encClientToServer.remaining(); + + // We limit the buffer so we have less then the header to read, this should result in an BUFFER_UNDERFLOW. + encClientToServer.limit(SSL_RECORD_HEADER_LENGTH - 1); + result = server.unwrap(encClientToServer, plainServer); + assertResultIsBufferUnderflow(result); + + // We limit the buffer so we can read the header but not the rest, this should result in an + // BUFFER_UNDERFLOW. + encClientToServer.limit(SSL_RECORD_HEADER_LENGTH); + result = server.unwrap(encClientToServer, plainServer); + assertResultIsBufferUnderflow(result); + + // We limit the buffer so we can read the header and partly the rest, this should result in an + // BUFFER_UNDERFLOW. + encClientToServer.limit(SSL_RECORD_HEADER_LENGTH + remaining - 1 - SSL_RECORD_HEADER_LENGTH); + result = server.unwrap(encClientToServer, plainServer); + assertResultIsBufferUnderflow(result); + + // Reset limit so we can read the full record. + encClientToServer.limit(remaining); + + result = server.unwrap(encClientToServer, plainServer); + assertEquals(SSLEngineResult.Status.OK, result.getStatus()); + assertEquals(result.bytesConsumed(), remaining); + assertTrue(result.bytesProduced() > 0); + } finally { + cert.delete(); + cleanupClientSslEngine(client); + cleanupServerSslEngine(server); + } + } + + private static void assertResultIsBufferUnderflow(SSLEngineResult result) { + assertEquals(SSLEngineResult.Status.BUFFER_UNDERFLOW, result.getStatus()); + assertEquals(0, result.bytesConsumed()); + assertEquals(0, result.bytesProduced()); + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testWrapDoesNotZeroOutSrc(SSLEngineTestParam param) throws Exception { + SelfSignedCertificate cert = new SelfSignedCertificate(); + + clientSslCtx = wrapContext(param, SslContextBuilder + .forClient() + .trustManager(cert.cert()) + .sslContextProvider(clientSslContextProvider()) + .sslProvider(sslClientProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + SSLEngine client = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + + serverSslCtx = wrapContext(param, SslContextBuilder + .forServer(cert.certificate(), cert.privateKey()) + .sslContextProvider(serverSslContextProvider()) + .sslProvider(sslServerProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + SSLEngine server = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + + try { + ByteBuffer plainServerOut = + allocateBuffer(param.type(), server.getSession().getApplicationBufferSize() / 2); + + handshake(param.type(), param.delegate(), client, server); + + // Fill the whole buffer and flip it. + for (int i = 0; i < plainServerOut.capacity(); i++) { + plainServerOut.put(i, (byte) i); + } + plainServerOut.position(plainServerOut.capacity()); + plainServerOut.flip(); + + ByteBuffer encryptedServerToClient = + allocateBuffer(param.type(), server.getSession().getPacketBufferSize()); + SSLEngineResult result = server.wrap(plainServerOut, encryptedServerToClient); + assertEquals(SSLEngineResult.Status.OK, result.getStatus()); + assertTrue(result.bytesConsumed() > 0); + + for (int i = 0; i < plainServerOut.capacity(); i++) { + assertEquals((byte) i, plainServerOut.get(i)); + } + } finally { + cleanupClientSslEngine(client); + cleanupServerSslEngine(server); + cert.delete(); + } + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testDisableProtocols(SSLEngineTestParam param) throws Exception { + testDisableProtocols(param, SslProtocols.SSL_v2, SslProtocols.SSL_v2); + testDisableProtocols(param, SslProtocols.SSL_v3, SslProtocols.SSL_v2, SslProtocols.SSL_v3); + testDisableProtocols(param, SslProtocols.TLS_v1, SslProtocols.SSL_v2, SslProtocols.SSL_v3, SslProtocols.TLS_v1); + testDisableProtocols(param, + SslProtocols.TLS_v1_1, SslProtocols.SSL_v2, SslProtocols.SSL_v3, + SslProtocols.TLS_v1, SslProtocols.TLS_v1_1); + testDisableProtocols(param, SslProtocols.TLS_v1_2, SslProtocols.SSL_v2, + SslProtocols.SSL_v3, SslProtocols.TLS_v1, SslProtocols.TLS_v1_1, SslProtocols.TLS_v1_2); + } + + private void testDisableProtocols(SSLEngineTestParam param, + String protocol, String... disabledProtocols) throws Exception { + SelfSignedCertificate cert = new SelfSignedCertificate(); + + SslContext ctx = wrapContext(param, SslContextBuilder + .forServer(cert.certificate(), cert.privateKey()) + .sslContextProvider(serverSslContextProvider()) + .sslProvider(sslServerProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + SSLEngine server = wrapEngine(ctx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + + try { + Set supported = new HashSet(Arrays.asList(server.getSupportedProtocols())); + if (supported.contains(protocol)) { + server.setEnabledProtocols(server.getSupportedProtocols()); + assertEquals(supported, new HashSet(Arrays.asList(server.getSupportedProtocols()))); + + for (String disabled : disabledProtocols) { + supported.remove(disabled); + } + if (supported.contains(SslProtocols.SSL_v2_HELLO) && supported.size() == 1) { + // It's not allowed to set only PROTOCOL_SSL_V2_HELLO if using JDK SSLEngine. + return; + } + server.setEnabledProtocols(supported.toArray(new String[0])); + assertEquals(supported, new HashSet(Arrays.asList(server.getEnabledProtocols()))); + server.setEnabledProtocols(server.getSupportedProtocols()); + } + } finally { + cleanupServerSslEngine(server); + cleanupClientSslContext(ctx); + cert.delete(); + } + } + + @Disabled("hangs") + @MethodSource("newTestParams") + @ParameterizedTest + public void testUsingX509TrustManagerVerifiesHostname(SSLEngineTestParam param) throws Exception { + testUsingX509TrustManagerVerifiesHostname(param, false); + } + + @Disabled("hangs") + @MethodSource("newTestParams") + @ParameterizedTest + public void testUsingX509TrustManagerVerifiesSNIHostname(SSLEngineTestParam param) throws Exception { + testUsingX509TrustManagerVerifiesHostname(param, true); + } + + private void testUsingX509TrustManagerVerifiesHostname(SSLEngineTestParam param, boolean useSNI) throws Exception { + if (clientSslContextProvider() != null) { + // Not supported when using conscrypt + return; + } + String fqdn = "something.netty.io"; + SelfSignedCertificate cert = new SelfSignedCertificate(fqdn); + clientSslCtx = wrapContext(param, SslContextBuilder + .forClient() + .trustManager(new TrustManagerFactory(new TrustManagerFactorySpi() { + @Override + protected void engineInit(KeyStore keyStore) { + // NOOP + } + @Override + protected TrustManager[] engineGetTrustManagers() { + // Provide a custom trust manager, this manager trust all certificates + return new TrustManager[] { + new X509TrustManager() { + @Override + public void checkClientTrusted( + java.security.cert.X509Certificate[] x509Certificates, String s) { + // NOOP + } + + @Override + public void checkServerTrusted( + java.security.cert.X509Certificate[] x509Certificates, String s) { + // NOOP + } + + @Override + public java.security.cert.X509Certificate[] getAcceptedIssuers() { + return EmptyArrays.EMPTY_X509_CERTIFICATES; + } + } + }; + } + + @Override + protected void engineInit(ManagerFactoryParameters managerFactoryParameters) { + } + }, null, TrustManagerFactory.getDefaultAlgorithm()) { + }) + .sslContextProvider(clientSslContextProvider()) + .sslProvider(sslClientProvider()) + .build()); + + SSLEngine client = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT, "127.0.0.1", 1234)); + SSLParameters sslParameters = client.getSSLParameters(); + sslParameters.setEndpointIdentificationAlgorithm("HTTPS"); + if (useSNI) { + sslParameters.setServerNames(Collections.singletonList(new SNIHostName(fqdn))); + } + client.setSSLParameters(sslParameters); + + serverSslCtx = wrapContext(param, SslContextBuilder + .forServer(cert.certificate(), cert.privateKey()) + .sslContextProvider(serverSslContextProvider()) + .sslProvider(sslServerProvider()) + .build()); + + SSLEngine server = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + try { + handshake(param.type(), param.delegate(), client, server); + if (!useSNI) { + fail(); + } + } catch (SSLException exception) { + if (useSNI) { + throw exception; + } + // expected as the hostname not matches. + } finally { + cleanupClientSslEngine(client); + cleanupServerSslEngine(server); + cert.delete(); + } + } + + @Test + public void testInvalidCipher() throws Exception { + SelfSignedCertificate cert = new SelfSignedCertificate(); + List cipherList = new ArrayList(); + Collections.addAll(cipherList, ((SSLSocketFactory) SSLSocketFactory.getDefault()).getDefaultCipherSuites()); + cipherList.add("InvalidCipher"); + SSLEngine server = null; + try { + serverSslCtx = wrapContext(null, SslContextBuilder.forServer(cert.key(), cert.cert()) + .sslContextProvider(serverSslContextProvider()) + .sslProvider(sslServerProvider()) + .ciphers(cipherList).build()); + server = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + fail(); + } catch (IllegalArgumentException expected) { + // expected when invalid cipher is used. + } catch (SSLException expected) { + // expected when invalid cipher is used. + } finally { + cert.delete(); + cleanupServerSslEngine(server); + } + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testGetCiphersuite(SSLEngineTestParam param) throws Exception { + clientSslCtx = wrapContext(param, SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(sslClientProvider()) + .sslContextProvider(clientSslContextProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + SelfSignedCertificate ssc = new SelfSignedCertificate(); + serverSslCtx = wrapContext(param, SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) + .sslProvider(sslServerProvider()) + .sslContextProvider(serverSslContextProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + SSLEngine clientEngine = null; + SSLEngine serverEngine = null; + try { + clientEngine = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + serverEngine = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + handshake(param.type(), param.delegate(), clientEngine, serverEngine); + + String clientCipher = clientEngine.getSession().getCipherSuite(); + String serverCipher = serverEngine.getSession().getCipherSuite(); + assertEquals(clientCipher, serverCipher); + + assertEquals(param.protocolCipherCombo.cipher, clientCipher); + } finally { + cleanupClientSslEngine(clientEngine); + cleanupServerSslEngine(serverEngine); + ssc.delete(); + } + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testSessionCache(SSLEngineTestParam param) throws Exception { + clientSslCtx = wrapContext(param, SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(sslClientProvider()) + .sslContextProvider(clientSslContextProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + SelfSignedCertificate ssc = new SelfSignedCertificate(); + serverSslCtx = wrapContext(param, SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) + .sslProvider(sslServerProvider()) + .sslContextProvider(serverSslContextProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + + try { + doHandshakeVerifyReusedAndClose(param, "a.netty.io", 9999, false); + doHandshakeVerifyReusedAndClose(param, "a.netty.io", 9999, true); + doHandshakeVerifyReusedAndClose(param, "b.netty.io", 9999, false); + invalidateSessionsAndAssert(serverSslCtx.sessionContext()); + invalidateSessionsAndAssert(clientSslCtx.sessionContext()); + } finally { + ssc.delete(); + } + } + + protected void invalidateSessionsAndAssert(SSLSessionContext context) { + Enumeration ids = context.getIds(); + while (ids.hasMoreElements()) { + byte[] id = ids.nextElement(); + SSLSession session = context.getSession(id); + if (session != null) { + session.invalidate(); + assertFalse(session.isValid()); + assertNull(context.getSession(id)); + } + } + } + + private static void assertSessionCache(SSLSessionContext sessionContext, int numSessions) { + Enumeration ids = sessionContext.getIds(); + int numIds = 0; + while (ids.hasMoreElements()) { + numIds++; + byte[] id = ids.nextElement(); + assertNotEquals(0, id.length); + SSLSession session = sessionContext.getSession(id); + assertArrayEquals(id, session.getId()); + } + assertEquals(numSessions, numIds); + } + + private void doHandshakeVerifyReusedAndClose(SSLEngineTestParam param, String host, int port, boolean reuse) + throws Exception { + SSLEngine clientEngine = null; + SSLEngine serverEngine = null; + try { + clientEngine = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT, host, port)); + serverEngine = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + handshake(param.type(), param.delegate(), clientEngine, serverEngine); + int clientSessions = currentSessionCacheSize(clientSslCtx.sessionContext()); + int serverSessions = currentSessionCacheSize(serverSslCtx.sessionContext()); + int nCSessions = clientSessions; + int nSSessions = serverSessions; + boolean clientSessionReused = false; + boolean serverSessionReused = false; + if (param.protocolCipherCombo == ProtocolCipherCombo.TLSV13) { + // Allocate something which is big enough for sure + ByteBuffer packetBuffer = allocateBuffer(param.type(), 32 * 1024); + ByteBuffer appBuffer = allocateBuffer(param.type(), 32 * 1024); + + appBuffer.clear().position(4).flip(); + packetBuffer.clear(); + + do { + SSLEngineResult result; + + do { + result = serverEngine.wrap(appBuffer, packetBuffer); + } while (appBuffer.hasRemaining() || result.bytesProduced() > 0); + + appBuffer.clear(); + packetBuffer.flip(); + do { + result = clientEngine.unwrap(packetBuffer, appBuffer); + } while (packetBuffer.hasRemaining() || result.bytesProduced() > 0); + + packetBuffer.clear(); + appBuffer.clear().position(4).flip(); + + do { + result = clientEngine.wrap(appBuffer, packetBuffer); + } while (appBuffer.hasRemaining() || result.bytesProduced() > 0); + + appBuffer.clear(); + packetBuffer.flip(); + + do { + result = serverEngine.unwrap(packetBuffer, appBuffer); + } while (packetBuffer.hasRemaining() || result.bytesProduced() > 0); + + packetBuffer.clear(); + appBuffer.clear().position(4).flip(); + nCSessions = currentSessionCacheSize(clientSslCtx.sessionContext()); + nSSessions = currentSessionCacheSize(serverSslCtx.sessionContext()); + clientSessionReused = isSessionMaybeReused(clientEngine); + serverSessionReused = isSessionMaybeReused(serverEngine); + } while ((reuse && (!clientSessionReused || !serverSessionReused)) + || (!reuse && (nCSessions < clientSessions || + // server may use multiple sessions + nSSessions < serverSessions))); + } + + assertSessionReusedForEngine(clientEngine, serverEngine, reuse); + + closeOutboundAndInbound(param.type(), clientEngine, serverEngine); + } finally { + cleanupClientSslEngine(clientEngine); + cleanupServerSslEngine(serverEngine); + } + } + + protected boolean isSessionMaybeReused(SSLEngine engine) { + return true; + } + + private static int currentSessionCacheSize(SSLSessionContext ctx) { + Enumeration ids = ctx.getIds(); + int i = 0; + while (ids.hasMoreElements()) { + i++; + ids.nextElement(); + } + return i; + } + + private void closeOutboundAndInbound( + BufferType type, SSLEngine clientEngine, SSLEngine serverEngine) throws SSLException { + assertFalse(clientEngine.isInboundDone()); + assertFalse(clientEngine.isOutboundDone()); + assertFalse(serverEngine.isInboundDone()); + assertFalse(serverEngine.isOutboundDone()); + + ByteBuffer empty = allocateBuffer(type, 0); + + // Ensure we allocate a bit more so we can fit in multiple packets. This is needed as we may call multiple + // time wrap / unwrap in a for loop before we drain the buffer we are writing in. + ByteBuffer cTOs = allocateBuffer(type, clientEngine.getSession().getPacketBufferSize() * 4); + ByteBuffer sTOs = allocateBuffer(type, serverEngine.getSession().getPacketBufferSize() * 4); + ByteBuffer cApps = allocateBuffer(type, clientEngine.getSession().getApplicationBufferSize() * 4); + ByteBuffer sApps = allocateBuffer(type, serverEngine.getSession().getApplicationBufferSize() * 4); + + clientEngine.closeOutbound(); + for (;;) { + // call wrap till we produced all data + SSLEngineResult result = clientEngine.wrap(empty, cTOs); + if (result.getStatus() == Status.CLOSED && result.bytesProduced() == 0) { + break; + } + assertTrue(cTOs.hasRemaining()); + } + cTOs.flip(); + + for (;;) { + // call unwrap till we consumed all data + SSLEngineResult result = serverEngine.unwrap(cTOs, sApps); + if (result.getStatus() == Status.CLOSED && result.bytesProduced() == 0) { + break; + } + assertTrue(sApps.hasRemaining()); + } + + serverEngine.closeOutbound(); + for (;;) { + // call wrap till we produced all data + SSLEngineResult result = serverEngine.wrap(empty, sTOs); + if (result.getStatus() == Status.CLOSED && result.bytesProduced() == 0) { + break; + } + assertTrue(sTOs.hasRemaining()); + } + sTOs.flip(); + + for (;;) { + // call unwrap till we consumed all data + SSLEngineResult result = clientEngine.unwrap(sTOs, cApps); + if (result.getStatus() == Status.CLOSED && result.bytesProduced() == 0) { + break; + } + assertTrue(cApps.hasRemaining()); + } + + // Now close the inbound as well + clientEngine.closeInbound(); + serverEngine.closeInbound(); + } + + protected void assertSessionReusedForEngine(SSLEngine clientEngine, SSLEngine serverEngine, boolean reuse) { + // NOOP + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testSessionCacheTimeout(SSLEngineTestParam param) throws Exception { + clientSslCtx = wrapContext(param, SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(sslClientProvider()) + .sslContextProvider(clientSslContextProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .sessionTimeout(1) + .build()); + SelfSignedCertificate ssc = new SelfSignedCertificate(); + serverSslCtx = wrapContext(param, SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) + .sslProvider(sslServerProvider()) + .sslContextProvider(serverSslContextProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .sessionTimeout(1) + .build()); + + try { + doHandshakeVerifyReusedAndClose(param, "a.netty.io", 9999, false); + + // Let's sleep for a bit more then 1 second so the cache should timeout the sessions. + Thread.sleep(1500); + + assertSessionCache(serverSslCtx.sessionContext(), 0); + assertSessionCache(clientSslCtx.sessionContext(), 0); + } finally { + ssc.delete(); + } + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testSessionCacheSize(SSLEngineTestParam param) throws Exception { + clientSslCtx = wrapContext(param, SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(sslClientProvider()) + .sslContextProvider(clientSslContextProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .sessionCacheSize(1) + .build()); + SelfSignedCertificate ssc = new SelfSignedCertificate(); + serverSslCtx = wrapContext(param, SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) + .sslProvider(sslServerProvider()) + .sslContextProvider(serverSslContextProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + + try { + doHandshakeVerifyReusedAndClose(param, "a.netty.io", 9999, false); + // As we have a cache size of 1 we should never have more then one session in the cache + doHandshakeVerifyReusedAndClose(param, "b.netty.io", 9999, false); + + // We should at least reuse b.netty.io + doHandshakeVerifyReusedAndClose(param, "b.netty.io", 9999, true); + } finally { + ssc.delete(); + } + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testSessionBindingEvent(SSLEngineTestParam param) throws Exception { + clientSslCtx = wrapContext(param, SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(sslClientProvider()) + .sslContextProvider(clientSslContextProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + SelfSignedCertificate ssc = new SelfSignedCertificate(); + serverSslCtx = wrapContext(param, SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) + .sslProvider(sslServerProvider()) + .sslContextProvider(serverSslContextProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + SSLEngine clientEngine = null; + SSLEngine serverEngine = null; + try { + clientEngine = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + serverEngine = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + handshake(param.type(), param.delegate(), clientEngine, serverEngine); + SSLSession session = clientEngine.getSession(); + assertEquals(0, session.getValueNames().length); + + class SSLSessionBindingEventValue implements SSLSessionBindingListener { + SSLSessionBindingEvent boundEvent; + SSLSessionBindingEvent unboundEvent; + + @Override + public void valueBound(SSLSessionBindingEvent sslSessionBindingEvent) { + assertNull(boundEvent); + boundEvent = sslSessionBindingEvent; + } + + @Override + public void valueUnbound(SSLSessionBindingEvent sslSessionBindingEvent) { + assertNull(unboundEvent); + unboundEvent = sslSessionBindingEvent; + } + } + + String name = "name"; + String name2 = "name2"; + + SSLSessionBindingEventValue value1 = new SSLSessionBindingEventValue(); + session.putValue(name, value1); + assertSSLSessionBindingEventValue(name, session, value1.boundEvent); + assertNull(value1.unboundEvent); + assertEquals(1, session.getValueNames().length); + + session.putValue(name2, "value"); + + SSLSessionBindingEventValue value2 = new SSLSessionBindingEventValue(); + session.putValue(name, value2); + assertEquals(2, session.getValueNames().length); + + assertSSLSessionBindingEventValue(name, session, value1.unboundEvent); + assertSSLSessionBindingEventValue(name, session, value2.boundEvent); + assertNull(value2.unboundEvent); + assertEquals(2, session.getValueNames().length); + + session.removeValue(name); + assertSSLSessionBindingEventValue(name, session, value2.unboundEvent); + assertEquals(1, session.getValueNames().length); + session.removeValue(name2); + } finally { + cleanupClientSslEngine(clientEngine); + cleanupServerSslEngine(serverEngine); + ssc.delete(); + } + } + + private static void assertSSLSessionBindingEventValue( + String name, SSLSession session, SSLSessionBindingEvent event) { + assertEquals(name, event.getName()); + assertEquals(session, event.getSession()); + assertEquals(session, event.getSource()); + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testSessionAfterHandshake(SSLEngineTestParam param) throws Exception { + testSessionAfterHandshake0(param, false, false); + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testSessionAfterHandshakeMutualAuth(SSLEngineTestParam param) throws Exception { + testSessionAfterHandshake0(param, false, true); + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testSessionAfterHandshakeKeyManagerFactory(SSLEngineTestParam param) throws Exception { + testSessionAfterHandshake0(param, true, false); + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testSessionAfterHandshakeKeyManagerFactoryMutualAuth(SSLEngineTestParam param) throws Exception { + testSessionAfterHandshake0(param, true, true); + } + + private void testSessionAfterHandshake0( + SSLEngineTestParam param, boolean useKeyManagerFactory, boolean mutualAuth) throws Exception { + SelfSignedCertificate ssc = new SelfSignedCertificate(); + KeyManagerFactory kmf = useKeyManagerFactory ? + SslContext.buildKeyManagerFactory( + new java.security.cert.X509Certificate[] { ssc.cert()}, null, + ssc.key(), null, null, null) : null; + + SslContextBuilder clientContextBuilder = SslContextBuilder.forClient(); + if (mutualAuth) { + if (kmf != null) { + clientContextBuilder.keyManager(kmf); + } else { + clientContextBuilder.keyManager(ssc.key(), ssc.cert()); + } + } + clientSslCtx = wrapContext(param, clientContextBuilder + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(sslClientProvider()) + .sslContextProvider(clientSslContextProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + + SslContextBuilder serverContextBuilder = kmf != null ? + SslContextBuilder.forServer(kmf) : + SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()); + if (mutualAuth) { + serverContextBuilder.clientAuth(ClientAuth.REQUIRE); + } + serverSslCtx = wrapContext(param, serverContextBuilder.trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(sslServerProvider()) + .sslContextProvider(serverSslContextProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + SSLEngine clientEngine = null; + SSLEngine serverEngine = null; + try { + clientEngine = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + serverEngine = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + + handshake(param.type(), param.delegate(), clientEngine, serverEngine); + + SSLSession clientSession = clientEngine.getSession(); + SSLSession serverSession = serverEngine.getSession(); + + assertNull(clientSession.getPeerHost()); + assertNull(serverSession.getPeerHost()); + assertEquals(-1, clientSession.getPeerPort()); + assertEquals(-1, serverSession.getPeerPort()); + + assertTrue(clientSession.getCreationTime() > 0); + assertTrue(serverSession.getCreationTime() > 0); + + assertTrue(clientSession.getLastAccessedTime() > 0); + assertTrue(serverSession.getLastAccessedTime() > 0); + + assertEquals(param.combo().protocol, clientSession.getProtocol()); + assertEquals(param.combo().protocol, serverSession.getProtocol()); + + assertEquals(param.combo().cipher, clientSession.getCipherSuite()); + assertEquals(param.combo().cipher, serverSession.getCipherSuite()); + + assertNotNull(clientSession.getId()); + assertNotNull(serverSession.getId()); + + assertTrue(clientSession.getApplicationBufferSize() > 0); + assertTrue(serverSession.getApplicationBufferSize() > 0); + + assertTrue(clientSession.getPacketBufferSize() > 0); + assertTrue(serverSession.getPacketBufferSize() > 0); + + assertNotNull(clientSession.getSessionContext()); + + // Workaround for JDK 14 regression. + // See https://bugs.openjdk.java.net/browse/JDK-8242008 + if (PlatformDependent.javaVersion() != 14) { + assertNotNull(serverSession.getSessionContext()); + } + + Object value = new Object(); + + assertEquals(0, clientSession.getValueNames().length); + clientSession.putValue("test", value); + assertEquals("test", clientSession.getValueNames()[0]); + assertSame(value, clientSession.getValue("test")); + clientSession.removeValue("test"); + assertEquals(0, clientSession.getValueNames().length); + + assertEquals(0, serverSession.getValueNames().length); + serverSession.putValue("test", value); + assertEquals("test", serverSession.getValueNames()[0]); + assertSame(value, serverSession.getValue("test")); + serverSession.removeValue("test"); + assertEquals(0, serverSession.getValueNames().length); + + Certificate[] serverLocalCertificates = serverSession.getLocalCertificates(); + assertEquals(1, serverLocalCertificates.length); + assertArrayEquals(ssc.cert().getEncoded(), serverLocalCertificates[0].getEncoded()); + + Principal serverLocalPrincipal = serverSession.getLocalPrincipal(); + assertNotNull(serverLocalPrincipal); + + if (mutualAuth) { + Certificate[] clientLocalCertificates = clientSession.getLocalCertificates(); + assertEquals(1, clientLocalCertificates.length); + + Certificate[] serverPeerCertificates = serverSession.getPeerCertificates(); + assertEquals(1, serverPeerCertificates.length); + assertArrayEquals(clientLocalCertificates[0].getEncoded(), serverPeerCertificates[0].getEncoded()); + + try { + X509Certificate[] serverPeerX509Certificates = serverSession.getPeerCertificateChain(); + assertEquals(1, serverPeerX509Certificates.length); + assertArrayEquals(clientLocalCertificates[0].getEncoded(), + serverPeerX509Certificates[0].getEncoded()); + } catch (UnsupportedOperationException e) { + // See https://bugs.openjdk.java.net/browse/JDK-8241039 + assertTrue(PlatformDependent.javaVersion() >= 15); + } + + Principal clientLocalPrincipial = clientSession.getLocalPrincipal(); + assertNotNull(clientLocalPrincipial); + + Principal serverPeerPrincipal = serverSession.getPeerPrincipal(); + assertEquals(clientLocalPrincipial, serverPeerPrincipal); + } else { + assertNull(clientSession.getLocalCertificates()); + assertNull(clientSession.getLocalPrincipal()); + + try { + serverSession.getPeerCertificates(); + fail(); + } catch (SSLPeerUnverifiedException expected) { + // As we did not use mutual auth this is expected + } + + try { + serverSession.getPeerCertificateChain(); + fail(); + } catch (SSLPeerUnverifiedException expected) { + // As we did not use mutual auth this is expected + } catch (UnsupportedOperationException e) { + // See https://bugs.openjdk.java.net/browse/JDK-8241039 + assertTrue(PlatformDependent.javaVersion() >= 15); + } + + try { + serverSession.getPeerPrincipal(); + fail(); + } catch (SSLPeerUnverifiedException expected) { + // As we did not use mutual auth this is expected + } + } + + Certificate[] clientPeerCertificates = clientSession.getPeerCertificates(); + assertEquals(1, clientPeerCertificates.length); + assertArrayEquals(serverLocalCertificates[0].getEncoded(), clientPeerCertificates[0].getEncoded()); + + try { + X509Certificate[] clientPeerX509Certificates = clientSession.getPeerCertificateChain(); + assertEquals(1, clientPeerX509Certificates.length); + assertArrayEquals(serverLocalCertificates[0].getEncoded(), clientPeerX509Certificates[0].getEncoded()); + } catch (UnsupportedOperationException e) { + // See https://bugs.openjdk.java.net/browse/JDK-8241039 + assertTrue(PlatformDependent.javaVersion() >= 15); + } + Principal clientPeerPrincipal = clientSession.getPeerPrincipal(); + assertEquals(serverLocalPrincipal, clientPeerPrincipal); + } finally { + cleanupClientSslEngine(clientEngine); + cleanupServerSslEngine(serverEngine); + ssc.delete(); + } + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testSupportedSignatureAlgorithms(SSLEngineTestParam param) throws Exception { + final SelfSignedCertificate ssc = new SelfSignedCertificate(); + + final class TestKeyManagerFactory extends KeyManagerFactory { + TestKeyManagerFactory(final KeyManagerFactory factory) { + super(new KeyManagerFactorySpi() { + + private final KeyManager[] managers = factory.getKeyManagers(); + + @Override + protected void engineInit(KeyStore keyStore, char[] chars) { + throw new UnsupportedOperationException(); + } + + @Override + protected void engineInit(ManagerFactoryParameters managerFactoryParameters) { + throw new UnsupportedOperationException(); + } + + @Override + protected KeyManager[] engineGetKeyManagers() { + KeyManager[] array = new KeyManager[managers.length]; + + for (int i = 0 ; i < array.length; i++) { + final X509ExtendedKeyManager x509ExtendedKeyManager = (X509ExtendedKeyManager) managers[i]; + + array[i] = new X509ExtendedKeyManager() { + @Override + public String[] getClientAliases(String s, Principal[] principals) { + fail(); + return null; + } + + @Override + public String chooseClientAlias( + String[] strings, Principal[] principals, Socket socket) { + fail(); + return null; + } + + @Override + public String[] getServerAliases(String s, Principal[] principals) { + fail(); + return null; + } + + @Override + public String chooseServerAlias(String s, Principal[] principals, Socket socket) { + fail(); + return null; + } + + @Override + public String chooseEngineClientAlias( + String[] strings, Principal[] principals, SSLEngine sslEngine) { + assertNotEquals(0, ((ExtendedSSLSession) sslEngine.getHandshakeSession()) + .getPeerSupportedSignatureAlgorithms().length); + assertNotEquals(0, ((ExtendedSSLSession) sslEngine.getHandshakeSession()) + .getLocalSupportedSignatureAlgorithms().length); + return x509ExtendedKeyManager.chooseEngineClientAlias( + strings, principals, sslEngine); + } + + @Override + public String chooseEngineServerAlias( + String s, Principal[] principals, SSLEngine sslEngine) { + assertNotEquals(0, ((ExtendedSSLSession) sslEngine.getHandshakeSession()) + .getPeerSupportedSignatureAlgorithms().length); + assertNotEquals(0, ((ExtendedSSLSession) sslEngine.getHandshakeSession()) + .getLocalSupportedSignatureAlgorithms().length); + return x509ExtendedKeyManager.chooseEngineServerAlias(s, principals, sslEngine); + } + + @Override + public java.security.cert.X509Certificate[] getCertificateChain(String s) { + return x509ExtendedKeyManager.getCertificateChain(s); + } + + @Override + public PrivateKey getPrivateKey(String s) { + return x509ExtendedKeyManager.getPrivateKey(s); + } + }; + } + return array; + } + }, factory.getProvider(), factory.getAlgorithm()); + } + } + + clientSslCtx = wrapContext(param, SslContextBuilder.forClient() + .keyManager(new TestKeyManagerFactory(newKeyManagerFactory(ssc))) + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(sslClientProvider()) + .sslContextProvider(clientSslContextProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + + serverSslCtx = wrapContext(param, SslContextBuilder.forServer( + new TestKeyManagerFactory(newKeyManagerFactory(ssc))) + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslContextProvider(serverSslContextProvider()) + .sslProvider(sslServerProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .clientAuth(ClientAuth.REQUIRE) + .build()); + SSLEngine clientEngine = null; + SSLEngine serverEngine = null; + try { + clientEngine = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + serverEngine = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + handshake(param.type(), param.delegate(), clientEngine, serverEngine); + } finally { + cleanupClientSslEngine(clientEngine); + cleanupServerSslEngine(serverEngine); + ssc.delete(); + } + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testHandshakeSession(SSLEngineTestParam param) throws Exception { + final SelfSignedCertificate ssc = new SelfSignedCertificate(); + + final TestTrustManagerFactory clientTmf = new TestTrustManagerFactory(ssc.cert()); + final TestTrustManagerFactory serverTmf = new TestTrustManagerFactory(ssc.cert()); + + clientSslCtx = wrapContext(param, SslContextBuilder.forClient() + .trustManager(new SimpleTrustManagerFactory() { + @Override + protected void engineInit(KeyStore keyStore) { + // NOOP + } + + @Override + protected void engineInit(ManagerFactoryParameters managerFactoryParameters) { + // NOOP + } + + @Override + protected TrustManager[] engineGetTrustManagers() { + return new TrustManager[] { clientTmf }; + } + }) + .keyManager(newKeyManagerFactory(ssc)) + .sslProvider(sslClientProvider()) + .sslContextProvider(clientSslContextProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + serverSslCtx = wrapContext(param, SslContextBuilder.forServer(newKeyManagerFactory(ssc)) + .trustManager(new SimpleTrustManagerFactory() { + @Override + protected void engineInit(KeyStore keyStore) { + // NOOP + } + + @Override + protected void engineInit(ManagerFactoryParameters managerFactoryParameters) { + // NOOP + } + + @Override + protected TrustManager[] engineGetTrustManagers() { + return new TrustManager[] { serverTmf }; + } + }) + .sslProvider(sslServerProvider()) + .sslContextProvider(serverSslContextProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .clientAuth(ClientAuth.REQUIRE) + .build()); + SSLEngine clientEngine = null; + SSLEngine serverEngine = null; + try { + clientEngine = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + serverEngine = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + handshake(param.type(), param.delegate(), clientEngine, serverEngine); + + assertTrue(clientTmf.isVerified()); + assertTrue(serverTmf.isVerified()); + } finally { + cleanupClientSslEngine(clientEngine); + cleanupServerSslEngine(serverEngine); + ssc.delete(); + } + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testSessionLocalWhenNonMutualWithKeyManager(SSLEngineTestParam param) throws Exception { + testSessionLocalWhenNonMutual(param, true); + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testSessionLocalWhenNonMutualWithoutKeyManager(SSLEngineTestParam param) throws Exception { + testSessionLocalWhenNonMutual(param, false); + } + + private void testSessionLocalWhenNonMutual(SSLEngineTestParam param, boolean useKeyManager) throws Exception { + final SelfSignedCertificate ssc = new SelfSignedCertificate(); + + SslContextBuilder clientSslCtxBuilder = SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(sslClientProvider()) + .sslContextProvider(clientSslContextProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()); + + if (useKeyManager) { + clientSslCtxBuilder.keyManager(newKeyManagerFactory(ssc)); + } else { + clientSslCtxBuilder.keyManager(ssc.certificate(), ssc.privateKey()); + } + clientSslCtx = wrapContext(param, clientSslCtxBuilder.build()); + + final SslContextBuilder serverSslCtxBuilder; + if (useKeyManager) { + serverSslCtxBuilder = SslContextBuilder.forServer(newKeyManagerFactory(ssc)); + } else { + serverSslCtxBuilder = SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()); + } + serverSslCtxBuilder.trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(sslServerProvider()) + .sslContextProvider(serverSslContextProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .clientAuth(ClientAuth.NONE); + + serverSslCtx = wrapContext(param, serverSslCtxBuilder.build()); + SSLEngine clientEngine = null; + SSLEngine serverEngine = null; + try { + clientEngine = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + serverEngine = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + handshake(param.type(), param.delegate(), clientEngine, serverEngine); + + SSLSession clientSession = clientEngine.getSession(); + assertNull(clientSession.getLocalCertificates()); + assertNull(clientSession.getLocalPrincipal()); + + SSLSession serverSession = serverEngine.getSession(); + assertNotNull(serverSession.getLocalCertificates()); + assertNotNull(serverSession.getLocalPrincipal()); + } finally { + cleanupClientSslEngine(clientEngine); + cleanupServerSslEngine(serverEngine); + ssc.delete(); + } + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testEnabledProtocolsAndCiphers(SSLEngineTestParam param) throws Exception { + clientSslCtx = wrapContext(param, SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(sslClientProvider()) + .sslContextProvider(clientSslContextProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + SelfSignedCertificate ssc = new SelfSignedCertificate(); + serverSslCtx = wrapContext(param, SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) + .sslProvider(sslServerProvider()) + .sslContextProvider(serverSslContextProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + SSLEngine clientEngine = null; + SSLEngine serverEngine = null; + try { + clientEngine = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + serverEngine = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + handshake(param.type(), param.delegate(), clientEngine, serverEngine); + assertEnabledProtocolsAndCipherSuites(clientEngine); + assertEnabledProtocolsAndCipherSuites(serverEngine); + } finally { + cleanupClientSslEngine(clientEngine); + cleanupServerSslEngine(serverEngine); + ssc.delete(); + } + } + + private static void assertEnabledProtocolsAndCipherSuites(SSLEngine engine) { + String protocol = engine.getSession().getProtocol(); + String cipherSuite = engine.getSession().getCipherSuite(); + + assertArrayContains(protocol, engine.getEnabledProtocols()); + assertArrayContains(cipherSuite, engine.getEnabledCipherSuites()); + } + + private static void assertArrayContains(String expected, String[] array) { + for (String value: array) { + if (expected.equals(value)) { + return; + } + } + fail("Array did not contain '" + expected + "':" + Arrays.toString(array)); + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testMasterKeyLogging(final SSLEngineTestParam param) throws Exception { + if (param.combo() != ProtocolCipherCombo.tlsv12()) { + return; + } + /* + * At the moment master key logging is not supported for conscrypt + */ + assumeFalse(serverSslContextProvider() instanceof OpenSSLProvider); + + /* + * The JDK SSL engine master key retrieval relies on being able to set field access to true. + * That is not available in JDK9+ + */ + assumeFalse(sslServerProvider() == SslProvider.JDK && PlatformDependent.javaVersion() > 8); + + String originalSystemPropertyValue = SystemPropertyUtil.get(SslMasterKeyHandler.SYSTEM_PROP_KEY); + System.setProperty(SslMasterKeyHandler.SYSTEM_PROP_KEY, Boolean.TRUE.toString()); + + SelfSignedCertificate ssc = new SelfSignedCertificate(); + serverSslCtx = wrapContext(param, SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) + .sslProvider(sslServerProvider()) + .sslContextProvider(serverSslContextProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + Socket socket = null; + + try { + sb = new ServerBootstrap(); + sb.group(new NioEventLoopGroup(), new NioEventLoopGroup()); + sb.channel(NioServerSocketChannel.class); + + final Promise promise = sb.config().group().next().newPromise(); + serverChannel = sb.childHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + ch.config().setAllocator(new TestByteBufAllocator(ch.config().getAllocator(), param.type())); + + SslHandler sslHandler = !param.delegate() ? + serverSslCtx.newHandler(ch.alloc()) : + serverSslCtx.newHandler(ch.alloc(), delegatingExecutor); + + ch.pipeline().addLast(sslHandler); + ch.pipeline().addLast(new SslMasterKeyHandler() { + @Override + protected void accept(SecretKey masterKey, SSLSession session) { + promise.setSuccess(masterKey); + } + }); + serverConnectedChannel = ch; + } + }).bind(new InetSocketAddress(0)).sync().channel(); + + int port = ((InetSocketAddress) serverChannel.localAddress()).getPort(); + + SSLContext sslContext = SSLContext.getInstance("TLS"); + sslContext.init(null, InsecureTrustManagerFactory.INSTANCE.getTrustManagers(), null); + socket = sslContext.getSocketFactory().createSocket(NetUtil.LOCALHOST, port); + OutputStream out = socket.getOutputStream(); + out.write(1); + out.flush(); + + assertTrue(promise.await(10, TimeUnit.SECONDS)); + SecretKey key = promise.get(); + assertEquals(48, key.getEncoded().length, "AES secret key must be 48 bytes"); + } finally { + closeQuietly(socket); + if (originalSystemPropertyValue != null) { + System.setProperty(SslMasterKeyHandler.SYSTEM_PROP_KEY, originalSystemPropertyValue); + } else { + System.clearProperty(SslMasterKeyHandler.SYSTEM_PROP_KEY); + } + ssc.delete(); + } + } + + private static void closeQuietly(Closeable c) { + if (c != null) { + try { + c.close(); + } catch (IOException ignore) { + // ignore + } + } + } + + private static KeyManagerFactory newKeyManagerFactory(SelfSignedCertificate ssc) + throws UnrecoverableKeyException, KeyStoreException, NoSuchAlgorithmException, + CertificateException, IOException { + return SslContext.buildKeyManagerFactory( + new java.security.cert.X509Certificate[] { ssc.cert() }, null, ssc.key(), null, null, null); + } + + private static final class TestTrustManagerFactory extends X509ExtendedTrustManager { + private final Certificate localCert; + private volatile boolean verified; + + TestTrustManagerFactory(Certificate localCert) { + this.localCert = localCert; + } + + boolean isVerified() { + return verified; + } + + @Override + public void checkClientTrusted( + java.security.cert.X509Certificate[] x509Certificates, String s, Socket socket) { + fail(); + } + + @Override + public void checkServerTrusted( + java.security.cert.X509Certificate[] x509Certificates, String s, Socket socket) { + fail(); + } + + @Override + public void checkClientTrusted( + java.security.cert.X509Certificate[] x509Certificates, String s, SSLEngine sslEngine) { + verified = true; + assertFalse(sslEngine.getUseClientMode()); + SSLSession session = sslEngine.getHandshakeSession(); + assertNotNull(session); + Certificate[] localCertificates = session.getLocalCertificates(); + assertNotNull(localCertificates); + assertEquals(1, localCertificates.length); + assertEquals(localCert, localCertificates[0]); + assertNotNull(session.getLocalPrincipal()); + } + + @Override + public void checkServerTrusted( + java.security.cert.X509Certificate[] x509Certificates, String s, SSLEngine sslEngine) { + verified = true; + assertTrue(sslEngine.getUseClientMode()); + SSLSession session = sslEngine.getHandshakeSession(); + assertNotNull(session); + assertNull(session.getLocalCertificates()); + assertNull(session.getLocalPrincipal()); + } + + @Override + public void checkClientTrusted( + java.security.cert.X509Certificate[] x509Certificates, String s) { + fail(); + } + + @Override + public void checkServerTrusted( + java.security.cert.X509Certificate[] x509Certificates, String s) { + fail(); + } + + @Override + public java.security.cert.X509Certificate[] getAcceptedIssuers() { + return EmptyArrays.EMPTY_X509_CERTIFICATES; + } + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testDefaultProtocolsIncludeTLSv13(SSLEngineTestParam param) throws Exception { + // Don't specify the protocols as we want to test the default selection + clientSslCtx = wrapContext(param, SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(sslClientProvider()) + .sslContextProvider(clientSslContextProvider()) + .ciphers(param.ciphers()) + .build()); + SelfSignedCertificate ssc = new SelfSignedCertificate(); + serverSslCtx = wrapContext(param, SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) + .sslProvider(sslServerProvider()) + .sslContextProvider(serverSslContextProvider()) + .ciphers(param.ciphers()) + .build()); + SSLEngine clientEngine = null; + SSLEngine serverEngine = null; + String[] clientProtocols; + String[] serverProtocols; + try { + clientEngine = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + serverEngine = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + clientProtocols = clientEngine.getEnabledProtocols(); + serverProtocols = serverEngine.getEnabledProtocols(); + } finally { + cleanupClientSslEngine(clientEngine); + cleanupServerSslEngine(serverEngine); + ssc.delete(); + } + + assertEquals(SslProvider.isTlsv13EnabledByDefault(sslClientProvider(), clientSslContextProvider()), + arrayContains(clientProtocols, SslProtocols.TLS_v1_3)); + assertEquals(SslProvider.isTlsv13EnabledByDefault(sslServerProvider(), serverSslContextProvider()), + arrayContains(serverProtocols, SslProtocols.TLS_v1_3)); + } + + // IMPORTANT: If this test fails, try rerunning the 'generate-certificate.sh' script. + @MethodSource("newTestParams") + @ParameterizedTest + public void testRSASSAPSS(SSLEngineTestParam param) throws Exception { + char[] password = "password".toCharArray(); + + final KeyStore serverKeyStore = KeyStore.getInstance("PKCS12"); + serverKeyStore.load(getClass().getResourceAsStream("rsaValidations-server-keystore.p12"), password); + + final KeyStore clientKeyStore = KeyStore.getInstance("PKCS12"); + clientKeyStore.load(getClass().getResourceAsStream("rsaValidation-user-certs.p12"), password); + + final KeyManagerFactory serverKeyManagerFactory = + KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); + serverKeyManagerFactory.init(serverKeyStore, password); + final KeyManagerFactory clientKeyManagerFactory = + KeyManagerFactory.getInstance(KeyManagerFactory.getDefaultAlgorithm()); + clientKeyManagerFactory.init(clientKeyStore, password); + + File commonChain = ResourcesUtil.getFile(getClass(), "rsapss-ca-cert.cert"); + ClientAuth auth = ClientAuth.REQUIRE; + + mySetupMutualAuth(param, serverKeyManagerFactory, commonChain, clientKeyManagerFactory, commonChain, + auth, false, true); + + assertTrue(clientLatch.await(10, TimeUnit.SECONDS)); + rethrowIfNotNull(clientException); + assertTrue(serverLatch.await(5, TimeUnit.SECONDS)); + rethrowIfNotNull(serverException); + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testInvalidSNIIsIgnoredAndNotThrow(SSLEngineTestParam param) throws Exception { + clientSslCtx = wrapContext(param, SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(sslClientProvider()) + .sslContextProvider(clientSslContextProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + SelfSignedCertificate ssc = new SelfSignedCertificate(); + serverSslCtx = wrapContext(param, SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) + .sslProvider(sslServerProvider()) + .sslContextProvider(serverSslContextProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + SSLEngine clientEngine = null; + SSLEngine serverEngine = null; + try { + clientEngine = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT, "/invalid.path", 80)); + serverEngine = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + handshake(param.type(), param.delegate(), clientEngine, serverEngine); + assertNotNull(clientEngine.getSSLParameters()); + assertNotNull(serverEngine.getSSLParameters()); + } finally { + cleanupClientSslEngine(clientEngine); + cleanupServerSslEngine(serverEngine); + } + } + + @MethodSource("newTestParams") + @ParameterizedTest + public void testBufferUnderflowPacketSizeDependency(SSLEngineTestParam param) throws Exception { + SelfSignedCertificate ssc = new SelfSignedCertificate(); + clientSslCtx = wrapContext(param, SslContextBuilder.forClient() + .keyManager(ssc.certificate(), ssc.privateKey()) + .trustManager((TrustManagerFactory) null) + .sslProvider(sslClientProvider()) + .sslContextProvider(clientSslContextProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + serverSslCtx = wrapContext(param, SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) + .sslProvider(sslServerProvider()) + .sslContextProvider(serverSslContextProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .clientAuth(ClientAuth.REQUIRE) + .build()); + SSLEngine clientEngine = null; + SSLEngine serverEngine = null; + try { + clientEngine = wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + serverEngine = wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)); + + handshake(param.type(), param.delegate(), clientEngine, serverEngine); + } catch (SSLHandshakeException expected) { + // Expected + } finally { + cleanupClientSslEngine(clientEngine); + cleanupServerSslEngine(serverEngine); + } + } + + @Test + public void testExtraDataInLastSrcBufferForClientUnwrap() throws Exception { + SSLEngineTestParam param = new SSLEngineTestParam(BufferType.Direct, ProtocolCipherCombo.tlsv12(), false); + SelfSignedCertificate ssc = new SelfSignedCertificate(); + clientSslCtx = wrapContext(param, SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(sslClientProvider()) + .sslContextProvider(clientSslContextProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .build()); + serverSslCtx = wrapContext(param, SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) + .sslProvider(sslServerProvider()) + .sslContextProvider(serverSslContextProvider()) + .protocols(param.protocols()) + .ciphers(param.ciphers()) + .clientAuth(ClientAuth.NONE) + .build()); + testExtraDataInLastSrcBufferForClientUnwrap(param, + wrapEngine(clientSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT)), + wrapEngine(serverSslCtx.newEngine(UnpooledByteBufAllocator.DEFAULT))); + } + + protected final void testExtraDataInLastSrcBufferForClientUnwrap( + SSLEngineTestParam param, SSLEngine clientEngine, SSLEngine serverEngine) throws Exception { + try { + ByteBuffer cTOs = allocateBuffer(param.type(), clientEngine.getSession().getPacketBufferSize()); + // Ensure we can fit two records as we want to include two records once the handshake completes on the + // server side. + ByteBuffer sTOc = allocateBuffer(param.type(), serverEngine.getSession().getPacketBufferSize() * 2); + + ByteBuffer serverAppReadBuffer = + allocateBuffer(param.type(), serverEngine.getSession().getApplicationBufferSize()); + ByteBuffer clientAppReadBuffer = + allocateBuffer(param.type(), clientEngine.getSession().getApplicationBufferSize()); + + ByteBuffer empty = allocateBuffer(param.type(), 0); + + SSLEngineResult clientResult; + SSLEngineResult serverResult; + + boolean clientHandshakeFinished = false; + boolean serverHandshakeFinished = false; + + do { + int cTOsPos = cTOs.position(); + int sTOcPos = sTOc.position(); + + if (!clientHandshakeFinished) { + clientResult = clientEngine.wrap(empty, cTOs); + runDelegatedTasks(param.delegate(), clientResult, clientEngine); + assertEquals(empty.remaining(), clientResult.bytesConsumed()); + assertEquals(cTOs.position() - cTOsPos, clientResult.bytesProduced()); + + if (isHandshakeFinished(clientResult)) { + clientHandshakeFinished = true; + } + + if (clientResult.getStatus() == Status.BUFFER_OVERFLOW) { + cTOs = increaseDstBuffer(clientEngine.getSession().getPacketBufferSize(), param.type(), cTOs); + } + } + + if (!serverHandshakeFinished) { + serverResult = serverEngine.wrap(empty, sTOc); + runDelegatedTasks(param.delegate(), serverResult, serverEngine); + assertEquals(empty.remaining(), serverResult.bytesConsumed()); + assertEquals(sTOc.position() - sTOcPos, serverResult.bytesProduced()); + + if (isHandshakeFinished(serverResult)) { + serverHandshakeFinished = true; + // We finished the handshake on the server side, lets add another record to the sTOc buffer + // so we can test that we will not unwrap extra data before we actually consider the handshake + // complete on the client side as well. + serverResult = serverEngine.wrap(ByteBuffer.wrap(new byte[8]), sTOc); + assertEquals(8, serverResult.bytesConsumed()); + } + + if (serverResult.getStatus() == Status.BUFFER_OVERFLOW) { + sTOc = increaseDstBuffer(serverEngine.getSession().getPacketBufferSize(), param.type(), sTOc); + } + } + + cTOs.flip(); + sTOc.flip(); + + cTOsPos = cTOs.position(); + sTOcPos = sTOc.position(); + + if (!clientHandshakeFinished) { + int clientAppReadBufferPos = clientAppReadBuffer.position(); + clientResult = clientEngine.unwrap(sTOc, clientAppReadBuffer); + + runDelegatedTasks(param.delegate(), clientResult, clientEngine); + assertEquals(sTOc.position() - sTOcPos, clientResult.bytesConsumed()); + assertEquals(clientAppReadBuffer.position() - clientAppReadBufferPos, clientResult.bytesProduced()); + assertEquals(0, clientAppReadBuffer.position()); + + if (isHandshakeFinished(clientResult)) { + clientHandshakeFinished = true; + } else { + assertEquals(0, clientAppReadBuffer.position() - clientAppReadBufferPos); + } + + if (clientResult.getStatus() == Status.BUFFER_OVERFLOW) { + clientAppReadBuffer = increaseDstBuffer( + clientEngine.getSession().getApplicationBufferSize(), + param.type(), clientAppReadBuffer); + } + } + + if (!serverHandshakeFinished) { + int serverAppReadBufferPos = serverAppReadBuffer.position(); + serverResult = serverEngine.unwrap(cTOs, serverAppReadBuffer); + runDelegatedTasks(param.delegate(), serverResult, serverEngine); + assertEquals(cTOs.position() - cTOsPos, serverResult.bytesConsumed()); + assertEquals(serverAppReadBuffer.position() - serverAppReadBufferPos, serverResult.bytesProduced()); + assertEquals(0, serverAppReadBuffer.position()); + + if (isHandshakeFinished(serverResult)) { + serverHandshakeFinished = true; + } + + if (serverResult.getStatus() == Status.BUFFER_OVERFLOW) { + serverAppReadBuffer = increaseDstBuffer( + serverEngine.getSession().getApplicationBufferSize(), + param.type(), serverAppReadBuffer); + } + } + + compactOrClear(cTOs); + compactOrClear(sTOc); + + serverAppReadBuffer.clear(); + clientAppReadBuffer.clear(); + + if (clientEngine.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING) { + clientHandshakeFinished = true; + } + + if (serverEngine.getHandshakeStatus() == SSLEngineResult.HandshakeStatus.NOT_HANDSHAKING) { + serverHandshakeFinished = true; + } + } while (!clientHandshakeFinished || !serverHandshakeFinished); + } finally { + cleanupClientSslEngine(clientEngine); + cleanupServerSslEngine(serverEngine); + } + } + + protected SSLEngine wrapEngine(SSLEngine engine) { + return engine; + } + + protected SslContext wrapContext(SSLEngineTestParam param, SslContext context) { + return context; + } +} diff --git a/netty-handler-ssl/src/test/java/io/netty/handler/ssl/SignatureAlgorithmConverterTest.java b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/SignatureAlgorithmConverterTest.java new file mode 100644 index 0000000..e0974c1 --- /dev/null +++ b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/SignatureAlgorithmConverterTest.java @@ -0,0 +1,59 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; + +public class SignatureAlgorithmConverterTest { + + @Test + public void testWithEncryption() { + assertEquals("SHA512withRSA", SignatureAlgorithmConverter.toJavaName("sha512WithRSAEncryption")); + } + + @Test + public void testWithDash() { + assertEquals("SHA256withECDSA", SignatureAlgorithmConverter.toJavaName("ecdsa-with-SHA256")); + } + + @Test + public void testWithUnderscore() { + assertEquals("SHA256withDSA", SignatureAlgorithmConverter.toJavaName("dsa_with_SHA256")); + } + + @Test + public void testBoringSSLOneUnderscore() { + assertEquals("SHA256withECDSA", SignatureAlgorithmConverter.toJavaName("ecdsa_sha256")); + } + + @Test + public void testBoringSSLPkcs1() { + assertEquals("SHA256withRSA", SignatureAlgorithmConverter.toJavaName("rsa_pkcs1_sha256")); + } + + @Test + public void testBoringSSLPSS() { + assertEquals("SHA256withRSA", SignatureAlgorithmConverter.toJavaName("rsa_pss_rsae_sha256")); + } + + @Test + public void testInvalid() { + assertNull(SignatureAlgorithmConverter.toJavaName("ThisIsSomethingInvalid")); + } +} diff --git a/netty-handler-ssl/src/test/java/io/netty/handler/ssl/SniClientJava8TestUtil.java b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/SniClientJava8TestUtil.java new file mode 100644 index 0000000..fc1b134 --- /dev/null +++ b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/SniClientJava8TestUtil.java @@ -0,0 +1,349 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBufAllocator; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.DefaultEventLoopGroup; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.local.LocalAddress; +import io.netty.channel.local.LocalChannel; +import io.netty.channel.local.LocalServerChannel; +import io.netty.handler.ssl.util.InsecureTrustManagerFactory; +import io.netty.handler.ssl.util.SelfSignedCertificate; +import io.netty.handler.ssl.util.SimpleTrustManagerFactory; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.concurrent.Promise; +import io.netty.util.internal.EmptyArrays; +import io.netty.util.internal.ThrowableUtil; + +import javax.net.ssl.ExtendedSSLSession; +import javax.net.ssl.KeyManager; +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.KeyManagerFactorySpi; +import javax.net.ssl.ManagerFactoryParameters; +import javax.net.ssl.SNIHostName; +import javax.net.ssl.SNIMatcher; +import javax.net.ssl.SNIServerName; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLException; +import javax.net.ssl.SSLParameters; +import javax.net.ssl.SSLSession; +import javax.net.ssl.TrustManager; +import javax.net.ssl.TrustManagerFactory; +import javax.net.ssl.X509ExtendedKeyManager; +import javax.net.ssl.X509ExtendedTrustManager; +import java.io.IOException; +import java.net.Socket; +import java.security.InvalidAlgorithmParameterException; +import java.security.KeyStore; +import java.security.KeyStoreException; +import java.security.NoSuchAlgorithmException; +import java.security.Principal; +import java.security.PrivateKey; +import java.security.UnrecoverableKeyException; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +/** + * In extra class to be able to run tests with java7 without trying to load classes that not exists in java7. + */ +final class SniClientJava8TestUtil { + + private SniClientJava8TestUtil() { } + + static void testSniClient(SslProvider sslClientProvider, SslProvider sslServerProvider, final boolean match) + throws Exception { + final String sniHost = "sni.netty.io"; + SelfSignedCertificate cert = new SelfSignedCertificate(); + LocalAddress address = new LocalAddress("test"); + EventLoopGroup group = new DefaultEventLoopGroup(1); + SslContext sslServerContext = null; + SslContext sslClientContext = null; + + Channel sc = null; + Channel cc = null; + try { + sslServerContext = SslContextBuilder.forServer(cert.key(), cert.cert()) + .sslProvider(sslServerProvider).build(); + final Promise promise = group.next().newPromise(); + ServerBootstrap sb = new ServerBootstrap(); + + final SslContext finalContext = sslServerContext; + sc = sb.group(group).channel(LocalServerChannel.class).childHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + SslHandler handler = finalContext.newHandler(ch.alloc()); + SSLParameters parameters = handler.engine().getSSLParameters(); + SNIMatcher matcher = new SNIMatcher(0) { + @Override + public boolean matches(SNIServerName sniServerName) { + return match; + } + }; + parameters.setSNIMatchers(Collections.singleton(matcher)); + handler.engine().setSSLParameters(parameters); + + ch.pipeline().addFirst(handler); + ch.pipeline().addLast(new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + if (evt instanceof SslHandshakeCompletionEvent) { + SslHandshakeCompletionEvent event = (SslHandshakeCompletionEvent) evt; + if (match) { + if (event.isSuccess()) { + promise.setSuccess(null); + } else { + promise.setFailure(event.cause()); + } + } else { + if (event.isSuccess()) { + promise.setFailure(new AssertionError("expected SSLException")); + } else { + Throwable cause = event.cause(); + if (cause instanceof SSLException) { + promise.setSuccess(null); + } else { + promise.setFailure( + new AssertionError("cause not of type SSLException: " + + ThrowableUtil.stackTraceToString(cause))); + } + } + } + } + } + }); + } + }).bind(address).syncUninterruptibly().channel(); + + sslClientContext = SslContextBuilder.forClient().trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(sslClientProvider).build(); + + SslHandler sslHandler = new SslHandler( + sslClientContext.newEngine(ByteBufAllocator.DEFAULT, sniHost, -1)); + Bootstrap cb = new Bootstrap(); + cc = cb.group(group).channel(LocalChannel.class).handler(sslHandler) + .connect(address).syncUninterruptibly().channel(); + + promise.syncUninterruptibly(); + sslHandler.handshakeFuture().syncUninterruptibly(); + } finally { + if (cc != null) { + cc.close().syncUninterruptibly(); + } + if (sc != null) { + sc.close().syncUninterruptibly(); + } + + ReferenceCountUtil.release(sslServerContext); + ReferenceCountUtil.release(sslClientContext); + + cert.delete(); + + group.shutdownGracefully(); + } + } + + static void assertSSLSession(boolean clientSide, SSLSession session, String name) { + assertSSLSession(clientSide, session, new SNIHostName(name)); + } + + private static void assertSSLSession(boolean clientSide, SSLSession session, SNIServerName name) { + assertNotNull(session); + if (session instanceof ExtendedSSLSession) { + ExtendedSSLSession extendedSSLSession = (ExtendedSSLSession) session; + List names = extendedSSLSession.getRequestedServerNames(); + assertEquals(1, names.size()); + assertEquals(name, names.get(0)); + assertTrue(extendedSSLSession.getLocalSupportedSignatureAlgorithms().length > 0); + if (clientSide) { + assertEquals(0, extendedSSLSession.getPeerSupportedSignatureAlgorithms().length); + } else { + assertTrue(extendedSSLSession.getPeerSupportedSignatureAlgorithms().length >= 0); + } + } + } + + static TrustManagerFactory newSniX509TrustmanagerFactory(String name) { + return new SniX509TrustmanagerFactory(new SNIHostName(name)); + } + + private static final class SniX509TrustmanagerFactory extends SimpleTrustManagerFactory { + + private final SNIServerName name; + + SniX509TrustmanagerFactory(SNIServerName name) { + this.name = name; + } + + @Override + protected void engineInit(KeyStore keyStore) throws Exception { + // NOOP + } + + @Override + protected void engineInit(ManagerFactoryParameters managerFactoryParameters) throws Exception { + // NOOP + } + + @Override + protected TrustManager[] engineGetTrustManagers() { + return new TrustManager[] { new X509ExtendedTrustManager() { + @Override + public void checkClientTrusted(X509Certificate[] x509Certificates, String s, Socket socket) + throws CertificateException { + fail(); + } + + @Override + public void checkServerTrusted(X509Certificate[] x509Certificates, String s, Socket socket) + throws CertificateException { + fail(); + } + + @Override + public void checkClientTrusted(X509Certificate[] x509Certificates, String s, SSLEngine sslEngine) + throws CertificateException { + fail(); + } + + @Override + public void checkServerTrusted(X509Certificate[] x509Certificates, String s, SSLEngine sslEngine) + throws CertificateException { + assertSSLSession(sslEngine.getUseClientMode(), sslEngine.getHandshakeSession(), name); + } + + @Override + public void checkClientTrusted(X509Certificate[] x509Certificates, String s) + throws CertificateException { + fail(); + } + + @Override + public void checkServerTrusted(X509Certificate[] x509Certificates, String s) + throws CertificateException { + fail(); + } + + @Override + public X509Certificate[] getAcceptedIssuers() { + return EmptyArrays.EMPTY_X509_CERTIFICATES; + } + } }; + } + } + + static KeyManagerFactory newSniX509KeyManagerFactory(SelfSignedCertificate cert, String hostname) + throws NoSuchAlgorithmException, KeyStoreException, UnrecoverableKeyException, + IOException, CertificateException { + return new SniX509KeyManagerFactory( + new SNIHostName(hostname), SslContext.buildKeyManagerFactory( + new X509Certificate[] { cert.cert() }, null, cert.key(), null, null, null)); + } + + private static final class SniX509KeyManagerFactory extends KeyManagerFactory { + + SniX509KeyManagerFactory(final SNIServerName name, final KeyManagerFactory factory) { + super(new KeyManagerFactorySpi() { + @Override + protected void engineInit(KeyStore keyStore, char[] chars) + throws KeyStoreException, NoSuchAlgorithmException, UnrecoverableKeyException { + factory.init(keyStore, chars); + } + + @Override + protected void engineInit(ManagerFactoryParameters managerFactoryParameters) + throws InvalidAlgorithmParameterException { + factory.init(managerFactoryParameters); + } + + @Override + protected KeyManager[] engineGetKeyManagers() { + List managers = new ArrayList(); + for (final KeyManager km: factory.getKeyManagers()) { + if (km instanceof X509ExtendedKeyManager) { + managers.add(new X509ExtendedKeyManager() { + @Override + public String[] getClientAliases(String s, Principal[] principals) { + return ((X509ExtendedKeyManager) km).getClientAliases(s, principals); + } + + @Override + public String chooseClientAlias(String[] strings, Principal[] principals, + Socket socket) { + return ((X509ExtendedKeyManager) km).chooseClientAlias(strings, principals, socket); + } + + @Override + public String[] getServerAliases(String s, Principal[] principals) { + return ((X509ExtendedKeyManager) km).getServerAliases(s, principals); + } + + @Override + public String chooseServerAlias(String s, Principal[] principals, Socket socket) { + return ((X509ExtendedKeyManager) km).chooseServerAlias(s, principals, socket); + } + + @Override + public X509Certificate[] getCertificateChain(String s) { + return ((X509ExtendedKeyManager) km).getCertificateChain(s); + } + + @Override + public PrivateKey getPrivateKey(String s) { + return ((X509ExtendedKeyManager) km).getPrivateKey(s); + } + + @Override + public String chooseEngineClientAlias(String[] strings, Principal[] principals, + SSLEngine sslEngine) { + return ((X509ExtendedKeyManager) km) + .chooseEngineClientAlias(strings, principals, sslEngine); + } + + @Override + public String chooseEngineServerAlias(String s, Principal[] principals, + SSLEngine sslEngine) { + + SSLSession session = sslEngine.getHandshakeSession(); + assertSSLSession(sslEngine.getUseClientMode(), session, name); + return ((X509ExtendedKeyManager) km) + .chooseEngineServerAlias(s, principals, sslEngine); + } + }); + } else { + managers.add(km); + } + } + return managers.toArray(new KeyManager[0]); + } + }, factory.getProvider(), factory.getAlgorithm()); + } + } +} diff --git a/netty-handler-ssl/src/test/java/io/netty/handler/ssl/SniClientTest.java b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/SniClientTest.java new file mode 100644 index 0000000..007c82f --- /dev/null +++ b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/SniClientTest.java @@ -0,0 +1,179 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBufAllocator; +import io.netty.channel.Channel; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.DefaultEventLoopGroup; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.local.LocalAddress; +import io.netty.channel.local.LocalChannel; +import io.netty.channel.local.LocalServerChannel; +import io.netty.handler.ssl.util.InsecureTrustManagerFactory; +import io.netty.handler.ssl.util.SelfSignedCertificate; +import io.netty.util.Mapping; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.concurrent.Promise; +import io.netty.util.internal.PlatformDependent; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.api.function.Executable; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import javax.net.ssl.KeyManagerFactory; +import javax.net.ssl.SSLException; +import javax.net.ssl.TrustManagerFactory; +import java.security.cert.X509Certificate; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.concurrent.TimeUnit; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +public class SniClientTest { + private static final String PARAMETERIZED_NAME = "{index}: serverSslProvider = {0}, clientSslProvider = {1}"; + static Collection parameters() { + List providers = new ArrayList(Arrays.asList(SslProvider.values())); + if (!OpenSsl.isAvailable()) { + providers.remove(SslProvider.OPENSSL); + providers.remove(SslProvider.OPENSSL_REFCNT); + } + + List params = new ArrayList(); + for (SslProvider sp: providers) { + for (SslProvider cp: providers) { + params.add(new Object[] { sp, cp }); + } + } + return params; + } + + @ParameterizedTest(name = PARAMETERIZED_NAME) + @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS) + @MethodSource("parameters") + public void testSniSNIMatcherMatchesClient(SslProvider serverProvider, SslProvider clientProvider) + throws Exception { + assumeTrue(PlatformDependent.javaVersion() >= 8); + SniClientJava8TestUtil.testSniClient(serverProvider, clientProvider, true); + } + + @ParameterizedTest(name = PARAMETERIZED_NAME) + @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS) + @MethodSource("parameters") + public void testSniSNIMatcherDoesNotMatchClient( + final SslProvider serverProvider, final SslProvider clientProvider) { + assumeTrue(PlatformDependent.javaVersion() >= 8); + assertThrows(SSLException.class, new Executable() { + @Override + public void execute() throws Throwable { + SniClientJava8TestUtil.testSniClient(serverProvider, clientProvider, false); + } + }); + } + + @ParameterizedTest(name = PARAMETERIZED_NAME) + @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS) + @MethodSource("parameters") + public void testSniClient(SslProvider sslServerProvider, SslProvider sslClientProvider) throws Exception { + String sniHostName = "sni.netty.io"; + LocalAddress address = new LocalAddress("SniClientTest"); + EventLoopGroup group = new DefaultEventLoopGroup(1); + SelfSignedCertificate cert = new SelfSignedCertificate(); + SslContext sslServerContext = null; + SslContext sslClientContext = null; + + Channel sc = null; + Channel cc = null; + try { + if ((sslServerProvider == SslProvider.OPENSSL || sslServerProvider == SslProvider.OPENSSL_REFCNT) + && !OpenSsl.useKeyManagerFactory()) { + sslServerContext = SslContextBuilder.forServer(cert.certificate(), cert.privateKey()) + .sslProvider(sslServerProvider) + .build(); + } else { + // The used OpenSSL version does support a KeyManagerFactory, so use it. + KeyManagerFactory kmf = PlatformDependent.javaVersion() >= 8 ? + SniClientJava8TestUtil.newSniX509KeyManagerFactory(cert, sniHostName) : + SslContext.buildKeyManagerFactory( + new X509Certificate[] { cert.cert() }, null, + cert.key(), null, null, null); + + sslServerContext = SslContextBuilder.forServer(kmf) + .sslProvider(sslServerProvider) + .build(); + } + + final SslContext finalContext = sslServerContext; + final Promise promise = group.next().newPromise(); + ServerBootstrap sb = new ServerBootstrap(); + sc = sb.group(group).channel(LocalServerChannel.class).childHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ch.pipeline().addFirst(new SniHandler(new Mapping() { + @Override + public SslContext map(String input) { + promise.setSuccess(input); + return finalContext; + } + })); + } + }).bind(address).syncUninterruptibly().channel(); + + TrustManagerFactory tmf = PlatformDependent.javaVersion() >= 8 ? + SniClientJava8TestUtil.newSniX509TrustmanagerFactory(sniHostName) : + InsecureTrustManagerFactory.INSTANCE; + sslClientContext = SslContextBuilder.forClient().trustManager(tmf) + .sslProvider(sslClientProvider).build(); + Bootstrap cb = new Bootstrap(); + + SslHandler handler = new SslHandler( + sslClientContext.newEngine(ByteBufAllocator.DEFAULT, sniHostName, -1)); + cc = cb.group(group).channel(LocalChannel.class).handler(handler) + .connect(address).syncUninterruptibly().channel(); + assertEquals(sniHostName, promise.syncUninterruptibly().getNow()); + + // After we are done with handshaking getHandshakeSession() should return null. + handler.handshakeFuture().syncUninterruptibly(); + assertNull(handler.engine().getHandshakeSession()); + + if (PlatformDependent.javaVersion() >= 8) { + SniClientJava8TestUtil.assertSSLSession( + handler.engine().getUseClientMode(), handler.engine().getSession(), sniHostName); + } + } finally { + if (cc != null) { + cc.close().syncUninterruptibly(); + } + if (sc != null) { + sc.close().syncUninterruptibly(); + } + ReferenceCountUtil.release(sslServerContext); + ReferenceCountUtil.release(sslClientContext); + + cert.delete(); + + group.shutdownGracefully(); + } + } +} diff --git a/netty-handler-ssl/src/test/java/io/netty/handler/ssl/SniHandlerTest.java b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/SniHandlerTest.java new file mode 100644 index 0000000..e01ba16 --- /dev/null +++ b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/SniHandlerTest.java @@ -0,0 +1,858 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.ssl; + +import java.io.File; +import java.net.InetSocketAddress; +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; + +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLException; + +import io.netty.handler.codec.TooLongFrameException; +import io.netty.util.concurrent.Future; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.DefaultEventLoopGroup; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.channel.local.LocalAddress; +import io.netty.channel.local.LocalChannel; +import io.netty.channel.local.LocalServerChannel; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.nio.NioServerSocketChannel; +import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.handler.codec.DecoderException; +import io.netty.handler.ssl.util.InsecureTrustManagerFactory; +import io.netty.handler.ssl.util.SelfSignedCertificate; +import io.netty.util.DomainNameMapping; +import io.netty.util.DomainNameMappingBuilder; +import io.netty.util.Mapping; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.ReferenceCounted; +import io.netty.util.concurrent.Promise; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.ResourcesUtil; +import io.netty.util.internal.StringUtil; +import org.hamcrest.CoreMatchers; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.api.function.Executable; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.CoreMatchers.nullValue; +import static org.hamcrest.MatcherAssert.assertThat; + +import static org.junit.jupiter.api.Assumptions.assumeTrue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.mock; + +public class SniHandlerTest { + + private static ApplicationProtocolConfig newApnConfig() { + return new ApplicationProtocolConfig( + ApplicationProtocolConfig.Protocol.ALPN, + // NO_ADVERTISE is currently the only mode supported by both OpenSsl and JDK providers. + ApplicationProtocolConfig.SelectorFailureBehavior.NO_ADVERTISE, + // ACCEPT is currently the only mode supported by both OpenSsl and JDK providers. + ApplicationProtocolConfig.SelectedListenerFailureBehavior.ACCEPT, + "myprotocol"); + } + + private static void assumeApnSupported(SslProvider provider) { + switch (provider) { + case OPENSSL: + case OPENSSL_REFCNT: + assumeTrue(OpenSsl.isAlpnSupported()); + break; + case JDK: + assumeTrue(true); + break; + default: + throw new Error(); + } + } + + private static SslContext makeSslContext(SslProvider provider, boolean apn) throws Exception { + if (apn) { + assumeApnSupported(provider); + } + + File keyFile = ResourcesUtil.getFile(SniHandlerTest.class, "test_encrypted.pem"); + File crtFile = ResourcesUtil.getFile(SniHandlerTest.class, "test.crt"); + + SslContextBuilder sslCtxBuilder = SslContextBuilder.forServer(crtFile, keyFile, "12345") + .sslProvider(provider); + if (apn) { + sslCtxBuilder.applicationProtocolConfig(newApnConfig()); + } + return sslCtxBuilder.build(); + } + + private static SslContext makeSslClientContext(SslProvider provider, boolean apn) throws Exception { + if (apn) { + assumeApnSupported(provider); + } + + File crtFile = ResourcesUtil.getFile(SniHandlerTest.class, "test.crt"); + + SslContextBuilder sslCtxBuilder = SslContextBuilder.forClient().trustManager(crtFile).sslProvider(provider); + if (apn) { + sslCtxBuilder.applicationProtocolConfig(newApnConfig()); + } + return sslCtxBuilder.build(); + } + + static Iterable data() { + List params = new ArrayList(3); + if (OpenSsl.isAvailable()) { + params.add(SslProvider.OPENSSL); + params.add(SslProvider.OPENSSL_REFCNT); + } + params.add(SslProvider.JDK); + return params; + } + + @ParameterizedTest(name = "{index}: sslProvider={0}") + @MethodSource("data") + public void testNonSslRecord(SslProvider provider) throws Exception { + SslContext nettyContext = makeSslContext(provider, false); + try { + final AtomicReference evtRef = + new AtomicReference(); + SniHandler handler = new SniHandler(new DomainNameMappingBuilder(nettyContext).build()); + final EmbeddedChannel ch = new EmbeddedChannel(handler, new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + if (evt instanceof SslHandshakeCompletionEvent) { + assertTrue(evtRef.compareAndSet(null, (SslHandshakeCompletionEvent) evt)); + } + } + }); + + try { + final byte[] bytes = new byte[1024]; + bytes[0] = SslUtils.SSL_CONTENT_TYPE_ALERT; + + DecoderException e = assertThrows(DecoderException.class, new Executable() { + @Override + public void execute() throws Throwable { + ch.writeInbound(Unpooled.wrappedBuffer(bytes)); + } + }); + assertThat(e.getCause(), CoreMatchers.instanceOf(NotSslRecordException.class)); + assertFalse(ch.finish()); + } finally { + ch.finishAndReleaseAll(); + } + assertThat(evtRef.get().cause(), CoreMatchers.instanceOf(NotSslRecordException.class)); + } finally { + releaseAll(nettyContext); + } + } + + @ParameterizedTest(name = "{index}: sslProvider={0}") + @MethodSource("data") + public void testServerNameParsing(SslProvider provider) throws Exception { + SslContext nettyContext = makeSslContext(provider, false); + SslContext leanContext = makeSslContext(provider, false); + SslContext leanContext2 = makeSslContext(provider, false); + + try { + DomainNameMapping mapping = new DomainNameMappingBuilder(nettyContext) + .add("*.netty.io", nettyContext) + // input with custom cases + .add("*.LEANCLOUD.CN", leanContext) + // a hostname conflict with previous one, since we are using order-sensitive config, + // the engine won't be used with the handler. + .add("chat4.leancloud.cn", leanContext2) + .build(); + + final AtomicReference evtRef = new AtomicReference(); + SniHandler handler = new SniHandler(mapping); + EmbeddedChannel ch = new EmbeddedChannel(handler, new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + if (evt instanceof SniCompletionEvent) { + assertTrue(evtRef.compareAndSet(null, (SniCompletionEvent) evt)); + } else { + ctx.fireUserEventTriggered(evt); + } + } + }); + + try { + // hex dump of a client hello packet, which contains hostname "CHAT4.LEANCLOUD.CN" + String tlsHandshakeMessageHex1 = "16030100"; + // part 2 + String tlsHandshakeMessageHex = "c6010000c20303bb0855d66532c05a0ef784f7c384feeafa68b3" + + "b655ac7288650d5eed4aa3fb52000038c02cc030009fcca9cca8ccaac02b" + + "c02f009ec024c028006bc023c0270067c00ac0140039c009c0130033009d" + + "009c003d003c0035002f00ff010000610000001700150000124348415434" + + "2e4c45414e434c4f55442e434e000b000403000102000a000a0008001d00" + + "170019001800230000000d0020001e060106020603050105020503040104" + + "0204030301030203030201020202030016000000170000"; + + ch.writeInbound(Unpooled.wrappedBuffer(StringUtil.decodeHexDump(tlsHandshakeMessageHex1))); + ch.writeInbound(Unpooled.wrappedBuffer(StringUtil.decodeHexDump(tlsHandshakeMessageHex))); + + // This should produce an alert + assertTrue(ch.finish()); + + assertThat(handler.hostname(), is("chat4.leancloud.cn")); + assertThat(handler.sslContext(), is(leanContext)); + + SniCompletionEvent evt = evtRef.get(); + assertNotNull(evt); + assertEquals("chat4.leancloud.cn", evt.hostname()); + assertTrue(evt.isSuccess()); + assertNull(evt.cause()); + } finally { + ch.finishAndReleaseAll(); + } + } finally { + releaseAll(leanContext, leanContext2, nettyContext); + } + } + + @ParameterizedTest(name = "{index}: sslProvider={0}") + @MethodSource("data") + public void testNonAsciiServerNameParsing(SslProvider provider) throws Exception { + SslContext nettyContext = makeSslContext(provider, false); + SslContext leanContext = makeSslContext(provider, false); + SslContext leanContext2 = makeSslContext(provider, false); + + try { + DomainNameMapping mapping = new DomainNameMappingBuilder(nettyContext) + .add("*.netty.io", nettyContext) + // input with custom cases + .add("*.LEANCLOUD.CN", leanContext) + // a hostname conflict with previous one, since we are using order-sensitive config, + // the engine won't be used with the handler. + .add("chat4.leancloud.cn", leanContext2) + .build(); + + SniHandler handler = new SniHandler(mapping); + final EmbeddedChannel ch = new EmbeddedChannel(handler); + + try { + // hex dump of a client hello packet, which contains an invalid hostname "CHAT4。LEANCLOUD。CN" + String tlsHandshakeMessageHex1 = "16030100"; + // part 2 + final String tlsHandshakeMessageHex = "bd010000b90303a74225676d1814ba57faff3b366" + + "3656ed05ee9dbb2a4dbb1bb1c32d2ea5fc39e0000000100008c0000001700150000164348" + + "415434E380824C45414E434C4F5544E38082434E000b000403000102000a00340032000e0" + + "00d0019000b000c00180009000a0016001700080006000700140015000400050012001300" + + "0100020003000f0010001100230000000d0020001e0601060206030501050205030401040" + + "20403030103020303020102020203000f00010133740000"; + + // Push the handshake message. + // Decode should fail because of the badly encoded "HostName" string in the SNI extension + // that isn't ASCII as per RFC 6066 - https://tools.ietf.org/html/rfc6066#page-6 + ch.writeInbound(Unpooled.wrappedBuffer(StringUtil.decodeHexDump(tlsHandshakeMessageHex1))); + + assertThrows(DecoderException.class, new Executable() { + @Override + public void execute() throws Throwable { + ch.writeInbound(Unpooled.wrappedBuffer(StringUtil.decodeHexDump(tlsHandshakeMessageHex))); + } + }); + } finally { + ch.finishAndReleaseAll(); + } + } finally { + releaseAll(leanContext, leanContext2, nettyContext); + } + } + + @ParameterizedTest(name = "{index}: sslProvider={0}") + @MethodSource("data") + public void testFallbackToDefaultContext(SslProvider provider) throws Exception { + SslContext nettyContext = makeSslContext(provider, false); + SslContext leanContext = makeSslContext(provider, false); + SslContext leanContext2 = makeSslContext(provider, false); + + try { + DomainNameMapping mapping = new DomainNameMappingBuilder(nettyContext) + .add("*.netty.io", nettyContext) + // input with custom cases + .add("*.LEANCLOUD.CN", leanContext) + // a hostname conflict with previous one, since we are using order-sensitive config, + // the engine won't be used with the handler. + .add("chat4.leancloud.cn", leanContext2) + .build(); + + SniHandler handler = new SniHandler(mapping); + EmbeddedChannel ch = new EmbeddedChannel(handler); + + // invalid + byte[] message = {22, 3, 1, 0, 0}; + try { + // Push the handshake message. + ch.writeInbound(Unpooled.wrappedBuffer(message)); + // TODO(scott): This should fail because the engine should reject zero length records during handshake. + // See https://github.com/netty/netty/issues/6348. + // fail(); + } catch (Exception e) { + // expected + } + + ch.close(); + + // When the channel is closed the SslHandler will write an empty buffer to the channel. + ByteBuf buf = ch.readOutbound(); + // TODO(scott): if the engine is shutdown correctly then this buffer shouldn't be null! + // See https://github.com/netty/netty/issues/6348. + if (buf != null) { + assertFalse(buf.isReadable()); + buf.release(); + } + + assertThat(ch.finish(), is(false)); + assertThat(handler.hostname(), nullValue()); + assertThat(handler.sslContext(), is(nettyContext)); + } finally { + releaseAll(leanContext, leanContext2, nettyContext); + } + } + + @ParameterizedTest(name = "{index}: sslProvider={0}") + @MethodSource("data") + @Timeout(value = 10000, unit = TimeUnit.MILLISECONDS) + public void testMajorVersionNot3(SslProvider provider) throws Exception { + SslContext nettyContext = makeSslContext(provider, false); + + try { + DomainNameMapping mapping = new DomainNameMappingBuilder(nettyContext).build(); + + SniHandler handler = new SniHandler(mapping); + EmbeddedChannel ch = new EmbeddedChannel(handler); + + // invalid + byte[] message = {22, 2, 0, 0, 0}; + try { + // Push the handshake message. + ch.writeInbound(Unpooled.wrappedBuffer(message)); + // TODO(scott): This should fail because the engine should reject zero length records during handshake. + // See https://github.com/netty/netty/issues/6348. + // fail(); + } catch (Exception e) { + // expected + } + + ch.close(); + + // Consume all the outbound data that may be produced by the SSLEngine. + for (;;) { + ByteBuf buf = ch.readOutbound(); + if (buf == null) { + break; + } + buf.release(); + } + + assertThat(ch.finish(), is(false)); + assertThat(handler.hostname(), nullValue()); + assertThat(handler.sslContext(), is(nettyContext)); + } finally { + releaseAll(nettyContext); + } + } + + @ParameterizedTest(name = "{index}: sslProvider={0}") + @MethodSource("data") + public void testSniWithApnHandler(SslProvider provider) throws Exception { + SslContext nettyContext = makeSslContext(provider, true); + SslContext sniContext = makeSslContext(provider, true); + final SslContext clientContext = makeSslClientContext(provider, true); + try { + final AtomicBoolean serverApnCtx = new AtomicBoolean(false); + final AtomicBoolean clientApnCtx = new AtomicBoolean(false); + final CountDownLatch serverApnDoneLatch = new CountDownLatch(1); + final CountDownLatch clientApnDoneLatch = new CountDownLatch(1); + + final DomainNameMapping mapping = new DomainNameMappingBuilder(nettyContext) + .add("*.netty.io", nettyContext) + .add("sni.fake.site", sniContext).build(); + final SniHandler handler = new SniHandler(mapping); + EventLoopGroup group = new NioEventLoopGroup(2); + Channel serverChannel = null; + Channel clientChannel = null; + try { + ServerBootstrap sb = new ServerBootstrap(); + sb.group(group); + sb.channel(NioServerSocketChannel.class); + sb.childHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ChannelPipeline p = ch.pipeline(); + // Server side SNI. + p.addLast(handler); + // Catch the notification event that APN has completed successfully. + p.addLast(new ApplicationProtocolNegotiationHandler("foo") { + @Override + protected void configurePipeline(ChannelHandlerContext ctx, String protocol) { + // addresses issue #9131 + serverApnCtx.set(ctx.pipeline().context(this) != null); + serverApnDoneLatch.countDown(); + } + }); + } + }); + + Bootstrap cb = new Bootstrap(); + cb.group(group); + cb.channel(NioSocketChannel.class); + cb.handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ch.pipeline().addLast(new SslHandler(clientContext.newEngine( + ch.alloc(), "sni.fake.site", -1))); + // Catch the notification event that APN has completed successfully. + ch.pipeline().addLast(new ApplicationProtocolNegotiationHandler("foo") { + @Override + protected void configurePipeline(ChannelHandlerContext ctx, String protocol) { + // addresses issue #9131 + clientApnCtx.set(ctx.pipeline().context(this) != null); + clientApnDoneLatch.countDown(); + } + }); + } + }); + + serverChannel = sb.bind(new InetSocketAddress(0)).sync().channel(); + + ChannelFuture ccf = cb.connect(serverChannel.localAddress()); + assertTrue(ccf.awaitUninterruptibly().isSuccess()); + clientChannel = ccf.channel(); + + assertTrue(serverApnDoneLatch.await(5, TimeUnit.SECONDS)); + assertTrue(clientApnDoneLatch.await(5, TimeUnit.SECONDS)); + assertTrue(serverApnCtx.get()); + assertTrue(clientApnCtx.get()); + assertThat(handler.hostname(), is("sni.fake.site")); + assertThat(handler.sslContext(), is(sniContext)); + } finally { + if (serverChannel != null) { + serverChannel.close().sync(); + } + if (clientChannel != null) { + clientChannel.close().sync(); + } + group.shutdownGracefully(0, 0, TimeUnit.MICROSECONDS); + } + } finally { + releaseAll(clientContext, nettyContext, sniContext); + } + } + + @ParameterizedTest(name = "{index}: sslProvider={0}") + @MethodSource("data") + @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS) + public void testReplaceHandler(SslProvider provider) throws Exception { + switch (provider) { + case OPENSSL: + case OPENSSL_REFCNT: + final String sniHost = "sni.netty.io"; + LocalAddress address = new LocalAddress("testReplaceHandler-" + Math.random()); + EventLoopGroup group = new DefaultEventLoopGroup(1); + Channel sc = null; + Channel cc = null; + SslContext sslContext = null; + + SelfSignedCertificate cert = new SelfSignedCertificate(); + + try { + final SslContext sslServerContext = SslContextBuilder + .forServer(cert.key(), cert.cert()) + .sslProvider(provider) + .build(); + + final Mapping mapping = new Mapping() { + @Override + public SslContext map(String input) { + return sslServerContext; + } + }; + + final Promise releasePromise = group.next().newPromise(); + + final SniHandler handler = new SniHandler(mapping) { + @Override + protected void replaceHandler(ChannelHandlerContext ctx, + String hostname, final SslContext sslContext) + throws Exception { + + boolean success = false; + try { + assertEquals(1, ((ReferenceCountedOpenSslContext) sslContext).refCnt()); + // The SniHandler's replaceHandler() method allows us to implement custom behavior. + // As an example, we want to release() the SslContext upon channelInactive() or rather + // when the SslHandler closes it's SslEngine. If you take a close look at SslHandler + // you'll see that it's doing it in the #handlerRemoved0() method. + + SSLEngine sslEngine = sslContext.newEngine(ctx.alloc()); + try { + assertEquals(2, ((ReferenceCountedOpenSslContext) sslContext).refCnt()); + SslHandler customSslHandler = new CustomSslHandler(sslContext, sslEngine) { + @Override + public void handlerRemoved0(ChannelHandlerContext ctx) throws Exception { + try { + super.handlerRemoved0(ctx); + } finally { + releasePromise.trySuccess(null); + } + } + }; + ctx.pipeline().replace(this, CustomSslHandler.class.getName(), customSslHandler); + success = true; + } finally { + if (!success) { + ReferenceCountUtil.safeRelease(sslEngine); + } + } + } finally { + if (!success) { + ReferenceCountUtil.safeRelease(sslContext); + releasePromise.cancel(true); + } + } + } + }; + + ServerBootstrap sb = new ServerBootstrap(); + sc = sb.group(group).channel(LocalServerChannel.class) + .childHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ch.pipeline().addFirst(handler); + } + }).bind(address).syncUninterruptibly().channel(); + + sslContext = SslContextBuilder.forClient().sslProvider(provider) + .trustManager(InsecureTrustManagerFactory.INSTANCE).build(); + + Bootstrap cb = new Bootstrap(); + cc = cb.group(group).channel(LocalChannel.class).handler(new SslHandler( + sslContext.newEngine(ByteBufAllocator.DEFAULT, sniHost, -1))) + .connect(address).syncUninterruptibly().channel(); + + cc.writeAndFlush(Unpooled.wrappedBuffer("Hello, World!".getBytes())) + .syncUninterruptibly(); + + // Notice how the server's SslContext refCnt is 2 as it is incremented when the SSLEngine is created + // and only decremented once it is destroyed. + assertEquals(2, ((ReferenceCounted) sslServerContext).refCnt()); + + // The client disconnects + cc.close().syncUninterruptibly(); + if (!releasePromise.awaitUninterruptibly(10L, TimeUnit.SECONDS)) { + throw new IllegalStateException("It doesn't seem #replaceHandler() got called."); + } + + // We should have successfully release() the SslContext + assertEquals(0, ((ReferenceCounted) sslServerContext).refCnt()); + } finally { + if (cc != null) { + cc.close().syncUninterruptibly(); + } + if (sc != null) { + sc.close().syncUninterruptibly(); + } + if (sslContext != null) { + ReferenceCountUtil.release(sslContext); + } + group.shutdownGracefully(); + + cert.delete(); + } + case JDK: + return; + default: + throw new Error(); + } + } + + /** + * This is a {@link SslHandler} that will call {@code release()} on the {@link SslContext} when + * the client disconnects. + * + * @see SniHandlerTest#testReplaceHandler(SslProvider) + */ + private static class CustomSslHandler extends SslHandler { + private final SslContext sslContext; + + CustomSslHandler(SslContext sslContext, SSLEngine sslEngine) { + super(sslEngine); + this.sslContext = ObjectUtil.checkNotNull(sslContext, "sslContext"); + } + + @Override + public void handlerRemoved0(ChannelHandlerContext ctx) throws Exception { + super.handlerRemoved0(ctx); + ReferenceCountUtil.release(sslContext); + } + } + + private static void releaseAll(SslContext... contexts) { + for (SslContext ctx: contexts) { + ReferenceCountUtil.release(ctx); + } + } + + @ParameterizedTest(name = "{index}: sslProvider={0}") + @MethodSource("data") + public void testNonFragmented(SslProvider provider) throws Exception { + testWithFragmentSize(provider, Integer.MAX_VALUE); + } + + @ParameterizedTest(name = "{index}: sslProvider={0}") + @MethodSource("data") + public void testFragmented(SslProvider provider) throws Exception { + testWithFragmentSize(provider, 50); + } + + private void testWithFragmentSize(SslProvider provider, final int maxFragmentSize) throws Exception { + final String sni = "netty.io"; + SelfSignedCertificate cert = new SelfSignedCertificate(); + final SslContext context = SslContextBuilder.forServer(cert.key(), cert.cert()) + .sslProvider(provider) + .build(); + try { + @SuppressWarnings("unchecked") final EmbeddedChannel server = new EmbeddedChannel( + new SniHandler(mock(DomainNameMapping.class)) { + @Override + protected Future lookup(final ChannelHandlerContext ctx, final String hostname) { + assertEquals(sni, hostname); + return ctx.executor().newSucceededFuture(context); + } + }); + + final List buffers = clientHelloInMultipleFragments(provider, sni, maxFragmentSize); + for (ByteBuf buffer : buffers) { + server.writeInbound(buffer); + } + assertTrue(server.finishAndReleaseAll()); + } finally { + releaseAll(context); + cert.delete(); + } + } + + private static List clientHelloInMultipleFragments( + SslProvider provider, String hostname, int maxTlsPlaintextSize) throws SSLException { + final EmbeddedChannel client = new EmbeddedChannel(); + final SslContext ctx = SslContextBuilder.forClient() + .sslProvider(provider) + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .build(); + try { + final SslHandler sslHandler = ctx.newHandler(client.alloc(), hostname, -1); + client.pipeline().addLast(sslHandler); + final ByteBuf clientHello = client.readOutbound(); + List buffers = split(clientHello, maxTlsPlaintextSize); + assertTrue(client.finishAndReleaseAll()); + return buffers; + } finally { + releaseAll(ctx); + } + } + + private static List split(ByteBuf clientHello, int maxSize) { + final int type = clientHello.readUnsignedByte(); + final int version = clientHello.readUnsignedShort(); + final int length = clientHello.readUnsignedShort(); + assertEquals(length, clientHello.readableBytes()); + + final List result = new ArrayList(); + while (clientHello.readableBytes() > 0) { + final int toRead = Math.min(maxSize, clientHello.readableBytes()); + final ByteBuf bb = clientHello.alloc().buffer(SslUtils.SSL_RECORD_HEADER_LENGTH + toRead); + bb.writeByte(type); + bb.writeShort(version); + bb.writeShort(toRead); + bb.writeBytes(clientHello, toRead); + result.add(bb); + } + clientHello.release(); + return result; + } + + @Test + public void testSniHandlerFailsOnTooBigClientHello() throws Exception { + SniHandler handler = new SniHandler(new Mapping() { + @Override + public SslContext map(String input) { + throw new UnsupportedOperationException("Should not be called"); + } + }, 10, 0); + + final AtomicReference completionEventRef = + new AtomicReference(); + final EmbeddedChannel ch = new EmbeddedChannel(handler, new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + if (evt instanceof SniCompletionEvent) { + completionEventRef.set((SniCompletionEvent) evt); + } + } + }); + final ByteBuf buffer = ch.alloc().buffer(); + buffer.writeByte(0x16); // Content Type: Handshake + buffer.writeShort((short) 0x0303); // TLS 1.2 + buffer.writeShort((short) 0x0006); // Packet length + + // 16_777_215 + buffer.writeByte((byte) 0x01); // Client Hello + buffer.writeMedium(0xFFFFFF); // Length + buffer.writeShort((short) 0x0303); // TLS 1.2 + + assertThrows(TooLongFrameException.class, new Executable() { + @Override + public void execute() throws Throwable { + ch.writeInbound(buffer); + } + }); + try { + while (completionEventRef.get() == null) { + Thread.sleep(100); + // We need to run all pending tasks as the handshake timeout is scheduled on the EventLoop. + ch.runPendingTasks(); + } + SniCompletionEvent completionEvent = completionEventRef.get(); + assertNotNull(completionEvent); + assertNotNull(completionEvent.cause()); + assertEquals(TooLongFrameException.class, completionEvent.cause().getClass()); + } finally { + ch.finishAndReleaseAll(); + } + } + + @Test + public void testSniHandlerFiresHandshakeTimeout() throws Exception { + SniHandler handler = new SniHandler(new Mapping() { + @Override + public SslContext map(String input) { + throw new UnsupportedOperationException("Should not be called"); + } + }, 0, 10); + + final AtomicReference completionEventRef = + new AtomicReference(); + EmbeddedChannel ch = new EmbeddedChannel(handler, new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + if (evt instanceof SniCompletionEvent) { + completionEventRef.set((SniCompletionEvent) evt); + } + } + }); + try { + while (completionEventRef.get() == null) { + Thread.sleep(100); + // We need to run all pending tasks as the handshake timeout is scheduled on the EventLoop. + ch.runPendingTasks(); + } + SniCompletionEvent completionEvent = completionEventRef.get(); + assertNotNull(completionEvent); + assertNotNull(completionEvent.cause()); + assertEquals(SslHandshakeTimeoutException.class, completionEvent.cause().getClass()); + } finally { + ch.finishAndReleaseAll(); + } + } + + @ParameterizedTest(name = "{index}: sslProvider={0}") + @MethodSource("data") + public void testSslHandlerFiresHandshakeTimeout(SslProvider provider) throws Exception { + final SslContext context = makeSslContext(provider, false); + SniHandler handler = new SniHandler(new Mapping() { + @Override + public SslContext map(String input) { + return context; + } + }, 0, 100); + + final AtomicReference sniCompletionEventRef = + new AtomicReference(); + final AtomicReference handshakeCompletionEventRef = + new AtomicReference(); + EmbeddedChannel ch = new EmbeddedChannel(handler, new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + if (evt instanceof SniCompletionEvent) { + sniCompletionEventRef.set((SniCompletionEvent) evt); + } else if (evt instanceof SslHandshakeCompletionEvent) { + handshakeCompletionEventRef.set((SslHandshakeCompletionEvent) evt); + } + } + }); + try { + // Send enough data to add the SslHandler and let the handshake incomplete + // Client Hello with "host1" server name + ch.writeInbound(Unpooled.wrappedBuffer(StringUtil.decodeHexDump( + "16030301800100017c0303478ae7e536aa7a9debad1f873121862d2d3d3173e0ef42975c31007faeb2" + + "52522047f55f81fc84fe58951e2af14026147d6178498fde551fcbafc636462c016ec9005a13011302" + + "c02cc02bc030009dc02ec032009f00a3c02f009cc02dc031009e00a2c024c028003dc026c02a006b00" + + "6ac00ac0140035c005c00f00390038c023c027003cc025c02900670040c009c013002fc004c00e0033" + + "003200ff010000d90000000a0008000005686f737431000500050100000000000a00160014001d0017" + + "00180019001e01000101010201030104000b00020100000d0028002604030503060308040805080608" + + "09080a080b040105010601040203030301030202030201020200320028002604030503060308040805" + + "08060809080a080b040105010601040203030301030202030201020200110009000702000400000000" + + "00170000002b00050403040303002d00020101003300260024001d00200bbc37375e214c1e4e7cb90f" + + "869e131dc983a21f8205ba24456177f340904935"))); + + while (handshakeCompletionEventRef.get() == null) { + Thread.sleep(10); + // We need to run all pending tasks as the handshake timeout is scheduled on the EventLoop. + ch.runPendingTasks(); + } + SniCompletionEvent sniCompletionEvent = sniCompletionEventRef.get(); + assertNotNull(sniCompletionEvent); + assertEquals("host1", sniCompletionEvent.hostname()); + SslCompletionEvent handshakeCompletionEvent = handshakeCompletionEventRef.get(); + assertNotNull(handshakeCompletionEvent); + assertNotNull(handshakeCompletionEvent.cause()); + assertEquals(SslHandshakeTimeoutException.class, handshakeCompletionEvent.cause().getClass()); + } finally { + ch.finishAndReleaseAll(); + releaseAll(context); + } + } +} diff --git a/netty-handler-ssl/src/test/java/io/netty/handler/ssl/SslContextBuilderTest.java b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/SslContextBuilderTest.java new file mode 100644 index 0000000..788c1d8 --- /dev/null +++ b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/SslContextBuilderTest.java @@ -0,0 +1,428 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.buffer.UnpooledByteBufAllocator; +import io.netty.handler.ssl.util.SelfSignedCertificate; +import io.netty.util.CharsetUtil; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import javax.net.ssl.KeyManager; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLException; +import javax.net.ssl.TrustManager; +import javax.net.ssl.X509ExtendedKeyManager; +import javax.net.ssl.X509ExtendedTrustManager; +import java.io.ByteArrayInputStream; +import java.net.Socket; +import java.security.Principal; +import java.security.PrivateKey; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; +import java.util.Collections; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; +import static org.junit.jupiter.api.Assumptions.assumeFalse; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +public class SslContextBuilderTest { + + @Test + public void testClientContextFromFileJdk() throws Exception { + testClientContextFromFile(SslProvider.JDK); + } + + @Test + public void testClientContextFromFileOpenssl() throws Exception { + OpenSsl.ensureAvailability(); + testClientContextFromFile(SslProvider.OPENSSL); + } + + @Test + public void testClientContextJdk() throws Exception { + testClientContext(SslProvider.JDK); + } + + @Test + public void testClientContextOpenssl() throws Exception { + OpenSsl.ensureAvailability(); + testClientContext(SslProvider.OPENSSL); + } + + @Test + public void testCombinedPemFileClientContextJdk() throws Exception { + testServerContextWithCombinedCertAndKeyInPem(SslProvider.JDK); + } + + @Test + public void testCombinedPemFileClientContextOpenssl() throws Exception { + OpenSsl.ensureAvailability(); + testServerContextWithCombinedCertAndKeyInPem(SslProvider.OPENSSL); + } + + @Test + public void testKeyStoreTypeJdk() throws Exception { + testKeyStoreType(SslProvider.JDK); + } + + @Test + public void testKeyStoreTypeOpenssl() throws Exception { + OpenSsl.ensureAvailability(); + testKeyStoreType(SslProvider.OPENSSL); + } + + @Test + public void testServerContextFromFileJdk() throws Exception { + testServerContextFromFile(SslProvider.JDK); + } + + @Test + public void testServerContextFromFileOpenssl() throws Exception { + OpenSsl.ensureAvailability(); + testServerContextFromFile(SslProvider.OPENSSL); + } + + @Test + public void testServerContextJdk() throws Exception { + testServerContext(SslProvider.JDK); + } + + @Test + public void testServerContextOpenssl() throws Exception { + OpenSsl.ensureAvailability(); + testServerContext(SslProvider.OPENSSL); + } + + @Test + public void testContextFromManagersJdk() throws Exception { + testContextFromManagers(SslProvider.JDK); + } + + @Test + public void testContextFromManagersOpenssl() throws Exception { + OpenSsl.ensureAvailability(); + assumeTrue(OpenSsl.useKeyManagerFactory()); + testContextFromManagers(SslProvider.OPENSSL); + } + + @Test + public void testUnsupportedPrivateKeyFailsFastForServer() { + assumeTrue(OpenSsl.isBoringSSL()); + testUnsupportedPrivateKeyFailsFast(true); + } + + @Test + public void testUnsupportedPrivateKeyFailsFastForClient() { + assumeTrue(OpenSsl.isBoringSSL()); + testUnsupportedPrivateKeyFailsFast(false); + } + + private static void testUnsupportedPrivateKeyFailsFast(boolean server) { + assumeTrue(OpenSsl.isBoringSSL()); + String cert = "-----BEGIN CERTIFICATE-----\n" + + "MIICODCCAY2gAwIBAgIEXKTrajAKBggqhkjOPQQDBDBUMQswCQYDVQQGEwJVUzEM\n" + + "MAoGA1UECAwDTi9hMQwwCgYDVQQHDANOL2ExDDAKBgNVBAoMA04vYTEMMAoGA1UE\n" + + "CwwDTi9hMQ0wCwYDVQQDDARUZXN0MB4XDTE5MDQwMzE3MjA0MloXDTIwMDQwMjE3\n" + + "MjA0MlowVDELMAkGA1UEBhMCVVMxDDAKBgNVBAgMA04vYTEMMAoGA1UEBwwDTi9h\n" + + "MQwwCgYDVQQKDANOL2ExDDAKBgNVBAsMA04vYTENMAsGA1UEAwwEVGVzdDCBpzAQ\n" + + "BgcqhkjOPQIBBgUrgQQAJwOBkgAEBPYWoTjlS2pCMGEM2P8qZnmURWA5e7XxPfIh\n" + + "HA876sjmgjJluPgT0OkweuxI4Y/XjzcPnnEBONgzAV1X93UmXdtRiIau/zvsAeFb\n" + + "j/q+6sfj1jdnUk6QsMx22kAwplXHmdz1z5ShXQ7mDZPxDbhCPEAUXzIzOqvWIZyA\n" + + "HgFxZXmQKEhExA8nxgSIvzQ3ucMwMAoGCCqGSM49BAMEA4GYADCBlAJIAdPD6jaN\n" + + "vGxkxcsIbcHn2gSfP1F1G8iNJYrXIN91KbQm8OEp4wxqnBwX8gb/3rmSoEhIU/te\n" + + "CcHuFs0guBjfgRWtJ/eDnKB/AkgDbkqrB5wqJFBmVd/rJ5QdwUVNuGP/vDjFVlb6\n" + + "Esny6//gTL7jYubLUKHOPIMftCZ2Jn4b+5l0kAs62HD5XkZLPDTwRbf7VCE=\n" + + "-----END CERTIFICATE-----"; + String key = "-----BEGIN PRIVATE KEY-----\n" + + "MIIBCQIBADAQBgcqhkjOPQIBBgUrgQQAJwSB8TCB7gIBAQRIALNClTXqQWWlYDHw\n" + + "LjNxXpLk17iPepkmablhbxmYX/8CNzoz1o2gcUidoIO2DM9hm7adI/W31EOmSiUJ\n" + + "+UsC/ZH3i2qr0wn+oAcGBSuBBAAnoYGVA4GSAAQE9hahOOVLakIwYQzY/ypmeZRF\n" + + "YDl7tfE98iEcDzvqyOaCMmW4+BPQ6TB67Ejhj9ePNw+ecQE42DMBXVf3dSZd21GI\n" + + "hq7/O+wB4VuP+r7qx+PWN2dSTpCwzHbaQDCmVceZ3PXPlKFdDuYNk/ENuEI8QBRf\n" + + "MjM6q9YhnIAeAXFleZAoSETEDyfGBIi/NDe5wzA=\n" + + "-----END PRIVATE KEY-----"; + ByteArrayInputStream certStream = new ByteArrayInputStream(cert.getBytes(CharsetUtil.US_ASCII)); + ByteArrayInputStream keyStream = new ByteArrayInputStream(key.getBytes(CharsetUtil.US_ASCII)); + final SslContextBuilder builder; + try { + if (server) { + builder = SslContextBuilder.forServer(certStream, keyStream, null); + } else { + builder = SslContextBuilder.forClient().keyManager(certStream, keyStream, null); + } + } catch (IllegalArgumentException e) { + assumeFalse("Input stream not contain valid certificates.".equals(e.getMessage()) + && e.getCause() != null + && "java.io.IOException: Unknown named curve: 1.3.132.0.39".equals( + e.getCause().getMessage()), + "Cannot test that SslProvider rejects certificates with curve " + + "1.3.132.0.39 because the key manager does not know the curve either."); + throw e; + } + assertThrows(SSLException.class, new Executable() { + @Override + public void execute() throws Throwable { + builder.sslProvider(SslProvider.OPENSSL).build(); + } + }); + } + + private void testServerContextWithCombinedCertAndKeyInPem(SslProvider provider) throws SSLException { + String pem = "-----BEGIN CERTIFICATE-----\n" + + "MIIB1jCCAX0CCQDq4PSOirh7MDAJBgcqhkjOPQQBMHIxCzAJBgNVBAYTAlVTMQsw\n" + + "CQYDVQQIDAJDQTEMMAoGA1UEBwwDRm9vMQwwCgYDVQQKDANCYXIxDDAKBgNVBAsM\n" + + "A0JhejEQMA4GA1UEAwwHQmFyLmNvbTEaMBgGCSqGSIb3DQEJARYLZm9vQGJhci5j\n" + + "b20wHhcNMjIxMDAyMTYzODAyWhcNMjIxMjAxMTYzODAyWjB2MQswCQYDVQQGEwJV\n" + + "UzELMAkGA1UECAwCQ0ExDDAKBgNVBAcMA0ZvbzEMMAoGA1UECgwDQmFyMQwwCgYD\n" + + "VQQLDANiYXoxFDASBgNVBAMMC2Jhci5iYXIuYmF6MRowGAYJKoZIhvcNAQkBFgtm\n" + + "b29AYmFyLmNvbTBZMBMGByqGSM49AgEGCCqGSM49AwEHA0IABHiEmjPEqQbqXYMB\n" + + "nAPOv24rJf6MhTwHB0QC1suZ9q9XFUkalnqGryqf/emHs81RsXWKz4sCsbIJkmHz\n" + + "H8HYhmkwCQYHKoZIzj0EAQNIADBFAiBCgzxZ5qviemPdejt2WazSgwNJTbirzoQa\n" + + "FMv2XFTTCwIhANS3fZ8BulbYkdRWVEFwm2FGotqLfC60JA/gg/brlWSP\n" + + "-----END CERTIFICATE-----\n" + + "-----BEGIN EC PRIVATE KEY-----\n" + + "MHcCAQEEIF8RlaD0JX8u2Lryq1+AbYfDaTBPJnPSA8+N2L12YuuUoAoGCCqGSM49\n" + + "AwEHoUQDQgAEeISaM8SpBupdgwGcA86/bisl/oyFPAcHRALWy5n2r1cVSRqWeoav\n" + + "Kp/96YezzVGxdYrPiwKxsgmSYfMfwdiGaQ==\n" + + "-----END EC PRIVATE KEY-----"; + + ByteArrayInputStream certStream = new ByteArrayInputStream(pem.getBytes(CharsetUtil.US_ASCII)); + ByteArrayInputStream keyStream = new ByteArrayInputStream(pem.getBytes(CharsetUtil.US_ASCII)); + + SslContext context = SslContextBuilder.forServer(certStream, keyStream, null) + .sslProvider(provider) + .clientAuth(ClientAuth.OPTIONAL) + .build(); + + SSLEngine engine = context.newEngine(UnpooledByteBufAllocator.DEFAULT); + assertTrue(engine.getWantClientAuth()); + assertFalse(engine.getNeedClientAuth()); + engine.closeInbound(); + engine.closeOutbound(); + } + + @Test + public void testInvalidCipherJdk() throws Exception { + OpenSsl.ensureAvailability(); + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() throws Throwable { + testInvalidCipher(SslProvider.JDK); + } + }); + } + + @Test + public void testInvalidCipherOpenSSL() throws Exception { + OpenSsl.ensureAvailability(); + try { + // This may fail or not depending on the OpenSSL version used + // See https://github.com/openssl/openssl/issues/7196 + testInvalidCipher(SslProvider.OPENSSL); + if (!OpenSsl.versionString().contains("1.1.1")) { + fail(); + } + } catch (SSLException expected) { + // ok + } + } + + private static void testKeyStoreType(SslProvider provider) throws Exception { + SelfSignedCertificate cert = new SelfSignedCertificate(); + SslContextBuilder builder = SslContextBuilder.forServer(cert.certificate(), cert.privateKey()) + .sslProvider(provider) + .keyStoreType("PKCS12"); + SslContext context = builder.build(); + SSLEngine engine = context.newEngine(UnpooledByteBufAllocator.DEFAULT); + engine.closeInbound(); + engine.closeOutbound(); + } + + private static void testInvalidCipher(SslProvider provider) throws Exception { + SelfSignedCertificate cert = new SelfSignedCertificate(); + SslContextBuilder builder = SslContextBuilder.forClient() + .sslProvider(provider) + .ciphers(Collections.singleton("SOME_INVALID_CIPHER")) + .keyManager(cert.certificate(), + cert.privateKey()) + .trustManager(cert.certificate()); + SslContext context = builder.build(); + context.newEngine(UnpooledByteBufAllocator.DEFAULT); + } + + private static void testClientContextFromFile(SslProvider provider) throws Exception { + SelfSignedCertificate cert = new SelfSignedCertificate(); + SslContextBuilder builder = SslContextBuilder.forClient() + .sslProvider(provider) + .keyManager(cert.certificate(), + cert.privateKey()) + .trustManager(cert.certificate()) + .clientAuth(ClientAuth.OPTIONAL); + SslContext context = builder.build(); + SSLEngine engine = context.newEngine(UnpooledByteBufAllocator.DEFAULT); + assertFalse(engine.getWantClientAuth()); + assertFalse(engine.getNeedClientAuth()); + engine.closeInbound(); + engine.closeOutbound(); + } + + private static void testClientContext(SslProvider provider) throws Exception { + SelfSignedCertificate cert = new SelfSignedCertificate(); + SslContextBuilder builder = SslContextBuilder.forClient() + .sslProvider(provider) + .keyManager(cert.key(), cert.cert()) + .trustManager(cert.cert()) + .clientAuth(ClientAuth.OPTIONAL); + SslContext context = builder.build(); + SSLEngine engine = context.newEngine(UnpooledByteBufAllocator.DEFAULT); + assertFalse(engine.getWantClientAuth()); + assertFalse(engine.getNeedClientAuth()); + engine.closeInbound(); + engine.closeOutbound(); + } + + private static void testServerContextFromFile(SslProvider provider) throws Exception { + SelfSignedCertificate cert = new SelfSignedCertificate(); + SslContextBuilder builder = SslContextBuilder.forServer(cert.certificate(), cert.privateKey()) + .sslProvider(provider) + .trustManager(cert.certificate()) + .clientAuth(ClientAuth.OPTIONAL); + SslContext context = builder.build(); + SSLEngine engine = context.newEngine(UnpooledByteBufAllocator.DEFAULT); + assertTrue(engine.getWantClientAuth()); + assertFalse(engine.getNeedClientAuth()); + engine.closeInbound(); + engine.closeOutbound(); + } + + private static void testServerContext(SslProvider provider) throws Exception { + SelfSignedCertificate cert = new SelfSignedCertificate(); + SslContextBuilder builder = SslContextBuilder.forServer(cert.key(), cert.cert()) + .sslProvider(provider) + .trustManager(cert.cert()) + .clientAuth(ClientAuth.REQUIRE); + SslContext context = builder.build(); + SSLEngine engine = context.newEngine(UnpooledByteBufAllocator.DEFAULT); + assertFalse(engine.getWantClientAuth()); + assertTrue(engine.getNeedClientAuth()); + engine.closeInbound(); + engine.closeOutbound(); + } + + private static void testContextFromManagers(SslProvider provider) throws Exception { + final SelfSignedCertificate cert = new SelfSignedCertificate(); + KeyManager customKeyManager = new X509ExtendedKeyManager() { + @Override + public String[] getClientAliases(String s, + Principal[] principals) { + return new String[0]; + } + + @Override + public String chooseClientAlias(String[] strings, + Principal[] principals, + Socket socket) { + return "cert_sent_to_server"; + } + + @Override + public String[] getServerAliases(String s, + Principal[] principals) { + return new String[0]; + } + + @Override + public String chooseServerAlias(String s, + Principal[] principals, + Socket socket) { + return null; + } + + @Override + public X509Certificate[] getCertificateChain(String s) { + X509Certificate[] certificates = new X509Certificate[1]; + certificates[0] = cert.cert(); + return new X509Certificate[0]; + } + + @Override + public PrivateKey getPrivateKey(String s) { + return cert.key(); + } + }; + TrustManager customTrustManager = new X509ExtendedTrustManager() { + @Override + public void checkClientTrusted( + X509Certificate[] x509Certificates, String s, + Socket socket) throws CertificateException { } + + @Override + public void checkServerTrusted( + X509Certificate[] x509Certificates, String s, + Socket socket) throws CertificateException { } + + @Override + public void checkClientTrusted( + X509Certificate[] x509Certificates, String s, + SSLEngine sslEngine) throws CertificateException { } + + @Override + public void checkServerTrusted( + X509Certificate[] x509Certificates, String s, + SSLEngine sslEngine) throws CertificateException { } + + @Override + public void checkClientTrusted( + X509Certificate[] x509Certificates, String s) + throws CertificateException { } + + @Override + public void checkServerTrusted( + X509Certificate[] x509Certificates, String s) + throws CertificateException { } + + @Override + public X509Certificate[] getAcceptedIssuers() { + return new X509Certificate[0]; + } + }; + SslContextBuilder client_builder = SslContextBuilder.forClient() + .sslProvider(provider) + .keyManager(customKeyManager) + .trustManager(customTrustManager) + .clientAuth(ClientAuth.OPTIONAL); + SslContext client_context = client_builder.build(); + SSLEngine client_engine = client_context.newEngine(UnpooledByteBufAllocator.DEFAULT); + assertFalse(client_engine.getWantClientAuth()); + assertFalse(client_engine.getNeedClientAuth()); + client_engine.closeInbound(); + client_engine.closeOutbound(); + SslContextBuilder server_builder = SslContextBuilder.forServer(customKeyManager) + .sslProvider(provider) + .trustManager(customTrustManager) + .clientAuth(ClientAuth.REQUIRE); + SslContext server_context = server_builder.build(); + SSLEngine server_engine = server_context.newEngine(UnpooledByteBufAllocator.DEFAULT); + assertFalse(server_engine.getWantClientAuth()); + assertTrue(server_engine.getNeedClientAuth()); + server_engine.closeInbound(); + server_engine.closeOutbound(); + } +} diff --git a/netty-handler-ssl/src/test/java/io/netty/handler/ssl/SslContextTest.java b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/SslContextTest.java new file mode 100644 index 0000000..5f43df2 --- /dev/null +++ b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/SslContextTest.java @@ -0,0 +1,371 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.util.internal.ResourcesUtil; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import java.io.File; +import java.io.IOException; +import java.security.KeyManagementException; +import java.security.NoSuchAlgorithmException; +import java.security.PrivateKey; +import java.security.cert.CertificateException; +import java.security.spec.InvalidKeySpecException; + +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLException; + +import static org.junit.jupiter.api.Assumptions.assumeTrue; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public abstract class SslContextTest { + + @Test + public void testUnencryptedEmptyPassword() throws Exception { + assertThrows(IOException.class, new Executable() { + @Override + public void execute() throws Throwable { + SslContext.toPrivateKey( + ResourcesUtil.getFile(getClass(), "test2_unencrypted.pem"), ""); + } + }); + } + + @Test + public void testUnEncryptedNullPassword() throws Exception { + PrivateKey key = SslContext.toPrivateKey( + ResourcesUtil.getFile(getClass(), "test2_unencrypted.pem"), null); + assertNotNull(key); + } + + @Test + public void testEncryptedEmptyPassword() throws Exception { + PrivateKey key = SslContext.toPrivateKey( + ResourcesUtil.getFile(getClass(), "test_encrypted_empty_pass.pem"), ""); + assertNotNull(key); + } + + @Test + public void testEncryptedNullPassword() throws Exception { + assertThrows(InvalidKeySpecException.class, new Executable() { + @Override + public void execute() throws Throwable { + SslContext.toPrivateKey( + ResourcesUtil.getFile(getClass(), "test_encrypted_empty_pass.pem"), null); + } + }); + } + + @Test + public void testSslContextWithEncryptedPrivateKey() throws SSLException { + File keyFile = ResourcesUtil.getFile(getClass(), "test_encrypted.pem"); + File crtFile = ResourcesUtil.getFile(getClass(), "test.crt"); + + newSslContext(crtFile, keyFile, "12345"); + } + + @Test + public void testSslContextWithEncryptedPrivateKey2() throws SSLException { + File keyFile = ResourcesUtil.getFile(getClass(), "test2_encrypted.pem"); + File crtFile = ResourcesUtil.getFile(getClass(), "test2.crt"); + + newSslContext(crtFile, keyFile, "12345"); + } + + @Test + public void testSslContextWithUnencryptedPrivateKey() throws SSLException { + File keyFile = ResourcesUtil.getFile(getClass(), "test_unencrypted.pem"); + File crtFile = ResourcesUtil.getFile(getClass(), "test.crt"); + + newSslContext(crtFile, keyFile, null); + } + + @Test + public void testSslContextWithUnencryptedPrivateKeyEmptyPass() throws SSLException { + final File keyFile = ResourcesUtil.getFile(getClass(), "test_unencrypted.pem"); + final File crtFile = ResourcesUtil.getFile(getClass(), "test.crt"); + + assertThrows(SSLException.class, new Executable() { + @Override + public void execute() throws Throwable { + newSslContext(crtFile, keyFile, ""); + } + }); + } + + @Test + public void testSupportedCiphers() throws KeyManagementException, NoSuchAlgorithmException, SSLException { + SSLContext jdkSslContext = SSLContext.getInstance("TLS"); + jdkSslContext.init(null, null, null); + SSLEngine sslEngine = jdkSslContext.createSSLEngine(); + + String unsupportedCipher = "TLS_DH_anon_WITH_DES_CBC_SHA"; + IllegalArgumentException exception = null; + try { + sslEngine.setEnabledCipherSuites(new String[] {unsupportedCipher}); + } catch (IllegalArgumentException e) { + exception = e; + } + assumeTrue(exception != null); + File keyFile = ResourcesUtil.getFile(getClass(), "test_unencrypted.pem"); + File crtFile = ResourcesUtil.getFile(getClass(), "test.crt"); + + SslContext sslContext = newSslContext(crtFile, keyFile, null); + assertFalse(sslContext.cipherSuites().contains(unsupportedCipher)); + } + + @Test + public void testUnsupportedParams() throws CertificateException { + assertThrows(CertificateException.class, new Executable() { + @Override + public void execute() throws Throwable { + SslContext.toX509Certificates( + new File(getClass().getResource("ec_params_unsupported.pem").getFile())); + } + }); + } + + protected abstract SslContext newSslContext(File crtFile, File keyFile, String pass) throws SSLException; + + @Test + public void testPkcs1UnencryptedRsa() throws Exception { + PrivateKey key = SslContext.toPrivateKey( + new File(getClass().getResource("rsa_pkcs1_unencrypted.key").getFile()), null); + assertNotNull(key); + } + + @Test + public void testPkcs8UnencryptedRsa() throws Exception { + PrivateKey key = SslContext.toPrivateKey( + new File(getClass().getResource("rsa_pkcs8_unencrypted.key").getFile()), null); + assertNotNull(key); + } + + @Test + public void testPkcs8AesEncryptedRsa() throws Exception { + PrivateKey key = SslContext.toPrivateKey(new File(getClass().getResource("rsa_pkcs8_aes_encrypted.key") + .getFile()), "example"); + assertNotNull(key); + } + + @Test + public void testPkcs8Des3EncryptedRsa() throws Exception { + PrivateKey key = SslContext.toPrivateKey(new File(getClass().getResource("rsa_pkcs8_des3_encrypted.key") + .getFile()), "example"); + assertNotNull(key); + } + + @Test + public void testPkcs8Pbes2() throws Exception { + PrivateKey key = SslContext.toPrivateKey(new File(getClass().getResource("rsa_pbes2_enc_pkcs8.key") + .getFile()), "12345678", false); + assertNotNull(key); + } + + @Test + public void testPkcs1UnencryptedRsaEmptyPassword() throws Exception { + assertThrows(IOException.class, new Executable() { + @Override + public void execute() throws Throwable { + SslContext.toPrivateKey( + new File(getClass().getResource("rsa_pkcs1_unencrypted.key").getFile()), ""); + } + }); + } + + @Test + public void testPkcs1Des3EncryptedRsa() throws Exception { + PrivateKey key = SslContext.toPrivateKey(new File(getClass().getResource("rsa_pkcs1_des3_encrypted.key") + .getFile()), "example"); + assertNotNull(key); + } + + @Test + public void testPkcs1AesEncryptedRsa() throws Exception { + PrivateKey key = SslContext.toPrivateKey(new File(getClass().getResource("rsa_pkcs1_aes_encrypted.key") + .getFile()), "example"); + assertNotNull(key); + } + + @Test + public void testPkcs1Des3EncryptedRsaNoPassword() throws Exception { + assertThrows(InvalidKeySpecException.class, new Executable() { + @Override + public void execute() throws Throwable { + SslContext.toPrivateKey(new File(getClass().getResource("rsa_pkcs1_des3_encrypted.key") + .getFile()), null); + } + }); + } + + @Test + public void testPkcs1AesEncryptedRsaNoPassword() throws Exception { + assertThrows(InvalidKeySpecException.class, new Executable() { + @Override + public void execute() throws Throwable { + SslContext.toPrivateKey(new File(getClass().getResource("rsa_pkcs1_aes_encrypted.key") + .getFile()), null); + } + }); + } + + @Test + public void testPkcs1Des3EncryptedRsaEmptyPassword() throws Exception { + assertThrows(IOException.class, new Executable() { + @Override + public void execute() throws Throwable { + SslContext.toPrivateKey(new File(getClass().getResource("rsa_pkcs1_des3_encrypted.key") + .getFile()), ""); + } + }); + } + + @Test + public void testPkcs1AesEncryptedRsaEmptyPassword() throws Exception { + assertThrows(IOException.class, new Executable() { + @Override + public void execute() throws Throwable { + SslContext.toPrivateKey(new File(getClass().getResource("rsa_pkcs1_aes_encrypted.key") + .getFile()), ""); + } + }); + } + + @Test + public void testPkcs1Des3EncryptedRsaWrongPassword() throws Exception { + assertThrows(IOException.class, new Executable() { + @Override + public void execute() throws Throwable { + SslContext.toPrivateKey(new File(getClass().getResource("rsa_pkcs1_des3_encrypted.key") + .getFile()), "wrong"); + } + }); + } + + @Test + public void testPkcs1AesEncryptedRsaWrongPassword() throws Exception { + assertThrows(IOException.class, new Executable() { + @Override + public void execute() throws Throwable { + SslContext.toPrivateKey(new File(getClass().getResource("rsa_pkcs1_aes_encrypted.key") + .getFile()), "wrong"); + } + }); + } + + @Test + public void testPkcs1UnencryptedDsa() throws Exception { + PrivateKey key = SslContext.toPrivateKey( + new File(getClass().getResource("dsa_pkcs1_unencrypted.key").getFile()), null); + assertNotNull(key); + } + + @Test + public void testPkcs1UnencryptedDsaEmptyPassword() throws Exception { + assertThrows(IOException.class, new Executable() { + @Override + public void execute() throws Throwable { + PrivateKey key = SslContext.toPrivateKey( + new File(getClass().getResource("dsa_pkcs1_unencrypted.key").getFile()), ""); + } + }); + } + + @Test + public void testPkcs1Des3EncryptedDsa() throws Exception { + PrivateKey key = SslContext.toPrivateKey(new File(getClass().getResource("dsa_pkcs1_des3_encrypted.key") + .getFile()), "example"); + assertNotNull(key); + } + + @Test + public void testPkcs1AesEncryptedDsa() throws Exception { + PrivateKey key = SslContext.toPrivateKey(new File(getClass().getResource("dsa_pkcs1_aes_encrypted.key") + .getFile()), "example"); + assertNotNull(key); + } + + @Test + public void testPkcs1Des3EncryptedDsaNoPassword() throws Exception { + assertThrows(InvalidKeySpecException.class, new Executable() { + @Override + public void execute() throws Throwable { + SslContext.toPrivateKey(new File(getClass().getResource("dsa_pkcs1_des3_encrypted.key") + .getFile()), null); + } + }); + } + + @Test + public void testPkcs1AesEncryptedDsaNoPassword() throws Exception { + assertThrows(InvalidKeySpecException.class, new Executable() { + @Override + public void execute() throws Throwable { + SslContext.toPrivateKey(new File(getClass().getResource("dsa_pkcs1_aes_encrypted.key") + .getFile()), null); + } + }); + } + + @Test + public void testPkcs1Des3EncryptedDsaEmptyPassword() throws Exception { + assertThrows(IOException.class, new Executable() { + @Override + public void execute() throws Throwable { + SslContext.toPrivateKey(new File(getClass().getResource("dsa_pkcs1_des3_encrypted.key") + .getFile()), ""); + } + }); + } + + @Test + public void testPkcs1AesEncryptedDsaEmptyPassword() throws Exception { + assertThrows(IOException.class, new Executable() { + @Override + public void execute() throws Throwable { + SslContext.toPrivateKey(new File(getClass().getResource("dsa_pkcs1_aes_encrypted.key") + .getFile()), ""); + } + }); + } + + @Test + public void testPkcs1Des3EncryptedDsaWrongPassword() throws Exception { + assertThrows(IOException.class, new Executable() { + @Override + public void execute() throws Throwable { + SslContext.toPrivateKey(new File(getClass().getResource("dsa_pkcs1_des3_encrypted.key") + .getFile()), "wrong"); + } + }); + } + + @Test + public void testPkcs1AesEncryptedDsaWrongPassword() throws Exception { + assertThrows(IOException.class, new Executable() { + @Override + public void execute() throws Throwable { + SslContext.toPrivateKey(new File(getClass().getResource("dsa_pkcs1_aes_encrypted.key") + .getFile()), "wrong"); + } + }); + } +} diff --git a/netty-handler-ssl/src/test/java/io/netty/handler/ssl/SslContextTrustManagerTest.java b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/SslContextTrustManagerTest.java new file mode 100644 index 0000000..509ac19 --- /dev/null +++ b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/SslContextTrustManagerTest.java @@ -0,0 +1,122 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import org.junit.jupiter.api.Test; + +import javax.net.ssl.TrustManager; +import javax.net.ssl.TrustManagerFactory; +import javax.net.ssl.X509TrustManager; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; +import java.util.Arrays; + +import static io.netty.handler.ssl.Java8SslTestUtils.loadCertCollection; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.fail; + +public class SslContextTrustManagerTest { + @Test + public void testUsingAllCAs() throws Exception { + runTests(new String[] { "tm_test_ca_1a.pem", "tm_test_ca_1b.pem", + "tm_test_ca_2.pem" }, new String[] { "tm_test_eec_1.pem", + "tm_test_eec_2.pem", "tm_test_eec_3.pem" }, new boolean[] { + true, true, true }); + } + + @Test + public void testUsingAllCAsWithDuplicates() throws Exception { + runTests(new String[] { "tm_test_ca_1a.pem", "tm_test_ca_1b.pem", + "tm_test_ca_2.pem", "tm_test_ca_2.pem" }, + new String[] { "tm_test_eec_1.pem", "tm_test_eec_2.pem", + "tm_test_eec_3.pem" }, + new boolean[] { true, true, true }); + } + + @Test + public void testUsingCAsOneAandB() throws Exception { + runTests(new String[] { "tm_test_ca_1a.pem", "tm_test_ca_1b.pem", }, + new String[] { "tm_test_eec_1.pem", "tm_test_eec_2.pem", + "tm_test_eec_3.pem" }, new boolean[] { true, true, + false }); + } + + @Test + public void testUsingCAsOneAandTwo() throws Exception { + runTests(new String[] { "tm_test_ca_1a.pem", "tm_test_ca_2.pem" }, + new String[] { "tm_test_eec_1.pem", "tm_test_eec_2.pem", + "tm_test_eec_3.pem" }, new boolean[] { true, false, + true }); + } + + /** + * + * @param caResources + * an array of paths to CA Certificates in PEM format to load + * from the classpath (relative to this class). + * @param eecResources + * an array of paths to Server Certificates in PEM format in to + * load from the classpath (relative to this class). + * @param expectations + * an array of expecting results for each EEC Server Certificate + * (the array is expected to have the same length the previous + * argument, and be arrange in matching order: true means + * expected to be valid, false otherwise. + */ + private static void runTests(String[] caResources, String[] eecResources, + boolean[] expectations) throws Exception { + X509TrustManager tm = getTrustManager(caResources); + + X509Certificate[] eecCerts = loadCertCollection(eecResources); + + for (int i = 0; i < eecResources.length; i++) { + X509Certificate eecCert = eecCerts[i]; + assertNotNull(eecCert, "Cannot use cert " + eecResources[i]); + try { + tm.checkServerTrusted(new X509Certificate[] { eecCert }, "RSA"); + if (!expectations[i]) { + fail(String.format( + "Certificate %s was expected not to be valid when using CAs %s, but its " + + "verification passed.", eecResources[i], + Arrays.asList(caResources))); + } + } catch (CertificateException e) { + if (expectations[i]) { + fail(String.format( + "Certificate %s was expected to be valid when using CAs %s, but its " + + "verification failed.", eecResources[i], + Arrays.asList(caResources))); + } + } + } + } + + private static X509TrustManager getTrustManager(String[] resourceNames) + throws Exception { + X509Certificate[] certCollection = loadCertCollection(resourceNames); + TrustManagerFactory tmf = SslContext.buildTrustManagerFactory( + certCollection, null, null); + + for (TrustManager tm : tmf.getTrustManagers()) { + if (tm instanceof X509TrustManager) { + return (X509TrustManager) tm; + } + } + + throw new Exception( + "Unable to find any X509TrustManager from this factory."); + } +} diff --git a/netty-handler-ssl/src/test/java/io/netty/handler/ssl/SslErrorTest.java b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/SslErrorTest.java new file mode 100644 index 0000000..be73753 --- /dev/null +++ b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/SslErrorTest.java @@ -0,0 +1,312 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.nio.NioServerSocketChannel; +import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.handler.logging.LogLevel; +import io.netty.handler.logging.LoggingHandler; +import io.netty.handler.ssl.util.InsecureTrustManagerFactory; +import io.netty.handler.ssl.util.SelfSignedCertificate; +import io.netty.handler.ssl.util.SimpleTrustManagerFactory; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.concurrent.Promise; +import io.netty.util.internal.EmptyArrays; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import javax.net.ssl.ManagerFactoryParameters; +import javax.net.ssl.SSLException; +import javax.net.ssl.TrustManager; +import javax.net.ssl.X509TrustManager; +import javax.security.auth.x500.X500Principal; +import java.io.File; +import java.security.KeyStore; +import java.security.cert.CRLReason; +import java.security.cert.CertPathValidatorException; +import java.security.cert.CertificateException; +import java.security.cert.CertificateExpiredException; +import java.security.cert.CertificateNotYetValidException; +import java.security.cert.CertificateRevokedException; +import java.security.cert.Extension; +import java.security.cert.X509Certificate; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.Date; +import java.util.List; +import java.util.Locale; +import java.util.concurrent.TimeUnit; + +public class SslErrorTest { + + static Collection data() { + List serverProviders = new ArrayList(2); + List clientProviders = new ArrayList(3); + + if (OpenSsl.isAvailable()) { + serverProviders.add(SslProvider.OPENSSL); + serverProviders.add(SslProvider.OPENSSL_REFCNT); + clientProviders.add(SslProvider.OPENSSL); + clientProviders.add(SslProvider.OPENSSL_REFCNT); + } + // We not test with SslProvider.JDK on the server side as the JDK implementation currently just send the same + // alert all the time, sigh..... + clientProviders.add(SslProvider.JDK); + + List exceptions = new ArrayList(6); + exceptions.add(new CertificateExpiredException()); + exceptions.add(new CertificateNotYetValidException()); + exceptions.add(new CertificateRevokedException( + new Date(), CRLReason.AA_COMPROMISE, new X500Principal(""), + Collections.emptyMap())); + + // Also use wrapped exceptions as this is what the JDK implementation of X509TrustManagerFactory is doing. + exceptions.add(newCertificateException(CertPathValidatorException.BasicReason.EXPIRED)); + exceptions.add(newCertificateException(CertPathValidatorException.BasicReason.NOT_YET_VALID)); + exceptions.add(newCertificateException(CertPathValidatorException.BasicReason.REVOKED)); + + List params = new ArrayList(); + for (SslProvider serverProvider: serverProviders) { + for (SslProvider clientProvider: clientProviders) { + for (CertificateException exception: exceptions) { + params.add(new Object[] { serverProvider, clientProvider, exception, true }); + params.add(new Object[] { serverProvider, clientProvider, exception, false }); + } + } + } + return params; + } + + private static CertificateException newCertificateException(CertPathValidatorException.Reason reason) { + return new TestCertificateException( + new CertPathValidatorException("x", null, null, -1, reason)); + } + + @ParameterizedTest( + name = "{index}: serverProvider = {0}, clientProvider = {1}, exception = {2}, serverProduceError = {3}") + @MethodSource("data") + @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS) + public void testCorrectAlert(SslProvider serverProvider, final SslProvider clientProvider, + final CertificateException exception, final boolean serverProduceError) + throws Exception { + // As this only works correctly at the moment when OpenSslEngine is used on the server-side there is + // no need to run it if there is no openssl is available at all. + OpenSsl.ensureAvailability(); + + SelfSignedCertificate ssc = new SelfSignedCertificate(); + + SslContextBuilder sslServerCtxBuilder = SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) + .sslProvider(serverProvider) + .clientAuth(ClientAuth.REQUIRE); + SslContextBuilder sslClientCtxBuilder = SslContextBuilder.forClient() + .keyManager(new File(getClass().getResource("test.crt").getFile()), + new File(getClass().getResource("test_unencrypted.pem").getFile())) + .sslProvider(clientProvider); + + if (serverProduceError) { + sslServerCtxBuilder.trustManager(new ExceptionTrustManagerFactory(exception)); + sslClientCtxBuilder.trustManager(InsecureTrustManagerFactory.INSTANCE); + } else { + sslServerCtxBuilder.trustManager(InsecureTrustManagerFactory.INSTANCE); + sslClientCtxBuilder.trustManager(new ExceptionTrustManagerFactory(exception)); + } + + final SslContext sslServerCtx = sslServerCtxBuilder.build(); + final SslContext sslClientCtx = sslClientCtxBuilder.build(); + + Channel serverChannel = null; + Channel clientChannel = null; + EventLoopGroup group = new NioEventLoopGroup(); + final Promise promise = group.next().newPromise(); + try { + serverChannel = new ServerBootstrap().group(group) + .channel(NioServerSocketChannel.class) + .handler(new LoggingHandler(LogLevel.INFO)) + .childHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + ch.pipeline().addLast(sslServerCtx.newHandler(ch.alloc())); + if (!serverProduceError) { + ch.pipeline().addLast(new AlertValidationHandler(clientProvider, serverProduceError, + exception, promise)); + } + ch.pipeline().addLast(new ChannelInboundHandlerAdapter() { + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + ctx.close(); + } + }); + } + }).bind(0).sync().channel(); + + clientChannel = new Bootstrap().group(group) + .channel(NioSocketChannel.class) + .handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + ch.pipeline().addLast(sslClientCtx.newHandler(ch.alloc())); + if (serverProduceError) { + ch.pipeline().addLast(new AlertValidationHandler(clientProvider, serverProduceError, + exception, promise)); + } + ch.pipeline().addLast(new ChannelInboundHandlerAdapter() { + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + ctx.close(); + } + }); + } + }).connect(serverChannel.localAddress()).syncUninterruptibly().channel(); + // Block until we received the correct exception + promise.syncUninterruptibly(); + } finally { + if (clientChannel != null) { + clientChannel.close().syncUninterruptibly(); + } + if (serverChannel != null) { + serverChannel.close().syncUninterruptibly(); + } + group.shutdownGracefully(); + + ReferenceCountUtil.release(sslServerCtx); + ReferenceCountUtil.release(sslClientCtx); + } + } + + private static final class ExceptionTrustManagerFactory extends SimpleTrustManagerFactory { + private final CertificateException exception; + + ExceptionTrustManagerFactory(CertificateException exception) { + this.exception = exception; + } + + @Override + protected void engineInit(KeyStore keyStore) { } + @Override + protected void engineInit(ManagerFactoryParameters managerFactoryParameters) { } + + @Override + protected TrustManager[] engineGetTrustManagers() { + return new TrustManager[] { new X509TrustManager() { + + @Override + public void checkClientTrusted(X509Certificate[] x509Certificates, String s) + throws CertificateException { + throw exception; + } + + @Override + public void checkServerTrusted(X509Certificate[] x509Certificates, String s) + throws CertificateException { + throw exception; + } + + @Override + public X509Certificate[] getAcceptedIssuers() { + return EmptyArrays.EMPTY_X509_CERTIFICATES; + } + } }; + } + } + + private static final class AlertValidationHandler extends ChannelInboundHandlerAdapter { + private final SslProvider clientProvider; + private final boolean serverProduceError; + private final CertificateException exception; + private final Promise promise; + + AlertValidationHandler(SslProvider clientProvider, boolean serverProduceError, + CertificateException exception, Promise promise) { + this.clientProvider = clientProvider; + this.serverProduceError = serverProduceError; + this.exception = exception; + this.promise = promise; + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + // Unwrap as its wrapped by a DecoderException + Throwable unwrappedCause = cause.getCause(); + if (unwrappedCause instanceof SSLException) { + if (exception instanceof TestCertificateException) { + CertPathValidatorException.Reason reason = + ((CertPathValidatorException) exception.getCause()).getReason(); + if (reason == CertPathValidatorException.BasicReason.EXPIRED) { + verifyException(clientProvider, serverProduceError, unwrappedCause, promise, "expired"); + } else if (reason == CertPathValidatorException.BasicReason.NOT_YET_VALID) { + // BoringSSL may use "expired" in this case while others use "bad" + verifyException(clientProvider, serverProduceError, unwrappedCause, promise, "expired", "bad"); + } else if (reason == CertPathValidatorException.BasicReason.REVOKED) { + verifyException(clientProvider, serverProduceError, unwrappedCause, promise, "revoked"); + } + } else if (exception instanceof CertificateExpiredException) { + verifyException(clientProvider, serverProduceError, unwrappedCause, promise, "expired"); + } else if (exception instanceof CertificateNotYetValidException) { + // BoringSSL may use "expired" in this case while others use "bad" + verifyException(clientProvider, serverProduceError, unwrappedCause, promise, "expired", "bad"); + } else if (exception instanceof CertificateRevokedException) { + verifyException(clientProvider, serverProduceError, unwrappedCause, promise, "revoked"); + } + } + } + } + + // Its a bit hacky to verify against the message that is part of the exception but there is no other way + // at the moment as there are no different exceptions for the different alerts. + private static void verifyException(SslProvider clientProvider, boolean serverProduceError, + Throwable cause, Promise promise, String... messageParts) { + String message = cause.getMessage(); + // When the error is produced on the client side and the client side uses JDK as provider it will always + // use "certificate unknown". + if (!serverProduceError && clientProvider == SslProvider.JDK && + message.toLowerCase(Locale.UK).contains("unknown")) { + promise.setSuccess(null); + return; + } + + for (String m: messageParts) { + if (message.toLowerCase(Locale.UK).contains(m.toLowerCase(Locale.UK))) { + promise.setSuccess(null); + return; + } + } + Throwable error = new AssertionError("message not contains any of '" + + Arrays.toString(messageParts) + "': " + message); + error.initCause(cause); + promise.setFailure(error); + } + + private static final class TestCertificateException extends CertificateException { + private static final long serialVersionUID = -5816338303868751410L; + + TestCertificateException(Throwable cause) { + super(cause); + } + } +} diff --git a/netty-handler-ssl/src/test/java/io/netty/handler/ssl/SslHandlerTest.java b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/SslHandlerTest.java new file mode 100644 index 0000000..c1f4d73 --- /dev/null +++ b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/SslHandlerTest.java @@ -0,0 +1,1885 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.ssl; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.netty.buffer.UnpooledByteBufAllocator; +import io.netty.channel.Channel; +import io.netty.channel.ChannelDuplexHandler; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelOutboundHandlerAdapter; +import io.netty.channel.ChannelPromise; +import io.netty.channel.DefaultChannelId; +import io.netty.channel.DefaultEventLoopGroup; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.channel.local.LocalAddress; +import io.netty.channel.local.LocalChannel; +import io.netty.channel.local.LocalServerChannel; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.nio.NioServerSocketChannel; +import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.handler.codec.ByteToMessageDecoder; +import io.netty.handler.codec.CodecException; +import io.netty.handler.codec.DecoderException; +import io.netty.handler.codec.UnsupportedMessageTypeException; +import io.netty.handler.ssl.util.InsecureTrustManagerFactory; +import io.netty.handler.ssl.util.SelfSignedCertificate; +import io.netty.util.AbstractReferenceCounted; +import io.netty.util.IllegalReferenceCountException; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.ReferenceCounted; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.FutureListener; +import io.netty.util.concurrent.ImmediateEventExecutor; +import io.netty.util.concurrent.ImmediateExecutor; +import io.netty.util.concurrent.Promise; +import io.netty.util.internal.EmptyArrays; +import io.netty.util.internal.PlatformDependent; +import org.hamcrest.CoreMatchers; +import org.hamcrest.Matchers; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.api.function.Executable; + +import java.net.InetSocketAddress; +import java.net.Socket; +import java.nio.channels.ClosedChannelException; +import java.security.NoSuchAlgorithmException; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; +import java.util.Collections; +import java.util.List; +import java.util.Queue; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Executor; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; + +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLException; +import javax.net.ssl.SSLProtocolException; +import javax.net.ssl.X509ExtendedTrustManager; + +import static io.netty.buffer.Unpooled.wrappedBuffer; +import static org.hamcrest.CoreMatchers.*; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assumptions.assumeFalse; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +public class SslHandlerTest { + + private static final Executor DIRECT_EXECUTOR = new Executor() { + @Override + public void execute(Runnable command) { + command.run(); + } + }; + + @Test + @Timeout(value = 5000, unit = TimeUnit.MILLISECONDS) + public void testNonApplicationDataFailureFailsQueuedWrites() throws NoSuchAlgorithmException, InterruptedException { + final CountDownLatch writeLatch = new CountDownLatch(1); + final Queue writesToFail = new ConcurrentLinkedQueue(); + SSLEngine engine = newClientModeSSLEngine(); + SslHandler handler = new SslHandler(engine) { + @Override + public void write(final ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + super.write(ctx, msg, promise); + writeLatch.countDown(); + } + }; + EmbeddedChannel ch = new EmbeddedChannel(new ChannelDuplexHandler() { + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) { + if (msg instanceof ByteBuf) { + if (((ByteBuf) msg).isReadable()) { + writesToFail.add(promise); + } else { + promise.setSuccess(); + } + } + ReferenceCountUtil.release(msg); + } + }, handler); + + try { + final CountDownLatch writeCauseLatch = new CountDownLatch(1); + final AtomicReference failureRef = new AtomicReference(); + ch.write(Unpooled.wrappedBuffer(new byte[]{1})).addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + failureRef.compareAndSet(null, future.cause()); + writeCauseLatch.countDown(); + } + }); + writeLatch.await(); + + // Simulate failing the SslHandler non-application writes after there are applications writes queued. + ChannelPromise promiseToFail; + while ((promiseToFail = writesToFail.poll()) != null) { + promiseToFail.setFailure(new RuntimeException("fake exception")); + } + + writeCauseLatch.await(); + Throwable writeCause = failureRef.get(); + assertNotNull(writeCause); + assertThat(writeCause, is(CoreMatchers.instanceOf(SSLException.class))); + Throwable cause = handler.handshakeFuture().cause(); + assertNotNull(cause); + assertThat(cause, is(CoreMatchers.instanceOf(SSLException.class))); + } finally { + assertFalse(ch.finishAndReleaseAll()); + } + } + + @Test + public void testNoSslHandshakeEventWhenNoHandshake() throws Exception { + final AtomicBoolean inActive = new AtomicBoolean(false); + + SSLEngine engine = SSLContext.getDefault().createSSLEngine(); + EmbeddedChannel ch = new EmbeddedChannel( + DefaultChannelId.newInstance(), false, false, new ChannelInboundHandlerAdapter() { + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + // Not forward the event to the SslHandler but just close the Channel. + ctx.close(); + } + }, new SslHandler(engine) { + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + // We want to override what Channel.isActive() will return as otherwise it will + // return true and so trigger an handshake. + inActive.set(true); + super.handlerAdded(ctx); + inActive.set(false); + } + }, new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + if (evt instanceof SslHandshakeCompletionEvent) { + throw (Exception) ((SslHandshakeCompletionEvent) evt).cause(); + } + } + }) { + @Override + public boolean isActive() { + return !inActive.get() && super.isActive(); + } + }; + + ch.register(); + assertFalse(ch.finishAndReleaseAll()); + } + + @Test + @Timeout(value = 3000, unit = TimeUnit.MILLISECONDS) + public void testClientHandshakeTimeout() throws Exception { + assertThrows(SslHandshakeTimeoutException.class, new Executable() { + @Override + public void execute() throws Throwable { + testHandshakeTimeout(true); + } + }); + } + + @Test + @Timeout(value = 3000, unit = TimeUnit.MILLISECONDS) + public void testServerHandshakeTimeout() throws Exception { + assertThrows(SslHandshakeTimeoutException.class, new Executable() { + @Override + public void execute() throws Throwable { + testHandshakeTimeout(false); + } + }); + } + + private static SSLEngine newServerModeSSLEngine() throws NoSuchAlgorithmException { + SSLEngine engine = SSLContext.getDefault().createSSLEngine(); + // Set the mode before we try to do the handshake as otherwise it may throw an IllegalStateException. + // See: + // - https://docs.oracle.com/javase/10/docs/api/javax/net/ssl/SSLEngine.html#beginHandshake() + // - https://mail.openjdk.java.net/pipermail/security-dev/2018-July/017715.html + engine.setUseClientMode(false); + return engine; + } + + private static SSLEngine newClientModeSSLEngine() throws NoSuchAlgorithmException { + SSLEngine engine = SSLContext.getDefault().createSSLEngine(); + // Set the mode before we try to do the handshake as otherwise it may throw an IllegalStateException. + // See: + // - https://docs.oracle.com/javase/10/docs/api/javax/net/ssl/SSLEngine.html#beginHandshake() + // - https://mail.openjdk.java.net/pipermail/security-dev/2018-July/017715.html + engine.setUseClientMode(true); + return engine; + } + + private static void testHandshakeTimeout(boolean client) throws Exception { + SSLEngine engine = SSLContext.getDefault().createSSLEngine(); + engine.setUseClientMode(client); + SslHandler handler = new SslHandler(engine); + handler.setHandshakeTimeoutMillis(1); + + EmbeddedChannel ch = new EmbeddedChannel(handler); + try { + while (!handler.handshakeFuture().isDone()) { + Thread.sleep(10); + // We need to run all pending tasks as the handshake timeout is scheduled on the EventLoop. + ch.runPendingTasks(); + } + + handler.handshakeFuture().syncUninterruptibly(); + } finally { + ch.finishAndReleaseAll(); + } + } + + @Test + @Timeout(value = 5000, unit = TimeUnit.MILLISECONDS) + public void testHandshakeAndClosePromiseFailedOnRemoval() throws Exception { + SSLEngine engine = SSLContext.getDefault().createSSLEngine(); + engine.setUseClientMode(true); + SslHandler handler = new SslHandler(engine); + final AtomicReference handshakeRef = new AtomicReference(); + final AtomicReference closeRef = new AtomicReference(); + EmbeddedChannel ch = new EmbeddedChannel(handler, new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + if (evt instanceof SslHandshakeCompletionEvent) { + handshakeRef.set(((SslHandshakeCompletionEvent) evt).cause()); + } else if (evt instanceof SslCloseCompletionEvent) { + closeRef.set(((SslCloseCompletionEvent) evt).cause()); + } + } + }); + assertFalse(handler.handshakeFuture().isDone()); + assertFalse(handler.sslCloseFuture().isDone()); + + ch.pipeline().remove(handler); + + try { + while (!handler.handshakeFuture().isDone() || handshakeRef.get() == null + || !handler.sslCloseFuture().isDone() || closeRef.get() == null) { + Thread.sleep(10); + // Continue running all pending tasks until we notified for everything. + ch.runPendingTasks(); + } + + assertSame(handler.handshakeFuture().cause(), handshakeRef.get()); + assertSame(handler.sslCloseFuture().cause(), closeRef.get()); + } finally { + ch.finishAndReleaseAll(); + } + } + + @Test + public void testTruncatedPacket() throws Exception { + SSLEngine engine = newServerModeSSLEngine(); + final EmbeddedChannel ch = new EmbeddedChannel(new SslHandler(engine)); + + // Push the first part of a 5-byte handshake message. + ch.writeInbound(wrappedBuffer(new byte[]{22, 3, 1, 0, 5})); + + // Should decode nothing yet. + assertThat(ch.readInbound(), is(nullValue())); + + DecoderException e = assertThrows(DecoderException.class, new Executable() { + @Override + public void execute() throws Throwable { + // Push the second part of the 5-byte handshake message. + ch.writeInbound(wrappedBuffer(new byte[]{2, 0, 0, 1, 0})); + } + }); + // Be sure we cleanup the channel and release any pending messages that may have been generated because + // of an alert. + // See https://github.com/netty/netty/issues/6057. + ch.finishAndReleaseAll(); + + // The pushed message is invalid, so it should raise an exception if it decoded the message correctly. + assertThat(e.getCause(), is(instanceOf(SSLProtocolException.class))); + } + + @Test + public void testNonByteBufWriteIsReleased() throws Exception { + SSLEngine engine = newServerModeSSLEngine(); + final EmbeddedChannel ch = new EmbeddedChannel(new SslHandler(engine)); + + final AbstractReferenceCounted referenceCounted = new AbstractReferenceCounted() { + @Override + public ReferenceCounted touch(Object hint) { + return this; + } + + @Override + protected void deallocate() { + } + }; + + ExecutionException e = assertThrows(ExecutionException.class, new Executable() { + @Override + public void execute() throws Throwable { + ch.write(referenceCounted).get(); + } + }); + assertThat(e.getCause(), is(instanceOf(UnsupportedMessageTypeException.class))); + assertEquals(0, referenceCounted.refCnt()); + assertTrue(ch.finishAndReleaseAll()); + } + + @Test + public void testNonByteBufNotPassThrough() throws Exception { + SSLEngine engine = newServerModeSSLEngine(); + final EmbeddedChannel ch = new EmbeddedChannel(new SslHandler(engine)); + + assertThrows(UnsupportedMessageTypeException.class, new Executable() { + @Override + public void execute() throws Throwable { + ch.writeOutbound(new Object()); + } + }); + ch.finishAndReleaseAll(); + } + + @Test + public void testIncompleteWriteDoesNotCompletePromisePrematurely() throws NoSuchAlgorithmException { + SSLEngine engine = newServerModeSSLEngine(); + EmbeddedChannel ch = new EmbeddedChannel(new SslHandler(engine)); + + ChannelPromise promise = ch.newPromise(); + ByteBuf buf = Unpooled.buffer(10).writeZero(10); + ch.writeAndFlush(buf, promise); + assertFalse(promise.isDone()); + assertTrue(ch.finishAndReleaseAll()); + assertTrue(promise.isDone()); + assertThat(promise.cause(), is(instanceOf(SSLException.class))); + } + + @Test + public void testReleaseSslEngine() throws Exception { + OpenSsl.ensureAvailability(); + + SelfSignedCertificate cert = new SelfSignedCertificate(); + try { + SslContext sslContext = SslContextBuilder.forServer(cert.certificate(), cert.privateKey()) + .sslProvider(SslProvider.OPENSSL) + .build(); + try { + assertEquals(1, ((ReferenceCounted) sslContext).refCnt()); + SSLEngine sslEngine = sslContext.newEngine(ByteBufAllocator.DEFAULT); + EmbeddedChannel ch = new EmbeddedChannel(new SslHandler(sslEngine)); + + assertEquals(2, ((ReferenceCounted) sslContext).refCnt()); + assertEquals(1, ((ReferenceCounted) sslEngine).refCnt()); + + assertTrue(ch.finishAndReleaseAll()); + ch.close().syncUninterruptibly(); + + assertEquals(1, ((ReferenceCounted) sslContext).refCnt()); + assertEquals(0, ((ReferenceCounted) sslEngine).refCnt()); + } finally { + ReferenceCountUtil.release(sslContext); + } + } finally { + cert.delete(); + } + } + + private static final class TlsReadTest extends ChannelOutboundHandlerAdapter { + private volatile boolean readIssued; + + @Override + public void read(ChannelHandlerContext ctx) throws Exception { + readIssued = true; + super.read(ctx); + } + + public void test(final boolean dropChannelActive) throws Exception { + SSLEngine engine = SSLContext.getDefault().createSSLEngine(); + engine.setUseClientMode(true); + + EmbeddedChannel ch = new EmbeddedChannel(false, false, + this, + new SslHandler(engine), + new ChannelInboundHandlerAdapter() { + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + if (!dropChannelActive) { + ctx.fireChannelActive(); + } + } + } + ); + ch.config().setAutoRead(false); + assertFalse(ch.config().isAutoRead()); + + ch.register(); + + assertTrue(readIssued); + readIssued = false; + + assertTrue(ch.writeOutbound(Unpooled.EMPTY_BUFFER)); + assertTrue(readIssued); + assertTrue(ch.finishAndReleaseAll()); + } + } + + @Test + public void testIssueReadAfterActiveWriteFlush() throws Exception { + // the handshake is initiated by channelActive + new TlsReadTest().test(false); + } + + @Test + public void testIssueReadAfterWriteFlushActive() throws Exception { + // the handshake is initiated by flush + new TlsReadTest().test(true); + } + + @Test + @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS) + public void testRemoval() throws Exception { + NioEventLoopGroup group = new NioEventLoopGroup(); + Channel sc = null; + Channel cc = null; + try { + final Promise clientPromise = group.next().newPromise(); + Bootstrap bootstrap = new Bootstrap() + .group(group) + .channel(NioSocketChannel.class) + .handler(newHandler(SslContextBuilder.forClient().trustManager( + InsecureTrustManagerFactory.INSTANCE).build(), clientPromise)); + + SelfSignedCertificate ssc = new SelfSignedCertificate(); + final Promise serverPromise = group.next().newPromise(); + ServerBootstrap serverBootstrap = new ServerBootstrap() + .group(group, group) + .channel(NioServerSocketChannel.class) + .childHandler(newHandler(SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()).build(), + serverPromise)); + sc = serverBootstrap.bind(new InetSocketAddress(0)).syncUninterruptibly().channel(); + cc = bootstrap.connect(sc.localAddress()).syncUninterruptibly().channel(); + + serverPromise.syncUninterruptibly(); + clientPromise.syncUninterruptibly(); + } finally { + if (cc != null) { + cc.close().syncUninterruptibly(); + } + if (sc != null) { + sc.close().syncUninterruptibly(); + } + group.shutdownGracefully(); + } + } + + private static ChannelHandler newHandler(final SslContext sslCtx, final Promise promise) { + return new ChannelInitializer() { + @Override + protected void initChannel(final Channel ch) { + final SslHandler sslHandler = sslCtx.newHandler(ch.alloc()); + sslHandler.setHandshakeTimeoutMillis(1000); + ch.pipeline().addFirst(sslHandler); + sslHandler.handshakeFuture().addListener(new FutureListener() { + @Override + public void operationComplete(final Future future) { + ch.pipeline().remove(sslHandler); + + // Schedule the close so removal has time to propagate exception if any. + ch.eventLoop().execute(new Runnable() { + @Override + public void run() { + ch.close(); + } + }); + } + }); + + ch.pipeline().addLast(new ChannelInboundHandlerAdapter() { + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + if (cause instanceof CodecException) { + cause = cause.getCause(); + } + if (cause instanceof IllegalReferenceCountException) { + promise.setFailure(cause); + } + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) { + promise.trySuccess(null); + } + }); + } + }; + } + + @Test + public void testCloseFutureNotified() throws Exception { + SSLEngine engine = newServerModeSSLEngine(); + SslHandler handler = new SslHandler(engine); + EmbeddedChannel ch = new EmbeddedChannel(handler); + + ch.close(); + + // When the channel is closed the SslHandler will write an empty buffer to the channel. + ByteBuf buf = ch.readOutbound(); + assertFalse(buf.isReadable()); + buf.release(); + + assertFalse(ch.finishAndReleaseAll()); + + assertThat(handler.handshakeFuture().cause(), instanceOf(ClosedChannelException.class)); + assertThat(handler.sslCloseFuture().cause(), instanceOf(ClosedChannelException.class)); + } + + @Test + @Timeout(value = 5000, unit = TimeUnit.MILLISECONDS) + public void testEventsFired() throws Exception { + SSLEngine engine = newServerModeSSLEngine(); + final BlockingQueue events = new LinkedBlockingQueue(); + EmbeddedChannel channel = new EmbeddedChannel(new SslHandler(engine), new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + if (evt instanceof SslCompletionEvent) { + events.add((SslCompletionEvent) evt); + } + } + }); + assertTrue(events.isEmpty()); + assertTrue(channel.finishAndReleaseAll()); + + SslCompletionEvent evt = events.take(); + assertTrue(evt instanceof SslHandshakeCompletionEvent); + assertThat(evt.cause(), instanceOf(ClosedChannelException.class)); + + evt = events.take(); + assertTrue(evt instanceof SslCloseCompletionEvent); + assertThat(evt.cause(), instanceOf(ClosedChannelException.class)); + assertTrue(events.isEmpty()); + } + + @Test + @Timeout(value = 5000, unit = TimeUnit.MILLISECONDS) + public void testHandshakeFailBeforeWritePromise() throws Exception { + SelfSignedCertificate ssc = new SelfSignedCertificate(); + final SslContext sslServerCtx = SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()).build(); + final CountDownLatch latch = new CountDownLatch(2); + final CountDownLatch latch2 = new CountDownLatch(2); + final BlockingQueue events = new LinkedBlockingQueue(); + Channel serverChannel = null; + Channel clientChannel = null; + EventLoopGroup group = new DefaultEventLoopGroup(); + try { + ServerBootstrap sb = new ServerBootstrap(); + sb.group(group) + .channel(LocalServerChannel.class) + .childHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + ch.pipeline().addLast(sslServerCtx.newHandler(ch.alloc())); + ch.pipeline().addLast(new ChannelInboundHandlerAdapter() { + @Override + public void channelActive(ChannelHandlerContext ctx) { + ByteBuf buf = ctx.alloc().buffer(10); + buf.writeZero(buf.capacity()); + ctx.writeAndFlush(buf).addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + events.add(future); + latch.countDown(); + } + }); + } + + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + if (evt instanceof SslCompletionEvent) { + events.add(evt); + latch.countDown(); + latch2.countDown(); + } + } + }); + } + }); + + Bootstrap cb = new Bootstrap(); + cb.group(group) + .channel(LocalChannel.class) + .handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + ch.pipeline().addFirst(new ChannelInboundHandlerAdapter() { + @Override + public void channelActive(ChannelHandlerContext ctx) { + ByteBuf buf = ctx.alloc().buffer(1000); + buf.writeZero(buf.capacity()); + ctx.writeAndFlush(buf); + } + }); + } + }); + + serverChannel = sb.bind(new LocalAddress("SslHandlerTest")).sync().channel(); + clientChannel = cb.connect(serverChannel.localAddress()).sync().channel(); + latch.await(); + + SslCompletionEvent evt = (SslCompletionEvent) events.take(); + assertTrue(evt instanceof SslHandshakeCompletionEvent); + assertThat(evt.cause(), is(instanceOf(SSLException.class))); + + ChannelFuture future = (ChannelFuture) events.take(); + assertThat(future.cause(), is(instanceOf(SSLException.class))); + + serverChannel.close().sync(); + serverChannel = null; + clientChannel.close().sync(); + clientChannel = null; + + latch2.await(); + evt = (SslCompletionEvent) events.take(); + assertTrue(evt instanceof SslCloseCompletionEvent); + assertThat(evt.cause(), is(instanceOf(ClosedChannelException.class))); + assertTrue(events.isEmpty()); + } finally { + if (serverChannel != null) { + serverChannel.close(); + } + if (clientChannel != null) { + clientChannel.close(); + } + group.shutdownGracefully(); + } + } + + @Test + public void writingReadOnlyBufferDoesNotBreakAggregation() throws Exception { + SelfSignedCertificate ssc = new SelfSignedCertificate(); + + final SslContext sslServerCtx = SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()).build(); + + final SslContext sslClientCtx = SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE).build(); + + EventLoopGroup group = new NioEventLoopGroup(); + Channel sc = null; + Channel cc = null; + final CountDownLatch serverReceiveLatch = new CountDownLatch(1); + try { + final int expectedBytes = 11; + sc = new ServerBootstrap() + .group(group) + .channel(NioServerSocketChannel.class) + .childHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ch.pipeline().addLast(sslServerCtx.newHandler(ch.alloc())); + ch.pipeline().addLast(new SimpleChannelInboundHandler() { + private int readBytes; + @Override + protected void channelRead0(ChannelHandlerContext ctx, ByteBuf msg) throws Exception { + readBytes += msg.readableBytes(); + if (readBytes >= expectedBytes) { + serverReceiveLatch.countDown(); + } + } + }); + } + }).bind(new InetSocketAddress(0)).syncUninterruptibly().channel(); + + cc = new Bootstrap() + .group(group) + .channel(NioSocketChannel.class) + .handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ch.pipeline().addLast(sslClientCtx.newHandler(ch.alloc())); + } + }).connect(sc.localAddress()).syncUninterruptibly().channel(); + + // We first write a ReadOnlyBuffer because SslHandler will attempt to take the first buffer and append to it + // until there is no room, or the aggregation size threshold is exceeded. We want to verify that we don't + // throw when a ReadOnlyBuffer is used and just verify that we don't aggregate in this case. + ByteBuf firstBuffer = Unpooled.buffer(10); + firstBuffer.writeByte(0); + firstBuffer = firstBuffer.asReadOnly(); + ByteBuf secondBuffer = Unpooled.buffer(10); + secondBuffer.writeZero(secondBuffer.capacity()); + cc.write(firstBuffer); + cc.writeAndFlush(secondBuffer).syncUninterruptibly(); + serverReceiveLatch.countDown(); + } finally { + if (cc != null) { + cc.close().syncUninterruptibly(); + } + if (sc != null) { + sc.close().syncUninterruptibly(); + } + group.shutdownGracefully(); + + ReferenceCountUtil.release(sslServerCtx); + ReferenceCountUtil.release(sslClientCtx); + } + } + + @Test + @Timeout(value = 10000, unit = TimeUnit.MILLISECONDS) + public void testCloseOnHandshakeFailure() throws Exception { + final SelfSignedCertificate ssc = new SelfSignedCertificate(); + + final SslContext sslServerCtx = SslContextBuilder.forServer(ssc.key(), ssc.cert()).build(); + final SslContext sslClientCtx = SslContextBuilder.forClient() + .trustManager(new SelfSignedCertificate().cert()) + .build(); + + EventLoopGroup group = new NioEventLoopGroup(1); + Channel sc = null; + Channel cc = null; + try { + LocalAddress address = new LocalAddress(getClass().getSimpleName() + ".testCloseOnHandshakeFailure"); + ServerBootstrap sb = new ServerBootstrap() + .group(group) + .channel(LocalServerChannel.class) + .childHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + ch.pipeline().addLast(sslServerCtx.newHandler(ch.alloc())); + } + }); + sc = sb.bind(address).syncUninterruptibly().channel(); + + final AtomicReference sslHandlerRef = new AtomicReference(); + Bootstrap b = new Bootstrap() + .group(group) + .channel(LocalChannel.class) + .handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + SslHandler handler = sslClientCtx.newHandler(ch.alloc()); + + // We propagate the SslHandler via an AtomicReference to the outer-scope as using + // pipeline.get(...) may return null if the pipeline was teared down by the time we call it. + // This will happen if the channel was closed in the meantime. + sslHandlerRef.set(handler); + ch.pipeline().addLast(handler); + } + }); + cc = b.connect(sc.localAddress()).syncUninterruptibly().channel(); + SslHandler handler = sslHandlerRef.get(); + handler.handshakeFuture().awaitUninterruptibly(); + assertFalse(handler.handshakeFuture().isSuccess()); + + cc.closeFuture().syncUninterruptibly(); + } finally { + if (cc != null) { + cc.close().syncUninterruptibly(); + } + if (sc != null) { + sc.close().syncUninterruptibly(); + } + group.shutdownGracefully(); + + ReferenceCountUtil.release(sslServerCtx); + ReferenceCountUtil.release(sslClientCtx); + } + } + + @Test + public void testOutboundClosedAfterChannelInactive() throws Exception { + SslContext context = SslContextBuilder.forClient().build(); + SSLEngine engine = context.newEngine(UnpooledByteBufAllocator.DEFAULT); + + EmbeddedChannel channel = new EmbeddedChannel(); + assertFalse(channel.finish()); + channel.pipeline().addLast(new SslHandler(engine)); + assertFalse(engine.isOutboundDone()); + channel.close().syncUninterruptibly(); + + assertTrue(engine.isOutboundDone()); + } + + @Test + @Timeout(value = 10000, unit = TimeUnit.MILLISECONDS) + public void testHandshakeFailedByWriteBeforeChannelActive() throws Exception { + final SslContext sslClientCtx = SslContextBuilder.forClient() + .protocols(SslProtocols.SSL_v3) + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(SslProvider.JDK).build(); + + EventLoopGroup group = new NioEventLoopGroup(); + Channel sc = null; + Channel cc = null; + final CountDownLatch activeLatch = new CountDownLatch(1); + final AtomicReference errorRef = new AtomicReference(); + final SslHandler sslHandler = sslClientCtx.newHandler(UnpooledByteBufAllocator.DEFAULT); + try { + sc = new ServerBootstrap() + .group(group) + .channel(NioServerSocketChannel.class) + .childHandler(new ChannelInboundHandlerAdapter()) + .bind(new InetSocketAddress(0)).syncUninterruptibly().channel(); + + cc = new Bootstrap() + .group(group) + .channel(NioSocketChannel.class) + .handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ch.pipeline().addLast(sslHandler); + ch.pipeline().addLast(new ChannelInboundHandlerAdapter() { + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) + throws Exception { + if (cause instanceof AssertionError) { + errorRef.set((AssertionError) cause); + } + } + + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + activeLatch.countDown(); + } + }); + } + }).connect(sc.localAddress()).addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + // Write something to trigger the handshake before fireChannelActive is called. + future.channel().writeAndFlush(wrappedBuffer(new byte [] { 1, 2, 3, 4 })); + } + }).syncUninterruptibly().channel(); + + // Ensure there is no AssertionError thrown by having the handshake failed by the writeAndFlush(...) before + // channelActive(...) was called. Let's first wait for the activeLatch countdown to happen and after this + // check if we saw and AssertionError (even if we timed out waiting). + activeLatch.await(5, TimeUnit.SECONDS); + AssertionError error = errorRef.get(); + if (error != null) { + throw error; + } + assertThat(sslHandler.handshakeFuture().await().cause(), + CoreMatchers.instanceOf(SSLException.class)); + } finally { + if (cc != null) { + cc.close().syncUninterruptibly(); + } + if (sc != null) { + sc.close().syncUninterruptibly(); + } + group.shutdownGracefully(); + + ReferenceCountUtil.release(sslClientCtx); + } + } + + @Test + @Timeout(value = 10000, unit = TimeUnit.MILLISECONDS) + public void testHandshakeTimeoutFlushStartsHandshake() throws Exception { + testHandshakeTimeout0(false); + } + + @Test + @Timeout(value = 10000, unit = TimeUnit.MILLISECONDS) + public void testHandshakeTimeoutStartTLS() throws Exception { + testHandshakeTimeout0(true); + } + + private static void testHandshakeTimeout0(final boolean startTls) throws Exception { + final SslContext sslClientCtx = SslContextBuilder.forClient() + .startTls(true) + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(SslProvider.JDK).build(); + + EventLoopGroup group = new NioEventLoopGroup(); + Channel sc = null; + Channel cc = null; + final SslHandler sslHandler = sslClientCtx.newHandler(UnpooledByteBufAllocator.DEFAULT); + sslHandler.setHandshakeTimeout(500, TimeUnit.MILLISECONDS); + + try { + sc = new ServerBootstrap() + .group(group) + .channel(NioServerSocketChannel.class) + .childHandler(new ChannelInboundHandlerAdapter()) + .bind(new InetSocketAddress(0)).syncUninterruptibly().channel(); + + ChannelFuture future = new Bootstrap() + .group(group) + .channel(NioSocketChannel.class) + .handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ch.pipeline().addLast(sslHandler); + if (startTls) { + ch.pipeline().addLast(new ChannelInboundHandlerAdapter() { + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + ctx.writeAndFlush(wrappedBuffer(new byte[] { 1, 2, 3, 4 })); + } + }); + } + } + }).connect(sc.localAddress()); + if (!startTls) { + future.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + // Write something to trigger the handshake before fireChannelActive is called. + future.channel().writeAndFlush(wrappedBuffer(new byte [] { 1, 2, 3, 4 })); + } + }); + } + cc = future.syncUninterruptibly().channel(); + + Throwable cause = sslHandler.handshakeFuture().await().cause(); + assertThat(cause, CoreMatchers.instanceOf(SSLException.class)); + assertThat(cause.getMessage(), containsString("timed out")); + } finally { + if (cc != null) { + cc.close().syncUninterruptibly(); + } + if (sc != null) { + sc.close().syncUninterruptibly(); + } + group.shutdownGracefully(); + ReferenceCountUtil.release(sslClientCtx); + } + } + + @Test + public void testHandshakeWithExecutorThatExecuteDirectlyJDK() throws Throwable { + testHandshakeWithExecutor(DIRECT_EXECUTOR, SslProvider.JDK, false); + } + + @Test + public void testHandshakeWithImmediateExecutorJDK() throws Throwable { + testHandshakeWithExecutor(ImmediateExecutor.INSTANCE, SslProvider.JDK, false); + } + + @Test + public void testHandshakeWithImmediateEventExecutorJDK() throws Throwable { + testHandshakeWithExecutor(ImmediateEventExecutor.INSTANCE, SslProvider.JDK, false); + } + + @Test + public void testHandshakeWithExecutorJDK() throws Throwable { + ExecutorService executorService = Executors.newCachedThreadPool(); + try { + testHandshakeWithExecutor(executorService, SslProvider.JDK, false); + } finally { + executorService.shutdown(); + } + } + + @Test + public void testHandshakeWithExecutorThatExecuteDirectlyOpenSsl() throws Throwable { + OpenSsl.ensureAvailability(); + testHandshakeWithExecutor(DIRECT_EXECUTOR, SslProvider.OPENSSL, false); + } + + @Test + public void testHandshakeWithImmediateExecutorOpenSsl() throws Throwable { + OpenSsl.ensureAvailability(); + testHandshakeWithExecutor(ImmediateExecutor.INSTANCE, SslProvider.OPENSSL, false); + } + + @Test + public void testHandshakeWithImmediateEventExecutorOpenSsl() throws Throwable { + OpenSsl.ensureAvailability(); + testHandshakeWithExecutor(ImmediateEventExecutor.INSTANCE, SslProvider.OPENSSL, false); + } + + @Test + public void testHandshakeWithExecutorOpenSsl() throws Throwable { + OpenSsl.ensureAvailability(); + ExecutorService executorService = Executors.newCachedThreadPool(); + try { + testHandshakeWithExecutor(executorService, SslProvider.OPENSSL, false); + } finally { + executorService.shutdown(); + } + } + + @Test + public void testHandshakeMTLSWithExecutorThatExecuteDirectlyJDK() throws Throwable { + testHandshakeWithExecutor(DIRECT_EXECUTOR, SslProvider.JDK, true); + } + + @Test + public void testHandshakeMTLSWithImmediateExecutorJDK() throws Throwable { + testHandshakeWithExecutor(ImmediateExecutor.INSTANCE, SslProvider.JDK, true); + } + + @Test + public void testHandshakeMTLSWithImmediateEventExecutorJDK() throws Throwable { + testHandshakeWithExecutor(ImmediateEventExecutor.INSTANCE, SslProvider.JDK, true); + } + + @Test + public void testHandshakeMTLSWithExecutorJDK() throws Throwable { + ExecutorService executorService = Executors.newCachedThreadPool(); + try { + testHandshakeWithExecutor(executorService, SslProvider.JDK, true); + } finally { + executorService.shutdown(); + } + } + + @Test + public void testHandshakeMTLSWithExecutorThatExecuteDirectlyOpenSsl() throws Throwable { + OpenSsl.ensureAvailability(); + testHandshakeWithExecutor(DIRECT_EXECUTOR, SslProvider.OPENSSL, true); + } + + @Test + public void testHandshakeMTLSWithImmediateExecutorOpenSsl() throws Throwable { + OpenSsl.ensureAvailability(); + testHandshakeWithExecutor(ImmediateExecutor.INSTANCE, SslProvider.OPENSSL, true); + } + + @Test + public void testHandshakeMTLSWithImmediateEventExecutorOpenSsl() throws Throwable { + OpenSsl.ensureAvailability(); + testHandshakeWithExecutor(ImmediateEventExecutor.INSTANCE, SslProvider.OPENSSL, true); + } + + @Test + public void testHandshakeMTLSWithExecutorOpenSsl() throws Throwable { + OpenSsl.ensureAvailability(); + ExecutorService executorService = Executors.newCachedThreadPool(); + try { + testHandshakeWithExecutor(executorService, SslProvider.OPENSSL, true); + } finally { + executorService.shutdown(); + } + } + + private static void testHandshakeWithExecutor(Executor executor, SslProvider provider, boolean mtls) + throws Throwable { + final SelfSignedCertificate cert = new SelfSignedCertificate(); + final SslContext sslClientCtx; + final SslContext sslServerCtx; + if (mtls) { + sslClientCtx = SslContextBuilder.forClient().protocols(SslProtocols.TLS_v1_2) + .trustManager(InsecureTrustManagerFactory.INSTANCE).keyManager(cert.key(), cert.cert()) + .sslProvider(provider).build(); + + sslServerCtx = SslContextBuilder.forServer(cert.key(), cert.cert()).protocols(SslProtocols.TLS_v1_2) + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .clientAuth(ClientAuth.REQUIRE) + .sslProvider(provider).build(); + } else { + sslClientCtx = SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(provider).build(); + + sslServerCtx = SslContextBuilder.forServer(cert.key(), cert.cert()) + .sslProvider(provider).build(); + } + + EventLoopGroup group = new NioEventLoopGroup(); + Channel sc = null; + Channel cc = null; + final SslHandler clientSslHandler = new SslHandler( + sslClientCtx.newEngine(UnpooledByteBufAllocator.DEFAULT), executor); + final SslHandler serverSslHandler = new SslHandler( + sslServerCtx.newEngine(UnpooledByteBufAllocator.DEFAULT), executor); + final AtomicReference causeRef = new AtomicReference(); + try { + sc = new ServerBootstrap() + .group(group) + .channel(NioServerSocketChannel.class) + .childHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + ch.pipeline().addLast(serverSslHandler); + ch.pipeline().addLast(new ChannelInboundHandlerAdapter() { + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + causeRef.compareAndSet(null, cause); + } + }); + } + }) + .bind(new InetSocketAddress(0)).syncUninterruptibly().channel(); + + ChannelFuture future = new Bootstrap() + .group(group) + .channel(NioSocketChannel.class) + .handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + ch.pipeline().addLast(clientSslHandler); + ch.pipeline().addLast(new ChannelInboundHandlerAdapter() { + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + causeRef.compareAndSet(null, cause); + } + }); + } + }).connect(sc.localAddress()); + cc = future.syncUninterruptibly().channel(); + + assertTrue(clientSslHandler.handshakeFuture().await().isSuccess()); + assertTrue(serverSslHandler.handshakeFuture().await().isSuccess()); + Throwable cause = causeRef.get(); + if (cause != null) { + throw cause; + } + } finally { + if (cc != null) { + cc.close().syncUninterruptibly(); + } + if (sc != null) { + sc.close().syncUninterruptibly(); + } + group.shutdownGracefully(); + ReferenceCountUtil.release(sslClientCtx); + } + } + + @Test + public void testClientHandshakeTimeoutBecauseExecutorNotExecute() throws Exception { + testHandshakeTimeoutBecauseExecutorNotExecute(true); + } + + @Test + public void testServerHandshakeTimeoutBecauseExecutorNotExecute() throws Exception { + testHandshakeTimeoutBecauseExecutorNotExecute(false); + } + + private static void testHandshakeTimeoutBecauseExecutorNotExecute(final boolean client) throws Exception { + final SslContext sslClientCtx = SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(SslProvider.JDK).build(); + + final SelfSignedCertificate cert = new SelfSignedCertificate(); + final SslContext sslServerCtx = SslContextBuilder.forServer(cert.key(), cert.cert()) + .sslProvider(SslProvider.JDK).build(); + + EventLoopGroup group = new NioEventLoopGroup(); + Channel sc = null; + Channel cc = null; + final SslHandler clientSslHandler = sslClientCtx.newHandler(UnpooledByteBufAllocator.DEFAULT, new Executor() { + @Override + public void execute(Runnable command) { + if (!client) { + command.run(); + } + // Do nothing to simulate slow execution. + } + }); + if (client) { + clientSslHandler.setHandshakeTimeout(100, TimeUnit.MILLISECONDS); + } + final SslHandler serverSslHandler = sslServerCtx.newHandler(UnpooledByteBufAllocator.DEFAULT, new Executor() { + @Override + public void execute(Runnable command) { + if (client) { + command.run(); + } + // Do nothing to simulate slow execution. + } + }); + if (!client) { + serverSslHandler.setHandshakeTimeout(100, TimeUnit.MILLISECONDS); + } + try { + sc = new ServerBootstrap() + .group(group) + .channel(NioServerSocketChannel.class) + .childHandler(serverSslHandler) + .bind(new InetSocketAddress(0)).syncUninterruptibly().channel(); + + ChannelFuture future = new Bootstrap() + .group(group) + .channel(NioSocketChannel.class) + .handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + ch.pipeline().addLast(clientSslHandler); + } + }).connect(sc.localAddress()); + cc = future.syncUninterruptibly().channel(); + + if (client) { + Throwable cause = clientSslHandler.handshakeFuture().await().cause(); + assertThat(cause, CoreMatchers.instanceOf(SslHandshakeTimeoutException.class)); + assertFalse(serverSslHandler.handshakeFuture().await().isSuccess()); + } else { + Throwable cause = serverSslHandler.handshakeFuture().await().cause(); + assertThat(cause, CoreMatchers.instanceOf(SslHandshakeTimeoutException.class)); + assertFalse(clientSslHandler.handshakeFuture().await().isSuccess()); + } + } finally { + if (cc != null) { + cc.close().syncUninterruptibly(); + } + if (sc != null) { + sc.close().syncUninterruptibly(); + } + group.shutdownGracefully(); + ReferenceCountUtil.release(sslClientCtx); + } + } + + @Test + @Timeout(value = 5000, unit = TimeUnit.MILLISECONDS) + public void testSessionTicketsWithTLSv12() throws Throwable { + testSessionTickets(SslProvider.OPENSSL, SslProtocols.TLS_v1_2, true); + } + + @Test + @Timeout(value = 5000, unit = TimeUnit.MILLISECONDS) + public void testSessionTicketsWithTLSv13() throws Throwable { + assumeTrue(SslProvider.isTlsv13Supported(SslProvider.OPENSSL)); + testSessionTickets(SslProvider.OPENSSL, SslProtocols.TLS_v1_3, true); + } + + @Test + @Timeout(value = 5000, unit = TimeUnit.MILLISECONDS) + public void testSessionTicketsWithTLSv12AndNoKey() throws Throwable { + testSessionTickets(SslProvider.OPENSSL, SslProtocols.TLS_v1_2, false); + } + + @Test + @Timeout(value = 5000, unit = TimeUnit.MILLISECONDS) + public void testSessionTicketsWithTLSv13AndNoKey() throws Throwable { + assumeTrue(OpenSsl.isTlsv13Supported()); + testSessionTickets(SslProvider.OPENSSL, SslProtocols.TLS_v1_3, false); + } + + private static void testSessionTickets(SslProvider provider, String protocol, boolean withKey) throws Throwable { + OpenSsl.ensureAvailability(); + final SslContext sslClientCtx = SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .sslProvider(provider) + .protocols(protocol) + .build(); + + // Explicit enable session cache as it's disabled by default atm. + ((OpenSslContext) sslClientCtx).sessionContext() + .setSessionCacheEnabled(true); + + final SelfSignedCertificate cert = new SelfSignedCertificate(); + final SslContext sslServerCtx = SslContextBuilder.forServer(cert.key(), cert.cert()) + .sslProvider(provider) + .protocols(protocol) + .build(); + + if (withKey) { + OpenSslSessionTicketKey key = new OpenSslSessionTicketKey(new byte[OpenSslSessionTicketKey.NAME_SIZE], + new byte[OpenSslSessionTicketKey.HMAC_KEY_SIZE], new byte[OpenSslSessionTicketKey.AES_KEY_SIZE]); + ((OpenSslSessionContext) sslClientCtx.sessionContext()).setTicketKeys(key); + ((OpenSslSessionContext) sslServerCtx.sessionContext()).setTicketKeys(key); + } else { + ((OpenSslSessionContext) sslClientCtx.sessionContext()).setTicketKeys(); + ((OpenSslSessionContext) sslServerCtx.sessionContext()).setTicketKeys(); + } + + EventLoopGroup group = new NioEventLoopGroup(); + Channel sc = null; + final byte[] bytes = new byte[96]; + PlatformDependent.threadLocalRandom().nextBytes(bytes); + try { + final AtomicReference assertErrorRef = new AtomicReference(); + sc = new ServerBootstrap() + .group(group) + .channel(NioServerSocketChannel.class) + .childHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + final SslHandler sslHandler = sslServerCtx.newHandler(ch.alloc()); + ch.pipeline().addLast(sslServerCtx.newHandler(UnpooledByteBufAllocator.DEFAULT)); + ch.pipeline().addLast(new ChannelInboundHandlerAdapter() { + + private int handshakeCount; + + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + if (evt instanceof SslHandshakeCompletionEvent) { + handshakeCount++; + ReferenceCountedOpenSslEngine engine = + (ReferenceCountedOpenSslEngine) sslHandler.engine(); + // This test only works for non TLSv1.3 as TLSv1.3 will establish sessions after + // the handshake is done. + // See https://www.openssl.org/docs/man1.1.1/man3/SSL_CTX_sess_set_get_cb.html + if (!SslProtocols.TLS_v1_3.equals(engine.getSession().getProtocol())) { + // First should not re-use the session + try { + assertEquals(handshakeCount > 1, engine.isSessionReused()); + } catch (AssertionError error) { + assertErrorRef.set(error); + return; + } + } + + ctx.writeAndFlush(Unpooled.wrappedBuffer(bytes)); + } + } + }); + } + }) + .bind(new InetSocketAddress(0)).syncUninterruptibly().channel(); + + InetSocketAddress serverAddr = (InetSocketAddress) sc.localAddress(); + testSessionTickets(serverAddr, group, sslClientCtx, bytes, false); + testSessionTickets(serverAddr, group, sslClientCtx, bytes, true); + AssertionError error = assertErrorRef.get(); + if (error != null) { + throw error; + } + } finally { + if (sc != null) { + sc.close().syncUninterruptibly(); + } + group.shutdownGracefully(); + ReferenceCountUtil.release(sslClientCtx); + } + } + + private static void testSessionTickets(InetSocketAddress serverAddress, EventLoopGroup group, + SslContext sslClientCtx, final byte[] bytes, boolean isReused) + throws Throwable { + Channel cc = null; + final BlockingQueue queue = new LinkedBlockingQueue(); + try { + final SslHandler clientSslHandler = sslClientCtx.newHandler(UnpooledByteBufAllocator.DEFAULT, + serverAddress.getAddress().getHostAddress(), serverAddress.getPort()); + + ChannelFuture future = new Bootstrap() + .group(group) + .channel(NioSocketChannel.class) + .handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + ch.pipeline().addLast(clientSslHandler); + ch.pipeline().addLast(new ByteToMessageDecoder() { + + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) { + if (in.readableBytes() == bytes.length) { + queue.add(in.readBytes(bytes.length)); + } + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + queue.add(cause); + } + }); + } + }).connect(serverAddress); + cc = future.syncUninterruptibly().channel(); + + assertTrue(clientSslHandler.handshakeFuture().sync().isSuccess()); + + ReferenceCountedOpenSslEngine engine = (ReferenceCountedOpenSslEngine) clientSslHandler.engine(); + // This test only works for non TLSv1.3 as TLSv1.3 will establish sessions after + // the handshake is done. + // See https://www.openssl.org/docs/man1.1.1/man3/SSL_CTX_sess_set_get_cb.html + if (!SslProtocols.TLS_v1_3.equals(engine.getSession().getProtocol())) { + assertEquals(isReused, engine.isSessionReused()); + } + Object obj = queue.take(); + if (obj instanceof ByteBuf) { + ByteBuf buffer = (ByteBuf) obj; + ByteBuf expected = Unpooled.wrappedBuffer(bytes); + try { + assertEquals(expected, buffer); + } finally { + expected.release(); + buffer.release(); + } + } else { + throw (Throwable) obj; + } + } finally { + if (cc != null) { + cc.close().syncUninterruptibly(); + } + } + } + + @Test + @Timeout(value = 10000, unit = TimeUnit.MILLISECONDS) + public void testHandshakeFailureOnlyFireExceptionOnce() throws Exception { + final SslContext sslClientCtx = SslContextBuilder.forClient() + .trustManager(new X509ExtendedTrustManager() { + @Override + public void checkClientTrusted(X509Certificate[] chain, String authType, Socket socket) + throws CertificateException { + failVerification(); + } + + @Override + public void checkServerTrusted(X509Certificate[] chain, String authType, Socket socket) + throws CertificateException { + failVerification(); + } + + @Override + public void checkClientTrusted(X509Certificate[] chain, String authType, SSLEngine engine) + throws CertificateException { + failVerification(); + } + + @Override + public void checkServerTrusted(X509Certificate[] chain, String authType, SSLEngine engine) + throws CertificateException { + failVerification(); + } + + @Override + public void checkClientTrusted(X509Certificate[] chain, String authType) + throws CertificateException { + failVerification(); + } + + @Override + public void checkServerTrusted(X509Certificate[] chain, String authType) + throws CertificateException { + failVerification(); + } + + @Override + public X509Certificate[] getAcceptedIssuers() { + return EmptyArrays.EMPTY_X509_CERTIFICATES; + } + + private void failVerification() throws CertificateException { + throw new CertificateException(); + } + }) + .sslProvider(SslProvider.JDK).build(); + + final SelfSignedCertificate cert = new SelfSignedCertificate(); + final SslContext sslServerCtx = SslContextBuilder.forServer(cert.key(), cert.cert()) + .sslProvider(SslProvider.JDK).build(); + + EventLoopGroup group = new NioEventLoopGroup(); + Channel sc = null; + final SslHandler clientSslHandler = sslClientCtx.newHandler(UnpooledByteBufAllocator.DEFAULT); + final SslHandler serverSslHandler = sslServerCtx.newHandler(UnpooledByteBufAllocator.DEFAULT); + + try { + final Object terminalEvent = new Object(); + final BlockingQueue errorQueue = new LinkedBlockingQueue(); + sc = new ServerBootstrap() + .group(group) + .channel(NioServerSocketChannel.class) + .childHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + ch.pipeline().addLast(serverSslHandler); + ch.pipeline().addLast(new ChannelInboundHandlerAdapter() { + @Override + public void exceptionCaught(final ChannelHandlerContext ctx, Throwable cause) { + errorQueue.add(cause); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) { + errorQueue.add(terminalEvent); + } + }); + } + }) + .bind(new InetSocketAddress(0)).syncUninterruptibly().channel(); + final ChannelFuture future = new Bootstrap() + .group(group) + .channel(NioSocketChannel.class) + .handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + ch.pipeline().addLast(clientSslHandler); + } + }).connect(sc.localAddress()); + future.syncUninterruptibly(); + clientSslHandler.handshakeFuture().addListener(new FutureListener() { + @Override + public void operationComplete(Future f) { + future.channel().close(); + } + }); + assertFalse(clientSslHandler.handshakeFuture().await().isSuccess()); + assertFalse(serverSslHandler.handshakeFuture().await().isSuccess()); + + Object error = errorQueue.take(); + assertThat(error, Matchers.instanceOf(DecoderException.class)); + assertThat(((Throwable) error).getCause(), Matchers.instanceOf(SSLException.class)); + Object terminal = errorQueue.take(); + assertSame(terminalEvent, terminal); + + assertNull(errorQueue.poll(1, TimeUnit.MILLISECONDS)); + } finally { + if (sc != null) { + sc.close().syncUninterruptibly(); + } + group.shutdownGracefully(); + } + } + + @Test + public void testHandshakeFailureCipherMissmatchTLSv12Jdk() throws Exception { + testHandshakeFailureCipherMissmatch(SslProvider.JDK, false); + } + + @Test + public void testHandshakeFailureCipherMissmatchTLSv13Jdk() throws Exception { + assumeTrue(SslProvider.isTlsv13Supported(SslProvider.JDK)); + testHandshakeFailureCipherMissmatch(SslProvider.JDK, true); + } + + @Test + public void testHandshakeFailureCipherMissmatchTLSv12OpenSsl() throws Exception { + OpenSsl.ensureAvailability(); + testHandshakeFailureCipherMissmatch(SslProvider.OPENSSL, false); + } + + @Test + public void testHandshakeFailureCipherMissmatchTLSv13OpenSsl() throws Exception { + OpenSsl.ensureAvailability(); + assumeTrue(SslProvider.isTlsv13Supported(SslProvider.OPENSSL)); + assumeFalse(OpenSsl.isBoringSSL(), "BoringSSL does not support setting ciphers for TLSv1.3 explicit"); + testHandshakeFailureCipherMissmatch(SslProvider.OPENSSL, true); + } + + private static void testHandshakeFailureCipherMissmatch(SslProvider provider, boolean tls13) throws Exception { + final String clientCipher; + final String serverCipher; + final String protocol; + + if (tls13) { + clientCipher = "TLS_AES_128_GCM_SHA256"; + serverCipher = "TLS_AES_256_GCM_SHA384"; + protocol = SslProtocols.TLS_v1_3; + } else { + clientCipher = "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256"; + serverCipher = "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384"; + protocol = SslProtocols.TLS_v1_2; + } + final SslContext sslClientCtx = SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .protocols(protocol) + .ciphers(Collections.singleton(clientCipher)) + .sslProvider(provider).build(); + + final SelfSignedCertificate cert = new SelfSignedCertificate(); + final SslContext sslServerCtx = SslContextBuilder.forServer(cert.key(), cert.cert()) + .protocols(protocol) + .ciphers(Collections.singleton(serverCipher)) + .sslProvider(provider).build(); + + EventLoopGroup group = new NioEventLoopGroup(); + Channel sc = null; + Channel cc = null; + final SslHandler clientSslHandler = sslClientCtx.newHandler(UnpooledByteBufAllocator.DEFAULT); + final SslHandler serverSslHandler = sslServerCtx.newHandler(UnpooledByteBufAllocator.DEFAULT); + + class SslEventHandler extends ChannelInboundHandlerAdapter { + private final AtomicReference ref; + + SslEventHandler(AtomicReference ref) { + this.ref = ref; + } + + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + if (evt instanceof SslHandshakeCompletionEvent) { + ref.set((SslHandshakeCompletionEvent) evt); + } + super.userEventTriggered(ctx, evt); + } + } + final AtomicReference clientEvent = + new AtomicReference(); + final AtomicReference serverEvent = + new AtomicReference(); + try { + sc = new ServerBootstrap() + .group(group) + .channel(NioServerSocketChannel.class) + .childHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ch.pipeline().addLast(serverSslHandler); + ch.pipeline().addLast(new SslEventHandler(serverEvent)); + } + }) + .bind(new InetSocketAddress(0)).syncUninterruptibly().channel(); + + ChannelFuture future = new Bootstrap() + .group(group) + .channel(NioSocketChannel.class) + .handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + ch.pipeline().addLast(clientSslHandler); + ch.pipeline().addLast(new SslEventHandler(clientEvent)); + } + }).connect(sc.localAddress()); + cc = future.syncUninterruptibly().channel(); + + Throwable clientCause = clientSslHandler.handshakeFuture().await().cause(); + assertThat(clientCause, CoreMatchers.instanceOf(SSLException.class)); + assertThat(clientCause.getCause(), not(CoreMatchers.instanceOf(ClosedChannelException.class))); + Throwable serverCause = serverSslHandler.handshakeFuture().await().cause(); + assertThat(serverCause, CoreMatchers.instanceOf(SSLException.class)); + assertThat(serverCause.getCause(), not(CoreMatchers.instanceOf(ClosedChannelException.class))); + cc.close().syncUninterruptibly(); + sc.close().syncUninterruptibly(); + + Throwable eventClientCause = clientEvent.get().cause(); + assertThat(eventClientCause, CoreMatchers.instanceOf(SSLException.class)); + assertThat(eventClientCause.getCause(), + not(CoreMatchers.instanceOf(ClosedChannelException.class))); + Throwable serverEventCause = serverEvent.get().cause(); + + assertThat(serverEventCause, CoreMatchers.instanceOf(SSLException.class)); + assertThat(serverEventCause.getCause(), + not(CoreMatchers.instanceOf(ClosedChannelException.class))); + } finally { + group.shutdownGracefully(); + ReferenceCountUtil.release(sslClientCtx); + } + } + + @Test + public void testSslCompletionEventsTls12JDK() throws Exception { + testSslCompletionEvents(SslProvider.JDK, SslProtocols.TLS_v1_2, true); + testSslCompletionEvents(SslProvider.JDK, SslProtocols.TLS_v1_2, false); + } + + @Test + public void testSslCompletionEventsTls12Openssl() throws Exception { + OpenSsl.ensureAvailability(); + testSslCompletionEvents(SslProvider.OPENSSL, SslProtocols.TLS_v1_2, true); + testSslCompletionEvents(SslProvider.OPENSSL, SslProtocols.TLS_v1_2, false); + } + + @Test + public void testSslCompletionEventsTls13JDK() throws Exception { + assumeTrue(SslProvider.isTlsv13Supported(SslProvider.JDK)); + testSslCompletionEvents(SslProvider.JDK, SslProtocols.TLS_v1_3, true); + testSslCompletionEvents(SslProvider.JDK, SslProtocols.TLS_v1_3, false); + } + + @Test + public void testSslCompletionEventsTls13Openssl() throws Exception { + OpenSsl.ensureAvailability(); + assumeTrue(SslProvider.isTlsv13Supported(SslProvider.OPENSSL)); + testSslCompletionEvents(SslProvider.OPENSSL, SslProtocols.TLS_v1_3, true); + testSslCompletionEvents(SslProvider.OPENSSL, SslProtocols.TLS_v1_3, false); + } + + private void testSslCompletionEvents(SslProvider provider, final String protocol, boolean clientClose) + throws Exception { + final SslContext sslClientCtx = SslContextBuilder.forClient() + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .protocols(protocol) + .sslProvider(provider).build(); + + final SelfSignedCertificate cert = new SelfSignedCertificate(); + final SslContext sslServerCtx = SslContextBuilder.forServer(cert.key(), cert.cert()) + .protocols(protocol) + .sslProvider(provider).build(); + + EventLoopGroup group = new NioEventLoopGroup(); + + final LinkedBlockingQueue acceptedChannels = + new LinkedBlockingQueue(); + + final LinkedBlockingQueue serverHandshakeCompletionEvents = + new LinkedBlockingQueue(); + + final LinkedBlockingQueue clientHandshakeCompletionEvents = + new LinkedBlockingQueue(); + + final LinkedBlockingQueue serverCloseCompletionEvents = + new LinkedBlockingQueue(); + + final LinkedBlockingQueue clientCloseCompletionEvents = + new LinkedBlockingQueue(); + try { + Channel sc = new ServerBootstrap() + .group(group) + .channel(NioServerSocketChannel.class) + .childHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + acceptedChannels.add(ch); + SslHandler handler = sslServerCtx.newHandler(ch.alloc()); + if (!SslProtocols.TLS_v1_3.equals(protocol)) { + handler.setCloseNotifyReadTimeout(5, TimeUnit.SECONDS); + } + ch.pipeline().addLast(handler); + ch.pipeline().addLast(new SslCompletionEventHandler( + serverHandshakeCompletionEvents, serverCloseCompletionEvents)); + } + }) + .bind(new InetSocketAddress(0)).syncUninterruptibly().channel(); + + Bootstrap bs = new Bootstrap() + .group(group) + .channel(NioSocketChannel.class) + .handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + SslHandler handler = sslClientCtx.newHandler( + ch.alloc(), "netty.io", 9999); + if (!SslProtocols.TLS_v1_3.equals(protocol)) { + handler.setCloseNotifyReadTimeout(5, TimeUnit.SECONDS); + } + ch.pipeline().addLast(handler); + ch.pipeline().addLast( + new SslCompletionEventHandler( + clientHandshakeCompletionEvents, clientCloseCompletionEvents)); + } + }) + .remoteAddress(sc.localAddress()); + + Channel cc1 = bs.connect().sync().channel(); + Channel cc2 = bs.connect().sync().channel(); + + // We expect 4 events as we have 2 connections and for each connection there should be one event + // on the server-side and one on the client-side. + for (int i = 0; i < 2; i++) { + SslHandshakeCompletionEvent event = clientHandshakeCompletionEvents.take(); + assertTrue(event.isSuccess()); + } + for (int i = 0; i < 2; i++) { + SslHandshakeCompletionEvent event = serverHandshakeCompletionEvents.take(); + assertTrue(event.isSuccess()); + } + + assertEquals(0, clientCloseCompletionEvents.size()); + assertEquals(0, serverCloseCompletionEvents.size()); + + if (clientClose) { + cc1.close().sync(); + cc2.close().sync(); + + acceptedChannels.take().closeFuture().sync(); + acceptedChannels.take().closeFuture().sync(); + } else { + acceptedChannels.take().close().sync(); + acceptedChannels.take().close().sync(); + + cc1.closeFuture().sync(); + cc2.closeFuture().sync(); + } + + // We expect 4 events as we have 2 connections and for each connection there should be one event + // on the server-side and one on the client-side. + for (int i = 0; i < 2; i++) { + SslCloseCompletionEvent event = clientCloseCompletionEvents.take(); + if (clientClose) { + // When we use TLSv1.3 the remote peer is not required to send a close_notify as response. + // See: + // - https://datatracker.ietf.org/doc/html/rfc8446#section-6.1 + // - https://bugs.openjdk.org/browse/JDK-8208526 + if (SslProtocols.TLS_v1_3.equals(protocol)) { + assertNotNull(event); + } else { + assertTrue(event.isSuccess()); + } + } else { + assertTrue(event.isSuccess()); + } + } + for (int i = 0; i < 2; i++) { + SslCloseCompletionEvent event = serverCloseCompletionEvents.take(); + + if (clientClose) { + assertTrue(event.isSuccess()); + } else { + // When we use TLSv1.3 the remote peer is not required to send a close_notify as response. + // See: + // - https://datatracker.ietf.org/doc/html/rfc8446#section-6.1 + // - https://bugs.openjdk.org/browse/JDK-8208526 + if (SslProtocols.TLS_v1_3.equals(protocol)) { + assertNotNull(event); + } else { + assertTrue(event.isSuccess()); + } + } + } + + sc.close().sync(); + assertEquals(0, clientHandshakeCompletionEvents.size()); + assertEquals(0, serverHandshakeCompletionEvents.size()); + assertEquals(0, clientCloseCompletionEvents.size()); + assertEquals(0, serverCloseCompletionEvents.size()); + } finally { + group.shutdownGracefully(); + ReferenceCountUtil.release(sslClientCtx); + ReferenceCountUtil.release(sslServerCtx); + } + } + + private static class SslCompletionEventHandler extends ChannelInboundHandlerAdapter { + private final Queue handshakeCompletionEvents; + private final Queue closeCompletionEvents; + + SslCompletionEventHandler(Queue handshakeCompletionEvents, + Queue closeCompletionEvents) { + this.handshakeCompletionEvents = handshakeCompletionEvents; + this.closeCompletionEvents = closeCompletionEvents; + } + + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + if (evt instanceof SslHandshakeCompletionEvent) { + handshakeCompletionEvents.add((SslHandshakeCompletionEvent) evt); + } else if (evt instanceof SslCloseCompletionEvent) { + closeCompletionEvents.add((SslCloseCompletionEvent) evt); + } + } + + @Override + public boolean isSharable() { + return true; + } + } +} diff --git a/netty-handler-ssl/src/test/java/io/netty/handler/ssl/SslUtilsTest.java b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/SslUtilsTest.java new file mode 100644 index 0000000..b158652 --- /dev/null +++ b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/SslUtilsTest.java @@ -0,0 +1,175 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.ValueSource; + +import javax.net.ssl.SSLContext; +import javax.net.ssl.SSLEngine; +import javax.net.ssl.SSLException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.security.NoSuchAlgorithmException; + +import static io.netty.handler.ssl.SslUtils.DTLS_1_0; +import static io.netty.handler.ssl.SslUtils.DTLS_1_2; +import static io.netty.handler.ssl.SslUtils.DTLS_1_3; +import static io.netty.handler.ssl.SslUtils.NOT_ENOUGH_DATA; +import static io.netty.handler.ssl.SslUtils.getEncryptedPacketLength; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class SslUtilsTest { + + @SuppressWarnings("deprecation") + @Test + public void testPacketLength() throws SSLException, NoSuchAlgorithmException { + SSLEngine engineLE = newEngine(); + SSLEngine engineBE = newEngine(); + + ByteBuffer empty = ByteBuffer.allocate(0); + ByteBuffer cTOsLE = ByteBuffer.allocate(17 * 1024).order(ByteOrder.LITTLE_ENDIAN); + ByteBuffer cTOsBE = ByteBuffer.allocate(17 * 1024); + + assertTrue(engineLE.wrap(empty, cTOsLE).bytesProduced() > 0); + cTOsLE.flip(); + + assertTrue(engineBE.wrap(empty, cTOsBE).bytesProduced() > 0); + cTOsBE.flip(); + + ByteBuf bufferLE = Unpooled.buffer().order(ByteOrder.LITTLE_ENDIAN).writeBytes(cTOsLE); + ByteBuf bufferBE = Unpooled.buffer().writeBytes(cTOsBE); + + // Test that the packet-length for BE and LE is the same + assertEquals(getEncryptedPacketLength(bufferBE, 0), getEncryptedPacketLength(bufferLE, 0)); + assertEquals(getEncryptedPacketLength(new ByteBuffer[] { bufferBE.nioBuffer() }, 0), + getEncryptedPacketLength(new ByteBuffer[] { bufferLE.nioBuffer().order(ByteOrder.LITTLE_ENDIAN) }, 0)); + } + + private static SSLEngine newEngine() throws SSLException, NoSuchAlgorithmException { + SSLEngine engine = SSLContext.getDefault().createSSLEngine(); + engine.setUseClientMode(true); + engine.beginHandshake(); + return engine; + } + + @Test + public void testIsTLSv13Cipher() { + assertTrue(SslUtils.isTLSv13Cipher("TLS_AES_128_GCM_SHA256")); + assertTrue(SslUtils.isTLSv13Cipher("TLS_AES_256_GCM_SHA384")); + assertTrue(SslUtils.isTLSv13Cipher("TLS_CHACHA20_POLY1305_SHA256")); + assertTrue(SslUtils.isTLSv13Cipher("TLS_AES_128_CCM_SHA256")); + assertTrue(SslUtils.isTLSv13Cipher("TLS_AES_128_CCM_8_SHA256")); + assertFalse(SslUtils.isTLSv13Cipher("TLS_DHE_RSA_WITH_AES_128_GCM_SHA256")); + } + + @Test + public void shouldGetPacketLengthOfGmsslProtocolFromByteBuf() { + int bodyLength = 65; + ByteBuf buf = Unpooled.buffer() + .writeByte(SslUtils.SSL_CONTENT_TYPE_HANDSHAKE) + .writeShort(SslUtils.GMSSL_PROTOCOL_VERSION) + .writeShort(bodyLength); + + int packetLength = getEncryptedPacketLength(buf, 0); + assertEquals(bodyLength + SslUtils.SSL_RECORD_HEADER_LENGTH, packetLength); + buf.release(); + } + + @Test + public void shouldGetPacketLengthOfGmsslProtocolFromByteBuffer() { + int bodyLength = 65; + ByteBuf buf = Unpooled.buffer() + .writeByte(SslUtils.SSL_CONTENT_TYPE_HANDSHAKE) + .writeShort(SslUtils.GMSSL_PROTOCOL_VERSION) + .writeShort(bodyLength); + + int packetLength = getEncryptedPacketLength(new ByteBuffer[] { buf.nioBuffer() }, 0); + assertEquals(bodyLength + SslUtils.SSL_RECORD_HEADER_LENGTH, packetLength); + buf.release(); + } + + @ParameterizedTest + @ValueSource(ints = {DTLS_1_0, DTLS_1_2, DTLS_1_3}) // six numbers + public void shouldGetPacketLengthOfDtlsRecordFromByteBuf(int dtlsVersion) { + int bodyLength = 65; + ByteBuf buf = Unpooled.buffer() + .writeByte(SslUtils.SSL_CONTENT_TYPE_HANDSHAKE) + .writeShort(dtlsVersion) + .writeShort(0) // epoch + .writeBytes(new byte[6]) // sequence number + .writeShort(bodyLength); + + int packetLength = getEncryptedPacketLength(buf, 0); + // bodyLength + DTLS_RECORD_HEADER_LENGTH = 65 + 13 = 78 + assertEquals(78, packetLength); + buf.release(); + } + + @ParameterizedTest + @ValueSource(ints = {DTLS_1_0, DTLS_1_2, DTLS_1_3}) // six numbers + public void shouldGetPacketLengthOfFirstDtlsRecordFromByteBuf(int dtlsVersion) { + int bodyLength = 65; + ByteBuf buf = Unpooled.buffer() + .writeByte(SslUtils.SSL_CONTENT_TYPE_HANDSHAKE) + .writeShort(dtlsVersion) + .writeShort(0) // epoch + .writeBytes(new byte[6]) // sequence number + .writeShort(bodyLength) + .writeBytes(new byte[65]) + .writeByte(SslUtils.SSL_CONTENT_TYPE_HANDSHAKE) + .writeShort(dtlsVersion) + .writeShort(0) // epoch + .writeBytes(new byte[6]) // sequence number + .writeShort(bodyLength) + .writeBytes(new byte[65]); + + int packetLength = getEncryptedPacketLength(buf, 0); + assertEquals(78, packetLength); + buf.release(); + } + + @ParameterizedTest + @ValueSource(ints = {DTLS_1_0, DTLS_1_2, DTLS_1_3}) // six numbers + public void shouldSupportIncompletePackets(int dtlsVersion) { + ByteBuf buf = Unpooled.buffer() + .writeByte(SslUtils.SSL_CONTENT_TYPE_HANDSHAKE) + .writeShort(dtlsVersion) + .writeShort(0) // epoch + .writeBytes(new byte[6]) // sequence number + .writeByte(0); + // Left off the last byte of the length on purpose + + int packetLength = getEncryptedPacketLength(buf, 0); + assertEquals(NOT_ENOUGH_DATA, packetLength); + buf.release(); + } + + @Test + public void testValidHostNameForSni() { + assertFalse(SslUtils.isValidHostNameForSNI("/test.de"), "SNI domain can't start with /"); + assertFalse(SslUtils.isValidHostNameForSNI("test.de."), "SNI domain can't end with a dot/"); + assertTrue(SslUtils.isValidHostNameForSNI("test.de")); + // see https://datatracker.ietf.org/doc/html/rfc6066#section-3 + // it has to be test.local to qualify as SNI + assertFalse(SslUtils.isValidHostNameForSNI("test"), "SNI has to be FQDN"); + } +} diff --git a/netty-handler-ssl/src/test/java/io/netty/handler/ssl/ocsp/OcspTest.java b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/ocsp/OcspTest.java new file mode 100644 index 0000000..e389a7a --- /dev/null +++ b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/ocsp/OcspTest.java @@ -0,0 +1,535 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.ssl.ocsp; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.DefaultEventLoopGroup; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.local.LocalAddress; +import io.netty.channel.local.LocalChannel; +import io.netty.channel.local.LocalServerChannel; +import io.netty.handler.ssl.OpenSsl; +import io.netty.handler.ssl.ReferenceCountedOpenSslEngine; +import io.netty.handler.ssl.SslContext; +import io.netty.handler.ssl.SslContextBuilder; +import io.netty.handler.ssl.SslHandler; +import io.netty.handler.ssl.SslProvider; +import io.netty.handler.ssl.util.InsecureTrustManagerFactory; +import io.netty.handler.ssl.util.SelfSignedCertificate; +import io.netty.util.CharsetUtil; +import io.netty.util.ReferenceCountUtil; +import org.hamcrest.CoreMatchers; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.api.function.Executable; + +import java.net.SocketAddress; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicReference; + +import javax.net.ssl.SSLHandshakeException; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNotSame; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +public class OcspTest { + + @BeforeAll + public static void checkOcspSupported() { + assumeTrue(OpenSsl.isOcspSupported()); + } + + @Test + public void testJdkClientEnableOcsp() throws Exception { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() throws Throwable { + SslContextBuilder.forClient() + .sslProvider(SslProvider.JDK) + .enableOcsp(true) + .build(); + } + }); + } + + @Test + public void testJdkServerEnableOcsp() throws Exception { + final SelfSignedCertificate ssc = new SelfSignedCertificate(); + try { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() throws Throwable { + SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) + .sslProvider(SslProvider.JDK) + .enableOcsp(true) + .build(); + } + }); + } finally { + ssc.delete(); + } + } + + @Test + public void testClientOcspNotEnabledOpenSsl() throws Exception { + testClientOcspNotEnabled(SslProvider.OPENSSL); + } + + @Test + public void testClientOcspNotEnabledOpenSslRefCnt() throws Exception { + testClientOcspNotEnabled(SslProvider.OPENSSL_REFCNT); + } + + private static void testClientOcspNotEnabled(SslProvider sslProvider) throws Exception { + SslContext context = SslContextBuilder.forClient() + .sslProvider(sslProvider) + .build(); + try { + SslHandler sslHandler = context.newHandler(ByteBufAllocator.DEFAULT); + final ReferenceCountedOpenSslEngine engine = (ReferenceCountedOpenSslEngine) sslHandler.engine(); + try { + assertThrows(IllegalStateException.class, new Executable() { + @Override + public void execute() { + engine.getOcspResponse(); + } + }); + } finally { + engine.release(); + } + } finally { + ReferenceCountUtil.release(context); + } + } + + @Test + public void testServerOcspNotEnabledOpenSsl() throws Exception { + testServerOcspNotEnabled(SslProvider.OPENSSL); + } + + @Test + public void testServerOcspNotEnabledOpenSslRefCnt() throws Exception { + testServerOcspNotEnabled(SslProvider.OPENSSL_REFCNT); + } + + private static void testServerOcspNotEnabled(SslProvider sslProvider) throws Exception { + SelfSignedCertificate ssc = new SelfSignedCertificate(); + try { + SslContext context = SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) + .sslProvider(sslProvider) + .build(); + try { + SslHandler sslHandler = context.newHandler(ByteBufAllocator.DEFAULT); + final ReferenceCountedOpenSslEngine engine = (ReferenceCountedOpenSslEngine) sslHandler.engine(); + try { + assertThrows(IllegalStateException.class, new Executable() { + @Override + public void execute() { + engine.setOcspResponse(new byte[] { 1, 2, 3 }); + } + }); + } finally { + engine.release(); + } + } finally { + ReferenceCountUtil.release(context); + } + } finally { + ssc.delete(); + } + } + + @Test + @Timeout(value = 10000L, unit = TimeUnit.MILLISECONDS) + public void testClientAcceptingOcspStapleOpenSsl() throws Exception { + testClientAcceptingOcspStaple(SslProvider.OPENSSL); + } + + @Test + @Timeout(value = 10000L, unit = TimeUnit.MILLISECONDS) + public void testClientAcceptingOcspStapleOpenSslRefCnt() throws Exception { + testClientAcceptingOcspStaple(SslProvider.OPENSSL_REFCNT); + } + + /** + * The Server provides an OCSP staple and the Client accepts it. + */ + private static void testClientAcceptingOcspStaple(SslProvider sslProvider) throws Exception { + final CountDownLatch latch = new CountDownLatch(1); + ChannelInboundHandlerAdapter serverHandler = new ChannelInboundHandlerAdapter() { + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + ctx.writeAndFlush(Unpooled.wrappedBuffer("Hello, World!".getBytes())); + ctx.fireChannelActive(); + } + }; + + ChannelInboundHandlerAdapter clientHandler = new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + try { + ReferenceCountUtil.release(msg); + } finally { + latch.countDown(); + } + } + }; + + byte[] response = newOcspResponse(); + TestClientOcspContext callback = new TestClientOcspContext(true); + + handshake(sslProvider, latch, serverHandler, response, clientHandler, callback); + + byte[] actual = callback.response(); + + assertNotNull(actual); + assertNotSame(response, actual); + assertArrayEquals(response, actual); + } + + @Test + @Timeout(value = 10000L, unit = TimeUnit.MILLISECONDS) + public void testClientRejectingOcspStapleOpenSsl() throws Exception { + testClientRejectingOcspStaple(SslProvider.OPENSSL); + } + + @Test + @Timeout(value = 10000L, unit = TimeUnit.MILLISECONDS) + public void testClientRejectingOcspStapleOpenSslRefCnt() throws Exception { + testClientRejectingOcspStaple(SslProvider.OPENSSL_REFCNT); + } + + /** + * The Server provides an OCSP staple and the Client rejects it. + */ + private static void testClientRejectingOcspStaple(SslProvider sslProvider) throws Exception { + final AtomicReference causeRef = new AtomicReference(); + final CountDownLatch latch = new CountDownLatch(1); + + ChannelInboundHandlerAdapter clientHandler = new ChannelInboundHandlerAdapter() { + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + try { + causeRef.set(cause); + } finally { + latch.countDown(); + } + } + }; + + byte[] response = newOcspResponse(); + TestClientOcspContext callback = new TestClientOcspContext(false); + + handshake(sslProvider, latch, null, response, clientHandler, callback); + + byte[] actual = callback.response(); + + assertNotNull(actual); + assertNotSame(response, actual); + assertArrayEquals(response, actual); + + Throwable cause = causeRef.get(); + assertThat(cause, CoreMatchers.instanceOf(SSLHandshakeException.class)); + } + + @Test + @Timeout(value = 10000L, unit = TimeUnit.MILLISECONDS) + public void testServerHasNoStapleOpenSsl() throws Exception { + testServerHasNoStaple(SslProvider.OPENSSL); + } + + @Test + @Timeout(value = 10000L, unit = TimeUnit.MILLISECONDS) + public void testServerHasNoStapleOpenSslRefCnt() throws Exception { + testServerHasNoStaple(SslProvider.OPENSSL_REFCNT); + } + + /** + * The server has OCSP stapling enabled but doesn't provide a staple. + */ + private static void testServerHasNoStaple(SslProvider sslProvider) throws Exception { + final CountDownLatch latch = new CountDownLatch(1); + ChannelInboundHandlerAdapter serverHandler = new ChannelInboundHandlerAdapter() { + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + ctx.writeAndFlush(Unpooled.wrappedBuffer("Hello, World!".getBytes())); + ctx.fireChannelActive(); + } + }; + + ChannelInboundHandlerAdapter clientHandler = new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + try { + ReferenceCountUtil.release(msg); + } finally { + latch.countDown(); + } + } + }; + + byte[] response = null; + TestClientOcspContext callback = new TestClientOcspContext(true); + + handshake(sslProvider, latch, serverHandler, response, clientHandler, callback); + + byte[] actual = callback.response(); + + assertNull(response); + assertNull(actual); + } + + @Test + @Timeout(value = 10000L, unit = TimeUnit.MILLISECONDS) + public void testClientExceptionOpenSsl() throws Exception { + testClientException(SslProvider.OPENSSL); + } + + @Test + @Timeout(value = 10000L, unit = TimeUnit.MILLISECONDS) + public void testClientExceptionOpenSslRefCnt() throws Exception { + testClientException(SslProvider.OPENSSL_REFCNT); + } + + /** + * Testing what happens if the {@link OcspClientCallback} throws an {@link Exception}. + * + * The exception should bubble up on the client side and the connection should get closed. + */ + private static void testClientException(SslProvider sslProvider) throws Exception { + final AtomicReference causeRef = new AtomicReference(); + final CountDownLatch latch = new CountDownLatch(1); + + ChannelInboundHandlerAdapter clientHandler = new ChannelInboundHandlerAdapter() { + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + try { + causeRef.set(cause); + } finally { + latch.countDown(); + } + } + }; + + final OcspTestException clientException = new OcspTestException("testClientException"); + byte[] response = newOcspResponse(); + OcspClientCallback callback = new OcspClientCallback() { + @Override + public boolean verify(byte[] response) throws Exception { + throw clientException; + } + }; + + handshake(sslProvider, latch, null, response, clientHandler, callback); + + assertSame(clientException, causeRef.get()); + } + + private static void handshake(SslProvider sslProvider, CountDownLatch latch, ChannelHandler serverHandler, + byte[] response, ChannelHandler clientHandler, OcspClientCallback callback) throws Exception { + + SelfSignedCertificate ssc = new SelfSignedCertificate(); + try { + SslContext serverSslContext = SslContextBuilder.forServer(ssc.certificate(), ssc.privateKey()) + .sslProvider(sslProvider) + .enableOcsp(true) + .build(); + + try { + SslContext clientSslContext = SslContextBuilder.forClient() + .sslProvider(sslProvider) + .enableOcsp(true) + .trustManager(InsecureTrustManagerFactory.INSTANCE) + .build(); + + try { + EventLoopGroup group = new DefaultEventLoopGroup(); + try { + LocalAddress address = new LocalAddress("handshake-" + Math.random()); + Channel server = newServer(group, address, serverSslContext, response, serverHandler); + Channel client = newClient(group, address, clientSslContext, callback, clientHandler); + try { + assertTrue(latch.await(10L, TimeUnit.SECONDS)); + } finally { + client.close().syncUninterruptibly(); + server.close().syncUninterruptibly(); + } + } finally { + group.shutdownGracefully(1L, 1L, TimeUnit.SECONDS); + } + } finally { + ReferenceCountUtil.release(clientSslContext); + } + } finally { + ReferenceCountUtil.release(serverSslContext); + } + } finally { + ssc.delete(); + } + } + + private static Channel newServer(EventLoopGroup group, SocketAddress address, + SslContext context, byte[] response, ChannelHandler handler) { + + ServerBootstrap bootstrap = new ServerBootstrap() + .channel(LocalServerChannel.class) + .group(group) + .childHandler(newServerHandler(context, response, handler)); + + return bootstrap.bind(address) + .syncUninterruptibly() + .channel(); + } + + private static Channel newClient(EventLoopGroup group, SocketAddress address, + SslContext context, OcspClientCallback callback, ChannelHandler handler) { + + Bootstrap bootstrap = new Bootstrap() + .channel(LocalChannel.class) + .group(group) + .handler(newClientHandler(context, callback, handler)); + + return bootstrap.connect(address) + .syncUninterruptibly() + .channel(); + } + + private static ChannelHandler newServerHandler(final SslContext context, + final byte[] response, final ChannelHandler handler) { + return new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ChannelPipeline pipeline = ch.pipeline(); + SslHandler sslHandler = context.newHandler(ch.alloc()); + + if (response != null) { + ReferenceCountedOpenSslEngine engine = (ReferenceCountedOpenSslEngine) sslHandler.engine(); + engine.setOcspResponse(response); + } + + pipeline.addLast(sslHandler); + + if (handler != null) { + pipeline.addLast(handler); + } + } + }; + } + + private static ChannelHandler newClientHandler(final SslContext context, + final OcspClientCallback callback, final ChannelHandler handler) { + return new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ChannelPipeline pipeline = ch.pipeline(); + + SslHandler sslHandler = context.newHandler(ch.alloc()); + ReferenceCountedOpenSslEngine engine = (ReferenceCountedOpenSslEngine) sslHandler.engine(); + + pipeline.addLast(sslHandler); + pipeline.addLast(new OcspClientCallbackHandler(engine, callback)); + + if (handler != null) { + pipeline.addLast(handler); + } + } + }; + } + + private static byte[] newOcspResponse() { + // Assume we got the OCSP staple from somewhere. Using a bogus byte[] + // in the test because getting a true staple from the CA is quite involved. + // It requires HttpCodec and Bouncycastle and the test may be very unreliable + // because the OCSP responder servers are basically being DDoS'd by the + // Internet. + + return "I am a bogus OCSP staple. OpenSSL does not care about the format of the byte[]!" + .getBytes(CharsetUtil.US_ASCII); + } + + private interface OcspClientCallback { + boolean verify(byte[] staple) throws Exception; + } + + private static final class TestClientOcspContext implements OcspClientCallback { + + private final CountDownLatch latch = new CountDownLatch(1); + private final boolean valid; + + private volatile byte[] response; + + TestClientOcspContext(boolean valid) { + this.valid = valid; + } + + public byte[] response() throws InterruptedException, TimeoutException { + assertTrue(latch.await(10L, TimeUnit.SECONDS)); + return response; + } + + @Override + public boolean verify(byte[] response) throws Exception { + this.response = response; + latch.countDown(); + + return valid; + } + } + + private static final class OcspClientCallbackHandler extends OcspClientHandler { + + private final OcspClientCallback callback; + + OcspClientCallbackHandler(ReferenceCountedOpenSslEngine engine, OcspClientCallback callback) { + super(engine); + this.callback = callback; + } + + @Override + protected boolean verify(ChannelHandlerContext ctx, ReferenceCountedOpenSslEngine engine) throws Exception { + byte[] response = engine.getOcspResponse(); + return callback.verify(response); + } + } + + private static final class OcspTestException extends IllegalStateException { + private static final long serialVersionUID = 4516426833250228159L; + + OcspTestException(String message) { + super(message); + } + } +} diff --git a/netty-handler-ssl/src/test/java/io/netty/handler/ssl/util/FingerprintTrustManagerFactoryTest.java b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/util/FingerprintTrustManagerFactoryTest.java new file mode 100644 index 0000000..3aad29b --- /dev/null +++ b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/util/FingerprintTrustManagerFactoryTest.java @@ -0,0 +1,141 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.ssl.util; + + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import javax.net.ssl.X509TrustManager; +import java.security.cert.CertificateException; +import java.security.cert.X509Certificate; + +import static io.netty.handler.ssl.Java8SslTestUtils.loadCertCollection; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class FingerprintTrustManagerFactoryTest { + + private static final String FIRST_CERT_SHA1_FINGERPRINT + = "18:C7:C2:76:1F:DF:72:3B:2A:A7:BB:2C:B0:30:D4:C0:C0:72:AD:84"; + + private static final String FIRST_CERT_SHA256_FINGERPRINT + = "1C:53:0E:6B:FF:93:F0:DE:C2:E6:E7:9D:10:53:58:FF:" + + "DD:8E:68:CD:82:D9:C9:36:9B:43:EE:B3:DC:13:68:FB"; + + private static final X509Certificate[] FIRST_CHAIN; + + private static final X509Certificate[] SECOND_CHAIN; + + static { + try { + FIRST_CHAIN = loadCertCollection("test.crt"); + SECOND_CHAIN = loadCertCollection("test2.crt"); + } catch (Exception e) { + throw new Error(e); + } + } + + @Test + public void testFingerprintWithInvalidLength() { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + FingerprintTrustManagerFactory.builder("SHA-256").fingerprints("00:00:00").build(); + } + }); + } + + @Test + public void testFingerprintWithUnexpectedCharacters() { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + FingerprintTrustManagerFactory.builder("SHA-256").fingerprints("00:00:00\n").build(); + } + }); + } + + @Test + public void testWithNoFingerprints() { + assertThrows(IllegalStateException.class, new Executable() { + @Override + public void execute() { + FingerprintTrustManagerFactory.builder("SHA-256").build(); + } + }); + } + + @Test + public void testWithNullFingerprint() { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + FingerprintTrustManagerFactory + .builder("SHA-256") + .fingerprints(FIRST_CERT_SHA256_FINGERPRINT, null) + .build(); + } + }); + } + + @Test + public void testValidSHA1Fingerprint() throws Exception { + FingerprintTrustManagerFactory factory = new FingerprintTrustManagerFactory(FIRST_CERT_SHA1_FINGERPRINT); + + assertTrue(factory.engineGetTrustManagers().length > 0); + assertTrue(factory.engineGetTrustManagers()[0] instanceof X509TrustManager); + X509TrustManager tm = (X509TrustManager) factory.engineGetTrustManagers()[0]; + tm.checkClientTrusted(FIRST_CHAIN, "test"); + } + + @Test + public void testTrustedCertificateWithSHA256Fingerprint() throws Exception { + FingerprintTrustManagerFactory factory = FingerprintTrustManagerFactory + .builder("SHA-256") + .fingerprints(FIRST_CERT_SHA256_FINGERPRINT) + .build(); + + X509Certificate[] keyCertChain = loadCertCollection("test.crt"); + assertNotNull(keyCertChain); + assertTrue(factory.engineGetTrustManagers().length > 0); + assertTrue(factory.engineGetTrustManagers()[0] instanceof X509TrustManager); + X509TrustManager tm = (X509TrustManager) factory.engineGetTrustManagers()[0]; + tm.checkClientTrusted(keyCertChain, "test"); + } + + @Test + public void testUntrustedCertificateWithSHA256Fingerprint() throws Exception { + FingerprintTrustManagerFactory factory = FingerprintTrustManagerFactory + .builder("SHA-256") + .fingerprints(FIRST_CERT_SHA256_FINGERPRINT) + .build(); + + assertTrue(factory.engineGetTrustManagers().length > 0); + assertTrue(factory.engineGetTrustManagers()[0] instanceof X509TrustManager); + final X509TrustManager tm = (X509TrustManager) factory.engineGetTrustManagers()[0]; + + assertThrows(CertificateException.class, new Executable() { + @Override + public void execute() throws Throwable { + tm.checkClientTrusted(SECOND_CHAIN, "test"); + } + }); + } + +} diff --git a/netty-handler-ssl/src/test/java/io/netty/handler/ssl/util/SelfSignedCertificateTest.java b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/util/SelfSignedCertificateTest.java new file mode 100644 index 0000000..59ca376 --- /dev/null +++ b/netty-handler-ssl/src/test/java/io/netty/handler/ssl/util/SelfSignedCertificateTest.java @@ -0,0 +1,50 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ssl.util; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import java.security.cert.CertificateException; + +import static org.junit.jupiter.api.Assertions.*; + +class SelfSignedCertificateTest { + + @Test + void fqdnAsteriskDoesNotThrowTest() { + assertDoesNotThrow(new Executable() { + @Override + public void execute() throws Throwable { + new SelfSignedCertificate("*.netty.io", "EC", 256); + } + }); + + assertDoesNotThrow(new Executable() { + @Override + public void execute() throws Throwable { + new SelfSignedCertificate("*.netty.io", "RSA", 2048); + } + }); + } + + @Test + void fqdnAsteriskFileNameTest() throws CertificateException { + SelfSignedCertificate ssc = new SelfSignedCertificate("*.netty.io", "EC", 256); + assertFalse(ssc.certificate().getName().contains("*")); + assertFalse(ssc.privateKey().getName().contains("*")); + } +} diff --git a/netty-handler-ssl/src/test/resources/io/netty/handler/ssl/ec_params_unsupported.pem b/netty-handler-ssl/src/test/resources/io/netty/handler/ssl/ec_params_unsupported.pem new file mode 100644 index 0000000..cafaea4 --- /dev/null +++ b/netty-handler-ssl/src/test/resources/io/netty/handler/ssl/ec_params_unsupported.pem @@ -0,0 +1,18 @@ +-----BEGIN CERTIFICATE----- +MIIC8TCCApagAwIBAgIJAOeu9WKx0IutMAoGCCqGSM49BAMCMFkxCzAJBgNVBAYT +AkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBXaWRn +aXRzIFB0eSBMdGQxEjAQBgNVBAMMCWxvY2FsaG9zdDAeFw0xODExMDEyMDAwMTha +Fw0yMDEwMzEyMDAwMThaMFkxCzAJBgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0 +YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBXaWRnaXRzIFB0eSBMdGQxEjAQBgNVBAMM +CWxvY2FsaG9zdDCCAUswggEDBgcqhkjOPQIBMIH3AgEBMCwGByqGSM49AQECIQD/ +////AAAAAQAAAAAAAAAAAAAAAP///////////////zBbBCD/////AAAAAQAAAAAA +AAAAAAAAAP///////////////AQgWsY12Ko6k+ez671VdpiGvGUdBrDMU7D2O848 +PifSYEsDFQDEnTYIhucEk2pmeOETnSa3gZ9+kARBBGsX0fLhLEJH+Lzm5WOkQPJ3 +A32BLeszoPShOUXYmMKWT+NC4v4af5uO5+tKfA+eFivOM1drMV7Oy7ZAaDe/UfUC +IQD/////AAAAAP//////////vOb6racXnoTzucrC/GMlUQIBAQNCAAQ3G/YXF+YE +XuASiyC1822n0iNPumHgFplF+6/veicKm+mDNA3NA/1zTRKJOyqpDdMyB9tgFrdV +zcHzw7JW+lDpo1MwUTAdBgNVHQ4EFgQUonraQIcnNMppU+GoJ6+vPbC84pEwHwYD +VR0jBBgwFoAUonraQIcnNMppU+GoJ6+vPbC84pEwDwYDVR0TAQH/BAUwAwEB/zAK +BggqhkjOPQQDAgNJADBGAiEAoIkAinhds0VvNtWdi6f+r+U8AA9rUsR1sJBzVOYD +ErACIQCMMyfEWW8d4N3q8fpZ/lWTNaionVWeZZHWjseTmafWQg== +-----END CERTIFICATE----- diff --git a/netty-handler-ssl/src/test/resources/io/netty/handler/ssl/generate-certificate.sh b/netty-handler-ssl/src/test/resources/io/netty/handler/ssl/generate-certificate.sh new file mode 100755 index 0000000..deed5e9 --- /dev/null +++ b/netty-handler-ssl/src/test/resources/io/netty/handler/ssl/generate-certificate.sh @@ -0,0 +1,22 @@ +# Generate CA key and certificate. +export PASS="password" +openssl req -x509 -newkey rsa:2048 -days 3650 -keyout rsapss-ca-key.pem -out rsapss-ca-cert.cert -subj "/C=GB/O=Netty/OU=netty-parent/CN=west.int" -sigopt rsa_padding_mode:pss -sha256 -sigopt rsa_pss_saltlen:20 -passin env:PASS -passout env:PASS + +# Generate user key nand. +openssl req -newkey rsa:2048 -keyout rsapss-user-key.pem -out rsaValidation-req.pem -subj "/C=GB/O=Netty/OU=netty-parent/CN=c1" -sigopt rsa_padding_mode:pss -sha256 -sigopt rsa_pss_saltlen:20 -passin env:PASS -passout env:PASS + +# Sign user cert request using CA certificate. +openssl x509 -req -in rsaValidation-req.pem -days 3650 -extensions ext -extfile rsapss-signing-ext.txt -CA rsapss-ca-cert.cert -CAkey rsapss-ca-key.pem -CAcreateserial -out rsapss-user-signed.cert -sigopt rsa_padding_mode:pss -sha256 -sigopt rsa_pss_saltlen:20 -passin env:PASS + +# Create user certificate keystore. +openssl pkcs12 -export -out rsaValidation-user-certs.p12 -inkey rsapss-user-key.pem -in rsapss-user-signed.cert -passin env:PASS -passout env:PASS + +# create keystore for the +openssl pkcs12 -in rsapss-ca-cert.cert -inkey rsapss-ca-key.pem -passin env:PASS -certfile rsapss-ca-cert.cert -export -out rsaValidations-server-keystore.p12 -passout env:PASS -name localhost + +# Create Trustore to verify the EndEntity certificate we have created. +keytool -importcert -storetype PKCS12 -keystore rsaValidations-truststore.p12 -storepass $PASS -alias ca -file rsapss-ca-cert.cert -noprompt + +# Clean up files we don't need for the test. +echo "# Cleaning up:" +rm -v rsapss-ca-cert.srl rsapss-ca-key.pem rsapss-user-key.pem rsapss-user-signed.cert rsaValidation-req.pem rsaValidations-truststore.p12 \ No newline at end of file diff --git a/netty-handler-ssl/src/test/resources/io/netty/handler/ssl/generate-certs.sh b/netty-handler-ssl/src/test/resources/io/netty/handler/ssl/generate-certs.sh new file mode 100755 index 0000000..ac82cff --- /dev/null +++ b/netty-handler-ssl/src/test/resources/io/netty/handler/ssl/generate-certs.sh @@ -0,0 +1,85 @@ +#!/bin/sh +# ---------------------------------------------------------------------------- +# Copyright 2021 The Netty Project +# +# The Netty Project licenses this file to you under the Apache License, +# version 2.0 (the "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at: +# +# https://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +# License for the specific language governing permissions and limitations +# under the License. +# ---------------------------------------------------------------------------- + +# Generate a new, self-signed root CA +openssl req -extensions v3_ca -new -x509 -days 30 -nodes -subj "/CN=NettyTestRoot" -newkey rsa:2048 -sha512 -out mutual_auth_ca.pem -keyout mutual_auth_ca.key + +# Generate a certificate/key for the server +openssl req -new -keyout mutual_auth_server.key -nodes -newkey rsa:2048 -subj "/CN=NettyTestServer" | \ + openssl x509 -req -CAkey mutual_auth_ca.key -CA mutual_auth_ca.pem -days 36500 -set_serial $RANDOM -sha512 -out mutual_auth_server.pem + +# Generate a certificate/key for the server to use for Hostname Verification via localhost +openssl req -new -keyout localhost_server_rsa.key -nodes -newkey rsa:2048 -subj "/CN=localhost" | \ + openssl x509 -req -CAkey mutual_auth_ca.key -CA mutual_auth_ca.pem -days 36500 -set_serial $RANDOM -sha512 -out localhost_server.pem +openssl pkcs8 -topk8 -inform PEM -outform PEM -in localhost_server_rsa.key -out localhost_server.key -nocrypt +rm localhost_server_rsa.key + +# Generate a certificate/key for the server to fail for Hostname Verification via localhost +openssl req -new -keyout notlocalhost_server_rsa.key -nodes -newkey rsa:2048 -subj "/CN=NOTlocalhost" | \ + openssl x509 -req -CAkey mutual_auth_ca.key -CA mutual_auth_ca.pem -days 36500 -set_serial $RANDOM -sha512 -out notlocalhost_server.pem +openssl pkcs8 -topk8 -inform PEM -outform PEM -in notlocalhost_server_rsa.key -out notlocalhost_server.key -nocrypt +rm notlocalhost_server_rsa.key + +# Generate an invalid intermediate CA which will be used to sign the client certificate +openssl req -new -keyout mutual_auth_invalid_intermediate_ca.key -nodes -newkey rsa:2048 -subj "/CN=NettyTestInvalidIntermediate" | \ + openssl x509 -req -CAkey mutual_auth_ca.key -CA mutual_auth_ca.pem -days 36500 -set_serial $RANDOM -sha512 -out mutual_auth_invalid_intermediate_ca.pem + +# Generate a client certificate signed by the invalid intermediate CA +openssl req -new -keyout mutual_auth_invalid_client.key -nodes -newkey rsa:2048 -subj "/CN=NettyTestInvalidClient/UID=ClientWithInvalidCa" | \ + openssl x509 -req -CAkey mutual_auth_invalid_intermediate_ca.key -CA mutual_auth_invalid_intermediate_ca.pem -days 36500 -set_serial $RANDOM -sha512 -out mutual_auth_invalid_client.pem + +# Generate a valid intermediate CA which will be used to sign the client certificate +openssl req -new -keyout mutual_auth_intermediate_ca.key -nodes -newkey rsa:2048 -out mutual_auth_intermediate_ca.key +openssl req -new -sha512 -key mutual_auth_intermediate_ca.key -subj "/CN=NettyTestIntermediate" -out intermediate.csr +openssl x509 -req -days 1825 -in intermediate.csr -extfile openssl.cnf -extensions v3_ca -CA mutual_auth_ca.pem -CAkey mutual_auth_ca.key -set_serial $RANDOM -out mutual_auth_intermediate_ca.pem + +# Generate a client certificate signed by the intermediate CA +openssl req -new -keyout mutual_auth_client.key -nodes -newkey rsa:2048 -subj "/CN=NettyTestClient/UID=Client" | \ + openssl x509 -req -CAkey mutual_auth_intermediate_ca.key -CA mutual_auth_intermediate_ca.pem -days 36500 -set_serial $RANDOM -sha512 -out mutual_auth_client.pem + +# For simplicity, squish everything down into PKCS#12 keystores +cat mutual_auth_invalid_intermediate_ca.pem mutual_auth_ca.pem > mutual_auth_invalid_client_cert_chain.pem +cat mutual_auth_intermediate_ca.pem mutual_auth_ca.pem > mutual_auth_client_cert_chain.pem +openssl pkcs12 -export -in mutual_auth_server.pem -inkey mutual_auth_server.key -certfile mutual_auth_ca.pem -out mutual_auth_server.p12 -password pass:example +openssl pkcs12 -export -in mutual_auth_invalid_client.pem -inkey mutual_auth_invalid_client.key -certfile mutual_auth_invalid_client_cert_chain.pem -out mutual_auth_invalid_client.p12 -password pass:example +openssl pkcs12 -export -in mutual_auth_client.pem -inkey mutual_auth_client.key -certfile mutual_auth_client_cert_chain.pem -out mutual_auth_client.p12 -password pass:example + +#PKCS#1 +openssl genrsa -out rsa_pkcs8_unencrypted.key 2048 +openssl genrsa -des3 -out rsa_pkcs8_des3_encrypted.key -passout pass:example 2048 +openssl genrsa -aes128 -out rsa_pkcs8_aes_encrypted.key -passout pass:example 2048 +# If using OpenSSL >3 use -traditional with openssl genrsa to generate traditional PKCS#1 keys +openssl genrsa -traditional -out rsa_pkcs1_unencrypted.key 2048 +openssl genrsa -traditional -des3 -out rsa_pkcs1_des3_encrypted.key -passout pass:example 2048 +openssl genrsa -traditional -aes128 -out rsa_pkcs1_aes_encrypted.key -passout pass:example 2048 +openssl dsaparam -out dsaparam.pem 2048 +openssl gendsa -out dsa_pkcs1_unencrypted.key dsaparam.pem +openssl gendsa -des3 -out dsa_pkcs1_des3_encrypted.key -passout pass:example dsaparam.pem +openssl gendsa -aes128 -out dsa_pkcs1_aes_encrypted.key -passout pass:example dsaparam.pem + +# PBES2 +openssl genrsa -out rsa_pbes2.key +openssl req -new -subj "/CN=NettyTest" -key rsa_pbes2.key -out rsa_pbes2.csr +openssl x509 -req -days 36500 -in rsa_pbes2.csr -signkey rsa_pbes2.key -out rsa_pbes2.crt +openssl pkcs8 -topk8 -inform PEM -in rsa_pbes2.key -outform pem -out rsa_pbes2_enc_pkcs8.key -v2 aes-256-cbc -passin pass:12345678 -passout pass:12345678 + +# Clean up intermediate files +rm intermediate.csr +rm mutual_auth_ca.key mutual_auth_invalid_client.key mutual_auth_client.key mutual_auth_server.key mutual_auth_invalid_intermediate_ca.key mutual_auth_intermediate_ca.key +rm mutual_auth_invalid_client.pem mutual_auth_client.pem mutual_auth_server.pem mutual_auth_client_cert_chain.pem mutual_auth_invalid_intermediate_ca.pem mutual_auth_intermediate_ca.pem mutual_auth_invalid_client_cert_chain.pem +rm dsaparam.pem +rm rsa_pbes2.crt rsa_pbes2.csr rsa_pbes2.key diff --git a/netty-handler-ssl/src/test/resources/io/netty/handler/ssl/localhost_server.pem b/netty-handler-ssl/src/test/resources/io/netty/handler/ssl/localhost_server.pem new file mode 100644 index 0000000..70759b2 --- /dev/null +++ b/netty-handler-ssl/src/test/resources/io/netty/handler/ssl/localhost_server.pem @@ -0,0 +1,17 @@ +-----BEGIN CERTIFICATE----- +MIICozCCAYsCAnS/MA0GCSqGSIb3DQEBDQUAMBgxFjAUBgNVBAMTDU5ldHR5VGVz +dFJvb3QwIBcNMTcwMjE3MDMzMzQ0WhgPMjExNzAxMjQwMzMzNDRaMBQxEjAQBgNV +BAMTCWxvY2FsaG9zdDCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBANis +u0yVnOh7YE+IJREXdKjlc0t+w4twBURFdeOfGpfox7HnlZ0mLLq5ZMUptEdYl1tY +Qt1nFWXn4Zeky/c52Qpm37X1l1J8HK/psHlE11k4Qaco4dJjZd2fNicXkkFpYTHR +++28g9k5SVYNaCxDMmTVRCG75ecFzs/WEdg2/CxU05H4cP0sZ5sPL5Rx+IvfhfAD +IF0dSxtwivyGW0AFyPq81uo4ud2lTzoPFT3P1vU8OaQVV+KwSWGkMSGnGZMLAjbZ +SzUYLwPzUsxnMyVtZLNN808S6o3MlgaIW39c/A+Q8/JW+2LRdY8FCnDMkUVRfnEq +w4YRGiUQtFPTI1BjOcUCAwEAATANBgkqhkiG9w0BAQ0FAAOCAQEAQNXnwE2MJFy5 +ti07xyi8h/mY0Kl1dwZUqx4F9D9eoxLCq2/p3h/Z18AlOmjdW06pvC2sGtQtyEqL +YjuQFbMjXRo9c+6+d+xwdDKTu7+XOTHvznJ8xJpKnFOlohGq/n3efBIJSsaeasTU +slFzmdKYABDZzbsQ4X6YCIOF4XVdEQqmXpS+uEbn5C2sVtG+LXI8srmkVGpCcRew +SuTGanwxLparhBBeN1ARjKzNxXUWuK2UKZ9p8c7n7TXGhd12ZNTcLhk4rCnOFq1J +ySFvP5YL2q29fpEt+Tq0zm3V7An2qtaNDp26cEdevtKPjRyOLkCJx8OlZxc9DZvJ +HjalFDoRUw== +-----END CERTIFICATE----- diff --git a/netty-handler-ssl/src/test/resources/io/netty/handler/ssl/mutual_auth_ca.pem b/netty-handler-ssl/src/test/resources/io/netty/handler/ssl/mutual_auth_ca.pem new file mode 100644 index 0000000..9c9241b --- /dev/null +++ b/netty-handler-ssl/src/test/resources/io/netty/handler/ssl/mutual_auth_ca.pem @@ -0,0 +1,19 @@ +-----BEGIN CERTIFICATE----- +MIIDLDCCAhSgAwIBAgIJAO1m5pioZhLLMA0GCSqGSIb3DQEBDQUAMBgxFjAUBgNV +BAMTDU5ldHR5VGVzdFJvb3QwHhcNMTcwMjE3MDMzMzQ0WhcNMTcwMzE5MDMzMzQ0 +WjAYMRYwFAYDVQQDEw1OZXR0eVRlc3RSb290MIIBIjANBgkqhkiG9w0BAQEFAAOC +AQ8AMIIBCgKCAQEAnC7Y/p/TSWI1KxBKETfFKaRWCPEkoYn5G973WbCF0VDT90PX +xK6yHvhqNdDQZPmddgfDAQfjekHeeIFkjCKlvQu0js0G4Bubz4NffNumd/Mgsix8 +SWJ13lPk+Ly4PDv0bK1zB6BxP1qQm1qxVwsPy9zNP8ylJrM0Div4TXHmnWOfc0JD +4/XPpfeUHH1tt/GMtsS2Gx6EpTVPD2w7LDKUza1/rQ7d9sqmFpgsNcI9Db/sAtFP +lK2iJku5WIXQkmHimn4bqZ9wkiXJ85pm5ggGQqGMPSbe+2Lh24AvZMIBiwPbkjEU +EDFXEJfKOC3Dl71JgWOthtHZ9vcCRDQ3Sky6AQIDAQABo3kwdzAdBgNVHQ4EFgQU +qT+cH8qrebiVPpKCBQDB6At2iOAwSAYDVR0jBEEwP4AUqT+cH8qrebiVPpKCBQDB +6At2iOChHKQaMBgxFjAUBgNVBAMTDU5ldHR5VGVzdFJvb3SCCQDtZuaYqGYSyzAM +BgNVHRMEBTADAQH/MA0GCSqGSIb3DQEBDQUAA4IBAQCEemXTIew4pR2cHEFpVsW2 +bLHXLAnC23wBMT46D3tqyxscukMYjFuWosCdEsgRW8d50BXy9o4dHWeg94+aDo3A +DX4OTRN/veQGIG7dgM6poDzFuVJlSN0ubKKg6gpDD60IhopZpMviFAOsmzr7OXwS +9hjbTqUWujMIEHQ95sPlQFdSaavYSFfqhSltWmVCPSbArxrw0lZ2QcnUqGN47EFp +whc5wFB+rSw/ojU1jBLMvgvgzf/8V8zr1IBTDSiHNlknGqGpOOaookzUh95YRiAT +hH82y9bBeflqroOeztqMpONpWoZjlz0sWbJNvXztXINL7LaNmVYOcoUrCcxPS54T +-----END CERTIFICATE----- diff --git a/netty-handler-ssl/src/test/resources/io/netty/handler/ssl/mutual_auth_client.p12 b/netty-handler-ssl/src/test/resources/io/netty/handler/ssl/mutual_auth_client.p12 new file mode 100644 index 0000000000000000000000000000000000000000..85e74daa26d54360070772fdda36fe19fc8263be GIT binary patch literal 3997 zcmV;O4`T2zf)AMj0Ru3C4_^ieDuzgg_YDCD0ic2pPy~VxOfZ5ENHBs4{{{&vhDe6@ z4FLxRpn?hTFoFr}0s#Opf(hjY2`Yw2hW8Bt2LUh~1_~;MNQU%=fX? zuKbujOslwsDja>5nfa%PJ9QWwv?r$>fE-85 zgZY%86WJhgujVDHNbGppP!Z<7m6+Jaj*R81?B)KpE1Wv z`zgbz8r+T5oD%9ht58pE2WNzX!Oz+Zzkog4?I}#ey=|U%{bfJMOZGuo?>U!{A&jN= z=KoJNRFLf`rC>AdNv@7=mwtdU$eK7(gC;_aje&{din9e8rdy zuJWoUZEYK;3J3Jh-it+`tzVv1Jasd-OJ0^{3Bk3aTE%odP&nOH+V*kXpbP^aTIWUu zzEP8N5us)QFEn*=G|KPjGWBd9cslxda)m&ThZO=m`)&~?wU>?;{yg@jx56YyeD4Z6VcR#-sz^Iu`K${I8@oNG$y0*#h**g&zMomMI(R-xtJVhBa#Eut zk}oH5HJVe~7X%l`&H(XYfC79E>k+|!=SXog4KwRU79j0xLf_NPZ}%arG0#GnsJ3ZT zCu3w?1(*`2L^GK8q_7C_k@PBCG}c=TP>rn1Ew~6sY~<#JN_i@MQK$qZ@#BxTUUMl2 zwaV0qxf~kKjOTb;b;Kl{)One$Do-XDECJe*=*BK!dIaZ2O((0go)*fMFXL2z6qCDf zK8`4w87Q-OcL_1~Sdb;b3pi+6aEykrG!pbaPB+BcY5NFK$G2VGbC%(zr#na;G$@*h z6}>k=8vfWvgw7C)(4GernwR52MSLb(pQT8yB$!Bs7v}H1)PHe-XFhHbupw!AeRvFk z-0x#w4eOH?p~Yt7+cpEHABe+JSn z`c%1oTREm^X`sRHC53u{ju=z}a|`k)xmOXC9c}MF`zeq&VJTlcltKFs8((?%#OiU& zcW-Yd`IkVIL@yVUHH18&1j)w7_>O@wmY`MYM=!_cb6AOwpI4z!<3t#tC@=smOLDZ_ z-e02mLWC=$+~xQ#k7SqvZg8A_Lh)z7sWhmd3@}ZJl}tZmBYVtL-o;Pl++Arfh3EU&R*6G-#bS zNWZrqM-4s&=-pp4(@H5rNu?ylP^H7tkGYsh8lKL_q2(G zpVinDh)?6bOE=4DuD|p!EDNpjo%1FSdb;UK_)Zz`&cvs2)q?f$aD>B4L4~@$8Qu#R{PB`<(Sf z6HLHW&P_ZPPjgm=RaYS0CV>eZ0ZFi!Xsgy#OD?Y>NnCkb=cRo$Q z5qJ@(lL-;p4c5*o4-`i@tvrikItg))H zo^AnObFBwU!#=7X&sr*3k4wwwvoc+06Tc7r81R)#im}pUOAKVCmHz+CEf2X^0yjhn zup7>zFw9^FI0FzwwAmf`&Yt6vXTMw_WwcSK-*8sl)Nk!Cg#*1?PgAzNQ5r#l-O z=$&hh@&nVKD@Bn|bZI+RsY%Ia6_wR=M4N{(0&|_`OeNIjt3bBpXj~kQeFPEd!WvEM zu?=skrAczoxRYFoWdHJiwc#bsbv;Sw zwkugJ$S&o63d}&m2}64*g3W!1@k>xt6DL8^Ae|EM?sW`U)m&~1W7N?OF-==jS2IE~ z-KKPJaM$^6wo1v#Qo1iMl5X7ezDTC1`7t^(pdZM_GFPWg(}Xjm zyzc!8gbKPBSXOq|E&uU}zI2MP80L4vuU2UTtpv5+Bl#?dglYpn82K)i+l+wmTYlh$ z1cmpPnKrf#<7*9I$Q-Za&8ZXtwv?f=HlGmptpS_mKwPP+9O^lE=lVRsH>Kc##OJvX z#hreLPp{1J86qsx!#^Xa*^*hT5|q*1h~fvW{TP`;>j8W^3vq44C%@Ha;%&{sG0&83tg&MI_?$V#@<98nE;l!$&a@cX|cM?=dS!L%$T z?0B&ljs=G5f?zG^ABk_`7DXZcY3Q2;jqO?nvaSP-4V|BeZ=d;@Ebvf3!Dx_I=2U&{ z=0#S9AP5{yDCbb-c_QJ8Ul$5V)$6tJKg96B0B_FVGfhwe_>uG*x3f}GK2ek zX3|i%F9!mk{EmQv*VXTm`BcOsn*;&?KtKV%-p3^3=xUfcWqLHcMV_%d;faOT_IGwY zf|Oocb#dXUq3{;TA6_>1Y8Lp7mU?I(7L)wOFoD>#Ntpt~y|218FWw*_=(t6_TAae* zb`*t0Vj|JELWaFA;4Na=_Xec-$6ZF1YDvf|PQtYIr&P;xAFa#2(dbeDSIy99#@X4< z2o;XYyrbD!n0a^Nw8E%hgxBmU4RTH~aGL~T%ou7N@`POh{=0$-yM*LwY*#|s)=%9< z&A|$xg5jWtJ#hqSPBZ_PfkJdwNuCh6bjK7Z&I)ex;lDy!0vfCtEsJwqRTgZnsbS*y zuV`E53ul<1tc4(NRtc)G4*)R2)8T7D++IW@tx-VP`4h}=m(X}t7pvykPcWO9xSoco z_t&;|34U?Jm>-y6rumEFe{@{rcF|53tIk3BZFUGTb5EeOC>F5`iuL;JsXQ!BxH+oc z^!CtJFoFd^1_>&LNQU>$#&y4Gaz}buc zKy+X1g~OJlK4+O-%Nc8`+HJGoF5h%OIoiF6wgRs5zw}(W)t9Y=(#Fj=TBOIZ^YFZI zAv-g+G*%Y`G+b**^P-yq(Kuz|=g_UxnALON+)oQ9lkVZM-IhgcR@ohq2NgMC`YW9A^S2I2lZ_mYK|+#>L9Tya|r zDEiI-`TNv-gmNtn`L^k{c1 z2->Mpg1wHvsQ(3j%=E1o$1P4%{-C0rzg_#~C>3>WIM5X-lBfoj7Al7(a;4fM}|6cK+C=Cx%cFU*VLHc;^dv&5H^bZ`|2w_;aH#p7o%} zm4kYh-l#B2@i1Sb)vU1=d29IP6O8EPR!Lu7Sl?sS-m)e@C$IP^VYlLOIkH&q0yjMp zQi|u_pfb5VUw(RLPpFC9_IA7XF_6U*qisJN1%hMFP3>3D=wy>*NPXhM&l0i|Ag|>tvJ}P@`J{C9=>pCc^uAIzY?tSH zo!R&viQX3Lo!=A1=&tzE)FX>H^uzG zTU{Sf@%Nr8hZ%j6(M&NVFe3&DDuzgg_YDCF6)_eB6r%_uQ4FHnE<`qq?2cHyc$(H! zb}%t8AutIB1uG5%0vZJX1QeztrP(+-x4nID2B(D|vdhxe(b5D6IAMcVR#c3&0s;sC DBJX;$ literal 0 HcmV?d00001 diff --git a/netty-handler-ssl/src/test/resources/io/netty/handler/ssl/mutual_auth_invalid_client.p12 b/netty-handler-ssl/src/test/resources/io/netty/handler/ssl/mutual_auth_invalid_client.p12 new file mode 100644 index 0000000000000000000000000000000000000000..cfac6627c82881932e6f308ace0d93aab7280a53 GIT binary patch literal 3949 zcmY+GXEYm(+r}dz#ICJUN^L=GReOa})zG49?^vm=cEoC}8m%p;z11qUsZqN|6}4ki zo1*@G&iTLZ^S&SMbD#6O?sI><&JT{EasUz#!ZB3BUM)(;80Wkp%LvD&BXNAJOGWPQ!4{jt()xp7hSBmNY# zw852O(5UuAfgqX3#VISWDOK*JJ4fzI7rN%*kCWY!ba zCmj(3wGeuXM(jx_Rvi-*Z!m4VSJbL_nuji3lj@}Mjy5>ciaFgIYUDb6gOto9SBiU} ze2!$K&AWxaL49#v{l<-IDb1eZ({^Z+x0W%uX!!m6VbDh1JH;PM`DN`!=qf3No1NHG zS~}kwA>&6c7CP5I`#0Ch3s&h@%t?*!fJ&Aftouu9<`yu|zh2HQFvf3a5UgL5YZM;w zF3Y3xANPIXgyx{b2Goa)lQJj=ws~1kJ}+EUyzKD(RxN9>3t*@lYm(P4A+U2+3}Z;2 zC37#e#MRC8p~hpvitOGs83|EbXz1WK8p@x`lp!+2|0Ki1q9GcM+2>@kYn zx<^wjSAe+X-MrdoS1+PeRc^%N(`)pFq*|2AaYDk}@#kZXx&-&BJm^ha4=D3V&dVCN z6!OH-ZbxDHe(LHPc_DFZ>yObvt1#k2qj+i=MwK*U{?hMrtx>RGp}=z%Df&Q!P*=Em zhdgWP!Ru9&)t7_Yc8miol6pE#?YKlePobhs^^0^|XrH1NPZF1zGiIxBEg9V~73s!H z_twnav%s6?nn!yh*O`1>mDPKIIRL$f`;?3YZFvkWE|#N!WfNCa)Y6>w6D)Md>|c{c z@#_gNH`N9kj~*5YWZc|RjvCS8Sev6`6F>6e4XnXP?MH|D%bT(x1s$#H8qaAAp<)5fNixE>s_g*)Ep+OQ^Z z`kh}e6sP|D>5ywee4sHneD;+lym_WF**lqc{IejRZ{OItaxQ0e>>H=aGx`Bj072W~Z1LFzi`>ahweLkby>i10!Csbp&crrvq#<_)ZNS`8X*#e2kIoAbR#j ze_}`2sfAR0+uUPPa_ zC$PR<$SFju7fSf6lWcVG+xL-Ho2sU87B%3kG>g7X%tCNKMD^qLnCWXpI7~}lM&D`u^k}+kLlJaJK7$)!>A;gTR^wTY zTmm@d=q!yY8tU)(HFxLn2>bexhIam{(V!TthkgU~5-px$e=F+CxIM`BEcl@R**Wv& zp`GcApD5*~IRQs==@SA_vgMt6ufoBc54hv5>)6>O_0;?At5ha-=M`z}hPT%GS3=1r zyaA1oVaNOl--p*^!9LLdB7AaNF(ZD&r|HH!fn{;<#cR2{aWDJ?hSX+?=Z1Ivt_cyU zDq*!p*Z^~tR)h4e*$%Rn{suuZL@{E!H4A|dQEGQN2lwrrCr%`Bz9nSXMXEk`b40MW zKQ^``jY73k4f+I6@e`Ab>?cyhh1$O6R3P$zu^D;}+|v2uNMy$iXjGJ1vl^KSibdQA z;ERG^`vqMV^mDL=ad+u2{eWn=!eWIINTmjuQFW7emT47So@$TErSoDcUtgUa*(d;-T=J~MS49mwcF~1~GPVV-qKS2!-$RvK;-TqdKmm4LCBTCF z_}xDzLJ`#R44uWRbR`CizCle1{bvFyIR(@*>UZES*kP(ksVssCAzqEUpK0OA} zNFM#9>ZtwhZq_Esv@a5hJ@x<9!KtM~) zdWaN1@QCA1=}-`xK}53I=C{`g5>AE~qFftF1nxr>cw5v-irIVqmXm(5^V_MGlxIx~ zK=GyEOuTpx^d9T!m!Lgzs9NeB2?yECW!y&kCnl~Q&9Bx_9F_mQDT3{tH69CWncD&7 zN!pFf^~mYk+j13u%F?~`(7i+$5lw3pjOnHi47yN0z#(`>?f`kc8f9!L$V_}p=>=sMcsd22O6)mKYL zbWO`C@-8%3-o52LvyG|rnbDp$+W?)Q#}wPV^Do`v11O%t(pJ+JjANbq2A&rFvQk7I zNik{ch;3efClbE{s5?Dei^&>!wh2hm7aRIii{aEY?y^8{(NtPZhJTphYOs^Qy?q36 zH4+1)9m(Z2R$K;TJ-n@O&yJ91#XWCw*_`yD!k+z>xlOZ2FwteUN!n!xL8OZD&bkY{ErnCG#;F+@uLN8E52 zB2f^ANa!!-|2y1Zvj5RIQUV|j1KNXQKwJOM?g;;}J5^8OW^EEb(SPg?gaP##HmENW zSp7!4-^;z zo@jKgh-&^Njkf3PJ{}haR}}PgNYF$a8nv(A~PxKOg*vZ&J7#-tPofZE!p~=oF3^+1Iu#S4;9jQ5A`*2(LpqX8By&`nz_kSRNH{ zV2EDPhUv~;*Zj_@{;ol!H~g_xQ)q3vxNDhCA7sLOcHNHhBFjWYw%Vk;eNn_I6i@Q4 z>9hoFCMK^N;f@QsnTp>%bUA3b6JV&;tH?PaEyBWSv{{f>ug&xXM58}HhbX;|Y)q(V z7ed?STd$M-VPI!aK@F?Z?_WxP%wC|Z7POU(X0|rpD!Hl#EI}C_Yeq;ceSavC-K%#R z^?b`tTyN6^W3BZlrxXBbeL_x6XHNrYhx1lY2gyFB zX2l{t+M|_zW z-Ss>hGY$y(q!kmLhEh#2{QZa3O15?$qVC7z1#btyS@1d7V&;-##j+nMF(Ne-&OBP@ zItq^@S-D2O47`t`gA3*^b>TN5^3T*pWr0HdewyYlHdaZ*OU~GRMHMrU?TQ21b-R*A zUPc#BjF0mji^pGFibSpN_cq@SrdzJ*we{@Jgz=MY$s}($Otn*ewHSWLNK2=p)30l^ zK0QDY)Isd4qaQy9k+qH%8f2K>++#wOd!{f3e_RpgYNj-;hTW(Cn&h9aJwpt5i*9;I zFLbJ+T{90T$dg*7a`3&?gXZYluD~$3RBrKPi5A+MJr%5FdfzCyhrd=!6OcTagNt5T zx)gpEGMr8OQ~vH6yLX!^bMx#ewXl7v|AU8;fG$xBzR52T#?jacgz!2^LL||r4WxzI za*RA_C~SNmbMhS``kAvSyRHCk;oWkRs+wMjbyYswQ5dzw__y;JoJ>j^d-c!o1^S6P zG+QK|T5Y>OrO0F*r9xU!eoT-?*2>MqV@hUdF@wclXM%I6c=x z?k%>F%FWuEkMdntQfnu|pT;q^bZ{oNjeJXP_3FUbW2-m6fL^vqPXUM?_HI9L=Rf=ZU>b!Ic-yUJKUZ~KMi0CHP#P4}L&aStH(yDM} zX52<8XFHt5DJAM@dWm+TJ`y(+(gY;EOBsCdpzckzOxBCtpi7AB=njQy9v~QqLuQurb3pO(U85nld=k-8|N zXjy7Yv-frAhAsD%Qu^sG#uV$BWn~X&UvE(%w)6qbaztZn4sI`gFJa|NR1mid<5Jo| z$zhd?vu`r}4M)Y{n|I}Y0wxlmBm~|j0)XgVeQvfEo#pMsx|k5$4GvLo^W1a?kv!QJ^`SC7*ZZd> F{{td>W=jA7 literal 0 HcmV?d00001 diff --git a/netty-handler-ssl/src/test/resources/io/netty/handler/ssl/mutual_auth_server.p12 b/netty-handler-ssl/src/test/resources/io/netty/handler/ssl/mutual_auth_server.p12 new file mode 100644 index 0000000000000000000000000000000000000000..368462a0e1e9081be7d5cc6d2eaeed12ee67da17 GIT binary patch literal 3149 zcmY+FXHXLe0)+{L8VCtOK%`5RVuDoZ9qG~o(mNt65J8HRurw(W=>`m;Dcul?vIy83@Ax?(h4@n2CPHGt|b%lW4Pzf44B>Lzhj6l6V%iU zRIz9}-@=ZuG;)OJ67;E>I=;J1{^W+DuFlnPKw!5A+0NCoK-$EyMnLRE;nXBFT$s&3 z`ZQo8^QoSwr|C`QXyLnd?FO$M@S4G<(%c9_3JQiDh}q1iB4Gem?#BD_qa8VN;3$ zw>Q8=t$+6bhb~sHgms@RN#DsGt?d*J{UfBb!rLL!uLO7L9#B(*XPqS#N{&8a) ze_5Zs@a-#>mic;il?Eb3Z?ZaqhY{ts*XOnAB}U}Y_F0CoHmT!0L4d^si05l7#msD# z8Y7nu`xv;-*V9d`-_}dpmu-W7xQpoPs#1<0EZtwkYx`&Nfm@C_AN=a8eP*F(Kh61q zkRFyKooAi5Ku%jiLhcHs7d483U!jm(ibbS+8VIZxwLkvL!kgs$(y`y-@?d|(Rto~t zePnu%-ThgwC3OXKEj2zrIdV9BkGf)@;6+or1M`|&iY1>8rH)#hqnksZPag3*)jTwX zBhCnRp?bl)#me(zRYUW`n$sTY`Q`>zT5yNRK_{mb1o>s{!q-iKX${2Bs$DcrA4a*U zhbJ_cXvXrVExrOviuyEPnPMv%p!zIWA?O7+bHu_1#KLW*hEK2CrZwoDE%`aUj9<{4 z9)Uehl21KcgW`ZurL)D;gX~FNGr#J4z-?g)uURFE^$y3HHgRv?i-Dvi>3e&2IUDl5xnAY}qsJjlCQi44EW4|gEcf_QA zcRt3bN@#xhGZnF2apAQ-#>6Je7Ho%K8YyJX!t7+lBGE+V_UPPAaG-)e?nO+4b+p0x zm)T2aO~8*zqDR~;7~ThsbkkQ0+m6HzinS(++su!B!@h}a-oFzO*r#l(!JyoasppGo zEjDy!}dbQ!={sku#6>L0&xZKTg^jh8a z+Gw9KxUy`sgEM`Y&A%Z<%ccl)ww!Z!Poh+ZHoGFOLFjG4DpFawkJi&g5*}bmjZAv| zj&HDqCypJ-N8OfYSC{za81YT?&>Ou{SMDDor`U8~S1=TvgP!@+K-_wla z3jZtvw~ozV`xJh;+blSk^uuE4@R$pjra03|EJh45$;F|em!D2w@qBhgkUhipC;QR@c=Pia2AHbRLpP3yp zA}GX1YbVAocOR?b;nmL$`jg2-eFXk2p|*?r0HaZz*`pfHh;f&wG^c7*FtZ{dlGb)Z zzJ2v7n$XWs%q7sfILmlR>AtItHd`DaSg{)a^eTJy6Vf< zh(s88hr4x z!2w%hyvu1|<%^M+Fyf#HYHn8iEw@O1E2EcCc5{_`rc8PDL!5;+-cN$#dB)IgIW4m5qo9D_;(rohFLeH6TKJ?inlgwd;7XSJD$t;H`-F}4dZYU3!-|r*FLNc%vT{yowP>*}J98&Gp92Li zkBUmL`OuPAp}0$srua0kLN;j&3-x49-H~KCmta((%nf=8UY*OKSYDw7Te|uuz)~C` zl{vafoU){#MtaY~l?@j%$4S-)I3_YR6wf+uB#?g;W7Sld8`kdUAt+_7)&&HMn6;MM ztf2z)`hp0UcV(uD?168G*U862q2}I>fSKXbN~wDzp1xR-yRDSEsBlm17o}BlojQdP zfw$+*U4omoe>W~*U1hC|01u9ytAtGI%aLdXY1NTHt+4^DtO?GDqe}MU9FJocWPQ|z z8f2j|jp8?XnRqh&P zJ(EfOTz$gwWP?nVhM0yMdFf_L+slm>`(h=*CXwE2EF%QmqW@wp!_bV6aUS{frcLUM zA1^`#e*M4i$jE_!0ofm}o?fm{WTXlGcf=(r{!yHA$x3hpToBGkM=J@Yq2{Kg0_NoyT zt0kpq)u#v3RHw)h(3)XfB?GFcjuo5Ajb=0{`Ud|0Fpb$a3(VG^2!?`n+-$SWL;sP zdL4o~z_od}(i4f|`Rq|g0Zw7&1jBbcb~tKEW&?zSHq*0j#+_AJUe zKQt6jt&+jq3pW*QA9K&kv{EuFp}!@SCxK71!|D7wq3whkzx1iBRi#aEU`Kec^+@;xBpZ5%bjDdWS zH@q(bzFKER#(VKE7HvgWC}^oP>N{0aAi^I&Do)l5zFn0$?E?l+1*{@aZGMA?V@7Nk zr2IFC&2?M-pi~=cdu2iVry$9-uY@(-vD-yJNyyrzoA36FLXh={sGLYtJpqa>{+Tz$4k3s zkp}A0$zi-S6`;za9Dh?)#u_z2sg(Uivgj6+?_#$D->%5xMxe(+7|&BmvhpWJ1u=cXR++JJ`JbXKI_aN5Vo_FnFR$4*P`$6|iUdFM4yK!cktGrV?#TGyOAQBObmo%zs*(jmv9P|tqn5cI3Eik$M+Dgf$ z|3c@ZV+Tj{-Cy9UJ>zA_Ufeu4MHjN@iYuhP+W68#K1>k`-qozyTC1ypq8jK|mPU7J zB{IpnnhHZy!ydMui?0^;L1jkg$7)LR8-KbC)L$ZJ`833q5w7@%;kppa%XK5+ulb7Z zfMP@MHm$yK^=h4d5oP`)T&bnHe4>yyul2q1_DzjY?ykk~tx3l}dA4rD4Le>7$jjJj zSz0E3srAJr(N*~?w-t~EA7icsg+J$@OBD8jal)I-_QJ5odQ|VGpOBS!V!wZ0;u>Uz zK=?Q$4H&SDrV0a*?67e2*|xI`9vK5#B#j~`cSh*;!Tl%A?dCJx<)v^-rA`u-c|7xI zdh@HLMi_omR*|1ql4izx-iicv3Gd!5M}iN=){8hK$H#806vRT;1JW~Qv?^A5 z=l9aVCR$JIt{jFD`9JU!VHgz{FpRRNvDE2_fmr^-5HkRH>J|AED$@VyME|!=eX}Dd z086m%-#Rg17zXKd7rr1^9_?xNEdZj7JSNE-NA>aEeBSran)7Z^B{RAf<2Rj`xn)^7 z=4%fwcm|1Ooe8QjwdKjCkP6<3k1(kw6>64m+2=Hc<@4KT(Bjz*hx&0L&A%u6K+6Y0 zsT-7hh0%L@Vq7?;6nlD5lWcSB3>53!fNqjeGq|K>2Y-9+J!kUqOQ`CUPmiKEV=m*L z`E!CfLE^nz`F6rBHF`0n?rhj$#B+Wf!`&3@>>HvAB8mtwcfh&tb#q!H3a)RdQdL?M27o&j z&IWw!9?5Cpacw-AT}sIm7@^rWKeONsp_fR##W|&!RaiXDgY18{{e!er^~bsKf!soI z9||PPukGCA6Ez0b(K^ajzde{~cLFW|maj%T2_+#!k_C*#J$yEm^du81Z&xmw;2SmL zZVmtV{D#n_%?Y-q>>crlXvf${)7pYOY4o$gnrpV(p61N}Ep!ZvB3B?~4!D75T9CJP zQJM69TY>STaaBq-SDd#ZHj{p;=gE$5`+xU7T)*w8-)l{#Hu1)l8ELl~0(CtMnljYeaoY87vXkk|nxB_aMa&{L~L9#o7 zH#w-`LPkXA)orCR=e`5bOp*tX74?g6WfZ9yLhGD#-|S6l&^$CDvY0fB-W)y`GrCTD zn&Y0kZ1+ge?*!pnooIs`PoiD|-4B%ZHFJbok*fhS;}pe6{)$S0kZEk`&`!H`#T*cm zR3w`D!AJ;5XpvBCS8}s$SnK^z#+NH|{Z3@CK7=fSuhWTv|Ao7{!I^x)=Z2?q3cK^Q z9ToAs=-GnFz;t?D{}3)|x2HMN2Ig(077WeN(`Fi}1g`SQUHwN(>|J_^0&2PZO_l$Ey`)>(d+m6|r@bz; z<#uNhVaj8p4xhG&^<|1FmDmsBgho$zXinR!WNJzG(#@X|pr!nGgQgv{297?0AG7Nf zr2sbL_^5|Zl?&+xm|5~4xo1T~Y{x z!6`sp?2^-ixbnni1vqz1U;-_8a#(6JH<8_ZU8npKlB$NmKqgDxG0r_K7@N4ie8nH~ zp0PZ~==;I7wWPE%aXIG=C}6)AxN zqWp%H8hTsf#t5>bjTWnf0qiwO*~lc@XxK}5I%@E$k&rQ!ip~vGy-wPB(UR8)3W-Dg zK=5Q%g&zxqs`d8BLL5;))}>b1;1jF#0xfkN6TYl7ajT^QhJ2hwJu+V+BLP*w+Dy(v ze=kTc(WoM?2Z@Q{%xqW0 Mv2qxG`fnxw1#!IPg8%>k literal 0 HcmV?d00001 diff --git a/netty-handler-ssl/src/test/resources/io/netty/handler/ssl/rsaValidations-server-keystore.p12 b/netty-handler-ssl/src/test/resources/io/netty/handler/ssl/rsaValidations-server-keystore.p12 new file mode 100644 index 0000000000000000000000000000000000000000..dc16cab487640062f20c84eec3825f71c080c9ee GIT binary patch literal 3456 zcmY+Fc{CIZ+s4h9v1NqpJHumaW31UJTXqw&4zgxVh>9MJHDheiDA|*3M3zKj9eYKR zEku@tF^U*7mcE{IzVCb9_mBHr_qne7+<#s_1R4?x1kfYUkjo6rvMEuFnlJUH@rDJ$Fmncn%D1)HOyX$Z4y`Ci>WL6AXAPrYE>;0yerIuy;V_X#DTNB0+V)=7jv;s73^2dN7$y6SGrgisa1cAYKi|w&uU_8h8nNu8= zW1HR*dcgX`Uf*KF1RoG{90L1T-a)U-$}xHiFkz;j<1;Qq)t98TjW{QgRP_H!E!iBG zc$Svw2_If$K9w-`L(i)Gz@%{T>d=SZr+=e&_%3hjJBH_Szd&wU(H1ghP%B9=Gn=ob z$9o{d%H442Hsra^9eeqx4=bp*3UyrC<7(=bnng6JrIAjlC7A=f4ka;TRjNj-=p{8}U-+Pl= zeS%G+Q73x@77yb33urE+<)dQX#^hMXPHXtY+;)-2e?i^zG z&W!$qUO5%yD=xY=nt9mSHY;$$|NSKhlP1d@n<&7_?^OS9auzhzY53`vV$Hf;b(6w# zp^Aqnvr~EDt9Be0Yn5!x4n_JTu_>GOI?0ta+55RV?w2i~SMn%J($;6}Mn}As@(KsP z6WU8rt#O$@)ukO1Ti7H7JLg(4VRt(}XH?9o!p0iY-kU}kj$#vd1)tvKFl?6W)apoZ zMZMgNQtx}sY~nR#0C?4#A+sZzvT)@aAbJ8*_V%vU&ZI9dIl8z2ePh$-Oxn|C-778s zZwgh%iC||fo`Xhowls{n>OM=o*ldJm$~qxwn;)j9*~vxv8_{dGY2(Na`yr^VVr=pJ zgAibAQsZb(&(6-^mkh@+&qv~h;uoECLOsBA&n?3{R03lC%?Y0h?kid#fV-@tO^WQf zPeSeRH6vjCV2(KdXP=g4nX`y7=k0TNPe8F3P0XSg*#=OC#RDJpc#!)CmL7Qtx$yXFfsxl zw9O#rh2f5X5dd)u>f!%vB5Ffl3PjLM2-cOWSWm?NHJ})C`GlpP|%)p z#>k`5;EWqQWieYbO2y^1aweCgq*%#SSexFfKfbM|IWy5xA=9mp;)2Y&Ja|>gBx!tF z+X>pO_M7tIcxN;GQ>?%a^$N_Z>b3`95vwU>5HWp~!w*j>aNO6(lrJEkaw)@y?go`> zAGr;&#Bdm^(X<5$f03Em$G-A*L+)=^2Ojm0d0WOj8KW#5zOR0DMCCH;T}csAj|M)f zvmN06eHzTt`0%X#|Cg2t3CG3n@EcDhHm;{R?156Irq4| zexoZA1xu28hu`cj95|!cZ^A_PkZarco-}vp}HUZLwL<( zeet3oPjtQP#cF8Z?u1K*2GN*@PjC1spdERx%CsR_V07}h-P(Dne#7%(dzP*f`@^hP zh46TZ6Mv$^D=-~1Dy)1si%xv)E5>>zs@BALsZ75$@S|SXdEp>F44ECAJ-d$(ngaM?G%I#Fcj3%+tw8+Np6)47~)&kulwg>JS^K z$~^KGzH_D^u2z~_2EjaS<0U@%{URU@{?=-C7gc%x*+7!2q0y?_J;`^`Wp)|%K<9Q{ zWl@x`Dzj~OLU&#=uBe?wu3h|^4Bbo%B8Z;>E|7CJQXhf6Lu5-|3)rAww0MJ*#!~V= zhaxVLZY~;AuTN{exu5L;aLmLjL+cEQH@W7IMV>x8o{;tlRX@rXDwxGNw_FKty+$tn z5aZ3+=Wi524%TX zxg|%-M5{wzh4V5V@2i8_h zlrTJPt{k2COX+pp6d`kUz|HTQ*Q@6m9UkRY4{OG?V6XZ}tiXLF?IHa90)u5nqrEse zJWrKT_H3KC@%_%Xp{_!2qg8pqP6wB44NG<(oa=)SqPle7qLI)5-xq>&f31cGq@rb4 zO)EO(ES_F#?~x0m6Pm6Ii8t6KDylRU-O81N(GEc6wq;_Y)9B>vU9<8hA(e(BvsT4- zS$`-@Iz4MAXbASgnwVd@KIrnM*o7Q|AO938&w$DW^1Ih`jhgUbdPTlg?5mhB9^eJE zU2&<`0eN25XTo9k)`GbU+CZe9TQ6^&^p1rORI_Q7+Aam${5+D)e03VS_TpFCjWs>L zchy!`0bWY4@^ z!D4$|!_!ad9@&>J@_Va=eO!Dt?pcJOu zDJNw!K4c>*7G?KJ+sR3F(7?~Zb%AxKlLRCNc=8?Hz&UgDcQ{57iMr)RG{D@I# zx*n@q-wYM=V5~-(@6-_YSm_Ev{QujLnM;XFgp)3q?k1f#-St0u@n!Qnf917DBQ78!I YkwbVW`(r^&Xnp%gg$r8l{}k!}0Atv3&;S4c literal 0 HcmV?d00001 diff --git a/netty-handler-ssl/src/test/resources/io/netty/handler/ssl/rsapss-ca-cert.cert b/netty-handler-ssl/src/test/resources/io/netty/handler/ssl/rsapss-ca-cert.cert new file mode 100644 index 0000000..78ed3fb --- /dev/null +++ b/netty-handler-ssl/src/test/resources/io/netty/handler/ssl/rsapss-ca-cert.cert @@ -0,0 +1,21 @@ +-----BEGIN CERTIFICATE----- +MIIDYDCCAh0CCQCEtGYLeiM4NjA4BgkqhkiG9w0BAQowK6ANMAsGCWCGSAFlAwQC +AaEaMBgGCSqGSIb3DQEBCDALBglghkgBZQMEAgEwRzELMAkGA1UEBhMCR0IxDjAM +BgNVBAoMBU5ldHR5MRUwEwYDVQQLDAxuZXR0eS1wYXJlbnQxETAPBgNVBAMMCHdl +c3QuaW50MB4XDTIzMDgzMTAyMDUwN1oXDTMzMDgyODAyMDUwN1owRzELMAkGA1UE +BhMCR0IxDjAMBgNVBAoMBU5ldHR5MRUwEwYDVQQLDAxuZXR0eS1wYXJlbnQxETAP +BgNVBAMMCHdlc3QuaW50MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEA +56mDzI8MrVmNSu5RyMxTBlbJ8K/A2Sr0j5d/pw2xkxINZkwEa4409hR2QKc1YW4C +NV+QNnQJRlDCoqvBYFv67kQEunEsYG50byiKbceYmU8SjXO0kxujK+vtkGW3LDsA +c1KhwNKsN3JPx77N4PvRxk6YK1XcvFs0aBwrPRjJrDnqcu5pjZH+j0S0c2csU5jU ++Bc+fTah1zRcilFWMSDfmInATU8ckadnNSCPDR8DtcIwubYJMT2hAiWi1LULWd+s +NUAVOzshr81MQexsF46vQ+wLFRBFwuvEzSQmwpESwXSPa8MScEc2UxytftwdAkL7 +x5h8HxuUUpSSikVSRu9nJwIDAQABMDgGCSqGSIb3DQEBCjAroA0wCwYJYIZIAWUD +BAIBoRowGAYJKoZIhvcNAQEIMAsGCWCGSAFlAwQCAQOCAQEAGFTO7GZ2UhzLBAXn +CNq5DhuL374h4JZWdLa2zyvqlheYX2y/zA7xp6PQkefPHYBgLP+OH46+UGi0ZPuR +9Jm6t1jzn36jtVIwfN7kc4Dz5chtnLPWlXw7jYjUtT2SfHAAUBN77EjD8XOI7apj +BnMAmQngfke3bUbgvV5nUWnnKSiYIg+zMbEk2n2DmMokJD0nzV353p928T+T8+7D +kTXWUMRjzKhVq8kuMnJREwXCd4Is4kOj31mKZGuRRLJPDX4IGbrFxmqZI86xYTIN +3qaioheon+I3gSJYsl4QtCCZVM1RoTbhznFPnJgqSBR8DDAi1RLodavMmWYYgTcx +FxnKjA== +-----END CERTIFICATE----- diff --git a/netty-handler-ssl/src/test/resources/io/netty/handler/ssl/rsapss-signing-ext.txt b/netty-handler-ssl/src/test/resources/io/netty/handler/ssl/rsapss-signing-ext.txt new file mode 100644 index 0000000..9716541 --- /dev/null +++ b/netty-handler-ssl/src/test/resources/io/netty/handler/ssl/rsapss-signing-ext.txt @@ -0,0 +1,21 @@ +[ ext ] +extendedKeyUsage = clientAuth +keyUsage = nonRepudiation, digitalSignature, keyEncipherment +#subjectKeyIdentifier = hash +authorityKeyIdentifier = keyid,issuer + + +[ exts ] +extendedKeyUsage = serverAuth +keyUsage = nonRepudiation, digitalSignature, keyEncipherment +#subjectKeyIdentifier = hash +authorityKeyIdentifier = keyid,issuer +#subjectAltName = @alt_names + +[ extca ] +authorityKeyIdentifier = keyid,issuer +basicConstraints=CA:TRUE +subjectKeyIdentifier = hash + +[alt_names] +DNS.1 = aws-dev-node.skylo.local diff --git a/netty-handler-ssl/src/test/resources/io/netty/handler/ssl/test2_encrypted.pem b/netty-handler-ssl/src/test/resources/io/netty/handler/ssl/test2_encrypted.pem new file mode 100644 index 0000000..a17f9dc --- /dev/null +++ b/netty-handler-ssl/src/test/resources/io/netty/handler/ssl/test2_encrypted.pem @@ -0,0 +1,29 @@ +-----BEGIN ENCRYPTED PRIVATE KEY----- +MIIE6jAcBgoqhkiG9w0BDAEDMA4ECCqT2dycwPtCAgIIAASCBMg/Z60Q85kVL5kv +q8WIIY9tbXo/2Q+6rspxdit9SRd86MV9QRdfZ5Vjwt0JTa+Rd1gMaNK4PySW23bq +F2+dD0sjVBcE24Qg0h4BcmL+YBdTftBfk7NDH/rHhsew7DZru9fdDvkO9bV3jXIz +fARW9U7JIfgAi6CfJ8Q1PS7sg6dVtrcjMRIie32x0TSbZrn+h9AaXpLHsC8oXiyY +BhWe4i9B7PobyJ0r/CTBFhbfUCGwRyHac0+bZXvlcwX9wy3W7jagc6RDlznOpowU +FP35CQGeKsJ9WD+yy5MU8X8M8v+eeaJk4oX+PSWJX669CxbYocVP/+LUtOXpe+4h +7yMmVNLUtsgBlY6tNsU0XBQkrqqb+voSxVBEVZ1WTKgLWsE/EiQ2P2GU8Gnr+J6c +/yHxw0D4q9J3jV40SiuXQlgFwlf8u9FuVjOcGxTidfKXyvNqPKqgkf9QD+7E09q3 +JQoNbI/A8BXrpdx9h87Gt0TblPwVJP2nf5whig9W62R4y9SWybUUNr2MFNkvEfKe +1QK8isf+HlvIO+VBYi4jof9HkWLwnAszlkpC+k1cOiSjNRn8QyLzsqX7A/VuS6W8 +6kKeND4yRNA4b7rfQqhyGg7gBwiwN+22UF6SKiikX4TB1ZyLdzlbPe0L+X/Gq0Jz +Kf+8/slgzB5K9WpDtKsARH/lRPAx1rcascvFxMuCJL5O9MO9l4xWDJor71WgPC2N +KwXxvEW3Kyvs3pSgWc8MC0BKcD9WIAahAlAVmSQBxDNWvJlGTgUVhzPqan7h03Fd +nWAxSn315ObfK9rjbqUBO9x/nkSZFS9nApmeiWkOIwVzgNfAfb9md07TYyC/rpK3 +nGIsThekqqQULMQaAPmEFqUj6A/0KlpBj1gZwddYvVvEL/MuQO0QBdz4n/OncxYP +TVoQEqXsndmNQnkuk2Kr4FACV2M9rbr84HJUIZVGGVSM5h80GrRqK03qpTzM8Nkc +e04R4KDpLDKHm+G4xYZbbraIGXNTkhxTqdNA2FyjJWFurmpQyFay55vC6WBFBVNA +BGVIqD1/9K3dJJGlpiHyymRCK9YGvflZlSr7dm7PW7PPEthwTijbAHkABOKsFSiu +xaUj027WIVuDb5FFIAaF3Wmn4GFXvsSH+8L95CQuXGB8J/5Buo+/Hg6S7PeDwrf+ +qNRAfg9vxo+AZOWpWfGEYGHQeX6BxVjdffar9RwL99cele4h2FgBLtIuAXvgLPyx +b+MIjDliCe1Nqx0PCCuaB1xRnaKiwbl7itDidzI8BUAaFcKxbBH2lpr44+vYPVHb +70Xrw55RLvrVYKAcaZgryTNOvbRatifJIMg3kf8V++2rwUMoZ+DQfXin/C4S/2/b +c6I1OvYaGxmI1YiI6qSpOryDSzTNlDEWcdh5feuixiP5RbyaQFswq2fH0hsWWHS4 +OsCeqT0nm5vd1CdUFQJ4Nuh/TTdgCAVKk5yJZJvH2BX77I2d4T0ZRGHLDKUm8P0E +n6ntrMqLFR+QooONAZg0DTaxvbsCvaupRJCn9NgiwtXyYJKbvf5F8NEOe57NoGwd +LqQ332mVTuJ1DiqnChLoe7Mz7OY21RsTa/AK5Q/onClvBATrLD0ynK4WiLn4+hGs +HK5t3audgdnrLxs4UoA= +-----END ENCRYPTED PRIVATE KEY----- diff --git a/netty-handler-ssl/src/test/resources/io/netty/handler/ssl/test2_unencrypted.pem b/netty-handler-ssl/src/test/resources/io/netty/handler/ssl/test2_unencrypted.pem new file mode 100644 index 0000000..209a9c0 --- /dev/null +++ b/netty-handler-ssl/src/test/resources/io/netty/handler/ssl/test2_unencrypted.pem @@ -0,0 +1,28 @@ +-----BEGIN PRIVATE KEY----- +MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQC4HXSOVyXFCuOT +ja+EE32WcWySXjmCjxFlSFivz1s+hCV7d4Q3aZp2kVCgL+JjCBZcVlZMU6mv3Qd4 +Y3JuYT7zYDwFGhFszLRnd9qfebCc28Ylr8YX+a6hdJ9WZOSl/hHGrIgrqTYN0y+4 +H00kmX2hkln+OtKGgkulglOF/V01IcrOzeQRTPrkCxNOPAp5AQjtQJ/fNBWrv9Pg +S14FvoiQul1kyW6Wfq+6w6QSpfh5kEObYBpQmXJ4ZkvW82E94aGoOiZA6CMGxbFT +pMHpsWRm/BrBhO+P6681CfTHUfxEwgNpJtZ2uhErLdP8JtKg93JSWoT/MRKGiflt +UxRos0VpAgMBAAECggEAYiTZd/L+oD3AuGwjrp0RKjwGKzPtJiqLlFjvZbB8LCQX +Muyv3zX8781glDNSU4YBHXGsiP1kC+ofzE3+ttZBz0xyUinmNgAc/rbGJJKi0crZ +okdDqo4fR9O6CDy6Ib4Azc40vEl0FgSIgHa3EZZ8gL9aF4pVpPwZxP1m9prrr6EP +SOlJP7rJNA/sTpuy0gz+UAu2Xf53pdkREUW7E2uzIGwrHxQVserN7Xxtft/zT79/ +oIHF09pHfiqE8a2TuVvVavjwV6787PSewFs7j8iKId9bpo1O7iqvj0UKOE+/63Lf +1pWRn7lRGS9ACw8EoyTY/M0njUbDEfaObJUzt08pjQKBgQDevZLRQjbGDtKOfQe6 +PKb/6PeFEE466NPFKH1bEz26VmC5vzF8U7lk71S11Dma51+vbOENzS5VlqOWqO+N +CyXTzb8a0rHXXUEP4+V6CazesTOEoBKViDswt2ffJfQYoCOFfKrcKq0j1Ps8Svhq +yzcMjAfX8eKIDWxK3qk+09SBtwKBgQDTm2Te4ENYwV5be+Z5L5See8gHNU5w3RtU +koO54TYBeJOTsTTtGDqEg60MoWIcx69OAJlHwTp5nPV5fhrjB8I9WUmI+2sPK7sU +OmhV/QzPjr6HW7fpbvbZ6fT+/Ay3aREa+qsJMypXsoqML1/fAeBno3hvHQt5Neog +leu3m0/x3wKBgQCCc8b8FeqcfuvkleejtIgeU2Q8I3ud1uTIkNkyMQezDYni385s +wWBQdDdJsvz181LAHGWGvsfHSs2OnGyIT6Ic9WBaplGQD8beNpwcqHP9jQzePR4F +Q99evdvw/nqCva9wK76p6bizxrZJ7qKlcVVRXOXvHHSPOEVXaCb5a/kG6wKBgGN6 +2G8XC1I8hfmIRA+Q2NOw6ZbJ7riMmf6mapsGT3ddkjOKyZD1JP2LUd1wOUnCbp3D +FkxvgOgPbC/Toxw8V4qz4Sgu2mPlcSvPUaGrN0yUlOnZqpppek9z96OwJuJK2KnQ +Unweu7dCznOdCfszTKYsacAC7ZPsTsdG8+v7bhgNAoGBAL8wlTp3tfQ2iuGDnQaf +268BBUtqp2qPlGPXCdkc5XXbnHXLFY/UYGw27Vh+UNW8UORTFYEb8XPvUxB4q2Mx +8ZZdcjFB1J4dM2+KGr51CEuzzpFuhFU8Nn4D/hcfYNKg733gTeSoI0Gs2Y9R+bDo ++cA9UxmyFSgS+Dq/7BOmPCDI +-----END PRIVATE KEY----- diff --git a/netty-handler-ssl/src/test/resources/io/netty/handler/ssl/test_encrypted.pem b/netty-handler-ssl/src/test/resources/io/netty/handler/ssl/test_encrypted.pem new file mode 100644 index 0000000..58d181e --- /dev/null +++ b/netty-handler-ssl/src/test/resources/io/netty/handler/ssl/test_encrypted.pem @@ -0,0 +1,29 @@ +-----BEGIN ENCRYPTED PRIVATE KEY----- +MIIE9jAoBgoqhkiG9w0BDAEDMBoEFDBlaUwB8TQ9ImbApCmAyVRTTX+kAgIIAASC +BMhC8QFNyn0VbVp7I+R9Yvmr+Ksl0xZshGg3zaUN8/HRblNSS3gPiP673rmnhcU3 +PfSNFR9hOrTqdtd5i6Qq4HznECs81KBlqRNB9ihgy++ByFkf6GTzdfBA6zJInhNx +qSWjUwpFtV4or1w/N23bTcpdGmjfdCSFBMQdbkIDgT7GaWxd3mCLxSbfVzF64tev +x+V22nA/TR0VWnG+aj7aVbReK6VpepiCX7ZmQ5KehXAeB0SDrgT89kcz2VIfDxvE +hkCymNTcJY/ETdPfTSiR+DSZvVJMgVmfk7j1toZZSnoMwl4IhlXmIPmDOUE465l3 +sNWLygkNKymTmMI5FTT1hChAIdsmeVTfDmVzNPK4HQi5gfEnTCy0uxj9U3HCZWr1 +Zlzmw7/430TRqNYSEJ/XkhFaV5V+6LfeZOyuwf2VJAs+CwNo+UYzEQqkW11JMqhA +i9fz8bCNoy4/dyWbE/wEK8UPGif1rzCpoodBYeWTt0QtHcIokE3ylXWyTTarz7jV +u9Rnbq4HAXYYEwPjLmWFQ6NeD/rx/t44oEAyekxS+ZPIHNTVXRLBH5Tl/LDkpK15 +x0FoIZ0vrDiFbmtHCq/TeDyFtudSbmihnn0Of6PtXKZJpXgEADQBnak/P4IE39/d +1hWd3H635goC6OkqHv9IAAyLlCNZCOVqC5Wa8TvyZdaKi5A2mZfGrpxPrUQDlnqN +8d3xlysNCaRH1hSMw4hGHu0xxGJaK4DQtklxfZB7IMMw5MkQh6Rim5TOXfopmzmK +PISJge1atiHbVIBP6sr3Egik3h6v0j7xXVmwj3UUQRaSBznZ43ShlYieLnin9sh8 +x/gLyvQrtJRvScN6skgrXFKVH3Jojxut9if64jjLo4C61UgNrvuka05treRTI+jT +hHB3GLy7hwSHnbsOvwvYbG3WgyePPq6jIM+LV4Vm3fPX6NPNI/jZMebROGwjTL0C +2403yvgeIpEOQyZpKsDBqAwgKB91Na53K05qGSbr8AgcZvgFflJdLzai+5Cg7hNg +YTEff0NKPeYnk4u3xQ8EqxI2jwdqfgzd0RcPcx60CHRBTULaKOU2sAYTSpwQmApj ++TnJNcQnWRAEcZ35b/b+oGlVH/BUmvjSdu2qvvU3g4GoHL7MuVGvzk0Cgo1Esktt +S6gO/pTQPaKGJ1ztxoHu2zzi7/URaus3sqI5qV9krWMSa35BMG21Eik/y9rou6LC +yT0EtMLOCxSrfM1I26XTU/7qPIEJlVZg0CJ39niZ7EEm1Hef0cmT8Aq9t5cRTyvR +BqbqBCJpcsgeIZUMH6RJ1zv616eJvY7wjd13Sl0Tbj9+nNS482D9PIlaXSD8UySh +mZ0bMPhCeyOsmRmz2qT1X+Zct8XtdXc/NPKBA6rnOtH8vJAHn7S120le5XIn5t9l +rDiO1Hozhb+0xcTk+SNc/vIORA6KrBoZrNpJpmyL3BzRp+/VLbR+/S3ikTDkYj7J +sktK2ap6vK7u50Jnrt9C/wynVACzGx1tlDVxiVerDmwjfQWL08qCXHlouEdjh9dD +L5XyVlT2FxEXXLRgKGHxFaSQw3Fzzug/o4SgizbNjKffJU5xQlC0aq3WX5+/l3Ic +LWTalgdli3edsR/9RGuu8EsZ11dmNh3csGs= +-----END ENCRYPTED PRIVATE KEY----- diff --git a/netty-handler-ssl/src/test/resources/io/netty/handler/ssl/test_encrypted_empty_pass.pem b/netty-handler-ssl/src/test/resources/io/netty/handler/ssl/test_encrypted_empty_pass.pem new file mode 100644 index 0000000..d28b047 --- /dev/null +++ b/netty-handler-ssl/src/test/resources/io/netty/handler/ssl/test_encrypted_empty_pass.pem @@ -0,0 +1,29 @@ +-----BEGIN ENCRYPTED PRIVATE KEY----- +MIIE6TAbBgkqhkiG9w0BBQMwDgQIr2JTEf0lU/cCAggABIIEyKHtsAnPt1FmHeAF +AWZFF4KddWUW13iaaxXs+cHENCN1dpVxDPiCWNIkstQxC16bwKlfb0qFz42jQg4O +DM4kmtUyJoAt76wewteZvFPjzSuIepvJc8+SVaQuZPafSbGUcSPyATZbA2LmRWcx +HIvuya7gJX40aZm1VT1NM439UybJgmlrByoRGBjsrvYXGMZnMWLEYBgccN1wZ6WI +Hpuzv8LlIqFXL1DZZFFeFI60INlvfdwZWsz5SGf1zP68ZoGlS3R7LuJWMTEXfDUR +CLSturHFKI0jRAsloSYuHV/dJ6Io8hYWhAu84wc++dMTT+iSnaMWv7McF2nEFWx9 +UVKZiar509z/Rm/C+yN7TkT8TMaPRoYC6jxDW3IsE2wpBdLSCeF8Nh0vCNeay13I +2OcOFz1UrZcopEqXwl5SBUfhU3VjNm5N+h2p9WW9HTa3TXZDbZ04kyJEPYj3e8OE +viVDMRunJGWRkJ2//oCp2E9+NBYMMj9gN11mCUTyFmfpY/Kec/0nhQMCcZRlMzUD +i9AAzCVsm1HYnSzRq4JwnTbt6KBnPyS/rW+IfKQ+2aI/7Or6JfzrpX/v7/ECe98S +zaVhO9rJjsfML0ceR8UBK+dkl1R6hYV2+xXO8UZEo+X9IR1c03gDXMW0nZ4/Jgq4 +d3PmzuAuObCGnyr2k3PJqOoNZk3f44qfkNPQgbIROgX5zA8GNpL4XxDFFOW4YN2i +opod3doJ5r00WVmJowrKjWYMs0Ljik9FFla3986oZ+s/+WEWoFOssXvJMYiYQUIT +mrvXsyqNI+Nhs1rl8zzuZQ0BKGfwHyrUEcAgngH7itKR+IoqNXx1+pcIqJe/oIQr +oIEr5bYxAtDSi6DUGrvsuVoe5l5ByE+YvW68S7gIore4UXl77EVWQ5yfuyFteQE6 +Bm8vcnjGeoo0YDPwwziN1MUsYJPX6cnFIV20phXm4w5YAYfrI2yKt3AIKFkDU9n0 +5yLogR4NlkKgHd2DjELJXqvFBSivW6wZXL0wusDI5imopET/SD3vUexYML480Ulk +CxhbYc9FAgcKgYr4K6vyPAYj4K/w0qKC7ceAdJarPxS2qeXBHeHAepr1wHBu7G9S +7J8yHht0xB06Ar3nJ6J6KcCikv6ZMpQImYeSbE/5jrsh8hnqdDmKz5dAvwq4fwy9 +r0zUev0HB4+ON4XH8XAcwt5p/3zJISRpQZb4G7vm7VaA5cIbPOneuMxO2pw3CTI3 +qXnvWWkGdoc4BYacHJxA/awCUk6SrglCM5X479BG+ZnMzkCuXbJHn+zh2BgSQ7B3 +Z3JsNtnCaPGp8QEqFYk9tLdFcZ6/PlUsTawcIjG/6VMvpLcNBahHqwqDg1rax4Mx +xJDnGfVXFsbBQIYku8yvrXi7QnANgD/N+AAAjo6CDNJZmhsc2nhE7fsVWxz1hbl1 +vK1N13kT8CbfOFCUs8Od/O369vbv5C7yvK3JeG11sNfOZ/MqYBAVPDFLPdtJwFu6 +0M/c8+TDgLOp+qlRkIgDwXkjed4ncjDHSl9g41Zp4n1uR94VsKtneCNLuj4+qr4h +uhBnUap1s7uBZ0PW9TA5DXXNvhCrHu2vz3YIMBmlYv/i5qu+bPtPki7eTfm8UWAB +p9pJJHa/9J0gfM4a4w== +-----END ENCRYPTED PRIVATE KEY----- diff --git a/netty-handler-ssl/src/test/resources/io/netty/handler/ssl/test_unencrypted.pem b/netty-handler-ssl/src/test/resources/io/netty/handler/ssl/test_unencrypted.pem new file mode 100644 index 0000000..608e7f4 --- /dev/null +++ b/netty-handler-ssl/src/test/resources/io/netty/handler/ssl/test_unencrypted.pem @@ -0,0 +1,24 @@ +-----BEGIN PRIVATE KEY----- +MIIEvwIBADANBgkqhkiG9w0BAQEFAASCBKkwggSlAgEAAoIBAQDb+HBO3C0URBKvDUgJHbhIlBye +8X/cbNH3lDq3XOOFBz7L4XZKLDIXS+FeQqSAUMo2otmU+Vkj0KorshMjbUXfE1KkTijTMJlaga2M +2xVVt21fRIkJNWbIL0dWFLWyRq7OXdygyFkIiW9b2/LYaePBgET22kbtHSCAEj+BlSf265+1rNxy +AXBGGGccCKzEbcqASBKHOgVp6pLqlQAfuSy6g/OzGzces3zXRrGu1N3pBIzAIwCW429n52ZlYfYR +0nr+REKDnRrPIIDsWASmEHhBezTD+v0qCJRyLz2usFgWY+7agUJE2yHHI2mTu2RAFngBilJXlMCt +VwT0xGuQxkbHAgMBAAECggEBAJJdKaVfXWNptCDkLnVaYB9y5eRgfppVkhQxfiw5023Vl1QjrgjG +hYH4zHli0IBMwXA/RZWZoFVzZ3dxoshk0iQPgGKxWvrDEJcnSCo8MGL7jPvh52jILp6uzsGZQBji +bTgFPmOBS7ShdgZiQKD9PD2psrmqHZ1yTwjIm5cGfzQM8Y6tjm0xLBn676ecJNdS1TL10y9vmSUM +Ofdkmeg9Z9TEK95lP2fF/NIcxCo0LF9JcHUvTuYBDnBH0XMZi0w0ZcRReMSdAZ2lLiXgBeCO53el +2NIrtkRx+qOvLua9UfwO2h/0rs66ZeV0YuFCjv067nytyZf2zhU/QbCHRypzfrkCgYEA/facuAJs +6MQKsNvhozoBeDRMkrZPMh8Sb0w50EqzIGz3pdms6UvCiggoMbhxKOwuYWZ689fBPGwm7x0RdwDO +jyUuEbFnQFe+CpdHy6VK7vIQed1SwAcdTMDwCYbkJNglqHEB7qUYYTFLr8okGyWVdthUoh4IAubU +TR3TFbGraDUCgYEA3bwJ/UNA5pHtb/nh4/dNL7/bRMwXyPZPpC5z+gjjgUMgsSRBz8+iPNTB4iSQ +1j9zm+pnXGi35zWZcI4jvIcFusb08eS7xcZDb+7X2r2wenLNmyuTOa1812y233FicU+ah91fa9aD +yUfTjj3GFawbgNNhMyWa3aEMV+c73t6sKosCgYEA35oQZhsMlOx2lT0jrzlVLeauPMZzeCfPbVrp +1DDRAg2vBcFf8pCXmjyQVyaTy3oXY/585tDh/DclGIa5Z9O4CmSr6TwPMqGOW3jS58SC81sBkqqB +Pz2EWJ3POjQgDyiYD3RgRSPrETf78azCmXw/2sGh0pMqbpOZ/MPzpDgoOLkCgYEAsdv4g09kCs75 +Dz34hRzErE2P+8JePdPdlEuyudhRbUlEOvNjWucpMvRSRSyhhUnGWUWP/V7+TRcAanmJjtsbrHOU +3Udlm0HqrCmAubQ4kC/wXsx4Pua7Yi2RDvBrT4rT4LGgreaXNWhI+Srx7kZslUx5Bkbez3I0bXpM +2vvwS/sCgYAducNt1KC4W7jzMWUivvuy5hQQmX/G0JHtu1pfv9cmA8agnc1I/r7xoirftuSG25Pm +r+eP5SKbKb8ZQlp10JeBkNnk8eAG8OkQyBaECYDBadEr1/LK2LmIEjYKzKAjYQ4cX2KMtY271jjX +WrzzXNqBdThFfMHiJE8k9xYmaLDKhQ== +-----END PRIVATE KEY----- diff --git a/netty-handler-ssl/src/test/resources/io/netty/handler/ssl/tm_test_ca_1a.pem b/netty-handler-ssl/src/test/resources/io/netty/handler/ssl/tm_test_ca_1a.pem new file mode 100644 index 0000000..120859e --- /dev/null +++ b/netty-handler-ssl/src/test/resources/io/netty/handler/ssl/tm_test_ca_1a.pem @@ -0,0 +1,19 @@ +-----BEGIN CERTIFICATE----- +MIIC/TCCAeWgAwIBAgIBATANBgkqhkiG9w0BAQsFADAXMRUwEwYDVQQDEwxuZXR0 +eS50ZXN0LjEwIBcNMTYwMzIyMTIwMDAwWhgPMjExNjAzMjIxMjAwMDBaMBcxFTAT +BgNVBAMTDG5ldHR5LnRlc3QuMTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoC +ggEBALWMMCP4QWYWJNt+fNqpwLNM9/LkJlS3NtzJl1chvnyHpxt8OFSD8/cYSl6z +MrbgRYyGNuaL3lsKIL5p2ZnUYzcR61niAhjuMXQgM6ZkptlIsgK6426OTALOSN6l +HukItWDDL/om0Mnc8zMuLL/kIpfnzYOKMseUf/1R1MftzlNSSAMPQ7Rn8So/3nUG +j42NywEInoONv89UZ4L+xPpyJwrp0k/u19ckwhFWdudw7l2lVo6s5aBJW9CK8v/f +uUxC75eUYiQ57suKhXCy1Vf8T4vVDiEjKxa3whD1QxlxRxZNYdHJA6tEQIhiWjCC +RiDZZcaAcCCD0evE/0l5V9nnRc8CAwEAAaNSMFAwDwYDVR0TAQH/BAUwAwEB/zAd +BgNVHQ4EFgQUvdqINGhE1D1xZi9Q8NyR+G+5bLwwCwYDVR0PBAQDAgEGMBEGCWCG +SAGG+EIBAQQEAwIABzANBgkqhkiG9w0BAQsFAAOCAQEAR8gHn7MJp6cNwMR6qF3e +jU3tAzCshZVM03NyoHvMpcHsILlR0g/q2KTjcgHzpMMo5PrUGf3oR6ad4JFr5els +kstgbCe4Vv/XzEC6faTEuhLolHGMyzr3Pd6k/wJSsMktF7Ob+YjsyZbgQbyhXqJV +UDQDDncIwxl5rdsRwfiltLUOle4702b4hSCb/1NsDsvsuZQVfeAHHzT1aS8XDSwK +bHOgrDgQGhVR6rBTH9WhcRgFY9rKQ4vVjhoNbwWweQvHmQSO8xYNUhtQnxVOeB7B +NzBM+kx5nw7oIqPCYT0hBINNqeoac9Bidfl4UoTB5YjsQFse4BNuBDPFowAXq5ZB +fg== +-----END CERTIFICATE----- diff --git a/netty-handler-ssl/src/test/resources/io/netty/handler/ssl/tm_test_ca_1b.pem b/netty-handler-ssl/src/test/resources/io/netty/handler/ssl/tm_test_ca_1b.pem new file mode 100644 index 0000000..df75823 --- /dev/null +++ b/netty-handler-ssl/src/test/resources/io/netty/handler/ssl/tm_test_ca_1b.pem @@ -0,0 +1,19 @@ +-----BEGIN CERTIFICATE----- +MIIC/TCCAeWgAwIBAgIBATANBgkqhkiG9w0BAQsFADAXMRUwEwYDVQQDEwxuZXR0 +eS50ZXN0LjEwIBcNMTYwMzIyMTMwMDAwWhgPMjExNjAzMjIxMzAwMDBaMBcxFTAT +BgNVBAMTDG5ldHR5LnRlc3QuMTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoC +ggEBANEUVJUwvWr+qyS28W2EMiDq5frQwforMED7Q8wMiMod+LFxy7y04p12zYWW +35iqC3RaQUQ31DOknAxc7H8vfr0vdl87BIsxc27Ud9h+Do0ggktCaz9Te8/q2Yxo +4TQ8QEFJ8x37zPB05LVqF4djim4GE/yaj0WFMuaRaZLUFvGbHTL7ilC2l6p9SuYx +y40cCucP5nNAXGNhnVYsJCPa/LkyIDLGbkvMMARorkbr7zfaYI2D1YfedwmCOEo1 +CkfBm2qL9/+ig/8VFrTRPlYzUWHsyPvCzfL9F/69NxRCdVk7XCyMEgmf/ztyy/7k +3iZeBhQ0z+RXiNLqBkK3RbaMD7ECAwEAAaNSMFAwDwYDVR0TAQH/BAUwAwEB/zAd +BgNVHQ4EFgQU2vr0yImPHyJv84PXeoSJYbI12uEwCwYDVR0PBAQDAgEGMBEGCWCG +SAGG+EIBAQQEAwIABzANBgkqhkiG9w0BAQsFAAOCAQEAT8mmZi/4dzozqa76tgxM +8ZQgw2C+WgetC81CKqIN3F3tFtu2KEsaAsXEpOVvDR278bLk+r3H3d47Or0xn053 +grk6kdI4C9IPHP7IDaNmAskZ5u9Hrl25P1fxMKG6hXwrk2Je7gD8aNP5IkOSKulo +e9b3XSW53WdtHZ+b98LKVMO0lRLQsiG1EmNrL0kJwMXuPxq5s0Ljqz/L19iWGupk +kybRWPcmjHnWIOnnYTwFswI/h79/afvwW5xUP4HgcU/nKrNDWveE7lSYq66zcvpt +rBCESrr3gvETNTJKCPN4u41EOJKGGgoN4U9fBopU4DfzIrcwZ5a4eFLsAFEqAB23 +3A== +-----END CERTIFICATE----- diff --git a/netty-handler-ssl/src/test/resources/io/netty/handler/ssl/tm_test_ca_2.pem b/netty-handler-ssl/src/test/resources/io/netty/handler/ssl/tm_test_ca_2.pem new file mode 100644 index 0000000..1d96bfc --- /dev/null +++ b/netty-handler-ssl/src/test/resources/io/netty/handler/ssl/tm_test_ca_2.pem @@ -0,0 +1,19 @@ +-----BEGIN CERTIFICATE----- +MIIC/TCCAeWgAwIBAgIBATANBgkqhkiG9w0BAQUFADAXMRUwEwYDVQQDEwxuZXR0 +eS50ZXN0LjIwIBcNMTYwMzIyMTIwMDAwWhgPMjExNjAzMjIxMjAwMDBaMBcxFTAT +BgNVBAMTDG5ldHR5LnRlc3QuMjCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoC +ggEBAMFTCqEUaMvZPwV8PFHgjpcqlrVfIadY+B2DAsYBQp6HirRv2STYdBue9bHS +Ya8n6J99Qcp0jct692MeOh9BhWoX7wOvi8Tckiu+LAMo7vBHAyUSamJ5qKyNvZNW +3Uwrng1LFwOo6uhY6N6vqyv5CIoDGv+afOlnZOS7484ZYvmYEPejbIPLZpj7IP8Z +c2xmi7TOj781uO2rwUzZgqGSEKsYZNMhVp7yZrsJ3el9T2+2Dma1aYt2w/grOW1t +pzxWvXqjkrjbNjJamAxyiy3qyQ2iDhpFYz8ONqSxqN/QRE72q2N8e+QtajceCFz7 +lpMWRqr7Z6Fdh5zUS7yJrCksJh8CAwEAAaNSMFAwDwYDVR0TAQH/BAUwAwEB/zAd +BgNVHQ4EFgQUnk0/s/aeR4sjFcubiyB/QMWryj8wCwYDVR0PBAQDAgEGMBEGCWCG +SAGG+EIBAQQEAwIABzANBgkqhkiG9w0BAQUFAAOCAQEAWu+RyHkL3lNl8dLlVuDa +A/Vakxf/xbd8+qFEIox2nLvSYZ3OkLzE+vUip/KP0JyQmmmzaz3sx5eZXx33gw9Z +rRYW3I0/c2QPjT5xNYnITUoX5z17FKd71lMr/bz8uhaF9Do+ZV84HgORwtmOCwNg +bOIIVtHO6Ht3V2RmLcQgUV4dK3neNJHa75/Wi3OkJNEZqbzcJX2r69BqupoLte8j +FxqkLBmwUruuCVl5gUFoXsxT3+qgWMxNweLSxEmbqkQ54g8W+06PTHMM/BpbsApv +Ce5mKeC8lHvbV3CxaOYp8w5xJPJbEt/vK6w8jrN47Tz6LaQcimDMdJVVM2H5zubG +RA== +-----END CERTIFICATE----- diff --git a/netty-handler-ssl/src/test/resources/io/netty/handler/ssl/tm_test_eec_1.pem b/netty-handler-ssl/src/test/resources/io/netty/handler/ssl/tm_test_eec_1.pem new file mode 100644 index 0000000..e90e205 --- /dev/null +++ b/netty-handler-ssl/src/test/resources/io/netty/handler/ssl/tm_test_eec_1.pem @@ -0,0 +1,19 @@ +-----BEGIN CERTIFICATE----- +MIIDADCCAeigAwIBAgIBAjANBgkqhkiG9w0BAQUFADAXMRUwEwYDVQQDEwxuZXR0 +eS50ZXN0LjEwIBcNMTYwMzIyMTQwMDAwWhgPMjExNjAzMjIxMjAwMDBaMB0xGzAZ +BgNVBAMTEm5ldHR5LnRlc3Quc2VydmVyMTCCASIwDQYJKoZIhvcNAQEBBQADggEP +ADCCAQoCggEBALEytHncQ4iJ0d3nbLcRhe3wWmHzeqr9JKvAMSnIGw2Oc+nUK3FP +xMjdYnY3CSw6BYjK5zJsSF2UWU/jycqc5caQTEXVUQJDjQENScHHsWB+jEevd1i7 +HDmd4Ykw/bqRRRP/I6npukpwJkPuqJ7LAsDfA7FuHJY15ZlhDod5+1zleWv2tdDK +fRCvODi8ehujyNqBqNeRL2YF4wt4yqviFbDNy5+JreyfzPvv1iujSprN5JJCRJ1C +7AABrHYYMQXuEYb10t9dS908Zg2B3sDUa969F1RnU86bCppwK1otQr/RO2hUqqdq +IItb9FHRkeko81OvUiM6nvLzzJLBOInyAzMCAwEAAaNPME0wDAYDVR0TAQH/BAIw +ADAdBgNVHQ4EFgQUt5uJ55JYS1Qw2YE5OWOhgv6RxUAwCwYDVR0PBAQDAgXgMBEG +CWCGSAGG+EIBAQQEAwIGQDANBgkqhkiG9w0BAQUFAAOCAQEAUS9HzI9VXyZiaGM6 +RwpUEDgUhbAeI7i7xsdJgqlbvKrTQQy+MKIbxDgsyoz2buqwdX7ekvykTmo0pltS +ASr36gTTW4dwRtiecn/HutrnyuJIckbvMZzld5xIdNERqLHnoiRAopVhe1Fc5UFd +YGEOd+685X2fuc9PMy3G8JjQAOftYOx21JaaNumyVVLcyvciGK0Ptwh/q+6hf4+h +XUHHtIzjnPAM9vkcCmHttVbl3uvare7TfeAoU82NODz0sUaOrIwG8dQbmEdrafHa +JHXti1wv+9ZEEiYKcecvnB3q4e0MT3atf3qedw4B9ZkzoniHEOhFpZgQg6UVA7/f +ga0mCg== +-----END CERTIFICATE----- diff --git a/netty-handler-ssl/src/test/resources/io/netty/handler/ssl/tm_test_eec_2.pem b/netty-handler-ssl/src/test/resources/io/netty/handler/ssl/tm_test_eec_2.pem new file mode 100644 index 0000000..898a6e0 --- /dev/null +++ b/netty-handler-ssl/src/test/resources/io/netty/handler/ssl/tm_test_eec_2.pem @@ -0,0 +1,19 @@ +-----BEGIN CERTIFICATE----- +MIIDADCCAeigAwIBAgIBAjANBgkqhkiG9w0BAQUFADAXMRUwEwYDVQQDEwxuZXR0 +eS50ZXN0LjEwIBcNMTYwMzIyMTQwMDAwWhgPMjExNjAzMjIxMzAwMDBaMB0xGzAZ +BgNVBAMTEm5ldHR5LnRlc3Quc2VydmVyMjCCASIwDQYJKoZIhvcNAQEBBQADggEP +ADCCAQoCggEBAMFf61lsJ5Tpcg6ux5lbZhPEzvLi117aZ42figORTeZg0fMX8a3W +XfRUHldTYD6CugTIpgheZDzsYfwrSLyK5jMyQl332rCoozj2dXdi+HY/JQ5gw8DY +odeQCjCWS3V6HpHTu8tqgNquqjygSLQHSsh+oCVUa/5IWA1N+zZY5/XARnyajwS0 +8Nrgok9kd//jR7hIwHm40YThrTawRetooDK/MuFxJb6FX6UPwZ5/9g2UZMUaKzIa +hrGfjAQmRxzkRyACbVqYv+wjBMSCq2SqYZ2Fq3nKeW+dvpPeFfHGOyF/F8kIqtpa +BRawXPHKaUoQYn0PDU+RRfZjfkWTQuz/OMcCAwEAAaNPME0wDAYDVR0TAQH/BAIw +ADAdBgNVHQ4EFgQUIUVnoWei3itUqTyFxPczCYsvW4EwCwYDVR0PBAQDAgXgMBEG +CWCGSAGG+EIBAQQEAwIGQDANBgkqhkiG9w0BAQUFAAOCAQEAfK+YlGqBVExATkGF +1ZIcJZtvaiX8rGH8mwqj1wPvKjPRCHvNpPTDLNGhHrFu/0sJlZQDz6hDn0NpJpD8 +TffF+jqmBfvGQW1MEd+jyfp5IXHwR0ZejJepQIeGYuMwyrlZXUKnXvQR2QDkLyx+ +rxmO58XWLNoFUkM4guts3Jb7oAgfCbzYnmBELMVhI8v+SQhuZamvL6S5Wdb18O9i +/N/zH/KDwJmIVtWo7D8UOAMeq69s9zYZLKkqwt8o+DSXth0YPZcNcU8IouDzEJ14 +C35My7Ll7vFehgetXq9D7cMYltx2VPKKYOeT5ZzI580ZvtryT8yCTBj4GoSvAzb6 +RHwFRw== +-----END CERTIFICATE----- diff --git a/netty-handler-ssl/src/test/resources/io/netty/handler/ssl/tm_test_eec_3.pem b/netty-handler-ssl/src/test/resources/io/netty/handler/ssl/tm_test_eec_3.pem new file mode 100644 index 0000000..82fa9c9 --- /dev/null +++ b/netty-handler-ssl/src/test/resources/io/netty/handler/ssl/tm_test_eec_3.pem @@ -0,0 +1,19 @@ +-----BEGIN CERTIFICATE----- +MIIDADCCAeigAwIBAgIBAjANBgkqhkiG9w0BAQUFADAXMRUwEwYDVQQDEwxuZXR0 +eS50ZXN0LjIwIBcNMTYwMzIyMTUwMDAwWhgPMjExNjAzMjIxMjAwMDBaMB0xGzAZ +BgNVBAMTEm5ldHR5LnRlc3Quc2VydmVyMzCCASIwDQYJKoZIhvcNAQEBBQADggEP +ADCCAQoCggEBALGdif7vCqwoWg001wmw2X3HnF/9cjZGTY6BWiex3cIarO+aby6L +QObv7lXzzh3r7bJc1JquaawKUek6CIt2mNM+KLAIscScNUOCYg4T1XOD0rh3Qrin +V8Q+hY8InSrUqN4cuX95YOdJoOddw3mDXs376fcNBU10raP4L7k7EsyvnQqIKyNP +ysU9PpDoDLPCMBEGB8cDASv1bkopvvH9u5sic1OHtTLVllsbGNnG6Y9B+1ysG1UO +iJFAX+teigPXKVokJ1+a8dt1CkzDd+iTK1j6PrY+TXc4XOhP/cSnbxwq6JDzkkSb +3pMTJK8ypi7DQNkDbPx73A8qbjRT32gJz4ECAwEAAaNPME0wDAYDVR0TAQH/BAIw +ADAdBgNVHQ4EFgQUCIQ1MChHCo6/1mI6B7S6QPPvEZcwCwYDVR0PBAQDAgXgMBEG +CWCGSAGG+EIBAQQEAwIGQDANBgkqhkiG9w0BAQUFAAOCAQEAZas05SOIsNFmKUYY +kyR1ctlgaA7OZwSzeRPh6vZJ4YaT2lVhNPUeO84tf3LqKE8B827FzWH9mcO/2zeJ +6PTR+QYls/wg8VR881V0Xb5KVNGfwTYpmfhH9+JSzKvKiEtlOoHyvYBMdUon7LJL +ojvlragwXm4QA246345+md5C8PEyQYQf/AoZVZWeLL/BRXZ2ZjsuIT+LzpMIXuTW +AKoH7IlFbKQ5tccQDGCzZb6V1txRDFlKZ/5bvFQZqo12n0MeJy2WPjrmeRm2NC+9 +imP9oR9GIGNyGKTT1h1qjnaZZwK24cx/82eb63qQKUx80pD4DYW9EDU6/tULz5gs +Kw0iig== +-----END CERTIFICATE----- diff --git a/netty-handler-ssl/src/test/resources/logging.properties b/netty-handler-ssl/src/test/resources/logging.properties new file mode 100644 index 0000000..0d14c34 --- /dev/null +++ b/netty-handler-ssl/src/test/resources/logging.properties @@ -0,0 +1,9 @@ +handlers=java.util.logging.ConsoleHandler +.level=ALL +java.util.logging.SimpleFormatter.format=%1$tY-%1$tm-%1$td %1$tH:%1$tM:%1$tS.%1$tL %4$-7s [%3$s] %5$s %6$s%n +java.util.logging.ConsoleHandler.level=ALL +java.util.logging.ConsoleHandler.formatter=java.util.logging.SimpleFormatter +jdk.event.security.level=INFO +org.junit.jupiter.engine.execution.ConditionEvaluator.level=OFF +org.junit.jupiter.engine.execution.ParameterResolutionUtils.level=OFF +org.junit.jupiter.engine.extension.MutableExtensionRegistry.level=OFF \ No newline at end of file diff --git a/netty-handler/build.gradle b/netty-handler/build.gradle new file mode 100644 index 0000000..1f7bd0b --- /dev/null +++ b/netty-handler/build.gradle @@ -0,0 +1,6 @@ +dependencies { + api project(':netty-handler-codec') + api project(':netty-resolver') + testImplementation testLibs.mockito.core + testImplementation testLibs.assertj +} diff --git a/netty-handler/src/main/java/io/netty/handler/address/DynamicAddressConnectHandler.java b/netty-handler/src/main/java/io/netty/handler/address/DynamicAddressConnectHandler.java new file mode 100644 index 0000000..801c501 --- /dev/null +++ b/netty-handler/src/main/java/io/netty/handler/address/DynamicAddressConnectHandler.java @@ -0,0 +1,82 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.address; + +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelOutboundHandler; +import io.netty.channel.ChannelOutboundHandlerAdapter; +import io.netty.channel.ChannelPromise; + +import java.net.NetworkInterface; +import java.net.SocketAddress; + +/** + * {@link ChannelOutboundHandler} implementation which allows to dynamically replace the used + * {@code remoteAddress} and / or {@code localAddress} when making a connection attempt. + *

+ * This can be useful to for example bind to a specific {@link NetworkInterface} based on + * the {@code remoteAddress}. + */ +public abstract class DynamicAddressConnectHandler extends ChannelOutboundHandlerAdapter { + + @Override + public final void connect(ChannelHandlerContext ctx, SocketAddress remoteAddress, + SocketAddress localAddress, ChannelPromise promise) { + final SocketAddress remote; + final SocketAddress local; + try { + remote = remoteAddress(remoteAddress, localAddress); + local = localAddress(remoteAddress, localAddress); + } catch (Exception e) { + promise.setFailure(e); + return; + } + ctx.connect(remote, local, promise).addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + if (future.isSuccess()) { + // We only remove this handler from the pipeline once the connect was successful as otherwise + // the user may try to connect again. + future.channel().pipeline().remove(DynamicAddressConnectHandler.this); + } + } + }); + } + + /** + * Returns the local {@link SocketAddress} to use for + * {@link ChannelHandlerContext#connect(SocketAddress, SocketAddress)} based on the original {@code remoteAddress} + * and {@code localAddress}. + * By default, this method returns the given {@code localAddress}. + */ + protected SocketAddress localAddress( + @SuppressWarnings("unused") SocketAddress remoteAddress, SocketAddress localAddress) throws Exception { + return localAddress; + } + + /** + * Returns the remote {@link SocketAddress} to use for + * {@link ChannelHandlerContext#connect(SocketAddress, SocketAddress)} based on the original {@code remoteAddress} + * and {@code localAddress}. + * By default, this method returns the given {@code remoteAddress}. + */ + protected SocketAddress remoteAddress( + SocketAddress remoteAddress, @SuppressWarnings("unused") SocketAddress localAddress) throws Exception { + return remoteAddress; + } +} diff --git a/netty-handler/src/main/java/io/netty/handler/address/ResolveAddressHandler.java b/netty-handler/src/main/java/io/netty/handler/address/ResolveAddressHandler.java new file mode 100644 index 0000000..bdabc85 --- /dev/null +++ b/netty-handler/src/main/java/io/netty/handler/address/ResolveAddressHandler.java @@ -0,0 +1,66 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.address; + +import io.netty.channel.ChannelHandler.Sharable; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelOutboundHandlerAdapter; +import io.netty.channel.ChannelPromise; +import io.netty.resolver.AddressResolver; +import io.netty.resolver.AddressResolverGroup; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.FutureListener; +import io.netty.util.internal.ObjectUtil; + +import java.net.SocketAddress; + +/** + * {@link ChannelOutboundHandlerAdapter} which will resolve the {@link SocketAddress} that is passed to + * {@link #connect(ChannelHandlerContext, SocketAddress, SocketAddress, ChannelPromise)} if it is not already resolved + * and the {@link AddressResolver} supports the type of {@link SocketAddress}. + */ +@Sharable +public class ResolveAddressHandler extends ChannelOutboundHandlerAdapter { + + private final AddressResolverGroup resolverGroup; + + public ResolveAddressHandler(AddressResolverGroup resolverGroup) { + this.resolverGroup = ObjectUtil.checkNotNull(resolverGroup, "resolverGroup"); + } + + @Override + public void connect(final ChannelHandlerContext ctx, SocketAddress remoteAddress, + final SocketAddress localAddress, final ChannelPromise promise) { + AddressResolver resolver = resolverGroup.getResolver(ctx.executor()); + if (resolver.isSupported(remoteAddress) && !resolver.isResolved(remoteAddress)) { + resolver.resolve(remoteAddress).addListener(new FutureListener() { + @Override + public void operationComplete(Future future) { + Throwable cause = future.cause(); + if (cause != null) { + promise.setFailure(cause); + } else { + ctx.connect(future.getNow(), localAddress, promise); + } + ctx.pipeline().remove(ResolveAddressHandler.this); + } + }); + } else { + ctx.connect(remoteAddress, localAddress, promise); + ctx.pipeline().remove(this); + } + } +} diff --git a/netty-handler/src/main/java/io/netty/handler/address/package-info.java b/netty-handler/src/main/java/io/netty/handler/address/package-info.java new file mode 100644 index 0000000..1f83996 --- /dev/null +++ b/netty-handler/src/main/java/io/netty/handler/address/package-info.java @@ -0,0 +1,20 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/** + * Package to dynamically replace local / remote {@link java.net.SocketAddress}. + */ +package io.netty.handler.address; diff --git a/netty-handler/src/main/java/io/netty/handler/flow/FlowControlHandler.java b/netty-handler/src/main/java/io/netty/handler/flow/FlowControlHandler.java new file mode 100644 index 0000000..7f1f542 --- /dev/null +++ b/netty-handler/src/main/java/io/netty/handler/flow/FlowControlHandler.java @@ -0,0 +1,256 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version + * 2.0 (the "License"); you may not use this file except in compliance with the + * License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.flow; + +import java.util.ArrayDeque; +import java.util.Queue; + +import io.netty.channel.ChannelConfig; +import io.netty.channel.ChannelDuplexHandler; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.handler.codec.ByteToMessageDecoder; +import io.netty.handler.codec.MessageToByteEncoder; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.internal.ObjectPool; +import io.netty.util.internal.ObjectPool.Handle; +import io.netty.util.internal.ObjectPool.ObjectCreator; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +/** + * The {@link FlowControlHandler} ensures that only one message per {@code read()} is sent downstream. + * + * Classes such as {@link ByteToMessageDecoder} or {@link MessageToByteEncoder} are free to emit as + * many events as they like for any given input. A channel's auto reading configuration doesn't usually + * apply in these scenarios. This is causing problems in downstream {@link ChannelHandler}s that would + * like to hold subsequent events while they're processing one event. It's a common problem with the + * {@code HttpObjectDecoder} that will very often fire an {@code HttpRequest} that is immediately followed + * by a {@code LastHttpContent} event. + * + *

{@code
+ * ChannelPipeline pipeline = ...;
+ *
+ * pipeline.addLast(new HttpServerCodec());
+ * pipeline.addLast(new FlowControlHandler());
+ *
+ * pipeline.addLast(new MyExampleHandler());
+ *
+ * class MyExampleHandler extends ChannelInboundHandlerAdapter {
+ *   @Override
+ *   public void channelRead(ChannelHandlerContext ctx, Object msg) {
+ *     if (msg instanceof HttpRequest) {
+ *       ctx.channel().config().setAutoRead(false);
+ *
+ *       // The FlowControlHandler will hold any subsequent events that
+ *       // were emitted by HttpObjectDecoder until auto reading is turned
+ *       // back on or Channel#read() is being called.
+ *     }
+ *   }
+ * }
+ * }
+ * + * @see ChannelConfig#setAutoRead(boolean) + */ +public class FlowControlHandler extends ChannelDuplexHandler { + private static final InternalLogger logger = InternalLoggerFactory.getInstance(FlowControlHandler.class); + + private final boolean releaseMessages; + + private RecyclableArrayDeque queue; + + private ChannelConfig config; + + private boolean shouldConsume; + + public FlowControlHandler() { + this(true); + } + + public FlowControlHandler(boolean releaseMessages) { + this.releaseMessages = releaseMessages; + } + + /** + * Determine if the underlying {@link Queue} is empty. This method exists for + * testing, debugging and inspection purposes and it is not Thread safe! + */ + boolean isQueueEmpty() { + return queue == null || queue.isEmpty(); + } + + /** + * Releases all messages and destroys the {@link Queue}. + */ + private void destroy() { + if (queue != null) { + + if (!queue.isEmpty()) { + logger.trace("Non-empty queue: {}", queue); + + if (releaseMessages) { + Object msg; + while ((msg = queue.poll()) != null) { + ReferenceCountUtil.safeRelease(msg); + } + } + } + + queue.recycle(); + queue = null; + } + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + config = ctx.channel().config(); + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { + super.handlerRemoved(ctx); + if (!isQueueEmpty()) { + dequeue(ctx, queue.size()); + } + destroy(); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + destroy(); + ctx.fireChannelInactive(); + } + + @Override + public void read(ChannelHandlerContext ctx) throws Exception { + if (dequeue(ctx, 1) == 0) { + // It seems no messages were consumed. We need to read() some + // messages from upstream and once one arrives it need to be + // relayed to downstream to keep the flow going. + shouldConsume = true; + ctx.read(); + } else if (config.isAutoRead()) { + ctx.read(); + } + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if (queue == null) { + queue = RecyclableArrayDeque.newInstance(); + } + + queue.offer(msg); + + // We just received one message. Do we need to relay it regardless + // of the auto reading configuration? The answer is yes if this + // method was called as a result of a prior read() call. + int minConsume = shouldConsume ? 1 : 0; + shouldConsume = false; + + dequeue(ctx, minConsume); + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + if (isQueueEmpty()) { + ctx.fireChannelReadComplete(); + } else { + // Don't relay completion events from upstream as they + // make no sense in this context. See dequeue() where + // a new set of completion events is being produced. + } + } + + /** + * Dequeues one or many (or none) messages depending on the channel's auto + * reading state and returns the number of messages that were consumed from + * the internal queue. + * + * The {@code minConsume} argument is used to force {@code dequeue()} into + * consuming that number of messages regardless of the channel's auto + * reading configuration. + * + * @see #read(ChannelHandlerContext) + * @see #channelRead(ChannelHandlerContext, Object) + */ + private int dequeue(ChannelHandlerContext ctx, int minConsume) { + int consumed = 0; + + // fireChannelRead(...) may call ctx.read() and so this method may reentrance. Because of this we need to + // check if queue was set to null in the meantime and if so break the loop. + while (queue != null && (consumed < minConsume || config.isAutoRead())) { + Object msg = queue.poll(); + if (msg == null) { + break; + } + + ++consumed; + ctx.fireChannelRead(msg); + } + + // We're firing a completion event every time one (or more) + // messages were consumed and the queue ended up being drained + // to an empty state. + if (queue != null && queue.isEmpty()) { + queue.recycle(); + queue = null; + + if (consumed > 0) { + ctx.fireChannelReadComplete(); + } + } + + return consumed; + } + + /** + * A recyclable {@link ArrayDeque}. + */ + private static final class RecyclableArrayDeque extends ArrayDeque { + + private static final long serialVersionUID = 0L; + + /** + * A value of {@code 2} should be a good choice for most scenarios. + */ + private static final int DEFAULT_NUM_ELEMENTS = 2; + + private static final ObjectPool RECYCLER = ObjectPool.newPool( + new ObjectCreator() { + @Override + public RecyclableArrayDeque newObject(Handle handle) { + return new RecyclableArrayDeque(DEFAULT_NUM_ELEMENTS, handle); + } + }); + + public static RecyclableArrayDeque newInstance() { + return RECYCLER.get(); + } + + private final Handle handle; + + private RecyclableArrayDeque(int numElements, Handle handle) { + super(numElements); + this.handle = handle; + } + + public void recycle() { + clear(); + handle.recycle(this); + } + } +} diff --git a/netty-handler/src/main/java/io/netty/handler/flow/package-info.java b/netty-handler/src/main/java/io/netty/handler/flow/package-info.java new file mode 100644 index 0000000..285901f --- /dev/null +++ b/netty-handler/src/main/java/io/netty/handler/flow/package-info.java @@ -0,0 +1,20 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/** + * Package to control the flow of messages. + */ +package io.netty.handler.flow; diff --git a/netty-handler/src/main/java/io/netty/handler/flush/FlushConsolidationHandler.java b/netty-handler/src/main/java/io/netty/handler/flush/FlushConsolidationHandler.java new file mode 100644 index 0000000..8a93571 --- /dev/null +++ b/netty-handler/src/main/java/io/netty/handler/flush/FlushConsolidationHandler.java @@ -0,0 +1,220 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.flush; + +import io.netty.channel.Channel; +import io.netty.channel.ChannelDuplexHandler; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelOutboundHandler; +import io.netty.channel.ChannelOutboundInvoker; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.ChannelPromise; +import io.netty.util.internal.ObjectUtil; + +import java.util.concurrent.Future; + +/** + * {@link ChannelDuplexHandler} which consolidates {@link Channel#flush()} / {@link ChannelHandlerContext#flush()} + * operations (which also includes + * {@link Channel#writeAndFlush(Object)} / {@link Channel#writeAndFlush(Object, ChannelPromise)} and + * {@link ChannelOutboundInvoker#writeAndFlush(Object)} / + * {@link ChannelOutboundInvoker#writeAndFlush(Object, ChannelPromise)}). + *

+ * Flush operations are generally speaking expensive as these may trigger a syscall on the transport level. Thus it is + * in most cases (where write latency can be traded with throughput) a good idea to try to minimize flush operations + * as much as possible. + *

+ * If a read loop is currently ongoing, {@link #flush(ChannelHandlerContext)} will not be passed on to the next + * {@link ChannelOutboundHandler} in the {@link ChannelPipeline}, as it will pick up any pending flushes when + * {@link #channelReadComplete(ChannelHandlerContext)} is triggered. + * If no read loop is ongoing, the behavior depends on the {@code consolidateWhenNoReadInProgress} constructor argument: + *

    + *
  • if {@code false}, flushes are passed on to the next handler directly;
  • + *
  • if {@code true}, the invocation of the next handler is submitted as a separate task on the event loop. Under + * high throughput, this gives the opportunity to process other flushes before the task gets executed, thus + * batching multiple flushes into one.
  • + *
+ * If {@code explicitFlushAfterFlushes} is reached the flush will be forwarded as well (whether while in a read loop, or + * while batching outside of a read loop). + *

+ * If the {@link Channel} becomes non-writable it will also try to execute any pending flush operations. + *

+ * The {@link FlushConsolidationHandler} should be put as first {@link ChannelHandler} in the + * {@link ChannelPipeline} to have the best effect. + */ +public class FlushConsolidationHandler extends ChannelDuplexHandler { + private final int explicitFlushAfterFlushes; + private final boolean consolidateWhenNoReadInProgress; + private final Runnable flushTask; + private int flushPendingCount; + private boolean readInProgress; + private ChannelHandlerContext ctx; + private Future nextScheduledFlush; + + /** + * The default number of flushes after which a flush will be forwarded to downstream handlers (whether while in a + * read loop, or while batching outside of a read loop). + */ + public static final int DEFAULT_EXPLICIT_FLUSH_AFTER_FLUSHES = 256; + + /** + * Create new instance which explicit flush after {@value DEFAULT_EXPLICIT_FLUSH_AFTER_FLUSHES} pending flush + * operations at the latest. + */ + public FlushConsolidationHandler() { + this(DEFAULT_EXPLICIT_FLUSH_AFTER_FLUSHES, false); + } + + /** + * Create new instance which doesn't consolidate flushes when no read is in progress. + * + * @param explicitFlushAfterFlushes the number of flushes after which an explicit flush will be done. + */ + public FlushConsolidationHandler(int explicitFlushAfterFlushes) { + this(explicitFlushAfterFlushes, false); + } + + /** + * Create new instance. + * + * @param explicitFlushAfterFlushes the number of flushes after which an explicit flush will be done. + * @param consolidateWhenNoReadInProgress whether to consolidate flushes even when no read loop is currently + * ongoing. + */ + public FlushConsolidationHandler(int explicitFlushAfterFlushes, boolean consolidateWhenNoReadInProgress) { + this.explicitFlushAfterFlushes = + ObjectUtil.checkPositive(explicitFlushAfterFlushes, "explicitFlushAfterFlushes"); + this.consolidateWhenNoReadInProgress = consolidateWhenNoReadInProgress; + this.flushTask = consolidateWhenNoReadInProgress ? + new Runnable() { + @Override + public void run() { + if (flushPendingCount > 0 && !readInProgress) { + flushPendingCount = 0; + nextScheduledFlush = null; + ctx.flush(); + } // else we'll flush when the read completes + } + } + : null; + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + this.ctx = ctx; + } + + @Override + public void flush(ChannelHandlerContext ctx) throws Exception { + if (readInProgress) { + // If there is still a read in progress we are sure we will see a channelReadComplete(...) call. Thus + // we only need to flush if we reach the explicitFlushAfterFlushes limit. + if (++flushPendingCount == explicitFlushAfterFlushes) { + flushNow(ctx); + } + } else if (consolidateWhenNoReadInProgress) { + // Flush immediately if we reach the threshold, otherwise schedule + if (++flushPendingCount == explicitFlushAfterFlushes) { + flushNow(ctx); + } else { + scheduleFlush(ctx); + } + } else { + // Always flush directly + flushNow(ctx); + } + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + // This may be the last event in the read loop, so flush now! + resetReadAndFlushIfNeeded(ctx); + ctx.fireChannelReadComplete(); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + readInProgress = true; + ctx.fireChannelRead(msg); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + // To ensure we not miss to flush anything, do it now. + resetReadAndFlushIfNeeded(ctx); + ctx.fireExceptionCaught(cause); + } + + @Override + public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + // Try to flush one last time if flushes are pending before disconnect the channel. + resetReadAndFlushIfNeeded(ctx); + ctx.disconnect(promise); + } + + @Override + public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + // Try to flush one last time if flushes are pending before close the channel. + resetReadAndFlushIfNeeded(ctx); + ctx.close(promise); + } + + @Override + public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception { + if (!ctx.channel().isWritable()) { + // The writability of the channel changed to false, so flush all consolidated flushes now to free up memory. + flushIfNeeded(ctx); + } + ctx.fireChannelWritabilityChanged(); + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { + flushIfNeeded(ctx); + } + + private void resetReadAndFlushIfNeeded(ChannelHandlerContext ctx) { + readInProgress = false; + flushIfNeeded(ctx); + } + + private void flushIfNeeded(ChannelHandlerContext ctx) { + if (flushPendingCount > 0) { + flushNow(ctx); + } + } + + private void flushNow(ChannelHandlerContext ctx) { + cancelScheduledFlush(); + flushPendingCount = 0; + ctx.flush(); + } + + private void scheduleFlush(final ChannelHandlerContext ctx) { + if (nextScheduledFlush == null) { + // Run as soon as possible, but still yield to give a chance for additional writes to enqueue. + nextScheduledFlush = ctx.channel().eventLoop().submit(flushTask); + } + } + + private void cancelScheduledFlush() { + if (nextScheduledFlush != null) { + nextScheduledFlush.cancel(false); + nextScheduledFlush = null; + } + } +} diff --git a/netty-handler/src/main/java/io/netty/handler/flush/package-info.java b/netty-handler/src/main/java/io/netty/handler/flush/package-info.java new file mode 100644 index 0000000..8dc2070 --- /dev/null +++ b/netty-handler/src/main/java/io/netty/handler/flush/package-info.java @@ -0,0 +1,20 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/** + * Package to control flush behavior. + */ +package io.netty.handler.flush; diff --git a/netty-handler/src/main/java/io/netty/handler/ipfilter/AbstractRemoteAddressFilter.java b/netty-handler/src/main/java/io/netty/handler/ipfilter/AbstractRemoteAddressFilter.java new file mode 100644 index 0000000..aafa92a --- /dev/null +++ b/netty-handler/src/main/java/io/netty/handler/ipfilter/AbstractRemoteAddressFilter.java @@ -0,0 +1,109 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ipfilter; + +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; + +import java.net.SocketAddress; + +/** + * This class provides the functionality to either accept or reject new {@link Channel}s + * based on their IP address. + *

+ * You should inherit from this class if you would like to implement your own IP-based filter. Basically you have to + * implement {@link #accept(ChannelHandlerContext, SocketAddress)} to decided whether you want to accept or reject + * a connection from the remote address. + *

+ * Furthermore overriding {@link #channelRejected(ChannelHandlerContext, SocketAddress)} gives you the + * flexibility to respond to rejected (denied) connections. If you do not want to send a response, just have it return + * null. Take a look at {@link RuleBasedIpFilter} for details. + */ +public abstract class AbstractRemoteAddressFilter extends ChannelInboundHandlerAdapter { + + @Override + public void channelRegistered(ChannelHandlerContext ctx) throws Exception { + handleNewChannel(ctx); + ctx.fireChannelRegistered(); + } + + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + if (!handleNewChannel(ctx)) { + throw new IllegalStateException("cannot determine to accept or reject a channel: " + ctx.channel()); + } else { + ctx.fireChannelActive(); + } + } + + private boolean handleNewChannel(ChannelHandlerContext ctx) throws Exception { + @SuppressWarnings("unchecked") + T remoteAddress = (T) ctx.channel().remoteAddress(); + + // If the remote address is not available yet, defer the decision. + if (remoteAddress == null) { + return false; + } + + // No need to keep this handler in the pipeline anymore because the decision is going to be made now. + // Also, this will prevent the subsequent events from being handled by this handler. + ctx.pipeline().remove(this); + + if (accept(ctx, remoteAddress)) { + channelAccepted(ctx, remoteAddress); + } else { + ChannelFuture rejectedFuture = channelRejected(ctx, remoteAddress); + if (rejectedFuture != null) { + rejectedFuture.addListener(ChannelFutureListener.CLOSE); + } else { + ctx.close(); + } + } + + return true; + } + + /** + * This method is called immediately after a {@link io.netty.channel.Channel} gets registered. + * + * @return Return true if connections from this IP address and port should be accepted. False otherwise. + */ + protected abstract boolean accept(ChannelHandlerContext ctx, T remoteAddress) throws Exception; + + /** + * This method is called if {@code remoteAddress} gets accepted by + * {@link #accept(ChannelHandlerContext, SocketAddress)}. You should override it if you would like to handle + * (e.g. respond to) accepted addresses. + */ + @SuppressWarnings("UnusedParameters") + protected void channelAccepted(ChannelHandlerContext ctx, T remoteAddress) { } + + /** + * This method is called if {@code remoteAddress} gets rejected by + * {@link #accept(ChannelHandlerContext, SocketAddress)}. You should override it if you would like to handle + * (e.g. respond to) rejected addresses. + * + * @return A {@link ChannelFuture} if you perform I/O operations, so that + * the {@link Channel} can be closed once it completes. Null otherwise. + */ + @SuppressWarnings("UnusedParameters") + protected ChannelFuture channelRejected(ChannelHandlerContext ctx, T remoteAddress) { + return null; + } +} diff --git a/netty-handler/src/main/java/io/netty/handler/ipfilter/IpFilterRule.java b/netty-handler/src/main/java/io/netty/handler/ipfilter/IpFilterRule.java new file mode 100644 index 0000000..51b2013 --- /dev/null +++ b/netty-handler/src/main/java/io/netty/handler/ipfilter/IpFilterRule.java @@ -0,0 +1,36 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ipfilter; + +import java.net.InetSocketAddress; + +/** + * Implement this interface to create new rules. + */ +public interface IpFilterRule { + /** + * @return This method should return true if remoteAddress is valid according to your criteria. False otherwise. + */ + boolean matches(InetSocketAddress remoteAddress); + + /** + * @return This method should return {@link IpFilterRuleType#ACCEPT} if all + * {@link IpFilterRule#matches(InetSocketAddress)} for which {@link #matches(InetSocketAddress)} + * returns true should the accepted. If you want to exclude all of those IP addresses then + * {@link IpFilterRuleType#REJECT} should be returned. + */ + IpFilterRuleType ruleType(); +} diff --git a/netty-handler/src/main/java/io/netty/handler/ipfilter/IpFilterRuleType.java b/netty-handler/src/main/java/io/netty/handler/ipfilter/IpFilterRuleType.java new file mode 100644 index 0000000..64ce9be --- /dev/null +++ b/netty-handler/src/main/java/io/netty/handler/ipfilter/IpFilterRuleType.java @@ -0,0 +1,24 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ipfilter; + +/** + * Used in {@link IpFilterRule} to decide if a matching IP Address should be allowed or denied to connect. + */ +public enum IpFilterRuleType { + ACCEPT, + REJECT +} diff --git a/netty-handler/src/main/java/io/netty/handler/ipfilter/IpSubnetFilter.java b/netty-handler/src/main/java/io/netty/handler/ipfilter/IpSubnetFilter.java new file mode 100644 index 0000000..15d19df --- /dev/null +++ b/netty-handler/src/main/java/io/netty/handler/ipfilter/IpSubnetFilter.java @@ -0,0 +1,226 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ipfilter; + +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandler.Sharable; +import io.netty.channel.ChannelHandlerContext; +import io.netty.util.internal.ObjectUtil; + +import java.net.Inet4Address; +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; + +/** + *

+ * This class allows one to filter new {@link Channel}s based on the + * {@link IpSubnetFilter}s passed to its constructor. If no rules are provided, all connections + * will be accepted since {@code acceptIfNotFound} is {@code true} by default. + *

+ * + *

+ * If you would like to explicitly take action on rejected {@link Channel}s, you should override + * {@link AbstractRemoteAddressFilter#channelRejected(ChannelHandlerContext, SocketAddress)}. + *

+ * + *

+ * Few Points to keep in mind: + *

    + *
  1. Since {@link IpSubnetFilter} uses Binary search algorithm, it's a good + * idea to insert IP addresses in incremental order.
  2. + *
  3. Remove any over-lapping CIDR.
  4. + *
+ *

+ * + */ +@Sharable +public class IpSubnetFilter extends AbstractRemoteAddressFilter { + + private final boolean acceptIfNotFound; + private final List ipv4Rules; + private final List ipv6Rules; + private final IpFilterRuleType ipFilterRuleTypeIPv4; + private final IpFilterRuleType ipFilterRuleTypeIPv6; + + /** + *

Create new {@link IpSubnetFilter} Instance with specified {@link IpSubnetFilterRule} as array.

+ *

{@code acceptIfNotFound} is set to {@code true}.

+ * + * @param rules {@link IpSubnetFilterRule} as an array + */ + public IpSubnetFilter(IpSubnetFilterRule... rules) { + this(true, Arrays.asList(ObjectUtil.checkNotNull(rules, "rules"))); + } + + /** + *

Create new {@link IpSubnetFilter} Instance with specified {@link IpSubnetFilterRule} as array + * and specify if we'll accept a connection if we don't find it in the rule(s).

+ * + * @param acceptIfNotFound {@code true} if we'll accept connection if not found in rule(s). + * @param rules {@link IpSubnetFilterRule} as an array + */ + public IpSubnetFilter(boolean acceptIfNotFound, IpSubnetFilterRule... rules) { + this(acceptIfNotFound, Arrays.asList(ObjectUtil.checkNotNull(rules, "rules"))); + } + + /** + *

Create new {@link IpSubnetFilter} Instance with specified {@link IpSubnetFilterRule} as {@link List}.

+ *

{@code acceptIfNotFound} is set to {@code true}.

+ * + * @param rules {@link IpSubnetFilterRule} as a {@link List} + */ + public IpSubnetFilter(List rules) { + this(true, rules); + } + + /** + *

Create new {@link IpSubnetFilter} Instance with specified {@link IpSubnetFilterRule} as {@link List} + * and specify if we'll accept a connection if we don't find it in the rule(s).

+ * + * @param acceptIfNotFound {@code true} if we'll accept connection if not found in rule(s). + * @param rules {@link IpSubnetFilterRule} as a {@link List} + */ + public IpSubnetFilter(boolean acceptIfNotFound, List rules) { + ObjectUtil.checkNotNull(rules, "rules"); + this.acceptIfNotFound = acceptIfNotFound; + + int numAcceptIPv4 = 0; + int numRejectIPv4 = 0; + int numAcceptIPv6 = 0; + int numRejectIPv6 = 0; + + List unsortedIPv4Rules = new ArrayList(); + List unsortedIPv6Rules = new ArrayList(); + + // Iterate over rules and check for `null` rule. + for (IpSubnetFilterRule ipSubnetFilterRule : rules) { + ObjectUtil.checkNotNull(ipSubnetFilterRule, "rule"); + + if (ipSubnetFilterRule.getFilterRule() instanceof IpSubnetFilterRule.Ip4SubnetFilterRule) { + unsortedIPv4Rules.add(ipSubnetFilterRule); + + if (ipSubnetFilterRule.ruleType() == IpFilterRuleType.ACCEPT) { + numAcceptIPv4++; + } else { + numRejectIPv4++; + } + } else { + unsortedIPv6Rules.add(ipSubnetFilterRule); + + if (ipSubnetFilterRule.ruleType() == IpFilterRuleType.ACCEPT) { + numAcceptIPv6++; + } else { + numRejectIPv6++; + } + } + } + + /* + * If Number of ACCEPT rule is 0 and number of REJECT rules is more than 0, + * then all rules are of "REJECT" type. + * + * In this case, we'll set `ipFilterRuleTypeIPv4` to `IpFilterRuleType.REJECT`. + * + * If Number of ACCEPT rules are more than 0 and number of REJECT rules is 0, + * then all rules are of "ACCEPT" type. + * + * In this case, we'll set `ipFilterRuleTypeIPv4` to `IpFilterRuleType.ACCEPT`. + */ + if (numAcceptIPv4 == 0 && numRejectIPv4 > 0) { + ipFilterRuleTypeIPv4 = IpFilterRuleType.REJECT; + } else if (numAcceptIPv4 > 0 && numRejectIPv4 == 0) { + ipFilterRuleTypeIPv4 = IpFilterRuleType.ACCEPT; + } else { + ipFilterRuleTypeIPv4 = null; + } + + if (numAcceptIPv6 == 0 && numRejectIPv6 > 0) { + ipFilterRuleTypeIPv6 = IpFilterRuleType.REJECT; + } else if (numAcceptIPv6 > 0 && numRejectIPv6 == 0) { + ipFilterRuleTypeIPv6 = IpFilterRuleType.ACCEPT; + } else { + ipFilterRuleTypeIPv6 = null; + } + + this.ipv4Rules = sortAndFilter(unsortedIPv4Rules); + this.ipv6Rules = sortAndFilter(unsortedIPv6Rules); + } + + @Override + protected boolean accept(ChannelHandlerContext ctx, InetSocketAddress remoteAddress) { + if (remoteAddress.getAddress() instanceof Inet4Address) { + int indexOf = Collections.binarySearch(ipv4Rules, remoteAddress, IpSubnetFilterRuleComparator.INSTANCE); + if (indexOf >= 0) { + if (ipFilterRuleTypeIPv4 == null) { + return ipv4Rules.get(indexOf).ruleType() == IpFilterRuleType.ACCEPT; + } else { + return ipFilterRuleTypeIPv4 == IpFilterRuleType.ACCEPT; + } + } + } else { + int indexOf = Collections.binarySearch(ipv6Rules, remoteAddress, IpSubnetFilterRuleComparator.INSTANCE); + if (indexOf >= 0) { + if (ipFilterRuleTypeIPv6 == null) { + return ipv6Rules.get(indexOf).ruleType() == IpFilterRuleType.ACCEPT; + } else { + return ipFilterRuleTypeIPv6 == IpFilterRuleType.ACCEPT; + } + } + } + + return acceptIfNotFound; + } + + /** + *
    + *
  1. Sort the list
  2. + *
  3. Remove over-lapping subnet
  4. + *
  5. Sort the list again
  6. + *
+ */ + @SuppressWarnings("ConstantConditions") + private static List sortAndFilter(List rules) { + Collections.sort(rules); + Iterator iterator = rules.iterator(); + List toKeep = new ArrayList(); + + IpSubnetFilterRule parentRule = iterator.hasNext() ? iterator.next() : null; + if (parentRule != null) { + toKeep.add(parentRule); + } + + while (iterator.hasNext()) { + + // Grab a potential child rule. + IpSubnetFilterRule childRule = iterator.next(); + + // If parentRule matches childRule, then there's no need to keep the child rule. + // Otherwise, the rules are distinct and we need both. + if (!parentRule.matches(new InetSocketAddress(childRule.getIpAddress(), 1))) { + toKeep.add(childRule); + // Then we'll keep the child rule around as the parent for the next round. + parentRule = childRule; + } + } + + return toKeep; + } +} diff --git a/netty-handler/src/main/java/io/netty/handler/ipfilter/IpSubnetFilterRule.java b/netty-handler/src/main/java/io/netty/handler/ipfilter/IpSubnetFilterRule.java new file mode 100644 index 0000000..ef1d487 --- /dev/null +++ b/netty-handler/src/main/java/io/netty/handler/ipfilter/IpSubnetFilterRule.java @@ -0,0 +1,219 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ipfilter; + +import io.netty.util.NetUtil; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.SocketUtils; + +import java.math.BigInteger; +import java.net.Inet4Address; +import java.net.Inet6Address; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.UnknownHostException; + +/** + * Use this class to create rules for {@link RuleBasedIpFilter} that group IP addresses into subnets. + * Supports both, IPv4 and IPv6. + */ +public final class IpSubnetFilterRule implements IpFilterRule, Comparable { + + private final IpFilterRule filterRule; + private final String ipAddress; + + public IpSubnetFilterRule(String ipAddress, int cidrPrefix, IpFilterRuleType ruleType) { + try { + this.ipAddress = ipAddress; + filterRule = selectFilterRule(SocketUtils.addressByName(ipAddress), cidrPrefix, ruleType); + } catch (UnknownHostException e) { + throw new IllegalArgumentException("ipAddress", e); + } + } + + public IpSubnetFilterRule(InetAddress ipAddress, int cidrPrefix, IpFilterRuleType ruleType) { + this.ipAddress = ipAddress.getHostAddress(); + filterRule = selectFilterRule(ipAddress, cidrPrefix, ruleType); + } + + private static IpFilterRule selectFilterRule(InetAddress ipAddress, int cidrPrefix, IpFilterRuleType ruleType) { + ObjectUtil.checkNotNull(ipAddress, "ipAddress"); + ObjectUtil.checkNotNull(ruleType, "ruleType"); + + if (ipAddress instanceof Inet4Address) { + return new Ip4SubnetFilterRule((Inet4Address) ipAddress, cidrPrefix, ruleType); + } else if (ipAddress instanceof Inet6Address) { + return new Ip6SubnetFilterRule((Inet6Address) ipAddress, cidrPrefix, ruleType); + } else { + throw new IllegalArgumentException("Only IPv4 and IPv6 addresses are supported"); + } + } + + @Override + public boolean matches(InetSocketAddress remoteAddress) { + return filterRule.matches(remoteAddress); + } + + @Override + public IpFilterRuleType ruleType() { + return filterRule.ruleType(); + } + + /** + * Get IP Address of this rule + */ + String getIpAddress() { + return ipAddress; + } + + /** + * {@link Ip4SubnetFilterRule} or {@link Ip6SubnetFilterRule} + */ + IpFilterRule getFilterRule() { + return filterRule; + } + + @Override + public int compareTo(IpSubnetFilterRule ipSubnetFilterRule) { + if (filterRule instanceof Ip4SubnetFilterRule) { + return compareInt(((Ip4SubnetFilterRule) filterRule).networkAddress, + ((Ip4SubnetFilterRule) ipSubnetFilterRule.filterRule).networkAddress); + } else { + return ((Ip6SubnetFilterRule) filterRule).networkAddress + .compareTo(((Ip6SubnetFilterRule) ipSubnetFilterRule.filterRule).networkAddress); + } + } + + /** + * It'll compare IP address with {@link Ip4SubnetFilterRule#networkAddress} or + * {@link Ip6SubnetFilterRule#networkAddress}. + * + * @param inetSocketAddress {@link InetSocketAddress} to match + * @return 0 if IP Address match else difference index. + */ + int compareTo(InetSocketAddress inetSocketAddress) { + if (filterRule instanceof Ip4SubnetFilterRule) { + Ip4SubnetFilterRule ip4SubnetFilterRule = (Ip4SubnetFilterRule) filterRule; + return compareInt(ip4SubnetFilterRule.networkAddress, NetUtil.ipv4AddressToInt((Inet4Address) + inetSocketAddress.getAddress()) & ip4SubnetFilterRule.subnetMask); + } else { + Ip6SubnetFilterRule ip6SubnetFilterRule = (Ip6SubnetFilterRule) filterRule; + return ip6SubnetFilterRule.networkAddress + .compareTo(Ip6SubnetFilterRule.ipToInt((Inet6Address) inetSocketAddress.getAddress()) + .and(ip6SubnetFilterRule.networkAddress)); + } + } + + /** + * Equivalent to {@link Integer#compare(int, int)} + */ + private static int compareInt(int x, int y) { + return (x < y) ? -1 : ((x == y) ? 0 : 1); + } + + static final class Ip4SubnetFilterRule implements IpFilterRule { + + private final int networkAddress; + private final int subnetMask; + private final IpFilterRuleType ruleType; + + private Ip4SubnetFilterRule(Inet4Address ipAddress, int cidrPrefix, IpFilterRuleType ruleType) { + if (cidrPrefix < 0 || cidrPrefix > 32) { + throw new IllegalArgumentException(String.format("IPv4 requires the subnet prefix to be in range of " + + "[0,32]. The prefix was: %d", cidrPrefix)); + } + + subnetMask = prefixToSubnetMask(cidrPrefix); + networkAddress = NetUtil.ipv4AddressToInt(ipAddress) & subnetMask; + this.ruleType = ruleType; + } + + @Override + public boolean matches(InetSocketAddress remoteAddress) { + final InetAddress inetAddress = remoteAddress.getAddress(); + if (inetAddress instanceof Inet4Address) { + int ipAddress = NetUtil.ipv4AddressToInt((Inet4Address) inetAddress); + return (ipAddress & subnetMask) == networkAddress; + } + return false; + } + + @Override + public IpFilterRuleType ruleType() { + return ruleType; + } + + private static int prefixToSubnetMask(int cidrPrefix) { + /* + * Perform the shift on a long and downcast it to int afterwards. + * This is necessary to handle a cidrPrefix of zero correctly. + * The left shift operator on an int only uses the five least + * significant bits of the right-hand operand. Thus -1 << 32 evaluates + * to -1 instead of 0. The left shift operator applied on a long + * uses the six least significant bits. + * + * Also see https://github.com/netty/netty/issues/2767 + */ + return (int) (-1L << 32 - cidrPrefix); + } + } + + static final class Ip6SubnetFilterRule implements IpFilterRule { + + private static final BigInteger MINUS_ONE = BigInteger.valueOf(-1); + + private final BigInteger networkAddress; + private final BigInteger subnetMask; + private final IpFilterRuleType ruleType; + + private Ip6SubnetFilterRule(Inet6Address ipAddress, int cidrPrefix, IpFilterRuleType ruleType) { + if (cidrPrefix < 0 || cidrPrefix > 128) { + throw new IllegalArgumentException(String.format("IPv6 requires the subnet prefix to be in range of " + + "[0,128]. The prefix was: %d", cidrPrefix)); + } + + subnetMask = prefixToSubnetMask(cidrPrefix); + networkAddress = ipToInt(ipAddress).and(subnetMask); + this.ruleType = ruleType; + } + + @Override + public boolean matches(InetSocketAddress remoteAddress) { + final InetAddress inetAddress = remoteAddress.getAddress(); + if (inetAddress instanceof Inet6Address) { + BigInteger ipAddress = ipToInt((Inet6Address) inetAddress); + return ipAddress.and(subnetMask).equals(subnetMask) || ipAddress.and(subnetMask).equals(networkAddress); + } + return false; + } + + @Override + public IpFilterRuleType ruleType() { + return ruleType; + } + + private static BigInteger ipToInt(Inet6Address ipAddress) { + byte[] octets = ipAddress.getAddress(); + assert octets.length == 16; + + return new BigInteger(octets); + } + + private static BigInteger prefixToSubnetMask(int cidrPrefix) { + return MINUS_ONE.shiftLeft(128 - cidrPrefix); + } + } +} diff --git a/netty-handler/src/main/java/io/netty/handler/ipfilter/IpSubnetFilterRuleComparator.java b/netty-handler/src/main/java/io/netty/handler/ipfilter/IpSubnetFilterRuleComparator.java new file mode 100644 index 0000000..35bf2bf --- /dev/null +++ b/netty-handler/src/main/java/io/netty/handler/ipfilter/IpSubnetFilterRuleComparator.java @@ -0,0 +1,36 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ipfilter; + +import java.net.InetSocketAddress; +import java.util.Comparator; + +/** + * This comparator is only used for searching. + */ +final class IpSubnetFilterRuleComparator implements Comparator { + + static final IpSubnetFilterRuleComparator INSTANCE = new IpSubnetFilterRuleComparator(); + + private IpSubnetFilterRuleComparator() { + // Prevent outside initialization + } + + @Override + public int compare(Object o1, Object o2) { + return ((IpSubnetFilterRule) o1).compareTo((InetSocketAddress) o2); + } +} diff --git a/netty-handler/src/main/java/io/netty/handler/ipfilter/RuleBasedIpFilter.java b/netty-handler/src/main/java/io/netty/handler/ipfilter/RuleBasedIpFilter.java new file mode 100644 index 0000000..92eac83 --- /dev/null +++ b/netty-handler/src/main/java/io/netty/handler/ipfilter/RuleBasedIpFilter.java @@ -0,0 +1,92 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ipfilter; + +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandler.Sharable; +import io.netty.channel.ChannelHandlerContext; +import io.netty.util.internal.ObjectUtil; + +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.util.ArrayList; +import java.util.List; + +/** + *

+ * This class allows one to filter new {@link Channel}s based on the + * {@link IpFilterRule}s passed to its constructor. If no rules are provided, all connections + * will be accepted. + *

+ * + *

+ * If you would like to explicitly take action on rejected {@link Channel}s, you should override + * {@link AbstractRemoteAddressFilter#channelRejected(ChannelHandlerContext, SocketAddress)}. + *

+ * + *

Consider using {@link IpSubnetFilter} for better performance while not as + * general purpose as this filter.

+ */ +@Sharable +public class RuleBasedIpFilter extends AbstractRemoteAddressFilter { + + private final boolean acceptIfNotFound; + private final List rules; + + /** + *

Create new Instance of {@link RuleBasedIpFilter} and filter incoming connections + * based on their IP address and {@code rules} applied.

+ * + *

{@code acceptIfNotFound} is set to {@code true}.

+ * + * @param rules An array of {@link IpFilterRule} containing all rules. + */ + public RuleBasedIpFilter(IpFilterRule... rules) { + this(true, rules); + } + + /** + * Create new Instance of {@link RuleBasedIpFilter} and filter incoming connections + * based on their IP address and {@code rules} applied. + * + * @param acceptIfNotFound If {@code true} then accept connection from IP Address if it + * doesn't match any rule. + * @param rules An array of {@link IpFilterRule} containing all rules. + */ + public RuleBasedIpFilter(boolean acceptIfNotFound, IpFilterRule... rules) { + ObjectUtil.checkNotNull(rules, "rules"); + + this.acceptIfNotFound = acceptIfNotFound; + this.rules = new ArrayList(rules.length); + + for (IpFilterRule rule : rules) { + if (rule != null) { + this.rules.add(rule); + } + } + } + + @Override + protected boolean accept(ChannelHandlerContext ctx, InetSocketAddress remoteAddress) throws Exception { + for (IpFilterRule rule : rules) { + if (rule.matches(remoteAddress)) { + return rule.ruleType() == IpFilterRuleType.ACCEPT; + } + } + + return acceptIfNotFound; + } +} diff --git a/netty-handler/src/main/java/io/netty/handler/ipfilter/UniqueIpFilter.java b/netty-handler/src/main/java/io/netty/handler/ipfilter/UniqueIpFilter.java new file mode 100644 index 0000000..a3a5d1e --- /dev/null +++ b/netty-handler/src/main/java/io/netty/handler/ipfilter/UniqueIpFilter.java @@ -0,0 +1,53 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ipfilter; + +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.util.internal.ConcurrentSet; + +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.util.Set; + +/** + * This class allows one to ensure that at all times for every IP address there is at most one + * {@link Channel} connected to the server. + */ +@ChannelHandler.Sharable +public class UniqueIpFilter extends AbstractRemoteAddressFilter { + + private final Set connected = new ConcurrentSet(); + + @Override + protected boolean accept(ChannelHandlerContext ctx, InetSocketAddress remoteAddress) throws Exception { + final InetAddress remoteIp = remoteAddress.getAddress(); + if (!connected.add(remoteIp)) { + return false; + } else { + ctx.channel().closeFuture().addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + connected.remove(remoteIp); + } + }); + return true; + } + } +} diff --git a/netty-handler/src/main/java/io/netty/handler/ipfilter/package-info.java b/netty-handler/src/main/java/io/netty/handler/ipfilter/package-info.java new file mode 100644 index 0000000..9c1fde3 --- /dev/null +++ b/netty-handler/src/main/java/io/netty/handler/ipfilter/package-info.java @@ -0,0 +1,20 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/** + * Package to filter IP addresses (allow/deny). + */ +package io.netty.handler.ipfilter; diff --git a/netty-handler/src/main/java/io/netty/handler/logging/ByteBufFormat.java b/netty-handler/src/main/java/io/netty/handler/logging/ByteBufFormat.java new file mode 100644 index 0000000..11950c8 --- /dev/null +++ b/netty-handler/src/main/java/io/netty/handler/logging/ByteBufFormat.java @@ -0,0 +1,36 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.logging; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufHolder; +import io.netty.buffer.ByteBufUtil; + +/** + * Used to control the format and verbosity of logging for {@link ByteBuf}s and {@link ByteBufHolder}s. + * + * @see LoggingHandler + */ +public enum ByteBufFormat { + /** + * {@link ByteBuf}s will be logged in a simple format, with no hex dump included. + */ + SIMPLE, + /** + * {@link ByteBuf}s will be logged using {@link ByteBufUtil#appendPrettyHexDump(StringBuilder, ByteBuf)}. + */ + HEX_DUMP +} diff --git a/netty-handler/src/main/java/io/netty/handler/logging/LogLevel.java b/netty-handler/src/main/java/io/netty/handler/logging/LogLevel.java new file mode 100644 index 0000000..7e59e60 --- /dev/null +++ b/netty-handler/src/main/java/io/netty/handler/logging/LogLevel.java @@ -0,0 +1,46 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.logging; + +import io.netty.util.internal.logging.InternalLogLevel; + +/** + * Maps the regular {@link LogLevel}s with the {@link InternalLogLevel} ones. + */ +public enum LogLevel { + TRACE(InternalLogLevel.TRACE), + DEBUG(InternalLogLevel.DEBUG), + INFO(InternalLogLevel.INFO), + WARN(InternalLogLevel.WARN), + ERROR(InternalLogLevel.ERROR); + + private final InternalLogLevel internalLevel; + + LogLevel(InternalLogLevel internalLevel) { + this.internalLevel = internalLevel; + } + + /** + * For internal use only. + * + *

Converts the specified {@link LogLevel} to its {@link InternalLogLevel} variant. + * + * @return the converted level. + */ + public InternalLogLevel toInternalLevel() { + return internalLevel; + } +} diff --git a/netty-handler/src/main/java/io/netty/handler/logging/LoggingHandler.java b/netty-handler/src/main/java/io/netty/handler/logging/LoggingHandler.java new file mode 100644 index 0000000..6b497d1 --- /dev/null +++ b/netty-handler/src/main/java/io/netty/handler/logging/LoggingHandler.java @@ -0,0 +1,427 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.logging; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufHolder; +import io.netty.channel.ChannelDuplexHandler; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandler.Sharable; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelOutboundHandler; +import io.netty.channel.ChannelPromise; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.logging.InternalLogLevel; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.net.SocketAddress; + +import static io.netty.buffer.ByteBufUtil.appendPrettyHexDump; +import static io.netty.util.internal.StringUtil.NEWLINE; + +/** + * A {@link ChannelHandler} that logs all events using a logging framework. + * By default, all events are logged at DEBUG level and full hex dumps are recorded for ByteBufs. + */ +@Sharable +@SuppressWarnings({ "StringConcatenationInsideStringBufferAppend", "StringBufferReplaceableByString" }) +public class LoggingHandler extends ChannelDuplexHandler { + + private static final LogLevel DEFAULT_LEVEL = LogLevel.DEBUG; + + protected final InternalLogger logger; + protected final InternalLogLevel internalLevel; + + private final LogLevel level; + private final ByteBufFormat byteBufFormat; + + /** + * Creates a new instance whose logger name is the fully qualified class + * name of the instance with hex dump enabled. + */ + public LoggingHandler() { + this(DEFAULT_LEVEL); + } + /** + * Creates a new instance whose logger name is the fully qualified class + * name of the instance. + * + * @param format Format of ByteBuf dumping + */ + public LoggingHandler(ByteBufFormat format) { + this(DEFAULT_LEVEL, format); + } + + /** + * Creates a new instance whose logger name is the fully qualified class + * name of the instance. + * + * @param level the log level + */ + public LoggingHandler(LogLevel level) { + this(level, ByteBufFormat.HEX_DUMP); + } + + /** + * Creates a new instance whose logger name is the fully qualified class + * name of the instance. + * + * @param level the log level + * @param byteBufFormat the ByteBuf format + */ + public LoggingHandler(LogLevel level, ByteBufFormat byteBufFormat) { + this.level = ObjectUtil.checkNotNull(level, "level"); + this.byteBufFormat = ObjectUtil.checkNotNull(byteBufFormat, "byteBufFormat"); + logger = InternalLoggerFactory.getInstance(getClass()); + internalLevel = level.toInternalLevel(); + } + + /** + * Creates a new instance with the specified logger name and with hex dump + * enabled. + * + * @param clazz the class type to generate the logger for + */ + public LoggingHandler(Class clazz) { + this(clazz, DEFAULT_LEVEL); + } + + /** + * Creates a new instance with the specified logger name. + * + * @param clazz the class type to generate the logger for + * @param level the log level + */ + public LoggingHandler(Class clazz, LogLevel level) { + this(clazz, level, ByteBufFormat.HEX_DUMP); + } + + /** + * Creates a new instance with the specified logger name. + * + * @param clazz the class type to generate the logger for + * @param level the log level + * @param byteBufFormat the ByteBuf format + */ + public LoggingHandler(Class clazz, LogLevel level, ByteBufFormat byteBufFormat) { + ObjectUtil.checkNotNull(clazz, "clazz"); + this.level = ObjectUtil.checkNotNull(level, "level"); + this.byteBufFormat = ObjectUtil.checkNotNull(byteBufFormat, "byteBufFormat"); + logger = InternalLoggerFactory.getInstance(clazz); + internalLevel = level.toInternalLevel(); + } + + /** + * Creates a new instance with the specified logger name using the default log level. + * + * @param name the name of the class to use for the logger + */ + public LoggingHandler(String name) { + this(name, DEFAULT_LEVEL); + } + + /** + * Creates a new instance with the specified logger name. + * + * @param name the name of the class to use for the logger + * @param level the log level + */ + public LoggingHandler(String name, LogLevel level) { + this(name, level, ByteBufFormat.HEX_DUMP); + } + + /** + * Creates a new instance with the specified logger name. + * + * @param name the name of the class to use for the logger + * @param level the log level + * @param byteBufFormat the ByteBuf format + */ + public LoggingHandler(String name, LogLevel level, ByteBufFormat byteBufFormat) { + ObjectUtil.checkNotNull(name, "name"); + + this.level = ObjectUtil.checkNotNull(level, "level"); + this.byteBufFormat = ObjectUtil.checkNotNull(byteBufFormat, "byteBufFormat"); + logger = InternalLoggerFactory.getInstance(name); + internalLevel = level.toInternalLevel(); + } + + /** + * Returns the {@link LogLevel} that this handler uses to log + */ + public LogLevel level() { + return level; + } + + /** + * Returns the {@link ByteBufFormat} that this handler uses to log + */ + public ByteBufFormat byteBufFormat() { + return byteBufFormat; + } + + @Override + public void channelRegistered(ChannelHandlerContext ctx) throws Exception { + if (logger.isEnabled(internalLevel)) { + logger.log(internalLevel, format(ctx, "REGISTERED")); + } + ctx.fireChannelRegistered(); + } + + @Override + public void channelUnregistered(ChannelHandlerContext ctx) throws Exception { + if (logger.isEnabled(internalLevel)) { + logger.log(internalLevel, format(ctx, "UNREGISTERED")); + } + ctx.fireChannelUnregistered(); + } + + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + if (logger.isEnabled(internalLevel)) { + logger.log(internalLevel, format(ctx, "ACTIVE")); + } + ctx.fireChannelActive(); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + if (logger.isEnabled(internalLevel)) { + logger.log(internalLevel, format(ctx, "INACTIVE")); + } + ctx.fireChannelInactive(); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + if (logger.isEnabled(internalLevel)) { + logger.log(internalLevel, format(ctx, "EXCEPTION", cause), cause); + } + ctx.fireExceptionCaught(cause); + } + + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + if (logger.isEnabled(internalLevel)) { + logger.log(internalLevel, format(ctx, "USER_EVENT", evt)); + } + ctx.fireUserEventTriggered(evt); + } + + @Override + public void bind(ChannelHandlerContext ctx, SocketAddress localAddress, ChannelPromise promise) throws Exception { + if (logger.isEnabled(internalLevel)) { + logger.log(internalLevel, format(ctx, "BIND", localAddress)); + } + ctx.bind(localAddress, promise); + } + + @Override + public void connect( + ChannelHandlerContext ctx, + SocketAddress remoteAddress, SocketAddress localAddress, ChannelPromise promise) throws Exception { + if (logger.isEnabled(internalLevel)) { + logger.log(internalLevel, format(ctx, "CONNECT", remoteAddress, localAddress)); + } + ctx.connect(remoteAddress, localAddress, promise); + } + + @Override + public void disconnect(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + if (logger.isEnabled(internalLevel)) { + logger.log(internalLevel, format(ctx, "DISCONNECT")); + } + ctx.disconnect(promise); + } + + @Override + public void close(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + if (logger.isEnabled(internalLevel)) { + logger.log(internalLevel, format(ctx, "CLOSE")); + } + ctx.close(promise); + } + + @Override + public void deregister(ChannelHandlerContext ctx, ChannelPromise promise) throws Exception { + if (logger.isEnabled(internalLevel)) { + logger.log(internalLevel, format(ctx, "DEREGISTER")); + } + ctx.deregister(promise); + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + if (logger.isEnabled(internalLevel)) { + logger.log(internalLevel, format(ctx, "READ COMPLETE")); + } + ctx.fireChannelReadComplete(); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if (logger.isEnabled(internalLevel)) { + logger.log(internalLevel, format(ctx, "READ", msg)); + } + ctx.fireChannelRead(msg); + } + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + if (logger.isEnabled(internalLevel)) { + logger.log(internalLevel, format(ctx, "WRITE", msg)); + } + ctx.write(msg, promise); + } + + @Override + public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception { + if (logger.isEnabled(internalLevel)) { + logger.log(internalLevel, format(ctx, "WRITABILITY CHANGED")); + } + ctx.fireChannelWritabilityChanged(); + } + + @Override + public void flush(ChannelHandlerContext ctx) throws Exception { + if (logger.isEnabled(internalLevel)) { + logger.log(internalLevel, format(ctx, "FLUSH")); + } + ctx.flush(); + } + + /** + * Formats an event and returns the formatted message. + * + * @param eventName the name of the event + */ + protected String format(ChannelHandlerContext ctx, String eventName) { + String chStr = ctx.channel().toString(); + return new StringBuilder(chStr.length() + 1 + eventName.length()) + .append(chStr) + .append(' ') + .append(eventName) + .toString(); + } + + /** + * Formats an event and returns the formatted message. + * + * @param eventName the name of the event + * @param arg the argument of the event + */ + protected String format(ChannelHandlerContext ctx, String eventName, Object arg) { + if (arg instanceof ByteBuf) { + return formatByteBuf(ctx, eventName, (ByteBuf) arg); + } else if (arg instanceof ByteBufHolder) { + return formatByteBufHolder(ctx, eventName, (ByteBufHolder) arg); + } else { + return formatSimple(ctx, eventName, arg); + } + } + + /** + * Formats an event and returns the formatted message. This method is currently only used for formatting + * {@link ChannelOutboundHandler#connect(ChannelHandlerContext, SocketAddress, SocketAddress, ChannelPromise)}. + * + * @param eventName the name of the event + * @param firstArg the first argument of the event + * @param secondArg the second argument of the event + */ + protected String format(ChannelHandlerContext ctx, String eventName, Object firstArg, Object secondArg) { + if (secondArg == null) { + return formatSimple(ctx, eventName, firstArg); + } + + String chStr = ctx.channel().toString(); + String arg1Str = String.valueOf(firstArg); + String arg2Str = secondArg.toString(); + StringBuilder buf = new StringBuilder( + chStr.length() + 1 + eventName.length() + 2 + arg1Str.length() + 2 + arg2Str.length()); + buf.append(chStr).append(' ').append(eventName).append(": ").append(arg1Str).append(", ").append(arg2Str); + return buf.toString(); + } + + /** + * Generates the default log message of the specified event whose argument is a {@link ByteBuf}. + */ + private String formatByteBuf(ChannelHandlerContext ctx, String eventName, ByteBuf msg) { + String chStr = ctx.channel().toString(); + int length = msg.readableBytes(); + if (length == 0) { + StringBuilder buf = new StringBuilder(chStr.length() + 1 + eventName.length() + 4); + buf.append(chStr).append(' ').append(eventName).append(": 0B"); + return buf.toString(); + } else { + int outputLength = chStr.length() + 1 + eventName.length() + 2 + 10 + 1; + if (byteBufFormat == ByteBufFormat.HEX_DUMP) { + int rows = length / 16 + (length % 15 == 0? 0 : 1) + 4; + int hexDumpLength = 2 + rows * 80; + outputLength += hexDumpLength; + } + StringBuilder buf = new StringBuilder(outputLength); + buf.append(chStr).append(' ').append(eventName).append(": ").append(length).append('B'); + if (byteBufFormat == ByteBufFormat.HEX_DUMP) { + buf.append(NEWLINE); + appendPrettyHexDump(buf, msg); + } + + return buf.toString(); + } + } + + /** + * Generates the default log message of the specified event whose argument is a {@link ByteBufHolder}. + */ + private String formatByteBufHolder(ChannelHandlerContext ctx, String eventName, ByteBufHolder msg) { + String chStr = ctx.channel().toString(); + String msgStr = msg.toString(); + ByteBuf content = msg.content(); + int length = content.readableBytes(); + if (length == 0) { + StringBuilder buf = new StringBuilder(chStr.length() + 1 + eventName.length() + 2 + msgStr.length() + 4); + buf.append(chStr).append(' ').append(eventName).append(", ").append(msgStr).append(", 0B"); + return buf.toString(); + } else { + int outputLength = chStr.length() + 1 + eventName.length() + 2 + msgStr.length() + 2 + 10 + 1; + if (byteBufFormat == ByteBufFormat.HEX_DUMP) { + int rows = length / 16 + (length % 15 == 0? 0 : 1) + 4; + int hexDumpLength = 2 + rows * 80; + outputLength += hexDumpLength; + } + StringBuilder buf = new StringBuilder(outputLength); + buf.append(chStr).append(' ').append(eventName).append(": ") + .append(msgStr).append(", ").append(length).append('B'); + if (byteBufFormat == ByteBufFormat.HEX_DUMP) { + buf.append(NEWLINE); + appendPrettyHexDump(buf, content); + } + + return buf.toString(); + } + } + + /** + * Generates the default log message of the specified event whose argument is an arbitrary object. + */ + private static String formatSimple(ChannelHandlerContext ctx, String eventName, Object msg) { + String chStr = ctx.channel().toString(); + String msgStr = String.valueOf(msg); + StringBuilder buf = new StringBuilder(chStr.length() + 1 + eventName.length() + 2 + msgStr.length()); + return buf.append(chStr).append(' ').append(eventName).append(": ").append(msgStr).toString(); + } +} diff --git a/netty-handler/src/main/java/io/netty/handler/logging/package-info.java b/netty-handler/src/main/java/io/netty/handler/logging/package-info.java new file mode 100644 index 0000000..04db33c --- /dev/null +++ b/netty-handler/src/main/java/io/netty/handler/logging/package-info.java @@ -0,0 +1,20 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/** + * Logs the I/O events for debugging purpose. + */ +package io.netty.handler.logging; diff --git a/netty-handler/src/main/java/io/netty/handler/pcap/EthernetPacket.java b/netty-handler/src/main/java/io/netty/handler/pcap/EthernetPacket.java new file mode 100644 index 0000000..0133d34 --- /dev/null +++ b/netty-handler/src/main/java/io/netty/handler/pcap/EthernetPacket.java @@ -0,0 +1,81 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.pcap; + +import io.netty.buffer.ByteBuf; + +final class EthernetPacket { + + /** + * MAC Address: 00:00:5E:00:53:00 + */ + private static final byte[] DUMMY_SOURCE_MAC_ADDRESS = new byte[]{0, 0, 94, 0, 83, 0}; + + /** + * MAC Address: 00:00:5E:00:53:FF + */ + private static final byte[] DUMMY_DESTINATION_MAC_ADDRESS = new byte[]{0, 0, 94, 0, 83, -1}; + + /** + * IPv4 + */ + private static final int V4 = 0x0800; + + /** + * IPv6 + */ + private static final int V6 = 0x86dd; + + private EthernetPacket() { + // Prevent outside initialization + } + + /** + * Write IPv4 Ethernet Packet. It uses a dummy MAC address for both source and destination. + * + * @param byteBuf ByteBuf where Ethernet Packet data will be set + * @param payload Payload of IPv4 + */ + static void writeIPv4(ByteBuf byteBuf, ByteBuf payload) { + writePacket(byteBuf, payload, DUMMY_SOURCE_MAC_ADDRESS, DUMMY_DESTINATION_MAC_ADDRESS, V4); + } + + /** + * Write IPv6 Ethernet Packet. It uses a dummy MAC address for both source and destination. + * + * @param byteBuf ByteBuf where Ethernet Packet data will be set + * @param payload Payload of IPv6 + */ + static void writeIPv6(ByteBuf byteBuf, ByteBuf payload) { + writePacket(byteBuf, payload, DUMMY_SOURCE_MAC_ADDRESS, DUMMY_DESTINATION_MAC_ADDRESS, V6); + } + + /** + * Write IPv6 Ethernet Packet + * + * @param byteBuf ByteBuf where Ethernet Packet data will be set + * @param payload Payload of IPv6 + * @param srcAddress Source MAC Address + * @param dstAddress Destination MAC Address + * @param type Type of Frame + */ + private static void writePacket(ByteBuf byteBuf, ByteBuf payload, byte[] srcAddress, byte[] dstAddress, int type) { + byteBuf.writeBytes(dstAddress); // Destination MAC Address + byteBuf.writeBytes(srcAddress); // Source MAC Address + byteBuf.writeShort(type); // Frame Type (IPv4 or IPv6) + byteBuf.writeBytes(payload); // Payload of L3 + } +} diff --git a/netty-handler/src/main/java/io/netty/handler/pcap/IPPacket.java b/netty-handler/src/main/java/io/netty/handler/pcap/IPPacket.java new file mode 100644 index 0000000..11e8bf3 --- /dev/null +++ b/netty-handler/src/main/java/io/netty/handler/pcap/IPPacket.java @@ -0,0 +1,111 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.pcap; + +import io.netty.buffer.ByteBuf; + +final class IPPacket { + + private static final byte MAX_TTL = (byte) 255; + private static final short V4_HEADER_SIZE = 20; + private static final byte TCP = 6 & 0xff; + private static final byte UDP = 17 & 0xff; + + /** + * Version + Traffic class + Flow label + */ + private static final int IPV6_VERSION_TRAFFIC_FLOW = 60000000; + + private IPPacket() { + // Prevent outside initialization + } + + /** + * Write IPv4 Packet for UDP Packet + * + * @param byteBuf ByteBuf where IP Packet data will be set + * @param payload Payload of UDP + * @param srcAddress Source IPv4 Address + * @param dstAddress Destination IPv4 Address + */ + static void writeUDPv4(ByteBuf byteBuf, ByteBuf payload, int srcAddress, int dstAddress) { + writePacketv4(byteBuf, payload, UDP, srcAddress, dstAddress); + } + + /** + * Write IPv6 Packet for UDP Packet + * + * @param byteBuf ByteBuf where IP Packet data will be set + * @param payload Payload of UDP + * @param srcAddress Source IPv6 Address + * @param dstAddress Destination IPv6 Address + */ + static void writeUDPv6(ByteBuf byteBuf, ByteBuf payload, byte[] srcAddress, byte[] dstAddress) { + writePacketv6(byteBuf, payload, UDP, srcAddress, dstAddress); + } + + /** + * Write IPv4 Packet for TCP Packet + * + * @param byteBuf ByteBuf where IP Packet data will be set + * @param payload Payload of TCP + * @param srcAddress Source IPv4 Address + * @param dstAddress Destination IPv4 Address + */ + static void writeTCPv4(ByteBuf byteBuf, ByteBuf payload, int srcAddress, int dstAddress) { + writePacketv4(byteBuf, payload, TCP, srcAddress, dstAddress); + } + + /** + * Write IPv6 Packet for TCP Packet + * + * @param byteBuf ByteBuf where IP Packet data will be set + * @param payload Payload of TCP + * @param srcAddress Source IPv6 Address + * @param dstAddress Destination IPv6 Address + */ + static void writeTCPv6(ByteBuf byteBuf, ByteBuf payload, byte[] srcAddress, byte[] dstAddress) { + writePacketv6(byteBuf, payload, TCP, srcAddress, dstAddress); + } + + private static void writePacketv4(ByteBuf byteBuf, ByteBuf payload, int protocol, int srcAddress, + int dstAddress) { + + byteBuf.writeByte(0x45); // Version + IHL + byteBuf.writeByte(0x00); // DSCP + byteBuf.writeShort(V4_HEADER_SIZE + payload.readableBytes()); // Length + byteBuf.writeShort(0x0000); // Identification + byteBuf.writeShort(0x0000); // Fragment + byteBuf.writeByte(MAX_TTL); // TTL + byteBuf.writeByte(protocol); // Protocol + byteBuf.writeShort(0); // Checksum + byteBuf.writeInt(srcAddress); // Source IPv4 Address + byteBuf.writeInt(dstAddress); // Destination IPv4 Address + byteBuf.writeBytes(payload); // Payload of L4 + } + + private static void writePacketv6(ByteBuf byteBuf, ByteBuf payload, int protocol, byte[] srcAddress, + byte[] dstAddress) { + + byteBuf.writeInt(IPV6_VERSION_TRAFFIC_FLOW); // Version + Traffic class + Flow label + byteBuf.writeShort(payload.readableBytes()); // Payload length + byteBuf.writeByte(protocol & 0xff); // Next header + byteBuf.writeByte(MAX_TTL); // Hop limit + byteBuf.writeBytes(srcAddress); // Source IPv6 Address + byteBuf.writeBytes(dstAddress); // Destination IPv6 Address + byteBuf.writeBytes(payload); // Payload of L4 + } +} diff --git a/netty-handler/src/main/java/io/netty/handler/pcap/PcapHeaders.java b/netty-handler/src/main/java/io/netty/handler/pcap/PcapHeaders.java new file mode 100644 index 0000000..fda1afb --- /dev/null +++ b/netty-handler/src/main/java/io/netty/handler/pcap/PcapHeaders.java @@ -0,0 +1,69 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.pcap; + +import io.netty.buffer.ByteBuf; + +import java.io.IOException; +import java.io.OutputStream; + +final class PcapHeaders { + + /** + * Pcap Global Header built from: + *

    + *
  1. magic_number
  2. + *
  3. version_major
  4. + *
  5. version_minor
  6. + *
  7. thiszone
  8. + *
  9. sigfigs
  10. + *
  11. snaplen
  12. + *
  13. network
  14. + *
+ */ + private static final byte[] GLOBAL_HEADER = {-95, -78, -61, -44, 0, 2, 0, 4, 0, 0, + 0, 0, 0, 0, 0, 0, 0, 0, -1, -1, 0, 0, 0, 1}; + + private PcapHeaders() { + // Prevent outside initialization + } + + /** + * Writes the Pcap Global Header to the provided {@code OutputStream} + * + * @param outputStream OutputStream where Pcap data will be written. + * @throws IOException if there is an error writing to the {@code OutputStream} + */ + static void writeGlobalHeader(OutputStream outputStream) throws IOException { + outputStream.write(GLOBAL_HEADER); + } + + /** + * Write Pcap Packet Header + * + * @param byteBuf ByteBuf where we'll write header data + * @param ts_sec timestamp seconds + * @param ts_usec timestamp microseconds + * @param incl_len number of octets of packet saved in file + * @param orig_len actual length of packet + */ + static void writePacketHeader(ByteBuf byteBuf, int ts_sec, int ts_usec, int incl_len, int orig_len) { + byteBuf.writeInt(ts_sec); + byteBuf.writeInt(ts_usec); + byteBuf.writeInt(incl_len); + byteBuf.writeInt(orig_len); + } +} diff --git a/netty-handler/src/main/java/io/netty/handler/pcap/PcapWriteHandler.java b/netty-handler/src/main/java/io/netty/handler/pcap/PcapWriteHandler.java new file mode 100644 index 0000000..bbd6871 --- /dev/null +++ b/netty-handler/src/main/java/io/netty/handler/pcap/PcapWriteHandler.java @@ -0,0 +1,861 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.pcap; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.channel.Channel; +import io.netty.channel.ChannelDuplexHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelPromise; +import io.netty.channel.ServerChannel; +import io.netty.channel.socket.DatagramChannel; +import io.netty.channel.socket.DatagramPacket; +import io.netty.channel.socket.ServerSocketChannel; +import io.netty.channel.socket.SocketChannel; +import io.netty.util.NetUtil; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.io.Closeable; +import java.io.IOException; +import java.io.OutputStream; +import java.net.Inet4Address; +import java.net.Inet6Address; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.UnknownHostException; +import java.util.concurrent.atomic.AtomicReference; + +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +/** + *

{@link PcapWriteHandler} captures {@link ByteBuf} from {@link SocketChannel} / {@link ServerChannel} + * or {@link DatagramPacket} and writes it into Pcap {@link OutputStream}.

+ * + *

+ * Things to keep in mind when using {@link PcapWriteHandler} with TCP: + * + *

    + *
  • Whenever {@link ChannelInboundHandlerAdapter#channelActive(ChannelHandlerContext)} is called, + * a fake TCP 3-way handshake (SYN, SYN+ACK, ACK) is simulated as new connection in Pcap.
  • + * + *
  • Whenever {@link ChannelInboundHandlerAdapter#handlerRemoved(ChannelHandlerContext)} is called, + * a fake TCP 3-way handshake (FIN+ACK, FIN+ACK, ACK) is simulated as connection shutdown in Pcap.
  • + * + *
  • Whenever {@link ChannelInboundHandlerAdapter#exceptionCaught(ChannelHandlerContext, Throwable)} + * is called, a fake TCP RST is sent to simulate connection Reset in Pcap.
  • + * + *
  • ACK is sent each time data is send / received.
  • + * + *
  • Zero Length Data Packets can cause TCP Double ACK error in Wireshark. To tackle this, + * set {@code captureZeroByte} to {@code false}.
  • + *
+ *

+ */ +public final class PcapWriteHandler extends ChannelDuplexHandler implements Closeable { + + /** + * Logger for logging events + */ + private final InternalLogger logger = InternalLoggerFactory.getInstance(PcapWriteHandler.class); + + /** + * {@link PcapWriter} Instance + */ + private PcapWriter pCapWriter; + + /** + * {@link OutputStream} where we'll write Pcap data. + */ + private final OutputStream outputStream; + + /** + * {@code true} if we want to capture packets with zero bytes else {@code false}. + */ + private final boolean captureZeroByte; + + /** + * {@code true} if we want to write Pcap Global Header on initialization of + * {@link PcapWriter} else {@code false}. + */ + private final boolean writePcapGlobalHeader; + + /** + * {@code true} if we want to synchronize on the {@link OutputStream} while writing + * else {@code false}. + */ + private final boolean sharedOutputStream; + + /** + * TCP Sender Segment Number. + * It'll start with 1 and keep incrementing with number of bytes read/sent. + */ + private int sendSegmentNumber = 1; + + /** + * TCP Receiver Segment Number. + * It'll start with 1 and keep incrementing with number of bytes read/sent. + */ + private int receiveSegmentNumber = 1; + + /** + * Type of the channel this handler is registered on + */ + private ChannelType channelType; + + /** + * Address of the initiator of the connection + */ + private InetSocketAddress initiatorAddr; + + /** + * Address of the receiver of the connection + */ + private InetSocketAddress handlerAddr; + + /** + * Set to {@code true} if this handler is registered on a server pipeline + */ + private boolean isServerPipeline; + + /** + * Current of this {@link PcapWriteHandler} + */ + private final AtomicReference state = new AtomicReference(State.INIT); + + /** + * Create new {@link PcapWriteHandler} Instance. + * {@code captureZeroByte} is set to {@code false} and + * {@code writePcapGlobalHeader} is set to {@code true}. + * + * @param outputStream OutputStream where Pcap data will be written. Call {@link #close()} to close this + * OutputStream. + * @throws NullPointerException If {@link OutputStream} is {@code null} then we'll throw an + * {@link NullPointerException} + * + * @deprecated Use {@link Builder} instead. + */ + @Deprecated + public PcapWriteHandler(OutputStream outputStream) { + this(outputStream, false, true); + } + + /** + * Create new {@link PcapWriteHandler} Instance + * + * @param outputStream OutputStream where Pcap data will be written. Call {@link #close()} to close this + * OutputStream. + * @param captureZeroByte Set to {@code true} to enable capturing packets with empty (0 bytes) payload. + * Otherwise, if set to {@code false}, empty packets will be filtered out. + * @param writePcapGlobalHeader Set to {@code true} to write Pcap Global Header on initialization. + * Otherwise, if set to {@code false}, Pcap Global Header will not be written + * on initialization. This could when writing Pcap data on a existing file where + * Pcap Global Header is already present. + * @throws NullPointerException If {@link OutputStream} is {@code null} then we'll throw an + * {@link NullPointerException} + * + * @deprecated Use {@link Builder} instead. + */ + @Deprecated + public PcapWriteHandler(OutputStream outputStream, boolean captureZeroByte, boolean writePcapGlobalHeader) { + this.outputStream = checkNotNull(outputStream, "OutputStream"); + this.captureZeroByte = captureZeroByte; + this.writePcapGlobalHeader = writePcapGlobalHeader; + sharedOutputStream = false; + } + + private PcapWriteHandler(Builder builder, OutputStream outputStream) { + this.outputStream = outputStream; + captureZeroByte = builder.captureZeroByte; + sharedOutputStream = builder.sharedOutputStream; + writePcapGlobalHeader = builder.writePcapGlobalHeader; + channelType = builder.channelType; + handlerAddr = builder.handlerAddr; + initiatorAddr = builder.initiatorAddr; + isServerPipeline = builder.isServerPipeline; + } + + /** + * Writes the Pcap Global Header to the provided {@code OutputStream} + * + * @param outputStream OutputStream where Pcap data will be written. + * @throws IOException if there is an error writing to the {@code OutputStream} + */ + public static void writeGlobalHeader(OutputStream outputStream) throws IOException { + PcapHeaders.writeGlobalHeader(outputStream); + } + + private void initializeIfNecessary(ChannelHandlerContext ctx) throws Exception { + // If State is not 'INIT' then it means we're already initialized so then no need to initiaize again. + if (state.get() != State.INIT) { + return; + } + + pCapWriter = new PcapWriter(this); + + if (channelType == null) { + // infer channel type + if (ctx.channel() instanceof SocketChannel) { + channelType = ChannelType.TCP; + + // If Channel belongs to `SocketChannel` then we're handling TCP. + // Capture correct `localAddress` and `remoteAddress` + if (ctx.channel().parent() instanceof ServerSocketChannel) { + isServerPipeline = true; + initiatorAddr = (InetSocketAddress) ctx.channel().remoteAddress(); + handlerAddr = getLocalAddress(ctx.channel(), initiatorAddr); + } else { + isServerPipeline = false; + handlerAddr = (InetSocketAddress) ctx.channel().remoteAddress(); + initiatorAddr = getLocalAddress(ctx.channel(), handlerAddr); + } + } else if (ctx.channel() instanceof DatagramChannel) { + channelType = ChannelType.UDP; + + DatagramChannel datagramChannel = (DatagramChannel) ctx.channel(); + + // If `DatagramChannel` is connected then we can get + // `localAddress` and `remoteAddress` from Channel. + if (datagramChannel.isConnected()) { + handlerAddr = (InetSocketAddress) ctx.channel().remoteAddress(); + initiatorAddr = getLocalAddress(ctx.channel(), handlerAddr); + } + } + } + + if (channelType == ChannelType.TCP) { + logger.debug("Initiating Fake TCP 3-Way Handshake"); + + ByteBuf tcpBuf = ctx.alloc().buffer(); + + try { + // Write SYN with Normal Source and Destination Address + TCPPacket.writePacket(tcpBuf, null, 0, 0, + initiatorAddr.getPort(), handlerAddr.getPort(), TCPPacket.TCPFlag.SYN); + completeTCPWrite(initiatorAddr, handlerAddr, tcpBuf, ctx.alloc(), ctx); + + // Write SYN+ACK with Reversed Source and Destination Address + TCPPacket.writePacket(tcpBuf, null, 0, 1, + handlerAddr.getPort(), initiatorAddr.getPort(), TCPPacket.TCPFlag.SYN, TCPPacket.TCPFlag.ACK); + completeTCPWrite(handlerAddr, initiatorAddr, tcpBuf, ctx.alloc(), ctx); + + // Write ACK with Normal Source and Destination Address + TCPPacket.writePacket(tcpBuf, null, 1, 1, initiatorAddr.getPort(), + handlerAddr.getPort(), TCPPacket.TCPFlag.ACK); + completeTCPWrite(initiatorAddr, handlerAddr, tcpBuf, ctx.alloc(), ctx); + } finally { + tcpBuf.release(); + } + + logger.debug("Finished Fake TCP 3-Way Handshake"); + } + + state.set(State.WRITING); + } + + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + initializeIfNecessary(ctx); + super.channelActive(ctx); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + // Initialize if needed + if (state.get() == State.INIT) { + initializeIfNecessary(ctx); + } + + // Only write if State is STARTED + if (state.get() == State.WRITING) { + if (channelType == ChannelType.TCP) { + handleTCP(ctx, msg, false); + } else if (channelType == ChannelType.UDP) { + handleUDP(ctx, msg); + } else { + logDiscard(); + } + } + super.channelRead(ctx, msg); + } + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + // Initialize if needed + if (state.get() == State.INIT) { + initializeIfNecessary(ctx); + } + + // Only write if State is STARTED + if (state.get() == State.WRITING) { + if (channelType == ChannelType.TCP) { + handleTCP(ctx, msg, true); + } else if (channelType == ChannelType.UDP) { + handleUDP(ctx, msg); + } else { + logDiscard(); + } + } + super.write(ctx, msg, promise); + } + + /** + * Handle TCP L4 + * + * @param ctx {@link ChannelHandlerContext} for {@link ByteBuf} allocation and + * {@code fireExceptionCaught} + * @param msg {@link Object} must be {@link ByteBuf} else it'll be discarded + * @param isWriteOperation Set {@code true} if we have to process packet when packets are being sent out + * else set {@code false} + */ + private void handleTCP(ChannelHandlerContext ctx, Object msg, boolean isWriteOperation) { + if (msg instanceof ByteBuf) { + + // If bytes are 0 and `captureZeroByte` is false, we won't capture this. + if (((ByteBuf) msg).readableBytes() == 0 && !captureZeroByte) { + logger.debug("Discarding Zero Byte TCP Packet. isWriteOperation {}", isWriteOperation); + return; + } + + ByteBufAllocator byteBufAllocator = ctx.alloc(); + ByteBuf packet = ((ByteBuf) msg).duplicate(); + ByteBuf tcpBuf = byteBufAllocator.buffer(); + int bytes = packet.readableBytes(); + + try { + if (isWriteOperation) { + final InetSocketAddress srcAddr; + final InetSocketAddress dstAddr; + if (isServerPipeline) { + srcAddr = handlerAddr; + dstAddr = initiatorAddr; + } else { + srcAddr = initiatorAddr; + dstAddr = handlerAddr; + } + + TCPPacket.writePacket(tcpBuf, packet, sendSegmentNumber, receiveSegmentNumber, srcAddr.getPort(), + dstAddr.getPort(), TCPPacket.TCPFlag.ACK); + completeTCPWrite(srcAddr, dstAddr, tcpBuf, byteBufAllocator, ctx); + logTCP(true, bytes, sendSegmentNumber, receiveSegmentNumber, srcAddr, dstAddr, false); + + sendSegmentNumber += bytes; + + TCPPacket.writePacket(tcpBuf, null, receiveSegmentNumber, sendSegmentNumber, dstAddr.getPort(), + srcAddr.getPort(), TCPPacket.TCPFlag.ACK); + completeTCPWrite(dstAddr, srcAddr, tcpBuf, byteBufAllocator, ctx); + logTCP(true, bytes, sendSegmentNumber, receiveSegmentNumber, dstAddr, srcAddr, true); + } else { + final InetSocketAddress srcAddr; + final InetSocketAddress dstAddr; + if (isServerPipeline) { + srcAddr = initiatorAddr; + dstAddr = handlerAddr; + } else { + srcAddr = handlerAddr; + dstAddr = initiatorAddr; + } + + TCPPacket.writePacket(tcpBuf, packet, receiveSegmentNumber, sendSegmentNumber, srcAddr.getPort(), + dstAddr.getPort(), TCPPacket.TCPFlag.ACK); + completeTCPWrite(srcAddr, dstAddr, tcpBuf, byteBufAllocator, ctx); + logTCP(false, bytes, receiveSegmentNumber, sendSegmentNumber, srcAddr, dstAddr, false); + + receiveSegmentNumber += bytes; + + TCPPacket.writePacket(tcpBuf, null, sendSegmentNumber, receiveSegmentNumber, dstAddr.getPort(), + srcAddr.getPort(), TCPPacket.TCPFlag.ACK); + completeTCPWrite(dstAddr, srcAddr, tcpBuf, byteBufAllocator, ctx); + logTCP(false, bytes, sendSegmentNumber, receiveSegmentNumber, dstAddr, srcAddr, true); + } + } finally { + tcpBuf.release(); + } + } else { + logger.debug("Discarding Pcap Write for TCP Object: {}", msg); + } + } + + /** + * Write TCP/IP L3 and L2 here. + * + * @param srcAddr {@link InetSocketAddress} Source Address of this Packet + * @param dstAddr {@link InetSocketAddress} Destination Address of this Packet + * @param tcpBuf {@link ByteBuf} containing TCP L4 Data + * @param byteBufAllocator {@link ByteBufAllocator} for allocating bytes for TCP/IP L3 and L2 data. + * @param ctx {@link ChannelHandlerContext} for {@code fireExceptionCaught} + */ + private void completeTCPWrite(InetSocketAddress srcAddr, InetSocketAddress dstAddr, ByteBuf tcpBuf, + ByteBufAllocator byteBufAllocator, ChannelHandlerContext ctx) { + + ByteBuf ipBuf = byteBufAllocator.buffer(); + ByteBuf ethernetBuf = byteBufAllocator.buffer(); + ByteBuf pcap = byteBufAllocator.buffer(); + + try { + if (srcAddr.getAddress() instanceof Inet4Address && dstAddr.getAddress() instanceof Inet4Address) { + IPPacket.writeTCPv4(ipBuf, tcpBuf, + NetUtil.ipv4AddressToInt((Inet4Address) srcAddr.getAddress()), + NetUtil.ipv4AddressToInt((Inet4Address) dstAddr.getAddress())); + + EthernetPacket.writeIPv4(ethernetBuf, ipBuf); + } else if (srcAddr.getAddress() instanceof Inet6Address && dstAddr.getAddress() instanceof Inet6Address) { + IPPacket.writeTCPv6(ipBuf, tcpBuf, + srcAddr.getAddress().getAddress(), + dstAddr.getAddress().getAddress()); + + EthernetPacket.writeIPv6(ethernetBuf, ipBuf); + } else { + logger.error("Source and Destination IP Address versions are not same. Source Address: {}, " + + "Destination Address: {}", srcAddr.getAddress(), dstAddr.getAddress()); + return; + } + + // Write Packet into Pcap + pCapWriter.writePacket(pcap, ethernetBuf); + } catch (IOException ex) { + logger.error("Caught Exception While Writing Packet into Pcap", ex); + ctx.fireExceptionCaught(ex); + } finally { + ipBuf.release(); + ethernetBuf.release(); + pcap.release(); + } + } + + /** + * Handle UDP l4 + * + * @param ctx {@link ChannelHandlerContext} for {@code localAddress} / {@code remoteAddress}, + * {@link ByteBuf} allocation and {@code fireExceptionCaught} + * @param msg {@link DatagramPacket} or {@link ByteBuf} + */ + private void handleUDP(ChannelHandlerContext ctx, Object msg) { + ByteBuf udpBuf = ctx.alloc().buffer(); + + try { + if (msg instanceof DatagramPacket) { + + // If bytes are 0 and `captureZeroByte` is false, we won't capture this. + if (((DatagramPacket) msg).content().readableBytes() == 0 && !captureZeroByte) { + logger.debug("Discarding Zero Byte UDP Packet"); + return; + } + + DatagramPacket datagramPacket = ((DatagramPacket) msg).duplicate(); + InetSocketAddress srcAddr = datagramPacket.sender(); + InetSocketAddress dstAddr = datagramPacket.recipient(); + + // If `datagramPacket.sender()` is `null` then DatagramPacket is initialized + // `sender` (local) address. In this case, we'll get source address from Channel. + if (srcAddr == null) { + srcAddr = getLocalAddress(ctx.channel(), dstAddr); + } + + logger.debug("Writing UDP Data of {} Bytes, Src Addr {}, Dst Addr {}", + datagramPacket.content().readableBytes(), srcAddr, dstAddr); + + UDPPacket.writePacket(udpBuf, datagramPacket.content(), srcAddr.getPort(), dstAddr.getPort()); + completeUDPWrite(srcAddr, dstAddr, udpBuf, ctx.alloc(), ctx); + } else if (msg instanceof ByteBuf && + (!(ctx.channel() instanceof DatagramChannel) || ((DatagramChannel) ctx.channel()).isConnected())) { + + // If bytes are 0 and `captureZeroByte` is false, we won't capture this. + if (((ByteBuf) msg).readableBytes() == 0 && !captureZeroByte) { + logger.debug("Discarding Zero Byte UDP Packet"); + return; + } + + ByteBuf byteBuf = ((ByteBuf) msg).duplicate(); + + logger.debug("Writing UDP Data of {} Bytes, Src Addr {}, Dst Addr {}", + byteBuf.readableBytes(), initiatorAddr, handlerAddr); + + UDPPacket.writePacket(udpBuf, byteBuf, initiatorAddr.getPort(), handlerAddr.getPort()); + completeUDPWrite(initiatorAddr, handlerAddr, udpBuf, ctx.alloc(), ctx); + } else { + logger.debug("Discarding Pcap Write for UDP Object: {}", msg); + } + } finally { + udpBuf.release(); + } + } + + /** + * Write UDP/IP L3 and L2 here. + * + * @param srcAddr {@link InetSocketAddress} Source Address of this Packet + * @param dstAddr {@link InetSocketAddress} Destination Address of this Packet + * @param udpBuf {@link ByteBuf} containing UDP L4 Data + * @param byteBufAllocator {@link ByteBufAllocator} for allocating bytes for UDP/IP L3 and L2 data. + * @param ctx {@link ChannelHandlerContext} for {@code fireExceptionCaught} + */ + private void completeUDPWrite(InetSocketAddress srcAddr, InetSocketAddress dstAddr, ByteBuf udpBuf, + ByteBufAllocator byteBufAllocator, ChannelHandlerContext ctx) { + + ByteBuf ipBuf = byteBufAllocator.buffer(); + ByteBuf ethernetBuf = byteBufAllocator.buffer(); + ByteBuf pcap = byteBufAllocator.buffer(); + + try { + if (srcAddr.getAddress() instanceof Inet4Address && dstAddr.getAddress() instanceof Inet4Address) { + IPPacket.writeUDPv4(ipBuf, udpBuf, + NetUtil.ipv4AddressToInt((Inet4Address) srcAddr.getAddress()), + NetUtil.ipv4AddressToInt((Inet4Address) dstAddr.getAddress())); + + EthernetPacket.writeIPv4(ethernetBuf, ipBuf); + } else if (srcAddr.getAddress() instanceof Inet6Address && dstAddr.getAddress() instanceof Inet6Address) { + IPPacket.writeUDPv6(ipBuf, udpBuf, + srcAddr.getAddress().getAddress(), + dstAddr.getAddress().getAddress()); + + EthernetPacket.writeIPv6(ethernetBuf, ipBuf); + } else { + logger.error("Source and Destination IP Address versions are not same. Source Address: {}, " + + "Destination Address: {}", srcAddr.getAddress(), dstAddr.getAddress()); + return; + } + + // Write Packet into Pcap + pCapWriter.writePacket(pcap, ethernetBuf); + } catch (IOException ex) { + logger.error("Caught Exception While Writing Packet into Pcap", ex); + ctx.fireExceptionCaught(ex); + } finally { + ipBuf.release(); + ethernetBuf.release(); + pcap.release(); + } + } + + /** + * Get the local address of a channel. If the address is a wildcard address ({@code 0.0.0.0} or {@code ::}), and + * the address family does not match that of the {@code remote}, return the wildcard address of the {@code remote}'s + * family instead. + * + * @param ch The channel to get the local address from + * @param remote The remote address + * @return The fixed local address + */ + private static InetSocketAddress getLocalAddress(Channel ch, InetSocketAddress remote) { + InetSocketAddress local = (InetSocketAddress) ch.localAddress(); + if (remote != null && local.getAddress().isAnyLocalAddress()) { + if (local.getAddress() instanceof Inet4Address && remote.getAddress() instanceof Inet6Address) { + return new InetSocketAddress(WildcardAddressHolder.wildcard6, local.getPort()); + } + if (local.getAddress() instanceof Inet6Address && remote.getAddress() instanceof Inet4Address) { + return new InetSocketAddress(WildcardAddressHolder.wildcard4, local.getPort()); + } + } + return local; + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { + + // If `isTCP` is true, then we'll simulate a `FIN` flow. + if (channelType == ChannelType.TCP) { + logger.debug("Starting Fake TCP FIN+ACK Flow to close connection"); + + ByteBufAllocator byteBufAllocator = ctx.alloc(); + ByteBuf tcpBuf = byteBufAllocator.buffer(); + + try { + // Write FIN+ACK with Normal Source and Destination Address + TCPPacket.writePacket(tcpBuf, null, sendSegmentNumber, receiveSegmentNumber, initiatorAddr.getPort(), + handlerAddr.getPort(), TCPPacket.TCPFlag.FIN, TCPPacket.TCPFlag.ACK); + completeTCPWrite(initiatorAddr, handlerAddr, tcpBuf, byteBufAllocator, ctx); + + // Write FIN+ACK with Reversed Source and Destination Address + TCPPacket.writePacket(tcpBuf, null, receiveSegmentNumber, sendSegmentNumber, handlerAddr.getPort(), + initiatorAddr.getPort(), TCPPacket.TCPFlag.FIN, TCPPacket.TCPFlag.ACK); + completeTCPWrite(handlerAddr, initiatorAddr, tcpBuf, byteBufAllocator, ctx); + + // Write ACK with Normal Source and Destination Address + TCPPacket.writePacket(tcpBuf, null, sendSegmentNumber + 1, receiveSegmentNumber + 1, + initiatorAddr.getPort(), handlerAddr.getPort(), TCPPacket.TCPFlag.ACK); + completeTCPWrite(initiatorAddr, handlerAddr, tcpBuf, byteBufAllocator, ctx); + } finally { + tcpBuf.release(); + } + + logger.debug("Finished Fake TCP FIN+ACK Flow to close connection"); + } + + close(); + super.handlerRemoved(ctx); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + + if (channelType == ChannelType.TCP) { + ByteBuf tcpBuf = ctx.alloc().buffer(); + + try { + // Write RST with Normal Source and Destination Address + TCPPacket.writePacket(tcpBuf, null, sendSegmentNumber, receiveSegmentNumber, initiatorAddr.getPort(), + handlerAddr.getPort(), TCPPacket.TCPFlag.RST, TCPPacket.TCPFlag.ACK); + completeTCPWrite(initiatorAddr, handlerAddr, tcpBuf, ctx.alloc(), ctx); + } finally { + tcpBuf.release(); + } + + logger.debug("Sent Fake TCP RST to close connection"); + } + + close(); + ctx.fireExceptionCaught(cause); + } + + /** + * Logger for TCP + */ + private void logTCP(boolean isWriteOperation, int bytes, int sendSegmentNumber, int receiveSegmentNumber, + InetSocketAddress srcAddr, InetSocketAddress dstAddr, boolean ackOnly) { + // If `ackOnly` is `true` when we don't need to write any data so we'll not + // log number of bytes being written and mark the operation as "TCP ACK". + if (logger.isDebugEnabled()) { + if (ackOnly) { + logger.debug("Writing TCP ACK, isWriteOperation {}, Segment Number {}, Ack Number {}, Src Addr {}, " + + "Dst Addr {}", isWriteOperation, sendSegmentNumber, receiveSegmentNumber, dstAddr, srcAddr); + } else { + logger.debug("Writing TCP Data of {} Bytes, isWriteOperation {}, Segment Number {}, Ack Number {}, " + + "Src Addr {}, Dst Addr {}", bytes, isWriteOperation, sendSegmentNumber, + receiveSegmentNumber, srcAddr, dstAddr); + } + } + } + + OutputStream outputStream() { + return outputStream; + } + + boolean sharedOutputStream() { + return sharedOutputStream; + } + + /** + * Returns {@code true} if the {@link PcapWriteHandler} is currently + * writing packets to the {@link OutputStream} else returns {@code false}. + */ + public boolean isWriting() { + return state.get() == State.WRITING; + } + + State state() { + return state.get(); + } + + /** + * Pause the {@link PcapWriteHandler} from writing packets to the {@link OutputStream}. + */ + public void pause() { + if (!state.compareAndSet(State.WRITING, State.PAUSED)) { + throw new IllegalStateException("State must be 'STARTED' to pause but current state is: " + state); + } + } + + /** + * Resume the {@link PcapWriteHandler} to writing packets to the {@link OutputStream}. + */ + public void resume() { + if (!state.compareAndSet(State.PAUSED, State.WRITING)) { + throw new IllegalStateException("State must be 'PAUSED' to resume but current state is: " + state); + } + } + + void markClosed() { + if (state.get() != State.CLOSED) { + state.set(State.CLOSED); + } + } + + // Visible for testing only. + PcapWriter pCapWriter() { + return pCapWriter; + } + + private void logDiscard() { + logger.warn("Discarding pcap write because channel type is unknown. The channel this handler is registered " + + "on is not a SocketChannel or DatagramChannel, so the inference does not work. Please call " + + "forceTcpChannel or forceUdpChannel before registering the handler."); + } + + @Override + public String toString() { + return "PcapWriteHandler{" + + "captureZeroByte=" + captureZeroByte + + ", writePcapGlobalHeader=" + writePcapGlobalHeader + + ", sharedOutputStream=" + sharedOutputStream + + ", sendSegmentNumber=" + sendSegmentNumber + + ", receiveSegmentNumber=" + receiveSegmentNumber + + ", channelType=" + channelType + + ", initiatorAddr=" + initiatorAddr + + ", handlerAddr=" + handlerAddr + + ", isServerPipeline=" + isServerPipeline + + ", state=" + state + + '}'; + } + + /** + *

Close {@code PcapWriter} and {@link OutputStream}.

+ *

Note: Calling this method does not close {@link PcapWriteHandler}. + * Only Pcap Writes are closed.

+ * + * @throws IOException If {@link OutputStream#close()} throws an exception + */ + @Override + public void close() throws IOException { + if (state.get() == State.CLOSED) { + logger.debug("PcapWriterHandler is already closed"); + } else { + pCapWriter.close(); + markClosed(); + logger.debug("PcapWriterHandler is now closed"); + } + } + + private enum ChannelType { + TCP, UDP + } + + public static Builder builder() { + return new Builder(); + } + + /** + * Builder for {@link PcapWriteHandler} + */ + public static final class Builder { + private boolean captureZeroByte; + private boolean sharedOutputStream; + private boolean writePcapGlobalHeader = true; + + private ChannelType channelType; + private InetSocketAddress initiatorAddr; + private InetSocketAddress handlerAddr; + private boolean isServerPipeline; + + private Builder() { + } + + /** + * Set to {@code true} to enable capturing packets with empty (0 bytes) payload. Otherwise, if set to + * {@code false}, empty packets will be filtered out. + * + * @param captureZeroByte Whether to filter out empty packets. + * @return this builder + */ + public Builder captureZeroByte(boolean captureZeroByte) { + this.captureZeroByte = captureZeroByte; + return this; + } + + /** + * Set to {@code true} if multiple {@link PcapWriteHandler} instances will be + * writing to the same {@link OutputStream} concurrently, and write locking is + * required. Otherwise, if set to {@code false}, no locking will be done. + * Additionally, {@link #close} will not close the underlying {@code OutputStream}. + * Note: it is probably an error to have both {@code writePcapGlobalHeader} and + * {@code sharedOutputStream} set to {@code true} at the same time. + * + * @param sharedOutputStream Whether {@link OutputStream} is shared or not + * @return this builder + */ + public Builder sharedOutputStream(boolean sharedOutputStream) { + this.sharedOutputStream = sharedOutputStream; + return this; + } + + /** + * Set to {@code true} to write Pcap Global Header on initialization. Otherwise, if set to {@code false}, Pcap + * Global Header will not be written on initialization. This could when writing Pcap data on a existing file + * where Pcap Global Header is already present. + * + * @param writePcapGlobalHeader Whether to write the pcap global header. + * @return this builder + */ + public Builder writePcapGlobalHeader(boolean writePcapGlobalHeader) { + this.writePcapGlobalHeader = writePcapGlobalHeader; + return this; + } + + /** + * Force this handler to write data as if they were TCP packets, with the given connection metadata. If this + * method isn't called, we determine the metadata from the channel. + * + * @param serverAddress The address of the TCP server (handler) + * @param clientAddress The address of the TCP client (initiator) + * @param isServerPipeline Whether the handler is part of the server channel + * @return this builder + */ + public Builder forceTcpChannel(InetSocketAddress serverAddress, InetSocketAddress clientAddress, + boolean isServerPipeline) { + channelType = ChannelType.TCP; + handlerAddr = checkNotNull(serverAddress, "serverAddress"); + initiatorAddr = checkNotNull(clientAddress, "clientAddress"); + this.isServerPipeline = isServerPipeline; + return this; + } + + /** + * Force this handler to write data as if they were UDP packets, with the given connection metadata. If this + * method isn't called, we determine the metadata from the channel. + *
+ * Note that even if this method is called, the address information on {@link DatagramPacket} takes precedence + * if it is present. + * + * @param localAddress The address of the UDP local + * @param remoteAddress The address of the UDP remote + * @return this builder + */ + public Builder forceUdpChannel(InetSocketAddress localAddress, InetSocketAddress remoteAddress) { + channelType = ChannelType.UDP; + handlerAddr = checkNotNull(remoteAddress, "remoteAddress"); + initiatorAddr = checkNotNull(localAddress, "localAddress"); + return this; + } + + /** + * Build the {@link PcapWriteHandler}. + * + * @param outputStream The output stream to write the pcap data to. + * @return The handler. + */ + public PcapWriteHandler build(OutputStream outputStream) { + checkNotNull(outputStream, "outputStream"); + return new PcapWriteHandler(this, outputStream); + } + } + + private static final class WildcardAddressHolder { + static final InetAddress wildcard4; // 0.0.0.0 + static final InetAddress wildcard6; // :: + + static { + try { + wildcard4 = InetAddress.getByAddress(new byte[4]); + wildcard6 = InetAddress.getByAddress(new byte[16]); + } catch (UnknownHostException e) { + // would only happen if the byte array was of incorrect size + throw new AssertionError(e); + } + } + } +} diff --git a/netty-handler/src/main/java/io/netty/handler/pcap/PcapWriter.java b/netty-handler/src/main/java/io/netty/handler/pcap/PcapWriter.java new file mode 100644 index 0000000..a4c915c --- /dev/null +++ b/netty-handler/src/main/java/io/netty/handler/pcap/PcapWriter.java @@ -0,0 +1,112 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.pcap; + +import io.netty.buffer.ByteBuf; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.io.Closeable; +import java.io.IOException; +import java.io.OutputStream; + +final class PcapWriter implements Closeable { + + /** + * Logger + */ + private static final InternalLogger logger = InternalLoggerFactory.getInstance(PcapWriter.class); + + private final PcapWriteHandler pcapWriteHandler; + + /** + * Reference declared so that we can use this as mutex in clean way. + */ + private final OutputStream outputStream; + + /** + * This uses {@link OutputStream} for writing Pcap data. + * + * @throws IOException If {@link OutputStream#write(byte[])} throws an exception + */ + PcapWriter(PcapWriteHandler pcapWriteHandler) throws IOException { + this.pcapWriteHandler = pcapWriteHandler; + outputStream = pcapWriteHandler.outputStream(); + + // If OutputStream is not shared then we have to write Global Header. + if (!pcapWriteHandler.sharedOutputStream()) { + PcapHeaders.writeGlobalHeader(pcapWriteHandler.outputStream()); + } + } + + /** + * Write Packet in Pcap OutputStream. + * + * @param packetHeaderBuf Packer Header {@link ByteBuf} + * @param packet Packet + * @throws IOException If {@link OutputStream#write(byte[])} throws an exception + */ + void writePacket(ByteBuf packetHeaderBuf, ByteBuf packet) throws IOException { + if (pcapWriteHandler.state() == State.CLOSED) { + logger.debug("Pcap Write attempted on closed PcapWriter"); + } + + long timestamp = System.currentTimeMillis(); + + PcapHeaders.writePacketHeader( + packetHeaderBuf, + (int) (timestamp / 1000L), + (int) (timestamp % 1000L * 1000L), + packet.readableBytes(), + packet.readableBytes() + ); + + if (pcapWriteHandler.sharedOutputStream()) { + synchronized (outputStream) { + packetHeaderBuf.readBytes(outputStream, packetHeaderBuf.readableBytes()); + packet.readBytes(outputStream, packet.readableBytes()); + } + } else { + packetHeaderBuf.readBytes(outputStream, packetHeaderBuf.readableBytes()); + packet.readBytes(outputStream, packet.readableBytes()); + } + } + + @Override + public String toString() { + return "PcapWriter{" + + "outputStream=" + outputStream + + '}'; + } + + @Override + public void close() throws IOException { + if (pcapWriteHandler.state() == State.CLOSED) { + logger.debug("PcapWriter is already closed"); + } else { + if (pcapWriteHandler.sharedOutputStream()) { + synchronized (outputStream) { + outputStream.flush(); + } + } else { + outputStream.flush(); + outputStream.close(); + } + pcapWriteHandler.markClosed(); + logger.debug("PcapWriter is now closed"); + } + } +} diff --git a/netty-handler/src/main/java/io/netty/handler/pcap/State.java b/netty-handler/src/main/java/io/netty/handler/pcap/State.java new file mode 100644 index 0000000..bd69b51 --- /dev/null +++ b/netty-handler/src/main/java/io/netty/handler/pcap/State.java @@ -0,0 +1,42 @@ +/* + * Copyright 2023 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.pcap; + +/** + * The state of the {@link PcapWriteHandler}. + */ +enum State { + + /** + * The handler is not active. + */ + INIT, + + /** + * The handler is active and actively writing Pcap data. + */ + WRITING, + + /** + * The handler is paused. No Pcap data will be written. + */ + PAUSED, + + /** + * The handler is closed. + */ + CLOSED +} diff --git a/netty-handler/src/main/java/io/netty/handler/pcap/TCPPacket.java b/netty-handler/src/main/java/io/netty/handler/pcap/TCPPacket.java new file mode 100644 index 0000000..cd74ea0 --- /dev/null +++ b/netty-handler/src/main/java/io/netty/handler/pcap/TCPPacket.java @@ -0,0 +1,82 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.pcap; + +import io.netty.buffer.ByteBuf; + +final class TCPPacket { + + /** + * Data Offset + Reserved Bits. + */ + private static final short OFFSET = 0x5000; + + private TCPPacket() { + // Prevent outside initialization + } + + /** + * Write TCP Packet + * + * @param byteBuf ByteBuf where Packet data will be set + * @param payload Payload of this Packet + * @param srcPort Source Port + * @param dstPort Destination Port + */ + static void writePacket(ByteBuf byteBuf, ByteBuf payload, int segmentNumber, int ackNumber, int srcPort, + int dstPort, TCPFlag... tcpFlags) { + + byteBuf.writeShort(srcPort); // Source Port + byteBuf.writeShort(dstPort); // Destination Port + byteBuf.writeInt(segmentNumber); // Segment Number + byteBuf.writeInt(ackNumber); // Acknowledgment Number + byteBuf.writeShort(OFFSET | TCPFlag.getFlag(tcpFlags)); // Flags + byteBuf.writeShort(65535); // Window Size + byteBuf.writeShort(0x0001); // Checksum + byteBuf.writeShort(0); // Urgent Pointer + + if (payload != null) { + byteBuf.writeBytes(payload); // Payload of Data + } + } + + enum TCPFlag { + FIN(1), + SYN(1 << 1), + RST(1 << 2), + PSH(1 << 3), + ACK(1 << 4), + URG(1 << 5), + ECE(1 << 6), + CWR(1 << 7); + + private final int value; + + TCPFlag(int value) { + this.value = value; + } + + static int getFlag(TCPFlag... tcpFlags) { + int flags = 0; + + for (TCPFlag tcpFlag : tcpFlags) { + flags |= tcpFlag.value; + } + + return flags; + } + } +} diff --git a/netty-handler/src/main/java/io/netty/handler/pcap/UDPPacket.java b/netty-handler/src/main/java/io/netty/handler/pcap/UDPPacket.java new file mode 100644 index 0000000..c59abf7 --- /dev/null +++ b/netty-handler/src/main/java/io/netty/handler/pcap/UDPPacket.java @@ -0,0 +1,43 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.pcap; + +import io.netty.buffer.ByteBuf; + +final class UDPPacket { + + private static final short UDP_HEADER_SIZE = 8; + + private UDPPacket() { + // Prevent outside initialization + } + + /** + * Write UDP Packet + * + * @param byteBuf ByteBuf where Packet data will be set + * @param payload Payload of this Packet + * @param srcPort Source Port + * @param dstPort Destination Port + */ + static void writePacket(ByteBuf byteBuf, ByteBuf payload, int srcPort, int dstPort) { + byteBuf.writeShort(srcPort); // Source Port + byteBuf.writeShort(dstPort); // Destination Port + byteBuf.writeShort(UDP_HEADER_SIZE + payload.readableBytes()); // UDP Header Length + Payload Length + byteBuf.writeShort(0x0001); // Checksum + byteBuf.writeBytes(payload); // Payload of Data + } +} diff --git a/netty-handler/src/main/java/io/netty/handler/pcap/package-info.java b/netty-handler/src/main/java/io/netty/handler/pcap/package-info.java new file mode 100644 index 0000000..e990b22 --- /dev/null +++ b/netty-handler/src/main/java/io/netty/handler/pcap/package-info.java @@ -0,0 +1,20 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/** + * Capture data and write into Pcap format which helps in troubleshooting. + */ +package io.netty.handler.pcap; diff --git a/netty-handler/src/main/java/io/netty/handler/stream/ChunkedFile.java b/netty-handler/src/main/java/io/netty/handler/stream/ChunkedFile.java new file mode 100644 index 0000000..5153c0a --- /dev/null +++ b/netty-handler/src/main/java/io/netty/handler/stream/ChunkedFile.java @@ -0,0 +1,170 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.stream; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.FileRegion; +import io.netty.util.internal.ObjectUtil; + +import java.io.File; +import java.io.IOException; +import java.io.RandomAccessFile; + +/** + * A {@link ChunkedInput} that fetches data from a file chunk by chunk. + *

+ * If your operating system supports + * zero-copy file transfer + * such as {@code sendfile()}, you might want to use {@link FileRegion} instead. + */ +public class ChunkedFile implements ChunkedInput { + + private final RandomAccessFile file; + private final long startOffset; + private final long endOffset; + private final int chunkSize; + private long offset; + + /** + * Creates a new instance that fetches data from the specified file. + */ + public ChunkedFile(File file) throws IOException { + this(file, ChunkedStream.DEFAULT_CHUNK_SIZE); + } + + /** + * Creates a new instance that fetches data from the specified file. + * + * @param chunkSize the number of bytes to fetch on each + * {@link #readChunk(ChannelHandlerContext)} call + */ + public ChunkedFile(File file, int chunkSize) throws IOException { + this(new RandomAccessFile(file, "r"), chunkSize); + } + + /** + * Creates a new instance that fetches data from the specified file. + */ + public ChunkedFile(RandomAccessFile file) throws IOException { + this(file, ChunkedStream.DEFAULT_CHUNK_SIZE); + } + + /** + * Creates a new instance that fetches data from the specified file. + * + * @param chunkSize the number of bytes to fetch on each + * {@link #readChunk(ChannelHandlerContext)} call + */ + public ChunkedFile(RandomAccessFile file, int chunkSize) throws IOException { + this(file, 0, file.length(), chunkSize); + } + + /** + * Creates a new instance that fetches data from the specified file. + * + * @param offset the offset of the file where the transfer begins + * @param length the number of bytes to transfer + * @param chunkSize the number of bytes to fetch on each + * {@link #readChunk(ChannelHandlerContext)} call + */ + public ChunkedFile(RandomAccessFile file, long offset, long length, int chunkSize) throws IOException { + ObjectUtil.checkNotNull(file, "file"); + ObjectUtil.checkPositiveOrZero(offset, "offset"); + ObjectUtil.checkPositiveOrZero(length, "length"); + ObjectUtil.checkPositive(chunkSize, "chunkSize"); + + this.file = file; + this.offset = startOffset = offset; + this.endOffset = offset + length; + this.chunkSize = chunkSize; + + file.seek(offset); + } + + /** + * Returns the offset in the file where the transfer began. + */ + public long startOffset() { + return startOffset; + } + + /** + * Returns the offset in the file where the transfer will end. + */ + public long endOffset() { + return endOffset; + } + + /** + * Returns the offset in the file where the transfer is happening currently. + */ + public long currentOffset() { + return offset; + } + + @Override + public boolean isEndOfInput() throws Exception { + return !(offset < endOffset && file.getChannel().isOpen()); + } + + @Override + public void close() throws Exception { + file.close(); + } + + @Deprecated + @Override + public ByteBuf readChunk(ChannelHandlerContext ctx) throws Exception { + return readChunk(ctx.alloc()); + } + + @Override + public ByteBuf readChunk(ByteBufAllocator allocator) throws Exception { + long offset = this.offset; + if (offset >= endOffset) { + return null; + } + + int chunkSize = (int) Math.min(this.chunkSize, endOffset - offset); + // Check if the buffer is backed by an byte array. If so we can optimize it a bit an safe a copy + + ByteBuf buf = allocator.heapBuffer(chunkSize); + boolean release = true; + try { + file.readFully(buf.array(), buf.arrayOffset(), chunkSize); + buf.writerIndex(chunkSize); + this.offset = offset + chunkSize; + release = false; + return buf; + } finally { + if (release) { + buf.release(); + } + } + } + + @Override + public long length() { + return endOffset - startOffset; + } + + @Override + public long progress() { + return offset - startOffset; + } +} diff --git a/netty-handler/src/main/java/io/netty/handler/stream/ChunkedInput.java b/netty-handler/src/main/java/io/netty/handler/stream/ChunkedInput.java new file mode 100644 index 0000000..6d932d6 --- /dev/null +++ b/netty-handler/src/main/java/io/netty/handler/stream/ChunkedInput.java @@ -0,0 +1,81 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.stream; + + +import io.netty.buffer.ByteBufAllocator; +import io.netty.channel.ChannelHandlerContext; + +/** + * A data stream of indefinite length which is consumed by {@link ChunkedWriteHandler}. + */ +public interface ChunkedInput { + + /** + * Return {@code true} if and only if there is no data left in the stream + * and the stream has reached at its end. + */ + boolean isEndOfInput() throws Exception; + + /** + * Releases the resources associated with the input. + */ + void close() throws Exception; + + /** + * @deprecated Use {@link #readChunk(ByteBufAllocator)}. + * + *

Fetches a chunked data from the stream. Once this method returns the last chunk + * and thus the stream has reached at its end, any subsequent {@link #isEndOfInput()} + * call must return {@code true}. + * + * @param ctx The context which provides a {@link ByteBufAllocator} if buffer allocation is necessary. + * @return the fetched chunk. + * {@code null} if there is no data left in the stream. + * Please note that {@code null} does not necessarily mean that the + * stream has reached at its end. In a slow stream, the next chunk + * might be unavailable just momentarily. + */ + @Deprecated + B readChunk(ChannelHandlerContext ctx) throws Exception; + + /** + * Fetches a chunked data from the stream. Once this method returns the last chunk + * and thus the stream has reached at its end, any subsequent {@link #isEndOfInput()} + * call must return {@code true}. + * + * @param allocator {@link ByteBufAllocator} if buffer allocation is necessary. + * @return the fetched chunk. + * {@code null} if there is no data left in the stream. + * Please note that {@code null} does not necessarily mean that the + * stream has reached at its end. In a slow stream, the next chunk + * might be unavailable just momentarily. + */ + B readChunk(ByteBufAllocator allocator) throws Exception; + + /** + * Returns the length of the input. + * @return the length of the input if the length of the input is known. + * a negative value if the length of the input is unknown. + */ + long length(); + + /** + * Returns current transfer progress. + */ + long progress(); + +} diff --git a/netty-handler/src/main/java/io/netty/handler/stream/ChunkedNioFile.java b/netty-handler/src/main/java/io/netty/handler/stream/ChunkedNioFile.java new file mode 100644 index 0000000..2fce8e2 --- /dev/null +++ b/netty-handler/src/main/java/io/netty/handler/stream/ChunkedNioFile.java @@ -0,0 +1,181 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.stream; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.FileRegion; +import io.netty.util.internal.ObjectUtil; + +import java.io.File; +import java.io.IOException; +import java.io.RandomAccessFile; +import java.nio.channels.ClosedChannelException; +import java.nio.channels.FileChannel; + +/** + * A {@link ChunkedInput} that fetches data from a file chunk by chunk using + * NIO {@link FileChannel}. + *

+ * If your operating system supports + * zero-copy file transfer + * such as {@code sendfile()}, you might want to use {@link FileRegion} instead. + */ +public class ChunkedNioFile implements ChunkedInput { + + private final FileChannel in; + private final long startOffset; + private final long endOffset; + private final int chunkSize; + private long offset; + + /** + * Creates a new instance that fetches data from the specified file. + */ + public ChunkedNioFile(File in) throws IOException { + this(new RandomAccessFile(in, "r").getChannel()); + } + + /** + * Creates a new instance that fetches data from the specified file. + * + * @param chunkSize the number of bytes to fetch on each + * {@link #readChunk(ChannelHandlerContext)} call + */ + public ChunkedNioFile(File in, int chunkSize) throws IOException { + this(new RandomAccessFile(in, "r").getChannel(), chunkSize); + } + + /** + * Creates a new instance that fetches data from the specified file. + */ + public ChunkedNioFile(FileChannel in) throws IOException { + this(in, ChunkedStream.DEFAULT_CHUNK_SIZE); + } + + /** + * Creates a new instance that fetches data from the specified file. + * + * @param chunkSize the number of bytes to fetch on each + * {@link #readChunk(ChannelHandlerContext)} call + */ + public ChunkedNioFile(FileChannel in, int chunkSize) throws IOException { + this(in, 0, in.size(), chunkSize); + } + + /** + * Creates a new instance that fetches data from the specified file. + * + * @param offset the offset of the file where the transfer begins + * @param length the number of bytes to transfer + * @param chunkSize the number of bytes to fetch on each + * {@link #readChunk(ChannelHandlerContext)} call + */ + public ChunkedNioFile(FileChannel in, long offset, long length, int chunkSize) + throws IOException { + ObjectUtil.checkNotNull(in, "in"); + ObjectUtil.checkPositiveOrZero(offset, "offset"); + ObjectUtil.checkPositiveOrZero(length, "length"); + ObjectUtil.checkPositive(chunkSize, "chunkSize"); + if (!in.isOpen()) { + throw new ClosedChannelException(); + } + this.in = in; + this.chunkSize = chunkSize; + this.offset = startOffset = offset; + endOffset = offset + length; + } + + /** + * Returns the offset in the file where the transfer began. + */ + public long startOffset() { + return startOffset; + } + + /** + * Returns the offset in the file where the transfer will end. + */ + public long endOffset() { + return endOffset; + } + + /** + * Returns the offset in the file where the transfer is happening currently. + */ + public long currentOffset() { + return offset; + } + + @Override + public boolean isEndOfInput() throws Exception { + return !(offset < endOffset && in.isOpen()); + } + + @Override + public void close() throws Exception { + in.close(); + } + + @Deprecated + @Override + public ByteBuf readChunk(ChannelHandlerContext ctx) throws Exception { + return readChunk(ctx.alloc()); + } + + @Override + public ByteBuf readChunk(ByteBufAllocator allocator) throws Exception { + long offset = this.offset; + if (offset >= endOffset) { + return null; + } + + int chunkSize = (int) Math.min(this.chunkSize, endOffset - offset); + ByteBuf buffer = allocator.buffer(chunkSize); + boolean release = true; + try { + int readBytes = 0; + for (;;) { + int localReadBytes = buffer.writeBytes(in, offset + readBytes, chunkSize - readBytes); + if (localReadBytes < 0) { + break; + } + readBytes += localReadBytes; + if (readBytes == chunkSize) { + break; + } + } + this.offset += readBytes; + release = false; + return buffer; + } finally { + if (release) { + buffer.release(); + } + } + } + + @Override + public long length() { + return endOffset - startOffset; + } + + @Override + public long progress() { + return offset - startOffset; + } +} diff --git a/netty-handler/src/main/java/io/netty/handler/stream/ChunkedNioStream.java b/netty-handler/src/main/java/io/netty/handler/stream/ChunkedNioStream.java new file mode 100644 index 0000000..ecf65df --- /dev/null +++ b/netty-handler/src/main/java/io/netty/handler/stream/ChunkedNioStream.java @@ -0,0 +1,143 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.stream; + +import static io.netty.util.internal.ObjectUtil.checkPositive; +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.channel.ChannelHandlerContext; + +import java.nio.ByteBuffer; +import java.nio.channels.ReadableByteChannel; + +/** + * A {@link ChunkedInput} that fetches data from a {@link ReadableByteChannel} + * chunk by chunk. Please note that the {@link ReadableByteChannel} must + * operate in blocking mode. Non-blocking mode channels are not supported. + */ +public class ChunkedNioStream implements ChunkedInput { + + private final ReadableByteChannel in; + + private final int chunkSize; + private long offset; + + /** + * Associated ByteBuffer + */ + private final ByteBuffer byteBuffer; + + /** + * Creates a new instance that fetches data from the specified channel. + */ + public ChunkedNioStream(ReadableByteChannel in) { + this(in, ChunkedStream.DEFAULT_CHUNK_SIZE); + } + + /** + * Creates a new instance that fetches data from the specified channel. + * + * @param chunkSize the number of bytes to fetch on each + * {@link #readChunk(ChannelHandlerContext)} call + */ + public ChunkedNioStream(ReadableByteChannel in, int chunkSize) { + this.in = checkNotNull(in, "in"); + this.chunkSize = checkPositive(chunkSize, "chunkSize"); + byteBuffer = ByteBuffer.allocate(chunkSize); + } + + /** + * Returns the number of transferred bytes. + */ + public long transferredBytes() { + return offset; + } + + @Override + public boolean isEndOfInput() throws Exception { + if (byteBuffer.position() > 0) { + // A previous read was not over, so there is a next chunk in the buffer at least + return false; + } + if (in.isOpen()) { + // Try to read a new part, and keep this part (no rewind) + int b = in.read(byteBuffer); + if (b < 0) { + return true; + } else { + offset += b; + return false; + } + } + return true; + } + + @Override + public void close() throws Exception { + in.close(); + } + + @Deprecated + @Override + public ByteBuf readChunk(ChannelHandlerContext ctx) throws Exception { + return readChunk(ctx.alloc()); + } + + @Override + public ByteBuf readChunk(ByteBufAllocator allocator) throws Exception { + if (isEndOfInput()) { + return null; + } + // buffer cannot be not be empty from there + int readBytes = byteBuffer.position(); + for (;;) { + int localReadBytes = in.read(byteBuffer); + if (localReadBytes < 0) { + break; + } + readBytes += localReadBytes; + offset += localReadBytes; + if (readBytes == chunkSize) { + break; + } + } + byteBuffer.flip(); + boolean release = true; + ByteBuf buffer = allocator.buffer(byteBuffer.remaining()); + try { + buffer.writeBytes(byteBuffer); + byteBuffer.clear(); + release = false; + return buffer; + } finally { + if (release) { + buffer.release(); + } + } + } + + @Override + public long length() { + return -1; + } + + @Override + public long progress() { + return offset; + } +} diff --git a/netty-handler/src/main/java/io/netty/handler/stream/ChunkedStream.java b/netty-handler/src/main/java/io/netty/handler/stream/ChunkedStream.java new file mode 100644 index 0000000..e051a7f --- /dev/null +++ b/netty-handler/src/main/java/io/netty/handler/stream/ChunkedStream.java @@ -0,0 +1,148 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.stream; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.channel.ChannelHandlerContext; +import io.netty.util.internal.ObjectUtil; + +import java.io.InputStream; +import java.io.PushbackInputStream; + +/** + * A {@link ChunkedInput} that fetches data from an {@link InputStream} chunk by + * chunk. + *

+ * Please note that the {@link InputStream} instance that feeds data into + * {@link ChunkedStream} must implement {@link InputStream#available()} as + * accurately as possible, rather than using the default implementation. + * Otherwise, {@link ChunkedStream} will generate many too small chunks or + * block unnecessarily often. + */ +public class ChunkedStream implements ChunkedInput { + + static final int DEFAULT_CHUNK_SIZE = 8192; + + private final PushbackInputStream in; + private final int chunkSize; + private long offset; + private boolean closed; + + /** + * Creates a new instance that fetches data from the specified stream. + */ + public ChunkedStream(InputStream in) { + this(in, DEFAULT_CHUNK_SIZE); + } + + /** + * Creates a new instance that fetches data from the specified stream. + * + * @param chunkSize the number of bytes to fetch on each + * {@link #readChunk(ChannelHandlerContext)} call + */ + public ChunkedStream(InputStream in, int chunkSize) { + ObjectUtil.checkNotNull(in, "in"); + ObjectUtil.checkPositive(chunkSize, "chunkSize"); + + if (in instanceof PushbackInputStream) { + this.in = (PushbackInputStream) in; + } else { + this.in = new PushbackInputStream(in); + } + this.chunkSize = chunkSize; + } + + /** + * Returns the number of transferred bytes. + */ + public long transferredBytes() { + return offset; + } + + @Override + public boolean isEndOfInput() throws Exception { + if (closed) { + return true; + } + if (in.available() > 0) { + return false; + } + + int b = in.read(); + if (b < 0) { + return true; + } else { + in.unread(b); + return false; + } + } + + @Override + public void close() throws Exception { + closed = true; + in.close(); + } + + @Deprecated + @Override + public ByteBuf readChunk(ChannelHandlerContext ctx) throws Exception { + return readChunk(ctx.alloc()); + } + + @Override + public ByteBuf readChunk(ByteBufAllocator allocator) throws Exception { + if (isEndOfInput()) { + return null; + } + + final int availableBytes = in.available(); + final int chunkSize; + if (availableBytes <= 0) { + chunkSize = this.chunkSize; + } else { + chunkSize = Math.min(this.chunkSize, in.available()); + } + + boolean release = true; + ByteBuf buffer = allocator.buffer(chunkSize); + try { + // transfer to buffer + int written = buffer.writeBytes(in, chunkSize); + if (written < 0) { + return null; + } + offset += written; + release = false; + return buffer; + } finally { + if (release) { + buffer.release(); + } + } + } + + @Override + public long length() { + return -1; + } + + @Override + public long progress() { + return offset; + } +} diff --git a/netty-handler/src/main/java/io/netty/handler/stream/ChunkedWriteHandler.java b/netty-handler/src/main/java/io/netty/handler/stream/ChunkedWriteHandler.java new file mode 100644 index 0000000..d8058d0 --- /dev/null +++ b/netty-handler/src/main/java/io/netty/handler/stream/ChunkedWriteHandler.java @@ -0,0 +1,384 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.stream; + +import static io.netty.util.internal.ObjectUtil.checkPositive; + +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelDuplexHandler; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.ChannelProgressivePromise; +import io.netty.channel.ChannelPromise; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.nio.channels.ClosedChannelException; +import java.util.ArrayDeque; +import java.util.Queue; + +/** + * A {@link ChannelHandler} that adds support for writing a large data stream + * asynchronously neither spending a lot of memory nor getting + * {@link OutOfMemoryError}. Large data streaming such as file + * transfer requires complicated state management in a {@link ChannelHandler} + * implementation. {@link ChunkedWriteHandler} manages such complicated states + * so that you can send a large data stream without difficulties. + *

+ * To use {@link ChunkedWriteHandler} in your application, you have to insert + * a new {@link ChunkedWriteHandler} instance: + *

+ * {@link ChannelPipeline} p = ...;
+ * p.addLast("streamer", new {@link ChunkedWriteHandler}());
+ * p.addLast("handler", new MyHandler());
+ * 
+ * Once inserted, you can write a {@link ChunkedInput} so that the + * {@link ChunkedWriteHandler} can pick it up and fetch the content of the + * stream chunk by chunk and write the fetched chunk downstream: + *
+ * {@link Channel} ch = ...;
+ * ch.write(new {@link ChunkedFile}(new File("video.mkv"));
+ * 
+ * + *

Sending a stream which generates a chunk intermittently

+ * + * Some {@link ChunkedInput} generates a chunk on a certain event or timing. + * Such {@link ChunkedInput} implementation often returns {@code null} on + * {@link ChunkedInput#readChunk(ChannelHandlerContext)}, resulting in the indefinitely suspended + * transfer. To resume the transfer when a new chunk is available, you have to + * call {@link #resumeTransfer()}. + */ +public class ChunkedWriteHandler extends ChannelDuplexHandler { + + private static final InternalLogger logger = + InternalLoggerFactory.getInstance(ChunkedWriteHandler.class); + + private final Queue queue = new ArrayDeque(); + private volatile ChannelHandlerContext ctx; + + public ChunkedWriteHandler() { + } + + /** + * @deprecated use {@link #ChunkedWriteHandler()} + */ + @Deprecated + public ChunkedWriteHandler(int maxPendingWrites) { + checkPositive(maxPendingWrites, "maxPendingWrites"); + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + this.ctx = ctx; + } + + /** + * Continues to fetch the chunks from the input. + */ + public void resumeTransfer() { + final ChannelHandlerContext ctx = this.ctx; + if (ctx == null) { + return; + } + if (ctx.executor().inEventLoop()) { + resumeTransfer0(ctx); + } else { + // let the transfer resume on the next event loop round + ctx.executor().execute(new Runnable() { + + @Override + public void run() { + resumeTransfer0(ctx); + } + }); + } + } + + private void resumeTransfer0(ChannelHandlerContext ctx) { + try { + doFlush(ctx); + } catch (Exception e) { + logger.warn("Unexpected exception while sending chunks.", e); + } + } + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + queue.add(new PendingWrite(msg, promise)); + } + + @Override + public void flush(ChannelHandlerContext ctx) throws Exception { + doFlush(ctx); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + doFlush(ctx); + ctx.fireChannelInactive(); + } + + @Override + public void channelWritabilityChanged(ChannelHandlerContext ctx) throws Exception { + if (ctx.channel().isWritable()) { + // channel is writable again try to continue flushing + doFlush(ctx); + } + ctx.fireChannelWritabilityChanged(); + } + + private void discard(Throwable cause) { + for (;;) { + PendingWrite currentWrite = queue.poll(); + + if (currentWrite == null) { + break; + } + Object message = currentWrite.msg; + if (message instanceof ChunkedInput) { + ChunkedInput in = (ChunkedInput) message; + boolean endOfInput; + long inputLength; + try { + endOfInput = in.isEndOfInput(); + inputLength = in.length(); + closeInput(in); + } catch (Exception e) { + closeInput(in); + currentWrite.fail(e); + if (logger.isWarnEnabled()) { + logger.warn(ChunkedInput.class.getSimpleName() + " failed", e); + } + continue; + } + + if (!endOfInput) { + if (cause == null) { + cause = new ClosedChannelException(); + } + currentWrite.fail(cause); + } else { + currentWrite.success(inputLength); + } + } else { + if (cause == null) { + cause = new ClosedChannelException(); + } + currentWrite.fail(cause); + } + } + } + + private void doFlush(final ChannelHandlerContext ctx) { + final Channel channel = ctx.channel(); + if (!channel.isActive()) { + discard(null); + return; + } + + boolean requiresFlush = true; + ByteBufAllocator allocator = ctx.alloc(); + while (channel.isWritable()) { + final PendingWrite currentWrite = queue.peek(); + + if (currentWrite == null) { + break; + } + + if (currentWrite.promise.isDone()) { + // This might happen e.g. in the case when a write operation + // failed, but there're still unconsumed chunks left. + // Most chunked input sources would stop generating chunks + // and report end of input, but this doesn't work with any + // source wrapped in HttpChunkedInput. + // Note, that we're not trying to release the message/chunks + // as this had to be done already by someone who resolved the + // promise (using ChunkedInput.close method). + // See https://github.com/netty/netty/issues/8700. + queue.remove(); + continue; + } + + final Object pendingMessage = currentWrite.msg; + + if (pendingMessage instanceof ChunkedInput) { + final ChunkedInput chunks = (ChunkedInput) pendingMessage; + boolean endOfInput; + boolean suspend; + Object message = null; + try { + message = chunks.readChunk(allocator); + endOfInput = chunks.isEndOfInput(); + + if (message == null) { + // No need to suspend when reached at the end. + suspend = !endOfInput; + } else { + suspend = false; + } + } catch (final Throwable t) { + queue.remove(); + + if (message != null) { + ReferenceCountUtil.release(message); + } + + closeInput(chunks); + currentWrite.fail(t); + break; + } + + if (suspend) { + // ChunkedInput.nextChunk() returned null and it has + // not reached at the end of input. Let's wait until + // more chunks arrive. Nothing to write or notify. + break; + } + + if (message == null) { + // If message is null write an empty ByteBuf. + // See https://github.com/netty/netty/issues/1671 + message = Unpooled.EMPTY_BUFFER; + } + + if (endOfInput) { + // We need to remove the element from the queue before we call writeAndFlush() as this operation + // may cause an action that also touches the queue. + queue.remove(); + } + // Flush each chunk to conserve memory + ChannelFuture f = ctx.writeAndFlush(message); + if (endOfInput) { + if (f.isDone()) { + handleEndOfInputFuture(f, currentWrite); + } else { + // Register a listener which will close the input once the write is complete. + // This is needed because the Chunk may have some resource bound that can not + // be closed before its not written. + // + // See https://github.com/netty/netty/issues/303 + f.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + handleEndOfInputFuture(future, currentWrite); + } + }); + } + } else { + final boolean resume = !channel.isWritable(); + if (f.isDone()) { + handleFuture(f, currentWrite, resume); + } else { + f.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + handleFuture(future, currentWrite, resume); + } + }); + } + } + requiresFlush = false; + } else { + queue.remove(); + ctx.write(pendingMessage, currentWrite.promise); + requiresFlush = true; + } + + if (!channel.isActive()) { + discard(new ClosedChannelException()); + break; + } + } + + if (requiresFlush) { + ctx.flush(); + } + } + + private static void handleEndOfInputFuture(ChannelFuture future, PendingWrite currentWrite) { + ChunkedInput input = (ChunkedInput) currentWrite.msg; + if (!future.isSuccess()) { + closeInput(input); + currentWrite.fail(future.cause()); + } else { + // read state of the input in local variables before closing it + long inputProgress = input.progress(); + long inputLength = input.length(); + closeInput(input); + currentWrite.progress(inputProgress, inputLength); + currentWrite.success(inputLength); + } + } + + private void handleFuture(ChannelFuture future, PendingWrite currentWrite, boolean resume) { + ChunkedInput input = (ChunkedInput) currentWrite.msg; + if (!future.isSuccess()) { + closeInput(input); + currentWrite.fail(future.cause()); + } else { + currentWrite.progress(input.progress(), input.length()); + if (resume && future.channel().isWritable()) { + resumeTransfer(); + } + } + } + + private static void closeInput(ChunkedInput chunks) { + try { + chunks.close(); + } catch (Throwable t) { + if (logger.isWarnEnabled()) { + logger.warn("Failed to close a chunked input.", t); + } + } + } + + private static final class PendingWrite { + final Object msg; + final ChannelPromise promise; + + PendingWrite(Object msg, ChannelPromise promise) { + this.msg = msg; + this.promise = promise; + } + + void fail(Throwable cause) { + ReferenceCountUtil.release(msg); + promise.tryFailure(cause); + } + + void success(long total) { + if (promise.isDone()) { + // No need to notify the progress or fulfill the promise because it's done already. + return; + } + progress(total, total); + promise.trySuccess(); + } + + void progress(long progress, long total) { + if (promise instanceof ChannelProgressivePromise) { + ((ChannelProgressivePromise) promise).tryProgress(progress, total); + } + } + } +} diff --git a/netty-handler/src/main/java/io/netty/handler/stream/package-info.java b/netty-handler/src/main/java/io/netty/handler/stream/package-info.java new file mode 100644 index 0000000..fb4e05c --- /dev/null +++ b/netty-handler/src/main/java/io/netty/handler/stream/package-info.java @@ -0,0 +1,22 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/** + * Writes very large data stream asynchronously neither spending a lot of + * memory nor getting {@link java.lang.OutOfMemoryError}. For a detailed + * example, please refer to {@code io.netty.example.http.file}. + */ +package io.netty.handler.stream; diff --git a/netty-handler/src/main/java/io/netty/handler/timeout/IdleState.java b/netty-handler/src/main/java/io/netty/handler/timeout/IdleState.java new file mode 100644 index 0000000..7f185b6 --- /dev/null +++ b/netty-handler/src/main/java/io/netty/handler/timeout/IdleState.java @@ -0,0 +1,37 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.timeout; + +import io.netty.channel.Channel; + + +/** + * An {@link Enum} that represents the idle state of a {@link Channel}. + */ +public enum IdleState { + /** + * No data was received for a while. + */ + READER_IDLE, + /** + * No data was sent for a while. + */ + WRITER_IDLE, + /** + * No data was either received or sent for a while. + */ + ALL_IDLE +} diff --git a/netty-handler/src/main/java/io/netty/handler/timeout/IdleStateEvent.java b/netty-handler/src/main/java/io/netty/handler/timeout/IdleStateEvent.java new file mode 100644 index 0000000..29db5f2 --- /dev/null +++ b/netty-handler/src/main/java/io/netty/handler/timeout/IdleStateEvent.java @@ -0,0 +1,85 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.timeout; + +import io.netty.channel.Channel; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.StringUtil; + +/** + * A user event triggered by {@link IdleStateHandler} when a {@link Channel} is idle. + */ +public class IdleStateEvent { + public static final IdleStateEvent FIRST_READER_IDLE_STATE_EVENT = + new DefaultIdleStateEvent(IdleState.READER_IDLE, true); + public static final IdleStateEvent READER_IDLE_STATE_EVENT = + new DefaultIdleStateEvent(IdleState.READER_IDLE, false); + public static final IdleStateEvent FIRST_WRITER_IDLE_STATE_EVENT = + new DefaultIdleStateEvent(IdleState.WRITER_IDLE, true); + public static final IdleStateEvent WRITER_IDLE_STATE_EVENT = + new DefaultIdleStateEvent(IdleState.WRITER_IDLE, false); + public static final IdleStateEvent FIRST_ALL_IDLE_STATE_EVENT = + new DefaultIdleStateEvent(IdleState.ALL_IDLE, true); + public static final IdleStateEvent ALL_IDLE_STATE_EVENT = + new DefaultIdleStateEvent(IdleState.ALL_IDLE, false); + + private final IdleState state; + private final boolean first; + + /** + * Constructor for sub-classes. + * + * @param state the {@link IdleStateEvent} which triggered the event. + * @param first {@code true} if its the first idle event for the {@link IdleStateEvent}. + */ + protected IdleStateEvent(IdleState state, boolean first) { + this.state = ObjectUtil.checkNotNull(state, "state"); + this.first = first; + } + + /** + * Returns the idle state. + */ + public IdleState state() { + return state; + } + + /** + * Returns {@code true} if this was the first event for the {@link IdleState} + */ + public boolean isFirst() { + return first; + } + + @Override + public String toString() { + return StringUtil.simpleClassName(this) + '(' + state + (first ? ", first" : "") + ')'; + } + + private static final class DefaultIdleStateEvent extends IdleStateEvent { + private final String representation; + + DefaultIdleStateEvent(IdleState state, boolean first) { + super(state, first); + this.representation = "IdleStateEvent(" + state + (first ? ", first" : "") + ')'; + } + + @Override + public String toString() { + return representation; + } + } +} diff --git a/netty-handler/src/main/java/io/netty/handler/timeout/IdleStateHandler.java b/netty-handler/src/main/java/io/netty/handler/timeout/IdleStateHandler.java new file mode 100644 index 0000000..33f69ce --- /dev/null +++ b/netty-handler/src/main/java/io/netty/handler/timeout/IdleStateHandler.java @@ -0,0 +1,587 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.timeout; + +import io.netty.bootstrap.ServerBootstrap; +import io.netty.channel.Channel; +import io.netty.channel.Channel.Unsafe; +import io.netty.channel.ChannelDuplexHandler; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelOutboundBuffer; +import io.netty.channel.ChannelPromise; +import io.netty.util.concurrent.Future; +import io.netty.util.internal.ObjectUtil; + +import java.util.concurrent.TimeUnit; + +/** + * Triggers an {@link IdleStateEvent} when a {@link Channel} has not performed + * read, write, or both operation for a while. + * + *

Supported idle states

+ * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + * + *
PropertyMeaning
{@code readerIdleTime}an {@link IdleStateEvent} whose state is {@link IdleState#READER_IDLE} + * will be triggered when no read was performed for the specified period of + * time. Specify {@code 0} to disable.
{@code writerIdleTime}an {@link IdleStateEvent} whose state is {@link IdleState#WRITER_IDLE} + * will be triggered when no write was performed for the specified period of + * time. Specify {@code 0} to disable.
{@code allIdleTime}an {@link IdleStateEvent} whose state is {@link IdleState#ALL_IDLE} + * will be triggered when neither read nor write was performed for the + * specified period of time. Specify {@code 0} to disable.
+ * + *
+ * // An example that sends a ping message when there is no outbound traffic
+ * // for 30 seconds.  The connection is closed when there is no inbound traffic
+ * // for 60 seconds.
+ *
+ * public class MyChannelInitializer extends {@link ChannelInitializer}<{@link Channel}> {
+ *     {@code @Override}
+ *     public void initChannel({@link Channel} channel) {
+ *         channel.pipeline().addLast("idleStateHandler", new {@link IdleStateHandler}(60, 30, 0));
+ *         channel.pipeline().addLast("myHandler", new MyHandler());
+ *     }
+ * }
+ *
+ * // Handler should handle the {@link IdleStateEvent} triggered by {@link IdleStateHandler}.
+ * public class MyHandler extends {@link ChannelDuplexHandler} {
+ *     {@code @Override}
+ *     public void userEventTriggered({@link ChannelHandlerContext} ctx, {@link Object} evt) throws {@link Exception} {
+ *         if (evt instanceof {@link IdleStateEvent}) {
+ *             {@link IdleStateEvent} e = ({@link IdleStateEvent}) evt;
+ *             if (e.state() == {@link IdleState}.READER_IDLE) {
+ *                 ctx.close();
+ *             } else if (e.state() == {@link IdleState}.WRITER_IDLE) {
+ *                 ctx.writeAndFlush(new PingMessage());
+ *             }
+ *         }
+ *     }
+ * }
+ *
+ * {@link ServerBootstrap} bootstrap = ...;
+ * ...
+ * bootstrap.childHandler(new MyChannelInitializer());
+ * ...
+ * 
+ * + * @see ReadTimeoutHandler + * @see WriteTimeoutHandler + */ +public class IdleStateHandler extends ChannelDuplexHandler { + private static final long MIN_TIMEOUT_NANOS = TimeUnit.MILLISECONDS.toNanos(1); + + // Not create a new ChannelFutureListener per write operation to reduce GC pressure. + private final ChannelFutureListener writeListener = new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + lastWriteTime = ticksInNanos(); + firstWriterIdleEvent = firstAllIdleEvent = true; + } + }; + + private final boolean observeOutput; + private final long readerIdleTimeNanos; + private final long writerIdleTimeNanos; + private final long allIdleTimeNanos; + + private Future readerIdleTimeout; + private long lastReadTime; + private boolean firstReaderIdleEvent = true; + + private Future writerIdleTimeout; + private long lastWriteTime; + private boolean firstWriterIdleEvent = true; + + private Future allIdleTimeout; + private boolean firstAllIdleEvent = true; + + private byte state; // 0 - none, 1 - initialized, 2 - destroyed + private boolean reading; + + private long lastChangeCheckTimeStamp; + private int lastMessageHashCode; + private long lastPendingWriteBytes; + private long lastFlushProgress; + + /** + * Creates a new instance firing {@link IdleStateEvent}s. + * + * @param readerIdleTimeSeconds + * an {@link IdleStateEvent} whose state is {@link IdleState#READER_IDLE} + * will be triggered when no read was performed for the specified + * period of time. Specify {@code 0} to disable. + * @param writerIdleTimeSeconds + * an {@link IdleStateEvent} whose state is {@link IdleState#WRITER_IDLE} + * will be triggered when no write was performed for the specified + * period of time. Specify {@code 0} to disable. + * @param allIdleTimeSeconds + * an {@link IdleStateEvent} whose state is {@link IdleState#ALL_IDLE} + * will be triggered when neither read nor write was performed for + * the specified period of time. Specify {@code 0} to disable. + */ + public IdleStateHandler( + int readerIdleTimeSeconds, + int writerIdleTimeSeconds, + int allIdleTimeSeconds) { + + this(readerIdleTimeSeconds, writerIdleTimeSeconds, allIdleTimeSeconds, + TimeUnit.SECONDS); + } + + /** + * @see #IdleStateHandler(boolean, long, long, long, TimeUnit) + */ + public IdleStateHandler( + long readerIdleTime, long writerIdleTime, long allIdleTime, + TimeUnit unit) { + this(false, readerIdleTime, writerIdleTime, allIdleTime, unit); + } + + /** + * Creates a new instance firing {@link IdleStateEvent}s. + * + * @param observeOutput + * whether or not the consumption of {@code bytes} should be taken into + * consideration when assessing write idleness. The default is {@code false}. + * @param readerIdleTime + * an {@link IdleStateEvent} whose state is {@link IdleState#READER_IDLE} + * will be triggered when no read was performed for the specified + * period of time. Specify {@code 0} to disable. + * @param writerIdleTime + * an {@link IdleStateEvent} whose state is {@link IdleState#WRITER_IDLE} + * will be triggered when no write was performed for the specified + * period of time. Specify {@code 0} to disable. + * @param allIdleTime + * an {@link IdleStateEvent} whose state is {@link IdleState#ALL_IDLE} + * will be triggered when neither read nor write was performed for + * the specified period of time. Specify {@code 0} to disable. + * @param unit + * the {@link TimeUnit} of {@code readerIdleTime}, + * {@code writeIdleTime}, and {@code allIdleTime} + */ + public IdleStateHandler(boolean observeOutput, + long readerIdleTime, long writerIdleTime, long allIdleTime, + TimeUnit unit) { + ObjectUtil.checkNotNull(unit, "unit"); + + this.observeOutput = observeOutput; + + if (readerIdleTime <= 0) { + readerIdleTimeNanos = 0; + } else { + readerIdleTimeNanos = Math.max(unit.toNanos(readerIdleTime), MIN_TIMEOUT_NANOS); + } + if (writerIdleTime <= 0) { + writerIdleTimeNanos = 0; + } else { + writerIdleTimeNanos = Math.max(unit.toNanos(writerIdleTime), MIN_TIMEOUT_NANOS); + } + if (allIdleTime <= 0) { + allIdleTimeNanos = 0; + } else { + allIdleTimeNanos = Math.max(unit.toNanos(allIdleTime), MIN_TIMEOUT_NANOS); + } + } + + /** + * Return the readerIdleTime that was given when instance this class in milliseconds. + * + */ + public long getReaderIdleTimeInMillis() { + return TimeUnit.NANOSECONDS.toMillis(readerIdleTimeNanos); + } + + /** + * Return the writerIdleTime that was given when instance this class in milliseconds. + * + */ + public long getWriterIdleTimeInMillis() { + return TimeUnit.NANOSECONDS.toMillis(writerIdleTimeNanos); + } + + /** + * Return the allIdleTime that was given when instance this class in milliseconds. + * + */ + public long getAllIdleTimeInMillis() { + return TimeUnit.NANOSECONDS.toMillis(allIdleTimeNanos); + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + if (ctx.channel().isActive() && ctx.channel().isRegistered()) { + // channelActive() event has been fired already, which means this.channelActive() will + // not be invoked. We have to initialize here instead. + initialize(ctx); + } else { + // channelActive() event has not been fired yet. this.channelActive() will be invoked + // and initialization will occur there. + } + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { + destroy(); + } + + @Override + public void channelRegistered(ChannelHandlerContext ctx) throws Exception { + // Initialize early if channel is active already. + if (ctx.channel().isActive()) { + initialize(ctx); + } + super.channelRegistered(ctx); + } + + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + // This method will be invoked only if this handler was added + // before channelActive() event is fired. If a user adds this handler + // after the channelActive() event, initialize() will be called by beforeAdd(). + initialize(ctx); + super.channelActive(ctx); + } + + @Override + public void channelInactive(ChannelHandlerContext ctx) throws Exception { + destroy(); + super.channelInactive(ctx); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if (readerIdleTimeNanos > 0 || allIdleTimeNanos > 0) { + reading = true; + firstReaderIdleEvent = firstAllIdleEvent = true; + } + ctx.fireChannelRead(msg); + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) throws Exception { + if ((readerIdleTimeNanos > 0 || allIdleTimeNanos > 0) && reading) { + lastReadTime = ticksInNanos(); + reading = false; + } + ctx.fireChannelReadComplete(); + } + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + // Allow writing with void promise if handler is only configured for read timeout events. + if (writerIdleTimeNanos > 0 || allIdleTimeNanos > 0) { + ctx.write(msg, promise.unvoid()).addListener(writeListener); + } else { + ctx.write(msg, promise); + } + } + + private void initialize(ChannelHandlerContext ctx) { + // Avoid the case where destroy() is called before scheduling timeouts. + // See: https://github.com/netty/netty/issues/143 + switch (state) { + case 1: + case 2: + return; + default: + break; + } + + state = 1; + initOutputChanged(ctx); + + lastReadTime = lastWriteTime = ticksInNanos(); + if (readerIdleTimeNanos > 0) { + readerIdleTimeout = schedule(ctx, new ReaderIdleTimeoutTask(ctx), + readerIdleTimeNanos, TimeUnit.NANOSECONDS); + } + if (writerIdleTimeNanos > 0) { + writerIdleTimeout = schedule(ctx, new WriterIdleTimeoutTask(ctx), + writerIdleTimeNanos, TimeUnit.NANOSECONDS); + } + if (allIdleTimeNanos > 0) { + allIdleTimeout = schedule(ctx, new AllIdleTimeoutTask(ctx), + allIdleTimeNanos, TimeUnit.NANOSECONDS); + } + } + + /** + * This method is visible for testing! + */ + long ticksInNanos() { + return System.nanoTime(); + } + + /** + * This method is visible for testing! + */ + Future schedule(ChannelHandlerContext ctx, Runnable task, long delay, TimeUnit unit) { + return ctx.executor().schedule(task, delay, unit); + } + + private void destroy() { + state = 2; + + if (readerIdleTimeout != null) { + readerIdleTimeout.cancel(false); + readerIdleTimeout = null; + } + if (writerIdleTimeout != null) { + writerIdleTimeout.cancel(false); + writerIdleTimeout = null; + } + if (allIdleTimeout != null) { + allIdleTimeout.cancel(false); + allIdleTimeout = null; + } + } + + /** + * Is called when an {@link IdleStateEvent} should be fired. This implementation calls + * {@link ChannelHandlerContext#fireUserEventTriggered(Object)}. + */ + protected void channelIdle(ChannelHandlerContext ctx, IdleStateEvent evt) throws Exception { + ctx.fireUserEventTriggered(evt); + } + + /** + * Returns a {@link IdleStateEvent}. + */ + protected IdleStateEvent newIdleStateEvent(IdleState state, boolean first) { + switch (state) { + case ALL_IDLE: + return first ? IdleStateEvent.FIRST_ALL_IDLE_STATE_EVENT : IdleStateEvent.ALL_IDLE_STATE_EVENT; + case READER_IDLE: + return first ? IdleStateEvent.FIRST_READER_IDLE_STATE_EVENT : IdleStateEvent.READER_IDLE_STATE_EVENT; + case WRITER_IDLE: + return first ? IdleStateEvent.FIRST_WRITER_IDLE_STATE_EVENT : IdleStateEvent.WRITER_IDLE_STATE_EVENT; + default: + throw new IllegalArgumentException("Unhandled: state=" + state + ", first=" + first); + } + } + + /** + * @see #hasOutputChanged(ChannelHandlerContext, boolean) + */ + private void initOutputChanged(ChannelHandlerContext ctx) { + if (observeOutput) { + Channel channel = ctx.channel(); + Unsafe unsafe = channel.unsafe(); + ChannelOutboundBuffer buf = unsafe.outboundBuffer(); + + if (buf != null) { + lastMessageHashCode = System.identityHashCode(buf.current()); + lastPendingWriteBytes = buf.totalPendingWriteBytes(); + lastFlushProgress = buf.currentProgress(); + } + } + } + + /** + * Returns {@code true} if and only if the {@link IdleStateHandler} was constructed + * with {@link #observeOutput} enabled and there has been an observed change in the + * {@link ChannelOutboundBuffer} between two consecutive calls of this method. + * + * https://github.com/netty/netty/issues/6150 + */ + private boolean hasOutputChanged(ChannelHandlerContext ctx, boolean first) { + if (observeOutput) { + + // We can take this shortcut if the ChannelPromises that got passed into write() + // appear to complete. It indicates "change" on message level and we simply assume + // that there's change happening on byte level. If the user doesn't observe channel + // writability events then they'll eventually OOME and there's clearly a different + // problem and idleness is least of their concerns. + if (lastChangeCheckTimeStamp != lastWriteTime) { + lastChangeCheckTimeStamp = lastWriteTime; + + // But this applies only if it's the non-first call. + if (!first) { + return true; + } + } + + Channel channel = ctx.channel(); + Unsafe unsafe = channel.unsafe(); + ChannelOutboundBuffer buf = unsafe.outboundBuffer(); + + if (buf != null) { + int messageHashCode = System.identityHashCode(buf.current()); + long pendingWriteBytes = buf.totalPendingWriteBytes(); + + if (messageHashCode != lastMessageHashCode || pendingWriteBytes != lastPendingWriteBytes) { + lastMessageHashCode = messageHashCode; + lastPendingWriteBytes = pendingWriteBytes; + + if (!first) { + return true; + } + } + + long flushProgress = buf.currentProgress(); + if (flushProgress != lastFlushProgress) { + lastFlushProgress = flushProgress; + return !first; + } + } + } + + return false; + } + + private abstract static class AbstractIdleTask implements Runnable { + + private final ChannelHandlerContext ctx; + + AbstractIdleTask(ChannelHandlerContext ctx) { + this.ctx = ctx; + } + + @Override + public void run() { + if (!ctx.channel().isOpen()) { + return; + } + + run(ctx); + } + + protected abstract void run(ChannelHandlerContext ctx); + } + + private final class ReaderIdleTimeoutTask extends AbstractIdleTask { + + ReaderIdleTimeoutTask(ChannelHandlerContext ctx) { + super(ctx); + } + + @Override + protected void run(ChannelHandlerContext ctx) { + long nextDelay = readerIdleTimeNanos; + if (!reading) { + nextDelay -= ticksInNanos() - lastReadTime; + } + + if (nextDelay <= 0) { + // Reader is idle - set a new timeout and notify the callback. + readerIdleTimeout = schedule(ctx, this, readerIdleTimeNanos, TimeUnit.NANOSECONDS); + + boolean first = firstReaderIdleEvent; + firstReaderIdleEvent = false; + + try { + IdleStateEvent event = newIdleStateEvent(IdleState.READER_IDLE, first); + channelIdle(ctx, event); + } catch (Throwable t) { + ctx.fireExceptionCaught(t); + } + } else { + // Read occurred before the timeout - set a new timeout with shorter delay. + readerIdleTimeout = schedule(ctx, this, nextDelay, TimeUnit.NANOSECONDS); + } + } + } + + private final class WriterIdleTimeoutTask extends AbstractIdleTask { + + WriterIdleTimeoutTask(ChannelHandlerContext ctx) { + super(ctx); + } + + @Override + protected void run(ChannelHandlerContext ctx) { + + long lastWriteTime = IdleStateHandler.this.lastWriteTime; + long nextDelay = writerIdleTimeNanos - (ticksInNanos() - lastWriteTime); + if (nextDelay <= 0) { + // Writer is idle - set a new timeout and notify the callback. + writerIdleTimeout = schedule(ctx, this, writerIdleTimeNanos, TimeUnit.NANOSECONDS); + + boolean first = firstWriterIdleEvent; + firstWriterIdleEvent = false; + + try { + if (hasOutputChanged(ctx, first)) { + return; + } + + IdleStateEvent event = newIdleStateEvent(IdleState.WRITER_IDLE, first); + channelIdle(ctx, event); + } catch (Throwable t) { + ctx.fireExceptionCaught(t); + } + } else { + // Write occurred before the timeout - set a new timeout with shorter delay. + writerIdleTimeout = schedule(ctx, this, nextDelay, TimeUnit.NANOSECONDS); + } + } + } + + private final class AllIdleTimeoutTask extends AbstractIdleTask { + + AllIdleTimeoutTask(ChannelHandlerContext ctx) { + super(ctx); + } + + @Override + protected void run(ChannelHandlerContext ctx) { + + long nextDelay = allIdleTimeNanos; + if (!reading) { + nextDelay -= ticksInNanos() - Math.max(lastReadTime, lastWriteTime); + } + if (nextDelay <= 0) { + // Both reader and writer are idle - set a new timeout and + // notify the callback. + allIdleTimeout = schedule(ctx, this, allIdleTimeNanos, TimeUnit.NANOSECONDS); + + boolean first = firstAllIdleEvent; + firstAllIdleEvent = false; + + try { + if (hasOutputChanged(ctx, first)) { + return; + } + + IdleStateEvent event = newIdleStateEvent(IdleState.ALL_IDLE, first); + channelIdle(ctx, event); + } catch (Throwable t) { + ctx.fireExceptionCaught(t); + } + } else { + // Either read or write occurred before the timeout - set a new + // timeout with shorter delay. + allIdleTimeout = schedule(ctx, this, nextDelay, TimeUnit.NANOSECONDS); + } + } + } +} diff --git a/netty-handler/src/main/java/io/netty/handler/timeout/ReadTimeoutException.java b/netty-handler/src/main/java/io/netty/handler/timeout/ReadTimeoutException.java new file mode 100644 index 0000000..d320adf --- /dev/null +++ b/netty-handler/src/main/java/io/netty/handler/timeout/ReadTimeoutException.java @@ -0,0 +1,40 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.timeout; + +import io.netty.util.internal.PlatformDependent; + +/** + * A {@link TimeoutException} raised by {@link ReadTimeoutHandler} when no data + * was read within a certain period of time. + */ +public final class ReadTimeoutException extends TimeoutException { + + private static final long serialVersionUID = 169287984113283421L; + + public static final ReadTimeoutException INSTANCE = PlatformDependent.javaVersion() >= 7 ? + new ReadTimeoutException(true) : new ReadTimeoutException(); + + public ReadTimeoutException() { } + + public ReadTimeoutException(String message) { + super(message, false); + } + + private ReadTimeoutException(boolean shared) { + super(null, shared); + } +} diff --git a/netty-handler/src/main/java/io/netty/handler/timeout/ReadTimeoutHandler.java b/netty-handler/src/main/java/io/netty/handler/timeout/ReadTimeoutHandler.java new file mode 100644 index 0000000..727e972 --- /dev/null +++ b/netty-handler/src/main/java/io/netty/handler/timeout/ReadTimeoutHandler.java @@ -0,0 +1,103 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.timeout; + +import io.netty.bootstrap.ServerBootstrap; +import io.netty.channel.Channel; +import io.netty.channel.ChannelDuplexHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInitializer; + +import java.util.concurrent.TimeUnit; + +/** + * Raises a {@link ReadTimeoutException} when no data was read within a certain + * period of time. + * + *
+ * // The connection is closed when there is no inbound traffic
+ * // for 30 seconds.
+ *
+ * public class MyChannelInitializer extends {@link ChannelInitializer}<{@link Channel}> {
+ *     public void initChannel({@link Channel} channel) {
+ *         channel.pipeline().addLast("readTimeoutHandler", new {@link ReadTimeoutHandler}(30));
+ *         channel.pipeline().addLast("myHandler", new MyHandler());
+ *     }
+ * }
+ *
+ * // Handler should handle the {@link ReadTimeoutException}.
+ * public class MyHandler extends {@link ChannelDuplexHandler} {
+ *     {@code @Override}
+ *     public void exceptionCaught({@link ChannelHandlerContext} ctx, {@link Throwable} cause)
+ *             throws {@link Exception} {
+ *         if (cause instanceof {@link ReadTimeoutException}) {
+ *             // do something
+ *         } else {
+ *             super.exceptionCaught(ctx, cause);
+ *         }
+ *     }
+ * }
+ *
+ * {@link ServerBootstrap} bootstrap = ...;
+ * ...
+ * bootstrap.childHandler(new MyChannelInitializer());
+ * ...
+ * 
+ * @see WriteTimeoutHandler + * @see IdleStateHandler + */ +public class ReadTimeoutHandler extends IdleStateHandler { + private boolean closed; + + /** + * Creates a new instance. + * + * @param timeoutSeconds + * read timeout in seconds + */ + public ReadTimeoutHandler(int timeoutSeconds) { + this(timeoutSeconds, TimeUnit.SECONDS); + } + + /** + * Creates a new instance. + * + * @param timeout + * read timeout + * @param unit + * the {@link TimeUnit} of {@code timeout} + */ + public ReadTimeoutHandler(long timeout, TimeUnit unit) { + super(timeout, 0, 0, unit); + } + + @Override + protected final void channelIdle(ChannelHandlerContext ctx, IdleStateEvent evt) throws Exception { + assert evt.state() == IdleState.READER_IDLE; + readTimedOut(ctx); + } + + /** + * Is called when a read timeout was detected. + */ + protected void readTimedOut(ChannelHandlerContext ctx) throws Exception { + if (!closed) { + ctx.fireExceptionCaught(ReadTimeoutException.INSTANCE); + ctx.close(); + closed = true; + } + } +} diff --git a/netty-handler/src/main/java/io/netty/handler/timeout/TimeoutException.java b/netty-handler/src/main/java/io/netty/handler/timeout/TimeoutException.java new file mode 100644 index 0000000..6c2b299 --- /dev/null +++ b/netty-handler/src/main/java/io/netty/handler/timeout/TimeoutException.java @@ -0,0 +1,40 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.timeout; + +import io.netty.channel.ChannelException; + +/** + * A {@link TimeoutException} when no data was either read or written within a + * certain period of time. + */ +public class TimeoutException extends ChannelException { + + private static final long serialVersionUID = 4673641882869672533L; + + TimeoutException() { + } + + TimeoutException(String message, boolean shared) { + super(message, null, shared); + } + + // Suppress a warning since the method doesn't need synchronization + @Override + public Throwable fillInStackTrace() { + return this; + } +} diff --git a/netty-handler/src/main/java/io/netty/handler/timeout/WriteTimeoutException.java b/netty-handler/src/main/java/io/netty/handler/timeout/WriteTimeoutException.java new file mode 100644 index 0000000..f0ac79e --- /dev/null +++ b/netty-handler/src/main/java/io/netty/handler/timeout/WriteTimeoutException.java @@ -0,0 +1,40 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.timeout; + +import io.netty.util.internal.PlatformDependent; + +/** + * A {@link TimeoutException} raised by {@link WriteTimeoutHandler} when a write operation + * cannot finish in a certain period of time. + */ +public final class WriteTimeoutException extends TimeoutException { + + private static final long serialVersionUID = -144786655770296065L; + + public static final WriteTimeoutException INSTANCE = PlatformDependent.javaVersion() >= 7 ? + new WriteTimeoutException(true) : new WriteTimeoutException(); + + public WriteTimeoutException() { } + + public WriteTimeoutException(String message) { + super(message, false); + } + + private WriteTimeoutException(boolean shared) { + super(null, shared); + } +} diff --git a/netty-handler/src/main/java/io/netty/handler/timeout/WriteTimeoutHandler.java b/netty-handler/src/main/java/io/netty/handler/timeout/WriteTimeoutHandler.java new file mode 100644 index 0000000..be2c0c2 --- /dev/null +++ b/netty-handler/src/main/java/io/netty/handler/timeout/WriteTimeoutHandler.java @@ -0,0 +1,236 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.timeout; + +import io.netty.bootstrap.ServerBootstrap; +import io.netty.channel.Channel; +import io.netty.channel.ChannelDuplexHandler; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelOutboundHandlerAdapter; +import io.netty.channel.ChannelPromise; +import io.netty.util.concurrent.Future; +import io.netty.util.internal.ObjectUtil; + +import java.util.concurrent.TimeUnit; + +/** + * Raises a {@link WriteTimeoutException} when a write operation cannot finish in a certain period of time. + * + *
+ * // The connection is closed when a write operation cannot finish in 30 seconds.
+ *
+ * public class MyChannelInitializer extends {@link ChannelInitializer}<{@link Channel}> {
+ *     public void initChannel({@link Channel} channel) {
+ *         channel.pipeline().addLast("writeTimeoutHandler", new {@link WriteTimeoutHandler}(30);
+ *         channel.pipeline().addLast("myHandler", new MyHandler());
+ *     }
+ * }
+ *
+ * // Handler should handle the {@link WriteTimeoutException}.
+ * public class MyHandler extends {@link ChannelDuplexHandler} {
+ *     {@code @Override}
+ *     public void exceptionCaught({@link ChannelHandlerContext} ctx, {@link Throwable} cause)
+ *             throws {@link Exception} {
+ *         if (cause instanceof {@link WriteTimeoutException}) {
+ *             // do something
+ *         } else {
+ *             super.exceptionCaught(ctx, cause);
+ *         }
+ *     }
+ * }
+ *
+ * {@link ServerBootstrap} bootstrap = ...;
+ * ...
+ * bootstrap.childHandler(new MyChannelInitializer());
+ * ...
+ * 
+ * @see ReadTimeoutHandler + * @see IdleStateHandler + */ +public class WriteTimeoutHandler extends ChannelOutboundHandlerAdapter { + private static final long MIN_TIMEOUT_NANOS = TimeUnit.MILLISECONDS.toNanos(1); + + private final long timeoutNanos; + + /** + * A doubly-linked list to track all WriteTimeoutTasks + */ + private WriteTimeoutTask lastTask; + + private boolean closed; + + /** + * Creates a new instance. + * + * @param timeoutSeconds + * write timeout in seconds + */ + public WriteTimeoutHandler(int timeoutSeconds) { + this(timeoutSeconds, TimeUnit.SECONDS); + } + + /** + * Creates a new instance. + * + * @param timeout + * write timeout + * @param unit + * the {@link TimeUnit} of {@code timeout} + */ + public WriteTimeoutHandler(long timeout, TimeUnit unit) { + ObjectUtil.checkNotNull(unit, "unit"); + + if (timeout <= 0) { + timeoutNanos = 0; + } else { + timeoutNanos = Math.max(unit.toNanos(timeout), MIN_TIMEOUT_NANOS); + } + } + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + if (timeoutNanos > 0) { + promise = promise.unvoid(); + scheduleTimeout(ctx, promise); + } + ctx.write(msg, promise); + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { + assert ctx.executor().inEventLoop(); + WriteTimeoutTask task = lastTask; + lastTask = null; + while (task != null) { + assert task.ctx.executor().inEventLoop(); + task.scheduledFuture.cancel(false); + WriteTimeoutTask prev = task.prev; + task.prev = null; + task.next = null; + task = prev; + } + } + + private void scheduleTimeout(final ChannelHandlerContext ctx, final ChannelPromise promise) { + // Schedule a timeout. + final WriteTimeoutTask task = new WriteTimeoutTask(ctx, promise); + task.scheduledFuture = ctx.executor().schedule(task, timeoutNanos, TimeUnit.NANOSECONDS); + + if (!task.scheduledFuture.isDone()) { + addWriteTimeoutTask(task); + + // Cancel the scheduled timeout if the flush promise is complete. + promise.addListener(task); + } + } + + private void addWriteTimeoutTask(WriteTimeoutTask task) { + assert task.ctx.executor().inEventLoop(); + if (lastTask != null) { + lastTask.next = task; + task.prev = lastTask; + } + lastTask = task; + } + + private void removeWriteTimeoutTask(WriteTimeoutTask task) { + assert task.ctx.executor().inEventLoop(); + if (task == lastTask) { + // task is the tail of list + assert task.next == null; + lastTask = lastTask.prev; + if (lastTask != null) { + lastTask.next = null; + } + } else if (task.prev == null && task.next == null) { + // Since task is not lastTask, then it has been removed or not been added. + return; + } else if (task.prev == null) { + // task is the head of list and the list has at least 2 nodes + task.next.prev = null; + } else { + task.prev.next = task.next; + task.next.prev = task.prev; + } + task.prev = null; + task.next = null; + } + + /** + * Is called when a write timeout was detected + */ + protected void writeTimedOut(ChannelHandlerContext ctx) throws Exception { + if (!closed) { + ctx.fireExceptionCaught(WriteTimeoutException.INSTANCE); + ctx.close(); + closed = true; + } + } + + private final class WriteTimeoutTask implements Runnable, ChannelFutureListener { + + private final ChannelHandlerContext ctx; + private final ChannelPromise promise; + + // WriteTimeoutTask is also a node of a doubly-linked list + WriteTimeoutTask prev; + WriteTimeoutTask next; + + Future scheduledFuture; + + WriteTimeoutTask(ChannelHandlerContext ctx, ChannelPromise promise) { + this.ctx = ctx; + this.promise = promise; + } + + @Override + public void run() { + // Was not written yet so issue a write timeout + // The promise itself will be failed with a ClosedChannelException once the close() was issued + // See https://github.com/netty/netty/issues/2159 + if (!promise.isDone()) { + try { + writeTimedOut(ctx); + } catch (Throwable t) { + ctx.fireExceptionCaught(t); + } + } + removeWriteTimeoutTask(this); + } + + @Override + public void operationComplete(ChannelFuture future) throws Exception { + // scheduledFuture has already be set when reaching here + scheduledFuture.cancel(false); + + // Check if its safe to modify the "doubly-linked-list" that we maintain. If its not we will schedule the + // modification so its picked up by the executor.. + if (ctx.executor().inEventLoop()) { + removeWriteTimeoutTask(this); + } else { + // So let's just pass outself to the executor which will then take care of remove this task + // from the doubly-linked list. Schedule ourself is fine as the promise itself is done. + // + // This fixes https://github.com/netty/netty/issues/11053 + assert promise.isDone(); + ctx.executor().execute(this); + } + } + } +} diff --git a/netty-handler/src/main/java/io/netty/handler/timeout/package-info.java b/netty-handler/src/main/java/io/netty/handler/timeout/package-info.java new file mode 100644 index 0000000..c403681 --- /dev/null +++ b/netty-handler/src/main/java/io/netty/handler/timeout/package-info.java @@ -0,0 +1,21 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/** + * Adds support for read and write timeout and idle connection notification + * using a {@link io.netty.util.Timer}. + */ +package io.netty.handler.timeout; diff --git a/netty-handler/src/main/java/io/netty/handler/traffic/AbstractTrafficShapingHandler.java b/netty-handler/src/main/java/io/netty/handler/traffic/AbstractTrafficShapingHandler.java new file mode 100644 index 0000000..0fecbf8 --- /dev/null +++ b/netty-handler/src/main/java/io/netty/handler/traffic/AbstractTrafficShapingHandler.java @@ -0,0 +1,658 @@ +/* + * Copyright 2011 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.traffic; + +import static io.netty.util.internal.ObjectUtil.checkPositive; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufHolder; +import io.netty.channel.Channel; +import io.netty.channel.ChannelDuplexHandler; +import io.netty.channel.ChannelConfig; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelOutboundBuffer; +import io.netty.channel.ChannelPromise; +import io.netty.channel.FileRegion; +import io.netty.util.Attribute; +import io.netty.util.AttributeKey; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.util.concurrent.TimeUnit; + +/** + *

AbstractTrafficShapingHandler allows to limit the global bandwidth + * (see {@link GlobalTrafficShapingHandler}) or per session + * bandwidth (see {@link ChannelTrafficShapingHandler}), as traffic shaping. + * It allows you to implement an almost real time monitoring of the bandwidth using + * the monitors from {@link TrafficCounter} that will call back every checkInterval + * the method doAccounting of this handler.

+ * + *

If you want for any particular reasons to stop the monitoring (accounting) or to change + * the read/write limit or the check interval, several methods allow that for you:

+ *
    + *
  • configure allows you to change read or write limits, or the checkInterval
  • + *
  • getTrafficCounter allows you to have access to the TrafficCounter and so to stop + * or start the monitoring, to change the checkInterval directly, or to have access to its values.
  • + *
+ */ +public abstract class AbstractTrafficShapingHandler extends ChannelDuplexHandler { + private static final InternalLogger logger = + InternalLoggerFactory.getInstance(AbstractTrafficShapingHandler.class); + /** + * Default delay between two checks: 1s + */ + public static final long DEFAULT_CHECK_INTERVAL = 1000; + + /** + * Default max delay in case of traffic shaping + * (during which no communication will occur). + * Shall be less than TIMEOUT. Here half of "standard" 30s + */ + public static final long DEFAULT_MAX_TIME = 15000; + + /** + * Default max size to not exceed in buffer (write only). + */ + static final long DEFAULT_MAX_SIZE = 4 * 1024 * 1024L; + + /** + * Default minimal time to wait: 10ms + */ + static final long MINIMAL_WAIT = 10; + + /** + * Traffic Counter + */ + protected TrafficCounter trafficCounter; + + /** + * Limit in B/s to apply to write + */ + private volatile long writeLimit; + + /** + * Limit in B/s to apply to read + */ + private volatile long readLimit; + + /** + * Max delay in wait + */ + protected volatile long maxTime = DEFAULT_MAX_TIME; // default 15 s + + /** + * Delay between two performance snapshots + */ + protected volatile long checkInterval = DEFAULT_CHECK_INTERVAL; // default 1 s + + static final AttributeKey READ_SUSPENDED = AttributeKey + .valueOf(AbstractTrafficShapingHandler.class.getName() + ".READ_SUSPENDED"); + static final AttributeKey REOPEN_TASK = AttributeKey.valueOf(AbstractTrafficShapingHandler.class + .getName() + ".REOPEN_TASK"); + + /** + * Max time to delay before proposing to stop writing new objects from next handlers + */ + volatile long maxWriteDelay = 4 * DEFAULT_CHECK_INTERVAL; // default 4 s + /** + * Max size in the list before proposing to stop writing new objects from next handlers + */ + volatile long maxWriteSize = DEFAULT_MAX_SIZE; // default 4MB + + /** + * Rank in UserDefinedWritability (1 for Channel, 2 for Global TrafficShapingHandler). + * Set in final constructor. Must be between 1 and 31 + */ + final int userDefinedWritabilityIndex; + + /** + * Default value for Channel UserDefinedWritability index + */ + static final int CHANNEL_DEFAULT_USER_DEFINED_WRITABILITY_INDEX = 1; + + /** + * Default value for Global UserDefinedWritability index + */ + static final int GLOBAL_DEFAULT_USER_DEFINED_WRITABILITY_INDEX = 2; + + /** + * Default value for GlobalChannel UserDefinedWritability index + */ + static final int GLOBALCHANNEL_DEFAULT_USER_DEFINED_WRITABILITY_INDEX = 3; + + /** + * @param newTrafficCounter + * the TrafficCounter to set + */ + void setTrafficCounter(TrafficCounter newTrafficCounter) { + trafficCounter = newTrafficCounter; + } + + /** + * @return the index to be used by the TrafficShapingHandler to manage the user defined writability. + * For Channel TSH it is defined as {@value #CHANNEL_DEFAULT_USER_DEFINED_WRITABILITY_INDEX}, + * for Global TSH it is defined as {@value #GLOBAL_DEFAULT_USER_DEFINED_WRITABILITY_INDEX}, + * for GlobalChannel TSH it is defined as + * {@value #GLOBALCHANNEL_DEFAULT_USER_DEFINED_WRITABILITY_INDEX}. + */ + protected int userDefinedWritabilityIndex() { + return CHANNEL_DEFAULT_USER_DEFINED_WRITABILITY_INDEX; + } + + /** + * @param writeLimit + * 0 or a limit in bytes/s + * @param readLimit + * 0 or a limit in bytes/s + * @param checkInterval + * The delay between two computations of performances for + * channels or 0 if no stats are to be computed. + * @param maxTime + * The maximum delay to wait in case of traffic excess. + * Must be positive. + */ + protected AbstractTrafficShapingHandler(long writeLimit, long readLimit, long checkInterval, long maxTime) { + this.maxTime = checkPositive(maxTime, "maxTime"); + + userDefinedWritabilityIndex = userDefinedWritabilityIndex(); + this.writeLimit = writeLimit; + this.readLimit = readLimit; + this.checkInterval = checkInterval; + } + + /** + * Constructor using default max time as delay allowed value of {@value #DEFAULT_MAX_TIME} ms. + * @param writeLimit + * 0 or a limit in bytes/s + * @param readLimit + * 0 or a limit in bytes/s + * @param checkInterval + * The delay between two computations of performances for + * channels or 0 if no stats are to be computed. + */ + protected AbstractTrafficShapingHandler(long writeLimit, long readLimit, long checkInterval) { + this(writeLimit, readLimit, checkInterval, DEFAULT_MAX_TIME); + } + + /** + * Constructor using default Check Interval value of {@value #DEFAULT_CHECK_INTERVAL} ms and + * default max time as delay allowed value of {@value #DEFAULT_MAX_TIME} ms. + * + * @param writeLimit + * 0 or a limit in bytes/s + * @param readLimit + * 0 or a limit in bytes/s + */ + protected AbstractTrafficShapingHandler(long writeLimit, long readLimit) { + this(writeLimit, readLimit, DEFAULT_CHECK_INTERVAL, DEFAULT_MAX_TIME); + } + + /** + * Constructor using NO LIMIT, default Check Interval value of {@value #DEFAULT_CHECK_INTERVAL} ms and + * default max time as delay allowed value of {@value #DEFAULT_MAX_TIME} ms. + */ + protected AbstractTrafficShapingHandler() { + this(0, 0, DEFAULT_CHECK_INTERVAL, DEFAULT_MAX_TIME); + } + + /** + * Constructor using NO LIMIT and + * default max time as delay allowed value of {@value #DEFAULT_MAX_TIME} ms. + * + * @param checkInterval + * The delay between two computations of performances for + * channels or 0 if no stats are to be computed. + */ + protected AbstractTrafficShapingHandler(long checkInterval) { + this(0, 0, checkInterval, DEFAULT_MAX_TIME); + } + + /** + * Change the underlying limitations and check interval. + *

Note the change will be taken as best effort, meaning + * that all already scheduled traffics will not be + * changed, but only applied to new traffics.

+ *

So the expected usage of this method is to be used not too often, + * accordingly to the traffic shaping configuration.

+ * + * @param newWriteLimit The new write limit (in bytes) + * @param newReadLimit The new read limit (in bytes) + * @param newCheckInterval The new check interval (in milliseconds) + */ + public void configure(long newWriteLimit, long newReadLimit, + long newCheckInterval) { + configure(newWriteLimit, newReadLimit); + configure(newCheckInterval); + } + + /** + * Change the underlying limitations. + *

Note the change will be taken as best effort, meaning + * that all already scheduled traffics will not be + * changed, but only applied to new traffics.

+ *

So the expected usage of this method is to be used not too often, + * accordingly to the traffic shaping configuration.

+ * + * @param newWriteLimit The new write limit (in bytes) + * @param newReadLimit The new read limit (in bytes) + */ + public void configure(long newWriteLimit, long newReadLimit) { + writeLimit = newWriteLimit; + readLimit = newReadLimit; + if (trafficCounter != null) { + trafficCounter.resetAccounting(TrafficCounter.milliSecondFromNano()); + } + } + + /** + * Change the check interval. + * + * @param newCheckInterval The new check interval (in milliseconds) + */ + public void configure(long newCheckInterval) { + checkInterval = newCheckInterval; + if (trafficCounter != null) { + trafficCounter.configure(checkInterval); + } + } + + /** + * @return the writeLimit + */ + public long getWriteLimit() { + return writeLimit; + } + + /** + *

Note the change will be taken as best effort, meaning + * that all already scheduled traffics will not be + * changed, but only applied to new traffics.

+ *

So the expected usage of this method is to be used not too often, + * accordingly to the traffic shaping configuration.

+ * + * @param writeLimit the writeLimit to set + */ + public void setWriteLimit(long writeLimit) { + this.writeLimit = writeLimit; + if (trafficCounter != null) { + trafficCounter.resetAccounting(TrafficCounter.milliSecondFromNano()); + } + } + + /** + * @return the readLimit + */ + public long getReadLimit() { + return readLimit; + } + + /** + *

Note the change will be taken as best effort, meaning + * that all already scheduled traffics will not be + * changed, but only applied to new traffics.

+ *

So the expected usage of this method is to be used not too often, + * accordingly to the traffic shaping configuration.

+ * + * @param readLimit the readLimit to set + */ + public void setReadLimit(long readLimit) { + this.readLimit = readLimit; + if (trafficCounter != null) { + trafficCounter.resetAccounting(TrafficCounter.milliSecondFromNano()); + } + } + + /** + * @return the checkInterval + */ + public long getCheckInterval() { + return checkInterval; + } + + /** + * @param checkInterval the interval in ms between each step check to set, default value being 1000 ms. + */ + public void setCheckInterval(long checkInterval) { + this.checkInterval = checkInterval; + if (trafficCounter != null) { + trafficCounter.configure(checkInterval); + } + } + + /** + *

Note the change will be taken as best effort, meaning + * that all already scheduled traffics will not be + * changed, but only applied to new traffics.

+ *

So the expected usage of this method is to be used not too often, + * accordingly to the traffic shaping configuration.

+ * + * @param maxTime + * Max delay in wait, shall be less than TIME OUT in related protocol. + * Must be positive. + */ + public void setMaxTimeWait(long maxTime) { + this.maxTime = checkPositive(maxTime, "maxTime"); + } + + /** + * @return the max delay in wait to prevent TIME OUT + */ + public long getMaxTimeWait() { + return maxTime; + } + + /** + * @return the maxWriteDelay + */ + public long getMaxWriteDelay() { + return maxWriteDelay; + } + + /** + *

Note the change will be taken as best effort, meaning + * that all already scheduled traffics will not be + * changed, but only applied to new traffics.

+ *

So the expected usage of this method is to be used not too often, + * accordingly to the traffic shaping configuration.

+ * + * @param maxWriteDelay the maximum Write Delay in ms in the buffer allowed before write suspension is set. + * Must be positive. + */ + public void setMaxWriteDelay(long maxWriteDelay) { + this.maxWriteDelay = checkPositive(maxWriteDelay, "maxWriteDelay"); + } + + /** + * @return the maxWriteSize default being {@value #DEFAULT_MAX_SIZE} bytes. + */ + public long getMaxWriteSize() { + return maxWriteSize; + } + + /** + *

Note that this limit is a best effort on memory limitation to prevent Out Of + * Memory Exception. To ensure it works, the handler generating the write should + * use one of the way provided by Netty to handle the capacity:

+ *

- the {@code Channel.isWritable()} property and the corresponding + * {@code channelWritabilityChanged()}

+ *

- the {@code ChannelFuture.addListener(new GenericFutureListener())}

+ * + * @param maxWriteSize the maximum Write Size allowed in the buffer + * per channel before write suspended is set, + * default being {@value #DEFAULT_MAX_SIZE} bytes. + */ + public void setMaxWriteSize(long maxWriteSize) { + this.maxWriteSize = maxWriteSize; + } + + /** + * Called each time the accounting is computed from the TrafficCounters. + * This method could be used for instance to implement almost real time accounting. + * + * @param counter + * the TrafficCounter that computes its performance + */ + protected void doAccounting(TrafficCounter counter) { + // NOOP by default + } + + /** + * Class to implement setReadable at fix time + */ + static final class ReopenReadTimerTask implements Runnable { + final ChannelHandlerContext ctx; + ReopenReadTimerTask(ChannelHandlerContext ctx) { + this.ctx = ctx; + } + + @Override + public void run() { + Channel channel = ctx.channel(); + ChannelConfig config = channel.config(); + if (!config.isAutoRead() && isHandlerActive(ctx)) { + // If AutoRead is False and Active is True, user make a direct setAutoRead(false) + // Then Just reset the status + if (logger.isDebugEnabled()) { + logger.debug("Not unsuspend: " + config.isAutoRead() + ':' + + isHandlerActive(ctx)); + } + channel.attr(READ_SUSPENDED).set(false); + } else { + // Anything else allows the handler to reset the AutoRead + if (logger.isDebugEnabled()) { + if (config.isAutoRead() && !isHandlerActive(ctx)) { + if (logger.isDebugEnabled()) { + logger.debug("Unsuspend: " + config.isAutoRead() + ':' + + isHandlerActive(ctx)); + } + } else { + if (logger.isDebugEnabled()) { + logger.debug("Normal unsuspend: " + config.isAutoRead() + ':' + + isHandlerActive(ctx)); + } + } + } + channel.attr(READ_SUSPENDED).set(false); + config.setAutoRead(true); + channel.read(); + } + if (logger.isDebugEnabled()) { + logger.debug("Unsuspend final status => " + config.isAutoRead() + ':' + + isHandlerActive(ctx)); + } + } + } + + /** + * Release the Read suspension + */ + void releaseReadSuspended(ChannelHandlerContext ctx) { + Channel channel = ctx.channel(); + channel.attr(READ_SUSPENDED).set(false); + channel.config().setAutoRead(true); + } + + @Override + public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception { + long size = calculateSize(msg); + long now = TrafficCounter.milliSecondFromNano(); + if (size > 0) { + // compute the number of ms to wait before reopening the channel + long wait = trafficCounter.readTimeToWait(size, readLimit, maxTime, now); + wait = checkWaitReadTime(ctx, wait, now); + if (wait >= MINIMAL_WAIT) { // At least 10ms seems a minimal + // time in order to try to limit the traffic + // Only AutoRead AND HandlerActive True means Context Active + Channel channel = ctx.channel(); + ChannelConfig config = channel.config(); + if (logger.isDebugEnabled()) { + logger.debug("Read suspend: " + wait + ':' + config.isAutoRead() + ':' + + isHandlerActive(ctx)); + } + if (config.isAutoRead() && isHandlerActive(ctx)) { + config.setAutoRead(false); + channel.attr(READ_SUSPENDED).set(true); + // Create a Runnable to reactive the read if needed. If one was create before it will just be + // reused to limit object creation + Attribute attr = channel.attr(REOPEN_TASK); + Runnable reopenTask = attr.get(); + if (reopenTask == null) { + reopenTask = new ReopenReadTimerTask(ctx); + attr.set(reopenTask); + } + ctx.executor().schedule(reopenTask, wait, TimeUnit.MILLISECONDS); + if (logger.isDebugEnabled()) { + logger.debug("Suspend final status => " + config.isAutoRead() + ':' + + isHandlerActive(ctx) + " will reopened at: " + wait); + } + } + } + } + informReadOperation(ctx, now); + ctx.fireChannelRead(msg); + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { + Channel channel = ctx.channel(); + if (channel.hasAttr(REOPEN_TASK)) { + //release the reopen task + channel.attr(REOPEN_TASK).set(null); + } + super.handlerRemoved(ctx); + } + + /** + * Method overridden in GTSH to take into account specific timer for the channel. + * @param wait the wait delay computed in ms + * @param now the relative now time in ms + * @return the wait to use according to the context + */ + long checkWaitReadTime(final ChannelHandlerContext ctx, long wait, final long now) { + // no change by default + return wait; + } + + /** + * Method overridden in GTSH to take into account specific timer for the channel. + * @param now the relative now time in ms + */ + void informReadOperation(final ChannelHandlerContext ctx, final long now) { + // default noop + } + + protected static boolean isHandlerActive(ChannelHandlerContext ctx) { + Boolean suspended = ctx.channel().attr(READ_SUSPENDED).get(); + return suspended == null || Boolean.FALSE.equals(suspended); + } + + @Override + public void read(ChannelHandlerContext ctx) { + if (isHandlerActive(ctx)) { + // For Global Traffic (and Read when using EventLoop in pipeline) : check if READ_SUSPENDED is False + ctx.read(); + } + } + + @Override + public void write(final ChannelHandlerContext ctx, final Object msg, final ChannelPromise promise) + throws Exception { + long size = calculateSize(msg); + long now = TrafficCounter.milliSecondFromNano(); + if (size > 0) { + // compute the number of ms to wait before continue with the channel + long wait = trafficCounter.writeTimeToWait(size, writeLimit, maxTime, now); + if (wait >= MINIMAL_WAIT) { + if (logger.isDebugEnabled()) { + logger.debug("Write suspend: " + wait + ':' + ctx.channel().config().isAutoRead() + ':' + + isHandlerActive(ctx)); + } + submitWrite(ctx, msg, size, wait, now, promise); + return; + } + } + // to maintain order of write + submitWrite(ctx, msg, size, 0, now, promise); + } + + @Deprecated + protected void submitWrite(final ChannelHandlerContext ctx, final Object msg, + final long delay, final ChannelPromise promise) { + submitWrite(ctx, msg, calculateSize(msg), + delay, TrafficCounter.milliSecondFromNano(), promise); + } + + abstract void submitWrite( + ChannelHandlerContext ctx, Object msg, long size, long delay, long now, ChannelPromise promise); + + @Override + public void channelRegistered(ChannelHandlerContext ctx) throws Exception { + setUserDefinedWritability(ctx, true); + super.channelRegistered(ctx); + } + + void setUserDefinedWritability(ChannelHandlerContext ctx, boolean writable) { + ChannelOutboundBuffer cob = ctx.channel().unsafe().outboundBuffer(); + if (cob != null) { + cob.setUserDefinedWritability(userDefinedWritabilityIndex, writable); + } + } + + /** + * Check the writability according to delay and size for the channel. + * Set if necessary setUserDefinedWritability status. + * @param delay the computed delay + * @param queueSize the current queueSize + */ + void checkWriteSuspend(ChannelHandlerContext ctx, long delay, long queueSize) { + if (queueSize > maxWriteSize || delay > maxWriteDelay) { + setUserDefinedWritability(ctx, false); + } + } + /** + * Explicitly release the Write suspended status. + */ + void releaseWriteSuspended(ChannelHandlerContext ctx) { + setUserDefinedWritability(ctx, true); + } + + /** + * @return the current TrafficCounter (if + * channel is still connected) + */ + public TrafficCounter trafficCounter() { + return trafficCounter; + } + + @Override + public String toString() { + StringBuilder builder = new StringBuilder(290) + .append("TrafficShaping with Write Limit: ").append(writeLimit) + .append(" Read Limit: ").append(readLimit) + .append(" CheckInterval: ").append(checkInterval) + .append(" maxDelay: ").append(maxWriteDelay) + .append(" maxSize: ").append(maxWriteSize) + .append(" and Counter: "); + if (trafficCounter != null) { + builder.append(trafficCounter); + } else { + builder.append("none"); + } + return builder.toString(); + } + + /** + * Calculate the size of the given {@link Object}. + * + * This implementation supports {@link ByteBuf}, {@link ByteBufHolder} and {@link FileRegion}. + * Sub-classes may override this. + * @param msg the msg for which the size should be calculated. + * @return size the size of the msg or {@code -1} if unknown. + */ + protected long calculateSize(Object msg) { + if (msg instanceof ByteBuf) { + return ((ByteBuf) msg).readableBytes(); + } + if (msg instanceof ByteBufHolder) { + return ((ByteBufHolder) msg).content().readableBytes(); + } + if (msg instanceof FileRegion) { + return ((FileRegion) msg).count(); + } + return -1; + } +} diff --git a/netty-handler/src/main/java/io/netty/handler/traffic/ChannelTrafficShapingHandler.java b/netty-handler/src/main/java/io/netty/handler/traffic/ChannelTrafficShapingHandler.java new file mode 100644 index 0000000..7004cc5 --- /dev/null +++ b/netty-handler/src/main/java/io/netty/handler/traffic/ChannelTrafficShapingHandler.java @@ -0,0 +1,231 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.traffic; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; + +import java.util.ArrayDeque; +import java.util.concurrent.TimeUnit; + +/** + *

This implementation of the {@link AbstractTrafficShapingHandler} is for channel + * traffic shaping, that is to say a per channel limitation of the bandwidth.

+ *

Note the index used in {@code OutboundBuffer.setUserDefinedWritability(index, boolean)} is 1.

+ * + *

The general use should be as follow:

+ *
    + *
  • Add in your pipeline a new ChannelTrafficShapingHandler.

    + *

    ChannelTrafficShapingHandler myHandler = new ChannelTrafficShapingHandler();

    + *

    pipeline.addLast(myHandler);

    + * + *

    Note that this handler has a Pipeline Coverage of "one" which means a new handler must be created + * for each new channel as the counter cannot be shared among all channels..

    + * + *

    Other arguments can be passed like write or read limitation (in bytes/s where 0 means no limitation) + * or the check interval (in millisecond) that represents the delay between two computations of the + * bandwidth and so the call back of the doAccounting method (0 means no accounting at all).

    + * + *

    A value of 0 means no accounting for checkInterval. If you need traffic shaping but no such accounting, + * it is recommended to set a positive value, even if it is high since the precision of the + * Traffic Shaping depends on the period where the traffic is computed. The highest the interval, + * the less precise the traffic shaping will be. It is suggested as higher value something close + * to 5 or 10 minutes.

    + * + *

    maxTimeToWait, by default set to 15s, allows to specify an upper bound of time shaping.

    + *
  • + *
  • In your handler, you should consider to use the {@code channel.isWritable()} and + * {@code channelWritabilityChanged(ctx)} to handle writability, or through + * {@code future.addListener(new GenericFutureListener())} on the future returned by + * {@code ctx.write()}.
  • + *
  • You shall also consider to have object size in read or write operations relatively adapted to + * the bandwidth you required: for instance having 10 MB objects for 10KB/s will lead to burst effect, + * while having 100 KB objects for 1 MB/s should be smoothly handle by this TrafficShaping handler.

  • + *
  • Some configuration methods will be taken as best effort, meaning + * that all already scheduled traffics will not be + * changed, but only applied to new traffics.

    + *

    So the expected usage of those methods are to be used not too often, + * accordingly to the traffic shaping configuration.

  • + *
+ */ +public class ChannelTrafficShapingHandler extends AbstractTrafficShapingHandler { + private final ArrayDeque messagesQueue = new ArrayDeque(); + private long queueSize; + + /** + * Create a new instance. + * + * @param writeLimit + * 0 or a limit in bytes/s + * @param readLimit + * 0 or a limit in bytes/s + * @param checkInterval + * The delay between two computations of performances for + * channels or 0 if no stats are to be computed. + * @param maxTime + * The maximum delay to wait in case of traffic excess. + */ + public ChannelTrafficShapingHandler(long writeLimit, long readLimit, + long checkInterval, long maxTime) { + super(writeLimit, readLimit, checkInterval, maxTime); + } + + /** + * Create a new instance using default + * max time as delay allowed value of 15000 ms. + * + * @param writeLimit + * 0 or a limit in bytes/s + * @param readLimit + * 0 or a limit in bytes/s + * @param checkInterval + * The delay between two computations of performances for + * channels or 0 if no stats are to be computed. + */ + public ChannelTrafficShapingHandler(long writeLimit, + long readLimit, long checkInterval) { + super(writeLimit, readLimit, checkInterval); + } + + /** + * Create a new instance using default Check Interval value of 1000 ms and + * max time as delay allowed value of 15000 ms. + * + * @param writeLimit + * 0 or a limit in bytes/s + * @param readLimit + * 0 or a limit in bytes/s + */ + public ChannelTrafficShapingHandler(long writeLimit, + long readLimit) { + super(writeLimit, readLimit); + } + + /** + * Create a new instance using + * default max time as delay allowed value of 15000 ms and no limit. + * + * @param checkInterval + * The delay between two computations of performances for + * channels or 0 if no stats are to be computed. + */ + public ChannelTrafficShapingHandler(long checkInterval) { + super(checkInterval); + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + TrafficCounter trafficCounter = new TrafficCounter(this, ctx.executor(), "ChannelTC" + + ctx.channel().hashCode(), checkInterval); + setTrafficCounter(trafficCounter); + trafficCounter.start(); + super.handlerAdded(ctx); + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { + trafficCounter.stop(); + // write order control + synchronized (this) { + if (ctx.channel().isActive()) { + for (ToSend toSend : messagesQueue) { + long size = calculateSize(toSend.toSend); + trafficCounter.bytesRealWriteFlowControl(size); + queueSize -= size; + ctx.write(toSend.toSend, toSend.promise); + } + } else { + for (ToSend toSend : messagesQueue) { + if (toSend.toSend instanceof ByteBuf) { + ((ByteBuf) toSend.toSend).release(); + } + } + } + messagesQueue.clear(); + } + releaseWriteSuspended(ctx); + releaseReadSuspended(ctx); + super.handlerRemoved(ctx); + } + + private static final class ToSend { + final long relativeTimeAction; + final Object toSend; + final ChannelPromise promise; + + private ToSend(final long delay, final Object toSend, final ChannelPromise promise) { + relativeTimeAction = delay; + this.toSend = toSend; + this.promise = promise; + } + } + + @Override + void submitWrite(final ChannelHandlerContext ctx, final Object msg, + final long size, final long delay, final long now, + final ChannelPromise promise) { + final ToSend newToSend; + // write order control + synchronized (this) { + if (delay == 0 && messagesQueue.isEmpty()) { + trafficCounter.bytesRealWriteFlowControl(size); + ctx.write(msg, promise); + return; + } + newToSend = new ToSend(delay + now, msg, promise); + messagesQueue.addLast(newToSend); + queueSize += size; + checkWriteSuspend(ctx, delay, queueSize); + } + final long futureNow = newToSend.relativeTimeAction; + ctx.executor().schedule(new Runnable() { + @Override + public void run() { + sendAllValid(ctx, futureNow); + } + }, delay, TimeUnit.MILLISECONDS); + } + + private void sendAllValid(final ChannelHandlerContext ctx, final long now) { + // write order control + synchronized (this) { + ToSend newToSend = messagesQueue.pollFirst(); + for (; newToSend != null; newToSend = messagesQueue.pollFirst()) { + if (newToSend.relativeTimeAction <= now) { + long size = calculateSize(newToSend.toSend); + trafficCounter.bytesRealWriteFlowControl(size); + queueSize -= size; + ctx.write(newToSend.toSend, newToSend.promise); + } else { + messagesQueue.addFirst(newToSend); + break; + } + } + if (messagesQueue.isEmpty()) { + releaseWriteSuspended(ctx); + } + } + ctx.flush(); + } + + /** + * @return current size in bytes of the write buffer. + */ + public long queueSize() { + return queueSize; + } +} diff --git a/netty-handler/src/main/java/io/netty/handler/traffic/GlobalChannelTrafficCounter.java b/netty-handler/src/main/java/io/netty/handler/traffic/GlobalChannelTrafficCounter.java new file mode 100644 index 0000000..aa0ec57 --- /dev/null +++ b/netty-handler/src/main/java/io/netty/handler/traffic/GlobalChannelTrafficCounter.java @@ -0,0 +1,127 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.traffic; + +import static io.netty.util.internal.ObjectUtil.checkNotNullWithIAE; + +import io.netty.handler.traffic.GlobalChannelTrafficShapingHandler.PerChannel; + +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; + +/** + * Version for {@link GlobalChannelTrafficShapingHandler}. + * This TrafficCounter is the Global one, and its special property is to directly handle + * other channel's TrafficCounters. In particular, there are no scheduler for those + * channel's TrafficCounters because it is managed by this one. + */ +public class GlobalChannelTrafficCounter extends TrafficCounter { + /** + * @param trafficShapingHandler the associated {@link GlobalChannelTrafficShapingHandler}. + * @param executor the underlying executor service for scheduling checks (both Global and per Channel). + * @param name the name given to this monitor. + * @param checkInterval the checkInterval in millisecond between two computations. + */ + public GlobalChannelTrafficCounter(GlobalChannelTrafficShapingHandler trafficShapingHandler, + ScheduledExecutorService executor, String name, long checkInterval) { + super(trafficShapingHandler, executor, name, checkInterval); + checkNotNullWithIAE(executor, "executor"); + } + + /** + * Class to implement monitoring at fix delay. + * This version is Mixed in the way it mixes Global and Channel counters. + */ + private static class MixedTrafficMonitoringTask implements Runnable { + /** + * The associated TrafficShapingHandler + */ + private final GlobalChannelTrafficShapingHandler trafficShapingHandler1; + + /** + * The associated TrafficCounter + */ + private final TrafficCounter counter; + + /** + * @param trafficShapingHandler The parent handler to which this task needs to callback to for accounting. + * @param counter The parent TrafficCounter that we need to reset the statistics for. + */ + MixedTrafficMonitoringTask( + GlobalChannelTrafficShapingHandler trafficShapingHandler, + TrafficCounter counter) { + trafficShapingHandler1 = trafficShapingHandler; + this.counter = counter; + } + + @Override + public void run() { + if (!counter.monitorActive) { + return; + } + long newLastTime = milliSecondFromNano(); + counter.resetAccounting(newLastTime); + for (PerChannel perChannel : trafficShapingHandler1.channelQueues.values()) { + perChannel.channelTrafficCounter.resetAccounting(newLastTime); + } + trafficShapingHandler1.doAccounting(counter); + } + } + + /** + * Start the monitoring process. + */ + @Override + public synchronized void start() { + if (monitorActive) { + return; + } + lastTime.set(milliSecondFromNano()); + long localCheckInterval = checkInterval.get(); + if (localCheckInterval > 0) { + monitorActive = true; + monitor = new MixedTrafficMonitoringTask((GlobalChannelTrafficShapingHandler) trafficShapingHandler, this); + scheduledFuture = + executor.scheduleAtFixedRate(monitor, 0, localCheckInterval, TimeUnit.MILLISECONDS); + } + } + + /** + * Stop the monitoring process. + */ + @Override + public synchronized void stop() { + if (!monitorActive) { + return; + } + monitorActive = false; + resetAccounting(milliSecondFromNano()); + trafficShapingHandler.doAccounting(this); + if (scheduledFuture != null) { + scheduledFuture.cancel(true); + } + } + + @Override + public void resetCumulativeTime() { + for (PerChannel perChannel : + ((GlobalChannelTrafficShapingHandler) trafficShapingHandler).channelQueues.values()) { + perChannel.channelTrafficCounter.resetCumulativeTime(); + } + super.resetCumulativeTime(); + } + +} diff --git a/netty-handler/src/main/java/io/netty/handler/traffic/GlobalChannelTrafficShapingHandler.java b/netty-handler/src/main/java/io/netty/handler/traffic/GlobalChannelTrafficShapingHandler.java new file mode 100644 index 0000000..9c35938 --- /dev/null +++ b/netty-handler/src/main/java/io/netty/handler/traffic/GlobalChannelTrafficShapingHandler.java @@ -0,0 +1,773 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.traffic; + +import static io.netty.util.internal.ObjectUtil.checkNotNullWithIAE; +import static io.netty.util.internal.ObjectUtil.checkPositive; +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandler.Sharable; +import io.netty.channel.ChannelConfig; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.util.Attribute; +import io.netty.util.concurrent.EventExecutor; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.util.AbstractCollection; +import java.util.ArrayDeque; +import java.util.Collection; +import java.util.Iterator; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; + +/** + * This implementation of the {@link AbstractTrafficShapingHandler} is for global + * and per channel traffic shaping, that is to say a global limitation of the bandwidth, whatever + * the number of opened channels and a per channel limitation of the bandwidth.

+ * This version shall not be in the same pipeline than other TrafficShapingHandler.

+ * + * The general use should be as follow:
+ *
    + *
  • Create your unique GlobalChannelTrafficShapingHandler like:

    + * GlobalChannelTrafficShapingHandler myHandler = new GlobalChannelTrafficShapingHandler(executor);

    + * The executor could be the underlying IO worker pool
    + * pipeline.addLast(myHandler);

    + * + * Note that this handler has a Pipeline Coverage of "all" which means only one such handler must be created + * and shared among all channels as the counter must be shared among all channels.

    + * + * Other arguments can be passed like write or read limitation (in bytes/s where 0 means no limitation) + * or the check interval (in millisecond) that represents the delay between two computations of the + * bandwidth and so the call back of the doAccounting method (0 means no accounting at all).
    + * Note that as this is a fusion of both Global and Channel Traffic Shaping, limits are in 2 sets, + * respectively Global and Channel.

    + * + * A value of 0 means no accounting for checkInterval. If you need traffic shaping but no such accounting, + * it is recommended to set a positive value, even if it is high since the precision of the + * Traffic Shaping depends on the period where the traffic is computed. The highest the interval, + * the less precise the traffic shaping will be. It is suggested as higher value something close + * to 5 or 10 minutes.

    + * + * maxTimeToWait, by default set to 15s, allows to specify an upper bound of time shaping.

    + *
  • + *
  • In your handler, you should consider to use the {@code channel.isWritable()} and + * {@code channelWritabilityChanged(ctx)} to handle writability, or through + * {@code future.addListener(new GenericFutureListener())} on the future returned by + * {@code ctx.write()}.
  • + *
  • You shall also consider to have object size in read or write operations relatively adapted to + * the bandwidth you required: for instance having 10 MB objects for 10KB/s will lead to burst effect, + * while having 100 KB objects for 1 MB/s should be smoothly handle by this TrafficShaping handler.

  • + *
  • Some configuration methods will be taken as best effort, meaning + * that all already scheduled traffics will not be + * changed, but only applied to new traffics.
    + * So the expected usage of those methods are to be used not too often, + * accordingly to the traffic shaping configuration.
  • + *

+ * + * Be sure to call {@link #release()} once this handler is not needed anymore to release all internal resources. + * This will not shutdown the {@link EventExecutor} as it may be shared, so you need to do this by your own. + */ +@Sharable +public class GlobalChannelTrafficShapingHandler extends AbstractTrafficShapingHandler { + private static final InternalLogger logger = + InternalLoggerFactory.getInstance(GlobalChannelTrafficShapingHandler.class); + /** + * All queues per channel + */ + final ConcurrentMap channelQueues = PlatformDependent.newConcurrentHashMap(); + + /** + * Global queues size + */ + private final AtomicLong queuesSize = new AtomicLong(); + + /** + * Maximum cumulative writing bytes for one channel among all (as long as channels stay the same) + */ + private final AtomicLong cumulativeWrittenBytes = new AtomicLong(); + + /** + * Maximum cumulative read bytes for one channel among all (as long as channels stay the same) + */ + private final AtomicLong cumulativeReadBytes = new AtomicLong(); + + /** + * Max size in the list before proposing to stop writing new objects from next handlers + * for all channel (global) + */ + volatile long maxGlobalWriteSize = DEFAULT_MAX_SIZE * 100; // default 400MB + + /** + * Limit in B/s to apply to write + */ + private volatile long writeChannelLimit; + + /** + * Limit in B/s to apply to read + */ + private volatile long readChannelLimit; + + private static final float DEFAULT_DEVIATION = 0.1F; + private static final float MAX_DEVIATION = 0.4F; + private static final float DEFAULT_SLOWDOWN = 0.4F; + private static final float DEFAULT_ACCELERATION = -0.1F; + private volatile float maxDeviation; + private volatile float accelerationFactor; + private volatile float slowDownFactor; + private volatile boolean readDeviationActive; + private volatile boolean writeDeviationActive; + + static final class PerChannel { + ArrayDeque messagesQueue; + TrafficCounter channelTrafficCounter; + long queueSize; + long lastWriteTimestamp; + long lastReadTimestamp; + } + + /** + * Create the global TrafficCounter + */ + void createGlobalTrafficCounter(ScheduledExecutorService executor) { + // Default + setMaxDeviation(DEFAULT_DEVIATION, DEFAULT_SLOWDOWN, DEFAULT_ACCELERATION); + checkNotNullWithIAE(executor, "executor"); + TrafficCounter tc = new GlobalChannelTrafficCounter(this, executor, "GlobalChannelTC", checkInterval); + setTrafficCounter(tc); + tc.start(); + } + + @Override + protected int userDefinedWritabilityIndex() { + return AbstractTrafficShapingHandler.GLOBALCHANNEL_DEFAULT_USER_DEFINED_WRITABILITY_INDEX; + } + + /** + * Create a new instance. + * + * @param executor + * the {@link ScheduledExecutorService} to use for the {@link TrafficCounter}. + * @param writeGlobalLimit + * 0 or a limit in bytes/s + * @param readGlobalLimit + * 0 or a limit in bytes/s + * @param writeChannelLimit + * 0 or a limit in bytes/s + * @param readChannelLimit + * 0 or a limit in bytes/s + * @param checkInterval + * The delay between two computations of performances for + * channels or 0 if no stats are to be computed. + * @param maxTime + * The maximum delay to wait in case of traffic excess. + */ + public GlobalChannelTrafficShapingHandler(ScheduledExecutorService executor, + long writeGlobalLimit, long readGlobalLimit, + long writeChannelLimit, long readChannelLimit, + long checkInterval, long maxTime) { + super(writeGlobalLimit, readGlobalLimit, checkInterval, maxTime); + createGlobalTrafficCounter(executor); + this.writeChannelLimit = writeChannelLimit; + this.readChannelLimit = readChannelLimit; + } + + /** + * Create a new instance. + * + * @param executor + * the {@link ScheduledExecutorService} to use for the {@link TrafficCounter}. + * @param writeGlobalLimit + * 0 or a limit in bytes/s + * @param readGlobalLimit + * 0 or a limit in bytes/s + * @param writeChannelLimit + * 0 or a limit in bytes/s + * @param readChannelLimit + * 0 or a limit in bytes/s + * @param checkInterval + * The delay between two computations of performances for + * channels or 0 if no stats are to be computed. + */ + public GlobalChannelTrafficShapingHandler(ScheduledExecutorService executor, + long writeGlobalLimit, long readGlobalLimit, + long writeChannelLimit, long readChannelLimit, + long checkInterval) { + super(writeGlobalLimit, readGlobalLimit, checkInterval); + this.writeChannelLimit = writeChannelLimit; + this.readChannelLimit = readChannelLimit; + createGlobalTrafficCounter(executor); + } + + /** + * Create a new instance. + * + * @param executor + * the {@link ScheduledExecutorService} to use for the {@link TrafficCounter}. + * @param writeGlobalLimit + * 0 or a limit in bytes/s + * @param readGlobalLimit + * 0 or a limit in bytes/s + * @param writeChannelLimit + * 0 or a limit in bytes/s + * @param readChannelLimit + * 0 or a limit in bytes/s + */ + public GlobalChannelTrafficShapingHandler(ScheduledExecutorService executor, + long writeGlobalLimit, long readGlobalLimit, + long writeChannelLimit, long readChannelLimit) { + super(writeGlobalLimit, readGlobalLimit); + this.writeChannelLimit = writeChannelLimit; + this.readChannelLimit = readChannelLimit; + createGlobalTrafficCounter(executor); + } + + /** + * Create a new instance. + * + * @param executor + * the {@link ScheduledExecutorService} to use for the {@link TrafficCounter}. + * @param checkInterval + * The delay between two computations of performances for + * channels or 0 if no stats are to be computed. + */ + public GlobalChannelTrafficShapingHandler(ScheduledExecutorService executor, long checkInterval) { + super(checkInterval); + createGlobalTrafficCounter(executor); + } + + /** + * Create a new instance. + * + * @param executor + * the {@link ScheduledExecutorService} to use for the {@link TrafficCounter}. + */ + public GlobalChannelTrafficShapingHandler(ScheduledExecutorService executor) { + createGlobalTrafficCounter(executor); + } + + /** + * @return the current max deviation + */ + public float maxDeviation() { + return maxDeviation; + } + + /** + * @return the current acceleration factor + */ + public float accelerationFactor() { + return accelerationFactor; + } + + /** + * @return the current slow down factor + */ + public float slowDownFactor() { + return slowDownFactor; + } + + /** + * @param maxDeviation + * the maximum deviation to allow during computation of average, default deviation + * being 0.1, so +/-10% of the desired bandwidth. Maximum being 0.4. + * @param slowDownFactor + * the factor set as +x% to the too fast client (minimal value being 0, meaning no + * slow down factor), default being 40% (0.4). + * @param accelerationFactor + * the factor set as -x% to the too slow client (maximal value being 0, meaning no + * acceleration factor), default being -10% (-0.1). + */ + public void setMaxDeviation(float maxDeviation, float slowDownFactor, float accelerationFactor) { + if (maxDeviation > MAX_DEVIATION) { + throw new IllegalArgumentException("maxDeviation must be <= " + MAX_DEVIATION); + } + checkPositiveOrZero(slowDownFactor, "slowDownFactor"); + if (accelerationFactor > 0) { + throw new IllegalArgumentException("accelerationFactor must be <= 0"); + } + this.maxDeviation = maxDeviation; + this.accelerationFactor = 1 + accelerationFactor; + this.slowDownFactor = 1 + slowDownFactor; + } + + private void computeDeviationCumulativeBytes() { + // compute the maximum cumulativeXxxxBytes among still connected Channels + long maxWrittenBytes = 0; + long maxReadBytes = 0; + long minWrittenBytes = Long.MAX_VALUE; + long minReadBytes = Long.MAX_VALUE; + for (PerChannel perChannel : channelQueues.values()) { + long value = perChannel.channelTrafficCounter.cumulativeWrittenBytes(); + if (maxWrittenBytes < value) { + maxWrittenBytes = value; + } + if (minWrittenBytes > value) { + minWrittenBytes = value; + } + value = perChannel.channelTrafficCounter.cumulativeReadBytes(); + if (maxReadBytes < value) { + maxReadBytes = value; + } + if (minReadBytes > value) { + minReadBytes = value; + } + } + boolean multiple = channelQueues.size() > 1; + readDeviationActive = multiple && minReadBytes < maxReadBytes / 2; + writeDeviationActive = multiple && minWrittenBytes < maxWrittenBytes / 2; + cumulativeWrittenBytes.set(maxWrittenBytes); + cumulativeReadBytes.set(maxReadBytes); + } + + @Override + protected void doAccounting(TrafficCounter counter) { + computeDeviationCumulativeBytes(); + super.doAccounting(counter); + } + + private long computeBalancedWait(float maxLocal, float maxGlobal, long wait) { + if (maxGlobal == 0) { + // no change + return wait; + } + float ratio = maxLocal / maxGlobal; + // if in the boundaries, same value + if (ratio > maxDeviation) { + if (ratio < 1 - maxDeviation) { + return wait; + } else { + ratio = slowDownFactor; + if (wait < MINIMAL_WAIT) { + wait = MINIMAL_WAIT; + } + } + } else { + ratio = accelerationFactor; + } + return (long) (wait * ratio); + } + + /** + * @return the maxGlobalWriteSize + */ + public long getMaxGlobalWriteSize() { + return maxGlobalWriteSize; + } + + /** + * Note the change will be taken as best effort, meaning + * that all already scheduled traffics will not be + * changed, but only applied to new traffics.
+ * So the expected usage of this method is to be used not too often, + * accordingly to the traffic shaping configuration. + * + * @param maxGlobalWriteSize the maximum Global Write Size allowed in the buffer + * globally for all channels before write suspended is set. + */ + public void setMaxGlobalWriteSize(long maxGlobalWriteSize) { + this.maxGlobalWriteSize = checkPositive(maxGlobalWriteSize, "maxGlobalWriteSize"); + } + + /** + * @return the global size of the buffers for all queues. + */ + public long queuesSize() { + return queuesSize.get(); + } + + /** + * @param newWriteLimit Channel write limit + * @param newReadLimit Channel read limit + */ + public void configureChannel(long newWriteLimit, long newReadLimit) { + writeChannelLimit = newWriteLimit; + readChannelLimit = newReadLimit; + long now = TrafficCounter.milliSecondFromNano(); + for (PerChannel perChannel : channelQueues.values()) { + perChannel.channelTrafficCounter.resetAccounting(now); + } + } + + /** + * @return Channel write limit + */ + public long getWriteChannelLimit() { + return writeChannelLimit; + } + + /** + * @param writeLimit Channel write limit + */ + public void setWriteChannelLimit(long writeLimit) { + writeChannelLimit = writeLimit; + long now = TrafficCounter.milliSecondFromNano(); + for (PerChannel perChannel : channelQueues.values()) { + perChannel.channelTrafficCounter.resetAccounting(now); + } + } + + /** + * @return Channel read limit + */ + public long getReadChannelLimit() { + return readChannelLimit; + } + + /** + * @param readLimit Channel read limit + */ + public void setReadChannelLimit(long readLimit) { + readChannelLimit = readLimit; + long now = TrafficCounter.milliSecondFromNano(); + for (PerChannel perChannel : channelQueues.values()) { + perChannel.channelTrafficCounter.resetAccounting(now); + } + } + + /** + * Release all internal resources of this instance. + */ + public final void release() { + trafficCounter.stop(); + } + + private PerChannel getOrSetPerChannel(ChannelHandlerContext ctx) { + // ensure creation is limited to one thread per channel + Channel channel = ctx.channel(); + Integer key = channel.hashCode(); + PerChannel perChannel = channelQueues.get(key); + if (perChannel == null) { + perChannel = new PerChannel(); + perChannel.messagesQueue = new ArrayDeque(); + // Don't start it since managed through the Global one + perChannel.channelTrafficCounter = new TrafficCounter(this, null, "ChannelTC" + + ctx.channel().hashCode(), checkInterval); + perChannel.queueSize = 0L; + perChannel.lastReadTimestamp = TrafficCounter.milliSecondFromNano(); + perChannel.lastWriteTimestamp = perChannel.lastReadTimestamp; + channelQueues.put(key, perChannel); + } + return perChannel; + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + getOrSetPerChannel(ctx); + trafficCounter.resetCumulativeTime(); + super.handlerAdded(ctx); + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { + trafficCounter.resetCumulativeTime(); + Channel channel = ctx.channel(); + Integer key = channel.hashCode(); + PerChannel perChannel = channelQueues.remove(key); + if (perChannel != null) { + // write operations need synchronization + synchronized (perChannel) { + if (channel.isActive()) { + for (ToSend toSend : perChannel.messagesQueue) { + long size = calculateSize(toSend.toSend); + trafficCounter.bytesRealWriteFlowControl(size); + perChannel.channelTrafficCounter.bytesRealWriteFlowControl(size); + perChannel.queueSize -= size; + queuesSize.addAndGet(-size); + ctx.write(toSend.toSend, toSend.promise); + } + } else { + queuesSize.addAndGet(-perChannel.queueSize); + for (ToSend toSend : perChannel.messagesQueue) { + if (toSend.toSend instanceof ByteBuf) { + ((ByteBuf) toSend.toSend).release(); + } + } + } + perChannel.messagesQueue.clear(); + } + } + releaseWriteSuspended(ctx); + releaseReadSuspended(ctx); + super.handlerRemoved(ctx); + } + + @Override + public void channelRead(final ChannelHandlerContext ctx, final Object msg) throws Exception { + long size = calculateSize(msg); + long now = TrafficCounter.milliSecondFromNano(); + if (size > 0) { + // compute the number of ms to wait before reopening the channel + long waitGlobal = trafficCounter.readTimeToWait(size, getReadLimit(), maxTime, now); + Integer key = ctx.channel().hashCode(); + PerChannel perChannel = channelQueues.get(key); + long wait = 0; + if (perChannel != null) { + wait = perChannel.channelTrafficCounter.readTimeToWait(size, readChannelLimit, maxTime, now); + if (readDeviationActive) { + // now try to balance between the channels + long maxLocalRead; + maxLocalRead = perChannel.channelTrafficCounter.cumulativeReadBytes(); + long maxGlobalRead = cumulativeReadBytes.get(); + if (maxLocalRead <= 0) { + maxLocalRead = 0; + } + if (maxGlobalRead < maxLocalRead) { + maxGlobalRead = maxLocalRead; + } + wait = computeBalancedWait(maxLocalRead, maxGlobalRead, wait); + } + } + if (wait < waitGlobal) { + wait = waitGlobal; + } + wait = checkWaitReadTime(ctx, wait, now); + if (wait >= MINIMAL_WAIT) { // At least 10ms seems a minimal + // time in order to try to limit the traffic + // Only AutoRead AND HandlerActive True means Context Active + Channel channel = ctx.channel(); + ChannelConfig config = channel.config(); + if (logger.isDebugEnabled()) { + logger.debug("Read Suspend: " + wait + ':' + config.isAutoRead() + ':' + + isHandlerActive(ctx)); + } + if (config.isAutoRead() && isHandlerActive(ctx)) { + config.setAutoRead(false); + channel.attr(READ_SUSPENDED).set(true); + // Create a Runnable to reactive the read if needed. If one was create before it will just be + // reused to limit object creation + Attribute attr = channel.attr(REOPEN_TASK); + Runnable reopenTask = attr.get(); + if (reopenTask == null) { + reopenTask = new ReopenReadTimerTask(ctx); + attr.set(reopenTask); + } + ctx.executor().schedule(reopenTask, wait, TimeUnit.MILLISECONDS); + if (logger.isDebugEnabled()) { + logger.debug("Suspend final status => " + config.isAutoRead() + ':' + + isHandlerActive(ctx) + " will reopened at: " + wait); + } + } + } + } + informReadOperation(ctx, now); + ctx.fireChannelRead(msg); + } + + @Override + protected long checkWaitReadTime(final ChannelHandlerContext ctx, long wait, final long now) { + Integer key = ctx.channel().hashCode(); + PerChannel perChannel = channelQueues.get(key); + if (perChannel != null) { + if (wait > maxTime && now + wait - perChannel.lastReadTimestamp > maxTime) { + wait = maxTime; + } + } + return wait; + } + + @Override + protected void informReadOperation(final ChannelHandlerContext ctx, final long now) { + Integer key = ctx.channel().hashCode(); + PerChannel perChannel = channelQueues.get(key); + if (perChannel != null) { + perChannel.lastReadTimestamp = now; + } + } + + private static final class ToSend { + final long relativeTimeAction; + final Object toSend; + final ChannelPromise promise; + final long size; + + private ToSend(final long delay, final Object toSend, final long size, final ChannelPromise promise) { + relativeTimeAction = delay; + this.toSend = toSend; + this.size = size; + this.promise = promise; + } + } + + protected long maximumCumulativeWrittenBytes() { + return cumulativeWrittenBytes.get(); + } + + protected long maximumCumulativeReadBytes() { + return cumulativeReadBytes.get(); + } + + /** + * To allow for instance doAccounting to use the TrafficCounter per channel. + * @return the list of TrafficCounters that exists at the time of the call. + */ + public Collection channelTrafficCounters() { + return new AbstractCollection() { + @Override + public Iterator iterator() { + return new Iterator() { + final Iterator iter = channelQueues.values().iterator(); + @Override + public boolean hasNext() { + return iter.hasNext(); + } + @Override + public TrafficCounter next() { + return iter.next().channelTrafficCounter; + } + @Override + public void remove() { + throw new UnsupportedOperationException(); + } + }; + } + @Override + public int size() { + return channelQueues.size(); + } + }; + } + + @Override + public void write(final ChannelHandlerContext ctx, final Object msg, final ChannelPromise promise) + throws Exception { + long size = calculateSize(msg); + long now = TrafficCounter.milliSecondFromNano(); + if (size > 0) { + // compute the number of ms to wait before continue with the channel + long waitGlobal = trafficCounter.writeTimeToWait(size, getWriteLimit(), maxTime, now); + Integer key = ctx.channel().hashCode(); + PerChannel perChannel = channelQueues.get(key); + long wait = 0; + if (perChannel != null) { + wait = perChannel.channelTrafficCounter.writeTimeToWait(size, writeChannelLimit, maxTime, now); + if (writeDeviationActive) { + // now try to balance between the channels + long maxLocalWrite; + maxLocalWrite = perChannel.channelTrafficCounter.cumulativeWrittenBytes(); + long maxGlobalWrite = cumulativeWrittenBytes.get(); + if (maxLocalWrite <= 0) { + maxLocalWrite = 0; + } + if (maxGlobalWrite < maxLocalWrite) { + maxGlobalWrite = maxLocalWrite; + } + wait = computeBalancedWait(maxLocalWrite, maxGlobalWrite, wait); + } + } + if (wait < waitGlobal) { + wait = waitGlobal; + } + if (wait >= MINIMAL_WAIT) { + if (logger.isDebugEnabled()) { + logger.debug("Write suspend: " + wait + ':' + ctx.channel().config().isAutoRead() + ':' + + isHandlerActive(ctx)); + } + submitWrite(ctx, msg, size, wait, now, promise); + return; + } + } + // to maintain order of write + submitWrite(ctx, msg, size, 0, now, promise); + } + + @Override + protected void submitWrite(final ChannelHandlerContext ctx, final Object msg, + final long size, final long writedelay, final long now, + final ChannelPromise promise) { + Channel channel = ctx.channel(); + Integer key = channel.hashCode(); + PerChannel perChannel = channelQueues.get(key); + if (perChannel == null) { + // in case write occurs before handlerAdded is raised for this handler + // imply a synchronized only if needed + perChannel = getOrSetPerChannel(ctx); + } + final ToSend newToSend; + long delay = writedelay; + boolean globalSizeExceeded = false; + // write operations need synchronization + synchronized (perChannel) { + if (writedelay == 0 && perChannel.messagesQueue.isEmpty()) { + trafficCounter.bytesRealWriteFlowControl(size); + perChannel.channelTrafficCounter.bytesRealWriteFlowControl(size); + ctx.write(msg, promise); + perChannel.lastWriteTimestamp = now; + return; + } + if (delay > maxTime && now + delay - perChannel.lastWriteTimestamp > maxTime) { + delay = maxTime; + } + newToSend = new ToSend(delay + now, msg, size, promise); + perChannel.messagesQueue.addLast(newToSend); + perChannel.queueSize += size; + queuesSize.addAndGet(size); + checkWriteSuspend(ctx, delay, perChannel.queueSize); + if (queuesSize.get() > maxGlobalWriteSize) { + globalSizeExceeded = true; + } + } + if (globalSizeExceeded) { + setUserDefinedWritability(ctx, false); + } + final long futureNow = newToSend.relativeTimeAction; + final PerChannel forSchedule = perChannel; + ctx.executor().schedule(new Runnable() { + @Override + public void run() { + sendAllValid(ctx, forSchedule, futureNow); + } + }, delay, TimeUnit.MILLISECONDS); + } + + private void sendAllValid(final ChannelHandlerContext ctx, final PerChannel perChannel, final long now) { + // write operations need synchronization + synchronized (perChannel) { + ToSend newToSend = perChannel.messagesQueue.pollFirst(); + for (; newToSend != null; newToSend = perChannel.messagesQueue.pollFirst()) { + if (newToSend.relativeTimeAction <= now) { + long size = newToSend.size; + trafficCounter.bytesRealWriteFlowControl(size); + perChannel.channelTrafficCounter.bytesRealWriteFlowControl(size); + perChannel.queueSize -= size; + queuesSize.addAndGet(-size); + ctx.write(newToSend.toSend, newToSend.promise); + perChannel.lastWriteTimestamp = now; + } else { + perChannel.messagesQueue.addFirst(newToSend); + break; + } + } + if (perChannel.messagesQueue.isEmpty()) { + releaseWriteSuspended(ctx); + } + } + ctx.flush(); + } + + @Override + public String toString() { + return new StringBuilder(340).append(super.toString()) + .append(" Write Channel Limit: ").append(writeChannelLimit) + .append(" Read Channel Limit: ").append(readChannelLimit).toString(); + } +} diff --git a/netty-handler/src/main/java/io/netty/handler/traffic/GlobalTrafficShapingHandler.java b/netty-handler/src/main/java/io/netty/handler/traffic/GlobalTrafficShapingHandler.java new file mode 100644 index 0000000..99da696 --- /dev/null +++ b/netty-handler/src/main/java/io/netty/handler/traffic/GlobalTrafficShapingHandler.java @@ -0,0 +1,401 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.traffic; + +import io.netty.buffer.ByteBuf; +import io.netty.channel.ChannelHandler.Sharable; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelPromise; +import io.netty.util.concurrent.EventExecutor; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.PlatformDependent; + +import java.util.ArrayDeque; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; + +/** + *

This implementation of the {@link AbstractTrafficShapingHandler} is for global + * traffic shaping, that is to say a global limitation of the bandwidth, whatever + * the number of opened channels.

+ *

Note the index used in {@code OutboundBuffer.setUserDefinedWritability(index, boolean)} is 2.

+ * + *

The general use should be as follow:

+ *
    + *
  • Create your unique GlobalTrafficShapingHandler like:

    + *

    GlobalTrafficShapingHandler myHandler = new GlobalTrafficShapingHandler(executor);

    + *

    The executor could be the underlying IO worker pool

    + *

    pipeline.addLast(myHandler);

    + * + *

    Note that this handler has a Pipeline Coverage of "all" which means only one such handler must be created + * and shared among all channels as the counter must be shared among all channels.

    + * + *

    Other arguments can be passed like write or read limitation (in bytes/s where 0 means no limitation) + * or the check interval (in millisecond) that represents the delay between two computations of the + * bandwidth and so the call back of the doAccounting method (0 means no accounting at all).

    + * + *

    A value of 0 means no accounting for checkInterval. If you need traffic shaping but no such accounting, + * it is recommended to set a positive value, even if it is high since the precision of the + * Traffic Shaping depends on the period where the traffic is computed. The highest the interval, + * the less precise the traffic shaping will be. It is suggested as higher value something close + * to 5 or 10 minutes.

    + * + *

    maxTimeToWait, by default set to 15s, allows to specify an upper bound of time shaping.

    + *
  • + *
  • In your handler, you should consider to use the {@code channel.isWritable()} and + * {@code channelWritabilityChanged(ctx)} to handle writability, or through + * {@code future.addListener(new GenericFutureListener())} on the future returned by + * {@code ctx.write()}.
  • + *
  • You shall also consider to have object size in read or write operations relatively adapted to + * the bandwidth you required: for instance having 10 MB objects for 10KB/s will lead to burst effect, + * while having 100 KB objects for 1 MB/s should be smoothly handle by this TrafficShaping handler.

  • + *
  • Some configuration methods will be taken as best effort, meaning + * that all already scheduled traffics will not be + * changed, but only applied to new traffics.

    + * So the expected usage of those methods are to be used not too often, + * accordingly to the traffic shaping configuration.
  • + *
+ * + * Be sure to call {@link #release()} once this handler is not needed anymore to release all internal resources. + * This will not shutdown the {@link EventExecutor} as it may be shared, so you need to do this by your own. + */ +@Sharable +public class GlobalTrafficShapingHandler extends AbstractTrafficShapingHandler { + /** + * All queues per channel + */ + private final ConcurrentMap channelQueues = PlatformDependent.newConcurrentHashMap(); + + /** + * Global queues size + */ + private final AtomicLong queuesSize = new AtomicLong(); + + /** + * Max size in the list before proposing to stop writing new objects from next handlers + * for all channel (global) + */ + long maxGlobalWriteSize = DEFAULT_MAX_SIZE * 100; // default 400MB + + private static final class PerChannel { + ArrayDeque messagesQueue; + long queueSize; + long lastWriteTimestamp; + long lastReadTimestamp; + } + + /** + * Create the global TrafficCounter. + */ + void createGlobalTrafficCounter(ScheduledExecutorService executor) { + TrafficCounter tc = new TrafficCounter(this, + ObjectUtil.checkNotNull(executor, "executor"), + "GlobalTC", + checkInterval); + + setTrafficCounter(tc); + tc.start(); + } + + @Override + protected int userDefinedWritabilityIndex() { + return AbstractTrafficShapingHandler.GLOBAL_DEFAULT_USER_DEFINED_WRITABILITY_INDEX; + } + + /** + * Create a new instance. + * + * @param executor + * the {@link ScheduledExecutorService} to use for the {@link TrafficCounter}. + * @param writeLimit + * 0 or a limit in bytes/s + * @param readLimit + * 0 or a limit in bytes/s + * @param checkInterval + * The delay between two computations of performances for + * channels or 0 if no stats are to be computed. + * @param maxTime + * The maximum delay to wait in case of traffic excess. + */ + public GlobalTrafficShapingHandler(ScheduledExecutorService executor, long writeLimit, long readLimit, + long checkInterval, long maxTime) { + super(writeLimit, readLimit, checkInterval, maxTime); + createGlobalTrafficCounter(executor); + } + + /** + * Create a new instance using + * default max time as delay allowed value of 15000 ms. + * + * @param executor + * the {@link ScheduledExecutorService} to use for the {@link TrafficCounter}. + * @param writeLimit + * 0 or a limit in bytes/s + * @param readLimit + * 0 or a limit in bytes/s + * @param checkInterval + * The delay between two computations of performances for + * channels or 0 if no stats are to be computed. + */ + public GlobalTrafficShapingHandler(ScheduledExecutorService executor, long writeLimit, + long readLimit, long checkInterval) { + super(writeLimit, readLimit, checkInterval); + createGlobalTrafficCounter(executor); + } + + /** + * Create a new instance using default Check Interval value of 1000 ms and + * default max time as delay allowed value of 15000 ms. + * + * @param executor + * the {@link ScheduledExecutorService} to use for the {@link TrafficCounter}. + * @param writeLimit + * 0 or a limit in bytes/s + * @param readLimit + * 0 or a limit in bytes/s + */ + public GlobalTrafficShapingHandler(ScheduledExecutorService executor, long writeLimit, + long readLimit) { + super(writeLimit, readLimit); + createGlobalTrafficCounter(executor); + } + + /** + * Create a new instance using + * default max time as delay allowed value of 15000 ms and no limit. + * + * @param executor + * the {@link ScheduledExecutorService} to use for the {@link TrafficCounter}. + * @param checkInterval + * The delay between two computations of performances for + * channels or 0 if no stats are to be computed. + */ + public GlobalTrafficShapingHandler(ScheduledExecutorService executor, long checkInterval) { + super(checkInterval); + createGlobalTrafficCounter(executor); + } + + /** + * Create a new instance using default Check Interval value of 1000 ms and + * default max time as delay allowed value of 15000 ms and no limit. + * + * @param executor + * the {@link EventExecutor} to use for the {@link TrafficCounter}. + */ + public GlobalTrafficShapingHandler(EventExecutor executor) { + createGlobalTrafficCounter(executor); + } + + /** + * @return the maxGlobalWriteSize default value being 400 MB. + */ + public long getMaxGlobalWriteSize() { + return maxGlobalWriteSize; + } + + /** + * Note the change will be taken as best effort, meaning + * that all already scheduled traffics will not be + * changed, but only applied to new traffics.
+ * So the expected usage of this method is to be used not too often, + * accordingly to the traffic shaping configuration. + * + * @param maxGlobalWriteSize the maximum Global Write Size allowed in the buffer + * globally for all channels before write suspended is set, + * default value being 400 MB. + */ + public void setMaxGlobalWriteSize(long maxGlobalWriteSize) { + this.maxGlobalWriteSize = maxGlobalWriteSize; + } + + /** + * @return the global size of the buffers for all queues. + */ + public long queuesSize() { + return queuesSize.get(); + } + + /** + * Release all internal resources of this instance. + */ + public final void release() { + trafficCounter.stop(); + } + + private PerChannel getOrSetPerChannel(ChannelHandlerContext ctx) { + // ensure creation is limited to one thread per channel + Channel channel = ctx.channel(); + Integer key = channel.hashCode(); + PerChannel perChannel = channelQueues.get(key); + if (perChannel == null) { + perChannel = new PerChannel(); + perChannel.messagesQueue = new ArrayDeque(); + perChannel.queueSize = 0L; + perChannel.lastReadTimestamp = TrafficCounter.milliSecondFromNano(); + perChannel.lastWriteTimestamp = perChannel.lastReadTimestamp; + channelQueues.put(key, perChannel); + } + return perChannel; + } + + @Override + public void handlerAdded(ChannelHandlerContext ctx) throws Exception { + getOrSetPerChannel(ctx); + super.handlerAdded(ctx); + } + + @Override + public void handlerRemoved(ChannelHandlerContext ctx) throws Exception { + Channel channel = ctx.channel(); + Integer key = channel.hashCode(); + PerChannel perChannel = channelQueues.remove(key); + if (perChannel != null) { + // write operations need synchronization + synchronized (perChannel) { + if (channel.isActive()) { + for (ToSend toSend : perChannel.messagesQueue) { + long size = calculateSize(toSend.toSend); + trafficCounter.bytesRealWriteFlowControl(size); + perChannel.queueSize -= size; + queuesSize.addAndGet(-size); + ctx.write(toSend.toSend, toSend.promise); + } + } else { + queuesSize.addAndGet(-perChannel.queueSize); + for (ToSend toSend : perChannel.messagesQueue) { + if (toSend.toSend instanceof ByteBuf) { + ((ByteBuf) toSend.toSend).release(); + } + } + } + perChannel.messagesQueue.clear(); + } + } + releaseWriteSuspended(ctx); + releaseReadSuspended(ctx); + super.handlerRemoved(ctx); + } + + @Override + long checkWaitReadTime(final ChannelHandlerContext ctx, long wait, final long now) { + Integer key = ctx.channel().hashCode(); + PerChannel perChannel = channelQueues.get(key); + if (perChannel != null) { + if (wait > maxTime && now + wait - perChannel.lastReadTimestamp > maxTime) { + wait = maxTime; + } + } + return wait; + } + + @Override + void informReadOperation(final ChannelHandlerContext ctx, final long now) { + Integer key = ctx.channel().hashCode(); + PerChannel perChannel = channelQueues.get(key); + if (perChannel != null) { + perChannel.lastReadTimestamp = now; + } + } + + private static final class ToSend { + final long relativeTimeAction; + final Object toSend; + final long size; + final ChannelPromise promise; + + private ToSend(final long delay, final Object toSend, final long size, final ChannelPromise promise) { + relativeTimeAction = delay; + this.toSend = toSend; + this.size = size; + this.promise = promise; + } + } + + @Override + void submitWrite(final ChannelHandlerContext ctx, final Object msg, + final long size, final long writedelay, final long now, + final ChannelPromise promise) { + Channel channel = ctx.channel(); + Integer key = channel.hashCode(); + PerChannel perChannel = channelQueues.get(key); + if (perChannel == null) { + // in case write occurs before handlerAdded is raised for this handler + // imply a synchronized only if needed + perChannel = getOrSetPerChannel(ctx); + } + final ToSend newToSend; + long delay = writedelay; + boolean globalSizeExceeded = false; + // write operations need synchronization + synchronized (perChannel) { + if (writedelay == 0 && perChannel.messagesQueue.isEmpty()) { + trafficCounter.bytesRealWriteFlowControl(size); + ctx.write(msg, promise); + perChannel.lastWriteTimestamp = now; + return; + } + if (delay > maxTime && now + delay - perChannel.lastWriteTimestamp > maxTime) { + delay = maxTime; + } + newToSend = new ToSend(delay + now, msg, size, promise); + perChannel.messagesQueue.addLast(newToSend); + perChannel.queueSize += size; + queuesSize.addAndGet(size); + checkWriteSuspend(ctx, delay, perChannel.queueSize); + if (queuesSize.get() > maxGlobalWriteSize) { + globalSizeExceeded = true; + } + } + if (globalSizeExceeded) { + setUserDefinedWritability(ctx, false); + } + final long futureNow = newToSend.relativeTimeAction; + final PerChannel forSchedule = perChannel; + ctx.executor().schedule(new Runnable() { + @Override + public void run() { + sendAllValid(ctx, forSchedule, futureNow); + } + }, delay, TimeUnit.MILLISECONDS); + } + + private void sendAllValid(final ChannelHandlerContext ctx, final PerChannel perChannel, final long now) { + // write operations need synchronization + synchronized (perChannel) { + ToSend newToSend = perChannel.messagesQueue.pollFirst(); + for (; newToSend != null; newToSend = perChannel.messagesQueue.pollFirst()) { + if (newToSend.relativeTimeAction <= now) { + long size = newToSend.size; + trafficCounter.bytesRealWriteFlowControl(size); + perChannel.queueSize -= size; + queuesSize.addAndGet(-size); + ctx.write(newToSend.toSend, newToSend.promise); + perChannel.lastWriteTimestamp = now; + } else { + perChannel.messagesQueue.addFirst(newToSend); + break; + } + } + if (perChannel.messagesQueue.isEmpty()) { + releaseWriteSuspended(ctx); + } + } + ctx.flush(); + } +} diff --git a/netty-handler/src/main/java/io/netty/handler/traffic/TrafficCounter.java b/netty-handler/src/main/java/io/netty/handler/traffic/TrafficCounter.java new file mode 100644 index 0000000..cc4fd0b --- /dev/null +++ b/netty-handler/src/main/java/io/netty/handler/traffic/TrafficCounter.java @@ -0,0 +1,619 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.traffic; + +import static io.netty.util.internal.ObjectUtil.checkNotNull; +import static io.netty.util.internal.ObjectUtil.checkNotNullWithIAE; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.ScheduledFuture; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; + + +/** + * Counts the number of read and written bytes for rate-limiting traffic. + *

+ * It computes the statistics for both inbound and outbound traffic periodically at the given + * {@code checkInterval}, and calls the {@link AbstractTrafficShapingHandler#doAccounting(TrafficCounter)} method back. + * If the {@code checkInterval} is {@code 0}, no accounting will be done and statistics will only be computed at each + * receive or write operation. + *

+ */ +public class TrafficCounter { + + private static final InternalLogger logger = InternalLoggerFactory.getInstance(TrafficCounter.class); + + /** + * @return the time in ms using nanoTime, so not real EPOCH time but elapsed time in ms. + */ + public static long milliSecondFromNano() { + return System.nanoTime() / 1000000; + } + + /** + * Current written bytes + */ + private final AtomicLong currentWrittenBytes = new AtomicLong(); + + /** + * Current read bytes + */ + private final AtomicLong currentReadBytes = new AtomicLong(); + + /** + * Last writing time during current check interval + */ + private long writingTime; + + /** + * Last reading delay during current check interval + */ + private long readingTime; + + /** + * Long life written bytes + */ + private final AtomicLong cumulativeWrittenBytes = new AtomicLong(); + + /** + * Long life read bytes + */ + private final AtomicLong cumulativeReadBytes = new AtomicLong(); + + /** + * Last Time where cumulative bytes where reset to zero: this time is a real EPOC time (informative only) + */ + private long lastCumulativeTime; + + /** + * Last writing bandwidth + */ + private long lastWriteThroughput; + + /** + * Last reading bandwidth + */ + private long lastReadThroughput; + + /** + * Last Time Check taken + */ + final AtomicLong lastTime = new AtomicLong(); + + /** + * Last written bytes number during last check interval + */ + private volatile long lastWrittenBytes; + + /** + * Last read bytes number during last check interval + */ + private volatile long lastReadBytes; + + /** + * Last future writing time during last check interval + */ + private volatile long lastWritingTime; + + /** + * Last reading time during last check interval + */ + private volatile long lastReadingTime; + + /** + * Real written bytes + */ + private final AtomicLong realWrittenBytes = new AtomicLong(); + + /** + * Real writing bandwidth + */ + private long realWriteThroughput; + + /** + * Delay between two captures + */ + final AtomicLong checkInterval = new AtomicLong( + AbstractTrafficShapingHandler.DEFAULT_CHECK_INTERVAL); + + // default 1 s + + /** + * Name of this Monitor + */ + final String name; + + /** + * The associated TrafficShapingHandler + */ + final AbstractTrafficShapingHandler trafficShapingHandler; + + /** + * Executor that will run the monitor + */ + final ScheduledExecutorService executor; + /** + * Monitor created once in start() + */ + Runnable monitor; + /** + * used in stop() to cancel the timer + */ + volatile ScheduledFuture scheduledFuture; + + /** + * Is Monitor active + */ + volatile boolean monitorActive; + + /** + * Class to implement monitoring at fix delay + * + */ + private final class TrafficMonitoringTask implements Runnable { + @Override + public void run() { + if (!monitorActive) { + return; + } + resetAccounting(milliSecondFromNano()); + if (trafficShapingHandler != null) { + trafficShapingHandler.doAccounting(TrafficCounter.this); + } + } + } + + /** + * Start the monitoring process. + */ + public synchronized void start() { + if (monitorActive) { + return; + } + lastTime.set(milliSecondFromNano()); + long localCheckInterval = checkInterval.get(); + // if executor is null, it means it is piloted by a GlobalChannelTrafficCounter, so no executor + if (localCheckInterval > 0 && executor != null) { + monitorActive = true; + monitor = new TrafficMonitoringTask(); + scheduledFuture = + executor.scheduleAtFixedRate(monitor, 0, localCheckInterval, TimeUnit.MILLISECONDS); + } + } + + /** + * Stop the monitoring process. + */ + public synchronized void stop() { + if (!monitorActive) { + return; + } + monitorActive = false; + resetAccounting(milliSecondFromNano()); + if (trafficShapingHandler != null) { + trafficShapingHandler.doAccounting(this); + } + if (scheduledFuture != null) { + scheduledFuture.cancel(true); + } + } + + /** + * Reset the accounting on Read and Write. + * + * @param newLastTime the milliseconds unix timestamp that we should be considered up-to-date for. + */ + synchronized void resetAccounting(long newLastTime) { + long interval = newLastTime - lastTime.getAndSet(newLastTime); + if (interval == 0) { + // nothing to do + return; + } + if (logger.isDebugEnabled() && interval > checkInterval() << 1) { + logger.debug("Acct schedule not ok: " + interval + " > 2*" + checkInterval() + " from " + name); + } + lastReadBytes = currentReadBytes.getAndSet(0); + lastWrittenBytes = currentWrittenBytes.getAndSet(0); + lastReadThroughput = lastReadBytes * 1000 / interval; + // nb byte / checkInterval in ms * 1000 (1s) + lastWriteThroughput = lastWrittenBytes * 1000 / interval; + // nb byte / checkInterval in ms * 1000 (1s) + realWriteThroughput = realWrittenBytes.getAndSet(0) * 1000 / interval; + lastWritingTime = Math.max(lastWritingTime, writingTime); + lastReadingTime = Math.max(lastReadingTime, readingTime); + } + + /** + * Constructor with the {@link AbstractTrafficShapingHandler} that hosts it, the {@link ScheduledExecutorService} + * to use, its name, the checkInterval between two computations in milliseconds. + * + * @param executor + * the underlying executor service for scheduling checks, might be null when used + * from {@link GlobalChannelTrafficCounter}. + * @param name + * the name given to this monitor. + * @param checkInterval + * the checkInterval in millisecond between two computations. + */ + public TrafficCounter(ScheduledExecutorService executor, String name, long checkInterval) { + + this.name = checkNotNull(name, "name"); + trafficShapingHandler = null; + this.executor = executor; + + init(checkInterval); + } + + /** + * Constructor with the {@link AbstractTrafficShapingHandler} that hosts it, the Timer to use, its + * name, the checkInterval between two computations in millisecond. + * + * @param trafficShapingHandler + * the associated AbstractTrafficShapingHandler. + * @param executor + * the underlying executor service for scheduling checks, might be null when used + * from {@link GlobalChannelTrafficCounter}. + * @param name + * the name given to this monitor. + * @param checkInterval + * the checkInterval in millisecond between two computations. + */ + public TrafficCounter( + AbstractTrafficShapingHandler trafficShapingHandler, ScheduledExecutorService executor, + String name, long checkInterval) { + this.name = checkNotNull(name, "name"); + this.trafficShapingHandler = checkNotNullWithIAE(trafficShapingHandler, "trafficShapingHandler"); + this.executor = executor; + + init(checkInterval); + } + + private void init(long checkInterval) { + // absolute time: informative only + lastCumulativeTime = System.currentTimeMillis(); + writingTime = milliSecondFromNano(); + readingTime = writingTime; + lastWritingTime = writingTime; + lastReadingTime = writingTime; + configure(checkInterval); + } + + /** + * Change checkInterval between two computations in millisecond. + * + * @param newCheckInterval The new check interval (in milliseconds) + */ + public void configure(long newCheckInterval) { + long newInterval = newCheckInterval / 10 * 10; + if (checkInterval.getAndSet(newInterval) != newInterval) { + if (newInterval <= 0) { + stop(); + // No more active monitoring + lastTime.set(milliSecondFromNano()); + } else { + // Restart + stop(); + start(); + } + } + } + + /** + * Computes counters for Read. + * + * @param recv + * the size in bytes to read + */ + void bytesRecvFlowControl(long recv) { + currentReadBytes.addAndGet(recv); + cumulativeReadBytes.addAndGet(recv); + } + + /** + * Computes counters for Write. + * + * @param write + * the size in bytes to write + */ + void bytesWriteFlowControl(long write) { + currentWrittenBytes.addAndGet(write); + cumulativeWrittenBytes.addAndGet(write); + } + + /** + * Computes counters for Real Write. + * + * @param write + * the size in bytes to write + */ + void bytesRealWriteFlowControl(long write) { + realWrittenBytes.addAndGet(write); + } + + /** + * @return the current checkInterval between two computations of traffic counter + * in millisecond. + */ + public long checkInterval() { + return checkInterval.get(); + } + + /** + * @return the Read Throughput in bytes/s computes in the last check interval. + */ + public long lastReadThroughput() { + return lastReadThroughput; + } + + /** + * @return the Write Throughput in bytes/s computes in the last check interval. + */ + public long lastWriteThroughput() { + return lastWriteThroughput; + } + + /** + * @return the number of bytes read during the last check Interval. + */ + public long lastReadBytes() { + return lastReadBytes; + } + + /** + * @return the number of bytes written during the last check Interval. + */ + public long lastWrittenBytes() { + return lastWrittenBytes; + } + + /** + * @return the current number of bytes read since the last checkInterval. + */ + public long currentReadBytes() { + return currentReadBytes.get(); + } + + /** + * @return the current number of bytes written since the last check Interval. + */ + public long currentWrittenBytes() { + return currentWrittenBytes.get(); + } + + /** + * @return the Time in millisecond of the last check as of System.currentTimeMillis(). + */ + public long lastTime() { + return lastTime.get(); + } + + /** + * @return the cumulativeWrittenBytes + */ + public long cumulativeWrittenBytes() { + return cumulativeWrittenBytes.get(); + } + + /** + * @return the cumulativeReadBytes + */ + public long cumulativeReadBytes() { + return cumulativeReadBytes.get(); + } + + /** + * @return the lastCumulativeTime in millisecond as of System.currentTimeMillis() + * when the cumulative counters were reset to 0. + */ + public long lastCumulativeTime() { + return lastCumulativeTime; + } + + /** + * @return the realWrittenBytes + */ + public AtomicLong getRealWrittenBytes() { + return realWrittenBytes; + } + + /** + * @return the realWriteThroughput + */ + public long getRealWriteThroughput() { + return realWriteThroughput; + } + + /** + * Reset both read and written cumulative bytes counters and the associated absolute time + * from System.currentTimeMillis(). + */ + public void resetCumulativeTime() { + lastCumulativeTime = System.currentTimeMillis(); + cumulativeReadBytes.set(0); + cumulativeWrittenBytes.set(0); + } + + /** + * @return the name of this TrafficCounter. + */ + public String name() { + return name; + } + + /** + * Returns the time to wait (if any) for the given length message, using the given limitTraffic and the max wait + * time. + * + * @param size + * the recv size + * @param limitTraffic + * the traffic limit in bytes per second. + * @param maxTime + * the max time in ms to wait in case of excess of traffic. + * @return the current time to wait (in ms) if needed for Read operation. + */ + @Deprecated + public long readTimeToWait(final long size, final long limitTraffic, final long maxTime) { + return readTimeToWait(size, limitTraffic, maxTime, milliSecondFromNano()); + } + + /** + * Returns the time to wait (if any) for the given length message, using the given limitTraffic and the max wait + * time. + * + * @param size + * the recv size + * @param limitTraffic + * the traffic limit in bytes per second + * @param maxTime + * the max time in ms to wait in case of excess of traffic. + * @param now the current time + * @return the current time to wait (in ms) if needed for Read operation. + */ + public long readTimeToWait(final long size, final long limitTraffic, final long maxTime, final long now) { + bytesRecvFlowControl(size); + if (size == 0 || limitTraffic == 0) { + return 0; + } + final long lastTimeCheck = lastTime.get(); + long sum = currentReadBytes.get(); + long localReadingTime = readingTime; + long lastRB = lastReadBytes; + final long interval = now - lastTimeCheck; + long pastDelay = Math.max(lastReadingTime - lastTimeCheck, 0); + if (interval > AbstractTrafficShapingHandler.MINIMAL_WAIT) { + // Enough interval time to compute shaping + long time = sum * 1000 / limitTraffic - interval + pastDelay; + if (time > AbstractTrafficShapingHandler.MINIMAL_WAIT) { + if (logger.isDebugEnabled()) { + logger.debug("Time: " + time + ':' + sum + ':' + interval + ':' + pastDelay); + } + if (time > maxTime && now + time - localReadingTime > maxTime) { + time = maxTime; + } + readingTime = Math.max(localReadingTime, now + time); + return time; + } + readingTime = Math.max(localReadingTime, now); + return 0; + } + // take the last read interval check to get enough interval time + long lastsum = sum + lastRB; + long lastinterval = interval + checkInterval.get(); + long time = lastsum * 1000 / limitTraffic - lastinterval + pastDelay; + if (time > AbstractTrafficShapingHandler.MINIMAL_WAIT) { + if (logger.isDebugEnabled()) { + logger.debug("Time: " + time + ':' + lastsum + ':' + lastinterval + ':' + pastDelay); + } + if (time > maxTime && now + time - localReadingTime > maxTime) { + time = maxTime; + } + readingTime = Math.max(localReadingTime, now + time); + return time; + } + readingTime = Math.max(localReadingTime, now); + return 0; + } + + /** + * Returns the time to wait (if any) for the given length message, using the given limitTraffic and + * the max wait time. + * + * @param size + * the write size + * @param limitTraffic + * the traffic limit in bytes per second. + * @param maxTime + * the max time in ms to wait in case of excess of traffic. + * @return the current time to wait (in ms) if needed for Write operation. + */ + @Deprecated + public long writeTimeToWait(final long size, final long limitTraffic, final long maxTime) { + return writeTimeToWait(size, limitTraffic, maxTime, milliSecondFromNano()); + } + + /** + * Returns the time to wait (if any) for the given length message, using the given limitTraffic and + * the max wait time. + * + * @param size + * the write size + * @param limitTraffic + * the traffic limit in bytes per second. + * @param maxTime + * the max time in ms to wait in case of excess of traffic. + * @param now the current time + * @return the current time to wait (in ms) if needed for Write operation. + */ + public long writeTimeToWait(final long size, final long limitTraffic, final long maxTime, final long now) { + bytesWriteFlowControl(size); + if (size == 0 || limitTraffic == 0) { + return 0; + } + final long lastTimeCheck = lastTime.get(); + long sum = currentWrittenBytes.get(); + long lastWB = lastWrittenBytes; + long localWritingTime = writingTime; + long pastDelay = Math.max(lastWritingTime - lastTimeCheck, 0); + final long interval = now - lastTimeCheck; + if (interval > AbstractTrafficShapingHandler.MINIMAL_WAIT) { + // Enough interval time to compute shaping + long time = sum * 1000 / limitTraffic - interval + pastDelay; + if (time > AbstractTrafficShapingHandler.MINIMAL_WAIT) { + if (logger.isDebugEnabled()) { + logger.debug("Time: " + time + ':' + sum + ':' + interval + ':' + pastDelay); + } + if (time > maxTime && now + time - localWritingTime > maxTime) { + time = maxTime; + } + writingTime = Math.max(localWritingTime, now + time); + return time; + } + writingTime = Math.max(localWritingTime, now); + return 0; + } + // take the last write interval check to get enough interval time + long lastsum = sum + lastWB; + long lastinterval = interval + checkInterval.get(); + long time = lastsum * 1000 / limitTraffic - lastinterval + pastDelay; + if (time > AbstractTrafficShapingHandler.MINIMAL_WAIT) { + if (logger.isDebugEnabled()) { + logger.debug("Time: " + time + ':' + lastsum + ':' + lastinterval + ':' + pastDelay); + } + if (time > maxTime && now + time - localWritingTime > maxTime) { + time = maxTime; + } + writingTime = Math.max(localWritingTime, now + time); + return time; + } + writingTime = Math.max(localWritingTime, now); + return 0; + } + + @Override + public String toString() { + return new StringBuilder(165).append("Monitor ").append(name) + .append(" Current Speed Read: ").append(lastReadThroughput >> 10).append(" KB/s, ") + .append("Asked Write: ").append(lastWriteThroughput >> 10).append(" KB/s, ") + .append("Real Write: ").append(realWriteThroughput >> 10).append(" KB/s, ") + .append("Current Read: ").append(currentReadBytes.get() >> 10).append(" KB, ") + .append("Current asked Write: ").append(currentWrittenBytes.get() >> 10).append(" KB, ") + .append("Current real Write: ").append(realWrittenBytes.get() >> 10).append(" KB").toString(); + } +} diff --git a/netty-handler/src/main/java/io/netty/handler/traffic/package-info.java b/netty-handler/src/main/java/io/netty/handler/traffic/package-info.java new file mode 100644 index 0000000..285318b --- /dev/null +++ b/netty-handler/src/main/java/io/netty/handler/traffic/package-info.java @@ -0,0 +1,60 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/** + * Implementation of a Traffic Shaping Handler and Dynamic Statistics. + * + *

The main goal of this package is to allow you to shape the traffic (bandwidth limitation), + * but also to get statistics on how many bytes are read or written. Both functions can + * be active or inactive (traffic or statistics).

+ * + *

Two classes implement this behavior: + *

    + *
  • {@link io.netty.handler.traffic.TrafficCounter}: this class implements the counters needed by the + * handlers. It can be accessed to get some extra information like the read or write bytes since last check, + * the read and write bandwidth from last check...
  • + * + *
  • {@link io.netty.handler.traffic.AbstractTrafficShapingHandler}: this abstract class implements + * the kernel of traffic shaping. It could be extended to fit your needs. Two classes are proposed as default + * implementations: see {@link io.netty.handler.traffic.ChannelTrafficShapingHandler} and + * {@link io.netty.handler.traffic.GlobalTrafficShapingHandler} respectively for Channel traffic shaping and + * global traffic shaping.
  • + *

+ * + *

Both inbound and outbound traffic can be shaped independently. This is done by either passing in + * the desired limiting values to the constructors of both the Channel and Global traffic shaping handlers, + * or by calling the configure method on the {@link io.netty.handler.traffic.AbstractTrafficShapingHandler}. + * A value of 0 for either parameter indicates that there should be no limitation. This allows you to monitor the + * incoming and outgoing traffic without shaping.

+ * + *

To activate or deactivate the statistics, you can adjust the delay to a low (suggested not less than 200ms + * for efficiency reasons) or a high value (let say 24H in millisecond is huge enough to not get the problem) + * or even using 0 which means no computation will be done.

+ * + *

If you want to do anything with these statistics, just override the doAccounting method.
+ * This interval can be changed either from the method configure + * in {@link io.netty.handler.traffic.AbstractTrafficShapingHandler} or directly using the method configure + * of {@link io.netty.handler.traffic.TrafficCounter}.

+ * + *

Note that a new {@link io.netty.handler.traffic.ChannelTrafficShapingHandler} must be created + * for each new channel, but only one {@link io.netty.handler.traffic.GlobalTrafficShapingHandler} must be created + * for all channels.

+ * + *

Note also that you can create different GlobalTrafficShapingHandler if you want to separate classes of + * channels (for instance either from business point of view or from bind address point of view).

+ */ +package io.netty.handler.traffic; + diff --git a/netty-handler/src/main/java/module-info.java b/netty-handler/src/main/java/module-info.java new file mode 100644 index 0000000..a51c4f6 --- /dev/null +++ b/netty-handler/src/main/java/module-info.java @@ -0,0 +1,16 @@ +module org.xbib.io.netty.handler { + exports io.netty.handler.address; + exports io.netty.handler.flow; + exports io.netty.handler.flush; + exports io.netty.handler.ipfilter; + exports io.netty.handler.logging; + exports io.netty.handler.pcap; + exports io.netty.handler.stream; + exports io.netty.handler.timeout; + exports io.netty.handler.traffic; + requires org.xbib.io.netty.buffer; + requires org.xbib.io.netty.channel; + requires org.xbib.io.netty.handler.codec; + requires org.xbib.io.netty.resolver; + requires org.xbib.io.netty.util; +} diff --git a/netty-handler/src/test/java/io/netty/handler/address/DynamicAddressConnectHandlerTest.java b/netty-handler/src/test/java/io/netty/handler/address/DynamicAddressConnectHandlerTest.java new file mode 100644 index 0000000..df0352e --- /dev/null +++ b/netty-handler/src/test/java/io/netty/handler/address/DynamicAddressConnectHandlerTest.java @@ -0,0 +1,107 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.address; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelOutboundHandlerAdapter; +import io.netty.channel.ChannelPromise; +import io.netty.channel.embedded.EmbeddedChannel; +import org.junit.jupiter.api.Test; + +import java.net.SocketAddress; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertSame; + +public class DynamicAddressConnectHandlerTest { + private static final SocketAddress LOCAL = new SocketAddress() { }; + private static final SocketAddress LOCAL_NEW = new SocketAddress() { }; + private static final SocketAddress REMOTE = new SocketAddress() { }; + private static final SocketAddress REMOTE_NEW = new SocketAddress() { }; + @Test + public void testReplaceAddresses() { + + EmbeddedChannel channel = new EmbeddedChannel(new ChannelOutboundHandlerAdapter() { + @Override + public void connect(ChannelHandlerContext ctx, SocketAddress remoteAddress, + SocketAddress localAddress, ChannelPromise promise) { + try { + assertSame(REMOTE_NEW, remoteAddress); + assertSame(LOCAL_NEW, localAddress); + promise.setSuccess(); + } catch (Throwable cause) { + promise.setFailure(cause); + } + } + }, new DynamicAddressConnectHandler() { + @Override + protected SocketAddress localAddress(SocketAddress remoteAddress, SocketAddress localAddress) { + assertSame(REMOTE, remoteAddress); + assertSame(LOCAL, localAddress); + return LOCAL_NEW; + } + + @Override + protected SocketAddress remoteAddress(SocketAddress remoteAddress, SocketAddress localAddress) { + assertSame(REMOTE, remoteAddress); + assertSame(LOCAL, localAddress); + return REMOTE_NEW; + } + }); + channel.connect(REMOTE, LOCAL).syncUninterruptibly(); + assertNull(channel.pipeline().get(DynamicAddressConnectHandler.class)); + assertFalse(channel.finish()); + } + + @Test + public void testLocalAddressThrows() { + testThrows0(true); + } + + @Test + public void testRemoteAddressThrows() { + testThrows0(false); + } + + private static void testThrows0(final boolean localThrows) { + final IllegalStateException exception = new IllegalStateException(); + + EmbeddedChannel channel = new EmbeddedChannel(new DynamicAddressConnectHandler() { + @Override + protected SocketAddress localAddress( + SocketAddress remoteAddress, SocketAddress localAddress) throws Exception { + if (localThrows) { + throw exception; + } + return super.localAddress(remoteAddress, localAddress); + } + + @Override + protected SocketAddress remoteAddress(SocketAddress remoteAddress, SocketAddress localAddress) + throws Exception { + if (!localThrows) { + throw exception; + } + return super.remoteAddress(remoteAddress, localAddress); + } + }); + assertSame(exception, channel.connect(REMOTE, LOCAL).cause()); + assertNotNull(channel.pipeline().get(DynamicAddressConnectHandler.class)); + assertFalse(channel.finish()); + } +} diff --git a/netty-handler/src/test/java/io/netty/handler/address/ResolveAddressHandlerTest.java b/netty-handler/src/test/java/io/netty/handler/address/ResolveAddressHandlerTest.java new file mode 100644 index 0000000..33786f7 --- /dev/null +++ b/netty-handler/src/test/java/io/netty/handler/address/ResolveAddressHandlerTest.java @@ -0,0 +1,142 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.address; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.DefaultEventLoopGroup; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.local.LocalAddress; +import io.netty.channel.local.LocalChannel; +import io.netty.channel.local.LocalServerChannel; +import io.netty.resolver.AbstractAddressResolver; +import io.netty.resolver.AddressResolver; +import io.netty.resolver.AddressResolverGroup; +import io.netty.util.concurrent.EventExecutor; +import io.netty.util.concurrent.Promise; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + + +import java.net.SocketAddress; +import java.net.UnknownHostException; +import java.util.List; +import java.util.UUID; + +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class ResolveAddressHandlerTest { + + private static final LocalAddress UNRESOLVED = new LocalAddress("unresolved-" + UUID.randomUUID()); + private static final LocalAddress RESOLVED = new LocalAddress("resolved-" + UUID.randomUUID()); + private static final Exception ERROR = new UnknownHostException(); + + private static EventLoopGroup group; + + @BeforeAll + public static void createEventLoop() { + group = new DefaultEventLoopGroup(); + } + + @AfterAll + public static void destroyEventLoop() { + if (group != null) { + group.shutdownGracefully(); + } + } + + @Test + public void testResolveSuccessful() { + testResolve(false); + } + + @Test + public void testResolveFails() { + testResolve(true); + } + + private static void testResolve(boolean fail) { + AddressResolverGroup resolverGroup = new TestResolverGroup(fail); + Bootstrap cb = new Bootstrap(); + cb.group(group).channel(LocalChannel.class).handler(new ResolveAddressHandler(resolverGroup)); + + ServerBootstrap sb = new ServerBootstrap(); + sb.group(group) + .channel(LocalServerChannel.class) + .childHandler(new ChannelInboundHandlerAdapter() { + @Override + public void channelActive(ChannelHandlerContext ctx) { + ctx.close(); + } + }); + + // Start server + Channel sc = sb.bind(RESOLVED).syncUninterruptibly().channel(); + ChannelFuture future = cb.connect(UNRESOLVED).awaitUninterruptibly(); + try { + if (fail) { + assertSame(ERROR, future.cause()); + } else { + assertTrue(future.isSuccess()); + } + future.channel().close().syncUninterruptibly(); + } finally { + future.channel().close().syncUninterruptibly(); + sc.close().syncUninterruptibly(); + resolverGroup.close(); + } + } + + private static final class TestResolverGroup extends AddressResolverGroup { + private final boolean fail; + + TestResolverGroup(boolean fail) { + this.fail = fail; + } + + @Override + protected AddressResolver newResolver(EventExecutor executor) { + return new AbstractAddressResolver(executor) { + @Override + protected boolean doIsResolved(SocketAddress address) { + return address == RESOLVED; + } + + @Override + protected void doResolve(SocketAddress unresolvedAddress, Promise promise) { + assertSame(UNRESOLVED, unresolvedAddress); + if (fail) { + promise.setFailure(ERROR); + } else { + promise.setSuccess(RESOLVED); + } + } + + @Override + protected void doResolveAll(SocketAddress unresolvedAddress, Promise> promise) { + fail(); + } + }; + } + } +} diff --git a/netty-handler/src/test/java/io/netty/handler/flow/FlowControlHandlerTest.java b/netty-handler/src/test/java/io/netty/handler/flow/FlowControlHandlerTest.java new file mode 100644 index 0000000..dd45e64 --- /dev/null +++ b/netty-handler/src/test/java/io/netty/handler/flow/FlowControlHandlerTest.java @@ -0,0 +1,680 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version + * 2.0 (the "License"); you may not use this file except in compliance with the + * License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.flow; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelConfig; +import io.netty.channel.ChannelDuplexHandler; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelOption; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.nio.NioServerSocketChannel; +import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.handler.codec.ByteToMessageDecoder; +import io.netty.handler.timeout.IdleStateEvent; +import io.netty.handler.timeout.IdleStateHandler; +import io.netty.util.ReferenceCountUtil; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +import java.net.SocketAddress; +import java.util.List; +import java.util.Queue; +import java.util.concurrent.Callable; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Exchanger; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.atomic.AtomicReference; + +import static java.util.concurrent.TimeUnit.*; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class FlowControlHandlerTest { + private static EventLoopGroup GROUP; + + @BeforeAll + public static void init() { + GROUP = new NioEventLoopGroup(); + } + + @AfterAll + public static void destroy() { + GROUP.shutdownGracefully(); + } + + /** + * The {@link OneByteToThreeStringsDecoder} decodes this {@code byte[]} into three messages. + */ + private static ByteBuf newOneMessage() { + return Unpooled.wrappedBuffer(new byte[]{ 1 }); + } + + private static Channel newServer(final boolean autoRead, final ChannelHandler... handlers) { + assertTrue(handlers.length >= 1); + + ServerBootstrap serverBootstrap = new ServerBootstrap(); + serverBootstrap.group(GROUP) + .channel(NioServerSocketChannel.class) + .childOption(ChannelOption.AUTO_READ, autoRead) + .childHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) { + ChannelPipeline pipeline = ch.pipeline(); + pipeline.addLast(new OneByteToThreeStringsDecoder()); + pipeline.addLast(handlers); + } + }); + + return serverBootstrap.bind(0) + .syncUninterruptibly() + .channel(); + } + + private static Channel newClient(SocketAddress server) { + Bootstrap bootstrap = new Bootstrap(); + + bootstrap.group(GROUP) + .channel(NioSocketChannel.class) + .option(ChannelOption.CONNECT_TIMEOUT_MILLIS, 1000) + .handler(new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + fail("In this test the client is never receiving a message from the server."); + } + }); + + return bootstrap.connect(server) + .syncUninterruptibly() + .channel(); + } + + /** + * This test demonstrates the default behavior if auto reading + * is turned on from the get-go and you're trying to turn it off + * once you've received your first message. + * + * NOTE: This test waits for the client to disconnect which is + * interpreted as the signal that all {@code byte}s have been + * transferred to the server. + */ + @Test + public void testAutoReadingOn() throws Exception { + final CountDownLatch latch = new CountDownLatch(3); + + ChannelInboundHandlerAdapter handler = new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + ReferenceCountUtil.release(msg); + // We're turning off auto reading in the hope that no + // new messages are being sent but that is not true. + ctx.channel().config().setAutoRead(false); + + latch.countDown(); + } + }; + + Channel server = newServer(true, handler); + Channel client = newClient(server.localAddress()); + + try { + client.writeAndFlush(newOneMessage()) + .syncUninterruptibly(); + + // We received three messages even through auto reading + // was turned off after we received the first message. + assertTrue(latch.await(1L, SECONDS)); + } finally { + client.close(); + server.close(); + } + } + + /** + * This test demonstrates the default behavior if auto reading + * is turned off from the get-go and you're calling read() in + * the hope that only one message will be returned. + * + * NOTE: This test waits for the client to disconnect which is + * interpreted as the signal that all {@code byte}s have been + * transferred to the server. + */ + @Test + public void testAutoReadingOff() throws Exception { + final Exchanger peerRef = new Exchanger(); + final CountDownLatch latch = new CountDownLatch(3); + + ChannelInboundHandlerAdapter handler = new ChannelInboundHandlerAdapter() { + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + peerRef.exchange(ctx.channel(), 1L, SECONDS); + ctx.fireChannelActive(); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + ReferenceCountUtil.release(msg); + latch.countDown(); + } + }; + + Channel server = newServer(false, handler); + Channel client = newClient(server.localAddress()); + + try { + // The client connection on the server side + Channel peer = peerRef.exchange(null, 1L, SECONDS); + + // Write the message + client.writeAndFlush(newOneMessage()) + .syncUninterruptibly(); + + // Read the message + peer.read(); + + // We received all three messages but hoped that only one + // message was read because auto reading was off and we + // invoked the read() method only once. + assertTrue(latch.await(1L, SECONDS)); + } finally { + client.close(); + server.close(); + } + } + + /** + * The {@link FlowControlHandler} will simply pass-through all messages + * if auto reading is on and remains on. + */ + @Test + public void testFlowAutoReadOn() throws Exception { + final CountDownLatch latch = new CountDownLatch(3); + final Exchanger peerRef = new Exchanger(); + + ChannelInboundHandlerAdapter handler = new ChannelDuplexHandler() { + + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + peerRef.exchange(ctx.channel(), 1L, SECONDS); + super.channelActive(ctx); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + ReferenceCountUtil.release(msg); + latch.countDown(); + } + }; + + final FlowControlHandler flow = new FlowControlHandler(); + Channel server = newServer(true, flow, handler); + Channel client = newClient(server.localAddress()); + try { + // The client connection on the server side + Channel peer = peerRef.exchange(null, 1L, SECONDS); + + // Write the message + client.writeAndFlush(newOneMessage()) + .syncUninterruptibly(); + + // We should receive 3 messages + assertTrue(latch.await(1L, SECONDS)); + + assertTrue(peer.eventLoop().submit(new Callable() { + @Override + public Boolean call() { + return flow.isQueueEmpty(); + } + }).get()); + } finally { + client.close(); + server.close(); + } + } + + /** + * The {@link FlowControlHandler} will pass down messages one by one + * if {@link ChannelConfig#setAutoRead(boolean)} is being toggled. + */ + @Test + public void testFlowToggleAutoRead() throws Exception { + final Exchanger peerRef = new Exchanger(); + final CountDownLatch msgRcvLatch1 = new CountDownLatch(1); + final CountDownLatch msgRcvLatch2 = new CountDownLatch(1); + final CountDownLatch msgRcvLatch3 = new CountDownLatch(1); + final CountDownLatch setAutoReadLatch1 = new CountDownLatch(1); + final CountDownLatch setAutoReadLatch2 = new CountDownLatch(1); + + ChannelInboundHandlerAdapter handler = new ChannelInboundHandlerAdapter() { + private int msgRcvCount; + private int expectedMsgCount; + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + peerRef.exchange(ctx.channel(), 1L, SECONDS); + ctx.fireChannelActive(); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws InterruptedException { + ReferenceCountUtil.release(msg); + + // Disable auto reading after each message + ctx.channel().config().setAutoRead(false); + + if (msgRcvCount++ != expectedMsgCount) { + return; + } + switch (msgRcvCount) { + case 1: + msgRcvLatch1.countDown(); + if (setAutoReadLatch1.await(1L, SECONDS)) { + ++expectedMsgCount; + } + break; + case 2: + msgRcvLatch2.countDown(); + if (setAutoReadLatch2.await(1L, SECONDS)) { + ++expectedMsgCount; + } + break; + default: + msgRcvLatch3.countDown(); + break; + } + } + }; + + final FlowControlHandler flow = new FlowControlHandler(); + Channel server = newServer(true, flow, handler); + Channel client = newClient(server.localAddress()); + try { + // The client connection on the server side + Channel peer = peerRef.exchange(null, 1L, SECONDS); + + client.writeAndFlush(newOneMessage()) + .syncUninterruptibly(); + + // channelRead(1) + assertTrue(msgRcvLatch1.await(1L, SECONDS)); + + // channelRead(2) + peer.config().setAutoRead(true); + setAutoReadLatch1.countDown(); + assertTrue(msgRcvLatch1.await(1L, SECONDS)); + + // channelRead(3) + peer.config().setAutoRead(true); + setAutoReadLatch2.countDown(); + assertTrue(msgRcvLatch3.await(1L, SECONDS)); + + assertTrue(peer.eventLoop().submit(new Callable() { + @Override + public Boolean call() { + return flow.isQueueEmpty(); + } + }).get()); + } finally { + client.close(); + server.close(); + } + } + + /** + * The {@link FlowControlHandler} will pass down messages one by one + * if auto reading is off and the user is calling {@code read()} on + * their own. + */ + @Test + public void testFlowAutoReadOff() throws Exception { + final Exchanger peerRef = new Exchanger(); + final CountDownLatch msgRcvLatch1 = new CountDownLatch(1); + final CountDownLatch msgRcvLatch2 = new CountDownLatch(2); + final CountDownLatch msgRcvLatch3 = new CountDownLatch(3); + + ChannelInboundHandlerAdapter handler = new ChannelDuplexHandler() { + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + ctx.fireChannelActive(); + peerRef.exchange(ctx.channel(), 1L, SECONDS); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + msgRcvLatch1.countDown(); + msgRcvLatch2.countDown(); + msgRcvLatch3.countDown(); + } + }; + + final FlowControlHandler flow = new FlowControlHandler(); + Channel server = newServer(false, flow, handler); + Channel client = newClient(server.localAddress()); + try { + // The client connection on the server side + Channel peer = peerRef.exchange(null, 1L, SECONDS); + + // Write the message + client.writeAndFlush(newOneMessage()) + .syncUninterruptibly(); + + // channelRead(1) + peer.read(); + assertTrue(msgRcvLatch1.await(1L, SECONDS)); + + // channelRead(2) + peer.read(); + assertTrue(msgRcvLatch2.await(1L, SECONDS)); + + // channelRead(3) + peer.read(); + assertTrue(msgRcvLatch3.await(1L, SECONDS)); + + assertTrue(peer.eventLoop().submit(new Callable() { + @Override + public Boolean call() { + return flow.isQueueEmpty(); + } + }).get()); + } finally { + client.close(); + server.close(); + } + } + + /** + * The {@link FlowControlHandler} will not pass read events onto the + * pipeline when the user is calling {@code read()} on their own if the + * queue is not empty and auto-reading is turned off for the channel. + */ + @Test + public void testFlowAutoReadOffAndQueueNonEmpty() throws Exception { + final Exchanger peerRef = new Exchanger(); + final CountDownLatch msgRcvLatch1 = new CountDownLatch(1); + final CountDownLatch msgRcvLatch2 = new CountDownLatch(2); + final CountDownLatch msgRcvLatch3 = new CountDownLatch(3); + + ChannelInboundHandlerAdapter handler = new ChannelDuplexHandler() { + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + ctx.fireChannelActive(); + peerRef.exchange(ctx.channel(), 1L, SECONDS); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + msgRcvLatch1.countDown(); + msgRcvLatch2.countDown(); + msgRcvLatch3.countDown(); + } + }; + + final FlowControlHandler flow = new FlowControlHandler(); + Channel server = newServer(false, flow, handler); + Channel client = newClient(server.localAddress()); + try { + // The client connection on the server side + Channel peer = peerRef.exchange(null, 1L, SECONDS); + + // Write the first message + client.writeAndFlush(newOneMessage()) + .syncUninterruptibly(); + + // channelRead(1) + peer.read(); + assertTrue(msgRcvLatch1.await(1L, SECONDS)); + assertFalse(peer.eventLoop().submit(new Callable() { + @Override + public Boolean call() { + return flow.isQueueEmpty(); + } + }).get()); + + // Write the second message + client.writeAndFlush(newOneMessage()) + .syncUninterruptibly(); + + // channelRead(2) + peer.read(); + assertTrue(msgRcvLatch2.await(1L, SECONDS)); + + // channelRead(3) + peer.read(); + assertTrue(msgRcvLatch3.await(1L, SECONDS)); + + assertTrue(peer.eventLoop().submit(new Callable() { + @Override + public Boolean call() { + return flow.isQueueEmpty(); + } + }).get()); + } finally { + client.close(); + server.close(); + } + } + + @Test + public void testReentranceNotCausesNPE() throws Throwable { + final Exchanger peerRef = new Exchanger(); + final CountDownLatch latch = new CountDownLatch(3); + final AtomicReference causeRef = new AtomicReference(); + ChannelInboundHandlerAdapter handler = new ChannelDuplexHandler() { + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + ctx.fireChannelActive(); + peerRef.exchange(ctx.channel(), 1L, SECONDS); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + latch.countDown(); + ctx.read(); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + causeRef.set(cause); + } + }; + + final FlowControlHandler flow = new FlowControlHandler(); + Channel server = newServer(false, flow, handler); + Channel client = newClient(server.localAddress()); + try { + // The client connection on the server side + Channel peer = peerRef.exchange(null, 1L, SECONDS); + + // Write the message + client.writeAndFlush(newOneMessage()) + .syncUninterruptibly(); + + // channelRead(1) + peer.read(); + assertTrue(latch.await(1L, SECONDS)); + + assertTrue(peer.eventLoop().submit(new Callable() { + @Override + public Boolean call() { + return flow.isQueueEmpty(); + } + }).get()); + + Throwable cause = causeRef.get(); + if (cause != null) { + throw cause; + } + } finally { + client.close(); + server.close(); + } + } + + @Test + public void testSwallowedReadComplete() throws Exception { + final long delayMillis = 100; + final Queue userEvents = new LinkedBlockingQueue(); + final EmbeddedChannel channel = new EmbeddedChannel(false, false, + new FlowControlHandler(), + new IdleStateHandler(delayMillis, 0, 0, MILLISECONDS), + new ChannelInboundHandlerAdapter() { + @Override + public void channelActive(ChannelHandlerContext ctx) { + ctx.fireChannelActive(); + ctx.read(); + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + ctx.fireChannelRead(msg); + ctx.read(); + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) { + ctx.fireChannelReadComplete(); + ctx.read(); + } + + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) { + if (evt instanceof IdleStateEvent) { + userEvents.add((IdleStateEvent) evt); + } + ctx.fireUserEventTriggered(evt); + } + } + ); + + channel.config().setAutoRead(false); + assertFalse(channel.config().isAutoRead()); + + channel.register(); + + // Reset read timeout by some message + assertTrue(channel.writeInbound(Unpooled.EMPTY_BUFFER)); + channel.flushInbound(); + assertEquals(Unpooled.EMPTY_BUFFER, channel.readInbound()); + + // Emulate 'no more messages in NIO channel' on the next read attempt. + channel.flushInbound(); + assertNull(channel.readInbound()); + + Thread.sleep(delayMillis + 20L); + channel.runPendingTasks(); + assertEquals(IdleStateEvent.FIRST_READER_IDLE_STATE_EVENT, userEvents.poll()); + assertFalse(channel.finish()); + } + + @Test + public void testRemoveFlowControl() throws Exception { + final Exchanger peerRef = new Exchanger(); + + final CountDownLatch latch = new CountDownLatch(3); + + ChannelInboundHandlerAdapter handler = new ChannelDuplexHandler() { + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + peerRef.exchange(ctx.channel(), 1L, SECONDS); + //do the first read + ctx.read(); + super.channelActive(ctx); + } + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + latch.countDown(); + super.channelRead(ctx, msg); + } + }; + + final FlowControlHandler flow = new FlowControlHandler() { + private int num; + @Override + public void channelRead(final ChannelHandlerContext ctx, Object msg) throws Exception { + super.channelRead(ctx, msg); + ++num; + if (num >= 3) { + //We have received 3 messages. Remove myself later + final ChannelHandler handler = this; + ctx.channel().eventLoop().execute(new Runnable() { + @Override + public void run() { + ctx.pipeline().remove(handler); + } + }); + } + } + }; + ChannelInboundHandlerAdapter tail = new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + //consume this msg + ReferenceCountUtil.release(msg); + } + }; + + Channel server = newServer(false /* no auto read */, flow, handler, tail); + Channel client = newClient(server.localAddress()); + try { + // The client connection on the server side + Channel peer = peerRef.exchange(null, 1L, SECONDS); + + // Write one message + client.writeAndFlush(newOneMessage()).sync(); + + // We should receive 3 messages + assertTrue(latch.await(1L, SECONDS)); + assertTrue(peer.eventLoop().submit(new Callable() { + @Override + public Boolean call() { + return flow.isQueueEmpty(); + } + }).get()); + } finally { + client.close(); + server.close(); + } + } + + /** + * This is a fictional message decoder. It decodes each {@code byte} + * into three strings. + */ + private static final class OneByteToThreeStringsDecoder extends ByteToMessageDecoder { + @Override + protected void decode(ChannelHandlerContext ctx, ByteBuf in, List out) { + for (int i = 0; i < in.readableBytes(); i++) { + out.add("1"); + out.add("2"); + out.add("3"); + } + in.readerIndex(in.readableBytes()); + } + } +} diff --git a/netty-handler/src/test/java/io/netty/handler/flush/FlushConsolidationHandlerTest.java b/netty-handler/src/test/java/io/netty/handler/flush/FlushConsolidationHandlerTest.java new file mode 100644 index 0000000..9d7cb3d --- /dev/null +++ b/netty-handler/src/test/java/io/netty/handler/flush/FlushConsolidationHandlerTest.java @@ -0,0 +1,203 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.flush; + +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelOutboundHandlerAdapter; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.GenericFutureListener; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import java.util.concurrent.atomic.AtomicInteger; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class FlushConsolidationHandlerTest { + + private static final int EXPLICIT_FLUSH_AFTER_FLUSHES = 3; + + @Test + public void testFlushViaScheduledTask() { + final AtomicInteger flushCount = new AtomicInteger(); + EmbeddedChannel channel = newChannel(flushCount, true); + // Flushes should not go through immediately, as they're scheduled as an async task + channel.flush(); + assertEquals(0, flushCount.get()); + channel.flush(); + assertEquals(0, flushCount.get()); + // Trigger the execution of the async task + channel.runPendingTasks(); + assertEquals(1, flushCount.get()); + assertFalse(channel.finish()); + } + + @Test + public void testFlushViaThresholdOutsideOfReadLoop() { + final AtomicInteger flushCount = new AtomicInteger(); + EmbeddedChannel channel = newChannel(flushCount, true); + // After a given threshold, the async task should be bypassed and a flush should be triggered immediately + for (int i = 0; i < EXPLICIT_FLUSH_AFTER_FLUSHES; i++) { + channel.flush(); + } + assertEquals(1, flushCount.get()); + assertFalse(channel.finish()); + } + + @Test + public void testImmediateFlushOutsideOfReadLoop() { + final AtomicInteger flushCount = new AtomicInteger(); + EmbeddedChannel channel = newChannel(flushCount, false); + channel.flush(); + assertEquals(1, flushCount.get()); + assertFalse(channel.finish()); + } + + @Test + public void testFlushViaReadComplete() { + final AtomicInteger flushCount = new AtomicInteger(); + EmbeddedChannel channel = newChannel(flushCount, false); + // Flush should go through as there is no read loop in progress. + channel.flush(); + channel.runPendingTasks(); + assertEquals(1, flushCount.get()); + + // Simulate read loop; + channel.pipeline().fireChannelRead(1L); + assertEquals(1, flushCount.get()); + channel.pipeline().fireChannelRead(2L); + assertEquals(1, flushCount.get()); + assertNull(channel.readOutbound()); + channel.pipeline().fireChannelReadComplete(); + assertEquals(2, flushCount.get()); + // Now flush again as the read loop is complete. + channel.flush(); + channel.runPendingTasks(); + assertEquals(3, flushCount.get()); + assertEquals(1L, (Long) channel.readOutbound()); + assertEquals(2L, (Long) channel.readOutbound()); + assertNull(channel.readOutbound()); + assertFalse(channel.finish()); + } + + @Test + public void testFlushViaClose() { + final AtomicInteger flushCount = new AtomicInteger(); + EmbeddedChannel channel = newChannel(flushCount, false); + // Simulate read loop; + channel.pipeline().fireChannelRead(1L); + assertEquals(0, flushCount.get()); + assertNull(channel.readOutbound()); + channel.close(); + assertEquals(1, flushCount.get()); + assertEquals(1L, (Long) channel.readOutbound()); + assertNull(channel.readOutbound()); + assertFalse(channel.finish()); + } + + @Test + public void testFlushViaDisconnect() { + final AtomicInteger flushCount = new AtomicInteger(); + EmbeddedChannel channel = newChannel(flushCount, false); + // Simulate read loop; + channel.pipeline().fireChannelRead(1L); + assertEquals(0, flushCount.get()); + assertNull(channel.readOutbound()); + channel.disconnect(); + assertEquals(1, flushCount.get()); + assertEquals(1L, (Long) channel.readOutbound()); + assertNull(channel.readOutbound()); + assertFalse(channel.finish()); + } + + @Test + public void testFlushViaException() { + final AtomicInteger flushCount = new AtomicInteger(); + final EmbeddedChannel channel = newChannel(flushCount, false); + // Simulate read loop; + channel.pipeline().fireChannelRead(1L); + assertEquals(0, flushCount.get()); + assertNull(channel.readOutbound()); + channel.pipeline().fireExceptionCaught(new IllegalStateException()); + assertEquals(1, flushCount.get()); + assertEquals(1L, (Long) channel.readOutbound()); + assertNull(channel.readOutbound()); + assertThrows(IllegalStateException.class, new Executable() { + @Override + public void execute() throws Throwable { + channel.finish(); + } + }); + } + + @Test + public void testFlushViaRemoval() { + final AtomicInteger flushCount = new AtomicInteger(); + EmbeddedChannel channel = newChannel(flushCount, false); + // Simulate read loop; + channel.pipeline().fireChannelRead(1L); + assertEquals(0, flushCount.get()); + assertNull(channel.readOutbound()); + channel.pipeline().remove(FlushConsolidationHandler.class); + assertEquals(1, flushCount.get()); + assertEquals(1L, (Long) channel.readOutbound()); + assertNull(channel.readOutbound()); + assertFalse(channel.finish()); + } + + /** + * See https://github.com/netty/netty/issues/9923 + */ + @Test + public void testResend() throws Exception { + final AtomicInteger flushCount = new AtomicInteger(); + final EmbeddedChannel channel = newChannel(flushCount, true); + channel.writeAndFlush(1L).addListener(new GenericFutureListener>() { + @Override + public void operationComplete(Future future) throws Exception { + channel.writeAndFlush(1L); + } + }); + channel.flushOutbound(); + assertEquals(1L, (Long) channel.readOutbound()); + assertEquals(1L, (Long) channel.readOutbound()); + assertNull(channel.readOutbound()); + assertFalse(channel.finish()); + } + + private static EmbeddedChannel newChannel(final AtomicInteger flushCount, boolean consolidateWhenNoReadInProgress) { + return new EmbeddedChannel( + new ChannelOutboundHandlerAdapter() { + @Override + public void flush(ChannelHandlerContext ctx) throws Exception { + flushCount.incrementAndGet(); + ctx.flush(); + } + }, + new FlushConsolidationHandler(EXPLICIT_FLUSH_AFTER_FLUSHES, consolidateWhenNoReadInProgress), + new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + ctx.writeAndFlush(msg); + } + }); + } +} diff --git a/netty-handler/src/test/java/io/netty/handler/ipfilter/IpSubnetFilterTest.java b/netty-handler/src/test/java/io/netty/handler/ipfilter/IpSubnetFilterTest.java new file mode 100644 index 0000000..6566c49 --- /dev/null +++ b/netty-handler/src/test/java/io/netty/handler/ipfilter/IpSubnetFilterTest.java @@ -0,0 +1,220 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ipfilter; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.util.internal.SocketUtils; +import org.junit.jupiter.api.Test; + +import java.net.InetSocketAddress; +import java.net.SocketAddress; +import java.util.ArrayList; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class IpSubnetFilterTest { + + @Test + public void testIpv4DefaultRoute() { + IpSubnetFilterRule rule = new IpSubnetFilterRule("0.0.0.0", 0, IpFilterRuleType.ACCEPT); + assertTrue(rule.matches(newSockAddress("91.114.240.43"))); + assertTrue(rule.matches(newSockAddress("10.0.0.3"))); + assertTrue(rule.matches(newSockAddress("192.168.93.2"))); + } + + @Test + public void testIpv4SubnetMaskCorrectlyHandlesIpv6() { + IpSubnetFilterRule rule = new IpSubnetFilterRule("0.0.0.0", 0, IpFilterRuleType.ACCEPT); + assertFalse(rule.matches(newSockAddress("2001:db8:abcd:0000::1"))); + } + + @Test + public void testIpv6SubnetMaskCorrectlyHandlesIpv4() { + IpSubnetFilterRule rule = new IpSubnetFilterRule("::", 0, IpFilterRuleType.ACCEPT); + assertFalse(rule.matches(newSockAddress("91.114.240.43"))); + } + + @Test + public void testIp4SubnetFilterRule() throws Exception { + IpSubnetFilterRule rule = new IpSubnetFilterRule("192.168.56.1", 24, IpFilterRuleType.ACCEPT); + for (int i = 0; i <= 255; i++) { + assertTrue(rule.matches(newSockAddress(String.format("192.168.56.%d", i)))); + } + assertFalse(rule.matches(newSockAddress("192.168.57.1"))); + + rule = new IpSubnetFilterRule("91.114.240.1", 23, IpFilterRuleType.ACCEPT); + assertTrue(rule.matches(newSockAddress("91.114.240.43"))); + assertTrue(rule.matches(newSockAddress("91.114.240.255"))); + assertTrue(rule.matches(newSockAddress("91.114.241.193"))); + assertTrue(rule.matches(newSockAddress("91.114.241.254"))); + assertFalse(rule.matches(newSockAddress("91.115.241.2"))); + } + + @Test + public void testIp6SubnetFilterRule() { + IpSubnetFilterRule rule; + + rule = new IpSubnetFilterRule("2001:db8:abcd:0000::", 52, IpFilterRuleType.ACCEPT); + assertTrue(rule.matches(newSockAddress("2001:db8:abcd:0000::1"))); + assertTrue(rule.matches(newSockAddress("2001:db8:abcd:0fff:ffff:ffff:ffff:ffff"))); + assertFalse(rule.matches(newSockAddress("2001:db8:abcd:1000::"))); + } + + @Test + public void testIp6SubnetFilterDefaultRule() { + IpFilterRule rule = new IpSubnetFilterRule("::", 0, IpFilterRuleType.ACCEPT); + assertTrue(rule.matches(newSockAddress("7FFF:FFFF:FFFF:FFFF:FFFF:FFFF:FFFF:FFFF"))); + assertTrue(rule.matches(newSockAddress("8000::"))); + } + + @Test + public void testIpFilterRuleHandler() throws Exception { + IpFilterRule filter0 = new IpFilterRule() { + @Override + public boolean matches(InetSocketAddress remoteAddress) { + return "192.168.57.1".equals(remoteAddress.getHostName()); + } + + @Override + public IpFilterRuleType ruleType() { + return IpFilterRuleType.REJECT; + } + }; + + RuleBasedIpFilter denyHandler = new RuleBasedIpFilter(filter0) { + private final byte[] message = {1, 2, 3, 4, 5, 6, 7}; + + @Override + protected ChannelFuture channelRejected(ChannelHandlerContext ctx, InetSocketAddress remoteAddress) { + assertTrue(ctx.channel().isActive()); + assertTrue(ctx.channel().isWritable()); + assertEquals("192.168.57.1", remoteAddress.getHostName()); + + return ctx.writeAndFlush(Unpooled.wrappedBuffer(message)); + } + }; + EmbeddedChannel chDeny = newEmbeddedInetChannel("192.168.57.1", denyHandler); + ByteBuf out = chDeny.readOutbound(); + assertEquals(7, out.readableBytes()); + for (byte i = 1; i <= 7; i++) { + assertEquals(i, out.readByte()); + } + assertFalse(chDeny.isActive()); + assertFalse(chDeny.isOpen()); + + RuleBasedIpFilter allowHandler = new RuleBasedIpFilter(filter0) { + @Override + protected ChannelFuture channelRejected(ChannelHandlerContext ctx, InetSocketAddress remoteAddress) { + fail(); + return null; + } + }; + EmbeddedChannel chAllow = newEmbeddedInetChannel("192.168.57.2", allowHandler); + assertTrue(chAllow.isActive()); + assertTrue(chAllow.isOpen()); + } + + @Test + public void testUniqueIpFilterHandler() { + UniqueIpFilter handler = new UniqueIpFilter(); + + EmbeddedChannel ch1 = newEmbeddedInetChannel("91.92.93.1", handler); + assertTrue(ch1.isActive()); + EmbeddedChannel ch2 = newEmbeddedInetChannel("91.92.93.2", handler); + assertTrue(ch2.isActive()); + EmbeddedChannel ch3 = newEmbeddedInetChannel("91.92.93.1", handler); + assertFalse(ch3.isActive()); + + // false means that no data is left to read/write + assertFalse(ch1.finish()); + + EmbeddedChannel ch4 = newEmbeddedInetChannel("91.92.93.1", handler); + assertTrue(ch4.isActive()); + } + + @Test + public void testBinarySearch() { + List ipSubnetFilterRuleList = new ArrayList(); + ipSubnetFilterRuleList.add(buildRejectIP("1.2.3.4", 32)); + ipSubnetFilterRuleList.add(buildRejectIP("1.1.1.1", 8)); + ipSubnetFilterRuleList.add(buildRejectIP("200.200.200.200", 32)); + ipSubnetFilterRuleList.add(buildRejectIP("108.0.0.0", 4)); + ipSubnetFilterRuleList.add(buildRejectIP("10.10.10.10", 8)); + ipSubnetFilterRuleList.add(buildRejectIP("2001:db8:abcd:0000::", 52)); + + // 1.0.0.0/8 + EmbeddedChannel ch1 = newEmbeddedInetChannel("1.1.1.1", new IpSubnetFilter(ipSubnetFilterRuleList)); + assertFalse(ch1.isActive()); + assertTrue(ch1.close().isSuccess()); + + // Nothing applies here + EmbeddedChannel ch2 = newEmbeddedInetChannel("2.2.2.2", new IpSubnetFilter(ipSubnetFilterRuleList)); + assertTrue(ch2.isActive()); + assertTrue(ch2.close().isSuccess()); + + // 108.0.0.0/4 + EmbeddedChannel ch3 = newEmbeddedInetChannel("97.100.100.100", new IpSubnetFilter(ipSubnetFilterRuleList)); + assertFalse(ch3.isActive()); + assertTrue(ch3.close().isSuccess()); + + // 200.200.200.200/32 + EmbeddedChannel ch4 = newEmbeddedInetChannel("200.200.200.200", new IpSubnetFilter(ipSubnetFilterRuleList)); + assertFalse(ch4.isActive()); + assertTrue(ch4.close().isSuccess()); + + // Nothing applies here + EmbeddedChannel ch5 = newEmbeddedInetChannel("127.0.0.1", new IpSubnetFilter(ipSubnetFilterRuleList)); + assertTrue(ch5.isActive()); + assertTrue(ch5.close().isSuccess()); + + // 10.0.0.0/8 + EmbeddedChannel ch6 = newEmbeddedInetChannel("10.1.1.2", new IpSubnetFilter(ipSubnetFilterRuleList)); + assertFalse(ch6.isActive()); + assertTrue(ch6.close().isSuccess()); + + //2001:db8:abcd:0000::/52 + EmbeddedChannel ch7 = newEmbeddedInetChannel("2001:db8:abcd:1000::", + new IpSubnetFilter(ipSubnetFilterRuleList)); + assertFalse(ch7.isActive()); + assertTrue(ch7.close().isSuccess()); + } + + private static IpSubnetFilterRule buildRejectIP(String ipAddress, int mask) { + return new IpSubnetFilterRule(ipAddress, mask, IpFilterRuleType.REJECT); + } + + private static EmbeddedChannel newEmbeddedInetChannel(final String ipAddress, ChannelHandler... handlers) { + return new EmbeddedChannel(handlers) { + @Override + protected SocketAddress remoteAddress0() { + return isActive()? SocketUtils.socketAddress(ipAddress, 5421) : null; + } + }; + } + + private static InetSocketAddress newSockAddress(String ipAddress) { + return SocketUtils.socketAddress(ipAddress, 1234); + } +} diff --git a/netty-handler/src/test/java/io/netty/handler/ipfilter/UniqueIpFilterTest.java b/netty-handler/src/test/java/io/netty/handler/ipfilter/UniqueIpFilterTest.java new file mode 100644 index 0000000..e0b753b --- /dev/null +++ b/netty-handler/src/test/java/io/netty/handler/ipfilter/UniqueIpFilterTest.java @@ -0,0 +1,76 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.ipfilter; + +import io.netty.channel.ChannelHandler; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.util.internal.SocketUtils; +import org.junit.jupiter.api.Test; + +import java.net.SocketAddress; +import java.util.concurrent.Callable; +import java.util.concurrent.CyclicBarrier; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class UniqueIpFilterTest { + + @Test + public void testUniqueIpFilterHandler() throws ExecutionException, InterruptedException { + final CyclicBarrier barrier = new CyclicBarrier(2); + ExecutorService executorService = Executors.newFixedThreadPool(2); + try { + for (int round = 0; round < 10000; round++) { + final UniqueIpFilter ipFilter = new UniqueIpFilter(); + Future future1 = newChannelAsync(barrier, executorService, ipFilter); + Future future2 = newChannelAsync(barrier, executorService, ipFilter); + EmbeddedChannel ch1 = future1.get(); + EmbeddedChannel ch2 = future2.get(); + assertTrue(ch1.isActive() || ch2.isActive()); + assertFalse(ch1.isActive() && ch2.isActive()); + + barrier.reset(); + ch1.close().await(); + ch2.close().await(); + } + } finally { + executorService.shutdown(); + } + } + + private static Future newChannelAsync(final CyclicBarrier barrier, + ExecutorService executorService, + final ChannelHandler... handler) { + return executorService.submit(new Callable() { + @Override + public EmbeddedChannel call() throws Exception { + barrier.await(); + return new EmbeddedChannel(handler) { + @Override + protected SocketAddress remoteAddress0() { + return isActive() ? SocketUtils.socketAddress("91.92.93.1", 5421) : null; + } + }; + } + }); + } + +} diff --git a/netty-handler/src/test/java/io/netty/handler/pcap/CloseDetectingByteBufOutputStream.java b/netty-handler/src/test/java/io/netty/handler/pcap/CloseDetectingByteBufOutputStream.java new file mode 100644 index 0000000..3978672 --- /dev/null +++ b/netty-handler/src/test/java/io/netty/handler/pcap/CloseDetectingByteBufOutputStream.java @@ -0,0 +1,46 @@ +/* + * Copyright 2023 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.pcap; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufOutputStream; + +import java.io.IOException; + +/** + * A {@link ByteBufOutputStream} which detects if {@link #close()} was called. + */ +final class CloseDetectingByteBufOutputStream extends ByteBufOutputStream { + + private boolean isCloseCalled; + + /** + * Creates a new stream which writes data to the specified {@code buffer}. + */ + CloseDetectingByteBufOutputStream(ByteBuf buffer) { + super(buffer); + } + + public boolean closeCalled() { + return isCloseCalled; + } + + @Override + public void close() throws IOException { + super.close(); + isCloseCalled = true; + } +} diff --git a/netty-handler/src/test/java/io/netty/handler/pcap/DiscardingStatsOutputStream.java b/netty-handler/src/test/java/io/netty/handler/pcap/DiscardingStatsOutputStream.java new file mode 100644 index 0000000..ebb736c --- /dev/null +++ b/netty-handler/src/test/java/io/netty/handler/pcap/DiscardingStatsOutputStream.java @@ -0,0 +1,38 @@ +/* + * Copyright 2023 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.pcap; + +import java.io.IOException; +import java.io.OutputStream; + +final class DiscardingStatsOutputStream extends OutputStream { + + private int writesCount; + + @Override + public void write(int b) throws IOException { + // NO-OP + } + + @Override + public void write(byte[] b, int off, int len) throws IOException { + writesCount++; + } + + int writesCalled() { + return writesCount; + } +} diff --git a/netty-handler/src/test/java/io/netty/handler/pcap/PcapWriteHandlerTest.java b/netty-handler/src/test/java/io/netty/handler/pcap/PcapWriteHandlerTest.java new file mode 100644 index 0000000..b8f1bff --- /dev/null +++ b/netty-handler/src/test/java/io/netty/handler/pcap/PcapWriteHandlerTest.java @@ -0,0 +1,680 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.pcap; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufOutputStream; +import io.netty.buffer.ByteBufUtil; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.ChannelOption; +import io.netty.channel.ChannelPipeline; +import io.netty.channel.ChannelPromise; +import io.netty.channel.DefaultChannelId; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.SimpleChannelInboundHandler; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.DatagramChannel; +import io.netty.channel.socket.DatagramChannelConfig; +import io.netty.channel.socket.DatagramPacket; +import io.netty.channel.socket.DefaultDatagramChannelConfig; +import io.netty.channel.socket.SocketChannel; +import io.netty.channel.socket.nio.NioDatagramChannel; +import io.netty.channel.socket.nio.NioServerSocketChannel; +import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.util.CharsetUtil; +import io.netty.util.NetUtil; +import io.netty.util.concurrent.Promise; + +import java.io.OutputStream; +import java.net.DatagramSocket; +import java.net.InetAddress; +import java.net.NetworkInterface; +import java.net.SocketAddress; +import java.net.SocketException; +import java.util.concurrent.TimeUnit; +import org.junit.jupiter.api.Test; + +import java.net.Inet4Address; +import java.net.InetSocketAddress; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class PcapWriteHandlerTest { + + @Test + public void udpV4SharedOutputStreamTest() throws InterruptedException { + udpV4(true); + } + + @Test + public void udpV4NonOutputStream() throws InterruptedException { + udpV4(false); + } + + private static void udpV4(boolean sharedOutputStream) throws InterruptedException { + ByteBuf byteBuf = Unpooled.buffer(); + + InetSocketAddress serverAddr = new InetSocketAddress("127.0.0.1", 0); + InetSocketAddress clientAddr = new InetSocketAddress("127.0.0.1", 0); + + NioEventLoopGroup eventLoopGroup = new NioEventLoopGroup(2); + + // We'll bootstrap a UDP Server to avoid "Network Unreachable errors" when sending UDP Packet. + Bootstrap server = new Bootstrap() + .group(eventLoopGroup) + .channel(NioDatagramChannel.class) + .handler(new SimpleChannelInboundHandler() { + @Override + protected void channelRead0(ChannelHandlerContext ctx, DatagramPacket msg) { + // Discard + } + }); + + ChannelFuture channelFutureServer = server.bind(serverAddr).sync(); + assertTrue(channelFutureServer.isSuccess()); + + CloseDetectingByteBufOutputStream outputStream = new CloseDetectingByteBufOutputStream(byteBuf); + + // We'll bootstrap a UDP Client for sending UDP Packets to UDP Server. + Bootstrap client = new Bootstrap() + .group(eventLoopGroup) + .channel(NioDatagramChannel.class) + .handler(PcapWriteHandler.builder() + .sharedOutputStream(sharedOutputStream) + .build(outputStream)); + + ChannelFuture channelFutureClient = + client.connect(channelFutureServer.channel().localAddress(), clientAddr).sync(); + assertTrue(channelFutureClient.isSuccess()); + + Channel clientChannel = channelFutureClient.channel(); + assertTrue(clientChannel.writeAndFlush(Unpooled.wrappedBuffer("Meow".getBytes())).sync().isSuccess()); + assertTrue(eventLoopGroup.shutdownGracefully().sync().isSuccess()); + + verifyUdpCapture(!sharedOutputStream, // if sharedOutputStream is true, we don't verify the global headers. + byteBuf, + (InetSocketAddress) clientChannel.remoteAddress(), + (InetSocketAddress) clientChannel.localAddress() + ); + + // If sharedOutputStream is true, we don't close the outputStream. + // If sharedOutputStream is false, we close the outputStream. + assertEquals(!sharedOutputStream, outputStream.closeCalled()); + } + + @Test + public void embeddedUdp() { + final ByteBuf pcapBuffer = Unpooled.buffer(); + final ByteBuf payload = Unpooled.wrappedBuffer("Meow".getBytes()); + + InetSocketAddress serverAddr = new InetSocketAddress("1.1.1.1", 1234); + InetSocketAddress clientAddr = new InetSocketAddress("2.2.2.2", 3456); + + // We fake a client + EmbeddedChannel embeddedChannel = new EmbeddedChannel( + PcapWriteHandler.builder() + .forceUdpChannel(clientAddr, serverAddr) + .build(new ByteBufOutputStream(pcapBuffer)) + ); + + assertTrue(embeddedChannel.writeOutbound(payload)); + assertEquals(payload, embeddedChannel.readOutbound()); + + // Verify the capture data + verifyUdpCapture(true, pcapBuffer, serverAddr, clientAddr); + + assertFalse(embeddedChannel.finishAndReleaseAll()); + } + + @Test + public void udpMixedAddress() throws SocketException { + final ByteBuf pcapBuffer = Unpooled.buffer(); + final ByteBuf payload = Unpooled.wrappedBuffer("Meow".getBytes()); + + InetSocketAddress serverAddr = new InetSocketAddress("1.1.1.1", 1234); + // for ipv6 ::, it's allowed to connect to ipv4 on some systems + InetSocketAddress clientAddr = new InetSocketAddress("::", 3456); + + // We fake a client + EmbeddedChannel embeddedChannel = new EmbeddedDatagramChannel(clientAddr, serverAddr); + embeddedChannel.pipeline().addLast(PcapWriteHandler.builder() + .build(new ByteBufOutputStream(pcapBuffer))); + + assertTrue(embeddedChannel.writeOutbound(payload)); + assertEquals(payload, embeddedChannel.readOutbound()); + + // Verify the capture data + verifyUdpCapture(true, pcapBuffer, serverAddr, new InetSocketAddress("0.0.0.0", 3456)); + + assertFalse(embeddedChannel.finishAndReleaseAll()); + } + + @Test + public void tcpV4SharedOutputStreamTest() throws Exception { + tcpV4(true); + } + + @Test + public void tcpV4NonOutputStream() throws Exception { + tcpV4(false); + } + + private static void tcpV4(final boolean sharedOutputStream) throws Exception { + final ByteBuf byteBuf = Unpooled.buffer(); + + EventLoopGroup bossGroup = new NioEventLoopGroup(1); + EventLoopGroup clientGroup = new NioEventLoopGroup(); + + // Configure the echo server + ServerBootstrap sb = new ServerBootstrap(); + final Promise dataReadPromise = bossGroup.next().newPromise(); + sb.group(bossGroup) + .channel(NioServerSocketChannel.class) + .option(ChannelOption.SO_BACKLOG, 100) + .childHandler(new ChannelInitializer() { + @Override + public void initChannel(SocketChannel ch) throws Exception { + ChannelPipeline p = ch.pipeline(); + p.addLast(PcapWriteHandler.builder().sharedOutputStream(sharedOutputStream) + .build(new ByteBufOutputStream(byteBuf))); + p.addLast(new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + ctx.write(msg); + } + + @Override + public void channelReadComplete(ChannelHandlerContext ctx) { + ctx.flush(); + dataReadPromise.setSuccess(true); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + ctx.close(); + } + }); + } + }); + + // Start the server. + ChannelFuture serverChannelFuture = sb.bind(new InetSocketAddress("127.0.0.1", 0)).sync(); + assertTrue(serverChannelFuture.isSuccess()); + + // configure the client + Bootstrap cb = new Bootstrap(); + final Promise dataWrittenPromise = clientGroup.next().newPromise(); + cb.group(clientGroup) + .channel(NioSocketChannel.class) + .option(ChannelOption.TCP_NODELAY, true) + .handler(new ChannelInitializer() { + @Override + public void initChannel(SocketChannel ch) throws Exception { + ChannelPipeline p = ch.pipeline(); + p.addLast(new ChannelInboundHandlerAdapter() { + @Override + public void channelActive(ChannelHandlerContext ctx) { + ctx.writeAndFlush(Unpooled.wrappedBuffer("Meow".getBytes())); + dataWrittenPromise.setSuccess(true); + } + + @Override + public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) { + ctx.close(); + } + }); + } + }); + + // Start the client. + ChannelFuture clientChannelFuture = cb.connect(serverChannelFuture.channel().localAddress()).sync(); + assertTrue(clientChannelFuture.isSuccess()); + + assertTrue(dataWrittenPromise.await(5, TimeUnit.SECONDS)); + assertTrue(dataReadPromise.await(5, TimeUnit.SECONDS)); + + clientChannelFuture.channel().close().sync(); + serverChannelFuture.channel().close().sync(); + + // Shut down all event loops to terminate all threads. + assertTrue(clientGroup.shutdownGracefully().sync().isSuccess()); + assertTrue(bossGroup.shutdownGracefully().sync().isSuccess()); + + verifyTcpCapture( + !sharedOutputStream, // if sharedOutputStream is true, we don't verify the global headers. + byteBuf, + (InetSocketAddress) serverChannelFuture.channel().localAddress(), + (InetSocketAddress) clientChannelFuture.channel().localAddress() + ); + } + + @Test + public void embeddedTcp() { + final ByteBuf pcapBuffer = Unpooled.buffer(); + final ByteBuf payload = Unpooled.wrappedBuffer("Meow".getBytes()); + + InetSocketAddress serverAddr = new InetSocketAddress("1.1.1.1", 1234); + InetSocketAddress clientAddr = new InetSocketAddress("2.2.2.2", 3456); + + EmbeddedChannel embeddedChannel = new EmbeddedChannel( + PcapWriteHandler.builder() + .forceTcpChannel(serverAddr, clientAddr, true) + .build(new ByteBufOutputStream(pcapBuffer)) + ); + + assertTrue(embeddedChannel.writeInbound(payload)); + assertEquals(payload, embeddedChannel.readInbound()); + + assertTrue(embeddedChannel.writeOutbound(payload)); + assertEquals(payload, embeddedChannel.readOutbound()); + + // Verify the capture data + verifyTcpCapture(true, pcapBuffer, serverAddr, clientAddr); + + assertFalse(embeddedChannel.finishAndReleaseAll()); + } + + @Test + public void writerStateTest() throws Exception { + final ByteBuf payload = Unpooled.wrappedBuffer("Meow".getBytes()); + final InetSocketAddress serverAddr = new InetSocketAddress("1.1.1.1", 1234); + final InetSocketAddress clientAddr = new InetSocketAddress("2.2.2.2", 3456); + + PcapWriteHandler pcapWriteHandler = PcapWriteHandler.builder() + .forceTcpChannel(serverAddr, clientAddr, true) + .build(new OutputStream() { + @Override + public void write(int b) { + // Discard everything + } + }); + + // State is INIT because we haven't written anything yet + // and 'channelActive' is not called yet as this Handler + // is yet to be attached to `EmbeddedChannel`. + assertEquals(State.INIT, pcapWriteHandler.state()); + + // Create a new 'EmbeddedChannel' and add the 'PcapWriteHandler' + EmbeddedChannel embeddedChannel = new EmbeddedChannel(pcapWriteHandler); + + // Write and read some data and verify it. + assertTrue(embeddedChannel.writeInbound(payload)); + assertEquals(payload, embeddedChannel.readInbound()); + + assertTrue(embeddedChannel.writeOutbound(payload)); + assertEquals(payload, embeddedChannel.readOutbound()); + + // State is now STARTED because we attached Handler to 'EmbeddedChannel'. + assertEquals(State.WRITING, pcapWriteHandler.state()); + + // Close the PcapWriter. This should trigger closure of PcapWriteHandler too. + pcapWriteHandler.pCapWriter().close(); + + // State should be changed to closed by now + assertEquals(State.CLOSED, pcapWriteHandler.state()); + + // Close PcapWriteHandler again. This should be a no-op. + pcapWriteHandler.close(); + + // State should still be CLOSED. No change. + assertEquals(State.CLOSED, pcapWriteHandler.state()); + + // Close the 'EmbeddedChannel'. + assertFalse(embeddedChannel.finishAndReleaseAll()); + } + + @Test + public void pauseResumeTest() throws Exception { + final byte[] payload = "Meow".getBytes(); + final InetSocketAddress serverAddr = new InetSocketAddress("1.1.1.1", 1234); + final InetSocketAddress clientAddr = new InetSocketAddress("2.2.2.2", 3456); + + DiscardingStatsOutputStream discardingStatsOutputStream = new DiscardingStatsOutputStream(); + PcapWriteHandler pcapWriteHandler = PcapWriteHandler.builder() + .forceTcpChannel(serverAddr, clientAddr, true) + .build(discardingStatsOutputStream); + + // Verify that no writes have been called yet. + assertEquals(0, discardingStatsOutputStream.writesCalled()); + + // Create a new 'EmbeddedChannel' and add the 'PcapWriteHandler' + EmbeddedChannel embeddedChannel = new EmbeddedChannel(pcapWriteHandler); + + for (int i = 0; i < 10; i++) { + assertTrue(embeddedChannel.writeInbound(Unpooled.wrappedBuffer(payload))); + } + + // Since we have written 10 times, we should have a value greater than 0. + // We can't say it will be 10 exactly because there will be pcap headers also. + final int initialWritesCalled = discardingStatsOutputStream.writesCalled(); + assertThat(initialWritesCalled).isGreaterThan(0); + + // Pause the Pcap + pcapWriteHandler.pause(); + assertEquals(State.PAUSED, pcapWriteHandler.state()); + + // Write 100 times. No writes should be called in OutputStream. + for (int i = 0; i < 100; i++) { + assertTrue(embeddedChannel.writeInbound(Unpooled.wrappedBuffer(payload))); + } + + // Current stats and previous stats should be same. + assertEquals(initialWritesCalled, discardingStatsOutputStream.writesCalled()); + + // Let's resume the Pcap now. + pcapWriteHandler.resume(); + assertEquals(State.WRITING, pcapWriteHandler.state()); + + // Write 100 times. Writes should be called in OutputStream now. + for (int i = 0; i < 100; i++) { + assertTrue(embeddedChannel.writeInbound(Unpooled.wrappedBuffer(payload))); + } + + // Verify we have written more than before. + assertThat(discardingStatsOutputStream.writesCalled()).isGreaterThan(initialWritesCalled); + + // Close PcapWriteHandler again. This should be a no-op. + pcapWriteHandler.close(); + + // State should still be CLOSED. No change. + assertEquals(State.CLOSED, pcapWriteHandler.state()); + + // Close the 'EmbeddedChannel'. + assertTrue(embeddedChannel.finishAndReleaseAll()); + } + + private static void verifyGlobalHeaders(ByteBuf byteBuf) { + assertEquals(0xa1b2c3d4, byteBuf.readInt()); // magic_number + assertEquals(2, byteBuf.readShort()); // version_major + assertEquals(4, byteBuf.readShort()); // version_minor + assertEquals(0, byteBuf.readInt()); // thiszone + assertEquals(0, byteBuf.readInt()); // sigfigs + assertEquals(0xffff, byteBuf.readInt()); // snaplen + assertEquals(1, byteBuf.readInt()); // network + } + + private static void verifyTcpCapture(boolean verifyGlobalHeaders, ByteBuf byteBuf, + InetSocketAddress serverAddr, InetSocketAddress clientAddr) { + // note: right now, this method only checks the first packet, which is part of the fake three-way handshake. + + if (verifyGlobalHeaders) { + verifyGlobalHeaders(byteBuf); + } + + // Verify Pcap Packet Header + byteBuf.readInt(); // Just read, we don't care about timestamps for now + byteBuf.readInt(); // Just read, we don't care about timestamps for now + assertEquals(54, byteBuf.readInt()); // Length of Packet Saved In Pcap + assertEquals(54, byteBuf.readInt()); // Actual Length of Packet + + // -------------------------------------------- Verify Packet -------------------------------------------- + // Verify Ethernet Packet + ByteBuf ethernetPacket = byteBuf.readSlice(54); + ByteBuf dstMac = ethernetPacket.readSlice(6); + ByteBuf srcMac = ethernetPacket.readSlice(6); + assertArrayEquals(new byte[]{0, 0, 94, 0, 83, -1}, ByteBufUtil.getBytes(dstMac)); + assertArrayEquals(new byte[]{0, 0, 94, 0, 83, 0}, ByteBufUtil.getBytes(srcMac)); + assertEquals(0x0800, ethernetPacket.readShort()); + + // Verify IPv4 Packet + ByteBuf ipv4Packet = ethernetPacket.readSlice(32); + assertEquals(0x45, ipv4Packet.readByte()); // Version + IHL + assertEquals(0x00, ipv4Packet.readByte()); // DSCP + assertEquals(40, ipv4Packet.readShort()); // Length + assertEquals(0x0000, ipv4Packet.readShort()); // Identification + assertEquals(0x0000, ipv4Packet.readShort()); // Fragment + assertEquals((byte) 0xff, ipv4Packet.readByte()); // TTL + assertEquals((byte) 6, ipv4Packet.readByte()); // Protocol + assertEquals(0, ipv4Packet.readShort()); // Checksum + // Source IPv4 Address + ipv4Packet.readInt(); + // Destination IPv4 Address + assertEquals(NetUtil.ipv4AddressToInt((Inet4Address) serverAddr.getAddress()), ipv4Packet.readInt()); + + // Verify ports + ByteBuf tcpPacket = ipv4Packet.readSlice(12); + assertEquals(clientAddr.getPort() & 0xffff, tcpPacket.readUnsignedShort()); // Source Port + assertEquals(serverAddr.getPort() & 0xffff, tcpPacket.readUnsignedShort()); // Destination Port + } + + private static void verifyUdpCapture(boolean verifyGlobalHeaders, ByteBuf byteBuf, + InetSocketAddress remoteAddress, InetSocketAddress localAddress) { + if (verifyGlobalHeaders) { + verifyGlobalHeaders(byteBuf); + } + + // Verify Pcap Packet Header + byteBuf.readInt(); // Just read, we don't care about timestamps for now + byteBuf.readInt(); // Just read, we don't care about timestamps for now + assertEquals(46, byteBuf.readInt()); // Length of Packet Saved In Pcap + assertEquals(46, byteBuf.readInt()); // Actual Length of Packet + + // -------------------------------------------- Verify Packet -------------------------------------------- + // Verify Ethernet Packet + ByteBuf ethernetPacket = byteBuf.readBytes(46); + ByteBuf dstMac = ethernetPacket.readBytes(6); + ByteBuf srcMac = ethernetPacket.readBytes(6); + assertArrayEquals(new byte[]{0, 0, 94, 0, 83, -1}, ByteBufUtil.getBytes(dstMac)); + assertArrayEquals(new byte[]{0, 0, 94, 0, 83, 0}, ByteBufUtil.getBytes(srcMac)); + assertEquals(0x0800, ethernetPacket.readShort()); + + // Verify IPv4 Packet + ByteBuf ipv4Packet = ethernetPacket.readBytes(32); + assertEquals(0x45, ipv4Packet.readByte()); // Version + IHL + assertEquals(0x00, ipv4Packet.readByte()); // DSCP + assertEquals(32, ipv4Packet.readShort()); // Length + assertEquals(0x0000, ipv4Packet.readShort()); // Identification + assertEquals(0x0000, ipv4Packet.readShort()); // Fragment + assertEquals((byte) 0xff, ipv4Packet.readByte()); // TTL + assertEquals((byte) 17, ipv4Packet.readByte()); // Protocol + assertEquals(0, ipv4Packet.readShort()); // Checksum + + // Source IPv4 Address + assertEquals(NetUtil.ipv4AddressToInt((Inet4Address) localAddress.getAddress()), ipv4Packet.readInt()); + + // Destination IPv4 Address + assertEquals(NetUtil.ipv4AddressToInt((Inet4Address) remoteAddress.getAddress()), ipv4Packet.readInt()); + + // Verify UDP Packet + ByteBuf udpPacket = ipv4Packet.readBytes(12); + assertEquals(localAddress.getPort() & 0xffff, udpPacket.readUnsignedShort()); // Source Port + assertEquals(remoteAddress.getPort() & 0xffff, udpPacket.readUnsignedShort()); // Destination Port + assertEquals(12, udpPacket.readShort()); // Length + assertEquals(0x0001, udpPacket.readShort()); // Checksum + + // Verify Payload + ByteBuf payload = udpPacket.readBytes(4); + assertArrayEquals("Meow".getBytes(CharsetUtil.UTF_8), ByteBufUtil.getBytes(payload)); // Payload + + // Release all ByteBuf + assertTrue(dstMac.release()); + assertTrue(srcMac.release()); + assertTrue(payload.release()); + assertTrue(byteBuf.release()); + assertTrue(ethernetPacket.release()); + assertTrue(ipv4Packet.release()); + assertTrue(udpPacket.release()); + } + + private static class EmbeddedDatagramChannel extends EmbeddedChannel implements DatagramChannel { + private final InetSocketAddress local; + private final InetSocketAddress remote; + private DatagramChannelConfig config; + + EmbeddedDatagramChannel(InetSocketAddress local, InetSocketAddress remote) { + super(DefaultChannelId.newInstance(), false); + this.local = local; + this.remote = remote; + } + + @Override + public boolean isConnected() { + return true; + } + + @Override + public DatagramChannelConfig config() { + if (config == null) { + // ick! config() is called by the super constructor, so we need to do this. + try { + config = new DefaultDatagramChannelConfig(this, new DatagramSocket()); + } catch (SocketException e) { + throw new RuntimeException(e); + } + } + return config; + } + + @Override + public InetSocketAddress localAddress() { + return (InetSocketAddress) super.localAddress(); + } + + @Override + public InetSocketAddress remoteAddress() { + return (InetSocketAddress) super.remoteAddress(); + } + + @Override + protected SocketAddress localAddress0() { + return local; + } + + @Override + protected SocketAddress remoteAddress0() { + return remote; + } + + @Override + public ChannelFuture joinGroup(InetAddress multicastAddress) { + throw new UnsupportedOperationException(); + } + + @Override + public ChannelFuture joinGroup(InetAddress multicastAddress, ChannelPromise future) { + throw new UnsupportedOperationException(); + } + + @Override + public ChannelFuture joinGroup(InetSocketAddress multicastAddress, NetworkInterface networkInterface) { + throw new UnsupportedOperationException(); + } + + @Override + public ChannelFuture joinGroup( + InetSocketAddress multicastAddress, + NetworkInterface networkInterface, + ChannelPromise future) { + throw new UnsupportedOperationException(); + } + + @Override + public ChannelFuture joinGroup( + InetAddress multicastAddress, + NetworkInterface networkInterface, + InetAddress source) { + throw new UnsupportedOperationException(); + } + + @Override + public ChannelFuture joinGroup( + InetAddress multicastAddress, + NetworkInterface networkInterface, + InetAddress source, + ChannelPromise future) { + throw new UnsupportedOperationException(); + } + + @Override + public ChannelFuture leaveGroup(InetAddress multicastAddress) { + throw new UnsupportedOperationException(); + } + + @Override + public ChannelFuture leaveGroup(InetAddress multicastAddress, ChannelPromise future) { + throw new UnsupportedOperationException(); + } + + @Override + public ChannelFuture leaveGroup(InetSocketAddress multicastAddress, NetworkInterface networkInterface) { + throw new UnsupportedOperationException(); + } + + @Override + public ChannelFuture leaveGroup( + InetSocketAddress multicastAddress, + NetworkInterface networkInterface, + ChannelPromise future) { + throw new UnsupportedOperationException(); + } + + @Override + public ChannelFuture leaveGroup( + InetAddress multicastAddress, + NetworkInterface networkInterface, + InetAddress source) { + throw new UnsupportedOperationException(); + } + + @Override + public ChannelFuture leaveGroup( + InetAddress multicastAddress, + NetworkInterface networkInterface, + InetAddress source, + ChannelPromise future) { + throw new UnsupportedOperationException(); + } + + @Override + public ChannelFuture block( + InetAddress multicastAddress, + NetworkInterface networkInterface, + InetAddress sourceToBlock) { + throw new UnsupportedOperationException(); + } + + @Override + public ChannelFuture block( + InetAddress multicastAddress, + NetworkInterface networkInterface, + InetAddress sourceToBlock, + ChannelPromise future) { + throw new UnsupportedOperationException(); + } + + @Override + public ChannelFuture block(InetAddress multicastAddress, InetAddress sourceToBlock) { + throw new UnsupportedOperationException(); + } + + @Override + public ChannelFuture block(InetAddress multicastAddress, InetAddress sourceToBlock, ChannelPromise future) { + throw new UnsupportedOperationException(); + } + } +} diff --git a/netty-handler/src/test/java/io/netty/handler/stream/ChunkedStreamTest.java b/netty-handler/src/test/java/io/netty/handler/stream/ChunkedStreamTest.java new file mode 100644 index 0000000..081f07c --- /dev/null +++ b/netty-handler/src/test/java/io/netty/handler/stream/ChunkedStreamTest.java @@ -0,0 +1,51 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.stream; + +import io.netty.buffer.UnpooledByteBufAllocator; +import org.junit.jupiter.api.Test; + +import java.io.InputStream; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class ChunkedStreamTest { + + @Test + public void writeTest() throws Exception { + ChunkedStream chunkedStream = new ChunkedStream(new InputStream() { + @Override + public int read() { + return -1; + } + + @Override + public int available() { + return 1; + } + }); + + assertFalse(chunkedStream.isEndOfInput()); + assertNull(chunkedStream.readChunk(UnpooledByteBufAllocator.DEFAULT)); + assertEquals(0, chunkedStream.progress()); + chunkedStream.close(); + assertTrue(chunkedStream.isEndOfInput()); + assertNull(chunkedStream.readChunk(UnpooledByteBufAllocator.DEFAULT)); + } +} diff --git a/netty-handler/src/test/java/io/netty/handler/stream/ChunkedWriteHandlerTest.java b/netty-handler/src/test/java/io/netty/handler/stream/ChunkedWriteHandlerTest.java new file mode 100644 index 0000000..f4aeda3 --- /dev/null +++ b/netty-handler/src/test/java/io/netty/handler/stream/ChunkedWriteHandlerTest.java @@ -0,0 +1,855 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.stream; + +import io.netty.buffer.ByteBuf; +import io.netty.buffer.ByteBufAllocator; +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelOutboundHandlerAdapter; +import io.netty.channel.ChannelPromise; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.util.CharsetUtil; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.internal.PlatformDependent; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import java.io.ByteArrayInputStream; +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.RandomAccessFile; +import java.nio.channels.Channels; +import java.nio.channels.ClosedChannelException; +import java.nio.channels.FileChannel; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; + +import static java.util.concurrent.TimeUnit.*; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class ChunkedWriteHandlerTest { + private static final byte[] BYTES = new byte[1024 * 64]; + private static final File TMP; + + static { + for (int i = 0; i < BYTES.length; i++) { + BYTES[i] = (byte) i; + } + + FileOutputStream out = null; + try { + TMP = PlatformDependent.createTempFile("netty-chunk-", ".tmp", null); + TMP.deleteOnExit(); + out = new FileOutputStream(TMP); + out.write(BYTES); + out.flush(); + } catch (IOException e) { + throw new RuntimeException(e); + } finally { + if (out != null) { + try { + out.close(); + } catch (IOException e) { + // ignore + } + } + } + } + + // See #310 + @Test + public void testChunkedStream() { + check(new ChunkedStream(new ByteArrayInputStream(BYTES))); + + check(new ChunkedStream(new ByteArrayInputStream(BYTES)), + new ChunkedStream(new ByteArrayInputStream(BYTES)), + new ChunkedStream(new ByteArrayInputStream(BYTES))); + } + + @Test + public void testChunkedNioStream() { + check(new ChunkedNioStream(Channels.newChannel(new ByteArrayInputStream(BYTES)))); + + check(new ChunkedNioStream(Channels.newChannel(new ByteArrayInputStream(BYTES))), + new ChunkedNioStream(Channels.newChannel(new ByteArrayInputStream(BYTES))), + new ChunkedNioStream(Channels.newChannel(new ByteArrayInputStream(BYTES)))); + } + + @Test + public void testChunkedFile() throws IOException { + check(new ChunkedFile(TMP)); + + check(new ChunkedFile(TMP), new ChunkedFile(TMP), new ChunkedFile(TMP)); + } + + @Test + public void testChunkedNioFile() throws IOException { + check(new ChunkedNioFile(TMP)); + + check(new ChunkedNioFile(TMP), new ChunkedNioFile(TMP), new ChunkedNioFile(TMP)); + } + + @Test + public void testChunkedNioFileLeftPositionUnchanged() throws IOException { + FileChannel in = null; + final long expectedPosition = 10; + try { + in = new RandomAccessFile(TMP, "r").getChannel(); + in.position(expectedPosition); + check(new ChunkedNioFile(in) { + @Override + public void close() throws Exception { + //no op + } + }); + assertTrue(in.isOpen()); + assertEquals(expectedPosition, in.position()); + } finally { + if (in != null) { + in.close(); + } + } + } + + @Test + public void testChunkedNioFileFailOnClosedFileChannel() throws IOException { + final FileChannel in = new RandomAccessFile(TMP, "r").getChannel(); + in.close(); + + assertThrows(ClosedChannelException.class, new Executable() { + @Override + public void execute() throws Throwable { + check(new ChunkedNioFile(in) { + @Override + public void close() throws Exception { + //no op + } + }); + } + }); + } + + @Test + public void testUnchunkedData() throws IOException { + check(Unpooled.wrappedBuffer(BYTES)); + + check(Unpooled.wrappedBuffer(BYTES), Unpooled.wrappedBuffer(BYTES), Unpooled.wrappedBuffer(BYTES)); + } + + // Test case which shows that there is not a bug like stated here: + // https://stackoverflow.com/a/10426305 + @Test + public void testListenerNotifiedWhenIsEnd() { + ByteBuf buffer = Unpooled.copiedBuffer("Test", CharsetUtil.ISO_8859_1); + + ChunkedInput input = new ChunkedInput() { + private boolean done; + private final ByteBuf buffer = Unpooled.copiedBuffer("Test", CharsetUtil.ISO_8859_1); + + @Override + public boolean isEndOfInput() throws Exception { + return done; + } + + @Override + public void close() throws Exception { + buffer.release(); + } + + @Deprecated + @Override + public ByteBuf readChunk(ChannelHandlerContext ctx) throws Exception { + return readChunk(ctx.alloc()); + } + + @Override + public ByteBuf readChunk(ByteBufAllocator allocator) throws Exception { + if (done) { + return null; + } + done = true; + return buffer.retainedDuplicate(); + } + + @Override + public long length() { + return -1; + } + + @Override + public long progress() { + return 1; + } + }; + + final AtomicBoolean listenerNotified = new AtomicBoolean(false); + final ChannelFutureListener listener = new ChannelFutureListener() { + + @Override + public void operationComplete(ChannelFuture future) throws Exception { + listenerNotified.set(true); + } + }; + + EmbeddedChannel ch = new EmbeddedChannel(new ChunkedWriteHandler()); + ch.writeAndFlush(input).addListener(listener).syncUninterruptibly(); + assertTrue(ch.finish()); + + // the listener should have been notified + assertTrue(listenerNotified.get()); + + ByteBuf buffer2 = ch.readOutbound(); + assertEquals(buffer, buffer2); + assertNull(ch.readOutbound()); + + buffer.release(); + buffer2.release(); + } + + @Test + public void testChunkedMessageInput() { + + ChunkedInput input = new ChunkedInput() { + private boolean done; + + @Override + public boolean isEndOfInput() throws Exception { + return done; + } + + @Override + public void close() throws Exception { + // NOOP + } + + @Deprecated + @Override + public Object readChunk(ChannelHandlerContext ctx) throws Exception { + return readChunk(ctx.alloc()); + } + + @Override + public Object readChunk(ByteBufAllocator ctx) throws Exception { + if (done) { + return false; + } + done = true; + return 0; + } + + @Override + public long length() { + return -1; + } + + @Override + public long progress() { + return 1; + } + }; + + EmbeddedChannel ch = new EmbeddedChannel(new ChunkedWriteHandler()); + ch.writeAndFlush(input).syncUninterruptibly(); + assertTrue(ch.finish()); + + assertEquals(0, (Integer) ch.readOutbound()); + assertNull(ch.readOutbound()); + } + + @Test + public void testWriteFailureChunkedStream() throws IOException { + checkFirstFailed(new ChunkedStream(new ByteArrayInputStream(BYTES))); + } + + @Test + public void testWriteFailureChunkedNioStream() throws IOException { + checkFirstFailed(new ChunkedNioStream(Channels.newChannel(new ByteArrayInputStream(BYTES)))); + } + + @Test + public void testWriteFailureChunkedFile() throws IOException { + checkFirstFailed(new ChunkedFile(TMP)); + } + + @Test + public void testWriteFailureChunkedNioFile() throws IOException { + checkFirstFailed(new ChunkedNioFile(TMP)); + } + + @Test + public void testWriteFailureUnchunkedData() throws IOException { + checkFirstFailed(Unpooled.wrappedBuffer(BYTES)); + } + + @Test + public void testSkipAfterFailedChunkedStream() throws IOException { + checkSkipFailed(new ChunkedStream(new ByteArrayInputStream(BYTES)), + new ChunkedStream(new ByteArrayInputStream(BYTES))); + } + + @Test + public void testSkipAfterFailedChunkedNioStream() throws IOException { + checkSkipFailed(new ChunkedNioStream(Channels.newChannel(new ByteArrayInputStream(BYTES))), + new ChunkedNioStream(Channels.newChannel(new ByteArrayInputStream(BYTES)))); + } + + @Test + public void testSkipAfterFailedChunkedFile() throws IOException { + checkSkipFailed(new ChunkedFile(TMP), new ChunkedFile(TMP)); + } + + @Test + public void testSkipAfterFailedChunkedNioFile() throws IOException { + checkSkipFailed(new ChunkedNioFile(TMP), new ChunkedFile(TMP)); + } + + // See https://github.com/netty/netty/issues/8700. + @Test + public void testFailureWhenLastChunkFailed() throws IOException { + ChannelOutboundHandlerAdapter failLast = new ChannelOutboundHandlerAdapter() { + private int passedWrites; + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) { + if (++this.passedWrites < 4) { + ctx.write(msg, promise); + } else { + ReferenceCountUtil.release(msg); + promise.tryFailure(new RuntimeException()); + } + } + }; + + EmbeddedChannel ch = new EmbeddedChannel(failLast, new ChunkedWriteHandler()); + ChannelFuture r = ch.writeAndFlush(new ChunkedFile(TMP, 1024 * 16)); // 4 chunks + assertTrue(ch.finish()); + + assertFalse(r.isSuccess()); + assertTrue(r.cause() instanceof RuntimeException); + + // 3 out of 4 chunks were already written + int read = 0; + for (;;) { + ByteBuf buffer = ch.readOutbound(); + if (buffer == null) { + break; + } + read += buffer.readableBytes(); + buffer.release(); + } + + assertEquals(1024 * 16 * 3, read); + } + + @Test + public void testDiscardPendingWritesOnInactive() throws IOException { + + final AtomicBoolean closeWasCalled = new AtomicBoolean(false); + + ChunkedInput notifiableInput = new ChunkedInput() { + private boolean done; + private final ByteBuf buffer = Unpooled.copiedBuffer("Test", CharsetUtil.ISO_8859_1); + + @Override + public boolean isEndOfInput() throws Exception { + return done; + } + + @Override + public void close() throws Exception { + buffer.release(); + closeWasCalled.set(true); + } + + @Deprecated + @Override + public ByteBuf readChunk(ChannelHandlerContext ctx) throws Exception { + return readChunk(ctx.alloc()); + } + + @Override + public ByteBuf readChunk(ByteBufAllocator allocator) throws Exception { + if (done) { + return null; + } + done = true; + return buffer.retainedDuplicate(); + } + + @Override + public long length() { + return -1; + } + + @Override + public long progress() { + return 1; + } + }; + + EmbeddedChannel ch = new EmbeddedChannel(new ChunkedWriteHandler()); + + // Write 3 messages and close channel before flushing + ChannelFuture r1 = ch.write(new ChunkedFile(TMP)); + ChannelFuture r2 = ch.write(new ChunkedNioFile(TMP)); + ch.write(notifiableInput); + + // Should be `false` as we do not expect any messages to be written + assertFalse(ch.finish()); + + assertFalse(r1.isSuccess()); + assertFalse(r2.isSuccess()); + assertTrue(closeWasCalled.get()); + } + + // See https://github.com/netty/netty/issues/8700. + @Test + public void testStopConsumingChunksWhenFailed() { + final ByteBuf buffer = Unpooled.copiedBuffer("Test", CharsetUtil.ISO_8859_1); + final AtomicInteger chunks = new AtomicInteger(0); + + ChunkedInput nonClosableInput = new ChunkedInput() { + @Override + public boolean isEndOfInput() throws Exception { + return chunks.get() >= 5; + } + + @Override + public void close() throws Exception { + // no-op + } + + @Deprecated + @Override + public ByteBuf readChunk(ChannelHandlerContext ctx) throws Exception { + return readChunk(ctx.alloc()); + } + + @Override + public ByteBuf readChunk(ByteBufAllocator allocator) throws Exception { + chunks.incrementAndGet(); + return buffer.retainedDuplicate(); + } + + @Override + public long length() { + return -1; + } + + @Override + public long progress() { + return 1; + } + }; + + ChannelOutboundHandlerAdapter noOpWrites = new ChannelOutboundHandlerAdapter() { + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) { + ReferenceCountUtil.release(msg); + promise.tryFailure(new RuntimeException()); + } + }; + + EmbeddedChannel ch = new EmbeddedChannel(noOpWrites, new ChunkedWriteHandler()); + ch.writeAndFlush(nonClosableInput).awaitUninterruptibly(); + // Should be `false` as we do not expect any messages to be written + assertFalse(ch.finish()); + buffer.release(); + + // We should expect only single chunked being read from the input. + // It's possible to get a race condition here between resolving a promise and + // allocating a new chunk, but should be fine when working with embedded channels. + assertEquals(1, chunks.get()); + } + + @Test + public void testCloseSuccessfulChunkedInput() { + int chunks = 10; + TestChunkedInput input = new TestChunkedInput(chunks); + EmbeddedChannel ch = new EmbeddedChannel(new ChunkedWriteHandler()); + + assertTrue(ch.writeOutbound(input)); + + for (int i = 0; i < chunks; i++) { + ByteBuf buf = ch.readOutbound(); + assertEquals(i, buf.readInt()); + buf.release(); + } + + assertTrue(input.isClosed()); + assertFalse(ch.finish()); + } + + @Test + public void testCloseFailedChunkedInput() { + Exception error = new Exception("Unable to produce a chunk"); + final ThrowingChunkedInput input = new ThrowingChunkedInput(error); + final EmbeddedChannel ch = new EmbeddedChannel(new ChunkedWriteHandler()); + + Exception e = assertThrows(Exception.class, new Executable() { + @Override + public void execute() throws Throwable { + ch.writeOutbound(input); + } + }); + assertEquals(error, e); + + assertTrue(input.isClosed()); + assertFalse(ch.finish()); + } + + @Test + public void testWriteListenerInvokedAfterSuccessfulChunkedInputClosed() throws Exception { + final TestChunkedInput input = new TestChunkedInput(2); + EmbeddedChannel ch = new EmbeddedChannel(new ChunkedWriteHandler()); + + final AtomicBoolean inputClosedWhenListenerInvoked = new AtomicBoolean(); + final CountDownLatch listenerInvoked = new CountDownLatch(1); + + ChannelFuture writeFuture = ch.write(input); + writeFuture.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + inputClosedWhenListenerInvoked.set(input.isClosed()); + listenerInvoked.countDown(); + } + }); + ch.flush(); + + assertTrue(listenerInvoked.await(10, SECONDS)); + assertTrue(writeFuture.isSuccess()); + assertTrue(inputClosedWhenListenerInvoked.get()); + assertTrue(ch.finishAndReleaseAll()); + } + + @Test + public void testWriteListenerInvokedAfterFailedChunkedInputClosed() throws Exception { + final ThrowingChunkedInput input = new ThrowingChunkedInput(new RuntimeException()); + EmbeddedChannel ch = new EmbeddedChannel(new ChunkedWriteHandler()); + + final AtomicBoolean inputClosedWhenListenerInvoked = new AtomicBoolean(); + final CountDownLatch listenerInvoked = new CountDownLatch(1); + + ChannelFuture writeFuture = ch.write(input); + writeFuture.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + inputClosedWhenListenerInvoked.set(input.isClosed()); + listenerInvoked.countDown(); + } + }); + ch.flush(); + + assertTrue(listenerInvoked.await(10, SECONDS)); + assertFalse(writeFuture.isSuccess()); + assertTrue(inputClosedWhenListenerInvoked.get()); + assertFalse(ch.finish()); + } + + @Test + public void testWriteListenerInvokedAfterChannelClosedAndInputFullyConsumed() throws Exception { + // use empty input which has endOfInput = true + final TestChunkedInput input = new TestChunkedInput(0); + EmbeddedChannel ch = new EmbeddedChannel(new ChunkedWriteHandler()); + + final AtomicBoolean inputClosedWhenListenerInvoked = new AtomicBoolean(); + final CountDownLatch listenerInvoked = new CountDownLatch(1); + + ChannelFuture writeFuture = ch.write(input); + writeFuture.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + inputClosedWhenListenerInvoked.set(input.isClosed()); + listenerInvoked.countDown(); + } + }); + ch.close(); // close channel to make handler discard the input on subsequent flush + ch.flush(); + + assertTrue(listenerInvoked.await(10, SECONDS)); + assertTrue(writeFuture.isSuccess()); + assertTrue(inputClosedWhenListenerInvoked.get()); + assertFalse(ch.finish()); + } + + @Test + public void testEndOfInputWhenChannelIsClosedwhenWrite() { + ChunkedInput input = new ChunkedInput() { + + @Override + public boolean isEndOfInput() { + return true; + } + + @Override + public void close() { + } + + @Deprecated + @Override + public ByteBuf readChunk(ChannelHandlerContext ctx) { + return null; + } + + @Override + public ByteBuf readChunk(ByteBufAllocator allocator) { + return null; + } + + @Override + public long length() { + return -1; + } + + @Override + public long progress() { + return 1; + } + }; + + EmbeddedChannel ch = new EmbeddedChannel(new ChannelOutboundHandlerAdapter() { + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) throws Exception { + ReferenceCountUtil.release(msg); + // Calling close so we will drop all queued messages in the ChunkedWriteHandler. + ctx.close(); + promise.setSuccess(); + } + }, new ChunkedWriteHandler()); + + ch.writeAndFlush(input).syncUninterruptibly(); + assertFalse(ch.finishAndReleaseAll()); + } + + @Test + public void testWriteListenerInvokedAfterChannelClosedAndInputNotFullyConsumed() throws Exception { + // use non-empty input which has endOfInput = false + final TestChunkedInput input = new TestChunkedInput(42); + EmbeddedChannel ch = new EmbeddedChannel(new ChunkedWriteHandler()); + + final AtomicBoolean inputClosedWhenListenerInvoked = new AtomicBoolean(); + final CountDownLatch listenerInvoked = new CountDownLatch(1); + + ChannelFuture writeFuture = ch.write(input); + writeFuture.addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) { + inputClosedWhenListenerInvoked.set(input.isClosed()); + listenerInvoked.countDown(); + } + }); + ch.close(); // close channel to make handler discard the input on subsequent flush + ch.flush(); + + assertTrue(listenerInvoked.await(10, SECONDS)); + assertFalse(writeFuture.isSuccess()); + assertTrue(inputClosedWhenListenerInvoked.get()); + assertFalse(ch.finish()); + } + + private static void check(Object... inputs) { + EmbeddedChannel ch = new EmbeddedChannel(new ChunkedWriteHandler()); + + for (Object input: inputs) { + ch.writeOutbound(input); + } + + assertTrue(ch.finish()); + + int i = 0; + int read = 0; + for (;;) { + ByteBuf buffer = ch.readOutbound(); + if (buffer == null) { + break; + } + while (buffer.isReadable()) { + assertEquals(BYTES[i++], buffer.readByte()); + read++; + if (i == BYTES.length) { + i = 0; + } + } + buffer.release(); + } + + assertEquals(BYTES.length * inputs.length, read); + } + + private static void checkFirstFailed(Object input) { + ChannelOutboundHandlerAdapter noOpWrites = new ChannelOutboundHandlerAdapter() { + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) { + ReferenceCountUtil.release(msg); + promise.tryFailure(new RuntimeException()); + } + }; + + EmbeddedChannel ch = new EmbeddedChannel(noOpWrites, new ChunkedWriteHandler()); + ChannelFuture r = ch.writeAndFlush(input); + + // Should be `false` as we do not expect any messages to be written + assertFalse(ch.finish()); + assertTrue(r.cause() instanceof RuntimeException); + } + + private static void checkSkipFailed(Object input1, Object input2) { + ChannelOutboundHandlerAdapter failFirst = new ChannelOutboundHandlerAdapter() { + private boolean alreadyFailed; + + @Override + public void write(ChannelHandlerContext ctx, Object msg, ChannelPromise promise) { + if (alreadyFailed) { + ctx.write(msg, promise); + } else { + this.alreadyFailed = true; + ReferenceCountUtil.release(msg); + promise.tryFailure(new RuntimeException()); + } + } + }; + + EmbeddedChannel ch = new EmbeddedChannel(failFirst, new ChunkedWriteHandler()); + ChannelFuture r1 = ch.write(input1); + ChannelFuture r2 = ch.writeAndFlush(input2).awaitUninterruptibly(); + assertTrue(ch.finish()); + + assertTrue(r1.cause() instanceof RuntimeException); + assertTrue(r2.isSuccess()); + + // note, that after we've "skipped" the first write, + // we expect to see the second message, chunk by chunk + int i = 0; + int read = 0; + for (;;) { + ByteBuf buffer = ch.readOutbound(); + if (buffer == null) { + break; + } + while (buffer.isReadable()) { + assertEquals(BYTES[i++], buffer.readByte()); + read++; + if (i == BYTES.length) { + i = 0; + } + } + buffer.release(); + } + + assertEquals(BYTES.length, read); + } + + private static final class TestChunkedInput implements ChunkedInput { + private final int chunksToProduce; + + private int chunksProduced; + private volatile boolean closed; + + TestChunkedInput(int chunksToProduce) { + this.chunksToProduce = chunksToProduce; + } + + @Override + public boolean isEndOfInput() { + return chunksProduced >= chunksToProduce; + } + + @Override + public void close() { + closed = true; + } + + @Override + public ByteBuf readChunk(ChannelHandlerContext ctx) { + return readChunk(ctx.alloc()); + } + + @Override + public ByteBuf readChunk(ByteBufAllocator allocator) { + ByteBuf buf = allocator.buffer(); + buf.writeInt(chunksProduced); + chunksProduced++; + return buf; + } + + @Override + public long length() { + return chunksToProduce; + } + + @Override + public long progress() { + return chunksProduced; + } + + boolean isClosed() { + return closed; + } + } + + private static final class ThrowingChunkedInput implements ChunkedInput { + private final Exception error; + + private volatile boolean closed; + + ThrowingChunkedInput(Exception error) { + this.error = error; + } + + @Override + public boolean isEndOfInput() { + return false; + } + + @Override + public void close() { + closed = true; + } + + @Override + public ByteBuf readChunk(ChannelHandlerContext ctx) throws Exception { + return readChunk(ctx.alloc()); + } + + @Override + public ByteBuf readChunk(ByteBufAllocator allocator) throws Exception { + throw error; + } + + @Override + public long length() { + return -1; + } + + @Override + public long progress() { + return -1; + } + + boolean isClosed() { + return closed; + } + } +} diff --git a/netty-handler/src/test/java/io/netty/handler/timeout/IdleStateEventTest.java b/netty-handler/src/test/java/io/netty/handler/timeout/IdleStateEventTest.java new file mode 100644 index 0000000..87511bc --- /dev/null +++ b/netty-handler/src/test/java/io/netty/handler/timeout/IdleStateEventTest.java @@ -0,0 +1,35 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.timeout; + + +import org.junit.jupiter.api.Test; + +import static io.netty.handler.timeout.IdleStateEvent.*; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.hasToString; + +public class IdleStateEventTest { + @Test + public void testHumanReadableToString() { + assertThat(FIRST_READER_IDLE_STATE_EVENT, hasToString("IdleStateEvent(READER_IDLE, first)")); + assertThat(READER_IDLE_STATE_EVENT, hasToString("IdleStateEvent(READER_IDLE)")); + assertThat(FIRST_WRITER_IDLE_STATE_EVENT, hasToString("IdleStateEvent(WRITER_IDLE, first)")); + assertThat(WRITER_IDLE_STATE_EVENT, hasToString("IdleStateEvent(WRITER_IDLE)")); + assertThat(FIRST_ALL_IDLE_STATE_EVENT, hasToString("IdleStateEvent(ALL_IDLE, first)")); + assertThat(ALL_IDLE_STATE_EVENT, hasToString("IdleStateEvent(ALL_IDLE)")); + } +} diff --git a/netty-handler/src/test/java/io/netty/handler/timeout/IdleStateHandlerTest.java b/netty-handler/src/test/java/io/netty/handler/timeout/IdleStateHandlerTest.java new file mode 100644 index 0000000..2256487 --- /dev/null +++ b/netty-handler/src/test/java/io/netty/handler/timeout/IdleStateHandlerTest.java @@ -0,0 +1,438 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.timeout; + +import io.netty.buffer.Unpooled; +import io.netty.channel.ChannelHandler; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelOutboundBuffer; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.util.ReferenceCountUtil; +import io.netty.util.concurrent.Future; +import org.junit.jupiter.api.Test; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.TimeUnit; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.greaterThanOrEqualTo; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertSame; + +public class IdleStateHandlerTest { + + @Test + public void testReaderIdle() throws Exception { + TestableIdleStateHandler idleStateHandler = new TestableIdleStateHandler( + false, 1L, 0L, 0L, TimeUnit.SECONDS); + + // We start with one FIRST_READER_IDLE_STATE_EVENT, followed by an infinite number of READER_IDLE_STATE_EVENTs + anyIdle(idleStateHandler, IdleStateEvent.FIRST_READER_IDLE_STATE_EVENT, + IdleStateEvent.READER_IDLE_STATE_EVENT, IdleStateEvent.READER_IDLE_STATE_EVENT); + } + + @Test + public void testWriterIdle() throws Exception { + TestableIdleStateHandler idleStateHandler = new TestableIdleStateHandler( + false, 0L, 1L, 0L, TimeUnit.SECONDS); + + anyIdle(idleStateHandler, IdleStateEvent.FIRST_WRITER_IDLE_STATE_EVENT, + IdleStateEvent.WRITER_IDLE_STATE_EVENT, IdleStateEvent.WRITER_IDLE_STATE_EVENT); + } + + @Test + public void testAllIdle() throws Exception { + TestableIdleStateHandler idleStateHandler = new TestableIdleStateHandler( + false, 0L, 0L, 1L, TimeUnit.SECONDS); + + anyIdle(idleStateHandler, IdleStateEvent.FIRST_ALL_IDLE_STATE_EVENT, + IdleStateEvent.ALL_IDLE_STATE_EVENT, IdleStateEvent.ALL_IDLE_STATE_EVENT); + } + + private static void anyIdle(TestableIdleStateHandler idleStateHandler, Object... expected) throws Exception { + assertThat(expected.length, greaterThanOrEqualTo(1)); + + final List events = new ArrayList(); + ChannelInboundHandlerAdapter handler = new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + events.add(evt); + } + }; + + EmbeddedChannel channel = new EmbeddedChannel(idleStateHandler, handler); + try { + // For each expected event advance the ticker and run() the task. Each + // step should yield in an IdleStateEvent because we haven't written + // or read anything from the channel. + for (int i = 0; i < expected.length; i++) { + idleStateHandler.tickRun(); + } + + assertEquals(expected.length, events.size()); + + // Compare the expected with the actual IdleStateEvents + for (int i = 0; i < expected.length; i++) { + Object evt = events.get(i); + assertSame(expected[i], evt, "Element " + i + " is not matching"); + } + } finally { + channel.finishAndReleaseAll(); + } + } + + @Test + public void testReaderNotIdle() throws Exception { + TestableIdleStateHandler idleStateHandler = new TestableIdleStateHandler( + false, 1L, 0L, 0L, TimeUnit.SECONDS); + + Action action = new Action() { + @Override + public void run(EmbeddedChannel channel) throws Exception { + channel.writeInbound("Hello, World!"); + } + }; + + anyNotIdle(idleStateHandler, action, IdleStateEvent.FIRST_READER_IDLE_STATE_EVENT); + } + + @Test + public void testWriterNotIdle() throws Exception { + TestableIdleStateHandler idleStateHandler = new TestableIdleStateHandler( + false, 0L, 1L, 0L, TimeUnit.SECONDS); + + Action action = new Action() { + @Override + public void run(EmbeddedChannel channel) throws Exception { + channel.writeAndFlush("Hello, World!"); + } + }; + + anyNotIdle(idleStateHandler, action, IdleStateEvent.FIRST_WRITER_IDLE_STATE_EVENT); + } + + @Test + public void testAllNotIdle() throws Exception { + // Reader... + TestableIdleStateHandler idleStateHandler = new TestableIdleStateHandler( + false, 0L, 0L, 1L, TimeUnit.SECONDS); + + Action reader = new Action() { + @Override + public void run(EmbeddedChannel channel) throws Exception { + channel.writeInbound("Hello, World!"); + } + }; + + anyNotIdle(idleStateHandler, reader, IdleStateEvent.FIRST_ALL_IDLE_STATE_EVENT); + + // Writer... + idleStateHandler = new TestableIdleStateHandler( + false, 0L, 0L, 1L, TimeUnit.SECONDS); + + Action writer = new Action() { + @Override + public void run(EmbeddedChannel channel) throws Exception { + channel.writeAndFlush("Hello, World!"); + } + }; + + anyNotIdle(idleStateHandler, writer, IdleStateEvent.FIRST_ALL_IDLE_STATE_EVENT); + } + + private static void anyNotIdle(TestableIdleStateHandler idleStateHandler, + Action action, Object expected) throws Exception { + + final List events = new ArrayList(); + ChannelInboundHandlerAdapter handler = new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + events.add(evt); + } + }; + + EmbeddedChannel channel = new EmbeddedChannel(idleStateHandler, handler); + try { + idleStateHandler.tick(1L, TimeUnit.NANOSECONDS); + action.run(channel); + + // Advance the ticker by some fraction and run() the task. + // There shouldn't be an IdleStateEvent getting fired because + // we've just performed an action on the channel that is meant + // to reset the idle task. + long delayInNanos = idleStateHandler.delay(TimeUnit.NANOSECONDS); + assertNotEquals(0L, delayInNanos); + + idleStateHandler.tickRun(delayInNanos / 2L, TimeUnit.NANOSECONDS); + assertEquals(0, events.size()); + + // Advance the ticker by the full amount and it should yield + // in an IdleStateEvent. + idleStateHandler.tickRun(); + assertEquals(1, events.size()); + assertSame(expected, events.get(0)); + } finally { + channel.finishAndReleaseAll(); + } + } + + @Test + public void testObserveWriterIdle() throws Exception { + observeOutputIdle(true); + } + + @Test + public void testObserveAllIdle() throws Exception { + observeOutputIdle(false); + } + + private static void observeOutputIdle(boolean writer) throws Exception { + + long writerIdleTime = 0L; + long allIdleTime = 0L; + IdleStateEvent expected; + + if (writer) { + writerIdleTime = 5L; + expected = IdleStateEvent.FIRST_WRITER_IDLE_STATE_EVENT; + } else { + allIdleTime = 5L; + expected = IdleStateEvent.FIRST_ALL_IDLE_STATE_EVENT; + } + + TestableIdleStateHandler idleStateHandler = new TestableIdleStateHandler( + true, 0L, writerIdleTime, allIdleTime, TimeUnit.SECONDS); + + final List events = new ArrayList(); + ChannelInboundHandlerAdapter handler = new ChannelInboundHandlerAdapter() { + @Override + public void userEventTriggered(ChannelHandlerContext ctx, Object evt) throws Exception { + events.add(evt); + } + }; + + ObservableChannel channel = new ObservableChannel(idleStateHandler, handler); + try { + // We're writing 3 messages that will be consumed at different rates! + channel.writeAndFlush(Unpooled.wrappedBuffer(new byte[] { 1 })); + channel.writeAndFlush(Unpooled.wrappedBuffer(new byte[] { 2 })); + channel.writeAndFlush(Unpooled.wrappedBuffer(new byte[] { 3 })); + channel.writeAndFlush(Unpooled.wrappedBuffer(new byte[5 * 1024])); + + // Establish a baseline. We're not consuming anything and let it idle once. + idleStateHandler.tickRun(); + assertEquals(1, events.size()); + assertSame(expected, events.get(0)); + events.clear(); + + // Our ticker should be at second 5 + assertEquals(5L, idleStateHandler.tick(TimeUnit.SECONDS)); + + // Consume one message in 4 seconds, then be idle for 2 seconds, + // then run the task and we shouldn't get an IdleStateEvent because + // we haven't been idle for long enough! + idleStateHandler.tick(4L, TimeUnit.SECONDS); + assertNotNullAndRelease(channel.consume()); + + idleStateHandler.tickRun(2L, TimeUnit.SECONDS); + assertEquals(0, events.size()); + assertEquals(11L, idleStateHandler.tick(TimeUnit.SECONDS)); // 5s + 4s + 2s + + // Consume one message in 3 seconds, then be idle for 4 seconds, + // then run the task and we shouldn't get an IdleStateEvent because + // we haven't been idle for long enough! + idleStateHandler.tick(3L, TimeUnit.SECONDS); + assertNotNullAndRelease(channel.consume()); + + idleStateHandler.tickRun(4L, TimeUnit.SECONDS); + assertEquals(0, events.size()); + assertEquals(18L, idleStateHandler.tick(TimeUnit.SECONDS)); // 11s + 3s + 4s + + // Don't consume a message and be idle for 5 seconds. + // We should get an IdleStateEvent! + idleStateHandler.tickRun(5L, TimeUnit.SECONDS); + assertEquals(1, events.size()); + assertEquals(23L, idleStateHandler.tick(TimeUnit.SECONDS)); // 18s + 5s + events.clear(); + + // Consume one message in 2 seconds, then be idle for 1 seconds, + // then run the task and we shouldn't get an IdleStateEvent because + // we haven't been idle for long enough! + idleStateHandler.tick(2L, TimeUnit.SECONDS); + assertNotNullAndRelease(channel.consume()); + + idleStateHandler.tickRun(1L, TimeUnit.SECONDS); + assertEquals(0, events.size()); + assertEquals(26L, idleStateHandler.tick(TimeUnit.SECONDS)); // 23s + 2s + 1s + + // Consume part of the message every 2 seconds, then be idle for 1 seconds, + // then run the task and we should get an IdleStateEvent because the first trigger + idleStateHandler.tick(2L, TimeUnit.SECONDS); + assertNotNullAndRelease(channel.consumePart(1024)); + idleStateHandler.tick(2L, TimeUnit.SECONDS); + assertNotNullAndRelease(channel.consumePart(1024)); + idleStateHandler.tickRun(1L, TimeUnit.SECONDS); + assertEquals(1, events.size()); + assertEquals(31L, idleStateHandler.tick(TimeUnit.SECONDS)); // 26s + 2s + 2s + 1s + events.clear(); + + // Consume part of the message every 2 seconds, then be idle for 1 seconds, + // then consume all the rest of the message, then run the task and we shouldn't + // get an IdleStateEvent because the data is flowing and we haven't been idle for long enough! + idleStateHandler.tick(2L, TimeUnit.SECONDS); + assertNotNullAndRelease(channel.consumePart(1024)); + idleStateHandler.tick(2L, TimeUnit.SECONDS); + assertNotNullAndRelease(channel.consumePart(1024)); + idleStateHandler.tickRun(1L, TimeUnit.SECONDS); + assertEquals(0, events.size()); + assertEquals(36L, idleStateHandler.tick(TimeUnit.SECONDS)); // 31s + 2s + 2s + 1s + idleStateHandler.tick(2L, TimeUnit.SECONDS); + assertNotNullAndRelease(channel.consumePart(1024)); + + // There are no messages left! Advance the ticker by 3 seconds, + // attempt a consume() but it will be null, then advance the + // ticker by an another 2 seconds and we should get an IdleStateEvent + // because we've been idle for 5 seconds. + idleStateHandler.tick(3L, TimeUnit.SECONDS); + assertNull(channel.consume()); + + idleStateHandler.tickRun(2L, TimeUnit.SECONDS); + assertEquals(1, events.size()); + assertEquals(43L, idleStateHandler.tick(TimeUnit.SECONDS)); // 36s + 2s + 3s + 2s + + // q.e.d. + } finally { + channel.finishAndReleaseAll(); + } + } + + private static void assertNotNullAndRelease(Object msg) { + assertNotNull(msg); + ReferenceCountUtil.release(msg); + } + + private interface Action { + void run(EmbeddedChannel channel) throws Exception; + } + + private static class TestableIdleStateHandler extends IdleStateHandler { + + private Runnable task; + + private long delayInNanos; + + private long ticksInNanos; + + TestableIdleStateHandler(boolean observeOutput, + long readerIdleTime, long writerIdleTime, long allIdleTime, + TimeUnit unit) { + super(observeOutput, readerIdleTime, writerIdleTime, allIdleTime, unit); + } + + public long delay(TimeUnit unit) { + return unit.convert(delayInNanos, TimeUnit.NANOSECONDS); + } + + public void run() { + task.run(); + } + + public void tickRun() { + tickRun(delayInNanos, TimeUnit.NANOSECONDS); + } + + public void tickRun(long delay, TimeUnit unit) { + tick(delay, unit); + run(); + } + + /** + * Advances the current ticker by the given amount. + */ + public void tick(long delay, TimeUnit unit) { + ticksInNanos += unit.toNanos(delay); + } + + /** + * Returns {@link #ticksInNanos()} in the given {@link TimeUnit}. + */ + public long tick(TimeUnit unit) { + return unit.convert(ticksInNanos(), TimeUnit.NANOSECONDS); + } + + @Override + long ticksInNanos() { + return ticksInNanos; + } + + @Override + Future schedule(ChannelHandlerContext ctx, Runnable task, long delay, TimeUnit unit) { + this.task = task; + this.delayInNanos = unit.toNanos(delay); + return null; + } + } + + private static class ObservableChannel extends EmbeddedChannel { + + ObservableChannel(ChannelHandler... handlers) { + super(handlers); + } + + @Override + protected void doWrite(ChannelOutboundBuffer in) throws Exception { + // Overridden to change EmbeddedChannel's default behavior. We went to keep + // the messages in the ChannelOutboundBuffer. + } + + private Object consume() { + ChannelOutboundBuffer buf = unsafe().outboundBuffer(); + if (buf != null) { + Object msg = buf.current(); + if (msg != null) { + ReferenceCountUtil.retain(msg); + buf.remove(); + return msg; + } + } + return null; + } + + /** + * Consume the part of a message. + * + * @param byteCount count of byte to be consumed + * @return the message currently being consumed + */ + private Object consumePart(int byteCount) { + ChannelOutboundBuffer buf = unsafe().outboundBuffer(); + if (buf != null) { + Object msg = buf.current(); + if (msg != null) { + ReferenceCountUtil.retain(msg); + buf.removeBytes(byteCount); + return msg; + } + } + return null; + } + } +} diff --git a/netty-handler/src/test/java/io/netty/handler/timeout/WriteTimeoutHandlerTest.java b/netty-handler/src/test/java/io/netty/handler/timeout/WriteTimeoutHandlerTest.java new file mode 100644 index 0000000..20be889 --- /dev/null +++ b/netty-handler/src/test/java/io/netty/handler/timeout/WriteTimeoutHandlerTest.java @@ -0,0 +1,61 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.handler.timeout; + +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelFutureListener; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.embedded.EmbeddedChannel; +import io.netty.util.concurrent.DefaultEventExecutorGroup; +import io.netty.util.concurrent.EventExecutorGroup; +import org.junit.jupiter.api.Test; + +import java.util.concurrent.CountDownLatch; + +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class WriteTimeoutHandlerTest { + + @Test + public void testPromiseUseDifferentExecutor() throws Exception { + EventExecutorGroup group1 = new DefaultEventExecutorGroup(1); + EventExecutorGroup group2 = new DefaultEventExecutorGroup(1); + EmbeddedChannel channel = new EmbeddedChannel(false, false); + try { + channel.pipeline().addLast(group1, new WriteTimeoutHandler(10000)); + final CountDownLatch latch = new CountDownLatch(1); + channel.pipeline().addLast(group2, new ChannelInboundHandlerAdapter() { + @Override + public void channelActive(ChannelHandlerContext ctx) throws Exception { + ctx.writeAndFlush("something").addListener(new ChannelFutureListener() { + @Override + public void operationComplete(ChannelFuture future) throws Exception { + latch.countDown(); + } + }); + } + }); + + channel.register(); + latch.await(); + assertTrue(channel.finishAndReleaseAll()); + } finally { + group1.shutdownGracefully(); + group2.shutdownGracefully(); + } + } +} diff --git a/netty-handler/src/test/java/io/netty/handler/traffic/FileRegionThrottleTest.java b/netty-handler/src/test/java/io/netty/handler/traffic/FileRegionThrottleTest.java new file mode 100644 index 0000000..58481e7 --- /dev/null +++ b/netty-handler/src/test/java/io/netty/handler/traffic/FileRegionThrottleTest.java @@ -0,0 +1,168 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version + * 2.0 (the "License"); you may not use this file except in compliance with the + * License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.handler.traffic; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.ByteBuf; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelFuture; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.DefaultFileRegion; +import io.netty.channel.EventLoopGroup; +import io.netty.channel.nio.NioEventLoopGroup; +import io.netty.channel.socket.SocketChannel; +import io.netty.channel.socket.nio.NioServerSocketChannel; +import io.netty.channel.socket.nio.NioSocketChannel; +import io.netty.handler.codec.LineBasedFrameDecoder; +import io.netty.util.CharsetUtil; +import io.netty.util.internal.PlatformDependent; +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Test; + +import java.io.File; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.RandomAccessFile; +import java.net.SocketAddress; +import java.nio.charset.Charset; +import java.util.Random; +import java.util.concurrent.CountDownLatch; + +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class FileRegionThrottleTest { + private static final byte[] BYTES = new byte[64 * 1024 * 4]; + private static final long WRITE_LIMIT = 64 * 1024; + private static File tmp; + private EventLoopGroup group; + + @BeforeAll + public static void beforeClass() throws IOException { + final Random r = new Random(); + for (int i = 0; i < BYTES.length; i++) { + BYTES[i] = (byte) r.nextInt(255); + } + + tmp = PlatformDependent.createTempFile("netty-traffic", ".tmp", null); + tmp.deleteOnExit(); + FileOutputStream out = null; + try { + out = new FileOutputStream(tmp); + out.write(BYTES); + out.flush(); + } catch (IOException e) { + throw new RuntimeException(e); + } finally { + if (out != null) { + try { + out.close(); + } catch (IOException e) { + // ignore + } + } + } + } + + @BeforeEach + public void setUp() { + group = new NioEventLoopGroup(); + } + + @AfterEach + public void tearDown() { + group.shutdownGracefully(); + } + + @Disabled("This test is flaky, need more investigation") + @Test + public void testGlobalWriteThrottle() throws Exception { + final CountDownLatch latch = new CountDownLatch(1); + final GlobalTrafficShapingHandler gtsh = new GlobalTrafficShapingHandler(group, WRITE_LIMIT, 0); + ServerBootstrap bs = new ServerBootstrap(); + bs.group(group).channel(NioServerSocketChannel.class).childHandler(new ChannelInitializer() { + @Override + protected void initChannel(SocketChannel ch) { + ch.pipeline().addLast(new LineBasedFrameDecoder(Integer.MAX_VALUE)); + ch.pipeline().addLast(new MessageDecoder()); + ch.pipeline().addLast(gtsh); + } + }); + Channel sc = bs.bind(0).sync().channel(); + Channel cc = clientConnect(sc.localAddress(), new ReadHandler(latch)).channel(); + + long start = TrafficCounter.milliSecondFromNano(); + cc.writeAndFlush(Unpooled.copiedBuffer("send-file\n", CharsetUtil.US_ASCII)).sync(); + latch.await(); + long timeTaken = TrafficCounter.milliSecondFromNano() - start; + assertTrue(timeTaken > 3000, "Data streamed faster than expected"); + sc.close().sync(); + cc.close().sync(); + } + + private ChannelFuture clientConnect(final SocketAddress server, final ReadHandler readHandler) throws Exception { + Bootstrap bc = new Bootstrap(); + bc.group(group).channel(NioSocketChannel.class).handler(new ChannelInitializer() { + @Override + protected void initChannel(SocketChannel ch) { + ch.pipeline().addLast(readHandler); + } + }); + return bc.connect(server).sync(); + } + + private static final class MessageDecoder extends ChannelInboundHandlerAdapter { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + if (msg instanceof ByteBuf) { + ByteBuf buf = (ByteBuf) msg; + String message = buf.toString(Charset.defaultCharset()); + buf.release(); + if (message.equals("send-file")) { + RandomAccessFile raf = new RandomAccessFile(tmp, "r"); + ctx.channel().writeAndFlush(new DefaultFileRegion(raf.getChannel(), 0, tmp.length())); + } + } + } + } + + private static final class ReadHandler extends ChannelInboundHandlerAdapter { + private long bytesTransferred; + private final CountDownLatch latch; + + ReadHandler(CountDownLatch latch) { + this.latch = latch; + } + + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) { + if (msg instanceof ByteBuf) { + ByteBuf buf = (ByteBuf) msg; + bytesTransferred += buf.readableBytes(); + buf.release(); + if (bytesTransferred == tmp.length()) { + latch.countDown(); + } + } + } + } +} diff --git a/netty-handler/src/test/java/io/netty/handler/traffic/TrafficShapingHandlerTest.java b/netty-handler/src/test/java/io/netty/handler/traffic/TrafficShapingHandlerTest.java new file mode 100644 index 0000000..f95ad8b --- /dev/null +++ b/netty-handler/src/test/java/io/netty/handler/traffic/TrafficShapingHandlerTest.java @@ -0,0 +1,125 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.handler.traffic; + +import java.util.concurrent.Executors; +import java.util.concurrent.ScheduledExecutorService; + +import io.netty.bootstrap.Bootstrap; +import io.netty.bootstrap.ServerBootstrap; +import io.netty.buffer.Unpooled; +import io.netty.channel.Channel; +import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInboundHandlerAdapter; +import io.netty.channel.ChannelInitializer; +import io.netty.channel.DefaultEventLoopGroup; +import io.netty.channel.local.LocalAddress; +import io.netty.channel.local.LocalChannel; +import io.netty.channel.local.LocalServerChannel; +import io.netty.util.Attribute; +import io.netty.util.CharsetUtil; +import org.junit.jupiter.api.AfterAll; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; + +public class TrafficShapingHandlerTest { + + private static final long READ_LIMIT_BYTES_PER_SECOND = 1; + private static final ScheduledExecutorService SES = Executors.newSingleThreadScheduledExecutor(); + private static final DefaultEventLoopGroup GROUP = new DefaultEventLoopGroup(1); + + @AfterAll + public static void destroy() { + GROUP.shutdownGracefully(); + SES.shutdown(); + } + + @Test + public void testHandlerRemove() throws Exception { + testHandlerRemove0(new ChannelTrafficShapingHandler(0, READ_LIMIT_BYTES_PER_SECOND)); + GlobalTrafficShapingHandler trafficHandler1 = + new GlobalTrafficShapingHandler(SES, 0, READ_LIMIT_BYTES_PER_SECOND); + try { + testHandlerRemove0(trafficHandler1); + } finally { + trafficHandler1.release(); + } + GlobalChannelTrafficShapingHandler trafficHandler2 = + new GlobalChannelTrafficShapingHandler(SES, 0, + READ_LIMIT_BYTES_PER_SECOND, 0, READ_LIMIT_BYTES_PER_SECOND); + try { + testHandlerRemove0(trafficHandler2); + } finally { + trafficHandler2.release(); + } + } + + private void testHandlerRemove0(final AbstractTrafficShapingHandler trafficHandler) + throws Exception { + Channel svrChannel = null; + Channel ch = null; + try { + ServerBootstrap serverBootstrap = new ServerBootstrap(); + serverBootstrap.channel(LocalServerChannel.class).group(GROUP, GROUP) + .childHandler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ch.pipeline().addLast(new ChannelInboundHandlerAdapter() { + @Override + public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception { + ctx.writeAndFlush(msg); + } + }); + } + }); + final LocalAddress svrAddr = new LocalAddress("foo"); + svrChannel = serverBootstrap.bind(svrAddr).sync().channel(); + Bootstrap bootstrap = new Bootstrap(); + bootstrap.channel(LocalChannel.class).group(GROUP).handler(new ChannelInitializer() { + @Override + protected void initChannel(Channel ch) throws Exception { + ch.pipeline().addLast("traffic-shaping", trafficHandler); + } + }); + ch = bootstrap.connect(svrAddr).sync().channel(); + Attribute attr = ch.attr(AbstractTrafficShapingHandler.REOPEN_TASK); + assertNull(attr.get()); + ch.writeAndFlush(Unpooled.wrappedBuffer("foo".getBytes(CharsetUtil.UTF_8))); + ch.writeAndFlush(Unpooled.wrappedBuffer("bar".getBytes(CharsetUtil.UTF_8))).await(); + assertNotNull(attr.get()); + final Channel clientChannel = ch; + ch.eventLoop().submit(new Runnable() { + @Override + public void run() { + clientChannel.pipeline().remove("traffic-shaping"); + } + }).await(); + //the attribute--reopen task must be released. + assertNull(attr.get()); + } finally { + if (ch != null) { + ch.close().sync(); + } + if (svrChannel != null) { + svrChannel.close().sync(); + } + } + } + +} diff --git a/netty-handler/src/test/resources/logging.properties b/netty-handler/src/test/resources/logging.properties new file mode 100644 index 0000000..3cd7309 --- /dev/null +++ b/netty-handler/src/test/resources/logging.properties @@ -0,0 +1,7 @@ +handlers=java.util.logging.ConsoleHandler +.level=ALL +java.util.logging.SimpleFormatter.format=%1$tY-%1$tm-%1$td %1$tH:%1$tM:%1$tS.%1$tL %4$-7s [%3$s] %5$s %6$s%n +java.util.logging.ConsoleHandler.level=ALL +java.util.logging.ConsoleHandler.formatter=java.util.logging.SimpleFormatter +jdk.event.security.level=INFO +org.junit.jupiter.engine.execution.ConditionEvaluator.level=OFF diff --git a/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/AsyncSSLPrivateKeyMethod.java b/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/AsyncSSLPrivateKeyMethod.java new file mode 100644 index 0000000..a78ead4 --- /dev/null +++ b/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/AsyncSSLPrivateKeyMethod.java @@ -0,0 +1,54 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.internal.tcnative; + +/** + * Allows to customize private key signing / decrypt (when using RSA). + */ +public interface AsyncSSLPrivateKeyMethod { + int SSL_SIGN_RSA_PKCS1_SHA1 = NativeStaticallyReferencedJniMethods.sslSignRsaPkcsSha1(); + int SSL_SIGN_RSA_PKCS1_SHA256 = NativeStaticallyReferencedJniMethods.sslSignRsaPkcsSha256(); + int SSL_SIGN_RSA_PKCS1_SHA384 = NativeStaticallyReferencedJniMethods.sslSignRsaPkcsSha384(); + int SSL_SIGN_RSA_PKCS1_SHA512 = NativeStaticallyReferencedJniMethods.sslSignRsaPkcsSha512(); + int SSL_SIGN_ECDSA_SHA1 = NativeStaticallyReferencedJniMethods.sslSignEcdsaPkcsSha1(); + int SSL_SIGN_ECDSA_SECP256R1_SHA256 = NativeStaticallyReferencedJniMethods.sslSignEcdsaSecp256r1Sha256(); + int SSL_SIGN_ECDSA_SECP384R1_SHA384 = NativeStaticallyReferencedJniMethods.sslSignEcdsaSecp384r1Sha384(); + int SSL_SIGN_ECDSA_SECP521R1_SHA512 = NativeStaticallyReferencedJniMethods.sslSignEcdsaSecp521r1Sha512(); + int SSL_SIGN_RSA_PSS_RSAE_SHA256 = NativeStaticallyReferencedJniMethods.sslSignRsaPssRsaeSha256(); + int SSL_SIGN_RSA_PSS_RSAE_SHA384 = NativeStaticallyReferencedJniMethods.sslSignRsaPssRsaeSha384(); + int SSL_SIGN_RSA_PSS_RSAE_SHA512 = NativeStaticallyReferencedJniMethods.sslSignRsaPssRsaeSha512(); + int SSL_SIGN_ED25519 = NativeStaticallyReferencedJniMethods.sslSignEd25519(); + int SSL_SIGN_RSA_PKCS1_MD5_SHA1 = NativeStaticallyReferencedJniMethods.sslSignRsaPkcs1Md5Sha1(); + + /** + * Sign the input with given EC key and notify {@link ResultCallback} with the signed bytes. + * + * @param ssl the SSL instance + * @param signatureAlgorithm the algorithm to use for signing + * @param input the input itself + * @param resultCallback the callback that will be notified once the operation completes + */ + void sign(long ssl, int signatureAlgorithm, byte[] input, ResultCallback resultCallback); + + /** + * Decrypts the input with the given RSA key and notify {@link ResultCallback} with the decrypted bytes. + * + * @param ssl the SSL instance + * @param input the input which should be decrypted + * @param resultCallback the callback that will be notified once the operation completes + */ + void decrypt(long ssl, byte[] input, ResultCallback resultCallback); +} diff --git a/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/AsyncSSLPrivateKeyMethodAdapter.java b/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/AsyncSSLPrivateKeyMethodAdapter.java new file mode 100644 index 0000000..e2a1d1c --- /dev/null +++ b/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/AsyncSSLPrivateKeyMethodAdapter.java @@ -0,0 +1,51 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.internal.tcnative; + +final class AsyncSSLPrivateKeyMethodAdapter implements AsyncSSLPrivateKeyMethod { + private final SSLPrivateKeyMethod method; + + AsyncSSLPrivateKeyMethodAdapter(SSLPrivateKeyMethod method) { + if (method == null) { + throw new NullPointerException("method"); + } + this.method = method; + } + + @Override + public void sign(long ssl, int signatureAlgorithm, byte[] input, ResultCallback resultCallback) { + final byte[] result; + try { + result = method.sign(ssl, signatureAlgorithm, input); + } catch (Throwable cause) { + resultCallback.onError(ssl, cause); + return; + } + resultCallback.onSuccess(ssl, result); + } + + @Override + public void decrypt(long ssl, byte[] input, ResultCallback resultCallback) { + final byte[] result; + try { + result = method.decrypt(ssl, input); + } catch (Throwable cause) { + resultCallback.onError(ssl, cause); + return; + } + resultCallback.onSuccess(ssl, result); + } +} diff --git a/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/AsyncTask.java b/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/AsyncTask.java new file mode 100644 index 0000000..caf8587 --- /dev/null +++ b/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/AsyncTask.java @@ -0,0 +1,27 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.internal.tcnative; + +public interface AsyncTask extends Runnable { + + /** + * Run this {@link AsyncTask} in an async fashion. Which means it will be run and completed at some point. + * Once it is done the {@link Runnable} is called + * + * @param completeCallback The {@link Runnable} that is run once the task was run and completed. + */ + void runAsync(Runnable completeCallback); +} diff --git a/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/Buffer.java b/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/Buffer.java new file mode 100644 index 0000000..de39e12 --- /dev/null +++ b/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/Buffer.java @@ -0,0 +1,54 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.netty.internal.tcnative; + +import java.nio.ByteBuffer; + +public final class Buffer { + + private Buffer() { } + + /** + * Returns the memory address of the ByteBuffer. + * @param buf Previously allocated ByteBuffer. + * @return the memory address. + */ + public static native long address(ByteBuffer buf); + + /** + * Returns the allocated memory size of the ByteBuffer. + * @param buf Previously allocated ByteBuffer. + * @return the allocated memory size + */ + public static native long size(ByteBuffer buf); +} diff --git a/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/CertificateCallback.java b/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/CertificateCallback.java new file mode 100644 index 0000000..f182a11 --- /dev/null +++ b/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/CertificateCallback.java @@ -0,0 +1,51 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.internal.tcnative; + +/** + * Is called during handshake and hooked into openssl via {@code SSL_CTX_set_cert_cb}. + * + * IMPORTANT: Implementations of this interface should be static as it is stored as a global reference via JNI. This + * means if you use an inner / anonymous class to implement this and also depend on the finalizer of the + * class to free up the SSLContext the finalizer will never run as the object is never GC, due the hard + * reference to the enclosing class. This will most likely result in a memory leak. + */ +public interface CertificateCallback { + + /** + * The types contained in the {@code keyTypeBytes} array. + */ + // Extracted from https://github.com/openssl/openssl/blob/master/include/openssl/tls1.h + byte TLS_CT_RSA_SIGN = 1; + byte TLS_CT_DSS_SIGN = 2; + byte TLS_CT_RSA_FIXED_DH = 3; + byte TLS_CT_DSS_FIXED_DH = 4; + byte TLS_CT_ECDSA_SIGN = 64; + byte TLS_CT_RSA_FIXED_ECDH = 65; + byte TLS_CT_ECDSA_FIXED_ECDH = 66; + + /** + * Called during cert selection. If a certificate chain / key should be used + * {@link SSL#setKeyMaterial(long, long, long)} must be called from this callback after + * all preparations / validations were completed. + * + * @param ssl the SSL instance + * @param keyTypeBytes an array of the key types on client-mode or {@code null} on server-mode. + * @param asn1DerEncodedPrincipals the principals or {@code null}. + * + */ + void handle(long ssl, byte[] keyTypeBytes, byte[][] asn1DerEncodedPrincipals) throws Exception; +} diff --git a/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/CertificateCallbackTask.java b/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/CertificateCallbackTask.java new file mode 100644 index 0000000..ff89e26 --- /dev/null +++ b/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/CertificateCallbackTask.java @@ -0,0 +1,49 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.internal.tcnative; + +/** + * Execute {@link CertificateCallback#handle(long, byte[], byte[][])}. + */ +final class CertificateCallbackTask extends SSLTask { + private final byte[] keyTypeBytes; + private final byte[][] asn1DerEncodedPrincipals; + private final CertificateCallback callback; + + CertificateCallbackTask(long ssl, byte[] keyTypeBytes, byte[][] asn1DerEncodedPrincipals, + CertificateCallback callback) { + // It is important that this constructor never throws. Be sure to not change this! + super(ssl); + // It's ok to not clone the arrays as we create these in JNI and not-reuse. + this.keyTypeBytes = keyTypeBytes; + this.asn1DerEncodedPrincipals = asn1DerEncodedPrincipals; + this.callback = callback; + } + + // See https://www.openssl.org/docs/man1.0.2/man3/SSL_set_cert_cb.html. + @Override + protected void runTask(long ssl, TaskCallback taskCallback) { + try { + callback.handle(ssl, keyTypeBytes, asn1DerEncodedPrincipals); + taskCallback.onResult(ssl, 1); + } catch (Exception e) { + // Just catch the exception and return 0 to fail the handshake. + // The problem is that rethrowing here is really "useless" as we will process it as part of an openssl + // c callback which needs to return 0 for an error to abort the handshake. + taskCallback.onResult(ssl, 0); + } + } +} diff --git a/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/CertificateCompressionAlgo.java b/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/CertificateCompressionAlgo.java new file mode 100644 index 0000000..b115843 --- /dev/null +++ b/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/CertificateCompressionAlgo.java @@ -0,0 +1,70 @@ +/* + * Copyright 2022 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.internal.tcnative; + +/** + * Provides compression/decompression implementations for TLS Certificate Compression + * (RFC 8879). + */ +public interface CertificateCompressionAlgo { + int TLS_EXT_CERT_COMPRESSION_ZLIB = NativeStaticallyReferencedJniMethods.tlsExtCertCompressionZlib(); + int TLS_EXT_CERT_COMPRESSION_BROTLI = NativeStaticallyReferencedJniMethods.tlsExtCertCompressionBrotli(); + int TLS_EXT_CERT_COMPRESSION_ZSTD = NativeStaticallyReferencedJniMethods.tlsExtCertCompressionZstd(); + + /** + * Compress the given input with the specified algorithm and return the compressed bytes. + * + * @param ssl the SSL instance + * @param input the uncompressed form of the certificate + * @return the compressed form of the certificate + * @throws Exception thrown if an error occurs while compressing + */ + byte[] compress(long ssl, byte[] input) throws Exception; + + /** + * Decompress the given input with the specified algorithm and return the decompressed bytes. + * + *

Implementation + * Security Considerations

+ *

Implementations SHOULD bound the memory usage when decompressing the CompressedCertificate message.

+ *

+ * Implementations MUST limit the size of the resulting decompressed chain to the specified {@code uncompressedLen}, + * and they MUST abort the connection (throw an exception) if the size of the output of the decompression + * function exceeds that limit. + *

+ * + * @param ssl the SSL instance + * @param uncompressedLen the expected length of the uncompressed certificate + * @param input the compressed form of the certificate + * @return the decompressed form of the certificate + * @throws Exception thrown if an error occurs while decompressing or output + * size exceeds {@code uncompressedLen} + */ + byte[] decompress(long ssl, int uncompressedLen, byte[] input) throws Exception; + + /** + * Return the ID for the compression algorithm provided for by a given implementation. + * + * @return compression algorithm ID as specified by RFC8879 + *
+     * {@link CertificateCompressionAlgo#TLS_EXT_CERT_COMPRESSION_ZLIB}
+     * {@link CertificateCompressionAlgo#TLS_EXT_CERT_COMPRESSION_BROTLI}
+     * {@link CertificateCompressionAlgo#TLS_EXT_CERT_COMPRESSION_ZSTD}
+     * 
+ */ + int algorithmId(); + +} diff --git a/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/CertificateRequestedCallback.java b/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/CertificateRequestedCallback.java new file mode 100644 index 0000000..601f3f3 --- /dev/null +++ b/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/CertificateRequestedCallback.java @@ -0,0 +1,57 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.internal.tcnative; + +/** + * Is called during handshake and hooked into openssl via {@code SSL_CTX_set_client_cert_cb}. + * + * IMPORTANT: Implementations of this interface should be static as it is stored as a global reference via JNI. This + * means if you use an inner / anonymous class to implement this and also depend on the finalizer of the + * class to free up the SSLContext the finalizer will never run as the object is never GC, due the hard + * reference to the enclosing class. This will most likely result in a memory leak.+ + * + * @deprecated use {@link CertificateCallback} + */ +@Deprecated +public interface CertificateRequestedCallback { + + /** + * The types contained in the {@code keyTypeBytes} array. + */ + // Extracted from https://github.com/openssl/openssl/blob/master/include/openssl/tls1.h + byte TLS_CT_RSA_SIGN = CertificateCallback.TLS_CT_RSA_SIGN; + byte TLS_CT_DSS_SIGN = CertificateCallback.TLS_CT_DSS_SIGN; + byte TLS_CT_RSA_FIXED_DH = CertificateCallback.TLS_CT_RSA_FIXED_DH; + byte TLS_CT_DSS_FIXED_DH = CertificateCallback.TLS_CT_DSS_FIXED_DH; + byte TLS_CT_ECDSA_SIGN = CertificateCallback.TLS_CT_ECDSA_SIGN; + byte TLS_CT_RSA_FIXED_ECDH = CertificateCallback.TLS_CT_RSA_FIXED_ECDH; + byte TLS_CT_ECDSA_FIXED_ECDH = CertificateCallback.TLS_CT_ECDSA_FIXED_ECDH; + + /** + * Called during cert selection. If a certificate chain / key should be used + * {@link SSL#setKeyMaterialClientSide(long, long, long, long, long)} must be called from this callback after + * all preparations / validations were completed. + * + * @param ssl the SSL instance + * @param certOut the pointer to the pointer of the certificate to use. + * @param keyOut the pointer to the pointer of the private key to use. + * @param keyTypeBytes an array of the key types. + * @param asn1DerEncodedPrincipals the principals + * + */ + void requested(long ssl, long certOut, long keyOut, byte[] keyTypeBytes, byte[][] asn1DerEncodedPrincipals) + throws Exception; +} diff --git a/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/CertificateVerifier.java b/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/CertificateVerifier.java new file mode 100644 index 0000000..4315308 --- /dev/null +++ b/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/CertificateVerifier.java @@ -0,0 +1,192 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.internal.tcnative; + +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; + +import static io.netty.internal.tcnative.NativeStaticallyReferencedJniMethods.*; + +/** + * Is called during handshake and hooked into openssl via {@code SSL_CTX_set_cert_verify_callback}. + * + * IMPORTANT: Implementations of this interface should be static as it is stored as a global reference via JNI. This + * means if you use an inner / anonymous class to implement this and also depend on the finalizer of the + * class to free up the SSLContext the finalizer will never run as the object is never GC, due the hard + * reference to the enclosing class. This will most likely result in a memory leak. + */ +public abstract class CertificateVerifier { + + // WARNING: If you add any new field here you also need to add it to the ERRORS set! + public static final int X509_V_OK = x509vOK(); + public static final int X509_V_ERR_UNSPECIFIED = x509vErrUnspecified(); + public static final int X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT = x509vErrUnableToGetIssuerCert(); + public static final int X509_V_ERR_UNABLE_TO_GET_CRL = x509vErrUnableToGetCrl(); + public static final int X509_V_ERR_UNABLE_TO_DECRYPT_CERT_SIGNATURE = x509vErrUnableToDecryptCertSignature(); + public static final int X509_V_ERR_UNABLE_TO_DECRYPT_CRL_SIGNATURE = x509vErrUnableToDecryptCrlSignature(); + public static final int X509_V_ERR_UNABLE_TO_DECODE_ISSUER_PUBLIC_KEY = x509vErrUnableToDecodeIssuerPublicKey(); + public static final int X509_V_ERR_CERT_SIGNATURE_FAILURE = x509vErrCertSignatureFailure(); + public static final int X509_V_ERR_CRL_SIGNATURE_FAILURE = x509vErrCrlSignatureFailure(); + public static final int X509_V_ERR_CERT_NOT_YET_VALID = x509vErrCertNotYetValid(); + public static final int X509_V_ERR_CERT_HAS_EXPIRED = x509vErrCertHasExpired(); + public static final int X509_V_ERR_CRL_NOT_YET_VALID = x509vErrCrlNotYetValid(); + public static final int X509_V_ERR_CRL_HAS_EXPIRED = x509vErrCrlHasExpired(); + public static final int X509_V_ERR_ERROR_IN_CERT_NOT_BEFORE_FIELD = x509vErrErrorInCertNotBeforeField(); + public static final int X509_V_ERR_ERROR_IN_CERT_NOT_AFTER_FIELD = x509vErrErrorInCertNotAfterField(); + public static final int X509_V_ERR_ERROR_IN_CRL_LAST_UPDATE_FIELD = x509vErrErrorInCrlLastUpdateField(); + public static final int X509_V_ERR_ERROR_IN_CRL_NEXT_UPDATE_FIELD = x509vErrErrorInCrlNextUpdateField(); + public static final int X509_V_ERR_OUT_OF_MEM = x509vErrOutOfMem(); + public static final int X509_V_ERR_DEPTH_ZERO_SELF_SIGNED_CERT = x509vErrDepthZeroSelfSignedCert(); + public static final int X509_V_ERR_SELF_SIGNED_CERT_IN_CHAIN = x509vErrSelfSignedCertInChain(); + public static final int X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT_LOCALLY = x509vErrUnableToGetIssuerCertLocally(); + public static final int X509_V_ERR_UNABLE_TO_VERIFY_LEAF_SIGNATURE = x509vErrUnableToVerifyLeafSignature(); + public static final int X509_V_ERR_CERT_CHAIN_TOO_LONG = x509vErrCertChainTooLong(); + public static final int X509_V_ERR_CERT_REVOKED = x509vErrCertRevoked(); + public static final int X509_V_ERR_INVALID_CA = x509vErrInvalidCa(); + public static final int X509_V_ERR_PATH_LENGTH_EXCEEDED = x509vErrPathLengthExceeded(); + public static final int X509_V_ERR_INVALID_PURPOSE = x509vErrInvalidPurpose(); + public static final int X509_V_ERR_CERT_UNTRUSTED = x509vErrCertUntrusted(); + public static final int X509_V_ERR_CERT_REJECTED = x509vErrCertRejected(); + public static final int X509_V_ERR_SUBJECT_ISSUER_MISMATCH = x509vErrSubjectIssuerMismatch(); + public static final int X509_V_ERR_AKID_SKID_MISMATCH = x509vErrAkidSkidMismatch(); + public static final int X509_V_ERR_AKID_ISSUER_SERIAL_MISMATCH = x509vErrAkidIssuerSerialMismatch(); + public static final int X509_V_ERR_KEYUSAGE_NO_CERTSIGN = x509vErrKeyUsageNoCertSign(); + public static final int X509_V_ERR_UNABLE_TO_GET_CRL_ISSUER = x509vErrUnableToGetCrlIssuer(); + public static final int X509_V_ERR_UNHANDLED_CRITICAL_EXTENSION = x509vErrUnhandledCriticalExtension(); + public static final int X509_V_ERR_KEYUSAGE_NO_CRL_SIGN = x509vErrKeyUsageNoCrlSign(); + public static final int X509_V_ERR_UNHANDLED_CRITICAL_CRL_EXTENSION = x509vErrUnhandledCriticalCrlExtension(); + public static final int X509_V_ERR_INVALID_NON_CA = x509vErrInvalidNonCa(); + public static final int X509_V_ERR_PROXY_PATH_LENGTH_EXCEEDED = x509vErrProxyPathLengthExceeded(); + public static final int X509_V_ERR_KEYUSAGE_NO_DIGITAL_SIGNATURE = x509vErrKeyUsageNoDigitalSignature(); + public static final int X509_V_ERR_PROXY_CERTIFICATES_NOT_ALLOWED = x509vErrProxyCertificatesNotAllowed(); + public static final int X509_V_ERR_INVALID_EXTENSION = x509vErrInvalidExtension(); + public static final int X509_V_ERR_INVALID_POLICY_EXTENSION = x509vErrInvalidPolicyExtension(); + public static final int X509_V_ERR_NO_EXPLICIT_POLICY = x509vErrNoExplicitPolicy(); + public static final int X509_V_ERR_DIFFERENT_CRL_SCOPE = x509vErrDifferntCrlScope(); + public static final int X509_V_ERR_UNSUPPORTED_EXTENSION_FEATURE = x509vErrUnsupportedExtensionFeature(); + public static final int X509_V_ERR_UNNESTED_RESOURCE = x509vErrUnnestedResource(); + public static final int X509_V_ERR_PERMITTED_VIOLATION = x509vErrPermittedViolation(); + public static final int X509_V_ERR_EXCLUDED_VIOLATION = x509vErrExcludedViolation(); + public static final int X509_V_ERR_SUBTREE_MINMAX = x509vErrSubtreeMinMax(); + public static final int X509_V_ERR_APPLICATION_VERIFICATION = x509vErrApplicationVerification(); + public static final int X509_V_ERR_UNSUPPORTED_CONSTRAINT_TYPE = x509vErrUnsupportedConstraintType(); + public static final int X509_V_ERR_UNSUPPORTED_CONSTRAINT_SYNTAX = x509vErrUnsupportedConstraintSyntax(); + public static final int X509_V_ERR_UNSUPPORTED_NAME_SYNTAX = x509vErrUnsupportedNameSyntax(); + public static final int X509_V_ERR_CRL_PATH_VALIDATION_ERROR = x509vErrCrlPathValidationError(); + public static final int X509_V_ERR_PATH_LOOP = x509vErrPathLoop(); + public static final int X509_V_ERR_SUITE_B_INVALID_VERSION = x509vErrSuiteBInvalidVersion(); + public static final int X509_V_ERR_SUITE_B_INVALID_ALGORITHM = x509vErrSuiteBInvalidAlgorithm(); + public static final int X509_V_ERR_SUITE_B_INVALID_CURVE = x509vErrSuiteBInvalidCurve(); + public static final int X509_V_ERR_SUITE_B_INVALID_SIGNATURE_ALGORITHM = x509vErrSuiteBInvalidSignatureAlgorithm(); + public static final int X509_V_ERR_SUITE_B_LOS_NOT_ALLOWED = x509vErrSuiteBLosNotAllowed(); + public static final int X509_V_ERR_SUITE_B_CANNOT_SIGN_P_384_WITH_P_256 = x509vErrSuiteBCannotSignP384WithP256(); + public static final int X509_V_ERR_HOSTNAME_MISMATCH = x509vErrHostnameMismatch(); + public static final int X509_V_ERR_EMAIL_MISMATCH = x509vErrEmailMismatch(); + public static final int X509_V_ERR_IP_ADDRESS_MISMATCH = x509vErrIpAddressMismatch(); + public static final int X509_V_ERR_DANE_NO_MATCH = x509vErrDaneNoMatch(); + + private static final Set ERRORS; + + static { + Set errors = new HashSet(); + errors.add(X509_V_OK); + errors.add(X509_V_ERR_UNSPECIFIED); + errors.add(X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT); + errors.add(X509_V_ERR_UNABLE_TO_GET_CRL); + errors.add(X509_V_ERR_UNABLE_TO_DECRYPT_CERT_SIGNATURE); + errors.add(X509_V_ERR_UNABLE_TO_DECRYPT_CRL_SIGNATURE); + errors.add(X509_V_ERR_UNABLE_TO_DECODE_ISSUER_PUBLIC_KEY); + errors.add(X509_V_ERR_CERT_SIGNATURE_FAILURE); + errors.add(X509_V_ERR_CRL_SIGNATURE_FAILURE); + errors.add(X509_V_ERR_CERT_NOT_YET_VALID); + errors.add(X509_V_ERR_CERT_HAS_EXPIRED); + errors.add(X509_V_ERR_CRL_NOT_YET_VALID); + errors.add(X509_V_ERR_CRL_HAS_EXPIRED); + errors.add(X509_V_ERR_ERROR_IN_CERT_NOT_BEFORE_FIELD); + errors.add(X509_V_ERR_ERROR_IN_CERT_NOT_AFTER_FIELD); + errors.add(X509_V_ERR_ERROR_IN_CRL_LAST_UPDATE_FIELD); + errors.add(X509_V_ERR_ERROR_IN_CRL_NEXT_UPDATE_FIELD); + errors.add(X509_V_ERR_OUT_OF_MEM); + errors.add(X509_V_ERR_DEPTH_ZERO_SELF_SIGNED_CERT); + errors.add(X509_V_ERR_SELF_SIGNED_CERT_IN_CHAIN); + errors.add(X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT_LOCALLY); + errors.add(X509_V_ERR_UNABLE_TO_VERIFY_LEAF_SIGNATURE); + errors.add(X509_V_ERR_CERT_CHAIN_TOO_LONG); + errors.add(X509_V_ERR_CERT_REVOKED); + errors.add(X509_V_ERR_INVALID_CA); + errors.add(X509_V_ERR_PATH_LENGTH_EXCEEDED); + errors.add(X509_V_ERR_INVALID_PURPOSE); + errors.add(X509_V_ERR_CERT_UNTRUSTED); + errors.add(X509_V_ERR_CERT_REJECTED); + errors.add(X509_V_ERR_SUBJECT_ISSUER_MISMATCH); + errors.add(X509_V_ERR_AKID_SKID_MISMATCH); + errors.add(X509_V_ERR_AKID_ISSUER_SERIAL_MISMATCH); + errors.add(X509_V_ERR_KEYUSAGE_NO_CERTSIGN); + errors.add(X509_V_ERR_UNABLE_TO_GET_CRL_ISSUER); + errors.add(X509_V_ERR_UNHANDLED_CRITICAL_EXTENSION); + errors.add(X509_V_ERR_KEYUSAGE_NO_CRL_SIGN); + errors.add(X509_V_ERR_UNHANDLED_CRITICAL_CRL_EXTENSION); + errors.add(X509_V_ERR_INVALID_NON_CA); + errors.add(X509_V_ERR_PROXY_PATH_LENGTH_EXCEEDED); + errors.add(X509_V_ERR_KEYUSAGE_NO_DIGITAL_SIGNATURE); + errors.add(X509_V_ERR_PROXY_CERTIFICATES_NOT_ALLOWED); + errors.add(X509_V_ERR_INVALID_EXTENSION); + errors.add(X509_V_ERR_INVALID_POLICY_EXTENSION); + errors.add(X509_V_ERR_NO_EXPLICIT_POLICY); + errors.add(X509_V_ERR_DIFFERENT_CRL_SCOPE); + errors.add(X509_V_ERR_UNSUPPORTED_EXTENSION_FEATURE); + errors.add(X509_V_ERR_UNNESTED_RESOURCE); + errors.add(X509_V_ERR_PERMITTED_VIOLATION); + errors.add(X509_V_ERR_EXCLUDED_VIOLATION); + errors.add(X509_V_ERR_SUBTREE_MINMAX); + errors.add(X509_V_ERR_APPLICATION_VERIFICATION); + errors.add(X509_V_ERR_UNSUPPORTED_CONSTRAINT_TYPE); + errors.add(X509_V_ERR_UNSUPPORTED_CONSTRAINT_SYNTAX); + errors.add(X509_V_ERR_UNSUPPORTED_NAME_SYNTAX); + errors.add(X509_V_ERR_CRL_PATH_VALIDATION_ERROR); + errors.add(X509_V_ERR_PATH_LOOP); + errors.add(X509_V_ERR_SUITE_B_INVALID_VERSION); + errors.add(X509_V_ERR_SUITE_B_INVALID_ALGORITHM); + errors.add(X509_V_ERR_SUITE_B_INVALID_CURVE); + errors.add(X509_V_ERR_SUITE_B_INVALID_SIGNATURE_ALGORITHM); + errors.add(X509_V_ERR_SUITE_B_LOS_NOT_ALLOWED); + errors.add(X509_V_ERR_SUITE_B_CANNOT_SIGN_P_384_WITH_P_256); + errors.add(X509_V_ERR_HOSTNAME_MISMATCH); + errors.add(X509_V_ERR_EMAIL_MISMATCH); + errors.add(X509_V_ERR_IP_ADDRESS_MISMATCH); + errors.add(X509_V_ERR_DANE_NO_MATCH); + ERRORS = Collections.unmodifiableSet(errors); + } + + /** + * Returns {@code} true if the given {@code errorCode} is valid, {@code false} otherwise. + */ + public static boolean isValid(int errorCode) { + return ERRORS.contains(errorCode); + } + + /** + * Returns {@code true} if the passed in certificate chain could be verified and so the handshake + * should be successful, {@code false} otherwise. + * + * @param ssl the SSL instance + * @param x509 the {@code X509} certificate chain + * @param authAlgorithm the auth algorithm + * @return verified {@code true} if verified successful, {@code false} otherwise + */ + public abstract int verify(long ssl, byte[][] x509, String authAlgorithm); +} diff --git a/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/CertificateVerifierTask.java b/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/CertificateVerifierTask.java new file mode 100644 index 0000000..5c77853 --- /dev/null +++ b/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/CertificateVerifierTask.java @@ -0,0 +1,39 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.internal.tcnative; + + +/** + * Execute {@link CertificateVerifier#verify(long, byte[][], String)}. + */ +final class CertificateVerifierTask extends SSLTask { + private final byte[][] x509; + private final String authAlgorithm; + private final CertificateVerifier verifier; + + CertificateVerifierTask(long ssl, byte[][] x509, String authAlgorithm, CertificateVerifier verifier) { + super(ssl); + this.x509 = x509; + this.authAlgorithm = authAlgorithm; + this.verifier = verifier; + } + + @Override + protected void runTask(long ssl, TaskCallback callback) { + int result = verifier.verify(ssl, x509, authAlgorithm); + callback.onResult(ssl, result); + } +} diff --git a/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/Library.java b/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/Library.java new file mode 100644 index 0000000..074162a --- /dev/null +++ b/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/Library.java @@ -0,0 +1,202 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.netty.internal.tcnative; + +import java.io.File; + +public final class Library { + + /* Default library names */ + private static final String [] NAMES = { + "netty_tcnative", + "libnetty_tcnative" + }; + + private static final String PROVIDED = "provided"; + + /* + * A handle to the unique Library singleton instance. + */ + private static Library _instance = null; + + static { + // Preload all classes that will be used in the OnLoad(...) function of JNI to eliminate the possiblity of a + // class-loader deadlock. This is a workaround for https://github.com/netty/netty/issues/11209. + + // This needs to match all the classes that are loaded via NETTY_JNI_UTIL_LOAD_CLASS or looked up via + // NETTY_JNI_UTIL_FIND_CLASS. + tryLoadClasses(ClassLoader.getSystemClassLoader(), + // error + Exception.class, NullPointerException.class, IllegalArgumentException.class, OutOfMemoryError.class, + + // jnilib + String.class, byte[].class, + + // sslcontext + SSLTask.class, CertificateCallbackTask.class, CertificateCallback.class, SSLPrivateKeyMethodTask.class, + SSLPrivateKeyMethodSignTask.class, SSLPrivateKeyMethodDecryptTask.class + ); + } + + /** + * Preload the given classes and so ensure the {@link ClassLoader} has these loaded after this method call. + * + * @param classLoader the {@link ClassLoader} + * @param classes the classes to load. + */ + private static void tryLoadClasses(ClassLoader classLoader, Class... classes) { + for (Class clazz: classes) { + tryLoadClass(classLoader, clazz.getName()); + } + } + + private static void tryLoadClass(ClassLoader classLoader, String className) { + try { + // Load the class and also ensure we init it which means its linked etc. + Class.forName(className, true, classLoader); + } catch (ClassNotFoundException ignore) { + // Ignore + } catch (SecurityException ignore) { + // Ignore + } + } + + private Library() throws Exception { + boolean loaded = false; + String path = System.getProperty("java.library.path"); + String [] paths = path.split(File.pathSeparator); + StringBuilder err = new StringBuilder(); + for (int i = 0; i < NAMES.length; i++) { + try { + loadLibrary(NAMES[i]); + loaded = true; + } catch (ThreadDeath t) { + throw t; + } catch (VirtualMachineError t) { + throw t; + } catch (Throwable t) { + String name = System.mapLibraryName(NAMES[i]); + for (int j = 0; j < paths.length; j++) { + File fd = new File(paths[j] , name); + if (fd.exists()) { + // File exists but failed to load + throw new RuntimeException(t); + } + } + if (i > 0) { + err.append(", "); + } + err.append(t.getMessage()); + } + if (loaded) { + break; + } + } + if (!loaded) { + throw new UnsatisfiedLinkError(err.toString()); + } + } + + private Library(String libraryName) { + if (!PROVIDED.equals(libraryName)) { + loadLibrary(libraryName); + } + } + + private static void loadLibrary(String libraryName) { + System.loadLibrary(calculatePackagePrefix().replace('.', '_') + libraryName); + } + + /** + * The shading prefix added to this class's full name. + * + * @throws UnsatisfiedLinkError if the shader used something other than a prefix + */ + private static String calculatePackagePrefix() { + String maybeShaded = Library.class.getName(); + // Use ! instead of . to avoid shading utilities from modifying the string + String expected = "io!netty!internal!tcnative!Library".replace('!', '.'); + if (!maybeShaded.endsWith(expected)) { + throw new UnsatisfiedLinkError(String.format( + "Could not find prefix added to %s to get %s. When shading, only adding a " + + "package prefix is supported", expected, maybeShaded)); + } + return maybeShaded.substring(0, maybeShaded.length() - expected.length()); + } + + /* create global TCN's APR pool + * This has to be the first call to TCN library. + */ + private static native boolean initialize0(); + + private static native boolean aprHasThreads(); + + private static native int aprMajorVersion(); + + /* APR_VERSION_STRING */ + private static native String aprVersionString(); + + /** + * Calls {@link #initialize(String, String)} with {@code "provided"} and {@code null}. + * + * @return {@code true} if initialization was successful + * @throws Exception if an error happens during initialization + */ + public static boolean initialize() throws Exception { + return initialize(PROVIDED, null); + } + + /** + * Setup native library. This is the first method that must be called! + * + * @param libraryName the name of the library to load + * @param engine Support for external a Crypto Device ("engine"), usually + * @return {@code true} if initialization was successful + * @throws Exception if an error happens during initialization + */ + public static boolean initialize(String libraryName, String engine) throws Exception { + if (_instance == null) { + _instance = libraryName == null ? new Library() : new Library(libraryName); + + if (aprMajorVersion() < 1) { + throw new UnsatisfiedLinkError("Unsupported APR Version (" + + aprVersionString() + ")"); + } + + if (!aprHasThreads()) { + throw new UnsatisfiedLinkError("Missing APR_HAS_THREADS"); + } + } + return initialize0() && SSL.initialize(engine) == 0; + } +} diff --git a/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/NativeStaticallyReferencedJniMethods.java b/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/NativeStaticallyReferencedJniMethods.java new file mode 100644 index 0000000..b25aa39 --- /dev/null +++ b/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/NativeStaticallyReferencedJniMethods.java @@ -0,0 +1,184 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.internal.tcnative; + +/** + * This class is necessary to break the following cyclic dependency: + *
    + *
  1. JNI_OnLoad
  2. + *
  3. JNI Calls FindClass because RegisterNatives (used to register JNI methods) requires a class
  4. + *
  5. FindClass loads the class, but static members variables of that class attempt to call a JNI method which has not + * yet been registered.
  6. + *
  7. {@link UnsatisfiedLinkError} is thrown because native method has not yet been registered.
  8. + *
+ * Static members which call JNI methods must not be declared in this class! + */ +final class NativeStaticallyReferencedJniMethods { + private NativeStaticallyReferencedJniMethods() { + } + + /** + * Options that may impact security and may be set by default as defined in: + * SSL Docs. + */ + static native int sslOpCipherServerPreference(); + static native int sslOpNoSSLv2(); + static native int sslOpNoSSLv3(); + static native int sslOpNoTLSv1(); + static native int sslOpNoTLSv11(); + static native int sslOpNoTLSv12(); + static native int sslOpNoTLSv13(); + static native int sslOpNoTicket(); + static native int sslOpAllowUnsafeLegacyRenegotiation(); + static native int sslOpLegacyServerConnect(); + + /** + * Options not defined in the OpenSSL docs but may impact security. + */ + static native int sslOpNoCompression(); + + static native int sslSessCacheOff(); + static native int sslSessCacheServer(); + static native int sslSessCacheClient(); + static native int sslSessCacheNoInternalLookup(); + static native int sslSessCacheNoInternalStore(); + + static native int sslStConnect(); + static native int sslStAccept(); + + static native int sslModeEnablePartialWrite(); + static native int sslModeAcceptMovingWriteBuffer(); + static native int sslModeReleaseBuffers(); + static native int sslModeEnableFalseStart(); + + static native int sslSendShutdown(); + static native int sslReceivedShutdown(); + static native int sslErrorNone(); + static native int sslErrorSSL(); + static native int sslErrorWantRead(); + static native int sslErrorWantWrite(); + static native int sslErrorWantX509Lookup(); + static native int sslErrorSyscall(); + static native int sslErrorZeroReturn(); + static native int sslErrorWantConnect(); + static native int sslErrorWantAccept(); + + static native int sslMaxPlaintextLength(); + static native int sslMaxEncryptedLength(); + static native int sslMaxRecordLength(); + + static native int x509CheckFlagAlwaysCheckSubject(); + static native int x509CheckFlagDisableWildCards(); + static native int x509CheckFlagNoPartialWildCards(); + static native int x509CheckFlagMultiLabelWildCards(); + + /* x509 certificate verification errors */ + static native int x509vOK(); + static native int x509vErrUnspecified(); + static native int x509vErrUnableToGetIssuerCert(); + static native int x509vErrUnableToGetCrl(); + static native int x509vErrUnableToDecryptCertSignature(); + static native int x509vErrUnableToDecryptCrlSignature(); + static native int x509vErrUnableToDecodeIssuerPublicKey(); + static native int x509vErrCertSignatureFailure(); + static native int x509vErrCrlSignatureFailure(); + static native int x509vErrCertNotYetValid(); + static native int x509vErrCertHasExpired(); + static native int x509vErrCrlNotYetValid(); + static native int x509vErrCrlHasExpired(); + static native int x509vErrErrorInCertNotBeforeField(); + static native int x509vErrErrorInCertNotAfterField(); + static native int x509vErrErrorInCrlLastUpdateField(); + static native int x509vErrErrorInCrlNextUpdateField(); + static native int x509vErrOutOfMem(); + static native int x509vErrDepthZeroSelfSignedCert(); + static native int x509vErrSelfSignedCertInChain(); + static native int x509vErrUnableToGetIssuerCertLocally(); + static native int x509vErrUnableToVerifyLeafSignature(); + static native int x509vErrCertChainTooLong(); + static native int x509vErrCertRevoked(); + static native int x509vErrInvalidCa(); + static native int x509vErrPathLengthExceeded(); + static native int x509vErrInvalidPurpose(); + static native int x509vErrCertUntrusted(); + static native int x509vErrCertRejected(); + static native int x509vErrSubjectIssuerMismatch(); + static native int x509vErrAkidSkidMismatch(); + static native int x509vErrAkidIssuerSerialMismatch(); + static native int x509vErrKeyUsageNoCertSign(); + static native int x509vErrUnableToGetCrlIssuer(); + static native int x509vErrUnhandledCriticalExtension(); + static native int x509vErrKeyUsageNoCrlSign(); + static native int x509vErrUnhandledCriticalCrlExtension(); + static native int x509vErrInvalidNonCa(); + static native int x509vErrProxyPathLengthExceeded(); + static native int x509vErrKeyUsageNoDigitalSignature(); + static native int x509vErrProxyCertificatesNotAllowed(); + static native int x509vErrInvalidExtension(); + static native int x509vErrInvalidPolicyExtension(); + static native int x509vErrNoExplicitPolicy(); + static native int x509vErrDifferntCrlScope(); + static native int x509vErrUnsupportedExtensionFeature(); + static native int x509vErrUnnestedResource(); + static native int x509vErrPermittedViolation(); + static native int x509vErrExcludedViolation(); + static native int x509vErrSubtreeMinMax(); + static native int x509vErrApplicationVerification(); + static native int x509vErrUnsupportedConstraintType(); + static native int x509vErrUnsupportedConstraintSyntax(); + static native int x509vErrUnsupportedNameSyntax(); + static native int x509vErrCrlPathValidationError(); + static native int x509vErrPathLoop(); + static native int x509vErrSuiteBInvalidVersion(); + static native int x509vErrSuiteBInvalidAlgorithm(); + static native int x509vErrSuiteBInvalidCurve(); + static native int x509vErrSuiteBInvalidSignatureAlgorithm(); + static native int x509vErrSuiteBLosNotAllowed(); + static native int x509vErrSuiteBCannotSignP384WithP256(); + static native int x509vErrHostnameMismatch(); + static native int x509vErrEmailMismatch(); + static native int x509vErrIpAddressMismatch(); + static native int x509vErrDaneNoMatch(); + + // BoringSSL specific. + static native int sslErrorWantCertificateVerify(); + static native int sslErrorWantPrivateKeyOperation(); + static native int sslSignRsaPkcsSha1(); + static native int sslSignRsaPkcsSha256(); + static native int sslSignRsaPkcsSha384(); + static native int sslSignRsaPkcsSha512(); + static native int sslSignEcdsaPkcsSha1(); + static native int sslSignEcdsaSecp256r1Sha256(); + static native int sslSignEcdsaSecp384r1Sha384(); + static native int sslSignEcdsaSecp521r1Sha512(); + static native int sslSignRsaPssRsaeSha256(); + static native int sslSignRsaPssRsaeSha384(); + static native int sslSignRsaPssRsaeSha512(); + static native int sslSignEd25519(); + static native int sslSignRsaPkcs1Md5Sha1(); + + static native int sslRenegotiateNever(); + static native int sslRenegotiateOnce(); + static native int sslRenegotiateFreely(); + static native int sslRenegotiateIgnore(); + static native int sslRenegotiateExplicit(); + static native int sslCertCompressionDirectionCompress(); + static native int sslCertCompressionDirectionDecompress(); + static native int sslCertCompressionDirectionBoth(); + static native int tlsExtCertCompressionZlib(); + static native int tlsExtCertCompressionBrotli(); + static native int tlsExtCertCompressionZstd(); +} diff --git a/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/ResultCallback.java b/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/ResultCallback.java new file mode 100644 index 0000000..104327e --- /dev/null +++ b/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/ResultCallback.java @@ -0,0 +1,39 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.internal.tcnative; + +/** + * Callback that is called once an operation completed. + * + * @param The result type. + */ +public interface ResultCallback { + /** + * Called when the operation completes with the given result. + * + * @param ssl the SSL instance (SSL *) + * @param result the result. + */ + void onSuccess(long ssl, T result); + + /** + * Called when the operation completes with an error. + * + * @param ssl the SSL instance (SSL *) + * @param cause the error. + */ + void onError(long ssl, Throwable cause); +} diff --git a/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/SSL.java b/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/SSL.java new file mode 100644 index 0000000..16cc388 --- /dev/null +++ b/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/SSL.java @@ -0,0 +1,923 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package io.netty.internal.tcnative; + +import java.nio.ByteBuffer; + +import javax.net.ssl.SSLEngine; + +import static io.netty.internal.tcnative.NativeStaticallyReferencedJniMethods.*; + +public final class SSL { + + private SSL() { } + + /* + * Define the SSL Protocol options + */ + public static final int SSL_PROTOCOL_NONE = 0; + public static final int SSL_PROTOCOL_SSLV2 = (1<<0); + public static final int SSL_PROTOCOL_SSLV3 = (1<<1); + public static final int SSL_PROTOCOL_TLSV1 = (1<<2); + public static final int SSL_PROTOCOL_TLSV1_1 = (1<<3); + public static final int SSL_PROTOCOL_TLSV1_2 = (1<<4); + public static final int SSL_PROTOCOL_TLSV1_3 = (1<<5); + + /** TLS_*method according to SSL_CTX_new */ + public static final int SSL_PROTOCOL_TLS = (SSL_PROTOCOL_SSLV3 | SSL_PROTOCOL_TLSV1 | SSL_PROTOCOL_TLSV1_1 | SSL_PROTOCOL_TLSV1_2 | SSL_PROTOCOL_TLSV1_3); + public static final int SSL_PROTOCOL_ALL = (SSL_PROTOCOL_SSLV2 | SSL_PROTOCOL_TLS); + + /* + * Define the SSL verify levels + */ + public static final int SSL_CVERIFY_IGNORED = -1; + public static final int SSL_CVERIFY_NONE = 0; + public static final int SSL_CVERIFY_OPTIONAL = 1; + public static final int SSL_CVERIFY_REQUIRED = 2; + + public static final int SSL_OP_CIPHER_SERVER_PREFERENCE = sslOpCipherServerPreference(); + public static final int SSL_OP_NO_SSLv2 = sslOpNoSSLv2(); + public static final int SSL_OP_NO_SSLv3 = sslOpNoSSLv3(); + public static final int SSL_OP_NO_TLSv1 = sslOpNoTLSv1(); + public static final int SSL_OP_NO_TLSv1_1 = sslOpNoTLSv11(); + public static final int SSL_OP_NO_TLSv1_2 = sslOpNoTLSv12(); + public static final int SSL_OP_NO_TLSv1_3 = sslOpNoTLSv13(); + public static final int SSL_OP_NO_TICKET = sslOpNoTicket(); + + public static final int SSL_OP_NO_COMPRESSION = sslOpNoCompression(); + public static final int SSL_OP_ALLOW_UNSAFE_LEGACY_RENEGOTIATION = sslOpAllowUnsafeLegacyRenegotiation(); + public static final int SSL_OP_LEGACY_SERVER_CONNECT = sslOpLegacyServerConnect(); + + public static final int SSL_MODE_CLIENT = 0; + public static final int SSL_MODE_SERVER = 1; + public static final int SSL_MODE_COMBINED = 2; + + public static final long SSL_SESS_CACHE_OFF = sslSessCacheOff(); + public static final long SSL_SESS_CACHE_SERVER = sslSessCacheServer(); + public static final long SSL_SESS_CACHE_CLIENT = sslSessCacheClient(); + public static final long SSL_SESS_CACHE_NO_INTERNAL_LOOKUP = sslSessCacheNoInternalLookup(); + public static final long SSL_SESS_CACHE_NO_INTERNAL_STORE = sslSessCacheNoInternalStore(); + + public static final int SSL_SELECTOR_FAILURE_NO_ADVERTISE = 0; + public static final int SSL_SELECTOR_FAILURE_CHOOSE_MY_LAST_PROTOCOL = 1; + + public static final int SSL_ST_CONNECT = sslStConnect(); + public static final int SSL_ST_ACCEPT = sslStAccept(); + + public static final int SSL_MODE_ENABLE_PARTIAL_WRITE = sslModeEnablePartialWrite(); + public static final int SSL_MODE_ACCEPT_MOVING_WRITE_BUFFER = sslModeAcceptMovingWriteBuffer(); + public static final int SSL_MODE_RELEASE_BUFFERS = sslModeReleaseBuffers(); + public static final int SSL_MODE_ENABLE_FALSE_START = sslModeEnableFalseStart(); + public static final int SSL_MAX_PLAINTEXT_LENGTH = sslMaxPlaintextLength(); + public static final int SSL_MAX_ENCRYPTED_LENGTH = sslMaxEncryptedLength(); + + /** + * The TLS 1.2 RFC defines the maximum length to be + * {@link #SSL_MAX_PLAINTEXT_LENGTH}, but there are some implementations such as + * OpenJDK's SSLEngineImpl + * that also allow sending larger packets. This can be used as a upper bound for data to support legacy systems. + */ + public static final int SSL_MAX_RECORD_LENGTH = sslMaxRecordLength(); + + // https://www.openssl.org/docs/man1.0.2/crypto/X509_check_host.html + public static final int X509_CHECK_FLAG_ALWAYS_CHECK_SUBJECT = x509CheckFlagAlwaysCheckSubject(); + public static final int X509_CHECK_FLAG_NO_WILD_CARDS = x509CheckFlagDisableWildCards(); + public static final int X509_CHECK_FLAG_NO_PARTIAL_WILD_CARDS = x509CheckFlagNoPartialWildCards(); + public static final int X509_CHECK_FLAG_MULTI_LABEL_WILDCARDS = x509CheckFlagMultiLabelWildCards(); + + public static final int SSL_RENEGOTIATE_NEVER = sslRenegotiateNever(); + public static final int SSL_RENEGOTIATE_ONCE = sslRenegotiateOnce(); + public static final int SSL_RENEGOTIATE_FREELY = sslRenegotiateFreely(); + public static final int SSL_RENEGOTIATE_IGNORE = sslRenegotiateIgnore(); + public static final int SSL_RENEGOTIATE_EXPLICIT = sslRenegotiateExplicit(); + + public static final int SSL_CERT_COMPRESSION_DIRECTION_COMPRESS = sslCertCompressionDirectionCompress(); + public static final int SSL_CERT_COMPRESSION_DIRECTION_DECOMPRESS = sslCertCompressionDirectionDecompress(); + public static final int SSL_CERT_COMPRESSION_DIRECTION_BOTH = sslCertCompressionDirectionBoth(); + + /* Return OpenSSL version number */ + public static native int version(); + + /* Return OpenSSL version string */ + public static native String versionString(); + + /** + * Initialize OpenSSL support. + * + * This function needs to be called once for the + * lifetime of JVM. See {@link Library#initialize(String, String)} + * + * @param engine Support for external a Crypto Device ("engine"), + * usually a hardware accelerator card for crypto operations. + * @return APR status code + */ + static native int initialize(String engine); + + /** + * Initialize new in-memory BIO that is located in the secure heap. + * + * @return New BIO handle + * @throws Exception if an error happened. + */ + public static native long newMemBIO() throws Exception; + + /** + * Return last SSL error string + * + * @return the last SSL error string. + */ + public static native String getLastError(); + + /* + * Begin Twitter API additions + */ + + public static final int SSL_SENT_SHUTDOWN = sslSendShutdown(); + public static final int SSL_RECEIVED_SHUTDOWN = sslReceivedShutdown(); + + public static final int SSL_ERROR_NONE = sslErrorNone(); + public static final int SSL_ERROR_SSL = sslErrorSSL(); + public static final int SSL_ERROR_WANT_READ = sslErrorWantRead(); + public static final int SSL_ERROR_WANT_WRITE = sslErrorWantWrite(); + public static final int SSL_ERROR_WANT_X509_LOOKUP = sslErrorWantX509Lookup(); + public static final int SSL_ERROR_SYSCALL = sslErrorSyscall(); /* look at error stack/return value/errno */ + public static final int SSL_ERROR_ZERO_RETURN = sslErrorZeroReturn(); + public static final int SSL_ERROR_WANT_CONNECT = sslErrorWantConnect(); + public static final int SSL_ERROR_WANT_ACCEPT = sslErrorWantAccept(); + // https://boringssl.googlesource.com/boringssl/+/chromium-stable/include/openssl/ssl.h#519 + public static final int SSL_ERROR_WANT_PRIVATE_KEY_OPERATION = sslErrorWantPrivateKeyOperation(); + + // BoringSSL specific + public static final int SSL_ERROR_WANT_CERTIFICATE_VERIFY = sslErrorWantCertificateVerify(); + + /** + * SSL_new + * @param ctx Server or Client context to use. + * @param server if true configure SSL instance to use accept handshake routines + * if false configure SSL instance to use connect handshake routines + * @return pointer to SSL instance (SSL *) + */ + public static native long newSSL(long ctx, boolean server); + + /** + * SSL_get_error + * @param ssl SSL pointer (SSL *) + * @param ret TLS/SSL I/O return value + * @return the error code + */ + public static native int getError(long ssl, int ret); + + /** + * BIO_write + * @param bioAddress The address of a {@code BIO*}. + * @param wbufAddress The address of a native {@code char*}. + * @param wlen The length to write starting at {@code wbufAddress}. + * @return The number of bytes that were written. + * See BIO_write for exceptional return values. + */ + public static native int bioWrite(long bioAddress, long wbufAddress, int wlen); + + /** + * Initialize the BIO for the SSL instance. This is a custom BIO which is designed to play nicely with a direct + * {@link ByteBuffer}. Because it is a special BIO it requires special usage such that + * {@link #bioSetByteBuffer(long, long, int, boolean)} and {@link #bioClearByteBuffer(long)} are called in order to provide + * to supply data to SSL, and also to ensure the internal SSL buffering mechanism is expecting write at the appropriate times. + * + * @param ssl the SSL instance (SSL *) + * @param nonApplicationBufferSize The size of the internal buffer for write operations that are not + * initiated directly by the application attempting to encrypt data. + * Must be >{@code 0}. + * @return pointer to the Network BIO (BIO *). + * The memory is owned by {@code ssl} and will be cleaned up by {@link #freeSSL(long)}. + */ + public static native long bioNewByteBuffer(long ssl, int nonApplicationBufferSize); + + /** + * Sets the socket file descriptor of the rbio field inside the SSL struct (ssl->rbio->num) + * + * @param ssl the SSL instance (SSL *) + * @param fd the file descriptor of the socket used for the given SSL connection + */ + public static native void bioSetFd(long ssl, int fd); + + /** + * Set the memory location which that OpenSSL's internal BIO will use to write encrypted data to, or read encrypted + * data from. + *

+ * After you are done buffering data you should call {@link #bioClearByteBuffer(long)}. + * @param bio {@code BIO*}. + * @param bufferAddress The memory address (typically from a direct {@link ByteBuffer}) which will be used + * to either write encrypted data to, or read encrypted data from by OpenSSL's internal BIO pair. + * @param maxUsableBytes The maximum usable length in bytes starting at {@code bufferAddress}. + * @param isSSLWriteSink {@code true} if this buffer is expected to buffer data as a result of calls to {@code SSL_write}. + * {@code false} if this buffer is expected to buffer data as a result of calls to {@code SSL_read}. + */ + public static native void bioSetByteBuffer(long bio, long bufferAddress, int maxUsableBytes, boolean isSSLWriteSink); + + /** + * After you are done buffering data from {@link #bioSetByteBuffer(long, long, int, boolean)}, this will ensure the + * internal SSL write buffers are ready to capture data which may unexpectedly happen (e.g. handshake, renegotiation, etc..). + * @param bio {@code BIO*}. + */ + public static native void bioClearByteBuffer(long bio); + + /** + * Flush any pending bytes in the internal SSL write buffer. + *

+ * This does the same thing as {@code BIO_flush} for a {@code BIO*} of type {@link #bioNewByteBuffer(long, int)} but + * returns the number of bytes that were flushed. + * @param bio {@code BIO*}. + * @return The number of bytes that were flushed. + */ + public static native int bioFlushByteBuffer(long bio); + + /** + * Get the remaining length of the {@link ByteBuffer} set by {@link #bioSetByteBuffer(long, long, int, boolean)}. + * @param bio {@code BIO*}. + * @return The remaining length of the {@link ByteBuffer} set by {@link #bioSetByteBuffer(long, long, int, boolean)}. + */ + public static native int bioLengthByteBuffer(long bio); + + /** + * Get the amount of data pending in buffer used for non-application writes. + * This value will not exceed the value configured in {@link #bioNewByteBuffer(long, int)}. + * @param bio {@code BIO*}. + * @return the amount of data pending in buffer used for non-application writes. + */ + public static native int bioLengthNonApplication(long bio); + + /** + * The number of bytes pending in SSL which can be read immediately. + * See SSL_pending. + * @param ssl the SSL instance (SSL *) + * @return The number of bytes pending in SSL which can be read immediately. + */ + public static native int sslPending(long ssl); + + /** + * SSL_write + * @param ssl the SSL instance (SSL *) + * @param wbuf the memory address of the buffer + * @param wlen the length + * @return the number of written bytes + */ + public static native int writeToSSL(long ssl, long wbuf, int wlen); + + /** + * SSL_read + * @param ssl the SSL instance (SSL *) + * @param rbuf the memory address of the buffer + * @param rlen the length + * @return the number of read bytes + */ + public static native int readFromSSL(long ssl, long rbuf, int rlen); + + /** + * SSL_get_shutdown + * @param ssl the SSL instance (SSL *) + * @return the return code of {@code SSL_get_shutdown} + */ + public static native int getShutdown(long ssl); + + /** + * SSL_set_shutdown + * @param ssl the SSL instance (SSL *) + * @param mode the mode to use + */ + public static native void setShutdown(long ssl, int mode); + + /** + * SSL_free + * @param ssl the SSL instance (SSL *) + */ + public static native void freeSSL(long ssl); + + /** + * BIO_free + * @param bio the BIO + */ + public static native void freeBIO(long bio); + + /** + * SSL_shutdown + * @param ssl the SSL instance (SSL *) + * @return the return code of {@code SSL_shutdown} + */ + public static native int shutdownSSL(long ssl); + + /** + * Get the error number representing the last error OpenSSL encountered on this thread. + * @return the last error code for the calling thread. + */ + public static native int getLastErrorNumber(); + + /** + * SSL_get_cipher + * @param ssl the SSL instance (SSL *) + * @return the name of the current cipher. + */ + public static native String getCipherForSSL(long ssl); + + /** + * SSL_get_version + * @param ssl the SSL instance (SSL *) + * @return the version. + */ + public static native String getVersion(long ssl); + + /** + * SSL_do_handshake + * @param ssl the SSL instance (SSL *) + * @return the return code of {@code SSL_do_handshake}. + */ + public static native int doHandshake(long ssl); + + /** + * SSL_in_init + * @param ssl the SSL instance (SSL *) + * @return the return code of {@code SSL_in_init}. + */ + public static native int isInInit(long ssl); + + /** + * SSL_get0_next_proto_negotiated + * @param ssl the SSL instance (SSL *) + * @return the name of the negotiated proto + */ + public static native String getNextProtoNegotiated(long ssl); + + /* + * End Twitter API Additions + */ + + /** + * SSL_get0_alpn_selected + * @param ssl the SSL instance (SSL *) + * @return the name of the selected ALPN protocol + */ + public static native String getAlpnSelected(long ssl); + + /** + * Get the peer certificate chain or {@code null} if none was send. + * @param ssl the SSL instance (SSL *) + * @return the chain or {@code null} if none was send + */ + public static native byte[][] getPeerCertChain(long ssl); + + /** + * Get the peer certificate or {@code null} if non was send. + * @param ssl the SSL instance (SSL *) + * @return the peer certificate or {@code null} if none was send + */ + public static native byte[] getPeerCertificate(long ssl); + + /** + * Get the error string representing for the given {@code errorNumber}. + * + * @param errorNumber the error number / code + * @return the error string + */ + public static native String getErrorString(long errorNumber); + + /** + * SSL_get_time + * @param ssl the SSL instance (SSL *) + * @return returns the time at which the session ssl was established. The time is given in seconds since the Epoch + */ + public static native long getTime(long ssl); + + /** + * SSL_get_timeout + * @param ssl the SSL instance (SSL *) + * @return returns the timeout for the session ssl The time is given in seconds since the Epoch + */ + public static native long getTimeout(long ssl); + + /** + * SSL_set_timeout + * @param ssl the SSL instance (SSL *) + * @param seconds timeout in seconds + * @return returns the timeout for the session ssl before this call. The time is given in seconds since the Epoch + */ + public static native long setTimeout(long ssl, long seconds); + + /** + * Set Type of Client Certificate verification and Maximum depth of CA Certificates + * in Client Certificate verification. + *

+ * This directive sets the Certificate verification level for the Client + * Authentication. Notice that this directive can be used both in per-server + * and per-directory context. In per-server context it applies to the client + * authentication process used in the standard SSL handshake when a connection + * is established. In per-directory context it forces a SSL renegotiation with + * the reconfigured client verification level after the HTTP request was read + * but before the HTTP response is sent. + *

+ * The following levels are available for level: + *

    + *
  • {@link #SSL_CVERIFY_IGNORED} - The level is ignored. Only depth will change.
  • + *
  • {@link #SSL_CVERIFY_NONE} - No client Certificate is required at all
  • + *
  • {@link #SSL_CVERIFY_OPTIONAL} - The client may present a valid Certificate
  • + *
  • {@link #SSL_CVERIFY_REQUIRED} - The client has to present a valid Certificate
  • + *
+ * The depth actually is the maximum number of intermediate certificate issuers, + * i.e. the number of CA certificates which are max allowed to be followed while + * verifying the client certificate. A depth of 0 means that self-signed client + * certificates are accepted only, the default depth of 1 means the client + * certificate can be self-signed or has to be signed by a CA which is directly + * known to the server (i.e. the CA's certificate is under + * {@code setCACertificatePath}, etc. + * + * @param ssl the SSL instance (SSL *) + * @param level Type of Client Certificate verification. + * @param depth Maximum depth of CA Certificates in Client Certificate + * verification. Ignored if value is {@code <0}. + */ + public static native void setVerify(long ssl, int level, int depth); + + /** + * Set OpenSSL Option. + * @param ssl the SSL instance (SSL *) + * @param options See SSL.SSL_OP_* for option flags. + */ + public static native void setOptions(long ssl, int options); + + /** + * Clear OpenSSL Option. + * @param ssl the SSL instance (SSL *) + * @param options See SSL.SSL_OP_* for option flags. + */ + public static native void clearOptions(long ssl, int options); + + /** + * Get OpenSSL Option. + * @param ssl the SSL instance (SSL *) + * @return options See SSL.SSL_OP_* for option flags. + */ + public static native int getOptions(long ssl); + + /** + * Call SSL_set_mode + * + * @param ssl the SSL instance (SSL *). + * @param mode the mode + * @return the set mode. + */ + public static native int setMode(long ssl, int mode); + + /** + * Call SSL_get_mode + * + * @param ssl the SSL instance (SSL *). + * @return the mode. + */ + public static native int getMode(long ssl); + + /** + * Get the maximum overhead, in bytes, of wrapping (a.k.a sealing) a record with ssl. + * See SSL_max_seal_overhead. + * @param ssl the SSL instance (SSL *). + * @return Maximum overhead, in bytes, of wrapping (a.k.a sealing) a record with ssl. + */ + public static native int getMaxWrapOverhead(long ssl); + + /** + * Returns all Returns the cipher suites that are available for negotiation in an SSL handshake. + * @param ssl the SSL instance (SSL *) + * @return ciphers + */ + public static native String[] getCiphers(long ssl); + + /** + * Returns the cipher suites available for negotiation in SSL handshake. + *

+ * This complex directive uses a colon-separated cipher-spec string consisting + * of OpenSSL cipher specifications to configure the Cipher Suite the client + * is permitted to negotiate in the SSL handshake phase. Notice that this + * directive can be used both in per-server and per-directory context. + * In per-server context it applies to the standard SSL handshake when a + * connection is established. In per-directory context it forces a SSL + * renegotiation with the reconfigured Cipher Suite after the HTTP request + * was read but before the HTTP response is sent. + * @param ssl the SSL instance (SSL *) + * @param ciphers an SSL cipher specification + * @return {@code true} if successful + * @throws Exception if an error happened + * @deprecated Use {@link #setCipherSuites(long, String, boolean)} + */ + @Deprecated + public static boolean setCipherSuites(long ssl, String ciphers) + throws Exception { + return setCipherSuites(ssl, ciphers, false); + } + + /** + * Returns the cipher suites available for negotiation in SSL handshake. + *

+ * This complex directive uses a colon-separated cipher-spec string consisting + * of OpenSSL cipher specifications to configure the Cipher Suite the client + * is permitted to negotiate in the SSL handshake phase. Notice that this + * directive can be used both in per-server and per-directory context. + * In per-server context it applies to the standard SSL handshake when a + * connection is established. In per-directory context it forces a SSL + * renegotiation with the reconfigured Cipher Suite after the HTTP request + * was read but before the HTTP response is sent. + * @param ssl the SSL instance (SSL *) + * @param ciphers an SSL cipher specification + * @param tlsv13 {@code true} if the ciphers are for TLSv1.3 + * @return {@code true} if successful + * @throws Exception if an error happened + */ + public static native boolean setCipherSuites(long ssl, String ciphers, boolean tlsv13) + throws Exception; + /** + * Returns the ID of the session as byte array representation. + * + * @param ssl the SSL instance (SSL *) + * @return the session as byte array representation obtained via SSL_SESSION_get_id. + */ + public static native byte[] getSessionId(long ssl); + + /** + * Returns the number of handshakes done for this SSL instance. This also includes renegations. + * + * @param ssl the SSL instance (SSL *) + * @return the number of handshakes done for this SSL instance. + */ + public static native int getHandshakeCount(long ssl); + + /** + * Clear all the errors from the error queue that OpenSSL encountered on this thread. + */ + public static native void clearError(); + + /** + * Call SSL_set_tlsext_host_name + * + * @param ssl the SSL instance (SSL *) + * @param hostname the hostname + */ + public static void setTlsExtHostName(long ssl, String hostname) { + if (hostname != null && hostname.endsWith(".")) { + // Strip trailing dot if included. + // See https://github.com/netty/netty-tcnative/issues/400 + hostname = hostname.substring(0, hostname.length() - 1); + } + setTlsExtHostName0(ssl, hostname); + } + + private static native void setTlsExtHostName0(long ssl, String hostname); + + /** + * Explicitly control hostname validation + * see X509_check_host for X509_CHECK_FLAG* definitions. + * Values are defined as a bitmask of {@code X509_CHECK_FLAG*} values. + * @param ssl the SSL instance (SSL*). + * @param flags a bitmask of {@code X509_CHECK_FLAG*} values. + * @param hostname the hostname which is expected for validation. + */ + public static native void setHostNameValidation(long ssl, int flags, String hostname); + + /** + * Return the methods used for authentication. + * + * @param ssl the SSL instance (SSL*) + * @return the methods + */ + public static native String[] authenticationMethods(long ssl); + + /** + * Set BIO of PEM-encoded Server CA Certificates + *

+ * This directive sets the optional all-in-one file where you can assemble the + * certificates of Certification Authorities (CA) which form the certificate + * chain of the server certificate. This starts with the issuing CA certificate + * of the server certificate and can range up to the root CA certificate. + * Such a file is simply the concatenation of the various PEM-encoded CA + * Certificate files, usually in certificate chain order. + *

+ * But be careful: Providing the certificate chain works only if you are using + * a single (either RSA or DSA) based server certificate. If you are using a + * coupled RSA+DSA certificate pair, this will work only if actually both + * certificates use the same certificate chain. Otherwsie the browsers will be + * confused in this situation. + * @param ssl Server or Client to use. + * @param bio BIO of PEM-encoded Server CA Certificates. + * @param skipfirst Skip first certificate if chain file is inside + * certificate file. + * + * @deprecated use {@link #setKeyMaterial(long, long, long)} + */ + @Deprecated + public static native void setCertificateChainBio(long ssl, long bio, boolean skipfirst); + + /** + * Set Certificate + *
+ * Point setCertificate at a PEM encoded certificate stored in a BIO. If + * the certificate is encrypted, then you will be prompted for a + * pass phrase. Note that a kill -HUP will prompt again. A test + * certificate can be generated with `make certificate' under + * built time. Keep in mind that if you've both a RSA and a DSA + * certificate you can configure both in parallel (to also allow + * the use of DSA ciphers, etc.) + *
+ * If the key is not combined with the certificate, use key param + * to point at the key file. Keep in mind that if + * you've both a RSA and a DSA private key you can configure + * both in parallel (to also allow the use of DSA ciphers, etc.) + * @param ssl Server or Client to use. + * @param certBio Certificate BIO. + * @param keyBio Private Key BIO to use if not in cert. + * @param password Certificate password. If null and certificate + * is encrypted. + * @throws Exception if an error happened + * + * @deprecated use {@link #setKeyMaterial(long, long, long)} + */ + @Deprecated + public static native void setCertificateBio( + long ssl, long certBio, long keyBio, String password) throws Exception; + + /** + * Load a private key from the used OpenSSL ENGINE via the + * ENGINE_load_private_key + * function. + * + *

Be sure you understand how OpenSsl will behave with respect to reference counting! + * + * If the ownership is not transferred you need to call {@link #freePrivateKey(long)} once the key is not used + * anymore to prevent memory leaks. + * + * @param keyId the id of the key. + * @param password the password to use or {@code null} if none. + * @return {@code EVP_PKEY} pointer + * @throws Exception if an error happened + */ + public static native long loadPrivateKeyFromEngine(String keyId, String password) throws Exception; + + /** + * Parse private key from BIO and return {@code EVP_PKEY} pointer. + * + *

Be sure you understand how OpenSsl will behave with respect to reference counting! + * + * If the {@code EVP_PKEY} pointer is used with the client certificate callback + * {@link CertificateRequestedCallback} the ownership goes over to OpenSsl / Tcnative and so calling + * {@link #freePrivateKey(long)} should NOT be done in this case. Otherwise you may + * need to call {@link #freePrivateKey(long)} to decrement the reference count and free memory. + * + * @param privateKeyBio the pointer to the {@code BIO} that contains the private key + * @param password the password or {@code null} if no password is needed + * @return {@code EVP_PKEY} pointer + * @throws Exception if an error happened + */ + public static native long parsePrivateKey(long privateKeyBio, String password) throws Exception; + + /** + * Free private key ({@code EVP_PKEY} pointer). + * + * @param privateKey {@code EVP_PKEY} pointer + */ + public static native void freePrivateKey(long privateKey); + + /** + * Parse X509 chain from BIO and return ({@code STACK_OF(X509)} pointer). + * + *

Be sure you understand how OpenSsl will behave with respect to reference counting! + * + * If the {@code STACK_OF(X509)} pointer is used with the client certificate callback + * {@link CertificateRequestedCallback} the ownership goes over to OpenSsl / Tcnative and so calling + * {@link #freeX509Chain(long)} should NOT be done in this case. Otherwise you may + * need to call {@link #freeX509Chain(long)} to decrement the reference count and free memory. + * + * @param x509ChainBio the pointer to the {@code BIO} that contains the X509 chain + * @return {@code STACK_OF(X509)} pointer + * @throws Exception if an error happened + */ + public static native long parseX509Chain(long x509ChainBio) throws Exception; + + /** + * Free x509 chain ({@code STACK_OF(X509)} pointer). + * + * @param x509Chain {@code STACK_OF(X509)} pointer + */ + public static native void freeX509Chain(long x509Chain); + + /** + * Enables OCSP stapling for the given {@link SSLEngine} or throws an + * exception if OCSP stapling is not supported. + * + *

NOTE: This needs to happen before the SSL handshake. + * + *

SSL_set_tlsext_status_type + *

Search for OCSP + */ + public static native void enableOcsp(long ssl); + + /** + * Sets the keymaterial to be used for the server side. The passed in chain and key needs to be generated via + * {@link #parseX509Chain(long)} and {@link #parsePrivateKey(long, String)}. It's important to note that the caller + * of the method is responsible to free the passed in chain and key in any case as this method will increment the + * reference count of the chain and key. + * + * @deprecated use {@link #setKeyMaterial(long, long, long)} + */ + @Deprecated + public static void setKeyMaterialServerSide(long ssl, long chain, long key) throws Exception { + setKeyMaterial(ssl, chain, key); + } + + /** + * Sets the keymaterial to be used. The passed in chain and key needs to be generated via + * {@link #parseX509Chain(long)} and {@link #parsePrivateKey(long, String)}. It's important to note that the caller + * of the method is responsible to free the passed in chain and key in any case as this method will increment the + * reference count of the chain and key. + */ + public static native void setKeyMaterial(long ssl, long chain, long key) throws Exception; + + /** + * Sets the keymaterial to be used for the client side. The passed in chain and key needs to be generated via + * {@link #parseX509Chain(long)} and {@link #parsePrivateKey(long, String)}. It's important to note that the caller + * of the method is responsible to free the passed in chain and key in any case as this method will increment the + * reference count of the chain and key. + * + * @deprecated use {@link #setKeyMaterial(long, long, long)} + */ + @Deprecated + public static native void setKeyMaterialClientSide(long ssl, long x509Out, long pkeyOut, long chain, long key) throws Exception; + + /** + * Sets the OCSP response for the given {@link SSLEngine} or throws an + * exception in case of an error. + * + *

NOTE: This is only meant to be called for server {@link SSLEngine}s. + * + *

SSL_set_tlsext_status_type + *

Search for OCSP + * + * @param ssl the SSL instance (SSL *) + */ + public static native void setOcspResponse(long ssl, byte[] response); + + /** + * Returns the OCSP response for the given {@link SSLEngine} or {@code null} + * if the server didn't provide a stapled OCSP response. + * + *

NOTE: This is only meant to be called for client {@link SSLEngine}s. + * + *

SSL_set_tlsext_status_type + *

Search for OCSP + * + * @param ssl the SSL instance (SSL *) + */ + public static native byte[] getOcspResponse(long ssl); + + /** + * Set the FIPS mode to use. See man FIPS_mode_set. + * + * @param mode the mode to use. + * @throws Exception throws if setting the fips mode failed. + */ + public static native void fipsModeSet(int mode) throws Exception; + + /** + * Return the SNI hostname that was sent as part of the SSL Hello. + * @param ssl the SSL instance (SSL *) + * @return the SNI hostname or {@code null} if none was used. + */ + public static native String getSniHostname(long ssl); + + /** + * Return the signature algorithms that the remote peer supports or {@code null} if none are supported. + * See man SSL_get_sigalgs for more details. + * The returned names are generated using {@code OBJ_nid2ln} with the {@code psignhash} as parameter. + * + * @param ssl the SSL instance (SSL *) + * @return the signature algorithms or {@code null}. + */ + public static native String[] getSigAlgs(long ssl); + + /** + * Returns the master key used for the current ssl session. + * This should be used extremely sparingly as leaking this key defeats the whole purpose of encryption + * especially forward secrecy. This exists here strictly for debugging purposes. + * + * @param ssl the SSL instance (SSL *) + * @return the master key used for the ssl session + */ + public static native byte[] getMasterKey(long ssl); + + /** + * Extracts the random value sent from the server to the client during the initial SSL/TLS handshake. + * This is needed to extract the HMAC & keys from the master key according to the TLS PRF. + * This is not a random number generator. + * + * @param ssl the SSL instance (SSL *) + * @return the random server value used for the ssl session + */ + public static native byte[] getServerRandom(long ssl); + + /** + * Extracts the random value sent from the client to the server during the initial SSL/TLS handshake. + * This is needed to extract the HMAC & keys from the master key according to the TLS PRF. + * This is not a random number generator. + * + * @param ssl the SSL instance (SSL *) + * @return the random client value used for the ssl session + */ + public static native byte[] getClientRandom(long ssl); + + /** + * Return the {@link Runnable} that needs to be run as an operation did signal that a task needs to be completed + * before we can retry the previous action. + * + * After the task was run we should retry the operation that did signal back that a task needed to be run. + * + * + * The {@link Runnable} may also implement {@link AsyncTask} which allows for fully asynchronous execution if + * {@link AsyncTask#runAsync(Runnable)} is used. + * + * @param ssl the SSL instance (SSL *) + * @return the task to run. + */ + public static native Runnable getTask(long ssl); + + /** + * Return the {@link AsyncTask} that needs to be run as an operation did signal that a task needs to be completed + * before we can retry it. + * + * After the task was run we should retry the operation that did signal back that a task needed to be run. + * + * @param ssl the SSL instance (SSL *) + * @return the task to run. + */ + public static AsyncTask getAsyncTask(long ssl) { + return (AsyncTask) getTask(ssl); + } + + /** + * Return {@code true} if the SSL_SESSION was reused. + * See SSL_session_reused. + * + * @param ssl the SSL instance (SSL *) + * @return {@code true} if the SSL_SESSION was reused, {@code false} otherwise. + */ + public static native boolean isSessionReused(long ssl); + + /** + * Sets the {@code SSL_SESSION} that should be used for {@code SSL}. + * @param ssl the SSL instance (SSL *) + * @param session the SSL_SESSION instance (SSL_SESSION *) + * @return {@code true} if successful, {@code false} otherwise. + */ + public static native boolean setSession(long ssl, long session); + + /** + * Returns the {@code SSL_SESSION} that is used for {@code SSL}. + * See SSL_get_session. + * + * @param ssl the SSL instance (SSL *) + * @return the SSL_SESSION instance (SSL_SESSION *) used + */ + public static native long getSession(long ssl); + + /** + * Allow to set the renegotiation mode that is used. This is only support by {@code BoringSSL}. + * + * See + * SSL_set_renegotiate_mode.. + * @param ssl the SSL instance (SSL *) + * @param mode the mode. + * @throws Exception thrown if some error happens. + */ + public static native void setRenegotiateMode(long ssl, int mode) throws Exception; +} diff --git a/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/SSLContext.java b/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/SSLContext.java new file mode 100644 index 0000000..e83adbc --- /dev/null +++ b/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/SSLContext.java @@ -0,0 +1,763 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package io.netty.internal.tcnative; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.nio.charset.StandardCharsets; + +public final class SSLContext { + private static final int MAX_ALPN_NPN_PROTO_SIZE = 65535; + + private SSLContext() { } + + /** + * Initialize new SSL context + * @param protocol The SSL protocol to use. It can be any combination of + * the following: + *

+     * {@link SSL#SSL_PROTOCOL_SSLV2}
+     * {@link SSL#SSL_PROTOCOL_SSLV3}
+     * {@link SSL#SSL_PROTOCOL_TLSV1}
+     * {@link SSL#SSL_PROTOCOL_TLSV1_1}
+     * {@link SSL#SSL_PROTOCOL_TLSV1_2}
+     * {@link SSL#SSL_PROTOCOL_ALL} ( == all TLS versions, no SSL)
+     * 
+ * @param mode SSL mode to use + *
+     * SSL_MODE_CLIENT
+     * SSL_MODE_SERVER
+     * SSL_MODE_COMBINED
+     * 
+ * @return the SSLContext struct + * @throws Exception if an error happened + */ + public static native long make(int protocol, int mode) + throws Exception; + + /** + * Free the resources used by the Context + * @param ctx Server or Client context to free. + * @return APR Status code. + */ + public static native int free(long ctx); + + /** + * Set Session context id. Usually host:port combination. + * @param ctx Context to use. + * @param id String that uniquely identifies this context. + */ + public static native void setContextId(long ctx, String id); + + /** + * Set OpenSSL Option. + * @param ctx Server or Client context to use. + * @param options See SSL.SSL_OP_* for option flags. + */ + public static native void setOptions(long ctx, int options); + + /** + * Get OpenSSL Option. + * @param ctx Server or Client context to use. + * @return options See SSL.SSL_OP_* for option flags. + */ + public static native int getOptions(long ctx); + + /** + * Clears OpenSSL Options. + * @param ctx Server or Client context to use. + * @param options See SSL.SSL_OP_* for option flags. + */ + public static native void clearOptions(long ctx, int options); + + /** + * Cipher Suite available for negotiation in SSL handshake. + *
+ * This complex directive uses a colon-separated cipher-spec string consisting + * of OpenSSL cipher specifications to configure the Cipher Suite the client + * is permitted to negotiate in the SSL handshake phase. Notice that this + * directive can be used both in per-server and per-directory context. + * In per-server context it applies to the standard SSL handshake when a + * connection is established. In per-directory context it forces a SSL + * renegotiation with the reconfigured Cipher Suite after the HTTP request + * was read but before the HTTP response is sent. + * @param ctx Server or Client context to use. + * @param ciphers An SSL cipher specification. + * @return {@code true} if successful + * @throws Exception if an error happened + * @deprecated Use {@link #setCipherSuite(long, String, boolean)}. + */ + @Deprecated + public static boolean setCipherSuite(long ctx, String ciphers) throws Exception { + return setCipherSuite(ctx, ciphers, false); + } + + /** + * Cipher Suite available for negotiation in SSL handshake. + *
+ * This complex directive uses a colon-separated cipher-spec string consisting + * of OpenSSL cipher specifications to configure the Cipher Suite the client + * is permitted to negotiate in the SSL handshake phase. Notice that this + * directive can be used both in per-server and per-directory context. + * In per-server context it applies to the standard SSL handshake when a + * connection is established. In per-directory context it forces a SSL + * renegotiation with the reconfigured Cipher Suite after the HTTP request + * was read but before the HTTP response is sent. + * @param ctx Server or Client context to use. + * @param ciphers An SSL cipher specification. + * @param tlsv13 {@code true} if the ciphers are for TLSv1.3 + * @return {@code true} if successful + * @throws Exception if an error happened + */ + public static native boolean setCipherSuite(long ctx, String ciphers, boolean tlsv13) throws Exception; + + /** + * Set File of PEM-encoded Server CA Certificates + *
+ * This directive sets the optional all-in-one file where you can assemble the + * certificates of Certification Authorities (CA) which form the certificate + * chain of the server certificate. This starts with the issuing CA certificate + * of of the server certificate and can range up to the root CA certificate. + * Such a file is simply the concatenation of the various PEM-encoded CA + * Certificate files, usually in certificate chain order. + *
+ * But be careful: Providing the certificate chain works only if you are using + * a single (either RSA or DSA) based server certificate. If you are using a + * coupled RSA+DSA certificate pair, this will work only if actually both + * certificates use the same certificate chain. Else the browsers will be + * confused in this situation. + * @param ctx Server or Client context to use. + * @param file File of PEM-encoded Server CA Certificates. + * @param skipfirst Skip first certificate if chain file is inside + * certificate file. + * @return {@code true} if successful + */ + public static native boolean setCertificateChainFile(long ctx, String file, boolean skipfirst); + /** + * Set BIO of PEM-encoded Server CA Certificates + *

+ * This directive sets the optional all-in-one file where you can assemble the + * certificates of Certification Authorities (CA) which form the certificate + * chain of the server certificate. This starts with the issuing CA certificate + * of of the server certificate and can range up to the root CA certificate. + * Such a file is simply the concatenation of the various PEM-encoded CA + * Certificate files, usually in certificate chain order. + *

+ * But be careful: Providing the certificate chain works only if you are using + * a single (either RSA or DSA) based server certificate. If you are using a + * coupled RSA+DSA certificate pair, this will work only if actually both + * certificates use the same certificate chain. Otherwsie the browsers will be + * confused in this situation. + * @param ctx Server or Client context to use. + * @param bio BIO of PEM-encoded Server CA Certificates. + * @param skipfirst Skip first certificate if chain file is inside + * certificate file. + * @return {@code true} if successful + */ + public static native boolean setCertificateChainBio(long ctx, long bio, boolean skipfirst); + + /** + * Set Certificate + *

+ * Point setCertificateFile at a PEM encoded certificate. If + * the certificate is encrypted, then you will be prompted for a + * pass phrase. Note that a kill -HUP will prompt again. A test + * certificate can be generated with `make certificate' under + * built time. Keep in mind that if you've both a RSA and a DSA + * certificate you can configure both in parallel (to also allow + * the use of DSA ciphers, etc.) + *

+ * If the key is not combined with the certificate, use key param + * to point at the key file. Keep in mind that if + * you've both a RSA and a DSA private key you can configure + * both in parallel (to also allow the use of DSA ciphers, etc.) + * @param ctx Server or Client context to use. + * @param cert Certificate file. + * @param key Private Key file to use if not in cert. + * @param password Certificate password. If null and certificate + * is encrypted, password prompt will be displayed. + * @return {@code true} if successful + * @throws Exception if an error happened + */ + public static native boolean setCertificate(long ctx, String cert, String key, String password) throws Exception; + + /** + * Set Certificate + *

+ * Point setCertificate at a PEM encoded certificate stored in a BIO. If + * the certificate is encrypted, then you will be prompted for a + * pass phrase. Note that a kill -HUP will prompt again. A test + * certificate can be generated with `make certificate' under + * built time. Keep in mind that if you've both a RSA and a DSA + * certificate you can configure both in parallel (to also allow + * the use of DSA ciphers, etc.) + *

+ * If the key is not combined with the certificate, use key param + * to point at the key file. Keep in mind that if + * you've both a RSA and a DSA private key you can configure + * both in parallel (to also allow the use of DSA ciphers, etc.) + * @param ctx Server or Client context to use. + * @param certBio Certificate BIO. + * @param keyBio Private Key BIO to use if not in cert. + * @param password Certificate password. If null and certificate + * is encrypted, password prompt will be displayed. + * @return {@code true} if successful + * @throws Exception if an error happened + */ + public static native boolean setCertificateBio(long ctx, long certBio, long keyBio, String password) throws Exception; + + /** + * Set the size of the internal session cache. + * See man SSL_CTX_sess_set_cache_size + * @param ctx Server or Client context to use. + * @param size the size of the cache + * @return the previous set value + */ + public static native long setSessionCacheSize(long ctx, long size); + + /** + * Get the size of the internal session cache. + * See man SSL_CTX_sess_get_cache_size + * @param ctx Server or Client context to use. + * @return the current value + */ + public static native long getSessionCacheSize(long ctx); + + /** + * Set the timeout for the internal session cache in seconds. + * See man SSL_CTX_set_timeout + * @param ctx Server or Client context to use. + * @param timeoutSeconds the timeout of the cache + * @return the previous set value + */ + public static native long setSessionCacheTimeout(long ctx, long timeoutSeconds); + + /** + * Get the timeout for the internal session cache in seconds. + * See man SSL_CTX_get_timeout + * @param ctx Server or Client context to use + * @return the current value + */ + public static native long getSessionCacheTimeout(long ctx); + + /** + * Set the mode of the internal session cache and return the previous used mode. + * @param ctx Server or Client context to use + * @param mode the mode of the cache + * @return the previous set value + */ + public static native long setSessionCacheMode(long ctx, long mode); + + /** + * Get the mode of the current used internal session cache. + * + * @param ctx Server or Client context to use + * @return the current mode + */ + public static native long getSessionCacheMode(long ctx); + + /** + * Session resumption statistics methods. + * See man SSL_CTX_sess_number + * @param ctx Server or Client context to use + * @return the current number + */ + public static native long sessionAccept(long ctx); + + /** + * Session resumption statistics methods. + * See man SSL_CTX_sess_number + * @param ctx Server or Client context to use + * @return the current number + */ + public static native long sessionAcceptGood(long ctx); + + /** + * Session resumption statistics methods. + * See man SSL_CTX_sess_number + * @param ctx Server or Client context to use + * @return the current number + */ + public static native long sessionAcceptRenegotiate(long ctx); + + /** + * Session resumption statistics methods. + * See man SSL_CTX_sess_number + * @param ctx Server or Client context to use + * @return the current number + */ + public static native long sessionCacheFull(long ctx); + + /** + * Session resumption statistics methods. + * See man SSL_CTX_sess_number + * @param ctx Server or Client context to use + * @return the current number + */ + public static native long sessionCbHits(long ctx); + + /** + * Session resumption statistics methods. + * See man SSL_CTX_sess_number + * @param ctx Server or Client context to use + * @return the current number + */ + public static native long sessionConnect(long ctx); + + /** + * Session resumption statistics methods. + * See man SSL_CTX_sess_number + * @param ctx Server or Client context to use + * @return the current number + */ + public static native long sessionConnectGood(long ctx); + + /** + * Session resumption statistics methods. + * See man SSL_CTX_sess_number + * @param ctx Server or Client context to use + * @return the current number + */ + public static native long sessionConnectRenegotiate(long ctx); + + /** + * Session resumption statistics methods. + * See man SSL_CTX_sess_number + * @param ctx Server or Client context to use + * @return the current number + */ + public static native long sessionHits(long ctx); + + /** + * Session resumption statistics methods. + * See man SSL_CTX_sess_number + * @param ctx Server or Client context to use + * @return the current number + */ + public static native long sessionMisses(long ctx); + + /** + * Session resumption statistics methods. + * See man SSL_CTX_sess_number + * @param ctx Server or Client context to use + * @return the current number + */ + public static native long sessionNumber(long ctx); + + /** + * Session resumption statistics methods. + * See man SSL_CTX_sess_number + * @param ctx Server or Client context to use + * @return the current number + */ + public static native long sessionTimeouts(long ctx); + + /** + * TLS session ticket key resumption statistics. + * + * @param ctx Server or Client context to use + * @return the current number + */ + public static native long sessionTicketKeyNew(long ctx); + + /** + * TLS session ticket key resumption statistics. + * + * @param ctx Server or Client context to use + * @return the current number + */ + public static native long sessionTicketKeyResume(long ctx); + + /** + * TLS session ticket key resumption statistics. + * + * @param ctx Server or Client context to use + * @return the current number + */ + public static native long sessionTicketKeyRenew(long ctx); + + /** + * TLS session ticket key resumption statistics. + * + * @param ctx Server or Client context to use + * @return the current number + */ + public static native long sessionTicketKeyFail(long ctx); + + /** + * Set TLS session ticket keys. + * + *

The first key in the list is the primary key. Tickets dervied from the other keys + * in the list will be accepted but updated to a new ticket using the primary key. This + * is useful for implementing ticket key rotation. + * See RFC 5077 + * + * @param ctx Server or Client context to use + * @param keys the {@link SessionTicketKey}s + */ + public static void setSessionTicketKeys(long ctx, SessionTicketKey[] keys) { + if (keys == null || keys.length == 0) { + throw new IllegalArgumentException("Length of the keys should be longer than 0."); + } + byte[] binaryKeys = new byte[keys.length * SessionTicketKey.TICKET_KEY_SIZE]; + for (int i = 0; i < keys.length; i++) { + SessionTicketKey key = keys[i]; + int dstCurPos = SessionTicketKey.TICKET_KEY_SIZE * i; + System.arraycopy(key.name, 0, binaryKeys, dstCurPos, SessionTicketKey.NAME_SIZE); + dstCurPos += SessionTicketKey.NAME_SIZE; + System.arraycopy(key.hmacKey, 0, binaryKeys, dstCurPos, SessionTicketKey.HMAC_KEY_SIZE); + dstCurPos += SessionTicketKey.HMAC_KEY_SIZE; + System.arraycopy(key.aesKey, 0, binaryKeys, dstCurPos, SessionTicketKey.AES_KEY_SIZE); + } + setSessionTicketKeys0(ctx, binaryKeys); + } + + /** + * Set TLS session keys. + */ + private static native void setSessionTicketKeys0(long ctx, byte[] keys); + + /** + * Set concatenated PEM-encoded CA Certificates for Client Auth + *
+ * This directive sets the all-in-one BIO where you can assemble the + * Certificates of Certification Authorities (CA) whose clients you deal with. + * These are used for Client Authentication. Such a BIO is simply the + * concatenation of the various PEM-encoded Certificate files, in order of + * preference. This can be used alternatively and/or additionally to + * path. + *
+ * @param ctx Server context to use. + * @param certBio Directory of PEM-encoded CA Certificates for Client Auth. + * @return {@code true} if successful, {@code false} otherwise. + */ + public static native boolean setCACertificateBio(long ctx, long certBio); + + /** + * Set Type of Client Certificate verification and Maximum depth of CA Certificates + * in Client Certificate verification. + *
+ * This directive sets the Certificate verification level for the Client + * Authentication. Notice that this directive can be used both in per-server + * and per-directory context. In per-server context it applies to the client + * authentication process used in the standard SSL handshake when a connection + * is established. In per-directory context it forces a SSL renegotiation with + * the reconfigured client verification level after the HTTP request was read + * but before the HTTP response is sent. + *
+ * The following levels are available for level: + *

    + *
  • {@link SSL#SSL_CVERIFY_IGNORED} - The level is ignored. Only depth will change.
  • + *
  • {@link SSL#SSL_CVERIFY_NONE} - No client Certificate is required at all
  • + *
  • {@link SSL#SSL_CVERIFY_OPTIONAL} - The client may present a valid Certificate
  • + *
  • {@link SSL#SSL_CVERIFY_REQUIRED} - The client has to present a valid Certificate
  • + *
+ * The depth actually is the maximum number of intermediate certificate issuers, + * i.e. the number of CA certificates which are max allowed to be followed while + * verifying the client certificate. A depth of 0 means that self-signed client + * certificates are accepted only, the default depth of 1 means the client + * certificate can be self-signed or has to be signed by a CA which is directly + * known to the server (i.e. the CA's certificate is under + * setCACertificatePath), etc. + * @param ctx Server or Client context to use. + * @param level Type of Client Certificate verification. + * @param depth Maximum depth of CA Certificates in Client Certificate + * verification. + */ + public static native void setVerify(long ctx, int level, int depth); + + /** + * Allow to hook {@link CertificateVerifier} into the handshake processing. + * This will call {@code SSL_CTX_set_cert_verify_callback} and so replace the default verification + * callback used by openssl + * @param ctx Server or Client context to use. + * @param verifier the verifier to call during handshake. + */ + public static native void setCertVerifyCallback(long ctx, CertificateVerifier verifier); + + /** + * Allow to hook {@link CertificateRequestedCallback} into the certificate choosing process. + * This will call {@code SSL_CTX_set_client_cert_cb} and so replace the default verification + * callback used by openssl + * @param ctx Server or Client context to use. + * @param callback the callback to call during certificate selection. + * @deprecated use {@link #setCertificateCallback(long, CertificateCallback)} + */ + @Deprecated + public static native void setCertRequestedCallback(long ctx, CertificateRequestedCallback callback); + + /** + * Allow to hook {@link CertificateCallback} into the certificate choosing process. + * This will call {@code SSL_CTX_set_cert_cb} and so replace the default verification + * callback used by openssl + * @param ctx Server or Client context to use. + * @param callback the callback to call during certificate selection. + */ + public static native void setCertificateCallback(long ctx, CertificateCallback callback); + + /** + * Allow to hook {@link SniHostNameMatcher} into the sni processing. + * This will call {@code SSL_CTX_set_tlsext_servername_callback} and so replace the default + * callback used by openssl + * @param ctx Server or Client context to use. + * @param matcher the matcher to call during sni hostname matching. + */ + public static native void setSniHostnameMatcher(long ctx, SniHostNameMatcher matcher); + + private static byte[] protocolsToWireFormat(String[] protocols) { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + try { + for (String p : protocols) { + byte[] bytes = p.getBytes(StandardCharsets.US_ASCII); + if (bytes.length <= MAX_ALPN_NPN_PROTO_SIZE) { + out.write(bytes.length); + out.write(bytes); + } + } + return out.toByteArray(); + } catch (IOException e) { + throw new IllegalStateException(e); + } finally { + try { + out.close(); + } catch (IOException ignore) { + // ignore + } + } + } + + /** + * Set next protocol for next protocol negotiation extension + * @param ctx Server context to use. + * @param nextProtos protocols in priority order + * @param selectorFailureBehavior see {@link SSL#SSL_SELECTOR_FAILURE_NO_ADVERTISE} + * and {@link SSL#SSL_SELECTOR_FAILURE_CHOOSE_MY_LAST_PROTOCOL} + */ + public static void setNpnProtos(long ctx, String[] nextProtos, int selectorFailureBehavior) { + setNpnProtos0(ctx, protocolsToWireFormat(nextProtos), selectorFailureBehavior); + } + + private static native void setNpnProtos0(long ctx, byte[] nextProtos, int selectorFailureBehavior); + + /** + * Set application layer protocol for application layer protocol negotiation extension + * @param ctx Server context to use. + * @param alpnProtos protocols in priority order + * @param selectorFailureBehavior see {@link SSL#SSL_SELECTOR_FAILURE_NO_ADVERTISE} + * and {@link SSL#SSL_SELECTOR_FAILURE_CHOOSE_MY_LAST_PROTOCOL} + */ + public static void setAlpnProtos(long ctx, String[] alpnProtos, int selectorFailureBehavior) { + setAlpnProtos0(ctx, protocolsToWireFormat(alpnProtos), selectorFailureBehavior); + } + + private static native void setAlpnProtos0(long ctx, byte[] alpnProtos, int selectorFailureBehavior); + + /** + * Set length of the DH to use. + * + * @param ctx Server context to use. + * @param length the length. + */ + public static native void setTmpDHLength(long ctx, int length); + + /** + * Set the context within which session be reused (server side only). + * See man SSL_CTX_set_session_id_context + * + * @param ctx Server context to use. + * @param sidCtx can be any kind of binary data, it is therefore possible to use e.g. the name + * of the application and/or the hostname and/or service name + * @return {@code true} if success, {@code false} otherwise. + */ + public static native boolean setSessionIdContext(long ctx, byte[] sidCtx); + + /** + * Call SSL_CTX_set_mode + * + * @param ctx context to use + * @param mode the mode + * @return the set mode. + */ + public static native int setMode(long ctx, int mode); + + /** + * Call SSL_CTX_get_mode + * + * @param ctx context to use + * @return the mode. + */ + public static native int getMode(long ctx); + + /** + * Enables OCSP stapling for the given {@link SSLContext} or throws an + * exception if OCSP stapling is not supported. + * + *

SSL_set_tlsext_status_type + *

Search for OCSP + */ + public static native void enableOcsp(long ctx, boolean client); + + /** + * Disables OCSP stapling on the given {@link SSLContext}. + * + *

SSL_set_tlsext_status_type + *

Search for OCSP + */ + public static native void disableOcsp(long ctx); + + /** + * Returns the {@code SSL_CTX}. + */ + public static native long getSslCtx(long ctx); + + /** + * Enable or disable producing of tasks that should be obtained via {@link SSL#getTask(long)} and run. + * + * @param ctx context to use + * @param useTasks {@code true} to enable, {@code false} to disable. + */ + public static native void setUseTasks(long ctx, boolean useTasks); + + /** + * Adds a certificate compression algorithm to the given {@link SSLContext} or throws an + * exception if certificate compression is not supported or the algorithm not recognized. + * For servers, algorithm preference order is dictated by the order of algorithm registration. + * Most preferred algorithm should be registered first. + * + * This method is currently only supported when {@code BoringSSL} is used. + * + * + * SSL_CTX_add_cert_compression_alg + * rfc8879 + * + * @param ctx context, to which, the algorithm should be added. + * @param direction indicates whether decompression support should be advertized, compression should be applied for + * peers which support it, or both. This allows the caller to support one way compression only. + *

+     * {@link SSL#SSL_CERT_COMPRESSION_DIRECTION_COMPRESS}
+     * {@link SSL#SSL_CERT_COMPRESSION_DIRECTION_DECOMPRESS}
+     * {@link SSL#SSL_CERT_COMPRESSION_DIRECTION_BOTH}
+     * 
+ * @param algorithm implementation of the compression and or decompression algorithm as a {@link CertificateCompressionAlgo} + * @return one on success or zero on error + */ + public static int addCertificateCompressionAlgorithm(long ctx, int direction, final CertificateCompressionAlgo algorithm) { + return addCertificateCompressionAlgorithm0(ctx, direction, algorithm.algorithmId(), algorithm); + } + + private static native int addCertificateCompressionAlgorithm0(long ctx, int direction, int algorithmId, final CertificateCompressionAlgo algorithm); + + /** + * Set the {@link SSLPrivateKeyMethod} to use for the given {@link SSLContext}. + * This allows to offload private key operations + * if needed. + * + * This method is currently only supported when {@code BoringSSL} is used. + * + * @param ctx context to use + * @param method method to use for the given context. + */ + public static void setPrivateKeyMethod(long ctx, final SSLPrivateKeyMethod method) { + setPrivateKeyMethod(ctx, new AsyncSSLPrivateKeyMethodAdapter(method)); + } + + /** + * Sets the {@link AsyncSSLPrivateKeyMethod} to use for the given {@link SSLContext}. + * This allows to offload private key operations if needed. + * + * This method is currently only supported when {@code BoringSSL} is used. + * + * @param ctx context to use + * @param method method to use for the given context. + */ + public static void setPrivateKeyMethod(long ctx, AsyncSSLPrivateKeyMethod method) { + setPrivateKeyMethod0(ctx, method); + } + + private static native void setPrivateKeyMethod0(long ctx, AsyncSSLPrivateKeyMethod method); + + /** + * Set the {@link SSLSessionCache} that will be used if session caching is enabled. + * + * @param ctx context to use. + * @param cache cache to use for the given context. + */ + public static native void setSSLSessionCache(long ctx, SSLSessionCache cache); + + /** + * Set the number of TLSv1.3 session tickets that will be sent to the client after a full handshake. + * + * See SSL_CTX_set_num_tickets for more details. + * @param ctx context to use + * @param tickets the number of tickets + * @return {@code true} if successful, {@code false} otherwise. + */ + public static native boolean setNumTickets(long ctx, int tickets); + + /** + * Sets the curves to use. + * + * See SSL_CTX_set1_curves_list. + * @param ctx context to use + * @param curves the curves to use. + * @return {@code true} if successful, {@code false} otherwise. + */ + public static boolean setCurvesList(long ctx, String... curves) { + if (curves == null) { + throw new NullPointerException("curves"); + } + if (curves.length == 0) { + throw new IllegalArgumentException(); + } + StringBuilder sb = new StringBuilder(); + for (String curve: curves) { + sb.append(curve); + // Curves are separated by : as explained in the manpage. + sb.append(':'); + } + sb.setLength(sb.length() - 1); + return setCurvesList0(ctx, sb.toString()); + } + + private static native boolean setCurvesList0(long ctx, String curves); + + /** + * Set the maximum number of bytes for the certificate chain during handshake. + * See + * SSL_CTX_set_max_cert_list + * for more details. + * @param ctx context to use + * @param size the maximum number of bytes + */ + public static native void setMaxCertList(long ctx, int size); +} diff --git a/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/SSLPrivateKeyMethod.java b/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/SSLPrivateKeyMethod.java new file mode 100644 index 0000000..445ab6a --- /dev/null +++ b/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/SSLPrivateKeyMethod.java @@ -0,0 +1,56 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.internal.tcnative; + +/** + * Allows to customize private key signing / decrypt (when using RSA). + */ +public interface SSLPrivateKeyMethod { + int SSL_SIGN_RSA_PKCS1_SHA1 = NativeStaticallyReferencedJniMethods.sslSignRsaPkcsSha1(); + int SSL_SIGN_RSA_PKCS1_SHA256 = NativeStaticallyReferencedJniMethods.sslSignRsaPkcsSha256(); + int SSL_SIGN_RSA_PKCS1_SHA384 = NativeStaticallyReferencedJniMethods.sslSignRsaPkcsSha384(); + int SSL_SIGN_RSA_PKCS1_SHA512 = NativeStaticallyReferencedJniMethods.sslSignRsaPkcsSha512(); + int SSL_SIGN_ECDSA_SHA1 = NativeStaticallyReferencedJniMethods.sslSignEcdsaPkcsSha1(); + int SSL_SIGN_ECDSA_SECP256R1_SHA256 = NativeStaticallyReferencedJniMethods.sslSignEcdsaSecp256r1Sha256(); + int SSL_SIGN_ECDSA_SECP384R1_SHA384 = NativeStaticallyReferencedJniMethods.sslSignEcdsaSecp384r1Sha384(); + int SSL_SIGN_ECDSA_SECP521R1_SHA512 = NativeStaticallyReferencedJniMethods.sslSignEcdsaSecp521r1Sha512(); + int SSL_SIGN_RSA_PSS_RSAE_SHA256 = NativeStaticallyReferencedJniMethods.sslSignRsaPssRsaeSha256(); + int SSL_SIGN_RSA_PSS_RSAE_SHA384 = NativeStaticallyReferencedJniMethods.sslSignRsaPssRsaeSha384(); + int SSL_SIGN_RSA_PSS_RSAE_SHA512 = NativeStaticallyReferencedJniMethods.sslSignRsaPssRsaeSha512(); + int SSL_SIGN_ED25519 = NativeStaticallyReferencedJniMethods.sslSignEd25519(); + int SSL_SIGN_RSA_PKCS1_MD5_SHA1 = NativeStaticallyReferencedJniMethods.sslSignRsaPkcs1Md5Sha1(); + + /** + * Sign the input with given EC key and returns the signed bytes. + * + * @param ssl the SSL instance + * @param signatureAlgorithm the algorithm to use for signing + * @param input the input itself + * @return the sign + * @throws Exception thrown if an error accours while signing. + */ + byte[] sign(long ssl, int signatureAlgorithm, byte[] input) throws Exception; + + /** + * Decrypts the input with the given RSA key and returns the decrypted bytes. + * + * @param ssl the SSL instance + * @param input the input which should be decrypted + * @return the decrypted data + * @throws Exception thrown if an error accours while decrypting. + */ + byte[] decrypt(long ssl, byte[] input) throws Exception; +} diff --git a/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/SSLPrivateKeyMethodDecryptTask.java b/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/SSLPrivateKeyMethodDecryptTask.java new file mode 100644 index 0000000..0e503f9 --- /dev/null +++ b/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/SSLPrivateKeyMethodDecryptTask.java @@ -0,0 +1,33 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.internal.tcnative; + +final class SSLPrivateKeyMethodDecryptTask extends SSLPrivateKeyMethodTask { + + private final byte[] input; + + SSLPrivateKeyMethodDecryptTask(long ssl, byte[] input, AsyncSSLPrivateKeyMethod method) { + super(ssl, method); + // It's OK to not clone the arrays as we create these in JNI and not reuse. + this.input = input; + } + + @Override + protected void runTask(long ssl, AsyncSSLPrivateKeyMethod method, + ResultCallback resultCallback) { + method.decrypt(ssl, input, resultCallback); + } +} diff --git a/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/SSLPrivateKeyMethodSignTask.java b/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/SSLPrivateKeyMethodSignTask.java new file mode 100644 index 0000000..be6047f --- /dev/null +++ b/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/SSLPrivateKeyMethodSignTask.java @@ -0,0 +1,34 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.internal.tcnative; + +final class SSLPrivateKeyMethodSignTask extends SSLPrivateKeyMethodTask { + private final int signatureAlgorithm; + private final byte[] digest; + + SSLPrivateKeyMethodSignTask(long ssl, int signatureAlgorithm, byte[] digest, AsyncSSLPrivateKeyMethod method) { + super(ssl, method); + this.signatureAlgorithm = signatureAlgorithm; + // It's OK to not clone the arrays as we create these in JNI and not reuse. + this.digest = digest; + } + + @Override + protected void runTask(long ssl, AsyncSSLPrivateKeyMethod method, + ResultCallback resultCallback) { + method.sign(ssl, signatureAlgorithm, digest, resultCallback); + } +} diff --git a/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/SSLPrivateKeyMethodTask.java b/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/SSLPrivateKeyMethodTask.java new file mode 100644 index 0000000..ba99cb5 --- /dev/null +++ b/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/SSLPrivateKeyMethodTask.java @@ -0,0 +1,56 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.internal.tcnative; + +abstract class SSLPrivateKeyMethodTask extends SSLTask implements AsyncTask { + private static final byte[] EMPTY = new byte[0]; + private final AsyncSSLPrivateKeyMethod method; + + // Will be accessed via JNI. + private byte[] resultBytes; + + SSLPrivateKeyMethodTask(long ssl, AsyncSSLPrivateKeyMethod method) { + super(ssl); + this.method = method; + } + + + @Override + public final void runAsync(final Runnable completeCallback) { + run(completeCallback); + } + + @Override + protected final void runTask(final long ssl, final TaskCallback callback) { + runTask(ssl, method, new ResultCallback() { + @Override + public void onSuccess(long ssl, byte[] result) { + resultBytes = result; + callback.onResult(ssl, 1); + } + + @Override + public void onError(long ssl, Throwable cause) { + // Return 0 as this signals back that the operation failed. + resultBytes = EMPTY; + callback.onResult(ssl, 0); + } + }); + } + + protected abstract void runTask(long ssl, AsyncSSLPrivateKeyMethod method, + ResultCallback resultCallback); +} diff --git a/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/SSLSession.java b/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/SSLSession.java new file mode 100644 index 0000000..65c383a --- /dev/null +++ b/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/SSLSession.java @@ -0,0 +1,80 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.internal.tcnative; + +/** + * Methods to operate on a {@code SSL_SESSION}. + */ +public final class SSLSession { + + private SSLSession() { } + + /** + * See SSL_SESSION_get_time. + * + * @param session the SSL_SESSION instance (SSL_SESSION *) + * @return returns the time at which the session was established. The time is given in seconds since the Epoch + */ + public static native long getTime(long session); + + /** + * See SSL_SESSION_get_timeout. + * + * @param session the SSL_SESSION instance (SSL_SESSION *) + * @return returns the timeout for the session. The time is given in seconds since the Epoch + */ + public static native long getTimeout(long session); + + /** + * See SSL_SESSION_set_timeout. + * + * @param session the SSL_SESSION instance (SSL_SESSION *) + * @param seconds timeout in seconds + * @return returns the timeout for the session before this call. The time is given in seconds since the Epoch + */ + public static native long setTimeout(long session, long seconds); + + /** + * See SSL_SESSION_get_id. + * + * @param session the SSL_SESSION instance (SSL_SESSION *) + * @return the session id as byte array representation obtained via SSL_SESSION_get_id. + */ + public static native byte[] getSessionId(long session); + + /** + * See SSL_SESSION_up_ref. + * + * @param session the SSL_SESSION instance (SSL_SESSION *) + * @return {@code true} if successful, {@code false} otherwise. + */ + public static native boolean upRef(long session); + + /** + * See SSL_SESSION_free. + * + * @param session the SSL_SESSION instance (SSL_SESSION *) + */ + public static native void free(long session); + + /** + * Will return {@code true} if the session should only re-used once. + * See SSL_SESSION_should_be_single_use. + * @param session + * @return {@code true} if the session should be re-used once only, {@code false} otherwise. + */ + public static native boolean shouldBeSingleUse(long session); +} diff --git a/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/SSLSessionCache.java b/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/SSLSessionCache.java new file mode 100644 index 0000000..1a9a315 --- /dev/null +++ b/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/SSLSessionCache.java @@ -0,0 +1,49 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.internal.tcnative; + +/** + * Allows to implement a custom external {@code SSL_SESSION} cache. + * + * See SSL_CTX_sess_set_get_cb.html + * and {a href="https://www.openssl.org/docs/man1.1.0/man3/SSL_CTX_set_session_cache_mode.html">SSL_CTX_set_session_cache_mode. + */ +public interface SSLSessionCache { + + /** + * Returns {@code true} if the cache takes ownership of the {@code SSL_SESSION} and will call {@code SSL_SESSION_free} once it should be destroyed, + * {@code false} otherwise. + * + * See SSL_CTX_sess_set_new_cb. + * + * @param ssl {@code SSL*} + * @param sslSession {@code SSL_SESSION*} + * @return {@code true} if session ownership was transfered, {@code false} if not. + */ + boolean sessionCreated(long ssl, long sslSession); + + /** + * Called once a {@code SSL_SESSION} should be retrieved for the given {@code SSL} and with the given session ID. + * See SSL_CTX_sess_set_get_cb. + * If the session is shared you need to call {@link SSLSession#upRef(long)} explicit in this callback and explicit free all {@code SSL_SESSION}s + * once the cache is destroyed via {@link SSLSession#free(long)}. + * + * @param sslCtx {code SSL_CTX*} + * @param sessionId the session id + * @return the {@link SSL_SESSION} or {@code -1} if none was found in the cache. + */ + long getSession(long sslCtx, byte[] sessionId); +} diff --git a/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/SSLTask.java b/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/SSLTask.java new file mode 100644 index 0000000..ba87097 --- /dev/null +++ b/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/SSLTask.java @@ -0,0 +1,69 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.internal.tcnative; + +/** + * A SSL related task that will be returned by {@link SSL#getTask(long)}. + */ +abstract class SSLTask implements Runnable { + private static final Runnable NOOP = new Runnable() { + @Override + public void run() { + // NOOP + } + }; + private final long ssl; + + // These fields are accessed via JNI. + private int returnValue; + private boolean complete; + protected boolean didRun; + + protected SSLTask(long ssl) { + // It is important that this constructor never throws. Be sure to not change this! + this.ssl = ssl; + } + + @Override + public final void run() { + run(NOOP); + } + + protected final void run(final Runnable completeCallback) { + if (!didRun) { + didRun = true; + runTask(ssl, new TaskCallback() { + @Override + public void onResult(long ssl, int result) { + returnValue = result; + complete = true; + completeCallback.run(); + } + }); + } else { + completeCallback.run(); + } + } + + /** + * Run the task and return the return value that should be passed back to OpenSSL. + */ + protected abstract void runTask(long ssl, TaskCallback callback); + + interface TaskCallback { + void onResult(long ssl, int result); + } +} diff --git a/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/SessionTicketKey.java b/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/SessionTicketKey.java new file mode 100644 index 0000000..92833a7 --- /dev/null +++ b/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/SessionTicketKey.java @@ -0,0 +1,90 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.internal.tcnative; + +/** + * Session Ticket Key + */ +public final class SessionTicketKey { + /** + * Size of session ticket key name + */ + public static final int NAME_SIZE = 16; + /** + * Size of session ticket key HMAC key + */ + public static final int HMAC_KEY_SIZE = 16; + /** + * Size of session ticket key AES key + */ + public static final int AES_KEY_SIZE = 16; + /** + * Size of session ticket key + */ + public static final int TICKET_KEY_SIZE = NAME_SIZE + HMAC_KEY_SIZE + AES_KEY_SIZE; + + // package private so we can access these in SSLContext without calling clone() on the byte[]. + final byte[] name; + final byte[] hmacKey; + final byte[] aesKey; + + /** + * Construct SessionTicketKey. + * @param name the name of the session ticket key + * @param hmacKey the HMAC key of the session ticket key + * @param aesKey the AES key of the session ticket key + */ + public SessionTicketKey(byte[] name, byte[] hmacKey, byte[] aesKey) { + if (name == null || name.length != NAME_SIZE) { + throw new IllegalArgumentException("Length of name should be " + NAME_SIZE); + } + if (hmacKey == null || hmacKey.length != HMAC_KEY_SIZE) { + throw new IllegalArgumentException("Length of hmacKey should be " + HMAC_KEY_SIZE); + } + if (aesKey == null || aesKey.length != AES_KEY_SIZE) { + throw new IllegalArgumentException("Length of aesKey should be " + AES_KEY_SIZE); + } + this.name = name; + this.hmacKey = hmacKey; + this.aesKey = aesKey; + } + + /** + * Get name. + * + * @return the name of the session ticket key + */ + public byte[] getName() { + return name.clone(); + } + + /** + * Get HMAC key. + * @return the HMAC key of the session ticket key + */ + public byte[] getHmacKey() { + return hmacKey.clone(); + } + + /** + * Get AES Key. + * @return the AES key of the session ticket key + */ + public byte[] getAesKey() { + return aesKey.clone(); + } +} diff --git a/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/SniHostNameMatcher.java b/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/SniHostNameMatcher.java new file mode 100644 index 0000000..4a1e78b --- /dev/null +++ b/netty-internal-tcnative/src/main/java/io/netty/internal/tcnative/SniHostNameMatcher.java @@ -0,0 +1,27 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.internal.tcnative; + +public interface SniHostNameMatcher { + + /** + * Returns {@code true} if the hostname was matched and so SNI should be allowed. + * @param ssl the SSL instance + * @param hostname the hostname to match. + * @return {@code true} if the hostname was matched + */ + boolean match(long ssl, String hostname); +} diff --git a/netty-internal-tcnative/src/main/java/module-info.java b/netty-internal-tcnative/src/main/java/module-info.java new file mode 100644 index 0000000..fd9f76a --- /dev/null +++ b/netty-internal-tcnative/src/main/java/module-info.java @@ -0,0 +1,3 @@ +module org.xbib.io.netty.internal.tcnative { + exports io.netty.internal.tcnative; +} diff --git a/netty-jctools/build.gradle b/netty-jctools/build.gradle new file mode 100644 index 0000000..84ae758 --- /dev/null +++ b/netty-jctools/build.gradle @@ -0,0 +1,7 @@ +dependencies { + testImplementation testLibs.junit4 + testImplementation testLibs.testlibs + testImplementation testLibs.lincheck + testImplementation testLibs.asm.commons + testImplementation testLibs.asm.util +} diff --git a/netty-jctools/src/main/java/module-info.java b/netty-jctools/src/main/java/module-info.java new file mode 100644 index 0000000..2180ea6 --- /dev/null +++ b/netty-jctools/src/main/java/module-info.java @@ -0,0 +1,9 @@ +module org.xbib.io.netty.jctools { + exports org.jctools.counters; + exports org.jctools.maps; + exports org.jctools.queues; + exports org.jctools.queues.atomic; + exports org.jctools.queues.unpadded; + exports org.jctools.util; + requires jdk.unsupported; +} diff --git a/netty-jctools/src/main/java/org/jctools/counters/Counter.java b/netty-jctools/src/main/java/org/jctools/counters/Counter.java new file mode 100644 index 0000000..213900e --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/counters/Counter.java @@ -0,0 +1,17 @@ +package org.jctools.counters; + +/** + * Base counter interface. + * + * @author Tolstopyatov Vsevolod + */ +public interface Counter { + + void inc(); + + void inc(long delta); + + long get(); + + long getAndReset(); +} diff --git a/netty-jctools/src/main/java/org/jctools/counters/CountersFactory.java b/netty-jctools/src/main/java/org/jctools/counters/CountersFactory.java new file mode 100644 index 0000000..05bc5a8 --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/counters/CountersFactory.java @@ -0,0 +1,28 @@ +package org.jctools.counters; + +import org.jctools.util.UnsafeAccess; + +/** + * @author Tolstopyatov Vsevolod + */ +public final class CountersFactory { + + private CountersFactory() { + } + + public static FixedSizeStripedLongCounter createFixedSizeStripedCounter(int stripesCount) { + if (UnsafeAccess.SUPPORTS_GET_AND_ADD_LONG) { + return new FixedSizeStripedLongCounterV8(stripesCount); + } else { + return new FixedSizeStripedLongCounterV6(stripesCount); + } + } + + public static FixedSizeStripedLongCounter createFixedSizeStripedCounterV6(int stripesCount) { + return new FixedSizeStripedLongCounterV6(stripesCount); + } + + public static FixedSizeStripedLongCounter createFixedSizeStripedCounterV8(int stripesCount) { + return new FixedSizeStripedLongCounterV8(stripesCount); + } +} diff --git a/netty-jctools/src/main/java/org/jctools/counters/FixedSizeStripedLongCounter.java b/netty-jctools/src/main/java/org/jctools/counters/FixedSizeStripedLongCounter.java new file mode 100644 index 0000000..9e4facf --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/counters/FixedSizeStripedLongCounter.java @@ -0,0 +1,163 @@ +package org.jctools.counters; + +import static org.jctools.util.UnsafeAccess.UNSAFE; + +import java.util.concurrent.ThreadLocalRandom; + +import org.jctools.util.PortableJvmInfo; +import org.jctools.util.Pow2; + +/** + * Basic class representing static striped long counter with + * common mechanics for implementors. + * + * @author Tolstopyatov Vsevolod + */ +abstract class FixedSizeStripedLongCounterPrePad { + byte b000,b001,b002,b003,b004,b005,b006,b007;// 8b + byte b010,b011,b012,b013,b014,b015,b016,b017;// 16b + byte b020,b021,b022,b023,b024,b025,b026,b027;// 24b + byte b030,b031,b032,b033,b034,b035,b036,b037;// 32b + byte b040,b041,b042,b043,b044,b045,b046,b047;// 40b + byte b050,b051,b052,b053,b054,b055,b056,b057;// 48b + byte b060,b061,b062,b063,b064,b065,b066,b067;// 56b + byte b070,b071,b072,b073,b074,b075,b076,b077;// 64b + byte b100,b101,b102,b103,b104,b105,b106,b107;// 72b + byte b110,b111,b112,b113,b114,b115,b116,b117;// 80b + byte b120,b121,b122,b123,b124,b125,b126,b127;// 88b + byte b130,b131,b132,b133,b134,b135,b136,b137;// 96b + byte b140,b141,b142,b143,b144,b145,b146,b147;//104b + byte b150,b151,b152,b153,b154,b155,b156,b157;//112b + byte b160,b161,b162,b163,b164,b165,b166,b167;//120b + // byte b170,b171,b172,b173,b174,b175,b176,b177;//128b +} +abstract class FixedSizeStripedLongCounterFields extends FixedSizeStripedLongCounterPrePad { + protected static final int CACHE_LINE_IN_LONGS = PortableJvmInfo.CACHE_LINE_SIZE / 8; + // place first element at the end of the cache line of the array object + protected static final long COUNTER_ARRAY_BASE = Math.max(UNSAFE.arrayBaseOffset(long[].class), PortableJvmInfo.CACHE_LINE_SIZE - 8); + // element shift is enlarged to include the padding, still aligned to long + protected static final long ELEMENT_SHIFT = Integer.numberOfTrailingZeros(PortableJvmInfo.CACHE_LINE_SIZE); + + // we pad each element in the array to effectively write a counter in each cache line + protected final long[] cells; + protected final int mask; + protected FixedSizeStripedLongCounterFields(int stripesCount) { + if (stripesCount <= 0) { + throw new IllegalArgumentException("Expecting a stripesCount that is larger than 0"); + } + int size = Pow2.roundToPowerOfTwo(stripesCount); + cells = new long[CACHE_LINE_IN_LONGS * size]; + mask = (size - 1); + } +} + +public abstract class FixedSizeStripedLongCounter extends FixedSizeStripedLongCounterFields implements Counter { + byte b000,b001,b002,b003,b004,b005,b006,b007;// 8b + byte b010,b011,b012,b013,b014,b015,b016,b017;// 16b + byte b020,b021,b022,b023,b024,b025,b026,b027;// 24b + byte b030,b031,b032,b033,b034,b035,b036,b037;// 32b + byte b040,b041,b042,b043,b044,b045,b046,b047;// 40b + byte b050,b051,b052,b053,b054,b055,b056,b057;// 48b + byte b060,b061,b062,b063,b064,b065,b066,b067;// 56b + byte b070,b071,b072,b073,b074,b075,b076,b077;// 64b + byte b100,b101,b102,b103,b104,b105,b106,b107;// 72b + byte b110,b111,b112,b113,b114,b115,b116,b117;// 80b + byte b120,b121,b122,b123,b124,b125,b126,b127;// 88b + byte b130,b131,b132,b133,b134,b135,b136,b137;// 96b + byte b140,b141,b142,b143,b144,b145,b146,b147;//104b + byte b150,b151,b152,b153,b154,b155,b156,b157;//112b + byte b160,b161,b162,b163,b164,b165,b166,b167;//120b + //byte b170,b171,b172,b173,b174,b175,b176,b177;//128b + + private static final long PROBE = getProbeOffset(); + + private static long getProbeOffset() { + try { + return UNSAFE.objectFieldOffset(Thread.class.getDeclaredField("threadLocalRandomProbe")); + + } catch (NoSuchFieldException e) { + return -1L; + } + } + + public FixedSizeStripedLongCounter(int stripesCount) { + super(stripesCount); + } + + @Override + public void inc() { + inc(1L); + } + + @Override + public void inc(long delta) { + inc(cells, counterOffset(index()), delta); + } + + @Override + public long get() { + long result = 0L; + long[] cells = this.cells; + int length = mask + 1; + for (int i = 0; i < length; i++) { + result += UNSAFE.getLongVolatile(cells, counterOffset(i)); + } + return result; + } + + private long counterOffset(long i) { + return COUNTER_ARRAY_BASE + (i << ELEMENT_SHIFT); + } + + @Override + public long getAndReset() { + long result = 0L; + long[] cells = this.cells; + int length = mask + 1; + for (int i = 0; i < length; i++) { + result += getAndReset(cells, counterOffset(i)); + } + return result; + } + + protected abstract void inc(long[] cells, long offset, long value); + + protected abstract long getAndReset(long[] cells, long offset); + + private int index() { + return probe() & mask; + } + + + /** + * Returns the probe value for the current thread. + * If target JDK version is 7 or higher, than ThreadLocalRandom-specific + * value will be used, xorshift with thread id otherwise. + */ + private int probe() { + // Fast path for reliable well-distributed probe, available from JDK 7+. + // As long as PROBE is final static this branch will be constant folded + // (i.e removed). + if (PROBE != -1) { + int probe; + if ((probe = UNSAFE.getInt(Thread.currentThread(), PROBE)) == 0) { + ThreadLocalRandom.current(); // force initialization + probe = UNSAFE.getInt(Thread.currentThread(), PROBE); + } + return probe; + } + + /* + * Else use much worse (for values distribution) method: + * Mix thread id with golden ratio and then xorshift it + * to spread consecutive ids (see Knuth multiplicative method as reference). + */ + int probe = (int) ((Thread.currentThread().getId() * 0x9e3779b9) & Integer.MAX_VALUE); + // xorshift + probe ^= probe << 13; + probe ^= probe >>> 17; + probe ^= probe << 5; + return probe; + } +} + diff --git a/netty-jctools/src/main/java/org/jctools/counters/FixedSizeStripedLongCounterV6.java b/netty-jctools/src/main/java/org/jctools/counters/FixedSizeStripedLongCounterV6.java new file mode 100644 index 0000000..2693ff7 --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/counters/FixedSizeStripedLongCounterV6.java @@ -0,0 +1,34 @@ +package org.jctools.counters; + +import static org.jctools.util.UnsafeAccess.UNSAFE; + +/** + * Lock-free implementation of striped counter using + * CAS primitives. + * + * @author Tolstopyatov Vsevolod + */ +class FixedSizeStripedLongCounterV6 extends FixedSizeStripedLongCounter { + + public FixedSizeStripedLongCounterV6(int stripesCount) { + super(stripesCount); + } + + @Override + protected void inc(long[] cells, long offset, long delta) { + long v; + do { + v = UNSAFE.getLongVolatile(cells, offset); + } while (!UNSAFE.compareAndSwapLong(cells, offset, v, v + delta)); + } + + @Override + protected long getAndReset(long[] cells, long offset) { + long v; + do { + v = UNSAFE.getLongVolatile(cells, offset); + } while (!UNSAFE.compareAndSwapLong(cells, offset, v, 0L)); + + return v; + } +} diff --git a/netty-jctools/src/main/java/org/jctools/counters/FixedSizeStripedLongCounterV8.java b/netty-jctools/src/main/java/org/jctools/counters/FixedSizeStripedLongCounterV8.java new file mode 100644 index 0000000..3dc85ad --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/counters/FixedSizeStripedLongCounterV8.java @@ -0,0 +1,26 @@ +package org.jctools.counters; + +import static org.jctools.util.UnsafeAccess.UNSAFE; + +/** + * Wait-free implementation of striped counter using + * Java 8 Unsafe intrinsics (lock addq and lock xchg). + * + * @author Tolstopyatov Vsevolod + */ +class FixedSizeStripedLongCounterV8 extends FixedSizeStripedLongCounter { + + public FixedSizeStripedLongCounterV8(int stripesCount) { + super(stripesCount); + } + + @Override + protected void inc(long[] cells, long offset, long delta) { + UNSAFE.getAndAddLong(cells, offset, delta); + } + + @Override + protected long getAndReset(long[] cells, long offset) { + return UNSAFE.getAndSetLong(cells, offset, 0L); + } +} diff --git a/netty-jctools/src/main/java/org/jctools/counters/package-info.java b/netty-jctools/src/main/java/org/jctools/counters/package-info.java new file mode 100644 index 0000000..56a1f47 --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/counters/package-info.java @@ -0,0 +1,14 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.counters; diff --git a/netty-jctools/src/main/java/org/jctools/maps/AbstractEntry.java b/netty-jctools/src/main/java/org/jctools/maps/AbstractEntry.java new file mode 100644 index 0000000..b1cff8d --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/maps/AbstractEntry.java @@ -0,0 +1,59 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.maps; +import java.util.Map; + +/** + * A simple implementation of {@link java.util.Map.Entry}. + * Does not implement {@link java.util.Map.Entry#setValue}, that is done by users of the class. + * + * @since 1.5 + * @author Cliff Click + * @param the type of keys maintained by this map + * @param the type of mapped values + */ + +abstract class AbstractEntry implements Map.Entry { + /** Strongly typed key */ + protected final TypeK _key; + /** Strongly typed value */ + protected TypeV _val; + + public AbstractEntry(final TypeK key, final TypeV val) { _key = key; _val = val; } + public AbstractEntry(final Map.Entry e ) { _key = e.getKey(); _val = e.getValue(); } + /** Return "key=val" string */ + public String toString() { return _key + "=" + _val; } + /** Return key */ + public TypeK getKey () { return _key; } + /** Return val */ + public TypeV getValue() { return _val; } + + /** Equal if the underlying key & value are equal */ + public boolean equals(final Object o) { + if (!(o instanceof Map.Entry)) return false; + final Map.Entry e = (Map.Entry)o; + return eq(_key, e.getKey()) && eq(_val, e.getValue()); + } + + /** Compute "key.hashCode() ^ val.hashCode()" */ + public int hashCode() { + return + ((_key == null) ? 0 : _key.hashCode()) ^ + ((_val == null) ? 0 : _val.hashCode()); + } + + private static boolean eq(final Object o1, final Object o2) { + return (o1 == null ? o2 == null : o1.equals(o2)); + } +} diff --git a/netty-jctools/src/main/java/org/jctools/maps/ConcurrentAutoTable.java b/netty-jctools/src/main/java/org/jctools/maps/ConcurrentAutoTable.java new file mode 100644 index 0000000..93a574b --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/maps/ConcurrentAutoTable.java @@ -0,0 +1,219 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.maps; +import static org.jctools.util.UnsafeAccess.UNSAFE; + +import java.io.Serializable; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; + + +/** + * An auto-resizing table of {@code longs}, supporting low-contention CAS + * operations. Updates are done with CAS's to no particular table element. + * The intent is to support highly scalable counters, r/w locks, and other + * structures where the updates are associative, loss-free (no-brainer), and + * otherwise happen at such a high volume that the cache contention for + * CAS'ing a single word is unacceptable. + * + * @since 1.5 + * @author Cliff Click + */ +public class ConcurrentAutoTable implements Serializable { + + // --- public interface --- + + /** + * Add the given value to current counter value. Concurrent updates will + * not be lost, but addAndGet or getAndAdd are not implemented because the + * total counter value (i.e., {@link #get}) is not atomically updated. + * Updates are striped across an array of counters to avoid cache contention + * and has been tested with performance scaling linearly up to 768 CPUs. + */ + public void add( long x ) { add_if( x); } + /** {@link #add} with -1 */ + public void decrement() { add_if(-1L); } + /** {@link #add} with +1 */ + public void increment() { add_if( 1L); } + + /** Atomically set the sum of the striped counters to specified value. + * Rather more expensive than a simple store, in order to remain atomic. + */ + public void set( long x ) { + CAT newcat = new CAT(null,4,x); + // Spin until CAS works + while( !CAS_cat(_cat,newcat) ) {/*empty*/} + } + + /** + * Current value of the counter. Since other threads are updating furiously + * the value is only approximate, but it includes all counts made by the + * current thread. Requires a pass over the internally striped counters. + */ + public long get() { return _cat.sum(); } + /** Same as {@link #get}, included for completeness. */ + public int intValue() { return (int)_cat.sum(); } + /** Same as {@link #get}, included for completeness. */ + public long longValue() { return _cat.sum(); } + + /** + * A cheaper {@link #get}. Updated only once/millisecond, but as fast as a + * simple load instruction when not updating. + */ + public long estimate_get( ) { return _cat.estimate_sum(); } + + /** + * Return the counter's {@code long} value converted to a string. + */ + public String toString() { return _cat.toString(); } + + /** + * A more verbose print than {@link #toString}, showing internal structure. + * Useful for debugging. + */ + public void print() { _cat.print(); } + + /** + * Return the internal counter striping factor. Useful for diagnosing + * performance problems. + */ + public int internal_size() { return _cat._t.length; } + + // Only add 'x' to some slot in table, hinted at by 'hash'. The sum can + // overflow. Value is CAS'd so no counts are lost. The CAS is retried until + // it succeeds. Returned value is the old value. + private long add_if( long x ) { return _cat.add_if(x,hash(),this); } + + // The underlying array of concurrently updated long counters + private volatile CAT _cat = new CAT(null,16/*Start Small, Think Big!*/,0L); + private static AtomicReferenceFieldUpdater _catUpdater = + AtomicReferenceFieldUpdater.newUpdater(ConcurrentAutoTable.class,CAT.class, "_cat"); + private boolean CAS_cat( CAT oldcat, CAT newcat ) { return _catUpdater.compareAndSet(this,oldcat,newcat); } + + // Hash spreader + private static int hash() { + //int h = (int)Thread.currentThread().getId(); + int h = System.identityHashCode(Thread.currentThread()); + return h<<3; // Pad out cache lines. The goal is to avoid cache-line contention + } + + // --- CAT ----------------------------------------------------------------- + private static class CAT implements Serializable { + + // Unsafe crud: get a function which will CAS arrays + private static final int _Lbase = UNSAFE.arrayBaseOffset(long[].class); + private static final int _Lscale = UNSAFE.arrayIndexScale(long[].class); + private static long rawIndex(long[] ary, int i) { + assert i >= 0 && i < ary.length; + return _Lbase + (i * (long)_Lscale); + } + private static boolean CAS( long[] A, int idx, long old, long nnn ) { + return UNSAFE.compareAndSwapLong( A, rawIndex(A,idx), old, nnn ); + } + + //volatile long _resizers; // count of threads attempting a resize + //static private final AtomicLongFieldUpdater _resizerUpdater = + // AtomicLongFieldUpdater.newUpdater(CAT.class, "_resizers"); + + private final CAT _next; + private volatile long _fuzzy_sum_cache; + private volatile long _fuzzy_time; + private static final int MAX_SPIN=1; + private final long[] _t; // Power-of-2 array of longs + + CAT( CAT next, int sz, long init ) { + _next = next; + _t = new long[sz]; + _t[0] = init; + } + + // Only add 'x' to some slot in table, hinted at by 'hash'. The sum can + // overflow. Value is CAS'd so no counts are lost. The CAS is attempted + // ONCE. + public long add_if( long x, int hash, ConcurrentAutoTable master ) { + final long[] t = _t; + final int idx = hash & (t.length-1); + // Peel loop; try once fast + long old = t[idx]; + final boolean ok = CAS( t, idx, old, old+x ); + if( ok ) return old; // Got it + // Try harder + int cnt=0; + while( true ) { + old = t[idx]; + if( CAS( t, idx, old, old+x ) ) break; // Got it! + cnt++; + } + if( cnt < MAX_SPIN ) return old; // Allowable spin loop count + if( t.length >= 1024*1024 ) return old; // too big already + + // Too much contention; double array size in an effort to reduce contention + //long r = _resizers; + //final int newbytes = (t.length<<1)<<3/*word to bytes*/; + //while( !_resizerUpdater.compareAndSet(this,r,r+newbytes) ) + // r = _resizers; + //r += newbytes; + if( master._cat != this ) return old; // Already doubled, don't bother + //if( (r>>17) != 0 ) { // Already too much allocation attempts? + // // We could use a wait with timeout, so we'll wakeup as soon as the new + // // table is ready, or after the timeout in any case. Annoyingly, this + // // breaks the non-blocking property - so for now we just briefly sleep. + // //synchronized( this ) { wait(8*megs); } // Timeout - we always wakeup + // try { Thread.sleep(r>>17); } catch( InterruptedException e ) { } + // if( master._cat != this ) return old; + //} + + CAT newcat = new CAT(this,t.length*2,0); + // Take 1 stab at updating the CAT with the new larger size. If this + // fails, we assume some other thread already expanded the CAT - so we + // do not need to retry until it succeeds. + while( master._cat == this && !master.CAS_cat(this,newcat) ) {/*empty*/} + return old; + } + + + // Return the current sum of all things in the table. Writers can be + // updating the table furiously, so the sum is only locally accurate. + public long sum( ) { + long sum = _next == null ? 0 : _next.sum(); // Recursively get cached sum + final long[] t = _t; + for( long cnt : t ) sum += cnt; + return sum; + } + + // Fast fuzzy version. Used a cached value until it gets old, then re-up + // the cache. + public long estimate_sum( ) { + // For short tables, just do the work + if( _t.length <= 64 ) return sum(); + // For bigger tables, periodically freshen a cached value + long millis = System.currentTimeMillis(); + if( _fuzzy_time != millis ) { // Time marches on? + _fuzzy_sum_cache = sum(); // Get sum the hard way + _fuzzy_time = millis; // Indicate freshness of cached value + } + return _fuzzy_sum_cache; // Return cached sum + } + + public String toString( ) { return Long.toString(sum()); } + + public void print() { + long[] t = _t; + System.out.print("["+t[0]); + for( int i=1; iHashtable. However, even though all operations are + * thread-safe, operations do not entail locking and there is + * not any support for locking the entire table in a way that + * prevents all access. This class is fully interoperable with + * Hashtable in programs that rely on its thread safety but not on + * its synchronization details. + * + *

Operations (including put) generally do not block, so may + * overlap with other update operations (including other puts and + * removes). Retrievals reflect the results of the most recently + * completed update operations holding upon their onset. For + * aggregate operations such as putAll, concurrent retrievals may + * reflect insertion or removal of only some entries. Similarly, Iterators + * and Enumerations return elements reflecting the state of the hash table at + * some point at or since the creation of the iterator/enumeration. They do + * not throw {@link ConcurrentModificationException}. However, + * iterators are designed to be used by only one thread at a time. + * + *

Very full tables, or tables with high re-probe rates may trigger an + * internal resize operation to move into a larger table. Resizing is not + * terribly expensive, but it is not free either; during resize operations + * table throughput may drop somewhat. All threads that visit the table + * during a resize will 'help' the resizing but will still be allowed to + * complete their operation before the resize is finished (i.e., a simple + * 'get' operation on a million-entry table undergoing resizing will not need + * to block until the entire million entries are copied). + * + *

This class and its views and iterators implement all of the + * optional methods of the {@link Map} and {@link Iterator} + * interfaces. + * + *

Like {@link Hashtable} but unlike {@link HashMap}, this class + * does not allow null to be used as a key or value. + * + * + * @since 1.5 + * @author Cliff Click + * @param the type of keys maintained by this map + * @param the type of mapped values + */ + +public class NonBlockingHashMap + extends AbstractMap + implements ConcurrentMap, Cloneable, Serializable { + + private static final long serialVersionUID = 1234123412341234123L; + + private static final int REPROBE_LIMIT=10; // Too many reprobes then force a table-resize + + // --- Bits to allow Unsafe access to arrays + private static final int _Obase = UNSAFE.arrayBaseOffset(Object[].class); + private static final int _Oscale = UNSAFE.arrayIndexScale(Object[].class); + private static final int _Olog = _Oscale==4?2:(_Oscale==8?3:9999); + private static long rawIndex(final Object[] ary, final int idx) { + assert idx >= 0 && idx < ary.length; + // Note the long-math requirement, to handle arrays of more than 2^31 bytes + // - or 2^28 - or about 268M - 8-byte pointer elements. + return _Obase + ((long)idx << _Olog); + } + + // --- Setup to use Unsafe + private static final long _kvs_offset = fieldOffset(NonBlockingHashMap.class, "_kvs"); + + private final boolean CAS_kvs( final Object[] oldkvs, final Object[] newkvs ) { + return UNSAFE.compareAndSwapObject(this, _kvs_offset, oldkvs, newkvs ); + } + + // --- Adding a 'prime' bit onto Values via wrapping with a junk wrapper class + private static final class Prime { + final Object _V; + Prime( Object V ) { _V = V; } + static Object unbox( Object V ) { return V instanceof Prime ? ((Prime)V)._V : V; } + } + + // --- hash ---------------------------------------------------------------- + // Helper function to spread lousy hashCodes. Throws NPE for null Key, on + // purpose - as the first place to conveniently toss the required NPE for a + // null Key. + private static final int hash(final Object key) { + int h = key.hashCode(); // The real hashCode call + h ^= (h>>>20) ^ (h>>>12); + h ^= (h>>> 7) ^ (h>>> 4); + h += h<<7; // smear low bits up high, for hashcodes that only differ by 1 + return h; + } + + // --- The Hash Table -------------------- + // Slot 0 is always used for a 'CHM' entry below to hold the interesting + // bits of the hash table. Slot 1 holds full hashes as an array of ints. + // Slots {2,3}, {4,5}, etc hold {Key,Value} pairs. The entire hash table + // can be atomically replaced by CASing the _kvs field. + // + // Why is CHM buried inside the _kvs Object array, instead of the other way + // around? The CHM info is used during resize events and updates, but not + // during standard 'get' operations. I assume 'get' is much more frequent + // than 'put'. 'get' can skip the extra indirection of skipping through the + // CHM to reach the _kvs array. + private transient Object[] _kvs; + private static final CHM chm (Object[] kvs) { return (CHM )kvs[0]; } + private static final int[] hashes(Object[] kvs) { return (int[])kvs[1]; } + // Number of K,V pairs in the table + private static final int len(Object[] kvs) { return (kvs.length-2)>>1; } + + // Time since last resize + private transient long _last_resize_milli; + + // --- Minimum table size ---------------- + // Pick size 8 K/V pairs, which turns into (8*2+2)*4+12 = 84 bytes on a + // standard 32-bit HotSpot, and (8*2+2)*8+12 = 156 bytes on 64-bit Azul. + private static final int MIN_SIZE_LOG=3; // + private static final int MIN_SIZE=(1<>4); + } + + // --- NonBlockingHashMap -------------------------------------------------- + // Constructors + + /** Create a new NonBlockingHashMap with default minimum size (currently set + * to 8 K/V pairs or roughly 84 bytes on a standard 32-bit JVM). */ + public NonBlockingHashMap( ) { this(MIN_SIZE); } + + /** Create a new NonBlockingHashMap with initial room for the given number of + * elements, thus avoiding internal resizing operations to reach an + * appropriate size. Large numbers here when used with a small count of + * elements will sacrifice space for a small amount of time gained. The + * initial size will be rounded up internally to the next larger power of 2. */ + public NonBlockingHashMap( final int initial_sz ) { initialize(initial_sz); } + private final void initialize( int initial_sz ) { + RangeUtil.checkPositiveOrZero(initial_sz, "initial_sz"); + int i; // Convert to next largest power-of-2 + if( initial_sz > 1024*1024 ) initial_sz = 1024*1024; + for( i=MIN_SIZE_LOG; (1<size() == 0. + * @return size() == 0 */ + @Override + public boolean isEmpty ( ) { return size() == 0; } + + /** Tests if the key in the table using the equals method. + * @return true if the key is in the table using the equals method + * @throws NullPointerException if the specified key is null */ + @Override + public boolean containsKey( Object key ) { return get(key) != null; } + + /** Legacy method testing if some key maps into the specified value in this + * table. This method is identical in functionality to {@link + * #containsValue}, and exists solely to ensure full compatibility with + * class {@link java.util.Hashtable}, which supported this method prior to + * introduction of the Java Collections framework. + * @param val a value to search for + * @return true if this map maps one or more keys to the specified value + * @throws NullPointerException if the specified value is null */ + public boolean contains ( Object val ) { return containsValue(val); } + + /** Maps the specified key to the specified value in the table. Neither key + * nor value can be null. + *

The value can be retrieved by calling {@link #get} with a key that is + * equal to the original key. + * @param key key with which the specified value is to be associated + * @param val value to be associated with the specified key + * @return the previous value associated with key, or + * null if there was no mapping for key + * @throws NullPointerException if the specified key or value is null */ + @Override + public TypeV put ( TypeK key, TypeV val ) { return putIfMatch( key, val, NO_MATCH_OLD); } + + /** Atomically, do a {@link #put} if-and-only-if the key is not mapped. + * Useful to ensure that only a single mapping for the key exists, even if + * many threads are trying to create the mapping in parallel. + * @return the previous value associated with the specified key, + * or null if there was no mapping for the key + * @throws NullPointerException if the specified key or value is null */ + @Override + public TypeV putIfAbsent( TypeK key, TypeV val ) { return putIfMatch( key, val, TOMBSTONE ); } + + /** Removes the key (and its corresponding value) from this map. + * This method does nothing if the key is not in the map. + * @return the previous value associated with key, or + * null if there was no mapping for key + * @throws NullPointerException if the specified key is null */ + @Override + public TypeV remove ( Object key ) { return putIfMatch( key,TOMBSTONE, NO_MATCH_OLD); } + + /** Atomically do a {@link #remove(Object)} if-and-only-if the key is mapped + * to a value which is equals to the given value. + * @throws NullPointerException if the specified key or value is null */ + public boolean remove ( Object key,Object val ) { + return objectsEquals(putIfMatch( key,TOMBSTONE, val ), val); + } + + /** Atomically do a put(key,val) if-and-only-if the key is + * mapped to some value already. + * @throws NullPointerException if the specified key or value is null */ + @Override + public TypeV replace ( TypeK key, TypeV val ) { return putIfMatch( key, val,MATCH_ANY ); } + + /** Atomically do a put(key,newValue) if-and-only-if the key is + * mapped a value which is equals to oldValue. + * @throws NullPointerException if the specified key or value is null */ + @Override + public boolean replace ( TypeK key, TypeV oldValue, TypeV newValue ) { + return objectsEquals(putIfMatch( key, newValue, oldValue ), oldValue); + } + private static boolean objectsEquals(Object a, Object b) { + return (a == b) || (a != null && a.equals(b)); + } + + // Atomically replace newVal for oldVal, returning the value that existed + // there before. If the oldVal matches the returned value, then newVal was + // inserted, otherwise not. A null oldVal means the key does not exist (only + // insert if missing); a null newVal means to remove the key. + public final TypeV putIfMatchAllowNull( Object key, Object newVal, Object oldVal ) { + if( oldVal == null ) oldVal = TOMBSTONE; + if( newVal == null ) newVal = TOMBSTONE; + final TypeV res = (TypeV) putIfMatch0(this, _kvs, key, newVal, oldVal ); + assert !(res instanceof Prime); + //assert res != null; + return res == TOMBSTONE ? null : res; + } + + /** Atomically replace newVal for oldVal, returning the value that existed + * there before. If the oldVal matches the returned value, then newVal was + * inserted, otherwise not. + * @return the previous value associated with the specified key, + * or null if there was no mapping for the key + * @throws NullPointerException if the key or either value is null + */ + public final TypeV putIfMatch( Object key, Object newVal, Object oldVal ) { + if (oldVal == null || newVal == null) throw new NullPointerException(); + final Object res = putIfMatch0(this, _kvs, key, newVal, oldVal ); + assert !(res instanceof Prime); + assert res != null; + return res == TOMBSTONE ? null : (TypeV)res; + } + + + /** Copies all of the mappings from the specified map to this one, replacing + * any existing mappings. + * @param m mappings to be stored in this map */ + @Override + public void putAll(Map m) { + for (Map.Entry e : m.entrySet()) + put(e.getKey(), e.getValue()); + } + + /** Removes all of the mappings from this map. */ + @Override + public void clear() { // Smack a new empty table down + Object[] newkvs = new NonBlockingHashMap(MIN_SIZE)._kvs; + while( !CAS_kvs(_kvs,newkvs) ) // Spin until the clear works + ; + } + + /** Returns true if this Map maps one or more keys to the specified + * value. Note: This method requires a full internal traversal of the + * hash table and is much slower than {@link #containsKey}. + * @param val value whose presence in this map is to be tested + * @return true if this map maps one or more keys to the specified value + * @throws NullPointerException if the specified value is null */ + @Override + public boolean containsValue( final Object val ) { + if( val == null ) throw new NullPointerException(); + for( TypeV V : values() ) + if( V == val || V.equals(val) ) + return true; + return false; + } + + // This function is supposed to do something for Hashtable, and the JCK + // tests hang until it gets called... by somebody ... for some reason, + // any reason.... + protected void rehash() { + } + + /** + * Creates a shallow copy of this hashtable. All the structure of the + * hashtable itself is copied, but the keys and values are not cloned. + * This is a relatively expensive operation. + * + * @return a clone of the hashtable. + */ + @Override + public Object clone() { + try { + // Must clone, to get the class right; NBHM might have been + // extended so it would be wrong to just make a new NBHM. + NonBlockingHashMap t = (NonBlockingHashMap) super.clone(); + // But I don't have an atomic clone operation - the underlying _kvs + // structure is undergoing rapid change. If I just clone the _kvs + // field, the CHM in _kvs[0] won't be in sync. + // + // Wipe out the cloned array (it was shallow anyways). + t.clear(); + // Now copy sanely + for( TypeK K : keySet() ) { + final TypeV V = get(K); // Do an official 'get' + t.put(K,V); + } + return t; + } catch (CloneNotSupportedException e) { + // this shouldn't happen, since we are Cloneable + throw new InternalError(); + } + } + + /** + * Returns a string representation of this map. The string representation + * consists of a list of key-value mappings in the order returned by the + * map's entrySet view's iterator, enclosed in braces + * ("{}"). Adjacent mappings are separated by the characters + * ", " (comma and space). Each key-value mapping is rendered as + * the key followed by an equals sign ("=") followed by the + * associated value. Keys and values are converted to strings as by + * {@link String#valueOf(Object)}. + * + * @return a string representation of this map + */ + @Override + public String toString() { + Iterator> i = entrySet().iterator(); + if( !i.hasNext()) + return "{}"; + + StringBuilder sb = new StringBuilder(); + sb.append('{'); + for (;;) { + Entry e = i.next(); + TypeK key = e.getKey(); + TypeV value = e.getValue(); + sb.append(key == this ? "(this Map)" : key); + sb.append('='); + sb.append(value == this ? "(this Map)" : value); + if( !i.hasNext()) + return sb.append('}').toString(); + sb.append(", "); + } + } + + // --- keyeq --------------------------------------------------------------- + // Check for key equality. Try direct pointer compare first, then see if + // the hashes are unequal (fast negative test) and finally do the full-on + // 'equals' v-call. + private static boolean keyeq( Object K, Object key, int[] hashes, int hash, int fullhash ) { + return + K==key || // Either keys match exactly OR + // hash exists and matches? hash can be zero during the install of a + // new key/value pair. + ((hashes[hash] == 0 || hashes[hash] == fullhash) && + // Do not call the users' "equals()" call with a Tombstone, as this can + // surprise poorly written "equals()" calls that throw exceptions + // instead of simply returning false. + K != TOMBSTONE && // Do not call users' equals call with a Tombstone + // Do the match the hard way - with the users' key being the loop- + // invariant "this" pointer. I could have flipped the order of + // operands (since equals is commutative), but I'm making mega-morphic + // v-calls in a re-probing loop and nailing down the 'this' argument + // gives both the JIT and the hardware a chance to prefetch the call target. + key.equals(K)); // Finally do the hard match + } + + // --- get ----------------------------------------------------------------- + /** Returns the value to which the specified key is mapped, or {@code null} + * if this map contains no mapping for the key. + *

More formally, if this map contains a mapping from a key {@code k} to + * a value {@code v} such that {@code key.equals(k)}, then this method + * returns {@code v}; otherwise it returns {@code null}. (There can be at + * most one such mapping.) + * @throws NullPointerException if the specified key is null */ + // Never returns a Prime nor a Tombstone. + @Override + public TypeV get( Object key ) { + final Object V = get_impl(this,_kvs,key); + assert !(V instanceof Prime); // Never return a Prime + assert V != TOMBSTONE; + return (TypeV)V; + } + + private static final Object get_impl( final NonBlockingHashMap topmap, final Object[] kvs, final Object key ) { + final int fullhash= hash (key); // throws NullPointerException if key is null + final int len = len (kvs); // Count of key/value pairs, reads kvs.length + final CHM chm = chm (kvs); // The CHM, for a volatile read below; reads slot 0 of kvs + final int[] hashes=hashes(kvs); // The memoized hashes; reads slot 1 of kvs + + int idx = fullhash & (len-1); // First key hash + + // Main spin/reprobe loop, looking for a Key hit + int reprobe_cnt=0; + while( true ) { + // Probe table. Each read of 'val' probably misses in cache in a big + // table; hopefully the read of 'key' then hits in cache. + final Object K = key(kvs,idx); // Get key before volatile read, could be null + final Object V = val(kvs,idx); // Get value before volatile read, could be null or Tombstone or Prime + if( K == null ) return null; // A clear miss + + // We need a volatile-read here to preserve happens-before semantics on + // newly inserted Keys. If the Key body was written just before inserting + // into the table a Key-compare here might read the uninitialized Key body. + // Annoyingly this means we have to volatile-read before EACH key compare. + // . + // We also need a volatile-read between reading a newly inserted Value + // and returning the Value (so the user might end up reading the stale + // Value contents). Same problem as with keys - and the one volatile + // read covers both. + final Object[] newkvs = chm._newkvs; // VOLATILE READ before key compare + + // Key-compare + if( keyeq(K,key,hashes,idx,fullhash) ) { + // Key hit! Check for no table-copy-in-progress + if( !(V instanceof Prime) ) // No copy? + return (V == TOMBSTONE) ? null : V; // Return the value + // Key hit - but slot is (possibly partially) copied to the new table. + // Finish the copy & retry in the new table. + return get_impl(topmap,chm.copy_slot_and_check(topmap,kvs,idx,key),key); // Retry in the new table + } + // get and put must have the same key lookup logic! But only 'put' + // needs to force a table-resize for a too-long key-reprobe sequence. + // Check for too-many-reprobes on get - and flip to the new table. + if( ++reprobe_cnt >= reprobe_limit(len) || // too many probes + K == TOMBSTONE ) // found a TOMBSTONE key, means no more keys in this table + return newkvs == null ? null : get_impl(topmap,topmap.help_copy(newkvs),key); // Retry in the new table + + idx = (idx+1)&(len-1); // Reprobe by 1! (could now prefetch) + } + } + + // --- getk ----------------------------------------------------------------- + /** Returns the Key to which the specified key is mapped, or {@code null} + * if this map contains no mapping for the key. + * @throws NullPointerException if the specified key is null */ + // Never returns a Prime nor a Tombstone. + public TypeK getk( TypeK key ) { + return (TypeK)getk_impl(this,_kvs,key); + } + + private static final Object getk_impl( final NonBlockingHashMap topmap, final Object[] kvs, final Object key ) { + final int fullhash= hash (key); // throws NullPointerException if key is null + final int len = len (kvs); // Count of key/value pairs, reads kvs.length + final CHM chm = chm (kvs); // The CHM, for a volatile read below; reads slot 0 of kvs + final int[] hashes=hashes(kvs); // The memoized hashes; reads slot 1 of kvs + + int idx = fullhash & (len-1); // First key hash + + // Main spin/reprobe loop, looking for a Key hit + int reprobe_cnt=0; + while( true ) { + // Probe table. + final Object K = key(kvs,idx); // Get key before volatile read, could be null + if( K == null ) return null; // A clear miss + + // We need a volatile-read here to preserve happens-before semantics on + // newly inserted Keys. If the Key body was written just before inserting + // into the table a Key-compare here might read the uninitialized Key body. + // Annoyingly this means we have to volatile-read before EACH key compare. + // . + // We also need a volatile-read between reading a newly inserted Value + // and returning the Value (so the user might end up reading the stale + // Value contents). Same problem as with keys - and the one volatile + // read covers both. + final Object[] newkvs = chm._newkvs; // VOLATILE READ before key compare + + // Key-compare + if( keyeq(K,key,hashes,idx,fullhash) ) + return K; // Return existing Key! + + // get and put must have the same key lookup logic! But only 'put' + // needs to force a table-resize for a too-long key-reprobe sequence. + // Check for too-many-reprobes on get - and flip to the new table. + if( ++reprobe_cnt >= reprobe_limit(len) || // too many probes + K == TOMBSTONE ) { // found a TOMBSTONE key, means no more keys in this table + return newkvs == null ? null : getk_impl(topmap,topmap.help_copy(newkvs),key); // Retry in the new table + } + + idx = (idx+1)&(len-1); // Reprobe by 1! (could now prefetch) + } + } + + static volatile int DUMMY_VOLATILE; + /** + * Put, Remove, PutIfAbsent, etc. Return the old value. If the returned value is equal to expVal (or expVal is + * {@link #NO_MATCH_OLD}) then the put can be assumed to work (although might have been immediately overwritten). + * Only the path through copy_slot passes in an expected value of null, and putIfMatch only returns a null if passed + * in an expected null. + * + * @param topmap the map to act on + * @param kvs the KV table snapshot we act on + * @param key not null (will result in {@link NullPointerException}) + * @param putval the new value to use. Not null. {@link #TOMBSTONE} will result in deleting the entry. + * @param expVal expected old value. Can be null. {@link #NO_MATCH_OLD} for an unconditional put/remove. + * {@link #TOMBSTONE} if we expect old entry to not exist(null/{@link #TOMBSTONE} value). + * {@link #MATCH_ANY} will ignore the current value, but only if an entry exists. A null expVal is used + * internally to perform a strict insert-if-never-been-seen-before operation. + * @return {@link #TOMBSTONE} if key does not exist or match has failed. null if expVal is + * null AND old value was null. Otherwise the old entry value (not null). + */ + private static final Object putIfMatch0( + final NonBlockingHashMap topmap, + final Object[] kvs, + final Object key, + final Object putval, + final Object expVal) + { + assert putval != null; + assert !(putval instanceof Prime); + assert !(expVal instanceof Prime); + final int fullhash = hash (key); // throws NullPointerException if key null + final int len = len (kvs); // Count of key/value pairs, reads kvs.length + final CHM chm = chm (kvs); // Reads kvs[0] + final int[] hashes = hashes(kvs); // Reads kvs[1], read before kvs[0] + int idx = fullhash & (len-1); + + // --- + // Key-Claim stanza: spin till we can claim a Key (or force a resizing). + int reprobe_cnt=0; + Object K=null, V=null; + Object[] newkvs=null; + while( true ) { // Spin till we get a Key slot + V = val(kvs,idx); // Get old value (before volatile read below!) + K = key(kvs,idx); // Get current key + if( K == null ) { // Slot is free? + // Found an empty Key slot - which means this Key has never been in + // this table. No need to put a Tombstone - the Key is not here! + if( putval == TOMBSTONE ) return TOMBSTONE; // Not-now & never-been in this table + if( expVal == MATCH_ANY ) return TOMBSTONE; // Will not match, even after K inserts + // Claim the null key-slot + if( CAS_key(kvs,idx, null, key ) ) { // Claim slot for Key + chm._slots.add(1); // Raise key-slots-used count + hashes[idx] = fullhash; // Memoize fullhash + break; // Got it! + } + // CAS to claim the key-slot failed. + // + // This re-read of the Key points out an annoying short-coming of Java + // CAS. Most hardware CAS's report back the existing value - so that + // if you fail you have a *witness* - the value which caused the CAS to + // fail. The Java API turns this into a boolean destroying the + // witness. Re-reading does not recover the witness because another + // thread can write over the memory after the CAS. Hence we can be in + // the unfortunate situation of having a CAS fail *for cause* but + // having that cause removed by a later store. This turns a + // non-spurious-failure CAS (such as Azul has) into one that can + // apparently spuriously fail - and we avoid apparent spurious failure + // by not allowing Keys to ever change. + + // Volatile read, to force loads of K to retry despite JIT, otherwise + // it is legal to e.g. haul the load of "K = key(kvs,idx);" outside of + // this loop (since failed CAS ops have no memory ordering semantics). + int dummy = DUMMY_VOLATILE; + continue; + } + // Key slot was not null, there exists a Key here + + // We need a volatile-read here to preserve happens-before semantics on + // newly inserted Keys. If the Key body was written just before inserting + // into the table a Key-compare here might read the uninitialized Key body. + // Annoyingly this means we have to volatile-read before EACH key compare. + newkvs = chm._newkvs; // VOLATILE READ before key compare + + if( keyeq(K,key,hashes,idx,fullhash) ) + break; // Got it! + + // get and put must have the same key lookup logic! Lest 'get' give + // up looking too soon. + //topmap._reprobes.add(1); + if( ++reprobe_cnt >= reprobe_limit(len) || // too many probes or + K == TOMBSTONE ) { // found a TOMBSTONE key, means no more keys + // We simply must have a new table to do a 'put'. At this point a + // 'get' will also go to the new table (if any). We do not need + // to claim a key slot (indeed, we cannot find a free one to claim!). + newkvs = chm.resize(topmap,kvs); + if( expVal != null ) topmap.help_copy(newkvs); // help along an existing copy + return putIfMatch0(topmap, newkvs, key, putval, expVal); + } + + idx = (idx+1)&(len-1); // Reprobe! + } // End of spinning till we get a Key slot + + while ( true ) { // Spin till we insert a value + // --- + // Found the proper Key slot, now update the matching Value slot. We + // never put a null, so Value slots monotonically move from null to + // not-null (deleted Values use Tombstone). Thus if 'V' is null we + // fail this fast cutout and fall into the check for table-full. + if( putval == V ) return V; // Fast cutout for no-change + + // See if we want to move to a new table (to avoid high average re-probe + // counts). We only check on the initial set of a Value from null to + // not-null (i.e., once per key-insert). Of course we got a 'free' check + // of newkvs once per key-compare (not really free, but paid-for by the + // time we get here). + if( newkvs == null && // New table-copy already spotted? + // Once per fresh key-insert check the hard way + ((V == null && chm.tableFull(reprobe_cnt,len)) || + // Or we found a Prime, but the JMM allowed reordering such that we + // did not spot the new table (very rare race here: the writing + // thread did a CAS of _newkvs then a store of a Prime. This thread + // reads the Prime, then reads _newkvs - but the read of Prime was so + // delayed (or the read of _newkvs was so accelerated) that they + // swapped and we still read a null _newkvs. The resize call below + // will do a CAS on _newkvs forcing the read. + V instanceof Prime) ) { + newkvs = chm.resize(topmap, kvs); // Force the new table copy to start + } + // See if we are moving to a new table. + // If so, copy our slot and retry in the new table. + if( newkvs != null ) { + return putIfMatch0(topmap, chm.copy_slot_and_check(topmap, kvs, idx, expVal), key, putval, expVal); + } + // --- + // We are finally prepared to update the existing table + assert !(V instanceof Prime); + + // Must match old, and we do not? Then bail out now. Note that either V + // or expVal might be TOMBSTONE. Also V can be null, if we've never + // inserted a value before. expVal can be null if we are called from + // copy_slot. + if( expVal != NO_MATCH_OLD && // Do we care about expected-Value at all? + V != expVal && // No instant match already? + (expVal != MATCH_ANY || V == TOMBSTONE || V == null) && + !(V==null && expVal == TOMBSTONE) && // Match on null/TOMBSTONE combo + (expVal == null || !expVal.equals(V)) ) { // Expensive equals check at the last + return (V == null) ? TOMBSTONE : V; // Do not update! + } + + // Actually change the Value in the Key,Value pair + if( CAS_val(kvs, idx, V, putval ) ) break; + + // CAS failed + // Because we have no witness, we do not know why it failed. + // Indeed, by the time we look again the value under test might have flipped + // a thousand times and now be the expected value (despite the CAS failing). + // Check for the never-succeed condition of a Prime value and jump to any + // nested table, or else just re-run. + + // We would not need this load at all if CAS returned the value on which + // the CAS failed (AKA witness). The new CAS semantics are supported via + // VarHandle in JDK9. + V = val(kvs,idx); // Get new value + + // If a Prime'd value got installed, we need to re-run the put on the + // new table. Otherwise we lost the CAS to another racing put. + if( V instanceof Prime ) + return putIfMatch0(topmap, chm.copy_slot_and_check(topmap, kvs, idx, expVal), key, putval, expVal); + + // Simply retry from the start. + // NOTE: need the fence, since otherwise 'val(kvs,idx)' load could be hoisted + // out of loop. + int dummy = DUMMY_VOLATILE; + } + + // CAS succeeded - we did the update! + // Both normal put's and table-copy calls putIfMatch, but table-copy + // does not (effectively) increase the number of live k/v pairs. + if( expVal != null ) { + // Adjust sizes - a striped counter + if( (V == null || V == TOMBSTONE) && putval != TOMBSTONE ) chm._size.add( 1); + if( !(V == null || V == TOMBSTONE) && putval == TOMBSTONE ) chm._size.add(-1); + } + + // We won; we know the update happened as expected. + return (V==null && expVal!=null) ? TOMBSTONE : V; + } + + // --- help_copy --------------------------------------------------------- + // Help along an existing resize operation. This is just a fast cut-out + // wrapper, to encourage inlining for the fast no-copy-in-progress case. We + // always help the top-most table copy, even if there are nested table + // copies in progress. + private final Object[] help_copy( Object[] helper ) { + // Read the top-level KVS only once. We'll try to help this copy along, + // even if it gets promoted out from under us (i.e., the copy completes + // and another KVS becomes the top-level copy). + Object[] topkvs = _kvs; + CHM topchm = chm(topkvs); + if( topchm._newkvs == null ) return helper; // No copy in-progress + topchm.help_copy_impl(this,topkvs,false); + return helper; + } + + + // --- CHM ----------------------------------------------------------------- + // The control structure for the NonBlockingHashMap + private static final class CHM { + // Size in active K,V pairs + private final ConcurrentAutoTable _size; + public int size () { return (int)_size.get(); } + + // --- + // These next 2 fields are used in the resizing heuristics, to judge when + // it is time to resize or copy the table. Slots is a count of used-up + // key slots, and when it nears a large fraction of the table we probably + // end up reprobing too much. Last-resize-milli is the time since the + // last resize; if we are running back-to-back resizes without growing + // (because there are only a few live keys but many slots full of dead + // keys) then we need a larger table to cut down on the churn. + + // Count of used slots, to tell when table is full of dead unusable slots + private final ConcurrentAutoTable _slots; + public int slots() { return (int)_slots.get(); } + + // --- + // New mappings, used during resizing. + // The 'new KVs' array - created during a resize operation. This + // represents the new table being copied from the old one. It's the + // volatile variable that is read as we cross from one table to the next, + // to get the required memory orderings. It monotonically transits from + // null to set (once). + volatile Object[] _newkvs; + private static final AtomicReferenceFieldUpdater _newkvsUpdater = + AtomicReferenceFieldUpdater.newUpdater(CHM.class,Object[].class, "_newkvs"); + // Set the _next field if we can. + boolean CAS_newkvs( Object[] newkvs ) { + while( _newkvs == null ) + if( _newkvsUpdater.compareAndSet(this,null,newkvs) ) + return true; + return false; + } + + // Sometimes many threads race to create a new very large table. Only 1 + // wins the race, but the losers all allocate a junk large table with + // hefty allocation costs. Attempt to control the overkill here by + // throttling attempts to create a new table. I cannot really block here + // (lest I lose the non-blocking property) but late-arriving threads can + // give the initial resizing thread a little time to allocate the initial + // new table. The Right Long Term Fix here is to use array-lets and + // incrementally create the new very large array. In C I'd make the array + // with malloc (which would mmap under the hood) which would only eat + // virtual-address and not real memory - and after Somebody wins then we + // could in parallel initialize the array. Java does not allow + // un-initialized array creation (especially of ref arrays!). + volatile long _resizers; // count of threads attempting an initial resize + private static final AtomicLongFieldUpdater _resizerUpdater = + AtomicLongFieldUpdater.newUpdater(CHM.class, "_resizers"); + + // --- + // Simple constructor + CHM( ConcurrentAutoTable size ) { + _size = size; + _slots= new ConcurrentAutoTable(); + } + + // --- tableFull --------------------------------------------------------- + // Heuristic to decide if this table is too full, and we should start a + // new table. Note that if a 'get' call has reprobed too many times and + // decided the table must be full, then always the estimate_sum must be + // high and we must report the table is full. If we do not, then we might + // end up deciding that the table is not full and inserting into the + // current table, while a 'get' has decided the same key cannot be in this + // table because of too many reprobes. The invariant is: + // slots.estimate_sum >= max_reprobe_cnt >= reprobe_limit(len) + private final boolean tableFull( int reprobe_cnt, int len ) { + return + // Do the cheap check first: we allow some number of reprobes always + reprobe_cnt >= REPROBE_LIMIT && + (reprobe_cnt >= reprobe_limit(len) || + // More expensive check: see if the table is > 1/2 full. + _slots.estimate_get() >= (len>>1)); + } + + // --- resize ------------------------------------------------------------ + // Resizing after too many probes. "How Big???" heuristics are here. + // Callers will (not this routine) will 'help_copy' any in-progress copy. + // Since this routine has a fast cutout for copy-already-started, callers + // MUST 'help_copy' lest we have a path which forever runs through + // 'resize' only to discover a copy-in-progress which never progresses. + private final Object[] resize( NonBlockingHashMap topmap, Object[] kvs) { + assert chm(kvs) == this; + + // Check for resize already in progress, probably triggered by another thread + Object[] newkvs = _newkvs; // VOLATILE READ + if( newkvs != null ) // See if resize is already in progress + return newkvs; // Use the new table already + + // No copy in-progress, so start one. First up: compute new table size. + int oldlen = len(kvs); // Old count of K,V pairs allowed + int sz = size(); // Get current table count of active K,V pairs + int newsz = sz; // First size estimate + + // Heuristic to determine new size. We expect plenty of dead-slots-with-keys + // and we need some decent padding to avoid endless reprobing. + if( sz >= (oldlen>>2) ) { // If we are >25% full of keys then... + newsz = oldlen<<1; // Double size, so new table will be between 12.5% and 25% full + // For tables less than 1M entries, if >50% full of keys then... + // For tables more than 1M entries, if >75% full of keys then... + if( 4L*sz >= ((oldlen>>20)!=0?3L:2L)*oldlen ) + newsz = oldlen<<2; // Double double size, so new table will be between %12.5 (18.75%) and 25% (25%) + } + // This heuristic in the next 2 lines leads to a much denser table + // with a higher reprobe rate + //if( sz >= (oldlen>>1) ) // If we are >50% full of keys then... + // newsz = oldlen<<1; // Double size + + // Last (re)size operation was very recent? Then double again + // despite having few live keys; slows down resize operations + // for tables subject to a high key churn rate - but do not + // forever grow the table. If there is a high key churn rate + // the table needs a steady state of rare same-size resize + // operations to clean out the dead keys. + long tm = System.currentTimeMillis(); + if( newsz <= oldlen && // New table would shrink or hold steady? + tm <= topmap._last_resize_milli+10000) // Recent resize (less than 10 sec ago) + newsz = oldlen<<1; // Double the existing size + + // Do not shrink, ever. If we hit this size once, assume we + // will again. + if( newsz < oldlen ) newsz = oldlen; + + // Convert to power-of-2 + int log2; + for( log2=MIN_SIZE_LOG; (1< ((len >> 2) + (len >> 1))) throw new RuntimeException("Table is full."); + } + + // Now limit the number of threads actually allocating memory to a + // handful - lest we have 750 threads all trying to allocate a giant + // resized array. + long r = _resizers; + while( !_resizerUpdater.compareAndSet(this,r,r+1) ) + r = _resizers; + // Size calculation: 2 words (K+V) per table entry, plus a handful. We + // guess at 64-bit pointers; 32-bit pointers screws up the size calc by + // 2x but does not screw up the heuristic very much. + long megs = ((((1L<>20/*megs*/; + if( r >= 2 && megs > 0 ) { // Already 2 guys trying; wait and see + newkvs = _newkvs; // Between dorking around, another thread did it + if( newkvs != null ) // See if resize is already in progress + return newkvs; // Use the new table already + // TODO - use a wait with timeout, so we'll wakeup as soon as the new table + // is ready, or after the timeout in any case. + //synchronized( this ) { wait(8*megs); } // Timeout - we always wakeup + // For now, sleep a tad and see if the 2 guys already trying to make + // the table actually get around to making it happen. + try { Thread.sleep(megs); } catch( Exception e ) { } + } + // Last check, since the 'new' below is expensive and there is a chance + // that another thread slipped in a new thread while we ran the heuristic. + newkvs = _newkvs; + if( newkvs != null ) // See if resize is already in progress + return newkvs; // Use the new table already + + // Double size for K,V pairs, add 1 for CHM + newkvs = new Object[(int)len]; // This can get expensive for big arrays + newkvs[0] = new CHM(_size); // CHM in slot 0 + newkvs[1] = new int[1< _copyIdxUpdater = + AtomicLongFieldUpdater.newUpdater(CHM.class, "_copyIdx"); + + // Work-done reporting. Used to efficiently signal when we can move to + // the new table. From 0 to len(oldkvs) refers to copying from the old + // table to the new. + volatile long _copyDone= 0; + static private final AtomicLongFieldUpdater _copyDoneUpdater = + AtomicLongFieldUpdater.newUpdater(CHM.class, "_copyDone"); + + // --- help_copy_impl ---------------------------------------------------- + // Help along an existing resize operation. We hope its the top-level + // copy (it was when we started) but this CHM might have been promoted out + // of the top position. + private final void help_copy_impl( NonBlockingHashMap topmap, Object[] oldkvs, boolean copy_all ) { + assert chm(oldkvs) == this; + Object[] newkvs = _newkvs; + assert newkvs != null; // Already checked by caller + int oldlen = len(oldkvs); // Total amount to copy + final int MIN_COPY_WORK = Math.min(oldlen,1024); // Limit per-thread work + + // --- + int panic_start = -1; + int copyidx=-9999; // Fool javac to think it's initialized + while( _copyDone < oldlen ) { // Still needing to copy? + // Carve out a chunk of work. The counter wraps around so every + // thread eventually tries to copy every slot repeatedly. + + // We "panic" if we have tried TWICE to copy every slot - and it still + // has not happened. i.e., twice some thread somewhere claimed they + // would copy 'slot X' (by bumping _copyIdx) but they never claimed to + // have finished (by bumping _copyDone). Our choices become limited: + // we can wait for the work-claimers to finish (and become a blocking + // algorithm) or do the copy work ourselves. Tiny tables with huge + // thread counts trying to copy the table often 'panic'. + if( panic_start == -1 ) { // No panic? + copyidx = (int)_copyIdx; + while( !_copyIdxUpdater.compareAndSet(this,copyidx,copyidx+MIN_COPY_WORK) ) + copyidx = (int)_copyIdx; // Re-read + if( !(copyidx < (oldlen<<1)) ) // Panic! + panic_start = copyidx; // Record where we started to panic-copy + } + + // We now know what to copy. Try to copy. + int workdone = 0; + for( int i=0; i 0 ) // Report work-done occasionally + copy_check_and_promote( topmap, oldkvs, workdone );// See if we can promote + //for( int i=0; i 0 ) { + while( !_copyDoneUpdater.compareAndSet(this,copyDone,copyDone+workdone) ) { + copyDone = _copyDone; // Reload, retry + assert (copyDone+workdone) <= oldlen; + } + } + + // Check for copy being ALL done, and promote. Note that we might have + // nested in-progress copies and manage to finish a nested copy before + // finishing the top-level copy. We only promote top-level copies. + if( copyDone+workdone == oldlen && // Ready to promote this table? + topmap._kvs == oldkvs && // Looking at the top-level table? + // Attempt to promote + topmap.CAS_kvs(oldkvs,_newkvs) ) { + topmap._last_resize_milli = System.currentTimeMillis(); // Record resize time for next check + } + } + + // --- copy_slot --------------------------------------------------------- + // Copy one K/V pair from oldkvs[i] to newkvs. Returns true if we can + // confirm that we set an old-table slot to TOMBPRIME, and only returns after + // updating the new table. We need an accurate confirmed-copy count so + // that we know when we can promote (if we promote the new table too soon, + // other threads may 'miss' on values not-yet-copied from the old table). + // We don't allow any direct updates on the new table, unless they first + // happened to the old table - so that any transition in the new table from + // null to not-null must have been from a copy_slot (or other old-table + // overwrite) and not from a thread directly writing in the new table. + private boolean copy_slot( NonBlockingHashMap topmap, int idx, Object[] oldkvs, Object[] newkvs ) { + // Blindly set the key slot from null to TOMBSTONE, to eagerly stop + // fresh put's from inserting new values in the old table when the old + // table is mid-resize. We don't need to act on the results here, + // because our correctness stems from box'ing the Value field. Slamming + // the Key field is a minor speed optimization. + Object key; + while( (key=key(oldkvs,idx)) == null ) + CAS_key(oldkvs,idx, null, TOMBSTONE); + + // --- + // Prevent new values from appearing in the old table. + // Box what we see in the old table, to prevent further updates. + Object oldval = val(oldkvs,idx); // Read OLD table + while( !(oldval instanceof Prime) ) { + final Prime box = (oldval == null || oldval == TOMBSTONE) ? TOMBPRIME : new Prime(oldval); + if( CAS_val(oldkvs,idx,oldval,box) ) { // CAS down a box'd version of oldval + // If we made the Value slot hold a TOMBPRIME, then we both + // prevented further updates here but also the (absent) + // oldval is vacuously available in the new table. We + // return with true here: any thread looking for a value for + // this key can correctly go straight to the new table and + // skip looking in the old table. + if( box == TOMBPRIME ) + return true; + // Otherwise we boxed something, but it still needs to be + // copied into the new table. + oldval = box; // Record updated oldval + break; // Break loop; oldval is now boxed by us + } + oldval = val(oldkvs,idx); // Else try, try again + } + if( oldval == TOMBPRIME ) return false; // Copy already complete here! + + // --- + // Copy the value into the new table, but only if we overwrite a null. + // If another value is already in the new table, then somebody else + // wrote something there and that write is happens-after any value that + // appears in the old table. + Object old_unboxed = ((Prime)oldval)._V; + assert old_unboxed != TOMBSTONE; + putIfMatch0(topmap, newkvs, key, old_unboxed, null); + + // --- + // Finally, now that any old value is exposed in the new table, we can + // forever hide the old-table value by slapping a TOMBPRIME down. This + // will stop other threads from uselessly attempting to copy this slot + // (i.e., it's a speed optimization not a correctness issue). + while( oldval != TOMBPRIME && !CAS_val(oldkvs,idx,oldval,TOMBPRIME) ) + oldval = val(oldkvs,idx); + + return oldval != TOMBPRIME; // True if we slammed the TOMBPRIME down + } // end copy_slot + } // End of CHM + + + // --- Snapshot ------------------------------------------------------------ + // The main class for iterating over the NBHM. It "snapshots" a clean + // view of the K/V array. + private class SnapshotV implements Iterator, Enumeration { + final Object[] _sskvs; + public SnapshotV() { + while( true ) { // Verify no table-copy-in-progress + Object[] topkvs = _kvs; + CHM topchm = chm(topkvs); + if( topchm._newkvs == null ) { // No table-copy-in-progress + // The "linearization point" for the iteration. Every key in this + // table will be visited, but keys added later might be skipped or + // even be added to a following table (also not iterated over). + _sskvs = topkvs; + break; + } + // Table copy in-progress - so we cannot get a clean iteration. We + // must help finish the table copy before we can start iterating. + topchm.help_copy_impl(NonBlockingHashMap.this,topkvs,true); + } + // Warm-up the iterator + next(); + } + int length() { return len(_sskvs); } + Object key(int idx) { return NonBlockingHashMap.key(_sskvs,idx); } + private int _idx; // Varies from 0-keys.length + private Object _nextK, _prevK; // Last 2 keys found + private TypeV _nextV, _prevV; // Last 2 values found + public boolean hasNext() { return _nextV != null; } + public TypeV next() { + // 'next' actually knows what the next value will be - it had to + // figure that out last go-around lest 'hasNext' report true and + // some other thread deleted the last value. Instead, 'next' + // spends all its effort finding the key that comes after the + // 'next' key. + if( _idx != 0 && _nextV == null ) throw new NoSuchElementException(); + _prevK = _nextK; // This will become the previous key + _prevV = _nextV; // This will become the previous value + _nextV = null; // We have no more next-key + // Attempt to set <_nextK,_nextV> to the next K,V pair. + // _nextV is the trigger: stop searching when it is != null + while( _idx, but the JDK always removes by key, even when the value has changed. + removeKey(); + } + + public TypeV nextElement() { return next(); } + public boolean hasMoreElements() { return hasNext(); } + } + public Object[] raw_array() { return new SnapshotV()._sskvs; } + + /** Returns an enumeration of the values in this table. + * @return an enumeration of the values in this table + * @see #values() */ + public Enumeration elements() { return new SnapshotV(); } + + // --- values -------------------------------------------------------------- + /** Returns a {@link Collection} view of the values contained in this map. + * The collection is backed by the map, so changes to the map are reflected + * in the collection, and vice-versa. The collection supports element + * removal, which removes the corresponding mapping from this map, via the + * Iterator.remove, Collection.remove, + * removeAll, retainAll, and clear operations. + * It does not support the add or addAll operations. + * + *

The view's iterator is a "weakly consistent" iterator that + * will never throw {@link ConcurrentModificationException}, and guarantees + * to traverse elements as they existed upon construction of the iterator, + * and may (but is not guaranteed to) reflect any modifications subsequent + * to construction. */ + @Override + public Collection values() { + return new AbstractCollection() { + @Override public void clear ( ) { NonBlockingHashMap.this.clear ( ); } + @Override public int size ( ) { return NonBlockingHashMap.this.size ( ); } + @Override public boolean contains( Object v ) { return NonBlockingHashMap.this.containsValue(v); } + @Override public Iterator iterator() { return new SnapshotV(); } + }; + } + + // --- keySet -------------------------------------------------------------- + private class SnapshotK implements Iterator, Enumeration { + final SnapshotV _ss; + public SnapshotK() { _ss = new SnapshotV(); } + public void remove() { _ss.removeKey(); } + public TypeK next() { _ss.next(); return (TypeK)_ss._prevK; } + public boolean hasNext() { return _ss.hasNext(); } + public TypeK nextElement() { return next(); } + public boolean hasMoreElements() { return hasNext(); } + } + + /** Returns an enumeration of the keys in this table. + * @return an enumeration of the keys in this table + * @see #keySet() */ + public Enumeration keys() { return new SnapshotK(); } + + /** Returns a {@link Set} view of the keys contained in this map. The set + * is backed by the map, so changes to the map are reflected in the set, + * and vice-versa. The set supports element removal, which removes the + * corresponding mapping from this map, via the Iterator.remove, + * Set.remove, removeAll, retainAll, and + * clear operations. It does not support the add or + * addAll operations. + * + *

The view's iterator is a "weakly consistent" iterator that + * will never throw {@link ConcurrentModificationException}, and guarantees + * to traverse elements as they existed upon construction of the iterator, + * and may (but is not guaranteed to) reflect any modifications subsequent + * to construction. */ + @Override + public Set keySet() { + return new AbstractSet () { + @Override public void clear ( ) { NonBlockingHashMap.this.clear ( ); } + @Override public int size ( ) { return NonBlockingHashMap.this.size ( ); } + @Override public boolean contains( Object k ) { return NonBlockingHashMap.this.containsKey(k); } + @Override public boolean remove ( Object k ) { return NonBlockingHashMap.this.remove (k) != null; } + @Override public Iterator iterator() { return new SnapshotK(); } + // This is an efficient implementation of toArray instead of the standard + // one. In particular it uses a smart iteration over the NBHM. + @Override public T[] toArray(T[] a) { + Object[] kvs = raw_array(); + // Estimate size of array; be prepared to see more or fewer elements + int sz = size(); + T[] r = a.length >= sz ? a : + (T[])java.lang.reflect.Array.newInstance(a.getClass().getComponentType(), sz); + // Fast efficient element walk. + int j=0; + for( int i=0; i= r.length ) { + int sz2 = (int)Math.min(Integer.MAX_VALUE-8,((long)j)<<1); + if( sz2<=r.length ) throw new OutOfMemoryError("Required array size too large"); + r = Arrays.copyOf(r,sz2); + } + r[j++] = (T)K; + } + } + if( j <= a.length ) { // Fit in the original array? + if( a!=r ) System.arraycopy(r,0,a,0,j); + if( j { + NBHMEntry( final TypeK k, final TypeV v ) { super(k,v); } + public TypeV setValue(final TypeV val) { + if( val == null ) throw new NullPointerException(); + _val = val; + return put(_key, val); + } + } + + private class SnapshotE implements Iterator> { + final SnapshotV _ss; + public SnapshotE() { _ss = new SnapshotV(); } + public void remove() { + // NOTE: it would seem logical that entry removal will semantically mean removing the matching pair , but + // the JDK always removes by key, even when the value has changed. + _ss.removeKey(); + } + public Map.Entry next() { _ss.next(); return new NBHMEntry((TypeK)_ss._prevK,_ss._prevV); } + public boolean hasNext() { return _ss.hasNext(); } + } + + /** Returns a {@link Set} view of the mappings contained in this map. The + * set is backed by the map, so changes to the map are reflected in the + * set, and vice-versa. The set supports element removal, which removes + * the corresponding mapping from the map, via the + * Iterator.remove, Set.remove, removeAll, + * retainAll, and clear operations. It does not support + * the add or addAll operations. + * + *

The view's iterator is a "weakly consistent" iterator + * that will never throw {@link ConcurrentModificationException}, + * and guarantees to traverse elements as they existed upon + * construction of the iterator, and may (but is not guaranteed to) + * reflect any modifications subsequent to construction. + * + *

Warning: the iterator associated with this Set + * requires the creation of {@link java.util.Map.Entry} objects with each + * iteration. The {@link NonBlockingHashMap} does not normally create or + * using {@link java.util.Map.Entry} objects so they will be created soley + * to support this iteration. Iterating using {@link Map#keySet} or {@link + * Map#values} will be more efficient. + */ + @Override + public Set> entrySet() { + return new AbstractSet>() { + @Override public void clear ( ) { NonBlockingHashMap.this.clear( ); } + @Override public int size ( ) { return NonBlockingHashMap.this.size ( ); } + @Override public boolean remove( final Object o ) { + if( !(o instanceof Map.Entry)) return false; + final Map.Entry e = (Map.Entry)o; + return NonBlockingHashMap.this.remove(e.getKey(), e.getValue()); + } + @Override public boolean contains(final Object o) { + if( !(o instanceof Map.Entry)) return false; + final Map.Entry e = (Map.Entry)o; + TypeV v = get(e.getKey()); + return v != null && v.equals(e.getValue()); + } + @Override public Iterator> iterator() { return new SnapshotE(); } + }; + } + + // --- writeObject ------------------------------------------------------- + // Write a NBHM to a stream + private void writeObject(java.io.ObjectOutputStream s) throws IOException { + s.defaultWriteObject(); // Nothing to write + for( Object K : keySet() ) { + final Object V = get(K); // Do an official 'get' + s.writeObject(K); // Write the pair + s.writeObject(V); + } + s.writeObject(null); // Sentinel to indicate end-of-data + s.writeObject(null); + } + + // --- readObject -------------------------------------------------------- + // Read a NBHM from a stream + private void readObject(java.io.ObjectInputStream s) throws IOException, ClassNotFoundException { + s.defaultReadObject(); // Read nothing + initialize(MIN_SIZE); + for(;;) { + final TypeK K = (TypeK) s.readObject(); + final TypeV V = (TypeV) s.readObject(); + if( K == null ) break; + put(K,V); // Insert with an offical put + } + } + +} // End NonBlockingHashMap class diff --git a/netty-jctools/src/main/java/org/jctools/maps/NonBlockingHashMapLong.java b/netty-jctools/src/main/java/org/jctools/maps/NonBlockingHashMapLong.java new file mode 100644 index 0000000..db9ade4 --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/maps/NonBlockingHashMapLong.java @@ -0,0 +1,1313 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.maps; + +import org.jctools.util.RangeUtil; + +import java.io.IOException; +import java.io.Serializable; +import java.util.*; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; + +import static org.jctools.maps.NonBlockingHashMap.DUMMY_VOLATILE; +import static org.jctools.util.UnsafeAccess.UNSAFE; +import static org.jctools.util.UnsafeAccess.fieldOffset; + + +/** + * A lock-free alternate implementation of {@link java.util.concurrent.ConcurrentHashMap} + * with primitive long keys, better scaling properties and + * generally lower costs. The use of {@code long} keys allows for faster + * compares and lower memory costs. The Map provides identical correctness + * properties as ConcurrentHashMap. All operations are non-blocking and + * multi-thread safe, including all update operations. {@link + * NonBlockingHashMapLong} scales substatially better than {@link + * java.util.concurrent.ConcurrentHashMap} for high update rates, even with a large + * concurrency factor. Scaling is linear up to 768 CPUs on a 768-CPU Azul + * box, even with 100% updates or 100% reads or any fraction in-between. + * Linear scaling up to all cpus has been observed on a 32-way Sun US2 box, + * 32-way Sun Niagra box, 8-way Intel box and a 4-way Power box. + * + *

The main benefit of this class over using plain {@link + * org.jctools.maps.NonBlockingHashMap} with {@link Long} keys is + * that it avoids the auto-boxing and unboxing costs. Since auto-boxing is + * automatic, it is easy to accidentally cause auto-boxing and negate + * the space and speed benefits. + * + *

This class obeys the same functional specification as {@link + * java.util.Hashtable}, and includes versions of methods corresponding to + * each method of Hashtable. However, even though all operations are + * thread-safe, operations do not entail locking and there is + * not any support for locking the entire table in a way that + * prevents all access. This class is fully interoperable with + * Hashtable in programs that rely on its thread safety but not on + * its synchronization details. + * + *

Operations (including put) generally do not block, so may + * overlap with other update operations (including other puts and + * removes). Retrievals reflect the results of the most recently + * completed update operations holding upon their onset. For + * aggregate operations such as putAll, concurrent retrievals may + * reflect insertion or removal of only some entries. Similarly, Iterators + * and Enumerations return elements reflecting the state of the hash table at + * some point at or since the creation of the iterator/enumeration. They do + * not throw {@link ConcurrentModificationException}. However, + * iterators are designed to be used by only one thread at a time. + * + *

Very full tables, or tables with high re-probe rates may trigger an + * internal resize operation to move into a larger table. Resizing is not + * terribly expensive, but it is not free either; during resize operations + * table throughput may drop somewhat. All threads that visit the table + * during a resize will 'help' the resizing but will still be allowed to + * complete their operation before the resize is finished (i.e., a simple + * 'get' operation on a million-entry table undergoing resizing will not need + * to block until the entire million entries are copied). + * + *

This class and its views and iterators implement all of the + * optional methods of the {@link Map} and {@link Iterator} + * interfaces. + * + *

Like {@link Hashtable} but unlike {@link HashMap}, this class + * does not allow null to be used as a value. + * + * + * @since 1.5 + * @author Cliff Click + * @param the type of mapped values + */ + +public class NonBlockingHashMapLong + extends AbstractMap + implements ConcurrentMap, Cloneable, Serializable { + + private static final long serialVersionUID = 1234123412341234124L; + + private static final int REPROBE_LIMIT=10; // Too many reprobes then force a table-resize + + // --- Bits to allow Unsafe access to arrays + private static final int _Obase = UNSAFE.arrayBaseOffset(Object[].class); + private static final int _Oscale = UNSAFE.arrayIndexScale(Object[].class); + private static long rawIndex(final Object[] ary, final int idx) { + assert idx >= 0 && idx < ary.length; + // Note the long-math requirement, to handle arrays of more than 2^31 bytes + // - or 2^28 - or about 268M - 8-byte pointer elements. + return _Obase + ((long)idx * _Oscale); + } + private static final int _Lbase = UNSAFE.arrayBaseOffset(long[].class); + private static final int _Lscale = UNSAFE.arrayIndexScale(long[].class); + private static long rawIndex(final long[] ary, final int idx) { + assert idx >= 0 && idx < ary.length; + // Note the long-math requirement, to handle arrays of more than 2^31 bytes + // - or 2^28 - or about 268M - 8-byte pointer elements. + return _Lbase + ((long)idx * _Lscale); + } + + // --- Bits to allow Unsafe CAS'ing of the CHM field + private static final long _chm_offset = fieldOffset(NonBlockingHashMapLong.class, "_chm"); + private static final long _val_1_offset = fieldOffset(NonBlockingHashMapLong.class, "_val_1"); + + private final boolean CAS( final long offset, final Object old, final Object nnn ) { + return UNSAFE.compareAndSwapObject(this, offset, old, nnn ); + } + + // --- Adding a 'prime' bit onto Values via wrapping with a junk wrapper class + private static final class Prime { + final Object _V; + Prime( Object V ) { _V = V; } + static Object unbox( Object V ) { return V instanceof Prime ? ((Prime)V)._V : V; } + } + + // --- The Hash Table -------------------- + private transient CHM _chm; + // This next field holds the value for Key 0 - the special key value which + // is the initial array value, and also means: no-key-inserted-yet. + private transient Object _val_1; // Value for Key: NO_KEY + + // Time since last resize + private transient long _last_resize_milli; + + // Optimize for space: use a 1/2-sized table and allow more re-probes + private final boolean _opt_for_space; + + // --- Minimum table size ---------------- + // Pick size 16 K/V pairs, which turns into (16*2)*4+12 = 140 bytes on a + // standard 32-bit HotSpot, and (16*2)*8+12 = 268 bytes on 64-bit Azul. + private static final int MIN_SIZE_LOG=4; // + private static final int MIN_SIZE=(1<>4); + } + + // --- NonBlockingHashMapLong ---------------------------------------------- + // Constructors + + /** Create a new NonBlockingHashMapLong with default minimum size (currently set + * to 8 K/V pairs or roughly 84 bytes on a standard 32-bit JVM). */ + public NonBlockingHashMapLong( ) { this(MIN_SIZE,true); } + + /** Create a new NonBlockingHashMapLong with initial room for the given + * number of elements, thus avoiding internal resizing operations to reach + * an appropriate size. Large numbers here when used with a small count of + * elements will sacrifice space for a small amount of time gained. The + * initial size will be rounded up internally to the next larger power of 2. */ + public NonBlockingHashMapLong( final int initial_sz ) { this(initial_sz,true); } + + /** Create a new NonBlockingHashMapLong, setting the space-for-speed + * tradeoff. {@code true} optimizes for space and is the default. {@code + * false} optimizes for speed and doubles space costs for roughly a 10% + * speed improvement. */ + public NonBlockingHashMapLong( final boolean opt_for_space ) { this(1,opt_for_space); } + + /** Create a new NonBlockingHashMapLong, setting both the initial size and + * the space-for-speed tradeoff. {@code true} optimizes for space and is + * the default. {@code false} optimizes for speed and doubles space costs + * for roughly a 10% speed improvement. */ + public NonBlockingHashMapLong( final int initial_sz, final boolean opt_for_space ) { + _opt_for_space = opt_for_space; + initialize(initial_sz); + } + private void initialize( final int initial_sz ) { + RangeUtil.checkPositiveOrZero(initial_sz, "initial_sz"); + int i; // Convert to next largest power-of-2 + for( i=MIN_SIZE_LOG; (1<true if the key is in the table */ + public boolean containsKey( long key ) { return get(key) != null; } + + /** Legacy method testing if some key maps into the specified value in this + * table. This method is identical in functionality to {@link + * #containsValue}, and exists solely to ensure full compatibility with + * class {@link java.util.Hashtable}, which supported this method prior to + * introduction of the Java Collections framework. + * @param val a value to search for + * @return true if this map maps one or more keys to the specified value + * @throws NullPointerException if the specified value is null */ + public boolean contains ( Object val ) { return containsValue(val); } + + /** Maps the specified key to the specified value in the table. The value + * cannot be null.

The value can be retrieved by calling {@link #get} + * with a key that is equal to the original key. + * @param key key with which the specified value is to be associated + * @param val value to be associated with the specified key + * @return the previous value associated with key, or + * null if there was no mapping for key + * @throws NullPointerException if the specified value is null */ + public TypeV put ( long key, TypeV val ) { return putIfMatch( key, val,NO_MATCH_OLD);} + + /** Atomically, do a {@link #put} if-and-only-if the key is not mapped. + * Useful to ensure that only a single mapping for the key exists, even if + * many threads are trying to create the mapping in parallel. + * @return the previous value associated with the specified key, + * or null if there was no mapping for the key + * @throws NullPointerException if the specified is value is null */ + public TypeV putIfAbsent( long key, TypeV val ) { return putIfMatch( key, val,TOMBSTONE );} + + /** Removes the key (and its corresponding value) from this map. + * This method does nothing if the key is not in the map. + * @return the previous value associated with key, or + * null if there was no mapping for key*/ + public TypeV remove ( long key ) { return putIfMatch( key,TOMBSTONE,NO_MATCH_OLD);} + + /** Atomically do a {@link #remove(long)} if-and-only-if the key is mapped + * to a value which is equals to the given value. + * @throws NullPointerException if the specified value is null */ + public boolean remove ( long key,Object val ) { return putIfMatch( key,TOMBSTONE,val ) == val ;} + + /** Atomically do a put(key,val) if-and-only-if the key is + * mapped to some value already. + * @throws NullPointerException if the specified value is null */ + public TypeV replace ( long key, TypeV val ) { return putIfMatch( key, val,MATCH_ANY );} + + /** Atomically do a put(key,newValue) if-and-only-if the key is + * mapped a value which is equals to oldValue. + * @throws NullPointerException if the specified value is null */ + public boolean replace ( long key, TypeV oldValue, TypeV newValue ) { + return putIfMatch( key, newValue, oldValue ) == oldValue; + } + + @SuppressWarnings("unchecked") + private TypeV putIfMatch( long key, Object newVal, Object oldVal ) { + if (oldVal == null || newVal == null) throw new NullPointerException(); + if( key == NO_KEY ) { + Object curVal = _val_1; + if( oldVal == NO_MATCH_OLD || // Do we care about expected-Value at all? + curVal == oldVal || // No instant match already? + (oldVal == MATCH_ANY && curVal != TOMBSTONE) || + oldVal.equals(curVal) ) { // Expensive equals check + if( !CAS(_val_1_offset,curVal,newVal) ) // One shot CAS update attempt + curVal = _val_1; // Failed; get failing witness + } + return curVal == TOMBSTONE ? null : (TypeV)curVal; // Return the last value present + } + final Object res = _chm.putIfMatch( key, newVal, oldVal ); + assert !(res instanceof Prime); + assert res != null; + return res == TOMBSTONE ? null : (TypeV)res; + } + + /** Removes all of the mappings from this map. */ + public void clear() { // Smack a new empty table down + CHM newchm = new CHM(this,new ConcurrentAutoTable(),MIN_SIZE_LOG); + while( !CAS(_chm_offset,_chm,newchm) ) { /*Spin until the clear works*/} + CAS(_val_1_offset,_val_1,TOMBSTONE); + } + // Non-atomic clear, preserving existing large arrays + public void clear(boolean large) { // Smack a new empty table down + _chm.clear(); + CAS(_val_1_offset,_val_1,TOMBSTONE); + } + + /** Returns true if this Map maps one or more keys to the specified + * value. Note: This method requires a full internal traversal of the + * hash table and is much slower than {@link #containsKey}. + * @param val value whose presence in this map is to be tested + * @return true if this Map maps one or more keys to the specified value + * @throws NullPointerException if the specified value is null */ + public boolean containsValue( Object val ) { + if( val == null ) return false; + if( val == _val_1 ) return true; // Key 0 + for( TypeV V : values() ) + if( V == val || V.equals(val) ) + return true; + return false; + } + + // --- get ----------------------------------------------------------------- + /** Returns the value to which the specified key is mapped, or {@code null} + * if this map contains no mapping for the key. + *

More formally, if this map contains a mapping from a key {@code k} to + * a value {@code v} such that {@code key==k}, then this method + * returns {@code v}; otherwise it returns {@code null}. (There can be at + * most one such mapping.) + * @throws NullPointerException if the specified key is null */ + // Never returns a Prime nor a Tombstone. + @SuppressWarnings("unchecked") + public final TypeV get( long key ) { + if( key == NO_KEY ) { + final Object V = _val_1; + return V == TOMBSTONE ? null : (TypeV)V; + } + final Object V = _chm.get_impl(key); + assert !(V instanceof Prime); // Never return a Prime + assert V != TOMBSTONE; + return (TypeV)V; + } + + /** Auto-boxing version of {@link #get(long)}. */ + public TypeV get ( Object key ) { return (key instanceof Long) ? get (((Long)key).longValue()) : null; } + /** Auto-boxing version of {@link #remove(long)}. */ + public TypeV remove ( Object key ) { return (key instanceof Long) ? remove (((Long)key).longValue()) : null; } + /** Auto-boxing version of {@link #remove(long,Object)}. */ + public boolean remove ( Object key, Object Val ) { return (key instanceof Long) && remove(((Long) key).longValue(), Val); } + /** Auto-boxing version of {@link #containsKey(long)}. */ + public boolean containsKey( Object key ) { return (key instanceof Long) && containsKey(((Long) key).longValue()); } + /** Auto-boxing version of {@link #putIfAbsent}. */ + public TypeV putIfAbsent( Long key, TypeV val ) { return putIfAbsent( key.longValue(), val ); } + /** Auto-boxing version of {@link #replace}. */ + public TypeV replace( Long key, TypeV Val ) { return replace(key.longValue(), Val); } + /** Auto-boxing version of {@link #put}. */ + public TypeV put ( Long key, TypeV val ) { return put(key.longValue(),val); } + /** Auto-boxing version of {@link #replace}. */ + public boolean replace( Long key, TypeV oldValue, TypeV newValue ) { + return replace(key.longValue(), oldValue, newValue); + } + + // --- help_copy ----------------------------------------------------------- + // Help along an existing resize operation. This is just a fast cut-out + // wrapper, to encourage inlining for the fast no-copy-in-progress case. We + // always help the top-most table copy, even if there are nested table + // copies in progress. + private void help_copy( ) { + // Read the top-level CHM only once. We'll try to help this copy along, + // even if it gets promoted out from under us (i.e., the copy completes + // and another KVS becomes the top-level copy). + CHM topchm = _chm; + if( topchm._newchm == null ) return; // No copy in-progress + topchm.help_copy_impl(false); + } + + // --- hash ---------------------------------------------------------------- + // Helper function to spread lousy hashCodes Throws NPE for null Key, on + // purpose - as the first place to conveniently toss the required NPE for a + // null Key. + private static final int hash(long h) { + h ^= (h>>>20) ^ (h>>>12); + h ^= (h>>> 7) ^ (h>>> 4); + h += h<<7; // smear low bits up high, for hashcodes that only differ by 1 + return (int)h; + } + + + // --- CHM ----------------------------------------------------------------- + // The control structure for the NonBlockingHashMapLong + private static final class CHM implements Serializable { + // Back-pointer to top-level structure + final NonBlockingHashMapLong _nbhml; + + // Size in active K,V pairs + private ConcurrentAutoTable _size; + public int size () { return (int)_size.get(); } + + // --- + // These next 2 fields are used in the resizing heuristics, to judge when + // it is time to resize or copy the table. Slots is a count of used-up + // key slots, and when it nears a large fraction of the table we probably + // end up reprobing too much. Last-resize-milli is the time since the + // last resize; if we are running back-to-back resizes without growing + // (because there are only a few live keys but many slots full of dead + // keys) then we need a larger table to cut down on the churn. + + // Count of used slots, to tell when table is full of dead unusable slots + private ConcurrentAutoTable _slots; + public int slots() { return (int)_slots.get(); } + + // --- + // New mappings, used during resizing. + // The 'next' CHM - created during a resize operation. This represents + // the new table being copied from the old one. It's the volatile + // variable that is read as we cross from one table to the next, to get + // the required memory orderings. It monotonically transits from null to + // set (once). + volatile CHM _newchm; + private static final AtomicReferenceFieldUpdater _newchmUpdater = + AtomicReferenceFieldUpdater.newUpdater(CHM.class,CHM.class, "_newchm"); + // Set the _newchm field if we can. AtomicUpdaters do not fail spuriously. + boolean CAS_newchm( CHM newchm ) { + return _newchmUpdater.compareAndSet(this,null,newchm); + } + // Sometimes many threads race to create a new very large table. Only 1 + // wins the race, but the losers all allocate a junk large table with + // hefty allocation costs. Attempt to control the overkill here by + // throttling attempts to create a new table. I cannot really block here + // (lest I lose the non-blocking property) but late-arriving threads can + // give the initial resizing thread a little time to allocate the initial + // new table. The Right Long Term Fix here is to use array-lets and + // incrementally create the new very large array. In C I'd make the array + // with malloc (which would mmap under the hood) which would only eat + // virtual-address and not real memory - and after Somebody wins then we + // could in parallel initialize the array. Java does not allow + // un-initialized array creation (especially of ref arrays!). + volatile long _resizers; // count of threads attempting an initial resize + private static final AtomicLongFieldUpdater _resizerUpdater = + AtomicLongFieldUpdater.newUpdater(CHM.class, "_resizers"); + + // --- key,val ------------------------------------------------------------- + // Access K,V for a given idx + private boolean CAS_key( int idx, long old, long key ) { + return UNSAFE.compareAndSwapLong ( _keys, rawIndex(_keys, idx), old, key ); + } + private boolean CAS_val( int idx, Object old, Object val ) { + return UNSAFE.compareAndSwapObject( _vals, rawIndex(_vals, idx), old, val ); + } + + final long [] _keys; + final Object [] _vals; + + // Simple constructor + CHM( final NonBlockingHashMapLong nbhml, ConcurrentAutoTable size, final int logsize ) { + _nbhml = nbhml; + _size = size; + _slots= new ConcurrentAutoTable(); + _keys = new long [1<= reprobe_limit(len) ) // too many probes + return _newchm == null // Table copy in progress? + ? null // Nope! A clear miss + : copy_slot_and_check(idx,key).get_impl(key); // Retry in the new table + + idx = (idx+1)&(len-1); // Reprobe by 1! (could now prefetch) + } + } + + // --- putIfMatch --------------------------------------------------------- + // Put, Remove, PutIfAbsent, etc. Return the old value. If the returned + // value is equal to expVal (or expVal is NO_MATCH_OLD) then the put can + // be assumed to work (although might have been immediately overwritten). + // Only the path through copy_slot passes in an expected value of null, + // and putIfMatch only returns a null if passed in an expected null. + private Object putIfMatch( final long key, final Object putval, final Object expVal ) { + final int hash = hash(key); + assert putval != null; + assert !(putval instanceof Prime); + assert !(expVal instanceof Prime); + final int len = _keys.length; + int idx = (hash & (len-1)); // The first key + + // --- + // Key-Claim stanza: spin till we can claim a Key (or force a resizing). + int reprobe_cnt=0; + long K; + Object V; + while( true ) { // Spin till we get a Key slot + V = _vals[idx]; // Get old value + K = _keys[idx]; // Get current key + if( K == NO_KEY ) { // Slot is free? + // Found an empty Key slot - which means this Key has never been in + // this table. No need to put a Tombstone - the Key is not here! + if( putval == TOMBSTONE ) return TOMBSTONE; // Not-now & never-been in this table + if( expVal == MATCH_ANY ) return TOMBSTONE; // Will not match, even after K inserts + // Claim the zero key-slot + if( CAS_key(idx, NO_KEY, key) ) { // Claim slot for Key + _slots.add(1); // Raise key-slots-used count + break; // Got it! + } + // CAS to claim the key-slot failed. + // + // This re-read of the Key points out an annoying short-coming of Java + // CAS. Most hardware CAS's report back the existing value - so that + // if you fail you have a *witness* - the value which caused the CAS + // to fail. The Java API turns this into a boolean destroying the + // witness. Re-reading does not recover the witness because another + // thread can write over the memory after the CAS. Hence we can be in + // the unfortunate situation of having a CAS fail *for cause* but + // having that cause removed by a later store. This turns a + // non-spurious-failure CAS (such as Azul has) into one that can + // apparently spuriously fail - and we avoid apparent spurious failure + // by not allowing Keys to ever change. + K = _keys[idx]; // CAS failed, get updated value + assert K != NO_KEY ; // If keys[idx] is NO_KEY, CAS shoulda worked + } + // Key slot was not null, there exists a Key here + if( K == key ) + break; // Got it! + + // get and put must have the same key lookup logic! Lest 'get' give + // up looking too soon. + //topmap._reprobes.add(1); + if( ++reprobe_cnt >= reprobe_limit(len) ) { + // We simply must have a new table to do a 'put'. At this point a + // 'get' will also go to the new table (if any). We do not need + // to claim a key slot (indeed, we cannot find a free one to claim!). + final CHM newchm = resize(); + if( expVal != null ) _nbhml.help_copy(); // help along an existing copy + return newchm.putIfMatch(key,putval,expVal); + } + + idx = (idx+1)&(len-1); // Reprobe! + } // End of spinning till we get a Key slot + + while ( true ) { // Spin till we insert a value + // --- + // Found the proper Key slot, now update the matching Value slot. We + // never put a null, so Value slots monotonically move from null to + // not-null (deleted Values use Tombstone). Thus if 'V' is null we + // fail this fast cutout and fall into the check for table-full. + if( putval == V ) return V; // Fast cutout for no-change + + // See if we want to move to a new table (to avoid high average re-probe + // counts). We only check on the initial set of a Value from null to + // not-null (i.e., once per key-insert). + if( (V == null && tableFull(reprobe_cnt,len)) || + // Or we found a Prime: resize is already in progress. The resize + // call below will do a CAS on _newchm forcing the read. + V instanceof Prime) { + resize(); // Force the new table copy to start + return copy_slot_and_check(idx,expVal).putIfMatch(key,putval,expVal); + } + + // --- + // We are finally prepared to update the existing table + //assert !(V instanceof Prime); // always true, so IDE warnings if uncommented + + // Must match old, and we do not? Then bail out now. Note that either V + // or expVal might be TOMBSTONE. Also V can be null, if we've never + // inserted a value before. expVal can be null if we are called from + // copy_slot. + + if( expVal != NO_MATCH_OLD && // Do we care about expected-Value at all? + V != expVal && // No instant match already? + (expVal != MATCH_ANY || V == TOMBSTONE || V == null) && + !(V==null && expVal == TOMBSTONE) && // Match on null/TOMBSTONE combo + (expVal == null || !expVal.equals(V)) ) // Expensive equals check at the last + return (V==null) ? TOMBSTONE : V; // Do not update! + + // Actually change the Value in the Key,Value pair + if( CAS_val(idx, V, putval ) ) break; + + // CAS failed + // Because we have no witness, we do not know why it failed. + // Indeed, by the time we look again the value under test might have flipped + // a thousand times and now be the expected value (despite the CAS failing). + // Check for the never-succeed condition of a Prime value and jump to any + // nested table, or else just re-run. + + // We would not need this load at all if CAS returned the value on which + // the CAS failed (AKA witness). The new CAS semantics are supported via + // VarHandle in JDK9. + V = _vals[idx]; // Get new value + + // If a Prime'd value got installed, we need to re-run the put on the + // new table. Otherwise we lost the CAS to another racing put. + // Simply retry from the start. + if( V instanceof Prime ) + return copy_slot_and_check(idx,expVal).putIfMatch(key,putval,expVal); + + // Simply retry from the start. + // NOTE: need the fence, since otherwise '_vals[idx]' load could be hoisted + // out of loop. + int dummy = DUMMY_VOLATILE; + } + + // CAS succeeded - we did the update! + // Both normal put's and table-copy calls putIfMatch, but table-copy + // does not (effectively) increase the number of live k/v pairs. + if( expVal != null ) { + // Adjust sizes - a striped counter + if( (V == null || V == TOMBSTONE) && putval != TOMBSTONE ) _size.add( 1); + if( !(V == null || V == TOMBSTONE) && putval == TOMBSTONE ) _size.add(-1); + } + + // We won; we know the update happened as expected. + return (V==null && expVal!=null) ? TOMBSTONE : V; + } + + // --- tableFull --------------------------------------------------------- + // Heuristic to decide if this table is too full, and we should start a + // new table. Note that if a 'get' call has reprobed too many times and + // decided the table must be full, then always the estimate_sum must be + // high and we must report the table is full. If we do not, then we might + // end up deciding that the table is not full and inserting into the + // current table, while a 'get' has decided the same key cannot be in this + // table because of too many reprobes. The invariant is: + // slots.estimate_sum >= max_reprobe_cnt >= reprobe_limit(len) + private boolean tableFull( int reprobe_cnt, int len ) { + return + // Do the cheap check first: we allow some number of reprobes always + reprobe_cnt >= REPROBE_LIMIT && + (reprobe_cnt >= reprobe_limit(len) || + // More expensive check: see if the table is > 1/2 full. + _slots.estimate_get() >= (len>>1)); + } + + // --- resize ------------------------------------------------------------ + // Resizing after too many probes. "How Big???" heuristics are here. + // Callers will (not this routine) will 'help_copy' any in-progress copy. + // Since this routine has a fast cutout for copy-already-started, callers + // MUST 'help_copy' lest we have a path which forever runs through + // 'resize' only to discover a copy-in-progress which never progresses. + private CHM resize() { + // Check for resize already in progress, probably triggered by another thread + CHM newchm = _newchm; // VOLATILE READ + if( newchm != null ) // See if resize is already in progress + return newchm; // Use the new table already + + // No copy in-progress, so start one. First up: compute new table size. + int oldlen = _keys.length; // Old count of K,V pairs allowed + int sz = size(); // Get current table count of active K,V pairs + int newsz = sz; // First size estimate + + // Heuristic to determine new size. We expect plenty of dead-slots-with-keys + // and we need some decent padding to avoid endless reprobing. + if( _nbhml._opt_for_space ) { + // This heuristic leads to a much denser table with a higher reprobe rate + if( sz >= (oldlen>>1) ) // If we are >50% full of keys then... + newsz = oldlen<<1; // Double size + } else { + if( sz >= (oldlen>>2) ) { // If we are >25% full of keys then... + newsz = oldlen<<1; // Double size + if( sz >= (oldlen>>1) ) // If we are >50% full of keys then... + newsz = oldlen<<2; // Double double size + } + } + + // Last (re)size operation was very recent? Then double again + // despite having few live keys; slows down resize operations + // for tables subject to a high key churn rate - but do not + // forever grow the table. If there is a high key churn rate + // the table needs a steady state of rare same-size resize + // operations to clean out the dead keys. + long tm = System.currentTimeMillis(); + if( newsz <= oldlen && // New table would shrink or hold steady? + tm <= _nbhml._last_resize_milli+10000) // Recent resize (less than 10 sec ago) + newsz = oldlen<<1; // Double the existing size + + // Do not shrink, ever. If we hit this size once, assume we + // will again. + if( newsz < oldlen ) newsz = oldlen; + + // Convert to power-of-2 + int log2; + for( log2=MIN_SIZE_LOG; (1< ((len >> 2) + (len >> 1))) throw new RuntimeException("Table is full."); + } + + // Now limit the number of threads actually allocating memory to a + // handful - lest we have 750 threads all trying to allocate a giant + // resized array. + long r = _resizers; + while( !_resizerUpdater.compareAndSet(this,r,r+1) ) + r = _resizers; + // Size calculation: 2 words (K+V) per table entry, plus a handful. We + // guess at 64-bit pointers; 32-bit pointers screws up the size calc by + // 2x but does not screw up the heuristic very much. + long megs = ((((1L<>20/*megs*/; + if( r >= 2 && megs > 0 ) { // Already 2 guys trying; wait and see + newchm = _newchm; // Between dorking around, another thread did it + if( newchm != null ) // See if resize is already in progress + return newchm; // Use the new table already + // We could use a wait with timeout, so we'll wakeup as soon as the new table + // is ready, or after the timeout in any case. + //synchronized( this ) { wait(8*megs); } // Timeout - we always wakeup + // For now, sleep a tad and see if the 2 guys already trying to make + // the table actually get around to making it happen. + try { Thread.sleep(megs); } catch( Exception e ) { /*empty*/} + } + // Last check, since the 'new' below is expensive and there is a chance + // that another thread slipped in a new thread while we ran the heuristic. + newchm = _newchm; + if( newchm != null ) // See if resize is already in progress + return newchm; // Use the new table already + + // New CHM - actually allocate the big arrays + newchm = new CHM(_nbhml,_size,log2); + + // Another check after the slow allocation + if( _newchm != null ) // See if resize is already in progress + return _newchm; // Use the new table already + + // The new table must be CAS'd in so only 1 winner amongst duplicate + // racing resizing threads. Extra CHM's will be GC'd. + if( CAS_newchm( newchm ) ) { // NOW a resize-is-in-progress! + //notifyAll(); // Wake up any sleepers + //long nano = System.nanoTime(); + //System.out.println(" "+nano+" Resize from "+oldlen+" to "+(1< _copyIdxUpdater = + AtomicLongFieldUpdater.newUpdater(CHM.class, "_copyIdx"); + + // Work-done reporting. Used to efficiently signal when we can move to + // the new table. From 0 to len(oldkvs) refers to copying from the old + // table to the new. + volatile long _copyDone= 0; + static private final AtomicLongFieldUpdater _copyDoneUpdater = + AtomicLongFieldUpdater.newUpdater(CHM.class, "_copyDone"); + + // --- help_copy_impl ---------------------------------------------------- + // Help along an existing resize operation. We hope its the top-level + // copy (it was when we started) but this CHM might have been promoted out + // of the top position. + private void help_copy_impl( final boolean copy_all ) { + final CHM newchm = _newchm; + assert newchm != null; // Already checked by caller + int oldlen = _keys.length; // Total amount to copy + final int MIN_COPY_WORK = Math.min(oldlen,1024); // Limit per-thread work + + // --- + int panic_start = -1; + int copyidx=-9999; // Fool javac to think it's initialized + while( _copyDone < oldlen ) { // Still needing to copy? + // Carve out a chunk of work. The counter wraps around so every + // thread eventually tries to copy every slot repeatedly. + + // We "panic" if we have tried TWICE to copy every slot - and it still + // has not happened. i.e., twice some thread somewhere claimed they + // would copy 'slot X' (by bumping _copyIdx) but they never claimed to + // have finished (by bumping _copyDone). Our choices become limited: + // we can wait for the work-claimers to finish (and become a blocking + // algorithm) or do the copy work ourselves. Tiny tables with huge + // thread counts trying to copy the table often 'panic'. + if( panic_start == -1 ) { // No panic? + copyidx = (int)_copyIdx; + while( copyidx < (oldlen<<1) && // 'panic' check + !_copyIdxUpdater.compareAndSet(this,copyidx,copyidx+MIN_COPY_WORK) ) + copyidx = (int)_copyIdx; // Re-read + if( !(copyidx < (oldlen<<1)) ) // Panic! + panic_start = copyidx; // Record where we started to panic-copy + } + + // We now know what to copy. Try to copy. + int workdone = 0; + for( int i=0; i 0 ) // Report work-done occasionally + copy_check_and_promote( workdone );// See if we can promote + //for( int i=0; i 0 ) { + while( !_copyDoneUpdater.compareAndSet(this,copyDone,nowDone) ) { + copyDone = _copyDone; // Reload, retry + nowDone = copyDone+workdone; + assert nowDone <= oldlen; + } + } + + // Check for copy being ALL done, and promote. Note that we might have + // nested in-progress copies and manage to finish a nested copy before + // finishing the top-level copy. We only promote top-level copies. + if( nowDone == oldlen && // Ready to promote this table? + _nbhml._chm == this && // Looking at the top-level table? + // Attempt to promote + _nbhml.CAS(_chm_offset,this,_newchm) ) { + _nbhml._last_resize_milli = System.currentTimeMillis(); // Record resize time for next check + } + } + + // --- copy_slot --------------------------------------------------------- + // Copy one K/V pair from oldkvs[i] to newkvs. Returns true if we can + // confirm that we set an old-table slot to TOMBPRIME, and only returns after + // updating the new table. We need an accurate confirmed-copy count so + // that we know when we can promote (if we promote the new table too soon, + // other threads may 'miss' on values not-yet-copied from the old table). + // We don't allow any direct updates on the new table, unless they first + // happened to the old table - so that any transition in the new table from + // null to not-null must have been from a copy_slot (or other old-table + // overwrite) and not from a thread directly writing in the new table. + private boolean copy_slot( int idx ) { + // Blindly set the key slot from NO_KEY to some key which hashes here, + // to eagerly stop fresh put's from inserting new values in the old + // table when the old table is mid-resize. We don't need to act on the + // results here, because our correctness stems from box'ing the Value + // field. Slamming the Key field is a minor speed optimization. + long key; + while( (key=_keys[idx]) == NO_KEY ) + CAS_key(idx, NO_KEY, (idx+_keys.length)/*a non-zero key which hashes here*/); + + // --- + // Prevent new values from appearing in the old table. + // Box what we see in the old table, to prevent further updates. + Object oldval = _vals[idx]; // Read OLD table + while( !(oldval instanceof Prime) ) { + final Prime box = (oldval == null || oldval == TOMBSTONE) ? TOMBPRIME : new Prime(oldval); + if( CAS_val(idx,oldval,box) ) { // CAS down a box'd version of oldval + // If we made the Value slot hold a TOMBPRIME, then we both + // prevented further updates here but also the (absent) oldval is + // vaccuously available in the new table. We return with true here: + // any thread looking for a value for this key can correctly go + // straight to the new table and skip looking in the old table. + if( box == TOMBPRIME ) + return true; + // Otherwise we boxed something, but it still needs to be + // copied into the new table. + oldval = box; // Record updated oldval + break; // Break loop; oldval is now boxed by us + } + oldval = _vals[idx]; // Else try, try again + } + if( oldval == TOMBPRIME ) return false; // Copy already complete here! + + // --- + // Copy the value into the new table, but only if we overwrite a null. + // If another value is already in the new table, then somebody else + // wrote something there and that write is happens-after any value that + // appears in the old table. + Object old_unboxed = ((Prime)oldval)._V; + assert old_unboxed != TOMBSTONE; + boolean copied_into_new = (_newchm.putIfMatch(key, old_unboxed, null) == null); + + // --- + // Finally, now that any old value is exposed in the new table, we can + // forever hide the old-table value by slapping a TOMBPRIME down. This + // will stop other threads from uselessly attempting to copy this slot + // (i.e., it's a speed optimization not a correctness issue). + while( oldval != TOMBPRIME && !CAS_val(idx,oldval,TOMBPRIME) ) + oldval = _vals[idx]; + + return copied_into_new; + } // end copy_slot + } // End of CHM + + + // --- Snapshot ------------------------------------------------------------ + // The main class for iterating over the NBHM. It "snapshots" a clean + // view of the K/V array. + private class SnapshotV implements Iterator, Enumeration { + final CHM _sschm; + public SnapshotV() { + CHM topchm; + while( true ) { // Verify no table-copy-in-progress + topchm = _chm; + if( topchm._newchm == null ) // No table-copy-in-progress + break; + // Table copy in-progress - so we cannot get a clean iteration. We + // must help finish the table copy before we can start iterating. + topchm.help_copy_impl(true); + } + // The "linearization point" for the iteration. Every key in this table + // will be visited, but keys added later might be skipped or even be + // added to a following table (also not iterated over). + _sschm = topchm; + // Warm-up the iterator + _idx = -1; + next(); + } + int length() { return _sschm._keys.length; } + long key(final int idx) { return _sschm._keys[idx]; } + private int _idx; // -2 for NO_KEY, -1 for CHECK_NEW_TABLE_LONG, 0-keys.length + private long _nextK, _prevK; // Last 2 keys found + private TypeV _nextV, _prevV; // Last 2 values found + public boolean hasNext() { return _nextV != null; } + public TypeV next() { + // 'next' actually knows what the next value will be - it had to + // figure that out last go 'round lest 'hasNext' report true and + // some other thread deleted the last value. Instead, 'next' + // spends all its effort finding the key that comes after the + // 'next' key. + if( _idx != -1 && _nextV == null ) throw new NoSuchElementException(); + _prevK = _nextK; // This will become the previous key + _prevV = _nextV; // This will become the previous value + _nextV = null; // We have no more next-key + // Attempt to set <_nextK,_nextV> to the next K,V pair. + // _nextV is the trigger: stop searching when it is != null + if( _idx == -1 ) { // Check for NO_KEY + _idx = 0; // Setup for next phase of search + _nextK = NO_KEY; + if( (_nextV=get(_nextK)) != null ) return _prevV; + } + while( _idx, but the JDK always + // removes by key, even when the value has changed. + removeKey(); + } + + public TypeV nextElement() { return next(); } + public boolean hasMoreElements() { return hasNext(); } + } + + /** Returns an enumeration of the values in this table. + * @return an enumeration of the values in this table + * @see #values() */ + public Enumeration elements() { return new SnapshotV(); } + + // --- values -------------------------------------------------------------- + /** Returns a {@link Collection} view of the values contained in this map. + * The collection is backed by the map, so changes to the map are reflected + * in the collection, and vice-versa. The collection supports element + * removal, which removes the corresponding mapping from this map, via the + * Iterator.remove, Collection.remove, + * removeAll, retainAll, and clear operations. + * It does not support the add or addAll operations. + * + *

The view's iterator is a "weakly consistent" iterator that + * will never throw {@link ConcurrentModificationException}, and guarantees + * to traverse elements as they existed upon construction of the iterator, + * and may (but is not guaranteed to) reflect any modifications subsequent + * to construction. */ + public Collection values() { + return new AbstractCollection() { + public void clear ( ) { NonBlockingHashMapLong.this.clear ( ); } + public int size ( ) { return NonBlockingHashMapLong.this.size ( ); } + public boolean contains( Object v ) { return NonBlockingHashMapLong.this.containsValue(v); } + public Iterator iterator() { return new SnapshotV(); } + }; + } + + // --- keySet -------------------------------------------------------------- + /** A class which implements the {@link Iterator} and {@link Enumeration} + * interfaces, generified to the {@link Long} class and supporting a + * non-auto-boxing {@link #nextLong} function. */ + public class IteratorLong implements Iterator, Enumeration { + private final SnapshotV _ss; + /** A new IteratorLong */ + public IteratorLong() { _ss = new SnapshotV(); } + /** Remove last key returned by {@link #next} or {@link #nextLong}. */ + public void remove() { _ss.removeKey(); } + /** Auto-box and return the next key. */ + public Long next () { _ss.next(); return _ss._prevK; } + /** Return the next key as a primitive {@code long}. */ + public long nextLong() { _ss.next(); return _ss._prevK; } + /** True if there are more keys to iterate over. */ + public boolean hasNext() { return _ss.hasNext(); } + /** Auto-box and return the next key. */ + public Long nextElement() { return next(); } + /** True if there are more keys to iterate over. */ + public boolean hasMoreElements() { return hasNext(); } + } + /** Returns an enumeration of the auto-boxed keys in this table. + * Warning: this version will auto-box all returned keys. + * @return an enumeration of the auto-boxed keys in this table + * @see #keySet() */ + public Enumeration keys() { return new IteratorLong(); } + + /** Returns a {@link Set} view of the keys contained in this map; with care + * the keys may be iterated over without auto-boxing. The + * set is backed by the map, so changes to the map are reflected in the + * set, and vice-versa. The set supports element removal, which removes + * the corresponding mapping from this map, via the + * Iterator.remove, Set.remove, removeAll, + * retainAll, and clear operations. It does not support + * the add or addAll operations. + * + *

The view's iterator is a "weakly consistent" iterator that + * will never throw {@link ConcurrentModificationException}, and guarantees + * to traverse elements as they existed upon construction of the iterator, + * and may (but is not guaranteed to) reflect any modifications subsequent + * to construction. */ + public Set keySet() { + return new AbstractSet () { + public void clear ( ) { NonBlockingHashMapLong.this.clear ( ); } + public int size ( ) { return NonBlockingHashMapLong.this.size ( ); } + public boolean contains( Object k ) { return NonBlockingHashMapLong.this.containsKey(k); } + public boolean remove ( Object k ) { return NonBlockingHashMapLong.this.remove (k) != null; } + public IteratorLong iterator() { return new IteratorLong(); } + }; + } + + /** Keys as a long array. Array may be zero-padded if keys are concurrently deleted. */ + @SuppressWarnings("unchecked") + public long[] keySetLong() { + long[] dom = new long[size()]; + IteratorLong i=(IteratorLong)keySet().iterator(); + int j=0; + while( j < dom.length && i.hasNext() ) + dom[j++] = i.nextLong(); + return dom; + } + + // --- entrySet ------------------------------------------------------------ + // Warning: Each call to 'next' in this iterator constructs a new Long and a + // new NBHMLEntry. + private class NBHMLEntry extends AbstractEntry { + NBHMLEntry( final Long k, final TypeV v ) { super(k,v); } + public TypeV setValue(final TypeV val) { + if (val == null) throw new NullPointerException(); + _val = val; + return put(_key, val); + } + } + private class SnapshotE implements Iterator> { + final SnapshotV _ss; + public SnapshotE() { _ss = new SnapshotV(); } + public void remove() { + // NOTE: it would seem logical that entry removal will semantically mean + // removing the matching pair , but the JDK always removes by key, + // even when the value has changed. + _ss.removeKey(); + } + public Map.Entry next() { _ss.next(); return new NBHMLEntry(_ss._prevK,_ss._prevV); } + public boolean hasNext() { return _ss.hasNext(); } + } + + /** Returns a {@link Set} view of the mappings contained in this map. The + * set is backed by the map, so changes to the map are reflected in the + * set, and vice-versa. The set supports element removal, which removes + * the corresponding mapping from the map, via the + * Iterator.remove, Set.remove, removeAll, + * retainAll, and clear operations. It does not support + * the add or addAll operations. + * + *

The view's iterator is a "weakly consistent" iterator + * that will never throw {@link ConcurrentModificationException}, + * and guarantees to traverse elements as they existed upon + * construction of the iterator, and may (but is not guaranteed to) + * reflect any modifications subsequent to construction. + * + *

Warning: the iterator associated with this Set + * requires the creation of {@link java.util.Map.Entry} objects with each + * iteration. The {@link org.jctools.maps.NonBlockingHashMap} + * does not normally create or using {@link java.util.Map.Entry} objects so + * they will be created soley to support this iteration. Iterating using + * {@link Map#keySet} or {@link Map#values} will be more efficient. In addition, + * this version requires auto-boxing the keys. + */ + public Set> entrySet() { + return new AbstractSet>() { + public void clear ( ) { NonBlockingHashMapLong.this.clear( ); } + public int size ( ) { return NonBlockingHashMapLong.this.size ( ); } + public boolean remove( final Object o ) { + if (!(o instanceof Map.Entry)) return false; + final Map.Entry e = (Map.Entry)o; + return NonBlockingHashMapLong.this.remove(e.getKey(), e.getValue()); + } + public boolean contains(final Object o) { + if (!(o instanceof Map.Entry)) return false; + final Map.Entry e = (Map.Entry)o; + TypeV v = get(e.getKey()); + return v != null && v.equals(e.getValue()); + } + public Iterator> iterator() { return new SnapshotE(); } + }; + } + + // --- writeObject ------------------------------------------------------- + // Write a NBHML to a stream + private void writeObject(java.io.ObjectOutputStream s) throws IOException { + s.defaultWriteObject(); // Write nothing + for( long K : keySet() ) { + final Object V = get(K); // Do an official 'get' + s.writeLong (K); // Write the pair + s.writeObject(V); + } + s.writeLong(NO_KEY); // Sentinel to indicate end-of-data + s.writeObject(null); + } + + // --- readObject -------------------------------------------------------- + // Read a NBHML from a stream + @SuppressWarnings("unchecked") + private void readObject(java.io.ObjectInputStream s) throws IOException, ClassNotFoundException { + s.defaultReadObject(); // Read nothing + initialize(MIN_SIZE); + for (;;) { + final long K = s.readLong(); + final TypeV V = (TypeV) s.readObject(); + if( K == NO_KEY && V == null ) break; + put(K,V); // Insert with an offical put + } + } + + /** + * Creates a shallow copy of this hashtable. All the structure of the + * hashtable itself is copied, but the keys and values are not cloned. + * This is a relatively expensive operation. + * + * @return a clone of the hashtable. + */ + @SuppressWarnings("unchecked") + @Override + public NonBlockingHashMapLong clone() { + try { + // Must clone, to get the class right; NBHML might have been + // extended so it would be wrong to just make a new NBHML. + NonBlockingHashMapLong t = (NonBlockingHashMapLong) super.clone(); + // But I don't have an atomic clone operation - the underlying _kvs + // structure is undergoing rapid change. If I just clone the _kvs + // field, the CHM in _kvs[0] won't be in sync. + // + // Wipe out the cloned array (it was shallow anyways). + t.clear(); + // Now copy sanely + for( long K : keySetLong() ) + t.put(K,get(K)); + return t; + } catch (CloneNotSupportedException e) { + // this shouldn't happen, since we are Cloneable + throw new InternalError(); + } + } + +} // End NonBlockingHashMapLong class diff --git a/netty-jctools/src/main/java/org/jctools/maps/NonBlockingHashSet.java b/netty-jctools/src/main/java/org/jctools/maps/NonBlockingHashSet.java new file mode 100644 index 0000000..2d75a0f --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/maps/NonBlockingHashSet.java @@ -0,0 +1,60 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.maps; +import java.io.Serializable; +import java.util.AbstractSet; +import java.util.Iterator; +import java.util.Set; + +/** + * A simple wrapper around {@link NonBlockingHashMap} making it implement the + * {@link Set} interface. All operations are Non-Blocking and multi-thread safe. + * + * @since 1.5 + * @author Cliff Click + */ +public class NonBlockingHashSet extends AbstractSet implements Serializable { + private static final Object V = ""; + + private final NonBlockingHashMap _map; + + /** Make a new empty {@link NonBlockingHashSet}. */ + public NonBlockingHashSet() { super(); _map = new NonBlockingHashMap(); } + + /** Add {@code o} to the set. + * @return true if {@code o} was added to the set, false + * if {@code o} was already in the set. */ + public boolean add( final E o ) { return _map.putIfAbsent(o,V) == null; } + + /** @return true if {@code o} is in the set. */ + public boolean contains ( final Object o ) { return _map.containsKey(o); } + + /** @return Returns the match for {@code o} if {@code o} is in the set. */ + public E get( final E o ) { return _map.getk(o); } + + /** Remove {@code o} from the set. + * @return true if {@code o} was removed to the set, false + * if {@code o} was not in the set. + */ + public boolean remove( final Object o ) { return _map.remove(o) == V; } + /** Current count of elements in the set. Due to concurrent racing updates, + * the size is only ever approximate. Updates due to the calling thread are + * immediately visible to calling thread. + * @return count of elements. */ + public int size( ) { return _map.size(); } + /** Empty the set. */ + public void clear( ) { _map.clear(); } + + public Iteratoriterator( ) { return _map.keySet().iterator(); } +} diff --git a/netty-jctools/src/main/java/org/jctools/maps/NonBlockingIdentityHashMap.java b/netty-jctools/src/main/java/org/jctools/maps/NonBlockingIdentityHashMap.java new file mode 100644 index 0000000..2f684c3 --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/maps/NonBlockingIdentityHashMap.java @@ -0,0 +1,1307 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.maps; + +import org.jctools.util.RangeUtil; + +import java.io.IOException; +import java.io.Serializable; +import java.util.*; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.atomic.AtomicLongFieldUpdater; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; + +import static org.jctools.maps.NonBlockingHashMap.DUMMY_VOLATILE; +import static org.jctools.util.UnsafeAccess.UNSAFE; +import static org.jctools.util.UnsafeAccess.fieldOffset; + +/** + * A lock-free alternate implementation of {@link java.util.concurrent.ConcurrentHashMap} + * with better scaling properties and generally lower costs to mutate the Map. + * It provides identical correctness properties as ConcurrentHashMap. All + * operations are non-blocking and multi-thread safe, including all update + * operations. {@link NonBlockingHashMap} scales substatially better than + * {@link java.util.concurrent.ConcurrentHashMap} for high update rates, even with a + * large concurrency factor. Scaling is linear up to 768 CPUs on a 768-CPU + * Azul box, even with 100% updates or 100% reads or any fraction in-between. + * Linear scaling up to all cpus has been observed on a 32-way Sun US2 box, + * 32-way Sun Niagra box, 8-way Intel box and a 4-way Power box. + * + * This class obeys the same functional specification as {@link + * java.util.Hashtable}, and includes versions of methods corresponding to + * each method of Hashtable. However, even though all operations are + * thread-safe, operations do not entail locking and there is + * not any support for locking the entire table in a way that + * prevents all access. This class is fully interoperable with + * Hashtable in programs that rely on its thread safety but not on + * its synchronization details. + * + *

Operations (including put) generally do not block, so may + * overlap with other update operations (including other puts and + * removes). Retrievals reflect the results of the most recently + * completed update operations holding upon their onset. For + * aggregate operations such as putAll, concurrent retrievals may + * reflect insertion or removal of only some entries. Similarly, Iterators + * and Enumerations return elements reflecting the state of the hash table at + * some point at or since the creation of the iterator/enumeration. They do + * not throw {@link ConcurrentModificationException}. However, + * iterators are designed to be used by only one thread at a time. + * + *

Very full tables, or tables with high reprobe rates may trigger an + * internal resize operation to move into a larger table. Resizing is not + * terribly expensive, but it is not free either; during resize operations + * table throughput may drop somewhat. All threads that visit the table + * during a resize will 'help' the resizing but will still be allowed to + * complete their operation before the resize is finished (i.e., a simple + * 'get' operation on a million-entry table undergoing resizing will not need + * to block until the entire million entries are copied). + * + *

This class and its views and iterators implement all of the + * optional methods of the {@link Map} and {@link Iterator} + * interfaces. + * + *

Like {@link Hashtable} but unlike {@link HashMap}, this class + * does not allow null to be used as a key or value. + * + * + * @since 1.5 + * @author Cliff Click + * @param the type of keys maintained by this map + * @param the type of mapped values + * + * @author Prashant Deva + * Modified from original NonBlockingHashMap to use identity equality. + * Uses System.identityHashCode() to calculate hashMap. + * Key equality is compared using '=='. + */ +public class NonBlockingIdentityHashMap + extends AbstractMap + implements ConcurrentMap, Cloneable, Serializable { + + private static final long serialVersionUID = 1234123412341234123L; + + private static final int REPROBE_LIMIT=10; // Too many reprobes then force a table-resize + + // --- Bits to allow Unsafe access to arrays + private static final int _Obase = UNSAFE.arrayBaseOffset(Object[].class); + private static final int _Oscale = UNSAFE.arrayIndexScale(Object[].class); + private static long rawIndex(final Object[] ary, final int idx) { + assert idx >= 0 && idx < ary.length; + return _Obase + (idx * (long)_Oscale); + } + + // --- Setup to use Unsafe + private static final long _kvs_offset = fieldOffset(NonBlockingHashMap.class, "_kvs"); + + private final boolean CAS_kvs( final Object[] oldkvs, final Object[] newkvs ) { + return UNSAFE.compareAndSwapObject(this, _kvs_offset, oldkvs, newkvs ); + } + + // --- Adding a 'prime' bit onto Values via wrapping with a junk wrapper class + private static final class Prime { + final Object _V; + Prime( Object V ) { _V = V; } + static Object unbox( Object V ) { return V instanceof Prime ? ((Prime)V)._V : V; } + } + + // --- hash ---------------------------------------------------------------- + // Helper function to spread lousy hashCodes + private static final int hash(final Object key) { + if (key == null) throw new NullPointerException(); + int h = System.identityHashCode(key); // The real hashCode call + // I assume that System.identityHashCode is well implemented with a good + // spreader, and a second bit-spreader is redundant. + return h; + } + + // --- The Hash Table -------------------- + // Slot 0 is always used for a 'CHM' entry below to hold the interesting + // bits of the hash table. Slot 1 holds full hashes as an array of ints. + // Slots {2,3}, {4,5}, etc hold {Key,Value} pairs. The entire hash table + // can be atomically replaced by CASing the _kvs field. + // + // Why is CHM buried inside the _kvs Object array, instead of the other way + // around? The CHM info is used during resize events and updates, but not + // during standard 'get' operations. I assume 'get' is much more frequent + // than 'put'. 'get' can skip the extra indirection of skipping through the + // CHM to reach the _kvs array. + private transient Object[] _kvs; + private static final CHM chm (Object[] kvs) { return (CHM )kvs[0]; } + private static final int[] hashes(Object[] kvs) { return (int[])kvs[1]; } + // Number of K,V pairs in the table + private static final int len(Object[] kvs) { return (kvs.length-2)>>1; } + + // Time since last resize + private transient long _last_resize_milli; + + // --- Minimum table size ---------------- + // Pick size 8 K/V pairs, which turns into (8*2+2)*4+12 = 84 bytes on a + // standard 32-bit HotSpot, and (8*2+2)*8+12 = 156 bytes on 64-bit Azul. + private static final int MIN_SIZE_LOG=3; // + private static final int MIN_SIZE=(1<>2); + } + + // --- NonBlockingHashMap -------------------------------------------------- + // Constructors + + /** Create a new NonBlockingHashMap with default minimum size (currently set + * to 8 K/V pairs or roughly 84 bytes on a standard 32-bit JVM). */ + public NonBlockingIdentityHashMap( ) { this(MIN_SIZE); } + + /** Create a new NonBlockingHashMap with initial room for the given number of + * elements, thus avoiding internal resizing operations to reach an + * appropriate size. Large numbers here when used with a small count of + * elements will sacrifice space for a small amount of time gained. The + * initial size will be rounded up internally to the next larger power of 2. */ + public NonBlockingIdentityHashMap( final int initial_sz ) { initialize(initial_sz); } + private final void initialize( int initial_sz ) { + RangeUtil.checkPositiveOrZero(initial_sz, "initial_sz"); + int i; // Convert to next largest power-of-2 + if( initial_sz > 1024*1024 ) initial_sz = 1024*1024; + for( i=MIN_SIZE_LOG; (1<size() == 0. + * @return size() == 0 */ + @Override + public boolean isEmpty ( ) { return size() == 0; } + + /** Tests if the key in the table using the equals method. + * @return true if the key is in the table using the equals method + * @throws NullPointerException if the specified key is null */ + @Override + public boolean containsKey( Object key ) { return get(key) != null; } + + /** Legacy method testing if some key maps into the specified value in this + * table. This method is identical in functionality to {@link + * #containsValue}, and exists solely to ensure full compatibility with + * class {@link java.util.Hashtable}, which supported this method prior to + * introduction of the Java Collections framework. + * @param val a value to search for + * @return true if this map maps one or more keys to the specified value + * @throws NullPointerException if the specified value is null */ + public boolean contains ( Object val ) { return containsValue(val); } + + /** Maps the specified key to the specified value in the table. Neither key + * nor value can be null. + *

The value can be retrieved by calling {@link #get} with a key that is + * equal to the original key. + * @param key key with which the specified value is to be associated + * @param val value to be associated with the specified key + * @return the previous value associated with key, or + * null if there was no mapping for key + * @throws NullPointerException if the specified key or value is null */ + @Override + public TypeV put ( TypeK key, TypeV val ) { return putIfMatch( key, val, NO_MATCH_OLD); } + + /** Atomically, do a {@link #put} if-and-only-if the key is not mapped. + * Useful to ensure that only a single mapping for the key exists, even if + * many threads are trying to create the mapping in parallel. + * @return the previous value associated with the specified key, + * or null if there was no mapping for the key + * @throws NullPointerException if the specified key or value is null */ + @Override + public TypeV putIfAbsent( TypeK key, TypeV val ) { return putIfMatch( key, val, TOMBSTONE ); } + + /** Removes the key (and its corresponding value) from this map. + * This method does nothing if the key is not in the map. + * @return the previous value associated with key, or + * null if there was no mapping for key + * @throws NullPointerException if the specified key is null */ + @Override + public TypeV remove ( Object key ) { return putIfMatch( key,TOMBSTONE, NO_MATCH_OLD); } + + /** Atomically do a {@link #remove(Object)} if-and-only-if the key is mapped + * to a value which is equals to the given value. + * @throws NullPointerException if the specified key or value is null */ + public boolean remove ( Object key,Object val ) { + return objectsEquals(putIfMatch( key,TOMBSTONE, val ), val); + } + + /** Atomically do a put(key,val) if-and-only-if the key is + * mapped to some value already. + * @throws NullPointerException if the specified key or value is null */ + @Override + public TypeV replace ( TypeK key, TypeV val ) { return putIfMatch( key, val,MATCH_ANY ); } + + /** Atomically do a put(key,newValue) if-and-only-if the key is + * mapped a value which is equals to oldValue. + * @throws NullPointerException if the specified key or value is null */ + @Override + public boolean replace ( TypeK key, TypeV oldValue, TypeV newValue ) { + return objectsEquals(putIfMatch( key, newValue, oldValue ), oldValue); + } + private static boolean objectsEquals(Object a, Object b) { + return (a == b) || (a != null && a.equals(b)); + } + + private final TypeV putIfMatch( Object key, Object newVal, Object oldVal ) { + if (oldVal == null || newVal == null) throw new NullPointerException(); + final Object res = putIfMatch0( this, _kvs, key, newVal, oldVal ); + assert !(res instanceof Prime); + assert res != null; + return res == TOMBSTONE ? null : (TypeV)res; + } + + + /** Copies all of the mappings from the specified map to this one, replacing + * any existing mappings. + * @param m mappings to be stored in this map */ + @Override + public void putAll(Map m) { + for (Map.Entry e : m.entrySet()) + put(e.getKey(), e.getValue()); + } + + /** Removes all of the mappings from this map. */ + @Override + public void clear() { // Smack a new empty table down + Object[] newkvs = new NonBlockingIdentityHashMap(MIN_SIZE)._kvs; + while( !CAS_kvs(_kvs,newkvs) ) // Spin until the clear works + ; + } + + /** Returns true if this Map maps one or more keys to the specified + * value. Note: This method requires a full internal traversal of the + * hash table and is much slower than {@link #containsKey}. + * @param val value whose presence in this map is to be tested + * @return true if this map maps one or more keys to the specified value + * @throws NullPointerException if the specified value is null */ + @Override + public boolean containsValue( final Object val ) { + if( val == null ) throw new NullPointerException(); + for( TypeV V : values() ) + if( V == val || V.equals(val) ) + return true; + return false; + } + + // This function is supposed to do something for Hashtable, and the JCK + // tests hang until it gets called... by somebody ... for some reason, + // any reason.... + protected void rehash() { + } + + /** + * Creates a shallow copy of this hashtable. All the structure of the + * hashtable itself is copied, but the keys and values are not cloned. + * This is a relatively expensive operation. + * + * @return a clone of the hashtable. + */ + @Override + public Object clone() { + try { + // Must clone, to get the class right; NBHM might have been + // extended so it would be wrong to just make a new NBHM. + NonBlockingIdentityHashMap t = (NonBlockingIdentityHashMap) super.clone(); + // But I don't have an atomic clone operation - the underlying _kvs + // structure is undergoing rapid change. If I just clone the _kvs + // field, the CHM in _kvs[0] won't be in sync. + // + // Wipe out the cloned array (it was shallow anyways). + t.clear(); + // Now copy sanely + for( TypeK K : keySet() ) { + final TypeV V = get(K); // Do an official 'get' + t.put(K,V); + } + return t; + } catch (CloneNotSupportedException e) { + // this shouldn't happen, since we are Cloneable + throw new InternalError(); + } + } + + /** + * Returns a string representation of this map. The string representation + * consists of a list of key-value mappings in the order returned by the + * map's entrySet view's iterator, enclosed in braces + * ("{}"). Adjacent mappings are separated by the characters + * ", " (comma and space). Each key-value mapping is rendered as + * the key followed by an equals sign ("=") followed by the + * associated value. Keys and values are converted to strings as by + * {@link String#valueOf(Object)}. + * + * @return a string representation of this map + */ + @Override + public String toString() { + Iterator> i = entrySet().iterator(); + if( !i.hasNext()) + return "{}"; + + StringBuilder sb = new StringBuilder(); + sb.append('{'); + for (;;) { + Entry e = i.next(); + TypeK key = e.getKey(); + TypeV value = e.getValue(); + sb.append(key == this ? "(this Map)" : key); + sb.append('='); + sb.append(value == this ? "(this Map)" : value); + if( !i.hasNext()) + return sb.append('}').toString(); + sb.append(", "); + } + } + + // --- get ----------------------------------------------------------------- + /** Returns the value to which the specified key is mapped, or {@code null} + * if this map contains no mapping for the key. + *

More formally, if this map contains a mapping from a key {@code k} to + * a value {@code v} such that {@code key.equals(k)}, then this method + * returns {@code v}; otherwise it returns {@code null}. (There can be at + * most one such mapping.) + * @throws NullPointerException if the specified key is null */ + // Never returns a Prime nor a Tombstone. + @Override + public TypeV get( Object key ) { + final Object V = get_impl(this,_kvs,key); + assert !(V instanceof Prime); // Never return a Prime + assert V != TOMBSTONE; + return (TypeV)V; + } + + private static final Object get_impl( final NonBlockingIdentityHashMap topmap, final Object[] kvs, final Object key ) { + final int fullhash= hash (key); // throws NullPointerException if key is null + final int len = len (kvs); // Count of key/value pairs, reads kvs.length + final CHM chm = chm (kvs); // The CHM, for a volatile read below; reads slot 0 of kvs + + int idx = fullhash & (len-1); // First key hash + + // Main spin/reprobe loop, looking for a Key hit + int reprobe_cnt=0; + while( true ) { + // Probe table. Each read of 'val' probably misses in cache in a big + // table; hopefully the read of 'key' then hits in cache. + final Object K = key(kvs,idx); // Get key before volatile read, could be null + final Object V = val(kvs,idx); // Get value before volatile read, could be null or Tombstone or Prime + if( K == null ) return null; // A clear miss + + // We need a volatile-read here to preserve happens-before semantics on + // newly inserted Keys. If the Key body was written just before inserting + // into the table a Key-compare here might read the uninitialized Key body. + // Annoyingly this means we have to volatile-read before EACH key compare. + // . + // We also need a volatile-read between reading a newly inserted Value + // and returning the Value (so the user might end up reading the stale + // Value contents). Same problem as with keys - and the one volatile + // read covers both. + final Object[] newkvs = chm._newkvs; // VOLATILE READ before key compare + + // Key-compare + if( K == key ) { + // Key hit! Check for no table-copy-in-progress + if( !(V instanceof Prime) ) // No copy? + return (V == TOMBSTONE) ? null : V; // Return the value + // Key hit - but slot is (possibly partially) copied to the new table. + // Finish the copy & retry in the new table. + return get_impl(topmap,chm.copy_slot_and_check(topmap,kvs,idx,key),key); // Retry in the new table + } + // get and put must have the same key lookup logic! But only 'put' + // needs to force a table-resize for a too-long key-reprobe sequence. + // Check for too-many-reprobes on get - and flip to the new table. + if( ++reprobe_cnt >= reprobe_limit(len) || // too many probes + K == TOMBSTONE ) // found a TOMBSTONE key, means no more keys in this table + return newkvs == null ? null : get_impl(topmap,topmap.help_copy(newkvs),key); // Retry in the new table + + idx = (idx+1)&(len-1); // Reprobe by 1! (could now prefetch) + } + } + + // --- putIfMatch --------------------------------------------------------- + // Put, Remove, PutIfAbsent, etc. Return the old value. If the returned + // value is equal to expVal (or expVal is NO_MATCH_OLD) then the put can be + // assumed to work (although might have been immediately overwritten). Only + // the path through copy_slot passes in an expected value of null, and + // putIfMatch only returns a null if passed in an expected null. + private static final Object putIfMatch0(final NonBlockingIdentityHashMap topmap, final Object[] kvs, final Object key, final Object putval, final Object expVal ) { + assert putval != null; + assert !(putval instanceof Prime); + assert !(expVal instanceof Prime); + final int fullhash = hash (key); // throws NullPointerException if key null + final int len = len (kvs); // Count of key/value pairs, reads kvs.length + final CHM chm = chm (kvs); // Reads kvs[0] + int idx = fullhash & (len-1); + + // --- + // Key-Claim stanza: spin till we can claim a Key (or force a resizing). + int reprobe_cnt=0; + Object K=null, V=null; + Object[] newkvs=null; + while( true ) { // Spin till we get a Key slot + V = val(kvs,idx); // Get old value (before volatile read below!) + K = key(kvs,idx); // Get current key + if( K == null ) { // Slot is free? + // Found an empty Key slot - which means this Key has never been in + // this table. No need to put a Tombstone - the Key is not here! + if( putval == TOMBSTONE ) return TOMBSTONE; // Not-now & never-been in this table + if( expVal == MATCH_ANY ) return TOMBSTONE; // Will not match, even after K inserts + // Claim the null key-slot + if( CAS_key(kvs,idx, null, key ) ) { // Claim slot for Key + chm._slots.add(1); // Raise key-slots-used count + break; // Got it! + } + // CAS to claim the key-slot failed. + // + // This re-read of the Key points out an annoying short-coming of Java + // CAS. Most hardware CAS's report back the existing value - so that + // if you fail you have a *witness* - the value which caused the CAS to + // fail. The Java API turns this into a boolean destroying the + // witness. Re-reading does not recover the witness because another + // thread can write over the memory after the CAS. Hence we can be in + // the unfortunate situation of having a CAS fail *for cause* but + // having that cause removed by a later store. This turns a + // non-spurious-failure CAS (such as Azul has) into one that can + // apparently spuriously fail - and we avoid apparent spurious failure + // by not allowing Keys to ever change. + + // Volatile read, to force loads of K to retry despite JIT, otherwise + // it is legal to e.g. haul the load of "K = key(kvs,idx);" outside of + // this loop (since failed CAS ops have no memory ordering semantics). + int dummy = DUMMY_VOLATILE; + continue; + } + // Key slot was not null, there exists a Key here + + // We need a volatile-read here to preserve happens-before semantics on + // newly inserted Keys. If the Key body was written just before inserting + // into the table a Key-compare here might read the uninitialized Key body. + // Annoyingly this means we have to volatile-read before EACH key compare. + newkvs = chm._newkvs; // VOLATILE READ before key compare + + if( K == key ) + break; // Got it! + + // get and put must have the same key lookup logic! Lest 'get' give + // up looking too soon. + //topmap._reprobes.add(1); + if( ++reprobe_cnt >= reprobe_limit(len) || // too many probes or + K == TOMBSTONE ) { // found a TOMBSTONE key, means no more keys + // We simply must have a new table to do a 'put'. At this point a + // 'get' will also go to the new table (if any). We do not need + // to claim a key slot (indeed, we cannot find a free one to claim!). + newkvs = chm.resize(topmap,kvs); + if( expVal != null ) topmap.help_copy(newkvs); // help along an existing copy + return putIfMatch0(topmap, newkvs, key, putval, expVal); + } + + idx = (idx+1)&(len-1); // Reprobe! + } // End of spinning till we get a Key slot + + while ( true ) { // Spin till we insert a value + // --- + // Found the proper Key slot, now update the matching Value slot. We + // never put a null, so Value slots monotonically move from null to + // not-null (deleted Values use Tombstone). Thus if 'V' is null we + // fail this fast cutout and fall into the check for table-full. + if( putval == V ) return V; // Fast cutout for no-change + + // See if we want to move to a new table (to avoid high average re-probe + // counts). We only check on the initial set of a Value from null to + // not-null (i.e., once per key-insert). Of course we got a 'free' check + // of newkvs once per key-compare (not really free, but paid-for by the + // time we get here). + if( newkvs == null && // New table-copy already spotted? + // Once per fresh key-insert check the hard way + ((V == null && chm.tableFull(reprobe_cnt,len)) || + // Or we found a Prime, but the JMM allowed reordering such that we + // did not spot the new table (very rare race here: the writing + // thread did a CAS of _newkvs then a store of a Prime. This thread + // reads the Prime, then reads _newkvs - but the read of Prime was so + // delayed (or the read of _newkvs was so accelerated) that they + // swapped and we still read a null _newkvs. The resize call below + // will do a CAS on _newkvs forcing the read. + V instanceof Prime) ) { + newkvs = chm.resize(topmap, kvs); // Force the new table copy to start + } + // See if we are moving to a new table. + // If so, copy our slot and retry in the new table. + if( newkvs != null ) { + return putIfMatch0(topmap, chm.copy_slot_and_check(topmap, kvs, idx, expVal), key, putval, expVal); + } + // --- + // We are finally prepared to update the existing table + assert !(V instanceof Prime); + + // Must match old, and we do not? Then bail out now. Note that either V + // or expVal might be TOMBSTONE. Also V can be null, if we've never + // inserted a value before. expVal can be null if we are called from + // copy_slot. + if( expVal != NO_MATCH_OLD && // Do we care about expected-Value at all? + V != expVal && // No instant match already? + (expVal != MATCH_ANY || V == TOMBSTONE || V == null) && + !(V==null && expVal == TOMBSTONE) && // Match on null/TOMBSTONE combo + (expVal == null || !expVal.equals(V)) ) { // Expensive equals check at the last + return (V == null) ? TOMBSTONE : V; // Do not update! + } + + // Actually change the Value in the Key,Value pair + if( CAS_val(kvs, idx, V, putval ) ) break; + + // CAS failed + // Because we have no witness, we do not know why it failed. + // Indeed, by the time we look again the value under test might have flipped + // a thousand times and now be the expected value (despite the CAS failing). + // Check for the never-succeed condition of a Prime value and jump to any + // nested table, or else just re-run. + + // We would not need this load at all if CAS returned the value on which + // the CAS failed (AKA witness). The new CAS semantics are supported via + // VarHandle in JDK9. + V = val(kvs,idx); // Get new value + + // If a Prime'd value got installed, we need to re-run the put on the + // new table. Otherwise we lost the CAS to another racing put. + if( V instanceof Prime ) + return putIfMatch0(topmap, chm.copy_slot_and_check(topmap, kvs, idx, expVal), key, putval, expVal); + + // Simply retry from the start. + // NOTE: need the fence, since otherwise 'val(kvs,idx)' load could be hoisted + // out of loop. + int dummy = DUMMY_VOLATILE; + } + + // CAS succeeded - we did the update! + // Both normal put's and table-copy calls putIfMatch, but table-copy + // does not (effectively) increase the number of live k/v pairs. + if( expVal != null ) { + // Adjust sizes - a striped counter + if( (V == null || V == TOMBSTONE) && putval != TOMBSTONE ) chm._size.add( 1); + if( !(V == null || V == TOMBSTONE) && putval == TOMBSTONE ) chm._size.add(-1); + } + + // We won; we know the update happened as expected. + return (V==null && expVal!=null) ? TOMBSTONE : V; + } + + // --- help_copy --------------------------------------------------------- + // Help along an existing resize operation. This is just a fast cut-out + // wrapper, to encourage inlining for the fast no-copy-in-progress case. We + // always help the top-most table copy, even if there are nested table + // copies in progress. + private final Object[] help_copy( Object[] helper ) { + // Read the top-level KVS only once. We'll try to help this copy along, + // even if it gets promoted out from under us (i.e., the copy completes + // and another KVS becomes the top-level copy). + Object[] topkvs = _kvs; + CHM topchm = chm(topkvs); + if( topchm._newkvs == null ) return helper; // No copy in-progress + topchm.help_copy_impl(this,topkvs,false); + return helper; + } + + + // --- CHM ----------------------------------------------------------------- + // The control structure for the NonBlockingIdentityHashMap + private static final class CHM { + // Size in active K,V pairs + private final ConcurrentAutoTable _size; + public int size () { return (int)_size.get(); } + + // --- + // These next 2 fields are used in the resizing heuristics, to judge when + // it is time to resize or copy the table. Slots is a count of used-up + // key slots, and when it nears a large fraction of the table we probably + // end up reprobing too much. Last-resize-milli is the time since the + // last resize; if we are running back-to-back resizes without growing + // (because there are only a few live keys but many slots full of dead + // keys) then we need a larger table to cut down on the churn. + + // Count of used slots, to tell when table is full of dead unusable slots + private final ConcurrentAutoTable _slots; + public int slots() { return (int)_slots.get(); } + + // --- + // New mappings, used during resizing. + // The 'new KVs' array - created during a resize operation. This + // represents the new table being copied from the old one. It's the + // volatile variable that is read as we cross from one table to the next, + // to get the required memory orderings. It monotonically transits from + // null to set (once). + volatile Object[] _newkvs; + private static final AtomicReferenceFieldUpdater _newkvsUpdater = + AtomicReferenceFieldUpdater.newUpdater(CHM.class,Object[].class, "_newkvs"); + // Set the _next field if we can. + boolean CAS_newkvs( Object[] newkvs ) { + while( _newkvs == null ) + if( _newkvsUpdater.compareAndSet(this,null,newkvs) ) + return true; + return false; + } + + // Sometimes many threads race to create a new very large table. Only 1 + // wins the race, but the losers all allocate a junk large table with + // hefty allocation costs. Attempt to control the overkill here by + // throttling attempts to create a new table. I cannot really block here + // (lest I lose the non-blocking property) but late-arriving threads can + // give the initial resizing thread a little time to allocate the initial + // new table. The Right Long Term Fix here is to use array-lets and + // incrementally create the new very large array. In C I'd make the array + // with malloc (which would mmap under the hood) which would only eat + // virtual-address and not real memory - and after Somebody wins then we + // could in parallel initialize the array. Java does not allow + // un-initialized array creation (especially of ref arrays!). + volatile long _resizers; // count of threads attempting an initial resize + private static final AtomicLongFieldUpdater _resizerUpdater = + AtomicLongFieldUpdater.newUpdater(CHM.class, "_resizers"); + + // --- + // Simple constructor + CHM( ConcurrentAutoTable size ) { + _size = size; + _slots= new ConcurrentAutoTable(); + } + + // --- tableFull --------------------------------------------------------- + // Heuristic to decide if this table is too full, and we should start a + // new table. Note that if a 'get' call has reprobed too many times and + // decided the table must be full, then always the estimate_sum must be + // high and we must report the table is full. If we do not, then we might + // end up deciding that the table is not full and inserting into the + // current table, while a 'get' has decided the same key cannot be in this + // table because of too many reprobes. The invariant is: + // slots.estimate_sum >= max_reprobe_cnt >= reprobe_limit(len) + private final boolean tableFull( int reprobe_cnt, int len ) { + return + // Do the cheap check first: we allow some number of reprobes always + reprobe_cnt >= REPROBE_LIMIT && + // More expensive check: see if the table is > 1/4 full. + _slots.estimate_get() >= reprobe_limit(len); + } + + // --- resize ------------------------------------------------------------ + // Resizing after too many probes. "How Big???" heuristics are here. + // Callers will (not this routine) will 'help_copy' any in-progress copy. + // Since this routine has a fast cutout for copy-already-started, callers + // MUST 'help_copy' lest we have a path which forever runs through + // 'resize' only to discover a copy-in-progress which never progresses. + private final Object[] resize( NonBlockingIdentityHashMap topmap, Object[] kvs) { + assert chm(kvs) == this; + + // Check for resize already in progress, probably triggered by another thread + Object[] newkvs = _newkvs; // VOLATILE READ + if( newkvs != null ) // See if resize is already in progress + return newkvs; // Use the new table already + + // No copy in-progress, so start one. First up: compute new table size. + int oldlen = len(kvs); // Old count of K,V pairs allowed + int sz = size(); // Get current table count of active K,V pairs + int newsz = sz; // First size estimate + + // Heuristic to determine new size. We expect plenty of dead-slots-with-keys + // and we need some decent padding to avoid endless reprobing. + if( sz >= (oldlen>>2) ) { // If we are >25% full of keys then... + newsz = oldlen<<1; // Double size + if( sz >= (oldlen>>1) ) // If we are >50% full of keys then... + newsz = oldlen<<2; // Double double size + } + // This heuristic in the next 2 lines leads to a much denser table + // with a higher reprobe rate + //if( sz >= (oldlen>>1) ) // If we are >50% full of keys then... + // newsz = oldlen<<1; // Double size + + // Last (re)size operation was very recent? Then double again; slows + // down resize operations for tables subject to a high key churn rate. + long tm = System.currentTimeMillis(); + long q=0; + if( newsz <= oldlen && // New table would shrink or hold steady? + tm <= topmap._last_resize_milli+10000 && // Recent resize (less than 1 sec ago) + (q=_slots.estimate_get()) >= (sz<<1) ) // 1/2 of keys are dead? + newsz = oldlen<<1; // Double the existing size + + // Do not shrink, ever + if( newsz < oldlen ) newsz = oldlen; + + // Convert to power-of-2 + int log2; + for( log2=MIN_SIZE_LOG; (1<>20/*megs*/; + if( r >= 2 && megs > 0 ) { // Already 2 guys trying; wait and see + newkvs = _newkvs; // Between dorking around, another thread did it + if( newkvs != null ) // See if resize is already in progress + return newkvs; // Use the new table already + // TODO - use a wait with timeout, so we'll wakeup as soon as the new table + // is ready, or after the timeout in any case. + //synchronized( this ) { wait(8*megs); } // Timeout - we always wakeup + // For now, sleep a tad and see if the 2 guys already trying to make + // the table actually get around to making it happen. + try { Thread.sleep(8l*megs); } catch( Exception e ) { } + } + // Last check, since the 'new' below is expensive and there is a chance + // that another thread slipped in a new thread while we ran the heuristic. + newkvs = _newkvs; + if( newkvs != null ) // See if resize is already in progress + return newkvs; // Use the new table already + + // Double size for K,V pairs, add 1 for CHM + newkvs = new Object[((1< _copyIdxUpdater = + AtomicLongFieldUpdater.newUpdater(CHM.class, "_copyIdx"); + + // Work-done reporting. Used to efficiently signal when we can move to + // the new table. From 0 to len(oldkvs) refers to copying from the old + // table to the new. + volatile long _copyDone= 0; + static private final AtomicLongFieldUpdater _copyDoneUpdater = + AtomicLongFieldUpdater.newUpdater(CHM.class, "_copyDone"); + + // --- help_copy_impl ---------------------------------------------------- + // Help along an existing resize operation. We hope its the top-level + // copy (it was when we started) but this CHM might have been promoted out + // of the top position. + private final void help_copy_impl( NonBlockingIdentityHashMap topmap, Object[] oldkvs, boolean copy_all ) { + assert chm(oldkvs) == this; + Object[] newkvs = _newkvs; + assert newkvs != null; // Already checked by caller + int oldlen = len(oldkvs); // Total amount to copy + final int MIN_COPY_WORK = Math.min(oldlen,1024); // Limit per-thread work + + // --- + int panic_start = -1; + int copyidx=-9999; // Fool javac to think it's initialized + while( _copyDone < oldlen ) { // Still needing to copy? + // Carve out a chunk of work. The counter wraps around so every + // thread eventually tries to copy every slot repeatedly. + + // We "panic" if we have tried TWICE to copy every slot - and it still + // has not happened. i.e., twice some thread somewhere claimed they + // would copy 'slot X' (by bumping _copyIdx) but they never claimed to + // have finished (by bumping _copyDone). Our choices become limited: + // we can wait for the work-claimers to finish (and become a blocking + // algorithm) or do the copy work ourselves. Tiny tables with huge + // thread counts trying to copy the table often 'panic'. + if( panic_start == -1 ) { // No panic? + copyidx = (int)_copyIdx; + while( copyidx < (oldlen<<1) && // 'panic' check + !_copyIdxUpdater.compareAndSet(this,copyidx,copyidx+MIN_COPY_WORK) ) + copyidx = (int)_copyIdx; // Re-read + if( !(copyidx < (oldlen<<1)) ) // Panic! + panic_start = copyidx; // Record where we started to panic-copy + } + + // We now know what to copy. Try to copy. + int workdone = 0; + for( int i=0; i 0 ) // Report work-done occasionally + copy_check_and_promote( topmap, oldkvs, workdone );// See if we can promote + //for( int i=0; i 0 ) { + while( !_copyDoneUpdater.compareAndSet(this,copyDone,copyDone+workdone) ) { + copyDone = _copyDone; // Reload, retry + assert (copyDone+workdone) <= oldlen; + } + } + + // Check for copy being ALL done, and promote. Note that we might have + // nested in-progress copies and manage to finish a nested copy before + // finishing the top-level copy. We only promote top-level copies. + if( copyDone+workdone == oldlen && // Ready to promote this table? + topmap._kvs == oldkvs && // Looking at the top-level table? + // Attempt to promote + topmap.CAS_kvs(oldkvs,_newkvs) ) { + topmap._last_resize_milli = System.currentTimeMillis(); // Record resize time for next check + } + } + + // --- copy_slot --------------------------------------------------------- + // Copy one K/V pair from oldkvs[i] to newkvs. Returns true if we can + // confirm that we set an old-table slot to TOMBPRIME, and only returns after + // updating the new table. We need an accurate confirmed-copy count so + // that we know when we can promote (if we promote the new table too soon, + // other threads may 'miss' on values not-yet-copied from the old table). + // We don't allow any direct updates on the new table, unless they first + // happened to the old table - so that any transition in the new table from + // null to not-null must have been from a copy_slot (or other old-table + // overwrite) and not from a thread directly writing in the new table. + private boolean copy_slot( NonBlockingIdentityHashMap topmap, int idx, Object[] oldkvs, Object[] newkvs ) { + // Blindly set the key slot from null to TOMBSTONE, to eagerly stop + // fresh put's from inserting new values in the old table when the old + // table is mid-resize. We don't need to act on the results here, + // because our correctness stems from box'ing the Value field. Slamming + // the Key field is a minor speed optimization. + Object key; + while( (key=key(oldkvs,idx)) == null ) + CAS_key(oldkvs,idx, null, TOMBSTONE); + + // --- + // Prevent new values from appearing in the old table. + // Box what we see in the old table, to prevent further updates. + Object oldval = val(oldkvs,idx); // Read OLD table + while( !(oldval instanceof Prime) ) { + final Prime box = (oldval == null || oldval == TOMBSTONE) ? TOMBPRIME : new Prime(oldval); + if( CAS_val(oldkvs,idx,oldval,box) ) { // CAS down a box'd version of oldval + // If we made the Value slot hold a TOMBPRIME, then we both + // prevented further updates here but also the (absent) + // oldval is vacuously available in the new table. We + // return with true here: any thread looking for a value for + // this key can correctly go straight to the new table and + // skip looking in the old table. + if( box == TOMBPRIME ) + return true; + // Otherwise we boxed something, but it still needs to be + // copied into the new table. + oldval = box; // Record updated oldval + break; // Break loop; oldval is now boxed by us + } + oldval = val(oldkvs,idx); // Else try, try again + } + if( oldval == TOMBPRIME ) return false; // Copy already complete here! + + // --- + // Copy the value into the new table, but only if we overwrite a null. + // If another value is already in the new table, then somebody else + // wrote something there and that write is happens-after any value that + // appears in the old table. + Object old_unboxed = ((Prime)oldval)._V; + assert old_unboxed != TOMBSTONE; + putIfMatch0(topmap, newkvs, key, old_unboxed, null); + + // --- + // Finally, now that any old value is exposed in the new table, we can + // forever hide the old-table value by slapping a TOMBPRIME down. This + // will stop other threads from uselessly attempting to copy this slot + // (i.e., it's a speed optimization not a correctness issue). + while( oldval != TOMBPRIME && !CAS_val(oldkvs,idx,oldval,TOMBPRIME) ) + oldval = val(oldkvs,idx); + + return oldval != TOMBPRIME; // True if we slammed the TOMBPRIME down + } // end copy_slot + } // End of CHM + + + // --- Snapshot ------------------------------------------------------------ + // The main class for iterating over the NBHM. It "snapshots" a clean + // view of the K/V array. + private class SnapshotV implements Iterator, Enumeration { + final Object[] _sskvs; + public SnapshotV() { + while( true ) { // Verify no table-copy-in-progress + Object[] topkvs = _kvs; + CHM topchm = chm(topkvs); + if( topchm._newkvs == null ) { // No table-copy-in-progress + // The "linearization point" for the iteration. Every key in this + // table will be visited, but keys added later might be skipped or + // even be added to a following table (also not iterated over). + _sskvs = topkvs; + break; + } + // Table copy in-progress - so we cannot get a clean iteration. We + // must help finish the table copy before we can start iterating. + topchm.help_copy_impl(NonBlockingIdentityHashMap.this,topkvs,true); + } + // Warm-up the iterator + next(); + } + int length() { return len(_sskvs); } + Object key(int idx) { return NonBlockingIdentityHashMap.key(_sskvs,idx); } + private int _idx; // Varies from 0-keys.length + private Object _nextK, _prevK; // Last 2 keys found + private TypeV _nextV, _prevV; // Last 2 values found + public boolean hasNext() { return _nextV != null; } + public TypeV next() { + // 'next' actually knows what the next value will be - it had to + // figure that out last go-around lest 'hasNext' report true and + // some other thread deleted the last value. Instead, 'next' + // spends all its effort finding the key that comes after the + // 'next' key. + if( _idx != 0 && _nextV == null ) throw new NoSuchElementException(); + _prevK = _nextK; // This will become the previous key + _prevV = _nextV; // This will become the previous value + _nextV = null; // We have no more next-key + // Attempt to set <_nextK,_nextV> to the next K,V pair. + // _nextV is the trigger: stop searching when it is != null + while( _idx, but the JDK always removes by key, even when the value has changed. + removeKey(); + } + + public TypeV nextElement() { return next(); } + public boolean hasMoreElements() { return hasNext(); } + } + + /** Returns an enumeration of the values in this table. + * @return an enumeration of the values in this table + * @see #values() */ + public Enumeration elements() { return new SnapshotV(); } + + // --- values -------------------------------------------------------------- + /** Returns a {@link Collection} view of the values contained in this map. + * The collection is backed by the map, so changes to the map are reflected + * in the collection, and vice-versa. The collection supports element + * removal, which removes the corresponding mapping from this map, via the + * Iterator.remove, Collection.remove, + * removeAll, retainAll, and clear operations. + * It does not support the add or addAll operations. + * + *

The view's iterator is a "weakly consistent" iterator that + * will never throw {@link ConcurrentModificationException}, and guarantees + * to traverse elements as they existed upon construction of the iterator, + * and may (but is not guaranteed to) reflect any modifications subsequent + * to construction. */ + @Override + public Collection values() { + return new AbstractCollection() { + @Override public void clear ( ) { NonBlockingIdentityHashMap.this.clear ( ); } + @Override public int size ( ) { return NonBlockingIdentityHashMap.this.size ( ); } + @Override public boolean contains( Object v ) { return NonBlockingIdentityHashMap.this.containsValue(v); } + @Override public Iterator iterator() { return new SnapshotV(); } + }; + } + + // --- keySet -------------------------------------------------------------- + private class SnapshotK implements Iterator, Enumeration { + final SnapshotV _ss; + public SnapshotK() { _ss = new SnapshotV(); } + public void remove() { _ss.removeKey(); } + public TypeK next() { _ss.next(); return (TypeK)_ss._prevK; } + public boolean hasNext() { return _ss.hasNext(); } + public TypeK nextElement() { return next(); } + public boolean hasMoreElements() { return hasNext(); } + } + + /** Returns an enumeration of the keys in this table. + * @return an enumeration of the keys in this table + * @see #keySet() */ + public Enumeration keys() { return new SnapshotK(); } + + /** Returns a {@link Set} view of the keys contained in this map. The set + * is backed by the map, so changes to the map are reflected in the set, + * and vice-versa. The set supports element removal, which removes the + * corresponding mapping from this map, via the Iterator.remove, + * Set.remove, removeAll, retainAll, and + * clear operations. It does not support the add or + * addAll operations. + * + *

The view's iterator is a "weakly consistent" iterator that + * will never throw {@link ConcurrentModificationException}, and guarantees + * to traverse elements as they existed upon construction of the iterator, + * and may (but is not guaranteed to) reflect any modifications subsequent + * to construction. */ + @Override + public Set keySet() { + return new AbstractSet () { + @Override public void clear ( ) { NonBlockingIdentityHashMap.this.clear ( ); } + @Override public int size ( ) { return NonBlockingIdentityHashMap.this.size ( ); } + @Override public boolean contains( Object k ) { return NonBlockingIdentityHashMap.this.containsKey(k); } + @Override public boolean remove ( Object k ) { return NonBlockingIdentityHashMap.this.remove (k) != null; } + @Override public Iterator iterator() { return new SnapshotK(); } + }; + } + + + // --- entrySet ------------------------------------------------------------ + // Warning: Each call to 'next' in this iterator constructs a new NBHMEntry. + private class NBHMEntry extends AbstractEntry { + NBHMEntry( final TypeK k, final TypeV v ) { super(k,v); } + public TypeV setValue(final TypeV val) { + if( val == null ) throw new NullPointerException(); + _val = val; + return put(_key, val); + } + } + + private class SnapshotE implements Iterator> { + final SnapshotV _ss; + public SnapshotE() { _ss = new SnapshotV(); } + public void remove() { + // NOTE: it would seem logical that entry removal will semantically mean removing the matching pair , but + // the JDK always removes by key, even when the value has changed. + _ss.removeKey(); + } + public Map.Entry next() { _ss.next(); return new NBHMEntry((TypeK)_ss._prevK,_ss._prevV); } + public boolean hasNext() { return _ss.hasNext(); } + } + + /** Returns a {@link Set} view of the mappings contained in this map. The + * set is backed by the map, so changes to the map are reflected in the + * set, and vice-versa. The set supports element removal, which removes + * the corresponding mapping from the map, via the + * Iterator.remove, Set.remove, removeAll, + * retainAll, and clear operations. It does not support + * the add or addAll operations. + * + *

The view's iterator is a "weakly consistent" iterator + * that will never throw {@link ConcurrentModificationException}, + * and guarantees to traverse elements as they existed upon + * construction of the iterator, and may (but is not guaranteed to) + * reflect any modifications subsequent to construction. + * + *

Warning: the iterator associated with this Set + * requires the creation of {@link java.util.Map.Entry} objects with each + * iteration. The {@link NonBlockingIdentityHashMap} does not normally create or + * using {@link java.util.Map.Entry} objects so they will be created soley + * to support this iteration. Iterating using {@link Map#keySet} or {@link + * Map#values} will be more efficient. + */ + @Override + public Set> entrySet() { + return new AbstractSet>() { + @Override public void clear ( ) { NonBlockingIdentityHashMap.this.clear( ); } + @Override public int size ( ) { return NonBlockingIdentityHashMap.this.size ( ); } + @Override public boolean remove( final Object o ) { + if( !(o instanceof Map.Entry)) return false; + final Map.Entry e = (Map.Entry)o; + return NonBlockingIdentityHashMap.this.remove(e.getKey(), e.getValue()); + } + @Override public boolean contains(final Object o) { + if( !(o instanceof Map.Entry)) return false; + final Map.Entry e = (Map.Entry)o; + TypeV v = get(e.getKey()); + return v != null && v.equals(e.getValue()); + } + @Override public Iterator> iterator() { return new SnapshotE(); } + }; + } + + // --- writeObject ------------------------------------------------------- + // Write a NBIHM to a stream + private void writeObject(java.io.ObjectOutputStream s) throws IOException { + s.defaultWriteObject(); // Nothing to write + for( Object K : keySet() ) { + final Object V = get(K); // Do an official 'get' + s.writeObject(K); // Write the pair + s.writeObject(V); + } + s.writeObject(null); // Sentinel to indicate end-of-data + s.writeObject(null); + } + + // --- readObject -------------------------------------------------------- + // Read a NBIHM from a stream + private void readObject(java.io.ObjectInputStream s) throws IOException, ClassNotFoundException { + s.defaultReadObject(); // Read nothing + initialize(MIN_SIZE); + for(;;) { + final TypeK K = (TypeK) s.readObject(); + final TypeV V = (TypeV) s.readObject(); + if( K == null ) break; + put(K,V); // Insert with an offical put + } + } + +} // End NonBlockingIdentityHashMap class diff --git a/netty-jctools/src/main/java/org/jctools/maps/NonBlockingSetInt.java b/netty-jctools/src/main/java/org/jctools/maps/NonBlockingSetInt.java new file mode 100644 index 0000000..2924a49 --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/maps/NonBlockingSetInt.java @@ -0,0 +1,476 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.maps; +import static org.jctools.util.UnsafeAccess.UNSAFE; +import static org.jctools.util.UnsafeAccess.fieldOffset; + +import java.io.IOException; +import java.io.Serializable; +import java.lang.reflect.Field; +import java.util.AbstractSet; +import java.util.Iterator; +import java.util.NoSuchElementException; +import java.util.concurrent.atomic.AtomicInteger; +import org.jctools.util.RangeUtil; + +/** + * A multi-threaded bit-vector set, implemented as an array of primitive + * {@code longs}. All operations are non-blocking and multi-threaded safe. + * {@link #contains(int)} calls are roughly the same speed as a {load, mask} + * sequence. {@link #add(int)} and {@link #remove(int)} calls are a tad more + * expensive than a {load, mask, store} sequence because they must use a CAS. + * The bit-vector is auto-sizing. + * + *

General note of caution: The Set API allows the use of {@link Integer} + * with silent autoboxing - which can be very expensive if many calls are + * being made. Since autoboxing is silent you may not be aware that this is + * going on. The built-in API takes lower-case {@code ints} and is much more + * efficient. + * + *

Space: space is used in proportion to the largest element, as opposed to + * the number of elements (as is the case with hash-table based Set + * implementations). Space is approximately (largest_element/8 + 64) bytes. + * + * The implementation is a simple bit-vector using CAS for update. + * + * @since 1.5 + * @author Cliff Click + */ +public class NonBlockingSetInt extends AbstractSet implements Serializable { + private static final long serialVersionUID = 1234123412341234123L; + + // --- Bits to allow atomic update of the NBSI + private static final long _nbsi_offset = fieldOffset(NonBlockingSetInt.class, "_nbsi"); + + private final boolean CAS_nbsi( NBSI old, NBSI nnn ) { + return UNSAFE.compareAndSwapObject(this, _nbsi_offset, old, nnn ); + } + + // The actual Set of Joy, which changes during a resize event. The + // Only Field for this class, so I can atomically change the entire + // set implementation with a single CAS. + private transient NBSI _nbsi; + + /** Create a new empty bit-vector */ + public NonBlockingSetInt( ) { + _nbsi = new NBSI(63, new ConcurrentAutoTable(), this); // The initial 1-word set + } + + /** + * Add {@code i} to the set. Uppercase {@link Integer} version of add, + * requires auto-unboxing. When possible use the {@code int} version of + * {@link #add(int)} for efficiency. + * @throws IllegalArgumentException if i is negative. + * @return true if i was added to the set. + */ + public boolean add ( final Integer i ) { + return add(i.intValue()); + } + /** + * Test if {@code o} is in the set. This is the uppercase {@link Integer} + * version of contains, requires a type-check and auto-unboxing. When + * possible use the {@code int} version of {@link #contains(int)} for + * efficiency. + * @return true if i was in the set. + */ + public boolean contains( final Object o ) { + return o instanceof Integer && contains(((Integer) o).intValue()); + } + /** + * Remove {@code o} from the set. This is the uppercase {@link Integer} + * version of remove, requires a type-check and auto-unboxing. When + * possible use the {@code int} version of {@link #remove(int)} for + * efficiency. + * @return true if i was removed to the set. + */ + public boolean remove( final Object o ) { + return o instanceof Integer && remove(((Integer) o).intValue()); + } + + /** + * Add {@code i} to the set. This is the lower-case '{@code int}' version + * of {@link #add} - no autoboxing. Negative values throw + * IllegalArgumentException. + * @throws IllegalArgumentException if i is negative. + * @return true if i was added to the set. + */ + public boolean add( final int i ) { + RangeUtil.checkPositiveOrZero(i, "i"); + return _nbsi.add(i); + } + /** + * Test if {@code i} is in the set. This is the lower-case '{@code int}' + * version of {@link #contains} - no autoboxing. + * @return true if i was int the set. + */ + public boolean contains( final int i ) { return i >= 0 && _nbsi.contains(i); } + /** + * Remove {@code i} from the set. This is the fast lower-case '{@code int}' + * version of {@link #remove} - no autoboxing. + * @return true if i was added to the set. + */ + public boolean remove ( final int i ) { return i >= 0 && _nbsi.remove(i); } + + /** + * Current count of elements in the set. Due to concurrent racing updates, + * the size is only ever approximate. Updates due to the calling thread are + * immediately visible to calling thread. + * @return count of elements. + */ + public int size ( ) { return _nbsi.size( ); } + /** Approx largest element in set; at least as big (but max might be smaller). */ + public int length() { return _nbsi._bits.length<<6; } + /** Empty the bitvector. */ + public void clear ( ) { + NBSI cleared = new NBSI(63, new ConcurrentAutoTable(), this); // An empty initial NBSI + while( !CAS_nbsi( _nbsi, cleared ) ) // Spin until clear works + ; + } + + /** Verbose printout of internal structure for debugging. */ + public void print() { _nbsi.print(0); } + + /** + * Standard Java {@link Iterator}. Not very efficient because it + * auto-boxes the returned values. + */ + public Iterator iterator( ) { return new iter(); } + + private class iter implements Iterator { + NBSI _nbsi2; + int _idx = -1; + int _prev = -1; + iter() { _nbsi2 = _nbsi; advance(); } + public boolean hasNext() { return _idx != -2; } + private void advance() { + while( true ) { + _idx++; // Next index + while( (_idx>>6) >= _nbsi2._bits.length ) { // Index out of range? + if( _nbsi2._new == null ) { // New table? + _idx = -2; // No, so must be all done + return; // + } + _nbsi2 = _nbsi2._new; // Carry on, in the new table + } + if( _nbsi2.contains(_idx) ) return; + } + } + public Integer next() { + if( _idx == -1 ) throw new NoSuchElementException(); + _prev = _idx; + advance(); + return _prev; + } + public void remove() { + if( _prev == -1 ) throw new IllegalStateException(); + _nbsi2.remove(_prev); + _prev = -1; + } + } + + // --- writeObject ------------------------------------------------------- + // Write a NBSI to a stream + private void writeObject(java.io.ObjectOutputStream s) throws IOException { + s.defaultWriteObject(); // Nothing to write + final NBSI nbsi = _nbsi; // The One Field is transient + final int len = _nbsi._bits.length<<6; + s.writeInt(len); // Write max element + for( int i=0; i= 0 && idx < ary.length; + return _Lbase + (idx * (long)_Lscale); + } + private final boolean CAS( int idx, long old, long nnn ) { + return UNSAFE.compareAndSwapLong( _bits, rawIndex(_bits, idx), old, nnn ); + } + + // --- Resize + // The New Table, only set once to non-zero during a resize. + // Must be atomically set. + private NBSI _new; + private static final long _new_offset = fieldOffset(NBSI.class, "_new"); + + private final boolean CAS_new( NBSI nnn ) { + return UNSAFE.compareAndSwapObject(this, _new_offset, null, nnn ); + } + + private transient final AtomicInteger _copyIdx; // Used to count bits started copying + private transient final AtomicInteger _copyDone; // Used to count words copied in a resize operation + private transient final int _sum_bits_length; // Sum of all nested _bits.lengths + + private static final long mask( int i ) { return 1L<<(i&63); } + + // I need 1 free bit out of 64 to allow for resize. I do this by stealing + // the high order bit - but then I need to do something with adding element + // number 63 (and friends). I could use a mod63 function but it's more + // efficient to handle the mod-64 case as an exception. + // + // Every 64th bit is put in it's own recursive bitvector. If the low 6 bits + // are all set, we shift them off and recursively operate on the _nbsi64 set. + private final NBSI _nbsi64; + + private NBSI( int max_elem, ConcurrentAutoTable ctr, NonBlockingSetInt nonb ) { + super(); + _non_blocking_set_int = nonb; + _size = ctr; + _copyIdx = ctr == null ? null : new AtomicInteger(); + _copyDone = ctr == null ? null : new AtomicInteger(); + // The main array of bits + _bits = new long[(int)(((long)max_elem+63)>>>6)]; + // Every 64th bit is moved off to it's own subarray, so that the + // sign-bit is free for other purposes + _nbsi64 = ((max_elem+1)>>>6) == 0 ? null : new NBSI((max_elem+1)>>>6, null, null); + _sum_bits_length = _bits.length + (_nbsi64==null ? 0 : _nbsi64._sum_bits_length); + } + + // Lower-case 'int' versions - no autoboxing, very fast. + // 'i' is known positive. + public boolean add( final int i ) { + // Check for out-of-range for the current size bit vector. + // If so we need to grow the bit vector. + if( (i>>6) >= _bits.length ) + return install_larger_new_bits(i). // Install larger pile-o-bits (duh) + help_copy().add(i); // Finally, add to the new table + + // Handle every 64th bit via using a nested array + NBSI nbsi = this; // The bit array being added into + int j = i; // The bit index being added + while( (j&63) == 63 ) { // Bit 64? (low 6 bits are all set) + nbsi = nbsi._nbsi64; // Recurse + j = j>>6; // Strip off low 6 bits (all set) + } + + final long mask = mask(j); + long old; + do { + old = nbsi._bits[j>>6]; // Read old bits + if( old < 0 ) // Not mutable? + // Not mutable: finish copy of word, and retry on copied word + return help_copy_impl(i).help_copy().add(i); + if( (old & mask) != 0 ) return false; // Bit is already set? + } while( !nbsi.CAS( j>>6, old, old | mask ) ); + _size.add(1); + return true; + } + + public boolean remove( final int i ) { + if( (i>>6) >= _bits.length ) // Out of bounds? Not in this array! + return _new != null && help_copy().remove(i); + + // Handle every 64th bit via using a nested array + NBSI nbsi = this; // The bit array being added into + int j = i; // The bit index being added + while( (j&63) == 63 ) { // Bit 64? (low 6 bits are all set) + nbsi = nbsi._nbsi64; // Recurse + j = j>>6; // Strip off low 6 bits (all set) + } + + final long mask = mask(j); + long old; + do { + old = nbsi._bits[j>>6]; // Read old bits + if( old < 0 ) // Not mutable? + // Not mutable: finish copy of word, and retry on copied word + return help_copy_impl(i).help_copy().remove(i); + if( (old & mask) == 0 ) return false; // Bit is already clear? + } while( !nbsi.CAS( j>>6, old, old & ~mask ) ); + _size.add(-1); + return true; + } + + public boolean contains( final int i ) { + if( (i>>6) >= _bits.length ) // Out of bounds? Not in this array! + return _new != null && help_copy().contains(i); + + // Handle every 64th bit via using a nested array + NBSI nbsi = this; // The bit array being added into + int j = i; // The bit index being added + while( (j&63) == 63 ) { // Bit 64? (low 6 bits are all set) + nbsi = nbsi._nbsi64; // Recurse + j = j>>6; // Strip off low 6 bits (all set) + } + + final long mask = mask(j); + long old = nbsi._bits[j>>6]; // Read old bits + if( old < 0 ) // Not mutable? + // Not mutable: finish copy of word, and retry on copied word + return help_copy_impl(i).help_copy().contains(i); + // Yes mutable: test & return bit + return (old & mask) != 0; + } + + public int size() { return (int)_size.get(); } + + // Must grow the current array to hold an element of size i + private NBSI install_larger_new_bits( final int i ) { + if( _new == null ) { + // Grow by powers of 2, to avoid minor grow-by-1's. + // Note: must grow by exact powers-of-2 or the by-64-bit trick doesn't work right + int sz = (_bits.length<<6)<<1; + // CAS to install a new larger size. Did it work? Did it fail? We + // don't know and don't care. Only One can be installed, so if + // another thread installed a too-small size, we can't help it - we + // must simply install our new larger size as a nested-resize table. + CAS_new(new NBSI(sz, _size, _non_blocking_set_int)); + } + // Return self for 'fluid' programming style + return this; + } + + // Help any top-level NBSI to copy until completed. + // Always return the _new version of *this* NBSI, in case we're nested. + private NBSI help_copy() { + // Pick some words to help with - but only help copy the top-level NBSI. + // Nested NBSI waits until the top is done before we start helping. + NBSI top_nbsi = _non_blocking_set_int._nbsi; + final int HELP = 8; // Tuning number: how much copy pain are we willing to inflict? + // We "help" by forcing individual bit indices to copy. However, bits + // come in lumps of 64 per word, so we just advance the bit counter by 64's. + int idx = top_nbsi._copyIdx.getAndAdd(64*HELP); + for( int i=0; i>6; // Strip off low 6 bits (all set) + } + + // Transit from state 1: word is not immutable yet + // Immutable is in bit 63, the sign bit. + long bits = old._bits[j>>6]; + while( bits >= 0 ) { // Still in state (1)? + long oldbits = bits; + bits |= mask(63); // Target state of bits: sign-bit means immutable + if( old.CAS( j>>6, oldbits, bits ) ) { + if( oldbits == 0 ) _copyDone.addAndGet(1); + break; // Success - old array word is now immutable + } + bits = old._bits[j>>6]; // Retry if CAS failed + } + + // Transit from state 2: non-zero in old and zero in new + if( bits != mask(63) ) { // Non-zero in old? + long new_bits = nnn._bits[j>>6]; + if( new_bits == 0 ) { // New array is still zero + new_bits = bits & ~mask(63); // Desired new value: a mutable copy of bits + // One-shot CAS attempt, no loop, from 0 to non-zero. + // If it fails, somebody else did the copy for us + if( !nnn.CAS( j>>6, 0, new_bits ) ) + new_bits = nnn._bits[j>>6]; // Since it failed, get the new value + assert new_bits != 0; + } + + // Transit from state 3: non-zero in old and non-zero in new + // One-shot CAS attempt, no loop, from non-zero to 0 (but immutable) + if( old.CAS( j>>6, bits, mask(63) ) ) + _copyDone.addAndGet(1); // One more word finished copying + } + + // Now in state 4: zero (and immutable) in old + + // Return the self bitvector for 'fluid' programming style + return this; + } + + private void print( int d, String msg ) { + for( int i=0; i extends AbstractQueue implements MessagePassingQueue +{ + byte b000,b001,b002,b003,b004,b005,b006,b007;// 8b + byte b010,b011,b012,b013,b014,b015,b016,b017;// 16b + byte b020,b021,b022,b023,b024,b025,b026,b027;// 24b + byte b030,b031,b032,b033,b034,b035,b036,b037;// 32b + byte b040,b041,b042,b043,b044,b045,b046,b047;// 40b + byte b050,b051,b052,b053,b054,b055,b056,b057;// 48b + byte b060,b061,b062,b063,b064,b065,b066,b067;// 56b + byte b070,b071,b072,b073,b074,b075,b076,b077;// 64b + byte b100,b101,b102,b103,b104,b105,b106,b107;// 72b + byte b110,b111,b112,b113,b114,b115,b116,b117;// 80b + byte b120,b121,b122,b123,b124,b125,b126,b127;// 88b + byte b130,b131,b132,b133,b134,b135,b136,b137;// 96b + byte b140,b141,b142,b143,b144,b145,b146,b147;//104b + byte b150,b151,b152,b153,b154,b155,b156,b157;//112b + byte b160,b161,b162,b163,b164,b165,b166,b167;//120b + // byte b170,b171,b172,b173,b174,b175,b176,b177;//128b + // * drop 8b as object header acts as padding and is >= 8b * +} + +// $gen:ordered-fields +abstract class BaseLinkedQueueProducerNodeRef extends BaseLinkedQueuePad0 +{ + final static long P_NODE_OFFSET = fieldOffset(BaseLinkedQueueProducerNodeRef.class, "producerNode"); + + private volatile LinkedQueueNode producerNode; + + final void spProducerNode(LinkedQueueNode newValue) + { + UNSAFE.putObject(this, P_NODE_OFFSET, newValue); + } + + final void soProducerNode(LinkedQueueNode newValue) + { + UNSAFE.putOrderedObject(this, P_NODE_OFFSET, newValue); + } + + final LinkedQueueNode lvProducerNode() + { + return producerNode; + } + + final boolean casProducerNode(LinkedQueueNode expect, LinkedQueueNode newValue) + { + return UNSAFE.compareAndSwapObject(this, P_NODE_OFFSET, expect, newValue); + } + + final LinkedQueueNode lpProducerNode() + { + return producerNode; + } +} + +abstract class BaseLinkedQueuePad1 extends BaseLinkedQueueProducerNodeRef +{ + byte b000,b001,b002,b003,b004,b005,b006,b007;// 8b + byte b010,b011,b012,b013,b014,b015,b016,b017;// 16b + byte b020,b021,b022,b023,b024,b025,b026,b027;// 24b + byte b030,b031,b032,b033,b034,b035,b036,b037;// 32b + byte b040,b041,b042,b043,b044,b045,b046,b047;// 40b + byte b050,b051,b052,b053,b054,b055,b056,b057;// 48b + byte b060,b061,b062,b063,b064,b065,b066,b067;// 56b + byte b070,b071,b072,b073,b074,b075,b076,b077;// 64b + byte b100,b101,b102,b103,b104,b105,b106,b107;// 72b + byte b110,b111,b112,b113,b114,b115,b116,b117;// 80b + byte b120,b121,b122,b123,b124,b125,b126,b127;// 88b + byte b130,b131,b132,b133,b134,b135,b136,b137;// 96b + byte b140,b141,b142,b143,b144,b145,b146,b147;//104b + byte b150,b151,b152,b153,b154,b155,b156,b157;//112b + byte b160,b161,b162,b163,b164,b165,b166,b167;//120b + byte b170,b171,b172,b173,b174,b175,b176,b177;//128b +} + +//$gen:ordered-fields +abstract class BaseLinkedQueueConsumerNodeRef extends BaseLinkedQueuePad1 +{ + private final static long C_NODE_OFFSET = fieldOffset(BaseLinkedQueueConsumerNodeRef.class,"consumerNode"); + + private LinkedQueueNode consumerNode; + + final void spConsumerNode(LinkedQueueNode newValue) + { + consumerNode = newValue; + } + + @SuppressWarnings("unchecked") + final LinkedQueueNode lvConsumerNode() + { + return (LinkedQueueNode) UNSAFE.getObjectVolatile(this, C_NODE_OFFSET); + } + + final LinkedQueueNode lpConsumerNode() + { + return consumerNode; + } +} + +abstract class BaseLinkedQueuePad2 extends BaseLinkedQueueConsumerNodeRef +{ + byte b000,b001,b002,b003,b004,b005,b006,b007;// 8b + byte b010,b011,b012,b013,b014,b015,b016,b017;// 16b + byte b020,b021,b022,b023,b024,b025,b026,b027;// 24b + byte b030,b031,b032,b033,b034,b035,b036,b037;// 32b + byte b040,b041,b042,b043,b044,b045,b046,b047;// 40b + byte b050,b051,b052,b053,b054,b055,b056,b057;// 48b + byte b060,b061,b062,b063,b064,b065,b066,b067;// 56b + byte b070,b071,b072,b073,b074,b075,b076,b077;// 64b + byte b100,b101,b102,b103,b104,b105,b106,b107;// 72b + byte b110,b111,b112,b113,b114,b115,b116,b117;// 80b + byte b120,b121,b122,b123,b124,b125,b126,b127;// 88b + byte b130,b131,b132,b133,b134,b135,b136,b137;// 96b + byte b140,b141,b142,b143,b144,b145,b146,b147;//104b + byte b150,b151,b152,b153,b154,b155,b156,b157;//112b + byte b160,b161,b162,b163,b164,b165,b166,b167;//120b + byte b170,b171,b172,b173,b174,b175,b176,b177;//128b +} + +/** + * A base data structure for concurrent linked queues. For convenience also pulled in common single consumer + * methods since at this time there's no plan to implement MC. + */ +abstract class BaseLinkedQueue extends BaseLinkedQueuePad2 +{ + + @Override + public final Iterator iterator() + { + throw new UnsupportedOperationException(); + } + + @Override + public String toString() + { + return this.getClass().getName(); + } + + protected final LinkedQueueNode newNode() + { + return new LinkedQueueNode(); + } + + protected final LinkedQueueNode newNode(E e) + { + return new LinkedQueueNode(e); + } + + /** + * {@inheritDoc}
+ *

+ * IMPLEMENTATION NOTES:
+ * This is an O(n) operation as we run through all the nodes and count them.
+ * The accuracy of the value returned by this method is subject to races with producer/consumer threads. In + * particular when racing with the consumer thread this method may under estimate the size.
+ * + * @see java.util.Queue#size() + */ + @Override + public final int size() + { + // Read consumer first, this is important because if the producer is node is 'older' than the consumer + // the consumer may overtake it (consume past it) invalidating the 'snapshot' notion of size. + LinkedQueueNode chaserNode = lvConsumerNode(); + LinkedQueueNode producerNode = lvProducerNode(); + int size = 0; + // must chase the nodes all the way to the producer node, but there's no need to count beyond expected head. + while (chaserNode != producerNode && // don't go passed producer node + chaserNode != null && // stop at last node + size < Integer.MAX_VALUE) // stop at max int + { + LinkedQueueNode next; + next = chaserNode.lvNext(); + // check if this node has been consumed, if so return what we have + if (next == chaserNode) + { + return size; + } + chaserNode = next; + size++; + } + return size; + } + + /** + * {@inheritDoc}
+ *

+ * IMPLEMENTATION NOTES:
+ * Queue is empty when producerNode is the same as consumerNode. An alternative implementation would be to + * observe the producerNode.value is null, which also means an empty queue because only the + * consumerNode.value is allowed to be null. + * + * @see MessagePassingQueue#isEmpty() + */ + @Override + public boolean isEmpty() + { + LinkedQueueNode consumerNode = lvConsumerNode(); + LinkedQueueNode producerNode = lvProducerNode(); + return consumerNode == producerNode; + } + + protected E getSingleConsumerNodeValue(LinkedQueueNode currConsumerNode, LinkedQueueNode nextNode) + { + // we have to null out the value because we are going to hang on to the node + final E nextValue = nextNode.getAndNullValue(); + + // Fix up the next ref of currConsumerNode to prevent promoted nodes from keeping new ones alive. + // We use a reference to self instead of null because null is already a meaningful value (the next of + // producer node is null). + currConsumerNode.soNext(currConsumerNode); + spConsumerNode(nextNode); + // currConsumerNode is now no longer referenced and can be collected + return nextValue; + } + + /** + * {@inheritDoc}
+ *

+ * IMPLEMENTATION NOTES:
+ * Poll is allowed from a SINGLE thread.
+ * Poll is potentially blocking here as the {@link Queue#poll()} does not allow returning {@code null} if the queue is not + * empty. This is very different from the original Vyukov guarantees. See {@link #relaxedPoll()} for the original + * semantics.
+ * Poll reads {@code consumerNode.next} and: + *

    + *
  1. If it is {@code null} AND the queue is empty return {@code null}, if queue is not empty spin wait for + * value to become visible. + *
  2. If it is not {@code null} set it as the consumer node and return it's now evacuated value. + *
+ * This means the consumerNode.value is always {@code null}, which is also the starting point for the queue. + * Because {@code null} values are not allowed to be offered this is the only node with it's value set to + * {@code null} at any one time. + * + * @see MessagePassingQueue#poll() + * @see java.util.Queue#poll() + */ + @Override + public E poll() + { + final LinkedQueueNode currConsumerNode = lpConsumerNode(); + LinkedQueueNode nextNode = currConsumerNode.lvNext(); + if (nextNode != null) + { + return getSingleConsumerNodeValue(currConsumerNode, nextNode); + } + else if (currConsumerNode != lvProducerNode()) + { + nextNode = spinWaitForNextNode(currConsumerNode); + // got the next node... + return getSingleConsumerNodeValue(currConsumerNode, nextNode); + } + return null; + } + + /** + * {@inheritDoc}
+ *

+ * IMPLEMENTATION NOTES:
+ * Peek is allowed from a SINGLE thread.
+ * Peek is potentially blocking here as the {@link Queue#peek()} does not allow returning {@code null} if the queue is not + * empty. This is very different from the original Vyukov guarantees. See {@link #relaxedPeek()} for the original + * semantics.
+ * Poll reads the next node from the consumerNode and: + *

    + *
  1. If it is {@code null} AND the queue is empty return {@code null}, if queue is not empty spin wait for + * value to become visible. + *
  2. If it is not {@code null} return it's value. + *
+ * + * @see MessagePassingQueue#peek() + * @see java.util.Queue#peek() + */ + @Override + public E peek() + { + final LinkedQueueNode currConsumerNode = lpConsumerNode(); + LinkedQueueNode nextNode = currConsumerNode.lvNext(); + if (nextNode != null) + { + return nextNode.lpValue(); + } + else if (currConsumerNode != lvProducerNode()) + { + nextNode = spinWaitForNextNode(currConsumerNode); + // got the next node... + return nextNode.lpValue(); + } + return null; + } + + LinkedQueueNode spinWaitForNextNode(LinkedQueueNode currNode) + { + LinkedQueueNode nextNode; + while ((nextNode = currNode.lvNext()) == null) + { + // spin, we are no longer wait free + } + return nextNode; + } + + @Override + public E relaxedPoll() + { + final LinkedQueueNode currConsumerNode = lpConsumerNode(); + final LinkedQueueNode nextNode = currConsumerNode.lvNext(); + if (nextNode != null) + { + return getSingleConsumerNodeValue(currConsumerNode, nextNode); + } + return null; + } + + @Override + public E relaxedPeek() + { + final LinkedQueueNode nextNode = lpConsumerNode().lvNext(); + if (nextNode != null) + { + return nextNode.lpValue(); + } + return null; + } + + @Override + public boolean relaxedOffer(E e) + { + return offer(e); + } + + @Override + public int drain(Consumer c, int limit) + { + if (null == c) + throw new IllegalArgumentException("c is null"); + if (limit < 0) + throw new IllegalArgumentException("limit is negative: " + limit); + if (limit == 0) + return 0; + + LinkedQueueNode chaserNode = this.lpConsumerNode(); + for (int i = 0; i < limit; i++) + { + final LinkedQueueNode nextNode = chaserNode.lvNext(); + + if (nextNode == null) + { + return i; + } + // we have to null out the value because we are going to hang on to the node + final E nextValue = getSingleConsumerNodeValue(chaserNode, nextNode); + chaserNode = nextNode; + c.accept(nextValue); + } + return limit; + } + + @Override + public int drain(Consumer c) + { + return MessagePassingQueueUtil.drain(this, c); + } + + @Override + public void drain(Consumer c, WaitStrategy wait, ExitCondition exit) + { + MessagePassingQueueUtil.drain(this, c, wait, exit); + } + + @Override + public int capacity() + { + return UNBOUNDED_CAPACITY; + } + +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/BaseMpscLinkedArrayQueue.java b/netty-jctools/src/main/java/org/jctools/queues/BaseMpscLinkedArrayQueue.java new file mode 100644 index 0000000..26ea831 --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/BaseMpscLinkedArrayQueue.java @@ -0,0 +1,781 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues; + +import org.jctools.queues.IndexedQueueSizeUtil.IndexedQueue; +import org.jctools.util.PortableJvmInfo; +import org.jctools.util.Pow2; +import org.jctools.util.RangeUtil; + +import java.util.AbstractQueue; +import java.util.Iterator; +import java.util.NoSuchElementException; + +import static org.jctools.queues.LinkedArrayQueueUtil.length; +import static org.jctools.queues.LinkedArrayQueueUtil.modifiedCalcCircularRefElementOffset; +import static org.jctools.util.UnsafeAccess.UNSAFE; +import static org.jctools.util.UnsafeAccess.fieldOffset; +import static org.jctools.util.UnsafeRefArrayAccess.*; + + +abstract class BaseMpscLinkedArrayQueuePad1 extends AbstractQueue implements IndexedQueue +{ + byte b000,b001,b002,b003,b004,b005,b006,b007;// 8b + byte b010,b011,b012,b013,b014,b015,b016,b017;// 16b + byte b020,b021,b022,b023,b024,b025,b026,b027;// 24b + byte b030,b031,b032,b033,b034,b035,b036,b037;// 32b + byte b040,b041,b042,b043,b044,b045,b046,b047;// 40b + byte b050,b051,b052,b053,b054,b055,b056,b057;// 48b + byte b060,b061,b062,b063,b064,b065,b066,b067;// 56b + byte b070,b071,b072,b073,b074,b075,b076,b077;// 64b + byte b100,b101,b102,b103,b104,b105,b106,b107;// 72b + byte b110,b111,b112,b113,b114,b115,b116,b117;// 80b + byte b120,b121,b122,b123,b124,b125,b126,b127;// 88b + byte b130,b131,b132,b133,b134,b135,b136,b137;// 96b + byte b140,b141,b142,b143,b144,b145,b146,b147;//104b + byte b150,b151,b152,b153,b154,b155,b156,b157;//112b + byte b160,b161,b162,b163,b164,b165,b166,b167;//120b + byte b170,b171,b172,b173,b174,b175,b176,b177;//128b +} + +// $gen:ordered-fields +abstract class BaseMpscLinkedArrayQueueProducerFields extends BaseMpscLinkedArrayQueuePad1 +{ + private final static long P_INDEX_OFFSET = fieldOffset(BaseMpscLinkedArrayQueueProducerFields.class, "producerIndex"); + + private volatile long producerIndex; + + @Override + public final long lvProducerIndex() + { + return producerIndex; + } + + final void soProducerIndex(long newValue) + { + UNSAFE.putOrderedLong(this, P_INDEX_OFFSET, newValue); + } + + final boolean casProducerIndex(long expect, long newValue) + { + return UNSAFE.compareAndSwapLong(this, P_INDEX_OFFSET, expect, newValue); + } +} + +abstract class BaseMpscLinkedArrayQueuePad2 extends BaseMpscLinkedArrayQueueProducerFields +{ + byte b000,b001,b002,b003,b004,b005,b006,b007;// 8b + byte b010,b011,b012,b013,b014,b015,b016,b017;// 16b + byte b020,b021,b022,b023,b024,b025,b026,b027;// 24b + byte b030,b031,b032,b033,b034,b035,b036,b037;// 32b + byte b040,b041,b042,b043,b044,b045,b046,b047;// 40b + byte b050,b051,b052,b053,b054,b055,b056,b057;// 48b + byte b060,b061,b062,b063,b064,b065,b066,b067;// 56b + byte b070,b071,b072,b073,b074,b075,b076,b077;// 64b + byte b100,b101,b102,b103,b104,b105,b106,b107;// 72b + byte b110,b111,b112,b113,b114,b115,b116,b117;// 80b + byte b120,b121,b122,b123,b124,b125,b126,b127;// 88b + byte b130,b131,b132,b133,b134,b135,b136,b137;// 96b + byte b140,b141,b142,b143,b144,b145,b146,b147;//104b + byte b150,b151,b152,b153,b154,b155,b156,b157;//112b + byte b160,b161,b162,b163,b164,b165,b166,b167;//120b + byte b170,b171,b172,b173,b174,b175,b176,b177;//128b +} + +// $gen:ordered-fields +abstract class BaseMpscLinkedArrayQueueConsumerFields extends BaseMpscLinkedArrayQueuePad2 +{ + private final static long C_INDEX_OFFSET = fieldOffset(BaseMpscLinkedArrayQueueConsumerFields.class,"consumerIndex"); + + private volatile long consumerIndex; + protected long consumerMask; + protected E[] consumerBuffer; + + @Override + public final long lvConsumerIndex() + { + return consumerIndex; + } + + final long lpConsumerIndex() + { + return UNSAFE.getLong(this, C_INDEX_OFFSET); + } + + final void soConsumerIndex(long newValue) + { + UNSAFE.putOrderedLong(this, C_INDEX_OFFSET, newValue); + } +} + +abstract class BaseMpscLinkedArrayQueuePad3 extends BaseMpscLinkedArrayQueueConsumerFields +{ + byte b000,b001,b002,b003,b004,b005,b006,b007;// 8b + byte b010,b011,b012,b013,b014,b015,b016,b017;// 16b + byte b020,b021,b022,b023,b024,b025,b026,b027;// 24b + byte b030,b031,b032,b033,b034,b035,b036,b037;// 32b + byte b040,b041,b042,b043,b044,b045,b046,b047;// 40b + byte b050,b051,b052,b053,b054,b055,b056,b057;// 48b + byte b060,b061,b062,b063,b064,b065,b066,b067;// 56b + byte b070,b071,b072,b073,b074,b075,b076,b077;// 64b + byte b100,b101,b102,b103,b104,b105,b106,b107;// 72b + byte b110,b111,b112,b113,b114,b115,b116,b117;// 80b + byte b120,b121,b122,b123,b124,b125,b126,b127;// 88b + byte b130,b131,b132,b133,b134,b135,b136,b137;// 96b + byte b140,b141,b142,b143,b144,b145,b146,b147;//104b + byte b150,b151,b152,b153,b154,b155,b156,b157;//112b + byte b160,b161,b162,b163,b164,b165,b166,b167;//120b + byte b170,b171,b172,b173,b174,b175,b176,b177;//128b +} + +// $gen:ordered-fields +abstract class BaseMpscLinkedArrayQueueColdProducerFields extends BaseMpscLinkedArrayQueuePad3 +{ + private final static long P_LIMIT_OFFSET = fieldOffset(BaseMpscLinkedArrayQueueColdProducerFields.class,"producerLimit"); + + private volatile long producerLimit; + protected long producerMask; + protected E[] producerBuffer; + + final long lvProducerLimit() + { + return producerLimit; + } + + final boolean casProducerLimit(long expect, long newValue) + { + return UNSAFE.compareAndSwapLong(this, P_LIMIT_OFFSET, expect, newValue); + } + + final void soProducerLimit(long newValue) + { + UNSAFE.putOrderedLong(this, P_LIMIT_OFFSET, newValue); + } +} + + +/** + * An MPSC array queue which starts at initialCapacity and grows to maxCapacity in linked chunks + * of the initial size. The queue grows only when the current buffer is full and elements are not copied on + * resize, instead a link to the new buffer is stored in the old buffer for the consumer to follow. + */ +abstract class BaseMpscLinkedArrayQueue extends BaseMpscLinkedArrayQueueColdProducerFields + implements MessagePassingQueue, QueueProgressIndicators +{ + // No post padding here, subclasses must add + private static final Object JUMP = new Object(); + private static final Object BUFFER_CONSUMED = new Object(); + private static final int CONTINUE_TO_P_INDEX_CAS = 0; + private static final int RETRY = 1; + private static final int QUEUE_FULL = 2; + private static final int QUEUE_RESIZE = 3; + + + /** + * @param initialCapacity the queue initial capacity. If chunk size is fixed this will be the chunk size. + * Must be 2 or more. + */ + public BaseMpscLinkedArrayQueue(final int initialCapacity) + { + RangeUtil.checkGreaterThanOrEqual(initialCapacity, 2, "initialCapacity"); + + int p2capacity = Pow2.roundToPowerOfTwo(initialCapacity); + // leave lower bit of mask clear + long mask = (p2capacity - 1) << 1; + // need extra element to point at next array + E[] buffer = allocateRefArray(p2capacity + 1); + producerBuffer = buffer; + producerMask = mask; + consumerBuffer = buffer; + consumerMask = mask; + soProducerLimit(mask); // we know it's all empty to start with + } + + @Override + public int size() + { + return IndexedQueueSizeUtil.size(this, IndexedQueueSizeUtil.IGNORE_PARITY_DIVISOR); + } + + @Override + public boolean isEmpty() + { + // Order matters! + // Loading consumer before producer allows for producer increments after consumer index is read. + // This ensures this method is conservative in it's estimate. Note that as this is an MPMC there is + // nothing we can do to make this an exact method. + return ((lvConsumerIndex() - lvProducerIndex()) / 2 == 0); + } + + @Override + public String toString() + { + return this.getClass().getName(); + } + + @Override + public boolean offer(final E e) + { + if (null == e) + { + throw new NullPointerException(); + } + + long mask; + E[] buffer; + long pIndex; + + while (true) + { + long producerLimit = lvProducerLimit(); + pIndex = lvProducerIndex(); + // lower bit is indicative of resize, if we see it we spin until it's cleared + if ((pIndex & 1) == 1) + { + continue; + } + // pIndex is even (lower bit is 0) -> actual index is (pIndex >> 1) + + // mask/buffer may get changed by resizing -> only use for array access after successful CAS. + mask = this.producerMask; + buffer = this.producerBuffer; + // a successful CAS ties the ordering, lv(pIndex) - [mask/buffer] -> cas(pIndex) + + // assumption behind this optimization is that queue is almost always empty or near empty + if (producerLimit <= pIndex) + { + int result = offerSlowPath(mask, pIndex, producerLimit); + switch (result) + { + case CONTINUE_TO_P_INDEX_CAS: + break; + case RETRY: + continue; + case QUEUE_FULL: + return false; + case QUEUE_RESIZE: + resize(mask, buffer, pIndex, e, null); + return true; + } + } + + if (casProducerIndex(pIndex, pIndex + 2)) + { + break; + } + } + // INDEX visible before ELEMENT + final long offset = modifiedCalcCircularRefElementOffset(pIndex, mask); + soRefElement(buffer, offset, e); // release element e + return true; + } + + /** + * {@inheritDoc} + *

+ * This implementation is correct for single consumer thread use only. + */ + @SuppressWarnings("unchecked") + @Override + public E poll() + { + final E[] buffer = consumerBuffer; + final long cIndex = lpConsumerIndex(); + final long mask = consumerMask; + + final long offset = modifiedCalcCircularRefElementOffset(cIndex, mask); + Object e = lvRefElement(buffer, offset); + if (e == null) + { + long pIndex = lvProducerIndex(); + // isEmpty? + if ((cIndex - pIndex) / 2 == 0) + { + return null; + } + // poll() == null iff queue is empty, null element is not strong enough indicator, so we must + // spin until element is visible. + do + { + e = lvRefElement(buffer, offset); + } + while (e == null); + } + + if (e == JUMP) + { + final E[] nextBuffer = nextBuffer(buffer, mask); + return newBufferPoll(nextBuffer, cIndex); + } + + soRefElement(buffer, offset, null); // release element null + soConsumerIndex(cIndex + 2); // release cIndex + return (E) e; + } + + /** + * {@inheritDoc} + *

+ * This implementation is correct for single consumer thread use only. + */ + @SuppressWarnings("unchecked") + @Override + public E peek() + { + final E[] buffer = consumerBuffer; + final long cIndex = lpConsumerIndex(); + final long mask = consumerMask; + + final long offset = modifiedCalcCircularRefElementOffset(cIndex, mask); + Object e = lvRefElement(buffer, offset); + if (e == null) + { + long pIndex = lvProducerIndex(); + // isEmpty? + if ((cIndex - pIndex) / 2 == 0) + { + return null; + } + // peek() == null iff queue is empty, null element is not strong enough indicator, so we must + // spin until element is visible. + do + { + e = lvRefElement(buffer, offset); + } + while (e == null); + } + if (e == JUMP) + { + return newBufferPeek(nextBuffer(buffer, mask), cIndex); + } + return (E) e; + } + + /** + * We do not inline resize into this method because we do not resize on fill. + */ + private int offerSlowPath(long mask, long pIndex, long producerLimit) + { + final long cIndex = lvConsumerIndex(); + long bufferCapacity = getCurrentBufferCapacity(mask); + + if (cIndex + bufferCapacity > pIndex) + { + if (!casProducerLimit(producerLimit, cIndex + bufferCapacity)) + { + // retry from top + return RETRY; + } + else + { + // continue to pIndex CAS + return CONTINUE_TO_P_INDEX_CAS; + } + } + // full and cannot grow + else if (availableInQueue(pIndex, cIndex) <= 0) + { + // offer should return false; + return QUEUE_FULL; + } + // grab index for resize -> set lower bit + else if (casProducerIndex(pIndex, pIndex + 1)) + { + // trigger a resize + return QUEUE_RESIZE; + } + else + { + // failed resize attempt, retry from top + return RETRY; + } + } + + /** + * @return available elements in queue * 2 + */ + protected abstract long availableInQueue(long pIndex, long cIndex); + + @SuppressWarnings("unchecked") + private E[] nextBuffer(final E[] buffer, final long mask) + { + final long offset = nextArrayOffset(mask); + final E[] nextBuffer = (E[]) lvRefElement(buffer, offset); + consumerBuffer = nextBuffer; + consumerMask = (length(nextBuffer) - 2) << 1; + soRefElement(buffer, offset, BUFFER_CONSUMED); + return nextBuffer; + } + + private static long nextArrayOffset(long mask) + { + return modifiedCalcCircularRefElementOffset(mask + 2, Long.MAX_VALUE); + } + + private E newBufferPoll(E[] nextBuffer, long cIndex) + { + final long offset = modifiedCalcCircularRefElementOffset(cIndex, consumerMask); + final E n = lvRefElement(nextBuffer, offset); + if (n == null) + { + throw new IllegalStateException("new buffer must have at least one element"); + } + soRefElement(nextBuffer, offset, null); + soConsumerIndex(cIndex + 2); + return n; + } + + private E newBufferPeek(E[] nextBuffer, long cIndex) + { + final long offset = modifiedCalcCircularRefElementOffset(cIndex, consumerMask); + final E n = lvRefElement(nextBuffer, offset); + if (null == n) + { + throw new IllegalStateException("new buffer must have at least one element"); + } + return n; + } + + @Override + public long currentProducerIndex() + { + return lvProducerIndex() / 2; + } + + @Override + public long currentConsumerIndex() + { + return lvConsumerIndex() / 2; + } + + @Override + public abstract int capacity(); + + @Override + public boolean relaxedOffer(E e) + { + return offer(e); + } + + @SuppressWarnings("unchecked") + @Override + public E relaxedPoll() + { + final E[] buffer = consumerBuffer; + final long cIndex = lpConsumerIndex(); + final long mask = consumerMask; + + final long offset = modifiedCalcCircularRefElementOffset(cIndex, mask); + Object e = lvRefElement(buffer, offset); + if (e == null) + { + return null; + } + if (e == JUMP) + { + final E[] nextBuffer = nextBuffer(buffer, mask); + return newBufferPoll(nextBuffer, cIndex); + } + soRefElement(buffer, offset, null); + soConsumerIndex(cIndex + 2); + return (E) e; + } + + @SuppressWarnings("unchecked") + @Override + public E relaxedPeek() + { + final E[] buffer = consumerBuffer; + final long cIndex = lpConsumerIndex(); + final long mask = consumerMask; + + final long offset = modifiedCalcCircularRefElementOffset(cIndex, mask); + Object e = lvRefElement(buffer, offset); + if (e == JUMP) + { + return newBufferPeek(nextBuffer(buffer, mask), cIndex); + } + return (E) e; + } + + @Override + public int fill(Supplier s) + { + long result = 0;// result is a long because we want to have a safepoint check at regular intervals + final int capacity = capacity(); + do + { + final int filled = fill(s, PortableJvmInfo.RECOMENDED_OFFER_BATCH); + if (filled == 0) + { + return (int) result; + } + result += filled; + } + while (result <= capacity); + return (int) result; + } + + @Override + public int fill(Supplier s, int limit) + { + if (null == s) + throw new IllegalArgumentException("supplier is null"); + if (limit < 0) + throw new IllegalArgumentException("limit is negative:" + limit); + if (limit == 0) + return 0; + + long mask; + E[] buffer; + long pIndex; + int claimedSlots; + while (true) + { + long producerLimit = lvProducerLimit(); + pIndex = lvProducerIndex(); + // lower bit is indicative of resize, if we see it we spin until it's cleared + if ((pIndex & 1) == 1) + { + continue; + } + // pIndex is even (lower bit is 0) -> actual index is (pIndex >> 1) + + // NOTE: mask/buffer may get changed by resizing -> only use for array access after successful CAS. + // Only by virtue offloading them between the lvProducerIndex and a successful casProducerIndex are they + // safe to use. + mask = this.producerMask; + buffer = this.producerBuffer; + // a successful CAS ties the ordering, lv(pIndex) -> [mask/buffer] -> cas(pIndex) + + // we want 'limit' slots, but will settle for whatever is visible to 'producerLimit' + long batchIndex = Math.min(producerLimit, pIndex + 2l * limit); // -> producerLimit >= batchIndex + + if (pIndex >= producerLimit) + { + int result = offerSlowPath(mask, pIndex, producerLimit); + switch (result) + { + case CONTINUE_TO_P_INDEX_CAS: + // offer slow path verifies only one slot ahead, we cannot rely on indication here + case RETRY: + continue; + case QUEUE_FULL: + return 0; + case QUEUE_RESIZE: + resize(mask, buffer, pIndex, null, s); + return 1; + } + } + + // claim limit slots at once + if (casProducerIndex(pIndex, batchIndex)) + { + claimedSlots = (int) ((batchIndex - pIndex) / 2); + break; + } + } + + for (int i = 0; i < claimedSlots; i++) + { + final long offset = modifiedCalcCircularRefElementOffset(pIndex + 2l * i, mask); + soRefElement(buffer, offset, s.get()); + } + return claimedSlots; + } + + @Override + public void fill(Supplier s, WaitStrategy wait, ExitCondition exit) + { + MessagePassingQueueUtil.fill(this, s, wait, exit); + } + @Override + public int drain(Consumer c) + { + return drain(c, capacity()); + } + + @Override + public int drain(Consumer c, int limit) + { + return MessagePassingQueueUtil.drain(this, c, limit); + } + + @Override + public void drain(Consumer c, WaitStrategy wait, ExitCondition exit) + { + MessagePassingQueueUtil.drain(this, c, wait, exit); + } + + /** + * Get an iterator for this queue. This method is thread safe. + *

+ * The iterator provides a best-effort snapshot of the elements in the queue. + * The returned iterator is not guaranteed to return elements in queue order, + * and races with the consumer thread may cause gaps in the sequence of returned elements. + * Like {link #relaxedPoll}, the iterator may not immediately return newly inserted elements. + * + * @return The iterator. + */ + @Override + public Iterator iterator() { + return new WeakIterator(consumerBuffer, lvConsumerIndex(), lvProducerIndex()); + } + + private static class WeakIterator implements Iterator + { + private final long pIndex; + private long nextIndex; + private E nextElement; + private E[] currentBuffer; + private int mask; + + WeakIterator(E[] currentBuffer, long cIndex, long pIndex) + { + this.pIndex = pIndex >> 1; + this.nextIndex = cIndex >> 1; + setBuffer(currentBuffer); + nextElement = getNext(); + } + + @Override + public void remove() { + throw new UnsupportedOperationException("remove"); + } + + @Override + public boolean hasNext() + { + return nextElement != null; + } + + @Override + public E next() + { + final E e = nextElement; + if (e == null) + { + throw new NoSuchElementException(); + } + nextElement = getNext(); + return e; + } + + private void setBuffer(E[] buffer) + { + this.currentBuffer = buffer; + this.mask = length(buffer) - 2; + } + + private E getNext() + { + while (nextIndex < pIndex) + { + long index = nextIndex++; + E e = lvRefElement(currentBuffer, calcCircularRefElementOffset(index, mask)); + // skip removed/not yet visible elements + if (e == null) + { + continue; + } + + // not null && not JUMP -> found next element + if (e != JUMP) + { + return e; + } + + // need to jump to the next buffer + int nextBufferIndex = mask + 1; + Object nextBuffer = lvRefElement(currentBuffer, + calcRefElementOffset(nextBufferIndex)); + + if (nextBuffer == BUFFER_CONSUMED || nextBuffer == null) + { + // Consumer may have passed us, or the next buffer is not visible yet: drop out early + return null; + } + + setBuffer((E[]) nextBuffer); + // now with the new array retry the load, it can't be a JUMP, but we need to repeat same index + e = lvRefElement(currentBuffer, calcCircularRefElementOffset(index, mask)); + // skip removed/not yet visible elements + if (e == null) + { + continue; + } + else + { + return e; + } + + } + return null; + } + } + + private void resize(long oldMask, E[] oldBuffer, long pIndex, E e, Supplier s) + { + assert (e != null && s == null) || (e == null || s != null); + int newBufferLength = getNextBufferSize(oldBuffer); + final E[] newBuffer; + try + { + newBuffer = allocateRefArray(newBufferLength); + } + catch (OutOfMemoryError oom) + { + assert lvProducerIndex() == pIndex + 1; + soProducerIndex(pIndex); + throw oom; + } + + producerBuffer = newBuffer; + final int newMask = (newBufferLength - 2) << 1; + producerMask = newMask; + + final long offsetInOld = modifiedCalcCircularRefElementOffset(pIndex, oldMask); + final long offsetInNew = modifiedCalcCircularRefElementOffset(pIndex, newMask); + + soRefElement(newBuffer, offsetInNew, e == null ? s.get() : e);// element in new array + soRefElement(oldBuffer, nextArrayOffset(oldMask), newBuffer);// buffer linked + + // ASSERT code + final long cIndex = lvConsumerIndex(); + final long availableInQueue = availableInQueue(pIndex, cIndex); + RangeUtil.checkPositive(availableInQueue, "availableInQueue"); + + // Invalidate racing CASs + // We never set the limit beyond the bounds of a buffer + soProducerLimit(pIndex + Math.min(newMask, availableInQueue)); + + // make resize visible to the other producers + soProducerIndex(pIndex + 2); + + // INDEX visible before ELEMENT, consistent with consumer expectation + + // make resize visible to consumer + soRefElement(oldBuffer, offsetInOld, JUMP); + } + + /** + * @return next buffer size(inclusive of next array pointer) + */ + protected abstract int getNextBufferSize(E[] buffer); + + /** + * @return current buffer capacity for elements (excluding next pointer and jump entry) * 2 + */ + protected abstract long getCurrentBufferCapacity(long mask); +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/BaseSpscLinkedArrayQueue.java b/netty-jctools/src/main/java/org/jctools/queues/BaseSpscLinkedArrayQueue.java new file mode 100644 index 0000000..8f65b52 --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/BaseSpscLinkedArrayQueue.java @@ -0,0 +1,420 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues; + +import org.jctools.queues.IndexedQueueSizeUtil.IndexedQueue; +import org.jctools.util.PortableJvmInfo; + +import java.util.AbstractQueue; +import java.util.Iterator; + +import static org.jctools.queues.LinkedArrayQueueUtil.length; +import static org.jctools.queues.LinkedArrayQueueUtil.nextArrayOffset; +import static org.jctools.util.UnsafeAccess.UNSAFE; +import static org.jctools.util.UnsafeAccess.fieldOffset; +import static org.jctools.util.UnsafeRefArrayAccess.*; + +abstract class BaseSpscLinkedArrayQueuePrePad extends AbstractQueue implements IndexedQueue +{ + byte b000,b001,b002,b003,b004,b005,b006,b007;// 8b + byte b010,b011,b012,b013,b014,b015,b016,b017;// 16b + byte b020,b021,b022,b023,b024,b025,b026,b027;// 24b + byte b030,b031,b032,b033,b034,b035,b036,b037;// 32b + byte b040,b041,b042,b043,b044,b045,b046,b047;// 40b + byte b050,b051,b052,b053,b054,b055,b056,b057;// 48b + byte b060,b061,b062,b063,b064,b065,b066,b067;// 56b + byte b070,b071,b072,b073,b074,b075,b076,b077;// 64b + byte b100,b101,b102,b103,b104,b105,b106,b107;// 72b + byte b110,b111,b112,b113,b114,b115,b116,b117;// 80b + byte b120,b121,b122,b123,b124,b125,b126,b127;// 88b + byte b130,b131,b132,b133,b134,b135,b136,b137;// 96b + byte b140,b141,b142,b143,b144,b145,b146,b147;//104b + byte b150,b151,b152,b153,b154,b155,b156,b157;//112b + //byte b160,b161,b162,b163,b164,b165,b166,b167;//120b + //byte b170,b171,b172,b173,b174,b175,b176,b177;//128b + // * drop 16b , the cold fields act as buffer * +} + +abstract class BaseSpscLinkedArrayQueueConsumerColdFields extends BaseSpscLinkedArrayQueuePrePad +{ + protected long consumerMask; + protected E[] consumerBuffer; +} + +// $gen:ordered-fields +abstract class BaseSpscLinkedArrayQueueConsumerField extends BaseSpscLinkedArrayQueueConsumerColdFields +{ + private final static long C_INDEX_OFFSET = fieldOffset(BaseSpscLinkedArrayQueueConsumerField.class, "consumerIndex"); + + private volatile long consumerIndex; + + @Override + public final long lvConsumerIndex() + { + return consumerIndex; + } + + final long lpConsumerIndex() + { + return UNSAFE.getLong(this, C_INDEX_OFFSET); + } + + final void soConsumerIndex(long newValue) + { + UNSAFE.putOrderedLong(this, C_INDEX_OFFSET, newValue); + } + +} + +abstract class BaseSpscLinkedArrayQueueL2Pad extends BaseSpscLinkedArrayQueueConsumerField +{ + byte b000,b001,b002,b003,b004,b005,b006,b007;// 8b + byte b010,b011,b012,b013,b014,b015,b016,b017;// 16b + byte b020,b021,b022,b023,b024,b025,b026,b027;// 24b + byte b030,b031,b032,b033,b034,b035,b036,b037;// 32b + byte b040,b041,b042,b043,b044,b045,b046,b047;// 40b + byte b050,b051,b052,b053,b054,b055,b056,b057;// 48b + byte b060,b061,b062,b063,b064,b065,b066,b067;// 56b + byte b070,b071,b072,b073,b074,b075,b076,b077;// 64b + byte b100,b101,b102,b103,b104,b105,b106,b107;// 72b + byte b110,b111,b112,b113,b114,b115,b116,b117;// 80b + byte b120,b121,b122,b123,b124,b125,b126,b127;// 88b + byte b130,b131,b132,b133,b134,b135,b136,b137;// 96b + byte b140,b141,b142,b143,b144,b145,b146,b147;//104b + byte b150,b151,b152,b153,b154,b155,b156,b157;//112b + byte b160,b161,b162,b163,b164,b165,b166,b167;//120b + byte b170,b171,b172,b173,b174,b175,b176,b177;//128b +} + +// $gen:ordered-fields +abstract class BaseSpscLinkedArrayQueueProducerFields extends BaseSpscLinkedArrayQueueL2Pad +{ + private final static long P_INDEX_OFFSET = fieldOffset(BaseSpscLinkedArrayQueueProducerFields.class,"producerIndex"); + + private volatile long producerIndex; + + @Override + public final long lvProducerIndex() + { + return producerIndex; + } + + final void soProducerIndex(long newValue) + { + UNSAFE.putOrderedLong(this, P_INDEX_OFFSET, newValue); + } + + final long lpProducerIndex() + { + return UNSAFE.getLong(this, P_INDEX_OFFSET); + } + +} + +abstract class BaseSpscLinkedArrayQueueProducerColdFields extends BaseSpscLinkedArrayQueueProducerFields +{ + protected long producerBufferLimit; + protected long producerMask; // fixed for chunked and unbounded + + protected E[] producerBuffer; +} + +abstract class BaseSpscLinkedArrayQueue extends BaseSpscLinkedArrayQueueProducerColdFields + implements MessagePassingQueue, QueueProgressIndicators +{ + + private static final Object JUMP = new Object(); + + @Override + public final Iterator iterator() + { + throw new UnsupportedOperationException(); + } + + @Override + public final int size() + { + return IndexedQueueSizeUtil.size(this, IndexedQueueSizeUtil.PLAIN_DIVISOR); + } + + @Override + public final boolean isEmpty() + { + return IndexedQueueSizeUtil.isEmpty(this); + } + + @Override + public String toString() + { + return this.getClass().getName(); + } + + @Override + public long currentProducerIndex() + { + return lvProducerIndex(); + } + + @Override + public long currentConsumerIndex() + { + return lvConsumerIndex(); + } + + protected final void soNext(E[] curr, E[] next) + { + long offset = nextArrayOffset(curr); + soRefElement(curr, offset, next); + } + + @SuppressWarnings("unchecked") + protected final E[] lvNextArrayAndUnlink(E[] curr) + { + final long offset = nextArrayOffset(curr); + final E[] nextBuffer = (E[]) lvRefElement(curr, offset); + // prevent GC nepotism + soRefElement(curr, offset, null); + return nextBuffer; + } + + @Override + public boolean relaxedOffer(E e) + { + return offer(e); + } + + @Override + public E relaxedPoll() + { + return poll(); + } + + @Override + public E relaxedPeek() + { + return peek(); + } + + @Override + public int drain(Consumer c) + { + return MessagePassingQueueUtil.drain(this, c); + } + + @Override + public int fill(Supplier s) + { + long result = 0;// result is a long because we want to have a safepoint check at regular intervals + final int capacity = capacity(); + do + { + final int filled = fill(s, PortableJvmInfo.RECOMENDED_OFFER_BATCH); + if (filled == 0) + { + return (int) result; + } + result += filled; + } + while (result <= capacity); + return (int) result; + } + + @Override + public int drain(Consumer c, int limit) + { + return MessagePassingQueueUtil.drain(this, c, limit); + } + + @Override + public int fill(Supplier s, int limit) + { + if (null == s) + throw new IllegalArgumentException("supplier is null"); + if (limit < 0) + throw new IllegalArgumentException("limit is negative:" + limit); + if (limit == 0) + return 0; + + for (int i = 0; i < limit; i++) + { + // local load of field to avoid repeated loads after volatile reads + final E[] buffer = producerBuffer; + final long index = lpProducerIndex(); + final long mask = producerMask; + final long offset = calcCircularRefElementOffset(index, mask); + // expected hot path + if (index < producerBufferLimit) + { + writeToQueue(buffer, s.get(), index, offset); + } + else + { + if (!offerColdPath(buffer, mask, index, offset, null, s)) + { + return i; + } + } + } + return limit; + } + + @Override + public void drain(Consumer c, WaitStrategy wait, ExitCondition exit) + { + MessagePassingQueueUtil.drain(this, c, wait, exit); + } + + @Override + public void fill(Supplier s, WaitStrategy wait, ExitCondition exit) + { + MessagePassingQueueUtil.fill(this, s, wait, exit); + } + + /** + * {@inheritDoc} + *

+ * This implementation is correct for single producer thread use only. + */ + @Override + public boolean offer(final E e) + { + // Objects.requireNonNull(e); + if (null == e) + { + throw new NullPointerException(); + } + // local load of field to avoid repeated loads after volatile reads + final E[] buffer = producerBuffer; + final long index = lpProducerIndex(); + final long mask = producerMask; + final long offset = calcCircularRefElementOffset(index, mask); + // expected hot path + if (index < producerBufferLimit) + { + writeToQueue(buffer, e, index, offset); + return true; + } + return offerColdPath(buffer, mask, index, offset, e, null); + } + + abstract boolean offerColdPath( + E[] buffer, + long mask, + long pIndex, + long offset, + E v, + Supplier s); + + /** + * {@inheritDoc} + *

+ * This implementation is correct for single consumer thread use only. + */ + @SuppressWarnings("unchecked") + @Override + public E poll() + { + // local load of field to avoid repeated loads after volatile reads + final E[] buffer = consumerBuffer; + final long index = lpConsumerIndex(); + final long mask = consumerMask; + final long offset = calcCircularRefElementOffset(index, mask); + final Object e = lvRefElement(buffer, offset); + boolean isNextBuffer = e == JUMP; + if (null != e && !isNextBuffer) + { + soConsumerIndex(index + 1);// this ensures correctness on 32bit platforms + soRefElement(buffer, offset, null); + return (E) e; + } + else if (isNextBuffer) + { + return newBufferPoll(buffer, index); + } + + return null; + } + + /** + * {@inheritDoc} + *

+ * This implementation is correct for single consumer thread use only. + */ + @SuppressWarnings("unchecked") + @Override + public E peek() + { + final E[] buffer = consumerBuffer; + final long index = lpConsumerIndex(); + final long mask = consumerMask; + final long offset = calcCircularRefElementOffset(index, mask); + final Object e = lvRefElement(buffer, offset); + if (e == JUMP) + { + return newBufferPeek(buffer, index); + } + + return (E) e; + } + + final void linkOldToNew( + final long currIndex, + final E[] oldBuffer, final long offset, + final E[] newBuffer, final long offsetInNew, + final E e) + { + soRefElement(newBuffer, offsetInNew, e); + // link to next buffer and add next indicator as element of old buffer + soNext(oldBuffer, newBuffer); + soRefElement(oldBuffer, offset, JUMP); + // index is visible after elements (isEmpty/poll ordering) + soProducerIndex(currIndex + 1);// this ensures atomic write of long on 32bit platforms + } + + final void writeToQueue(final E[] buffer, final E e, final long index, final long offset) + { + soRefElement(buffer, offset, e); + soProducerIndex(index + 1);// this ensures atomic write of long on 32bit platforms + } + + private E newBufferPeek(final E[] buffer, final long index) + { + E[] nextBuffer = lvNextArrayAndUnlink(buffer); + consumerBuffer = nextBuffer; + final long mask = length(nextBuffer) - 2; + consumerMask = mask; + final long offset = calcCircularRefElementOffset(index, mask); + return lvRefElement(nextBuffer, offset); + } + + private E newBufferPoll(final E[] buffer, final long index) + { + E[] nextBuffer = lvNextArrayAndUnlink(buffer); + consumerBuffer = nextBuffer; + final long mask = length(nextBuffer) - 2; + consumerMask = mask; + final long offset = calcCircularRefElementOffset(index, mask); + final E n = lvRefElement(nextBuffer, offset); + if (null == n) + { + throw new IllegalStateException("new buffer must have at least one element"); + } + else + { + soConsumerIndex(index + 1);// this ensures correctness on 32bit platforms + soRefElement(nextBuffer, offset, null); + return n; + } + } +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/ConcurrentCircularArrayQueue.java b/netty-jctools/src/main/java/org/jctools/queues/ConcurrentCircularArrayQueue.java new file mode 100644 index 0000000..54aa06f --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/ConcurrentCircularArrayQueue.java @@ -0,0 +1,170 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues; + +import org.jctools.queues.IndexedQueueSizeUtil.IndexedQueue; +import org.jctools.util.Pow2; + +import java.util.AbstractQueue; +import java.util.Iterator; +import java.util.NoSuchElementException; + +import static org.jctools.util.UnsafeRefArrayAccess.*; + +abstract class ConcurrentCircularArrayQueueL0Pad extends AbstractQueue +{ + byte b000,b001,b002,b003,b004,b005,b006,b007;// 8b + byte b010,b011,b012,b013,b014,b015,b016,b017;// 16b + byte b020,b021,b022,b023,b024,b025,b026,b027;// 24b + byte b030,b031,b032,b033,b034,b035,b036,b037;// 32b + byte b040,b041,b042,b043,b044,b045,b046,b047;// 40b + byte b050,b051,b052,b053,b054,b055,b056,b057;// 48b + byte b060,b061,b062,b063,b064,b065,b066,b067;// 56b + byte b070,b071,b072,b073,b074,b075,b076,b077;// 64b + byte b100,b101,b102,b103,b104,b105,b106,b107;// 72b + byte b110,b111,b112,b113,b114,b115,b116,b117;// 80b + byte b120,b121,b122,b123,b124,b125,b126,b127;// 88b + byte b130,b131,b132,b133,b134,b135,b136,b137;// 96b + byte b140,b141,b142,b143,b144,b145,b146,b147;//104b + byte b150,b151,b152,b153,b154,b155,b156,b157;//112b + byte b160,b161,b162,b163,b164,b165,b166,b167;//120b + byte b170,b171,b172,b173,b174,b175,b176,b177;//128b +} + +/** + * Common functionality for array backed queues. The class is pre-padded and the array is padded on either side to help + * with False Sharing prevention. It is expected that subclasses handle post padding. + */ +abstract class ConcurrentCircularArrayQueue extends ConcurrentCircularArrayQueueL0Pad + implements MessagePassingQueue, IndexedQueue, QueueProgressIndicators, SupportsIterator +{ + protected final long mask; + protected final E[] buffer; + + ConcurrentCircularArrayQueue(int capacity) + { + int actualCapacity = Pow2.roundToPowerOfTwo(capacity); + mask = actualCapacity - 1; + buffer = allocateRefArray(actualCapacity); + } + + @Override + public int size() + { + return IndexedQueueSizeUtil.size(this, IndexedQueueSizeUtil.PLAIN_DIVISOR); + } + + @Override + public boolean isEmpty() + { + return IndexedQueueSizeUtil.isEmpty(this); + } + + @Override + public String toString() + { + return this.getClass().getName(); + } + + @Override + public void clear() + { + while (poll() != null) + { + // if you stare into the void + } + } + + @Override + public int capacity() + { + return (int) (mask + 1); + } + + @Override + public long currentProducerIndex() + { + return lvProducerIndex(); + } + + @Override + public long currentConsumerIndex() + { + return lvConsumerIndex(); + } + + /** + * Get an iterator for this queue. This method is thread safe. + *

+ * The iterator provides a best-effort snapshot of the elements in the queue. + * The returned iterator is not guaranteed to return elements in queue order, + * and races with the consumer thread may cause gaps in the sequence of returned elements. + * Like {link #relaxedPoll}, the iterator may not immediately return newly inserted elements. + * + * @return The iterator. + */ + @Override + public Iterator iterator() { + final long cIndex = lvConsumerIndex(); + final long pIndex = lvProducerIndex(); + + return new WeakIterator(cIndex, pIndex, mask, buffer); + } + + private static class WeakIterator implements Iterator { + private final long pIndex; + private final long mask; + private final E[] buffer; + private long nextIndex; + private E nextElement; + + WeakIterator(long cIndex, long pIndex, long mask, E[] buffer) { + this.nextIndex = cIndex; + this.pIndex = pIndex; + this.mask = mask; + this.buffer = buffer; + nextElement = getNext(); + } + + @Override + public void remove() { + throw new UnsupportedOperationException("remove"); + } + + @Override + public boolean hasNext() { + return nextElement != null; + } + + @Override + public E next() { + final E e = nextElement; + if (e == null) + throw new NoSuchElementException(); + nextElement = getNext(); + return e; + } + + private E getNext() { + while (nextIndex < pIndex) { + long offset = calcCircularRefElementOffset(nextIndex++, mask); + E e = lvRefElement(buffer, offset); + if (e != null) { + return e; + } + } + return null; + } + } +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/ConcurrentSequencedCircularArrayQueue.java b/netty-jctools/src/main/java/org/jctools/queues/ConcurrentSequencedCircularArrayQueue.java new file mode 100644 index 0000000..5eaafca --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/ConcurrentSequencedCircularArrayQueue.java @@ -0,0 +1,33 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues; + +import static org.jctools.util.UnsafeLongArrayAccess.*; + +public abstract class ConcurrentSequencedCircularArrayQueue extends ConcurrentCircularArrayQueue +{ + protected final long[] sequenceBuffer; + + public ConcurrentSequencedCircularArrayQueue(int capacity) + { + super(capacity); + int actualCapacity = (int) (this.mask + 1); + // pad data on either end with some empty slots. Note that actualCapacity is <= MAX_POW2_INT + sequenceBuffer = allocateLongArray(actualCapacity); + for (long i = 0; i < actualCapacity; i++) + { + soLongElement(sequenceBuffer, calcCircularLongElementOffset(i, mask), i); + } + } +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/IndexedQueueSizeUtil.java b/netty-jctools/src/main/java/org/jctools/queues/IndexedQueueSizeUtil.java new file mode 100644 index 0000000..c1da275 --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/IndexedQueueSizeUtil.java @@ -0,0 +1,101 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues; + +import org.jctools.util.InternalAPI; + +/** + * A note to maintainers on index assumptions: in a single threaded world it would seem intuitive to assume: + *

+ * producerIndex >= consumerIndex
+ * 
+ * As an invariant, but in a concurrent, long running settings all of the following need to be considered: + *
    + *
  • consumerIndex > producerIndex : due to counter overflow (unlikey with longs, but easy to reason) + *
  • consumerIndex > producerIndex : due to consumer FastFlow like implementation discovering the + * element before the counter is updated. + *
  • producerIndex - consumerIndex < 0 : due to above. + *
  • producerIndex - consumerIndex > Integer.MAX_VALUE : as linked buffers allow constructing queues + * with more than Integer.MAX_VALUE elements. + * + *
+ */ +@InternalAPI +public final class IndexedQueueSizeUtil +{ + + public static final int PLAIN_DIVISOR = 1; + public static final int IGNORE_PARITY_DIVISOR = 2; + + public static int size(IndexedQueue iq, int divisor) { + /* + * It is possible for a thread to be interrupted or reschedule between the reads of the producer and + * consumer indices. It is also for the indices to be updated in a `weakly` visible way. It follows that + * the size value needs to be sanitized to match a valid range. + */ + long after = iq.lvConsumerIndex(); + long size; + while (true) + { + final long before = after; + // pIndex read is "sandwiched" between 2 cIndex reads + final long currentProducerIndex = iq.lvProducerIndex(); + after = iq.lvConsumerIndex(); + if (before == after) + { + size = (currentProducerIndex - after) / divisor; + break; + } + } + return sanitizedSize(iq.capacity(), size); + } + + public static int sanitizedSize(int capacity, long size) { + // Concurrent updates to cIndex and pIndex may lag behind other progress enablers (e.g. FastFlow), so we need + // to check bounds [0,capacity] + if (size < 0) + { + return 0; + } + if (capacity != MessagePassingQueue.UNBOUNDED_CAPACITY && size > capacity) + { + return capacity; + } + // Integer overflow is possible for the unbounded indexed queues. + if (size > Integer.MAX_VALUE) + { + return Integer.MAX_VALUE; + } + return (int) size; + } + + public static boolean isEmpty(IndexedQueue iq) + { + // Order matters! + // Loading consumer before producer allows for producer increments after consumer index is read. + // This ensures this method is conservative in it's estimate. Note that as this is an MPMC there is + // nothing we can do to make this an exact method. + return (iq.lvConsumerIndex() >= iq.lvProducerIndex()); + } + + @InternalAPI + public interface IndexedQueue + { + long lvConsumerIndex(); + + long lvProducerIndex(); + + int capacity(); + } +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/LinkedArrayQueueUtil.java b/netty-jctools/src/main/java/org/jctools/queues/LinkedArrayQueueUtil.java new file mode 100644 index 0000000..6b00d61 --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/LinkedArrayQueueUtil.java @@ -0,0 +1,33 @@ +package org.jctools.queues; + +import org.jctools.util.InternalAPI; + +import static org.jctools.util.UnsafeRefArrayAccess.REF_ARRAY_BASE; +import static org.jctools.util.UnsafeRefArrayAccess.REF_ELEMENT_SHIFT; + +/** + * This is used for method substitution in the LinkedArray classes code generation. + */ +@InternalAPI +public final class LinkedArrayQueueUtil +{ + public static int length(Object[] buf) + { + return buf.length; + } + + /** + * This method assumes index is actually (index << 1) because lower bit is + * used for resize. This is compensated for by reducing the element shift. + * The computation is constant folded, so there's no cost. + */ + public static long modifiedCalcCircularRefElementOffset(long index, long mask) + { + return REF_ARRAY_BASE + ((index & mask) << (REF_ELEMENT_SHIFT - 1)); + } + + public static long nextArrayOffset(Object[] curr) + { + return REF_ARRAY_BASE + ((long) (length(curr) - 1) << REF_ELEMENT_SHIFT); + } +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/LinkedQueueNode.java b/netty-jctools/src/main/java/org/jctools/queues/LinkedQueueNode.java new file mode 100644 index 0000000..a0d3008 --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/LinkedQueueNode.java @@ -0,0 +1,75 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues; + +import org.jctools.util.InternalAPI; + +import static org.jctools.util.UnsafeAccess.UNSAFE; +import static org.jctools.util.UnsafeAccess.fieldOffset; + +@InternalAPI +public final class LinkedQueueNode +{ + private final static long NEXT_OFFSET = fieldOffset(LinkedQueueNode.class,"next"); + + private E value; + private volatile LinkedQueueNode next; + + public LinkedQueueNode() + { + this(null); + } + + public LinkedQueueNode(E val) + { + spValue(val); + } + + /** + * Gets the current value and nulls out the reference to it from this node. + * + * @return value + */ + public E getAndNullValue() + { + E temp = lpValue(); + spValue(null); + return temp; + } + + public E lpValue() + { + return value; + } + + public void spValue(E newValue) + { + value = newValue; + } + + public void soNext(LinkedQueueNode n) + { + UNSAFE.putOrderedObject(this, NEXT_OFFSET, n); + } + + public void spNext(LinkedQueueNode n) + { + UNSAFE.putObject(this, NEXT_OFFSET, n); + } + + public LinkedQueueNode lvNext() + { + return next; + } +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/MessagePassingQueue.java b/netty-jctools/src/main/java/org/jctools/queues/MessagePassingQueue.java new file mode 100644 index 0000000..5a3edf6 --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/MessagePassingQueue.java @@ -0,0 +1,316 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues; + +import java.util.Queue; + +/** + * Message passing queues are intended for concurrent method passing. A subset of {@link Queue} methods are provided + * with the same semantics, while further functionality which accomodates the concurrent usecase is also on offer. + *

+ * Message passing queues provide happens before semantics to messages passed through, namely that writes made + * by the producer before offering the message are visible to the consuming thread after the message has been + * polled out of the queue. + * + * @param the event/message type + */ +public interface MessagePassingQueue +{ + int UNBOUNDED_CAPACITY = -1; + + interface Supplier + { + /** + * This method will return the next value to be written to the queue. As such the queue + * implementations are commited to insert the value once the call is made. + *

+ * Users should be aware that underlying queue implementations may upfront claim parts of the queue + * for batch operations and this will effect the view on the queue from the supplier method. In + * particular size and any offer methods may take the view that the full batch has already happened. + * + *

WARNING: this method is assumed to never throw. Breaking this assumption can lead to a broken queue. + *

WARNING: this method is assumed to never return {@code null}. Breaking this assumption can lead to a broken queue. + * + * @return new element, NEVER {@code null} + */ + T get(); + } + + interface Consumer + { + /** + * This method will process an element already removed from the queue. This method is expected to + * never throw an exception. + *

+ * Users should be aware that underlying queue implementations may upfront claim parts of the queue + * for batch operations and this will effect the view on the queue from the accept method. In + * particular size and any poll/peek methods may take the view that the full batch has already + * happened. + * + *

WARNING: this method is assumed to never throw. Breaking this assumption can lead to a broken queue. + * @param e not {@code null} + */ + void accept(T e); + } + + interface WaitStrategy + { + /** + * This method can implement static or dynamic backoff. Dynamic backoff will rely on the counter for + * estimating how long the caller has been idling. The expected usage is: + *

+ *

+         * 
+         * int ic = 0;
+         * while(true) {
+         *   if(!isGodotArrived()) {
+         *     ic = w.idle(ic);
+         *     continue;
+         *   }
+         *   ic = 0;
+         *   // party with Godot until he goes again
+         * }
+         * 
+         * 
+ * + * @param idleCounter idle calls counter, managed by the idle method until reset + * @return new counter value to be used on subsequent idle cycle + */ + int idle(int idleCounter); + } + + interface ExitCondition + { + + /** + * This method should be implemented such that the flag read or determination cannot be hoisted out of + * a loop which notmally means a volatile load, but with JDK9 VarHandles may mean getOpaque. + * + * @return true as long as we should keep running + */ + boolean keepRunning(); + } + + /** + * Called from a producer thread subject to the restrictions appropriate to the implementation and + * according to the {@link Queue#offer(Object)} interface. + * + * @param e not {@code null}, will throw NPE if it is + * @return true if element was inserted into the queue, false iff full + */ + boolean offer(T e); + + /** + * Called from the consumer thread subject to the restrictions appropriate to the implementation and + * according to the {@link Queue#poll()} interface. + * + * @return a message from the queue if one is available, {@code null} iff empty + */ + T poll(); + + /** + * Called from the consumer thread subject to the restrictions appropriate to the implementation and + * according to the {@link Queue#peek()} interface. + * + * @return a message from the queue if one is available, {@code null} iff empty + */ + T peek(); + + /** + * This method's accuracy is subject to concurrent modifications happening as the size is estimated and as + * such is a best effort rather than absolute value. For some implementations this method may be O(n) + * rather than O(1). + * + * @return number of messages in the queue, between 0 and {@link Integer#MAX_VALUE} but less or equals to + * capacity (if bounded). + */ + int size(); + + /** + * Removes all items from the queue. Called from the consumer thread subject to the restrictions + * appropriate to the implementation and according to the {@link Queue#clear()} interface. + */ + void clear(); + + /** + * This method's accuracy is subject to concurrent modifications happening as the observation is carried + * out. + * + * @return true if empty, false otherwise + */ + boolean isEmpty(); + + /** + * @return the capacity of this queue or {@link MessagePassingQueue#UNBOUNDED_CAPACITY} if not bounded + */ + int capacity(); + + /** + * Called from a producer thread subject to the restrictions appropriate to the implementation. As opposed + * to {@link Queue#offer(Object)} this method may return false without the queue being full. + * + * @param e not {@code null}, will throw NPE if it is + * @return true if element was inserted into the queue, false if unable to offer + */ + boolean relaxedOffer(T e); + + /** + * Called from the consumer thread subject to the restrictions appropriate to the implementation. As + * opposed to {@link Queue#poll()} this method may return {@code null} without the queue being empty. + * + * @return a message from the queue if one is available, {@code null} if unable to poll + */ + T relaxedPoll(); + + /** + * Called from the consumer thread subject to the restrictions appropriate to the implementation. As + * opposed to {@link Queue#peek()} this method may return {@code null} without the queue being empty. + * + * @return a message from the queue if one is available, {@code null} if unable to peek + */ + T relaxedPeek(); + + /** + * Remove up to limit elements from the queue and hand to consume. This should be semantically + * similar to: + *

+ *

{@code
+     *   M m;
+     *   int i = 0;
+     *   for(;i < limit && (m = relaxedPoll()) != null; i++){
+     *     c.accept(m);
+     *   }
+     *   return i;
+     * }
+ *

+ * There's no strong commitment to the queue being empty at the end of a drain. Called from a consumer + * thread subject to the restrictions appropriate to the implementation. + *

+ * WARNING: Explicit assumptions are made with regards to {@link Consumer#accept} make sure you have read + * and understood these before using this method. + * + * @return the number of polled elements + * @throws IllegalArgumentException c is {@code null} + * @throws IllegalArgumentException if limit is negative + */ + int drain(Consumer c, int limit); + + /** + * Stuff the queue with up to limit elements from the supplier. Semantically similar to: + *

+ *

{@code
+     *   for(int i=0; i < limit && relaxedOffer(s.get()); i++);
+     * }
+ *

+ * There's no strong commitment to the queue being full at the end of a fill. Called from a producer + * thread subject to the restrictions appropriate to the implementation. + * + * WARNING: Explicit assumptions are made with regards to {@link Supplier#get} make sure you have read + * and understood these before using this method. + * + * @return the number of offered elements + * @throws IllegalArgumentException s is {@code null} + * @throws IllegalArgumentException if limit is negative + */ + int fill(Supplier s, int limit); + + /** + * Remove all available item from the queue and hand to consume. This should be semantically similar to: + *

+     * M m;
+     * while((m = relaxedPoll()) != null){
+     * c.accept(m);
+     * }
+     * 
+ * There's no strong commitment to the queue being empty at the end of a drain. Called from a + * consumer thread subject to the restrictions appropriate to the implementation. + *

+ * WARNING: Explicit assumptions are made with regards to {@link Consumer#accept} make sure you have read + * and understood these before using this method. + * + * @return the number of polled elements + * @throws IllegalArgumentException c is {@code null} + */ + int drain(Consumer c); + + /** + * Stuff the queue with elements from the supplier. Semantically similar to: + *

+     * while(relaxedOffer(s.get());
+     * 
+ * There's no strong commitment to the queue being full at the end of a fill. Called from a + * producer thread subject to the restrictions appropriate to the implementation. + *

+ * Unbounded queues will fill up the queue with a fixed amount rather than fill up to oblivion. + * + * WARNING: Explicit assumptions are made with regards to {@link Supplier#get} make sure you have read + * and understood these before using this method. + * + * @return the number of offered elements + * @throws IllegalArgumentException s is {@code null} + */ + int fill(Supplier s); + + /** + * Remove elements from the queue and hand to consume forever. Semantically similar to: + *

+ *

+     *  int idleCounter = 0;
+     *  while (exit.keepRunning()) {
+     *      E e = relaxedPoll();
+     *      if(e==null){
+     *          idleCounter = wait.idle(idleCounter);
+     *          continue;
+     *      }
+     *      idleCounter = 0;
+     *      c.accept(e);
+     *  }
+     * 
+ *

+ * Called from a consumer thread subject to the restrictions appropriate to the implementation. + *

+ * WARNING: Explicit assumptions are made with regards to {@link Consumer#accept} make sure you have read + * and understood these before using this method. + * + * @throws IllegalArgumentException c OR wait OR exit are {@code null} + */ + void drain(Consumer c, WaitStrategy wait, ExitCondition exit); + + /** + * Stuff the queue with elements from the supplier forever. Semantically similar to: + *

+ *

+     * 
+     *  int idleCounter = 0;
+     *  while (exit.keepRunning()) {
+     *      E e = s.get();
+     *      while (!relaxedOffer(e)) {
+     *          idleCounter = wait.idle(idleCounter);
+     *          continue;
+     *      }
+     *      idleCounter = 0;
+     *  }
+     * 
+     * 
+ *

+ * Called from a producer thread subject to the restrictions appropriate to the implementation. The main difference + * being that implementors MUST assure room in the queue is available BEFORE calling {@link Supplier#get}. + * + * WARNING: Explicit assumptions are made with regards to {@link Supplier#get} make sure you have read + * and understood these before using this method. + * + * @throws IllegalArgumentException s OR wait OR exit are {@code null} + */ + void fill(Supplier s, WaitStrategy wait, ExitCondition exit); +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/MessagePassingQueueUtil.java b/netty-jctools/src/main/java/org/jctools/queues/MessagePassingQueueUtil.java new file mode 100644 index 0000000..dfb5391 --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/MessagePassingQueueUtil.java @@ -0,0 +1,125 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.jctools.queues; + +import org.jctools.queues.MessagePassingQueue.Consumer; +import org.jctools.queues.MessagePassingQueue.ExitCondition; +import org.jctools.queues.MessagePassingQueue.Supplier; +import org.jctools.queues.MessagePassingQueue.WaitStrategy; +import org.jctools.util.InternalAPI; +import org.jctools.util.PortableJvmInfo; + +@InternalAPI +public final class MessagePassingQueueUtil +{ + public static int drain(MessagePassingQueue queue, Consumer c, int limit) + { + if (null == c) + throw new IllegalArgumentException("c is null"); + if (limit < 0) + throw new IllegalArgumentException("limit is negative: " + limit); + if (limit == 0) + return 0; + E e; + int i = 0; + for (; i < limit && (e = queue.relaxedPoll()) != null; i++) + { + c.accept(e); + } + return i; + } + + public static int drain(MessagePassingQueue queue, Consumer c) + { + if (null == c) + throw new IllegalArgumentException("c is null"); + E e; + int i = 0; + while ((e = queue.relaxedPoll()) != null) + { + i++; + c.accept(e); + } + return i; + } + + public static void drain(MessagePassingQueue queue, Consumer c, WaitStrategy wait, ExitCondition exit) + { + if (null == c) + throw new IllegalArgumentException("c is null"); + if (null == wait) + throw new IllegalArgumentException("wait is null"); + if (null == exit) + throw new IllegalArgumentException("exit condition is null"); + + int idleCounter = 0; + while (exit.keepRunning()) + { + final E e = queue.relaxedPoll(); + if (e == null) + { + idleCounter = wait.idle(idleCounter); + continue; + } + idleCounter = 0; + c.accept(e); + } + } + + public static void fill(MessagePassingQueue q, Supplier s, WaitStrategy wait, ExitCondition exit) + { + if (null == wait) + throw new IllegalArgumentException("waiter is null"); + if (null == exit) + throw new IllegalArgumentException("exit condition is null"); + + int idleCounter = 0; + while (exit.keepRunning()) + { + if (q.fill(s, PortableJvmInfo.RECOMENDED_OFFER_BATCH) == 0) + { + idleCounter = wait.idle(idleCounter); + continue; + } + idleCounter = 0; + } + } + + public static int fillBounded(MessagePassingQueue q, Supplier s) + { + return fillInBatchesToLimit(q, s, PortableJvmInfo.RECOMENDED_OFFER_BATCH, q.capacity()); + } + + public static int fillInBatchesToLimit(MessagePassingQueue q, Supplier s, int batch, int limit) + { + long result = 0;// result is a long because we want to have a safepoint check at regular intervals + do + { + final int filled = q.fill(s, batch); + if (filled == 0) + { + return (int) result; + } + result += filled; + } + while (result <= limit); + return (int) result; + } + + public static int fillUnbounded(MessagePassingQueue q, Supplier s) + { + return fillInBatchesToLimit(q, s, PortableJvmInfo.RECOMENDED_OFFER_BATCH, 4096); + } +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/MpUnboundedXaddArrayQueue.java b/netty-jctools/src/main/java/org/jctools/queues/MpUnboundedXaddArrayQueue.java new file mode 100644 index 0000000..fb514fe --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/MpUnboundedXaddArrayQueue.java @@ -0,0 +1,469 @@ +package org.jctools.queues; + +import org.jctools.queues.IndexedQueueSizeUtil.IndexedQueue; +import org.jctools.util.PortableJvmInfo; +import org.jctools.util.Pow2; +import org.jctools.util.UnsafeAccess; + +import java.util.AbstractQueue; +import java.util.Iterator; + +import static org.jctools.queues.MpUnboundedXaddChunk.NOT_USED; +import static org.jctools.util.UnsafeAccess.UNSAFE; +import static org.jctools.util.UnsafeAccess.fieldOffset; + +abstract class MpUnboundedXaddArrayQueuePad1 extends AbstractQueue implements IndexedQueue +{ + byte b000,b001,b002,b003,b004,b005,b006,b007;// 8b + byte b010,b011,b012,b013,b014,b015,b016,b017;// 16b + byte b020,b021,b022,b023,b024,b025,b026,b027;// 24b + byte b030,b031,b032,b033,b034,b035,b036,b037;// 32b + byte b040,b041,b042,b043,b044,b045,b046,b047;// 40b + byte b050,b051,b052,b053,b054,b055,b056,b057;// 48b + byte b060,b061,b062,b063,b064,b065,b066,b067;// 56b + byte b070,b071,b072,b073,b074,b075,b076,b077;// 64b + byte b100,b101,b102,b103,b104,b105,b106,b107;// 72b + byte b110,b111,b112,b113,b114,b115,b116,b117;// 80b + byte b120,b121,b122,b123,b124,b125,b126,b127;// 88b + byte b130,b131,b132,b133,b134,b135,b136,b137;// 96b + byte b140,b141,b142,b143,b144,b145,b146,b147;//104b + byte b150,b151,b152,b153,b154,b155,b156,b157;//112b + byte b160,b161,b162,b163,b164,b165,b166,b167;//120b + byte b170,b171,b172,b173,b174,b175,b176,b177;//128b +} + +// $gen:ordered-fields +abstract class MpUnboundedXaddArrayQueueProducerFields extends MpUnboundedXaddArrayQueuePad1 +{ + private final static long P_INDEX_OFFSET = + fieldOffset(MpUnboundedXaddArrayQueueProducerFields.class, "producerIndex"); + private volatile long producerIndex; + + @Override + public final long lvProducerIndex() + { + return producerIndex; + } + + final long getAndIncrementProducerIndex() + { + return UNSAFE.getAndAddLong(this, P_INDEX_OFFSET, 1); + } + + final long getAndAddProducerIndex(long delta) + { + return UNSAFE.getAndAddLong(this, P_INDEX_OFFSET, delta); + } +} + +abstract class MpUnboundedXaddArrayQueuePad2 extends MpUnboundedXaddArrayQueueProducerFields +{ + byte b000,b001,b002,b003,b004,b005,b006,b007;// 8b + byte b010,b011,b012,b013,b014,b015,b016,b017;// 16b + byte b020,b021,b022,b023,b024,b025,b026,b027;// 24b + byte b030,b031,b032,b033,b034,b035,b036,b037;// 32b + byte b040,b041,b042,b043,b044,b045,b046,b047;// 40b + byte b050,b051,b052,b053,b054,b055,b056,b057;// 48b + byte b060,b061,b062,b063,b064,b065,b066,b067;// 56b + byte b070,b071,b072,b073,b074,b075,b076,b077;// 64b + byte b100,b101,b102,b103,b104,b105,b106,b107;// 72b + byte b110,b111,b112,b113,b114,b115,b116,b117;// 80b + byte b120,b121,b122,b123,b124,b125,b126,b127;// 88b + byte b130,b131,b132,b133,b134,b135,b136,b137;// 96b + byte b140,b141,b142,b143,b144,b145,b146,b147;//104b + byte b150,b151,b152,b153,b154,b155,b156,b157;//112b + byte b160,b161,b162,b163,b164,b165,b166,b167;//120b + // byte b170,b171,b172,b173,b174,b175,b176,b177;//128b +} + +// $gen:ordered-fields +abstract class MpUnboundedXaddArrayQueueProducerChunk, E> + extends MpUnboundedXaddArrayQueuePad2 +{ + private static final long P_CHUNK_OFFSET = + fieldOffset(MpUnboundedXaddArrayQueueProducerChunk.class, "producerChunk"); + private static final long P_CHUNK_INDEX_OFFSET = + fieldOffset(MpUnboundedXaddArrayQueueProducerChunk.class, "producerChunkIndex"); + + private volatile R producerChunk; + private volatile long producerChunkIndex; + + + final long lvProducerChunkIndex() + { + return producerChunkIndex; + } + + final boolean casProducerChunkIndex(long expected, long value) + { + return UNSAFE.compareAndSwapLong(this, P_CHUNK_INDEX_OFFSET, expected, value); + } + + final void soProducerChunkIndex(long value) + { + UNSAFE.putOrderedLong(this, P_CHUNK_INDEX_OFFSET, value); + } + + final R lvProducerChunk() + { + return this.producerChunk; + } + + final void soProducerChunk(R chunk) + { + UNSAFE.putOrderedObject(this, P_CHUNK_OFFSET, chunk); + } +} + +abstract class MpUnboundedXaddArrayQueuePad3, E> + extends MpUnboundedXaddArrayQueueProducerChunk +{ + byte b000,b001,b002,b003,b004,b005,b006,b007;// 8b + byte b010,b011,b012,b013,b014,b015,b016,b017;// 16b + byte b020,b021,b022,b023,b024,b025,b026,b027;// 24b + byte b030,b031,b032,b033,b034,b035,b036,b037;// 32b + byte b040,b041,b042,b043,b044,b045,b046,b047;// 40b + byte b050,b051,b052,b053,b054,b055,b056,b057;// 48b + byte b060,b061,b062,b063,b064,b065,b066,b067;// 56b + byte b070,b071,b072,b073,b074,b075,b076,b077;// 64b + byte b100,b101,b102,b103,b104,b105,b106,b107;// 72b + byte b110,b111,b112,b113,b114,b115,b116,b117;// 80b + byte b120,b121,b122,b123,b124,b125,b126,b127;// 88b + byte b130,b131,b132,b133,b134,b135,b136,b137;// 96b + byte b140,b141,b142,b143,b144,b145,b146,b147;//104b + byte b150,b151,b152,b153,b154,b155,b156,b157;//112b + byte b160,b161,b162,b163,b164,b165,b166,b167;//120b + // byte b170,b171,b172,b173,b174,b175,b176,b177;//128b +} + +// $gen:ordered-fields +abstract class MpUnboundedXaddArrayQueueConsumerFields, E> + extends MpUnboundedXaddArrayQueuePad3 +{ + private final static long C_INDEX_OFFSET = + fieldOffset(MpUnboundedXaddArrayQueueConsumerFields.class, "consumerIndex"); + private final static long C_CHUNK_OFFSET = + fieldOffset(MpUnboundedXaddArrayQueueConsumerFields.class, "consumerChunk"); + + private volatile long consumerIndex; + private volatile R consumerChunk; + + @Override + public final long lvConsumerIndex() + { + return consumerIndex; + } + + final boolean casConsumerIndex(long expect, long newValue) + { + return UNSAFE.compareAndSwapLong(this, C_INDEX_OFFSET, expect, newValue); + } + + final R lpConsumerChunk() + { + return (R) UNSAFE.getObject(this, C_CHUNK_OFFSET); + } + + final R lvConsumerChunk() + { + return this.consumerChunk; + } + + final void soConsumerChunk(R newValue) + { + UNSAFE.putOrderedObject(this, C_CHUNK_OFFSET, newValue); + } + + final long lpConsumerIndex() + { + return UNSAFE.getLong(this, C_INDEX_OFFSET); + } + + final void soConsumerIndex(long newValue) + { + UNSAFE.putOrderedLong(this, C_INDEX_OFFSET, newValue); + } +} + +abstract class MpUnboundedXaddArrayQueuePad5, E> + extends MpUnboundedXaddArrayQueueConsumerFields +{ + byte b000,b001,b002,b003,b004,b005,b006,b007;// 8b + byte b010,b011,b012,b013,b014,b015,b016,b017;// 16b + byte b020,b021,b022,b023,b024,b025,b026,b027;// 24b + byte b030,b031,b032,b033,b034,b035,b036,b037;// 32b + byte b040,b041,b042,b043,b044,b045,b046,b047;// 40b + byte b050,b051,b052,b053,b054,b055,b056,b057;// 48b + byte b060,b061,b062,b063,b064,b065,b066,b067;// 56b + byte b070,b071,b072,b073,b074,b075,b076,b077;// 64b + byte b100,b101,b102,b103,b104,b105,b106,b107;// 72b + byte b110,b111,b112,b113,b114,b115,b116,b117;// 80b + byte b120,b121,b122,b123,b124,b125,b126,b127;// 88b + byte b130,b131,b132,b133,b134,b135,b136,b137;// 96b + byte b140,b141,b142,b143,b144,b145,b146,b147;//104b + byte b150,b151,b152,b153,b154,b155,b156,b157;//112b + byte b160,b161,b162,b163,b164,b165,b166,b167;//120b + // byte b170,b171,b172,b173,b174,b175,b176,b177;//128b +} + +/** + * Common infrastructure for the XADD queues. + * + * @author https://github.com/franz1981 + */ +abstract class MpUnboundedXaddArrayQueue, E> + extends MpUnboundedXaddArrayQueuePad5 + implements MessagePassingQueue, QueueProgressIndicators +{ + // it must be != MpUnboundedXaddChunk.NOT_USED + private static final long ROTATION = -2; + final int chunkMask; + final int chunkShift; + final int maxPooledChunks; + final SpscArrayQueue freeChunksPool; + + /** + * @param chunkSize The buffer size to be used in each chunk of this queue + * @param maxPooledChunks The maximum number of reused chunks kept around to avoid allocation, chunks are pre-allocated + */ + MpUnboundedXaddArrayQueue(int chunkSize, int maxPooledChunks) + { + if (!UnsafeAccess.SUPPORTS_GET_AND_ADD_LONG) + { + throw new IllegalStateException("Unsafe::getAndAddLong support (JDK 8+) is required for this queue to work"); + } + if (maxPooledChunks < 0) + { + throw new IllegalArgumentException("Expecting a positive maxPooledChunks, but got:"+maxPooledChunks); + } + chunkSize = Pow2.roundToPowerOfTwo(chunkSize); + + this.chunkMask = chunkSize - 1; + this.chunkShift = Integer.numberOfTrailingZeros(chunkSize); + freeChunksPool = new SpscArrayQueue(maxPooledChunks); + + final R first = newChunk(0, null, chunkSize, maxPooledChunks > 0); + soProducerChunk(first); + soProducerChunkIndex(0); + soConsumerChunk(first); + for (int i = 1; i < maxPooledChunks; i++) + { + freeChunksPool.offer(newChunk(NOT_USED, null, chunkSize, true)); + } + this.maxPooledChunks = maxPooledChunks; + } + + public final int chunkSize() + { + return chunkMask + 1; + } + + public final int maxPooledChunks() + { + return maxPooledChunks; + } + + abstract R newChunk(long index, R prev, int chunkSize, boolean pooled); + + @Override + public long currentProducerIndex() + { + return lvProducerIndex(); + } + + @Override + public long currentConsumerIndex() + { + return lvConsumerIndex(); + } + + /** + * We're here because currentChunk.index doesn't match the expectedChunkIndex. To resolve we must now chase the linked + * chunks to the appropriate chunk. More than one producer may end up racing to add or discover new chunks. + * + * @param initialChunk the starting point chunk, which does not match the required chunk index + * @param requiredChunkIndex the chunk index we need + * @return the chunk matching the required index + */ + final R producerChunkForIndex( + final R initialChunk, + final long requiredChunkIndex) + { + R currentChunk = initialChunk; + long jumpBackward; + while (true) + { + if (currentChunk == null) + { + currentChunk = lvProducerChunk(); + } + final long currentChunkIndex = currentChunk.lvIndex(); + assert currentChunkIndex != NOT_USED; + // if the required chunk index is less than the current chunk index then we need to walk the linked list of + // chunks back to the required index + jumpBackward = currentChunkIndex - requiredChunkIndex; + if (jumpBackward >= 0) + { + break; + } + // try validate against the last producer chunk index + if (lvProducerChunkIndex() == currentChunkIndex) + { + currentChunk = appendNextChunks(currentChunk, currentChunkIndex, -jumpBackward); + } + else + { + currentChunk = null; + } + } + for (long i = 0; i < jumpBackward; i++) + { + // prev cannot be null, because the consumer cannot null it without consuming the element for which we are + // trying to get the chunk. + currentChunk = currentChunk.lvPrev(); + assert currentChunk != null; + } + assert currentChunk.lvIndex() == requiredChunkIndex; + return currentChunk; + } + + protected final R appendNextChunks( + R currentChunk, + long currentChunkIndex, + long chunksToAppend) + { + assert currentChunkIndex != NOT_USED; + // prevent other concurrent attempts on appendNextChunk + if (!casProducerChunkIndex(currentChunkIndex, ROTATION)) + { + return null; + } + /* LOCKED FOR APPEND */ + { + // it is valid for the currentChunk to be consumed while appending is in flight, but it's not valid for the + // current chunk ordering to change otherwise. + assert currentChunkIndex == currentChunk.lvIndex(); + + for (long i = 1; i <= chunksToAppend; i++) + { + R newChunk = newOrPooledChunk(currentChunk, currentChunkIndex + i); + soProducerChunk(newChunk); + //link the next chunk only when finished + currentChunk.soNext(newChunk); + currentChunk = newChunk; + } + + // release appending + soProducerChunkIndex(currentChunkIndex + chunksToAppend); + } + /* UNLOCKED FOR APPEND */ + return currentChunk; + } + + private R newOrPooledChunk(R prevChunk, long nextChunkIndex) + { + R newChunk = freeChunksPool.poll(); + if (newChunk != null) + { + // single-writer: prevChunk::index == nextChunkIndex is protecting it + assert newChunk.lvIndex() < prevChunk.lvIndex(); + newChunk.soPrev(prevChunk); + // index set is releasing prev, allowing other pending offers to continue + newChunk.soIndex(nextChunkIndex); + } + else + { + newChunk = newChunk(nextChunkIndex, prevChunk, chunkMask + 1, false); + } + return newChunk; + } + + + /** + * Does not null out the first element of `next`, callers must do that + */ + final void moveToNextConsumerChunk(R cChunk, R next) + { + // avoid GC nepotism + cChunk.soNext(null); + next.soPrev(null); + // no need to cChunk.soIndex(NOT_USED) + if (cChunk.isPooled()) + { + final boolean pooled = freeChunksPool.offer(cChunk); + assert pooled; + } + this.soConsumerChunk(next); + // MC case: + // from now on the code is not single-threaded anymore and + // other consumers can move forward consumerIndex + } + + @Override + public Iterator iterator() + { + throw new UnsupportedOperationException(); + } + + @Override + public int size() + { + return IndexedQueueSizeUtil.size(this, IndexedQueueSizeUtil.PLAIN_DIVISOR); + } + + @Override + public boolean isEmpty() + { + return IndexedQueueSizeUtil.isEmpty(this); + } + + @Override + public int capacity() + { + return MessagePassingQueue.UNBOUNDED_CAPACITY; + } + + @Override + public boolean relaxedOffer(E e) + { + return offer(e); + } + + @Override + public int drain(Consumer c) + { + return MessagePassingQueueUtil.drain(this, c); + } + + @Override + public int fill(Supplier s) + { + final int chunkCapacity = chunkMask + 1; + final int offerBatch = Math.min(PortableJvmInfo.RECOMENDED_OFFER_BATCH, chunkCapacity); + return MessagePassingQueueUtil.fillInBatchesToLimit(this, s, offerBatch, chunkCapacity); + } + + @Override + public int drain(Consumer c, int limit) + { + return MessagePassingQueueUtil.drain(this, c, limit); + } + + @Override + public void drain(Consumer c, WaitStrategy wait, ExitCondition exit) + { + MessagePassingQueueUtil.drain(this, c, wait, exit); + } + + @Override + public void fill(Supplier s, WaitStrategy wait, ExitCondition exit) + { + MessagePassingQueueUtil.fill(this, s, wait, exit); + } + + @Override + public String toString() + { + return this.getClass().getName(); + } +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/MpUnboundedXaddChunk.java b/netty-jctools/src/main/java/org/jctools/queues/MpUnboundedXaddChunk.java new file mode 100644 index 0000000..eb0411d --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/MpUnboundedXaddChunk.java @@ -0,0 +1,95 @@ +package org.jctools.queues; + +import org.jctools.util.InternalAPI; + +import static org.jctools.util.UnsafeAccess.UNSAFE; +import static org.jctools.util.UnsafeAccess.fieldOffset; +import static org.jctools.util.UnsafeRefArrayAccess.*; + +@InternalAPI +public class MpUnboundedXaddChunk +{ + public final static int NOT_USED = -1; + + private static final long PREV_OFFSET = fieldOffset(MpUnboundedXaddChunk.class, "prev"); + private static final long NEXT_OFFSET = fieldOffset(MpUnboundedXaddChunk.class, "next"); + private static final long INDEX_OFFSET = fieldOffset(MpUnboundedXaddChunk.class, "index"); + + private final boolean pooled; + private final E[] buffer; + + private volatile R prev; + private volatile long index; + private volatile R next; + protected MpUnboundedXaddChunk(long index, R prev, int size, boolean pooled) + { + buffer = allocateRefArray(size); + // next is null + soPrev(prev); + spIndex(index); + this.pooled = pooled; + } + + public final boolean isPooled() + { + return pooled; + } + + public final long lvIndex() + { + return index; + } + + public final void soIndex(long index) + { + UNSAFE.putOrderedLong(this, INDEX_OFFSET, index); + } + + final void spIndex(long index) + { + UNSAFE.putLong(this, INDEX_OFFSET, index); + } + + public final R lvNext() + { + return next; + } + + public final void soNext(R value) + { + UNSAFE.putOrderedObject(this, NEXT_OFFSET, value); + } + + public final R lvPrev() + { + return prev; + } + + public final void soPrev(R value) + { + UNSAFE.putObject(this, PREV_OFFSET, value); + } + + public final void soElement(int index, E e) + { + soRefElement(buffer, calcRefElementOffset(index), e); + } + + public final E lvElement(int index) + { + return lvRefElement(buffer, calcRefElementOffset(index)); + } + + public final E spinForElement(int index, boolean isNull) + { + E[] buffer = this.buffer; + long offset = calcRefElementOffset(index); + E e; + do + { + e = lvRefElement(buffer, offset); + } + while (isNull != (e == null)); + return e; + } +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/MpmcArrayQueue.java b/netty-jctools/src/main/java/org/jctools/queues/MpmcArrayQueue.java new file mode 100755 index 0000000..02562e1 --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/MpmcArrayQueue.java @@ -0,0 +1,624 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues; + +import org.jctools.util.RangeUtil; + +import static org.jctools.util.UnsafeAccess.UNSAFE; +import static org.jctools.util.UnsafeAccess.fieldOffset; +import static org.jctools.util.UnsafeLongArrayAccess.*; +import static org.jctools.util.UnsafeRefArrayAccess.*; + +abstract class MpmcArrayQueueL1Pad extends ConcurrentSequencedCircularArrayQueue +{ + byte b000,b001,b002,b003,b004,b005,b006,b007;// 8b + byte b010,b011,b012,b013,b014,b015,b016,b017;// 16b + byte b020,b021,b022,b023,b024,b025,b026,b027;// 24b + byte b030,b031,b032,b033,b034,b035,b036,b037;// 32b + byte b040,b041,b042,b043,b044,b045,b046,b047;// 40b + byte b050,b051,b052,b053,b054,b055,b056,b057;// 48b + byte b060,b061,b062,b063,b064,b065,b066,b067;// 56b + byte b070,b071,b072,b073,b074,b075,b076,b077;// 64b + byte b100,b101,b102,b103,b104,b105,b106,b107;// 72b + byte b110,b111,b112,b113,b114,b115,b116,b117;// 80b + byte b120,b121,b122,b123,b124,b125,b126,b127;// 88b + byte b130,b131,b132,b133,b134,b135,b136,b137;// 96b + byte b140,b141,b142,b143,b144,b145,b146,b147;//104b + byte b150,b151,b152,b153,b154,b155,b156,b157;//112b + byte b160,b161,b162,b163,b164,b165,b166,b167;//120b + // byte b170,b171,b172,b173,b174,b175,b176,b177;//128b + + MpmcArrayQueueL1Pad(int capacity) + { + super(capacity); + } +} + +//$gen:ordered-fields +abstract class MpmcArrayQueueProducerIndexField extends MpmcArrayQueueL1Pad +{ + private final static long P_INDEX_OFFSET = fieldOffset(MpmcArrayQueueProducerIndexField.class, "producerIndex"); + + private volatile long producerIndex; + + MpmcArrayQueueProducerIndexField(int capacity) + { + super(capacity); + } + + @Override + public final long lvProducerIndex() + { + return producerIndex; + } + + final boolean casProducerIndex(long expect, long newValue) + { + return UNSAFE.compareAndSwapLong(this, P_INDEX_OFFSET, expect, newValue); + } +} + +abstract class MpmcArrayQueueL2Pad extends MpmcArrayQueueProducerIndexField +{ + byte b000,b001,b002,b003,b004,b005,b006,b007;// 8b + byte b010,b011,b012,b013,b014,b015,b016,b017;// 16b + byte b020,b021,b022,b023,b024,b025,b026,b027;// 24b + byte b030,b031,b032,b033,b034,b035,b036,b037;// 32b + byte b040,b041,b042,b043,b044,b045,b046,b047;// 40b + byte b050,b051,b052,b053,b054,b055,b056,b057;// 48b + byte b060,b061,b062,b063,b064,b065,b066,b067;// 56b + byte b070,b071,b072,b073,b074,b075,b076,b077;// 64b + byte b100,b101,b102,b103,b104,b105,b106,b107;// 72b + byte b110,b111,b112,b113,b114,b115,b116,b117;// 80b + byte b120,b121,b122,b123,b124,b125,b126,b127;// 88b + byte b130,b131,b132,b133,b134,b135,b136,b137;// 96b + byte b140,b141,b142,b143,b144,b145,b146,b147;//104b + byte b150,b151,b152,b153,b154,b155,b156,b157;//112b + byte b160,b161,b162,b163,b164,b165,b166,b167;//120b + byte b170,b171,b172,b173,b174,b175,b176,b177;//128b + + MpmcArrayQueueL2Pad(int capacity) + { + super(capacity); + } +} + +//$gen:ordered-fields +abstract class MpmcArrayQueueConsumerIndexField extends MpmcArrayQueueL2Pad +{ + private final static long C_INDEX_OFFSET = fieldOffset(MpmcArrayQueueConsumerIndexField.class, "consumerIndex"); + + private volatile long consumerIndex; + + MpmcArrayQueueConsumerIndexField(int capacity) + { + super(capacity); + } + + @Override + public final long lvConsumerIndex() + { + return consumerIndex; + } + + final boolean casConsumerIndex(long expect, long newValue) + { + return UNSAFE.compareAndSwapLong(this, C_INDEX_OFFSET, expect, newValue); + } +} + +abstract class MpmcArrayQueueL3Pad extends MpmcArrayQueueConsumerIndexField +{ + byte b000,b001,b002,b003,b004,b005,b006,b007;// 8b + byte b010,b011,b012,b013,b014,b015,b016,b017;// 16b + byte b020,b021,b022,b023,b024,b025,b026,b027;// 24b + byte b030,b031,b032,b033,b034,b035,b036,b037;// 32b + byte b040,b041,b042,b043,b044,b045,b046,b047;// 40b + byte b050,b051,b052,b053,b054,b055,b056,b057;// 48b + byte b060,b061,b062,b063,b064,b065,b066,b067;// 56b + byte b070,b071,b072,b073,b074,b075,b076,b077;// 64b + byte b100,b101,b102,b103,b104,b105,b106,b107;// 72b + byte b110,b111,b112,b113,b114,b115,b116,b117;// 80b + byte b120,b121,b122,b123,b124,b125,b126,b127;// 88b + byte b130,b131,b132,b133,b134,b135,b136,b137;// 96b + byte b140,b141,b142,b143,b144,b145,b146,b147;//104b + byte b150,b151,b152,b153,b154,b155,b156,b157;//112b + byte b160,b161,b162,b163,b164,b165,b166,b167;//120b + byte b170,b171,b172,b173,b174,b175,b176,b177;//128b + + MpmcArrayQueueL3Pad(int capacity) + { + super(capacity); + } +} + +/** + * A Multi-Producer-Multi-Consumer queue based on a {@link org.jctools.queues.ConcurrentCircularArrayQueue}. This + * implies that any and all threads may call the offer/poll/peek methods and correctness is maintained.
+ * This implementation follows patterns documented on the package level for False Sharing protection.
+ * The algorithm for offer/poll is an adaptation of the one put forward by D. Vyukov (See here). The original + * algorithm uses an array of structs which should offer nice locality properties but is sadly not possible in + * Java (waiting on Value Types or similar). The alternative explored here utilizes 2 arrays, one for each + * field of the struct. There is a further alternative in the experimental project which uses iteration phase + * markers to achieve the same algo and is closer structurally to the original, but sadly does not perform as + * well as this implementation.
+ *

+ * Tradeoffs to keep in mind: + *

    + *
  1. Padding for false sharing: counter fields and queue fields are all padded as well as either side of + * both arrays. We are trading memory to avoid false sharing(active and passive). + *
  2. 2 arrays instead of one: The algorithm requires an extra array of longs matching the size of the + * elements array. This is doubling/tripling the memory allocated for the buffer. + *
  3. Power of 2 capacity: Actual elements buffer (and sequence buffer) is the closest power of 2 larger or + * equal to the requested capacity. + *
+ */ +public class MpmcArrayQueue extends MpmcArrayQueueL3Pad +{ + public static final int MAX_LOOK_AHEAD_STEP = Integer.getInteger("jctools.mpmc.max.lookahead.step", 4096); + private final int lookAheadStep; + + public MpmcArrayQueue(final int capacity) + { + super(RangeUtil.checkGreaterThanOrEqual(capacity, 2, "capacity")); + lookAheadStep = Math.max(2, Math.min(capacity() / 4, MAX_LOOK_AHEAD_STEP)); + } + + @Override + public boolean offer(final E e) + { + if (null == e) + { + throw new NullPointerException(); + } + final long mask = this.mask; + final long capacity = mask + 1; + final long[] sBuffer = sequenceBuffer; + + long pIndex; + long seqOffset; + long seq; + long cIndex = Long.MIN_VALUE;// start with bogus value, hope we don't need it + do + { + pIndex = lvProducerIndex(); + seqOffset = calcCircularLongElementOffset(pIndex, mask); + seq = lvLongElement(sBuffer, seqOffset); + // consumer has not moved this seq forward, it's as last producer left + if (seq < pIndex) + { + // Extra check required to ensure [Queue.offer == false iff queue is full] + if (pIndex - capacity >= cIndex && // test against cached cIndex + pIndex - capacity >= (cIndex = lvConsumerIndex())) // test against latest cIndex + { + return false; + } + else + { + seq = pIndex + 1; // (+) hack to make it go around again without CAS + } + } + } + while (seq > pIndex || // another producer has moved the sequence(or +) + !casProducerIndex(pIndex, pIndex + 1)); // failed to increment + + // casProducerIndex ensures correct construction + spRefElement(buffer, calcCircularRefElementOffset(pIndex, mask), e); + // seq++; + soLongElement(sBuffer, seqOffset, pIndex + 1); + return true; + } + + /** + * {@inheritDoc} + *

+ * Because return null indicates queue is empty we cannot simply rely on next element visibility for poll + * and must test producer index when next element is not visible. + */ + @Override + public E poll() + { + // local load of field to avoid repeated loads after volatile reads + final long[] sBuffer = sequenceBuffer; + final long mask = this.mask; + + long cIndex; + long seq; + long seqOffset; + long expectedSeq; + long pIndex = -1; // start with bogus value, hope we don't need it + do + { + cIndex = lvConsumerIndex(); + seqOffset = calcCircularLongElementOffset(cIndex, mask); + seq = lvLongElement(sBuffer, seqOffset); + expectedSeq = cIndex + 1; + if (seq < expectedSeq) + { + // slot has not been moved by producer + if (cIndex >= pIndex && // test against cached pIndex + cIndex == (pIndex = lvProducerIndex())) // update pIndex if we must + { + // strict empty check, this ensures [Queue.poll() == null iff isEmpty()] + return null; + } + else + { + seq = expectedSeq + 1; // trip another go around + } + } + } + while (seq > expectedSeq || // another consumer beat us to it + !casConsumerIndex(cIndex, cIndex + 1)); // failed the CAS + + final long offset = calcCircularRefElementOffset(cIndex, mask); + final E e = lpRefElement(buffer, offset); + spRefElement(buffer, offset, null); + // i.e. seq += capacity + soLongElement(sBuffer, seqOffset, cIndex + mask + 1); + return e; + } + + @Override + public E peek() + { + // local load of field to avoid repeated loads after volatile reads + final long[] sBuffer = sequenceBuffer; + final long mask = this.mask; + + long cIndex; + long seq; + long seqOffset; + long expectedSeq; + long pIndex = -1; // start with bogus value, hope we don't need it + E e; + while (true) + { + cIndex = lvConsumerIndex(); + seqOffset = calcCircularLongElementOffset(cIndex, mask); + seq = lvLongElement(sBuffer, seqOffset); + expectedSeq = cIndex + 1; + if (seq < expectedSeq) + { + // slot has not been moved by producer + if (cIndex >= pIndex && // test against cached pIndex + cIndex == (pIndex = lvProducerIndex())) // update pIndex if we must + { + // strict empty check, this ensures [Queue.poll() == null iff isEmpty()] + return null; + } + } + else if (seq == expectedSeq) + { + final long offset = calcCircularRefElementOffset(cIndex, mask); + e = lvRefElement(buffer, offset); + if (lvConsumerIndex() == cIndex) + return e; + } + } + } + + @Override + public boolean relaxedOffer(E e) + { + if (null == e) + { + throw new NullPointerException(); + } + final long mask = this.mask; + final long[] sBuffer = sequenceBuffer; + + long pIndex; + long seqOffset; + long seq; + do + { + pIndex = lvProducerIndex(); + seqOffset = calcCircularLongElementOffset(pIndex, mask); + seq = lvLongElement(sBuffer, seqOffset); + if (seq < pIndex) + { // slot not cleared by consumer yet + return false; + } + } + while (seq > pIndex || // another producer has moved the sequence + !casProducerIndex(pIndex, pIndex + 1)); // failed to increment + + // casProducerIndex ensures correct construction + spRefElement(buffer, calcCircularRefElementOffset(pIndex, mask), e); + soLongElement(sBuffer, seqOffset, pIndex + 1); + return true; + } + + @Override + public E relaxedPoll() + { + final long[] sBuffer = sequenceBuffer; + final long mask = this.mask; + + long cIndex; + long seqOffset; + long seq; + long expectedSeq; + do + { + cIndex = lvConsumerIndex(); + seqOffset = calcCircularLongElementOffset(cIndex, mask); + seq = lvLongElement(sBuffer, seqOffset); + expectedSeq = cIndex + 1; + if (seq < expectedSeq) + { + return null; + } + } + while (seq > expectedSeq || // another consumer beat us to it + !casConsumerIndex(cIndex, cIndex + 1)); // failed the CAS + + final long offset = calcCircularRefElementOffset(cIndex, mask); + final E e = lpRefElement(buffer, offset); + spRefElement(buffer, offset, null); + soLongElement(sBuffer, seqOffset, cIndex + mask + 1); + return e; + } + + @Override + public E relaxedPeek() + { + // local load of field to avoid repeated loads after volatile reads + final long[] sBuffer = sequenceBuffer; + final long mask = this.mask; + + long cIndex; + long seq; + long seqOffset; + long expectedSeq; + E e; + do + { + cIndex = lvConsumerIndex(); + seqOffset = calcCircularLongElementOffset(cIndex, mask); + seq = lvLongElement(sBuffer, seqOffset); + expectedSeq = cIndex + 1; + if (seq < expectedSeq) + { + return null; + } + else if (seq == expectedSeq) + { + final long offset = calcCircularRefElementOffset(cIndex, mask); + e = lvRefElement(buffer, offset); + if (lvConsumerIndex() == cIndex) + return e; + } + } + while (true); + } + + @Override + public int drain(Consumer c, int limit) + { + if (null == c) + throw new IllegalArgumentException("c is null"); + if (limit < 0) + throw new IllegalArgumentException("limit is negative: " + limit); + if (limit == 0) + return 0; + + final long[] sBuffer = sequenceBuffer; + final long mask = this.mask; + final E[] buffer = this.buffer; + final int maxLookAheadStep = Math.min(this.lookAheadStep, limit); + int consumed = 0; + + while (consumed < limit) + { + final int remaining = limit - consumed; + final int lookAheadStep = Math.min(remaining, maxLookAheadStep); + final long cIndex = lvConsumerIndex(); + final long lookAheadIndex = cIndex + lookAheadStep - 1; + final long lookAheadSeqOffset = calcCircularLongElementOffset(lookAheadIndex, mask); + final long lookAheadSeq = lvLongElement(sBuffer, lookAheadSeqOffset); + final long expectedLookAheadSeq = lookAheadIndex + 1; + if (lookAheadSeq == expectedLookAheadSeq && casConsumerIndex(cIndex, expectedLookAheadSeq)) + { + for (int i = 0; i < lookAheadStep; i++) + { + final long index = cIndex + i; + final long seqOffset = calcCircularLongElementOffset(index, mask); + final long offset = calcCircularRefElementOffset(index, mask); + final long expectedSeq = index + 1; + while (lvLongElement(sBuffer, seqOffset) != expectedSeq) + { + + } + final E e = lpRefElement(buffer, offset); + spRefElement(buffer, offset, null); + soLongElement(sBuffer, seqOffset, index + mask + 1); + c.accept(e); + } + consumed += lookAheadStep; + } + else + { + if (lookAheadSeq < expectedLookAheadSeq) + { + if (notAvailable(cIndex, mask, sBuffer, cIndex + 1)) + { + return consumed; + } + } + return consumed + drainOneByOne(c, remaining); + } + } + return limit; + } + + private int drainOneByOne(Consumer c, int limit) + { + final long[] sBuffer = sequenceBuffer; + final long mask = this.mask; + final E[] buffer = this.buffer; + + long cIndex; + long seqOffset; + long seq; + long expectedSeq; + for (int i = 0; i < limit; i++) + { + do + { + cIndex = lvConsumerIndex(); + seqOffset = calcCircularLongElementOffset(cIndex, mask); + seq = lvLongElement(sBuffer, seqOffset); + expectedSeq = cIndex + 1; + if (seq < expectedSeq) + { + return i; + } + } + while (seq > expectedSeq || // another consumer beat us to it + !casConsumerIndex(cIndex, cIndex + 1)); // failed the CAS + + final long offset = calcCircularRefElementOffset(cIndex, mask); + final E e = lpRefElement(buffer, offset); + spRefElement(buffer, offset, null); + soLongElement(sBuffer, seqOffset, cIndex + mask + 1); + c.accept(e); + } + return limit; + } + + @Override + public int fill(Supplier s, int limit) + { + if (null == s) + throw new IllegalArgumentException("supplier is null"); + if (limit < 0) + throw new IllegalArgumentException("limit is negative:" + limit); + if (limit == 0) + return 0; + + final long[] sBuffer = sequenceBuffer; + final long mask = this.mask; + final E[] buffer = this.buffer; + final int maxLookAheadStep = Math.min(this.lookAheadStep, limit); + int produced = 0; + + while (produced < limit) + { + final int remaining = limit - produced; + final int lookAheadStep = Math.min(remaining, maxLookAheadStep); + final long pIndex = lvProducerIndex(); + final long lookAheadIndex = pIndex + lookAheadStep - 1; + final long lookAheadSeqOffset = calcCircularLongElementOffset(lookAheadIndex, mask); + final long lookAheadSeq = lvLongElement(sBuffer, lookAheadSeqOffset); + final long expectedLookAheadSeq = lookAheadIndex; + if (lookAheadSeq == expectedLookAheadSeq && casProducerIndex(pIndex, expectedLookAheadSeq + 1)) + { + for (int i = 0; i < lookAheadStep; i++) + { + final long index = pIndex + i; + final long seqOffset = calcCircularLongElementOffset(index, mask); + final long offset = calcCircularRefElementOffset(index, mask); + while (lvLongElement(sBuffer, seqOffset) != index) + { + + } + // Ordered store ensures correct construction + soRefElement(buffer, offset, s.get()); + soLongElement(sBuffer, seqOffset, index + 1); + } + produced += lookAheadStep; + } + else + { + if (lookAheadSeq < expectedLookAheadSeq) + { + if (notAvailable(pIndex, mask, sBuffer, pIndex)) + { + return produced; + } + } + return produced + fillOneByOne(s, remaining); + } + } + return limit; + } + + private boolean notAvailable(long index, long mask, long[] sBuffer, long expectedSeq) + { + final long seqOffset = calcCircularLongElementOffset(index, mask); + final long seq = lvLongElement(sBuffer, seqOffset); + if (seq < expectedSeq) + { + return true; + } + return false; + } + + private int fillOneByOne(Supplier s, int limit) + { + final long[] sBuffer = sequenceBuffer; + final long mask = this.mask; + final E[] buffer = this.buffer; + + long pIndex; + long seqOffset; + long seq; + for (int i = 0; i < limit; i++) + { + do + { + pIndex = lvProducerIndex(); + seqOffset = calcCircularLongElementOffset(pIndex, mask); + seq = lvLongElement(sBuffer, seqOffset); + if (seq < pIndex) + { // slot not cleared by consumer yet + return i; + } + } + while (seq > pIndex || // another producer has moved the sequence + !casProducerIndex(pIndex, pIndex + 1)); // failed to increment + // Ordered store ensures correct construction + soRefElement(buffer, calcCircularRefElementOffset(pIndex, mask), s.get()); + soLongElement(sBuffer, seqOffset, pIndex + 1); + } + return limit; + } + + @Override + public int drain(Consumer c) + { + return MessagePassingQueueUtil.drain(this, c); + } + + @Override + public int fill(Supplier s) + { + return MessagePassingQueueUtil.fillBounded(this, s); + } + + @Override + public void drain(Consumer c, WaitStrategy w, ExitCondition exit) + { + MessagePassingQueueUtil.drain(this, c, w, exit); + } + + @Override + public void fill(Supplier s, WaitStrategy wait, ExitCondition exit) + { + MessagePassingQueueUtil.fill(this, s, wait, exit); + } +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/MpmcUnboundedXaddArrayQueue.java b/netty-jctools/src/main/java/org/jctools/queues/MpmcUnboundedXaddArrayQueue.java new file mode 100644 index 0000000..eabd2d2 --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/MpmcUnboundedXaddArrayQueue.java @@ -0,0 +1,475 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues; + + + +/** + * An MPMC array queue which grows unbounded in linked chunks.
+ * Differently from {@link MpmcArrayQueue} it is designed to provide a better scaling when more + * producers are concurrently offering.
+ * Users should be aware that {@link #poll()} could spin while awaiting a new element to be available: + * to avoid this behaviour {@link #relaxedPoll()} should be used instead, accounting for the semantic differences + * between the twos. + * + * @author https://github.com/franz1981 + */ +public class MpmcUnboundedXaddArrayQueue extends MpUnboundedXaddArrayQueue, E> +{ + + /** + * @param chunkSize The buffer size to be used in each chunk of this queue + * @param maxPooledChunks The maximum number of reused chunks kept around to avoid allocation, chunks are pre-allocated + */ + public MpmcUnboundedXaddArrayQueue(int chunkSize, int maxPooledChunks) + { + super(chunkSize, maxPooledChunks); + } + + public MpmcUnboundedXaddArrayQueue(int chunkSize) + { + this(chunkSize, 2); + } + + @Override + final MpmcUnboundedXaddChunk newChunk(long index, MpmcUnboundedXaddChunk prev, int chunkSize, boolean pooled) + { + return new MpmcUnboundedXaddChunk(index, prev, chunkSize, pooled); + } + + @Override + public boolean offer(E e) + { + if (null == e) + { + throw new NullPointerException(); + } + final int chunkMask = this.chunkMask; + final int chunkShift = this.chunkShift; + + final long pIndex = getAndIncrementProducerIndex(); + + final int piChunkOffset = (int) (pIndex & chunkMask); + final long piChunkIndex = pIndex >> chunkShift; + + MpmcUnboundedXaddChunk pChunk = lvProducerChunk(); + if (pChunk.lvIndex() != piChunkIndex) + { + // Other producers may have advanced the producer chunk as we claimed a slot in a prev chunk, or we may have + // now stepped into a brand new chunk which needs appending. + pChunk = producerChunkForIndex(pChunk, piChunkIndex); + } + + final boolean isPooled = pChunk.isPooled(); + + if (isPooled) + { + // wait any previous consumer to finish its job + pChunk.spinForElement(piChunkOffset, true); + } + pChunk.soElement(piChunkOffset, e); + if (isPooled) + { + pChunk.soSequence(piChunkOffset, piChunkIndex); + } + return true; + } + + @Override + public E poll() + { + final int chunkMask = this.chunkMask; + final int chunkShift = this.chunkShift; + long cIndex; + MpmcUnboundedXaddChunk cChunk; + int ciChunkOffset; + boolean isFirstElementOfNewChunk; + boolean pooled = false; + E e = null; + MpmcUnboundedXaddChunk next = null; + long pIndex = -1; // start with bogus value, hope we don't need it + long ciChunkIndex; + while (true) + { + isFirstElementOfNewChunk = false; + cIndex = this.lvConsumerIndex(); + // chunk is in sync with the index, and is safe to mutate after CAS of index (because we pre-verify it + // matched the indicate ciChunkIndex) + cChunk = this.lvConsumerChunk(); + + ciChunkOffset = (int) (cIndex & chunkMask); + ciChunkIndex = cIndex >> chunkShift; + + final long ccChunkIndex = cChunk.lvIndex(); + if (ciChunkOffset == 0 && cIndex != 0) { + if (ciChunkIndex - ccChunkIndex != 1) + { + continue; + } + isFirstElementOfNewChunk = true; + next = cChunk.lvNext(); + // next could have been modified by another racing consumer, but: + // - if null: it still needs to check q empty + casConsumerIndex + // - if !null: it will fail on casConsumerIndex + if (next == null) + { + if (cIndex >= pIndex && // test against cached pIndex + cIndex == (pIndex = lvProducerIndex())) // update pIndex if we must + { + // strict empty check, this ensures [Queue.poll() == null iff isEmpty()] + return null; + } + // we will go ahead with the CAS and have the winning consumer spin for the next buffer + } + // not empty: can attempt the cas (and transition to next chunk if successful) + if (casConsumerIndex(cIndex, cIndex + 1)) + { + break; + } + continue; + } + if (ccChunkIndex > ciChunkIndex) + { + //stale view of the world + continue; + } + // mid chunk elements + assert !isFirstElementOfNewChunk && ccChunkIndex <= ciChunkIndex; + pooled = cChunk.isPooled(); + if (ccChunkIndex == ciChunkIndex) + { + if (pooled) + { + // Pooled chunks need a stronger guarantee than just element null checking in case of a stale view + // on a reused entry where a racing consumer has grabbed the slot but not yet null-ed it out and a + // producer has not yet set it to the new value. + final long sequence = cChunk.lvSequence(ciChunkOffset); + if (sequence == ciChunkIndex) + { + if (casConsumerIndex(cIndex, cIndex + 1)) + { + break; + } + continue; + } + if (sequence > ciChunkIndex) + { + //stale view of the world + continue; + } + // sequence < ciChunkIndex: element yet to be set? + } + else + { + e = cChunk.lvElement(ciChunkOffset); + if (e != null) + { + if (casConsumerIndex(cIndex, cIndex + 1)) + { + break; + } + continue; + } + // e == null: element yet to be set? + } + } + // ccChunkIndex < ciChunkIndex || e == null || sequence < ciChunkIndex: + if (cIndex >= pIndex && // test against cached pIndex + cIndex == (pIndex = lvProducerIndex())) // update pIndex if we must + { + // strict empty check, this ensures [Queue.poll() == null iff isEmpty()] + return null; + } + } + + // if we are the isFirstElementOfNewChunk we need to get the consumer chunk + if (isFirstElementOfNewChunk) + { + e = switchToNextConsumerChunkAndPoll(cChunk, next, ciChunkIndex); + } + else + { + if (pooled) + { + e = cChunk.lvElement(ciChunkOffset); + } + assert !cChunk.isPooled() || (cChunk.isPooled() && cChunk.lvSequence(ciChunkOffset) == ciChunkIndex); + + cChunk.soElement(ciChunkOffset, null); + } + return e; + } + + private E switchToNextConsumerChunkAndPoll( + MpmcUnboundedXaddChunk cChunk, + MpmcUnboundedXaddChunk next, + long expectedChunkIndex) + { + if (next == null) { + final long ccChunkIndex = expectedChunkIndex - 1; + assert cChunk.lvIndex() == ccChunkIndex; + if (lvProducerChunkIndex() == ccChunkIndex) { + // no need to help too much here or the consumer latency will be hurt + next = appendNextChunks(cChunk, ccChunkIndex, 1); + } + } + while (next == null) + { + next = cChunk.lvNext(); + } + // we can freely spin awaiting producer, because we are the only one in charge to + // rotate the consumer buffer and use next + final E e = next.spinForElement(0, false); + + final boolean pooled = next.isPooled(); + if (pooled) + { + next.spinForSequence(0, expectedChunkIndex); + } + + next.soElement(0, null); + moveToNextConsumerChunk(cChunk, next); + return e; + } + + @Override + public E peek() + { + final int chunkMask = this.chunkMask; + final int chunkShift = this.chunkShift; + long cIndex; + E e; + do + { + e = null; + cIndex = this.lvConsumerIndex(); + MpmcUnboundedXaddChunk cChunk = this.lvConsumerChunk(); + final int ciChunkOffset = (int) (cIndex & chunkMask); + final long ciChunkIndex = cIndex >> chunkShift; + final boolean firstElementOfNewChunk = ciChunkOffset == 0 && cIndex != 0; + if (firstElementOfNewChunk) + { + final long expectedChunkIndex = ciChunkIndex - 1; + if (expectedChunkIndex != cChunk.lvIndex()) + { + continue; + } + final MpmcUnboundedXaddChunk next = cChunk.lvNext(); + if (next == null) + { + continue; + } + cChunk = next; + } + if (cChunk.isPooled()) + { + if (cChunk.lvSequence(ciChunkOffset) != ciChunkIndex) + { + continue; + } + } else { + if (cChunk.lvIndex() != ciChunkIndex) + { + continue; + } + } + e = cChunk.lvElement(ciChunkOffset); + } + // checking again vs consumerIndex changes is necessary to verify that e is still valid + while ((e == null && cIndex != lvProducerIndex()) || + (e != null && cIndex != lvConsumerIndex())); + return e; + } + + @Override + public E relaxedPoll() + { + final int chunkMask = this.chunkMask; + final int chunkShift = this.chunkShift; + final long cIndex = this.lvConsumerIndex(); + final MpmcUnboundedXaddChunk cChunk = this.lvConsumerChunk(); + + final int ciChunkOffset = (int) (cIndex & chunkMask); + final long ciChunkIndex = cIndex >> chunkShift; + + final boolean firstElementOfNewChunk = ciChunkOffset == 0 && cIndex != 0; + if (firstElementOfNewChunk) + { + final long expectedChunkIndex = ciChunkIndex - 1; + final MpmcUnboundedXaddChunk next; + final long ccChunkIndex = cChunk.lvIndex(); + if (expectedChunkIndex != ccChunkIndex || (next = cChunk.lvNext()) == null) + { + return null; + } + E e = null; + final boolean pooled = next.isPooled(); + if (pooled) + { + if (next.lvSequence(0) != ciChunkIndex) + { + return null; + } + } + else + { + e = next.lvElement(0); + if (e == null) + { + return null; + } + } + if (!casConsumerIndex(cIndex, cIndex + 1)) + { + return null; + } + if (pooled) + { + e = next.lvElement(0); + } + assert e != null; + + next.soElement(0, null); + moveToNextConsumerChunk(cChunk, next); + return e; + } + else + { + final boolean pooled = cChunk.isPooled(); + E e = null; + if (pooled) + { + final long sequence = cChunk.lvSequence(ciChunkOffset); + if (sequence != ciChunkIndex) + { + return null; + } + } + else + { + final long ccChunkIndex = cChunk.lvIndex(); + if (ccChunkIndex != ciChunkIndex || (e = cChunk.lvElement(ciChunkOffset)) == null) + { + return null; + } + } + if (!casConsumerIndex(cIndex, cIndex + 1)) + { + return null; + } + if (pooled) + { + e = cChunk.lvElement(ciChunkOffset); + assert e != null; + } + assert !pooled || (pooled && cChunk.lvSequence(ciChunkOffset) == ciChunkIndex); + cChunk.soElement(ciChunkOffset, null); + return e; + } + } + + @Override + public E relaxedPeek() + { + final int chunkMask = this.chunkMask; + final int chunkShift = this.chunkShift; + final long cIndex = this.lvConsumerIndex(); + final int ciChunkOffset = (int) (cIndex & chunkMask); + final long ciChunkIndex = cIndex >> chunkShift; + + MpmcUnboundedXaddChunk consumerBuffer = this.lvConsumerChunk(); + + final int chunkSize = chunkMask + 1; + final boolean firstElementOfNewChunk = ciChunkOffset == 0 && cIndex >= chunkSize; + if (firstElementOfNewChunk) + { + final long expectedChunkIndex = ciChunkIndex - 1; + if (expectedChunkIndex != consumerBuffer.lvIndex()) + { + return null; + } + final MpmcUnboundedXaddChunk next = consumerBuffer.lvNext(); + if (next == null) + { + return null; + } + consumerBuffer = next; + } + if (consumerBuffer.isPooled()) + { + if (consumerBuffer.lvSequence(ciChunkOffset) != ciChunkIndex) + { + return null; + } + } + else + { + if (consumerBuffer.lvIndex() != ciChunkIndex) + { + return null; + } + } + final E e = consumerBuffer.lvElement(ciChunkOffset); + // checking again vs consumerIndex changes is necessary to verify that e is still valid + if (cIndex != lvConsumerIndex()) + { + return null; + } + return e; + } + + @Override + public int fill(Supplier s, int limit) + { + if (null == s) + throw new IllegalArgumentException("supplier is null"); + if (limit < 0) + throw new IllegalArgumentException("limit is negative:" + limit); + if (limit == 0) + return 0; + + final int chunkShift = this.chunkShift; + final int chunkMask = this.chunkMask; + long producerSeq = getAndAddProducerIndex(limit); + MpmcUnboundedXaddChunk producerBuffer = null; + for (int i = 0; i < limit; i++) + { + final int pOffset = (int) (producerSeq & chunkMask); + long chunkIndex = producerSeq >> chunkShift; + if (producerBuffer == null || producerBuffer.lvIndex() != chunkIndex) + { + producerBuffer = producerChunkForIndex(producerBuffer, chunkIndex); + if (producerBuffer.isPooled()) + { + chunkIndex = producerBuffer.lvIndex(); + } + } + if (producerBuffer.isPooled()) + { + while (producerBuffer.lvElement(pOffset) != null) + { + + } + } + producerBuffer.soElement(pOffset, s.get()); + if (producerBuffer.isPooled()) + { + producerBuffer.soSequence(pOffset, chunkIndex); + } + producerSeq++; + } + return limit; + } + +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/MpmcUnboundedXaddChunk.java b/netty-jctools/src/main/java/org/jctools/queues/MpmcUnboundedXaddChunk.java new file mode 100644 index 0000000..8ed0d8f --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/MpmcUnboundedXaddChunk.java @@ -0,0 +1,66 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues; + +import org.jctools.util.InternalAPI; + +import java.util.Arrays; + +import static org.jctools.util.UnsafeLongArrayAccess.*; + +@InternalAPI +public final class MpmcUnboundedXaddChunk extends MpUnboundedXaddChunk, E> +{ + private final long[] sequence; + + public MpmcUnboundedXaddChunk(long index, MpmcUnboundedXaddChunk prev, int size, boolean pooled) + { + super(index, prev, size, pooled); + if (pooled) + { + sequence = allocateLongArray(size); + Arrays.fill(sequence, MpmcUnboundedXaddChunk.NOT_USED); + } + else + { + sequence = null; + } + } + + public void soSequence(int index, long e) + { + assert isPooled(); + soLongElement(sequence, calcLongElementOffset(index), e); + } + + public long lvSequence(int index) + { + assert isPooled(); + return lvLongElement(sequence, calcLongElementOffset(index)); + } + + public void spinForSequence(int index, long e) + { + assert isPooled(); + final long[] sequence = this.sequence; + final long offset = calcLongElementOffset(index); + while (true) + { + if (lvLongElement(sequence, offset) == e) + { + break; + } + } + } +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/MpscArrayQueue.java b/netty-jctools/src/main/java/org/jctools/queues/MpscArrayQueue.java new file mode 100755 index 0000000..d338991 --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/MpscArrayQueue.java @@ -0,0 +1,588 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues; + +import static org.jctools.util.UnsafeAccess.UNSAFE; +import static org.jctools.util.UnsafeAccess.fieldOffset; +import static org.jctools.util.UnsafeRefArrayAccess.*; + +abstract class MpscArrayQueueL1Pad extends ConcurrentCircularArrayQueue +{ + byte b000,b001,b002,b003,b004,b005,b006,b007;// 8b + byte b010,b011,b012,b013,b014,b015,b016,b017;// 16b + byte b020,b021,b022,b023,b024,b025,b026,b027;// 24b + byte b030,b031,b032,b033,b034,b035,b036,b037;// 32b + byte b040,b041,b042,b043,b044,b045,b046,b047;// 40b + byte b050,b051,b052,b053,b054,b055,b056,b057;// 48b + byte b060,b061,b062,b063,b064,b065,b066,b067;// 56b + byte b070,b071,b072,b073,b074,b075,b076,b077;// 64b + byte b100,b101,b102,b103,b104,b105,b106,b107;// 72b + byte b110,b111,b112,b113,b114,b115,b116,b117;// 80b + byte b120,b121,b122,b123,b124,b125,b126,b127;// 88b + byte b130,b131,b132,b133,b134,b135,b136,b137;// 96b + byte b140,b141,b142,b143,b144,b145,b146,b147;//104b + byte b150,b151,b152,b153,b154,b155,b156,b157;//112b + byte b160,b161,b162,b163,b164,b165,b166,b167;//120b + // byte b170,b171,b172,b173,b174,b175,b176,b177;//128b + + MpscArrayQueueL1Pad(int capacity) + { + super(capacity); + } +} + +//$gen:ordered-fields +abstract class MpscArrayQueueProducerIndexField extends MpscArrayQueueL1Pad +{ + private final static long P_INDEX_OFFSET = fieldOffset(MpscArrayQueueProducerIndexField.class, "producerIndex"); + + private volatile long producerIndex; + + MpscArrayQueueProducerIndexField(int capacity) + { + super(capacity); + } + + @Override + public final long lvProducerIndex() + { + return producerIndex; + } + + final boolean casProducerIndex(long expect, long newValue) + { + return UNSAFE.compareAndSwapLong(this, P_INDEX_OFFSET, expect, newValue); + } +} + +abstract class MpscArrayQueueMidPad extends MpscArrayQueueProducerIndexField +{ + byte b000,b001,b002,b003,b004,b005,b006,b007;// 8b + byte b010,b011,b012,b013,b014,b015,b016,b017;// 16b + byte b020,b021,b022,b023,b024,b025,b026,b027;// 24b + byte b030,b031,b032,b033,b034,b035,b036,b037;// 32b + byte b040,b041,b042,b043,b044,b045,b046,b047;// 40b + byte b050,b051,b052,b053,b054,b055,b056,b057;// 48b + byte b060,b061,b062,b063,b064,b065,b066,b067;// 56b + byte b070,b071,b072,b073,b074,b075,b076,b077;// 64b + byte b100,b101,b102,b103,b104,b105,b106,b107;// 72b + byte b110,b111,b112,b113,b114,b115,b116,b117;// 80b + byte b120,b121,b122,b123,b124,b125,b126,b127;// 88b + byte b130,b131,b132,b133,b134,b135,b136,b137;// 96b + byte b140,b141,b142,b143,b144,b145,b146,b147;//104b + byte b150,b151,b152,b153,b154,b155,b156,b157;//112b + byte b160,b161,b162,b163,b164,b165,b166,b167;//120b + byte b170,b171,b172,b173,b174,b175,b176,b177;//128b + + MpscArrayQueueMidPad(int capacity) + { + super(capacity); + } +} + +//$gen:ordered-fields +abstract class MpscArrayQueueProducerLimitField extends MpscArrayQueueMidPad +{ + private final static long P_LIMIT_OFFSET = fieldOffset(MpscArrayQueueProducerLimitField.class, "producerLimit"); + + // First unavailable index the producer may claim up to before rereading the consumer index + private volatile long producerLimit; + + MpscArrayQueueProducerLimitField(int capacity) + { + super(capacity); + this.producerLimit = capacity; + } + + final long lvProducerLimit() + { + return producerLimit; + } + + final void soProducerLimit(long newValue) + { + UNSAFE.putOrderedLong(this, P_LIMIT_OFFSET, newValue); + } +} + +abstract class MpscArrayQueueL2Pad extends MpscArrayQueueProducerLimitField +{ + byte b000,b001,b002,b003,b004,b005,b006,b007;// 8b + byte b010,b011,b012,b013,b014,b015,b016,b017;// 16b + byte b020,b021,b022,b023,b024,b025,b026,b027;// 24b + byte b030,b031,b032,b033,b034,b035,b036,b037;// 32b + byte b040,b041,b042,b043,b044,b045,b046,b047;// 40b + byte b050,b051,b052,b053,b054,b055,b056,b057;// 48b + byte b060,b061,b062,b063,b064,b065,b066,b067;// 56b + byte b070,b071,b072,b073,b074,b075,b076,b077;// 64b + byte b100,b101,b102,b103,b104,b105,b106,b107;// 72b + byte b110,b111,b112,b113,b114,b115,b116,b117;// 80b + byte b120,b121,b122,b123,b124,b125,b126,b127;// 88b + byte b130,b131,b132,b133,b134,b135,b136,b137;// 96b + byte b140,b141,b142,b143,b144,b145,b146,b147;//104b + byte b150,b151,b152,b153,b154,b155,b156,b157;//112b + byte b160,b161,b162,b163,b164,b165,b166,b167;//120b + // byte b170,b171,b172,b173,b174,b175,b176,b177;//128b + + MpscArrayQueueL2Pad(int capacity) + { + super(capacity); + } +} + +//$gen:ordered-fields +abstract class MpscArrayQueueConsumerIndexField extends MpscArrayQueueL2Pad +{ + private final static long C_INDEX_OFFSET = fieldOffset(MpscArrayQueueConsumerIndexField.class, "consumerIndex"); + + private volatile long consumerIndex; + + MpscArrayQueueConsumerIndexField(int capacity) + { + super(capacity); + } + + @Override + public final long lvConsumerIndex() + { + return consumerIndex; + } + + final long lpConsumerIndex() + { + return UNSAFE.getLong(this, C_INDEX_OFFSET); + } + + final void soConsumerIndex(long newValue) + { + UNSAFE.putOrderedLong(this, C_INDEX_OFFSET, newValue); + } +} + +abstract class MpscArrayQueueL3Pad extends MpscArrayQueueConsumerIndexField +{ + byte b000,b001,b002,b003,b004,b005,b006,b007;// 8b + byte b010,b011,b012,b013,b014,b015,b016,b017;// 16b + byte b020,b021,b022,b023,b024,b025,b026,b027;// 24b + byte b030,b031,b032,b033,b034,b035,b036,b037;// 32b + byte b040,b041,b042,b043,b044,b045,b046,b047;// 40b + byte b050,b051,b052,b053,b054,b055,b056,b057;// 48b + byte b060,b061,b062,b063,b064,b065,b066,b067;// 56b + byte b070,b071,b072,b073,b074,b075,b076,b077;// 64b + byte b100,b101,b102,b103,b104,b105,b106,b107;// 72b + byte b110,b111,b112,b113,b114,b115,b116,b117;// 80b + byte b120,b121,b122,b123,b124,b125,b126,b127;// 88b + byte b130,b131,b132,b133,b134,b135,b136,b137;// 96b + byte b140,b141,b142,b143,b144,b145,b146,b147;//104b + byte b150,b151,b152,b153,b154,b155,b156,b157;//112b + byte b160,b161,b162,b163,b164,b165,b166,b167;//120b + byte b170,b171,b172,b173,b174,b175,b176,b177;//128b + + MpscArrayQueueL3Pad(int capacity) + { + super(capacity); + } +} + +/** + * A Multi-Producer-Single-Consumer queue based on a {@link org.jctools.queues.ConcurrentCircularArrayQueue}. This + * implies that any thread may call the offer method, but only a single thread may call poll/peek for correctness to + * maintained.
+ * This implementation follows patterns documented on the package level for False Sharing protection.
+ * This implementation is using the Fast Flow + * method for polling from the queue (with minor change to correctly publish the index) and an extension of + * the Leslie Lamport concurrent queue algorithm (originated by Martin Thompson) on the producer side. + */ +public class MpscArrayQueue extends MpscArrayQueueL3Pad +{ + + public MpscArrayQueue(final int capacity) + { + super(capacity); + } + + /** + * {@link #offer}} if {@link #size()} is less than threshold. + * + * @param e the object to offer onto the queue, not null + * @param threshold the maximum allowable size + * @return true if the offer is successful, false if queue size exceeds threshold + * @since 1.0.1 + */ + public boolean offerIfBelowThreshold(final E e, int threshold) + { + if (null == e) + { + throw new NullPointerException(); + } + + final long mask = this.mask; + final long capacity = mask + 1; + + long producerLimit = lvProducerLimit(); + long pIndex; + do + { + pIndex = lvProducerIndex(); + long available = producerLimit - pIndex; + long size = capacity - available; + if (size >= threshold) + { + final long cIndex = lvConsumerIndex(); + size = pIndex - cIndex; + if (size >= threshold) + { + return false; // the size exceeds threshold + } + else + { + // update producer limit to the next index that we must recheck the consumer index + producerLimit = cIndex + capacity; + + // this is racy, but the race is benign + soProducerLimit(producerLimit); + } + } + } + while (!casProducerIndex(pIndex, pIndex + 1)); + /* + * NOTE: the new producer index value is made visible BEFORE the element in the array. If we relied on + * the index visibility to poll() we would need to handle the case where the element is not visible. + */ + + // Won CAS, move on to storing + final long offset = calcCircularRefElementOffset(pIndex, mask); + soRefElement(buffer, offset, e); + return true; // AWESOME :) + } + + /** + * {@inheritDoc}
+ *

+ * IMPLEMENTATION NOTES:
+ * Lock free offer using a single CAS. As class name suggests access is permitted to many threads + * concurrently. + * + * @see java.util.Queue#offer + * @see org.jctools.queues.MessagePassingQueue#offer + */ + @Override + public boolean offer(final E e) + { + if (null == e) + { + throw new NullPointerException(); + } + + // use a cached view on consumer index (potentially updated in loop) + final long mask = this.mask; + long producerLimit = lvProducerLimit(); + long pIndex; + do + { + pIndex = lvProducerIndex(); + if (pIndex >= producerLimit) + { + final long cIndex = lvConsumerIndex(); + producerLimit = cIndex + mask + 1; + + if (pIndex >= producerLimit) + { + return false; // FULL :( + } + else + { + // update producer limit to the next index that we must recheck the consumer index + // this is racy, but the race is benign + soProducerLimit(producerLimit); + } + } + } + while (!casProducerIndex(pIndex, pIndex + 1)); + /* + * NOTE: the new producer index value is made visible BEFORE the element in the array. If we relied on + * the index visibility to poll() we would need to handle the case where the element is not visible. + */ + + // Won CAS, move on to storing + final long offset = calcCircularRefElementOffset(pIndex, mask); + soRefElement(buffer, offset, e); + return true; // AWESOME :) + } + + /** + * A wait free alternative to offer which fails on CAS failure. + * + * @param e new element, not null + * @return 1 if next element cannot be filled, -1 if CAS failed, 0 if successful + */ + public final int failFastOffer(final E e) + { + if (null == e) + { + throw new NullPointerException(); + } + final long mask = this.mask; + final long capacity = mask + 1; + final long pIndex = lvProducerIndex(); + long producerLimit = lvProducerLimit(); + if (pIndex >= producerLimit) + { + final long cIndex = lvConsumerIndex(); + producerLimit = cIndex + capacity; + if (pIndex >= producerLimit) + { + return 1; // FULL :( + } + else + { + // update producer limit to the next index that we must recheck the consumer index + soProducerLimit(producerLimit); + } + } + + // look Ma, no loop! + if (!casProducerIndex(pIndex, pIndex + 1)) + { + return -1; // CAS FAIL :( + } + + // Won CAS, move on to storing + final long offset = calcCircularRefElementOffset(pIndex, mask); + soRefElement(buffer, offset, e); + return 0; // AWESOME :) + } + + /** + * {@inheritDoc} + *

+ * IMPLEMENTATION NOTES:
+ * Lock free poll using ordered loads/stores. As class name suggests access is limited to a single thread. + * + * @see java.util.Queue#poll + * @see org.jctools.queues.MessagePassingQueue#poll + */ + @Override + public E poll() + { + final long cIndex = lpConsumerIndex(); + final long offset = calcCircularRefElementOffset(cIndex, mask); + // Copy field to avoid re-reading after volatile load + final E[] buffer = this.buffer; + + // If we can't see the next available element we can't poll + E e = lvRefElement(buffer, offset); + if (null == e) + { + /* + * NOTE: Queue may not actually be empty in the case of a producer (P1) being interrupted after + * winning the CAS on offer but before storing the element in the queue. Other producers may go on + * to fill up the queue after this element. + */ + if (cIndex != lvProducerIndex()) + { + do + { + e = lvRefElement(buffer, offset); + } + while (e == null); + } + else + { + return null; + } + } + + spRefElement(buffer, offset, null); + soConsumerIndex(cIndex + 1); + return e; + } + + /** + * {@inheritDoc} + *

+ * IMPLEMENTATION NOTES:
+ * Lock free peek using ordered loads. As class name suggests access is limited to a single thread. + * + * @see java.util.Queue#poll + * @see org.jctools.queues.MessagePassingQueue#poll + */ + @Override + public E peek() + { + // Copy field to avoid re-reading after volatile load + final E[] buffer = this.buffer; + + final long cIndex = lpConsumerIndex(); + final long offset = calcCircularRefElementOffset(cIndex, mask); + E e = lvRefElement(buffer, offset); + if (null == e) + { + /* + * NOTE: Queue may not actually be empty in the case of a producer (P1) being interrupted after + * winning the CAS on offer but before storing the element in the queue. Other producers may go on + * to fill up the queue after this element. + */ + if (cIndex != lvProducerIndex()) + { + do + { + e = lvRefElement(buffer, offset); + } + while (e == null); + } + else + { + return null; + } + } + return e; + } + + @Override + public boolean relaxedOffer(E e) + { + return offer(e); + } + + @Override + public E relaxedPoll() + { + final E[] buffer = this.buffer; + final long cIndex = lpConsumerIndex(); + final long offset = calcCircularRefElementOffset(cIndex, mask); + + // If we can't see the next available element we can't poll + E e = lvRefElement(buffer, offset); + if (null == e) + { + return null; + } + + spRefElement(buffer, offset, null); + soConsumerIndex(cIndex + 1); + return e; + } + + @Override + public E relaxedPeek() + { + final E[] buffer = this.buffer; + final long mask = this.mask; + final long cIndex = lpConsumerIndex(); + return lvRefElement(buffer, calcCircularRefElementOffset(cIndex, mask)); + } + + @Override + public int drain(final Consumer c, final int limit) + { + if (null == c) + throw new IllegalArgumentException("c is null"); + if (limit < 0) + throw new IllegalArgumentException("limit is negative: " + limit); + if (limit == 0) + return 0; + + final E[] buffer = this.buffer; + final long mask = this.mask; + final long cIndex = lpConsumerIndex(); + + for (int i = 0; i < limit; i++) + { + final long index = cIndex + i; + final long offset = calcCircularRefElementOffset(index, mask); + final E e = lvRefElement(buffer, offset); + if (null == e) + { + return i; + } + spRefElement(buffer, offset, null); + soConsumerIndex(index + 1); // ordered store -> atomic and ordered for size() + c.accept(e); + } + return limit; + } + + @Override + public int fill(Supplier s, int limit) + { + if (null == s) + throw new IllegalArgumentException("supplier is null"); + if (limit < 0) + throw new IllegalArgumentException("limit is negative:" + limit); + if (limit == 0) + return 0; + + final long mask = this.mask; + final long capacity = mask + 1; + long producerLimit = lvProducerLimit(); + long pIndex; + int actualLimit; + do + { + pIndex = lvProducerIndex(); + long available = producerLimit - pIndex; + if (available <= 0) + { + final long cIndex = lvConsumerIndex(); + producerLimit = cIndex + capacity; + available = producerLimit - pIndex; + if (available <= 0) + { + return 0; // FULL :( + } + else + { + // update producer limit to the next index that we must recheck the consumer index + soProducerLimit(producerLimit); + } + } + actualLimit = Math.min((int) available, limit); + } + while (!casProducerIndex(pIndex, pIndex + actualLimit)); + // right, now we claimed a few slots and can fill them with goodness + final E[] buffer = this.buffer; + for (int i = 0; i < actualLimit; i++) + { + // Won CAS, move on to storing + final long offset = calcCircularRefElementOffset(pIndex + i, mask); + soRefElement(buffer, offset, s.get()); + } + return actualLimit; + } + + @Override + public int drain(Consumer c) + { + return drain(c, capacity()); + } + + @Override + public int fill(Supplier s) + { + return MessagePassingQueueUtil.fillBounded(this, s); + } + + @Override + public void drain(Consumer c, WaitStrategy w, ExitCondition exit) + { + MessagePassingQueueUtil.drain(this, c, w, exit); + } + + @Override + public void fill(Supplier s, WaitStrategy wait, ExitCondition exit) + { + MessagePassingQueueUtil.fill(this, s, wait, exit); + } +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/MpscBlockingConsumerArrayQueue.java b/netty-jctools/src/main/java/org/jctools/queues/MpscBlockingConsumerArrayQueue.java new file mode 100644 index 0000000..e3799fd --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/MpscBlockingConsumerArrayQueue.java @@ -0,0 +1,828 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues; + +import java.util.AbstractQueue; +import java.util.Collection; +import java.util.Iterator; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.locks.LockSupport; + +import org.jctools.queues.IndexedQueueSizeUtil.IndexedQueue; +import org.jctools.util.Pow2; +import org.jctools.util.RangeUtil; + +import static org.jctools.queues.LinkedArrayQueueUtil.modifiedCalcCircularRefElementOffset; +import static org.jctools.util.UnsafeAccess.UNSAFE; +import static org.jctools.util.UnsafeAccess.fieldOffset; +import static org.jctools.util.UnsafeRefArrayAccess.*; + +@SuppressWarnings("unused") +abstract class MpscBlockingConsumerArrayQueuePad1 extends AbstractQueue implements IndexedQueue +{ + byte b000,b001,b002,b003,b004,b005,b006,b007;// 8b + byte b010,b011,b012,b013,b014,b015,b016,b017;// 16b + byte b020,b021,b022,b023,b024,b025,b026,b027;// 24b + byte b030,b031,b032,b033,b034,b035,b036,b037;// 32b + byte b040,b041,b042,b043,b044,b045,b046,b047;// 40b + byte b050,b051,b052,b053,b054,b055,b056,b057;// 48b + byte b060,b061,b062,b063,b064,b065,b066,b067;// 56b + byte b070,b071,b072,b073,b074,b075,b076,b077;// 64b + byte b100,b101,b102,b103,b104,b105,b106,b107;// 72b + byte b110,b111,b112,b113,b114,b115,b116,b117;// 80b + byte b120,b121,b122,b123,b124,b125,b126,b127;// 88b + byte b130,b131,b132,b133,b134,b135,b136,b137;// 96b + byte b140,b141,b142,b143,b144,b145,b146,b147;//104b + byte b150,b151,b152,b153,b154,b155,b156,b157;//112b + byte b160,b161,b162,b163,b164,b165,b166,b167;//120b + byte b170,b171,b172,b173,b174,b175,b176,b177;//128b +} +// $gen:ordered-fields +abstract class MpscBlockingConsumerArrayQueueColdProducerFields extends MpscBlockingConsumerArrayQueuePad1 +{ + private final static long P_LIMIT_OFFSET = fieldOffset(MpscBlockingConsumerArrayQueueColdProducerFields.class,"producerLimit"); + + private volatile long producerLimit; + protected final long producerMask; + protected final E[] producerBuffer; + + MpscBlockingConsumerArrayQueueColdProducerFields(long producerMask, E[] producerBuffer) + { + this.producerMask = producerMask; + this.producerBuffer = producerBuffer; + } + + final long lvProducerLimit() + { + return producerLimit; + } + + final boolean casProducerLimit(long expect, long newValue) + { + return UNSAFE.compareAndSwapLong(this, P_LIMIT_OFFSET, expect, newValue); + } + + final void soProducerLimit(long newValue) + { + UNSAFE.putOrderedLong(this, P_LIMIT_OFFSET, newValue); + } +} + +@SuppressWarnings("unused") +abstract class MpscBlockingConsumerArrayQueuePad2 extends MpscBlockingConsumerArrayQueueColdProducerFields +{ + byte b000,b001,b002,b003,b004,b005,b006,b007;// 8b + byte b010,b011,b012,b013,b014,b015,b016,b017;// 16b + byte b020,b021,b022,b023,b024,b025,b026,b027;// 24b + byte b030,b031,b032,b033,b034,b035,b036,b037;// 32b + byte b040,b041,b042,b043,b044,b045,b046,b047;// 40b + byte b050,b051,b052,b053,b054,b055,b056,b057;// 48b + byte b060,b061,b062,b063,b064,b065,b066,b067;// 56b + // byte b070,b071,b072,b073,b074,b075,b076,b077;// 64b + + MpscBlockingConsumerArrayQueuePad2(long mask, E[] buffer) + { + super(mask, buffer); + } +} + +// $gen:ordered-fields +abstract class MpscBlockingConsumerArrayQueueProducerFields extends MpscBlockingConsumerArrayQueuePad2 +{ + private final static long P_INDEX_OFFSET = fieldOffset(MpscBlockingConsumerArrayQueueProducerFields.class, "producerIndex"); + + private volatile long producerIndex; + + MpscBlockingConsumerArrayQueueProducerFields(long mask, E[] buffer) + { + super(mask, buffer); + } + + @Override + public final long lvProducerIndex() + { + return producerIndex; + } + + final void soProducerIndex(long newValue) + { + UNSAFE.putOrderedLong(this, P_INDEX_OFFSET, newValue); + } + + final boolean casProducerIndex(long expect, long newValue) + { + return UNSAFE.compareAndSwapLong(this, P_INDEX_OFFSET, expect, newValue); + } +} + +@SuppressWarnings("unused") +abstract class MpscBlockingConsumerArrayQueuePad3 extends MpscBlockingConsumerArrayQueueProducerFields +{ + byte b000,b001,b002,b003,b004,b005,b006,b007;// 8b + byte b010,b011,b012,b013,b014,b015,b016,b017;// 16b + byte b020,b021,b022,b023,b024,b025,b026,b027;// 24b + byte b030,b031,b032,b033,b034,b035,b036,b037;// 32b + byte b040,b041,b042,b043,b044,b045,b046,b047;// 40b + byte b050,b051,b052,b053,b054,b055,b056,b057;// 48b + byte b060,b061,b062,b063,b064,b065,b066,b067;// 56b + byte b070,b071,b072,b073,b074,b075,b076,b077;// 64b + byte b100,b101,b102,b103,b104,b105,b106,b107;// 72b + byte b110,b111,b112,b113,b114,b115,b116,b117;// 80b + byte b120,b121,b122,b123,b124,b125,b126,b127;// 88b + byte b130,b131,b132,b133,b134,b135,b136,b137;// 96b + byte b140,b141,b142,b143,b144,b145,b146,b147;//104b + byte b150,b151,b152,b153,b154,b155,b156,b157;//112b + byte b160,b161,b162,b163,b164,b165,b166,b167;//120b + byte b170,b171,b172,b173,b174,b175,b176,b177;//128b + + MpscBlockingConsumerArrayQueuePad3(long mask, E[] buffer) + { + super(mask, buffer); + } +} + +// $gen:ordered-fields +abstract class MpscBlockingConsumerArrayQueueConsumerFields extends MpscBlockingConsumerArrayQueuePad3 +{ + private final static long C_INDEX_OFFSET = fieldOffset(MpscBlockingConsumerArrayQueueConsumerFields.class,"consumerIndex"); + private final static long BLOCKED_OFFSET = fieldOffset(MpscBlockingConsumerArrayQueueConsumerFields.class,"blocked"); + + private volatile long consumerIndex; + protected final long consumerMask; + private volatile Thread blocked; + protected final E[] consumerBuffer; + + MpscBlockingConsumerArrayQueueConsumerFields(long mask, E[] buffer) + { + super(mask, buffer); + consumerMask = mask; + consumerBuffer = buffer; + } + + @Override + public final long lvConsumerIndex() + { + return consumerIndex; + } + + final long lpConsumerIndex() + { + return UNSAFE.getLong(this, C_INDEX_OFFSET); + } + + final void soConsumerIndex(long newValue) + { + UNSAFE.putOrderedLong(this, C_INDEX_OFFSET, newValue); + } + + final Thread lvBlocked() + { + return blocked; + } + + /** + * This field should only be written to from the consumer thread. It is set before parking the consumer and nulled + * when the consumer is unblocked. The value is read by producer thread to unpark the consumer. + * + * @param thread the consumer thread which is blocked waiting for the producers + */ + final void soBlocked(Thread thread) + { + UNSAFE.putOrderedObject(this, BLOCKED_OFFSET, thread); + } +} + + + +/** + * This is a partial implementation of the {@link java.util.concurrent.BlockingQueue} on the consumer side only on top + * of the mechanics described in {@link BaseMpscLinkedArrayQueue}, but with the reservation bit used for blocking rather + * than resizing in this instance. + */ +@SuppressWarnings("unused") +public class MpscBlockingConsumerArrayQueue extends MpscBlockingConsumerArrayQueueConsumerFields + implements MessagePassingQueue, QueueProgressIndicators, BlockingQueue +{ + byte b000,b001,b002,b003,b004,b005,b006,b007;// 8b + byte b010,b011,b012,b013,b014,b015,b016,b017;// 16b + byte b020,b021,b022,b023,b024,b025,b026,b027;// 24b + byte b030,b031,b032,b033,b034,b035,b036,b037;// 32b + byte b040,b041,b042,b043,b044,b045,b046,b047;// 40b + byte b050,b051,b052,b053,b054,b055,b056,b057;// 48b + byte b060,b061,b062,b063,b064,b065,b066,b067;// 56b + byte b070,b071,b072,b073,b074,b075,b076,b077;// 64b + byte b100,b101,b102,b103,b104,b105,b106,b107;// 72b + byte b110,b111,b112,b113,b114,b115,b116,b117;// 80b + byte b120,b121,b122,b123,b124,b125,b126,b127;// 88b + byte b130,b131,b132,b133,b134,b135,b136,b137;// 96b + byte b140,b141,b142,b143,b144,b145,b146,b147;//104b + byte b150,b151,b152,b153,b154,b155,b156,b157;//112b + byte b160,b161,b162,b163,b164,b165,b166,b167;//120b + byte b170,b171,b172,b173,b174,b175,b176,b177;//128b + + public MpscBlockingConsumerArrayQueue(final int capacity) + { + // leave lower bit of mask clear + super((Pow2.roundToPowerOfTwo(capacity) - 1) << 1, (E[]) allocateRefArray(Pow2.roundToPowerOfTwo(capacity))); + + RangeUtil.checkGreaterThanOrEqual(capacity, 1, "capacity"); + soProducerLimit((Pow2.roundToPowerOfTwo(capacity) - 1) << 1); // we know it's all empty to start with + } + + @Override + public final Iterator iterator() + { + throw new UnsupportedOperationException(); + } + + @Override + public final int size() + { + return IndexedQueueSizeUtil.size(this, IndexedQueueSizeUtil.IGNORE_PARITY_DIVISOR); + } + + @Override + public final boolean isEmpty() + { + // Order matters! + // Loading consumer before producer allows for producer increments after consumer index is read. + // This ensures this method is conservative in it's estimate. Note that as this is an MPMC there is + // nothing we can do to make this an exact method. + return ((this.lvConsumerIndex()/2) == (this.lvProducerIndex()/2)); + } + + @Override + public String toString() + { + return this.getClass().getName(); + } + + /** + * {@link #offer} if {@link #size()} is less than threshold. + * + * @param e the object to offer onto the queue, not null + * @param threshold the maximum allowable size + * @return true if the offer is successful, false if queue size exceeds threshold + * @since 3.0.1 + */ + public boolean offerIfBelowThreshold(final E e, int threshold) + { + if (null == e) + { + throw new NullPointerException(); + } + + final long mask = this.producerMask; + final long capacity = mask + 2; + threshold = threshold << 1; + final E[] buffer = this.producerBuffer; + long pIndex; + while (true) + { + pIndex = lvProducerIndex(); + // lower bit is indicative of blocked consumer + if ((pIndex & 1) == 1) + { + if (offerAndWakeup(buffer, mask, pIndex, e)) { + return true; + } + continue; + } + // pIndex is even (lower bit is 0) -> actual index is (pIndex >> 1), consumer is awake + final long producerLimit = lvProducerLimit(); + + // Use producer limit to save a read of the more rapidly mutated consumer index. + // Assumption: queue is usually empty or near empty + + // available is also << 1 + final long available = producerLimit - pIndex; + // sizeEstimate <= size + final long sizeEstimate = capacity - available; + + if (sizeEstimate >= threshold || + // producerLimit check allows for threshold >= capacity + producerLimit <= pIndex) + { + if (!recalculateProducerLimit(pIndex, producerLimit, lvConsumerIndex(), capacity, threshold)) + { + return false; + } + } + + // Claim the index + if (casProducerIndex(pIndex, pIndex + 2)) + { + break; + } + } + final long offset = modifiedCalcCircularRefElementOffset(pIndex, mask); + // INDEX visible before ELEMENT + soRefElement(buffer, offset, e); // release element e + return true; + } + + @Override + public boolean offer(final E e) + { + if (null == e) + { + throw new NullPointerException(); + } + + final long mask = this.producerMask; + final E[] buffer = this.producerBuffer; + long pIndex; + while (true) + { + pIndex = lvProducerIndex(); + // lower bit is indicative of blocked consumer + if ((pIndex & 1) == 1) + { + if (offerAndWakeup(buffer, mask, pIndex, e)) + return true; + continue; + } + // pIndex is even (lower bit is 0) -> actual index is (pIndex >> 1), consumer is awake + final long producerLimit = lvProducerLimit(); + + // Use producer limit to save a read of the more rapidly mutated consumer index. + // Assumption: queue is usually empty or near empty + if (producerLimit <= pIndex) + { + if (!recalculateProducerLimit(mask, pIndex, producerLimit)) + { + return false; + } + } + + // Claim the index + if (casProducerIndex(pIndex, pIndex + 2)) + { + break; + } + } + final long offset = modifiedCalcCircularRefElementOffset(pIndex, mask); + // INDEX visible before ELEMENT + soRefElement(buffer, offset, e); // release element e + return true; + } + + @Override + public void put(E e) throws InterruptedException + { + if (!offer(e)) + throw new UnsupportedOperationException(); + } + + @Override + public boolean offer(E e, long timeout, TimeUnit unit) throws InterruptedException + { + if (offer(e)) + return true; + throw new UnsupportedOperationException(); + } + + private boolean offerAndWakeup(E[] buffer, long mask, long pIndex, E e) + { + final long offset = modifiedCalcCircularRefElementOffset(pIndex, mask); + final Thread consumerThread = lvBlocked(); + + // We could see a null here through a race with the consumer not yet storing the reference. Just retry. + if (consumerThread == null) + { + return false; + } + + // Claim the slot and the responsibility of unparking + if(!casProducerIndex(pIndex, pIndex + 1)) + { + return false; + } + + soRefElement(buffer, offset, e); + LockSupport.unpark(consumerThread); + return true; + } + + private boolean recalculateProducerLimit(long mask, long pIndex, long producerLimit) + { + return recalculateProducerLimit(pIndex, producerLimit, lvConsumerIndex(), mask + 2, mask + 2); + } + + private boolean recalculateProducerLimit(long pIndex, long producerLimit, long cIndex, long bufferCapacity, long threshold) + { + // try to update the limit with our new found knowledge on cIndex + if (cIndex + bufferCapacity > pIndex) + { + casProducerLimit(producerLimit, cIndex + bufferCapacity); + } + // full and cannot grow, or hit threshold + long size = pIndex - cIndex; + return size < threshold && size < bufferCapacity; + } + + /** + * {@inheritDoc} + *

+ * This implementation is correct for single consumer thread use only. + */ + @Override + public E take() throws InterruptedException + { + final E[] buffer = consumerBuffer; + final long mask = consumerMask; + + final long cIndex = lpConsumerIndex(); + final long offset = modifiedCalcCircularRefElementOffset(cIndex, mask); + E e = lvRefElement(buffer, offset); + if (e == null) + { + return parkUntilNext(buffer, cIndex, offset, Long.MAX_VALUE); + } + + soRefElement(buffer, offset, null); // release element null + soConsumerIndex(cIndex + 2); // release cIndex + + return e; + } + + /** + * {@inheritDoc} + *

+ * This implementation is correct for single consumer thread use only. + */ + @Override + public E poll(long timeout, TimeUnit unit) throws InterruptedException + { + final E[] buffer = consumerBuffer; + final long mask = consumerMask; + + final long cIndex = lpConsumerIndex(); + final long offset = modifiedCalcCircularRefElementOffset(cIndex, mask); + E e = lvRefElement(buffer, offset); + if (e == null) + { + long timeoutNs = unit.toNanos(timeout); + if (timeoutNs <= 0) + { + return null; + } + return parkUntilNext(buffer, cIndex, offset, timeoutNs); + } + + soRefElement(buffer, offset, null); // release element null + soConsumerIndex(cIndex + 2); // release cIndex + + return e; + } + + private E parkUntilNext(E[] buffer, long cIndex, long offset, long timeoutNs) throws InterruptedException { + E e; + final long pIndex = lvProducerIndex(); + if (cIndex == pIndex && // queue is empty + casProducerIndex(pIndex, pIndex + 1)) // we announce ourselves as parked by setting parity + { + // producers only try a wakeup when both the index and the blocked thread are visible, otherwise they spin + soBlocked(Thread.currentThread()); + // ignore deadline when it's forever + final long deadlineNs = timeoutNs == Long.MAX_VALUE ? 0 : System.nanoTime() + timeoutNs; + + try + { + while (true) + { + LockSupport.parkNanos(this, timeoutNs); + if (Thread.interrupted()) + { + casProducerIndex(pIndex + 1, pIndex); + throw new InterruptedException(); + } + if ((lvProducerIndex() & 1) == 0) { + break; + } + // ignore deadline when it's forever + timeoutNs = timeoutNs == Long.MAX_VALUE ? Long.MAX_VALUE : deadlineNs - System.nanoTime(); + if (timeoutNs <= 0) + { + if (casProducerIndex(pIndex + 1, pIndex)) + { + // ran out of time and the producer has not moved the index + return null; + } + + break; // just in the nick of time + } + } + } + finally + { + soBlocked(null); + } + } + // producer index is visible before element, so if we wake up between the index moving and the element + // store we could see a null. + e = spinWaitForElement(buffer, offset); + + soRefElement(buffer, offset, null); // release element null + soConsumerIndex(cIndex + 2); // release cIndex + + return e; + } + + @Override + public int remainingCapacity() + { + return capacity() - size(); + } + + @Override + public int drainTo(Collection c) + { + throw new UnsupportedOperationException(); + } + + @Override + public int drainTo(Collection c, int maxElements) + { + throw new UnsupportedOperationException(); + } + + /** + * {@inheritDoc} + *

+ * This implementation is correct for single consumer thread use only. + */ + @Override + public E poll() + { + final E[] buffer = consumerBuffer; + final long mask = consumerMask; + + final long index = lpConsumerIndex(); + final long offset = modifiedCalcCircularRefElementOffset(index, mask); + E e = lvRefElement(buffer, offset); + if (e == null) + { + // consumer can't see the odd producer index + if (index != lvProducerIndex()) + { + // poll() == null iff queue is empty, null element is not strong enough indicator, so we must + // check the producer index. If the queue is indeed not empty we spin until element is + // visible. + e = spinWaitForElement(buffer, offset); + } + else + { + return null; + } + } + + soRefElement(buffer, offset, null); // release element null + soConsumerIndex(index + 2); // release cIndex + return e; + } + + private static E spinWaitForElement(E[] buffer, long offset) + { + E e; + do + { + e = lvRefElement(buffer, offset); + } + while (e == null); + return e; + } + + /** + * {@inheritDoc} + *

+ * This implementation is correct for single consumer thread use only. + */ + @Override + public E peek() + { + final E[] buffer = consumerBuffer; + final long mask = consumerMask; + + final long index = lpConsumerIndex(); + final long offset = modifiedCalcCircularRefElementOffset(index, mask); + E e = lvRefElement(buffer, offset); + if (e == null && index != lvProducerIndex()) + { + // peek() == null iff queue is empty, null element is not strong enough indicator, so we must + // check the producer index. If the queue is indeed not empty we spin until element is visible. + e = spinWaitForElement(buffer, offset); + } + + return e; + } + + @Override + public long currentProducerIndex() + { + return lvProducerIndex() / 2; + } + + @Override + public long currentConsumerIndex() + { + return lvConsumerIndex() / 2; + } + + @Override + public int capacity() + { + return (int) ((consumerMask + 2) >> 1); + } + + @Override + public boolean relaxedOffer(E e) + { + return offer(e); + } + + @Override + public E relaxedPoll() + { + final E[] buffer = consumerBuffer; + final long index = lpConsumerIndex(); + final long mask = consumerMask; + + final long offset = modifiedCalcCircularRefElementOffset(index, mask); + E e = lvRefElement(buffer, offset); + if (e == null) + { + return null; + } + soRefElement(buffer, offset, null); + soConsumerIndex(index + 2); + return e; + } + + @Override + public E relaxedPeek() + { + final E[] buffer = consumerBuffer; + final long index = lpConsumerIndex(); + final long mask = consumerMask; + + final long offset = modifiedCalcCircularRefElementOffset(index, mask); + return lvRefElement(buffer, offset); + } + + @Override + public int fill(Supplier s, int limit) + { + if (null == s) + throw new IllegalArgumentException("supplier is null"); + if (limit < 0) + throw new IllegalArgumentException("limit is negative:" + limit); + if (limit == 0) + return 0; + + final long mask = this.producerMask; + + long pIndex; + int claimedSlots; + Thread blockedConsumer = null; + long batchLimit = 0; + final long shiftedBatchSize = 2L * limit; + + while (true) + { + pIndex = lvProducerIndex(); + long producerLimit = lvProducerLimit(); + + // lower bit is indicative of blocked consumer + if ((pIndex & 1) == 1) + { + // observe the blocked thread for the pIndex + blockedConsumer = lvBlocked(); + if (blockedConsumer == null) + continue;// racing, retry + if(!casProducerIndex(pIndex, pIndex + 1)) + { + blockedConsumer = null; + continue; + } + // We have observed the blocked thread for the pIndex(lv index, lv thread, cas index). + // We've claimed pIndex, now we need to wake up consumer and set the element + batchLimit = pIndex + 1; + pIndex = pIndex - 1; + break; + } + // pIndex is even (lower bit is 0) -> actual index is (pIndex >> 1), consumer is awake + + // we want 'limit' slots, but will settle for whatever is visible to 'producerLimit' + batchLimit = Math.min(producerLimit, pIndex + shiftedBatchSize); // -> producerLimit >= batchLimit + + // Use producer limit to save a read of the more rapidly mutated consumer index. + // Assumption: queue is usually empty or near empty + if (pIndex >= producerLimit) + { + if (!recalculateProducerLimit(mask, pIndex, producerLimit)) + { + return 0; + } + batchLimit = Math.min(lvProducerLimit(), pIndex + shiftedBatchSize); + } + + // Claim the index + if (casProducerIndex(pIndex, batchLimit)) + { + break; + } + } + claimedSlots = (int) ((batchLimit - pIndex) / 2); + + final E[] buffer = this.producerBuffer; + // first element offset might be a wakeup, so peeled from loop + for (int i = 0; i < claimedSlots; i++) + { + long offset = modifiedCalcCircularRefElementOffset(pIndex + 2L * i, mask); + soRefElement(buffer, offset, s.get()); + } + + if (blockedConsumer != null) + { + // no point unblocking an unrelated blocked thread, things have obviously moved on + if (lvBlocked() == blockedConsumer) { + LockSupport.unpark(blockedConsumer); + } + } + + return claimedSlots; + } + + /** + * Remove up to limit elements from the queue and hand to consume, waiting up to the specified wait time if + * necessary for an element to become available. + *

+ * There's no strong commitment to the queue being empty at the end of it. + * This implementation is correct for single consumer thread use only. + *

+ * WARNING: Explicit assumptions are made with regards to {@link Consumer#accept} make sure you have read + * and understood these before using this method. + * + * @return the number of polled elements + * @throws InterruptedException if interrupted while waiting + * @throws IllegalArgumentException c is {@code null} + * @throws IllegalArgumentException if limit is negative + */ + public int drain(Consumer c, final int limit, long timeout, TimeUnit unit) throws InterruptedException { + if (limit == 0) { + return 0; + } + final int drained = drain(c, limit); + if (drained != 0) { + return drained; + } + final E e = poll(timeout, unit); + if (e == null) + return 0; + c.accept(e); + return 1 + drain(c, limit - 1); + } + + @Override + public int fill(Supplier s) + { + return MessagePassingQueueUtil.fillBounded(this, s); + } + + @Override + public void fill(Supplier s, WaitStrategy wait, ExitCondition exit) + { + MessagePassingQueueUtil.fill(this, s, wait, exit); + } + + @Override + public int drain(Consumer c) + { + return drain(c, capacity()); + } + + @Override + public int drain(final Consumer c, final int limit) + { + return MessagePassingQueueUtil.drain(this, c, limit); + } + + @Override + public void drain(Consumer c, WaitStrategy w, ExitCondition exit) + { + MessagePassingQueueUtil.drain(this, c, w, exit); + } +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/MpscChunkedArrayQueue.java b/netty-jctools/src/main/java/org/jctools/queues/MpscChunkedArrayQueue.java new file mode 100644 index 0000000..bfbfb8f --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/MpscChunkedArrayQueue.java @@ -0,0 +1,102 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues; + +import org.jctools.util.Pow2; +import org.jctools.util.RangeUtil; + +import static java.lang.Math.max; +import static java.lang.Math.min; +import static org.jctools.queues.LinkedArrayQueueUtil.length; +import static org.jctools.util.Pow2.roundToPowerOfTwo; + +abstract class MpscChunkedArrayQueueColdProducerFields extends BaseMpscLinkedArrayQueue +{ + protected final long maxQueueCapacity; + + MpscChunkedArrayQueueColdProducerFields(int initialCapacity, int maxCapacity) + { + super(initialCapacity); + RangeUtil.checkGreaterThanOrEqual(maxCapacity, 4, "maxCapacity"); + RangeUtil.checkLessThan(roundToPowerOfTwo(initialCapacity), roundToPowerOfTwo(maxCapacity), + "initialCapacity"); + maxQueueCapacity = ((long) Pow2.roundToPowerOfTwo(maxCapacity)) << 1; + } +} + +/** + * An MPSC array queue which starts at initialCapacity and grows to maxCapacity in linked chunks + * of the initial size. The queue grows only when the current chunk is full and elements are not copied on + * resize, instead a link to the new chunk is stored in the old chunk for the consumer to follow. + */ +public class MpscChunkedArrayQueue extends MpscChunkedArrayQueueColdProducerFields +{ + byte b000,b001,b002,b003,b004,b005,b006,b007;// 8b + byte b010,b011,b012,b013,b014,b015,b016,b017;// 16b + byte b020,b021,b022,b023,b024,b025,b026,b027;// 24b + byte b030,b031,b032,b033,b034,b035,b036,b037;// 32b + byte b040,b041,b042,b043,b044,b045,b046,b047;// 40b + byte b050,b051,b052,b053,b054,b055,b056,b057;// 48b + byte b060,b061,b062,b063,b064,b065,b066,b067;// 56b + byte b070,b071,b072,b073,b074,b075,b076,b077;// 64b + byte b100,b101,b102,b103,b104,b105,b106,b107;// 72b + byte b110,b111,b112,b113,b114,b115,b116,b117;// 80b + byte b120,b121,b122,b123,b124,b125,b126,b127;// 88b + byte b130,b131,b132,b133,b134,b135,b136,b137;// 96b + byte b140,b141,b142,b143,b144,b145,b146,b147;//104b + byte b150,b151,b152,b153,b154,b155,b156,b157;//112b + byte b160,b161,b162,b163,b164,b165,b166,b167;//120b + byte b170,b171,b172,b173,b174,b175,b176,b177;//128b + + public MpscChunkedArrayQueue(int maxCapacity) + { + super(max(2, min(1024, roundToPowerOfTwo(maxCapacity / 8))), maxCapacity); + } + + /** + * @param initialCapacity the queue initial capacity. If chunk size is fixed this will be the chunk size. + * Must be 2 or more. + * @param maxCapacity the maximum capacity will be rounded up to the closest power of 2 and will be the + * upper limit of number of elements in this queue. Must be 4 or more and round up to a larger + * power of 2 than initialCapacity. + */ + public MpscChunkedArrayQueue(int initialCapacity, int maxCapacity) + { + super(initialCapacity, maxCapacity); + } + + @Override + protected long availableInQueue(long pIndex, long cIndex) + { + return maxQueueCapacity - (pIndex - cIndex); + } + + @Override + public int capacity() + { + return (int) (maxQueueCapacity / 2); + } + + @Override + protected int getNextBufferSize(E[] buffer) + { + return length(buffer); + } + + @Override + protected long getCurrentBufferCapacity(long mask) + { + return mask; + } +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/MpscCompoundQueue.java b/netty-jctools/src/main/java/org/jctools/queues/MpscCompoundQueue.java new file mode 100644 index 0000000..b34b5a3 --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/MpscCompoundQueue.java @@ -0,0 +1,374 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues; + +import org.jctools.util.RangeUtil; + +import java.util.AbstractQueue; +import java.util.Iterator; + +import static org.jctools.util.PortableJvmInfo.CPUs; +import static org.jctools.util.Pow2.isPowerOfTwo; +import static org.jctools.util.Pow2.roundToPowerOfTwo; + +/** + * Use a set number of parallel MPSC queues to diffuse the contention on tail. + */ +abstract class MpscCompoundQueueL0Pad extends AbstractQueue implements MessagePassingQueue +{ + byte b000,b001,b002,b003,b004,b005,b006,b007;// 8b + byte b010,b011,b012,b013,b014,b015,b016,b017;// 16b + byte b020,b021,b022,b023,b024,b025,b026,b027;// 24b + byte b030,b031,b032,b033,b034,b035,b036,b037;// 32b + byte b040,b041,b042,b043,b044,b045,b046,b047;// 40b + byte b050,b051,b052,b053,b054,b055,b056,b057;// 48b + byte b060,b061,b062,b063,b064,b065,b066,b067;// 56b + byte b070,b071,b072,b073,b074,b075,b076,b077;// 64b + byte b100,b101,b102,b103,b104,b105,b106,b107;// 72b + byte b110,b111,b112,b113,b114,b115,b116,b117;// 80b + byte b120,b121,b122,b123,b124,b125,b126,b127;// 88b + byte b130,b131,b132,b133,b134,b135,b136,b137;// 96b + byte b140,b141,b142,b143,b144,b145,b146,b147;//104b + byte b150,b151,b152,b153,b154,b155,b156,b157;//112b + byte b160,b161,b162,b163,b164,b165,b166,b167;//120b + byte b170,b171,b172,b173,b174,b175,b176,b177;//128b +} + +abstract class MpscCompoundQueueColdFields extends MpscCompoundQueueL0Pad +{ + // must be power of 2 + protected final int parallelQueues; + protected final int parallelQueuesMask; + protected final MpscArrayQueue[] queues; + + @SuppressWarnings("unchecked") + MpscCompoundQueueColdFields(int capacity, int queueParallelism) + { + parallelQueues = isPowerOfTwo(queueParallelism) ? queueParallelism + : roundToPowerOfTwo(queueParallelism) / 2; + parallelQueuesMask = parallelQueues - 1; + queues = new MpscArrayQueue[parallelQueues]; + int fullCapacity = roundToPowerOfTwo(capacity); + RangeUtil.checkGreaterThanOrEqual(fullCapacity, parallelQueues, "fullCapacity"); + for (int i = 0; i < parallelQueues; i++) + { + queues[i] = new MpscArrayQueue(fullCapacity / parallelQueues); + } + } +} + +abstract class MpscCompoundQueueMidPad extends MpscCompoundQueueColdFields +{ + byte b000,b001,b002,b003,b004,b005,b006,b007;// 8b + byte b010,b011,b012,b013,b014,b015,b016,b017;// 16b + byte b020,b021,b022,b023,b024,b025,b026,b027;// 24b + byte b030,b031,b032,b033,b034,b035,b036,b037;// 32b + byte b040,b041,b042,b043,b044,b045,b046,b047;// 40b + byte b050,b051,b052,b053,b054,b055,b056,b057;// 48b + byte b060,b061,b062,b063,b064,b065,b066,b067;// 56b + byte b070,b071,b072,b073,b074,b075,b076,b077;// 64b + byte b100,b101,b102,b103,b104,b105,b106,b107;// 72b + byte b110,b111,b112,b113,b114,b115,b116,b117;// 80b + byte b120,b121,b122,b123,b124,b125,b126,b127;// 88b + byte b130,b131,b132,b133,b134,b135,b136,b137;// 96b + byte b140,b141,b142,b143,b144,b145,b146,b147;//104b + byte b150,b151,b152,b153,b154,b155,b156,b157;//112b + byte b160,b161,b162,b163,b164,b165,b166,b167;//120b + byte b170,b171,b172,b173,b174,b175,b176,b177;//128b + + public MpscCompoundQueueMidPad(int capacity, int queueParallelism) + { + super(capacity, queueParallelism); + } +} + +abstract class MpscCompoundQueueConsumerQueueIndex extends MpscCompoundQueueMidPad +{ + int consumerQueueIndex; + + MpscCompoundQueueConsumerQueueIndex(int capacity, int queueParallelism) + { + super(capacity, queueParallelism); + } +} + +public class MpscCompoundQueue extends MpscCompoundQueueConsumerQueueIndex +{ + byte b000,b001,b002,b003,b004,b005,b006,b007;// 8b + byte b010,b011,b012,b013,b014,b015,b016,b017;// 16b + byte b020,b021,b022,b023,b024,b025,b026,b027;// 24b + byte b030,b031,b032,b033,b034,b035,b036,b037;// 32b + byte b040,b041,b042,b043,b044,b045,b046,b047;// 40b + byte b050,b051,b052,b053,b054,b055,b056,b057;// 48b + byte b060,b061,b062,b063,b064,b065,b066,b067;// 56b + byte b070,b071,b072,b073,b074,b075,b076,b077;// 64b + byte b100,b101,b102,b103,b104,b105,b106,b107;// 72b + byte b110,b111,b112,b113,b114,b115,b116,b117;// 80b + byte b120,b121,b122,b123,b124,b125,b126,b127;// 88b + byte b130,b131,b132,b133,b134,b135,b136,b137;// 96b + byte b140,b141,b142,b143,b144,b145,b146,b147;//104b + byte b150,b151,b152,b153,b154,b155,b156,b157;//112b + byte b160,b161,b162,b163,b164,b165,b166,b167;//120b + byte b170,b171,b172,b173,b174,b175,b176,b177;//128b + + public MpscCompoundQueue(int capacity) + { + this(capacity, CPUs); + } + + public MpscCompoundQueue(int capacity, int queueParallelism) + { + super(capacity, queueParallelism); + } + + @Override + public boolean offer(final E e) + { + if (null == e) + { + throw new NullPointerException(); + } + final int parallelQueuesMask = this.parallelQueuesMask; + int start = (int) (Thread.currentThread().getId() & parallelQueuesMask); + final MpscArrayQueue[] queues = this.queues; + if (queues[start].offer(e)) + { + return true; + } + else + { + return slowOffer(queues, parallelQueuesMask, start + 1, e); + } + } + + private boolean slowOffer(MpscArrayQueue[] queues, int parallelQueuesMask, int start, E e) + { + final int queueCount = parallelQueuesMask + 1; + final int end = start + queueCount; + while (true) + { + int status = 0; + for (int i = start; i < end; i++) + { + int s = queues[i & parallelQueuesMask].failFastOffer(e); + if (s == 0) + { + return true; + } + status += s; + } + if (status == queueCount) + { + return false; + } + } + } + + @Override + public E poll() + { + int qIndex = consumerQueueIndex & parallelQueuesMask; + int limit = qIndex + parallelQueues; + E e = null; + for (; qIndex < limit; qIndex++) + { + e = queues[qIndex & parallelQueuesMask].poll(); + if (e != null) + { + break; + } + } + consumerQueueIndex = qIndex; + return e; + } + + @Override + public E peek() + { + int qIndex = consumerQueueIndex & parallelQueuesMask; + int limit = qIndex + parallelQueues; + E e = null; + for (; qIndex < limit; qIndex++) + { + e = queues[qIndex & parallelQueuesMask].peek(); + if (e != null) + { + break; + } + } + consumerQueueIndex = qIndex; + return e; + } + + @Override + public int size() + { + int size = 0; + for (MpscArrayQueue lane : queues) + { + size += lane.size(); + } + return size; + } + + @Override + public Iterator iterator() + { + throw new UnsupportedOperationException(); + } + + @Override + public String toString() + { + return this.getClass().getName(); + } + + @Override + public boolean relaxedOffer(E e) + { + if (null == e) + { + throw new NullPointerException(); + } + final int parallelQueuesMask = this.parallelQueuesMask; + int start = (int) (Thread.currentThread().getId() & parallelQueuesMask); + final MpscArrayQueue[] queues = this.queues; + if (queues[start].failFastOffer(e) == 0) + { + return true; + } + else + { + // we already offered to first queue, try the rest + for (int i = start + 1; i < start + parallelQueuesMask + 1; i++) + { + if (queues[i & parallelQueuesMask].failFastOffer(e) == 0) + { + return true; + } + } + // this is a relaxed offer, we can fail for any reason we like + return false; + } + } + + @Override + public E relaxedPoll() + { + int qIndex = consumerQueueIndex & parallelQueuesMask; + int limit = qIndex + parallelQueues; + E e = null; + for (; qIndex < limit; qIndex++) + { + e = queues[qIndex & parallelQueuesMask].relaxedPoll(); + if (e != null) + { + break; + } + } + consumerQueueIndex = qIndex; + return e; + } + + @Override + public E relaxedPeek() + { + int qIndex = consumerQueueIndex & parallelQueuesMask; + int limit = qIndex + parallelQueues; + E e = null; + for (; qIndex < limit; qIndex++) + { + e = queues[qIndex & parallelQueuesMask].relaxedPeek(); + if (e != null) + { + break; + } + } + consumerQueueIndex = qIndex; + return e; + } + + @Override + public int capacity() + { + return queues.length * queues[0].capacity(); + } + + + @Override + public int drain(Consumer c) + { + final int limit = capacity(); + return drain(c, limit); + } + + @Override + public int fill(Supplier s) + { + + return MessagePassingQueueUtil.fillBounded(this, s); + } + + @Override + public int drain(Consumer c, int limit) + { + return MessagePassingQueueUtil.drain(this, c, limit); + } + + @Override + public int fill(Supplier s, int limit) + { + if (null == s) + throw new IllegalArgumentException("supplier is null"); + if (limit < 0) + throw new IllegalArgumentException("limit is negative:" + limit); + if (limit == 0) + return 0; + + final int parallelQueuesMask = this.parallelQueuesMask; + int start = (int) (Thread.currentThread().getId() & parallelQueuesMask); + final MpscArrayQueue[] queues = this.queues; + int filled = queues[start].fill(s, limit); + if (filled == limit) + { + return limit; + } + else + { + // we already offered to first queue, try the rest + for (int i = start + 1; i < start + parallelQueuesMask + 1; i++) + { + filled += queues[i & parallelQueuesMask].fill(s, limit - filled); + if (filled == limit) + { + return limit; + } + } + // this is a relaxed offer, we can fail for any reason we like + return filled; + } + } + + @Override + public void drain(Consumer c, WaitStrategy wait, ExitCondition exit) + { + MessagePassingQueueUtil.drain(this, c, wait, exit); + } + + @Override + public void fill(Supplier s, WaitStrategy wait, ExitCondition exit) + { + MessagePassingQueueUtil.fill(this, s, wait, exit); + } +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/MpscGrowableArrayQueue.java b/netty-jctools/src/main/java/org/jctools/queues/MpscGrowableArrayQueue.java new file mode 100644 index 0000000..14c1825 --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/MpscGrowableArrayQueue.java @@ -0,0 +1,63 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues; + +import org.jctools.util.Pow2; +import org.jctools.util.RangeUtil; + +import static org.jctools.queues.LinkedArrayQueueUtil.length; + + +/** + * An MPSC array queue which starts at initialCapacity and grows to maxCapacity in linked chunks, + * doubling theirs size every time until the full blown backing array is used. + * The queue grows only when the current chunk is full and elements are not copied on + * resize, instead a link to the new chunk is stored in the old chunk for the consumer to follow. + */ +public class MpscGrowableArrayQueue extends MpscChunkedArrayQueue +{ + + public MpscGrowableArrayQueue(int maxCapacity) + { + super(Math.max(2, Pow2.roundToPowerOfTwo(maxCapacity / 8)), maxCapacity); + } + + /** + * @param initialCapacity the queue initial capacity. If chunk size is fixed this will be the chunk size. + * Must be 2 or more. + * @param maxCapacity the maximum capacity will be rounded up to the closest power of 2 and will be the + * upper limit of number of elements in this queue. Must be 4 or more and round up to a larger + * power of 2 than initialCapacity. + */ + public MpscGrowableArrayQueue(int initialCapacity, int maxCapacity) + { + super(initialCapacity, maxCapacity); + } + + + @Override + protected int getNextBufferSize(E[] buffer) + { + final long maxSize = maxQueueCapacity / 2; + RangeUtil.checkLessThanOrEqual(length(buffer), maxSize, "buffer.length"); + final int newSize = 2 * (length(buffer) - 1); + return newSize + 1; + } + + @Override + protected long getCurrentBufferCapacity(long mask) + { + return (mask + 2 == maxQueueCapacity) ? maxQueueCapacity : mask; + } +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/MpscLinkedQueue.java b/netty-jctools/src/main/java/org/jctools/queues/MpscLinkedQueue.java new file mode 100644 index 0000000..a12957b --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/MpscLinkedQueue.java @@ -0,0 +1,196 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues; + +import org.jctools.util.UnsafeAccess; + +import java.util.Queue; + +import static org.jctools.util.UnsafeAccess.UNSAFE; + +/** + * This is a Java port of the MPSC algorithm as presented + * on + * 1024 Cores by D. Vyukov. The original has been adapted to Java and it's quirks with regards to memory + * model and layout: + *

    + *
  1. Use inheritance to ensure no false sharing occurs between producer/consumer node reference fields. + *
  2. Use XCHG functionality to the best of the JDK ability (see differences in JDK7/8 impls). + *
  3. Conform to {@link java.util.Queue} contract on poll. The original semantics are available via relaxedPoll. + *
+ * The queue is initialized with a stub node which is set to both the producer and consumer node references. + * From this point follow the notes on offer/poll. + */ +public class MpscLinkedQueue extends BaseLinkedQueue +{ + public MpscLinkedQueue() + { + LinkedQueueNode node = newNode(); + spConsumerNode(node); + xchgProducerNode(node); + } + + /** + * {@inheritDoc}
+ *

+ * IMPLEMENTATION NOTES:
+ * Offer is allowed from multiple threads.
+ * Offer allocates a new node and: + *

    + *
  1. Swaps it atomically with current producer node (only one producer 'wins') + *
  2. Sets the new node as the node following from the swapped producer node + *
+ * This works because each producer is guaranteed to 'plant' a new node and link the old node. No 2 + * producers can get the same producer node as part of XCHG guarantee. + * + * @see MessagePassingQueue#offer(Object) + * @see java.util.Queue#offer(java.lang.Object) + */ + @Override + public boolean offer(final E e) + { + if (null == e) + { + throw new NullPointerException(); + } + final LinkedQueueNode nextNode = newNode(e); + final LinkedQueueNode prevProducerNode = xchgProducerNode(nextNode); + // Should a producer thread get interrupted here the chain WILL be broken until that thread is resumed + // and completes the store in prev.next. This is a "bubble". + prevProducerNode.soNext(nextNode); + return true; + } + + /** + * {@inheritDoc} + *

+ * This method is only safe to call from the (single) consumer thread, and is subject to best effort when racing + * with producers. This method is potentially blocking when "bubble"s in the queue are visible. + */ + @Override + public boolean remove(Object o) + { + if (null == o) + { + return false; // Null elements are not permitted, so null will never be removed. + } + + final LinkedQueueNode originalConsumerNode = lpConsumerNode(); + LinkedQueueNode prevConsumerNode = originalConsumerNode; + LinkedQueueNode currConsumerNode = getNextConsumerNode(originalConsumerNode); + while (currConsumerNode != null) + { + if (o.equals(currConsumerNode.lpValue())) + { + LinkedQueueNode nextNode = getNextConsumerNode(currConsumerNode); + // e.g.: consumerNode -> node0 -> node1(o==v) -> node2 ... => consumerNode -> node0 -> node2 + if (nextNode != null) + { + // We are removing an interior node. + prevConsumerNode.soNext(nextNode); + } + // This case reflects: prevConsumerNode != originalConsumerNode && nextNode == null + // At rest, this would be the producerNode, but we must contend with racing. Changes to subclassed + // queues need to consider remove() when implementing offer(). + else + { + // producerNode is currConsumerNode, try to atomically update the reference to move it to the + // previous node. + prevConsumerNode.soNext(null); + if (!casProducerNode(currConsumerNode, prevConsumerNode)) + { + // If the producer(s) have offered more items we need to remove the currConsumerNode link. + nextNode = spinWaitForNextNode(currConsumerNode); + prevConsumerNode.soNext(nextNode); + } + } + + // Avoid GC nepotism because we are discarding the current node. + currConsumerNode.soNext(null); + currConsumerNode.spValue(null); + + return true; + } + prevConsumerNode = currConsumerNode; + currConsumerNode = getNextConsumerNode(currConsumerNode); + } + return false; + } + + @Override + public int fill(Supplier s) + { + return MessagePassingQueueUtil.fillUnbounded(this, s); + } + + @Override + public int fill(Supplier s, int limit) + { + if (null == s) + throw new IllegalArgumentException("supplier is null"); + if (limit < 0) + throw new IllegalArgumentException("limit is negative:" + limit); + if (limit == 0) + return 0; + + LinkedQueueNode tail = newNode(s.get()); + final LinkedQueueNode head = tail; + for (int i = 1; i < limit; i++) + { + final LinkedQueueNode temp = newNode(s.get()); + // spNext: xchgProducerNode ensures correct construction + tail.spNext(temp); + tail = temp; + } + final LinkedQueueNode oldPNode = xchgProducerNode(tail); + oldPNode.soNext(head); + return limit; + } + + @Override + public void fill(Supplier s, WaitStrategy wait, ExitCondition exit) + { + MessagePassingQueueUtil.fill(this, s, wait, exit); + } + + // $gen:ignore + private LinkedQueueNode xchgProducerNode(LinkedQueueNode newVal) + { + if (UnsafeAccess.SUPPORTS_GET_AND_SET_REF) + { + return (LinkedQueueNode) UNSAFE.getAndSetObject(this, P_NODE_OFFSET, newVal); + } + else + { + LinkedQueueNode oldVal; + do + { + oldVal = lvProducerNode(); + } + while (!UNSAFE.compareAndSwapObject(this, P_NODE_OFFSET, oldVal, newVal)); + return oldVal; + } + } + + private LinkedQueueNode getNextConsumerNode(LinkedQueueNode currConsumerNode) + { + LinkedQueueNode nextNode = currConsumerNode.lvNext(); + if (nextNode == null && currConsumerNode != lvProducerNode()) + { + nextNode = spinWaitForNextNode(currConsumerNode); + } + return nextNode; + } + +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/MpscUnboundedArrayQueue.java b/netty-jctools/src/main/java/org/jctools/queues/MpscUnboundedArrayQueue.java new file mode 100644 index 0000000..99dd4e7 --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/MpscUnboundedArrayQueue.java @@ -0,0 +1,83 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues; + +import static org.jctools.queues.LinkedArrayQueueUtil.length; + +/** + * An MPSC array queue which starts at initialCapacity and grows indefinitely in linked chunks of the initial size. + * The queue grows only when the current chunk is full and elements are not copied on + * resize, instead a link to the new chunk is stored in the old chunk for the consumer to follow. + */ +public class MpscUnboundedArrayQueue extends BaseMpscLinkedArrayQueue +{ + byte b000,b001,b002,b003,b004,b005,b006,b007;// 8b + byte b010,b011,b012,b013,b014,b015,b016,b017;// 16b + byte b020,b021,b022,b023,b024,b025,b026,b027;// 24b + byte b030,b031,b032,b033,b034,b035,b036,b037;// 32b + byte b040,b041,b042,b043,b044,b045,b046,b047;// 40b + byte b050,b051,b052,b053,b054,b055,b056,b057;// 48b + byte b060,b061,b062,b063,b064,b065,b066,b067;// 56b + byte b070,b071,b072,b073,b074,b075,b076,b077;// 64b + byte b100,b101,b102,b103,b104,b105,b106,b107;// 72b + byte b110,b111,b112,b113,b114,b115,b116,b117;// 80b + byte b120,b121,b122,b123,b124,b125,b126,b127;// 88b + byte b130,b131,b132,b133,b134,b135,b136,b137;// 96b + byte b140,b141,b142,b143,b144,b145,b146,b147;//104b + byte b150,b151,b152,b153,b154,b155,b156,b157;//112b + byte b160,b161,b162,b163,b164,b165,b166,b167;//120b + byte b170,b171,b172,b173,b174,b175,b176,b177;//128b + + public MpscUnboundedArrayQueue(int chunkSize) + { + super(chunkSize); + } + + + @Override + protected long availableInQueue(long pIndex, long cIndex) + { + return Integer.MAX_VALUE; + } + + @Override + public int capacity() + { + return MessagePassingQueue.UNBOUNDED_CAPACITY; + } + + @Override + public int drain(Consumer c) + { + return drain(c, 4096); + } + + @Override + public int fill(Supplier s) + { + return MessagePassingQueueUtil.fillUnbounded(this, s); + } + + @Override + protected int getNextBufferSize(E[] buffer) + { + return length(buffer); + } + + @Override + protected long getCurrentBufferCapacity(long mask) + { + return mask; + } +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/MpscUnboundedXaddArrayQueue.java b/netty-jctools/src/main/java/org/jctools/queues/MpscUnboundedXaddArrayQueue.java new file mode 100644 index 0000000..9a6d80c --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/MpscUnboundedXaddArrayQueue.java @@ -0,0 +1,346 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues; + +import org.jctools.util.PortableJvmInfo; + +/** + * An MPSC array queue which grows unbounded in linked chunks.
+ * Differently from {@link MpscUnboundedArrayQueue} it is designed to provide a better scaling when more + * producers are concurrently offering.
+ * Users should be aware that {@link #poll()} could spin while awaiting a new element to be available: + * to avoid this behaviour {@link #relaxedPoll()} should be used instead, accounting for the semantic differences + * between the twos. + * + * @author https://github.com/franz1981 + */ +public class MpscUnboundedXaddArrayQueue extends MpUnboundedXaddArrayQueue, E> +{ + /** + * @param chunkSize The buffer size to be used in each chunk of this queue + * @param maxPooledChunks The maximum number of reused chunks kept around to avoid allocation, chunks are pre-allocated + */ + public MpscUnboundedXaddArrayQueue(int chunkSize, int maxPooledChunks) + { + super(chunkSize, maxPooledChunks); + } + + public MpscUnboundedXaddArrayQueue(int chunkSize) + { + this(chunkSize, 2); + } + + @Override + final MpscUnboundedXaddChunk newChunk(long index, MpscUnboundedXaddChunk prev, int chunkSize, boolean pooled) + { + return new MpscUnboundedXaddChunk(index, prev, chunkSize, pooled); + } + + @Override + public boolean offer(E e) + { + if (null == e) + { + throw new NullPointerException(); + } + final int chunkMask = this.chunkMask; + final int chunkShift = this.chunkShift; + + final long pIndex = getAndIncrementProducerIndex(); + + final int piChunkOffset = (int) (pIndex & chunkMask); + final long piChunkIndex = pIndex >> chunkShift; + + MpscUnboundedXaddChunk pChunk = lvProducerChunk(); + if (pChunk.lvIndex() != piChunkIndex) + { + // Other producers may have advanced the producer chunk as we claimed a slot in a prev chunk, or we may have + // now stepped into a brand new chunk which needs appending. + pChunk = producerChunkForIndex(pChunk, piChunkIndex); + } + pChunk.soElement(piChunkOffset, e); + return true; + } + + private MpscUnboundedXaddChunk pollNextBuffer(MpscUnboundedXaddChunk cChunk, long cIndex) + { + final MpscUnboundedXaddChunk next = spinForNextIfNotEmpty(cChunk, cIndex); + + if (next == null) + { + return null; + } + + moveToNextConsumerChunk(cChunk, next); + assert next.lvIndex() == cIndex >> chunkShift; + return next; + } + + private MpscUnboundedXaddChunk spinForNextIfNotEmpty(MpscUnboundedXaddChunk cChunk, long cIndex) + { + MpscUnboundedXaddChunk next = cChunk.lvNext(); + if (next == null) + { + if (lvProducerIndex() == cIndex) + { + return null; + } + final long ccChunkIndex = cChunk.lvIndex(); + if (lvProducerChunkIndex() == ccChunkIndex) { + // no need to help too much here or the consumer latency will be hurt + next = appendNextChunks(cChunk, ccChunkIndex, 1); + } + while (next == null) + { + next = cChunk.lvNext(); + } + } + return next; + } + + @Override + public E poll() + { + final int chunkMask = this.chunkMask; + final long cIndex = this.lpConsumerIndex(); + final int ciChunkOffset = (int) (cIndex & chunkMask); + + MpscUnboundedXaddChunk cChunk = this.lvConsumerChunk(); + // start of new chunk? + if (ciChunkOffset == 0 && cIndex != 0) + { + // pollNextBuffer will verify emptiness check + cChunk = pollNextBuffer(cChunk, cIndex); + if (cChunk == null) + { + return null; + } + } + + E e = cChunk.lvElement(ciChunkOffset); + if (e == null) + { + if (lvProducerIndex() == cIndex) + { + return null; + } + else + { + e = cChunk.spinForElement(ciChunkOffset, false); + } + } + cChunk.soElement(ciChunkOffset, null); + soConsumerIndex(cIndex + 1); + return e; + } + + @Override + public E peek() + { + final int chunkMask = this.chunkMask; + final long cIndex = this.lpConsumerIndex(); + final int ciChunkOffset = (int) (cIndex & chunkMask); + + MpscUnboundedXaddChunk cChunk = this.lpConsumerChunk(); + // start of new chunk? + if (ciChunkOffset == 0 && cIndex != 0) + { + cChunk = spinForNextIfNotEmpty(cChunk, cIndex); + if (cChunk == null) + { + return null; + } + } + + E e = cChunk.lvElement(ciChunkOffset); + if (e == null) + { + if (lvProducerIndex() == cIndex) + { + return null; + } + else + { + e = cChunk.spinForElement(ciChunkOffset, false); + } + } + return e; + } + + @Override + public E relaxedPoll() + { + final int chunkMask = this.chunkMask; + final long cIndex = this.lpConsumerIndex(); + final int ciChunkOffset = (int) (cIndex & chunkMask); + + MpscUnboundedXaddChunk cChunk = this.lpConsumerChunk(); + E e; + // start of new chunk? + if (ciChunkOffset == 0 && cIndex != 0) + { + final MpscUnboundedXaddChunk next = cChunk.lvNext(); + if (next == null) + { + return null; + } + e = next.lvElement(0); + + // if the next chunk doesn't have the first element set we give up + if (e == null) + { + return null; + } + moveToNextConsumerChunk(cChunk, next); + + cChunk = next; + } + else + { + e = cChunk.lvElement(ciChunkOffset); + if (e == null) + { + return null; + } + } + + cChunk.soElement(ciChunkOffset, null); + soConsumerIndex(cIndex + 1); + return e; + } + + @Override + public E relaxedPeek() + { + final int chunkMask = this.chunkMask; + final long cIndex = this.lpConsumerIndex(); + final int cChunkOffset = (int) (cIndex & chunkMask); + + MpscUnboundedXaddChunk cChunk = this.lpConsumerChunk(); + + // start of new chunk? + if (cChunkOffset == 0 && cIndex !=0) + { + cChunk = cChunk.lvNext(); + if (cChunk == null) + { + return null; + } + } + return cChunk.lvElement(cChunkOffset); + } + + @Override + public int fill(Supplier s) + { + long result = 0;// result is a long because we want to have a safepoint check at regular intervals + final int capacity = chunkMask + 1; + final int offerBatch = Math.min(PortableJvmInfo.RECOMENDED_OFFER_BATCH, capacity); + do + { + final int filled = fill(s, offerBatch); + if (filled == 0) + { + return (int) result; + } + result += filled; + } + while (result <= capacity); + return (int) result; + } + + @Override + public int drain(Consumer c, int limit) + { + if (null == c) + throw new IllegalArgumentException("c is null"); + if (limit < 0) + throw new IllegalArgumentException("limit is negative: " + limit); + if (limit == 0) + return 0; + + final int chunkMask = this.chunkMask; + + long cIndex = this.lpConsumerIndex(); + + MpscUnboundedXaddChunk cChunk = this.lpConsumerChunk(); + + for (int i = 0; i < limit; i++) + { + final int consumerOffset = (int) (cIndex & chunkMask); + E e; + if (consumerOffset == 0 && cIndex != 0) + { + final MpscUnboundedXaddChunk next = cChunk.lvNext(); + if (next == null) + { + return i; + } + e = next.lvElement(0); + + // if the next chunk doesn't have the first element set we give up + if (e == null) + { + return i; + } + moveToNextConsumerChunk(cChunk, next); + + cChunk = next; + } + else + { + e = cChunk.lvElement(consumerOffset); + if (e == null) + { + return i; + } + } + cChunk.soElement(consumerOffset, null); + final long nextConsumerIndex = cIndex + 1; + soConsumerIndex(nextConsumerIndex); + c.accept(e); + cIndex = nextConsumerIndex; + } + return limit; + } + + @Override + public int fill(Supplier s, int limit) + { + if (null == s) + throw new IllegalArgumentException("supplier is null"); + if (limit < 0) + throw new IllegalArgumentException("limit is negative:" + limit); + if (limit == 0) + return 0; + + final int chunkShift = this.chunkShift; + final int chunkMask = this.chunkMask; + + long pIndex = getAndAddProducerIndex(limit); + MpscUnboundedXaddChunk pChunk = null; + for (int i = 0; i < limit; i++) + { + final int pChunkOffset = (int) (pIndex & chunkMask); + final long chunkIndex = pIndex >> chunkShift; + if (pChunk == null || pChunk.lvIndex() != chunkIndex) + { + pChunk = producerChunkForIndex(pChunk, chunkIndex); + } + pChunk.soElement(pChunkOffset, s.get()); + pIndex++; + } + return limit; + } +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/MpscUnboundedXaddChunk.java b/netty-jctools/src/main/java/org/jctools/queues/MpscUnboundedXaddChunk.java new file mode 100644 index 0000000..aff9ee4 --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/MpscUnboundedXaddChunk.java @@ -0,0 +1,26 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues; + +import org.jctools.util.InternalAPI; + +@InternalAPI +public final class MpscUnboundedXaddChunk extends MpUnboundedXaddChunk, E> +{ + public MpscUnboundedXaddChunk(long index, MpscUnboundedXaddChunk prev, int size, boolean pooled) + { + super(index, prev, size, pooled); + + } +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/QueueProgressIndicators.java b/netty-jctools/src/main/java/org/jctools/queues/QueueProgressIndicators.java new file mode 100644 index 0000000..f8196b8 --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/QueueProgressIndicators.java @@ -0,0 +1,36 @@ +package org.jctools.queues; + +/** + * This interface is provided for monitoring purposes only and is only available on queues where it is easy to + * provide it. The producer/consumer progress indicators usually correspond with the number of elements + * offered/polled, but they are not guaranteed to maintain that semantic. + */ +public interface QueueProgressIndicators +{ + + /** + * This method has no concurrent visibility semantics. The value returned may be negative. Under normal + * circumstances 2 consecutive calls to this method can offer an idea of progress made by producer threads + * by subtracting the 2 results though in extreme cases (if producers have progressed by more than 2^64) + * this may also fail.
+ * This value will normally indicate number of elements passed into the queue, but may under some + * circumstances be a derivative of that figure. This method should not be used to derive size or + * emptiness. + * + * @return the current value of the producer progress index + */ + long currentProducerIndex(); + + /** + * This method has no concurrent visibility semantics. The value returned may be negative. Under normal + * circumstances 2 consecutive calls to this method can offer an idea of progress made by consumer threads + * by subtracting the 2 results though in extreme cases (if consumers have progressed by more than 2^64) + * this may also fail.
+ * This value will normally indicate number of elements taken out of the queue, but may under some + * circumstances be a derivative of that figure. This method should not be used to derive size or + * emptiness. + * + * @return the current value of the consumer progress index + */ + long currentConsumerIndex(); +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/SpmcArrayQueue.java b/netty-jctools/src/main/java/org/jctools/queues/SpmcArrayQueue.java new file mode 100644 index 0000000..ef8b23d --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/SpmcArrayQueue.java @@ -0,0 +1,456 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues; + +import static org.jctools.util.UnsafeAccess.UNSAFE; +import static org.jctools.util.UnsafeAccess.fieldOffset; +import static org.jctools.util.UnsafeRefArrayAccess.*; + +abstract class SpmcArrayQueueL1Pad extends ConcurrentCircularArrayQueue +{ + byte b000,b001,b002,b003,b004,b005,b006,b007;// 8b + byte b010,b011,b012,b013,b014,b015,b016,b017;// 16b + byte b020,b021,b022,b023,b024,b025,b026,b027;// 24b + byte b030,b031,b032,b033,b034,b035,b036,b037;// 32b + byte b040,b041,b042,b043,b044,b045,b046,b047;// 40b + byte b050,b051,b052,b053,b054,b055,b056,b057;// 48b + byte b060,b061,b062,b063,b064,b065,b066,b067;// 56b + byte b070,b071,b072,b073,b074,b075,b076,b077;// 64b + byte b100,b101,b102,b103,b104,b105,b106,b107;// 72b + byte b110,b111,b112,b113,b114,b115,b116,b117;// 80b + byte b120,b121,b122,b123,b124,b125,b126,b127;// 88b + byte b130,b131,b132,b133,b134,b135,b136,b137;// 96b + byte b140,b141,b142,b143,b144,b145,b146,b147;//104b + byte b150,b151,b152,b153,b154,b155,b156,b157;//112b + byte b160,b161,b162,b163,b164,b165,b166,b167;//120b + byte b170,b171,b172,b173,b174,b175,b176,b177;//128b + + SpmcArrayQueueL1Pad(int capacity) + { + super(capacity); + } +} + +//$gen:ordered-fields +abstract class SpmcArrayQueueProducerIndexField extends SpmcArrayQueueL1Pad +{ + protected final static long P_INDEX_OFFSET = fieldOffset(SpmcArrayQueueProducerIndexField.class,"producerIndex"); + + private volatile long producerIndex; + + SpmcArrayQueueProducerIndexField(int capacity) + { + super(capacity); + } + + @Override + public final long lvProducerIndex() + { + return producerIndex; + } + + final long lpProducerIndex() + { + return UNSAFE.getLong(this, P_INDEX_OFFSET); + } + + final void soProducerIndex(long newValue) + { + UNSAFE.putOrderedLong(this, P_INDEX_OFFSET, newValue); + } + +} + +abstract class SpmcArrayQueueL2Pad extends SpmcArrayQueueProducerIndexField +{ + byte b000,b001,b002,b003,b004,b005,b006,b007;// 8b + byte b010,b011,b012,b013,b014,b015,b016,b017;// 16b + byte b020,b021,b022,b023,b024,b025,b026,b027;// 24b + byte b030,b031,b032,b033,b034,b035,b036,b037;// 32b + byte b040,b041,b042,b043,b044,b045,b046,b047;// 40b + byte b050,b051,b052,b053,b054,b055,b056,b057;// 48b + byte b060,b061,b062,b063,b064,b065,b066,b067;// 56b + byte b070,b071,b072,b073,b074,b075,b076,b077;// 64b + byte b100,b101,b102,b103,b104,b105,b106,b107;// 72b + byte b110,b111,b112,b113,b114,b115,b116,b117;// 80b + byte b120,b121,b122,b123,b124,b125,b126,b127;// 88b + byte b130,b131,b132,b133,b134,b135,b136,b137;// 96b + byte b140,b141,b142,b143,b144,b145,b146,b147;//104b + byte b150,b151,b152,b153,b154,b155,b156,b157;//112b + byte b160,b161,b162,b163,b164,b165,b166,b167;//120b + byte b170,b171,b172,b173,b174,b175,b176,b177;//128b + + SpmcArrayQueueL2Pad(int capacity) + { + super(capacity); + } +} + +//$gen:ordered-fields +abstract class SpmcArrayQueueConsumerIndexField extends SpmcArrayQueueL2Pad +{ + protected final static long C_INDEX_OFFSET = fieldOffset(SpmcArrayQueueConsumerIndexField.class, "consumerIndex"); + + private volatile long consumerIndex; + + SpmcArrayQueueConsumerIndexField(int capacity) + { + super(capacity); + } + + @Override + public final long lvConsumerIndex() + { + return consumerIndex; + } + + final boolean casConsumerIndex(long expect, long newValue) + { + return UNSAFE.compareAndSwapLong(this, C_INDEX_OFFSET, expect, newValue); + } +} + +abstract class SpmcArrayQueueMidPad extends SpmcArrayQueueConsumerIndexField +{ + byte b000,b001,b002,b003,b004,b005,b006,b007;// 8b + byte b010,b011,b012,b013,b014,b015,b016,b017;// 16b + byte b020,b021,b022,b023,b024,b025,b026,b027;// 24b + byte b030,b031,b032,b033,b034,b035,b036,b037;// 32b + byte b040,b041,b042,b043,b044,b045,b046,b047;// 40b + byte b050,b051,b052,b053,b054,b055,b056,b057;// 48b + byte b060,b061,b062,b063,b064,b065,b066,b067;// 56b + byte b070,b071,b072,b073,b074,b075,b076,b077;// 64b + byte b100,b101,b102,b103,b104,b105,b106,b107;// 72b + byte b110,b111,b112,b113,b114,b115,b116,b117;// 80b + byte b120,b121,b122,b123,b124,b125,b126,b127;// 88b + byte b130,b131,b132,b133,b134,b135,b136,b137;// 96b + byte b140,b141,b142,b143,b144,b145,b146,b147;//104b + byte b150,b151,b152,b153,b154,b155,b156,b157;//112b + byte b160,b161,b162,b163,b164,b165,b166,b167;//120b + byte b170,b171,b172,b173,b174,b175,b176,b177;//128b + + SpmcArrayQueueMidPad(int capacity) + { + super(capacity); + } +} + +//$gen:ordered-fields +abstract class SpmcArrayQueueProducerIndexCacheField extends SpmcArrayQueueMidPad +{ + // This is separated from the consumerIndex which will be highly contended in the hope that this value spends most + // of it's time in a cache line that is Shared(and rarely invalidated) + private volatile long producerIndexCache; + + SpmcArrayQueueProducerIndexCacheField(int capacity) + { + super(capacity); + } + + protected final long lvProducerIndexCache() + { + return producerIndexCache; + } + + protected final void svProducerIndexCache(long newValue) + { + producerIndexCache = newValue; + } +} + +abstract class SpmcArrayQueueL3Pad extends SpmcArrayQueueProducerIndexCacheField +{ + byte b000,b001,b002,b003,b004,b005,b006,b007;// 8b + byte b010,b011,b012,b013,b014,b015,b016,b017;// 16b + byte b020,b021,b022,b023,b024,b025,b026,b027;// 24b + byte b030,b031,b032,b033,b034,b035,b036,b037;// 32b + byte b040,b041,b042,b043,b044,b045,b046,b047;// 40b + byte b050,b051,b052,b053,b054,b055,b056,b057;// 48b + byte b060,b061,b062,b063,b064,b065,b066,b067;// 56b + byte b070,b071,b072,b073,b074,b075,b076,b077;// 64b + byte b100,b101,b102,b103,b104,b105,b106,b107;// 72b + byte b110,b111,b112,b113,b114,b115,b116,b117;// 80b + byte b120,b121,b122,b123,b124,b125,b126,b127;// 88b + byte b130,b131,b132,b133,b134,b135,b136,b137;// 96b + byte b140,b141,b142,b143,b144,b145,b146,b147;//104b + byte b150,b151,b152,b153,b154,b155,b156,b157;//112b + byte b160,b161,b162,b163,b164,b165,b166,b167;//120b + byte b170,b171,b172,b173,b174,b175,b176,b177;//128b + + SpmcArrayQueueL3Pad(int capacity) + { + super(capacity); + } +} + +public class SpmcArrayQueue extends SpmcArrayQueueL3Pad +{ + + public SpmcArrayQueue(final int capacity) + { + super(capacity); + } + + @Override + public boolean offer(final E e) + { + if (null == e) + { + throw new NullPointerException(); + } + final E[] buffer = this.buffer; + final long mask = this.mask; + final long currProducerIndex = lvProducerIndex(); + final long offset = calcCircularRefElementOffset(currProducerIndex, mask); + if (null != lvRefElement(buffer, offset)) + { + long size = currProducerIndex - lvConsumerIndex(); + + if (size > mask) + { + return false; + } + else + { + // Bubble: This can happen because `poll` moves index before placing element. + // spin wait for slot to clear, buggers wait freedom + while (null != lvRefElement(buffer, offset)) + { + // BURN + } + } + } + soRefElement(buffer, offset, e); + // single producer, so store ordered is valid. It is also required to correctly publish the element + // and for the consumers to pick up the tail value. + soProducerIndex(currProducerIndex + 1); + return true; + } + + @Override + public E poll() + { + long currentConsumerIndex; + long currProducerIndexCache = lvProducerIndexCache(); + do + { + currentConsumerIndex = lvConsumerIndex(); + if (currentConsumerIndex >= currProducerIndexCache) + { + long currProducerIndex = lvProducerIndex(); + if (currentConsumerIndex >= currProducerIndex) + { + return null; + } + else + { + currProducerIndexCache = currProducerIndex; + svProducerIndexCache(currProducerIndex); + } + } + } + while (!casConsumerIndex(currentConsumerIndex, currentConsumerIndex + 1)); + // consumers are gated on latest visible tail, and so can't see a null value in the queue or overtake + // and wrap to hit same location. + return removeElement(buffer, currentConsumerIndex, mask); + } + + private E removeElement(final E[] buffer, long index, final long mask) + { + final long offset = calcCircularRefElementOffset(index, mask); + // load plain, element happens before it's index becomes visible + final E e = lpRefElement(buffer, offset); + // store ordered, make sure nulling out is visible. Producer is waiting for this value. + soRefElement(buffer, offset, null); + return e; + } + + @Override + public E peek() + { + final E[] buffer = this.buffer; + final long mask = this.mask; + long currProducerIndexCache = lvProducerIndexCache(); + long currentConsumerIndex; + long nextConsumerIndex = lvConsumerIndex(); + E e; + do + { + currentConsumerIndex = nextConsumerIndex; + if (currentConsumerIndex >= currProducerIndexCache) + { + long currProducerIndex = lvProducerIndex(); + if (currentConsumerIndex >= currProducerIndex) + { + return null; + } + else + { + currProducerIndexCache = currProducerIndex; + svProducerIndexCache(currProducerIndex); + } + } + e = lvRefElement(buffer, calcCircularRefElementOffset(currentConsumerIndex, mask)); + // sandwich the element load between 2 consumer index loads + nextConsumerIndex = lvConsumerIndex(); + } + while (null == e || nextConsumerIndex != currentConsumerIndex); + return e; + } + + @Override + public boolean relaxedOffer(E e) + { + if (null == e) + { + throw new NullPointerException("Null is not a valid element"); + } + final E[] buffer = this.buffer; + final long mask = this.mask; + final long producerIndex = lpProducerIndex(); + final long offset = calcCircularRefElementOffset(producerIndex, mask); + if (null != lvRefElement(buffer, offset)) + { + return false; + } + soRefElement(buffer, offset, e); + // single producer, so store ordered is valid. It is also required to correctly publish the element + // and for the consumers to pick up the tail value. + soProducerIndex(producerIndex + 1); + return true; + } + + @Override + public E relaxedPoll() + { + return poll(); + } + + @Override + public E relaxedPeek() + { + final E[] buffer = this.buffer; + final long mask = this.mask; + long currentConsumerIndex; + long nextConsumerIndex = lvConsumerIndex(); + E e; + do + { + currentConsumerIndex = nextConsumerIndex; + e = lvRefElement(buffer, calcCircularRefElementOffset(currentConsumerIndex, mask)); + // sandwich the element load between 2 consumer index loads + nextConsumerIndex = lvConsumerIndex(); + } + while (nextConsumerIndex != currentConsumerIndex); + return e; + } + + @Override + public int drain(final Consumer c, final int limit) + { + if (null == c) + throw new IllegalArgumentException("c is null"); + if (limit < 0) + throw new IllegalArgumentException("limit is negative: " + limit); + if (limit == 0) + return 0; + + final E[] buffer = this.buffer; + final long mask = this.mask; + long currProducerIndexCache = lvProducerIndexCache(); + int adjustedLimit = 0; + long currentConsumerIndex; + do + { + currentConsumerIndex = lvConsumerIndex(); + // is there any space in the queue? + if (currentConsumerIndex >= currProducerIndexCache) + { + long currProducerIndex = lvProducerIndex(); + if (currentConsumerIndex >= currProducerIndex) + { + return 0; + } + else + { + currProducerIndexCache = currProducerIndex; + svProducerIndexCache(currProducerIndex); + } + } + // try and claim up to 'limit' elements in one go + int remaining = (int) (currProducerIndexCache - currentConsumerIndex); + adjustedLimit = Math.min(remaining, limit); + } + while (!casConsumerIndex(currentConsumerIndex, currentConsumerIndex + adjustedLimit)); + + for (int i = 0; i < adjustedLimit; i++) + { + c.accept(removeElement(buffer, currentConsumerIndex + i, mask)); + } + return adjustedLimit; + } + + + @Override + public int fill(final Supplier s, final int limit) + { + if (null == s) + throw new IllegalArgumentException("supplier is null"); + if (limit < 0) + throw new IllegalArgumentException("limit is negative:" + limit); + if (limit == 0) + return 0; + + final E[] buffer = this.buffer; + final long mask = this.mask; + long producerIndex = this.lpProducerIndex(); + + for (int i = 0; i < limit; i++) + { + final long offset = calcCircularRefElementOffset(producerIndex, mask); + if (null != lvRefElement(buffer, offset)) + { + return i; + } + producerIndex++; + soRefElement(buffer, offset, s.get()); + soProducerIndex(producerIndex); // ordered store -> atomic and ordered for size() + } + return limit; + } + + @Override + public int drain(final Consumer c) + { + return MessagePassingQueueUtil.drain(this, c); + } + + @Override + public int fill(final Supplier s) + { + return fill(s, capacity()); + } + + @Override + public void drain(final Consumer c, final WaitStrategy w, final ExitCondition exit) + { + MessagePassingQueueUtil.drain(this, c, w, exit); + } + + @Override + public void fill(final Supplier s, final WaitStrategy w, final ExitCondition e) + { + MessagePassingQueueUtil.fill(this, s, w, e); + } +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/SpscArrayQueue.java b/netty-jctools/src/main/java/org/jctools/queues/SpscArrayQueue.java new file mode 100644 index 0000000..a21eb11 --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/SpscArrayQueue.java @@ -0,0 +1,458 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues; + +import org.jctools.util.SpscLookAheadUtil; + +import static org.jctools.util.UnsafeAccess.UNSAFE; +import static org.jctools.util.UnsafeAccess.fieldOffset; +import static org.jctools.util.UnsafeRefArrayAccess.*; + +abstract class SpscArrayQueueColdField extends ConcurrentCircularArrayQueue +{ + final int lookAheadStep; + + SpscArrayQueueColdField(int capacity) + { + super(capacity); + int actualCapacity = capacity(); + lookAheadStep = SpscLookAheadUtil.computeLookAheadStep(actualCapacity); + } + +} + +abstract class SpscArrayQueueL1Pad extends SpscArrayQueueColdField +{ + byte b000,b001,b002,b003,b004,b005,b006,b007;// 8b + byte b010,b011,b012,b013,b014,b015,b016,b017;// 16b + byte b020,b021,b022,b023,b024,b025,b026,b027;// 24b + byte b030,b031,b032,b033,b034,b035,b036,b037;// 32b + byte b040,b041,b042,b043,b044,b045,b046,b047;// 40b + byte b050,b051,b052,b053,b054,b055,b056,b057;// 48b + byte b060,b061,b062,b063,b064,b065,b066,b067;// 56b + byte b070,b071,b072,b073,b074,b075,b076,b077;// 64b + byte b100,b101,b102,b103,b104,b105,b106,b107;// 72b + byte b110,b111,b112,b113,b114,b115,b116,b117;// 80b + byte b120,b121,b122,b123,b124,b125,b126,b127;// 88b + byte b130,b131,b132,b133,b134,b135,b136,b137;// 96b + byte b140,b141,b142,b143,b144,b145,b146,b147;//104b + byte b150,b151,b152,b153,b154,b155,b156,b157;//112b + byte b160,b161,b162,b163,b164,b165,b166,b167;//120b + byte b170,b171,b172,b173,b174,b175,b176,b177;//128b + + SpscArrayQueueL1Pad(int capacity) + { + super(capacity); + } +} + +// $gen:ordered-fields +abstract class SpscArrayQueueProducerIndexFields extends SpscArrayQueueL1Pad +{ + private final static long P_INDEX_OFFSET = fieldOffset(SpscArrayQueueProducerIndexFields.class, "producerIndex"); + + private volatile long producerIndex; + protected long producerLimit; + + SpscArrayQueueProducerIndexFields(int capacity) + { + super(capacity); + } + + @Override + public final long lvProducerIndex() + { + return producerIndex; + } + + final long lpProducerIndex() + { + return UNSAFE.getLong(this, P_INDEX_OFFSET); + } + + final void soProducerIndex(final long newValue) + { + UNSAFE.putOrderedLong(this, P_INDEX_OFFSET, newValue); + } + +} + +abstract class SpscArrayQueueL2Pad extends SpscArrayQueueProducerIndexFields +{ + byte b000,b001,b002,b003,b004,b005,b006,b007;// 8b + byte b010,b011,b012,b013,b014,b015,b016,b017;// 16b + byte b020,b021,b022,b023,b024,b025,b026,b027;// 24b + byte b030,b031,b032,b033,b034,b035,b036,b037;// 32b + byte b040,b041,b042,b043,b044,b045,b046,b047;// 40b + byte b050,b051,b052,b053,b054,b055,b056,b057;// 48b + byte b060,b061,b062,b063,b064,b065,b066,b067;// 56b + byte b070,b071,b072,b073,b074,b075,b076,b077;// 64b + byte b100,b101,b102,b103,b104,b105,b106,b107;// 72b + byte b110,b111,b112,b113,b114,b115,b116,b117;// 80b + byte b120,b121,b122,b123,b124,b125,b126,b127;// 88b + byte b130,b131,b132,b133,b134,b135,b136,b137;// 96b + byte b140,b141,b142,b143,b144,b145,b146,b147;//104b + byte b150,b151,b152,b153,b154,b155,b156,b157;//112b + byte b160,b161,b162,b163,b164,b165,b166,b167;//120b + byte b170,b171,b172,b173,b174,b175,b176,b177;//128b + + SpscArrayQueueL2Pad(int capacity) + { + super(capacity); + } +} + +//$gen:ordered-fields +abstract class SpscArrayQueueConsumerIndexField extends SpscArrayQueueL2Pad +{ + private final static long C_INDEX_OFFSET = fieldOffset(SpscArrayQueueConsumerIndexField.class, "consumerIndex"); + + private volatile long consumerIndex; + + SpscArrayQueueConsumerIndexField(int capacity) + { + super(capacity); + } + + public final long lvConsumerIndex() + { + return UNSAFE.getLongVolatile(this, C_INDEX_OFFSET); + } + + final long lpConsumerIndex() + { + return UNSAFE.getLong(this, C_INDEX_OFFSET); + } + + final void soConsumerIndex(final long newValue) + { + UNSAFE.putOrderedLong(this, C_INDEX_OFFSET, newValue); + } +} + +abstract class SpscArrayQueueL3Pad extends SpscArrayQueueConsumerIndexField +{ + byte b000,b001,b002,b003,b004,b005,b006,b007;// 8b + byte b010,b011,b012,b013,b014,b015,b016,b017;// 16b + byte b020,b021,b022,b023,b024,b025,b026,b027;// 24b + byte b030,b031,b032,b033,b034,b035,b036,b037;// 32b + byte b040,b041,b042,b043,b044,b045,b046,b047;// 40b + byte b050,b051,b052,b053,b054,b055,b056,b057;// 48b + byte b060,b061,b062,b063,b064,b065,b066,b067;// 56b + byte b070,b071,b072,b073,b074,b075,b076,b077;// 64b + byte b100,b101,b102,b103,b104,b105,b106,b107;// 72b + byte b110,b111,b112,b113,b114,b115,b116,b117;// 80b + byte b120,b121,b122,b123,b124,b125,b126,b127;// 88b + byte b130,b131,b132,b133,b134,b135,b136,b137;// 96b + byte b140,b141,b142,b143,b144,b145,b146,b147;//104b + byte b150,b151,b152,b153,b154,b155,b156,b157;//112b + byte b160,b161,b162,b163,b164,b165,b166,b167;//120b + byte b170,b171,b172,b173,b174,b175,b176,b177;//128b + + SpscArrayQueueL3Pad(int capacity) + { + super(capacity); + } +} + + +/** + * A Single-Producer-Single-Consumer queue backed by a pre-allocated buffer. + *

+ * This implementation is a mashup of the Fast Flow + * algorithm with an optimization of the offer method taken from the BQueue algorithm (a variation on Fast + * Flow), and adjusted to comply with Queue.offer semantics with regards to capacity.
+ * For convenience the relevant papers are available in the `resources` folder:
+ * + * 2010 - Pisa - SPSC Queues on Shared Cache Multi-Core Systems.pdf
+ * 2012 - Junchang- BQueue- Efficient and Practical Queuing.pdf
+ *
+ * This implementation is wait free. + */ +public class SpscArrayQueue extends SpscArrayQueueL3Pad +{ + + public SpscArrayQueue(final int capacity) + { + super(Math.max(capacity, 4)); + } + + /** + * {@inheritDoc} + *

+ * This implementation is correct for single producer thread use only. + */ + @Override + public boolean offer(final E e) + { + if (null == e) + { + throw new NullPointerException(); + } + // local load of field to avoid repeated loads after volatile reads + final E[] buffer = this.buffer; + final long mask = this.mask; + final long producerIndex = this.lpProducerIndex(); + + if (producerIndex >= producerLimit && + !offerSlowPath(buffer, mask, producerIndex)) + { + return false; + } + final long offset = calcCircularRefElementOffset(producerIndex, mask); + + soRefElement(buffer, offset, e); + soProducerIndex(producerIndex + 1); // ordered store -> atomic and ordered for size() + return true; + } + + private boolean offerSlowPath(final E[] buffer, final long mask, final long producerIndex) + { + final int lookAheadStep = this.lookAheadStep; + if (null == lvRefElement(buffer, + calcCircularRefElementOffset(producerIndex + lookAheadStep, mask))) + { + producerLimit = producerIndex + lookAheadStep; + } + else + { + final long offset = calcCircularRefElementOffset(producerIndex, mask); + if (null != lvRefElement(buffer, offset)) + { + return false; + } + } + return true; + } + + /** + * {@inheritDoc} + *

+ * This implementation is correct for single consumer thread use only. + */ + @Override + public E poll() + { + final long consumerIndex = this.lpConsumerIndex(); + final long offset = calcCircularRefElementOffset(consumerIndex, mask); + // local load of field to avoid repeated loads after volatile reads + final E[] buffer = this.buffer; + final E e = lvRefElement(buffer, offset); + if (null == e) + { + return null; + } + soRefElement(buffer, offset, null); + soConsumerIndex(consumerIndex + 1); // ordered store -> atomic and ordered for size() + return e; + } + + /** + * {@inheritDoc} + *

+ * This implementation is correct for single consumer thread use only. + */ + @Override + public E peek() + { + return lvRefElement(buffer, calcCircularRefElementOffset(lpConsumerIndex(), mask)); + } + + @Override + public boolean relaxedOffer(final E message) + { + return offer(message); + } + + @Override + public E relaxedPoll() + { + return poll(); + } + + @Override + public E relaxedPeek() + { + return peek(); + } + + @Override + public int drain(final Consumer c) + { + return drain(c, capacity()); + } + + @Override + public int fill(final Supplier s) + { + return fill(s, capacity()); + } + + @Override + public int drain(final Consumer c, final int limit) + { + if (null == c) + throw new IllegalArgumentException("c is null"); + if (limit < 0) + throw new IllegalArgumentException("limit is negative: " + limit); + if (limit == 0) + return 0; + + final E[] buffer = this.buffer; + final long mask = this.mask; + final long consumerIndex = this.lpConsumerIndex(); + + for (int i = 0; i < limit; i++) + { + final long index = consumerIndex + i; + final long offset = calcCircularRefElementOffset(index, mask); + final E e = lvRefElement(buffer, offset); + if (null == e) + { + return i; + } + soRefElement(buffer, offset, null); + soConsumerIndex(index + 1); // ordered store -> atomic and ordered for size() + c.accept(e); + } + return limit; + } + + @Override + public int fill(final Supplier s, final int limit) + { + if (null == s) + throw new IllegalArgumentException("supplier is null"); + if (limit < 0) + throw new IllegalArgumentException("limit is negative:" + limit); + if (limit == 0) + return 0; + + final E[] buffer = this.buffer; + final long mask = this.mask; + final int lookAheadStep = this.lookAheadStep; + final long producerIndex = this.lpProducerIndex(); + + for (int i = 0; i < limit; i++) + { + final long index = producerIndex + i; + final long lookAheadElementOffset = + calcCircularRefElementOffset(index + lookAheadStep, mask); + if (null == lvRefElement(buffer, lookAheadElementOffset)) + { + int lookAheadLimit = Math.min(lookAheadStep, limit - i); + for (int j = 0; j < lookAheadLimit; j++) + { + final long offset = calcCircularRefElementOffset(index + j, mask); + soRefElement(buffer, offset, s.get()); + soProducerIndex(index + j + 1); // ordered store -> atomic and ordered for size() + } + i += lookAheadLimit - 1; + } + else + { + final long offset = calcCircularRefElementOffset(index, mask); + if (null != lvRefElement(buffer, offset)) + { + return i; + } + soRefElement(buffer, offset, s.get()); + soProducerIndex(index + 1); // ordered store -> atomic and ordered for size() + } + + } + return limit; + } + + @Override + public void drain(final Consumer c, final WaitStrategy w, final ExitCondition exit) + { + if (null == c) + throw new IllegalArgumentException("c is null"); + if (null == w) + throw new IllegalArgumentException("wait is null"); + if (null == exit) + throw new IllegalArgumentException("exit condition is null"); + + final E[] buffer = this.buffer; + final long mask = this.mask; + long consumerIndex = this.lpConsumerIndex(); + + int counter = 0; + while (exit.keepRunning()) + { + for (int i = 0; i < 4096; i++) + { + final long offset = calcCircularRefElementOffset(consumerIndex, mask); + final E e = lvRefElement(buffer, offset); + if (null == e) + { + counter = w.idle(counter); + continue; + } + consumerIndex++; + counter = 0; + soRefElement(buffer, offset, null); + soConsumerIndex(consumerIndex); // ordered store -> atomic and ordered for size() + c.accept(e); + } + } + } + + @Override + public void fill(final Supplier s, final WaitStrategy w, final ExitCondition e) + { + if (null == w) + throw new IllegalArgumentException("waiter is null"); + if (null == e) + throw new IllegalArgumentException("exit condition is null"); + if (null == s) + throw new IllegalArgumentException("supplier is null"); + + final E[] buffer = this.buffer; + final long mask = this.mask; + final int lookAheadStep = this.lookAheadStep; + long producerIndex = this.lpProducerIndex(); + int counter = 0; + while (e.keepRunning()) + { + final long lookAheadElementOffset = + calcCircularRefElementOffset(producerIndex + lookAheadStep, mask); + if (null == lvRefElement(buffer, lookAheadElementOffset)) + { + for (int j = 0; j < lookAheadStep; j++) + { + final long offset = calcCircularRefElementOffset(producerIndex, mask); + producerIndex++; + soRefElement(buffer, offset, s.get()); + soProducerIndex(producerIndex); // ordered store -> atomic and ordered for size() + } + } + else + { + final long offset = calcCircularRefElementOffset(producerIndex, mask); + if (null != lvRefElement(buffer, offset)) + { + counter = w.idle(counter); + continue; + } + producerIndex++; + counter = 0; + soRefElement(buffer, offset, s.get()); + soProducerIndex(producerIndex); // ordered store -> atomic and ordered for size() + } + } + } +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/SpscChunkedArrayQueue.java b/netty-jctools/src/main/java/org/jctools/queues/SpscChunkedArrayQueue.java new file mode 100644 index 0000000..50a902d --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/SpscChunkedArrayQueue.java @@ -0,0 +1,114 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues; + +import org.jctools.util.Pow2; +import org.jctools.util.RangeUtil; + +import static org.jctools.util.UnsafeRefArrayAccess.*; + +/** + * An SPSC array queue which starts at initialCapacity and grows to maxCapacity in linked chunks + * of the initial size. The queue grows only when the current chunk is full and elements are not copied on + * resize, instead a link to the new chunk is stored in the old chunk for the consumer to follow.
+ * + * @param + */ +public class SpscChunkedArrayQueue extends BaseSpscLinkedArrayQueue +{ + private final int maxQueueCapacity; + private long producerQueueLimit; + + public SpscChunkedArrayQueue(int capacity) + { + this(Math.max(8, Pow2.roundToPowerOfTwo(capacity / 8)), capacity); + } + + public SpscChunkedArrayQueue(int chunkSize, int capacity) + { + RangeUtil.checkGreaterThanOrEqual(capacity, 16, "capacity"); + // minimal chunk size of eight makes sure minimal lookahead step is 2 + RangeUtil.checkGreaterThanOrEqual(chunkSize, 8, "chunkSize"); + + maxQueueCapacity = Pow2.roundToPowerOfTwo(capacity); + int chunkCapacity = Pow2.roundToPowerOfTwo(chunkSize); + RangeUtil.checkLessThan(chunkCapacity, maxQueueCapacity, "chunkCapacity"); + + long mask = chunkCapacity - 1; + // need extra element to point at next array + E[] buffer = allocateRefArray(chunkCapacity + 1); + producerBuffer = buffer; + producerMask = mask; + consumerBuffer = buffer; + consumerMask = mask; + producerBufferLimit = mask - 1; // we know it's all empty to start with + producerQueueLimit = maxQueueCapacity; + } + + @Override + final boolean offerColdPath(E[] buffer, long mask, long pIndex, long offset, E v, Supplier s) + { + // use a fixed lookahead step based on buffer capacity + final long lookAheadStep = (mask + 1) / 4; + long pBufferLimit = pIndex + lookAheadStep; + + long pQueueLimit = producerQueueLimit; + + if (pIndex >= pQueueLimit) + { + // we tested against a potentially out of date queue limit, refresh it + final long cIndex = lvConsumerIndex(); + producerQueueLimit = pQueueLimit = cIndex + maxQueueCapacity; + // if we're full we're full + if (pIndex >= pQueueLimit) + { + return false; + } + } + // if buffer limit is after queue limit we use queue limit. We need to handle overflow so + // cannot use Math.min + if (pBufferLimit - pQueueLimit > 0) + { + pBufferLimit = pQueueLimit; + } + + // go around the buffer or add a new buffer + if (pBufferLimit > pIndex + 1 && // there's sufficient room in buffer/queue to use pBufferLimit + null == lvRefElement(buffer, calcCircularRefElementOffset(pBufferLimit, mask))) + { + producerBufferLimit = pBufferLimit - 1; // joy, there's plenty of room + writeToQueue(buffer, v == null ? s.get() : v, pIndex, offset); + } + else if (null == lvRefElement(buffer, calcCircularRefElementOffset(pIndex + 1, mask))) + { // buffer is not full + writeToQueue(buffer, v == null ? s.get() : v, pIndex, offset); + } + else + { + // we got one slot left to write into, and we are not full. Need to link new buffer. + // allocate new buffer of same length + final E[] newBuffer = allocateRefArray((int) (mask + 2)); + producerBuffer = newBuffer; + + linkOldToNew(pIndex, buffer, offset, newBuffer, offset, v == null ? s.get() : v); + } + return true; + } + + @Override + public int capacity() + { + return maxQueueCapacity; + } +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/SpscGrowableArrayQueue.java b/netty-jctools/src/main/java/org/jctools/queues/SpscGrowableArrayQueue.java new file mode 100644 index 0000000..6b34340 --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/SpscGrowableArrayQueue.java @@ -0,0 +1,170 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues; + +import org.jctools.util.Pow2; +import org.jctools.util.RangeUtil; +import org.jctools.util.SpscLookAheadUtil; + +import static org.jctools.queues.LinkedArrayQueueUtil.length; +import static org.jctools.util.UnsafeRefArrayAccess.*; + +/** + * An SPSC array queue which starts at initialCapacity and grows to maxCapacity in linked chunks, + * doubling theirs size every time until the full blown backing array is used. + * The queue grows only when the current chunk is full and elements are not copied on + * resize, instead a link to the new chunk is stored in the old chunk for the consumer to follow.
+ * + * @param + */ +public class SpscGrowableArrayQueue extends BaseSpscLinkedArrayQueue +{ + private final int maxQueueCapacity; + private long lookAheadStep; + + public SpscGrowableArrayQueue(final int capacity) + { + this(Math.max(8, Pow2.roundToPowerOfTwo(capacity / 8)), capacity); + } + + public SpscGrowableArrayQueue(final int chunkSize, final int capacity) + { + RangeUtil.checkGreaterThanOrEqual(capacity, 16, "capacity"); + // minimal chunk size of eight makes sure minimal lookahead step is 2 + RangeUtil.checkGreaterThanOrEqual(chunkSize, 8, "chunkSize"); + + maxQueueCapacity = Pow2.roundToPowerOfTwo(capacity); + int chunkCapacity = Pow2.roundToPowerOfTwo(chunkSize); + RangeUtil.checkLessThan(chunkCapacity, maxQueueCapacity, "chunkCapacity"); + + long mask = chunkCapacity - 1; + // need extra element to point at next array + E[] buffer = allocateRefArray(chunkCapacity + 1); + producerBuffer = buffer; + producerMask = mask; + consumerBuffer = buffer; + consumerMask = mask; + producerBufferLimit = mask - 1; // we know it's all empty to start with + adjustLookAheadStep(chunkCapacity); + } + + @Override + final boolean offerColdPath( + final E[] buffer, + final long mask, + final long index, + final long offset, + final E v, + final Supplier s) + { + final long lookAheadStep = this.lookAheadStep; + // normal case, go around the buffer or resize if full (unless we hit max capacity) + if (lookAheadStep > 0) + { + long lookAheadElementOffset = calcCircularRefElementOffset(index + lookAheadStep, mask); + // Try and look ahead a number of elements so we don't have to do this all the time + if (null == lvRefElement(buffer, lookAheadElementOffset)) + { + producerBufferLimit = index + lookAheadStep - 1; // joy, there's plenty of room + writeToQueue(buffer, v == null ? s.get() : v, index, offset); + return true; + } + // we're at max capacity, can use up last element + final int maxCapacity = maxQueueCapacity; + if (mask + 1 == maxCapacity) + { + if (null == lvRefElement(buffer, offset)) + { + writeToQueue(buffer, v == null ? s.get() : v, index, offset); + return true; + } + // we're full and can't grow + return false; + } + // not at max capacity, so must allow extra slot for next buffer pointer + if (null == lvRefElement(buffer, calcCircularRefElementOffset(index + 1, mask))) + { // buffer is not full + writeToQueue(buffer, v == null ? s.get() : v, index, offset); + } + else + { + // allocate new buffer of same length + final E[] newBuffer = allocateRefArray((int) (2 * (mask + 1) + 1)); + + producerBuffer = newBuffer; + producerMask = length(newBuffer) - 2; + + final long offsetInNew = calcCircularRefElementOffset(index, producerMask); + linkOldToNew(index, buffer, offset, newBuffer, offsetInNew, v == null ? s.get() : v); + int newCapacity = (int) (producerMask + 1); + if (newCapacity == maxCapacity) + { + long currConsumerIndex = lvConsumerIndex(); + // use lookAheadStep to store the consumer distance from final buffer + this.lookAheadStep = -(index - currConsumerIndex); + producerBufferLimit = currConsumerIndex + maxCapacity; + } + else + { + producerBufferLimit = index + producerMask - 1; + adjustLookAheadStep(newCapacity); + } + } + return true; + } + // the step is negative (or zero) in the period between allocating the max sized buffer and the + // consumer starting on it + else + { + final long prevElementsInOtherBuffers = -lookAheadStep; + // until the consumer starts using the current buffer we need to check consumer index to + // verify size + long currConsumerIndex = lvConsumerIndex(); + int size = (int) (index - currConsumerIndex); + int maxCapacity = (int) mask + 1; // we're on max capacity or we wouldn't be here + if (size == maxCapacity) + { + // consumer index has not changed since adjusting the lookAhead index, we're full + return false; + } + // if consumerIndex progressed enough so that current size indicates it is on same buffer + long firstIndexInCurrentBuffer = producerBufferLimit - maxCapacity + prevElementsInOtherBuffers; + if (currConsumerIndex >= firstIndexInCurrentBuffer) + { + // job done, we've now settled into our final state + adjustLookAheadStep(maxCapacity); + } + // consumer is still on some other buffer + else + { + // how many elements out of buffer? + this.lookAheadStep = (int) (currConsumerIndex - firstIndexInCurrentBuffer); + } + producerBufferLimit = currConsumerIndex + maxCapacity; + writeToQueue(buffer, v == null ? s.get() : v, index, offset); + return true; + } + } + + private void adjustLookAheadStep(int capacity) + { + lookAheadStep = SpscLookAheadUtil.computeLookAheadStep(capacity); + } + + @Override + public int capacity() + { + return maxQueueCapacity; + } +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/SpscLinkedQueue.java b/netty-jctools/src/main/java/org/jctools/queues/SpscLinkedQueue.java new file mode 100644 index 0000000..ac8a925 --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/SpscLinkedQueue.java @@ -0,0 +1,111 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues; + +/** + * This is a weakened version of the MPSC algorithm as presented + * on + * 1024 Cores by D. Vyukov. The original has been adapted to Java and it's quirks with regards to memory + * model and layout: + *

    + *
  1. Use inheritance to ensure no false sharing occurs between producer/consumer node reference fields. + *
  2. As this is an SPSC we have no need for XCHG, an ordered store is enough. + *
+ * The queue is initialized with a stub node which is set to both the producer and consumer node references. + * From this point follow the notes on offer/poll. + * + * @param + * @author nitsanw + */ +public class SpscLinkedQueue extends BaseLinkedQueue +{ + + public SpscLinkedQueue() + { + LinkedQueueNode node = newNode(); + spProducerNode(node); + spConsumerNode(node); + node.soNext(null); // this ensures correct construction: StoreStore + } + + /** + * {@inheritDoc}
+ *

+ * IMPLEMENTATION NOTES:
+ * Offer is allowed from a SINGLE thread.
+ * Offer allocates a new node (holding the offered value) and: + *

    + *
  1. Sets the new node as the producerNode + *
  2. Sets that node as the lastProducerNode.next + *
+ * From this follows that producerNode.next is always null and for all other nodes node.next is not null. + * + * @see MessagePassingQueue#offer(Object) + * @see java.util.Queue#offer(java.lang.Object) + */ + @Override + public boolean offer(final E e) + { + if (null == e) + { + throw new NullPointerException(); + } + final LinkedQueueNode nextNode = newNode(e); + LinkedQueueNode oldNode = lpProducerNode(); + soProducerNode(nextNode); + // Should a producer thread get interrupted here the chain WILL be broken until that thread is resumed + // and completes the store in prev.next. This is a "bubble". + // Inverting the order here will break the `isEmpty` invariant, and will require matching adjustments elsewhere. + oldNode.soNext(nextNode); + return true; + } + + @Override + public int fill(Supplier s) + { + return MessagePassingQueueUtil.fillUnbounded(this, s); + } + + @Override + public int fill(Supplier s, int limit) + { + if (null == s) + throw new IllegalArgumentException("supplier is null"); + if (limit < 0) + throw new IllegalArgumentException("limit is negative:" + limit); + if (limit == 0) + return 0; + + LinkedQueueNode tail = newNode(s.get()); + final LinkedQueueNode head = tail; + for (int i = 1; i < limit; i++) + { + final LinkedQueueNode temp = newNode(s.get()); + // spNext : soProducerNode ensures correct construction + tail.spNext(temp); + tail = temp; + } + final LinkedQueueNode oldPNode = lpProducerNode(); + soProducerNode(tail); + // same bubble as offer, and for the same reasons. + oldPNode.soNext(head); + return limit; + } + + @Override + public void fill(Supplier s, WaitStrategy wait, ExitCondition exit) + { + MessagePassingQueueUtil.fill(this, s, wait, exit); + } +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/SpscUnboundedArrayQueue.java b/netty-jctools/src/main/java/org/jctools/queues/SpscUnboundedArrayQueue.java new file mode 100644 index 0000000..d34f288 --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/SpscUnboundedArrayQueue.java @@ -0,0 +1,83 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues; + +import org.jctools.util.Pow2; + +import static org.jctools.util.UnsafeRefArrayAccess.*; + +/** + * An SPSC array queue which starts at initialCapacity and grows indefinitely in linked chunks of the initial size. + * The queue grows only when the current chunk is full and elements are not copied on + * resize, instead a link to the new chunk is stored in the old chunk for the consumer to follow.
+ * + * @param + */ +public class SpscUnboundedArrayQueue extends BaseSpscLinkedArrayQueue +{ + + public SpscUnboundedArrayQueue(int chunkSize) + { + int chunkCapacity = Math.max(Pow2.roundToPowerOfTwo(chunkSize), 16); + long mask = chunkCapacity - 1; + E[] buffer = allocateRefArray(chunkCapacity + 1); + producerBuffer = buffer; + producerMask = mask; + consumerBuffer = buffer; + consumerMask = mask; + producerBufferLimit = mask - 1; // we know it's all empty to start with + } + + @Override + final boolean offerColdPath(E[] buffer, long mask, long pIndex, long offset, E v, Supplier s) + { + // use a fixed lookahead step based on buffer capacity + final long lookAheadStep = (mask + 1) / 4; + long pBufferLimit = pIndex + lookAheadStep; + + // go around the buffer or add a new buffer + if (null == lvRefElement(buffer, calcCircularRefElementOffset(pBufferLimit, mask))) + { + producerBufferLimit = pBufferLimit - 1; // joy, there's plenty of room + writeToQueue(buffer, v == null ? s.get() : v, pIndex, offset); + } + else if (null == lvRefElement(buffer, calcCircularRefElementOffset(pIndex + 1, mask))) + { // buffer is not full + writeToQueue(buffer, v == null ? s.get() : v, pIndex, offset); + } + else + { + // we got one slot left to write into, and we are not full. Need to link new buffer. + // allocate new buffer of same length + final E[] newBuffer = allocateRefArray((int) (mask + 2)); + producerBuffer = newBuffer; + producerBufferLimit = pIndex + mask - 1; + + linkOldToNew(pIndex, buffer, offset, newBuffer, offset, v == null ? s.get() : v); + } + return true; + } + + @Override + public int fill(Supplier s) + { + return fill(s, (int) this.producerMask); + } + + @Override + public int capacity() + { + return UNBOUNDED_CAPACITY; + } +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/SupportsIterator.java b/netty-jctools/src/main/java/org/jctools/queues/SupportsIterator.java new file mode 100644 index 0000000..4da8771 --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/SupportsIterator.java @@ -0,0 +1,24 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues; + +import org.jctools.util.InternalAPI; + +/** + * Tagging interface to help testing + */ +@InternalAPI +public interface SupportsIterator +{ +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/atomic/AtomicQueueUtil.java b/netty-jctools/src/main/java/org/jctools/queues/atomic/AtomicQueueUtil.java new file mode 100644 index 0000000..63d559f --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/atomic/AtomicQueueUtil.java @@ -0,0 +1,101 @@ +package org.jctools.queues.atomic; + +import java.util.concurrent.atomic.AtomicLongArray; +import java.util.concurrent.atomic.AtomicReferenceArray; + +final class AtomicQueueUtil +{ + static E lvRefElement(AtomicReferenceArray buffer, int offset) + { + return buffer.get(offset); + } + + static E lpRefElement(AtomicReferenceArray buffer, int offset) + { + return buffer.get(offset); // no weaker form available + } + + static void spRefElement(AtomicReferenceArray buffer, int offset, E value) + { + buffer.lazySet(offset, value); // no weaker form available + } + + static void soRefElement(AtomicReferenceArray buffer, int offset, Object value) + { + buffer.lazySet(offset, value); + } + + static void svRefElement(AtomicReferenceArray buffer, int offset, E value) + { + buffer.set(offset, value); + } + + static int calcRefElementOffset(long index) + { + return (int) index; + } + + static int calcCircularRefElementOffset(long index, long mask) + { + return (int) (index & mask); + } + + static AtomicReferenceArray allocateRefArray(int capacity) + { + return new AtomicReferenceArray(capacity); + } + + static void spLongElement(AtomicLongArray buffer, int offset, long e) + { + buffer.lazySet(offset, e); + } + + static void soLongElement(AtomicLongArray buffer, int offset, long e) + { + buffer.lazySet(offset, e); + } + + static long lpLongElement(AtomicLongArray buffer, int offset) + { + return buffer.get(offset); + } + + static long lvLongElement(AtomicLongArray buffer, int offset) + { + return buffer.get(offset); + } + + static int calcLongElementOffset(long index) + { + return (int) index; + } + + static int calcCircularLongElementOffset(long index, int mask) + { + return (int) (index & mask); + } + + static AtomicLongArray allocateLongArray(int capacity) + { + return new AtomicLongArray(capacity); + } + + static int length(AtomicReferenceArray buf) + { + return buf.length(); + } + + /** + * This method assumes index is actually (index << 1) because lower bit is used for resize hence the >> 1 + */ + static int modifiedCalcCircularRefElementOffset(long index, long mask) + { + return (int) (index & mask) >> 1; + } + + static int nextArrayOffset(AtomicReferenceArray curr) + { + return length(curr) - 1; + } + +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/atomic/AtomicReferenceArrayQueue.java b/netty-jctools/src/main/java/org/jctools/queues/atomic/AtomicReferenceArrayQueue.java new file mode 100644 index 0000000..5753eda --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/atomic/AtomicReferenceArrayQueue.java @@ -0,0 +1,157 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues.atomic; + +import org.jctools.queues.IndexedQueueSizeUtil; +import org.jctools.queues.IndexedQueueSizeUtil.IndexedQueue; +import org.jctools.queues.MessagePassingQueue; +import org.jctools.queues.QueueProgressIndicators; +import org.jctools.queues.SupportsIterator; +import org.jctools.util.Pow2; + +import java.util.AbstractQueue; +import java.util.Iterator; +import java.util.NoSuchElementException; +import java.util.concurrent.atomic.AtomicReferenceArray; + +import static org.jctools.queues.atomic.AtomicQueueUtil.*; + +abstract class AtomicReferenceArrayQueue extends AbstractQueue implements IndexedQueue, QueueProgressIndicators, MessagePassingQueue, SupportsIterator +{ + protected final AtomicReferenceArray buffer; + protected final int mask; + + public AtomicReferenceArrayQueue(int capacity) + { + int actualCapacity = Pow2.roundToPowerOfTwo(capacity); + this.mask = actualCapacity - 1; + this.buffer = new AtomicReferenceArray(actualCapacity); + } + + @Override + public String toString() + { + return this.getClass().getName(); + } + + @Override + public void clear() + { + while (poll() != null) + { + // toss it away + } + } + + @Override + public final int capacity() + { + return (int) (mask + 1); + } + + /** + * {@inheritDoc} + *

+ */ + @Override + public final int size() + { + return IndexedQueueSizeUtil.size(this, IndexedQueueSizeUtil.PLAIN_DIVISOR); + } + + @Override + public final boolean isEmpty() + { + return IndexedQueueSizeUtil.isEmpty(this); + } + + @Override + public final long currentProducerIndex() + { + return lvProducerIndex(); + } + + @Override + public final long currentConsumerIndex() + { + return lvConsumerIndex(); + } + + /** + * Get an iterator for this queue. This method is thread safe. + *

+ * The iterator provides a best-effort snapshot of the elements in the queue. + * The returned iterator is not guaranteed to return elements in queue order, + * and races with the consumer thread may cause gaps in the sequence of returned elements. + * Like {link #relaxedPoll}, the iterator may not immediately return newly inserted elements. + * + * @return The iterator. + */ + @Override + public final Iterator iterator() { + final long cIndex = lvConsumerIndex(); + final long pIndex = lvProducerIndex(); + + return new WeakIterator(cIndex, pIndex, mask, buffer); + } + + private static class WeakIterator implements Iterator { + + private final long pIndex; + private final int mask; + private final AtomicReferenceArray buffer; + private long nextIndex; + private E nextElement; + + WeakIterator(long cIndex, long pIndex, int mask, AtomicReferenceArray buffer) { + this.nextIndex = cIndex; + this.pIndex = pIndex; + this.mask = mask; + this.buffer = buffer; + nextElement = getNext(); + } + + @Override + public void remove() { + throw new UnsupportedOperationException("remove"); + } + + @Override + public boolean hasNext() { + return nextElement != null; + } + + @Override + public E next() { + final E e = nextElement; + if (e == null) + throw new NoSuchElementException(); + nextElement = getNext(); + return e; + } + + private E getNext() { + final int mask = this.mask; + final AtomicReferenceArray buffer = this.buffer; + while (nextIndex < pIndex) { + int offset = calcCircularRefElementOffset(nextIndex++, mask); + E e = lvRefElement(buffer, offset); + if (e != null) { + return e; + } + } + return null; + } + } +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/atomic/BaseLinkedAtomicQueue.java b/netty-jctools/src/main/java/org/jctools/queues/atomic/BaseLinkedAtomicQueue.java new file mode 100644 index 0000000..a8cf18c --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/atomic/BaseLinkedAtomicQueue.java @@ -0,0 +1,471 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues.atomic; + +import java.util.AbstractQueue; +import java.util.Iterator; +import java.util.Queue; +import java.util.concurrent.atomic.*; +import org.jctools.queues.*; +import static org.jctools.queues.atomic.AtomicQueueUtil.*; + +/** + * NOTE: This class was automatically generated by org.jctools.queues.atomic.JavaParsingAtomicLinkedQueueGenerator + * which can found in the jctools-build module. The original source file is BaseLinkedQueue.java. + */ +abstract class // byte b170,b171,b172,b173,b174,b175,b176,b177;//128b +// * drop 8b as object header acts as padding and is >= 8b * +BaseLinkedAtomicQueuePad0 extends AbstractQueue implements MessagePassingQueue { + + // 8b + byte b000, b001, b002, b003, b004, b005, b006, b007; + + // 16b + byte b010, b011, b012, b013, b014, b015, b016, b017; + + // 24b + byte b020, b021, b022, b023, b024, b025, b026, b027; + + // 32b + byte b030, b031, b032, b033, b034, b035, b036, b037; + + // 40b + byte b040, b041, b042, b043, b044, b045, b046, b047; + + // 48b + byte b050, b051, b052, b053, b054, b055, b056, b057; + + // 56b + byte b060, b061, b062, b063, b064, b065, b066, b067; + + // 64b + byte b070, b071, b072, b073, b074, b075, b076, b077; + + // 72b + byte b100, b101, b102, b103, b104, b105, b106, b107; + + // 80b + byte b110, b111, b112, b113, b114, b115, b116, b117; + + // 88b + byte b120, b121, b122, b123, b124, b125, b126, b127; + + // 96b + byte b130, b131, b132, b133, b134, b135, b136, b137; + + // 104b + byte b140, b141, b142, b143, b144, b145, b146, b147; + + // 112b + byte b150, b151, b152, b153, b154, b155, b156, b157; + + // 120b + byte b160, b161, b162, b163, b164, b165, b166, b167; +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.atomic.JavaParsingAtomicLinkedQueueGenerator + * which can found in the jctools-build module. The original source file is BaseLinkedQueue.java. + */ +abstract class BaseLinkedAtomicQueueProducerNodeRef extends BaseLinkedAtomicQueuePad0 { + + private static final AtomicReferenceFieldUpdater P_NODE_UPDATER = AtomicReferenceFieldUpdater.newUpdater(BaseLinkedAtomicQueueProducerNodeRef.class, LinkedQueueAtomicNode.class, "producerNode"); + + private volatile LinkedQueueAtomicNode producerNode; + + final void spProducerNode(LinkedQueueAtomicNode newValue) { + P_NODE_UPDATER.lazySet(this, newValue); + } + + final void soProducerNode(LinkedQueueAtomicNode newValue) { + P_NODE_UPDATER.lazySet(this, newValue); + } + + final LinkedQueueAtomicNode lvProducerNode() { + return producerNode; + } + + final boolean casProducerNode(LinkedQueueAtomicNode expect, LinkedQueueAtomicNode newValue) { + return P_NODE_UPDATER.compareAndSet(this, expect, newValue); + } + + final LinkedQueueAtomicNode lpProducerNode() { + return producerNode; + } + + protected final LinkedQueueAtomicNode xchgProducerNode(LinkedQueueAtomicNode newValue) { + return P_NODE_UPDATER.getAndSet(this, newValue); + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.atomic.JavaParsingAtomicLinkedQueueGenerator + * which can found in the jctools-build module. The original source file is BaseLinkedQueue.java. + */ +abstract class BaseLinkedAtomicQueuePad1 extends BaseLinkedAtomicQueueProducerNodeRef { + + // 8b + byte b000, b001, b002, b003, b004, b005, b006, b007; + + // 16b + byte b010, b011, b012, b013, b014, b015, b016, b017; + + // 24b + byte b020, b021, b022, b023, b024, b025, b026, b027; + + // 32b + byte b030, b031, b032, b033, b034, b035, b036, b037; + + // 40b + byte b040, b041, b042, b043, b044, b045, b046, b047; + + // 48b + byte b050, b051, b052, b053, b054, b055, b056, b057; + + // 56b + byte b060, b061, b062, b063, b064, b065, b066, b067; + + // 64b + byte b070, b071, b072, b073, b074, b075, b076, b077; + + // 72b + byte b100, b101, b102, b103, b104, b105, b106, b107; + + // 80b + byte b110, b111, b112, b113, b114, b115, b116, b117; + + // 88b + byte b120, b121, b122, b123, b124, b125, b126, b127; + + // 96b + byte b130, b131, b132, b133, b134, b135, b136, b137; + + // 104b + byte b140, b141, b142, b143, b144, b145, b146, b147; + + // 112b + byte b150, b151, b152, b153, b154, b155, b156, b157; + + // 120b + byte b160, b161, b162, b163, b164, b165, b166, b167; + + // 128b + byte b170, b171, b172, b173, b174, b175, b176, b177; +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.atomic.JavaParsingAtomicLinkedQueueGenerator + * which can found in the jctools-build module. The original source file is BaseLinkedQueue.java. + */ +abstract class BaseLinkedAtomicQueueConsumerNodeRef extends BaseLinkedAtomicQueuePad1 { + + private static final AtomicReferenceFieldUpdater C_NODE_UPDATER = AtomicReferenceFieldUpdater.newUpdater(BaseLinkedAtomicQueueConsumerNodeRef.class, LinkedQueueAtomicNode.class, "consumerNode"); + + private volatile LinkedQueueAtomicNode consumerNode; + + final void spConsumerNode(LinkedQueueAtomicNode newValue) { + C_NODE_UPDATER.lazySet(this, newValue); + } + + @SuppressWarnings("unchecked") + final LinkedQueueAtomicNode lvConsumerNode() { + return consumerNode; + } + + final LinkedQueueAtomicNode lpConsumerNode() { + return consumerNode; + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.atomic.JavaParsingAtomicLinkedQueueGenerator + * which can found in the jctools-build module. The original source file is BaseLinkedQueue.java. + */ +abstract class BaseLinkedAtomicQueuePad2 extends BaseLinkedAtomicQueueConsumerNodeRef { + + // 8b + byte b000, b001, b002, b003, b004, b005, b006, b007; + + // 16b + byte b010, b011, b012, b013, b014, b015, b016, b017; + + // 24b + byte b020, b021, b022, b023, b024, b025, b026, b027; + + // 32b + byte b030, b031, b032, b033, b034, b035, b036, b037; + + // 40b + byte b040, b041, b042, b043, b044, b045, b046, b047; + + // 48b + byte b050, b051, b052, b053, b054, b055, b056, b057; + + // 56b + byte b060, b061, b062, b063, b064, b065, b066, b067; + + // 64b + byte b070, b071, b072, b073, b074, b075, b076, b077; + + // 72b + byte b100, b101, b102, b103, b104, b105, b106, b107; + + // 80b + byte b110, b111, b112, b113, b114, b115, b116, b117; + + // 88b + byte b120, b121, b122, b123, b124, b125, b126, b127; + + // 96b + byte b130, b131, b132, b133, b134, b135, b136, b137; + + // 104b + byte b140, b141, b142, b143, b144, b145, b146, b147; + + // 112b + byte b150, b151, b152, b153, b154, b155, b156, b157; + + // 120b + byte b160, b161, b162, b163, b164, b165, b166, b167; + + // 128b + byte b170, b171, b172, b173, b174, b175, b176, b177; +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.atomic.JavaParsingAtomicLinkedQueueGenerator + * which can found in the jctools-build module. The original source file is BaseLinkedQueue.java. + * + * A base data structure for concurrent linked queues. For convenience also pulled in common single consumer + * methods since at this time there's no plan to implement MC. + */ +abstract class BaseLinkedAtomicQueue extends BaseLinkedAtomicQueuePad2 { + + @Override + public final Iterator iterator() { + throw new UnsupportedOperationException(); + } + + @Override + public String toString() { + return this.getClass().getName(); + } + + protected final LinkedQueueAtomicNode newNode() { + return new LinkedQueueAtomicNode(); + } + + protected final LinkedQueueAtomicNode newNode(E e) { + return new LinkedQueueAtomicNode(e); + } + + /** + * {@inheritDoc}
+ *

+ * IMPLEMENTATION NOTES:
+ * This is an O(n) operation as we run through all the nodes and count them.
+ * The accuracy of the value returned by this method is subject to races with producer/consumer threads. In + * particular when racing with the consumer thread this method may under estimate the size.
+ * + * @see java.util.Queue#size() + */ + @Override + public final int size() { + // Read consumer first, this is important because if the producer is node is 'older' than the consumer + // the consumer may overtake it (consume past it) invalidating the 'snapshot' notion of size. + LinkedQueueAtomicNode chaserNode = lvConsumerNode(); + LinkedQueueAtomicNode producerNode = lvProducerNode(); + int size = 0; + // must chase the nodes all the way to the producer node, but there's no need to count beyond expected head. + while (// don't go passed producer node + chaserNode != producerNode && // stop at last node + chaserNode != null && // stop at max int + size < Integer.MAX_VALUE) { + LinkedQueueAtomicNode next; + next = chaserNode.lvNext(); + // check if this node has been consumed, if so return what we have + if (next == chaserNode) { + return size; + } + chaserNode = next; + size++; + } + return size; + } + + /** + * {@inheritDoc}
+ *

+ * IMPLEMENTATION NOTES:
+ * Queue is empty when producerNode is the same as consumerNode. An alternative implementation would be to + * observe the producerNode.value is null, which also means an empty queue because only the + * consumerNode.value is allowed to be null. + * + * @see MessagePassingQueue#isEmpty() + */ + @Override + public boolean isEmpty() { + LinkedQueueAtomicNode consumerNode = lvConsumerNode(); + LinkedQueueAtomicNode producerNode = lvProducerNode(); + return consumerNode == producerNode; + } + + protected E getSingleConsumerNodeValue(LinkedQueueAtomicNode currConsumerNode, LinkedQueueAtomicNode nextNode) { + // we have to null out the value because we are going to hang on to the node + final E nextValue = nextNode.getAndNullValue(); + // Fix up the next ref of currConsumerNode to prevent promoted nodes from keeping new ones alive. + // We use a reference to self instead of null because null is already a meaningful value (the next of + // producer node is null). + currConsumerNode.soNext(currConsumerNode); + spConsumerNode(nextNode); + // currConsumerNode is now no longer referenced and can be collected + return nextValue; + } + + /** + * {@inheritDoc}
+ *

+ * IMPLEMENTATION NOTES:
+ * Poll is allowed from a SINGLE thread.
+ * Poll is potentially blocking here as the {@link Queue#poll()} does not allow returning {@code null} if the queue is not + * empty. This is very different from the original Vyukov guarantees. See {@link #relaxedPoll()} for the original + * semantics.
+ * Poll reads {@code consumerNode.next} and: + *

    + *
  1. If it is {@code null} AND the queue is empty return {@code null}, if queue is not empty spin wait for + * value to become visible. + *
  2. If it is not {@code null} set it as the consumer node and return it's now evacuated value. + *
+ * This means the consumerNode.value is always {@code null}, which is also the starting point for the queue. + * Because {@code null} values are not allowed to be offered this is the only node with it's value set to + * {@code null} at any one time. + * + * @see MessagePassingQueue#poll() + * @see java.util.Queue#poll() + */ + @Override + public E poll() { + final LinkedQueueAtomicNode currConsumerNode = lpConsumerNode(); + LinkedQueueAtomicNode nextNode = currConsumerNode.lvNext(); + if (nextNode != null) { + return getSingleConsumerNodeValue(currConsumerNode, nextNode); + } else if (currConsumerNode != lvProducerNode()) { + nextNode = spinWaitForNextNode(currConsumerNode); + // got the next node... + return getSingleConsumerNodeValue(currConsumerNode, nextNode); + } + return null; + } + + /** + * {@inheritDoc}
+ *

+ * IMPLEMENTATION NOTES:
+ * Peek is allowed from a SINGLE thread.
+ * Peek is potentially blocking here as the {@link Queue#peek()} does not allow returning {@code null} if the queue is not + * empty. This is very different from the original Vyukov guarantees. See {@link #relaxedPeek()} for the original + * semantics.
+ * Poll reads the next node from the consumerNode and: + *

    + *
  1. If it is {@code null} AND the queue is empty return {@code null}, if queue is not empty spin wait for + * value to become visible. + *
  2. If it is not {@code null} return it's value. + *
+ * + * @see MessagePassingQueue#peek() + * @see java.util.Queue#peek() + */ + @Override + public E peek() { + final LinkedQueueAtomicNode currConsumerNode = lpConsumerNode(); + LinkedQueueAtomicNode nextNode = currConsumerNode.lvNext(); + if (nextNode != null) { + return nextNode.lpValue(); + } else if (currConsumerNode != lvProducerNode()) { + nextNode = spinWaitForNextNode(currConsumerNode); + // got the next node... + return nextNode.lpValue(); + } + return null; + } + + LinkedQueueAtomicNode spinWaitForNextNode(LinkedQueueAtomicNode currNode) { + LinkedQueueAtomicNode nextNode; + while ((nextNode = currNode.lvNext()) == null) { + // spin, we are no longer wait free + } + return nextNode; + } + + @Override + public E relaxedPoll() { + final LinkedQueueAtomicNode currConsumerNode = lpConsumerNode(); + final LinkedQueueAtomicNode nextNode = currConsumerNode.lvNext(); + if (nextNode != null) { + return getSingleConsumerNodeValue(currConsumerNode, nextNode); + } + return null; + } + + @Override + public E relaxedPeek() { + final LinkedQueueAtomicNode nextNode = lpConsumerNode().lvNext(); + if (nextNode != null) { + return nextNode.lpValue(); + } + return null; + } + + @Override + public boolean relaxedOffer(E e) { + return offer(e); + } + + @Override + public int drain(Consumer c, int limit) { + if (null == c) + throw new IllegalArgumentException("c is null"); + if (limit < 0) + throw new IllegalArgumentException("limit is negative: " + limit); + if (limit == 0) + return 0; + LinkedQueueAtomicNode chaserNode = this.lpConsumerNode(); + for (int i = 0; i < limit; i++) { + final LinkedQueueAtomicNode nextNode = chaserNode.lvNext(); + if (nextNode == null) { + return i; + } + // we have to null out the value because we are going to hang on to the node + final E nextValue = getSingleConsumerNodeValue(chaserNode, nextNode); + chaserNode = nextNode; + c.accept(nextValue); + } + return limit; + } + + @Override + public int drain(Consumer c) { + return MessagePassingQueueUtil.drain(this, c); + } + + @Override + public void drain(Consumer c, WaitStrategy wait, ExitCondition exit) { + MessagePassingQueueUtil.drain(this, c, wait, exit); + } + + @Override + public int capacity() { + return UNBOUNDED_CAPACITY; + } +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/atomic/BaseMpscLinkedAtomicArrayQueue.java b/netty-jctools/src/main/java/org/jctools/queues/atomic/BaseMpscLinkedAtomicArrayQueue.java new file mode 100644 index 0000000..6000645 --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/atomic/BaseMpscLinkedAtomicArrayQueue.java @@ -0,0 +1,794 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues.atomic; + +import org.jctools.queues.IndexedQueueSizeUtil.IndexedQueue; +import org.jctools.util.PortableJvmInfo; +import org.jctools.util.Pow2; +import org.jctools.util.RangeUtil; +import java.util.AbstractQueue; +import java.util.Iterator; +import java.util.NoSuchElementException; +import java.util.concurrent.atomic.*; +import org.jctools.queues.*; +import static org.jctools.queues.atomic.AtomicQueueUtil.*; + +/** + * NOTE: This class was automatically generated by org.jctools.queues.atomic.JavaParsingAtomicLinkedQueueGenerator + * which can found in the jctools-build module. The original source file is BaseMpscLinkedArrayQueue.java. + */ +abstract class BaseMpscLinkedAtomicArrayQueuePad1 extends AbstractQueue implements IndexedQueue { + + // 8b + byte b000, b001, b002, b003, b004, b005, b006, b007; + + // 16b + byte b010, b011, b012, b013, b014, b015, b016, b017; + + // 24b + byte b020, b021, b022, b023, b024, b025, b026, b027; + + // 32b + byte b030, b031, b032, b033, b034, b035, b036, b037; + + // 40b + byte b040, b041, b042, b043, b044, b045, b046, b047; + + // 48b + byte b050, b051, b052, b053, b054, b055, b056, b057; + + // 56b + byte b060, b061, b062, b063, b064, b065, b066, b067; + + // 64b + byte b070, b071, b072, b073, b074, b075, b076, b077; + + // 72b + byte b100, b101, b102, b103, b104, b105, b106, b107; + + // 80b + byte b110, b111, b112, b113, b114, b115, b116, b117; + + // 88b + byte b120, b121, b122, b123, b124, b125, b126, b127; + + // 96b + byte b130, b131, b132, b133, b134, b135, b136, b137; + + // 104b + byte b140, b141, b142, b143, b144, b145, b146, b147; + + // 112b + byte b150, b151, b152, b153, b154, b155, b156, b157; + + // 120b + byte b160, b161, b162, b163, b164, b165, b166, b167; + + // 128b + byte b170, b171, b172, b173, b174, b175, b176, b177; +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.atomic.JavaParsingAtomicLinkedQueueGenerator + * which can found in the jctools-build module. The original source file is BaseMpscLinkedArrayQueue.java. + */ +abstract class BaseMpscLinkedAtomicArrayQueueProducerFields extends BaseMpscLinkedAtomicArrayQueuePad1 { + + private static final AtomicLongFieldUpdater P_INDEX_UPDATER = AtomicLongFieldUpdater.newUpdater(BaseMpscLinkedAtomicArrayQueueProducerFields.class, "producerIndex"); + + private volatile long producerIndex; + + @Override + public final long lvProducerIndex() { + return producerIndex; + } + + final void soProducerIndex(long newValue) { + P_INDEX_UPDATER.lazySet(this, newValue); + } + + final boolean casProducerIndex(long expect, long newValue) { + return P_INDEX_UPDATER.compareAndSet(this, expect, newValue); + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.atomic.JavaParsingAtomicLinkedQueueGenerator + * which can found in the jctools-build module. The original source file is BaseMpscLinkedArrayQueue.java. + */ +abstract class BaseMpscLinkedAtomicArrayQueuePad2 extends BaseMpscLinkedAtomicArrayQueueProducerFields { + + // 8b + byte b000, b001, b002, b003, b004, b005, b006, b007; + + // 16b + byte b010, b011, b012, b013, b014, b015, b016, b017; + + // 24b + byte b020, b021, b022, b023, b024, b025, b026, b027; + + // 32b + byte b030, b031, b032, b033, b034, b035, b036, b037; + + // 40b + byte b040, b041, b042, b043, b044, b045, b046, b047; + + // 48b + byte b050, b051, b052, b053, b054, b055, b056, b057; + + // 56b + byte b060, b061, b062, b063, b064, b065, b066, b067; + + // 64b + byte b070, b071, b072, b073, b074, b075, b076, b077; + + // 72b + byte b100, b101, b102, b103, b104, b105, b106, b107; + + // 80b + byte b110, b111, b112, b113, b114, b115, b116, b117; + + // 88b + byte b120, b121, b122, b123, b124, b125, b126, b127; + + // 96b + byte b130, b131, b132, b133, b134, b135, b136, b137; + + // 104b + byte b140, b141, b142, b143, b144, b145, b146, b147; + + // 112b + byte b150, b151, b152, b153, b154, b155, b156, b157; + + // 120b + byte b160, b161, b162, b163, b164, b165, b166, b167; + + // 128b + byte b170, b171, b172, b173, b174, b175, b176, b177; +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.atomic.JavaParsingAtomicLinkedQueueGenerator + * which can found in the jctools-build module. The original source file is BaseMpscLinkedArrayQueue.java. + */ +abstract class BaseMpscLinkedAtomicArrayQueueConsumerFields extends BaseMpscLinkedAtomicArrayQueuePad2 { + + private static final AtomicLongFieldUpdater C_INDEX_UPDATER = AtomicLongFieldUpdater.newUpdater(BaseMpscLinkedAtomicArrayQueueConsumerFields.class, "consumerIndex"); + + private volatile long consumerIndex; + + protected long consumerMask; + + protected AtomicReferenceArray consumerBuffer; + + @Override + public final long lvConsumerIndex() { + return consumerIndex; + } + + final long lpConsumerIndex() { + return consumerIndex; + } + + final void soConsumerIndex(long newValue) { + C_INDEX_UPDATER.lazySet(this, newValue); + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.atomic.JavaParsingAtomicLinkedQueueGenerator + * which can found in the jctools-build module. The original source file is BaseMpscLinkedArrayQueue.java. + */ +abstract class BaseMpscLinkedAtomicArrayQueuePad3 extends BaseMpscLinkedAtomicArrayQueueConsumerFields { + + // 8b + byte b000, b001, b002, b003, b004, b005, b006, b007; + + // 16b + byte b010, b011, b012, b013, b014, b015, b016, b017; + + // 24b + byte b020, b021, b022, b023, b024, b025, b026, b027; + + // 32b + byte b030, b031, b032, b033, b034, b035, b036, b037; + + // 40b + byte b040, b041, b042, b043, b044, b045, b046, b047; + + // 48b + byte b050, b051, b052, b053, b054, b055, b056, b057; + + // 56b + byte b060, b061, b062, b063, b064, b065, b066, b067; + + // 64b + byte b070, b071, b072, b073, b074, b075, b076, b077; + + // 72b + byte b100, b101, b102, b103, b104, b105, b106, b107; + + // 80b + byte b110, b111, b112, b113, b114, b115, b116, b117; + + // 88b + byte b120, b121, b122, b123, b124, b125, b126, b127; + + // 96b + byte b130, b131, b132, b133, b134, b135, b136, b137; + + // 104b + byte b140, b141, b142, b143, b144, b145, b146, b147; + + // 112b + byte b150, b151, b152, b153, b154, b155, b156, b157; + + // 120b + byte b160, b161, b162, b163, b164, b165, b166, b167; + + // 128b + byte b170, b171, b172, b173, b174, b175, b176, b177; +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.atomic.JavaParsingAtomicLinkedQueueGenerator + * which can found in the jctools-build module. The original source file is BaseMpscLinkedArrayQueue.java. + */ +abstract class BaseMpscLinkedAtomicArrayQueueColdProducerFields extends BaseMpscLinkedAtomicArrayQueuePad3 { + + private static final AtomicLongFieldUpdater P_LIMIT_UPDATER = AtomicLongFieldUpdater.newUpdater(BaseMpscLinkedAtomicArrayQueueColdProducerFields.class, "producerLimit"); + + private volatile long producerLimit; + + protected long producerMask; + + protected AtomicReferenceArray producerBuffer; + + final long lvProducerLimit() { + return producerLimit; + } + + final boolean casProducerLimit(long expect, long newValue) { + return P_LIMIT_UPDATER.compareAndSet(this, expect, newValue); + } + + final void soProducerLimit(long newValue) { + P_LIMIT_UPDATER.lazySet(this, newValue); + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.atomic.JavaParsingAtomicLinkedQueueGenerator + * which can found in the jctools-build module. The original source file is BaseMpscLinkedArrayQueue.java. + * + * An MPSC array queue which starts at initialCapacity and grows to maxCapacity in linked chunks + * of the initial size. The queue grows only when the current buffer is full and elements are not copied on + * resize, instead a link to the new buffer is stored in the old buffer for the consumer to follow. + */ +abstract class BaseMpscLinkedAtomicArrayQueue extends BaseMpscLinkedAtomicArrayQueueColdProducerFields implements MessagePassingQueue, QueueProgressIndicators { + + // No post padding here, subclasses must add + private static final Object JUMP = new Object(); + + private static final Object BUFFER_CONSUMED = new Object(); + + private static final int CONTINUE_TO_P_INDEX_CAS = 0; + + private static final int RETRY = 1; + + private static final int QUEUE_FULL = 2; + + private static final int QUEUE_RESIZE = 3; + + /** + * @param initialCapacity the queue initial capacity. If chunk size is fixed this will be the chunk size. + * Must be 2 or more. + */ + public BaseMpscLinkedAtomicArrayQueue(final int initialCapacity) { + RangeUtil.checkGreaterThanOrEqual(initialCapacity, 2, "initialCapacity"); + int p2capacity = Pow2.roundToPowerOfTwo(initialCapacity); + // leave lower bit of mask clear + long mask = (p2capacity - 1) << 1; + // need extra element to point at next array + AtomicReferenceArray buffer = allocateRefArray(p2capacity + 1); + producerBuffer = buffer; + producerMask = mask; + consumerBuffer = buffer; + consumerMask = mask; + // we know it's all empty to start with + soProducerLimit(mask); + } + + @Override + public int size() { + return IndexedQueueSizeUtil.size(this, IndexedQueueSizeUtil.IGNORE_PARITY_DIVISOR); + } + + @Override + public boolean isEmpty() { + // Order matters! + // Loading consumer before producer allows for producer increments after consumer index is read. + // This ensures this method is conservative in it's estimate. Note that as this is an MPMC there is + // nothing we can do to make this an exact method. + return ((lvConsumerIndex() - lvProducerIndex()) / 2 == 0); + } + + @Override + public String toString() { + return this.getClass().getName(); + } + + @Override + public boolean offer(final E e) { + if (null == e) { + throw new NullPointerException(); + } + long mask; + AtomicReferenceArray buffer; + long pIndex; + while (true) { + long producerLimit = lvProducerLimit(); + pIndex = lvProducerIndex(); + // lower bit is indicative of resize, if we see it we spin until it's cleared + if ((pIndex & 1) == 1) { + continue; + } + // pIndex is even (lower bit is 0) -> actual index is (pIndex >> 1) + // mask/buffer may get changed by resizing -> only use for array access after successful CAS. + mask = this.producerMask; + buffer = this.producerBuffer; + // a successful CAS ties the ordering, lv(pIndex) - [mask/buffer] -> cas(pIndex) + // assumption behind this optimization is that queue is almost always empty or near empty + if (producerLimit <= pIndex) { + int result = offerSlowPath(mask, pIndex, producerLimit); + switch(result) { + case CONTINUE_TO_P_INDEX_CAS: + break; + case RETRY: + continue; + case QUEUE_FULL: + return false; + case QUEUE_RESIZE: + resize(mask, buffer, pIndex, e, null); + return true; + } + } + if (casProducerIndex(pIndex, pIndex + 2)) { + break; + } + } + // INDEX visible before ELEMENT + final int offset = modifiedCalcCircularRefElementOffset(pIndex, mask); + // release element e + soRefElement(buffer, offset, e); + return true; + } + + /** + * {@inheritDoc} + *

+ * This implementation is correct for single consumer thread use only. + */ + @SuppressWarnings("unchecked") + @Override + public E poll() { + final AtomicReferenceArray buffer = consumerBuffer; + final long cIndex = lpConsumerIndex(); + final long mask = consumerMask; + final int offset = modifiedCalcCircularRefElementOffset(cIndex, mask); + Object e = lvRefElement(buffer, offset); + if (e == null) { + long pIndex = lvProducerIndex(); + // isEmpty? + if ((cIndex - pIndex) / 2 == 0) { + return null; + } + // poll() == null iff queue is empty, null element is not strong enough indicator, so we must + // spin until element is visible. + do { + e = lvRefElement(buffer, offset); + } while (e == null); + } + if (e == JUMP) { + final AtomicReferenceArray nextBuffer = nextBuffer(buffer, mask); + return newBufferPoll(nextBuffer, cIndex); + } + // release element null + soRefElement(buffer, offset, null); + // release cIndex + soConsumerIndex(cIndex + 2); + return (E) e; + } + + /** + * {@inheritDoc} + *

+ * This implementation is correct for single consumer thread use only. + */ + @SuppressWarnings("unchecked") + @Override + public E peek() { + final AtomicReferenceArray buffer = consumerBuffer; + final long cIndex = lpConsumerIndex(); + final long mask = consumerMask; + final int offset = modifiedCalcCircularRefElementOffset(cIndex, mask); + Object e = lvRefElement(buffer, offset); + if (e == null) { + long pIndex = lvProducerIndex(); + // isEmpty? + if ((cIndex - pIndex) / 2 == 0) { + return null; + } + // peek() == null iff queue is empty, null element is not strong enough indicator, so we must + // spin until element is visible. + do { + e = lvRefElement(buffer, offset); + } while (e == null); + } + if (e == JUMP) { + return newBufferPeek(nextBuffer(buffer, mask), cIndex); + } + return (E) e; + } + + /** + * We do not inline resize into this method because we do not resize on fill. + */ + private int offerSlowPath(long mask, long pIndex, long producerLimit) { + final long cIndex = lvConsumerIndex(); + long bufferCapacity = getCurrentBufferCapacity(mask); + if (cIndex + bufferCapacity > pIndex) { + if (!casProducerLimit(producerLimit, cIndex + bufferCapacity)) { + // retry from top + return RETRY; + } else { + // continue to pIndex CAS + return CONTINUE_TO_P_INDEX_CAS; + } + } else // full and cannot grow + if (availableInQueue(pIndex, cIndex) <= 0) { + // offer should return false; + return QUEUE_FULL; + } else // grab index for resize -> set lower bit + if (casProducerIndex(pIndex, pIndex + 1)) { + // trigger a resize + return QUEUE_RESIZE; + } else { + // failed resize attempt, retry from top + return RETRY; + } + } + + /** + * @return available elements in queue * 2 + */ + protected abstract long availableInQueue(long pIndex, long cIndex); + + @SuppressWarnings("unchecked") + private AtomicReferenceArray nextBuffer(final AtomicReferenceArray buffer, final long mask) { + final int offset = nextArrayOffset(mask); + final AtomicReferenceArray nextBuffer = (AtomicReferenceArray) lvRefElement(buffer, offset); + consumerBuffer = nextBuffer; + consumerMask = (length(nextBuffer) - 2) << 1; + soRefElement(buffer, offset, BUFFER_CONSUMED); + return nextBuffer; + } + + private static int nextArrayOffset(long mask) { + return modifiedCalcCircularRefElementOffset(mask + 2, Long.MAX_VALUE); + } + + private E newBufferPoll(AtomicReferenceArray nextBuffer, long cIndex) { + final int offset = modifiedCalcCircularRefElementOffset(cIndex, consumerMask); + final E n = lvRefElement(nextBuffer, offset); + if (n == null) { + throw new IllegalStateException("new buffer must have at least one element"); + } + soRefElement(nextBuffer, offset, null); + soConsumerIndex(cIndex + 2); + return n; + } + + private E newBufferPeek(AtomicReferenceArray nextBuffer, long cIndex) { + final int offset = modifiedCalcCircularRefElementOffset(cIndex, consumerMask); + final E n = lvRefElement(nextBuffer, offset); + if (null == n) { + throw new IllegalStateException("new buffer must have at least one element"); + } + return n; + } + + @Override + public long currentProducerIndex() { + return lvProducerIndex() / 2; + } + + @Override + public long currentConsumerIndex() { + return lvConsumerIndex() / 2; + } + + @Override + public abstract int capacity(); + + @Override + public boolean relaxedOffer(E e) { + return offer(e); + } + + @SuppressWarnings("unchecked") + @Override + public E relaxedPoll() { + final AtomicReferenceArray buffer = consumerBuffer; + final long cIndex = lpConsumerIndex(); + final long mask = consumerMask; + final int offset = modifiedCalcCircularRefElementOffset(cIndex, mask); + Object e = lvRefElement(buffer, offset); + if (e == null) { + return null; + } + if (e == JUMP) { + final AtomicReferenceArray nextBuffer = nextBuffer(buffer, mask); + return newBufferPoll(nextBuffer, cIndex); + } + soRefElement(buffer, offset, null); + soConsumerIndex(cIndex + 2); + return (E) e; + } + + @SuppressWarnings("unchecked") + @Override + public E relaxedPeek() { + final AtomicReferenceArray buffer = consumerBuffer; + final long cIndex = lpConsumerIndex(); + final long mask = consumerMask; + final int offset = modifiedCalcCircularRefElementOffset(cIndex, mask); + Object e = lvRefElement(buffer, offset); + if (e == JUMP) { + return newBufferPeek(nextBuffer(buffer, mask), cIndex); + } + return (E) e; + } + + @Override + public int fill(Supplier s) { + // result is a long because we want to have a safepoint check at regular intervals + long result = 0; + final int capacity = capacity(); + do { + final int filled = fill(s, PortableJvmInfo.RECOMENDED_OFFER_BATCH); + if (filled == 0) { + return (int) result; + } + result += filled; + } while (result <= capacity); + return (int) result; + } + + @Override + public int fill(Supplier s, int limit) { + if (null == s) + throw new IllegalArgumentException("supplier is null"); + if (limit < 0) + throw new IllegalArgumentException("limit is negative:" + limit); + if (limit == 0) + return 0; + long mask; + AtomicReferenceArray buffer; + long pIndex; + int claimedSlots; + while (true) { + long producerLimit = lvProducerLimit(); + pIndex = lvProducerIndex(); + // lower bit is indicative of resize, if we see it we spin until it's cleared + if ((pIndex & 1) == 1) { + continue; + } + // pIndex is even (lower bit is 0) -> actual index is (pIndex >> 1) + // NOTE: mask/buffer may get changed by resizing -> only use for array access after successful CAS. + // Only by virtue offloading them between the lvProducerIndex and a successful casProducerIndex are they + // safe to use. + mask = this.producerMask; + buffer = this.producerBuffer; + // a successful CAS ties the ordering, lv(pIndex) -> [mask/buffer] -> cas(pIndex) + // we want 'limit' slots, but will settle for whatever is visible to 'producerLimit' + // -> producerLimit >= batchIndex + long batchIndex = Math.min(producerLimit, pIndex + 2l * limit); + if (pIndex >= producerLimit) { + int result = offerSlowPath(mask, pIndex, producerLimit); + switch(result) { + case CONTINUE_TO_P_INDEX_CAS: + // offer slow path verifies only one slot ahead, we cannot rely on indication here + case RETRY: + continue; + case QUEUE_FULL: + return 0; + case QUEUE_RESIZE: + resize(mask, buffer, pIndex, null, s); + return 1; + } + } + // claim limit slots at once + if (casProducerIndex(pIndex, batchIndex)) { + claimedSlots = (int) ((batchIndex - pIndex) / 2); + break; + } + } + for (int i = 0; i < claimedSlots; i++) { + final int offset = modifiedCalcCircularRefElementOffset(pIndex + 2l * i, mask); + soRefElement(buffer, offset, s.get()); + } + return claimedSlots; + } + + @Override + public void fill(Supplier s, WaitStrategy wait, ExitCondition exit) { + MessagePassingQueueUtil.fill(this, s, wait, exit); + } + + @Override + public int drain(Consumer c) { + return drain(c, capacity()); + } + + @Override + public int drain(Consumer c, int limit) { + return MessagePassingQueueUtil.drain(this, c, limit); + } + + @Override + public void drain(Consumer c, WaitStrategy wait, ExitCondition exit) { + MessagePassingQueueUtil.drain(this, c, wait, exit); + } + + /** + * Get an iterator for this queue. This method is thread safe. + *

+ * The iterator provides a best-effort snapshot of the elements in the queue. + * The returned iterator is not guaranteed to return elements in queue order, + * and races with the consumer thread may cause gaps in the sequence of returned elements. + * Like {link #relaxedPoll}, the iterator may not immediately return newly inserted elements. + * + * @return The iterator. + */ + @Override + public Iterator iterator() { + return new WeakIterator(consumerBuffer, lvConsumerIndex(), lvProducerIndex()); + } + + /** + * NOTE: This class was automatically generated by org.jctools.queues.atomic.JavaParsingAtomicLinkedQueueGenerator + * which can found in the jctools-build module. The original source file is BaseMpscLinkedArrayQueue.java. + */ + private static class WeakIterator implements Iterator { + + private final long pIndex; + + private long nextIndex; + + private E nextElement; + + private AtomicReferenceArray currentBuffer; + + private int mask; + + WeakIterator(AtomicReferenceArray currentBuffer, long cIndex, long pIndex) { + this.pIndex = pIndex >> 1; + this.nextIndex = cIndex >> 1; + setBuffer(currentBuffer); + nextElement = getNext(); + } + + @Override + public void remove() { + throw new UnsupportedOperationException("remove"); + } + + @Override + public boolean hasNext() { + return nextElement != null; + } + + @Override + public E next() { + final E e = nextElement; + if (e == null) { + throw new NoSuchElementException(); + } + nextElement = getNext(); + return e; + } + + private void setBuffer(AtomicReferenceArray buffer) { + this.currentBuffer = buffer; + this.mask = length(buffer) - 2; + } + + private E getNext() { + while (nextIndex < pIndex) { + long index = nextIndex++; + E e = lvRefElement(currentBuffer, calcCircularRefElementOffset(index, mask)); + // skip removed/not yet visible elements + if (e == null) { + continue; + } + // not null && not JUMP -> found next element + if (e != JUMP) { + return e; + } + // need to jump to the next buffer + int nextBufferIndex = mask + 1; + Object nextBuffer = lvRefElement(currentBuffer, calcRefElementOffset(nextBufferIndex)); + if (nextBuffer == BUFFER_CONSUMED || nextBuffer == null) { + // Consumer may have passed us, or the next buffer is not visible yet: drop out early + return null; + } + setBuffer((AtomicReferenceArray) nextBuffer); + // now with the new array retry the load, it can't be a JUMP, but we need to repeat same index + e = lvRefElement(currentBuffer, calcCircularRefElementOffset(index, mask)); + // skip removed/not yet visible elements + if (e == null) { + continue; + } else { + return e; + } + } + return null; + } + } + + private void resize(long oldMask, AtomicReferenceArray oldBuffer, long pIndex, E e, Supplier s) { + assert (e != null && s == null) || (e == null || s != null); + int newBufferLength = getNextBufferSize(oldBuffer); + final AtomicReferenceArray newBuffer; + try { + newBuffer = allocateRefArray(newBufferLength); + } catch (OutOfMemoryError oom) { + assert lvProducerIndex() == pIndex + 1; + soProducerIndex(pIndex); + throw oom; + } + producerBuffer = newBuffer; + final int newMask = (newBufferLength - 2) << 1; + producerMask = newMask; + final int offsetInOld = modifiedCalcCircularRefElementOffset(pIndex, oldMask); + final int offsetInNew = modifiedCalcCircularRefElementOffset(pIndex, newMask); + // element in new array + soRefElement(newBuffer, offsetInNew, e == null ? s.get() : e); + // buffer linked + soRefElement(oldBuffer, nextArrayOffset(oldMask), newBuffer); + // ASSERT code + final long cIndex = lvConsumerIndex(); + final long availableInQueue = availableInQueue(pIndex, cIndex); + RangeUtil.checkPositive(availableInQueue, "availableInQueue"); + // Invalidate racing CASs + // We never set the limit beyond the bounds of a buffer + soProducerLimit(pIndex + Math.min(newMask, availableInQueue)); + // make resize visible to the other producers + soProducerIndex(pIndex + 2); + // INDEX visible before ELEMENT, consistent with consumer expectation + // make resize visible to consumer + soRefElement(oldBuffer, offsetInOld, JUMP); + } + + /** + * @return next buffer size(inclusive of next array pointer) + */ + protected abstract int getNextBufferSize(AtomicReferenceArray buffer); + + /** + * @return current buffer capacity for elements (excluding next pointer and jump entry) * 2 + */ + protected abstract long getCurrentBufferCapacity(long mask); +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/atomic/BaseSpscLinkedAtomicArrayQueue.java b/netty-jctools/src/main/java/org/jctools/queues/atomic/BaseSpscLinkedAtomicArrayQueue.java new file mode 100644 index 0000000..ac921c9 --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/atomic/BaseSpscLinkedAtomicArrayQueue.java @@ -0,0 +1,444 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues.atomic; + +import org.jctools.queues.IndexedQueueSizeUtil.IndexedQueue; +import org.jctools.util.PortableJvmInfo; +import java.util.AbstractQueue; +import java.util.Iterator; +import java.util.concurrent.atomic.*; +import org.jctools.queues.*; +import static org.jctools.queues.atomic.AtomicQueueUtil.*; + +/** + * NOTE: This class was automatically generated by org.jctools.queues.atomic.JavaParsingAtomicLinkedQueueGenerator + * which can found in the jctools-build module. The original source file is BaseSpscLinkedArrayQueue.java. + */ +abstract class // byte b160,b161,b162,b163,b164,b165,b166,b167;//120b +// byte b170,b171,b172,b173,b174,b175,b176,b177;//128b +// * drop 16b , the cold fields act as buffer * +BaseSpscLinkedAtomicArrayQueuePrePad extends AbstractQueue implements IndexedQueue { + + // 8b + byte b000, b001, b002, b003, b004, b005, b006, b007; + + // 16b + byte b010, b011, b012, b013, b014, b015, b016, b017; + + // 24b + byte b020, b021, b022, b023, b024, b025, b026, b027; + + // 32b + byte b030, b031, b032, b033, b034, b035, b036, b037; + + // 40b + byte b040, b041, b042, b043, b044, b045, b046, b047; + + // 48b + byte b050, b051, b052, b053, b054, b055, b056, b057; + + // 56b + byte b060, b061, b062, b063, b064, b065, b066, b067; + + // 64b + byte b070, b071, b072, b073, b074, b075, b076, b077; + + // 72b + byte b100, b101, b102, b103, b104, b105, b106, b107; + + // 80b + byte b110, b111, b112, b113, b114, b115, b116, b117; + + // 88b + byte b120, b121, b122, b123, b124, b125, b126, b127; + + // 96b + byte b130, b131, b132, b133, b134, b135, b136, b137; + + // 104b + byte b140, b141, b142, b143, b144, b145, b146, b147; + + // 112b + byte b150, b151, b152, b153, b154, b155, b156, b157; +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.atomic.JavaParsingAtomicLinkedQueueGenerator + * which can found in the jctools-build module. The original source file is BaseSpscLinkedArrayQueue.java. + */ +abstract class BaseSpscLinkedAtomicArrayQueueConsumerColdFields extends BaseSpscLinkedAtomicArrayQueuePrePad { + + protected long consumerMask; + + protected AtomicReferenceArray consumerBuffer; +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.atomic.JavaParsingAtomicLinkedQueueGenerator + * which can found in the jctools-build module. The original source file is BaseSpscLinkedArrayQueue.java. + */ +abstract class BaseSpscLinkedAtomicArrayQueueConsumerField extends BaseSpscLinkedAtomicArrayQueueConsumerColdFields { + + private static final AtomicLongFieldUpdater C_INDEX_UPDATER = AtomicLongFieldUpdater.newUpdater(BaseSpscLinkedAtomicArrayQueueConsumerField.class, "consumerIndex"); + + private volatile long consumerIndex; + + @Override + public final long lvConsumerIndex() { + return consumerIndex; + } + + final long lpConsumerIndex() { + return consumerIndex; + } + + final void soConsumerIndex(long newValue) { + C_INDEX_UPDATER.lazySet(this, newValue); + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.atomic.JavaParsingAtomicLinkedQueueGenerator + * which can found in the jctools-build module. The original source file is BaseSpscLinkedArrayQueue.java. + */ +abstract class BaseSpscLinkedAtomicArrayQueueL2Pad extends BaseSpscLinkedAtomicArrayQueueConsumerField { + + // 8b + byte b000, b001, b002, b003, b004, b005, b006, b007; + + // 16b + byte b010, b011, b012, b013, b014, b015, b016, b017; + + // 24b + byte b020, b021, b022, b023, b024, b025, b026, b027; + + // 32b + byte b030, b031, b032, b033, b034, b035, b036, b037; + + // 40b + byte b040, b041, b042, b043, b044, b045, b046, b047; + + // 48b + byte b050, b051, b052, b053, b054, b055, b056, b057; + + // 56b + byte b060, b061, b062, b063, b064, b065, b066, b067; + + // 64b + byte b070, b071, b072, b073, b074, b075, b076, b077; + + // 72b + byte b100, b101, b102, b103, b104, b105, b106, b107; + + // 80b + byte b110, b111, b112, b113, b114, b115, b116, b117; + + // 88b + byte b120, b121, b122, b123, b124, b125, b126, b127; + + // 96b + byte b130, b131, b132, b133, b134, b135, b136, b137; + + // 104b + byte b140, b141, b142, b143, b144, b145, b146, b147; + + // 112b + byte b150, b151, b152, b153, b154, b155, b156, b157; + + // 120b + byte b160, b161, b162, b163, b164, b165, b166, b167; + + // 128b + byte b170, b171, b172, b173, b174, b175, b176, b177; +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.atomic.JavaParsingAtomicLinkedQueueGenerator + * which can found in the jctools-build module. The original source file is BaseSpscLinkedArrayQueue.java. + */ +abstract class BaseSpscLinkedAtomicArrayQueueProducerFields extends BaseSpscLinkedAtomicArrayQueueL2Pad { + + private static final AtomicLongFieldUpdater P_INDEX_UPDATER = AtomicLongFieldUpdater.newUpdater(BaseSpscLinkedAtomicArrayQueueProducerFields.class, "producerIndex"); + + private volatile long producerIndex; + + @Override + public final long lvProducerIndex() { + return producerIndex; + } + + final void soProducerIndex(long newValue) { + P_INDEX_UPDATER.lazySet(this, newValue); + } + + final long lpProducerIndex() { + return producerIndex; + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.atomic.JavaParsingAtomicLinkedQueueGenerator + * which can found in the jctools-build module. The original source file is BaseSpscLinkedArrayQueue.java. + */ +abstract class BaseSpscLinkedAtomicArrayQueueProducerColdFields extends BaseSpscLinkedAtomicArrayQueueProducerFields { + + protected long producerBufferLimit; + + // fixed for chunked and unbounded + protected long producerMask; + + protected AtomicReferenceArray producerBuffer; +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.atomic.JavaParsingAtomicLinkedQueueGenerator + * which can found in the jctools-build module. The original source file is BaseSpscLinkedArrayQueue.java. + */ +abstract class BaseSpscLinkedAtomicArrayQueue extends BaseSpscLinkedAtomicArrayQueueProducerColdFields implements MessagePassingQueue, QueueProgressIndicators { + + private static final Object JUMP = new Object(); + + @Override + public final Iterator iterator() { + throw new UnsupportedOperationException(); + } + + @Override + public final int size() { + return IndexedQueueSizeUtil.size(this, IndexedQueueSizeUtil.PLAIN_DIVISOR); + } + + @Override + public final boolean isEmpty() { + return IndexedQueueSizeUtil.isEmpty(this); + } + + @Override + public String toString() { + return this.getClass().getName(); + } + + @Override + public long currentProducerIndex() { + return lvProducerIndex(); + } + + @Override + public long currentConsumerIndex() { + return lvConsumerIndex(); + } + + protected final void soNext(AtomicReferenceArray curr, AtomicReferenceArray next) { + int offset = nextArrayOffset(curr); + soRefElement(curr, offset, next); + } + + @SuppressWarnings("unchecked") + protected final AtomicReferenceArray lvNextArrayAndUnlink(AtomicReferenceArray curr) { + final int offset = nextArrayOffset(curr); + final AtomicReferenceArray nextBuffer = (AtomicReferenceArray) lvRefElement(curr, offset); + // prevent GC nepotism + soRefElement(curr, offset, null); + return nextBuffer; + } + + @Override + public boolean relaxedOffer(E e) { + return offer(e); + } + + @Override + public E relaxedPoll() { + return poll(); + } + + @Override + public E relaxedPeek() { + return peek(); + } + + @Override + public int drain(Consumer c) { + return MessagePassingQueueUtil.drain(this, c); + } + + @Override + public int fill(Supplier s) { + // result is a long because we want to have a safepoint check at regular intervals + long result = 0; + final int capacity = capacity(); + do { + final int filled = fill(s, PortableJvmInfo.RECOMENDED_OFFER_BATCH); + if (filled == 0) { + return (int) result; + } + result += filled; + } while (result <= capacity); + return (int) result; + } + + @Override + public int drain(Consumer c, int limit) { + return MessagePassingQueueUtil.drain(this, c, limit); + } + + @Override + public int fill(Supplier s, int limit) { + if (null == s) + throw new IllegalArgumentException("supplier is null"); + if (limit < 0) + throw new IllegalArgumentException("limit is negative:" + limit); + if (limit == 0) + return 0; + for (int i = 0; i < limit; i++) { + // local load of field to avoid repeated loads after volatile reads + final AtomicReferenceArray buffer = producerBuffer; + final long index = lpProducerIndex(); + final long mask = producerMask; + final int offset = calcCircularRefElementOffset(index, mask); + // expected hot path + if (index < producerBufferLimit) { + writeToQueue(buffer, s.get(), index, offset); + } else { + if (!offerColdPath(buffer, mask, index, offset, null, s)) { + return i; + } + } + } + return limit; + } + + @Override + public void drain(Consumer c, WaitStrategy wait, ExitCondition exit) { + MessagePassingQueueUtil.drain(this, c, wait, exit); + } + + @Override + public void fill(Supplier s, WaitStrategy wait, ExitCondition exit) { + MessagePassingQueueUtil.fill(this, s, wait, exit); + } + + /** + * {@inheritDoc} + *

+ * This implementation is correct for single producer thread use only. + */ + @Override + public boolean offer(final E e) { + // Objects.requireNonNull(e); + if (null == e) { + throw new NullPointerException(); + } + // local load of field to avoid repeated loads after volatile reads + final AtomicReferenceArray buffer = producerBuffer; + final long index = lpProducerIndex(); + final long mask = producerMask; + final int offset = calcCircularRefElementOffset(index, mask); + // expected hot path + if (index < producerBufferLimit) { + writeToQueue(buffer, e, index, offset); + return true; + } + return offerColdPath(buffer, mask, index, offset, e, null); + } + + abstract boolean offerColdPath(AtomicReferenceArray buffer, long mask, long pIndex, int offset, E v, Supplier s); + + /** + * {@inheritDoc} + *

+ * This implementation is correct for single consumer thread use only. + */ + @SuppressWarnings("unchecked") + @Override + public E poll() { + // local load of field to avoid repeated loads after volatile reads + final AtomicReferenceArray buffer = consumerBuffer; + final long index = lpConsumerIndex(); + final long mask = consumerMask; + final int offset = calcCircularRefElementOffset(index, mask); + final Object e = lvRefElement(buffer, offset); + boolean isNextBuffer = e == JUMP; + if (null != e && !isNextBuffer) { + // this ensures correctness on 32bit platforms + soConsumerIndex(index + 1); + soRefElement(buffer, offset, null); + return (E) e; + } else if (isNextBuffer) { + return newBufferPoll(buffer, index); + } + return null; + } + + /** + * {@inheritDoc} + *

+ * This implementation is correct for single consumer thread use only. + */ + @SuppressWarnings("unchecked") + @Override + public E peek() { + final AtomicReferenceArray buffer = consumerBuffer; + final long index = lpConsumerIndex(); + final long mask = consumerMask; + final int offset = calcCircularRefElementOffset(index, mask); + final Object e = lvRefElement(buffer, offset); + if (e == JUMP) { + return newBufferPeek(buffer, index); + } + return (E) e; + } + + final void linkOldToNew(final long currIndex, final AtomicReferenceArray oldBuffer, final int offset, final AtomicReferenceArray newBuffer, final int offsetInNew, final E e) { + soRefElement(newBuffer, offsetInNew, e); + // link to next buffer and add next indicator as element of old buffer + soNext(oldBuffer, newBuffer); + soRefElement(oldBuffer, offset, JUMP); + // index is visible after elements (isEmpty/poll ordering) + // this ensures atomic write of long on 32bit platforms + soProducerIndex(currIndex + 1); + } + + final void writeToQueue(final AtomicReferenceArray buffer, final E e, final long index, final int offset) { + soRefElement(buffer, offset, e); + // this ensures atomic write of long on 32bit platforms + soProducerIndex(index + 1); + } + + private E newBufferPeek(final AtomicReferenceArray buffer, final long index) { + AtomicReferenceArray nextBuffer = lvNextArrayAndUnlink(buffer); + consumerBuffer = nextBuffer; + final long mask = length(nextBuffer) - 2; + consumerMask = mask; + final int offset = calcCircularRefElementOffset(index, mask); + return lvRefElement(nextBuffer, offset); + } + + private E newBufferPoll(final AtomicReferenceArray buffer, final long index) { + AtomicReferenceArray nextBuffer = lvNextArrayAndUnlink(buffer); + consumerBuffer = nextBuffer; + final long mask = length(nextBuffer) - 2; + consumerMask = mask; + final int offset = calcCircularRefElementOffset(index, mask); + final E n = lvRefElement(nextBuffer, offset); + if (null == n) { + throw new IllegalStateException("new buffer must have at least one element"); + } else { + // this ensures correctness on 32bit platforms + soConsumerIndex(index + 1); + soRefElement(nextBuffer, offset, null); + return n; + } + } +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/atomic/LinkedQueueAtomicNode.java b/netty-jctools/src/main/java/org/jctools/queues/atomic/LinkedQueueAtomicNode.java new file mode 100644 index 0000000..3338683 --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/atomic/LinkedQueueAtomicNode.java @@ -0,0 +1,69 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues.atomic; + +import java.util.concurrent.atomic.AtomicReference; + +public final class LinkedQueueAtomicNode extends AtomicReference> +{ + /** */ + private static final long serialVersionUID = 2404266111789071508L; + private E value; + + LinkedQueueAtomicNode() + { + } + + LinkedQueueAtomicNode(E val) + { + spValue(val); + } + + /** + * Gets the current value and nulls out the reference to it from this node. + * + * @return value + */ + public E getAndNullValue() + { + E temp = lpValue(); + spValue(null); + return temp; + } + + public E lpValue() + { + return value; + } + + public void spValue(E newValue) + { + value = newValue; + } + + public void soNext(LinkedQueueAtomicNode n) + { + lazySet(n); + } + + public void spNext(LinkedQueueAtomicNode n) + { + lazySet(n); + } + + public LinkedQueueAtomicNode lvNext() + { + return get(); + } +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/atomic/MpmcAtomicArrayQueue.java b/netty-jctools/src/main/java/org/jctools/queues/atomic/MpmcAtomicArrayQueue.java new file mode 100644 index 0000000..d9d57a4 --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/atomic/MpmcAtomicArrayQueue.java @@ -0,0 +1,652 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues.atomic; + +import org.jctools.util.RangeUtil; +import java.util.concurrent.atomic.*; +import org.jctools.queues.*; +import static org.jctools.queues.atomic.AtomicQueueUtil.*; + +/** + * NOTE: This class was automatically generated by org.jctools.queues.atomic.JavaParsingAtomicArrayQueueGenerator + * which can found in the jctools-build module. The original source file is MpmcArrayQueue.java. + */ +abstract class MpmcAtomicArrayQueueL1Pad extends SequencedAtomicReferenceArrayQueue { + + // 8b + byte b000, b001, b002, b003, b004, b005, b006, b007; + + // 16b + byte b010, b011, b012, b013, b014, b015, b016, b017; + + // 24b + byte b020, b021, b022, b023, b024, b025, b026, b027; + + // 32b + byte b030, b031, b032, b033, b034, b035, b036, b037; + + // 40b + byte b040, b041, b042, b043, b044, b045, b046, b047; + + // 48b + byte b050, b051, b052, b053, b054, b055, b056, b057; + + // 56b + byte b060, b061, b062, b063, b064, b065, b066, b067; + + // 64b + byte b070, b071, b072, b073, b074, b075, b076, b077; + + // 72b + byte b100, b101, b102, b103, b104, b105, b106, b107; + + // 80b + byte b110, b111, b112, b113, b114, b115, b116, b117; + + // 88b + byte b120, b121, b122, b123, b124, b125, b126, b127; + + // 96b + byte b130, b131, b132, b133, b134, b135, b136, b137; + + // 104b + byte b140, b141, b142, b143, b144, b145, b146, b147; + + // 112b + byte b150, b151, b152, b153, b154, b155, b156, b157; + + // 120b + byte b160, b161, b162, b163, b164, b165, b166, b167; + + // byte b170,b171,b172,b173,b174,b175,b176,b177;//128b + MpmcAtomicArrayQueueL1Pad(int capacity) { + super(capacity); + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.atomic.JavaParsingAtomicArrayQueueGenerator + * which can found in the jctools-build module. The original source file is MpmcArrayQueue.java. + */ +abstract class MpmcAtomicArrayQueueProducerIndexField extends MpmcAtomicArrayQueueL1Pad { + + private static final AtomicLongFieldUpdater P_INDEX_UPDATER = AtomicLongFieldUpdater.newUpdater(MpmcAtomicArrayQueueProducerIndexField.class, "producerIndex"); + + private volatile long producerIndex; + + MpmcAtomicArrayQueueProducerIndexField(int capacity) { + super(capacity); + } + + @Override + public final long lvProducerIndex() { + return producerIndex; + } + + final boolean casProducerIndex(long expect, long newValue) { + return P_INDEX_UPDATER.compareAndSet(this, expect, newValue); + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.atomic.JavaParsingAtomicArrayQueueGenerator + * which can found in the jctools-build module. The original source file is MpmcArrayQueue.java. + */ +abstract class MpmcAtomicArrayQueueL2Pad extends MpmcAtomicArrayQueueProducerIndexField { + + // 8b + byte b000, b001, b002, b003, b004, b005, b006, b007; + + // 16b + byte b010, b011, b012, b013, b014, b015, b016, b017; + + // 24b + byte b020, b021, b022, b023, b024, b025, b026, b027; + + // 32b + byte b030, b031, b032, b033, b034, b035, b036, b037; + + // 40b + byte b040, b041, b042, b043, b044, b045, b046, b047; + + // 48b + byte b050, b051, b052, b053, b054, b055, b056, b057; + + // 56b + byte b060, b061, b062, b063, b064, b065, b066, b067; + + // 64b + byte b070, b071, b072, b073, b074, b075, b076, b077; + + // 72b + byte b100, b101, b102, b103, b104, b105, b106, b107; + + // 80b + byte b110, b111, b112, b113, b114, b115, b116, b117; + + // 88b + byte b120, b121, b122, b123, b124, b125, b126, b127; + + // 96b + byte b130, b131, b132, b133, b134, b135, b136, b137; + + // 104b + byte b140, b141, b142, b143, b144, b145, b146, b147; + + // 112b + byte b150, b151, b152, b153, b154, b155, b156, b157; + + // 120b + byte b160, b161, b162, b163, b164, b165, b166, b167; + + // 128b + byte b170, b171, b172, b173, b174, b175, b176, b177; + + MpmcAtomicArrayQueueL2Pad(int capacity) { + super(capacity); + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.atomic.JavaParsingAtomicArrayQueueGenerator + * which can found in the jctools-build module. The original source file is MpmcArrayQueue.java. + */ +abstract class MpmcAtomicArrayQueueConsumerIndexField extends MpmcAtomicArrayQueueL2Pad { + + private static final AtomicLongFieldUpdater C_INDEX_UPDATER = AtomicLongFieldUpdater.newUpdater(MpmcAtomicArrayQueueConsumerIndexField.class, "consumerIndex"); + + private volatile long consumerIndex; + + MpmcAtomicArrayQueueConsumerIndexField(int capacity) { + super(capacity); + } + + @Override + public final long lvConsumerIndex() { + return consumerIndex; + } + + final boolean casConsumerIndex(long expect, long newValue) { + return C_INDEX_UPDATER.compareAndSet(this, expect, newValue); + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.atomic.JavaParsingAtomicArrayQueueGenerator + * which can found in the jctools-build module. The original source file is MpmcArrayQueue.java. + */ +abstract class MpmcAtomicArrayQueueL3Pad extends MpmcAtomicArrayQueueConsumerIndexField { + + // 8b + byte b000, b001, b002, b003, b004, b005, b006, b007; + + // 16b + byte b010, b011, b012, b013, b014, b015, b016, b017; + + // 24b + byte b020, b021, b022, b023, b024, b025, b026, b027; + + // 32b + byte b030, b031, b032, b033, b034, b035, b036, b037; + + // 40b + byte b040, b041, b042, b043, b044, b045, b046, b047; + + // 48b + byte b050, b051, b052, b053, b054, b055, b056, b057; + + // 56b + byte b060, b061, b062, b063, b064, b065, b066, b067; + + // 64b + byte b070, b071, b072, b073, b074, b075, b076, b077; + + // 72b + byte b100, b101, b102, b103, b104, b105, b106, b107; + + // 80b + byte b110, b111, b112, b113, b114, b115, b116, b117; + + // 88b + byte b120, b121, b122, b123, b124, b125, b126, b127; + + // 96b + byte b130, b131, b132, b133, b134, b135, b136, b137; + + // 104b + byte b140, b141, b142, b143, b144, b145, b146, b147; + + // 112b + byte b150, b151, b152, b153, b154, b155, b156, b157; + + // 120b + byte b160, b161, b162, b163, b164, b165, b166, b167; + + // 128b + byte b170, b171, b172, b173, b174, b175, b176, b177; + + MpmcAtomicArrayQueueL3Pad(int capacity) { + super(capacity); + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.atomic.JavaParsingAtomicArrayQueueGenerator + * which can found in the jctools-build module. The original source file is MpmcArrayQueue.java. + * + * A Multi-Producer-Multi-Consumer queue based on a {@link org.jctools.queues.ConcurrentCircularArrayQueue}. This + * implies that any and all threads may call the offer/poll/peek methods and correctness is maintained.
+ * This implementation follows patterns documented on the package level for False Sharing protection.
+ * The algorithm for offer/poll is an adaptation of the one put forward by D. Vyukov (See here). The original + * algorithm uses an array of structs which should offer nice locality properties but is sadly not possible in + * Java (waiting on Value Types or similar). The alternative explored here utilizes 2 arrays, one for each + * field of the struct. There is a further alternative in the experimental project which uses iteration phase + * markers to achieve the same algo and is closer structurally to the original, but sadly does not perform as + * well as this implementation.
+ *

+ * Tradeoffs to keep in mind: + *

    + *
  1. Padding for false sharing: counter fields and queue fields are all padded as well as either side of + * both arrays. We are trading memory to avoid false sharing(active and passive). + *
  2. 2 arrays instead of one: The algorithm requires an extra array of longs matching the size of the + * elements array. This is doubling/tripling the memory allocated for the buffer. + *
  3. Power of 2 capacity: Actual elements buffer (and sequence buffer) is the closest power of 2 larger or + * equal to the requested capacity. + *
+ */ +public class MpmcAtomicArrayQueue extends MpmcAtomicArrayQueueL3Pad { + + public static final int MAX_LOOK_AHEAD_STEP = Integer.getInteger("jctools.mpmc.max.lookahead.step", 4096); + + private final int lookAheadStep; + + public MpmcAtomicArrayQueue(final int capacity) { + super(RangeUtil.checkGreaterThanOrEqual(capacity, 2, "capacity")); + lookAheadStep = Math.max(2, Math.min(capacity() / 4, MAX_LOOK_AHEAD_STEP)); + } + + @Override + public boolean offer(final E e) { + if (null == e) { + throw new NullPointerException(); + } + final int mask = this.mask; + final long capacity = mask + 1; + final AtomicLongArray sBuffer = sequenceBuffer; + long pIndex; + int seqOffset; + long seq; + // start with bogus value, hope we don't need it + long cIndex = Long.MIN_VALUE; + do { + pIndex = lvProducerIndex(); + seqOffset = calcCircularLongElementOffset(pIndex, mask); + seq = lvLongElement(sBuffer, seqOffset); + // consumer has not moved this seq forward, it's as last producer left + if (seq < pIndex) { + // Extra check required to ensure [Queue.offer == false iff queue is full] + if (// test against cached cIndex + pIndex - capacity >= cIndex && // test against latest cIndex + pIndex - capacity >= (cIndex = lvConsumerIndex())) { + return false; + } else { + // (+) hack to make it go around again without CAS + seq = pIndex + 1; + } + } + } while (// another producer has moved the sequence(or +) + seq > pIndex || // failed to increment + !casProducerIndex(pIndex, pIndex + 1)); + // casProducerIndex ensures correct construction + spRefElement(buffer, calcCircularRefElementOffset(pIndex, mask), e); + // seq++; + soLongElement(sBuffer, seqOffset, pIndex + 1); + return true; + } + + /** + * {@inheritDoc} + *

+ * Because return null indicates queue is empty we cannot simply rely on next element visibility for poll + * and must test producer index when next element is not visible. + */ + @Override + public E poll() { + // local load of field to avoid repeated loads after volatile reads + final AtomicLongArray sBuffer = sequenceBuffer; + final int mask = this.mask; + long cIndex; + long seq; + int seqOffset; + long expectedSeq; + // start with bogus value, hope we don't need it + long pIndex = -1; + do { + cIndex = lvConsumerIndex(); + seqOffset = calcCircularLongElementOffset(cIndex, mask); + seq = lvLongElement(sBuffer, seqOffset); + expectedSeq = cIndex + 1; + if (seq < expectedSeq) { + // slot has not been moved by producer + if (// test against cached pIndex + cIndex >= pIndex && // update pIndex if we must + cIndex == (pIndex = lvProducerIndex())) { + // strict empty check, this ensures [Queue.poll() == null iff isEmpty()] + return null; + } else { + // trip another go around + seq = expectedSeq + 1; + } + } + } while (// another consumer beat us to it + seq > expectedSeq || // failed the CAS + !casConsumerIndex(cIndex, cIndex + 1)); + final int offset = calcCircularRefElementOffset(cIndex, mask); + final E e = lpRefElement(buffer, offset); + spRefElement(buffer, offset, null); + // i.e. seq += capacity + soLongElement(sBuffer, seqOffset, cIndex + mask + 1); + return e; + } + + @Override + public E peek() { + // local load of field to avoid repeated loads after volatile reads + final AtomicLongArray sBuffer = sequenceBuffer; + final int mask = this.mask; + long cIndex; + long seq; + int seqOffset; + long expectedSeq; + // start with bogus value, hope we don't need it + long pIndex = -1; + E e; + while (true) { + cIndex = lvConsumerIndex(); + seqOffset = calcCircularLongElementOffset(cIndex, mask); + seq = lvLongElement(sBuffer, seqOffset); + expectedSeq = cIndex + 1; + if (seq < expectedSeq) { + // slot has not been moved by producer + if (// test against cached pIndex + cIndex >= pIndex && // update pIndex if we must + cIndex == (pIndex = lvProducerIndex())) { + // strict empty check, this ensures [Queue.poll() == null iff isEmpty()] + return null; + } + } else if (seq == expectedSeq) { + final int offset = calcCircularRefElementOffset(cIndex, mask); + e = lvRefElement(buffer, offset); + if (lvConsumerIndex() == cIndex) + return e; + } + } + } + + @Override + public boolean relaxedOffer(E e) { + if (null == e) { + throw new NullPointerException(); + } + final int mask = this.mask; + final AtomicLongArray sBuffer = sequenceBuffer; + long pIndex; + int seqOffset; + long seq; + do { + pIndex = lvProducerIndex(); + seqOffset = calcCircularLongElementOffset(pIndex, mask); + seq = lvLongElement(sBuffer, seqOffset); + if (seq < pIndex) { + // slot not cleared by consumer yet + return false; + } + } while (// another producer has moved the sequence + seq > pIndex || // failed to increment + !casProducerIndex(pIndex, pIndex + 1)); + // casProducerIndex ensures correct construction + spRefElement(buffer, calcCircularRefElementOffset(pIndex, mask), e); + soLongElement(sBuffer, seqOffset, pIndex + 1); + return true; + } + + @Override + public E relaxedPoll() { + final AtomicLongArray sBuffer = sequenceBuffer; + final int mask = this.mask; + long cIndex; + int seqOffset; + long seq; + long expectedSeq; + do { + cIndex = lvConsumerIndex(); + seqOffset = calcCircularLongElementOffset(cIndex, mask); + seq = lvLongElement(sBuffer, seqOffset); + expectedSeq = cIndex + 1; + if (seq < expectedSeq) { + return null; + } + } while (// another consumer beat us to it + seq > expectedSeq || // failed the CAS + !casConsumerIndex(cIndex, cIndex + 1)); + final int offset = calcCircularRefElementOffset(cIndex, mask); + final E e = lpRefElement(buffer, offset); + spRefElement(buffer, offset, null); + soLongElement(sBuffer, seqOffset, cIndex + mask + 1); + return e; + } + + @Override + public E relaxedPeek() { + // local load of field to avoid repeated loads after volatile reads + final AtomicLongArray sBuffer = sequenceBuffer; + final int mask = this.mask; + long cIndex; + long seq; + int seqOffset; + long expectedSeq; + E e; + do { + cIndex = lvConsumerIndex(); + seqOffset = calcCircularLongElementOffset(cIndex, mask); + seq = lvLongElement(sBuffer, seqOffset); + expectedSeq = cIndex + 1; + if (seq < expectedSeq) { + return null; + } else if (seq == expectedSeq) { + final int offset = calcCircularRefElementOffset(cIndex, mask); + e = lvRefElement(buffer, offset); + if (lvConsumerIndex() == cIndex) + return e; + } + } while (true); + } + + @Override + public int drain(Consumer c, int limit) { + if (null == c) + throw new IllegalArgumentException("c is null"); + if (limit < 0) + throw new IllegalArgumentException("limit is negative: " + limit); + if (limit == 0) + return 0; + final AtomicLongArray sBuffer = sequenceBuffer; + final int mask = this.mask; + final AtomicReferenceArray buffer = this.buffer; + final int maxLookAheadStep = Math.min(this.lookAheadStep, limit); + int consumed = 0; + while (consumed < limit) { + final int remaining = limit - consumed; + final int lookAheadStep = Math.min(remaining, maxLookAheadStep); + final long cIndex = lvConsumerIndex(); + final long lookAheadIndex = cIndex + lookAheadStep - 1; + final int lookAheadSeqOffset = calcCircularLongElementOffset(lookAheadIndex, mask); + final long lookAheadSeq = lvLongElement(sBuffer, lookAheadSeqOffset); + final long expectedLookAheadSeq = lookAheadIndex + 1; + if (lookAheadSeq == expectedLookAheadSeq && casConsumerIndex(cIndex, expectedLookAheadSeq)) { + for (int i = 0; i < lookAheadStep; i++) { + final long index = cIndex + i; + final int seqOffset = calcCircularLongElementOffset(index, mask); + final int offset = calcCircularRefElementOffset(index, mask); + final long expectedSeq = index + 1; + while (lvLongElement(sBuffer, seqOffset) != expectedSeq) { + } + final E e = lpRefElement(buffer, offset); + spRefElement(buffer, offset, null); + soLongElement(sBuffer, seqOffset, index + mask + 1); + c.accept(e); + } + consumed += lookAheadStep; + } else { + if (lookAheadSeq < expectedLookAheadSeq) { + if (notAvailable(cIndex, mask, sBuffer, cIndex + 1)) { + return consumed; + } + } + return consumed + drainOneByOne(c, remaining); + } + } + return limit; + } + + private int drainOneByOne(Consumer c, int limit) { + final AtomicLongArray sBuffer = sequenceBuffer; + final int mask = this.mask; + final AtomicReferenceArray buffer = this.buffer; + long cIndex; + int seqOffset; + long seq; + long expectedSeq; + for (int i = 0; i < limit; i++) { + do { + cIndex = lvConsumerIndex(); + seqOffset = calcCircularLongElementOffset(cIndex, mask); + seq = lvLongElement(sBuffer, seqOffset); + expectedSeq = cIndex + 1; + if (seq < expectedSeq) { + return i; + } + } while (// another consumer beat us to it + seq > expectedSeq || // failed the CAS + !casConsumerIndex(cIndex, cIndex + 1)); + final int offset = calcCircularRefElementOffset(cIndex, mask); + final E e = lpRefElement(buffer, offset); + spRefElement(buffer, offset, null); + soLongElement(sBuffer, seqOffset, cIndex + mask + 1); + c.accept(e); + } + return limit; + } + + @Override + public int fill(Supplier s, int limit) { + if (null == s) + throw new IllegalArgumentException("supplier is null"); + if (limit < 0) + throw new IllegalArgumentException("limit is negative:" + limit); + if (limit == 0) + return 0; + final AtomicLongArray sBuffer = sequenceBuffer; + final int mask = this.mask; + final AtomicReferenceArray buffer = this.buffer; + final int maxLookAheadStep = Math.min(this.lookAheadStep, limit); + int produced = 0; + while (produced < limit) { + final int remaining = limit - produced; + final int lookAheadStep = Math.min(remaining, maxLookAheadStep); + final long pIndex = lvProducerIndex(); + final long lookAheadIndex = pIndex + lookAheadStep - 1; + final int lookAheadSeqOffset = calcCircularLongElementOffset(lookAheadIndex, mask); + final long lookAheadSeq = lvLongElement(sBuffer, lookAheadSeqOffset); + final long expectedLookAheadSeq = lookAheadIndex; + if (lookAheadSeq == expectedLookAheadSeq && casProducerIndex(pIndex, expectedLookAheadSeq + 1)) { + for (int i = 0; i < lookAheadStep; i++) { + final long index = pIndex + i; + final int seqOffset = calcCircularLongElementOffset(index, mask); + final int offset = calcCircularRefElementOffset(index, mask); + while (lvLongElement(sBuffer, seqOffset) != index) { + } + // Ordered store ensures correct construction + soRefElement(buffer, offset, s.get()); + soLongElement(sBuffer, seqOffset, index + 1); + } + produced += lookAheadStep; + } else { + if (lookAheadSeq < expectedLookAheadSeq) { + if (notAvailable(pIndex, mask, sBuffer, pIndex)) { + return produced; + } + } + return produced + fillOneByOne(s, remaining); + } + } + return limit; + } + + private boolean notAvailable(long index, int mask, AtomicLongArray sBuffer, long expectedSeq) { + final int seqOffset = calcCircularLongElementOffset(index, mask); + final long seq = lvLongElement(sBuffer, seqOffset); + if (seq < expectedSeq) { + return true; + } + return false; + } + + private int fillOneByOne(Supplier s, int limit) { + final AtomicLongArray sBuffer = sequenceBuffer; + final int mask = this.mask; + final AtomicReferenceArray buffer = this.buffer; + long pIndex; + int seqOffset; + long seq; + for (int i = 0; i < limit; i++) { + do { + pIndex = lvProducerIndex(); + seqOffset = calcCircularLongElementOffset(pIndex, mask); + seq = lvLongElement(sBuffer, seqOffset); + if (seq < pIndex) { + // slot not cleared by consumer yet + return i; + } + } while (// another producer has moved the sequence + seq > pIndex || // failed to increment + !casProducerIndex(pIndex, pIndex + 1)); + // Ordered store ensures correct construction + soRefElement(buffer, calcCircularRefElementOffset(pIndex, mask), s.get()); + soLongElement(sBuffer, seqOffset, pIndex + 1); + } + return limit; + } + + @Override + public int drain(Consumer c) { + return MessagePassingQueueUtil.drain(this, c); + } + + @Override + public int fill(Supplier s) { + return MessagePassingQueueUtil.fillBounded(this, s); + } + + @Override + public void drain(Consumer c, WaitStrategy w, ExitCondition exit) { + MessagePassingQueueUtil.drain(this, c, w, exit); + } + + @Override + public void fill(Supplier s, WaitStrategy wait, ExitCondition exit) { + MessagePassingQueueUtil.fill(this, s, wait, exit); + } +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/atomic/MpscAtomicArrayQueue.java b/netty-jctools/src/main/java/org/jctools/queues/atomic/MpscAtomicArrayQueue.java new file mode 100644 index 0000000..ecaf089 --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/atomic/MpscAtomicArrayQueue.java @@ -0,0 +1,664 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues.atomic; + +import java.util.concurrent.atomic.*; +import org.jctools.queues.*; +import static org.jctools.queues.atomic.AtomicQueueUtil.*; + +/** + * NOTE: This class was automatically generated by org.jctools.queues.atomic.JavaParsingAtomicArrayQueueGenerator + * which can found in the jctools-build module. The original source file is MpscArrayQueue.java. + */ +abstract class MpscAtomicArrayQueueL1Pad extends AtomicReferenceArrayQueue { + + // 8b + byte b000, b001, b002, b003, b004, b005, b006, b007; + + // 16b + byte b010, b011, b012, b013, b014, b015, b016, b017; + + // 24b + byte b020, b021, b022, b023, b024, b025, b026, b027; + + // 32b + byte b030, b031, b032, b033, b034, b035, b036, b037; + + // 40b + byte b040, b041, b042, b043, b044, b045, b046, b047; + + // 48b + byte b050, b051, b052, b053, b054, b055, b056, b057; + + // 56b + byte b060, b061, b062, b063, b064, b065, b066, b067; + + // 64b + byte b070, b071, b072, b073, b074, b075, b076, b077; + + // 72b + byte b100, b101, b102, b103, b104, b105, b106, b107; + + // 80b + byte b110, b111, b112, b113, b114, b115, b116, b117; + + // 88b + byte b120, b121, b122, b123, b124, b125, b126, b127; + + // 96b + byte b130, b131, b132, b133, b134, b135, b136, b137; + + // 104b + byte b140, b141, b142, b143, b144, b145, b146, b147; + + // 112b + byte b150, b151, b152, b153, b154, b155, b156, b157; + + // 120b + byte b160, b161, b162, b163, b164, b165, b166, b167; + + // byte b170,b171,b172,b173,b174,b175,b176,b177;//128b + MpscAtomicArrayQueueL1Pad(int capacity) { + super(capacity); + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.atomic.JavaParsingAtomicArrayQueueGenerator + * which can found in the jctools-build module. The original source file is MpscArrayQueue.java. + */ +abstract class MpscAtomicArrayQueueProducerIndexField extends MpscAtomicArrayQueueL1Pad { + + private static final AtomicLongFieldUpdater P_INDEX_UPDATER = AtomicLongFieldUpdater.newUpdater(MpscAtomicArrayQueueProducerIndexField.class, "producerIndex"); + + private volatile long producerIndex; + + MpscAtomicArrayQueueProducerIndexField(int capacity) { + super(capacity); + } + + @Override + public final long lvProducerIndex() { + return producerIndex; + } + + final boolean casProducerIndex(long expect, long newValue) { + return P_INDEX_UPDATER.compareAndSet(this, expect, newValue); + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.atomic.JavaParsingAtomicArrayQueueGenerator + * which can found in the jctools-build module. The original source file is MpscArrayQueue.java. + */ +abstract class MpscAtomicArrayQueueMidPad extends MpscAtomicArrayQueueProducerIndexField { + + // 8b + byte b000, b001, b002, b003, b004, b005, b006, b007; + + // 16b + byte b010, b011, b012, b013, b014, b015, b016, b017; + + // 24b + byte b020, b021, b022, b023, b024, b025, b026, b027; + + // 32b + byte b030, b031, b032, b033, b034, b035, b036, b037; + + // 40b + byte b040, b041, b042, b043, b044, b045, b046, b047; + + // 48b + byte b050, b051, b052, b053, b054, b055, b056, b057; + + // 56b + byte b060, b061, b062, b063, b064, b065, b066, b067; + + // 64b + byte b070, b071, b072, b073, b074, b075, b076, b077; + + // 72b + byte b100, b101, b102, b103, b104, b105, b106, b107; + + // 80b + byte b110, b111, b112, b113, b114, b115, b116, b117; + + // 88b + byte b120, b121, b122, b123, b124, b125, b126, b127; + + // 96b + byte b130, b131, b132, b133, b134, b135, b136, b137; + + // 104b + byte b140, b141, b142, b143, b144, b145, b146, b147; + + // 112b + byte b150, b151, b152, b153, b154, b155, b156, b157; + + // 120b + byte b160, b161, b162, b163, b164, b165, b166, b167; + + // 128b + byte b170, b171, b172, b173, b174, b175, b176, b177; + + MpscAtomicArrayQueueMidPad(int capacity) { + super(capacity); + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.atomic.JavaParsingAtomicArrayQueueGenerator + * which can found in the jctools-build module. The original source file is MpscArrayQueue.java. + */ +abstract class MpscAtomicArrayQueueProducerLimitField extends MpscAtomicArrayQueueMidPad { + + private static final AtomicLongFieldUpdater P_LIMIT_UPDATER = AtomicLongFieldUpdater.newUpdater(MpscAtomicArrayQueueProducerLimitField.class, "producerLimit"); + + // First unavailable index the producer may claim up to before rereading the consumer index + private volatile long producerLimit; + + MpscAtomicArrayQueueProducerLimitField(int capacity) { + super(capacity); + this.producerLimit = capacity; + } + + final long lvProducerLimit() { + return producerLimit; + } + + final void soProducerLimit(long newValue) { + P_LIMIT_UPDATER.lazySet(this, newValue); + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.atomic.JavaParsingAtomicArrayQueueGenerator + * which can found in the jctools-build module. The original source file is MpscArrayQueue.java. + */ +abstract class MpscAtomicArrayQueueL2Pad extends MpscAtomicArrayQueueProducerLimitField { + + // 8b + byte b000, b001, b002, b003, b004, b005, b006, b007; + + // 16b + byte b010, b011, b012, b013, b014, b015, b016, b017; + + // 24b + byte b020, b021, b022, b023, b024, b025, b026, b027; + + // 32b + byte b030, b031, b032, b033, b034, b035, b036, b037; + + // 40b + byte b040, b041, b042, b043, b044, b045, b046, b047; + + // 48b + byte b050, b051, b052, b053, b054, b055, b056, b057; + + // 56b + byte b060, b061, b062, b063, b064, b065, b066, b067; + + // 64b + byte b070, b071, b072, b073, b074, b075, b076, b077; + + // 72b + byte b100, b101, b102, b103, b104, b105, b106, b107; + + // 80b + byte b110, b111, b112, b113, b114, b115, b116, b117; + + // 88b + byte b120, b121, b122, b123, b124, b125, b126, b127; + + // 96b + byte b130, b131, b132, b133, b134, b135, b136, b137; + + // 104b + byte b140, b141, b142, b143, b144, b145, b146, b147; + + // 112b + byte b150, b151, b152, b153, b154, b155, b156, b157; + + // 120b + byte b160, b161, b162, b163, b164, b165, b166, b167; + + // byte b170,b171,b172,b173,b174,b175,b176,b177;//128b + MpscAtomicArrayQueueL2Pad(int capacity) { + super(capacity); + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.atomic.JavaParsingAtomicArrayQueueGenerator + * which can found in the jctools-build module. The original source file is MpscArrayQueue.java. + */ +abstract class MpscAtomicArrayQueueConsumerIndexField extends MpscAtomicArrayQueueL2Pad { + + private static final AtomicLongFieldUpdater C_INDEX_UPDATER = AtomicLongFieldUpdater.newUpdater(MpscAtomicArrayQueueConsumerIndexField.class, "consumerIndex"); + + private volatile long consumerIndex; + + MpscAtomicArrayQueueConsumerIndexField(int capacity) { + super(capacity); + } + + @Override + public final long lvConsumerIndex() { + return consumerIndex; + } + + final long lpConsumerIndex() { + return consumerIndex; + } + + final void soConsumerIndex(long newValue) { + C_INDEX_UPDATER.lazySet(this, newValue); + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.atomic.JavaParsingAtomicArrayQueueGenerator + * which can found in the jctools-build module. The original source file is MpscArrayQueue.java. + */ +abstract class MpscAtomicArrayQueueL3Pad extends MpscAtomicArrayQueueConsumerIndexField { + + // 8b + byte b000, b001, b002, b003, b004, b005, b006, b007; + + // 16b + byte b010, b011, b012, b013, b014, b015, b016, b017; + + // 24b + byte b020, b021, b022, b023, b024, b025, b026, b027; + + // 32b + byte b030, b031, b032, b033, b034, b035, b036, b037; + + // 40b + byte b040, b041, b042, b043, b044, b045, b046, b047; + + // 48b + byte b050, b051, b052, b053, b054, b055, b056, b057; + + // 56b + byte b060, b061, b062, b063, b064, b065, b066, b067; + + // 64b + byte b070, b071, b072, b073, b074, b075, b076, b077; + + // 72b + byte b100, b101, b102, b103, b104, b105, b106, b107; + + // 80b + byte b110, b111, b112, b113, b114, b115, b116, b117; + + // 88b + byte b120, b121, b122, b123, b124, b125, b126, b127; + + // 96b + byte b130, b131, b132, b133, b134, b135, b136, b137; + + // 104b + byte b140, b141, b142, b143, b144, b145, b146, b147; + + // 112b + byte b150, b151, b152, b153, b154, b155, b156, b157; + + // 120b + byte b160, b161, b162, b163, b164, b165, b166, b167; + + // 128b + byte b170, b171, b172, b173, b174, b175, b176, b177; + + MpscAtomicArrayQueueL3Pad(int capacity) { + super(capacity); + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.atomic.JavaParsingAtomicArrayQueueGenerator + * which can found in the jctools-build module. The original source file is MpscArrayQueue.java. + * + * A Multi-Producer-Single-Consumer queue based on a {@link org.jctools.queues.ConcurrentCircularArrayQueue}. This + * implies that any thread may call the offer method, but only a single thread may call poll/peek for correctness to + * maintained.
+ * This implementation follows patterns documented on the package level for False Sharing protection.
+ * This implementation is using the Fast Flow + * method for polling from the queue (with minor change to correctly publish the index) and an extension of + * the Leslie Lamport concurrent queue algorithm (originated by Martin Thompson) on the producer side. + */ +public class MpscAtomicArrayQueue extends MpscAtomicArrayQueueL3Pad { + + public MpscAtomicArrayQueue(final int capacity) { + super(capacity); + } + + /** + * {@link #offer}} if {@link #size()} is less than threshold. + * + * @param e the object to offer onto the queue, not null + * @param threshold the maximum allowable size + * @return true if the offer is successful, false if queue size exceeds threshold + * @since 1.0.1 + */ + public boolean offerIfBelowThreshold(final E e, int threshold) { + if (null == e) { + throw new NullPointerException(); + } + final int mask = this.mask; + final long capacity = mask + 1; + long producerLimit = lvProducerLimit(); + long pIndex; + do { + pIndex = lvProducerIndex(); + long available = producerLimit - pIndex; + long size = capacity - available; + if (size >= threshold) { + final long cIndex = lvConsumerIndex(); + size = pIndex - cIndex; + if (size >= threshold) { + // the size exceeds threshold + return false; + } else { + // update producer limit to the next index that we must recheck the consumer index + producerLimit = cIndex + capacity; + // this is racy, but the race is benign + soProducerLimit(producerLimit); + } + } + } while (!casProducerIndex(pIndex, pIndex + 1)); + /* + * NOTE: the new producer index value is made visible BEFORE the element in the array. If we relied on + * the index visibility to poll() we would need to handle the case where the element is not visible. + */ + // Won CAS, move on to storing + final int offset = calcCircularRefElementOffset(pIndex, mask); + soRefElement(buffer, offset, e); + // AWESOME :) + return true; + } + + /** + * {@inheritDoc}
+ *

+ * IMPLEMENTATION NOTES:
+ * Lock free offer using a single CAS. As class name suggests access is permitted to many threads + * concurrently. + * + * @see java.util.Queue#offer + * @see org.jctools.queues.MessagePassingQueue#offer + */ + @Override + public boolean offer(final E e) { + if (null == e) { + throw new NullPointerException(); + } + // use a cached view on consumer index (potentially updated in loop) + final int mask = this.mask; + long producerLimit = lvProducerLimit(); + long pIndex; + do { + pIndex = lvProducerIndex(); + if (pIndex >= producerLimit) { + final long cIndex = lvConsumerIndex(); + producerLimit = cIndex + mask + 1; + if (pIndex >= producerLimit) { + // FULL :( + return false; + } else { + // update producer limit to the next index that we must recheck the consumer index + // this is racy, but the race is benign + soProducerLimit(producerLimit); + } + } + } while (!casProducerIndex(pIndex, pIndex + 1)); + /* + * NOTE: the new producer index value is made visible BEFORE the element in the array. If we relied on + * the index visibility to poll() we would need to handle the case where the element is not visible. + */ + // Won CAS, move on to storing + final int offset = calcCircularRefElementOffset(pIndex, mask); + soRefElement(buffer, offset, e); + // AWESOME :) + return true; + } + + /** + * A wait free alternative to offer which fails on CAS failure. + * + * @param e new element, not null + * @return 1 if next element cannot be filled, -1 if CAS failed, 0 if successful + */ + public final int failFastOffer(final E e) { + if (null == e) { + throw new NullPointerException(); + } + final int mask = this.mask; + final long capacity = mask + 1; + final long pIndex = lvProducerIndex(); + long producerLimit = lvProducerLimit(); + if (pIndex >= producerLimit) { + final long cIndex = lvConsumerIndex(); + producerLimit = cIndex + capacity; + if (pIndex >= producerLimit) { + // FULL :( + return 1; + } else { + // update producer limit to the next index that we must recheck the consumer index + soProducerLimit(producerLimit); + } + } + // look Ma, no loop! + if (!casProducerIndex(pIndex, pIndex + 1)) { + // CAS FAIL :( + return -1; + } + // Won CAS, move on to storing + final int offset = calcCircularRefElementOffset(pIndex, mask); + soRefElement(buffer, offset, e); + // AWESOME :) + return 0; + } + + /** + * {@inheritDoc} + *

+ * IMPLEMENTATION NOTES:
+ * Lock free poll using ordered loads/stores. As class name suggests access is limited to a single thread. + * + * @see java.util.Queue#poll + * @see org.jctools.queues.MessagePassingQueue#poll + */ + @Override + public E poll() { + final long cIndex = lpConsumerIndex(); + final int offset = calcCircularRefElementOffset(cIndex, mask); + // Copy field to avoid re-reading after volatile load + final AtomicReferenceArray buffer = this.buffer; + // If we can't see the next available element we can't poll + E e = lvRefElement(buffer, offset); + if (null == e) { + /* + * NOTE: Queue may not actually be empty in the case of a producer (P1) being interrupted after + * winning the CAS on offer but before storing the element in the queue. Other producers may go on + * to fill up the queue after this element. + */ + if (cIndex != lvProducerIndex()) { + do { + e = lvRefElement(buffer, offset); + } while (e == null); + } else { + return null; + } + } + spRefElement(buffer, offset, null); + soConsumerIndex(cIndex + 1); + return e; + } + + /** + * {@inheritDoc} + *

+ * IMPLEMENTATION NOTES:
+ * Lock free peek using ordered loads. As class name suggests access is limited to a single thread. + * + * @see java.util.Queue#poll + * @see org.jctools.queues.MessagePassingQueue#poll + */ + @Override + public E peek() { + // Copy field to avoid re-reading after volatile load + final AtomicReferenceArray buffer = this.buffer; + final long cIndex = lpConsumerIndex(); + final int offset = calcCircularRefElementOffset(cIndex, mask); + E e = lvRefElement(buffer, offset); + if (null == e) { + /* + * NOTE: Queue may not actually be empty in the case of a producer (P1) being interrupted after + * winning the CAS on offer but before storing the element in the queue. Other producers may go on + * to fill up the queue after this element. + */ + if (cIndex != lvProducerIndex()) { + do { + e = lvRefElement(buffer, offset); + } while (e == null); + } else { + return null; + } + } + return e; + } + + @Override + public boolean relaxedOffer(E e) { + return offer(e); + } + + @Override + public E relaxedPoll() { + final AtomicReferenceArray buffer = this.buffer; + final long cIndex = lpConsumerIndex(); + final int offset = calcCircularRefElementOffset(cIndex, mask); + // If we can't see the next available element we can't poll + E e = lvRefElement(buffer, offset); + if (null == e) { + return null; + } + spRefElement(buffer, offset, null); + soConsumerIndex(cIndex + 1); + return e; + } + + @Override + public E relaxedPeek() { + final AtomicReferenceArray buffer = this.buffer; + final int mask = this.mask; + final long cIndex = lpConsumerIndex(); + return lvRefElement(buffer, calcCircularRefElementOffset(cIndex, mask)); + } + + @Override + public int drain(final Consumer c, final int limit) { + if (null == c) + throw new IllegalArgumentException("c is null"); + if (limit < 0) + throw new IllegalArgumentException("limit is negative: " + limit); + if (limit == 0) + return 0; + final AtomicReferenceArray buffer = this.buffer; + final int mask = this.mask; + final long cIndex = lpConsumerIndex(); + for (int i = 0; i < limit; i++) { + final long index = cIndex + i; + final int offset = calcCircularRefElementOffset(index, mask); + final E e = lvRefElement(buffer, offset); + if (null == e) { + return i; + } + spRefElement(buffer, offset, null); + // ordered store -> atomic and ordered for size() + soConsumerIndex(index + 1); + c.accept(e); + } + return limit; + } + + @Override + public int fill(Supplier s, int limit) { + if (null == s) + throw new IllegalArgumentException("supplier is null"); + if (limit < 0) + throw new IllegalArgumentException("limit is negative:" + limit); + if (limit == 0) + return 0; + final int mask = this.mask; + final long capacity = mask + 1; + long producerLimit = lvProducerLimit(); + long pIndex; + int actualLimit; + do { + pIndex = lvProducerIndex(); + long available = producerLimit - pIndex; + if (available <= 0) { + final long cIndex = lvConsumerIndex(); + producerLimit = cIndex + capacity; + available = producerLimit - pIndex; + if (available <= 0) { + // FULL :( + return 0; + } else { + // update producer limit to the next index that we must recheck the consumer index + soProducerLimit(producerLimit); + } + } + actualLimit = Math.min((int) available, limit); + } while (!casProducerIndex(pIndex, pIndex + actualLimit)); + // right, now we claimed a few slots and can fill them with goodness + final AtomicReferenceArray buffer = this.buffer; + for (int i = 0; i < actualLimit; i++) { + // Won CAS, move on to storing + final int offset = calcCircularRefElementOffset(pIndex + i, mask); + soRefElement(buffer, offset, s.get()); + } + return actualLimit; + } + + @Override + public int drain(Consumer c) { + return drain(c, capacity()); + } + + @Override + public int fill(Supplier s) { + return MessagePassingQueueUtil.fillBounded(this, s); + } + + @Override + public void drain(Consumer c, WaitStrategy w, ExitCondition exit) { + MessagePassingQueueUtil.drain(this, c, w, exit); + } + + @Override + public void fill(Supplier s, WaitStrategy wait, ExitCondition exit) { + MessagePassingQueueUtil.fill(this, s, wait, exit); + } + + /** + * @deprecated This was renamed to failFastOffer please migrate + */ + @Deprecated + public int weakOffer(E e) { + return this.failFastOffer(e); + } +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/atomic/MpscChunkedAtomicArrayQueue.java b/netty-jctools/src/main/java/org/jctools/queues/atomic/MpscChunkedAtomicArrayQueue.java new file mode 100644 index 0000000..a34da79 --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/atomic/MpscChunkedAtomicArrayQueue.java @@ -0,0 +1,133 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues.atomic; + +import org.jctools.util.Pow2; +import org.jctools.util.RangeUtil; +import static java.lang.Math.max; +import static java.lang.Math.min; +import static org.jctools.util.Pow2.roundToPowerOfTwo; +import java.util.concurrent.atomic.*; +import org.jctools.queues.*; +import static org.jctools.queues.atomic.AtomicQueueUtil.*; + +/** + * NOTE: This class was automatically generated by org.jctools.queues.atomic.JavaParsingAtomicLinkedQueueGenerator + * which can found in the jctools-build module. The original source file is MpscChunkedArrayQueue.java. + */ +abstract class MpscChunkedAtomicArrayQueueColdProducerFields extends BaseMpscLinkedAtomicArrayQueue { + + protected final long maxQueueCapacity; + + MpscChunkedAtomicArrayQueueColdProducerFields(int initialCapacity, int maxCapacity) { + super(initialCapacity); + RangeUtil.checkGreaterThanOrEqual(maxCapacity, 4, "maxCapacity"); + RangeUtil.checkLessThan(roundToPowerOfTwo(initialCapacity), roundToPowerOfTwo(maxCapacity), "initialCapacity"); + maxQueueCapacity = ((long) Pow2.roundToPowerOfTwo(maxCapacity)) << 1; + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.atomic.JavaParsingAtomicLinkedQueueGenerator + * which can found in the jctools-build module. The original source file is MpscChunkedArrayQueue.java. + * + * An MPSC array queue which starts at initialCapacity and grows to maxCapacity in linked chunks + * of the initial size. The queue grows only when the current chunk is full and elements are not copied on + * resize, instead a link to the new chunk is stored in the old chunk for the consumer to follow. + */ +public class MpscChunkedAtomicArrayQueue extends MpscChunkedAtomicArrayQueueColdProducerFields { + + // 8b + byte b000, b001, b002, b003, b004, b005, b006, b007; + + // 16b + byte b010, b011, b012, b013, b014, b015, b016, b017; + + // 24b + byte b020, b021, b022, b023, b024, b025, b026, b027; + + // 32b + byte b030, b031, b032, b033, b034, b035, b036, b037; + + // 40b + byte b040, b041, b042, b043, b044, b045, b046, b047; + + // 48b + byte b050, b051, b052, b053, b054, b055, b056, b057; + + // 56b + byte b060, b061, b062, b063, b064, b065, b066, b067; + + // 64b + byte b070, b071, b072, b073, b074, b075, b076, b077; + + // 72b + byte b100, b101, b102, b103, b104, b105, b106, b107; + + // 80b + byte b110, b111, b112, b113, b114, b115, b116, b117; + + // 88b + byte b120, b121, b122, b123, b124, b125, b126, b127; + + // 96b + byte b130, b131, b132, b133, b134, b135, b136, b137; + + // 104b + byte b140, b141, b142, b143, b144, b145, b146, b147; + + // 112b + byte b150, b151, b152, b153, b154, b155, b156, b157; + + // 120b + byte b160, b161, b162, b163, b164, b165, b166, b167; + + // 128b + byte b170, b171, b172, b173, b174, b175, b176, b177; + + public MpscChunkedAtomicArrayQueue(int maxCapacity) { + super(max(2, min(1024, roundToPowerOfTwo(maxCapacity / 8))), maxCapacity); + } + + /** + * @param initialCapacity the queue initial capacity. If chunk size is fixed this will be the chunk size. + * Must be 2 or more. + * @param maxCapacity the maximum capacity will be rounded up to the closest power of 2 and will be the + * upper limit of number of elements in this queue. Must be 4 or more and round up to a larger + * power of 2 than initialCapacity. + */ + public MpscChunkedAtomicArrayQueue(int initialCapacity, int maxCapacity) { + super(initialCapacity, maxCapacity); + } + + @Override + protected long availableInQueue(long pIndex, long cIndex) { + return maxQueueCapacity - (pIndex - cIndex); + } + + @Override + public int capacity() { + return (int) (maxQueueCapacity / 2); + } + + @Override + protected int getNextBufferSize(AtomicReferenceArray buffer) { + return length(buffer); + } + + @Override + protected long getCurrentBufferCapacity(long mask) { + return mask; + } +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/atomic/MpscGrowableAtomicArrayQueue.java b/netty-jctools/src/main/java/org/jctools/queues/atomic/MpscGrowableAtomicArrayQueue.java new file mode 100644 index 0000000..377f6ee --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/atomic/MpscGrowableAtomicArrayQueue.java @@ -0,0 +1,60 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues.atomic; + +import org.jctools.util.Pow2; +import org.jctools.util.RangeUtil; +import java.util.concurrent.atomic.*; +import org.jctools.queues.*; +import static org.jctools.queues.atomic.AtomicQueueUtil.*; + +/** + * NOTE: This class was automatically generated by org.jctools.queues.atomic.JavaParsingAtomicLinkedQueueGenerator + * which can found in the jctools-build module. The original source file is MpscGrowableArrayQueue.java. + * + * An MPSC array queue which starts at initialCapacity and grows to maxCapacity in linked chunks, + * doubling theirs size every time until the full blown backing array is used. + * The queue grows only when the current chunk is full and elements are not copied on + * resize, instead a link to the new chunk is stored in the old chunk for the consumer to follow. + */ +public class MpscGrowableAtomicArrayQueue extends MpscChunkedAtomicArrayQueue { + + public MpscGrowableAtomicArrayQueue(int maxCapacity) { + super(Math.max(2, Pow2.roundToPowerOfTwo(maxCapacity / 8)), maxCapacity); + } + + /** + * @param initialCapacity the queue initial capacity. If chunk size is fixed this will be the chunk size. + * Must be 2 or more. + * @param maxCapacity the maximum capacity will be rounded up to the closest power of 2 and will be the + * upper limit of number of elements in this queue. Must be 4 or more and round up to a larger + * power of 2 than initialCapacity. + */ + public MpscGrowableAtomicArrayQueue(int initialCapacity, int maxCapacity) { + super(initialCapacity, maxCapacity); + } + + @Override + protected int getNextBufferSize(AtomicReferenceArray buffer) { + final long maxSize = maxQueueCapacity / 2; + RangeUtil.checkLessThanOrEqual(length(buffer), maxSize, "buffer.length"); + final int newSize = 2 * (length(buffer) - 1); + return newSize + 1; + } + + @Override + protected long getCurrentBufferCapacity(long mask) { + return (mask + 2 == maxQueueCapacity) ? maxQueueCapacity : mask; + } +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/atomic/MpscLinkedAtomicQueue.java b/netty-jctools/src/main/java/org/jctools/queues/atomic/MpscLinkedAtomicQueue.java new file mode 100644 index 0000000..1ac12f3 --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/atomic/MpscLinkedAtomicQueue.java @@ -0,0 +1,158 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues.atomic; + +import java.util.Queue; +import java.util.concurrent.atomic.*; +import org.jctools.queues.*; +import static org.jctools.queues.atomic.AtomicQueueUtil.*; + +/** + * NOTE: This class was automatically generated by org.jctools.queues.atomic.JavaParsingAtomicLinkedQueueGenerator + * which can found in the jctools-build module. The original source file is MpscLinkedQueue.java. + * + * This is a Java port of the MPSC algorithm as presented + * on + * 1024 Cores by D. Vyukov. The original has been adapted to Java and it's quirks with regards to memory + * model and layout: + *

    + *
  1. Use inheritance to ensure no false sharing occurs between producer/consumer node reference fields. + *
  2. Use XCHG functionality to the best of the JDK ability (see differences in JDK7/8 impls). + *
  3. Conform to {@link java.util.Queue} contract on poll. The original semantics are available via relaxedPoll. + *
+ * The queue is initialized with a stub node which is set to both the producer and consumer node references. + * From this point follow the notes on offer/poll. + */ +public class MpscLinkedAtomicQueue extends BaseLinkedAtomicQueue { + + public MpscLinkedAtomicQueue() { + LinkedQueueAtomicNode node = newNode(); + spConsumerNode(node); + xchgProducerNode(node); + } + + /** + * {@inheritDoc}
+ *

+ * IMPLEMENTATION NOTES:
+ * Offer is allowed from multiple threads.
+ * Offer allocates a new node and: + *

    + *
  1. Swaps it atomically with current producer node (only one producer 'wins') + *
  2. Sets the new node as the node following from the swapped producer node + *
+ * This works because each producer is guaranteed to 'plant' a new node and link the old node. No 2 + * producers can get the same producer node as part of XCHG guarantee. + * + * @see MessagePassingQueue#offer(Object) + * @see java.util.Queue#offer(java.lang.Object) + */ + @Override + public boolean offer(final E e) { + if (null == e) { + throw new NullPointerException(); + } + final LinkedQueueAtomicNode nextNode = newNode(e); + final LinkedQueueAtomicNode prevProducerNode = xchgProducerNode(nextNode); + // Should a producer thread get interrupted here the chain WILL be broken until that thread is resumed + // and completes the store in prev.next. This is a "bubble". + prevProducerNode.soNext(nextNode); + return true; + } + + /** + * {@inheritDoc} + *

+ * This method is only safe to call from the (single) consumer thread, and is subject to best effort when racing + * with producers. This method is potentially blocking when "bubble"s in the queue are visible. + */ + @Override + public boolean remove(Object o) { + if (null == o) { + // Null elements are not permitted, so null will never be removed. + return false; + } + final LinkedQueueAtomicNode originalConsumerNode = lpConsumerNode(); + LinkedQueueAtomicNode prevConsumerNode = originalConsumerNode; + LinkedQueueAtomicNode currConsumerNode = getNextConsumerNode(originalConsumerNode); + while (currConsumerNode != null) { + if (o.equals(currConsumerNode.lpValue())) { + LinkedQueueAtomicNode nextNode = getNextConsumerNode(currConsumerNode); + // e.g.: consumerNode -> node0 -> node1(o==v) -> node2 ... => consumerNode -> node0 -> node2 + if (nextNode != null) { + // We are removing an interior node. + prevConsumerNode.soNext(nextNode); + } else // This case reflects: prevConsumerNode != originalConsumerNode && nextNode == null + // At rest, this would be the producerNode, but we must contend with racing. Changes to subclassed + // queues need to consider remove() when implementing offer(). + { + // producerNode is currConsumerNode, try to atomically update the reference to move it to the + // previous node. + prevConsumerNode.soNext(null); + if (!casProducerNode(currConsumerNode, prevConsumerNode)) { + // If the producer(s) have offered more items we need to remove the currConsumerNode link. + nextNode = spinWaitForNextNode(currConsumerNode); + prevConsumerNode.soNext(nextNode); + } + } + // Avoid GC nepotism because we are discarding the current node. + currConsumerNode.soNext(null); + currConsumerNode.spValue(null); + return true; + } + prevConsumerNode = currConsumerNode; + currConsumerNode = getNextConsumerNode(currConsumerNode); + } + return false; + } + + @Override + public int fill(Supplier s) { + return MessagePassingQueueUtil.fillUnbounded(this, s); + } + + @Override + public int fill(Supplier s, int limit) { + if (null == s) + throw new IllegalArgumentException("supplier is null"); + if (limit < 0) + throw new IllegalArgumentException("limit is negative:" + limit); + if (limit == 0) + return 0; + LinkedQueueAtomicNode tail = newNode(s.get()); + final LinkedQueueAtomicNode head = tail; + for (int i = 1; i < limit; i++) { + final LinkedQueueAtomicNode temp = newNode(s.get()); + // spNext: xchgProducerNode ensures correct construction + tail.spNext(temp); + tail = temp; + } + final LinkedQueueAtomicNode oldPNode = xchgProducerNode(tail); + oldPNode.soNext(head); + return limit; + } + + @Override + public void fill(Supplier s, WaitStrategy wait, ExitCondition exit) { + MessagePassingQueueUtil.fill(this, s, wait, exit); + } + + private LinkedQueueAtomicNode getNextConsumerNode(LinkedQueueAtomicNode currConsumerNode) { + LinkedQueueAtomicNode nextNode = currConsumerNode.lvNext(); + if (nextNode == null && currConsumerNode != lvProducerNode()) { + nextNode = spinWaitForNextNode(currConsumerNode); + } + return nextNode; + } +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/atomic/MpscUnboundedAtomicArrayQueue.java b/netty-jctools/src/main/java/org/jctools/queues/atomic/MpscUnboundedAtomicArrayQueue.java new file mode 100644 index 0000000..a6b9203 --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/atomic/MpscUnboundedAtomicArrayQueue.java @@ -0,0 +1,111 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues.atomic; + +import java.util.concurrent.atomic.*; +import org.jctools.queues.*; +import static org.jctools.queues.atomic.AtomicQueueUtil.*; + +/** + * NOTE: This class was automatically generated by org.jctools.queues.atomic.JavaParsingAtomicLinkedQueueGenerator + * which can found in the jctools-build module. The original source file is MpscUnboundedArrayQueue.java. + * + * An MPSC array queue which starts at initialCapacity and grows indefinitely in linked chunks of the initial size. + * The queue grows only when the current chunk is full and elements are not copied on + * resize, instead a link to the new chunk is stored in the old chunk for the consumer to follow. + */ +public class MpscUnboundedAtomicArrayQueue extends BaseMpscLinkedAtomicArrayQueue { + + // 8b + byte b000, b001, b002, b003, b004, b005, b006, b007; + + // 16b + byte b010, b011, b012, b013, b014, b015, b016, b017; + + // 24b + byte b020, b021, b022, b023, b024, b025, b026, b027; + + // 32b + byte b030, b031, b032, b033, b034, b035, b036, b037; + + // 40b + byte b040, b041, b042, b043, b044, b045, b046, b047; + + // 48b + byte b050, b051, b052, b053, b054, b055, b056, b057; + + // 56b + byte b060, b061, b062, b063, b064, b065, b066, b067; + + // 64b + byte b070, b071, b072, b073, b074, b075, b076, b077; + + // 72b + byte b100, b101, b102, b103, b104, b105, b106, b107; + + // 80b + byte b110, b111, b112, b113, b114, b115, b116, b117; + + // 88b + byte b120, b121, b122, b123, b124, b125, b126, b127; + + // 96b + byte b130, b131, b132, b133, b134, b135, b136, b137; + + // 104b + byte b140, b141, b142, b143, b144, b145, b146, b147; + + // 112b + byte b150, b151, b152, b153, b154, b155, b156, b157; + + // 120b + byte b160, b161, b162, b163, b164, b165, b166, b167; + + // 128b + byte b170, b171, b172, b173, b174, b175, b176, b177; + + public MpscUnboundedAtomicArrayQueue(int chunkSize) { + super(chunkSize); + } + + @Override + protected long availableInQueue(long pIndex, long cIndex) { + return Integer.MAX_VALUE; + } + + @Override + public int capacity() { + return MessagePassingQueue.UNBOUNDED_CAPACITY; + } + + @Override + public int drain(Consumer c) { + return drain(c, 4096); + } + + @Override + public int fill(Supplier s) { + return MessagePassingQueueUtil.fillUnbounded(this, s); + } + + @Override + protected int getNextBufferSize(AtomicReferenceArray buffer) { + return length(buffer); + } + + @Override + protected long getCurrentBufferCapacity(long mask) { + return mask; + } +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/atomic/SequencedAtomicReferenceArrayQueue.java b/netty-jctools/src/main/java/org/jctools/queues/atomic/SequencedAtomicReferenceArrayQueue.java new file mode 100644 index 0000000..c54eed2 --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/atomic/SequencedAtomicReferenceArrayQueue.java @@ -0,0 +1,55 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues.atomic; + +import java.util.concurrent.atomic.AtomicLongArray; + +abstract class SequencedAtomicReferenceArrayQueue extends + AtomicReferenceArrayQueue +{ + protected final AtomicLongArray sequenceBuffer; + + public SequencedAtomicReferenceArrayQueue(int capacity) + { + super(capacity); + int actualCapacity = this.mask + 1; + // pad data on either end with some empty slots. + sequenceBuffer = new AtomicLongArray(actualCapacity); + for (int i = 0; i < actualCapacity; i++) + { + soSequence(sequenceBuffer, i, i); + } + } + + protected final long calcSequenceOffset(long index) + { + return calcSequenceOffset(index, mask); + } + + protected static int calcSequenceOffset(long index, int mask) + { + return (int) index & mask; + } + + protected final void soSequence(AtomicLongArray buffer, int offset, long e) + { + buffer.lazySet(offset, e); + } + + protected final long lvSequence(AtomicLongArray buffer, int offset) + { + return buffer.get(offset); + } + +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/atomic/SpmcAtomicArrayQueue.java b/netty-jctools/src/main/java/org/jctools/queues/atomic/SpmcAtomicArrayQueue.java new file mode 100644 index 0000000..098e94e --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/atomic/SpmcAtomicArrayQueue.java @@ -0,0 +1,543 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues.atomic; + +import java.util.concurrent.atomic.*; +import org.jctools.queues.*; +import static org.jctools.queues.atomic.AtomicQueueUtil.*; + +/** + * NOTE: This class was automatically generated by org.jctools.queues.atomic.JavaParsingAtomicArrayQueueGenerator + * which can found in the jctools-build module. The original source file is SpmcArrayQueue.java. + */ +abstract class SpmcAtomicArrayQueueL1Pad extends AtomicReferenceArrayQueue { + + // 8b + byte b000, b001, b002, b003, b004, b005, b006, b007; + + // 16b + byte b010, b011, b012, b013, b014, b015, b016, b017; + + // 24b + byte b020, b021, b022, b023, b024, b025, b026, b027; + + // 32b + byte b030, b031, b032, b033, b034, b035, b036, b037; + + // 40b + byte b040, b041, b042, b043, b044, b045, b046, b047; + + // 48b + byte b050, b051, b052, b053, b054, b055, b056, b057; + + // 56b + byte b060, b061, b062, b063, b064, b065, b066, b067; + + // 64b + byte b070, b071, b072, b073, b074, b075, b076, b077; + + // 72b + byte b100, b101, b102, b103, b104, b105, b106, b107; + + // 80b + byte b110, b111, b112, b113, b114, b115, b116, b117; + + // 88b + byte b120, b121, b122, b123, b124, b125, b126, b127; + + // 96b + byte b130, b131, b132, b133, b134, b135, b136, b137; + + // 104b + byte b140, b141, b142, b143, b144, b145, b146, b147; + + // 112b + byte b150, b151, b152, b153, b154, b155, b156, b157; + + // 120b + byte b160, b161, b162, b163, b164, b165, b166, b167; + + // 128b + byte b170, b171, b172, b173, b174, b175, b176, b177; + + SpmcAtomicArrayQueueL1Pad(int capacity) { + super(capacity); + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.atomic.JavaParsingAtomicArrayQueueGenerator + * which can found in the jctools-build module. The original source file is SpmcArrayQueue.java. + */ +abstract class SpmcAtomicArrayQueueProducerIndexField extends SpmcAtomicArrayQueueL1Pad { + + private static final AtomicLongFieldUpdater P_INDEX_UPDATER = AtomicLongFieldUpdater.newUpdater(SpmcAtomicArrayQueueProducerIndexField.class, "producerIndex"); + + private volatile long producerIndex; + + SpmcAtomicArrayQueueProducerIndexField(int capacity) { + super(capacity); + } + + @Override + public final long lvProducerIndex() { + return producerIndex; + } + + final long lpProducerIndex() { + return producerIndex; + } + + final void soProducerIndex(long newValue) { + P_INDEX_UPDATER.lazySet(this, newValue); + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.atomic.JavaParsingAtomicArrayQueueGenerator + * which can found in the jctools-build module. The original source file is SpmcArrayQueue.java. + */ +abstract class SpmcAtomicArrayQueueL2Pad extends SpmcAtomicArrayQueueProducerIndexField { + + // 8b + byte b000, b001, b002, b003, b004, b005, b006, b007; + + // 16b + byte b010, b011, b012, b013, b014, b015, b016, b017; + + // 24b + byte b020, b021, b022, b023, b024, b025, b026, b027; + + // 32b + byte b030, b031, b032, b033, b034, b035, b036, b037; + + // 40b + byte b040, b041, b042, b043, b044, b045, b046, b047; + + // 48b + byte b050, b051, b052, b053, b054, b055, b056, b057; + + // 56b + byte b060, b061, b062, b063, b064, b065, b066, b067; + + // 64b + byte b070, b071, b072, b073, b074, b075, b076, b077; + + // 72b + byte b100, b101, b102, b103, b104, b105, b106, b107; + + // 80b + byte b110, b111, b112, b113, b114, b115, b116, b117; + + // 88b + byte b120, b121, b122, b123, b124, b125, b126, b127; + + // 96b + byte b130, b131, b132, b133, b134, b135, b136, b137; + + // 104b + byte b140, b141, b142, b143, b144, b145, b146, b147; + + // 112b + byte b150, b151, b152, b153, b154, b155, b156, b157; + + // 120b + byte b160, b161, b162, b163, b164, b165, b166, b167; + + // 128b + byte b170, b171, b172, b173, b174, b175, b176, b177; + + SpmcAtomicArrayQueueL2Pad(int capacity) { + super(capacity); + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.atomic.JavaParsingAtomicArrayQueueGenerator + * which can found in the jctools-build module. The original source file is SpmcArrayQueue.java. + */ +abstract class SpmcAtomicArrayQueueConsumerIndexField extends SpmcAtomicArrayQueueL2Pad { + + private static final AtomicLongFieldUpdater C_INDEX_UPDATER = AtomicLongFieldUpdater.newUpdater(SpmcAtomicArrayQueueConsumerIndexField.class, "consumerIndex"); + + private volatile long consumerIndex; + + SpmcAtomicArrayQueueConsumerIndexField(int capacity) { + super(capacity); + } + + @Override + public final long lvConsumerIndex() { + return consumerIndex; + } + + final boolean casConsumerIndex(long expect, long newValue) { + return C_INDEX_UPDATER.compareAndSet(this, expect, newValue); + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.atomic.JavaParsingAtomicArrayQueueGenerator + * which can found in the jctools-build module. The original source file is SpmcArrayQueue.java. + */ +abstract class SpmcAtomicArrayQueueMidPad extends SpmcAtomicArrayQueueConsumerIndexField { + + // 8b + byte b000, b001, b002, b003, b004, b005, b006, b007; + + // 16b + byte b010, b011, b012, b013, b014, b015, b016, b017; + + // 24b + byte b020, b021, b022, b023, b024, b025, b026, b027; + + // 32b + byte b030, b031, b032, b033, b034, b035, b036, b037; + + // 40b + byte b040, b041, b042, b043, b044, b045, b046, b047; + + // 48b + byte b050, b051, b052, b053, b054, b055, b056, b057; + + // 56b + byte b060, b061, b062, b063, b064, b065, b066, b067; + + // 64b + byte b070, b071, b072, b073, b074, b075, b076, b077; + + // 72b + byte b100, b101, b102, b103, b104, b105, b106, b107; + + // 80b + byte b110, b111, b112, b113, b114, b115, b116, b117; + + // 88b + byte b120, b121, b122, b123, b124, b125, b126, b127; + + // 96b + byte b130, b131, b132, b133, b134, b135, b136, b137; + + // 104b + byte b140, b141, b142, b143, b144, b145, b146, b147; + + // 112b + byte b150, b151, b152, b153, b154, b155, b156, b157; + + // 120b + byte b160, b161, b162, b163, b164, b165, b166, b167; + + // 128b + byte b170, b171, b172, b173, b174, b175, b176, b177; + + SpmcAtomicArrayQueueMidPad(int capacity) { + super(capacity); + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.atomic.JavaParsingAtomicArrayQueueGenerator + * which can found in the jctools-build module. The original source file is SpmcArrayQueue.java. + */ +abstract class SpmcAtomicArrayQueueProducerIndexCacheField extends SpmcAtomicArrayQueueMidPad { + + // This is separated from the consumerIndex which will be highly contended in the hope that this value spends most + // of it's time in a cache line that is Shared(and rarely invalidated) + private volatile long producerIndexCache; + + SpmcAtomicArrayQueueProducerIndexCacheField(int capacity) { + super(capacity); + } + + protected final long lvProducerIndexCache() { + return producerIndexCache; + } + + protected final void svProducerIndexCache(long newValue) { + producerIndexCache = newValue; + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.atomic.JavaParsingAtomicArrayQueueGenerator + * which can found in the jctools-build module. The original source file is SpmcArrayQueue.java. + */ +abstract class SpmcAtomicArrayQueueL3Pad extends SpmcAtomicArrayQueueProducerIndexCacheField { + + // 8b + byte b000, b001, b002, b003, b004, b005, b006, b007; + + // 16b + byte b010, b011, b012, b013, b014, b015, b016, b017; + + // 24b + byte b020, b021, b022, b023, b024, b025, b026, b027; + + // 32b + byte b030, b031, b032, b033, b034, b035, b036, b037; + + // 40b + byte b040, b041, b042, b043, b044, b045, b046, b047; + + // 48b + byte b050, b051, b052, b053, b054, b055, b056, b057; + + // 56b + byte b060, b061, b062, b063, b064, b065, b066, b067; + + // 64b + byte b070, b071, b072, b073, b074, b075, b076, b077; + + // 72b + byte b100, b101, b102, b103, b104, b105, b106, b107; + + // 80b + byte b110, b111, b112, b113, b114, b115, b116, b117; + + // 88b + byte b120, b121, b122, b123, b124, b125, b126, b127; + + // 96b + byte b130, b131, b132, b133, b134, b135, b136, b137; + + // 104b + byte b140, b141, b142, b143, b144, b145, b146, b147; + + // 112b + byte b150, b151, b152, b153, b154, b155, b156, b157; + + // 120b + byte b160, b161, b162, b163, b164, b165, b166, b167; + + // 128b + byte b170, b171, b172, b173, b174, b175, b176, b177; + + SpmcAtomicArrayQueueL3Pad(int capacity) { + super(capacity); + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.atomic.JavaParsingAtomicArrayQueueGenerator + * which can found in the jctools-build module. The original source file is SpmcArrayQueue.java. + */ +public class SpmcAtomicArrayQueue extends SpmcAtomicArrayQueueL3Pad { + + public SpmcAtomicArrayQueue(final int capacity) { + super(capacity); + } + + @Override + public boolean offer(final E e) { + if (null == e) { + throw new NullPointerException(); + } + final AtomicReferenceArray buffer = this.buffer; + final int mask = this.mask; + final long currProducerIndex = lvProducerIndex(); + final int offset = calcCircularRefElementOffset(currProducerIndex, mask); + if (null != lvRefElement(buffer, offset)) { + long size = currProducerIndex - lvConsumerIndex(); + if (size > mask) { + return false; + } else { + // Bubble: This can happen because `poll` moves index before placing element. + // spin wait for slot to clear, buggers wait freedom + while (null != lvRefElement(buffer, offset)) { + // BURN + } + } + } + soRefElement(buffer, offset, e); + // single producer, so store ordered is valid. It is also required to correctly publish the element + // and for the consumers to pick up the tail value. + soProducerIndex(currProducerIndex + 1); + return true; + } + + @Override + public E poll() { + long currentConsumerIndex; + long currProducerIndexCache = lvProducerIndexCache(); + do { + currentConsumerIndex = lvConsumerIndex(); + if (currentConsumerIndex >= currProducerIndexCache) { + long currProducerIndex = lvProducerIndex(); + if (currentConsumerIndex >= currProducerIndex) { + return null; + } else { + currProducerIndexCache = currProducerIndex; + svProducerIndexCache(currProducerIndex); + } + } + } while (!casConsumerIndex(currentConsumerIndex, currentConsumerIndex + 1)); + // consumers are gated on latest visible tail, and so can't see a null value in the queue or overtake + // and wrap to hit same location. + return removeElement(buffer, currentConsumerIndex, mask); + } + + private E removeElement(final AtomicReferenceArray buffer, long index, final int mask) { + final int offset = calcCircularRefElementOffset(index, mask); + // load plain, element happens before it's index becomes visible + final E e = lpRefElement(buffer, offset); + // store ordered, make sure nulling out is visible. Producer is waiting for this value. + soRefElement(buffer, offset, null); + return e; + } + + @Override + public E peek() { + final AtomicReferenceArray buffer = this.buffer; + final int mask = this.mask; + long currProducerIndexCache = lvProducerIndexCache(); + long currentConsumerIndex; + long nextConsumerIndex = lvConsumerIndex(); + E e; + do { + currentConsumerIndex = nextConsumerIndex; + if (currentConsumerIndex >= currProducerIndexCache) { + long currProducerIndex = lvProducerIndex(); + if (currentConsumerIndex >= currProducerIndex) { + return null; + } else { + currProducerIndexCache = currProducerIndex; + svProducerIndexCache(currProducerIndex); + } + } + e = lvRefElement(buffer, calcCircularRefElementOffset(currentConsumerIndex, mask)); + // sandwich the element load between 2 consumer index loads + nextConsumerIndex = lvConsumerIndex(); + } while (null == e || nextConsumerIndex != currentConsumerIndex); + return e; + } + + @Override + public boolean relaxedOffer(E e) { + if (null == e) { + throw new NullPointerException("Null is not a valid element"); + } + final AtomicReferenceArray buffer = this.buffer; + final int mask = this.mask; + final long producerIndex = lpProducerIndex(); + final int offset = calcCircularRefElementOffset(producerIndex, mask); + if (null != lvRefElement(buffer, offset)) { + return false; + } + soRefElement(buffer, offset, e); + // single producer, so store ordered is valid. It is also required to correctly publish the element + // and for the consumers to pick up the tail value. + soProducerIndex(producerIndex + 1); + return true; + } + + @Override + public E relaxedPoll() { + return poll(); + } + + @Override + public E relaxedPeek() { + final AtomicReferenceArray buffer = this.buffer; + final int mask = this.mask; + long currentConsumerIndex; + long nextConsumerIndex = lvConsumerIndex(); + E e; + do { + currentConsumerIndex = nextConsumerIndex; + e = lvRefElement(buffer, calcCircularRefElementOffset(currentConsumerIndex, mask)); + // sandwich the element load between 2 consumer index loads + nextConsumerIndex = lvConsumerIndex(); + } while (nextConsumerIndex != currentConsumerIndex); + return e; + } + + @Override + public int drain(final Consumer c, final int limit) { + if (null == c) + throw new IllegalArgumentException("c is null"); + if (limit < 0) + throw new IllegalArgumentException("limit is negative: " + limit); + if (limit == 0) + return 0; + final AtomicReferenceArray buffer = this.buffer; + final int mask = this.mask; + long currProducerIndexCache = lvProducerIndexCache(); + int adjustedLimit = 0; + long currentConsumerIndex; + do { + currentConsumerIndex = lvConsumerIndex(); + // is there any space in the queue? + if (currentConsumerIndex >= currProducerIndexCache) { + long currProducerIndex = lvProducerIndex(); + if (currentConsumerIndex >= currProducerIndex) { + return 0; + } else { + currProducerIndexCache = currProducerIndex; + svProducerIndexCache(currProducerIndex); + } + } + // try and claim up to 'limit' elements in one go + int remaining = (int) (currProducerIndexCache - currentConsumerIndex); + adjustedLimit = Math.min(remaining, limit); + } while (!casConsumerIndex(currentConsumerIndex, currentConsumerIndex + adjustedLimit)); + for (int i = 0; i < adjustedLimit; i++) { + c.accept(removeElement(buffer, currentConsumerIndex + i, mask)); + } + return adjustedLimit; + } + + @Override + public int fill(final Supplier s, final int limit) { + if (null == s) + throw new IllegalArgumentException("supplier is null"); + if (limit < 0) + throw new IllegalArgumentException("limit is negative:" + limit); + if (limit == 0) + return 0; + final AtomicReferenceArray buffer = this.buffer; + final int mask = this.mask; + long producerIndex = this.lpProducerIndex(); + for (int i = 0; i < limit; i++) { + final int offset = calcCircularRefElementOffset(producerIndex, mask); + if (null != lvRefElement(buffer, offset)) { + return i; + } + producerIndex++; + soRefElement(buffer, offset, s.get()); + // ordered store -> atomic and ordered for size() + soProducerIndex(producerIndex); + } + return limit; + } + + @Override + public int drain(final Consumer c) { + return MessagePassingQueueUtil.drain(this, c); + } + + @Override + public int fill(final Supplier s) { + return fill(s, capacity()); + } + + @Override + public void drain(final Consumer c, final WaitStrategy w, final ExitCondition exit) { + MessagePassingQueueUtil.drain(this, c, w, exit); + } + + @Override + public void fill(final Supplier s, final WaitStrategy w, final ExitCondition e) { + MessagePassingQueueUtil.fill(this, s, w, e); + } +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/atomic/SpscAtomicArrayQueue.java b/netty-jctools/src/main/java/org/jctools/queues/atomic/SpscAtomicArrayQueue.java new file mode 100644 index 0000000..b4bb5ec --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/atomic/SpscAtomicArrayQueue.java @@ -0,0 +1,516 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues.atomic; + +import org.jctools.util.SpscLookAheadUtil; +import java.util.concurrent.atomic.*; +import org.jctools.queues.*; +import static org.jctools.queues.atomic.AtomicQueueUtil.*; + +/** + * NOTE: This class was automatically generated by org.jctools.queues.atomic.JavaParsingAtomicArrayQueueGenerator + * which can found in the jctools-build module. The original source file is SpscArrayQueue.java. + */ +abstract class SpscAtomicArrayQueueColdField extends AtomicReferenceArrayQueue { + + final int lookAheadStep; + + SpscAtomicArrayQueueColdField(int capacity) { + super(capacity); + int actualCapacity = capacity(); + lookAheadStep = SpscLookAheadUtil.computeLookAheadStep(actualCapacity); + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.atomic.JavaParsingAtomicArrayQueueGenerator + * which can found in the jctools-build module. The original source file is SpscArrayQueue.java. + */ +abstract class SpscAtomicArrayQueueL1Pad extends SpscAtomicArrayQueueColdField { + + // 8b + byte b000, b001, b002, b003, b004, b005, b006, b007; + + // 16b + byte b010, b011, b012, b013, b014, b015, b016, b017; + + // 24b + byte b020, b021, b022, b023, b024, b025, b026, b027; + + // 32b + byte b030, b031, b032, b033, b034, b035, b036, b037; + + // 40b + byte b040, b041, b042, b043, b044, b045, b046, b047; + + // 48b + byte b050, b051, b052, b053, b054, b055, b056, b057; + + // 56b + byte b060, b061, b062, b063, b064, b065, b066, b067; + + // 64b + byte b070, b071, b072, b073, b074, b075, b076, b077; + + // 72b + byte b100, b101, b102, b103, b104, b105, b106, b107; + + // 80b + byte b110, b111, b112, b113, b114, b115, b116, b117; + + // 88b + byte b120, b121, b122, b123, b124, b125, b126, b127; + + // 96b + byte b130, b131, b132, b133, b134, b135, b136, b137; + + // 104b + byte b140, b141, b142, b143, b144, b145, b146, b147; + + // 112b + byte b150, b151, b152, b153, b154, b155, b156, b157; + + // 120b + byte b160, b161, b162, b163, b164, b165, b166, b167; + + // 128b + byte b170, b171, b172, b173, b174, b175, b176, b177; + + SpscAtomicArrayQueueL1Pad(int capacity) { + super(capacity); + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.atomic.JavaParsingAtomicArrayQueueGenerator + * which can found in the jctools-build module. The original source file is SpscArrayQueue.java. + */ +abstract class SpscAtomicArrayQueueProducerIndexFields extends SpscAtomicArrayQueueL1Pad { + + private static final AtomicLongFieldUpdater P_INDEX_UPDATER = AtomicLongFieldUpdater.newUpdater(SpscAtomicArrayQueueProducerIndexFields.class, "producerIndex"); + + private volatile long producerIndex; + + protected long producerLimit; + + SpscAtomicArrayQueueProducerIndexFields(int capacity) { + super(capacity); + } + + @Override + public final long lvProducerIndex() { + return producerIndex; + } + + final long lpProducerIndex() { + return producerIndex; + } + + final void soProducerIndex(final long newValue) { + P_INDEX_UPDATER.lazySet(this, newValue); + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.atomic.JavaParsingAtomicArrayQueueGenerator + * which can found in the jctools-build module. The original source file is SpscArrayQueue.java. + */ +abstract class SpscAtomicArrayQueueL2Pad extends SpscAtomicArrayQueueProducerIndexFields { + + // 8b + byte b000, b001, b002, b003, b004, b005, b006, b007; + + // 16b + byte b010, b011, b012, b013, b014, b015, b016, b017; + + // 24b + byte b020, b021, b022, b023, b024, b025, b026, b027; + + // 32b + byte b030, b031, b032, b033, b034, b035, b036, b037; + + // 40b + byte b040, b041, b042, b043, b044, b045, b046, b047; + + // 48b + byte b050, b051, b052, b053, b054, b055, b056, b057; + + // 56b + byte b060, b061, b062, b063, b064, b065, b066, b067; + + // 64b + byte b070, b071, b072, b073, b074, b075, b076, b077; + + // 72b + byte b100, b101, b102, b103, b104, b105, b106, b107; + + // 80b + byte b110, b111, b112, b113, b114, b115, b116, b117; + + // 88b + byte b120, b121, b122, b123, b124, b125, b126, b127; + + // 96b + byte b130, b131, b132, b133, b134, b135, b136, b137; + + // 104b + byte b140, b141, b142, b143, b144, b145, b146, b147; + + // 112b + byte b150, b151, b152, b153, b154, b155, b156, b157; + + // 120b + byte b160, b161, b162, b163, b164, b165, b166, b167; + + // 128b + byte b170, b171, b172, b173, b174, b175, b176, b177; + + SpscAtomicArrayQueueL2Pad(int capacity) { + super(capacity); + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.atomic.JavaParsingAtomicArrayQueueGenerator + * which can found in the jctools-build module. The original source file is SpscArrayQueue.java. + */ +abstract class SpscAtomicArrayQueueConsumerIndexField extends SpscAtomicArrayQueueL2Pad { + + private static final AtomicLongFieldUpdater C_INDEX_UPDATER = AtomicLongFieldUpdater.newUpdater(SpscAtomicArrayQueueConsumerIndexField.class, "consumerIndex"); + + private volatile long consumerIndex; + + SpscAtomicArrayQueueConsumerIndexField(int capacity) { + super(capacity); + } + + public final long lvConsumerIndex() { + return consumerIndex; + } + + final long lpConsumerIndex() { + return consumerIndex; + } + + final void soConsumerIndex(final long newValue) { + C_INDEX_UPDATER.lazySet(this, newValue); + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.atomic.JavaParsingAtomicArrayQueueGenerator + * which can found in the jctools-build module. The original source file is SpscArrayQueue.java. + */ +abstract class SpscAtomicArrayQueueL3Pad extends SpscAtomicArrayQueueConsumerIndexField { + + // 8b + byte b000, b001, b002, b003, b004, b005, b006, b007; + + // 16b + byte b010, b011, b012, b013, b014, b015, b016, b017; + + // 24b + byte b020, b021, b022, b023, b024, b025, b026, b027; + + // 32b + byte b030, b031, b032, b033, b034, b035, b036, b037; + + // 40b + byte b040, b041, b042, b043, b044, b045, b046, b047; + + // 48b + byte b050, b051, b052, b053, b054, b055, b056, b057; + + // 56b + byte b060, b061, b062, b063, b064, b065, b066, b067; + + // 64b + byte b070, b071, b072, b073, b074, b075, b076, b077; + + // 72b + byte b100, b101, b102, b103, b104, b105, b106, b107; + + // 80b + byte b110, b111, b112, b113, b114, b115, b116, b117; + + // 88b + byte b120, b121, b122, b123, b124, b125, b126, b127; + + // 96b + byte b130, b131, b132, b133, b134, b135, b136, b137; + + // 104b + byte b140, b141, b142, b143, b144, b145, b146, b147; + + // 112b + byte b150, b151, b152, b153, b154, b155, b156, b157; + + // 120b + byte b160, b161, b162, b163, b164, b165, b166, b167; + + // 128b + byte b170, b171, b172, b173, b174, b175, b176, b177; + + SpscAtomicArrayQueueL3Pad(int capacity) { + super(capacity); + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.atomic.JavaParsingAtomicArrayQueueGenerator + * which can found in the jctools-build module. The original source file is SpscArrayQueue.java. + * + * A Single-Producer-Single-Consumer queue backed by a pre-allocated buffer. + *

+ * This implementation is a mashup of the Fast Flow + * algorithm with an optimization of the offer method taken from the BQueue algorithm (a variation on Fast + * Flow), and adjusted to comply with Queue.offer semantics with regards to capacity.
+ * For convenience the relevant papers are available in the `resources` folder:
+ * + * 2010 - Pisa - SPSC Queues on Shared Cache Multi-Core Systems.pdf
+ * 2012 - Junchang- BQueue- Efficient and Practical Queuing.pdf
+ *
+ * This implementation is wait free. + */ +public class SpscAtomicArrayQueue extends SpscAtomicArrayQueueL3Pad { + + public SpscAtomicArrayQueue(final int capacity) { + super(Math.max(capacity, 4)); + } + + /** + * {@inheritDoc} + *

+ * This implementation is correct for single producer thread use only. + */ + @Override + public boolean offer(final E e) { + if (null == e) { + throw new NullPointerException(); + } + // local load of field to avoid repeated loads after volatile reads + final AtomicReferenceArray buffer = this.buffer; + final int mask = this.mask; + final long producerIndex = this.lpProducerIndex(); + if (producerIndex >= producerLimit && !offerSlowPath(buffer, mask, producerIndex)) { + return false; + } + final int offset = calcCircularRefElementOffset(producerIndex, mask); + soRefElement(buffer, offset, e); + // ordered store -> atomic and ordered for size() + soProducerIndex(producerIndex + 1); + return true; + } + + private boolean offerSlowPath(final AtomicReferenceArray buffer, final int mask, final long producerIndex) { + final int lookAheadStep = this.lookAheadStep; + if (null == lvRefElement(buffer, calcCircularRefElementOffset(producerIndex + lookAheadStep, mask))) { + producerLimit = producerIndex + lookAheadStep; + } else { + final int offset = calcCircularRefElementOffset(producerIndex, mask); + if (null != lvRefElement(buffer, offset)) { + return false; + } + } + return true; + } + + /** + * {@inheritDoc} + *

+ * This implementation is correct for single consumer thread use only. + */ + @Override + public E poll() { + final long consumerIndex = this.lpConsumerIndex(); + final int offset = calcCircularRefElementOffset(consumerIndex, mask); + // local load of field to avoid repeated loads after volatile reads + final AtomicReferenceArray buffer = this.buffer; + final E e = lvRefElement(buffer, offset); + if (null == e) { + return null; + } + soRefElement(buffer, offset, null); + // ordered store -> atomic and ordered for size() + soConsumerIndex(consumerIndex + 1); + return e; + } + + /** + * {@inheritDoc} + *

+ * This implementation is correct for single consumer thread use only. + */ + @Override + public E peek() { + return lvRefElement(buffer, calcCircularRefElementOffset(lpConsumerIndex(), mask)); + } + + @Override + public boolean relaxedOffer(final E message) { + return offer(message); + } + + @Override + public E relaxedPoll() { + return poll(); + } + + @Override + public E relaxedPeek() { + return peek(); + } + + @Override + public int drain(final Consumer c) { + return drain(c, capacity()); + } + + @Override + public int fill(final Supplier s) { + return fill(s, capacity()); + } + + @Override + public int drain(final Consumer c, final int limit) { + if (null == c) + throw new IllegalArgumentException("c is null"); + if (limit < 0) + throw new IllegalArgumentException("limit is negative: " + limit); + if (limit == 0) + return 0; + final AtomicReferenceArray buffer = this.buffer; + final int mask = this.mask; + final long consumerIndex = this.lpConsumerIndex(); + for (int i = 0; i < limit; i++) { + final long index = consumerIndex + i; + final int offset = calcCircularRefElementOffset(index, mask); + final E e = lvRefElement(buffer, offset); + if (null == e) { + return i; + } + soRefElement(buffer, offset, null); + // ordered store -> atomic and ordered for size() + soConsumerIndex(index + 1); + c.accept(e); + } + return limit; + } + + @Override + public int fill(final Supplier s, final int limit) { + if (null == s) + throw new IllegalArgumentException("supplier is null"); + if (limit < 0) + throw new IllegalArgumentException("limit is negative:" + limit); + if (limit == 0) + return 0; + final AtomicReferenceArray buffer = this.buffer; + final int mask = this.mask; + final int lookAheadStep = this.lookAheadStep; + final long producerIndex = this.lpProducerIndex(); + for (int i = 0; i < limit; i++) { + final long index = producerIndex + i; + final int lookAheadElementOffset = calcCircularRefElementOffset(index + lookAheadStep, mask); + if (null == lvRefElement(buffer, lookAheadElementOffset)) { + int lookAheadLimit = Math.min(lookAheadStep, limit - i); + for (int j = 0; j < lookAheadLimit; j++) { + final int offset = calcCircularRefElementOffset(index + j, mask); + soRefElement(buffer, offset, s.get()); + // ordered store -> atomic and ordered for size() + soProducerIndex(index + j + 1); + } + i += lookAheadLimit - 1; + } else { + final int offset = calcCircularRefElementOffset(index, mask); + if (null != lvRefElement(buffer, offset)) { + return i; + } + soRefElement(buffer, offset, s.get()); + // ordered store -> atomic and ordered for size() + soProducerIndex(index + 1); + } + } + return limit; + } + + @Override + public void drain(final Consumer c, final WaitStrategy w, final ExitCondition exit) { + if (null == c) + throw new IllegalArgumentException("c is null"); + if (null == w) + throw new IllegalArgumentException("wait is null"); + if (null == exit) + throw new IllegalArgumentException("exit condition is null"); + final AtomicReferenceArray buffer = this.buffer; + final int mask = this.mask; + long consumerIndex = this.lpConsumerIndex(); + int counter = 0; + while (exit.keepRunning()) { + for (int i = 0; i < 4096; i++) { + final int offset = calcCircularRefElementOffset(consumerIndex, mask); + final E e = lvRefElement(buffer, offset); + if (null == e) { + counter = w.idle(counter); + continue; + } + consumerIndex++; + counter = 0; + soRefElement(buffer, offset, null); + // ordered store -> atomic and ordered for size() + soConsumerIndex(consumerIndex); + c.accept(e); + } + } + } + + @Override + public void fill(final Supplier s, final WaitStrategy w, final ExitCondition e) { + if (null == w) + throw new IllegalArgumentException("waiter is null"); + if (null == e) + throw new IllegalArgumentException("exit condition is null"); + if (null == s) + throw new IllegalArgumentException("supplier is null"); + final AtomicReferenceArray buffer = this.buffer; + final int mask = this.mask; + final int lookAheadStep = this.lookAheadStep; + long producerIndex = this.lpProducerIndex(); + int counter = 0; + while (e.keepRunning()) { + final int lookAheadElementOffset = calcCircularRefElementOffset(producerIndex + lookAheadStep, mask); + if (null == lvRefElement(buffer, lookAheadElementOffset)) { + for (int j = 0; j < lookAheadStep; j++) { + final int offset = calcCircularRefElementOffset(producerIndex, mask); + producerIndex++; + soRefElement(buffer, offset, s.get()); + // ordered store -> atomic and ordered for size() + soProducerIndex(producerIndex); + } + } else { + final int offset = calcCircularRefElementOffset(producerIndex, mask); + if (null != lvRefElement(buffer, offset)) { + counter = w.idle(counter); + continue; + } + producerIndex++; + counter = 0; + soRefElement(buffer, offset, s.get()); + // ordered store -> atomic and ordered for size() + soProducerIndex(producerIndex); + } + } + } +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/atomic/SpscChunkedAtomicArrayQueue.java b/netty-jctools/src/main/java/org/jctools/queues/atomic/SpscChunkedAtomicArrayQueue.java new file mode 100644 index 0000000..cbac582 --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/atomic/SpscChunkedAtomicArrayQueue.java @@ -0,0 +1,104 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues.atomic; + +import org.jctools.util.Pow2; +import org.jctools.util.RangeUtil; +import java.util.concurrent.atomic.*; +import org.jctools.queues.*; +import static org.jctools.queues.atomic.AtomicQueueUtil.*; + +/** + * NOTE: This class was automatically generated by org.jctools.queues.atomic.JavaParsingAtomicLinkedQueueGenerator + * which can found in the jctools-build module. The original source file is SpscChunkedArrayQueue.java. + * + * An SPSC array queue which starts at initialCapacity and grows to maxCapacity in linked chunks + * of the initial size. The queue grows only when the current chunk is full and elements are not copied on + * resize, instead a link to the new chunk is stored in the old chunk for the consumer to follow.
+ * + * @param + */ +public class SpscChunkedAtomicArrayQueue extends BaseSpscLinkedAtomicArrayQueue { + + private final int maxQueueCapacity; + + private long producerQueueLimit; + + public SpscChunkedAtomicArrayQueue(int capacity) { + this(Math.max(8, Pow2.roundToPowerOfTwo(capacity / 8)), capacity); + } + + public SpscChunkedAtomicArrayQueue(int chunkSize, int capacity) { + RangeUtil.checkGreaterThanOrEqual(capacity, 16, "capacity"); + // minimal chunk size of eight makes sure minimal lookahead step is 2 + RangeUtil.checkGreaterThanOrEqual(chunkSize, 8, "chunkSize"); + maxQueueCapacity = Pow2.roundToPowerOfTwo(capacity); + int chunkCapacity = Pow2.roundToPowerOfTwo(chunkSize); + RangeUtil.checkLessThan(chunkCapacity, maxQueueCapacity, "chunkCapacity"); + long mask = chunkCapacity - 1; + // need extra element to point at next array + AtomicReferenceArray buffer = allocateRefArray(chunkCapacity + 1); + producerBuffer = buffer; + producerMask = mask; + consumerBuffer = buffer; + consumerMask = mask; + // we know it's all empty to start with + producerBufferLimit = mask - 1; + producerQueueLimit = maxQueueCapacity; + } + + @Override + final boolean offerColdPath(AtomicReferenceArray buffer, long mask, long pIndex, int offset, E v, Supplier s) { + // use a fixed lookahead step based on buffer capacity + final long lookAheadStep = (mask + 1) / 4; + long pBufferLimit = pIndex + lookAheadStep; + long pQueueLimit = producerQueueLimit; + if (pIndex >= pQueueLimit) { + // we tested against a potentially out of date queue limit, refresh it + final long cIndex = lvConsumerIndex(); + producerQueueLimit = pQueueLimit = cIndex + maxQueueCapacity; + // if we're full we're full + if (pIndex >= pQueueLimit) { + return false; + } + } + // if buffer limit is after queue limit we use queue limit. We need to handle overflow so + // cannot use Math.min + if (pBufferLimit - pQueueLimit > 0) { + pBufferLimit = pQueueLimit; + } + // go around the buffer or add a new buffer + if (// there's sufficient room in buffer/queue to use pBufferLimit + pBufferLimit > pIndex + 1 && null == lvRefElement(buffer, calcCircularRefElementOffset(pBufferLimit, mask))) { + // joy, there's plenty of room + producerBufferLimit = pBufferLimit - 1; + writeToQueue(buffer, v == null ? s.get() : v, pIndex, offset); + } else if (null == lvRefElement(buffer, calcCircularRefElementOffset(pIndex + 1, mask))) { + // buffer is not full + writeToQueue(buffer, v == null ? s.get() : v, pIndex, offset); + } else { + // we got one slot left to write into, and we are not full. Need to link new buffer. + // allocate new buffer of same length + final AtomicReferenceArray newBuffer = allocateRefArray((int) (mask + 2)); + producerBuffer = newBuffer; + linkOldToNew(pIndex, buffer, offset, newBuffer, offset, v == null ? s.get() : v); + } + return true; + } + + @Override + public int capacity() { + return maxQueueCapacity; + } +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/atomic/SpscGrowableAtomicArrayQueue.java b/netty-jctools/src/main/java/org/jctools/queues/atomic/SpscGrowableAtomicArrayQueue.java new file mode 100644 index 0000000..b29b3ba --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/atomic/SpscGrowableAtomicArrayQueue.java @@ -0,0 +1,147 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues.atomic; + +import org.jctools.util.Pow2; +import org.jctools.util.RangeUtil; +import org.jctools.util.SpscLookAheadUtil; +import java.util.concurrent.atomic.*; +import org.jctools.queues.*; +import static org.jctools.queues.atomic.AtomicQueueUtil.*; + +/** + * NOTE: This class was automatically generated by org.jctools.queues.atomic.JavaParsingAtomicLinkedQueueGenerator + * which can found in the jctools-build module. The original source file is SpscGrowableArrayQueue.java. + * + * An SPSC array queue which starts at initialCapacity and grows to maxCapacity in linked chunks, + * doubling theirs size every time until the full blown backing array is used. + * The queue grows only when the current chunk is full and elements are not copied on + * resize, instead a link to the new chunk is stored in the old chunk for the consumer to follow.
+ * + * @param + */ +public class SpscGrowableAtomicArrayQueue extends BaseSpscLinkedAtomicArrayQueue { + + private final int maxQueueCapacity; + + private long lookAheadStep; + + public SpscGrowableAtomicArrayQueue(final int capacity) { + this(Math.max(8, Pow2.roundToPowerOfTwo(capacity / 8)), capacity); + } + + public SpscGrowableAtomicArrayQueue(final int chunkSize, final int capacity) { + RangeUtil.checkGreaterThanOrEqual(capacity, 16, "capacity"); + // minimal chunk size of eight makes sure minimal lookahead step is 2 + RangeUtil.checkGreaterThanOrEqual(chunkSize, 8, "chunkSize"); + maxQueueCapacity = Pow2.roundToPowerOfTwo(capacity); + int chunkCapacity = Pow2.roundToPowerOfTwo(chunkSize); + RangeUtil.checkLessThan(chunkCapacity, maxQueueCapacity, "chunkCapacity"); + long mask = chunkCapacity - 1; + // need extra element to point at next array + AtomicReferenceArray buffer = allocateRefArray(chunkCapacity + 1); + producerBuffer = buffer; + producerMask = mask; + consumerBuffer = buffer; + consumerMask = mask; + // we know it's all empty to start with + producerBufferLimit = mask - 1; + adjustLookAheadStep(chunkCapacity); + } + + @Override + final boolean offerColdPath(final AtomicReferenceArray buffer, final long mask, final long index, final int offset, final E v, final Supplier s) { + final long lookAheadStep = this.lookAheadStep; + // normal case, go around the buffer or resize if full (unless we hit max capacity) + if (lookAheadStep > 0) { + int lookAheadElementOffset = calcCircularRefElementOffset(index + lookAheadStep, mask); + // Try and look ahead a number of elements so we don't have to do this all the time + if (null == lvRefElement(buffer, lookAheadElementOffset)) { + // joy, there's plenty of room + producerBufferLimit = index + lookAheadStep - 1; + writeToQueue(buffer, v == null ? s.get() : v, index, offset); + return true; + } + // we're at max capacity, can use up last element + final int maxCapacity = maxQueueCapacity; + if (mask + 1 == maxCapacity) { + if (null == lvRefElement(buffer, offset)) { + writeToQueue(buffer, v == null ? s.get() : v, index, offset); + return true; + } + // we're full and can't grow + return false; + } + // not at max capacity, so must allow extra slot for next buffer pointer + if (null == lvRefElement(buffer, calcCircularRefElementOffset(index + 1, mask))) { + // buffer is not full + writeToQueue(buffer, v == null ? s.get() : v, index, offset); + } else { + // allocate new buffer of same length + final AtomicReferenceArray newBuffer = allocateRefArray((int) (2 * (mask + 1) + 1)); + producerBuffer = newBuffer; + producerMask = length(newBuffer) - 2; + final int offsetInNew = calcCircularRefElementOffset(index, producerMask); + linkOldToNew(index, buffer, offset, newBuffer, offsetInNew, v == null ? s.get() : v); + int newCapacity = (int) (producerMask + 1); + if (newCapacity == maxCapacity) { + long currConsumerIndex = lvConsumerIndex(); + // use lookAheadStep to store the consumer distance from final buffer + this.lookAheadStep = -(index - currConsumerIndex); + producerBufferLimit = currConsumerIndex + maxCapacity; + } else { + producerBufferLimit = index + producerMask - 1; + adjustLookAheadStep(newCapacity); + } + } + return true; + } else // the step is negative (or zero) in the period between allocating the max sized buffer and the + // consumer starting on it + { + final long prevElementsInOtherBuffers = -lookAheadStep; + // until the consumer starts using the current buffer we need to check consumer index to + // verify size + long currConsumerIndex = lvConsumerIndex(); + int size = (int) (index - currConsumerIndex); + // we're on max capacity or we wouldn't be here + int maxCapacity = (int) mask + 1; + if (size == maxCapacity) { + // consumer index has not changed since adjusting the lookAhead index, we're full + return false; + } + // if consumerIndex progressed enough so that current size indicates it is on same buffer + long firstIndexInCurrentBuffer = producerBufferLimit - maxCapacity + prevElementsInOtherBuffers; + if (currConsumerIndex >= firstIndexInCurrentBuffer) { + // job done, we've now settled into our final state + adjustLookAheadStep(maxCapacity); + } else // consumer is still on some other buffer + { + // how many elements out of buffer? + this.lookAheadStep = (int) (currConsumerIndex - firstIndexInCurrentBuffer); + } + producerBufferLimit = currConsumerIndex + maxCapacity; + writeToQueue(buffer, v == null ? s.get() : v, index, offset); + return true; + } + } + + private void adjustLookAheadStep(int capacity) { + lookAheadStep = SpscLookAheadUtil.computeLookAheadStep(capacity); + } + + @Override + public int capacity() { + return maxQueueCapacity; + } +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/atomic/SpscLinkedAtomicQueue.java b/netty-jctools/src/main/java/org/jctools/queues/atomic/SpscLinkedAtomicQueue.java new file mode 100644 index 0000000..2bd1081 --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/atomic/SpscLinkedAtomicQueue.java @@ -0,0 +1,110 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues.atomic; + +import java.util.concurrent.atomic.*; +import org.jctools.queues.*; +import static org.jctools.queues.atomic.AtomicQueueUtil.*; + +/** + * NOTE: This class was automatically generated by org.jctools.queues.atomic.JavaParsingAtomicLinkedQueueGenerator + * which can found in the jctools-build module. The original source file is SpscLinkedQueue.java. + * + * This is a weakened version of the MPSC algorithm as presented + * on + * 1024 Cores by D. Vyukov. The original has been adapted to Java and it's quirks with regards to memory + * model and layout: + *

    + *
  1. Use inheritance to ensure no false sharing occurs between producer/consumer node reference fields. + *
  2. As this is an SPSC we have no need for XCHG, an ordered store is enough. + *
+ * The queue is initialized with a stub node which is set to both the producer and consumer node references. + * From this point follow the notes on offer/poll. + * + * @param + * @author nitsanw + */ +public class SpscLinkedAtomicQueue extends BaseLinkedAtomicQueue { + + public SpscLinkedAtomicQueue() { + LinkedQueueAtomicNode node = newNode(); + spProducerNode(node); + spConsumerNode(node); + // this ensures correct construction: StoreStore + node.soNext(null); + } + + /** + * {@inheritDoc}
+ *

+ * IMPLEMENTATION NOTES:
+ * Offer is allowed from a SINGLE thread.
+ * Offer allocates a new node (holding the offered value) and: + *

    + *
  1. Sets the new node as the producerNode + *
  2. Sets that node as the lastProducerNode.next + *
+ * From this follows that producerNode.next is always null and for all other nodes node.next is not null. + * + * @see MessagePassingQueue#offer(Object) + * @see java.util.Queue#offer(java.lang.Object) + */ + @Override + public boolean offer(final E e) { + if (null == e) { + throw new NullPointerException(); + } + final LinkedQueueAtomicNode nextNode = newNode(e); + LinkedQueueAtomicNode oldNode = lpProducerNode(); + soProducerNode(nextNode); + // Should a producer thread get interrupted here the chain WILL be broken until that thread is resumed + // and completes the store in prev.next. This is a "bubble". + // Inverting the order here will break the `isEmpty` invariant, and will require matching adjustments elsewhere. + oldNode.soNext(nextNode); + return true; + } + + @Override + public int fill(Supplier s) { + return MessagePassingQueueUtil.fillUnbounded(this, s); + } + + @Override + public int fill(Supplier s, int limit) { + if (null == s) + throw new IllegalArgumentException("supplier is null"); + if (limit < 0) + throw new IllegalArgumentException("limit is negative:" + limit); + if (limit == 0) + return 0; + LinkedQueueAtomicNode tail = newNode(s.get()); + final LinkedQueueAtomicNode head = tail; + for (int i = 1; i < limit; i++) { + final LinkedQueueAtomicNode temp = newNode(s.get()); + // spNext : soProducerNode ensures correct construction + tail.spNext(temp); + tail = temp; + } + final LinkedQueueAtomicNode oldPNode = lpProducerNode(); + soProducerNode(tail); + // same bubble as offer, and for the same reasons. + oldPNode.soNext(head); + return limit; + } + + @Override + public void fill(Supplier s, WaitStrategy wait, ExitCondition exit) { + MessagePassingQueueUtil.fill(this, s, wait, exit); + } +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/atomic/SpscUnboundedAtomicArrayQueue.java b/netty-jctools/src/main/java/org/jctools/queues/atomic/SpscUnboundedAtomicArrayQueue.java new file mode 100644 index 0000000..3d68201 --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/atomic/SpscUnboundedAtomicArrayQueue.java @@ -0,0 +1,78 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues.atomic; + +import org.jctools.util.Pow2; +import java.util.concurrent.atomic.*; +import org.jctools.queues.*; +import static org.jctools.queues.atomic.AtomicQueueUtil.*; + +/** + * NOTE: This class was automatically generated by org.jctools.queues.atomic.JavaParsingAtomicLinkedQueueGenerator + * which can found in the jctools-build module. The original source file is SpscUnboundedArrayQueue.java. + * + * An SPSC array queue which starts at initialCapacity and grows indefinitely in linked chunks of the initial size. + * The queue grows only when the current chunk is full and elements are not copied on + * resize, instead a link to the new chunk is stored in the old chunk for the consumer to follow.
+ * + * @param + */ +public class SpscUnboundedAtomicArrayQueue extends BaseSpscLinkedAtomicArrayQueue { + + public SpscUnboundedAtomicArrayQueue(int chunkSize) { + int chunkCapacity = Math.max(Pow2.roundToPowerOfTwo(chunkSize), 16); + long mask = chunkCapacity - 1; + AtomicReferenceArray buffer = allocateRefArray(chunkCapacity + 1); + producerBuffer = buffer; + producerMask = mask; + consumerBuffer = buffer; + consumerMask = mask; + // we know it's all empty to start with + producerBufferLimit = mask - 1; + } + + @Override + final boolean offerColdPath(AtomicReferenceArray buffer, long mask, long pIndex, int offset, E v, Supplier s) { + // use a fixed lookahead step based on buffer capacity + final long lookAheadStep = (mask + 1) / 4; + long pBufferLimit = pIndex + lookAheadStep; + // go around the buffer or add a new buffer + if (null == lvRefElement(buffer, calcCircularRefElementOffset(pBufferLimit, mask))) { + // joy, there's plenty of room + producerBufferLimit = pBufferLimit - 1; + writeToQueue(buffer, v == null ? s.get() : v, pIndex, offset); + } else if (null == lvRefElement(buffer, calcCircularRefElementOffset(pIndex + 1, mask))) { + // buffer is not full + writeToQueue(buffer, v == null ? s.get() : v, pIndex, offset); + } else { + // we got one slot left to write into, and we are not full. Need to link new buffer. + // allocate new buffer of same length + final AtomicReferenceArray newBuffer = allocateRefArray((int) (mask + 2)); + producerBuffer = newBuffer; + producerBufferLimit = pIndex + mask - 1; + linkOldToNew(pIndex, buffer, offset, newBuffer, offset, v == null ? s.get() : v); + } + return true; + } + + @Override + public int fill(Supplier s) { + return fill(s, (int) this.producerMask); + } + + @Override + public int capacity() { + return UNBOUNDED_CAPACITY; + } +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/atomic/package-info.java b/netty-jctools/src/main/java/org/jctools/queues/atomic/package-info.java new file mode 100644 index 0000000..8436b19 --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/atomic/package-info.java @@ -0,0 +1,14 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues.atomic; diff --git a/netty-jctools/src/main/java/org/jctools/queues/package-info.java b/netty-jctools/src/main/java/org/jctools/queues/package-info.java new file mode 100644 index 0000000..37d34aa --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/package-info.java @@ -0,0 +1,98 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +/** + * This package aims to fill a gap in current JDK implementations in offering lock free (wait free where possible) + * queues for inter-thread message passing with finer grained guarantees and an emphasis on performance.
+ * At the time of writing the only lock free queue available in the JDK is + * {@link java.util.concurrent.ConcurrentLinkedQueue} which is an unbounded multi-producer, multi-consumer queue which + * is further encumbered by the need to implement the full range of {@link java.util.Queue} methods. In this package we + * offer a range of implementations: + *
    + *
  1. Bounded/Unbounded SPSC queues - Serving the Single Producer Single Consumer use case. + *
  2. Bounded/Unbounded MPSC queues - The Multi Producer Single Consumer case also has a multi-lane implementation on + * offer which trades the FIFO ordering(re-ordering is not limited) for reduced contention and increased throughput + * under contention. + *
  3. Bounded SPMC/MPMC queues + *
+ *

+ *
+ * Limited Queue methods support:
+ * The queues implement a subset of the {@link java.util.Queue} interface which is documented under the + * {@link org.jctools.queues.MessagePassingQueue} interface. In particular {@link java.util.Queue#iterator()} is usually not + * supported and dependent methods from {@link java.util.AbstractQueue} are also not supported such as: + *

    + *
  1. {@link java.util.Queue#remove(Object)} + *
  2. {@link java.util.Queue#removeAll(java.util.Collection)} + *
  3. {@link java.util.Queue#removeIf(java.util.function.Predicate)} + *
  4. {@link java.util.Queue#contains(Object)} + *
  5. {@link java.util.Queue#containsAll(java.util.Collection)} + *
+ * A few queues do support a limited form of iteration. This support is documented in the Javadoc of the relevant queues. + *

+ *
+ * Memory layout controls and False Sharing:
+ * The classes in this package use what is considered at the moment the most reliable method of controlling class field + * layout, namely inheritance. The method is described in this + * post which also covers + * why other methods are currently suspect.
+ * Note that we attempt to tackle both active (write/write) and passive(read/write) false sharing case: + *

    + *
  1. Hot counters (or write locations) are padded. + *
  2. Read-Only shared fields are padded. + *
  3. Array edges are NOT padded (though doing so is entirely legitimate). + *
+ *

+ *
+ * Use of sun.misc.Unsafe:
+ * A choice is made in this library to utilize sun.misc.Unsafe for performance reasons. In this package we have two use + * cases: + *

    + *
  1. The queue counters in the queues are all inlined (i.e. are primitive fields of the queue classes). To allow + * lazySet/CAS semantics to these fields we could use {@link java.util.concurrent.atomic.AtomicLongFieldUpdater} but + * choose not to for performance reasons. On newer OpenJDKs where AFU is made more performant the difference is small. + *
  2. We use Unsafe to gain volatile/lazySet access to array elements. We could use + * {@link java.util.concurrent.atomic.AtomicReferenceArray} but choose not to for performance reasons(extra reference + * chase and redundant boundary checks). + *
+ * Both use cases should be made obsolete by VarHandles at some point. + *

+ *
+ * Avoiding redundant loads of fields:
+ * Because a volatile load will force any following field access to reload the field value an effort is made to cache + * field values in local variables where possible and expose interfaces which allow the code to capitalize on such + * caching. As a convention the local variable name will be the field name and will be final. + *

+ *
+ * Method naming conventions:
+ * The following convention is followed in method naming to highlight volatile/ordered/plain access to fields: + *

    + *
  1. lpFoo/spFoo: these will be plain load or stores to the field. No memory ordering is needed or expected. + *
  2. soFoo: this is an ordered stored to the field (like + * {@link java.util.concurrent.atomic.AtomicInteger#lazySet(int)}). Implies an ordering of stores (StoreStore barrier + * before the store). + *
  3. lv/svFoo: these are volatile load/store. A store implies a StoreLoad barrier, a load implies LoadLoad barrier + * before and LoadStore after. + *
  4. casFoo: compare and swap the field. StoreLoad if successful. See + * {@link java.util.concurrent.atomic.AtomicInteger#compareAndSet(int, int)} + *
  5. xchgFoo: atomically get and set the field. Effectively a StoreLoad. See + * {@link java.util.concurrent.atomic.AtomicInteger#getAndSet(int)} + *
  6. xaddFoo: atomically get and add to the field. Effectively a StoreLoad. See + * {@link java.util.concurrent.atomic.AtomicInteger#getAndAdd(int)} + *
+ * It is generally expected that a volatile load signals the acquire of a field previously released by a non-plain + * store. + * + * @author nitsanw + */ +package org.jctools.queues; diff --git a/netty-jctools/src/main/java/org/jctools/queues/unpadded/BaseLinkedUnpaddedQueue.java b/netty-jctools/src/main/java/org/jctools/queues/unpadded/BaseLinkedUnpaddedQueue.java new file mode 100644 index 0000000..8456efd --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/unpadded/BaseLinkedUnpaddedQueue.java @@ -0,0 +1,324 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues.unpadded; + +import java.util.AbstractQueue; +import java.util.Iterator; +import java.util.Queue; +import static org.jctools.util.UnsafeAccess.UNSAFE; +import static org.jctools.util.UnsafeAccess.fieldOffset; +import org.jctools.queues.*; + +/** + * NOTE: This class was automatically generated by org.jctools.queues.unpadded.JavaParsingUnpaddedQueueGenerator + * which can found in the jctools-build module. The original source file is BaseLinkedQueue.java. + */ +abstract class BaseLinkedUnpaddedQueuePad0 extends AbstractQueue implements MessagePassingQueue { +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.unpadded.JavaParsingUnpaddedQueueGenerator + * which can found in the jctools-build module. The original source file is BaseLinkedQueue.java. + */ +abstract class BaseLinkedUnpaddedQueueProducerNodeRef extends BaseLinkedUnpaddedQueuePad0 { + + final static long P_NODE_OFFSET = fieldOffset(BaseLinkedUnpaddedQueueProducerNodeRef.class, "producerNode"); + + private volatile LinkedQueueNode producerNode; + + final void spProducerNode(LinkedQueueNode newValue) { + UNSAFE.putObject(this, P_NODE_OFFSET, newValue); + } + + final void soProducerNode(LinkedQueueNode newValue) { + UNSAFE.putOrderedObject(this, P_NODE_OFFSET, newValue); + } + + final LinkedQueueNode lvProducerNode() { + return producerNode; + } + + final boolean casProducerNode(LinkedQueueNode expect, LinkedQueueNode newValue) { + return UNSAFE.compareAndSwapObject(this, P_NODE_OFFSET, expect, newValue); + } + + final LinkedQueueNode lpProducerNode() { + return producerNode; + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.unpadded.JavaParsingUnpaddedQueueGenerator + * which can found in the jctools-build module. The original source file is BaseLinkedQueue.java. + */ +abstract class BaseLinkedUnpaddedQueuePad1 extends BaseLinkedUnpaddedQueueProducerNodeRef { +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.unpadded.JavaParsingUnpaddedQueueGenerator + * which can found in the jctools-build module. The original source file is BaseLinkedQueue.java. + */ +abstract class BaseLinkedUnpaddedQueueConsumerNodeRef extends BaseLinkedUnpaddedQueuePad1 { + + private final static long C_NODE_OFFSET = fieldOffset(BaseLinkedUnpaddedQueueConsumerNodeRef.class, "consumerNode"); + + private LinkedQueueNode consumerNode; + + final void spConsumerNode(LinkedQueueNode newValue) { + consumerNode = newValue; + } + + @SuppressWarnings("unchecked") + final LinkedQueueNode lvConsumerNode() { + return (LinkedQueueNode) UNSAFE.getObjectVolatile(this, C_NODE_OFFSET); + } + + final LinkedQueueNode lpConsumerNode() { + return consumerNode; + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.unpadded.JavaParsingUnpaddedQueueGenerator + * which can found in the jctools-build module. The original source file is BaseLinkedQueue.java. + */ +abstract class BaseLinkedUnpaddedQueuePad2 extends BaseLinkedUnpaddedQueueConsumerNodeRef { +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.unpadded.JavaParsingUnpaddedQueueGenerator + * which can found in the jctools-build module. The original source file is BaseLinkedQueue.java. + * + * A base data structure for concurrent linked queues. For convenience also pulled in common single consumer + * methods since at this time there's no plan to implement MC. + */ +abstract class BaseLinkedUnpaddedQueue extends BaseLinkedUnpaddedQueuePad2 { + + @Override + public final Iterator iterator() { + throw new UnsupportedOperationException(); + } + + @Override + public String toString() { + return this.getClass().getName(); + } + + protected final LinkedQueueNode newNode() { + return new LinkedQueueNode(); + } + + protected final LinkedQueueNode newNode(E e) { + return new LinkedQueueNode(e); + } + + /** + * {@inheritDoc}
+ *

+ * IMPLEMENTATION NOTES:
+ * This is an O(n) operation as we run through all the nodes and count them.
+ * The accuracy of the value returned by this method is subject to races with producer/consumer threads. In + * particular when racing with the consumer thread this method may under estimate the size.
+ * + * @see java.util.Queue#size() + */ + @Override + public final int size() { + // Read consumer first, this is important because if the producer is node is 'older' than the consumer + // the consumer may overtake it (consume past it) invalidating the 'snapshot' notion of size. + LinkedQueueNode chaserNode = lvConsumerNode(); + LinkedQueueNode producerNode = lvProducerNode(); + int size = 0; + // must chase the nodes all the way to the producer node, but there's no need to count beyond expected head. + while (// don't go passed producer node + chaserNode != producerNode && // stop at last node + chaserNode != null && // stop at max int + size < Integer.MAX_VALUE) { + LinkedQueueNode next; + next = chaserNode.lvNext(); + // check if this node has been consumed, if so return what we have + if (next == chaserNode) { + return size; + } + chaserNode = next; + size++; + } + return size; + } + + /** + * {@inheritDoc}
+ *

+ * IMPLEMENTATION NOTES:
+ * Queue is empty when producerNode is the same as consumerNode. An alternative implementation would be to + * observe the producerNode.value is null, which also means an empty queue because only the + * consumerNode.value is allowed to be null. + * + * @see MessagePassingQueue#isEmpty() + */ + @Override + public boolean isEmpty() { + LinkedQueueNode consumerNode = lvConsumerNode(); + LinkedQueueNode producerNode = lvProducerNode(); + return consumerNode == producerNode; + } + + protected E getSingleConsumerNodeValue(LinkedQueueNode currConsumerNode, LinkedQueueNode nextNode) { + // we have to null out the value because we are going to hang on to the node + final E nextValue = nextNode.getAndNullValue(); + // Fix up the next ref of currConsumerNode to prevent promoted nodes from keeping new ones alive. + // We use a reference to self instead of null because null is already a meaningful value (the next of + // producer node is null). + currConsumerNode.soNext(currConsumerNode); + spConsumerNode(nextNode); + // currConsumerNode is now no longer referenced and can be collected + return nextValue; + } + + /** + * {@inheritDoc}
+ *

+ * IMPLEMENTATION NOTES:
+ * Poll is allowed from a SINGLE thread.
+ * Poll is potentially blocking here as the {@link Queue#poll()} does not allow returning {@code null} if the queue is not + * empty. This is very different from the original Vyukov guarantees. See {@link #relaxedPoll()} for the original + * semantics.
+ * Poll reads {@code consumerNode.next} and: + *

    + *
  1. If it is {@code null} AND the queue is empty return {@code null}, if queue is not empty spin wait for + * value to become visible. + *
  2. If it is not {@code null} set it as the consumer node and return it's now evacuated value. + *
+ * This means the consumerNode.value is always {@code null}, which is also the starting point for the queue. + * Because {@code null} values are not allowed to be offered this is the only node with it's value set to + * {@code null} at any one time. + * + * @see MessagePassingQueue#poll() + * @see java.util.Queue#poll() + */ + @Override + public E poll() { + final LinkedQueueNode currConsumerNode = lpConsumerNode(); + LinkedQueueNode nextNode = currConsumerNode.lvNext(); + if (nextNode != null) { + return getSingleConsumerNodeValue(currConsumerNode, nextNode); + } else if (currConsumerNode != lvProducerNode()) { + nextNode = spinWaitForNextNode(currConsumerNode); + // got the next node... + return getSingleConsumerNodeValue(currConsumerNode, nextNode); + } + return null; + } + + /** + * {@inheritDoc}
+ *

+ * IMPLEMENTATION NOTES:
+ * Peek is allowed from a SINGLE thread.
+ * Peek is potentially blocking here as the {@link Queue#peek()} does not allow returning {@code null} if the queue is not + * empty. This is very different from the original Vyukov guarantees. See {@link #relaxedPeek()} for the original + * semantics.
+ * Poll reads the next node from the consumerNode and: + *

    + *
  1. If it is {@code null} AND the queue is empty return {@code null}, if queue is not empty spin wait for + * value to become visible. + *
  2. If it is not {@code null} return it's value. + *
+ * + * @see MessagePassingQueue#peek() + * @see java.util.Queue#peek() + */ + @Override + public E peek() { + final LinkedQueueNode currConsumerNode = lpConsumerNode(); + LinkedQueueNode nextNode = currConsumerNode.lvNext(); + if (nextNode != null) { + return nextNode.lpValue(); + } else if (currConsumerNode != lvProducerNode()) { + nextNode = spinWaitForNextNode(currConsumerNode); + // got the next node... + return nextNode.lpValue(); + } + return null; + } + + LinkedQueueNode spinWaitForNextNode(LinkedQueueNode currNode) { + LinkedQueueNode nextNode; + while ((nextNode = currNode.lvNext()) == null) { + // spin, we are no longer wait free + } + return nextNode; + } + + @Override + public E relaxedPoll() { + final LinkedQueueNode currConsumerNode = lpConsumerNode(); + final LinkedQueueNode nextNode = currConsumerNode.lvNext(); + if (nextNode != null) { + return getSingleConsumerNodeValue(currConsumerNode, nextNode); + } + return null; + } + + @Override + public E relaxedPeek() { + final LinkedQueueNode nextNode = lpConsumerNode().lvNext(); + if (nextNode != null) { + return nextNode.lpValue(); + } + return null; + } + + @Override + public boolean relaxedOffer(E e) { + return offer(e); + } + + @Override + public int drain(Consumer c, int limit) { + if (null == c) + throw new IllegalArgumentException("c is null"); + if (limit < 0) + throw new IllegalArgumentException("limit is negative: " + limit); + if (limit == 0) + return 0; + LinkedQueueNode chaserNode = this.lpConsumerNode(); + for (int i = 0; i < limit; i++) { + final LinkedQueueNode nextNode = chaserNode.lvNext(); + if (nextNode == null) { + return i; + } + // we have to null out the value because we are going to hang on to the node + final E nextValue = getSingleConsumerNodeValue(chaserNode, nextNode); + chaserNode = nextNode; + c.accept(nextValue); + } + return limit; + } + + @Override + public int drain(Consumer c) { + return MessagePassingQueueUtil.drain(this, c); + } + + @Override + public void drain(Consumer c, WaitStrategy wait, ExitCondition exit) { + MessagePassingQueueUtil.drain(this, c, wait, exit); + } + + @Override + public int capacity() { + return UNBOUNDED_CAPACITY; + } +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/unpadded/BaseMpscLinkedUnpaddedArrayQueue.java b/netty-jctools/src/main/java/org/jctools/queues/unpadded/BaseMpscLinkedUnpaddedArrayQueue.java new file mode 100644 index 0000000..66ff539 --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/unpadded/BaseMpscLinkedUnpaddedArrayQueue.java @@ -0,0 +1,649 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues.unpadded; + +import org.jctools.queues.IndexedQueueSizeUtil.IndexedQueue; +import org.jctools.util.PortableJvmInfo; +import org.jctools.util.Pow2; +import org.jctools.util.RangeUtil; +import java.util.AbstractQueue; +import java.util.Iterator; +import java.util.NoSuchElementException; +import static org.jctools.queues.LinkedArrayQueueUtil.length; +import static org.jctools.queues.LinkedArrayQueueUtil.modifiedCalcCircularRefElementOffset; +import static org.jctools.util.UnsafeAccess.UNSAFE; +import static org.jctools.util.UnsafeAccess.fieldOffset; +import static org.jctools.util.UnsafeRefArrayAccess.*; +import org.jctools.queues.*; + +/** + * NOTE: This class was automatically generated by org.jctools.queues.unpadded.JavaParsingUnpaddedQueueGenerator + * which can found in the jctools-build module. The original source file is BaseMpscLinkedArrayQueue.java. + */ +abstract class BaseMpscLinkedUnpaddedArrayQueuePad1 extends AbstractQueue implements IndexedQueue { +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.unpadded.JavaParsingUnpaddedQueueGenerator + * which can found in the jctools-build module. The original source file is BaseMpscLinkedArrayQueue.java. + */ +abstract class BaseMpscLinkedUnpaddedArrayQueueProducerFields extends BaseMpscLinkedUnpaddedArrayQueuePad1 { + + private final static long P_INDEX_OFFSET = fieldOffset(BaseMpscLinkedUnpaddedArrayQueueProducerFields.class, "producerIndex"); + + private volatile long producerIndex; + + @Override + public final long lvProducerIndex() { + return producerIndex; + } + + final void soProducerIndex(long newValue) { + UNSAFE.putOrderedLong(this, P_INDEX_OFFSET, newValue); + } + + final boolean casProducerIndex(long expect, long newValue) { + return UNSAFE.compareAndSwapLong(this, P_INDEX_OFFSET, expect, newValue); + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.unpadded.JavaParsingUnpaddedQueueGenerator + * which can found in the jctools-build module. The original source file is BaseMpscLinkedArrayQueue.java. + */ +abstract class BaseMpscLinkedUnpaddedArrayQueuePad2 extends BaseMpscLinkedUnpaddedArrayQueueProducerFields { +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.unpadded.JavaParsingUnpaddedQueueGenerator + * which can found in the jctools-build module. The original source file is BaseMpscLinkedArrayQueue.java. + */ +abstract class BaseMpscLinkedUnpaddedArrayQueueConsumerFields extends BaseMpscLinkedUnpaddedArrayQueuePad2 { + + private final static long C_INDEX_OFFSET = fieldOffset(BaseMpscLinkedUnpaddedArrayQueueConsumerFields.class, "consumerIndex"); + + private volatile long consumerIndex; + + protected long consumerMask; + + protected E[] consumerBuffer; + + @Override + public final long lvConsumerIndex() { + return consumerIndex; + } + + final long lpConsumerIndex() { + return UNSAFE.getLong(this, C_INDEX_OFFSET); + } + + final void soConsumerIndex(long newValue) { + UNSAFE.putOrderedLong(this, C_INDEX_OFFSET, newValue); + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.unpadded.JavaParsingUnpaddedQueueGenerator + * which can found in the jctools-build module. The original source file is BaseMpscLinkedArrayQueue.java. + */ +abstract class BaseMpscLinkedUnpaddedArrayQueuePad3 extends BaseMpscLinkedUnpaddedArrayQueueConsumerFields { +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.unpadded.JavaParsingUnpaddedQueueGenerator + * which can found in the jctools-build module. The original source file is BaseMpscLinkedArrayQueue.java. + */ +abstract class BaseMpscLinkedUnpaddedArrayQueueColdProducerFields extends BaseMpscLinkedUnpaddedArrayQueuePad3 { + + private final static long P_LIMIT_OFFSET = fieldOffset(BaseMpscLinkedUnpaddedArrayQueueColdProducerFields.class, "producerLimit"); + + private volatile long producerLimit; + + protected long producerMask; + + protected E[] producerBuffer; + + final long lvProducerLimit() { + return producerLimit; + } + + final boolean casProducerLimit(long expect, long newValue) { + return UNSAFE.compareAndSwapLong(this, P_LIMIT_OFFSET, expect, newValue); + } + + final void soProducerLimit(long newValue) { + UNSAFE.putOrderedLong(this, P_LIMIT_OFFSET, newValue); + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.unpadded.JavaParsingUnpaddedQueueGenerator + * which can found in the jctools-build module. The original source file is BaseMpscLinkedArrayQueue.java. + * + * An MPSC array queue which starts at initialCapacity and grows to maxCapacity in linked chunks + * of the initial size. The queue grows only when the current buffer is full and elements are not copied on + * resize, instead a link to the new buffer is stored in the old buffer for the consumer to follow. + */ +abstract class BaseMpscLinkedUnpaddedArrayQueue extends BaseMpscLinkedUnpaddedArrayQueueColdProducerFields implements MessagePassingQueue, QueueProgressIndicators { + + // No post padding here, subclasses must add + private static final Object JUMP = new Object(); + + private static final Object BUFFER_CONSUMED = new Object(); + + private static final int CONTINUE_TO_P_INDEX_CAS = 0; + + private static final int RETRY = 1; + + private static final int QUEUE_FULL = 2; + + private static final int QUEUE_RESIZE = 3; + + /** + * @param initialCapacity the queue initial capacity. If chunk size is fixed this will be the chunk size. + * Must be 2 or more. + */ + public BaseMpscLinkedUnpaddedArrayQueue(final int initialCapacity) { + RangeUtil.checkGreaterThanOrEqual(initialCapacity, 2, "initialCapacity"); + int p2capacity = Pow2.roundToPowerOfTwo(initialCapacity); + // leave lower bit of mask clear + long mask = (p2capacity - 1) << 1; + // need extra element to point at next array + E[] buffer = allocateRefArray(p2capacity + 1); + producerBuffer = buffer; + producerMask = mask; + consumerBuffer = buffer; + consumerMask = mask; + // we know it's all empty to start with + soProducerLimit(mask); + } + + @Override + public int size() { + return IndexedQueueSizeUtil.size(this, IndexedQueueSizeUtil.IGNORE_PARITY_DIVISOR); + } + + @Override + public boolean isEmpty() { + // Order matters! + // Loading consumer before producer allows for producer increments after consumer index is read. + // This ensures this method is conservative in it's estimate. Note that as this is an MPMC there is + // nothing we can do to make this an exact method. + return ((lvConsumerIndex() - lvProducerIndex()) / 2 == 0); + } + + @Override + public String toString() { + return this.getClass().getName(); + } + + @Override + public boolean offer(final E e) { + if (null == e) { + throw new NullPointerException(); + } + long mask; + E[] buffer; + long pIndex; + while (true) { + long producerLimit = lvProducerLimit(); + pIndex = lvProducerIndex(); + // lower bit is indicative of resize, if we see it we spin until it's cleared + if ((pIndex & 1) == 1) { + continue; + } + // pIndex is even (lower bit is 0) -> actual index is (pIndex >> 1) + // mask/buffer may get changed by resizing -> only use for array access after successful CAS. + mask = this.producerMask; + buffer = this.producerBuffer; + // a successful CAS ties the ordering, lv(pIndex) - [mask/buffer] -> cas(pIndex) + // assumption behind this optimization is that queue is almost always empty or near empty + if (producerLimit <= pIndex) { + int result = offerSlowPath(mask, pIndex, producerLimit); + switch(result) { + case CONTINUE_TO_P_INDEX_CAS: + break; + case RETRY: + continue; + case QUEUE_FULL: + return false; + case QUEUE_RESIZE: + resize(mask, buffer, pIndex, e, null); + return true; + } + } + if (casProducerIndex(pIndex, pIndex + 2)) { + break; + } + } + // INDEX visible before ELEMENT + final long offset = modifiedCalcCircularRefElementOffset(pIndex, mask); + // release element e + soRefElement(buffer, offset, e); + return true; + } + + /** + * {@inheritDoc} + *

+ * This implementation is correct for single consumer thread use only. + */ + @SuppressWarnings("unchecked") + @Override + public E poll() { + final E[] buffer = consumerBuffer; + final long cIndex = lpConsumerIndex(); + final long mask = consumerMask; + final long offset = modifiedCalcCircularRefElementOffset(cIndex, mask); + Object e = lvRefElement(buffer, offset); + if (e == null) { + long pIndex = lvProducerIndex(); + // isEmpty? + if ((cIndex - pIndex) / 2 == 0) { + return null; + } + // poll() == null iff queue is empty, null element is not strong enough indicator, so we must + // spin until element is visible. + do { + e = lvRefElement(buffer, offset); + } while (e == null); + } + if (e == JUMP) { + final E[] nextBuffer = nextBuffer(buffer, mask); + return newBufferPoll(nextBuffer, cIndex); + } + // release element null + soRefElement(buffer, offset, null); + // release cIndex + soConsumerIndex(cIndex + 2); + return (E) e; + } + + /** + * {@inheritDoc} + *

+ * This implementation is correct for single consumer thread use only. + */ + @SuppressWarnings("unchecked") + @Override + public E peek() { + final E[] buffer = consumerBuffer; + final long cIndex = lpConsumerIndex(); + final long mask = consumerMask; + final long offset = modifiedCalcCircularRefElementOffset(cIndex, mask); + Object e = lvRefElement(buffer, offset); + if (e == null) { + long pIndex = lvProducerIndex(); + // isEmpty? + if ((cIndex - pIndex) / 2 == 0) { + return null; + } + // peek() == null iff queue is empty, null element is not strong enough indicator, so we must + // spin until element is visible. + do { + e = lvRefElement(buffer, offset); + } while (e == null); + } + if (e == JUMP) { + return newBufferPeek(nextBuffer(buffer, mask), cIndex); + } + return (E) e; + } + + /** + * We do not inline resize into this method because we do not resize on fill. + */ + private int offerSlowPath(long mask, long pIndex, long producerLimit) { + final long cIndex = lvConsumerIndex(); + long bufferCapacity = getCurrentBufferCapacity(mask); + if (cIndex + bufferCapacity > pIndex) { + if (!casProducerLimit(producerLimit, cIndex + bufferCapacity)) { + // retry from top + return RETRY; + } else { + // continue to pIndex CAS + return CONTINUE_TO_P_INDEX_CAS; + } + } else // full and cannot grow + if (availableInQueue(pIndex, cIndex) <= 0) { + // offer should return false; + return QUEUE_FULL; + } else // grab index for resize -> set lower bit + if (casProducerIndex(pIndex, pIndex + 1)) { + // trigger a resize + return QUEUE_RESIZE; + } else { + // failed resize attempt, retry from top + return RETRY; + } + } + + /** + * @return available elements in queue * 2 + */ + protected abstract long availableInQueue(long pIndex, long cIndex); + + @SuppressWarnings("unchecked") + private E[] nextBuffer(final E[] buffer, final long mask) { + final long offset = nextArrayOffset(mask); + final E[] nextBuffer = (E[]) lvRefElement(buffer, offset); + consumerBuffer = nextBuffer; + consumerMask = (length(nextBuffer) - 2) << 1; + soRefElement(buffer, offset, BUFFER_CONSUMED); + return nextBuffer; + } + + private static long nextArrayOffset(long mask) { + return modifiedCalcCircularRefElementOffset(mask + 2, Long.MAX_VALUE); + } + + private E newBufferPoll(E[] nextBuffer, long cIndex) { + final long offset = modifiedCalcCircularRefElementOffset(cIndex, consumerMask); + final E n = lvRefElement(nextBuffer, offset); + if (n == null) { + throw new IllegalStateException("new buffer must have at least one element"); + } + soRefElement(nextBuffer, offset, null); + soConsumerIndex(cIndex + 2); + return n; + } + + private E newBufferPeek(E[] nextBuffer, long cIndex) { + final long offset = modifiedCalcCircularRefElementOffset(cIndex, consumerMask); + final E n = lvRefElement(nextBuffer, offset); + if (null == n) { + throw new IllegalStateException("new buffer must have at least one element"); + } + return n; + } + + @Override + public long currentProducerIndex() { + return lvProducerIndex() / 2; + } + + @Override + public long currentConsumerIndex() { + return lvConsumerIndex() / 2; + } + + @Override + public abstract int capacity(); + + @Override + public boolean relaxedOffer(E e) { + return offer(e); + } + + @SuppressWarnings("unchecked") + @Override + public E relaxedPoll() { + final E[] buffer = consumerBuffer; + final long cIndex = lpConsumerIndex(); + final long mask = consumerMask; + final long offset = modifiedCalcCircularRefElementOffset(cIndex, mask); + Object e = lvRefElement(buffer, offset); + if (e == null) { + return null; + } + if (e == JUMP) { + final E[] nextBuffer = nextBuffer(buffer, mask); + return newBufferPoll(nextBuffer, cIndex); + } + soRefElement(buffer, offset, null); + soConsumerIndex(cIndex + 2); + return (E) e; + } + + @SuppressWarnings("unchecked") + @Override + public E relaxedPeek() { + final E[] buffer = consumerBuffer; + final long cIndex = lpConsumerIndex(); + final long mask = consumerMask; + final long offset = modifiedCalcCircularRefElementOffset(cIndex, mask); + Object e = lvRefElement(buffer, offset); + if (e == JUMP) { + return newBufferPeek(nextBuffer(buffer, mask), cIndex); + } + return (E) e; + } + + @Override + public int fill(Supplier s) { + // result is a long because we want to have a safepoint check at regular intervals + long result = 0; + final int capacity = capacity(); + do { + final int filled = fill(s, PortableJvmInfo.RECOMENDED_OFFER_BATCH); + if (filled == 0) { + return (int) result; + } + result += filled; + } while (result <= capacity); + return (int) result; + } + + @Override + public int fill(Supplier s, int limit) { + if (null == s) + throw new IllegalArgumentException("supplier is null"); + if (limit < 0) + throw new IllegalArgumentException("limit is negative:" + limit); + if (limit == 0) + return 0; + long mask; + E[] buffer; + long pIndex; + int claimedSlots; + while (true) { + long producerLimit = lvProducerLimit(); + pIndex = lvProducerIndex(); + // lower bit is indicative of resize, if we see it we spin until it's cleared + if ((pIndex & 1) == 1) { + continue; + } + // pIndex is even (lower bit is 0) -> actual index is (pIndex >> 1) + // NOTE: mask/buffer may get changed by resizing -> only use for array access after successful CAS. + // Only by virtue offloading them between the lvProducerIndex and a successful casProducerIndex are they + // safe to use. + mask = this.producerMask; + buffer = this.producerBuffer; + // a successful CAS ties the ordering, lv(pIndex) -> [mask/buffer] -> cas(pIndex) + // we want 'limit' slots, but will settle for whatever is visible to 'producerLimit' + // -> producerLimit >= batchIndex + long batchIndex = Math.min(producerLimit, pIndex + 2l * limit); + if (pIndex >= producerLimit) { + int result = offerSlowPath(mask, pIndex, producerLimit); + switch(result) { + case CONTINUE_TO_P_INDEX_CAS: + // offer slow path verifies only one slot ahead, we cannot rely on indication here + case RETRY: + continue; + case QUEUE_FULL: + return 0; + case QUEUE_RESIZE: + resize(mask, buffer, pIndex, null, s); + return 1; + } + } + // claim limit slots at once + if (casProducerIndex(pIndex, batchIndex)) { + claimedSlots = (int) ((batchIndex - pIndex) / 2); + break; + } + } + for (int i = 0; i < claimedSlots; i++) { + final long offset = modifiedCalcCircularRefElementOffset(pIndex + 2l * i, mask); + soRefElement(buffer, offset, s.get()); + } + return claimedSlots; + } + + @Override + public void fill(Supplier s, WaitStrategy wait, ExitCondition exit) { + MessagePassingQueueUtil.fill(this, s, wait, exit); + } + + @Override + public int drain(Consumer c) { + return drain(c, capacity()); + } + + @Override + public int drain(Consumer c, int limit) { + return MessagePassingQueueUtil.drain(this, c, limit); + } + + @Override + public void drain(Consumer c, WaitStrategy wait, ExitCondition exit) { + MessagePassingQueueUtil.drain(this, c, wait, exit); + } + + /** + * Get an iterator for this queue. This method is thread safe. + *

+ * The iterator provides a best-effort snapshot of the elements in the queue. + * The returned iterator is not guaranteed to return elements in queue order, + * and races with the consumer thread may cause gaps in the sequence of returned elements. + * Like {link #relaxedPoll}, the iterator may not immediately return newly inserted elements. + * + * @return The iterator. + */ + @Override + public Iterator iterator() { + return new WeakIterator(consumerBuffer, lvConsumerIndex(), lvProducerIndex()); + } + + private static class WeakIterator implements Iterator { + + private final long pIndex; + + private long nextIndex; + + private E nextElement; + + private E[] currentBuffer; + + private int mask; + + WeakIterator(E[] currentBuffer, long cIndex, long pIndex) { + this.pIndex = pIndex >> 1; + this.nextIndex = cIndex >> 1; + setBuffer(currentBuffer); + nextElement = getNext(); + } + + @Override + public void remove() { + throw new UnsupportedOperationException("remove"); + } + + @Override + public boolean hasNext() { + return nextElement != null; + } + + @Override + public E next() { + final E e = nextElement; + if (e == null) { + throw new NoSuchElementException(); + } + nextElement = getNext(); + return e; + } + + private void setBuffer(E[] buffer) { + this.currentBuffer = buffer; + this.mask = length(buffer) - 2; + } + + private E getNext() { + while (nextIndex < pIndex) { + long index = nextIndex++; + E e = lvRefElement(currentBuffer, calcCircularRefElementOffset(index, mask)); + // skip removed/not yet visible elements + if (e == null) { + continue; + } + // not null && not JUMP -> found next element + if (e != JUMP) { + return e; + } + // need to jump to the next buffer + int nextBufferIndex = mask + 1; + Object nextBuffer = lvRefElement(currentBuffer, calcRefElementOffset(nextBufferIndex)); + if (nextBuffer == BUFFER_CONSUMED || nextBuffer == null) { + // Consumer may have passed us, or the next buffer is not visible yet: drop out early + return null; + } + setBuffer((E[]) nextBuffer); + // now with the new array retry the load, it can't be a JUMP, but we need to repeat same index + e = lvRefElement(currentBuffer, calcCircularRefElementOffset(index, mask)); + // skip removed/not yet visible elements + if (e == null) { + continue; + } else { + return e; + } + } + return null; + } + } + + private void resize(long oldMask, E[] oldBuffer, long pIndex, E e, Supplier s) { + assert (e != null && s == null) || (e == null || s != null); + int newBufferLength = getNextBufferSize(oldBuffer); + final E[] newBuffer; + try { + newBuffer = allocateRefArray(newBufferLength); + } catch (OutOfMemoryError oom) { + assert lvProducerIndex() == pIndex + 1; + soProducerIndex(pIndex); + throw oom; + } + producerBuffer = newBuffer; + final int newMask = (newBufferLength - 2) << 1; + producerMask = newMask; + final long offsetInOld = modifiedCalcCircularRefElementOffset(pIndex, oldMask); + final long offsetInNew = modifiedCalcCircularRefElementOffset(pIndex, newMask); + // element in new array + soRefElement(newBuffer, offsetInNew, e == null ? s.get() : e); + // buffer linked + soRefElement(oldBuffer, nextArrayOffset(oldMask), newBuffer); + // ASSERT code + final long cIndex = lvConsumerIndex(); + final long availableInQueue = availableInQueue(pIndex, cIndex); + RangeUtil.checkPositive(availableInQueue, "availableInQueue"); + // Invalidate racing CASs + // We never set the limit beyond the bounds of a buffer + soProducerLimit(pIndex + Math.min(newMask, availableInQueue)); + // make resize visible to the other producers + soProducerIndex(pIndex + 2); + // INDEX visible before ELEMENT, consistent with consumer expectation + // make resize visible to consumer + soRefElement(oldBuffer, offsetInOld, JUMP); + } + + /** + * @return next buffer size(inclusive of next array pointer) + */ + protected abstract int getNextBufferSize(E[] buffer); + + /** + * @return current buffer capacity for elements (excluding next pointer and jump entry) * 2 + */ + protected abstract long getCurrentBufferCapacity(long mask); +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/unpadded/BaseSpscLinkedUnpaddedArrayQueue.java b/netty-jctools/src/main/java/org/jctools/queues/unpadded/BaseSpscLinkedUnpaddedArrayQueue.java new file mode 100644 index 0000000..173a33d --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/unpadded/BaseSpscLinkedUnpaddedArrayQueue.java @@ -0,0 +1,354 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues.unpadded; + +import org.jctools.queues.IndexedQueueSizeUtil.IndexedQueue; +import org.jctools.util.PortableJvmInfo; +import java.util.AbstractQueue; +import java.util.Iterator; +import static org.jctools.queues.LinkedArrayQueueUtil.length; +import static org.jctools.queues.LinkedArrayQueueUtil.nextArrayOffset; +import static org.jctools.util.UnsafeAccess.UNSAFE; +import static org.jctools.util.UnsafeAccess.fieldOffset; +import static org.jctools.util.UnsafeRefArrayAccess.*; +import org.jctools.queues.*; + +/** + * NOTE: This class was automatically generated by org.jctools.queues.unpadded.JavaParsingUnpaddedQueueGenerator + * which can found in the jctools-build module. The original source file is BaseSpscLinkedArrayQueue.java. + */ +abstract class BaseSpscLinkedUnpaddedArrayQueuePrePad extends AbstractQueue implements IndexedQueue { +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.unpadded.JavaParsingUnpaddedQueueGenerator + * which can found in the jctools-build module. The original source file is BaseSpscLinkedArrayQueue.java. + */ +abstract class BaseSpscLinkedUnpaddedArrayQueueConsumerColdFields extends BaseSpscLinkedUnpaddedArrayQueuePrePad { + + protected long consumerMask; + + protected E[] consumerBuffer; +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.unpadded.JavaParsingUnpaddedQueueGenerator + * which can found in the jctools-build module. The original source file is BaseSpscLinkedArrayQueue.java. + */ +abstract class BaseSpscLinkedUnpaddedArrayQueueConsumerField extends BaseSpscLinkedUnpaddedArrayQueueConsumerColdFields { + + private final static long C_INDEX_OFFSET = fieldOffset(BaseSpscLinkedUnpaddedArrayQueueConsumerField.class, "consumerIndex"); + + private volatile long consumerIndex; + + @Override + public final long lvConsumerIndex() { + return consumerIndex; + } + + final long lpConsumerIndex() { + return UNSAFE.getLong(this, C_INDEX_OFFSET); + } + + final void soConsumerIndex(long newValue) { + UNSAFE.putOrderedLong(this, C_INDEX_OFFSET, newValue); + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.unpadded.JavaParsingUnpaddedQueueGenerator + * which can found in the jctools-build module. The original source file is BaseSpscLinkedArrayQueue.java. + */ +abstract class BaseSpscLinkedUnpaddedArrayQueueL2Pad extends BaseSpscLinkedUnpaddedArrayQueueConsumerField { +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.unpadded.JavaParsingUnpaddedQueueGenerator + * which can found in the jctools-build module. The original source file is BaseSpscLinkedArrayQueue.java. + */ +abstract class BaseSpscLinkedUnpaddedArrayQueueProducerFields extends BaseSpscLinkedUnpaddedArrayQueueL2Pad { + + private final static long P_INDEX_OFFSET = fieldOffset(BaseSpscLinkedUnpaddedArrayQueueProducerFields.class, "producerIndex"); + + private volatile long producerIndex; + + @Override + public final long lvProducerIndex() { + return producerIndex; + } + + final void soProducerIndex(long newValue) { + UNSAFE.putOrderedLong(this, P_INDEX_OFFSET, newValue); + } + + final long lpProducerIndex() { + return UNSAFE.getLong(this, P_INDEX_OFFSET); + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.unpadded.JavaParsingUnpaddedQueueGenerator + * which can found in the jctools-build module. The original source file is BaseSpscLinkedArrayQueue.java. + */ +abstract class BaseSpscLinkedUnpaddedArrayQueueProducerColdFields extends BaseSpscLinkedUnpaddedArrayQueueProducerFields { + + protected long producerBufferLimit; + + // fixed for chunked and unbounded + protected long producerMask; + + protected E[] producerBuffer; +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.unpadded.JavaParsingUnpaddedQueueGenerator + * which can found in the jctools-build module. The original source file is BaseSpscLinkedArrayQueue.java. + */ +abstract class BaseSpscLinkedUnpaddedArrayQueue extends BaseSpscLinkedUnpaddedArrayQueueProducerColdFields implements MessagePassingQueue, QueueProgressIndicators { + + private static final Object JUMP = new Object(); + + @Override + public final Iterator iterator() { + throw new UnsupportedOperationException(); + } + + @Override + public final int size() { + return IndexedQueueSizeUtil.size(this, IndexedQueueSizeUtil.PLAIN_DIVISOR); + } + + @Override + public final boolean isEmpty() { + return IndexedQueueSizeUtil.isEmpty(this); + } + + @Override + public String toString() { + return this.getClass().getName(); + } + + @Override + public long currentProducerIndex() { + return lvProducerIndex(); + } + + @Override + public long currentConsumerIndex() { + return lvConsumerIndex(); + } + + protected final void soNext(E[] curr, E[] next) { + long offset = nextArrayOffset(curr); + soRefElement(curr, offset, next); + } + + @SuppressWarnings("unchecked") + protected final E[] lvNextArrayAndUnlink(E[] curr) { + final long offset = nextArrayOffset(curr); + final E[] nextBuffer = (E[]) lvRefElement(curr, offset); + // prevent GC nepotism + soRefElement(curr, offset, null); + return nextBuffer; + } + + @Override + public boolean relaxedOffer(E e) { + return offer(e); + } + + @Override + public E relaxedPoll() { + return poll(); + } + + @Override + public E relaxedPeek() { + return peek(); + } + + @Override + public int drain(Consumer c) { + return MessagePassingQueueUtil.drain(this, c); + } + + @Override + public int fill(Supplier s) { + // result is a long because we want to have a safepoint check at regular intervals + long result = 0; + final int capacity = capacity(); + do { + final int filled = fill(s, PortableJvmInfo.RECOMENDED_OFFER_BATCH); + if (filled == 0) { + return (int) result; + } + result += filled; + } while (result <= capacity); + return (int) result; + } + + @Override + public int drain(Consumer c, int limit) { + return MessagePassingQueueUtil.drain(this, c, limit); + } + + @Override + public int fill(Supplier s, int limit) { + if (null == s) + throw new IllegalArgumentException("supplier is null"); + if (limit < 0) + throw new IllegalArgumentException("limit is negative:" + limit); + if (limit == 0) + return 0; + for (int i = 0; i < limit; i++) { + // local load of field to avoid repeated loads after volatile reads + final E[] buffer = producerBuffer; + final long index = lpProducerIndex(); + final long mask = producerMask; + final long offset = calcCircularRefElementOffset(index, mask); + // expected hot path + if (index < producerBufferLimit) { + writeToQueue(buffer, s.get(), index, offset); + } else { + if (!offerColdPath(buffer, mask, index, offset, null, s)) { + return i; + } + } + } + return limit; + } + + @Override + public void drain(Consumer c, WaitStrategy wait, ExitCondition exit) { + MessagePassingQueueUtil.drain(this, c, wait, exit); + } + + @Override + public void fill(Supplier s, WaitStrategy wait, ExitCondition exit) { + MessagePassingQueueUtil.fill(this, s, wait, exit); + } + + /** + * {@inheritDoc} + *

+ * This implementation is correct for single producer thread use only. + */ + @Override + public boolean offer(final E e) { + // Objects.requireNonNull(e); + if (null == e) { + throw new NullPointerException(); + } + // local load of field to avoid repeated loads after volatile reads + final E[] buffer = producerBuffer; + final long index = lpProducerIndex(); + final long mask = producerMask; + final long offset = calcCircularRefElementOffset(index, mask); + // expected hot path + if (index < producerBufferLimit) { + writeToQueue(buffer, e, index, offset); + return true; + } + return offerColdPath(buffer, mask, index, offset, e, null); + } + + abstract boolean offerColdPath(E[] buffer, long mask, long pIndex, long offset, E v, Supplier s); + + /** + * {@inheritDoc} + *

+ * This implementation is correct for single consumer thread use only. + */ + @SuppressWarnings("unchecked") + @Override + public E poll() { + // local load of field to avoid repeated loads after volatile reads + final E[] buffer = consumerBuffer; + final long index = lpConsumerIndex(); + final long mask = consumerMask; + final long offset = calcCircularRefElementOffset(index, mask); + final Object e = lvRefElement(buffer, offset); + boolean isNextBuffer = e == JUMP; + if (null != e && !isNextBuffer) { + // this ensures correctness on 32bit platforms + soConsumerIndex(index + 1); + soRefElement(buffer, offset, null); + return (E) e; + } else if (isNextBuffer) { + return newBufferPoll(buffer, index); + } + return null; + } + + /** + * {@inheritDoc} + *

+ * This implementation is correct for single consumer thread use only. + */ + @SuppressWarnings("unchecked") + @Override + public E peek() { + final E[] buffer = consumerBuffer; + final long index = lpConsumerIndex(); + final long mask = consumerMask; + final long offset = calcCircularRefElementOffset(index, mask); + final Object e = lvRefElement(buffer, offset); + if (e == JUMP) { + return newBufferPeek(buffer, index); + } + return (E) e; + } + + final void linkOldToNew(final long currIndex, final E[] oldBuffer, final long offset, final E[] newBuffer, final long offsetInNew, final E e) { + soRefElement(newBuffer, offsetInNew, e); + // link to next buffer and add next indicator as element of old buffer + soNext(oldBuffer, newBuffer); + soRefElement(oldBuffer, offset, JUMP); + // index is visible after elements (isEmpty/poll ordering) + // this ensures atomic write of long on 32bit platforms + soProducerIndex(currIndex + 1); + } + + final void writeToQueue(final E[] buffer, final E e, final long index, final long offset) { + soRefElement(buffer, offset, e); + // this ensures atomic write of long on 32bit platforms + soProducerIndex(index + 1); + } + + private E newBufferPeek(final E[] buffer, final long index) { + E[] nextBuffer = lvNextArrayAndUnlink(buffer); + consumerBuffer = nextBuffer; + final long mask = length(nextBuffer) - 2; + consumerMask = mask; + final long offset = calcCircularRefElementOffset(index, mask); + return lvRefElement(nextBuffer, offset); + } + + private E newBufferPoll(final E[] buffer, final long index) { + E[] nextBuffer = lvNextArrayAndUnlink(buffer); + consumerBuffer = nextBuffer; + final long mask = length(nextBuffer) - 2; + consumerMask = mask; + final long offset = calcCircularRefElementOffset(index, mask); + final E n = lvRefElement(nextBuffer, offset); + if (null == n) { + throw new IllegalStateException("new buffer must have at least one element"); + } else { + // this ensures correctness on 32bit platforms + soConsumerIndex(index + 1); + soRefElement(nextBuffer, offset, null); + return n; + } + } +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/unpadded/ConcurrentCircularUnpaddedArrayQueue.java b/netty-jctools/src/main/java/org/jctools/queues/unpadded/ConcurrentCircularUnpaddedArrayQueue.java new file mode 100644 index 0000000..35d3bd1 --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/unpadded/ConcurrentCircularUnpaddedArrayQueue.java @@ -0,0 +1,154 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues.unpadded; + +import org.jctools.queues.IndexedQueueSizeUtil.IndexedQueue; +import org.jctools.util.Pow2; +import java.util.AbstractQueue; +import java.util.Iterator; +import java.util.NoSuchElementException; +import static org.jctools.util.UnsafeRefArrayAccess.*; +import org.jctools.queues.*; + +/** + * NOTE: This class was automatically generated by org.jctools.queues.unpadded.JavaParsingUnpaddedQueueGenerator + * which can found in the jctools-build module. The original source file is ConcurrentCircularArrayQueue.java. + */ +abstract class ConcurrentCircularUnpaddedArrayQueueL0Pad extends AbstractQueue { +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.unpadded.JavaParsingUnpaddedQueueGenerator + * which can found in the jctools-build module. The original source file is ConcurrentCircularArrayQueue.java. + * + * Common functionality for array backed queues. The class is pre-padded and the array is padded on either side to help + * with False Sharing prevention. It is expected that subclasses handle post padding. + */ +abstract class ConcurrentCircularUnpaddedArrayQueue extends ConcurrentCircularUnpaddedArrayQueueL0Pad implements MessagePassingQueue, IndexedQueue, QueueProgressIndicators, SupportsIterator { + + protected final long mask; + + protected final E[] buffer; + + ConcurrentCircularUnpaddedArrayQueue(int capacity) { + int actualCapacity = Pow2.roundToPowerOfTwo(capacity); + mask = actualCapacity - 1; + buffer = allocateRefArray(actualCapacity); + } + + @Override + public int size() { + return IndexedQueueSizeUtil.size(this, IndexedQueueSizeUtil.PLAIN_DIVISOR); + } + + @Override + public boolean isEmpty() { + return IndexedQueueSizeUtil.isEmpty(this); + } + + @Override + public String toString() { + return this.getClass().getName(); + } + + @Override + public void clear() { + while (poll() != null) { + // if you stare into the void + } + } + + @Override + public int capacity() { + return (int) (mask + 1); + } + + @Override + public long currentProducerIndex() { + return lvProducerIndex(); + } + + @Override + public long currentConsumerIndex() { + return lvConsumerIndex(); + } + + /** + * Get an iterator for this queue. This method is thread safe. + *

+ * The iterator provides a best-effort snapshot of the elements in the queue. + * The returned iterator is not guaranteed to return elements in queue order, + * and races with the consumer thread may cause gaps in the sequence of returned elements. + * Like {link #relaxedPoll}, the iterator may not immediately return newly inserted elements. + * + * @return The iterator. + */ + @Override + public Iterator iterator() { + final long cIndex = lvConsumerIndex(); + final long pIndex = lvProducerIndex(); + return new WeakIterator(cIndex, pIndex, mask, buffer); + } + + private static class WeakIterator implements Iterator { + + private final long pIndex; + + private final long mask; + + private final E[] buffer; + + private long nextIndex; + + private E nextElement; + + WeakIterator(long cIndex, long pIndex, long mask, E[] buffer) { + this.nextIndex = cIndex; + this.pIndex = pIndex; + this.mask = mask; + this.buffer = buffer; + nextElement = getNext(); + } + + @Override + public void remove() { + throw new UnsupportedOperationException("remove"); + } + + @Override + public boolean hasNext() { + return nextElement != null; + } + + @Override + public E next() { + final E e = nextElement; + if (e == null) + throw new NoSuchElementException(); + nextElement = getNext(); + return e; + } + + private E getNext() { + while (nextIndex < pIndex) { + long offset = calcCircularRefElementOffset(nextIndex++, mask); + E e = lvRefElement(buffer, offset); + if (e != null) { + return e; + } + } + return null; + } + } +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/unpadded/ConcurrentSequencedCircularUnpaddedArrayQueue.java b/netty-jctools/src/main/java/org/jctools/queues/unpadded/ConcurrentSequencedCircularUnpaddedArrayQueue.java new file mode 100644 index 0000000..7de58dd --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/unpadded/ConcurrentSequencedCircularUnpaddedArrayQueue.java @@ -0,0 +1,36 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues.unpadded; + +import static org.jctools.util.UnsafeLongArrayAccess.*; +import org.jctools.queues.*; + +/** + * NOTE: This class was automatically generated by org.jctools.queues.unpadded.JavaParsingUnpaddedQueueGenerator + * which can found in the jctools-build module. The original source file is ConcurrentSequencedCircularArrayQueue.java. + */ +public abstract class ConcurrentSequencedCircularUnpaddedArrayQueue extends ConcurrentCircularUnpaddedArrayQueue { + + protected final long[] sequenceBuffer; + + public ConcurrentSequencedCircularUnpaddedArrayQueue(int capacity) { + super(capacity); + int actualCapacity = (int) (this.mask + 1); + // pad data on either end with some empty slots. Note that actualCapacity is <= MAX_POW2_INT + sequenceBuffer = allocateLongArray(actualCapacity); + for (long i = 0; i < actualCapacity; i++) { + soLongElement(sequenceBuffer, calcCircularLongElementOffset(i, mask), i); + } + } +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/unpadded/MpmcUnpaddedArrayQueue.java b/netty-jctools/src/main/java/org/jctools/queues/unpadded/MpmcUnpaddedArrayQueue.java new file mode 100644 index 0000000..7d54dbb --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/unpadded/MpmcUnpaddedArrayQueue.java @@ -0,0 +1,512 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues.unpadded; + +import org.jctools.util.RangeUtil; +import static org.jctools.util.UnsafeAccess.UNSAFE; +import static org.jctools.util.UnsafeAccess.fieldOffset; +import static org.jctools.util.UnsafeLongArrayAccess.*; +import static org.jctools.util.UnsafeRefArrayAccess.*; +import org.jctools.queues.*; + +/** + * NOTE: This class was automatically generated by org.jctools.queues.unpadded.JavaParsingUnpaddedQueueGenerator + * which can found in the jctools-build module. The original source file is MpmcArrayQueue.java. + */ +abstract class MpmcUnpaddedArrayQueueL1Pad extends ConcurrentSequencedCircularUnpaddedArrayQueue { + + MpmcUnpaddedArrayQueueL1Pad(int capacity) { + super(capacity); + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.unpadded.JavaParsingUnpaddedQueueGenerator + * which can found in the jctools-build module. The original source file is MpmcArrayQueue.java. + */ +abstract class MpmcUnpaddedArrayQueueProducerIndexField extends MpmcUnpaddedArrayQueueL1Pad { + + private final static long P_INDEX_OFFSET = fieldOffset(MpmcUnpaddedArrayQueueProducerIndexField.class, "producerIndex"); + + private volatile long producerIndex; + + MpmcUnpaddedArrayQueueProducerIndexField(int capacity) { + super(capacity); + } + + @Override + public final long lvProducerIndex() { + return producerIndex; + } + + final boolean casProducerIndex(long expect, long newValue) { + return UNSAFE.compareAndSwapLong(this, P_INDEX_OFFSET, expect, newValue); + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.unpadded.JavaParsingUnpaddedQueueGenerator + * which can found in the jctools-build module. The original source file is MpmcArrayQueue.java. + */ +abstract class MpmcUnpaddedArrayQueueL2Pad extends MpmcUnpaddedArrayQueueProducerIndexField { + + MpmcUnpaddedArrayQueueL2Pad(int capacity) { + super(capacity); + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.unpadded.JavaParsingUnpaddedQueueGenerator + * which can found in the jctools-build module. The original source file is MpmcArrayQueue.java. + */ +abstract class MpmcUnpaddedArrayQueueConsumerIndexField extends MpmcUnpaddedArrayQueueL2Pad { + + private final static long C_INDEX_OFFSET = fieldOffset(MpmcUnpaddedArrayQueueConsumerIndexField.class, "consumerIndex"); + + private volatile long consumerIndex; + + MpmcUnpaddedArrayQueueConsumerIndexField(int capacity) { + super(capacity); + } + + @Override + public final long lvConsumerIndex() { + return consumerIndex; + } + + final boolean casConsumerIndex(long expect, long newValue) { + return UNSAFE.compareAndSwapLong(this, C_INDEX_OFFSET, expect, newValue); + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.unpadded.JavaParsingUnpaddedQueueGenerator + * which can found in the jctools-build module. The original source file is MpmcArrayQueue.java. + */ +abstract class MpmcUnpaddedArrayQueueL3Pad extends MpmcUnpaddedArrayQueueConsumerIndexField { + + MpmcUnpaddedArrayQueueL3Pad(int capacity) { + super(capacity); + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.unpadded.JavaParsingUnpaddedQueueGenerator + * which can found in the jctools-build module. The original source file is MpmcArrayQueue.java. + * + * A Multi-Producer-Multi-Consumer queue based on a {@link org.jctools.queues.ConcurrentCircularArrayQueue}. This + * implies that any and all threads may call the offer/poll/peek methods and correctness is maintained.
+ * This implementation follows patterns documented on the package level for False Sharing protection.
+ * The algorithm for offer/poll is an adaptation of the one put forward by D. Vyukov (See here). The original + * algorithm uses an array of structs which should offer nice locality properties but is sadly not possible in + * Java (waiting on Value Types or similar). The alternative explored here utilizes 2 arrays, one for each + * field of the struct. There is a further alternative in the experimental project which uses iteration phase + * markers to achieve the same algo and is closer structurally to the original, but sadly does not perform as + * well as this implementation.
+ *

+ * Tradeoffs to keep in mind: + *

    + *
  1. Padding for false sharing: counter fields and queue fields are all padded as well as either side of + * both arrays. We are trading memory to avoid false sharing(active and passive). + *
  2. 2 arrays instead of one: The algorithm requires an extra array of longs matching the size of the + * elements array. This is doubling/tripling the memory allocated for the buffer. + *
  3. Power of 2 capacity: Actual elements buffer (and sequence buffer) is the closest power of 2 larger or + * equal to the requested capacity. + *
+ */ +public class MpmcUnpaddedArrayQueue extends MpmcUnpaddedArrayQueueL3Pad { + + public static final int MAX_LOOK_AHEAD_STEP = Integer.getInteger("jctools.mpmc.max.lookahead.step", 4096); + + private final int lookAheadStep; + + public MpmcUnpaddedArrayQueue(final int capacity) { + super(RangeUtil.checkGreaterThanOrEqual(capacity, 2, "capacity")); + lookAheadStep = Math.max(2, Math.min(capacity() / 4, MAX_LOOK_AHEAD_STEP)); + } + + @Override + public boolean offer(final E e) { + if (null == e) { + throw new NullPointerException(); + } + final long mask = this.mask; + final long capacity = mask + 1; + final long[] sBuffer = sequenceBuffer; + long pIndex; + long seqOffset; + long seq; + // start with bogus value, hope we don't need it + long cIndex = Long.MIN_VALUE; + do { + pIndex = lvProducerIndex(); + seqOffset = calcCircularLongElementOffset(pIndex, mask); + seq = lvLongElement(sBuffer, seqOffset); + // consumer has not moved this seq forward, it's as last producer left + if (seq < pIndex) { + // Extra check required to ensure [Queue.offer == false iff queue is full] + if (// test against cached cIndex + pIndex - capacity >= cIndex && // test against latest cIndex + pIndex - capacity >= (cIndex = lvConsumerIndex())) { + return false; + } else { + // (+) hack to make it go around again without CAS + seq = pIndex + 1; + } + } + } while (// another producer has moved the sequence(or +) + seq > pIndex || // failed to increment + !casProducerIndex(pIndex, pIndex + 1)); + // casProducerIndex ensures correct construction + spRefElement(buffer, calcCircularRefElementOffset(pIndex, mask), e); + // seq++; + soLongElement(sBuffer, seqOffset, pIndex + 1); + return true; + } + + /** + * {@inheritDoc} + *

+ * Because return null indicates queue is empty we cannot simply rely on next element visibility for poll + * and must test producer index when next element is not visible. + */ + @Override + public E poll() { + // local load of field to avoid repeated loads after volatile reads + final long[] sBuffer = sequenceBuffer; + final long mask = this.mask; + long cIndex; + long seq; + long seqOffset; + long expectedSeq; + // start with bogus value, hope we don't need it + long pIndex = -1; + do { + cIndex = lvConsumerIndex(); + seqOffset = calcCircularLongElementOffset(cIndex, mask); + seq = lvLongElement(sBuffer, seqOffset); + expectedSeq = cIndex + 1; + if (seq < expectedSeq) { + // slot has not been moved by producer + if (// test against cached pIndex + cIndex >= pIndex && // update pIndex if we must + cIndex == (pIndex = lvProducerIndex())) { + // strict empty check, this ensures [Queue.poll() == null iff isEmpty()] + return null; + } else { + // trip another go around + seq = expectedSeq + 1; + } + } + } while (// another consumer beat us to it + seq > expectedSeq || // failed the CAS + !casConsumerIndex(cIndex, cIndex + 1)); + final long offset = calcCircularRefElementOffset(cIndex, mask); + final E e = lpRefElement(buffer, offset); + spRefElement(buffer, offset, null); + // i.e. seq += capacity + soLongElement(sBuffer, seqOffset, cIndex + mask + 1); + return e; + } + + @Override + public E peek() { + // local load of field to avoid repeated loads after volatile reads + final long[] sBuffer = sequenceBuffer; + final long mask = this.mask; + long cIndex; + long seq; + long seqOffset; + long expectedSeq; + // start with bogus value, hope we don't need it + long pIndex = -1; + E e; + while (true) { + cIndex = lvConsumerIndex(); + seqOffset = calcCircularLongElementOffset(cIndex, mask); + seq = lvLongElement(sBuffer, seqOffset); + expectedSeq = cIndex + 1; + if (seq < expectedSeq) { + // slot has not been moved by producer + if (// test against cached pIndex + cIndex >= pIndex && // update pIndex if we must + cIndex == (pIndex = lvProducerIndex())) { + // strict empty check, this ensures [Queue.poll() == null iff isEmpty()] + return null; + } + } else if (seq == expectedSeq) { + final long offset = calcCircularRefElementOffset(cIndex, mask); + e = lvRefElement(buffer, offset); + if (lvConsumerIndex() == cIndex) + return e; + } + } + } + + @Override + public boolean relaxedOffer(E e) { + if (null == e) { + throw new NullPointerException(); + } + final long mask = this.mask; + final long[] sBuffer = sequenceBuffer; + long pIndex; + long seqOffset; + long seq; + do { + pIndex = lvProducerIndex(); + seqOffset = calcCircularLongElementOffset(pIndex, mask); + seq = lvLongElement(sBuffer, seqOffset); + if (seq < pIndex) { + // slot not cleared by consumer yet + return false; + } + } while (// another producer has moved the sequence + seq > pIndex || // failed to increment + !casProducerIndex(pIndex, pIndex + 1)); + // casProducerIndex ensures correct construction + spRefElement(buffer, calcCircularRefElementOffset(pIndex, mask), e); + soLongElement(sBuffer, seqOffset, pIndex + 1); + return true; + } + + @Override + public E relaxedPoll() { + final long[] sBuffer = sequenceBuffer; + final long mask = this.mask; + long cIndex; + long seqOffset; + long seq; + long expectedSeq; + do { + cIndex = lvConsumerIndex(); + seqOffset = calcCircularLongElementOffset(cIndex, mask); + seq = lvLongElement(sBuffer, seqOffset); + expectedSeq = cIndex + 1; + if (seq < expectedSeq) { + return null; + } + } while (// another consumer beat us to it + seq > expectedSeq || // failed the CAS + !casConsumerIndex(cIndex, cIndex + 1)); + final long offset = calcCircularRefElementOffset(cIndex, mask); + final E e = lpRefElement(buffer, offset); + spRefElement(buffer, offset, null); + soLongElement(sBuffer, seqOffset, cIndex + mask + 1); + return e; + } + + @Override + public E relaxedPeek() { + // local load of field to avoid repeated loads after volatile reads + final long[] sBuffer = sequenceBuffer; + final long mask = this.mask; + long cIndex; + long seq; + long seqOffset; + long expectedSeq; + E e; + do { + cIndex = lvConsumerIndex(); + seqOffset = calcCircularLongElementOffset(cIndex, mask); + seq = lvLongElement(sBuffer, seqOffset); + expectedSeq = cIndex + 1; + if (seq < expectedSeq) { + return null; + } else if (seq == expectedSeq) { + final long offset = calcCircularRefElementOffset(cIndex, mask); + e = lvRefElement(buffer, offset); + if (lvConsumerIndex() == cIndex) + return e; + } + } while (true); + } + + @Override + public int drain(Consumer c, int limit) { + if (null == c) + throw new IllegalArgumentException("c is null"); + if (limit < 0) + throw new IllegalArgumentException("limit is negative: " + limit); + if (limit == 0) + return 0; + final long[] sBuffer = sequenceBuffer; + final long mask = this.mask; + final E[] buffer = this.buffer; + final int maxLookAheadStep = Math.min(this.lookAheadStep, limit); + int consumed = 0; + while (consumed < limit) { + final int remaining = limit - consumed; + final int lookAheadStep = Math.min(remaining, maxLookAheadStep); + final long cIndex = lvConsumerIndex(); + final long lookAheadIndex = cIndex + lookAheadStep - 1; + final long lookAheadSeqOffset = calcCircularLongElementOffset(lookAheadIndex, mask); + final long lookAheadSeq = lvLongElement(sBuffer, lookAheadSeqOffset); + final long expectedLookAheadSeq = lookAheadIndex + 1; + if (lookAheadSeq == expectedLookAheadSeq && casConsumerIndex(cIndex, expectedLookAheadSeq)) { + for (int i = 0; i < lookAheadStep; i++) { + final long index = cIndex + i; + final long seqOffset = calcCircularLongElementOffset(index, mask); + final long offset = calcCircularRefElementOffset(index, mask); + final long expectedSeq = index + 1; + while (lvLongElement(sBuffer, seqOffset) != expectedSeq) { + } + final E e = lpRefElement(buffer, offset); + spRefElement(buffer, offset, null); + soLongElement(sBuffer, seqOffset, index + mask + 1); + c.accept(e); + } + consumed += lookAheadStep; + } else { + if (lookAheadSeq < expectedLookAheadSeq) { + if (notAvailable(cIndex, mask, sBuffer, cIndex + 1)) { + return consumed; + } + } + return consumed + drainOneByOne(c, remaining); + } + } + return limit; + } + + private int drainOneByOne(Consumer c, int limit) { + final long[] sBuffer = sequenceBuffer; + final long mask = this.mask; + final E[] buffer = this.buffer; + long cIndex; + long seqOffset; + long seq; + long expectedSeq; + for (int i = 0; i < limit; i++) { + do { + cIndex = lvConsumerIndex(); + seqOffset = calcCircularLongElementOffset(cIndex, mask); + seq = lvLongElement(sBuffer, seqOffset); + expectedSeq = cIndex + 1; + if (seq < expectedSeq) { + return i; + } + } while (// another consumer beat us to it + seq > expectedSeq || // failed the CAS + !casConsumerIndex(cIndex, cIndex + 1)); + final long offset = calcCircularRefElementOffset(cIndex, mask); + final E e = lpRefElement(buffer, offset); + spRefElement(buffer, offset, null); + soLongElement(sBuffer, seqOffset, cIndex + mask + 1); + c.accept(e); + } + return limit; + } + + @Override + public int fill(Supplier s, int limit) { + if (null == s) + throw new IllegalArgumentException("supplier is null"); + if (limit < 0) + throw new IllegalArgumentException("limit is negative:" + limit); + if (limit == 0) + return 0; + final long[] sBuffer = sequenceBuffer; + final long mask = this.mask; + final E[] buffer = this.buffer; + final int maxLookAheadStep = Math.min(this.lookAheadStep, limit); + int produced = 0; + while (produced < limit) { + final int remaining = limit - produced; + final int lookAheadStep = Math.min(remaining, maxLookAheadStep); + final long pIndex = lvProducerIndex(); + final long lookAheadIndex = pIndex + lookAheadStep - 1; + final long lookAheadSeqOffset = calcCircularLongElementOffset(lookAheadIndex, mask); + final long lookAheadSeq = lvLongElement(sBuffer, lookAheadSeqOffset); + final long expectedLookAheadSeq = lookAheadIndex; + if (lookAheadSeq == expectedLookAheadSeq && casProducerIndex(pIndex, expectedLookAheadSeq + 1)) { + for (int i = 0; i < lookAheadStep; i++) { + final long index = pIndex + i; + final long seqOffset = calcCircularLongElementOffset(index, mask); + final long offset = calcCircularRefElementOffset(index, mask); + while (lvLongElement(sBuffer, seqOffset) != index) { + } + // Ordered store ensures correct construction + soRefElement(buffer, offset, s.get()); + soLongElement(sBuffer, seqOffset, index + 1); + } + produced += lookAheadStep; + } else { + if (lookAheadSeq < expectedLookAheadSeq) { + if (notAvailable(pIndex, mask, sBuffer, pIndex)) { + return produced; + } + } + return produced + fillOneByOne(s, remaining); + } + } + return limit; + } + + private boolean notAvailable(long index, long mask, long[] sBuffer, long expectedSeq) { + final long seqOffset = calcCircularLongElementOffset(index, mask); + final long seq = lvLongElement(sBuffer, seqOffset); + if (seq < expectedSeq) { + return true; + } + return false; + } + + private int fillOneByOne(Supplier s, int limit) { + final long[] sBuffer = sequenceBuffer; + final long mask = this.mask; + final E[] buffer = this.buffer; + long pIndex; + long seqOffset; + long seq; + for (int i = 0; i < limit; i++) { + do { + pIndex = lvProducerIndex(); + seqOffset = calcCircularLongElementOffset(pIndex, mask); + seq = lvLongElement(sBuffer, seqOffset); + if (seq < pIndex) { + // slot not cleared by consumer yet + return i; + } + } while (// another producer has moved the sequence + seq > pIndex || // failed to increment + !casProducerIndex(pIndex, pIndex + 1)); + // Ordered store ensures correct construction + soRefElement(buffer, calcCircularRefElementOffset(pIndex, mask), s.get()); + soLongElement(sBuffer, seqOffset, pIndex + 1); + } + return limit; + } + + @Override + public int drain(Consumer c) { + return MessagePassingQueueUtil.drain(this, c); + } + + @Override + public int fill(Supplier s) { + return MessagePassingQueueUtil.fillBounded(this, s); + } + + @Override + public void drain(Consumer c, WaitStrategy w, ExitCondition exit) { + MessagePassingQueueUtil.drain(this, c, w, exit); + } + + @Override + public void fill(Supplier s, WaitStrategy wait, ExitCondition exit) { + MessagePassingQueueUtil.fill(this, s, wait, exit); + } +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/unpadded/MpscChunkedUnpaddedArrayQueue.java b/netty-jctools/src/main/java/org/jctools/queues/unpadded/MpscChunkedUnpaddedArrayQueue.java new file mode 100644 index 0000000..351dc85 --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/unpadded/MpscChunkedUnpaddedArrayQueue.java @@ -0,0 +1,84 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues.unpadded; + +import org.jctools.util.Pow2; +import org.jctools.util.RangeUtil; +import static java.lang.Math.max; +import static java.lang.Math.min; +import static org.jctools.queues.LinkedArrayQueueUtil.length; +import static org.jctools.util.Pow2.roundToPowerOfTwo; +import org.jctools.queues.*; + +/** + * NOTE: This class was automatically generated by org.jctools.queues.unpadded.JavaParsingUnpaddedQueueGenerator + * which can found in the jctools-build module. The original source file is MpscChunkedArrayQueue.java. + */ +abstract class MpscChunkedUnpaddedArrayQueueColdProducerFields extends BaseMpscLinkedUnpaddedArrayQueue { + + protected final long maxQueueCapacity; + + MpscChunkedUnpaddedArrayQueueColdProducerFields(int initialCapacity, int maxCapacity) { + super(initialCapacity); + RangeUtil.checkGreaterThanOrEqual(maxCapacity, 4, "maxCapacity"); + RangeUtil.checkLessThan(roundToPowerOfTwo(initialCapacity), roundToPowerOfTwo(maxCapacity), "initialCapacity"); + maxQueueCapacity = ((long) Pow2.roundToPowerOfTwo(maxCapacity)) << 1; + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.unpadded.JavaParsingUnpaddedQueueGenerator + * which can found in the jctools-build module. The original source file is MpscChunkedArrayQueue.java. + * + * An MPSC array queue which starts at initialCapacity and grows to maxCapacity in linked chunks + * of the initial size. The queue grows only when the current chunk is full and elements are not copied on + * resize, instead a link to the new chunk is stored in the old chunk for the consumer to follow. + */ +public class MpscChunkedUnpaddedArrayQueue extends MpscChunkedUnpaddedArrayQueueColdProducerFields { + + public MpscChunkedUnpaddedArrayQueue(int maxCapacity) { + super(max(2, min(1024, roundToPowerOfTwo(maxCapacity / 8))), maxCapacity); + } + + /** + * @param initialCapacity the queue initial capacity. If chunk size is fixed this will be the chunk size. + * Must be 2 or more. + * @param maxCapacity the maximum capacity will be rounded up to the closest power of 2 and will be the + * upper limit of number of elements in this queue. Must be 4 or more and round up to a larger + * power of 2 than initialCapacity. + */ + public MpscChunkedUnpaddedArrayQueue(int initialCapacity, int maxCapacity) { + super(initialCapacity, maxCapacity); + } + + @Override + protected long availableInQueue(long pIndex, long cIndex) { + return maxQueueCapacity - (pIndex - cIndex); + } + + @Override + public int capacity() { + return (int) (maxQueueCapacity / 2); + } + + @Override + protected int getNextBufferSize(E[] buffer) { + return length(buffer); + } + + @Override + protected long getCurrentBufferCapacity(long mask) { + return mask; + } +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/unpadded/MpscGrowableUnpaddedArrayQueue.java b/netty-jctools/src/main/java/org/jctools/queues/unpadded/MpscGrowableUnpaddedArrayQueue.java new file mode 100644 index 0000000..f8ac608 --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/unpadded/MpscGrowableUnpaddedArrayQueue.java @@ -0,0 +1,59 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues.unpadded; + +import org.jctools.util.Pow2; +import org.jctools.util.RangeUtil; +import static org.jctools.queues.LinkedArrayQueueUtil.length; +import org.jctools.queues.*; + +/** + * NOTE: This class was automatically generated by org.jctools.queues.unpadded.JavaParsingUnpaddedQueueGenerator + * which can found in the jctools-build module. The original source file is MpscGrowableArrayQueue.java. + * + * An MPSC array queue which starts at initialCapacity and grows to maxCapacity in linked chunks, + * doubling theirs size every time until the full blown backing array is used. + * The queue grows only when the current chunk is full and elements are not copied on + * resize, instead a link to the new chunk is stored in the old chunk for the consumer to follow. + */ +public class MpscGrowableUnpaddedArrayQueue extends MpscChunkedUnpaddedArrayQueue { + + public MpscGrowableUnpaddedArrayQueue(int maxCapacity) { + super(Math.max(2, Pow2.roundToPowerOfTwo(maxCapacity / 8)), maxCapacity); + } + + /** + * @param initialCapacity the queue initial capacity. If chunk size is fixed this will be the chunk size. + * Must be 2 or more. + * @param maxCapacity the maximum capacity will be rounded up to the closest power of 2 and will be the + * upper limit of number of elements in this queue. Must be 4 or more and round up to a larger + * power of 2 than initialCapacity. + */ + public MpscGrowableUnpaddedArrayQueue(int initialCapacity, int maxCapacity) { + super(initialCapacity, maxCapacity); + } + + @Override + protected int getNextBufferSize(E[] buffer) { + final long maxSize = maxQueueCapacity / 2; + RangeUtil.checkLessThanOrEqual(length(buffer), maxSize, "buffer.length"); + final int newSize = 2 * (length(buffer) - 1); + return newSize + 1; + } + + @Override + protected long getCurrentBufferCapacity(long mask) { + return (mask + 2 == maxQueueCapacity) ? maxQueueCapacity : mask; + } +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/unpadded/MpscLinkedUnpaddedQueue.java b/netty-jctools/src/main/java/org/jctools/queues/unpadded/MpscLinkedUnpaddedQueue.java new file mode 100644 index 0000000..1dbb816 --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/unpadded/MpscLinkedUnpaddedQueue.java @@ -0,0 +1,171 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues.unpadded; + +import org.jctools.util.UnsafeAccess; +import java.util.Queue; +import static org.jctools.util.UnsafeAccess.UNSAFE; +import org.jctools.queues.*; + +/** + * NOTE: This class was automatically generated by org.jctools.queues.unpadded.JavaParsingUnpaddedQueueGenerator + * which can found in the jctools-build module. The original source file is MpscLinkedQueue.java. + * + * This is a Java port of the MPSC algorithm as presented + * on + * 1024 Cores by D. Vyukov. The original has been adapted to Java and it's quirks with regards to memory + * model and layout: + *

    + *
  1. Use inheritance to ensure no false sharing occurs between producer/consumer node reference fields. + *
  2. Use XCHG functionality to the best of the JDK ability (see differences in JDK7/8 impls). + *
  3. Conform to {@link java.util.Queue} contract on poll. The original semantics are available via relaxedPoll. + *
+ * The queue is initialized with a stub node which is set to both the producer and consumer node references. + * From this point follow the notes on offer/poll. + */ +public class MpscLinkedUnpaddedQueue extends BaseLinkedUnpaddedQueue { + + public MpscLinkedUnpaddedQueue() { + LinkedQueueNode node = newNode(); + spConsumerNode(node); + xchgProducerNode(node); + } + + /** + * {@inheritDoc}
+ *

+ * IMPLEMENTATION NOTES:
+ * Offer is allowed from multiple threads.
+ * Offer allocates a new node and: + *

    + *
  1. Swaps it atomically with current producer node (only one producer 'wins') + *
  2. Sets the new node as the node following from the swapped producer node + *
+ * This works because each producer is guaranteed to 'plant' a new node and link the old node. No 2 + * producers can get the same producer node as part of XCHG guarantee. + * + * @see MessagePassingQueue#offer(Object) + * @see java.util.Queue#offer(java.lang.Object) + */ + @Override + public boolean offer(final E e) { + if (null == e) { + throw new NullPointerException(); + } + final LinkedQueueNode nextNode = newNode(e); + final LinkedQueueNode prevProducerNode = xchgProducerNode(nextNode); + // Should a producer thread get interrupted here the chain WILL be broken until that thread is resumed + // and completes the store in prev.next. This is a "bubble". + prevProducerNode.soNext(nextNode); + return true; + } + + /** + * {@inheritDoc} + *

+ * This method is only safe to call from the (single) consumer thread, and is subject to best effort when racing + * with producers. This method is potentially blocking when "bubble"s in the queue are visible. + */ + @Override + public boolean remove(Object o) { + if (null == o) { + // Null elements are not permitted, so null will never be removed. + return false; + } + final LinkedQueueNode originalConsumerNode = lpConsumerNode(); + LinkedQueueNode prevConsumerNode = originalConsumerNode; + LinkedQueueNode currConsumerNode = getNextConsumerNode(originalConsumerNode); + while (currConsumerNode != null) { + if (o.equals(currConsumerNode.lpValue())) { + LinkedQueueNode nextNode = getNextConsumerNode(currConsumerNode); + // e.g.: consumerNode -> node0 -> node1(o==v) -> node2 ... => consumerNode -> node0 -> node2 + if (nextNode != null) { + // We are removing an interior node. + prevConsumerNode.soNext(nextNode); + } else // This case reflects: prevConsumerNode != originalConsumerNode && nextNode == null + // At rest, this would be the producerNode, but we must contend with racing. Changes to subclassed + // queues need to consider remove() when implementing offer(). + { + // producerNode is currConsumerNode, try to atomically update the reference to move it to the + // previous node. + prevConsumerNode.soNext(null); + if (!casProducerNode(currConsumerNode, prevConsumerNode)) { + // If the producer(s) have offered more items we need to remove the currConsumerNode link. + nextNode = spinWaitForNextNode(currConsumerNode); + prevConsumerNode.soNext(nextNode); + } + } + // Avoid GC nepotism because we are discarding the current node. + currConsumerNode.soNext(null); + currConsumerNode.spValue(null); + return true; + } + prevConsumerNode = currConsumerNode; + currConsumerNode = getNextConsumerNode(currConsumerNode); + } + return false; + } + + @Override + public int fill(Supplier s) { + return MessagePassingQueueUtil.fillUnbounded(this, s); + } + + @Override + public int fill(Supplier s, int limit) { + if (null == s) + throw new IllegalArgumentException("supplier is null"); + if (limit < 0) + throw new IllegalArgumentException("limit is negative:" + limit); + if (limit == 0) + return 0; + LinkedQueueNode tail = newNode(s.get()); + final LinkedQueueNode head = tail; + for (int i = 1; i < limit; i++) { + final LinkedQueueNode temp = newNode(s.get()); + // spNext: xchgProducerNode ensures correct construction + tail.spNext(temp); + tail = temp; + } + final LinkedQueueNode oldPNode = xchgProducerNode(tail); + oldPNode.soNext(head); + return limit; + } + + @Override + public void fill(Supplier s, WaitStrategy wait, ExitCondition exit) { + MessagePassingQueueUtil.fill(this, s, wait, exit); + } + + // $gen:ignore + private LinkedQueueNode xchgProducerNode(LinkedQueueNode newVal) { + if (UnsafeAccess.SUPPORTS_GET_AND_SET_REF) { + return (LinkedQueueNode) UNSAFE.getAndSetObject(this, P_NODE_OFFSET, newVal); + } else { + LinkedQueueNode oldVal; + do { + oldVal = lvProducerNode(); + } while (!UNSAFE.compareAndSwapObject(this, P_NODE_OFFSET, oldVal, newVal)); + return oldVal; + } + } + + private LinkedQueueNode getNextConsumerNode(LinkedQueueNode currConsumerNode) { + LinkedQueueNode nextNode = currConsumerNode.lvNext(); + if (nextNode == null && currConsumerNode != lvProducerNode()) { + nextNode = spinWaitForNextNode(currConsumerNode); + } + return nextNode; + } +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/unpadded/MpscUnboundedUnpaddedArrayQueue.java b/netty-jctools/src/main/java/org/jctools/queues/unpadded/MpscUnboundedUnpaddedArrayQueue.java new file mode 100644 index 0000000..3904755 --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/unpadded/MpscUnboundedUnpaddedArrayQueue.java @@ -0,0 +1,62 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues.unpadded; + +import static org.jctools.queues.LinkedArrayQueueUtil.length; +import org.jctools.queues.*; + +/** + * NOTE: This class was automatically generated by org.jctools.queues.unpadded.JavaParsingUnpaddedQueueGenerator + * which can found in the jctools-build module. The original source file is MpscUnboundedArrayQueue.java. + * + * An MPSC array queue which starts at initialCapacity and grows indefinitely in linked chunks of the initial size. + * The queue grows only when the current chunk is full and elements are not copied on + * resize, instead a link to the new chunk is stored in the old chunk for the consumer to follow. + */ +public class MpscUnboundedUnpaddedArrayQueue extends BaseMpscLinkedUnpaddedArrayQueue { + + public MpscUnboundedUnpaddedArrayQueue(int chunkSize) { + super(chunkSize); + } + + @Override + protected long availableInQueue(long pIndex, long cIndex) { + return Integer.MAX_VALUE; + } + + @Override + public int capacity() { + return MessagePassingQueue.UNBOUNDED_CAPACITY; + } + + @Override + public int drain(Consumer c) { + return drain(c, 4096); + } + + @Override + public int fill(Supplier s) { + return MessagePassingQueueUtil.fillUnbounded(this, s); + } + + @Override + protected int getNextBufferSize(E[] buffer) { + return length(buffer); + } + + @Override + protected long getCurrentBufferCapacity(long mask) { + return mask; + } +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/unpadded/MpscUnpaddedArrayQueue.java b/netty-jctools/src/main/java/org/jctools/queues/unpadded/MpscUnpaddedArrayQueue.java new file mode 100644 index 0000000..a12903c --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/unpadded/MpscUnpaddedArrayQueue.java @@ -0,0 +1,469 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues.unpadded; + +import static org.jctools.util.UnsafeAccess.UNSAFE; +import static org.jctools.util.UnsafeAccess.fieldOffset; +import static org.jctools.util.UnsafeRefArrayAccess.*; +import org.jctools.queues.*; + +/** + * NOTE: This class was automatically generated by org.jctools.queues.unpadded.JavaParsingUnpaddedQueueGenerator + * which can found in the jctools-build module. The original source file is MpscArrayQueue.java. + */ +abstract class MpscUnpaddedArrayQueueL1Pad extends ConcurrentCircularUnpaddedArrayQueue { + + MpscUnpaddedArrayQueueL1Pad(int capacity) { + super(capacity); + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.unpadded.JavaParsingUnpaddedQueueGenerator + * which can found in the jctools-build module. The original source file is MpscArrayQueue.java. + */ +abstract class MpscUnpaddedArrayQueueProducerIndexField extends MpscUnpaddedArrayQueueL1Pad { + + private final static long P_INDEX_OFFSET = fieldOffset(MpscUnpaddedArrayQueueProducerIndexField.class, "producerIndex"); + + private volatile long producerIndex; + + MpscUnpaddedArrayQueueProducerIndexField(int capacity) { + super(capacity); + } + + @Override + public final long lvProducerIndex() { + return producerIndex; + } + + final boolean casProducerIndex(long expect, long newValue) { + return UNSAFE.compareAndSwapLong(this, P_INDEX_OFFSET, expect, newValue); + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.unpadded.JavaParsingUnpaddedQueueGenerator + * which can found in the jctools-build module. The original source file is MpscArrayQueue.java. + */ +abstract class MpscUnpaddedArrayQueueMidPad extends MpscUnpaddedArrayQueueProducerIndexField { + + MpscUnpaddedArrayQueueMidPad(int capacity) { + super(capacity); + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.unpadded.JavaParsingUnpaddedQueueGenerator + * which can found in the jctools-build module. The original source file is MpscArrayQueue.java. + */ +abstract class MpscUnpaddedArrayQueueProducerLimitField extends MpscUnpaddedArrayQueueMidPad { + + private final static long P_LIMIT_OFFSET = fieldOffset(MpscUnpaddedArrayQueueProducerLimitField.class, "producerLimit"); + + // First unavailable index the producer may claim up to before rereading the consumer index + private volatile long producerLimit; + + MpscUnpaddedArrayQueueProducerLimitField(int capacity) { + super(capacity); + this.producerLimit = capacity; + } + + final long lvProducerLimit() { + return producerLimit; + } + + final void soProducerLimit(long newValue) { + UNSAFE.putOrderedLong(this, P_LIMIT_OFFSET, newValue); + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.unpadded.JavaParsingUnpaddedQueueGenerator + * which can found in the jctools-build module. The original source file is MpscArrayQueue.java. + */ +abstract class MpscUnpaddedArrayQueueL2Pad extends MpscUnpaddedArrayQueueProducerLimitField { + + MpscUnpaddedArrayQueueL2Pad(int capacity) { + super(capacity); + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.unpadded.JavaParsingUnpaddedQueueGenerator + * which can found in the jctools-build module. The original source file is MpscArrayQueue.java. + */ +abstract class MpscUnpaddedArrayQueueConsumerIndexField extends MpscUnpaddedArrayQueueL2Pad { + + private final static long C_INDEX_OFFSET = fieldOffset(MpscUnpaddedArrayQueueConsumerIndexField.class, "consumerIndex"); + + private volatile long consumerIndex; + + MpscUnpaddedArrayQueueConsumerIndexField(int capacity) { + super(capacity); + } + + @Override + public final long lvConsumerIndex() { + return consumerIndex; + } + + final long lpConsumerIndex() { + return UNSAFE.getLong(this, C_INDEX_OFFSET); + } + + final void soConsumerIndex(long newValue) { + UNSAFE.putOrderedLong(this, C_INDEX_OFFSET, newValue); + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.unpadded.JavaParsingUnpaddedQueueGenerator + * which can found in the jctools-build module. The original source file is MpscArrayQueue.java. + */ +abstract class MpscUnpaddedArrayQueueL3Pad extends MpscUnpaddedArrayQueueConsumerIndexField { + + MpscUnpaddedArrayQueueL3Pad(int capacity) { + super(capacity); + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.unpadded.JavaParsingUnpaddedQueueGenerator + * which can found in the jctools-build module. The original source file is MpscArrayQueue.java. + * + * A Multi-Producer-Single-Consumer queue based on a {@link org.jctools.queues.ConcurrentCircularArrayQueue}. This + * implies that any thread may call the offer method, but only a single thread may call poll/peek for correctness to + * maintained.
+ * This implementation follows patterns documented on the package level for False Sharing protection.
+ * This implementation is using the Fast Flow + * method for polling from the queue (with minor change to correctly publish the index) and an extension of + * the Leslie Lamport concurrent queue algorithm (originated by Martin Thompson) on the producer side. + */ +public class MpscUnpaddedArrayQueue extends MpscUnpaddedArrayQueueL3Pad { + + public MpscUnpaddedArrayQueue(final int capacity) { + super(capacity); + } + + /** + * {@link #offer}} if {@link #size()} is less than threshold. + * + * @param e the object to offer onto the queue, not null + * @param threshold the maximum allowable size + * @return true if the offer is successful, false if queue size exceeds threshold + * @since 1.0.1 + */ + public boolean offerIfBelowThreshold(final E e, int threshold) { + if (null == e) { + throw new NullPointerException(); + } + final long mask = this.mask; + final long capacity = mask + 1; + long producerLimit = lvProducerLimit(); + long pIndex; + do { + pIndex = lvProducerIndex(); + long available = producerLimit - pIndex; + long size = capacity - available; + if (size >= threshold) { + final long cIndex = lvConsumerIndex(); + size = pIndex - cIndex; + if (size >= threshold) { + // the size exceeds threshold + return false; + } else { + // update producer limit to the next index that we must recheck the consumer index + producerLimit = cIndex + capacity; + // this is racy, but the race is benign + soProducerLimit(producerLimit); + } + } + } while (!casProducerIndex(pIndex, pIndex + 1)); + /* + * NOTE: the new producer index value is made visible BEFORE the element in the array. If we relied on + * the index visibility to poll() we would need to handle the case where the element is not visible. + */ + // Won CAS, move on to storing + final long offset = calcCircularRefElementOffset(pIndex, mask); + soRefElement(buffer, offset, e); + // AWESOME :) + return true; + } + + /** + * {@inheritDoc}
+ *

+ * IMPLEMENTATION NOTES:
+ * Lock free offer using a single CAS. As class name suggests access is permitted to many threads + * concurrently. + * + * @see java.util.Queue#offer + * @see org.jctools.queues.MessagePassingQueue#offer + */ + @Override + public boolean offer(final E e) { + if (null == e) { + throw new NullPointerException(); + } + // use a cached view on consumer index (potentially updated in loop) + final long mask = this.mask; + long producerLimit = lvProducerLimit(); + long pIndex; + do { + pIndex = lvProducerIndex(); + if (pIndex >= producerLimit) { + final long cIndex = lvConsumerIndex(); + producerLimit = cIndex + mask + 1; + if (pIndex >= producerLimit) { + // FULL :( + return false; + } else { + // update producer limit to the next index that we must recheck the consumer index + // this is racy, but the race is benign + soProducerLimit(producerLimit); + } + } + } while (!casProducerIndex(pIndex, pIndex + 1)); + /* + * NOTE: the new producer index value is made visible BEFORE the element in the array. If we relied on + * the index visibility to poll() we would need to handle the case where the element is not visible. + */ + // Won CAS, move on to storing + final long offset = calcCircularRefElementOffset(pIndex, mask); + soRefElement(buffer, offset, e); + // AWESOME :) + return true; + } + + /** + * A wait free alternative to offer which fails on CAS failure. + * + * @param e new element, not null + * @return 1 if next element cannot be filled, -1 if CAS failed, 0 if successful + */ + public final int failFastOffer(final E e) { + if (null == e) { + throw new NullPointerException(); + } + final long mask = this.mask; + final long capacity = mask + 1; + final long pIndex = lvProducerIndex(); + long producerLimit = lvProducerLimit(); + if (pIndex >= producerLimit) { + final long cIndex = lvConsumerIndex(); + producerLimit = cIndex + capacity; + if (pIndex >= producerLimit) { + // FULL :( + return 1; + } else { + // update producer limit to the next index that we must recheck the consumer index + soProducerLimit(producerLimit); + } + } + // look Ma, no loop! + if (!casProducerIndex(pIndex, pIndex + 1)) { + // CAS FAIL :( + return -1; + } + // Won CAS, move on to storing + final long offset = calcCircularRefElementOffset(pIndex, mask); + soRefElement(buffer, offset, e); + // AWESOME :) + return 0; + } + + /** + * {@inheritDoc} + *

+ * IMPLEMENTATION NOTES:
+ * Lock free poll using ordered loads/stores. As class name suggests access is limited to a single thread. + * + * @see java.util.Queue#poll + * @see org.jctools.queues.MessagePassingQueue#poll + */ + @Override + public E poll() { + final long cIndex = lpConsumerIndex(); + final long offset = calcCircularRefElementOffset(cIndex, mask); + // Copy field to avoid re-reading after volatile load + final E[] buffer = this.buffer; + // If we can't see the next available element we can't poll + E e = lvRefElement(buffer, offset); + if (null == e) { + /* + * NOTE: Queue may not actually be empty in the case of a producer (P1) being interrupted after + * winning the CAS on offer but before storing the element in the queue. Other producers may go on + * to fill up the queue after this element. + */ + if (cIndex != lvProducerIndex()) { + do { + e = lvRefElement(buffer, offset); + } while (e == null); + } else { + return null; + } + } + spRefElement(buffer, offset, null); + soConsumerIndex(cIndex + 1); + return e; + } + + /** + * {@inheritDoc} + *

+ * IMPLEMENTATION NOTES:
+ * Lock free peek using ordered loads. As class name suggests access is limited to a single thread. + * + * @see java.util.Queue#poll + * @see org.jctools.queues.MessagePassingQueue#poll + */ + @Override + public E peek() { + // Copy field to avoid re-reading after volatile load + final E[] buffer = this.buffer; + final long cIndex = lpConsumerIndex(); + final long offset = calcCircularRefElementOffset(cIndex, mask); + E e = lvRefElement(buffer, offset); + if (null == e) { + /* + * NOTE: Queue may not actually be empty in the case of a producer (P1) being interrupted after + * winning the CAS on offer but before storing the element in the queue. Other producers may go on + * to fill up the queue after this element. + */ + if (cIndex != lvProducerIndex()) { + do { + e = lvRefElement(buffer, offset); + } while (e == null); + } else { + return null; + } + } + return e; + } + + @Override + public boolean relaxedOffer(E e) { + return offer(e); + } + + @Override + public E relaxedPoll() { + final E[] buffer = this.buffer; + final long cIndex = lpConsumerIndex(); + final long offset = calcCircularRefElementOffset(cIndex, mask); + // If we can't see the next available element we can't poll + E e = lvRefElement(buffer, offset); + if (null == e) { + return null; + } + spRefElement(buffer, offset, null); + soConsumerIndex(cIndex + 1); + return e; + } + + @Override + public E relaxedPeek() { + final E[] buffer = this.buffer; + final long mask = this.mask; + final long cIndex = lpConsumerIndex(); + return lvRefElement(buffer, calcCircularRefElementOffset(cIndex, mask)); + } + + @Override + public int drain(final Consumer c, final int limit) { + if (null == c) + throw new IllegalArgumentException("c is null"); + if (limit < 0) + throw new IllegalArgumentException("limit is negative: " + limit); + if (limit == 0) + return 0; + final E[] buffer = this.buffer; + final long mask = this.mask; + final long cIndex = lpConsumerIndex(); + for (int i = 0; i < limit; i++) { + final long index = cIndex + i; + final long offset = calcCircularRefElementOffset(index, mask); + final E e = lvRefElement(buffer, offset); + if (null == e) { + return i; + } + spRefElement(buffer, offset, null); + // ordered store -> atomic and ordered for size() + soConsumerIndex(index + 1); + c.accept(e); + } + return limit; + } + + @Override + public int fill(Supplier s, int limit) { + if (null == s) + throw new IllegalArgumentException("supplier is null"); + if (limit < 0) + throw new IllegalArgumentException("limit is negative:" + limit); + if (limit == 0) + return 0; + final long mask = this.mask; + final long capacity = mask + 1; + long producerLimit = lvProducerLimit(); + long pIndex; + int actualLimit; + do { + pIndex = lvProducerIndex(); + long available = producerLimit - pIndex; + if (available <= 0) { + final long cIndex = lvConsumerIndex(); + producerLimit = cIndex + capacity; + available = producerLimit - pIndex; + if (available <= 0) { + // FULL :( + return 0; + } else { + // update producer limit to the next index that we must recheck the consumer index + soProducerLimit(producerLimit); + } + } + actualLimit = Math.min((int) available, limit); + } while (!casProducerIndex(pIndex, pIndex + actualLimit)); + // right, now we claimed a few slots and can fill them with goodness + final E[] buffer = this.buffer; + for (int i = 0; i < actualLimit; i++) { + // Won CAS, move on to storing + final long offset = calcCircularRefElementOffset(pIndex + i, mask); + soRefElement(buffer, offset, s.get()); + } + return actualLimit; + } + + @Override + public int drain(Consumer c) { + return drain(c, capacity()); + } + + @Override + public int fill(Supplier s) { + return MessagePassingQueueUtil.fillBounded(this, s); + } + + @Override + public void drain(Consumer c, WaitStrategy w, ExitCondition exit) { + MessagePassingQueueUtil.drain(this, c, w, exit); + } + + @Override + public void fill(Supplier s, WaitStrategy wait, ExitCondition exit) { + MessagePassingQueueUtil.fill(this, s, wait, exit); + } +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/unpadded/SpmcUnpaddedArrayQueue.java b/netty-jctools/src/main/java/org/jctools/queues/unpadded/SpmcUnpaddedArrayQueue.java new file mode 100644 index 0000000..cd045b3 --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/unpadded/SpmcUnpaddedArrayQueue.java @@ -0,0 +1,352 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues.unpadded; + +import static org.jctools.util.UnsafeAccess.UNSAFE; +import static org.jctools.util.UnsafeAccess.fieldOffset; +import static org.jctools.util.UnsafeRefArrayAccess.*; +import org.jctools.queues.*; + +/** + * NOTE: This class was automatically generated by org.jctools.queues.unpadded.JavaParsingUnpaddedQueueGenerator + * which can found in the jctools-build module. The original source file is SpmcArrayQueue.java. + */ +abstract class SpmcUnpaddedArrayQueueL1Pad extends ConcurrentCircularUnpaddedArrayQueue { + + SpmcUnpaddedArrayQueueL1Pad(int capacity) { + super(capacity); + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.unpadded.JavaParsingUnpaddedQueueGenerator + * which can found in the jctools-build module. The original source file is SpmcArrayQueue.java. + */ +abstract class SpmcUnpaddedArrayQueueProducerIndexField extends SpmcUnpaddedArrayQueueL1Pad { + + protected final static long P_INDEX_OFFSET = fieldOffset(SpmcUnpaddedArrayQueueProducerIndexField.class, "producerIndex"); + + private volatile long producerIndex; + + SpmcUnpaddedArrayQueueProducerIndexField(int capacity) { + super(capacity); + } + + @Override + public final long lvProducerIndex() { + return producerIndex; + } + + final long lpProducerIndex() { + return UNSAFE.getLong(this, P_INDEX_OFFSET); + } + + final void soProducerIndex(long newValue) { + UNSAFE.putOrderedLong(this, P_INDEX_OFFSET, newValue); + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.unpadded.JavaParsingUnpaddedQueueGenerator + * which can found in the jctools-build module. The original source file is SpmcArrayQueue.java. + */ +abstract class SpmcUnpaddedArrayQueueL2Pad extends SpmcUnpaddedArrayQueueProducerIndexField { + + SpmcUnpaddedArrayQueueL2Pad(int capacity) { + super(capacity); + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.unpadded.JavaParsingUnpaddedQueueGenerator + * which can found in the jctools-build module. The original source file is SpmcArrayQueue.java. + */ +abstract class SpmcUnpaddedArrayQueueConsumerIndexField extends SpmcUnpaddedArrayQueueL2Pad { + + protected final static long C_INDEX_OFFSET = fieldOffset(SpmcUnpaddedArrayQueueConsumerIndexField.class, "consumerIndex"); + + private volatile long consumerIndex; + + SpmcUnpaddedArrayQueueConsumerIndexField(int capacity) { + super(capacity); + } + + @Override + public final long lvConsumerIndex() { + return consumerIndex; + } + + final boolean casConsumerIndex(long expect, long newValue) { + return UNSAFE.compareAndSwapLong(this, C_INDEX_OFFSET, expect, newValue); + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.unpadded.JavaParsingUnpaddedQueueGenerator + * which can found in the jctools-build module. The original source file is SpmcArrayQueue.java. + */ +abstract class SpmcUnpaddedArrayQueueMidPad extends SpmcUnpaddedArrayQueueConsumerIndexField { + + SpmcUnpaddedArrayQueueMidPad(int capacity) { + super(capacity); + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.unpadded.JavaParsingUnpaddedQueueGenerator + * which can found in the jctools-build module. The original source file is SpmcArrayQueue.java. + */ +abstract class SpmcUnpaddedArrayQueueProducerIndexCacheField extends SpmcUnpaddedArrayQueueMidPad { + + // This is separated from the consumerIndex which will be highly contended in the hope that this value spends most + // of it's time in a cache line that is Shared(and rarely invalidated) + private volatile long producerIndexCache; + + SpmcUnpaddedArrayQueueProducerIndexCacheField(int capacity) { + super(capacity); + } + + protected final long lvProducerIndexCache() { + return producerIndexCache; + } + + protected final void svProducerIndexCache(long newValue) { + producerIndexCache = newValue; + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.unpadded.JavaParsingUnpaddedQueueGenerator + * which can found in the jctools-build module. The original source file is SpmcArrayQueue.java. + */ +abstract class SpmcUnpaddedArrayQueueL3Pad extends SpmcUnpaddedArrayQueueProducerIndexCacheField { + + SpmcUnpaddedArrayQueueL3Pad(int capacity) { + super(capacity); + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.unpadded.JavaParsingUnpaddedQueueGenerator + * which can found in the jctools-build module. The original source file is SpmcArrayQueue.java. + */ +public class SpmcUnpaddedArrayQueue extends SpmcUnpaddedArrayQueueL3Pad { + + public SpmcUnpaddedArrayQueue(final int capacity) { + super(capacity); + } + + @Override + public boolean offer(final E e) { + if (null == e) { + throw new NullPointerException(); + } + final E[] buffer = this.buffer; + final long mask = this.mask; + final long currProducerIndex = lvProducerIndex(); + final long offset = calcCircularRefElementOffset(currProducerIndex, mask); + if (null != lvRefElement(buffer, offset)) { + long size = currProducerIndex - lvConsumerIndex(); + if (size > mask) { + return false; + } else { + // Bubble: This can happen because `poll` moves index before placing element. + // spin wait for slot to clear, buggers wait freedom + while (null != lvRefElement(buffer, offset)) { + // BURN + } + } + } + soRefElement(buffer, offset, e); + // single producer, so store ordered is valid. It is also required to correctly publish the element + // and for the consumers to pick up the tail value. + soProducerIndex(currProducerIndex + 1); + return true; + } + + @Override + public E poll() { + long currentConsumerIndex; + long currProducerIndexCache = lvProducerIndexCache(); + do { + currentConsumerIndex = lvConsumerIndex(); + if (currentConsumerIndex >= currProducerIndexCache) { + long currProducerIndex = lvProducerIndex(); + if (currentConsumerIndex >= currProducerIndex) { + return null; + } else { + currProducerIndexCache = currProducerIndex; + svProducerIndexCache(currProducerIndex); + } + } + } while (!casConsumerIndex(currentConsumerIndex, currentConsumerIndex + 1)); + // consumers are gated on latest visible tail, and so can't see a null value in the queue or overtake + // and wrap to hit same location. + return removeElement(buffer, currentConsumerIndex, mask); + } + + private E removeElement(final E[] buffer, long index, final long mask) { + final long offset = calcCircularRefElementOffset(index, mask); + // load plain, element happens before it's index becomes visible + final E e = lpRefElement(buffer, offset); + // store ordered, make sure nulling out is visible. Producer is waiting for this value. + soRefElement(buffer, offset, null); + return e; + } + + @Override + public E peek() { + final E[] buffer = this.buffer; + final long mask = this.mask; + long currProducerIndexCache = lvProducerIndexCache(); + long currentConsumerIndex; + long nextConsumerIndex = lvConsumerIndex(); + E e; + do { + currentConsumerIndex = nextConsumerIndex; + if (currentConsumerIndex >= currProducerIndexCache) { + long currProducerIndex = lvProducerIndex(); + if (currentConsumerIndex >= currProducerIndex) { + return null; + } else { + currProducerIndexCache = currProducerIndex; + svProducerIndexCache(currProducerIndex); + } + } + e = lvRefElement(buffer, calcCircularRefElementOffset(currentConsumerIndex, mask)); + // sandwich the element load between 2 consumer index loads + nextConsumerIndex = lvConsumerIndex(); + } while (null == e || nextConsumerIndex != currentConsumerIndex); + return e; + } + + @Override + public boolean relaxedOffer(E e) { + if (null == e) { + throw new NullPointerException("Null is not a valid element"); + } + final E[] buffer = this.buffer; + final long mask = this.mask; + final long producerIndex = lpProducerIndex(); + final long offset = calcCircularRefElementOffset(producerIndex, mask); + if (null != lvRefElement(buffer, offset)) { + return false; + } + soRefElement(buffer, offset, e); + // single producer, so store ordered is valid. It is also required to correctly publish the element + // and for the consumers to pick up the tail value. + soProducerIndex(producerIndex + 1); + return true; + } + + @Override + public E relaxedPoll() { + return poll(); + } + + @Override + public E relaxedPeek() { + final E[] buffer = this.buffer; + final long mask = this.mask; + long currentConsumerIndex; + long nextConsumerIndex = lvConsumerIndex(); + E e; + do { + currentConsumerIndex = nextConsumerIndex; + e = lvRefElement(buffer, calcCircularRefElementOffset(currentConsumerIndex, mask)); + // sandwich the element load between 2 consumer index loads + nextConsumerIndex = lvConsumerIndex(); + } while (nextConsumerIndex != currentConsumerIndex); + return e; + } + + @Override + public int drain(final Consumer c, final int limit) { + if (null == c) + throw new IllegalArgumentException("c is null"); + if (limit < 0) + throw new IllegalArgumentException("limit is negative: " + limit); + if (limit == 0) + return 0; + final E[] buffer = this.buffer; + final long mask = this.mask; + long currProducerIndexCache = lvProducerIndexCache(); + int adjustedLimit = 0; + long currentConsumerIndex; + do { + currentConsumerIndex = lvConsumerIndex(); + // is there any space in the queue? + if (currentConsumerIndex >= currProducerIndexCache) { + long currProducerIndex = lvProducerIndex(); + if (currentConsumerIndex >= currProducerIndex) { + return 0; + } else { + currProducerIndexCache = currProducerIndex; + svProducerIndexCache(currProducerIndex); + } + } + // try and claim up to 'limit' elements in one go + int remaining = (int) (currProducerIndexCache - currentConsumerIndex); + adjustedLimit = Math.min(remaining, limit); + } while (!casConsumerIndex(currentConsumerIndex, currentConsumerIndex + adjustedLimit)); + for (int i = 0; i < adjustedLimit; i++) { + c.accept(removeElement(buffer, currentConsumerIndex + i, mask)); + } + return adjustedLimit; + } + + @Override + public int fill(final Supplier s, final int limit) { + if (null == s) + throw new IllegalArgumentException("supplier is null"); + if (limit < 0) + throw new IllegalArgumentException("limit is negative:" + limit); + if (limit == 0) + return 0; + final E[] buffer = this.buffer; + final long mask = this.mask; + long producerIndex = this.lpProducerIndex(); + for (int i = 0; i < limit; i++) { + final long offset = calcCircularRefElementOffset(producerIndex, mask); + if (null != lvRefElement(buffer, offset)) { + return i; + } + producerIndex++; + soRefElement(buffer, offset, s.get()); + // ordered store -> atomic and ordered for size() + soProducerIndex(producerIndex); + } + return limit; + } + + @Override + public int drain(final Consumer c) { + return MessagePassingQueueUtil.drain(this, c); + } + + @Override + public int fill(final Supplier s) { + return fill(s, capacity()); + } + + @Override + public void drain(final Consumer c, final WaitStrategy w, final ExitCondition exit) { + MessagePassingQueueUtil.drain(this, c, w, exit); + } + + @Override + public void fill(final Supplier s, final WaitStrategy w, final ExitCondition e) { + MessagePassingQueueUtil.fill(this, s, w, e); + } +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/unpadded/SpscChunkedUnpaddedArrayQueue.java b/netty-jctools/src/main/java/org/jctools/queues/unpadded/SpscChunkedUnpaddedArrayQueue.java new file mode 100644 index 0000000..1793d7a --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/unpadded/SpscChunkedUnpaddedArrayQueue.java @@ -0,0 +1,103 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues.unpadded; + +import org.jctools.util.Pow2; +import org.jctools.util.RangeUtil; +import static org.jctools.util.UnsafeRefArrayAccess.*; +import org.jctools.queues.*; + +/** + * NOTE: This class was automatically generated by org.jctools.queues.unpadded.JavaParsingUnpaddedQueueGenerator + * which can found in the jctools-build module. The original source file is SpscChunkedArrayQueue.java. + * + * An SPSC array queue which starts at initialCapacity and grows to maxCapacity in linked chunks + * of the initial size. The queue grows only when the current chunk is full and elements are not copied on + * resize, instead a link to the new chunk is stored in the old chunk for the consumer to follow.
+ * + * @param + */ +public class SpscChunkedUnpaddedArrayQueue extends BaseSpscLinkedUnpaddedArrayQueue { + + private final int maxQueueCapacity; + + private long producerQueueLimit; + + public SpscChunkedUnpaddedArrayQueue(int capacity) { + this(Math.max(8, Pow2.roundToPowerOfTwo(capacity / 8)), capacity); + } + + public SpscChunkedUnpaddedArrayQueue(int chunkSize, int capacity) { + RangeUtil.checkGreaterThanOrEqual(capacity, 16, "capacity"); + // minimal chunk size of eight makes sure minimal lookahead step is 2 + RangeUtil.checkGreaterThanOrEqual(chunkSize, 8, "chunkSize"); + maxQueueCapacity = Pow2.roundToPowerOfTwo(capacity); + int chunkCapacity = Pow2.roundToPowerOfTwo(chunkSize); + RangeUtil.checkLessThan(chunkCapacity, maxQueueCapacity, "chunkCapacity"); + long mask = chunkCapacity - 1; + // need extra element to point at next array + E[] buffer = allocateRefArray(chunkCapacity + 1); + producerBuffer = buffer; + producerMask = mask; + consumerBuffer = buffer; + consumerMask = mask; + // we know it's all empty to start with + producerBufferLimit = mask - 1; + producerQueueLimit = maxQueueCapacity; + } + + @Override + final boolean offerColdPath(E[] buffer, long mask, long pIndex, long offset, E v, Supplier s) { + // use a fixed lookahead step based on buffer capacity + final long lookAheadStep = (mask + 1) / 4; + long pBufferLimit = pIndex + lookAheadStep; + long pQueueLimit = producerQueueLimit; + if (pIndex >= pQueueLimit) { + // we tested against a potentially out of date queue limit, refresh it + final long cIndex = lvConsumerIndex(); + producerQueueLimit = pQueueLimit = cIndex + maxQueueCapacity; + // if we're full we're full + if (pIndex >= pQueueLimit) { + return false; + } + } + // if buffer limit is after queue limit we use queue limit. We need to handle overflow so + // cannot use Math.min + if (pBufferLimit - pQueueLimit > 0) { + pBufferLimit = pQueueLimit; + } + // go around the buffer or add a new buffer + if (// there's sufficient room in buffer/queue to use pBufferLimit + pBufferLimit > pIndex + 1 && null == lvRefElement(buffer, calcCircularRefElementOffset(pBufferLimit, mask))) { + // joy, there's plenty of room + producerBufferLimit = pBufferLimit - 1; + writeToQueue(buffer, v == null ? s.get() : v, pIndex, offset); + } else if (null == lvRefElement(buffer, calcCircularRefElementOffset(pIndex + 1, mask))) { + // buffer is not full + writeToQueue(buffer, v == null ? s.get() : v, pIndex, offset); + } else { + // we got one slot left to write into, and we are not full. Need to link new buffer. + // allocate new buffer of same length + final E[] newBuffer = allocateRefArray((int) (mask + 2)); + producerBuffer = newBuffer; + linkOldToNew(pIndex, buffer, offset, newBuffer, offset, v == null ? s.get() : v); + } + return true; + } + + @Override + public int capacity() { + return maxQueueCapacity; + } +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/unpadded/SpscGrowableUnpaddedArrayQueue.java b/netty-jctools/src/main/java/org/jctools/queues/unpadded/SpscGrowableUnpaddedArrayQueue.java new file mode 100644 index 0000000..ef2a782 --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/unpadded/SpscGrowableUnpaddedArrayQueue.java @@ -0,0 +1,147 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues.unpadded; + +import org.jctools.util.Pow2; +import org.jctools.util.RangeUtil; +import org.jctools.util.SpscLookAheadUtil; +import static org.jctools.queues.LinkedArrayQueueUtil.length; +import static org.jctools.util.UnsafeRefArrayAccess.*; +import org.jctools.queues.*; + +/** + * NOTE: This class was automatically generated by org.jctools.queues.unpadded.JavaParsingUnpaddedQueueGenerator + * which can found in the jctools-build module. The original source file is SpscGrowableArrayQueue.java. + * + * An SPSC array queue which starts at initialCapacity and grows to maxCapacity in linked chunks, + * doubling theirs size every time until the full blown backing array is used. + * The queue grows only when the current chunk is full and elements are not copied on + * resize, instead a link to the new chunk is stored in the old chunk for the consumer to follow.
+ * + * @param + */ +public class SpscGrowableUnpaddedArrayQueue extends BaseSpscLinkedUnpaddedArrayQueue { + + private final int maxQueueCapacity; + + private long lookAheadStep; + + public SpscGrowableUnpaddedArrayQueue(final int capacity) { + this(Math.max(8, Pow2.roundToPowerOfTwo(capacity / 8)), capacity); + } + + public SpscGrowableUnpaddedArrayQueue(final int chunkSize, final int capacity) { + RangeUtil.checkGreaterThanOrEqual(capacity, 16, "capacity"); + // minimal chunk size of eight makes sure minimal lookahead step is 2 + RangeUtil.checkGreaterThanOrEqual(chunkSize, 8, "chunkSize"); + maxQueueCapacity = Pow2.roundToPowerOfTwo(capacity); + int chunkCapacity = Pow2.roundToPowerOfTwo(chunkSize); + RangeUtil.checkLessThan(chunkCapacity, maxQueueCapacity, "chunkCapacity"); + long mask = chunkCapacity - 1; + // need extra element to point at next array + E[] buffer = allocateRefArray(chunkCapacity + 1); + producerBuffer = buffer; + producerMask = mask; + consumerBuffer = buffer; + consumerMask = mask; + // we know it's all empty to start with + producerBufferLimit = mask - 1; + adjustLookAheadStep(chunkCapacity); + } + + @Override + final boolean offerColdPath(final E[] buffer, final long mask, final long index, final long offset, final E v, final Supplier s) { + final long lookAheadStep = this.lookAheadStep; + // normal case, go around the buffer or resize if full (unless we hit max capacity) + if (lookAheadStep > 0) { + long lookAheadElementOffset = calcCircularRefElementOffset(index + lookAheadStep, mask); + // Try and look ahead a number of elements so we don't have to do this all the time + if (null == lvRefElement(buffer, lookAheadElementOffset)) { + // joy, there's plenty of room + producerBufferLimit = index + lookAheadStep - 1; + writeToQueue(buffer, v == null ? s.get() : v, index, offset); + return true; + } + // we're at max capacity, can use up last element + final int maxCapacity = maxQueueCapacity; + if (mask + 1 == maxCapacity) { + if (null == lvRefElement(buffer, offset)) { + writeToQueue(buffer, v == null ? s.get() : v, index, offset); + return true; + } + // we're full and can't grow + return false; + } + // not at max capacity, so must allow extra slot for next buffer pointer + if (null == lvRefElement(buffer, calcCircularRefElementOffset(index + 1, mask))) { + // buffer is not full + writeToQueue(buffer, v == null ? s.get() : v, index, offset); + } else { + // allocate new buffer of same length + final E[] newBuffer = allocateRefArray((int) (2 * (mask + 1) + 1)); + producerBuffer = newBuffer; + producerMask = length(newBuffer) - 2; + final long offsetInNew = calcCircularRefElementOffset(index, producerMask); + linkOldToNew(index, buffer, offset, newBuffer, offsetInNew, v == null ? s.get() : v); + int newCapacity = (int) (producerMask + 1); + if (newCapacity == maxCapacity) { + long currConsumerIndex = lvConsumerIndex(); + // use lookAheadStep to store the consumer distance from final buffer + this.lookAheadStep = -(index - currConsumerIndex); + producerBufferLimit = currConsumerIndex + maxCapacity; + } else { + producerBufferLimit = index + producerMask - 1; + adjustLookAheadStep(newCapacity); + } + } + return true; + } else // the step is negative (or zero) in the period between allocating the max sized buffer and the + // consumer starting on it + { + final long prevElementsInOtherBuffers = -lookAheadStep; + // until the consumer starts using the current buffer we need to check consumer index to + // verify size + long currConsumerIndex = lvConsumerIndex(); + int size = (int) (index - currConsumerIndex); + // we're on max capacity or we wouldn't be here + int maxCapacity = (int) mask + 1; + if (size == maxCapacity) { + // consumer index has not changed since adjusting the lookAhead index, we're full + return false; + } + // if consumerIndex progressed enough so that current size indicates it is on same buffer + long firstIndexInCurrentBuffer = producerBufferLimit - maxCapacity + prevElementsInOtherBuffers; + if (currConsumerIndex >= firstIndexInCurrentBuffer) { + // job done, we've now settled into our final state + adjustLookAheadStep(maxCapacity); + } else // consumer is still on some other buffer + { + // how many elements out of buffer? + this.lookAheadStep = (int) (currConsumerIndex - firstIndexInCurrentBuffer); + } + producerBufferLimit = currConsumerIndex + maxCapacity; + writeToQueue(buffer, v == null ? s.get() : v, index, offset); + return true; + } + } + + private void adjustLookAheadStep(int capacity) { + lookAheadStep = SpscLookAheadUtil.computeLookAheadStep(capacity); + } + + @Override + public int capacity() { + return maxQueueCapacity; + } +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/unpadded/SpscLinkedUnpaddedQueue.java b/netty-jctools/src/main/java/org/jctools/queues/unpadded/SpscLinkedUnpaddedQueue.java new file mode 100644 index 0000000..e757531 --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/unpadded/SpscLinkedUnpaddedQueue.java @@ -0,0 +1,108 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues.unpadded; + +import org.jctools.queues.*; + +/** + * NOTE: This class was automatically generated by org.jctools.queues.unpadded.JavaParsingUnpaddedQueueGenerator + * which can found in the jctools-build module. The original source file is SpscLinkedQueue.java. + * + * This is a weakened version of the MPSC algorithm as presented + * on + * 1024 Cores by D. Vyukov. The original has been adapted to Java and it's quirks with regards to memory + * model and layout: + *

    + *
  1. Use inheritance to ensure no false sharing occurs between producer/consumer node reference fields. + *
  2. As this is an SPSC we have no need for XCHG, an ordered store is enough. + *
+ * The queue is initialized with a stub node which is set to both the producer and consumer node references. + * From this point follow the notes on offer/poll. + * + * @param + * @author nitsanw + */ +public class SpscLinkedUnpaddedQueue extends BaseLinkedUnpaddedQueue { + + public SpscLinkedUnpaddedQueue() { + LinkedQueueNode node = newNode(); + spProducerNode(node); + spConsumerNode(node); + // this ensures correct construction: StoreStore + node.soNext(null); + } + + /** + * {@inheritDoc}
+ *

+ * IMPLEMENTATION NOTES:
+ * Offer is allowed from a SINGLE thread.
+ * Offer allocates a new node (holding the offered value) and: + *

    + *
  1. Sets the new node as the producerNode + *
  2. Sets that node as the lastProducerNode.next + *
+ * From this follows that producerNode.next is always null and for all other nodes node.next is not null. + * + * @see MessagePassingQueue#offer(Object) + * @see java.util.Queue#offer(java.lang.Object) + */ + @Override + public boolean offer(final E e) { + if (null == e) { + throw new NullPointerException(); + } + final LinkedQueueNode nextNode = newNode(e); + LinkedQueueNode oldNode = lpProducerNode(); + soProducerNode(nextNode); + // Should a producer thread get interrupted here the chain WILL be broken until that thread is resumed + // and completes the store in prev.next. This is a "bubble". + // Inverting the order here will break the `isEmpty` invariant, and will require matching adjustments elsewhere. + oldNode.soNext(nextNode); + return true; + } + + @Override + public int fill(Supplier s) { + return MessagePassingQueueUtil.fillUnbounded(this, s); + } + + @Override + public int fill(Supplier s, int limit) { + if (null == s) + throw new IllegalArgumentException("supplier is null"); + if (limit < 0) + throw new IllegalArgumentException("limit is negative:" + limit); + if (limit == 0) + return 0; + LinkedQueueNode tail = newNode(s.get()); + final LinkedQueueNode head = tail; + for (int i = 1; i < limit; i++) { + final LinkedQueueNode temp = newNode(s.get()); + // spNext : soProducerNode ensures correct construction + tail.spNext(temp); + tail = temp; + } + final LinkedQueueNode oldPNode = lpProducerNode(); + soProducerNode(tail); + // same bubble as offer, and for the same reasons. + oldPNode.soNext(head); + return limit; + } + + @Override + public void fill(Supplier s, WaitStrategy wait, ExitCondition exit) { + MessagePassingQueueUtil.fill(this, s, wait, exit); + } +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/unpadded/SpscUnboundedUnpaddedArrayQueue.java b/netty-jctools/src/main/java/org/jctools/queues/unpadded/SpscUnboundedUnpaddedArrayQueue.java new file mode 100644 index 0000000..a66d635 --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/unpadded/SpscUnboundedUnpaddedArrayQueue.java @@ -0,0 +1,77 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues.unpadded; + +import org.jctools.util.Pow2; +import static org.jctools.util.UnsafeRefArrayAccess.*; +import org.jctools.queues.*; + +/** + * NOTE: This class was automatically generated by org.jctools.queues.unpadded.JavaParsingUnpaddedQueueGenerator + * which can found in the jctools-build module. The original source file is SpscUnboundedArrayQueue.java. + * + * An SPSC array queue which starts at initialCapacity and grows indefinitely in linked chunks of the initial size. + * The queue grows only when the current chunk is full and elements are not copied on + * resize, instead a link to the new chunk is stored in the old chunk for the consumer to follow.
+ * + * @param + */ +public class SpscUnboundedUnpaddedArrayQueue extends BaseSpscLinkedUnpaddedArrayQueue { + + public SpscUnboundedUnpaddedArrayQueue(int chunkSize) { + int chunkCapacity = Math.max(Pow2.roundToPowerOfTwo(chunkSize), 16); + long mask = chunkCapacity - 1; + E[] buffer = allocateRefArray(chunkCapacity + 1); + producerBuffer = buffer; + producerMask = mask; + consumerBuffer = buffer; + consumerMask = mask; + // we know it's all empty to start with + producerBufferLimit = mask - 1; + } + + @Override + final boolean offerColdPath(E[] buffer, long mask, long pIndex, long offset, E v, Supplier s) { + // use a fixed lookahead step based on buffer capacity + final long lookAheadStep = (mask + 1) / 4; + long pBufferLimit = pIndex + lookAheadStep; + // go around the buffer or add a new buffer + if (null == lvRefElement(buffer, calcCircularRefElementOffset(pBufferLimit, mask))) { + // joy, there's plenty of room + producerBufferLimit = pBufferLimit - 1; + writeToQueue(buffer, v == null ? s.get() : v, pIndex, offset); + } else if (null == lvRefElement(buffer, calcCircularRefElementOffset(pIndex + 1, mask))) { + // buffer is not full + writeToQueue(buffer, v == null ? s.get() : v, pIndex, offset); + } else { + // we got one slot left to write into, and we are not full. Need to link new buffer. + // allocate new buffer of same length + final E[] newBuffer = allocateRefArray((int) (mask + 2)); + producerBuffer = newBuffer; + producerBufferLimit = pIndex + mask - 1; + linkOldToNew(pIndex, buffer, offset, newBuffer, offset, v == null ? s.get() : v); + } + return true; + } + + @Override + public int fill(Supplier s) { + return fill(s, (int) this.producerMask); + } + + @Override + public int capacity() { + return UNBOUNDED_CAPACITY; + } +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/unpadded/SpscUnpaddedArrayQueue.java b/netty-jctools/src/main/java/org/jctools/queues/unpadded/SpscUnpaddedArrayQueue.java new file mode 100644 index 0000000..83fdb8b --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/unpadded/SpscUnpaddedArrayQueue.java @@ -0,0 +1,373 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues.unpadded; + +import org.jctools.util.SpscLookAheadUtil; +import static org.jctools.util.UnsafeAccess.UNSAFE; +import static org.jctools.util.UnsafeAccess.fieldOffset; +import static org.jctools.util.UnsafeRefArrayAccess.*; +import org.jctools.queues.*; + +/** + * NOTE: This class was automatically generated by org.jctools.queues.unpadded.JavaParsingUnpaddedQueueGenerator + * which can found in the jctools-build module. The original source file is SpscArrayQueue.java. + */ +abstract class SpscUnpaddedArrayQueueColdField extends ConcurrentCircularUnpaddedArrayQueue { + + final int lookAheadStep; + + SpscUnpaddedArrayQueueColdField(int capacity) { + super(capacity); + int actualCapacity = capacity(); + lookAheadStep = SpscLookAheadUtil.computeLookAheadStep(actualCapacity); + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.unpadded.JavaParsingUnpaddedQueueGenerator + * which can found in the jctools-build module. The original source file is SpscArrayQueue.java. + */ +abstract class SpscUnpaddedArrayQueueL1Pad extends SpscUnpaddedArrayQueueColdField { + + SpscUnpaddedArrayQueueL1Pad(int capacity) { + super(capacity); + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.unpadded.JavaParsingUnpaddedQueueGenerator + * which can found in the jctools-build module. The original source file is SpscArrayQueue.java. + */ +abstract class SpscUnpaddedArrayQueueProducerIndexFields extends SpscUnpaddedArrayQueueL1Pad { + + private final static long P_INDEX_OFFSET = fieldOffset(SpscUnpaddedArrayQueueProducerIndexFields.class, "producerIndex"); + + private volatile long producerIndex; + + protected long producerLimit; + + SpscUnpaddedArrayQueueProducerIndexFields(int capacity) { + super(capacity); + } + + @Override + public final long lvProducerIndex() { + return producerIndex; + } + + final long lpProducerIndex() { + return UNSAFE.getLong(this, P_INDEX_OFFSET); + } + + final void soProducerIndex(final long newValue) { + UNSAFE.putOrderedLong(this, P_INDEX_OFFSET, newValue); + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.unpadded.JavaParsingUnpaddedQueueGenerator + * which can found in the jctools-build module. The original source file is SpscArrayQueue.java. + */ +abstract class SpscUnpaddedArrayQueueL2Pad extends SpscUnpaddedArrayQueueProducerIndexFields { + + SpscUnpaddedArrayQueueL2Pad(int capacity) { + super(capacity); + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.unpadded.JavaParsingUnpaddedQueueGenerator + * which can found in the jctools-build module. The original source file is SpscArrayQueue.java. + */ +abstract class SpscUnpaddedArrayQueueConsumerIndexField extends SpscUnpaddedArrayQueueL2Pad { + + private final static long C_INDEX_OFFSET = fieldOffset(SpscUnpaddedArrayQueueConsumerIndexField.class, "consumerIndex"); + + private volatile long consumerIndex; + + SpscUnpaddedArrayQueueConsumerIndexField(int capacity) { + super(capacity); + } + + public final long lvConsumerIndex() { + return UNSAFE.getLongVolatile(this, C_INDEX_OFFSET); + } + + final long lpConsumerIndex() { + return UNSAFE.getLong(this, C_INDEX_OFFSET); + } + + final void soConsumerIndex(final long newValue) { + UNSAFE.putOrderedLong(this, C_INDEX_OFFSET, newValue); + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.unpadded.JavaParsingUnpaddedQueueGenerator + * which can found in the jctools-build module. The original source file is SpscArrayQueue.java. + */ +abstract class SpscUnpaddedArrayQueueL3Pad extends SpscUnpaddedArrayQueueConsumerIndexField { + + SpscUnpaddedArrayQueueL3Pad(int capacity) { + super(capacity); + } +} + +/** + * NOTE: This class was automatically generated by org.jctools.queues.unpadded.JavaParsingUnpaddedQueueGenerator + * which can found in the jctools-build module. The original source file is SpscArrayQueue.java. + * + * A Single-Producer-Single-Consumer queue backed by a pre-allocated buffer. + *

+ * This implementation is a mashup of the Fast Flow + * algorithm with an optimization of the offer method taken from the BQueue algorithm (a variation on Fast + * Flow), and adjusted to comply with Queue.offer semantics with regards to capacity.
+ * For convenience the relevant papers are available in the `resources` folder:
+ * + * 2010 - Pisa - SPSC Queues on Shared Cache Multi-Core Systems.pdf
+ * 2012 - Junchang- BQueue- Efficient and Practical Queuing.pdf
+ *
+ * This implementation is wait free. + */ +public class SpscUnpaddedArrayQueue extends SpscUnpaddedArrayQueueL3Pad { + + public SpscUnpaddedArrayQueue(final int capacity) { + super(Math.max(capacity, 4)); + } + + /** + * {@inheritDoc} + *

+ * This implementation is correct for single producer thread use only. + */ + @Override + public boolean offer(final E e) { + if (null == e) { + throw new NullPointerException(); + } + // local load of field to avoid repeated loads after volatile reads + final E[] buffer = this.buffer; + final long mask = this.mask; + final long producerIndex = this.lpProducerIndex(); + if (producerIndex >= producerLimit && !offerSlowPath(buffer, mask, producerIndex)) { + return false; + } + final long offset = calcCircularRefElementOffset(producerIndex, mask); + soRefElement(buffer, offset, e); + // ordered store -> atomic and ordered for size() + soProducerIndex(producerIndex + 1); + return true; + } + + private boolean offerSlowPath(final E[] buffer, final long mask, final long producerIndex) { + final int lookAheadStep = this.lookAheadStep; + if (null == lvRefElement(buffer, calcCircularRefElementOffset(producerIndex + lookAheadStep, mask))) { + producerLimit = producerIndex + lookAheadStep; + } else { + final long offset = calcCircularRefElementOffset(producerIndex, mask); + if (null != lvRefElement(buffer, offset)) { + return false; + } + } + return true; + } + + /** + * {@inheritDoc} + *

+ * This implementation is correct for single consumer thread use only. + */ + @Override + public E poll() { + final long consumerIndex = this.lpConsumerIndex(); + final long offset = calcCircularRefElementOffset(consumerIndex, mask); + // local load of field to avoid repeated loads after volatile reads + final E[] buffer = this.buffer; + final E e = lvRefElement(buffer, offset); + if (null == e) { + return null; + } + soRefElement(buffer, offset, null); + // ordered store -> atomic and ordered for size() + soConsumerIndex(consumerIndex + 1); + return e; + } + + /** + * {@inheritDoc} + *

+ * This implementation is correct for single consumer thread use only. + */ + @Override + public E peek() { + return lvRefElement(buffer, calcCircularRefElementOffset(lpConsumerIndex(), mask)); + } + + @Override + public boolean relaxedOffer(final E message) { + return offer(message); + } + + @Override + public E relaxedPoll() { + return poll(); + } + + @Override + public E relaxedPeek() { + return peek(); + } + + @Override + public int drain(final Consumer c) { + return drain(c, capacity()); + } + + @Override + public int fill(final Supplier s) { + return fill(s, capacity()); + } + + @Override + public int drain(final Consumer c, final int limit) { + if (null == c) + throw new IllegalArgumentException("c is null"); + if (limit < 0) + throw new IllegalArgumentException("limit is negative: " + limit); + if (limit == 0) + return 0; + final E[] buffer = this.buffer; + final long mask = this.mask; + final long consumerIndex = this.lpConsumerIndex(); + for (int i = 0; i < limit; i++) { + final long index = consumerIndex + i; + final long offset = calcCircularRefElementOffset(index, mask); + final E e = lvRefElement(buffer, offset); + if (null == e) { + return i; + } + soRefElement(buffer, offset, null); + // ordered store -> atomic and ordered for size() + soConsumerIndex(index + 1); + c.accept(e); + } + return limit; + } + + @Override + public int fill(final Supplier s, final int limit) { + if (null == s) + throw new IllegalArgumentException("supplier is null"); + if (limit < 0) + throw new IllegalArgumentException("limit is negative:" + limit); + if (limit == 0) + return 0; + final E[] buffer = this.buffer; + final long mask = this.mask; + final int lookAheadStep = this.lookAheadStep; + final long producerIndex = this.lpProducerIndex(); + for (int i = 0; i < limit; i++) { + final long index = producerIndex + i; + final long lookAheadElementOffset = calcCircularRefElementOffset(index + lookAheadStep, mask); + if (null == lvRefElement(buffer, lookAheadElementOffset)) { + int lookAheadLimit = Math.min(lookAheadStep, limit - i); + for (int j = 0; j < lookAheadLimit; j++) { + final long offset = calcCircularRefElementOffset(index + j, mask); + soRefElement(buffer, offset, s.get()); + // ordered store -> atomic and ordered for size() + soProducerIndex(index + j + 1); + } + i += lookAheadLimit - 1; + } else { + final long offset = calcCircularRefElementOffset(index, mask); + if (null != lvRefElement(buffer, offset)) { + return i; + } + soRefElement(buffer, offset, s.get()); + // ordered store -> atomic and ordered for size() + soProducerIndex(index + 1); + } + } + return limit; + } + + @Override + public void drain(final Consumer c, final WaitStrategy w, final ExitCondition exit) { + if (null == c) + throw new IllegalArgumentException("c is null"); + if (null == w) + throw new IllegalArgumentException("wait is null"); + if (null == exit) + throw new IllegalArgumentException("exit condition is null"); + final E[] buffer = this.buffer; + final long mask = this.mask; + long consumerIndex = this.lpConsumerIndex(); + int counter = 0; + while (exit.keepRunning()) { + for (int i = 0; i < 4096; i++) { + final long offset = calcCircularRefElementOffset(consumerIndex, mask); + final E e = lvRefElement(buffer, offset); + if (null == e) { + counter = w.idle(counter); + continue; + } + consumerIndex++; + counter = 0; + soRefElement(buffer, offset, null); + // ordered store -> atomic and ordered for size() + soConsumerIndex(consumerIndex); + c.accept(e); + } + } + } + + @Override + public void fill(final Supplier s, final WaitStrategy w, final ExitCondition e) { + if (null == w) + throw new IllegalArgumentException("waiter is null"); + if (null == e) + throw new IllegalArgumentException("exit condition is null"); + if (null == s) + throw new IllegalArgumentException("supplier is null"); + final E[] buffer = this.buffer; + final long mask = this.mask; + final int lookAheadStep = this.lookAheadStep; + long producerIndex = this.lpProducerIndex(); + int counter = 0; + while (e.keepRunning()) { + final long lookAheadElementOffset = calcCircularRefElementOffset(producerIndex + lookAheadStep, mask); + if (null == lvRefElement(buffer, lookAheadElementOffset)) { + for (int j = 0; j < lookAheadStep; j++) { + final long offset = calcCircularRefElementOffset(producerIndex, mask); + producerIndex++; + soRefElement(buffer, offset, s.get()); + // ordered store -> atomic and ordered for size() + soProducerIndex(producerIndex); + } + } else { + final long offset = calcCircularRefElementOffset(producerIndex, mask); + if (null != lvRefElement(buffer, offset)) { + counter = w.idle(counter); + continue; + } + producerIndex++; + counter = 0; + soRefElement(buffer, offset, s.get()); + // ordered store -> atomic and ordered for size() + soProducerIndex(producerIndex); + } + } + } +} diff --git a/netty-jctools/src/main/java/org/jctools/queues/unpadded/package-info.java b/netty-jctools/src/main/java/org/jctools/queues/unpadded/package-info.java new file mode 100644 index 0000000..f94a3ca --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/queues/unpadded/package-info.java @@ -0,0 +1,14 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues.unpadded; diff --git a/netty-jctools/src/main/java/org/jctools/util/InternalAPI.java b/netty-jctools/src/main/java/org/jctools/util/InternalAPI.java new file mode 100644 index 0000000..a3cc48e --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/util/InternalAPI.java @@ -0,0 +1,30 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.util; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * This annotation marks classes and methods which may be public for any reason (to support better testing or reduce + * code duplication) but are not intended as public API and may change between releases without the change being + * considered a breaking API change (a major release). + */ +@Target({ElementType.TYPE, ElementType.METHOD, ElementType.FIELD, ElementType.CONSTRUCTOR}) +@Retention(RetentionPolicy.SOURCE) +public @interface InternalAPI +{ +} diff --git a/netty-jctools/src/main/java/org/jctools/util/PaddedAtomicLong.java b/netty-jctools/src/main/java/org/jctools/util/PaddedAtomicLong.java new file mode 100644 index 0000000..e5c5fd5 --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/util/PaddedAtomicLong.java @@ -0,0 +1,392 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.util; + +import java.util.function.LongBinaryOperator; +import java.util.function.LongUnaryOperator; + +import static org.jctools.util.UnsafeAccess.UNSAFE; +import static org.jctools.util.UnsafeAccess.fieldOffset; + +abstract class PaddedAtomicLongL1Pad extends Number implements java.io.Serializable { + private static final long serialVersionUID = 1; + + byte b000,b001,b002,b003,b004,b005,b006,b007;// 8b + byte b010,b011,b012,b013,b014,b015,b016,b017;// 16b + byte b020,b021,b022,b023,b024,b025,b026,b027;// 24b + byte b030,b031,b032,b033,b034,b035,b036,b037;// 32b + byte b040,b041,b042,b043,b044,b045,b046,b047;// 40b + byte b050,b051,b052,b053,b054,b055,b056,b057;// 48b + byte b060,b061,b062,b063,b064,b065,b066,b067;// 56b + byte b070,b071,b072,b073,b074,b075,b076,b077;// 64b + byte b100,b101,b102,b103,b104,b105,b106,b107;// 72b + byte b110,b111,b112,b113,b114,b115,b116,b117;// 80b + byte b120,b121,b122,b123,b124,b125,b126,b127;// 88b + byte b130,b131,b132,b133,b134,b135,b136,b137;// 96b + byte b140,b141,b142,b143,b144,b145,b146,b147;//104b + byte b150,b151,b152,b153,b154,b155,b156,b157;//112b + byte b160,b161,b162,b163,b164,b165,b166,b167;//120b + byte b170,b171,b172,b173,b174,b175,b176,b177;//128b +} + +abstract class PaddedAtomicLongL1Field extends PaddedAtomicLongL1Pad { + private final static long VALUE_OFFSET = fieldOffset(PaddedAtomicLongL1Field.class, "value"); + private volatile long value; + + public void spVal(long v) { + UNSAFE.putLong(this, VALUE_OFFSET, v); + } + public void soVal(long v) { + UNSAFE.putOrderedLong(this, VALUE_OFFSET, v); + } + + public void svVal(long v) { + value = v; + } + + public long lvVal() { + return value; + } + public long lpVal() { + return UNSAFE.getLong(this, VALUE_OFFSET); + } + + public boolean casVal(long expectedV, long newV) { + return UNSAFE.compareAndSwapLong(this, VALUE_OFFSET, expectedV, newV); + } + + public long getAndSetVal(long v) { + if (UnsafeAccess.SUPPORTS_GET_AND_ADD_LONG) { + return UNSAFE.getAndSetLong(this, VALUE_OFFSET, v); + } + else { + long currV; + do { + currV = lvVal(); + } while (!casVal(currV, v)); + return currV; + } + } + + public long getAndAddVal(long delta) { + if (UnsafeAccess.SUPPORTS_GET_AND_ADD_LONG) { + return UNSAFE.getAndAddLong(this, VALUE_OFFSET, delta); + } + else { + long currV; + do { + currV = lvVal(); + } while (!casVal(currV, currV + delta)); + return currV; + } + } +} + +abstract class PaddedAtomicLongL2Pad extends PaddedAtomicLongL1Field { + + byte b000,b001,b002,b003,b004,b005,b006,b007;// 8b + byte b010,b011,b012,b013,b014,b015,b016,b017;// 16b + byte b020,b021,b022,b023,b024,b025,b026,b027;// 24b + byte b030,b031,b032,b033,b034,b035,b036,b037;// 32b + byte b040,b041,b042,b043,b044,b045,b046,b047;// 40b + byte b050,b051,b052,b053,b054,b055,b056,b057;// 48b + byte b060,b061,b062,b063,b064,b065,b066,b067;// 56b + byte b070,b071,b072,b073,b074,b075,b076,b077;// 64b + byte b100,b101,b102,b103,b104,b105,b106,b107;// 72b + byte b110,b111,b112,b113,b114,b115,b116,b117;// 80b + byte b120,b121,b122,b123,b124,b125,b126,b127;// 88b + byte b130,b131,b132,b133,b134,b135,b136,b137;// 96b + byte b140,b141,b142,b143,b144,b145,b146,b147;//104b + byte b150,b151,b152,b153,b154,b155,b156,b157;//112b + byte b160,b161,b162,b163,b164,b165,b166,b167;//120b + byte b170,b171,b172,b173,b174,b175,b176,b177;//128b +} + +/** + * A padded version of the {@link java.util.concurrent.atomic.AtomicLong}. + */ +public class PaddedAtomicLong extends PaddedAtomicLongL2Pad { + /** + * Creates a new PaddedAtomicLong with initial value {@code 0}. + */ + public PaddedAtomicLong() { + } + + /** + * Creates a new PaddedAtomicLong with the given initial value. + * + * @param initialValue the initial value + */ + public PaddedAtomicLong(long initialValue) { + svVal(initialValue); + } + + /** + * Gets the current value. + * + * @return the current value + * @see java.util.concurrent.atomic.AtomicLong#get() + */ + public long get() { + return lvVal(); + } + + /** + * Sets to the given value. + * + * @param newValue the new value + * @see java.util.concurrent.atomic.AtomicLong#set(long) + */ + public void set(long newValue) { + svVal(newValue); + } + + /** + * Eventually sets to the given value. + * + * @param newValue the new value + * @see java.util.concurrent.atomic.AtomicLong#lazySet(long) + */ + public void lazySet(long newValue) { + soVal(newValue); + } + + /** + * Atomically sets to the given value and returns the old value. + * + * @param newValue the new value + * @return the previous value + * @see java.util.concurrent.atomic.AtomicLong#getAndSet(long) + */ + public long getAndSet(long newValue) { + return getAndSetVal(newValue); + } + + /** + * Atomically sets the value to the given updated value + * if the current value {@code ==} the expected value. + * + * @param expect the expected value + * @param update the new value + * @return {@code true} if successful. False return indicates that + * the actual value was not equal to the expected value. + * @see java.util.concurrent.atomic.AtomicLong#compareAndSet(long, long) + */ + public boolean compareAndSet(long expect, long update) { + return casVal(expect, update); + } + + /** + * Atomically sets the value to the given updated value + * if the current value {@code ==} the expected value. + * + *

May fail + * spuriously and does not provide ordering guarantees, so is + * only rarely an appropriate alternative to {@code compareAndSet}. + * + * @param expect the expected value + * @param update the new value + * @return {@code true} if successful + * @see java.util.concurrent.atomic.AtomicLong#weakCompareAndSet(long, long) + */ + public boolean weakCompareAndSet(long expect, long update) { + return casVal(expect, update); + } + + /** + * Atomically increments the current value by 1. + * + * @return the previous value + * @see java.util.concurrent.atomic.AtomicLong#getAndIncrement() + */ + public long getAndIncrement() { + return getAndAddVal(1L); + } + + /** + * Atomically decrements the current value by 1. + * + * @return the previous value + * @see java.util.concurrent.atomic.AtomicLong#getAndDecrement() + */ + public long getAndDecrement() { + return getAndAddVal(-1L); + } + + /** + * Atomically adds to the current value the given value. + * + * @param delta the value to add + * @return the previous value + * @see java.util.concurrent.atomic.AtomicLong#getAndAdd(long) + */ + public long getAndAdd(long delta) { + return getAndAddVal(delta); + } + + /** + * Atomically increments the current value by one. + * + * @return the updated value + * @see java.util.concurrent.atomic.AtomicLong#incrementAndGet() + */ + public long incrementAndGet() { + return getAndAddVal(1L) + 1L; + } + + /** + * Atomically decrements the current value by one. + * + * @return the updated value + * @see java.util.concurrent.atomic.AtomicLong#decrementAndGet() + */ + public long decrementAndGet() { + return getAndAddVal(-1L) - 1L; + } + + /** + * Atomically adds to current value te given value. + * + * @param delta the value to add + * @return the updated value + * @see java.util.concurrent.atomic.AtomicLong#addAndGet(long) + */ + public long addAndGet(long delta) { + return getAndAddVal( delta) + delta; + } + + /** + * Atomically updates the current value with the results of + * applying the given function, returning the previous value. The + * function should be side-effect-free, since it may be re-applied + * when attempted updates fail due to contention among threads. + * + * @param updateFunction a side-effect-free function + * @return the previous value + * @see java.util.concurrent.atomic.AtomicLong#getAndUpdate(LongUnaryOperator) + */ + public long getAndUpdate(LongUnaryOperator updateFunction) { + long prev, next; + do { + prev = lvVal(); + next = updateFunction.applyAsLong(prev); + } while (!casVal(prev, next)); + return prev; + } + + /** + * Atomically updates the current value with the results of + * applying the given function, returning the updated value. The + * function should be side-effect-free, since it may be re-applied + * when attempted updates fail due to contention among threads. + * + * @param updateFunction a side-effect-free function + * @return the updated value + * @see java.util.concurrent.atomic.AtomicLong#updateAndGet(LongUnaryOperator) + */ + public long updateAndGet(LongUnaryOperator updateFunction) { + long prev, next; + do { + prev = lvVal(); + next = updateFunction.applyAsLong(prev); + } while (!casVal(prev, next)); + return next; + } + + /** + * Atomically updates the current value with the results of + * applying the given function to the current and given values, + * returning the previous value. The function should be + * side-effect-free, since it may be re-applied when attempted + * updates fail due to contention among threads. The function + * is applied with the current value as its first argument, + * and the given update as the second argument. + * + * @param v the update value + * @param f a side-effect-free function of two arguments + * @return the previous value + * @see java.util.concurrent.atomic.AtomicLong#getAndAccumulate(long, LongBinaryOperator) + */ + public long getAndAccumulate(long v, LongBinaryOperator f) { + long prev, next; + do { + prev = lvVal(); + next = f.applyAsLong(prev, v); + } while (!casVal(prev, next)); + return prev; + } + + + /** + * {@link java.util.concurrent.atomic.AtomicLong#accumulateAndGet(long, LongBinaryOperator)} + */ + public long accumulateAndGet(long x, LongBinaryOperator f) { + long prev, next; + do { + prev = lvVal(); + next = f.applyAsLong(prev, x); + } while (!casVal(prev, next)); + return next; + } + + /** + * Returns the String representation of the current value. + * + * @return the String representation of the current value + */ + @Override + public String toString() { + return Long.toString(lvVal()); + } + + /** + * Returns the value as an {@code int}. + * + * @see java.util.concurrent.atomic.AtomicLong#intValue() + */ + @Override + public int intValue() { + return (int) lvVal(); + } + + /** + * Returns the value as a {@code long}. + * + * @see java.util.concurrent.atomic.AtomicLong#longValue() + */ + @Override + public long longValue() { + return lvVal(); + } + + /** + * Returns the value of a {@code float}. + * + * @see java.util.concurrent.atomic.AtomicLong#floatValue() + */ + @Override + public float floatValue() { + return (float) lvVal(); + } + + /** + * Returns the value of a {@code double}. + * + * @see java.util.concurrent.atomic.AtomicLong#doubleValue() + */ + @Override + public double doubleValue() { + return (double) lvVal(); + } +} diff --git a/netty-jctools/src/main/java/org/jctools/util/PortableJvmInfo.java b/netty-jctools/src/main/java/org/jctools/util/PortableJvmInfo.java new file mode 100644 index 0000000..7396fe7 --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/util/PortableJvmInfo.java @@ -0,0 +1,25 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.util; + +/** + * JVM Information that is standard and available on all JVMs (i.e. does not use unsafe) + */ +@InternalAPI +public interface PortableJvmInfo { + int CACHE_LINE_SIZE = Integer.getInteger("jctools.cacheLineSize", 64); + int CPUs = Runtime.getRuntime().availableProcessors(); + int RECOMENDED_OFFER_BATCH = CPUs * 4; + int RECOMENDED_POLL_BATCH = CPUs * 4; +} diff --git a/netty-jctools/src/main/java/org/jctools/util/Pow2.java b/netty-jctools/src/main/java/org/jctools/util/Pow2.java new file mode 100644 index 0000000..90aa759 --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/util/Pow2.java @@ -0,0 +1,61 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.util; + +/** + * Power of 2 utility functions. + */ +@InternalAPI +public final class Pow2 { + public static final int MAX_POW2 = 1 << 30; + + /** + * @param value from which next positive power of two will be found. + * @return the next positive power of 2, this value if it is a power of 2. Negative values are mapped to 1. + * @throws IllegalArgumentException is value is more than MAX_POW2 or less than 0 + */ + public static int roundToPowerOfTwo(final int value) { + if (value > MAX_POW2) { + throw new IllegalArgumentException("There is no larger power of 2 int for value:"+value+" since it exceeds 2^31."); + } + if (value < 0) { + throw new IllegalArgumentException("Given value:"+value+". Expecting value >= 0."); + } + final int nextPow2 = 1 << (32 - Integer.numberOfLeadingZeros(value - 1)); + return nextPow2; + } + + /** + * @param value to be tested to see if it is a power of two. + * @return true if the value is a power of 2 otherwise false. + */ + public static boolean isPowerOfTwo(final int value) { + return (value & (value - 1)) == 0; + } + + /** + * Align a value to the next multiple up of alignment. If the value equals an alignment multiple then it + * is returned unchanged. + * + * @param value to be aligned up. + * @param alignment to be used, must be a power of 2. + * @return the value aligned to the next boundary. + */ + public static long align(final long value, final int alignment) { + if (!isPowerOfTwo(alignment)) { + throw new IllegalArgumentException("alignment must be a power of 2:" + alignment); + } + return (value + (alignment - 1)) & ~(alignment - 1); + } +} diff --git a/netty-jctools/src/main/java/org/jctools/util/RangeUtil.java b/netty-jctools/src/main/java/org/jctools/util/RangeUtil.java new file mode 100644 index 0000000..37033cd --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/util/RangeUtil.java @@ -0,0 +1,68 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.util; + +@InternalAPI +public final class RangeUtil +{ + public static long checkPositive(long n, String name) + { + if (n <= 0) + { + throw new IllegalArgumentException(name + ": " + n + " (expected: > 0)"); + } + + return n; + } + + public static int checkPositiveOrZero(int n, String name) + { + if (n < 0) + { + throw new IllegalArgumentException(name + ": " + n + " (expected: >= 0)"); + } + + return n; + } + + public static int checkLessThan(int n, int expected, String name) + { + if (n >= expected) + { + throw new IllegalArgumentException(name + ": " + n + " (expected: < " + expected + ')'); + } + + return n; + } + + public static int checkLessThanOrEqual(int n, long expected, String name) + { + if (n > expected) + { + throw new IllegalArgumentException(name + ": " + n + " (expected: <= " + expected + ')'); + } + + return n; + } + + public static int checkGreaterThanOrEqual(int n, int expected, String name) + { + if (n < expected) + { + throw new IllegalArgumentException(name + ": " + n + " (expected: >= " + expected + ')'); + } + + return n; + } +} diff --git a/netty-jctools/src/main/java/org/jctools/util/SpscLookAheadUtil.java b/netty-jctools/src/main/java/org/jctools/util/SpscLookAheadUtil.java new file mode 100644 index 0000000..712de9a --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/util/SpscLookAheadUtil.java @@ -0,0 +1,12 @@ +package org.jctools.util; + +@InternalAPI +public class SpscLookAheadUtil +{ + public static final int MAX_LOOK_AHEAD_STEP = Integer.getInteger("jctools.spsc.max.lookahead.step", 4096); + + public static int computeLookAheadStep(int actualCapacity) + { + return Math.min(actualCapacity / 4, MAX_LOOK_AHEAD_STEP); + } +} diff --git a/netty-jctools/src/main/java/org/jctools/util/UnsafeAccess.java b/netty-jctools/src/main/java/org/jctools/util/UnsafeAccess.java new file mode 100755 index 0000000..adea7db --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/util/UnsafeAccess.java @@ -0,0 +1,114 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.util; + +import sun.misc.Unsafe; + +import java.lang.reflect.Constructor; +import java.lang.reflect.Field; +import java.util.concurrent.atomic.AtomicReferenceArray; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; + +/** + * Why should we resort to using Unsafe?
+ *

    + *
  1. To construct class fields which allow volatile/ordered/plain access: This requirement is covered by + * {@link AtomicReferenceFieldUpdater} and similar but their performance is arguably worse than the DIY approach + * (depending on JVM version) while Unsafe intrinsification is a far lesser challenge for JIT compilers. + *
  2. To construct flavors of {@link AtomicReferenceArray}. + *
  3. Other use cases exist but are not present in this library yet. + *
+ * + * @author nitsanw + */ +@InternalAPI +public class UnsafeAccess +{ + public static final boolean SUPPORTS_GET_AND_SET_REF; + public static final boolean SUPPORTS_GET_AND_ADD_LONG; + public static final Unsafe UNSAFE; + + static + { + UNSAFE = getUnsafe(); + SUPPORTS_GET_AND_SET_REF = hasGetAndSetSupport(); + SUPPORTS_GET_AND_ADD_LONG = hasGetAndAddLongSupport(); + } + + private static Unsafe getUnsafe() + { + Unsafe instance; + try + { + final Field field = Unsafe.class.getDeclaredField("theUnsafe"); + field.setAccessible(true); + instance = (Unsafe) field.get(null); + } + catch (Exception ignored) + { + // Some platforms, notably Android, might not have a sun.misc.Unsafe implementation with a private + // `theUnsafe` static instance. In this case we can try to call the default constructor, which is sufficient + // for Android usage. + try + { + Constructor c = Unsafe.class.getDeclaredConstructor(); + c.setAccessible(true); + instance = c.newInstance(); + } + catch (Exception e) + { + throw new RuntimeException(e); + } + } + return instance; + } + + private static boolean hasGetAndSetSupport() + { + try + { + Unsafe.class.getMethod("getAndSetObject", Object.class, Long.TYPE, Object.class); + return true; + } + catch (Exception ignored) + { + } + return false; + } + + private static boolean hasGetAndAddLongSupport() + { + try + { + Unsafe.class.getMethod("getAndAddLong", Object.class, Long.TYPE, Long.TYPE); + return true; + } + catch (Exception ignored) + { + } + return false; + } + + public static long fieldOffset(Class clz, String fieldName) throws RuntimeException + { + try + { + return UNSAFE.objectFieldOffset(clz.getDeclaredField(fieldName)); + } + catch (NoSuchFieldException e) + { + throw new RuntimeException(e); + } + } +} diff --git a/netty-jctools/src/main/java/org/jctools/util/UnsafeJvmInfo.java b/netty-jctools/src/main/java/org/jctools/util/UnsafeJvmInfo.java new file mode 100644 index 0000000..8296104 --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/util/UnsafeJvmInfo.java @@ -0,0 +1,19 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.util; + +@InternalAPI +public interface UnsafeJvmInfo { + int PAGE_SIZE = UnsafeAccess.UNSAFE.pageSize(); +} diff --git a/netty-jctools/src/main/java/org/jctools/util/UnsafeLongArrayAccess.java b/netty-jctools/src/main/java/org/jctools/util/UnsafeLongArrayAccess.java new file mode 100644 index 0000000..d29b373 --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/util/UnsafeLongArrayAccess.java @@ -0,0 +1,114 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.util; + +import static org.jctools.util.UnsafeAccess.UNSAFE; + +@InternalAPI +public final class UnsafeLongArrayAccess +{ + public static final long LONG_ARRAY_BASE; + public static final int LONG_ELEMENT_SHIFT; + + static + { + final int scale = UnsafeAccess.UNSAFE.arrayIndexScale(long[].class); + if (8 == scale) + { + LONG_ELEMENT_SHIFT = 3; + } + else + { + throw new IllegalStateException("Unknown pointer size: " + scale); + } + LONG_ARRAY_BASE = UnsafeAccess.UNSAFE.arrayBaseOffset(long[].class); + } + + /** + * A plain store (no ordering/fences) of an element to a given offset + * + * @param buffer le buffer + * @param offset computed via {@link UnsafeLongArrayAccess#calcLongElementOffset(long)} + * @param e an orderly kitty + */ + public static void spLongElement(long[] buffer, long offset, long e) + { + UNSAFE.putLong(buffer, offset, e); + } + + /** + * An ordered store of an element to a given offset + * + * @param buffer le buffer + * @param offset computed via {@link UnsafeLongArrayAccess#calcCircularLongElementOffset} + * @param e an orderly kitty + */ + public static void soLongElement(long[] buffer, long offset, long e) + { + UNSAFE.putOrderedLong(buffer, offset, e); + } + + /** + * A plain load (no ordering/fences) of an element from a given offset. + * + * @param buffer le buffer + * @param offset computed via {@link UnsafeLongArrayAccess#calcLongElementOffset(long)} + * @return the element at the offset + */ + public static long lpLongElement(long[] buffer, long offset) + { + return UNSAFE.getLong(buffer, offset); + } + + /** + * A volatile load of an element from a given offset. + * + * @param buffer le buffer + * @param offset computed via {@link UnsafeLongArrayAccess#calcCircularLongElementOffset} + * @return the element at the offset + */ + public static long lvLongElement(long[] buffer, long offset) + { + return UNSAFE.getLongVolatile(buffer, offset); + } + + /** + * @param index desirable element index + * @return the offset in bytes within the array for a given index + */ + public static long calcLongElementOffset(long index) + { + return LONG_ARRAY_BASE + (index << LONG_ELEMENT_SHIFT); + } + + /** + * Note: circular arrays are assumed a power of 2 in length and the `mask` is (length - 1). + * + * @param index desirable element index + * @param mask (length - 1) + * @return the offset in bytes within the circular array for a given index + */ + public static long calcCircularLongElementOffset(long index, long mask) + { + return LONG_ARRAY_BASE + ((index & mask) << LONG_ELEMENT_SHIFT); + } + + /** + * This makes for an easier time generating the atomic queues, and removes some warnings. + */ + public static long[] allocateLongArray(int capacity) + { + return new long[capacity]; + } +} diff --git a/netty-jctools/src/main/java/org/jctools/util/UnsafeRefArrayAccess.java b/netty-jctools/src/main/java/org/jctools/util/UnsafeRefArrayAccess.java new file mode 100644 index 0000000..598fad3 --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/util/UnsafeRefArrayAccess.java @@ -0,0 +1,121 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.util; + +import static org.jctools.util.UnsafeAccess.UNSAFE; + +@InternalAPI +public final class UnsafeRefArrayAccess +{ + public static final long REF_ARRAY_BASE; + public static final int REF_ELEMENT_SHIFT; + + static + { + final int scale = UnsafeAccess.UNSAFE.arrayIndexScale(Object[].class); + if (4 == scale) + { + REF_ELEMENT_SHIFT = 2; + } + else if (8 == scale) + { + REF_ELEMENT_SHIFT = 3; + } + else + { + throw new IllegalStateException("Unknown pointer size: " + scale); + } + REF_ARRAY_BASE = UnsafeAccess.UNSAFE.arrayBaseOffset(Object[].class); + } + + /** + * A plain store (no ordering/fences) of an element to a given offset + * + * @param buffer this.buffer + * @param offset computed via {@link UnsafeRefArrayAccess#calcRefElementOffset(long)} + * @param e an orderly kitty + */ + public static void spRefElement(E[] buffer, long offset, E e) + { + UNSAFE.putObject(buffer, offset, e); + } + + /** + * An ordered store of an element to a given offset + * + * @param buffer this.buffer + * @param offset computed via {@link UnsafeRefArrayAccess#calcCircularRefElementOffset} + * @param e an orderly kitty + */ + public static void soRefElement(E[] buffer, long offset, E e) + { + UNSAFE.putOrderedObject(buffer, offset, e); + } + + /** + * A plain load (no ordering/fences) of an element from a given offset. + * + * @param buffer this.buffer + * @param offset computed via {@link UnsafeRefArrayAccess#calcRefElementOffset(long)} + * @return the element at the offset + */ + @SuppressWarnings("unchecked") + public static E lpRefElement(E[] buffer, long offset) + { + return (E) UNSAFE.getObject(buffer, offset); + } + + /** + * A volatile load of an element from a given offset. + * + * @param buffer this.buffer + * @param offset computed via {@link UnsafeRefArrayAccess#calcRefElementOffset(long)} + * @return the element at the offset + */ + @SuppressWarnings("unchecked") + public static E lvRefElement(E[] buffer, long offset) + { + return (E) UNSAFE.getObjectVolatile(buffer, offset); + } + + /** + * @param index desirable element index + * @return the offset in bytes within the array for a given index + */ + public static long calcRefElementOffset(long index) + { + return REF_ARRAY_BASE + (index << REF_ELEMENT_SHIFT); + } + + /** + * Note: circular arrays are assumed a power of 2 in length and the `mask` is (length - 1). + * + * @param index desirable element index + * @param mask (length - 1) + * @return the offset in bytes within the circular array for a given index + */ + public static long calcCircularRefElementOffset(long index, long mask) + { + return REF_ARRAY_BASE + ((index & mask) << REF_ELEMENT_SHIFT); + } + + /** + * This makes for an easier time generating the atomic queues, and removes some warnings. + */ + @SuppressWarnings("unchecked") + public static E[] allocateRefArray(int capacity) + { + return (E[]) new Object[capacity]; + } +} diff --git a/netty-jctools/src/main/java/org/jctools/util/package-info.java b/netty-jctools/src/main/java/org/jctools/util/package-info.java new file mode 100644 index 0000000..a2df00e --- /dev/null +++ b/netty-jctools/src/main/java/org/jctools/util/package-info.java @@ -0,0 +1,14 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.util; diff --git a/netty-jctools/src/test/java/org/jctools/counters/FixedSizeStripedLongCounterTest.java b/netty-jctools/src/test/java/org/jctools/counters/FixedSizeStripedLongCounterTest.java new file mode 100644 index 0000000..78710af --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/counters/FixedSizeStripedLongCounterTest.java @@ -0,0 +1,88 @@ +package org.jctools.counters; + +import static org.junit.Assert.*; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicLong; + +import org.jctools.util.PortableJvmInfo; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +/** + * @author Tolstopyatov Vsevolod + */ +@RunWith(Parameterized.class) +public class FixedSizeStripedLongCounterTest { + + @Parameterized.Parameters + public static Collection parameters() { + int stripesCount = PortableJvmInfo.CPUs * 2; + ArrayList list = new ArrayList<>(); + list.add(new Counter[]{new FixedSizeStripedLongCounterV6(stripesCount)}); + list.add(new Counter[]{new FixedSizeStripedLongCounterV8(stripesCount)}); + return list; + } + + private final Counter counter; + + public FixedSizeStripedLongCounterTest(Counter counter) { + this.counter = counter; + } + + @Test + public void testCounterSanity() { + long expected = 1000L; + for (int i = 0; i < expected; i++) { + counter.inc(); + } + + assertSanity(expected); + } + + @Test + public void testMultipleThreadsCounterSanity() throws Exception { + int threadsCount = PortableJvmInfo.CPUs; + AtomicLong summary = new AtomicLong(); + AtomicBoolean running = new AtomicBoolean(true); + CountDownLatch startLatch = new CountDownLatch(1); + CountDownLatch finishLatch = new CountDownLatch(threadsCount); + AtomicBoolean fail = new AtomicBoolean(false); + for (int i = 0; i < threadsCount; i++) { + new Thread(() -> { + try { + Counter c = counter; + startLatch.await(); + long local = 0; + while (running.get()) { + c.inc(); + local++; + } + summary.addAndGet(local); + } catch (Exception e) { + fail.set(true); + } + finally { + finishLatch.countDown(); + } + }).start(); + } + + startLatch.countDown(); + Thread.sleep(1000); + running.set(false); + finishLatch.await(); + assertFalse(fail.get()); + assertSanity(summary.get()); + } + + private void assertSanity(long expected) { + assertEquals(expected, counter.get()); + assertEquals(expected, counter.getAndReset()); + assertEquals(0L, counter.get()); + } +} diff --git a/netty-jctools/src/test/java/org/jctools/maps/KeyAtomicityTest.java b/netty-jctools/src/test/java/org/jctools/maps/KeyAtomicityTest.java new file mode 100644 index 0000000..f018631 --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/maps/KeyAtomicityTest.java @@ -0,0 +1,103 @@ +package org.jctools.maps; + +import java.util.*; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicBoolean; + +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +public class KeyAtomicityTest +{ + static final int THREAD_SEGMENT = 1000000; + + @Test + public void putReturnValuesAreDistinct() throws Exception + { + Map map = new NonBlockingHashMap<>(); + map.put("K", -1l); + int processors = Runtime.getRuntime().availableProcessors(); + CountDownLatch ready = new CountDownLatch(processors); + CountDownLatch start = new CountDownLatch(1); + CountDownLatch done = new CountDownLatch(processors); + AtomicBoolean keepRunning = new AtomicBoolean(true); + PutKey[] putKeys = new PutKey[processors]; + for (int i = 0; i < processors; i++) + { + putKeys[i] = new PutKey(map, "K", keepRunning, ready, start, done, i * THREAD_SEGMENT); + Thread t = new Thread(putKeys[i]); + t.setName("Putty McPutkey-"+i); + t.start(); + } + ready.await(); + start.countDown(); + Thread.sleep(1000); + keepRunning.set(false); + done.await(); + Set values = new HashSet((int)(processors*THREAD_SEGMENT)); + long totalKeys = 0; + for (PutKey putKey : putKeys) + { + values.addAll(putKey.values); + totalKeys += putKey.endIndex - putKey.startIndex; + } + assertEquals(totalKeys, values.size()); + } + + static class PutKey implements Runnable + { + final Map map; + final String key; + final AtomicBoolean keepRunning; + final CountDownLatch ready; + final CountDownLatch start; + final CountDownLatch done; + final int startIndex; + int endIndex; + + List values = new ArrayList<>(THREAD_SEGMENT); + + PutKey( + Map map, + String key, + AtomicBoolean keepRunning, CountDownLatch ready, + CountDownLatch start, + CountDownLatch done, + int startIndex) + { + this.map = map; + this.key = key; + this.keepRunning = keepRunning; + this.ready = ready; + this.start = start; + this.done = done; + this.startIndex = startIndex; + assert startIndex >= 0 && startIndex + THREAD_SEGMENT > 0; + } + + @Override + public void run() + { + ready.countDown(); + try + { + start.await(); + } + catch (InterruptedException e) + { + e.printStackTrace(); + return; + } + long limit = startIndex + THREAD_SEGMENT; + long v = startIndex; + String k = key; + for (; v < limit && keepRunning.get(); v++) + { + values.add(map.put(k, v)); + } + endIndex = (int) v; + done.countDown(); + } + } +} diff --git a/netty-jctools/src/test/java/org/jctools/maps/NBHMIdentityKeyAtomicityTest.java b/netty-jctools/src/test/java/org/jctools/maps/NBHMIdentityKeyAtomicityTest.java new file mode 100644 index 0000000..9e1a035 --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/maps/NBHMIdentityKeyAtomicityTest.java @@ -0,0 +1,103 @@ +package org.jctools.maps; + +import org.junit.Test; + +import java.util.*; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicBoolean; + +import static org.junit.Assert.assertEquals; + +public class NBHMIdentityKeyAtomicityTest +{ + static final int THREAD_SEGMENT = 1000000; + + @Test + public void putReturnValuesAreDistinct() throws Exception + { + Map map = new NonBlockingIdentityHashMap<>(); + map.put("K", -1l); + int processors = Runtime.getRuntime().availableProcessors(); + CountDownLatch ready = new CountDownLatch(processors); + CountDownLatch start = new CountDownLatch(1); + CountDownLatch done = new CountDownLatch(processors); + AtomicBoolean keepRunning = new AtomicBoolean(true); + PutKey[] putKeys = new PutKey[processors]; + for (int i = 0; i < processors; i++) + { + putKeys[i] = new PutKey(map, "K", keepRunning, ready, start, done, i * THREAD_SEGMENT); + Thread t = new Thread(putKeys[i]); + t.setName("Putty McPutkey-"+i); + t.start(); + } + ready.await(); + start.countDown(); + Thread.sleep(1000); + keepRunning.set(false); + done.await(); + Set values = new HashSet((int)(processors*THREAD_SEGMENT)); + long totalKeys = 0; + for (PutKey putKey : putKeys) + { + values.addAll(putKey.values); + totalKeys += putKey.endIndex - putKey.startIndex; + } + assertEquals(totalKeys, values.size()); + } + + static class PutKey implements Runnable + { + final Map map; + final String key; + final AtomicBoolean keepRunning; + final CountDownLatch ready; + final CountDownLatch start; + final CountDownLatch done; + final int startIndex; + int endIndex; + + List values = new ArrayList<>(THREAD_SEGMENT); + + PutKey( + Map map, + String key, + AtomicBoolean keepRunning, CountDownLatch ready, + CountDownLatch start, + CountDownLatch done, + int startIndex) + { + this.map = map; + this.key = key; + this.keepRunning = keepRunning; + this.ready = ready; + this.start = start; + this.done = done; + this.startIndex = startIndex; + assert startIndex >= 0 && startIndex + THREAD_SEGMENT > 0; + } + + @Override + public void run() + { + ready.countDown(); + try + { + start.await(); + } + catch (InterruptedException e) + { + e.printStackTrace(); + return; + } + long limit = startIndex + THREAD_SEGMENT; + long v = startIndex; + String k = key; + for (; v < limit && keepRunning.get(); v++) + { + values.add(map.put(k, v)); + } + endIndex = (int) v; + done.countDown(); + } + } +} diff --git a/netty-jctools/src/test/java/org/jctools/maps/NBHMLongKeyAtomicityTest.java b/netty-jctools/src/test/java/org/jctools/maps/NBHMLongKeyAtomicityTest.java new file mode 100644 index 0000000..5a4c005 --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/maps/NBHMLongKeyAtomicityTest.java @@ -0,0 +1,103 @@ +package org.jctools.maps; + +import org.junit.Test; + +import java.util.*; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.atomic.AtomicBoolean; + +import static org.junit.Assert.assertEquals; + +public class NBHMLongKeyAtomicityTest +{ + static final int THREAD_SEGMENT = 1000000; + static final long K = 1; + @Test + public void putReturnValuesAreDistinct() throws Exception + { + Map map = new NonBlockingHashMapLong<>(); + map.put(K, -1l); + int processors = Runtime.getRuntime().availableProcessors(); + CountDownLatch ready = new CountDownLatch(processors); + CountDownLatch start = new CountDownLatch(1); + CountDownLatch done = new CountDownLatch(processors); + AtomicBoolean keepRunning = new AtomicBoolean(true); + PutKey[] putKeys = new PutKey[processors]; + for (int i = 0; i < processors; i++) + { + putKeys[i] = new PutKey(map, K, keepRunning, ready, start, done, i * THREAD_SEGMENT); + Thread t = new Thread(putKeys[i]); + t.setName("Putty McPutkey-"+i); + t.start(); + } + ready.await(); + start.countDown(); + Thread.sleep(1000); + keepRunning.set(false); + done.await(); + Set values = new HashSet((int)(processors*THREAD_SEGMENT)); + long totalKeys = 0; + for (PutKey putKey : putKeys) + { + values.addAll(putKey.values); + totalKeys += putKey.endIndex - putKey.startIndex; + } + assertEquals(totalKeys, values.size()); + } + + static class PutKey implements Runnable + { + final Map map; + final long key; + final AtomicBoolean keepRunning; + final CountDownLatch ready; + final CountDownLatch start; + final CountDownLatch done; + final int startIndex; + int endIndex; + + List values = new ArrayList<>(THREAD_SEGMENT); + + PutKey( + Map map, + long key, + AtomicBoolean keepRunning, CountDownLatch ready, + CountDownLatch start, + CountDownLatch done, + int startIndex) + { + this.map = map; + this.key = key; + this.keepRunning = keepRunning; + this.ready = ready; + this.start = start; + this.done = done; + this.startIndex = startIndex; + assert startIndex >= 0 && startIndex + THREAD_SEGMENT > 0; + } + + @Override + public void run() + { + ready.countDown(); + try + { + start.await(); + } + catch (InterruptedException e) + { + e.printStackTrace(); + return; + } + long limit = startIndex + THREAD_SEGMENT; + long v = startIndex; + long k = key; + for (; v < limit && keepRunning.get(); v++) + { + values.add(map.put(k, v)); + } + endIndex = (int) v; + done.countDown(); + } + } +} diff --git a/netty-jctools/src/test/java/org/jctools/maps/NBHMRemoveTest.java b/netty-jctools/src/test/java/org/jctools/maps/NBHMRemoveTest.java new file mode 100644 index 0000000..d0fb0a6 --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/maps/NBHMRemoveTest.java @@ -0,0 +1,250 @@ +package org.jctools.maps; + +import org.junit.After; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.*; + +import static org.hamcrest.Matchers.instanceOf; +import static org.hamcrest.Matchers.is; +import static org.junit.Assert.*; +import static org.junit.Assume.assumeThat; + +@RunWith(Parameterized.class) +public class NBHMRemoveTest { + + public static final Long TEST_KEY_OTHER = 123777L; + public static final Long TEST_KEY_0 = 0L; + + @Parameterized.Parameters + public static Collection parameters() { + ArrayList list = new ArrayList<>(); + // Verify the test assumptions against JDK reference implementations, useful for debugging +// list.add(new Object[]{new HashMap<>(), TEST_KEY_0, "0", "1"}); +// list.add(new Object[]{new ConcurrentHashMap<>(), TEST_KEY_0, "0", "1"}); +// list.add(new Object[]{new Hashtable<>(), TEST_KEY_0, "0", "1"}); + + // Test with special key + list.add(new Object[]{new NonBlockingHashMap<>(), TEST_KEY_0, "0", "1"}); + list.add(new Object[]{new NonBlockingHashMapLong<>(), TEST_KEY_0, "0", "1"}); + list.add(new Object[]{new NonBlockingIdentityHashMap<>(), TEST_KEY_0, "0", "1"}); + + // Test with some other key + list.add(new Object[]{new NonBlockingHashMap<>(), TEST_KEY_OTHER, "0", "1"}); + list.add(new Object[]{new NonBlockingHashMapLong<>(), TEST_KEY_OTHER, "0", "1"}); + list.add(new Object[]{new NonBlockingIdentityHashMap<>(), TEST_KEY_OTHER, "0", "1"}); + + return list; + } + + final Map map; + final Long key; + final String v1; + final String v2; + + public NBHMRemoveTest(Map map, Long key, String v1, String v2) { + this.map = map; + this.key = key; + this.v1 = v1; + this.v2 = v2; + } + + @After + public void clear() { + map.clear(); + } + + /** + * This test demonstrates a retention issue in the NBHM implementation which keeps old keys around until a resize + * event. + * See https://github.com/JCTools/JCTools/issues/354 + */ + @Test + public void removeRetainsKey() { + assumeThat(map, is(instanceOf(NonBlockingHashMap.class))); + assumeThat(key, is(TEST_KEY_0)); + + Long key1 = Long.valueOf(42424242); + Long key2 = Long.valueOf(42424242); + assertEquals(key1, key2); + assertFalse(key1 == key2); + // key1 and key2 are different instances with same hash/equals + map.put(key1, "a"); + map.remove(key1); + if (map instanceof NonBlockingHashMap) { + assertTrue(contains(((NonBlockingHashMap) map).raw_array(), key1)); + } + map.put(key2, "a"); + if (map instanceof NonBlockingHashMap) { + assertFalse(contains(((NonBlockingHashMap) map).raw_array(), key2)); + } + // key1 remains in the map + Set keySet = map.keySet(); + assertEquals(keySet.size(), 1); + if (map instanceof NonBlockingHashMap) { + assertTrue(keySet.toArray()[0] == key1); + } + } + + /** + * This test demonstrates a retention issue in the NBHM implementation which keeps old keys around until a resize + * event. + * See https://github.com/JCTools/JCTools/issues/354 + */ + @Test + public void removeRetainsKey2() { + assumeThat(map, is(instanceOf(NonBlockingHashMap.class))); + assumeThat(key, is(TEST_KEY_0)); + + String key1 = "Aa"; + String key2 = "BB"; + NonBlockingHashMap map = new NonBlockingHashMap<>(); + assertEquals(key1.hashCode(), key2.hashCode()); + assertNotEquals(key1, key2); + assertFalse(key1 == key2); + // key1 and key2 are different instances with same hash, but different equals + map.put(key1, "a"); + map.remove(key1); + map.put(key2, "a"); + // key1 remains in the map + Set keySet = map.keySet(); + assertEquals(keySet.size(), 1); + assertEquals(keySet.toArray()[0], key2); + Object[] raw_array = map.raw_array(); + assertTrue(contains(raw_array, key1)); + assertTrue(contains(raw_array, key2)); + } + + private boolean contains(Object[] raw_array, Object v) { + for (int i = 0; i < raw_array.length; i++) { + if (raw_array[i] == v) return true; + } + return false; + } + + @Test + public void directRemoveKey() { + installValue(map, key, v1); + assertEquals(v1, map.remove(key)); + postRemoveAsserts(map, key); + assertFalse(map.containsValue(v1)); + } + + @Test + public void keySetIteratorRemoveKey() { + installValue(map, key, v1); + Iterator iterator = map.keySet().iterator(); + while (iterator.hasNext()) { + if (key.equals(iterator.next())) { + iterator.remove(); + break; + } + } + postRemoveAsserts(map, key); + assertFalse(map.containsValue(v1)); + } + + @Test + public void keySetIteratorRemoveKeyAfterValChange() { + installValue(map, key, v1); + Iterator iterator = map.keySet().iterator(); + map.put(key, v2); + while (iterator.hasNext()) { + if (key.equals(iterator.next())) { + iterator.remove(); + break; + } + } + postRemoveAsserts(map, key); + assertFalse(map.containsValue(v1)); + assertFalse(map.containsValue(v2)); + } + + @Test + public void entriesIteratorRemoveKey() { + installValue(map, key, v1); + Iterator> iterator = map.entrySet().iterator(); + while (iterator.hasNext()) { + Map.Entry entry = iterator.next(); + if (key.equals(entry.getKey())) { + iterator.remove(); + break; + } + } + postRemoveAsserts(map, key); + assertFalse(map.containsValue(v1)); + } + + @Test + public void entriesIteratorRemoveKeyAfterValChange() { + installValue(map, key, v1); + Iterator> iterator = map.entrySet().iterator(); + assertTrue(iterator.hasNext()); + Map.Entry entry = iterator.next(); + assertEquals(key, entry.getKey()); + // change the value for the key + map.put(key, v2); + + iterator.remove(); + // This is weird, since the entry has in fact changed, so should not be removed, but JDK implementations + // all remove based on the key. + postRemoveAsserts(map, key); + assertFalse(map.containsValue(v1)); + assertFalse(map.containsValue(v2)); + } + + @Test + public void valuesIteratorRemove() { + installValue(map, key, v1); + Iterator iterator = map.values().iterator(); + while (iterator.hasNext()) { + if (v1.equals(iterator.next())) { + iterator.remove(); + break; + } + } + postRemoveAsserts(map, key); + assertFalse(map.containsValue(v1)); + } + + @Test + public void valuesIteratorRemoveAfterValChange() { + installValue(map, key, v1); + Iterator iterator = map.values().iterator(); + assertTrue(iterator.hasNext()); + String value = iterator.next(); + assertEquals(v1, value); + + // change the value for the key + map.put(key, v2); + + iterator.remove(); + // This is weird, since the entry has in fact changed, so should not be removed, but JDK implementations + // all remove based on the key. + postRemoveAsserts(map, key); + assertFalse(map.containsValue(v1)); + assertFalse(map.containsValue(v2)); + } + + private void installValue(Map map, Long testKey, String value) { + map.put(testKey, value); + singleValueInMapAsserts(map, testKey, value); + } + + private void singleValueInMapAsserts(Map map, Long testKey, String value) { + assertEquals(value, map.get(testKey)); + assertEquals(1, map.size()); + assertFalse(map.isEmpty()); + assertTrue(map.containsKey(testKey)); + assertTrue(map.containsValue(value)); + } + + private void postRemoveAsserts(Map map, Long testKey) { + assertNull(map.get(testKey)); + assertEquals(0, map.size()); + assertTrue(map.isEmpty()); + assertFalse(map.containsKey(testKey)); + } +} diff --git a/netty-jctools/src/test/java/org/jctools/maps/NBHMReplaceTest.java b/netty-jctools/src/test/java/org/jctools/maps/NBHMReplaceTest.java new file mode 100644 index 0000000..0d74419 --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/maps/NBHMReplaceTest.java @@ -0,0 +1,20 @@ +package org.jctools.maps; + +import org.junit.Test; + +import static org.junit.Assert.assertEquals; + +public class NBHMReplaceTest { + @Test + public void replaceOnEmptyMap() { + assertEquals(null, new NonBlockingHashMap().replace("k", "v")); + } + @Test + public void replaceOnEmptyIdentityMap() { + assertEquals(null, new NonBlockingIdentityHashMap().replace("k", "v")); + } + @Test + public void replaceOnEmptyLongMap() { + assertEquals(null, new NonBlockingHashMapLong().replace(1, "v")); + } +} diff --git a/netty-jctools/src/test/java/org/jctools/maps/NonBlockingHashMapGuavaTestSuite.java b/netty-jctools/src/test/java/org/jctools/maps/NonBlockingHashMapGuavaTestSuite.java new file mode 100644 index 0000000..3262a7e --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/maps/NonBlockingHashMapGuavaTestSuite.java @@ -0,0 +1,139 @@ +package org.jctools.maps; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import com.google.common.collect.testing.*; +import com.google.common.collect.testing.features.CollectionFeature; +import com.google.common.collect.testing.features.CollectionSize; +import com.google.common.collect.testing.features.MapFeature; +import com.google.common.collect.testing.testers.MapReplaceEntryTester; +import com.google.common.collect.testing.testers.MapReplaceTester; + +import junit.framework.Test; +import junit.framework.TestCase; +import junit.framework.TestSuite; + +/** + * @author Tolstopyatov Vsevolod + * @since 04/06/17 + */ +@SuppressWarnings("unchecked") +public class NonBlockingHashMapGuavaTestSuite extends TestCase +{ + + public static Test suite() throws Exception + { + TestSuite suite = new TestSuite(); + + TestSuite mapSuite = mapTestSuite(new TestStringMapGenerator() + { + @Override + protected Map create(Map.Entry[] entries) + { + Map map = new NonBlockingHashMap<>(); + for (Map.Entry entry : entries) + { + map.put(entry.getKey(), entry.getValue()); + } + return map; + } + }, NonBlockingHashMap.class.getSimpleName()); + + TestSuite idMapSuite = mapTestSuite(new TestStringMapGenerator() + { + @Override + protected Map create(Map.Entry[] entries) + { + Map map = new NonBlockingIdentityHashMap<>(); + for (Map.Entry entry : entries) + { + map.put(entry.getKey(), entry.getValue()); + } + return map; + } + }, NonBlockingIdentityHashMap.class.getSimpleName()); + + TestSuite longMapSuite = mapTestSuite(new TestMapGenerator() + { + @Override + public Long[] createKeyArray(int length) + { + return new Long[length]; + } + + @Override + public Long[] createValueArray(int length) + { + return new Long[length]; + } + + @Override + public SampleElements> samples() + { + return new SampleElements<>( + Helpers.mapEntry(1L, 1L), + Helpers.mapEntry(2L, 2L), + Helpers.mapEntry(3L, 3L), + Helpers.mapEntry(4L, 4L), + Helpers.mapEntry(5L, 5L)); + } + + @Override + public Map create(Object... elements) + { + Map map = new NonBlockingHashMapLong<>(); + for (Object o : elements) + { + Map.Entry e = (Map.Entry) o; + map.put(e.getKey(), e.getValue()); + } + return map; + } + + @Override + public Map.Entry[] createArray(int length) + { + return new Map.Entry[length]; + } + + @Override + public Iterable> order(List> insertionOrder) + { + return insertionOrder; + } + }, NonBlockingHashMapLong.class.getSimpleName()); + + suite.addTest(mapSuite); + suite.addTest(idMapSuite); + suite.addTest(longMapSuite); + return suite; + } + + private static TestSuite mapTestSuite(TestMapGenerator testMapGenerator, String name) + { + return new MapTestSuiteBuilder() + { + { + usingGenerator(testMapGenerator); + } + + @Override + protected List> getTesters() + { + List> testers = new ArrayList<>(super.getTesters()); + // NonBlockingHashMap doesn't support null in putIfAbsent and provides putIfAbsentAllowsNull instead + testers.remove(MapReplaceEntryTester.class); + testers.remove(MapReplaceTester.class); + return testers; + } + }.withFeatures( + MapFeature.GENERAL_PURPOSE, + CollectionSize.ANY, + CollectionFeature.SUPPORTS_ITERATOR_REMOVE) + .named(name) + .createTestSuite(); + } + +} diff --git a/netty-jctools/src/test/java/org/jctools/maps/linearizability_test/LincheckMapTest.java b/netty-jctools/src/test/java/org/jctools/maps/linearizability_test/LincheckMapTest.java new file mode 100644 index 0000000..add62e7 --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/maps/linearizability_test/LincheckMapTest.java @@ -0,0 +1,109 @@ +package org.jctools.maps.linearizability_test; + +import org.jetbrains.annotations.NotNull; +import org.jetbrains.kotlinx.lincheck.*; +import org.jetbrains.kotlinx.lincheck.annotations.*; +import org.jetbrains.kotlinx.lincheck.paramgen.*; +import org.jetbrains.kotlinx.lincheck.strategy.managed.modelchecking.ModelCheckingOptions; +import org.jetbrains.kotlinx.lincheck.strategy.stress.StressOptions; +import org.jetbrains.kotlinx.lincheck.verifier.VerifierState; +import org.junit.Test; + +import java.util.Map; + +@Param(name = "key", gen = LongGen.class, conf = "1:5") // keys are longs in 1..5 range +@Param(name = "value", gen = IntGen.class, conf = "1:10") // values are longs in 1..10 range +public abstract class LincheckMapTest extends VerifierState +{ + private final Map map; + + public LincheckMapTest(Map map) + { + this.map = map; + } + + @Operation + public Integer get(@Param(name = "key") long key) + { + return map.get(key); + } + + @Operation + public Integer put(@Param(name = "key") long key, @Param(name = "value") int value) + { + return map.put(key, value); + } + + @Operation + public boolean replace(@Param(name = "key") long key, @Param(name = "value") int previousValue, @Param(name = "value") int nextValue) + { + return map.replace(key, previousValue, nextValue); + } + + @Operation + public Integer remove(@Param(name = "key") long key) + { + return map.remove(key); + } + + @Operation + public boolean containsKey(@Param(name = "key") long key) + { + return map.containsKey(key); + } + + @Operation + public boolean containsValue(@Param(name = "value") int value) + { + return map.containsValue(value); + } + + @Operation + public void clear() + { + map.clear(); + } + + /** + * This test checks that the concurrent map is linearizable with bounded model checking. + * Unlike stress testing, this approach can also provide a trace of an incorrect execution. + * However, it uses sequential consistency model, so it can not find any low-level bugs (e.g., missing 'volatile'), + * and thus, it it recommended to have both test modes. + */ + @Test + public void modelCheckingTest() + { + ModelCheckingOptions options = new ModelCheckingOptions(); + // The size of the test can be changed with 'options.iterations' or `options.invocationsPerIteration`. + // The first one defines the number of different scenarios generated, + // while the second one determines how deeply each scenario is tested. + new LinChecker(this.getClass(), options).check(); + } + + /** + * This test checks that the concurrent map is linearizable with stress testing. + */ + @Test + public void stressTest() + { + StressOptions options = new StressOptions(); + // The size of the test can be changed with 'options.iterations' or `options.invocationsPerIteration`. + // The first one defines the number of different scenarios generated, + // while the second one determines how deeply each scenario is tested. + new LinChecker(this.getClass(), options).check(); + } + + /** + * Provides something with correct equals and hashCode methods + * that can be interpreted as an internal data structure state for faster verification. + * The only limitation is that it should be different for different data structure states. + * For {@link Map} it itself is used. + * @return object representing internal state + */ + @NotNull + @Override + protected Object extractState() + { + return map; + } +} diff --git a/netty-jctools/src/test/java/org/jctools/maps/linearizability_test/LincheckSetTest.java b/netty-jctools/src/test/java/org/jctools/maps/linearizability_test/LincheckSetTest.java new file mode 100644 index 0000000..763e571 --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/maps/linearizability_test/LincheckSetTest.java @@ -0,0 +1,90 @@ +package org.jctools.maps.linearizability_test; + +import org.jetbrains.annotations.NotNull; +import org.jetbrains.kotlinx.lincheck.*; +import org.jetbrains.kotlinx.lincheck.annotations.*; +import org.jetbrains.kotlinx.lincheck.paramgen.*; +import org.jetbrains.kotlinx.lincheck.strategy.managed.modelchecking.ModelCheckingOptions; +import org.jetbrains.kotlinx.lincheck.strategy.stress.StressOptions; +import org.jetbrains.kotlinx.lincheck.verifier.VerifierState; +import org.junit.Test; + +import java.util.Set; + +@Param(name = "key", gen = IntGen.class, conf = "1:5") // keys are longs in 1..5 range +public abstract class LincheckSetTest extends VerifierState +{ + private final Set set; + + public LincheckSetTest(Set set) + { + this.set = set; + } + + @Operation + public boolean contains(@Param(name = "key") int key) + { + return set.contains(key); + } + + @Operation + public boolean add(@Param(name = "key") int key) + { + return set.add(key); + } + + @Operation + public boolean remove(@Param(name = "key") int key) + { + return set.remove(key); + } + + @Operation + public void clear() + { + set.clear(); + } + + /** + * This test checks that the concurrent set is linearizable with bounded model checking. + * Unlike stress testing, this approach can also provide a trace of an incorrect execution. + * However, it uses sequential consistency model, so it can not find any low-level bugs (e.g., missing 'volatile'), + * and thus, it it recommended to have both test modes. + */ + @Test + public void modelCheckingTest() + { + ModelCheckingOptions options = new ModelCheckingOptions(); + // The size of the test can be changed with 'options.iterations' or `options.invocationsPerIteration`. + // The first one defines the number of different scenarios generated, + // while the second one determines how deeply each scenario is tested. + new LinChecker(this.getClass(), options).check(); + } + + /** + * This test checks that the concurrent set is linearizable with stress testing. + */ + @Test + public void stressTest() + { + StressOptions options = new StressOptions(); + // The size of the test can be changed with 'options.iterations' or `options.invocationsPerIteration`. + // The first one defines the number of different scenarios generated, + // while the second one determines how deeply each scenario is tested. + new LinChecker(this.getClass(), options).check(); + } + + /** + * Provides something with correct equals and hashCode methods + * that can be interpreted as an internal data structure state for faster verification. + * The only limitation is that it should be different for different data structure states. + * For {@link Set} it itself is used. + * @return object representing internal state + */ + @NotNull + @Override + protected Object extractState() + { + return set; + } +} diff --git a/netty-jctools/src/test/java/org/jctools/maps/linearizability_test/NonBlockingHashMapLinearizabilityTest.java b/netty-jctools/src/test/java/org/jctools/maps/linearizability_test/NonBlockingHashMapLinearizabilityTest.java new file mode 100644 index 0000000..99ecbc1 --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/maps/linearizability_test/NonBlockingHashMapLinearizabilityTest.java @@ -0,0 +1,11 @@ +package org.jctools.maps.linearizability_test; + +import org.jctools.maps.NonBlockingHashMap; + +public class NonBlockingHashMapLinearizabilityTest extends LincheckMapTest +{ + public NonBlockingHashMapLinearizabilityTest() + { + super(new NonBlockingHashMap<>()); + } +} diff --git a/netty-jctools/src/test/java/org/jctools/maps/linearizability_test/NonBlockingHashMapLongLinearizabilityTest.java b/netty-jctools/src/test/java/org/jctools/maps/linearizability_test/NonBlockingHashMapLongLinearizabilityTest.java new file mode 100644 index 0000000..ff482d6 --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/maps/linearizability_test/NonBlockingHashMapLongLinearizabilityTest.java @@ -0,0 +1,11 @@ +package org.jctools.maps.linearizability_test; + +import org.jctools.maps.NonBlockingHashMapLong; + +public class NonBlockingHashMapLongLinearizabilityTest extends LincheckMapTest +{ + public NonBlockingHashMapLongLinearizabilityTest() + { + super(new NonBlockingHashMapLong<>()); + } +} diff --git a/netty-jctools/src/test/java/org/jctools/maps/linearizability_test/NonBlockingHashSetLinearizabilityTest.java b/netty-jctools/src/test/java/org/jctools/maps/linearizability_test/NonBlockingHashSetLinearizabilityTest.java new file mode 100644 index 0000000..b60264a --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/maps/linearizability_test/NonBlockingHashSetLinearizabilityTest.java @@ -0,0 +1,11 @@ +package org.jctools.maps.linearizability_test; + +import org.jctools.maps.NonBlockingHashSet; + +public class NonBlockingHashSetLinearizabilityTest extends LincheckSetTest +{ + public NonBlockingHashSetLinearizabilityTest() + { + super(new NonBlockingHashSet<>()); + } +} diff --git a/netty-jctools/src/test/java/org/jctools/maps/linearizability_test/NonBlockingIdentityHashMapLinearizabilityTest.java b/netty-jctools/src/test/java/org/jctools/maps/linearizability_test/NonBlockingIdentityHashMapLinearizabilityTest.java new file mode 100644 index 0000000..4561d26 --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/maps/linearizability_test/NonBlockingIdentityHashMapLinearizabilityTest.java @@ -0,0 +1,15 @@ +package org.jctools.maps.linearizability_test; + +import org.jctools.maps.NonBlockingIdentityHashMap; + +public class NonBlockingIdentityHashMapLinearizabilityTest extends LincheckMapTest +{ + public NonBlockingIdentityHashMapLinearizabilityTest() + { + // For NonBlockingIdentityHashMap operations with long keys may seem strange, + // but as small Longs are typically cached in the jvm, + // map.put(1L, value); map.containsKey(1L); + // returns true, not false. + super(new NonBlockingIdentityHashMap<>()); + } +} diff --git a/netty-jctools/src/test/java/org/jctools/maps/linearizability_test/NonBlockingSetIntLinearizabilityTest.java b/netty-jctools/src/test/java/org/jctools/maps/linearizability_test/NonBlockingSetIntLinearizabilityTest.java new file mode 100644 index 0000000..9a7197d --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/maps/linearizability_test/NonBlockingSetIntLinearizabilityTest.java @@ -0,0 +1,11 @@ +package org.jctools.maps.linearizability_test; + +import org.jctools.maps.NonBlockingSetInt; + +public class NonBlockingSetIntLinearizabilityTest extends LincheckSetTest +{ + public NonBlockingSetIntLinearizabilityTest() + { + super(new NonBlockingSetInt()); + } +} diff --git a/netty-jctools/src/test/java/org/jctools/maps/nbhm_test/NBHMID_Tester2.java b/netty-jctools/src/test/java/org/jctools/maps/nbhm_test/NBHMID_Tester2.java new file mode 100644 index 0000000..02d01a4 --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/maps/nbhm_test/NBHMID_Tester2.java @@ -0,0 +1,721 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.maps.nbhm_test; + +import java.io.*; +import java.util.*; +import java.util.concurrent.*; + +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +import org.jctools.maps.NonBlockingIdentityHashMap; + +import static org.hamcrest.CoreMatchers.*; +import static org.junit.Assert.*; + +// Test NonBlockingHashMap via JUnit +public class NBHMID_Tester2 +{ + static private NonBlockingIdentityHashMap _nbhm; + + @BeforeClass + public static void setUp() + { + _nbhm = new NonBlockingIdentityHashMap<>(); + } + + @AfterClass + public static void tearDown() + { + _nbhm = null; + } + + // Test some basic stuff; add a few keys, remove a few keys + @Test + public void testBasic() + { + assertTrue(_nbhm.isEmpty()); + assertThat(_nbhm.putIfAbsent("k1", "v1"), nullValue()); + checkSizes(1); + assertThat(_nbhm.putIfAbsent("k2", "v2"), nullValue()); + checkSizes(2); + assertTrue(_nbhm.containsKey("k2")); + assertThat(_nbhm.put("k1", "v1a"), is("v1")); + assertThat(_nbhm.put("k2", "v2a"), is("v2")); + checkSizes(2); + assertThat(_nbhm.putIfAbsent("k2", "v2b"), is("v2a")); + assertThat(_nbhm.remove("k1"), is("v1a")); + assertFalse(_nbhm.containsKey("k1")); + checkSizes(1); + assertThat(_nbhm.remove("k1"), nullValue()); + assertThat(_nbhm.remove("k2"), is("v2a")); + checkSizes(0); + assertThat(_nbhm.remove("k2"), nullValue()); + assertThat(_nbhm.remove("k3"), nullValue()); + assertTrue(_nbhm.isEmpty()); + + assertThat(_nbhm.put("k0", "v0"), nullValue()); + assertTrue(_nbhm.containsKey("k0")); + checkSizes(1); + assertThat(_nbhm.remove("k0"), is("v0")); + assertFalse(_nbhm.containsKey("k0")); + checkSizes(0); + + assertThat(_nbhm.replace("k0", "v0"), nullValue()); + assertFalse(_nbhm.containsKey("k0")); + assertThat(_nbhm.put("k0", "v0"), nullValue()); + assertEquals(_nbhm.replace("k0", "v0a"), "v0"); + assertEquals(_nbhm.get("k0"), "v0a"); + assertThat(_nbhm.remove("k0"), is("v0a")); + assertFalse(_nbhm.containsKey("k0")); + checkSizes(0); + + assertThat(_nbhm.replace("k1", "v1"), nullValue()); + assertFalse(_nbhm.containsKey("k1")); + assertThat(_nbhm.put("k1", "v1"), nullValue()); + assertEquals(_nbhm.replace("k1", "v1a"), "v1"); + assertEquals(_nbhm.get("k1"), "v1a"); + assertThat(_nbhm.remove("k1"), is("v1a")); + assertFalse(_nbhm.containsKey("k1")); + checkSizes(0); + + // Insert & Remove KeyBonks until the table resizes and we start + // finding Tombstone keys- and KeyBonk's equals-call with throw a + // ClassCastException if it sees a non-KeyBonk. + NonBlockingIdentityHashMap dumb = new NonBlockingIdentityHashMap<>(); + for (int i = 0; i < 10000; i++) + { + final KeyBonk happy1 = new KeyBonk(i); + assertThat(dumb.put(happy1, "and"), nullValue()); + if ((i & 1) == 0) + { + dumb.remove(happy1); + } + final KeyBonk happy2 = new KeyBonk(i); // 'equals' but not '==' + dumb.get(happy2); + } + } + + // Check all iterators for correct size counts + private void checkSizes(int expectedSize) + { + assertEquals("size()", _nbhm.size(), expectedSize); + Collection vals = _nbhm.values(); + checkSizes("values()", vals.size(), vals.iterator(), expectedSize); + Set keys = _nbhm.keySet(); + checkSizes("keySet()", keys.size(), keys.iterator(), expectedSize); + Set> ents = _nbhm.entrySet(); + checkSizes("entrySet()", ents.size(), ents.iterator(), expectedSize); + } + + // Check that the iterator iterates the correct number of times + private void checkSizes(String msg, int sz, Iterator it, int expectedSize) + { + assertEquals(msg, expectedSize, sz); + int result = 0; + while (it.hasNext()) + { + result++; + it.next(); + } + assertEquals(msg, expectedSize, result); + } + + @Test + public void testIteration() + { + assertTrue(_nbhm.isEmpty()); + assertThat(_nbhm.put("k1", "v1"), nullValue()); + assertThat(_nbhm.put("k2", "v2"), nullValue()); + + String str1 = ""; + for (Map.Entry e : _nbhm.entrySet()) + { + str1 += e.getKey(); + } + assertThat("found all entries", str1, anyOf(is("k1k2"), is("k2k1"))); + + String str2 = ""; + for (String key : _nbhm.keySet()) + { + str2 += key; + } + assertThat("found all keys", str2, anyOf(is("k1k2"), is("k2k1"))); + + String str3 = ""; + for (String val : _nbhm.values()) + { + str3 += val; + } + assertThat("found all vals", str3, anyOf(is("v1v2"), is("v2v1"))); + + assertThat("toString works", _nbhm.toString(), anyOf(is("{k1=v1, k2=v2}"), is("{k2=v2, k1=v1}"))); + _nbhm.clear(); + } + + @Test + public void testSerial() + { + assertTrue(_nbhm.isEmpty()); + assertThat(_nbhm.put("k1", "v1"), nullValue()); + assertThat(_nbhm.put("k2", "v2"), nullValue()); + + // Serialize it out + try + { + FileOutputStream fos = new FileOutputStream("NBHM_test.txt"); + ObjectOutputStream out = new ObjectOutputStream(fos); + out.writeObject(_nbhm); + out.close(); + } + catch (IOException ex) + { + ex.printStackTrace(); + } + + // Read it back + try + { + File f = new File("NBHM_test.txt"); + FileInputStream fis = new FileInputStream(f); + ObjectInputStream in = new ObjectInputStream(fis); + NonBlockingIdentityHashMap nbhm = (NonBlockingIdentityHashMap) in.readObject(); + in.close(); + + assertEquals(_nbhm.size(), nbhm.size()); + Object[] keys = nbhm.keySet().toArray(); + if (keys[0].equals("k1")) + { + assertEquals(nbhm.get(keys[0]), "v1"); + assertEquals(nbhm.get(keys[1]), "v2"); + } + else + { + assertEquals(nbhm.get(keys[1]), "v1"); + assertEquals(nbhm.get(keys[0]), "v2"); + } + + if (!f.delete()) + { + throw new IOException("delete failed"); + } + } + catch (IOException | ClassNotFoundException ex) + { + ex.printStackTrace(); + } + _nbhm.clear(); + } + + @Test + public void testIterationBig2() + { + final int CNT = 10000; + NonBlockingIdentityHashMap nbhm = new NonBlockingIdentityHashMap<>(); + final String v = "v"; + for (int i = 0; i < CNT; i++) + { + final Integer z = i; + String s0 = nbhm.get(z); + assertThat(s0, nullValue()); + nbhm.put(z, v); + String s1 = nbhm.get(z); + assertThat(s1, is(v)); + } + assertThat(nbhm.size(), is(CNT)); + _nbhm.clear(); + } + + @Test + public void testIterationBig() + { + final int CNT = 10000; + String[] keys = new String[CNT]; + String[] vals = new String[CNT]; + assertThat(_nbhm.size(), is(0)); + for (int i = 0; i < CNT; i++) + { + _nbhm.put(keys[i] = ("k" + i), vals[i] = ("v" + i)); + } + assertThat(_nbhm.size(), is(CNT)); + + int sz = 0; + int sum = 0; + for (String s : _nbhm.keySet()) + { + sz++; + assertThat("", s.charAt(0), is('k')); + int x = Integer.parseInt(s.substring(1)); + sum += x; + assertTrue(x >= 0 && x <= (CNT - 1)); + } + assertThat("Found 10000 ints", sz, is(CNT)); + assertThat("Found all integers in list", sum, is(CNT * (CNT - 1) / 2)); + + assertThat("can remove 3", _nbhm.remove(keys[3]), is(vals[3])); + assertThat("can remove 4", _nbhm.remove(keys[4]), is(vals[4])); + sz = 0; + sum = 0; + for (String s : _nbhm.keySet()) + { + sz++; + assertThat("", s.charAt(0), is('k')); + int x = Integer.parseInt(s.substring(1)); + sum += x; + assertTrue(x >= 0 && x <= (CNT - 1)); + String v = _nbhm.get(s); + assertThat("", v.charAt(0), is('v')); + assertThat("", s.substring(1), is(v.substring(1))); + } + assertThat("Found " + (CNT - 2) + " ints", sz, is(CNT - 2)); + assertThat("Found all integers in list", sum, is(CNT * (CNT - 1) / 2 - (3 + 4))); + _nbhm.clear(); + } + + // Do some simple concurrent testing + @Test + public void testConcurrentSimple() throws InterruptedException + { + final NonBlockingIdentityHashMap nbhm = new NonBlockingIdentityHashMap<>(); + final String[] keys = new String[20000]; + for (int i = 0; i < 20000; i++) + { + keys[i] = "k" + i; + } + + // In 2 threads, add & remove even & odd elements concurrently + Thread t1 = new Thread() + { + public void run() + { + work_helper(nbhm, "T1", 1, keys); + } + }; + t1.start(); + work_helper(nbhm, "T0", 0, keys); + t1.join(); + + // In the end, all members should be removed + StringBuilder buf = new StringBuilder(); + buf.append("Should be emptyset but has these elements: {"); + boolean found = false; + for (String x : nbhm.keySet()) + { + buf.append(" ").append(x); + found = true; + } + if (found) + { + System.out.println(buf + " }"); + } + assertThat("concurrent size=0", nbhm.size(), is(0)); + assertThat("keySet size=0", nbhm.keySet().size(), is(0)); + } + + void work_helper(NonBlockingIdentityHashMap nbhm, String thrd, int d, String[] keys) + { + final int ITERS = 20000; + for (int j = 0; j < 10; j++) + { + //long start = System.nanoTime(); + for (int i = d; i < ITERS; i += 2) + { + assertThat("this key not in there, so putIfAbsent must work", + nbhm.putIfAbsent(keys[i], thrd), is((String) null)); + } + for (int i = d; i < ITERS; i += 2) + { + assertTrue(nbhm.remove(keys[i], thrd)); + } + //double delta_nanos = System.nanoTime()-start; + //double delta_secs = delta_nanos/1000000000.0; + //double ops = ITERS*2; + //System.out.println("Thrd"+thrd+" "+(ops/delta_secs)+" ops/sec size="+nbhm.size()); + } + } + + @Test + public final void testNonBlockingIdentityHashMapSize() + { + NonBlockingIdentityHashMap items = new NonBlockingIdentityHashMap<>(); + items.put(100L, "100"); + items.put(101L, "101"); + + assertEquals("keySet().size()", 2, items.keySet().size()); + assertTrue("keySet().contains(100)", items.keySet().contains(100L)); + assertTrue("keySet().contains(101)", items.keySet().contains(101L)); + + assertEquals("values().size()", 2, items.values().size()); + assertTrue("values().contains(\"100\")", items.values().contains("100")); + assertTrue("values().contains(\"101\")", items.values().contains("101")); + + assertEquals("entrySet().size()", 2, items.entrySet().size()); + boolean found100 = false; + boolean found101 = false; + for (Map.Entry entry : items.entrySet()) + { + if (entry.getKey().equals(100L)) + { + assertEquals("entry[100].getValue()==\"100\"", "100", entry.getValue()); + found100 = true; + } + else if (entry.getKey().equals(101L)) + { + assertEquals("entry[101].getValue()==\"101\"", "101", entry.getValue()); + found101 = true; + } + } + assertTrue("entrySet().contains([100])", found100); + assertTrue("entrySet().contains([101])", found101); + } + + // Concurrent insertion & then iterator test. + @Test + public void testNonBlockingIdentityHashMapIterator() throws InterruptedException + { + final int ITEM_COUNT1 = 1000; + final int THREAD_COUNT = 5; + final int PER_CNT = ITEM_COUNT1 / THREAD_COUNT; + final int ITEM_COUNT = PER_CNT * THREAD_COUNT; // fix roundoff for odd thread counts + + NonBlockingIdentityHashMap nbhml = new NonBlockingIdentityHashMap<>(); + // use a barrier to open the gate for all threads at once to avoid rolling + // start and no actual concurrency + final CyclicBarrier barrier = new CyclicBarrier(THREAD_COUNT); + final ExecutorService ex = Executors.newFixedThreadPool(THREAD_COUNT); + final CompletionService co = new ExecutorCompletionService<>(ex); + for (int i = 0; i < THREAD_COUNT; i++) + { + co.submit(new NBHMLFeeder(nbhml, PER_CNT, barrier, i * PER_CNT)); + } + for (int retCount = 0; retCount < THREAD_COUNT; retCount++) + { + co.take(); + } + ex.shutdown(); + + assertEquals("values().size()", ITEM_COUNT, nbhml.values().size()); + assertEquals("entrySet().size()", ITEM_COUNT, nbhml.entrySet().size()); + int itemCount = 0; + for (TestKey K : nbhml.values()) + { + itemCount++; + } + assertEquals("values().iterator() count", ITEM_COUNT, itemCount); + } + + // --- Customer Test Case 3 ------------------------------------------------ + private TestKeyFeeder getTestKeyFeeder() + { + final TestKeyFeeder feeder = new TestKeyFeeder(); + feeder.checkedPut(10401000001844L, 657829272, 680293140); // section 12 + feeder.checkedPut(10401000000614L, 657829272, 401326994); // section 12 + feeder.checkedPut(10400345749304L, 2095121916, -9852212); // section 12 + feeder.checkedPut(10401000002204L, 657829272, 14438460); // section 12 + feeder.checkedPut(10400345749234L, 1186831289, -894006017); // section 12 + feeder.checkedPut(10401000500234L, 969314784, -2112018706); // section 12 + feeder.checkedPut(10401000000284L, 657829272, 521425852); // section 12 + feeder.checkedPut(10401000002134L, 657829272, 208406306); // section 12 + feeder.checkedPut(10400345749254L, 2095121916, -341939818); // section 12 + feeder.checkedPut(10401000500384L, 969314784, -2136811544); // section 12 + feeder.checkedPut(10401000001944L, 657829272, 935194952); // section 12 + feeder.checkedPut(10400345749224L, 1186831289, -828214183); // section 12 + feeder.checkedPut(10400345749244L, 2095121916, -351234120); // section 12 + feeder.checkedPut(10400333128994L, 2095121916, -496909430); // section 12 + feeder.checkedPut(10400333197934L, 2095121916, 2147144926); // section 12 + feeder.checkedPut(10400333197944L, 2095121916, -2082366964); // section 12 + feeder.checkedPut(10400336947684L, 2095121916, -1404212288); // section 12 + feeder.checkedPut(10401000000594L, 657829272, 124369790); // section 12 + feeder.checkedPut(10400331896264L, 2095121916, -1028383492); // section 12 + feeder.checkedPut(10400332415044L, 2095121916, 1629436704); // section 12 + feeder.checkedPut(10400345749614L, 1186831289, 1027996827); // section 12 + feeder.checkedPut(10401000500424L, 969314784, -1871616544); // section 12 + feeder.checkedPut(10400336947694L, 2095121916, -1468802722); // section 12 + feeder.checkedPut(10410002672481L, 2154973, 1515288586); // section 12 + feeder.checkedPut(10410345749171L, 2154973, 2084791828); // section 12 + feeder.checkedPut(10400004960671L, 2154973, 1554754674); // section 12 + feeder.checkedPut(10410009983601L, 2154973, -2049707334); // section 12 + feeder.checkedPut(10410335811601L, 2154973, 1547385114); // section 12 + feeder.checkedPut(10410000005951L, 2154973, -1136117016); // section 12 + feeder.checkedPut(10400004938331L, 2154973, -1361373018); // section 12 + feeder.checkedPut(10410001490421L, 2154973, -818792874); // section 12 + feeder.checkedPut(10400001187131L, 2154973, 649763142); // section 12 + feeder.checkedPut(10410000409071L, 2154973, -614460616); // section 12 + feeder.checkedPut(10410333717391L, 2154973, 1343531416); // section 12 + feeder.checkedPut(10410336680071L, 2154973, -914544144); // section 12 + feeder.checkedPut(10410002068511L, 2154973, -746995576); // section 12 + feeder.checkedPut(10410336207851L, 2154973, 863146156); // section 12 + feeder.checkedPut(10410002365251L, 2154973, 542724164); // section 12 + feeder.checkedPut(10400335812581L, 2154973, 2146284796); // section 12 + feeder.checkedPut(10410337345361L, 2154973, -384625318); // section 12 + feeder.checkedPut(10410000409091L, 2154973, -528258556); // section 12 + return feeder; + } + + // --- + @Test + public void testNonBlockingIdentityHashMapIteratorMultithreaded() throws InterruptedException, ExecutionException + { + TestKeyFeeder feeder = getTestKeyFeeder(); + final int itemCount = feeder.size(); + + // validate results + final NonBlockingIdentityHashMap items = feeder.getMapMultithreaded(); + assertEquals("size()", itemCount, items.size()); + + assertEquals("values().size()", itemCount, items.values().size()); + + assertEquals("entrySet().size()", itemCount, items.entrySet().size()); + + int iteratorCount = 0; + for (TestKey m : items.values()) + { + iteratorCount++; + } + // sometimes a different result comes back the second time + int iteratorCount2 = 0; + for (TestKey m : items.values()) + { + iteratorCount2++; + } + assertEquals("iterator counts differ", iteratorCount, iteratorCount2); + assertEquals("values().iterator() count", itemCount, iteratorCount); + } + + // Throw a ClassCastException if I see a tombstone during key-compares + private static class KeyBonk + { + final int _x; + + KeyBonk(int i) + { + _x = i; + } + + public int hashCode() + { + return (_x >> 2); + } + + public boolean equals(Object o) + { + return o != null && ((KeyBonk) o)._x // Throw CCE here + == this._x; + } + + + public String toString() + { + return "Bonk_" + Integer.toString(_x); + } + } + + // --- NBHMLFeeder --- + // Class to be called from another thread, to get concurrent installs into + // the table. + static private class NBHMLFeeder implements Callable + { + static private final Random _rand = new Random(System.currentTimeMillis()); + private final NonBlockingIdentityHashMap _map; + private final int _count; + private final CyclicBarrier _barrier; + private final long _offset; + + public NBHMLFeeder( + final NonBlockingIdentityHashMap map, + final int count, + final CyclicBarrier barrier, + final long offset) + { + _map = map; + _count = count; + _barrier = barrier; + _offset = offset; + } + + public Object call() throws Exception + { + _barrier.await(); // barrier, to force racing start + for (long j = 0; j < _count; j++) + { + _map.put( + j + _offset, + new TestKey(_rand.nextLong(), _rand.nextInt(), (short) _rand.nextInt(Short.MAX_VALUE))); + } + return null; + } + } + + // --- TestKey --- + // Funny key tests all sorts of things, has a pre-wired hashCode & equals. + static private final class TestKey + { + public final int _type; + public final long _id; + public final int _hash; + + public TestKey(final long id, final int type, int hash) + { + _id = id; + _type = type; + _hash = hash; + } + + public int hashCode() + { + return _hash; + } + + public boolean equals(Object object) + { + if (null == object) + { + return false; + } + if (object == this) + { + return true; + } + if (object.getClass() != this.getClass()) + { + return false; + } + final TestKey other = (TestKey) object; + return (this._type == other._type && this._id == other._id); + } + + public String toString() + { + return String.format("%s:%d,%d,%d", getClass().getSimpleName(), _id, _type, _hash); + } + } + + // --- + static private class TestKeyFeeder + { + private final Hashtable> _items = new Hashtable<>(); + private int _size = 0; + + public int size() + { + return _size; + } + + // Put items into the hashtable, sorted by 'type' into LinkedLists. + public void checkedPut(final long id, final int type, final int hash) + { + _size++; + final TestKey item = new TestKey(id, type, hash); + if (!_items.containsKey(type)) + { + _items.put(type, new LinkedList<>()); + } + _items.get(type).add(item); + } + + public NonBlockingIdentityHashMap getMapMultithreaded() + throws InterruptedException, ExecutionException + { + final int threadCount = _items.keySet().size(); + final NonBlockingIdentityHashMap map = new NonBlockingIdentityHashMap<>(); + + // use a barrier to open the gate for all threads at once to avoid rolling start and no actual concurrency + final CyclicBarrier barrier = new CyclicBarrier(threadCount); + final ExecutorService ex = Executors.newFixedThreadPool(threadCount); + final CompletionService co = new ExecutorCompletionService<>(ex); + for (Integer type : _items.keySet()) + { + // A linked-list of things to insert + List items = _items.get(type); + TestKeyFeederThread feeder = new TestKeyFeederThread(items, map, barrier); + co.submit(feeder); + } + + // wait for all threads to return + int itemCount = 0; + for (int retCount = 0; retCount < threadCount; retCount++) + { + final Future result = co.take(); + itemCount += result.get(); + } + ex.shutdown(); + return map; + } + } + + // --- TestKeyFeederThread + static private class TestKeyFeederThread implements Callable + { + private final NonBlockingIdentityHashMap _map; + private final List _items; + private final CyclicBarrier _barrier; + + public TestKeyFeederThread( + final List items, + final NonBlockingIdentityHashMap map, + final CyclicBarrier barrier) + { + _map = map; + _items = items; + _barrier = barrier; + } + + public Integer call() throws Exception + { + _barrier.await(); + int count = 0; + for (TestKey item : _items) + { + if (_map.contains(item._id)) + { + System.err.printf("COLLISION DETECTED: %s exists\n", item.toString()); + } + final TestKey exists = _map.putIfAbsent(item._id, item); + if (exists == null) + { + count++; + } + else + { + System.err.printf("COLLISION DETECTED: %s exists as %s\n", item.toString(), exists.toString()); + } + } + return count; + } + } + + // This test is a copy of the JCK test Hashtable2027, which is incorrect. + // The test requires a particular order of values to appear in the esa + // array - but this is not part of the spec. A different implementation + // might put the same values into the array but in a different order. + //public void testToArray() { + // NonBlockingIdentityHashMap ht = new NonBlockingIdentityHashMap(); + // + // ht.put("Nine", new Integer(9)); + // ht.put("Ten", new Integer(10)); + // ht.put("Ten1", new Integer(100)); + // + // Collection es = ht.values(); + // + // Object [] esa = es.toArray(); + // + // ht.remove("Ten1"); + // + // assertEquals( "size check", es.size(), 2 ); + // assertEquals( "iterator_order[0]", new Integer( 9), esa[0] ); + // assertEquals( "iterator_order[1]", new Integer(10), esa[1] ); + //} +} diff --git a/netty-jctools/src/test/java/org/jctools/maps/nbhm_test/NBHML_Tester2.java b/netty-jctools/src/test/java/org/jctools/maps/nbhm_test/NBHML_Tester2.java new file mode 100644 index 0000000..657e0f7 --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/maps/nbhm_test/NBHML_Tester2.java @@ -0,0 +1,663 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.maps.nbhm_test; + +import java.io.*; +import java.util.*; +import java.util.concurrent.*; + +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +import org.jctools.maps.NonBlockingHashMapLong; + +import static org.hamcrest.CoreMatchers.*; +import static org.junit.Assert.*; + +// Test NonBlockingHashMapLong via JUnit +public class NBHML_Tester2 +{ + + static private NonBlockingHashMapLong _nbhml; + + @BeforeClass + public static void setUp() + { + _nbhml = new NonBlockingHashMapLong<>(); + } + + @AfterClass + public static void tearDown() + { + _nbhml = null; + } + + // Test some basic stuff; add a few keys, remove a few keys + @Test + public void testBasic() + { + assertTrue(_nbhml.isEmpty()); + assertThat(_nbhml.put(1, "v1"), nullValue()); + checkSizes(1); + assertThat(_nbhml.putIfAbsent(2, "v2"), nullValue()); + checkSizes(2); + assertTrue(_nbhml.containsKey(2)); + assertThat(_nbhml.put(1, "v1a"), is("v1")); + assertThat(_nbhml.put(2, "v2a"), is("v2")); + checkSizes(2); + assertThat(_nbhml.putIfAbsent(2, "v2b"), is("v2a")); + assertThat(_nbhml.remove(1), is("v1a")); + assertFalse(_nbhml.containsKey(1)); + checkSizes(1); + assertThat(_nbhml.remove(1), nullValue()); + assertThat(_nbhml.remove(2), is("v2a")); + checkSizes(0); + assertThat(_nbhml.remove(2), nullValue()); + assertThat(_nbhml.remove("k3"), nullValue()); + assertTrue(_nbhml.isEmpty()); + + assertThat(_nbhml.put(0, "v0"), nullValue()); + assertTrue(_nbhml.containsKey(0)); + checkSizes(1); + assertThat(_nbhml.remove(0), is("v0")); + assertFalse(_nbhml.containsKey(0)); + checkSizes(0); + + assertThat(_nbhml.replace(0, "v0"), nullValue()); + assertFalse(_nbhml.containsKey(0)); + assertThat(_nbhml.put(0, "v0"), nullValue()); + assertEquals(_nbhml.replace(0, "v0a"), "v0"); + assertEquals(_nbhml.get(0), "v0a"); + assertThat(_nbhml.remove(0), is("v0a")); + assertFalse(_nbhml.containsKey(0)); + checkSizes(0); + + assertThat(_nbhml.replace(1, "v1"), nullValue()); + assertFalse(_nbhml.containsKey(1)); + assertThat(_nbhml.put(1, "v1"), nullValue()); + assertEquals(_nbhml.replace(1, "v1a"), "v1"); + assertEquals(_nbhml.get(1), "v1a"); + assertThat(_nbhml.remove(1), is("v1a")); + assertFalse(_nbhml.containsKey(1)); + checkSizes(0); + + // Simple insert of simple keys, with no reprobing on insert until the + // table gets full exactly. Then do a 'get' on the totally full table. + NonBlockingHashMapLong map = new NonBlockingHashMapLong(32); + for (int i = 1; i < 32; i++) + { + map.put(i, new Object()); + } + map.get(33); // this causes a NPE + } + + // Check all iterators for correct size counts + private void checkSizes(int expectedSize) + { + assertEquals("size()", _nbhml.size(), expectedSize); + Collection vals = _nbhml.values(); + checkSizes("values()", vals.size(), vals.iterator(), expectedSize); + Set keys = _nbhml.keySet(); + checkSizes("keySet()", keys.size(), keys.iterator(), expectedSize); + Set> ents = _nbhml.entrySet(); + checkSizes("entrySet()", ents.size(), ents.iterator(), expectedSize); + } + + // Check that the iterator iterates the correct number of times + private void checkSizes(String msg, int sz, Iterator it, int expectedSize) + { + assertEquals(msg, expectedSize, sz); + int result = 0; + while (it.hasNext()) + { + result++; + it.next(); + } + assertEquals(msg, expectedSize, result); + } + + + @Test + public void replaceMissingValue() { + NonBlockingHashMapLong map = new NonBlockingHashMapLong<>(); + assertNull(map.replace(1, 2)); + assertFalse(map.replace(1, 2, 3)); + } + + @Test + public void testIterationBig2() + { + final int CNT = 10000; + assertThat(_nbhml.size(), is(0)); + final String v = "v"; + for (int i = 0; i < CNT; i++) + { + _nbhml.put(i, v); + String s = _nbhml.get(i); + assertThat(s, is(v)); + } + assertThat(_nbhml.size(), is(CNT)); + _nbhml.clear(); + } + + + @Test + public void testIteration() + { + assertTrue(_nbhml.isEmpty()); + assertThat(_nbhml.put(1, "v1"), nullValue()); + assertThat(_nbhml.put(2, "v2"), nullValue()); + + String str1 = ""; + for (Map.Entry e : _nbhml.entrySet()) + { + str1 += e.getKey(); + } + assertThat("found all entries", str1, anyOf(is("12"), is("21"))); + + String str2 = ""; + for (Long key : _nbhml.keySet()) + { + str2 += key; + } + assertThat("found all keys", str2, anyOf(is("12"), is("21"))); + + String str3 = ""; + for (String val : _nbhml.values()) + { + str3 += val; + } + assertThat("found all vals", str3, anyOf(is("v1v2"), is("v2v1"))); + + assertThat("toString works", _nbhml.toString(), anyOf(is("{1=v1, 2=v2}"), is("{2=v2, 1=v1}"))); + _nbhml.clear(); + } + + @Test + public void testSerial() + { + assertTrue(_nbhml.isEmpty()); + assertThat(_nbhml.put(0x12345678L, "v1"), nullValue()); + assertThat(_nbhml.put(0x87654321L, "v2"), nullValue()); + + // Serialize it out + try + { + FileOutputStream fos = new FileOutputStream("NBHML_test.txt"); + ObjectOutputStream out = new ObjectOutputStream(fos); + out.writeObject(_nbhml); + out.close(); + } + catch (IOException ex) + { + ex.printStackTrace(); + } + + // Read it back + try + { + File f = new File("NBHML_test.txt"); + FileInputStream fis = new FileInputStream(f); + ObjectInputStream in = new ObjectInputStream(fis); + NonBlockingHashMapLong nbhml = (NonBlockingHashMapLong) in.readObject(); + in.close(); + assertEquals(_nbhml.toString(), nbhml.toString()); + if (!f.delete()) + { + throw new IOException("delete failed"); + } + } + catch (IOException | ClassNotFoundException ex) + { + ex.printStackTrace(); + } + + } + + @Test + public void testIterationBig() + { + final int CNT = 10000; + assertThat(_nbhml.size(), is(0)); + for (int i = 0; i < CNT; i++) + { + _nbhml.put(i, "v" + i); + } + assertThat(_nbhml.size(), is(CNT)); + + int sz = 0; + int sum = 0; + for (long x : _nbhml.keySet()) + { + sz++; + sum += x; + assertTrue(x >= 0 && x <= (CNT - 1)); + } + assertThat("Found 10000 ints", sz, is(CNT)); + assertThat("Found all integers in list", sum, is(CNT * (CNT - 1) / 2)); + + assertThat("can remove 3", _nbhml.remove(3), is("v3")); + assertThat("can remove 4", _nbhml.remove(4), is("v4")); + sz = 0; + sum = 0; + for (long x : _nbhml.keySet()) + { + sz++; + sum += x; + assertTrue(x >= 0 && x <= (CNT - 1)); + String v = _nbhml.get(x); + assertThat("", v.charAt(0), is('v')); + assertThat("", x, is(Long.parseLong(v.substring(1)))); + } + assertThat("Found " + (CNT - 2) + " ints", sz, is(CNT - 2)); + assertThat("Found all integers in list", sum, is(CNT * (CNT - 1) / 2 - (3 + 4))); + _nbhml.clear(); + } + + // Do some simple concurrent testing + @Test + public void testConcurrentSimple() throws InterruptedException + { + final NonBlockingHashMapLong nbhml = new NonBlockingHashMapLong<>(); + + // In 2 threads, add & remove even & odd elements concurrently + final int num_thrds = 2; + Thread ts[] = new Thread[num_thrds]; + for (int i = 1; i < num_thrds; i++) + { + final int x = i; + ts[i] = new Thread() + { + public void run() + { + work_helper(nbhml, x, num_thrds); + } + }; + } + for (int i = 1; i < num_thrds; i++) + { + ts[i].start(); + } + work_helper(nbhml, 0, num_thrds); + for (int i = 1; i < num_thrds; i++) + { + ts[i].join(); + } + + // In the end, all members should be removed + StringBuilder buf = new StringBuilder(); + buf.append("Should be emptyset but has these elements: {"); + boolean found = false; + for (long x : nbhml.keySet()) + { + buf.append(" ").append(x); + found = true; + } + if (found) + { + System.out.println(buf + " }"); + } + assertThat("concurrent size=0", nbhml.size(), is(0)); + assertThat("keySet size==0", nbhml.keySet().size(), is(0)); + } + + void work_helper(NonBlockingHashMapLong nbhml, int d, int num_thrds) + { + String thrd = "T" + d; + final int ITERS = 20000; + for (int j = 0; j < 10; j++) + { + //long start = System.nanoTime(); + for (int i = d; i < ITERS; i += num_thrds) + { + assertThat("key " + i + " not in there, so putIfAbsent must work", + nbhml.putIfAbsent((long) i, thrd), is((String) null)); + } + for (int i = d; i < ITERS; i += num_thrds) + { + assertTrue(nbhml.remove((long) i, thrd)); + } + //double delta_nanos = System.nanoTime()-start; + //double delta_secs = delta_nanos/1000000000.0; + //double ops = ITERS*2; + //System.out.println("Thrd"+thrd+" "+(ops/delta_secs)+" ops/sec size="+nbhml.size()); + } + } + + + // --- Customer Test Case 1 ------------------------------------------------ + @Test + public final void testNonBlockingHashMapSize() + { + NonBlockingHashMapLong items = new NonBlockingHashMapLong<>(); + items.put(100L, "100"); + items.put(101L, "101"); + + assertEquals("keySet().size()", 2, items.keySet().size()); + assertTrue("keySet().contains(100)", items.keySet().contains(100L)); + assertTrue("keySet().contains(101)", items.keySet().contains(101L)); + + assertEquals("values().size()", 2, items.values().size()); + assertTrue("values().contains(\"100\")", items.values().contains("100")); + assertTrue("values().contains(\"101\")", items.values().contains("101")); + + assertEquals("entrySet().size()", 2, items.entrySet().size()); + boolean found100 = false; + boolean found101 = false; + for (Map.Entry entry : items.entrySet()) + { + if (entry.getKey().equals(100L)) + { + assertEquals("entry[100].getValue()==\"100\"", "100", entry.getValue()); + found100 = true; + } + else if (entry.getKey().equals(101L)) + { + assertEquals("entry[101].getValue()==\"101\"", "101", entry.getValue()); + found101 = true; + } + } + assertTrue("entrySet().contains([100])", found100); + assertTrue("entrySet().contains([101])", found101); + } + + // --- Customer Test Case 2 ------------------------------------------------ + // Concurrent insertion & then iterator test. + @Test + public void testNonBlockingHashMapIterator() throws InterruptedException + { + final int ITEM_COUNT1 = 1000; + final int THREAD_COUNT = 5; + final int PER_CNT = ITEM_COUNT1 / THREAD_COUNT; + final int ITEM_COUNT = PER_CNT * THREAD_COUNT; // fix roundoff for odd thread counts + + NonBlockingHashMapLong nbhml = new NonBlockingHashMapLong<>(); + // use a barrier to open the gate for all threads at once to avoid rolling + // start and no actual concurrency + final CyclicBarrier barrier = new CyclicBarrier(THREAD_COUNT); + final ExecutorService ex = Executors.newFixedThreadPool(THREAD_COUNT); + final CompletionService co = new ExecutorCompletionService<>(ex); + for (int i = 0; i < THREAD_COUNT; i++) + { + co.submit(new NBHMLFeeder(nbhml, PER_CNT, barrier, i * PER_CNT)); + } + for (int retCount = 0; retCount < THREAD_COUNT; retCount++) + { + co.take(); + } + ex.shutdown(); + + assertEquals("values().size()", ITEM_COUNT, nbhml.values().size()); + assertEquals("entrySet().size()", ITEM_COUNT, nbhml.entrySet().size()); + int itemCount = 0; + for (TestKey K : nbhml.values()) + { + itemCount++; + } + assertEquals("values().iterator() count", ITEM_COUNT, itemCount); + } + + // --- Customer Test Case 3 ------------------------------------------------ + private TestKeyFeeder getTestKeyFeeder() + { + final TestKeyFeeder feeder = new TestKeyFeeder(); + feeder.checkedPut(10401000001844L, 657829272, 680293140); // section 12 + feeder.checkedPut(10401000000614L, 657829272, 401326994); // section 12 + feeder.checkedPut(10400345749304L, 2095121916, -9852212); // section 12 + feeder.checkedPut(10401000002204L, 657829272, 14438460); // section 12 + feeder.checkedPut(10400345749234L, 1186831289, -894006017); // section 12 + feeder.checkedPut(10401000500234L, 969314784, -2112018706); // section 12 + feeder.checkedPut(10401000000284L, 657829272, 521425852); // section 12 + feeder.checkedPut(10401000002134L, 657829272, 208406306); // section 12 + feeder.checkedPut(10400345749254L, 2095121916, -341939818); // section 12 + feeder.checkedPut(10401000500384L, 969314784, -2136811544); // section 12 + feeder.checkedPut(10401000001944L, 657829272, 935194952); // section 12 + feeder.checkedPut(10400345749224L, 1186831289, -828214183); // section 12 + feeder.checkedPut(10400345749244L, 2095121916, -351234120); // section 12 + feeder.checkedPut(10400333128994L, 2095121916, -496909430); // section 12 + feeder.checkedPut(10400333197934L, 2095121916, 2147144926); // section 12 + feeder.checkedPut(10400333197944L, 2095121916, -2082366964); // section 12 + feeder.checkedPut(10400336947684L, 2095121916, -1404212288); // section 12 + feeder.checkedPut(10401000000594L, 657829272, 124369790); // section 12 + feeder.checkedPut(10400331896264L, 2095121916, -1028383492); // section 12 + feeder.checkedPut(10400332415044L, 2095121916, 1629436704); // section 12 + feeder.checkedPut(10400345749614L, 1186831289, 1027996827); // section 12 + feeder.checkedPut(10401000500424L, 969314784, -1871616544); // section 12 + feeder.checkedPut(10400336947694L, 2095121916, -1468802722); // section 12 + feeder.checkedPut(10410002672481L, 2154973, 1515288586); // section 12 + feeder.checkedPut(10410345749171L, 2154973, 2084791828); // section 12 + feeder.checkedPut(10400004960671L, 2154973, 1554754674); // section 12 + feeder.checkedPut(10410009983601L, 2154973, -2049707334); // section 12 + feeder.checkedPut(10410335811601L, 2154973, 1547385114); // section 12 + feeder.checkedPut(10410000005951L, 2154973, -1136117016); // section 12 + feeder.checkedPut(10400004938331L, 2154973, -1361373018); // section 12 + feeder.checkedPut(10410001490421L, 2154973, -818792874); // section 12 + feeder.checkedPut(10400001187131L, 2154973, 649763142); // section 12 + feeder.checkedPut(10410000409071L, 2154973, -614460616); // section 12 + feeder.checkedPut(10410333717391L, 2154973, 1343531416); // section 12 + feeder.checkedPut(10410336680071L, 2154973, -914544144); // section 12 + feeder.checkedPut(10410002068511L, 2154973, -746995576); // section 12 + feeder.checkedPut(10410336207851L, 2154973, 863146156); // section 12 + feeder.checkedPut(10410002365251L, 2154973, 542724164); // section 12 + feeder.checkedPut(10400335812581L, 2154973, 2146284796); // section 12 + feeder.checkedPut(10410337345361L, 2154973, -384625318); // section 12 + feeder.checkedPut(10410000409091L, 2154973, -528258556); // section 12 + return feeder; + } + + // --- + @Test + public void testNonBlockingHashMapIteratorMultithreaded() throws InterruptedException, ExecutionException + { + TestKeyFeeder feeder = getTestKeyFeeder(); + final int itemCount = feeder.size(); + + // validate results + final NonBlockingHashMapLong items = feeder.getMapMultithreaded(); + assertEquals("size()", itemCount, items.size()); + + assertEquals("values().size()", itemCount, items.values().size()); + + assertEquals("entrySet().size()", itemCount, items.entrySet().size()); + + int iteratorCount = 0; + for (TestKey m : items.values()) + { + iteratorCount++; + } + // sometimes a different result comes back the second time + int iteratorCount2 = 0; + for (TestKey m2 : items.values()) + { + iteratorCount2++; + } + assertEquals("iterator counts differ", iteratorCount, iteratorCount2); + assertEquals("values().iterator() count", itemCount, iteratorCount); + } + + // --- NBHMLFeeder --- + // Class to be called from another thread, to get concurrent installs into + // the table. + static private class NBHMLFeeder implements Callable + { + static private final Random _rand = new Random(System.currentTimeMillis()); + private final NonBlockingHashMapLong _map; + private final int _count; + private final CyclicBarrier _barrier; + private final long _offset; + + public NBHMLFeeder( + final NonBlockingHashMapLong map, + final int count, + final CyclicBarrier barrier, + final long offset) + { + _map = map; + _count = count; + _barrier = barrier; + _offset = offset; + } + + public Object call() throws Exception + { + _barrier.await(); // barrier, to force racing start + for (long j = 0; j < _count; j++) + { + _map.put( + j + _offset, + new TestKey(_rand.nextLong(), _rand.nextInt(), (short) _rand.nextInt(Short.MAX_VALUE))); + } + return null; + } + } + + // --- TestKey --- + // Funny key tests all sorts of things, has a pre-wired hashCode & equals. + static private final class TestKey + { + public final int _type; + public final long _id; + public final int _hash; + + public TestKey(final long id, final int type, int hash) + { + _id = id; + _type = type; + _hash = hash; + } + + public int hashCode() + { + return _hash; + } + + public boolean equals(Object object) + { + if (null == object) + { + return false; + } + if (object == this) + { + return true; + } + if (object.getClass() != this.getClass()) + { + return false; + } + final TestKey other = (TestKey) object; + return (this._type == other._type && this._id == other._id); + } + + public String toString() + { + return String.format("%s:%d,%d,%d", getClass().getSimpleName(), _id, _type, _hash); + } + } + + // --- + static private class TestKeyFeeder + { + private final Hashtable> _items = new Hashtable<>(); + private int _size = 0; + + public int size() + { + return _size; + } + + // Put items into the hashtable, sorted by 'type' into LinkedLists. + public void checkedPut(final long id, final int type, final int hash) + { + _size++; + final TestKey item = new TestKey(id, type, hash); + if (!_items.containsKey(type)) + { + _items.put(type, new LinkedList<>()); + } + _items.get(type).add(item); + } + + public NonBlockingHashMapLong getMapMultithreaded() throws InterruptedException, ExecutionException + { + final int threadCount = _items.keySet().size(); + final NonBlockingHashMapLong map = new NonBlockingHashMapLong<>(); + + // use a barrier to open the gate for all threads at once to avoid rolling start and no actual concurrency + final CyclicBarrier barrier = new CyclicBarrier(threadCount); + final ExecutorService ex = Executors.newFixedThreadPool(threadCount); + final CompletionService co = new ExecutorCompletionService<>(ex); + for (Integer type : _items.keySet()) + { + // A linked-list of things to insert + List items = _items.get(type); + TestKeyFeederThread feeder = new TestKeyFeederThread(items, map, barrier); + co.submit(feeder); + } + + // wait for all threads to return + int itemCount = 0; + for (int retCount = 0; retCount < threadCount; retCount++) + { + final Future result = co.take(); + itemCount += result.get(); + } + ex.shutdown(); + return map; + } + } + + // --- TestKeyFeederThread + static private class TestKeyFeederThread implements Callable + { + private final NonBlockingHashMapLong _map; + private final List _items; + private final CyclicBarrier _barrier; + + public TestKeyFeederThread( + final List items, + final NonBlockingHashMapLong map, + final CyclicBarrier barrier) + { + _map = map; + _items = items; + _barrier = barrier; + } + + public Integer call() throws Exception + { + _barrier.await(); + int count = 0; + for (TestKey item : _items) + { + if (_map.contains(item._id)) + { + System.err.printf("COLLISION DETECTED: %s exists\n", item.toString()); + } + final TestKey exists = _map.putIfAbsent(item._id, item); + if (exists == null) + { + count++; + } + else + { + System.err.printf("COLLISION DETECTED: %s exists as %s\n", item.toString(), exists.toString()); + } + } + return count; + } + } + +} diff --git a/netty-jctools/src/test/java/org/jctools/maps/nbhm_test/NBHM_Tester2.java b/netty-jctools/src/test/java/org/jctools/maps/nbhm_test/NBHM_Tester2.java new file mode 100644 index 0000000..922d60f --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/maps/nbhm_test/NBHM_Tester2.java @@ -0,0 +1,728 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.maps.nbhm_test; + +import java.io.*; +import java.util.*; +import java.util.concurrent.*; + +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +import org.jctools.maps.NonBlockingHashMap; + +import static org.hamcrest.CoreMatchers.*; +import static org.junit.Assert.*; + + +// Test NonBlockingHashMap via JUnit +public class NBHM_Tester2 +{ + + static private NonBlockingHashMap _nbhm; + + @BeforeClass + public static void setUp() + { + _nbhm = new NonBlockingHashMap<>(); + } + + @AfterClass + public static void tearDown() + { + _nbhm = null; + } + + // Test some basic stuff; add a few keys, remove a few keys + @Test + public void testBasic() + { + assertTrue(_nbhm.isEmpty()); + assertThat(_nbhm.putIfAbsent("k1", "v1"), nullValue()); + checkSizes(1); + assertThat(_nbhm.putIfAbsent("k2", "v2"), nullValue()); + checkSizes(2); + assertTrue(_nbhm.containsKey("k2")); + assertThat(_nbhm.put("k1", "v1a"), is("v1")); + assertThat(_nbhm.put("k2", "v2a"), is("v2")); + checkSizes(2); + assertThat(_nbhm.putIfAbsent("k2", "v2b"), is("v2a")); + assertThat(_nbhm.remove("k1"), is("v1a")); + assertFalse(_nbhm.containsKey("k1")); + checkSizes(1); + assertThat(_nbhm.remove("k1"), nullValue()); + assertThat(_nbhm.remove("k2"), is("v2a")); + checkSizes(0); + assertThat(_nbhm.remove("k2"), nullValue()); + assertThat(_nbhm.remove("k3"), nullValue()); + assertTrue(_nbhm.isEmpty()); + + assertThat(_nbhm.put("k0", "v0"), nullValue()); + assertTrue(_nbhm.containsKey("k0")); + checkSizes(1); + assertThat(_nbhm.remove("k0"), is("v0")); + assertFalse(_nbhm.containsKey("k0")); + checkSizes(0); + + assertThat(_nbhm.replace("k0", "v0"), nullValue()); + assertFalse(_nbhm.containsKey("k0")); + assertThat(_nbhm.put("k0", "v0"), nullValue()); + assertEquals(_nbhm.replace("k0", "v0a"), "v0"); + assertEquals(_nbhm.get("k0"), "v0a"); + assertThat(_nbhm.remove("k0"), is("v0a")); + assertFalse(_nbhm.containsKey("k0")); + checkSizes(0); + + assertThat(_nbhm.replace("k1", "v1"), nullValue()); + assertFalse(_nbhm.containsKey("k1")); + assertThat(_nbhm.put("k1", "v1"), nullValue()); + assertEquals(_nbhm.replace("k1", "v1a"), "v1"); + assertEquals(_nbhm.get("k1"), "v1a"); + assertThat(_nbhm.remove("k1"), is("v1a")); + assertFalse(_nbhm.containsKey("k1")); + checkSizes(0); + + // Insert & Remove KeyBonks until the table resizes and we start + // finding Tombstone keys- and KeyBonk's equals-call with throw a + // ClassCastException if it sees a non-KeyBonk. + NonBlockingHashMap dumb = new NonBlockingHashMap<>(); + for (int i = 0; i < 10000; i++) + { + final KeyBonk happy1 = new KeyBonk(i); + assertThat(dumb.put(happy1, "and"), nullValue()); + if ((i & 1) == 0) + { + dumb.remove(happy1); + } + final KeyBonk happy2 = new KeyBonk(i); // 'equals' but not '==' + dumb.get(happy2); + } + + // Simple insert of simple keys, with no reprobing on insert until the + // table gets full exactly. Then do a 'get' on the totally full table. + NonBlockingHashMap map = new NonBlockingHashMap<>(32); + for (int i = 1; i < 32; i++) + { + map.put(i, new Object()); + } + map.get(33); // this returns null, but tested a crash edge case for expansion + } + + // Check all iterators for correct size counts + private void checkSizes(int expectedSize) + { + assertEquals("size()", _nbhm.size(), expectedSize); + Collection vals = _nbhm.values(); + checkSizes("values()", vals.size(), vals.iterator(), expectedSize); + Set keys = _nbhm.keySet(); + checkSizes("keySet()", keys.size(), keys.iterator(), expectedSize); + Set> ents = _nbhm.entrySet(); + checkSizes("entrySet()", ents.size(), ents.iterator(), expectedSize); + } + + // Check that the iterator iterates the correct number of times + private void checkSizes(String msg, int sz, Iterator it, int expectedSize) + { + assertEquals(msg, expectedSize, sz); + int result = 0; + while (it.hasNext()) + { + result++; + it.next(); + } + assertEquals(msg, expectedSize, result); + } + + @Test + public void testIteration() + { + assertTrue(_nbhm.isEmpty()); + assertThat(_nbhm.put("k1", "v1"), nullValue()); + assertThat(_nbhm.put("k2", "v2"), nullValue()); + + String str1 = ""; + for (Map.Entry e : _nbhm.entrySet()) + { + str1 += e.getKey(); + } + assertThat("found all entries", str1, anyOf(is("k1k2"), is("k2k1"))); + + String str2 = ""; + for (String key : _nbhm.keySet()) + { + str2 += key; + } + assertThat("found all keys", str2, anyOf(is("k1k2"), is("k2k1"))); + + String str3 = ""; + for (String val : _nbhm.values()) + { + str3 += val; + } + assertThat("found all vals", str3, anyOf(is("v1v2"), is("v2v1"))); + + assertThat("toString works", _nbhm.toString(), anyOf(is("{k1=v1, k2=v2}"), is("{k2=v2, k1=v1}"))); + _nbhm.clear(); + } + + @Test + public void testSerial() + { + assertTrue(_nbhm.isEmpty()); + assertThat(_nbhm.put("k1", "v1"), nullValue()); + assertThat(_nbhm.put("k2", "v2"), nullValue()); + + // Serialize it out + try + { + FileOutputStream fos = new FileOutputStream("NBHM_test.txt"); + ObjectOutputStream out = new ObjectOutputStream(fos); + out.writeObject(_nbhm); + out.close(); + } + catch (IOException ex) + { + ex.printStackTrace(); + } + + // Read it back + try + { + File f = new File("NBHM_test.txt"); + FileInputStream fis = new FileInputStream(f); + ObjectInputStream in = new ObjectInputStream(fis); + NonBlockingHashMap nbhm = (NonBlockingHashMap) in.readObject(); + in.close(); + assertEquals(_nbhm.toString(), nbhm.toString()); + if (!f.delete()) + { + throw new IOException("delete failed"); + } + } + catch (IOException | ClassNotFoundException ex) + { + ex.printStackTrace(); + } + } + + @Test + public void testIterationBig2() + { + final int CNT = 10000; + NonBlockingHashMap nbhm = new NonBlockingHashMap<>(); + final String v = "v"; + for (int i = 0; i < CNT; i++) + { + final Integer z = i; + String s0 = nbhm.get(z); + assertThat(s0, nullValue()); + nbhm.put(z, v); + String s1 = nbhm.get(z); + assertThat(s1, is(v)); + } + assertThat(nbhm.size(), is(CNT)); + } + + @Test + public void testIterationBig() + { + final int CNT = 10000; + assertThat(_nbhm.size(), is(0)); + for (int i = 0; i < CNT; i++) + { + _nbhm.put("k" + i, "v" + i); + } + assertThat(_nbhm.size(), is(CNT)); + + int sz = 0; + int sum = 0; + for (String s : _nbhm.keySet()) + { + sz++; + assertThat("", s.charAt(0), is('k')); + int x = Integer.parseInt(s.substring(1)); + sum += x; + assertTrue(x >= 0 && x <= (CNT - 1)); + } + assertThat("Found 10000 ints", sz, is(CNT)); + assertThat("Found all integers in list", sum, is(CNT * (CNT - 1) / 2)); + + assertThat("can remove 3", _nbhm.remove("k3"), is("v3")); + assertThat("can remove 4", _nbhm.remove("k4"), is("v4")); + sz = 0; + sum = 0; + for (String s : _nbhm.keySet()) + { + sz++; + assertThat("", s.charAt(0), is('k')); + int x = Integer.parseInt(s.substring(1)); + sum += x; + assertTrue(x >= 0 && x <= (CNT - 1)); + String v = _nbhm.get(s); + assertThat("", v.charAt(0), is('v')); + assertThat("", s.substring(1), is(v.substring(1))); + } + assertThat("Found " + (CNT - 2) + " ints", sz, is(CNT - 2)); + assertThat("Found all integers in list", sum, is(CNT * (CNT - 1) / 2 - (3 + 4))); + _nbhm.clear(); + } + + // Do some simple concurrent testing + @Test + public void testConcurrentSimple() throws InterruptedException + { + final NonBlockingHashMap nbhm = new NonBlockingHashMap<>(); + + // In 2 threads, add & remove even & odd elements concurrently + Thread t1 = new Thread() + { + public void run() + { + work_helper(nbhm, "T1", 1); + } + }; + t1.start(); + work_helper(nbhm, "T0", 0); + t1.join(); + + // In the end, all members should be removed + StringBuilder buf = new StringBuilder(); + buf.append("Should be emptyset but has these elements: {"); + boolean found = false; + for (String x : nbhm.keySet()) + { + buf.append(" ").append(x); + found = true; + } + if (found) + { + System.out.println(buf + " }"); + } + assertThat("concurrent size=0", nbhm.size(), is(0)); + assertThat("keyset size=0", nbhm.keySet().size(), is(0)); + } + + void work_helper(NonBlockingHashMap nbhm, String thrd, int d) + { + final int ITERS = 20000; + for (int j = 0; j < 10; j++) + { + //long start = System.nanoTime(); + for (int i = d; i < ITERS; i += 2) + { + assertThat("this key not in there, so putIfAbsent must work", + nbhm.putIfAbsent("k" + i, thrd), is((String) null)); + } + for (int i = d; i < ITERS; i += 2) + { + assertTrue(nbhm.remove("k" + i, thrd)); + } + //double delta_nanos = System.nanoTime()-start; + //double delta_secs = delta_nanos/1000000000.0; + //double ops = ITERS*2; + //System.out.println("Thrd"+thrd+" "+(ops/delta_secs)+" ops/sec size="+nbhm.size()); + } + } + + @Test + public final void testNonBlockingHashMapSize() + { + NonBlockingHashMap items = new NonBlockingHashMap<>(); + items.put(100L, "100"); + items.put(101L, "101"); + + assertEquals("keySet().size()", 2, items.keySet().size()); + assertTrue("keySet().contains(100)", items.keySet().contains(100L)); + assertTrue("keySet().contains(101)", items.keySet().contains(101L)); + + assertEquals("values().size()", 2, items.values().size()); + assertTrue("values().contains(\"100\")", items.values().contains("100")); + assertTrue("values().contains(\"101\")", items.values().contains("101")); + + assertEquals("entrySet().size()", 2, items.entrySet().size()); + boolean found100 = false; + boolean found101 = false; + for (Map.Entry entry : items.entrySet()) + { + if (entry.getKey().equals(100L)) + { + assertEquals("entry[100].getValue()==\"100\"", "100", entry.getValue()); + found100 = true; + } + else if (entry.getKey().equals(101L)) + { + assertEquals("entry[101].getValue()==\"101\"", "101", entry.getValue()); + found101 = true; + } + } + assertTrue("entrySet().contains([100])", found100); + assertTrue("entrySet().contains([101])", found101); + } + + // Concurrent insertion & then iterator test. + @Test + public void testNonBlockingHashMapIterator() throws InterruptedException + { + final int ITEM_COUNT1 = 1000; + final int THREAD_COUNT = 5; + final int PER_CNT = ITEM_COUNT1 / THREAD_COUNT; + final int ITEM_COUNT = PER_CNT * THREAD_COUNT; // fix roundoff for odd thread counts + + NonBlockingHashMap nbhml = new NonBlockingHashMap<>(); + // use a barrier to open the gate for all threads at once to avoid rolling + // start and no actual concurrency + final CyclicBarrier barrier = new CyclicBarrier(THREAD_COUNT); + final ExecutorService ex = Executors.newFixedThreadPool(THREAD_COUNT); + final CompletionService co = new ExecutorCompletionService<>(ex); + for (int i = 0; i < THREAD_COUNT; i++) + { + co.submit(new NBHMLFeeder(nbhml, PER_CNT, barrier, i * PER_CNT)); + } + for (int retCount = 0; retCount < THREAD_COUNT; retCount++) + { + co.take(); + } + ex.shutdown(); + + assertEquals("values().size()", ITEM_COUNT, nbhml.values().size()); + assertEquals("entrySet().size()", ITEM_COUNT, nbhml.entrySet().size()); + int itemCount = 0; + for (TestKey K : nbhml.values()) + { + itemCount++; + } + assertEquals("values().iterator() count", ITEM_COUNT, itemCount); + } + + // --- Customer Test Case 3 ------------------------------------------------ + private TestKeyFeeder getTestKeyFeeder() + { + final TestKeyFeeder feeder = new TestKeyFeeder(); + feeder.checkedPut(10401000001844L, 657829272, 680293140); // section 12 + feeder.checkedPut(10401000000614L, 657829272, 401326994); // section 12 + feeder.checkedPut(10400345749304L, 2095121916, -9852212); // section 12 + feeder.checkedPut(10401000002204L, 657829272, 14438460); // section 12 + feeder.checkedPut(10400345749234L, 1186831289, -894006017); // section 12 + feeder.checkedPut(10401000500234L, 969314784, -2112018706); // section 12 + feeder.checkedPut(10401000000284L, 657829272, 521425852); // section 12 + feeder.checkedPut(10401000002134L, 657829272, 208406306); // section 12 + feeder.checkedPut(10400345749254L, 2095121916, -341939818); // section 12 + feeder.checkedPut(10401000500384L, 969314784, -2136811544); // section 12 + feeder.checkedPut(10401000001944L, 657829272, 935194952); // section 12 + feeder.checkedPut(10400345749224L, 1186831289, -828214183); // section 12 + feeder.checkedPut(10400345749244L, 2095121916, -351234120); // section 12 + feeder.checkedPut(10400333128994L, 2095121916, -496909430); // section 12 + feeder.checkedPut(10400333197934L, 2095121916, 2147144926); // section 12 + feeder.checkedPut(10400333197944L, 2095121916, -2082366964); // section 12 + feeder.checkedPut(10400336947684L, 2095121916, -1404212288); // section 12 + feeder.checkedPut(10401000000594L, 657829272, 124369790); // section 12 + feeder.checkedPut(10400331896264L, 2095121916, -1028383492); // section 12 + feeder.checkedPut(10400332415044L, 2095121916, 1629436704); // section 12 + feeder.checkedPut(10400345749614L, 1186831289, 1027996827); // section 12 + feeder.checkedPut(10401000500424L, 969314784, -1871616544); // section 12 + feeder.checkedPut(10400336947694L, 2095121916, -1468802722); // section 12 + feeder.checkedPut(10410002672481L, 2154973, 1515288586); // section 12 + feeder.checkedPut(10410345749171L, 2154973, 2084791828); // section 12 + feeder.checkedPut(10400004960671L, 2154973, 1554754674); // section 12 + feeder.checkedPut(10410009983601L, 2154973, -2049707334); // section 12 + feeder.checkedPut(10410335811601L, 2154973, 1547385114); // section 12 + feeder.checkedPut(10410000005951L, 2154973, -1136117016); // section 12 + feeder.checkedPut(10400004938331L, 2154973, -1361373018); // section 12 + feeder.checkedPut(10410001490421L, 2154973, -818792874); // section 12 + feeder.checkedPut(10400001187131L, 2154973, 649763142); // section 12 + feeder.checkedPut(10410000409071L, 2154973, -614460616); // section 12 + feeder.checkedPut(10410333717391L, 2154973, 1343531416); // section 12 + feeder.checkedPut(10410336680071L, 2154973, -914544144); // section 12 + feeder.checkedPut(10410002068511L, 2154973, -746995576); // section 12 + feeder.checkedPut(10410336207851L, 2154973, 863146156); // section 12 + feeder.checkedPut(10410002365251L, 2154973, 542724164); // section 12 + feeder.checkedPut(10400335812581L, 2154973, 2146284796); // section 12 + feeder.checkedPut(10410337345361L, 2154973, -384625318); // section 12 + feeder.checkedPut(10410000409091L, 2154973, -528258556); // section 12 + return feeder; + } + + // --- + @Test + public void testNonBlockingHashMapIteratorMultithreaded() throws InterruptedException, ExecutionException + { + TestKeyFeeder feeder = getTestKeyFeeder(); + final int itemCount = feeder.size(); + + // validate results + final NonBlockingHashMap items = feeder.getMapMultithreaded(); + assertEquals("size()", itemCount, items.size()); + + assertEquals("values().size()", itemCount, items.values().size()); + + assertEquals("entrySet().size()", itemCount, items.entrySet().size()); + + int iteratorCount = 0; + for (TestKey m : items.values()) + { + iteratorCount++; + } + // sometimes a different result comes back the second time + int iteratorCount2 = 0; + for (TestKey m2 : items.values()) + { + iteratorCount2++; + } + assertEquals("iterator counts differ", iteratorCount, iteratorCount2); + assertEquals("values().iterator() count", itemCount, iteratorCount); + } + + // --- Tests on equality of values + @Test + public void replaceResultIsBasedOnEquality() { + NonBlockingHashMap map = new NonBlockingHashMap<>(); + Integer initialValue = new Integer(10); + map.put(1, initialValue); + assertTrue(map.replace(1, initialValue, 20)); + assertTrue(map.replace(1, new Integer(20), 30)); + } + + @Test + public void removeResultIsBasedOnEquality() { + NonBlockingHashMap map = new NonBlockingHashMap<>(); + Integer initialValue = new Integer(10); + map.put(1, initialValue); + assertTrue(map.remove(1, initialValue)); + map.put(1, initialValue); + assertTrue(map.remove(1, new Integer(10))); + } + + // Throw a ClassCastException if I see a tombstone during key-compares + private static class KeyBonk + { + final int _x; + + KeyBonk(int i) + { + _x = i; + } + + public int hashCode() + { + return (_x >> 2); + } public boolean equals(Object o) + { + return o != null && ((KeyBonk) o)._x // Throw CCE here + == this._x; + } + + + + public String toString() + { + return "Bonk_" + Integer.toString(_x); + } + } + + // --- NBHMLFeeder --- + // Class to be called from another thread, to get concurrent installs into + // the table. + static private class NBHMLFeeder implements Callable + { + static private final Random _rand = new Random(System.currentTimeMillis()); + private final NonBlockingHashMap _map; + private final int _count; + private final CyclicBarrier _barrier; + private final long _offset; + + public NBHMLFeeder( + final NonBlockingHashMap map, + final int count, + final CyclicBarrier barrier, + final long offset) + { + _map = map; + _count = count; + _barrier = barrier; + _offset = offset; + } + + public Object call() throws Exception + { + _barrier.await(); // barrier, to force racing start + for (long j = 0; j < _count; j++) + { + _map.put( + j + _offset, + new TestKey(_rand.nextLong(), _rand.nextInt(), (short) _rand.nextInt(Short.MAX_VALUE))); + } + return null; + } + } + + // --- TestKey --- + // Funny key tests all sorts of things, has a pre-wired hashCode & equals. + static private final class TestKey + { + public final int _type; + public final long _id; + public final int _hash; + + public TestKey(final long id, final int type, int hash) + { + _id = id; + _type = type; + _hash = hash; + } + + public int hashCode() + { + return _hash; + } + + public boolean equals(Object object) + { + if (null == object) + { + return false; + } + if (object == this) + { + return true; + } + if (object.getClass() != this.getClass()) + { + return false; + } + final TestKey other = (TestKey) object; + return (this._type == other._type && this._id == other._id); + } + + public String toString() + { + return String.format("%s:%d,%d,%d", getClass().getSimpleName(), _id, _type, _hash); + } + } + + // --- + static private class TestKeyFeeder + { + private final Hashtable> _items = new Hashtable<>(); + private int _size = 0; + + public int size() + { + return _size; + } + + // Put items into the hashtable, sorted by 'type' into LinkedLists. + public void checkedPut(final long id, final int type, final int hash) + { + _size++; + final TestKey item = new TestKey(id, type, hash); + if (!_items.containsKey(type)) + { + _items.put(type, new LinkedList<>()); + } + _items.get(type).add(item); + } + + public NonBlockingHashMap getMapMultithreaded() throws InterruptedException, ExecutionException + { + final int threadCount = _items.keySet().size(); + final NonBlockingHashMap map = new NonBlockingHashMap<>(); + + // use a barrier to open the gate for all threads at once to avoid rolling start and no actual concurrency + final CyclicBarrier barrier = new CyclicBarrier(threadCount); + final ExecutorService ex = Executors.newFixedThreadPool(threadCount); + final CompletionService co = new ExecutorCompletionService<>(ex); + for (Integer type : _items.keySet()) + { + // A linked-list of things to insert + List items = _items.get(type); + TestKeyFeederThread feeder = new TestKeyFeederThread(items, map, barrier); + co.submit(feeder); + } + + // wait for all threads to return + int itemCount = 0; + for (int retCount = 0; retCount < threadCount; retCount++) + { + final Future result = co.take(); + itemCount += result.get(); + } + ex.shutdown(); + return map; + } + } + + // --- TestKeyFeederThread + static private class TestKeyFeederThread implements Callable + { + private final NonBlockingHashMap _map; + private final List _items; + private final CyclicBarrier _barrier; + + public TestKeyFeederThread( + final List items, + final NonBlockingHashMap map, + final CyclicBarrier barrier) + { + _map = map; + _items = items; + _barrier = barrier; + } + + public Integer call() throws Exception + { + _barrier.await(); + int count = 0; + for (TestKey item : _items) + { + if (_map.contains(item._id)) + { + System.err.printf("COLLISION DETECTED: %s exists\n", item.toString()); + } + final TestKey exists = _map.putIfAbsent(item._id, item); + if (exists == null) + { + count++; + } + else + { + System.err.printf("COLLISION DETECTED: %s exists as %s\n", item.toString(), exists.toString()); + } + } + return count; + } + } + + // This test is a copy of the JCK test Hashtable2027, which is incorrect. + // The test requires a particular order of values to appear in the esa + // array - but this is not part of the spec. A different implementation + // might put the same values into the array but in a different order. + //public void testToArray() { + // NonBlockingHashMap ht = new NonBlockingHashMap(); + // + // ht.put("Nine", new Integer(9)); + // ht.put("Ten", new Integer(10)); + // ht.put("Ten1", new Integer(100)); + // + // Collection es = ht.values(); + // + // Object [] esa = es.toArray(); + // + // ht.remove("Ten1"); + // + // assertEquals( "size check", es.size(), 2 ); + // assertEquals( "iterator_order[0]", new Integer( 9), esa[0] ); + // assertEquals( "iterator_order[1]", new Integer(10), esa[1] ); + //} +} diff --git a/netty-jctools/src/test/java/org/jctools/maps/nbhs_test/nbhs_tester.java b/netty-jctools/src/test/java/org/jctools/maps/nbhs_test/nbhs_tester.java new file mode 100644 index 0000000..08a4654 --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/maps/nbhs_test/nbhs_tester.java @@ -0,0 +1,175 @@ +package org.jctools.maps.nbhs_test; + +import java.io.*; +import java.util.Iterator; + +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +import org.jctools.maps.NonBlockingHashSet; + +import static org.hamcrest.CoreMatchers.anyOf; +import static org.hamcrest.CoreMatchers.is; +import static org.junit.Assert.*; + +/* + * Written by Cliff Click and released to the public domain, as explained at + * http://creativecommons.org/licenses/publicdomain + */ + +// Test NonBlockingHashSet via JUnit +public class nbhs_tester +{ + + static private NonBlockingHashSet _nbhs; + + @BeforeClass + public static void setUp() + { + _nbhs = new NonBlockingHashSet(); + } + + @AfterClass + public static void tearDown() + { + _nbhs = null; + } + + // Test some basic stuff; add a few keys, remove a few keys + @Test + public void testBasic() + { + assertTrue(_nbhs.isEmpty()); + assertTrue(_nbhs.add("k1")); + checkSizes(1); + assertTrue(_nbhs.add("k2")); + checkSizes(2); + assertFalse(_nbhs.add("k1")); + assertFalse(_nbhs.add("k2")); + checkSizes(2); + assertThat(_nbhs.remove("k1"), is(true)); + checkSizes(1); + assertThat(_nbhs.remove("k1"), is(false)); + assertTrue(_nbhs.remove("k2")); + checkSizes(0); + assertFalse(_nbhs.remove("k2")); + assertFalse(_nbhs.remove("k3")); + assertTrue(_nbhs.isEmpty()); + } + + // Check all iterators for correct size counts + private void checkSizes(int expectedSize) + { + assertEquals("size()", _nbhs.size(), expectedSize); + Iterator it = _nbhs.iterator(); + int result = 0; + while (it.hasNext()) + { + result++; + it.next(); + } + assertEquals("iterator missed", expectedSize, result); + } + + @Test + public void testIteration() + { + assertTrue(_nbhs.isEmpty()); + assertTrue(_nbhs.add("k1")); + assertTrue(_nbhs.add("k2")); + + StringBuilder buf = new StringBuilder(); + for (String val : _nbhs) + { + buf.append(val); + } + assertThat("found all vals", buf.toString(), anyOf(is("k1k2"), is("k2k1"))); + + assertThat("toString works", _nbhs.toString(), anyOf(is("[k1, k2]"), is("[k2, k1]"))); + _nbhs.clear(); + } + + @Test + public void testIterationBig() + { + for (int i = 0; i < 100; i++) + { + _nbhs.add("a" + i); + } + assertThat(_nbhs.size(), is(100)); + + int sz = 0; + int sum = 0; + for (String s : _nbhs) + { + sz++; + assertThat("", s.charAt(0), is('a')); + int x = Integer.parseInt(s.substring(1)); + sum += x; + assertTrue(x >= 0 && x <= 99); + } + assertThat("Found 100 ints", sz, is(100)); + assertThat("Found all integers in list", sum, is(100 * 99 / 2)); + + assertThat("can remove 3", _nbhs.remove("a3"), is(true)); + assertThat("can remove 4", _nbhs.remove("a4"), is(true)); + sz = 0; + sum = 0; + for (String s : _nbhs) + { + sz++; + assertThat("", s.charAt(0), is('a')); + int x = Integer.parseInt(s.substring(1)); + sum += x; + assertTrue(x >= 0 && x <= 99); + } + assertThat("Found 98 ints", sz, is(98)); + assertThat("Found all integers in list", sum, is(100 * 99 / 2 - (3 + 4))); + _nbhs.clear(); + } + + @Test + public void testSerial() + { + assertTrue(_nbhs.isEmpty()); + assertTrue(_nbhs.add("k1")); + assertTrue(_nbhs.add("k2")); + + // Serialize it out + try + { + FileOutputStream fos = new FileOutputStream("NBHS_test.txt"); + ObjectOutputStream out = new ObjectOutputStream(fos); + out.writeObject(_nbhs); + out.close(); + } + catch (IOException ex) + { + ex.printStackTrace(); + } + + // Read it back + try + { + File f = new File("NBHS_test.txt"); + FileInputStream fis = new FileInputStream(f); + ObjectInputStream in = new ObjectInputStream(fis); + NonBlockingHashSet nbhs = (NonBlockingHashSet) in.readObject(); + in.close(); + assertEquals(_nbhs.toString(), nbhs.toString()); + if (!f.delete()) + { + throw new IOException("delete failed"); + } + } + catch (IOException ex) + { + ex.printStackTrace(); + } + catch (ClassNotFoundException ex) + { + ex.printStackTrace(); + } + } +} diff --git a/netty-jctools/src/test/java/org/jctools/maps/nbhs_test/nbsi_tester.java b/netty-jctools/src/test/java/org/jctools/maps/nbhs_test/nbsi_tester.java new file mode 100644 index 0000000..2f5a3ca --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/maps/nbhs_test/nbsi_tester.java @@ -0,0 +1,250 @@ +package org.jctools.maps.nbhs_test; + +import java.io.*; +import java.util.Iterator; + +import org.junit.AfterClass; +import org.junit.BeforeClass; +import org.junit.Test; + +import org.jctools.maps.NonBlockingSetInt; + +import static org.hamcrest.CoreMatchers.anyOf; +import static org.hamcrest.CoreMatchers.is; +import static org.junit.Assert.*; + +/* + * Written by Cliff Click and released to the public domain, as explained at + * http://creativecommons.org/licenses/publicdomain + */ + +// Test NonBlockingSetInt via JUnit +public class nbsi_tester +{ + + static private NonBlockingSetInt _nbsi; + + @BeforeClass + public static void setUp() + { + _nbsi = new NonBlockingSetInt(); + } + + @AfterClass + public static void tearDown() + { + _nbsi = null; + } + + // Test some basic stuff; add a few keys, remove a few keys + @Test + public void testBasic() + { + assertTrue(_nbsi.isEmpty()); + assertTrue(_nbsi.add(1)); + checkSizes(1); + assertTrue(_nbsi.add(2)); + checkSizes(2); + assertFalse(_nbsi.add(1)); + assertFalse(_nbsi.add(2)); + checkSizes(2); + assertThat(_nbsi.remove(1), is(true)); + checkSizes(1); + assertThat(_nbsi.remove(1), is(false)); + assertTrue(_nbsi.remove(2)); + checkSizes(0); + assertFalse(_nbsi.remove(2)); + assertFalse(_nbsi.remove(3)); + assertTrue(_nbsi.isEmpty()); + assertTrue(_nbsi.add(63)); + checkSizes(1); + assertTrue(_nbsi.remove(63)); + assertFalse(_nbsi.remove(63)); + + + assertTrue(_nbsi.isEmpty()); + assertTrue(_nbsi.add(10000)); + checkSizes(1); + assertTrue(_nbsi.add(20000)); + checkSizes(2); + assertFalse(_nbsi.add(10000)); + assertFalse(_nbsi.add(20000)); + checkSizes(2); + assertThat(_nbsi.remove(10000), is(true)); + checkSizes(1); + assertThat(_nbsi.remove(10000), is(false)); + assertTrue(_nbsi.remove(20000)); + checkSizes(0); + assertFalse(_nbsi.remove(20000)); + _nbsi.clear(); + } + + // Check all iterators for correct size counts + private void checkSizes(int expectedSize) + { + assertEquals("size()", _nbsi.size(), expectedSize); + Iterator it = _nbsi.iterator(); + int result = 0; + while (it.hasNext()) + { + result++; + it.next(); + } + assertEquals("iterator missed", expectedSize, result); + } + + + @Test + public void testIteration() + { + assertTrue(_nbsi.isEmpty()); + assertTrue(_nbsi.add(1)); + assertTrue(_nbsi.add(2)); + + StringBuilder buf = new StringBuilder(); + for (Integer val : _nbsi) + { + buf.append(val); + } + assertThat("found all vals", buf.toString(), anyOf(is("12"), is("21"))); + + assertThat("toString works", _nbsi.toString(), anyOf(is("[1, 2]"), is("[2, 1]"))); + _nbsi.clear(); + } + + @Test + public void testIterationBig() + { + for (int i = 0; i < 100; i++) + { + _nbsi.add(i); + } + assertThat(_nbsi.size(), is(100)); + + int sz = 0; + int sum = 0; + for (Integer x : _nbsi) + { + sz++; + sum += x; + assertTrue(x >= 0 && x <= 99); + } + assertThat("Found 100 ints", sz, is(100)); + assertThat("Found all integers in list", sum, is(100 * 99 / 2)); + + assertThat("can remove 3", _nbsi.remove(3), is(true)); + assertThat("can remove 4", _nbsi.remove(4), is(true)); + sz = 0; + sum = 0; + for (Integer x : _nbsi) + { + sz++; + sum += x; + assertTrue(x >= 0 && x <= 99); + } + assertThat("Found 98 ints", sz, is(98)); + assertThat("Found all integers in list", sum, is(100 * 99 / 2 - (3 + 4))); + _nbsi.clear(); + } + + @Test + public void testSerial() + { + assertTrue(_nbsi.isEmpty()); + assertTrue(_nbsi.add(1)); + assertTrue(_nbsi.add(2)); + + // Serialize it out + try + { + FileOutputStream fos = new FileOutputStream("NBSI_test.txt"); + ObjectOutputStream out = new ObjectOutputStream(fos); + out.writeObject(_nbsi); + out.close(); + } + catch (IOException ex) + { + ex.printStackTrace(); + } + + // Read it back + try + { + File f = new File("NBSI_test.txt"); + FileInputStream fis = new FileInputStream(f); + ObjectInputStream in = new ObjectInputStream(fis); + NonBlockingSetInt nbsi = (NonBlockingSetInt) in.readObject(); + in.close(); + assertEquals(_nbsi.toString(), nbsi.toString()); + if (!f.delete()) + { + throw new IOException("delete failed"); + } + } + catch (IOException | ClassNotFoundException ex) + { + ex.printStackTrace(); + } + _nbsi.clear(); + } + + // Do some simple concurrent testing + @Test + public void testConcurrentSimple() throws InterruptedException + { + final NonBlockingSetInt nbsi = new NonBlockingSetInt(); + + // In 2 threads, add & remove even & odd elements concurrently + Thread t1 = new Thread() + { + public void run() + { + work_helper(nbsi, "T1", 1); + } + }; + t1.start(); + work_helper(nbsi, "T0", 1); + t1.join(); + + // In the end, all members should be removed + StringBuffer buf = new StringBuffer(); + buf.append("Should be emptyset but has these elements: {"); + boolean found = false; + for (Integer x : nbsi) + { + buf.append(" ").append(x); + found = true; + } + if (found) + { + System.out.println(buf); + } + assertThat("concurrent size=0", nbsi.size(), is(0)); + for (Integer x : nbsi) + { + assertTrue("No elements so never get here", false); + } + _nbsi.clear(); + } + + void work_helper(NonBlockingSetInt nbsi, String thrd, int d) + { + final int ITERS = 100000; + for (int j = 0; j < 10; j++) + { + //long start = System.nanoTime(); + for (int i = d; i < ITERS; i += 2) + { + nbsi.add(i); + } + for (int i = d; i < ITERS; i += 2) + { + nbsi.remove(i); + } + //double delta_nanos = System.nanoTime()-start; + //double delta_secs = delta_nanos/1000000000.0; + //double ops = ITERS*2; + //System.out.println("Thrd"+thrd+" "+(ops/delta_secs)+" ops/sec size="+nbsi.size()); + } + } +} diff --git a/netty-jctools/src/test/java/org/jctools/queues/MpqSanityTest.java b/netty-jctools/src/test/java/org/jctools/queues/MpqSanityTest.java new file mode 100644 index 0000000..776e114 --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/queues/MpqSanityTest.java @@ -0,0 +1,1018 @@ +package org.jctools.queues; + +import org.jctools.queues.IndexedQueueSizeUtil.IndexedQueue; +import org.jctools.queues.spec.ConcurrentQueueSpec; +import org.jctools.queues.spec.Ordering; +import org.jctools.util.Pow2; +import org.junit.After; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.List; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.locks.LockSupport; + +import static org.hamcrest.Matchers.is; +import static org.jctools.util.TestUtil.*; +import static org.junit.Assert.*; +import static org.junit.Assume.assumeThat; + +public abstract class MpqSanityTest +{ + + public static final int SIZE = 8192 * 2; + + protected final MessagePassingQueue queue; + private final ConcurrentQueueSpec spec; + int count = 0; + Integer p; + public static final Integer DUMMY_ELEMENT = 1; + + public MpqSanityTest(ConcurrentQueueSpec spec, MessagePassingQueue queue) + { + this.queue = queue; + this.spec = spec; + } + + @After + public void clear() throws InterruptedException + { + queue.clear(); + assertTrue(queue.isEmpty()); + assertTrue(queue.size() == 0); + } + + @Test(expected = NullPointerException.class) + public void relaxedOfferNullResultsInNPE() + { + queue.relaxedOffer(null); + } + + + @Test + public void capacityWorks() + { + if (spec.isBounded()) + { + assertEquals(Pow2.roundToPowerOfTwo(spec.capacity), queue.capacity()); + } + else + { + assertEquals(MessagePassingQueue.UNBOUNDED_CAPACITY, queue.capacity()); + } + } + + @Test + public void fillToCapacityOnBounded() + { + assumeThat(spec.isBounded(), is(Boolean.TRUE)); + + queue.fill(() -> DUMMY_ELEMENT); + assertEquals(queue.capacity(), queue.size()); + } + + @Test + public void fillOnUnbounded() + { + assumeThat(spec.isBounded(), is(Boolean.FALSE)); + + queue.fill(() -> DUMMY_ELEMENT); + assertTrue(!queue.isEmpty()); + } + @Test + public void fillToCapacityInBatches() + { + assumeThat(spec.isBounded(), is(Boolean.TRUE)); + Integer element = 1; + + int filled = 0; + for (int i = 0; i < SIZE; i++) + { + filled += queue.fill(() -> DUMMY_ELEMENT, 16); + assertEquals(filled, queue.size()); + if (filled == queue.capacity()) + break; + } + assertEquals(queue.capacity(), queue.size()); + } + + @Test(expected = IllegalArgumentException.class) + public void fillNullSupplier() + { + queue.fill(null); + fail(); + } + + @Test(expected = IllegalArgumentException.class) + public void fillNullSupplierLimit() + { + queue.fill(null, 10); + fail(); + } + + @Test(expected = IllegalArgumentException.class) + public void fillNegativeLimit() + { + queue.fill(() -> DUMMY_ELEMENT,-1); + fail(); + } + + @Test + public void fill0() + { + assertEquals(0, queue.fill(() -> {fail(); return 1;},0)); + assertTrue(queue.isEmpty()); + } + + @Test(expected = IllegalArgumentException.class) + public void fillNullSupplierWaiterExit() + { + queue.fill(null, i -> i++, () -> true); + fail(); + } + + @Test(expected = IllegalArgumentException.class) + public void fillSupplierNullWaiterExit() + { + queue.fill(() -> DUMMY_ELEMENT, null, () -> true); + fail(); + } + + @Test(expected = IllegalArgumentException.class) + public void fillSupplierWaiterNullExit() + { + queue.fill(() -> DUMMY_ELEMENT, i -> i++, null); + fail(); + } + + @Test(expected = IllegalArgumentException.class) + public void drainNullConsumer() + { + queue.drain(null); + fail(); + } + + @Test(expected = IllegalArgumentException.class) + public void drainNullConsumerLimit() + { + queue.drain(null, 10); + fail(); + } + + @Test(expected = IllegalArgumentException.class) + public void drainNegativeLimit() + { + queue.drain(e -> {},-1); + fail(); + } + + @Test + public void drain0() + { + queue.offer(DUMMY_ELEMENT); + assertEquals(0, queue.drain(e -> fail(),0)); + assertEquals(1, queue.size()); + } + + @Test(expected = IllegalArgumentException.class) + public void drainNullConsumerWaiterExit() + { + queue.drain(null, i -> i++, () -> true); + fail(); + } + + @Test(expected = IllegalArgumentException.class) + public void drainSupplierNullWaiterExit() + { + queue.drain(e -> fail(), null, () -> true); + fail(); + } + + @Test(expected = IllegalArgumentException.class) + public void drainSupplierWaiterNullExit() + { + queue.drain(e -> fail(), i -> i++, null); + fail(); + } + + @Test + public void sanity() + { + for (int i = 0; i < SIZE; i++) + { + assertNull(queue.relaxedPoll()); + assertTrue(queue.isEmpty()); + assertTrue(queue.size() == 0); + } + int i = 0; + while (i < SIZE && queue.relaxedOffer(i)) + { + i++; + } + int size = i; + assertEquals(size, queue.size()); + if (spec.ordering == Ordering.FIFO) + { + // expect FIFO + i = 0; + Integer p; + Integer e; + while ((p = queue.relaxedPeek()) != null) + { + e = queue.relaxedPoll(); + assertEquals(p, e); + assertEquals(size - (i + 1), queue.size()); + assertEquals(i++, e.intValue()); + } + assertEquals(size, i); + } + else + { + // expect sum of elements is (size - 1) * size / 2 = 0 + 1 + .... + (size - 1) + int sum = (size - 1) * size / 2; + i = 0; + Integer e; + while ((e = queue.relaxedPoll()) != null) + { + assertEquals(--size, queue.size()); + sum -= e; + } + assertEquals(0, sum); + } + } + + int sum; + @Test + public void sanityDrainBatch() + { + assertEquals(0, queue.drain(e -> + { + }, SIZE)); + assertTrue(queue.isEmpty()); + assertTrue(queue.size() == 0); + count = 0; + sum = 0; + int i = queue.fill(() -> + { + final int val = count++; + sum += val; + return val; + }, SIZE); + final int size = i; + assertEquals(size, queue.size()); + if (spec.ordering == Ordering.FIFO) + { + // expect FIFO + count = 0; + int drainCount = 0; + i = 0; + do + { + i += drainCount = queue.drain(e -> + { + assertEquals(count++, e.intValue()); + }); + } + while (drainCount != 0); + assertEquals(size, i); + + assertTrue(queue.isEmpty()); + assertTrue(queue.size() == 0); + } + else + { + int drainCount = 0; + i = 0; + do + { + i += drainCount = queue.drain(e -> + { + sum -= e.intValue(); + }); + } + while (drainCount != 0); + assertEquals(size, i); + + assertTrue(queue.isEmpty()); + assertTrue(queue.size() == 0); + assertEquals(0, sum); + } + } + + @Test + public void testSizeIsTheNumberOfOffers() + { + int currentSize = 0; + while (currentSize < SIZE && queue.relaxedOffer(currentSize)) + { + currentSize++; + assertFalse(queue.isEmpty()); + assertTrue(queue.size() == currentSize); + } + if (spec.isBounded()) + { + assertEquals(spec.capacity, currentSize); + } + else + { + assertEquals(SIZE, currentSize); + } + } + + @Test + public void supplyMessageUntilFull() + { + assumeThat(spec.isBounded(), is(Boolean.TRUE)); + final Val instances = new Val(); + instances.value = 0; + final MessagePassingQueue.Supplier messageFactory = () -> instances.value++; + final int capacity = queue.capacity(); + int filled = 0; + while (filled < capacity) + { + filled += queue.fill(messageFactory, capacity - filled); + } + assertEquals(instances.value, capacity); + final int noItems = queue.fill(messageFactory, 1); + assertEquals(noItems, 0); + assertEquals(instances.value, capacity); + } + + @Test + public void whenFirstInThenFirstOut() + { + assumeThat(spec.ordering, is(Ordering.FIFO)); + + // Arrange + int i = 0; + while (i < SIZE && queue.relaxedOffer(i)) + { + i++; + } + final int size = queue.size(); + + // Act + i = 0; + Integer prev; + while ((prev = queue.relaxedPeek()) != null) + { + final Integer item = queue.relaxedPoll(); + + assertThat(item, is(prev)); + assertEquals((size - (i + 1)), queue.size()); + assertThat(item, is(i)); + i++; + } + + // Assert + assertThat(i, is(size)); + } + + @Test + public void test_FIFO_PRODUCER_Ordering() throws Exception + { + assumeThat(spec.ordering, is((Ordering.FIFO))); + + // Arrange + int i = 0; + while (i < SIZE && queue.relaxedOffer(i)) + { + i++; + } + int size = queue.size(); + + // Act + // expect sum of elements is (size - 1) * size / 2 = 0 + 1 + .... + (size - 1) + int sum = (size - 1) * size / 2; + Integer e; + while ((e = queue.relaxedPoll()) != null) + { + size--; + assertEquals(size, queue.size()); + sum -= e; + } + + // Assert + assertThat(sum, is(0)); + } + + @Test + public void whenOfferItemAndPollItemThenSameInstanceReturnedAndQueueIsEmpty() + { + assertTrue(queue.isEmpty()); + assertTrue(queue.size() == 0); + + // Act + final Integer e = new Integer(1876876); + queue.relaxedOffer(e); + assertFalse(queue.isEmpty()); + assertEquals(1, queue.size()); + + final Integer oh = queue.relaxedPoll(); + assertEquals(e, oh); + + // Assert + assertTrue(queue.isEmpty()); + assertTrue(queue.size() == 0); + } + + @Test + public void testPowerOf2Capacity() + { + assumeThat(spec.isBounded(), is(true)); + int n = Pow2.roundToPowerOfTwo(spec.capacity); + + for (int i = 0; i < n; i++) + { + assertTrue("Failed to insert:" + i, queue.relaxedOffer(i)); + } + assertFalse(queue.relaxedOffer(n)); + } + + @Test(timeout = TEST_TIMEOUT) + public void testHappensBefore() throws Exception + { + final AtomicBoolean stop = new AtomicBoolean(); + final MessagePassingQueue q = queue; + final Val fail = new Val(); + List threads = new ArrayList<>(); + threads(() -> { + while (!stop.get()) + { + for (int i = 1; i <= 10; i++) + { + Val v = new Val(); + v.value = i; + q.relaxedOffer(v); + } + // slow down the producer, this will make the queue mostly empty encouraging visibility + // issues. + Thread.yield(); + } + }, spec.producers, threads); + + threads(() -> { + while (!stop.get()) + { + for (int i = 0; i < 10; i++) + { + Val v1 = (Val) q.relaxedPeek(); + if (v1 != null && v1.value == 0) + { + fail.value = 1; + stop.set(true); + } + else + { + continue; + } + Val v2 = (Val) q.relaxedPoll(); + if (v2 == null || v1 != v2) + { + fail.value = 2; + stop.set(true); + } + } + } + }, 1, threads); + + startWaitJoin(stop, threads); + assertEquals("reordering detected", 0, fail.value); + } + + @Test(timeout = TEST_TIMEOUT) + public void testHappensBeforePerpetualDrain() throws Exception + { + final AtomicBoolean stop = new AtomicBoolean(); + final MessagePassingQueue q = queue; + final Val fail = new Val(); + List threads = new ArrayList<>(); + threads(() -> { + while (!stop.get()) + { + for (int i = 1; i <= 10; i++) + { + Val v = new Val(); + v.value = i; + q.relaxedOffer(v); + } + // slow down the producer, this will make the queue mostly empty encouraging visibility + // issues. + Thread.yield(); + } + }, spec.producers, threads); + + threads(() -> { + while (!stop.get()) + { + q.drain(e -> + { + Val v = (Val) e; + if (v != null && v.value == 0) + { + fail.value = 1; + stop.set(true); + } + if (v == null) + { + fail.value = 1; + stop.set(true); + System.out.println("Unexpected: v == null"); + } + }, idle -> + { + return idle; + }, () -> + { + return !stop.get(); + }); + } + }, 1 , threads); + + startWaitJoin(stop, threads); + assertEquals("reordering detected", 0, fail.value); + + } + + @Test(timeout = TEST_TIMEOUT) + public void testHappensBeforePerpetualFill() throws Exception + { + final AtomicBoolean stop = new AtomicBoolean(); + final MessagePassingQueue q = queue; + final Val fail = new Val(); + List threads = new ArrayList<>(); + threads(() -> { + Val counter = new Val(); + counter.value = 1; + q.fill(() -> + { + Val v = new Val(); + int c = counter.value++ % 10; + v.value = 1 + c; + if (c == 0) + Thread.yield(); + return v; + }, e -> + { + return e; + }, () -> + { + // slow down the producer, this will make the queue mostly empty encouraging visibility + // issues. + Thread.yield(); + return !stop.get(); + }); + }, spec.producers, threads); + threads(() -> { + while (!stop.get()) + { + for (int i = 0; i < 10; i++) + { + Val v1 = (Val) q.relaxedPeek(); + int r; + if (v1 != null && (r = v1.value) == 0) + { + fail.value = 1; + stop.set(true); + } + else + { + continue; + } + + Val v2 = (Val) q.relaxedPoll(); + if (v2 == null || v1 != v2) + { + fail.value = 1; + stop.set(true); + } + } + } + }, 1, threads); + + startWaitJoin(stop, threads); + assertEquals("reordering detected", 0, fail.value); + + } + + @Test(timeout = TEST_TIMEOUT) + public void testHappensBeforePerpetualFillDrain() throws Exception + { + final AtomicBoolean stop = new AtomicBoolean(); + final MessagePassingQueue q = queue; + final Val fail = new Val(); + List threads = new ArrayList<>(); + threads(() -> { + Val counter = new Val(); + counter.value = 1; + q.fill(() -> + { + Val v = new Val(); + v.value = 1 + (counter.value++ % 10); + return v; + }, e -> + { + return e; + }, () -> + { // slow down the producer, this will make the queue mostly empty encouraging + // visibility issues. + Thread.yield(); + return !stop.get(); + }); + }, spec.producers, threads); + + threads(() -> { + while (!stop.get()) + { + q.drain(e -> + { + Val v = (Val) e; + if (v != null && v.value == 0) + { + fail.value = 1; + stop.set(true); + } + if (v == null) + { + fail.value = 1; + stop.set(true); + System.out.println("Unexpected: v == null"); + } + }, idle -> + { + return idle; + }, () -> + { + return !stop.get(); + }); + } + }, 1, threads); + + startWaitJoin(stop, threads); + assertEquals("reordering detected", 0, fail.value); + queue.clear(); + + } + + @Test(timeout = TEST_TIMEOUT) + public void testRelaxedOfferPollObservedSize() throws Exception + { + final int capacity = !spec.isBounded() ? Integer.MAX_VALUE : queue.capacity(); + final AtomicBoolean stop = new AtomicBoolean(); + final MessagePassingQueue q = queue; + final Val fail = new Val(); + List threads = new ArrayList<>(); + threads(() -> { + while (!stop.get()) + { + if(q.relaxedOffer(1)) + while (q.relaxedPoll() == null); + } + }, !spec.isMpmc()? 1: 0, threads); + + int threadCount = threads.size(); + threads(() -> { + final int max = Math.min(threadCount, capacity); + while (!stop.get()) + { + int size = q.size(); + if (size < 0 || size > max) + { + fail.value++; + } + } + }, 1, threads); + + startWaitJoin(stop, threads); + assertEquals("Unexpected size observed", 0, fail.value); + + } + + @Test(timeout = TEST_TIMEOUT) + public void testPeekAfterIsEmpty1() throws Exception + { + final AtomicBoolean stop = new AtomicBoolean(); + final MessagePassingQueue q = queue; + final Val fail = new Val(); + + testIsEmptyInvariant(stop, q, fail, () -> { + while (!stop.get()) + { + if (!q.isEmpty() && q.peek() == null) + { + fail.value++; + } + q.poll(); + } + }); + } + + @Test(timeout = TEST_TIMEOUT) + public void testPeekAfterIsEmpty2() throws Exception + { + final AtomicBoolean stop = new AtomicBoolean(); + final MessagePassingQueue q = queue; + final Val fail = new Val(); + + testIsEmptyInvariant(stop, q, fail, () -> { + while (!stop.get()) + { + // can the consumer progress "passed" the producer and confuse `isEmpty`? + q.poll(); + if (!q.isEmpty() && q.peek() == null) + { + fail.value++; + } + } + }); + } + + @Test(timeout = TEST_TIMEOUT) + public void testPeekAfterIsEmpty3() throws Exception + { + final AtomicBoolean stop = new AtomicBoolean(); + final MessagePassingQueue q = queue; + final Val fail = new Val(); + + testIsEmptyInvariant(stop, q, fail, () -> { + while (!stop.get()) + { + // can the consumer progress "passed" the producer and confuse `size`? + q.poll(); + if (q.size() != 0 && q.peek() == null) + { + fail.value++; + } + } + }); + } + + @Test(timeout = TEST_TIMEOUT) + public void testPollAfterIsEmpty1() throws Exception + { + final AtomicBoolean stop = new AtomicBoolean(); + final MessagePassingQueue q = queue; + final Val fail = new Val(); + + testIsEmptyInvariant(stop, q, fail, () -> { + while (!stop.get()) + { + if (!q.isEmpty() && q.poll() == null) + { + fail.value++; + } + } + }); + } + + @Test(timeout = TEST_TIMEOUT) + public void testPollAfterIsEmpty2() throws Exception + { + final AtomicBoolean stop = new AtomicBoolean(); + final MessagePassingQueue q = queue; + final Val fail = new Val(); + + testIsEmptyInvariant(stop, q, fail, () -> { + while (!stop.get()) + { + // can the consumer progress "passed" the producer and confuse `isEmpty`? + q.poll(); + if (!q.isEmpty() && q.poll() == null) + { + fail.value++; + } + } + }); + } + + @Test(timeout = TEST_TIMEOUT) + public void testPollAfterIsEmpty3() throws Exception + { + final AtomicBoolean stop = new AtomicBoolean(); + final MessagePassingQueue q = queue; + final Val fail = new Val(); + boolean slowSize = !spec.isBounded() && !(queue instanceof IndexedQueue); + + testIsEmptyInvariant(stop, q, fail, () -> { + while (!stop.get()) + { + // can the consumer progress "passed" the producer and confuse `size`? + q.poll(); + int size = q.size(); + if (size != 0 && q.poll() == null) + { + fail.value++; + } + if (slowSize) { + q.clear(); + } + } + }); + } + + private void testIsEmptyInvariant(AtomicBoolean stop, MessagePassingQueue q, Val fail, Runnable consumerLoop) + throws InterruptedException + { + testIsEmptyInvariant(stop, fail, consumerLoop, () -> { + while (!stop.get()) + { + q.relaxedOffer(1); + // slow down the producer, this will make the queue mostly empty encouraging visibility issues. + Thread.yield(); + } + }); + + testIsEmptyInvariant(stop, fail, consumerLoop, () -> { + while (!stop.get()) + { + int items = q.fill(() -> 1); + // slow down the producer, this will make the queue mostly empty encouraging visibility issues. + LockSupport.parkNanos(items); + } + }); + + testIsEmptyInvariant(stop, fail, consumerLoop, () -> { + while (!stop.get()) + { + q.fill(() -> 1, 1); + // slow down the producer, this will make the queue mostly empty encouraging visibility issues. + Thread.yield(); + } + }); + + testIsEmptyInvariant(stop, fail, consumerLoop, () -> { + q.fill(() -> 1, i -> {Thread.yield(); return i;}, () -> !stop.get()); + }); + + int capacity = q.capacity(); + if (capacity == MessagePassingQueue.UNBOUNDED_CAPACITY || capacity == 1) + return; + int limit = Math.max(capacity/8, 2); + testIsEmptyInvariant(stop, fail, consumerLoop, () -> { + while (!stop.get()) + { + int items = q.fill(() -> 1, limit); + // slow down the producer, this will make the queue mostly empty encouraging visibility issues. + LockSupport.parkNanos(items); + } + }); + + } + + private void testIsEmptyInvariant( + AtomicBoolean stop, + Val fail, + Runnable consumerLoop, + Runnable producerLoop) + throws InterruptedException + { + List threads = new ArrayList<>(); + threads(producerLoop, spec.producers, threads); + threads(consumerLoop, 1, threads); + startWaitJoin(stop, threads, 4); + + assertEquals("Observed no element in non-empty queue", 0, fail.value); + clear(); + } + + @Test(timeout = TEST_TIMEOUT) + public void testSizeLtZero() throws Exception + { + final AtomicBoolean stop = new AtomicBoolean(); + final MessagePassingQueue q = queue; + + // producer check size and offer + final Val pFail = new Val(); + List threads = new ArrayList<>(); + threads(() -> { + while (!stop.get()) { + if (q.size() < 0) { + pFail.value++; + } + + q.offer(1); + Thread.yield(); + } + }, spec.producers, threads); + + // consumer poll and check size + final Val cFail = new Val(); + threads(() -> { + while (!stop.get()) + { + q.poll(); + + if (q.size() < 0) { + cFail.value++; + } + } + }, spec.consumers, threads); + + // observer check size + final Val oFail = new Val(); + threads(() -> { + while (!stop.get()) + { + if (q.size() < 0) { + oFail.value++; + } + Thread.yield(); + } + }, 1, threads); + + startWaitJoin(stop, threads); + + assertEquals("Observed producer size < 0", 0, pFail.value); + assertEquals("Observed consumer size < 0", 0, cFail.value); + assertEquals("Observed observer size < 0", 0, oFail.value); + } + + @Test(timeout = TEST_TIMEOUT) + public void testSizeGtCapacity() throws Exception + { + assumeThat(spec.isBounded(), is(Boolean.TRUE)); + + final int capacity = spec.capacity; + final AtomicBoolean stop = new AtomicBoolean(); + final MessagePassingQueue q = queue; + + // producer offer and check size + final Val pFail = new Val(); + List threads = new ArrayList<>(); + threads(() -> { + while (!stop.get()) { + q.offer(1); + + if (q.size() > capacity) { + pFail.value++; + } + } + }, spec.producers, threads); + + // consumer check size and poll + final Val cFail = new Val(); + threads(() -> { + while (!stop.get()) + { + if (q.size() > capacity) { + cFail.value++; + } + + q.poll(); + sleepQuietly(1); + } + }, 1, threads); + + // observer check size + final Val oFail = new Val(); + threads(() -> { + while (!stop.get()) + { + if (q.size() > capacity) { + oFail.value++; + } + Thread.yield(); + } + }, 1, threads); + + startWaitJoin(stop, threads); + + assertEquals("Observed producer size > capacity", 0, pFail.value); + assertEquals("Observed consumer size > capacity", 0, cFail.value); + assertEquals("Observed observer size > capacity", 0, oFail.value); + } + + @Test(timeout = TEST_TIMEOUT) + public void testPeekEqualsPoll() throws InterruptedException { + final AtomicBoolean stop = new AtomicBoolean(); + final MessagePassingQueue q = queue; + + List threads = new ArrayList<>(); + threads(() -> { + int sequence = 0; + while (!stop.get()) { + if (q.offer(sequence)) { + sequence++; + } + Thread.yield(); + } + }, spec.producers, threads); + + final Val fail = new Val(); + threads(() -> { + while (!stop.get()) { + final Integer peekedSequence = q.peek(); + if (peekedSequence == null) { + continue; + } + if (!peekedSequence.equals(q.poll())) { + fail.value++; + } + } + }, 1, threads); + + startWaitJoin(stop, threads); + + assertEquals("Observed peekedSequence is not equal to polledSequence", 0, fail.value); + } +} diff --git a/netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestMpmcArray.java b/netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestMpmcArray.java new file mode 100644 index 0000000..3cdd720 --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestMpmcArray.java @@ -0,0 +1,33 @@ +package org.jctools.queues; + +import org.jctools.queues.spec.ConcurrentQueueSpec; +import org.jctools.queues.spec.Ordering; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.ArrayList; +import java.util.Collection; + +import static org.jctools.util.TestUtil.*; + +@RunWith(Parameterized.class) +public class MpqSanityTestMpmcArray extends MpqSanityTest +{ + public MpqSanityTestMpmcArray(ConcurrentQueueSpec spec, MessagePassingQueue queue) + { + super(spec, queue); + } + + @Parameterized.Parameters + public static Collection parameters() + { + ArrayList list = new ArrayList(); + list.add(makeMpq(0, 0, 2, Ordering.FIFO)); + list.add(makeMpq(0, 0, SIZE, Ordering.FIFO)); + list.add(makeAtomic(0, 0, 2, Ordering.FIFO)); + list.add(makeAtomic(0, 0, SIZE, Ordering.FIFO)); + list.add(makeUnpadded(0, 0, 2, Ordering.FIFO)); + list.add(makeUnpadded(0, 0, SIZE, Ordering.FIFO)); + return list; + } +} diff --git a/netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestMpmcUnboundedXadd.java b/netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestMpmcUnboundedXadd.java new file mode 100644 index 0000000..bbe6b09 --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestMpmcUnboundedXadd.java @@ -0,0 +1,214 @@ +package org.jctools.queues; + +import org.jctools.queues.spec.ConcurrentQueueSpec; +import org.jctools.queues.spec.Ordering; +import org.junit.Assert; +import org.junit.Assume; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; + +import static org.jctools.util.TestUtil.makeParams; + +@RunWith(Parameterized.class) +public class MpqSanityTestMpmcUnboundedXadd extends MpqSanityTest +{ + public MpqSanityTestMpmcUnboundedXadd(ConcurrentQueueSpec spec, MessagePassingQueue queue) + { + super(spec, queue); + } + + @Parameterized.Parameters + public static Collection parameters() + { + ArrayList list = new ArrayList(); + list.add(makeParams(0, 0, 0, Ordering.FIFO, new MpmcUnboundedXaddArrayQueue<>(1, 0))); + list.add(makeParams(0, 0, 0, Ordering.FIFO, new MpmcUnboundedXaddArrayQueue<>(16, 0))); + list.add(makeParams(0, 0, 0, Ordering.FIFO, new MpmcUnboundedXaddArrayQueue<>(1, 1))); + list.add(makeParams(0, 0, 0, Ordering.FIFO, new MpmcUnboundedXaddArrayQueue<>(16, 1))); + list.add(makeParams(0, 0, 0, Ordering.FIFO, new MpmcUnboundedXaddArrayQueue<>(1, 2))); + list.add(makeParams(0, 0, 0, Ordering.FIFO, new MpmcUnboundedXaddArrayQueue<>(16, 2))); + list.add(makeParams(0, 0, 0, Ordering.FIFO, new MpmcUnboundedXaddArrayQueue<>(1, 3))); + list.add(makeParams(0, 0, 0, Ordering.FIFO, new MpmcUnboundedXaddArrayQueue<>(16, 3))); + list.add(makeParams(0, 0, 0, Ordering.FIFO, new MpmcUnboundedXaddArrayQueue<>(1, 4))); + list.add(makeParams(0, 0, 0, Ordering.FIFO, new MpmcUnboundedXaddArrayQueue<>(16, 4))); + return list; + } + + @Test + public void peekShouldNotSeeFutureElements() throws InterruptedException + { + MpmcUnboundedXaddArrayQueue xaddQueue = (MpmcUnboundedXaddArrayQueue) this.queue; + Assume.assumeTrue("The queue need to pool some chunk to run this test", xaddQueue.maxPooledChunks() > 0); + CountDownLatch stop = new CountDownLatch(1); + Producer producer = new Producer(xaddQueue, stop); + producer.start(); + Consumer consumer = new Consumer(xaddQueue, stop); + consumer.start(); + Peeker peeker = new Peeker(xaddQueue, stop, false); + peeker.start(); + try + { + stop.await(2, TimeUnit.SECONDS); + } + finally + { + stop.countDown(); + } + final String error = peeker.error; + if (error != null) + { + Assert.fail(error); + } + producer.join(); + consumer.join(); + peeker.join(); + } + + @Test + public void relaxedPeekShouldNotSeeFutureElements() throws InterruptedException + { + MpmcUnboundedXaddArrayQueue xaddQueue = (MpmcUnboundedXaddArrayQueue) this.queue; + Assume.assumeTrue("The queue need to pool some chunk to run this test", xaddQueue.maxPooledChunks() > 0); + CountDownLatch stop = new CountDownLatch(1); + Producer producer = new Producer(xaddQueue, stop); + producer.start(); + Consumer consumer = new Consumer(xaddQueue, stop); + consumer.start(); + Peeker peeker = new Peeker(xaddQueue, stop, true); + peeker.start(); + try + { + stop.await(2, TimeUnit.SECONDS); + } + finally + { + stop.countDown(); + } + final String error = peeker.error; + if (error != null) + { + Assert.fail(error); + } + producer.join(); + consumer.join(); + peeker.join(); + } + + private static class Producer extends Thread + { + final CountDownLatch stop; + final MpUnboundedXaddArrayQueue messageQueue; + long sequence = 0; + + Producer(MpUnboundedXaddArrayQueue messageQueue, CountDownLatch stop) + { + this.messageQueue = messageQueue; + this.stop = stop; + } + + @Override + public void run() + { + final int chunkSize = messageQueue.chunkSize(); + final int capacity = chunkSize * messageQueue.maxPooledChunks(); + + while (stop.getCount() > 0) + { + if (messageQueue.offer(sequence)) + { + sequence++; + } + + while (messageQueue.size() >= capacity - chunkSize) + { + if (stop.getCount() == 0) { + return; + } + Thread.yield(); + } + } + } + } + + private static class Consumer extends Thread + { + + final MpUnboundedXaddArrayQueue messageQueue; + final CountDownLatch stop; + + private Consumer(MpUnboundedXaddArrayQueue messageQueue, CountDownLatch stop) + { + this.messageQueue = messageQueue; + this.stop = stop; + + } + + @Override + public void run() + { + final int chunkSize = messageQueue.chunkSize(); + + while (stop.getCount() > 0) + { + messageQueue.poll(); + + while (messageQueue.size() < chunkSize) + { + if (stop.getCount() == 0) { + return; + } + Thread.yield(); + } + } + } + + } + + private static class Peeker extends Thread + { + + final MessagePassingQueue messageQueue; + final CountDownLatch stop; + long lastPeekedSequence; + volatile String error; + final boolean relaxed; + + private Peeker(MessagePassingQueue messageQueue, CountDownLatch stop, boolean relaxed) + { + this.messageQueue = messageQueue; + this.stop = stop; + this.relaxed = relaxed; + setPriority(MIN_PRIORITY); + error = null; + } + + @Override + public void run() + { + final boolean relaxed = this.relaxed; + while (stop.getCount() > 0) + { + final Long peekedSequence = relaxed ? messageQueue.relaxedPeek() : messageQueue.peek(); + if (peekedSequence == null) + { + continue; + } + + if (peekedSequence < lastPeekedSequence) + { + error = + String.format("peekedSequence %s, lastPeekedSequence %s", peekedSequence, lastPeekedSequence); + stop.countDown(); + } + + lastPeekedSequence = peekedSequence; + } + } + } +} diff --git a/netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestMpscArray.java b/netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestMpscArray.java new file mode 100644 index 0000000..ddf33e3 --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestMpscArray.java @@ -0,0 +1,33 @@ +package org.jctools.queues; + +import org.jctools.queues.spec.ConcurrentQueueSpec; +import org.jctools.queues.spec.Ordering; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.ArrayList; +import java.util.Collection; + +import static org.jctools.util.TestUtil.*; + +@RunWith(Parameterized.class) +public class MpqSanityTestMpscArray extends MpqSanityTest +{ + public MpqSanityTestMpscArray(ConcurrentQueueSpec spec, MessagePassingQueue queue) + { + super(spec, queue); + } + + @Parameterized.Parameters + public static Collection parameters() + { + ArrayList list = new ArrayList(); + list.add(makeMpq(0, 1, 1, Ordering.FIFO));// MPSC size 1 + list.add(makeMpq(0, 1, SIZE, Ordering.FIFO));// MPSC size SIZE + list.add(makeAtomic(0, 1, 1, Ordering.FIFO));// MPSC size 1 + list.add(makeAtomic(0, 1, SIZE, Ordering.FIFO));// MPSC size SIZE + list.add(makeUnpadded(0, 1, 1, Ordering.FIFO));// MPSC size 1 + list.add(makeUnpadded(0, 1, SIZE, Ordering.FIFO));// MPSC size SIZE + return list; + } +} diff --git a/netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestMpscBlockingConsumer.java b/netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestMpscBlockingConsumer.java new file mode 100644 index 0000000..d4c70a3 --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestMpscBlockingConsumer.java @@ -0,0 +1,88 @@ +package org.jctools.queues; + +import org.jctools.queues.spec.ConcurrentQueueSpec; +import org.jctools.queues.spec.Ordering; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Collection; + +import static java.util.concurrent.TimeUnit.NANOSECONDS; +import static org.jctools.util.TestUtil.TEST_TIMEOUT; +import static org.jctools.util.TestUtil.makeParams; + +@RunWith(Parameterized.class) +public class MpqSanityTestMpscBlockingConsumer extends MpqSanityTest +{ + public MpqSanityTestMpscBlockingConsumer(ConcurrentQueueSpec spec, MessagePassingQueue queue) + { + super(spec, queue); + } + + @Parameterized.Parameters + public static Collection parameters() + { + ArrayList list = new ArrayList(); + list.add(makeParams(0, 1, 1, Ordering.FIFO, new MpscBlockingConsumerArrayQueue<>(1)));// MPSC size 1 + list.add(makeParams(0, 1, SIZE, Ordering.FIFO, new MpscBlockingConsumerArrayQueue<>(SIZE)));// MPSC size SIZE + return list; + } + + @Test(timeout = TEST_TIMEOUT) + public void testSpinWaitForUnblockDrainForever() throws InterruptedException { + + class Echo implements Runnable{ + private MpscBlockingConsumerArrayQueue source; + private MpscBlockingConsumerArrayQueue sink; + private int interations; + + Echo( + MpscBlockingConsumerArrayQueue source, + MpscBlockingConsumerArrayQueue sink, + int interations) { + this.source = source; + this.sink = sink; + this.interations = interations; + } + + public void run() { + ArrayDeque ints = new ArrayDeque<>(); + try { + for (int i = 0; i < interations; ++i) { + T t; + do { + source.drain(ints::offer, 1, 1, NANOSECONDS); + t = ints.poll(); + } + while (t == null); + + sink.put(t); + } + } + catch (InterruptedException e) { + throw new AssertionError(e); + } + } + } + + final MpscBlockingConsumerArrayQueue q1 = + new MpscBlockingConsumerArrayQueue<>(1024); + final MpscBlockingConsumerArrayQueue q2 = + new MpscBlockingConsumerArrayQueue<>(1024); + + final Thread t1 = new Thread(new Echo<>(q1, q2, 100000)); + final Thread t2 = new Thread(new Echo<>(q2, q1, 100000)); + + t1.start(); + t2.start(); + + q1.put("x"); + + t1.join(); + t2.join(); + } + +} diff --git a/netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestMpscBlockingConsumerExtended.java b/netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestMpscBlockingConsumerExtended.java new file mode 100644 index 0000000..6a99402 --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestMpscBlockingConsumerExtended.java @@ -0,0 +1,74 @@ +package org.jctools.queues; + +import org.junit.Test; + +import java.util.ArrayDeque; +import java.util.Queue; +import java.util.concurrent.TimeUnit; + +public class MpqSanityTestMpscBlockingConsumerExtended +{ + /** + * This test demonstrates a race described here: https://github.com/JCTools/JCTools/issues/339 + * You will need to debug to observe the spin. + */ + @Test + public void testSpinWaitForUnblockForeverFill() throws InterruptedException { + + class Echo implements Runnable { + final MpscBlockingConsumerArrayQueue source; + final MpscBlockingConsumerArrayQueue sink; + final int interations; + final int batch; + + Echo( + MpscBlockingConsumerArrayQueue source, + MpscBlockingConsumerArrayQueue sink, + int iterations, + int batch) { + this.source = source; + this.sink = sink; + this.interations = iterations; + this.batch = batch; + } + + public void run() { + Queue batchContainer = new ArrayDeque<>(batch); + try { + for (int i = 0; i < interations; ++i) { + for (int j = 0; j < batch; j++) { + T t; + do { + t = source.poll(1, TimeUnit.NANOSECONDS); + } + while (t == null); + batchContainer.add(t); + } + do { + sink.fill(() -> batchContainer.poll(), batchContainer.size()); + } while (!batchContainer.isEmpty()); + } + } + catch (InterruptedException e) { + throw new AssertionError(e); + } + } + } + + final MpscBlockingConsumerArrayQueue q1 = + new MpscBlockingConsumerArrayQueue<>(1024); + final MpscBlockingConsumerArrayQueue q2 = + new MpscBlockingConsumerArrayQueue<>(1024); + + final Thread t1 = new Thread(new Echo<>(q1, q2, 100000, 10)); + final Thread t2 = new Thread(new Echo<>(q2, q1, 100000, 10)); + + t1.start(); + t2.start(); + + for (int j = 0; j < 10; j++) q1.put("x"); + + t1.join(); + t2.join(); + } +} diff --git a/netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestMpscChunked.java b/netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestMpscChunked.java new file mode 100644 index 0000000..4009d5e --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestMpscChunked.java @@ -0,0 +1,46 @@ +package org.jctools.queues; + +import org.jctools.queues.atomic.MpscChunkedAtomicArrayQueue; +import org.jctools.queues.spec.ConcurrentQueueSpec; +import org.jctools.queues.spec.Ordering; +import org.jctools.queues.unpadded.MpscChunkedUnpaddedArrayQueue; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.ArrayList; +import java.util.Collection; + +import static org.jctools.util.TestUtil.makeParams; + +@RunWith(Parameterized.class) +public class MpqSanityTestMpscChunked extends MpqSanityTest +{ + public MpqSanityTestMpscChunked(ConcurrentQueueSpec spec, MessagePassingQueue queue) + { + super(spec, queue); + } + + @Parameterized.Parameters + public static Collection parameters() + { + ArrayList list = new ArrayList(); + list.add(makeParams(0, 1, 4, Ordering.FIFO, new MpscChunkedArrayQueue<>(2, 4)));// MPSC size 1 + list.add(makeParams(0, 1, SIZE, Ordering.FIFO, new MpscChunkedArrayQueue<>(8, SIZE)));// MPSC size SIZE + list.add(makeParams(0, 1, 4, Ordering.FIFO, new MpscChunkedAtomicArrayQueue<>(2, 4)));// MPSC size 1 + list.add(makeParams(0, 1, SIZE, Ordering.FIFO, new MpscChunkedAtomicArrayQueue<>(8, SIZE)));// MPSC size SIZE + list.add(makeParams(0, 1, 4, Ordering.FIFO, new MpscChunkedUnpaddedArrayQueue<>(2, 4)));// MPSC size 1 + list.add(makeParams(0, 1, SIZE, Ordering.FIFO, new MpscChunkedUnpaddedArrayQueue<>(8, SIZE)));// MPSC size SIZE + return list; + } + + @Test + public void testMaxSizeQueue() + { + MpscChunkedArrayQueue queue = new MpscChunkedArrayQueue(1024, 1000 * 1024 * 1024); + for (int i = 0; i < 400001; i++) + { + queue.offer(i); + } + } +} diff --git a/netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestMpscCompound.java b/netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestMpscCompound.java new file mode 100644 index 0000000..c425fad --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestMpscCompound.java @@ -0,0 +1,31 @@ +package org.jctools.queues; + +import org.jctools.queues.spec.ConcurrentQueueSpec; +import org.jctools.queues.spec.Ordering; +import org.jctools.util.Pow2; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.ArrayList; +import java.util.Collection; + +import static org.jctools.util.PortableJvmInfo.CPUs; +import static org.jctools.util.TestUtil.makeMpq; + +@RunWith(Parameterized.class) +public class MpqSanityTestMpscCompound extends MpqSanityTest +{ + public MpqSanityTestMpscCompound(ConcurrentQueueSpec spec, MessagePassingQueue queue) + { + super(spec, queue); + } + + @Parameterized.Parameters + public static Collection parameters() + { + ArrayList list = new ArrayList(); + list.add(makeMpq(0, 1, Pow2.roundToPowerOfTwo(CPUs), Ordering.NONE));// MPSC size 1 + list.add(makeMpq(0, 1, SIZE, Ordering.NONE));// MPSC size SIZE + return list; + } +} diff --git a/netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestMpscGrowable.java b/netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestMpscGrowable.java new file mode 100644 index 0000000..d96f730 --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestMpscGrowable.java @@ -0,0 +1,35 @@ +package org.jctools.queues; + +import org.jctools.queues.atomic.MpscGrowableAtomicArrayQueue; +import org.jctools.queues.spec.ConcurrentQueueSpec; +import org.jctools.queues.spec.Ordering; +import org.jctools.queues.unpadded.MpscGrowableUnpaddedArrayQueue; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.ArrayList; +import java.util.Collection; + +import static org.jctools.util.TestUtil.makeParams; + +@RunWith(Parameterized.class) +public class MpqSanityTestMpscGrowable extends MpqSanityTest +{ + public MpqSanityTestMpscGrowable(ConcurrentQueueSpec spec, MessagePassingQueue queue) + { + super(spec, queue); + } + + @Parameterized.Parameters + public static Collection parameters() + { + ArrayList list = new ArrayList(); + list.add(makeParams(0, 1, 4, Ordering.FIFO, new MpscGrowableArrayQueue<>(2, 4)));// MPSC size 1 + list.add(makeParams(0, 1, SIZE, Ordering.FIFO, new MpscGrowableArrayQueue<>(8, SIZE)));// MPSC size SIZE + list.add(makeParams(0, 1, 4, Ordering.FIFO, new MpscGrowableAtomicArrayQueue<>(2, 4)));// MPSC size 1 + list.add(makeParams(0, 1, SIZE, Ordering.FIFO, new MpscGrowableAtomicArrayQueue<>(8, SIZE)));// MPSC size SIZE + list.add(makeParams(0, 1, 4, Ordering.FIFO, new MpscGrowableUnpaddedArrayQueue<>(2, 4)));// MPSC size 1 + list.add(makeParams(0, 1, SIZE, Ordering.FIFO, new MpscGrowableUnpaddedArrayQueue<>(8, SIZE)));// MPSC size SIZE + return list; + } +} diff --git a/netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestMpscLinked.java b/netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestMpscLinked.java new file mode 100644 index 0000000..05cf8c9 --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestMpscLinked.java @@ -0,0 +1,30 @@ +package org.jctools.queues; + +import org.jctools.queues.spec.ConcurrentQueueSpec; +import org.jctools.queues.spec.Ordering; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.ArrayList; +import java.util.Collection; + +import static org.jctools.util.TestUtil.*; + +@RunWith(Parameterized.class) +public class MpqSanityTestMpscLinked extends MpqSanityTest +{ + public MpqSanityTestMpscLinked(ConcurrentQueueSpec spec, MessagePassingQueue queue) + { + super(spec, queue); + } + + @Parameterized.Parameters + public static Collection parameters() + { + ArrayList list = new ArrayList(); + list.add(makeMpq(0, 1, 0, Ordering.FIFO));// unbounded MPSC + list.add(makeAtomic(0, 1, 0, Ordering.FIFO));// unbounded MPSC + list.add(makeUnpadded(0, 1, 0, Ordering.FIFO));// unbounded MPSC + return list; + } +} diff --git a/netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestMpscUnbounded.java b/netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestMpscUnbounded.java new file mode 100644 index 0000000..d52d6f9 --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestMpscUnbounded.java @@ -0,0 +1,35 @@ +package org.jctools.queues; + +import org.jctools.queues.atomic.MpscUnboundedAtomicArrayQueue; +import org.jctools.queues.spec.ConcurrentQueueSpec; +import org.jctools.queues.spec.Ordering; +import org.jctools.queues.unpadded.MpscUnboundedUnpaddedArrayQueue; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.ArrayList; +import java.util.Collection; + +import static org.jctools.util.TestUtil.makeParams; + +@RunWith(Parameterized.class) +public class MpqSanityTestMpscUnbounded extends MpqSanityTest +{ + public MpqSanityTestMpscUnbounded(ConcurrentQueueSpec spec, MessagePassingQueue queue) + { + super(spec, queue); + } + + @Parameterized.Parameters + public static Collection parameters() + { + ArrayList list = new ArrayList(); + list.add(makeParams(0, 1, 0, Ordering.FIFO, new MpscUnboundedArrayQueue<>(2))); + list.add(makeParams(0, 1, 0, Ordering.FIFO, new MpscUnboundedArrayQueue<>(64))); + list.add(makeParams(0, 1, 0, Ordering.FIFO, new MpscUnboundedAtomicArrayQueue<>(2))); + list.add(makeParams(0, 1, 0, Ordering.FIFO, new MpscUnboundedAtomicArrayQueue<>(64))); + list.add(makeParams(0, 1, 0, Ordering.FIFO, new MpscUnboundedUnpaddedArrayQueue<>(2))); + list.add(makeParams(0, 1, 0, Ordering.FIFO, new MpscUnboundedUnpaddedArrayQueue<>(64))); + return list; + } +} diff --git a/netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestMpscUnboundedXadd.java b/netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestMpscUnboundedXadd.java new file mode 100644 index 0000000..9b52c6d --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestMpscUnboundedXadd.java @@ -0,0 +1,36 @@ +package org.jctools.queues; + +import org.jctools.queues.spec.ConcurrentQueueSpec; +import org.jctools.queues.spec.Ordering; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.ArrayList; +import java.util.Collection; + +import static org.jctools.util.TestUtil.makeParams; + +@RunWith(Parameterized.class) +public class MpqSanityTestMpscUnboundedXadd extends MpqSanityTest +{ + + public MpqSanityTestMpscUnboundedXadd(ConcurrentQueueSpec spec, MessagePassingQueue queue) + { + super(spec, queue); + } + + @Parameterized.Parameters + public static Collection parameters() + { + ArrayList list = new ArrayList(); + list.add(makeParams(0, 1, 0, Ordering.FIFO, new MpscUnboundedXaddArrayQueue<>(1, 0))); + list.add(makeParams(0, 1, 0, Ordering.FIFO, new MpscUnboundedXaddArrayQueue<>(64, 0))); + list.add(makeParams(0, 1, 0, Ordering.FIFO, new MpscUnboundedXaddArrayQueue<>(1, 1))); + list.add(makeParams(0, 1, 0, Ordering.FIFO, new MpscUnboundedXaddArrayQueue<>(64, 1))); + list.add(makeParams(0, 1, 0, Ordering.FIFO, new MpscUnboundedXaddArrayQueue<>(1, 2))); + list.add(makeParams(0, 1, 0, Ordering.FIFO, new MpscUnboundedXaddArrayQueue<>(64, 2))); + list.add(makeParams(0, 1, 0, Ordering.FIFO, new MpscUnboundedXaddArrayQueue<>(1, 3))); + list.add(makeParams(0, 1, 0, Ordering.FIFO, new MpscUnboundedXaddArrayQueue<>(64, 3))); + return list; + } +} diff --git a/netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestSpmcArray.java b/netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestSpmcArray.java new file mode 100644 index 0000000..98ae222 --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestSpmcArray.java @@ -0,0 +1,33 @@ +package org.jctools.queues; + +import org.jctools.queues.spec.ConcurrentQueueSpec; +import org.jctools.queues.spec.Ordering; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.ArrayList; +import java.util.Collection; + +import static org.jctools.util.TestUtil.*; + +@RunWith(Parameterized.class) +public class MpqSanityTestSpmcArray extends MpqSanityTest +{ + public MpqSanityTestSpmcArray(ConcurrentQueueSpec spec, MessagePassingQueue queue) + { + super(spec, queue); + } + + @Parameterized.Parameters + public static Collection parameters() + { + ArrayList list = new ArrayList(); + list.add(makeMpq(1, 0, 1, Ordering.FIFO));// SPMC size 1 + list.add(makeMpq(1, 0, SIZE, Ordering.FIFO));// SPMC size SIZE + list.add(makeAtomic(1, 0, 1, Ordering.FIFO));// SPMC size 1 + list.add(makeAtomic(1, 0, SIZE, Ordering.FIFO));// SPMC size SIZE + list.add(makeUnpadded(1, 0, 1, Ordering.FIFO));// SPMC size 1 + list.add(makeUnpadded(1, 0, SIZE, Ordering.FIFO));// SPMC size SIZE + return list; + } +} diff --git a/netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestSpscArray.java b/netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestSpscArray.java new file mode 100644 index 0000000..6e2255a --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestSpscArray.java @@ -0,0 +1,33 @@ +package org.jctools.queues; + +import org.jctools.queues.spec.ConcurrentQueueSpec; +import org.jctools.queues.spec.Ordering; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.ArrayList; +import java.util.Collection; + +import static org.jctools.util.TestUtil.*; + +@RunWith(Parameterized.class) +public class MpqSanityTestSpscArray extends MpqSanityTest +{ + public MpqSanityTestSpscArray(ConcurrentQueueSpec spec, MessagePassingQueue queue) + { + super(spec, queue); + } + + @Parameterized.Parameters + public static Collection parameters() + { + ArrayList list = new ArrayList(); + list.add(makeMpq(1, 1, 4, Ordering.FIFO));// SPSC size 4 + list.add(makeMpq(1, 1, SIZE, Ordering.FIFO));// SPSC size SIZE + list.add(makeAtomic(1, 1, 4, Ordering.FIFO));// SPSC size 4 + list.add(makeAtomic(1, 1, SIZE, Ordering.FIFO));// SPSC size SIZE + list.add(makeUnpadded(1, 1, 4, Ordering.FIFO));// SPSC size 4 + list.add(makeUnpadded(1, 1, SIZE, Ordering.FIFO));// SPSC size SIZE + return list; + } +} diff --git a/netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestSpscChunked.java b/netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestSpscChunked.java new file mode 100644 index 0000000..564943f --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestSpscChunked.java @@ -0,0 +1,46 @@ +package org.jctools.queues; + +import org.jctools.queues.atomic.SpscChunkedAtomicArrayQueue; +import org.jctools.queues.spec.ConcurrentQueueSpec; +import org.jctools.queues.spec.Ordering; +import org.jctools.queues.unpadded.SpscChunkedUnpaddedArrayQueue; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.ArrayList; +import java.util.Collection; + +import static org.jctools.util.TestUtil.makeParams; + +@RunWith(Parameterized.class) +public class MpqSanityTestSpscChunked extends MpqSanityTest +{ + public MpqSanityTestSpscChunked(ConcurrentQueueSpec spec, MessagePassingQueue queue) + { + super(spec, queue); + } + + @Parameterized.Parameters + public static Collection parameters() + { + ArrayList list = new ArrayList(); + list.add(makeParams(1, 1, 16, Ordering.FIFO, new SpscChunkedArrayQueue<>(8, 16)));// MPSC size 1 + list.add(makeParams(1, 1, SIZE, Ordering.FIFO, new SpscChunkedArrayQueue<>(8, SIZE)));// MPSC size SIZE + list.add(makeParams(1, 1, 16, Ordering.FIFO, new SpscChunkedAtomicArrayQueue<>(8, 16)));// MPSC size 1 + list.add(makeParams(1, 1, SIZE, Ordering.FIFO, new SpscChunkedAtomicArrayQueue<>(8, SIZE)));// MPSC size SIZE + list.add(makeParams(1, 1, 16, Ordering.FIFO, new SpscChunkedUnpaddedArrayQueue<>(8, 16)));// MPSC size 1 + list.add(makeParams(1, 1, SIZE, Ordering.FIFO, new SpscChunkedUnpaddedArrayQueue<>(8, SIZE)));// MPSC size SIZE + return list; + } + + @Test + public void testMaxSizeQueue() + { + SpscChunkedArrayQueue queue = new SpscChunkedArrayQueue(1024, 1000 * 1024 * 1024); + for (int i = 0; i < 400001; i++) + { + queue.offer(i); + } + } +} diff --git a/netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestSpscGrowable.java b/netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestSpscGrowable.java new file mode 100644 index 0000000..5f7ddd3 --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestSpscGrowable.java @@ -0,0 +1,36 @@ +package org.jctools.queues; + +import org.jctools.queues.atomic.SpscGrowableAtomicArrayQueue; +import org.jctools.queues.spec.ConcurrentQueueSpec; +import org.jctools.queues.spec.Ordering; +import org.jctools.queues.unpadded.SpscGrowableUnpaddedArrayQueue; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.ArrayList; +import java.util.Collection; + +import static org.jctools.util.TestUtil.makeParams; + +@RunWith(Parameterized.class) +public class MpqSanityTestSpscGrowable extends MpqSanityTest +{ + + public MpqSanityTestSpscGrowable(ConcurrentQueueSpec spec, MessagePassingQueue queue) + { + super(spec, queue); + } + + @Parameterized.Parameters + public static Collection parameters() + { + ArrayList list = new ArrayList(); + list.add(makeParams(1, 1, 16, Ordering.FIFO, new SpscGrowableArrayQueue<>(8, 16))); + list.add(makeParams(1, 1, SIZE, Ordering.FIFO, new SpscGrowableArrayQueue<>(8, SIZE))); + list.add(makeParams(1, 1, 16, Ordering.FIFO, new SpscGrowableAtomicArrayQueue<>(8, 16))); + list.add(makeParams(1, 1, SIZE, Ordering.FIFO, new SpscGrowableAtomicArrayQueue<>(8, SIZE))); + list.add(makeParams(1, 1, 16, Ordering.FIFO, new SpscGrowableUnpaddedArrayQueue<>(8, 16))); + list.add(makeParams(1, 1, SIZE, Ordering.FIFO, new SpscGrowableUnpaddedArrayQueue<>(8, SIZE))); + return list; + } +} diff --git a/netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestSpscLinked.java b/netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestSpscLinked.java new file mode 100644 index 0000000..8f17d0a --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestSpscLinked.java @@ -0,0 +1,30 @@ +package org.jctools.queues; + +import org.jctools.queues.spec.ConcurrentQueueSpec; +import org.jctools.queues.spec.Ordering; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.ArrayList; +import java.util.Collection; + +import static org.jctools.util.TestUtil.*; + +@RunWith(Parameterized.class) +public class MpqSanityTestSpscLinked extends MpqSanityTest +{ + public MpqSanityTestSpscLinked(ConcurrentQueueSpec spec, MessagePassingQueue queue) + { + super(spec, queue); + } + + @Parameterized.Parameters + public static Collection parameters() + { + ArrayList list = new ArrayList(); + list.add(makeMpq(1, 1, 0, Ordering.FIFO));// unbounded SPSC + list.add(makeAtomic(1, 1, 0, Ordering.FIFO));// unbounded SPSC + list.add(makeUnpadded(1, 1, 0, Ordering.FIFO));// unbounded SPSC + return list; + } +} diff --git a/netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestSpscUnbounded.java b/netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestSpscUnbounded.java new file mode 100644 index 0000000..979b390 --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/queues/MpqSanityTestSpscUnbounded.java @@ -0,0 +1,36 @@ +package org.jctools.queues; + +import org.jctools.queues.atomic.SpscUnboundedAtomicArrayQueue; +import org.jctools.queues.spec.ConcurrentQueueSpec; +import org.jctools.queues.spec.Ordering; +import org.jctools.queues.unpadded.SpscUnboundedUnpaddedArrayQueue; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.ArrayList; +import java.util.Collection; + +import static org.jctools.util.TestUtil.makeParams; + +@RunWith(Parameterized.class) +public class MpqSanityTestSpscUnbounded extends MpqSanityTest +{ + + public MpqSanityTestSpscUnbounded(ConcurrentQueueSpec spec, MessagePassingQueue queue) + { + super(spec, queue); + } + + @Parameterized.Parameters + public static Collection parameters() + { + ArrayList list = new ArrayList(); + list.add(makeParams(1, 1, 0, Ordering.FIFO, new SpscUnboundedArrayQueue<>(2))); + list.add(makeParams(1, 1, 0, Ordering.FIFO, new SpscUnboundedArrayQueue<>(64))); + list.add(makeParams(1, 1, 0, Ordering.FIFO, new SpscUnboundedAtomicArrayQueue<>(2))); + list.add(makeParams(1, 1, 0, Ordering.FIFO, new SpscUnboundedAtomicArrayQueue<>(64))); + list.add(makeParams(1, 1, 0, Ordering.FIFO, new SpscUnboundedUnpaddedArrayQueue<>(2))); + list.add(makeParams(1, 1, 0, Ordering.FIFO, new SpscUnboundedUnpaddedArrayQueue<>(64))); + return list; + } +} diff --git a/netty-jctools/src/test/java/org/jctools/queues/MpscArrayQueueSnapshotTest.java b/netty-jctools/src/test/java/org/jctools/queues/MpscArrayQueueSnapshotTest.java new file mode 100644 index 0000000..d404fd1 --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/queues/MpscArrayQueueSnapshotTest.java @@ -0,0 +1,79 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.jctools.queues; + +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.is; +import static org.junit.Assert.assertThat; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import org.junit.Before; +import org.junit.Test; + +public class MpscArrayQueueSnapshotTest { + + private MpscArrayQueue queue; + + @Before + public void setUp() throws Exception { + this.queue = new MpscArrayQueue<>(4); + } + + @Test + public void testIterator() { + queue.offer(0); + assertThat(iteratorToList(), contains(0)); + for (int i = 1; i < queue.capacity(); i++) { + queue.offer(i); + } + assertThat(iteratorToList(), containsInAnyOrder(0, 1, 2, 3)); + queue.poll(); + queue.offer(4); + queue.poll(); + assertThat(iteratorToList(), containsInAnyOrder(2, 3, 4)); + } + + @Test + public void testIteratorHasNextConcurrentModification() { + //There may be gaps in the elements returned by the iterator, + //but hasNext needs to be reliable even if the elements are consumed between hasNext() and next(). + queue.offer(0); + queue.offer(1); + Iterator iter = queue.iterator(); + assertThat(iter.hasNext(), is(true)); + queue.poll(); + queue.poll(); + assertThat(queue.isEmpty(), is(true)); + assertThat(iter.hasNext(), is(true)); + assertThat(iter.next(), is(0)); + assertThat(iter.hasNext(), is(false)); + } + + private List iteratorToList() { + List list = new ArrayList<>(); + Iterator iter = queue.iterator(); + iter.forEachRemaining(list::add); + return list; + } + +} diff --git a/netty-jctools/src/test/java/org/jctools/queues/MpscUnboundedArrayQueueSnapshotTest.java b/netty-jctools/src/test/java/org/jctools/queues/MpscUnboundedArrayQueueSnapshotTest.java new file mode 100644 index 0000000..c002404 --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/queues/MpscUnboundedArrayQueueSnapshotTest.java @@ -0,0 +1,108 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package org.jctools.queues; + +import static org.hamcrest.CoreMatchers.hasItems; +import static org.hamcrest.Matchers.contains; +import static org.hamcrest.Matchers.containsInAnyOrder; +import static org.hamcrest.Matchers.is; +import static org.junit.Assert.assertThat; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import org.junit.Before; +import org.junit.Test; + +public class MpscUnboundedArrayQueueSnapshotTest { + + private static final int CHUNK_SIZE = 4; + private MpscUnboundedArrayQueue queue; + + @Before + public void setUp() throws Exception { + this.queue = new MpscUnboundedArrayQueue<>(CHUNK_SIZE + 1); //Account for extra slot for JUMP + } + + @Test + public void testIterator() { + queue.offer(0); + assertThat(iteratorToList(), contains(0)); + for (int i = 1; i < CHUNK_SIZE; i++) { + queue.offer(i); + } + assertThat(iteratorToList(), containsInAnyOrder(0, 1, 2, 3)); + queue.offer(4); + queue.offer(5); + assertThat(iteratorToList(), containsInAnyOrder(0, 1, 2, 3, 4, 5)); + queue.poll(); + assertThat(iteratorToList(), containsInAnyOrder(1, 2, 3, 4, 5)); + for (int i = 1; i < CHUNK_SIZE; i++) { + queue.poll(); + } + assertThat(iteratorToList(), containsInAnyOrder(4, 5)); + } + + @Test + public void testIteratorOutpacedByConsumer() { + int slotsToForceMultipleBuffers = CHUNK_SIZE + 1; + for (int i = 0; i < slotsToForceMultipleBuffers; i++) { + queue.offer(i); + } + Iterator iter = queue.iterator(); + List entries = new ArrayList<>(); + entries.add(iter.next()); + for (int i = 0; i < CHUNK_SIZE; i++) { + queue.poll(); + } + //Now that the consumer has discarded the first buffer, the iterator needs to follow it to the new buffer. + iter.forEachRemaining(entries::add); + assertThat(entries, containsInAnyOrder(0, 1, 4)); + } + + @Test + public void testIteratorHasNextConcurrentModification() { + /* + * There may be gaps in the elements returned by the iterator, but hasNext needs to be reliable even if the elements are consumed + * between hasNext() and next(), and even if the consumer buffer changes. + */ + int slotsToForceMultipleBuffers = CHUNK_SIZE + 1; + for (int i = 0; i < slotsToForceMultipleBuffers; i++) { + queue.offer(i); + } + Iterator iter = queue.iterator(); + assertThat(iter.hasNext(), is(true)); + for (int i = 0; i < slotsToForceMultipleBuffers; i++) { + queue.poll(); + } + assertThat(queue.isEmpty(), is(true)); + assertThat(iter.hasNext(), is(true)); + assertThat(iter.next(), is(0)); + assertThat(iter.hasNext(), is(false)); + } + + private List iteratorToList() { + List list = new ArrayList<>(); + Iterator iter = queue.iterator(); + iter.forEachRemaining(list::add); + return list; + } + +} diff --git a/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTest.java b/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTest.java new file mode 100644 index 0000000..f2adce9 --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTest.java @@ -0,0 +1,789 @@ +package org.jctools.queues; + +import org.jctools.queues.spec.ConcurrentQueueSpec; +import org.jctools.queues.spec.Ordering; +import org.jctools.util.Pow2; +import org.jctools.util.TestUtil; +import org.junit.After; +import org.junit.Test; + +import java.util.ArrayList; +import java.util.Iterator; +import java.util.List; +import java.util.Queue; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.hamcrest.Matchers.*; +import static org.jctools.queues.MessagePassingQueue.UNBOUNDED_CAPACITY; +import static org.jctools.queues.matchers.Matchers.emptyAndZeroSize; +import static org.jctools.util.TestUtil.*; +import static org.junit.Assert.*; +import static org.junit.Assume.assumeThat; + +public abstract class QueueSanityTest +{ + public static final int SIZE = 8192 * 2; + + protected final Queue queue; + protected final ConcurrentQueueSpec spec; + + public QueueSanityTest(ConcurrentQueueSpec spec, Queue queue) + { + this.queue = queue; + this.spec = spec; + } + + @After + public void clear() + { + queue.clear(); + assertThat(queue, emptyAndZeroSize()); + } + + @Test + public void toStringWorks() + { + assertNotNull(queue.toString()); + } + + @Test + public void sanity() + { + for (int i = 0; i < SIZE; i++) + { + assertNull(queue.poll()); + assertThat(queue, emptyAndZeroSize()); + } + int i = 0; + while (i < SIZE && queue.offer(i)) + { + i++; + } + int size = i; + assertEquals(size, queue.size()); + if (spec.ordering == Ordering.FIFO) + { + // expect FIFO + i = 0; + Integer p; + Integer e; + while ((p = queue.peek()) != null) + { + e = queue.poll(); + assertEquals(p, e); + assertEquals(size - (i + 1), queue.size()); + assertEquals(i++, e.intValue()); + } + assertEquals(size, i); + } + else + { + // expect sum of elements is (size - 1) * size / 2 = 0 + 1 + .... + (size - 1) + int sum = (size - 1) * size / 2; + i = 0; + Integer e; + while ((e = queue.poll()) != null) + { + assertEquals(--size, queue.size()); + sum -= e; + } + assertEquals(0, sum); + } + assertNull(queue.poll()); + assertThat(queue, emptyAndZeroSize()); + } + + @Test + public void testSizeIsTheNumberOfOffers() + { + int currentSize = 0; + while (currentSize < SIZE && queue.offer(currentSize)) + { + currentSize++; + assertThat(queue, hasSize(currentSize)); + } + } + + @Test + public void whenFirstInThenFirstOut() + { + assumeThat(spec.ordering, is(Ordering.FIFO)); + + // Arrange + int i = 0; + while (i < SIZE && queue.offer(i)) + { + i++; + } + final int size = queue.size(); + + // Act + i = 0; + Integer prev; + while ((prev = queue.peek()) != null) + { + final Integer item = queue.poll(); + + assertThat(item, is(prev)); + assertThat(queue, hasSize(size - (i + 1))); + assertThat(item, is(i)); + i++; + } + + // Assert + assertThat(i, is(size)); + } + + @Test + public void test_FIFO_PRODUCER_Ordering() throws Exception + { + assumeThat(spec.ordering, is((Ordering.FIFO))); + + // Arrange + int i = 0; + while (i < SIZE && queue.offer(i)) + { + i++; + } + int size = queue.size(); + + // Act + // expect sum of elements is (size - 1) * size / 2 = 0 + 1 + .... + (size - 1) + int sum = (size - 1) * size / 2; + Integer e; + while ((e = queue.poll()) != null) + { + size--; + assertThat(queue, hasSize(size)); + sum -= e; + } + + // Assert + assertThat(sum, is(0)); + } + + @Test(expected = NullPointerException.class) + public void offerNullResultsInNPE() + { + queue.offer(null); + } + + @Test + public void whenOfferItemAndPollItemThenSameInstanceReturnedAndQueueIsEmpty() + { + assertThat(queue, emptyAndZeroSize()); + + // Act + final Integer e = new Integer(1876876); + queue.offer(e); + assertFalse(queue.isEmpty()); + assertEquals(1, queue.size()); + + final Integer oh = queue.poll(); + assertEquals(e, oh); + + // Assert + assertThat(oh, sameInstance(e)); + assertThat(queue, emptyAndZeroSize()); + } + + @Test + public void testPowerOf2Capacity() + { + assumeThat(spec.isBounded(), is(true)); + int n = Pow2.roundToPowerOfTwo(spec.capacity); + + for (int i = 0; i < n; i++) + { + assertTrue("Failed to insert:" + i, queue.offer(i)); + } + assertFalse(queue.offer(n)); + } + + @Test + public void testQueueProgressIndicators() + { + assumeThat(queue, is(instanceOf(QueueProgressIndicators.class))); + QueueProgressIndicators q = (QueueProgressIndicators) queue; + // queue is empty + assertEquals(q.currentConsumerIndex(), q.currentProducerIndex()); + queue.offer(1); + assertEquals(q.currentConsumerIndex() + 1, q.currentProducerIndex()); + queue.poll(); + assertEquals(q.currentConsumerIndex(), q.currentProducerIndex()); + } + + @Test(timeout = TEST_TIMEOUT) + public void testHappensBeforePeek() throws Exception + { + testHappensBefore0(true); + + } + + @Test(timeout = TEST_TIMEOUT) + public void testHappensBeforePoll() throws Exception + { + testHappensBefore0(false); + + } + + private void testHappensBefore0(boolean peek) throws InterruptedException + { + final AtomicBoolean stop = new AtomicBoolean(); + final Queue q = queue; + final Val fail = new Val(); + List threads = new ArrayList<>(); + threads(() -> { + while (!stop.get()) + { + for (int i = 1; i <= 10; i++) + { + Val v = new Val(); + v.value = i; + q.offer(v); + } + // slow down the producer, this will make the queue mostly empty encouraging visibility issues. + Thread.yield(); + } + }, spec.producers, threads); + + threads(() -> { + while (!stop.get()) + { + for (int i = 0; i < 10; i++) + { + Val v = peek ? (Val) q.peek() : (Val) q.poll(); + if (v != null && v.value == 0) + { + // assert peek/poll visible values are never uninitialized + fail.value = 1; + stop.set(true); + break; + } + else if (peek && v != null && v != q.poll()) + { + // assert peek visible values are same as poll + fail.value = 2; + stop.set(true); + break; + } + } + } + }, 1, threads); + + startWaitJoin(stop, threads); + assertEquals("reordering detected", 0, fail.value); + } + + @Test(timeout = TEST_TIMEOUT) + public void testSize() throws Exception + { + final int capacity = !spec.isBounded() ? Integer.MAX_VALUE : spec.capacity; + + final AtomicBoolean stop = new AtomicBoolean(); + final Queue q = queue; + final Val fail = new Val(); + List threads = new ArrayList<>(); + // each thread is adding and removing 1 element + threads(() -> { + try + { + while (!stop.get()) + { + // conditional is required when threads > capacity + if (q.offer(1)) ; + while (q.poll() == null && !stop.get()) ; + } + } + catch (Throwable t) + { + t.printStackTrace(); + fail.value++; + } + }, spec.isMpmc() ? 0 : 1, threads); + int producersConsumers = threads.size(); + // observer + threads(() -> { + final int max = Math.min(producersConsumers, capacity); + while (!stop.get()) + { + int size = q.size(); + if (size < 0 || size > max) + { + fail.value++; + } + } + }, 1, threads); + + startWaitJoin(stop, threads); + assertEquals("Unexpected size observed", 0, fail.value); + + } + + @Test(timeout = TEST_TIMEOUT) + public void testSizeContendedFull() throws Exception + { + assumeThat(spec.isBounded(), is(Boolean.TRUE)); + final AtomicBoolean stop = new AtomicBoolean(); + final Queue q = queue; + // this is fragile + int capacity = spec.capacity; + final Val fail = new Val(); + List threads = new ArrayList<>(); + + threads(() -> { + while (!stop.get()) + { + q.offer(1); + } + }, spec.producers, threads); + + threads(() -> { + while (!stop.get()) + { + q.poll(); + // slow down the consumer, this will make the queue mostly full + sleepQuietly(1); + } + }, spec.consumers, threads); + // observer + threads(() -> { + while (!stop.get()) + { + int size = q.size(); + if (size > capacity) + { + fail.value++; + } + } + }, 1, threads); + + startWaitJoin(stop, threads); + + assertEquals("Observed no element in non-empty queue", 0, fail.value); + } + + @Test(timeout = TEST_TIMEOUT) + public void testPeekAfterIsEmpty1() throws Exception + { + final AtomicBoolean stop = new AtomicBoolean(); + final Queue q = queue; + final Val fail = new Val(); + testIsEmptyInvariant(stop, q, fail, () -> { + while (!stop.get()) + { + if (!q.isEmpty() && q.peek() == null) + { + fail.value++; + } + q.poll(); + } + }); + } + + @Test(timeout = TEST_TIMEOUT) + public void testPeekAfterIsEmpty2() throws Exception + { + final AtomicBoolean stop = new AtomicBoolean(); + final Queue q = queue; + final Val fail = new Val(); + testIsEmptyInvariant(stop, q, fail, () -> { + while (!stop.get()) + { + // can the consumer progress "passed" the producer and confuse `isEmpty`? + q.poll(); + if (!q.isEmpty() && q.peek() == null) + { + fail.value++; + } + } + }); + } + + @Test(timeout = TEST_TIMEOUT) + public void testPeekAfterIsEmpty3() throws Exception + { + final AtomicBoolean stop = new AtomicBoolean(); + final Queue q = queue; + final Val fail = new Val(); + testIsEmptyInvariant(stop, q, fail, () -> { + while (!stop.get()) + { + // can the consumer progress "passed" the producer and confuse `size`? + q.poll(); + if (q.size() != 0 && q.peek() == null) + { + fail.value++; + } + } + }); + } + + @Test(timeout = TEST_TIMEOUT) + public void testPollAfterIsEmpty1() throws Exception + { + final AtomicBoolean stop = new AtomicBoolean(); + final Queue q = queue; + final Val fail = new Val(); + testIsEmptyInvariant(stop, q, fail, () -> { + while (!stop.get()) + { + if (!q.isEmpty() && q.poll() == null) + { + fail.value++; + } + } + }); + } + + @Test(timeout = TEST_TIMEOUT) + public void testPollAfterIsEmpty2() throws Exception + { + final AtomicBoolean stop = new AtomicBoolean(); + final Queue q = queue; + final Val fail = new Val(); + testIsEmptyInvariant(stop, q, fail, () -> { + while (!stop.get()) + { + // can the consumer progress "passed" the producer and confuse `isEmpty`? + q.poll(); + if (!q.isEmpty() && q.poll() == null) + { + fail.value++; + } + } + }); + } + + @Test(timeout = TEST_TIMEOUT) + public void testPollAfterIsEmpty3() throws Exception + { + final AtomicBoolean stop = new AtomicBoolean(); + final Queue q = queue; + final Val fail = new Val(); + testIsEmptyInvariant(stop, q, fail, () -> { + while (!stop.get()) + { + // can the consumer progress "passed" the producer and confuse `size`? + q.poll(); + if (q.size() != 0 && q.poll() == null) + { + fail.value++; + } + } + }); + } + + private void testIsEmptyInvariant(AtomicBoolean stop, Queue q, Val fail, Runnable consumerLoop) + throws InterruptedException + { + List threads = new ArrayList<>(); + + threads(() -> { + while (!stop.get()) + { + q.offer(1); + // slow down the producer, this will make the queue mostly empty encouraging visibility issues. + Thread.yield(); + } + }, spec.producers, threads); + threads(consumerLoop, 1, threads); + + startWaitJoin(stop, threads); + + assertEquals("Observed no element in non-empty queue", 0, fail.value); + } + + @Test(timeout = TEST_TIMEOUT) + public void testPollOrderContendedFull() throws Exception + { + assumeThat(spec.isBounded(), is(Boolean.TRUE)); + assumeThat(spec.ordering, is(Ordering.FIFO)); + + final AtomicBoolean stop = new AtomicBoolean(); + final Queue q = queue; + final Val fail = new Val(); + List threads = new ArrayList<>(); + + final AtomicInteger pThreadId = new AtomicInteger(); + threads(() -> { + // store the thread id in the top 8 bits + int pId = pThreadId.getAndIncrement() << 24; + + int i = 0; + while (!stop.get() && i < Integer.MAX_VALUE >>> 8) + { + // clear the top 8 bits + int nextVal = (i << 8) >>> 8; + // set the pid in the top 8 bits + if (q.offer(nextVal ^ pId)) + { + i++; + } + } + }, spec.producers, threads); + int producers = threads.size(); + assertThat("The thread ID scheme above doesn't work for more than 256 threads", producers, lessThan(256)); + threads(() -> { + Integer[] lastPolledSequence = new Integer[producers]; + while (!stop.get()) + { + sleepQuietly(1); + final Integer polledSequenceAndTid = q.poll(); + if (polledSequenceAndTid == null) + { + continue; + } + int pTid = polledSequenceAndTid >>> 24; + int polledSequence = (polledSequenceAndTid << 8) >>> 8; + if (lastPolledSequence[pTid] != null && polledSequence - lastPolledSequence[pTid] < 0) + { + fail.value++; + } + + lastPolledSequence[pTid] = polledSequence; + } + }, spec.consumers, threads); + + startWaitJoin(stop, threads); + + assertEquals("Polled elements out of order", 0, fail.value); + } + + @Test(timeout = TEST_TIMEOUT) + public void testPeekOrderContendedFull() throws Exception + { + // only for multi consumers as we need a separate peek/poll thread here + assumeThat(spec.isBounded() && (spec.isMpmc() || spec.isSpmc()), is(Boolean.TRUE)); + final AtomicBoolean stop = new AtomicBoolean(); + final Queue q = queue; + final Val fail = new Val(); + List threads = new ArrayList<>(); + + final AtomicInteger pThreadId = new AtomicInteger(); + threads(() -> { + // store the thread id in the top 8 bits + int pId = pThreadId.getAndIncrement() << 24; + + int i = 0; + while (!stop.get() && i < Integer.MAX_VALUE >>> 8) + { + // clear the top 8 bits + int nextVal = (i << 8) >>> 8; + // set the pid in the top 8 bits + if (q.offer(nextVal ^ pId)) + { + i++; + } + } + }, spec.producers, threads); + int producers = threads.size(); + assertThat("The thread ID scheme above doesn't work for more than 256 threads", producers, lessThan(256)); + + threads(() -> { + while (!stop.get()) + { + q.poll(); + // slow down the consumer, this will make the queue mostly full + sleepQuietly(1); + } + }, spec.consumers, threads); + // observer + threads(() -> { + Integer[] lastPeekedSequence = new Integer[producers]; + while (!stop.get()) + { + final Integer peekedSequenceAndTid = q.peek(); + if (peekedSequenceAndTid == null) + { + continue; + } + int pTid = peekedSequenceAndTid >>> 24; + int peekedSequence = (peekedSequenceAndTid << 8) >>> 8; + + if (lastPeekedSequence[pTid] != null && peekedSequence - lastPeekedSequence[pTid] < 0) + { + fail.value++; + } + + lastPeekedSequence[pTid] = peekedSequence; + } + }, 1, threads); + + startWaitJoin(stop, threads); + + assertEquals("Peeked elements out of order", 0, fail.value); + } + + @Test + public void testIterator() + { + assumeThat(queue, instanceOf(SupportsIterator.class)); + assumeThat(queue, instanceOf(MessagePassingQueue.class)); + + int capacity = ((MessagePassingQueue) queue).capacity(); + int insertLimit = (capacity == UNBOUNDED_CAPACITY) ? 128 : capacity; + + for (int i = 0; i < insertLimit; i++) + { + queue.offer(i); + } + + Iterator iterator = queue.iterator(); + for (int i = 0; i < insertLimit; i++) + { + assertEquals(Integer.valueOf(i), iterator.next()); + } + assertTrue((capacity == UNBOUNDED_CAPACITY) || !iterator.hasNext()); + + queue.poll(); // drop 0 + queue.offer(insertLimit); // add capacity + iterator = queue.iterator(); + for (int i = 1; i <= insertLimit; i++) + { + assertEquals(Integer.valueOf(i), iterator.next()); + } + assertTrue((capacity == UNBOUNDED_CAPACITY) || !iterator.hasNext()); + } + + @Test + public void testIteratorHasNextConcurrentModification() + { + assumeThat(queue, instanceOf(SupportsIterator.class)); + assumeThat(queue, instanceOf(MessagePassingQueue.class)); + int capacity = ((MessagePassingQueue) queue).capacity(); + if (capacity != UNBOUNDED_CAPACITY) + { + assumeThat(capacity, greaterThanOrEqualTo(2)); + } + //There may be gaps in the elements returned by the iterator, + //but hasNext needs to be reliable even if the elements are consumed between hasNext() and next(). + queue.offer(0); + queue.offer(1); + Iterator iter = queue.iterator(); + assertThat(iter.hasNext(), is(true)); + queue.poll(); + queue.poll(); + assertThat(queue.isEmpty(), is(true)); + assertThat(iter.hasNext(), is(true)); + assertThat(iter.next(), is(0)); + assertThat(iter.hasNext(), is(false)); + } + + @Test(timeout = TEST_TIMEOUT) + public void testSizeLtZero() throws Exception + { + final AtomicBoolean stop = new AtomicBoolean(); + final Queue q = queue; + final List threads = new ArrayList<>(); + + // producer check size and offer + final Val pFail = new Val(); + threads(() -> { + while (!stop.get()) + { + if (q.size() < 0) + { + pFail.value++; + } + + q.offer(1); + TestUtil.sleepQuietly(1); + } + }, spec.producers, threads); + + // consumer poll and check size + final Val cFail = new Val(); + threads(() -> { + while (!stop.get()) + { + q.poll(); + + if (q.size() < 0) + { + cFail.value++; + } + } + }, spec.consumers, threads); + + // observer check size + final Val oFail = new Val(); + threads(() -> { + while (!stop.get()) + { + if (q.size() < 0) + { + oFail.value++; + } + TestUtil.sleepQuietly(1); + } + }, 1, threads); + + startWaitJoin(stop, threads); + + assertEquals("Observed producer size < 0", 0, pFail.value); + assertEquals("Observed consumer size < 0", 0, cFail.value); + assertEquals("Observed observer size < 0", 0, oFail.value); + } + + @Test(timeout = TEST_TIMEOUT) + public void testSizeGtCapacity() throws Exception + { + assumeThat(spec.isBounded(), is(Boolean.TRUE)); + + final int capacity = spec.capacity; + final AtomicBoolean stop = new AtomicBoolean(); + final Queue q = queue; + final List threads = new ArrayList<>(); + + // producer offer and check size + final Val pFail = new Val(); + threads(() -> { + while (!stop.get()) + { + q.offer(1); + + if (q.size() > capacity) + { + pFail.value++; + } + } + }, spec.producers, threads); + + // consumer check size and poll + final Val cFail = new Val(); + threads(() -> { + while (!stop.get()) + { + if (q.size() > capacity) + { + cFail.value++; + } + + q.poll(); + TestUtil.sleepQuietly(1); + } + }, spec.consumers, threads); + + // observer check size + final Val oFail = new Val(); + threads(() -> { + while (!stop.get()) + { + if (q.size() > capacity) + { + oFail.value++; + } + TestUtil.sleepQuietly(1); + } + }, 1, threads); + + startWaitJoin(stop, threads); + + assertEquals("Observed producer size > capacity", 0, pFail.value); + assertEquals("Observed consumer size > capacity", 0, cFail.value); + assertEquals("Observed observer size > capacity", 0, oFail.value); + } + +} diff --git a/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestMpmcArray.java b/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestMpmcArray.java new file mode 100644 index 0000000..59068be --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestMpmcArray.java @@ -0,0 +1,100 @@ +package org.jctools.queues; + +import org.jctools.queues.spec.ConcurrentQueueSpec; +import org.jctools.queues.spec.Ordering; +import org.jctools.util.TestUtil.*; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Queue; +import java.util.concurrent.atomic.AtomicBoolean; + +import static org.jctools.util.TestUtil.*; +import static org.junit.Assert.assertEquals; + +@RunWith(Parameterized.class) +public class QueueSanityTestMpmcArray extends QueueSanityTest +{ + public QueueSanityTestMpmcArray(ConcurrentQueueSpec spec, Queue queue) + { + super(spec, queue); + } + + @Parameterized.Parameters + public static Collection parameters() + { + ArrayList list = new ArrayList(); + // Mpmc minimal size is 2 + list.add(makeMpq(0, 0, 2, Ordering.FIFO)); + list.add(makeMpq(0, 0, SIZE, Ordering.FIFO)); + list.add(makeAtomic(0, 0, 2, Ordering.FIFO)); + list.add(makeAtomic(0, 0, SIZE, Ordering.FIFO)); + list.add(makeUnpadded(0, 0, 2, Ordering.FIFO)); + list.add(makeUnpadded(0, 0, SIZE, Ordering.FIFO)); + return list; + } + + @Test + public void testOfferPollSemantics() throws Exception + { + final AtomicBoolean stop = new AtomicBoolean(); + final Queue q = queue; + // fill up the queue + while (q.offer(1)) + { + ; + } + // queue has 2 empty slots + q.poll(); + q.poll(); + + final Val fail = new Val(); + Thread t1 = new Thread(new Runnable() + { + @Override + public void run() + { + while (!stop.get()) + { + if (!q.offer(1)) + { + fail.value++; + } + if (q.poll() == null) + { + fail.value++; + } + } + } + }); + Thread t2 = new Thread(new Runnable() + { + @Override + public void run() + { + while (!stop.get()) + { + if (!q.offer(1)) + { + fail.value++; + } + if (q.poll() == null) + { + fail.value++; + } + } + } + }); + + t1.start(); + t2.start(); + Thread.sleep(1000); + stop.set(true); + t1.join(); + t2.join(); + assertEquals("Unexpected offer/poll observed", 0, fail.value); + } +} diff --git a/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestMpmcUnboundedXadd.java b/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestMpmcUnboundedXadd.java new file mode 100644 index 0000000..7a3071a --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestMpmcUnboundedXadd.java @@ -0,0 +1,38 @@ +package org.jctools.queues; + +import org.jctools.queues.spec.ConcurrentQueueSpec; +import org.jctools.queues.spec.Ordering; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Queue; + +import static org.jctools.util.TestUtil.makeParams; + +@RunWith(Parameterized.class) +public class QueueSanityTestMpmcUnboundedXadd extends QueueSanityTest +{ + public QueueSanityTestMpmcUnboundedXadd(ConcurrentQueueSpec spec, Queue queue) + { + super(spec, queue); + } + + @Parameterized.Parameters + public static Collection parameters() + { + ArrayList list = new ArrayList(); + list.add(makeParams(0, 0, 0, Ordering.FIFO, new MpmcUnboundedXaddArrayQueue<>(1, 0))); + list.add(makeParams(0, 0, 0, Ordering.FIFO, new MpmcUnboundedXaddArrayQueue<>(16, 0))); + list.add(makeParams(0, 0, 0, Ordering.FIFO, new MpmcUnboundedXaddArrayQueue<>(1, 1))); + list.add(makeParams(0, 0, 0, Ordering.FIFO, new MpmcUnboundedXaddArrayQueue<>(16, 1))); + list.add(makeParams(0, 0, 0, Ordering.FIFO, new MpmcUnboundedXaddArrayQueue<>(1, 2))); + list.add(makeParams(0, 0, 0, Ordering.FIFO, new MpmcUnboundedXaddArrayQueue<>(16, 2))); + list.add(makeParams(0, 0, 0, Ordering.FIFO, new MpmcUnboundedXaddArrayQueue<>(1, 3))); + list.add(makeParams(0, 0, 0, Ordering.FIFO, new MpmcUnboundedXaddArrayQueue<>(16, 3))); + list.add(makeParams(0, 0, 0, Ordering.FIFO, new MpmcUnboundedXaddArrayQueue<>(1, 4))); + list.add(makeParams(0, 0, 0, Ordering.FIFO, new MpmcUnboundedXaddArrayQueue<>(16, 4))); + return list; + } +} diff --git a/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestMpscArray.java b/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestMpscArray.java new file mode 100644 index 0000000..22ba92e --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestMpscArray.java @@ -0,0 +1,98 @@ +package org.jctools.queues; + +import org.jctools.queues.spec.ConcurrentQueueSpec; +import org.jctools.queues.spec.Ordering; +import org.jctools.util.TestUtil.*; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Queue; +import java.util.concurrent.atomic.AtomicBoolean; + +import static org.hamcrest.Matchers.is; +import static org.jctools.util.TestUtil.*; +import static org.junit.Assert.assertEquals; +import static org.junit.Assume.assumeThat; + +@RunWith(Parameterized.class) +public class QueueSanityTestMpscArray extends QueueSanityTest +{ + public QueueSanityTestMpscArray(ConcurrentQueueSpec spec, Queue queue) + { + super(spec, queue); + } + + @Parameterized.Parameters + public static Collection parameters() + { + ArrayList list = new ArrayList(); + list.add(makeMpq(0, 1, 1, Ordering.FIFO)); + list.add(makeMpq(0, 1, 1, Ordering.FIFO)); + list.add(makeMpq(0, 1, SIZE, Ordering.FIFO)); + list.add(makeAtomic(0, 1, 1, Ordering.FIFO)); + list.add(makeAtomic(0, 1, 2, Ordering.FIFO)); + list.add(makeAtomic(0, 1, SIZE, Ordering.FIFO)); + list.add(makeUnpadded(0, 1, 1, Ordering.FIFO)); + list.add(makeUnpadded(0, 1, 2, Ordering.FIFO)); + list.add(makeUnpadded(0, 1, SIZE, Ordering.FIFO)); + + return list; + } + + @Test + public void testOfferPollSemantics() throws Exception + { + assumeThat(spec.capacity, is(2)); + final AtomicBoolean stop = new AtomicBoolean(); + final AtomicBoolean consumerLock = new AtomicBoolean(true); + final Queue q = this.queue; + // fill up the queue + while (q.offer(1)) + { + ; + } + // queue has 2 empty slots + q.poll(); + q.poll(); + + final Val fail = new Val(); + final Runnable runnable = new Runnable() + { + @Override + public void run() + { + while (!stop.get()) + { + if (!q.offer(1)) + { + fail.value++; + } + + while (!consumerLock.compareAndSet(true, false)) + { + ; + } + if (q.poll() == null) + { + fail.value++; + } + consumerLock.lazySet(true); + } + } + }; + Thread t1 = new Thread(runnable); + Thread t2 = new Thread(runnable); + + t1.start(); + t2.start(); + Thread.sleep(1000); + stop.set(true); + t1.join(); + t2.join(); + assertEquals("Unexpected offer/poll observed", 0, fail.value); + + } +} diff --git a/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestMpscArrayExtended.java b/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestMpscArrayExtended.java new file mode 100644 index 0000000..69ebeb0 --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestMpscArrayExtended.java @@ -0,0 +1,87 @@ +package org.jctools.queues; + +import org.jctools.util.TestUtil.Val; +import org.junit.Assert; +import org.junit.Test; + +import java.util.Queue; +import java.util.concurrent.atomic.AtomicBoolean; + +import static org.junit.Assert.assertEquals; + +public class QueueSanityTestMpscArrayExtended +{ + @Test + public void testOfferWithThreshold() + { + MpscArrayQueue queue = new MpscArrayQueue(16); + int i; + for (i = 0; i < 8; ++i) + { + //Offers succeed because current size is below the HWM. + Assert.assertTrue(queue.offerIfBelowThreshold(i, 8)); + } + //Not anymore, our offer got rejected. + Assert.assertFalse(queue.offerIfBelowThreshold(i, 8)); + Assert.assertFalse(queue.offerIfBelowThreshold(i, 7)); + Assert.assertFalse(queue.offerIfBelowThreshold(i, 1)); + Assert.assertFalse(queue.offerIfBelowThreshold(i, 0)); + + //Also, the threshold is dynamic and different levels can be set for + //different task priorities. + Assert.assertTrue(queue.offerIfBelowThreshold(i, 9)); + Assert.assertTrue(queue.offerIfBelowThreshold(i, 16)); + } + + @Test + public void testOfferPollSemantics() throws Exception + { + final AtomicBoolean stop = new AtomicBoolean(); + final AtomicBoolean consumerLock = new AtomicBoolean(true); + final Queue q = new MpscArrayQueue(2); + // fill up the queue + while (q.offer(1)) + { + ; + } + // queue has 2 empty slots + q.poll(); + q.poll(); + + final Val fail = new Val(); + final Runnable runnable = new Runnable() + { + @Override + public void run() + { + while (!stop.get()) + { + if (!q.offer(1)) + { + fail.value++; + } + + while (!consumerLock.compareAndSet(true, false)) + { + ; + } + if (q.poll() == null) + { + fail.value++; + } + consumerLock.lazySet(true); + } + } + }; + Thread t1 = new Thread(runnable); + Thread t2 = new Thread(runnable); + + t1.start(); + t2.start(); + Thread.sleep(1000); + stop.set(true); + t1.join(); + t2.join(); + assertEquals("Unexpected offer/poll observed", 0, fail.value); + } +} diff --git a/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestMpscBlockingConsumer.java b/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestMpscBlockingConsumer.java new file mode 100644 index 0000000..bfd69f5 --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestMpscBlockingConsumer.java @@ -0,0 +1,30 @@ +package org.jctools.queues; + +import org.jctools.queues.spec.ConcurrentQueueSpec; +import org.jctools.queues.spec.Ordering; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Queue; + +import static org.jctools.util.TestUtil.makeParams; + +@RunWith(Parameterized.class) +public class QueueSanityTestMpscBlockingConsumer extends QueueSanityTest +{ + public QueueSanityTestMpscBlockingConsumer(ConcurrentQueueSpec spec, Queue queue) + { + super(spec, queue); + } + + @Parameterized.Parameters + public static Collection parameters() + { + ArrayList list = new ArrayList<>(); + list.add(makeParams(0, 1, 2, Ordering.FIFO, new MpscBlockingConsumerArrayQueue<>(2))); + list.add(makeParams(0, 1, SIZE, Ordering.FIFO, new MpscBlockingConsumerArrayQueue<>(SIZE))); + return list; + } +} diff --git a/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestMpscBlockingConsumerArrayExtended.java b/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestMpscBlockingConsumerArrayExtended.java new file mode 100644 index 0000000..e7427e2 --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestMpscBlockingConsumerArrayExtended.java @@ -0,0 +1,513 @@ +package org.jctools.queues; + +import java.lang.Thread.State; +import java.util.ArrayDeque; +import java.util.ArrayList; +import java.util.Queue; +import java.util.concurrent.ThreadLocalRandom; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; +import java.util.concurrent.locks.LockSupport; + +import org.jctools.util.TestUtil.Val; +import org.junit.Assert; +import org.junit.Test; + +import static java.util.concurrent.TimeUnit.*; +import static org.jctools.util.TestUtil.CONCURRENT_TEST_DURATION; +import static org.jctools.util.TestUtil.TEST_TIMEOUT; +import static org.junit.Assert.*; + + +public class QueueSanityTestMpscBlockingConsumerArrayExtended +{ + @Test + public void testOfferPollSemantics() throws Exception + { + final AtomicBoolean stop = new AtomicBoolean(); + final AtomicBoolean consumerLock = new AtomicBoolean(true); + final Queue q = new MpscBlockingConsumerArrayQueue<>(2); + // fill up the queue + while (q.offer(1)); + + // queue has 2 empty slots + q.poll(); + q.poll(); + + final Val fail = new Val(); + final Runnable runnable = () -> { + while (!stop.get()) + { + if (!q.offer(1)) + { + fail.value++; + } + + while (!consumerLock.compareAndSet(true, false)); + + + if (q.poll() == null) + { + fail.value++; + } + consumerLock.lazySet(true); + } + }; + + Thread t1 = new Thread(runnable); + Thread t2 = new Thread(runnable); + + t1.start(); + t2.start(); + Thread.sleep(1000); + stop.set(true); + t1.join(); + t2.join(); + assertEquals("Unexpected offer/poll observed", 0, fail.value); + + } + + @Test(timeout = TEST_TIMEOUT) + public void testPollTimeout() throws InterruptedException { + + final MpscBlockingConsumerArrayQueue queue = + new MpscBlockingConsumerArrayQueue<>(128000); + + final Thread consumerThread = new Thread(() -> { + try { + while (true) { + queue.poll(100, TimeUnit.NANOSECONDS); + } + } + catch (InterruptedException e) { + } + }); + + consumerThread.start(); + + final Thread producerThread = new Thread(() -> { + while (!Thread.interrupted()) { + for (int i = 0; i < 10; ++i) { + queue.offer("x"); + } + LockSupport.parkNanos(100000); + } + }); + + producerThread.start(); + + Thread.sleep(CONCURRENT_TEST_DURATION); + + consumerThread.interrupt(); + consumerThread.join(); + + producerThread.interrupt(); + producerThread.join(); + } + + @Test(timeout = TEST_TIMEOUT) + public void testOfferTakeSemantics() throws Exception + { + testOfferBlockSemantics(false); + } + + @Test(timeout = TEST_TIMEOUT) + public void testOfferPollWithTimeoutSemantics() throws Exception + { + testOfferBlockSemantics(true); + } + + @Test(timeout = TEST_TIMEOUT) + public void testOfferBlockingDrainSemantics() throws Exception + { + final AtomicBoolean stop = new AtomicBoolean(); + final AtomicBoolean consumerLock = new AtomicBoolean(true); + final MpscBlockingConsumerArrayQueue q = new MpscBlockingConsumerArrayQueue<>(2); + // fill up the queue + while (q.offer(1)); + + // queue has 2 empty slots + q.poll(); + q.poll(); + + final Val fail = new Val(); + final Runnable runnable = () -> { + ArrayDeque ints = new ArrayDeque<>(1); + + while (!stop.get()) + { + if (!q.offer(1)) + { + fail.value++; + } + + while (!consumerLock.compareAndSet(true, false)); + + try + { + int howMany = q.drain(ints::offer, 1, 1L, DAYS); + if (howMany == 0 || howMany != ints.size()) + { + fail.value++; + } + ints.clear(); + } + catch (InterruptedException e) + { + fail.value++; + } + consumerLock.lazySet(true); + } + }; + Thread t1 = new Thread(runnable); + Thread t2 = new Thread(runnable); + + t1.start(); + t2.start(); + Thread.sleep(CONCURRENT_TEST_DURATION); + stop.set(true); + t1.join(); + t2.join(); + assertEquals("Unexpected offer/poll observed", 0, fail.value); + + } + + @Test(timeout = TEST_TIMEOUT) + public void testBlockingDrainSemantics() throws Exception + { + final MpscBlockingConsumerArrayQueue q = new MpscBlockingConsumerArrayQueue<>(2); + ArrayDeque ints = new ArrayDeque<>(); + assertEquals(0, q.drain(ints::add, 0, 0, NANOSECONDS)); + assertEquals(0, ints.size()); + + q.offer(1); + + assertEquals(0, q.drain(ints::add, 0, 0, NANOSECONDS)); + assertEquals(0, ints.size()); + assertEquals(1, q.drain(ints::add, 1, 0, NANOSECONDS)); + assertEquals((Integer) 1, ints.poll()); + + long beforeNanos = System.nanoTime(); + assertEquals(0, q.drain(ints::add, 1, 250L, MILLISECONDS)); + long tookMillis = MILLISECONDS.convert(System.nanoTime() - beforeNanos, NANOSECONDS); + assertEquals(0, ints.size()); + assertTrue("took " + tookMillis + "ms", 200L < tookMillis && tookMillis < 300L); + } + + private void testOfferBlockSemantics(boolean withTimeout) throws Exception + { + final AtomicBoolean stop = new AtomicBoolean(); + final AtomicBoolean consumerLock = new AtomicBoolean(true); + final MpscBlockingConsumerArrayQueue q = new MpscBlockingConsumerArrayQueue<>(2); + // fill up the queue + while (q.offer(1)); + + // queue has 2 empty slots + q.poll(); + q.poll(); + + final Val fail = new Val(); + final Runnable runnable = () -> { + while (!stop.get()) + { + if (!q.offer(1)) + { + fail.value++; + } + + while (!consumerLock.compareAndSet(true, false)); + + try + { + Integer take = withTimeout ? q.poll(1L, DAYS) : q.take(); + if (take == null) + { + fail.value++; + } + } + catch (InterruptedException e) + { + fail.value++; + } + consumerLock.lazySet(true); + } + }; + Thread t1 = new Thread(runnable); + Thread t2 = new Thread(runnable); + + t1.start(); + t2.start(); + Thread.sleep(CONCURRENT_TEST_DURATION); + stop.set(true); + t1.join(); + t2.join(); + assertEquals("Unexpected offer/poll observed", 0, fail.value); + + } + + @Test(timeout = TEST_TIMEOUT) + public void testPollTimeoutSemantics() throws Exception + { + final MpscBlockingConsumerArrayQueue q = new MpscBlockingConsumerArrayQueue<>(2); + + assertNull(q.poll(0, NANOSECONDS)); + + q.offer(1); + assertEquals((Integer) 1, q.poll(0, NANOSECONDS)); + + long beforeNanos = System.nanoTime(); + assertNull(q.poll(250L, MILLISECONDS)); + long tookMillis = MILLISECONDS.convert(System.nanoTime() - beforeNanos, NANOSECONDS); + + assertTrue("took " + tookMillis + "ms", 200L < tookMillis && tookMillis < 300L); + } + + @Test(timeout = TEST_TIMEOUT) + public void testTakeBlocksAndIsInterrupted() throws Exception + { + testTakeBlocksAndIsInterrupted(PollType.Take); + } + + @Test(timeout = TEST_TIMEOUT) + public void testPollWithTimeoutBlocksAndIsInterrupted() throws Exception + { + testTakeBlocksAndIsInterrupted(PollType.BlockingPoll); + } + + @Test(timeout = TEST_TIMEOUT) + public void testBlockingDrainWithTimeoutBlocksAndIsInterrupted() throws Exception + { + testTakeBlocksAndIsInterrupted(PollType.BlockingDrain); + } + + private enum PollType { + BlockingPoll, + Take, + BlockingDrain + } + + private void testTakeBlocksAndIsInterrupted(PollType pollType) throws Exception + { + final AtomicBoolean wasInterrupted = new AtomicBoolean(); + final AtomicBoolean interruptedStatusAfter = new AtomicBoolean(); + final MpscBlockingConsumerArrayQueue q = new MpscBlockingConsumerArrayQueue<>(1024); + Thread consumer = new Thread(() -> { + try + { + switch (pollType) { + + case BlockingPoll: + q.poll(1L, DAYS); + break; + case Take: + q.take(); + break; + case BlockingDrain: + q.drain(ignored -> {}, 1, 1L, DAYS); + break; + } + } + catch (InterruptedException e) + { + wasInterrupted.set(true); + } + interruptedStatusAfter.set(Thread.currentThread().isInterrupted()); + }); + consumer.setDaemon(true); + consumer.start(); + while(consumer.getState() != State.TIMED_WAITING) + { + Thread.yield(); + } + // If we got here -> thread got to the waiting state -> parked + consumer.interrupt(); + consumer.join(); + assertTrue(wasInterrupted.get()); + assertFalse(interruptedStatusAfter.get()); + + // Queue should remain in original state (empty) + assertNull(q.poll()); + } + + @Test(timeout = TEST_TIMEOUT) + public void testTakeSomeElementsThenBlocksAndIsInterrupted() throws Exception + { + testTakeSomeElementsThenBlocksAndIsInterrupted(false); + } + + @Test(timeout = TEST_TIMEOUT) + public void testTakeSomeElementsThenPollWithTimeoutAndIsInterrupted() throws Exception + { + testTakeSomeElementsThenBlocksAndIsInterrupted(true); + } + + private void testTakeSomeElementsThenBlocksAndIsInterrupted(boolean withTimeout) throws Exception + { + Val v = new Val(); + final AtomicBoolean wasInterrupted = new AtomicBoolean(); + final MpscBlockingConsumerArrayQueue q = new MpscBlockingConsumerArrayQueue<>(1024); + Thread consumer = new Thread(() -> { + try + { + while (true) + { + Integer take = withTimeout ? q.poll(1L, DAYS) : q.take(); + assertNotNull(take); // take never returns null + assertEquals(take.intValue(), v.value); + v.value++; + } + } + catch (InterruptedException e) + { + wasInterrupted.set(true); + } + }); + consumer.setDaemon(true); + consumer.start(); + while(consumer.getState() != State.TIMED_WAITING) + { + Thread.yield(); + } + // If we got here -> thread got to the waiting state -> parked + int someElements = ThreadLocalRandom.current().nextInt(10000); + for (int i=0;i < someElements; i++) + while (!q.offer(i)); + + while(!q.isEmpty()) + { + Thread.yield(); + } + // Eventually queue is drained + + while(consumer.getState() != State.TIMED_WAITING) + { + Thread.yield(); + } + // If we got here -> thread got to the waiting state -> parked + + consumer.interrupt(); + consumer.join(); + assertTrue(wasInterrupted.get()); + assertEquals(someElements, v.value); + } + + @Test + public void testOfferIfBelowThresholdSemantics() throws Exception + { + final AtomicBoolean stop = new AtomicBoolean(); + final MpscBlockingConsumerArrayQueue q = + new MpscBlockingConsumerArrayQueue<>(8); + + final Val fail = new Val(); + + Thread t1 = new Thread(() -> { + while (!stop.get()) + { + q.poll(); + + if (q.size() > 5) + { + fail.value++; + } + } + }); + + Thread t2 = new Thread(() -> { + while (!stop.get()) + { + q.offerIfBelowThreshold(1, 5); + } + }); + + t1.start(); + t2.start(); + Thread.sleep(1000); + stop.set(true); + t1.join(); + t2.join(); + assertEquals("Unexpected size observed", 0, fail.value); + } + + @Test + public void testOfferWithThreshold() + { + MpscBlockingConsumerArrayQueue queue = new MpscBlockingConsumerArrayQueue(16); + int i; + for (i = 0; i < 8; ++i) + { + //Offers succeed because current size is below the HWM. + Assert.assertTrue(queue.offerIfBelowThreshold(i, 8)); + } + //Not anymore, our offer got rejected. + Assert.assertFalse(queue.offerIfBelowThreshold(i, 8)); + Assert.assertFalse(queue.offerIfBelowThreshold(i, 7)); + Assert.assertFalse(queue.offerIfBelowThreshold(i, 1)); + Assert.assertFalse(queue.offerIfBelowThreshold(i, 0)); + + //Also, the threshold is dynamic and different levels can be set for + //different task priorities. + Assert.assertTrue(queue.offerIfBelowThreshold(i, 9)); + Assert.assertTrue(queue.offerIfBelowThreshold(i, 16)); + } + + /** + * This test demonstrates a race where a producer wins the CAS and writes + * null to blocked, only to have the consumer overwrite its change. To have + * it fail consistently, add a 1ms sleep (or a breakpoint) to + * parkUntilNext(), just before the soBlocked(Thread.currentThread()). If it + * hits the race writing to the blocked field, one of the threads will spin + * forever in spinWaitForUnblock(). + */ + @Test(timeout = TEST_TIMEOUT) + public void testSpinWaitForUnblockForever() throws InterruptedException { + + class Echo implements Runnable{ + private MpscBlockingConsumerArrayQueue source; + private MpscBlockingConsumerArrayQueue sink; + private int interations; + + Echo( + MpscBlockingConsumerArrayQueue source, + MpscBlockingConsumerArrayQueue sink, + int interations) { + this.source = source; + this.sink = sink; + this.interations = interations; + } + + public void run() { + try { + for (int i = 0; i < interations; ++i) { + T t; + do { + t = source.poll(1, TimeUnit.NANOSECONDS); + } + while (t == null); + + sink.put(t); + } + } + catch (InterruptedException e) { + throw new AssertionError(e); + } + } + } + + final MpscBlockingConsumerArrayQueue q1 = + new MpscBlockingConsumerArrayQueue<>(1024); + final MpscBlockingConsumerArrayQueue q2 = + new MpscBlockingConsumerArrayQueue<>(1024); + + final Thread t1 = new Thread(new Echo<>(q1, q2, 100000)); + final Thread t2 = new Thread(new Echo<>(q2, q1, 100000)); + + t1.start(); + t2.start(); + + q1.put("x"); + + t1.join(); + t2.join(); + } +} diff --git a/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestMpscBlockingConsumerOfferBelowThreshold.java b/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestMpscBlockingConsumerOfferBelowThreshold.java new file mode 100644 index 0000000..d29d709 --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestMpscBlockingConsumerOfferBelowThreshold.java @@ -0,0 +1,68 @@ +package org.jctools.queues; + +import org.jctools.queues.spec.ConcurrentQueueSpec; +import org.jctools.queues.spec.Ordering; +import org.junit.Ignore; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Queue; + +import static org.jctools.util.TestUtil.makeParams; + +@RunWith(Parameterized.class) +public class QueueSanityTestMpscBlockingConsumerOfferBelowThreshold extends QueueSanityTest +{ + public QueueSanityTestMpscBlockingConsumerOfferBelowThreshold(ConcurrentQueueSpec spec, Queue queue) + { + super(spec, queue); + } + + @Parameterized.Parameters + public static Collection parameters() + { + ArrayList list = new ArrayList(); + MpscBlockingConsumerArrayQueueOverride q = new MpscBlockingConsumerArrayQueueOverride(16); + list.add(makeParams(0, 1, 8, Ordering.FIFO, q)); + q = new MpscBlockingConsumerArrayQueueOverride(16); + q.threshold = 12; + list.add(makeParams(0, 1, 12, Ordering.FIFO, q)); + q = new MpscBlockingConsumerArrayQueueOverride(16); + q.threshold = 4; + list.add(makeParams(0, 1, 4, Ordering.FIFO, q)); + return list; + } + + @Ignore + public void testPowerOf2Capacity() + { + } + + @Ignore + public void testIterator() + { + } + + /** + * This allows us to test the offersIfBelowThreshold through all the offer utilizing threads. The effect should be + * as if the queue capacity is halved. + */ + static class MpscBlockingConsumerArrayQueueOverride extends MpscBlockingConsumerArrayQueue + { + int threshold; + + public MpscBlockingConsumerArrayQueueOverride(int capacity) + { + super(capacity); + threshold = capacity() / 2; + } + + @Override + public boolean offer(E e) + { + return super.offerIfBelowThreshold(e, threshold); + } + } +} diff --git a/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestMpscChunked.java b/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestMpscChunked.java new file mode 100644 index 0000000..5593156 --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestMpscChunked.java @@ -0,0 +1,39 @@ +package org.jctools.queues; + +import org.jctools.queues.atomic.MpscChunkedAtomicArrayQueue; +import org.jctools.queues.spec.ConcurrentQueueSpec; +import org.jctools.queues.spec.Ordering; +import org.jctools.queues.unpadded.MpscChunkedUnpaddedArrayQueue; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Queue; + +import static org.jctools.util.TestUtil.makeParams; + +@RunWith(Parameterized.class) +public class QueueSanityTestMpscChunked extends QueueSanityTestMpscArray +{ + public QueueSanityTestMpscChunked(ConcurrentQueueSpec spec, Queue queue) + { + super(spec, queue); + } + + @Parameterized.Parameters + public static Collection parameters() + { + ArrayList list = new ArrayList(); + list.add(makeParams(0, 1, 4, Ordering.FIFO, new MpscChunkedArrayQueue<>(2, 4)));// MPSC size 1 + list.add(makeParams(0, 1, SIZE, Ordering.FIFO, new MpscChunkedArrayQueue<>(8, SIZE)));// MPSC size SIZE + list.add(makeParams(0, 1, 4096, Ordering.FIFO, new MpscChunkedArrayQueue<>(32, 4096)));// Netty recycler defaults + list.add(makeParams(0, 1, 4, Ordering.FIFO, new MpscChunkedAtomicArrayQueue<>(2, 4)));// MPSC size 1 + list.add(makeParams(0, 1, SIZE, Ordering.FIFO, new MpscChunkedAtomicArrayQueue<>(8, SIZE)));// MPSC size SIZE + list.add(makeParams(0, 1, 4096, Ordering.FIFO, new MpscChunkedAtomicArrayQueue<>(32, 4096)));// Netty recycler defaults + list.add(makeParams(0, 1, 4, Ordering.FIFO, new MpscChunkedUnpaddedArrayQueue<>(2, 4)));// MPSC size 1 + list.add(makeParams(0, 1, SIZE, Ordering.FIFO, new MpscChunkedUnpaddedArrayQueue<>(8, SIZE)));// MPSC size SIZE + list.add(makeParams(0, 1, 4096, Ordering.FIFO, new MpscChunkedUnpaddedArrayQueue<>(32, 4096)));// Netty recycler defaults + return list; + } +} diff --git a/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestMpscChunkedExtended.java b/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestMpscChunkedExtended.java new file mode 100644 index 0000000..f882634 --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestMpscChunkedExtended.java @@ -0,0 +1,16 @@ +package org.jctools.queues; + +import org.junit.Test; + +public class QueueSanityTestMpscChunkedExtended +{ + @Test + public void testMaxSizeQueue() + { + MpscChunkedArrayQueue queue = new MpscChunkedArrayQueue(1024, 1000 * 1024 * 1024); + for (int i = 0; i < 400001; i++) + { + queue.offer(i); + } + } +} diff --git a/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestMpscCompound.java b/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestMpscCompound.java new file mode 100644 index 0000000..4c3b037 --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestMpscCompound.java @@ -0,0 +1,32 @@ +package org.jctools.queues; + +import org.jctools.queues.spec.ConcurrentQueueSpec; +import org.jctools.queues.spec.Ordering; +import org.jctools.util.Pow2; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Queue; + +import static org.jctools.util.PortableJvmInfo.CPUs; +import static org.jctools.util.TestUtil.makeMpq; + +@RunWith(Parameterized.class) +public class QueueSanityTestMpscCompound extends QueueSanityTest +{ + public QueueSanityTestMpscCompound(ConcurrentQueueSpec spec, Queue queue) + { + super(spec, queue); + } + + @Parameterized.Parameters + public static Collection parameters() + { + ArrayList list = new ArrayList(); + list.add(makeMpq(0, 1, Pow2.roundToPowerOfTwo(CPUs), Ordering.NONE)); + list.add(makeMpq(0, 1, SIZE, Ordering.NONE)); + return list; + } +} diff --git a/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestMpscGrowable.java b/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestMpscGrowable.java new file mode 100644 index 0000000..bccff5f --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestMpscGrowable.java @@ -0,0 +1,30 @@ +package org.jctools.queues; + +import org.jctools.queues.spec.ConcurrentQueueSpec; +import org.jctools.queues.spec.Ordering; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Queue; + +import static org.jctools.util.TestUtil.makeParams; + +@RunWith(Parameterized.class) +public class QueueSanityTestMpscGrowable extends QueueSanityTestMpscArray +{ + public QueueSanityTestMpscGrowable(ConcurrentQueueSpec spec, Queue queue) + { + super(spec, queue); + } + + @Parameterized.Parameters + public static Collection parameters() + { + ArrayList list = new ArrayList(); + list.add(makeParams(0, 1, 4, Ordering.FIFO, new MpscGrowableArrayQueue<>(2, 4)));// MPSC size 1 + list.add(makeParams(0, 1, SIZE, Ordering.FIFO, new MpscGrowableArrayQueue<>(8, SIZE)));// MPSC size SIZE + return list; + } +} diff --git a/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestMpscLinked.java b/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestMpscLinked.java new file mode 100644 index 0000000..01f6c28 --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestMpscLinked.java @@ -0,0 +1,31 @@ +package org.jctools.queues; + +import org.jctools.queues.spec.ConcurrentQueueSpec; +import org.jctools.queues.spec.Ordering; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Queue; + +import static org.jctools.util.TestUtil.*; + +@RunWith(Parameterized.class) +public class QueueSanityTestMpscLinked extends QueueSanityTest +{ + public QueueSanityTestMpscLinked(ConcurrentQueueSpec spec, Queue queue) + { + super(spec, queue); + } + + @Parameterized.Parameters + public static Collection parameters() + { + ArrayList list = new ArrayList(); + list.add(makeMpq(0, 1, 0, Ordering.FIFO)); + list.add(makeAtomic(0, 1, 0, Ordering.FIFO)); + list.add(makeUnpadded(0, 1, 0, Ordering.FIFO)); + return list; + } +} diff --git a/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestMpscOfferBelowThreshold.java b/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestMpscOfferBelowThreshold.java new file mode 100644 index 0000000..8f791e1 --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestMpscOfferBelowThreshold.java @@ -0,0 +1,68 @@ +package org.jctools.queues; + +import org.jctools.queues.spec.ConcurrentQueueSpec; +import org.jctools.queues.spec.Ordering; +import org.junit.Ignore; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Queue; + +import static org.jctools.util.TestUtil.makeParams; + +@RunWith(Parameterized.class) +public class QueueSanityTestMpscOfferBelowThreshold extends QueueSanityTest +{ + public QueueSanityTestMpscOfferBelowThreshold(ConcurrentQueueSpec spec, Queue queue) + { + super(spec, queue); + } + + @Parameterized.Parameters + public static Collection parameters() + { + ArrayList list = new ArrayList(); + MpscArrayQueueOverride q = new MpscArrayQueueOverride(16); + list.add(makeParams(0, 1, 8, Ordering.FIFO, q)); + q = new MpscArrayQueueOverride(16); + q.threshold = 12; + list.add(makeParams(0, 1, 12, Ordering.FIFO, q)); + q = new MpscArrayQueueOverride(16); + q.threshold = 4; + list.add(makeParams(0, 1, 4, Ordering.FIFO, q)); + return list; + } + + @Ignore + public void testPowerOf2Capacity() + { + } + + @Ignore + public void testIterator() + { + } + + /** + * This allows us to test the offersIfBelowThreshold through all the offer utilizing threads. The effect should be + * as if the queue capacity is halved. + */ + static class MpscArrayQueueOverride extends MpscArrayQueue + { + int threshold; + + public MpscArrayQueueOverride(int capacity) + { + super(capacity); + threshold = capacity() / 2; + } + + @Override + public boolean offer(E e) + { + return super.offerIfBelowThreshold(e, threshold); + } + } +} diff --git a/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestMpscUnboundedArray.java b/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestMpscUnboundedArray.java new file mode 100644 index 0000000..04b20a5 --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestMpscUnboundedArray.java @@ -0,0 +1,36 @@ +package org.jctools.queues; + +import org.jctools.queues.atomic.MpscUnboundedAtomicArrayQueue; +import org.jctools.queues.spec.ConcurrentQueueSpec; +import org.jctools.queues.spec.Ordering; +import org.jctools.queues.unpadded.MpscUnboundedUnpaddedArrayQueue; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Queue; + +import static org.jctools.util.TestUtil.makeParams; + +@RunWith(Parameterized.class) +public class QueueSanityTestMpscUnboundedArray extends QueueSanityTest +{ + public QueueSanityTestMpscUnboundedArray(ConcurrentQueueSpec spec, Queue queue) + { + super(spec, queue); + } + + @Parameterized.Parameters + public static Collection parameters() + { + ArrayList list = new ArrayList(); + list.add(makeParams(0, 1, 0, Ordering.FIFO, new MpscUnboundedArrayQueue<>(2)));// MPSC size 1 + list.add(makeParams(0, 1, 0, Ordering.FIFO, new MpscUnboundedArrayQueue<>(64)));// MPSC size SIZE + list.add(makeParams(0, 1, 0, Ordering.FIFO, new MpscUnboundedAtomicArrayQueue<>(2)));// MPSC size 1 + list.add(makeParams(0, 1, 0, Ordering.FIFO, new MpscUnboundedAtomicArrayQueue<>(64)));// MPSC size SIZE + list.add(makeParams(0, 1, 0, Ordering.FIFO, new MpscUnboundedUnpaddedArrayQueue<>(2)));// MPSC size 1 + list.add(makeParams(0, 1, 0, Ordering.FIFO, new MpscUnboundedUnpaddedArrayQueue<>(64)));// MPSC size SIZE + return list; + } +} diff --git a/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestMpscUnboundedXadd.java b/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestMpscUnboundedXadd.java new file mode 100644 index 0000000..4b394b5 --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestMpscUnboundedXadd.java @@ -0,0 +1,36 @@ +package org.jctools.queues; + +import org.jctools.queues.spec.ConcurrentQueueSpec; +import org.jctools.queues.spec.Ordering; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Queue; + +import static org.jctools.util.TestUtil.makeParams; + +@RunWith(Parameterized.class) +public class QueueSanityTestMpscUnboundedXadd extends QueueSanityTest +{ + public QueueSanityTestMpscUnboundedXadd(ConcurrentQueueSpec spec, Queue queue) + { + super(spec, queue); + } + + @Parameterized.Parameters + public static Collection parameters() + { + ArrayList list = new ArrayList(); + list.add(makeParams(0, 1, 0, Ordering.FIFO, new MpscUnboundedXaddArrayQueue<>(1, 0))); + list.add(makeParams(0, 1, 0, Ordering.FIFO, new MpscUnboundedXaddArrayQueue<>(64, 0))); + list.add(makeParams(0, 1, 0, Ordering.FIFO, new MpscUnboundedXaddArrayQueue<>(1, 1))); + list.add(makeParams(0, 1, 0, Ordering.FIFO, new MpscUnboundedXaddArrayQueue<>(64, 1))); + list.add(makeParams(0, 1, 0, Ordering.FIFO, new MpscUnboundedXaddArrayQueue<>(1, 2))); + list.add(makeParams(0, 1, 0, Ordering.FIFO, new MpscUnboundedXaddArrayQueue<>(64, 2))); + list.add(makeParams(0, 1, 0, Ordering.FIFO, new MpscUnboundedXaddArrayQueue<>(1, 3))); + list.add(makeParams(0, 1, 0, Ordering.FIFO, new MpscUnboundedXaddArrayQueue<>(64, 3))); + return list; + } +} diff --git a/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestSpmcArray.java b/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestSpmcArray.java new file mode 100644 index 0000000..b79617e --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestSpmcArray.java @@ -0,0 +1,36 @@ +package org.jctools.queues; + +import org.jctools.queues.spec.ConcurrentQueueSpec; +import org.jctools.queues.spec.Ordering; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Queue; + +import static org.jctools.util.TestUtil.*; + +@RunWith(Parameterized.class) +public class QueueSanityTestSpmcArray extends QueueSanityTest +{ + public QueueSanityTestSpmcArray(ConcurrentQueueSpec spec, Queue queue) + { + super(spec, queue); + } + + @Parameterized.Parameters + public static Collection parameters() + { + ArrayList list = new ArrayList(); + // need at least size 2 for this test + list.add(makeMpq(1, 0, 1, Ordering.FIFO)); + list.add(makeMpq(1, 0, SIZE, Ordering.FIFO)); + list.add(makeAtomic(1, 0, 1, Ordering.FIFO)); + list.add(makeAtomic(1, 0, SIZE, Ordering.FIFO)); + list.add(makeUnpadded(1, 0, 1, Ordering.FIFO)); + list.add(makeUnpadded(1, 0, SIZE, Ordering.FIFO)); + + return list; + } +} diff --git a/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestSpscArray.java b/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestSpscArray.java new file mode 100644 index 0000000..48b94ae --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestSpscArray.java @@ -0,0 +1,34 @@ +package org.jctools.queues; + +import org.jctools.queues.spec.ConcurrentQueueSpec; +import org.jctools.queues.spec.Ordering; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Queue; + +import static org.jctools.util.TestUtil.*; + +@RunWith(Parameterized.class) +public class QueueSanityTestSpscArray extends QueueSanityTest +{ + public QueueSanityTestSpscArray(ConcurrentQueueSpec spec, Queue queue) + { + super(spec, queue); + } + + @Parameterized.Parameters + public static Collection parameters() + { + ArrayList list = new ArrayList(); + list.add(makeMpq(1, 1, 4, Ordering.FIFO)); + list.add(makeMpq(1, 1, SIZE, Ordering.FIFO)); + list.add(makeAtomic(1, 1, 4, Ordering.FIFO)); + list.add(makeAtomic(1, 1, SIZE, Ordering.FIFO)); + list.add(makeUnpadded(1, 1, 4, Ordering.FIFO)); + list.add(makeUnpadded(1, 1, SIZE, Ordering.FIFO)); + return list; + } +} diff --git a/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestSpscArrayExtended.java b/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestSpscArrayExtended.java new file mode 100644 index 0000000..91b2321 --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestSpscArrayExtended.java @@ -0,0 +1,45 @@ +package org.jctools.queues; + +import org.junit.Test; + +import static org.hamcrest.Matchers.*; +import static org.jctools.queues.matchers.Matchers.emptyAndZeroSize; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; + +public class QueueSanityTestSpscArrayExtended +{ + @Test + public void shouldWorkAfterWrap() + { + // Arrange + final SpscArrayQueue q = new SpscArrayQueue(1024); + // starting point for empty queue at max long, next offer will wrap the producerIndex + q.soConsumerIndex(Long.MAX_VALUE); + q.soProducerIndex(Long.MAX_VALUE); + q.producerLimit = Long.MAX_VALUE; + // valid starting point + assertThat(q, emptyAndZeroSize()); + + // Act + // assert offer is successful + final Object e = new Object(); + assertTrue(q.offer(e)); + // size is computed correctly after wrap + assertThat(q, not(emptyAndZeroSize())); + assertThat(q, hasSize(1)); + + // now consumer index wraps + final Object poll = q.poll(); + assertThat(poll, sameInstance(e)); + assertThat(q, emptyAndZeroSize()); + + // let's go again + assertTrue(q.offer(e)); + assertThat(q, not(emptyAndZeroSize())); + + final Object poll2 = q.poll(); + assertThat(poll2, sameInstance(e)); + assertThat(q, emptyAndZeroSize()); + } +} diff --git a/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestSpscChunked.java b/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestSpscChunked.java new file mode 100644 index 0000000..0cc8011 --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestSpscChunked.java @@ -0,0 +1,36 @@ +package org.jctools.queues; + +import org.jctools.queues.atomic.SpscChunkedAtomicArrayQueue; +import org.jctools.queues.spec.ConcurrentQueueSpec; +import org.jctools.queues.spec.Ordering; +import org.jctools.queues.unpadded.SpscChunkedUnpaddedArrayQueue; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Queue; + +import static org.jctools.util.TestUtil.makeParams; + +@RunWith(Parameterized.class) +public class QueueSanityTestSpscChunked extends QueueSanityTest +{ + public QueueSanityTestSpscChunked(ConcurrentQueueSpec spec, Queue queue) + { + super(spec, queue); + } + + @Parameterized.Parameters + public static Collection parameters() + { + ArrayList list = new ArrayList(); + list.add(makeParams(1, 1, 16, Ordering.FIFO, new SpscChunkedArrayQueue<>(8, 16)));// MPSC size 1 + list.add(makeParams(1, 1, SIZE, Ordering.FIFO, new SpscChunkedArrayQueue<>(8, SIZE)));// MPSC size SIZE + list.add(makeParams(1, 1, 16, Ordering.FIFO, new SpscChunkedAtomicArrayQueue<>(8, 16)));// MPSC size 1 + list.add(makeParams(1, 1, SIZE, Ordering.FIFO, new SpscChunkedAtomicArrayQueue<>(8, SIZE)));// MPSC size SIZE + list.add(makeParams(1, 1, 16, Ordering.FIFO, new SpscChunkedUnpaddedArrayQueue<>(8, 16)));// MPSC size 1 + list.add(makeParams(1, 1, SIZE, Ordering.FIFO, new SpscChunkedUnpaddedArrayQueue<>(8, SIZE)));// MPSC size SIZE + return list; + } +} diff --git a/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestSpscChunkedExtended.java b/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestSpscChunkedExtended.java new file mode 100644 index 0000000..37dac85 --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestSpscChunkedExtended.java @@ -0,0 +1,17 @@ +package org.jctools.queues; + +import org.junit.Test; + +public class QueueSanityTestSpscChunkedExtended +{ + + @Test + public void testMaxSizeQueue() + { + SpscChunkedArrayQueue queue = new SpscChunkedArrayQueue(1024, 1000 * 1024 * 1024); + for (int i = 0; i < 400001; i++) + { + queue.offer(i); + } + } +} diff --git a/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestSpscGrowable.java b/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestSpscGrowable.java new file mode 100644 index 0000000..9420244 --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestSpscGrowable.java @@ -0,0 +1,66 @@ +package org.jctools.queues; + +import org.jctools.queues.atomic.SpscGrowableAtomicArrayQueue; +import org.jctools.queues.spec.ConcurrentQueueSpec; +import org.jctools.queues.spec.Ordering; +import org.jctools.queues.unpadded.SpscGrowableUnpaddedArrayQueue; +import org.junit.Assert; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Queue; + +import static org.hamcrest.Matchers.is; +import static org.jctools.util.TestUtil.makeParams; + +@RunWith(Parameterized.class) +public class QueueSanityTestSpscGrowable extends QueueSanityTest +{ + + public QueueSanityTestSpscGrowable(ConcurrentQueueSpec spec, Queue queue) + { + super(spec, queue); + } + + @Parameterized.Parameters + public static Collection parameters() + { + ArrayList list = new ArrayList(); + list.add(makeParams(1, 1, 16, Ordering.FIFO, new SpscGrowableArrayQueue<>(8, 16))); + list.add(makeParams(1, 1, SIZE, Ordering.FIFO, new SpscGrowableArrayQueue<>(8, SIZE))); + list.add(makeParams(1, 1, 16, Ordering.FIFO, new SpscGrowableAtomicArrayQueue<>(8, 16))); + list.add(makeParams(1, 1, SIZE, Ordering.FIFO, new SpscGrowableAtomicArrayQueue<>(8, SIZE))); + list.add(makeParams(1, 1, 16, Ordering.FIFO, new SpscGrowableUnpaddedArrayQueue<>(8, 16))); + list.add(makeParams(1, 1, SIZE, Ordering.FIFO, new SpscGrowableUnpaddedArrayQueue<>(8, SIZE))); + return list; + } + + @Test + public void testSizeNeverExceedCapacity() + { + final SpscGrowableArrayQueue q = new SpscGrowableArrayQueue<>(8, 16); + final Integer v = 0; + final int capacity = q.capacity(); + for (int i = 0; i < capacity; i++) + { + Assert.assertTrue(q.offer(v)); + } + Assert.assertFalse(q.offer(v)); + Assert.assertThat(q.size(), is(capacity)); + for (int i = 0; i < 6; i++) + { + Assert.assertEquals(v, q.poll()); + } + //the consumer is left in the chunk previous the last and biggest one + Assert.assertThat(q.size(), is(capacity - 6)); + for (int i = 0; i < 6; i++) + { + q.offer(v); + } + Assert.assertThat(q.size(), is(capacity)); + Assert.assertFalse(q.offer(v)); + } +} diff --git a/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestSpscGrowableExtended.java b/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestSpscGrowableExtended.java new file mode 100644 index 0000000..08dffce --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestSpscGrowableExtended.java @@ -0,0 +1,35 @@ +package org.jctools.queues; + +import org.junit.Assert; +import org.junit.Test; + +import static org.hamcrest.Matchers.is; + +public class QueueSanityTestSpscGrowableExtended +{ + @Test + public void testSizeNeverExceedCapacity() + { + final SpscGrowableArrayQueue q = new SpscGrowableArrayQueue<>(8, 16); + final Integer v = 0; + final int capacity = q.capacity(); + for (int i = 0; i < capacity; i++) + { + Assert.assertTrue(q.offer(v)); + } + Assert.assertFalse(q.offer(v)); + Assert.assertThat(q.size(), is(capacity)); + for (int i = 0; i < 6; i++) + { + Assert.assertEquals(v, q.poll()); + } + //the consumer is left in the chunk previous the last and biggest one + Assert.assertThat(q.size(), is(capacity - 6)); + for (int i = 0; i < 6; i++) + { + q.offer(v); + } + Assert.assertThat(q.size(), is(capacity)); + Assert.assertFalse(q.offer(v)); + } +} diff --git a/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestSpscLinked.java b/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestSpscLinked.java new file mode 100644 index 0000000..4929e31 --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestSpscLinked.java @@ -0,0 +1,31 @@ +package org.jctools.queues; + +import org.jctools.queues.spec.ConcurrentQueueSpec; +import org.jctools.queues.spec.Ordering; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Queue; + +import static org.jctools.util.TestUtil.*; + +@RunWith(Parameterized.class) +public class QueueSanityTestSpscLinked extends QueueSanityTest +{ + public QueueSanityTestSpscLinked(ConcurrentQueueSpec spec, Queue queue) + { + super(spec, queue); + } + + @Parameterized.Parameters + public static Collection parameters() + { + ArrayList list = new ArrayList(); + list.add(makeMpq(1, 1, 0, Ordering.FIFO)); + list.add(makeAtomic(1, 1, 0, Ordering.FIFO)); + list.add(makeUnpadded(1, 1, 0, Ordering.FIFO)); + return list; + } +} diff --git a/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestSpscUnbounded.java b/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestSpscUnbounded.java new file mode 100644 index 0000000..fc5260b --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/queues/QueueSanityTestSpscUnbounded.java @@ -0,0 +1,37 @@ +package org.jctools.queues; + +import org.jctools.queues.atomic.SpscUnboundedAtomicArrayQueue; +import org.jctools.queues.spec.ConcurrentQueueSpec; +import org.jctools.queues.spec.Ordering; +import org.jctools.queues.unpadded.SpscUnboundedUnpaddedArrayQueue; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.Queue; + +import static org.jctools.util.TestUtil.makeParams; + +@RunWith(Parameterized.class) +public class QueueSanityTestSpscUnbounded extends QueueSanityTest +{ + + public QueueSanityTestSpscUnbounded(ConcurrentQueueSpec spec, Queue queue) + { + super(spec, queue); + } + + @Parameterized.Parameters + public static Collection parameters() + { + ArrayList list = new ArrayList(); + list.add(makeParams(1, 1, 0, Ordering.FIFO, new SpscUnboundedArrayQueue<>(2))); + list.add(makeParams(1, 1, 0, Ordering.FIFO, new SpscUnboundedArrayQueue<>(64))); + list.add(makeParams(1, 1, 0, Ordering.FIFO, new SpscUnboundedAtomicArrayQueue<>(2))); + list.add(makeParams(1, 1, 0, Ordering.FIFO, new SpscUnboundedAtomicArrayQueue<>(64))); + list.add(makeParams(1, 1, 0, Ordering.FIFO, new SpscUnboundedUnpaddedArrayQueue<>(2))); + list.add(makeParams(1, 1, 0, Ordering.FIFO, new SpscUnboundedUnpaddedArrayQueue<>(64))); + return list; + } +} diff --git a/netty-jctools/src/test/java/org/jctools/queues/ScQueueRemoveTest.java b/netty-jctools/src/test/java/org/jctools/queues/ScQueueRemoveTest.java new file mode 100644 index 0000000..52709e8 --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/queues/ScQueueRemoveTest.java @@ -0,0 +1,188 @@ +package org.jctools.queues; + +import java.util.Queue; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.locks.LockSupport; + +import org.junit.Test; + +import static org.junit.Assert.*; + +public abstract class ScQueueRemoveTest +{ + private static void assertQueueEmpty(Queue queue) + { + assertNull(queue.peek()); + assertNull(queue.poll()); + assertTrue(queue.isEmpty()); + assertEquals(0, queue.size()); + } + + protected abstract Queue newQueue(); + + private void removeSimple(int removeValue, int expectedFirst, int expectedSecond) throws InterruptedException + { + final Queue queue = newQueue(); + Thread t = new Thread() + { + @Override + public void run() + { + queue.offer(1); + queue.offer(2); + queue.offer(3); + } + }; + + assertQueueEmpty(queue); + + t.start(); + + while (queue.size() < 3) + { + Thread.yield(); + } + + assertTrue(queue.remove(removeValue)); + // Try to remove again, just to ensure pointers are updated as expected. + assertFalse(queue.remove(removeValue)); + assertFalse(queue.isEmpty()); + assertEquals(2, queue.size()); + + assertEquals(expectedFirst, queue.poll().intValue()); + assertEquals(expectedSecond, queue.poll().intValue()); + assertQueueEmpty(queue); + + t.join(); + } + + @Test + public void removeConsumerNode() throws InterruptedException + { + removeSimple(1, 2, 3); + } + + @Test + public void removeInteriorNode() throws InterruptedException + { + removeSimple(2, 1, 3); + } + + @Test + public void removeProducerNode() throws InterruptedException + { + removeSimple(3, 1, 2); + } + + @Test + public void removeFailsWhenExpected() throws InterruptedException + { + final Queue queue = newQueue(); + Thread t = new Thread() + { + @Override + public void run() + { + queue.offer(1); + queue.offer(2); + queue.offer(3); + } + }; + + assertQueueEmpty(queue); + + t.start(); + + while (queue.size() < 3) + { + Thread.yield(); + } + + // Remove an element which doesn't exist. + assertFalse(queue.remove(4)); + assertFalse(queue.remove(4)); + assertFalse(queue.isEmpty()); + assertEquals(3, queue.size()); + + // Verify that none of the links have been modified. + assertEquals(1, queue.poll().intValue()); + assertEquals(2, queue.poll().intValue()); + assertEquals(3, queue.poll().intValue()); + assertQueueEmpty(queue); + + t.join(); + } + + @Test(timeout = 1000) + public void removeStressTest() throws InterruptedException + { + final AtomicBoolean running = new AtomicBoolean(true); + final AtomicBoolean failed = new AtomicBoolean(false); + final Queue queue = newQueue(); + + Thread p = new Thread() + { + @Override + public void run() + { + int i = 0; + try + { + while (running.get()) + { + if (queue.isEmpty()) + { + queue.offer(i++); + queue.offer(i++); + queue.offer(i++); + } + } + } + catch (Exception e) + { + e.printStackTrace(); + failed.set(true); + running.set(false); + } + } + }; + + Thread c = new Thread() + { + @Override + public void run() + { + int i = 0; + try + { + while (running.get()) + { + if (!queue.isEmpty()) + { + if (!queue.remove(i)) + { + failed.set(true); + running.set(false); + } + i++; + } + } + } + catch (Exception e) + { + e.printStackTrace(); + failed.set(true); + running.set(false); + } + } + }; + p.start(); + c.start(); + LockSupport.parkNanos(TimeUnit.MILLISECONDS.toNanos(250)); + running.set(false); + p.join(); + c.join(); + assertFalse(failed.get()); + } +} diff --git a/netty-jctools/src/test/java/org/jctools/queues/ScQueueRemoveTestMpscLinked.java b/netty-jctools/src/test/java/org/jctools/queues/ScQueueRemoveTestMpscLinked.java new file mode 100644 index 0000000..6046b48 --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/queues/ScQueueRemoveTestMpscLinked.java @@ -0,0 +1,10 @@ +package org.jctools.queues; + +import java.util.Queue; + +public class ScQueueRemoveTestMpscLinked extends ScQueueRemoveTest { + @Override + protected Queue newQueue() { + return new MpscLinkedQueue(); + } +} diff --git a/netty-jctools/src/test/java/org/jctools/queues/atomic/MpscAtomicArrayQueueOfferWithThresholdTest.java b/netty-jctools/src/test/java/org/jctools/queues/atomic/MpscAtomicArrayQueueOfferWithThresholdTest.java new file mode 100644 index 0000000..cdafda1 --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/queues/atomic/MpscAtomicArrayQueueOfferWithThresholdTest.java @@ -0,0 +1,38 @@ +package org.jctools.queues.atomic; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +public class MpscAtomicArrayQueueOfferWithThresholdTest +{ + + private MpscAtomicArrayQueue queue; + + @Before + public void setUp() throws Exception + { + this.queue = new MpscAtomicArrayQueue(16); + } + + @Test + public void testOfferWithThreshold() + { + int i; + for (i = 0; i < 8; ++i) + { + //Offers succeed because current size is below the HWM. + Assert.assertTrue(this.queue.offerIfBelowThreshold(i, 8)); + } + //Not anymore, our offer got rejected. + Assert.assertFalse(this.queue.offerIfBelowThreshold(i, 8)); + Assert.assertFalse(this.queue.offerIfBelowThreshold(i, 7)); + Assert.assertFalse(this.queue.offerIfBelowThreshold(i, 1)); + Assert.assertFalse(this.queue.offerIfBelowThreshold(i, 0)); + + //Also, the threshold is dynamic and different levels can be set for + //different task priorities. + Assert.assertTrue(this.queue.offerIfBelowThreshold(i, 9)); + Assert.assertTrue(this.queue.offerIfBelowThreshold(i, 16)); + } +} diff --git a/netty-jctools/src/test/java/org/jctools/queues/atomic/MpscLinkedAtomicQueueRemoveTest.java b/netty-jctools/src/test/java/org/jctools/queues/atomic/MpscLinkedAtomicQueueRemoveTest.java new file mode 100644 index 0000000..89e74f8 --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/queues/atomic/MpscLinkedAtomicQueueRemoveTest.java @@ -0,0 +1,12 @@ +package org.jctools.queues.atomic; + +import java.util.Queue; + +import org.jctools.queues.ScQueueRemoveTest; + +public class MpscLinkedAtomicQueueRemoveTest extends ScQueueRemoveTest { + @Override + protected Queue newQueue() { + return new MpscLinkedAtomicQueue<>(); + } +} diff --git a/netty-jctools/src/test/java/org/jctools/queues/atomic/SpscAtomicArrayQueueTest.java b/netty-jctools/src/test/java/org/jctools/queues/atomic/SpscAtomicArrayQueueTest.java new file mode 100644 index 0000000..600d735 --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/queues/atomic/SpscAtomicArrayQueueTest.java @@ -0,0 +1,58 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues.atomic; + +import org.junit.Test; + +import static org.hamcrest.Matchers.*; +import static org.jctools.queues.matchers.Matchers.emptyAndZeroSize; +import static org.junit.Assert.assertThat; +import static org.junit.Assert.assertTrue; + +public class SpscAtomicArrayQueueTest +{ + @Test + public void shouldWorkAfterWrap() + { + // Arrange + final SpscAtomicArrayQueue q = new SpscAtomicArrayQueue(1024); + // starting point for empty queue at max long, next offer will wrap the producerIndex + q.soConsumerIndex(Long.MAX_VALUE); + q.soProducerIndex(Long.MAX_VALUE); + q.producerLimit = Long.MAX_VALUE; + // valid starting point + assertThat(q, emptyAndZeroSize()); + + // Act + // assert offer is successful + final Object e = new Object(); + assertTrue(q.offer(e)); + // size is computed correctly after wrap + assertThat(q, not(emptyAndZeroSize())); + assertThat(q, hasSize(1)); + + // now consumer index wraps + final Object poll = q.poll(); + assertThat(poll, sameInstance(e)); + assertThat(q, emptyAndZeroSize()); + + // let's go again + assertTrue(q.offer(e)); + assertThat(q, not(emptyAndZeroSize())); + + final Object poll2 = q.poll(); + assertThat(poll2, sameInstance(e)); + assertThat(q, emptyAndZeroSize()); + } +} diff --git a/netty-jctools/src/test/java/org/jctools/queues/matchers/Matchers.java b/netty-jctools/src/test/java/org/jctools/queues/matchers/Matchers.java new file mode 100644 index 0000000..84780ab --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/queues/matchers/Matchers.java @@ -0,0 +1,22 @@ +package org.jctools.queues.matchers; + +import java.util.Collection; + +import org.hamcrest.Matcher; + +import static org.hamcrest.Matchers.*; + +/** + * @author Andrey Satarin (https://github.com/asatarin) + */ +public class Matchers +{ + private Matchers() + { + } + + public static Matcher> emptyAndZeroSize() + { + return allOf(hasSize(0), empty()); + } +} diff --git a/netty-jctools/src/test/java/org/jctools/queues/spec/ConcurrentQueueSpec.java b/netty-jctools/src/test/java/org/jctools/queues/spec/ConcurrentQueueSpec.java new file mode 100644 index 0000000..457c774 --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/queues/spec/ConcurrentQueueSpec.java @@ -0,0 +1,93 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues.spec; + +import org.jctools.queues.MessagePassingQueue; + +@Deprecated//(since = "3.0.0") +public final class ConcurrentQueueSpec +{ + public final int producers; + public final int consumers; + public final int capacity; + public final Ordering ordering; + public final Preference preference; + + public static ConcurrentQueueSpec createBoundedSpsc(int capacity) + { + return new ConcurrentQueueSpec(1, 1, capacity, Ordering.FIFO, Preference.NONE); + } + + public static ConcurrentQueueSpec createBoundedMpsc(int capacity) + { + return new ConcurrentQueueSpec(0, 1, capacity, Ordering.FIFO, Preference.NONE); + } + + public static ConcurrentQueueSpec createBoundedSpmc(int capacity) + { + return new ConcurrentQueueSpec(1, 0, capacity, Ordering.FIFO, Preference.NONE); + } + + public static ConcurrentQueueSpec createBoundedMpmc(int capacity) + { + return new ConcurrentQueueSpec(0, 0, capacity, Ordering.FIFO, Preference.NONE); + } + + public ConcurrentQueueSpec(int producers, int consumers, int capacity, Ordering ordering, Preference preference) + { + super(); + this.producers = producers; + this.consumers = consumers; + this.capacity = capacity < 1 ? MessagePassingQueue.UNBOUNDED_CAPACITY : capacity; + this.ordering = ordering; + this.preference = preference; + } + + public boolean isSpsc() + { + return consumers == 1 && producers == 1; + } + + public boolean isMpsc() + { + return consumers == 1 && producers != 1; + } + + public boolean isSpmc() + { + return consumers != 1 && producers == 1; + } + + public boolean isMpmc() + { + return consumers != 1 && producers != 1; + } + + public boolean isBounded() + { + return capacity != MessagePassingQueue.UNBOUNDED_CAPACITY; + } + + @Override + public String toString() + { + return "ConcurrentQueueSpec{" + + "producers=" + producers + + ", consumers=" + consumers + + ", capacity=" + capacity + + ", ordering=" + ordering + + ", preference=" + preference + + '}'; + } +} \ No newline at end of file diff --git a/netty-jctools/src/test/java/org/jctools/queues/spec/Ordering.java b/netty-jctools/src/test/java/org/jctools/queues/spec/Ordering.java new file mode 100644 index 0000000..9331081 --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/queues/spec/Ordering.java @@ -0,0 +1,20 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues.spec; + +@Deprecated//(since = "3.0.0") +public enum Ordering +{ + FIFO, KFIFO, PRODUCER_FIFO, NONE +} diff --git a/netty-jctools/src/test/java/org/jctools/queues/spec/Preference.java b/netty-jctools/src/test/java/org/jctools/queues/spec/Preference.java new file mode 100644 index 0000000..1413515 --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/queues/spec/Preference.java @@ -0,0 +1,20 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.queues.spec; + +@Deprecated//(since = "3.0.0") +public enum Preference +{ + LATENCY, THROUGHPUT, NONE +} diff --git a/netty-jctools/src/test/java/org/jctools/util/AtomicQueueFactory.java b/netty-jctools/src/test/java/org/jctools/util/AtomicQueueFactory.java new file mode 100644 index 0000000..cbf074e --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/util/AtomicQueueFactory.java @@ -0,0 +1,72 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.util; + +import org.jctools.queues.atomic.*; +import org.jctools.queues.spec.ConcurrentQueueSpec; + +import java.util.Queue; + +/** + * The queue factory produces {@link Queue} instances based on a best fit to the {@link ConcurrentQueueSpec}. + * This allows minimal dependencies between user code and the queue implementations and gives users a way to express + * their requirements on a higher level. + * + * @author nitsanw + * @author akarnokd + */ +@Deprecated//(since = "4.0.0") +public class AtomicQueueFactory +{ + public static Queue newAtomicQueue(ConcurrentQueueSpec qs) + { + if (qs.isBounded()) + { + // SPSC + if (qs.isSpsc()) + { + return new SpscAtomicArrayQueue(qs.capacity); + } + // MPSC + else if (qs.isMpsc()) + { + return new MpscAtomicArrayQueue(qs.capacity); + } + // SPMC + else if (qs.isSpmc()) + { + return new SpmcAtomicArrayQueue(qs.capacity); + } + // MPMC + else + { + return new MpmcAtomicArrayQueue(qs.capacity); + } + } + else + { + // SPSC + if (qs.isSpsc()) + { + return new SpscLinkedAtomicQueue(); + } + // MPSC + else if (qs.isMpsc()) + { + return new MpscLinkedAtomicQueue(); + } + } + throw new IllegalArgumentException("Cannot match queue for spec:" + qs); + } +} diff --git a/netty-jctools/src/test/java/org/jctools/util/PaddedAtomicLongTest.java b/netty-jctools/src/test/java/org/jctools/util/PaddedAtomicLongTest.java new file mode 100644 index 0000000..c060e7c --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/util/PaddedAtomicLongTest.java @@ -0,0 +1,214 @@ +package org.jctools.util; + +import org.junit.Test; + +import java.util.function.LongBinaryOperator; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; + +public class PaddedAtomicLongTest { + + @Test + public void testDefaultConstructor() { + PaddedAtomicLong counter = new PaddedAtomicLong(); + + assertEquals(0L, counter.get()); + } + + @Test + public void testConstructor_withValue() { + PaddedAtomicLong counter = new PaddedAtomicLong(20); + + assertEquals(20, counter.get()); + } + + + @Test + public void testSet() { + PaddedAtomicLong counter = new PaddedAtomicLong(); + + counter.set(10); + + assertEquals(10L, counter.get()); + } + + @Test + public void lazySet() { + PaddedAtomicLong counter = new PaddedAtomicLong(); + + counter.lazySet(10); + + assertEquals(10L, counter.get()); + } + + @Test + public void testGetAndSet() { + PaddedAtomicLong counter = new PaddedAtomicLong(1); + + long result = counter.getAndSet(2); + + assertEquals(1, result); + assertEquals(2, counter.get()); + } + + @Test + public void testCompareAndSet_whenSuccess() { + PaddedAtomicLong counter = new PaddedAtomicLong(); + + assertTrue(counter.compareAndSet(0, 1)); + assertEquals(1, counter.get()); + } + + @Test + public void testCompareAndSet_whenFailure() { + PaddedAtomicLong counter = new PaddedAtomicLong(); + + assertFalse(counter.compareAndSet(1, 2)); + assertEquals(0, counter.get()); + } + + @Test + public void testWeakCompareAndSet_whenSuccess() { + PaddedAtomicLong counter = new PaddedAtomicLong(); + + assertTrue(counter.weakCompareAndSet(0, 1)); + assertEquals(1, counter.get()); + } + + @Test + public void testWeakCompareAndSet_whenFailure() { + PaddedAtomicLong counter = new PaddedAtomicLong(); + + assertFalse(counter.weakCompareAndSet(1, 2)); + assertEquals(0, counter.get()); + } + + @Test + public void testGetAndIncrement() { + PaddedAtomicLong counter = new PaddedAtomicLong(); + + long value = counter.getAndIncrement(); + assertEquals(0L, value); + assertEquals(1, counter.get()); + } + + @Test + public void testGetAndDecrement() { + PaddedAtomicLong counter = new PaddedAtomicLong(); + + long value = counter.getAndDecrement(); + assertEquals(0L, value); + assertEquals(-1, counter.get()); + } + + @Test + public void testGetAndAdd() { + PaddedAtomicLong counter = new PaddedAtomicLong(); + + long value = counter.getAndAdd(10); + assertEquals(0L, value); + assertEquals(10, counter.get()); + } + + @Test + public void testIncrementAndGet() { + PaddedAtomicLong counter = new PaddedAtomicLong(); + + long value = counter.incrementAndGet(); + assertEquals(1L, value); + assertEquals(1L, counter.get()); + } + + @Test + public void testDecrementAndGet() { + PaddedAtomicLong counter = new PaddedAtomicLong(); + + long value = counter.decrementAndGet(); + assertEquals(-1, value); + assertEquals(-1, counter.get()); + } + + @Test + public void testAddAndGet() { + PaddedAtomicLong counter = new PaddedAtomicLong(); + + long value = counter.addAndGet(1); + assertEquals(1, value); + assertEquals(1, counter.get()); + } + + @Test + public void testGetAndUpdate() { + PaddedAtomicLong counter = new PaddedAtomicLong(); + + long value = counter.getAndUpdate(operand -> operand + 2); + assertEquals(0, value); + assertEquals(2, counter.get()); + } + + @Test + public void testUpdateAndGet() { + PaddedAtomicLong counter = new PaddedAtomicLong(); + + long value = counter.updateAndGet(operand -> operand + 2); + assertEquals(2, value); + assertEquals(2, counter.get()); + } + + @Test + public void testGetAndAccumulate() { + PaddedAtomicLong counter = new PaddedAtomicLong(10); + + long value = counter.getAndAccumulate(1, (left, right) -> left+right); + + assertEquals(value, 10); + assertEquals(11, counter.get()); + } + + @Test + public void testAccumulateAndGet() { + PaddedAtomicLong counter = new PaddedAtomicLong(10); + + long value = counter.accumulateAndGet(1, (left, right) -> left+right); + + assertEquals(value, 11); + assertEquals(11, counter.get()); + } + + @Test + public void testIntValue() { + PaddedAtomicLong counter = new PaddedAtomicLong(10); + + assertEquals(10, counter.intValue()); + } + + @Test + public void testLongValue() { + PaddedAtomicLong counter = new PaddedAtomicLong(10); + + assertEquals(10, counter.longValue()); + } + + @Test + public void testFloatValue() { + PaddedAtomicLong counter = new PaddedAtomicLong(10); + + assertEquals(10f, counter.floatValue(), 0.01); + } + + @Test + public void testDoubleValue() { + PaddedAtomicLong counter = new PaddedAtomicLong(10); + + assertEquals(10d, counter.doubleValue(), 0.01); + } + + @Test + public void testToString() { + PaddedAtomicLong counter = new PaddedAtomicLong(10); + + assertEquals("10", counter.toString()); + } +} diff --git a/netty-jctools/src/test/java/org/jctools/util/Pow2Test.java b/netty-jctools/src/test/java/org/jctools/util/Pow2Test.java new file mode 100755 index 0000000..7cd6996 --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/util/Pow2Test.java @@ -0,0 +1,41 @@ +package org.jctools.util; + +import org.junit.Test; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.fail; + +public class Pow2Test +{ + static final int MAX_POSITIVE_POW2 = 1 << 30; + + @Test + public void testAlign() + { + assertEquals(4, Pow2.align(2, 4)); + assertEquals(4, Pow2.align(4, 4)); + } + + @Test + public void testRound() + { + assertEquals(4, Pow2.roundToPowerOfTwo(4)); + assertEquals(4, Pow2.roundToPowerOfTwo(3)); + assertEquals(1, Pow2.roundToPowerOfTwo(0)); + assertEquals(MAX_POSITIVE_POW2, Pow2.roundToPowerOfTwo(MAX_POSITIVE_POW2)); + } + + @Test(expected = IllegalArgumentException.class) + public void testMaxRoundException() + { + Pow2.roundToPowerOfTwo(MAX_POSITIVE_POW2 + 1); + fail(); + } + + @Test(expected = IllegalArgumentException.class) + public void testNegativeRoundException() + { + Pow2.roundToPowerOfTwo(-1); + fail(); + } +} diff --git a/netty-jctools/src/test/java/org/jctools/util/QueueFactory.java b/netty-jctools/src/test/java/org/jctools/util/QueueFactory.java new file mode 100644 index 0000000..2565d12 --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/util/QueueFactory.java @@ -0,0 +1,81 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.util; + +import org.jctools.queues.*; +import org.jctools.queues.spec.ConcurrentQueueSpec; +import org.jctools.queues.spec.Ordering; + +import java.util.Queue; +import java.util.concurrent.ConcurrentLinkedQueue; + +/** + * The queue factory produces {@link java.util.Queue} instances based on a best fit to the {@link ConcurrentQueueSpec}. + * This allows minimal dependencies between user code and the queue implementations and gives users a way to express + * their requirements on a higher level. + * + * @author nitsanw + */ +@Deprecated//(since = "3.0.0") +public class QueueFactory +{ + + public static Queue newQueue(ConcurrentQueueSpec qs) + { + if (qs.isBounded()) + { + // SPSC + if (qs.isSpsc()) + { + return new SpscArrayQueue(qs.capacity); + } + // MPSC + else if (qs.isMpsc()) + { + if (qs.ordering != Ordering.NONE) + { + return new MpscArrayQueue(qs.capacity); + } + else + { + return new MpscCompoundQueue(qs.capacity); + } + } + // SPMC + else if (qs.isSpmc()) + { + return new SpmcArrayQueue(qs.capacity); + } + // MPMC + else + { + return new MpmcArrayQueue(qs.capacity); + } + } + else + { + // SPSC + if (qs.isSpsc()) + { + return new SpscLinkedQueue(); + } + // MPSC + else if (qs.isMpsc()) + { + return new MpscLinkedQueue(); + } + } + return new ConcurrentLinkedQueue(); + } +} diff --git a/netty-jctools/src/test/java/org/jctools/util/RangeUtilTest.java b/netty-jctools/src/test/java/org/jctools/util/RangeUtilTest.java new file mode 100644 index 0000000..ef03c6d --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/util/RangeUtilTest.java @@ -0,0 +1,128 @@ +package org.jctools.util; + +import org.junit.Test; + +import static org.hamcrest.CoreMatchers.equalTo; +import static org.hamcrest.CoreMatchers.is; +import static org.junit.Assert.assertThat; + +public class RangeUtilTest +{ + + @Test(expected = IllegalArgumentException.class) + public void checkPositiveMustFailIfArgumentIsZero() + { + RangeUtil.checkPositive(0, "var"); + } + + @Test(expected = IllegalArgumentException.class) + public void checkPositiveMustFailIfArgumentIsLessThanZero() + { + RangeUtil.checkPositive(-1, "var"); + } + + @Test + public void checkPositiveMustPassIfArgumentIsGreaterThanZero() + { + final long n = 1; + final long actual = RangeUtil.checkPositive(n, "var"); + + assertThat(actual, is(equalTo(n))); + } + + @Test(expected = IllegalArgumentException.class) + public void checkPositiveOrZeroMustFailIfArgumentIsNegative() + { + RangeUtil.checkPositiveOrZero(-1, "var"); + } + + @Test + public void checkPositiveOrZeroMustPassIfArgumentIsZero() + { + final int n = 0; + final int actual = RangeUtil.checkPositiveOrZero(n, "var"); + + assertThat(actual, is(equalTo(n))); + } + + @Test + public void checkPositiveOrZeroMustPassIfArgumentIsGreaterThanZero() + { + final int n = 1; + final int actual = RangeUtil.checkPositiveOrZero(n, "var"); + + assertThat(actual, is(equalTo(n))); + } + + @Test(expected = IllegalArgumentException.class) + public void checkLessThanMustFailIfArgumentIsGreaterThanExpected() + { + RangeUtil.checkLessThan(1, 0, "var"); + } + + @Test(expected = IllegalArgumentException.class) + public void checkLessThanMustFailIfArgumentIsEqualToExpected() + { + final int n = 1; + final int actual = RangeUtil.checkLessThan(1, 1, "var"); + + assertThat(actual, is(equalTo(n))); + } + + @Test + public void checkLessThanMustPassIfArgumentIsLessThanExpected() + { + final int n = 0; + final int actual = RangeUtil.checkLessThan(n, 1, "var"); + + assertThat(actual, is(equalTo(n))); + } + + @Test(expected = IllegalArgumentException.class) + public void checkLessThanOrEqualMustFailIfArgumentIsGreaterThanExpected() + { + RangeUtil.checkLessThanOrEqual(1, 0, "var"); + } + + @Test + public void checkLessThanOrEqualMustPassIfArgumentIsEqualToExpected() + { + final int n = 1; + final int actual = RangeUtil.checkLessThanOrEqual(n, 1, "var"); + + assertThat(actual, is(equalTo(n))); + } + + @Test + public void checkLessThanOrEqualMustPassIfArgumentIsLessThanExpected() + { + final int n = 0; + final int actual = RangeUtil.checkLessThanOrEqual(n, 1, "var"); + + assertThat(actual, is(equalTo(n))); + } + + @Test(expected = IllegalArgumentException.class) + public void checkGreaterThanOrEqualMustFailIfArgumentIsLessThanExpected() + { + RangeUtil.checkGreaterThanOrEqual(0, 1, "var"); + } + + @Test + public void checkGreaterThanOrEqualMustPassIfArgumentIsEqualToExpected() + { + final int n = 1; + final int actual = RangeUtil.checkGreaterThanOrEqual(n, 1, "var"); + + assertThat(actual, is(equalTo(n))); + } + + @Test + public void checkGreaterThanOrEqualMustPassIfArgumentIsGreaterThanExpected() + { + final int n = 1; + final int actual = RangeUtil.checkGreaterThanOrEqual(n, 0, "var"); + + assertThat(actual, is(equalTo(n))); + } +} \ No newline at end of file diff --git a/netty-jctools/src/test/java/org/jctools/util/TestUtil.java b/netty-jctools/src/test/java/org/jctools/util/TestUtil.java new file mode 100644 index 0000000..9ddb731 --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/util/TestUtil.java @@ -0,0 +1,93 @@ +package org.jctools.util; + +import org.jctools.queues.spec.ConcurrentQueueSpec; +import org.jctools.queues.spec.Ordering; +import org.jctools.queues.spec.Preference; +import org.junit.Assert; + +import java.util.List; +import java.util.Queue; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.locks.LockSupport; + +import static org.jctools.util.AtomicQueueFactory.newAtomicQueue; +import static org.jctools.util.QueueFactory.newQueue; +import static org.jctools.util.UnpaddedQueueFactory.newUnpaddedQueue; + +public class TestUtil { + public static final int CONCURRENT_TEST_DURATION = Integer.getInteger("org.jctools.concTestDurationMs", 500); + public static final int CONCURRENT_TEST_THREADS = Integer.getInteger("org.jctools.concTestThreads", Math.min(4, Runtime.getRuntime().availableProcessors())); + public static final int TEST_TIMEOUT = 30000; + private static final AtomicInteger threadIndex = new AtomicInteger(); + public static void sleepQuietly(long timeMs) { + LockSupport.parkNanos(TimeUnit.MILLISECONDS.toNanos(timeMs)); + } + + public static void startWaitJoin(AtomicBoolean stop, List threads) throws InterruptedException + { + startWaitJoin(stop, threads, 1); + } + + public static void startWaitJoin( + AtomicBoolean stop, + List threads, + int waitDivisor) throws InterruptedException + { + int waitMs = CONCURRENT_TEST_DURATION / waitDivisor; + + for (Thread t : threads) t.start(); + Thread.sleep(waitMs); + stop.set(true); + for (Thread t : threads) t.join(); + stop.set(false); + } + + public static void threads(Runnable runnable, int count, List threads) + { + if (count <= 0) + count = CONCURRENT_TEST_THREADS - 1; + + for (int i = 0; i < count; i++) + { + Thread thread = new Thread(runnable); + thread.setName("JCTools test thread-" + threadIndex.getAndIncrement()); + threads.add(thread); + } + } + + public static Object[] makeParams(int producers, int consumers, int capacity, Ordering ordering, Queue q) + { + Assert.assertNotNull(q); + return new Object[] {makeSpec(producers, consumers, capacity, ordering), q}; + } + + public static Object[] makeMpq(int producers, int consumers, int capacity, Ordering ordering) + { + ConcurrentQueueSpec spec = makeSpec(producers, consumers, capacity, ordering); + return new Object[] {spec, newQueue(spec)}; + } + + public static Object[] makeAtomic(int producers, int consumers, int capacity, Ordering ordering) + { + ConcurrentQueueSpec spec = makeSpec(producers, consumers, capacity, ordering); + return new Object[] {spec, newAtomicQueue(spec)}; + } + + public static Object[] makeUnpadded(int producers, int consumers, int capacity, Ordering ordering) + { + ConcurrentQueueSpec spec = makeSpec(producers, consumers, capacity, ordering); + return new Object[] {spec, newUnpaddedQueue(spec)}; + } + + static ConcurrentQueueSpec makeSpec(int producers, int consumers, int capacity, Ordering ordering) + { + return new ConcurrentQueueSpec(producers, consumers, capacity, ordering, Preference.NONE); + } + + public static final class Val + { + public int value; + } +} diff --git a/netty-jctools/src/test/java/org/jctools/util/UnpaddedQueueFactory.java b/netty-jctools/src/test/java/org/jctools/util/UnpaddedQueueFactory.java new file mode 100644 index 0000000..f664b8d --- /dev/null +++ b/netty-jctools/src/test/java/org/jctools/util/UnpaddedQueueFactory.java @@ -0,0 +1,73 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.jctools.util; + +import org.jctools.queues.spec.ConcurrentQueueSpec; +import org.jctools.queues.unpadded.*; + +import java.util.Queue; +import java.util.concurrent.ConcurrentLinkedQueue; + +/** + * The queue factory produces {@link java.util.Queue} instances based on a best fit to the {@link ConcurrentQueueSpec}. + * This allows minimal dependencies between user code and the queue implementations and gives users a way to express + * their requirements on a higher level. + * + * @author nitsanw + * @author akarnokd + */ +@Deprecated//(since = "4.0.0") +public class UnpaddedQueueFactory +{ + public static Queue newUnpaddedQueue(ConcurrentQueueSpec qs) + { + if (qs.isBounded()) + { + // SPSC + if (qs.isSpsc()) + { + return new SpscUnpaddedArrayQueue(qs.capacity); + } + // MPSC + else if (qs.isMpsc()) + { + return new MpscUnpaddedArrayQueue(qs.capacity); + } + // SPMC + else if (qs.isSpmc()) + { + return new SpmcUnpaddedArrayQueue(qs.capacity); + } + // MPMC + else + { + return new MpmcUnpaddedArrayQueue(qs.capacity); + } + } + else + { + // SPSC + if (qs.isSpsc()) + { + return new SpscLinkedUnpaddedQueue(); + } + // MPSC + else if (qs.isMpsc()) + { + return new MpscLinkedUnpaddedQueue(); + } + } + return new ConcurrentLinkedQueue(); + } +} diff --git a/netty-jctools/src/test/resources/logging.properties b/netty-jctools/src/test/resources/logging.properties new file mode 100644 index 0000000..3cd7309 --- /dev/null +++ b/netty-jctools/src/test/resources/logging.properties @@ -0,0 +1,7 @@ +handlers=java.util.logging.ConsoleHandler +.level=ALL +java.util.logging.SimpleFormatter.format=%1$tY-%1$tm-%1$td %1$tH:%1$tM:%1$tS.%1$tL %4$-7s [%3$s] %5$s %6$s%n +java.util.logging.ConsoleHandler.level=ALL +java.util.logging.ConsoleHandler.formatter=java.util.logging.SimpleFormatter +jdk.event.security.level=INFO +org.junit.jupiter.engine.execution.ConditionEvaluator.level=OFF diff --git a/netty-resolver/build.gradle b/netty-resolver/build.gradle new file mode 100644 index 0000000..01c6b2a --- /dev/null +++ b/netty-resolver/build.gradle @@ -0,0 +1,4 @@ +dependencies { + api project(':netty-util') + testImplementation testLibs.mockito.core +} diff --git a/netty-resolver/src/main/java/io/netty/resolver/AbstractAddressResolver.java b/netty-resolver/src/main/java/io/netty/resolver/AbstractAddressResolver.java new file mode 100644 index 0000000..a4bf144 --- /dev/null +++ b/netty-resolver/src/main/java/io/netty/resolver/AbstractAddressResolver.java @@ -0,0 +1,206 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.resolver; + +import io.netty.util.concurrent.EventExecutor; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.Promise; +import io.netty.util.internal.TypeParameterMatcher; + +import java.net.SocketAddress; +import java.nio.channels.UnsupportedAddressTypeException; +import java.util.Collections; +import java.util.List; + +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +/** + * A skeletal {@link AddressResolver} implementation. + */ +public abstract class AbstractAddressResolver implements AddressResolver { + + private final EventExecutor executor; + private final TypeParameterMatcher matcher; + + /** + * @param executor the {@link EventExecutor} which is used to notify the listeners of the {@link Future} returned + * by {@link #resolve(SocketAddress)} + */ + protected AbstractAddressResolver(EventExecutor executor) { + this.executor = checkNotNull(executor, "executor"); + this.matcher = TypeParameterMatcher.find(this, AbstractAddressResolver.class, "T"); + } + + /** + * @param executor the {@link EventExecutor} which is used to notify the listeners of the {@link Future} returned + * by {@link #resolve(SocketAddress)} + * @param addressType the type of the {@link SocketAddress} supported by this resolver + */ + protected AbstractAddressResolver(EventExecutor executor, Class addressType) { + this.executor = checkNotNull(executor, "executor"); + this.matcher = TypeParameterMatcher.get(addressType); + } + + /** + * Returns the {@link EventExecutor} which is used to notify the listeners of the {@link Future} returned + * by {@link #resolve(SocketAddress)}. + */ + protected EventExecutor executor() { + return executor; + } + + @Override + public boolean isSupported(SocketAddress address) { + return matcher.match(address); + } + + @Override + public final boolean isResolved(SocketAddress address) { + if (!isSupported(address)) { + throw new UnsupportedAddressTypeException(); + } + + @SuppressWarnings("unchecked") + final T castAddress = (T) address; + return doIsResolved(castAddress); + } + + /** + * Invoked by {@link #isResolved(SocketAddress)} to check if the specified {@code address} has been resolved + * already. + */ + protected abstract boolean doIsResolved(T address); + + @Override + public final Future resolve(SocketAddress address) { + if (!isSupported(checkNotNull(address, "address"))) { + // Address type not supported by the resolver + return executor().newFailedFuture(new UnsupportedAddressTypeException()); + } + + if (isResolved(address)) { + // Resolved already; no need to perform a lookup + @SuppressWarnings("unchecked") + final T cast = (T) address; + return executor.newSucceededFuture(cast); + } + + try { + @SuppressWarnings("unchecked") + final T cast = (T) address; + final Promise promise = executor().newPromise(); + doResolve(cast, promise); + return promise; + } catch (Exception e) { + return executor().newFailedFuture(e); + } + } + + @Override + public final Future resolve(SocketAddress address, Promise promise) { + checkNotNull(address, "address"); + checkNotNull(promise, "promise"); + + if (!isSupported(address)) { + // Address type not supported by the resolver + return promise.setFailure(new UnsupportedAddressTypeException()); + } + + if (isResolved(address)) { + // Resolved already; no need to perform a lookup + @SuppressWarnings("unchecked") + final T cast = (T) address; + return promise.setSuccess(cast); + } + + try { + @SuppressWarnings("unchecked") + final T cast = (T) address; + doResolve(cast, promise); + return promise; + } catch (Exception e) { + return promise.setFailure(e); + } + } + + @Override + public final Future> resolveAll(SocketAddress address) { + if (!isSupported(checkNotNull(address, "address"))) { + // Address type not supported by the resolver + return executor().newFailedFuture(new UnsupportedAddressTypeException()); + } + + if (isResolved(address)) { + // Resolved already; no need to perform a lookup + @SuppressWarnings("unchecked") + final T cast = (T) address; + return executor.newSucceededFuture(Collections.singletonList(cast)); + } + + try { + @SuppressWarnings("unchecked") + final T cast = (T) address; + final Promise> promise = executor().newPromise(); + doResolveAll(cast, promise); + return promise; + } catch (Exception e) { + return executor().newFailedFuture(e); + } + } + + @Override + public final Future> resolveAll(SocketAddress address, Promise> promise) { + checkNotNull(address, "address"); + checkNotNull(promise, "promise"); + + if (!isSupported(address)) { + // Address type not supported by the resolver + return promise.setFailure(new UnsupportedAddressTypeException()); + } + + if (isResolved(address)) { + // Resolved already; no need to perform a lookup + @SuppressWarnings("unchecked") + final T cast = (T) address; + return promise.setSuccess(Collections.singletonList(cast)); + } + + try { + @SuppressWarnings("unchecked") + final T cast = (T) address; + doResolveAll(cast, promise); + return promise; + } catch (Exception e) { + return promise.setFailure(e); + } + } + + /** + * Invoked by {@link #resolve(SocketAddress)} to perform the actual name + * resolution. + */ + protected abstract void doResolve(T unresolvedAddress, Promise promise) throws Exception; + + /** + * Invoked by {@link #resolveAll(SocketAddress)} to perform the actual name + * resolution. + */ + protected abstract void doResolveAll(T unresolvedAddress, Promise> promise) throws Exception; + + @Override + public void close() { } +} diff --git a/netty-resolver/src/main/java/io/netty/resolver/AddressResolver.java b/netty-resolver/src/main/java/io/netty/resolver/AddressResolver.java new file mode 100644 index 0000000..6012f08 --- /dev/null +++ b/netty-resolver/src/main/java/io/netty/resolver/AddressResolver.java @@ -0,0 +1,90 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.resolver; + +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.Promise; + +import java.io.Closeable; +import java.net.SocketAddress; +import java.nio.channels.UnsupportedAddressTypeException; +import java.util.List; + +/** + * Resolves a possibility unresolved {@link SocketAddress}. + */ +public interface AddressResolver extends Closeable { + + /** + * Returns {@code true} if and only if the specified address is supported by this resolved. + */ + boolean isSupported(SocketAddress address); + + /** + * Returns {@code true} if and only if the specified address has been resolved. + * + * @throws UnsupportedAddressTypeException if the specified address is not supported by this resolver + */ + boolean isResolved(SocketAddress address); + + /** + * Resolves the specified address. If the specified address is resolved already, this method does nothing + * but returning the original address. + * + * @param address the address to resolve + * + * @return the {@link SocketAddress} as the result of the resolution + */ + Future resolve(SocketAddress address); + + /** + * Resolves the specified address. If the specified address is resolved already, this method does nothing + * but returning the original address. + * + * @param address the address to resolve + * @param promise the {@link Promise} which will be fulfilled when the name resolution is finished + * + * @return the {@link SocketAddress} as the result of the resolution + */ + Future resolve(SocketAddress address, Promise promise); + + /** + * Resolves the specified address. If the specified address is resolved already, this method does nothing + * but returning the original address. + * + * @param address the address to resolve + * + * @return the list of the {@link SocketAddress}es as the result of the resolution + */ + Future> resolveAll(SocketAddress address); + + /** + * Resolves the specified address. If the specified address is resolved already, this method does nothing + * but returning the original address. + * + * @param address the address to resolve + * @param promise the {@link Promise} which will be fulfilled when the name resolution is finished + * + * @return the list of the {@link SocketAddress}es as the result of the resolution + */ + Future> resolveAll(SocketAddress address, Promise> promise); + + /** + * Closes all the resources allocated and used by this resolver. + */ + @Override + void close(); +} diff --git a/netty-resolver/src/main/java/io/netty/resolver/AddressResolverGroup.java b/netty-resolver/src/main/java/io/netty/resolver/AddressResolverGroup.java new file mode 100644 index 0000000..0af80f0 --- /dev/null +++ b/netty-resolver/src/main/java/io/netty/resolver/AddressResolverGroup.java @@ -0,0 +1,131 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.resolver; + +import io.netty.util.concurrent.EventExecutor; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.FutureListener; +import io.netty.util.concurrent.GenericFutureListener; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.io.Closeable; +import java.net.SocketAddress; +import java.util.IdentityHashMap; +import java.util.Map; +import java.util.concurrent.ConcurrentMap; + +/** + * Creates and manages {@link NameResolver}s so that each {@link EventExecutor} has its own resolver instance. + */ +public abstract class AddressResolverGroup implements Closeable { + + private static final InternalLogger logger = InternalLoggerFactory.getInstance(AddressResolverGroup.class); + + /** + * Note that we do not use a {@link ConcurrentMap} here because it is usually expensive to instantiate a resolver. + */ + private final Map> resolvers = + new IdentityHashMap>(); + + private final Map>> executorTerminationListeners = + new IdentityHashMap>>(); + + protected AddressResolverGroup() { } + + /** + * Returns the {@link AddressResolver} associated with the specified {@link EventExecutor}. If there's no associated + * resolver found, this method creates and returns a new resolver instance created by + * {@link #newResolver(EventExecutor)} so that the new resolver is reused on another + * {@code #getResolver(EventExecutor)} call with the same {@link EventExecutor}. + */ + public AddressResolver getResolver(final EventExecutor executor) { + ObjectUtil.checkNotNull(executor, "executor"); + + if (executor.isShuttingDown()) { + throw new IllegalStateException("executor not accepting a task"); + } + + AddressResolver r; + synchronized (resolvers) { + r = resolvers.get(executor); + if (r == null) { + final AddressResolver newResolver; + try { + newResolver = newResolver(executor); + } catch (Exception e) { + throw new IllegalStateException("failed to create a new resolver", e); + } + + resolvers.put(executor, newResolver); + + final FutureListener terminationListener = new FutureListener() { + @Override + public void operationComplete(Future future) { + synchronized (resolvers) { + resolvers.remove(executor); + executorTerminationListeners.remove(executor); + } + newResolver.close(); + } + }; + + executorTerminationListeners.put(executor, terminationListener); + executor.terminationFuture().addListener(terminationListener); + + r = newResolver; + } + } + + return r; + } + + /** + * Invoked by {@link #getResolver(EventExecutor)} to create a new {@link AddressResolver}. + */ + protected abstract AddressResolver newResolver(EventExecutor executor) throws Exception; + + /** + * Closes all {@link NameResolver}s created by this group. + */ + @Override + @SuppressWarnings({ "unchecked", "SuspiciousToArrayCall" }) + public void close() { + final AddressResolver[] rArray; + final Map.Entry>>[] listeners; + + synchronized (resolvers) { + rArray = (AddressResolver[]) resolvers.values().toArray(new AddressResolver[0]); + resolvers.clear(); + listeners = executorTerminationListeners.entrySet().toArray(new Map.Entry[0]); + executorTerminationListeners.clear(); + } + + for (final Map.Entry>> entry : listeners) { + entry.getKey().terminationFuture().removeListener(entry.getValue()); + } + + for (final AddressResolver r: rArray) { + try { + r.close(); + } catch (Throwable t) { + logger.warn("Failed to close a resolver:", t); + } + } + } +} diff --git a/netty-resolver/src/main/java/io/netty/resolver/CompositeNameResolver.java b/netty-resolver/src/main/java/io/netty/resolver/CompositeNameResolver.java new file mode 100644 index 0000000..09dacc5 --- /dev/null +++ b/netty-resolver/src/main/java/io/netty/resolver/CompositeNameResolver.java @@ -0,0 +1,107 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.resolver; + +import io.netty.util.concurrent.EventExecutor; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.FutureListener; +import io.netty.util.concurrent.Promise; +import io.netty.util.internal.ObjectUtil; + +import java.util.Arrays; +import java.util.List; + +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +/** + * A composite {@link SimpleNameResolver} that resolves a host name against a sequence of {@link NameResolver}s. + * + * In case of a failure, only the last one will be reported. + */ +public final class CompositeNameResolver extends SimpleNameResolver { + + private final NameResolver[] resolvers; + + /** + * @param executor the {@link EventExecutor} which is used to notify the listeners of the {@link Future} returned + * by {@link #resolve(String)} + * @param resolvers the {@link NameResolver}s to be tried sequentially + */ + public CompositeNameResolver(EventExecutor executor, NameResolver... resolvers) { + super(executor); + checkNotNull(resolvers, "resolvers"); + for (int i = 0; i < resolvers.length; i++) { + ObjectUtil.checkNotNull(resolvers[i], "resolvers[" + i + ']'); + } + if (resolvers.length < 2) { + throw new IllegalArgumentException("resolvers: " + Arrays.asList(resolvers) + + " (expected: at least 2 resolvers)"); + } + this.resolvers = resolvers.clone(); + } + + @Override + protected void doResolve(String inetHost, Promise promise) throws Exception { + doResolveRec(inetHost, promise, 0, null); + } + + private void doResolveRec(final String inetHost, + final Promise promise, + final int resolverIndex, + Throwable lastFailure) throws Exception { + if (resolverIndex >= resolvers.length) { + promise.setFailure(lastFailure); + } else { + NameResolver resolver = resolvers[resolverIndex]; + resolver.resolve(inetHost).addListener(new FutureListener() { + @Override + public void operationComplete(Future future) throws Exception { + if (future.isSuccess()) { + promise.setSuccess(future.getNow()); + } else { + doResolveRec(inetHost, promise, resolverIndex + 1, future.cause()); + } + } + }); + } + } + + @Override + protected void doResolveAll(String inetHost, Promise> promise) throws Exception { + doResolveAllRec(inetHost, promise, 0, null); + } + + private void doResolveAllRec(final String inetHost, + final Promise> promise, + final int resolverIndex, + Throwable lastFailure) throws Exception { + if (resolverIndex >= resolvers.length) { + promise.setFailure(lastFailure); + } else { + NameResolver resolver = resolvers[resolverIndex]; + resolver.resolveAll(inetHost).addListener(new FutureListener>() { + @Override + public void operationComplete(Future> future) throws Exception { + if (future.isSuccess()) { + promise.setSuccess(future.getNow()); + } else { + doResolveAllRec(inetHost, promise, resolverIndex + 1, future.cause()); + } + } + }); + } + } +} diff --git a/netty-resolver/src/main/java/io/netty/resolver/DefaultAddressResolverGroup.java b/netty-resolver/src/main/java/io/netty/resolver/DefaultAddressResolverGroup.java new file mode 100644 index 0000000..15980fc --- /dev/null +++ b/netty-resolver/src/main/java/io/netty/resolver/DefaultAddressResolverGroup.java @@ -0,0 +1,36 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.resolver; + +import io.netty.util.concurrent.EventExecutor; + +import java.net.InetSocketAddress; + +/** + * A {@link AddressResolverGroup} of {@link DefaultNameResolver}s. + */ +public final class DefaultAddressResolverGroup extends AddressResolverGroup { + + public static final DefaultAddressResolverGroup INSTANCE = new DefaultAddressResolverGroup(); + + private DefaultAddressResolverGroup() { } + + @Override + protected AddressResolver newResolver(EventExecutor executor) throws Exception { + return new DefaultNameResolver(executor).asAddressResolver(); + } +} diff --git a/netty-resolver/src/main/java/io/netty/resolver/DefaultHostsFileEntriesResolver.java b/netty-resolver/src/main/java/io/netty/resolver/DefaultHostsFileEntriesResolver.java new file mode 100644 index 0000000..e3ab452 --- /dev/null +++ b/netty-resolver/src/main/java/io/netty/resolver/DefaultHostsFileEntriesResolver.java @@ -0,0 +1,148 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.resolver; + +import io.netty.util.CharsetUtil; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.SystemPropertyUtil; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.net.InetAddress; +import java.nio.charset.Charset; +import java.util.ArrayList; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.concurrent.atomic.AtomicLong; + +/** + * Default {@link HostsFileEntriesResolver} that resolves hosts file entries only once. + */ +public final class DefaultHostsFileEntriesResolver implements HostsFileEntriesResolver { + + private static final InternalLogger logger = + InternalLoggerFactory.getInstance(DefaultHostsFileEntriesResolver.class); + private static final long DEFAULT_REFRESH_INTERVAL; + + private final long refreshInterval; + private final AtomicLong lastRefresh = new AtomicLong(System.nanoTime()); + private final HostsFileEntriesProvider.Parser hostsFileParser; + private volatile Map> inet4Entries; + private volatile Map> inet6Entries; + + static { + DEFAULT_REFRESH_INTERVAL = SystemPropertyUtil.getLong( + "io.netty.hostsFileRefreshInterval", /*nanos*/0); + + if (logger.isDebugEnabled()) { + logger.debug("-Dio.netty.hostsFileRefreshInterval: {}", DEFAULT_REFRESH_INTERVAL); + } + } + + public DefaultHostsFileEntriesResolver() { + this(HostsFileEntriesProvider.parser(), DEFAULT_REFRESH_INTERVAL); + } + + // for testing purpose only + DefaultHostsFileEntriesResolver(HostsFileEntriesProvider.Parser hostsFileParser, long refreshInterval) { + this.hostsFileParser = hostsFileParser; + this.refreshInterval = ObjectUtil.checkPositiveOrZero(refreshInterval, "refreshInterval"); + HostsFileEntriesProvider entries = parseEntries(hostsFileParser); + inet4Entries = entries.ipv4Entries(); + inet6Entries = entries.ipv6Entries(); + } + + @Override + public InetAddress address(String inetHost, ResolvedAddressTypes resolvedAddressTypes) { + return firstAddress(addresses(inetHost, resolvedAddressTypes)); + } + + /** + * Resolves all addresses of a hostname against the entries in a hosts file, depending on the specified + * {@link ResolvedAddressTypes}. + * + * @param inetHost the hostname to resolve + * @param resolvedAddressTypes the address types to resolve + * @return all matching addresses or {@code null} in case the hostname cannot be resolved + */ + public List addresses(String inetHost, ResolvedAddressTypes resolvedAddressTypes) { + String normalized = normalize(inetHost); + ensureHostsFileEntriesAreFresh(); + + switch (resolvedAddressTypes) { + case IPV4_ONLY: + return inet4Entries.get(normalized); + case IPV6_ONLY: + return inet6Entries.get(normalized); + case IPV4_PREFERRED: + List allInet4Addresses = inet4Entries.get(normalized); + return allInet4Addresses != null ? allAddresses(allInet4Addresses, inet6Entries.get(normalized)) : + inet6Entries.get(normalized); + case IPV6_PREFERRED: + List allInet6Addresses = inet6Entries.get(normalized); + return allInet6Addresses != null ? allAddresses(allInet6Addresses, inet4Entries.get(normalized)) : + inet4Entries.get(normalized); + default: + throw new IllegalArgumentException("Unknown ResolvedAddressTypes " + resolvedAddressTypes); + } + } + + private void ensureHostsFileEntriesAreFresh() { + long interval = refreshInterval; + if (interval == 0) { + return; + } + long last = lastRefresh.get(); + long currentTime = System.nanoTime(); + if (currentTime - last > interval) { + if (lastRefresh.compareAndSet(last, currentTime)) { + HostsFileEntriesProvider entries = parseEntries(hostsFileParser); + inet4Entries = entries.ipv4Entries(); + inet6Entries = entries.ipv6Entries(); + } + } + } + + // package-private for testing purposes + String normalize(String inetHost) { + return inetHost.toLowerCase(Locale.ENGLISH); + } + + private static List allAddresses(List a, List b) { + List result = new ArrayList(a.size() + (b == null ? 0 : b.size())); + result.addAll(a); + if (b != null) { + result.addAll(b); + } + return result; + } + + private static InetAddress firstAddress(List addresses) { + return addresses != null && !addresses.isEmpty() ? addresses.get(0) : null; + } + + private static HostsFileEntriesProvider parseEntries(HostsFileEntriesProvider.Parser parser) { + if (PlatformDependent.isWindows()) { + // Ony windows there seems to be no standard for the encoding used for the hosts file, so let us + // try multiple until we either were able to parse it or there is none left and so we return an + // empty instance. + return parser.parseSilently(Charset.defaultCharset(), CharsetUtil.UTF_16, CharsetUtil.UTF_8); + } + return parser.parseSilently(); + } +} diff --git a/netty-resolver/src/main/java/io/netty/resolver/DefaultNameResolver.java b/netty-resolver/src/main/java/io/netty/resolver/DefaultNameResolver.java new file mode 100644 index 0000000..b399d41 --- /dev/null +++ b/netty-resolver/src/main/java/io/netty/resolver/DefaultNameResolver.java @@ -0,0 +1,55 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.resolver; + +import io.netty.util.internal.SocketUtils; +import io.netty.util.concurrent.EventExecutor; +import io.netty.util.concurrent.Promise; + +import java.net.InetAddress; +import java.net.UnknownHostException; +import java.util.Arrays; +import java.util.List; + +/** + * A {@link InetNameResolver} that resolves using JDK's built-in domain name lookup mechanism. + * Note that this resolver performs a blocking name lookup from the caller thread. + */ +public class DefaultNameResolver extends InetNameResolver { + + public DefaultNameResolver(EventExecutor executor) { + super(executor); + } + + @Override + protected void doResolve(String inetHost, Promise promise) throws Exception { + try { + promise.setSuccess(SocketUtils.addressByName(inetHost)); + } catch (UnknownHostException e) { + promise.setFailure(e); + } + } + + @Override + protected void doResolveAll(String inetHost, Promise> promise) throws Exception { + try { + promise.setSuccess(Arrays.asList(SocketUtils.allAddressesByName(inetHost))); + } catch (UnknownHostException e) { + promise.setFailure(e); + } + } +} diff --git a/netty-resolver/src/main/java/io/netty/resolver/HostsFileEntries.java b/netty-resolver/src/main/java/io/netty/resolver/HostsFileEntries.java new file mode 100644 index 0000000..7c79522 --- /dev/null +++ b/netty-resolver/src/main/java/io/netty/resolver/HostsFileEntries.java @@ -0,0 +1,62 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.resolver; + +import java.net.Inet4Address; +import java.net.Inet6Address; +import java.util.Collections; +import java.util.HashMap; +import java.util.Map; + +/** + * A container of hosts file entries. + * The mappings contain only the first entry per hostname. + * Consider using {@link HostsFileEntriesProvider} when mappings with all entries per hostname are needed. + */ +public final class HostsFileEntries { + + /** + * Empty entries + */ + static final HostsFileEntries EMPTY = + new HostsFileEntries( + Collections.emptyMap(), + Collections.emptyMap()); + + private final Map inet4Entries; + private final Map inet6Entries; + + public HostsFileEntries(Map inet4Entries, Map inet6Entries) { + this.inet4Entries = Collections.unmodifiableMap(new HashMap(inet4Entries)); + this.inet6Entries = Collections.unmodifiableMap(new HashMap(inet6Entries)); + } + + /** + * The IPv4 entries + * @return the IPv4 entries + */ + public Map inet4Entries() { + return inet4Entries; + } + + /** + * The IPv6 entries + * @return the IPv6 entries + */ + public Map inet6Entries() { + return inet6Entries; + } +} diff --git a/netty-resolver/src/main/java/io/netty/resolver/HostsFileEntriesProvider.java b/netty-resolver/src/main/java/io/netty/resolver/HostsFileEntriesProvider.java new file mode 100644 index 0000000..aaf2920 --- /dev/null +++ b/netty-resolver/src/main/java/io/netty/resolver/HostsFileEntriesProvider.java @@ -0,0 +1,317 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.resolver; + +import io.netty.util.NetUtil; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +import java.io.BufferedReader; +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStreamReader; +import java.io.Reader; +import java.net.Inet4Address; +import java.net.InetAddress; +import java.nio.charset.Charset; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.regex.Pattern; + +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +/** + * A container of hosts file entries + */ +public final class HostsFileEntriesProvider { + + public interface Parser { + + /** + * Parses the hosts file at standard OS location using the system default {@link Charset} for decoding. + * + * @return a new {@link HostsFileEntriesProvider} + * @throws IOException file could not be read + */ + HostsFileEntriesProvider parse() throws IOException; + + /** + * Parses the hosts file at standard OS location using the given {@link Charset}s one after another until + * parse something or none is left. + * + * @param charsets the {@link Charset}s to try as file encodings when parsing + * @return a new {@link HostsFileEntriesProvider} + * @throws IOException file could not be read + */ + HostsFileEntriesProvider parse(Charset... charsets) throws IOException; + + /** + * Parses the provided hosts file using the given {@link Charset}s one after another until + * parse something or none is left. In case {@link Charset}s are not provided, + * the system default {@link Charset} is used for decoding. + * + * @param file the file to be parsed + * @param charsets the {@link Charset}s to try as file encodings when parsing, in case {@link Charset}s + * are not provided, the system default {@link Charset} is used for decoding + * @return a new {@link HostsFileEntriesProvider} + * @throws IOException file could not be read + */ + HostsFileEntriesProvider parse(File file, Charset... charsets) throws IOException; + + /** + * Performs the parsing operation using the provided reader of hosts file format. + * + * @param reader the reader of hosts file format + * @return a new {@link HostsFileEntriesProvider} + */ + HostsFileEntriesProvider parse(Reader reader) throws IOException; + + /** + * Parses the hosts file at standard OS location using the system default {@link Charset} for decoding. + * + * @return a new {@link HostsFileEntriesProvider} + */ + HostsFileEntriesProvider parseSilently(); + + /** + * Parses the hosts file at standard OS location using the given {@link Charset}s one after another until + * parse something or none is left. + * + * @param charsets the {@link Charset}s to try as file encodings when parsing + * @return a new {@link HostsFileEntriesProvider} + */ + HostsFileEntriesProvider parseSilently(Charset... charsets); + + /** + * Parses the provided hosts file using the given {@link Charset}s one after another until + * parse something or none is left. In case {@link Charset}s are not provided, + * the system default {@link Charset} is used for decoding. + * + * @param file the file to be parsed + * @param charsets the {@link Charset}s to try as file encodings when parsing, in case {@link Charset}s + * are not provided, the system default {@link Charset} is used for decoding + * @return a new {@link HostsFileEntriesProvider} + */ + HostsFileEntriesProvider parseSilently(File file, Charset... charsets); + } + + /** + * Creates a parser for {@link HostsFileEntriesProvider}. + * + * @return a new {@link HostsFileEntriesProvider.Parser} + */ + public static Parser parser() { + return ParserImpl.INSTANCE; + } + + static final HostsFileEntriesProvider EMPTY = + new HostsFileEntriesProvider( + Collections.>emptyMap(), + Collections.>emptyMap()); + + private final Map> ipv4Entries; + private final Map> ipv6Entries; + + HostsFileEntriesProvider(Map> ipv4Entries, Map> ipv6Entries) { + this.ipv4Entries = Collections.unmodifiableMap(new HashMap>(ipv4Entries)); + this.ipv6Entries = Collections.unmodifiableMap(new HashMap>(ipv6Entries)); + } + + /** + * The IPv4 entries. + * + * @return the IPv4 entries + */ + public Map> ipv4Entries() { + return ipv4Entries; + } + + /** + * The IPv6 entries. + * + * @return the IPv6 entries + */ + public Map> ipv6Entries() { + return ipv6Entries; + } + + private static final class ParserImpl implements Parser { + + private static final String WINDOWS_DEFAULT_SYSTEM_ROOT = "C:\\Windows"; + private static final String WINDOWS_HOSTS_FILE_RELATIVE_PATH = "\\system32\\drivers\\etc\\hosts"; + private static final String X_PLATFORMS_HOSTS_FILE_PATH = "/etc/hosts"; + + private static final Pattern WHITESPACES = Pattern.compile("[ \t]+"); + + private static final InternalLogger logger = InternalLoggerFactory.getInstance(Parser.class); + + static final ParserImpl INSTANCE = new ParserImpl(); + + private ParserImpl() { + // singleton + } + + @Override + public HostsFileEntriesProvider parse() throws IOException { + return parse(locateHostsFile(), Charset.defaultCharset()); + } + + @Override + public HostsFileEntriesProvider parse(Charset... charsets) throws IOException { + return parse(locateHostsFile(), charsets); + } + + @Override + public HostsFileEntriesProvider parse(File file, Charset... charsets) throws IOException { + checkNotNull(file, "file"); + checkNotNull(charsets, "charsets"); + if (charsets.length == 0) { + charsets = new Charset[]{Charset.defaultCharset()}; + } + if (file.exists() && file.isFile()) { + for (Charset charset : charsets) { + BufferedReader reader = new BufferedReader( + new InputStreamReader(new FileInputStream(file), charset)); + try { + HostsFileEntriesProvider entries = parse(reader); + if (entries != HostsFileEntriesProvider.EMPTY) { + return entries; + } + } finally { + reader.close(); + } + } + } + return HostsFileEntriesProvider.EMPTY; + } + + @Override + public HostsFileEntriesProvider parse(Reader reader) throws IOException { + checkNotNull(reader, "reader"); + BufferedReader buff = new BufferedReader(reader); + try { + Map> ipv4Entries = new HashMap>(); + Map> ipv6Entries = new HashMap>(); + String line; + while ((line = buff.readLine()) != null) { + // remove comment + int commentPosition = line.indexOf('#'); + if (commentPosition != -1) { + line = line.substring(0, commentPosition); + } + // skip empty lines + line = line.trim(); + if (line.isEmpty()) { + continue; + } + + // split + List lineParts = new ArrayList(); + for (String s : WHITESPACES.split(line)) { + if (!s.isEmpty()) { + lineParts.add(s); + } + } + + // a valid line should be [IP, hostname, alias*] + if (lineParts.size() < 2) { + // skip invalid line + continue; + } + + byte[] ipBytes = NetUtil.createByteArrayFromIpAddressString(lineParts.get(0)); + + if (ipBytes == null) { + // skip invalid IP + continue; + } + + // loop over hostname and aliases + for (int i = 1; i < lineParts.size(); i++) { + String hostname = lineParts.get(i); + String hostnameLower = hostname.toLowerCase(Locale.ENGLISH); + InetAddress address = InetAddress.getByAddress(hostname, ipBytes); + List addresses; + if (address instanceof Inet4Address) { + addresses = ipv4Entries.get(hostnameLower); + if (addresses == null) { + addresses = new ArrayList(); + ipv4Entries.put(hostnameLower, addresses); + } + } else { + addresses = ipv6Entries.get(hostnameLower); + if (addresses == null) { + addresses = new ArrayList(); + ipv6Entries.put(hostnameLower, addresses); + } + } + addresses.add(address); + } + } + return ipv4Entries.isEmpty() && ipv6Entries.isEmpty() ? + HostsFileEntriesProvider.EMPTY : + new HostsFileEntriesProvider(ipv4Entries, ipv6Entries); + } finally { + try { + buff.close(); + } catch (IOException e) { + logger.warn("Failed to close a reader", e); + } + } + } + + @Override + public HostsFileEntriesProvider parseSilently() { + return parseSilently(locateHostsFile(), Charset.defaultCharset()); + } + + @Override + public HostsFileEntriesProvider parseSilently(Charset... charsets) { + return parseSilently(locateHostsFile(), charsets); + } + + @Override + public HostsFileEntriesProvider parseSilently(File file, Charset... charsets) { + try { + return parse(file, charsets); + } catch (IOException e) { + if (logger.isWarnEnabled()) { + logger.warn("Failed to load and parse hosts file at " + file.getPath(), e); + } + return HostsFileEntriesProvider.EMPTY; + } + } + + private static File locateHostsFile() { + File hostsFile; + if (PlatformDependent.isWindows()) { + hostsFile = new File(System.getenv("SystemRoot") + WINDOWS_HOSTS_FILE_RELATIVE_PATH); + if (!hostsFile.exists()) { + hostsFile = new File(WINDOWS_DEFAULT_SYSTEM_ROOT + WINDOWS_HOSTS_FILE_RELATIVE_PATH); + } + } else { + hostsFile = new File(X_PLATFORMS_HOSTS_FILE_PATH); + } + return hostsFile; + } + } +} diff --git a/netty-resolver/src/main/java/io/netty/resolver/HostsFileEntriesResolver.java b/netty-resolver/src/main/java/io/netty/resolver/HostsFileEntriesResolver.java new file mode 100644 index 0000000..06bd140 --- /dev/null +++ b/netty-resolver/src/main/java/io/netty/resolver/HostsFileEntriesResolver.java @@ -0,0 +1,37 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.resolver; + +import java.net.InetAddress; + +/** + * Resolves a hostname against the hosts file entries. + */ +public interface HostsFileEntriesResolver { + + /** + * Default instance: a {@link DefaultHostsFileEntriesResolver}. + */ + HostsFileEntriesResolver DEFAULT = new DefaultHostsFileEntriesResolver(); + + /** + * Resolve the address of a hostname against the entries in a hosts file, depending on some address types. + * @param inetHost the hostname to resolve + * @param resolvedAddressTypes the address types to resolve + * @return the first matching address + */ + InetAddress address(String inetHost, ResolvedAddressTypes resolvedAddressTypes); +} diff --git a/netty-resolver/src/main/java/io/netty/resolver/HostsFileParser.java b/netty-resolver/src/main/java/io/netty/resolver/HostsFileParser.java new file mode 100644 index 0000000..8a0b5e9 --- /dev/null +++ b/netty-resolver/src/main/java/io/netty/resolver/HostsFileParser.java @@ -0,0 +1,123 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.resolver; + +import java.io.File; +import java.io.IOException; +import java.io.Reader; +import java.net.Inet4Address; +import java.net.Inet6Address; +import java.net.InetAddress; +import java.nio.charset.Charset; +import java.util.List; +import java.util.HashMap; +import java.util.Map; + +/** + * A parser for hosts files. + * The produced mappings contain only the first entry per hostname. + * Consider using {@link HostsFileEntriesProvider} when mappings with all entries per hostname are needed. + */ +public final class HostsFileParser { + + /** + * Parse hosts file at standard OS location using the systems default {@link Charset} for decoding. + * + * @return a {@link HostsFileEntries} + */ + public static HostsFileEntries parseSilently() { + return hostsFileEntries(HostsFileEntriesProvider.parser().parseSilently()); + } + + /** + * Parse hosts file at standard OS location using the given {@link Charset}s one after each other until + * we were able to parse something or none is left. + * + * @param charsets the {@link Charset}s to try as file encodings when parsing. + * @return a {@link HostsFileEntries} + */ + public static HostsFileEntries parseSilently(Charset... charsets) { + return hostsFileEntries(HostsFileEntriesProvider.parser().parseSilently(charsets)); + } + + /** + * Parse hosts file at standard OS location using the system default {@link Charset} for decoding. + * + * @return a {@link HostsFileEntries} + * @throws IOException file could not be read + */ + public static HostsFileEntries parse() throws IOException { + return hostsFileEntries(HostsFileEntriesProvider.parser().parse()); + } + + /** + * Parse a hosts file using the system default {@link Charset} for decoding. + * + * @param file the file to be parsed + * @return a {@link HostsFileEntries} + * @throws IOException file could not be read + */ + public static HostsFileEntries parse(File file) throws IOException { + return hostsFileEntries(HostsFileEntriesProvider.parser().parse(file)); + } + + /** + * Parse a hosts file. + * + * @param file the file to be parsed + * @param charsets the {@link Charset}s to try as file encodings when parsing. + * @return a {@link HostsFileEntries} + * @throws IOException file could not be read + */ + public static HostsFileEntries parse(File file, Charset... charsets) throws IOException { + return hostsFileEntries(HostsFileEntriesProvider.parser().parse(file, charsets)); + } + + /** + * Parse a reader of hosts file format. + * + * @param reader the file to be parsed + * @return a {@link HostsFileEntries} + * @throws IOException file could not be read + */ + public static HostsFileEntries parse(Reader reader) throws IOException { + return hostsFileEntries(HostsFileEntriesProvider.parser().parse(reader)); + } + + /** + * Can't be instantiated. + */ + private HostsFileParser() { + } + + @SuppressWarnings("unchecked") + private static HostsFileEntries hostsFileEntries(HostsFileEntriesProvider provider) { + return provider == HostsFileEntriesProvider.EMPTY ? HostsFileEntries.EMPTY : + new HostsFileEntries((Map) toMapWithSingleValue(provider.ipv4Entries()), + (Map) toMapWithSingleValue(provider.ipv6Entries())); + } + + private static Map toMapWithSingleValue(Map> fromMapWithListValue) { + Map result = new HashMap(fromMapWithListValue.size()); + for (Map.Entry> entry : fromMapWithListValue.entrySet()) { + List value = entry.getValue(); + if (!value.isEmpty()) { + result.put(entry.getKey(), value.get(0)); + } + } + return result; + } +} diff --git a/netty-resolver/src/main/java/io/netty/resolver/InetNameResolver.java b/netty-resolver/src/main/java/io/netty/resolver/InetNameResolver.java new file mode 100644 index 0000000..55e53a7 --- /dev/null +++ b/netty-resolver/src/main/java/io/netty/resolver/InetNameResolver.java @@ -0,0 +1,54 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.resolver; + +import io.netty.util.concurrent.EventExecutor; +import io.netty.util.concurrent.Future; + +import java.net.InetAddress; +import java.net.InetSocketAddress; + +/** + * A skeletal {@link NameResolver} implementation that resolves {@link InetAddress}. + */ +public abstract class InetNameResolver extends SimpleNameResolver { + private volatile AddressResolver addressResolver; + + /** + * @param executor the {@link EventExecutor} which is used to notify the listeners of the {@link Future} returned + * by {@link #resolve(String)} + */ + protected InetNameResolver(EventExecutor executor) { + super(executor); + } + + /** + * Return a {@link AddressResolver} that will use this name resolver underneath. + * It's cached internally, so the same instance is always returned. + */ + public AddressResolver asAddressResolver() { + AddressResolver result = addressResolver; + if (result == null) { + synchronized (this) { + result = addressResolver; + if (result == null) { + addressResolver = result = new InetSocketAddressResolver(executor(), this); + } + } + } + return result; + } +} diff --git a/netty-resolver/src/main/java/io/netty/resolver/InetSocketAddressResolver.java b/netty-resolver/src/main/java/io/netty/resolver/InetSocketAddressResolver.java new file mode 100644 index 0000000..765e6ba --- /dev/null +++ b/netty-resolver/src/main/java/io/netty/resolver/InetSocketAddressResolver.java @@ -0,0 +1,96 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.resolver; + +import io.netty.util.concurrent.EventExecutor; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.FutureListener; +import io.netty.util.concurrent.Promise; + +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.util.ArrayList; +import java.util.List; + +/** + * A {@link AbstractAddressResolver} that resolves {@link InetSocketAddress}. + */ +public class InetSocketAddressResolver extends AbstractAddressResolver { + + final NameResolver nameResolver; + + /** + * @param executor the {@link EventExecutor} which is used to notify the listeners of the {@link Future} returned + * by {@link #resolve(java.net.SocketAddress)} + * @param nameResolver the {@link NameResolver} used for name resolution + */ + public InetSocketAddressResolver(EventExecutor executor, NameResolver nameResolver) { + super(executor, InetSocketAddress.class); + this.nameResolver = nameResolver; + } + + @Override + protected boolean doIsResolved(InetSocketAddress address) { + return !address.isUnresolved(); + } + + @Override + protected void doResolve(final InetSocketAddress unresolvedAddress, final Promise promise) + throws Exception { + // Note that InetSocketAddress.getHostName() will never incur a reverse lookup here, + // because an unresolved address always has a host name. + nameResolver.resolve(unresolvedAddress.getHostName()) + .addListener(new FutureListener() { + @Override + public void operationComplete(Future future) throws Exception { + if (future.isSuccess()) { + promise.setSuccess(new InetSocketAddress(future.getNow(), unresolvedAddress.getPort())); + } else { + promise.setFailure(future.cause()); + } + } + }); + } + + @Override + protected void doResolveAll(final InetSocketAddress unresolvedAddress, + final Promise> promise) throws Exception { + // Note that InetSocketAddress.getHostName() will never incur a reverse lookup here, + // because an unresolved address always has a host name. + nameResolver.resolveAll(unresolvedAddress.getHostName()) + .addListener(new FutureListener>() { + @Override + public void operationComplete(Future> future) throws Exception { + if (future.isSuccess()) { + List inetAddresses = future.getNow(); + List socketAddresses = + new ArrayList(inetAddresses.size()); + for (InetAddress inetAddress : inetAddresses) { + socketAddresses.add(new InetSocketAddress(inetAddress, unresolvedAddress.getPort())); + } + promise.setSuccess(socketAddresses); + } else { + promise.setFailure(future.cause()); + } + } + }); + } + + @Override + public void close() { + nameResolver.close(); + } +} diff --git a/netty-resolver/src/main/java/io/netty/resolver/NameResolver.java b/netty-resolver/src/main/java/io/netty/resolver/NameResolver.java new file mode 100644 index 0000000..844cdb9 --- /dev/null +++ b/netty-resolver/src/main/java/io/netty/resolver/NameResolver.java @@ -0,0 +1,73 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.resolver; + +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.Promise; + +import java.io.Closeable; +import java.util.List; + +/** + * Resolves an arbitrary string that represents the name of an endpoint into an address. + */ +public interface NameResolver extends Closeable { + + /** + * Resolves the specified name into an address. + * + * @param inetHost the name to resolve + * + * @return the address as the result of the resolution + */ + Future resolve(String inetHost); + + /** + * Resolves the specified name into an address. + * + * @param inetHost the name to resolve + * @param promise the {@link Promise} which will be fulfilled when the name resolution is finished + * + * @return the address as the result of the resolution + */ + Future resolve(String inetHost, Promise promise); + + /** + * Resolves the specified host name and port into a list of address. + * + * @param inetHost the name to resolve + * + * @return the list of the address as the result of the resolution + */ + Future> resolveAll(String inetHost); + + /** + * Resolves the specified host name and port into a list of address. + * + * @param inetHost the name to resolve + * @param promise the {@link Promise} which will be fulfilled when the name resolution is finished + * + * @return the list of the address as the result of the resolution + */ + Future> resolveAll(String inetHost, Promise> promise); + + /** + * Closes all the resources allocated and used by this resolver. + */ + @Override + void close(); +} diff --git a/netty-resolver/src/main/java/io/netty/resolver/NoopAddressResolver.java b/netty-resolver/src/main/java/io/netty/resolver/NoopAddressResolver.java new file mode 100644 index 0000000..d4e645c --- /dev/null +++ b/netty-resolver/src/main/java/io/netty/resolver/NoopAddressResolver.java @@ -0,0 +1,51 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.resolver; + +import io.netty.util.concurrent.EventExecutor; +import io.netty.util.concurrent.Promise; + +import java.net.SocketAddress; +import java.util.Collections; +import java.util.List; + +/** + * A {@link AddressResolver} that does not perform any resolution but always reports successful resolution. + * This resolver is useful when name resolution is performed by a handler in a pipeline, such as a proxy handler. + */ +public class NoopAddressResolver extends AbstractAddressResolver { + + public NoopAddressResolver(EventExecutor executor) { + super(executor); + } + + @Override + protected boolean doIsResolved(SocketAddress address) { + return true; + } + + @Override + protected void doResolve(SocketAddress unresolvedAddress, Promise promise) throws Exception { + promise.setSuccess(unresolvedAddress); + } + + @Override + protected void doResolveAll( + SocketAddress unresolvedAddress, Promise> promise) throws Exception { + promise.setSuccess(Collections.singletonList(unresolvedAddress)); + } +} diff --git a/netty-resolver/src/main/java/io/netty/resolver/NoopAddressResolverGroup.java b/netty-resolver/src/main/java/io/netty/resolver/NoopAddressResolverGroup.java new file mode 100644 index 0000000..e2135fb --- /dev/null +++ b/netty-resolver/src/main/java/io/netty/resolver/NoopAddressResolverGroup.java @@ -0,0 +1,36 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.resolver; + +import io.netty.util.concurrent.EventExecutor; + +import java.net.SocketAddress; + +/** + * A {@link AddressResolverGroup} of {@link NoopAddressResolver}s. + */ +public final class NoopAddressResolverGroup extends AddressResolverGroup { + + public static final NoopAddressResolverGroup INSTANCE = new NoopAddressResolverGroup(); + + private NoopAddressResolverGroup() { } + + @Override + protected AddressResolver newResolver(EventExecutor executor) throws Exception { + return new NoopAddressResolver(executor); + } +} diff --git a/netty-resolver/src/main/java/io/netty/resolver/ResolvedAddressTypes.java b/netty-resolver/src/main/java/io/netty/resolver/ResolvedAddressTypes.java new file mode 100644 index 0000000..cb2d384 --- /dev/null +++ b/netty-resolver/src/main/java/io/netty/resolver/ResolvedAddressTypes.java @@ -0,0 +1,38 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.resolver; + +/** + * Defined resolved address types. + */ +public enum ResolvedAddressTypes { + /** + * Only resolve IPv4 addresses + */ + IPV4_ONLY, + /** + * Only resolve IPv6 addresses + */ + IPV6_ONLY, + /** + * Prefer IPv4 addresses over IPv6 ones + */ + IPV4_PREFERRED, + /** + * Prefer IPv6 addresses over IPv4 ones + */ + IPV6_PREFERRED +} diff --git a/netty-resolver/src/main/java/io/netty/resolver/RoundRobinInetAddressResolver.java b/netty-resolver/src/main/java/io/netty/resolver/RoundRobinInetAddressResolver.java new file mode 100644 index 0000000..e997b97 --- /dev/null +++ b/netty-resolver/src/main/java/io/netty/resolver/RoundRobinInetAddressResolver.java @@ -0,0 +1,106 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.resolver; + +import io.netty.util.concurrent.EventExecutor; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.FutureListener; +import io.netty.util.concurrent.Promise; +import io.netty.util.internal.PlatformDependent; + +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.UnknownHostException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; + +/** + * A {@link NameResolver} that resolves {@link InetAddress} and force Round Robin by choosing a single address + * randomly in {@link #resolve(String)} and {@link #resolve(String, Promise)} + * if multiple are returned by the {@link NameResolver}. + * Use {@link #asAddressResolver()} to create a {@link InetSocketAddress} resolver + */ +public class RoundRobinInetAddressResolver extends InetNameResolver { + private final NameResolver nameResolver; + + /** + * @param executor the {@link EventExecutor} which is used to notify the listeners of the {@link Future} returned by + * {@link #resolve(String)} + * @param nameResolver the {@link NameResolver} used for name resolution + */ + public RoundRobinInetAddressResolver(EventExecutor executor, NameResolver nameResolver) { + super(executor); + this.nameResolver = nameResolver; + } + + @Override + protected void doResolve(final String inetHost, final Promise promise) throws Exception { + // hijack the doResolve request, but do a doResolveAll request under the hood. + // Note that InetSocketAddress.getHostName() will never incur a reverse lookup here, + // because an unresolved address always has a host name. + nameResolver.resolveAll(inetHost).addListener(new FutureListener>() { + @Override + public void operationComplete(Future> future) throws Exception { + if (future.isSuccess()) { + List inetAddresses = future.getNow(); + int numAddresses = inetAddresses.size(); + if (numAddresses > 0) { + // if there are multiple addresses: we shall pick one by one + // to support the round robin distribution + promise.setSuccess(inetAddresses.get(randomIndex(numAddresses))); + } else { + promise.setFailure(new UnknownHostException(inetHost)); + } + } else { + promise.setFailure(future.cause()); + } + } + }); + } + + @Override + protected void doResolveAll(String inetHost, final Promise> promise) throws Exception { + nameResolver.resolveAll(inetHost).addListener(new FutureListener>() { + @Override + public void operationComplete(Future> future) throws Exception { + if (future.isSuccess()) { + List inetAddresses = future.getNow(); + if (!inetAddresses.isEmpty()) { + // create a copy to make sure that it's modifiable random access collection + List result = new ArrayList(inetAddresses); + // rotate by different distance each time to force round robin distribution + Collections.rotate(result, randomIndex(inetAddresses.size())); + promise.setSuccess(result); + } else { + promise.setSuccess(inetAddresses); + } + } else { + promise.setFailure(future.cause()); + } + } + }); + } + + private static int randomIndex(int numAddresses) { + return numAddresses == 1 ? 0 : PlatformDependent.threadLocalRandom().nextInt(numAddresses); + } + + @Override + public void close() { + nameResolver.close(); + } +} diff --git a/netty-resolver/src/main/java/io/netty/resolver/SimpleNameResolver.java b/netty-resolver/src/main/java/io/netty/resolver/SimpleNameResolver.java new file mode 100644 index 0000000..cc458cf --- /dev/null +++ b/netty-resolver/src/main/java/io/netty/resolver/SimpleNameResolver.java @@ -0,0 +1,98 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.resolver; + +import io.netty.util.concurrent.EventExecutor; +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.Promise; + +import java.util.List; + +import static io.netty.util.internal.ObjectUtil.*; + +/** + * A skeletal {@link NameResolver} implementation. + */ +public abstract class SimpleNameResolver implements NameResolver { + + private final EventExecutor executor; + + /** + * @param executor the {@link EventExecutor} which is used to notify the listeners of the {@link Future} returned + * by {@link #resolve(String)} + */ + protected SimpleNameResolver(EventExecutor executor) { + this.executor = checkNotNull(executor, "executor"); + } + + /** + * Returns the {@link EventExecutor} which is used to notify the listeners of the {@link Future} returned + * by {@link #resolve(String)}. + */ + protected EventExecutor executor() { + return executor; + } + + @Override + public final Future resolve(String inetHost) { + final Promise promise = executor().newPromise(); + return resolve(inetHost, promise); + } + + @Override + public Future resolve(String inetHost, Promise promise) { + checkNotNull(promise, "promise"); + + try { + doResolve(inetHost, promise); + return promise; + } catch (Exception e) { + return promise.setFailure(e); + } + } + + @Override + public final Future> resolveAll(String inetHost) { + final Promise> promise = executor().newPromise(); + return resolveAll(inetHost, promise); + } + + @Override + public Future> resolveAll(String inetHost, Promise> promise) { + checkNotNull(promise, "promise"); + + try { + doResolveAll(inetHost, promise); + return promise; + } catch (Exception e) { + return promise.setFailure(e); + } + } + + /** + * Invoked by {@link #resolve(String)} to perform the actual name resolution. + */ + protected abstract void doResolve(String inetHost, Promise promise) throws Exception; + + /** + * Invoked by {@link #resolveAll(String)} to perform the actual name resolution. + */ + protected abstract void doResolveAll(String inetHost, Promise> promise) throws Exception; + + @Override + public void close() { } +} diff --git a/netty-resolver/src/main/java/io/netty/resolver/package-info.java b/netty-resolver/src/main/java/io/netty/resolver/package-info.java new file mode 100644 index 0000000..3695a59 --- /dev/null +++ b/netty-resolver/src/main/java/io/netty/resolver/package-info.java @@ -0,0 +1,20 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/** + * Resolves an arbitrary string that represents the name of an endpoint into an address. + */ +package io.netty.resolver; diff --git a/netty-resolver/src/main/java/module-info.java b/netty-resolver/src/main/java/module-info.java new file mode 100644 index 0000000..192235c --- /dev/null +++ b/netty-resolver/src/main/java/module-info.java @@ -0,0 +1,4 @@ +module org.xbib.io.netty.resolver { + exports io.netty.resolver; + requires org.xbib.io.netty.util; +} diff --git a/netty-resolver/src/test/java/io/netty/resolver/DefaultHostsFileEntriesResolverTest.java b/netty-resolver/src/test/java/io/netty/resolver/DefaultHostsFileEntriesResolverTest.java new file mode 100644 index 0000000..985e415 --- /dev/null +++ b/netty-resolver/src/test/java/io/netty/resolver/DefaultHostsFileEntriesResolverTest.java @@ -0,0 +1,195 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.resolver; + +import io.netty.util.NetUtil; +import java.util.HashMap; +import org.junit.jupiter.api.Test; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +import java.net.Inet4Address; +import java.net.Inet6Address; +import java.net.InetAddress; +import java.nio.charset.Charset; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.UUID; +import java.util.concurrent.TimeUnit; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.instanceOf; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.mockito.Mockito.when; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.any; + +public class DefaultHostsFileEntriesResolverTest { + private static final Map> LOCALHOST_V4_ADDRESSES = + Collections.singletonMap("localhost", Collections.singletonList(NetUtil.LOCALHOST4)); + private static final Map> LOCALHOST_V6_ADDRESSES = + Collections.singletonMap("localhost", Collections.singletonList(NetUtil.LOCALHOST6)); + private static final long ENTRIES_TTL = TimeUnit.MINUTES.toNanos(1); + + /** + * show issue https://github.com/netty/netty/issues/5182 + * HostsFileParser tries to resolve hostnames as case-sensitive + */ + @Test + public void testCaseInsensitivity() { + DefaultHostsFileEntriesResolver resolver = new DefaultHostsFileEntriesResolver(); + //normalized somehow + assertEquals(resolver.normalize("localhost"), resolver.normalize("LOCALHOST")); + } + + @Test + public void shouldntFindWhenAddressTypeDoesntMatch() { + HostsFileEntriesProvider.Parser parser = givenHostsParserWith( + LOCALHOST_V4_ADDRESSES, + Collections.>emptyMap() + ); + + DefaultHostsFileEntriesResolver resolver = new DefaultHostsFileEntriesResolver(parser, ENTRIES_TTL); + + InetAddress address = resolver.address("localhost", ResolvedAddressTypes.IPV6_ONLY); + assertNull(address, "Should pick an IPv6 address"); + } + + @Test + public void shouldPickIpv4WhenBothAreDefinedButIpv4IsPreferred() { + HostsFileEntriesProvider.Parser parser = givenHostsParserWith( + LOCALHOST_V4_ADDRESSES, + LOCALHOST_V6_ADDRESSES + ); + + DefaultHostsFileEntriesResolver resolver = new DefaultHostsFileEntriesResolver(parser, ENTRIES_TTL); + + InetAddress address = resolver.address("localhost", ResolvedAddressTypes.IPV4_PREFERRED); + assertThat("Should pick an IPv4 address", address, instanceOf(Inet4Address.class)); + } + + @Test + public void shouldPickIpv6WhenBothAreDefinedButIpv6IsPreferred() { + HostsFileEntriesProvider.Parser parser = givenHostsParserWith( + LOCALHOST_V4_ADDRESSES, + LOCALHOST_V6_ADDRESSES + ); + + DefaultHostsFileEntriesResolver resolver = new DefaultHostsFileEntriesResolver(parser, ENTRIES_TTL); + + InetAddress address = resolver.address("localhost", ResolvedAddressTypes.IPV6_PREFERRED); + assertThat("Should pick an IPv6 address", address, instanceOf(Inet6Address.class)); + } + + @Test + public void shouldntFindWhenAddressesTypeDoesntMatch() { + HostsFileEntriesProvider.Parser parser = givenHostsParserWith( + LOCALHOST_V4_ADDRESSES, + Collections.>emptyMap() + ); + + DefaultHostsFileEntriesResolver resolver = new DefaultHostsFileEntriesResolver(parser, ENTRIES_TTL); + + List addresses = resolver.addresses("localhost", ResolvedAddressTypes.IPV6_ONLY); + assertNull(addresses, "Should pick an IPv6 address"); + } + + @Test + public void shouldPickIpv4FirstWhenBothAreDefinedButIpv4IsPreferred() { + HostsFileEntriesProvider.Parser parser = givenHostsParserWith( + LOCALHOST_V4_ADDRESSES, + LOCALHOST_V6_ADDRESSES + ); + + DefaultHostsFileEntriesResolver resolver = new DefaultHostsFileEntriesResolver(parser, ENTRIES_TTL); + + List addresses = resolver.addresses("localhost", ResolvedAddressTypes.IPV4_PREFERRED); + assertNotNull(addresses); + assertEquals(2, addresses.size()); + assertThat("Should pick an IPv4 address", addresses.get(0), instanceOf(Inet4Address.class)); + assertThat("Should pick an IPv6 address", addresses.get(1), instanceOf(Inet6Address.class)); + } + + @Test + public void shouldPickIpv6FirstWhenBothAreDefinedButIpv6IsPreferred() { + HostsFileEntriesProvider.Parser parser = givenHostsParserWith( + LOCALHOST_V4_ADDRESSES, + LOCALHOST_V6_ADDRESSES + ); + + DefaultHostsFileEntriesResolver resolver = new DefaultHostsFileEntriesResolver(parser, ENTRIES_TTL); + + List addresses = resolver.addresses("localhost", ResolvedAddressTypes.IPV6_PREFERRED); + assertNotNull(addresses); + assertEquals(2, addresses.size()); + assertThat("Should pick an IPv6 address", addresses.get(0), instanceOf(Inet6Address.class)); + assertThat("Should pick an IPv4 address", addresses.get(1), instanceOf(Inet4Address.class)); + } + + @Test + public void shouldNotRefreshHostsFileContentBeforeRefreshIntervalElapsed() { + Map> v4Addresses = new HashMap<>(LOCALHOST_V4_ADDRESSES); + Map> v6Addresses = new HashMap<>(LOCALHOST_V6_ADDRESSES); + DefaultHostsFileEntriesResolver resolver = + new DefaultHostsFileEntriesResolver(givenHostsParserWith(v4Addresses, v6Addresses), ENTRIES_TTL); + String newHost = UUID.randomUUID().toString(); + + v4Addresses.put(newHost, Collections.singletonList(NetUtil.LOCALHOST4)); + v6Addresses.put(newHost, Collections.singletonList(NetUtil.LOCALHOST6)); + + assertNull(resolver.address(newHost, ResolvedAddressTypes.IPV4_ONLY)); + assertNull(resolver.address(newHost, ResolvedAddressTypes.IPV6_ONLY)); + } + + @Test + public void shouldRefreshHostsFileContentAfterRefreshInterval() throws Exception { + Map> v4Addresses = new HashMap<>(LOCALHOST_V4_ADDRESSES); + Map> v6Addresses = new HashMap<>(LOCALHOST_V6_ADDRESSES); + DefaultHostsFileEntriesResolver resolver = + new DefaultHostsFileEntriesResolver(givenHostsParserWith(v4Addresses, v6Addresses), /*nanos*/1); + String newHost = UUID.randomUUID().toString(); + + InetAddress address = resolver.address(newHost, ResolvedAddressTypes.IPV6_ONLY); + assertNull(address); + /*let refreshIntervalNanos = 1 elapse*/ + Thread.sleep(1); + v4Addresses.put(newHost, Collections.singletonList(NetUtil.LOCALHOST4)); + v6Addresses.put(newHost, Collections.singletonList(NetUtil.LOCALHOST6)); + + assertEquals(NetUtil.LOCALHOST4, resolver.address(newHost, ResolvedAddressTypes.IPV4_ONLY)); + assertEquals(NetUtil.LOCALHOST6, resolver.address(newHost, ResolvedAddressTypes.IPV6_ONLY)); + } + + private HostsFileEntriesProvider.Parser givenHostsParserWith(final Map> inet4Entries, + final Map> inet6Entries) { + HostsFileEntriesProvider.Parser mockParser = mock(HostsFileEntriesProvider.Parser.class); + + Answer mockedAnswer = new Answer() { + @Override + public HostsFileEntriesProvider answer(InvocationOnMock invocation) { + return new HostsFileEntriesProvider(inet4Entries, inet6Entries); + } + }; + + when(mockParser.parseSilently()).thenAnswer(mockedAnswer); + when(mockParser.parseSilently(any(Charset.class))).thenAnswer(mockedAnswer); + + return mockParser; + } +} diff --git a/netty-resolver/src/test/java/io/netty/resolver/HostsFileEntriesProviderTest.java b/netty-resolver/src/test/java/io/netty/resolver/HostsFileEntriesProviderTest.java new file mode 100644 index 0000000..2b5c541 --- /dev/null +++ b/netty-resolver/src/test/java/io/netty/resolver/HostsFileEntriesProviderTest.java @@ -0,0 +1,146 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.resolver; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import java.io.BufferedReader; +import java.io.File; +import java.io.IOException; +import java.io.Reader; +import java.io.StringReader; +import java.net.InetAddress; +import java.nio.charset.Charset; +import java.util.List; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; + +class HostsFileEntriesProviderTest { + + @Test + void testParse() throws IOException { + String hostsString = new StringBuilder() + .append("127.0.0.1 host1").append("\n") // single hostname, separated with blanks + .append("::1 host1").append("\n") // same as above, but IPv6 + .append("\n") // empty line + .append("192.168.0.1\thost2").append("\n") // single hostname, separated with tabs + .append("#comment").append("\n") // comment at the beginning of the line + .append(" #comment ").append("\n") // comment in the middle of the line + .append("192.168.0.2 host3 #comment").append("\n") // comment after hostname + .append("192.168.0.3 host4 host5 host6").append("\n") // multiple aliases + .append("192.168.0.4 host4").append("\n") // host mapped to a second address, must be considered + .append("192.168.0.5 HOST7").append("\n") // uppercase host, should match lowercase host + .append("192.168.0.6 host7").append("\n") // must be considered + .toString(); + + HostsFileEntriesProvider entries = HostsFileEntriesProvider.parser() + .parse(new BufferedReader(new StringReader(hostsString))); + Map> inet4Entries = entries.ipv4Entries(); + Map> inet6Entries = entries.ipv6Entries(); + + assertEquals(7, inet4Entries.size(), "Expected 7 IPv4 entries"); + assertEquals(1, inet6Entries.size(), "Expected 1 IPv6 entries"); + + assertEquals(1, inet4Entries.get("host1").size()); + assertEquals("127.0.0.1", inet4Entries.get("host1").get(0).getHostAddress()); + + assertEquals(1, inet4Entries.get("host2").size()); + assertEquals("192.168.0.1", inet4Entries.get("host2").get(0).getHostAddress()); + + assertEquals(1, inet4Entries.get("host3").size()); + assertEquals("192.168.0.2", inet4Entries.get("host3").get(0).getHostAddress()); + + assertEquals(2, inet4Entries.get("host4").size()); + assertEquals("192.168.0.3", inet4Entries.get("host4").get(0).getHostAddress()); + assertEquals("192.168.0.4", inet4Entries.get("host4").get(1).getHostAddress()); + + assertEquals(1, inet4Entries.get("host5").size()); + assertEquals("192.168.0.3", inet4Entries.get("host5").get(0).getHostAddress()); + + assertEquals(1, inet4Entries.get("host6").size()); + assertEquals("192.168.0.3", inet4Entries.get("host6").get(0).getHostAddress()); + + assertNotNull(inet4Entries.get("host7"), "Uppercase host doesn't resolve"); + assertEquals(2, inet4Entries.get("host7").size()); + assertEquals("192.168.0.5", inet4Entries.get("host7").get(0).getHostAddress()); + assertEquals("192.168.0.6", inet4Entries.get("host7").get(1).getHostAddress()); + + assertEquals(1, inet6Entries.get("host1").size()); + assertEquals("0:0:0:0:0:0:0:1", inet6Entries.get("host1").get(0).getHostAddress()); + } + + @Test + void testCharsetInputValidation() { + assertThrows(NullPointerException.class, new Executable() { + @Override + public void execute() throws IOException { + HostsFileEntriesProvider.parser().parse((Charset[]) null); + } + }); + + assertThrows(NullPointerException.class, new Executable() { + @Override + public void execute() throws IOException { + HostsFileEntriesProvider.parser().parse(new File(""), (Charset[]) null); + } + }); + + assertThrows(NullPointerException.class, new Executable() { + @Override + public void execute() { + HostsFileEntriesProvider.parser().parseSilently((Charset[]) null); + } + }); + + assertThrows(NullPointerException.class, new Executable() { + @Override + public void execute() { + HostsFileEntriesProvider.parser().parseSilently(new File(""), (Charset[]) null); + } + }); + } + + @Test + void testFileInputValidation() { + assertThrows(NullPointerException.class, new Executable() { + @Override + public void execute() throws IOException { + HostsFileEntriesProvider.parser().parse((File) null); + } + }); + + assertThrows(NullPointerException.class, new Executable() { + @Override + public void execute() { + HostsFileEntriesProvider.parser().parseSilently((File) null); + } + }); + } + + @Test + void testReaderInputValidation() { + assertThrows(NullPointerException.class, new Executable() { + @Override + public void execute() throws IOException { + HostsFileEntriesProvider.parser().parse((Reader) null); + } + }); + } +} diff --git a/netty-resolver/src/test/java/io/netty/resolver/HostsFileParserTest.java b/netty-resolver/src/test/java/io/netty/resolver/HostsFileParserTest.java new file mode 100644 index 0000000..1e5cff8 --- /dev/null +++ b/netty-resolver/src/test/java/io/netty/resolver/HostsFileParserTest.java @@ -0,0 +1,103 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.resolver; + +import io.netty.util.CharsetUtil; +import io.netty.util.internal.ResourcesUtil; +import org.junit.jupiter.api.Test; + +import java.io.BufferedReader; +import java.io.IOException; +import java.io.StringReader; +import java.net.Inet4Address; +import java.net.Inet6Address; +import java.nio.charset.Charset; +import java.nio.charset.UnsupportedCharsetException; +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; + +public class HostsFileParserTest { + + @Test + public void testParse() throws IOException { + String hostsString = new StringBuilder() + .append("127.0.0.1 host1").append("\n") // single hostname, separated with blanks + .append("::1 host1").append("\n") // same as above, but IPv6 + .append("\n") // empty line + .append("192.168.0.1\thost2").append("\n") // single hostname, separated with tabs + .append("#comment").append("\n") // comment at the beginning of the line + .append(" #comment ").append("\n") // comment in the middle of the line + .append("192.168.0.2 host3 #comment").append("\n") // comment after hostname + .append("192.168.0.3 host4 host5 host6").append("\n") // multiple aliases + .append("192.168.0.4 host4").append("\n") // host mapped to a second address, must be ignored + .append("192.168.0.5 HOST7").append("\n") // uppercase host, should match lowercase host + .append("192.168.0.6 host7").append("\n") // should be ignored since we have the uppercase host already + .toString(); + + HostsFileEntries entries = HostsFileParser.parse(new BufferedReader(new StringReader(hostsString))); + Map inet4Entries = entries.inet4Entries(); + Map inet6Entries = entries.inet6Entries(); + + assertEquals(7, inet4Entries.size(), "Expected 7 IPv4 entries"); + assertEquals(1, inet6Entries.size(), "Expected 1 IPv6 entries"); + assertEquals("127.0.0.1", inet4Entries.get("host1").getHostAddress()); + assertEquals("192.168.0.1", inet4Entries.get("host2").getHostAddress()); + assertEquals("192.168.0.2", inet4Entries.get("host3").getHostAddress()); + assertEquals("192.168.0.3", inet4Entries.get("host4").getHostAddress()); + assertEquals("192.168.0.3", inet4Entries.get("host5").getHostAddress()); + assertEquals("192.168.0.3", inet4Entries.get("host6").getHostAddress()); + assertNotNull(inet4Entries.get("host7"), "uppercase host doesn't resolve"); + assertEquals("192.168.0.5", inet4Entries.get("host7").getHostAddress()); + assertEquals("0:0:0:0:0:0:0:1", inet6Entries.get("host1").getHostAddress()); + } + + @Test + public void testParseUnicode() throws IOException { + final Charset unicodeCharset; + try { + unicodeCharset = Charset.forName("unicode"); + } catch (UnsupportedCharsetException e) { + return; + } + testParseFile(HostsFileParser.parse( + ResourcesUtil.getFile(getClass(), "hosts-unicode"), unicodeCharset)); + } + + @Test + public void testParseMultipleCharsets() throws IOException { + final Charset unicodeCharset; + try { + unicodeCharset = Charset.forName("unicode"); + } catch (UnsupportedCharsetException e) { + return; + } + testParseFile(HostsFileParser.parse(ResourcesUtil.getFile(getClass(), "hosts-unicode"), + CharsetUtil.UTF_8, CharsetUtil.ISO_8859_1, unicodeCharset)); + } + + private static void testParseFile(HostsFileEntries entries) throws IOException { + Map inet4Entries = entries.inet4Entries(); + Map inet6Entries = entries.inet6Entries(); + + assertEquals(2, inet4Entries.size(), "Expected 2 IPv4 entries"); + assertEquals(1, inet6Entries.size(), "Expected 1 IPv6 entries"); + assertEquals("127.0.0.1", inet4Entries.get("localhost").getHostAddress()); + assertEquals("255.255.255.255", inet4Entries.get("broadcasthost").getHostAddress()); + assertEquals("0:0:0:0:0:0:0:1", inet6Entries.get("localhost").getHostAddress()); + } +} diff --git a/netty-resolver/src/test/java/io/netty/resolver/InetSocketAddressResolverTest.java b/netty-resolver/src/test/java/io/netty/resolver/InetSocketAddressResolverTest.java new file mode 100644 index 0000000..39825a2 --- /dev/null +++ b/netty-resolver/src/test/java/io/netty/resolver/InetSocketAddressResolverTest.java @@ -0,0 +1,38 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.resolver; + +import io.netty.util.concurrent.ImmediateEventExecutor; +import org.junit.jupiter.api.Test; + +import java.net.InetAddress; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; + +public class InetSocketAddressResolverTest { + + @Test + public void testCloseDelegates() { + @SuppressWarnings("unchecked") + NameResolver nameResolver = mock(NameResolver.class); + InetSocketAddressResolver resolver = new InetSocketAddressResolver( + ImmediateEventExecutor.INSTANCE, nameResolver); + resolver.close(); + verify(nameResolver, times(1)).close(); + } +} diff --git a/netty-resolver/src/test/resources/io/netty/resolver/hosts-unicode b/netty-resolver/src/test/resources/io/netty/resolver/hosts-unicode new file mode 100644 index 0000000000000000000000000000000000000000..68750bfd947a62c95e78744d61f88f5e2f2fb4b9 GIT binary patch literal 426 zcmZ`#+X})k3_Y*gR}}U>tRRBVKKL8m#cnt^=_=yidy?u9FN8F0b8>PL6ey6>4Gx$v zBl5O~V|1wKGg{3j%s@=d-vQl%?ujwDDzf1%GGmCEFye?Fp1D94$#fcpWx_@IYn+(r ziH*z!RyfwCo(4Cq7~g7V+*M_Zv(`!PO8KI)c$EB+8AQd%Qzoj(gg64=lz^JCu&10|HO={0;?Pv-I#3lwX+;MnpK^m_1k{*djN#5LLC4A literal 0 HcmV?d00001 diff --git a/netty-resolver/src/test/resources/logging.properties b/netty-resolver/src/test/resources/logging.properties new file mode 100644 index 0000000..3cd7309 --- /dev/null +++ b/netty-resolver/src/test/resources/logging.properties @@ -0,0 +1,7 @@ +handlers=java.util.logging.ConsoleHandler +.level=ALL +java.util.logging.SimpleFormatter.format=%1$tY-%1$tm-%1$td %1$tH:%1$tM:%1$tS.%1$tL %4$-7s [%3$s] %5$s %6$s%n +java.util.logging.ConsoleHandler.level=ALL +java.util.logging.ConsoleHandler.formatter=java.util.logging.SimpleFormatter +jdk.event.security.level=INFO +org.junit.jupiter.engine.execution.ConditionEvaluator.level=OFF diff --git a/netty-util/build.gradle b/netty-util/build.gradle new file mode 100644 index 0000000..a3eda71 --- /dev/null +++ b/netty-util/build.gradle @@ -0,0 +1,5 @@ +dependencies { + api project(':netty-jctools') + testImplementation testLibs.mockito.core + testImplementation testLibs.assertj +} diff --git a/netty-util/src/main/java/io/netty/util/AbstractConstant.java b/netty-util/src/main/java/io/netty/util/AbstractConstant.java new file mode 100644 index 0000000..de16653 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/AbstractConstant.java @@ -0,0 +1,89 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util; + +import java.util.concurrent.atomic.AtomicLong; + +/** + * Base implementation of {@link Constant}. + */ +public abstract class AbstractConstant> implements Constant { + + private static final AtomicLong uniqueIdGenerator = new AtomicLong(); + private final int id; + private final String name; + private final long uniquifier; + + /** + * Creates a new instance. + */ + protected AbstractConstant(int id, String name) { + this.id = id; + this.name = name; + this.uniquifier = uniqueIdGenerator.getAndIncrement(); + } + + @Override + public final String name() { + return name; + } + + @Override + public final int id() { + return id; + } + + @Override + public final String toString() { + return name(); + } + + @Override + public final int hashCode() { + return super.hashCode(); + } + + @Override + public final boolean equals(Object obj) { + return super.equals(obj); + } + + @Override + public final int compareTo(T o) { + if (this == o) { + return 0; + } + + @SuppressWarnings("UnnecessaryLocalVariable") + AbstractConstant other = o; + int returnCode; + + returnCode = hashCode() - other.hashCode(); + if (returnCode != 0) { + return returnCode; + } + + if (uniquifier < other.uniquifier) { + return -1; + } + if (uniquifier > other.uniquifier) { + return 1; + } + + throw new Error("failed to compare two different constants"); + } + +} diff --git a/netty-util/src/main/java/io/netty/util/AbstractReferenceCounted.java b/netty-util/src/main/java/io/netty/util/AbstractReferenceCounted.java new file mode 100644 index 0000000..52be59e --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/AbstractReferenceCounted.java @@ -0,0 +1,95 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util; + +import io.netty.util.internal.ReferenceCountUpdater; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; + +/** + * Abstract base class for classes wants to implement {@link ReferenceCounted}. + */ +public abstract class AbstractReferenceCounted implements ReferenceCounted { + private static final long REFCNT_FIELD_OFFSET = + ReferenceCountUpdater.getUnsafeOffset(AbstractReferenceCounted.class, "refCnt"); + private static final AtomicIntegerFieldUpdater AIF_UPDATER = + AtomicIntegerFieldUpdater.newUpdater(AbstractReferenceCounted.class, "refCnt"); + + private static final ReferenceCountUpdater updater = + new ReferenceCountUpdater() { + @Override + protected AtomicIntegerFieldUpdater updater() { + return AIF_UPDATER; + } + + @Override + protected long unsafeOffset() { + return REFCNT_FIELD_OFFSET; + } + }; + + // Value might not equal "real" reference count, all access should be via the updater + @SuppressWarnings({"unused", "FieldMayBeFinal"}) + private volatile int refCnt = updater.initialValue(); + + @Override + public int refCnt() { + return updater.refCnt(this); + } + + /** + * An unsafe operation intended for use by a subclass that sets the reference count of the buffer directly + */ + protected final void setRefCnt(int refCnt) { + updater.setRefCnt(this, refCnt); + } + + @Override + public ReferenceCounted retain() { + return updater.retain(this); + } + + @Override + public ReferenceCounted retain(int increment) { + return updater.retain(this, increment); + } + + @Override + public ReferenceCounted touch() { + return touch(null); + } + + @Override + public boolean release() { + return handleRelease(updater.release(this)); + } + + @Override + public boolean release(int decrement) { + return handleRelease(updater.release(this, decrement)); + } + + private boolean handleRelease(boolean result) { + if (result) { + deallocate(); + } + return result; + } + + /** + * Called once {@link #refCnt()} is equals 0. + */ + protected abstract void deallocate(); +} diff --git a/netty-util/src/main/java/io/netty/util/AsciiString.java b/netty-util/src/main/java/io/netty/util/AsciiString.java new file mode 100644 index 0000000..45f8774 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/AsciiString.java @@ -0,0 +1,1883 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util; + +import io.netty.util.internal.EmptyArrays; +import io.netty.util.internal.InternalThreadLocalMap; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.PlatformDependent; +import java.nio.ByteBuffer; +import java.nio.CharBuffer; +import java.nio.charset.Charset; +import java.nio.charset.CharsetEncoder; +import java.util.Arrays; +import java.util.Collection; +import java.util.List; +import java.util.regex.Pattern; +import java.util.regex.PatternSyntaxException; +import static io.netty.util.internal.MathUtil.isOutOfBounds; +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +/** + * A string which has been encoded into a character encoding whose character always takes a single byte, similarly to + * ASCII. It internally keeps its content in a byte array unlike {@link String}, which uses a character array, for + * reduced memory footprint and faster data transfer from/to byte-based data structures such as a byte array and + * {@link ByteBuffer}. It is often used in conjunction with {@code Headers} that require a {@link CharSequence}. + *

+ * This class was designed to provide an immutable array of bytes, and caches some internal state based upon the value + * of this array. However underlying access to this byte array is provided via not copying the array on construction or + * {@link #array()}. If any changes are made to the underlying byte array it is the user's responsibility to call + * {@link #arrayChanged()} so the state of this class can be reset. + */ +public final class AsciiString implements CharSequence, Comparable { + public static final AsciiString EMPTY_STRING = cached(""); + private static final char MAX_CHAR_VALUE = 255; + + public static final int INDEX_NOT_FOUND = -1; + + /** + * If this value is modified outside the constructor then call {@link #arrayChanged()}. + */ + private final byte[] value; + /** + * Offset into {@link #value} that all operations should use when acting upon {@link #value}. + */ + private final int offset; + /** + * Length in bytes for {@link #value} that we care about. This is independent from {@code value.length} + * because we may be looking at a subsection of the array. + */ + private final int length; + /** + * The hash code is cached after it is first computed. It can be reset with {@link #arrayChanged()}. + */ + private int hash; + /** + * Used to cache the {@link #toString()} value. + */ + private String string; + + /** + * Initialize this byte string based upon a byte array. A copy will be made. + */ + public AsciiString(byte[] value) { + this(value, true); + } + + /** + * Initialize this byte string based upon a byte array. + * {@code copy} determines if a copy is made or the array is shared. + */ + public AsciiString(byte[] value, boolean copy) { + this(value, 0, value.length, copy); + } + + /** + * Construct a new instance from a {@code byte[]} array. + * + * @param copy {@code true} then a copy of the memory will be made. {@code false} the underlying memory + * will be shared. + */ + public AsciiString(byte[] value, int start, int length, boolean copy) { + if (copy) { + final byte[] rangedCopy = new byte[length]; + System.arraycopy(value, start, rangedCopy, 0, rangedCopy.length); + this.value = rangedCopy; + this.offset = 0; + } else { + if (isOutOfBounds(start, length, value.length)) { + throw new IndexOutOfBoundsException("expected: " + "0 <= start(" + start + ") <= start + length(" + + length + ") <= " + "value.length(" + value.length + ')'); + } + this.value = value; + this.offset = start; + } + this.length = length; + } + + /** + * Create a copy of the underlying storage from {@code value}. + * The copy will start at {@link ByteBuffer#position()} and copy {@link ByteBuffer#remaining()} bytes. + */ + public AsciiString(ByteBuffer value) { + this(value, true); + } + + /** + * Initialize an instance based upon the underlying storage from {@code value}. + * There is a potential to share the underlying array storage if {@link ByteBuffer#hasArray()} is {@code true}. + * if {@code copy} is {@code true} a copy will be made of the memory. + * if {@code copy} is {@code false} the underlying storage will be shared, if possible. + */ + public AsciiString(ByteBuffer value, boolean copy) { + this(value, value.position(), value.remaining(), copy); + } + + /** + * Initialize an {@link AsciiString} based upon the underlying storage from {@code value}. + * There is a potential to share the underlying array storage if {@link ByteBuffer#hasArray()} is {@code true}. + * if {@code copy} is {@code true} a copy will be made of the memory. + * if {@code copy} is {@code false} the underlying storage will be shared, if possible. + */ + public AsciiString(ByteBuffer value, int start, int length, boolean copy) { + if (isOutOfBounds(start, length, value.capacity())) { + throw new IndexOutOfBoundsException("expected: " + "0 <= start(" + start + ") <= start + length(" + length + + ") <= " + "value.capacity(" + value.capacity() + ')'); + } + + if (value.hasArray()) { + if (copy) { + final int bufferOffset = value.arrayOffset() + start; + this.value = Arrays.copyOfRange(value.array(), bufferOffset, bufferOffset + length); + offset = 0; + } else { + this.value = value.array(); + this.offset = start; + } + } else { + this.value = PlatformDependent.allocateUninitializedArray(length); + int oldPos = value.position(); + value.get(this.value, 0, length); + value.position(oldPos); + this.offset = 0; + } + this.length = length; + } + + /** + * Create a copy of {@code value} into this instance assuming ASCII encoding. + */ + public AsciiString(char[] value) { + this(value, 0, value.length); + } + + /** + * Create a copy of {@code value} into this instance assuming ASCII encoding. + * The copy will start at index {@code start} and copy {@code length} bytes. + */ + public AsciiString(char[] value, int start, int length) { + if (isOutOfBounds(start, length, value.length)) { + throw new IndexOutOfBoundsException("expected: " + "0 <= start(" + start + ") <= start + length(" + length + + ") <= " + "value.length(" + value.length + ')'); + } + + this.value = PlatformDependent.allocateUninitializedArray(length); + for (int i = 0, j = start; i < length; i++, j++) { + this.value[i] = c2b(value[j]); + } + this.offset = 0; + this.length = length; + } + + /** + * Create a copy of {@code value} into this instance using the encoding type of {@code charset}. + */ + public AsciiString(char[] value, Charset charset) { + this(value, charset, 0, value.length); + } + + /** + * Create a copy of {@code value} into a this instance using the encoding type of {@code charset}. + * The copy will start at index {@code start} and copy {@code length} bytes. + */ + public AsciiString(char[] value, Charset charset, int start, int length) { + CharBuffer cbuf = CharBuffer.wrap(value, start, length); + CharsetEncoder encoder = CharsetUtil.encoder(charset); + ByteBuffer nativeBuffer = ByteBuffer.allocate((int) (encoder.maxBytesPerChar() * length)); + encoder.encode(cbuf, nativeBuffer, true); + final int bufferOffset = nativeBuffer.arrayOffset(); + this.value = Arrays.copyOfRange(nativeBuffer.array(), bufferOffset, bufferOffset + nativeBuffer.position()); + this.offset = 0; + this.length = this.value.length; + } + + /** + * Create a copy of {@code value} into this instance assuming ASCII encoding. + */ + public AsciiString(CharSequence value) { + this(value, 0, value.length()); + } + + /** + * Create a copy of {@code value} into this instance assuming ASCII encoding. + * The copy will start at index {@code start} and copy {@code length} bytes. + */ + public AsciiString(CharSequence value, int start, int length) { + if (isOutOfBounds(start, length, value.length())) { + throw new IndexOutOfBoundsException("expected: " + "0 <= start(" + start + ") <= start + length(" + length + + ") <= " + "value.length(" + value.length() + ')'); + } + + this.value = PlatformDependent.allocateUninitializedArray(length); + for (int i = 0, j = start; i < length; i++, j++) { + this.value[i] = c2b(value.charAt(j)); + } + this.offset = 0; + this.length = length; + } + + /** + * Create a copy of {@code value} into this instance using the encoding type of {@code charset}. + */ + public AsciiString(CharSequence value, Charset charset) { + this(value, charset, 0, value.length()); + } + + /** + * Create a copy of {@code value} into this instance using the encoding type of {@code charset}. + * The copy will start at index {@code start} and copy {@code length} bytes. + */ + public AsciiString(CharSequence value, Charset charset, int start, int length) { + CharBuffer cbuf = CharBuffer.wrap(value, start, start + length); + CharsetEncoder encoder = CharsetUtil.encoder(charset); + ByteBuffer nativeBuffer = ByteBuffer.allocate((int) (encoder.maxBytesPerChar() * length)); + encoder.encode(cbuf, nativeBuffer, true); + final int offset = nativeBuffer.arrayOffset(); + this.value = Arrays.copyOfRange(nativeBuffer.array(), offset, offset + nativeBuffer.position()); + this.offset = 0; + this.length = this.value.length; + } + + /** + * Iterates over the readable bytes of this buffer with the specified {@code processor} in ascending order. + * + * @return {@code -1} if the processor iterated to or beyond the end of the readable bytes. + * The last-visited index If the {@link ByteProcessor#process(byte)} returned {@code false}. + */ + public int forEachByte(ByteProcessor visitor) throws Exception { + return forEachByte0(0, length(), visitor); + } + + /** + * Iterates over the specified area of this buffer with the specified {@code processor} in ascending order. + * (i.e. {@code index}, {@code (index + 1)}, .. {@code (index + length - 1)}). + * + * @return {@code -1} if the processor iterated to or beyond the end of the specified area. + * The last-visited index If the {@link ByteProcessor#process(byte)} returned {@code false}. + */ + public int forEachByte(int index, int length, ByteProcessor visitor) throws Exception { + if (isOutOfBounds(index, length, length())) { + throw new IndexOutOfBoundsException("expected: " + "0 <= index(" + index + ") <= start + length(" + length + + ") <= " + "length(" + length() + ')'); + } + return forEachByte0(index, length, visitor); + } + + private int forEachByte0(int index, int length, ByteProcessor visitor) throws Exception { + final int len = offset + index + length; + for (int i = offset + index; i < len; ++i) { + if (!visitor.process(value[i])) { + return i - offset; + } + } + return -1; + } + + /** + * Iterates over the readable bytes of this buffer with the specified {@code processor} in descending order. + * + * @return {@code -1} if the processor iterated to or beyond the beginning of the readable bytes. + * The last-visited index If the {@link ByteProcessor#process(byte)} returned {@code false}. + */ + public int forEachByteDesc(ByteProcessor visitor) throws Exception { + return forEachByteDesc0(0, length(), visitor); + } + + /** + * Iterates over the specified area of this buffer with the specified {@code processor} in descending order. + * (i.e. {@code (index + length - 1)}, {@code (index + length - 2)}, ... {@code index}). + * + * @return {@code -1} if the processor iterated to or beyond the beginning of the specified area. + * The last-visited index If the {@link ByteProcessor#process(byte)} returned {@code false}. + */ + public int forEachByteDesc(int index, int length, ByteProcessor visitor) throws Exception { + if (isOutOfBounds(index, length, length())) { + throw new IndexOutOfBoundsException("expected: " + "0 <= index(" + index + ") <= start + length(" + length + + ") <= " + "length(" + length() + ')'); + } + return forEachByteDesc0(index, length, visitor); + } + + private int forEachByteDesc0(int index, int length, ByteProcessor visitor) throws Exception { + final int end = offset + index; + for (int i = offset + index + length - 1; i >= end; --i) { + if (!visitor.process(value[i])) { + return i - offset; + } + } + return -1; + } + + public byte byteAt(int index) { + // We must do a range check here to enforce the access does not go outside our sub region of the array. + // We rely on the array access itself to pick up the array out of bounds conditions + if (index < 0 || index >= length) { + throw new IndexOutOfBoundsException("index: " + index + " must be in the range [0," + length + ")"); + } + // Try to use unsafe to avoid double checking the index bounds + if (PlatformDependent.hasUnsafe()) { + return PlatformDependent.getByte(value, index + offset); + } + return value[index + offset]; + } + + /** + * Determine if this instance has 0 length. + */ + public boolean isEmpty() { + return length == 0; + } + + /** + * The length in bytes of this instance. + */ + @Override + public int length() { + return length; + } + + /** + * During normal use cases the {@link AsciiString} should be immutable, but if the underlying array is shared, + * and changes then this needs to be called. + */ + public void arrayChanged() { + string = null; + hash = 0; + } + + /** + * This gives direct access to the underlying storage array. + * The {@link #toByteArray()} should be preferred over this method. + * If the return value is changed then {@link #arrayChanged()} must be called. + * + * @see #arrayOffset() + * @see #isEntireArrayUsed() + */ + public byte[] array() { + return value; + } + + /** + * The offset into {@link #array()} for which data for this ByteString begins. + * + * @see #array() + * @see #isEntireArrayUsed() + */ + public int arrayOffset() { + return offset; + } + + /** + * Determine if the storage represented by {@link #array()} is entirely used. + * + * @see #array() + */ + public boolean isEntireArrayUsed() { + return offset == 0 && length == value.length; + } + + /** + * Converts this string to a byte array. + */ + public byte[] toByteArray() { + return toByteArray(0, length()); + } + + /** + * Converts a subset of this string to a byte array. + * The subset is defined by the range [{@code start}, {@code end}). + */ + public byte[] toByteArray(int start, int end) { + return Arrays.copyOfRange(value, start + offset, end + offset); + } + + /** + * Copies the content of this string to a byte array. + * + * @param srcIdx the starting offset of characters to copy. + * @param dst the destination byte array. + * @param dstIdx the starting offset in the destination byte array. + * @param length the number of characters to copy. + */ + public void copy(int srcIdx, byte[] dst, int dstIdx, int length) { + if (isOutOfBounds(srcIdx, length, length())) { + throw new IndexOutOfBoundsException("expected: " + "0 <= srcIdx(" + srcIdx + ") <= srcIdx + length(" + + length + ") <= srcLen(" + length() + ')'); + } + + System.arraycopy(value, srcIdx + offset, checkNotNull(dst, "dst"), dstIdx, length); + } + + @Override + public char charAt(int index) { + return b2c(byteAt(index)); + } + + /** + * Determines if this {@code String} contains the sequence of characters in the {@code CharSequence} passed. + * + * @param cs the character sequence to search for. + * @return {@code true} if the sequence of characters are contained in this string, otherwise {@code false}. + */ + public boolean contains(CharSequence cs) { + return indexOf(cs) >= 0; + } + + /** + * Compares the specified string to this string using the ASCII values of the characters. Returns 0 if the strings + * contain the same characters in the same order. Returns a negative integer if the first non-equal character in + * this string has an ASCII value which is less than the ASCII value of the character at the same position in the + * specified string, or if this string is a prefix of the specified string. Returns a positive integer if the first + * non-equal character in this string has a ASCII value which is greater than the ASCII value of the character at + * the same position in the specified string, or if the specified string is a prefix of this string. + * + * @param string the string to compare. + * @return 0 if the strings are equal, a negative integer if this string is before the specified string, or a + * positive integer if this string is after the specified string. + * @throws NullPointerException if {@code string} is {@code null}. + */ + @Override + public int compareTo(CharSequence string) { + if (this == string) { + return 0; + } + + int result; + int length1 = length(); + int length2 = string.length(); + int minLength = Math.min(length1, length2); + for (int i = 0, j = arrayOffset(); i < minLength; i++, j++) { + result = b2c(value[j]) - string.charAt(i); + if (result != 0) { + return result; + } + } + + return length1 - length2; + } + + /** + * Concatenates this string and the specified string. + * + * @param string the string to concatenate + * @return a new string which is the concatenation of this string and the specified string. + */ + public AsciiString concat(CharSequence string) { + int thisLen = length(); + int thatLen = string.length(); + if (thatLen == 0) { + return this; + } + + if (string instanceof AsciiString that) { + if (isEmpty()) { + return that; + } + + byte[] newValue = PlatformDependent.allocateUninitializedArray(thisLen + thatLen); + System.arraycopy(value, arrayOffset(), newValue, 0, thisLen); + System.arraycopy(that.value, that.arrayOffset(), newValue, thisLen, thatLen); + return new AsciiString(newValue, false); + } + + if (isEmpty()) { + return new AsciiString(string); + } + + byte[] newValue = PlatformDependent.allocateUninitializedArray(thisLen + thatLen); + System.arraycopy(value, arrayOffset(), newValue, 0, thisLen); + for (int i = thisLen, j = 0; i < newValue.length; i++, j++) { + newValue[i] = c2b(string.charAt(j)); + } + + return new AsciiString(newValue, false); + } + + /** + * Compares the specified string to this string to determine if the specified string is a suffix. + * + * @param suffix the suffix to look for. + * @return {@code true} if the specified string is a suffix of this string, {@code false} otherwise. + * @throws NullPointerException if {@code suffix} is {@code null}. + */ + public boolean endsWith(CharSequence suffix) { + int suffixLen = suffix.length(); + return regionMatches(length() - suffixLen, suffix, 0, suffixLen); + } + + /** + * Compares the specified string to this string ignoring the case of the characters and returns true if they are + * equal. + * + * @param string the string to compare. + * @return {@code true} if the specified string is equal to this string, {@code false} otherwise. + */ + public boolean contentEqualsIgnoreCase(CharSequence string) { + if (this == string) { + return true; + } + + if (string == null || string.length() != length()) { + return false; + } + + if (string instanceof AsciiString rhs) { + for (int i = arrayOffset(), j = rhs.arrayOffset(), end = i + length(); i < end; ++i, ++j) { + if (!equalsIgnoreCase(value[i], rhs.value[j])) { + return false; + } + } + return true; + } + + for (int i = arrayOffset(), j = 0, end = length(); j < end; ++i, ++j) { + if (!equalsIgnoreCase(b2c(value[i]), string.charAt(j))) { + return false; + } + } + return true; + } + + /** + * Copies the characters in this string to a character array. + * + * @return a character array containing the characters of this string. + */ + public char[] toCharArray() { + return toCharArray(0, length()); + } + + /** + * Copies the characters in this string to a character array. + * + * @return a character array containing the characters of this string. + */ + public char[] toCharArray(int start, int end) { + int length = end - start; + if (length == 0) { + return EmptyArrays.EMPTY_CHARS; + } + + if (isOutOfBounds(start, length, length())) { + throw new IndexOutOfBoundsException("expected: " + "0 <= start(" + start + ") <= srcIdx + length(" + + length + ") <= srcLen(" + length() + ')'); + } + + final char[] buffer = new char[length]; + for (int i = 0, j = start + arrayOffset(); i < length; i++, j++) { + buffer[i] = b2c(value[j]); + } + return buffer; + } + + /** + * Copied the content of this string to a character array. + * + * @param srcIdx the starting offset of characters to copy. + * @param dst the destination character array. + * @param dstIdx the starting offset in the destination byte array. + * @param length the number of characters to copy. + */ + public void copy(int srcIdx, char[] dst, int dstIdx, int length) { + ObjectUtil.checkNotNull(dst, "dst"); + + if (isOutOfBounds(srcIdx, length, length())) { + throw new IndexOutOfBoundsException("expected: " + "0 <= srcIdx(" + srcIdx + ") <= srcIdx + length(" + + length + ") <= srcLen(" + length() + ')'); + } + + final int dstEnd = dstIdx + length; + for (int i = dstIdx, j = srcIdx + arrayOffset(); i < dstEnd; i++, j++) { + dst[i] = b2c(value[j]); + } + } + + /** + * Copies a range of characters into a new string. + * + * @param start the offset of the first character (inclusive). + * @return a new string containing the characters from start to the end of the string. + * @throws IndexOutOfBoundsException if {@code start < 0} or {@code start > length()}. + */ + public AsciiString subSequence(int start) { + return subSequence(start, length()); + } + + /** + * Copies a range of characters into a new string. + * + * @param start the offset of the first character (inclusive). + * @param end The index to stop at (exclusive). + * @return a new string containing the characters from start to the end of the string. + * @throws IndexOutOfBoundsException if {@code start < 0} or {@code start > length()}. + */ + @Override + public AsciiString subSequence(int start, int end) { + return subSequence(start, end, true); + } + + /** + * Either copy or share a subset of underlying sub-sequence of bytes. + * + * @param start the offset of the first character (inclusive). + * @param end The index to stop at (exclusive). + * @param copy If {@code true} then a copy of the underlying storage will be made. + * If {@code false} then the underlying storage will be shared. + * @return a new string containing the characters from start to the end of the string. + * @throws IndexOutOfBoundsException if {@code start < 0} or {@code start > length()}. + */ + public AsciiString subSequence(int start, int end, boolean copy) { + if (isOutOfBounds(start, end - start, length())) { + throw new IndexOutOfBoundsException("expected: 0 <= start(" + start + ") <= end (" + end + ") <= length(" + + length() + ')'); + } + + if (start == 0 && end == length()) { + return this; + } + + if (end == start) { + return EMPTY_STRING; + } + + return new AsciiString(value, start + offset, end - start, copy); + } + + /** + * Searches in this string for the first index of the specified string. The search for the string starts at the + * beginning and moves towards the end of this string. + * + * @param string the string to find. + * @return the index of the first character of the specified string in this string, -1 if the specified string is + * not a substring. + * @throws NullPointerException if {@code string} is {@code null}. + */ + public int indexOf(CharSequence string) { + return indexOf(string, 0); + } + + /** + * Searches in this string for the index of the specified string. The search for the string starts at the specified + * offset and moves towards the end of this string. + * + * @param subString the string to find. + * @param start the starting offset. + * @return the index of the first character of the specified string in this string, -1 if the specified string is + * not a substring. + * @throws NullPointerException if {@code subString} is {@code null}. + */ + public int indexOf(CharSequence subString, int start) { + final int subCount = subString.length(); + if (start < 0) { + start = 0; + } + if (subCount <= 0) { + return start < length ? start : length; + } + if (subCount > length - start) { + return INDEX_NOT_FOUND; + } + + final char firstChar = subString.charAt(0); + if (firstChar > MAX_CHAR_VALUE) { + return INDEX_NOT_FOUND; + } + final byte firstCharAsByte = c2b0(firstChar); + final int len = offset + length - subCount; + for (int i = start + offset; i <= len; ++i) { + if (value[i] == firstCharAsByte) { + int o1 = i, o2 = 0; + while (++o2 < subCount && b2c(value[++o1]) == subString.charAt(o2)) { + // Intentionally empty + } + if (o2 == subCount) { + return i - offset; + } + } + } + return INDEX_NOT_FOUND; + } + + /** + * Searches in this string for the index of the specified char {@code ch}. + * The search for the char starts at the specified offset {@code start} and moves towards the end of this string. + * + * @param ch the char to find. + * @param start the starting offset. + * @return the index of the first occurrence of the specified char {@code ch} in this string, + * -1 if found no occurrence. + */ + public int indexOf(char ch, int start) { + if (ch > MAX_CHAR_VALUE) { + return INDEX_NOT_FOUND; + } + + if (start < 0) { + start = 0; + } + + final byte chAsByte = c2b0(ch); + final int len = offset + length; + for (int i = start + offset; i < len; ++i) { + if (value[i] == chAsByte) { + return i - offset; + } + } + return INDEX_NOT_FOUND; + } + + /** + * Searches in this string for the last index of the specified string. The search for the string starts at the end + * and moves towards the beginning of this string. + * + * @param string the string to find. + * @return the index of the first character of the specified string in this string, -1 if the specified string is + * not a substring. + * @throws NullPointerException if {@code string} is {@code null}. + */ + public int lastIndexOf(CharSequence string) { + // Use count instead of count - 1 so lastIndexOf("") answers count + return lastIndexOf(string, length); + } + + /** + * Searches in this string for the index of the specified string. The search for the string starts at the specified + * offset and moves towards the beginning of this string. + * + * @param subString the string to find. + * @param start the starting offset. + * @return the index of the first character of the specified string in this string , -1 if the specified string is + * not a substring. + * @throws NullPointerException if {@code subString} is {@code null}. + */ + public int lastIndexOf(CharSequence subString, int start) { + final int subCount = subString.length(); + start = Math.min(start, length - subCount); + if (start < 0) { + return INDEX_NOT_FOUND; + } + if (subCount == 0) { + return start; + } + + final char firstChar = subString.charAt(0); + if (firstChar > MAX_CHAR_VALUE) { + return INDEX_NOT_FOUND; + } + final byte firstCharAsByte = c2b0(firstChar); + for (int i = offset + start; i >= offset; --i) { + if (value[i] == firstCharAsByte) { + int o1 = i, o2 = 0; + while (++o2 < subCount && b2c(value[++o1]) == subString.charAt(o2)) { + // Intentionally empty + } + if (o2 == subCount) { + return i - offset; + } + } + } + return INDEX_NOT_FOUND; + } + + /** + * Compares the specified string to this string and compares the specified range of characters to determine if they + * are the same. + * + * @param thisStart the starting offset in this string. + * @param string the string to compare. + * @param start the starting offset in the specified string. + * @param length the number of characters to compare. + * @return {@code true} if the ranges of characters are equal, {@code false} otherwise + * @throws NullPointerException if {@code string} is {@code null}. + */ + public boolean regionMatches(int thisStart, CharSequence string, int start, int length) { + ObjectUtil.checkNotNull(string, "string"); + + if (start < 0 || string.length() - start < length) { + return false; + } + + final int thisLen = length(); + if (thisStart < 0 || thisLen - thisStart < length) { + return false; + } + + if (length <= 0) { + return true; + } + + final int thatEnd = start + length; + for (int i = start, j = thisStart + arrayOffset(); i < thatEnd; i++, j++) { + if (b2c(value[j]) != string.charAt(i)) { + return false; + } + } + return true; + } + + /** + * Compares the specified string to this string and compares the specified range of characters to determine if they + * are the same. When ignoreCase is true, the case of the characters is ignored during the comparison. + * + * @param ignoreCase specifies if case should be ignored. + * @param thisStart the starting offset in this string. + * @param string the string to compare. + * @param start the starting offset in the specified string. + * @param length the number of characters to compare. + * @return {@code true} if the ranges of characters are equal, {@code false} otherwise. + * @throws NullPointerException if {@code string} is {@code null}. + */ + public boolean regionMatches(boolean ignoreCase, int thisStart, CharSequence string, int start, int length) { + if (!ignoreCase) { + return regionMatches(thisStart, string, start, length); + } + + ObjectUtil.checkNotNull(string, "string"); + + final int thisLen = length(); + if (thisStart < 0 || length > thisLen - thisStart) { + return false; + } + if (start < 0 || length > string.length() - start) { + return false; + } + + thisStart += arrayOffset(); + final int thisEnd = thisStart + length; + while (thisStart < thisEnd) { + if (!equalsIgnoreCase(b2c(value[thisStart++]), string.charAt(start++))) { + return false; + } + } + return true; + } + + /** + * Copies this string replacing occurrences of the specified character with another character. + * + * @param oldChar the character to replace. + * @param newChar the replacement character. + * @return a new string with occurrences of oldChar replaced by newChar. + */ + public AsciiString replace(char oldChar, char newChar) { + if (oldChar > MAX_CHAR_VALUE) { + return this; + } + + final byte oldCharAsByte = c2b0(oldChar); + final byte newCharAsByte = c2b(newChar); + final int len = offset + length; + for (int i = offset; i < len; ++i) { + if (value[i] == oldCharAsByte) { + byte[] buffer = PlatformDependent.allocateUninitializedArray(length()); + System.arraycopy(value, offset, buffer, 0, i - offset); + buffer[i - offset] = newCharAsByte; + ++i; + for (; i < len; ++i) { + byte oldValue = value[i]; + buffer[i - offset] = oldValue != oldCharAsByte ? oldValue : newCharAsByte; + } + return new AsciiString(buffer, false); + } + } + return this; + } + + /** + * Compares the specified string to this string to determine if the specified string is a prefix. + * + * @param prefix the string to look for. + * @return {@code true} if the specified string is a prefix of this string, {@code false} otherwise + * @throws NullPointerException if {@code prefix} is {@code null}. + */ + public boolean startsWith(CharSequence prefix) { + return startsWith(prefix, 0); + } + + /** + * Compares the specified string to this string, starting at the specified offset, to determine if the specified + * string is a prefix. + * + * @param prefix the string to look for. + * @param start the starting offset. + * @return {@code true} if the specified string occurs in this string at the specified offset, {@code false} + * otherwise. + * @throws NullPointerException if {@code prefix} is {@code null}. + */ + public boolean startsWith(CharSequence prefix, int start) { + return regionMatches(start, prefix, 0, prefix.length()); + } + + /** + * Converts the characters in this string to lowercase, using the default Locale. + * + * @return a new string containing the lowercase characters equivalent to the characters in this string. + */ + public AsciiString toLowerCase() { + boolean lowercased = true; + int i, j; + final int len = length() + arrayOffset(); + for (i = arrayOffset(); i < len; ++i) { + byte b = value[i]; + if (b >= 'A' && b <= 'Z') { + lowercased = false; + break; + } + } + + // Check if this string does not contain any uppercase characters. + if (lowercased) { + return this; + } + + final byte[] newValue = PlatformDependent.allocateUninitializedArray(length()); + for (i = 0, j = arrayOffset(); i < newValue.length; ++i, ++j) { + newValue[i] = toLowerCase(value[j]); + } + + return new AsciiString(newValue, false); + } + + /** + * Converts the characters in this string to uppercase, using the default Locale. + * + * @return a new string containing the uppercase characters equivalent to the characters in this string. + */ + public AsciiString toUpperCase() { + boolean uppercased = true; + int i, j; + final int len = length() + arrayOffset(); + for (i = arrayOffset(); i < len; ++i) { + byte b = value[i]; + if (b >= 'a' && b <= 'z') { + uppercased = false; + break; + } + } + + // Check if this string does not contain any lowercase characters. + if (uppercased) { + return this; + } + + final byte[] newValue = PlatformDependent.allocateUninitializedArray(length()); + for (i = 0, j = arrayOffset(); i < newValue.length; ++i, ++j) { + newValue[i] = toUpperCase(value[j]); + } + + return new AsciiString(newValue, false); + } + + /** + * Copies this string removing white space characters from the beginning and end of the string, and tries not to + * copy if possible. + * + * @param c The {@link CharSequence} to trim. + * @return a new string with characters {@code <= \\u0020} removed from the beginning and the end. + */ + public static CharSequence trim(CharSequence c) { + if (c instanceof AsciiString) { + return ((AsciiString) c).trim(); + } + if (c instanceof String) { + return ((String) c).trim(); + } + int start = 0, last = c.length() - 1; + int end = last; + while (start <= end && c.charAt(start) <= ' ') { + start++; + } + while (end >= start && c.charAt(end) <= ' ') { + end--; + } + if (start == 0 && end == last) { + return c; + } + return c.subSequence(start, end); + } + + /** + * Duplicates this string removing white space characters from the beginning and end of the + * string, without copying. + * + * @return a new string with characters {@code <= \\u0020} removed from the beginning and the end. + */ + public AsciiString trim() { + int start = arrayOffset(), last = arrayOffset() + length() - 1; + int end = last; + while (start <= end && value[start] <= ' ') { + start++; + } + while (end >= start && value[end] <= ' ') { + end--; + } + if (start == 0 && end == last) { + return this; + } + return new AsciiString(value, start, end - start + 1, false); + } + + /** + * Compares a {@code CharSequence} to this {@code String} to determine if their contents are equal. + * + * @param a the character sequence to compare to. + * @return {@code true} if equal, otherwise {@code false} + */ + public boolean contentEquals(CharSequence a) { + if (this == a) { + return true; + } + + if (a == null || a.length() != length()) { + return false; + } + if (a instanceof AsciiString) { + return equals(a); + } + + for (int i = arrayOffset(), j = 0; j < a.length(); ++i, ++j) { + if (b2c(value[i]) != a.charAt(j)) { + return false; + } + } + return true; + } + + /** + * Determines whether this string matches a given regular expression. + * + * @param expr the regular expression to be matched. + * @return {@code true} if the expression matches, otherwise {@code false}. + * @throws PatternSyntaxException if the syntax of the supplied regular expression is not valid. + * @throws NullPointerException if {@code expr} is {@code null}. + */ + public boolean matches(String expr) { + return Pattern.matches(expr, this); + } + + /** + * Splits this string using the supplied regular expression {@code expr}. The parameter {@code max} controls the + * behavior how many times the pattern is applied to the string. + * + * @param expr the regular expression used to divide the string. + * @param max the number of entries in the resulting array. + * @return an array of Strings created by separating the string along matches of the regular expression. + * @throws NullPointerException if {@code expr} is {@code null}. + * @throws PatternSyntaxException if the syntax of the supplied regular expression is not valid. + * @see Pattern#split(CharSequence, int) + */ + public AsciiString[] split(String expr, int max) { + return toAsciiStringArray(Pattern.compile(expr).split(this, max)); + } + + /** + * Splits the specified {@link String} with the specified delimiter.. + */ + public AsciiString[] split(char delim) { + final List res = InternalThreadLocalMap.get().arrayList(); + + int start = 0; + final int length = length(); + for (int i = start; i < length; i++) { + if (charAt(i) == delim) { + if (start == i) { + res.add(EMPTY_STRING); + } else { + res.add(new AsciiString(value, start + arrayOffset(), i - start, false)); + } + start = i + 1; + } + } + + if (start == 0) { // If no delimiter was found in the value + res.add(this); + } else { + if (start != length) { + // Add the last element if it's not empty. + res.add(new AsciiString(value, start + arrayOffset(), length - start, false)); + } else { + // Truncate trailing empty elements. + for (int i = res.size() - 1; i >= 0; i--) { + if (res.get(i).isEmpty()) { + res.remove(i); + } else { + break; + } + } + } + } + + return res.toArray(EmptyArrays.EMPTY_ASCII_STRINGS); + } + + /** + * {@inheritDoc} + *

+ * Provides a case-insensitive hash code for Ascii like byte strings. + */ + @Override + public int hashCode() { + int h = hash; + if (h == 0) { + h = PlatformDependent.hashCodeAscii(value, offset, length); + hash = h; + } + return h; + } + + @Override + public boolean equals(Object obj) { + if (obj == null || obj.getClass() != AsciiString.class) { + return false; + } + if (this == obj) { + return true; + } + + AsciiString other = (AsciiString) obj; + return length() == other.length() && + hashCode() == other.hashCode() && + PlatformDependent.equals(array(), arrayOffset(), other.array(), other.arrayOffset(), length()); + } + + /** + * Translates the entire byte string to a {@link String}. + * + * @see #toString(int) + */ + @Override + public String toString() { + String cache = string; + if (cache == null) { + cache = toString(0); + string = cache; + } + return cache; + } + + /** + * Translates the entire byte string to a {@link String} using the {@code charset} encoding. + * + * @see #toString(int, int) + */ + public String toString(int start) { + return toString(start, length()); + } + + /** + * Translates the [{@code start}, {@code end}) range of this byte string to a {@link String}. + */ + public String toString(int start, int end) { + int length = end - start; + if (length == 0) { + return ""; + } + + if (isOutOfBounds(start, length, length())) { + throw new IndexOutOfBoundsException("expected: " + "0 <= start(" + start + ") <= srcIdx + length(" + + length + ") <= srcLen(" + length() + ')'); + } + + @SuppressWarnings("deprecation") final String str = new String(value, 0, start + offset, length); + return str; + } + + public boolean parseBoolean() { + return length >= 1 && value[offset] != 0; + } + + public char parseChar() { + return parseChar(0); + } + + public char parseChar(int start) { + if (start + 1 >= length()) { + throw new IndexOutOfBoundsException("2 bytes required to convert to character. index " + + start + " would go out of bounds."); + } + final int startWithOffset = start + offset; + return (char) ((b2c(value[startWithOffset]) << 8) | b2c(value[startWithOffset + 1])); + } + + public short parseShort() { + return parseShort(0, length(), 10); + } + + public short parseShort(int radix) { + return parseShort(0, length(), radix); + } + + public short parseShort(int start, int end) { + return parseShort(start, end, 10); + } + + public short parseShort(int start, int end, int radix) { + int intValue = parseInt(start, end, radix); + short result = (short) intValue; + if (result != intValue) { + throw new NumberFormatException(subSequence(start, end, false).toString()); + } + return result; + } + + public int parseInt() { + return parseInt(0, length(), 10); + } + + public int parseInt(int radix) { + return parseInt(0, length(), radix); + } + + public int parseInt(int start, int end) { + return parseInt(start, end, 10); + } + + public int parseInt(int start, int end, int radix) { + if (radix < Character.MIN_RADIX || radix > Character.MAX_RADIX) { + throw new NumberFormatException(); + } + + if (start == end) { + throw new NumberFormatException(); + } + + int i = start; + boolean negative = byteAt(i) == '-'; + if (negative && ++i == end) { + throw new NumberFormatException(subSequence(start, end, false).toString()); + } + + return parseInt(i, end, radix, negative); + } + + private int parseInt(int start, int end, int radix, boolean negative) { + int max = Integer.MIN_VALUE / radix; + int result = 0; + int currOffset = start; + while (currOffset < end) { + int digit = Character.digit((char) (value[currOffset++ + offset] & 0xFF), radix); + if (digit == -1) { + throw new NumberFormatException(subSequence(start, end, false).toString()); + } + if (max > result) { + throw new NumberFormatException(subSequence(start, end, false).toString()); + } + int next = result * radix - digit; + if (next > result) { + throw new NumberFormatException(subSequence(start, end, false).toString()); + } + result = next; + } + if (!negative) { + result = -result; + if (result < 0) { + throw new NumberFormatException(subSequence(start, end, false).toString()); + } + } + return result; + } + + public long parseLong() { + return parseLong(0, length(), 10); + } + + public long parseLong(int radix) { + return parseLong(0, length(), radix); + } + + public long parseLong(int start, int end) { + return parseLong(start, end, 10); + } + + public long parseLong(int start, int end, int radix) { + if (radix < Character.MIN_RADIX || radix > Character.MAX_RADIX) { + throw new NumberFormatException(); + } + + if (start == end) { + throw new NumberFormatException(); + } + + int i = start; + boolean negative = byteAt(i) == '-'; + if (negative && ++i == end) { + throw new NumberFormatException(subSequence(start, end, false).toString()); + } + + return parseLong(i, end, radix, negative); + } + + private long parseLong(int start, int end, int radix, boolean negative) { + long max = Long.MIN_VALUE / radix; + long result = 0; + int currOffset = start; + while (currOffset < end) { + int digit = Character.digit((char) (value[currOffset++ + offset] & 0xFF), radix); + if (digit == -1) { + throw new NumberFormatException(subSequence(start, end, false).toString()); + } + if (max > result) { + throw new NumberFormatException(subSequence(start, end, false).toString()); + } + long next = result * radix - digit; + if (next > result) { + throw new NumberFormatException(subSequence(start, end, false).toString()); + } + result = next; + } + if (!negative) { + result = -result; + if (result < 0) { + throw new NumberFormatException(subSequence(start, end, false).toString()); + } + } + return result; + } + + public float parseFloat() { + return parseFloat(0, length()); + } + + public float parseFloat(int start, int end) { + return Float.parseFloat(toString(start, end)); + } + + public double parseDouble() { + return parseDouble(0, length()); + } + + public double parseDouble(int start, int end) { + return Double.parseDouble(toString(start, end)); + } + + public static final HashingStrategy CASE_INSENSITIVE_HASHER = + new HashingStrategy() { + @Override + public int hashCode(CharSequence o) { + return AsciiString.hashCode(o); + } + + @Override + public boolean equals(CharSequence a, CharSequence b) { + return AsciiString.contentEqualsIgnoreCase(a, b); + } + }; + + public static final HashingStrategy CASE_SENSITIVE_HASHER = + new HashingStrategy() { + @Override + public int hashCode(CharSequence o) { + return AsciiString.hashCode(o); + } + + @Override + public boolean equals(CharSequence a, CharSequence b) { + return AsciiString.contentEquals(a, b); + } + }; + + /** + * Returns an {@link AsciiString} containing the given character sequence. If the given string is already a + * {@link AsciiString}, just returns the same instance. + */ + public static AsciiString of(CharSequence string) { + return string instanceof AsciiString ? (AsciiString) string : new AsciiString(string); + } + + /** + * Returns an {@link AsciiString} containing the given string and retains/caches the input + * string for later use in {@link #toString()}. + * Used for the constants (which already stored in the JVM's string table) and in cases + * where the guaranteed use of the {@link #toString()} method. + */ + public static AsciiString cached(String string) { + AsciiString asciiString = new AsciiString(string); + asciiString.string = string; + return asciiString; + } + + /** + * Returns the case-insensitive hash code of the specified string. Note that this method uses the same hashing + * algorithm with {@link #hashCode()} so that you can put both {@link AsciiString}s and arbitrary + * {@link CharSequence}s into the same headers. + */ + public static int hashCode(CharSequence value) { + if (value == null) { + return 0; + } + if (value instanceof AsciiString) { + return value.hashCode(); + } + + return PlatformDependent.hashCodeAscii(value); + } + + /** + * Determine if {@code a} contains {@code b} in a case sensitive manner. + */ + public static boolean contains(CharSequence a, CharSequence b) { + return contains(a, b, DefaultCharEqualityComparator.INSTANCE); + } + + /** + * Determine if {@code a} contains {@code b} in a case insensitive manner. + */ + public static boolean containsIgnoreCase(CharSequence a, CharSequence b) { + return contains(a, b, AsciiCaseInsensitiveCharEqualityComparator.INSTANCE); + } + + /** + * Returns {@code true} if both {@link CharSequence}'s are equals when ignore the case. This only supports 8-bit + * ASCII. + */ + public static boolean contentEqualsIgnoreCase(CharSequence a, CharSequence b) { + if (a == null || b == null) { + return a == b; + } + + if (a instanceof AsciiString) { + return ((AsciiString) a).contentEqualsIgnoreCase(b); + } + if (b instanceof AsciiString) { + return ((AsciiString) b).contentEqualsIgnoreCase(a); + } + + if (a.length() != b.length()) { + return false; + } + for (int i = 0; i < a.length(); ++i) { + if (!equalsIgnoreCase(a.charAt(i), b.charAt(i))) { + return false; + } + } + return true; + } + + /** + * Determine if {@code collection} contains {@code value} and using + * {@link #contentEqualsIgnoreCase(CharSequence, CharSequence)} to compare values. + * + * @param collection The collection to look for and equivalent element as {@code value}. + * @param value The value to look for in {@code collection}. + * @return {@code true} if {@code collection} contains {@code value} according to + * {@link #contentEqualsIgnoreCase(CharSequence, CharSequence)}. {@code false} otherwise. + * @see #contentEqualsIgnoreCase(CharSequence, CharSequence) + */ + public static boolean containsContentEqualsIgnoreCase(Collection collection, CharSequence value) { + for (CharSequence v : collection) { + if (contentEqualsIgnoreCase(value, v)) { + return true; + } + } + return false; + } + + /** + * Determine if {@code a} contains all of the values in {@code b} using + * {@link #contentEqualsIgnoreCase(CharSequence, CharSequence)} to compare values. + * + * @param a The collection under test. + * @param b The values to test for. + * @return {@code true} if {@code a} contains all of the values in {@code b} using + * {@link #contentEqualsIgnoreCase(CharSequence, CharSequence)} to compare values. {@code false} otherwise. + * @see #contentEqualsIgnoreCase(CharSequence, CharSequence) + */ + public static boolean containsAllContentEqualsIgnoreCase(Collection a, Collection b) { + for (CharSequence v : b) { + if (!containsContentEqualsIgnoreCase(a, v)) { + return false; + } + } + return true; + } + + /** + * Returns {@code true} if the content of both {@link CharSequence}'s are equals. This only supports 8-bit ASCII. + */ + public static boolean contentEquals(CharSequence a, CharSequence b) { + if (a == null || b == null) { + return a == b; + } + + if (a instanceof AsciiString) { + return ((AsciiString) a).contentEquals(b); + } + + if (b instanceof AsciiString) { + return ((AsciiString) b).contentEquals(a); + } + + if (a.length() != b.length()) { + return false; + } + for (int i = 0; i < a.length(); ++i) { + if (a.charAt(i) != b.charAt(i)) { + return false; + } + } + return true; + } + + private static AsciiString[] toAsciiStringArray(String[] jdkResult) { + AsciiString[] res = new AsciiString[jdkResult.length]; + for (int i = 0; i < jdkResult.length; i++) { + res[i] = new AsciiString(jdkResult[i]); + } + return res; + } + + private interface CharEqualityComparator { + boolean equals(char a, char b); + } + + private static final class DefaultCharEqualityComparator implements CharEqualityComparator { + static final DefaultCharEqualityComparator INSTANCE = new DefaultCharEqualityComparator(); + + private DefaultCharEqualityComparator() { + } + + @Override + public boolean equals(char a, char b) { + return a == b; + } + } + + private static final class AsciiCaseInsensitiveCharEqualityComparator implements CharEqualityComparator { + static final AsciiCaseInsensitiveCharEqualityComparator + INSTANCE = new AsciiCaseInsensitiveCharEqualityComparator(); + + private AsciiCaseInsensitiveCharEqualityComparator() { + } + + @Override + public boolean equals(char a, char b) { + return equalsIgnoreCase(a, b); + } + } + + private static final class GeneralCaseInsensitiveCharEqualityComparator implements CharEqualityComparator { + static final GeneralCaseInsensitiveCharEqualityComparator + INSTANCE = new GeneralCaseInsensitiveCharEqualityComparator(); + + private GeneralCaseInsensitiveCharEqualityComparator() { + } + + @Override + public boolean equals(char a, char b) { + //For motivation, why we need two checks, see comment in String#regionMatches + return Character.toUpperCase(a) == Character.toUpperCase(b) || + Character.toLowerCase(a) == Character.toLowerCase(b); + } + } + + private static boolean contains(CharSequence a, CharSequence b, CharEqualityComparator cmp) { + if (a == null || b == null || a.length() < b.length()) { + return false; + } + if (b.length() == 0) { + return true; + } + int bStart = 0; + for (int i = 0; i < a.length(); ++i) { + if (cmp.equals(b.charAt(bStart), a.charAt(i))) { + // If b is consumed then true. + if (++bStart == b.length()) { + return true; + } + } else if (a.length() - i < b.length()) { + // If there are not enough characters left in a for b to be contained, then false. + return false; + } else { + bStart = 0; + } + } + return false; + } + + private static boolean regionMatchesCharSequences(final CharSequence cs, final int csStart, + final CharSequence string, final int start, final int length, + CharEqualityComparator charEqualityComparator) { + //general purpose implementation for CharSequences + if (csStart < 0 || length > cs.length() - csStart) { + return false; + } + if (start < 0 || length > string.length() - start) { + return false; + } + + int csIndex = csStart; + int csEnd = csIndex + length; + int stringIndex = start; + + while (csIndex < csEnd) { + char c1 = cs.charAt(csIndex++); + char c2 = string.charAt(stringIndex++); + + if (!charEqualityComparator.equals(c1, c2)) { + return false; + } + } + return true; + } + + /** + * This methods make regionMatches operation correctly for any chars in strings + * + * @param cs the {@code CharSequence} to be processed + * @param ignoreCase specifies if case should be ignored. + * @param csStart the starting offset in the {@code cs} CharSequence + * @param string the {@code CharSequence} to compare. + * @param start the starting offset in the specified {@code string}. + * @param length the number of characters to compare. + * @return {@code true} if the ranges of characters are equal, {@code false} otherwise. + */ + public static boolean regionMatches(final CharSequence cs, final boolean ignoreCase, final int csStart, + final CharSequence string, final int start, final int length) { + if (cs == null || string == null) { + return false; + } + + if (cs instanceof String && string instanceof String) { + return ((String) cs).regionMatches(ignoreCase, csStart, (String) string, start, length); + } + + if (cs instanceof AsciiString) { + return ((AsciiString) cs).regionMatches(ignoreCase, csStart, string, start, length); + } + + return regionMatchesCharSequences(cs, csStart, string, start, length, + ignoreCase ? GeneralCaseInsensitiveCharEqualityComparator.INSTANCE : + DefaultCharEqualityComparator.INSTANCE); + } + + /** + * This is optimized version of regionMatches for string with ASCII chars only + * + * @param cs the {@code CharSequence} to be processed + * @param ignoreCase specifies if case should be ignored. + * @param csStart the starting offset in the {@code cs} CharSequence + * @param string the {@code CharSequence} to compare. + * @param start the starting offset in the specified {@code string}. + * @param length the number of characters to compare. + * @return {@code true} if the ranges of characters are equal, {@code false} otherwise. + */ + public static boolean regionMatchesAscii(final CharSequence cs, final boolean ignoreCase, final int csStart, + final CharSequence string, final int start, final int length) { + if (cs == null || string == null) { + return false; + } + + if (!ignoreCase && cs instanceof String && string instanceof String) { + //we don't call regionMatches from String for ignoreCase==true. It's a general purpose method, + //which make complex comparison in case of ignoreCase==true, which is useless for ASCII-only strings. + //To avoid applying this complex ignore-case comparison, we will use regionMatchesCharSequences + return ((String) cs).regionMatches(false, csStart, (String) string, start, length); + } + + if (cs instanceof AsciiString) { + return ((AsciiString) cs).regionMatches(ignoreCase, csStart, string, start, length); + } + + return regionMatchesCharSequences(cs, csStart, string, start, length, + ignoreCase ? AsciiCaseInsensitiveCharEqualityComparator.INSTANCE : + DefaultCharEqualityComparator.INSTANCE); + } + + /** + *

Case in-sensitive find of the first index within a CharSequence + * from the specified position.

+ * + *

A {@code null} CharSequence will return {@code -1}. + * A negative start position is treated as zero. + * An empty ("") search CharSequence always matches. + * A start position greater than the string length only matches + * an empty search CharSequence.

+ * + *
+     * AsciiString.indexOfIgnoreCase(null, *, *)          = -1
+     * AsciiString.indexOfIgnoreCase(*, null, *)          = -1
+     * AsciiString.indexOfIgnoreCase("", "", 0)           = 0
+     * AsciiString.indexOfIgnoreCase("aabaabaa", "A", 0)  = 0
+     * AsciiString.indexOfIgnoreCase("aabaabaa", "B", 0)  = 2
+     * AsciiString.indexOfIgnoreCase("aabaabaa", "AB", 0) = 1
+     * AsciiString.indexOfIgnoreCase("aabaabaa", "B", 3)  = 5
+     * AsciiString.indexOfIgnoreCase("aabaabaa", "B", 9)  = -1
+     * AsciiString.indexOfIgnoreCase("aabaabaa", "B", -1) = 2
+     * AsciiString.indexOfIgnoreCase("aabaabaa", "", 2)   = 2
+     * AsciiString.indexOfIgnoreCase("abc", "", 9)        = -1
+     * 
+ * + * @param str the CharSequence to check, may be null + * @param searchStr the CharSequence to find, may be null + * @param startPos the start position, negative treated as zero + * @return the first index of the search CharSequence (always ≥ startPos), + * -1 if no match or {@code null} string input + */ + public static int indexOfIgnoreCase(final CharSequence str, final CharSequence searchStr, int startPos) { + if (str == null || searchStr == null) { + return INDEX_NOT_FOUND; + } + if (startPos < 0) { + startPos = 0; + } + int searchStrLen = searchStr.length(); + final int endLimit = str.length() - searchStrLen + 1; + if (startPos > endLimit) { + return INDEX_NOT_FOUND; + } + if (searchStrLen == 0) { + return startPos; + } + for (int i = startPos; i < endLimit; i++) { + if (regionMatches(str, true, i, searchStr, 0, searchStrLen)) { + return i; + } + } + return INDEX_NOT_FOUND; + } + + /** + *

Case in-sensitive find of the first index within a CharSequence + * from the specified position. This method optimized and works correctly for ASCII CharSequences only

+ * + *

A {@code null} CharSequence will return {@code -1}. + * A negative start position is treated as zero. + * An empty ("") search CharSequence always matches. + * A start position greater than the string length only matches + * an empty search CharSequence.

+ * + *
+     * AsciiString.indexOfIgnoreCase(null, *, *)          = -1
+     * AsciiString.indexOfIgnoreCase(*, null, *)          = -1
+     * AsciiString.indexOfIgnoreCase("", "", 0)           = 0
+     * AsciiString.indexOfIgnoreCase("aabaabaa", "A", 0)  = 0
+     * AsciiString.indexOfIgnoreCase("aabaabaa", "B", 0)  = 2
+     * AsciiString.indexOfIgnoreCase("aabaabaa", "AB", 0) = 1
+     * AsciiString.indexOfIgnoreCase("aabaabaa", "B", 3)  = 5
+     * AsciiString.indexOfIgnoreCase("aabaabaa", "B", 9)  = -1
+     * AsciiString.indexOfIgnoreCase("aabaabaa", "B", -1) = 2
+     * AsciiString.indexOfIgnoreCase("aabaabaa", "", 2)   = 2
+     * AsciiString.indexOfIgnoreCase("abc", "", 9)        = -1
+     * 
+ * + * @param str the CharSequence to check, may be null + * @param searchStr the CharSequence to find, may be null + * @param startPos the start position, negative treated as zero + * @return the first index of the search CharSequence (always ≥ startPos), + * -1 if no match or {@code null} string input + */ + public static int indexOfIgnoreCaseAscii(final CharSequence str, final CharSequence searchStr, int startPos) { + if (str == null || searchStr == null) { + return INDEX_NOT_FOUND; + } + if (startPos < 0) { + startPos = 0; + } + int searchStrLen = searchStr.length(); + final int endLimit = str.length() - searchStrLen + 1; + if (startPos > endLimit) { + return INDEX_NOT_FOUND; + } + if (searchStrLen == 0) { + return startPos; + } + for (int i = startPos; i < endLimit; i++) { + if (regionMatchesAscii(str, true, i, searchStr, 0, searchStrLen)) { + return i; + } + } + return INDEX_NOT_FOUND; + } + + /** + *

Finds the first index in the {@code CharSequence} that matches the + * specified character.

+ * + * @param cs the {@code CharSequence} to be processed, not null + * @param searchChar the char to be searched for + * @param start the start index, negative starts at the string start + * @return the index where the search char was found, + * -1 if char {@code searchChar} is not found or {@code cs == null} + */ + //----------------------------------------------------------------------- + public static int indexOf(final CharSequence cs, final char searchChar, int start) { + if (cs instanceof String) { + return ((String) cs).indexOf(searchChar, start); + } else if (cs instanceof AsciiString) { + return ((AsciiString) cs).indexOf(searchChar, start); + } + if (cs == null) { + return INDEX_NOT_FOUND; + } + final int sz = cs.length(); + for (int i = start < 0 ? 0 : start; i < sz; i++) { + if (cs.charAt(i) == searchChar) { + return i; + } + } + return INDEX_NOT_FOUND; + } + + private static boolean equalsIgnoreCase(byte a, byte b) { + return a == b || toLowerCase(a) == toLowerCase(b); + } + + private static boolean equalsIgnoreCase(char a, char b) { + return a == b || toLowerCase(a) == toLowerCase(b); + } + + private static byte toLowerCase(byte b) { + return isUpperCase(b) ? (byte) (b + 32) : b; + } + + /** + * If the character is uppercase - converts the character to lowercase, + * otherwise returns the character as it is. Only for ASCII characters. + * + * @return lowercase ASCII character equivalent + */ + public static char toLowerCase(char c) { + return isUpperCase(c) ? (char) (c + 32) : c; + } + + private static byte toUpperCase(byte b) { + return isLowerCase(b) ? (byte) (b - 32) : b; + } + + private static boolean isLowerCase(byte value) { + return value >= 'a' && value <= 'z'; + } + + public static boolean isUpperCase(byte value) { + return value >= 'A' && value <= 'Z'; + } + + public static boolean isUpperCase(char value) { + return value >= 'A' && value <= 'Z'; + } + + public static byte c2b(char c) { + return (byte) ((c > MAX_CHAR_VALUE) ? '?' : c); + } + + private static byte c2b0(char c) { + return (byte) c; + } + + public static char b2c(byte b) { + return (char) (b & 0xFF); + } +} diff --git a/netty-util/src/main/java/io/netty/util/AsyncMapping.java b/netty-util/src/main/java/io/netty/util/AsyncMapping.java new file mode 100644 index 0000000..b3114c7 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/AsyncMapping.java @@ -0,0 +1,28 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util; + +import io.netty.util.concurrent.Future; +import io.netty.util.concurrent.Promise; + +public interface AsyncMapping { + + /** + * Returns the {@link Future} that will provide the result of the mapping. The given {@link Promise} will + * be fulfilled when the result is available. + */ + Future map(IN input, Promise promise); +} diff --git a/netty-util/src/main/java/io/netty/util/Attribute.java b/netty-util/src/main/java/io/netty/util/Attribute.java new file mode 100644 index 0000000..8024f12 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/Attribute.java @@ -0,0 +1,93 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util; + +/** + * An attribute which allows to store a value reference. It may be updated atomically and so is thread-safe. + * + * @param the type of the value it holds. + */ +public interface Attribute { + + /** + * Returns the key of this attribute. + */ + AttributeKey key(); + + /** + * Returns the current value, which may be {@code null} + */ + T get(); + + /** + * Sets the value + */ + void set(T value); + + /** + * Atomically sets to the given value and returns the old value which may be {@code null} if non was set before. + */ + T getAndSet(T value); + + /** + * Atomically sets to the given value if this {@link Attribute}'s value is {@code null}. + * If it was not possible to set the value as it contains a value it will just return the current value. + */ + T setIfAbsent(T value); + + /** + * Removes this attribute from the {@link AttributeMap} and returns the old value. Subsequent {@link #get()} + * calls will return {@code null}. + *

+ * If you only want to return the old value and clear the {@link Attribute} while still keep it in the + * {@link AttributeMap} use {@link #getAndSet(Object)} with a value of {@code null}. + * + *

+ * Be aware that even if you call this method another thread that has obtained a reference to this {@link Attribute} + * via {@link AttributeMap#attr(AttributeKey)} will still operate on the same instance. That said if now another + * thread or even the same thread later will call {@link AttributeMap#attr(AttributeKey)} again, a new + * {@link Attribute} instance is created and so is not the same as the previous one that was removed. Because of + * this special caution should be taken when you call {@link #remove()} or {@link #getAndRemove()}. + * + * @deprecated please consider using {@link #getAndSet(Object)} (with value of {@code null}). + */ + @Deprecated + T getAndRemove(); + + /** + * Atomically sets the value to the given updated value if the current value == the expected value. + * If it the set was successful it returns {@code true} otherwise {@code false}. + */ + boolean compareAndSet(T oldValue, T newValue); + + /** + * Removes this attribute from the {@link AttributeMap}. Subsequent {@link #get()} calls will return @{code null}. + *

+ * If you only want to remove the value and clear the {@link Attribute} while still keep it in + * {@link AttributeMap} use {@link #set(Object)} with a value of {@code null}. + * + *

+ * Be aware that even if you call this method another thread that has obtained a reference to this {@link Attribute} + * via {@link AttributeMap#attr(AttributeKey)} will still operate on the same instance. That said if now another + * thread or even the same thread later will call {@link AttributeMap#attr(AttributeKey)} again, a new + * {@link Attribute} instance is created and so is not the same as the previous one that was removed. Because of + * this special caution should be taken when you call {@link #remove()} or {@link #getAndRemove()}. + * + * @deprecated please consider using {@link #set(Object)} (with value of {@code null}). + */ + @Deprecated + void remove(); +} diff --git a/netty-util/src/main/java/io/netty/util/AttributeKey.java b/netty-util/src/main/java/io/netty/util/AttributeKey.java new file mode 100644 index 0000000..ae2f482 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/AttributeKey.java @@ -0,0 +1,66 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util; + +/** + * Key which can be used to access {@link Attribute} out of the {@link AttributeMap}. Be aware that it is not be + * possible to have multiple keys with the same name. + * + * @param the type of the {@link Attribute} which can be accessed via this {@link AttributeKey}. + */ +@SuppressWarnings("UnusedDeclaration") // 'T' is used only at compile time +public final class AttributeKey extends AbstractConstant> { + + private static final ConstantPool> pool = new ConstantPool>() { + @Override + protected AttributeKey newConstant(int id, String name) { + return new AttributeKey(id, name); + } + }; + + /** + * Returns the singleton instance of the {@link AttributeKey} which has the specified {@code name}. + */ + @SuppressWarnings("unchecked") + public static AttributeKey valueOf(String name) { + return (AttributeKey) pool.valueOf(name); + } + + /** + * Returns {@code true} if a {@link AttributeKey} exists for the given {@code name}. + */ + public static boolean exists(String name) { + return pool.exists(name); + } + + /** + * Creates a new {@link AttributeKey} for the given {@code name} or fail with an + * {@link IllegalArgumentException} if a {@link AttributeKey} for the given {@code name} exists. + */ + @SuppressWarnings("unchecked") + public static AttributeKey newInstance(String name) { + return (AttributeKey) pool.newInstance(name); + } + + @SuppressWarnings("unchecked") + public static AttributeKey valueOf(Class firstNameComponent, String secondNameComponent) { + return (AttributeKey) pool.valueOf(firstNameComponent, secondNameComponent); + } + + private AttributeKey(int id, String name) { + super(id, name); + } +} diff --git a/netty-util/src/main/java/io/netty/util/AttributeMap.java b/netty-util/src/main/java/io/netty/util/AttributeMap.java new file mode 100644 index 0000000..6ae0caf --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/AttributeMap.java @@ -0,0 +1,34 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util; + +/** + * Holds {@link Attribute}s which can be accessed via {@link AttributeKey}. + *

+ * Implementations must be Thread-safe. + */ +public interface AttributeMap { + /** + * Get the {@link Attribute} for the given {@link AttributeKey}. This method will never return null, but may return + * an {@link Attribute} which does not have a value set yet. + */ + Attribute attr(AttributeKey key); + + /** + * Returns {@code true} if and only if the given {@link Attribute} exists in this {@link AttributeMap}. + */ + boolean hasAttr(AttributeKey key); +} diff --git a/netty-util/src/main/java/io/netty/util/BooleanSupplier.java b/netty-util/src/main/java/io/netty/util/BooleanSupplier.java new file mode 100644 index 0000000..b8222c8 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/BooleanSupplier.java @@ -0,0 +1,49 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util; + +/** + * Represents a supplier of {@code boolean}-valued results. + */ +public interface BooleanSupplier { + /** + * Gets a boolean value. + * + * @return a boolean value. + * @throws Exception If an exception occurs. + */ + boolean get() throws Exception; + + /** + * A supplier which always returns {@code false} and never throws. + */ + BooleanSupplier FALSE_SUPPLIER = new BooleanSupplier() { + @Override + public boolean get() { + return false; + } + }; + + /** + * A supplier which always returns {@code true} and never throws. + */ + BooleanSupplier TRUE_SUPPLIER = new BooleanSupplier() { + @Override + public boolean get() { + return true; + } + }; +} diff --git a/netty-util/src/main/java/io/netty/util/ByteProcessor.java b/netty-util/src/main/java/io/netty/util/ByteProcessor.java new file mode 100644 index 0000000..4162268 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/ByteProcessor.java @@ -0,0 +1,148 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.util; + +import static io.netty.util.ByteProcessorUtils.CARRIAGE_RETURN; +import static io.netty.util.ByteProcessorUtils.HTAB; +import static io.netty.util.ByteProcessorUtils.LINE_FEED; +import static io.netty.util.ByteProcessorUtils.SPACE; + +/** + * Provides a mechanism to iterate over a collection of bytes. + */ +public interface ByteProcessor { + /** + * A {@link ByteProcessor} which finds the first appearance of a specific byte. + */ + class IndexOfProcessor implements ByteProcessor { + private final byte byteToFind; + + public IndexOfProcessor(byte byteToFind) { + this.byteToFind = byteToFind; + } + + @Override + public boolean process(byte value) { + return value != byteToFind; + } + } + + /** + * A {@link ByteProcessor} which finds the first appearance which is not of a specific byte. + */ + class IndexNotOfProcessor implements ByteProcessor { + private final byte byteToNotFind; + + public IndexNotOfProcessor(byte byteToNotFind) { + this.byteToNotFind = byteToNotFind; + } + + @Override + public boolean process(byte value) { + return value == byteToNotFind; + } + } + + /** + * Aborts on a {@code NUL (0x00)}. + */ + ByteProcessor FIND_NUL = new IndexOfProcessor((byte) 0); + + /** + * Aborts on a non-{@code NUL (0x00)}. + */ + ByteProcessor FIND_NON_NUL = new IndexNotOfProcessor((byte) 0); + + /** + * Aborts on a {@code CR ('\r')}. + */ + ByteProcessor FIND_CR = new IndexOfProcessor(CARRIAGE_RETURN); + + /** + * Aborts on a non-{@code CR ('\r')}. + */ + ByteProcessor FIND_NON_CR = new IndexNotOfProcessor(CARRIAGE_RETURN); + + /** + * Aborts on a {@code LF ('\n')}. + */ + ByteProcessor FIND_LF = new IndexOfProcessor(LINE_FEED); + + /** + * Aborts on a non-{@code LF ('\n')}. + */ + ByteProcessor FIND_NON_LF = new IndexNotOfProcessor(LINE_FEED); + + /** + * Aborts on a semicolon {@code (';')}. + */ + ByteProcessor FIND_SEMI_COLON = new IndexOfProcessor((byte) ';'); + + /** + * Aborts on a comma {@code (',')}. + */ + ByteProcessor FIND_COMMA = new IndexOfProcessor((byte) ','); + + /** + * Aborts on a ascii space character ({@code ' '}). + */ + ByteProcessor FIND_ASCII_SPACE = new IndexOfProcessor(SPACE); + + /** + * Aborts on a {@code CR ('\r')} or a {@code LF ('\n')}. + */ + ByteProcessor FIND_CRLF = new ByteProcessor() { + @Override + public boolean process(byte value) { + return value != CARRIAGE_RETURN && value != LINE_FEED; + } + }; + + /** + * Aborts on a byte which is neither a {@code CR ('\r')} nor a {@code LF ('\n')}. + */ + ByteProcessor FIND_NON_CRLF = new ByteProcessor() { + @Override + public boolean process(byte value) { + return value == CARRIAGE_RETURN || value == LINE_FEED; + } + }; + + /** + * Aborts on a linear whitespace (a ({@code ' '} or a {@code '\t'}). + */ + ByteProcessor FIND_LINEAR_WHITESPACE = new ByteProcessor() { + @Override + public boolean process(byte value) { + return value != SPACE && value != HTAB; + } + }; + + /** + * Aborts on a byte which is not a linear whitespace (neither {@code ' '} nor {@code '\t'}). + */ + ByteProcessor FIND_NON_LINEAR_WHITESPACE = new ByteProcessor() { + @Override + public boolean process(byte value) { + return value == SPACE || value == HTAB; + } + }; + + /** + * @return {@code true} if the processor wants to continue the loop and handle the next byte in the buffer. + * {@code false} if the processor wants to stop handling bytes and abort the loop. + */ + boolean process(byte value) throws Exception; +} diff --git a/netty-util/src/main/java/io/netty/util/ByteProcessorUtils.java b/netty-util/src/main/java/io/netty/util/ByteProcessorUtils.java new file mode 100644 index 0000000..073af93 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/ByteProcessorUtils.java @@ -0,0 +1,25 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.util; + +final class ByteProcessorUtils { + static final byte SPACE = (byte) ' '; + static final byte HTAB = (byte) '\t'; + static final byte CARRIAGE_RETURN = (byte) '\r'; + static final byte LINE_FEED = (byte) '\n'; + + private ByteProcessorUtils() { + } +} diff --git a/netty-util/src/main/java/io/netty/util/CharsetUtil.java b/netty-util/src/main/java/io/netty/util/CharsetUtil.java new file mode 100644 index 0000000..c62cbf7 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/CharsetUtil.java @@ -0,0 +1,186 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util; + +import io.netty.util.internal.InternalThreadLocalMap; +import java.nio.charset.Charset; +import java.nio.charset.CharsetDecoder; +import java.nio.charset.CharsetEncoder; +import java.nio.charset.CodingErrorAction; +import java.nio.charset.StandardCharsets; +import java.util.Map; +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +/** + * A utility class that provides various common operations and constants + * related with {@link Charset} and its relevant classes. + */ +public final class CharsetUtil { + + /** + * 16-bit UTF (UCS Transformation Format) whose byte order is identified by + * an optional byte-order mark + */ + public static final Charset UTF_16 = StandardCharsets.UTF_16; + + /** + * 16-bit UTF (UCS Transformation Format) whose byte order is big-endian + */ + public static final Charset UTF_16BE = StandardCharsets.UTF_16BE; + + /** + * 16-bit UTF (UCS Transformation Format) whose byte order is little-endian + */ + public static final Charset UTF_16LE = StandardCharsets.UTF_16LE; + + /** + * 8-bit UTF (UCS Transformation Format) + */ + public static final Charset UTF_8 = StandardCharsets.UTF_8; + + /** + * ISO Latin Alphabet No. 1, as known as ISO-LATIN-1 + */ + public static final Charset ISO_8859_1 = StandardCharsets.ISO_8859_1; + + /** + * 7-bit ASCII, as known as ISO646-US or the Basic Latin block of the + * Unicode character set + */ + public static final Charset US_ASCII = StandardCharsets.US_ASCII; + + private static final Charset[] CHARSETS = new Charset[] + {UTF_16, UTF_16BE, UTF_16LE, UTF_8, ISO_8859_1, US_ASCII}; + + public static Charset[] values() { + return CHARSETS; + } + + /** + * @deprecated Use {@link #encoder(Charset)}. + */ + @Deprecated + public static CharsetEncoder getEncoder(Charset charset) { + return encoder(charset); + } + + /** + * Returns a new {@link CharsetEncoder} for the {@link Charset} with specified error actions. + * + * @param charset The specified charset + * @param malformedInputAction The encoder's action for malformed-input errors + * @param unmappableCharacterAction The encoder's action for unmappable-character errors + * @return The encoder for the specified {@code charset} + */ + public static CharsetEncoder encoder(Charset charset, CodingErrorAction malformedInputAction, + CodingErrorAction unmappableCharacterAction) { + checkNotNull(charset, "charset"); + CharsetEncoder e = charset.newEncoder(); + e.onMalformedInput(malformedInputAction).onUnmappableCharacter(unmappableCharacterAction); + return e; + } + + /** + * Returns a new {@link CharsetEncoder} for the {@link Charset} with the specified error action. + * + * @param charset The specified charset + * @param codingErrorAction The encoder's action for malformed-input and unmappable-character errors + * @return The encoder for the specified {@code charset} + */ + public static CharsetEncoder encoder(Charset charset, CodingErrorAction codingErrorAction) { + return encoder(charset, codingErrorAction, codingErrorAction); + } + + /** + * Returns a cached thread-local {@link CharsetEncoder} for the specified {@link Charset}. + * + * @param charset The specified charset + * @return The encoder for the specified {@code charset} + */ + public static CharsetEncoder encoder(Charset charset) { + checkNotNull(charset, "charset"); + + Map map = InternalThreadLocalMap.get().charsetEncoderCache(); + CharsetEncoder e = map.get(charset); + if (e != null) { + e.reset().onMalformedInput(CodingErrorAction.REPLACE).onUnmappableCharacter(CodingErrorAction.REPLACE); + return e; + } + + e = encoder(charset, CodingErrorAction.REPLACE, CodingErrorAction.REPLACE); + map.put(charset, e); + return e; + } + + /** + * @deprecated Use {@link #decoder(Charset)}. + */ + @Deprecated + public static CharsetDecoder getDecoder(Charset charset) { + return decoder(charset); + } + + /** + * Returns a new {@link CharsetDecoder} for the {@link Charset} with specified error actions. + * + * @param charset The specified charset + * @param malformedInputAction The decoder's action for malformed-input errors + * @param unmappableCharacterAction The decoder's action for unmappable-character errors + * @return The decoder for the specified {@code charset} + */ + public static CharsetDecoder decoder(Charset charset, CodingErrorAction malformedInputAction, + CodingErrorAction unmappableCharacterAction) { + checkNotNull(charset, "charset"); + CharsetDecoder d = charset.newDecoder(); + d.onMalformedInput(malformedInputAction).onUnmappableCharacter(unmappableCharacterAction); + return d; + } + + /** + * Returns a new {@link CharsetDecoder} for the {@link Charset} with the specified error action. + * + * @param charset The specified charset + * @param codingErrorAction The decoder's action for malformed-input and unmappable-character errors + * @return The decoder for the specified {@code charset} + */ + public static CharsetDecoder decoder(Charset charset, CodingErrorAction codingErrorAction) { + return decoder(charset, codingErrorAction, codingErrorAction); + } + + /** + * Returns a cached thread-local {@link CharsetDecoder} for the specified {@link Charset}. + * + * @param charset The specified charset + * @return The decoder for the specified {@code charset} + */ + public static CharsetDecoder decoder(Charset charset) { + checkNotNull(charset, "charset"); + + Map map = InternalThreadLocalMap.get().charsetDecoderCache(); + CharsetDecoder d = map.get(charset); + if (d != null) { + d.reset().onMalformedInput(CodingErrorAction.REPLACE).onUnmappableCharacter(CodingErrorAction.REPLACE); + return d; + } + + d = decoder(charset, CodingErrorAction.REPLACE, CodingErrorAction.REPLACE); + map.put(charset, d); + return d; + } + + private CharsetUtil() { + } +} diff --git a/netty-util/src/main/java/io/netty/util/Constant.java b/netty-util/src/main/java/io/netty/util/Constant.java new file mode 100644 index 0000000..3ada9d8 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/Constant.java @@ -0,0 +1,32 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util; + +/** + * A singleton which is safe to compare via the {@code ==} operator. Created and managed by {@link ConstantPool}. + */ +public interface Constant> extends Comparable { + + /** + * Returns the unique number assigned to this {@link Constant}. + */ + int id(); + + /** + * Returns the name of this {@link Constant}. + */ + String name(); +} diff --git a/netty-util/src/main/java/io/netty/util/ConstantPool.java b/netty-util/src/main/java/io/netty/util/ConstantPool.java new file mode 100644 index 0000000..b0156f8 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/ConstantPool.java @@ -0,0 +1,115 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.util; + +import io.netty.util.internal.PlatformDependent; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.atomic.AtomicInteger; +import static io.netty.util.internal.ObjectUtil.checkNonEmpty; +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +/** + * A pool of {@link Constant}s. + * + * @param the type of the constant + */ +public abstract class ConstantPool> { + + private final ConcurrentMap constants = PlatformDependent.newConcurrentHashMap(); + + private final AtomicInteger nextId = new AtomicInteger(1); + + /** + * Shortcut of {@link #valueOf(String) valueOf(firstNameComponent.getName() + "#" + secondNameComponent)}. + */ + public T valueOf(Class firstNameComponent, String secondNameComponent) { + return valueOf( + checkNotNull(firstNameComponent, "firstNameComponent").getName() + + '#' + + checkNotNull(secondNameComponent, "secondNameComponent")); + } + + /** + * Returns the {@link Constant} which is assigned to the specified {@code name}. + * If there's no such {@link Constant}, a new one will be created and returned. + * Once created, the subsequent calls with the same {@code name} will always return the previously created one + * (i.e. singleton.) + * + * @param name the name of the {@link Constant} + */ + public T valueOf(String name) { + return getOrCreate(checkNonEmpty(name, "name")); + } + + /** + * Get existing constant by name or creates new one if not exists. Threadsafe + * + * @param name the name of the {@link Constant} + */ + private T getOrCreate(String name) { + T constant = constants.get(name); + if (constant == null) { + final T tempConstant = newConstant(nextId(), name); + constant = constants.putIfAbsent(name, tempConstant); + if (constant == null) { + return tempConstant; + } + } + + return constant; + } + + /** + * Returns {@code true} if a {@link AttributeKey} exists for the given {@code name}. + */ + public boolean exists(String name) { + return constants.containsKey(checkNonEmpty(name, "name")); + } + + /** + * Creates a new {@link Constant} for the given {@code name} or fail with an + * {@link IllegalArgumentException} if a {@link Constant} for the given {@code name} exists. + */ + public T newInstance(String name) { + return createOrThrow(checkNonEmpty(name, "name")); + } + + /** + * Creates constant by name or throws exception. Threadsafe + * + * @param name the name of the {@link Constant} + */ + private T createOrThrow(String name) { + T constant = constants.get(name); + if (constant == null) { + final T tempConstant = newConstant(nextId(), name); + constant = constants.putIfAbsent(name, tempConstant); + if (constant == null) { + return tempConstant; + } + } + + throw new IllegalArgumentException(String.format("'%s' is already in use", name)); + } + + protected abstract T newConstant(int id, String name); + + @Deprecated + public final int nextId() { + return nextId.getAndIncrement(); + } +} diff --git a/netty-util/src/main/java/io/netty/util/DefaultAttributeMap.java b/netty-util/src/main/java/io/netty/util/DefaultAttributeMap.java new file mode 100644 index 0000000..5c55d49 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/DefaultAttributeMap.java @@ -0,0 +1,212 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util; + +import io.netty.util.internal.ObjectUtil; +import java.util.Arrays; +import java.util.concurrent.atomic.AtomicReference; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; + +/** + * Default {@link AttributeMap} implementation which not exibit any blocking behaviour on attribute lookup while using a + * copy-on-write approach on the modify path.
Attributes lookup and remove exibit {@code O(logn)} time worst-case + * complexity, hence {@code attribute::set(null)} is to be preferred to {@code remove}. + */ +public class DefaultAttributeMap implements AttributeMap { + + private static final AtomicReferenceFieldUpdater ATTRIBUTES_UPDATER = + AtomicReferenceFieldUpdater.newUpdater(DefaultAttributeMap.class, DefaultAttribute[].class, "attributes"); + private static final DefaultAttribute[] EMPTY_ATTRIBUTES = new DefaultAttribute[0]; + + /** + * Similarly to {@code Arrays::binarySearch} it perform a binary search optimized for this use case, in order to + * save polymorphic calls (on comparator side) and unnecessary class checks. + */ + private static int searchAttributeByKey(DefaultAttribute[] sortedAttributes, AttributeKey key) { + int low = 0; + int high = sortedAttributes.length - 1; + + while (low <= high) { + int mid = low + high >>> 1; + DefaultAttribute midVal = sortedAttributes[mid]; + AttributeKey midValKey = midVal.key; + if (midValKey == key) { + return mid; + } + int midValKeyId = midValKey.id(); + int keyId = key.id(); + assert midValKeyId != keyId; + boolean searchRight = midValKeyId < keyId; + if (searchRight) { + low = mid + 1; + } else { + high = mid - 1; + } + } + + return -(low + 1); + } + + private static void orderedCopyOnInsert(DefaultAttribute[] sortedSrc, int srcLength, DefaultAttribute[] copy, + DefaultAttribute toInsert) { + // let's walk backward, because as a rule of thumb, toInsert.key.id() tends to be higher for new keys + final int id = toInsert.key.id(); + int i; + for (i = srcLength - 1; i >= 0; i--) { + DefaultAttribute attribute = sortedSrc[i]; + assert attribute.key.id() != id; + if (attribute.key.id() < id) { + break; + } + copy[i + 1] = sortedSrc[i]; + } + copy[i + 1] = toInsert; + final int toCopy = i + 1; + if (toCopy > 0) { + System.arraycopy(sortedSrc, 0, copy, 0, toCopy); + } + } + + private volatile DefaultAttribute[] attributes = EMPTY_ATTRIBUTES; + + @SuppressWarnings("unchecked") + @Override + public Attribute attr(AttributeKey key) { + ObjectUtil.checkNotNull(key, "key"); + DefaultAttribute newAttribute = null; + for (; ; ) { + final DefaultAttribute[] attributes = this.attributes; + final int index = searchAttributeByKey(attributes, key); + final DefaultAttribute[] newAttributes; + if (index >= 0) { + final DefaultAttribute attribute = attributes[index]; + assert attribute.key() == key; + if (!attribute.isRemoved()) { + return attribute; + } + // let's try replace the removed attribute with a new one + if (newAttribute == null) { + newAttribute = new DefaultAttribute(this, key); + } + final int count = attributes.length; + newAttributes = Arrays.copyOf(attributes, count); + newAttributes[index] = newAttribute; + } else { + if (newAttribute == null) { + newAttribute = new DefaultAttribute(this, key); + } + final int count = attributes.length; + newAttributes = new DefaultAttribute[count + 1]; + orderedCopyOnInsert(attributes, count, newAttributes, newAttribute); + } + if (ATTRIBUTES_UPDATER.compareAndSet(this, attributes, newAttributes)) { + return newAttribute; + } + } + } + + @Override + public boolean hasAttr(AttributeKey key) { + ObjectUtil.checkNotNull(key, "key"); + return searchAttributeByKey(attributes, key) >= 0; + } + + private void removeAttributeIfMatch(AttributeKey key, DefaultAttribute value) { + for (; ; ) { + final DefaultAttribute[] attributes = this.attributes; + final int index = searchAttributeByKey(attributes, key); + if (index < 0) { + return; + } + final DefaultAttribute attribute = attributes[index]; + assert attribute.key() == key; + if (attribute != value) { + return; + } + final int count = attributes.length; + final int newCount = count - 1; + final DefaultAttribute[] newAttributes = + newCount == 0 ? EMPTY_ATTRIBUTES : new DefaultAttribute[newCount]; + // perform 2 bulk copies + System.arraycopy(attributes, 0, newAttributes, 0, index); + final int remaining = count - index - 1; + if (remaining > 0) { + System.arraycopy(attributes, index + 1, newAttributes, index, remaining); + } + if (ATTRIBUTES_UPDATER.compareAndSet(this, attributes, newAttributes)) { + return; + } + } + } + + @SuppressWarnings("serial") + private static final class DefaultAttribute extends AtomicReference implements Attribute { + + private static final AtomicReferenceFieldUpdater MAP_UPDATER = + AtomicReferenceFieldUpdater.newUpdater(DefaultAttribute.class, + DefaultAttributeMap.class, "attributeMap"); + private static final long serialVersionUID = -2661411462200283011L; + + private volatile DefaultAttributeMap attributeMap; + private final AttributeKey key; + + DefaultAttribute(DefaultAttributeMap attributeMap, AttributeKey key) { + this.attributeMap = attributeMap; + this.key = key; + } + + @Override + public AttributeKey key() { + return key; + } + + private boolean isRemoved() { + return attributeMap == null; + } + + @Override + public T setIfAbsent(T value) { + while (!compareAndSet(null, value)) { + T old = get(); + if (old != null) { + return old; + } + } + return null; + } + + @Override + public T getAndRemove() { + final DefaultAttributeMap attributeMap = this.attributeMap; + final boolean removed = attributeMap != null && MAP_UPDATER.compareAndSet(this, attributeMap, null); + T oldValue = getAndSet(null); + if (removed) { + attributeMap.removeAttributeIfMatch(key, this); + } + return oldValue; + } + + @Override + public void remove() { + final DefaultAttributeMap attributeMap = this.attributeMap; + final boolean removed = attributeMap != null && MAP_UPDATER.compareAndSet(this, attributeMap, null); + set(null); + if (removed) { + attributeMap.removeAttributeIfMatch(key, this); + } + } + } +} diff --git a/netty-util/src/main/java/io/netty/util/DomainMappingBuilder.java b/netty-util/src/main/java/io/netty/util/DomainMappingBuilder.java new file mode 100644 index 0000000..46358ba --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/DomainMappingBuilder.java @@ -0,0 +1,77 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.util; + +/** + * Builder for immutable {@link DomainNameMapping} instances. + * + * @param concrete type of value objects + * @deprecated Use {@link DomainWildcardMappingBuilder} instead. + */ +@Deprecated +public final class DomainMappingBuilder { + + private final DomainNameMappingBuilder builder; + + /** + * Constructor with default initial capacity of the map holding the mappings + * + * @param defaultValue the default value for {@link DomainNameMapping#map(String)} to return + * when nothing matches the input + */ + public DomainMappingBuilder(V defaultValue) { + builder = new DomainNameMappingBuilder(defaultValue); + } + + /** + * Constructor with initial capacity of the map holding the mappings + * + * @param initialCapacity initial capacity for the internal map + * @param defaultValue the default value for {@link DomainNameMapping#map(String)} to return + * when nothing matches the input + */ + public DomainMappingBuilder(int initialCapacity, V defaultValue) { + builder = new DomainNameMappingBuilder(initialCapacity, defaultValue); + } + + /** + * Adds a mapping that maps the specified (optionally wildcard) host name to the specified output value. + * Null values are forbidden for both hostnames and values. + *

+ * DNS wildcard is supported as hostname. + * For example, you can use {@code *.netty.io} to match {@code netty.io} and {@code downloads.netty.io}. + *

+ * + * @param hostname the host name (optionally wildcard) + * @param output the output value that will be returned by {@link DomainNameMapping#map(String)} + * when the specified host name matches the specified input host name + */ + public DomainMappingBuilder add(String hostname, V output) { + builder.add(hostname, output); + return this; + } + + /** + * Creates a new instance of immutable {@link DomainNameMapping} + * Attempts to add new mappings to the result object will cause {@link UnsupportedOperationException} to be thrown + * + * @return new {@link DomainNameMapping} instance + */ + public DomainNameMapping build() { + return builder.build(); + } +} diff --git a/netty-util/src/main/java/io/netty/util/DomainNameMapping.java b/netty-util/src/main/java/io/netty/util/DomainNameMapping.java new file mode 100644 index 0000000..1013a5c --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/DomainNameMapping.java @@ -0,0 +1,151 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.util; + +import io.netty.util.internal.StringUtil; +import java.net.IDN; +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.Locale; +import java.util.Map; +import static io.netty.util.internal.ObjectUtil.checkNotNull; +import static io.netty.util.internal.StringUtil.commonSuffixOfLength; + +/** + * Maps a domain name to its associated value object. + *

+ * DNS wildcard is supported as hostname, so you can use {@code *.netty.io} to match both {@code netty.io} + * and {@code downloads.netty.io}. + *

+ * + * @deprecated Use {@link DomainWildcardMappingBuilder}} + */ +@Deprecated +public class DomainNameMapping implements Mapping { + + final V defaultValue; + private final Map map; + private final Map unmodifiableMap; + + /** + * Creates a default, order-sensitive mapping. If your hostnames are in conflict, the mapping + * will choose the one you add first. + * + * @param defaultValue the default value for {@link #map(String)} to return when nothing matches the input + * @deprecated use {@link DomainNameMappingBuilder} to create and fill the mapping instead + */ + @Deprecated + public DomainNameMapping(V defaultValue) { + this(4, defaultValue); + } + + /** + * Creates a default, order-sensitive mapping. If your hostnames are in conflict, the mapping + * will choose the one you add first. + * + * @param initialCapacity initial capacity for the internal map + * @param defaultValue the default value for {@link #map(String)} to return when nothing matches the input + * @deprecated use {@link DomainNameMappingBuilder} to create and fill the mapping instead + */ + @Deprecated + public DomainNameMapping(int initialCapacity, V defaultValue) { + this(new LinkedHashMap(initialCapacity), defaultValue); + } + + DomainNameMapping(Map map, V defaultValue) { + this.defaultValue = checkNotNull(defaultValue, "defaultValue"); + this.map = map; + unmodifiableMap = map != null ? Collections.unmodifiableMap(map) + : null; + } + + /** + * Adds a mapping that maps the specified (optionally wildcard) host name to the specified output value. + *

+ * DNS wildcard is supported as hostname. + * For example, you can use {@code *.netty.io} to match {@code netty.io} and {@code downloads.netty.io}. + *

+ * + * @param hostname the host name (optionally wildcard) + * @param output the output value that will be returned by {@link #map(String)} when the specified host name + * matches the specified input host name + * @deprecated use {@link DomainNameMappingBuilder} to create and fill the mapping instead + */ + @Deprecated + public DomainNameMapping add(String hostname, V output) { + map.put(normalizeHostname(checkNotNull(hostname, "hostname")), checkNotNull(output, "output")); + return this; + } + + /** + * Simple function to match DNS wildcard. + */ + static boolean matches(String template, String hostName) { + if (template.startsWith("*.")) { + return template.regionMatches(2, hostName, 0, hostName.length()) + || commonSuffixOfLength(hostName, template, template.length() - 1); + } + return template.equals(hostName); + } + + /** + * IDNA ASCII conversion and case normalization + */ + static String normalizeHostname(String hostname) { + if (needsNormalization(hostname)) { + hostname = IDN.toASCII(hostname, IDN.ALLOW_UNASSIGNED); + } + return hostname.toLowerCase(Locale.US); + } + + private static boolean needsNormalization(String hostname) { + final int length = hostname.length(); + for (int i = 0; i < length; i++) { + int c = hostname.charAt(i); + if (c > 0x7F) { + return true; + } + } + return false; + } + + @Override + public V map(String hostname) { + if (hostname != null) { + hostname = normalizeHostname(hostname); + + for (Map.Entry entry : map.entrySet()) { + if (matches(entry.getKey(), hostname)) { + return entry.getValue(); + } + } + } + return defaultValue; + } + + /** + * Returns a read-only {@link Map} of the domain mapping patterns and their associated value objects. + */ + public Map asMap() { + return unmodifiableMap; + } + + @Override + public String toString() { + return StringUtil.simpleClassName(this) + "(default: " + defaultValue + ", map: " + map + ')'; + } +} diff --git a/netty-util/src/main/java/io/netty/util/DomainNameMappingBuilder.java b/netty-util/src/main/java/io/netty/util/DomainNameMappingBuilder.java new file mode 100644 index 0000000..38e19ff --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/DomainNameMappingBuilder.java @@ -0,0 +1,205 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.util; + +import java.util.Collections; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.Set; +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +/** + * Builder for immutable {@link DomainNameMapping} instances. + * + * @param concrete type of value objects + * @deprecated Use {@link DomainWildcardMappingBuilder} + */ +@Deprecated +public final class DomainNameMappingBuilder { + + private final V defaultValue; + private final Map map; + + /** + * Constructor with default initial capacity of the map holding the mappings + * + * @param defaultValue the default value for {@link DomainNameMapping#map(String)} to return + * when nothing matches the input + */ + public DomainNameMappingBuilder(V defaultValue) { + this(4, defaultValue); + } + + /** + * Constructor with initial capacity of the map holding the mappings + * + * @param initialCapacity initial capacity for the internal map + * @param defaultValue the default value for {@link DomainNameMapping#map(String)} to return + * when nothing matches the input + */ + public DomainNameMappingBuilder(int initialCapacity, V defaultValue) { + this.defaultValue = checkNotNull(defaultValue, "defaultValue"); + map = new LinkedHashMap(initialCapacity); + } + + /** + * Adds a mapping that maps the specified (optionally wildcard) host name to the specified output value. + * Null values are forbidden for both hostnames and values. + *

+ * DNS wildcard is supported as hostname. + * For example, you can use {@code *.netty.io} to match {@code netty.io} and {@code downloads.netty.io}. + *

+ * + * @param hostname the host name (optionally wildcard) + * @param output the output value that will be returned by {@link DomainNameMapping#map(String)} + * when the specified host name matches the specified input host name + */ + public DomainNameMappingBuilder add(String hostname, V output) { + map.put(checkNotNull(hostname, "hostname"), checkNotNull(output, "output")); + return this; + } + + /** + * Creates a new instance of immutable {@link DomainNameMapping} + * Attempts to add new mappings to the result object will cause {@link UnsupportedOperationException} to be thrown + * + * @return new {@link DomainNameMapping} instance + */ + public DomainNameMapping build() { + return new ImmutableDomainNameMapping(defaultValue, map); + } + + /** + * Immutable mapping from domain name pattern to its associated value object. + * Mapping is represented by two arrays: keys and values. Key domainNamePatterns[i] is associated with values[i]. + * + * @param concrete type of value objects + */ + private static final class ImmutableDomainNameMapping extends DomainNameMapping { + private static final String REPR_HEADER = "ImmutableDomainNameMapping(default: "; + private static final String REPR_MAP_OPENING = ", map: {"; + private static final String REPR_MAP_CLOSING = "})"; + private static final int REPR_CONST_PART_LENGTH = + REPR_HEADER.length() + REPR_MAP_OPENING.length() + REPR_MAP_CLOSING.length(); + + private final String[] domainNamePatterns; + private final V[] values; + private final Map map; + + @SuppressWarnings("unchecked") + private ImmutableDomainNameMapping(V defaultValue, Map map) { + super(null, defaultValue); + + Set> mappings = map.entrySet(); + int numberOfMappings = mappings.size(); + domainNamePatterns = new String[numberOfMappings]; + values = (V[]) new Object[numberOfMappings]; + + final Map mapCopy = new LinkedHashMap(map.size()); + int index = 0; + for (Map.Entry mapping : mappings) { + final String hostname = normalizeHostname(mapping.getKey()); + final V value = mapping.getValue(); + domainNamePatterns[index] = hostname; + values[index] = value; + mapCopy.put(hostname, value); + ++index; + } + + this.map = Collections.unmodifiableMap(mapCopy); + } + + @Override + @Deprecated + public DomainNameMapping add(String hostname, V output) { + throw new UnsupportedOperationException( + "Immutable DomainNameMapping does not support modification after initial creation"); + } + + @Override + public V map(String hostname) { + if (hostname != null) { + hostname = normalizeHostname(hostname); + + int length = domainNamePatterns.length; + for (int index = 0; index < length; ++index) { + if (matches(domainNamePatterns[index], hostname)) { + return values[index]; + } + } + } + + return defaultValue; + } + + @Override + public Map asMap() { + return map; + } + + @Override + public String toString() { + String defaultValueStr = defaultValue.toString(); + + int numberOfMappings = domainNamePatterns.length; + if (numberOfMappings == 0) { + return REPR_HEADER + defaultValueStr + REPR_MAP_OPENING + REPR_MAP_CLOSING; + } + + String pattern0 = domainNamePatterns[0]; + String value0 = values[0].toString(); + int oneMappingLength = pattern0.length() + value0.length() + 3; // 2 for separator ", " and 1 for '=' + int estimatedBufferSize = estimateBufferSize(defaultValueStr.length(), numberOfMappings, oneMappingLength); + + StringBuilder sb = new StringBuilder(estimatedBufferSize) + .append(REPR_HEADER).append(defaultValueStr).append(REPR_MAP_OPENING); + + appendMapping(sb, pattern0, value0); + for (int index = 1; index < numberOfMappings; ++index) { + sb.append(", "); + appendMapping(sb, index); + } + + return sb.append(REPR_MAP_CLOSING).toString(); + } + + /** + * Estimates the length of string representation of the given instance: + * est = lengthOfConstantComponents + defaultValueLength + (estimatedMappingLength * numOfMappings) * 1.10 + * + * @param defaultValueLength length of string representation of {@link #defaultValue} + * @param numberOfMappings number of mappings the given instance holds, + * e.g. {@link #domainNamePatterns#length} + * @param estimatedMappingLength estimated size taken by one mapping + * @return estimated length of string returned by {@link #toString()} + */ + private static int estimateBufferSize(int defaultValueLength, + int numberOfMappings, + int estimatedMappingLength) { + return REPR_CONST_PART_LENGTH + defaultValueLength + + (int) (estimatedMappingLength * numberOfMappings * 1.10); + } + + private StringBuilder appendMapping(StringBuilder sb, int mappingIndex) { + return appendMapping(sb, domainNamePatterns[mappingIndex], values[mappingIndex].toString()); + } + + private static StringBuilder appendMapping(StringBuilder sb, String domainNamePattern, String value) { + return sb.append(domainNamePattern).append('=').append(value); + } + } +} diff --git a/netty-util/src/main/java/io/netty/util/DomainWildcardMappingBuilder.java b/netty-util/src/main/java/io/netty/util/DomainWildcardMappingBuilder.java new file mode 100644 index 0000000..b433a8c --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/DomainWildcardMappingBuilder.java @@ -0,0 +1,160 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util; + +import java.util.LinkedHashMap; +import java.util.Map; +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +/** + * Builder that allows to build {@link Mapping}s that support + * DNS wildcard matching. + * + * @param the type of the value that we map to. + */ +public class DomainWildcardMappingBuilder { + + private final V defaultValue; + private final Map map; + + /** + * Constructor with default initial capacity of the map holding the mappings + * + * @param defaultValue the default value for {@link Mapping#map(Object)} )} to return + * when nothing matches the input + */ + public DomainWildcardMappingBuilder(V defaultValue) { + this(4, defaultValue); + } + + /** + * Constructor with initial capacity of the map holding the mappings + * + * @param initialCapacity initial capacity for the internal map + * @param defaultValue the default value for {@link Mapping#map(Object)} to return + * when nothing matches the input + */ + public DomainWildcardMappingBuilder(int initialCapacity, V defaultValue) { + this.defaultValue = checkNotNull(defaultValue, "defaultValue"); + map = new LinkedHashMap(initialCapacity); + } + + /** + * Adds a mapping that maps the specified (optionally wildcard) host name to the specified output value. + * {@code null} values are forbidden for both hostnames and values. + *

+ * DNS wildcard is supported as hostname. The + * wildcard will only match one sub-domain deep and only when wildcard is used as the most-left label. + *

+ * For example: + * + *

+ * *.netty.io will match xyz.netty.io but NOT abc.xyz.netty.io + *

+ * + * @param hostname the host name (optionally wildcard) + * @param output the output value that will be returned by {@link Mapping#map(Object)} + * when the specified host name matches the specified input host name + */ + public DomainWildcardMappingBuilder add(String hostname, V output) { + map.put(normalizeHostName(hostname), + checkNotNull(output, "output")); + return this; + } + + private String normalizeHostName(String hostname) { + checkNotNull(hostname, "hostname"); + if (hostname.isEmpty() || hostname.charAt(0) == '.') { + throw new IllegalArgumentException("Hostname '" + hostname + "' not valid"); + } + hostname = ImmutableDomainWildcardMapping.normalize(checkNotNull(hostname, "hostname")); + if (hostname.charAt(0) == '*') { + if (hostname.length() < 3 || hostname.charAt(1) != '.') { + throw new IllegalArgumentException("Wildcard Hostname '" + hostname + "'not valid"); + } + return hostname.substring(1); + } + return hostname; + } + + /** + * Creates a new instance of an immutable {@link Mapping}. + * + * @return new {@link Mapping} instance + */ + public Mapping build() { + return new ImmutableDomainWildcardMapping(defaultValue, map); + } + + private static final class ImmutableDomainWildcardMapping implements Mapping { + private static final String REPR_HEADER = "ImmutableDomainWildcardMapping(default: "; + private static final String REPR_MAP_OPENING = ", map: "; + private static final String REPR_MAP_CLOSING = ")"; + + private final V defaultValue; + private final Map map; + + ImmutableDomainWildcardMapping(V defaultValue, Map map) { + this.defaultValue = defaultValue; + this.map = new LinkedHashMap(map); + } + + @Override + public V map(String hostname) { + if (hostname != null) { + hostname = normalize(hostname); + + // Let's try an exact match first + V value = map.get(hostname); + if (value != null) { + return value; + } + + // No exact match, let's try a wildcard match. + int idx = hostname.indexOf('.'); + if (idx != -1) { + value = map.get(hostname.substring(idx)); + if (value != null) { + return value; + } + } + } + + return defaultValue; + } + + @SuppressWarnings("deprecation") + static String normalize(String hostname) { + return DomainNameMapping.normalizeHostname(hostname); + } + + @Override + public String toString() { + StringBuilder sb = new StringBuilder(); + sb.append(REPR_HEADER).append(defaultValue).append(REPR_MAP_OPENING).append('{'); + + for (Map.Entry entry : map.entrySet()) { + String hostname = entry.getKey(); + if (hostname.charAt(0) == '.') { + hostname = '*' + hostname; + } + sb.append(hostname).append('=').append(entry.getValue()).append(", "); + } + sb.setLength(sb.length() - 2); + return sb.append('}').append(REPR_MAP_CLOSING).toString(); + } + } +} diff --git a/netty-util/src/main/java/io/netty/util/HashedWheelTimer.java b/netty-util/src/main/java/io/netty/util/HashedWheelTimer.java new file mode 100644 index 0000000..30b6062 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/HashedWheelTimer.java @@ -0,0 +1,870 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util; + +import io.netty.util.concurrent.ImmediateExecutor; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; +import java.util.Collections; +import java.util.HashSet; +import java.util.Queue; +import java.util.Set; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; +import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import java.util.concurrent.atomic.AtomicLong; +import static io.netty.util.internal.ObjectUtil.checkInRange; +import static io.netty.util.internal.ObjectUtil.checkNotNull; +import static io.netty.util.internal.ObjectUtil.checkPositive; +import static io.netty.util.internal.StringUtil.simpleClassName; + +/** + * A {@link Timer} optimized for approximated I/O timeout scheduling. + * + *

Tick Duration

+ *

+ * As described with 'approximated', this timer does not execute the scheduled + * {@link TimerTask} on time. {@link HashedWheelTimer}, on every tick, will + * check if there are any {@link TimerTask}s behind the schedule and execute + * them. + *

+ * You can increase or decrease the accuracy of the execution timing by + * specifying smaller or larger tick duration in the constructor. In most + * network applications, I/O timeout does not need to be accurate. Therefore, + * the default tick duration is 100 milliseconds and you will not need to try + * different configurations in most cases. + * + *

Ticks per Wheel (Wheel Size)

+ *

+ * {@link HashedWheelTimer} maintains a data structure called 'wheel'. + * To put simply, a wheel is a hash table of {@link TimerTask}s whose hash + * function is 'dead line of the task'. The default number of ticks per wheel + * (i.e. the size of the wheel) is 512. You could specify a larger value + * if you are going to schedule a lot of timeouts. + * + *

Do not create many instances.

+ *

+ * {@link HashedWheelTimer} creates a new thread whenever it is instantiated and + * started. Therefore, you should make sure to create only one instance and + * share it across your application. One of the common mistakes, that makes + * your application unresponsive, is to create a new instance for every connection. + * + *

Implementation Details

+ *

+ * {@link HashedWheelTimer} is based on + * George Varghese and + * Tony Lauck's paper, + * 'Hashed + * and Hierarchical Timing Wheels: data structures to efficiently implement a + * timer facility'. More comprehensive slides are located + * here. + */ +public class HashedWheelTimer implements Timer { + + static final InternalLogger logger = + InternalLoggerFactory.getInstance(HashedWheelTimer.class); + + private static final AtomicInteger INSTANCE_COUNTER = new AtomicInteger(); + private static final AtomicBoolean WARNED_TOO_MANY_INSTANCES = new AtomicBoolean(); + private static final int INSTANCE_COUNT_LIMIT = 64; + private static final long MILLISECOND_NANOS = TimeUnit.MILLISECONDS.toNanos(1); + private static final ResourceLeakDetector leakDetector = ResourceLeakDetectorFactory.instance() + .newResourceLeakDetector(HashedWheelTimer.class, 1); + + private static final AtomicIntegerFieldUpdater WORKER_STATE_UPDATER = + AtomicIntegerFieldUpdater.newUpdater(HashedWheelTimer.class, "workerState"); + + private final ResourceLeakTracker leak; + private final Worker worker = new Worker(); + private final Thread workerThread; + + public static final int WORKER_STATE_INIT = 0; + public static final int WORKER_STATE_STARTED = 1; + public static final int WORKER_STATE_SHUTDOWN = 2; + @SuppressWarnings({"unused", "FieldMayBeFinal"}) + private volatile int workerState; // 0 - init, 1 - started, 2 - shut down + + private final long tickDuration; + private final HashedWheelBucket[] wheel; + private final int mask; + private final CountDownLatch startTimeInitialized = new CountDownLatch(1); + private final Queue timeouts = PlatformDependent.newMpscQueue(); + private final Queue cancelledTimeouts = PlatformDependent.newMpscQueue(); + private final AtomicLong pendingTimeouts = new AtomicLong(0); + private final long maxPendingTimeouts; + private final Executor taskExecutor; + + private volatile long startTime; + + /** + * Creates a new timer with the default thread factory + * ({@link Executors#defaultThreadFactory()}), default tick duration, and + * default number of ticks per wheel. + */ + public HashedWheelTimer() { + this(Executors.defaultThreadFactory()); + } + + /** + * Creates a new timer with the default thread factory + * ({@link Executors#defaultThreadFactory()}) and default number of ticks + * per wheel. + * + * @param tickDuration the duration between tick + * @param unit the time unit of the {@code tickDuration} + * @throws NullPointerException if {@code unit} is {@code null} + * @throws IllegalArgumentException if {@code tickDuration} is <= 0 + */ + public HashedWheelTimer(long tickDuration, TimeUnit unit) { + this(Executors.defaultThreadFactory(), tickDuration, unit); + } + + /** + * Creates a new timer with the default thread factory + * ({@link Executors#defaultThreadFactory()}). + * + * @param tickDuration the duration between tick + * @param unit the time unit of the {@code tickDuration} + * @param ticksPerWheel the size of the wheel + * @throws NullPointerException if {@code unit} is {@code null} + * @throws IllegalArgumentException if either of {@code tickDuration} and {@code ticksPerWheel} is <= 0 + */ + public HashedWheelTimer(long tickDuration, TimeUnit unit, int ticksPerWheel) { + this(Executors.defaultThreadFactory(), tickDuration, unit, ticksPerWheel); + } + + /** + * Creates a new timer with the default tick duration and default number of + * ticks per wheel. + * + * @param threadFactory a {@link ThreadFactory} that creates a + * background {@link Thread} which is dedicated to + * {@link TimerTask} execution. + * @throws NullPointerException if {@code threadFactory} is {@code null} + */ + public HashedWheelTimer(ThreadFactory threadFactory) { + this(threadFactory, 100, TimeUnit.MILLISECONDS); + } + + /** + * Creates a new timer with the default number of ticks per wheel. + * + * @param threadFactory a {@link ThreadFactory} that creates a + * background {@link Thread} which is dedicated to + * {@link TimerTask} execution. + * @param tickDuration the duration between tick + * @param unit the time unit of the {@code tickDuration} + * @throws NullPointerException if either of {@code threadFactory} and {@code unit} is {@code null} + * @throws IllegalArgumentException if {@code tickDuration} is <= 0 + */ + public HashedWheelTimer( + ThreadFactory threadFactory, long tickDuration, TimeUnit unit) { + this(threadFactory, tickDuration, unit, 512); + } + + /** + * Creates a new timer. + * + * @param threadFactory a {@link ThreadFactory} that creates a + * background {@link Thread} which is dedicated to + * {@link TimerTask} execution. + * @param tickDuration the duration between tick + * @param unit the time unit of the {@code tickDuration} + * @param ticksPerWheel the size of the wheel + * @throws NullPointerException if either of {@code threadFactory} and {@code unit} is {@code null} + * @throws IllegalArgumentException if either of {@code tickDuration} and {@code ticksPerWheel} is <= 0 + */ + public HashedWheelTimer( + ThreadFactory threadFactory, + long tickDuration, TimeUnit unit, int ticksPerWheel) { + this(threadFactory, tickDuration, unit, ticksPerWheel, true); + } + + /** + * Creates a new timer. + * + * @param threadFactory a {@link ThreadFactory} that creates a + * background {@link Thread} which is dedicated to + * {@link TimerTask} execution. + * @param tickDuration the duration between tick + * @param unit the time unit of the {@code tickDuration} + * @param ticksPerWheel the size of the wheel + * @param leakDetection {@code true} if leak detection should be enabled always, + * if false it will only be enabled if the worker thread is not + * a daemon thread. + * @throws NullPointerException if either of {@code threadFactory} and {@code unit} is {@code null} + * @throws IllegalArgumentException if either of {@code tickDuration} and {@code ticksPerWheel} is <= 0 + */ + public HashedWheelTimer( + ThreadFactory threadFactory, + long tickDuration, TimeUnit unit, int ticksPerWheel, boolean leakDetection) { + this(threadFactory, tickDuration, unit, ticksPerWheel, leakDetection, -1); + } + + /** + * Creates a new timer. + * + * @param threadFactory a {@link ThreadFactory} that creates a + * background {@link Thread} which is dedicated to + * {@link TimerTask} execution. + * @param tickDuration the duration between tick + * @param unit the time unit of the {@code tickDuration} + * @param ticksPerWheel the size of the wheel + * @param leakDetection {@code true} if leak detection should be enabled always, + * if false it will only be enabled if the worker thread is not + * a daemon thread. + * @param maxPendingTimeouts The maximum number of pending timeouts after which call to + * {@code newTimeout} will result in + * {@link java.util.concurrent.RejectedExecutionException} + * being thrown. No maximum pending timeouts limit is assumed if + * this value is 0 or negative. + * @throws NullPointerException if either of {@code threadFactory} and {@code unit} is {@code null} + * @throws IllegalArgumentException if either of {@code tickDuration} and {@code ticksPerWheel} is <= 0 + */ + public HashedWheelTimer( + ThreadFactory threadFactory, + long tickDuration, TimeUnit unit, int ticksPerWheel, boolean leakDetection, + long maxPendingTimeouts) { + this(threadFactory, tickDuration, unit, ticksPerWheel, leakDetection, + maxPendingTimeouts, ImmediateExecutor.INSTANCE); + } + + /** + * Creates a new timer. + * + * @param threadFactory a {@link ThreadFactory} that creates a + * background {@link Thread} which is dedicated to + * {@link TimerTask} execution. + * @param tickDuration the duration between tick + * @param unit the time unit of the {@code tickDuration} + * @param ticksPerWheel the size of the wheel + * @param leakDetection {@code true} if leak detection should be enabled always, + * if false it will only be enabled if the worker thread is not + * a daemon thread. + * @param maxPendingTimeouts The maximum number of pending timeouts after which call to + * {@code newTimeout} will result in + * {@link java.util.concurrent.RejectedExecutionException} + * being thrown. No maximum pending timeouts limit is assumed if + * this value is 0 or negative. + * @param taskExecutor The {@link Executor} that is used to execute the submitted {@link TimerTask}s. + * The caller is responsible to shutdown the {@link Executor} once it is not needed + * anymore. + * @throws NullPointerException if either of {@code threadFactory} and {@code unit} is {@code null} + * @throws IllegalArgumentException if either of {@code tickDuration} and {@code ticksPerWheel} is <= 0 + */ + public HashedWheelTimer( + ThreadFactory threadFactory, + long tickDuration, TimeUnit unit, int ticksPerWheel, boolean leakDetection, + long maxPendingTimeouts, Executor taskExecutor) { + + checkNotNull(threadFactory, "threadFactory"); + checkNotNull(unit, "unit"); + checkPositive(tickDuration, "tickDuration"); + checkPositive(ticksPerWheel, "ticksPerWheel"); + this.taskExecutor = checkNotNull(taskExecutor, "taskExecutor"); + + // Normalize ticksPerWheel to power of two and initialize the wheel. + wheel = createWheel(ticksPerWheel); + mask = wheel.length - 1; + + // Convert tickDuration to nanos. + long duration = unit.toNanos(tickDuration); + + // Prevent overflow. + if (duration >= Long.MAX_VALUE / wheel.length) { + throw new IllegalArgumentException(String.format( + "tickDuration: %d (expected: 0 < tickDuration in nanos < %d", + tickDuration, Long.MAX_VALUE / wheel.length)); + } + + if (duration < MILLISECOND_NANOS) { + logger.warn("Configured tickDuration {} smaller than {}, using 1ms.", + tickDuration, MILLISECOND_NANOS); + this.tickDuration = MILLISECOND_NANOS; + } else { + this.tickDuration = duration; + } + + workerThread = threadFactory.newThread(worker); + + leak = leakDetection || !workerThread.isDaemon() ? leakDetector.track(this) : null; + + this.maxPendingTimeouts = maxPendingTimeouts; + + if (INSTANCE_COUNTER.incrementAndGet() > INSTANCE_COUNT_LIMIT && + WARNED_TOO_MANY_INSTANCES.compareAndSet(false, true)) { + reportTooManyInstances(); + } + } + + @Override + protected void finalize() throws Throwable { + try { + super.finalize(); + } finally { + // This object is going to be GCed and it is assumed the ship has sailed to do a proper shutdown. If + // we have not yet shutdown then we want to make sure we decrement the active instance count. + if (WORKER_STATE_UPDATER.getAndSet(this, WORKER_STATE_SHUTDOWN) != WORKER_STATE_SHUTDOWN) { + INSTANCE_COUNTER.decrementAndGet(); + } + } + } + + private static HashedWheelBucket[] createWheel(int ticksPerWheel) { + //ticksPerWheel may not be greater than 2^30 + checkInRange(ticksPerWheel, 1, 1073741824, "ticksPerWheel"); + + ticksPerWheel = normalizeTicksPerWheel(ticksPerWheel); + HashedWheelBucket[] wheel = new HashedWheelBucket[ticksPerWheel]; + for (int i = 0; i < wheel.length; i++) { + wheel[i] = new HashedWheelBucket(); + } + return wheel; + } + + private static int normalizeTicksPerWheel(int ticksPerWheel) { + int normalizedTicksPerWheel = 1; + while (normalizedTicksPerWheel < ticksPerWheel) { + normalizedTicksPerWheel <<= 1; + } + return normalizedTicksPerWheel; + } + + /** + * Starts the background thread explicitly. The background thread will + * start automatically on demand even if you did not call this method. + * + * @throws IllegalStateException if this timer has been + * {@linkplain #stop() stopped} already + */ + public void start() { + switch (WORKER_STATE_UPDATER.get(this)) { + case WORKER_STATE_INIT: + if (WORKER_STATE_UPDATER.compareAndSet(this, WORKER_STATE_INIT, WORKER_STATE_STARTED)) { + workerThread.start(); + } + break; + case WORKER_STATE_STARTED: + break; + case WORKER_STATE_SHUTDOWN: + throw new IllegalStateException("cannot be started once stopped"); + default: + throw new Error("Invalid WorkerState"); + } + + // Wait until the startTime is initialized by the worker. + while (startTime == 0) { + try { + startTimeInitialized.await(); + } catch (InterruptedException ignore) { + // Ignore - it will be ready very soon. + } + } + } + + @Override + public Set stop() { + if (Thread.currentThread() == workerThread) { + throw new IllegalStateException( + HashedWheelTimer.class.getSimpleName() + + ".stop() cannot be called from " + + TimerTask.class.getSimpleName()); + } + + if (!WORKER_STATE_UPDATER.compareAndSet(this, WORKER_STATE_STARTED, WORKER_STATE_SHUTDOWN)) { + // workerState can be 0 or 2 at this moment - let it always be 2. + if (WORKER_STATE_UPDATER.getAndSet(this, WORKER_STATE_SHUTDOWN) != WORKER_STATE_SHUTDOWN) { + INSTANCE_COUNTER.decrementAndGet(); + if (leak != null) { + boolean closed = leak.close(this); + assert closed; + } + } + + return Collections.emptySet(); + } + + try { + boolean interrupted = false; + while (workerThread.isAlive()) { + workerThread.interrupt(); + try { + workerThread.join(100); + } catch (InterruptedException ignored) { + interrupted = true; + } + } + + if (interrupted) { + Thread.currentThread().interrupt(); + } + } finally { + INSTANCE_COUNTER.decrementAndGet(); + if (leak != null) { + boolean closed = leak.close(this); + assert closed; + } + } + return worker.unprocessedTimeouts(); + } + + @Override + public Timeout newTimeout(TimerTask task, long delay, TimeUnit unit) { + checkNotNull(task, "task"); + checkNotNull(unit, "unit"); + + long pendingTimeoutsCount = pendingTimeouts.incrementAndGet(); + + if (maxPendingTimeouts > 0 && pendingTimeoutsCount > maxPendingTimeouts) { + pendingTimeouts.decrementAndGet(); + throw new RejectedExecutionException("Number of pending timeouts (" + + pendingTimeoutsCount + ") is greater than or equal to maximum allowed pending " + + "timeouts (" + maxPendingTimeouts + ")"); + } + + start(); + + // Add the timeout to the timeout queue which will be processed on the next tick. + // During processing all the queued HashedWheelTimeouts will be added to the correct HashedWheelBucket. + long deadline = System.nanoTime() + unit.toNanos(delay) - startTime; + + // Guard against overflow. + if (delay > 0 && deadline < 0) { + deadline = Long.MAX_VALUE; + } + HashedWheelTimeout timeout = new HashedWheelTimeout(this, task, deadline); + timeouts.add(timeout); + return timeout; + } + + /** + * Returns the number of pending timeouts of this {@link Timer}. + */ + public long pendingTimeouts() { + return pendingTimeouts.get(); + } + + private static void reportTooManyInstances() { + if (logger.isErrorEnabled()) { + String resourceType = simpleClassName(HashedWheelTimer.class); + logger.error("You are creating too many " + resourceType + " instances. " + + resourceType + " is a shared resource that must be reused across the JVM, " + + "so that only a few instances are created."); + } + } + + private final class Worker implements Runnable { + private final Set unprocessedTimeouts = new HashSet(); + + private long tick; + + @Override + public void run() { + // Initialize the startTime. + startTime = System.nanoTime(); + if (startTime == 0) { + // We use 0 as an indicator for the uninitialized value here, so make sure it's not 0 when initialized. + startTime = 1; + } + + // Notify the other threads waiting for the initialization at start(). + startTimeInitialized.countDown(); + + do { + final long deadline = waitForNextTick(); + if (deadline > 0) { + int idx = (int) (tick & mask); + processCancelledTasks(); + HashedWheelBucket bucket = + wheel[idx]; + transferTimeoutsToBuckets(); + bucket.expireTimeouts(deadline); + tick++; + } + } while (WORKER_STATE_UPDATER.get(HashedWheelTimer.this) == WORKER_STATE_STARTED); + + // Fill the unprocessedTimeouts so we can return them from stop() method. + for (HashedWheelBucket bucket : wheel) { + bucket.clearTimeouts(unprocessedTimeouts); + } + for (; ; ) { + HashedWheelTimeout timeout = timeouts.poll(); + if (timeout == null) { + break; + } + if (!timeout.isCancelled()) { + unprocessedTimeouts.add(timeout); + } + } + processCancelledTasks(); + } + + private void transferTimeoutsToBuckets() { + // transfer only max. 100000 timeouts per tick to prevent a thread to stale the workerThread when it just + // adds new timeouts in a loop. + for (int i = 0; i < 100000; i++) { + HashedWheelTimeout timeout = timeouts.poll(); + if (timeout == null) { + // all processed + break; + } + if (timeout.state() == HashedWheelTimeout.ST_CANCELLED) { + // Was cancelled in the meantime. + continue; + } + + long calculated = timeout.deadline / tickDuration; + timeout.remainingRounds = (calculated - tick) / wheel.length; + + final long ticks = Math.max(calculated, tick); // Ensure we don't schedule for past. + int stopIndex = (int) (ticks & mask); + + HashedWheelBucket bucket = wheel[stopIndex]; + bucket.addTimeout(timeout); + } + } + + private void processCancelledTasks() { + for (; ; ) { + HashedWheelTimeout timeout = cancelledTimeouts.poll(); + if (timeout == null) { + // all processed + break; + } + try { + timeout.remove(); + } catch (Throwable t) { + if (logger.isWarnEnabled()) { + logger.warn("An exception was thrown while process a cancellation task", t); + } + } + } + } + + /** + * calculate goal nanoTime from startTime and current tick number, + * then wait until that goal has been reached. + * + * @return Long.MIN_VALUE if received a shutdown request, + * current time otherwise (with Long.MIN_VALUE changed by +1) + */ + private long waitForNextTick() { + long deadline = tickDuration * (tick + 1); + + for (; ; ) { + final long currentTime = System.nanoTime() - startTime; + long sleepTimeMs = (deadline - currentTime + 999999) / 1000000; + + if (sleepTimeMs <= 0) { + if (currentTime == Long.MIN_VALUE) { + return -Long.MAX_VALUE; + } else { + return currentTime; + } + } + + // Check if we run on windows, as if thats the case we will need + // to round the sleepTime as workaround for a bug that only affect + // the JVM if it runs on windows. + // + // See https://github.com/netty/netty/issues/356 + if (PlatformDependent.isWindows()) { + sleepTimeMs = sleepTimeMs / 10 * 10; + if (sleepTimeMs == 0) { + sleepTimeMs = 1; + } + } + + try { + Thread.sleep(sleepTimeMs); + } catch (InterruptedException ignored) { + if (WORKER_STATE_UPDATER.get(HashedWheelTimer.this) == WORKER_STATE_SHUTDOWN) { + return Long.MIN_VALUE; + } + } + } + } + + public Set unprocessedTimeouts() { + return Collections.unmodifiableSet(unprocessedTimeouts); + } + } + + private static final class HashedWheelTimeout implements Timeout, Runnable { + + private static final int ST_INIT = 0; + private static final int ST_CANCELLED = 1; + private static final int ST_EXPIRED = 2; + private static final AtomicIntegerFieldUpdater STATE_UPDATER = + AtomicIntegerFieldUpdater.newUpdater(HashedWheelTimeout.class, "state"); + + private final HashedWheelTimer timer; + private final TimerTask task; + private final long deadline; + + @SuppressWarnings({"unused", "FieldMayBeFinal", "RedundantFieldInitialization"}) + private volatile int state = ST_INIT; + + // remainingRounds will be calculated and set by Worker.transferTimeoutsToBuckets() before the + // HashedWheelTimeout will be added to the correct HashedWheelBucket. + long remainingRounds; + + // This will be used to chain timeouts in HashedWheelTimerBucket via a double-linked-list. + // As only the workerThread will act on it there is no need for synchronization / volatile. + HashedWheelTimeout next; + HashedWheelTimeout prev; + + // The bucket to which the timeout was added + HashedWheelBucket bucket; + + HashedWheelTimeout(HashedWheelTimer timer, TimerTask task, long deadline) { + this.timer = timer; + this.task = task; + this.deadline = deadline; + } + + @Override + public Timer timer() { + return timer; + } + + @Override + public TimerTask task() { + return task; + } + + @Override + public boolean cancel() { + // only update the state it will be removed from HashedWheelBucket on next tick. + if (!compareAndSetState(ST_INIT, ST_CANCELLED)) { + return false; + } + // If a task should be canceled we put this to another queue which will be processed on each tick. + // So this means that we will have a GC latency of max. 1 tick duration which is good enough. This way + // we can make again use of our MpscLinkedQueue and so minimize the locking / overhead as much as possible. + timer.cancelledTimeouts.add(this); + return true; + } + + void remove() { + HashedWheelBucket bucket = this.bucket; + if (bucket != null) { + bucket.remove(this); + } else { + timer.pendingTimeouts.decrementAndGet(); + } + } + + public boolean compareAndSetState(int expected, int state) { + return STATE_UPDATER.compareAndSet(this, expected, state); + } + + public int state() { + return state; + } + + @Override + public boolean isCancelled() { + return state() == ST_CANCELLED; + } + + @Override + public boolean isExpired() { + return state() == ST_EXPIRED; + } + + public void expire() { + if (!compareAndSetState(ST_INIT, ST_EXPIRED)) { + return; + } + + try { + timer.taskExecutor.execute(this); + } catch (Throwable t) { + if (logger.isWarnEnabled()) { + logger.warn("An exception was thrown while submit " + TimerTask.class.getSimpleName() + + " for execution.", t); + } + } + } + + @Override + public void run() { + try { + task.run(this); + } catch (Throwable t) { + if (logger.isWarnEnabled()) { + logger.warn("An exception was thrown by " + TimerTask.class.getSimpleName() + '.', t); + } + } + } + + @Override + public String toString() { + final long currentTime = System.nanoTime(); + long remaining = deadline - currentTime + timer.startTime; + + StringBuilder buf = new StringBuilder(192) + .append(simpleClassName(this)) + .append('(') + .append("deadline: "); + if (remaining > 0) { + buf.append(remaining) + .append(" ns later"); + } else if (remaining < 0) { + buf.append(-remaining) + .append(" ns ago"); + } else { + buf.append("now"); + } + + if (isCancelled()) { + buf.append(", cancelled"); + } + + return buf.append(", task: ") + .append(task()) + .append(')') + .toString(); + } + } + + /** + * Bucket that stores HashedWheelTimeouts. These are stored in a linked-list like datastructure to allow easy + * removal of HashedWheelTimeouts in the middle. Also the HashedWheelTimeout act as nodes themself and so no + * extra object creation is needed. + */ + private static final class HashedWheelBucket { + // Used for the linked-list datastructure + private HashedWheelTimeout head; + private HashedWheelTimeout tail; + + /** + * Add {@link HashedWheelTimeout} to this bucket. + */ + public void addTimeout(HashedWheelTimeout timeout) { + assert timeout.bucket == null; + timeout.bucket = this; + if (head == null) { + head = tail = timeout; + } else { + tail.next = timeout; + timeout.prev = tail; + tail = timeout; + } + } + + /** + * Expire all {@link HashedWheelTimeout}s for the given {@code deadline}. + */ + public void expireTimeouts(long deadline) { + HashedWheelTimeout timeout = head; + + // process all timeouts + while (timeout != null) { + HashedWheelTimeout next = timeout.next; + if (timeout.remainingRounds <= 0) { + next = remove(timeout); + if (timeout.deadline <= deadline) { + timeout.expire(); + } else { + // The timeout was placed into a wrong slot. This should never happen. + throw new IllegalStateException(String.format( + "timeout.deadline (%d) > deadline (%d)", timeout.deadline, deadline)); + } + } else if (timeout.isCancelled()) { + next = remove(timeout); + } else { + timeout.remainingRounds--; + } + timeout = next; + } + } + + public HashedWheelTimeout remove(HashedWheelTimeout timeout) { + HashedWheelTimeout next = timeout.next; + // remove timeout that was either processed or cancelled by updating the linked-list + if (timeout.prev != null) { + timeout.prev.next = next; + } + if (timeout.next != null) { + timeout.next.prev = timeout.prev; + } + + if (timeout == head) { + // if timeout is also the tail we need to adjust the entry too + if (timeout == tail) { + tail = null; + head = null; + } else { + head = next; + } + } else if (timeout == tail) { + // if the timeout is the tail modify the tail to be the prev node. + tail = timeout.prev; + } + // null out prev, next and bucket to allow for GC. + timeout.prev = null; + timeout.next = null; + timeout.bucket = null; + timeout.timer.pendingTimeouts.decrementAndGet(); + return next; + } + + /** + * Clear this bucket and return all not expired / cancelled {@link Timeout}s. + */ + public void clearTimeouts(Set set) { + for (; ; ) { + HashedWheelTimeout timeout = pollTimeout(); + if (timeout == null) { + return; + } + if (timeout.isExpired() || timeout.isCancelled()) { + continue; + } + set.add(timeout); + } + } + + private HashedWheelTimeout pollTimeout() { + HashedWheelTimeout head = this.head; + if (head == null) { + return null; + } + HashedWheelTimeout next = head.next; + if (next == null) { + tail = this.head = null; + } else { + this.head = next; + next.prev = null; + } + + // null out prev and next to allow for GC. + head.next = null; + head.prev = null; + head.bucket = null; + return head; + } + } +} diff --git a/netty-util/src/main/java/io/netty/util/HashingStrategy.java b/netty-util/src/main/java/io/netty/util/HashingStrategy.java new file mode 100644 index 0000000..46a3b97 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/HashingStrategy.java @@ -0,0 +1,75 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util; + +import java.util.Objects; + +/** + * Abstraction for hash code generation and equality comparison. + */ +public interface HashingStrategy { + /** + * Generate a hash code for {@code obj}. + *

+ * This method must obey the same relationship that {@link java.lang.Object#hashCode()} has with + * {@link java.lang.Object#equals(Object)}: + *

    + *
  • Calling this method multiple times with the same {@code obj} should return the same result
  • + *
  • If {@link #equals(Object, Object)} with parameters {@code a} and {@code b} returns {@code true} + * then the return value for this method for parameters {@code a} and {@code b} must return the same result
  • + *
  • If {@link #equals(Object, Object)} with parameters {@code a} and {@code b} returns {@code false} + * then the return value for this method for parameters {@code a} and {@code b} does not have to + * return different results results. However this property is desirable.
  • + *
  • if {@code obj} is {@code null} then this method return {@code 0}
  • + *
+ */ + int hashCode(T obj); + + /** + * Returns {@code true} if the arguments are equal to each other and {@code false} otherwise. + * This method has the following restrictions: + *
    + *
  • reflexive - {@code equals(a, a)} should return true
  • + *
  • symmetric - {@code equals(a, b)} returns {@code true} if {@code equals(b, a)} returns + * {@code true}
  • + *
  • transitive - if {@code equals(a, b)} returns {@code true} and {@code equals(a, c)} returns + * {@code true} then {@code equals(b, c)} should also return {@code true}
  • + *
  • consistent - {@code equals(a, b)} should return the same result when called multiple times + * assuming {@code a} and {@code b} remain unchanged relative to the comparison criteria
  • + *
  • if {@code a} and {@code b} are both {@code null} then this method returns {@code true}
  • + *
  • if {@code a} is {@code null} and {@code b} is non-{@code null}, or {@code a} is non-{@code null} and + * {@code b} is {@code null} then this method returns {@code false}
  • + *
+ */ + boolean equals(T a, T b); + + /** + * A {@link HashingStrategy} which delegates to java's {@link Object#hashCode()} + * and {@link Object#equals(Object)}. + */ + @SuppressWarnings("rawtypes") + HashingStrategy JAVA_HASHER = new HashingStrategy() { + @Override + public int hashCode(Object obj) { + return obj != null ? obj.hashCode() : 0; + } + + @Override + public boolean equals(Object a, Object b) { + return Objects.equals(a, b); + } + }; +} diff --git a/netty-util/src/main/java/io/netty/util/IllegalReferenceCountException.java b/netty-util/src/main/java/io/netty/util/IllegalReferenceCountException.java new file mode 100644 index 0000000..22deb7c --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/IllegalReferenceCountException.java @@ -0,0 +1,49 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.util; + +/** + * An {@link IllegalStateException} which is raised when a user attempts to access a {@link ReferenceCounted} whose + * reference count has been decreased to 0 (and consequently freed). + */ +public class IllegalReferenceCountException extends IllegalStateException { + + private static final long serialVersionUID = -2507492394288153468L; + + public IllegalReferenceCountException() { + } + + public IllegalReferenceCountException(int refCnt) { + this("refCnt: " + refCnt); + } + + public IllegalReferenceCountException(int refCnt, int increment) { + this("refCnt: " + refCnt + ", " + (increment > 0 ? "increment: " + increment : "decrement: " + -increment)); + } + + public IllegalReferenceCountException(String message) { + super(message); + } + + public IllegalReferenceCountException(String message, Throwable cause) { + super(message, cause); + } + + public IllegalReferenceCountException(Throwable cause) { + super(cause); + } +} diff --git a/netty-util/src/main/java/io/netty/util/IntSupplier.java b/netty-util/src/main/java/io/netty/util/IntSupplier.java new file mode 100644 index 0000000..2840e05 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/IntSupplier.java @@ -0,0 +1,29 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util; + +/** + * Represents a supplier of {@code int}-valued results. + */ +public interface IntSupplier { + + /** + * Gets a result. + * + * @return a result + */ + int get() throws Exception; +} diff --git a/netty-util/src/main/java/io/netty/util/Mapping.java b/netty-util/src/main/java/io/netty/util/Mapping.java new file mode 100644 index 0000000..b57c746 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/Mapping.java @@ -0,0 +1,27 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util; + +/** + * Maintains the mapping from the objects of one type to the objects of the other type. + */ +public interface Mapping { + + /** + * Returns mapped value of the specified input. + */ + OUT map(IN input); +} diff --git a/netty-util/src/main/java/io/netty/util/NetUtil.java b/netty-util/src/main/java/io/netty/util/NetUtil.java new file mode 100644 index 0000000..9c014f5 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/NetUtil.java @@ -0,0 +1,1102 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util; + +import io.netty.util.NetUtilInitializations.NetworkIfaceAndInetAddress; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.StringUtil; +import io.netty.util.internal.SystemPropertyUtil; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; +import java.io.BufferedReader; +import java.io.File; +import java.io.FileReader; +import java.io.IOException; +import java.io.InputStream; +import java.io.InputStreamReader; +import java.net.Inet4Address; +import java.net.Inet6Address; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.NetworkInterface; +import java.net.UnknownHostException; +import java.security.AccessController; +import java.security.PrivilegedAction; +import java.util.Arrays; +import java.util.Collection; +import static io.netty.util.AsciiString.indexOf; + +/** + * A class that holds a number of network-related constants. + *

+ * This class borrowed some of its methods from a modified fork of the + * Inet6Util class which was part of Apache Harmony. + */ +public final class NetUtil { + + /** + * The {@link Inet4Address} that represents the IPv4 loopback address '127.0.0.1' + */ + public static final Inet4Address LOCALHOST4; + + /** + * The {@link Inet6Address} that represents the IPv6 loopback address '::1' + */ + public static final Inet6Address LOCALHOST6; + + /** + * The {@link InetAddress} that represents the loopback address. If IPv6 stack is available, it will refer to + * {@link #LOCALHOST6}. Otherwise, {@link #LOCALHOST4}. + */ + public static final InetAddress LOCALHOST; + + /** + * The loopback {@link NetworkInterface} of the current machine + */ + public static final NetworkInterface LOOPBACK_IF; + + /** + * An unmodifiable Collection of all the interfaces on this machine. + */ + public static final Collection NETWORK_INTERFACES; + + /** + * The SOMAXCONN value of the current machine. If failed to get the value, {@code 200} is used as a + * default value for Windows and {@code 128} for others. + */ + public static final int SOMAXCONN; + + /** + * This defines how many words (represented as ints) are needed to represent an IPv6 address + */ + private static final int IPV6_WORD_COUNT = 8; + + /** + * The maximum number of characters for an IPV6 string with no scope + */ + private static final int IPV6_MAX_CHAR_COUNT = 39; + + /** + * Number of bytes needed to represent an IPV6 value + */ + private static final int IPV6_BYTE_COUNT = 16; + + /** + * Maximum amount of value adding characters in between IPV6 separators + */ + private static final int IPV6_MAX_CHAR_BETWEEN_SEPARATOR = 4; + + /** + * Minimum number of separators that must be present in an IPv6 string + */ + private static final int IPV6_MIN_SEPARATORS = 2; + + /** + * Maximum number of separators that must be present in an IPv6 string + */ + private static final int IPV6_MAX_SEPARATORS = 8; + + /** + * Maximum amount of value adding characters in between IPV4 separators + */ + private static final int IPV4_MAX_CHAR_BETWEEN_SEPARATOR = 3; + + /** + * Number of separators that must be present in an IPv4 string + */ + private static final int IPV4_SEPARATORS = 3; + + /** + * {@code true} if IPv4 should be used even if the system supports both IPv4 and IPv6. + */ + private static final boolean IPV4_PREFERRED = SystemPropertyUtil.getBoolean("java.net.preferIPv4Stack", false); + + /** + * {@code true} if an IPv6 address should be preferred when a host has both an IPv4 address and an IPv6 address. + */ + private static final boolean IPV6_ADDRESSES_PREFERRED; + + /** + * The logger being used by this class + */ + private static final InternalLogger logger = InternalLoggerFactory.getInstance(NetUtil.class); + + static { + String prefer = SystemPropertyUtil.get("java.net.preferIPv6Addresses", "false"); + // Let's just use false in this case as only true is "forcing" ipv6. + IPV6_ADDRESSES_PREFERRED = "true".equalsIgnoreCase(prefer.trim()); + logger.debug("-Djava.net.preferIPv4Stack: {}", IPV4_PREFERRED); + logger.debug("-Djava.net.preferIPv6Addresses: {}", prefer); + + NETWORK_INTERFACES = NetUtilInitializations.networkInterfaces(); + + // Create IPv4 loopback address. + LOCALHOST4 = NetUtilInitializations.createLocalhost4(); + + // Create IPv6 loopback address. + LOCALHOST6 = NetUtilInitializations.createLocalhost6(); + + NetworkIfaceAndInetAddress loopback = + NetUtilInitializations.determineLoopback(NETWORK_INTERFACES, LOCALHOST4, LOCALHOST6); + LOOPBACK_IF = loopback.iface(); + LOCALHOST = loopback.address(); + + // As a SecurityManager may prevent reading the somaxconn file we wrap this in a privileged block. + // + // See https://github.com/netty/netty/issues/3680 + SOMAXCONN = AccessController.doPrivileged(new SoMaxConnAction()); + } + + private static final class SoMaxConnAction implements PrivilegedAction { + @Override + public Integer run() { + // Determine the default somaxconn (server socket backlog) value of the platform. + // The known defaults: + // - Windows NT Server 4.0+: 200 + // - Linux and Mac OS X: 128 + int somaxconn = PlatformDependent.isWindows() ? 200 : 128; + File file = new File("/proc/sys/net/core/somaxconn"); + BufferedReader in = null; + try { + // file.exists() may throw a SecurityException if a SecurityManager is used, so execute it in the + // try / catch block. + // See https://github.com/netty/netty/issues/4936 + if (file.exists()) { + in = new BufferedReader(new FileReader(file)); + somaxconn = Integer.parseInt(in.readLine()); + if (logger.isDebugEnabled()) { + logger.debug("{}: {}", file, somaxconn); + } + } else { + // Try to get from sysctl + Integer tmp = null; + if (SystemPropertyUtil.getBoolean("io.netty.net.somaxconn.trySysctl", false)) { + tmp = sysctlGetInt("kern.ipc.somaxconn"); + if (tmp == null) { + tmp = sysctlGetInt("kern.ipc.soacceptqueue"); + if (tmp != null) { + somaxconn = tmp; + } + } else { + somaxconn = tmp; + } + } + + if (tmp == null) { + logger.debug("Failed to get SOMAXCONN from sysctl and file {}. Default: {}", file, + somaxconn); + } + } + } catch (Exception e) { + if (logger.isDebugEnabled()) { + logger.debug("Failed to get SOMAXCONN from sysctl and file {}. Default: {}", + file, somaxconn, e); + } + } finally { + if (in != null) { + try { + in.close(); + } catch (Exception e) { + // Ignored. + } + } + } + return somaxconn; + } + } + + /** + * This will execute sysctl with the {@code sysctlKey} + * which is expected to return the numeric value for for {@code sysctlKey}. + * + * @param sysctlKey The key which the return value corresponds to. + * @return The sysctl value for {@code sysctlKey}. + */ + private static Integer sysctlGetInt(String sysctlKey) throws IOException { + Process process = new ProcessBuilder("sysctl", sysctlKey).start(); + try { + // Suppress warnings about resource leaks since the buffered reader is closed below + InputStream is = process.getInputStream(); + InputStreamReader isr = new InputStreamReader(is); + BufferedReader br = new BufferedReader(isr); + try { + String line = br.readLine(); + if (line != null && line.startsWith(sysctlKey)) { + for (int i = line.length() - 1; i > sysctlKey.length(); --i) { + if (!Character.isDigit(line.charAt(i))) { + return Integer.valueOf(line.substring(i + 1)); + } + } + } + return null; + } finally { + br.close(); + } + } finally { + // No need of 'null' check because we're initializing + // the Process instance in first line. Any exception + // raised will directly lead to throwable. + process.destroy(); + } + } + + /** + * Returns {@code true} if IPv4 should be used even if the system supports both IPv4 and IPv6. Setting this + * property to {@code true} will disable IPv6 support. The default value of this property is {@code false}. + * + * @see Java SE + * networking properties + */ + public static boolean isIpV4StackPreferred() { + return IPV4_PREFERRED; + } + + /** + * Returns {@code true} if an IPv6 address should be preferred when a host has both an IPv4 address and an IPv6 + * address. The default value of this property is {@code false}. + * + * @see Java SE + * networking properties + */ + public static boolean isIpV6AddressesPreferred() { + return IPV6_ADDRESSES_PREFERRED; + } + + /** + * Creates an byte[] based on an ipAddressString. No error handling is performed here. + */ + public static byte[] createByteArrayFromIpAddressString(String ipAddressString) { + + if (isValidIpV4Address(ipAddressString)) { + return validIpV4ToBytes(ipAddressString); + } + + if (isValidIpV6Address(ipAddressString)) { + if (ipAddressString.charAt(0) == '[') { + ipAddressString = ipAddressString.substring(1, ipAddressString.length() - 1); + } + + int percentPos = ipAddressString.indexOf('%'); + if (percentPos >= 0) { + ipAddressString = ipAddressString.substring(0, percentPos); + } + + return getIPv6ByName(ipAddressString, true); + } + return null; + } + + /** + * Creates an {@link InetAddress} based on an ipAddressString or might return null if it can't be parsed. + * No error handling is performed here. + */ + public static InetAddress createInetAddressFromIpAddressString(String ipAddressString) { + if (isValidIpV4Address(ipAddressString)) { + byte[] bytes = validIpV4ToBytes(ipAddressString); + try { + return InetAddress.getByAddress(bytes); + } catch (UnknownHostException e) { + // Should never happen! + throw new IllegalStateException(e); + } + } + + if (isValidIpV6Address(ipAddressString)) { + if (ipAddressString.charAt(0) == '[') { + ipAddressString = ipAddressString.substring(1, ipAddressString.length() - 1); + } + + int percentPos = ipAddressString.indexOf('%'); + if (percentPos >= 0) { + try { + int scopeId = Integer.parseInt(ipAddressString.substring(percentPos + 1)); + ipAddressString = ipAddressString.substring(0, percentPos); + byte[] bytes = getIPv6ByName(ipAddressString, true); + if (bytes == null) { + return null; + } + try { + return Inet6Address.getByAddress(null, bytes, scopeId); + } catch (UnknownHostException e) { + // Should never happen! + throw new IllegalStateException(e); + } + } catch (NumberFormatException e) { + return null; + } + } + byte[] bytes = getIPv6ByName(ipAddressString, true); + if (bytes == null) { + return null; + } + try { + return InetAddress.getByAddress(bytes); + } catch (UnknownHostException e) { + // Should never happen! + throw new IllegalStateException(e); + } + } + return null; + } + + private static int decimalDigit(String str, int pos) { + return str.charAt(pos) - '0'; + } + + private static byte ipv4WordToByte(String ip, int from, int toExclusive) { + int ret = decimalDigit(ip, from); + from++; + if (from == toExclusive) { + return (byte) ret; + } + ret = ret * 10 + decimalDigit(ip, from); + from++; + if (from == toExclusive) { + return (byte) ret; + } + return (byte) (ret * 10 + decimalDigit(ip, from)); + } + + // visible for tests + static byte[] validIpV4ToBytes(String ip) { + int i; + return new byte[]{ + ipv4WordToByte(ip, 0, i = ip.indexOf('.', 1)), + ipv4WordToByte(ip, i + 1, i = ip.indexOf('.', i + 2)), + ipv4WordToByte(ip, i + 1, i = ip.indexOf('.', i + 2)), + ipv4WordToByte(ip, i + 1, ip.length()) + }; + } + + /** + * Convert {@link Inet4Address} into {@code int} + */ + public static int ipv4AddressToInt(Inet4Address ipAddress) { + byte[] octets = ipAddress.getAddress(); + + return (octets[0] & 0xff) << 24 | + (octets[1] & 0xff) << 16 | + (octets[2] & 0xff) << 8 | + octets[3] & 0xff; + } + + /** + * Converts a 32-bit integer into an IPv4 address. + */ + public static String intToIpAddress(int i) { + String buf = String.valueOf(i >> 24 & 0xff) + + '.' + + (i >> 16 & 0xff) + + '.' + + (i >> 8 & 0xff) + + '.' + + (i & 0xff); + return buf; + } + + /** + * Converts 4-byte or 16-byte data into an IPv4 or IPv6 string respectively. + * + * @throws IllegalArgumentException if {@code length} is not {@code 4} nor {@code 16} + */ + public static String bytesToIpAddress(byte[] bytes) { + return bytesToIpAddress(bytes, 0, bytes.length); + } + + /** + * Converts 4-byte or 16-byte data into an IPv4 or IPv6 string respectively. + * + * @throws IllegalArgumentException if {@code length} is not {@code 4} nor {@code 16} + */ + public static String bytesToIpAddress(byte[] bytes, int offset, int length) { + switch (length) { + case 4: { + return String.valueOf(bytes[offset] & 0xff) + + '.' + + (bytes[offset + 1] & 0xff) + + '.' + + (bytes[offset + 2] & 0xff) + + '.' + + (bytes[offset + 3] & 0xff); + } + case 16: + return toAddressString(bytes, offset, false); + default: + throw new IllegalArgumentException("length: " + length + " (expected: 4 or 16)"); + } + } + + public static boolean isValidIpV6Address(String ip) { + return isValidIpV6Address((CharSequence) ip); + } + + public static boolean isValidIpV6Address(CharSequence ip) { + int end = ip.length(); + if (end < 2) { + return false; + } + + // strip "[]" + int start; + char c = ip.charAt(0); + if (c == '[') { + end--; + if (ip.charAt(end) != ']') { + // must have a close ] + return false; + } + start = 1; + c = ip.charAt(1); + } else { + start = 0; + } + + int colons; + int compressBegin; + if (c == ':') { + // an IPv6 address can start with "::" or with a number + if (ip.charAt(start + 1) != ':') { + return false; + } + colons = 2; + compressBegin = start; + start += 2; + } else { + colons = 0; + compressBegin = -1; + } + + int wordLen = 0; + loop: + for (int i = start; i < end; i++) { + c = ip.charAt(i); + if (isValidHexChar(c)) { + if (wordLen < 4) { + wordLen++; + continue; + } + return false; + } + + switch (c) { + case ':': + if (colons > 7) { + return false; + } + if (ip.charAt(i - 1) == ':') { + if (compressBegin >= 0) { + return false; + } + compressBegin = i - 1; + } else { + wordLen = 0; + } + colons++; + break; + case '.': + // case for the last 32-bits represented as IPv4 x:x:x:x:x:x:d.d.d.d + + // check a normal case (6 single colons) + if (compressBegin < 0 && colons != 6 || + // a special case ::1:2:3:4:5:d.d.d.d allows 7 colons with an + // IPv4 ending, otherwise 7 :'s is bad + (colons == 7 && compressBegin >= start || colons > 7)) { + return false; + } + + // Verify this address is of the correct structure to contain an IPv4 address. + // It must be IPv4-Mapped or IPv4-Compatible + // (see https://tools.ietf.org/html/rfc4291#section-2.5.5). + int ipv4Start = i - wordLen; + int j = ipv4Start - 2; // index of character before the previous ':'. + if (isValidIPv4MappedChar(ip.charAt(j))) { + if (!isValidIPv4MappedChar(ip.charAt(j - 1)) || + !isValidIPv4MappedChar(ip.charAt(j - 2)) || + !isValidIPv4MappedChar(ip.charAt(j - 3))) { + return false; + } + j -= 5; + } + + for (; j >= start; --j) { + char tmpChar = ip.charAt(j); + if (tmpChar != '0' && tmpChar != ':') { + return false; + } + } + + // 7 - is minimum IPv4 address length + int ipv4End = indexOf(ip, '%', ipv4Start + 7); + if (ipv4End < 0) { + ipv4End = end; + } + return isValidIpV4Address(ip, ipv4Start, ipv4End); + case '%': + // strip the interface name/index after the percent sign + end = i; + break loop; + default: + return false; + } + } + + // normal case without compression + if (compressBegin < 0) { + return colons == 7 && wordLen > 0; + } + + return compressBegin + 2 == end || + // 8 colons is valid only if compression in start or end + wordLen > 0 && (colons < 8 || compressBegin <= start); + } + + private static boolean isValidIpV4Word(CharSequence word, int from, int toExclusive) { + int len = toExclusive - from; + char c0, c1, c2; + if (len < 1 || len > 3 || (c0 = word.charAt(from)) < '0') { + return false; + } + if (len == 3) { + return (c1 = word.charAt(from + 1)) >= '0' && + (c2 = word.charAt(from + 2)) >= '0' && + (c0 <= '1' && c1 <= '9' && c2 <= '9' || + c0 == '2' && c1 <= '5' && (c2 <= '5' || c1 < '5' && c2 <= '9')); + } + return c0 <= '9' && (len == 1 || isValidNumericChar(word.charAt(from + 1))); + } + + private static boolean isValidHexChar(char c) { + return c >= '0' && c <= '9' || c >= 'A' && c <= 'F' || c >= 'a' && c <= 'f'; + } + + private static boolean isValidNumericChar(char c) { + return c >= '0' && c <= '9'; + } + + private static boolean isValidIPv4MappedChar(char c) { + return c == 'f' || c == 'F'; + } + + private static boolean isValidIPv4MappedSeparators(byte b0, byte b1, boolean mustBeZero) { + // We allow IPv4 Mapped (https://tools.ietf.org/html/rfc4291#section-2.5.5.1) + // and IPv4 compatible (https://tools.ietf.org/html/rfc4291#section-2.5.5.1). + // The IPv4 compatible is deprecated, but it allows parsing of plain IPv4 addressed into IPv6-Mapped addresses. + return b0 == b1 && (b0 == 0 || !mustBeZero && b1 == -1); + } + + private static boolean isValidIPv4Mapped(byte[] bytes, int currentIndex, int compressBegin, int compressLength) { + final boolean mustBeZero = compressBegin + compressLength >= 14; + return currentIndex <= 12 && currentIndex >= 2 && (!mustBeZero || compressBegin < 12) && + isValidIPv4MappedSeparators(bytes[currentIndex - 1], bytes[currentIndex - 2], mustBeZero) && + PlatformDependent.isZero(bytes, 0, currentIndex - 3); + } + + /** + * Takes a {@link CharSequence} and parses it to see if it is a valid IPV4 address. + * + * @return true, if the string represents an IPV4 address in dotted + * notation, false otherwise + */ + public static boolean isValidIpV4Address(CharSequence ip) { + return isValidIpV4Address(ip, 0, ip.length()); + } + + /** + * Takes a {@link String} and parses it to see if it is a valid IPV4 address. + * + * @return true, if the string represents an IPV4 address in dotted + * notation, false otherwise + */ + public static boolean isValidIpV4Address(String ip) { + return isValidIpV4Address(ip, 0, ip.length()); + } + + private static boolean isValidIpV4Address(CharSequence ip, int from, int toExcluded) { + return ip instanceof String ? isValidIpV4Address((String) ip, from, toExcluded) : + ip instanceof AsciiString ? isValidIpV4Address((AsciiString) ip, from, toExcluded) : + isValidIpV4Address0(ip, from, toExcluded); + } + + @SuppressWarnings("DuplicateBooleanBranch") + private static boolean isValidIpV4Address(String ip, int from, int toExcluded) { + int len = toExcluded - from; + int i; + return len <= 15 && len >= 7 && + (i = ip.indexOf('.', from + 1)) > 0 && isValidIpV4Word(ip, from, i) && + (i = ip.indexOf('.', from = i + 2)) > 0 && isValidIpV4Word(ip, from - 1, i) && + (i = ip.indexOf('.', from = i + 2)) > 0 && isValidIpV4Word(ip, from - 1, i) && + isValidIpV4Word(ip, i + 1, toExcluded); + } + + @SuppressWarnings("DuplicateBooleanBranch") + private static boolean isValidIpV4Address(AsciiString ip, int from, int toExcluded) { + int len = toExcluded - from; + int i; + return len <= 15 && len >= 7 && + (i = ip.indexOf('.', from + 1)) > 0 && isValidIpV4Word(ip, from, i) && + (i = ip.indexOf('.', from = i + 2)) > 0 && isValidIpV4Word(ip, from - 1, i) && + (i = ip.indexOf('.', from = i + 2)) > 0 && isValidIpV4Word(ip, from - 1, i) && + isValidIpV4Word(ip, i + 1, toExcluded); + } + + @SuppressWarnings("DuplicateBooleanBranch") + private static boolean isValidIpV4Address0(CharSequence ip, int from, int toExcluded) { + int len = toExcluded - from; + int i; + return len <= 15 && len >= 7 && + (i = indexOf(ip, '.', from + 1)) > 0 && isValidIpV4Word(ip, from, i) && + (i = indexOf(ip, '.', from = i + 2)) > 0 && isValidIpV4Word(ip, from - 1, i) && + (i = indexOf(ip, '.', from = i + 2)) > 0 && isValidIpV4Word(ip, from - 1, i) && + isValidIpV4Word(ip, i + 1, toExcluded); + } + + /** + * Returns the {@link Inet6Address} representation of a {@link CharSequence} IP address. + *

+ * This method will treat all IPv4 type addresses as "IPv4 mapped" (see {@link #getByName(CharSequence, boolean)}) + * + * @param ip {@link CharSequence} IP address to be converted to a {@link Inet6Address} + * @return {@link Inet6Address} representation of the {@code ip} or {@code null} if not a valid IP address. + */ + public static Inet6Address getByName(CharSequence ip) { + return getByName(ip, true); + } + + /** + * Returns the {@link Inet6Address} representation of a {@link CharSequence} IP address. + *

+ * The {@code ipv4Mapped} parameter specifies how IPv4 addresses should be treated. + * "IPv4 mapped" format as + * defined in rfc 4291 section 2 is supported. + * + * @param ip {@link CharSequence} IP address to be converted to a {@link Inet6Address} + * @param ipv4Mapped

    + *
  • {@code true} To allow IPv4 mapped inputs to be translated into {@link Inet6Address}
  • + *
  • {@code false} Consider IPv4 mapped addresses as invalid.
  • + *
+ * @return {@link Inet6Address} representation of the {@code ip} or {@code null} if not a valid IP address. + */ + public static Inet6Address getByName(CharSequence ip, boolean ipv4Mapped) { + byte[] bytes = getIPv6ByName(ip, ipv4Mapped); + if (bytes == null) { + return null; + } + try { + return Inet6Address.getByAddress(null, bytes, -1); + } catch (UnknownHostException e) { + throw new RuntimeException(e); // Should never happen + } + } + + /** + * Returns the byte array representation of a {@link CharSequence} IP address. + *

+ * The {@code ipv4Mapped} parameter specifies how IPv4 addresses should be treated. + * "IPv4 mapped" format as + * defined in rfc 4291 section 2 is supported. + * + * @param ip {@link CharSequence} IP address to be converted to a {@link Inet6Address} + * @param ipv4Mapped

    + *
  • {@code true} To allow IPv4 mapped inputs to be translated into {@link Inet6Address}
  • + *
  • {@code false} Consider IPv4 mapped addresses as invalid.
  • + *
+ * @return byte array representation of the {@code ip} or {@code null} if not a valid IP address. + */ + // visible for test + static byte[] getIPv6ByName(CharSequence ip, boolean ipv4Mapped) { + final byte[] bytes = new byte[IPV6_BYTE_COUNT]; + final int ipLength = ip.length(); + int compressBegin = 0; + int compressLength = 0; + int currentIndex = 0; + int value = 0; + int begin = -1; + int i = 0; + int ipv6Separators = 0; + int ipv4Separators = 0; + int tmp; + for (; i < ipLength; ++i) { + final char c = ip.charAt(i); + switch (c) { + case ':': + ++ipv6Separators; + if (i - begin > IPV6_MAX_CHAR_BETWEEN_SEPARATOR || + ipv4Separators > 0 || ipv6Separators > IPV6_MAX_SEPARATORS || + currentIndex + 1 >= bytes.length) { + return null; + } + value <<= (IPV6_MAX_CHAR_BETWEEN_SEPARATOR - (i - begin)) << 2; + + if (compressLength > 0) { + compressLength -= 2; + } + + // The value integer holds at most 4 bytes from right (most significant) to left (least significant). + // The following bit shifting is used to extract and re-order the individual bytes to achieve a + // left (most significant) to right (least significant) ordering. + bytes[currentIndex++] = (byte) (((value & 0xf) << 4) | ((value >> 4) & 0xf)); + bytes[currentIndex++] = (byte) ((((value >> 8) & 0xf) << 4) | ((value >> 12) & 0xf)); + tmp = i + 1; + if (tmp < ipLength && ip.charAt(tmp) == ':') { + ++tmp; + if (compressBegin != 0 || (tmp < ipLength && ip.charAt(tmp) == ':')) { + return null; + } + ++ipv6Separators; + compressBegin = currentIndex; + compressLength = bytes.length - compressBegin - 2; + ++i; + } + value = 0; + begin = -1; + break; + case '.': + ++ipv4Separators; + tmp = i - begin; // tmp is the length of the current segment. + if (tmp > IPV4_MAX_CHAR_BETWEEN_SEPARATOR + || begin < 0 + || ipv4Separators > IPV4_SEPARATORS + || (ipv6Separators > 0 && (currentIndex + compressLength < 12)) + || i + 1 >= ipLength + || currentIndex >= bytes.length + || ipv4Separators == 1 && + // We also parse pure IPv4 addresses as IPv4-Mapped for ease of use. + ((!ipv4Mapped || currentIndex != 0 && !isValidIPv4Mapped(bytes, currentIndex, + compressBegin, compressLength)) || + (tmp == 3 && (!isValidNumericChar(ip.charAt(i - 1)) || + !isValidNumericChar(ip.charAt(i - 2)) || + !isValidNumericChar(ip.charAt(i - 3))) || + tmp == 2 && (!isValidNumericChar(ip.charAt(i - 1)) || + !isValidNumericChar(ip.charAt(i - 2))) || + tmp == 1 && !isValidNumericChar(ip.charAt(i - 1))))) { + return null; + } + value <<= (IPV4_MAX_CHAR_BETWEEN_SEPARATOR - tmp) << 2; + + // The value integer holds at most 3 bytes from right (most significant) to left (least significant). + // The following bit shifting is to restructure the bytes to be left (most significant) to + // right (least significant) while also accounting for each IPv4 digit is base 10. + begin = (value & 0xf) * 100 + ((value >> 4) & 0xf) * 10 + ((value >> 8) & 0xf); + if (begin > 255) { + return null; + } + bytes[currentIndex++] = (byte) begin; + value = 0; + begin = -1; + break; + default: + if (!isValidHexChar(c) || (ipv4Separators > 0 && !isValidNumericChar(c))) { + return null; + } + if (begin < 0) { + begin = i; + } else if (i - begin > IPV6_MAX_CHAR_BETWEEN_SEPARATOR) { + return null; + } + // The value is treated as a sort of array of numbers because we are dealing with + // at most 4 consecutive bytes we can use bit shifting to accomplish this. + // The most significant byte will be encountered first, and reside in the right most + // position of the following integer + value += StringUtil.decodeHexNibble(c) << ((i - begin) << 2); + break; + } + } + + final boolean isCompressed = compressBegin > 0; + // Finish up last set of data that was accumulated in the loop (or before the loop) + if (ipv4Separators > 0) { + if (begin > 0 && i - begin > IPV4_MAX_CHAR_BETWEEN_SEPARATOR || + ipv4Separators != IPV4_SEPARATORS || + currentIndex >= bytes.length) { + return null; + } + if (!(ipv6Separators == 0 || ipv6Separators >= IPV6_MIN_SEPARATORS && + (!isCompressed && (ipv6Separators == 6 && ip.charAt(0) != ':') || + isCompressed && (ipv6Separators < IPV6_MAX_SEPARATORS && + (ip.charAt(0) != ':' || compressBegin <= 2))))) { + return null; + } + value <<= (IPV4_MAX_CHAR_BETWEEN_SEPARATOR - (i - begin)) << 2; + + // The value integer holds at most 3 bytes from right (most significant) to left (least significant). + // The following bit shifting is to restructure the bytes to be left (most significant) to + // right (least significant) while also accounting for each IPv4 digit is base 10. + begin = (value & 0xf) * 100 + ((value >> 4) & 0xf) * 10 + ((value >> 8) & 0xf); + if (begin > 255) { + return null; + } + bytes[currentIndex++] = (byte) begin; + } else { + tmp = ipLength - 1; + if (begin > 0 && i - begin > IPV6_MAX_CHAR_BETWEEN_SEPARATOR || + ipv6Separators < IPV6_MIN_SEPARATORS || + !isCompressed && (ipv6Separators + 1 != IPV6_MAX_SEPARATORS || + ip.charAt(0) == ':' || ip.charAt(tmp) == ':') || + isCompressed && (ipv6Separators > IPV6_MAX_SEPARATORS || + (ipv6Separators == IPV6_MAX_SEPARATORS && + (compressBegin <= 2 && ip.charAt(0) != ':' || + compressBegin >= 14 && ip.charAt(tmp) != ':'))) || + currentIndex + 1 >= bytes.length || + begin < 0 && ip.charAt(tmp - 1) != ':' || + compressBegin > 2 && ip.charAt(0) == ':') { + return null; + } + if (begin >= 0 && i - begin <= IPV6_MAX_CHAR_BETWEEN_SEPARATOR) { + value <<= (IPV6_MAX_CHAR_BETWEEN_SEPARATOR - (i - begin)) << 2; + } + // The value integer holds at most 4 bytes from right (most significant) to left (least significant). + // The following bit shifting is used to extract and re-order the individual bytes to achieve a + // left (most significant) to right (least significant) ordering. + bytes[currentIndex++] = (byte) (((value & 0xf) << 4) | ((value >> 4) & 0xf)); + bytes[currentIndex++] = (byte) ((((value >> 8) & 0xf) << 4) | ((value >> 12) & 0xf)); + } + + if (currentIndex < bytes.length) { + int toBeCopiedLength = currentIndex - compressBegin; + int targetIndex = bytes.length - toBeCopiedLength; + System.arraycopy(bytes, compressBegin, bytes, targetIndex, toBeCopiedLength); + // targetIndex is also the `toIndex` to fill 0 + Arrays.fill(bytes, compressBegin, targetIndex, (byte) 0); + } + + if (ipv4Separators > 0) { + // We only support IPv4-Mapped addresses [1] because IPv4-Compatible addresses are deprecated [2]. + // [1] https://tools.ietf.org/html/rfc4291#section-2.5.5.2 + // [2] https://tools.ietf.org/html/rfc4291#section-2.5.5.1 + bytes[10] = bytes[11] = (byte) 0xff; + } + + return bytes; + } + + /** + * Returns the {@link String} representation of an {@link InetSocketAddress}. + *

+ * The output does not include Scope ID. + * + * @param addr {@link InetSocketAddress} to be converted to an address string + * @return {@code String} containing the text-formatted IP address + */ + public static String toSocketAddressString(InetSocketAddress addr) { + String port = String.valueOf(addr.getPort()); + final StringBuilder sb; + + if (addr.isUnresolved()) { + String hostname = getHostname(addr); + sb = newSocketAddressStringBuilder(hostname, port, !isValidIpV6Address(hostname)); + } else { + InetAddress address = addr.getAddress(); + String hostString = toAddressString(address); + sb = newSocketAddressStringBuilder(hostString, port, address instanceof Inet4Address); + } + return sb.append(':').append(port).toString(); + } + + /** + * Returns the {@link String} representation of a host port combo. + */ + public static String toSocketAddressString(String host, int port) { + String portStr = String.valueOf(port); + return newSocketAddressStringBuilder( + host, portStr, !isValidIpV6Address(host)).append(':').append(portStr).toString(); + } + + private static StringBuilder newSocketAddressStringBuilder(String host, String port, boolean ipv4) { + int hostLen = host.length(); + if (ipv4) { + // Need to include enough space for hostString:port. + return new StringBuilder(hostLen + 1 + port.length()).append(host); + } + // Need to include enough space for [hostString]:port. + StringBuilder stringBuilder = new StringBuilder(hostLen + 3 + port.length()); + if (hostLen > 1 && host.charAt(0) == '[' && host.charAt(hostLen - 1) == ']') { + return stringBuilder.append(host); + } + return stringBuilder.append('[').append(host).append(']'); + } + + /** + * Returns the {@link String} representation of an {@link InetAddress}. + *

    + *
  • Inet4Address results are identical to {@link InetAddress#getHostAddress()}
  • + *
  • Inet6Address results adhere to + * rfc 5952 section 4
  • + *
+ *

+ * The output does not include Scope ID. + * + * @param ip {@link InetAddress} to be converted to an address string + * @return {@code String} containing the text-formatted IP address + */ + public static String toAddressString(InetAddress ip) { + return toAddressString(ip, false); + } + + /** + * Returns the {@link String} representation of an {@link InetAddress}. + *

    + *
  • Inet4Address results are identical to {@link InetAddress#getHostAddress()}
  • + *
  • Inet6Address results adhere to + * rfc 5952 section 4 if + * {@code ipv4Mapped} is false. If {@code ipv4Mapped} is true then "IPv4 mapped" format + * from rfc 4291 section 2 will be supported. + * The compressed result will always obey the compression rules defined in + * rfc 5952 section 4
  • + *
+ *

+ * The output does not include Scope ID. + * + * @param ip {@link InetAddress} to be converted to an address string + * @param ipv4Mapped

    + *
  • {@code true} to stray from strict rfc 5952 and support the "IPv4 mapped" format + * defined in rfc 4291 section 2 while still + * following the updated guidelines in + * rfc 5952 section 4
  • + *
  • {@code false} to strictly follow rfc 5952
  • + *
+ * @return {@code String} containing the text-formatted IP address + */ + public static String toAddressString(InetAddress ip, boolean ipv4Mapped) { + if (ip instanceof Inet4Address) { + return ip.getHostAddress(); + } + if (!(ip instanceof Inet6Address)) { + throw new IllegalArgumentException("Unhandled type: " + ip); + } + + return toAddressString(ip.getAddress(), 0, ipv4Mapped); + } + + private static String toAddressString(byte[] bytes, int offset, boolean ipv4Mapped) { + final int[] words = new int[IPV6_WORD_COUNT]; + for (int i = 0; i < words.length; ++i) { + int idx = (i << 1) + offset; + words[i] = ((bytes[idx] & 0xff) << 8) | (bytes[idx + 1] & 0xff); + } + + // Find longest run of 0s, tie goes to first found instance + int currentStart = -1; + int currentLength; + int shortestStart = -1; + int shortestLength = 0; + for (int i = 0; i < words.length; ++i) { + if (words[i] == 0) { + if (currentStart < 0) { + currentStart = i; + } + } else if (currentStart >= 0) { + currentLength = i - currentStart; + if (currentLength > shortestLength) { + shortestStart = currentStart; + shortestLength = currentLength; + } + currentStart = -1; + } + } + // If the array ends on a streak of zeros, make sure we account for it + if (currentStart >= 0) { + currentLength = words.length - currentStart; + if (currentLength > shortestLength) { + shortestStart = currentStart; + shortestLength = currentLength; + } + } + // Ignore the longest streak if it is only 1 long + if (shortestLength == 1) { + shortestLength = 0; + shortestStart = -1; + } + + // Translate to string taking into account longest consecutive 0s + final int shortestEnd = shortestStart + shortestLength; + final StringBuilder b = new StringBuilder(IPV6_MAX_CHAR_COUNT); + if (shortestEnd < 0) { // Optimization when there is no compressing needed + b.append(Integer.toHexString(words[0])); + for (int i = 1; i < words.length; ++i) { + b.append(':'); + b.append(Integer.toHexString(words[i])); + } + } else { // General case that can handle compressing (and not compressing) + // Loop unroll the first index (so we don't constantly check i==0 cases in loop) + final boolean isIpv4Mapped; + if (inRangeEndExclusive(0, shortestStart, shortestEnd)) { + b.append("::"); + isIpv4Mapped = ipv4Mapped && (shortestEnd == 5 && words[5] == 0xffff); + } else { + b.append(Integer.toHexString(words[0])); + isIpv4Mapped = false; + } + for (int i = 1; i < words.length; ++i) { + if (!inRangeEndExclusive(i, shortestStart, shortestEnd)) { + if (!inRangeEndExclusive(i - 1, shortestStart, shortestEnd)) { + // If the last index was not part of the shortened sequence + if (!isIpv4Mapped || i == 6) { + b.append(':'); + } else { + b.append('.'); + } + } + if (isIpv4Mapped && i > 5) { + b.append(words[i] >> 8); + b.append('.'); + b.append(words[i] & 0xff); + } else { + b.append(Integer.toHexString(words[i])); + } + } else if (!inRangeEndExclusive(i - 1, shortestStart, shortestEnd)) { + // If we are in the shortened sequence and the last index was not + b.append("::"); + } + } + } + + return b.toString(); + } + + /** + * Returns {@link InetSocketAddress#getHostString()} if Java >= 7, + * or {@link InetSocketAddress#getHostName()} otherwise. + * + * @param addr The address + * @return the host string + */ + public static String getHostname(InetSocketAddress addr) { + return PlatformDependent.javaVersion() >= 7 ? addr.getHostString() : addr.getHostName(); + } + + /** + * Does a range check on {@code value} if is within {@code start} (inclusive) and {@code end} (exclusive). + * + * @param value The value to checked if is within {@code start} (inclusive) and {@code end} (exclusive) + * @param start The start of the range (inclusive) + * @param end The end of the range (exclusive) + * @return
    + *
  • {@code true} if {@code value} if is within {@code start} (inclusive) and {@code end} (exclusive)
  • + *
  • {@code false} otherwise
  • + *
+ */ + private static boolean inRangeEndExclusive(int value, int start, int end) { + return value >= start && value < end; + } + + /** + * A constructor to stop this class being constructed. + */ + private NetUtil() { + // Unused + } +} diff --git a/netty-util/src/main/java/io/netty/util/NetUtilInitializations.java b/netty-util/src/main/java/io/netty/util/NetUtilInitializations.java new file mode 100644 index 0000000..9368543 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/NetUtilInitializations.java @@ -0,0 +1,188 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util; + +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.SocketUtils; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; +import java.net.Inet4Address; +import java.net.Inet6Address; +import java.net.InetAddress; +import java.net.NetworkInterface; +import java.net.SocketException; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; +import java.util.Enumeration; +import java.util.List; + +final class NetUtilInitializations { + /** + * The logger being used by this class + */ + private static final InternalLogger logger = InternalLoggerFactory.getInstance(NetUtilInitializations.class); + + private NetUtilInitializations() { + } + + static Inet4Address createLocalhost4() { + byte[] LOCALHOST4_BYTES = {127, 0, 0, 1}; + + Inet4Address localhost4 = null; + try { + localhost4 = (Inet4Address) InetAddress.getByAddress("localhost", LOCALHOST4_BYTES); + } catch (Exception e) { + // We should not get here as long as the length of the address is correct. + PlatformDependent.throwException(e); + } + + return localhost4; + } + + static Inet6Address createLocalhost6() { + byte[] LOCALHOST6_BYTES = {0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1}; + + Inet6Address localhost6 = null; + try { + localhost6 = (Inet6Address) InetAddress.getByAddress("localhost", LOCALHOST6_BYTES); + } catch (Exception e) { + // We should not get here as long as the length of the address is correct. + PlatformDependent.throwException(e); + } + + return localhost6; + } + + static Collection networkInterfaces() { + List networkInterfaces = new ArrayList(); + try { + Enumeration interfaces = NetworkInterface.getNetworkInterfaces(); + if (interfaces != null) { + while (interfaces.hasMoreElements()) { + networkInterfaces.add(interfaces.nextElement()); + } + } + } catch (SocketException e) { + logger.warn("Failed to retrieve the list of available network interfaces", e); + } catch (NullPointerException e) { + if (!PlatformDependent.isAndroid()) { + throw e; + } + // Might happen on earlier version of Android. + // See https://developer.android.com/reference/java/net/NetworkInterface#getNetworkInterfaces() + } + return Collections.unmodifiableList(networkInterfaces); + } + + static NetworkIfaceAndInetAddress determineLoopback( + Collection networkInterfaces, Inet4Address localhost4, Inet6Address localhost6) { + // Retrieve the list of available network interfaces. + List ifaces = new ArrayList(); + for (NetworkInterface iface : networkInterfaces) { + // Use the interface with proper INET addresses only. + if (SocketUtils.addressesFromNetworkInterface(iface).hasMoreElements()) { + ifaces.add(iface); + } + } + + // Find the first loopback interface available from its INET address (127.0.0.1 or ::1) + // Note that we do not use NetworkInterface.isLoopback() in the first place because it takes long time + // on a certain environment. (e.g. Windows with -Djava.net.preferIPv4Stack=true) + NetworkInterface loopbackIface = null; + InetAddress loopbackAddr = null; + loop: + for (NetworkInterface iface : ifaces) { + for (Enumeration i = SocketUtils.addressesFromNetworkInterface(iface); i.hasMoreElements(); ) { + InetAddress addr = i.nextElement(); + if (addr.isLoopbackAddress()) { + // Found + loopbackIface = iface; + loopbackAddr = addr; + break loop; + } + } + } + + // If failed to find the loopback interface from its INET address, fall back to isLoopback(). + if (loopbackIface == null) { + try { + for (NetworkInterface iface : ifaces) { + if (iface.isLoopback()) { + Enumeration i = SocketUtils.addressesFromNetworkInterface(iface); + if (i.hasMoreElements()) { + // Found the one with INET address. + loopbackIface = iface; + loopbackAddr = i.nextElement(); + break; + } + } + } + + if (loopbackIface == null) { + logger.warn("Failed to find the loopback interface"); + } + } catch (SocketException e) { + logger.warn("Failed to find the loopback interface", e); + } + } + + if (loopbackIface != null) { + // Found the loopback interface with an INET address. + logger.debug( + "Loopback interface: {} ({}, {})", + loopbackIface.getName(), loopbackIface.getDisplayName(), loopbackAddr.getHostAddress()); + } else { + // Could not find the loopback interface, but we can't leave LOCALHOST as null. + // Use LOCALHOST6 or LOCALHOST4, preferably the IPv6 one. + if (loopbackAddr == null) { + try { + if (NetworkInterface.getByInetAddress(localhost6) != null) { + logger.debug("Using hard-coded IPv6 localhost address: {}", localhost6); + loopbackAddr = localhost6; + } + } catch (Exception e) { + // Ignore + } finally { + if (loopbackAddr == null) { + logger.debug("Using hard-coded IPv4 localhost address: {}", localhost4); + loopbackAddr = localhost4; + } + } + } + } + + return new NetworkIfaceAndInetAddress(loopbackIface, loopbackAddr); + } + + static final class NetworkIfaceAndInetAddress { + private final NetworkInterface iface; + private final InetAddress address; + + NetworkIfaceAndInetAddress(NetworkInterface iface, InetAddress address) { + this.iface = iface; + this.address = address; + } + + public NetworkInterface iface() { + return iface; + } + + public InetAddress address() { + return address; + } + } +} diff --git a/netty-util/src/main/java/io/netty/util/NettyRuntime.java b/netty-util/src/main/java/io/netty/util/NettyRuntime.java new file mode 100644 index 0000000..3991115 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/NettyRuntime.java @@ -0,0 +1,105 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.util; + +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.SystemPropertyUtil; +import java.util.Locale; + +/** + * A utility class for wrapping calls to {@link Runtime}. + */ +public final class NettyRuntime { + + /** + * Holder class for available processors to enable testing. + */ + static class AvailableProcessorsHolder { + + private int availableProcessors; + + /** + * Set the number of available processors. + * + * @param availableProcessors the number of available processors + * @throws IllegalArgumentException if the specified number of available processors is non-positive + * @throws IllegalStateException if the number of available processors is already configured + */ + synchronized void setAvailableProcessors(final int availableProcessors) { + ObjectUtil.checkPositive(availableProcessors, "availableProcessors"); + if (this.availableProcessors != 0) { + final String message = String.format( + Locale.ROOT, + "availableProcessors is already set to [%d], rejecting [%d]", + this.availableProcessors, + availableProcessors); + throw new IllegalStateException(message); + } + this.availableProcessors = availableProcessors; + } + + /** + * Get the configured number of available processors. The default is {@link Runtime#availableProcessors()}. + * This can be overridden by setting the system property "io.netty.availableProcessors" or by invoking + * {@link #setAvailableProcessors(int)} before any calls to this method. + * + * @return the configured number of available processors + */ + @SuppressForbidden(reason = "to obtain default number of available processors") + synchronized int availableProcessors() { + if (this.availableProcessors == 0) { + final int availableProcessors = + SystemPropertyUtil.getInt( + "io.netty.availableProcessors", + Runtime.getRuntime().availableProcessors()); + setAvailableProcessors(availableProcessors); + } + return this.availableProcessors; + } + } + + private static final AvailableProcessorsHolder holder = new AvailableProcessorsHolder(); + + /** + * Set the number of available processors. + * + * @param availableProcessors the number of available processors + * @throws IllegalArgumentException if the specified number of available processors is non-positive + * @throws IllegalStateException if the number of available processors is already configured + */ + @SuppressWarnings("unused,WeakerAccess") // this method is part of the public API + public static void setAvailableProcessors(final int availableProcessors) { + holder.setAvailableProcessors(availableProcessors); + } + + /** + * Get the configured number of available processors. The default is {@link Runtime#availableProcessors()}. This + * can be overridden by setting the system property "io.netty.availableProcessors" or by invoking + * {@link #setAvailableProcessors(int)} before any calls to this method. + * + * @return the configured number of available processors + */ + public static int availableProcessors() { + return holder.availableProcessors(); + } + + /** + * No public constructor to prevent instances from being created. + */ + private NettyRuntime() { + } +} diff --git a/netty-util/src/main/java/io/netty/util/Recycler.java b/netty-util/src/main/java/io/netty/util/Recycler.java new file mode 100644 index 0000000..f5f8d2e --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/Recycler.java @@ -0,0 +1,485 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util; + +import io.netty.util.concurrent.FastThreadLocal; +import io.netty.util.concurrent.FastThreadLocalThread; +import io.netty.util.internal.ObjectPool; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.SystemPropertyUtil; +import io.netty.util.internal.UnstableApi; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; +import java.util.ArrayDeque; +import java.util.Queue; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import org.jctools.queues.MessagePassingQueue; +import static io.netty.util.internal.PlatformDependent.newMpscQueue; +import static java.lang.Math.max; +import static java.lang.Math.min; + +/** + * Light-weight object pool based on a thread-local stack. + * + * @param the type of the pooled object + */ +public abstract class Recycler { + private static final InternalLogger logger = InternalLoggerFactory.getInstance(Recycler.class); + private static final EnhancedHandle NOOP_HANDLE = new EnhancedHandle() { + @Override + public void recycle(Object object) { + // NOOP + } + + @Override + public void unguardedRecycle(final Object object) { + // NOOP + } + + @Override + public String toString() { + return "NOOP_HANDLE"; + } + }; + private static final int DEFAULT_INITIAL_MAX_CAPACITY_PER_THREAD = 4 * 1024; // Use 4k instances as default. + private static final int DEFAULT_MAX_CAPACITY_PER_THREAD; + private static final int RATIO; + private static final int DEFAULT_QUEUE_CHUNK_SIZE_PER_THREAD; + private static final boolean BLOCKING_POOL; + private static final boolean BATCH_FAST_TL_ONLY; + + static { + // In the future, we might have different maxCapacity for different object types. + // e.g. io.netty.recycler.maxCapacity.writeTask + // io.netty.recycler.maxCapacity.outboundBuffer + int maxCapacityPerThread = SystemPropertyUtil.getInt("io.netty.recycler.maxCapacityPerThread", + SystemPropertyUtil.getInt("io.netty.recycler.maxCapacity", DEFAULT_INITIAL_MAX_CAPACITY_PER_THREAD)); + if (maxCapacityPerThread < 0) { + maxCapacityPerThread = DEFAULT_INITIAL_MAX_CAPACITY_PER_THREAD; + } + + DEFAULT_MAX_CAPACITY_PER_THREAD = maxCapacityPerThread; + DEFAULT_QUEUE_CHUNK_SIZE_PER_THREAD = SystemPropertyUtil.getInt("io.netty.recycler.chunkSize", 32); + + // By default, we allow one push to a Recycler for each 8th try on handles that were never recycled before. + // This should help to slowly increase the capacity of the recycler while not be too sensitive to allocation + // bursts. + RATIO = max(0, SystemPropertyUtil.getInt("io.netty.recycler.ratio", 8)); + + BLOCKING_POOL = SystemPropertyUtil.getBoolean("io.netty.recycler.blocking", false); + BATCH_FAST_TL_ONLY = SystemPropertyUtil.getBoolean("io.netty.recycler.batchFastThreadLocalOnly", true); + + if (logger.isDebugEnabled()) { + if (DEFAULT_MAX_CAPACITY_PER_THREAD == 0) { + logger.debug("-Dio.netty.recycler.maxCapacityPerThread: disabled"); + logger.debug("-Dio.netty.recycler.ratio: disabled"); + logger.debug("-Dio.netty.recycler.chunkSize: disabled"); + logger.debug("-Dio.netty.recycler.blocking: disabled"); + logger.debug("-Dio.netty.recycler.batchFastThreadLocalOnly: disabled"); + } else { + logger.debug("-Dio.netty.recycler.maxCapacityPerThread: {}", DEFAULT_MAX_CAPACITY_PER_THREAD); + logger.debug("-Dio.netty.recycler.ratio: {}", RATIO); + logger.debug("-Dio.netty.recycler.chunkSize: {}", DEFAULT_QUEUE_CHUNK_SIZE_PER_THREAD); + logger.debug("-Dio.netty.recycler.blocking: {}", BLOCKING_POOL); + logger.debug("-Dio.netty.recycler.batchFastThreadLocalOnly: {}", BATCH_FAST_TL_ONLY); + } + } + } + + private final int maxCapacityPerThread; + private final int interval; + private final int chunkSize; + private final FastThreadLocal> threadLocal = new FastThreadLocal>() { + @Override + protected LocalPool initialValue() { + return new LocalPool(maxCapacityPerThread, interval, chunkSize); + } + + @Override + protected void onRemoval(LocalPool value) throws Exception { + super.onRemoval(value); + MessagePassingQueue> handles = value.pooledHandles; + value.pooledHandles = null; + value.owner = null; + handles.clear(); + } + }; + + protected Recycler() { + this(DEFAULT_MAX_CAPACITY_PER_THREAD); + } + + protected Recycler(int maxCapacityPerThread) { + this(maxCapacityPerThread, RATIO, DEFAULT_QUEUE_CHUNK_SIZE_PER_THREAD); + } + + /** + * @deprecated Use one of the following instead: + * {@link #Recycler()}, {@link #Recycler(int)}, {@link #Recycler(int, int, int)}. + */ + @Deprecated + @SuppressWarnings("unused") // Parameters we can't remove due to compatibility. + protected Recycler(int maxCapacityPerThread, int maxSharedCapacityFactor) { + this(maxCapacityPerThread, RATIO, DEFAULT_QUEUE_CHUNK_SIZE_PER_THREAD); + } + + /** + * @deprecated Use one of the following instead: + * {@link #Recycler()}, {@link #Recycler(int)}, {@link #Recycler(int, int, int)}. + */ + @Deprecated + @SuppressWarnings("unused") // Parameters we can't remove due to compatibility. + protected Recycler(int maxCapacityPerThread, int maxSharedCapacityFactor, + int ratio, int maxDelayedQueuesPerThread) { + this(maxCapacityPerThread, ratio, DEFAULT_QUEUE_CHUNK_SIZE_PER_THREAD); + } + + /** + * @deprecated Use one of the following instead: + * {@link #Recycler()}, {@link #Recycler(int)}, {@link #Recycler(int, int, int)}. + */ + @Deprecated + @SuppressWarnings("unused") // Parameters we can't remove due to compatibility. + protected Recycler(int maxCapacityPerThread, int maxSharedCapacityFactor, + int ratio, int maxDelayedQueuesPerThread, int delayedQueueRatio) { + this(maxCapacityPerThread, ratio, DEFAULT_QUEUE_CHUNK_SIZE_PER_THREAD); + } + + protected Recycler(int maxCapacityPerThread, int ratio, int chunkSize) { + interval = max(0, ratio); + if (maxCapacityPerThread <= 0) { + this.maxCapacityPerThread = 0; + this.chunkSize = 0; + } else { + this.maxCapacityPerThread = max(4, maxCapacityPerThread); + this.chunkSize = max(2, min(chunkSize, this.maxCapacityPerThread >> 1)); + } + } + + @SuppressWarnings("unchecked") + public final T get() { + if (maxCapacityPerThread == 0) { + return newObject((Handle) NOOP_HANDLE); + } + LocalPool localPool = threadLocal.get(); + DefaultHandle handle = localPool.claim(); + T obj; + if (handle == null) { + handle = localPool.newHandle(); + if (handle != null) { + obj = newObject(handle); + handle.set(obj); + } else { + obj = newObject((Handle) NOOP_HANDLE); + } + } else { + obj = handle.get(); + } + + return obj; + } + + /** + * @deprecated use {@link Handle#recycle(Object)}. + */ + @Deprecated + public final boolean recycle(T o, Handle handle) { + if (handle == NOOP_HANDLE) { + return false; + } + + handle.recycle(o); + return true; + } + + final int threadLocalSize() { + LocalPool localPool = threadLocal.getIfExists(); + return localPool == null ? 0 : localPool.pooledHandles.size() + localPool.batch.size(); + } + + /** + * @param handle can NOT be null. + */ + protected abstract T newObject(Handle handle); + + @SuppressWarnings("ClassNameSameAsAncestorName") // Can't change this due to compatibility. + public interface Handle extends ObjectPool.Handle { + } + + @UnstableApi + public abstract static class EnhancedHandle implements Handle { + + public abstract void unguardedRecycle(Object object); + + private EnhancedHandle() { + } + } + + private static final class DefaultHandle extends EnhancedHandle { + private static final int STATE_CLAIMED = 0; + private static final int STATE_AVAILABLE = 1; + private static final AtomicIntegerFieldUpdater> STATE_UPDATER; + + static { + AtomicIntegerFieldUpdater updater = AtomicIntegerFieldUpdater.newUpdater(DefaultHandle.class, "state"); + //noinspection unchecked + STATE_UPDATER = (AtomicIntegerFieldUpdater>) updater; + } + + private volatile int state; // State is initialised to STATE_CLAIMED (aka. 0) so they can be released. + private final LocalPool localPool; + private T value; + + DefaultHandle(LocalPool localPool) { + this.localPool = localPool; + } + + @Override + public void recycle(Object object) { + if (object != value) { + throw new IllegalArgumentException("object does not belong to handle"); + } + localPool.release(this, true); + } + + @Override + public void unguardedRecycle(Object object) { + if (object != value) { + throw new IllegalArgumentException("object does not belong to handle"); + } + localPool.release(this, false); + } + + T get() { + return value; + } + + void set(T value) { + this.value = value; + } + + void toClaimed() { + assert state == STATE_AVAILABLE; + STATE_UPDATER.lazySet(this, STATE_CLAIMED); + } + + void toAvailable() { + int prev = STATE_UPDATER.getAndSet(this, STATE_AVAILABLE); + if (prev == STATE_AVAILABLE) { + throw new IllegalStateException("Object has been recycled already."); + } + } + + void unguardedToAvailable() { + int prev = state; + if (prev == STATE_AVAILABLE) { + throw new IllegalStateException("Object has been recycled already."); + } + STATE_UPDATER.lazySet(this, STATE_AVAILABLE); + } + } + + private static final class LocalPool implements MessagePassingQueue.Consumer> { + private final int ratioInterval; + private final int chunkSize; + private final ArrayDeque> batch; + private volatile Thread owner; + private volatile MessagePassingQueue> pooledHandles; + private int ratioCounter; + + @SuppressWarnings("unchecked") + LocalPool(int maxCapacity, int ratioInterval, int chunkSize) { + this.ratioInterval = ratioInterval; + this.chunkSize = chunkSize; + batch = new ArrayDeque>(chunkSize); + Thread currentThread = Thread.currentThread(); + owner = !BATCH_FAST_TL_ONLY || currentThread instanceof FastThreadLocalThread ? currentThread : null; + if (BLOCKING_POOL) { + pooledHandles = new BlockingMessageQueue>(maxCapacity); + } else { + pooledHandles = (MessagePassingQueue>) newMpscQueue(chunkSize, maxCapacity); + } + ratioCounter = ratioInterval; // Start at interval so the first one will be recycled. + } + + DefaultHandle claim() { + MessagePassingQueue> handles = pooledHandles; + if (handles == null) { + return null; + } + if (batch.isEmpty()) { + handles.drain(this, chunkSize); + } + DefaultHandle handle = batch.pollFirst(); + if (null != handle) { + handle.toClaimed(); + } + return handle; + } + + void release(DefaultHandle handle, boolean guarded) { + if (guarded) { + handle.toAvailable(); + } else { + handle.unguardedToAvailable(); + } + Thread owner = this.owner; + if (owner != null && Thread.currentThread() == owner && batch.size() < chunkSize) { + accept(handle); + } else if (owner != null && isTerminated(owner)) { + this.owner = null; + pooledHandles = null; + } else { + MessagePassingQueue> handles = pooledHandles; + if (handles != null) { + handles.relaxedOffer(handle); + } + } + } + + private static boolean isTerminated(Thread owner) { + // Do not use `Thread.getState()` in J9 JVM because it's known to have a performance issue. + // See: https://github.com/netty/netty/issues/13347#issuecomment-1518537895 + return PlatformDependent.isJ9Jvm() ? !owner.isAlive() : owner.getState() == Thread.State.TERMINATED; + } + + DefaultHandle newHandle() { + if (++ratioCounter >= ratioInterval) { + ratioCounter = 0; + return new DefaultHandle(this); + } + return null; + } + + @Override + public void accept(DefaultHandle e) { + batch.addLast(e); + } + } + + /** + * This is an implementation of {@link MessagePassingQueue}, similar to what might be returned from + * {@link PlatformDependent#newMpscQueue(int)}, but intended to be used for debugging purpose. + * The implementation relies on synchronised monitor locks for thread-safety. + * The {@code fill} bulk operation is not supported by this implementation. + */ + private static final class BlockingMessageQueue implements MessagePassingQueue { + private final Queue deque; + private final int maxCapacity; + + BlockingMessageQueue(int maxCapacity) { + this.maxCapacity = maxCapacity; + // This message passing queue is backed by an ArrayDeque instance, + // made thread-safe by synchronising on `this` BlockingMessageQueue instance. + // Why ArrayDeque? + // We use ArrayDeque instead of LinkedList or LinkedBlockingQueue because it's more space efficient. + // We use ArrayDeque instead of ArrayList because we need the queue APIs. + // We use ArrayDeque instead of ConcurrentLinkedQueue because CLQ is unbounded and has O(n) size(). + // We use ArrayDeque instead of ArrayBlockingQueue because ABQ allocates its max capacity up-front, + // and these queues will usually have large capacities, in potentially great numbers (one per thread), + // but often only have comparatively few items in them. + deque = new ArrayDeque(); + } + + @Override + public synchronized boolean offer(T e) { + if (deque.size() == maxCapacity) { + return false; + } + return deque.offer(e); + } + + @Override + public synchronized T poll() { + return deque.poll(); + } + + @Override + public synchronized T peek() { + return deque.peek(); + } + + @Override + public synchronized int size() { + return deque.size(); + } + + @Override + public synchronized void clear() { + deque.clear(); + } + + @Override + public synchronized boolean isEmpty() { + return deque.isEmpty(); + } + + @Override + public int capacity() { + return maxCapacity; + } + + @Override + public boolean relaxedOffer(T e) { + return offer(e); + } + + @Override + public T relaxedPoll() { + return poll(); + } + + @Override + public T relaxedPeek() { + return peek(); + } + + @Override + public int drain(Consumer c, int limit) { + T obj; + int i = 0; + for (; i < limit && (obj = poll()) != null; i++) { + c.accept(obj); + } + return i; + } + + @Override + public int fill(Supplier s, int limit) { + throw new UnsupportedOperationException(); + } + + @Override + public int drain(Consumer c) { + throw new UnsupportedOperationException(); + } + + @Override + public int fill(Supplier s) { + throw new UnsupportedOperationException(); + } + + @Override + public void drain(Consumer c, WaitStrategy wait, ExitCondition exit) { + throw new UnsupportedOperationException(); + } + + @Override + public void fill(Supplier s, WaitStrategy wait, ExitCondition exit) { + throw new UnsupportedOperationException(); + } + } +} diff --git a/netty-util/src/main/java/io/netty/util/ReferenceCountUtil.java b/netty-util/src/main/java/io/netty/util/ReferenceCountUtil.java new file mode 100644 index 0000000..bc31834 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/ReferenceCountUtil.java @@ -0,0 +1,210 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util; + +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.StringUtil; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +/** + * Collection of method to handle objects that may implement {@link ReferenceCounted}. + */ +public final class ReferenceCountUtil { + + private static final InternalLogger logger = InternalLoggerFactory.getInstance(ReferenceCountUtil.class); + + static { + ResourceLeakDetector.addExclusions(ReferenceCountUtil.class, "touch"); + } + + /** + * Try to call {@link ReferenceCounted#retain()} if the specified message implements {@link ReferenceCounted}. + * If the specified message doesn't implement {@link ReferenceCounted}, this method does nothing. + */ + @SuppressWarnings("unchecked") + public static T retain(T msg) { + if (msg instanceof ReferenceCounted) { + return (T) ((ReferenceCounted) msg).retain(); + } + return msg; + } + + /** + * Try to call {@link ReferenceCounted#retain(int)} if the specified message implements {@link ReferenceCounted}. + * If the specified message doesn't implement {@link ReferenceCounted}, this method does nothing. + */ + @SuppressWarnings("unchecked") + public static T retain(T msg, int increment) { + ObjectUtil.checkPositive(increment, "increment"); + if (msg instanceof ReferenceCounted) { + return (T) ((ReferenceCounted) msg).retain(increment); + } + return msg; + } + + /** + * Tries to call {@link ReferenceCounted#touch()} if the specified message implements {@link ReferenceCounted}. + * If the specified message doesn't implement {@link ReferenceCounted}, this method does nothing. + */ + @SuppressWarnings("unchecked") + public static T touch(T msg) { + if (msg instanceof ReferenceCounted) { + return (T) ((ReferenceCounted) msg).touch(); + } + return msg; + } + + /** + * Tries to call {@link ReferenceCounted#touch(Object)} if the specified message implements + * {@link ReferenceCounted}. If the specified message doesn't implement {@link ReferenceCounted}, + * this method does nothing. + */ + @SuppressWarnings("unchecked") + public static T touch(T msg, Object hint) { + if (msg instanceof ReferenceCounted) { + return (T) ((ReferenceCounted) msg).touch(hint); + } + return msg; + } + + /** + * Try to call {@link ReferenceCounted#release()} if the specified message implements {@link ReferenceCounted}. + * If the specified message doesn't implement {@link ReferenceCounted}, this method does nothing. + */ + public static boolean release(Object msg) { + if (msg instanceof ReferenceCounted) { + return ((ReferenceCounted) msg).release(); + } + return false; + } + + /** + * Try to call {@link ReferenceCounted#release(int)} if the specified message implements {@link ReferenceCounted}. + * If the specified message doesn't implement {@link ReferenceCounted}, this method does nothing. + */ + public static boolean release(Object msg, int decrement) { + ObjectUtil.checkPositive(decrement, "decrement"); + if (msg instanceof ReferenceCounted) { + return ((ReferenceCounted) msg).release(decrement); + } + return false; + } + + /** + * Try to call {@link ReferenceCounted#release()} if the specified message implements {@link ReferenceCounted}. + * If the specified message doesn't implement {@link ReferenceCounted}, this method does nothing. + * Unlike {@link #release(Object)} this method catches an exception raised by {@link ReferenceCounted#release()} + * and logs it, rather than rethrowing it to the caller. It is usually recommended to use {@link #release(Object)} + * instead, unless you absolutely need to swallow an exception. + */ + public static void safeRelease(Object msg) { + try { + release(msg); + } catch (Throwable t) { + logger.warn("Failed to release a message: {}", msg, t); + } + } + + /** + * Try to call {@link ReferenceCounted#release(int)} if the specified message implements {@link ReferenceCounted}. + * If the specified message doesn't implement {@link ReferenceCounted}, this method does nothing. + * Unlike {@link #release(Object)} this method catches an exception raised by {@link ReferenceCounted#release(int)} + * and logs it, rather than rethrowing it to the caller. It is usually recommended to use + * {@link #release(Object, int)} instead, unless you absolutely need to swallow an exception. + */ + public static void safeRelease(Object msg, int decrement) { + try { + ObjectUtil.checkPositive(decrement, "decrement"); + release(msg, decrement); + } catch (Throwable t) { + if (logger.isWarnEnabled()) { + logger.warn("Failed to release a message: {} (decrement: {})", msg, decrement, t); + } + } + } + + /** + * Schedules the specified object to be released when the caller thread terminates. Note that this operation is + * intended to simplify reference counting of ephemeral objects during unit tests. Do not use it beyond the + * intended use case. + * + * @deprecated this may introduce a lot of memory usage so it is generally preferable to manually release objects. + */ + @Deprecated + public static T releaseLater(T msg) { + return releaseLater(msg, 1); + } + + /** + * Schedules the specified object to be released when the caller thread terminates. Note that this operation is + * intended to simplify reference counting of ephemeral objects during unit tests. Do not use it beyond the + * intended use case. + * + * @deprecated this may introduce a lot of memory usage so it is generally preferable to manually release objects. + */ + @Deprecated + public static T releaseLater(T msg, int decrement) { + ObjectUtil.checkPositive(decrement, "decrement"); + if (msg instanceof ReferenceCounted) { + ThreadDeathWatcher.watch(Thread.currentThread(), new ReleasingTask((ReferenceCounted) msg, decrement)); + } + return msg; + } + + /** + * Returns reference count of a {@link ReferenceCounted} object. If object is not type of + * {@link ReferenceCounted}, {@code -1} is returned. + */ + public static int refCnt(Object msg) { + return msg instanceof ReferenceCounted ? ((ReferenceCounted) msg).refCnt() : -1; + } + + /** + * Releases the objects when the thread that called {@link #releaseLater(Object)} has been terminated. + */ + private static final class ReleasingTask implements Runnable { + + private final ReferenceCounted obj; + private final int decrement; + + ReleasingTask(ReferenceCounted obj, int decrement) { + this.obj = obj; + this.decrement = decrement; + } + + @Override + public void run() { + try { + if (!obj.release(decrement)) { + logger.warn("Non-zero refCnt: {}", this); + } else { + logger.debug("Released: {}", this); + } + } catch (Exception ex) { + logger.warn("Failed to release an object: {}", obj, ex); + } + } + + @Override + public String toString() { + return StringUtil.simpleClassName(obj) + ".release(" + decrement + ") refCnt: " + obj.refCnt(); + } + } + + private ReferenceCountUtil() { + } +} diff --git a/netty-util/src/main/java/io/netty/util/ReferenceCounted.java b/netty-util/src/main/java/io/netty/util/ReferenceCounted.java new file mode 100644 index 0000000..c510bef --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/ReferenceCounted.java @@ -0,0 +1,77 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util; + +/** + * A reference-counted object that requires explicit deallocation. + *

+ * When a new {@link ReferenceCounted} is instantiated, it starts with the reference count of {@code 1}. + * {@link #retain()} increases the reference count, and {@link #release()} decreases the reference count. + * If the reference count is decreased to {@code 0}, the object will be deallocated explicitly, and accessing + * the deallocated object will usually result in an access violation. + *

+ *

+ * If an object that implements {@link ReferenceCounted} is a container of other objects that implement + * {@link ReferenceCounted}, the contained objects will also be released via {@link #release()} when the container's + * reference count becomes 0. + *

+ */ +public interface ReferenceCounted { + /** + * Returns the reference count of this object. If {@code 0}, it means this object has been deallocated. + */ + int refCnt(); + + /** + * Increases the reference count by {@code 1}. + */ + ReferenceCounted retain(); + + /** + * Increases the reference count by the specified {@code increment}. + */ + ReferenceCounted retain(int increment); + + /** + * Records the current access location of this object for debugging purposes. + * If this object is determined to be leaked, the information recorded by this operation will be provided to you + * via {@link ResourceLeakDetector}. This method is a shortcut to {@link #touch(Object) touch(null)}. + */ + ReferenceCounted touch(); + + /** + * Records the current access location of this object with an additional arbitrary information for debugging + * purposes. If this object is determined to be leaked, the information recorded by this operation will be + * provided to you via {@link ResourceLeakDetector}. + */ + ReferenceCounted touch(Object hint); + + /** + * Decreases the reference count by {@code 1} and deallocates this object if the reference count reaches at + * {@code 0}. + * + * @return {@code true} if and only if the reference count became {@code 0} and this object has been deallocated + */ + boolean release(); + + /** + * Decreases the reference count by the specified {@code decrement} and deallocates this object if the reference + * count reaches at {@code 0}. + * + * @return {@code true} if and only if the reference count became {@code 0} and this object has been deallocated + */ + boolean release(int decrement); +} diff --git a/netty-util/src/main/java/io/netty/util/ResourceLeak.java b/netty-util/src/main/java/io/netty/util/ResourceLeak.java new file mode 100644 index 0000000..3443f14 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/ResourceLeak.java @@ -0,0 +1,42 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.util; + +/** + * @deprecated please use {@link ResourceLeakTracker} as it may lead to false-positives. + */ +@Deprecated +public interface ResourceLeak { + /** + * Records the caller's current stack trace so that the {@link ResourceLeakDetector} can tell where the leaked + * resource was accessed lastly. This method is a shortcut to {@link #record(Object) record(null)}. + */ + void record(); + + /** + * Records the caller's current stack trace and the specified additional arbitrary information + * so that the {@link ResourceLeakDetector} can tell where the leaked resource was accessed lastly. + */ + void record(Object hint); + + /** + * Close the leak so that {@link ResourceLeakDetector} does not warn about leaked resources. + * + * @return {@code true} if called first time, {@code false} if called already + */ + boolean close(); +} diff --git a/netty-util/src/main/java/io/netty/util/ResourceLeakDetector.java b/netty-util/src/main/java/io/netty/util/ResourceLeakDetector.java new file mode 100644 index 0000000..e0c425e --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/ResourceLeakDetector.java @@ -0,0 +1,712 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.util; + +import io.netty.util.internal.EmptyArrays; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.SystemPropertyUtil; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; +import java.lang.ref.ReferenceQueue; +import java.lang.ref.WeakReference; +import java.lang.reflect.Method; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import java.util.concurrent.atomic.AtomicReference; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import static io.netty.util.internal.StringUtil.EMPTY_STRING; +import static io.netty.util.internal.StringUtil.NEWLINE; +import static io.netty.util.internal.StringUtil.simpleClassName; + +public class ResourceLeakDetector { + + private static final String PROP_LEVEL_OLD = "io.netty.leakDetectionLevel"; + private static final String PROP_LEVEL = "io.netty.leakDetection.level"; + private static final Level DEFAULT_LEVEL = Level.SIMPLE; + + private static final String PROP_TARGET_RECORDS = "io.netty.leakDetection.targetRecords"; + private static final int DEFAULT_TARGET_RECORDS = 4; + + private static final String PROP_SAMPLING_INTERVAL = "io.netty.leakDetection.samplingInterval"; + // There is a minor performance benefit in TLR if this is a power of 2. + private static final int DEFAULT_SAMPLING_INTERVAL = 128; + + private static final int TARGET_RECORDS; + static final int SAMPLING_INTERVAL; + + /** + * Represents the level of resource leak detection. + */ + public enum Level { + /** + * Disables resource leak detection. + */ + DISABLED, + /** + * Enables simplistic sampling resource leak detection which reports there is a leak or not, + * at the cost of small overhead (default). + */ + SIMPLE, + /** + * Enables advanced sampling resource leak detection which reports where the leaked object was accessed + * recently at the cost of high overhead. + */ + ADVANCED, + /** + * Enables paranoid resource leak detection which reports where the leaked object was accessed recently, + * at the cost of the highest possible overhead (for testing purposes only). + */ + PARANOID; + + /** + * Returns level based on string value. Accepts also string that represents ordinal number of enum. + * + * @param levelStr - level string : DISABLED, SIMPLE, ADVANCED, PARANOID. Ignores case. + * @return corresponding level or SIMPLE level in case of no match. + */ + static Level parseLevel(String levelStr) { + String trimmedLevelStr = levelStr.trim(); + for (Level l : values()) { + if (trimmedLevelStr.equalsIgnoreCase(l.name()) || trimmedLevelStr.equals(String.valueOf(l.ordinal()))) { + return l; + } + } + return DEFAULT_LEVEL; + } + } + + private static Level level; + + private static final InternalLogger logger = InternalLoggerFactory.getInstance(ResourceLeakDetector.class); + + static { + final boolean disabled; + if (SystemPropertyUtil.get("io.netty.noResourceLeakDetection") != null) { + disabled = SystemPropertyUtil.getBoolean("io.netty.noResourceLeakDetection", false); + logger.debug("-Dio.netty.noResourceLeakDetection: {}", disabled); + logger.warn( + "-Dio.netty.noResourceLeakDetection is deprecated. Use '-D{}={}' instead.", + PROP_LEVEL, Level.DISABLED.name().toLowerCase()); + } else { + disabled = false; + } + + Level defaultLevel = disabled ? Level.DISABLED : DEFAULT_LEVEL; + + // First read old property name + String levelStr = SystemPropertyUtil.get(PROP_LEVEL_OLD, defaultLevel.name()); + + // If new property name is present, use it + levelStr = SystemPropertyUtil.get(PROP_LEVEL, levelStr); + Level level = Level.parseLevel(levelStr); + + TARGET_RECORDS = SystemPropertyUtil.getInt(PROP_TARGET_RECORDS, DEFAULT_TARGET_RECORDS); + SAMPLING_INTERVAL = SystemPropertyUtil.getInt(PROP_SAMPLING_INTERVAL, DEFAULT_SAMPLING_INTERVAL); + + ResourceLeakDetector.level = level; + if (logger.isDebugEnabled()) { + logger.debug("-D{}: {}", PROP_LEVEL, level.name().toLowerCase()); + logger.debug("-D{}: {}", PROP_TARGET_RECORDS, TARGET_RECORDS); + } + } + + /** + * @deprecated Use {@link #setLevel(Level)} instead. + */ + @Deprecated + public static void setEnabled(boolean enabled) { + setLevel(enabled ? Level.SIMPLE : Level.DISABLED); + } + + /** + * Returns {@code true} if resource leak detection is enabled. + */ + public static boolean isEnabled() { + return getLevel().ordinal() > Level.DISABLED.ordinal(); + } + + /** + * Sets the resource leak detection level. + */ + public static void setLevel(Level level) { + ResourceLeakDetector.level = ObjectUtil.checkNotNull(level, "level"); + } + + /** + * Returns the current resource leak detection level. + */ + public static Level getLevel() { + return level; + } + + /** + * the collection of active resources + */ + private final Set> allLeaks = + Collections.newSetFromMap(new ConcurrentHashMap, Boolean>()); + + private final ReferenceQueue refQueue = new ReferenceQueue(); + private final Set reportedLeaks = + Collections.newSetFromMap(new ConcurrentHashMap()); + + private final String resourceType; + private final int samplingInterval; + + /** + * Will be notified once a leak is detected. + */ + private volatile LeakListener leakListener; + + /** + * @deprecated use {@link ResourceLeakDetectorFactory#newResourceLeakDetector(Class, int, long)}. + */ + @Deprecated + public ResourceLeakDetector(Class resourceType) { + this(simpleClassName(resourceType)); + } + + /** + * @deprecated use {@link ResourceLeakDetectorFactory#newResourceLeakDetector(Class, int, long)}. + */ + @Deprecated + public ResourceLeakDetector(String resourceType) { + this(resourceType, DEFAULT_SAMPLING_INTERVAL, Long.MAX_VALUE); + } + + /** + * @param maxActive This is deprecated and will be ignored. + * @deprecated Use {@link ResourceLeakDetector#ResourceLeakDetector(Class, int)}. + *

+ * This should not be used directly by users of {@link ResourceLeakDetector}. + * Please use {@link ResourceLeakDetectorFactory#newResourceLeakDetector(Class)} + * or {@link ResourceLeakDetectorFactory#newResourceLeakDetector(Class, int, long)} + */ + @Deprecated + public ResourceLeakDetector(Class resourceType, int samplingInterval, long maxActive) { + this(resourceType, samplingInterval); + } + + /** + * This should not be used directly by users of {@link ResourceLeakDetector}. + * Please use {@link ResourceLeakDetectorFactory#newResourceLeakDetector(Class)} + * or {@link ResourceLeakDetectorFactory#newResourceLeakDetector(Class, int, long)} + */ + @SuppressWarnings("deprecation") + public ResourceLeakDetector(Class resourceType, int samplingInterval) { + this(simpleClassName(resourceType), samplingInterval, Long.MAX_VALUE); + } + + /** + * @param maxActive This is deprecated and will be ignored. + * @deprecated use {@link ResourceLeakDetectorFactory#newResourceLeakDetector(Class, int, long)}. + *

+ */ + @Deprecated + public ResourceLeakDetector(String resourceType, int samplingInterval, long maxActive) { + this.resourceType = ObjectUtil.checkNotNull(resourceType, "resourceType"); + this.samplingInterval = samplingInterval; + } + + /** + * Creates a new {@link ResourceLeak} which is expected to be closed via {@link ResourceLeak#close()} when the + * related resource is deallocated. + * + * @return the {@link ResourceLeak} or {@code null} + * @deprecated use {@link #track(Object)} + */ + @Deprecated + public final ResourceLeak open(T obj) { + return track0(obj, false); + } + + /** + * Creates a new {@link ResourceLeakTracker} which is expected to be closed via + * {@link ResourceLeakTracker#close(Object)} when the related resource is deallocated. + * + * @return the {@link ResourceLeakTracker} or {@code null} + */ + @SuppressWarnings("unchecked") + public final ResourceLeakTracker track(T obj) { + return track0(obj, false); + } + + /** + * Creates a new {@link ResourceLeakTracker} which is expected to be closed via + * {@link ResourceLeakTracker#close(Object)} when the related resource is deallocated. + *

+ * Unlike {@link #track(Object)}, this method always returns a tracker, regardless + * of the detection settings. + * + * @return the {@link ResourceLeakTracker} + */ + @SuppressWarnings("unchecked") + public ResourceLeakTracker trackForcibly(T obj) { + return track0(obj, true); + } + + @SuppressWarnings("unchecked") + private DefaultResourceLeak track0(T obj, boolean force) { + Level level = ResourceLeakDetector.level; + if (force || + level == Level.PARANOID || + (level != Level.DISABLED && PlatformDependent.threadLocalRandom().nextInt(samplingInterval) == 0)) { + reportLeak(); + return new DefaultResourceLeak(obj, refQueue, allLeaks, getInitialHint(resourceType)); + } + return null; + } + + private void clearRefQueue() { + for (; ; ) { + DefaultResourceLeak ref = (DefaultResourceLeak) refQueue.poll(); + if (ref == null) { + break; + } + ref.dispose(); + } + } + + /** + * When the return value is {@code true}, {@link #reportTracedLeak} and {@link #reportUntracedLeak} + * will be called once a leak is detected, otherwise not. + * + * @return {@code true} to enable leak reporting. + */ + protected boolean needReport() { + return logger.isErrorEnabled(); + } + + private void reportLeak() { + if (!needReport()) { + clearRefQueue(); + return; + } + + // Detect and report previous leaks. + for (; ; ) { + DefaultResourceLeak ref = (DefaultResourceLeak) refQueue.poll(); + if (ref == null) { + break; + } + + if (!ref.dispose()) { + continue; + } + + String records = ref.getReportAndClearRecords(); + if (reportedLeaks.add(records)) { + if (records.isEmpty()) { + reportUntracedLeak(resourceType); + } else { + reportTracedLeak(resourceType, records); + } + + LeakListener listener = leakListener; + if (listener != null) { + listener.onLeak(resourceType, records); + } + } + } + } + + /** + * This method is called when a traced leak is detected. It can be overridden for tracking how many times leaks + * have been detected. + */ + protected void reportTracedLeak(String resourceType, String records) { + logger.error( + "LEAK: {}.release() was not called before it's garbage-collected. " + + "See https://netty.io/wiki/reference-counted-objects.html for more information.{}", + resourceType, records); + } + + /** + * This method is called when an untraced leak is detected. It can be overridden for tracking how many times leaks + * have been detected. + */ + protected void reportUntracedLeak(String resourceType) { + logger.error("LEAK: {}.release() was not called before it's garbage-collected. " + + "Enable advanced leak reporting to find out where the leak occurred. " + + "To enable advanced leak reporting, " + + "specify the JVM option '-D{}={}' or call {}.setLevel() " + + "See https://netty.io/wiki/reference-counted-objects.html for more information.", + resourceType, PROP_LEVEL, Level.ADVANCED.name().toLowerCase(), simpleClassName(this)); + } + + /** + * @deprecated This method will no longer be invoked by {@link ResourceLeakDetector}. + */ + @Deprecated + protected void reportInstancesLeak(String resourceType) { + } + + /** + * Create a hint object to be attached to an object tracked by this record. Similar to the additional information + * supplied to {@link ResourceLeakTracker#record(Object)}, will be printed alongside the stack trace of the + * creation of the resource. + */ + protected Object getInitialHint(String resourceType) { + return null; + } + + /** + * Set leak listener. Previous listener will be replaced. + */ + public void setLeakListener(LeakListener leakListener) { + this.leakListener = leakListener; + } + + public interface LeakListener { + + /** + * Will be called once a leak is detected. + */ + void onLeak(String resourceType, String records); + } + + @SuppressWarnings("deprecation") + private static final class DefaultResourceLeak + extends WeakReference implements ResourceLeakTracker, ResourceLeak { + + @SuppressWarnings("unchecked") // generics and updaters do not mix. + private static final AtomicReferenceFieldUpdater, TraceRecord> headUpdater = + (AtomicReferenceFieldUpdater) + AtomicReferenceFieldUpdater.newUpdater(DefaultResourceLeak.class, TraceRecord.class, "head"); + + @SuppressWarnings("unchecked") // generics and updaters do not mix. + private static final AtomicIntegerFieldUpdater> droppedRecordsUpdater = + (AtomicIntegerFieldUpdater) + AtomicIntegerFieldUpdater.newUpdater(DefaultResourceLeak.class, "droppedRecords"); + + @SuppressWarnings("unused") + private volatile TraceRecord head; + @SuppressWarnings("unused") + private volatile int droppedRecords; + + private final Set> allLeaks; + private final int trackedHash; + + DefaultResourceLeak( + Object referent, + ReferenceQueue refQueue, + Set> allLeaks, + Object initialHint) { + super(referent, refQueue); + + assert referent != null; + + // Store the hash of the tracked object to later assert it in the close(...) method. + // It's important that we not store a reference to the referent as this would disallow it from + // be collected via the WeakReference. + trackedHash = System.identityHashCode(referent); + allLeaks.add(this); + // Create a new Record so we always have the creation stacktrace included. + headUpdater.set(this, initialHint == null ? + new TraceRecord(TraceRecord.BOTTOM) : new TraceRecord(TraceRecord.BOTTOM, initialHint)); + this.allLeaks = allLeaks; + } + + @Override + public void record() { + record0(null); + } + + @Override + public void record(Object hint) { + record0(hint); + } + + /** + * This method works by exponentially backing off as more records are present in the stack. Each record has a + * 1 / 2^n chance of dropping the top most record and replacing it with itself. This has a number of convenient + * properties: + * + *
    + *
  1. The current record is always recorded. This is due to the compare and swap dropping the top most + * record, rather than the to-be-pushed record. + *
  2. The very last access will always be recorded. This comes as a property of 1. + *
  3. It is possible to retain more records than the target, based upon the probability distribution. + *
  4. It is easy to keep a precise record of the number of elements in the stack, since each element has to + * know how tall the stack is. + *
+ *

+ * In this particular implementation, there are also some advantages. A thread local random is used to decide + * if something should be recorded. This means that if there is a deterministic access pattern, it is now + * possible to see what other accesses occur, rather than always dropping them. Second, after + * {@link #TARGET_RECORDS} accesses, backoff occurs. This matches typical access patterns, + * where there are either a high number of accesses (i.e. a cached buffer), or low (an ephemeral buffer), but + * not many in between. + *

+ * The use of atomics avoids serializing a high number of accesses, when most of the records will be thrown + * away. High contention only happens when there are very few existing records, which is only likely when the + * object isn't shared! If this is a problem, the loop can be aborted and the record dropped, because another + * thread won the race. + */ + private void record0(Object hint) { + // Check TARGET_RECORDS > 0 here to avoid similar check before remove from and add to lastRecords + if (TARGET_RECORDS > 0) { + TraceRecord oldHead; + TraceRecord prevHead; + TraceRecord newHead; + boolean dropped; + do { + if ((prevHead = oldHead = headUpdater.get(this)) == null) { + // already closed. + return; + } + final int numElements = oldHead.pos + 1; + if (numElements >= TARGET_RECORDS) { + final int backOffFactor = Math.min(numElements - TARGET_RECORDS, 30); + if (dropped = PlatformDependent.threadLocalRandom().nextInt(1 << backOffFactor) != 0) { + prevHead = oldHead.next; + } + } else { + dropped = false; + } + newHead = hint != null ? new TraceRecord(prevHead, hint) : new TraceRecord(prevHead); + } while (!headUpdater.compareAndSet(this, oldHead, newHead)); + if (dropped) { + droppedRecordsUpdater.incrementAndGet(this); + } + } + } + + boolean dispose() { + clear(); + return allLeaks.remove(this); + } + + @Override + public boolean close() { + if (allLeaks.remove(this)) { + // Call clear so the reference is not even enqueued. + clear(); + headUpdater.set(this, null); + return true; + } + return false; + } + + @Override + public boolean close(T trackedObject) { + // Ensure that the object that was tracked is the same as the one that was passed to close(...). + assert trackedHash == System.identityHashCode(trackedObject); + + try { + return close(); + } finally { + // This method will do `synchronized(trackedObject)` and we should be sure this will not cause deadlock. + // It should not, because somewhere up the callstack should be a (successful) `trackedObject.release`, + // therefore it is unreasonable that anyone else, anywhere, is holding a lock on the trackedObject. + // (Unreasonable but possible, unfortunately.) + reachabilityFence0(trackedObject); + } + } + + /** + * Ensures that the object referenced by the given reference remains + * strongly reachable, + * regardless of any prior actions of the program that might otherwise cause + * the object to become unreachable; thus, the referenced object is not + * reclaimable by garbage collection at least until after the invocation of + * this method. + * + *

Recent versions of the JDK have a nasty habit of prematurely deciding objects are unreachable. + * see: https://stackoverflow.com/questions/26642153/finalize-called-on-strongly-reachable-object-in-java-8 + * The Java 9 method Reference.reachabilityFence offers a solution to this problem. + * + *

This method is always implemented as a synchronization on {@code ref}, not as + * {@code Reference.reachabilityFence} for consistency across platforms and to allow building on JDK 6-8. + * It is the caller's responsibility to ensure that this synchronization will not cause deadlock. + * + * @param ref the reference. If {@code null}, this method has no effect. + * @see java.lang.ref.Reference#reachabilityFence + */ + private static void reachabilityFence0(Object ref) { + if (ref != null) { + synchronized (ref) { + // Empty synchronized is ok: https://stackoverflow.com/a/31933260/1151521 + } + } + } + + @Override + public String toString() { + TraceRecord oldHead = headUpdater.get(this); + return generateReport(oldHead); + } + + String getReportAndClearRecords() { + TraceRecord oldHead = headUpdater.getAndSet(this, null); + return generateReport(oldHead); + } + + private String generateReport(TraceRecord oldHead) { + if (oldHead == null) { + // Already closed + return EMPTY_STRING; + } + + final int dropped = droppedRecordsUpdater.get(this); + int duped = 0; + + int present = oldHead.pos + 1; + // Guess about 2 kilobytes per stack trace + StringBuilder buf = new StringBuilder(present * 2048).append(NEWLINE); + buf.append("Recent access records: ").append(NEWLINE); + + int i = 1; + Set seen = new HashSet(present); + for (; oldHead != TraceRecord.BOTTOM; oldHead = oldHead.next) { + String s = oldHead.toString(); + if (seen.add(s)) { + if (oldHead.next == TraceRecord.BOTTOM) { + buf.append("Created at:").append(NEWLINE).append(s); + } else { + buf.append('#').append(i++).append(':').append(NEWLINE).append(s); + } + } else { + duped++; + } + } + + if (duped > 0) { + buf.append(": ") + .append(duped) + .append(" leak records were discarded because they were duplicates") + .append(NEWLINE); + } + + if (dropped > 0) { + buf.append(": ") + .append(dropped) + .append(" leak records were discarded because the leak record count is targeted to ") + .append(TARGET_RECORDS) + .append(". Use system property ") + .append(PROP_TARGET_RECORDS) + .append(" to increase the limit.") + .append(NEWLINE); + } + + buf.setLength(buf.length() - NEWLINE.length()); + return buf.toString(); + } + } + + private static final AtomicReference excludedMethods = + new AtomicReference(EmptyArrays.EMPTY_STRINGS); + + public static void addExclusions(Class clz, String... methodNames) { + Set nameSet = new HashSet(Arrays.asList(methodNames)); + // Use loop rather than lookup. This avoids knowing the parameters, and doesn't have to handle + // NoSuchMethodException. + for (Method method : clz.getDeclaredMethods()) { + if (nameSet.remove(method.getName()) && nameSet.isEmpty()) { + break; + } + } + if (!nameSet.isEmpty()) { + throw new IllegalArgumentException("Can't find '" + nameSet + "' in " + clz.getName()); + } + String[] oldMethods; + String[] newMethods; + do { + oldMethods = excludedMethods.get(); + newMethods = Arrays.copyOf(oldMethods, oldMethods.length + 2 * methodNames.length); + for (int i = 0; i < methodNames.length; i++) { + newMethods[oldMethods.length + i * 2] = clz.getName(); + newMethods[oldMethods.length + i * 2 + 1] = methodNames[i]; + } + } while (!excludedMethods.compareAndSet(oldMethods, newMethods)); + } + + private static class TraceRecord extends Throwable { + private static final long serialVersionUID = 6065153674892850720L; + + private static final TraceRecord BOTTOM = new TraceRecord() { + private static final long serialVersionUID = 7396077602074694571L; + + // Override fillInStackTrace() so we not populate the backtrace via a native call and so leak the + // Classloader. + // See https://github.com/netty/netty/pull/10691 + @Override + public Throwable fillInStackTrace() { + return this; + } + }; + + private final String hintString; + private final TraceRecord next; + private final int pos; + + TraceRecord(TraceRecord next, Object hint) { + // This needs to be generated even if toString() is never called as it may change later on. + hintString = hint instanceof ResourceLeakHint ? ((ResourceLeakHint) hint).toHintString() : hint.toString(); + this.next = next; + this.pos = next.pos + 1; + } + + TraceRecord(TraceRecord next) { + hintString = null; + this.next = next; + this.pos = next.pos + 1; + } + + // Used to terminate the stack + private TraceRecord() { + hintString = null; + next = null; + pos = -1; + } + + @Override + public String toString() { + StringBuilder buf = new StringBuilder(2048); + if (hintString != null) { + buf.append("\tHint: ").append(hintString).append(NEWLINE); + } + + // Append the stack trace. + StackTraceElement[] array = getStackTrace(); + // Skip the first three elements. + out: + for (int i = 3; i < array.length; i++) { + StackTraceElement element = array[i]; + // Strip the noisy stack trace elements. + String[] exclusions = excludedMethods.get(); + for (int k = 0; k < exclusions.length; k += 2) { + // Suppress a warning about out of bounds access + // since the length of excludedMethods is always even, see addExclusions() + if (exclusions[k].equals(element.getClassName()) + && exclusions[k + 1].equals(element.getMethodName())) { + continue out; + } + } + + buf.append('\t'); + buf.append(element.toString()); + buf.append(NEWLINE); + } + return buf.toString(); + } + } +} diff --git a/netty-util/src/main/java/io/netty/util/ResourceLeakDetectorFactory.java b/netty-util/src/main/java/io/netty/util/ResourceLeakDetectorFactory.java new file mode 100644 index 0000000..12383a7 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/ResourceLeakDetectorFactory.java @@ -0,0 +1,198 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.util; + +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.SystemPropertyUtil; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; +import java.lang.reflect.Constructor; + +/** + * This static factory should be used to load {@link ResourceLeakDetector}s as needed + */ +public abstract class ResourceLeakDetectorFactory { + private static final InternalLogger logger = InternalLoggerFactory.getInstance(ResourceLeakDetectorFactory.class); + + private static volatile ResourceLeakDetectorFactory factoryInstance = new DefaultResourceLeakDetectorFactory(); + + /** + * Get the singleton instance of this factory class. + * + * @return the current {@link ResourceLeakDetectorFactory} + */ + public static ResourceLeakDetectorFactory instance() { + return factoryInstance; + } + + /** + * Set the factory's singleton instance. This has to be called before the static initializer of the + * {@link ResourceLeakDetector} is called by all the callers of this factory. That is, before initializing a + * Netty Bootstrap. + * + * @param factory the instance that will become the current {@link ResourceLeakDetectorFactory}'s singleton + */ + public static void setResourceLeakDetectorFactory(ResourceLeakDetectorFactory factory) { + factoryInstance = ObjectUtil.checkNotNull(factory, "factory"); + } + + /** + * Returns a new instance of a {@link ResourceLeakDetector} with the given resource class. + * + * @param resource the resource class used to initialize the {@link ResourceLeakDetector} + * @param the type of the resource class + * @return a new instance of {@link ResourceLeakDetector} + */ + public final ResourceLeakDetector newResourceLeakDetector(Class resource) { + return newResourceLeakDetector(resource, ResourceLeakDetector.SAMPLING_INTERVAL); + } + + /** + * @param resource the resource class used to initialize the {@link ResourceLeakDetector} + * @param samplingInterval the interval on which sampling takes place + * @param maxActive This is deprecated and will be ignored. + * @param the type of the resource class + * @return a new instance of {@link ResourceLeakDetector} + * @deprecated Use {@link #newResourceLeakDetector(Class, int)} instead. + *

+ * Returns a new instance of a {@link ResourceLeakDetector} with the given resource class. + */ + @Deprecated + public abstract ResourceLeakDetector newResourceLeakDetector( + Class resource, int samplingInterval, long maxActive); + + /** + * Returns a new instance of a {@link ResourceLeakDetector} with the given resource class. + * + * @param resource the resource class used to initialize the {@link ResourceLeakDetector} + * @param samplingInterval the interval on which sampling takes place + * @param the type of the resource class + * @return a new instance of {@link ResourceLeakDetector} + */ + @SuppressWarnings("deprecation") + public ResourceLeakDetector newResourceLeakDetector(Class resource, int samplingInterval) { + ObjectUtil.checkPositive(samplingInterval, "samplingInterval"); + return newResourceLeakDetector(resource, samplingInterval, Long.MAX_VALUE); + } + + /** + * Default implementation that loads custom leak detector via system property + */ + private static final class DefaultResourceLeakDetectorFactory extends ResourceLeakDetectorFactory { + private final Constructor obsoleteCustomClassConstructor; + private final Constructor customClassConstructor; + + DefaultResourceLeakDetectorFactory() { + String customLeakDetector; + try { + customLeakDetector = SystemPropertyUtil.get("io.netty.customResourceLeakDetector"); + } catch (Throwable cause) { + logger.error("Could not access System property: io.netty.customResourceLeakDetector", cause); + customLeakDetector = null; + } + if (customLeakDetector == null) { + obsoleteCustomClassConstructor = customClassConstructor = null; + } else { + obsoleteCustomClassConstructor = obsoleteCustomClassConstructor(customLeakDetector); + customClassConstructor = customClassConstructor(customLeakDetector); + } + } + + private static Constructor obsoleteCustomClassConstructor(String customLeakDetector) { + try { + final Class detectorClass = Class.forName(customLeakDetector, true, + PlatformDependent.getSystemClassLoader()); + + if (ResourceLeakDetector.class.isAssignableFrom(detectorClass)) { + return detectorClass.getConstructor(Class.class, int.class, long.class); + } else { + logger.error("Class {} does not inherit from ResourceLeakDetector.", customLeakDetector); + } + } catch (Throwable t) { + logger.error("Could not load custom resource leak detector class provided: {}", + customLeakDetector, t); + } + return null; + } + + private static Constructor customClassConstructor(String customLeakDetector) { + try { + final Class detectorClass = Class.forName(customLeakDetector, true, + PlatformDependent.getSystemClassLoader()); + + if (ResourceLeakDetector.class.isAssignableFrom(detectorClass)) { + return detectorClass.getConstructor(Class.class, int.class); + } else { + logger.error("Class {} does not inherit from ResourceLeakDetector.", customLeakDetector); + } + } catch (Throwable t) { + logger.error("Could not load custom resource leak detector class provided: {}", + customLeakDetector, t); + } + return null; + } + + @SuppressWarnings("deprecation") + @Override + public ResourceLeakDetector newResourceLeakDetector(Class resource, int samplingInterval, + long maxActive) { + if (obsoleteCustomClassConstructor != null) { + try { + @SuppressWarnings("unchecked") + ResourceLeakDetector leakDetector = + (ResourceLeakDetector) obsoleteCustomClassConstructor.newInstance( + resource, samplingInterval, maxActive); + logger.debug("Loaded custom ResourceLeakDetector: {}", + obsoleteCustomClassConstructor.getDeclaringClass().getName()); + return leakDetector; + } catch (Throwable t) { + logger.error( + "Could not load custom resource leak detector provided: {} with the given resource: {}", + obsoleteCustomClassConstructor.getDeclaringClass().getName(), resource, t); + } + } + + ResourceLeakDetector resourceLeakDetector = new ResourceLeakDetector(resource, samplingInterval, + maxActive); + logger.debug("Loaded default ResourceLeakDetector: {}", resourceLeakDetector); + return resourceLeakDetector; + } + + @Override + public ResourceLeakDetector newResourceLeakDetector(Class resource, int samplingInterval) { + if (customClassConstructor != null) { + try { + @SuppressWarnings("unchecked") + ResourceLeakDetector leakDetector = + (ResourceLeakDetector) customClassConstructor.newInstance(resource, samplingInterval); + logger.debug("Loaded custom ResourceLeakDetector: {}", + customClassConstructor.getDeclaringClass().getName()); + return leakDetector; + } catch (Throwable t) { + logger.error( + "Could not load custom resource leak detector provided: {} with the given resource: {}", + customClassConstructor.getDeclaringClass().getName(), resource, t); + } + } + + ResourceLeakDetector resourceLeakDetector = new ResourceLeakDetector(resource, samplingInterval); + logger.debug("Loaded default ResourceLeakDetector: {}", resourceLeakDetector); + return resourceLeakDetector; + } + } +} diff --git a/netty-util/src/main/java/io/netty/util/ResourceLeakException.java b/netty-util/src/main/java/io/netty/util/ResourceLeakException.java new file mode 100644 index 0000000..e4bb81a --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/ResourceLeakException.java @@ -0,0 +1,70 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.util; + +import java.util.Arrays; + +/** + * @deprecated This class will be removed in the future version. + */ +@Deprecated +public class ResourceLeakException extends RuntimeException { + + private static final long serialVersionUID = 7186453858343358280L; + + private final StackTraceElement[] cachedStackTrace; + + public ResourceLeakException() { + cachedStackTrace = getStackTrace(); + } + + public ResourceLeakException(String message) { + super(message); + cachedStackTrace = getStackTrace(); + } + + public ResourceLeakException(String message, Throwable cause) { + super(message, cause); + cachedStackTrace = getStackTrace(); + } + + public ResourceLeakException(Throwable cause) { + super(cause); + cachedStackTrace = getStackTrace(); + } + + @Override + public int hashCode() { + int hashCode = 0; + for (StackTraceElement e : cachedStackTrace) { + hashCode = hashCode * 31 + e.hashCode(); + } + return hashCode; + } + + @Override + public boolean equals(Object o) { + if (!(o instanceof ResourceLeakException)) { + return false; + } + if (o == this) { + return true; + } + + return Arrays.equals(cachedStackTrace, ((ResourceLeakException) o).cachedStackTrace); + } +} diff --git a/netty-util/src/main/java/io/netty/util/ResourceLeakHint.java b/netty-util/src/main/java/io/netty/util/ResourceLeakHint.java new file mode 100644 index 0000000..57811db --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/ResourceLeakHint.java @@ -0,0 +1,27 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.util; + +/** + * A hint object that provides human-readable message for easier resource leak tracking. + */ +public interface ResourceLeakHint { + /** + * Returns a human-readable message that potentially enables easier resource leak tracking. + */ + String toHintString(); +} diff --git a/netty-util/src/main/java/io/netty/util/ResourceLeakTracker.java b/netty-util/src/main/java/io/netty/util/ResourceLeakTracker.java new file mode 100644 index 0000000..e4d59f9 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/ResourceLeakTracker.java @@ -0,0 +1,39 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util; + +public interface ResourceLeakTracker { + + /** + * Records the caller's current stack trace so that the {@link ResourceLeakDetector} can tell where the leaked + * resource was accessed lastly. This method is a shortcut to {@link #record(Object) record(null)}. + */ + void record(); + + /** + * Records the caller's current stack trace and the specified additional arbitrary information + * so that the {@link ResourceLeakDetector} can tell where the leaked resource was accessed lastly. + */ + void record(Object hint); + + /** + * Close the leak so that {@link ResourceLeakTracker} does not warn about leaked resources. + * After this method is called a leak associated with this ResourceLeakTracker should not be reported. + * + * @return {@code true} if called first time, {@code false} if called already + */ + boolean close(T trackedObject); +} diff --git a/netty-util/src/main/java/io/netty/util/Signal.java b/netty-util/src/main/java/io/netty/util/Signal.java new file mode 100644 index 0000000..d6d1aed --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/Signal.java @@ -0,0 +1,118 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util; + + +/** + * A special {@link Error} which is used to signal some state or request by throwing it. + * {@link Signal} has an empty stack trace and has no cause to save the instantiation overhead. + */ +public final class Signal extends Error implements Constant { + + private static final long serialVersionUID = -221145131122459977L; + + private static final ConstantPool pool = new ConstantPool() { + @Override + protected Signal newConstant(int id, String name) { + return new Signal(id, name); + } + }; + + /** + * Returns the {@link Signal} of the specified name. + */ + public static Signal valueOf(String name) { + return pool.valueOf(name); + } + + /** + * Shortcut of {@link #valueOf(String) valueOf(firstNameComponent.getName() + "#" + secondNameComponent)}. + */ + public static Signal valueOf(Class firstNameComponent, String secondNameComponent) { + return pool.valueOf(firstNameComponent, secondNameComponent); + } + + private final SignalConstant constant; + + /** + * Creates a new {@link Signal} with the specified {@code name}. + */ + private Signal(int id, String name) { + constant = new SignalConstant(id, name); + } + + /** + * Check if the given {@link Signal} is the same as this instance. If not an {@link IllegalStateException} will + * be thrown. + */ + public void expect(Signal signal) { + if (this != signal) { + throw new IllegalStateException("unexpected signal: " + signal); + } + } + + // Suppress a warning since the method doesn't need synchronization + @Override + public Throwable initCause(Throwable cause) { + return this; + } + + // Suppress a warning since the method doesn't need synchronization + @Override + public Throwable fillInStackTrace() { + return this; + } + + @Override + public int id() { + return constant.id(); + } + + @Override + public String name() { + return constant.name(); + } + + @Override + public boolean equals(Object obj) { + return this == obj; + } + + @Override + public int hashCode() { + return System.identityHashCode(this); + } + + @Override + public int compareTo(Signal other) { + if (this == other) { + return 0; + } + + return constant.compareTo(other.constant); + } + + @Override + public String toString() { + return name(); + } + + private static final class SignalConstant extends AbstractConstant { + SignalConstant(int id, String name) { + super(id, name); + } + } +} diff --git a/netty-util/src/main/java/io/netty/util/SuppressForbidden.java b/netty-util/src/main/java/io/netty/util/SuppressForbidden.java new file mode 100644 index 0000000..5fbe698 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/SuppressForbidden.java @@ -0,0 +1,32 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.util; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * Annotation to suppress forbidden-apis errors inside a whole class, a method, or a field. + */ +@Retention(RetentionPolicy.CLASS) +@Target({ElementType.CONSTRUCTOR, ElementType.FIELD, ElementType.METHOD, ElementType.TYPE}) +public @interface SuppressForbidden { + + String reason(); +} diff --git a/netty-util/src/main/java/io/netty/util/ThreadDeathWatcher.java b/netty-util/src/main/java/io/netty/util/ThreadDeathWatcher.java new file mode 100644 index 0000000..8da83bb --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/ThreadDeathWatcher.java @@ -0,0 +1,258 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.util; + +import io.netty.util.concurrent.DefaultThreadFactory; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.StringUtil; +import io.netty.util.internal.SystemPropertyUtil; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; +import java.security.AccessController; +import java.security.PrivilegedAction; +import java.util.ArrayList; +import java.util.List; +import java.util.Queue; +import java.util.concurrent.ConcurrentLinkedQueue; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * Checks if a thread is alive periodically and runs a task when a thread dies. + *

+ * This thread starts a daemon thread to check the state of the threads being watched and to invoke their + * associated {@link Runnable}s. When there is no thread to watch (i.e. all threads are dead), the daemon thread + * will terminate itself, and a new daemon thread will be started again when a new watch is added. + *

+ * + * @deprecated will be removed in the next major release + */ +@Deprecated +public final class ThreadDeathWatcher { + + private static final InternalLogger logger = InternalLoggerFactory.getInstance(ThreadDeathWatcher.class); + // visible for testing + static final ThreadFactory threadFactory; + + // Use a MPMC queue as we may end up checking isEmpty() from multiple threads which may not be allowed to do + // concurrently depending on the implementation of it in a MPSC queue. + private static final Queue pendingEntries = new ConcurrentLinkedQueue(); + private static final Watcher watcher = new Watcher(); + private static final AtomicBoolean started = new AtomicBoolean(); + private static volatile Thread watcherThread; + + static { + String poolName = "threadDeathWatcher"; + String serviceThreadPrefix = SystemPropertyUtil.get("io.netty.serviceThreadPrefix"); + if (!StringUtil.isNullOrEmpty(serviceThreadPrefix)) { + poolName = serviceThreadPrefix + poolName; + } + // because the ThreadDeathWatcher is a singleton, tasks submitted to it can come from arbitrary threads and + // this can trigger the creation of a thread from arbitrary thread groups; for this reason, the thread factory + // must not be sticky about its thread group + threadFactory = new DefaultThreadFactory(poolName, true, Thread.MIN_PRIORITY, null); + } + + /** + * Schedules the specified {@code task} to run when the specified {@code thread} dies. + * + * @param thread the {@link Thread} to watch + * @param task the {@link Runnable} to run when the {@code thread} dies + * @throws IllegalArgumentException if the specified {@code thread} is not alive + */ + public static void watch(Thread thread, Runnable task) { + ObjectUtil.checkNotNull(thread, "thread"); + ObjectUtil.checkNotNull(task, "task"); + + if (!thread.isAlive()) { + throw new IllegalArgumentException("thread must be alive."); + } + + schedule(thread, task, true); + } + + /** + * Cancels the task scheduled via {@link #watch(Thread, Runnable)}. + */ + public static void unwatch(Thread thread, Runnable task) { + schedule(ObjectUtil.checkNotNull(thread, "thread"), + ObjectUtil.checkNotNull(task, "task"), + false); + } + + private static void schedule(Thread thread, Runnable task, boolean isWatch) { + pendingEntries.add(new Entry(thread, task, isWatch)); + + if (started.compareAndSet(false, true)) { + final Thread watcherThread = threadFactory.newThread(watcher); + // Set to null to ensure we not create classloader leaks by holds a strong reference to the inherited + // classloader. + // See: + // - https://github.com/netty/netty/issues/7290 + // - https://bugs.openjdk.java.net/browse/JDK-7008595 + AccessController.doPrivileged(new PrivilegedAction() { + @Override + public Void run() { + watcherThread.setContextClassLoader(null); + return null; + } + }); + + watcherThread.start(); + ThreadDeathWatcher.watcherThread = watcherThread; + } + } + + /** + * Waits until the thread of this watcher has no threads to watch and terminates itself. + * Because a new watcher thread will be started again on {@link #watch(Thread, Runnable)}, + * this operation is only useful when you want to ensure that the watcher thread is terminated + * after your application is shut down and there's no chance of calling + * {@link #watch(Thread, Runnable)} afterwards. + * + * @return {@code true} if and only if the watcher thread has been terminated + */ + public static boolean awaitInactivity(long timeout, TimeUnit unit) throws InterruptedException { + ObjectUtil.checkNotNull(unit, "unit"); + + Thread watcherThread = ThreadDeathWatcher.watcherThread; + if (watcherThread != null) { + watcherThread.join(unit.toMillis(timeout)); + return !watcherThread.isAlive(); + } else { + return true; + } + } + + private ThreadDeathWatcher() { + } + + private static final class Watcher implements Runnable { + + private final List watchees = new ArrayList(); + + @Override + public void run() { + for (; ; ) { + fetchWatchees(); + notifyWatchees(); + + // Try once again just in case notifyWatchees() triggered watch() or unwatch(). + fetchWatchees(); + notifyWatchees(); + + try { + Thread.sleep(1000); + } catch (InterruptedException ignore) { + // Ignore the interrupt; do not terminate until all tasks are run. + } + + if (watchees.isEmpty() && pendingEntries.isEmpty()) { + + // Mark the current worker thread as stopped. + // The following CAS must always success and must be uncontended, + // because only one watcher thread should be running at the same time. + boolean stopped = started.compareAndSet(true, false); + assert stopped; + + // Check if there are pending entries added by watch() while we do CAS above. + if (pendingEntries.isEmpty()) { + // A) watch() was not invoked and thus there's nothing to handle + // -> safe to terminate because there's nothing left to do + // B) a new watcher thread started and handled them all + // -> safe to terminate the new watcher thread will take care the rest + break; + } + + // There are pending entries again, added by watch() + if (!started.compareAndSet(false, true)) { + // watch() started a new watcher thread and set 'started' to true. + // -> terminate this thread so that the new watcher reads from pendingEntries exclusively. + break; + } + + // watch() added an entry, but this worker was faster to set 'started' to true. + // i.e. a new watcher thread was not started + // -> keep this thread alive to handle the newly added entries. + } + } + } + + private void fetchWatchees() { + for (; ; ) { + Entry e = pendingEntries.poll(); + if (e == null) { + break; + } + + if (e.isWatch) { + watchees.add(e); + } else { + watchees.remove(e); + } + } + } + + private void notifyWatchees() { + List watchees = this.watchees; + for (int i = 0; i < watchees.size(); ) { + Entry e = watchees.get(i); + if (!e.thread.isAlive()) { + watchees.remove(i); + try { + e.task.run(); + } catch (Throwable t) { + logger.warn("Thread death watcher task raised an exception:", t); + } + } else { + i++; + } + } + } + } + + private static final class Entry { + final Thread thread; + final Runnable task; + final boolean isWatch; + + Entry(Thread thread, Runnable task, boolean isWatch) { + this.thread = thread; + this.task = task; + this.isWatch = isWatch; + } + + @Override + public int hashCode() { + return thread.hashCode() ^ task.hashCode(); + } + + @Override + public boolean equals(Object obj) { + if (obj == this) { + return true; + } + + if (!(obj instanceof Entry that)) { + return false; + } + + return thread == that.thread && task == that.task; + } + } +} diff --git a/netty-util/src/main/java/io/netty/util/Timeout.java b/netty-util/src/main/java/io/netty/util/Timeout.java new file mode 100644 index 0000000..7e0df58 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/Timeout.java @@ -0,0 +1,54 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util; + +/** + * A handle associated with a {@link TimerTask} that is returned by a + * {@link Timer}. + */ +public interface Timeout { + + /** + * Returns the {@link Timer} that created this handle. + */ + Timer timer(); + + /** + * Returns the {@link TimerTask} which is associated with this handle. + */ + TimerTask task(); + + /** + * Returns {@code true} if and only if the {@link TimerTask} associated + * with this handle has been expired. + */ + boolean isExpired(); + + /** + * Returns {@code true} if and only if the {@link TimerTask} associated + * with this handle has been cancelled. + */ + boolean isCancelled(); + + /** + * Attempts to cancel the {@link TimerTask} associated with this handle. + * If the task has been executed or cancelled already, it will return with + * no side effect. + * + * @return True if the cancellation completed successfully, otherwise false + */ + boolean cancel(); +} diff --git a/netty-util/src/main/java/io/netty/util/Timer.java b/netty-util/src/main/java/io/netty/util/Timer.java new file mode 100644 index 0000000..838b5ed --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/Timer.java @@ -0,0 +1,47 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util; + +import java.util.Set; +import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.TimeUnit; + +/** + * Schedules {@link TimerTask}s for one-time future execution in a background + * thread. + */ +public interface Timer { + + /** + * Schedules the specified {@link TimerTask} for one-time execution after + * the specified delay. + * + * @return a handle which is associated with the specified task + * @throws IllegalStateException if this timer has been {@linkplain #stop() stopped} already + * @throws RejectedExecutionException if the pending timeouts are too many and creating new timeout + * can cause instability in the system. + */ + Timeout newTimeout(TimerTask task, long delay, TimeUnit unit); + + /** + * Releases all resources acquired by this {@link Timer} and cancels all + * tasks which were scheduled but not executed yet. + * + * @return the handles associated with the tasks which were canceled by + * this method + */ + Set stop(); +} diff --git a/netty-util/src/main/java/io/netty/util/TimerTask.java b/netty-util/src/main/java/io/netty/util/TimerTask.java new file mode 100644 index 0000000..b11ef61 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/TimerTask.java @@ -0,0 +1,33 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util; + +import java.util.concurrent.TimeUnit; + +/** + * A task which is executed after the delay specified with + * {@link Timer#newTimeout(TimerTask, long, TimeUnit)}. + */ +public interface TimerTask { + + /** + * Executed after the delay specified with + * {@link Timer#newTimeout(TimerTask, long, TimeUnit)}. + * + * @param timeout a handle which is associated with this task + */ + void run(Timeout timeout) throws Exception; +} diff --git a/netty-util/src/main/java/io/netty/util/UncheckedBooleanSupplier.java b/netty-util/src/main/java/io/netty/util/UncheckedBooleanSupplier.java new file mode 100644 index 0000000..c0a2d80 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/UncheckedBooleanSupplier.java @@ -0,0 +1,49 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util; + +/** + * Represents a supplier of {@code boolean}-valued results which doesn't throw any checked exceptions. + */ +public interface UncheckedBooleanSupplier extends BooleanSupplier { + /** + * Gets a boolean value. + * + * @return a boolean value. + */ + @Override + boolean get(); + + /** + * A supplier which always returns {@code false} and never throws. + */ + UncheckedBooleanSupplier FALSE_SUPPLIER = new UncheckedBooleanSupplier() { + @Override + public boolean get() { + return false; + } + }; + + /** + * A supplier which always returns {@code true} and never throws. + */ + UncheckedBooleanSupplier TRUE_SUPPLIER = new UncheckedBooleanSupplier() { + @Override + public boolean get() { + return true; + } + }; +} diff --git a/netty-util/src/main/java/io/netty/util/Version.java b/netty-util/src/main/java/io/netty/util/Version.java new file mode 100644 index 0000000..5677b29 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/Version.java @@ -0,0 +1,202 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.util; + +import io.netty.util.internal.PlatformDependent; +import java.io.InputStream; +import java.net.URL; +import java.text.ParseException; +import java.text.SimpleDateFormat; +import java.util.Enumeration; +import java.util.HashSet; +import java.util.Map; +import java.util.Properties; +import java.util.Set; +import java.util.TreeMap; + +/** + * Retrieves the version information of available Netty artifacts. + *

+ * This class retrieves the version information from {@code META-INF/io.netty.versions.properties}, which is + * generated in build time. Note that it may not be possible to retrieve the information completely, depending on + * your environment, such as the specified {@link ClassLoader}, the current {@link SecurityManager}. + *

+ */ +public final class Version { + + private static final String PROP_VERSION = ".version"; + private static final String PROP_BUILD_DATE = ".buildDate"; + private static final String PROP_COMMIT_DATE = ".commitDate"; + private static final String PROP_SHORT_COMMIT_HASH = ".shortCommitHash"; + private static final String PROP_LONG_COMMIT_HASH = ".longCommitHash"; + private static final String PROP_REPO_STATUS = ".repoStatus"; + + /** + * Retrieves the version information of Netty artifacts using the current + * {@linkplain Thread#getContextClassLoader() context class loader}. + * + * @return A {@link Map} whose keys are Maven artifact IDs and whose values are {@link Version}s + */ + public static Map identify() { + return identify(null); + } + + /** + * Retrieves the version information of Netty artifacts using the specified {@link ClassLoader}. + * + * @return A {@link Map} whose keys are Maven artifact IDs and whose values are {@link Version}s + */ + public static Map identify(ClassLoader classLoader) { + if (classLoader == null) { + classLoader = PlatformDependent.getContextClassLoader(); + } + + // Collect all properties. + Properties props = new Properties(); + try { + Enumeration resources = classLoader.getResources("META-INF/io.netty.versions.properties"); + while (resources.hasMoreElements()) { + URL url = resources.nextElement(); + InputStream in = url.openStream(); + try { + props.load(in); + } finally { + try { + in.close(); + } catch (Exception ignore) { + // Ignore. + } + } + } + } catch (Exception ignore) { + // Not critical. Just ignore. + } + + // Collect all artifactIds. + Set artifactIds = new HashSet(); + for (Object o : props.keySet()) { + String k = (String) o; + + int dotIndex = k.indexOf('.'); + if (dotIndex <= 0) { + continue; + } + + String artifactId = k.substring(0, dotIndex); + + // Skip the entries without required information. + if (!props.containsKey(artifactId + PROP_VERSION) || + !props.containsKey(artifactId + PROP_BUILD_DATE) || + !props.containsKey(artifactId + PROP_COMMIT_DATE) || + !props.containsKey(artifactId + PROP_SHORT_COMMIT_HASH) || + !props.containsKey(artifactId + PROP_LONG_COMMIT_HASH) || + !props.containsKey(artifactId + PROP_REPO_STATUS)) { + continue; + } + + artifactIds.add(artifactId); + } + + Map versions = new TreeMap(); + for (String artifactId : artifactIds) { + versions.put( + artifactId, + new Version( + artifactId, + props.getProperty(artifactId + PROP_VERSION), + parseIso8601(props.getProperty(artifactId + PROP_BUILD_DATE)), + parseIso8601(props.getProperty(artifactId + PROP_COMMIT_DATE)), + props.getProperty(artifactId + PROP_SHORT_COMMIT_HASH), + props.getProperty(artifactId + PROP_LONG_COMMIT_HASH), + props.getProperty(artifactId + PROP_REPO_STATUS))); + } + + return versions; + } + + private static long parseIso8601(String value) { + try { + return new SimpleDateFormat("yyyy-MM-dd HH:mm:ss Z").parse(value).getTime(); + } catch (ParseException ignored) { + return 0; + } + } + + /** + * Prints the version information to {@link System#err}. + */ + public static void main(String[] args) { + for (Version v : identify().values()) { + System.err.println(v); + } + } + + private final String artifactId; + private final String artifactVersion; + private final long buildTimeMillis; + private final long commitTimeMillis; + private final String shortCommitHash; + private final String longCommitHash; + private final String repositoryStatus; + + private Version( + String artifactId, String artifactVersion, + long buildTimeMillis, long commitTimeMillis, + String shortCommitHash, String longCommitHash, String repositoryStatus) { + this.artifactId = artifactId; + this.artifactVersion = artifactVersion; + this.buildTimeMillis = buildTimeMillis; + this.commitTimeMillis = commitTimeMillis; + this.shortCommitHash = shortCommitHash; + this.longCommitHash = longCommitHash; + this.repositoryStatus = repositoryStatus; + } + + public String artifactId() { + return artifactId; + } + + public String artifactVersion() { + return artifactVersion; + } + + public long buildTimeMillis() { + return buildTimeMillis; + } + + public long commitTimeMillis() { + return commitTimeMillis; + } + + public String shortCommitHash() { + return shortCommitHash; + } + + public String longCommitHash() { + return longCommitHash; + } + + public String repositoryStatus() { + return repositoryStatus; + } + + @Override + public String toString() { + return artifactId + '-' + artifactVersion + '.' + shortCommitHash + + ("clean".equals(repositoryStatus) ? "" : " (repository: " + repositoryStatus + ')'); + } +} diff --git a/netty-util/src/main/java/io/netty/util/collection/ByteCollections.java b/netty-util/src/main/java/io/netty/util/collection/ByteCollections.java new file mode 100644 index 0000000..ebf0302 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/collection/ByteCollections.java @@ -0,0 +1,313 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.util.collection; + +import java.util.Collection; +import java.util.Collections; +import java.util.Iterator; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Set; + +/** + * Utilities for byte-based primitive collections. + */ +public final class ByteCollections { + + private static final ByteObjectMap EMPTY_MAP = new EmptyMap(); + + private ByteCollections() { + } + + /** + * Returns an unmodifiable empty {@link ByteObjectMap}. + */ + @SuppressWarnings("unchecked") + public static ByteObjectMap emptyMap() { + return (ByteObjectMap) EMPTY_MAP; + } + + /** + * Creates an unmodifiable wrapper around the given map. + */ + public static ByteObjectMap unmodifiableMap(final ByteObjectMap map) { + return new UnmodifiableMap(map); + } + + /** + * An empty map. All operations that attempt to modify the map are unsupported. + */ + private static final class EmptyMap implements ByteObjectMap { + @Override + public Object get(byte key) { + return null; + } + + @Override + public Object put(byte key, Object value) { + throw new UnsupportedOperationException("put"); + } + + @Override + public Object remove(byte key) { + return null; + } + + @Override + public int size() { + return 0; + } + + @Override + public boolean isEmpty() { + return true; + } + + @Override + public boolean containsKey(Object key) { + return false; + } + + @Override + public void clear() { + // Do nothing. + } + + @Override + public Set keySet() { + return Collections.emptySet(); + } + + @Override + public boolean containsKey(byte key) { + return false; + } + + @Override + public boolean containsValue(Object value) { + return false; + } + + @Override + public Iterable> entries() { + return Collections.emptySet(); + } + + @Override + public Object get(Object key) { + return null; + } + + @Override + public Object put(Byte key, Object value) { + throw new UnsupportedOperationException(); + } + + @Override + public Object remove(Object key) { + return null; + } + + @Override + public void putAll(Map m) { + throw new UnsupportedOperationException(); + } + + @Override + public Collection values() { + return Collections.emptyList(); + } + + @Override + public Set> entrySet() { + return Collections.emptySet(); + } + } + + /** + * An unmodifiable wrapper around a {@link ByteObjectMap}. + * + * @param the value type stored in the map. + */ + private static final class UnmodifiableMap implements ByteObjectMap { + private final ByteObjectMap map; + private Set keySet; + private Set> entrySet; + private Collection values; + private Iterable> entries; + + UnmodifiableMap(ByteObjectMap map) { + this.map = map; + } + + @Override + public V get(byte key) { + return map.get(key); + } + + @Override + public V put(byte key, V value) { + throw new UnsupportedOperationException("put"); + } + + @Override + public V remove(byte key) { + throw new UnsupportedOperationException("remove"); + } + + @Override + public int size() { + return map.size(); + } + + @Override + public boolean isEmpty() { + return map.isEmpty(); + } + + @Override + public void clear() { + throw new UnsupportedOperationException("clear"); + } + + @Override + public boolean containsKey(byte key) { + return map.containsKey(key); + } + + @Override + public boolean containsValue(Object value) { + return map.containsValue(value); + } + + @Override + public boolean containsKey(Object key) { + return map.containsKey(key); + } + + @Override + public V get(Object key) { + return map.get(key); + } + + @Override + public V put(Byte key, V value) { + throw new UnsupportedOperationException("put"); + } + + @Override + public V remove(Object key) { + throw new UnsupportedOperationException("remove"); + } + + @Override + public void putAll(Map m) { + throw new UnsupportedOperationException("putAll"); + } + + @Override + public Iterable> entries() { + if (entries == null) { + entries = new Iterable>() { + @Override + public Iterator> iterator() { + return new IteratorImpl(map.entries().iterator()); + } + }; + } + + return entries; + } + + @Override + public Set keySet() { + if (keySet == null) { + keySet = Collections.unmodifiableSet(map.keySet()); + } + return keySet; + } + + @Override + public Set> entrySet() { + if (entrySet == null) { + entrySet = Collections.unmodifiableSet(map.entrySet()); + } + return entrySet; + } + + @Override + public Collection values() { + if (values == null) { + values = Collections.unmodifiableCollection(map.values()); + } + return values; + } + + /** + * Unmodifiable wrapper for an iterator. + */ + private class IteratorImpl implements Iterator> { + final Iterator> iter; + + IteratorImpl(Iterator> iter) { + this.iter = iter; + } + + @Override + public boolean hasNext() { + return iter.hasNext(); + } + + @Override + public PrimitiveEntry next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + return new EntryImpl(iter.next()); + } + + @Override + public void remove() { + throw new UnsupportedOperationException("remove"); + } + } + + /** + * Unmodifiable wrapper for an entry. + */ + private class EntryImpl implements PrimitiveEntry { + private final PrimitiveEntry entry; + + EntryImpl(PrimitiveEntry entry) { + this.entry = entry; + } + + @Override + public byte key() { + return entry.key(); + } + + @Override + public V value() { + return entry.value(); + } + + @Override + public void setValue(V value) { + throw new UnsupportedOperationException("setValue"); + } + } + } +} diff --git a/netty-util/src/main/java/io/netty/util/collection/ByteObjectHashMap.java b/netty-util/src/main/java/io/netty/util/collection/ByteObjectHashMap.java new file mode 100644 index 0000000..1342baa --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/collection/ByteObjectHashMap.java @@ -0,0 +1,723 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package io.netty.util.collection; + +import static io.netty.util.internal.MathUtil.safeFindNextPositivePowerOfTwo; + +import java.util.AbstractCollection; +import java.util.AbstractSet; +import java.util.Arrays; +import java.util.Collection; +import java.util.Iterator; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Set; + +/** + * A hash map implementation of {@link ByteObjectMap} that uses open addressing for keys. + * To minimize the memory footprint, this class uses open addressing rather than chaining. + * Collisions are resolved using linear probing. Deletions implement compaction, so cost of + * remove can approach O(N) for full maps, which makes a small loadFactor recommended. + * + * @param The value type stored in the map. + */ +public class ByteObjectHashMap implements ByteObjectMap { + + /** Default initial capacity. Used if not specified in the constructor */ + public static final int DEFAULT_CAPACITY = 8; + + /** Default load factor. Used if not specified in the constructor */ + public static final float DEFAULT_LOAD_FACTOR = 0.5f; + + /** + * Placeholder for null values, so we can use the actual null to mean available. + * (Better than using a placeholder for available: less references for GC processing.) + */ + private static final Object NULL_VALUE = new Object(); + + /** The maximum number of elements allowed without allocating more space. */ + private int maxSize; + + /** The load factor for the map. Used to calculate {@link #maxSize}. */ + private final float loadFactor; + + private byte[] keys; + private V[] values; + private int size; + private int mask; + + private final Set keySet = new KeySet(); + private final Set> entrySet = new EntrySet(); + private final Iterable> entries = new Iterable>() { + @Override + public Iterator> iterator() { + return new PrimitiveIterator(); + } + }; + + public ByteObjectHashMap() { + this(DEFAULT_CAPACITY, DEFAULT_LOAD_FACTOR); + } + + public ByteObjectHashMap(int initialCapacity) { + this(initialCapacity, DEFAULT_LOAD_FACTOR); + } + + public ByteObjectHashMap(int initialCapacity, float loadFactor) { + if (loadFactor <= 0.0f || loadFactor > 1.0f) { + // Cannot exceed 1 because we can never store more than capacity elements; + // using a bigger loadFactor would trigger rehashing before the desired load is reached. + throw new IllegalArgumentException("loadFactor must be > 0 and <= 1"); + } + + this.loadFactor = loadFactor; + + // Adjust the initial capacity if necessary. + int capacity = safeFindNextPositivePowerOfTwo(initialCapacity); + mask = capacity - 1; + + // Allocate the arrays. + keys = new byte[capacity]; + @SuppressWarnings({ "unchecked", "SuspiciousArrayCast" }) + V[] temp = (V[]) new Object[capacity]; + values = temp; + + // Initialize the maximum size value. + maxSize = calcMaxSize(capacity); + } + + private static T toExternal(T value) { + assert value != null : "null is not a legitimate internal value. Concurrent Modification?"; + return value == NULL_VALUE ? null : value; + } + + @SuppressWarnings("unchecked") + private static T toInternal(T value) { + return value == null ? (T) NULL_VALUE : value; + } + + @Override + public V get(byte key) { + int index = indexOf(key); + return index == -1 ? null : toExternal(values[index]); + } + + @Override + public V put(byte key, V value) { + int startIndex = hashIndex(key); + int index = startIndex; + + for (;;) { + if (values[index] == null) { + // Found empty slot, use it. + keys[index] = key; + values[index] = toInternal(value); + growSize(); + return null; + } + if (keys[index] == key) { + // Found existing entry with this key, just replace the value. + V previousValue = values[index]; + values[index] = toInternal(value); + return toExternal(previousValue); + } + + // Conflict, keep probing ... + if ((index = probeNext(index)) == startIndex) { + // Can only happen if the map was full at MAX_ARRAY_SIZE and couldn't grow. + throw new IllegalStateException("Unable to insert"); + } + } + } + + @Override + public void putAll(Map sourceMap) { + if (sourceMap instanceof ByteObjectHashMap) { + // Optimization - iterate through the arrays. + @SuppressWarnings("unchecked") + ByteObjectHashMap source = (ByteObjectHashMap) sourceMap; + for (int i = 0; i < source.values.length; ++i) { + V sourceValue = source.values[i]; + if (sourceValue != null) { + put(source.keys[i], sourceValue); + } + } + return; + } + + // Otherwise, just add each entry. + for (Entry entry : sourceMap.entrySet()) { + put(entry.getKey(), entry.getValue()); + } + } + + @Override + public V remove(byte key) { + int index = indexOf(key); + if (index == -1) { + return null; + } + + V prev = values[index]; + removeAt(index); + return toExternal(prev); + } + + @Override + public int size() { + return size; + } + + @Override + public boolean isEmpty() { + return size == 0; + } + + @Override + public void clear() { + Arrays.fill(keys, (byte) 0); + Arrays.fill(values, null); + size = 0; + } + + @Override + public boolean containsKey(byte key) { + return indexOf(key) >= 0; + } + + @Override + public boolean containsValue(Object value) { + @SuppressWarnings("unchecked") + V v1 = toInternal((V) value); + for (V v2 : values) { + // The map supports null values; this will be matched as NULL_VALUE.equals(NULL_VALUE). + if (v2 != null && v2.equals(v1)) { + return true; + } + } + return false; + } + + @Override + public Iterable> entries() { + return entries; + } + + @Override + public Collection values() { + return new AbstractCollection() { + @Override + public Iterator iterator() { + return new Iterator() { + final PrimitiveIterator iter = new PrimitiveIterator(); + + @Override + public boolean hasNext() { + return iter.hasNext(); + } + + @Override + public V next() { + return iter.next().value(); + } + + @Override + public void remove() { + iter.remove(); + } + }; + } + + @Override + public int size() { + return size; + } + }; + } + + @Override + public int hashCode() { + // Hashcode is based on all non-zero, valid keys. We have to scan the whole keys + // array, which may have different lengths for two maps of same size(), so the + // capacity cannot be used as input for hashing but the size can. + int hash = size; + for (byte key : keys) { + // 0 can be a valid key or unused slot, but won't impact the hashcode in either case. + // This way we can use a cheap loop without conditionals, or hard-to-unroll operations, + // or the devastatingly bad memory locality of visiting value objects. + // Also, it's important to use a hash function that does not depend on the ordering + // of terms, only their values; since the map is an unordered collection and + // entries can end up in different positions in different maps that have the same + // elements, but with different history of puts/removes, due to conflicts. + hash ^= hashCode(key); + } + return hash; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof ByteObjectMap)) { + return false; + } + @SuppressWarnings("rawtypes") + ByteObjectMap other = (ByteObjectMap) obj; + if (size != other.size()) { + return false; + } + for (int i = 0; i < values.length; ++i) { + V value = values[i]; + if (value != null) { + byte key = keys[i]; + Object otherValue = other.get(key); + if (value == NULL_VALUE) { + if (otherValue != null) { + return false; + } + } else if (!value.equals(otherValue)) { + return false; + } + } + } + return true; + } + + @Override + public boolean containsKey(Object key) { + return containsKey(objectToKey(key)); + } + + @Override + public V get(Object key) { + return get(objectToKey(key)); + } + + @Override + public V put(Byte key, V value) { + return put(objectToKey(key), value); + } + + @Override + public V remove(Object key) { + return remove(objectToKey(key)); + } + + @Override + public Set keySet() { + return keySet; + } + + @Override + public Set> entrySet() { + return entrySet; + } + + private byte objectToKey(Object key) { + return (byte) ((Byte) key).byteValue(); + } + + /** + * Locates the index for the given key. This method probes using double hashing. + * + * @param key the key for an entry in the map. + * @return the index where the key was found, or {@code -1} if no entry is found for that key. + */ + private int indexOf(byte key) { + int startIndex = hashIndex(key); + int index = startIndex; + + for (;;) { + if (values[index] == null) { + // It's available, so no chance that this value exists anywhere in the map. + return -1; + } + if (key == keys[index]) { + return index; + } + + // Conflict, keep probing ... + if ((index = probeNext(index)) == startIndex) { + return -1; + } + } + } + + /** + * Returns the hashed index for the given key. + */ + private int hashIndex(byte key) { + // The array lengths are always a power of two, so we can use a bitmask to stay inside the array bounds. + return hashCode(key) & mask; + } + + /** + * Returns the hash code for the key. + */ + private static int hashCode(byte key) { + return (int) key; + } + + /** + * Get the next sequential index after {@code index} and wraps if necessary. + */ + private int probeNext(int index) { + // The array lengths are always a power of two, so we can use a bitmask to stay inside the array bounds. + return (index + 1) & mask; + } + + /** + * Grows the map size after an insertion. If necessary, performs a rehash of the map. + */ + private void growSize() { + size++; + + if (size > maxSize) { + if(keys.length == Integer.MAX_VALUE) { + throw new IllegalStateException("Max capacity reached at size=" + size); + } + + // Double the capacity. + rehash(keys.length << 1); + } + } + + /** + * Removes entry at the given index position. Also performs opportunistic, incremental rehashing + * if necessary to not break conflict chains. + * + * @param index the index position of the element to remove. + * @return {@code true} if the next item was moved back. {@code false} otherwise. + */ + private boolean removeAt(final int index) { + --size; + // Clearing the key is not strictly necessary (for GC like in a regular collection), + // but recommended for security. The memory location is still fresh in the cache anyway. + keys[index] = 0; + values[index] = null; + + // In the interval from index to the next available entry, the arrays may have entries + // that are displaced from their base position due to prior conflicts. Iterate these + // entries and move them back if possible, optimizing future lookups. + // Knuth Section 6.4 Algorithm R, also used by the JDK's IdentityHashMap. + + int nextFree = index; + int i = probeNext(index); + for (V value = values[i]; value != null; value = values[i = probeNext(i)]) { + byte key = keys[i]; + int bucket = hashIndex(key); + if (i < bucket && (bucket <= nextFree || nextFree <= i) || + bucket <= nextFree && nextFree <= i) { + // Move the displaced entry "back" to the first available position. + keys[nextFree] = key; + values[nextFree] = value; + // Put the first entry after the displaced entry + keys[i] = 0; + values[i] = null; + nextFree = i; + } + } + return nextFree != index; + } + + /** + * Calculates the maximum size allowed before rehashing. + */ + private int calcMaxSize(int capacity) { + // Clip the upper bound so that there will always be at least one available slot. + int upperBound = capacity - 1; + return Math.min(upperBound, (int) (capacity * loadFactor)); + } + + /** + * Rehashes the map for the given capacity. + * + * @param newCapacity the new capacity for the map. + */ + private void rehash(int newCapacity) { + byte[] oldKeys = keys; + V[] oldVals = values; + + keys = new byte[newCapacity]; + @SuppressWarnings({ "unchecked", "SuspiciousArrayCast" }) + V[] temp = (V[]) new Object[newCapacity]; + values = temp; + + maxSize = calcMaxSize(newCapacity); + mask = newCapacity - 1; + + // Insert to the new arrays. + for (int i = 0; i < oldVals.length; ++i) { + V oldVal = oldVals[i]; + if (oldVal != null) { + // Inlined put(), but much simpler: we don't need to worry about + // duplicated keys, growing/rehashing, or failing to insert. + byte oldKey = oldKeys[i]; + int index = hashIndex(oldKey); + + for (;;) { + if (values[index] == null) { + keys[index] = oldKey; + values[index] = oldVal; + break; + } + + // Conflict, keep probing. Can wrap around, but never reaches startIndex again. + index = probeNext(index); + } + } + } + } + + @Override + public String toString() { + if (isEmpty()) { + return "{}"; + } + StringBuilder sb = new StringBuilder(4 * size); + sb.append('{'); + boolean first = true; + for (int i = 0; i < values.length; ++i) { + V value = values[i]; + if (value != null) { + if (!first) { + sb.append(", "); + } + sb.append(keyToString(keys[i])).append('=').append(value == this ? "(this Map)" : + toExternal(value)); + first = false; + } + } + return sb.append('}').toString(); + } + + /** + * Helper method called by {@link #toString()} in order to convert a single map key into a string. + * This is protected to allow subclasses to override the appearance of a given key. + */ + protected String keyToString(byte key) { + return Byte.toString(key); + } + + /** + * Set implementation for iterating over the entries of the map. + */ + private final class EntrySet extends AbstractSet> { + @Override + public Iterator> iterator() { + return new MapIterator(); + } + + @Override + public int size() { + return ByteObjectHashMap.this.size(); + } + } + + /** + * Set implementation for iterating over the keys. + */ + private final class KeySet extends AbstractSet { + @Override + public int size() { + return ByteObjectHashMap.this.size(); + } + + @Override + public boolean contains(Object o) { + return ByteObjectHashMap.this.containsKey(o); + } + + @Override + public boolean remove(Object o) { + return ByteObjectHashMap.this.remove(o) != null; + } + + @Override + public boolean retainAll(Collection retainedKeys) { + boolean changed = false; + for(Iterator> iter = entries().iterator(); iter.hasNext(); ) { + PrimitiveEntry entry = iter.next(); + if (!retainedKeys.contains(entry.key())) { + changed = true; + iter.remove(); + } + } + return changed; + } + + @Override + public void clear() { + ByteObjectHashMap.this.clear(); + } + + @Override + public Iterator iterator() { + return new Iterator() { + private final Iterator> iter = entrySet.iterator(); + + @Override + public boolean hasNext() { + return iter.hasNext(); + } + + @Override + public Byte next() { + return iter.next().getKey(); + } + + @Override + public void remove() { + iter.remove(); + } + }; + } + } + + /** + * Iterator over primitive entries. Entry key/values are overwritten by each call to {@link #next()}. + */ + private final class PrimitiveIterator implements Iterator>, PrimitiveEntry { + private int prevIndex = -1; + private int nextIndex = -1; + private int entryIndex = -1; + + private void scanNext() { + while (++nextIndex != values.length && values[nextIndex] == null) { + } + } + + @Override + public boolean hasNext() { + if (nextIndex == -1) { + scanNext(); + } + return nextIndex != values.length; + } + + @Override + public PrimitiveEntry next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + + prevIndex = nextIndex; + scanNext(); + + // Always return the same Entry object, just change its index each time. + entryIndex = prevIndex; + return this; + } + + @Override + public void remove() { + if (prevIndex == -1) { + throw new IllegalStateException("next must be called before each remove."); + } + if (removeAt(prevIndex)) { + // removeAt may move elements "back" in the array if they have been displaced because their spot in the + // array was occupied when they were inserted. If this occurs then the nextIndex is now invalid and + // should instead point to the prevIndex which now holds an element which was "moved back". + nextIndex = prevIndex; + } + prevIndex = -1; + } + + // Entry implementation. Since this implementation uses a single Entry, we coalesce that + // into the Iterator object (potentially making loop optimization much easier). + + @Override + public byte key() { + return keys[entryIndex]; + } + + @Override + public V value() { + return toExternal(values[entryIndex]); + } + + @Override + public void setValue(V value) { + values[entryIndex] = toInternal(value); + } + } + + /** + * Iterator used by the {@link Map} interface. + */ + private final class MapIterator implements Iterator> { + private final PrimitiveIterator iter = new PrimitiveIterator(); + + @Override + public boolean hasNext() { + return iter.hasNext(); + } + + @Override + public Entry next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + + iter.next(); + + return new MapEntry(iter.entryIndex); + } + + @Override + public void remove() { + iter.remove(); + } + } + + /** + * A single entry in the map. + */ + final class MapEntry implements Entry { + private final int entryIndex; + + MapEntry(int entryIndex) { + this.entryIndex = entryIndex; + } + + @Override + public Byte getKey() { + verifyExists(); + return keys[entryIndex]; + } + + @Override + public V getValue() { + verifyExists(); + return toExternal(values[entryIndex]); + } + + @Override + public V setValue(V value) { + verifyExists(); + V prevValue = toExternal(values[entryIndex]); + values[entryIndex] = toInternal(value); + return prevValue; + } + + private void verifyExists() { + if (values[entryIndex] == null) { + throw new IllegalStateException("The map entry has been removed"); + } + } + } +} diff --git a/netty-util/src/main/java/io/netty/util/collection/ByteObjectMap.java b/netty-util/src/main/java/io/netty/util/collection/ByteObjectMap.java new file mode 100644 index 0000000..66290ac --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/collection/ByteObjectMap.java @@ -0,0 +1,84 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.util.collection; + +import java.util.Map; + +/** + * Interface for a primitive map that uses {@code byte}s as keys. + * + * @param the value type stored in the map. + */ +public interface ByteObjectMap extends Map { + + /** + * A primitive entry in the map, provided by the iterator from {@link #entries()} + * + * @param the value type stored in the map. + */ + interface PrimitiveEntry { + /** + * Gets the key for this entry. + */ + byte key(); + + /** + * Gets the value for this entry. + */ + V value(); + + /** + * Sets the value for this entry. + */ + void setValue(V value); + } + + /** + * Gets the value in the map with the specified key. + * + * @param key the key whose associated value is to be returned. + * @return the value or {@code null} if the key was not found in the map. + */ + V get(byte key); + + /** + * Puts the given entry into the map. + * + * @param key the key of the entry. + * @param value the value of the entry. + * @return the previous value for this key or {@code null} if there was no previous mapping. + */ + V put(byte key, V value); + + /** + * Removes the entry with the specified key. + * + * @param key the key for the entry to be removed from this map. + * @return the previous value for the key, or {@code null} if there was no mapping. + */ + V remove(byte key); + + /** + * Gets an iterable to traverse over the primitive entries contained in this map. As an optimization, + * the {@link PrimitiveEntry}s returned by the {@link java.util.Iterator} may change as the {@link java.util.Iterator} + * progresses. The caller should not rely on {@link PrimitiveEntry} key/value stability. + */ + Iterable> entries(); + + /** + * Indicates whether or not this map contains a value for the specified key. + */ + boolean containsKey(byte key); +} diff --git a/netty-util/src/main/java/io/netty/util/collection/CharCollections.java b/netty-util/src/main/java/io/netty/util/collection/CharCollections.java new file mode 100644 index 0000000..8090aa4 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/collection/CharCollections.java @@ -0,0 +1,313 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.util.collection; + +import java.util.Collection; +import java.util.Collections; +import java.util.Iterator; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Set; + +/** + * Utilities for char-based primitive collections. + */ +public final class CharCollections { + + private static final CharObjectMap EMPTY_MAP = new EmptyMap(); + + private CharCollections() { + } + + /** + * Returns an unmodifiable empty {@link CharObjectMap}. + */ + @SuppressWarnings("unchecked") + public static CharObjectMap emptyMap() { + return (CharObjectMap) EMPTY_MAP; + } + + /** + * Creates an unmodifiable wrapper around the given map. + */ + public static CharObjectMap unmodifiableMap(final CharObjectMap map) { + return new UnmodifiableMap(map); + } + + /** + * An empty map. All operations that attempt to modify the map are unsupported. + */ + private static final class EmptyMap implements CharObjectMap { + @Override + public Object get(char key) { + return null; + } + + @Override + public Object put(char key, Object value) { + throw new UnsupportedOperationException("put"); + } + + @Override + public Object remove(char key) { + return null; + } + + @Override + public int size() { + return 0; + } + + @Override + public boolean isEmpty() { + return true; + } + + @Override + public boolean containsKey(Object key) { + return false; + } + + @Override + public void clear() { + // Do nothing. + } + + @Override + public Set keySet() { + return Collections.emptySet(); + } + + @Override + public boolean containsKey(char key) { + return false; + } + + @Override + public boolean containsValue(Object value) { + return false; + } + + @Override + public Iterable> entries() { + return Collections.emptySet(); + } + + @Override + public Object get(Object key) { + return null; + } + + @Override + public Object put(Character key, Object value) { + throw new UnsupportedOperationException(); + } + + @Override + public Object remove(Object key) { + return null; + } + + @Override + public void putAll(Map m) { + throw new UnsupportedOperationException(); + } + + @Override + public Collection values() { + return Collections.emptyList(); + } + + @Override + public Set> entrySet() { + return Collections.emptySet(); + } + } + + /** + * An unmodifiable wrapper around a {@link CharObjectMap}. + * + * @param the value type stored in the map. + */ + private static final class UnmodifiableMap implements CharObjectMap { + private final CharObjectMap map; + private Set keySet; + private Set> entrySet; + private Collection values; + private Iterable> entries; + + UnmodifiableMap(CharObjectMap map) { + this.map = map; + } + + @Override + public V get(char key) { + return map.get(key); + } + + @Override + public V put(char key, V value) { + throw new UnsupportedOperationException("put"); + } + + @Override + public V remove(char key) { + throw new UnsupportedOperationException("remove"); + } + + @Override + public int size() { + return map.size(); + } + + @Override + public boolean isEmpty() { + return map.isEmpty(); + } + + @Override + public void clear() { + throw new UnsupportedOperationException("clear"); + } + + @Override + public boolean containsKey(char key) { + return map.containsKey(key); + } + + @Override + public boolean containsValue(Object value) { + return map.containsValue(value); + } + + @Override + public boolean containsKey(Object key) { + return map.containsKey(key); + } + + @Override + public V get(Object key) { + return map.get(key); + } + + @Override + public V put(Character key, V value) { + throw new UnsupportedOperationException("put"); + } + + @Override + public V remove(Object key) { + throw new UnsupportedOperationException("remove"); + } + + @Override + public void putAll(Map m) { + throw new UnsupportedOperationException("putAll"); + } + + @Override + public Iterable> entries() { + if (entries == null) { + entries = new Iterable>() { + @Override + public Iterator> iterator() { + return new IteratorImpl(map.entries().iterator()); + } + }; + } + + return entries; + } + + @Override + public Set keySet() { + if (keySet == null) { + keySet = Collections.unmodifiableSet(map.keySet()); + } + return keySet; + } + + @Override + public Set> entrySet() { + if (entrySet == null) { + entrySet = Collections.unmodifiableSet(map.entrySet()); + } + return entrySet; + } + + @Override + public Collection values() { + if (values == null) { + values = Collections.unmodifiableCollection(map.values()); + } + return values; + } + + /** + * Unmodifiable wrapper for an iterator. + */ + private class IteratorImpl implements Iterator> { + final Iterator> iter; + + IteratorImpl(Iterator> iter) { + this.iter = iter; + } + + @Override + public boolean hasNext() { + return iter.hasNext(); + } + + @Override + public PrimitiveEntry next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + return new EntryImpl(iter.next()); + } + + @Override + public void remove() { + throw new UnsupportedOperationException("remove"); + } + } + + /** + * Unmodifiable wrapper for an entry. + */ + private class EntryImpl implements PrimitiveEntry { + private final PrimitiveEntry entry; + + EntryImpl(PrimitiveEntry entry) { + this.entry = entry; + } + + @Override + public char key() { + return entry.key(); + } + + @Override + public V value() { + return entry.value(); + } + + @Override + public void setValue(V value) { + throw new UnsupportedOperationException("setValue"); + } + } + } +} diff --git a/netty-util/src/main/java/io/netty/util/collection/CharObjectHashMap.java b/netty-util/src/main/java/io/netty/util/collection/CharObjectHashMap.java new file mode 100644 index 0000000..d16d885 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/collection/CharObjectHashMap.java @@ -0,0 +1,723 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package io.netty.util.collection; + +import static io.netty.util.internal.MathUtil.safeFindNextPositivePowerOfTwo; + +import java.util.AbstractCollection; +import java.util.AbstractSet; +import java.util.Arrays; +import java.util.Collection; +import java.util.Iterator; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Set; + +/** + * A hash map implementation of {@link CharObjectMap} that uses open addressing for keys. + * To minimize the memory footprint, this class uses open addressing rather than chaining. + * Collisions are resolved using linear probing. Deletions implement compaction, so cost of + * remove can approach O(N) for full maps, which makes a small loadFactor recommended. + * + * @param The value type stored in the map. + */ +public class CharObjectHashMap implements CharObjectMap { + + /** Default initial capacity. Used if not specified in the constructor */ + public static final int DEFAULT_CAPACITY = 8; + + /** Default load factor. Used if not specified in the constructor */ + public static final float DEFAULT_LOAD_FACTOR = 0.5f; + + /** + * Placeholder for null values, so we can use the actual null to mean available. + * (Better than using a placeholder for available: less references for GC processing.) + */ + private static final Object NULL_VALUE = new Object(); + + /** The maximum number of elements allowed without allocating more space. */ + private int maxSize; + + /** The load factor for the map. Used to calculate {@link #maxSize}. */ + private final float loadFactor; + + private char[] keys; + private V[] values; + private int size; + private int mask; + + private final Set keySet = new KeySet(); + private final Set> entrySet = new EntrySet(); + private final Iterable> entries = new Iterable>() { + @Override + public Iterator> iterator() { + return new PrimitiveIterator(); + } + }; + + public CharObjectHashMap() { + this(DEFAULT_CAPACITY, DEFAULT_LOAD_FACTOR); + } + + public CharObjectHashMap(int initialCapacity) { + this(initialCapacity, DEFAULT_LOAD_FACTOR); + } + + public CharObjectHashMap(int initialCapacity, float loadFactor) { + if (loadFactor <= 0.0f || loadFactor > 1.0f) { + // Cannot exceed 1 because we can never store more than capacity elements; + // using a bigger loadFactor would trigger rehashing before the desired load is reached. + throw new IllegalArgumentException("loadFactor must be > 0 and <= 1"); + } + + this.loadFactor = loadFactor; + + // Adjust the initial capacity if necessary. + int capacity = safeFindNextPositivePowerOfTwo(initialCapacity); + mask = capacity - 1; + + // Allocate the arrays. + keys = new char[capacity]; + @SuppressWarnings({ "unchecked", "SuspiciousArrayCast" }) + V[] temp = (V[]) new Object[capacity]; + values = temp; + + // Initialize the maximum size value. + maxSize = calcMaxSize(capacity); + } + + private static T toExternal(T value) { + assert value != null : "null is not a legitimate internal value. Concurrent Modification?"; + return value == NULL_VALUE ? null : value; + } + + @SuppressWarnings("unchecked") + private static T toInternal(T value) { + return value == null ? (T) NULL_VALUE : value; + } + + @Override + public V get(char key) { + int index = indexOf(key); + return index == -1 ? null : toExternal(values[index]); + } + + @Override + public V put(char key, V value) { + int startIndex = hashIndex(key); + int index = startIndex; + + for (;;) { + if (values[index] == null) { + // Found empty slot, use it. + keys[index] = key; + values[index] = toInternal(value); + growSize(); + return null; + } + if (keys[index] == key) { + // Found existing entry with this key, just replace the value. + V previousValue = values[index]; + values[index] = toInternal(value); + return toExternal(previousValue); + } + + // Conflict, keep probing ... + if ((index = probeNext(index)) == startIndex) { + // Can only happen if the map was full at MAX_ARRAY_SIZE and couldn't grow. + throw new IllegalStateException("Unable to insert"); + } + } + } + + @Override + public void putAll(Map sourceMap) { + if (sourceMap instanceof CharObjectHashMap) { + // Optimization - iterate through the arrays. + @SuppressWarnings("unchecked") + CharObjectHashMap source = (CharObjectHashMap) sourceMap; + for (int i = 0; i < source.values.length; ++i) { + V sourceValue = source.values[i]; + if (sourceValue != null) { + put(source.keys[i], sourceValue); + } + } + return; + } + + // Otherwise, just add each entry. + for (Entry entry : sourceMap.entrySet()) { + put(entry.getKey(), entry.getValue()); + } + } + + @Override + public V remove(char key) { + int index = indexOf(key); + if (index == -1) { + return null; + } + + V prev = values[index]; + removeAt(index); + return toExternal(prev); + } + + @Override + public int size() { + return size; + } + + @Override + public boolean isEmpty() { + return size == 0; + } + + @Override + public void clear() { + Arrays.fill(keys, (char) 0); + Arrays.fill(values, null); + size = 0; + } + + @Override + public boolean containsKey(char key) { + return indexOf(key) >= 0; + } + + @Override + public boolean containsValue(Object value) { + @SuppressWarnings("unchecked") + V v1 = toInternal((V) value); + for (V v2 : values) { + // The map supports null values; this will be matched as NULL_VALUE.equals(NULL_VALUE). + if (v2 != null && v2.equals(v1)) { + return true; + } + } + return false; + } + + @Override + public Iterable> entries() { + return entries; + } + + @Override + public Collection values() { + return new AbstractCollection() { + @Override + public Iterator iterator() { + return new Iterator() { + final PrimitiveIterator iter = new PrimitiveIterator(); + + @Override + public boolean hasNext() { + return iter.hasNext(); + } + + @Override + public V next() { + return iter.next().value(); + } + + @Override + public void remove() { + iter.remove(); + } + }; + } + + @Override + public int size() { + return size; + } + }; + } + + @Override + public int hashCode() { + // Hashcode is based on all non-zero, valid keys. We have to scan the whole keys + // array, which may have different lengths for two maps of same size(), so the + // capacity cannot be used as input for hashing but the size can. + int hash = size; + for (char key : keys) { + // 0 can be a valid key or unused slot, but won't impact the hashcode in either case. + // This way we can use a cheap loop without conditionals, or hard-to-unroll operations, + // or the devastatingly bad memory locality of visiting value objects. + // Also, it's important to use a hash function that does not depend on the ordering + // of terms, only their values; since the map is an unordered collection and + // entries can end up in different positions in different maps that have the same + // elements, but with different history of puts/removes, due to conflicts. + hash ^= hashCode(key); + } + return hash; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof CharObjectMap)) { + return false; + } + @SuppressWarnings("rawtypes") + CharObjectMap other = (CharObjectMap) obj; + if (size != other.size()) { + return false; + } + for (int i = 0; i < values.length; ++i) { + V value = values[i]; + if (value != null) { + char key = keys[i]; + Object otherValue = other.get(key); + if (value == NULL_VALUE) { + if (otherValue != null) { + return false; + } + } else if (!value.equals(otherValue)) { + return false; + } + } + } + return true; + } + + @Override + public boolean containsKey(Object key) { + return containsKey(objectToKey(key)); + } + + @Override + public V get(Object key) { + return get(objectToKey(key)); + } + + @Override + public V put(Character key, V value) { + return put(objectToKey(key), value); + } + + @Override + public V remove(Object key) { + return remove(objectToKey(key)); + } + + @Override + public Set keySet() { + return keySet; + } + + @Override + public Set> entrySet() { + return entrySet; + } + + private char objectToKey(Object key) { + return (char) ((Character) key).charValue(); + } + + /** + * Locates the index for the given key. This method probes using double hashing. + * + * @param key the key for an entry in the map. + * @return the index where the key was found, or {@code -1} if no entry is found for that key. + */ + private int indexOf(char key) { + int startIndex = hashIndex(key); + int index = startIndex; + + for (;;) { + if (values[index] == null) { + // It's available, so no chance that this value exists anywhere in the map. + return -1; + } + if (key == keys[index]) { + return index; + } + + // Conflict, keep probing ... + if ((index = probeNext(index)) == startIndex) { + return -1; + } + } + } + + /** + * Returns the hashed index for the given key. + */ + private int hashIndex(char key) { + // The array lengths are always a power of two, so we can use a bitmask to stay inside the array bounds. + return hashCode(key) & mask; + } + + /** + * Returns the hash code for the key. + */ + private static int hashCode(char key) { + return (int) key; + } + + /** + * Get the next sequential index after {@code index} and wraps if necessary. + */ + private int probeNext(int index) { + // The array lengths are always a power of two, so we can use a bitmask to stay inside the array bounds. + return (index + 1) & mask; + } + + /** + * Grows the map size after an insertion. If necessary, performs a rehash of the map. + */ + private void growSize() { + size++; + + if (size > maxSize) { + if(keys.length == Integer.MAX_VALUE) { + throw new IllegalStateException("Max capacity reached at size=" + size); + } + + // Double the capacity. + rehash(keys.length << 1); + } + } + + /** + * Removes entry at the given index position. Also performs opportunistic, incremental rehashing + * if necessary to not break conflict chains. + * + * @param index the index position of the element to remove. + * @return {@code true} if the next item was moved back. {@code false} otherwise. + */ + private boolean removeAt(final int index) { + --size; + // Clearing the key is not strictly necessary (for GC like in a regular collection), + // but recommended for security. The memory location is still fresh in the cache anyway. + keys[index] = 0; + values[index] = null; + + // In the interval from index to the next available entry, the arrays may have entries + // that are displaced from their base position due to prior conflicts. Iterate these + // entries and move them back if possible, optimizing future lookups. + // Knuth Section 6.4 Algorithm R, also used by the JDK's IdentityHashMap. + + int nextFree = index; + int i = probeNext(index); + for (V value = values[i]; value != null; value = values[i = probeNext(i)]) { + char key = keys[i]; + int bucket = hashIndex(key); + if (i < bucket && (bucket <= nextFree || nextFree <= i) || + bucket <= nextFree && nextFree <= i) { + // Move the displaced entry "back" to the first available position. + keys[nextFree] = key; + values[nextFree] = value; + // Put the first entry after the displaced entry + keys[i] = 0; + values[i] = null; + nextFree = i; + } + } + return nextFree != index; + } + + /** + * Calculates the maximum size allowed before rehashing. + */ + private int calcMaxSize(int capacity) { + // Clip the upper bound so that there will always be at least one available slot. + int upperBound = capacity - 1; + return Math.min(upperBound, (int) (capacity * loadFactor)); + } + + /** + * Rehashes the map for the given capacity. + * + * @param newCapacity the new capacity for the map. + */ + private void rehash(int newCapacity) { + char[] oldKeys = keys; + V[] oldVals = values; + + keys = new char[newCapacity]; + @SuppressWarnings({ "unchecked", "SuspiciousArrayCast" }) + V[] temp = (V[]) new Object[newCapacity]; + values = temp; + + maxSize = calcMaxSize(newCapacity); + mask = newCapacity - 1; + + // Insert to the new arrays. + for (int i = 0; i < oldVals.length; ++i) { + V oldVal = oldVals[i]; + if (oldVal != null) { + // Inlined put(), but much simpler: we don't need to worry about + // duplicated keys, growing/rehashing, or failing to insert. + char oldKey = oldKeys[i]; + int index = hashIndex(oldKey); + + for (;;) { + if (values[index] == null) { + keys[index] = oldKey; + values[index] = oldVal; + break; + } + + // Conflict, keep probing. Can wrap around, but never reaches startIndex again. + index = probeNext(index); + } + } + } + } + + @Override + public String toString() { + if (isEmpty()) { + return "{}"; + } + StringBuilder sb = new StringBuilder(4 * size); + sb.append('{'); + boolean first = true; + for (int i = 0; i < values.length; ++i) { + V value = values[i]; + if (value != null) { + if (!first) { + sb.append(", "); + } + sb.append(keyToString(keys[i])).append('=').append(value == this ? "(this Map)" : + toExternal(value)); + first = false; + } + } + return sb.append('}').toString(); + } + + /** + * Helper method called by {@link #toString()} in order to convert a single map key into a string. + * This is protected to allow subclasses to override the appearance of a given key. + */ + protected String keyToString(char key) { + return Character.toString(key); + } + + /** + * Set implementation for iterating over the entries of the map. + */ + private final class EntrySet extends AbstractSet> { + @Override + public Iterator> iterator() { + return new MapIterator(); + } + + @Override + public int size() { + return CharObjectHashMap.this.size(); + } + } + + /** + * Set implementation for iterating over the keys. + */ + private final class KeySet extends AbstractSet { + @Override + public int size() { + return CharObjectHashMap.this.size(); + } + + @Override + public boolean contains(Object o) { + return CharObjectHashMap.this.containsKey(o); + } + + @Override + public boolean remove(Object o) { + return CharObjectHashMap.this.remove(o) != null; + } + + @Override + public boolean retainAll(Collection retainedKeys) { + boolean changed = false; + for(Iterator> iter = entries().iterator(); iter.hasNext(); ) { + PrimitiveEntry entry = iter.next(); + if (!retainedKeys.contains(entry.key())) { + changed = true; + iter.remove(); + } + } + return changed; + } + + @Override + public void clear() { + CharObjectHashMap.this.clear(); + } + + @Override + public Iterator iterator() { + return new Iterator() { + private final Iterator> iter = entrySet.iterator(); + + @Override + public boolean hasNext() { + return iter.hasNext(); + } + + @Override + public Character next() { + return iter.next().getKey(); + } + + @Override + public void remove() { + iter.remove(); + } + }; + } + } + + /** + * Iterator over primitive entries. Entry key/values are overwritten by each call to {@link #next()}. + */ + private final class PrimitiveIterator implements Iterator>, PrimitiveEntry { + private int prevIndex = -1; + private int nextIndex = -1; + private int entryIndex = -1; + + private void scanNext() { + while (++nextIndex != values.length && values[nextIndex] == null) { + } + } + + @Override + public boolean hasNext() { + if (nextIndex == -1) { + scanNext(); + } + return nextIndex != values.length; + } + + @Override + public PrimitiveEntry next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + + prevIndex = nextIndex; + scanNext(); + + // Always return the same Entry object, just change its index each time. + entryIndex = prevIndex; + return this; + } + + @Override + public void remove() { + if (prevIndex == -1) { + throw new IllegalStateException("next must be called before each remove."); + } + if (removeAt(prevIndex)) { + // removeAt may move elements "back" in the array if they have been displaced because their spot in the + // array was occupied when they were inserted. If this occurs then the nextIndex is now invalid and + // should instead point to the prevIndex which now holds an element which was "moved back". + nextIndex = prevIndex; + } + prevIndex = -1; + } + + // Entry implementation. Since this implementation uses a single Entry, we coalesce that + // into the Iterator object (potentially making loop optimization much easier). + + @Override + public char key() { + return keys[entryIndex]; + } + + @Override + public V value() { + return toExternal(values[entryIndex]); + } + + @Override + public void setValue(V value) { + values[entryIndex] = toInternal(value); + } + } + + /** + * Iterator used by the {@link Map} interface. + */ + private final class MapIterator implements Iterator> { + private final PrimitiveIterator iter = new PrimitiveIterator(); + + @Override + public boolean hasNext() { + return iter.hasNext(); + } + + @Override + public Entry next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + + iter.next(); + + return new MapEntry(iter.entryIndex); + } + + @Override + public void remove() { + iter.remove(); + } + } + + /** + * A single entry in the map. + */ + final class MapEntry implements Entry { + private final int entryIndex; + + MapEntry(int entryIndex) { + this.entryIndex = entryIndex; + } + + @Override + public Character getKey() { + verifyExists(); + return keys[entryIndex]; + } + + @Override + public V getValue() { + verifyExists(); + return toExternal(values[entryIndex]); + } + + @Override + public V setValue(V value) { + verifyExists(); + V prevValue = toExternal(values[entryIndex]); + values[entryIndex] = toInternal(value); + return prevValue; + } + + private void verifyExists() { + if (values[entryIndex] == null) { + throw new IllegalStateException("The map entry has been removed"); + } + } + } +} diff --git a/netty-util/src/main/java/io/netty/util/collection/CharObjectMap.java b/netty-util/src/main/java/io/netty/util/collection/CharObjectMap.java new file mode 100644 index 0000000..3bbe640 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/collection/CharObjectMap.java @@ -0,0 +1,84 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.util.collection; + +import java.util.Map; + +/** + * Interface for a primitive map that uses {@code char}s as keys. + * + * @param the value type stored in the map. + */ +public interface CharObjectMap extends Map { + + /** + * A primitive entry in the map, provided by the iterator from {@link #entries()} + * + * @param the value type stored in the map. + */ + interface PrimitiveEntry { + /** + * Gets the key for this entry. + */ + char key(); + + /** + * Gets the value for this entry. + */ + V value(); + + /** + * Sets the value for this entry. + */ + void setValue(V value); + } + + /** + * Gets the value in the map with the specified key. + * + * @param key the key whose associated value is to be returned. + * @return the value or {@code null} if the key was not found in the map. + */ + V get(char key); + + /** + * Puts the given entry into the map. + * + * @param key the key of the entry. + * @param value the value of the entry. + * @return the previous value for this key or {@code null} if there was no previous mapping. + */ + V put(char key, V value); + + /** + * Removes the entry with the specified key. + * + * @param key the key for the entry to be removed from this map. + * @return the previous value for the key, or {@code null} if there was no mapping. + */ + V remove(char key); + + /** + * Gets an iterable to traverse over the primitive entries contained in this map. As an optimization, + * the {@link PrimitiveEntry}s returned by the {@link java.util.Iterator} may change as the {@link java.util.Iterator} + * progresses. The caller should not rely on {@link PrimitiveEntry} key/value stability. + */ + Iterable> entries(); + + /** + * Indicates whether or not this map contains a value for the specified key. + */ + boolean containsKey(char key); +} diff --git a/netty-util/src/main/java/io/netty/util/collection/IntCollections.java b/netty-util/src/main/java/io/netty/util/collection/IntCollections.java new file mode 100644 index 0000000..4a4b457 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/collection/IntCollections.java @@ -0,0 +1,313 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.util.collection; + +import java.util.Collection; +import java.util.Collections; +import java.util.Iterator; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Set; + +/** + * Utilities for int-based primitive collections. + */ +public final class IntCollections { + + private static final IntObjectMap EMPTY_MAP = new EmptyMap(); + + private IntCollections() { + } + + /** + * Returns an unmodifiable empty {@link IntObjectMap}. + */ + @SuppressWarnings("unchecked") + public static IntObjectMap emptyMap() { + return (IntObjectMap) EMPTY_MAP; + } + + /** + * Creates an unmodifiable wrapper around the given map. + */ + public static IntObjectMap unmodifiableMap(final IntObjectMap map) { + return new UnmodifiableMap(map); + } + + /** + * An empty map. All operations that attempt to modify the map are unsupported. + */ + private static final class EmptyMap implements IntObjectMap { + @Override + public Object get(int key) { + return null; + } + + @Override + public Object put(int key, Object value) { + throw new UnsupportedOperationException("put"); + } + + @Override + public Object remove(int key) { + return null; + } + + @Override + public int size() { + return 0; + } + + @Override + public boolean isEmpty() { + return true; + } + + @Override + public boolean containsKey(Object key) { + return false; + } + + @Override + public void clear() { + // Do nothing. + } + + @Override + public Set keySet() { + return Collections.emptySet(); + } + + @Override + public boolean containsKey(int key) { + return false; + } + + @Override + public boolean containsValue(Object value) { + return false; + } + + @Override + public Iterable> entries() { + return Collections.emptySet(); + } + + @Override + public Object get(Object key) { + return null; + } + + @Override + public Object put(Integer key, Object value) { + throw new UnsupportedOperationException(); + } + + @Override + public Object remove(Object key) { + return null; + } + + @Override + public void putAll(Map m) { + throw new UnsupportedOperationException(); + } + + @Override + public Collection values() { + return Collections.emptyList(); + } + + @Override + public Set> entrySet() { + return Collections.emptySet(); + } + } + + /** + * An unmodifiable wrapper around a {@link IntObjectMap}. + * + * @param the value type stored in the map. + */ + private static final class UnmodifiableMap implements IntObjectMap { + private final IntObjectMap map; + private Set keySet; + private Set> entrySet; + private Collection values; + private Iterable> entries; + + UnmodifiableMap(IntObjectMap map) { + this.map = map; + } + + @Override + public V get(int key) { + return map.get(key); + } + + @Override + public V put(int key, V value) { + throw new UnsupportedOperationException("put"); + } + + @Override + public V remove(int key) { + throw new UnsupportedOperationException("remove"); + } + + @Override + public int size() { + return map.size(); + } + + @Override + public boolean isEmpty() { + return map.isEmpty(); + } + + @Override + public void clear() { + throw new UnsupportedOperationException("clear"); + } + + @Override + public boolean containsKey(int key) { + return map.containsKey(key); + } + + @Override + public boolean containsValue(Object value) { + return map.containsValue(value); + } + + @Override + public boolean containsKey(Object key) { + return map.containsKey(key); + } + + @Override + public V get(Object key) { + return map.get(key); + } + + @Override + public V put(Integer key, V value) { + throw new UnsupportedOperationException("put"); + } + + @Override + public V remove(Object key) { + throw new UnsupportedOperationException("remove"); + } + + @Override + public void putAll(Map m) { + throw new UnsupportedOperationException("putAll"); + } + + @Override + public Iterable> entries() { + if (entries == null) { + entries = new Iterable>() { + @Override + public Iterator> iterator() { + return new IteratorImpl(map.entries().iterator()); + } + }; + } + + return entries; + } + + @Override + public Set keySet() { + if (keySet == null) { + keySet = Collections.unmodifiableSet(map.keySet()); + } + return keySet; + } + + @Override + public Set> entrySet() { + if (entrySet == null) { + entrySet = Collections.unmodifiableSet(map.entrySet()); + } + return entrySet; + } + + @Override + public Collection values() { + if (values == null) { + values = Collections.unmodifiableCollection(map.values()); + } + return values; + } + + /** + * Unmodifiable wrapper for an iterator. + */ + private class IteratorImpl implements Iterator> { + final Iterator> iter; + + IteratorImpl(Iterator> iter) { + this.iter = iter; + } + + @Override + public boolean hasNext() { + return iter.hasNext(); + } + + @Override + public PrimitiveEntry next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + return new EntryImpl(iter.next()); + } + + @Override + public void remove() { + throw new UnsupportedOperationException("remove"); + } + } + + /** + * Unmodifiable wrapper for an entry. + */ + private class EntryImpl implements PrimitiveEntry { + private final PrimitiveEntry entry; + + EntryImpl(PrimitiveEntry entry) { + this.entry = entry; + } + + @Override + public int key() { + return entry.key(); + } + + @Override + public V value() { + return entry.value(); + } + + @Override + public void setValue(V value) { + throw new UnsupportedOperationException("setValue"); + } + } + } +} diff --git a/netty-util/src/main/java/io/netty/util/collection/IntObjectHashMap.java b/netty-util/src/main/java/io/netty/util/collection/IntObjectHashMap.java new file mode 100644 index 0000000..ac90d87 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/collection/IntObjectHashMap.java @@ -0,0 +1,723 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package io.netty.util.collection; + +import static io.netty.util.internal.MathUtil.safeFindNextPositivePowerOfTwo; + +import java.util.AbstractCollection; +import java.util.AbstractSet; +import java.util.Arrays; +import java.util.Collection; +import java.util.Iterator; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Set; + +/** + * A hash map implementation of {@link IntObjectMap} that uses open addressing for keys. + * To minimize the memory footprint, this class uses open addressing rather than chaining. + * Collisions are resolved using linear probing. Deletions implement compaction, so cost of + * remove can approach O(N) for full maps, which makes a small loadFactor recommended. + * + * @param The value type stored in the map. + */ +public class IntObjectHashMap implements IntObjectMap { + + /** Default initial capacity. Used if not specified in the constructor */ + public static final int DEFAULT_CAPACITY = 8; + + /** Default load factor. Used if not specified in the constructor */ + public static final float DEFAULT_LOAD_FACTOR = 0.5f; + + /** + * Placeholder for null values, so we can use the actual null to mean available. + * (Better than using a placeholder for available: less references for GC processing.) + */ + private static final Object NULL_VALUE = new Object(); + + /** The maximum number of elements allowed without allocating more space. */ + private int maxSize; + + /** The load factor for the map. Used to calculate {@link #maxSize}. */ + private final float loadFactor; + + private int[] keys; + private V[] values; + private int size; + private int mask; + + private final Set keySet = new KeySet(); + private final Set> entrySet = new EntrySet(); + private final Iterable> entries = new Iterable>() { + @Override + public Iterator> iterator() { + return new PrimitiveIterator(); + } + }; + + public IntObjectHashMap() { + this(DEFAULT_CAPACITY, DEFAULT_LOAD_FACTOR); + } + + public IntObjectHashMap(int initialCapacity) { + this(initialCapacity, DEFAULT_LOAD_FACTOR); + } + + public IntObjectHashMap(int initialCapacity, float loadFactor) { + if (loadFactor <= 0.0f || loadFactor > 1.0f) { + // Cannot exceed 1 because we can never store more than capacity elements; + // using a bigger loadFactor would trigger rehashing before the desired load is reached. + throw new IllegalArgumentException("loadFactor must be > 0 and <= 1"); + } + + this.loadFactor = loadFactor; + + // Adjust the initial capacity if necessary. + int capacity = safeFindNextPositivePowerOfTwo(initialCapacity); + mask = capacity - 1; + + // Allocate the arrays. + keys = new int[capacity]; + @SuppressWarnings({ "unchecked", "SuspiciousArrayCast" }) + V[] temp = (V[]) new Object[capacity]; + values = temp; + + // Initialize the maximum size value. + maxSize = calcMaxSize(capacity); + } + + private static T toExternal(T value) { + assert value != null : "null is not a legitimate internal value. Concurrent Modification?"; + return value == NULL_VALUE ? null : value; + } + + @SuppressWarnings("unchecked") + private static T toInternal(T value) { + return value == null ? (T) NULL_VALUE : value; + } + + @Override + public V get(int key) { + int index = indexOf(key); + return index == -1 ? null : toExternal(values[index]); + } + + @Override + public V put(int key, V value) { + int startIndex = hashIndex(key); + int index = startIndex; + + for (;;) { + if (values[index] == null) { + // Found empty slot, use it. + keys[index] = key; + values[index] = toInternal(value); + growSize(); + return null; + } + if (keys[index] == key) { + // Found existing entry with this key, just replace the value. + V previousValue = values[index]; + values[index] = toInternal(value); + return toExternal(previousValue); + } + + // Conflict, keep probing ... + if ((index = probeNext(index)) == startIndex) { + // Can only happen if the map was full at MAX_ARRAY_SIZE and couldn't grow. + throw new IllegalStateException("Unable to insert"); + } + } + } + + @Override + public void putAll(Map sourceMap) { + if (sourceMap instanceof IntObjectHashMap) { + // Optimization - iterate through the arrays. + @SuppressWarnings("unchecked") + IntObjectHashMap source = (IntObjectHashMap) sourceMap; + for (int i = 0; i < source.values.length; ++i) { + V sourceValue = source.values[i]; + if (sourceValue != null) { + put(source.keys[i], sourceValue); + } + } + return; + } + + // Otherwise, just add each entry. + for (Entry entry : sourceMap.entrySet()) { + put(entry.getKey(), entry.getValue()); + } + } + + @Override + public V remove(int key) { + int index = indexOf(key); + if (index == -1) { + return null; + } + + V prev = values[index]; + removeAt(index); + return toExternal(prev); + } + + @Override + public int size() { + return size; + } + + @Override + public boolean isEmpty() { + return size == 0; + } + + @Override + public void clear() { + Arrays.fill(keys, (int) 0); + Arrays.fill(values, null); + size = 0; + } + + @Override + public boolean containsKey(int key) { + return indexOf(key) >= 0; + } + + @Override + public boolean containsValue(Object value) { + @SuppressWarnings("unchecked") + V v1 = toInternal((V) value); + for (V v2 : values) { + // The map supports null values; this will be matched as NULL_VALUE.equals(NULL_VALUE). + if (v2 != null && v2.equals(v1)) { + return true; + } + } + return false; + } + + @Override + public Iterable> entries() { + return entries; + } + + @Override + public Collection values() { + return new AbstractCollection() { + @Override + public Iterator iterator() { + return new Iterator() { + final PrimitiveIterator iter = new PrimitiveIterator(); + + @Override + public boolean hasNext() { + return iter.hasNext(); + } + + @Override + public V next() { + return iter.next().value(); + } + + @Override + public void remove() { + iter.remove(); + } + }; + } + + @Override + public int size() { + return size; + } + }; + } + + @Override + public int hashCode() { + // Hashcode is based on all non-zero, valid keys. We have to scan the whole keys + // array, which may have different lengths for two maps of same size(), so the + // capacity cannot be used as input for hashing but the size can. + int hash = size; + for (int key : keys) { + // 0 can be a valid key or unused slot, but won't impact the hashcode in either case. + // This way we can use a cheap loop without conditionals, or hard-to-unroll operations, + // or the devastatingly bad memory locality of visiting value objects. + // Also, it's important to use a hash function that does not depend on the ordering + // of terms, only their values; since the map is an unordered collection and + // entries can end up in different positions in different maps that have the same + // elements, but with different history of puts/removes, due to conflicts. + hash ^= hashCode(key); + } + return hash; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof IntObjectMap)) { + return false; + } + @SuppressWarnings("rawtypes") + IntObjectMap other = (IntObjectMap) obj; + if (size != other.size()) { + return false; + } + for (int i = 0; i < values.length; ++i) { + V value = values[i]; + if (value != null) { + int key = keys[i]; + Object otherValue = other.get(key); + if (value == NULL_VALUE) { + if (otherValue != null) { + return false; + } + } else if (!value.equals(otherValue)) { + return false; + } + } + } + return true; + } + + @Override + public boolean containsKey(Object key) { + return containsKey(objectToKey(key)); + } + + @Override + public V get(Object key) { + return get(objectToKey(key)); + } + + @Override + public V put(Integer key, V value) { + return put(objectToKey(key), value); + } + + @Override + public V remove(Object key) { + return remove(objectToKey(key)); + } + + @Override + public Set keySet() { + return keySet; + } + + @Override + public Set> entrySet() { + return entrySet; + } + + private int objectToKey(Object key) { + return (int) ((Integer) key).intValue(); + } + + /** + * Locates the index for the given key. This method probes using double hashing. + * + * @param key the key for an entry in the map. + * @return the index where the key was found, or {@code -1} if no entry is found for that key. + */ + private int indexOf(int key) { + int startIndex = hashIndex(key); + int index = startIndex; + + for (;;) { + if (values[index] == null) { + // It's available, so no chance that this value exists anywhere in the map. + return -1; + } + if (key == keys[index]) { + return index; + } + + // Conflict, keep probing ... + if ((index = probeNext(index)) == startIndex) { + return -1; + } + } + } + + /** + * Returns the hashed index for the given key. + */ + private int hashIndex(int key) { + // The array lengths are always a power of two, so we can use a bitmask to stay inside the array bounds. + return hashCode(key) & mask; + } + + /** + * Returns the hash code for the key. + */ + private static int hashCode(int key) { + return (int) key; + } + + /** + * Get the next sequential index after {@code index} and wraps if necessary. + */ + private int probeNext(int index) { + // The array lengths are always a power of two, so we can use a bitmask to stay inside the array bounds. + return (index + 1) & mask; + } + + /** + * Grows the map size after an insertion. If necessary, performs a rehash of the map. + */ + private void growSize() { + size++; + + if (size > maxSize) { + if(keys.length == Integer.MAX_VALUE) { + throw new IllegalStateException("Max capacity reached at size=" + size); + } + + // Double the capacity. + rehash(keys.length << 1); + } + } + + /** + * Removes entry at the given index position. Also performs opportunistic, incremental rehashing + * if necessary to not break conflict chains. + * + * @param index the index position of the element to remove. + * @return {@code true} if the next item was moved back. {@code false} otherwise. + */ + private boolean removeAt(final int index) { + --size; + // Clearing the key is not strictly necessary (for GC like in a regular collection), + // but recommended for security. The memory location is still fresh in the cache anyway. + keys[index] = 0; + values[index] = null; + + // In the interval from index to the next available entry, the arrays may have entries + // that are displaced from their base position due to prior conflicts. Iterate these + // entries and move them back if possible, optimizing future lookups. + // Knuth Section 6.4 Algorithm R, also used by the JDK's IdentityHashMap. + + int nextFree = index; + int i = probeNext(index); + for (V value = values[i]; value != null; value = values[i = probeNext(i)]) { + int key = keys[i]; + int bucket = hashIndex(key); + if (i < bucket && (bucket <= nextFree || nextFree <= i) || + bucket <= nextFree && nextFree <= i) { + // Move the displaced entry "back" to the first available position. + keys[nextFree] = key; + values[nextFree] = value; + // Put the first entry after the displaced entry + keys[i] = 0; + values[i] = null; + nextFree = i; + } + } + return nextFree != index; + } + + /** + * Calculates the maximum size allowed before rehashing. + */ + private int calcMaxSize(int capacity) { + // Clip the upper bound so that there will always be at least one available slot. + int upperBound = capacity - 1; + return Math.min(upperBound, (int) (capacity * loadFactor)); + } + + /** + * Rehashes the map for the given capacity. + * + * @param newCapacity the new capacity for the map. + */ + private void rehash(int newCapacity) { + int[] oldKeys = keys; + V[] oldVals = values; + + keys = new int[newCapacity]; + @SuppressWarnings({ "unchecked", "SuspiciousArrayCast" }) + V[] temp = (V[]) new Object[newCapacity]; + values = temp; + + maxSize = calcMaxSize(newCapacity); + mask = newCapacity - 1; + + // Insert to the new arrays. + for (int i = 0; i < oldVals.length; ++i) { + V oldVal = oldVals[i]; + if (oldVal != null) { + // Inlined put(), but much simpler: we don't need to worry about + // duplicated keys, growing/rehashing, or failing to insert. + int oldKey = oldKeys[i]; + int index = hashIndex(oldKey); + + for (;;) { + if (values[index] == null) { + keys[index] = oldKey; + values[index] = oldVal; + break; + } + + // Conflict, keep probing. Can wrap around, but never reaches startIndex again. + index = probeNext(index); + } + } + } + } + + @Override + public String toString() { + if (isEmpty()) { + return "{}"; + } + StringBuilder sb = new StringBuilder(4 * size); + sb.append('{'); + boolean first = true; + for (int i = 0; i < values.length; ++i) { + V value = values[i]; + if (value != null) { + if (!first) { + sb.append(", "); + } + sb.append(keyToString(keys[i])).append('=').append(value == this ? "(this Map)" : + toExternal(value)); + first = false; + } + } + return sb.append('}').toString(); + } + + /** + * Helper method called by {@link #toString()} in order to convert a single map key into a string. + * This is protected to allow subclasses to override the appearance of a given key. + */ + protected String keyToString(int key) { + return Integer.toString(key); + } + + /** + * Set implementation for iterating over the entries of the map. + */ + private final class EntrySet extends AbstractSet> { + @Override + public Iterator> iterator() { + return new MapIterator(); + } + + @Override + public int size() { + return IntObjectHashMap.this.size(); + } + } + + /** + * Set implementation for iterating over the keys. + */ + private final class KeySet extends AbstractSet { + @Override + public int size() { + return IntObjectHashMap.this.size(); + } + + @Override + public boolean contains(Object o) { + return IntObjectHashMap.this.containsKey(o); + } + + @Override + public boolean remove(Object o) { + return IntObjectHashMap.this.remove(o) != null; + } + + @Override + public boolean retainAll(Collection retainedKeys) { + boolean changed = false; + for(Iterator> iter = entries().iterator(); iter.hasNext(); ) { + PrimitiveEntry entry = iter.next(); + if (!retainedKeys.contains(entry.key())) { + changed = true; + iter.remove(); + } + } + return changed; + } + + @Override + public void clear() { + IntObjectHashMap.this.clear(); + } + + @Override + public Iterator iterator() { + return new Iterator() { + private final Iterator> iter = entrySet.iterator(); + + @Override + public boolean hasNext() { + return iter.hasNext(); + } + + @Override + public Integer next() { + return iter.next().getKey(); + } + + @Override + public void remove() { + iter.remove(); + } + }; + } + } + + /** + * Iterator over primitive entries. Entry key/values are overwritten by each call to {@link #next()}. + */ + private final class PrimitiveIterator implements Iterator>, PrimitiveEntry { + private int prevIndex = -1; + private int nextIndex = -1; + private int entryIndex = -1; + + private void scanNext() { + while (++nextIndex != values.length && values[nextIndex] == null) { + } + } + + @Override + public boolean hasNext() { + if (nextIndex == -1) { + scanNext(); + } + return nextIndex != values.length; + } + + @Override + public PrimitiveEntry next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + + prevIndex = nextIndex; + scanNext(); + + // Always return the same Entry object, just change its index each time. + entryIndex = prevIndex; + return this; + } + + @Override + public void remove() { + if (prevIndex == -1) { + throw new IllegalStateException("next must be called before each remove."); + } + if (removeAt(prevIndex)) { + // removeAt may move elements "back" in the array if they have been displaced because their spot in the + // array was occupied when they were inserted. If this occurs then the nextIndex is now invalid and + // should instead point to the prevIndex which now holds an element which was "moved back". + nextIndex = prevIndex; + } + prevIndex = -1; + } + + // Entry implementation. Since this implementation uses a single Entry, we coalesce that + // into the Iterator object (potentially making loop optimization much easier). + + @Override + public int key() { + return keys[entryIndex]; + } + + @Override + public V value() { + return toExternal(values[entryIndex]); + } + + @Override + public void setValue(V value) { + values[entryIndex] = toInternal(value); + } + } + + /** + * Iterator used by the {@link Map} interface. + */ + private final class MapIterator implements Iterator> { + private final PrimitiveIterator iter = new PrimitiveIterator(); + + @Override + public boolean hasNext() { + return iter.hasNext(); + } + + @Override + public Entry next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + + iter.next(); + + return new MapEntry(iter.entryIndex); + } + + @Override + public void remove() { + iter.remove(); + } + } + + /** + * A single entry in the map. + */ + final class MapEntry implements Entry { + private final int entryIndex; + + MapEntry(int entryIndex) { + this.entryIndex = entryIndex; + } + + @Override + public Integer getKey() { + verifyExists(); + return keys[entryIndex]; + } + + @Override + public V getValue() { + verifyExists(); + return toExternal(values[entryIndex]); + } + + @Override + public V setValue(V value) { + verifyExists(); + V prevValue = toExternal(values[entryIndex]); + values[entryIndex] = toInternal(value); + return prevValue; + } + + private void verifyExists() { + if (values[entryIndex] == null) { + throw new IllegalStateException("The map entry has been removed"); + } + } + } +} diff --git a/netty-util/src/main/java/io/netty/util/collection/IntObjectMap.java b/netty-util/src/main/java/io/netty/util/collection/IntObjectMap.java new file mode 100644 index 0000000..7da2840 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/collection/IntObjectMap.java @@ -0,0 +1,84 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.util.collection; + +import java.util.Map; + +/** + * Interface for a primitive map that uses {@code int}s as keys. + * + * @param the value type stored in the map. + */ +public interface IntObjectMap extends Map { + + /** + * A primitive entry in the map, provided by the iterator from {@link #entries()} + * + * @param the value type stored in the map. + */ + interface PrimitiveEntry { + /** + * Gets the key for this entry. + */ + int key(); + + /** + * Gets the value for this entry. + */ + V value(); + + /** + * Sets the value for this entry. + */ + void setValue(V value); + } + + /** + * Gets the value in the map with the specified key. + * + * @param key the key whose associated value is to be returned. + * @return the value or {@code null} if the key was not found in the map. + */ + V get(int key); + + /** + * Puts the given entry into the map. + * + * @param key the key of the entry. + * @param value the value of the entry. + * @return the previous value for this key or {@code null} if there was no previous mapping. + */ + V put(int key, V value); + + /** + * Removes the entry with the specified key. + * + * @param key the key for the entry to be removed from this map. + * @return the previous value for the key, or {@code null} if there was no mapping. + */ + V remove(int key); + + /** + * Gets an iterable to traverse over the primitive entries contained in this map. As an optimization, + * the {@link PrimitiveEntry}s returned by the {@link java.util.Iterator} may change as the {@link java.util.Iterator} + * progresses. The caller should not rely on {@link PrimitiveEntry} key/value stability. + */ + Iterable> entries(); + + /** + * Indicates whether or not this map contains a value for the specified key. + */ + boolean containsKey(int key); +} diff --git a/netty-util/src/main/java/io/netty/util/collection/LongCollections.java b/netty-util/src/main/java/io/netty/util/collection/LongCollections.java new file mode 100644 index 0000000..2815aaa --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/collection/LongCollections.java @@ -0,0 +1,313 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.util.collection; + +import java.util.Collection; +import java.util.Collections; +import java.util.Iterator; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Set; + +/** + * Utilities for long-based primitive collections. + */ +public final class LongCollections { + + private static final LongObjectMap EMPTY_MAP = new EmptyMap(); + + private LongCollections() { + } + + /** + * Returns an unmodifiable empty {@link LongObjectMap}. + */ + @SuppressWarnings("unchecked") + public static LongObjectMap emptyMap() { + return (LongObjectMap) EMPTY_MAP; + } + + /** + * Creates an unmodifiable wrapper around the given map. + */ + public static LongObjectMap unmodifiableMap(final LongObjectMap map) { + return new UnmodifiableMap(map); + } + + /** + * An empty map. All operations that attempt to modify the map are unsupported. + */ + private static final class EmptyMap implements LongObjectMap { + @Override + public Object get(long key) { + return null; + } + + @Override + public Object put(long key, Object value) { + throw new UnsupportedOperationException("put"); + } + + @Override + public Object remove(long key) { + return null; + } + + @Override + public int size() { + return 0; + } + + @Override + public boolean isEmpty() { + return true; + } + + @Override + public boolean containsKey(Object key) { + return false; + } + + @Override + public void clear() { + // Do nothing. + } + + @Override + public Set keySet() { + return Collections.emptySet(); + } + + @Override + public boolean containsKey(long key) { + return false; + } + + @Override + public boolean containsValue(Object value) { + return false; + } + + @Override + public Iterable> entries() { + return Collections.emptySet(); + } + + @Override + public Object get(Object key) { + return null; + } + + @Override + public Object put(Long key, Object value) { + throw new UnsupportedOperationException(); + } + + @Override + public Object remove(Object key) { + return null; + } + + @Override + public void putAll(Map m) { + throw new UnsupportedOperationException(); + } + + @Override + public Collection values() { + return Collections.emptyList(); + } + + @Override + public Set> entrySet() { + return Collections.emptySet(); + } + } + + /** + * An unmodifiable wrapper around a {@link LongObjectMap}. + * + * @param the value type stored in the map. + */ + private static final class UnmodifiableMap implements LongObjectMap { + private final LongObjectMap map; + private Set keySet; + private Set> entrySet; + private Collection values; + private Iterable> entries; + + UnmodifiableMap(LongObjectMap map) { + this.map = map; + } + + @Override + public V get(long key) { + return map.get(key); + } + + @Override + public V put(long key, V value) { + throw new UnsupportedOperationException("put"); + } + + @Override + public V remove(long key) { + throw new UnsupportedOperationException("remove"); + } + + @Override + public int size() { + return map.size(); + } + + @Override + public boolean isEmpty() { + return map.isEmpty(); + } + + @Override + public void clear() { + throw new UnsupportedOperationException("clear"); + } + + @Override + public boolean containsKey(long key) { + return map.containsKey(key); + } + + @Override + public boolean containsValue(Object value) { + return map.containsValue(value); + } + + @Override + public boolean containsKey(Object key) { + return map.containsKey(key); + } + + @Override + public V get(Object key) { + return map.get(key); + } + + @Override + public V put(Long key, V value) { + throw new UnsupportedOperationException("put"); + } + + @Override + public V remove(Object key) { + throw new UnsupportedOperationException("remove"); + } + + @Override + public void putAll(Map m) { + throw new UnsupportedOperationException("putAll"); + } + + @Override + public Iterable> entries() { + if (entries == null) { + entries = new Iterable>() { + @Override + public Iterator> iterator() { + return new IteratorImpl(map.entries().iterator()); + } + }; + } + + return entries; + } + + @Override + public Set keySet() { + if (keySet == null) { + keySet = Collections.unmodifiableSet(map.keySet()); + } + return keySet; + } + + @Override + public Set> entrySet() { + if (entrySet == null) { + entrySet = Collections.unmodifiableSet(map.entrySet()); + } + return entrySet; + } + + @Override + public Collection values() { + if (values == null) { + values = Collections.unmodifiableCollection(map.values()); + } + return values; + } + + /** + * Unmodifiable wrapper for an iterator. + */ + private class IteratorImpl implements Iterator> { + final Iterator> iter; + + IteratorImpl(Iterator> iter) { + this.iter = iter; + } + + @Override + public boolean hasNext() { + return iter.hasNext(); + } + + @Override + public PrimitiveEntry next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + return new EntryImpl(iter.next()); + } + + @Override + public void remove() { + throw new UnsupportedOperationException("remove"); + } + } + + /** + * Unmodifiable wrapper for an entry. + */ + private class EntryImpl implements PrimitiveEntry { + private final PrimitiveEntry entry; + + EntryImpl(PrimitiveEntry entry) { + this.entry = entry; + } + + @Override + public long key() { + return entry.key(); + } + + @Override + public V value() { + return entry.value(); + } + + @Override + public void setValue(V value) { + throw new UnsupportedOperationException("setValue"); + } + } + } +} diff --git a/netty-util/src/main/java/io/netty/util/collection/LongObjectHashMap.java b/netty-util/src/main/java/io/netty/util/collection/LongObjectHashMap.java new file mode 100644 index 0000000..6ae8696 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/collection/LongObjectHashMap.java @@ -0,0 +1,723 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package io.netty.util.collection; + +import static io.netty.util.internal.MathUtil.safeFindNextPositivePowerOfTwo; + +import java.util.AbstractCollection; +import java.util.AbstractSet; +import java.util.Arrays; +import java.util.Collection; +import java.util.Iterator; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Set; + +/** + * A hash map implementation of {@link LongObjectMap} that uses open addressing for keys. + * To minimize the memory footprint, this class uses open addressing rather than chaining. + * Collisions are resolved using linear probing. Deletions implement compaction, so cost of + * remove can approach O(N) for full maps, which makes a small loadFactor recommended. + * + * @param The value type stored in the map. + */ +public class LongObjectHashMap implements LongObjectMap { + + /** Default initial capacity. Used if not specified in the constructor */ + public static final int DEFAULT_CAPACITY = 8; + + /** Default load factor. Used if not specified in the constructor */ + public static final float DEFAULT_LOAD_FACTOR = 0.5f; + + /** + * Placeholder for null values, so we can use the actual null to mean available. + * (Better than using a placeholder for available: less references for GC processing.) + */ + private static final Object NULL_VALUE = new Object(); + + /** The maximum number of elements allowed without allocating more space. */ + private int maxSize; + + /** The load factor for the map. Used to calculate {@link #maxSize}. */ + private final float loadFactor; + + private long[] keys; + private V[] values; + private int size; + private int mask; + + private final Set keySet = new KeySet(); + private final Set> entrySet = new EntrySet(); + private final Iterable> entries = new Iterable>() { + @Override + public Iterator> iterator() { + return new PrimitiveIterator(); + } + }; + + public LongObjectHashMap() { + this(DEFAULT_CAPACITY, DEFAULT_LOAD_FACTOR); + } + + public LongObjectHashMap(int initialCapacity) { + this(initialCapacity, DEFAULT_LOAD_FACTOR); + } + + public LongObjectHashMap(int initialCapacity, float loadFactor) { + if (loadFactor <= 0.0f || loadFactor > 1.0f) { + // Cannot exceed 1 because we can never store more than capacity elements; + // using a bigger loadFactor would trigger rehashing before the desired load is reached. + throw new IllegalArgumentException("loadFactor must be > 0 and <= 1"); + } + + this.loadFactor = loadFactor; + + // Adjust the initial capacity if necessary. + int capacity = safeFindNextPositivePowerOfTwo(initialCapacity); + mask = capacity - 1; + + // Allocate the arrays. + keys = new long[capacity]; + @SuppressWarnings({ "unchecked", "SuspiciousArrayCast" }) + V[] temp = (V[]) new Object[capacity]; + values = temp; + + // Initialize the maximum size value. + maxSize = calcMaxSize(capacity); + } + + private static T toExternal(T value) { + assert value != null : "null is not a legitimate internal value. Concurrent Modification?"; + return value == NULL_VALUE ? null : value; + } + + @SuppressWarnings("unchecked") + private static T toInternal(T value) { + return value == null ? (T) NULL_VALUE : value; + } + + @Override + public V get(long key) { + int index = indexOf(key); + return index == -1 ? null : toExternal(values[index]); + } + + @Override + public V put(long key, V value) { + int startIndex = hashIndex(key); + int index = startIndex; + + for (;;) { + if (values[index] == null) { + // Found empty slot, use it. + keys[index] = key; + values[index] = toInternal(value); + growSize(); + return null; + } + if (keys[index] == key) { + // Found existing entry with this key, just replace the value. + V previousValue = values[index]; + values[index] = toInternal(value); + return toExternal(previousValue); + } + + // Conflict, keep probing ... + if ((index = probeNext(index)) == startIndex) { + // Can only happen if the map was full at MAX_ARRAY_SIZE and couldn't grow. + throw new IllegalStateException("Unable to insert"); + } + } + } + + @Override + public void putAll(Map sourceMap) { + if (sourceMap instanceof LongObjectHashMap) { + // Optimization - iterate through the arrays. + @SuppressWarnings("unchecked") + LongObjectHashMap source = (LongObjectHashMap) sourceMap; + for (int i = 0; i < source.values.length; ++i) { + V sourceValue = source.values[i]; + if (sourceValue != null) { + put(source.keys[i], sourceValue); + } + } + return; + } + + // Otherwise, just add each entry. + for (Entry entry : sourceMap.entrySet()) { + put(entry.getKey(), entry.getValue()); + } + } + + @Override + public V remove(long key) { + int index = indexOf(key); + if (index == -1) { + return null; + } + + V prev = values[index]; + removeAt(index); + return toExternal(prev); + } + + @Override + public int size() { + return size; + } + + @Override + public boolean isEmpty() { + return size == 0; + } + + @Override + public void clear() { + Arrays.fill(keys, (long) 0); + Arrays.fill(values, null); + size = 0; + } + + @Override + public boolean containsKey(long key) { + return indexOf(key) >= 0; + } + + @Override + public boolean containsValue(Object value) { + @SuppressWarnings("unchecked") + V v1 = toInternal((V) value); + for (V v2 : values) { + // The map supports null values; this will be matched as NULL_VALUE.equals(NULL_VALUE). + if (v2 != null && v2.equals(v1)) { + return true; + } + } + return false; + } + + @Override + public Iterable> entries() { + return entries; + } + + @Override + public Collection values() { + return new AbstractCollection() { + @Override + public Iterator iterator() { + return new Iterator() { + final PrimitiveIterator iter = new PrimitiveIterator(); + + @Override + public boolean hasNext() { + return iter.hasNext(); + } + + @Override + public V next() { + return iter.next().value(); + } + + @Override + public void remove() { + iter.remove(); + } + }; + } + + @Override + public int size() { + return size; + } + }; + } + + @Override + public int hashCode() { + // Hashcode is based on all non-zero, valid keys. We have to scan the whole keys + // array, which may have different lengths for two maps of same size(), so the + // capacity cannot be used as input for hashing but the size can. + int hash = size; + for (long key : keys) { + // 0 can be a valid key or unused slot, but won't impact the hashcode in either case. + // This way we can use a cheap loop without conditionals, or hard-to-unroll operations, + // or the devastatingly bad memory locality of visiting value objects. + // Also, it's important to use a hash function that does not depend on the ordering + // of terms, only their values; since the map is an unordered collection and + // entries can end up in different positions in different maps that have the same + // elements, but with different history of puts/removes, due to conflicts. + hash ^= hashCode(key); + } + return hash; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof LongObjectMap)) { + return false; + } + @SuppressWarnings("rawtypes") + LongObjectMap other = (LongObjectMap) obj; + if (size != other.size()) { + return false; + } + for (int i = 0; i < values.length; ++i) { + V value = values[i]; + if (value != null) { + long key = keys[i]; + Object otherValue = other.get(key); + if (value == NULL_VALUE) { + if (otherValue != null) { + return false; + } + } else if (!value.equals(otherValue)) { + return false; + } + } + } + return true; + } + + @Override + public boolean containsKey(Object key) { + return containsKey(objectToKey(key)); + } + + @Override + public V get(Object key) { + return get(objectToKey(key)); + } + + @Override + public V put(Long key, V value) { + return put(objectToKey(key), value); + } + + @Override + public V remove(Object key) { + return remove(objectToKey(key)); + } + + @Override + public Set keySet() { + return keySet; + } + + @Override + public Set> entrySet() { + return entrySet; + } + + private long objectToKey(Object key) { + return (long) ((Long) key).longValue(); + } + + /** + * Locates the index for the given key. This method probes using double hashing. + * + * @param key the key for an entry in the map. + * @return the index where the key was found, or {@code -1} if no entry is found for that key. + */ + private int indexOf(long key) { + int startIndex = hashIndex(key); + int index = startIndex; + + for (;;) { + if (values[index] == null) { + // It's available, so no chance that this value exists anywhere in the map. + return -1; + } + if (key == keys[index]) { + return index; + } + + // Conflict, keep probing ... + if ((index = probeNext(index)) == startIndex) { + return -1; + } + } + } + + /** + * Returns the hashed index for the given key. + */ + private int hashIndex(long key) { + // The array lengths are always a power of two, so we can use a bitmask to stay inside the array bounds. + return hashCode(key) & mask; + } + + /** + * Returns the hash code for the key. + */ + private static int hashCode(long key) { + return (int) (key ^ (key >>> 32)); + } + + /** + * Get the next sequential index after {@code index} and wraps if necessary. + */ + private int probeNext(int index) { + // The array lengths are always a power of two, so we can use a bitmask to stay inside the array bounds. + return (index + 1) & mask; + } + + /** + * Grows the map size after an insertion. If necessary, performs a rehash of the map. + */ + private void growSize() { + size++; + + if (size > maxSize) { + if(keys.length == Integer.MAX_VALUE) { + throw new IllegalStateException("Max capacity reached at size=" + size); + } + + // Double the capacity. + rehash(keys.length << 1); + } + } + + /** + * Removes entry at the given index position. Also performs opportunistic, incremental rehashing + * if necessary to not break conflict chains. + * + * @param index the index position of the element to remove. + * @return {@code true} if the next item was moved back. {@code false} otherwise. + */ + private boolean removeAt(final int index) { + --size; + // Clearing the key is not strictly necessary (for GC like in a regular collection), + // but recommended for security. The memory location is still fresh in the cache anyway. + keys[index] = 0; + values[index] = null; + + // In the interval from index to the next available entry, the arrays may have entries + // that are displaced from their base position due to prior conflicts. Iterate these + // entries and move them back if possible, optimizing future lookups. + // Knuth Section 6.4 Algorithm R, also used by the JDK's IdentityHashMap. + + int nextFree = index; + int i = probeNext(index); + for (V value = values[i]; value != null; value = values[i = probeNext(i)]) { + long key = keys[i]; + int bucket = hashIndex(key); + if (i < bucket && (bucket <= nextFree || nextFree <= i) || + bucket <= nextFree && nextFree <= i) { + // Move the displaced entry "back" to the first available position. + keys[nextFree] = key; + values[nextFree] = value; + // Put the first entry after the displaced entry + keys[i] = 0; + values[i] = null; + nextFree = i; + } + } + return nextFree != index; + } + + /** + * Calculates the maximum size allowed before rehashing. + */ + private int calcMaxSize(int capacity) { + // Clip the upper bound so that there will always be at least one available slot. + int upperBound = capacity - 1; + return Math.min(upperBound, (int) (capacity * loadFactor)); + } + + /** + * Rehashes the map for the given capacity. + * + * @param newCapacity the new capacity for the map. + */ + private void rehash(int newCapacity) { + long[] oldKeys = keys; + V[] oldVals = values; + + keys = new long[newCapacity]; + @SuppressWarnings({ "unchecked", "SuspiciousArrayCast" }) + V[] temp = (V[]) new Object[newCapacity]; + values = temp; + + maxSize = calcMaxSize(newCapacity); + mask = newCapacity - 1; + + // Insert to the new arrays. + for (int i = 0; i < oldVals.length; ++i) { + V oldVal = oldVals[i]; + if (oldVal != null) { + // Inlined put(), but much simpler: we don't need to worry about + // duplicated keys, growing/rehashing, or failing to insert. + long oldKey = oldKeys[i]; + int index = hashIndex(oldKey); + + for (;;) { + if (values[index] == null) { + keys[index] = oldKey; + values[index] = oldVal; + break; + } + + // Conflict, keep probing. Can wrap around, but never reaches startIndex again. + index = probeNext(index); + } + } + } + } + + @Override + public String toString() { + if (isEmpty()) { + return "{}"; + } + StringBuilder sb = new StringBuilder(4 * size); + sb.append('{'); + boolean first = true; + for (int i = 0; i < values.length; ++i) { + V value = values[i]; + if (value != null) { + if (!first) { + sb.append(", "); + } + sb.append(keyToString(keys[i])).append('=').append(value == this ? "(this Map)" : + toExternal(value)); + first = false; + } + } + return sb.append('}').toString(); + } + + /** + * Helper method called by {@link #toString()} in order to convert a single map key into a string. + * This is protected to allow subclasses to override the appearance of a given key. + */ + protected String keyToString(long key) { + return Long.toString(key); + } + + /** + * Set implementation for iterating over the entries of the map. + */ + private final class EntrySet extends AbstractSet> { + @Override + public Iterator> iterator() { + return new MapIterator(); + } + + @Override + public int size() { + return LongObjectHashMap.this.size(); + } + } + + /** + * Set implementation for iterating over the keys. + */ + private final class KeySet extends AbstractSet { + @Override + public int size() { + return LongObjectHashMap.this.size(); + } + + @Override + public boolean contains(Object o) { + return LongObjectHashMap.this.containsKey(o); + } + + @Override + public boolean remove(Object o) { + return LongObjectHashMap.this.remove(o) != null; + } + + @Override + public boolean retainAll(Collection retainedKeys) { + boolean changed = false; + for(Iterator> iter = entries().iterator(); iter.hasNext(); ) { + PrimitiveEntry entry = iter.next(); + if (!retainedKeys.contains(entry.key())) { + changed = true; + iter.remove(); + } + } + return changed; + } + + @Override + public void clear() { + LongObjectHashMap.this.clear(); + } + + @Override + public Iterator iterator() { + return new Iterator() { + private final Iterator> iter = entrySet.iterator(); + + @Override + public boolean hasNext() { + return iter.hasNext(); + } + + @Override + public Long next() { + return iter.next().getKey(); + } + + @Override + public void remove() { + iter.remove(); + } + }; + } + } + + /** + * Iterator over primitive entries. Entry key/values are overwritten by each call to {@link #next()}. + */ + private final class PrimitiveIterator implements Iterator>, PrimitiveEntry { + private int prevIndex = -1; + private int nextIndex = -1; + private int entryIndex = -1; + + private void scanNext() { + while (++nextIndex != values.length && values[nextIndex] == null) { + } + } + + @Override + public boolean hasNext() { + if (nextIndex == -1) { + scanNext(); + } + return nextIndex != values.length; + } + + @Override + public PrimitiveEntry next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + + prevIndex = nextIndex; + scanNext(); + + // Always return the same Entry object, just change its index each time. + entryIndex = prevIndex; + return this; + } + + @Override + public void remove() { + if (prevIndex == -1) { + throw new IllegalStateException("next must be called before each remove."); + } + if (removeAt(prevIndex)) { + // removeAt may move elements "back" in the array if they have been displaced because their spot in the + // array was occupied when they were inserted. If this occurs then the nextIndex is now invalid and + // should instead point to the prevIndex which now holds an element which was "moved back". + nextIndex = prevIndex; + } + prevIndex = -1; + } + + // Entry implementation. Since this implementation uses a single Entry, we coalesce that + // into the Iterator object (potentially making loop optimization much easier). + + @Override + public long key() { + return keys[entryIndex]; + } + + @Override + public V value() { + return toExternal(values[entryIndex]); + } + + @Override + public void setValue(V value) { + values[entryIndex] = toInternal(value); + } + } + + /** + * Iterator used by the {@link Map} interface. + */ + private final class MapIterator implements Iterator> { + private final PrimitiveIterator iter = new PrimitiveIterator(); + + @Override + public boolean hasNext() { + return iter.hasNext(); + } + + @Override + public Entry next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + + iter.next(); + + return new MapEntry(iter.entryIndex); + } + + @Override + public void remove() { + iter.remove(); + } + } + + /** + * A single entry in the map. + */ + final class MapEntry implements Entry { + private final int entryIndex; + + MapEntry(int entryIndex) { + this.entryIndex = entryIndex; + } + + @Override + public Long getKey() { + verifyExists(); + return keys[entryIndex]; + } + + @Override + public V getValue() { + verifyExists(); + return toExternal(values[entryIndex]); + } + + @Override + public V setValue(V value) { + verifyExists(); + V prevValue = toExternal(values[entryIndex]); + values[entryIndex] = toInternal(value); + return prevValue; + } + + private void verifyExists() { + if (values[entryIndex] == null) { + throw new IllegalStateException("The map entry has been removed"); + } + } + } +} diff --git a/netty-util/src/main/java/io/netty/util/collection/LongObjectMap.java b/netty-util/src/main/java/io/netty/util/collection/LongObjectMap.java new file mode 100644 index 0000000..6c23583 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/collection/LongObjectMap.java @@ -0,0 +1,84 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.util.collection; + +import java.util.Map; + +/** + * Interface for a primitive map that uses {@code long}s as keys. + * + * @param the value type stored in the map. + */ +public interface LongObjectMap extends Map { + + /** + * A primitive entry in the map, provided by the iterator from {@link #entries()} + * + * @param the value type stored in the map. + */ + interface PrimitiveEntry { + /** + * Gets the key for this entry. + */ + long key(); + + /** + * Gets the value for this entry. + */ + V value(); + + /** + * Sets the value for this entry. + */ + void setValue(V value); + } + + /** + * Gets the value in the map with the specified key. + * + * @param key the key whose associated value is to be returned. + * @return the value or {@code null} if the key was not found in the map. + */ + V get(long key); + + /** + * Puts the given entry into the map. + * + * @param key the key of the entry. + * @param value the value of the entry. + * @return the previous value for this key or {@code null} if there was no previous mapping. + */ + V put(long key, V value); + + /** + * Removes the entry with the specified key. + * + * @param key the key for the entry to be removed from this map. + * @return the previous value for the key, or {@code null} if there was no mapping. + */ + V remove(long key); + + /** + * Gets an iterable to traverse over the primitive entries contained in this map. As an optimization, + * the {@link PrimitiveEntry}s returned by the {@link java.util.Iterator} may change as the {@link java.util.Iterator} + * progresses. The caller should not rely on {@link PrimitiveEntry} key/value stability. + */ + Iterable> entries(); + + /** + * Indicates whether or not this map contains a value for the specified key. + */ + boolean containsKey(long key); +} diff --git a/netty-util/src/main/java/io/netty/util/collection/ShortCollections.java b/netty-util/src/main/java/io/netty/util/collection/ShortCollections.java new file mode 100644 index 0000000..f4ec3fb --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/collection/ShortCollections.java @@ -0,0 +1,313 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.util.collection; + +import java.util.Collection; +import java.util.Collections; +import java.util.Iterator; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Set; + +/** + * Utilities for short-based primitive collections. + */ +public final class ShortCollections { + + private static final ShortObjectMap EMPTY_MAP = new EmptyMap(); + + private ShortCollections() { + } + + /** + * Returns an unmodifiable empty {@link ShortObjectMap}. + */ + @SuppressWarnings("unchecked") + public static ShortObjectMap emptyMap() { + return (ShortObjectMap) EMPTY_MAP; + } + + /** + * Creates an unmodifiable wrapper around the given map. + */ + public static ShortObjectMap unmodifiableMap(final ShortObjectMap map) { + return new UnmodifiableMap(map); + } + + /** + * An empty map. All operations that attempt to modify the map are unsupported. + */ + private static final class EmptyMap implements ShortObjectMap { + @Override + public Object get(short key) { + return null; + } + + @Override + public Object put(short key, Object value) { + throw new UnsupportedOperationException("put"); + } + + @Override + public Object remove(short key) { + return null; + } + + @Override + public int size() { + return 0; + } + + @Override + public boolean isEmpty() { + return true; + } + + @Override + public boolean containsKey(Object key) { + return false; + } + + @Override + public void clear() { + // Do nothing. + } + + @Override + public Set keySet() { + return Collections.emptySet(); + } + + @Override + public boolean containsKey(short key) { + return false; + } + + @Override + public boolean containsValue(Object value) { + return false; + } + + @Override + public Iterable> entries() { + return Collections.emptySet(); + } + + @Override + public Object get(Object key) { + return null; + } + + @Override + public Object put(Short key, Object value) { + throw new UnsupportedOperationException(); + } + + @Override + public Object remove(Object key) { + return null; + } + + @Override + public void putAll(Map m) { + throw new UnsupportedOperationException(); + } + + @Override + public Collection values() { + return Collections.emptyList(); + } + + @Override + public Set> entrySet() { + return Collections.emptySet(); + } + } + + /** + * An unmodifiable wrapper around a {@link ShortObjectMap}. + * + * @param the value type stored in the map. + */ + private static final class UnmodifiableMap implements ShortObjectMap { + private final ShortObjectMap map; + private Set keySet; + private Set> entrySet; + private Collection values; + private Iterable> entries; + + UnmodifiableMap(ShortObjectMap map) { + this.map = map; + } + + @Override + public V get(short key) { + return map.get(key); + } + + @Override + public V put(short key, V value) { + throw new UnsupportedOperationException("put"); + } + + @Override + public V remove(short key) { + throw new UnsupportedOperationException("remove"); + } + + @Override + public int size() { + return map.size(); + } + + @Override + public boolean isEmpty() { + return map.isEmpty(); + } + + @Override + public void clear() { + throw new UnsupportedOperationException("clear"); + } + + @Override + public boolean containsKey(short key) { + return map.containsKey(key); + } + + @Override + public boolean containsValue(Object value) { + return map.containsValue(value); + } + + @Override + public boolean containsKey(Object key) { + return map.containsKey(key); + } + + @Override + public V get(Object key) { + return map.get(key); + } + + @Override + public V put(Short key, V value) { + throw new UnsupportedOperationException("put"); + } + + @Override + public V remove(Object key) { + throw new UnsupportedOperationException("remove"); + } + + @Override + public void putAll(Map m) { + throw new UnsupportedOperationException("putAll"); + } + + @Override + public Iterable> entries() { + if (entries == null) { + entries = new Iterable>() { + @Override + public Iterator> iterator() { + return new IteratorImpl(map.entries().iterator()); + } + }; + } + + return entries; + } + + @Override + public Set keySet() { + if (keySet == null) { + keySet = Collections.unmodifiableSet(map.keySet()); + } + return keySet; + } + + @Override + public Set> entrySet() { + if (entrySet == null) { + entrySet = Collections.unmodifiableSet(map.entrySet()); + } + return entrySet; + } + + @Override + public Collection values() { + if (values == null) { + values = Collections.unmodifiableCollection(map.values()); + } + return values; + } + + /** + * Unmodifiable wrapper for an iterator. + */ + private class IteratorImpl implements Iterator> { + final Iterator> iter; + + IteratorImpl(Iterator> iter) { + this.iter = iter; + } + + @Override + public boolean hasNext() { + return iter.hasNext(); + } + + @Override + public PrimitiveEntry next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + return new EntryImpl(iter.next()); + } + + @Override + public void remove() { + throw new UnsupportedOperationException("remove"); + } + } + + /** + * Unmodifiable wrapper for an entry. + */ + private class EntryImpl implements PrimitiveEntry { + private final PrimitiveEntry entry; + + EntryImpl(PrimitiveEntry entry) { + this.entry = entry; + } + + @Override + public short key() { + return entry.key(); + } + + @Override + public V value() { + return entry.value(); + } + + @Override + public void setValue(V value) { + throw new UnsupportedOperationException("setValue"); + } + } + } +} diff --git a/netty-util/src/main/java/io/netty/util/collection/ShortObjectHashMap.java b/netty-util/src/main/java/io/netty/util/collection/ShortObjectHashMap.java new file mode 100644 index 0000000..dc53e6d --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/collection/ShortObjectHashMap.java @@ -0,0 +1,723 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package io.netty.util.collection; + +import static io.netty.util.internal.MathUtil.safeFindNextPositivePowerOfTwo; + +import java.util.AbstractCollection; +import java.util.AbstractSet; +import java.util.Arrays; +import java.util.Collection; +import java.util.Iterator; +import java.util.Map; +import java.util.NoSuchElementException; +import java.util.Set; + +/** + * A hash map implementation of {@link ShortObjectMap} that uses open addressing for keys. + * To minimize the memory footprint, this class uses open addressing rather than chaining. + * Collisions are resolved using linear probing. Deletions implement compaction, so cost of + * remove can approach O(N) for full maps, which makes a small loadFactor recommended. + * + * @param The value type stored in the map. + */ +public class ShortObjectHashMap implements ShortObjectMap { + + /** Default initial capacity. Used if not specified in the constructor */ + public static final int DEFAULT_CAPACITY = 8; + + /** Default load factor. Used if not specified in the constructor */ + public static final float DEFAULT_LOAD_FACTOR = 0.5f; + + /** + * Placeholder for null values, so we can use the actual null to mean available. + * (Better than using a placeholder for available: less references for GC processing.) + */ + private static final Object NULL_VALUE = new Object(); + + /** The maximum number of elements allowed without allocating more space. */ + private int maxSize; + + /** The load factor for the map. Used to calculate {@link #maxSize}. */ + private final float loadFactor; + + private short[] keys; + private V[] values; + private int size; + private int mask; + + private final Set keySet = new KeySet(); + private final Set> entrySet = new EntrySet(); + private final Iterable> entries = new Iterable>() { + @Override + public Iterator> iterator() { + return new PrimitiveIterator(); + } + }; + + public ShortObjectHashMap() { + this(DEFAULT_CAPACITY, DEFAULT_LOAD_FACTOR); + } + + public ShortObjectHashMap(int initialCapacity) { + this(initialCapacity, DEFAULT_LOAD_FACTOR); + } + + public ShortObjectHashMap(int initialCapacity, float loadFactor) { + if (loadFactor <= 0.0f || loadFactor > 1.0f) { + // Cannot exceed 1 because we can never store more than capacity elements; + // using a bigger loadFactor would trigger rehashing before the desired load is reached. + throw new IllegalArgumentException("loadFactor must be > 0 and <= 1"); + } + + this.loadFactor = loadFactor; + + // Adjust the initial capacity if necessary. + int capacity = safeFindNextPositivePowerOfTwo(initialCapacity); + mask = capacity - 1; + + // Allocate the arrays. + keys = new short[capacity]; + @SuppressWarnings({ "unchecked", "SuspiciousArrayCast" }) + V[] temp = (V[]) new Object[capacity]; + values = temp; + + // Initialize the maximum size value. + maxSize = calcMaxSize(capacity); + } + + private static T toExternal(T value) { + assert value != null : "null is not a legitimate internal value. Concurrent Modification?"; + return value == NULL_VALUE ? null : value; + } + + @SuppressWarnings("unchecked") + private static T toInternal(T value) { + return value == null ? (T) NULL_VALUE : value; + } + + @Override + public V get(short key) { + int index = indexOf(key); + return index == -1 ? null : toExternal(values[index]); + } + + @Override + public V put(short key, V value) { + int startIndex = hashIndex(key); + int index = startIndex; + + for (;;) { + if (values[index] == null) { + // Found empty slot, use it. + keys[index] = key; + values[index] = toInternal(value); + growSize(); + return null; + } + if (keys[index] == key) { + // Found existing entry with this key, just replace the value. + V previousValue = values[index]; + values[index] = toInternal(value); + return toExternal(previousValue); + } + + // Conflict, keep probing ... + if ((index = probeNext(index)) == startIndex) { + // Can only happen if the map was full at MAX_ARRAY_SIZE and couldn't grow. + throw new IllegalStateException("Unable to insert"); + } + } + } + + @Override + public void putAll(Map sourceMap) { + if (sourceMap instanceof ShortObjectHashMap) { + // Optimization - iterate through the arrays. + @SuppressWarnings("unchecked") + ShortObjectHashMap source = (ShortObjectHashMap) sourceMap; + for (int i = 0; i < source.values.length; ++i) { + V sourceValue = source.values[i]; + if (sourceValue != null) { + put(source.keys[i], sourceValue); + } + } + return; + } + + // Otherwise, just add each entry. + for (Entry entry : sourceMap.entrySet()) { + put(entry.getKey(), entry.getValue()); + } + } + + @Override + public V remove(short key) { + int index = indexOf(key); + if (index == -1) { + return null; + } + + V prev = values[index]; + removeAt(index); + return toExternal(prev); + } + + @Override + public int size() { + return size; + } + + @Override + public boolean isEmpty() { + return size == 0; + } + + @Override + public void clear() { + Arrays.fill(keys, (short) 0); + Arrays.fill(values, null); + size = 0; + } + + @Override + public boolean containsKey(short key) { + return indexOf(key) >= 0; + } + + @Override + public boolean containsValue(Object value) { + @SuppressWarnings("unchecked") + V v1 = toInternal((V) value); + for (V v2 : values) { + // The map supports null values; this will be matched as NULL_VALUE.equals(NULL_VALUE). + if (v2 != null && v2.equals(v1)) { + return true; + } + } + return false; + } + + @Override + public Iterable> entries() { + return entries; + } + + @Override + public Collection values() { + return new AbstractCollection() { + @Override + public Iterator iterator() { + return new Iterator() { + final PrimitiveIterator iter = new PrimitiveIterator(); + + @Override + public boolean hasNext() { + return iter.hasNext(); + } + + @Override + public V next() { + return iter.next().value(); + } + + @Override + public void remove() { + iter.remove(); + } + }; + } + + @Override + public int size() { + return size; + } + }; + } + + @Override + public int hashCode() { + // Hashcode is based on all non-zero, valid keys. We have to scan the whole keys + // array, which may have different lengths for two maps of same size(), so the + // capacity cannot be used as input for hashing but the size can. + int hash = size; + for (short key : keys) { + // 0 can be a valid key or unused slot, but won't impact the hashcode in either case. + // This way we can use a cheap loop without conditionals, or hard-to-unroll operations, + // or the devastatingly bad memory locality of visiting value objects. + // Also, it's important to use a hash function that does not depend on the ordering + // of terms, only their values; since the map is an unordered collection and + // entries can end up in different positions in different maps that have the same + // elements, but with different history of puts/removes, due to conflicts. + hash ^= hashCode(key); + } + return hash; + } + + @Override + public boolean equals(Object obj) { + if (this == obj) { + return true; + } + if (!(obj instanceof ShortObjectMap)) { + return false; + } + @SuppressWarnings("rawtypes") + ShortObjectMap other = (ShortObjectMap) obj; + if (size != other.size()) { + return false; + } + for (int i = 0; i < values.length; ++i) { + V value = values[i]; + if (value != null) { + short key = keys[i]; + Object otherValue = other.get(key); + if (value == NULL_VALUE) { + if (otherValue != null) { + return false; + } + } else if (!value.equals(otherValue)) { + return false; + } + } + } + return true; + } + + @Override + public boolean containsKey(Object key) { + return containsKey(objectToKey(key)); + } + + @Override + public V get(Object key) { + return get(objectToKey(key)); + } + + @Override + public V put(Short key, V value) { + return put(objectToKey(key), value); + } + + @Override + public V remove(Object key) { + return remove(objectToKey(key)); + } + + @Override + public Set keySet() { + return keySet; + } + + @Override + public Set> entrySet() { + return entrySet; + } + + private short objectToKey(Object key) { + return (short) ((Short) key).shortValue(); + } + + /** + * Locates the index for the given key. This method probes using double hashing. + * + * @param key the key for an entry in the map. + * @return the index where the key was found, or {@code -1} if no entry is found for that key. + */ + private int indexOf(short key) { + int startIndex = hashIndex(key); + int index = startIndex; + + for (;;) { + if (values[index] == null) { + // It's available, so no chance that this value exists anywhere in the map. + return -1; + } + if (key == keys[index]) { + return index; + } + + // Conflict, keep probing ... + if ((index = probeNext(index)) == startIndex) { + return -1; + } + } + } + + /** + * Returns the hashed index for the given key. + */ + private int hashIndex(short key) { + // The array lengths are always a power of two, so we can use a bitmask to stay inside the array bounds. + return hashCode(key) & mask; + } + + /** + * Returns the hash code for the key. + */ + private static int hashCode(short key) { + return (int) key; + } + + /** + * Get the next sequential index after {@code index} and wraps if necessary. + */ + private int probeNext(int index) { + // The array lengths are always a power of two, so we can use a bitmask to stay inside the array bounds. + return (index + 1) & mask; + } + + /** + * Grows the map size after an insertion. If necessary, performs a rehash of the map. + */ + private void growSize() { + size++; + + if (size > maxSize) { + if(keys.length == Integer.MAX_VALUE) { + throw new IllegalStateException("Max capacity reached at size=" + size); + } + + // Double the capacity. + rehash(keys.length << 1); + } + } + + /** + * Removes entry at the given index position. Also performs opportunistic, incremental rehashing + * if necessary to not break conflict chains. + * + * @param index the index position of the element to remove. + * @return {@code true} if the next item was moved back. {@code false} otherwise. + */ + private boolean removeAt(final int index) { + --size; + // Clearing the key is not strictly necessary (for GC like in a regular collection), + // but recommended for security. The memory location is still fresh in the cache anyway. + keys[index] = 0; + values[index] = null; + + // In the interval from index to the next available entry, the arrays may have entries + // that are displaced from their base position due to prior conflicts. Iterate these + // entries and move them back if possible, optimizing future lookups. + // Knuth Section 6.4 Algorithm R, also used by the JDK's IdentityHashMap. + + int nextFree = index; + int i = probeNext(index); + for (V value = values[i]; value != null; value = values[i = probeNext(i)]) { + short key = keys[i]; + int bucket = hashIndex(key); + if (i < bucket && (bucket <= nextFree || nextFree <= i) || + bucket <= nextFree && nextFree <= i) { + // Move the displaced entry "back" to the first available position. + keys[nextFree] = key; + values[nextFree] = value; + // Put the first entry after the displaced entry + keys[i] = 0; + values[i] = null; + nextFree = i; + } + } + return nextFree != index; + } + + /** + * Calculates the maximum size allowed before rehashing. + */ + private int calcMaxSize(int capacity) { + // Clip the upper bound so that there will always be at least one available slot. + int upperBound = capacity - 1; + return Math.min(upperBound, (int) (capacity * loadFactor)); + } + + /** + * Rehashes the map for the given capacity. + * + * @param newCapacity the new capacity for the map. + */ + private void rehash(int newCapacity) { + short[] oldKeys = keys; + V[] oldVals = values; + + keys = new short[newCapacity]; + @SuppressWarnings({ "unchecked", "SuspiciousArrayCast" }) + V[] temp = (V[]) new Object[newCapacity]; + values = temp; + + maxSize = calcMaxSize(newCapacity); + mask = newCapacity - 1; + + // Insert to the new arrays. + for (int i = 0; i < oldVals.length; ++i) { + V oldVal = oldVals[i]; + if (oldVal != null) { + // Inlined put(), but much simpler: we don't need to worry about + // duplicated keys, growing/rehashing, or failing to insert. + short oldKey = oldKeys[i]; + int index = hashIndex(oldKey); + + for (;;) { + if (values[index] == null) { + keys[index] = oldKey; + values[index] = oldVal; + break; + } + + // Conflict, keep probing. Can wrap around, but never reaches startIndex again. + index = probeNext(index); + } + } + } + } + + @Override + public String toString() { + if (isEmpty()) { + return "{}"; + } + StringBuilder sb = new StringBuilder(4 * size); + sb.append('{'); + boolean first = true; + for (int i = 0; i < values.length; ++i) { + V value = values[i]; + if (value != null) { + if (!first) { + sb.append(", "); + } + sb.append(keyToString(keys[i])).append('=').append(value == this ? "(this Map)" : + toExternal(value)); + first = false; + } + } + return sb.append('}').toString(); + } + + /** + * Helper method called by {@link #toString()} in order to convert a single map key into a string. + * This is protected to allow subclasses to override the appearance of a given key. + */ + protected String keyToString(short key) { + return Short.toString(key); + } + + /** + * Set implementation for iterating over the entries of the map. + */ + private final class EntrySet extends AbstractSet> { + @Override + public Iterator> iterator() { + return new MapIterator(); + } + + @Override + public int size() { + return ShortObjectHashMap.this.size(); + } + } + + /** + * Set implementation for iterating over the keys. + */ + private final class KeySet extends AbstractSet { + @Override + public int size() { + return ShortObjectHashMap.this.size(); + } + + @Override + public boolean contains(Object o) { + return ShortObjectHashMap.this.containsKey(o); + } + + @Override + public boolean remove(Object o) { + return ShortObjectHashMap.this.remove(o) != null; + } + + @Override + public boolean retainAll(Collection retainedKeys) { + boolean changed = false; + for(Iterator> iter = entries().iterator(); iter.hasNext(); ) { + PrimitiveEntry entry = iter.next(); + if (!retainedKeys.contains(entry.key())) { + changed = true; + iter.remove(); + } + } + return changed; + } + + @Override + public void clear() { + ShortObjectHashMap.this.clear(); + } + + @Override + public Iterator iterator() { + return new Iterator() { + private final Iterator> iter = entrySet.iterator(); + + @Override + public boolean hasNext() { + return iter.hasNext(); + } + + @Override + public Short next() { + return iter.next().getKey(); + } + + @Override + public void remove() { + iter.remove(); + } + }; + } + } + + /** + * Iterator over primitive entries. Entry key/values are overwritten by each call to {@link #next()}. + */ + private final class PrimitiveIterator implements Iterator>, PrimitiveEntry { + private int prevIndex = -1; + private int nextIndex = -1; + private int entryIndex = -1; + + private void scanNext() { + while (++nextIndex != values.length && values[nextIndex] == null) { + } + } + + @Override + public boolean hasNext() { + if (nextIndex == -1) { + scanNext(); + } + return nextIndex != values.length; + } + + @Override + public PrimitiveEntry next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + + prevIndex = nextIndex; + scanNext(); + + // Always return the same Entry object, just change its index each time. + entryIndex = prevIndex; + return this; + } + + @Override + public void remove() { + if (prevIndex == -1) { + throw new IllegalStateException("next must be called before each remove."); + } + if (removeAt(prevIndex)) { + // removeAt may move elements "back" in the array if they have been displaced because their spot in the + // array was occupied when they were inserted. If this occurs then the nextIndex is now invalid and + // should instead point to the prevIndex which now holds an element which was "moved back". + nextIndex = prevIndex; + } + prevIndex = -1; + } + + // Entry implementation. Since this implementation uses a single Entry, we coalesce that + // into the Iterator object (potentially making loop optimization much easier). + + @Override + public short key() { + return keys[entryIndex]; + } + + @Override + public V value() { + return toExternal(values[entryIndex]); + } + + @Override + public void setValue(V value) { + values[entryIndex] = toInternal(value); + } + } + + /** + * Iterator used by the {@link Map} interface. + */ + private final class MapIterator implements Iterator> { + private final PrimitiveIterator iter = new PrimitiveIterator(); + + @Override + public boolean hasNext() { + return iter.hasNext(); + } + + @Override + public Entry next() { + if (!hasNext()) { + throw new NoSuchElementException(); + } + + iter.next(); + + return new MapEntry(iter.entryIndex); + } + + @Override + public void remove() { + iter.remove(); + } + } + + /** + * A single entry in the map. + */ + final class MapEntry implements Entry { + private final int entryIndex; + + MapEntry(int entryIndex) { + this.entryIndex = entryIndex; + } + + @Override + public Short getKey() { + verifyExists(); + return keys[entryIndex]; + } + + @Override + public V getValue() { + verifyExists(); + return toExternal(values[entryIndex]); + } + + @Override + public V setValue(V value) { + verifyExists(); + V prevValue = toExternal(values[entryIndex]); + values[entryIndex] = toInternal(value); + return prevValue; + } + + private void verifyExists() { + if (values[entryIndex] == null) { + throw new IllegalStateException("The map entry has been removed"); + } + } + } +} diff --git a/netty-util/src/main/java/io/netty/util/collection/ShortObjectMap.java b/netty-util/src/main/java/io/netty/util/collection/ShortObjectMap.java new file mode 100644 index 0000000..d34ffc4 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/collection/ShortObjectMap.java @@ -0,0 +1,84 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.util.collection; + +import java.util.Map; + +/** + * Interface for a primitive map that uses {@code short}s as keys. + * + * @param the value type stored in the map. + */ +public interface ShortObjectMap extends Map { + + /** + * A primitive entry in the map, provided by the iterator from {@link #entries()} + * + * @param the value type stored in the map. + */ + interface PrimitiveEntry { + /** + * Gets the key for this entry. + */ + short key(); + + /** + * Gets the value for this entry. + */ + V value(); + + /** + * Sets the value for this entry. + */ + void setValue(V value); + } + + /** + * Gets the value in the map with the specified key. + * + * @param key the key whose associated value is to be returned. + * @return the value or {@code null} if the key was not found in the map. + */ + V get(short key); + + /** + * Puts the given entry into the map. + * + * @param key the key of the entry. + * @param value the value of the entry. + * @return the previous value for this key or {@code null} if there was no previous mapping. + */ + V put(short key, V value); + + /** + * Removes the entry with the specified key. + * + * @param key the key for the entry to be removed from this map. + * @return the previous value for the key, or {@code null} if there was no mapping. + */ + V remove(short key); + + /** + * Gets an iterable to traverse over the primitive entries contained in this map. As an optimization, + * the {@link PrimitiveEntry}s returned by the {@link java.util.Iterator} may change as the {@link java.util.Iterator} + * progresses. The caller should not rely on {@link PrimitiveEntry} key/value stability. + */ + Iterable> entries(); + + /** + * Indicates whether or not this map contains a value for the specified key. + */ + boolean containsKey(short key); +} diff --git a/netty-util/src/main/java/io/netty/util/concurrent/AbstractEventExecutor.java b/netty-util/src/main/java/io/netty/util/concurrent/AbstractEventExecutor.java new file mode 100644 index 0000000..96efbc3 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/concurrent/AbstractEventExecutor.java @@ -0,0 +1,191 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.concurrent; + +import io.netty.util.internal.UnstableApi; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; +import java.util.Collection; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.concurrent.AbstractExecutorService; +import java.util.concurrent.Callable; +import java.util.concurrent.RunnableFuture; +import java.util.concurrent.TimeUnit; + +/** + * Abstract base class for {@link EventExecutor} implementations. + */ +public abstract class AbstractEventExecutor extends AbstractExecutorService implements EventExecutor { + private static final InternalLogger logger = InternalLoggerFactory.getInstance(AbstractEventExecutor.class); + + static final long DEFAULT_SHUTDOWN_QUIET_PERIOD = 2; + static final long DEFAULT_SHUTDOWN_TIMEOUT = 15; + + private final EventExecutorGroup parent; + private final Collection selfCollection = Collections.singleton(this); + + protected AbstractEventExecutor() { + this(null); + } + + protected AbstractEventExecutor(EventExecutorGroup parent) { + this.parent = parent; + } + + @Override + public EventExecutorGroup parent() { + return parent; + } + + @Override + public EventExecutor next() { + return this; + } + + @Override + public boolean inEventLoop() { + return inEventLoop(Thread.currentThread()); + } + + @Override + public Iterator iterator() { + return selfCollection.iterator(); + } + + @Override + public Future shutdownGracefully() { + return shutdownGracefully(DEFAULT_SHUTDOWN_QUIET_PERIOD, DEFAULT_SHUTDOWN_TIMEOUT, TimeUnit.SECONDS); + } + + /** + * @deprecated {@link #shutdownGracefully(long, long, TimeUnit)} or {@link #shutdownGracefully()} instead. + */ + @Override + @Deprecated + public abstract void shutdown(); + + /** + * @deprecated {@link #shutdownGracefully(long, long, TimeUnit)} or {@link #shutdownGracefully()} instead. + */ + @Override + @Deprecated + public List shutdownNow() { + shutdown(); + return Collections.emptyList(); + } + + @Override + public Promise newPromise() { + return new DefaultPromise(this); + } + + @Override + public ProgressivePromise newProgressivePromise() { + return new DefaultProgressivePromise(this); + } + + @Override + public Future newSucceededFuture(V result) { + return new SucceededFuture(this, result); + } + + @Override + public Future newFailedFuture(Throwable cause) { + return new FailedFuture(this, cause); + } + + @Override + public Future submit(Runnable task) { + return (Future) super.submit(task); + } + + @Override + public Future submit(Runnable task, T result) { + return (Future) super.submit(task, result); + } + + @Override + public Future submit(Callable task) { + return (Future) super.submit(task); + } + + @Override + protected final RunnableFuture newTaskFor(Runnable runnable, T value) { + return new PromiseTask(this, runnable, value); + } + + @Override + protected final RunnableFuture newTaskFor(Callable callable) { + return new PromiseTask(this, callable); + } + + @Override + public ScheduledFuture schedule(Runnable command, long delay, + TimeUnit unit) { + throw new UnsupportedOperationException(); + } + + @Override + public ScheduledFuture schedule(Callable callable, long delay, TimeUnit unit) { + throw new UnsupportedOperationException(); + } + + @Override + public ScheduledFuture scheduleAtFixedRate(Runnable command, long initialDelay, long period, TimeUnit unit) { + throw new UnsupportedOperationException(); + } + + @Override + public ScheduledFuture scheduleWithFixedDelay(Runnable command, long initialDelay, long delay, TimeUnit unit) { + throw new UnsupportedOperationException(); + } + + /** + * Try to execute the given {@link Runnable} and just log if it throws a {@link Throwable}. + */ + protected static void safeExecute(Runnable task) { + try { + runTask(task); + } catch (Throwable t) { + logger.warn("A task raised an exception. Task: {}", task, t); + } + } + + protected static void runTask(Runnable task) { + task.run(); + } + + /** + * Like {@link #execute(Runnable)} but does not guarantee the task will be run until either + * a non-lazy task is executed or the executor is shut down. + *

+ * The default implementation just delegates to {@link #execute(Runnable)}. + *

+ */ + @UnstableApi + public void lazyExecute(Runnable task) { + execute(task); + } + + /** + * @deprecated override {@link SingleThreadEventExecutor#wakesUpForTask} to re-create this behaviour + */ + @Deprecated + public interface LazyRunnable extends Runnable { + } +} diff --git a/netty-util/src/main/java/io/netty/util/concurrent/AbstractEventExecutorGroup.java b/netty-util/src/main/java/io/netty/util/concurrent/AbstractEventExecutorGroup.java new file mode 100644 index 0000000..38c3fc8 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/concurrent/AbstractEventExecutorGroup.java @@ -0,0 +1,117 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.concurrent; + +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import static io.netty.util.concurrent.AbstractEventExecutor.DEFAULT_SHUTDOWN_QUIET_PERIOD; +import static io.netty.util.concurrent.AbstractEventExecutor.DEFAULT_SHUTDOWN_TIMEOUT; + + +/** + * Abstract base class for {@link EventExecutorGroup} implementations. + */ +public abstract class AbstractEventExecutorGroup implements EventExecutorGroup { + @Override + public Future submit(Runnable task) { + return next().submit(task); + } + + @Override + public Future submit(Runnable task, T result) { + return next().submit(task, result); + } + + @Override + public Future submit(Callable task) { + return next().submit(task); + } + + @Override + public ScheduledFuture schedule(Runnable command, long delay, TimeUnit unit) { + return next().schedule(command, delay, unit); + } + + @Override + public ScheduledFuture schedule(Callable callable, long delay, TimeUnit unit) { + return next().schedule(callable, delay, unit); + } + + @Override + public ScheduledFuture scheduleAtFixedRate(Runnable command, long initialDelay, long period, TimeUnit unit) { + return next().scheduleAtFixedRate(command, initialDelay, period, unit); + } + + @Override + public ScheduledFuture scheduleWithFixedDelay(Runnable command, long initialDelay, long delay, TimeUnit unit) { + return next().scheduleWithFixedDelay(command, initialDelay, delay, unit); + } + + @Override + public Future shutdownGracefully() { + return shutdownGracefully(DEFAULT_SHUTDOWN_QUIET_PERIOD, DEFAULT_SHUTDOWN_TIMEOUT, TimeUnit.SECONDS); + } + + /** + * @deprecated {@link #shutdownGracefully(long, long, TimeUnit)} or {@link #shutdownGracefully()} instead. + */ + @Override + @Deprecated + public abstract void shutdown(); + + /** + * @deprecated {@link #shutdownGracefully(long, long, TimeUnit)} or {@link #shutdownGracefully()} instead. + */ + @Override + @Deprecated + public List shutdownNow() { + shutdown(); + return Collections.emptyList(); + } + + @Override + public List> invokeAll(Collection> tasks) + throws InterruptedException { + return next().invokeAll(tasks); + } + + @Override + public List> invokeAll( + Collection> tasks, long timeout, TimeUnit unit) throws InterruptedException { + return next().invokeAll(tasks, timeout, unit); + } + + @Override + public T invokeAny(Collection> tasks) throws InterruptedException, ExecutionException { + return next().invokeAny(tasks); + } + + @Override + public T invokeAny(Collection> tasks, long timeout, TimeUnit unit) + throws InterruptedException, ExecutionException, TimeoutException { + return next().invokeAny(tasks, timeout, unit); + } + + @Override + public void execute(Runnable command) { + next().execute(command); + } +} diff --git a/netty-util/src/main/java/io/netty/util/concurrent/AbstractFuture.java b/netty-util/src/main/java/io/netty/util/concurrent/AbstractFuture.java new file mode 100644 index 0000000..82d6c14 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/concurrent/AbstractFuture.java @@ -0,0 +1,58 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.concurrent; + +import java.util.concurrent.CancellationException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; + +/** + * Abstract {@link Future} implementation which does not allow for cancellation. + * + * @param + */ +public abstract class AbstractFuture implements Future { + + @Override + public V get() throws InterruptedException, ExecutionException { + await(); + + Throwable cause = cause(); + if (cause == null) { + return getNow(); + } + if (cause instanceof CancellationException) { + throw (CancellationException) cause; + } + throw new ExecutionException(cause); + } + + @Override + public V get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException { + if (await(timeout, unit)) { + Throwable cause = cause(); + if (cause == null) { + return getNow(); + } + if (cause instanceof CancellationException) { + throw (CancellationException) cause; + } + throw new ExecutionException(cause); + } + throw new TimeoutException(); + } +} diff --git a/netty-util/src/main/java/io/netty/util/concurrent/AbstractScheduledEventExecutor.java b/netty-util/src/main/java/io/netty/util/concurrent/AbstractScheduledEventExecutor.java new file mode 100644 index 0000000..ae9f50a --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/concurrent/AbstractScheduledEventExecutor.java @@ -0,0 +1,341 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.concurrent; + +import io.netty.util.internal.DefaultPriorityQueue; +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.PriorityQueue; +import java.util.Comparator; +import java.util.Queue; +import java.util.concurrent.Callable; +import java.util.concurrent.TimeUnit; + +/** + * Abstract base class for {@link EventExecutor}s that want to support scheduling. + */ +public abstract class AbstractScheduledEventExecutor extends AbstractEventExecutor { + private static final Comparator> SCHEDULED_FUTURE_TASK_COMPARATOR = + new Comparator>() { + @Override + public int compare(ScheduledFutureTask o1, ScheduledFutureTask o2) { + return o1.compareTo(o2); + } + }; + + private static final long START_TIME = System.nanoTime(); + + static final Runnable WAKEUP_TASK = new Runnable() { + @Override + public void run() { + } // Do nothing + }; + + PriorityQueue> scheduledTaskQueue; + + long nextTaskId; + + protected AbstractScheduledEventExecutor() { + } + + protected AbstractScheduledEventExecutor(EventExecutorGroup parent) { + super(parent); + } + + /** + * Get the current time in nanoseconds by this executor's clock. This is not the same as {@link System#nanoTime()} + * for two reasons: + * + *
    + *
  • We apply a fixed offset to the {@link System#nanoTime() nanoTime}
  • + *
  • Implementations (in particular EmbeddedEventLoop) may use their own time source so they can control time + * for testing purposes.
  • + *
+ */ + protected long getCurrentTimeNanos() { + return defaultCurrentTimeNanos(); + } + + /** + * @deprecated Use the non-static {@link #getCurrentTimeNanos()} instead. + */ + @Deprecated + protected static long nanoTime() { + return defaultCurrentTimeNanos(); + } + + static long defaultCurrentTimeNanos() { + return System.nanoTime() - START_TIME; + } + + static long deadlineNanos(long nanoTime, long delay) { + long deadlineNanos = nanoTime + delay; + // Guard against overflow + return deadlineNanos < 0 ? Long.MAX_VALUE : deadlineNanos; + } + + /** + * Given an arbitrary deadline {@code deadlineNanos}, calculate the number of nano seconds from now + * {@code deadlineNanos} would expire. + * + * @param deadlineNanos An arbitrary deadline in nano seconds. + * @return the number of nano seconds from now {@code deadlineNanos} would expire. + */ + protected static long deadlineToDelayNanos(long deadlineNanos) { + return ScheduledFutureTask.deadlineToDelayNanos(defaultCurrentTimeNanos(), deadlineNanos); + } + + /** + * The initial value used for delay and computations based upon a monatomic time source. + * + * @return initial value used for delay and computations based upon a monatomic time source. + */ + protected static long initialNanoTime() { + return START_TIME; + } + + PriorityQueue> scheduledTaskQueue() { + if (scheduledTaskQueue == null) { + scheduledTaskQueue = new DefaultPriorityQueue>( + SCHEDULED_FUTURE_TASK_COMPARATOR, + // Use same initial capacity as java.util.PriorityQueue + 11); + } + return scheduledTaskQueue; + } + + private static boolean isNullOrEmpty(Queue> queue) { + return queue == null || queue.isEmpty(); + } + + /** + * Cancel all scheduled tasks. + *

+ * This method MUST be called only when {@link #inEventLoop()} is {@code true}. + */ + protected void cancelScheduledTasks() { + assert inEventLoop(); + PriorityQueue> scheduledTaskQueue = this.scheduledTaskQueue; + if (isNullOrEmpty(scheduledTaskQueue)) { + return; + } + + final ScheduledFutureTask[] scheduledTasks = + scheduledTaskQueue.toArray(new ScheduledFutureTask[0]); + + for (ScheduledFutureTask task : scheduledTasks) { + task.cancelWithoutRemove(false); + } + + scheduledTaskQueue.clearIgnoringIndexes(); + } + + /** + * @see #pollScheduledTask(long) + */ + protected final Runnable pollScheduledTask() { + return pollScheduledTask(getCurrentTimeNanos()); + } + + /** + * Return the {@link Runnable} which is ready to be executed with the given {@code nanoTime}. + * You should use {@link #getCurrentTimeNanos()} to retrieve the correct {@code nanoTime}. + */ + protected final Runnable pollScheduledTask(long nanoTime) { + assert inEventLoop(); + + ScheduledFutureTask scheduledTask = peekScheduledTask(); + if (scheduledTask == null || scheduledTask.deadlineNanos() - nanoTime > 0) { + return null; + } + scheduledTaskQueue.remove(); + scheduledTask.setConsumed(); + return scheduledTask; + } + + /** + * Return the nanoseconds until the next scheduled task is ready to be run or {@code -1} if no task is scheduled. + */ + protected final long nextScheduledTaskNano() { + ScheduledFutureTask scheduledTask = peekScheduledTask(); + return scheduledTask != null ? scheduledTask.delayNanos() : -1; + } + + /** + * Return the deadline (in nanoseconds) when the next scheduled task is ready to be run or {@code -1} + * if no task is scheduled. + */ + protected final long nextScheduledTaskDeadlineNanos() { + ScheduledFutureTask scheduledTask = peekScheduledTask(); + return scheduledTask != null ? scheduledTask.deadlineNanos() : -1; + } + + final ScheduledFutureTask peekScheduledTask() { + Queue> scheduledTaskQueue = this.scheduledTaskQueue; + return scheduledTaskQueue != null ? scheduledTaskQueue.peek() : null; + } + + /** + * Returns {@code true} if a scheduled task is ready for processing. + */ + protected final boolean hasScheduledTasks() { + ScheduledFutureTask scheduledTask = peekScheduledTask(); + return scheduledTask != null && scheduledTask.deadlineNanos() <= getCurrentTimeNanos(); + } + + @Override + public ScheduledFuture schedule(Runnable command, long delay, TimeUnit unit) { + ObjectUtil.checkNotNull(command, "command"); + ObjectUtil.checkNotNull(unit, "unit"); + if (delay < 0) { + delay = 0; + } + validateScheduled0(delay, unit); + + return schedule(new ScheduledFutureTask( + this, + command, + deadlineNanos(getCurrentTimeNanos(), unit.toNanos(delay)))); + } + + @Override + public ScheduledFuture schedule(Callable callable, long delay, TimeUnit unit) { + ObjectUtil.checkNotNull(callable, "callable"); + ObjectUtil.checkNotNull(unit, "unit"); + if (delay < 0) { + delay = 0; + } + validateScheduled0(delay, unit); + + return schedule(new ScheduledFutureTask( + this, callable, deadlineNanos(getCurrentTimeNanos(), unit.toNanos(delay)))); + } + + @Override + public ScheduledFuture scheduleAtFixedRate(Runnable command, long initialDelay, long period, TimeUnit unit) { + ObjectUtil.checkNotNull(command, "command"); + ObjectUtil.checkNotNull(unit, "unit"); + if (initialDelay < 0) { + throw new IllegalArgumentException( + String.format("initialDelay: %d (expected: >= 0)", initialDelay)); + } + if (period <= 0) { + throw new IllegalArgumentException( + String.format("period: %d (expected: > 0)", period)); + } + validateScheduled0(initialDelay, unit); + validateScheduled0(period, unit); + + return schedule(new ScheduledFutureTask( + this, command, deadlineNanos(getCurrentTimeNanos(), unit.toNanos(initialDelay)), unit.toNanos(period))); + } + + @Override + public ScheduledFuture scheduleWithFixedDelay(Runnable command, long initialDelay, long delay, TimeUnit unit) { + ObjectUtil.checkNotNull(command, "command"); + ObjectUtil.checkNotNull(unit, "unit"); + if (initialDelay < 0) { + throw new IllegalArgumentException( + String.format("initialDelay: %d (expected: >= 0)", initialDelay)); + } + if (delay <= 0) { + throw new IllegalArgumentException( + String.format("delay: %d (expected: > 0)", delay)); + } + + validateScheduled0(initialDelay, unit); + validateScheduled0(delay, unit); + + return schedule(new ScheduledFutureTask( + this, command, deadlineNanos(getCurrentTimeNanos(), unit.toNanos(initialDelay)), -unit.toNanos(delay))); + } + + @SuppressWarnings("deprecation") + private void validateScheduled0(long amount, TimeUnit unit) { + validateScheduled(amount, unit); + } + + /** + * Sub-classes may override this to restrict the maximal amount of time someone can use to schedule a task. + * + * @deprecated will be removed in the future. + */ + @Deprecated + protected void validateScheduled(long amount, TimeUnit unit) { + // NOOP + } + + final void scheduleFromEventLoop(final ScheduledFutureTask task) { + // nextTaskId a long and so there is no chance it will overflow back to 0 + scheduledTaskQueue().add(task.setId(++nextTaskId)); + } + + private ScheduledFuture schedule(final ScheduledFutureTask task) { + if (inEventLoop()) { + scheduleFromEventLoop(task); + } else { + final long deadlineNanos = task.deadlineNanos(); + // task will add itself to scheduled task queue when run if not expired + if (beforeScheduledTaskSubmitted(deadlineNanos)) { + execute(task); + } else { + lazyExecute(task); + // Second hook after scheduling to facilitate race-avoidance + if (afterScheduledTaskSubmitted(deadlineNanos)) { + execute(WAKEUP_TASK); + } + } + } + + return task; + } + + final void removeScheduled(final ScheduledFutureTask task) { + assert task.isCancelled(); + if (inEventLoop()) { + scheduledTaskQueue().removeTyped(task); + } else { + // task will remove itself from scheduled task queue when it runs + lazyExecute(task); + } + } + + /** + * Called from arbitrary non-{@link EventExecutor} threads prior to scheduled task submission. + * Returns {@code true} if the {@link EventExecutor} thread should be woken immediately to + * process the scheduled task (if not already awake). + *

+ * If {@code false} is returned, {@link #afterScheduledTaskSubmitted(long)} will be called with + * the same value after the scheduled task is enqueued, providing another opportunity + * to wake the {@link EventExecutor} thread if required. + * + * @param deadlineNanos deadline of the to-be-scheduled task + * relative to {@link AbstractScheduledEventExecutor#getCurrentTimeNanos()} + * @return {@code true} if the {@link EventExecutor} thread should be woken, {@code false} otherwise + */ + protected boolean beforeScheduledTaskSubmitted(long deadlineNanos) { + return true; + } + + /** + * See {@link #beforeScheduledTaskSubmitted(long)}. Called only after that method returns false. + * + * @param deadlineNanos relative to {@link AbstractScheduledEventExecutor#getCurrentTimeNanos()} + * @return {@code true} if the {@link EventExecutor} thread should be woken, {@code false} otherwise + */ + protected boolean afterScheduledTaskSubmitted(long deadlineNanos) { + return true; + } +} diff --git a/netty-util/src/main/java/io/netty/util/concurrent/BlockingOperationException.java b/netty-util/src/main/java/io/netty/util/concurrent/BlockingOperationException.java new file mode 100644 index 0000000..84e1be0 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/concurrent/BlockingOperationException.java @@ -0,0 +1,42 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.concurrent; + +/** + * An {@link IllegalStateException} which is raised when a user performed a blocking operation + * when the user is in an event loop thread. If a blocking operation is performed in an event loop + * thread, the blocking operation will most likely enter a dead lock state, hence throwing this + * exception. + */ +public class BlockingOperationException extends IllegalStateException { + + private static final long serialVersionUID = 2462223247762460301L; + + public BlockingOperationException() { + } + + public BlockingOperationException(String s) { + super(s); + } + + public BlockingOperationException(Throwable cause) { + super(cause); + } + + public BlockingOperationException(String message, Throwable cause) { + super(message, cause); + } +} diff --git a/netty-util/src/main/java/io/netty/util/concurrent/CompleteFuture.java b/netty-util/src/main/java/io/netty/util/concurrent/CompleteFuture.java new file mode 100644 index 0000000..34206e2 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/concurrent/CompleteFuture.java @@ -0,0 +1,149 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.util.concurrent; + +import io.netty.util.internal.ObjectUtil; +import java.util.concurrent.TimeUnit; + +/** + * A skeletal {@link Future} implementation which represents a {@link Future} which has been completed already. + */ +public abstract class CompleteFuture extends AbstractFuture { + + private final EventExecutor executor; + + /** + * Creates a new instance. + * + * @param executor the {@link EventExecutor} associated with this future + */ + protected CompleteFuture(EventExecutor executor) { + this.executor = executor; + } + + /** + * Return the {@link EventExecutor} which is used by this {@link CompleteFuture}. + */ + protected EventExecutor executor() { + return executor; + } + + @Override + public Future addListener(GenericFutureListener> listener) { + DefaultPromise.notifyListener(executor(), this, ObjectUtil.checkNotNull(listener, "listener")); + return this; + } + + @Override + public Future addListeners(GenericFutureListener>... listeners) { + for (GenericFutureListener> l : + ObjectUtil.checkNotNull(listeners, "listeners")) { + + if (l == null) { + break; + } + DefaultPromise.notifyListener(executor(), this, l); + } + return this; + } + + @Override + public Future removeListener(GenericFutureListener> listener) { + // NOOP + return this; + } + + @Override + public Future removeListeners(GenericFutureListener>... listeners) { + // NOOP + return this; + } + + @Override + public Future await() throws InterruptedException { + if (Thread.interrupted()) { + throw new InterruptedException(); + } + return this; + } + + @Override + public boolean await(long timeout, TimeUnit unit) throws InterruptedException { + if (Thread.interrupted()) { + throw new InterruptedException(); + } + return true; + } + + @Override + public Future sync() throws InterruptedException { + return this; + } + + @Override + public Future syncUninterruptibly() { + return this; + } + + @Override + public boolean await(long timeoutMillis) throws InterruptedException { + if (Thread.interrupted()) { + throw new InterruptedException(); + } + return true; + } + + @Override + public Future awaitUninterruptibly() { + return this; + } + + @Override + public boolean awaitUninterruptibly(long timeout, TimeUnit unit) { + return true; + } + + @Override + public boolean awaitUninterruptibly(long timeoutMillis) { + return true; + } + + @Override + public boolean isDone() { + return true; + } + + @Override + public boolean isCancellable() { + return false; + } + + @Override + public boolean isCancelled() { + return false; + } + + /** + * {@inheritDoc} + * + * @param mayInterruptIfRunning this value has no effect in this implementation. + */ + @Override + public boolean cancel(boolean mayInterruptIfRunning) { + return false; + } +} diff --git a/netty-util/src/main/java/io/netty/util/concurrent/DefaultEventExecutor.java b/netty-util/src/main/java/io/netty/util/concurrent/DefaultEventExecutor.java new file mode 100644 index 0000000..2923cac --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/concurrent/DefaultEventExecutor.java @@ -0,0 +1,75 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.concurrent; + +import java.util.concurrent.Executor; +import java.util.concurrent.ThreadFactory; + +/** + * Default {@link SingleThreadEventExecutor} implementation which just execute all submitted task in a + * serial fashion. + */ +public final class DefaultEventExecutor extends SingleThreadEventExecutor { + + public DefaultEventExecutor() { + this((EventExecutorGroup) null); + } + + public DefaultEventExecutor(ThreadFactory threadFactory) { + this(null, threadFactory); + } + + public DefaultEventExecutor(Executor executor) { + this(null, executor); + } + + public DefaultEventExecutor(EventExecutorGroup parent) { + this(parent, new DefaultThreadFactory(DefaultEventExecutor.class)); + } + + public DefaultEventExecutor(EventExecutorGroup parent, ThreadFactory threadFactory) { + super(parent, threadFactory, true); + } + + public DefaultEventExecutor(EventExecutorGroup parent, Executor executor) { + super(parent, executor, true); + } + + public DefaultEventExecutor(EventExecutorGroup parent, ThreadFactory threadFactory, int maxPendingTasks, + RejectedExecutionHandler rejectedExecutionHandler) { + super(parent, threadFactory, true, maxPendingTasks, rejectedExecutionHandler); + } + + public DefaultEventExecutor(EventExecutorGroup parent, Executor executor, int maxPendingTasks, + RejectedExecutionHandler rejectedExecutionHandler) { + super(parent, executor, true, maxPendingTasks, rejectedExecutionHandler); + } + + @Override + protected void run() { + for (; ; ) { + Runnable task = takeTask(); + if (task != null) { + runTask(task); + updateLastExecutionTime(); + } + + if (confirmShutdown()) { + break; + } + } + } +} diff --git a/netty-util/src/main/java/io/netty/util/concurrent/DefaultEventExecutorChooserFactory.java b/netty-util/src/main/java/io/netty/util/concurrent/DefaultEventExecutorChooserFactory.java new file mode 100644 index 0000000..a771096 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/concurrent/DefaultEventExecutorChooserFactory.java @@ -0,0 +1,76 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.concurrent; + +import io.netty.util.internal.UnstableApi; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicLong; + +/** + * Default implementation which uses simple round-robin to choose next {@link EventExecutor}. + */ +@UnstableApi +public final class DefaultEventExecutorChooserFactory implements EventExecutorChooserFactory { + + public static final DefaultEventExecutorChooserFactory INSTANCE = new DefaultEventExecutorChooserFactory(); + + private DefaultEventExecutorChooserFactory() { + } + + @Override + public EventExecutorChooser newChooser(EventExecutor[] executors) { + if (isPowerOfTwo(executors.length)) { + return new PowerOfTwoEventExecutorChooser(executors); + } else { + return new GenericEventExecutorChooser(executors); + } + } + + private static boolean isPowerOfTwo(int val) { + return (val & -val) == val; + } + + private static final class PowerOfTwoEventExecutorChooser implements EventExecutorChooser { + private final AtomicInteger idx = new AtomicInteger(); + private final EventExecutor[] executors; + + PowerOfTwoEventExecutorChooser(EventExecutor[] executors) { + this.executors = executors; + } + + @Override + public EventExecutor next() { + return executors[idx.getAndIncrement() & executors.length - 1]; + } + } + + private static final class GenericEventExecutorChooser implements EventExecutorChooser { + // Use a 'long' counter to avoid non-round-robin behaviour at the 32-bit overflow boundary. + // The 64-bit long solves this by placing the overflow so far into the future, that no system + // will encounter this in practice. + private final AtomicLong idx = new AtomicLong(); + private final EventExecutor[] executors; + + GenericEventExecutorChooser(EventExecutor[] executors) { + this.executors = executors; + } + + @Override + public EventExecutor next() { + return executors[(int) Math.abs(idx.getAndIncrement() % executors.length)]; + } + } +} diff --git a/netty-util/src/main/java/io/netty/util/concurrent/DefaultEventExecutorGroup.java b/netty-util/src/main/java/io/netty/util/concurrent/DefaultEventExecutorGroup.java new file mode 100644 index 0000000..3874fe1 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/concurrent/DefaultEventExecutorGroup.java @@ -0,0 +1,61 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.concurrent; + +import java.util.concurrent.Executor; +import java.util.concurrent.ThreadFactory; + +/** + * Default implementation of {@link MultithreadEventExecutorGroup} which will use {@link DefaultEventExecutor} instances + * to handle the tasks. + */ +public class DefaultEventExecutorGroup extends MultithreadEventExecutorGroup { + /** + * @see #DefaultEventExecutorGroup(int, ThreadFactory) + */ + public DefaultEventExecutorGroup(int nThreads) { + this(nThreads, null); + } + + /** + * Create a new instance. + * + * @param nThreads the number of threads that will be used by this instance. + * @param threadFactory the ThreadFactory to use, or {@code null} if the default should be used. + */ + public DefaultEventExecutorGroup(int nThreads, ThreadFactory threadFactory) { + this(nThreads, threadFactory, SingleThreadEventExecutor.DEFAULT_MAX_PENDING_EXECUTOR_TASKS, + RejectedExecutionHandlers.reject()); + } + + /** + * Create a new instance. + * + * @param nThreads the number of threads that will be used by this instance. + * @param threadFactory the ThreadFactory to use, or {@code null} if the default should be used. + * @param maxPendingTasks the maximum number of pending tasks before new tasks will be rejected. + * @param rejectedHandler the {@link RejectedExecutionHandler} to use. + */ + public DefaultEventExecutorGroup(int nThreads, ThreadFactory threadFactory, int maxPendingTasks, + RejectedExecutionHandler rejectedHandler) { + super(nThreads, threadFactory, maxPendingTasks, rejectedHandler); + } + + @Override + protected EventExecutor newChild(Executor executor, Object... args) throws Exception { + return new DefaultEventExecutor(this, executor, (Integer) args[0], (RejectedExecutionHandler) args[1]); + } +} diff --git a/netty-util/src/main/java/io/netty/util/concurrent/DefaultFutureListeners.java b/netty-util/src/main/java/io/netty/util/concurrent/DefaultFutureListeners.java new file mode 100644 index 0000000..0294965 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/concurrent/DefaultFutureListeners.java @@ -0,0 +1,86 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.concurrent; + +import java.util.Arrays; + +final class DefaultFutureListeners { + + private GenericFutureListener>[] listeners; + private int size; + private int progressiveSize; // the number of progressive listeners + + @SuppressWarnings("unchecked") + DefaultFutureListeners( + GenericFutureListener> first, GenericFutureListener> second) { + listeners = new GenericFutureListener[2]; + listeners[0] = first; + listeners[1] = second; + size = 2; + if (first instanceof GenericProgressiveFutureListener) { + progressiveSize++; + } + if (second instanceof GenericProgressiveFutureListener) { + progressiveSize++; + } + } + + public void add(GenericFutureListener> l) { + GenericFutureListener>[] listeners = this.listeners; + final int size = this.size; + if (size == listeners.length) { + this.listeners = listeners = Arrays.copyOf(listeners, size << 1); + } + listeners[size] = l; + this.size = size + 1; + + if (l instanceof GenericProgressiveFutureListener) { + progressiveSize++; + } + } + + public void remove(GenericFutureListener> l) { + final GenericFutureListener>[] listeners = this.listeners; + int size = this.size; + for (int i = 0; i < size; i++) { + if (listeners[i] == l) { + int listenersToMove = size - i - 1; + if (listenersToMove > 0) { + System.arraycopy(listeners, i + 1, listeners, i, listenersToMove); + } + listeners[--size] = null; + this.size = size; + + if (l instanceof GenericProgressiveFutureListener) { + progressiveSize--; + } + return; + } + } + } + + public GenericFutureListener>[] listeners() { + return listeners; + } + + public int size() { + return size; + } + + public int progressiveSize() { + return progressiveSize; + } +} diff --git a/netty-util/src/main/java/io/netty/util/concurrent/DefaultProgressivePromise.java b/netty-util/src/main/java/io/netty/util/concurrent/DefaultProgressivePromise.java new file mode 100644 index 0000000..bd8b06a --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/concurrent/DefaultProgressivePromise.java @@ -0,0 +1,129 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.util.concurrent; + +import static io.netty.util.internal.ObjectUtil.checkPositiveOrZero; + +public class DefaultProgressivePromise extends DefaultPromise implements ProgressivePromise { + + /** + * Creates a new instance. + *

+ * It is preferable to use {@link EventExecutor#newProgressivePromise()} to create a new progressive promise + * + * @param executor the {@link EventExecutor} which is used to notify the promise when it progresses or it is complete + */ + public DefaultProgressivePromise(EventExecutor executor) { + super(executor); + } + + protected DefaultProgressivePromise() { /* only for subclasses */ } + + @Override + public ProgressivePromise setProgress(long progress, long total) { + if (total < 0) { + // total unknown + total = -1; // normalize + checkPositiveOrZero(progress, "progress"); + } else if (progress < 0 || progress > total) { + throw new IllegalArgumentException( + "progress: " + progress + " (expected: 0 <= progress <= total (" + total + "))"); + } + + if (isDone()) { + throw new IllegalStateException("complete already"); + } + + notifyProgressiveListeners(progress, total); + return this; + } + + @Override + public boolean tryProgress(long progress, long total) { + if (total < 0) { + total = -1; + if (progress < 0 || isDone()) { + return false; + } + } else if (progress < 0 || progress > total || isDone()) { + return false; + } + + notifyProgressiveListeners(progress, total); + return true; + } + + @Override + public ProgressivePromise addListener(GenericFutureListener> listener) { + super.addListener(listener); + return this; + } + + @Override + public ProgressivePromise addListeners(GenericFutureListener>... listeners) { + super.addListeners(listeners); + return this; + } + + @Override + public ProgressivePromise removeListener(GenericFutureListener> listener) { + super.removeListener(listener); + return this; + } + + @Override + public ProgressivePromise removeListeners(GenericFutureListener>... listeners) { + super.removeListeners(listeners); + return this; + } + + @Override + public ProgressivePromise sync() throws InterruptedException { + super.sync(); + return this; + } + + @Override + public ProgressivePromise syncUninterruptibly() { + super.syncUninterruptibly(); + return this; + } + + @Override + public ProgressivePromise await() throws InterruptedException { + super.await(); + return this; + } + + @Override + public ProgressivePromise awaitUninterruptibly() { + super.awaitUninterruptibly(); + return this; + } + + @Override + public ProgressivePromise setSuccess(V result) { + super.setSuccess(result); + return this; + } + + @Override + public ProgressivePromise setFailure(Throwable cause) { + super.setFailure(cause); + return this; + } +} diff --git a/netty-util/src/main/java/io/netty/util/concurrent/DefaultPromise.java b/netty-util/src/main/java/io/netty/util/concurrent/DefaultPromise.java new file mode 100644 index 0000000..c5b1e3a --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/concurrent/DefaultPromise.java @@ -0,0 +1,887 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.concurrent; + +import io.netty.util.internal.InternalThreadLocalMap; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.StringUtil; +import io.netty.util.internal.SystemPropertyUtil; +import io.netty.util.internal.ThrowableUtil; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; +import java.util.concurrent.CancellationException; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; +import static io.netty.util.internal.ObjectUtil.checkNotNull; +import static java.util.concurrent.TimeUnit.MILLISECONDS; + +public class DefaultPromise extends AbstractFuture implements Promise { + private static final InternalLogger logger = InternalLoggerFactory.getInstance(DefaultPromise.class); + private static final InternalLogger rejectedExecutionLogger = + InternalLoggerFactory.getInstance(DefaultPromise.class.getName() + ".rejectedExecution"); + private static final int MAX_LISTENER_STACK_DEPTH = Math.min(8, + SystemPropertyUtil.getInt("io.netty.defaultPromise.maxListenerStackDepth", 8)); + @SuppressWarnings("rawtypes") + private static final AtomicReferenceFieldUpdater RESULT_UPDATER = + AtomicReferenceFieldUpdater.newUpdater(DefaultPromise.class, Object.class, "result"); + private static final Object SUCCESS = new Object(); + private static final Object UNCANCELLABLE = new Object(); + private static final CauseHolder CANCELLATION_CAUSE_HOLDER = new CauseHolder( + StacklessCancellationException.newInstance(DefaultPromise.class, "cancel(...)")); + private static final StackTraceElement[] CANCELLATION_STACK = CANCELLATION_CAUSE_HOLDER.cause.getStackTrace(); + + private volatile Object result; + private final EventExecutor executor; + /** + * One or more listeners. Can be a {@link GenericFutureListener} or a {@link DefaultFutureListeners}. + * If {@code null}, it means either 1) no listeners were added yet or 2) all listeners were notified. + *

+ * Threading - synchronized(this). We must support adding listeners when there is no EventExecutor. + */ + private GenericFutureListener> listener; + private DefaultFutureListeners listeners; + /** + * Threading - synchronized(this). We are required to hold the monitor to use Java's underlying wait()/notifyAll(). + */ + private short waiters; + + /** + * Threading - synchronized(this). We must prevent concurrent notification and FIFO listener notification if the + * executor changes. + */ + private boolean notifyingListeners; + + /** + * Creates a new instance. + *

+ * It is preferable to use {@link EventExecutor#newPromise()} to create a new promise + * + * @param executor the {@link EventExecutor} which is used to notify the promise once it is complete. + * It is assumed this executor will protect against {@link StackOverflowError} exceptions. + * The executor may be used to avoid {@link StackOverflowError} by executing a {@link Runnable} if the stack + * depth exceeds a threshold. + */ + public DefaultPromise(EventExecutor executor) { + this.executor = checkNotNull(executor, "executor"); + } + + /** + * See {@link #executor()} for expectations of the executor. + */ + protected DefaultPromise() { + // only for subclasses + executor = null; + } + + @Override + public Promise setSuccess(V result) { + if (setSuccess0(result)) { + return this; + } + throw new IllegalStateException("complete already: " + this); + } + + @Override + public boolean trySuccess(V result) { + return setSuccess0(result); + } + + @Override + public Promise setFailure(Throwable cause) { + if (setFailure0(cause)) { + return this; + } + throw new IllegalStateException("complete already: " + this, cause); + } + + @Override + public boolean tryFailure(Throwable cause) { + return setFailure0(cause); + } + + @Override + public boolean setUncancellable() { + if (RESULT_UPDATER.compareAndSet(this, null, UNCANCELLABLE)) { + return true; + } + Object result = this.result; + return !isDone0(result) || !isCancelled0(result); + } + + @Override + public boolean isSuccess() { + Object result = this.result; + return result != null && result != UNCANCELLABLE && !(result instanceof CauseHolder); + } + + @Override + public boolean isCancellable() { + return result == null; + } + + private static final class LeanCancellationException extends CancellationException { + private static final long serialVersionUID = 2794674970981187807L; + + // Suppress a warning since the method doesn't need synchronization + @Override + public Throwable fillInStackTrace() { + setStackTrace(CANCELLATION_STACK); + return this; + } + + @Override + public String toString() { + return CancellationException.class.getName(); + } + } + + @Override + public Throwable cause() { + return cause0(result); + } + + private Throwable cause0(Object result) { + if (!(result instanceof CauseHolder)) { + return null; + } + if (result == CANCELLATION_CAUSE_HOLDER) { + CancellationException ce = new LeanCancellationException(); + if (RESULT_UPDATER.compareAndSet(this, CANCELLATION_CAUSE_HOLDER, new CauseHolder(ce))) { + return ce; + } + result = this.result; + } + return ((CauseHolder) result).cause; + } + + @Override + public Promise addListener(GenericFutureListener> listener) { + checkNotNull(listener, "listener"); + + synchronized (this) { + addListener0(listener); + } + + if (isDone()) { + notifyListeners(); + } + + return this; + } + + @Override + public Promise addListeners(GenericFutureListener>... listeners) { + checkNotNull(listeners, "listeners"); + + synchronized (this) { + for (GenericFutureListener> listener : listeners) { + if (listener == null) { + break; + } + addListener0(listener); + } + } + + if (isDone()) { + notifyListeners(); + } + + return this; + } + + @Override + public Promise removeListener(final GenericFutureListener> listener) { + checkNotNull(listener, "listener"); + + synchronized (this) { + removeListener0(listener); + } + + return this; + } + + @Override + public Promise removeListeners(final GenericFutureListener>... listeners) { + checkNotNull(listeners, "listeners"); + + synchronized (this) { + for (GenericFutureListener> listener : listeners) { + if (listener == null) { + break; + } + removeListener0(listener); + } + } + + return this; + } + + @Override + public Promise await() throws InterruptedException { + if (isDone()) { + return this; + } + + if (Thread.interrupted()) { + throw new InterruptedException(toString()); + } + + checkDeadLock(); + + synchronized (this) { + while (!isDone()) { + incWaiters(); + try { + wait(); + } finally { + decWaiters(); + } + } + } + return this; + } + + @Override + public Promise awaitUninterruptibly() { + if (isDone()) { + return this; + } + + checkDeadLock(); + + boolean interrupted = false; + synchronized (this) { + while (!isDone()) { + incWaiters(); + try { + wait(); + } catch (InterruptedException e) { + // Interrupted while waiting. + interrupted = true; + } finally { + decWaiters(); + } + } + } + + if (interrupted) { + Thread.currentThread().interrupt(); + } + + return this; + } + + @Override + public boolean await(long timeout, TimeUnit unit) throws InterruptedException { + return await0(unit.toNanos(timeout), true); + } + + @Override + public boolean await(long timeoutMillis) throws InterruptedException { + return await0(MILLISECONDS.toNanos(timeoutMillis), true); + } + + @Override + public boolean awaitUninterruptibly(long timeout, TimeUnit unit) { + try { + return await0(unit.toNanos(timeout), false); + } catch (InterruptedException e) { + // Should not be raised at all. + throw new InternalError(); + } + } + + @Override + public boolean awaitUninterruptibly(long timeoutMillis) { + try { + return await0(MILLISECONDS.toNanos(timeoutMillis), false); + } catch (InterruptedException e) { + // Should not be raised at all. + throw new InternalError(); + } + } + + @SuppressWarnings("unchecked") + @Override + public V getNow() { + Object result = this.result; + if (result instanceof CauseHolder || result == SUCCESS || result == UNCANCELLABLE) { + return null; + } + return (V) result; + } + + @SuppressWarnings("unchecked") + @Override + public V get() throws InterruptedException, ExecutionException { + Object result = this.result; + if (!isDone0(result)) { + await(); + result = this.result; + } + if (result == SUCCESS || result == UNCANCELLABLE) { + return null; + } + Throwable cause = cause0(result); + if (cause == null) { + return (V) result; + } + if (cause instanceof CancellationException) { + throw (CancellationException) cause; + } + throw new ExecutionException(cause); + } + + @SuppressWarnings("unchecked") + @Override + public V get(long timeout, TimeUnit unit) throws InterruptedException, ExecutionException, TimeoutException { + Object result = this.result; + if (!isDone0(result)) { + if (!await(timeout, unit)) { + throw new TimeoutException(); + } + result = this.result; + } + if (result == SUCCESS || result == UNCANCELLABLE) { + return null; + } + Throwable cause = cause0(result); + if (cause == null) { + return (V) result; + } + if (cause instanceof CancellationException) { + throw (CancellationException) cause; + } + throw new ExecutionException(cause); + } + + /** + * {@inheritDoc} + * + * @param mayInterruptIfRunning this value has no effect in this implementation. + */ + @Override + public boolean cancel(boolean mayInterruptIfRunning) { + if (RESULT_UPDATER.compareAndSet(this, null, CANCELLATION_CAUSE_HOLDER)) { + if (checkNotifyWaiters()) { + notifyListeners(); + } + return true; + } + return false; + } + + @Override + public boolean isCancelled() { + return isCancelled0(result); + } + + @Override + public boolean isDone() { + return isDone0(result); + } + + @Override + public Promise sync() throws InterruptedException { + await(); + rethrowIfFailed(); + return this; + } + + @Override + public Promise syncUninterruptibly() { + awaitUninterruptibly(); + rethrowIfFailed(); + return this; + } + + @Override + public String toString() { + return toStringBuilder().toString(); + } + + protected StringBuilder toStringBuilder() { + StringBuilder buf = new StringBuilder(64) + .append(StringUtil.simpleClassName(this)) + .append('@') + .append(Integer.toHexString(hashCode())); + + Object result = this.result; + if (result == SUCCESS) { + buf.append("(success)"); + } else if (result == UNCANCELLABLE) { + buf.append("(uncancellable)"); + } else if (result instanceof CauseHolder) { + buf.append("(failure: ") + .append(((CauseHolder) result).cause) + .append(')'); + } else if (result != null) { + buf.append("(success: ") + .append(result) + .append(')'); + } else { + buf.append("(incomplete)"); + } + + return buf; + } + + /** + * Get the executor used to notify listeners when this promise is complete. + *

+ * It is assumed this executor will protect against {@link StackOverflowError} exceptions. + * The executor may be used to avoid {@link StackOverflowError} by executing a {@link Runnable} if the stack + * depth exceeds a threshold. + * + * @return The executor used to notify listeners when this promise is complete. + */ + protected EventExecutor executor() { + return executor; + } + + protected void checkDeadLock() { + EventExecutor e = executor(); + if (e != null && e.inEventLoop()) { + throw new BlockingOperationException(toString()); + } + } + + /** + * Notify a listener that a future has completed. + *

+ * This method has a fixed depth of {@link #MAX_LISTENER_STACK_DEPTH} that will limit recursion to prevent + * {@link StackOverflowError} and will stop notifying listeners added after this threshold is exceeded. + * + * @param eventExecutor the executor to use to notify the listener {@code listener}. + * @param future the future that is complete. + * @param listener the listener to notify. + */ + protected static void notifyListener( + EventExecutor eventExecutor, final Future future, final GenericFutureListener listener) { + notifyListenerWithStackOverFlowProtection( + checkNotNull(eventExecutor, "eventExecutor"), + checkNotNull(future, "future"), + checkNotNull(listener, "listener")); + } + + private void notifyListeners() { + EventExecutor executor = executor(); + if (executor.inEventLoop()) { + final InternalThreadLocalMap threadLocals = InternalThreadLocalMap.get(); + final int stackDepth = threadLocals.futureListenerStackDepth(); + if (stackDepth < MAX_LISTENER_STACK_DEPTH) { + threadLocals.setFutureListenerStackDepth(stackDepth + 1); + try { + notifyListenersNow(); + } finally { + threadLocals.setFutureListenerStackDepth(stackDepth); + } + return; + } + } + + safeExecute(executor, new Runnable() { + @Override + public void run() { + notifyListenersNow(); + } + }); + } + + /** + * The logic in this method should be identical to {@link #notifyListeners()} but + * cannot share code because the listener(s) cannot be cached for an instance of {@link DefaultPromise} since the + * listener(s) may be changed and is protected by a synchronized operation. + */ + private static void notifyListenerWithStackOverFlowProtection(final EventExecutor executor, + final Future future, + final GenericFutureListener listener) { + if (executor.inEventLoop()) { + final InternalThreadLocalMap threadLocals = InternalThreadLocalMap.get(); + final int stackDepth = threadLocals.futureListenerStackDepth(); + if (stackDepth < MAX_LISTENER_STACK_DEPTH) { + threadLocals.setFutureListenerStackDepth(stackDepth + 1); + try { + notifyListener0(future, listener); + } finally { + threadLocals.setFutureListenerStackDepth(stackDepth); + } + return; + } + } + + safeExecute(executor, new Runnable() { + @Override + public void run() { + notifyListener0(future, listener); + } + }); + } + + private void notifyListenersNow() { + GenericFutureListener listener; + DefaultFutureListeners listeners; + synchronized (this) { + listener = this.listener; + listeners = this.listeners; + // Only proceed if there are listeners to notify and we are not already notifying listeners. + if (notifyingListeners || (listener == null && listeners == null)) { + return; + } + notifyingListeners = true; + if (listener != null) { + this.listener = null; + } else { + this.listeners = null; + } + } + for (; ; ) { + if (listener != null) { + notifyListener0(this, listener); + } else { + notifyListeners0(listeners); + } + synchronized (this) { + if (this.listener == null && this.listeners == null) { + // Nothing can throw from within this method, so setting notifyingListeners back to false does not + // need to be in a finally block. + notifyingListeners = false; + return; + } + listener = this.listener; + listeners = this.listeners; + if (listener != null) { + this.listener = null; + } else { + this.listeners = null; + } + } + } + } + + private void notifyListeners0(DefaultFutureListeners listeners) { + GenericFutureListener[] a = listeners.listeners(); + int size = listeners.size(); + for (int i = 0; i < size; i++) { + notifyListener0(this, a[i]); + } + } + + @SuppressWarnings({"unchecked", "rawtypes"}) + private static void notifyListener0(Future future, GenericFutureListener l) { + try { + l.operationComplete(future); + } catch (Throwable t) { + if (logger.isWarnEnabled()) { + logger.warn("An exception was thrown by " + l.getClass().getName() + ".operationComplete()", t); + } + } + } + + private void addListener0(GenericFutureListener> listener) { + if (this.listener == null) { + if (listeners == null) { + this.listener = listener; + } else { + listeners.add(listener); + } + } else { + assert listeners == null; + listeners = new DefaultFutureListeners(this.listener, listener); + this.listener = null; + } + } + + private void removeListener0(GenericFutureListener> toRemove) { + if (listener == toRemove) { + listener = null; + } else if (listeners != null) { + listeners.remove(toRemove); + // Removal is rare, no need for compaction + if (listeners.size() == 0) { + listeners = null; + } + } + } + + private boolean setSuccess0(V result) { + return setValue0(result == null ? SUCCESS : result); + } + + private boolean setFailure0(Throwable cause) { + return setValue0(new CauseHolder(checkNotNull(cause, "cause"))); + } + + private boolean setValue0(Object objResult) { + if (RESULT_UPDATER.compareAndSet(this, null, objResult) || + RESULT_UPDATER.compareAndSet(this, UNCANCELLABLE, objResult)) { + if (checkNotifyWaiters()) { + notifyListeners(); + } + return true; + } + return false; + } + + /** + * Check if there are any waiters and if so notify these. + * + * @return {@code true} if there are any listeners attached to the promise, {@code false} otherwise. + */ + private synchronized boolean checkNotifyWaiters() { + if (waiters > 0) { + notifyAll(); + } + return listener != null || listeners != null; + } + + private void incWaiters() { + if (waiters == Short.MAX_VALUE) { + throw new IllegalStateException("too many waiters: " + this); + } + ++waiters; + } + + private void decWaiters() { + --waiters; + } + + private void rethrowIfFailed() { + Throwable cause = cause(); + if (cause == null) { + return; + } + + PlatformDependent.throwException(cause); + } + + private boolean await0(long timeoutNanos, boolean interruptable) throws InterruptedException { + if (isDone()) { + return true; + } + + if (timeoutNanos <= 0) { + return isDone(); + } + + if (interruptable && Thread.interrupted()) { + throw new InterruptedException(toString()); + } + + checkDeadLock(); + + // Start counting time from here instead of the first line of this method, + // to avoid/postpone performance cost of System.nanoTime(). + final long startTime = System.nanoTime(); + synchronized (this) { + boolean interrupted = false; + try { + long waitTime = timeoutNanos; + while (!isDone() && waitTime > 0) { + incWaiters(); + try { + wait(waitTime / 1000000, (int) (waitTime % 1000000)); + } catch (InterruptedException e) { + if (interruptable) { + throw e; + } else { + interrupted = true; + } + } finally { + decWaiters(); + } + // Check isDone() in advance, try to avoid calculating the elapsed time later. + if (isDone()) { + return true; + } + // Calculate the elapsed time here instead of in the while condition, + // try to avoid performance cost of System.nanoTime() in the first loop of while. + waitTime = timeoutNanos - (System.nanoTime() - startTime); + } + return isDone(); + } finally { + if (interrupted) { + Thread.currentThread().interrupt(); + } + } + } + } + + /** + * Notify all progressive listeners. + *

+ * No attempt is made to ensure notification order if multiple calls are made to this method before + * the original invocation completes. + *

+ * This will do an iteration over all listeners to get all of type {@link GenericProgressiveFutureListener}s. + * + * @param progress the new progress. + * @param total the total progress. + */ + @SuppressWarnings("unchecked") + void notifyProgressiveListeners(final long progress, final long total) { + final Object listeners = progressiveListeners(); + if (listeners == null) { + return; + } + + final ProgressiveFuture self = (ProgressiveFuture) this; + + EventExecutor executor = executor(); + if (executor.inEventLoop()) { + if (listeners instanceof GenericProgressiveFutureListener[]) { + notifyProgressiveListeners0( + self, (GenericProgressiveFutureListener[]) listeners, progress, total); + } else { + notifyProgressiveListener0( + self, (GenericProgressiveFutureListener>) listeners, progress, total); + } + } else { + if (listeners instanceof GenericProgressiveFutureListener[]) { + final GenericProgressiveFutureListener[] array = + (GenericProgressiveFutureListener[]) listeners; + safeExecute(executor, new Runnable() { + @Override + public void run() { + notifyProgressiveListeners0(self, array, progress, total); + } + }); + } else { + final GenericProgressiveFutureListener> l = + (GenericProgressiveFutureListener>) listeners; + safeExecute(executor, new Runnable() { + @Override + public void run() { + notifyProgressiveListener0(self, l, progress, total); + } + }); + } + } + } + + /** + * Returns a {@link GenericProgressiveFutureListener}, an array of {@link GenericProgressiveFutureListener}, or + * {@code null}. + */ + private synchronized Object progressiveListeners() { + final GenericFutureListener listener = this.listener; + final DefaultFutureListeners listeners = this.listeners; + if (listener == null && listeners == null) { + // No listeners added + return null; + } + + if (listeners != null) { + // Copy DefaultFutureListeners into an array of listeners. + DefaultFutureListeners dfl = listeners; + int progressiveSize = dfl.progressiveSize(); + switch (progressiveSize) { + case 0: + return null; + case 1: + for (GenericFutureListener l : dfl.listeners()) { + if (l instanceof GenericProgressiveFutureListener) { + return l; + } + } + return null; + } + + GenericFutureListener[] array = dfl.listeners(); + GenericProgressiveFutureListener[] copy = new GenericProgressiveFutureListener[progressiveSize]; + for (int i = 0, j = 0; j < progressiveSize; i++) { + GenericFutureListener l = array[i]; + if (l instanceof GenericProgressiveFutureListener) { + copy[j++] = (GenericProgressiveFutureListener) l; + } + } + + return copy; + } else if (listener instanceof GenericProgressiveFutureListener) { + return listener; + } else { + // Only one listener was added and it's not a progressive listener. + return null; + } + } + + private static void notifyProgressiveListeners0( + ProgressiveFuture future, GenericProgressiveFutureListener[] listeners, long progress, long total) { + for (GenericProgressiveFutureListener l : listeners) { + if (l == null) { + break; + } + notifyProgressiveListener0(future, l, progress, total); + } + } + + @SuppressWarnings({"unchecked", "rawtypes"}) + private static void notifyProgressiveListener0( + ProgressiveFuture future, GenericProgressiveFutureListener l, long progress, long total) { + try { + l.operationProgressed(future, progress, total); + } catch (Throwable t) { + if (logger.isWarnEnabled()) { + logger.warn("An exception was thrown by " + l.getClass().getName() + ".operationProgressed()", t); + } + } + } + + private static boolean isCancelled0(Object result) { + return result instanceof CauseHolder && ((CauseHolder) result).cause instanceof CancellationException; + } + + private static boolean isDone0(Object result) { + return result != null && result != UNCANCELLABLE; + } + + private static final class CauseHolder { + final Throwable cause; + + CauseHolder(Throwable cause) { + this.cause = cause; + } + } + + private static void safeExecute(EventExecutor executor, Runnable task) { + try { + executor.execute(task); + } catch (Throwable t) { + rejectedExecutionLogger.error("Failed to submit a listener notification task. Event loop shut down?", t); + } + } + + private static final class StacklessCancellationException extends CancellationException { + + private static final long serialVersionUID = -2974906711413716191L; + + private StacklessCancellationException() { + } + + // Override fillInStackTrace() so we not populate the backtrace via a native call and so leak the + // Classloader. + @Override + public Throwable fillInStackTrace() { + return this; + } + + static StacklessCancellationException newInstance(Class clazz, String method) { + return ThrowableUtil.unknownStackTrace(new StacklessCancellationException(), clazz, method); + } + } +} diff --git a/netty-util/src/main/java/io/netty/util/concurrent/DefaultThreadFactory.java b/netty-util/src/main/java/io/netty/util/concurrent/DefaultThreadFactory.java new file mode 100644 index 0000000..1e62aa7 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/concurrent/DefaultThreadFactory.java @@ -0,0 +1,122 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.util.concurrent; + +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.StringUtil; +import java.util.Locale; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.atomic.AtomicInteger; + +/** + * A {@link ThreadFactory} implementation with a simple naming rule. + */ +public class DefaultThreadFactory implements ThreadFactory { + + private static final AtomicInteger poolId = new AtomicInteger(); + + private final AtomicInteger nextId = new AtomicInteger(); + private final String prefix; + private final boolean daemon; + private final int priority; + protected final ThreadGroup threadGroup; + + public DefaultThreadFactory(Class poolType) { + this(poolType, false, Thread.NORM_PRIORITY); + } + + public DefaultThreadFactory(String poolName) { + this(poolName, false, Thread.NORM_PRIORITY); + } + + public DefaultThreadFactory(Class poolType, boolean daemon) { + this(poolType, daemon, Thread.NORM_PRIORITY); + } + + public DefaultThreadFactory(String poolName, boolean daemon) { + this(poolName, daemon, Thread.NORM_PRIORITY); + } + + public DefaultThreadFactory(Class poolType, int priority) { + this(poolType, false, priority); + } + + public DefaultThreadFactory(String poolName, int priority) { + this(poolName, false, priority); + } + + public DefaultThreadFactory(Class poolType, boolean daemon, int priority) { + this(toPoolName(poolType), daemon, priority); + } + + public static String toPoolName(Class poolType) { + ObjectUtil.checkNotNull(poolType, "poolType"); + + String poolName = StringUtil.simpleClassName(poolType); + switch (poolName.length()) { + case 0: + return "unknown"; + case 1: + return poolName.toLowerCase(Locale.US); + default: + if (Character.isUpperCase(poolName.charAt(0)) && Character.isLowerCase(poolName.charAt(1))) { + return Character.toLowerCase(poolName.charAt(0)) + poolName.substring(1); + } else { + return poolName; + } + } + } + + public DefaultThreadFactory(String poolName, boolean daemon, int priority, ThreadGroup threadGroup) { + ObjectUtil.checkNotNull(poolName, "poolName"); + + if (priority < Thread.MIN_PRIORITY || priority > Thread.MAX_PRIORITY) { + throw new IllegalArgumentException( + "priority: " + priority + " (expected: Thread.MIN_PRIORITY <= priority <= Thread.MAX_PRIORITY)"); + } + + prefix = poolName + '-' + poolId.incrementAndGet() + '-'; + this.daemon = daemon; + this.priority = priority; + this.threadGroup = threadGroup; + } + + public DefaultThreadFactory(String poolName, boolean daemon, int priority) { + this(poolName, daemon, priority, null); + } + + @Override + public Thread newThread(Runnable r) { + Thread t = newThread(FastThreadLocalRunnable.wrap(r), prefix + nextId.incrementAndGet()); + try { + if (t.isDaemon() != daemon) { + t.setDaemon(daemon); + } + + if (t.getPriority() != priority) { + t.setPriority(priority); + } + } catch (Exception ignored) { + // Doesn't matter even if failed to set. + } + return t; + } + + protected Thread newThread(Runnable r, String name) { + return new FastThreadLocalThread(threadGroup, r, name); + } +} diff --git a/netty-util/src/main/java/io/netty/util/concurrent/EventExecutor.java b/netty-util/src/main/java/io/netty/util/concurrent/EventExecutor.java new file mode 100644 index 0000000..2e106eb --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/concurrent/EventExecutor.java @@ -0,0 +1,71 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.concurrent; + +/** + * The {@link EventExecutor} is a special {@link EventExecutorGroup} which comes + * with some handy methods to see if a {@link Thread} is executed in a event loop. + * Besides this, it also extends the {@link EventExecutorGroup} to allow for a generic + * way to access methods. + */ +public interface EventExecutor extends EventExecutorGroup { + + /** + * Returns a reference to itself. + */ + @Override + EventExecutor next(); + + /** + * Return the {@link EventExecutorGroup} which is the parent of this {@link EventExecutor}, + */ + EventExecutorGroup parent(); + + /** + * Calls {@link #inEventLoop(Thread)} with {@link Thread#currentThread()} as argument + */ + boolean inEventLoop(); + + /** + * Return {@code true} if the given {@link Thread} is executed in the event loop, + * {@code false} otherwise. + */ + boolean inEventLoop(Thread thread); + + /** + * Return a new {@link Promise}. + */ + Promise newPromise(); + + /** + * Create a new {@link ProgressivePromise}. + */ + ProgressivePromise newProgressivePromise(); + + /** + * Create a new {@link Future} which is marked as succeeded already. So {@link Future#isSuccess()} + * will return {@code true}. All {@link FutureListener} added to it will be notified directly. Also + * every call of blocking methods will just return without blocking. + */ + Future newSucceededFuture(V result); + + /** + * Create a new {@link Future} which is marked as failed already. So {@link Future#isSuccess()} + * will return {@code false}. All {@link FutureListener} added to it will be notified directly. Also + * every call of blocking methods will just return without blocking. + */ + Future newFailedFuture(Throwable cause); +} diff --git a/netty-util/src/main/java/io/netty/util/concurrent/EventExecutorChooserFactory.java b/netty-util/src/main/java/io/netty/util/concurrent/EventExecutorChooserFactory.java new file mode 100644 index 0000000..b6a145a --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/concurrent/EventExecutorChooserFactory.java @@ -0,0 +1,42 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.concurrent; + +import io.netty.util.internal.UnstableApi; + +/** + * Factory that creates new {@link EventExecutorChooser}s. + */ +@UnstableApi +public interface EventExecutorChooserFactory { + + /** + * Returns a new {@link EventExecutorChooser}. + */ + EventExecutorChooser newChooser(EventExecutor[] executors); + + /** + * Chooses the next {@link EventExecutor} to use. + */ + @UnstableApi + interface EventExecutorChooser { + + /** + * Returns the new {@link EventExecutor} to use. + */ + EventExecutor next(); + } +} diff --git a/netty-util/src/main/java/io/netty/util/concurrent/EventExecutorGroup.java b/netty-util/src/main/java/io/netty/util/concurrent/EventExecutorGroup.java new file mode 100644 index 0000000..253096b --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/concurrent/EventExecutorGroup.java @@ -0,0 +1,107 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.concurrent; + +import java.util.Iterator; +import java.util.List; +import java.util.concurrent.Callable; +import java.util.concurrent.ScheduledExecutorService; +import java.util.concurrent.TimeUnit; + +/** + * The {@link EventExecutorGroup} is responsible for providing the {@link EventExecutor}'s to use + * via its {@link #next()} method. Besides this, it is also responsible for handling their + * life-cycle and allows shutting them down in a global fashion. + */ +public interface EventExecutorGroup extends ScheduledExecutorService, Iterable { + + /** + * Returns {@code true} if and only if all {@link EventExecutor}s managed by this {@link EventExecutorGroup} + * are being {@linkplain #shutdownGracefully() shut down gracefully} or was {@linkplain #isShutdown() shut down}. + */ + boolean isShuttingDown(); + + /** + * Shortcut method for {@link #shutdownGracefully(long, long, TimeUnit)} with sensible default values. + * + * @return the {@link #terminationFuture()} + */ + Future shutdownGracefully(); + + /** + * Signals this executor that the caller wants the executor to be shut down. Once this method is called, + * {@link #isShuttingDown()} starts to return {@code true}, and the executor prepares to shut itself down. + * Unlike {@link #shutdown()}, graceful shutdown ensures that no tasks are submitted for 'the quiet period' + * (usually a couple seconds) before it shuts itself down. If a task is submitted during the quiet period, + * it is guaranteed to be accepted and the quiet period will start over. + * + * @param quietPeriod the quiet period as described in the documentation + * @param timeout the maximum amount of time to wait until the executor is {@linkplain #shutdown()} + * regardless if a task was submitted during the quiet period + * @param unit the unit of {@code quietPeriod} and {@code timeout} + * @return the {@link #terminationFuture()} + */ + Future shutdownGracefully(long quietPeriod, long timeout, TimeUnit unit); + + /** + * Returns the {@link Future} which is notified when all {@link EventExecutor}s managed by this + * {@link EventExecutorGroup} have been terminated. + */ + Future terminationFuture(); + + /** + * @deprecated {@link #shutdownGracefully(long, long, TimeUnit)} or {@link #shutdownGracefully()} instead. + */ + @Override + @Deprecated + void shutdown(); + + /** + * @deprecated {@link #shutdownGracefully(long, long, TimeUnit)} or {@link #shutdownGracefully()} instead. + */ + @Override + @Deprecated + List shutdownNow(); + + /** + * Returns one of the {@link EventExecutor}s managed by this {@link EventExecutorGroup}. + */ + EventExecutor next(); + + @Override + Iterator iterator(); + + @Override + Future submit(Runnable task); + + @Override + Future submit(Runnable task, T result); + + @Override + Future submit(Callable task); + + @Override + ScheduledFuture schedule(Runnable command, long delay, TimeUnit unit); + + @Override + ScheduledFuture schedule(Callable callable, long delay, TimeUnit unit); + + @Override + ScheduledFuture scheduleAtFixedRate(Runnable command, long initialDelay, long period, TimeUnit unit); + + @Override + ScheduledFuture scheduleWithFixedDelay(Runnable command, long initialDelay, long delay, TimeUnit unit); +} diff --git a/netty-util/src/main/java/io/netty/util/concurrent/FailedFuture.java b/netty-util/src/main/java/io/netty/util/concurrent/FailedFuture.java new file mode 100644 index 0000000..f003f4c --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/concurrent/FailedFuture.java @@ -0,0 +1,67 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.concurrent; + +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.PlatformDependent; + +/** + * The {@link CompleteFuture} which is failed already. It is + * recommended to use {@link EventExecutor#newFailedFuture(Throwable)} + * instead of calling the constructor of this future. + */ +public final class FailedFuture extends CompleteFuture { + + private final Throwable cause; + + /** + * Creates a new instance. + * + * @param executor the {@link EventExecutor} associated with this future + * @param cause the cause of failure + */ + public FailedFuture(EventExecutor executor, Throwable cause) { + super(executor); + this.cause = ObjectUtil.checkNotNull(cause, "cause"); + } + + @Override + public Throwable cause() { + return cause; + } + + @Override + public boolean isSuccess() { + return false; + } + + @Override + public Future sync() { + PlatformDependent.throwException(cause); + return this; + } + + @Override + public Future syncUninterruptibly() { + PlatformDependent.throwException(cause); + return this; + } + + @Override + public V getNow() { + return null; + } +} diff --git a/netty-util/src/main/java/io/netty/util/concurrent/FastThreadLocal.java b/netty-util/src/main/java/io/netty/util/concurrent/FastThreadLocal.java new file mode 100644 index 0000000..74511a7 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/concurrent/FastThreadLocal.java @@ -0,0 +1,279 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.concurrent; + +import io.netty.util.internal.InternalThreadLocalMap; +import io.netty.util.internal.PlatformDependent; +import java.util.Collections; +import java.util.IdentityHashMap; +import java.util.Set; +import static io.netty.util.internal.InternalThreadLocalMap.VARIABLES_TO_REMOVE_INDEX; + +/** + * A special variant of {@link ThreadLocal} that yields higher access performance when accessed from a + * {@link FastThreadLocalThread}. + *

+ * Internally, a {@link FastThreadLocal} uses a constant index in an array, instead of using hash code and hash table, + * to look for a variable. Although seemingly very subtle, it yields slight performance advantage over using a hash + * table, and it is useful when accessed frequently. + *

+ * To take advantage of this thread-local variable, your thread must be a {@link FastThreadLocalThread} or its subtype. + * By default, all threads created by {@link DefaultThreadFactory} are {@link FastThreadLocalThread} due to this reason. + *

+ * Note that the fast path is only possible on threads that extend {@link FastThreadLocalThread}, because it requires + * a special field to store the necessary state. An access by any other kind of thread falls back to a regular + * {@link ThreadLocal}. + *

+ * + * @param the type of the thread-local variable + * @see ThreadLocal + */ +public class FastThreadLocal { + + /** + * Removes all {@link FastThreadLocal} variables bound to the current thread. This operation is useful when you + * are in a container environment, and you don't want to leave the thread local variables in the threads you do not + * manage. + */ + public static void removeAll() { + InternalThreadLocalMap threadLocalMap = InternalThreadLocalMap.getIfSet(); + if (threadLocalMap == null) { + return; + } + + try { + Object v = threadLocalMap.indexedVariable(VARIABLES_TO_REMOVE_INDEX); + if (v != null && v != InternalThreadLocalMap.UNSET) { + @SuppressWarnings("unchecked") + Set> variablesToRemove = (Set>) v; + FastThreadLocal[] variablesToRemoveArray = + variablesToRemove.toArray(new FastThreadLocal[0]); + for (FastThreadLocal tlv : variablesToRemoveArray) { + tlv.remove(threadLocalMap); + } + } + } finally { + InternalThreadLocalMap.remove(); + } + } + + /** + * Returns the number of thread local variables bound to the current thread. + */ + public static int size() { + InternalThreadLocalMap threadLocalMap = InternalThreadLocalMap.getIfSet(); + if (threadLocalMap == null) { + return 0; + } else { + return threadLocalMap.size(); + } + } + + /** + * Destroys the data structure that keeps all {@link FastThreadLocal} variables accessed from + * non-{@link FastThreadLocalThread}s. This operation is useful when you are in a container environment, and you + * do not want to leave the thread local variables in the threads you do not manage. Call this method when your + * application is being unloaded from the container. + */ + public static void destroy() { + InternalThreadLocalMap.destroy(); + } + + @SuppressWarnings("unchecked") + private static void addToVariablesToRemove(InternalThreadLocalMap threadLocalMap, FastThreadLocal variable) { + Object v = threadLocalMap.indexedVariable(VARIABLES_TO_REMOVE_INDEX); + Set> variablesToRemove; + if (v == InternalThreadLocalMap.UNSET || v == null) { + variablesToRemove = Collections.newSetFromMap(new IdentityHashMap, Boolean>()); + threadLocalMap.setIndexedVariable(VARIABLES_TO_REMOVE_INDEX, variablesToRemove); + } else { + variablesToRemove = (Set>) v; + } + + variablesToRemove.add(variable); + } + + private static void removeFromVariablesToRemove( + InternalThreadLocalMap threadLocalMap, FastThreadLocal variable) { + + Object v = threadLocalMap.indexedVariable(VARIABLES_TO_REMOVE_INDEX); + + if (v == InternalThreadLocalMap.UNSET || v == null) { + return; + } + + @SuppressWarnings("unchecked") + Set> variablesToRemove = (Set>) v; + variablesToRemove.remove(variable); + } + + private final int index; + + public FastThreadLocal() { + index = InternalThreadLocalMap.nextVariableIndex(); + } + + /** + * Returns the current value for the current thread + */ + @SuppressWarnings("unchecked") + public final V get() { + InternalThreadLocalMap threadLocalMap = InternalThreadLocalMap.get(); + Object v = threadLocalMap.indexedVariable(index); + if (v != InternalThreadLocalMap.UNSET) { + return (V) v; + } + + return initialize(threadLocalMap); + } + + /** + * Returns the current value for the current thread if it exists, {@code null} otherwise. + */ + @SuppressWarnings("unchecked") + public final V getIfExists() { + InternalThreadLocalMap threadLocalMap = InternalThreadLocalMap.getIfSet(); + if (threadLocalMap != null) { + Object v = threadLocalMap.indexedVariable(index); + if (v != InternalThreadLocalMap.UNSET) { + return (V) v; + } + } + return null; + } + + /** + * Returns the current value for the specified thread local map. + * The specified thread local map must be for the current thread. + */ + @SuppressWarnings("unchecked") + public final V get(InternalThreadLocalMap threadLocalMap) { + Object v = threadLocalMap.indexedVariable(index); + if (v != InternalThreadLocalMap.UNSET) { + return (V) v; + } + + return initialize(threadLocalMap); + } + + private V initialize(InternalThreadLocalMap threadLocalMap) { + V v = null; + try { + v = initialValue(); + if (v == InternalThreadLocalMap.UNSET) { + throw new IllegalArgumentException("InternalThreadLocalMap.UNSET can not be initial value."); + } + } catch (Exception e) { + PlatformDependent.throwException(e); + } + + threadLocalMap.setIndexedVariable(index, v); + addToVariablesToRemove(threadLocalMap, this); + return v; + } + + /** + * Set the value for the current thread. + */ + public final void set(V value) { + if (value != InternalThreadLocalMap.UNSET) { + InternalThreadLocalMap threadLocalMap = InternalThreadLocalMap.get(); + setKnownNotUnset(threadLocalMap, value); + } else { + remove(); + } + } + + /** + * Set the value for the specified thread local map. The specified thread local map must be for the current thread. + */ + public final void set(InternalThreadLocalMap threadLocalMap, V value) { + if (value != InternalThreadLocalMap.UNSET) { + setKnownNotUnset(threadLocalMap, value); + } else { + remove(threadLocalMap); + } + } + + /** + * @see InternalThreadLocalMap#setIndexedVariable(int, Object). + */ + private void setKnownNotUnset(InternalThreadLocalMap threadLocalMap, V value) { + if (threadLocalMap.setIndexedVariable(index, value)) { + addToVariablesToRemove(threadLocalMap, this); + } + } + + /** + * Returns {@code true} if and only if this thread-local variable is set. + */ + public final boolean isSet() { + return isSet(InternalThreadLocalMap.getIfSet()); + } + + /** + * Returns {@code true} if and only if this thread-local variable is set. + * The specified thread local map must be for the current thread. + */ + public final boolean isSet(InternalThreadLocalMap threadLocalMap) { + return threadLocalMap != null && threadLocalMap.isIndexedVariableSet(index); + } + + /** + * Sets the value to uninitialized for the specified thread local map. + * After this, any subsequent call to get() will trigger a new call to initialValue(). + */ + public final void remove() { + remove(InternalThreadLocalMap.getIfSet()); + } + + /** + * Sets the value to uninitialized for the specified thread local map. + * After this, any subsequent call to get() will trigger a new call to initialValue(). + * The specified thread local map must be for the current thread. + */ + @SuppressWarnings("unchecked") + public final void remove(InternalThreadLocalMap threadLocalMap) { + if (threadLocalMap == null) { + return; + } + + Object v = threadLocalMap.removeIndexedVariable(index); + if (v != InternalThreadLocalMap.UNSET) { + removeFromVariablesToRemove(threadLocalMap, this); + try { + onRemoval((V) v); + } catch (Exception e) { + PlatformDependent.throwException(e); + } + } + } + + /** + * Returns the initial value for this thread-local variable. + */ + protected V initialValue() throws Exception { + return null; + } + + /** + * Invoked when this thread local variable is removed by {@link #remove()}. Be aware that {@link #remove()} + * is not guaranteed to be called when the `Thread` completes which means you can not depend on this for + * cleanup of the resources in the case of `Thread` completion. + */ + protected void onRemoval(@SuppressWarnings("UnusedParameters") V value) throws Exception { + } +} diff --git a/netty-util/src/main/java/io/netty/util/concurrent/FastThreadLocalRunnable.java b/netty-util/src/main/java/io/netty/util/concurrent/FastThreadLocalRunnable.java new file mode 100644 index 0000000..218c62c --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/concurrent/FastThreadLocalRunnable.java @@ -0,0 +1,39 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.concurrent; + +import io.netty.util.internal.ObjectUtil; + +final class FastThreadLocalRunnable implements Runnable { + private final Runnable runnable; + + private FastThreadLocalRunnable(Runnable runnable) { + this.runnable = ObjectUtil.checkNotNull(runnable, "runnable"); + } + + @Override + public void run() { + try { + runnable.run(); + } finally { + FastThreadLocal.removeAll(); + } + } + + static Runnable wrap(Runnable runnable) { + return runnable instanceof FastThreadLocalRunnable ? runnable : new FastThreadLocalRunnable(runnable); + } +} diff --git a/netty-util/src/main/java/io/netty/util/concurrent/FastThreadLocalThread.java b/netty-util/src/main/java/io/netty/util/concurrent/FastThreadLocalThread.java new file mode 100644 index 0000000..4f0fc7b --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/concurrent/FastThreadLocalThread.java @@ -0,0 +1,128 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.concurrent; + +import io.netty.util.internal.InternalThreadLocalMap; +import io.netty.util.internal.UnstableApi; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +/** + * A special {@link Thread} that provides fast access to {@link FastThreadLocal} variables. + */ +public class FastThreadLocalThread extends Thread { + + private static final InternalLogger logger = InternalLoggerFactory.getInstance(FastThreadLocalThread.class); + + // This will be set to true if we have a chance to wrap the Runnable. + private final boolean cleanupFastThreadLocals; + + private InternalThreadLocalMap threadLocalMap; + + public FastThreadLocalThread() { + cleanupFastThreadLocals = false; + } + + public FastThreadLocalThread(Runnable target) { + super(FastThreadLocalRunnable.wrap(target)); + cleanupFastThreadLocals = true; + } + + public FastThreadLocalThread(ThreadGroup group, Runnable target) { + super(group, FastThreadLocalRunnable.wrap(target)); + cleanupFastThreadLocals = true; + } + + public FastThreadLocalThread(String name) { + super(name); + cleanupFastThreadLocals = false; + } + + public FastThreadLocalThread(ThreadGroup group, String name) { + super(group, name); + cleanupFastThreadLocals = false; + } + + public FastThreadLocalThread(Runnable target, String name) { + super(FastThreadLocalRunnable.wrap(target), name); + cleanupFastThreadLocals = true; + } + + public FastThreadLocalThread(ThreadGroup group, Runnable target, String name) { + super(group, FastThreadLocalRunnable.wrap(target), name); + cleanupFastThreadLocals = true; + } + + public FastThreadLocalThread(ThreadGroup group, Runnable target, String name, long stackSize) { + super(group, FastThreadLocalRunnable.wrap(target), name, stackSize); + cleanupFastThreadLocals = true; + } + + /** + * Returns the internal data structure that keeps the thread-local variables bound to this thread. + * Note that this method is for internal use only, and thus is subject to change at any time. + */ + public final InternalThreadLocalMap threadLocalMap() { + if (this != Thread.currentThread() && logger.isWarnEnabled()) { + logger.warn(new RuntimeException("It's not thread-safe to get 'threadLocalMap' " + + "which doesn't belong to the caller thread")); + } + return threadLocalMap; + } + + /** + * Sets the internal data structure that keeps the thread-local variables bound to this thread. + * Note that this method is for internal use only, and thus is subject to change at any time. + */ + public final void setThreadLocalMap(InternalThreadLocalMap threadLocalMap) { + if (this != Thread.currentThread() && logger.isWarnEnabled()) { + logger.warn(new RuntimeException("It's not thread-safe to set 'threadLocalMap' " + + "which doesn't belong to the caller thread")); + } + this.threadLocalMap = threadLocalMap; + } + + /** + * Returns {@code true} if {@link FastThreadLocal#removeAll()} will be called once {@link #run()} completes. + */ + @UnstableApi + public boolean willCleanupFastThreadLocals() { + return cleanupFastThreadLocals; + } + + /** + * Returns {@code true} if {@link FastThreadLocal#removeAll()} will be called once {@link Thread#run()} completes. + */ + @UnstableApi + public static boolean willCleanupFastThreadLocals(Thread thread) { + return thread instanceof FastThreadLocalThread && + ((FastThreadLocalThread) thread).willCleanupFastThreadLocals(); + } + + /** + * Query whether this thread is allowed to perform blocking calls or not. + * {@link FastThreadLocalThread}s are often used in event-loops, where blocking calls are forbidden in order to + * prevent event-loop stalls, so this method returns {@code false} by default. + *

+ * Subclasses of {@link FastThreadLocalThread} can override this method if they are not meant to be used for + * running event-loops. + * + * @return {@code false}, unless overriden by a subclass. + */ + public boolean permitBlockingCalls() { + return false; + } +} diff --git a/netty-util/src/main/java/io/netty/util/concurrent/Future.java b/netty-util/src/main/java/io/netty/util/concurrent/Future.java new file mode 100644 index 0000000..723256a --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/concurrent/Future.java @@ -0,0 +1,164 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.concurrent; + +import java.util.concurrent.CancellationException; +import java.util.concurrent.TimeUnit; + + +/** + * The result of an asynchronous operation. + */ +@SuppressWarnings("ClassNameSameAsAncestorName") +public interface Future extends java.util.concurrent.Future { + + /** + * Returns {@code true} if and only if the I/O operation was completed + * successfully. + */ + boolean isSuccess(); + + /** + * returns {@code true} if and only if the operation can be cancelled via {@link #cancel(boolean)}. + */ + boolean isCancellable(); + + /** + * Returns the cause of the failed I/O operation if the I/O operation has + * failed. + * + * @return the cause of the failure. + * {@code null} if succeeded or this future is not + * completed yet. + */ + Throwable cause(); + + /** + * Adds the specified listener to this future. The + * specified listener is notified when this future is + * {@linkplain #isDone() done}. If this future is already + * completed, the specified listener is notified immediately. + */ + Future addListener(GenericFutureListener> listener); + + /** + * Adds the specified listeners to this future. The + * specified listeners are notified when this future is + * {@linkplain #isDone() done}. If this future is already + * completed, the specified listeners are notified immediately. + */ + Future addListeners(GenericFutureListener>... listeners); + + /** + * Removes the first occurrence of the specified listener from this future. + * The specified listener is no longer notified when this + * future is {@linkplain #isDone() done}. If the specified + * listener is not associated with this future, this method + * does nothing and returns silently. + */ + Future removeListener(GenericFutureListener> listener); + + /** + * Removes the first occurrence for each of the listeners from this future. + * The specified listeners are no longer notified when this + * future is {@linkplain #isDone() done}. If the specified + * listeners are not associated with this future, this method + * does nothing and returns silently. + */ + Future removeListeners(GenericFutureListener>... listeners); + + /** + * Waits for this future until it is done, and rethrows the cause of the failure if this future + * failed. + */ + Future sync() throws InterruptedException; + + /** + * Waits for this future until it is done, and rethrows the cause of the failure if this future + * failed. + */ + Future syncUninterruptibly(); + + /** + * Waits for this future to be completed. + * + * @throws InterruptedException if the current thread was interrupted + */ + Future await() throws InterruptedException; + + /** + * Waits for this future to be completed without + * interruption. This method catches an {@link InterruptedException} and + * discards it silently. + */ + Future awaitUninterruptibly(); + + /** + * Waits for this future to be completed within the + * specified time limit. + * + * @return {@code true} if and only if the future was completed within + * the specified time limit + * @throws InterruptedException if the current thread was interrupted + */ + boolean await(long timeout, TimeUnit unit) throws InterruptedException; + + /** + * Waits for this future to be completed within the + * specified time limit. + * + * @return {@code true} if and only if the future was completed within + * the specified time limit + * @throws InterruptedException if the current thread was interrupted + */ + boolean await(long timeoutMillis) throws InterruptedException; + + /** + * Waits for this future to be completed within the + * specified time limit without interruption. This method catches an + * {@link InterruptedException} and discards it silently. + * + * @return {@code true} if and only if the future was completed within + * the specified time limit + */ + boolean awaitUninterruptibly(long timeout, TimeUnit unit); + + /** + * Waits for this future to be completed within the + * specified time limit without interruption. This method catches an + * {@link InterruptedException} and discards it silently. + * + * @return {@code true} if and only if the future was completed within + * the specified time limit + */ + boolean awaitUninterruptibly(long timeoutMillis); + + /** + * Return the result without blocking. If the future is not done yet this will return {@code null}. + *

+ * As it is possible that a {@code null} value is used to mark the future as successful you also need to check + * if the future is really done with {@link #isDone()} and not rely on the returned {@code null} value. + */ + V getNow(); + + /** + * {@inheritDoc} + *

+ * If the cancellation was successful it will fail the future with a {@link CancellationException}. + */ + @Override + boolean cancel(boolean mayInterruptIfRunning); +} diff --git a/netty-util/src/main/java/io/netty/util/concurrent/FutureListener.java b/netty-util/src/main/java/io/netty/util/concurrent/FutureListener.java new file mode 100644 index 0000000..080d50a --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/concurrent/FutureListener.java @@ -0,0 +1,29 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.util.concurrent; + +/** + * A subtype of {@link GenericFutureListener} that hides type parameter for convenience. + *

+ * Future f = new DefaultPromise(..);
+ * f.addListener(new FutureListener() {
+ *     public void operationComplete(Future f) { .. }
+ * });
+ * 
+ */ +public interface FutureListener extends GenericFutureListener> { +} diff --git a/netty-util/src/main/java/io/netty/util/concurrent/GenericFutureListener.java b/netty-util/src/main/java/io/netty/util/concurrent/GenericFutureListener.java new file mode 100644 index 0000000..e9869cf --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/concurrent/GenericFutureListener.java @@ -0,0 +1,32 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.concurrent; + +import java.util.EventListener; + +/** + * Listens to the result of a {@link Future}. The result of the asynchronous operation is notified once this listener + * is added by calling {@link Future#addListener(GenericFutureListener)}. + */ +public interface GenericFutureListener> extends EventListener { + + /** + * Invoked when the operation associated with the {@link Future} has been completed. + * + * @param future the source {@link Future} which called this callback + */ + void operationComplete(F future) throws Exception; +} diff --git a/netty-util/src/main/java/io/netty/util/concurrent/GenericProgressiveFutureListener.java b/netty-util/src/main/java/io/netty/util/concurrent/GenericProgressiveFutureListener.java new file mode 100644 index 0000000..4ff04ce --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/concurrent/GenericProgressiveFutureListener.java @@ -0,0 +1,28 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.util.concurrent; + +public interface GenericProgressiveFutureListener> extends GenericFutureListener { + /** + * Invoked when the operation has progressed. + * + * @param progress the progress of the operation so far (cumulative) + * @param total the number that signifies the end of the operation when {@code progress} reaches at it. + * {@code -1} if the end of operation is unknown. + */ + void operationProgressed(F future, long progress, long total) throws Exception; +} diff --git a/netty-util/src/main/java/io/netty/util/concurrent/GlobalEventExecutor.java b/netty-util/src/main/java/io/netty/util/concurrent/GlobalEventExecutor.java new file mode 100644 index 0000000..5d41801 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/concurrent/GlobalEventExecutor.java @@ -0,0 +1,303 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.concurrent; + +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.SystemPropertyUtil; +import io.netty.util.internal.ThreadExecutorMap; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; +import java.security.AccessController; +import java.security.PrivilegedAction; +import java.util.Queue; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.Executors; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * Single-thread singleton {@link EventExecutor}. It starts the thread automatically and stops it when there is no + * task pending in the task queue for {@code io.netty.globalEventExecutor.quietPeriodSeconds} second + * (default is 1 second). Please note it is not scalable to schedule large number of tasks to this executor; + * use a dedicated executor. + */ +public final class GlobalEventExecutor extends AbstractScheduledEventExecutor implements OrderedEventExecutor { + private static final InternalLogger logger = InternalLoggerFactory.getInstance(GlobalEventExecutor.class); + + private static final long SCHEDULE_QUIET_PERIOD_INTERVAL; + + static { + int quietPeriod = SystemPropertyUtil.getInt("io.netty.globalEventExecutor.quietPeriodSeconds", 1); + if (quietPeriod <= 0) { + quietPeriod = 1; + } + logger.debug("-Dio.netty.globalEventExecutor.quietPeriodSeconds: {}", quietPeriod); + + SCHEDULE_QUIET_PERIOD_INTERVAL = TimeUnit.SECONDS.toNanos(quietPeriod); + } + + public static final GlobalEventExecutor INSTANCE = new GlobalEventExecutor(); + + final BlockingQueue taskQueue = new LinkedBlockingQueue(); + final ScheduledFutureTask quietPeriodTask = new ScheduledFutureTask( + this, Executors.callable(new Runnable() { + @Override + public void run() { + // NOOP + } + }, null), + // note: the getCurrentTimeNanos() call here only works because this is a final class, otherwise the method + // could be overridden leading to unsafe initialization here! + deadlineNanos(getCurrentTimeNanos(), SCHEDULE_QUIET_PERIOD_INTERVAL), + -SCHEDULE_QUIET_PERIOD_INTERVAL + ); + + // because the GlobalEventExecutor is a singleton, tasks submitted to it can come from arbitrary threads and this + // can trigger the creation of a thread from arbitrary thread groups; for this reason, the thread factory must not + // be sticky about its thread group + // visible for testing + final ThreadFactory threadFactory; + private final TaskRunner taskRunner = new TaskRunner(); + private final AtomicBoolean started = new AtomicBoolean(); + volatile Thread thread; + + private final Future terminationFuture = new FailedFuture(this, new UnsupportedOperationException()); + + private GlobalEventExecutor() { + scheduledTaskQueue().add(quietPeriodTask); + threadFactory = ThreadExecutorMap.apply(new DefaultThreadFactory( + DefaultThreadFactory.toPoolName(getClass()), false, Thread.NORM_PRIORITY, null), this); + } + + /** + * Take the next {@link Runnable} from the task queue and so will block if no task is currently present. + * + * @return {@code null} if the executor thread has been interrupted or waken up. + */ + Runnable takeTask() { + BlockingQueue taskQueue = this.taskQueue; + for (; ; ) { + ScheduledFutureTask scheduledTask = peekScheduledTask(); + if (scheduledTask == null) { + Runnable task = null; + try { + task = taskQueue.take(); + } catch (InterruptedException e) { + // Ignore + } + return task; + } else { + long delayNanos = scheduledTask.delayNanos(); + Runnable task = null; + if (delayNanos > 0) { + try { + task = taskQueue.poll(delayNanos, TimeUnit.NANOSECONDS); + } catch (InterruptedException e) { + // Waken up. + return null; + } + } + if (task == null) { + // We need to fetch the scheduled tasks now as otherwise there may be a chance that + // scheduled tasks are never executed if there is always one task in the taskQueue. + // This is for example true for the read task of OIO Transport + // See https://github.com/netty/netty/issues/1614 + fetchFromScheduledTaskQueue(); + task = taskQueue.poll(); + } + + if (task != null) { + return task; + } + } + } + } + + private void fetchFromScheduledTaskQueue() { + long nanoTime = getCurrentTimeNanos(); + Runnable scheduledTask = pollScheduledTask(nanoTime); + while (scheduledTask != null) { + taskQueue.add(scheduledTask); + scheduledTask = pollScheduledTask(nanoTime); + } + } + + /** + * Return the number of tasks that are pending for processing. + */ + public int pendingTasks() { + return taskQueue.size(); + } + + /** + * Add a task to the task queue, or throws a {@link RejectedExecutionException} if this instance was shutdown + * before. + */ + private void addTask(Runnable task) { + taskQueue.add(ObjectUtil.checkNotNull(task, "task")); + } + + @Override + public boolean inEventLoop(Thread thread) { + return thread == this.thread; + } + + @Override + public Future shutdownGracefully(long quietPeriod, long timeout, TimeUnit unit) { + return terminationFuture(); + } + + @Override + public Future terminationFuture() { + return terminationFuture; + } + + @Override + @Deprecated + public void shutdown() { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isShuttingDown() { + return false; + } + + @Override + public boolean isShutdown() { + return false; + } + + @Override + public boolean isTerminated() { + return false; + } + + @Override + public boolean awaitTermination(long timeout, TimeUnit unit) { + return false; + } + + /** + * Waits until the worker thread of this executor has no tasks left in its task queue and terminates itself. + * Because a new worker thread will be started again when a new task is submitted, this operation is only useful + * when you want to ensure that the worker thread is terminated after your application is shut + * down and there's no chance of submitting a new task afterwards. + * + * @return {@code true} if and only if the worker thread has been terminated + */ + public boolean awaitInactivity(long timeout, TimeUnit unit) throws InterruptedException { + ObjectUtil.checkNotNull(unit, "unit"); + + final Thread thread = this.thread; + if (thread == null) { + throw new IllegalStateException("thread was not started"); + } + thread.join(unit.toMillis(timeout)); + return !thread.isAlive(); + } + + @Override + public void execute(Runnable task) { + execute0(task); + } + + private void execute0(Runnable task) { + addTask(ObjectUtil.checkNotNull(task, "task")); + if (!inEventLoop()) { + startThread(); + } + } + + private void startThread() { + if (started.compareAndSet(false, true)) { + final Thread t = threadFactory.newThread(taskRunner); + // Set to null to ensure we not create classloader leaks by holds a strong reference to the inherited + // classloader. + // See: + // - https://github.com/netty/netty/issues/7290 + // - https://bugs.openjdk.java.net/browse/JDK-7008595 + AccessController.doPrivileged(new PrivilegedAction() { + @Override + public Void run() { + t.setContextClassLoader(null); + return null; + } + }); + + // Set the thread before starting it as otherwise inEventLoop() may return false and so produce + // an assert error. + // See https://github.com/netty/netty/issues/4357 + thread = t; + t.start(); + } + } + + final class TaskRunner implements Runnable { + @Override + public void run() { + for (; ; ) { + Runnable task = takeTask(); + if (task != null) { + try { + runTask(task); + } catch (Throwable t) { + logger.warn("Unexpected exception from the global event executor: ", t); + } + + if (task != quietPeriodTask) { + continue; + } + } + + Queue> scheduledTaskQueue = GlobalEventExecutor.this.scheduledTaskQueue; + // Terminate if there is no task in the queue (except the noop task). + if (taskQueue.isEmpty() && (scheduledTaskQueue == null || scheduledTaskQueue.size() == 1)) { + // Mark the current thread as stopped. + // The following CAS must always success and must be uncontended, + // because only one thread should be running at the same time. + boolean stopped = started.compareAndSet(true, false); + assert stopped; + + // Check if there are pending entries added by execute() or schedule*() while we do CAS above. + // Do not check scheduledTaskQueue because it is not thread-safe and can only be mutated from a + // TaskRunner actively running tasks. + if (taskQueue.isEmpty()) { + // A) No new task was added and thus there's nothing to handle + // -> safe to terminate because there's nothing left to do + // B) A new thread started and handled all the new tasks. + // -> safe to terminate the new thread will take care the rest + break; + } + + // There are pending tasks added again. + if (!started.compareAndSet(false, true)) { + // startThread() started a new thread and set 'started' to true. + // -> terminate this thread so that the new thread reads from taskQueue exclusively. + break; + } + + // New tasks were added, but this worker was faster to set 'started' to true. + // i.e. a new worker thread was not started by startThread(). + // -> keep this thread alive to handle the newly added entries. + } + } + } + } +} diff --git a/netty-util/src/main/java/io/netty/util/concurrent/ImmediateEventExecutor.java b/netty-util/src/main/java/io/netty/util/concurrent/ImmediateEventExecutor.java new file mode 100644 index 0000000..caafda5 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/concurrent/ImmediateEventExecutor.java @@ -0,0 +1,162 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.concurrent; + +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; +import java.util.ArrayDeque; +import java.util.Queue; +import java.util.concurrent.TimeUnit; + +/** + * Executes {@link Runnable} objects in the caller's thread. If the {@link #execute(Runnable)} is reentrant it will be + * queued until the original {@link Runnable} finishes execution. + *

+ * All {@link Throwable} objects thrown from {@link #execute(Runnable)} will be swallowed and logged. This is to ensure + * that all queued {@link Runnable} objects have the chance to be run. + */ +public final class ImmediateEventExecutor extends AbstractEventExecutor { + private static final InternalLogger logger = InternalLoggerFactory.getInstance(ImmediateEventExecutor.class); + public static final ImmediateEventExecutor INSTANCE = new ImmediateEventExecutor(); + /** + * A Runnable will be queued if we are executing a Runnable. This is to prevent a {@link StackOverflowError}. + */ + private static final FastThreadLocal> DELAYED_RUNNABLES = new FastThreadLocal>() { + @Override + protected Queue initialValue() throws Exception { + return new ArrayDeque(); + } + }; + /** + * Set to {@code true} if we are executing a runnable. + */ + private static final FastThreadLocal RUNNING = new FastThreadLocal() { + @Override + protected Boolean initialValue() throws Exception { + return false; + } + }; + + private final Future terminationFuture = new FailedFuture( + GlobalEventExecutor.INSTANCE, new UnsupportedOperationException()); + + private ImmediateEventExecutor() { + } + + @Override + public boolean inEventLoop() { + return true; + } + + @Override + public boolean inEventLoop(Thread thread) { + return true; + } + + @Override + public Future shutdownGracefully(long quietPeriod, long timeout, TimeUnit unit) { + return terminationFuture(); + } + + @Override + public Future terminationFuture() { + return terminationFuture; + } + + @Override + @Deprecated + public void shutdown() { + } + + @Override + public boolean isShuttingDown() { + return false; + } + + @Override + public boolean isShutdown() { + return false; + } + + @Override + public boolean isTerminated() { + return false; + } + + @Override + public boolean awaitTermination(long timeout, TimeUnit unit) { + return false; + } + + @Override + public void execute(Runnable command) { + ObjectUtil.checkNotNull(command, "command"); + if (!RUNNING.get()) { + RUNNING.set(true); + try { + command.run(); + } catch (Throwable cause) { + logger.info("Throwable caught while executing Runnable {}", command, cause); + } finally { + Queue delayedRunnables = DELAYED_RUNNABLES.get(); + Runnable runnable; + while ((runnable = delayedRunnables.poll()) != null) { + try { + runnable.run(); + } catch (Throwable cause) { + logger.info("Throwable caught while executing Runnable {}", runnable, cause); + } + } + RUNNING.set(false); + } + } else { + DELAYED_RUNNABLES.get().add(command); + } + } + + @Override + public Promise newPromise() { + return new ImmediatePromise(this); + } + + @Override + public ProgressivePromise newProgressivePromise() { + return new ImmediateProgressivePromise(this); + } + + static class ImmediatePromise extends DefaultPromise { + ImmediatePromise(EventExecutor executor) { + super(executor); + } + + @Override + protected void checkDeadLock() { + // No check + } + } + + static class ImmediateProgressivePromise extends DefaultProgressivePromise { + ImmediateProgressivePromise(EventExecutor executor) { + super(executor); + } + + @Override + protected void checkDeadLock() { + // No check + } + } +} diff --git a/netty-util/src/main/java/io/netty/util/concurrent/ImmediateExecutor.java b/netty-util/src/main/java/io/netty/util/concurrent/ImmediateExecutor.java new file mode 100644 index 0000000..ab644a0 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/concurrent/ImmediateExecutor.java @@ -0,0 +1,35 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.concurrent; + +import io.netty.util.internal.ObjectUtil; +import java.util.concurrent.Executor; + +/** + * {@link Executor} which execute tasks in the callers thread. + */ +public final class ImmediateExecutor implements Executor { + public static final ImmediateExecutor INSTANCE = new ImmediateExecutor(); + + private ImmediateExecutor() { + // use static instance + } + + @Override + public void execute(Runnable command) { + ObjectUtil.checkNotNull(command, "command").run(); + } +} diff --git a/netty-util/src/main/java/io/netty/util/concurrent/MultithreadEventExecutorGroup.java b/netty-util/src/main/java/io/netty/util/concurrent/MultithreadEventExecutorGroup.java new file mode 100644 index 0000000..591bb88 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/concurrent/MultithreadEventExecutorGroup.java @@ -0,0 +1,227 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.concurrent; + +import java.util.Collections; +import java.util.Iterator; +import java.util.LinkedHashSet; +import java.util.Set; +import java.util.concurrent.Executor; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import static io.netty.util.internal.ObjectUtil.checkPositive; + +/** + * Abstract base class for {@link EventExecutorGroup} implementations that handles their tasks with multiple threads at + * the same time. + */ +public abstract class MultithreadEventExecutorGroup extends AbstractEventExecutorGroup { + + private final EventExecutor[] children; + private final Set readonlyChildren; + private final AtomicInteger terminatedChildren = new AtomicInteger(); + private final Promise terminationFuture = new DefaultPromise(GlobalEventExecutor.INSTANCE); + private final EventExecutorChooserFactory.EventExecutorChooser chooser; + + /** + * Create a new instance. + * + * @param nThreads the number of threads that will be used by this instance. + * @param threadFactory the ThreadFactory to use, or {@code null} if the default should be used. + * @param args arguments which will passed to each {@link #newChild(Executor, Object...)} call + */ + protected MultithreadEventExecutorGroup(int nThreads, ThreadFactory threadFactory, Object... args) { + this(nThreads, threadFactory == null ? null : new ThreadPerTaskExecutor(threadFactory), args); + } + + /** + * Create a new instance. + * + * @param nThreads the number of threads that will be used by this instance. + * @param executor the Executor to use, or {@code null} if the default should be used. + * @param args arguments which will passed to each {@link #newChild(Executor, Object...)} call + */ + protected MultithreadEventExecutorGroup(int nThreads, Executor executor, Object... args) { + this(nThreads, executor, DefaultEventExecutorChooserFactory.INSTANCE, args); + } + + /** + * Create a new instance. + * + * @param nThreads the number of threads that will be used by this instance. + * @param executor the Executor to use, or {@code null} if the default should be used. + * @param chooserFactory the {@link EventExecutorChooserFactory} to use. + * @param args arguments which will passed to each {@link #newChild(Executor, Object...)} call + */ + protected MultithreadEventExecutorGroup(int nThreads, Executor executor, + EventExecutorChooserFactory chooserFactory, Object... args) { + checkPositive(nThreads, "nThreads"); + + if (executor == null) { + executor = new ThreadPerTaskExecutor(newDefaultThreadFactory()); + } + + children = new EventExecutor[nThreads]; + + for (int i = 0; i < nThreads; i++) { + boolean success = false; + try { + children[i] = newChild(executor, args); + success = true; + } catch (Exception e) { + // TODO: Think about if this is a good exception type + throw new IllegalStateException("failed to create a child event loop", e); + } finally { + if (!success) { + for (int j = 0; j < i; j++) { + children[j].shutdownGracefully(); + } + + for (int j = 0; j < i; j++) { + EventExecutor e = children[j]; + try { + while (!e.isTerminated()) { + e.awaitTermination(Integer.MAX_VALUE, TimeUnit.SECONDS); + } + } catch (InterruptedException interrupted) { + // Let the caller handle the interruption. + Thread.currentThread().interrupt(); + break; + } + } + } + } + } + + chooser = chooserFactory.newChooser(children); + + final FutureListener terminationListener = new FutureListener() { + @Override + public void operationComplete(Future future) throws Exception { + if (terminatedChildren.incrementAndGet() == children.length) { + terminationFuture.setSuccess(null); + } + } + }; + + for (EventExecutor e : children) { + e.terminationFuture().addListener(terminationListener); + } + + Set childrenSet = new LinkedHashSet(children.length); + Collections.addAll(childrenSet, children); + readonlyChildren = Collections.unmodifiableSet(childrenSet); + } + + protected ThreadFactory newDefaultThreadFactory() { + return new DefaultThreadFactory(getClass()); + } + + @Override + public EventExecutor next() { + return chooser.next(); + } + + @Override + public Iterator iterator() { + return readonlyChildren.iterator(); + } + + /** + * Return the number of {@link EventExecutor} this implementation uses. This number is the maps + * 1:1 to the threads it use. + */ + public final int executorCount() { + return children.length; + } + + /** + * Create a new EventExecutor which will later then accessible via the {@link #next()} method. This method will be + * called for each thread that will serve this {@link MultithreadEventExecutorGroup}. + */ + protected abstract EventExecutor newChild(Executor executor, Object... args) throws Exception; + + @Override + public Future shutdownGracefully(long quietPeriod, long timeout, TimeUnit unit) { + for (EventExecutor l : children) { + l.shutdownGracefully(quietPeriod, timeout, unit); + } + return terminationFuture(); + } + + @Override + public Future terminationFuture() { + return terminationFuture; + } + + @Override + @Deprecated + public void shutdown() { + for (EventExecutor l : children) { + l.shutdown(); + } + } + + @Override + public boolean isShuttingDown() { + for (EventExecutor l : children) { + if (!l.isShuttingDown()) { + return false; + } + } + return true; + } + + @Override + public boolean isShutdown() { + for (EventExecutor l : children) { + if (!l.isShutdown()) { + return false; + } + } + return true; + } + + @Override + public boolean isTerminated() { + for (EventExecutor l : children) { + if (!l.isTerminated()) { + return false; + } + } + return true; + } + + @Override + public boolean awaitTermination(long timeout, TimeUnit unit) + throws InterruptedException { + long deadline = System.nanoTime() + unit.toNanos(timeout); + loop: + for (EventExecutor l : children) { + for (; ; ) { + long timeLeft = deadline - System.nanoTime(); + if (timeLeft <= 0) { + break loop; + } + if (l.awaitTermination(timeLeft, TimeUnit.NANOSECONDS)) { + break; + } + } + } + return isTerminated(); + } +} diff --git a/netty-util/src/main/java/io/netty/util/concurrent/NonStickyEventExecutorGroup.java b/netty-util/src/main/java/io/netty/util/concurrent/NonStickyEventExecutorGroup.java new file mode 100644 index 0000000..8606d1b --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/concurrent/NonStickyEventExecutorGroup.java @@ -0,0 +1,345 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.concurrent; + +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.UnstableApi; +import java.util.Collection; +import java.util.Iterator; +import java.util.List; +import java.util.Queue; +import java.util.concurrent.Callable; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; + +/** + * {@link EventExecutorGroup} which will preserve {@link Runnable} execution order but makes no guarantees about what + * {@link EventExecutor} (and therefore {@link Thread}) will be used to execute the {@link Runnable}s. + * + *

The {@link EventExecutorGroup#next()} for the wrapped {@link EventExecutorGroup} must NOT return + * executors of type {@link OrderedEventExecutor}. + */ +@UnstableApi +public final class NonStickyEventExecutorGroup implements EventExecutorGroup { + private final EventExecutorGroup group; + private final int maxTaskExecutePerRun; + + /** + * Creates a new instance. Be aware that the given {@link EventExecutorGroup} MUST NOT contain + * any {@link OrderedEventExecutor}s. + */ + public NonStickyEventExecutorGroup(EventExecutorGroup group) { + this(group, 1024); + } + + /** + * Creates a new instance. Be aware that the given {@link EventExecutorGroup} MUST NOT contain + * any {@link OrderedEventExecutor}s. + */ + public NonStickyEventExecutorGroup(EventExecutorGroup group, int maxTaskExecutePerRun) { + this.group = verify(group); + this.maxTaskExecutePerRun = ObjectUtil.checkPositive(maxTaskExecutePerRun, "maxTaskExecutePerRun"); + } + + private static EventExecutorGroup verify(EventExecutorGroup group) { + Iterator executors = ObjectUtil.checkNotNull(group, "group").iterator(); + while (executors.hasNext()) { + EventExecutor executor = executors.next(); + if (executor instanceof OrderedEventExecutor) { + throw new IllegalArgumentException("EventExecutorGroup " + group + + " contains OrderedEventExecutors: " + executor); + } + } + return group; + } + + private NonStickyOrderedEventExecutor newExecutor(EventExecutor executor) { + return new NonStickyOrderedEventExecutor(executor, maxTaskExecutePerRun); + } + + @Override + public boolean isShuttingDown() { + return group.isShuttingDown(); + } + + @Override + public Future shutdownGracefully() { + return group.shutdownGracefully(); + } + + @Override + public Future shutdownGracefully(long quietPeriod, long timeout, TimeUnit unit) { + return group.shutdownGracefully(quietPeriod, timeout, unit); + } + + @Override + public Future terminationFuture() { + return group.terminationFuture(); + } + + @SuppressWarnings("deprecation") + @Override + public void shutdown() { + group.shutdown(); + } + + @SuppressWarnings("deprecation") + @Override + public List shutdownNow() { + return group.shutdownNow(); + } + + @Override + public EventExecutor next() { + return newExecutor(group.next()); + } + + @Override + public Iterator iterator() { + final Iterator itr = group.iterator(); + return new Iterator() { + @Override + public boolean hasNext() { + return itr.hasNext(); + } + + @Override + public EventExecutor next() { + return newExecutor(itr.next()); + } + + @Override + public void remove() { + itr.remove(); + } + }; + } + + @Override + public Future submit(Runnable task) { + return group.submit(task); + } + + @Override + public Future submit(Runnable task, T result) { + return group.submit(task, result); + } + + @Override + public Future submit(Callable task) { + return group.submit(task); + } + + @Override + public ScheduledFuture schedule(Runnable command, long delay, TimeUnit unit) { + return group.schedule(command, delay, unit); + } + + @Override + public ScheduledFuture schedule(Callable callable, long delay, TimeUnit unit) { + return group.schedule(callable, delay, unit); + } + + @Override + public ScheduledFuture scheduleAtFixedRate(Runnable command, long initialDelay, long period, TimeUnit unit) { + return group.scheduleAtFixedRate(command, initialDelay, period, unit); + } + + @Override + public ScheduledFuture scheduleWithFixedDelay(Runnable command, long initialDelay, long delay, TimeUnit unit) { + return group.scheduleWithFixedDelay(command, initialDelay, delay, unit); + } + + @Override + public boolean isShutdown() { + return group.isShutdown(); + } + + @Override + public boolean isTerminated() { + return group.isTerminated(); + } + + @Override + public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException { + return group.awaitTermination(timeout, unit); + } + + @Override + public List> invokeAll( + Collection> tasks) throws InterruptedException { + return group.invokeAll(tasks); + } + + @Override + public List> invokeAll( + Collection> tasks, long timeout, TimeUnit unit) throws InterruptedException { + return group.invokeAll(tasks, timeout, unit); + } + + @Override + public T invokeAny(Collection> tasks) throws InterruptedException, ExecutionException { + return group.invokeAny(tasks); + } + + @Override + public T invokeAny(Collection> tasks, long timeout, TimeUnit unit) + throws InterruptedException, ExecutionException, TimeoutException { + return group.invokeAny(tasks, timeout, unit); + } + + @Override + public void execute(Runnable command) { + group.execute(command); + } + + private static final class NonStickyOrderedEventExecutor extends AbstractEventExecutor + implements Runnable, OrderedEventExecutor { + private final EventExecutor executor; + private final Queue tasks = PlatformDependent.newMpscQueue(); + + private static final int NONE = 0; + private static final int SUBMITTED = 1; + private static final int RUNNING = 2; + + private final AtomicInteger state = new AtomicInteger(); + private final int maxTaskExecutePerRun; + + private final AtomicReference executingThread = new AtomicReference(); + + NonStickyOrderedEventExecutor(EventExecutor executor, int maxTaskExecutePerRun) { + super(executor); + this.executor = executor; + this.maxTaskExecutePerRun = maxTaskExecutePerRun; + } + + @Override + public void run() { + if (!state.compareAndSet(SUBMITTED, RUNNING)) { + return; + } + Thread current = Thread.currentThread(); + executingThread.set(current); + for (; ; ) { + int i = 0; + try { + for (; i < maxTaskExecutePerRun; i++) { + Runnable task = tasks.poll(); + if (task == null) { + break; + } + safeExecute(task); + } + } finally { + if (i == maxTaskExecutePerRun) { + try { + state.set(SUBMITTED); + // Only set executingThread to null if no other thread did update it yet. + executingThread.compareAndSet(current, null); + executor.execute(this); + return; // done + } catch (Throwable ignore) { + // Reset the state back to running as we will keep on executing tasks. + state.set(RUNNING); + // if an error happened we should just ignore it and let the loop run again as there is not + // much else we can do. Most likely this was triggered by a full task queue. In this case + // we just will run more tasks and try again later. + } + } else { + state.set(NONE); + // After setting the state to NONE, look at the tasks queue one more time. + // If it is empty, then we can return from this method. + // Otherwise, it means the producer thread has called execute(Runnable) + // and enqueued a task in between the tasks.poll() above and the state.set(NONE) here. + // There are two possible scenarios when this happens + // + // 1. The producer thread sees state == NONE, hence the compareAndSet(NONE, SUBMITTED) + // is successfully setting the state to SUBMITTED. This mean the producer + // will call / has called executor.execute(this). In this case, we can just return. + // 2. The producer thread don't see the state change, hence the compareAndSet(NONE, SUBMITTED) + // returns false. In this case, the producer thread won't call executor.execute. + // In this case, we need to change the state to RUNNING and keeps running. + // + // The above cases can be distinguished by performing a + // compareAndSet(NONE, RUNNING). If it returns "false", it is case 1; otherwise it is case 2. + if (tasks.isEmpty() || !state.compareAndSet(NONE, RUNNING)) { + // Only set executingThread to null if no other thread did update it yet. + executingThread.compareAndSet(current, null); + return; // done + } + } + } + } + } + + @Override + public boolean inEventLoop(Thread thread) { + return executingThread.get() == thread; + } + + @Override + public boolean isShuttingDown() { + return executor.isShutdown(); + } + + @Override + public Future shutdownGracefully(long quietPeriod, long timeout, TimeUnit unit) { + return executor.shutdownGracefully(quietPeriod, timeout, unit); + } + + @Override + public Future terminationFuture() { + return executor.terminationFuture(); + } + + @Override + public void shutdown() { + executor.shutdown(); + } + + @Override + public boolean isShutdown() { + return executor.isShutdown(); + } + + @Override + public boolean isTerminated() { + return executor.isTerminated(); + } + + @Override + public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException { + return executor.awaitTermination(timeout, unit); + } + + @Override + public void execute(Runnable command) { + if (!tasks.offer(command)) { + throw new RejectedExecutionException(); + } + if (state.compareAndSet(NONE, SUBMITTED)) { + // Actually it could happen that the runnable was picked up in between but we not care to much and just + // execute ourself. At worst this will be a NOOP when run() is called. + executor.execute(this); + } + } + } +} diff --git a/netty-util/src/main/java/io/netty/util/concurrent/OrderedEventExecutor.java b/netty-util/src/main/java/io/netty/util/concurrent/OrderedEventExecutor.java new file mode 100644 index 0000000..b77a5ad --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/concurrent/OrderedEventExecutor.java @@ -0,0 +1,22 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.concurrent; + +/** + * Marker interface for {@link EventExecutor}s that will process all submitted tasks in an ordered / serial fashion. + */ +public interface OrderedEventExecutor extends EventExecutor { +} diff --git a/netty-util/src/main/java/io/netty/util/concurrent/ProgressiveFuture.java b/netty-util/src/main/java/io/netty/util/concurrent/ProgressiveFuture.java new file mode 100644 index 0000000..d9f792a --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/concurrent/ProgressiveFuture.java @@ -0,0 +1,47 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.util.concurrent; + +/** + * A {@link Future} which is used to indicate the progress of an operation. + */ +public interface ProgressiveFuture extends Future { + + @Override + ProgressiveFuture addListener(GenericFutureListener> listener); + + @Override + ProgressiveFuture addListeners(GenericFutureListener>... listeners); + + @Override + ProgressiveFuture removeListener(GenericFutureListener> listener); + + @Override + ProgressiveFuture removeListeners(GenericFutureListener>... listeners); + + @Override + ProgressiveFuture sync() throws InterruptedException; + + @Override + ProgressiveFuture syncUninterruptibly(); + + @Override + ProgressiveFuture await() throws InterruptedException; + + @Override + ProgressiveFuture awaitUninterruptibly(); +} diff --git a/netty-util/src/main/java/io/netty/util/concurrent/ProgressivePromise.java b/netty-util/src/main/java/io/netty/util/concurrent/ProgressivePromise.java new file mode 100644 index 0000000..b4ead0b --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/concurrent/ProgressivePromise.java @@ -0,0 +1,65 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.concurrent; + +/** + * Special {@link ProgressiveFuture} which is writable. + */ +public interface ProgressivePromise extends Promise, ProgressiveFuture { + + /** + * Sets the current progress of the operation and notifies the listeners that implement + * {@link GenericProgressiveFutureListener}. + */ + ProgressivePromise setProgress(long progress, long total); + + /** + * Tries to set the current progress of the operation and notifies the listeners that implement + * {@link GenericProgressiveFutureListener}. If the operation is already complete or the progress is out of range, + * this method does nothing but returning {@code false}. + */ + boolean tryProgress(long progress, long total); + + @Override + ProgressivePromise setSuccess(V result); + + @Override + ProgressivePromise setFailure(Throwable cause); + + @Override + ProgressivePromise addListener(GenericFutureListener> listener); + + @Override + ProgressivePromise addListeners(GenericFutureListener>... listeners); + + @Override + ProgressivePromise removeListener(GenericFutureListener> listener); + + @Override + ProgressivePromise removeListeners(GenericFutureListener>... listeners); + + @Override + ProgressivePromise await() throws InterruptedException; + + @Override + ProgressivePromise awaitUninterruptibly(); + + @Override + ProgressivePromise sync() throws InterruptedException; + + @Override + ProgressivePromise syncUninterruptibly(); +} diff --git a/netty-util/src/main/java/io/netty/util/concurrent/Promise.java b/netty-util/src/main/java/io/netty/util/concurrent/Promise.java new file mode 100644 index 0000000..ab5b3cd --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/concurrent/Promise.java @@ -0,0 +1,90 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.concurrent; + +/** + * Special {@link Future} which is writable. + */ +public interface Promise extends Future { + + /** + * Marks this future as a success and notifies all + * listeners. + *

+ * If it is success or failed already it will throw an {@link IllegalStateException}. + */ + Promise setSuccess(V result); + + /** + * Marks this future as a success and notifies all + * listeners. + * + * @return {@code true} if and only if successfully marked this future as + * a success. Otherwise {@code false} because this future is + * already marked as either a success or a failure. + */ + boolean trySuccess(V result); + + /** + * Marks this future as a failure and notifies all + * listeners. + *

+ * If it is success or failed already it will throw an {@link IllegalStateException}. + */ + Promise setFailure(Throwable cause); + + /** + * Marks this future as a failure and notifies all + * listeners. + * + * @return {@code true} if and only if successfully marked this future as + * a failure. Otherwise {@code false} because this future is + * already marked as either a success or a failure. + */ + boolean tryFailure(Throwable cause); + + /** + * Make this future impossible to cancel. + * + * @return {@code true} if and only if successfully marked this future as uncancellable or it is already done + * without being cancelled. {@code false} if this future has been cancelled already. + */ + boolean setUncancellable(); + + @Override + Promise addListener(GenericFutureListener> listener); + + @Override + Promise addListeners(GenericFutureListener>... listeners); + + @Override + Promise removeListener(GenericFutureListener> listener); + + @Override + Promise removeListeners(GenericFutureListener>... listeners); + + @Override + Promise await() throws InterruptedException; + + @Override + Promise awaitUninterruptibly(); + + @Override + Promise sync() throws InterruptedException; + + @Override + Promise syncUninterruptibly(); +} diff --git a/netty-util/src/main/java/io/netty/util/concurrent/PromiseAggregator.java b/netty-util/src/main/java/io/netty/util/concurrent/PromiseAggregator.java new file mode 100644 index 0000000..2b9aadd --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/concurrent/PromiseAggregator.java @@ -0,0 +1,110 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.util.concurrent; + +import io.netty.util.internal.ObjectUtil; +import java.util.LinkedHashSet; +import java.util.Set; + +/** + * @param the type of value returned by the {@link Future} + * @param the type of {@link Future} + * @deprecated Use {@link PromiseCombiner#PromiseCombiner(EventExecutor)}. + *

+ * {@link GenericFutureListener} implementation which consolidates multiple {@link Future}s + * into one, by listening to individual {@link Future}s and producing an aggregated result + * (success/failure) when all {@link Future}s have completed. + */ +@Deprecated +public class PromiseAggregator> implements GenericFutureListener { + + private final Promise aggregatePromise; + private final boolean failPending; + private Set> pendingPromises; + + /** + * Creates a new instance. + * + * @param aggregatePromise the {@link Promise} to notify + * @param failPending {@code true} to fail pending promises, false to leave them unaffected + */ + public PromiseAggregator(Promise aggregatePromise, boolean failPending) { + this.aggregatePromise = ObjectUtil.checkNotNull(aggregatePromise, "aggregatePromise"); + this.failPending = failPending; + } + + /** + * See {@link PromiseAggregator#PromiseAggregator(Promise, boolean)}. + * Defaults {@code failPending} to true. + */ + public PromiseAggregator(Promise aggregatePromise) { + this(aggregatePromise, true); + } + + /** + * Add the given {@link Promise}s to the aggregator. + */ + @SafeVarargs + public final PromiseAggregator add(Promise... promises) { + ObjectUtil.checkNotNull(promises, "promises"); + if (promises.length == 0) { + return this; + } + synchronized (this) { + if (pendingPromises == null) { + int size; + if (promises.length > 1) { + size = promises.length; + } else { + size = 2; + } + pendingPromises = new LinkedHashSet>(size); + } + for (Promise p : promises) { + if (p == null) { + continue; + } + pendingPromises.add(p); + p.addListener(this); + } + } + return this; + } + + @Override + public synchronized void operationComplete(F future) throws Exception { + if (pendingPromises == null) { + aggregatePromise.setSuccess(null); + } else { + pendingPromises.remove(future); + if (!future.isSuccess()) { + Throwable cause = future.cause(); + aggregatePromise.setFailure(cause); + if (failPending) { + for (Promise pendingFuture : pendingPromises) { + pendingFuture.setFailure(cause); + } + } + } else { + if (pendingPromises.isEmpty()) { + aggregatePromise.setSuccess(null); + } + } + } + } + +} diff --git a/netty-util/src/main/java/io/netty/util/concurrent/PromiseCombiner.java b/netty-util/src/main/java/io/netty/util/concurrent/PromiseCombiner.java new file mode 100644 index 0000000..40845fa --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/concurrent/PromiseCombiner.java @@ -0,0 +1,176 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.concurrent; + +import io.netty.util.internal.ObjectUtil; + +/** + *

A promise combiner monitors the outcome of a number of discrete futures, then notifies a final, aggregate promise + * when all of the combined futures are finished. The aggregate promise will succeed if and only if all of the combined + * futures succeed. If any of the combined futures fail, the aggregate promise will fail. The cause failure for the + * aggregate promise will be the failure for one of the failed combined futures; if more than one of the combined + * futures fails, exactly which cause of failure will be assigned to the aggregate promise is undefined.

+ * + *

Callers may populate a promise combiner with any number of futures to be combined via the + * {@link PromiseCombiner#add(Future)} and {@link PromiseCombiner#addAll(Future[])} methods. When all futures to be + * combined have been added, callers must provide an aggregate promise to be notified when all combined promises have + * finished via the {@link PromiseCombiner#finish(Promise)} method.

+ * + *

This implementation is NOT thread-safe and all methods must be called + * from the {@link EventExecutor} thread.

+ */ +public final class PromiseCombiner { + private int expectedCount; + private int doneCount; + private Promise aggregatePromise; + private Throwable cause; + private final GenericFutureListener> listener = new GenericFutureListener>() { + @Override + public void operationComplete(final Future future) { + if (executor.inEventLoop()) { + operationComplete0(future); + } else { + executor.execute(new Runnable() { + @Override + public void run() { + operationComplete0(future); + } + }); + } + } + + private void operationComplete0(Future future) { + assert executor.inEventLoop(); + ++doneCount; + if (!future.isSuccess() && cause == null) { + cause = future.cause(); + } + if (doneCount == expectedCount && aggregatePromise != null) { + tryPromise(); + } + } + }; + + private final EventExecutor executor; + + /** + * Deprecated use {@link PromiseCombiner#PromiseCombiner(EventExecutor)}. + */ + @Deprecated + public PromiseCombiner() { + this(ImmediateEventExecutor.INSTANCE); + } + + /** + * The {@link EventExecutor} to use for notifications. You must call {@link #add(Future)}, {@link #addAll(Future[])} + * and {@link #finish(Promise)} from within the {@link EventExecutor} thread. + * + * @param executor the {@link EventExecutor} to use for notifications. + */ + public PromiseCombiner(EventExecutor executor) { + this.executor = ObjectUtil.checkNotNull(executor, "executor"); + } + + /** + * Adds a new promise to be combined. New promises may be added until an aggregate promise is added via the + * {@link PromiseCombiner#finish(Promise)} method. + * + * @param promise the promise to add to this promise combiner + * @deprecated Replaced by {@link PromiseCombiner#add(Future)}. + */ + @Deprecated + public void add(Promise promise) { + add((Future) promise); + } + + /** + * Adds a new future to be combined. New futures may be added until an aggregate promise is added via the + * {@link PromiseCombiner#finish(Promise)} method. + * + * @param future the future to add to this promise combiner + */ + @SuppressWarnings({"unchecked", "rawtypes"}) + public void add(Future future) { + checkAddAllowed(); + checkInEventLoop(); + ++expectedCount; + future.addListener(listener); + } + + /** + * Adds new promises to be combined. New promises may be added until an aggregate promise is added via the + * {@link PromiseCombiner#finish(Promise)} method. + * + * @param promises the promises to add to this promise combiner + * @deprecated Replaced by {@link PromiseCombiner#addAll(Future[])} + */ + @Deprecated + public void addAll(Promise... promises) { + addAll((Future[]) promises); + } + + /** + * Adds new futures to be combined. New futures may be added until an aggregate promise is added via the + * {@link PromiseCombiner#finish(Promise)} method. + * + * @param futures the futures to add to this promise combiner + */ + @SuppressWarnings({"unchecked", "rawtypes"}) + public void addAll(Future... futures) { + for (Future future : futures) { + this.add(future); + } + } + + /** + *

Sets the promise to be notified when all combined futures have finished. If all combined futures succeed, + * then the aggregate promise will succeed. If one or more combined futures fails, then the aggregate promise will + * fail with the cause of one of the failed futures. If more than one combined future fails, then exactly which + * failure will be assigned to the aggregate promise is undefined.

+ * + *

After this method is called, no more futures may be added via the {@link PromiseCombiner#add(Future)} or + * {@link PromiseCombiner#addAll(Future[])} methods.

+ * + * @param aggregatePromise the promise to notify when all combined futures have finished + */ + public void finish(Promise aggregatePromise) { + ObjectUtil.checkNotNull(aggregatePromise, "aggregatePromise"); + checkInEventLoop(); + if (this.aggregatePromise != null) { + throw new IllegalStateException("Already finished"); + } + this.aggregatePromise = aggregatePromise; + if (doneCount == expectedCount) { + tryPromise(); + } + } + + private void checkInEventLoop() { + if (!executor.inEventLoop()) { + throw new IllegalStateException("Must be called from EventExecutor thread"); + } + } + + private boolean tryPromise() { + return (cause == null) ? aggregatePromise.trySuccess(null) : aggregatePromise.tryFailure(cause); + } + + private void checkAddAllowed() { + if (aggregatePromise != null) { + throw new IllegalStateException("Adding promises is not allowed after finished adding"); + } + } +} diff --git a/netty-util/src/main/java/io/netty/util/concurrent/PromiseNotifier.java b/netty-util/src/main/java/io/netty/util/concurrent/PromiseNotifier.java new file mode 100644 index 0000000..9e99c4d --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/concurrent/PromiseNotifier.java @@ -0,0 +1,133 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.concurrent; + +import io.netty.util.internal.PromiseNotificationUtil; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; +import static io.netty.util.internal.ObjectUtil.checkNotNull; +import static io.netty.util.internal.ObjectUtil.checkNotNullWithIAE; + +/** + * {@link GenericFutureListener} implementation which takes other {@link Promise}s + * and notifies them on completion. + * + * @param the type of value returned by the future + * @param the type of future + */ +public class PromiseNotifier> implements GenericFutureListener { + + private static final InternalLogger logger = InternalLoggerFactory.getInstance(PromiseNotifier.class); + private final Promise[] promises; + private final boolean logNotifyFailure; + + /** + * Create a new instance. + * + * @param promises the {@link Promise}s to notify once this {@link GenericFutureListener} is notified. + */ + @SafeVarargs + public PromiseNotifier(Promise... promises) { + this(true, promises); + } + + /** + * Create a new instance. + * + * @param logNotifyFailure {@code true} if logging should be done in case notification fails. + * @param promises the {@link Promise}s to notify once this {@link GenericFutureListener} is notified. + */ + @SafeVarargs + public PromiseNotifier(boolean logNotifyFailure, Promise... promises) { + checkNotNull(promises, "promises"); + for (Promise promise : promises) { + checkNotNullWithIAE(promise, "promise"); + } + this.promises = promises.clone(); + this.logNotifyFailure = logNotifyFailure; + } + + /** + * Link the {@link Future} and {@link Promise} such that if the {@link Future} completes the {@link Promise} + * will be notified. Cancellation is propagated both ways such that if the {@link Future} is cancelled + * the {@link Promise} is cancelled and vise-versa. + * + * @param future the {@link Future} which will be used to listen to for notifying the {@link Promise}. + * @param promise the {@link Promise} which will be notified + * @param the type of the value. + * @param the type of the {@link Future} + * @return the passed in {@link Future} + */ + public static > F cascade(final F future, final Promise promise) { + return cascade(true, future, promise); + } + + /** + * Link the {@link Future} and {@link Promise} such that if the {@link Future} completes the {@link Promise} + * will be notified. Cancellation is propagated both ways such that if the {@link Future} is cancelled + * the {@link Promise} is cancelled and vise-versa. + * + * @param logNotifyFailure {@code true} if logging should be done in case notification fails. + * @param future the {@link Future} which will be used to listen to for notifying the {@link Promise}. + * @param promise the {@link Promise} which will be notified + * @param the type of the value. + * @param the type of the {@link Future} + * @return the passed in {@link Future} + */ + @SuppressWarnings({"unchecked", "rawtypes"}) + public static > F cascade(boolean logNotifyFailure, final F future, + final Promise promise) { + promise.addListener(new FutureListener() { + @Override + public void operationComplete(Future f) { + if (f.isCancelled()) { + future.cancel(false); + } + } + }); + future.addListener(new PromiseNotifier(logNotifyFailure, promise) { + @Override + public void operationComplete(Future f) throws Exception { + if (promise.isCancelled() && f.isCancelled()) { + // Just return if we propagate a cancel from the promise to the future and both are notified already + return; + } + super.operationComplete(future); + } + }); + return future; + } + + @Override + public void operationComplete(F future) throws Exception { + InternalLogger internalLogger = logNotifyFailure ? logger : null; + if (future.isSuccess()) { + V result = future.get(); + for (Promise p : promises) { + PromiseNotificationUtil.trySuccess(p, result, internalLogger); + } + } else if (future.isCancelled()) { + for (Promise p : promises) { + PromiseNotificationUtil.tryCancel(p, internalLogger); + } + } else { + Throwable cause = future.cause(); + for (Promise p : promises) { + PromiseNotificationUtil.tryFailure(p, cause, internalLogger); + } + } + } +} diff --git a/netty-util/src/main/java/io/netty/util/concurrent/PromiseTask.java b/netty-util/src/main/java/io/netty/util/concurrent/PromiseTask.java new file mode 100644 index 0000000..56bf747 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/concurrent/PromiseTask.java @@ -0,0 +1,189 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.concurrent; + +import java.util.concurrent.Callable; +import java.util.concurrent.RunnableFuture; + +class PromiseTask extends DefaultPromise implements RunnableFuture { + + private static final class RunnableAdapter implements Callable { + final Runnable task; + final T result; + + RunnableAdapter(Runnable task, T result) { + this.task = task; + this.result = result; + } + + @Override + public T call() { + task.run(); + return result; + } + + @Override + public String toString() { + return "Callable(task: " + task + ", result: " + result + ')'; + } + } + + private static final Runnable COMPLETED = new SentinelRunnable("COMPLETED"); + private static final Runnable CANCELLED = new SentinelRunnable("CANCELLED"); + private static final Runnable FAILED = new SentinelRunnable("FAILED"); + + private static class SentinelRunnable implements Runnable { + private final String name; + + SentinelRunnable(String name) { + this.name = name; + } + + @Override + public void run() { + } // no-op + + @Override + public String toString() { + return name; + } + } + + // Strictly of type Callable or Runnable + private Object task; + + PromiseTask(EventExecutor executor, Runnable runnable, V result) { + super(executor); + task = result == null ? runnable : new RunnableAdapter(runnable, result); + } + + PromiseTask(EventExecutor executor, Runnable runnable) { + super(executor); + task = runnable; + } + + PromiseTask(EventExecutor executor, Callable callable) { + super(executor); + task = callable; + } + + @Override + public final int hashCode() { + return System.identityHashCode(this); + } + + @Override + public final boolean equals(Object obj) { + return this == obj; + } + + @SuppressWarnings("unchecked") + V runTask() throws Throwable { + final Object task = this.task; + if (task instanceof Callable) { + return ((Callable) task).call(); + } + ((Runnable) task).run(); + return null; + } + + @Override + public void run() { + try { + if (setUncancellableInternal()) { + V result = runTask(); + setSuccessInternal(result); + } + } catch (Throwable e) { + setFailureInternal(e); + } + } + + private boolean clearTaskAfterCompletion(boolean done, Runnable result) { + if (done) { + // The only time where it might be possible for the sentinel task + // to be called is in the case of a periodic ScheduledFutureTask, + // in which case it's a benign race with cancellation and the (null) + // return value is not used. + task = result; + } + return done; + } + + @Override + public final Promise setFailure(Throwable cause) { + throw new IllegalStateException(); + } + + protected final Promise setFailureInternal(Throwable cause) { + super.setFailure(cause); + clearTaskAfterCompletion(true, FAILED); + return this; + } + + @Override + public final boolean tryFailure(Throwable cause) { + return false; + } + + protected final boolean tryFailureInternal(Throwable cause) { + return clearTaskAfterCompletion(super.tryFailure(cause), FAILED); + } + + @Override + public final Promise setSuccess(V result) { + throw new IllegalStateException(); + } + + protected final Promise setSuccessInternal(V result) { + super.setSuccess(result); + clearTaskAfterCompletion(true, COMPLETED); + return this; + } + + @Override + public final boolean trySuccess(V result) { + return false; + } + + protected final boolean trySuccessInternal(V result) { + return clearTaskAfterCompletion(super.trySuccess(result), COMPLETED); + } + + @Override + public final boolean setUncancellable() { + throw new IllegalStateException(); + } + + protected final boolean setUncancellableInternal() { + return super.setUncancellable(); + } + + @Override + public boolean cancel(boolean mayInterruptIfRunning) { + return clearTaskAfterCompletion(super.cancel(mayInterruptIfRunning), CANCELLED); + } + + @Override + protected StringBuilder toStringBuilder() { + StringBuilder buf = super.toStringBuilder(); + buf.setCharAt(buf.length() - 1, ','); + + return buf.append(" task: ") + .append(task) + .append(')'); + } +} diff --git a/netty-util/src/main/java/io/netty/util/concurrent/RejectedExecutionHandler.java b/netty-util/src/main/java/io/netty/util/concurrent/RejectedExecutionHandler.java new file mode 100644 index 0000000..bf3f7a5 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/concurrent/RejectedExecutionHandler.java @@ -0,0 +1,28 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.concurrent; + +/** + * Similar to {@link java.util.concurrent.RejectedExecutionHandler} but specific to {@link SingleThreadEventExecutor}. + */ +public interface RejectedExecutionHandler { + + /** + * Called when someone tried to add a task to {@link SingleThreadEventExecutor} but this failed due capacity + * restrictions. + */ + void rejected(Runnable task, SingleThreadEventExecutor executor); +} diff --git a/netty-util/src/main/java/io/netty/util/concurrent/RejectedExecutionHandlers.java b/netty-util/src/main/java/io/netty/util/concurrent/RejectedExecutionHandlers.java new file mode 100644 index 0000000..7ccd1d3 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/concurrent/RejectedExecutionHandlers.java @@ -0,0 +1,72 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.concurrent; + +import io.netty.util.internal.ObjectUtil; +import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.locks.LockSupport; + +/** + * Expose helper methods which create different {@link RejectedExecutionHandler}s. + */ +public final class RejectedExecutionHandlers { + private static final RejectedExecutionHandler REJECT = new RejectedExecutionHandler() { + @Override + public void rejected(Runnable task, SingleThreadEventExecutor executor) { + throw new RejectedExecutionException(); + } + }; + + private RejectedExecutionHandlers() { + } + + /** + * Returns a {@link RejectedExecutionHandler} that will always just throw a {@link RejectedExecutionException}. + */ + public static RejectedExecutionHandler reject() { + return REJECT; + } + + /** + * Tries to backoff when the task can not be added due restrictions for an configured amount of time. This + * is only done if the task was added from outside of the event loop which means + * {@link EventExecutor#inEventLoop()} returns {@code false}. + */ + public static RejectedExecutionHandler backoff(final int retries, long backoffAmount, TimeUnit unit) { + ObjectUtil.checkPositive(retries, "retries"); + final long backOffNanos = unit.toNanos(backoffAmount); + return new RejectedExecutionHandler() { + @Override + public void rejected(Runnable task, SingleThreadEventExecutor executor) { + if (!executor.inEventLoop()) { + for (int i = 0; i < retries; i++) { + // Try to wake up the executor so it will empty its task queue. + executor.wakeup(false); + + LockSupport.parkNanos(backOffNanos); + if (executor.offerTask(task)) { + return; + } + } + } + // Either we tried to add the task from within the EventLoop or we was not able to add it even with + // backoff. + throw new RejectedExecutionException(); + } + }; + } +} diff --git a/netty-util/src/main/java/io/netty/util/concurrent/ScheduledFuture.java b/netty-util/src/main/java/io/netty/util/concurrent/ScheduledFuture.java new file mode 100644 index 0000000..dcc1d4a --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/concurrent/ScheduledFuture.java @@ -0,0 +1,23 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.concurrent; + +/** + * The result of a scheduled asynchronous operation. + */ +@SuppressWarnings("ClassNameSameAsAncestorName") +public interface ScheduledFuture extends Future, java.util.concurrent.ScheduledFuture { +} diff --git a/netty-util/src/main/java/io/netty/util/concurrent/ScheduledFutureTask.java b/netty-util/src/main/java/io/netty/util/concurrent/ScheduledFutureTask.java new file mode 100644 index 0000000..c3250c9 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/concurrent/ScheduledFutureTask.java @@ -0,0 +1,219 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.util.concurrent; + +import io.netty.util.internal.DefaultPriorityQueue; +import io.netty.util.internal.PriorityQueueNode; +import java.util.concurrent.Callable; +import java.util.concurrent.Delayed; +import java.util.concurrent.TimeUnit; + +@SuppressWarnings("ComparableImplementedButEqualsNotOverridden") +final class ScheduledFutureTask extends PromiseTask implements ScheduledFuture, PriorityQueueNode { + // set once when added to priority queue + private long id; + + private long deadlineNanos; + /* 0 - no repeat, >0 - repeat at fixed rate, <0 - repeat with fixed delay */ + private final long periodNanos; + + private int queueIndex = INDEX_NOT_IN_QUEUE; + + ScheduledFutureTask(AbstractScheduledEventExecutor executor, + Runnable runnable, long nanoTime) { + + super(executor, runnable); + deadlineNanos = nanoTime; + periodNanos = 0; + } + + ScheduledFutureTask(AbstractScheduledEventExecutor executor, + Runnable runnable, long nanoTime, long period) { + + super(executor, runnable); + deadlineNanos = nanoTime; + periodNanos = validatePeriod(period); + } + + ScheduledFutureTask(AbstractScheduledEventExecutor executor, + Callable callable, long nanoTime, long period) { + + super(executor, callable); + deadlineNanos = nanoTime; + periodNanos = validatePeriod(period); + } + + ScheduledFutureTask(AbstractScheduledEventExecutor executor, + Callable callable, long nanoTime) { + + super(executor, callable); + deadlineNanos = nanoTime; + periodNanos = 0; + } + + private static long validatePeriod(long period) { + if (period == 0) { + throw new IllegalArgumentException("period: 0 (expected: != 0)"); + } + return period; + } + + ScheduledFutureTask setId(long id) { + if (this.id == 0L) { + this.id = id; + } + return this; + } + + @Override + protected EventExecutor executor() { + return super.executor(); + } + + public long deadlineNanos() { + return deadlineNanos; + } + + void setConsumed() { + // Optimization to avoid checking system clock again + // after deadline has passed and task has been dequeued + if (periodNanos == 0) { + assert scheduledExecutor().getCurrentTimeNanos() >= deadlineNanos; + deadlineNanos = 0L; + } + } + + public long delayNanos() { + return delayNanos(scheduledExecutor().getCurrentTimeNanos()); + } + + static long deadlineToDelayNanos(long currentTimeNanos, long deadlineNanos) { + return deadlineNanos == 0L ? 0L : Math.max(0L, deadlineNanos - currentTimeNanos); + } + + public long delayNanos(long currentTimeNanos) { + return deadlineToDelayNanos(currentTimeNanos, deadlineNanos); + } + + @Override + public long getDelay(TimeUnit unit) { + return unit.convert(delayNanos(), TimeUnit.NANOSECONDS); + } + + @Override + public int compareTo(Delayed o) { + if (this == o) { + return 0; + } + + ScheduledFutureTask that = (ScheduledFutureTask) o; + long d = deadlineNanos() - that.deadlineNanos(); + if (d < 0) { + return -1; + } else if (d > 0) { + return 1; + } else if (id < that.id) { + return -1; + } else { + assert id != that.id; + return 1; + } + } + + @Override + public void run() { + assert executor().inEventLoop(); + try { + if (delayNanos() > 0L) { + // Not yet expired, need to add or remove from queue + if (isCancelled()) { + scheduledExecutor().scheduledTaskQueue().removeTyped(this); + } else { + scheduledExecutor().scheduleFromEventLoop(this); + } + return; + } + if (periodNanos == 0) { + if (setUncancellableInternal()) { + V result = runTask(); + setSuccessInternal(result); + } + } else { + // check if is done as it may was cancelled + if (!isCancelled()) { + runTask(); + if (!executor().isShutdown()) { + if (periodNanos > 0) { + deadlineNanos += periodNanos; + } else { + deadlineNanos = scheduledExecutor().getCurrentTimeNanos() - periodNanos; + } + if (!isCancelled()) { + scheduledExecutor().scheduledTaskQueue().add(this); + } + } + } + } + } catch (Throwable cause) { + setFailureInternal(cause); + } + } + + private AbstractScheduledEventExecutor scheduledExecutor() { + return (AbstractScheduledEventExecutor) executor(); + } + + /** + * {@inheritDoc} + * + * @param mayInterruptIfRunning this value has no effect in this implementation. + */ + @Override + public boolean cancel(boolean mayInterruptIfRunning) { + boolean canceled = super.cancel(mayInterruptIfRunning); + if (canceled) { + scheduledExecutor().removeScheduled(this); + } + return canceled; + } + + boolean cancelWithoutRemove(boolean mayInterruptIfRunning) { + return super.cancel(mayInterruptIfRunning); + } + + @Override + protected StringBuilder toStringBuilder() { + StringBuilder buf = super.toStringBuilder(); + buf.setCharAt(buf.length() - 1, ','); + + return buf.append(" deadline: ") + .append(deadlineNanos) + .append(", period: ") + .append(periodNanos) + .append(')'); + } + + @Override + public int priorityQueueIndex(DefaultPriorityQueue queue) { + return queueIndex; + } + + @Override + public void priorityQueueIndex(DefaultPriorityQueue queue, int i) { + queueIndex = i; + } +} diff --git a/netty-util/src/main/java/io/netty/util/concurrent/SingleThreadEventExecutor.java b/netty-util/src/main/java/io/netty/util/concurrent/SingleThreadEventExecutor.java new file mode 100644 index 0000000..e0433a4 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/concurrent/SingleThreadEventExecutor.java @@ -0,0 +1,1130 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.concurrent; + +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.PlatformDependent; +import io.netty.util.internal.SystemPropertyUtil; +import io.netty.util.internal.ThreadExecutorMap; +import io.netty.util.internal.UnstableApi; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; +import java.lang.Thread.State; +import java.util.ArrayList; +import java.util.Collection; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Queue; +import java.util.Set; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.Callable; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.Executor; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.TimeoutException; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import java.util.concurrent.atomic.AtomicReferenceFieldUpdater; + +/** + * Abstract base class for {@link OrderedEventExecutor}'s that execute all its submitted tasks in a single thread. + */ +public abstract class SingleThreadEventExecutor extends AbstractScheduledEventExecutor implements OrderedEventExecutor { + + static final int DEFAULT_MAX_PENDING_EXECUTOR_TASKS = Math.max(16, + SystemPropertyUtil.getInt("io.netty.eventexecutor.maxPendingTasks", Integer.MAX_VALUE)); + + private static final InternalLogger logger = + InternalLoggerFactory.getInstance(SingleThreadEventExecutor.class); + + private static final int ST_NOT_STARTED = 1; + private static final int ST_STARTED = 2; + private static final int ST_SHUTTING_DOWN = 3; + private static final int ST_SHUTDOWN = 4; + private static final int ST_TERMINATED = 5; + + private static final Runnable NOOP_TASK = new Runnable() { + @Override + public void run() { + // Do nothing. + } + }; + + private static final AtomicIntegerFieldUpdater STATE_UPDATER = + AtomicIntegerFieldUpdater.newUpdater(SingleThreadEventExecutor.class, "state"); + private static final AtomicReferenceFieldUpdater PROPERTIES_UPDATER = + AtomicReferenceFieldUpdater.newUpdater( + SingleThreadEventExecutor.class, ThreadProperties.class, "threadProperties"); + + private final Queue taskQueue; + + private volatile Thread thread; + @SuppressWarnings("unused") + private volatile ThreadProperties threadProperties; + private final Executor executor; + private volatile boolean interrupted; + + private final CountDownLatch threadLock = new CountDownLatch(1); + private final Set shutdownHooks = new LinkedHashSet(); + private final boolean addTaskWakesUp; + private final int maxPendingTasks; + private final RejectedExecutionHandler rejectedExecutionHandler; + + private long lastExecutionTime; + + @SuppressWarnings({"FieldMayBeFinal", "unused"}) + private volatile int state = ST_NOT_STARTED; + + private volatile long gracefulShutdownQuietPeriod; + private volatile long gracefulShutdownTimeout; + private long gracefulShutdownStartTime; + + private final Promise terminationFuture = new DefaultPromise(GlobalEventExecutor.INSTANCE); + + /** + * Create a new instance + * + * @param parent the {@link EventExecutorGroup} which is the parent of this instance and belongs to it + * @param threadFactory the {@link ThreadFactory} which will be used for the used {@link Thread} + * @param addTaskWakesUp {@code true} if and only if invocation of {@link #addTask(Runnable)} will wake up the + * executor thread + */ + protected SingleThreadEventExecutor( + EventExecutorGroup parent, ThreadFactory threadFactory, boolean addTaskWakesUp) { + this(parent, new ThreadPerTaskExecutor(threadFactory), addTaskWakesUp); + } + + /** + * Create a new instance + * + * @param parent the {@link EventExecutorGroup} which is the parent of this instance and belongs to it + * @param threadFactory the {@link ThreadFactory} which will be used for the used {@link Thread} + * @param addTaskWakesUp {@code true} if and only if invocation of {@link #addTask(Runnable)} will wake up the + * executor thread + * @param maxPendingTasks the maximum number of pending tasks before new tasks will be rejected. + * @param rejectedHandler the {@link RejectedExecutionHandler} to use. + */ + protected SingleThreadEventExecutor( + EventExecutorGroup parent, ThreadFactory threadFactory, + boolean addTaskWakesUp, int maxPendingTasks, RejectedExecutionHandler rejectedHandler) { + this(parent, new ThreadPerTaskExecutor(threadFactory), addTaskWakesUp, maxPendingTasks, rejectedHandler); + } + + /** + * Create a new instance + * + * @param parent the {@link EventExecutorGroup} which is the parent of this instance and belongs to it + * @param executor the {@link Executor} which will be used for executing + * @param addTaskWakesUp {@code true} if and only if invocation of {@link #addTask(Runnable)} will wake up the + * executor thread + */ + protected SingleThreadEventExecutor(EventExecutorGroup parent, Executor executor, boolean addTaskWakesUp) { + this(parent, executor, addTaskWakesUp, DEFAULT_MAX_PENDING_EXECUTOR_TASKS, RejectedExecutionHandlers.reject()); + } + + /** + * Create a new instance + * + * @param parent the {@link EventExecutorGroup} which is the parent of this instance and belongs to it + * @param executor the {@link Executor} which will be used for executing + * @param addTaskWakesUp {@code true} if and only if invocation of {@link #addTask(Runnable)} will wake up the + * executor thread + * @param maxPendingTasks the maximum number of pending tasks before new tasks will be rejected. + * @param rejectedHandler the {@link RejectedExecutionHandler} to use. + */ + protected SingleThreadEventExecutor(EventExecutorGroup parent, Executor executor, + boolean addTaskWakesUp, int maxPendingTasks, + RejectedExecutionHandler rejectedHandler) { + super(parent); + this.addTaskWakesUp = addTaskWakesUp; + this.maxPendingTasks = Math.max(16, maxPendingTasks); + this.executor = ThreadExecutorMap.apply(executor, this); + taskQueue = newTaskQueue(this.maxPendingTasks); + rejectedExecutionHandler = ObjectUtil.checkNotNull(rejectedHandler, "rejectedHandler"); + } + + protected SingleThreadEventExecutor(EventExecutorGroup parent, Executor executor, + boolean addTaskWakesUp, Queue taskQueue, + RejectedExecutionHandler rejectedHandler) { + super(parent); + this.addTaskWakesUp = addTaskWakesUp; + this.maxPendingTasks = DEFAULT_MAX_PENDING_EXECUTOR_TASKS; + this.executor = ThreadExecutorMap.apply(executor, this); + this.taskQueue = ObjectUtil.checkNotNull(taskQueue, "taskQueue"); + this.rejectedExecutionHandler = ObjectUtil.checkNotNull(rejectedHandler, "rejectedHandler"); + } + + /** + * @deprecated Please use and override {@link #newTaskQueue(int)}. + */ + @Deprecated + protected Queue newTaskQueue() { + return newTaskQueue(maxPendingTasks); + } + + /** + * Create a new {@link Queue} which will holds the tasks to execute. This default implementation will return a + * {@link LinkedBlockingQueue} but if your sub-class of {@link SingleThreadEventExecutor} will not do any blocking + * calls on the this {@link Queue} it may make sense to {@code @Override} this and return some more performant + * implementation that does not support blocking operations at all. + */ + protected Queue newTaskQueue(int maxPendingTasks) { + return new LinkedBlockingQueue(maxPendingTasks); + } + + /** + * Interrupt the current running {@link Thread}. + */ + protected void interruptThread() { + Thread currentThread = thread; + if (currentThread == null) { + interrupted = true; + } else { + currentThread.interrupt(); + } + } + + /** + * @see Queue#poll() + */ + protected Runnable pollTask() { + assert inEventLoop(); + return pollTaskFrom(taskQueue); + } + + protected static Runnable pollTaskFrom(Queue taskQueue) { + for (; ; ) { + Runnable task = taskQueue.poll(); + if (task != WAKEUP_TASK) { + return task; + } + } + } + + /** + * Take the next {@link Runnable} from the task queue and so will block if no task is currently present. + *

+ * Be aware that this method will throw an {@link UnsupportedOperationException} if the task queue, which was + * created via {@link #newTaskQueue()}, does not implement {@link BlockingQueue}. + *

+ * + * @return {@code null} if the executor thread has been interrupted or waken up. + */ + protected Runnable takeTask() { + assert inEventLoop(); + if (!(taskQueue instanceof BlockingQueue)) { + throw new UnsupportedOperationException(); + } + + BlockingQueue taskQueue = (BlockingQueue) this.taskQueue; + for (; ; ) { + ScheduledFutureTask scheduledTask = peekScheduledTask(); + if (scheduledTask == null) { + Runnable task = null; + try { + task = taskQueue.take(); + if (task == WAKEUP_TASK) { + task = null; + } + } catch (InterruptedException e) { + // Ignore + } + return task; + } else { + long delayNanos = scheduledTask.delayNanos(); + Runnable task = null; + if (delayNanos > 0) { + try { + task = taskQueue.poll(delayNanos, TimeUnit.NANOSECONDS); + } catch (InterruptedException e) { + // Waken up. + return null; + } + } + if (task == null) { + // We need to fetch the scheduled tasks now as otherwise there may be a chance that + // scheduled tasks are never executed if there is always one task in the taskQueue. + // This is for example true for the read task of OIO Transport + // See https://github.com/netty/netty/issues/1614 + fetchFromScheduledTaskQueue(); + task = taskQueue.poll(); + } + + if (task != null) { + return task; + } + } + } + } + + private boolean fetchFromScheduledTaskQueue() { + if (scheduledTaskQueue == null || scheduledTaskQueue.isEmpty()) { + return true; + } + long nanoTime = getCurrentTimeNanos(); + for (; ; ) { + Runnable scheduledTask = pollScheduledTask(nanoTime); + if (scheduledTask == null) { + return true; + } + if (!taskQueue.offer(scheduledTask)) { + // No space left in the task queue add it back to the scheduledTaskQueue so we pick it up again. + scheduledTaskQueue.add((ScheduledFutureTask) scheduledTask); + return false; + } + } + } + + /** + * @return {@code true} if at least one scheduled task was executed. + */ + private boolean executeExpiredScheduledTasks() { + if (scheduledTaskQueue == null || scheduledTaskQueue.isEmpty()) { + return false; + } + long nanoTime = getCurrentTimeNanos(); + Runnable scheduledTask = pollScheduledTask(nanoTime); + if (scheduledTask == null) { + return false; + } + do { + safeExecute(scheduledTask); + } while ((scheduledTask = pollScheduledTask(nanoTime)) != null); + return true; + } + + /** + * @see Queue#peek() + */ + protected Runnable peekTask() { + assert inEventLoop(); + return taskQueue.peek(); + } + + /** + * @see Queue#isEmpty() + */ + protected boolean hasTasks() { + assert inEventLoop(); + return !taskQueue.isEmpty(); + } + + /** + * Return the number of tasks that are pending for processing. + */ + public int pendingTasks() { + return taskQueue.size(); + } + + /** + * Add a task to the task queue, or throws a {@link RejectedExecutionException} if this instance was shutdown + * before. + */ + protected void addTask(Runnable task) { + ObjectUtil.checkNotNull(task, "task"); + if (!offerTask(task)) { + reject(task); + } + } + + final boolean offerTask(Runnable task) { + if (isShutdown()) { + reject(); + } + return taskQueue.offer(task); + } + + /** + * @see Queue#remove(Object) + */ + protected boolean removeTask(Runnable task) { + return taskQueue.remove(ObjectUtil.checkNotNull(task, "task")); + } + + /** + * Poll all tasks from the task queue and run them via {@link Runnable#run()} method. + * + * @return {@code true} if and only if at least one task was run + */ + protected boolean runAllTasks() { + assert inEventLoop(); + boolean fetchedAll; + boolean ranAtLeastOne = false; + + do { + fetchedAll = fetchFromScheduledTaskQueue(); + if (runAllTasksFrom(taskQueue)) { + ranAtLeastOne = true; + } + } while (!fetchedAll); // keep on processing until we fetched all scheduled tasks. + + if (ranAtLeastOne) { + lastExecutionTime = getCurrentTimeNanos(); + } + afterRunningAllTasks(); + return ranAtLeastOne; + } + + /** + * Execute all expired scheduled tasks and all current tasks in the executor queue until both queues are empty, + * or {@code maxDrainAttempts} has been exceeded. + * + * @param maxDrainAttempts The maximum amount of times this method attempts to drain from queues. This is to prevent + * continuous task execution and scheduling from preventing the EventExecutor thread to + * make progress and return to the selector mechanism to process inbound I/O events. + * @return {@code true} if at least one task was run. + */ + protected final boolean runScheduledAndExecutorTasks(final int maxDrainAttempts) { + assert inEventLoop(); + boolean ranAtLeastOneTask; + int drainAttempt = 0; + do { + // We must run the taskQueue tasks first, because the scheduled tasks from outside the EventLoop are queued + // here because the taskQueue is thread safe and the scheduledTaskQueue is not thread safe. + ranAtLeastOneTask = runExistingTasksFrom(taskQueue) | executeExpiredScheduledTasks(); + } while (ranAtLeastOneTask && ++drainAttempt < maxDrainAttempts); + + if (drainAttempt > 0) { + lastExecutionTime = getCurrentTimeNanos(); + } + afterRunningAllTasks(); + + return drainAttempt > 0; + } + + /** + * Runs all tasks from the passed {@code taskQueue}. + * + * @param taskQueue To poll and execute all tasks. + * @return {@code true} if at least one task was executed. + */ + protected final boolean runAllTasksFrom(Queue taskQueue) { + Runnable task = pollTaskFrom(taskQueue); + if (task == null) { + return false; + } + for (; ; ) { + safeExecute(task); + task = pollTaskFrom(taskQueue); + if (task == null) { + return true; + } + } + } + + /** + * What ever tasks are present in {@code taskQueue} when this method is invoked will be {@link Runnable#run()}. + * + * @param taskQueue the task queue to drain. + * @return {@code true} if at least {@link Runnable#run()} was called. + */ + private boolean runExistingTasksFrom(Queue taskQueue) { + Runnable task = pollTaskFrom(taskQueue); + if (task == null) { + return false; + } + int remaining = Math.min(maxPendingTasks, taskQueue.size()); + safeExecute(task); + // Use taskQueue.poll() directly rather than pollTaskFrom() since the latter may + // silently consume more than one item from the queue (skips over WAKEUP_TASK instances) + while (remaining-- > 0 && (task = taskQueue.poll()) != null) { + safeExecute(task); + } + return true; + } + + /** + * Poll all tasks from the task queue and run them via {@link Runnable#run()} method. This method stops running + * the tasks in the task queue and returns if it ran longer than {@code timeoutNanos}. + */ + protected boolean runAllTasks(long timeoutNanos) { + fetchFromScheduledTaskQueue(); + Runnable task = pollTask(); + if (task == null) { + afterRunningAllTasks(); + return false; + } + + final long deadline = timeoutNanos > 0 ? getCurrentTimeNanos() + timeoutNanos : 0; + long runTasks = 0; + long lastExecutionTime; + for (; ; ) { + safeExecute(task); + + runTasks++; + + // Check timeout every 64 tasks because nanoTime() is relatively expensive. + // XXX: Hard-coded value - will make it configurable if it is really a problem. + if ((runTasks & 0x3F) == 0) { + lastExecutionTime = getCurrentTimeNanos(); + if (lastExecutionTime >= deadline) { + break; + } + } + + task = pollTask(); + if (task == null) { + lastExecutionTime = getCurrentTimeNanos(); + break; + } + } + + afterRunningAllTasks(); + this.lastExecutionTime = lastExecutionTime; + return true; + } + + /** + * Invoked before returning from {@link #runAllTasks()} and {@link #runAllTasks(long)}. + */ + @UnstableApi + protected void afterRunningAllTasks() { + } + + /** + * Returns the amount of time left until the scheduled task with the closest dead line is executed. + */ + protected long delayNanos(long currentTimeNanos) { + currentTimeNanos -= initialNanoTime(); + + ScheduledFutureTask scheduledTask = peekScheduledTask(); + if (scheduledTask == null) { + return SCHEDULE_PURGE_INTERVAL; + } + + return scheduledTask.delayNanos(currentTimeNanos); + } + + /** + * Returns the absolute point in time (relative to {@link #getCurrentTimeNanos()}) at which the next + * closest scheduled task should run. + */ + @UnstableApi + protected long deadlineNanos() { + ScheduledFutureTask scheduledTask = peekScheduledTask(); + if (scheduledTask == null) { + return getCurrentTimeNanos() + SCHEDULE_PURGE_INTERVAL; + } + return scheduledTask.deadlineNanos(); + } + + /** + * Updates the internal timestamp that tells when a submitted task was executed most recently. + * {@link #runAllTasks()} and {@link #runAllTasks(long)} updates this timestamp automatically, and thus there's + * usually no need to call this method. However, if you take the tasks manually using {@link #takeTask()} or + * {@link #pollTask()}, you have to call this method at the end of task execution loop for accurate quiet period + * checks. + */ + protected void updateLastExecutionTime() { + lastExecutionTime = getCurrentTimeNanos(); + } + + /** + * Run the tasks in the {@link #taskQueue} + */ + protected abstract void run(); + + /** + * Do nothing, sub-classes may override + */ + protected void cleanup() { + // NOOP + } + + protected void wakeup(boolean inEventLoop) { + if (!inEventLoop) { + // Use offer as we actually only need this to unblock the thread and if offer fails we do not care as there + // is already something in the queue. + taskQueue.offer(WAKEUP_TASK); + } + } + + @Override + public boolean inEventLoop(Thread thread) { + return thread == this.thread; + } + + /** + * Add a {@link Runnable} which will be executed on shutdown of this instance + */ + public void addShutdownHook(final Runnable task) { + if (inEventLoop()) { + shutdownHooks.add(task); + } else { + execute(new Runnable() { + @Override + public void run() { + shutdownHooks.add(task); + } + }); + } + } + + /** + * Remove a previous added {@link Runnable} as a shutdown hook + */ + public void removeShutdownHook(final Runnable task) { + if (inEventLoop()) { + shutdownHooks.remove(task); + } else { + execute(new Runnable() { + @Override + public void run() { + shutdownHooks.remove(task); + } + }); + } + } + + private boolean runShutdownHooks() { + boolean ran = false; + // Note shutdown hooks can add / remove shutdown hooks. + while (!shutdownHooks.isEmpty()) { + List copy = new ArrayList(shutdownHooks); + shutdownHooks.clear(); + for (Runnable task : copy) { + try { + runTask(task); + } catch (Throwable t) { + logger.warn("Shutdown hook raised an exception.", t); + } finally { + ran = true; + } + } + } + + if (ran) { + lastExecutionTime = getCurrentTimeNanos(); + } + + return ran; + } + + @Override + public Future shutdownGracefully(long quietPeriod, long timeout, TimeUnit unit) { + ObjectUtil.checkPositiveOrZero(quietPeriod, "quietPeriod"); + if (timeout < quietPeriod) { + throw new IllegalArgumentException( + "timeout: " + timeout + " (expected >= quietPeriod (" + quietPeriod + "))"); + } + ObjectUtil.checkNotNull(unit, "unit"); + + if (isShuttingDown()) { + return terminationFuture(); + } + + boolean inEventLoop = inEventLoop(); + boolean wakeup; + int oldState; + for (; ; ) { + if (isShuttingDown()) { + return terminationFuture(); + } + int newState; + wakeup = true; + oldState = state; + if (inEventLoop) { + newState = ST_SHUTTING_DOWN; + } else { + switch (oldState) { + case ST_NOT_STARTED: + case ST_STARTED: + newState = ST_SHUTTING_DOWN; + break; + default: + newState = oldState; + wakeup = false; + } + } + if (STATE_UPDATER.compareAndSet(this, oldState, newState)) { + break; + } + } + gracefulShutdownQuietPeriod = unit.toNanos(quietPeriod); + gracefulShutdownTimeout = unit.toNanos(timeout); + + if (ensureThreadStarted(oldState)) { + return terminationFuture; + } + + if (wakeup) { + taskQueue.offer(WAKEUP_TASK); + if (!addTaskWakesUp) { + wakeup(inEventLoop); + } + } + + return terminationFuture(); + } + + @Override + public Future terminationFuture() { + return terminationFuture; + } + + @Override + @Deprecated + public void shutdown() { + if (isShutdown()) { + return; + } + + boolean inEventLoop = inEventLoop(); + boolean wakeup; + int oldState; + for (; ; ) { + if (isShuttingDown()) { + return; + } + int newState; + wakeup = true; + oldState = state; + if (inEventLoop) { + newState = ST_SHUTDOWN; + } else { + switch (oldState) { + case ST_NOT_STARTED: + case ST_STARTED: + case ST_SHUTTING_DOWN: + newState = ST_SHUTDOWN; + break; + default: + newState = oldState; + wakeup = false; + } + } + if (STATE_UPDATER.compareAndSet(this, oldState, newState)) { + break; + } + } + + if (ensureThreadStarted(oldState)) { + return; + } + + if (wakeup) { + taskQueue.offer(WAKEUP_TASK); + if (!addTaskWakesUp) { + wakeup(inEventLoop); + } + } + } + + @Override + public boolean isShuttingDown() { + return state >= ST_SHUTTING_DOWN; + } + + @Override + public boolean isShutdown() { + return state >= ST_SHUTDOWN; + } + + @Override + public boolean isTerminated() { + return state == ST_TERMINATED; + } + + /** + * Confirm that the shutdown if the instance should be done now! + */ + protected boolean confirmShutdown() { + if (!isShuttingDown()) { + return false; + } + + if (!inEventLoop()) { + throw new IllegalStateException("must be invoked from an event loop"); + } + + cancelScheduledTasks(); + + if (gracefulShutdownStartTime == 0) { + gracefulShutdownStartTime = getCurrentTimeNanos(); + } + + if (runAllTasks() || runShutdownHooks()) { + if (isShutdown()) { + // Executor shut down - no new tasks anymore. + return true; + } + + // There were tasks in the queue. Wait a little bit more until no tasks are queued for the quiet period or + // terminate if the quiet period is 0. + // See https://github.com/netty/netty/issues/4241 + if (gracefulShutdownQuietPeriod == 0) { + return true; + } + taskQueue.offer(WAKEUP_TASK); + return false; + } + + final long nanoTime = getCurrentTimeNanos(); + + if (isShutdown() || nanoTime - gracefulShutdownStartTime > gracefulShutdownTimeout) { + return true; + } + + if (nanoTime - lastExecutionTime <= gracefulShutdownQuietPeriod) { + // Check if any tasks were added to the queue every 100ms. + // TODO: Change the behavior of takeTask() so that it returns on timeout. + taskQueue.offer(WAKEUP_TASK); + try { + Thread.sleep(100); + } catch (InterruptedException e) { + // Ignore + } + + return false; + } + + // No tasks were added for last quiet period - hopefully safe to shut down. + // (Hopefully because we really cannot make a guarantee that there will be no execute() calls by a user.) + return true; + } + + @Override + public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException { + ObjectUtil.checkNotNull(unit, "unit"); + if (inEventLoop()) { + throw new IllegalStateException("cannot await termination of the current thread"); + } + + threadLock.await(timeout, unit); + + return isTerminated(); + } + + @Override + public void execute(Runnable task) { + execute0(task); + } + + @Override + public void lazyExecute(Runnable task) { + lazyExecute0(task); + } + + private void execute0(Runnable task) { + ObjectUtil.checkNotNull(task, "task"); + execute(task, wakesUpForTask(task)); + } + + private void lazyExecute0(Runnable task) { + execute(ObjectUtil.checkNotNull(task, "task"), false); + } + + private void execute(Runnable task, boolean immediate) { + boolean inEventLoop = inEventLoop(); + addTask(task); + if (!inEventLoop) { + startThread(); + if (isShutdown()) { + boolean reject = false; + try { + if (removeTask(task)) { + reject = true; + } + } catch (UnsupportedOperationException e) { + // The task queue does not support removal so the best thing we can do is to just move on and + // hope we will be able to pick-up the task before its completely terminated. + // In worst case we will log on termination. + } + if (reject) { + reject(); + } + } + } + + if (!addTaskWakesUp && immediate) { + wakeup(inEventLoop); + } + } + + @Override + public T invokeAny(Collection> tasks) throws InterruptedException, ExecutionException { + throwIfInEventLoop("invokeAny"); + return super.invokeAny(tasks); + } + + @Override + public T invokeAny(Collection> tasks, long timeout, TimeUnit unit) + throws InterruptedException, ExecutionException, TimeoutException { + throwIfInEventLoop("invokeAny"); + return super.invokeAny(tasks, timeout, unit); + } + + @Override + public List> invokeAll(Collection> tasks) + throws InterruptedException { + throwIfInEventLoop("invokeAll"); + return super.invokeAll(tasks); + } + + @Override + public List> invokeAll( + Collection> tasks, long timeout, TimeUnit unit) throws InterruptedException { + throwIfInEventLoop("invokeAll"); + return super.invokeAll(tasks, timeout, unit); + } + + private void throwIfInEventLoop(String method) { + if (inEventLoop()) { + throw new RejectedExecutionException("Calling " + method + " from within the EventLoop is not allowed"); + } + } + + /** + * Returns the {@link ThreadProperties} of the {@link Thread} that powers the {@link SingleThreadEventExecutor}. + * If the {@link SingleThreadEventExecutor} is not started yet, this operation will start it and block until + * it is fully started. + */ + public final ThreadProperties threadProperties() { + ThreadProperties threadProperties = this.threadProperties; + if (threadProperties == null) { + Thread thread = this.thread; + if (thread == null) { + assert !inEventLoop(); + submit(NOOP_TASK).syncUninterruptibly(); + thread = this.thread; + assert thread != null; + } + + threadProperties = new DefaultThreadProperties(thread); + if (!PROPERTIES_UPDATER.compareAndSet(this, null, threadProperties)) { + threadProperties = this.threadProperties; + } + } + + return threadProperties; + } + + /** + * @deprecated override {@link SingleThreadEventExecutor#wakesUpForTask} to re-create this behaviour + */ + @Deprecated + protected interface NonWakeupRunnable extends LazyRunnable { + } + + /** + * Can be overridden to control which tasks require waking the {@link EventExecutor} thread + * if it is waiting so that they can be run immediately. + */ + protected boolean wakesUpForTask(Runnable task) { + return true; + } + + protected static void reject() { + throw new RejectedExecutionException("event executor terminated"); + } + + /** + * Offers the task to the associated {@link RejectedExecutionHandler}. + * + * @param task to reject. + */ + protected final void reject(Runnable task) { + rejectedExecutionHandler.rejected(task, this); + } + + // ScheduledExecutorService implementation + + private static final long SCHEDULE_PURGE_INTERVAL = TimeUnit.SECONDS.toNanos(1); + + private void startThread() { + if (state == ST_NOT_STARTED) { + if (STATE_UPDATER.compareAndSet(this, ST_NOT_STARTED, ST_STARTED)) { + boolean success = false; + try { + doStartThread(); + success = true; + } finally { + if (!success) { + STATE_UPDATER.compareAndSet(this, ST_STARTED, ST_NOT_STARTED); + } + } + } + } + } + + private boolean ensureThreadStarted(int oldState) { + if (oldState == ST_NOT_STARTED) { + try { + doStartThread(); + } catch (Throwable cause) { + STATE_UPDATER.set(this, ST_TERMINATED); + terminationFuture.tryFailure(cause); + + if (!(cause instanceof Exception)) { + // Also rethrow as it may be an OOME for example + PlatformDependent.throwException(cause); + } + return true; + } + } + return false; + } + + private void doStartThread() { + assert thread == null; + executor.execute(new Runnable() { + @Override + public void run() { + thread = Thread.currentThread(); + if (interrupted) { + thread.interrupt(); + } + + boolean success = false; + updateLastExecutionTime(); + try { + SingleThreadEventExecutor.this.run(); + success = true; + } catch (Throwable t) { + logger.warn("Unexpected exception from an event executor: ", t); + } finally { + for (; ; ) { + int oldState = state; + if (oldState >= ST_SHUTTING_DOWN || STATE_UPDATER.compareAndSet( + SingleThreadEventExecutor.this, oldState, ST_SHUTTING_DOWN)) { + break; + } + } + + // Check if confirmShutdown() was called at the end of the loop. + if (success && gracefulShutdownStartTime == 0) { + if (logger.isErrorEnabled()) { + logger.error("Buggy " + EventExecutor.class.getSimpleName() + " implementation; " + + SingleThreadEventExecutor.class.getSimpleName() + ".confirmShutdown() must " + + "be called before run() implementation terminates."); + } + } + + try { + // Run all remaining tasks and shutdown hooks. At this point the event loop + // is in ST_SHUTTING_DOWN state still accepting tasks which is needed for + // graceful shutdown with quietPeriod. + for (; ; ) { + if (confirmShutdown()) { + break; + } + } + + // Now we want to make sure no more tasks can be added from this point. This is + // achieved by switching the state. Any new tasks beyond this point will be rejected. + for (; ; ) { + int oldState = state; + if (oldState >= ST_SHUTDOWN || STATE_UPDATER.compareAndSet( + SingleThreadEventExecutor.this, oldState, ST_SHUTDOWN)) { + break; + } + } + + // We have the final set of tasks in the queue now, no more can be added, run all remaining. + // No need to loop here, this is the final pass. + confirmShutdown(); + } finally { + try { + cleanup(); + } finally { + // Lets remove all FastThreadLocals for the Thread as we are about to terminate and notify + // the future. The user may block on the future and once it unblocks the JVM may terminate + // and start unloading classes. + // See https://github.com/netty/netty/issues/6596. + FastThreadLocal.removeAll(); + + STATE_UPDATER.set(SingleThreadEventExecutor.this, ST_TERMINATED); + threadLock.countDown(); + int numUserTasks = drainTasks(); + if (numUserTasks > 0 && logger.isWarnEnabled()) { + logger.warn("An event executor terminated with " + + "non-empty task queue (" + numUserTasks + ')'); + } + terminationFuture.setSuccess(null); + } + } + } + } + }); + } + + final int drainTasks() { + int numTasks = 0; + for (; ; ) { + Runnable runnable = taskQueue.poll(); + if (runnable == null) { + break; + } + // WAKEUP_TASK should be just discarded as these are added internally. + // The important bit is that we not have any user tasks left. + if (WAKEUP_TASK != runnable) { + numTasks++; + } + } + return numTasks; + } + + private static final class DefaultThreadProperties implements ThreadProperties { + private final Thread t; + + DefaultThreadProperties(Thread t) { + this.t = t; + } + + @Override + public State state() { + return t.getState(); + } + + @Override + public int priority() { + return t.getPriority(); + } + + @Override + public boolean isInterrupted() { + return t.isInterrupted(); + } + + @Override + public boolean isDaemon() { + return t.isDaemon(); + } + + @Override + public String name() { + return t.getName(); + } + + @Override + public long id() { + return t.getId(); + } + + @Override + public StackTraceElement[] stackTrace() { + return t.getStackTrace(); + } + + @Override + public boolean isAlive() { + return t.isAlive(); + } + } +} diff --git a/netty-util/src/main/java/io/netty/util/concurrent/SucceededFuture.java b/netty-util/src/main/java/io/netty/util/concurrent/SucceededFuture.java new file mode 100644 index 0000000..b8a2007 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/concurrent/SucceededFuture.java @@ -0,0 +1,50 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.concurrent; + +/** + * The {@link CompleteFuture} which is succeeded already. It is + * recommended to use {@link EventExecutor#newSucceededFuture(Object)} instead of + * calling the constructor of this future. + */ +public final class SucceededFuture extends CompleteFuture { + private final V result; + + /** + * Creates a new instance. + * + * @param executor the {@link EventExecutor} associated with this future + */ + public SucceededFuture(EventExecutor executor, V result) { + super(executor); + this.result = result; + } + + @Override + public Throwable cause() { + return null; + } + + @Override + public boolean isSuccess() { + return true; + } + + @Override + public V getNow() { + return result; + } +} diff --git a/netty-util/src/main/java/io/netty/util/concurrent/ThreadPerTaskExecutor.java b/netty-util/src/main/java/io/netty/util/concurrent/ThreadPerTaskExecutor.java new file mode 100644 index 0000000..c38413c --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/concurrent/ThreadPerTaskExecutor.java @@ -0,0 +1,33 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.concurrent; + +import io.netty.util.internal.ObjectUtil; +import java.util.concurrent.Executor; +import java.util.concurrent.ThreadFactory; + +public final class ThreadPerTaskExecutor implements Executor { + private final ThreadFactory threadFactory; + + public ThreadPerTaskExecutor(ThreadFactory threadFactory) { + this.threadFactory = ObjectUtil.checkNotNull(threadFactory, "threadFactory"); + } + + @Override + public void execute(Runnable command) { + threadFactory.newThread(command).start(); + } +} diff --git a/netty-util/src/main/java/io/netty/util/concurrent/ThreadProperties.java b/netty-util/src/main/java/io/netty/util/concurrent/ThreadProperties.java new file mode 100644 index 0000000..650ef7e --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/concurrent/ThreadProperties.java @@ -0,0 +1,61 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.concurrent; + +/** + * Expose details for a {@link Thread}. + */ +public interface ThreadProperties { + /** + * @see Thread#getState() + */ + Thread.State state(); + + /** + * @see Thread#getPriority() + */ + int priority(); + + /** + * @see Thread#isInterrupted() + */ + boolean isInterrupted(); + + /** + * @see Thread#isDaemon() + */ + boolean isDaemon(); + + /** + * @see Thread#getName() + */ + String name(); + + /** + * @see Thread#getId() + */ + long id(); + + /** + * @see Thread#getStackTrace() + */ + StackTraceElement[] stackTrace(); + + /** + * @see Thread#isAlive() + */ + boolean isAlive(); +} diff --git a/netty-util/src/main/java/io/netty/util/concurrent/UnaryPromiseNotifier.java b/netty-util/src/main/java/io/netty/util/concurrent/UnaryPromiseNotifier.java new file mode 100644 index 0000000..99e0ed4 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/concurrent/UnaryPromiseNotifier.java @@ -0,0 +1,55 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.concurrent; + +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; + +/** + * @deprecated use {@link PromiseNotifier#cascade(boolean, Future, Promise)}. + */ +@Deprecated +public final class UnaryPromiseNotifier implements FutureListener { + private static final InternalLogger logger = InternalLoggerFactory.getInstance(UnaryPromiseNotifier.class); + private final Promise promise; + + public UnaryPromiseNotifier(Promise promise) { + this.promise = ObjectUtil.checkNotNull(promise, "promise"); + } + + @Override + public void operationComplete(Future future) throws Exception { + cascadeTo(future, promise); + } + + public static void cascadeTo(Future completedFuture, Promise promise) { + if (completedFuture.isSuccess()) { + if (!promise.trySuccess(completedFuture.getNow())) { + logger.warn("Failed to mark a promise as success because it is done already: {}", promise); + } + } else if (completedFuture.isCancelled()) { + if (!promise.cancel(false)) { + logger.warn("Failed to cancel a promise because it is done already: {}", promise); + } + } else { + if (!promise.tryFailure(completedFuture.cause())) { + logger.warn("Failed to mark a promise as failure because it's done already: {}", promise, + completedFuture.cause()); + } + } + } +} diff --git a/netty-util/src/main/java/io/netty/util/concurrent/UnorderedThreadPoolEventExecutor.java b/netty-util/src/main/java/io/netty/util/concurrent/UnorderedThreadPoolEventExecutor.java new file mode 100644 index 0000000..c6e8cef --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/concurrent/UnorderedThreadPoolEventExecutor.java @@ -0,0 +1,293 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.concurrent; + +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; +import java.util.Collections; +import java.util.Iterator; +import java.util.List; +import java.util.Set; +import java.util.concurrent.Callable; +import java.util.concurrent.Delayed; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.RejectedExecutionHandler; +import java.util.concurrent.RunnableScheduledFuture; +import java.util.concurrent.ScheduledThreadPoolExecutor; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.TimeUnit; +import static java.util.concurrent.TimeUnit.NANOSECONDS; + +/** + * {@link EventExecutor} implementation which makes no guarantees about the ordering of task execution that + * are submitted because there may be multiple threads executing these tasks. + * This implementation is most useful for protocols that do not need strict ordering. + * + * Because it provides no ordering care should be taken when using it! + */ +public final class UnorderedThreadPoolEventExecutor extends ScheduledThreadPoolExecutor implements EventExecutor { + private static final InternalLogger logger = InternalLoggerFactory.getInstance( + UnorderedThreadPoolEventExecutor.class); + + private final Promise terminationFuture = GlobalEventExecutor.INSTANCE.newPromise(); + private final Set executorSet = Collections.singleton(this); + + /** + * Calls {@link UnorderedThreadPoolEventExecutor#UnorderedThreadPoolEventExecutor(int, ThreadFactory)} + * using {@link DefaultThreadFactory}. + */ + public UnorderedThreadPoolEventExecutor(int corePoolSize) { + this(corePoolSize, new DefaultThreadFactory(UnorderedThreadPoolEventExecutor.class)); + } + + /** + * See {@link ScheduledThreadPoolExecutor#ScheduledThreadPoolExecutor(int, ThreadFactory)} + */ + public UnorderedThreadPoolEventExecutor(int corePoolSize, ThreadFactory threadFactory) { + super(corePoolSize, threadFactory); + } + + /** + * Calls {@link UnorderedThreadPoolEventExecutor#UnorderedThreadPoolEventExecutor(int, + * ThreadFactory, java.util.concurrent.RejectedExecutionHandler)} using {@link DefaultThreadFactory}. + */ + public UnorderedThreadPoolEventExecutor(int corePoolSize, RejectedExecutionHandler handler) { + this(corePoolSize, new DefaultThreadFactory(UnorderedThreadPoolEventExecutor.class), handler); + } + + /** + * See {@link ScheduledThreadPoolExecutor#ScheduledThreadPoolExecutor(int, ThreadFactory, RejectedExecutionHandler)} + */ + public UnorderedThreadPoolEventExecutor(int corePoolSize, ThreadFactory threadFactory, + RejectedExecutionHandler handler) { + super(corePoolSize, threadFactory, handler); + } + + @Override + public EventExecutor next() { + return this; + } + + @Override + public EventExecutorGroup parent() { + return this; + } + + @Override + public boolean inEventLoop() { + return false; + } + + @Override + public boolean inEventLoop(Thread thread) { + return false; + } + + @Override + public Promise newPromise() { + return new DefaultPromise(this); + } + + @Override + public ProgressivePromise newProgressivePromise() { + return new DefaultProgressivePromise(this); + } + + @Override + public Future newSucceededFuture(V result) { + return new SucceededFuture(this, result); + } + + @Override + public Future newFailedFuture(Throwable cause) { + return new FailedFuture(this, cause); + } + + @Override + public boolean isShuttingDown() { + return isShutdown(); + } + + @Override + public List shutdownNow() { + List tasks = super.shutdownNow(); + terminationFuture.trySuccess(null); + return tasks; + } + + @Override + public void shutdown() { + super.shutdown(); + terminationFuture.trySuccess(null); + } + + @Override + public Future shutdownGracefully() { + return shutdownGracefully(2, 15, TimeUnit.SECONDS); + } + + @Override + public Future shutdownGracefully(long quietPeriod, long timeout, TimeUnit unit) { + // TODO: At the moment this just calls shutdown but we may be able to do something more smart here which + // respects the quietPeriod and timeout. + shutdown(); + return terminationFuture(); + } + + @Override + public Future terminationFuture() { + return terminationFuture; + } + + @Override + public Iterator iterator() { + return executorSet.iterator(); + } + + @Override + protected RunnableScheduledFuture decorateTask(Runnable runnable, RunnableScheduledFuture task) { + return runnable instanceof NonNotifyRunnable ? + task : new RunnableScheduledFutureTask(this, task, false); + } + + @Override + protected RunnableScheduledFuture decorateTask(Callable callable, RunnableScheduledFuture task) { + return new RunnableScheduledFutureTask(this, task, true); + } + + @Override + public ScheduledFuture schedule(Runnable command, long delay, TimeUnit unit) { + return (ScheduledFuture) super.schedule(command, delay, unit); + } + + @Override + public ScheduledFuture schedule(Callable callable, long delay, TimeUnit unit) { + return (ScheduledFuture) super.schedule(callable, delay, unit); + } + + @Override + public ScheduledFuture scheduleAtFixedRate(Runnable command, long initialDelay, long period, TimeUnit unit) { + return (ScheduledFuture) super.scheduleAtFixedRate(command, initialDelay, period, unit); + } + + @Override + public ScheduledFuture scheduleWithFixedDelay(Runnable command, long initialDelay, long delay, TimeUnit unit) { + return (ScheduledFuture) super.scheduleWithFixedDelay(command, initialDelay, delay, unit); + } + + @Override + public Future submit(Runnable task) { + return (Future) super.submit(task); + } + + @Override + public Future submit(Runnable task, T result) { + return (Future) super.submit(task, result); + } + + @Override + public Future submit(Callable task) { + return (Future) super.submit(task); + } + + @Override + public void execute(Runnable command) { + super.schedule(new NonNotifyRunnable(command), 0, NANOSECONDS); + } + + private static final class RunnableScheduledFutureTask extends PromiseTask + implements RunnableScheduledFuture, ScheduledFuture { + private final RunnableScheduledFuture future; + private final boolean wasCallable; + + RunnableScheduledFutureTask(EventExecutor executor, RunnableScheduledFuture future, boolean wasCallable) { + super(executor, future); + this.future = future; + this.wasCallable = wasCallable; + } + + @Override + V runTask() throws Throwable { + V result = super.runTask(); + if (result == null && wasCallable) { + // If this RunnableScheduledFutureTask wraps a RunnableScheduledFuture that wraps a Callable we need + // to ensure that we return the correct result by calling future.get(). + // + // See https://github.com/netty/netty/issues/11072 + assert future.isDone(); + try { + return future.get(); + } catch (ExecutionException e) { + // unwrap exception. + throw e.getCause(); + } + } + return result; + } + + @Override + public void run() { + if (!isPeriodic()) { + super.run(); + } else if (!isDone()) { + try { + // Its a periodic task so we need to ignore the return value + runTask(); + } catch (Throwable cause) { + if (!tryFailureInternal(cause)) { + logger.warn("Failure during execution of task", cause); + } + } + } + } + + @Override + public boolean isPeriodic() { + return future.isPeriodic(); + } + + @Override + public long getDelay(TimeUnit unit) { + return future.getDelay(unit); + } + + @Override + public int compareTo(Delayed o) { + return future.compareTo(o); + } + } + + // This is a special wrapper which we will be used in execute(...) to wrap the submitted Runnable. This is needed as + // ScheduledThreadPoolExecutor.execute(...) will delegate to submit(...) which will then use decorateTask(...). + // The problem with this is that decorateTask(...) needs to ensure we only do our own decoration if we not call + // from execute(...) as otherwise we may end up creating an endless loop because DefaultPromise will call + // EventExecutor.execute(...) when notify the listeners of the promise. + // + // See https://github.com/netty/netty/issues/6507 + private static final class NonNotifyRunnable implements Runnable { + + private final Runnable task; + + NonNotifyRunnable(Runnable task) { + this.task = task; + } + + @Override + public void run() { + task.run(); + } + } +} diff --git a/netty-util/src/main/java/io/netty/util/concurrent/package-info.java b/netty-util/src/main/java/io/netty/util/concurrent/package-info.java new file mode 100644 index 0000000..ac6077b --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/concurrent/package-info.java @@ -0,0 +1,20 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/** + * Utility classes for concurrent / async tasks. + */ +package io.netty.util.concurrent; diff --git a/netty-util/src/main/java/io/netty/util/internal/AppendableCharSequence.java b/netty-util/src/main/java/io/netty/util/internal/AppendableCharSequence.java new file mode 100644 index 0000000..82d2fed --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/internal/AppendableCharSequence.java @@ -0,0 +1,169 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.internal; + +import java.util.Arrays; +import static io.netty.util.internal.ObjectUtil.checkNonEmpty; +import static io.netty.util.internal.ObjectUtil.checkPositive; + +public final class AppendableCharSequence implements CharSequence, Appendable { + private char[] chars; + private int pos; + + public AppendableCharSequence(int length) { + chars = new char[checkPositive(length, "length")]; + } + + private AppendableCharSequence(char[] chars) { + this.chars = checkNonEmpty(chars, "chars"); + pos = chars.length; + } + + public void setLength(int length) { + if (length < 0 || length > pos) { + throw new IllegalArgumentException("length: " + length + " (length: >= 0, <= " + pos + ')'); + } + this.pos = length; + } + + @Override + public int length() { + return pos; + } + + @Override + public char charAt(int index) { + if (index > pos) { + throw new IndexOutOfBoundsException(); + } + return chars[index]; + } + + /** + * Access a value in this {@link CharSequence}. + * This method is considered unsafe as index values are assumed to be legitimate. + * Only underlying array bounds checking is done. + * + * @param index The index to access the underlying array at. + * @return The value at {@code index}. + */ + public char charAtUnsafe(int index) { + return chars[index]; + } + + @Override + public AppendableCharSequence subSequence(int start, int end) { + if (start == end) { + // If start and end index is the same we need to return an empty sequence to conform to the interface. + // As our expanding logic depends on the fact that we have a char[] with length > 0 we need to construct + // an instance for which this is true. + return new AppendableCharSequence(Math.min(16, chars.length)); + } + return new AppendableCharSequence(Arrays.copyOfRange(chars, start, end)); + } + + @Override + public AppendableCharSequence append(char c) { + if (pos == chars.length) { + char[] old = chars; + chars = new char[old.length << 1]; + System.arraycopy(old, 0, chars, 0, old.length); + } + chars[pos++] = c; + return this; + } + + @Override + public AppendableCharSequence append(CharSequence csq) { + return append(csq, 0, csq.length()); + } + + @Override + public AppendableCharSequence append(CharSequence csq, int start, int end) { + if (csq.length() < end) { + throw new IndexOutOfBoundsException("expected: csq.length() >= (" + + end + "),but actual is (" + csq.length() + ")"); + } + int length = end - start; + if (length > chars.length - pos) { + chars = expand(chars, pos + length, pos); + } + if (csq instanceof AppendableCharSequence seq) { + // Optimize append operations via array copy + char[] src = seq.chars; + System.arraycopy(src, start, chars, pos, length); + pos += length; + return this; + } + for (int i = start; i < end; i++) { + chars[pos++] = csq.charAt(i); + } + + return this; + } + + /** + * Reset the {@link AppendableCharSequence}. Be aware this will only reset the current internal position and not + * shrink the internal char array. + */ + public void reset() { + pos = 0; + } + + @Override + public String toString() { + return new String(chars, 0, pos); + } + + /** + * Create a new {@link String} from the given start to end. + */ + public String substring(int start, int end) { + int length = end - start; + if (start > pos || length > pos) { + throw new IndexOutOfBoundsException("expected: start and length <= (" + + pos + ")"); + } + return new String(chars, start, length); + } + + /** + * Create a new {@link String} from the given start to end. + * This method is considered unsafe as index values are assumed to be legitimate. + * Only underlying array bounds checking is done. + */ + public String subStringUnsafe(int start, int end) { + return new String(chars, start, end - start); + } + + private static char[] expand(char[] array, int neededSpace, int size) { + int newCapacity = array.length; + do { + // double capacity until it is big enough + newCapacity <<= 1; + + if (newCapacity < 0) { + throw new IllegalStateException(); + } + + } while (neededSpace > newCapacity); + + char[] newArray = new char[newCapacity]; + System.arraycopy(array, 0, newArray, 0, size); + + return newArray; + } +} diff --git a/netty-util/src/main/java/io/netty/util/internal/ClassInitializerUtil.java b/netty-util/src/main/java/io/netty/util/internal/ClassInitializerUtil.java new file mode 100644 index 0000000..624b20a --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/internal/ClassInitializerUtil.java @@ -0,0 +1,49 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.internal; + +/** + * Utility which ensures that classes are loaded by the {@link ClassLoader}. + */ +public final class ClassInitializerUtil { + + private ClassInitializerUtil() { + } + + /** + * Preload the given classes and so ensure the {@link ClassLoader} has these loaded after this method call. + * + * @param loadingClass the {@link Class} that wants to load the classes. + * @param classes the classes to load. + */ + public static void tryLoadClasses(Class loadingClass, Class... classes) { + ClassLoader loader = PlatformDependent.getClassLoader(loadingClass); + for (Class clazz : classes) { + tryLoadClass(loader, clazz.getName()); + } + } + + private static void tryLoadClass(ClassLoader classLoader, String className) { + try { + // Load the class and also ensure we init it which means its linked etc. + Class.forName(className, true, classLoader); + } catch (ClassNotFoundException ignore) { + // Ignore + } catch (SecurityException ignore) { + // Ignore + } + } +} diff --git a/netty-util/src/main/java/io/netty/util/internal/Cleaner.java b/netty-util/src/main/java/io/netty/util/internal/Cleaner.java new file mode 100644 index 0000000..24ab1fe --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/internal/Cleaner.java @@ -0,0 +1,29 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.internal; + +import java.nio.ByteBuffer; + +/** + * Allows to free direct {@link ByteBuffer}s. + */ +interface Cleaner { + + /** + * Free a direct {@link ByteBuffer} if possible + */ + void freeDirectBuffer(ByteBuffer buffer); +} diff --git a/netty-util/src/main/java/io/netty/util/internal/CleanerJava9.java b/netty-util/src/main/java/io/netty/util/internal/CleanerJava9.java new file mode 100644 index 0000000..5bd0297 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/internal/CleanerJava9.java @@ -0,0 +1,93 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.internal; + +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.nio.ByteBuffer; + +/** + * Provide a way to clean a ByteBuffer on Java9+. + */ +final class CleanerJava9 implements Cleaner { + private static final InternalLogger logger = InternalLoggerFactory.getInstance(CleanerJava9.class); + + private static final Method INVOKE_CLEANER; + + static { + final Method method; + final Throwable error; + if (PlatformDependent0.hasUnsafe()) { + final ByteBuffer buffer = ByteBuffer.allocateDirect(1); + Object maybeInvokeMethod; + try { + // See https://bugs.openjdk.java.net/browse/JDK-8171377 + Method m = PlatformDependent0.UNSAFE.getClass().getDeclaredMethod( + "invokeCleaner", ByteBuffer.class); + m.invoke(PlatformDependent0.UNSAFE, buffer); + maybeInvokeMethod = m; + } catch (NoSuchMethodException | InvocationTargetException | IllegalAccessException e) { + maybeInvokeMethod = e; + } + if (maybeInvokeMethod instanceof Throwable) { + method = null; + error = (Throwable) maybeInvokeMethod; + } else { + method = (Method) maybeInvokeMethod; + error = null; + } + } else { + method = null; + error = new UnsupportedOperationException("sun.misc.Unsafe unavailable"); + } + if (error == null) { + logger.debug("java.nio.ByteBuffer.cleaner(): available"); + } else { + logger.debug("java.nio.ByteBuffer.cleaner(): unavailable", error); + } + INVOKE_CLEANER = method; + } + + static boolean isSupported() { + return INVOKE_CLEANER != null; + } + + @Override + public void freeDirectBuffer(ByteBuffer buffer) { + // Try to minimize overhead when there is no SecurityManager present. + // See https://bugs.openjdk.java.net/browse/JDK-8191053. + try { + INVOKE_CLEANER.invoke(PlatformDependent0.UNSAFE, buffer); + } catch (Throwable cause) { + PlatformDependent0.throwException(cause); + } + } + + private static void freeDirectBufferPrivileged(final ByteBuffer buffer) { + Exception error; + try { + INVOKE_CLEANER.invoke(PlatformDependent0.UNSAFE, buffer); + error = null; + } catch (InvocationTargetException | IllegalAccessException e) { + error = e; + } + if (error != null) { + PlatformDependent0.throwException(error); + } + } +} diff --git a/netty-util/src/main/java/io/netty/util/internal/ConcurrentSet.java b/netty-util/src/main/java/io/netty/util/internal/ConcurrentSet.java new file mode 100644 index 0000000..c597254 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/internal/ConcurrentSet.java @@ -0,0 +1,69 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.internal; + +import java.io.Serializable; +import java.util.AbstractSet; +import java.util.Iterator; +import java.util.concurrent.ConcurrentMap; + +/** + * @deprecated For removal in Netty 4.2. Please use {@link java.util.concurrent.ConcurrentHashMap#newKeySet()} instead + */ +@Deprecated +public final class ConcurrentSet extends AbstractSet implements Serializable { + + private static final long serialVersionUID = -6761513279741915432L; + + private final ConcurrentMap map; + + /** + * Creates a new instance which wraps the specified {@code map}. + */ + public ConcurrentSet() { + map = PlatformDependent.newConcurrentHashMap(); + } + + @Override + public int size() { + return map.size(); + } + + @Override + public boolean contains(Object o) { + return map.containsKey(o); + } + + @Override + public boolean add(E o) { + return map.putIfAbsent(o, Boolean.TRUE) == null; + } + + @Override + public boolean remove(Object o) { + return map.remove(o) != null; + } + + @Override + public void clear() { + map.clear(); + } + + @Override + public Iterator iterator() { + return map.keySet().iterator(); + } +} diff --git a/netty-util/src/main/java/io/netty/util/internal/ConstantTimeUtils.java b/netty-util/src/main/java/io/netty/util/internal/ConstantTimeUtils.java new file mode 100644 index 0000000..2c3c9b1 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/internal/ConstantTimeUtils.java @@ -0,0 +1,136 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.internal; + +public final class ConstantTimeUtils { + private ConstantTimeUtils() { + } + + /** + * Compare two {@code int}s without leaking timing information. + *

+ * The {@code int} return type is intentional and is designed to allow cascading of constant time operations: + *

+     *     int l1 = 1;
+     *     int l2 = 1;
+     *     int l3 = 1;
+     *     int l4 = 500;
+     *     boolean equals = (equalsConstantTime(l1, l2) & equalsConstantTime(l3, l4)) != 0;
+     * 
+ * + * @param x the first value. + * @param y the second value. + * @return {@code 0} if not equal. {@code 1} if equal. + */ + public static int equalsConstantTime(int x, int y) { + int z = ~(x ^ y); + z &= z >> 16; + z &= z >> 8; + z &= z >> 4; + z &= z >> 2; + z &= z >> 1; + return z & 1; + } + + /** + * Compare two {@code longs}s without leaking timing information. + *

+ * The {@code int} return type is intentional and is designed to allow cascading of constant time operations: + *

+     *     long l1 = 1;
+     *     long l2 = 1;
+     *     long l3 = 1;
+     *     long l4 = 500;
+     *     boolean equals = (equalsConstantTime(l1, l2) & equalsConstantTime(l3, l4)) != 0;
+     * 
+ * + * @param x the first value. + * @param y the second value. + * @return {@code 0} if not equal. {@code 1} if equal. + */ + public static int equalsConstantTime(long x, long y) { + long z = ~(x ^ y); + z &= z >> 32; + z &= z >> 16; + z &= z >> 8; + z &= z >> 4; + z &= z >> 2; + z &= z >> 1; + return (int) (z & 1); + } + + /** + * Compare two {@code byte} arrays for equality without leaking timing information. + * For performance reasons no bounds checking on the parameters is performed. + *

+ * The {@code int} return type is intentional and is designed to allow cascading of constant time operations: + *

+     *     byte[] s1 = new {1, 2, 3};
+     *     byte[] s2 = new {1, 2, 3};
+     *     byte[] s3 = new {1, 2, 3};
+     *     byte[] s4 = new {4, 5, 6};
+     *     boolean equals = (equalsConstantTime(s1, 0, s2, 0, s1.length) &
+     *                       equalsConstantTime(s3, 0, s4, 0, s3.length)) != 0;
+     * 
+ * + * @param bytes1 the first byte array. + * @param startPos1 the position (inclusive) to start comparing in {@code bytes1}. + * @param bytes2 the second byte array. + * @param startPos2 the position (inclusive) to start comparing in {@code bytes2}. + * @param length the amount of bytes to compare. This is assumed to be validated as not going out of bounds + * by the caller. + * @return {@code 0} if not equal. {@code 1} if equal. + */ + public static int equalsConstantTime(byte[] bytes1, int startPos1, + byte[] bytes2, int startPos2, int length) { + // Benchmarking demonstrates that using an int to accumulate is faster than other data types. + int b = 0; + final int end = startPos1 + length; + for (; startPos1 < end; ++startPos1, ++startPos2) { + b |= bytes1[startPos1] ^ bytes2[startPos2]; + } + return equalsConstantTime(b, 0); + } + + /** + * Compare two {@link CharSequence} objects without leaking timing information. + *

+ * The {@code int} return type is intentional and is designed to allow cascading of constant time operations: + *

+     *     String s1 = "foo";
+     *     String s2 = "foo";
+     *     String s3 = "foo";
+     *     String s4 = "goo";
+     *     boolean equals = (equalsConstantTime(s1, s2) & equalsConstantTime(s3, s4)) != 0;
+     * 
+ * + * @param s1 the first value. + * @param s2 the second value. + * @return {@code 0} if not equal. {@code 1} if equal. + */ + public static int equalsConstantTime(CharSequence s1, CharSequence s2) { + if (s1.length() != s2.length()) { + return 0; + } + + // Benchmarking demonstrates that using an int to accumulate is faster than other data types. + int c = 0; + for (int i = 0; i < s1.length(); ++i) { + c |= s1.charAt(i) ^ s2.charAt(i); + } + return equalsConstantTime(c, 0); + } +} diff --git a/netty-util/src/main/java/io/netty/util/internal/DefaultPriorityQueue.java b/netty-util/src/main/java/io/netty/util/internal/DefaultPriorityQueue.java new file mode 100644 index 0000000..bf179ba --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/internal/DefaultPriorityQueue.java @@ -0,0 +1,295 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.internal; + +import java.util.AbstractQueue; +import java.util.Arrays; +import java.util.Comparator; +import java.util.Iterator; +import java.util.NoSuchElementException; +import static io.netty.util.internal.PriorityQueueNode.INDEX_NOT_IN_QUEUE; + +/** + * A priority queue which uses natural ordering of elements. Elements are also required to be of type + * {@link PriorityQueueNode} for the purpose of maintaining the index in the priority queue. + * + * @param The object that is maintained in the queue. + */ +public final class DefaultPriorityQueue extends AbstractQueue + implements PriorityQueue { + private static final PriorityQueueNode[] EMPTY_ARRAY = new PriorityQueueNode[0]; + private final Comparator comparator; + private T[] queue; + private int size; + + @SuppressWarnings("unchecked") + public DefaultPriorityQueue(Comparator comparator, int initialSize) { + this.comparator = ObjectUtil.checkNotNull(comparator, "comparator"); + queue = (T[]) (initialSize != 0 ? new PriorityQueueNode[initialSize] : EMPTY_ARRAY); + } + + @Override + public int size() { + return size; + } + + @Override + public boolean isEmpty() { + return size == 0; + } + + @Override + public boolean contains(Object o) { + if (!(o instanceof PriorityQueueNode node)) { + return false; + } + return contains(node, node.priorityQueueIndex(this)); + } + + @Override + public boolean containsTyped(T node) { + return contains(node, node.priorityQueueIndex(this)); + } + + @Override + public void clear() { + for (int i = 0; i < size; ++i) { + T node = queue[i]; + if (node != null) { + node.priorityQueueIndex(this, INDEX_NOT_IN_QUEUE); + queue[i] = null; + } + } + size = 0; + } + + @Override + public void clearIgnoringIndexes() { + size = 0; + } + + @Override + public boolean offer(T e) { + if (e.priorityQueueIndex(this) != INDEX_NOT_IN_QUEUE) { + throw new IllegalArgumentException("e.priorityQueueIndex(): " + e.priorityQueueIndex(this) + + " (expected: " + INDEX_NOT_IN_QUEUE + ") + e: " + e); + } + + // Check that the array capacity is enough to hold values by doubling capacity. + if (size >= queue.length) { + // Use a policy which allows for a 0 initial capacity. Same policy as JDK's priority queue, double when + // "small", then grow by 50% when "large". + queue = Arrays.copyOf(queue, queue.length + ((queue.length < 64) ? + (queue.length + 2) : + (queue.length >>> 1))); + } + + bubbleUp(size++, e); + return true; + } + + @Override + public T poll() { + if (size == 0) { + return null; + } + T result = queue[0]; + result.priorityQueueIndex(this, INDEX_NOT_IN_QUEUE); + + T last = queue[--size]; + queue[size] = null; + if (size != 0) { // Make sure we don't add the last element back. + bubbleDown(0, last); + } + + return result; + } + + @Override + public T peek() { + return (size == 0) ? null : queue[0]; + } + + @SuppressWarnings("unchecked") + @Override + public boolean remove(Object o) { + final T node; + try { + node = (T) o; + } catch (ClassCastException e) { + return false; + } + return removeTyped(node); + } + + @Override + public boolean removeTyped(T node) { + int i = node.priorityQueueIndex(this); + if (!contains(node, i)) { + return false; + } + + node.priorityQueueIndex(this, INDEX_NOT_IN_QUEUE); + if (--size == 0 || size == i) { + // If there are no node left, or this is the last node in the array just remove and return. + queue[i] = null; + return true; + } + + // Move the last element where node currently lives in the array. + T moved = queue[i] = queue[size]; + queue[size] = null; + // priorityQueueIndex will be updated below in bubbleUp or bubbleDown + + // Make sure the moved node still preserves the min-heap properties. + if (comparator.compare(node, moved) < 0) { + bubbleDown(i, moved); + } else { + bubbleUp(i, moved); + } + return true; + } + + @Override + public void priorityChanged(T node) { + int i = node.priorityQueueIndex(this); + if (!contains(node, i)) { + return; + } + + // Preserve the min-heap property by comparing the new priority with parents/children in the heap. + if (i == 0) { + bubbleDown(i, node); + } else { + // Get the parent to see if min-heap properties are violated. + int iParent = (i - 1) >>> 1; + T parent = queue[iParent]; + if (comparator.compare(node, parent) < 0) { + bubbleUp(i, node); + } else { + bubbleDown(i, node); + } + } + } + + @Override + public Object[] toArray() { + return Arrays.copyOf(queue, size); + } + + @SuppressWarnings("unchecked") + @Override + public X[] toArray(X[] a) { + if (a.length < size) { + return (X[]) Arrays.copyOf(queue, size, a.getClass()); + } + System.arraycopy(queue, 0, a, 0, size); + if (a.length > size) { + a[size] = null; + } + return a; + } + + /** + * This iterator does not return elements in any particular order. + */ + @Override + public Iterator iterator() { + return new PriorityQueueIterator(); + } + + private final class PriorityQueueIterator implements Iterator { + private int index; + + @Override + public boolean hasNext() { + return index < size; + } + + @Override + public T next() { + if (index >= size) { + throw new NoSuchElementException(); + } + + return queue[index++]; + } + + @Override + public void remove() { + throw new UnsupportedOperationException("remove"); + } + } + + private boolean contains(PriorityQueueNode node, int i) { + return i >= 0 && i < size && node.equals(queue[i]); + } + + private void bubbleDown(int k, T node) { + final int half = size >>> 1; + while (k < half) { + // Compare node to the children of index k. + int iChild = (k << 1) + 1; + T child = queue[iChild]; + + // Make sure we get the smallest child to compare against. + int rightChild = iChild + 1; + if (rightChild < size && comparator.compare(child, queue[rightChild]) > 0) { + child = queue[iChild = rightChild]; + } + // If the bubbleDown node is less than or equal to the smallest child then we will preserve the min-heap + // property by inserting the bubbleDown node here. + if (comparator.compare(node, child) <= 0) { + break; + } + + // Bubble the child up. + queue[k] = child; + child.priorityQueueIndex(this, k); + + // Move down k down the tree for the next iteration. + k = iChild; + } + + // We have found where node should live and still satisfy the min-heap property, so put it in the queue. + queue[k] = node; + node.priorityQueueIndex(this, k); + } + + private void bubbleUp(int k, T node) { + while (k > 0) { + int iParent = (k - 1) >>> 1; + T parent = queue[iParent]; + + // If the bubbleUp node is less than the parent, then we have found a spot to insert and still maintain + // min-heap properties. + if (comparator.compare(node, parent) >= 0) { + break; + } + + // Bubble the parent down. + queue[k] = parent; + parent.priorityQueueIndex(this, k); + + // Move k up the tree for the next iteration. + k = iParent; + } + + // We have found where node should live and still satisfy the min-heap property, so put it in the queue. + queue[k] = node; + node.priorityQueueIndex(this, k); + } +} diff --git a/netty-util/src/main/java/io/netty/util/internal/EmptyArrays.java b/netty-util/src/main/java/io/netty/util/internal/EmptyArrays.java new file mode 100644 index 0000000..582edc3 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/internal/EmptyArrays.java @@ -0,0 +1,43 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.util.internal; + +import io.netty.util.AsciiString; +import java.nio.ByteBuffer; +import java.security.cert.Certificate; +import java.security.cert.X509Certificate; + +public final class EmptyArrays { + + public static final int[] EMPTY_INTS = {}; + public static final byte[] EMPTY_BYTES = {}; + public static final char[] EMPTY_CHARS = {}; + public static final Object[] EMPTY_OBJECTS = {}; + public static final Class[] EMPTY_CLASSES = {}; + public static final String[] EMPTY_STRINGS = {}; + public static final AsciiString[] EMPTY_ASCII_STRINGS = {}; + public static final StackTraceElement[] EMPTY_STACK_TRACE = {}; + public static final ByteBuffer[] EMPTY_BYTE_BUFFERS = {}; + public static final Certificate[] EMPTY_CERTIFICATES = {}; + public static final X509Certificate[] EMPTY_X509_CERTIFICATES = {}; + public static final Certificate[] EMPTY_JAVAX_X509_CERTIFICATES = {}; + + public static final Throwable[] EMPTY_THROWABLES = {}; + + private EmptyArrays() { + } +} diff --git a/netty-util/src/main/java/io/netty/util/internal/EmptyPriorityQueue.java b/netty-util/src/main/java/io/netty/util/internal/EmptyPriorityQueue.java new file mode 100644 index 0000000..cdf8c97 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/internal/EmptyPriorityQueue.java @@ -0,0 +1,161 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.internal; + +import java.util.Collection; +import java.util.Collections; +import java.util.Iterator; +import java.util.NoSuchElementException; + +public final class EmptyPriorityQueue implements PriorityQueue { + private static final PriorityQueue INSTANCE = new EmptyPriorityQueue(); + + private EmptyPriorityQueue() { + } + + /** + * Returns an unmodifiable empty {@link PriorityQueue}. + */ + @SuppressWarnings("unchecked") + public static EmptyPriorityQueue instance() { + return (EmptyPriorityQueue) INSTANCE; + } + + @Override + public boolean removeTyped(T node) { + return false; + } + + @Override + public boolean containsTyped(T node) { + return false; + } + + @Override + public void priorityChanged(T node) { + } + + @Override + public int size() { + return 0; + } + + @Override + public boolean isEmpty() { + return true; + } + + @Override + public boolean contains(Object o) { + return false; + } + + @Override + public Iterator iterator() { + return Collections.emptyIterator(); + } + + @Override + public Object[] toArray() { + return EmptyArrays.EMPTY_OBJECTS; + } + + @Override + public T1[] toArray(T1[] a) { + if (a.length > 0) { + a[0] = null; + } + return a; + } + + @Override + public boolean add(T t) { + return false; + } + + @Override + public boolean remove(Object o) { + return false; + } + + @Override + public boolean containsAll(Collection c) { + return false; + } + + @Override + public boolean addAll(Collection c) { + return false; + } + + @Override + public boolean removeAll(Collection c) { + return false; + } + + @Override + public boolean retainAll(Collection c) { + return false; + } + + @Override + public void clear() { + } + + @Override + public void clearIgnoringIndexes() { + } + + @Override + public boolean equals(Object o) { + return o instanceof PriorityQueue && ((PriorityQueue) o).isEmpty(); + } + + @Override + public int hashCode() { + return 0; + } + + @Override + public boolean offer(T t) { + return false; + } + + @Override + public T remove() { + throw new NoSuchElementException(); + } + + @Override + public T poll() { + return null; + } + + @Override + public T element() { + throw new NoSuchElementException(); + } + + @Override + public T peek() { + return null; + } + + @Override + public String toString() { + return EmptyPriorityQueue.class.getSimpleName(); + } +} diff --git a/netty-util/src/main/java/io/netty/util/internal/IntegerHolder.java b/netty-util/src/main/java/io/netty/util/internal/IntegerHolder.java new file mode 100644 index 0000000..2c69f8b --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/internal/IntegerHolder.java @@ -0,0 +1,25 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.util.internal; + +/** + * @deprecated For removal in netty 4.2 + */ +@Deprecated +public final class IntegerHolder { + public int value; +} diff --git a/netty-util/src/main/java/io/netty/util/internal/InternalThreadLocalMap.java b/netty-util/src/main/java/io/netty/util/internal/InternalThreadLocalMap.java new file mode 100644 index 0000000..e13c874 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/internal/InternalThreadLocalMap.java @@ -0,0 +1,398 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.util.internal; + +import io.netty.util.concurrent.FastThreadLocal; +import io.netty.util.concurrent.FastThreadLocalThread; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; +import java.nio.charset.Charset; +import java.nio.charset.CharsetDecoder; +import java.nio.charset.CharsetEncoder; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.BitSet; +import java.util.IdentityHashMap; +import java.util.Map; +import java.util.Set; +import java.util.WeakHashMap; +import java.util.concurrent.atomic.AtomicInteger; + +/** + * The internal data structure that stores the thread-local variables for Netty and all {@link FastThreadLocal}s. + * Note that this class is for internal use only and is subject to change at any time. Use {@link FastThreadLocal} + * unless you know what you are doing. + */ +public final class InternalThreadLocalMap extends UnpaddedInternalThreadLocalMap { + private static final ThreadLocal slowThreadLocalMap = + new ThreadLocal(); + private static final AtomicInteger nextIndex = new AtomicInteger(); + // Internal use only. + public static final int VARIABLES_TO_REMOVE_INDEX = nextVariableIndex(); + + private static final int DEFAULT_ARRAY_LIST_INITIAL_CAPACITY = 8; + private static final int ARRAY_LIST_CAPACITY_EXPAND_THRESHOLD = 1 << 30; + // Reference: https://hg.openjdk.java.net/jdk8/jdk8/jdk/file/tip/src/share/classes/java/util/ArrayList.java#l229 + private static final int ARRAY_LIST_CAPACITY_MAX_SIZE = Integer.MAX_VALUE - 8; + + private static final int HANDLER_SHARABLE_CACHE_INITIAL_CAPACITY = 4; + private static final int INDEXED_VARIABLE_TABLE_INITIAL_SIZE = 32; + + private static final int STRING_BUILDER_INITIAL_SIZE; + private static final int STRING_BUILDER_MAX_SIZE; + + private static final InternalLogger logger; + /** + * Internal use only. + */ + public static final Object UNSET = new Object(); + + /** + * Used by {@link FastThreadLocal} + */ + private Object[] indexedVariables; + + // Core thread-locals + private int futureListenerStackDepth; + private int localChannelReaderStackDepth; + private Map, Boolean> handlerSharableCache; + private IntegerHolder counterHashCode; + private ThreadLocalRandom random; + private Map, TypeParameterMatcher> typeParameterMatcherGetCache; + private Map, Map> typeParameterMatcherFindCache; + + // String-related thread-locals + private StringBuilder stringBuilder; + private Map charsetEncoderCache; + private Map charsetDecoderCache; + + // ArrayList-related thread-locals + private ArrayList arrayList; + + private BitSet cleanerFlags; + + /** + * @deprecated These padding fields will be removed in the future. + */ + public long rp1, rp2, rp3, rp4, rp5, rp6, rp7, rp8; + + static { + STRING_BUILDER_INITIAL_SIZE = + SystemPropertyUtil.getInt("io.netty.threadLocalMap.stringBuilder.initialSize", 1024); + STRING_BUILDER_MAX_SIZE = + SystemPropertyUtil.getInt("io.netty.threadLocalMap.stringBuilder.maxSize", 1024 * 4); + + // Ensure the InternalLogger is initialized as last field in this class as InternalThreadLocalMap might be used + // by the InternalLogger itself. For this its important that all the other static fields are correctly + // initialized. + // + // See https://github.com/netty/netty/issues/12931. + logger = InternalLoggerFactory.getInstance(InternalThreadLocalMap.class); + logger.debug("-Dio.netty.threadLocalMap.stringBuilder.initialSize: {}", STRING_BUILDER_INITIAL_SIZE); + logger.debug("-Dio.netty.threadLocalMap.stringBuilder.maxSize: {}", STRING_BUILDER_MAX_SIZE); + } + + public static InternalThreadLocalMap getIfSet() { + Thread thread = Thread.currentThread(); + if (thread instanceof FastThreadLocalThread) { + return ((FastThreadLocalThread) thread).threadLocalMap(); + } + return slowThreadLocalMap.get(); + } + + public static InternalThreadLocalMap get() { + Thread thread = Thread.currentThread(); + if (thread instanceof FastThreadLocalThread) { + return fastGet((FastThreadLocalThread) thread); + } else { + return slowGet(); + } + } + + private static InternalThreadLocalMap fastGet(FastThreadLocalThread thread) { + InternalThreadLocalMap threadLocalMap = thread.threadLocalMap(); + if (threadLocalMap == null) { + thread.setThreadLocalMap(threadLocalMap = new InternalThreadLocalMap()); + } + return threadLocalMap; + } + + private static InternalThreadLocalMap slowGet() { + InternalThreadLocalMap ret = slowThreadLocalMap.get(); + if (ret == null) { + ret = new InternalThreadLocalMap(); + slowThreadLocalMap.set(ret); + } + return ret; + } + + public static void remove() { + Thread thread = Thread.currentThread(); + if (thread instanceof FastThreadLocalThread) { + ((FastThreadLocalThread) thread).setThreadLocalMap(null); + } else { + slowThreadLocalMap.remove(); + } + } + + public static void destroy() { + slowThreadLocalMap.remove(); + } + + public static int nextVariableIndex() { + int index = nextIndex.getAndIncrement(); + if (index >= ARRAY_LIST_CAPACITY_MAX_SIZE || index < 0) { + nextIndex.set(ARRAY_LIST_CAPACITY_MAX_SIZE); + throw new IllegalStateException("too many thread-local indexed variables"); + } + return index; + } + + public static int lastVariableIndex() { + return nextIndex.get() - 1; + } + + private InternalThreadLocalMap() { + indexedVariables = newIndexedVariableTable(); + } + + private static Object[] newIndexedVariableTable() { + Object[] array = new Object[INDEXED_VARIABLE_TABLE_INITIAL_SIZE]; + Arrays.fill(array, UNSET); + return array; + } + + public int size() { + int count = 0; + + if (futureListenerStackDepth != 0) { + count++; + } + if (localChannelReaderStackDepth != 0) { + count++; + } + if (handlerSharableCache != null) { + count++; + } + if (counterHashCode != null) { + count++; + } + if (random != null) { + count++; + } + if (typeParameterMatcherGetCache != null) { + count++; + } + if (typeParameterMatcherFindCache != null) { + count++; + } + if (stringBuilder != null) { + count++; + } + if (charsetEncoderCache != null) { + count++; + } + if (charsetDecoderCache != null) { + count++; + } + if (arrayList != null) { + count++; + } + + Object v = indexedVariable(VARIABLES_TO_REMOVE_INDEX); + if (v != null && v != InternalThreadLocalMap.UNSET) { + @SuppressWarnings("unchecked") + Set> variablesToRemove = (Set>) v; + count += variablesToRemove.size(); + } + + return count; + } + + public StringBuilder stringBuilder() { + StringBuilder sb = stringBuilder; + if (sb == null) { + return stringBuilder = new StringBuilder(STRING_BUILDER_INITIAL_SIZE); + } + if (sb.capacity() > STRING_BUILDER_MAX_SIZE) { + sb.setLength(STRING_BUILDER_INITIAL_SIZE); + sb.trimToSize(); + } + sb.setLength(0); + return sb; + } + + public Map charsetEncoderCache() { + Map cache = charsetEncoderCache; + if (cache == null) { + charsetEncoderCache = cache = new IdentityHashMap(); + } + return cache; + } + + public Map charsetDecoderCache() { + Map cache = charsetDecoderCache; + if (cache == null) { + charsetDecoderCache = cache = new IdentityHashMap(); + } + return cache; + } + + public ArrayList arrayList() { + return arrayList(DEFAULT_ARRAY_LIST_INITIAL_CAPACITY); + } + + @SuppressWarnings("unchecked") + public ArrayList arrayList(int minCapacity) { + ArrayList list = (ArrayList) arrayList; + if (list == null) { + arrayList = new ArrayList(minCapacity); + return (ArrayList) arrayList; + } + list.clear(); + list.ensureCapacity(minCapacity); + return list; + } + + public int futureListenerStackDepth() { + return futureListenerStackDepth; + } + + public void setFutureListenerStackDepth(int futureListenerStackDepth) { + this.futureListenerStackDepth = futureListenerStackDepth; + } + + public ThreadLocalRandom random() { + ThreadLocalRandom r = random; + if (r == null) { + random = r = new ThreadLocalRandom(); + } + return r; + } + + public Map, TypeParameterMatcher> typeParameterMatcherGetCache() { + Map, TypeParameterMatcher> cache = typeParameterMatcherGetCache; + if (cache == null) { + typeParameterMatcherGetCache = cache = new IdentityHashMap, TypeParameterMatcher>(); + } + return cache; + } + + public Map, Map> typeParameterMatcherFindCache() { + Map, Map> cache = typeParameterMatcherFindCache; + if (cache == null) { + typeParameterMatcherFindCache = cache = new IdentityHashMap, Map>(); + } + return cache; + } + + @Deprecated + public IntegerHolder counterHashCode() { + return counterHashCode; + } + + @Deprecated + public void setCounterHashCode(IntegerHolder counterHashCode) { + this.counterHashCode = counterHashCode; + } + + public Map, Boolean> handlerSharableCache() { + Map, Boolean> cache = handlerSharableCache; + if (cache == null) { + // Start with small capacity to keep memory overhead as low as possible. + handlerSharableCache = cache = new WeakHashMap, Boolean>(HANDLER_SHARABLE_CACHE_INITIAL_CAPACITY); + } + return cache; + } + + public int localChannelReaderStackDepth() { + return localChannelReaderStackDepth; + } + + public void setLocalChannelReaderStackDepth(int localChannelReaderStackDepth) { + this.localChannelReaderStackDepth = localChannelReaderStackDepth; + } + + public Object indexedVariable(int index) { + Object[] lookup = indexedVariables; + return index < lookup.length ? lookup[index] : UNSET; + } + + /** + * @return {@code true} if and only if a new thread-local variable has been created + */ + public boolean setIndexedVariable(int index, Object value) { + Object[] lookup = indexedVariables; + if (index < lookup.length) { + Object oldValue = lookup[index]; + lookup[index] = value; + return oldValue == UNSET; + } else { + expandIndexedVariableTableAndSet(index, value); + return true; + } + } + + private void expandIndexedVariableTableAndSet(int index, Object value) { + Object[] oldArray = indexedVariables; + final int oldCapacity = oldArray.length; + int newCapacity; + if (index < ARRAY_LIST_CAPACITY_EXPAND_THRESHOLD) { + newCapacity = index; + newCapacity |= newCapacity >>> 1; + newCapacity |= newCapacity >>> 2; + newCapacity |= newCapacity >>> 4; + newCapacity |= newCapacity >>> 8; + newCapacity |= newCapacity >>> 16; + newCapacity++; + } else { + newCapacity = ARRAY_LIST_CAPACITY_MAX_SIZE; + } + + Object[] newArray = Arrays.copyOf(oldArray, newCapacity); + Arrays.fill(newArray, oldCapacity, newArray.length, UNSET); + newArray[index] = value; + indexedVariables = newArray; + } + + public Object removeIndexedVariable(int index) { + Object[] lookup = indexedVariables; + if (index < lookup.length) { + Object v = lookup[index]; + lookup[index] = UNSET; + return v; + } else { + return UNSET; + } + } + + public boolean isIndexedVariableSet(int index) { + Object[] lookup = indexedVariables; + return index < lookup.length && lookup[index] != UNSET; + } + + public boolean isCleanerFlagSet(int index) { + return cleanerFlags != null && cleanerFlags.get(index); + } + + public void setCleanerFlag(int index) { + if (cleanerFlags == null) { + cleanerFlags = new BitSet(); + } + cleanerFlags.set(index); + } +} diff --git a/netty-util/src/main/java/io/netty/util/internal/LongAdderCounter.java b/netty-util/src/main/java/io/netty/util/internal/LongAdderCounter.java new file mode 100644 index 0000000..7b4c2ce --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/internal/LongAdderCounter.java @@ -0,0 +1,27 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.internal; + +import java.util.concurrent.atomic.LongAdder; + +@SuppressJava6Requirement(reason = "Usage guarded by java version check") +final class LongAdderCounter extends LongAdder implements LongCounter { + + @Override + public long value() { + return longValue(); + } +} diff --git a/netty-util/src/main/java/io/netty/util/internal/LongCounter.java b/netty-util/src/main/java/io/netty/util/internal/LongCounter.java new file mode 100644 index 0000000..1dd39fd --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/internal/LongCounter.java @@ -0,0 +1,29 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.internal; + +/** + * Counter for long. + */ +public interface LongCounter { + void add(long delta); + + void increment(); + + void decrement(); + + long value(); +} diff --git a/netty-util/src/main/java/io/netty/util/internal/MacAddressUtil.java b/netty-util/src/main/java/io/netty/util/internal/MacAddressUtil.java new file mode 100644 index 0000000..8b780c1 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/internal/MacAddressUtil.java @@ -0,0 +1,269 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.util.internal; + +import io.netty.util.NetUtil; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; +import java.net.InetAddress; +import java.net.NetworkInterface; +import java.net.SocketException; +import java.util.Arrays; +import java.util.Enumeration; +import java.util.LinkedHashMap; +import java.util.Map; +import java.util.Map.Entry; +import static io.netty.util.internal.EmptyArrays.EMPTY_BYTES; + +public final class MacAddressUtil { + private static final InternalLogger logger = InternalLoggerFactory.getInstance(MacAddressUtil.class); + + private static final int EUI64_MAC_ADDRESS_LENGTH = 8; + private static final int EUI48_MAC_ADDRESS_LENGTH = 6; + + /** + * Obtains the best MAC address found on local network interfaces. + * Generally speaking, an active network interface used on public + * networks is better than a local network interface. + * + * @return byte array containing a MAC. null if no MAC can be found. + */ + public static byte[] bestAvailableMac() { + // Find the best MAC address available. + byte[] bestMacAddr = EMPTY_BYTES; + InetAddress bestInetAddr = NetUtil.LOCALHOST4; + + // Retrieve the list of available network interfaces. + Map ifaces = new LinkedHashMap(); + for (NetworkInterface iface : NetUtil.NETWORK_INTERFACES) { + // Use the interface with proper INET addresses only. + Enumeration addrs = SocketUtils.addressesFromNetworkInterface(iface); + if (addrs.hasMoreElements()) { + InetAddress a = addrs.nextElement(); + if (!a.isLoopbackAddress()) { + ifaces.put(iface, a); + } + } + } + + for (Entry entry : ifaces.entrySet()) { + NetworkInterface iface = entry.getKey(); + InetAddress inetAddr = entry.getValue(); + if (iface.isVirtual()) { + continue; + } + + byte[] macAddr; + try { + macAddr = SocketUtils.hardwareAddressFromNetworkInterface(iface); + } catch (SocketException e) { + logger.debug("Failed to get the hardware address of a network interface: {}", iface, e); + continue; + } + + boolean replace = false; + int res = compareAddresses(bestMacAddr, macAddr); + if (res < 0) { + // Found a better MAC address. + replace = true; + } else if (res == 0) { + // Two MAC addresses are of pretty much same quality. + res = compareAddresses(bestInetAddr, inetAddr); + if (res < 0) { + // Found a MAC address with better INET address. + replace = true; + } else if (res == 0) { + // Cannot tell the difference. Choose the longer one. + if (bestMacAddr.length < macAddr.length) { + replace = true; + } + } + } + + if (replace) { + bestMacAddr = macAddr; + bestInetAddr = inetAddr; + } + } + + if (bestMacAddr == EMPTY_BYTES) { + return null; + } + + if (bestMacAddr.length == EUI48_MAC_ADDRESS_LENGTH) { // EUI-48 - convert to EUI-64 + byte[] newAddr = new byte[EUI64_MAC_ADDRESS_LENGTH]; + System.arraycopy(bestMacAddr, 0, newAddr, 0, 3); + newAddr[3] = (byte) 0xFF; + newAddr[4] = (byte) 0xFE; + System.arraycopy(bestMacAddr, 3, newAddr, 5, 3); + bestMacAddr = newAddr; + } else { + // Unknown + bestMacAddr = Arrays.copyOf(bestMacAddr, EUI64_MAC_ADDRESS_LENGTH); + } + + return bestMacAddr; + } + + /** + * Returns the result of {@link #bestAvailableMac()} if non-{@code null} otherwise returns a random EUI-64 MAC + * address. + */ + public static byte[] defaultMachineId() { + byte[] bestMacAddr = bestAvailableMac(); + if (bestMacAddr == null) { + bestMacAddr = new byte[EUI64_MAC_ADDRESS_LENGTH]; + PlatformDependent.threadLocalRandom().nextBytes(bestMacAddr); + logger.warn( + "Failed to find a usable hardware address from the network interfaces; using random bytes: {}", + formatAddress(bestMacAddr)); + } + return bestMacAddr; + } + + /** + * Parse a EUI-48, MAC-48, or EUI-64 MAC address from a {@link String} and return it as a {@code byte[]}. + * + * @param value The string representation of the MAC address. + * @return The byte representation of the MAC address. + */ + public static byte[] parseMAC(String value) { + final byte[] machineId; + final char separator; + switch (value.length()) { + case 17: + separator = value.charAt(2); + validateMacSeparator(separator); + machineId = new byte[EUI48_MAC_ADDRESS_LENGTH]; + break; + case 23: + separator = value.charAt(2); + validateMacSeparator(separator); + machineId = new byte[EUI64_MAC_ADDRESS_LENGTH]; + break; + default: + throw new IllegalArgumentException("value is not supported [MAC-48, EUI-48, EUI-64]"); + } + + final int end = machineId.length - 1; + int j = 0; + for (int i = 0; i < end; ++i, j += 3) { + final int sIndex = j + 2; + machineId[i] = StringUtil.decodeHexByte(value, j); + if (value.charAt(sIndex) != separator) { + throw new IllegalArgumentException("expected separator '" + separator + " but got '" + + value.charAt(sIndex) + "' at index: " + sIndex); + } + } + + machineId[end] = StringUtil.decodeHexByte(value, j); + + return machineId; + } + + private static void validateMacSeparator(char separator) { + if (separator != ':' && separator != '-') { + throw new IllegalArgumentException("unsupported separator: " + separator + " (expected: [:-])"); + } + } + + /** + * @param addr byte array of a MAC address. + * @return hex formatted MAC address. + */ + public static String formatAddress(byte[] addr) { + StringBuilder buf = new StringBuilder(24); + for (byte b : addr) { + buf.append(String.format("%02x:", b & 0xff)); + } + return buf.substring(0, buf.length() - 1); + } + + /** + * @return positive - current is better, 0 - cannot tell from MAC addr, negative - candidate is better. + */ + // visible for testing + static int compareAddresses(byte[] current, byte[] candidate) { + if (candidate == null || candidate.length < EUI48_MAC_ADDRESS_LENGTH) { + return 1; + } + + // Must not be filled with only 0 and 1. + boolean onlyZeroAndOne = true; + for (byte b : candidate) { + if (b != 0 && b != 1) { + onlyZeroAndOne = false; + break; + } + } + + if (onlyZeroAndOne) { + return 1; + } + + // Must not be a multicast address + if ((candidate[0] & 1) != 0) { + return 1; + } + + // Prefer globally unique address. + if ((candidate[0] & 2) == 0) { + if (current.length != 0 && (current[0] & 2) == 0) { + // Both current and candidate are globally unique addresses. + return 0; + } else { + // Only candidate is globally unique. + return -1; + } + } else { + if (current.length != 0 && (current[0] & 2) == 0) { + // Only current is globally unique. + return 1; + } else { + // Both current and candidate are non-unique. + return 0; + } + } + } + + /** + * @return positive - current is better, 0 - cannot tell, negative - candidate is better + */ + private static int compareAddresses(InetAddress current, InetAddress candidate) { + return scoreAddress(current) - scoreAddress(candidate); + } + + private static int scoreAddress(InetAddress addr) { + if (addr.isAnyLocalAddress() || addr.isLoopbackAddress()) { + return 0; + } + if (addr.isMulticastAddress()) { + return 1; + } + if (addr.isLinkLocalAddress()) { + return 2; + } + if (addr.isSiteLocalAddress()) { + return 3; + } + + return 4; + } + + private MacAddressUtil() { + } +} diff --git a/netty-util/src/main/java/io/netty/util/internal/MathUtil.java b/netty-util/src/main/java/io/netty/util/internal/MathUtil.java new file mode 100644 index 0000000..dfaebf9 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/internal/MathUtil.java @@ -0,0 +1,97 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.util.internal; + +/** + * Math utility methods. + */ +public final class MathUtil { + + private MathUtil() { + } + + /** + * Fast method of finding the next power of 2 greater than or equal to the supplied value. + * + *

If the value is {@code <= 0} then 1 will be returned. + * This method is not suitable for {@link Integer#MIN_VALUE} or numbers greater than 2^30. + * + * @param value from which to search for next power of 2 + * @return The next power of 2 or the value itself if it is a power of 2 + */ + public static int findNextPositivePowerOfTwo(final int value) { + assert value > Integer.MIN_VALUE && value < 0x40000000; + return 1 << (32 - Integer.numberOfLeadingZeros(value - 1)); + } + + /** + * Fast method of finding the next power of 2 greater than or equal to the supplied value. + *

This method will do runtime bounds checking and call {@link #findNextPositivePowerOfTwo(int)} if within a + * valid range. + * + * @param value from which to search for next power of 2 + * @return The next power of 2 or the value itself if it is a power of 2. + *

Special cases for return values are as follows: + *

    + *
  • {@code <= 0} -> 1
  • + *
  • {@code >= 2^30} -> 2^30
  • + *
+ */ + public static int safeFindNextPositivePowerOfTwo(final int value) { + return value <= 0 ? 1 : value >= 0x40000000 ? 0x40000000 : findNextPositivePowerOfTwo(value); + } + + /** + * Determine if the requested {@code index} and {@code length} will fit within {@code capacity}. + * + * @param index The starting index. + * @param length The length which will be utilized (starting from {@code index}). + * @param capacity The capacity that {@code index + length} is allowed to be within. + * @return {@code false} if the requested {@code index} and {@code length} will fit within {@code capacity}. + * {@code true} if this would result in an index out of bounds exception. + */ + public static boolean isOutOfBounds(int index, int length, int capacity) { + return (index | length | capacity | (index + length) | (capacity - (index + length))) < 0; + } + + /** + * Compares two {@code int} values. + * + * @param x the first {@code int} to compare + * @param y the second {@code int} to compare + * @return the value {@code 0} if {@code x == y}; + * {@code -1} if {@code x < y}; and + * {@code 1} if {@code x > y} + */ + public static int compare(final int x, final int y) { + // do not subtract for comparison, it could overflow + return x < y ? -1 : (x > y ? 1 : 0); + } + + /** + * Compare two {@code long} values. + * + * @param x the first {@code long} to compare. + * @param y the second {@code long} to compare. + * @return
    + *
  • 0 if {@code x == y}
  • + *
  • {@code > 0} if {@code x > y}
  • + *
  • {@code < 0} if {@code x < y}
  • + *
+ */ + public static int compare(long x, long y) { + return (x < y) ? -1 : (x > y) ? 1 : 0; + } +} diff --git a/netty-util/src/main/java/io/netty/util/internal/NativeLibraryLoader.java b/netty-util/src/main/java/io/netty/util/internal/NativeLibraryLoader.java new file mode 100644 index 0000000..ba2a03d --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/internal/NativeLibraryLoader.java @@ -0,0 +1,567 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.internal; + +import io.netty.util.CharsetUtil; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; +import java.io.ByteArrayOutputStream; +import java.io.Closeable; +import java.io.File; +import java.io.FileNotFoundException; +import java.io.FileOutputStream; +import java.io.IOException; +import java.io.InputStream; +import java.io.OutputStream; +import java.lang.reflect.Method; +import java.net.URL; +import java.security.MessageDigest; +import java.security.NoSuchAlgorithmException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.EnumSet; +import java.util.Enumeration; +import java.util.List; +import java.util.Set; + +/** + * Helper class to load JNI resources. + */ +public final class NativeLibraryLoader { + + private static final InternalLogger logger = InternalLoggerFactory.getInstance(NativeLibraryLoader.class); + + private static final String NATIVE_RESOURCE_HOME = "META-INF/native/"; + private static final File WORKDIR; + private static final boolean DELETE_NATIVE_LIB_AFTER_LOADING; + private static final boolean TRY_TO_PATCH_SHADED_ID; + private static final boolean DETECT_NATIVE_LIBRARY_DUPLICATES; + + // Just use a-Z and numbers as valid ID bytes. + private static final byte[] UNIQUE_ID_BYTES = + "0123456789abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ".getBytes(CharsetUtil.US_ASCII); + + static { + String workdir = SystemPropertyUtil.get("io.netty.native.workdir"); + if (workdir != null) { + File f = new File(workdir); + f.mkdirs(); + + try { + f = f.getAbsoluteFile(); + } catch (Exception ignored) { + // Good to have an absolute path, but it's OK. + } + + WORKDIR = f; + logger.debug("-Dio.netty.native.workdir: " + WORKDIR); + } else { + WORKDIR = PlatformDependent.tmpdir(); + logger.debug("-Dio.netty.native.workdir: " + WORKDIR + " (io.netty.tmpdir)"); + } + + DELETE_NATIVE_LIB_AFTER_LOADING = SystemPropertyUtil.getBoolean( + "io.netty.native.deleteLibAfterLoading", true); + logger.debug("-Dio.netty.native.deleteLibAfterLoading: {}", DELETE_NATIVE_LIB_AFTER_LOADING); + + TRY_TO_PATCH_SHADED_ID = SystemPropertyUtil.getBoolean( + "io.netty.native.tryPatchShadedId", true); + logger.debug("-Dio.netty.native.tryPatchShadedId: {}", TRY_TO_PATCH_SHADED_ID); + + DETECT_NATIVE_LIBRARY_DUPLICATES = SystemPropertyUtil.getBoolean( + "io.netty.native.detectNativeLibraryDuplicates", true); + logger.debug("-Dio.netty.native.detectNativeLibraryDuplicates: {}", DETECT_NATIVE_LIBRARY_DUPLICATES); + } + + /** + * Loads the first available library in the collection with the specified + * {@link ClassLoader}. + * + * @throws IllegalArgumentException if none of the given libraries load successfully. + */ + public static void loadFirstAvailable(ClassLoader loader, String... names) { + List suppressed = new ArrayList(); + for (String name : names) { + try { + load(name, loader); + logger.debug("Loaded library with name '{}'", name); + return; + } catch (Throwable t) { + suppressed.add(t); + } + } + + IllegalArgumentException iae = + new IllegalArgumentException("Failed to load any of the given libraries: " + Arrays.toString(names)); + ThrowableUtil.addSuppressedAndClear(iae, suppressed); + throw iae; + } + + /** + * Calculates the mangled shading prefix added to this class's full name. + * + *

This method mangles the package name as follows, so we can unmangle it back later: + *

    + *
  • {@code _} to {@code _1}
  • + *
  • {@code .} to {@code _}
  • + *
+ * + *

Note that we don't mangle non-ASCII characters here because it's extremely unlikely to have + * a non-ASCII character in a package name. For more information, see: + *

+ * + * @throws UnsatisfiedLinkError if the shader used something other than a prefix + */ + private static String calculateMangledPackagePrefix() { + String maybeShaded = NativeLibraryLoader.class.getName(); + // Use ! instead of . to avoid shading utilities from modifying the string + String expected = "io!netty!util!internal!NativeLibraryLoader".replace('!', '.'); + if (!maybeShaded.endsWith(expected)) { + throw new UnsatisfiedLinkError(String.format( + "Could not find prefix added to %s to get %s. When shading, only adding a " + + "package prefix is supported", expected, maybeShaded)); + } + return maybeShaded.substring(0, maybeShaded.length() - expected.length()) + .replace("_", "_1") + .replace('.', '_'); + } + + /** + * Load the given library with the specified {@link ClassLoader} + */ + public static void load(String originalName, ClassLoader loader) { + String mangledPackagePrefix = calculateMangledPackagePrefix(); + String name = mangledPackagePrefix + originalName; + List suppressed = new ArrayList(); + try { + // first try to load from java.library.path + loadLibrary(loader, name, false); + return; + } catch (Throwable ex) { + suppressed.add(ex); + } + + String libname = System.mapLibraryName(name); + String path = NATIVE_RESOURCE_HOME + libname; + + InputStream in = null; + OutputStream out = null; + File tmpFile = null; + URL url = getResource(path, loader); + try { + if (url == null) { + if (PlatformDependent.isOsx()) { + String fileName = path.endsWith(".jnilib") ? NATIVE_RESOURCE_HOME + "lib" + name + ".dynlib" : + NATIVE_RESOURCE_HOME + "lib" + name + ".jnilib"; + url = getResource(fileName, loader); + if (url == null) { + FileNotFoundException fnf = new FileNotFoundException(fileName); + ThrowableUtil.addSuppressedAndClear(fnf, suppressed); + throw fnf; + } + } else { + FileNotFoundException fnf = new FileNotFoundException(path); + ThrowableUtil.addSuppressedAndClear(fnf, suppressed); + throw fnf; + } + } + + int index = libname.lastIndexOf('.'); + String prefix = libname.substring(0, index); + String suffix = libname.substring(index); + + tmpFile = PlatformDependent.createTempFile(prefix, suffix, WORKDIR); + in = url.openStream(); + out = new FileOutputStream(tmpFile); + + byte[] buffer = new byte[8192]; + int length; + while ((length = in.read(buffer)) > 0) { + out.write(buffer, 0, length); + } + out.flush(); + + if (shouldShadedLibraryIdBePatched(mangledPackagePrefix)) { + // Let's try to patch the id and re-sign it. This is a best-effort and might fail if a + // SecurityManager is setup or the right executables are not installed :/ + tryPatchShadedLibraryIdAndSign(tmpFile, originalName); + } + + // Close the output stream before loading the unpacked library, + // because otherwise Windows will refuse to load it when it's in use by other process. + closeQuietly(out); + out = null; + + loadLibrary(loader, tmpFile.getPath(), true); + } catch (UnsatisfiedLinkError e) { + try { + if (tmpFile != null && tmpFile.isFile() && tmpFile.canRead() && + !NoexecVolumeDetector.canExecuteExecutable(tmpFile)) { + // Pass "io.netty.native.workdir" as an argument to allow shading tools to see + // the string. Since this is printed out to users to tell them what to do next, + // we want the value to be correct even when shading. + logger.info("{} exists but cannot be executed even when execute permissions set; " + + "check volume for \"noexec\" flag; use -D{}=[path] " + + "to set native working directory separately.", + tmpFile.getPath(), "io.netty.native.workdir"); + } + } catch (Throwable t) { + suppressed.add(t); + logger.debug("Error checking if {} is on a file store mounted with noexec", tmpFile, t); + } + // Re-throw to fail the load + ThrowableUtil.addSuppressedAndClear(e, suppressed); + throw e; + } catch (Exception e) { + UnsatisfiedLinkError ule = new UnsatisfiedLinkError("could not load a native library: " + name); + ule.initCause(e); + ThrowableUtil.addSuppressedAndClear(ule, suppressed); + throw ule; + } finally { + closeQuietly(in); + closeQuietly(out); + // After we load the library it is safe to delete the file. + // We delete the file immediately to free up resources as soon as possible, + // and if this fails fallback to deleting on JVM exit. + if (tmpFile != null && (!DELETE_NATIVE_LIB_AFTER_LOADING || !tmpFile.delete())) { + tmpFile.deleteOnExit(); + } + } + } + + private static URL getResource(String path, ClassLoader loader) { + final Enumeration urls; + try { + if (loader == null) { + urls = ClassLoader.getSystemResources(path); + } else { + urls = loader.getResources(path); + } + } catch (IOException iox) { + throw new RuntimeException("An error occurred while getting the resources for " + path, iox); + } + + List urlsList = Collections.list(urls); + int size = urlsList.size(); + switch (size) { + case 0: + return null; + case 1: + return urlsList.get(0); + default: + if (DETECT_NATIVE_LIBRARY_DUPLICATES) { + try { + MessageDigest md = MessageDigest.getInstance("SHA-256"); + // We found more than 1 resource with the same name. Let's check if the content of the file is + // the same as in this case it will not have any bad effect. + URL url = urlsList.get(0); + byte[] digest = digest(md, url); + boolean allSame = true; + if (digest != null) { + for (int i = 1; i < size; i++) { + byte[] digest2 = digest(md, urlsList.get(i)); + if (digest2 == null || !Arrays.equals(digest, digest2)) { + allSame = false; + break; + } + } + } else { + allSame = false; + } + if (allSame) { + return url; + } + } catch (NoSuchAlgorithmException e) { + logger.debug("Don't support SHA-256, can't check if resources have same content.", e); + } + + throw new IllegalStateException( + "Multiple resources found for '" + path + "' with different content: " + urlsList); + } else { + logger.warn("Multiple resources found for '" + path + "' with different content: " + + urlsList + ". Please fix your dependency graph."); + return urlsList.get(0); + } + } + } + + private static byte[] digest(MessageDigest digest, URL url) { + InputStream in = null; + try { + in = url.openStream(); + byte[] bytes = new byte[8192]; + int i; + while ((i = in.read(bytes)) != -1) { + digest.update(bytes, 0, i); + } + return digest.digest(); + } catch (IOException e) { + logger.debug("Can't read resource.", e); + return null; + } finally { + closeQuietly(in); + } + } + + static void tryPatchShadedLibraryIdAndSign(File libraryFile, String originalName) { + if (!new File("/Library/Developer/CommandLineTools").exists()) { + logger.debug("Can't patch shaded library id as CommandLineTools are not installed." + + " Consider installing CommandLineTools with 'xcode-select --install'"); + return; + } + String newId = new String(generateUniqueId(originalName.length()), CharsetUtil.UTF_8); + if (!tryExec("install_name_tool -id " + newId + " " + libraryFile.getAbsolutePath())) { + return; + } + + tryExec("codesign -s - " + libraryFile.getAbsolutePath()); + } + + private static boolean tryExec(String cmd) { + try { + int exitValue = Runtime.getRuntime().exec(cmd).waitFor(); + if (exitValue != 0) { + logger.debug("Execution of '{}' failed: {}", cmd, exitValue); + return false; + } + logger.debug("Execution of '{}' succeed: {}", cmd, exitValue); + return true; + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } catch (IOException e) { + logger.info("Execution of '{}' failed.", cmd, e); + } catch (SecurityException e) { + logger.error("Execution of '{}' failed.", cmd, e); + } + return false; + } + + private static boolean shouldShadedLibraryIdBePatched(String packagePrefix) { + return TRY_TO_PATCH_SHADED_ID && PlatformDependent.isOsx() && !packagePrefix.isEmpty(); + } + + private static byte[] generateUniqueId(int length) { + byte[] idBytes = new byte[length]; + for (int i = 0; i < idBytes.length; i++) { + // We should only use bytes as replacement that are in our UNIQUE_ID_BYTES array. + idBytes[i] = UNIQUE_ID_BYTES[PlatformDependent.threadLocalRandom() + .nextInt(UNIQUE_ID_BYTES.length)]; + } + return idBytes; + } + + /** + * Loading the native library into the specified {@link ClassLoader}. + * + * @param loader - The {@link ClassLoader} where the native library will be loaded into + * @param name - The native library path or name + * @param absolute - Whether the native library will be loaded by path or by name + */ + private static void loadLibrary(final ClassLoader loader, final String name, final boolean absolute) { + Throwable suppressed = null; + try { + try { + // Make sure the helper belongs to the target ClassLoader. + final Class newHelper = tryToLoadClass(loader, NativeLibraryUtil.class); + loadLibraryByHelper(newHelper, name, absolute); + logger.debug("Successfully loaded the library {}", name); + return; + } catch (UnsatisfiedLinkError e) { // Should by pass the UnsatisfiedLinkError here! + suppressed = e; + } catch (Exception e) { + suppressed = e; + } + NativeLibraryUtil.loadLibrary(name, absolute); // Fallback to local helper class. + logger.debug("Successfully loaded the library {}", name); + } catch (NoSuchMethodError nsme) { + if (suppressed != null) { + ThrowableUtil.addSuppressed(nsme, suppressed); + } + rethrowWithMoreDetailsIfPossible(name, nsme); + } catch (UnsatisfiedLinkError ule) { + if (suppressed != null) { + ThrowableUtil.addSuppressed(ule, suppressed); + } + throw ule; + } + } + + @SuppressJava6Requirement(reason = "Guarded by version check") + private static void rethrowWithMoreDetailsIfPossible(String name, NoSuchMethodError error) { + if (PlatformDependent.javaVersion() >= 7) { + throw new LinkageError( + "Possible multiple incompatible native libraries on the classpath for '" + name + "'?", error); + } + throw error; + } + + private static void loadLibraryByHelper(final Class helper, final String name, final boolean absolute) + throws UnsatisfiedLinkError { + Object ret; + try { + // Invoke the helper to load the native library, if succeed, then the native + // library belong to the specified ClassLoader. + Method method = helper.getMethod("loadLibrary", String.class, boolean.class); + method.setAccessible(true); + ret = method.invoke(null, name, absolute); + } catch (Exception e) { + ret = e; + } + if (ret instanceof Throwable t) { + assert !(t instanceof UnsatisfiedLinkError) : t + " should be a wrapper throwable"; + Throwable cause = t.getCause(); + if (cause instanceof UnsatisfiedLinkError) { + throw (UnsatisfiedLinkError) cause; + } + UnsatisfiedLinkError ule = new UnsatisfiedLinkError(t.getMessage()); + ule.initCause(t); + throw ule; + } + } + + /** + * Try to load the helper {@link Class} into specified {@link ClassLoader}. + * + * @param loader - The {@link ClassLoader} where to load the helper {@link Class} + * @param helper - The helper {@link Class} + * @return A new helper Class defined in the specified ClassLoader. + * @throws ClassNotFoundException Helper class not found or loading failed + */ + private static Class tryToLoadClass(final ClassLoader loader, final Class helper) + throws ClassNotFoundException { + try { + return Class.forName(helper.getName(), false, loader); + } catch (ClassNotFoundException e1) { + if (loader == null) { + // cannot defineClass inside bootstrap class loader + throw e1; + } + try { + // The helper class is NOT found in target ClassLoader, we have to define the helper class. + final byte[] classBinary = classToByteArray(helper); + try { + // Define the helper class in the target ClassLoader, + // then we can call the helper to load the native library. + Method defineClass = ClassLoader.class.getDeclaredMethod("defineClass", String.class, + byte[].class, int.class, int.class); + defineClass.setAccessible(true); + return (Class) defineClass.invoke(loader, helper.getName(), classBinary, 0, + classBinary.length); + } catch (Exception e) { + throw new IllegalStateException("Define class failed!", e); + } + } catch (ClassNotFoundException | RuntimeException | Error e2) { + ThrowableUtil.addSuppressed(e2, e1); + throw e2; + } + } + } + + /** + * Load the helper {@link Class} as a byte array, to be redefined in specified {@link ClassLoader}. + * + * @param clazz - The helper {@link Class} provided by this bundle + * @return The binary content of helper {@link Class}. + * @throws ClassNotFoundException Helper class not found or loading failed + */ + private static byte[] classToByteArray(Class clazz) throws ClassNotFoundException { + String fileName = clazz.getName(); + int lastDot = fileName.lastIndexOf('.'); + if (lastDot > 0) { + fileName = fileName.substring(lastDot + 1); + } + URL classUrl = clazz.getResource(fileName + ".class"); + if (classUrl == null) { + throw new ClassNotFoundException(clazz.getName()); + } + byte[] buf = new byte[1024]; + ByteArrayOutputStream out = new ByteArrayOutputStream(4096); + InputStream in = null; + try { + in = classUrl.openStream(); + for (int r; (r = in.read(buf)) != -1; ) { + out.write(buf, 0, r); + } + return out.toByteArray(); + } catch (IOException ex) { + throw new ClassNotFoundException(clazz.getName(), ex); + } finally { + closeQuietly(in); + closeQuietly(out); + } + } + + private static void closeQuietly(Closeable c) { + if (c != null) { + try { + c.close(); + } catch (IOException ignore) { + // ignore + } + } + } + + private NativeLibraryLoader() { + // Utility + } + + private static final class NoexecVolumeDetector { + + @SuppressJava6Requirement(reason = "Usage guarded by java version check") + private static boolean canExecuteExecutable(File file) throws IOException { + if (PlatformDependent.javaVersion() < 7) { + // Pre-JDK7, the Java API did not directly support POSIX permissions; instead of implementing a custom + // work-around, assume true, which disables the check. + return true; + } + + // If we can already execute, there is nothing to do. + if (file.canExecute()) { + return true; + } + + // On volumes, with noexec set, even files with the executable POSIX permissions will fail to execute. + // The File#canExecute() method honors this behavior, probaby via parsing the noexec flag when initializing + // the UnixFileStore, though the flag is not exposed via a public API. To find out if library is being + // loaded off a volume with noexec, confirm or add executalbe permissions, then check File#canExecute(). + + // Note: We use FQCN to not break when netty is used in java6 + Set existingFilePermissions = + java.nio.file.Files.getPosixFilePermissions(file.toPath()); + Set executePermissions = + EnumSet.of(java.nio.file.attribute.PosixFilePermission.OWNER_EXECUTE, + java.nio.file.attribute.PosixFilePermission.GROUP_EXECUTE, + java.nio.file.attribute.PosixFilePermission.OTHERS_EXECUTE); + if (existingFilePermissions.containsAll(executePermissions)) { + return false; + } + + Set newPermissions = EnumSet.copyOf(existingFilePermissions); + newPermissions.addAll(executePermissions); + java.nio.file.Files.setPosixFilePermissions(file.toPath(), newPermissions); + return file.canExecute(); + } + + private NoexecVolumeDetector() { + // Utility + } + } +} diff --git a/netty-util/src/main/java/io/netty/util/internal/NativeLibraryUtil.java b/netty-util/src/main/java/io/netty/util/internal/NativeLibraryUtil.java new file mode 100644 index 0000000..1e28679 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/internal/NativeLibraryUtil.java @@ -0,0 +1,46 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.internal; + +/** + * A Utility to Call the {@link System#load(String)} or {@link System#loadLibrary(String)}. + * Because the {@link System#load(String)} and {@link System#loadLibrary(String)} are both + * CallerSensitive, it will load the native library into its caller's {@link ClassLoader}. + * In OSGi environment, we need this helper to delegate the calling to {@link System#load(String)} + * and it should be as simple as possible. It will be injected into the native library's + * ClassLoader when it is undefined. And therefore, when the defined new helper is invoked, + * the native library would be loaded into the native library's ClassLoader, not the + * caller's ClassLoader. + */ +final class NativeLibraryUtil { + /** + * Delegate the calling to {@link System#load(String)} or {@link System#loadLibrary(String)}. + * + * @param libName - The native library path or name + * @param absolute - Whether the native library will be loaded by path or by name + */ + public static void loadLibrary(String libName, boolean absolute) { + if (absolute) { + System.load(libName); + } else { + System.loadLibrary(libName); + } + } + + private NativeLibraryUtil() { + // Utility + } +} diff --git a/netty-util/src/main/java/io/netty/util/internal/NoOpTypeParameterMatcher.java b/netty-util/src/main/java/io/netty/util/internal/NoOpTypeParameterMatcher.java new file mode 100644 index 0000000..6044848 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/internal/NoOpTypeParameterMatcher.java @@ -0,0 +1,24 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.util.internal; + +public final class NoOpTypeParameterMatcher extends TypeParameterMatcher { + @Override + public boolean match(Object msg) { + return true; + } +} diff --git a/netty-util/src/main/java/io/netty/util/internal/ObjectCleaner.java b/netty-util/src/main/java/io/netty/util/internal/ObjectCleaner.java new file mode 100644 index 0000000..9e2c44c --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/internal/ObjectCleaner.java @@ -0,0 +1,147 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.internal; + +import io.netty.util.concurrent.FastThreadLocalThread; +import java.lang.ref.ReferenceQueue; +import java.lang.ref.WeakReference; +import java.util.Set; +import java.util.concurrent.atomic.AtomicBoolean; +import static io.netty.util.internal.SystemPropertyUtil.getInt; +import static java.lang.Math.max; + +/** + * Allows a way to register some {@link Runnable} that will executed once there are no references to an {@link Object} + * anymore. + */ +public final class ObjectCleaner { + private static final int REFERENCE_QUEUE_POLL_TIMEOUT_MS = + max(500, getInt("io.netty.util.internal.ObjectCleaner.refQueuePollTimeout", 10000)); + + // Package-private for testing + static final String CLEANER_THREAD_NAME = ObjectCleaner.class.getSimpleName() + "Thread"; + // This will hold a reference to the AutomaticCleanerReference which will be removed once we called cleanup() + private static final Set LIVE_SET = new ConcurrentSet(); + private static final ReferenceQueue REFERENCE_QUEUE = new ReferenceQueue(); + private static final AtomicBoolean CLEANER_RUNNING = new AtomicBoolean(false); + private static final Runnable CLEANER_TASK = new Runnable() { + @Override + public void run() { + boolean interrupted = false; + for (; ; ) { + // Keep on processing as long as the LIVE_SET is not empty and once it becomes empty + // See if we can let this thread complete. + while (!LIVE_SET.isEmpty()) { + final AutomaticCleanerReference reference; + try { + reference = (AutomaticCleanerReference) REFERENCE_QUEUE.remove(REFERENCE_QUEUE_POLL_TIMEOUT_MS); + } catch (InterruptedException ex) { + // Just consume and move on + interrupted = true; + continue; + } + if (reference != null) { + try { + reference.cleanup(); + } catch (Throwable ignored) { + // ignore exceptions, and don't log in case the logger throws an exception, blocks, or has + // other unexpected side effects. + } + LIVE_SET.remove(reference); + } + } + CLEANER_RUNNING.set(false); + + // Its important to first access the LIVE_SET and then CLEANER_RUNNING to ensure correct + // behavior in multi-threaded environments. + if (LIVE_SET.isEmpty() || !CLEANER_RUNNING.compareAndSet(false, true)) { + // There was nothing added after we set STARTED to false or some other cleanup Thread + // was started already so its safe to let this Thread complete now. + break; + } + } + if (interrupted) { + // As we caught the InterruptedException above we should mark the Thread as interrupted. + Thread.currentThread().interrupt(); + } + } + }; + + /** + * Register the given {@link Object} for which the {@link Runnable} will be executed once there are no references + * to the object anymore. + *

+ * This should only be used if there are no other ways to execute some cleanup once the Object is not reachable + * anymore because it is not a cheap way to handle the cleanup. + */ + public static void register(Object object, Runnable cleanupTask) { + AutomaticCleanerReference reference = new AutomaticCleanerReference(object, + ObjectUtil.checkNotNull(cleanupTask, "cleanupTask")); + // Its important to add the reference to the LIVE_SET before we access CLEANER_RUNNING to ensure correct + // behavior in multi-threaded environments. + LIVE_SET.add(reference); + + // Check if there is already a cleaner running. + if (CLEANER_RUNNING.compareAndSet(false, true)) { + final Thread cleanupThread = new FastThreadLocalThread(CLEANER_TASK); + cleanupThread.setPriority(Thread.MIN_PRIORITY); + // Set to null to ensure we not create classloader leaks by holding a strong reference to the inherited + // classloader. + // See: + // - https://github.com/netty/netty/issues/7290 + // - https://bugs.openjdk.java.net/browse/JDK-7008595 + cleanupThread.setContextClassLoader(null); + cleanupThread.setName(CLEANER_THREAD_NAME); + + // Mark this as a daemon thread to ensure that we the JVM can exit if this is the only thread that is + // running. + cleanupThread.setDaemon(true); + cleanupThread.start(); + } + } + + public static int getLiveSetCount() { + return LIVE_SET.size(); + } + + private ObjectCleaner() { + // Only contains a static method. + } + + private static final class AutomaticCleanerReference extends WeakReference { + private final Runnable cleanupTask; + + AutomaticCleanerReference(Object referent, Runnable cleanupTask) { + super(referent, REFERENCE_QUEUE); + this.cleanupTask = cleanupTask; + } + + void cleanup() { + cleanupTask.run(); + } + + @Override + public Thread get() { + return null; + } + + @Override + public void clear() { + LIVE_SET.remove(this); + super.clear(); + } + } +} diff --git a/netty-util/src/main/java/io/netty/util/internal/ObjectPool.java b/netty-util/src/main/java/io/netty/util/internal/ObjectPool.java new file mode 100644 index 0000000..0fb900a --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/internal/ObjectPool.java @@ -0,0 +1,91 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.internal; + +import io.netty.util.Recycler; + +/** + * Light-weight object pool. + * + * @param the type of the pooled object + */ +public abstract class ObjectPool { + + ObjectPool() { + } + + /** + * Get a {@link Object} from the {@link ObjectPool}. The returned {@link Object} may be created via + * {@link ObjectCreator#newObject(Handle)} if no pooled {@link Object} is ready to be reused. + */ + public abstract T get(); + + /** + * Handle for an pooled {@link Object} that will be used to notify the {@link ObjectPool} once it can + * reuse the pooled {@link Object} again. + * + * @param + */ + public interface Handle { + /** + * Recycle the {@link Object} if possible and so make it ready to be reused. + */ + void recycle(T self); + } + + /** + * Creates a new Object which references the given {@link Handle} and calls {@link Handle#recycle(Object)} once + * it can be re-used. + * + * @param the type of the pooled object + */ + public interface ObjectCreator { + + /** + * Creates an returns a new {@link Object} that can be used and later recycled via + * {@link Handle#recycle(Object)}. + * + * @param handle can NOT be null. + */ + T newObject(Handle handle); + } + + /** + * Creates a new {@link ObjectPool} which will use the given {@link ObjectCreator} to create the {@link Object} + * that should be pooled. + */ + public static ObjectPool newPool(final ObjectCreator creator) { + return new RecyclerObjectPool(ObjectUtil.checkNotNull(creator, "creator")); + } + + private static final class RecyclerObjectPool extends ObjectPool { + private final Recycler recycler; + + RecyclerObjectPool(final ObjectCreator creator) { + recycler = new Recycler() { + @Override + protected T newObject(Handle handle) { + return creator.newObject(handle); + } + }; + } + + @Override + public T get() { + return recycler.get(); + } + } +} diff --git a/netty-util/src/main/java/io/netty/util/internal/ObjectUtil.java b/netty-util/src/main/java/io/netty/util/internal/ObjectUtil.java new file mode 100644 index 0000000..3263f49 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/internal/ObjectUtil.java @@ -0,0 +1,329 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.util.internal; + +import java.util.Collection; +import java.util.Map; + +/** + * A grab-bag of useful utility methods. + */ +public final class ObjectUtil { + + private static final float FLOAT_ZERO = 0.0F; + private static final double DOUBLE_ZERO = 0.0D; + private static final long LONG_ZERO = 0L; + private static final int INT_ZERO = 0; + + private ObjectUtil() { + } + + /** + * Checks that the given argument is not null. If it is, throws {@link NullPointerException}. + * Otherwise, returns the argument. + */ + public static T checkNotNull(T arg, String text) { + if (arg == null) { + throw new NullPointerException(text); + } + return arg; + } + + /** + * Check that the given varargs is not null and does not contain elements + * null elements. + *

+ * If it is, throws {@link NullPointerException}. + * Otherwise, returns the argument. + */ + public static T[] deepCheckNotNull(String text, T... varargs) { + if (varargs == null) { + throw new NullPointerException(text); + } + + for (T element : varargs) { + if (element == null) { + throw new NullPointerException(text); + } + } + return varargs; + } + + /** + * Checks that the given argument is not null. If it is, throws {@link IllegalArgumentException}. + * Otherwise, returns the argument. + */ + public static T checkNotNullWithIAE(final T arg, final String paramName) throws IllegalArgumentException { + if (arg == null) { + throw new IllegalArgumentException("Param '" + paramName + "' must not be null"); + } + return arg; + } + + /** + * Checks that the given argument is not null. If it is, throws {@link IllegalArgumentException}. + * Otherwise, returns the argument. + * + * @param type of the given argument value. + * @param name of the parameter, belongs to the exception message. + * @param index of the array, belongs to the exception message. + * @param value to check. + * @return the given argument value. + * @throws IllegalArgumentException if value is null. + */ + public static T checkNotNullArrayParam(T value, int index, String name) throws IllegalArgumentException { + if (value == null) { + throw new IllegalArgumentException( + "Array index " + index + " of parameter '" + name + "' must not be null"); + } + return value; + } + + /** + * Checks that the given argument is strictly positive. If it is not, throws {@link IllegalArgumentException}. + * Otherwise, returns the argument. + */ + public static int checkPositive(int i, String name) { + if (i <= INT_ZERO) { + throw new IllegalArgumentException(name + " : " + i + " (expected: > 0)"); + } + return i; + } + + /** + * Checks that the given argument is strictly positive. If it is not, throws {@link IllegalArgumentException}. + * Otherwise, returns the argument. + */ + public static long checkPositive(long l, String name) { + if (l <= LONG_ZERO) { + throw new IllegalArgumentException(name + " : " + l + " (expected: > 0)"); + } + return l; + } + + /** + * Checks that the given argument is strictly positive. If it is not, throws {@link IllegalArgumentException}. + * Otherwise, returns the argument. + */ + public static double checkPositive(final double d, final String name) { + if (d <= DOUBLE_ZERO) { + throw new IllegalArgumentException(name + " : " + d + " (expected: > 0)"); + } + return d; + } + + /** + * Checks that the given argument is strictly positive. If it is not, throws {@link IllegalArgumentException}. + * Otherwise, returns the argument. + */ + public static float checkPositive(final float f, final String name) { + if (f <= FLOAT_ZERO) { + throw new IllegalArgumentException(name + " : " + f + " (expected: > 0)"); + } + return f; + } + + /** + * Checks that the given argument is positive or zero. If it is not , throws {@link IllegalArgumentException}. + * Otherwise, returns the argument. + */ + public static int checkPositiveOrZero(int i, String name) { + if (i < INT_ZERO) { + throw new IllegalArgumentException(name + " : " + i + " (expected: >= 0)"); + } + return i; + } + + /** + * Checks that the given argument is positive or zero. If it is not, throws {@link IllegalArgumentException}. + * Otherwise, returns the argument. + */ + public static long checkPositiveOrZero(long l, String name) { + if (l < LONG_ZERO) { + throw new IllegalArgumentException(name + " : " + l + " (expected: >= 0)"); + } + return l; + } + + /** + * Checks that the given argument is positive or zero. If it is not, throws {@link IllegalArgumentException}. + * Otherwise, returns the argument. + */ + public static double checkPositiveOrZero(final double d, final String name) { + if (d < DOUBLE_ZERO) { + throw new IllegalArgumentException(name + " : " + d + " (expected: >= 0)"); + } + return d; + } + + /** + * Checks that the given argument is positive or zero. If it is not, throws {@link IllegalArgumentException}. + * Otherwise, returns the argument. + */ + public static float checkPositiveOrZero(final float f, final String name) { + if (f < FLOAT_ZERO) { + throw new IllegalArgumentException(name + " : " + f + " (expected: >= 0)"); + } + return f; + } + + /** + * Checks that the given argument is in range. If it is not, throws {@link IllegalArgumentException}. + * Otherwise, returns the argument. + */ + public static int checkInRange(int i, int start, int end, String name) { + if (i < start || i > end) { + throw new IllegalArgumentException(name + ": " + i + " (expected: " + start + "-" + end + ")"); + } + return i; + } + + /** + * Checks that the given argument is in range. If it is not, throws {@link IllegalArgumentException}. + * Otherwise, returns the argument. + */ + public static long checkInRange(long l, long start, long end, String name) { + if (l < start || l > end) { + throw new IllegalArgumentException(name + ": " + l + " (expected: " + start + "-" + end + ")"); + } + return l; + } + + /** + * Checks that the given argument is neither null nor empty. + * If it is, throws {@link NullPointerException} or {@link IllegalArgumentException}. + * Otherwise, returns the argument. + */ + public static T[] checkNonEmpty(T[] array, String name) { + //No String concatenation for check + if (checkNotNull(array, name).length == 0) { + throw new IllegalArgumentException("Param '" + name + "' must not be empty"); + } + return array; + } + + /** + * Checks that the given argument is neither null nor empty. + * If it is, throws {@link NullPointerException} or {@link IllegalArgumentException}. + * Otherwise, returns the argument. + */ + public static byte[] checkNonEmpty(byte[] array, String name) { + //No String concatenation for check + if (checkNotNull(array, name).length == 0) { + throw new IllegalArgumentException("Param '" + name + "' must not be empty"); + } + return array; + } + + /** + * Checks that the given argument is neither null nor empty. + * If it is, throws {@link NullPointerException} or {@link IllegalArgumentException}. + * Otherwise, returns the argument. + */ + public static char[] checkNonEmpty(char[] array, String name) { + //No String concatenation for check + if (checkNotNull(array, name).length == 0) { + throw new IllegalArgumentException("Param '" + name + "' must not be empty"); + } + return array; + } + + /** + * Checks that the given argument is neither null nor empty. + * If it is, throws {@link NullPointerException} or {@link IllegalArgumentException}. + * Otherwise, returns the argument. + */ + public static > T checkNonEmpty(T collection, String name) { + //No String concatenation for check + if (checkNotNull(collection, name).isEmpty()) { + throw new IllegalArgumentException("Param '" + name + "' must not be empty"); + } + return collection; + } + + /** + * Checks that the given argument is neither null nor empty. + * If it is, throws {@link NullPointerException} or {@link IllegalArgumentException}. + * Otherwise, returns the argument. + */ + public static String checkNonEmpty(final String value, final String name) { + if (checkNotNull(value, name).isEmpty()) { + throw new IllegalArgumentException("Param '" + name + "' must not be empty"); + } + return value; + } + + /** + * Checks that the given argument is neither null nor empty. + * If it is, throws {@link NullPointerException} or {@link IllegalArgumentException}. + * Otherwise, returns the argument. + */ + public static > T checkNonEmpty(T value, String name) { + if (checkNotNull(value, name).isEmpty()) { + throw new IllegalArgumentException("Param '" + name + "' must not be empty"); + } + return value; + } + + /** + * Checks that the given argument is neither null nor empty. + * If it is, throws {@link NullPointerException} or {@link IllegalArgumentException}. + * Otherwise, returns the argument. + */ + public static CharSequence checkNonEmpty(final CharSequence value, final String name) { + if (checkNotNull(value, name).length() == 0) { + throw new IllegalArgumentException("Param '" + name + "' must not be empty"); + } + return value; + } + + /** + * Trims the given argument and checks whether it is neither null nor empty. + * If it is, throws {@link NullPointerException} or {@link IllegalArgumentException}. + * Otherwise, returns the trimmed argument. + * + * @param value to trim and check. + * @param name of the parameter. + * @return the trimmed (not the original) value. + * @throws NullPointerException if value is null. + * @throws IllegalArgumentException if the trimmed value is empty. + */ + public static String checkNonEmptyAfterTrim(final String value, final String name) { + String trimmed = checkNotNull(value, name).trim(); + return checkNonEmpty(trimmed, name); + } + + /** + * Resolves a possibly null Integer to a primitive int, using a default value. + * + * @param wrapper the wrapper + * @param defaultValue the default value + * @return the primitive value + */ + public static int intValue(Integer wrapper, int defaultValue) { + return wrapper != null ? wrapper : defaultValue; + } + + /** + * Resolves a possibly null Long to a primitive long, using a default value. + * + * @param wrapper the wrapper + * @param defaultValue the default value + * @return the primitive value + */ + public static long longValue(Long wrapper, long defaultValue) { + return wrapper != null ? wrapper : defaultValue; + } +} diff --git a/netty-util/src/main/java/io/netty/util/internal/OutOfDirectMemoryError.java b/netty-util/src/main/java/io/netty/util/internal/OutOfDirectMemoryError.java new file mode 100644 index 0000000..d916949 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/internal/OutOfDirectMemoryError.java @@ -0,0 +1,30 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.internal; + +import java.nio.ByteBuffer; + +/** + * {@link OutOfMemoryError} that is throws if {@link PlatformDependent#allocateDirectNoCleaner(int)} can not allocate + * a new {@link ByteBuffer} due memory restrictions. + */ +public final class OutOfDirectMemoryError extends OutOfMemoryError { + private static final long serialVersionUID = 4228264016184011555L; + + OutOfDirectMemoryError(String s) { + super(s); + } +} diff --git a/netty-util/src/main/java/io/netty/util/internal/PendingWrite.java b/netty-util/src/main/java/io/netty/util/internal/PendingWrite.java new file mode 100644 index 0000000..44f3375 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/internal/PendingWrite.java @@ -0,0 +1,99 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.internal; + +import io.netty.util.ReferenceCountUtil; +import io.netty.util.concurrent.Promise; +import io.netty.util.internal.ObjectPool.Handle; +import io.netty.util.internal.ObjectPool.ObjectCreator; + +/** + * Some pending write which should be picked up later. + */ +public final class PendingWrite { + private static final ObjectPool RECYCLER = ObjectPool.newPool(new ObjectCreator() { + @Override + public PendingWrite newObject(Handle handle) { + return new PendingWrite(handle); + } + }); + + /** + * Create a new empty {@link RecyclableArrayList} instance + */ + public static PendingWrite newInstance(Object msg, Promise promise) { + PendingWrite pending = RECYCLER.get(); + pending.msg = msg; + pending.promise = promise; + return pending; + } + + private final Handle handle; + private Object msg; + private Promise promise; + + private PendingWrite(Handle handle) { + this.handle = handle; + } + + /** + * Clear and recycle this instance. + */ + public boolean recycle() { + msg = null; + promise = null; + handle.recycle(this); + return true; + } + + /** + * Fails the underlying {@link Promise} with the given cause and recycle this instance. + */ + public boolean failAndRecycle(Throwable cause) { + ReferenceCountUtil.release(msg); + if (promise != null) { + promise.setFailure(cause); + } + return recycle(); + } + + /** + * Mark the underlying {@link Promise} successfully and recycle this instance. + */ + public boolean successAndRecycle() { + if (promise != null) { + promise.setSuccess(null); + } + return recycle(); + } + + public Object msg() { + return msg; + } + + public Promise promise() { + return promise; + } + + /** + * Recycle this instance and return the {@link Promise}. + */ + public Promise recycleAndGet() { + Promise promise = this.promise; + recycle(); + return promise; + } +} diff --git a/netty-util/src/main/java/io/netty/util/internal/PlatformDependent.java b/netty-util/src/main/java/io/netty/util/internal/PlatformDependent.java new file mode 100644 index 0000000..4d517d2 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/internal/PlatformDependent.java @@ -0,0 +1,1635 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.internal; + +import io.netty.util.CharsetUtil; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; +import java.io.BufferedReader; +import java.io.File; +import java.io.FileInputStream; +import java.io.IOException; +import java.io.InputStreamReader; +import java.lang.reflect.Field; +import java.lang.reflect.Method; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; +import java.nio.file.Files; +import java.util.Arrays; +import java.util.Collections; +import java.util.Deque; +import java.util.HashSet; +import java.util.LinkedHashSet; +import java.util.List; +import java.util.Locale; +import java.util.Map; +import java.util.Queue; +import java.util.Random; +import java.util.Set; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentLinkedDeque; +import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.LinkedBlockingDeque; +import java.util.concurrent.atomic.AtomicLong; +import java.util.regex.Matcher; +import java.util.regex.Pattern; +import org.jctools.queues.MpscArrayQueue; +import org.jctools.queues.MpscChunkedArrayQueue; +import org.jctools.queues.MpscUnboundedArrayQueue; +import org.jctools.queues.SpscLinkedQueue; +import org.jctools.queues.atomic.MpscAtomicArrayQueue; +import org.jctools.queues.atomic.MpscChunkedAtomicArrayQueue; +import org.jctools.queues.atomic.MpscUnboundedAtomicArrayQueue; +import org.jctools.queues.atomic.SpscLinkedAtomicQueue; +import org.jctools.util.Pow2; +import org.jctools.util.UnsafeAccess; +import static io.netty.util.internal.PlatformDependent0.HASH_CODE_ASCII_SEED; +import static io.netty.util.internal.PlatformDependent0.HASH_CODE_C1; +import static io.netty.util.internal.PlatformDependent0.HASH_CODE_C2; +import static io.netty.util.internal.PlatformDependent0.hashCodeAsciiSanitize; +import static io.netty.util.internal.PlatformDependent0.unalignedAccess; +import static java.lang.Math.max; +import static java.lang.Math.min; + +/** + * Utility that detects various properties specific to the current runtime + * environment, such as Java version and the availability of the + * {@code sun.misc.Unsafe} object. + *

+ * You can disable the use of {@code sun.misc.Unsafe} if you specify + * the system property io.netty.noUnsafe. + */ +public final class PlatformDependent { + + private static final InternalLogger logger = InternalLoggerFactory.getInstance(PlatformDependent.class); + + private static Pattern MAX_DIRECT_MEMORY_SIZE_ARG_PATTERN; + private static final boolean MAYBE_SUPER_USER; + + private static final boolean CAN_ENABLE_TCP_NODELAY_BY_DEFAULT = !isAndroid(); + + private static final Throwable UNSAFE_UNAVAILABILITY_CAUSE = unsafeUnavailabilityCause0(); + private static final boolean DIRECT_BUFFER_PREFERRED; + private static final long MAX_DIRECT_MEMORY = estimateMaxDirectMemory(); + + private static final int MPSC_CHUNK_SIZE = 1024; + private static final int MIN_MAX_MPSC_CAPACITY = MPSC_CHUNK_SIZE * 2; + private static final int MAX_ALLOWED_MPSC_CAPACITY = Pow2.MAX_POW2; + + private static final long BYTE_ARRAY_BASE_OFFSET = byteArrayBaseOffset0(); + + private static final File TMPDIR = tmpdir0(); + + private static final int BIT_MODE = bitMode0(); + private static final String NORMALIZED_ARCH = normalizeArch(SystemPropertyUtil.get("os.arch", "")); + private static final String NORMALIZED_OS = normalizeOs(SystemPropertyUtil.get("os.name", "")); + + // keep in sync with maven's pom.xml via os.detection.classifierWithLikes! + private static final String[] ALLOWED_LINUX_OS_CLASSIFIERS = {"fedora", "suse", "arch"}; + private static final Set LINUX_OS_CLASSIFIERS; + + private static final boolean IS_WINDOWS = isWindows0(); + private static final boolean IS_OSX = isOsx0(); + private static final boolean IS_J9_JVM = isJ9Jvm0(); + private static final boolean IS_IVKVM_DOT_NET = isIkvmDotNet0(); + + private static final int ADDRESS_SIZE = addressSize0(); + private static final boolean USE_DIRECT_BUFFER_NO_CLEANER; + private static final AtomicLong DIRECT_MEMORY_COUNTER; + private static final long DIRECT_MEMORY_LIMIT; + private static final ThreadLocalRandomProvider RANDOM_PROVIDER; + private static final Cleaner CLEANER; + private static final int UNINITIALIZED_ARRAY_ALLOCATION_THRESHOLD; + // For specifications, see https://www.freedesktop.org/software/systemd/man/os-release.html + private static final String[] OS_RELEASE_FILES = {"/etc/os-release", "/usr/lib/os-release"}; + private static final String LINUX_ID_PREFIX = "ID="; + private static final String LINUX_ID_LIKE_PREFIX = "ID_LIKE="; + public static final boolean BIG_ENDIAN_NATIVE_ORDER = ByteOrder.nativeOrder() == ByteOrder.BIG_ENDIAN; + + private static final Cleaner NOOP = new Cleaner() { + @Override + public void freeDirectBuffer(ByteBuffer buffer) { + // NOOP + } + }; + + static { + if (javaVersion() >= 7) { + RANDOM_PROVIDER = new ThreadLocalRandomProvider() { + @Override + @SuppressJava6Requirement(reason = "Usage guarded by java version check") + public Random current() { + return java.util.concurrent.ThreadLocalRandom.current(); + } + }; + } else { + RANDOM_PROVIDER = new ThreadLocalRandomProvider() { + @Override + public Random current() { + return ThreadLocalRandom.current(); + } + }; + } + + // Here is how the system property is used: + // + // * < 0 - Don't use cleaner, and inherit max direct memory from java. In this case the + // "practical max direct memory" would be 2 * max memory as defined by the JDK. + // * == 0 - Use cleaner, Netty will not enforce max memory, and instead will defer to JDK. + // * > 0 - Don't use cleaner. This will limit Netty's total direct memory + // (note: that JDK's direct memory limit is independent of this). + long maxDirectMemory = SystemPropertyUtil.getLong("io.netty.maxDirectMemory", -1); + + if (maxDirectMemory == 0 || !hasUnsafe() || !PlatformDependent0.hasDirectBufferNoCleanerConstructor()) { + USE_DIRECT_BUFFER_NO_CLEANER = false; + DIRECT_MEMORY_COUNTER = null; + } else { + USE_DIRECT_BUFFER_NO_CLEANER = true; + if (maxDirectMemory < 0) { + maxDirectMemory = MAX_DIRECT_MEMORY; + if (maxDirectMemory <= 0) { + DIRECT_MEMORY_COUNTER = null; + } else { + DIRECT_MEMORY_COUNTER = new AtomicLong(); + } + } else { + DIRECT_MEMORY_COUNTER = new AtomicLong(); + } + } + logger.debug("-Dio.netty.maxDirectMemory: {} bytes", maxDirectMemory); + DIRECT_MEMORY_LIMIT = maxDirectMemory >= 1 ? maxDirectMemory : MAX_DIRECT_MEMORY; + + int tryAllocateUninitializedArray = + SystemPropertyUtil.getInt("io.netty.uninitializedArrayAllocationThreshold", 1024); + UNINITIALIZED_ARRAY_ALLOCATION_THRESHOLD = javaVersion() >= 9 && PlatformDependent0.hasAllocateArrayMethod() ? + tryAllocateUninitializedArray : -1; + logger.debug("-Dio.netty.uninitializedArrayAllocationThreshold: {}", UNINITIALIZED_ARRAY_ALLOCATION_THRESHOLD); + + MAYBE_SUPER_USER = maybeSuperUser0(); + + if (!isAndroid()) { + // only direct to method if we are not running on android. + // See https://github.com/netty/netty/issues/2604 + if (javaVersion() >= 9) { + CLEANER = CleanerJava9.isSupported() ? new CleanerJava9() : NOOP; + } else { + CLEANER = NOOP; + } + } else { + CLEANER = NOOP; + } + + // We should always prefer direct buffers by default if we can use a Cleaner to release direct buffers. + DIRECT_BUFFER_PREFERRED = CLEANER != NOOP + && !SystemPropertyUtil.getBoolean("io.netty.noPreferDirect", false); + if (logger.isDebugEnabled()) { + logger.debug("-Dio.netty.noPreferDirect: {}", !DIRECT_BUFFER_PREFERRED); + } + + /* + * We do not want to log this message if unsafe is explicitly disabled. Do not remove the explicit no unsafe + * guard. + */ + if (CLEANER == NOOP && !PlatformDependent0.isExplicitNoUnsafe()) { + logger.info( + "Your platform does not provide complete low-level API for accessing direct buffers reliably. " + + "Unless explicitly requested, heap buffer will always be preferred to avoid potential system " + + "instability."); + } + + final Set allowedClassifiers = Collections.unmodifiableSet( + new HashSet(Arrays.asList(ALLOWED_LINUX_OS_CLASSIFIERS))); + final Set availableClassifiers = new LinkedHashSet(); + + if (!addPropertyOsClassifiers(allowedClassifiers, availableClassifiers)) { + addFilesystemOsClassifiers(allowedClassifiers, availableClassifiers); + } + LINUX_OS_CLASSIFIERS = Collections.unmodifiableSet(availableClassifiers); + } + + static void addFilesystemOsClassifiers(final Set allowedClassifiers, + final Set availableClassifiers) { + for (final String osReleaseFileName : OS_RELEASE_FILES) { + final File file = new File(osReleaseFileName); + boolean found; + if (file.exists()) { + try (BufferedReader reader = new BufferedReader( + new InputStreamReader(new FileInputStream(file), CharsetUtil.UTF_8))) { + String line; + while ((line = reader.readLine()) != null) { + if (line.startsWith(LINUX_ID_PREFIX)) { + String id = normalizeOsReleaseVariableValue( + line.substring(LINUX_ID_PREFIX.length())); + addClassifier(allowedClassifiers, availableClassifiers, id); + } else if (line.startsWith(LINUX_ID_LIKE_PREFIX)) { + line = normalizeOsReleaseVariableValue( + line.substring(LINUX_ID_LIKE_PREFIX.length())); + addClassifier(allowedClassifiers, availableClassifiers, line.split("[ ]+")); + } + } + } catch (IOException e) { + logger.debug("Error while reading content of {}", osReleaseFileName, e); + } + // Ignore + // specification states we should only fall back if /etc/os-release does not exist + found = true; + } else { + found = false; + } + if (found) { + break; + } + } + } + + static boolean addPropertyOsClassifiers(Set allowedClassifiers, Set availableClassifiers) { + // empty: -Dio.netty.osClassifiers (no distro specific classifiers for native libs) + // single ID: -Dio.netty.osClassifiers=ubuntu + // pair ID, ID_LIKE: -Dio.netty.osClassifiers=ubuntu,debian + // illegal otherwise + String osClassifiersPropertyName = "io.netty.osClassifiers"; + String osClassifiers = SystemPropertyUtil.get(osClassifiersPropertyName); + if (osClassifiers == null) { + return false; + } + if (osClassifiers.isEmpty()) { + // let users omit classifiers with just -Dio.netty.osClassifiers + return true; + } + String[] classifiers = osClassifiers.split(","); + if (classifiers.length == 0) { + throw new IllegalArgumentException( + osClassifiersPropertyName + " property is not empty, but contains no classifiers: " + + osClassifiers); + } + // at most ID, ID_LIKE classifiers + if (classifiers.length > 2) { + throw new IllegalArgumentException( + osClassifiersPropertyName + " property contains more than 2 classifiers: " + osClassifiers); + } + for (String classifier : classifiers) { + addClassifier(allowedClassifiers, availableClassifiers, classifier); + } + return true; + } + + public static long byteArrayBaseOffset() { + return BYTE_ARRAY_BASE_OFFSET; + } + + public static boolean hasDirectBufferNoCleanerConstructor() { + return PlatformDependent0.hasDirectBufferNoCleanerConstructor(); + } + + public static byte[] allocateUninitializedArray(int size) { + return UNINITIALIZED_ARRAY_ALLOCATION_THRESHOLD < 0 || UNINITIALIZED_ARRAY_ALLOCATION_THRESHOLD > size ? + new byte[size] : PlatformDependent0.allocateUninitializedArray(size); + } + + /** + * Returns {@code true} if and only if the current platform is Android + */ + public static boolean isAndroid() { + return PlatformDependent0.isAndroid(); + } + + /** + * Return {@code true} if the JVM is running on Windows + */ + public static boolean isWindows() { + return IS_WINDOWS; + } + + /** + * Return {@code true} if the JVM is running on OSX / MacOS + */ + public static boolean isOsx() { + return IS_OSX; + } + + /** + * Return {@code true} if the current user may be a super-user. Be aware that this is just an hint and so it may + * return false-positives. + */ + public static boolean maybeSuperUser() { + return MAYBE_SUPER_USER; + } + + /** + * Return the version of Java under which this library is used. + */ + public static int javaVersion() { + return PlatformDependent0.javaVersion(); + } + + /** + * Returns {@code true} if and only if it is fine to enable TCP_NODELAY socket option by default. + */ + public static boolean canEnableTcpNoDelayByDefault() { + return CAN_ENABLE_TCP_NODELAY_BY_DEFAULT; + } + + /** + * Return {@code true} if {@code sun.misc.Unsafe} was found on the classpath and can be used for accelerated + * direct memory access. + */ + public static boolean hasUnsafe() { + return UNSAFE_UNAVAILABILITY_CAUSE == null; + } + + /** + * Return the reason (if any) why {@code sun.misc.Unsafe} was not available. + */ + public static Throwable getUnsafeUnavailabilityCause() { + return UNSAFE_UNAVAILABILITY_CAUSE; + } + + /** + * {@code true} if and only if the platform supports unaligned access. + * + * @see Wikipedia on segfault + */ + public static boolean isUnaligned() { + return PlatformDependent0.isUnaligned(); + } + + /** + * Returns {@code true} if the platform has reliable low-level direct buffer access API and a user has not specified + * {@code -Dio.netty.noPreferDirect} option. + */ + public static boolean directBufferPreferred() { + return DIRECT_BUFFER_PREFERRED; + } + + /** + * Returns the maximum memory reserved for direct buffer allocation. + */ + public static long maxDirectMemory() { + return DIRECT_MEMORY_LIMIT; + } + + /** + * Returns the current memory reserved for direct buffer allocation. + * This method returns -1 in case that a value is not available. + * + * @see #maxDirectMemory() + */ + public static long usedDirectMemory() { + return DIRECT_MEMORY_COUNTER != null ? DIRECT_MEMORY_COUNTER.get() : -1; + } + + /** + * Returns the temporary directory. + */ + public static File tmpdir() { + return TMPDIR; + } + + /** + * Returns the bit mode of the current VM (usually 32 or 64.) + */ + public static int bitMode() { + return BIT_MODE; + } + + /** + * Return the address size of the OS. + * 4 (for 32 bits systems ) and 8 (for 64 bits systems). + */ + public static int addressSize() { + return ADDRESS_SIZE; + } + + public static long allocateMemory(long size) { + return PlatformDependent0.allocateMemory(size); + } + + public static void freeMemory(long address) { + PlatformDependent0.freeMemory(address); + } + + public static long reallocateMemory(long address, long newSize) { + return PlatformDependent0.reallocateMemory(address, newSize); + } + + /** + * Raises an exception bypassing compiler checks for checked exceptions. + */ + public static void throwException(Throwable t) { + if (hasUnsafe()) { + PlatformDependent0.throwException(t); + } else { + PlatformDependent.throwException0(t); + } + } + + @SuppressWarnings("unchecked") + private static void throwException0(Throwable t) throws E { + throw (E) t; + } + + /** + * Creates a new fastest {@link ConcurrentMap} implementation for the current platform. + */ + public static ConcurrentMap newConcurrentHashMap() { + return new ConcurrentHashMap(); + } + + /** + * Creates a new fastest {@link LongCounter} implementation for the current platform. + */ + public static LongCounter newLongCounter() { + if (javaVersion() >= 8) { + return new LongAdderCounter(); + } else { + return new AtomicLongCounter(); + } + } + + /** + * Creates a new fastest {@link ConcurrentMap} implementation for the current platform. + */ + public static ConcurrentMap newConcurrentHashMap(int initialCapacity) { + return new ConcurrentHashMap(initialCapacity); + } + + /** + * Creates a new fastest {@link ConcurrentMap} implementation for the current platform. + */ + public static ConcurrentMap newConcurrentHashMap(int initialCapacity, float loadFactor) { + return new ConcurrentHashMap(initialCapacity, loadFactor); + } + + /** + * Creates a new fastest {@link ConcurrentMap} implementation for the current platform. + */ + public static ConcurrentMap newConcurrentHashMap( + int initialCapacity, float loadFactor, int concurrencyLevel) { + return new ConcurrentHashMap(initialCapacity, loadFactor, concurrencyLevel); + } + + /** + * Creates a new fastest {@link ConcurrentMap} implementation for the current platform. + */ + public static ConcurrentMap newConcurrentHashMap(Map map) { + return new ConcurrentHashMap(map); + } + + /** + * Try to deallocate the specified direct {@link ByteBuffer}. Please note this method does nothing if + * the current platform does not support this operation or the specified buffer is not a direct buffer. + */ + public static void freeDirectBuffer(ByteBuffer buffer) { + CLEANER.freeDirectBuffer(buffer); + } + + public static long directBufferAddress(ByteBuffer buffer) { + return PlatformDependent0.directBufferAddress(buffer); + } + + public static ByteBuffer directBuffer(long memoryAddress, int size) { + if (PlatformDependent0.hasDirectBufferNoCleanerConstructor()) { + return PlatformDependent0.newDirectBuffer(memoryAddress, size); + } + throw new UnsupportedOperationException( + "sun.misc.Unsafe or java.nio.DirectByteBuffer.(long, int) not available"); + } + + public static Object getObject(Object object, long fieldOffset) { + return PlatformDependent0.getObject(object, fieldOffset); + } + + public static int getInt(Object object, long fieldOffset) { + return PlatformDependent0.getInt(object, fieldOffset); + } + + static void safeConstructPutInt(Object object, long fieldOffset, int value) { + PlatformDependent0.safeConstructPutInt(object, fieldOffset, value); + } + + public static int getIntVolatile(long address) { + return PlatformDependent0.getIntVolatile(address); + } + + public static byte getByte(long address) { + return PlatformDependent0.getByte(address); + } + + public static short getShort(long address) { + return PlatformDependent0.getShort(address); + } + + public static int getInt(long address) { + return PlatformDependent0.getInt(address); + } + + public static long getLong(long address) { + return PlatformDependent0.getLong(address); + } + + public static byte getByte(byte[] data, int index) { + return PlatformDependent0.getByte(data, index); + } + + public static byte getByte(byte[] data, long index) { + return PlatformDependent0.getByte(data, index); + } + + public static short getShort(byte[] data, int index) { + return PlatformDependent0.getShort(data, index); + } + + public static int getInt(byte[] data, int index) { + return PlatformDependent0.getInt(data, index); + } + + public static int getInt(int[] data, long index) { + return PlatformDependent0.getInt(data, index); + } + + public static long getLong(byte[] data, int index) { + return PlatformDependent0.getLong(data, index); + } + + public static long getLong(long[] data, long index) { + return PlatformDependent0.getLong(data, index); + } + + private static long getLongSafe(byte[] bytes, int offset) { + if (BIG_ENDIAN_NATIVE_ORDER) { + return (long) bytes[offset] << 56 | + ((long) bytes[offset + 1] & 0xff) << 48 | + ((long) bytes[offset + 2] & 0xff) << 40 | + ((long) bytes[offset + 3] & 0xff) << 32 | + ((long) bytes[offset + 4] & 0xff) << 24 | + ((long) bytes[offset + 5] & 0xff) << 16 | + ((long) bytes[offset + 6] & 0xff) << 8 | + (long) bytes[offset + 7] & 0xff; + } + return (long) bytes[offset] & 0xff | + ((long) bytes[offset + 1] & 0xff) << 8 | + ((long) bytes[offset + 2] & 0xff) << 16 | + ((long) bytes[offset + 3] & 0xff) << 24 | + ((long) bytes[offset + 4] & 0xff) << 32 | + ((long) bytes[offset + 5] & 0xff) << 40 | + ((long) bytes[offset + 6] & 0xff) << 48 | + (long) bytes[offset + 7] << 56; + } + + private static int getIntSafe(byte[] bytes, int offset) { + if (BIG_ENDIAN_NATIVE_ORDER) { + return bytes[offset] << 24 | + (bytes[offset + 1] & 0xff) << 16 | + (bytes[offset + 2] & 0xff) << 8 | + bytes[offset + 3] & 0xff; + } + return bytes[offset] & 0xff | + (bytes[offset + 1] & 0xff) << 8 | + (bytes[offset + 2] & 0xff) << 16 | + bytes[offset + 3] << 24; + } + + private static short getShortSafe(byte[] bytes, int offset) { + if (BIG_ENDIAN_NATIVE_ORDER) { + return (short) (bytes[offset] << 8 | (bytes[offset + 1] & 0xff)); + } + return (short) (bytes[offset] & 0xff | (bytes[offset + 1] << 8)); + } + + /** + * Identical to {@link PlatformDependent0#hashCodeAsciiCompute(long, int)} but for {@link CharSequence}. + */ + private static int hashCodeAsciiCompute(CharSequence value, int offset, int hash) { + if (BIG_ENDIAN_NATIVE_ORDER) { + return hash * HASH_CODE_C1 + + // Low order int + hashCodeAsciiSanitizeInt(value, offset + 4) * HASH_CODE_C2 + + // High order int + hashCodeAsciiSanitizeInt(value, offset); + } + return hash * HASH_CODE_C1 + + // Low order int + hashCodeAsciiSanitizeInt(value, offset) * HASH_CODE_C2 + + // High order int + hashCodeAsciiSanitizeInt(value, offset + 4); + } + + /** + * Identical to {@link PlatformDependent0#hashCodeAsciiSanitize(int)} but for {@link CharSequence}. + */ + private static int hashCodeAsciiSanitizeInt(CharSequence value, int offset) { + if (BIG_ENDIAN_NATIVE_ORDER) { + // mimic a unsafe.getInt call on a big endian machine + return (value.charAt(offset + 3) & 0x1f) | + (value.charAt(offset + 2) & 0x1f) << 8 | + (value.charAt(offset + 1) & 0x1f) << 16 | + (value.charAt(offset) & 0x1f) << 24; + } + return (value.charAt(offset + 3) & 0x1f) << 24 | + (value.charAt(offset + 2) & 0x1f) << 16 | + (value.charAt(offset + 1) & 0x1f) << 8 | + (value.charAt(offset) & 0x1f); + } + + /** + * Identical to {@link PlatformDependent0#hashCodeAsciiSanitize(short)} but for {@link CharSequence}. + */ + private static int hashCodeAsciiSanitizeShort(CharSequence value, int offset) { + if (BIG_ENDIAN_NATIVE_ORDER) { + // mimic a unsafe.getShort call on a big endian machine + return (value.charAt(offset + 1) & 0x1f) | + (value.charAt(offset) & 0x1f) << 8; + } + return (value.charAt(offset + 1) & 0x1f) << 8 | + (value.charAt(offset) & 0x1f); + } + + /** + * Identical to {@link PlatformDependent0#hashCodeAsciiSanitize(byte)} but for {@link CharSequence}. + */ + private static int hashCodeAsciiSanitizeByte(char value) { + return value & 0x1f; + } + + public static void putByte(long address, byte value) { + PlatformDependent0.putByte(address, value); + } + + public static void putShort(long address, short value) { + PlatformDependent0.putShort(address, value); + } + + public static void putInt(long address, int value) { + PlatformDependent0.putInt(address, value); + } + + public static void putLong(long address, long value) { + PlatformDependent0.putLong(address, value); + } + + public static void putByte(byte[] data, int index, byte value) { + PlatformDependent0.putByte(data, index, value); + } + + public static void putByte(Object data, long offset, byte value) { + PlatformDependent0.putByte(data, offset, value); + } + + public static void putShort(byte[] data, int index, short value) { + PlatformDependent0.putShort(data, index, value); + } + + public static void putInt(byte[] data, int index, int value) { + PlatformDependent0.putInt(data, index, value); + } + + public static void putLong(byte[] data, int index, long value) { + PlatformDependent0.putLong(data, index, value); + } + + public static void putObject(Object o, long offset, Object x) { + PlatformDependent0.putObject(o, offset, x); + } + + public static long objectFieldOffset(Field field) { + return PlatformDependent0.objectFieldOffset(field); + } + + public static void copyMemory(long srcAddr, long dstAddr, long length) { + PlatformDependent0.copyMemory(srcAddr, dstAddr, length); + } + + public static void copyMemory(byte[] src, int srcIndex, long dstAddr, long length) { + PlatformDependent0.copyMemory(src, BYTE_ARRAY_BASE_OFFSET + srcIndex, null, dstAddr, length); + } + + public static void copyMemory(byte[] src, int srcIndex, byte[] dst, int dstIndex, long length) { + PlatformDependent0.copyMemory(src, BYTE_ARRAY_BASE_OFFSET + srcIndex, + dst, BYTE_ARRAY_BASE_OFFSET + dstIndex, length); + } + + public static void copyMemory(long srcAddr, byte[] dst, int dstIndex, long length) { + PlatformDependent0.copyMemory(null, srcAddr, dst, BYTE_ARRAY_BASE_OFFSET + dstIndex, length); + } + + public static void setMemory(byte[] dst, int dstIndex, long bytes, byte value) { + PlatformDependent0.setMemory(dst, BYTE_ARRAY_BASE_OFFSET + dstIndex, bytes, value); + } + + public static void setMemory(long address, long bytes, byte value) { + PlatformDependent0.setMemory(address, bytes, value); + } + + /** + * Allocate a new {@link ByteBuffer} with the given {@code capacity}. {@link ByteBuffer}s allocated with + * this method MUST be deallocated via {@link #freeDirectNoCleaner(ByteBuffer)}. + */ + public static ByteBuffer allocateDirectNoCleaner(int capacity) { + assert USE_DIRECT_BUFFER_NO_CLEANER; + + incrementMemoryCounter(capacity); + try { + return PlatformDependent0.allocateDirectNoCleaner(capacity); + } catch (Throwable e) { + decrementMemoryCounter(capacity); + throwException(e); + return null; + } + } + + /** + * Reallocate a new {@link ByteBuffer} with the given {@code capacity}. {@link ByteBuffer}s reallocated with + * this method MUST be deallocated via {@link #freeDirectNoCleaner(ByteBuffer)}. + */ + public static ByteBuffer reallocateDirectNoCleaner(ByteBuffer buffer, int capacity) { + assert USE_DIRECT_BUFFER_NO_CLEANER; + + int len = capacity - buffer.capacity(); + incrementMemoryCounter(len); + try { + return PlatformDependent0.reallocateDirectNoCleaner(buffer, capacity); + } catch (Throwable e) { + decrementMemoryCounter(len); + throwException(e); + return null; + } + } + + /** + * This method MUST only be called for {@link ByteBuffer}s that were allocated via + * {@link #allocateDirectNoCleaner(int)}. + */ + public static void freeDirectNoCleaner(ByteBuffer buffer) { + assert USE_DIRECT_BUFFER_NO_CLEANER; + + int capacity = buffer.capacity(); + PlatformDependent0.freeMemory(PlatformDependent0.directBufferAddress(buffer)); + decrementMemoryCounter(capacity); + } + + public static boolean hasAlignDirectByteBuffer() { + return hasUnsafe() || PlatformDependent0.hasAlignSliceMethod(); + } + + public static ByteBuffer alignDirectBuffer(ByteBuffer buffer, int alignment) { + if (!buffer.isDirect()) { + throw new IllegalArgumentException("Cannot get aligned slice of non-direct byte buffer."); + } + if (PlatformDependent0.hasAlignSliceMethod()) { + return PlatformDependent0.alignSlice(buffer, alignment); + } + if (hasUnsafe()) { + long address = directBufferAddress(buffer); + long aligned = align(address, alignment); + buffer.position((int) (aligned - address)); + return buffer.slice(); + } + // We don't have enough information to be able to align any buffers. + throw new UnsupportedOperationException("Cannot align direct buffer. " + + "Needs either Unsafe or ByteBuffer.alignSlice method available."); + } + + public static long align(long value, int alignment) { + return Pow2.align(value, alignment); + } + + private static void incrementMemoryCounter(int capacity) { + if (DIRECT_MEMORY_COUNTER != null) { + long newUsedMemory = DIRECT_MEMORY_COUNTER.addAndGet(capacity); + if (newUsedMemory > DIRECT_MEMORY_LIMIT) { + DIRECT_MEMORY_COUNTER.addAndGet(-capacity); + throw new OutOfDirectMemoryError("failed to allocate " + capacity + + " byte(s) of direct memory (used: " + (newUsedMemory - capacity) + + ", max: " + DIRECT_MEMORY_LIMIT + ')'); + } + } + } + + private static void decrementMemoryCounter(int capacity) { + if (DIRECT_MEMORY_COUNTER != null) { + long usedMemory = DIRECT_MEMORY_COUNTER.addAndGet(-capacity); + assert usedMemory >= 0; + } + } + + public static boolean useDirectBufferNoCleaner() { + return USE_DIRECT_BUFFER_NO_CLEANER; + } + + /** + * Compare two {@code byte} arrays for equality. For performance reasons no bounds checking on the + * parameters is performed. + * + * @param bytes1 the first byte array. + * @param startPos1 the position (inclusive) to start comparing in {@code bytes1}. + * @param bytes2 the second byte array. + * @param startPos2 the position (inclusive) to start comparing in {@code bytes2}. + * @param length the amount of bytes to compare. This is assumed to be validated as not going out of bounds + * by the caller. + */ + public static boolean equals(byte[] bytes1, int startPos1, byte[] bytes2, int startPos2, int length) { + if (javaVersion() > 8 && (startPos2 | startPos1 | (bytes1.length - length) | bytes2.length - length) == 0) { + return Arrays.equals(bytes1, bytes2); + } + return !hasUnsafe() || !unalignedAccess() ? + equalsSafe(bytes1, startPos1, bytes2, startPos2, length) : + PlatformDependent0.equals(bytes1, startPos1, bytes2, startPos2, length); + } + + /** + * Determine if a subsection of an array is zero. + * + * @param bytes The byte array. + * @param startPos The starting index (inclusive) in {@code bytes}. + * @param length The amount of bytes to check for zero. + * @return {@code false} if {@code bytes[startPos:startsPos+length)} contains a value other than zero. + */ + public static boolean isZero(byte[] bytes, int startPos, int length) { + return !hasUnsafe() || !unalignedAccess() ? + isZeroSafe(bytes, startPos, length) : + PlatformDependent0.isZero(bytes, startPos, length); + } + + /** + * Compare two {@code byte} arrays for equality without leaking timing information. + * For performance reasons no bounds checking on the parameters is performed. + *

+ * The {@code int} return type is intentional and is designed to allow cascading of constant time operations: + *

+     *     byte[] s1 = new {1, 2, 3};
+     *     byte[] s2 = new {1, 2, 3};
+     *     byte[] s3 = new {1, 2, 3};
+     *     byte[] s4 = new {4, 5, 6};
+     *     boolean equals = (equalsConstantTime(s1, 0, s2, 0, s1.length) &
+     *                       equalsConstantTime(s3, 0, s4, 0, s3.length)) != 0;
+     * 
+ * + * @param bytes1 the first byte array. + * @param startPos1 the position (inclusive) to start comparing in {@code bytes1}. + * @param bytes2 the second byte array. + * @param startPos2 the position (inclusive) to start comparing in {@code bytes2}. + * @param length the amount of bytes to compare. This is assumed to be validated as not going out of bounds + * by the caller. + * @return {@code 0} if not equal. {@code 1} if equal. + */ + public static int equalsConstantTime(byte[] bytes1, int startPos1, byte[] bytes2, int startPos2, int length) { + return !hasUnsafe() || !unalignedAccess() ? + ConstantTimeUtils.equalsConstantTime(bytes1, startPos1, bytes2, startPos2, length) : + PlatformDependent0.equalsConstantTime(bytes1, startPos1, bytes2, startPos2, length); + } + + /** + * Calculate a hash code of a byte array assuming ASCII character encoding. + * The resulting hash code will be case insensitive. + * + * @param bytes The array which contains the data to hash. + * @param startPos What index to start generating a hash code in {@code bytes} + * @param length The amount of bytes that should be accounted for in the computation. + * @return The hash code of {@code bytes} assuming ASCII character encoding. + * The resulting hash code will be case insensitive. + */ + public static int hashCodeAscii(byte[] bytes, int startPos, int length) { + return !hasUnsafe() || !unalignedAccess() ? + hashCodeAsciiSafe(bytes, startPos, length) : + PlatformDependent0.hashCodeAscii(bytes, startPos, length); + } + + /** + * Calculate a hash code of a byte array assuming ASCII character encoding. + * The resulting hash code will be case insensitive. + *

+ * This method assumes that {@code bytes} is equivalent to a {@code byte[]} but just using {@link CharSequence} + * for storage. The upper most byte of each {@code char} from {@code bytes} is ignored. + * + * @param bytes The array which contains the data to hash (assumed to be equivalent to a {@code byte[]}). + * @return The hash code of {@code bytes} assuming ASCII character encoding. + * The resulting hash code will be case insensitive. + */ + public static int hashCodeAscii(CharSequence bytes) { + final int length = bytes.length(); + final int remainingBytes = length & 7; + int hash = HASH_CODE_ASCII_SEED; + // Benchmarking shows that by just naively looping for inputs 8~31 bytes long we incur a relatively large + // performance penalty (only achieve about 60% performance of loop which iterates over each char). So because + // of this we take special provisions to unroll the looping for these conditions. + if (length >= 32) { + for (int i = length - 8; i >= remainingBytes; i -= 8) { + hash = hashCodeAsciiCompute(bytes, i, hash); + } + } else if (length >= 8) { + hash = hashCodeAsciiCompute(bytes, length - 8, hash); + if (length >= 16) { + hash = hashCodeAsciiCompute(bytes, length - 16, hash); + if (length >= 24) { + hash = hashCodeAsciiCompute(bytes, length - 24, hash); + } + } + } + if (remainingBytes == 0) { + return hash; + } + int offset = 0; + if (remainingBytes != 2 & remainingBytes != 4 & remainingBytes != 6) { // 1, 3, 5, 7 + hash = hash * HASH_CODE_C1 + hashCodeAsciiSanitizeByte(bytes.charAt(0)); + offset = 1; + } + if (remainingBytes != 1 & remainingBytes != 4 & remainingBytes != 5) { // 2, 3, 6, 7 + hash = hash * (offset == 0 ? HASH_CODE_C1 : HASH_CODE_C2) + + hashCodeAsciiSanitize(hashCodeAsciiSanitizeShort(bytes, offset)); + offset += 2; + } + if (remainingBytes >= 4) { // 4, 5, 6, 7 + return hash * ((offset == 0 | offset == 3) ? HASH_CODE_C1 : HASH_CODE_C2) + + hashCodeAsciiSanitizeInt(bytes, offset); + } + return hash; + } + + private static final class Mpsc { + private static final boolean USE_MPSC_CHUNKED_ARRAY_QUEUE; + + private Mpsc() { + } + + static { + Object unsafe = null; + if (hasUnsafe()) { + unsafe = UnsafeAccess.UNSAFE; + } + if (unsafe == null) { + logger.debug("MpscChunkedArrayQueue: unavailable"); + USE_MPSC_CHUNKED_ARRAY_QUEUE = false; + } else { + logger.debug("MpscChunkedArrayQueue: available"); + USE_MPSC_CHUNKED_ARRAY_QUEUE = true; + } + } + + static Queue newMpscQueue(final int maxCapacity) { + // Calculate the max capacity which can not be bigger than MAX_ALLOWED_MPSC_CAPACITY. + // This is forced by the MpscChunkedArrayQueue implementation as will try to round it + // up to the next power of two and so will overflow otherwise. + final int capacity = max(min(maxCapacity, MAX_ALLOWED_MPSC_CAPACITY), MIN_MAX_MPSC_CAPACITY); + return newChunkedMpscQueue(MPSC_CHUNK_SIZE, capacity); + } + + static Queue newChunkedMpscQueue(final int chunkSize, final int capacity) { + return USE_MPSC_CHUNKED_ARRAY_QUEUE ? new MpscChunkedArrayQueue(chunkSize, capacity) + : new MpscChunkedAtomicArrayQueue(chunkSize, capacity); + } + + static Queue newMpscQueue() { + return USE_MPSC_CHUNKED_ARRAY_QUEUE ? new MpscUnboundedArrayQueue(MPSC_CHUNK_SIZE) + : new MpscUnboundedAtomicArrayQueue(MPSC_CHUNK_SIZE); + } + } + + /** + * Create a new {@link Queue} which is safe to use for multiple producers (different threads) and a single + * consumer (one thread!). + * + * @return A MPSC queue which may be unbounded. + */ + public static Queue newMpscQueue() { + return Mpsc.newMpscQueue(); + } + + /** + * Create a new {@link Queue} which is safe to use for multiple producers (different threads) and a single + * consumer (one thread!). + */ + public static Queue newMpscQueue(final int maxCapacity) { + return Mpsc.newMpscQueue(maxCapacity); + } + + /** + * Create a new {@link Queue} which is safe to use for multiple producers (different threads) and a single + * consumer (one thread!). + * The queue will grow and shrink its capacity in units of the given chunk size. + */ + public static Queue newMpscQueue(final int chunkSize, final int maxCapacity) { + return Mpsc.newChunkedMpscQueue(chunkSize, maxCapacity); + } + + /** + * Create a new {@link Queue} which is safe to use for single producer (one thread!) and a single + * consumer (one thread!). + */ + public static Queue newSpscQueue() { + return hasUnsafe() ? new SpscLinkedQueue() : new SpscLinkedAtomicQueue(); + } + + /** + * Create a new {@link Queue} which is safe to use for multiple producers (different threads) and a single + * consumer (one thread!) with the given fixes {@code capacity}. + */ + public static Queue newFixedMpscQueue(int capacity) { + return hasUnsafe() ? new MpscArrayQueue(capacity) : new MpscAtomicArrayQueue(capacity); + } + + /** + * Return the {@link ClassLoader} for the given {@link Class}. + */ + public static ClassLoader getClassLoader(final Class clazz) { + return PlatformDependent0.getClassLoader(clazz); + } + + /** + * Return the context {@link ClassLoader} for the current {@link Thread}. + */ + public static ClassLoader getContextClassLoader() { + return PlatformDependent0.getContextClassLoader(); + } + + /** + * Return the system {@link ClassLoader}. + */ + public static ClassLoader getSystemClassLoader() { + return PlatformDependent0.getSystemClassLoader(); + } + + /** + * Returns a new concurrent {@link Deque}. + */ + @SuppressJava6Requirement(reason = "Usage guarded by java version check") + public static Deque newConcurrentDeque() { + if (javaVersion() < 7) { + return new LinkedBlockingDeque(); + } else { + return new ConcurrentLinkedDeque(); + } + } + + /** + * Return a {@link Random} which is not-threadsafe and so can only be used from the same thread. + */ + public static Random threadLocalRandom() { + return RANDOM_PROVIDER.current(); + } + + private static boolean isWindows0() { + boolean windows = "windows".equals(NORMALIZED_OS); + if (windows) { + logger.debug("Platform: Windows"); + } + return windows; + } + + private static boolean isOsx0() { + boolean osx = "osx".equals(NORMALIZED_OS); + if (osx) { + logger.debug("Platform: MacOS"); + } + return osx; + } + + private static boolean maybeSuperUser0() { + String username = SystemPropertyUtil.get("user.name"); + if (isWindows()) { + return "Administrator".equals(username); + } + // Check for root and toor as some BSDs have a toor user that is basically the same as root. + return "root".equals(username) || "toor".equals(username); + } + + private static Throwable unsafeUnavailabilityCause0() { + if (isAndroid()) { + logger.debug("sun.misc.Unsafe: unavailable (Android)"); + return new UnsupportedOperationException("sun.misc.Unsafe: unavailable (Android)"); + } + + if (isIkvmDotNet()) { + logger.debug("sun.misc.Unsafe: unavailable (IKVM.NET)"); + return new UnsupportedOperationException("sun.misc.Unsafe: unavailable (IKVM.NET)"); + } + + Throwable cause = PlatformDependent0.getUnsafeUnavailabilityCause(); + if (cause != null) { + return cause; + } + + try { + boolean hasUnsafe = PlatformDependent0.hasUnsafe(); + logger.debug("sun.misc.Unsafe: {}", hasUnsafe ? "available" : "unavailable"); + return hasUnsafe ? null : PlatformDependent0.getUnsafeUnavailabilityCause(); + } catch (Throwable t) { + logger.trace("Could not determine if Unsafe is available", t); + // Probably failed to initialize PlatformDependent0. + return new UnsupportedOperationException("Could not determine if Unsafe is available", t); + } + } + + /** + * Returns {@code true} if the running JVM is either IBM J9 or + * Eclipse OpenJ9, {@code false} otherwise. + */ + public static boolean isJ9Jvm() { + return IS_J9_JVM; + } + + private static boolean isJ9Jvm0() { + String vmName = SystemPropertyUtil.get("java.vm.name", "").toLowerCase(); + return vmName.startsWith("ibm j9") || vmName.startsWith("eclipse openj9"); + } + + /** + * Returns {@code true} if the running JVM is IKVM.NET, {@code false} otherwise. + */ + public static boolean isIkvmDotNet() { + return IS_IVKVM_DOT_NET; + } + + private static boolean isIkvmDotNet0() { + String vmName = SystemPropertyUtil.get("java.vm.name", "").toUpperCase(Locale.US); + return vmName.equals("IKVM.NET"); + } + + private static Pattern getMaxDirectMemorySizeArgPattern() { + // Pattern's is immutable so it's always safe published + Pattern pattern = MAX_DIRECT_MEMORY_SIZE_ARG_PATTERN; + if (pattern == null) { + pattern = Pattern.compile("\\s*-XX:MaxDirectMemorySize\\s*=\\s*([0-9]+)\\s*([kKmMgG]?)\\s*$"); + MAX_DIRECT_MEMORY_SIZE_ARG_PATTERN = pattern; + } + return pattern; + } + + /** + * Compute an estimate of the maximum amount of direct memory available to this JVM. + *

+ * The computation is not cached, so you probably want to use {@link #maxDirectMemory()} instead. + *

+ * This will produce debug log output when called. + * + * @return The estimated max direct memory, in bytes. + */ + public static long estimateMaxDirectMemory() { + long maxDirectMemory = PlatformDependent0.bitsMaxDirectMemory(); + if (maxDirectMemory > 0) { + return maxDirectMemory; + } + + ClassLoader systemClassLoader = null; + try { + systemClassLoader = getSystemClassLoader(); + + // When using IBM J9 / Eclipse OpenJ9 we should not use VM.maxDirectMemory() as it not reflects the + // correct value. + // See: + // - https://github.com/netty/netty/issues/7654 + String vmName = SystemPropertyUtil.get("java.vm.name", "").toLowerCase(); + if (!vmName.startsWith("ibm j9") && + // https://github.com/eclipse/openj9/blob/openj9-0.8.0/runtime/include/vendor_version.h#L53 + !vmName.startsWith("eclipse openj9")) { + // Try to get from sun.misc.VM.maxDirectMemory() which should be most accurate. + Class vmClass = Class.forName("sun.misc.VM", true, systemClassLoader); + Method m = vmClass.getDeclaredMethod("maxDirectMemory"); + maxDirectMemory = ((Number) m.invoke(null)).longValue(); + } + } catch (Throwable ignored) { + // Ignore + } + + if (maxDirectMemory > 0) { + return maxDirectMemory; + } + + try { + // Now try to get the JVM option (-XX:MaxDirectMemorySize) and parse it. + // Note that we are using reflection because Android doesn't have these classes. + Class mgmtFactoryClass = Class.forName( + "java.lang.management.ManagementFactory", true, systemClassLoader); + Class runtimeClass = Class.forName( + "java.lang.management.RuntimeMXBean", true, systemClassLoader); + + Object runtime = mgmtFactoryClass.getDeclaredMethod("getRuntimeMXBean").invoke(null); + + @SuppressWarnings("unchecked") + List vmArgs = (List) runtimeClass.getDeclaredMethod("getInputArguments").invoke(runtime); + + Pattern maxDirectMemorySizeArgPattern = getMaxDirectMemorySizeArgPattern(); + + for (int i = vmArgs.size() - 1; i >= 0; i--) { + Matcher m = maxDirectMemorySizeArgPattern.matcher(vmArgs.get(i)); + if (!m.matches()) { + continue; + } + + maxDirectMemory = Long.parseLong(m.group(1)); + switch (m.group(2).charAt(0)) { + case 'k': + case 'K': + maxDirectMemory *= 1024; + break; + case 'm': + case 'M': + maxDirectMemory *= 1024 * 1024; + break; + case 'g': + case 'G': + maxDirectMemory *= 1024 * 1024 * 1024; + break; + default: + break; + } + break; + } + } catch (Throwable ignored) { + // Ignore + } + + if (maxDirectMemory <= 0) { + maxDirectMemory = Runtime.getRuntime().maxMemory(); + logger.debug("maxDirectMemory: {} bytes (maybe)", maxDirectMemory); + } else { + logger.debug("maxDirectMemory: {} bytes", maxDirectMemory); + } + + return maxDirectMemory; + } + + private static File tmpdir0() { + File f; + try { + f = toDirectory(SystemPropertyUtil.get("io.netty.tmpdir")); + if (f != null) { + logger.debug("-Dio.netty.tmpdir: {}", f); + return f; + } + + f = toDirectory(SystemPropertyUtil.get("java.io.tmpdir")); + if (f != null) { + logger.debug("-Dio.netty.tmpdir: {} (java.io.tmpdir)", f); + return f; + } + + // This shouldn't happen, but just in case .. + if (isWindows()) { + f = toDirectory(System.getenv("TEMP")); + if (f != null) { + logger.debug("-Dio.netty.tmpdir: {} (%TEMP%)", f); + return f; + } + + String userprofile = System.getenv("USERPROFILE"); + if (userprofile != null) { + f = toDirectory(userprofile + "\\AppData\\Local\\Temp"); + if (f != null) { + logger.debug("-Dio.netty.tmpdir: {} (%USERPROFILE%\\AppData\\Local\\Temp)", f); + return f; + } + + f = toDirectory(userprofile + "\\Local Settings\\Temp"); + if (f != null) { + logger.debug("-Dio.netty.tmpdir: {} (%USERPROFILE%\\Local Settings\\Temp)", f); + return f; + } + } + } else { + f = toDirectory(System.getenv("TMPDIR")); + if (f != null) { + logger.debug("-Dio.netty.tmpdir: {} ($TMPDIR)", f); + return f; + } + } + } catch (Throwable ignored) { + // Environment variable inaccessible + } + + // Last resort. + if (isWindows()) { + f = new File("C:\\Windows\\Temp"); + } else { + f = new File("/tmp"); + } + + logger.warn("Failed to get the temporary directory; falling back to: {}", f); + return f; + } + + @SuppressWarnings("ResultOfMethodCallIgnored") + private static File toDirectory(String path) { + if (path == null) { + return null; + } + + File f = new File(path); + f.mkdirs(); + + if (!f.isDirectory()) { + return null; + } + + try { + return f.getAbsoluteFile(); + } catch (Exception ignored) { + return f; + } + } + + private static int bitMode0() { + // Check user-specified bit mode first. + int bitMode = SystemPropertyUtil.getInt("io.netty.bitMode", 0); + if (bitMode > 0) { + logger.debug("-Dio.netty.bitMode: {}", bitMode); + return bitMode; + } + + // And then the vendor specific ones which is probably most reliable. + bitMode = SystemPropertyUtil.getInt("sun.arch.data.model", 0); + if (bitMode > 0) { + logger.debug("-Dio.netty.bitMode: {} (sun.arch.data.model)", bitMode); + return bitMode; + } + bitMode = SystemPropertyUtil.getInt("com.ibm.vm.bitmode", 0); + if (bitMode > 0) { + logger.debug("-Dio.netty.bitMode: {} (com.ibm.vm.bitmode)", bitMode); + return bitMode; + } + + // os.arch also gives us a good hint. + String arch = SystemPropertyUtil.get("os.arch", "").toLowerCase(Locale.US).trim(); + if ("amd64".equals(arch) || "x86_64".equals(arch)) { + bitMode = 64; + } else if ("i386".equals(arch) || "i486".equals(arch) || "i586".equals(arch) || "i686".equals(arch)) { + bitMode = 32; + } + + if (bitMode > 0) { + logger.debug("-Dio.netty.bitMode: {} (os.arch: {})", bitMode, arch); + } + + // Last resort: guess from VM name and then fall back to most common 64-bit mode. + String vm = SystemPropertyUtil.get("java.vm.name", "").toLowerCase(Locale.US); + Pattern bitPattern = Pattern.compile("([1-9][0-9]+)-?bit"); + Matcher m = bitPattern.matcher(vm); + if (m.find()) { + return Integer.parseInt(m.group(1)); + } else { + return 64; + } + } + + private static int addressSize0() { + if (!hasUnsafe()) { + return -1; + } + return PlatformDependent0.addressSize(); + } + + private static long byteArrayBaseOffset0() { + if (!hasUnsafe()) { + return -1; + } + return PlatformDependent0.byteArrayBaseOffset(); + } + + private static boolean equalsSafe(byte[] bytes1, int startPos1, byte[] bytes2, int startPos2, int length) { + final int end = startPos1 + length; + for (; startPos1 < end; ++startPos1, ++startPos2) { + if (bytes1[startPos1] != bytes2[startPos2]) { + return false; + } + } + return true; + } + + private static boolean isZeroSafe(byte[] bytes, int startPos, int length) { + final int end = startPos + length; + for (; startPos < end; ++startPos) { + if (bytes[startPos] != 0) { + return false; + } + } + return true; + } + + /** + * Package private for testing purposes only! + */ + static int hashCodeAsciiSafe(byte[] bytes, int startPos, int length) { + int hash = HASH_CODE_ASCII_SEED; + final int remainingBytes = length & 7; + final int end = startPos + remainingBytes; + for (int i = startPos - 8 + length; i >= end; i -= 8) { + hash = PlatformDependent0.hashCodeAsciiCompute(getLongSafe(bytes, i), hash); + } + switch (remainingBytes) { + case 7: + return ((hash * HASH_CODE_C1 + hashCodeAsciiSanitize(bytes[startPos])) + * HASH_CODE_C2 + hashCodeAsciiSanitize(getShortSafe(bytes, startPos + 1))) + * HASH_CODE_C1 + hashCodeAsciiSanitize(getIntSafe(bytes, startPos + 3)); + case 6: + return (hash * HASH_CODE_C1 + hashCodeAsciiSanitize(getShortSafe(bytes, startPos))) + * HASH_CODE_C2 + hashCodeAsciiSanitize(getIntSafe(bytes, startPos + 2)); + case 5: + return (hash * HASH_CODE_C1 + hashCodeAsciiSanitize(bytes[startPos])) + * HASH_CODE_C2 + hashCodeAsciiSanitize(getIntSafe(bytes, startPos + 1)); + case 4: + return hash * HASH_CODE_C1 + hashCodeAsciiSanitize(getIntSafe(bytes, startPos)); + case 3: + return (hash * HASH_CODE_C1 + hashCodeAsciiSanitize(bytes[startPos])) + * HASH_CODE_C2 + hashCodeAsciiSanitize(getShortSafe(bytes, startPos + 1)); + case 2: + return hash * HASH_CODE_C1 + hashCodeAsciiSanitize(getShortSafe(bytes, startPos)); + case 1: + return hash * HASH_CODE_C1 + hashCodeAsciiSanitize(bytes[startPos]); + default: + return hash; + } + } + + public static String normalizedArch() { + return NORMALIZED_ARCH; + } + + public static String normalizedOs() { + return NORMALIZED_OS; + } + + public static Set normalizedLinuxClassifiers() { + return LINUX_OS_CLASSIFIERS; + } + + @SuppressJava6Requirement(reason = "Guarded by version check") + public static File createTempFile(String prefix, String suffix, File directory) throws IOException { + if (javaVersion() >= 7) { + if (directory == null) { + return Files.createTempFile(prefix, suffix).toFile(); + } + return Files.createTempFile(directory.toPath(), prefix, suffix).toFile(); + } + final File file; + if (directory == null) { + file = File.createTempFile(prefix, suffix); + } else { + file = File.createTempFile(prefix, suffix, directory); + } + + // Try to adjust the perms, if this fails there is not much else we can do... + if (!file.setReadable(false, false)) { + throw new IOException("Failed to set permissions on temporary file " + file); + } + if (!file.setReadable(true, true)) { + throw new IOException("Failed to set permissions on temporary file " + file); + } + return file; + } + + /** + * Adds only those classifier strings to dest which are present in allowed. + * + * @param allowed allowed classifiers + * @param dest destination set + * @param maybeClassifiers potential classifiers to add + */ + private static void addClassifier(Set allowed, Set dest, String... maybeClassifiers) { + for (String id : maybeClassifiers) { + if (allowed.contains(id)) { + dest.add(id); + } + } + } + + private static String normalizeOsReleaseVariableValue(String value) { + // Variable assignment values may be enclosed in double or single quotes. + return value.trim().replaceAll("[\"']", ""); + } + + private static String normalize(String value) { + return value.toLowerCase(Locale.US).replaceAll("[^a-z0-9]+", ""); + } + + private static String normalizeArch(String value) { + value = normalize(value); + if (value.matches("^(x8664|amd64|ia32e|em64t|x64)$")) { + return "x86_64"; + } + if (value.matches("^(x8632|x86|i[3-6]86|ia32|x32)$")) { + return "x86_32"; + } + if (value.matches("^(ia64|itanium64)$")) { + return "itanium_64"; + } + if (value.matches("^(sparc|sparc32)$")) { + return "sparc_32"; + } + if (value.matches("^(sparcv9|sparc64)$")) { + return "sparc_64"; + } + if (value.matches("^(arm|arm32)$")) { + return "arm_32"; + } + if ("aarch64".equals(value)) { + return "aarch_64"; + } + if ("riscv64".equals(value)) { + // os.detected.arch is riscv64 for RISC-V, no underscore + return "riscv64"; + } + if (value.matches("^(ppc|ppc32)$")) { + return "ppc_32"; + } + if ("ppc64".equals(value)) { + return "ppc_64"; + } + if ("ppc64le".equals(value)) { + return "ppcle_64"; + } + if ("s390".equals(value)) { + return "s390_32"; + } + if ("s390x".equals(value)) { + return "s390_64"; + } + if ("loongarch64".equals(value)) { + return "loongarch_64"; + } + + return "unknown"; + } + + private static String normalizeOs(String value) { + value = normalize(value); + if (value.startsWith("aix")) { + return "aix"; + } + if (value.startsWith("hpux")) { + return "hpux"; + } + if (value.startsWith("os400")) { + // Avoid the names such as os4000 + if (value.length() <= 5 || !Character.isDigit(value.charAt(5))) { + return "os400"; + } + } + if (value.startsWith("linux")) { + return "linux"; + } + if (value.startsWith("macosx") || value.startsWith("osx") || value.startsWith("darwin")) { + return "osx"; + } + if (value.startsWith("freebsd")) { + return "freebsd"; + } + if (value.startsWith("openbsd")) { + return "openbsd"; + } + if (value.startsWith("netbsd")) { + return "netbsd"; + } + if (value.startsWith("solaris") || value.startsWith("sunos")) { + return "sunos"; + } + if (value.startsWith("windows")) { + return "windows"; + } + + return "unknown"; + } + + private static final class AtomicLongCounter extends AtomicLong implements LongCounter { + private static final long serialVersionUID = 4074772784610639305L; + + @Override + public void add(long delta) { + addAndGet(delta); + } + + @Override + public void increment() { + incrementAndGet(); + } + + @Override + public void decrement() { + decrementAndGet(); + } + + @Override + public long value() { + return get(); + } + } + + private interface ThreadLocalRandomProvider { + Random current(); + } + + private PlatformDependent() { + // only static method supported + } +} diff --git a/netty-util/src/main/java/io/netty/util/internal/PlatformDependent0.java b/netty-util/src/main/java/io/netty/util/internal/PlatformDependent0.java new file mode 100644 index 0000000..ef39829 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/internal/PlatformDependent0.java @@ -0,0 +1,931 @@ +package io.netty.util.internal; + +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; +import java.lang.reflect.Constructor; +import java.lang.reflect.Field; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; +import java.nio.Buffer; +import java.nio.ByteBuffer; +import java.util.Objects; +import java.util.concurrent.atomic.AtomicLong; +import sun.misc.Unsafe; +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +/** + * The {@link PlatformDependent} operations which requires access to {@code sun.misc.*}. + */ +@SuppressJava6Requirement(reason = "Unsafe access is guarded") +final class PlatformDependent0 { + + private static final InternalLogger logger = InternalLoggerFactory.getInstance(PlatformDependent0.class); + private static final long ADDRESS_FIELD_OFFSET; + private static final long BYTE_ARRAY_BASE_OFFSET; + private static final long INT_ARRAY_BASE_OFFSET; + private static final long INT_ARRAY_INDEX_SCALE; + private static final long LONG_ARRAY_BASE_OFFSET; + private static final long LONG_ARRAY_INDEX_SCALE; + private static final Constructor DIRECT_BUFFER_CONSTRUCTOR; + private static final Throwable EXPLICIT_NO_UNSAFE_CAUSE = explicitNoUnsafeCause0(); + private static final Method ALLOCATE_ARRAY_METHOD; + private static final Method ALIGN_SLICE; + private static final int JAVA_VERSION = javaVersion0(); + private static final boolean IS_ANDROID = isAndroid0(); + private static final boolean STORE_FENCE_AVAILABLE; + + private static final Throwable UNSAFE_UNAVAILABILITY_CAUSE; + private static final Object INTERNAL_UNSAFE; + + // See https://github.com/oracle/graal/blob/master/sdk/src/org.graalvm.nativeimage/src/org/graalvm/nativeimage/ + // ImageInfo.java + private static final boolean RUNNING_IN_NATIVE_IMAGE = SystemPropertyUtil.contains( + "org.graalvm.nativeimage.imagecode"); + + private static final boolean IS_EXPLICIT_TRY_REFLECTION_SET_ACCESSIBLE = explicitTryReflectionSetAccessible0(); + + static final Unsafe UNSAFE; + + // constants borrowed from murmur3 + static final int HASH_CODE_ASCII_SEED = 0xc2b2ae35; + static final int HASH_CODE_C1 = 0xcc9e2d51; + static final int HASH_CODE_C2 = 0x1b873593; + + /** + * Limits the number of bytes to copy per {@link Unsafe#copyMemory(long, long, long)} to allow safepoint polling + * during a large copy. + */ + private static final long UNSAFE_COPY_THRESHOLD = 1024L * 1024L; + + private static final boolean UNALIGNED; + + private static final long BITS_MAX_DIRECT_MEMORY; + + static { + final ByteBuffer direct; + Field addressField = null; + Method allocateArrayMethod = null; + Throwable unsafeUnavailabilityCause = null; + Unsafe unsafe; + Object internalUnsafe = null; + boolean storeFenceAvailable = false; + if ((unsafeUnavailabilityCause = EXPLICIT_NO_UNSAFE_CAUSE) != null) { + direct = null; + addressField = null; + unsafe = null; + internalUnsafe = null; + } else { + direct = ByteBuffer.allocateDirect(1); + + // attempt to access field Unsafe#theUnsafe + Object maybeUnsafe; + try { + final Field unsafeField = Unsafe.class.getDeclaredField("theUnsafe"); + // We always want to try using Unsafe as the access still works on java9 as well and + // we need it for out native-transports and many optimizations. + Throwable cause = ReflectionUtil.trySetAccessible(unsafeField, false); + if (cause != null) { + maybeUnsafe = cause; + } else { + // the unsafe instance + maybeUnsafe = unsafeField.get(null); + } + } catch (NoSuchFieldException | IllegalAccessException e) { + maybeUnsafe = e; + } catch (NoClassDefFoundError e) { + // Also catch NoClassDefFoundError in case someone uses for example OSGI and it made + // Unsafe unloadable. + maybeUnsafe = e; + } + // the conditional check here can not be replaced with checking that maybeUnsafe + // is an instanceof Unsafe and reversing the if and else blocks; this is because an + // instanceof check against Unsafe will trigger a class load and we might not have + // the runtime permission accessClassInPackage.sun.misc + if (maybeUnsafe instanceof Throwable) { + unsafe = null; + unsafeUnavailabilityCause = (Throwable) maybeUnsafe; + if (logger.isTraceEnabled()) { + logger.debug("sun.misc.Unsafe.theUnsafe: unavailable", (Throwable) maybeUnsafe); + } else { + logger.debug("sun.misc.Unsafe.theUnsafe: unavailable: {}", ((Throwable) maybeUnsafe).getMessage()); + } + } else { + unsafe = (Unsafe) maybeUnsafe; + logger.debug("sun.misc.Unsafe.theUnsafe: available"); + } + + // ensure the unsafe supports all necessary methods to work around the mistake in the latest OpenJDK + // https://github.com/netty/netty/issues/1061 + // https://www.mail-archive.com/jdk6-dev@openjdk.java.net/msg00698.html + if (unsafe != null) { + final Unsafe finalUnsafe = unsafe; + Object maybeException; + try { + finalUnsafe.getClass().getDeclaredMethod( + "copyMemory", Object.class, long.class, Object.class, long.class, long.class); + maybeException = null; + } catch (NoSuchMethodException e) { + maybeException = e; + } + if (maybeException == null) { + logger.debug("sun.misc.Unsafe.copyMemory: available"); + } else { + // Unsafe.copyMemory(Object, long, Object, long, long) unavailable. + unsafe = null; + unsafeUnavailabilityCause = (Throwable) maybeException; + if (logger.isTraceEnabled()) { + logger.debug("sun.misc.Unsafe.copyMemory: unavailable", (Throwable) maybeException); + } else { + logger.debug("sun.misc.Unsafe.copyMemory: unavailable: {}", + ((Throwable) maybeException).getMessage()); + } + } + } + + // ensure Unsafe::storeFence to be available: jdk < 8 shouldn't have it + if (unsafe != null) { + final Unsafe finalUnsafe = unsafe; + Object maybeException; + try { + finalUnsafe.getClass().getDeclaredMethod("storeFence"); + maybeException = null; + } catch (NoSuchMethodException e) { + maybeException = e; + } + if (maybeException == null) { + logger.debug("sun.misc.Unsafe.storeFence: available"); + storeFenceAvailable = true; + } else { + storeFenceAvailable = false; + // Unsafe.storeFence unavailable. + if (logger.isTraceEnabled()) { + logger.debug("sun.misc.Unsafe.storeFence: unavailable", (Throwable) maybeException); + } else { + logger.debug("sun.misc.Unsafe.storeFence: unavailable: {}", + ((Throwable) maybeException).getMessage()); + } + } + } + + if (unsafe != null) { + final Unsafe finalUnsafe = unsafe; + + // attempt to access field Buffer#address + Object maybeAddressField; + try { + final Field field = Buffer.class.getDeclaredField("address"); + // Use Unsafe to read value of the address field. This way it will not fail on JDK9+ which + // will forbid changing the access level via reflection. + final long offset = finalUnsafe.objectFieldOffset(field); + final long address = finalUnsafe.getLong(direct, offset); + + // if direct really is a direct buffer, address will be non-zero + if (address == 0) { + maybeAddressField = null; + } else { + maybeAddressField = field; + } + } catch (NoSuchFieldException e) { + maybeAddressField = e; + } + if (maybeAddressField instanceof Field) { + addressField = (Field) maybeAddressField; + logger.debug("java.nio.Buffer.address: available"); + } else { + unsafeUnavailabilityCause = (Throwable) maybeAddressField; + if (logger.isTraceEnabled()) { + logger.debug("java.nio.Buffer.address: unavailable", (Throwable) maybeAddressField); + } else { + logger.debug("java.nio.Buffer.address: unavailable: {}", + ((Throwable) maybeAddressField).getMessage()); + } + + // If we cannot access the address of a direct buffer, there's no point of using unsafe. + // Let's just pretend unsafe is unavailable for overall simplicity. + unsafe = null; + } + } + + if (unsafe != null) { + // There are assumptions made where ever BYTE_ARRAY_BASE_OFFSET is used (equals, hashCodeAscii, and + // primitive accessors) that arrayIndexScale == 1, and results are undefined if this is not the case. + long byteArrayIndexScale = unsafe.arrayIndexScale(byte[].class); + if (byteArrayIndexScale != 1) { + logger.debug("unsafe.arrayIndexScale is {} (expected: 1). Not using unsafe.", byteArrayIndexScale); + unsafeUnavailabilityCause = new UnsupportedOperationException("Unexpected unsafe.arrayIndexScale"); + unsafe = null; + } + } + } + UNSAFE_UNAVAILABILITY_CAUSE = unsafeUnavailabilityCause; + UNSAFE = unsafe; + + if (unsafe == null) { + ADDRESS_FIELD_OFFSET = -1; + BYTE_ARRAY_BASE_OFFSET = -1; + LONG_ARRAY_BASE_OFFSET = -1; + LONG_ARRAY_INDEX_SCALE = -1; + INT_ARRAY_BASE_OFFSET = -1; + INT_ARRAY_INDEX_SCALE = -1; + UNALIGNED = false; + BITS_MAX_DIRECT_MEMORY = -1; + DIRECT_BUFFER_CONSTRUCTOR = null; + ALLOCATE_ARRAY_METHOD = null; + STORE_FENCE_AVAILABLE = false; + } else { + Constructor directBufferConstructor; + long address = -1; + try { + Object maybeDirectBufferConstructor; + try { + final Constructor constructor = javaVersion() >= 21 ? + direct.getClass().getDeclaredConstructor(long.class, long.class) : + direct.getClass().getDeclaredConstructor(long.class, int.class); + Throwable cause = ReflectionUtil.trySetAccessible(constructor, true); + maybeDirectBufferConstructor = Objects.requireNonNullElse(cause, constructor); + } catch (NoSuchMethodException e) { + maybeDirectBufferConstructor = e; + } + if (maybeDirectBufferConstructor instanceof Constructor) { + address = UNSAFE.allocateMemory(1); + // try to use the constructor now + try { + ((Constructor) maybeDirectBufferConstructor).newInstance(address, 1); + directBufferConstructor = (Constructor) maybeDirectBufferConstructor; + logger.debug("direct buffer constructor: available"); + } catch (InstantiationException | IllegalAccessException | InvocationTargetException e) { + directBufferConstructor = null; + } + } else { + if (logger.isTraceEnabled()) { + logger.debug("direct buffer constructor: unavailable", + (Throwable) maybeDirectBufferConstructor); + } else { + logger.debug("direct buffer constructor: unavailable: {}", + ((Throwable) maybeDirectBufferConstructor).getMessage()); + } + directBufferConstructor = null; + } + } finally { + if (address != -1) { + UNSAFE.freeMemory(address); + } + } + DIRECT_BUFFER_CONSTRUCTOR = directBufferConstructor; + ADDRESS_FIELD_OFFSET = objectFieldOffset(addressField); + BYTE_ARRAY_BASE_OFFSET = UNSAFE.arrayBaseOffset(byte[].class); + INT_ARRAY_BASE_OFFSET = UNSAFE.arrayBaseOffset(int[].class); + INT_ARRAY_INDEX_SCALE = UNSAFE.arrayIndexScale(int[].class); + LONG_ARRAY_BASE_OFFSET = UNSAFE.arrayBaseOffset(long[].class); + LONG_ARRAY_INDEX_SCALE = UNSAFE.arrayIndexScale(long[].class); + final boolean unaligned; + // using a known type to avoid loading new classes + final AtomicLong maybeMaxMemory = new AtomicLong(-1); + Object maybeUnaligned; + try { + Class bitsClass = + Class.forName("java.nio.Bits", false, getSystemClassLoader()); + int version = javaVersion(); + if (unsafeStaticFieldOffsetSupported() && version >= 9) { + // Java9/10 use all lowercase and later versions all uppercase. + String fieldName = version >= 11 ? "MAX_MEMORY" : "maxMemory"; + // On Java9 and later we try to directly access the field as we can do this without + // adjust the accessible levels. + try { + Field maxMemoryField = bitsClass.getDeclaredField(fieldName); + if (maxMemoryField.getType() == long.class) { + long offset = UNSAFE.staticFieldOffset(maxMemoryField); + Object object = UNSAFE.staticFieldBase(maxMemoryField); + maybeMaxMemory.lazySet(UNSAFE.getLong(object, offset)); + } + } catch (Throwable ignore) { + // ignore if can't access + } + fieldName = version >= 11 ? "UNALIGNED" : "unaligned"; + try { + Field unalignedField = bitsClass.getDeclaredField(fieldName); + if (unalignedField.getType() == boolean.class) { + long offset = UNSAFE.staticFieldOffset(unalignedField); + Object object = UNSAFE.staticFieldBase(unalignedField); + maybeUnaligned = UNSAFE.getBoolean(object, offset); + } + // There is something unexpected stored in the field, + // let us fall-back and try to use a reflective method call as last resort. + } catch (NoSuchFieldException ignore) { + // We did not find the field we expected, move on. + } + } + Method unalignedMethod = bitsClass.getDeclaredMethod("unaligned"); + Throwable cause = ReflectionUtil.trySetAccessible(unalignedMethod, true); + if (cause != null) { + maybeUnaligned = cause; + } + maybeUnaligned = unalignedMethod.invoke(null); + } catch (NoSuchMethodException | IllegalAccessException | ClassNotFoundException | + InvocationTargetException e) { + maybeUnaligned = e; + } + if (maybeUnaligned instanceof Boolean) { + unaligned = (Boolean) maybeUnaligned; + logger.debug("java.nio.Bits.unaligned: available, {}", unaligned); + } else { + String arch = SystemPropertyUtil.get("os.arch", ""); + //noinspection DynamicRegexReplaceableByCompiledPattern + unaligned = arch.matches("^(i[3-6]86|x86(_64)?|x64|amd64)$"); + Throwable t = (Throwable) maybeUnaligned; + if (logger.isTraceEnabled()) { + logger.debug("java.nio.Bits.unaligned: unavailable, {}", unaligned, t); + } else { + logger.debug("java.nio.Bits.unaligned: unavailable, {}, {}", unaligned, t.getMessage()); + } + } + + UNALIGNED = unaligned; + BITS_MAX_DIRECT_MEMORY = maybeMaxMemory.get() >= 0 ? maybeMaxMemory.get() : -1; + + if (javaVersion() >= 9) { + Object maybeException; + try { + // Java9 has jdk.internal.misc.Unsafe and not all methods are propagated to + // sun.misc.Unsafe + Class internalUnsafeClass = getClassLoader(PlatformDependent0.class) + .loadClass("jdk.internal.misc.Unsafe"); + Method method = internalUnsafeClass.getDeclaredMethod("getUnsafe"); + maybeException = method.invoke(null); + } catch (Throwable e) { + maybeException = e; + } + if (!(maybeException instanceof Throwable)) { + internalUnsafe = maybeException; + final Object finalInternalUnsafe = internalUnsafe; + try { + maybeException = finalInternalUnsafe.getClass().getDeclaredMethod( + "allocateUninitializedArray", Class.class, int.class); + } catch (NoSuchMethodException e) { + throw new RuntimeException(e); + } + if (maybeException instanceof Method) { + try { + Method m = (Method) maybeException; + byte[] bytes = (byte[]) m.invoke(finalInternalUnsafe, byte.class, 8); + assert bytes.length == 8; + allocateArrayMethod = m; + } catch (IllegalAccessException | InvocationTargetException e) { + maybeException = e; + } + } + } + + if (maybeException instanceof Throwable) { + if (logger.isTraceEnabled()) { + logger.debug("jdk.internal.misc.Unsafe.allocateUninitializedArray(int): unavailable", + (Throwable) maybeException); + } else { + logger.debug("jdk.internal.misc.Unsafe.allocateUninitializedArray(int): unavailable: {}", + ((Throwable) maybeException).getMessage()); + } + } else { + logger.debug("jdk.internal.misc.Unsafe.allocateUninitializedArray(int): available"); + } + } else { + logger.debug("jdk.internal.misc.Unsafe.allocateUninitializedArray(int): unavailable prior to Java9"); + } + ALLOCATE_ARRAY_METHOD = allocateArrayMethod; + STORE_FENCE_AVAILABLE = storeFenceAvailable; + } + + if (javaVersion() > 9) { + try { + ALIGN_SLICE = ByteBuffer.class.getDeclaredMethod("alignedSlice", int.class); + } catch (NoSuchMethodException e) { + throw new RuntimeException(e); + } + } else { + ALIGN_SLICE = null; + } + + INTERNAL_UNSAFE = internalUnsafe; + + logger.debug("java.nio.DirectByteBuffer.(long, {int,long}): {}", + DIRECT_BUFFER_CONSTRUCTOR != null ? "available" : "unavailable"); + } + + private static boolean unsafeStaticFieldOffsetSupported() { + return !RUNNING_IN_NATIVE_IMAGE; + } + + static boolean isExplicitNoUnsafe() { + return EXPLICIT_NO_UNSAFE_CAUSE != null; + } + + private static Throwable explicitNoUnsafeCause0() { + final boolean noUnsafe = SystemPropertyUtil.getBoolean("io.netty.noUnsafe", false); + logger.debug("-Dio.netty.noUnsafe: {}", noUnsafe); + + if (noUnsafe) { + logger.debug("sun.misc.Unsafe: unavailable (io.netty.noUnsafe)"); + return new UnsupportedOperationException("sun.misc.Unsafe: unavailable (io.netty.noUnsafe)"); + } + + // Legacy properties + String unsafePropName; + if (SystemPropertyUtil.contains("io.netty.tryUnsafe")) { + unsafePropName = "io.netty.tryUnsafe"; + } else { + unsafePropName = "org.jboss.netty.tryUnsafe"; + } + + if (!SystemPropertyUtil.getBoolean(unsafePropName, true)) { + String msg = "sun.misc.Unsafe: unavailable (" + unsafePropName + ")"; + logger.debug(msg); + return new UnsupportedOperationException(msg); + } + + return null; + } + + static boolean isUnaligned() { + return UNALIGNED; + } + + /** + * Any value >= 0 should be considered as a valid max direct memory value. + */ + static long bitsMaxDirectMemory() { + return BITS_MAX_DIRECT_MEMORY; + } + + static boolean hasUnsafe() { + return UNSAFE != null; + } + + static Throwable getUnsafeUnavailabilityCause() { + return UNSAFE_UNAVAILABILITY_CAUSE; + } + + static boolean unalignedAccess() { + return UNALIGNED; + } + + static void throwException(Throwable cause) { + // JVM has been observed to crash when passing a null argument. See https://github.com/netty/netty/issues/4131. + UNSAFE.throwException(checkNotNull(cause, "cause")); + } + + static boolean hasDirectBufferNoCleanerConstructor() { + return DIRECT_BUFFER_CONSTRUCTOR != null; + } + + static ByteBuffer reallocateDirectNoCleaner(ByteBuffer buffer, int capacity) { + return newDirectBuffer(UNSAFE.reallocateMemory(directBufferAddress(buffer), capacity), capacity); + } + + static ByteBuffer allocateDirectNoCleaner(int capacity) { + // Calling malloc with capacity of 0 may return a null ptr or a memory address that can be used. + // Just use 1 to make it safe to use in all cases: + // See: https://pubs.opengroup.org/onlinepubs/009695399/functions/malloc.html + return newDirectBuffer(UNSAFE.allocateMemory(Math.max(1, capacity)), capacity); + } + + static boolean hasAlignSliceMethod() { + return ALIGN_SLICE != null; + } + + static ByteBuffer alignSlice(ByteBuffer buffer, int alignment) { + try { + return (ByteBuffer) ALIGN_SLICE.invoke(buffer, alignment); + } catch (IllegalAccessException e) { + throw new Error(e); + } catch (InvocationTargetException e) { + throw new Error(e); + } + } + + static boolean hasAllocateArrayMethod() { + return ALLOCATE_ARRAY_METHOD != null; + } + + static byte[] allocateUninitializedArray(int size) { + try { + return (byte[]) ALLOCATE_ARRAY_METHOD.invoke(INTERNAL_UNSAFE, byte.class, size); + } catch (IllegalAccessException e) { + throw new Error(e); + } catch (InvocationTargetException e) { + throw new Error(e); + } + } + + static ByteBuffer newDirectBuffer(long address, int capacity) { + ObjectUtil.checkPositiveOrZero(capacity, "capacity"); + + try { + return (ByteBuffer) DIRECT_BUFFER_CONSTRUCTOR.newInstance(address, capacity); + } catch (Throwable cause) { + // Not expected to ever throw! + if (cause instanceof Error) { + throw (Error) cause; + } + throw new Error(cause); + } + } + + static long directBufferAddress(ByteBuffer buffer) { + return getLong(buffer, ADDRESS_FIELD_OFFSET); + } + + static long byteArrayBaseOffset() { + return BYTE_ARRAY_BASE_OFFSET; + } + + static Object getObject(Object object, long fieldOffset) { + return UNSAFE.getObject(object, fieldOffset); + } + + static int getInt(Object object, long fieldOffset) { + return UNSAFE.getInt(object, fieldOffset); + } + + static void safeConstructPutInt(Object object, long fieldOffset, int value) { + if (STORE_FENCE_AVAILABLE) { + UNSAFE.putInt(object, fieldOffset, value); + UNSAFE.storeFence(); + } else { + UNSAFE.putIntVolatile(object, fieldOffset, value); + } + } + + private static long getLong(Object object, long fieldOffset) { + return UNSAFE.getLong(object, fieldOffset); + } + + static long objectFieldOffset(Field field) { + return UNSAFE.objectFieldOffset(field); + } + + static byte getByte(long address) { + return UNSAFE.getByte(address); + } + + static short getShort(long address) { + return UNSAFE.getShort(address); + } + + static int getInt(long address) { + return UNSAFE.getInt(address); + } + + static long getLong(long address) { + return UNSAFE.getLong(address); + } + + static byte getByte(byte[] data, int index) { + return UNSAFE.getByte(data, BYTE_ARRAY_BASE_OFFSET + index); + } + + static byte getByte(byte[] data, long index) { + return UNSAFE.getByte(data, BYTE_ARRAY_BASE_OFFSET + index); + } + + static short getShort(byte[] data, int index) { + return UNSAFE.getShort(data, BYTE_ARRAY_BASE_OFFSET + index); + } + + static int getInt(byte[] data, int index) { + return UNSAFE.getInt(data, BYTE_ARRAY_BASE_OFFSET + index); + } + + static int getInt(int[] data, long index) { + return UNSAFE.getInt(data, INT_ARRAY_BASE_OFFSET + INT_ARRAY_INDEX_SCALE * index); + } + + static int getIntVolatile(long address) { + return UNSAFE.getIntVolatile(null, address); + } + + static long getLong(byte[] data, int index) { + return UNSAFE.getLong(data, BYTE_ARRAY_BASE_OFFSET + index); + } + + static long getLong(long[] data, long index) { + return UNSAFE.getLong(data, LONG_ARRAY_BASE_OFFSET + LONG_ARRAY_INDEX_SCALE * index); + } + + static void putByte(long address, byte value) { + UNSAFE.putByte(address, value); + } + + static void putShort(long address, short value) { + UNSAFE.putShort(address, value); + } + + static void putInt(long address, int value) { + UNSAFE.putInt(address, value); + } + + static void putLong(long address, long value) { + UNSAFE.putLong(address, value); + } + + static void putByte(byte[] data, int index, byte value) { + UNSAFE.putByte(data, BYTE_ARRAY_BASE_OFFSET + index, value); + } + + static void putByte(Object data, long offset, byte value) { + UNSAFE.putByte(data, offset, value); + } + + static void putShort(byte[] data, int index, short value) { + UNSAFE.putShort(data, BYTE_ARRAY_BASE_OFFSET + index, value); + } + + static void putInt(byte[] data, int index, int value) { + UNSAFE.putInt(data, BYTE_ARRAY_BASE_OFFSET + index, value); + } + + static void putLong(byte[] data, int index, long value) { + UNSAFE.putLong(data, BYTE_ARRAY_BASE_OFFSET + index, value); + } + + static void putObject(Object o, long offset, Object x) { + UNSAFE.putObject(o, offset, x); + } + + static void copyMemory(long srcAddr, long dstAddr, long length) { + // Manual safe-point polling is only needed prior Java9: + // See https://bugs.openjdk.java.net/browse/JDK-8149596 + if (javaVersion() <= 8) { + copyMemoryWithSafePointPolling(srcAddr, dstAddr, length); + } else { + UNSAFE.copyMemory(srcAddr, dstAddr, length); + } + } + + private static void copyMemoryWithSafePointPolling(long srcAddr, long dstAddr, long length) { + while (length > 0) { + long size = Math.min(length, UNSAFE_COPY_THRESHOLD); + UNSAFE.copyMemory(srcAddr, dstAddr, size); + length -= size; + srcAddr += size; + dstAddr += size; + } + } + + static void copyMemory(Object src, long srcOffset, Object dst, long dstOffset, long length) { + // Manual safe-point polling is only needed prior Java9: + // See https://bugs.openjdk.java.net/browse/JDK-8149596 + if (javaVersion() <= 8) { + copyMemoryWithSafePointPolling(src, srcOffset, dst, dstOffset, length); + } else { + UNSAFE.copyMemory(src, srcOffset, dst, dstOffset, length); + } + } + + private static void copyMemoryWithSafePointPolling( + Object src, long srcOffset, Object dst, long dstOffset, long length) { + while (length > 0) { + long size = Math.min(length, UNSAFE_COPY_THRESHOLD); + UNSAFE.copyMemory(src, srcOffset, dst, dstOffset, size); + length -= size; + srcOffset += size; + dstOffset += size; + } + } + + static void setMemory(long address, long bytes, byte value) { + UNSAFE.setMemory(address, bytes, value); + } + + static void setMemory(Object o, long offset, long bytes, byte value) { + UNSAFE.setMemory(o, offset, bytes, value); + } + + static boolean equals(byte[] bytes1, int startPos1, byte[] bytes2, int startPos2, int length) { + int remainingBytes = length & 7; + final long baseOffset1 = BYTE_ARRAY_BASE_OFFSET + startPos1; + final long diff = startPos2 - startPos1; + if (length >= 8) { + final long end = baseOffset1 + remainingBytes; + for (long i = baseOffset1 - 8 + length; i >= end; i -= 8) { + if (UNSAFE.getLong(bytes1, i) != UNSAFE.getLong(bytes2, i + diff)) { + return false; + } + } + } + if (remainingBytes >= 4) { + remainingBytes -= 4; + long pos = baseOffset1 + remainingBytes; + if (UNSAFE.getInt(bytes1, pos) != UNSAFE.getInt(bytes2, pos + diff)) { + return false; + } + } + final long baseOffset2 = baseOffset1 + diff; + if (remainingBytes >= 2) { + return UNSAFE.getChar(bytes1, baseOffset1) == UNSAFE.getChar(bytes2, baseOffset2) && + (remainingBytes == 2 || + UNSAFE.getByte(bytes1, baseOffset1 + 2) == UNSAFE.getByte(bytes2, baseOffset2 + 2)); + } + return remainingBytes == 0 || + UNSAFE.getByte(bytes1, baseOffset1) == UNSAFE.getByte(bytes2, baseOffset2); + } + + static int equalsConstantTime(byte[] bytes1, int startPos1, byte[] bytes2, int startPos2, int length) { + long result = 0; + long remainingBytes = length & 7; + final long baseOffset1 = BYTE_ARRAY_BASE_OFFSET + startPos1; + final long end = baseOffset1 + remainingBytes; + final long diff = startPos2 - startPos1; + for (long i = baseOffset1 - 8 + length; i >= end; i -= 8) { + result |= UNSAFE.getLong(bytes1, i) ^ UNSAFE.getLong(bytes2, i + diff); + } + if (remainingBytes >= 4) { + result |= UNSAFE.getInt(bytes1, baseOffset1) ^ UNSAFE.getInt(bytes2, baseOffset1 + diff); + remainingBytes -= 4; + } + if (remainingBytes >= 2) { + long pos = end - remainingBytes; + result |= UNSAFE.getChar(bytes1, pos) ^ UNSAFE.getChar(bytes2, pos + diff); + remainingBytes -= 2; + } + if (remainingBytes == 1) { + long pos = end - 1; + result |= UNSAFE.getByte(bytes1, pos) ^ UNSAFE.getByte(bytes2, pos + diff); + } + return ConstantTimeUtils.equalsConstantTime(result, 0); + } + + static boolean isZero(byte[] bytes, int startPos, int length) { + if (length <= 0) { + return true; + } + final long baseOffset = BYTE_ARRAY_BASE_OFFSET + startPos; + int remainingBytes = length & 7; + final long end = baseOffset + remainingBytes; + for (long i = baseOffset - 8 + length; i >= end; i -= 8) { + if (UNSAFE.getLong(bytes, i) != 0) { + return false; + } + } + + if (remainingBytes >= 4) { + remainingBytes -= 4; + if (UNSAFE.getInt(bytes, baseOffset + remainingBytes) != 0) { + return false; + } + } + if (remainingBytes >= 2) { + return UNSAFE.getChar(bytes, baseOffset) == 0 && + (remainingBytes == 2 || bytes[startPos + 2] == 0); + } + return bytes[startPos] == 0; + } + + static int hashCodeAscii(byte[] bytes, int startPos, int length) { + int hash = HASH_CODE_ASCII_SEED; + long baseOffset = BYTE_ARRAY_BASE_OFFSET + startPos; + final int remainingBytes = length & 7; + final long end = baseOffset + remainingBytes; + for (long i = baseOffset - 8 + length; i >= end; i -= 8) { + hash = hashCodeAsciiCompute(UNSAFE.getLong(bytes, i), hash); + } + if (remainingBytes == 0) { + return hash; + } + int hcConst = HASH_CODE_C1; + if (remainingBytes != 2 & remainingBytes != 4 & remainingBytes != 6) { // 1, 3, 5, 7 + hash = hash * HASH_CODE_C1 + hashCodeAsciiSanitize(UNSAFE.getByte(bytes, baseOffset)); + hcConst = HASH_CODE_C2; + baseOffset++; + } + if (remainingBytes != 1 & remainingBytes != 4 & remainingBytes != 5) { // 2, 3, 6, 7 + hash = hash * hcConst + hashCodeAsciiSanitize(UNSAFE.getShort(bytes, baseOffset)); + hcConst = hcConst == HASH_CODE_C1 ? HASH_CODE_C2 : HASH_CODE_C1; + baseOffset += 2; + } + if (remainingBytes >= 4) { // 4, 5, 6, 7 + return hash * hcConst + hashCodeAsciiSanitize(UNSAFE.getInt(bytes, baseOffset)); + } + return hash; + } + + static int hashCodeAsciiCompute(long value, int hash) { + // masking with 0x1f reduces the number of overall bits that impact the hash code but makes the hash + // code the same regardless of character case (upper case or lower case hash is the same). + return hash * HASH_CODE_C1 + + // Low order int + hashCodeAsciiSanitize((int) value) * HASH_CODE_C2 + + // High order int + (int) ((value & 0x1f1f1f1f00000000L) >>> 32); + } + + static int hashCodeAsciiSanitize(int value) { + return value & 0x1f1f1f1f; + } + + static int hashCodeAsciiSanitize(short value) { + return value & 0x1f1f; + } + + static int hashCodeAsciiSanitize(byte value) { + return value & 0x1f; + } + + static ClassLoader getClassLoader(final Class clazz) { + return clazz.getClassLoader(); + } + + static ClassLoader getContextClassLoader() { + return Thread.currentThread().getContextClassLoader(); + } + + static ClassLoader getSystemClassLoader() { + return ClassLoader.getSystemClassLoader(); + } + + static int addressSize() { + return UNSAFE.addressSize(); + } + + static long allocateMemory(long size) { + return UNSAFE.allocateMemory(size); + } + + static void freeMemory(long address) { + UNSAFE.freeMemory(address); + } + + static long reallocateMemory(long address, long newSize) { + return UNSAFE.reallocateMemory(address, newSize); + } + + static boolean isAndroid() { + return IS_ANDROID; + } + + private static boolean isAndroid0() { + // Idea: Sometimes java binaries include Android classes on the classpath, even if it isn't actually Android. + // Rather than check if certain classes are present, just check the VM, which is tied to the JDK. + + // Optional improvement: check if `android.os.Build.VERSION` is >= 24. On later versions of Android, the + // OpenJDK is used, which means `Unsafe` will actually work as expected. + + // Android sets this property to Dalvik, regardless of whether it actually is. + String vmName = SystemPropertyUtil.get("java.vm.name"); + boolean isAndroid = "Dalvik".equals(vmName); + if (isAndroid) { + logger.debug("Platform: Android"); + } + return isAndroid; + } + + private static boolean explicitTryReflectionSetAccessible0() { + // we disable reflective access + return SystemPropertyUtil.getBoolean("io.netty.tryReflectionSetAccessible", + javaVersion() < 9 || RUNNING_IN_NATIVE_IMAGE); + } + + static boolean isExplicitTryReflectionSetAccessible() { + return IS_EXPLICIT_TRY_REFLECTION_SET_ACCESSIBLE; + } + + static int javaVersion() { + return JAVA_VERSION; + } + + private static int javaVersion0() { + final int majorVersion; + + if (isAndroid0()) { + majorVersion = 6; + } else { + majorVersion = majorVersionFromJavaSpecificationVersion(); + } + + logger.debug("Java version: {}", majorVersion); + + return majorVersion; + } + + // Package-private for testing only + static int majorVersionFromJavaSpecificationVersion() { + return majorVersion(SystemPropertyUtil.get("java.specification.version", "1.6")); + } + + // Package-private for testing only + static int majorVersion(final String javaSpecVersion) { + final String[] components = javaSpecVersion.split("\\."); + final int[] version = new int[components.length]; + for (int i = 0; i < components.length; i++) { + version[i] = Integer.parseInt(components[i]); + } + + if (version[0] == 1) { + assert version[1] >= 6; + return version[1]; + } else { + return version[0]; + } + } + + private PlatformDependent0() { + } +} diff --git a/netty-util/src/main/java/io/netty/util/internal/PriorityQueue.java b/netty-util/src/main/java/io/netty/util/internal/PriorityQueue.java new file mode 100644 index 0000000..6535cf3 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/internal/PriorityQueue.java @@ -0,0 +1,47 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.internal; + +import java.util.Queue; + +public interface PriorityQueue extends Queue { + /** + * Same as {@link #remove(Object)} but typed using generics. + */ + boolean removeTyped(T node); + + /** + * Same as {@link #contains(Object)} but typed using generics. + */ + boolean containsTyped(T node); + + /** + * Notify the queue that the priority for {@code node} has changed. The queue will adjust to ensure the priority + * queue properties are maintained. + * + * @param node An object which is in this queue and the priority may have changed. + */ + void priorityChanged(T node); + + /** + * Removes all of the elements from this {@link PriorityQueue} without calling + * {@link PriorityQueueNode#priorityQueueIndex(DefaultPriorityQueue)} or explicitly removing references to them to + * allow them to be garbage collected. This should only be used when it is certain that the nodes will not be + * re-inserted into this or any other {@link PriorityQueue} and it is known that the {@link PriorityQueue} itself + * will be garbage collected after this call. + */ + void clearIgnoringIndexes(); +} diff --git a/netty-util/src/main/java/io/netty/util/internal/PriorityQueueNode.java b/netty-util/src/main/java/io/netty/util/internal/PriorityQueueNode.java new file mode 100644 index 0000000..c1ecc5a --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/internal/PriorityQueueNode.java @@ -0,0 +1,45 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.internal; + +/** + * Provides methods for {@link DefaultPriorityQueue} to maintain internal state. These methods should generally not be + * used outside the scope of {@link DefaultPriorityQueue}. + */ +public interface PriorityQueueNode { + /** + * This should be used to initialize the storage returned by {@link #priorityQueueIndex(DefaultPriorityQueue)}. + */ + int INDEX_NOT_IN_QUEUE = -1; + + /** + * Get the last value set by {@link #priorityQueueIndex(DefaultPriorityQueue, int)} for the value corresponding to + * {@code queue}. + *

+ * Throwing exceptions from this method will result in undefined behavior. + */ + int priorityQueueIndex(DefaultPriorityQueue queue); + + /** + * Used by {@link DefaultPriorityQueue} to maintain state for an element in the queue. + *

+ * Throwing exceptions from this method will result in undefined behavior. + * + * @param queue The queue for which the index is being set. + * @param i The index as used by {@link DefaultPriorityQueue}. + */ + void priorityQueueIndex(DefaultPriorityQueue queue, int i); +} diff --git a/netty-util/src/main/java/io/netty/util/internal/PromiseNotificationUtil.java b/netty-util/src/main/java/io/netty/util/internal/PromiseNotificationUtil.java new file mode 100644 index 0000000..dfd1c18 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/internal/PromiseNotificationUtil.java @@ -0,0 +1,77 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.internal; + +import io.netty.util.concurrent.Promise; +import io.netty.util.internal.logging.InternalLogger; + +/** + * Internal utilities to notify {@link Promise}s. + */ +public final class PromiseNotificationUtil { + + private PromiseNotificationUtil() { + } + + /** + * Try to cancel the {@link Promise} and log if {@code logger} is not {@code null} in case this fails. + */ + public static void tryCancel(Promise p, InternalLogger logger) { + if (!p.cancel(false) && logger != null) { + Throwable err = p.cause(); + if (err == null) { + logger.warn("Failed to cancel promise because it has succeeded already: {}", p); + } else { + logger.warn( + "Failed to cancel promise because it has failed already: {}, unnotified cause:", + p, err); + } + } + } + + /** + * Try to mark the {@link Promise} as success and log if {@code logger} is not {@code null} in case this fails. + */ + public static void trySuccess(Promise p, V result, InternalLogger logger) { + if (!p.trySuccess(result) && logger != null) { + Throwable err = p.cause(); + if (err == null) { + logger.warn("Failed to mark a promise as success because it has succeeded already: {}", p); + } else { + logger.warn( + "Failed to mark a promise as success because it has failed already: {}, unnotified cause:", + p, err); + } + } + } + + /** + * Try to mark the {@link Promise} as failure and log if {@code logger} is not {@code null} in case this fails. + */ + public static void tryFailure(Promise p, Throwable cause, InternalLogger logger) { + if (!p.tryFailure(cause) && logger != null) { + Throwable err = p.cause(); + if (err == null) { + logger.warn("Failed to mark a promise as failure because it has succeeded already: {}", p, cause); + } else if (logger.isWarnEnabled()) { + logger.warn( + "Failed to mark a promise as failure because it has failed already: {}, unnotified cause: {}", + p, ThrowableUtil.stackTraceToString(err), cause); + } + } + } + +} diff --git a/netty-util/src/main/java/io/netty/util/internal/ReadOnlyIterator.java b/netty-util/src/main/java/io/netty/util/internal/ReadOnlyIterator.java new file mode 100644 index 0000000..93d01ca --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/internal/ReadOnlyIterator.java @@ -0,0 +1,42 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.util.internal; + +import java.util.Iterator; + +public final class ReadOnlyIterator implements Iterator { + private final Iterator iterator; + + public ReadOnlyIterator(Iterator iterator) { + this.iterator = ObjectUtil.checkNotNull(iterator, "iterator"); + } + + @Override + public boolean hasNext() { + return iterator.hasNext(); + } + + @Override + public T next() { + return iterator.next(); + } + + @Override + public void remove() { + throw new UnsupportedOperationException("read-only"); + } +} diff --git a/netty-util/src/main/java/io/netty/util/internal/RecyclableArrayList.java b/netty-util/src/main/java/io/netty/util/internal/RecyclableArrayList.java new file mode 100644 index 0000000..804fec7 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/internal/RecyclableArrayList.java @@ -0,0 +1,148 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.util.internal; + +import io.netty.util.internal.ObjectPool.Handle; +import io.netty.util.internal.ObjectPool.ObjectCreator; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.RandomAccess; + +/** + * A simple list which is recyclable. This implementation does not allow {@code null} elements to be added. + */ +public final class RecyclableArrayList extends ArrayList { + + private static final long serialVersionUID = -8605125654176467947L; + + private static final int DEFAULT_INITIAL_CAPACITY = 8; + + private static final ObjectPool RECYCLER = ObjectPool.newPool( + new ObjectCreator() { + @Override + public RecyclableArrayList newObject(Handle handle) { + return new RecyclableArrayList(handle); + } + }); + + private boolean insertSinceRecycled; + + /** + * Create a new empty {@link RecyclableArrayList} instance + */ + public static RecyclableArrayList newInstance() { + return newInstance(DEFAULT_INITIAL_CAPACITY); + } + + /** + * Create a new empty {@link RecyclableArrayList} instance with the given capacity. + */ + public static RecyclableArrayList newInstance(int minCapacity) { + RecyclableArrayList ret = RECYCLER.get(); + ret.ensureCapacity(minCapacity); + return ret; + } + + private final Handle handle; + + private RecyclableArrayList(Handle handle) { + this(handle, DEFAULT_INITIAL_CAPACITY); + } + + private RecyclableArrayList(Handle handle, int initialCapacity) { + super(initialCapacity); + this.handle = handle; + } + + @Override + public boolean addAll(Collection c) { + checkNullElements(c); + if (super.addAll(c)) { + insertSinceRecycled = true; + return true; + } + return false; + } + + @Override + public boolean addAll(int index, Collection c) { + checkNullElements(c); + if (super.addAll(index, c)) { + insertSinceRecycled = true; + return true; + } + return false; + } + + private static void checkNullElements(Collection c) { + if (c instanceof RandomAccess && c instanceof List list) { + // produce less garbage + int size = list.size(); + for (int i = 0; i < size; i++) { + if (list.get(i) == null) { + throw new IllegalArgumentException("c contains null values"); + } + } + } else { + for (Object element : c) { + if (element == null) { + throw new IllegalArgumentException("c contains null values"); + } + } + } + } + + @Override + public boolean add(Object element) { + if (super.add(ObjectUtil.checkNotNull(element, "element"))) { + insertSinceRecycled = true; + return true; + } + return false; + } + + @Override + public void add(int index, Object element) { + super.add(index, ObjectUtil.checkNotNull(element, "element")); + insertSinceRecycled = true; + } + + @Override + public Object set(int index, Object element) { + Object old = super.set(index, ObjectUtil.checkNotNull(element, "element")); + insertSinceRecycled = true; + return old; + } + + /** + * Returns {@code true} if any elements where added or set. This will be reset once {@link #recycle()} was called. + */ + public boolean insertSinceRecycled() { + return insertSinceRecycled; + } + + /** + * Clear and recycle this instance. + */ + public boolean recycle() { + clear(); + insertSinceRecycled = false; + handle.recycle(this); + return true; + } +} diff --git a/netty-util/src/main/java/io/netty/util/internal/ReferenceCountUpdater.java b/netty-util/src/main/java/io/netty/util/internal/ReferenceCountUpdater.java new file mode 100644 index 0000000..89dc430 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/internal/ReferenceCountUpdater.java @@ -0,0 +1,188 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.internal; + +import io.netty.util.IllegalReferenceCountException; +import io.netty.util.ReferenceCounted; +import java.util.concurrent.atomic.AtomicIntegerFieldUpdater; +import static io.netty.util.internal.ObjectUtil.checkPositive; + +/** + * Common logic for {@link ReferenceCounted} implementations + */ +public abstract class ReferenceCountUpdater { + /* + * Implementation notes: + * + * For the updated int field: + * Even => "real" refcount is (refCnt >>> 1) + * Odd => "real" refcount is 0 + * + * (x & y) appears to be surprisingly expensive relative to (x == y). Thus this class uses + * a fast-path in some places for most common low values when checking for live (even) refcounts, + * for example: if (rawCnt == 2 || rawCnt == 4 || (rawCnt & 1) == 0) { ... + */ + + protected ReferenceCountUpdater() { + } + + public static long getUnsafeOffset(Class clz, String fieldName) { + try { + if (PlatformDependent.hasUnsafe()) { + return PlatformDependent.objectFieldOffset(clz.getDeclaredField(fieldName)); + } + } catch (Throwable ignore) { + // fall-back + } + return -1; + } + + protected abstract AtomicIntegerFieldUpdater updater(); + + protected abstract long unsafeOffset(); + + public final int initialValue() { + return 2; + } + + public void setInitialValue(T instance) { + final long offset = unsafeOffset(); + if (offset == -1) { + updater().set(instance, initialValue()); + } else { + PlatformDependent.safeConstructPutInt(instance, offset, initialValue()); + } + } + + private static int realRefCnt(int rawCnt) { + return rawCnt != 2 && rawCnt != 4 && (rawCnt & 1) != 0 ? 0 : rawCnt >>> 1; + } + + /** + * Like {@link #realRefCnt(int)} but throws if refCnt == 0 + */ + private static int toLiveRealRefCnt(int rawCnt, int decrement) { + if (rawCnt == 2 || rawCnt == 4 || (rawCnt & 1) == 0) { + return rawCnt >>> 1; + } + // odd rawCnt => already deallocated + throw new IllegalReferenceCountException(0, -decrement); + } + + private int nonVolatileRawCnt(T instance) { + // TODO: Once we compile against later versions of Java we can replace the Unsafe usage here by varhandles. + final long offset = unsafeOffset(); + return offset != -1 ? PlatformDependent.getInt(instance, offset) : updater().get(instance); + } + + public final int refCnt(T instance) { + return realRefCnt(updater().get(instance)); + } + + public final boolean isLiveNonVolatile(T instance) { + final long offset = unsafeOffset(); + final int rawCnt = offset != -1 ? PlatformDependent.getInt(instance, offset) : updater().get(instance); + + // The "real" ref count is > 0 if the rawCnt is even. + return rawCnt == 2 || rawCnt == 4 || rawCnt == 6 || rawCnt == 8 || (rawCnt & 1) == 0; + } + + /** + * An unsafe operation that sets the reference count directly + */ + public final void setRefCnt(T instance, int refCnt) { + updater().set(instance, refCnt > 0 ? refCnt << 1 : 1); // overflow OK here + } + + /** + * Resets the reference count to 1 + */ + public final void resetRefCnt(T instance) { + // no need of a volatile set, it should happen in a quiescent state + updater().lazySet(instance, initialValue()); + } + + public final T retain(T instance) { + return retain0(instance, 1, 2); + } + + public final T retain(T instance, int increment) { + // all changes to the raw count are 2x the "real" change - overflow is OK + int rawIncrement = checkPositive(increment, "increment") << 1; + return retain0(instance, increment, rawIncrement); + } + + // rawIncrement == increment << 1 + private T retain0(T instance, final int increment, final int rawIncrement) { + int oldRef = updater().getAndAdd(instance, rawIncrement); + if (oldRef != 2 && oldRef != 4 && (oldRef & 1) != 0) { + throw new IllegalReferenceCountException(0, increment); + } + // don't pass 0! + if ((oldRef <= 0 && oldRef + rawIncrement >= 0) + || (oldRef >= 0 && oldRef + rawIncrement < oldRef)) { + // overflow case + updater().getAndAdd(instance, -rawIncrement); + throw new IllegalReferenceCountException(realRefCnt(oldRef), increment); + } + return instance; + } + + public final boolean release(T instance) { + int rawCnt = nonVolatileRawCnt(instance); + return rawCnt == 2 ? tryFinalRelease0(instance, 2) || retryRelease0(instance, 1) + : nonFinalRelease0(instance, 1, rawCnt, toLiveRealRefCnt(rawCnt, 1)); + } + + public final boolean release(T instance, int decrement) { + int rawCnt = nonVolatileRawCnt(instance); + int realCnt = toLiveRealRefCnt(rawCnt, checkPositive(decrement, "decrement")); + return decrement == realCnt ? tryFinalRelease0(instance, rawCnt) || retryRelease0(instance, decrement) + : nonFinalRelease0(instance, decrement, rawCnt, realCnt); + } + + private boolean tryFinalRelease0(T instance, int expectRawCnt) { + return updater().compareAndSet(instance, expectRawCnt, 1); // any odd number will work + } + + private boolean nonFinalRelease0(T instance, int decrement, int rawCnt, int realCnt) { + if (decrement < realCnt + // all changes to the raw count are 2x the "real" change - overflow is OK + && updater().compareAndSet(instance, rawCnt, rawCnt - (decrement << 1))) { + return false; + } + return retryRelease0(instance, decrement); + } + + private boolean retryRelease0(T instance, int decrement) { + for (; ; ) { + int rawCnt = updater().get(instance), realCnt = toLiveRealRefCnt(rawCnt, decrement); + if (decrement == realCnt) { + if (tryFinalRelease0(instance, rawCnt)) { + return true; + } + } else if (decrement < realCnt) { + // all changes to the raw count are 2x the "real" change + if (updater().compareAndSet(instance, rawCnt, rawCnt - (decrement << 1))) { + return false; + } + } else { + throw new IllegalReferenceCountException(realCnt, -decrement); + } + Thread.yield(); // this benefits throughput under high contention + } + } +} diff --git a/netty-util/src/main/java/io/netty/util/internal/ReflectionUtil.java b/netty-util/src/main/java/io/netty/util/internal/ReflectionUtil.java new file mode 100644 index 0000000..ece493b --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/internal/ReflectionUtil.java @@ -0,0 +1,53 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.internal; + +import java.lang.reflect.AccessibleObject; + +public final class ReflectionUtil { + + private ReflectionUtil() { + } + + /** + * Try to call {@link AccessibleObject#setAccessible(boolean)} but will catch any {@link SecurityException} and + * {@link java.lang.reflect.InaccessibleObjectException} and return it. + * The caller must check if it returns {@code null} and if not handle the returned exception. + */ + public static Throwable trySetAccessible(AccessibleObject object, boolean checkAccessible) { + if (checkAccessible && !PlatformDependent0.isExplicitTryReflectionSetAccessible()) { + return new UnsupportedOperationException("Reflective setAccessible(true) disabled"); + } + try { + object.setAccessible(true); + return null; + } catch (SecurityException e) { + return e; + } catch (RuntimeException e) { + return handleInaccessibleObjectException(e); + } + } + + private static RuntimeException handleInaccessibleObjectException(RuntimeException e) { + // JDK 9 can throw an inaccessible object exception here; since Netty compiles + // against JDK 7 and this exception was only added in JDK 9, we have to weakly + // check the type + if ("java.lang.reflect.InaccessibleObjectException".equals(e.getClass().getName())) { + return e; + } + throw e; + } +} diff --git a/netty-util/src/main/java/io/netty/util/internal/ResourcesUtil.java b/netty-util/src/main/java/io/netty/util/internal/ResourcesUtil.java new file mode 100644 index 0000000..cf9ac3e --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/internal/ResourcesUtil.java @@ -0,0 +1,41 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.util.internal; + +import java.io.File; +import java.io.UnsupportedEncodingException; +import java.net.URLDecoder; +import java.nio.charset.StandardCharsets; + +/** + * A utility class that provides various common operations and constants + * related to loading resources + */ +public final class ResourcesUtil { + + /** + * Returns a {@link File} named {@code fileName} associated with {@link Class} {@code resourceClass} . + * + * @param resourceClass The associated class + * @param fileName The file name + * @return The file named {@code fileName} associated with {@link Class} {@code resourceClass} . + */ + public static File getFile(Class resourceClass, String fileName) { + return new File(URLDecoder.decode(resourceClass.getResource(fileName).getFile(), StandardCharsets.UTF_8)); + } + + private ResourcesUtil() { + } +} diff --git a/netty-util/src/main/java/io/netty/util/internal/SocketUtils.java b/netty-util/src/main/java/io/netty/util/internal/SocketUtils.java new file mode 100644 index 0000000..2778f00 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/internal/SocketUtils.java @@ -0,0 +1,118 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.internal; + +import java.io.IOException; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.NetworkInterface; +import java.net.ServerSocket; +import java.net.Socket; +import java.net.SocketAddress; +import java.net.SocketException; +import java.net.UnknownHostException; +import java.nio.channels.DatagramChannel; +import java.nio.channels.ServerSocketChannel; +import java.nio.channels.SocketChannel; +import java.util.Collections; +import java.util.Enumeration; + +/** + * Provides socket operations with privileges enabled. + */ +public final class SocketUtils { + + private static final Enumeration EMPTY = Collections.enumeration(Collections.emptyList()); + + private SocketUtils() { + } + + @SuppressWarnings("unchecked") + private static Enumeration empty() { + return (Enumeration) EMPTY; + } + + public static void connect(final Socket socket, final SocketAddress remoteAddress, final int timeout) + throws IOException { + socket.connect(remoteAddress, timeout); + } + + public static void bind(final Socket socket, final SocketAddress bindpoint) throws IOException { + socket.bind(bindpoint); + } + + public static boolean connect(final SocketChannel socketChannel, final SocketAddress remoteAddress) + throws IOException { + return socketChannel.connect(remoteAddress); + } + + @SuppressJava6Requirement(reason = "Usage guarded by java version check") + public static void bind(final SocketChannel socketChannel, final SocketAddress address) throws IOException { + socketChannel.bind(address); + } + + public static SocketChannel accept(final ServerSocketChannel serverSocketChannel) throws IOException { + return serverSocketChannel.accept(); + } + + @SuppressJava6Requirement(reason = "Usage guarded by java version check") + public static void bind(final DatagramChannel networkChannel, final SocketAddress address) throws IOException { + networkChannel.bind(address); + } + + public static SocketAddress localSocketAddress(final ServerSocket socket) { + return socket.getLocalSocketAddress(); + } + + public static InetAddress addressByName(final String hostname) throws UnknownHostException { + return InetAddress.getByName(hostname); + } + + public static InetAddress[] allAddressesByName(final String hostname) throws UnknownHostException { + return InetAddress.getAllByName(hostname); + } + + public static InetSocketAddress socketAddress(final String hostname, final int port) { + return new InetSocketAddress(hostname, port); + } + + public static Enumeration addressesFromNetworkInterface(final NetworkInterface intf) { + Enumeration addresses = intf.getInetAddresses(); + // Android seems to sometimes return null even if this is not a valid return value by the api docs. + // Just return an empty Enumeration in this case. + // See https://github.com/netty/netty/issues/10045 + if (addresses == null) { + return empty(); + } + return addresses; + } + + @SuppressJava6Requirement(reason = "Usage guarded by java version check") + public static InetAddress loopbackAddress() { + if (PlatformDependent.javaVersion() >= 7) { + return InetAddress.getLoopbackAddress(); + } + try { + return InetAddress.getByName(null); + } catch (UnknownHostException e) { + throw new IllegalStateException(e); + } + } + + public static byte[] hardwareAddressFromNetworkInterface(final NetworkInterface intf) throws SocketException { + return intf.getHardwareAddress(); + } +} diff --git a/netty-util/src/main/java/io/netty/util/internal/StringUtil.java b/netty-util/src/main/java/io/netty/util/internal/StringUtil.java new file mode 100644 index 0000000..1a2bcd0 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/internal/StringUtil.java @@ -0,0 +1,719 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.internal; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Iterator; +import java.util.List; +import static io.netty.util.internal.ObjectUtil.checkNotNull; + +/** + * String utility class. + */ +public final class StringUtil { + + public static final String EMPTY_STRING = ""; + public static final String NEWLINE = SystemPropertyUtil.get("line.separator", "\n"); + + public static final char DOUBLE_QUOTE = '\"'; + public static final char COMMA = ','; + public static final char LINE_FEED = '\n'; + public static final char CARRIAGE_RETURN = '\r'; + public static final char TAB = '\t'; + public static final char SPACE = 0x20; + + private static final String[] BYTE2HEX_PAD = new String[256]; + private static final String[] BYTE2HEX_NOPAD = new String[256]; + private static final byte[] HEX2B; + + /** + * 2 - Quote character at beginning and end. + * 5 - Extra allowance for anticipated escape characters that may be added. + */ + private static final int CSV_NUMBER_ESCAPE_CHARACTERS = 2 + 5; + private static final char PACKAGE_SEPARATOR_CHAR = '.'; + + static { + // Generate the lookup table that converts a byte into a 2-digit hexadecimal integer. + for (int i = 0; i < BYTE2HEX_PAD.length; i++) { + String str = Integer.toHexString(i); + BYTE2HEX_PAD[i] = i > 0xf ? str : ('0' + str); + BYTE2HEX_NOPAD[i] = str; + } + // Generate the lookup table that converts an hex char into its decimal value: + // the size of the table is such that the JVM is capable of save any bounds-check + // if a char type is used as an index. + HEX2B = new byte[Character.MAX_VALUE + 1]; + Arrays.fill(HEX2B, (byte) -1); + HEX2B['0'] = 0; + HEX2B['1'] = 1; + HEX2B['2'] = 2; + HEX2B['3'] = 3; + HEX2B['4'] = 4; + HEX2B['5'] = 5; + HEX2B['6'] = 6; + HEX2B['7'] = 7; + HEX2B['8'] = 8; + HEX2B['9'] = 9; + HEX2B['A'] = 10; + HEX2B['B'] = 11; + HEX2B['C'] = 12; + HEX2B['D'] = 13; + HEX2B['E'] = 14; + HEX2B['F'] = 15; + HEX2B['a'] = 10; + HEX2B['b'] = 11; + HEX2B['c'] = 12; + HEX2B['d'] = 13; + HEX2B['e'] = 14; + HEX2B['f'] = 15; + } + + private StringUtil() { + // Unused. + } + + /** + * Get the item after one char delim if the delim is found (else null). + * This operation is a simplified and optimized + * version of {@link String#split(String, int)}. + */ + public static String substringAfter(String value, char delim) { + int pos = value.indexOf(delim); + if (pos >= 0) { + return value.substring(pos + 1); + } + return null; + } + + /** + * Get the item before one char delim if the delim is found (else null). + * This operation is a simplified and optimized + * version of {@link String#split(String, int)}. + */ + public static String substringBefore(String value, char delim) { + int pos = value.indexOf(delim); + if (pos >= 0) { + return value.substring(0, pos); + } + return null; + } + + /** + * Checks if two strings have the same suffix of specified length + * + * @param s string + * @param p string + * @param len length of the common suffix + * @return true if both s and p are not null and both have the same suffix. Otherwise - false + */ + public static boolean commonSuffixOfLength(String s, String p, int len) { + return s != null && p != null && len >= 0 && s.regionMatches(s.length() - len, p, p.length() - len, len); + } + + /** + * Converts the specified byte value into a 2-digit hexadecimal integer. + */ + public static String byteToHexStringPadded(int value) { + return BYTE2HEX_PAD[value & 0xff]; + } + + /** + * Converts the specified byte value into a 2-digit hexadecimal integer and appends it to the specified buffer. + */ + public static T byteToHexStringPadded(T buf, int value) { + try { + buf.append(byteToHexStringPadded(value)); + } catch (IOException e) { + PlatformDependent.throwException(e); + } + return buf; + } + + /** + * Converts the specified byte array into a hexadecimal value. + */ + public static String toHexStringPadded(byte[] src) { + return toHexStringPadded(src, 0, src.length); + } + + /** + * Converts the specified byte array into a hexadecimal value. + */ + public static String toHexStringPadded(byte[] src, int offset, int length) { + return toHexStringPadded(new StringBuilder(length << 1), src, offset, length).toString(); + } + + /** + * Converts the specified byte array into a hexadecimal value and appends it to the specified buffer. + */ + public static T toHexStringPadded(T dst, byte[] src) { + return toHexStringPadded(dst, src, 0, src.length); + } + + /** + * Converts the specified byte array into a hexadecimal value and appends it to the specified buffer. + */ + public static T toHexStringPadded(T dst, byte[] src, int offset, int length) { + final int end = offset + length; + for (int i = offset; i < end; i++) { + byteToHexStringPadded(dst, src[i]); + } + return dst; + } + + /** + * Converts the specified byte value into a hexadecimal integer. + */ + public static String byteToHexString(int value) { + return BYTE2HEX_NOPAD[value & 0xff]; + } + + /** + * Converts the specified byte value into a hexadecimal integer and appends it to the specified buffer. + */ + public static T byteToHexString(T buf, int value) { + try { + buf.append(byteToHexString(value)); + } catch (IOException e) { + PlatformDependent.throwException(e); + } + return buf; + } + + /** + * Converts the specified byte array into a hexadecimal value. + */ + public static String toHexString(byte[] src) { + return toHexString(src, 0, src.length); + } + + /** + * Converts the specified byte array into a hexadecimal value. + */ + public static String toHexString(byte[] src, int offset, int length) { + return toHexString(new StringBuilder(length << 1), src, offset, length).toString(); + } + + /** + * Converts the specified byte array into a hexadecimal value and appends it to the specified buffer. + */ + public static T toHexString(T dst, byte[] src) { + return toHexString(dst, src, 0, src.length); + } + + /** + * Converts the specified byte array into a hexadecimal value and appends it to the specified buffer. + */ + public static T toHexString(T dst, byte[] src, int offset, int length) { + assert length >= 0; + if (length == 0) { + return dst; + } + + final int end = offset + length; + final int endMinusOne = end - 1; + int i; + + // Skip preceding zeroes. + for (i = offset; i < endMinusOne; i++) { + if (src[i] != 0) { + break; + } + } + + byteToHexString(dst, src[i++]); + int remaining = end - i; + toHexStringPadded(dst, src, i, remaining); + + return dst; + } + + /** + * Helper to decode half of a hexadecimal number from a string. + * + * @param c The ASCII character of the hexadecimal number to decode. + * Must be in the range {@code [0-9a-fA-F]}. + * @return The hexadecimal value represented in the ASCII character + * given, or {@code -1} if the character is invalid. + */ + public static int decodeHexNibble(final char c) { + // Character.digit() is not used here, as it addresses a larger + // set of characters (both ASCII and full-width latin letters). + return HEX2B[c]; + } + + /** + * Helper to decode half of a hexadecimal number from a string. + * + * @param b The ASCII character of the hexadecimal number to decode. + * Must be in the range {@code [0-9a-fA-F]}. + * @return The hexadecimal value represented in the ASCII character + * given, or {@code -1} if the character is invalid. + */ + public static int decodeHexNibble(final byte b) { + // Character.digit() is not used here, as it addresses a larger + // set of characters (both ASCII and full-width latin letters). + return HEX2B[b]; + } + + /** + * Decode a 2-digit hex byte from within a string. + */ + public static byte decodeHexByte(CharSequence s, int pos) { + int hi = decodeHexNibble(s.charAt(pos)); + int lo = decodeHexNibble(s.charAt(pos + 1)); + if (hi == -1 || lo == -1) { + throw new IllegalArgumentException(String.format( + "invalid hex byte '%s' at index %d of '%s'", s.subSequence(pos, pos + 2), pos, s)); + } + return (byte) ((hi << 4) + lo); + } + + /** + * Decodes part of a string with hex dump + * + * @param hexDump a {@link CharSequence} which contains the hex dump + * @param fromIndex start of hex dump in {@code hexDump} + * @param length hex string length + */ + public static byte[] decodeHexDump(CharSequence hexDump, int fromIndex, int length) { + if (length < 0 || (length & 1) != 0) { + throw new IllegalArgumentException("length: " + length); + } + if (length == 0) { + return EmptyArrays.EMPTY_BYTES; + } + byte[] bytes = new byte[length >>> 1]; + for (int i = 0; i < length; i += 2) { + bytes[i >>> 1] = decodeHexByte(hexDump, fromIndex + i); + } + return bytes; + } + + /** + * Decodes a hex dump + */ + public static byte[] decodeHexDump(CharSequence hexDump) { + return decodeHexDump(hexDump, 0, hexDump.length()); + } + + /** + * The shortcut to {@link #simpleClassName(Class) simpleClassName(o.getClass())}. + */ + public static String simpleClassName(Object o) { + if (o == null) { + return "null_object"; + } else { + return simpleClassName(o.getClass()); + } + } + + /** + * Generates a simplified name from a {@link Class}. Similar to {@link Class#getSimpleName()}, but it works fine + * with anonymous classes. + */ + public static String simpleClassName(Class clazz) { + String className = checkNotNull(clazz, "clazz").getName(); + final int lastDotIdx = className.lastIndexOf(PACKAGE_SEPARATOR_CHAR); + if (lastDotIdx > -1) { + return className.substring(lastDotIdx + 1); + } + return className; + } + + /** + * Escapes the specified value, if necessary according to + * RFC-4180. + * + * @param value The value which will be escaped according to + * RFC-4180 + * @return {@link CharSequence} the escaped value if necessary, or the value unchanged + */ + public static CharSequence escapeCsv(CharSequence value) { + return escapeCsv(value, false); + } + + /** + * Escapes the specified value, if necessary according to + * RFC-4180. + * + * @param value The value which will be escaped according to + * RFC-4180 + * @param trimWhiteSpace The value will first be trimmed of its optional white-space characters, + * according to RFC-7230 + * @return {@link CharSequence} the escaped value if necessary, or the value unchanged + */ + public static CharSequence escapeCsv(CharSequence value, boolean trimWhiteSpace) { + int length = checkNotNull(value, "value").length(); + int start; + int last; + if (trimWhiteSpace) { + start = indexOfFirstNonOwsChar(value, length); + last = indexOfLastNonOwsChar(value, start, length); + } else { + start = 0; + last = length - 1; + } + if (start > last) { + return EMPTY_STRING; + } + + int firstUnescapedSpecial = -1; + boolean quoted = false; + if (isDoubleQuote(value.charAt(start))) { + quoted = isDoubleQuote(value.charAt(last)) && last > start; + if (quoted) { + start++; + last--; + } else { + firstUnescapedSpecial = start; + } + } + + if (firstUnescapedSpecial < 0) { + if (quoted) { + for (int i = start; i <= last; i++) { + if (isDoubleQuote(value.charAt(i))) { + if (i == last || !isDoubleQuote(value.charAt(i + 1))) { + firstUnescapedSpecial = i; + break; + } + i++; + } + } + } else { + for (int i = start; i <= last; i++) { + char c = value.charAt(i); + if (c == LINE_FEED || c == CARRIAGE_RETURN || c == COMMA) { + firstUnescapedSpecial = i; + break; + } + if (isDoubleQuote(c)) { + if (i == last || !isDoubleQuote(value.charAt(i + 1))) { + firstUnescapedSpecial = i; + break; + } + i++; + } + } + } + + if (firstUnescapedSpecial < 0) { + // Special characters is not found or all of them already escaped. + // In the most cases returns a same string. New string will be instantiated (via StringBuilder) + // only if it really needed. It's important to prevent GC extra load. + return quoted ? value.subSequence(start - 1, last + 2) : value.subSequence(start, last + 1); + } + } + + StringBuilder result = new StringBuilder(last - start + 1 + CSV_NUMBER_ESCAPE_CHARACTERS); + result.append(DOUBLE_QUOTE).append(value, start, firstUnescapedSpecial); + for (int i = firstUnescapedSpecial; i <= last; i++) { + char c = value.charAt(i); + if (isDoubleQuote(c)) { + result.append(DOUBLE_QUOTE); + if (i < last && isDoubleQuote(value.charAt(i + 1))) { + i++; + } + } + result.append(c); + } + return result.append(DOUBLE_QUOTE); + } + + /** + * Unescapes the specified escaped CSV field, if necessary according to + * RFC-4180. + * + * @param value The escaped CSV field which will be unescaped according to + * RFC-4180 + * @return {@link CharSequence} the unescaped value if necessary, or the value unchanged + */ + public static CharSequence unescapeCsv(CharSequence value) { + int length = checkNotNull(value, "value").length(); + if (length == 0) { + return value; + } + int last = length - 1; + boolean quoted = isDoubleQuote(value.charAt(0)) && isDoubleQuote(value.charAt(last)) && length != 1; + if (!quoted) { + validateCsvFormat(value); + return value; + } + StringBuilder unescaped = InternalThreadLocalMap.get().stringBuilder(); + for (int i = 1; i < last; i++) { + char current = value.charAt(i); + if (current == DOUBLE_QUOTE) { + if (isDoubleQuote(value.charAt(i + 1)) && (i + 1) != last) { + // Followed by a double-quote but not the last character + // Just skip the next double-quote + i++; + } else { + // Not followed by a double-quote or the following double-quote is the last character + throw newInvalidEscapedCsvFieldException(value, i); + } + } + unescaped.append(current); + } + return unescaped.toString(); + } + + /** + * Unescapes the specified escaped CSV fields according to + * RFC-4180. + * + * @param value A string with multiple CSV escaped fields which will be unescaped according to + * RFC-4180 + * @return {@link List} the list of unescaped fields + */ + public static List unescapeCsvFields(CharSequence value) { + List unescaped = new ArrayList(2); + StringBuilder current = InternalThreadLocalMap.get().stringBuilder(); + boolean quoted = false; + int last = value.length() - 1; + for (int i = 0; i <= last; i++) { + char c = value.charAt(i); + if (quoted) { + switch (c) { + case DOUBLE_QUOTE: + if (i == last) { + // Add the last field and return + unescaped.add(current.toString()); + return unescaped; + } + char next = value.charAt(++i); + if (next == DOUBLE_QUOTE) { + // 2 double-quotes should be unescaped to one + current.append(DOUBLE_QUOTE); + break; + } + if (next == COMMA) { + // This is the end of a field. Let's start to parse the next field. + quoted = false; + unescaped.add(current.toString()); + current.setLength(0); + break; + } + // double-quote followed by other character is invalid + throw newInvalidEscapedCsvFieldException(value, i - 1); + default: + current.append(c); + } + } else { + switch (c) { + case COMMA: + // Start to parse the next field + unescaped.add(current.toString()); + current.setLength(0); + break; + case DOUBLE_QUOTE: + if (current.length() == 0) { + quoted = true; + break; + } + // double-quote appears without being enclosed with double-quotes + // fall through + case LINE_FEED: + // fall through + case CARRIAGE_RETURN: + // special characters appears without being enclosed with double-quotes + throw newInvalidEscapedCsvFieldException(value, i); + default: + current.append(c); + } + } + } + if (quoted) { + throw newInvalidEscapedCsvFieldException(value, last); + } + unescaped.add(current.toString()); + return unescaped; + } + + /** + * Validate if {@code value} is a valid csv field without double-quotes. + * + * @throws IllegalArgumentException if {@code value} needs to be encoded with double-quotes. + */ + private static void validateCsvFormat(CharSequence value) { + int length = value.length(); + for (int i = 0; i < length; i++) { + switch (value.charAt(i)) { + case DOUBLE_QUOTE: + case LINE_FEED: + case CARRIAGE_RETURN: + case COMMA: + // If value contains any special character, it should be enclosed with double-quotes + throw newInvalidEscapedCsvFieldException(value, i); + default: + } + } + } + + private static IllegalArgumentException newInvalidEscapedCsvFieldException(CharSequence value, int index) { + return new IllegalArgumentException("invalid escaped CSV field: " + value + " index: " + index); + } + + /** + * Get the length of a string, {@code null} input is considered {@code 0} length. + */ + public static int length(String s) { + return s == null ? 0 : s.length(); + } + + /** + * Determine if a string is {@code null} or {@link String#isEmpty()} returns {@code true}. + */ + public static boolean isNullOrEmpty(String s) { + return s == null || s.isEmpty(); + } + + /** + * Find the index of the first non-white space character in {@code s} starting at {@code offset}. + * + * @param seq The string to search. + * @param offset The offset to start searching at. + * @return the index of the first non-white space character or <{@code -1} if none was found. + */ + public static int indexOfNonWhiteSpace(CharSequence seq, int offset) { + for (; offset < seq.length(); ++offset) { + if (!Character.isWhitespace(seq.charAt(offset))) { + return offset; + } + } + return -1; + } + + /** + * Find the index of the first white space character in {@code s} starting at {@code offset}. + * + * @param seq The string to search. + * @param offset The offset to start searching at. + * @return the index of the first white space character or <{@code -1} if none was found. + */ + public static int indexOfWhiteSpace(CharSequence seq, int offset) { + for (; offset < seq.length(); ++offset) { + if (Character.isWhitespace(seq.charAt(offset))) { + return offset; + } + } + return -1; + } + + /** + * Determine if {@code c} lies within the range of values defined for + * Surrogate Code Point. + * + * @param c the character to check. + * @return {@code true} if {@code c} lies within the range of values defined for + * Surrogate Code Point. {@code false} otherwise. + */ + public static boolean isSurrogate(char c) { + return c >= '\uD800' && c <= '\uDFFF'; + } + + private static boolean isDoubleQuote(char c) { + return c == DOUBLE_QUOTE; + } + + /** + * Determine if the string {@code s} ends with the char {@code c}. + * + * @param s the string to test + * @param c the tested char + * @return true if {@code s} ends with the char {@code c} + */ + public static boolean endsWith(CharSequence s, char c) { + int len = s.length(); + return len > 0 && s.charAt(len - 1) == c; + } + + /** + * Trim optional white-space characters from the specified value, + * according to RFC-7230. + * + * @param value the value to trim + * @return {@link CharSequence} the trimmed value if necessary, or the value unchanged + */ + public static CharSequence trimOws(CharSequence value) { + final int length = value.length(); + if (length == 0) { + return value; + } + int start = indexOfFirstNonOwsChar(value, length); + int end = indexOfLastNonOwsChar(value, start, length); + return start == 0 && end == length - 1 ? value : value.subSequence(start, end + 1); + } + + /** + * Returns a char sequence that contains all {@code elements} joined by a given separator. + * + * @param separator for each element + * @param elements to join together + * @return a char sequence joined by a given separator. + */ + public static CharSequence join(CharSequence separator, Iterable elements) { + ObjectUtil.checkNotNull(separator, "separator"); + ObjectUtil.checkNotNull(elements, "elements"); + + Iterator iterator = elements.iterator(); + if (!iterator.hasNext()) { + return EMPTY_STRING; + } + + CharSequence firstElement = iterator.next(); + if (!iterator.hasNext()) { + return firstElement; + } + + StringBuilder builder = new StringBuilder(firstElement); + do { + builder.append(separator).append(iterator.next()); + } while (iterator.hasNext()); + + return builder; + } + + /** + * @return {@code length} if no OWS is found. + */ + private static int indexOfFirstNonOwsChar(CharSequence value, int length) { + int i = 0; + while (i < length && isOws(value.charAt(i))) { + i++; + } + return i; + } + + /** + * @return {@code start} if no OWS is found. + */ + private static int indexOfLastNonOwsChar(CharSequence value, int start, int length) { + int i = length - 1; + while (i > start && isOws(value.charAt(i))) { + i--; + } + return i; + } + + private static boolean isOws(char c) { + return c == SPACE || c == TAB; + } + +} diff --git a/netty-util/src/main/java/io/netty/util/internal/SuppressJava6Requirement.java b/netty-util/src/main/java/io/netty/util/internal/SuppressJava6Requirement.java new file mode 100644 index 0000000..bdb3d11 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/internal/SuppressJava6Requirement.java @@ -0,0 +1,32 @@ +/* + * Copyright 2018 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.util.internal; + +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * Annotation to suppress the Java 6 source code requirement checks for a method. + */ +@Retention(RetentionPolicy.CLASS) +@Target({ElementType.METHOD, ElementType.CONSTRUCTOR, ElementType.TYPE}) +public @interface SuppressJava6Requirement { + + String reason(); +} diff --git a/netty-util/src/main/java/io/netty/util/internal/SystemPropertyUtil.java b/netty-util/src/main/java/io/netty/util/internal/SystemPropertyUtil.java new file mode 100644 index 0000000..c363911 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/internal/SystemPropertyUtil.java @@ -0,0 +1,164 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.internal; + +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; +import static io.netty.util.internal.ObjectUtil.checkNonEmpty; + +/** + * A collection of utility methods to retrieve and parse the values of the Java system properties. + */ +public final class SystemPropertyUtil { + + private static final InternalLogger logger = InternalLoggerFactory.getInstance(SystemPropertyUtil.class); + + /** + * Returns {@code true} if and only if the system property with the specified {@code key} + * exists. + */ + public static boolean contains(String key) { + return get(key) != null; + } + + /** + * Returns the value of the Java system property with the specified + * {@code key}, while falling back to {@code null} if the property access fails. + * + * @return the property value or {@code null} + */ + public static String get(String key) { + return get(key, null); + } + + /** + * Returns the value of the Java system property with the specified + * {@code key}, while falling back to the specified default value if + * the property access fails. + * + * @return the property value. + * {@code def} if there's no such property or if an access to the + * specified property is not allowed. + */ + public static String get(final String key, String def) { + checkNonEmpty(key, "key"); + String value = System.getProperty(key); + if (value == null) { + return def; + } + return value; + } + + /** + * Returns the value of the Java system property with the specified + * {@code key}, while falling back to the specified default value if + * the property access fails. + * + * @return the property value. + * {@code def} if there's no such property or if an access to the + * specified property is not allowed. + */ + public static boolean getBoolean(String key, boolean def) { + String value = get(key); + if (value == null) { + return def; + } + + value = value.trim().toLowerCase(); + if (value.isEmpty()) { + return def; + } + + if ("true".equals(value) || "yes".equals(value) || "1".equals(value)) { + return true; + } + + if ("false".equals(value) || "no".equals(value) || "0".equals(value)) { + return false; + } + + logger.warn( + "Unable to parse the boolean system property '{}':{} - using the default value: {}", + key, value, def + ); + + return def; + } + + /** + * Returns the value of the Java system property with the specified + * {@code key}, while falling back to the specified default value if + * the property access fails. + * + * @return the property value. + * {@code def} if there's no such property or if an access to the + * specified property is not allowed. + */ + public static int getInt(String key, int def) { + String value = get(key); + if (value == null) { + return def; + } + + value = value.trim(); + try { + return Integer.parseInt(value); + } catch (Exception e) { + // Ignore + } + + logger.warn( + "Unable to parse the integer system property '{}':{} - using the default value: {}", + key, value, def + ); + + return def; + } + + /** + * Returns the value of the Java system property with the specified + * {@code key}, while falling back to the specified default value if + * the property access fails. + * + * @return the property value. + * {@code def} if there's no such property or if an access to the + * specified property is not allowed. + */ + public static long getLong(String key, long def) { + String value = get(key); + if (value == null) { + return def; + } + + value = value.trim(); + try { + return Long.parseLong(value); + } catch (Exception e) { + // Ignore + } + + logger.warn( + "Unable to parse the long integer system property '{}':{} - using the default value: {}", + key, value, def + ); + + return def; + } + + private SystemPropertyUtil() { + // Unused + } +} diff --git a/netty-util/src/main/java/io/netty/util/internal/ThreadExecutorMap.java b/netty-util/src/main/java/io/netty/util/internal/ThreadExecutorMap.java new file mode 100644 index 0000000..e6b182d --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/internal/ThreadExecutorMap.java @@ -0,0 +1,96 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.internal; + +import io.netty.util.concurrent.EventExecutor; +import io.netty.util.concurrent.FastThreadLocal; +import java.util.concurrent.Executor; +import java.util.concurrent.ThreadFactory; + +/** + * Allow to retrieve the {@link EventExecutor} for the calling {@link Thread}. + */ +public final class ThreadExecutorMap { + + private static final FastThreadLocal mappings = new FastThreadLocal(); + + private ThreadExecutorMap() { + } + + /** + * Returns the current {@link EventExecutor} that uses the {@link Thread}, or {@code null} if none / unknown. + */ + public static EventExecutor currentExecutor() { + return mappings.get(); + } + + /** + * Set the current {@link EventExecutor} that is used by the {@link Thread}. + */ + private static void setCurrentEventExecutor(EventExecutor executor) { + mappings.set(executor); + } + + /** + * Decorate the given {@link Executor} and ensure {@link #currentExecutor()} will return {@code eventExecutor} + * when called from within the {@link Runnable} during execution. + */ + public static Executor apply(final Executor executor, final EventExecutor eventExecutor) { + ObjectUtil.checkNotNull(executor, "executor"); + ObjectUtil.checkNotNull(eventExecutor, "eventExecutor"); + return new Executor() { + @Override + public void execute(final Runnable command) { + executor.execute(apply(command, eventExecutor)); + } + }; + } + + /** + * Decorate the given {@link Runnable} and ensure {@link #currentExecutor()} will return {@code eventExecutor} + * when called from within the {@link Runnable} during execution. + */ + public static Runnable apply(final Runnable command, final EventExecutor eventExecutor) { + ObjectUtil.checkNotNull(command, "command"); + ObjectUtil.checkNotNull(eventExecutor, "eventExecutor"); + return new Runnable() { + @Override + public void run() { + setCurrentEventExecutor(eventExecutor); + try { + command.run(); + } finally { + setCurrentEventExecutor(null); + } + } + }; + } + + /** + * Decorate the given {@link ThreadFactory} and ensure {@link #currentExecutor()} will return {@code eventExecutor} + * when called from within the {@link Runnable} during execution. + */ + public static ThreadFactory apply(final ThreadFactory threadFactory, final EventExecutor eventExecutor) { + ObjectUtil.checkNotNull(threadFactory, "threadFactory"); + ObjectUtil.checkNotNull(eventExecutor, "eventExecutor"); + return new ThreadFactory() { + @Override + public Thread newThread(Runnable r) { + return threadFactory.newThread(apply(r, eventExecutor)); + } + }; + } +} diff --git a/netty-util/src/main/java/io/netty/util/internal/ThreadLocalRandom.java b/netty-util/src/main/java/io/netty/util/internal/ThreadLocalRandom.java new file mode 100644 index 0000000..2127669 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/internal/ThreadLocalRandom.java @@ -0,0 +1,384 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/* + * Written by Doug Lea with assistance from members of JCP JSR-166 + * Expert Group and released to the public domain, as explained at + * https://creativecommons.org/publicdomain/zero/1.0/ + */ + +package io.netty.util.internal; + +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; +import java.lang.Thread.UncaughtExceptionHandler; +import java.security.SecureRandom; +import java.util.Random; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicLong; +import static io.netty.util.internal.ObjectUtil.checkPositive; + +/** + * A random number generator isolated to the current thread. Like the + * global {@link java.util.Random} generator used by the {@link + * java.lang.Math} class, a {@code ThreadLocalRandom} is initialized + * with an internally generated seed that may not otherwise be + * modified. When applicable, use of {@code ThreadLocalRandom} rather + * than shared {@code Random} objects in concurrent programs will + * typically encounter much less overhead and contention. Use of + * {@code ThreadLocalRandom} is particularly appropriate when multiple + * tasks (for example, each ForkJoinTask use random numbers + * in parallel in thread pools. + * + *

Usages of this class should typically be of the form: + * {@code ThreadLocalRandom.current().nextX(...)} (where + * {@code X} is {@code Int}, {@code Long}, etc). + * When all usages are of this form, it is never possible to + * accidently share a {@code ThreadLocalRandom} across multiple threads. + * + *

This class also provides additional commonly used bounded random + * generation methods. + *

+ * //since 1.7 + * //author Doug Lea + */ +@SuppressWarnings("all") +public final class ThreadLocalRandom extends Random { + + private static final InternalLogger logger = InternalLoggerFactory.getInstance(ThreadLocalRandom.class); + + private static final AtomicLong seedUniquifier = new AtomicLong(); + + private static volatile long initialSeedUniquifier; + + private static final Thread seedGeneratorThread; + private static final BlockingQueue seedQueue; + private static final long seedGeneratorStartTime; + private static volatile long seedGeneratorEndTime; + + static { + initialSeedUniquifier = SystemPropertyUtil.getLong("io.netty.initialSeedUniquifier", 0); + if (initialSeedUniquifier == 0) { + boolean secureRandom = SystemPropertyUtil.getBoolean("java.util.secureRandomSeed", false); + if (secureRandom) { + seedQueue = new LinkedBlockingQueue(); + seedGeneratorStartTime = System.nanoTime(); + + // Try to generate a real random number from /dev/random. + // Get from a different thread to avoid blocking indefinitely on a machine without much entropy. + seedGeneratorThread = new Thread("initialSeedUniquifierGenerator") { + @Override + public void run() { + final SecureRandom random = new SecureRandom(); // Get the real random seed from /dev/random + final byte[] seed = random.generateSeed(8); + seedGeneratorEndTime = System.nanoTime(); + long s = ((long) seed[0] & 0xff) << 56 | + ((long) seed[1] & 0xff) << 48 | + ((long) seed[2] & 0xff) << 40 | + ((long) seed[3] & 0xff) << 32 | + ((long) seed[4] & 0xff) << 24 | + ((long) seed[5] & 0xff) << 16 | + ((long) seed[6] & 0xff) << 8 | + (long) seed[7] & 0xff; + seedQueue.add(s); + } + }; + seedGeneratorThread.setDaemon(true); + seedGeneratorThread.setUncaughtExceptionHandler(new UncaughtExceptionHandler() { + @Override + public void uncaughtException(Thread t, Throwable e) { + logger.debug("An exception has been raised by {}", t.getName(), e); + } + }); + seedGeneratorThread.start(); + } else { + initialSeedUniquifier = mix64(System.currentTimeMillis()) ^ mix64(System.nanoTime()); + seedGeneratorThread = null; + seedQueue = null; + seedGeneratorStartTime = 0L; + } + } else { + seedGeneratorThread = null; + seedQueue = null; + seedGeneratorStartTime = 0L; + } + } + + public static void setInitialSeedUniquifier(long initialSeedUniquifier) { + ThreadLocalRandom.initialSeedUniquifier = initialSeedUniquifier; + } + + public static long getInitialSeedUniquifier() { + // Use the value set via the setter. + long initialSeedUniquifier = ThreadLocalRandom.initialSeedUniquifier; + if (initialSeedUniquifier != 0) { + return initialSeedUniquifier; + } + + synchronized (ThreadLocalRandom.class) { + initialSeedUniquifier = ThreadLocalRandom.initialSeedUniquifier; + if (initialSeedUniquifier != 0) { + return initialSeedUniquifier; + } + + // Get the random seed from the generator thread with timeout. + final long timeoutSeconds = 3; + final long deadLine = seedGeneratorStartTime + TimeUnit.SECONDS.toNanos(timeoutSeconds); + boolean interrupted = false; + for (; ; ) { + final long waitTime = deadLine - System.nanoTime(); + try { + final Long seed; + if (waitTime <= 0) { + seed = seedQueue.poll(); + } else { + seed = seedQueue.poll(waitTime, TimeUnit.NANOSECONDS); + } + + if (seed != null) { + initialSeedUniquifier = seed; + break; + } + } catch (InterruptedException e) { + interrupted = true; + logger.warn("Failed to generate a seed from SecureRandom due to an InterruptedException."); + break; + } + + if (waitTime <= 0) { + seedGeneratorThread.interrupt(); + logger.warn( + "Failed to generate a seed from SecureRandom within {} seconds. " + + "Not enough entropy?", timeoutSeconds + ); + break; + } + } + + // Just in case the initialSeedUniquifier is zero or some other constant + initialSeedUniquifier ^= 0x3255ecdc33bae119L; // just a meaningless random number + initialSeedUniquifier ^= Long.reverse(System.nanoTime()); + + ThreadLocalRandom.initialSeedUniquifier = initialSeedUniquifier; + + if (interrupted) { + // Restore the interrupt status because we don't know how to/don't need to handle it here. + Thread.currentThread().interrupt(); + + // Interrupt the generator thread if it's still running, + // in the hope that the SecureRandom provider raises an exception on interruption. + seedGeneratorThread.interrupt(); + } + + if (seedGeneratorEndTime == 0) { + seedGeneratorEndTime = System.nanoTime(); + } + + return initialSeedUniquifier; + } + } + + private static long newSeed() { + for (; ; ) { + final long current = seedUniquifier.get(); + final long actualCurrent = current != 0 ? current : getInitialSeedUniquifier(); + + // L'Ecuyer, "Tables of Linear Congruential Generators of Different Sizes and Good Lattice Structure", 1999 + final long next = actualCurrent * 181783497276652981L; + + if (seedUniquifier.compareAndSet(current, next)) { + if (current == 0 && logger.isDebugEnabled()) { + if (seedGeneratorEndTime != 0) { + logger.debug(String.format( + "-Dio.netty.initialSeedUniquifier: 0x%016x (took %d ms)", + actualCurrent, + TimeUnit.NANOSECONDS.toMillis(seedGeneratorEndTime - seedGeneratorStartTime))); + } else { + logger.debug(String.format("-Dio.netty.initialSeedUniquifier: 0x%016x", actualCurrent)); + } + } + return next ^ System.nanoTime(); + } + } + } + + // Borrowed from + // http://gee.cs.oswego.edu/cgi-bin/viewcvs.cgi/jsr166/src/main/java/util/concurrent/ThreadLocalRandom.java + private static long mix64(long z) { + z = (z ^ (z >>> 33)) * 0xff51afd7ed558ccdL; + z = (z ^ (z >>> 33)) * 0xc4ceb9fe1a85ec53L; + return z ^ (z >>> 33); + } + + // same constants as Random, but must be redeclared because private + private static final long multiplier = 0x5DEECE66DL; + private static final long addend = 0xBL; + private static final long mask = (1L << 48) - 1; + + /** + * The random seed. We can't use super.seed. + */ + private long rnd; + + /** + * Initialization flag to permit calls to setSeed to succeed only + * while executing the Random constructor. We can't allow others + * since it would cause setting seed in one part of a program to + * unintentionally impact other usages by the thread. + */ + boolean initialized; + + // Padding to help avoid memory contention among seed updates in + // different TLRs in the common case that they are located near + // each other. + private long pad0, pad1, pad2, pad3, pad4, pad5, pad6, pad7; + + /** + * Constructor called only by localRandom.initialValue. + */ + ThreadLocalRandom() { + super(newSeed()); + initialized = true; + } + + /** + * Returns the current thread's {@code ThreadLocalRandom}. + * + * @return the current thread's {@code ThreadLocalRandom} + */ + public static ThreadLocalRandom current() { + return InternalThreadLocalMap.get().random(); + } + + /** + * Throws {@code UnsupportedOperationException}. Setting seeds in + * this generator is not supported. + * + * @throws UnsupportedOperationException always + */ + @Override + public void setSeed(long seed) { + if (initialized) { + throw new UnsupportedOperationException(); + } + rnd = (seed ^ multiplier) & mask; + } + + @Override + protected int next(int bits) { + rnd = (rnd * multiplier + addend) & mask; + return (int) (rnd >>> (48 - bits)); + } + + /** + * Returns a pseudorandom, uniformly distributed value between the + * given least value (inclusive) and bound (exclusive). + * + * @param least the least value returned + * @param bound the upper bound (exclusive) + * @return the next value + * @throws IllegalArgumentException if least greater than or equal + * to bound + */ + public int nextInt(int least, int bound) { + if (least >= bound) { + throw new IllegalArgumentException(); + } + return nextInt(bound - least) + least; + } + + /** + * Returns a pseudorandom, uniformly distributed value + * between 0 (inclusive) and the specified value (exclusive). + * + * @param n the bound on the random number to be returned. Must be + * positive. + * @return the next value + * @throws IllegalArgumentException if n is not positive + */ + public long nextLong(long n) { + checkPositive(n, "n"); + + // Divide n by two until small enough for nextInt. On each + // iteration (at most 31 of them but usually much less), + // randomly choose both whether to include high bit in result + // (offset) and whether to continue with the lower vs upper + // half (which makes a difference only if odd). + long offset = 0; + while (n >= Integer.MAX_VALUE) { + int bits = next(2); + long half = n >>> 1; + long nextn = ((bits & 2) == 0) ? half : n - half; + if ((bits & 1) == 0) { + offset += n - nextn; + } + n = nextn; + } + return offset + nextInt((int) n); + } + + /** + * Returns a pseudorandom, uniformly distributed value between the + * given least value (inclusive) and bound (exclusive). + * + * @param least the least value returned + * @param bound the upper bound (exclusive) + * @return the next value + * @throws IllegalArgumentException if least greater than or equal + * to bound + */ + public long nextLong(long least, long bound) { + if (least >= bound) { + throw new IllegalArgumentException(); + } + return nextLong(bound - least) + least; + } + + /** + * Returns a pseudorandom, uniformly distributed {@code double} value + * between 0 (inclusive) and the specified value (exclusive). + * + * @param n the bound on the random number to be returned. Must be + * positive. + * @return the next value + * @throws IllegalArgumentException if n is not positive + */ + public double nextDouble(double n) { + checkPositive(n, "n"); + return nextDouble() * n; + } + + /** + * Returns a pseudorandom, uniformly distributed value between the + * given least value (inclusive) and bound (exclusive). + * + * @param least the least value returned + * @param bound the upper bound (exclusive) + * @return the next value + * @throws IllegalArgumentException if least greater than or equal + * to bound + */ + public double nextDouble(double least, double bound) { + if (least >= bound) { + throw new IllegalArgumentException(); + } + return nextDouble() * (bound - least) + least; + } + + private static final long serialVersionUID = -5851777807851030925L; +} diff --git a/netty-util/src/main/java/io/netty/util/internal/ThrowableUtil.java b/netty-util/src/main/java/io/netty/util/internal/ThrowableUtil.java new file mode 100644 index 0000000..00ce1ab --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/internal/ThrowableUtil.java @@ -0,0 +1,88 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.internal; + +import java.io.ByteArrayOutputStream; +import java.io.IOException; +import java.io.PrintStream; +import java.util.List; + +public final class ThrowableUtil { + + private ThrowableUtil() { + } + + /** + * Set the {@link StackTraceElement} for the given {@link Throwable}, using the {@link Class} and method name. + */ + public static T unknownStackTrace(T cause, Class clazz, String method) { + cause.setStackTrace(new StackTraceElement[]{new StackTraceElement(clazz.getName(), method, null, -1)}); + return cause; + } + + /** + * Gets the stack trace from a Throwable as a String. + * + * @param cause the {@link Throwable} to be examined + * @return the stack trace as generated by {@link Throwable#printStackTrace(java.io.PrintWriter)} method. + */ + public static String stackTraceToString(Throwable cause) { + ByteArrayOutputStream out = new ByteArrayOutputStream(); + PrintStream pout = new PrintStream(out); + cause.printStackTrace(pout); + pout.flush(); + try { + return out.toString(); + } finally { + try { + out.close(); + } catch (IOException ignore) { + // ignore as should never happen + } + } + } + + public static boolean haveSuppressed() { + return PlatformDependent.javaVersion() >= 7; + } + + @SuppressJava6Requirement(reason = "Throwable addSuppressed is only available for >= 7. Has check for < 7.") + public static void addSuppressed(Throwable target, Throwable suppressed) { + if (!haveSuppressed()) { + return; + } + target.addSuppressed(suppressed); + } + + public static void addSuppressedAndClear(Throwable target, List suppressed) { + addSuppressed(target, suppressed); + suppressed.clear(); + } + + public static void addSuppressed(Throwable target, List suppressed) { + for (Throwable t : suppressed) { + addSuppressed(target, t); + } + } + + @SuppressJava6Requirement(reason = "Throwable getSuppressed is only available for >= 7. Has check for < 7.") + public static Throwable[] getSuppressed(Throwable source) { + if (!haveSuppressed()) { + return EmptyArrays.EMPTY_THROWABLES; + } + return source.getSuppressed(); + } +} diff --git a/netty-util/src/main/java/io/netty/util/internal/TypeParameterMatcher.java b/netty-util/src/main/java/io/netty/util/internal/TypeParameterMatcher.java new file mode 100644 index 0000000..d7bfd21 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/internal/TypeParameterMatcher.java @@ -0,0 +1,165 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.util.internal; + +import java.lang.reflect.Array; +import java.lang.reflect.GenericArrayType; +import java.lang.reflect.ParameterizedType; +import java.lang.reflect.Type; +import java.lang.reflect.TypeVariable; +import java.util.HashMap; +import java.util.Map; + +public abstract class TypeParameterMatcher { + + private static final TypeParameterMatcher NOOP = new TypeParameterMatcher() { + @Override + public boolean match(Object msg) { + return true; + } + }; + + public static TypeParameterMatcher get(final Class parameterType) { + final Map, TypeParameterMatcher> getCache = + InternalThreadLocalMap.get().typeParameterMatcherGetCache(); + + TypeParameterMatcher matcher = getCache.get(parameterType); + if (matcher == null) { + if (parameterType == Object.class) { + matcher = NOOP; + } else { + matcher = new ReflectiveMatcher(parameterType); + } + getCache.put(parameterType, matcher); + } + + return matcher; + } + + public static TypeParameterMatcher find( + final Object object, final Class parametrizedSuperclass, final String typeParamName) { + + final Map, Map> findCache = + InternalThreadLocalMap.get().typeParameterMatcherFindCache(); + final Class thisClass = object.getClass(); + + Map map = findCache.get(thisClass); + if (map == null) { + map = new HashMap(); + findCache.put(thisClass, map); + } + + TypeParameterMatcher matcher = map.get(typeParamName); + if (matcher == null) { + matcher = get(find0(object, parametrizedSuperclass, typeParamName)); + map.put(typeParamName, matcher); + } + + return matcher; + } + + private static Class find0( + final Object object, Class parametrizedSuperclass, String typeParamName) { + + final Class thisClass = object.getClass(); + Class currentClass = thisClass; + for (; ; ) { + if (currentClass.getSuperclass() == parametrizedSuperclass) { + int typeParamIndex = -1; + TypeVariable[] typeParams = currentClass.getSuperclass().getTypeParameters(); + for (int i = 0; i < typeParams.length; i++) { + if (typeParamName.equals(typeParams[i].getName())) { + typeParamIndex = i; + break; + } + } + + if (typeParamIndex < 0) { + throw new IllegalStateException( + "unknown type parameter '" + typeParamName + "': " + parametrizedSuperclass); + } + + Type genericSuperType = currentClass.getGenericSuperclass(); + if (!(genericSuperType instanceof ParameterizedType)) { + return Object.class; + } + + Type[] actualTypeParams = ((ParameterizedType) genericSuperType).getActualTypeArguments(); + + Type actualTypeParam = actualTypeParams[typeParamIndex]; + if (actualTypeParam instanceof ParameterizedType) { + actualTypeParam = ((ParameterizedType) actualTypeParam).getRawType(); + } + if (actualTypeParam instanceof Class) { + return (Class) actualTypeParam; + } + if (actualTypeParam instanceof GenericArrayType) { + Type componentType = ((GenericArrayType) actualTypeParam).getGenericComponentType(); + if (componentType instanceof ParameterizedType) { + componentType = ((ParameterizedType) componentType).getRawType(); + } + if (componentType instanceof Class) { + return Array.newInstance((Class) componentType, 0).getClass(); + } + } + if (actualTypeParam instanceof TypeVariable v) { + // Resolved type parameter points to another type parameter. + if (!(v.getGenericDeclaration() instanceof Class)) { + return Object.class; + } + + currentClass = thisClass; + parametrizedSuperclass = (Class) v.getGenericDeclaration(); + typeParamName = v.getName(); + if (parametrizedSuperclass.isAssignableFrom(thisClass)) { + continue; + } + return Object.class; + } + + return fail(thisClass, typeParamName); + } + currentClass = currentClass.getSuperclass(); + if (currentClass == null) { + return fail(thisClass, typeParamName); + } + } + } + + private static Class fail(Class type, String typeParamName) { + throw new IllegalStateException( + "cannot determine the type of the type parameter '" + typeParamName + "': " + type); + } + + public abstract boolean match(Object msg); + + private static final class ReflectiveMatcher extends TypeParameterMatcher { + private final Class type; + + ReflectiveMatcher(Class type) { + this.type = type; + } + + @Override + public boolean match(Object msg) { + return type.isInstance(msg); + } + } + + TypeParameterMatcher() { + } +} diff --git a/netty-util/src/main/java/io/netty/util/internal/UnpaddedInternalThreadLocalMap.java b/netty-util/src/main/java/io/netty/util/internal/UnpaddedInternalThreadLocalMap.java new file mode 100644 index 0000000..bfa8df3 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/internal/UnpaddedInternalThreadLocalMap.java @@ -0,0 +1,24 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.internal; + +/** + * @deprecated This class will be removed in the future. + */ +@Deprecated +class UnpaddedInternalThreadLocalMap { + // We cannot remove this in 4.1 because it could break compatibility. +} diff --git a/netty-util/src/main/java/io/netty/util/internal/UnstableApi.java b/netty-util/src/main/java/io/netty/util/internal/UnstableApi.java new file mode 100644 index 0000000..54a3aca --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/internal/UnstableApi.java @@ -0,0 +1,47 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.internal; + +import java.lang.annotation.Documented; +import java.lang.annotation.ElementType; +import java.lang.annotation.Retention; +import java.lang.annotation.RetentionPolicy; +import java.lang.annotation.Target; + +/** + * Indicates a public API that can change at any time (even in minor/bugfix releases). + *

+ * Usage guidelines: + * + *

    + *
  1. Is not needed for things located in *.internal.* packages
  2. + *
  3. Only public accessible classes/interfaces must be annotated
  4. + *
  5. If this annotation is not present the API is considered stable and so no backward compatibility can be + * broken in a non-major release!
  6. + *
+ */ +@Retention(RetentionPolicy.SOURCE) // TODO Retention policy needs to be CLASS in Netty 5. +@Target({ + ElementType.ANNOTATION_TYPE, + ElementType.CONSTRUCTOR, + ElementType.FIELD, + ElementType.METHOD, + ElementType.PACKAGE, + ElementType.TYPE +}) +@Documented +public @interface UnstableApi { +} diff --git a/netty-util/src/main/java/io/netty/util/internal/logging/AbstractInternalLogger.java b/netty-util/src/main/java/io/netty/util/internal/logging/AbstractInternalLogger.java new file mode 100644 index 0000000..8ca6716 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/internal/logging/AbstractInternalLogger.java @@ -0,0 +1,237 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.internal.logging; + +import io.netty.util.internal.ObjectUtil; +import io.netty.util.internal.StringUtil; +import java.io.ObjectStreamException; +import java.io.Serializable; + +/** + * A skeletal implementation of {@link InternalLogger}. This class implements + * all methods that have a {@link InternalLogLevel} parameter by default to call + * specific logger methods such as {@link #info(String)} or {@link #isInfoEnabled()}. + */ +public abstract class AbstractInternalLogger implements InternalLogger, Serializable { + + private static final long serialVersionUID = -6382972526573193470L; + + static final String EXCEPTION_MESSAGE = "Unexpected exception:"; + + private final String name; + + /** + * Creates a new instance. + */ + protected AbstractInternalLogger(String name) { + this.name = ObjectUtil.checkNotNull(name, "name"); + } + + @Override + public String name() { + return name; + } + + @Override + public boolean isEnabled(InternalLogLevel level) { + switch (level) { + case TRACE: + return isTraceEnabled(); + case DEBUG: + return isDebugEnabled(); + case INFO: + return isInfoEnabled(); + case WARN: + return isWarnEnabled(); + case ERROR: + return isErrorEnabled(); + default: + throw new Error(); + } + } + + @Override + public void trace(Throwable t) { + trace(EXCEPTION_MESSAGE, t); + } + + @Override + public void debug(Throwable t) { + debug(EXCEPTION_MESSAGE, t); + } + + @Override + public void info(Throwable t) { + info(EXCEPTION_MESSAGE, t); + } + + @Override + public void warn(Throwable t) { + warn(EXCEPTION_MESSAGE, t); + } + + @Override + public void error(Throwable t) { + error(EXCEPTION_MESSAGE, t); + } + + @Override + public void log(InternalLogLevel level, String msg, Throwable cause) { + switch (level) { + case TRACE: + trace(msg, cause); + break; + case DEBUG: + debug(msg, cause); + break; + case INFO: + info(msg, cause); + break; + case WARN: + warn(msg, cause); + break; + case ERROR: + error(msg, cause); + break; + default: + throw new Error(); + } + } + + @Override + public void log(InternalLogLevel level, Throwable cause) { + switch (level) { + case TRACE: + trace(cause); + break; + case DEBUG: + debug(cause); + break; + case INFO: + info(cause); + break; + case WARN: + warn(cause); + break; + case ERROR: + error(cause); + break; + default: + throw new Error(); + } + } + + @Override + public void log(InternalLogLevel level, String msg) { + switch (level) { + case TRACE: + trace(msg); + break; + case DEBUG: + debug(msg); + break; + case INFO: + info(msg); + break; + case WARN: + warn(msg); + break; + case ERROR: + error(msg); + break; + default: + throw new Error(); + } + } + + @Override + public void log(InternalLogLevel level, String format, Object arg) { + switch (level) { + case TRACE: + trace(format, arg); + break; + case DEBUG: + debug(format, arg); + break; + case INFO: + info(format, arg); + break; + case WARN: + warn(format, arg); + break; + case ERROR: + error(format, arg); + break; + default: + throw new Error(); + } + } + + @Override + public void log(InternalLogLevel level, String format, Object argA, Object argB) { + switch (level) { + case TRACE: + trace(format, argA, argB); + break; + case DEBUG: + debug(format, argA, argB); + break; + case INFO: + info(format, argA, argB); + break; + case WARN: + warn(format, argA, argB); + break; + case ERROR: + error(format, argA, argB); + break; + default: + throw new Error(); + } + } + + @Override + public void log(InternalLogLevel level, String format, Object... arguments) { + switch (level) { + case TRACE: + trace(format, arguments); + break; + case DEBUG: + debug(format, arguments); + break; + case INFO: + info(format, arguments); + break; + case WARN: + warn(format, arguments); + break; + case ERROR: + error(format, arguments); + break; + default: + throw new Error(); + } + } + + protected Object readResolve() throws ObjectStreamException { + return InternalLoggerFactory.getInstance(name()); + } + + @Override + public String toString() { + return StringUtil.simpleClassName(this) + '(' + name() + ')'; + } +} diff --git a/netty-util/src/main/java/io/netty/util/internal/logging/FormattingTuple.java b/netty-util/src/main/java/io/netty/util/internal/logging/FormattingTuple.java new file mode 100644 index 0000000..efaa536 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/internal/logging/FormattingTuple.java @@ -0,0 +1,61 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +/** + * Copyright (c) 2004-2011 QOS.ch + * All rights reserved. + *

+ * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files (the + * "Software"), to deal in the Software without restriction, including + * without limitation the rights to use, copy, modify, merge, publish, + * distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to + * the following conditions: + *

+ * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + *

+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE + * LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION + * OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION + * WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ +package io.netty.util.internal.logging; + +/** + * Holds the results of formatting done by {@link MessageFormatter}. + */ +public final class FormattingTuple { + + private final String message; + private final Throwable throwable; + + public FormattingTuple(String message, Throwable throwable) { + this.message = message; + this.throwable = throwable; + } + + public String getMessage() { + return message; + } + + public Throwable getThrowable() { + return throwable; + } +} diff --git a/netty-util/src/main/java/io/netty/util/internal/logging/InternalLogLevel.java b/netty-util/src/main/java/io/netty/util/internal/logging/InternalLogLevel.java new file mode 100644 index 0000000..d15d526 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/internal/logging/InternalLogLevel.java @@ -0,0 +1,42 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.internal.logging; + +/** + * The log level that {@link InternalLogger} can log at. + */ +public enum InternalLogLevel { + /** + * 'TRACE' log level. + */ + TRACE, + /** + * 'DEBUG' log level. + */ + DEBUG, + /** + * 'INFO' log level. + */ + INFO, + /** + * 'WARN' log level. + */ + WARN, + /** + * 'ERROR' log level. + */ + ERROR +} diff --git a/netty-util/src/main/java/io/netty/util/internal/logging/InternalLogger.java b/netty-util/src/main/java/io/netty/util/internal/logging/InternalLogger.java new file mode 100644 index 0000000..c1231bc --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/internal/logging/InternalLogger.java @@ -0,0 +1,485 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +/** + * Copyright (c) 2004-2011 QOS.ch + * All rights reserved. + *

+ * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files (the + * "Software"), to deal in the Software without restriction, including + * without limitation the rights to use, copy, modify, merge, publish, + * distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to + * the following conditions: + *

+ * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + *

+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE + * LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION + * OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION + * WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ +package io.netty.util.internal.logging; + +/** + * Internal-use-only logger used by Netty. DO NOT + * access this class outside of Netty. + */ +public interface InternalLogger { + + /** + * Return the name of this {@link InternalLogger} instance. + * + * @return name of this logger instance + */ + String name(); + + /** + * Is the logger instance enabled for the TRACE level? + * + * @return True if this Logger is enabled for the TRACE level, + * false otherwise. + */ + boolean isTraceEnabled(); + + /** + * Log a message at the TRACE level. + * + * @param msg the message string to be logged + */ + void trace(String msg); + + /** + * Log a message at the TRACE level according to the specified format + * and argument. + *

+ *

This form avoids superfluous object creation when the logger + * is disabled for the TRACE level.

+ * + * @param format the format string + * @param arg the argument + */ + void trace(String format, Object arg); + + /** + * Log a message at the TRACE level according to the specified format + * and arguments. + *

+ *

This form avoids superfluous object creation when the logger + * is disabled for the TRACE level.

+ * + * @param format the format string + * @param argA the first argument + * @param argB the second argument + */ + void trace(String format, Object argA, Object argB); + + /** + * Log a message at the TRACE level according to the specified format + * and arguments. + *

+ *

This form avoids superfluous string concatenation when the logger + * is disabled for the TRACE level. However, this variant incurs the hidden + * (and relatively small) cost of creating an {@code Object[]} before invoking the method, + * even if this logger is disabled for TRACE. The variants taking {@link #trace(String, Object) one} and + * {@link #trace(String, Object, Object) two} arguments exist solely in order to avoid this hidden cost.

+ * + * @param format the format string + * @param arguments a list of 3 or more arguments + */ + void trace(String format, Object... arguments); + + /** + * Log an exception (throwable) at the TRACE level with an + * accompanying message. + * + * @param msg the message accompanying the exception + * @param t the exception (throwable) to log + */ + void trace(String msg, Throwable t); + + /** + * Log an exception (throwable) at the TRACE level. + * + * @param t the exception (throwable) to log + */ + void trace(Throwable t); + + /** + * Is the logger instance enabled for the DEBUG level? + * + * @return True if this Logger is enabled for the DEBUG level, + * false otherwise. + */ + boolean isDebugEnabled(); + + /** + * Log a message at the DEBUG level. + * + * @param msg the message string to be logged + */ + void debug(String msg); + + /** + * Log a message at the DEBUG level according to the specified format + * and argument. + *

+ *

This form avoids superfluous object creation when the logger + * is disabled for the DEBUG level.

+ * + * @param format the format string + * @param arg the argument + */ + void debug(String format, Object arg); + + /** + * Log a message at the DEBUG level according to the specified format + * and arguments. + *

+ *

This form avoids superfluous object creation when the logger + * is disabled for the DEBUG level.

+ * + * @param format the format string + * @param argA the first argument + * @param argB the second argument + */ + void debug(String format, Object argA, Object argB); + + /** + * Log a message at the DEBUG level according to the specified format + * and arguments. + *

+ *

This form avoids superfluous string concatenation when the logger + * is disabled for the DEBUG level. However, this variant incurs the hidden + * (and relatively small) cost of creating an {@code Object[]} before invoking the method, + * even if this logger is disabled for DEBUG. The variants taking + * {@link #debug(String, Object) one} and {@link #debug(String, Object, Object) two} + * arguments exist solely in order to avoid this hidden cost.

+ * + * @param format the format string + * @param arguments a list of 3 or more arguments + */ + void debug(String format, Object... arguments); + + /** + * Log an exception (throwable) at the DEBUG level with an + * accompanying message. + * + * @param msg the message accompanying the exception + * @param t the exception (throwable) to log + */ + void debug(String msg, Throwable t); + + /** + * Log an exception (throwable) at the DEBUG level. + * + * @param t the exception (throwable) to log + */ + void debug(Throwable t); + + /** + * Is the logger instance enabled for the INFO level? + * + * @return True if this Logger is enabled for the INFO level, + * false otherwise. + */ + boolean isInfoEnabled(); + + /** + * Log a message at the INFO level. + * + * @param msg the message string to be logged + */ + void info(String msg); + + /** + * Log a message at the INFO level according to the specified format + * and argument. + *

+ *

This form avoids superfluous object creation when the logger + * is disabled for the INFO level.

+ * + * @param format the format string + * @param arg the argument + */ + void info(String format, Object arg); + + /** + * Log a message at the INFO level according to the specified format + * and arguments. + *

+ *

This form avoids superfluous object creation when the logger + * is disabled for the INFO level.

+ * + * @param format the format string + * @param argA the first argument + * @param argB the second argument + */ + void info(String format, Object argA, Object argB); + + /** + * Log a message at the INFO level according to the specified format + * and arguments. + *

+ *

This form avoids superfluous string concatenation when the logger + * is disabled for the INFO level. However, this variant incurs the hidden + * (and relatively small) cost of creating an {@code Object[]} before invoking the method, + * even if this logger is disabled for INFO. The variants taking + * {@link #info(String, Object) one} and {@link #info(String, Object, Object) two} + * arguments exist solely in order to avoid this hidden cost.

+ * + * @param format the format string + * @param arguments a list of 3 or more arguments + */ + void info(String format, Object... arguments); + + /** + * Log an exception (throwable) at the INFO level with an + * accompanying message. + * + * @param msg the message accompanying the exception + * @param t the exception (throwable) to log + */ + void info(String msg, Throwable t); + + /** + * Log an exception (throwable) at the INFO level. + * + * @param t the exception (throwable) to log + */ + void info(Throwable t); + + /** + * Is the logger instance enabled for the WARN level? + * + * @return True if this Logger is enabled for the WARN level, + * false otherwise. + */ + boolean isWarnEnabled(); + + /** + * Log a message at the WARN level. + * + * @param msg the message string to be logged + */ + void warn(String msg); + + /** + * Log a message at the WARN level according to the specified format + * and argument. + *

+ *

This form avoids superfluous object creation when the logger + * is disabled for the WARN level.

+ * + * @param format the format string + * @param arg the argument + */ + void warn(String format, Object arg); + + /** + * Log a message at the WARN level according to the specified format + * and arguments. + *

+ *

This form avoids superfluous string concatenation when the logger + * is disabled for the WARN level. However, this variant incurs the hidden + * (and relatively small) cost of creating an {@code Object[]} before invoking the method, + * even if this logger is disabled for WARN. The variants taking + * {@link #warn(String, Object) one} and {@link #warn(String, Object, Object) two} + * arguments exist solely in order to avoid this hidden cost.

+ * + * @param format the format string + * @param arguments a list of 3 or more arguments + */ + void warn(String format, Object... arguments); + + /** + * Log a message at the WARN level according to the specified format + * and arguments. + *

+ *

This form avoids superfluous object creation when the logger + * is disabled for the WARN level.

+ * + * @param format the format string + * @param argA the first argument + * @param argB the second argument + */ + void warn(String format, Object argA, Object argB); + + /** + * Log an exception (throwable) at the WARN level with an + * accompanying message. + * + * @param msg the message accompanying the exception + * @param t the exception (throwable) to log + */ + void warn(String msg, Throwable t); + + /** + * Log an exception (throwable) at the WARN level. + * + * @param t the exception (throwable) to log + */ + void warn(Throwable t); + + /** + * Is the logger instance enabled for the ERROR level? + * + * @return True if this Logger is enabled for the ERROR level, + * false otherwise. + */ + boolean isErrorEnabled(); + + /** + * Log a message at the ERROR level. + * + * @param msg the message string to be logged + */ + void error(String msg); + + /** + * Log a message at the ERROR level according to the specified format + * and argument. + *

+ *

This form avoids superfluous object creation when the logger + * is disabled for the ERROR level.

+ * + * @param format the format string + * @param arg the argument + */ + void error(String format, Object arg); + + /** + * Log a message at the ERROR level according to the specified format + * and arguments. + *

+ *

This form avoids superfluous object creation when the logger + * is disabled for the ERROR level.

+ * + * @param format the format string + * @param argA the first argument + * @param argB the second argument + */ + void error(String format, Object argA, Object argB); + + /** + * Log a message at the ERROR level according to the specified format + * and arguments. + *

+ *

This form avoids superfluous string concatenation when the logger + * is disabled for the ERROR level. However, this variant incurs the hidden + * (and relatively small) cost of creating an {@code Object[]} before invoking the method, + * even if this logger is disabled for ERROR. The variants taking + * {@link #error(String, Object) one} and {@link #error(String, Object, Object) two} + * arguments exist solely in order to avoid this hidden cost.

+ * + * @param format the format string + * @param arguments a list of 3 or more arguments + */ + void error(String format, Object... arguments); + + /** + * Log an exception (throwable) at the ERROR level with an + * accompanying message. + * + * @param msg the message accompanying the exception + * @param t the exception (throwable) to log + */ + void error(String msg, Throwable t); + + /** + * Log an exception (throwable) at the ERROR level. + * + * @param t the exception (throwable) to log + */ + void error(Throwable t); + + /** + * Is the logger instance enabled for the specified {@code level}? + * + * @return True if this Logger is enabled for the specified {@code level}, + * false otherwise. + */ + boolean isEnabled(InternalLogLevel level); + + /** + * Log a message at the specified {@code level}. + * + * @param msg the message string to be logged + */ + void log(InternalLogLevel level, String msg); + + /** + * Log a message at the specified {@code level} according to the specified format + * and argument. + *

+ *

This form avoids superfluous object creation when the logger + * is disabled for the specified {@code level}.

+ * + * @param format the format string + * @param arg the argument + */ + void log(InternalLogLevel level, String format, Object arg); + + /** + * Log a message at the specified {@code level} according to the specified format + * and arguments. + *

+ *

This form avoids superfluous object creation when the logger + * is disabled for the specified {@code level}.

+ * + * @param format the format string + * @param argA the first argument + * @param argB the second argument + */ + void log(InternalLogLevel level, String format, Object argA, Object argB); + + /** + * Log a message at the specified {@code level} according to the specified format + * and arguments. + *

+ *

This form avoids superfluous string concatenation when the logger + * is disabled for the specified {@code level}. However, this variant incurs the hidden + * (and relatively small) cost of creating an {@code Object[]} before invoking the method, + * even if this logger is disabled for the specified {@code level}. The variants taking + * {@link #log(InternalLogLevel, String, Object) one} and + * {@link #log(InternalLogLevel, String, Object, Object) two} arguments exist solely + * in order to avoid this hidden cost.

+ * + * @param format the format string + * @param arguments a list of 3 or more arguments + */ + void log(InternalLogLevel level, String format, Object... arguments); + + /** + * Log an exception (throwable) at the specified {@code level} with an + * accompanying message. + * + * @param msg the message accompanying the exception + * @param t the exception (throwable) to log + */ + void log(InternalLogLevel level, String msg, Throwable t); + + /** + * Log an exception (throwable) at the specified {@code level}. + * + * @param t the exception (throwable) to log + */ + void log(InternalLogLevel level, Throwable t); +} diff --git a/netty-util/src/main/java/io/netty/util/internal/logging/InternalLoggerFactory.java b/netty-util/src/main/java/io/netty/util/internal/logging/InternalLoggerFactory.java new file mode 100644 index 0000000..ffdc2f3 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/internal/logging/InternalLoggerFactory.java @@ -0,0 +1,82 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.util.internal.logging; + +import io.netty.util.internal.ObjectUtil; + +/** + * Creates an {@link InternalLogger} or changes the default factory + * implementation. This factory allows you to choose what logging framework + * Netty should use. + * Please note that the new default factory is effective only for the classes + * which were loaded after the default factory is changed. Therefore, + * {@link #setDefaultFactory(InternalLoggerFactory)} should be called as early + * as possible and shouldn't be called more than once. + */ +public abstract class InternalLoggerFactory { + + private static volatile InternalLoggerFactory defaultFactory; + + @SuppressWarnings("UnusedCatchParameter") + private static InternalLoggerFactory newDefaultFactory(String name) { + return useJdkLoggerFactory(name); + } + + private static InternalLoggerFactory useJdkLoggerFactory(String name) { + InternalLoggerFactory f = JdkLoggerFactory.INSTANCE; + f.newInstance(name).debug("Using java.util.logging as the default logging framework"); + return f; + } + + /** + * Returns the default factory. The initial default factory is + * {@link JdkLoggerFactory}. + */ + public static InternalLoggerFactory getDefaultFactory() { + if (defaultFactory == null) { + defaultFactory = newDefaultFactory(InternalLoggerFactory.class.getName()); + } + return defaultFactory; + } + + /** + * Changes the default factory. + */ + public static void setDefaultFactory(InternalLoggerFactory defaultFactory) { + InternalLoggerFactory.defaultFactory = ObjectUtil.checkNotNull(defaultFactory, "defaultFactory"); + } + + /** + * Creates a new logger instance with the name of the specified class. + */ + public static InternalLogger getInstance(Class clazz) { + return getInstance(clazz.getName()); + } + + /** + * Creates a new logger instance with the specified name. + */ + public static InternalLogger getInstance(String name) { + return getDefaultFactory().newInstance(name); + } + + /** + * Creates a new logger instance with the specified name. + */ + protected abstract InternalLogger newInstance(String name); + +} diff --git a/netty-util/src/main/java/io/netty/util/internal/logging/JdkLogger.java b/netty-util/src/main/java/io/netty/util/internal/logging/JdkLogger.java new file mode 100644 index 0000000..bdd7ebf --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/internal/logging/JdkLogger.java @@ -0,0 +1,646 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +/** + * Copyright (c) 2004-2011 QOS.ch + * All rights reserved. + *

+ * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files (the + * "Software"), to deal in the Software without restriction, including + * without limitation the rights to use, copy, modify, merge, publish, + * distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to + * the following conditions: + *

+ * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + *

+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE + * LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION + * OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION + * WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ +package io.netty.util.internal.logging; + +import java.util.logging.Level; +import java.util.logging.LogRecord; +import java.util.logging.Logger; + +/** + * java.util.logging + * logger. + */ +class JdkLogger extends AbstractInternalLogger { + + private static final long serialVersionUID = -1767272577989225979L; + + final transient Logger logger; + + JdkLogger(Logger logger) { + super(logger.getName()); + this.logger = logger; + } + + /** + * Is this logger instance enabled for the FINEST level? + * + * @return True if this Logger is enabled for level FINEST, false otherwise. + */ + @Override + public boolean isTraceEnabled() { + return logger.isLoggable(Level.FINEST); + } + + /** + * Log a message object at level FINEST. + * + * @param msg + * - the message object to be logged + */ + @Override + public void trace(String msg) { + if (logger.isLoggable(Level.FINEST)) { + log(SELF, Level.FINEST, msg, null); + } + } + + /** + * Log a message at level FINEST according to the specified format and + * argument. + * + *

+ * This form avoids superfluous object creation when the logger is disabled + * for level FINEST. + *

+ * + * @param format + * the format string + * @param arg + * the argument + */ + @Override + public void trace(String format, Object arg) { + if (logger.isLoggable(Level.FINEST)) { + FormattingTuple ft = MessageFormatter.format(format, arg); + log(SELF, Level.FINEST, ft.getMessage(), ft.getThrowable()); + } + } + + /** + * Log a message at level FINEST according to the specified format and + * arguments. + * + *

+ * This form avoids superfluous object creation when the logger is disabled + * for the FINEST level. + *

+ * + * @param format + * the format string + * @param argA + * the first argument + * @param argB + * the second argument + */ + @Override + public void trace(String format, Object argA, Object argB) { + if (logger.isLoggable(Level.FINEST)) { + FormattingTuple ft = MessageFormatter.format(format, argA, argB); + log(SELF, Level.FINEST, ft.getMessage(), ft.getThrowable()); + } + } + + /** + * Log a message at level FINEST according to the specified format and + * arguments. + * + *

+ * This form avoids superfluous object creation when the logger is disabled + * for the FINEST level. + *

+ * + * @param format + * the format string + * @param argArray + * an array of arguments + */ + @Override + public void trace(String format, Object... argArray) { + if (logger.isLoggable(Level.FINEST)) { + FormattingTuple ft = MessageFormatter.arrayFormat(format, argArray); + log(SELF, Level.FINEST, ft.getMessage(), ft.getThrowable()); + } + } + + /** + * Log an exception (throwable) at level FINEST with an accompanying message. + * + * @param msg + * the message accompanying the exception + * @param t + * the exception (throwable) to log + */ + @Override + public void trace(String msg, Throwable t) { + if (logger.isLoggable(Level.FINEST)) { + log(SELF, Level.FINEST, msg, t); + } + } + + /** + * Is this logger instance enabled for the FINE level? + * + * @return True if this Logger is enabled for level FINE, false otherwise. + */ + @Override + public boolean isDebugEnabled() { + return logger.isLoggable(Level.FINE); + } + + /** + * Log a message object at level FINE. + * + * @param msg + * - the message object to be logged + */ + @Override + public void debug(String msg) { + if (logger.isLoggable(Level.FINE)) { + log(SELF, Level.FINE, msg, null); + } + } + + /** + * Log a message at level FINE according to the specified format and argument. + * + *

+ * This form avoids superfluous object creation when the logger is disabled + * for level FINE. + *

+ * + * @param format + * the format string + * @param arg + * the argument + */ + @Override + public void debug(String format, Object arg) { + if (logger.isLoggable(Level.FINE)) { + FormattingTuple ft = MessageFormatter.format(format, arg); + log(SELF, Level.FINE, ft.getMessage(), ft.getThrowable()); + } + } + + /** + * Log a message at level FINE according to the specified format and + * arguments. + * + *

+ * This form avoids superfluous object creation when the logger is disabled + * for the FINE level. + *

+ * + * @param format + * the format string + * @param argA + * the first argument + * @param argB + * the second argument + */ + @Override + public void debug(String format, Object argA, Object argB) { + if (logger.isLoggable(Level.FINE)) { + FormattingTuple ft = MessageFormatter.format(format, argA, argB); + log(SELF, Level.FINE, ft.getMessage(), ft.getThrowable()); + } + } + + /** + * Log a message at level FINE according to the specified format and + * arguments. + * + *

+ * This form avoids superfluous object creation when the logger is disabled + * for the FINE level. + *

+ * + * @param format + * the format string + * @param argArray + * an array of arguments + */ + @Override + public void debug(String format, Object... argArray) { + if (logger.isLoggable(Level.FINE)) { + FormattingTuple ft = MessageFormatter.arrayFormat(format, argArray); + log(SELF, Level.FINE, ft.getMessage(), ft.getThrowable()); + } + } + + /** + * Log an exception (throwable) at level FINE with an accompanying message. + * + * @param msg + * the message accompanying the exception + * @param t + * the exception (throwable) to log + */ + @Override + public void debug(String msg, Throwable t) { + if (logger.isLoggable(Level.FINE)) { + log(SELF, Level.FINE, msg, t); + } + } + + /** + * Is this logger instance enabled for the INFO level? + * + * @return True if this Logger is enabled for the INFO level, false otherwise. + */ + @Override + public boolean isInfoEnabled() { + return logger.isLoggable(Level.INFO); + } + + /** + * Log a message object at the INFO level. + * + * @param msg + * - the message object to be logged + */ + @Override + public void info(String msg) { + if (logger.isLoggable(Level.INFO)) { + log(SELF, Level.INFO, msg, null); + } + } + + /** + * Log a message at level INFO according to the specified format and argument. + * + *

+ * This form avoids superfluous object creation when the logger is disabled + * for the INFO level. + *

+ * + * @param format + * the format string + * @param arg + * the argument + */ + @Override + public void info(String format, Object arg) { + if (logger.isLoggable(Level.INFO)) { + FormattingTuple ft = MessageFormatter.format(format, arg); + log(SELF, Level.INFO, ft.getMessage(), ft.getThrowable()); + } + } + + /** + * Log a message at the INFO level according to the specified format and + * arguments. + * + *

+ * This form avoids superfluous object creation when the logger is disabled + * for the INFO level. + *

+ * + * @param format + * the format string + * @param argA + * the first argument + * @param argB + * the second argument + */ + @Override + public void info(String format, Object argA, Object argB) { + if (logger.isLoggable(Level.INFO)) { + FormattingTuple ft = MessageFormatter.format(format, argA, argB); + log(SELF, Level.INFO, ft.getMessage(), ft.getThrowable()); + } + } + + /** + * Log a message at level INFO according to the specified format and + * arguments. + * + *

+ * This form avoids superfluous object creation when the logger is disabled + * for the INFO level. + *

+ * + * @param format + * the format string + * @param argArray + * an array of arguments + */ + @Override + public void info(String format, Object... argArray) { + if (logger.isLoggable(Level.INFO)) { + FormattingTuple ft = MessageFormatter.arrayFormat(format, argArray); + log(SELF, Level.INFO, ft.getMessage(), ft.getThrowable()); + } + } + + /** + * Log an exception (throwable) at the INFO level with an accompanying + * message. + * + * @param msg + * the message accompanying the exception + * @param t + * the exception (throwable) to log + */ + @Override + public void info(String msg, Throwable t) { + if (logger.isLoggable(Level.INFO)) { + log(SELF, Level.INFO, msg, t); + } + } + + /** + * Is this logger instance enabled for the WARNING level? + * + * @return True if this Logger is enabled for the WARNING level, false + * otherwise. + */ + @Override + public boolean isWarnEnabled() { + return logger.isLoggable(Level.WARNING); + } + + /** + * Log a message object at the WARNING level. + * + * @param msg + * - the message object to be logged + */ + @Override + public void warn(String msg) { + if (logger.isLoggable(Level.WARNING)) { + log(SELF, Level.WARNING, msg, null); + } + } + + /** + * Log a message at the WARNING level according to the specified format and + * argument. + * + *

+ * This form avoids superfluous object creation when the logger is disabled + * for the WARNING level. + *

+ * + * @param format + * the format string + * @param arg + * the argument + */ + @Override + public void warn(String format, Object arg) { + if (logger.isLoggable(Level.WARNING)) { + FormattingTuple ft = MessageFormatter.format(format, arg); + log(SELF, Level.WARNING, ft.getMessage(), ft.getThrowable()); + } + } + + /** + * Log a message at the WARNING level according to the specified format and + * arguments. + * + *

+ * This form avoids superfluous object creation when the logger is disabled + * for the WARNING level. + *

+ * + * @param format + * the format string + * @param argA + * the first argument + * @param argB + * the second argument + */ + @Override + public void warn(String format, Object argA, Object argB) { + if (logger.isLoggable(Level.WARNING)) { + FormattingTuple ft = MessageFormatter.format(format, argA, argB); + log(SELF, Level.WARNING, ft.getMessage(), ft.getThrowable()); + } + } + + /** + * Log a message at level WARNING according to the specified format and + * arguments. + * + *

+ * This form avoids superfluous object creation when the logger is disabled + * for the WARNING level. + *

+ * + * @param format + * the format string + * @param argArray + * an array of arguments + */ + @Override + public void warn(String format, Object... argArray) { + if (logger.isLoggable(Level.WARNING)) { + FormattingTuple ft = MessageFormatter.arrayFormat(format, argArray); + log(SELF, Level.WARNING, ft.getMessage(), ft.getThrowable()); + } + } + + /** + * Log an exception (throwable) at the WARNING level with an accompanying + * message. + * + * @param msg + * the message accompanying the exception + * @param t + * the exception (throwable) to log + */ + @Override + public void warn(String msg, Throwable t) { + if (logger.isLoggable(Level.WARNING)) { + log(SELF, Level.WARNING, msg, t); + } + } + + /** + * Is this logger instance enabled for level SEVERE? + * + * @return True if this Logger is enabled for level SEVERE, false otherwise. + */ + @Override + public boolean isErrorEnabled() { + return logger.isLoggable(Level.SEVERE); + } + + /** + * Log a message object at the SEVERE level. + * + * @param msg + * - the message object to be logged + */ + @Override + public void error(String msg) { + if (logger.isLoggable(Level.SEVERE)) { + log(SELF, Level.SEVERE, msg, null); + } + } + + /** + * Log a message at the SEVERE level according to the specified format and + * argument. + * + *

+ * This form avoids superfluous object creation when the logger is disabled + * for the SEVERE level. + *

+ * + * @param format + * the format string + * @param arg + * the argument + */ + @Override + public void error(String format, Object arg) { + if (logger.isLoggable(Level.SEVERE)) { + FormattingTuple ft = MessageFormatter.format(format, arg); + log(SELF, Level.SEVERE, ft.getMessage(), ft.getThrowable()); + } + } + + /** + * Log a message at the SEVERE level according to the specified format and + * arguments. + * + *

+ * This form avoids superfluous object creation when the logger is disabled + * for the SEVERE level. + *

+ * + * @param format + * the format string + * @param argA + * the first argument + * @param argB + * the second argument + */ + @Override + public void error(String format, Object argA, Object argB) { + if (logger.isLoggable(Level.SEVERE)) { + FormattingTuple ft = MessageFormatter.format(format, argA, argB); + log(SELF, Level.SEVERE, ft.getMessage(), ft.getThrowable()); + } + } + + /** + * Log a message at level SEVERE according to the specified format and + * arguments. + * + *

+ * This form avoids superfluous object creation when the logger is disabled + * for the SEVERE level. + *

+ * + * @param format + * the format string + * @param arguments + * an array of arguments + */ + @Override + public void error(String format, Object... arguments) { + if (logger.isLoggable(Level.SEVERE)) { + FormattingTuple ft = MessageFormatter.arrayFormat(format, arguments); + log(SELF, Level.SEVERE, ft.getMessage(), ft.getThrowable()); + } + } + + /** + * Log an exception (throwable) at the SEVERE level with an accompanying + * message. + * + * @param msg + * the message accompanying the exception + * @param t + * the exception (throwable) to log + */ + @Override + public void error(String msg, Throwable t) { + if (logger.isLoggable(Level.SEVERE)) { + log(SELF, Level.SEVERE, msg, t); + } + } + + /** + * Log the message at the specified level with the specified throwable if any. + * This method creates a LogRecord and fills in caller date before calling + * this instance's JDK14 logger. + * + * See bug report #13 for more details. + */ + private void log(String callerFQCN, Level level, String msg, Throwable t) { + // millis and thread are filled by the constructor + LogRecord record = new LogRecord(level, msg); + record.setLoggerName(name()); + record.setThrown(t); + fillCallerData(callerFQCN, record); + logger.log(record); + } + + static final String SELF = JdkLogger.class.getName(); + static final String SUPER = AbstractInternalLogger.class.getName(); + + /** + * Fill in caller data if possible. + * + * @param record + * The record to update + */ + private static void fillCallerData(String callerFQCN, LogRecord record) { + StackTraceElement[] steArray = new Throwable().getStackTrace(); + + int selfIndex = -1; + for (int i = 0; i < steArray.length; i++) { + final String className = steArray[i].getClassName(); + if (className.equals(callerFQCN) || className.equals(SUPER)) { + selfIndex = i; + break; + } + } + + int found = -1; + for (int i = selfIndex + 1; i < steArray.length; i++) { + final String className = steArray[i].getClassName(); + if (!(className.equals(callerFQCN) || className.equals(SUPER))) { + found = i; + break; + } + } + + if (found != -1) { + StackTraceElement ste = steArray[found]; + // setting the class name has the side effect of setting + // the needToInferCaller variable to false. + record.setSourceClassName(ste.getClassName()); + record.setSourceMethodName(ste.getMethodName()); + } + } +} diff --git a/netty-util/src/main/java/io/netty/util/internal/logging/JdkLoggerFactory.java b/netty-util/src/main/java/io/netty/util/internal/logging/JdkLoggerFactory.java new file mode 100644 index 0000000..df2c2ed --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/internal/logging/JdkLoggerFactory.java @@ -0,0 +1,41 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.internal.logging; + + +import java.util.logging.Logger; + +/** + * Logger factory which creates a + * java.util.logging + * logger. + */ +public class JdkLoggerFactory extends InternalLoggerFactory { + + public static final InternalLoggerFactory INSTANCE = new JdkLoggerFactory(); + + /** + * @deprecated Use {@link #INSTANCE} instead. + */ + @Deprecated + public JdkLoggerFactory() { + } + + @Override + public InternalLogger newInstance(String name) { + return new JdkLogger(Logger.getLogger(name)); + } +} diff --git a/netty-util/src/main/java/io/netty/util/internal/logging/MessageFormatter.java b/netty-util/src/main/java/io/netty/util/internal/logging/MessageFormatter.java new file mode 100644 index 0000000..bc22567 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/internal/logging/MessageFormatter.java @@ -0,0 +1,396 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +/** + * Copyright (c) 2004-2011 QOS.ch + * All rights reserved. + *

+ * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files (the + * "Software"), to deal in the Software without restriction, including + * without limitation the rights to use, copy, modify, merge, publish, + * distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to + * the following conditions: + *

+ * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + *

+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE + * LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION + * OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION + * WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ +package io.netty.util.internal.logging; + +import java.text.MessageFormat; +import java.util.HashSet; +import java.util.Set; + +// contributors: lizongbo: proposed special treatment of array parameter values +// Joern Huxhorn: pointed out double[] omission, suggested deep array copy + +/** + * Formats messages according to very simple substitution rules. Substitutions + * can be made 1, 2 or more arguments. + *

+ *

+ * For example, + *

+ *

+ * MessageFormatter.format("Hi {}.", "there")
+ * 
+ *

+ * will return the string "Hi there.". + *

+ * The {} pair is called the formatting anchor. It serves to designate + * the location where arguments need to be substituted within the message + * pattern. + *

+ * In case your message contains the '{' or the '}' character, you do not have + * to do anything special unless the '}' character immediately follows '{'. For + * example, + *

+ *

+ * MessageFormatter.format("Set {1,2,3} is not equal to {}.", "1,2");
+ * 
+ *

+ * will return the string "Set {1,2,3} is not equal to 1,2.". + *

+ *

+ * If for whatever reason you need to place the string "{}" in the message + * without its formatting anchor meaning, then you need to escape the + * '{' character with '\', that is the backslash character. Only the '{' + * character should be escaped. There is no need to escape the '}' character. + * For example, + *

+ *

+ * MessageFormatter.format("Set \\{} is not equal to {}.", "1,2");
+ * 
+ *

+ * will return the string "Set {} is not equal to 1,2.". + *

+ *

+ * The escaping behavior just described can be overridden by escaping the escape + * character '\'. Calling + *

+ *

+ * MessageFormatter.format("File name is C:\\\\{}.", "file.zip");
+ * 
+ *

+ * will return the string "File name is C:\file.zip". + *

+ *

+ * The formatting conventions are different than those of {@link MessageFormat} + * which ships with the Java platform. This is justified by the fact that + * SLF4J's implementation is 10 times faster than that of {@link MessageFormat}. + * This local performance difference is both measurable and significant in the + * larger context of the complete logging processing chain. + *

+ *

+ * See also {@link #format(String, Object)}, + * {@link #format(String, Object, Object)} and + * {@link #arrayFormat(String, Object[])} methods for more details. + */ +public final class MessageFormatter { + private static final String DELIM_STR = "{}"; + private static final char ESCAPE_CHAR = '\\'; + + /** + * Performs single argument substitution for the 'messagePattern' passed as + * parameter. + *

+ * For example, + *

+ *

+     * MessageFormatter.format("Hi {}.", "there");
+     * 
+ *

+ * will return the string "Hi there.". + *

+ * + * @param messagePattern The message pattern which will be parsed and formatted + * @param arg The argument to be substituted in place of the formatting anchor + * @return The formatted message + */ + public static FormattingTuple format(String messagePattern, Object arg) { + return arrayFormat(messagePattern, new Object[]{arg}); + } + + /** + * Performs a two argument substitution for the 'messagePattern' passed as + * parameter. + *

+ * For example, + *

+ *

+     * MessageFormatter.format("Hi {}. My name is {}.", "Alice", "Bob");
+     * 
+ *

+ * will return the string "Hi Alice. My name is Bob.". + * + * @param messagePattern The message pattern which will be parsed and formatted + * @param argA The argument to be substituted in place of the first formatting + * anchor + * @param argB The argument to be substituted in place of the second formatting + * anchor + * @return The formatted message + */ + public static FormattingTuple format(final String messagePattern, + Object argA, Object argB) { + return arrayFormat(messagePattern, new Object[]{argA, argB}); + } + + /** + * Same principle as the {@link #format(String, Object)} and + * {@link #format(String, Object, Object)} methods except that any number of + * arguments can be passed in an array. + * + * @param messagePattern The message pattern which will be parsed and formatted + * @param argArray An array of arguments to be substituted in place of formatting + * anchors + * @return The formatted message + */ + public static FormattingTuple arrayFormat(final String messagePattern, + final Object[] argArray) { + if (argArray == null || argArray.length == 0) { + return new FormattingTuple(messagePattern, null); + } + + int lastArrIdx = argArray.length - 1; + Object lastEntry = argArray[lastArrIdx]; + Throwable throwable = lastEntry instanceof Throwable ? (Throwable) lastEntry : null; + + if (messagePattern == null) { + return new FormattingTuple(null, throwable); + } + + int j = messagePattern.indexOf(DELIM_STR); + if (j == -1) { + // this is a simple string + return new FormattingTuple(messagePattern, throwable); + } + + StringBuilder sbuf = new StringBuilder(messagePattern.length() + 50); + int i = 0; + int L = 0; + do { + boolean notEscaped = j == 0 || messagePattern.charAt(j - 1) != ESCAPE_CHAR; + if (notEscaped) { + // normal case + sbuf.append(messagePattern, i, j); + } else { + sbuf.append(messagePattern, i, j - 1); + // check that escape char is not is escaped: "abc x:\\{}" + notEscaped = j >= 2 && messagePattern.charAt(j - 2) == ESCAPE_CHAR; + } + + i = j + 2; + if (notEscaped) { + deeplyAppendParameter(sbuf, argArray[L], null); + L++; + if (L > lastArrIdx) { + break; + } + } else { + sbuf.append(DELIM_STR); + } + j = messagePattern.indexOf(DELIM_STR, i); + } while (j != -1); + + // append the characters following the last {} pair. + sbuf.append(messagePattern, i, messagePattern.length()); + return new FormattingTuple(sbuf.toString(), L <= lastArrIdx ? throwable : null); + } + + // special treatment of array values was suggested by 'lizongbo' + private static void deeplyAppendParameter(StringBuilder sbuf, Object o, + Set seenSet) { + if (o == null) { + sbuf.append("null"); + return; + } + Class objClass = o.getClass(); + if (!objClass.isArray()) { + if (Number.class.isAssignableFrom(objClass)) { + // Prevent String instantiation for some number types + if (objClass == Long.class) { + sbuf.append(((Long) o).longValue()); + } else if (objClass == Integer.class || objClass == Short.class || objClass == Byte.class) { + sbuf.append(((Number) o).intValue()); + } else if (objClass == Double.class) { + sbuf.append(((Double) o).doubleValue()); + } else if (objClass == Float.class) { + sbuf.append(((Float) o).floatValue()); + } else { + safeObjectAppend(sbuf, o); + } + } else { + safeObjectAppend(sbuf, o); + } + } else { + // check for primitive array types because they + // unfortunately cannot be cast to Object[] + sbuf.append('['); + if (objClass == boolean[].class) { + booleanArrayAppend(sbuf, (boolean[]) o); + } else if (objClass == byte[].class) { + byteArrayAppend(sbuf, (byte[]) o); + } else if (objClass == char[].class) { + charArrayAppend(sbuf, (char[]) o); + } else if (objClass == short[].class) { + shortArrayAppend(sbuf, (short[]) o); + } else if (objClass == int[].class) { + intArrayAppend(sbuf, (int[]) o); + } else if (objClass == long[].class) { + longArrayAppend(sbuf, (long[]) o); + } else if (objClass == float[].class) { + floatArrayAppend(sbuf, (float[]) o); + } else if (objClass == double[].class) { + doubleArrayAppend(sbuf, (double[]) o); + } else { + objectArrayAppend(sbuf, (Object[]) o, seenSet); + } + sbuf.append(']'); + } + } + + private static void safeObjectAppend(StringBuilder sbuf, Object o) { + try { + String oAsString = o.toString(); + sbuf.append(oAsString); + } catch (Throwable t) { + System.err + .println("SLF4J: Failed toString() invocation on an object of type [" + + o.getClass().getName() + ']'); + t.printStackTrace(); + sbuf.append("[FAILED toString()]"); + } + } + + private static void objectArrayAppend(StringBuilder sbuf, Object[] a, Set seenSet) { + if (a.length == 0) { + return; + } + if (seenSet == null) { + seenSet = new HashSet(a.length); + } + if (seenSet.add(a)) { + deeplyAppendParameter(sbuf, a[0], seenSet); + for (int i = 1; i < a.length; i++) { + sbuf.append(", "); + deeplyAppendParameter(sbuf, a[i], seenSet); + } + // allow repeats in siblings + seenSet.remove(a); + } else { + sbuf.append("..."); + } + } + + private static void booleanArrayAppend(StringBuilder sbuf, boolean[] a) { + if (a.length == 0) { + return; + } + sbuf.append(a[0]); + for (int i = 1; i < a.length; i++) { + sbuf.append(", "); + sbuf.append(a[i]); + } + } + + private static void byteArrayAppend(StringBuilder sbuf, byte[] a) { + if (a.length == 0) { + return; + } + sbuf.append(a[0]); + for (int i = 1; i < a.length; i++) { + sbuf.append(", "); + sbuf.append(a[i]); + } + } + + private static void charArrayAppend(StringBuilder sbuf, char[] a) { + if (a.length == 0) { + return; + } + sbuf.append(a[0]); + for (int i = 1; i < a.length; i++) { + sbuf.append(", "); + sbuf.append(a[i]); + } + } + + private static void shortArrayAppend(StringBuilder sbuf, short[] a) { + if (a.length == 0) { + return; + } + sbuf.append(a[0]); + for (int i = 1; i < a.length; i++) { + sbuf.append(", "); + sbuf.append(a[i]); + } + } + + private static void intArrayAppend(StringBuilder sbuf, int[] a) { + if (a.length == 0) { + return; + } + sbuf.append(a[0]); + for (int i = 1; i < a.length; i++) { + sbuf.append(", "); + sbuf.append(a[i]); + } + } + + private static void longArrayAppend(StringBuilder sbuf, long[] a) { + if (a.length == 0) { + return; + } + sbuf.append(a[0]); + for (int i = 1; i < a.length; i++) { + sbuf.append(", "); + sbuf.append(a[i]); + } + } + + private static void floatArrayAppend(StringBuilder sbuf, float[] a) { + if (a.length == 0) { + return; + } + sbuf.append(a[0]); + for (int i = 1; i < a.length; i++) { + sbuf.append(", "); + sbuf.append(a[i]); + } + } + + private static void doubleArrayAppend(StringBuilder sbuf, double[] a) { + if (a.length == 0) { + return; + } + sbuf.append(a[0]); + for (int i = 1; i < a.length; i++) { + sbuf.append(", "); + sbuf.append(a[i]); + } + } + + private MessageFormatter() { + } +} diff --git a/netty-util/src/main/java/io/netty/util/internal/logging/package-info.java b/netty-util/src/main/java/io/netty/util/internal/logging/package-info.java new file mode 100644 index 0000000..38c5c55 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/internal/logging/package-info.java @@ -0,0 +1,20 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/** + * Internal-use-only logging API which is not allowed to be used outside Netty. + */ +package io.netty.util.internal.logging; diff --git a/netty-util/src/main/java/io/netty/util/internal/package-info.java b/netty-util/src/main/java/io/netty/util/internal/package-info.java new file mode 100644 index 0000000..890a99c --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/internal/package-info.java @@ -0,0 +1,21 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/** + * Internal-use-only utilities which is not allowed to be used + * outside Netty. + */ +package io.netty.util.internal; diff --git a/netty-util/src/main/java/io/netty/util/package-info.java b/netty-util/src/main/java/io/netty/util/package-info.java new file mode 100644 index 0000000..598ea40 --- /dev/null +++ b/netty-util/src/main/java/io/netty/util/package-info.java @@ -0,0 +1,20 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +/** + * Utility classes used across multiple packages. + */ +package io.netty.util; diff --git a/netty-util/src/main/java/module-info.java b/netty-util/src/main/java/module-info.java new file mode 100644 index 0000000..486d833 --- /dev/null +++ b/netty-util/src/main/java/module-info.java @@ -0,0 +1,32 @@ +module org.xbib.io.netty.util { + exports io.netty.util; + exports io.netty.util.collection; + exports io.netty.util.concurrent; + exports io.netty.util.internal to + org.xbib.io.netty.buffer, + org.xbib.io.netty.channel, + org.xbib.io.netty.channel.unix, + org.xbib.io.netty.handler, + org.xbib.io.netty.handler.codec, + org.xbib.io.netty.handler.codec.compression, + org.xbib.io.netty.handler.codec.http, + org.xbib.io.netty.handler.codec.httptwo, + org.xbib.io.netty.handler.codec.protobuf, + org.xbib.io.netty.handler.ssl, + org.xbib.io.netty.resolver; + exports io.netty.util.internal.logging to + org.xbib.io.netty.buffer, + org.xbib.io.netty.channel, + org.xbib.io.netty.channel.unix, + org.xbib.io.netty.handler, + org.xbib.io.netty.handler.codec, + org.xbib.io.netty.handler.codec.compression, + org.xbib.io.netty.handler.codec.http, + org.xbib.io.netty.handler.codec.httptwo, + org.xbib.io.netty.handler.codec.protobuf, + org.xbib.io.netty.handler.ssl, + org.xbib.io.netty.resolver; + requires org.xbib.io.netty.jctools; + requires java.logging; + requires jdk.unsupported; +} diff --git a/netty-util/src/test/java/io/netty/util/AbstractReferenceCountedTest.java b/netty-util/src/test/java/io/netty/util/AbstractReferenceCountedTest.java new file mode 100644 index 0000000..5030c3e --- /dev/null +++ b/netty-util/src/test/java/io/netty/util/AbstractReferenceCountedTest.java @@ -0,0 +1,231 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util; + +import io.netty.util.internal.ThreadLocalRandom; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.api.function.Executable; + +import java.util.ArrayDeque; +import java.util.Queue; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.Future; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class AbstractReferenceCountedTest { + + @Test + public void testRetainOverflow() { + final AbstractReferenceCounted referenceCounted = newReferenceCounted(); + referenceCounted.setRefCnt(Integer.MAX_VALUE); + assertEquals(Integer.MAX_VALUE, referenceCounted.refCnt()); + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + referenceCounted.retain(); + } + }); + } + + @Test + public void testRetainOverflow2() { + final AbstractReferenceCounted referenceCounted = newReferenceCounted(); + assertEquals(1, referenceCounted.refCnt()); + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + referenceCounted.retain(Integer.MAX_VALUE); + } + }); + } + + @Test + public void testReleaseOverflow() { + final AbstractReferenceCounted referenceCounted = newReferenceCounted(); + referenceCounted.setRefCnt(0); + assertEquals(0, referenceCounted.refCnt()); + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + referenceCounted.release(Integer.MAX_VALUE); + } + }); + } + + @Test + public void testReleaseErrorMessage() { + AbstractReferenceCounted referenceCounted = newReferenceCounted(); + assertTrue(referenceCounted.release()); + try { + referenceCounted.release(1); + fail("IllegalReferenceCountException didn't occur"); + } catch (IllegalReferenceCountException e) { + assertEquals("refCnt: 0, decrement: 1", e.getMessage()); + } + } + + @Test + public void testRetainResurrect() { + final AbstractReferenceCounted referenceCounted = newReferenceCounted(); + assertTrue(referenceCounted.release()); + assertEquals(0, referenceCounted.refCnt()); + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + referenceCounted.retain(); + } + }); + } + + @Test + public void testRetainResurrect2() { + final AbstractReferenceCounted referenceCounted = newReferenceCounted(); + assertTrue(referenceCounted.release()); + assertEquals(0, referenceCounted.refCnt()); + assertThrows(IllegalReferenceCountException.class, new Executable() { + @Override + public void execute() { + referenceCounted.retain(2); + } + }); + } + + @Test + @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS) + public void testRetainFromMultipleThreadsThrowsReferenceCountException() throws Exception { + int threads = 4; + Queue> futures = new ArrayDeque>(threads); + ExecutorService service = Executors.newFixedThreadPool(threads); + final AtomicInteger refCountExceptions = new AtomicInteger(); + + try { + for (int i = 0; i < 10000; i++) { + final AbstractReferenceCounted referenceCounted = newReferenceCounted(); + final CountDownLatch retainLatch = new CountDownLatch(1); + assertTrue(referenceCounted.release()); + + for (int a = 0; a < threads; a++) { + final int retainCnt = ThreadLocalRandom.current().nextInt(1, Integer.MAX_VALUE); + futures.add(service.submit(new Runnable() { + @Override + public void run() { + try { + retainLatch.await(); + try { + referenceCounted.retain(retainCnt); + } catch (IllegalReferenceCountException e) { + refCountExceptions.incrementAndGet(); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + })); + } + retainLatch.countDown(); + + for (;;) { + Future f = futures.poll(); + if (f == null) { + break; + } + f.get(); + } + assertEquals(4, refCountExceptions.get()); + refCountExceptions.set(0); + } + } finally { + service.shutdown(); + } + } + + @Test + @Timeout(value = 30000, unit = TimeUnit.MILLISECONDS) + public void testReleaseFromMultipleThreadsThrowsReferenceCountException() throws Exception { + int threads = 4; + Queue> futures = new ArrayDeque>(threads); + ExecutorService service = Executors.newFixedThreadPool(threads); + final AtomicInteger refCountExceptions = new AtomicInteger(); + + try { + for (int i = 0; i < 10000; i++) { + final AbstractReferenceCounted referenceCounted = newReferenceCounted(); + final CountDownLatch releaseLatch = new CountDownLatch(1); + final AtomicInteger releasedCount = new AtomicInteger(); + + for (int a = 0; a < threads; a++) { + final AtomicInteger releaseCnt = new AtomicInteger(0); + + futures.add(service.submit(new Runnable() { + @Override + public void run() { + try { + releaseLatch.await(); + try { + if (referenceCounted.release(releaseCnt.incrementAndGet())) { + releasedCount.incrementAndGet(); + } + } catch (IllegalReferenceCountException e) { + refCountExceptions.incrementAndGet(); + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } + } + })); + } + releaseLatch.countDown(); + + for (;;) { + Future f = futures.poll(); + if (f == null) { + break; + } + f.get(); + } + assertEquals(3, refCountExceptions.get()); + assertEquals(1, releasedCount.get()); + + refCountExceptions.set(0); + } + } finally { + service.shutdown(); + } + } + + private static AbstractReferenceCounted newReferenceCounted() { + return new AbstractReferenceCounted() { + @Override + protected void deallocate() { + // NOOP + } + + @Override + public ReferenceCounted touch(Object hint) { + return this; + } + }; + } +} diff --git a/netty-util/src/test/java/io/netty/util/AsciiStringCharacterTest.java b/netty-util/src/test/java/io/netty/util/AsciiStringCharacterTest.java new file mode 100644 index 0000000..b2a9214 --- /dev/null +++ b/netty-util/src/test/java/io/netty/util/AsciiStringCharacterTest.java @@ -0,0 +1,436 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util; + +import org.junit.jupiter.api.Test; + +import java.nio.CharBuffer; +import java.nio.charset.Charset; +import java.util.Random; + +import static io.netty.util.AsciiString.contains; +import static io.netty.util.AsciiString.containsIgnoreCase; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * Test character encoding and case insensitivity for the {@link AsciiString} class + */ +public class AsciiStringCharacterTest { + private static final Random r = new Random(); + + @Test + public void testContentEqualsIgnoreCase() { + byte[] bytes = { 32, 'a' }; + AsciiString asciiString = new AsciiString(bytes, 1, 1, false); + // https://github.com/netty/netty/issues/9475 + assertFalse(asciiString.contentEqualsIgnoreCase("b")); + assertFalse(asciiString.contentEqualsIgnoreCase(AsciiString.of("b"))); + } + + @Test + public void testGetBytesStringBuilder() { + final StringBuilder b = new StringBuilder(); + for (int i = 0; i < 1 << 16; ++i) { + b.append("eéaà"); + } + final String bString = b.toString(); + final Charset[] charsets = CharsetUtil.values(); + for (int i = 0; i < charsets.length; ++i) { + final Charset charset = charsets[i]; + byte[] expected = bString.getBytes(charset); + byte[] actual = new AsciiString(b, charset).toByteArray(); + assertArrayEquals(expected, actual, "failure for " + charset); + } + } + + @Test + public void testGetBytesString() { + final StringBuilder b = new StringBuilder(); + for (int i = 0; i < 1 << 16; ++i) { + b.append("eéaà"); + } + final String bString = b.toString(); + final Charset[] charsets = CharsetUtil.values(); + for (int i = 0; i < charsets.length; ++i) { + final Charset charset = charsets[i]; + byte[] expected = bString.getBytes(charset); + byte[] actual = new AsciiString(bString, charset).toByteArray(); + assertArrayEquals(expected, actual, "failure for " + charset); + } + } + + @Test + public void testGetBytesAsciiString() { + final StringBuilder b = new StringBuilder(); + for (int i = 0; i < 1 << 16; ++i) { + b.append("eéaà"); + } + final String bString = b.toString(); + // The AsciiString class actually limits the Charset to ISO_8859_1 + byte[] expected = bString.getBytes(CharsetUtil.ISO_8859_1); + byte[] actual = new AsciiString(bString).toByteArray(); + assertArrayEquals(expected, actual); + } + + @Test + public void testComparisonWithString() { + String string = "shouldn't fail"; + AsciiString ascii = new AsciiString(string.toCharArray()); + assertEquals(string, ascii.toString()); + } + + @Test + public void subSequenceTest() { + byte[] init = {'t', 'h', 'i', 's', ' ', 'i', 's', ' ', 'a', ' ', 't', 'e', 's', 't' }; + AsciiString ascii = new AsciiString(init); + final int start = 2; + final int end = init.length; + AsciiString sub1 = ascii.subSequence(start, end, false); + AsciiString sub2 = ascii.subSequence(start, end, true); + assertEquals(sub1.hashCode(), sub2.hashCode()); + assertEquals(sub1, sub2); + for (int i = start; i < end; ++i) { + assertEquals(init[i], sub1.byteAt(i - start)); + } + } + + @Test + public void testContains() { + String[] falseLhs = {null, "a", "aa", "aaa" }; + String[] falseRhs = {null, "b", "ba", "baa" }; + for (int i = 0; i < falseLhs.length; ++i) { + for (int j = 0; j < falseRhs.length; ++j) { + assertContains(falseLhs[i], falseRhs[i], false, false); + } + } + + assertContains("", "", true, true); + assertContains("AsfdsF", "", true, true); + assertContains("", "b", false, false); + assertContains("a", "a", true, true); + assertContains("a", "b", false, false); + assertContains("a", "A", false, true); + String b = "xyz"; + String a = b; + assertContains(a, b, true, true); + + a = "a" + b; + assertContains(a, b, true, true); + + a = b + "a"; + assertContains(a, b, true, true); + + a = "a" + b + "a"; + assertContains(a, b, true, true); + + b = "xYz"; + a = "xyz"; + assertContains(a, b, false, true); + + b = "xYz"; + a = "xyzxxxXyZ" + b + "aaa"; + assertContains(a, b, true, true); + + b = "foOo"; + a = "fooofoO"; + assertContains(a, b, false, true); + + b = "Content-Equals: 10000"; + a = "content-equals: 1000"; + assertContains(a, b, false, false); + a += "0"; + assertContains(a, b, false, true); + } + + private static void assertContains(String a, String b, boolean caseSensitiveEquals, boolean caseInsenstaiveEquals) { + assertEquals(caseSensitiveEquals, contains(a, b)); + assertEquals(caseInsenstaiveEquals, containsIgnoreCase(a, b)); + } + + @Test + public void testCaseSensitivity() { + int i = 0; + for (; i < 32; i++) { + doCaseSensitivity(i); + } + final int min = i; + final int max = 4000; + final int len = r.nextInt((max - min) + 1) + min; + doCaseSensitivity(len); + } + + private static void doCaseSensitivity(int len) { + // Build an upper case and lower case string + final int upperA = 'A'; + final int upperZ = 'Z'; + final int upperToLower = (int) 'a' - upperA; + byte[] lowerCaseBytes = new byte[len]; + StringBuilder upperCaseBuilder = new StringBuilder(len); + for (int i = 0; i < len; ++i) { + char upper = (char) (r.nextInt((upperZ - upperA) + 1) + upperA); + upperCaseBuilder.append(upper); + lowerCaseBytes[i] = (byte) (upper + upperToLower); + } + String upperCaseString = upperCaseBuilder.toString(); + String lowerCaseString = new String(lowerCaseBytes); + AsciiString lowerCaseAscii = new AsciiString(lowerCaseBytes, false); + AsciiString upperCaseAscii = new AsciiString(upperCaseString); + final String errorString = "len: " + len; + // Test upper case hash codes are equal + final int upperCaseExpected = upperCaseAscii.hashCode(); + assertEquals(upperCaseExpected, AsciiString.hashCode(upperCaseBuilder), errorString); + assertEquals(upperCaseExpected, AsciiString.hashCode(upperCaseString), errorString); + assertEquals(upperCaseExpected, upperCaseAscii.hashCode(), errorString); + + // Test lower case hash codes are equal + final int lowerCaseExpected = lowerCaseAscii.hashCode(); + assertEquals(lowerCaseExpected, AsciiString.hashCode(lowerCaseAscii), errorString); + assertEquals(lowerCaseExpected, AsciiString.hashCode(lowerCaseString), errorString); + assertEquals(lowerCaseExpected, lowerCaseAscii.hashCode(), errorString); + + // Test case insensitive hash codes are equal + final int expectedCaseInsensitive = lowerCaseAscii.hashCode(); + assertEquals(expectedCaseInsensitive, AsciiString.hashCode(upperCaseBuilder), errorString); + assertEquals(expectedCaseInsensitive, AsciiString.hashCode(upperCaseString), errorString); + assertEquals(expectedCaseInsensitive, AsciiString.hashCode(lowerCaseString), errorString); + assertEquals(expectedCaseInsensitive, AsciiString.hashCode(lowerCaseAscii), errorString); + assertEquals(expectedCaseInsensitive, AsciiString.hashCode(upperCaseAscii), errorString); + assertEquals(expectedCaseInsensitive, lowerCaseAscii.hashCode(), errorString); + assertEquals(expectedCaseInsensitive, upperCaseAscii.hashCode(), errorString); + + // Test that opposite cases are equal + assertEquals(lowerCaseAscii.hashCode(), AsciiString.hashCode(upperCaseString), errorString); + assertEquals(upperCaseAscii.hashCode(), AsciiString.hashCode(lowerCaseString), errorString); + } + + @Test + public void caseInsensitiveHasherCharBuffer() { + String s1 = new String("TRANSFER-ENCODING"); + char[] array = new char[128]; + final int offset = 100; + for (int i = 0; i < s1.length(); ++i) { + array[offset + i] = s1.charAt(i); + } + CharBuffer buffer = CharBuffer.wrap(array, offset, s1.length()); + assertEquals(AsciiString.hashCode(s1), AsciiString.hashCode(buffer)); + } + + @Test + public void testBooleanUtilityMethods() { + assertTrue(new AsciiString(new byte[] { 1 }).parseBoolean()); + assertFalse(AsciiString.EMPTY_STRING.parseBoolean()); + assertFalse(new AsciiString(new byte[] { 0 }).parseBoolean()); + assertTrue(new AsciiString(new byte[] { 5 }).parseBoolean()); + assertTrue(new AsciiString(new byte[] { 2, 0 }).parseBoolean()); + } + + @Test + public void testEqualsIgnoreCase() { + assertThat(AsciiString.contentEqualsIgnoreCase(null, null), is(true)); + assertThat(AsciiString.contentEqualsIgnoreCase(null, "foo"), is(false)); + assertThat(AsciiString.contentEqualsIgnoreCase("bar", null), is(false)); + assertThat(AsciiString.contentEqualsIgnoreCase("FoO", "fOo"), is(true)); + assertThat(AsciiString.contentEqualsIgnoreCase("FoO", "bar"), is(false)); + assertThat(AsciiString.contentEqualsIgnoreCase("Foo", "foobar"), is(false)); + assertThat(AsciiString.contentEqualsIgnoreCase("foobar", "Foo"), is(false)); + + // Test variations (Ascii + String, Ascii + Ascii, String + Ascii) + assertThat(AsciiString.contentEqualsIgnoreCase(new AsciiString("FoO"), "fOo"), is(true)); + assertThat(AsciiString.contentEqualsIgnoreCase(new AsciiString("FoO"), new AsciiString("fOo")), is(true)); + assertThat(AsciiString.contentEqualsIgnoreCase("FoO", new AsciiString("fOo")), is(true)); + + // Test variations (Ascii + String, Ascii + Ascii, String + Ascii) + assertThat(AsciiString.contentEqualsIgnoreCase(new AsciiString("FoO"), "bAr"), is(false)); + assertThat(AsciiString.contentEqualsIgnoreCase(new AsciiString("FoO"), new AsciiString("bAr")), is(false)); + assertThat(AsciiString.contentEqualsIgnoreCase("FoO", new AsciiString("bAr")), is(false)); + } + + @Test + public void testIndexOfIgnoreCase() { + assertEquals(-1, AsciiString.indexOfIgnoreCase(null, "abc", 1)); + assertEquals(-1, AsciiString.indexOfIgnoreCase("abc", null, 1)); + assertEquals(0, AsciiString.indexOfIgnoreCase("", "", 0)); + assertEquals(0, AsciiString.indexOfIgnoreCase("aabaabaa", "A", 0)); + assertEquals(2, AsciiString.indexOfIgnoreCase("aabaabaa", "B", 0)); + assertEquals(1, AsciiString.indexOfIgnoreCase("aabaabaa", "AB", 0)); + assertEquals(5, AsciiString.indexOfIgnoreCase("aabaabaa", "B", 3)); + assertEquals(-1, AsciiString.indexOfIgnoreCase("aabaabaa", "B", 9)); + assertEquals(2, AsciiString.indexOfIgnoreCase("aabaabaa", "B", -1)); + assertEquals(2, AsciiString.indexOfIgnoreCase("aabaabaa", "", 2)); + assertEquals(-1, AsciiString.indexOfIgnoreCase("abc", "", 9)); + assertEquals(0, AsciiString.indexOfIgnoreCase("ãabaabaa", "Ã", 0)); + } + + @Test + public void testIndexOfIgnoreCaseAscii() { + assertEquals(-1, AsciiString.indexOfIgnoreCaseAscii(null, "abc", 1)); + assertEquals(-1, AsciiString.indexOfIgnoreCaseAscii("abc", null, 1)); + assertEquals(0, AsciiString.indexOfIgnoreCaseAscii("", "", 0)); + assertEquals(0, AsciiString.indexOfIgnoreCaseAscii("aabaabaa", "A", 0)); + assertEquals(2, AsciiString.indexOfIgnoreCaseAscii("aabaabaa", "B", 0)); + assertEquals(1, AsciiString.indexOfIgnoreCaseAscii("aabaabaa", "AB", 0)); + assertEquals(5, AsciiString.indexOfIgnoreCaseAscii("aabaabaa", "B", 3)); + assertEquals(-1, AsciiString.indexOfIgnoreCaseAscii("aabaabaa", "B", 9)); + assertEquals(2, AsciiString.indexOfIgnoreCaseAscii("aabaabaa", "B", -1)); + assertEquals(2, AsciiString.indexOfIgnoreCaseAscii("aabaabaa", "", 2)); + assertEquals(-1, AsciiString.indexOfIgnoreCaseAscii("abc", "", 9)); + } + + @Test + public void testTrim() { + assertEquals("", AsciiString.EMPTY_STRING.trim().toString()); + assertEquals("abc", new AsciiString(" abc").trim().toString()); + assertEquals("abc", new AsciiString("abc ").trim().toString()); + assertEquals("abc", new AsciiString(" abc ").trim().toString()); + } + + @Test + public void testIndexOfChar() { + assertEquals(-1, AsciiString.indexOf(null, 'a', 0)); + assertEquals(-1, AsciiString.of("").indexOf('a', 0)); + assertEquals(-1, AsciiString.of("abc").indexOf('d', 0)); + assertEquals(-1, AsciiString.of("aabaabaa").indexOf('A', 0)); + assertEquals(0, AsciiString.of("aabaabaa").indexOf('a', 0)); + assertEquals(1, AsciiString.of("aabaabaa").indexOf('a', 1)); + assertEquals(3, AsciiString.of("aabaabaa").indexOf('a', 2)); + assertEquals(3, AsciiString.of("aabdabaa").indexOf('d', 1)); + assertEquals(1, new AsciiString("abcd", 1, 2).indexOf('c', 0)); + assertEquals(2, new AsciiString("abcd", 1, 3).indexOf('d', 2)); + assertEquals(0, new AsciiString("abcd", 1, 2).indexOf('b', 0)); + assertEquals(-1, new AsciiString("abcd", 0, 2).indexOf('c', 0)); + assertEquals(-1, new AsciiString("abcd", 1, 3).indexOf('a', 0)); + } + + @Test + public void testIndexOfCharSequence() { + assertEquals(0, new AsciiString("abcd").indexOf("abcd", 0)); + assertEquals(0, new AsciiString("abcd").indexOf("abc", 0)); + assertEquals(1, new AsciiString("abcd").indexOf("bcd", 0)); + assertEquals(1, new AsciiString("abcd").indexOf("bc", 0)); + assertEquals(1, new AsciiString("abcdabcd").indexOf("bcd", 0)); + assertEquals(0, new AsciiString("abcd", 1, 2).indexOf("bc", 0)); + assertEquals(0, new AsciiString("abcd", 1, 3).indexOf("bcd", 0)); + assertEquals(1, new AsciiString("abcdabcd", 4, 4).indexOf("bcd", 0)); + assertEquals(3, new AsciiString("012345").indexOf("345", 3)); + assertEquals(3, new AsciiString("012345").indexOf("345", 0)); + + // Test with empty string + assertEquals(0, new AsciiString("abcd").indexOf("", 0)); + assertEquals(1, new AsciiString("abcd").indexOf("", 1)); + assertEquals(3, new AsciiString("abcd", 1, 3).indexOf("", 4)); + + // Test not found + assertEquals(-1, new AsciiString("abcd").indexOf("abcde", 0)); + assertEquals(-1, new AsciiString("abcdbc").indexOf("bce", 0)); + assertEquals(-1, new AsciiString("abcd", 1, 3).indexOf("abc", 0)); + assertEquals(-1, new AsciiString("abcd", 1, 2).indexOf("bd", 0)); + assertEquals(-1, new AsciiString("012345").indexOf("345", 4)); + assertEquals(-1, new AsciiString("012345").indexOf("abc", 3)); + assertEquals(-1, new AsciiString("012345").indexOf("abc", 0)); + assertEquals(-1, new AsciiString("012345").indexOf("abcdefghi", 0)); + assertEquals(-1, new AsciiString("012345").indexOf("abcdefghi", 4)); + } + + @Test + public void testStaticIndexOfChar() { + assertEquals(-1, AsciiString.indexOf(null, 'a', 0)); + assertEquals(-1, AsciiString.indexOf("", 'a', 0)); + assertEquals(-1, AsciiString.indexOf("abc", 'd', 0)); + assertEquals(-1, AsciiString.indexOf("aabaabaa", 'A', 0)); + assertEquals(0, AsciiString.indexOf("aabaabaa", 'a', 0)); + assertEquals(1, AsciiString.indexOf("aabaabaa", 'a', 1)); + assertEquals(3, AsciiString.indexOf("aabaabaa", 'a', 2)); + assertEquals(3, AsciiString.indexOf("aabdabaa", 'd', 1)); + } + + @Test + public void testLastIndexOfCharSequence() { + final byte[] bytes = {'a', 'b', 'c', 'd', 'e'}; + final AsciiString ascii = new AsciiString(bytes, 2, 3, false); + + assertEquals(0, new AsciiString("abcd").lastIndexOf("abcd", 0)); + assertEquals(0, new AsciiString("abcd").lastIndexOf("abc", 4)); + assertEquals(1, new AsciiString("abcd").lastIndexOf("bcd", 4)); + assertEquals(1, new AsciiString("abcd").lastIndexOf("bc", 4)); + assertEquals(5, new AsciiString("abcdabcd").lastIndexOf("bcd", 10)); + assertEquals(0, new AsciiString("abcd", 1, 2).lastIndexOf("bc", 2)); + assertEquals(0, new AsciiString("abcd", 1, 3).lastIndexOf("bcd", 3)); + assertEquals(1, new AsciiString("abcdabcd", 4, 4).lastIndexOf("bcd", 4)); + assertEquals(3, new AsciiString("012345").lastIndexOf("345", 3)); + assertEquals(3, new AsciiString("012345").lastIndexOf("345", 6)); + assertEquals(1, ascii.lastIndexOf("de", 3)); + assertEquals(0, ascii.lastIndexOf("cde", 3)); + + // Test with empty string + assertEquals(0, new AsciiString("abcd").lastIndexOf("", 0)); + assertEquals(1, new AsciiString("abcd").lastIndexOf("", 1)); + assertEquals(3, new AsciiString("abcd", 1, 3).lastIndexOf("", 4)); + assertEquals(3, ascii.lastIndexOf("", 3)); + + // Test not found + assertEquals(-1, new AsciiString("abcd").lastIndexOf("abcde", 0)); + assertEquals(-1, new AsciiString("abcdbc").lastIndexOf("bce", 0)); + assertEquals(-1, new AsciiString("abcd", 1, 3).lastIndexOf("abc", 0)); + assertEquals(-1, new AsciiString("abcd", 1, 2).lastIndexOf("bd", 0)); + assertEquals(-1, new AsciiString("012345").lastIndexOf("345", 2)); + assertEquals(-1, new AsciiString("012345").lastIndexOf("abc", 3)); + assertEquals(-1, new AsciiString("012345").lastIndexOf("abc", 0)); + assertEquals(-1, new AsciiString("012345").lastIndexOf("abcdefghi", 0)); + assertEquals(-1, new AsciiString("012345").lastIndexOf("abcdefghi", 4)); + assertEquals(-1, ascii.lastIndexOf("a", 3)); + assertEquals(-1, ascii.lastIndexOf("abc", 3)); + assertEquals(-1, ascii.lastIndexOf("ce", 3)); + } + + @Test + public void testReplace() { + AsciiString abcd = new AsciiString("abcd"); + assertEquals(new AsciiString("adcd"), abcd.replace('b', 'd')); + assertEquals(new AsciiString("dbcd"), abcd.replace('a', 'd')); + assertEquals(new AsciiString("abca"), abcd.replace('d', 'a')); + assertSame(abcd, abcd.replace('x', 'a')); + assertEquals(new AsciiString("cc"), new AsciiString("abcd", 1, 2).replace('b', 'c')); + assertEquals(new AsciiString("bb"), new AsciiString("abcd", 1, 2).replace('c', 'b')); + assertEquals(new AsciiString("bddd"), new AsciiString("abcdc", 1, 4).replace('c', 'd')); + assertEquals(new AsciiString("xbcxd"), new AsciiString("abcada", 0, 5).replace('a', 'x')); + } + + @Test + public void testSubStringHashCode() { + //two "123"s + assertEquals(AsciiString.hashCode("123"), AsciiString.hashCode("a123".substring(1))); + } + + @Test + public void testIndexOf() { + AsciiString foo = AsciiString.of("This is a test"); + int i1 = foo.indexOf(' ', 0); + assertEquals(4, i1); + int i2 = foo.indexOf(' ', i1 + 1); + assertEquals(7, i2); + int i3 = foo.indexOf(' ', i2 + 1); + assertEquals(9, i3); + assertTrue(i3 + 1 < foo.length()); + int i4 = foo.indexOf(' ', i3 + 1); + assertEquals(i4, -1); + } +} diff --git a/netty-util/src/test/java/io/netty/util/AsciiStringMemoryTest.java b/netty-util/src/test/java/io/netty/util/AsciiStringMemoryTest.java new file mode 100644 index 0000000..efd7600 --- /dev/null +++ b/netty-util/src/test/java/io/netty/util/AsciiStringMemoryTest.java @@ -0,0 +1,173 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import io.netty.util.ByteProcessor.IndexOfProcessor; + +import java.util.Random; +import java.util.concurrent.atomic.AtomicReference; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +/** + * Test the underlying memory methods for the {@link AsciiString} class. + */ +public class AsciiStringMemoryTest { + private byte[] a; + private byte[] b; + private int aOffset = 22; + private int bOffset = 53; + private int length = 100; + private AsciiString aAsciiString; + private AsciiString bAsciiString; + private final Random r = new Random(); + + @BeforeEach + public void setup() { + a = new byte[128]; + b = new byte[256]; + r.nextBytes(a); + r.nextBytes(b); + aOffset = 22; + bOffset = 53; + length = 100; + System.arraycopy(a, aOffset, b, bOffset, length); + aAsciiString = new AsciiString(a, aOffset, length, false); + bAsciiString = new AsciiString(b, bOffset, length, false); + } + + @Test + public void testSharedMemory() { + ++a[aOffset]; + AsciiString aAsciiString1 = new AsciiString(a, aOffset, length, true); + AsciiString aAsciiString2 = new AsciiString(a, aOffset, length, false); + assertEquals(aAsciiString, aAsciiString1); + assertEquals(aAsciiString, aAsciiString2); + for (int i = aOffset; i < length; ++i) { + assertEquals(a[i], aAsciiString.byteAt(i - aOffset)); + } + } + + @Test + public void testNotSharedMemory() { + AsciiString aAsciiString1 = new AsciiString(a, aOffset, length, true); + ++a[aOffset]; + assertNotEquals(aAsciiString, aAsciiString1); + int i = aOffset; + assertNotEquals(a[i], aAsciiString1.byteAt(i - aOffset)); + ++i; + for (; i < length; ++i) { + assertEquals(a[i], aAsciiString1.byteAt(i - aOffset)); + } + } + + @Test + public void forEachTest() throws Exception { + final AtomicReference aCount = new AtomicReference(0); + final AtomicReference bCount = new AtomicReference(0); + aAsciiString.forEachByte(new ByteProcessor() { + int i; + @Override + public boolean process(byte value) { + assertEquals(value, bAsciiString.byteAt(i++), "failed at index: " + i); + aCount.set(aCount.get() + 1); + return true; + } + }); + bAsciiString.forEachByte(new ByteProcessor() { + int i; + @Override + public boolean process(byte value) { + assertEquals(value, aAsciiString.byteAt(i++), "failed at index: " + i); + bCount.set(bCount.get() + 1); + return true; + } + }); + assertEquals(aAsciiString.length(), aCount.get().intValue()); + assertEquals(bAsciiString.length(), bCount.get().intValue()); + } + + @Test + public void forEachWithIndexEndTest() throws Exception { + assertNotEquals(-1, aAsciiString.forEachByte(aAsciiString.length() - 1, + 1, new IndexOfProcessor(aAsciiString.byteAt(aAsciiString.length() - 1)))); + } + + @Test + public void forEachWithIndexBeginTest() throws Exception { + assertNotEquals(-1, aAsciiString.forEachByte(0, + 1, new IndexOfProcessor(aAsciiString.byteAt(0)))); + } + + @Test + public void forEachDescTest() throws Exception { + final AtomicReference aCount = new AtomicReference(0); + final AtomicReference bCount = new AtomicReference(0); + aAsciiString.forEachByteDesc(new ByteProcessor() { + int i = 1; + @Override + public boolean process(byte value) { + assertEquals(value, bAsciiString.byteAt(bAsciiString.length() - (i++)), "failed at index: " + i); + aCount.set(aCount.get() + 1); + return true; + } + }); + bAsciiString.forEachByteDesc(new ByteProcessor() { + int i = 1; + @Override + public boolean process(byte value) { + assertEquals(value, aAsciiString.byteAt(aAsciiString.length() - (i++)), "failed at index: " + i); + bCount.set(bCount.get() + 1); + return true; + } + }); + assertEquals(aAsciiString.length(), aCount.get().intValue()); + assertEquals(bAsciiString.length(), bCount.get().intValue()); + } + + @Test + public void forEachDescWithIndexEndTest() throws Exception { + assertNotEquals(-1, bAsciiString.forEachByteDesc(bAsciiString.length() - 1, + 1, new IndexOfProcessor(bAsciiString.byteAt(bAsciiString.length() - 1)))); + } + + @Test + public void forEachDescWithIndexBeginTest() throws Exception { + assertNotEquals(-1, bAsciiString.forEachByteDesc(0, + 1, new IndexOfProcessor(bAsciiString.byteAt(0)))); + } + + @Test + public void subSequenceTest() { + final int start = 12; + final int end = aAsciiString.length(); + AsciiString aSubSequence = aAsciiString.subSequence(start, end, false); + AsciiString bSubSequence = bAsciiString.subSequence(start, end, true); + assertEquals(aSubSequence, bSubSequence); + assertEquals(aSubSequence.hashCode(), bSubSequence.hashCode()); + } + + @Test + public void copyTest() { + byte[] aCopy = new byte[aAsciiString.length()]; + aAsciiString.copy(0, aCopy, 0, aCopy.length); + AsciiString aAsciiStringCopy = new AsciiString(aCopy, false); + assertEquals(aAsciiString, aAsciiStringCopy); + } +} diff --git a/netty-util/src/test/java/io/netty/util/AttributeKeyTest.java b/netty-util/src/test/java/io/netty/util/AttributeKeyTest.java new file mode 100644 index 0000000..0024671 --- /dev/null +++ b/netty-util/src/test/java/io/netty/util/AttributeKeyTest.java @@ -0,0 +1,63 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class AttributeKeyTest { + + @Test + public void testExists() { + String name = "test"; + assertFalse(AttributeKey.exists(name)); + AttributeKey attr = AttributeKey.valueOf(name); + + assertTrue(AttributeKey.exists(name)); + assertNotNull(attr); + } + + @Test + public void testValueOf() { + String name = "test1"; + assertFalse(AttributeKey.exists(name)); + AttributeKey attr = AttributeKey.valueOf(name); + AttributeKey attr2 = AttributeKey.valueOf(name); + + assertSame(attr, attr2); + } + + @Test + public void testNewInstance() { + String name = "test2"; + assertFalse(AttributeKey.exists(name)); + AttributeKey attr = AttributeKey.newInstance(name); + assertTrue(AttributeKey.exists(name)); + assertNotNull(attr); + + try { + AttributeKey.newInstance(name); + fail(); + } catch (IllegalArgumentException e) { + // expected + } + } +} diff --git a/netty-util/src/test/java/io/netty/util/ConstantPoolTest.java b/netty-util/src/test/java/io/netty/util/ConstantPoolTest.java new file mode 100644 index 0000000..5bcafb0 --- /dev/null +++ b/netty-util/src/test/java/io/netty/util/ConstantPoolTest.java @@ -0,0 +1,108 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import java.util.Arrays; +import java.util.Comparator; +import java.util.Set; +import java.util.TreeSet; + +import static org.hamcrest.CoreMatchers.*; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class ConstantPoolTest { + + static final class TestConstant extends AbstractConstant { + TestConstant(int id, String name) { + super(id, name); + } + } + + private static final ConstantPool pool = new ConstantPool() { + @Override + protected TestConstant newConstant(int id, String name) { + return new TestConstant(id, name); + } + }; + + @Test + public void testCannotProvideNullName() { + assertThrows(NullPointerException.class, new Executable() { + @Override + public void execute() { + pool.valueOf(null); + } + }); + } + + @Test + @SuppressWarnings("RedundantStringConstructorCall") + public void testUniqueness() { + TestConstant a = pool.valueOf(new String("Leroy")); + TestConstant b = pool.valueOf(new String("Leroy")); + assertThat(a, is(sameInstance(b))); + } + + @Test + public void testIdUniqueness() { + TestConstant one = pool.valueOf("one"); + TestConstant two = pool.valueOf("two"); + assertThat(one.id(), is(not(two.id()))); + } + + @Test + public void testCompare() { + TestConstant a = pool.valueOf("a_alpha"); + TestConstant b = pool.valueOf("b_beta"); + TestConstant c = pool.valueOf("c_gamma"); + TestConstant d = pool.valueOf("d_delta"); + TestConstant e = pool.valueOf("e_epsilon"); + + Set set = new TreeSet(); + set.add(b); + set.add(c); + set.add(e); + set.add(d); + set.add(a); + + TestConstant[] array = set.toArray(new TestConstant[0]); + assertThat(array.length, is(5)); + + // Sort by name + Arrays.sort(array, new Comparator() { + @Override + public int compare(TestConstant o1, TestConstant o2) { + return o1.name().compareTo(o2.name()); + } + }); + + assertThat(array[0], is(sameInstance(a))); + assertThat(array[1], is(sameInstance(b))); + assertThat(array[2], is(sameInstance(c))); + assertThat(array[3], is(sameInstance(d))); + assertThat(array[4], is(sameInstance(e))); + } + + @Test + public void testComposedName() { + TestConstant a = pool.valueOf(Object.class, "A"); + assertThat(a.name(), is("java.lang.Object#A")); + } +} diff --git a/netty-util/src/test/java/io/netty/util/DefaultAttributeMapTest.java b/netty-util/src/test/java/io/netty/util/DefaultAttributeMapTest.java new file mode 100644 index 0000000..cb70c40 --- /dev/null +++ b/netty-util/src/test/java/io/netty/util/DefaultAttributeMapTest.java @@ -0,0 +1,131 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNotSame; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class DefaultAttributeMapTest { + + private DefaultAttributeMap map; + + @BeforeEach + public void setup() { + map = new DefaultAttributeMap(); + } + + @Test + public void testMapExists() { + assertNotNull(map); + } + + @Test + public void testGetSetString() { + AttributeKey key = AttributeKey.valueOf("Nothing"); + Attribute one = map.attr(key); + + assertSame(one, map.attr(key)); + + one.setIfAbsent("Whoohoo"); + assertSame("Whoohoo", one.get()); + + one.setIfAbsent("What"); + assertNotSame("What", one.get()); + + one.remove(); + assertNull(one.get()); + } + + @Test + public void testGetSetInt() { + AttributeKey key = AttributeKey.valueOf("Nada"); + Attribute one = map.attr(key); + + assertSame(one, map.attr(key)); + + one.setIfAbsent(3653); + assertEquals(Integer.valueOf(3653), one.get()); + + one.setIfAbsent(1); + assertNotSame(1, one.get()); + + one.remove(); + assertNull(one.get()); + } + + // See https://github.com/netty/netty/issues/2523 + @Test + public void testSetRemove() { + AttributeKey key = AttributeKey.valueOf("key"); + + Attribute attr = map.attr(key); + attr.set(1); + assertSame(1, attr.getAndRemove()); + + Attribute attr2 = map.attr(key); + attr2.set(2); + assertSame(2, attr2.get()); + assertNotSame(attr, attr2); + } + + @Test + public void testHasAttrRemoved() { + AttributeKey[] keys = new AttributeKey[20]; + for (int i = 0; i < 20; i++) { + keys[i] = AttributeKey.valueOf(Integer.toString(i)); + } + for (int i = 10; i < 20; i++) { + map.attr(keys[i]); + } + for (int i = 0; i < 10; i++) { + map.attr(keys[i]); + } + for (int i = 10; i < 20; i++) { + AttributeKey key = AttributeKey.valueOf(Integer.toString(i)); + assertTrue(map.hasAttr(key)); + map.attr(key).remove(); + assertFalse(map.hasAttr(key)); + } + for (int i = 0; i < 10; i++) { + AttributeKey key = AttributeKey.valueOf(Integer.toString(i)); + assertTrue(map.hasAttr(key)); + map.attr(key).remove(); + assertFalse(map.hasAttr(key)); + } + } + + @Test + public void testGetAndSetWithNull() { + AttributeKey key = AttributeKey.valueOf("key"); + + Attribute attr = map.attr(key); + attr.set(1); + assertSame(1, attr.getAndSet(null)); + + Attribute attr2 = map.attr(key); + attr2.set(2); + assertSame(2, attr2.get()); + assertSame(attr, attr2); + } +} diff --git a/netty-util/src/test/java/io/netty/util/DomainNameMappingTest.java b/netty-util/src/test/java/io/netty/util/DomainNameMappingTest.java new file mode 100644 index 0000000..f5a85ca --- /dev/null +++ b/netty-util/src/test/java/io/netty/util/DomainNameMappingTest.java @@ -0,0 +1,246 @@ +/* +* Copyright 2015 The Netty Project +* +* The Netty Project licenses this file to you under the Apache License, +* version 2.0 (the "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at: +* +* https://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ + +package io.netty.util; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import java.util.Map; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +@SuppressWarnings("deprecation") +public class DomainNameMappingTest { + + // Deprecated API + + @Test + public void testNullDefaultValueInDeprecatedApi() { + assertThrows(NullPointerException.class, new Executable() { + @Override + public void execute() { + new DomainNameMapping(null); + } + }); + } + + @Test + public void testNullDomainNamePatternsAreForbiddenInDeprecatedApi() { + assertThrows(NullPointerException.class, new Executable() { + @Override + public void execute() { + new DomainNameMapping("NotFound").add(null, "Some value"); + } + }); + } + + @Test + public void testNullValuesAreForbiddenInDeprecatedApi() { + assertThrows(NullPointerException.class, new Executable() { + @Override + public void execute() { + new DomainNameMapping("NotFound").add("Some key", null); + } + }); + } + + @Test + public void testDefaultValueInDeprecatedApi() { + DomainNameMapping mapping = new DomainNameMapping("NotFound"); + + assertEquals("NotFound", mapping.map("not-existing")); + + mapping.add("*.netty.io", "Netty"); + + assertEquals("NotFound", mapping.map("not-existing")); + } + + @Test + public void testStrictEqualityInDeprecatedApi() { + DomainNameMapping mapping = new DomainNameMapping("NotFound") + .add("netty.io", "Netty") + .add("downloads.netty.io", "Netty-Downloads"); + + assertEquals("Netty", mapping.map("netty.io")); + assertEquals("Netty-Downloads", mapping.map("downloads.netty.io")); + + assertEquals("NotFound", mapping.map("x.y.z.netty.io")); + } + + @Test + public void testWildcardMatchesAnyPrefixInDeprecatedApi() { + DomainNameMapping mapping = new DomainNameMapping("NotFound") + .add("*.netty.io", "Netty"); + + assertEquals("Netty", mapping.map("netty.io")); + assertEquals("Netty", mapping.map("downloads.netty.io")); + assertEquals("Netty", mapping.map("x.y.z.netty.io")); + + assertEquals("NotFound", mapping.map("netty.io.x")); + } + + @Test + public void testFirstMatchWinsInDeprecatedApi() { + assertEquals("Netty", + new DomainNameMapping("NotFound") + .add("*.netty.io", "Netty") + .add("downloads.netty.io", "Netty-Downloads") + .map("downloads.netty.io")); + + assertEquals("Netty-Downloads", + new DomainNameMapping("NotFound") + .add("downloads.netty.io", "Netty-Downloads") + .add("*.netty.io", "Netty") + .map("downloads.netty.io")); + } + + @Test + public void testToStringInDeprecatedApi() { + DomainNameMapping mapping = new DomainNameMapping("NotFound") + .add("*.netty.io", "Netty") + .add("downloads.netty.io", "Netty-Downloads"); + + assertEquals( + "DomainNameMapping(default: NotFound, map: {*.netty.io=Netty, downloads.netty.io=Netty-Downloads})", + mapping.toString()); + } + + // Immutable DomainNameMapping Builder API + + @Test + public void testNullDefaultValue() { + assertThrows(NullPointerException.class, new Executable() { + @Override + public void execute() { + new DomainNameMappingBuilder(null); + } + }); + } + + @Test + public void testNullDomainNamePatternsAreForbidden() { + assertThrows(NullPointerException.class, new Executable() { + @Override + public void execute() { + new DomainNameMappingBuilder("NotFound").add(null, "Some value"); + } + }); + } + + @Test + public void testNullValuesAreForbidden() { + + assertThrows(NullPointerException.class, new Executable() { + @Override + public void execute() { + new DomainNameMappingBuilder("NotFound").add("Some key", null); + } + }); + } + + @Test + public void testDefaultValue() { + DomainNameMapping mapping = new DomainNameMappingBuilder("NotFound") + .add("*.netty.io", "Netty") + .build(); + + assertEquals("NotFound", mapping.map("not-existing")); + } + + @Test + public void testStrictEquality() { + DomainNameMapping mapping = new DomainNameMappingBuilder("NotFound") + .add("netty.io", "Netty") + .add("downloads.netty.io", "Netty-Downloads") + .build(); + + assertEquals("Netty", mapping.map("netty.io")); + assertEquals("Netty-Downloads", mapping.map("downloads.netty.io")); + + assertEquals("NotFound", mapping.map("x.y.z.netty.io")); + } + + @Test + public void testWildcardMatchesAnyPrefix() { + DomainNameMapping mapping = new DomainNameMappingBuilder("NotFound") + .add("*.netty.io", "Netty") + .build(); + + assertEquals("Netty", mapping.map("netty.io")); + assertEquals("Netty", mapping.map("downloads.netty.io")); + assertEquals("Netty", mapping.map("x.y.z.netty.io")); + + assertEquals("NotFound", mapping.map("netty.io.x")); + } + + @Test + public void testFirstMatchWins() { + assertEquals("Netty", + new DomainNameMappingBuilder("NotFound") + .add("*.netty.io", "Netty") + .add("downloads.netty.io", "Netty-Downloads") + .build() + .map("downloads.netty.io")); + + assertEquals("Netty-Downloads", + new DomainNameMappingBuilder("NotFound") + .add("downloads.netty.io", "Netty-Downloads") + .add("*.netty.io", "Netty") + .build() + .map("downloads.netty.io")); + } + + @Test + public void testToString() { + DomainNameMapping mapping = new DomainNameMappingBuilder("NotFound") + .add("*.netty.io", "Netty") + .add("downloads.netty.io", "Netty-Download") + .build(); + + assertEquals( + "ImmutableDomainNameMapping(default: NotFound, map: {*.netty.io=Netty, downloads.netty.io=Netty-Download})", + mapping.toString()); + } + + @Test + public void testAsMap() { + DomainNameMapping mapping = new DomainNameMapping("NotFound") + .add("netty.io", "Netty") + .add("downloads.netty.io", "Netty-Downloads"); + + Map entries = mapping.asMap(); + + assertEquals(2, entries.size()); + assertEquals("Netty", entries.get("netty.io")); + assertEquals("Netty-Downloads", entries.get("downloads.netty.io")); + } + + @Test + public void testAsMapWithImmutableDomainNameMapping() { + DomainNameMapping mapping = new DomainNameMappingBuilder("NotFound") + .add("netty.io", "Netty") + .add("downloads.netty.io", "Netty-Downloads") + .build(); + + Map entries = mapping.asMap(); + + assertEquals(2, entries.size()); + assertEquals("Netty", entries.get("netty.io")); + assertEquals("Netty-Downloads", entries.get("downloads.netty.io")); + } +} diff --git a/netty-util/src/test/java/io/netty/util/DomainWildcardMappingBuilderTest.java b/netty-util/src/test/java/io/netty/util/DomainWildcardMappingBuilderTest.java new file mode 100644 index 0000000..e5cf9ac --- /dev/null +++ b/netty-util/src/test/java/io/netty/util/DomainWildcardMappingBuilderTest.java @@ -0,0 +1,121 @@ +/* +* Copyright 2015 The Netty Project +* +* The Netty Project licenses this file to you under the Apache License, +* version 2.0 (the "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at: +* +* https://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ + +package io.netty.util; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class DomainWildcardMappingBuilderTest { + + @Test + public void testNullDefaultValue() { + assertThrows(NullPointerException.class, new Executable() { + @Override + public void execute() { + new DomainWildcardMappingBuilder(null); + } + }); + } + + @Test + public void testNullDomainNamePatternsAreForbidden() { + assertThrows(NullPointerException.class, new Executable() { + @Override + public void execute() { + new DomainWildcardMappingBuilder("NotFound").add(null, "Some value"); + } + }); + } + + @Test + public void testNullValuesAreForbidden() { + assertThrows(NullPointerException.class, new Executable() { + @Override + public void execute() { + new DomainWildcardMappingBuilder("NotFound").add("Some key", null); + } + }); + } + + @Test + public void testDefaultValue() { + Mapping mapping = new DomainWildcardMappingBuilder("NotFound") + .add("*.netty.io", "Netty") + .build(); + + assertEquals("NotFound", mapping.map("not-existing")); + } + + @Test + public void testStrictEquality() { + Mapping mapping = new DomainWildcardMappingBuilder("NotFound") + .add("netty.io", "Netty") + .add("downloads.netty.io", "Netty-Downloads") + .build(); + + assertEquals("Netty", mapping.map("netty.io")); + assertEquals("Netty-Downloads", mapping.map("downloads.netty.io")); + + assertEquals("NotFound", mapping.map("x.y.z.netty.io")); + } + + @Test + public void testWildcardMatchesNotAnyPrefix() { + Mapping mapping = new DomainWildcardMappingBuilder("NotFound") + .add("*.netty.io", "Netty") + .build(); + + assertEquals("NotFound", mapping.map("netty.io")); + assertEquals("Netty", mapping.map("downloads.netty.io")); + assertEquals("NotFound", mapping.map("x.y.z.netty.io")); + + assertEquals("NotFound", mapping.map("netty.io.x")); + } + + @Test + public void testExactMatchWins() { + assertEquals("Netty-Downloads", + new DomainWildcardMappingBuilder("NotFound") + .add("*.netty.io", "Netty") + .add("downloads.netty.io", "Netty-Downloads") + .build() + .map("downloads.netty.io")); + + assertEquals("Netty-Downloads", + new DomainWildcardMappingBuilder("NotFound") + .add("downloads.netty.io", "Netty-Downloads") + .add("*.netty.io", "Netty") + .build() + .map("downloads.netty.io")); + } + + @Test + public void testToString() { + Mapping mapping = new DomainWildcardMappingBuilder("NotFound") + .add("*.netty.io", "Netty") + .add("downloads.netty.io", "Netty-Download") + .build(); + + assertEquals( + "ImmutableDomainWildcardMapping(default: NotFound, map: " + + "{*.netty.io=Netty, downloads.netty.io=Netty-Download})", + mapping.toString()); + } +} diff --git a/netty-util/src/test/java/io/netty/util/HashedWheelTimerTest.java b/netty-util/src/test/java/io/netty/util/HashedWheelTimerTest.java new file mode 100644 index 0000000..62ebfde --- /dev/null +++ b/netty-util/src/test/java/io/netty/util/HashedWheelTimerTest.java @@ -0,0 +1,299 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util; + +import org.junit.jupiter.api.Test; + +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.TimeUnit; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class HashedWheelTimerTest { + + @Test + public void testScheduleTimeoutShouldNotRunBeforeDelay() throws InterruptedException { + final Timer timer = new HashedWheelTimer(); + final CountDownLatch barrier = new CountDownLatch(1); + final Timeout timeout = timer.newTimeout(new TimerTask() { + @Override + public void run(Timeout timeout) throws Exception { + fail("This should not have run"); + barrier.countDown(); + } + }, 10, TimeUnit.SECONDS); + assertFalse(barrier.await(3, TimeUnit.SECONDS)); + assertFalse(timeout.isExpired(), "timer should not expire"); + timer.stop(); + } + + @Test + public void testScheduleTimeoutShouldRunAfterDelay() throws InterruptedException { + final Timer timer = new HashedWheelTimer(); + final CountDownLatch barrier = new CountDownLatch(1); + final Timeout timeout = timer.newTimeout(new TimerTask() { + @Override + public void run(Timeout timeout) throws Exception { + barrier.countDown(); + } + }, 2, TimeUnit.SECONDS); + assertTrue(barrier.await(3, TimeUnit.SECONDS)); + assertTrue(timeout.isExpired(), "timer should expire"); + timer.stop(); + } + + @Test + @org.junit.jupiter.api.Timeout(value = 3000, unit = TimeUnit.MILLISECONDS) + public void testStopTimer() throws InterruptedException { + final CountDownLatch latch = new CountDownLatch(3); + final Timer timerProcessed = new HashedWheelTimer(); + for (int i = 0; i < 3; i ++) { + timerProcessed.newTimeout(new TimerTask() { + @Override + public void run(final Timeout timeout) throws Exception { + latch.countDown(); + } + }, 1, TimeUnit.MILLISECONDS); + } + + latch.await(); + assertEquals(0, timerProcessed.stop().size(), "Number of unprocessed timeouts should be 0"); + + final Timer timerUnprocessed = new HashedWheelTimer(); + for (int i = 0; i < 5; i ++) { + timerUnprocessed.newTimeout(new TimerTask() { + @Override + public void run(Timeout timeout) throws Exception { + } + }, 5, TimeUnit.SECONDS); + } + Thread.sleep(1000L); // sleep for a second + assertFalse(timerUnprocessed.stop().isEmpty(), "Number of unprocessed timeouts should be greater than 0"); + } + + @Test + @org.junit.jupiter.api.Timeout(value = 3000, unit = TimeUnit.MILLISECONDS) + public void testTimerShouldThrowExceptionAfterShutdownForNewTimeouts() throws InterruptedException { + final CountDownLatch latch = new CountDownLatch(3); + final Timer timer = new HashedWheelTimer(); + for (int i = 0; i < 3; i ++) { + timer.newTimeout(new TimerTask() { + @Override + public void run(Timeout timeout) throws Exception { + latch.countDown(); + } + }, 1, TimeUnit.MILLISECONDS); + } + + latch.await(); + timer.stop(); + + try { + timer.newTimeout(createNoOpTimerTask(), 1, TimeUnit.MILLISECONDS); + fail("Expected exception didn't occur."); + } catch (IllegalStateException ignored) { + // expected + } + } + + @Test + @org.junit.jupiter.api.Timeout(value = 5000, unit = TimeUnit.MILLISECONDS) + public void testTimerOverflowWheelLength() throws InterruptedException { + final HashedWheelTimer timer = new HashedWheelTimer( + Executors.defaultThreadFactory(), 100, TimeUnit.MILLISECONDS, 32); + final CountDownLatch latch = new CountDownLatch(3); + + timer.newTimeout(new TimerTask() { + @Override + public void run(final Timeout timeout) throws Exception { + timer.newTimeout(this, 100, TimeUnit.MILLISECONDS); + latch.countDown(); + } + }, 100, TimeUnit.MILLISECONDS); + + latch.await(); + assertFalse(timer.stop().isEmpty()); + } + + @Test + public void testExecutionOnTime() throws InterruptedException { + int tickDuration = 200; + int timeout = 125; + int maxTimeout = 2 * (tickDuration + timeout); + final HashedWheelTimer timer = new HashedWheelTimer(tickDuration, TimeUnit.MILLISECONDS); + final BlockingQueue queue = new LinkedBlockingQueue(); + + int scheduledTasks = 100000; + for (int i = 0; i < scheduledTasks; i++) { + final long start = System.nanoTime(); + timer.newTimeout(new TimerTask() { + @Override + public void run(final Timeout timeout) throws Exception { + queue.add(TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - start)); + } + }, timeout, TimeUnit.MILLISECONDS); + } + + for (int i = 0; i < scheduledTasks; i++) { + long delay = queue.take(); + assertTrue(delay >= timeout && delay < maxTimeout, + "Timeout + " + scheduledTasks + " delay " + delay + " must be " + timeout + " < " + maxTimeout); + } + + timer.stop(); + } + + @Test + public void testExecutionOnTaskExecutor() throws InterruptedException { + int timeout = 10; + + final CountDownLatch latch = new CountDownLatch(1); + final CountDownLatch timeoutLatch = new CountDownLatch(1); + Executor executor = new Executor() { + @Override + public void execute(Runnable command) { + try { + command.run(); + } finally { + latch.countDown(); + } + } + }; + final HashedWheelTimer timer = new HashedWheelTimer(Executors.defaultThreadFactory(), 100, + TimeUnit.MILLISECONDS, 32, true, 2, executor); + timer.newTimeout(new TimerTask() { + @Override + public void run(final Timeout timeout) throws Exception { + timeoutLatch.countDown(); + } + }, timeout, TimeUnit.MILLISECONDS); + + latch.await(); + timeoutLatch.await(); + timer.stop(); + } + + @Test + public void testRejectedExecutionExceptionWhenTooManyTimeoutsAreAddedBackToBack() { + HashedWheelTimer timer = new HashedWheelTimer(Executors.defaultThreadFactory(), 100, + TimeUnit.MILLISECONDS, 32, true, 2); + timer.newTimeout(createNoOpTimerTask(), 5, TimeUnit.SECONDS); + timer.newTimeout(createNoOpTimerTask(), 5, TimeUnit.SECONDS); + try { + timer.newTimeout(createNoOpTimerTask(), 1, TimeUnit.MILLISECONDS); + fail("Timer allowed adding 3 timeouts when maxPendingTimeouts was 2"); + } catch (RejectedExecutionException e) { + // Expected + } finally { + timer.stop(); + } + } + + @Test + public void testNewTimeoutShouldStopThrowingRejectedExecutionExceptionWhenExistingTimeoutIsCancelled() + throws InterruptedException { + final int tickDurationMs = 100; + final HashedWheelTimer timer = new HashedWheelTimer(Executors.defaultThreadFactory(), tickDurationMs, + TimeUnit.MILLISECONDS, 32, true, 2); + timer.newTimeout(createNoOpTimerTask(), 5, TimeUnit.SECONDS); + Timeout timeoutToCancel = timer.newTimeout(createNoOpTimerTask(), 5, TimeUnit.SECONDS); + assertTrue(timeoutToCancel.cancel()); + + Thread.sleep(tickDurationMs * 5); + + final CountDownLatch secondLatch = new CountDownLatch(1); + timer.newTimeout(createCountDownLatchTimerTask(secondLatch), 90, TimeUnit.MILLISECONDS); + + secondLatch.await(); + timer.stop(); + } + + @Test + @org.junit.jupiter.api.Timeout(value = 3000, unit = TimeUnit.MILLISECONDS) + public void testNewTimeoutShouldStopThrowingRejectedExecutionExceptionWhenExistingTimeoutIsExecuted() + throws InterruptedException { + final CountDownLatch latch = new CountDownLatch(1); + final HashedWheelTimer timer = new HashedWheelTimer(Executors.defaultThreadFactory(), 25, + TimeUnit.MILLISECONDS, 4, true, 2); + timer.newTimeout(createNoOpTimerTask(), 5, TimeUnit.SECONDS); + timer.newTimeout(createCountDownLatchTimerTask(latch), 90, TimeUnit.MILLISECONDS); + + latch.await(); + + final CountDownLatch secondLatch = new CountDownLatch(1); + timer.newTimeout(createCountDownLatchTimerTask(secondLatch), 90, TimeUnit.MILLISECONDS); + + secondLatch.await(); + timer.stop(); + } + + @Test() + public void reportPendingTimeouts() throws InterruptedException { + final CountDownLatch latch = new CountDownLatch(1); + final HashedWheelTimer timer = new HashedWheelTimer(); + final Timeout t1 = timer.newTimeout(createNoOpTimerTask(), 100, TimeUnit.MINUTES); + final Timeout t2 = timer.newTimeout(createNoOpTimerTask(), 100, TimeUnit.MINUTES); + timer.newTimeout(createCountDownLatchTimerTask(latch), 90, TimeUnit.MILLISECONDS); + + assertEquals(3, timer.pendingTimeouts()); + t1.cancel(); + t2.cancel(); + latch.await(); + + assertEquals(0, timer.pendingTimeouts()); + timer.stop(); + } + + @Test + public void testOverflow() throws InterruptedException { + final HashedWheelTimer timer = new HashedWheelTimer(); + final CountDownLatch latch = new CountDownLatch(1); + Timeout timeout = timer.newTimeout(new TimerTask() { + @Override + public void run(Timeout timeout) { + latch.countDown(); + } + }, Long.MAX_VALUE, TimeUnit.MILLISECONDS); + assertFalse(latch.await(1, TimeUnit.SECONDS)); + timeout.cancel(); + timer.stop(); + } + + private static TimerTask createNoOpTimerTask() { + return new TimerTask() { + @Override + public void run(final Timeout timeout) throws Exception { + } + }; + } + + private static TimerTask createCountDownLatchTimerTask(final CountDownLatch latch) { + return new TimerTask() { + @Override + public void run(final Timeout timeout) throws Exception { + latch.countDown(); + } + }; + } +} diff --git a/netty-util/src/test/java/io/netty/util/NetUtilTest.java b/netty-util/src/test/java/io/netty/util/NetUtilTest.java new file mode 100644 index 0000000..9b0b105 --- /dev/null +++ b/netty-util/src/test/java/io/netty/util/NetUtilTest.java @@ -0,0 +1,820 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util; + +import io.netty.util.internal.StringUtil; +import org.junit.jupiter.api.Test; + +import java.net.Inet4Address; +import java.net.Inet6Address; +import java.net.InetAddress; +import java.net.InetSocketAddress; +import java.net.UnknownHostException; +import java.util.HashMap; +import java.util.Map; +import java.util.Map.Entry; + +import static io.netty.util.NetUtil.*; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class NetUtilTest { + + private static final class TestMap extends HashMap { + private static final long serialVersionUID = -298642816998608473L; + + TestMap(String... values) { + for (int i = 0; i < values.length; i += 2) { + String key = values[i]; + String value = values[i + 1]; + put(key, value); + } + } + } + + private static final Map validIpV4Hosts = new TestMap( + "192.168.1.0", "c0a80100", + "10.255.255.254", "0afffffe", + "172.18.5.4", "ac120504", + "0.0.0.0", "00000000", + "127.0.0.1", "7f000001", + "255.255.255.255", "ffffffff", + "1.2.3.4", "01020304"); + + private static final Map invalidIpV4Hosts = new TestMap( + "1.256.3.4", null, + "256.0.0.1", null, + "1.1.1.1.1", null, + "x.255.255.255", null, + "0.1:0.0", null, + "0.1.0.0:", null, + "127.0.0.", null, + "1.2..4", null, + "192.0.1", null, + "192.0.1.1.1", null, + "192.0.1.a", null, + "19a.0.1.1", null, + "a.0.1.1", null, + ".0.1.1", null, + "127.0.0", null, + "192.0.1.256", null, + "0.0.200.259", null, + "1.1.-1.1", null, + "1.1. 1.1", null, + "1.1.1.1 ", null, + "1.1.+1.1", null, + "0.0x1.0.255", null, + "0.01x.0.255", null, + "0.x01.0.255", null, + "0.-.0.0", null, + "0..0.0", null, + "0.A.0.0", null, + "0.1111.0.0", null, + "...", null); + + private static final Map validIpV6Hosts = new TestMap( + "::ffff:5.6.7.8", "00000000000000000000ffff05060708", + "fdf8:f53b:82e4::53", "fdf8f53b82e400000000000000000053", + "fe80::200:5aee:feaa:20a2", "fe8000000000000002005aeefeaa20a2", + "2001::1", "20010000000000000000000000000001", + "2001:0000:4136:e378:8000:63bf:3fff:fdd2", "200100004136e378800063bf3ffffdd2", + "2001:0002:6c::430", "20010002006c00000000000000000430", + "2001:10:240:ab::a", "20010010024000ab000000000000000a", + "2002:cb0a:3cdd:1::1", "2002cb0a3cdd00010000000000000001", + "2001:db8:8:4::2", "20010db8000800040000000000000002", + "ff01:0:0:0:0:0:0:2", "ff010000000000000000000000000002", + "[fdf8:f53b:82e4::53]", "fdf8f53b82e400000000000000000053", + "[fe80::200:5aee:feaa:20a2]", "fe8000000000000002005aeefeaa20a2", + "[2001::1]", "20010000000000000000000000000001", + "[2001:0000:4136:e378:8000:63bf:3fff:fdd2]", "200100004136e378800063bf3ffffdd2", + "0:1:2:3:4:5:6:789a", "0000000100020003000400050006789a", + "0:1:2:3::f", "0000000100020003000000000000000f", + "0:0:0:0:0:0:10.0.0.1", "00000000000000000000ffff0a000001", + "0:0:0:0:0::10.0.0.1", "00000000000000000000ffff0a000001", + "0:0:0:0::10.0.0.1", "00000000000000000000ffff0a000001", + "::0:0:0:0:0:10.0.0.1", "00000000000000000000ffff0a000001", + "0::0:0:0:0:10.0.0.1", "00000000000000000000ffff0a000001", + "0:0::0:0:0:10.0.0.1", "00000000000000000000ffff0a000001", + "0:0:0::0:0:10.0.0.1", "00000000000000000000ffff0a000001", + "0:0:0:0::0:10.0.0.1", "00000000000000000000ffff0a000001", + "0:0:0:0:0:ffff:10.0.0.1", "00000000000000000000ffff0a000001", + "::ffff:192.168.0.1", "00000000000000000000ffffc0a80001", + // Test if various interface names after the percent sign are recognized. + "[::1%1]", "00000000000000000000000000000001", + "[::1%eth0]", "00000000000000000000000000000001", + "[::1%%]", "00000000000000000000000000000001", + "0:0:0:0:0:ffff:10.0.0.1%", "00000000000000000000ffff0a000001", + "0:0:0:0:0:ffff:10.0.0.1%1", "00000000000000000000ffff0a000001", + "[0:0:0:0:0:ffff:10.0.0.1%1]", "00000000000000000000ffff0a000001", + "[0:0:0:0:0::10.0.0.1%1]", "00000000000000000000ffff0a000001", + "[::0:0:0:0:ffff:10.0.0.1%1]", "00000000000000000000ffff0a000001", + "::0:0:0:0:ffff:10.0.0.1%1", "00000000000000000000ffff0a000001", + "::1%1", "00000000000000000000000000000001", + "::1%eth0", "00000000000000000000000000000001", + "::1%%", "00000000000000000000000000000001", + // Tests with leading or trailing compression + "0:0:0:0:0:0:0::", "00000000000000000000000000000000", + "0:0:0:0:0:0::", "00000000000000000000000000000000", + "0:0:0:0:0::", "00000000000000000000000000000000", + "0:0:0:0::", "00000000000000000000000000000000", + "0:0:0::", "00000000000000000000000000000000", + "0:0::", "00000000000000000000000000000000", + "0::", "00000000000000000000000000000000", + "::", "00000000000000000000000000000000", + "::0", "00000000000000000000000000000000", + "::0:0", "00000000000000000000000000000000", + "::0:0:0", "00000000000000000000000000000000", + "::0:0:0:0", "00000000000000000000000000000000", + "::0:0:0:0:0", "00000000000000000000000000000000", + "::0:0:0:0:0:0", "00000000000000000000000000000000", + "::0:0:0:0:0:0:0", "00000000000000000000000000000000"); + + private static final Map invalidIpV6Hosts = new TestMap( + // Test method with garbage. + "Obvious Garbage", null, + // Test method with preferred style, too many : + "0:1:2:3:4:5:6:7:8", null, + // Test method with preferred style, not enough : + "0:1:2:3:4:5:6", null, + // Test method with preferred style, bad digits. + "0:1:2:3:4:5:6:x", null, + // Test method with preferred style, adjacent : + "0:1:2:3:4:5:6::7", null, + // Too many : separators trailing + "0:1:2:3:4:5:6:7::", null, + // Too many : separators leading + "::0:1:2:3:4:5:6:7", null, + // Too many : separators trailing + "1:2:3:4:5:6:7:", null, + // Too many : separators leading + ":1:2:3:4:5:6:7", null, + // Compression with : separators trailing + "0:1:2:3:4:5::7:", null, + "0:1:2:3:4::7:", null, + "0:1:2:3::7:", null, + "0:1:2::7:", null, + "0:1::7:", null, + "0::7:", null, + // Compression at start with : separators trailing + "::0:1:2:3:4:5:7:", null, + "::0:1:2:3:4:7:", null, + "::0:1:2:3:7:", null, + "::0:1:2:7:", null, + "::0:1:7:", null, + "::7:", null, + // The : separators leading and trailing + ":1:2:3:4:5:6:7:", null, + ":1:2:3:4:5:6:", null, + ":1:2:3:4:5:", null, + ":1:2:3:4:", null, + ":1:2:3:", null, + ":1:2:", null, + ":1:", null, + // Compression with : separators leading + ":1::2:3:4:5:6:7", null, + ":1::3:4:5:6:7", null, + ":1::4:5:6:7", null, + ":1::5:6:7", null, + ":1::6:7", null, + ":1::7", null, + ":1:2:3:4:5:6::7", null, + ":1:3:4:5:6::7", null, + ":1:4:5:6::7", null, + ":1:5:6::7", null, + ":1:6::7", null, + ":1::", null, + // Compression trailing with : separators leading + ":1:2:3:4:5:6:7::", null, + ":1:3:4:5:6:7::", null, + ":1:4:5:6:7::", null, + ":1:5:6:7::", null, + ":1:6:7::", null, + ":1:7::", null, + // Double compression + "1::2:3:4:5:6::", null, + "::1:2:3:4:5::6", null, + "::1:2:3:4:5:6::", null, + "::1:2:3:4:5::", null, + "::1:2:3:4::", null, + "::1:2:3::", null, + "::1:2::", null, + "::0::", null, + "12::0::12", null, + // Too many : separators leading 0 + "0::1:2:3:4:5:6:7", null, + // Test method with preferred style, too many digits. + "0:1:2:3:4:5:6:789abcdef", null, + // Test method with compressed style, bad digits. + "0:1:2:3::x", null, + // Test method with compressed style, too many adjacent : + "0:1:2:::3", null, + // Test method with compressed style, too many digits. + "0:1:2:3::abcde", null, + // Test method with compressed style, not enough : + "0:1", null, + // Test method with ipv4 style, bad ipv6 digits. + "0:0:0:0:0:x:10.0.0.1", null, + // Test method with ipv4 style, bad ipv4 digits. + "0:0:0:0:0:0:10.0.0.x", null, + // Test method with ipv4 style, too many ipv6 digits. + "0:0:0:0:0:00000:10.0.0.1", null, + // Test method with ipv4 style, too many : + "0:0:0:0:0:0:0:10.0.0.1", null, + // Test method with ipv4 style, not enough : + "0:0:0:0:0:10.0.0.1", null, + // Test method with ipv4 style, too many . + "0:0:0:0:0:0:10.0.0.0.1", null, + // Test method with ipv4 style, not enough . + "0:0:0:0:0:0:10.0.1", null, + // Test method with ipv4 style, adjacent . + "0:0:0:0:0:0:10..0.0.1", null, + // Test method with ipv4 style, leading . + "0:0:0:0:0:0:.0.0.1", null, + // Test method with ipv4 style, leading . + "0:0:0:0:0:0:.10.0.0.1", null, + // Test method with ipv4 style, trailing . + "0:0:0:0:0:0:10.0.0.", null, + // Test method with ipv4 style, trailing . + "0:0:0:0:0:0:10.0.0.1.", null, + // Test method with compressed ipv4 style, bad ipv6 digits. + "::fffx:192.168.0.1", null, + // Test method with compressed ipv4 style, bad ipv4 digits. + "::ffff:192.168.0.x", null, + // Test method with compressed ipv4 style, too many adjacent : + ":::ffff:192.168.0.1", null, + // Test method with compressed ipv4 style, too many ipv6 digits. + "::fffff:192.168.0.1", null, + // Test method with compressed ipv4 style, too many ipv4 digits. + "::ffff:1923.168.0.1", null, + // Test method with compressed ipv4 style, not enough : + ":ffff:192.168.0.1", null, + // Test method with compressed ipv4 style, too many . + "::ffff:192.168.0.1.2", null, + // Test method with compressed ipv4 style, not enough . + "::ffff:192.168.0", null, + // Test method with compressed ipv4 style, adjacent . + "::ffff:192.168..0.1", null, + // Test method, bad ipv6 digits. + "x:0:0:0:0:0:10.0.0.1", null, + // Test method, bad ipv4 digits. + "0:0:0:0:0:0:x.0.0.1", null, + // Test method, too many ipv6 digits. + "00000:0:0:0:0:0:10.0.0.1", null, + // Test method, too many ipv4 digits. + "0:0:0:0:0:0:10.0.0.1000", null, + // Test method, too many : + "0:0:0:0:0:0:0:10.0.0.1", null, + // Test method, not enough : + "0:0:0:0:0:10.0.0.1", null, + // Test method, out of order trailing : + "0:0:0:0:0:10.0.0.1:", null, + // Test method, out of order leading : + ":0:0:0:0:0:10.0.0.1", null, + // Test method, out of order leading : + "0:0:0:0::10.0.0.1:", null, + // Test method, out of order trailing : + ":0:0:0:0::10.0.0.1", null, + // Test method, too many . + "0:0:0:0:0:0:10.0.0.0.1", null, + // Test method, not enough . + "0:0:0:0:0:0:10.0.1", null, + // Test method, adjacent . + "0:0:0:0:0:0:10.0.0..1", null, + // Empty contents + "", null, + // Invalid single compression + ":", null, + ":::", null, + // Trailing : (max number of : = 8) + "2001:0:4136:e378:8000:63bf:3fff:fdd2:", null, + // Leading : (max number of : = 8) + ":aaaa:bbbb:cccc:dddd:eeee:ffff:1111:2222", null, + // Invalid character + "1234:2345:3456:4567:5678:6789::X890", null, + // Trailing . in IPv4 + "::ffff:255.255.255.255.", null, + // To many characters in IPv4 + "::ffff:0.0.1111.0", null, + // Test method, adjacent . + "::ffff:0.0..0", null, + // Not enough IPv4 entries trailing . + "::ffff:127.0.0.", null, + // Invalid trailing IPv4 character + "::ffff:127.0.0.a", null, + // Invalid leading IPv4 character + "::ffff:a.0.0.1", null, + // Invalid middle IPv4 character + "::ffff:127.a.0.1", null, + // Invalid middle IPv4 character + "::ffff:127.0.a.1", null, + // Not enough IPv4 entries no trailing . + "::ffff:1.2.4", null, + // Extra IPv4 entry + "::ffff:192.168.0.1.255", null, + // Not enough IPv6 content + ":ffff:192.168.0.1.255", null, + // Intermixed IPv4 and IPv6 symbols + "::ffff:255.255:255.255.", null, + // Invalid IPv4 mapped address - invalid ipv4 separator + "0:0:0::0:0:00f.0.0.1", null, + // Invalid IPv4 mapped address - not enough f's + "0:0:0:0:0:fff:1.0.0.1", null, + // Invalid IPv4 mapped address - not IPv4 mapped, not IPv4 compatible + "0:0:0:0:0:ff00:1.0.0.1", null, + // Invalid IPv4 mapped address - not IPv4 mapped, not IPv4 compatible + "0:0:0:0:0:ff:1.0.0.1", null, + // Invalid IPv4 mapped address - too many f's + "0:0:0:0:0:fffff:1.0.0.1", null, + // Invalid IPv4 mapped address - too many bytes (too many 0's) + "0:0:0:0:0:0:ffff:1.0.0.1", null, + // Invalid IPv4 mapped address - too many bytes (too many 0's) + "::0:0:0:0:0:ffff:1.0.0.1", null, + // Invalid IPv4 mapped address - too many bytes (too many 0's) + "0:0:0:0:0:0::1.0.0.1", null, + // Invalid IPv4 mapped address - too many bytes (too many 0's) + "0:0:0:0:0:00000:1.0.0.1", null, + // Invalid IPv4 mapped address - too few bytes (not enough 0's) + "0:0:0:0:ffff:1.0.0.1", null, + // Invalid IPv4 mapped address - too few bytes (not enough 0's) + "ffff:192.168.0.1", null, + // Invalid IPv4 mapped address - 0's after the mapped ffff indicator + "0:0:0:0:0:ffff::10.0.0.1", null, + // Invalid IPv4 mapped address - 0's after the mapped ffff indicator + "0:0:0:0:ffff::10.0.0.1", null, + // Invalid IPv4 mapped address - 0's after the mapped ffff indicator + "0:0:0:ffff::10.0.0.1", null, + // Invalid IPv4 mapped address - 0's after the mapped ffff indicator + "0:0:ffff::10.0.0.1", null, + // Invalid IPv4 mapped address - 0's after the mapped ffff indicator + "0:ffff::10.0.0.1", null, + // Invalid IPv4 mapped address - 0's after the mapped ffff indicator + "ffff::10.0.0.1", null, + // Invalid IPv4 mapped address - not all 0's before the mapped separator + "1:0:0:0:0:ffff:10.0.0.1", null, + // Address that is similar to IPv4 mapped, but is invalid + "0:0:0:0:ffff:ffff:1.0.0.1", null, + // Valid number of separators, but invalid IPv4 format + "::1:2:3:4:5:6.7.8.9", null, + // Too many digits + "0:0:0:0:0:0:ffff:10.0.0.1", null, + // Invalid IPv4 format + ":1.2.3.4", null, + // Invalid IPv4 format + "::.2.3.4", null, + // Invalid IPv4 format + "::ffff:0.1.2.", null); + + private static final Map ipv6ToAddressStrings = new HashMap() { + private static final long serialVersionUID = 2999763170377573184L; + { + // From the RFC 5952 https://tools.ietf.org/html/rfc5952#section-4 + put(new byte[] { + 32, 1, 13, -72, + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 1 + }, + "2001:db8::1"); + put(new byte[] { + 32, 1, 13, -72, + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 2, 0, 1 + }, + "2001:db8::2:1"); + put(new byte[] { + 32, 1, 13, -72, + 0, 0, 0, 1, + 0, 1, 0, 1, + 0, 1, 0, 1 + }, + "2001:db8:0:1:1:1:1:1"); + + // Other examples + put(new byte[] { + 32, 1, 13, -72, + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 2, 0, 1 + }, + "2001:db8::2:1"); + put(new byte[] { + 32, 1, 0, 0, + 0, 0, 0, 1, + 0, 0, 0, 0, + 0, 0, 0, 1 + }, + "2001:0:0:1::1"); + put(new byte[] { + 32, 1, 13, -72, + 0, 0, 0, 0, + 0, 1, 0, 0, + 0, 0, 0, 1 + }, + "2001:db8::1:0:0:1"); + put(new byte[] { + 32, 1, 13, -72, + 0, 0, 0, 0, + 0, 1, 0, 0, + 0, 0, 0, 0 + }, + "2001:db8:0:0:1::"); + put(new byte[] { + 32, 1, 13, -72, + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 2, 0, 0 + }, + "2001:db8::2:0"); + put(new byte[] { + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 1 + }, + "::1"); + put(new byte[] { + 0, 0, 0, 0, + 0, 0, 0, 1, + 0, 0, 0, 0, + 0, 0, 0, 1 + }, + "::1:0:0:0:1"); + put(new byte[] { + 0, 0, 0, 0, + 1, 0, 0, 1, + 0, 0, 0, 0, + 1, 0, 0, 0 + }, + "::100:1:0:0:100:0"); + put(new byte[] { + 32, 1, 0, 0, + 65, 54, -29, 120, + -128, 0, 99, -65, + 63, -1, -3, -46 + }, + "2001:0:4136:e378:8000:63bf:3fff:fdd2"); + put(new byte[] { + -86, -86, -69, -69, + -52, -52, -35, -35, + -18, -18, -1, -1, + 17, 17, 34, 34 + }, + "aaaa:bbbb:cccc:dddd:eeee:ffff:1111:2222"); + put(new byte[] { + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0, + 0, 0, 0, 0 + }, + "::"); + } + }; + + private static final Map ipv4MappedToIPv6AddressStrings = new TestMap( + // IPv4 addresses + "255.255.255.255", "::ffff:255.255.255.255", + "0.0.0.0", "::ffff:0.0.0.0", + "127.0.0.1", "::ffff:127.0.0.1", + "1.2.3.4", "::ffff:1.2.3.4", + "192.168.0.1", "::ffff:192.168.0.1", + + // IPv4 compatible addresses are deprecated [1], so we don't support outputting them, but we do support + // parsing them into IPv4 mapped addresses. These values are treated the same as a plain IPv4 address above. + // [1] https://tools.ietf.org/html/rfc4291#section-2.5.5.1 + "0:0:0:0:0:0:255.254.253.252", "::ffff:255.254.253.252", + "0:0:0:0:0::1.2.3.4", "::ffff:1.2.3.4", + "0:0:0:0::1.2.3.4", "::ffff:1.2.3.4", + "::0:0:0:0:0:1.2.3.4", "::ffff:1.2.3.4", + "0::0:0:0:0:1.2.3.4", "::ffff:1.2.3.4", + "0:0::0:0:0:1.2.3.4", "::ffff:1.2.3.4", + "0:0:0::0:0:1.2.3.4", "::ffff:1.2.3.4", + "0:0:0:0::0:1.2.3.4", "::ffff:1.2.3.4", + "0:0:0:0:0::1.2.3.4", "::ffff:1.2.3.4", + "::0:0:0:0:1.2.3.4", "::ffff:1.2.3.4", + "0::0:0:0:1.2.3.4", "::ffff:1.2.3.4", + "0:0::0:0:1.2.3.4", "::ffff:1.2.3.4", + "0:0:0::0:1.2.3.4", "::ffff:1.2.3.4", + "0:0:0:0::1.2.3.4", "::ffff:1.2.3.4", + "::0:0:0:0:1.2.3.4", "::ffff:1.2.3.4", + "0::0:0:0:1.2.3.4", "::ffff:1.2.3.4", + "0:0::0:0:1.2.3.4", "::ffff:1.2.3.4", + "0:0:0::0:1.2.3.4", "::ffff:1.2.3.4", + "0:0:0:0::1.2.3.4", "::ffff:1.2.3.4", + "::0:0:0:1.2.3.4", "::ffff:1.2.3.4", + "0::0:0:1.2.3.4", "::ffff:1.2.3.4", + "0:0::0:1.2.3.4", "::ffff:1.2.3.4", + "0:0:0::1.2.3.4", "::ffff:1.2.3.4", + "::0:0:1.2.3.4", "::ffff:1.2.3.4", + "0::0:1.2.3.4", "::ffff:1.2.3.4", + "0:0::1.2.3.4", "::ffff:1.2.3.4", + "::0:1.2.3.4", "::ffff:1.2.3.4", + "::1.2.3.4", "::ffff:1.2.3.4", + + // IPv4 mapped (fully specified) + "0:0:0:0:0:ffff:1.2.3.4", "::ffff:1.2.3.4", + + // IPv6 addresses + // Fully specified + "2001:0:4136:e378:8000:63bf:3fff:fdd2", "2001:0:4136:e378:8000:63bf:3fff:fdd2", + "aaaa:bbbb:cccc:dddd:eeee:ffff:1111:2222", "aaaa:bbbb:cccc:dddd:eeee:ffff:1111:2222", + "0:0:0:0:0:0:0:0", "::", + "0:0:0:0:0:0:0:1", "::1", + + // Compressing at the beginning + "::1:0:0:0:1", "::1:0:0:0:1", + "::1:ffff:ffff", "::1:ffff:ffff", + "::", "::", + "::1", "::1", + "::ffff", "::ffff", + "::ffff:0", "::ffff:0", + "::ffff:ffff", "::ffff:ffff", + "::0987:9876:8765", "::987:9876:8765", + "::0987:9876:8765:7654", "::987:9876:8765:7654", + "::0987:9876:8765:7654:6543", "::987:9876:8765:7654:6543", + "::0987:9876:8765:7654:6543:5432", "::987:9876:8765:7654:6543:5432", + // Note the compression is removed (rfc 5952 section 4.2.2) + "::0987:9876:8765:7654:6543:5432:3210", "0:987:9876:8765:7654:6543:5432:3210", + + // Compressing at the end + // Note the compression is removed (rfc 5952 section 4.2.2) + "2001:db8:abcd:bcde:cdef:def1:ef12::", "2001:db8:abcd:bcde:cdef:def1:ef12:0", + "2001:db8:abcd:bcde:cdef:def1::", "2001:db8:abcd:bcde:cdef:def1::", + "2001:db8:abcd:bcde:cdef::", "2001:db8:abcd:bcde:cdef::", + "2001:db8:abcd:bcde::", "2001:db8:abcd:bcde::", + "2001:db8:abcd::", "2001:db8:abcd::", + "2001:1234::", "2001:1234::", + "2001::", "2001::", + "0::", "::", + + // Compressing in the middle + "1234:2345::7890", "1234:2345::7890", + "1234::2345:7890", "1234::2345:7890", + "1234:2345:3456::7890", "1234:2345:3456::7890", + "1234:2345::3456:7890", "1234:2345::3456:7890", + "1234::2345:3456:7890", "1234::2345:3456:7890", + "1234:2345:3456:4567::7890", "1234:2345:3456:4567::7890", + "1234:2345:3456::4567:7890", "1234:2345:3456::4567:7890", + "1234:2345::3456:4567:7890", "1234:2345::3456:4567:7890", + "1234::2345:3456:4567:7890", "1234::2345:3456:4567:7890", + "1234:2345:3456:4567:5678::7890", "1234:2345:3456:4567:5678::7890", + "1234:2345:3456:4567::5678:7890", "1234:2345:3456:4567::5678:7890", + "1234:2345:3456::4567:5678:7890", "1234:2345:3456::4567:5678:7890", + "1234:2345::3456:4567:5678:7890", "1234:2345::3456:4567:5678:7890", + "1234::2345:3456:4567:5678:7890", "1234::2345:3456:4567:5678:7890", + // Note the compression is removed (rfc 5952 section 4.2.2) + "1234:2345:3456:4567:5678:6789::7890", "1234:2345:3456:4567:5678:6789:0:7890", + // Note the compression is removed (rfc 5952 section 4.2.2) + "1234:2345:3456:4567:5678::6789:7890", "1234:2345:3456:4567:5678:0:6789:7890", + // Note the compression is removed (rfc 5952 section 4.2.2) + "1234:2345:3456:4567::5678:6789:7890", "1234:2345:3456:4567:0:5678:6789:7890", + // Note the compression is removed (rfc 5952 section 4.2.2) + "1234:2345:3456::4567:5678:6789:7890", "1234:2345:3456:0:4567:5678:6789:7890", + // Note the compression is removed (rfc 5952 section 4.2.2) + "1234:2345::3456:4567:5678:6789:7890", "1234:2345:0:3456:4567:5678:6789:7890", + // Note the compression is removed (rfc 5952 section 4.2.2) + "1234::2345:3456:4567:5678:6789:7890", "1234:0:2345:3456:4567:5678:6789:7890", + + // IPv4 mapped addresses + "::ffff:255.255.255.255", "::ffff:255.255.255.255", + "::ffff:0.0.0.0", "::ffff:0.0.0.0", + "::ffff:127.0.0.1", "::ffff:127.0.0.1", + "::ffff:1.2.3.4", "::ffff:1.2.3.4", + "::ffff:192.168.0.1", "::ffff:192.168.0.1"); + + @Test + public void testLocalhost() { + assertNotNull(LOCALHOST); + } + + @Test + public void testLoopback() { + assertNotNull(LOOPBACK_IF); + } + + @Test + public void testIsValidIpV4Address() { + for (String host : validIpV4Hosts.keySet()) { + assertTrue(isValidIpV4Address(host), host); + } + for (String host : invalidIpV4Hosts.keySet()) { + assertFalse(isValidIpV4Address(host), host); + } + } + + @Test + public void testIsValidIpV6Address() { + for (String host : validIpV6Hosts.keySet()) { + assertTrue(isValidIpV6Address(host), host); + if (host.charAt(0) != '[' && !host.contains("%")) { + assertNotNull(getByName(host, true), host); + + String hostMod = '[' + host + ']'; + assertTrue(isValidIpV6Address(hostMod), hostMod); + + hostMod = host + '%'; + assertTrue(isValidIpV6Address(hostMod), hostMod); + + hostMod = host + "%eth1"; + assertTrue(isValidIpV6Address(hostMod), hostMod); + + hostMod = '[' + host + "%]"; + assertTrue(isValidIpV6Address(hostMod), hostMod); + + hostMod = '[' + host + "%1]"; + assertTrue(isValidIpV6Address(hostMod), hostMod); + + hostMod = '[' + host + "]%"; + assertFalse(isValidIpV6Address(hostMod), hostMod); + + hostMod = '[' + host + "]%1"; + assertFalse(isValidIpV6Address(hostMod), hostMod); + } + } + for (String host : invalidIpV6Hosts.keySet()) { + assertFalse(isValidIpV6Address(host), host); + assertNull(getByName(host), host); + + String hostMod = '[' + host + ']'; + assertFalse(isValidIpV6Address(hostMod), hostMod); + + hostMod = host + '%'; + assertFalse(isValidIpV6Address(hostMod), hostMod); + + hostMod = host + "%eth1"; + assertFalse(isValidIpV6Address(hostMod), hostMod); + + hostMod = '[' + host + "%]"; + assertFalse(isValidIpV6Address(hostMod), hostMod); + + hostMod = '[' + host + "%1]"; + assertFalse(isValidIpV6Address(hostMod), hostMod); + + hostMod = '[' + host + "]%"; + assertFalse(isValidIpV6Address(hostMod), hostMod); + + hostMod = '[' + host + "]%1"; + assertFalse(isValidIpV6Address(hostMod), hostMod); + + hostMod = host + ']'; + assertFalse(isValidIpV6Address(hostMod), hostMod); + + hostMod = '[' + host; + assertFalse(isValidIpV6Address(hostMod), hostMod); + } + } + + @Test + public void testCreateByteArrayFromIpAddressString() { + for (Entry e : validIpV4Hosts.entrySet()) { + String ip = e.getKey(); + assertHexDumpEquals(e.getValue(), createByteArrayFromIpAddressString(ip), ip); + } + for (Entry e : invalidIpV4Hosts.entrySet()) { + String ip = e.getKey(); + assertHexDumpEquals(e.getValue(), createByteArrayFromIpAddressString(ip), ip); + } + for (Entry e : validIpV6Hosts.entrySet()) { + String ip = e.getKey(); + assertHexDumpEquals(e.getValue(), createByteArrayFromIpAddressString(ip), ip); + } + for (Entry e : invalidIpV6Hosts.entrySet()) { + String ip = e.getKey(); + assertHexDumpEquals(e.getValue(), createByteArrayFromIpAddressString(ip), ip); + } + } + + @Test + public void testBytesToIpAddress() { + for (Entry e : validIpV4Hosts.entrySet()) { + assertEquals(e.getKey(), bytesToIpAddress(createByteArrayFromIpAddressString(e.getKey()))); + assertEquals(e.getKey(), bytesToIpAddress(validIpV4ToBytes(e.getKey()))); + } + for (Entry testEntry : ipv6ToAddressStrings.entrySet()) { + assertEquals(testEntry.getValue(), bytesToIpAddress(testEntry.getKey())); + } + } + + @Test + public void testBytesToIpAddressWithOffset() { + for (Entry e : validIpV4Hosts.entrySet()) { + byte[] bytes = copyWithOffset(createByteArrayFromIpAddressString(e.getKey())); + assertEquals(e.getKey(), bytesToIpAddress(bytes, 1, bytes.length - 2)); + + byte[] bytes2 = copyWithOffset(createByteArrayFromIpAddressString(e.getKey())); + assertEquals(e.getKey(), bytesToIpAddress(bytes2, 1, bytes2.length - 2)); + } + + for (Entry testEntry : ipv6ToAddressStrings.entrySet()) { + byte[] bytes = copyWithOffset(testEntry.getKey()); + assertEquals(testEntry.getValue(), bytesToIpAddress(bytes, 1, bytes.length - 2)); + } + } + + private static byte[] copyWithOffset(byte[] bytes) { + if (bytes == null) { + return null; + } + byte[] array = new byte[bytes.length + 2]; + System.arraycopy(bytes, 0, array, 1, bytes.length); + return array; + } + + @Test + public void testIp6AddressToString() throws UnknownHostException { + for (Entry testEntry : ipv6ToAddressStrings.entrySet()) { + assertEquals(testEntry.getValue(), toAddressString(InetAddress.getByAddress(testEntry.getKey()))); + } + } + + @Test + public void testIp4AddressToString() throws UnknownHostException { + for (Entry e : validIpV4Hosts.entrySet()) { + assertEquals(e.getKey(), toAddressString(InetAddress.getByAddress(unhex(e.getValue())))); + } + } + + @Test + public void testIPv4ToInt() throws UnknownHostException { + assertEquals(2130706433, ipv4AddressToInt((Inet4Address) InetAddress.getByName("127.0.0.1"))); + assertEquals(-1062731519, ipv4AddressToInt((Inet4Address) InetAddress.getByName("192.168.1.1"))); + } + + @Test + public void testIpv4MappedIp6GetByName() { + for (Entry testEntry : ipv4MappedToIPv6AddressStrings.entrySet()) { + String srcIp = testEntry.getKey(); + String dstIp = testEntry.getValue(); + Inet6Address inet6Address = getByName(srcIp, true); + assertNotNull(inet6Address, srcIp + ", " + dstIp); + assertEquals(dstIp, toAddressString(inet6Address, true), srcIp); + } + } + + @Test + public void testInvalidIpv4MappedIp6GetByName() { + for (String host : invalidIpV4Hosts.keySet()) { + assertNull(getByName(host, true), host); + } + + for (String host : invalidIpV6Hosts.keySet()) { + assertNull(getByName(host, true), host); + } + } + + @Test + public void testIp6InetSocketAddressToString() throws UnknownHostException { + for (Entry testEntry : ipv6ToAddressStrings.entrySet()) { + assertEquals('[' + testEntry.getValue() + "]:9999", + toSocketAddressString(new InetSocketAddress(InetAddress.getByAddress(testEntry.getKey()), 9999))); + } + } + + @Test + public void testIp4SocketAddressToString() throws UnknownHostException { + for (Entry e : validIpV4Hosts.entrySet()) { + assertEquals(e.getKey() + ":9999", + toSocketAddressString(new InetSocketAddress(InetAddress.getByAddress(unhex(e.getValue())), 9999))); + } + } + + private static void assertHexDumpEquals(String expected, byte[] actual, String message) { + assertEquals(expected, hex(actual), message); + } + + private static String hex(byte[] value) { + if (value == null) { + return null; + } + + StringBuilder buf = new StringBuilder(value.length << 1); + for (byte b: value) { + String hex = StringUtil.byteToHexString(b); + if (hex.length() == 1) { + buf.append('0'); + } + buf.append(hex); + } + return buf.toString(); + } + + private static byte[] unhex(String value) { + return value != null ? StringUtil.decodeHexDump(value) : null; + } +} diff --git a/netty-util/src/test/java/io/netty/util/NettyRuntimeTests.java b/netty-util/src/test/java/io/netty/util/NettyRuntimeTests.java new file mode 100644 index 0000000..9024c19 --- /dev/null +++ b/netty-util/src/test/java/io/netty/util/NettyRuntimeTests.java @@ -0,0 +1,206 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.util; + +import io.netty.util.internal.SystemPropertyUtil; +import org.junit.jupiter.api.Test; + +import java.util.concurrent.BrokenBarrierException; +import java.util.concurrent.CyclicBarrier; +import java.util.concurrent.atomic.AtomicReference; + +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.containsString; +import static org.hamcrest.Matchers.equalTo; +import static org.hamcrest.Matchers.hasToString; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.fail; + +public class NettyRuntimeTests { + + @Test + public void testIllegalSet() { + final NettyRuntime.AvailableProcessorsHolder holder = new NettyRuntime.AvailableProcessorsHolder(); + for (final int i : new int[] { -1, 0 }) { + try { + holder.setAvailableProcessors(i); + fail(); + } catch (final IllegalArgumentException e) { + assertThat(e, hasToString(containsString("(expected: > 0)"))); + } + } + } + + @Test + public void testMultipleSets() { + final NettyRuntime.AvailableProcessorsHolder holder = new NettyRuntime.AvailableProcessorsHolder(); + holder.setAvailableProcessors(1); + try { + holder.setAvailableProcessors(2); + fail(); + } catch (final IllegalStateException e) { + assertThat(e, hasToString(containsString("availableProcessors is already set to [1], rejecting [2]"))); + } + } + + @Test + public void testSetAfterGet() { + final NettyRuntime.AvailableProcessorsHolder holder = new NettyRuntime.AvailableProcessorsHolder(); + holder.availableProcessors(); + try { + holder.setAvailableProcessors(1); + fail(); + } catch (final IllegalStateException e) { + assertThat(e, hasToString(containsString("availableProcessors is already set"))); + } + } + + @Test + public void testRacingGetAndGet() throws InterruptedException { + final NettyRuntime.AvailableProcessorsHolder holder = new NettyRuntime.AvailableProcessorsHolder(); + final CyclicBarrier barrier = new CyclicBarrier(3); + + final AtomicReference firstReference = new AtomicReference(); + final Runnable firstTarget = getRunnable(holder, barrier, firstReference); + final Thread firstGet = new Thread(firstTarget); + firstGet.start(); + + final AtomicReference secondRefernce = new AtomicReference(); + final Runnable secondTarget = getRunnable(holder, barrier, secondRefernce); + final Thread secondGet = new Thread(secondTarget); + secondGet.start(); + + // release the hounds + await(barrier); + + // wait for the hounds + await(barrier); + + firstGet.join(); + secondGet.join(); + + assertNull(firstReference.get()); + assertNull(secondRefernce.get()); + } + + private static Runnable getRunnable( + final NettyRuntime.AvailableProcessorsHolder holder, + final CyclicBarrier barrier, + final AtomicReference reference) { + return new Runnable() { + @Override + public void run() { + await(barrier); + try { + holder.availableProcessors(); + } catch (final IllegalStateException e) { + reference.set(e); + } + await(barrier); + } + }; + } + + @Test + public void testRacingGetAndSet() throws InterruptedException { + final NettyRuntime.AvailableProcessorsHolder holder = new NettyRuntime.AvailableProcessorsHolder(); + final CyclicBarrier barrier = new CyclicBarrier(3); + final Thread get = new Thread(new Runnable() { + @Override + public void run() { + await(barrier); + holder.availableProcessors(); + await(barrier); + } + }); + get.start(); + + final AtomicReference setException = new AtomicReference(); + final Thread set = new Thread(new Runnable() { + @Override + public void run() { + await(barrier); + try { + holder.setAvailableProcessors(2048); + } catch (final IllegalStateException e) { + setException.set(e); + } + await(barrier); + } + }); + set.start(); + + // release the hounds + await(barrier); + + // wait for the hounds + await(barrier); + + get.join(); + set.join(); + + if (setException.get() == null) { + assertThat(holder.availableProcessors(), equalTo(2048)); + } else { + assertNotNull(setException.get()); + } + } + + @Test + public void testGetWithSystemProperty() { + final String availableProcessorsSystemProperty = SystemPropertyUtil.get("io.netty.availableProcessors"); + try { + System.setProperty("io.netty.availableProcessors", "2048"); + final NettyRuntime.AvailableProcessorsHolder holder = new NettyRuntime.AvailableProcessorsHolder(); + assertThat(holder.availableProcessors(), equalTo(2048)); + } finally { + if (availableProcessorsSystemProperty != null) { + System.setProperty("io.netty.availableProcessors", availableProcessorsSystemProperty); + } else { + System.clearProperty("io.netty.availableProcessors"); + } + } + } + + @Test + @SuppressForbidden(reason = "testing fallback to Runtime#availableProcessors") + public void testGet() { + final String availableProcessorsSystemProperty = SystemPropertyUtil.get("io.netty.availableProcessors"); + try { + System.clearProperty("io.netty.availableProcessors"); + final NettyRuntime.AvailableProcessorsHolder holder = new NettyRuntime.AvailableProcessorsHolder(); + assertThat(holder.availableProcessors(), equalTo(Runtime.getRuntime().availableProcessors())); + } finally { + if (availableProcessorsSystemProperty != null) { + System.setProperty("io.netty.availableProcessors", availableProcessorsSystemProperty); + } else { + System.clearProperty("io.netty.availableProcessors"); + } + } + } + + private static void await(final CyclicBarrier barrier) { + try { + barrier.await(); + } catch (final InterruptedException e) { + fail(e.toString()); + } catch (final BrokenBarrierException e) { + fail(e.toString()); + } + } +} diff --git a/netty-util/src/test/java/io/netty/util/RecyclerFastThreadLocalTest.java b/netty-util/src/test/java/io/netty/util/RecyclerFastThreadLocalTest.java new file mode 100644 index 0000000..1dfaffc --- /dev/null +++ b/netty-util/src/test/java/io/netty/util/RecyclerFastThreadLocalTest.java @@ -0,0 +1,74 @@ +/* + * Copyright 2023 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util; + +import io.netty.util.concurrent.FastThreadLocalThread; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.api.extension.ExtendWith; + +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; + +import static org.junit.jupiter.api.Assertions.assertFalse; + +@ExtendWith(RunInFastThreadLocalThreadExtension.class) +public class RecyclerFastThreadLocalTest extends RecyclerTest { + @Override + protected Thread newThread(Runnable runnable) { + return new FastThreadLocalThread(runnable); + } + + @Override + @Test + @Timeout(value = 5000, unit = TimeUnit.MILLISECONDS) + public void testThreadCanBeCollectedEvenIfHandledObjectIsReferenced() throws Exception { + final Recycler recycler = newRecycler(1024); + final AtomicBoolean collected = new AtomicBoolean(); + final AtomicReference reference = new AtomicReference(); + Thread thread = new FastThreadLocalThread(new Runnable() { + @Override + public void run() { + HandledObject object = recycler.get(); + // Store a reference to the HandledObject to ensure it is not collected when the run method finish. + reference.set(object); + } + }) { + @Override + protected void finalize() throws Throwable { + super.finalize(); + collected.set(true); + } + }; + assertFalse(collected.get()); + thread.start(); + thread.join(); + + // Null out so it can be collected. + thread = null; + + // Loop until the Thread was collected. If we can not collect it the Test will fail due of a timeout. + while (!collected.get()) { + System.gc(); + System.runFinalization(); + Thread.sleep(50); + } + + // Now call recycle after the Thread was collected to ensure this still works... + reference.getAndSet(null).recycle(); + } +} diff --git a/netty-util/src/test/java/io/netty/util/RecyclerTest.java b/netty-util/src/test/java/io/netty/util/RecyclerTest.java new file mode 100644 index 0000000..1ee4c6b --- /dev/null +++ b/netty-util/src/test/java/io/netty/util/RecyclerTest.java @@ -0,0 +1,476 @@ +/* +* Copyright 2014 The Netty Project +* +* The Netty Project licenses this file to you under the Apache License, +* version 2.0 (the "License"); you may not use this file except in compliance +* with the License. You may obtain a copy of the License at: +* +* https://www.apache.org/licenses/LICENSE-2.0 +* +* Unless required by applicable law or agreed to in writing, software +* distributed under the License is distributed on an "AS IS" BASIS, WITHOUT +* WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the +* License for the specific language governing permissions and limitations +* under the License. +*/ +package io.netty.util; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.api.function.Executable; + +import java.util.Random; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.ThreadFactory; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; + +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNotSame; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class RecyclerTest { + + protected static Recycler newRecycler(int maxCapacityPerThread) { + return newRecycler(maxCapacityPerThread, 8, maxCapacityPerThread >> 1); + } + + protected static Recycler newRecycler(int maxCapacityPerThread, int ratio, int chunkSize) { + return new Recycler(maxCapacityPerThread, ratio, chunkSize) { + @Override + protected HandledObject newObject( + Recycler.Handle handle) { + return new HandledObject(handle); + } + }; + } + + protected Thread newThread(Runnable runnable) { + return new Thread(runnable); + } + + @Test + @Timeout(value = 5000, unit = TimeUnit.MILLISECONDS) + public void testThreadCanBeCollectedEvenIfHandledObjectIsReferenced() throws Exception { + final Recycler recycler = newRecycler(1024); + final AtomicBoolean collected = new AtomicBoolean(); + final AtomicReference reference = new AtomicReference(); + Thread thread = new Thread(new Runnable() { + @Override + public void run() { + HandledObject object = recycler.get(); + // Store a reference to the HandledObject to ensure it is not collected when the run method finish. + reference.set(object); + } + }) { + @Override + protected void finalize() throws Throwable { + super.finalize(); + collected.set(true); + } + }; + assertFalse(collected.get()); + thread.start(); + thread.join(); + + // Null out so it can be collected. + thread = null; + + // Loop until the Thread was collected. If we can not collect it the Test will fail due of a timeout. + while (!collected.get()) { + System.gc(); + System.runFinalization(); + Thread.sleep(50); + } + + // Now call recycle after the Thread was collected to ensure this still works... + reference.getAndSet(null).recycle(); + } + + @Test + public void verySmallRecycer() { + newRecycler(2, 0, 1).get(); + } + + @Test + public void testMultipleRecycle() { + Recycler recycler = newRecycler(1024); + final HandledObject object = recycler.get(); + object.recycle(); + assertThrows(IllegalStateException.class, new Executable() { + @Override + public void execute() { + object.recycle(); + } + }); + } + + @Test + public void testMultipleRecycleAtDifferentThread() throws InterruptedException { + Recycler recycler = newRecycler(1024); + final HandledObject object = recycler.get(); + final AtomicReference exceptionStore = new AtomicReference(); + final Thread thread1 = newThread(new Runnable() { + @Override + public void run() { + object.recycle(); + } + }); + thread1.start(); + thread1.join(); + + final Thread thread2 = newThread(new Runnable() { + @Override + public void run() { + try { + object.recycle(); + } catch (IllegalStateException e) { + exceptionStore.set(e); + } + } + }); + thread2.start(); + thread2.join(); + HandledObject a = recycler.get(); + HandledObject b = recycler.get(); + assertNotSame(a, b); + IllegalStateException exception = exceptionStore.get(); + assertNotNull(exception); + } + + @Test + public void testMultipleRecycleAtDifferentThreadRacing() throws InterruptedException { + Recycler recycler = newRecycler(1024); + final HandledObject object = recycler.get(); + final AtomicReference exceptionStore = new AtomicReference(); + + final CountDownLatch countDownLatch = new CountDownLatch(2); + final Thread thread1 = newThread(new Runnable() { + @Override + public void run() { + try { + object.recycle(); + } catch (IllegalStateException e) { + Exception x = exceptionStore.getAndSet(e); + if (x != null) { + e.addSuppressed(x); + } + } finally { + countDownLatch.countDown(); + } + } + }); + thread1.start(); + + final Thread thread2 = newThread(new Runnable() { + @Override + public void run() { + try { + object.recycle(); + } catch (IllegalStateException e) { + Exception x = exceptionStore.getAndSet(e); + if (x != null) { + e.addSuppressed(x); + } + } finally { + countDownLatch.countDown(); + } + } + }); + thread2.start(); + + try { + countDownLatch.await(); + HandledObject a = recycler.get(); + HandledObject b = recycler.get(); + assertNotSame(a, b); + IllegalStateException exception = exceptionStore.get(); + if (exception != null) { + assertThat(exception).hasMessageContaining("recycled already"); + assertEquals(0, exception.getSuppressed().length); + } + } finally { + thread1.join(1000); + thread2.join(1000); + } + } + + @Test + public void testMultipleRecycleRacing() throws InterruptedException { + Recycler recycler = newRecycler(1024); + final HandledObject object = recycler.get(); + final AtomicReference exceptionStore = new AtomicReference(); + + final CountDownLatch countDownLatch = new CountDownLatch(1); + final Thread thread1 = newThread(new Runnable() { + @Override + public void run() { + try { + object.recycle(); + } catch (IllegalStateException e) { + Exception x = exceptionStore.getAndSet(e); + if (x != null) { + e.addSuppressed(x); + } + } finally { + countDownLatch.countDown(); + } + } + }); + thread1.start(); + + try { + object.recycle(); + } catch (IllegalStateException e) { + Exception x = exceptionStore.getAndSet(e); + if (x != null) { + e.addSuppressed(x); + } + } + + try { + countDownLatch.await(); + HandledObject a = recycler.get(); + HandledObject b = recycler.get(); + assertNotSame(a, b); + IllegalStateException exception = exceptionStore.get(); + assertNotNull(exception); // Object got recycled twice, so at least one of the calls must throw. + } finally { + thread1.join(1000); + } + } + + @Test + public void testRecycle() { + Recycler recycler = newRecycler(1024); + HandledObject object = recycler.get(); + object.recycle(); + HandledObject object2 = recycler.get(); + assertSame(object, object2); + object2.recycle(); + } + + @Test + public void testRecycleDisable() { + Recycler recycler = newRecycler(-1); + HandledObject object = recycler.get(); + object.recycle(); + HandledObject object2 = recycler.get(); + assertNotSame(object, object2); + object2.recycle(); + } + + @Test + public void testRecycleDisableDrop() { + Recycler recycler = newRecycler(1024, 0, 16); + HandledObject object = recycler.get(); + object.recycle(); + HandledObject object2 = recycler.get(); + assertSame(object, object2); + object2.recycle(); + HandledObject object3 = recycler.get(); + assertSame(object, object3); + object3.recycle(); + } + + /** + * Test to make sure bug #2848 never happens again + * https://github.com/netty/netty/issues/2848 + */ + @Test + public void testMaxCapacity() { + testMaxCapacity(300); + Random rand = new Random(); + for (int i = 0; i < 50; i++) { + testMaxCapacity(rand.nextInt(1000) + 256); // 256 - 1256 + } + } + + private static void testMaxCapacity(int maxCapacity) { + Recycler recycler = newRecycler(maxCapacity); + HandledObject[] objects = new HandledObject[maxCapacity * 3]; + for (int i = 0; i < objects.length; i++) { + objects[i] = recycler.get(); + } + + for (int i = 0; i < objects.length; i++) { + objects[i].recycle(); + objects[i] = null; + } + + assertTrue(maxCapacity >= recycler.threadLocalSize(), + "The threadLocalSize (" + recycler.threadLocalSize() + ") must be <= maxCapacity (" + + maxCapacity + ") as we not pool all new handles internally"); + } + + @Test + public void testRecycleAtDifferentThread() throws Exception { + final Recycler recycler = newRecycler(256, 2, 16); + final HandledObject o = recycler.get(); + final HandledObject o2 = recycler.get(); + + final Thread thread = newThread(new Runnable() { + @Override + public void run() { + o.recycle(); + o2.recycle(); + } + }); + thread.start(); + thread.join(); + + assertSame(recycler.get(), o); + assertNotSame(recycler.get(), o2); + } + + @Test + public void testRecycleAtTwoThreadsMulti() throws Exception { + final Recycler recycler = newRecycler(256); + final HandledObject o = recycler.get(); + + ExecutorService single = Executors.newSingleThreadExecutor(new ThreadFactory() { + @Override + public Thread newThread(Runnable r) { + return RecyclerTest.this.newThread(r); + } + }); + + final CountDownLatch latch1 = new CountDownLatch(1); + single.execute(new Runnable() { + @Override + public void run() { + o.recycle(); + latch1.countDown(); + } + }); + assertTrue(latch1.await(100, TimeUnit.MILLISECONDS)); + final HandledObject o2 = recycler.get(); + // Always recycler the first object, that is Ok + assertSame(o2, o); + + final CountDownLatch latch2 = new CountDownLatch(1); + single.execute(new Runnable() { + @Override + public void run() { + //The object should be recycled + o2.recycle(); + latch2.countDown(); + } + }); + assertTrue(latch2.await(100, TimeUnit.MILLISECONDS)); + + // It should be the same object, right? + final HandledObject o3 = recycler.get(); + assertSame(o3, o); + single.shutdown(); + } + + @Test + public void testMaxCapacityWithRecycleAtDifferentThread() throws Exception { + final int maxCapacity = 4; + final Recycler recycler = newRecycler(maxCapacity, 4, 4); + + // Borrow 2 * maxCapacity objects. + // Return the half from the same thread. + // Return the other half from the different thread. + + final HandledObject[] array = new HandledObject[maxCapacity * 3]; + for (int i = 0; i < array.length; i ++) { + array[i] = recycler.get(); + } + + for (int i = 0; i < maxCapacity; i ++) { + array[i].recycle(); + } + + final Thread thread = newThread(new Runnable() { + @Override + public void run() { + for (int i1 = maxCapacity; i1 < array.length; i1++) { + array[i1].recycle(); + } + } + }); + thread.start(); + thread.join(); + + assertEquals(maxCapacity * 3 / 4, recycler.threadLocalSize()); + + for (int i = 0; i < array.length; i ++) { + recycler.get(); + } + + assertEquals(0, recycler.threadLocalSize()); + } + + @Test + public void testDiscardingExceedingElementsWithRecycleAtDifferentThread() throws Exception { + final int maxCapacity = 32; + final AtomicInteger instancesCount = new AtomicInteger(0); + + final Recycler recycler = new Recycler(maxCapacity) { + @Override + protected HandledObject newObject(Recycler.Handle handle) { + instancesCount.incrementAndGet(); + return new HandledObject(handle); + } + }; + + // Borrow 2 * maxCapacity objects. + final HandledObject[] array = new HandledObject[maxCapacity * 2]; + for (int i = 0; i < array.length; i++) { + array[i] = recycler.get(); + } + + assertEquals(array.length, instancesCount.get()); + // Reset counter. + instancesCount.set(0); + + // Recycle from other thread. + final Thread thread = newThread(new Runnable() { + @Override + public void run() { + for (HandledObject object: array) { + object.recycle(); + } + } + }); + thread.start(); + thread.join(); + + assertEquals(0, instancesCount.get()); + + // Borrow 2 * maxCapacity objects. Half of them should come from + // the recycler queue, the other half should be freshly allocated. + for (int i = 0; i < array.length; i++) { + recycler.get(); + } + + // The implementation uses maxCapacity / 2 as limit per WeakOrderQueue + assertTrue(array.length - maxCapacity / 2 <= instancesCount.get(), + "The instances count (" + instancesCount.get() + ") must be <= array.length (" + array.length + + ") - maxCapacity (" + maxCapacity + ") / 2 as we not pool all new handles" + + " internally"); + } + + static final class HandledObject { + Recycler.Handle handle; + + HandledObject(Recycler.Handle handle) { + this.handle = handle; + } + + void recycle() { + handle.recycle(this); + } + } +} diff --git a/netty-util/src/test/java/io/netty/util/ResourceLeakDetectorTest.java b/netty-util/src/test/java/io/netty/util/ResourceLeakDetectorTest.java new file mode 100644 index 0000000..3947d60 --- /dev/null +++ b/netty-util/src/test/java/io/netty/util/ResourceLeakDetectorTest.java @@ -0,0 +1,263 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.util.ArrayDeque; +import java.util.Queue; +import java.util.UUID; +import java.util.concurrent.CyclicBarrier; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; + +import static org.assertj.core.api.Assertions.assertThat; + +public class ResourceLeakDetectorTest { + @SuppressWarnings("unused") + private static volatile int sink; + + @Test + @Timeout(value = 60000, unit = TimeUnit.MILLISECONDS) + public void testConcurrentUsage() throws Throwable { + final AtomicBoolean finished = new AtomicBoolean(); + final AtomicReference error = new AtomicReference(); + // With 50 threads issue #6087 is reproducible on every run. + Thread[] threads = new Thread[50]; + final CyclicBarrier barrier = new CyclicBarrier(threads.length); + for (int i = 0; i < threads.length; i++) { + Thread t = new Thread(new Runnable() { + final Queue resources = new ArrayDeque(100); + + @Override + public void run() { + try { + barrier.await(); + + // Run 10000 times or until the test is marked as finished. + for (int b = 0; b < 1000 && !finished.get(); b++) { + + // Allocate 100 LeakAwareResource per run and close them after it. + for (int a = 0; a < 100; a++) { + DefaultResource resource = new DefaultResource(); + ResourceLeakTracker leak = DefaultResource.detector.track(resource); + LeakAwareResource leakAwareResource = new LeakAwareResource(resource, leak); + resources.add(leakAwareResource); + } + if (closeResources(true)) { + finished.set(true); + } + } + } catch (InterruptedException e) { + Thread.currentThread().interrupt(); + } catch (Throwable e) { + error.compareAndSet(null, e); + } finally { + // Just close all resource now without assert it to eliminate more reports. + closeResources(false); + } + } + + private boolean closeResources(boolean checkClosed) { + for (;;) { + LeakAwareResource r = resources.poll(); + if (r == null) { + return false; + } + boolean closed = r.close(); + if (checkClosed && !closed) { + error.compareAndSet(null, + new AssertionError("ResourceLeak.close() returned 'false' but expected 'true'")); + return true; + } + } + } + }); + threads[i] = t; + t.start(); + } + + // Just wait until all threads are done. + for (Thread t: threads) { + t.join(); + } + + // Check if we had any leak reports in the ResourceLeakDetector itself + DefaultResource.detector.assertNoErrors(); + + assertNoErrors(error); + } + + @Timeout(10) + @Test + public void testLeakSetupHints() throws Throwable { + DefaultResource.detectorWithSetupHint.initialise(); + leakResource(); + + do { + // Trigger GC. + System.gc(); + // Track another resource to trigger refqueue visiting. + Resource resource2 = new DefaultResource(); + DefaultResource.detectorWithSetupHint.track(resource2).close(resource2); + // Give the GC something to work on. + for (int i = 0; i < 1000; i++) { + sink = System.identityHashCode(new byte[10000]); + } + } while (DefaultResource.detectorWithSetupHint.getLeaksFound() < 1 && !Thread.interrupted()); + + assertThat(DefaultResource.detectorWithSetupHint.getLeaksFound()).isOne(); + DefaultResource.detectorWithSetupHint.assertNoErrors(); + } + + private static void leakResource() { + Resource resource = new DefaultResource(); + // We'll never close this ResourceLeakTracker. + DefaultResource.detectorWithSetupHint.track(resource); + } + + // Mimic the way how we implement our classes that should help with leak detection + private static final class LeakAwareResource implements Resource { + private final Resource resource; + private final ResourceLeakTracker leak; + + LeakAwareResource(Resource resource, ResourceLeakTracker leak) { + this.resource = resource; + this.leak = leak; + } + + @Override + public boolean close() { + // Using ResourceLeakDetector.close(...) to prove this fixes the leak problem reported + // in https://github.com/netty/netty/issues/6034 . + // + // The following implementation would produce a leak: + // return leak.close(); + return leak.close(resource); + } + } + + private static final class DefaultResource implements Resource { + // Sample every allocation + static final TestResourceLeakDetector detector = new TestResourceLeakDetector( + Resource.class, 1, Integer.MAX_VALUE); + static final CreationRecordLeakDetector detectorWithSetupHint = + new CreationRecordLeakDetector(Resource.class, 1); + + @Override + public boolean close() { + return true; + } + } + + private interface Resource { + boolean close(); + } + + private static void assertNoErrors(AtomicReference ref) throws Throwable { + Throwable error = ref.get(); + if (error != null) { + throw error; + } + } + + private static final class TestResourceLeakDetector extends ResourceLeakDetector { + + private final AtomicReference error = new AtomicReference(); + + TestResourceLeakDetector(Class resourceType, int samplingInterval, long maxActive) { + super(resourceType, samplingInterval, maxActive); + } + + @Override + protected void reportTracedLeak(String resourceType, String records) { + reportError(new AssertionError("Leak reported for '" + resourceType + "':\n" + records)); + } + + @Override + protected void reportUntracedLeak(String resourceType) { + reportError(new AssertionError("Leak reported for '" + resourceType + '\'')); + } + + @Override + protected void reportInstancesLeak(String resourceType) { + reportError(new AssertionError("Leak reported for '" + resourceType + '\'')); + } + + private void reportError(AssertionError cause) { + error.compareAndSet(null, cause); + } + + void assertNoErrors() throws Throwable { + ResourceLeakDetectorTest.assertNoErrors(error); + } + } + + private static final class CreationRecordLeakDetector extends ResourceLeakDetector { + private String canaryString; + + private final AtomicReference error = new AtomicReference(); + private final AtomicInteger leaksFound = new AtomicInteger(0); + + CreationRecordLeakDetector(Class resourceType, int samplingInterval) { + super(resourceType, samplingInterval); + } + + public void initialise() { + canaryString = "creation-canary-" + UUID.randomUUID(); + leaksFound.set(0); + } + + @Override + protected boolean needReport() { + return true; + } + + @Override + protected void reportTracedLeak(String resourceType, String records) { + if (!records.contains(canaryString)) { + reportError(new AssertionError("Leak records did not contain canary string")); + } + leaksFound.incrementAndGet(); + } + + @Override + protected void reportUntracedLeak(String resourceType) { + reportError(new AssertionError("Got untraced leak w/o canary string")); + leaksFound.incrementAndGet(); + } + + private void reportError(AssertionError cause) { + error.compareAndSet(null, cause); + } + + @Override + protected Object getInitialHint(String resourceType) { + return canaryString; + } + + int getLeaksFound() { + return leaksFound.get(); + } + + void assertNoErrors() throws Throwable { + ResourceLeakDetectorTest.assertNoErrors(error); + } + } +} diff --git a/netty-util/src/test/java/io/netty/util/RunInFastThreadLocalThreadExtension.java b/netty-util/src/test/java/io/netty/util/RunInFastThreadLocalThreadExtension.java new file mode 100644 index 0000000..5445b83 --- /dev/null +++ b/netty-util/src/test/java/io/netty/util/RunInFastThreadLocalThreadExtension.java @@ -0,0 +1,58 @@ +/* + * Copyright 2023 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util; + +import io.netty.util.concurrent.FastThreadLocalThread; +import org.junit.jupiter.api.extension.ExtensionContext; +import org.junit.jupiter.api.extension.InvocationInterceptor; +import org.junit.jupiter.api.extension.ReflectiveInvocationContext; + +import java.lang.reflect.Method; +import java.util.concurrent.atomic.AtomicReference; + +/** + * Annotate your test class with {@code @ExtendWith(RunInFastThreadLocalThreadExtension.class)} to have all test methods + * run in a {@link io.netty.util.concurrent.FastThreadLocalThread}. + *

+ * This extension implementation is modified from the JUnit 5 + * + * intercepting invocations example. + */ +public class RunInFastThreadLocalThreadExtension implements InvocationInterceptor { + @Override + public void interceptTestMethod( + final Invocation invocation, + final ReflectiveInvocationContext invocationContext, + final ExtensionContext extensionContext) throws Throwable { + final AtomicReference throwable = new AtomicReference(); + Thread thread = new FastThreadLocalThread(new Runnable() { + @Override + public void run() { + try { + invocation.proceed(); + } catch (Throwable t) { + throwable.set(t); + } + } + }); + thread.start(); + thread.join(); + Throwable t = throwable.get(); + if (t != null) { + throw t; + } + } +} diff --git a/netty-util/src/test/java/io/netty/util/ThreadDeathWatcherTest.java b/netty-util/src/test/java/io/netty/util/ThreadDeathWatcherTest.java new file mode 100644 index 0000000..9a0609a --- /dev/null +++ b/netty-util/src/test/java/io/netty/util/ThreadDeathWatcherTest.java @@ -0,0 +1,144 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.util; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.fail; + +public class ThreadDeathWatcherTest { + + @Test + @Timeout(value = 10000, unit = TimeUnit.MILLISECONDS) + public void testWatch() throws Exception { + final CountDownLatch latch = new CountDownLatch(1); + final Thread t = new Thread() { + @Override + public void run() { + for (;;) { + try { + Thread.sleep(1000); + } catch (InterruptedException ignore) { + break; + } + } + } + }; + + final Runnable task = new Runnable() { + @Override + public void run() { + if (!t.isAlive()) { + latch.countDown(); + } + } + }; + + try { + ThreadDeathWatcher.watch(t, task); + fail("must reject to watch a non-alive thread."); + } catch (IllegalArgumentException e) { + // expected + } + + t.start(); + ThreadDeathWatcher.watch(t, task); + + // As long as the thread is alive, the task should not run. + assertThat(latch.await(750, TimeUnit.MILLISECONDS), is(false)); + + // Interrupt the thread to terminate it. + t.interrupt(); + + // The task must be run on termination. + latch.await(); + } + + @Test + @Timeout(value = 10000, unit = TimeUnit.MILLISECONDS) + public void testUnwatch() throws Exception { + final AtomicBoolean run = new AtomicBoolean(); + final Thread t = new Thread() { + @Override + public void run() { + for (;;) { + try { + Thread.sleep(1000); + } catch (InterruptedException ignore) { + break; + } + } + } + }; + + final Runnable task = new Runnable() { + @Override + public void run() { + run.set(true); + } + }; + + t.start(); + + // Watch and then unwatch. + ThreadDeathWatcher.watch(t, task); + ThreadDeathWatcher.unwatch(t, task); + + // Interrupt the thread to terminate it. + t.interrupt(); + + // Wait until the thread dies. + t.join(); + + // Wait until the watcher thread terminates itself. + assertThat(ThreadDeathWatcher.awaitInactivity(Long.MAX_VALUE, TimeUnit.SECONDS), is(true)); + + // And the task should not run. + assertThat(run.get(), is(false)); + } + + @Test + @Timeout(value = 2000, unit = TimeUnit.MILLISECONDS) + public void testThreadGroup() throws InterruptedException { + final ThreadGroup group = new ThreadGroup("group"); + final AtomicReference capturedGroup = new AtomicReference(); + final Thread thread = new Thread(group, new Runnable() { + @Override + public void run() { + final Thread t = ThreadDeathWatcher.threadFactory.newThread(new Runnable() { + @Override + public void run() { + } + }); + capturedGroup.set(t.getThreadGroup()); + } + }); + thread.start(); + thread.join(); + + assertEquals(group, capturedGroup.get()); + } +} diff --git a/netty-util/src/test/java/io/netty/util/concurrent/AbstractScheduledEventExecutorTest.java b/netty-util/src/test/java/io/netty/util/concurrent/AbstractScheduledEventExecutorTest.java new file mode 100644 index 0000000..13c4447 --- /dev/null +++ b/netty-util/src/test/java/io/netty/util/concurrent/AbstractScheduledEventExecutorTest.java @@ -0,0 +1,173 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.concurrent; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import java.util.concurrent.Callable; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class AbstractScheduledEventExecutorTest { + private static final Runnable TEST_RUNNABLE = new Runnable() { + + @Override + public void run() { + } + }; + + private static final Callable TEST_CALLABLE = Executors.callable(TEST_RUNNABLE); + + @Test + public void testScheduleRunnableZero() { + TestScheduledEventExecutor executor = new TestScheduledEventExecutor(); + ScheduledFuture future = executor.schedule(TEST_RUNNABLE, 0, TimeUnit.NANOSECONDS); + assertEquals(0, future.getDelay(TimeUnit.NANOSECONDS)); + assertNotNull(executor.pollScheduledTask()); + assertNull(executor.pollScheduledTask()); + } + + @Test + public void testScheduleRunnableNegative() { + TestScheduledEventExecutor executor = new TestScheduledEventExecutor(); + ScheduledFuture future = executor.schedule(TEST_RUNNABLE, -1, TimeUnit.NANOSECONDS); + assertEquals(0, future.getDelay(TimeUnit.NANOSECONDS)); + assertNotNull(executor.pollScheduledTask()); + assertNull(executor.pollScheduledTask()); + } + + @Test + public void testScheduleCallableZero() { + TestScheduledEventExecutor executor = new TestScheduledEventExecutor(); + ScheduledFuture future = executor.schedule(TEST_CALLABLE, 0, TimeUnit.NANOSECONDS); + assertEquals(0, future.getDelay(TimeUnit.NANOSECONDS)); + assertNotNull(executor.pollScheduledTask()); + assertNull(executor.pollScheduledTask()); + } + + @Test + public void testScheduleCallableNegative() { + TestScheduledEventExecutor executor = new TestScheduledEventExecutor(); + ScheduledFuture future = executor.schedule(TEST_CALLABLE, -1, TimeUnit.NANOSECONDS); + assertEquals(0, future.getDelay(TimeUnit.NANOSECONDS)); + assertNotNull(executor.pollScheduledTask()); + assertNull(executor.pollScheduledTask()); + } + + @Test + public void testScheduleAtFixedRateRunnableZero() { + final TestScheduledEventExecutor executor = new TestScheduledEventExecutor(); + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + executor.scheduleAtFixedRate(TEST_RUNNABLE, 0, 0, TimeUnit.DAYS); + } + }); + } + + @Test + public void testScheduleAtFixedRateRunnableNegative() { + final TestScheduledEventExecutor executor = new TestScheduledEventExecutor(); + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + executor.scheduleAtFixedRate(TEST_RUNNABLE, 0, -1, TimeUnit.DAYS); + } + }); + } + + @Test + public void testScheduleWithFixedDelayZero() { + final TestScheduledEventExecutor executor = new TestScheduledEventExecutor(); + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + executor.scheduleWithFixedDelay(TEST_RUNNABLE, 0, -1, TimeUnit.DAYS); + } + }); + } + + @Test + public void testScheduleWithFixedDelayNegative() { + final TestScheduledEventExecutor executor = new TestScheduledEventExecutor(); + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + executor.scheduleWithFixedDelay(TEST_RUNNABLE, 0, -1, TimeUnit.DAYS); + } + }); + } + + @Test + public void testDeadlineNanosNotOverflow() { + Assertions.assertEquals(Long.MAX_VALUE, AbstractScheduledEventExecutor.deadlineNanos( + AbstractScheduledEventExecutor.defaultCurrentTimeNanos(), Long.MAX_VALUE)); + } + + private static final class TestScheduledEventExecutor extends AbstractScheduledEventExecutor { + @Override + public boolean isShuttingDown() { + return false; + } + + @Override + public boolean inEventLoop(Thread thread) { + return true; + } + + @Override + public void shutdown() { + // NOOP + } + + @Override + public Future shutdownGracefully(long quietPeriod, long timeout, TimeUnit unit) { + throw new UnsupportedOperationException(); + } + + @Override + public Future terminationFuture() { + throw new UnsupportedOperationException(); + } + + @Override + public boolean isShutdown() { + return false; + } + + @Override + public boolean isTerminated() { + return false; + } + + @Override + public boolean awaitTermination(long timeout, TimeUnit unit) { + return false; + } + + @Override + public void execute(Runnable command) { + throw new UnsupportedOperationException(); + } + } +} diff --git a/netty-util/src/test/java/io/netty/util/concurrent/DefaultPromiseTest.java b/netty-util/src/test/java/io/netty/util/concurrent/DefaultPromiseTest.java new file mode 100644 index 0000000..0ea415d --- /dev/null +++ b/netty-util/src/test/java/io/netty/util/concurrent/DefaultPromiseTest.java @@ -0,0 +1,643 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.util.concurrent; + +import io.netty.util.Signal; +import io.netty.util.internal.logging.InternalLogger; +import io.netty.util.internal.logging.InternalLoggerFactory; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.api.function.Executable; + +import java.util.HashMap; +import java.util.Map; +import java.util.concurrent.BlockingQueue; +import java.util.concurrent.Callable; +import java.util.concurrent.CancellationException; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; + +import static java.lang.Math.max; +import static org.assertj.core.api.Assertions.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class DefaultPromiseTest { + private static final InternalLogger logger = InternalLoggerFactory.getInstance(DefaultPromiseTest.class); + private static int stackOverflowDepth; + + @BeforeAll + public static void beforeClass() { + try { + findStackOverflowDepth(); + throw new IllegalStateException("Expected StackOverflowError but didn't get it?!"); + } catch (StackOverflowError e) { + logger.debug("StackOverflowError depth: {}", stackOverflowDepth); + } + } + + @SuppressWarnings("InfiniteRecursion") + private static void findStackOverflowDepth() { + ++stackOverflowDepth; + findStackOverflowDepth(); + } + + private static int stackOverflowTestDepth() { + return max(stackOverflowDepth << 1, stackOverflowDepth); + } + + private static class RejectingEventExecutor extends AbstractEventExecutor { + @Override + public boolean isShuttingDown() { + return false; + } + + @Override + public Future shutdownGracefully(long quietPeriod, long timeout, TimeUnit unit) { + return null; + } + + @Override + public Future terminationFuture() { + return null; + } + + @Override + public void shutdown() { + } + + @Override + public boolean isShutdown() { + return false; + } + + @Override + public boolean isTerminated() { + return false; + } + + @Override + public boolean awaitTermination(long timeout, TimeUnit unit) throws InterruptedException { + return false; + } + + @Override + public ScheduledFuture schedule(Runnable command, long delay, TimeUnit unit) { + return fail("Cannot schedule commands"); + } + + @Override + public ScheduledFuture schedule(Callable callable, long delay, TimeUnit unit) { + return fail("Cannot schedule commands"); + } + + @Override + public ScheduledFuture scheduleAtFixedRate(Runnable command, long initialDelay, long period, TimeUnit unit) { + return fail("Cannot schedule commands"); + } + + @Override + public ScheduledFuture scheduleWithFixedDelay(Runnable command, long initialDelay, long delay, + TimeUnit unit) { + return fail("Cannot schedule commands"); + } + + @Override + public boolean inEventLoop(Thread thread) { + return false; + } + + @Override + public void execute(Runnable command) { + fail("Cannot schedule commands"); + } + } + + @Test + public void testCancelDoesNotScheduleWhenNoListeners() { + EventExecutor executor = new RejectingEventExecutor(); + + Promise promise = new DefaultPromise(executor); + assertTrue(promise.cancel(false)); + assertTrue(promise.isCancelled()); + } + + @Test + public void testSuccessDoesNotScheduleWhenNoListeners() { + EventExecutor executor = new RejectingEventExecutor(); + + Object value = new Object(); + Promise promise = new DefaultPromise(executor); + promise.setSuccess(value); + assertSame(value, promise.getNow()); + } + + @Test + public void testFailureDoesNotScheduleWhenNoListeners() { + EventExecutor executor = new RejectingEventExecutor(); + + Exception cause = new Exception(); + Promise promise = new DefaultPromise(executor); + promise.setFailure(cause); + assertSame(cause, promise.cause()); + } + + @Test + public void testCancellationExceptionIsThrownWhenBlockingGet() { + final Promise promise = new DefaultPromise(ImmediateEventExecutor.INSTANCE); + assertTrue(promise.cancel(false)); + assertThrows(CancellationException.class, new Executable() { + @Override + public void execute() throws Throwable { + promise.get(); + } + }); + } + + @Test + public void testCancellationExceptionIsThrownWhenBlockingGetWithTimeout() { + final Promise promise = new DefaultPromise(ImmediateEventExecutor.INSTANCE); + assertTrue(promise.cancel(false)); + assertThrows(CancellationException.class, new Executable() { + @Override + public void execute() throws Throwable { + promise.get(1, TimeUnit.SECONDS); + } + }); + } + + @Test + public void testCancellationExceptionIsReturnedAsCause() { + final Promise promise = new DefaultPromise(ImmediateEventExecutor.INSTANCE); + assertTrue(promise.cancel(false)); + assertThat(promise.cause()).isInstanceOf(CancellationException.class); + } + + @Test + public void testStackOverflowWithImmediateEventExecutorA() throws Exception { + testStackOverFlowChainedFuturesA(stackOverflowTestDepth(), ImmediateEventExecutor.INSTANCE, true); + testStackOverFlowChainedFuturesA(stackOverflowTestDepth(), ImmediateEventExecutor.INSTANCE, false); + } + + @Test + public void testNoStackOverflowWithDefaultEventExecutorA() throws Exception { + ExecutorService executorService = Executors.newSingleThreadExecutor(); + try { + EventExecutor executor = new DefaultEventExecutor(executorService); + try { + testStackOverFlowChainedFuturesA(stackOverflowTestDepth(), executor, true); + testStackOverFlowChainedFuturesA(stackOverflowTestDepth(), executor, false); + } finally { + executor.shutdownGracefully(0, 0, TimeUnit.MILLISECONDS); + } + } finally { + executorService.shutdown(); + } + } + + @Test + public void testNoStackOverflowWithImmediateEventExecutorB() throws Exception { + testStackOverFlowChainedFuturesB(stackOverflowTestDepth(), ImmediateEventExecutor.INSTANCE, true); + testStackOverFlowChainedFuturesB(stackOverflowTestDepth(), ImmediateEventExecutor.INSTANCE, false); + } + + @Test + public void testNoStackOverflowWithDefaultEventExecutorB() throws Exception { + ExecutorService executorService = Executors.newSingleThreadExecutor(); + try { + EventExecutor executor = new DefaultEventExecutor(executorService); + try { + testStackOverFlowChainedFuturesB(stackOverflowTestDepth(), executor, true); + testStackOverFlowChainedFuturesB(stackOverflowTestDepth(), executor, false); + } finally { + executor.shutdownGracefully(0, 0, TimeUnit.MILLISECONDS); + } + } finally { + executorService.shutdown(); + } + } + + @Test + public void testListenerNotifyOrder() throws Exception { + EventExecutor executor = new TestEventExecutor(); + try { + final BlockingQueue> listeners = new LinkedBlockingQueue>(); + int runs = 100000; + + for (int i = 0; i < runs; i++) { + final Promise promise = new DefaultPromise(executor); + final FutureListener listener1 = new FutureListener() { + @Override + public void operationComplete(Future future) throws Exception { + listeners.add(this); + } + }; + final FutureListener listener2 = new FutureListener() { + @Override + public void operationComplete(Future future) throws Exception { + listeners.add(this); + } + }; + final FutureListener listener4 = new FutureListener() { + @Override + public void operationComplete(Future future) throws Exception { + listeners.add(this); + } + }; + final FutureListener listener3 = new FutureListener() { + @Override + public void operationComplete(Future future) throws Exception { + listeners.add(this); + future.addListener(listener4); + } + }; + + GlobalEventExecutor.INSTANCE.execute(new Runnable() { + @Override + public void run() { + promise.setSuccess(null); + } + }); + + promise.addListener(listener1).addListener(listener2).addListener(listener3); + + assertSame(listener1, listeners.take(), "Fail 1 during run " + i + " / " + runs); + assertSame(listener2, listeners.take(), "Fail 2 during run " + i + " / " + runs); + assertSame(listener3, listeners.take(), "Fail 3 during run " + i + " / " + runs); + assertSame(listener4, listeners.take(), "Fail 4 during run " + i + " / " + runs); + assertTrue(listeners.isEmpty(), "Fail during run " + i + " / " + runs); + } + } finally { + executor.shutdownGracefully(0, 0, TimeUnit.SECONDS).sync(); + } + } + + @Test + public void testListenerNotifyLater() throws Exception { + // Testing first execution path in DefaultPromise + testListenerNotifyLater(1); + + // Testing second execution path in DefaultPromise + testListenerNotifyLater(2); + } + + @Test + @Timeout(value = 2000, unit = TimeUnit.MILLISECONDS) + public void testPromiseListenerAddWhenCompleteFailure() throws Exception { + testPromiseListenerAddWhenComplete(fakeException()); + } + + @Test + @Timeout(value = 2000, unit = TimeUnit.MILLISECONDS) + public void testPromiseListenerAddWhenCompleteSuccess() throws Exception { + testPromiseListenerAddWhenComplete(null); + } + + @Test + @Timeout(value = 2000, unit = TimeUnit.MILLISECONDS) + public void testLateListenerIsOrderedCorrectlySuccess() throws InterruptedException { + testLateListenerIsOrderedCorrectly(null); + } + + @Test + @Timeout(value = 2000, unit = TimeUnit.MILLISECONDS) + public void testLateListenerIsOrderedCorrectlyFailure() throws InterruptedException { + testLateListenerIsOrderedCorrectly(fakeException()); + } + + @Test + public void testSignalRace() { + final long wait = TimeUnit.NANOSECONDS.convert(10, TimeUnit.SECONDS); + EventExecutor executor = null; + try { + executor = new TestEventExecutor(); + + final int numberOfAttempts = 4096; + final Map> promises = new HashMap>(); + for (int i = 0; i < numberOfAttempts; i++) { + final DefaultPromise promise = new DefaultPromise(executor); + final Thread thread = new Thread(new Runnable() { + @Override + public void run() { + promise.setSuccess(null); + } + }); + promises.put(thread, promise); + } + + for (final Map.Entry> promise : promises.entrySet()) { + promise.getKey().start(); + final long start = System.nanoTime(); + promise.getValue().awaitUninterruptibly(wait, TimeUnit.NANOSECONDS); + assertThat(System.nanoTime() - start).isLessThan(wait); + } + } finally { + if (executor != null) { + executor.shutdownGracefully(); + } + } + } + + @Test + public void signalUncancellableCompletionValue() { + final Promise promise = new DefaultPromise(ImmediateEventExecutor.INSTANCE); + promise.setSuccess(Signal.valueOf(DefaultPromise.class, "UNCANCELLABLE")); + assertTrue(promise.isDone()); + assertTrue(promise.isSuccess()); + } + + @Test + public void signalSuccessCompletionValue() { + final Promise promise = new DefaultPromise(ImmediateEventExecutor.INSTANCE); + promise.setSuccess(Signal.valueOf(DefaultPromise.class, "SUCCESS")); + assertTrue(promise.isDone()); + assertTrue(promise.isSuccess()); + } + + @Test + public void setUncancellableGetNow() { + final Promise promise = new DefaultPromise(ImmediateEventExecutor.INSTANCE); + assertNull(promise.getNow()); + assertTrue(promise.setUncancellable()); + assertNull(promise.getNow()); + assertFalse(promise.isDone()); + assertFalse(promise.isSuccess()); + + promise.setSuccess("success"); + + assertTrue(promise.isDone()); + assertTrue(promise.isSuccess()); + assertEquals("success", promise.getNow()); + } + + private static void testStackOverFlowChainedFuturesA(int promiseChainLength, final EventExecutor executor, + boolean runTestInExecutorThread) + throws InterruptedException { + final Promise[] p = new DefaultPromise[promiseChainLength]; + final CountDownLatch latch = new CountDownLatch(promiseChainLength); + + if (runTestInExecutorThread) { + executor.execute(new Runnable() { + @Override + public void run() { + testStackOverFlowChainedFuturesA(executor, p, latch); + } + }); + } else { + testStackOverFlowChainedFuturesA(executor, p, latch); + } + + assertTrue(latch.await(2, TimeUnit.SECONDS)); + for (int i = 0; i < p.length; ++i) { + assertTrue(p[i].isSuccess(), "index " + i); + } + } + + private static void testStackOverFlowChainedFuturesA(EventExecutor executor, final Promise[] p, + final CountDownLatch latch) { + for (int i = 0; i < p.length; i ++) { + final int finalI = i; + p[i] = new DefaultPromise(executor); + p[i].addListener(new FutureListener() { + @Override + public void operationComplete(Future future) throws Exception { + if (finalI + 1 < p.length) { + p[finalI + 1].setSuccess(null); + } + latch.countDown(); + } + }); + } + + p[0].setSuccess(null); + } + + private static void testStackOverFlowChainedFuturesB(int promiseChainLength, final EventExecutor executor, + boolean runTestInExecutorThread) + throws InterruptedException { + final Promise[] p = new DefaultPromise[promiseChainLength]; + final CountDownLatch latch = new CountDownLatch(promiseChainLength); + + if (runTestInExecutorThread) { + executor.execute(new Runnable() { + @Override + public void run() { + testStackOverFlowChainedFuturesB(executor, p, latch); + } + }); + } else { + testStackOverFlowChainedFuturesB(executor, p, latch); + } + + assertTrue(latch.await(2, TimeUnit.SECONDS)); + for (int i = 0; i < p.length; ++i) { + assertTrue(p[i].isSuccess(), "index " + i); + } + } + + private static void testStackOverFlowChainedFuturesB(EventExecutor executor, final Promise[] p, + final CountDownLatch latch) { + for (int i = 0; i < p.length; i ++) { + final int finalI = i; + p[i] = new DefaultPromise(executor); + p[i].addListener(new FutureListener() { + @Override + public void operationComplete(Future future) throws Exception { + future.addListener(new FutureListener() { + @Override + public void operationComplete(Future future) throws Exception { + if (finalI + 1 < p.length) { + p[finalI + 1].setSuccess(null); + } + latch.countDown(); + } + }); + } + }); + } + + p[0].setSuccess(null); + } + + /** + * This test is mean to simulate the following sequence of events, which all take place on the I/O thread: + *
    + *
  1. A write is done
  2. + *
  3. The write operation completes, and the promise state is changed to done
  4. + *
  5. A listener is added to the return from the write. The {@link FutureListener#operationComplete(Future)} + * updates state which must be invoked before the response to the previous write is read.
  6. + *
  7. The write operation
  8. + *
+ */ + private static void testLateListenerIsOrderedCorrectly(Throwable cause) throws InterruptedException { + final EventExecutor executor = new TestEventExecutor(); + try { + final AtomicInteger state = new AtomicInteger(); + final CountDownLatch latch1 = new CountDownLatch(1); + final CountDownLatch latch2 = new CountDownLatch(2); + final Promise promise = new DefaultPromise(executor); + + // Add a listener before completion so "lateListener" is used next time we add a listener. + promise.addListener(new FutureListener() { + @Override + public void operationComplete(Future future) throws Exception { + assertTrue(state.compareAndSet(0, 1)); + } + }); + + // Simulate write operation completing, which will execute listeners in another thread. + if (cause == null) { + promise.setSuccess(null); + } else { + promise.setFailure(cause); + } + + // Add a "late listener" + promise.addListener(new FutureListener() { + @Override + public void operationComplete(Future future) throws Exception { + assertTrue(state.compareAndSet(1, 2)); + latch1.countDown(); + } + }); + + // Wait for the listeners and late listeners to be completed. + latch1.await(); + assertEquals(2, state.get()); + + // This is the important listener. A late listener that is added after all late listeners + // have completed, and needs to update state before a read operation (on the same executor). + executor.execute(new Runnable() { + @Override + public void run() { + promise.addListener(new FutureListener() { + @Override + public void operationComplete(Future future) throws Exception { + assertTrue(state.compareAndSet(2, 3)); + latch2.countDown(); + } + }); + } + }); + + // Simulate a read operation being queued up in the executor. + executor.execute(new Runnable() { + @Override + public void run() { + // This is the key, we depend upon the state being set in the next listener. + assertEquals(3, state.get()); + latch2.countDown(); + } + }); + + latch2.await(); + } finally { + executor.shutdownGracefully(0, 0, TimeUnit.SECONDS).sync(); + } + } + + private static void testPromiseListenerAddWhenComplete(Throwable cause) throws InterruptedException { + final CountDownLatch latch = new CountDownLatch(1); + final Promise promise = new DefaultPromise(ImmediateEventExecutor.INSTANCE); + promise.addListener(new FutureListener() { + @Override + public void operationComplete(Future future) throws Exception { + promise.addListener(new FutureListener() { + @Override + public void operationComplete(Future future) throws Exception { + latch.countDown(); + } + }); + } + }); + if (cause == null) { + promise.setSuccess(null); + } else { + promise.setFailure(cause); + } + latch.await(); + } + + private static void testListenerNotifyLater(final int numListenersBefore) throws Exception { + EventExecutor executor = new TestEventExecutor(); + int expectedCount = numListenersBefore + 2; + final CountDownLatch latch = new CountDownLatch(expectedCount); + final FutureListener listener = new FutureListener() { + @Override + public void operationComplete(Future future) throws Exception { + latch.countDown(); + } + }; + final Promise promise = new DefaultPromise(executor); + executor.execute(new Runnable() { + @Override + public void run() { + for (int i = 0; i < numListenersBefore; i++) { + promise.addListener(listener); + } + promise.setSuccess(null); + + GlobalEventExecutor.INSTANCE.execute(new Runnable() { + @Override + public void run() { + promise.addListener(listener); + } + }); + promise.addListener(listener); + } + }); + + assertTrue(latch.await(5, TimeUnit.SECONDS), + "Should have notified " + expectedCount + " listeners"); + executor.shutdownGracefully().sync(); + } + + private static final class TestEventExecutor extends SingleThreadEventExecutor { + TestEventExecutor() { + super(null, Executors.defaultThreadFactory(), true); + } + + @Override + protected void run() { + for (;;) { + Runnable task = takeTask(); + if (task != null) { + task.run(); + updateLastExecutionTime(); + } + + if (confirmShutdown()) { + break; + } + } + } + } + + private static RuntimeException fakeException() { + return new RuntimeException("fake exception"); + } +} diff --git a/netty-util/src/test/java/io/netty/util/concurrent/DefaultThreadFactoryTest.java b/netty-util/src/test/java/io/netty/util/concurrent/DefaultThreadFactoryTest.java new file mode 100644 index 0000000..3d2ae60 --- /dev/null +++ b/netty-util/src/test/java/io/netty/util/concurrent/DefaultThreadFactoryTest.java @@ -0,0 +1,297 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.util.concurrent; + +import org.junit.jupiter.api.Assumptions; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.security.Permission; +import java.util.concurrent.Callable; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; + +public class DefaultThreadFactoryTest { + + @Test + @Timeout(value = 2000, unit = TimeUnit.MILLISECONDS) + public void testDescendantThreadGroups() throws InterruptedException { + final SecurityManager current = System.getSecurityManager(); + + boolean securityManagerSet = false; + try { + try { + // install security manager that only allows parent thread groups to mess with descendant thread groups + System.setSecurityManager(new SecurityManager() { + @Override + public void checkAccess(ThreadGroup g) { + final ThreadGroup source = Thread.currentThread().getThreadGroup(); + + if (source != null) { + if (!source.parentOf(g)) { + throw new SecurityException("source group is not an ancestor of the target group"); + } + super.checkAccess(g); + } + } + + // so we can restore the security manager at the end of the test + @Override + public void checkPermission(Permission perm) { + } + }); + } catch (UnsupportedOperationException e) { + Assumptions.assumeFalse(true, "Setting SecurityManager not supported"); + } + securityManagerSet = true; + + // holder for the thread factory, plays the role of a global singleton + final AtomicReference factory = new AtomicReference(); + final AtomicInteger counter = new AtomicInteger(); + final Runnable task = new Runnable() { + @Override + public void run() { + counter.incrementAndGet(); + } + }; + + final AtomicReference interrupted = new AtomicReference(); + + // create the thread factory, since we are running the thread group brother, the thread + // factory will now forever be tied to that group + // we then create a thread from the factory to run a "task" for us + final Thread first = new Thread(new ThreadGroup("brother"), new Runnable() { + @Override + public void run() { + factory.set(new DefaultThreadFactory("test", false, Thread.NORM_PRIORITY, null)); + final Thread t = factory.get().newThread(task); + t.start(); + try { + t.join(); + } catch (InterruptedException e) { + interrupted.set(e); + Thread.currentThread().interrupt(); + } + } + }); + first.start(); + first.join(); + + assertNull(interrupted.get()); + + // now we will use factory again, this time from a sibling thread group sister + // if DefaultThreadFactory is "sticky" about thread groups, a security manager + // that forbids sibling thread groups from messing with each other will strike this down + final Thread second = new Thread(new ThreadGroup("sister"), new Runnable() { + @Override + public void run() { + final Thread t = factory.get().newThread(task); + t.start(); + try { + t.join(); + } catch (InterruptedException e) { + interrupted.set(e); + Thread.currentThread().interrupt(); + } + } + }); + second.start(); + second.join(); + + assertNull(interrupted.get()); + + assertEquals(2, counter.get()); + } finally { + if (securityManagerSet) { + System.setSecurityManager(current); + } + } + } + + // test that when DefaultThreadFactory is constructed with a sticky thread group, threads + // created by it have the sticky thread group + @Test + @Timeout(value = 2000, unit = TimeUnit.MILLISECONDS) + public void testDefaultThreadFactoryStickyThreadGroupConstructor() throws InterruptedException { + final ThreadGroup sticky = new ThreadGroup("sticky"); + runStickyThreadGroupTest( + new Callable() { + @Override + public DefaultThreadFactory call() throws Exception { + return new DefaultThreadFactory("test", false, Thread.NORM_PRIORITY, sticky); + } + }, + sticky); + } + + // test that when a security manager is installed that provides a ThreadGroup, DefaultThreadFactory inherits from + // the security manager + @Test + @Timeout(value = 2000, unit = TimeUnit.MILLISECONDS) + public void testDefaultThreadFactoryInheritsThreadGroupFromSecurityManager() throws InterruptedException { + final SecurityManager current = System.getSecurityManager(); + + boolean securityManagerSet = false; + try { + final ThreadGroup sticky = new ThreadGroup("sticky"); + try { + System.setSecurityManager(new SecurityManager() { + @Override + public ThreadGroup getThreadGroup() { + return sticky; + } + + // so we can restore the security manager at the end of the test + @Override + public void checkPermission(Permission perm) { + } + }); + } catch (UnsupportedOperationException e) { + Assumptions.assumeFalse(true, "Setting SecurityManager not supported"); + } + securityManagerSet = true; + + runStickyThreadGroupTest( + new Callable() { + @Override + public DefaultThreadFactory call() throws Exception { + return new DefaultThreadFactory("test"); + } + }, + sticky); + } finally { + if (securityManagerSet) { + System.setSecurityManager(current); + } + } + } + + private static void runStickyThreadGroupTest( + final Callable callable, + final ThreadGroup expected) throws InterruptedException { + final AtomicReference captured = new AtomicReference(); + final AtomicReference exception = new AtomicReference(); + + final Thread first = new Thread(new ThreadGroup("wrong"), new Runnable() { + @Override + public void run() { + final DefaultThreadFactory factory; + try { + factory = callable.call(); + } catch (Exception e) { + exception.set(e); + throw new RuntimeException(e); + } + final Thread t = factory.newThread(new Runnable() { + @Override + public void run() { + } + }); + captured.set(t.getThreadGroup()); + } + }); + first.start(); + first.join(); + + assertNull(exception.get()); + + assertEquals(expected, captured.get()); + } + + // test that when DefaultThreadFactory is constructed without a sticky thread group, threads + // created by it inherit the correct thread group + @Test + @Timeout(value = 2000, unit = TimeUnit.MILLISECONDS) + public void testDefaultThreadFactoryNonStickyThreadGroupConstructor() throws InterruptedException { + + final AtomicReference factory = new AtomicReference(); + final AtomicReference firstCaptured = new AtomicReference(); + + final ThreadGroup firstGroup = new ThreadGroup("first"); + final Thread first = new Thread(firstGroup, new Runnable() { + @Override + public void run() { + factory.set(new DefaultThreadFactory("sticky", false, Thread.NORM_PRIORITY, null)); + final Thread t = factory.get().newThread(new Runnable() { + @Override + public void run() { + } + }); + firstCaptured.set(t.getThreadGroup()); + } + }); + first.start(); + first.join(); + + assertEquals(firstGroup, firstCaptured.get()); + + final AtomicReference secondCaptured = new AtomicReference(); + + final ThreadGroup secondGroup = new ThreadGroup("second"); + final Thread second = new Thread(secondGroup, new Runnable() { + @Override + public void run() { + final Thread t = factory.get().newThread(new Runnable() { + @Override + public void run() { + } + }); + secondCaptured.set(t.getThreadGroup()); + } + }); + second.start(); + second.join(); + + assertEquals(secondGroup, secondCaptured.get()); + } + + // test that when DefaultThreadFactory is constructed without a sticky thread group, threads + // created by it inherit the correct thread group + @Test + @Timeout(value = 2000, unit = TimeUnit.MILLISECONDS) + public void testCurrentThreadGroupIsUsed() throws InterruptedException { + final AtomicReference factory = new AtomicReference(); + final AtomicReference firstCaptured = new AtomicReference(); + + final ThreadGroup group = new ThreadGroup("first"); + final Thread first = new Thread(group, new Runnable() { + @Override + public void run() { + final Thread current = Thread.currentThread(); + firstCaptured.set(current.getThreadGroup()); + factory.set(new DefaultThreadFactory("sticky", false)); + } + }); + first.start(); + first.join(); + assertEquals(group, firstCaptured.get()); + + ThreadGroup currentThreadGroup = Thread.currentThread().getThreadGroup(); + Thread second = factory.get().newThread(new Runnable() { + @Override + public void run() { + // NOOP. + } + }); + second.join(); + assertEquals(currentThreadGroup, currentThreadGroup); + } +} diff --git a/netty-util/src/test/java/io/netty/util/concurrent/FastThreadLocalTest.java b/netty-util/src/test/java/io/netty/util/concurrent/FastThreadLocalTest.java new file mode 100644 index 0000000..8c6a054 --- /dev/null +++ b/netty-util/src/test/java/io/netty/util/concurrent/FastThreadLocalTest.java @@ -0,0 +1,363 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.util.concurrent; + +import io.netty.util.internal.InternalThreadLocalMap; +import io.netty.util.internal.ObjectCleaner; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Disabled; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.api.condition.EnabledIfEnvironmentVariable; + +import java.lang.reflect.Field; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; + +import static org.hamcrest.CoreMatchers.instanceOf; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.CoreMatchers.not; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.nullValue; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class FastThreadLocalTest { + @BeforeEach + public void setUp() { + FastThreadLocal.removeAll(); + assertThat(FastThreadLocal.size(), is(0)); + } + + @Test + public void testGetIfExists() { + FastThreadLocal threadLocal = new FastThreadLocal() { + @Override + protected Boolean initialValue() { + return Boolean.TRUE; + } + }; + + assertNull(threadLocal.getIfExists()); + assertTrue(threadLocal.get()); + assertTrue(threadLocal.getIfExists()); + + FastThreadLocal.removeAll(); + assertNull(threadLocal.getIfExists()); + } + + @Test + @Timeout(value = 10000, unit = TimeUnit.MILLISECONDS) + public void testRemoveAll() throws Exception { + final AtomicBoolean removed = new AtomicBoolean(); + final FastThreadLocal var = new FastThreadLocal() { + @Override + protected void onRemoval(Boolean value) { + removed.set(true); + } + }; + + // Initialize a thread-local variable. + assertThat(var.get(), is(nullValue())); + assertThat(FastThreadLocal.size(), is(1)); + + // And then remove it. + FastThreadLocal.removeAll(); + assertThat(removed.get(), is(true)); + assertThat(FastThreadLocal.size(), is(0)); + } + + @Test + @Timeout(value = 10000, unit = TimeUnit.MILLISECONDS) + public void testRemoveAllFromFTLThread() throws Throwable { + final AtomicReference throwable = new AtomicReference(); + final Thread thread = new FastThreadLocalThread() { + @Override + public void run() { + try { + testRemoveAll(); + } catch (Throwable t) { + throwable.set(t); + } + } + }; + + thread.start(); + thread.join(); + + Throwable t = throwable.get(); + if (t != null) { + throw t; + } + } + + @Test + public void testMultipleSetRemove() throws Exception { + final FastThreadLocal threadLocal = new FastThreadLocal(); + final Runnable runnable = new Runnable() { + @Override + public void run() { + threadLocal.set("1"); + threadLocal.remove(); + threadLocal.set("2"); + threadLocal.remove(); + } + }; + + final int sizeWhenStart = ObjectCleaner.getLiveSetCount(); + Thread thread = new Thread(runnable); + thread.start(); + thread.join(); + + assertEquals(0, ObjectCleaner.getLiveSetCount() - sizeWhenStart); + + Thread thread2 = new Thread(runnable); + thread2.start(); + thread2.join(); + + assertEquals(0, ObjectCleaner.getLiveSetCount() - sizeWhenStart); + } + + @Test + public void testMultipleSetRemove_multipleThreadLocal() throws Exception { + final FastThreadLocal threadLocal = new FastThreadLocal(); + final FastThreadLocal threadLocal2 = new FastThreadLocal(); + final Runnable runnable = new Runnable() { + @Override + public void run() { + threadLocal.set("1"); + threadLocal.remove(); + threadLocal.set("2"); + threadLocal.remove(); + threadLocal2.set("1"); + threadLocal2.remove(); + threadLocal2.set("2"); + threadLocal2.remove(); + } + }; + + final int sizeWhenStart = ObjectCleaner.getLiveSetCount(); + Thread thread = new Thread(runnable); + thread.start(); + thread.join(); + + assertEquals(0, ObjectCleaner.getLiveSetCount() - sizeWhenStart); + + Thread thread2 = new Thread(runnable); + thread2.start(); + thread2.join(); + + assertEquals(0, ObjectCleaner.getLiveSetCount() - sizeWhenStart); + } + + @Test + @Timeout(value = 4000, unit = TimeUnit.MILLISECONDS) + public void testOnRemoveCalledForFastThreadLocalGet() throws Exception { + testOnRemoveCalled(true, true); + } + + @Disabled("onRemoval(...) not called with non FastThreadLocal") + @Test + @Timeout(value = 4000, unit = TimeUnit.MILLISECONDS) + public void testOnRemoveCalledForNonFastThreadLocalGet() throws Exception { + testOnRemoveCalled(false, true); + } + + @Test + @Timeout(value = 4000, unit = TimeUnit.MILLISECONDS) + public void testOnRemoveCalledForFastThreadLocalSet() throws Exception { + testOnRemoveCalled(true, false); + } + + @Disabled("onRemoval(...) not called with non FastThreadLocal") + @Test + @Timeout(value = 4000, unit = TimeUnit.MILLISECONDS) + public void testOnRemoveCalledForNonFastThreadLocalSet() throws Exception { + testOnRemoveCalled(false, false); + } + + private static void testOnRemoveCalled(boolean fastThreadLocal, final boolean callGet) throws Exception { + + final TestFastThreadLocal threadLocal = new TestFastThreadLocal(); + final TestFastThreadLocal threadLocal2 = new TestFastThreadLocal(); + + Runnable runnable = new Runnable() { + @Override + public void run() { + if (callGet) { + assertEquals(Thread.currentThread().getName(), threadLocal.get()); + assertEquals(Thread.currentThread().getName(), threadLocal2.get()); + } else { + threadLocal.set(Thread.currentThread().getName()); + threadLocal2.set(Thread.currentThread().getName()); + } + } + }; + Thread thread = fastThreadLocal ? new FastThreadLocalThread(runnable) : new Thread(runnable); + thread.start(); + thread.join(); + + String threadName = thread.getName(); + + // Null this out so it can be collected + thread = null; + + // Loop until onRemoval(...) was called. This will fail the test if this not works due a timeout. + while (threadLocal.onRemovalCalled.get() == null || threadLocal2.onRemovalCalled.get() == null) { + System.gc(); + System.runFinalization(); + Thread.sleep(50); + } + + assertEquals(threadName, threadLocal.onRemovalCalled.get()); + assertEquals(threadName, threadLocal2.onRemovalCalled.get()); + } + + private static final class TestFastThreadLocal extends FastThreadLocal { + + final AtomicReference onRemovalCalled = new AtomicReference(); + + @Override + protected String initialValue() throws Exception { + return Thread.currentThread().getName(); + } + + @Override + protected void onRemoval(String value) throws Exception { + onRemovalCalled.set(value); + } + } + + @Test + public void testConstructionWithIndex() throws Exception { + int ARRAY_LIST_CAPACITY_MAX_SIZE = Integer.MAX_VALUE - 8; + Field nextIndexField = + InternalThreadLocalMap.class.getDeclaredField("nextIndex"); + nextIndexField.setAccessible(true); + AtomicInteger nextIndex = (AtomicInteger) nextIndexField.get(AtomicInteger.class); + int nextIndex_before = nextIndex.get(); + final AtomicReference throwable = new AtomicReference(); + try { + while (nextIndex.get() < ARRAY_LIST_CAPACITY_MAX_SIZE) { + new FastThreadLocal(); + } + assertEquals(ARRAY_LIST_CAPACITY_MAX_SIZE - 1, InternalThreadLocalMap.lastVariableIndex()); + try { + new FastThreadLocal(); + } catch (Throwable t) { + throwable.set(t); + } finally { + // Assert the max index cannot greater than (ARRAY_LIST_CAPACITY_MAX_SIZE - 1). + assertThat(throwable.get(), is(instanceOf(IllegalStateException.class))); + // Assert the index was reset to ARRAY_LIST_CAPACITY_MAX_SIZE + // after it reaches ARRAY_LIST_CAPACITY_MAX_SIZE. + assertEquals(ARRAY_LIST_CAPACITY_MAX_SIZE - 1, InternalThreadLocalMap.lastVariableIndex()); + } + } finally { + // Restore the index. + nextIndex.set(nextIndex_before); + } + } + + @EnabledIfEnvironmentVariable(named = "CI", matches = "true", disabledReason = "" + + "This deliberately causes OutOfMemoryErrors, for which heap dumps are automatically generated. " + + "To avoid confusion, wasted time investigating heap dumps, and to avoid heap dumps accidentally " + + "getting committed to the Git repository, we should only enable this test when running in a CI " + + "environment. We make this check by assuming a 'CI' environment variable. " + + "This matches what Github Actions is doing for us currently.") + @Test + public void testInternalThreadLocalMapExpand() throws Exception { + final AtomicReference throwable = new AtomicReference(); + Runnable runnable = new Runnable() { + @Override + public void run() { + int expand_threshold = 1 << 30; + try { + InternalThreadLocalMap.get().setIndexedVariable(expand_threshold, null); + } catch (Throwable t) { + throwable.set(t); + } + } + }; + FastThreadLocalThread fastThreadLocalThread = new FastThreadLocalThread(runnable); + fastThreadLocalThread.start(); + fastThreadLocalThread.join(); + // assert the expanded size is not overflowed to negative value + assertThat(throwable.get(), is(not(instanceOf(NegativeArraySizeException.class)))); + } + + @Test + public void testFastThreadLocalSize() throws Exception { + int originSize = FastThreadLocal.size(); + assertTrue(originSize >= 0); + + InternalThreadLocalMap.get(); + assertEquals(originSize, FastThreadLocal.size()); + + new FastThreadLocal(); + assertEquals(originSize, FastThreadLocal.size()); + + FastThreadLocal fst2 = new FastThreadLocal(); + fst2.get(); + assertEquals(1 + originSize, FastThreadLocal.size()); + + FastThreadLocal fst3 = new FastThreadLocal(); + fst3.set(null); + assertEquals(2 + originSize, FastThreadLocal.size()); + + FastThreadLocal fst4 = new FastThreadLocal(); + fst4.set(Boolean.TRUE); + assertEquals(3 + originSize, FastThreadLocal.size()); + + fst4.set(Boolean.TRUE); + assertEquals(3 + originSize, FastThreadLocal.size()); + + fst4.remove(); + assertEquals(2 + originSize, FastThreadLocal.size()); + + FastThreadLocal.removeAll(); + assertEquals(0, FastThreadLocal.size()); + } + + @Test + public void testFastThreadLocalInitialValueWithUnset() throws Exception { + final AtomicReference throwable = new AtomicReference(); + final FastThreadLocal fst = new FastThreadLocal() { + @Override + protected Object initialValue() throws Exception { + return InternalThreadLocalMap.UNSET; + } + }; + Runnable runnable = new Runnable() { + @Override + public void run() { + try { + fst.get(); + } catch (Throwable t) { + throwable.set(t); + } + } + }; + FastThreadLocalThread fastThreadLocalThread = new FastThreadLocalThread(runnable); + fastThreadLocalThread.start(); + fastThreadLocalThread.join(); + assertThat(throwable.get(), is(instanceOf(IllegalArgumentException.class))); + } +} diff --git a/netty-util/src/test/java/io/netty/util/concurrent/GlobalEventExecutorTest.java b/netty-util/src/test/java/io/netty/util/concurrent/GlobalEventExecutorTest.java new file mode 100644 index 0000000..19e63d2 --- /dev/null +++ b/netty-util/src/test/java/io/netty/util/concurrent/GlobalEventExecutorTest.java @@ -0,0 +1,179 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.util.concurrent; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicReference; + +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.CoreMatchers.not; +import static org.hamcrest.CoreMatchers.nullValue; +import static org.hamcrest.CoreMatchers.sameInstance; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class GlobalEventExecutorTest { + + private static final GlobalEventExecutor e = GlobalEventExecutor.INSTANCE; + + @BeforeEach + public void setUp() throws Exception { + // Wait until the global executor is stopped (just in case there is a task running due to previous test cases) + for (;;) { + if (e.thread == null || !e.thread.isAlive()) { + break; + } + + Thread.sleep(50); + } + } + + @Test + @Timeout(value = 5000, unit = TimeUnit.MILLISECONDS) + public void testAutomaticStartStop() throws Exception { + final TestRunnable task = new TestRunnable(500); + e.execute(task); + + // Ensure the new thread has started. + Thread thread = e.thread; + assertThat(thread, is(not(nullValue()))); + assertThat(thread.isAlive(), is(true)); + + thread.join(); + assertThat(task.ran.get(), is(true)); + + // Ensure another new thread starts again. + task.ran.set(false); + e.execute(task); + assertThat(e.thread, not(sameInstance(thread))); + thread = e.thread; + + thread.join(); + + assertThat(task.ran.get(), is(true)); + } + + @Test + @Timeout(value = 5000, unit = TimeUnit.MILLISECONDS) + public void testScheduledTasks() throws Exception { + TestRunnable task = new TestRunnable(0); + ScheduledFuture f = e.schedule(task, 1500, TimeUnit.MILLISECONDS); + f.sync(); + assertThat(task.ran.get(), is(true)); + + // Ensure the thread is still running. + Thread thread = e.thread; + assertThat(thread, is(not(nullValue()))); + assertThat(thread.isAlive(), is(true)); + + thread.join(); + } + + // ensure that when a task submission causes a new thread to be created, the thread inherits the thread group of the + // submitting thread + @Test + @Timeout(value = 2000, unit = TimeUnit.MILLISECONDS) + public void testThreadGroup() throws InterruptedException { + final ThreadGroup group = new ThreadGroup("group"); + final AtomicReference capturedGroup = new AtomicReference(); + final Thread thread = new Thread(group, new Runnable() { + @Override + public void run() { + final Thread t = e.threadFactory.newThread(new Runnable() { + @Override + public void run() { + } + }); + capturedGroup.set(t.getThreadGroup()); + } + }); + thread.start(); + thread.join(); + + assertEquals(group, capturedGroup.get()); + } + + @Test + @Timeout(value = 5000, unit = TimeUnit.MILLISECONDS) + public void testTakeTask() throws Exception { + //add task + TestRunnable beforeTask = new TestRunnable(0); + e.execute(beforeTask); + + //add scheduled task + TestRunnable scheduledTask = new TestRunnable(0); + ScheduledFuture f = e.schedule(scheduledTask , 1500, TimeUnit.MILLISECONDS); + + //add task + TestRunnable afterTask = new TestRunnable(0); + e.execute(afterTask); + + f.sync(); + + assertThat(beforeTask.ran.get(), is(true)); + assertThat(scheduledTask.ran.get(), is(true)); + assertThat(afterTask.ran.get(), is(true)); + } + + @Test + @Timeout(value = 5000, unit = TimeUnit.MILLISECONDS) + public void testTakeTaskAlwaysHasTask() throws Exception { + //for https://github.com/netty/netty/issues/1614 + //add scheduled task + TestRunnable t = new TestRunnable(0); + final ScheduledFuture f = e.schedule(t, 1500, TimeUnit.MILLISECONDS); + + //ensure always has at least one task in taskQueue + //check if scheduled tasks are triggered + e.execute(new Runnable() { + @Override + public void run() { + if (!f.isDone()) { + e.execute(this); + } + } + }); + + f.sync(); + + assertThat(t.ran.get(), is(true)); + } + + private static final class TestRunnable implements Runnable { + final AtomicBoolean ran = new AtomicBoolean(); + final long delay; + + TestRunnable(long delay) { + this.delay = delay; + } + + @Override + public void run() { + try { + Thread.sleep(delay); + ran.set(true); + } catch (InterruptedException ignored) { + // Ignore + } + } + } +} diff --git a/netty-util/src/test/java/io/netty/util/concurrent/ImmediateExecutorTest.java b/netty-util/src/test/java/io/netty/util/concurrent/ImmediateExecutorTest.java new file mode 100644 index 0000000..8b9eb30 --- /dev/null +++ b/netty-util/src/test/java/io/netty/util/concurrent/ImmediateExecutorTest.java @@ -0,0 +1,54 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.util.concurrent; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.util.concurrent.FutureTask; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +public class ImmediateExecutorTest { + + @Test + public void testExecuteNullRunnable() { + assertThrows(NullPointerException.class, new Executable() { + @Override + public void execute() { + ImmediateExecutor.INSTANCE.execute(null); + } + }); + } + + @Test + public void testExecuteNonNullRunnable() throws Exception { + FutureTask task = new FutureTask(new Runnable() { + @Override + public void run() { + // NOOP + } + }, null); + ImmediateExecutor.INSTANCE.execute(task); + assertTrue(task.isDone()); + assertFalse(task.isCancelled()); + assertNull(task.get()); + } +} diff --git a/netty-util/src/test/java/io/netty/util/concurrent/NonStickyEventExecutorGroupTest.java b/netty-util/src/test/java/io/netty/util/concurrent/NonStickyEventExecutorGroupTest.java new file mode 100644 index 0000000..aedd77e --- /dev/null +++ b/netty-util/src/test/java/io/netty/util/concurrent/NonStickyEventExecutorGroupTest.java @@ -0,0 +1,178 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.concurrent; + +import io.netty.util.NettyRuntime; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.api.function.Executable; +import org.junit.jupiter.params.ParameterizedTest; +import org.junit.jupiter.params.provider.MethodSource; + +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class NonStickyEventExecutorGroupTest { + private static final String PARAMETERIZED_NAME = "{index}: maxTaskExecutePerRun = {0}"; + + @Test + public void testInvalidGroup() { + final EventExecutorGroup group = new DefaultEventExecutorGroup(1); + try { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + new NonStickyEventExecutorGroup(group); + } + }); + } finally { + group.shutdownGracefully(); + } + } + + public static Collection data() throws Exception { + List params = new ArrayList(); + params.add(new Object[] {64}); + params.add(new Object[] {256}); + params.add(new Object[] {1024}); + params.add(new Object[] {Integer.MAX_VALUE}); + return params; + } + + @ParameterizedTest(name = PARAMETERIZED_NAME) + @MethodSource("data") + @Timeout(value = 10000, unit = TimeUnit.MILLISECONDS) + public void testOrdering(int maxTaskExecutePerRun) throws Throwable { + final int threads = NettyRuntime.availableProcessors() * 2; + final EventExecutorGroup group = new UnorderedThreadPoolEventExecutor(threads); + final NonStickyEventExecutorGroup nonStickyGroup = new NonStickyEventExecutorGroup(group, maxTaskExecutePerRun); + try { + final CountDownLatch startLatch = new CountDownLatch(1); + final AtomicReference error = new AtomicReference(); + List threadList = new ArrayList(threads); + for (int i = 0 ; i < threads; i++) { + Thread thread = new Thread(new Runnable() { + @Override + public void run() { + try { + execute(nonStickyGroup, startLatch); + } catch (Throwable cause) { + error.compareAndSet(null, cause); + } + } + }); + threadList.add(thread); + thread.start(); + } + startLatch.countDown(); + for (Thread t: threadList) { + t.join(); + } + Throwable cause = error.get(); + if (cause != null) { + throw cause; + } + } finally { + nonStickyGroup.shutdownGracefully(); + } + } + + @ParameterizedTest(name = PARAMETERIZED_NAME) + @MethodSource("data") + public void testRaceCondition(int maxTaskExecutePerRun) throws InterruptedException { + EventExecutorGroup group = new UnorderedThreadPoolEventExecutor(1); + NonStickyEventExecutorGroup nonStickyGroup = new NonStickyEventExecutorGroup(group, maxTaskExecutePerRun); + + try { + EventExecutor executor = nonStickyGroup.next(); + + for (int j = 0; j < 5000; j++) { + final CountDownLatch firstCompleted = new CountDownLatch(1); + final CountDownLatch latch = new CountDownLatch(2); + for (int i = 0; i < 2; i++) { + executor.execute(new Runnable() { + @Override + public void run() { + firstCompleted.countDown(); + latch.countDown(); + } + }); + assertTrue(firstCompleted.await(1, TimeUnit.SECONDS)); + } + + assertTrue(latch.await(5, TimeUnit.SECONDS)); + } + } finally { + nonStickyGroup.shutdownGracefully(); + } + } + + private static void execute(EventExecutorGroup group, CountDownLatch startLatch) throws Throwable { + final EventExecutor executor = group.next(); + assertTrue(executor instanceof OrderedEventExecutor); + final AtomicReference cause = new AtomicReference(); + final AtomicInteger last = new AtomicInteger(); + int tasks = 10000; + List> futures = new ArrayList>(tasks); + final CountDownLatch latch = new CountDownLatch(tasks); + startLatch.await(); + + for (int i = 1 ; i <= tasks; i++) { + final int id = i; + assertFalse(executor.inEventLoop()); + assertFalse(executor.inEventLoop(Thread.currentThread())); + futures.add(executor.submit(new Runnable() { + @Override + public void run() { + try { + assertTrue(executor.inEventLoop(Thread.currentThread())); + assertTrue(executor.inEventLoop()); + + if (cause.get() == null) { + int lastId = last.get(); + if (lastId >= id) { + cause.compareAndSet(null, new AssertionError( + "Out of order execution id(" + id + ") >= lastId(" + lastId + ')')); + } + if (!last.compareAndSet(lastId, id)) { + cause.compareAndSet(null, new AssertionError("Concurrent execution of tasks")); + } + } + } finally { + latch.countDown(); + } + } + })); + } + latch.await(); + for (Future future: futures) { + future.syncUninterruptibly(); + } + Throwable error = cause.get(); + if (error != null) { + throw error; + } + } +} diff --git a/netty-util/src/test/java/io/netty/util/concurrent/PromiseAggregatorTest.java b/netty-util/src/test/java/io/netty/util/concurrent/PromiseAggregatorTest.java new file mode 100644 index 0000000..532d9c2 --- /dev/null +++ b/netty-util/src/test/java/io/netty/util/concurrent/PromiseAggregatorTest.java @@ -0,0 +1,150 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.util.concurrent; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import static org.hamcrest.CoreMatchers.*; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.mockito.Mockito.*; + +public class PromiseAggregatorTest { + + @Test + public void testNullAggregatePromise() { + assertThrows(NullPointerException.class, new Executable() { + @SuppressWarnings("deprecation") + @Override + public void execute() { + new PromiseAggregator>(null); + } + }); + } + + @Test + public void testAddNullFuture() { + @SuppressWarnings("unchecked") + Promise p = mock(Promise.class); + @SuppressWarnings("deprecation") + final PromiseAggregator> a = + new PromiseAggregator>(p); + assertThrows(NullPointerException.class, new Executable() { + @Override + public void execute() { + a.add((Promise[]) null); + } + }); + } + + @SuppressWarnings("unchecked") + @Test + public void testSuccessfulNoPending() throws Exception { + Promise p = mock(Promise.class); + @SuppressWarnings("deprecation") + PromiseAggregator> a = + new PromiseAggregator>(p); + + Future future = mock(Future.class); + when(p.setSuccess(null)).thenReturn(p); + + a.add(); + a.operationComplete(future); + verifyNoMoreInteractions(future); + verify(p).setSuccess(null); + } + + @SuppressWarnings("unchecked") + @Test + public void testSuccessfulPending() throws Exception { + Promise p = mock(Promise.class); + PromiseAggregator> a = + new PromiseAggregator>(p); + Promise p1 = mock(Promise.class); + Promise p2 = mock(Promise.class); + + when(p1.addListener(a)).thenReturn(p1); + when(p2.addListener(a)).thenReturn(p2); + when(p1.isSuccess()).thenReturn(true); + when(p2.isSuccess()).thenReturn(true); + when(p.setSuccess(null)).thenReturn(p); + + assertThat(a.add(p1, null, p2), is(a)); + a.operationComplete(p1); + a.operationComplete(p2); + + verify(p1).addListener(a); + verify(p2).addListener(a); + verify(p1).isSuccess(); + verify(p2).isSuccess(); + verify(p).setSuccess(null); + } + + @SuppressWarnings("unchecked") + @Test + public void testFailedFutureFailPending() throws Exception { + Promise p = mock(Promise.class); + PromiseAggregator> a = + new PromiseAggregator>(p); + Promise p1 = mock(Promise.class); + Promise p2 = mock(Promise.class); + Throwable t = mock(Throwable.class); + + when(p1.addListener(a)).thenReturn(p1); + when(p2.addListener(a)).thenReturn(p2); + when(p1.isSuccess()).thenReturn(false); + when(p1.cause()).thenReturn(t); + when(p.setFailure(t)).thenReturn(p); + when(p2.setFailure(t)).thenReturn(p2); + + a.add(p1, p2); + a.operationComplete(p1); + + verify(p1).addListener(a); + verify(p2).addListener(a); + verify(p1).cause(); + verify(p).setFailure(t); + verify(p2).setFailure(t); + } + + @SuppressWarnings("unchecked") + @Test + public void testFailedFutureNoFailPending() throws Exception { + Promise p = mock(Promise.class); + PromiseAggregator> a = + new PromiseAggregator>(p, false); + Promise p1 = mock(Promise.class); + Promise p2 = mock(Promise.class); + Throwable t = mock(Throwable.class); + + when(p1.addListener(a)).thenReturn(p1); + when(p2.addListener(a)).thenReturn(p2); + when(p1.isSuccess()).thenReturn(false); + when(p1.cause()).thenReturn(t); + when(p.setFailure(t)).thenReturn(p); + + a.add(p1, p2); + a.operationComplete(p1); + + verify(p1).addListener(a); + verify(p2).addListener(a); + verify(p1).isSuccess(); + verify(p1).cause(); + verify(p).setFailure(t); + } +} diff --git a/netty-util/src/test/java/io/netty/util/concurrent/PromiseCombinerTest.java b/netty-util/src/test/java/io/netty/util/concurrent/PromiseCombinerTest.java new file mode 100644 index 0000000..eb86af4 --- /dev/null +++ b/netty-util/src/test/java/io/netty/util/concurrent/PromiseCombinerTest.java @@ -0,0 +1,267 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.concurrent; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; +import org.mockito.Mock; +import org.mockito.MockitoAnnotations; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; + +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.fail; +import static org.mockito.Mockito.any; +import static org.mockito.Mockito.eq; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +public class PromiseCombinerTest { + @Mock + private Promise p1; + private GenericFutureListener> l1; + private final GenericFutureListenerConsumer l1Consumer = new GenericFutureListenerConsumer() { + @Override + public void accept(GenericFutureListener> listener) { + l1 = listener; + } + }; + @Mock + private Promise p2; + private GenericFutureListener> l2; + private final GenericFutureListenerConsumer l2Consumer = new GenericFutureListenerConsumer() { + @Override + public void accept(GenericFutureListener> listener) { + l2 = listener; + } + }; + @Mock + private Promise p3; + private PromiseCombiner combiner; + + @BeforeEach + public void setup() { + MockitoAnnotations.initMocks(this); + combiner = new PromiseCombiner(ImmediateEventExecutor.INSTANCE); + } + + @Test + public void testNullArgument() { + try { + combiner.finish(null); + fail(); + } catch (NullPointerException expected) { + // expected + } + combiner.finish(p1); + verify(p1).trySuccess(null); + } + + @Test + public void testNullAggregatePromise() { + combiner.finish(p1); + verify(p1).trySuccess(null); + } + + @Test + public void testAddNullPromise() { + assertThrows(NullPointerException.class, new Executable() { + @Override + public void execute() { + combiner.add(null); + } + }); + } + + @Test + public void testAddAllNullPromise() { + assertThrows(NullPointerException.class, new Executable() { + @Override + public void execute() { + combiner.addAll(null); + } + }); + } + + @Test + public void testAddAfterFinish() { + combiner.finish(p1); + assertThrows(IllegalStateException.class, new Executable() { + @Override + public void execute() { + combiner.add(p2); + } + }); + } + + @SuppressWarnings("unchecked") + @Test + public void testAddAllAfterFinish() { + combiner.finish(p1); + assertThrows(IllegalStateException.class, new Executable() { + @Override + public void execute() { + combiner.addAll(p2); + } + }); + } + + @SuppressWarnings("unchecked") + @Test + public void testFinishCalledTwiceThrows() { + combiner.finish(p1); + assertThrows(IllegalStateException.class, new Executable() { + @Override + public void execute() { + combiner.finish(p1); + } + }); + } + + @Test + public void testAddAllSuccess() throws Exception { + mockSuccessPromise(p1, l1Consumer); + mockSuccessPromise(p2, l2Consumer); + combiner.addAll(p1, p2); + combiner.finish(p3); + l1.operationComplete(p1); + verifyNotCompleted(p3); + l2.operationComplete(p2); + verifySuccess(p3); + } + + @Test + public void testAddSuccess() throws Exception { + mockSuccessPromise(p1, l1Consumer); + mockSuccessPromise(p2, l2Consumer); + combiner.add(p1); + l1.operationComplete(p1); + combiner.add(p2); + l2.operationComplete(p2); + verifyNotCompleted(p3); + combiner.finish(p3); + verifySuccess(p3); + } + + @Test + public void testAddAllFail() throws Exception { + RuntimeException e1 = new RuntimeException("fake exception 1"); + RuntimeException e2 = new RuntimeException("fake exception 2"); + mockFailedPromise(p1, e1, l1Consumer); + mockFailedPromise(p2, e2, l2Consumer); + combiner.addAll(p1, p2); + combiner.finish(p3); + l1.operationComplete(p1); + verifyNotCompleted(p3); + l2.operationComplete(p2); + verifyFail(p3, e1); + } + + @Test + public void testAddFail() throws Exception { + RuntimeException e1 = new RuntimeException("fake exception 1"); + RuntimeException e2 = new RuntimeException("fake exception 2"); + mockFailedPromise(p1, e1, l1Consumer); + mockFailedPromise(p2, e2, l2Consumer); + combiner.add(p1); + l1.operationComplete(p1); + combiner.add(p2); + l2.operationComplete(p2); + verifyNotCompleted(p3); + combiner.finish(p3); + verifyFail(p3, e1); + } + + @Test + public void testEventExecutor() { + EventExecutor executor = mock(EventExecutor.class); + when(executor.inEventLoop()).thenReturn(false); + combiner = new PromiseCombiner(executor); + + Future future = mock(Future.class); + + try { + combiner.add(future); + fail(); + } catch (IllegalStateException expected) { + // expected + } + + try { + combiner.addAll(future); + fail(); + } catch (IllegalStateException expected) { + // expected + } + + @SuppressWarnings("unchecked") + Promise promise = (Promise) mock(Promise.class); + try { + combiner.finish(promise); + fail(); + } catch (IllegalStateException expected) { + // expected + } + } + + private static void verifyFail(Promise p, Throwable cause) { + verify(p).tryFailure(eq(cause)); + } + + private static void verifySuccess(Promise p) { + verify(p).trySuccess(null); + } + + private static void verifyNotCompleted(Promise p) { + verify(p, never()).trySuccess(any(Void.class)); + verify(p, never()).tryFailure(any(Throwable.class)); + verify(p, never()).setSuccess(any(Void.class)); + verify(p, never()).setFailure(any(Throwable.class)); + } + + private static void mockSuccessPromise(Promise p, GenericFutureListenerConsumer consumer) { + when(p.isDone()).thenReturn(true); + when(p.isSuccess()).thenReturn(true); + mockListener(p, consumer); + } + + private static void mockFailedPromise(Promise p, Throwable cause, GenericFutureListenerConsumer consumer) { + when(p.isDone()).thenReturn(true); + when(p.isSuccess()).thenReturn(false); + when(p.cause()).thenReturn(cause); + mockListener(p, consumer); + } + + @SuppressWarnings("unchecked") + private static void mockListener(final Promise p, final GenericFutureListenerConsumer consumer) { + doAnswer(new Answer>() { + @SuppressWarnings({ "unchecked", "raw-types" }) + @Override + public Promise answer(InvocationOnMock invocation) throws Throwable { + consumer.accept((GenericFutureListener) invocation.getArgument(0)); + return p; + } + }).when(p).addListener(any(GenericFutureListener.class)); + } + + interface GenericFutureListenerConsumer { + void accept(GenericFutureListener> listener); + } +} diff --git a/netty-util/src/test/java/io/netty/util/concurrent/PromiseNotifierTest.java b/netty-util/src/test/java/io/netty/util/concurrent/PromiseNotifierTest.java new file mode 100644 index 0000000..6eb4e7e --- /dev/null +++ b/netty-util/src/test/java/io/netty/util/concurrent/PromiseNotifierTest.java @@ -0,0 +1,110 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.util.concurrent; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.*; + +public class PromiseNotifierTest { + + @Test + public void testNullPromisesArray() { + assertThrows(NullPointerException.class, new Executable() { + @Override + public void execute() { + new PromiseNotifier>((Promise[]) null); + } + }); + } + + @SuppressWarnings("unchecked") + @Test + public void testNullPromiseInArray() { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + new PromiseNotifier>((Promise) null); + } + }); + } + + @Test + public void testListenerSuccess() throws Exception { + @SuppressWarnings("unchecked") + Promise p1 = mock(Promise.class); + @SuppressWarnings("unchecked") + Promise p2 = mock(Promise.class); + + @SuppressWarnings("unchecked") + PromiseNotifier> notifier = + new PromiseNotifier>(p1, p2); + + @SuppressWarnings("unchecked") + Future future = mock(Future.class); + when(future.isSuccess()).thenReturn(true); + when(future.get()).thenReturn(null); + when(p1.trySuccess(null)).thenReturn(true); + when(p2.trySuccess(null)).thenReturn(true); + + notifier.operationComplete(future); + verify(p1).trySuccess(null); + verify(p2).trySuccess(null); + } + + @Test + public void testListenerFailure() throws Exception { + @SuppressWarnings("unchecked") + Promise p1 = mock(Promise.class); + @SuppressWarnings("unchecked") + Promise p2 = mock(Promise.class); + + @SuppressWarnings("unchecked") + PromiseNotifier> notifier = + new PromiseNotifier>(p1, p2); + + @SuppressWarnings("unchecked") + Future future = mock(Future.class); + Throwable t = mock(Throwable.class); + when(future.isSuccess()).thenReturn(false); + when(future.isCancelled()).thenReturn(false); + when(future.cause()).thenReturn(t); + when(p1.tryFailure(t)).thenReturn(true); + when(p2.tryFailure(t)).thenReturn(true); + + notifier.operationComplete(future); + verify(p1).tryFailure(t); + verify(p2).tryFailure(t); + } + + @Test + public void testCancelPropagationWhenFusedFromFuture() { + Promise p1 = ImmediateEventExecutor.INSTANCE.newPromise(); + Promise p2 = ImmediateEventExecutor.INSTANCE.newPromise(); + + Promise returned = PromiseNotifier.cascade(p1, p2); + assertSame(p1, returned); + + assertTrue(returned.cancel(false)); + assertTrue(returned.isCancelled()); + assertTrue(p2.isCancelled()); + } +} diff --git a/netty-util/src/test/java/io/netty/util/concurrent/SingleThreadEventExecutorTest.java b/netty-util/src/test/java/io/netty/util/concurrent/SingleThreadEventExecutorTest.java new file mode 100644 index 0000000..dc12034 --- /dev/null +++ b/netty-util/src/test/java/io/netty/util/concurrent/SingleThreadEventExecutorTest.java @@ -0,0 +1,430 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.concurrent; + +import org.junit.jupiter.api.Test; + +import org.junit.jupiter.api.Timeout; +import org.junit.jupiter.api.function.Executable; + +import java.util.Collections; +import java.util.Set; +import java.util.concurrent.Callable; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Executor; +import java.util.concurrent.ExecutorService; +import java.util.concurrent.Executors; +import java.util.concurrent.LinkedBlockingQueue; +import java.util.concurrent.RejectedExecutionException; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; +import java.util.concurrent.atomic.AtomicReference; + +import static org.hamcrest.CoreMatchers.*; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class SingleThreadEventExecutorTest { + + @Test + public void testWrappedExecutorIsShutdown() { + ExecutorService executorService = Executors.newSingleThreadExecutor(); + + final SingleThreadEventExecutor executor = + new SingleThreadEventExecutor(null, executorService, false) { + @Override + protected void run() { + while (!confirmShutdown()) { + Runnable task = takeTask(); + if (task != null) { + task.run(); + } + } + } + }; + + executorService.shutdownNow(); + executeShouldFail(executor); + executeShouldFail(executor); + assertThrows(RejectedExecutionException.class, new Executable() { + @Override + public void execute() { + executor.shutdownGracefully().syncUninterruptibly(); + } + }); + assertTrue(executor.isShutdown()); + } + + private static void executeShouldFail(final Executor executor) { + assertThrows(RejectedExecutionException.class, new Executable() { + @Override + public void execute() { + executor.execute(new Runnable() { + @Override + public void run() { + // Noop. + } + }); + } + }); + } + + @Test + public void testThreadProperties() { + final AtomicReference threadRef = new AtomicReference(); + SingleThreadEventExecutor executor = new SingleThreadEventExecutor( + null, new DefaultThreadFactory("test"), false) { + @Override + protected void run() { + threadRef.set(Thread.currentThread()); + while (!confirmShutdown()) { + Runnable task = takeTask(); + if (task != null) { + task.run(); + } + } + } + }; + ThreadProperties threadProperties = executor.threadProperties(); + + Thread thread = threadRef.get(); + assertEquals(thread.getId(), threadProperties.id()); + assertEquals(thread.getName(), threadProperties.name()); + assertEquals(thread.getPriority(), threadProperties.priority()); + assertEquals(thread.isAlive(), threadProperties.isAlive()); + assertEquals(thread.isDaemon(), threadProperties.isDaemon()); + assertTrue(threadProperties.stackTrace().length > 0); + executor.shutdownGracefully(); + } + + @Test + @Timeout(value = 3000, unit = TimeUnit.MILLISECONDS) + public void testInvokeAnyInEventLoop() { + testInvokeInEventLoop(true, false); + } + + @Test + @Timeout(value = 3000, unit = TimeUnit.MILLISECONDS) + public void testInvokeAnyInEventLoopWithTimeout() { + testInvokeInEventLoop(true, true); + } + + @Test + @Timeout(value = 3000, unit = TimeUnit.MILLISECONDS) + public void testInvokeAllInEventLoop() { + testInvokeInEventLoop(false, false); + } + + @Test + @Timeout(value = 3000, unit = TimeUnit.MILLISECONDS) + public void testInvokeAllInEventLoopWithTimeout() { + testInvokeInEventLoop(false, true); + } + + private static void testInvokeInEventLoop(final boolean any, final boolean timeout) { + final SingleThreadEventExecutor executor = new SingleThreadEventExecutor(null, + Executors.defaultThreadFactory(), true) { + @Override + protected void run() { + while (!confirmShutdown()) { + Runnable task = takeTask(); + if (task != null) { + task.run(); + } + } + } + }; + try { + assertThrows(RejectedExecutionException.class, new Executable() { + @Override + public void execute() throws Throwable { + final Promise promise = executor.newPromise(); + executor.execute(new Runnable() { + @Override + public void run() { + try { + Set> set = Collections.>singleton( + new Callable() { + @Override + public Boolean call() throws Exception { + promise.setFailure(new AssertionError("Should never execute the Callable")); + return Boolean.TRUE; + } + }); + if (any) { + if (timeout) { + executor.invokeAny(set, 10, TimeUnit.SECONDS); + } else { + executor.invokeAny(set); + } + } else { + if (timeout) { + executor.invokeAll(set, 10, TimeUnit.SECONDS); + } else { + executor.invokeAll(set); + } + } + promise.setFailure(new AssertionError("Should never reach here")); + } catch (Throwable cause) { + promise.setFailure(cause); + } + } + }); + promise.syncUninterruptibly(); + } + }); + } finally { + executor.shutdownGracefully(0, 0, TimeUnit.MILLISECONDS); + } + } + + static class LatchTask extends CountDownLatch implements Runnable { + LatchTask() { + super(1); + } + + @Override + public void run() { + countDown(); + } + } + + static class LazyLatchTask extends LatchTask { } + + @Test + public void testLazyExecution() throws Exception { + final SingleThreadEventExecutor executor = new SingleThreadEventExecutor(null, + Executors.defaultThreadFactory(), false) { + + @Override + protected boolean wakesUpForTask(final Runnable task) { + return !(task instanceof LazyLatchTask); + } + + @Override + protected void run() { + while (!confirmShutdown()) { + try { + synchronized (this) { + if (!hasTasks()) { + wait(); + } + } + runAllTasks(); + } catch (Exception e) { + e.printStackTrace(); + fail(e.toString()); + } + } + } + + @Override + protected void wakeup(boolean inEventLoop) { + if (!inEventLoop) { + synchronized (this) { + notifyAll(); + } + } + } + }; + + // Ensure event loop is started + LatchTask latch0 = new LatchTask(); + executor.execute(latch0); + assertTrue(latch0.await(100, TimeUnit.MILLISECONDS)); + // Pause to ensure it enters waiting state + Thread.sleep(100L); + + // Submit task via lazyExecute + LatchTask latch1 = new LatchTask(); + executor.lazyExecute(latch1); + // Sumbit lazy task via regular execute + LatchTask latch2 = new LazyLatchTask(); + executor.execute(latch2); + + // Neither should run yet + assertFalse(latch1.await(100, TimeUnit.MILLISECONDS)); + assertFalse(latch2.await(100, TimeUnit.MILLISECONDS)); + + // Submit regular task via regular execute + LatchTask latch3 = new LatchTask(); + executor.execute(latch3); + + // Should flush latch1 and latch2 and then run latch3 immediately + assertTrue(latch3.await(100, TimeUnit.MILLISECONDS)); + assertEquals(0, latch1.getCount()); + assertEquals(0, latch2.getCount()); + } + + @Test + public void testTaskAddedAfterShutdownNotAbandoned() throws Exception { + + // A queue that doesn't support remove, so tasks once added cannot be rejected anymore + LinkedBlockingQueue taskQueue = new LinkedBlockingQueue() { + @Override + public boolean remove(Object o) { + throw new UnsupportedOperationException(); + } + }; + + final Runnable dummyTask = new Runnable() { + @Override + public void run() { + } + }; + + final LinkedBlockingQueue> submittedTasks = new LinkedBlockingQueue>(); + final AtomicInteger attempts = new AtomicInteger(); + final AtomicInteger rejects = new AtomicInteger(); + + ExecutorService executorService = Executors.newSingleThreadExecutor(); + final SingleThreadEventExecutor executor = new SingleThreadEventExecutor(null, executorService, false, + taskQueue, RejectedExecutionHandlers.reject()) { + @Override + protected void run() { + while (!confirmShutdown()) { + Runnable task = takeTask(); + if (task != null) { + task.run(); + } + } + } + + @Override + protected boolean confirmShutdown() { + boolean result = super.confirmShutdown(); + // After shutdown is confirmed, scheduled one more task and record it + if (result) { + attempts.incrementAndGet(); + try { + submittedTasks.add(submit(dummyTask)); + } catch (RejectedExecutionException e) { + // ignore, tasks are either accepted or rejected + rejects.incrementAndGet(); + } + } + return result; + } + }; + + // Start the loop + executor.submit(dummyTask).sync(); + + // Shutdown without any quiet period + executor.shutdownGracefully(0, 100, TimeUnit.MILLISECONDS).sync(); + + // Ensure there are no user-tasks left. + assertEquals(0, executor.drainTasks()); + + // Verify that queue is empty and all attempts either succeeded or were rejected + assertTrue(taskQueue.isEmpty()); + assertTrue(attempts.get() > 0); + assertEquals(attempts.get(), submittedTasks.size() + rejects.get()); + for (Future f : submittedTasks) { + assertTrue(f.isSuccess()); + } + } + + @Test + @Timeout(value = 5000, unit = TimeUnit.MILLISECONDS) + public void testTakeTask() throws Exception { + final SingleThreadEventExecutor executor = + new SingleThreadEventExecutor(null, Executors.defaultThreadFactory(), true) { + @Override + protected void run() { + while (!confirmShutdown()) { + Runnable task = takeTask(); + if (task != null) { + task.run(); + } + } + } + }; + + //add task + TestRunnable beforeTask = new TestRunnable(); + executor.execute(beforeTask); + + //add scheduled task + TestRunnable scheduledTask = new TestRunnable(); + ScheduledFuture f = executor.schedule(scheduledTask , 1500, TimeUnit.MILLISECONDS); + + //add task + TestRunnable afterTask = new TestRunnable(); + executor.execute(afterTask); + + f.sync(); + + assertThat(beforeTask.ran.get(), is(true)); + assertThat(scheduledTask.ran.get(), is(true)); + assertThat(afterTask.ran.get(), is(true)); + } + + @Test + @Timeout(value = 5000, unit = TimeUnit.MILLISECONDS) + public void testTakeTaskAlwaysHasTask() throws Exception { + //for https://github.com/netty/netty/issues/1614 + + final SingleThreadEventExecutor executor = + new SingleThreadEventExecutor(null, Executors.defaultThreadFactory(), true) { + @Override + protected void run() { + while (!confirmShutdown()) { + Runnable task = takeTask(); + if (task != null) { + task.run(); + } + } + } + }; + + //add scheduled task + TestRunnable t = new TestRunnable(); + final ScheduledFuture f = executor.schedule(t, 1500, TimeUnit.MILLISECONDS); + + //ensure always has at least one task in taskQueue + //check if scheduled tasks are triggered + executor.execute(new Runnable() { + @Override + public void run() { + if (!f.isDone()) { + executor.execute(this); + } + } + }); + + f.sync(); + + assertThat(t.ran.get(), is(true)); + } + + private static final class TestRunnable implements Runnable { + final AtomicBoolean ran = new AtomicBoolean(); + + TestRunnable() { + } + + @Override + public void run() { + ran.set(true); + } + } +} diff --git a/netty-util/src/test/java/io/netty/util/concurrent/UnorderedThreadPoolEventExecutorTest.java b/netty-util/src/test/java/io/netty/util/concurrent/UnorderedThreadPoolEventExecutorTest.java new file mode 100644 index 0000000..7d57b9c --- /dev/null +++ b/netty-util/src/test/java/io/netty/util/concurrent/UnorderedThreadPoolEventExecutorTest.java @@ -0,0 +1,135 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.concurrent; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.util.concurrent.Callable; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Exchanger; +import java.util.concurrent.TimeUnit; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class UnorderedThreadPoolEventExecutorTest { + + // See https://github.com/netty/netty/issues/6507 + @Test + public void testNotEndlessExecute() throws Exception { + UnorderedThreadPoolEventExecutor executor = new UnorderedThreadPoolEventExecutor(1); + + try { + // Having the first task wait on an exchanger allow us to make sure that the lister on the second task + // is not added *after* the promise completes. We need to do this to prevent a race where the second task + // and listener are completed before the DefaultPromise.NotifyListeners task get to run, which means our + // queue inspection might observe this task after the CountDownLatch opens. + final Exchanger exchanger = new Exchanger(); + final CountDownLatch latch = new CountDownLatch(3); + Runnable task = new Runnable() { + @Override + public void run() { + try { + exchanger.exchange(null); + } catch (InterruptedException e) { + throw new RuntimeException(e); + } + latch.countDown(); + } + }; + executor.execute(task); + Future future = executor.submit(new Runnable() { + @Override + public void run() { + latch.countDown(); + } + }).addListener(new FutureListener() { + @Override + public void operationComplete(Future future) throws Exception { + latch.countDown(); + } + }); + exchanger.exchange(null); + latch.await(); + future.syncUninterruptibly(); + + // Now just check if the queue stays empty multiple times. This is needed as the submit to execute(...) + // by DefaultPromise may happen in an async fashion + for (int i = 0; i < 10000; i++) { + assertTrue(executor.getQueue().isEmpty()); + } + } finally { + executor.shutdownGracefully(); + } + } + + @Test + @Timeout(value = 10000, unit = TimeUnit.MILLISECONDS) + public void scheduledAtFixedRateMustRunTaskRepeatedly() throws InterruptedException { + UnorderedThreadPoolEventExecutor executor = new UnorderedThreadPoolEventExecutor(1); + final CountDownLatch latch = new CountDownLatch(3); + Future future = executor.scheduleAtFixedRate(new Runnable() { + @Override + public void run() { + latch.countDown(); + } + }, 1, 1, TimeUnit.MILLISECONDS); + try { + latch.await(); + } finally { + future.cancel(true); + executor.shutdownGracefully(); + } + } + + @Test + public void testGetReturnsCorrectValueOnSuccess() throws Exception { + UnorderedThreadPoolEventExecutor executor = new UnorderedThreadPoolEventExecutor(1); + try { + final String expected = "expected"; + Future f = executor.submit(new Callable() { + @Override + public String call() { + return expected; + } + }); + + assertEquals(expected, f.get()); + } finally { + executor.shutdownGracefully(); + } + } + + @Test + public void testGetReturnsCorrectValueOnFailure() throws Exception { + UnorderedThreadPoolEventExecutor executor = new UnorderedThreadPoolEventExecutor(1); + try { + final RuntimeException cause = new RuntimeException(); + Future f = executor.submit(new Callable() { + @Override + public String call() { + throw cause; + } + }); + + assertSame(cause, f.await().cause()); + } finally { + executor.shutdownGracefully(); + } + } +} diff --git a/netty-util/src/test/java/io/netty/util/internal/AppendableCharSequenceTest.java b/netty-util/src/test/java/io/netty/util/internal/AppendableCharSequenceTest.java new file mode 100644 index 0000000..b735a8a --- /dev/null +++ b/netty-util/src/test/java/io/netty/util/internal/AppendableCharSequenceTest.java @@ -0,0 +1,110 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.internal; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; + +public class AppendableCharSequenceTest { + + @Test + public void testSimpleAppend() { + testSimpleAppend0(new AppendableCharSequence(128)); + } + + @Test + public void testAppendString() { + testAppendString0(new AppendableCharSequence(128)); + } + + @Test + public void testAppendAppendableCharSequence() { + AppendableCharSequence seq = new AppendableCharSequence(128); + + String text = "testdata"; + AppendableCharSequence seq2 = new AppendableCharSequence(128); + seq2.append(text); + seq.append(seq2); + + assertEquals(text, seq.toString()); + assertEquals(text.substring(1, text.length() - 2), seq.substring(1, text.length() - 2)); + + assertEqualsChars(text, seq); + } + + @Test + public void testSimpleAppendWithExpand() { + testSimpleAppend0(new AppendableCharSequence(2)); + } + + @Test + public void testAppendStringWithExpand() { + testAppendString0(new AppendableCharSequence(2)); + } + + @Test + public void testSubSequence() { + AppendableCharSequence master = new AppendableCharSequence(26); + master.append("abcdefghijlkmonpqrstuvwxyz"); + assertEquals("abcdefghij", master.subSequence(0, 10).toString()); + } + + @Test + public void testEmptySubSequence() { + AppendableCharSequence master = new AppendableCharSequence(26); + master.append("abcdefghijlkmonpqrstuvwxyz"); + AppendableCharSequence sub = master.subSequence(0, 0); + assertEquals(0, sub.length()); + sub.append('b'); + assertEquals('b', sub.charAt(0)); + } + + private static void testSimpleAppend0(AppendableCharSequence seq) { + String text = "testdata"; + for (int i = 0; i < text.length(); i++) { + seq.append(text.charAt(i)); + } + + assertEquals(text, seq.toString()); + assertEquals(text.substring(1, text.length() - 2), seq.substring(1, text.length() - 2)); + + assertEqualsChars(text, seq); + + seq.reset(); + assertEquals(0, seq.length()); + } + + private static void testAppendString0(AppendableCharSequence seq) { + String text = "testdata"; + seq.append(text); + + assertEquals(text, seq.toString()); + assertEquals(text.substring(1, text.length() - 2), seq.substring(1, text.length() - 2)); + + assertEqualsChars(text, seq); + + seq.reset(); + assertEquals(0, seq.length()); + } + + private static void assertEqualsChars(CharSequence seq1, CharSequence seq2) { + assertEquals(seq1.length(), seq2.length()); + for (int i = 0; i < seq1.length(); i++) { + assertEquals(seq1.charAt(i), seq2.charAt(i)); + } + } +} diff --git a/netty-util/src/test/java/io/netty/util/internal/DefaultPriorityQueueTest.java b/netty-util/src/test/java/io/netty/util/internal/DefaultPriorityQueueTest.java new file mode 100644 index 0000000..132e471 --- /dev/null +++ b/netty-util/src/test/java/io/netty/util/internal/DefaultPriorityQueueTest.java @@ -0,0 +1,322 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.internal; + +import org.junit.jupiter.api.Test; + +import java.io.Serializable; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.Comparator; +import java.util.Iterator; +import java.util.List; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; + +public class DefaultPriorityQueueTest { + @Test + public void testPoll() { + PriorityQueue queue = new DefaultPriorityQueue(TestElementComparator.INSTANCE, 0); + assertEmptyQueue(queue); + + TestElement a = new TestElement(5); + TestElement b = new TestElement(10); + TestElement c = new TestElement(2); + TestElement d = new TestElement(7); + TestElement e = new TestElement(6); + + assertOffer(queue, a); + assertOffer(queue, b); + assertOffer(queue, c); + assertOffer(queue, d); + + // Remove the first element + assertSame(c, queue.peek()); + assertSame(c, queue.poll()); + assertEquals(3, queue.size()); + + // Test that offering another element preserves the priority queue semantics. + assertOffer(queue, e); + assertEquals(4, queue.size()); + assertSame(a, queue.peek()); + assertSame(a, queue.poll()); + assertEquals(3, queue.size()); + + // Keep removing the remaining elements + assertSame(e, queue.peek()); + assertSame(e, queue.poll()); + assertEquals(2, queue.size()); + + assertSame(d, queue.peek()); + assertSame(d, queue.poll()); + assertEquals(1, queue.size()); + + assertSame(b, queue.peek()); + assertSame(b, queue.poll()); + assertEmptyQueue(queue); + } + + @Test + public void testClear() { + PriorityQueue queue = new DefaultPriorityQueue(TestElementComparator.INSTANCE, 0); + assertEmptyQueue(queue); + + TestElement a = new TestElement(5); + TestElement b = new TestElement(10); + TestElement c = new TestElement(2); + TestElement d = new TestElement(6); + + assertOffer(queue, a); + assertOffer(queue, b); + assertOffer(queue, c); + assertOffer(queue, d); + + queue.clear(); + assertEmptyQueue(queue); + + // Test that elements can be re-inserted after the clear operation + assertOffer(queue, a); + assertSame(a, queue.peek()); + + assertOffer(queue, b); + assertSame(a, queue.peek()); + + assertOffer(queue, c); + assertSame(c, queue.peek()); + + assertOffer(queue, d); + assertSame(c, queue.peek()); + } + + @Test + public void testClearIgnoringIndexes() { + PriorityQueue queue = new DefaultPriorityQueue(TestElementComparator.INSTANCE, 0); + assertEmptyQueue(queue); + + TestElement a = new TestElement(5); + TestElement b = new TestElement(10); + TestElement c = new TestElement(2); + TestElement d = new TestElement(6); + TestElement e = new TestElement(11); + + assertOffer(queue, a); + assertOffer(queue, b); + assertOffer(queue, c); + assertOffer(queue, d); + + queue.clearIgnoringIndexes(); + assertEmptyQueue(queue); + + // Elements cannot be re-inserted but new ones can. + try { + queue.offer(a); + fail(); + } catch (IllegalArgumentException t) { + // expected + } + + assertOffer(queue, e); + assertSame(e, queue.peek()); + } + + @Test + public void testRemoval() { + testRemoval(false); + } + + @Test + public void testRemovalTyped() { + testRemoval(true); + } + + private static void testRemoval(boolean typed) { + PriorityQueue queue = new DefaultPriorityQueue(TestElementComparator.INSTANCE, 4); + assertEmptyQueue(queue); + + TestElement a = new TestElement(5); + TestElement b = new TestElement(10); + TestElement c = new TestElement(2); + TestElement d = new TestElement(6); + TestElement notInQueue = new TestElement(-1); + + assertOffer(queue, a); + assertOffer(queue, b); + assertOffer(queue, c); + assertOffer(queue, d); + + // Remove an element that isn't in the queue. + assertFalse(typed ? queue.removeTyped(notInQueue) : queue.remove(notInQueue)); + assertSame(c, queue.peek()); + assertEquals(4, queue.size()); + + // Remove the last element in the array, when the array is non-empty. + assertTrue(typed ? queue.removeTyped(b) : queue.remove(b)); + assertSame(c, queue.peek()); + assertEquals(3, queue.size()); + + // Re-insert the element after removal + assertOffer(queue, b); + assertSame(c, queue.peek()); + assertEquals(4, queue.size()); + + // Repeat remove the last element in the array, when the array is non-empty. + assertTrue(typed ? queue.removeTyped(b) : queue.remove(b)); + assertSame(c, queue.peek()); + assertEquals(3, queue.size()); + + // Remove the head of the queue. + assertTrue(typed ? queue.removeTyped(c) : queue.remove(c)); + assertSame(a, queue.peek()); + assertEquals(2, queue.size()); + + assertTrue(typed ? queue.removeTyped(a) : queue.remove(a)); + assertSame(d, queue.peek()); + assertEquals(1, queue.size()); + + assertTrue(typed ? queue.removeTyped(d) : queue.remove(d)); + assertEmptyQueue(queue); + } + + @Test + public void testZeroInitialSize() { + PriorityQueue queue = new DefaultPriorityQueue(TestElementComparator.INSTANCE, 0); + assertEmptyQueue(queue); + TestElement e = new TestElement(1); + assertOffer(queue, e); + assertSame(e, queue.peek()); + assertEquals(1, queue.size()); + assertFalse(queue.isEmpty()); + assertSame(e, queue.poll()); + assertEmptyQueue(queue); + } + + @Test + public void testPriorityChange() { + PriorityQueue queue = new DefaultPriorityQueue(TestElementComparator.INSTANCE, 0); + assertEmptyQueue(queue); + TestElement a = new TestElement(10); + TestElement b = new TestElement(20); + TestElement c = new TestElement(30); + TestElement d = new TestElement(25); + TestElement e = new TestElement(23); + TestElement f = new TestElement(15); + queue.add(a); + queue.add(b); + queue.add(c); + queue.add(d); + queue.add(e); + queue.add(f); + + e.value = 35; + queue.priorityChanged(e); + + a.value = 40; + queue.priorityChanged(a); + + a.value = 31; + queue.priorityChanged(a); + + d.value = 10; + queue.priorityChanged(d); + + f.value = 5; + queue.priorityChanged(f); + + List expectedOrderList = new ArrayList(queue.size()); + expectedOrderList.addAll(Arrays.asList(a, b, c, d, e, f)); + Collections.sort(expectedOrderList, TestElementComparator.INSTANCE); + + assertEquals(expectedOrderList.size(), queue.size()); + assertEquals(expectedOrderList.isEmpty(), queue.isEmpty()); + Iterator itr = expectedOrderList.iterator(); + while (itr.hasNext()) { + TestElement next = itr.next(); + TestElement poll = queue.poll(); + assertEquals(next, poll); + itr.remove(); + assertEquals(expectedOrderList.size(), queue.size()); + assertEquals(expectedOrderList.isEmpty(), queue.isEmpty()); + } + } + + private static void assertOffer(PriorityQueue queue, TestElement a) { + assertTrue(queue.offer(a)); + assertTrue(queue.contains(a)); + assertTrue(queue.containsTyped(a)); + try { // An element can not be inserted more than 1 time. + queue.offer(a); + fail(); + } catch (IllegalArgumentException ignored) { + // ignored + } + } + + private static void assertEmptyQueue(PriorityQueue queue) { + assertNull(queue.peek()); + assertNull(queue.poll()); + assertEquals(0, queue.size()); + assertTrue(queue.isEmpty()); + } + + private static final class TestElementComparator implements Comparator, Serializable { + private static final long serialVersionUID = 7930368853384760103L; + + static final TestElementComparator INSTANCE = new TestElementComparator(); + + private TestElementComparator() { + } + + @Override + public int compare(TestElement o1, TestElement o2) { + return o1.value - o2.value; + } + } + + private static final class TestElement implements PriorityQueueNode { + int value; + private int priorityQueueIndex = INDEX_NOT_IN_QUEUE; + + TestElement(int value) { + this.value = value; + } + + @Override + public boolean equals(Object o) { + return o instanceof TestElement && ((TestElement) o).value == value; + } + + @Override + public int hashCode() { + return value; + } + + @Override + public int priorityQueueIndex(DefaultPriorityQueue queue) { + return priorityQueueIndex; + } + + @Override + public void priorityQueueIndex(DefaultPriorityQueue queue, int i) { + priorityQueueIndex = i; + } + } +} diff --git a/netty-util/src/test/java/io/netty/util/internal/MacAddressUtilTest.java b/netty-util/src/test/java/io/netty/util/internal/MacAddressUtilTest.java new file mode 100644 index 0000000..f7a2541 --- /dev/null +++ b/netty-util/src/test/java/io/netty/util/internal/MacAddressUtilTest.java @@ -0,0 +1,198 @@ +/* + * Copyright 2016 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.internal; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import static io.netty.util.internal.EmptyArrays.EMPTY_BYTES; +import static io.netty.util.internal.MacAddressUtil.parseMAC; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertThrows; + +public class MacAddressUtilTest { + @Test + public void testCompareAddresses() { + // should not prefer empty address when candidate is not globally unique + assertEquals( + 0, + MacAddressUtil.compareAddresses( + EMPTY_BYTES, + new byte[]{(byte) 0x52, (byte) 0x54, (byte) 0x00, (byte) 0xf9, (byte) 0x32, (byte) 0xbd})); + + // only candidate is globally unique + assertEquals( + -1, + MacAddressUtil.compareAddresses( + EMPTY_BYTES, + new byte[]{(byte) 0x50, (byte) 0x54, (byte) 0x00, (byte) 0xf9, (byte) 0x32, (byte) 0xbd})); + + // only candidate is globally unique + assertEquals( + -1, + MacAddressUtil.compareAddresses( + new byte[]{(byte) 0x52, (byte) 0x54, (byte) 0x00, (byte) 0xf9, (byte) 0x32, (byte) 0xbd}, + new byte[]{(byte) 0x50, (byte) 0x54, (byte) 0x00, (byte) 0xf9, (byte) 0x32, (byte) 0xbd})); + + // only current is globally unique + assertEquals( + 1, + MacAddressUtil.compareAddresses( + new byte[]{(byte) 0x52, (byte) 0x54, (byte) 0x00, (byte) 0xf9, (byte) 0x32, (byte) 0xbd}, + EMPTY_BYTES)); + + // only current is globally unique + assertEquals( + 1, + MacAddressUtil.compareAddresses( + new byte[]{(byte) 0x50, (byte) 0x54, (byte) 0x00, (byte) 0xf9, (byte) 0x32, (byte) 0xbd}, + new byte[]{(byte) 0x52, (byte) 0x54, (byte) 0x00, (byte) 0xf9, (byte) 0x32, (byte) 0xbd})); + + // both are globally unique + assertEquals( + 0, + MacAddressUtil.compareAddresses( + new byte[]{(byte) 0x50, (byte) 0x54, (byte) 0x00, (byte) 0xf9, (byte) 0x32, (byte) 0xbd}, + new byte[]{(byte) 0x50, (byte) 0x55, (byte) 0x01, (byte) 0xfa, (byte) 0x33, (byte) 0xbe})); + } + + @Test + public void testParseMacEUI48() { + assertArrayEquals(new byte[]{0, (byte) 0xaa, 0x11, (byte) 0xbb, 0x22, (byte) 0xcc}, + parseMAC("00-AA-11-BB-22-CC")); + assertArrayEquals(new byte[]{0, (byte) 0xaa, 0x11, (byte) 0xbb, 0x22, (byte) 0xcc}, + parseMAC("00:AA:11:BB:22:CC")); + } + + @Test + public void testParseMacMAC48ToEUI64() { + // MAC-48 into an EUI-64 + assertArrayEquals(new byte[]{0, (byte) 0xaa, 0x11, (byte) 0xff, (byte) 0xff, (byte) 0xbb, 0x22, (byte) 0xcc}, + parseMAC("00-AA-11-FF-FF-BB-22-CC")); + assertArrayEquals(new byte[]{0, (byte) 0xaa, 0x11, (byte) 0xff, (byte) 0xff, (byte) 0xbb, 0x22, (byte) 0xcc}, + parseMAC("00:AA:11:FF:FF:BB:22:CC")); + } + + @Test + public void testParseMacEUI48ToEUI64() { + // EUI-48 into an EUI-64 + assertArrayEquals(new byte[]{0, (byte) 0xaa, 0x11, (byte) 0xff, (byte) 0xfe, (byte) 0xbb, 0x22, (byte) 0xcc}, + parseMAC("00-AA-11-FF-FE-BB-22-CC")); + assertArrayEquals(new byte[]{0, (byte) 0xaa, 0x11, (byte) 0xff, (byte) 0xfe, (byte) 0xbb, 0x22, (byte) 0xcc}, + parseMAC("00:AA:11:FF:FE:BB:22:CC")); + } + + @Test + public void testParseMacInvalid7HexGroupsA() { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + parseMAC("00-AA-11-BB-22-CC-FF"); + } + }); + } + + @Test + public void testParseMacInvalid7HexGroupsB() { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + parseMAC("00:AA:11:BB:22:CC:FF"); + } + }); + } + + @Test + public void testParseMacInvalidEUI48MixedSeparatorA() { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + parseMAC("00-AA:11-BB-22-CC"); + } + }); + } + + @Test + public void testParseMacInvalidEUI48MixedSeparatorB() { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + parseMAC("00:AA-11:BB:22:CC"); + } + }); + } + + @Test + public void testParseMacInvalidEUI64MixedSeparatorA() { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + parseMAC("00-AA-11-FF-FE-BB-22:CC"); + } + }); + } + + @Test + public void testParseMacInvalidEUI64MixedSeparatorB() { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + parseMAC("00:AA:11:FF:FE:BB:22-CC"); + } + }); + } + + @Test + public void testParseMacInvalidEUI48TrailingSeparatorA() { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + parseMAC("00-AA-11-BB-22-CC-"); + } + }); + } + + @Test + public void testParseMacInvalidEUI48TrailingSeparatorB() { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + parseMAC("00:AA:11:BB:22:CC:"); + } + }); + } + + @Test + public void testParseMacInvalidEUI64TrailingSeparatorA() { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + parseMAC("00-AA-11-FF-FE-BB-22-CC-"); + } + }); + } + + @Test + public void testParseMacInvalidEUI64TrailingSeparatorB() { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + parseMAC("00:AA:11:FF:FE:BB:22:CC:"); + } + }); + } +} diff --git a/netty-util/src/test/java/io/netty/util/internal/MathUtilTest.java b/netty-util/src/test/java/io/netty/util/internal/MathUtilTest.java new file mode 100644 index 0000000..9b6370d --- /dev/null +++ b/netty-util/src/test/java/io/netty/util/internal/MathUtilTest.java @@ -0,0 +1,91 @@ +/* + * Copyright 2020 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.internal; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static io.netty.util.internal.MathUtil.*; + +import org.junit.jupiter.api.Test; + +public class MathUtilTest { + + @Test + public void testFindNextPositivePowerOfTwo() { + assertEquals(1, findNextPositivePowerOfTwo(0)); + assertEquals(1, findNextPositivePowerOfTwo(1)); + assertEquals(1024, findNextPositivePowerOfTwo(1000)); + assertEquals(1024, findNextPositivePowerOfTwo(1023)); + assertEquals(2048, findNextPositivePowerOfTwo(2048)); + assertEquals(1 << 30, findNextPositivePowerOfTwo((1 << 30) - 1)); + assertEquals(1, findNextPositivePowerOfTwo(-1)); + assertEquals(1, findNextPositivePowerOfTwo(-10000)); + } + + @Test + public void testSafeFindNextPositivePowerOfTwo() { + assertEquals(1, safeFindNextPositivePowerOfTwo(0)); + assertEquals(1, safeFindNextPositivePowerOfTwo(1)); + assertEquals(1024, safeFindNextPositivePowerOfTwo(1000)); + assertEquals(1024, safeFindNextPositivePowerOfTwo(1023)); + assertEquals(2048, safeFindNextPositivePowerOfTwo(2048)); + assertEquals(1 << 30, safeFindNextPositivePowerOfTwo((1 << 30) - 1)); + assertEquals(1, safeFindNextPositivePowerOfTwo(-1)); + assertEquals(1, safeFindNextPositivePowerOfTwo(-10000)); + assertEquals(1 << 30, safeFindNextPositivePowerOfTwo(Integer.MAX_VALUE)); + assertEquals(1 << 30, safeFindNextPositivePowerOfTwo((1 << 30) + 1)); + assertEquals(1, safeFindNextPositivePowerOfTwo(Integer.MIN_VALUE)); + assertEquals(1, safeFindNextPositivePowerOfTwo(Integer.MIN_VALUE + 1)); + } + + @Test + public void testIsOutOfBounds() { + assertFalse(isOutOfBounds(0, 0, 0)); + assertFalse(isOutOfBounds(0, 0, 1)); + assertFalse(isOutOfBounds(0, 1, 1)); + assertTrue(isOutOfBounds(1, 1, 1)); + assertTrue(isOutOfBounds(Integer.MAX_VALUE, 1, 1)); + assertTrue(isOutOfBounds(Integer.MAX_VALUE, Integer.MAX_VALUE, 1)); + assertTrue(isOutOfBounds(Integer.MAX_VALUE, Integer.MAX_VALUE, Integer.MAX_VALUE)); + assertFalse(isOutOfBounds(0, Integer.MAX_VALUE, Integer.MAX_VALUE)); + assertFalse(isOutOfBounds(0, Integer.MAX_VALUE - 1, Integer.MAX_VALUE)); + assertTrue(isOutOfBounds(0, Integer.MAX_VALUE, Integer.MAX_VALUE - 1)); + assertFalse(isOutOfBounds(Integer.MAX_VALUE - 1, 1, Integer.MAX_VALUE)); + assertTrue(isOutOfBounds(Integer.MAX_VALUE - 1, 1, Integer.MAX_VALUE - 1)); + assertTrue(isOutOfBounds(Integer.MAX_VALUE - 1, 2, Integer.MAX_VALUE)); + assertTrue(isOutOfBounds(1, Integer.MAX_VALUE, Integer.MAX_VALUE)); + assertTrue(isOutOfBounds(0, 1, Integer.MIN_VALUE)); + assertTrue(isOutOfBounds(0, 1, -1)); + assertTrue(isOutOfBounds(0, Integer.MAX_VALUE, 0)); + } + + @Test + public void testCompare() { + assertEquals(-1, compare(0, 1)); + assertEquals(-1, compare(0L, 1L)); + assertEquals(-1, compare(0, Integer.MAX_VALUE)); + assertEquals(-1, compare(0L, Long.MAX_VALUE)); + assertEquals(0, compare(0, 0)); + assertEquals(0, compare(0L, 0L)); + assertEquals(0, compare(Integer.MIN_VALUE, Integer.MIN_VALUE)); + assertEquals(0, compare(Long.MIN_VALUE, Long.MIN_VALUE)); + assertEquals(1, compare(Integer.MAX_VALUE, 0)); + assertEquals(1, compare(Integer.MAX_VALUE, Integer.MAX_VALUE - 1)); + assertEquals(1, compare(Long.MAX_VALUE, 0L)); + assertEquals(1, compare(Long.MAX_VALUE, Long.MAX_VALUE - 1)); + } +} diff --git a/netty-util/src/test/java/io/netty/util/internal/NativeLibraryLoaderTest.java b/netty-util/src/test/java/io/netty/util/internal/NativeLibraryLoaderTest.java new file mode 100644 index 0000000..03b66f6 --- /dev/null +++ b/netty-util/src/test/java/io/netty/util/internal/NativeLibraryLoaderTest.java @@ -0,0 +1,127 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.internal; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.condition.EnabledIf; +import org.junit.jupiter.api.condition.EnabledOnOs; +import org.junit.jupiter.api.function.Executable; + +import java.io.File; +import java.io.FileNotFoundException; +import java.net.MalformedURLException; +import java.net.URL; +import java.net.URLClassLoader; +import java.util.UUID; + +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assertions.fail; +import static org.junit.jupiter.api.condition.OS.LINUX; + +class NativeLibraryLoaderTest { + + private static final String OS_ARCH = System.getProperty("os.arch"); + private boolean is_x86_64() { + return "x86_64".equals(OS_ARCH) || "amd64".equals(OS_ARCH); + } + + @Test + void testFileNotFound() { + try { + NativeLibraryLoader.load(UUID.randomUUID().toString(), NativeLibraryLoaderTest.class.getClassLoader()); + fail(); + } catch (UnsatisfiedLinkError error) { + assertTrue(error.getCause() instanceof FileNotFoundException); + if (PlatformDependent.javaVersion() >= 7) { + verifySuppressedException(error, UnsatisfiedLinkError.class); + } + } + } + + @Test + void testFileNotFoundWithNullClassLoader() { + try { + NativeLibraryLoader.load(UUID.randomUUID().toString(), null); + fail(); + } catch (UnsatisfiedLinkError error) { + assertTrue(error.getCause() instanceof FileNotFoundException); + if (PlatformDependent.javaVersion() >= 7) { + verifySuppressedException(error, ClassNotFoundException.class); + } + } + } + + @Test + @EnabledOnOs(LINUX) + @EnabledIf("is_x86_64") + void testMultipleResourcesWithSameContentInTheClassLoader() throws MalformedURLException { + URL url1 = new File("src/test/data/NativeLibraryLoader/1").toURI().toURL(); + URL url2 = new File("src/test/data/NativeLibraryLoader/2").toURI().toURL(); + final URLClassLoader loader = new URLClassLoader(new URL[] {url1, url2}); + final String resourceName = "test3"; + + NativeLibraryLoader.load(resourceName, loader); + assertTrue(true); + } + + @Test + @EnabledOnOs(LINUX) + @EnabledIf("is_x86_64") + void testMultipleResourcesInTheClassLoader() throws MalformedURLException { + URL url1 = new File("src/test/data/NativeLibraryLoader/1").toURI().toURL(); + URL url2 = new File("src/test/data/NativeLibraryLoader/2").toURI().toURL(); + final URLClassLoader loader = new URLClassLoader(new URL[] {url1, url2}); + final String resourceName = "test1"; + + Exception ise = assertThrows(IllegalStateException.class, new Executable() { + @Override + public void execute() { + NativeLibraryLoader.load(resourceName, loader); + } + }); + assertTrue(ise.getMessage() + .contains("Multiple resources found for 'META-INF/native/lib" + resourceName + ".so'")); + } + + @Test + @EnabledOnOs(LINUX) + @EnabledIf("is_x86_64") + void testSingleResourceInTheClassLoader() throws MalformedURLException { + URL url1 = new File("src/test/data/NativeLibraryLoader/1").toURI().toURL(); + URL url2 = new File("src/test/data/NativeLibraryLoader/2").toURI().toURL(); + URLClassLoader loader = new URLClassLoader(new URL[] {url1, url2}); + String resourceName = "test2"; + + NativeLibraryLoader.load(resourceName, loader); + assertTrue(true); + } + + @SuppressJava6Requirement(reason = "uses Java 7+ Throwable#getSuppressed but is guarded by version checks") + private static void verifySuppressedException(UnsatisfiedLinkError error, + Class expectedSuppressedExceptionClass) { + try { + Throwable[] suppressed = error.getCause().getSuppressed(); + assertTrue(suppressed.length == 1); + assertTrue(suppressed[0] instanceof UnsatisfiedLinkError); + suppressed = (suppressed[0]).getSuppressed(); + assertTrue(suppressed.length == 1); + assertTrue(expectedSuppressedExceptionClass.isInstance(suppressed[0])); + } catch (Exception e) { + throw new RuntimeException(e); + } + } +} diff --git a/netty-util/src/test/java/io/netty/util/internal/ObjectCleanerTest.java b/netty-util/src/test/java/io/netty/util/internal/ObjectCleanerTest.java new file mode 100644 index 0000000..4c56230 --- /dev/null +++ b/netty-util/src/test/java/io/netty/util/internal/ObjectCleanerTest.java @@ -0,0 +1,142 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.internal; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.Timeout; + +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class ObjectCleanerTest { + + private Thread temporaryThread; + private Object temporaryObject; + + @Test + @Timeout(value = 5000, unit = TimeUnit.MILLISECONDS) + public void testCleanup() throws Exception { + final AtomicBoolean freeCalled = new AtomicBoolean(); + final CountDownLatch latch = new CountDownLatch(1); + temporaryThread = new Thread(new Runnable() { + @Override + public void run() { + try { + latch.await(); + } catch (InterruptedException ignore) { + // just ignore + } + } + }); + temporaryThread.start(); + ObjectCleaner.register(temporaryThread, new Runnable() { + @Override + public void run() { + freeCalled.set(true); + } + }); + + latch.countDown(); + temporaryThread.join(); + assertFalse(freeCalled.get()); + + // Null out the temporary object to ensure it is enqueued for GC. + temporaryThread = null; + + while (!freeCalled.get()) { + System.gc(); + System.runFinalization(); + Thread.sleep(100); + } + } + + @Test + @Timeout(value = 5000, unit = TimeUnit.MILLISECONDS) + public void testCleanupContinuesDespiteThrowing() throws InterruptedException { + final AtomicInteger freeCalledCount = new AtomicInteger(); + final CountDownLatch latch = new CountDownLatch(1); + temporaryThread = new Thread(new Runnable() { + @Override + public void run() { + try { + latch.await(); + } catch (InterruptedException ignore) { + // just ignore + } + } + }); + temporaryThread.start(); + temporaryObject = new Object(); + ObjectCleaner.register(temporaryThread, new Runnable() { + @Override + public void run() { + freeCalledCount.incrementAndGet(); + throw new RuntimeException("expected"); + } + }); + ObjectCleaner.register(temporaryObject, new Runnable() { + @Override + public void run() { + freeCalledCount.incrementAndGet(); + throw new RuntimeException("expected"); + } + }); + + latch.countDown(); + temporaryThread.join(); + assertEquals(0, freeCalledCount.get()); + + // Null out the temporary object to ensure it is enqueued for GC. + temporaryThread = null; + temporaryObject = null; + + while (freeCalledCount.get() != 2) { + System.gc(); + System.runFinalization(); + Thread.sleep(100); + } + } + + @Test + @Timeout(value = 5000, unit = TimeUnit.MILLISECONDS) + public void testCleanerThreadIsDaemon() throws Exception { + temporaryObject = new Object(); + ObjectCleaner.register(temporaryObject, new Runnable() { + @Override + public void run() { + // NOOP + } + }); + + Thread cleanerThread = null; + + for (Thread thread : Thread.getAllStackTraces().keySet()) { + if (thread.getName().equals(ObjectCleaner.CLEANER_THREAD_NAME)) { + cleanerThread = thread; + break; + } + } + assertNotNull(cleanerThread); + assertTrue(cleanerThread.isDaemon()); + } +} diff --git a/netty-util/src/test/java/io/netty/util/internal/ObjectUtilTest.java b/netty-util/src/test/java/io/netty/util/internal/ObjectUtilTest.java new file mode 100644 index 0000000..8e0dd22 --- /dev/null +++ b/netty-util/src/test/java/io/netty/util/internal/ObjectUtilTest.java @@ -0,0 +1,595 @@ +/* + * Copyright 2021 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ +package io.netty.util.internal; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNull; +import static org.junit.jupiter.api.Assertions.assertTrue; + +/** + * Testcases for io.netty.util.internal.ObjectUtil. + * + * The tests for exceptions do not use a fail mimic. The tests evaluate the + * presence and type, to have really regression character. + * + */ +public class ObjectUtilTest { + + private static final Object NULL_OBJECT = null; + + private static final Object NON_NULL_OBJECT = "Object is not null"; + private static final String NON_NULL_EMPTY_STRING = ""; + private static final String NON_NULL_WHITESPACE_STRING = " "; + private static final Object[] NON_NULL_EMPTY_OBJECT_ARRAY = {}; + private static final Object[] NON_NULL_FILLED_OBJECT_ARRAY = { NON_NULL_OBJECT }; + private static final CharSequence NULL_CHARSEQUENCE = (CharSequence) NULL_OBJECT; + private static final CharSequence NON_NULL_CHARSEQUENCE = (CharSequence) NON_NULL_OBJECT; + private static final CharSequence NON_NULL_EMPTY_CHARSEQUENCE = (CharSequence) NON_NULL_EMPTY_STRING; + private static final byte[] NON_NULL_EMPTY_BYTE_ARRAY = {}; + private static final byte[] NON_NULL_FILLED_BYTE_ARRAY = { (byte) 0xa }; + private static final char[] NON_NULL_EMPTY_CHAR_ARRAY = {}; + private static final char[] NON_NULL_FILLED_CHAR_ARRAY = { 'A' }; + + private static final String NULL_NAME = "IS_NULL"; + private static final String NON_NULL_NAME = "NOT_NULL"; + private static final String NON_NULL_EMPTY_NAME = "NOT_NULL_BUT_EMPTY"; + + private static final String TEST_RESULT_NULLEX_OK = "Expected a NPE/IAE"; + private static final String TEST_RESULT_NULLEX_NOK = "Expected no exception"; + private static final String TEST_RESULT_EXTYPE_NOK = "Expected type not found"; + + private static final int ZERO_INT = 0; + private static final long ZERO_LONG = 0; + private static final double ZERO_DOUBLE = 0.0d; + private static final float ZERO_FLOAT = 0.0f; + + private static final int POS_ONE_INT = 1; + private static final long POS_ONE_LONG = 1; + private static final double POS_ONE_DOUBLE = 1.0d; + private static final float POS_ONE_FLOAT = 1.0f; + + private static final int NEG_ONE_INT = -1; + private static final long NEG_ONE_LONG = -1; + private static final double NEG_ONE_DOUBLE = -1.0d; + private static final float NEG_ONE_FLOAT = -1.0f; + + private static final String NUM_POS_NAME = "NUMBER_POSITIVE"; + private static final String NUM_ZERO_NAME = "NUMBER_ZERO"; + private static final String NUM_NEG_NAME = "NUMBER_NEGATIVE"; + + @Test + public void testCheckNotNull() { + Exception actualEx = null; + try { + ObjectUtil.checkNotNull(NON_NULL_OBJECT, NON_NULL_NAME); + } catch (Exception e) { + actualEx = e; + } + assertNull(actualEx, TEST_RESULT_NULLEX_NOK); + + actualEx = null; + try { + ObjectUtil.checkNotNull(NULL_OBJECT, NULL_NAME); + } catch (Exception e) { + actualEx = e; + } + assertNotNull(actualEx, TEST_RESULT_NULLEX_OK); + assertTrue(actualEx instanceof NullPointerException, TEST_RESULT_EXTYPE_NOK); + } + + @Test + public void testCheckNotNullWithIAE() { + Exception actualEx = null; + try { + ObjectUtil.checkNotNullWithIAE(NON_NULL_OBJECT, NON_NULL_NAME); + } catch (Exception e) { + actualEx = e; + } + assertNull(actualEx, TEST_RESULT_NULLEX_NOK); + + actualEx = null; + try { + ObjectUtil.checkNotNullWithIAE(NULL_OBJECT, NULL_NAME); + } catch (Exception e) { + actualEx = e; + } + assertNotNull(actualEx, TEST_RESULT_NULLEX_OK); + assertTrue(actualEx instanceof IllegalArgumentException, TEST_RESULT_EXTYPE_NOK); + } + + @Test + public void testCheckNotNullArrayParam() { + Exception actualEx = null; + try { + ObjectUtil.checkNotNullArrayParam(NON_NULL_OBJECT, 1, NON_NULL_NAME); + } catch (Exception e) { + actualEx = e; + } + assertNull(actualEx, TEST_RESULT_NULLEX_NOK); + + actualEx = null; + try { + ObjectUtil.checkNotNullArrayParam(NULL_OBJECT, 1, NULL_NAME); + } catch (Exception e) { + actualEx = e; + } + assertNotNull(actualEx, TEST_RESULT_NULLEX_OK); + assertTrue(actualEx instanceof IllegalArgumentException, TEST_RESULT_EXTYPE_NOK); + } + + @Test + public void testCheckPositiveIntString() { + Exception actualEx = null; + try { + ObjectUtil.checkPositive(POS_ONE_INT, NUM_POS_NAME); + } catch (Exception e) { + actualEx = e; + } + assertNull(actualEx, TEST_RESULT_NULLEX_NOK); + + actualEx = null; + try { + ObjectUtil.checkPositive(ZERO_INT, NUM_ZERO_NAME); + } catch (Exception e) { + actualEx = e; + } + assertNotNull(actualEx, TEST_RESULT_NULLEX_OK); + assertTrue(actualEx instanceof IllegalArgumentException, TEST_RESULT_EXTYPE_NOK); + + actualEx = null; + try { + ObjectUtil.checkPositive(NEG_ONE_INT, NUM_NEG_NAME); + } catch (Exception e) { + actualEx = e; + } + assertNotNull(actualEx, TEST_RESULT_NULLEX_OK); + assertTrue(actualEx instanceof IllegalArgumentException, TEST_RESULT_EXTYPE_NOK); + } + + @Test + public void testCheckPositiveLongString() { + Exception actualEx = null; + try { + ObjectUtil.checkPositive(POS_ONE_LONG, NUM_POS_NAME); + } catch (Exception e) { + actualEx = e; + } + assertNull(actualEx, TEST_RESULT_NULLEX_NOK); + + actualEx = null; + try { + ObjectUtil.checkPositive(ZERO_LONG, NUM_ZERO_NAME); + } catch (Exception e) { + actualEx = e; + } + assertNotNull(actualEx, TEST_RESULT_NULLEX_OK); + assertTrue(actualEx instanceof IllegalArgumentException, TEST_RESULT_EXTYPE_NOK); + + actualEx = null; + try { + ObjectUtil.checkPositive(NEG_ONE_LONG, NUM_NEG_NAME); + } catch (Exception e) { + actualEx = e; + } + assertNotNull(actualEx, TEST_RESULT_NULLEX_OK); + assertTrue(actualEx instanceof IllegalArgumentException, TEST_RESULT_EXTYPE_NOK); + } + + @Test + public void testCheckPositiveDoubleString() { + Exception actualEx = null; + try { + ObjectUtil.checkPositive(POS_ONE_DOUBLE, NUM_POS_NAME); + } catch (Exception e) { + actualEx = e; + } + assertNull(actualEx, TEST_RESULT_NULLEX_NOK); + + actualEx = null; + try { + ObjectUtil.checkPositive(ZERO_DOUBLE, NUM_ZERO_NAME); + } catch (Exception e) { + actualEx = e; + } + assertNotNull(actualEx, TEST_RESULT_NULLEX_OK); + assertTrue(actualEx instanceof IllegalArgumentException, TEST_RESULT_EXTYPE_NOK); + + actualEx = null; + try { + ObjectUtil.checkPositive(NEG_ONE_DOUBLE, NUM_NEG_NAME); + } catch (Exception e) { + actualEx = e; + } + assertNotNull(actualEx, TEST_RESULT_NULLEX_OK); + assertTrue(actualEx instanceof IllegalArgumentException, TEST_RESULT_EXTYPE_NOK); + } + + @Test + public void testCheckPositiveFloatString() { + Exception actualEx = null; + try { + ObjectUtil.checkPositive(POS_ONE_FLOAT, NUM_POS_NAME); + } catch (Exception e) { + actualEx = e; + } + assertNull(actualEx, TEST_RESULT_NULLEX_NOK); + + actualEx = null; + try { + ObjectUtil.checkPositive(ZERO_FLOAT, NUM_ZERO_NAME); + } catch (Exception e) { + actualEx = e; + } + assertNotNull(actualEx, TEST_RESULT_NULLEX_OK); + assertTrue(actualEx instanceof IllegalArgumentException, TEST_RESULT_EXTYPE_NOK); + + actualEx = null; + try { + ObjectUtil.checkPositive(NEG_ONE_FLOAT, NUM_NEG_NAME); + } catch (Exception e) { + actualEx = e; + } + assertNotNull(actualEx, TEST_RESULT_NULLEX_OK); + assertTrue(actualEx instanceof IllegalArgumentException, TEST_RESULT_EXTYPE_NOK); + } + + @Test + public void testCheckPositiveOrZeroIntString() { + Exception actualEx = null; + try { + ObjectUtil.checkPositiveOrZero(POS_ONE_INT, NUM_POS_NAME); + } catch (Exception e) { + actualEx = e; + } + assertNull(actualEx, TEST_RESULT_NULLEX_NOK); + + actualEx = null; + try { + ObjectUtil.checkPositiveOrZero(ZERO_INT, NUM_ZERO_NAME); + } catch (Exception e) { + actualEx = e; + } + assertNull(actualEx, TEST_RESULT_NULLEX_NOK); + + actualEx = null; + try { + ObjectUtil.checkPositiveOrZero(NEG_ONE_INT, NUM_NEG_NAME); + } catch (Exception e) { + actualEx = e; + } + assertNotNull(actualEx, TEST_RESULT_NULLEX_OK); + assertTrue(actualEx instanceof IllegalArgumentException, TEST_RESULT_EXTYPE_NOK); + } + + @Test + public void testCheckPositiveOrZeroLongString() { + Exception actualEx = null; + try { + ObjectUtil.checkPositiveOrZero(POS_ONE_LONG, NUM_POS_NAME); + } catch (Exception e) { + actualEx = e; + } + assertNull(actualEx, TEST_RESULT_NULLEX_NOK); + + actualEx = null; + try { + ObjectUtil.checkPositiveOrZero(ZERO_LONG, NUM_ZERO_NAME); + } catch (Exception e) { + actualEx = e; + } + assertNull(actualEx, TEST_RESULT_NULLEX_NOK); + + actualEx = null; + try { + ObjectUtil.checkPositiveOrZero(NEG_ONE_LONG, NUM_NEG_NAME); + } catch (Exception e) { + actualEx = e; + } + assertNotNull(actualEx, TEST_RESULT_NULLEX_OK); + assertTrue(actualEx instanceof IllegalArgumentException, TEST_RESULT_EXTYPE_NOK); + } + + @Test + public void testCheckPositiveOrZeroDoubleString() { + Exception actualEx = null; + try { + ObjectUtil.checkPositiveOrZero(POS_ONE_DOUBLE, NUM_POS_NAME); + } catch (Exception e) { + actualEx = e; + } + assertNull(actualEx, TEST_RESULT_NULLEX_NOK); + + actualEx = null; + try { + ObjectUtil.checkPositiveOrZero(ZERO_DOUBLE, NUM_ZERO_NAME); + } catch (Exception e) { + actualEx = e; + } + assertNull(actualEx, TEST_RESULT_NULLEX_NOK); + + actualEx = null; + try { + ObjectUtil.checkPositiveOrZero(NEG_ONE_DOUBLE, NUM_NEG_NAME); + } catch (Exception e) { + actualEx = e; + } + assertNotNull(actualEx, TEST_RESULT_NULLEX_OK); + assertTrue(actualEx instanceof IllegalArgumentException, TEST_RESULT_EXTYPE_NOK); + } + + @Test + public void testCheckPositiveOrZeroFloatString() { + Exception actualEx = null; + try { + ObjectUtil.checkPositiveOrZero(POS_ONE_FLOAT, NUM_POS_NAME); + } catch (Exception e) { + actualEx = e; + } + assertNull(actualEx, TEST_RESULT_NULLEX_NOK); + + actualEx = null; + try { + ObjectUtil.checkPositiveOrZero(ZERO_FLOAT, NUM_ZERO_NAME); + } catch (Exception e) { + actualEx = e; + } + assertNull(actualEx, TEST_RESULT_NULLEX_NOK); + + actualEx = null; + try { + ObjectUtil.checkPositiveOrZero(NEG_ONE_FLOAT, NUM_NEG_NAME); + } catch (Exception e) { + actualEx = e; + } + assertNotNull(actualEx, TEST_RESULT_NULLEX_OK); + assertTrue(actualEx instanceof IllegalArgumentException, TEST_RESULT_EXTYPE_NOK); + } + + @Test + public void testCheckNonEmptyTArrayString() { + Exception actualEx = null; + + try { + ObjectUtil.checkNonEmpty((Object[]) NULL_OBJECT, NULL_NAME); + } catch (Exception e) { + actualEx = e; + } + assertNotNull(actualEx, TEST_RESULT_NULLEX_OK); + assertTrue(actualEx instanceof NullPointerException, TEST_RESULT_EXTYPE_NOK); + + actualEx = null; + try { + ObjectUtil.checkNonEmpty((Object[]) NON_NULL_FILLED_OBJECT_ARRAY, NON_NULL_NAME); + } catch (Exception e) { + actualEx = e; + } + assertNull(actualEx, TEST_RESULT_NULLEX_NOK); + + actualEx = null; + try { + ObjectUtil.checkNonEmpty((Object[]) NON_NULL_EMPTY_OBJECT_ARRAY, NON_NULL_EMPTY_NAME); + } catch (Exception e) { + actualEx = e; + } + assertNotNull(actualEx, TEST_RESULT_NULLEX_OK); + assertTrue(actualEx instanceof IllegalArgumentException, TEST_RESULT_EXTYPE_NOK); + } + + @Test + public void testCheckNonEmptyByteArrayString() { + Exception actualEx = null; + + try { + ObjectUtil.checkNonEmpty((byte[]) NULL_OBJECT, NULL_NAME); + } catch (Exception e) { + actualEx = e; + } + assertNotNull(actualEx, TEST_RESULT_NULLEX_OK); + assertTrue(actualEx instanceof NullPointerException, TEST_RESULT_EXTYPE_NOK); + + actualEx = null; + try { + ObjectUtil.checkNonEmpty((byte[]) NON_NULL_FILLED_BYTE_ARRAY, NON_NULL_NAME); + } catch (Exception e) { + actualEx = e; + } + assertNull(actualEx, TEST_RESULT_NULLEX_NOK); + + actualEx = null; + try { + ObjectUtil.checkNonEmpty((byte[]) NON_NULL_EMPTY_BYTE_ARRAY, NON_NULL_EMPTY_NAME); + } catch (Exception e) { + actualEx = e; + } + assertNotNull(actualEx, TEST_RESULT_NULLEX_OK); + assertTrue(actualEx instanceof IllegalArgumentException, TEST_RESULT_EXTYPE_NOK); + } + + @Test + public void testCheckNonEmptyCharArrayString() { + Exception actualEx = null; + + try { + ObjectUtil.checkNonEmpty((char[]) NULL_OBJECT, NULL_NAME); + } catch (Exception e) { + actualEx = e; + } + assertNotNull(actualEx, TEST_RESULT_NULLEX_OK); + assertTrue(actualEx instanceof NullPointerException, TEST_RESULT_EXTYPE_NOK); + + actualEx = null; + try { + ObjectUtil.checkNonEmpty((char[]) NON_NULL_FILLED_CHAR_ARRAY, NON_NULL_NAME); + } catch (Exception e) { + actualEx = e; + } + assertNull(actualEx, TEST_RESULT_NULLEX_NOK); + + actualEx = null; + try { + ObjectUtil.checkNonEmpty((char[]) NON_NULL_EMPTY_CHAR_ARRAY, NON_NULL_EMPTY_NAME); + } catch (Exception e) { + actualEx = e; + } + assertNotNull(actualEx, TEST_RESULT_NULLEX_OK); + assertTrue(actualEx instanceof IllegalArgumentException, TEST_RESULT_EXTYPE_NOK); + } + + @Test + public void testCheckNonEmptyTString() { + Exception actualEx = null; + try { + ObjectUtil.checkNonEmpty((Object[]) NULL_OBJECT, NULL_NAME); + } catch (Exception e) { + actualEx = e; + } + assertNotNull(actualEx, TEST_RESULT_NULLEX_OK); + assertTrue(actualEx instanceof NullPointerException, TEST_RESULT_EXTYPE_NOK); + + actualEx = null; + try { + ObjectUtil.checkNonEmpty((Object[]) NON_NULL_FILLED_OBJECT_ARRAY, NON_NULL_NAME); + } catch (Exception e) { + actualEx = e; + } + assertNull(actualEx, TEST_RESULT_NULLEX_NOK); + + actualEx = null; + try { + ObjectUtil.checkNonEmpty((Object[]) NON_NULL_EMPTY_OBJECT_ARRAY, NON_NULL_EMPTY_NAME); + } catch (Exception e) { + actualEx = e; + } + assertNotNull(actualEx, TEST_RESULT_NULLEX_OK); + assertTrue(actualEx instanceof IllegalArgumentException, TEST_RESULT_EXTYPE_NOK); + } + + @Test + public void testCheckNonEmptyStringString() { + Exception actualEx = null; + + try { + ObjectUtil.checkNonEmpty((String) NULL_OBJECT, NULL_NAME); + } catch (Exception e) { + actualEx = e; + } + assertNotNull(actualEx, TEST_RESULT_NULLEX_OK); + assertTrue(actualEx instanceof NullPointerException, TEST_RESULT_EXTYPE_NOK); + + actualEx = null; + try { + ObjectUtil.checkNonEmpty((String) NON_NULL_OBJECT, NON_NULL_NAME); + } catch (Exception e) { + actualEx = e; + } + assertNull(actualEx, TEST_RESULT_NULLEX_NOK); + + actualEx = null; + try { + ObjectUtil.checkNonEmpty((String) NON_NULL_EMPTY_STRING, NON_NULL_EMPTY_NAME); + } catch (Exception e) { + actualEx = e; + } + assertNotNull(actualEx, TEST_RESULT_NULLEX_OK); + assertTrue(actualEx instanceof IllegalArgumentException, TEST_RESULT_EXTYPE_NOK); + + actualEx = null; + try { + ObjectUtil.checkNonEmpty((String) NON_NULL_WHITESPACE_STRING, NON_NULL_EMPTY_NAME); + } catch (Exception e) { + actualEx = e; + } + assertNull(actualEx, TEST_RESULT_NULLEX_NOK); + } + + @Test + public void testCheckNonEmptyCharSequenceString() { + Exception actualEx = null; + + try { + ObjectUtil.checkNonEmpty((CharSequence) NULL_CHARSEQUENCE, NULL_NAME); + } catch (Exception e) { + actualEx = e; + } + assertNotNull(actualEx, TEST_RESULT_NULLEX_OK); + assertTrue(actualEx instanceof NullPointerException, TEST_RESULT_EXTYPE_NOK); + + actualEx = null; + try { + ObjectUtil.checkNonEmpty((CharSequence) NON_NULL_CHARSEQUENCE, NON_NULL_NAME); + } catch (Exception e) { + actualEx = e; + } + assertNull(actualEx, TEST_RESULT_NULLEX_NOK); + + actualEx = null; + try { + ObjectUtil.checkNonEmpty((CharSequence) NON_NULL_EMPTY_CHARSEQUENCE, NON_NULL_EMPTY_NAME); + } catch (Exception e) { + actualEx = e; + } + assertNotNull(actualEx, TEST_RESULT_NULLEX_OK); + assertTrue(actualEx instanceof IllegalArgumentException, TEST_RESULT_EXTYPE_NOK); + + actualEx = null; + try { + ObjectUtil.checkNonEmpty((CharSequence) NON_NULL_WHITESPACE_STRING, NON_NULL_EMPTY_NAME); + } catch (Exception e) { + actualEx = e; + } + assertNull(actualEx, TEST_RESULT_NULLEX_NOK); + } + + @Test + public void testCheckNonEmptyAfterTrim() { + Exception actualEx = null; + + try { + ObjectUtil.checkNonEmptyAfterTrim((String) NULL_OBJECT, NULL_NAME); + } catch (Exception e) { + actualEx = e; + } + assertNotNull(actualEx, TEST_RESULT_NULLEX_OK); + assertTrue(actualEx instanceof NullPointerException, TEST_RESULT_EXTYPE_NOK); + + actualEx = null; + try { + ObjectUtil.checkNonEmptyAfterTrim((String) NON_NULL_OBJECT, NON_NULL_NAME); + } catch (Exception e) { + actualEx = e; + } + assertNull(actualEx, TEST_RESULT_NULLEX_NOK); + + actualEx = null; + try { + ObjectUtil.checkNonEmptyAfterTrim(NON_NULL_EMPTY_STRING, NON_NULL_EMPTY_NAME); + } catch (Exception e) { + actualEx = e; + } + assertNotNull(actualEx, TEST_RESULT_NULLEX_OK); + assertTrue(actualEx instanceof IllegalArgumentException, TEST_RESULT_EXTYPE_NOK); + + actualEx = null; + try { + ObjectUtil.checkNonEmptyAfterTrim(NON_NULL_WHITESPACE_STRING, NON_NULL_EMPTY_NAME); + } catch (Exception e) { + actualEx = e; + } + assertNotNull(actualEx, TEST_RESULT_NULLEX_OK); + assertTrue(actualEx instanceof IllegalArgumentException, TEST_RESULT_EXTYPE_NOK); + } +} diff --git a/netty-util/src/test/java/io/netty/util/internal/OsClassifiersTest.java b/netty-util/src/test/java/io/netty/util/internal/OsClassifiersTest.java new file mode 100644 index 0000000..903f5cb --- /dev/null +++ b/netty-util/src/test/java/io/netty/util/internal/OsClassifiersTest.java @@ -0,0 +1,121 @@ +/* + * Copyright 2022 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, version 2.0 (the + * "License"); you may not use this file except in compliance with the License. You may obtain a + * copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License + * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express + * or implied. See the License for the specific language governing permissions and limitations under + * the License. + */ + +package io.netty.util.internal; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import java.util.Arrays; +import java.util.Collections; +import java.util.HashSet; +import java.util.LinkedHashSet; +import java.util.Properties; +import java.util.Set; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class OsClassifiersTest { + private static final String OS_CLASSIFIERS_PROPERTY = "io.netty.osClassifiers"; + + private Properties systemProperties; + + @BeforeEach + void setUp() { + systemProperties = System.getProperties(); + } + + @AfterEach + void tearDown() { + systemProperties.remove(OS_CLASSIFIERS_PROPERTY); + } + + @Test + void testOsClassifiersPropertyAbsent() { + Set allowed = new HashSet(Arrays.asList("fedora", "suse", "arch")); + Set available = new LinkedHashSet(2); + boolean added = PlatformDependent.addPropertyOsClassifiers(allowed, available); + assertFalse(added); + assertTrue(available.isEmpty()); + } + + @Test + void testOsClassifiersPropertyEmpty() { + // empty property -Dio.netty.osClassifiers + systemProperties.setProperty(OS_CLASSIFIERS_PROPERTY, ""); + Set allowed = Collections.singleton("fedora"); + Set available = new LinkedHashSet(2); + boolean added = PlatformDependent.addPropertyOsClassifiers(allowed, available); + assertTrue(added); + assertTrue(available.isEmpty()); + } + + @Test + void testOsClassifiersPropertyNotEmptyNoClassifiers() { + // ID + systemProperties.setProperty(OS_CLASSIFIERS_PROPERTY, ","); + final Set allowed = new HashSet(Arrays.asList("fedora", "suse", "arch")); + final Set available = new LinkedHashSet(2); + Assertions.assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() throws Throwable { + PlatformDependent.addPropertyOsClassifiers(allowed, available); + } + }); + } + + @Test + void testOsClassifiersPropertySingle() { + // ID + systemProperties.setProperty(OS_CLASSIFIERS_PROPERTY, "fedora"); + Set allowed = Collections.singleton("fedora"); + Set available = new LinkedHashSet(2); + boolean added = PlatformDependent.addPropertyOsClassifiers(allowed, available); + assertTrue(added); + assertEquals(1, available.size()); + assertEquals("fedora", available.iterator().next()); + } + + @Test + void testOsClassifiersPropertyPair() { + // ID, ID_LIKE + systemProperties.setProperty(OS_CLASSIFIERS_PROPERTY, "manjaro,arch"); + Set allowed = new HashSet(Arrays.asList("fedora", "suse", "arch")); + Set available = new LinkedHashSet(2); + boolean added = PlatformDependent.addPropertyOsClassifiers(allowed, available); + assertTrue(added); + assertEquals(1, available.size()); + assertEquals("arch", available.iterator().next()); + } + + @Test + void testOsClassifiersPropertyExcessive() { + // ID, ID_LIKE, excessive + systemProperties.setProperty(OS_CLASSIFIERS_PROPERTY, "manjaro,arch,slackware"); + final Set allowed = new HashSet(Arrays.asList("fedora", "suse", "arch")); + final Set available = new LinkedHashSet(2); + Assertions.assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() throws Throwable { + PlatformDependent.addPropertyOsClassifiers(allowed, available); + } + }); + } +} diff --git a/netty-util/src/test/java/io/netty/util/internal/PlatformDependent0Test.java b/netty-util/src/test/java/io/netty/util/internal/PlatformDependent0Test.java new file mode 100644 index 0000000..8899fa0 --- /dev/null +++ b/netty-util/src/test/java/io/netty/util/internal/PlatformDependent0Test.java @@ -0,0 +1,94 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.internal; + +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; + +import java.nio.ByteBuffer; +import java.security.Permission; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +public class PlatformDependent0Test { + + @BeforeAll + public static void assumeUnsafe() { + assumeTrue(PlatformDependent0.hasUnsafe()); + assumeTrue(PlatformDependent0.hasDirectBufferNoCleanerConstructor()); + } + + @Test + public void testNewDirectBufferNegativeMemoryAddress() { + testNewDirectBufferMemoryAddress(-1); + } + + @Test + public void testNewDirectBufferNonNegativeMemoryAddress() { + testNewDirectBufferMemoryAddress(10); + } + + @Test + public void testNewDirectBufferZeroMemoryAddress() { + PlatformDependent0.newDirectBuffer(0, 10); + } + + private static void testNewDirectBufferMemoryAddress(long address) { + assumeTrue(PlatformDependent0.hasDirectBufferNoCleanerConstructor()); + + int capacity = 10; + ByteBuffer buffer = PlatformDependent0.newDirectBuffer(address, capacity); + assertEquals(address, PlatformDependent0.directBufferAddress(buffer)); + assertEquals(capacity, buffer.capacity()); + } + + @Test + public void testMajorVersionFromJavaSpecificationVersion() { + final SecurityManager current = System.getSecurityManager(); + + try { + System.setSecurityManager(new SecurityManager() { + @Override + public void checkPropertyAccess(String key) { + if (key.equals("java.specification.version")) { + // deny + throw new SecurityException(key); + } + } + + // so we can restore the security manager + @Override + public void checkPermission(Permission perm) { + } + }); + + assertEquals(6, PlatformDependent0.majorVersionFromJavaSpecificationVersion()); + } finally { + System.setSecurityManager(current); + } + } + + @Test + public void testMajorVersion() { + assertEquals(6, PlatformDependent0.majorVersion("1.6")); + assertEquals(7, PlatformDependent0.majorVersion("1.7")); + assertEquals(8, PlatformDependent0.majorVersion("1.8")); + assertEquals(8, PlatformDependent0.majorVersion("8")); + assertEquals(9, PlatformDependent0.majorVersion("1.9")); // early version of JDK 9 before Project Verona + assertEquals(9, PlatformDependent0.majorVersion("9")); + } +} diff --git a/netty-util/src/test/java/io/netty/util/internal/PlatformDependentTest.java b/netty-util/src/test/java/io/netty/util/internal/PlatformDependentTest.java new file mode 100644 index 0000000..d29a985 --- /dev/null +++ b/netty-util/src/test/java/io/netty/util/internal/PlatformDependentTest.java @@ -0,0 +1,161 @@ +/* + * Copyright 2015 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.internal; + +import org.junit.jupiter.api.Test; + +import java.nio.ByteBuffer; +import java.util.Random; + +import static io.netty.util.internal.PlatformDependent.hashCodeAscii; +import static io.netty.util.internal.PlatformDependent.hashCodeAsciiSafe; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotEquals; +import static org.junit.jupiter.api.Assertions.assertNotSame; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.junit.jupiter.api.Assumptions.assumeTrue; + +public class PlatformDependentTest { + private static final Random r = new Random(); + @Test + public void testEqualsConsistentTime() { + testEquals(new EqualityChecker() { + @Override + public boolean equals(byte[] bytes1, int startPos1, byte[] bytes2, int startPos2, int length) { + return PlatformDependent.equalsConstantTime(bytes1, startPos1, bytes2, startPos2, length) != 0; + } + }); + } + + @Test + public void testEquals() { + testEquals(new EqualityChecker() { + @Override + public boolean equals(byte[] bytes1, int startPos1, byte[] bytes2, int startPos2, int length) { + return PlatformDependent.equals(bytes1, startPos1, bytes2, startPos2, length); + } + }); + } + + @Test + public void testIsZero() { + byte[] bytes = new byte[100]; + assertTrue(PlatformDependent.isZero(bytes, 0, 0)); + assertTrue(PlatformDependent.isZero(bytes, 0, -1)); + assertTrue(PlatformDependent.isZero(bytes, 0, 100)); + assertTrue(PlatformDependent.isZero(bytes, 10, 90)); + bytes[10] = 1; + assertTrue(PlatformDependent.isZero(bytes, 0, 10)); + assertFalse(PlatformDependent.isZero(bytes, 0, 11)); + assertFalse(PlatformDependent.isZero(bytes, 10, 1)); + assertTrue(PlatformDependent.isZero(bytes, 11, 89)); + } + + private interface EqualityChecker { + boolean equals(byte[] bytes1, int startPos1, byte[] bytes2, int startPos2, int length); + } + + private static void testEquals(EqualityChecker equalsChecker) { + byte[] bytes1 = {'H', 'e', 'l', 'l', 'o', ' ', 'W', 'o', 'r', 'l', 'd'}; + byte[] bytes2 = {'H', 'e', 'l', 'l', 'o', ' ', 'W', 'o', 'r', 'l', 'd'}; + assertNotSame(bytes1, bytes2); + assertTrue(equalsChecker.equals(bytes1, 0, bytes2, 0, bytes1.length)); + assertTrue(equalsChecker.equals(bytes1, 2, bytes2, 2, bytes1.length - 2)); + + bytes1 = new byte[] {1, 2, 3, 4, 5, 6}; + bytes2 = new byte[] {1, 2, 3, 4, 5, 6, 7}; + assertNotSame(bytes1, bytes2); + assertFalse(equalsChecker.equals(bytes1, 0, bytes2, 1, bytes1.length)); + assertTrue(equalsChecker.equals(bytes2, 0, bytes1, 0, bytes1.length)); + + bytes1 = new byte[] {1, 2, 3, 4}; + bytes2 = new byte[] {1, 2, 3, 5}; + assertFalse(equalsChecker.equals(bytes1, 0, bytes2, 0, bytes1.length)); + assertTrue(equalsChecker.equals(bytes1, 0, bytes2, 0, 3)); + + bytes1 = new byte[] {1, 2, 3, 4}; + bytes2 = new byte[] {1, 3, 3, 4}; + assertFalse(equalsChecker.equals(bytes1, 0, bytes2, 0, bytes1.length)); + assertTrue(equalsChecker.equals(bytes1, 2, bytes2, 2, bytes1.length - 2)); + + bytes1 = new byte[0]; + bytes2 = new byte[0]; + assertNotSame(bytes1, bytes2); + assertTrue(equalsChecker.equals(bytes1, 0, bytes2, 0, 0)); + + bytes1 = new byte[100]; + bytes2 = new byte[100]; + for (int i = 0; i < 100; i++) { + bytes1[i] = (byte) i; + bytes2[i] = (byte) i; + } + assertTrue(equalsChecker.equals(bytes1, 0, bytes2, 0, bytes1.length)); + bytes1[50] = 0; + assertFalse(equalsChecker.equals(bytes1, 0, bytes2, 0, bytes1.length)); + assertTrue(equalsChecker.equals(bytes1, 51, bytes2, 51, bytes1.length - 51)); + assertTrue(equalsChecker.equals(bytes1, 0, bytes2, 0, 50)); + + bytes1 = new byte[]{1, 2, 3, 4, 5}; + bytes2 = new byte[]{3, 4, 5}; + assertFalse(equalsChecker.equals(bytes1, 0, bytes2, 0, bytes2.length)); + assertTrue(equalsChecker.equals(bytes1, 2, bytes2, 0, bytes2.length)); + assertTrue(equalsChecker.equals(bytes2, 0, bytes1, 2, bytes2.length)); + + for (int i = 0; i < 1000; ++i) { + bytes1 = new byte[i]; + r.nextBytes(bytes1); + bytes2 = bytes1.clone(); + assertTrue(equalsChecker.equals(bytes1, 0, bytes2, 0, bytes1.length)); + } + + assertTrue(equalsChecker.equals(bytes1, 0, bytes2, 0, 0)); + assertTrue(equalsChecker.equals(bytes1, 0, bytes2, 0, -1)); + } + + private static char randomCharInByteRange() { + return (char) r.nextInt(255 + 1); + } + + @Test + public void testHashCodeAscii() { + for (int i = 0; i < 1000; ++i) { + // byte[] and char[] need to be initialized such that there values are within valid "ascii" range + byte[] bytes = new byte[i]; + char[] bytesChar = new char[i]; + for (int j = 0; j < bytesChar.length; ++j) { + bytesChar[j] = randomCharInByteRange(); + bytes[j] = (byte) (bytesChar[j] & 0xff); + } + String string = new String(bytesChar); + assertEquals(hashCodeAsciiSafe(bytes, 0, bytes.length), + hashCodeAscii(bytes, 0, bytes.length), + "length=" + i); + assertEquals(hashCodeAscii(bytes, 0, bytes.length), + hashCodeAscii(string), + "length=" + i); + } + } + + @Test + public void testAllocateWithCapacity0() { + assumeTrue(PlatformDependent.hasDirectBufferNoCleanerConstructor()); + ByteBuffer buffer = PlatformDependent.allocateDirectNoCleaner(0); + assertNotEquals(0, PlatformDependent.directBufferAddress(buffer)); + assertEquals(0, buffer.capacity()); + PlatformDependent.freeDirectNoCleaner(buffer); + } +} diff --git a/netty-util/src/test/java/io/netty/util/internal/StringUtilTest.java b/netty-util/src/test/java/io/netty/util/internal/StringUtilTest.java new file mode 100644 index 0000000..962fc37 --- /dev/null +++ b/netty-util/src/test/java/io/netty/util/internal/StringUtilTest.java @@ -0,0 +1,648 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.internal; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import java.util.Arrays; +import java.util.Collections; + +import static io.netty.util.internal.StringUtil.NEWLINE; +import static io.netty.util.internal.StringUtil.commonSuffixOfLength; +import static io.netty.util.internal.StringUtil.indexOfWhiteSpace; +import static io.netty.util.internal.StringUtil.indexOfNonWhiteSpace; +import static io.netty.util.internal.StringUtil.isNullOrEmpty; +import static io.netty.util.internal.StringUtil.simpleClassName; +import static io.netty.util.internal.StringUtil.substringAfter; +import static io.netty.util.internal.StringUtil.toHexString; +import static io.netty.util.internal.StringUtil.toHexStringPadded; +import static io.netty.util.internal.StringUtil.unescapeCsv; +import static io.netty.util.internal.StringUtil.unescapeCsvFields; +import static org.hamcrest.CoreMatchers.is; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.junit.jupiter.api.Assertions.assertArrayEquals; +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class StringUtilTest { + + @Test + public void ensureNewlineExists() { + assertNotNull(NEWLINE); + } + + @Test + public void testToHexString() { + assertThat(toHexString(new byte[] { 0 }), is("0")); + assertThat(toHexString(new byte[] { 1 }), is("1")); + assertThat(toHexString(new byte[] { 0, 0 }), is("0")); + assertThat(toHexString(new byte[] { 1, 0 }), is("100")); + assertThat(toHexString(EmptyArrays.EMPTY_BYTES), is("")); + } + + @Test + public void testToHexStringPadded() { + assertThat(toHexStringPadded(new byte[]{0}), is("00")); + assertThat(toHexStringPadded(new byte[]{1}), is("01")); + assertThat(toHexStringPadded(new byte[]{0, 0}), is("0000")); + assertThat(toHexStringPadded(new byte[]{1, 0}), is("0100")); + assertThat(toHexStringPadded(EmptyArrays.EMPTY_BYTES), is("")); + } + + @Test + public void splitSimple() { + assertArrayEquals(new String[] { "foo", "bar" }, "foo:bar".split(":")); + } + + @Test + public void splitWithTrailingDelimiter() { + assertArrayEquals(new String[] { "foo", "bar" }, "foo,bar,".split(",")); + } + + @Test + public void splitWithTrailingDelimiters() { + assertArrayEquals(new String[] { "foo", "bar" }, "foo!bar!!".split("!")); + } + + @Test + public void splitWithTrailingDelimitersDot() { + assertArrayEquals(new String[] { "foo", "bar" }, "foo.bar..".split("\\.")); + } + + @Test + public void splitWithTrailingDelimitersEq() { + assertArrayEquals(new String[] { "foo", "bar" }, "foo=bar==".split("=")); + } + + @Test + public void splitWithTrailingDelimitersSpace() { + assertArrayEquals(new String[] { "foo", "bar" }, "foo bar ".split(" ")); + } + + @Test + public void splitWithConsecutiveDelimiters() { + assertArrayEquals(new String[] { "foo", "", "bar" }, "foo$$bar".split("\\$")); + } + + @Test + public void splitWithDelimiterAtBeginning() { + assertArrayEquals(new String[] { "", "foo", "bar" }, "#foo#bar".split("#")); + } + + @Test + public void splitMaxPart() { + assertArrayEquals(new String[] { "foo", "bar:bar2" }, "foo:bar:bar2".split(":", 2)); + assertArrayEquals(new String[] { "foo", "bar", "bar2" }, "foo:bar:bar2".split(":", 3)); + } + + @Test + public void substringAfterTest() { + assertEquals("bar:bar2", substringAfter("foo:bar:bar2", ':')); + } + + @Test + public void commonSuffixOfLengthTest() { + // negative length suffixes are never common + checkNotCommonSuffix("abc", "abc", -1); + + // null has no suffix + checkNotCommonSuffix("abc", null, 0); + checkNotCommonSuffix(null, null, 0); + + // any non-null string has 0-length suffix + checkCommonSuffix("abc", "xx", 0); + + checkCommonSuffix("abc", "abc", 0); + checkCommonSuffix("abc", "abc", 1); + checkCommonSuffix("abc", "abc", 2); + checkCommonSuffix("abc", "abc", 3); + checkNotCommonSuffix("abc", "abc", 4); + + checkCommonSuffix("abcd", "cd", 1); + checkCommonSuffix("abcd", "cd", 2); + checkNotCommonSuffix("abcd", "cd", 3); + + checkCommonSuffix("abcd", "axcd", 1); + checkCommonSuffix("abcd", "axcd", 2); + checkNotCommonSuffix("abcd", "axcd", 3); + + checkNotCommonSuffix("abcx", "abcy", 1); + } + + private static void checkNotCommonSuffix(String s, String p, int len) { + assertFalse(checkCommonSuffixSymmetric(s, p, len)); + } + + private static void checkCommonSuffix(String s, String p, int len) { + assertTrue(checkCommonSuffixSymmetric(s, p, len)); + } + + private static boolean checkCommonSuffixSymmetric(String s, String p, int len) { + boolean sp = commonSuffixOfLength(s, p, len); + boolean ps = commonSuffixOfLength(p, s, len); + assertEquals(sp, ps); + return sp; + } + + @Test + public void escapeCsvNull() { + assertThrows(NullPointerException.class, new Executable() { + @Override + public void execute() { + StringUtil.escapeCsv(null); + } + }); + } + + @Test + public void escapeCsvEmpty() { + CharSequence value = ""; + escapeCsv(value, value); + } + + @Test + public void escapeCsvUnquoted() { + CharSequence value = "something"; + escapeCsv(value, value); + } + + @Test + public void escapeCsvAlreadyQuoted() { + CharSequence value = "\"something\""; + CharSequence expected = "\"something\""; + escapeCsv(value, expected); + } + + @Test + public void escapeCsvWithQuote() { + CharSequence value = "s\""; + CharSequence expected = "\"s\"\"\""; + escapeCsv(value, expected); + } + + @Test + public void escapeCsvWithQuoteInMiddle() { + CharSequence value = "some text\"and more text"; + CharSequence expected = "\"some text\"\"and more text\""; + escapeCsv(value, expected); + } + + @Test + public void escapeCsvWithQuoteInMiddleAlreadyQuoted() { + CharSequence value = "\"some text\"and more text\""; + CharSequence expected = "\"some text\"\"and more text\""; + escapeCsv(value, expected); + } + + @Test + public void escapeCsvWithQuotedWords() { + CharSequence value = "\"foo\"\"goo\""; + CharSequence expected = "\"foo\"\"goo\""; + escapeCsv(value, expected); + } + + @Test + public void escapeCsvWithAlreadyEscapedQuote() { + CharSequence value = "foo\"\"goo"; + CharSequence expected = "foo\"\"goo"; + escapeCsv(value, expected); + } + + @Test + public void escapeCsvEndingWithQuote() { + CharSequence value = "some\""; + CharSequence expected = "\"some\"\"\""; + escapeCsv(value, expected); + } + + @Test + public void escapeCsvWithSingleQuote() { + CharSequence value = "\""; + CharSequence expected = "\"\"\"\""; + escapeCsv(value, expected); + } + + @Test + public void escapeCsvWithSingleQuoteAndCharacter() { + CharSequence value = "\"f"; + CharSequence expected = "\"\"\"f\""; + escapeCsv(value, expected); + } + + @Test + public void escapeCsvAlreadyEscapedQuote() { + CharSequence value = "\"some\"\""; + CharSequence expected = "\"some\"\"\""; + escapeCsv(value, expected); + } + + @Test + public void escapeCsvQuoted() { + CharSequence value = "\"foo,goo\""; + escapeCsv(value, value); + } + + @Test + public void escapeCsvWithLineFeed() { + CharSequence value = "some text\n more text"; + CharSequence expected = "\"some text\n more text\""; + escapeCsv(value, expected); + } + + @Test + public void escapeCsvWithSingleLineFeedCharacter() { + CharSequence value = "\n"; + CharSequence expected = "\"\n\""; + escapeCsv(value, expected); + } + + @Test + public void escapeCsvWithMultipleLineFeedCharacter() { + CharSequence value = "\n\n"; + CharSequence expected = "\"\n\n\""; + escapeCsv(value, expected); + } + + @Test + public void escapeCsvWithQuotedAndLineFeedCharacter() { + CharSequence value = " \" \n "; + CharSequence expected = "\" \"\" \n \""; + escapeCsv(value, expected); + } + + @Test + public void escapeCsvWithLineFeedAtEnd() { + CharSequence value = "testing\n"; + CharSequence expected = "\"testing\n\""; + escapeCsv(value, expected); + } + + @Test + public void escapeCsvWithComma() { + CharSequence value = "test,ing"; + CharSequence expected = "\"test,ing\""; + escapeCsv(value, expected); + } + + @Test + public void escapeCsvWithSingleComma() { + CharSequence value = ","; + CharSequence expected = "\",\""; + escapeCsv(value, expected); + } + + @Test + public void escapeCsvWithSingleCarriageReturn() { + CharSequence value = "\r"; + CharSequence expected = "\"\r\""; + escapeCsv(value, expected); + } + + @Test + public void escapeCsvWithMultipleCarriageReturn() { + CharSequence value = "\r\r"; + CharSequence expected = "\"\r\r\""; + escapeCsv(value, expected); + } + + @Test + public void escapeCsvWithCarriageReturn() { + CharSequence value = "some text\r more text"; + CharSequence expected = "\"some text\r more text\""; + escapeCsv(value, expected); + } + + @Test + public void escapeCsvWithQuotedAndCarriageReturnCharacter() { + CharSequence value = "\"\r"; + CharSequence expected = "\"\"\"\r\""; + escapeCsv(value, expected); + } + + @Test + public void escapeCsvWithCarriageReturnAtEnd() { + CharSequence value = "testing\r"; + CharSequence expected = "\"testing\r\""; + escapeCsv(value, expected); + } + + @Test + public void escapeCsvWithCRLFCharacter() { + CharSequence value = "\r\n"; + CharSequence expected = "\"\r\n\""; + escapeCsv(value, expected); + } + + private static void escapeCsv(CharSequence value, CharSequence expected) { + escapeCsv(value, expected, false); + } + + private static void escapeCsvWithTrimming(CharSequence value, CharSequence expected) { + escapeCsv(value, expected, true); + } + + private static void escapeCsv(CharSequence value, CharSequence expected, boolean trimOws) { + CharSequence escapedValue = value; + for (int i = 0; i < 10; ++i) { + escapedValue = StringUtil.escapeCsv(escapedValue, trimOws); + assertEquals(expected, escapedValue.toString()); + } + } + + @Test + public void escapeCsvWithTrimming() { + assertSame("", StringUtil.escapeCsv("", true)); + assertSame("ab", StringUtil.escapeCsv("ab", true)); + + escapeCsvWithTrimming("", ""); + escapeCsvWithTrimming(" \t ", ""); + escapeCsvWithTrimming("ab", "ab"); + escapeCsvWithTrimming("a b", "a b"); + escapeCsvWithTrimming(" \ta \tb", "a \tb"); + escapeCsvWithTrimming("a \tb \t", "a \tb"); + escapeCsvWithTrimming("\t a \tb \t", "a \tb"); + escapeCsvWithTrimming("\"\t a b \"", "\"\t a b \""); + escapeCsvWithTrimming(" \"\t a b \"\t", "\"\t a b \""); + escapeCsvWithTrimming(" testing\t\n ", "\"testing\t\n\""); + escapeCsvWithTrimming("\ttest,ing ", "\"test,ing\""); + } + + @Test + public void escapeCsvGarbageFree() { + // 'StringUtil#escapeCsv()' should return same string object if string didn't changing. + assertSame("1", StringUtil.escapeCsv("1", true)); + assertSame(" 123 ", StringUtil.escapeCsv(" 123 ", false)); + assertSame("\" 123 \"", StringUtil.escapeCsv("\" 123 \"", true)); + assertSame("\"\"", StringUtil.escapeCsv("\"\"", true)); + assertSame("123 \"\"", StringUtil.escapeCsv("123 \"\"", true)); + assertSame("123\"\"321", StringUtil.escapeCsv("123\"\"321", true)); + assertSame("\"123\"\"321\"", StringUtil.escapeCsv("\"123\"\"321\"", true)); + } + + @Test + public void testUnescapeCsv() { + assertEquals("", unescapeCsv("")); + assertEquals("\"", unescapeCsv("\"\"\"\"")); + assertEquals("\"\"", unescapeCsv("\"\"\"\"\"\"")); + assertEquals("\"\"\"", unescapeCsv("\"\"\"\"\"\"\"\"")); + assertEquals("\"netty\"", unescapeCsv("\"\"\"netty\"\"\"")); + assertEquals("netty", unescapeCsv("netty")); + assertEquals("netty", unescapeCsv("\"netty\"")); + assertEquals("\r", unescapeCsv("\"\r\"")); + assertEquals("\n", unescapeCsv("\"\n\"")); + assertEquals("hello,netty", unescapeCsv("\"hello,netty\"")); + } + + @Test + public void unescapeCsvWithSingleQuote() { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + unescapeCsv("\""); + } + }); + } + + @Test + public void unescapeCsvWithOddQuote() { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + unescapeCsv("\"\"\""); + } + }); + } + + @Test + public void unescapeCsvWithCRAndWithoutQuote() { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + unescapeCsv("\r"); + } + }); + } + + @Test + public void unescapeCsvWithLFAndWithoutQuote() { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + unescapeCsv("\n"); + } + }); + } + + @Test + public void unescapeCsvWithCommaAndWithoutQuote() { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + unescapeCsv(","); + } + }); + } + + @Test + public void escapeCsvAndUnEscapeCsv() { + assertEscapeCsvAndUnEscapeCsv(""); + assertEscapeCsvAndUnEscapeCsv("netty"); + assertEscapeCsvAndUnEscapeCsv("hello,netty"); + assertEscapeCsvAndUnEscapeCsv("hello,\"netty\""); + assertEscapeCsvAndUnEscapeCsv("\""); + assertEscapeCsvAndUnEscapeCsv(","); + assertEscapeCsvAndUnEscapeCsv("\r"); + assertEscapeCsvAndUnEscapeCsv("\n"); + } + + private static void assertEscapeCsvAndUnEscapeCsv(String value) { + assertEquals(value, unescapeCsv(StringUtil.escapeCsv(value))); + } + + @Test + public void testUnescapeCsvFields() { + assertEquals(Collections.singletonList(""), unescapeCsvFields("")); + assertEquals(Arrays.asList("", ""), unescapeCsvFields(",")); + assertEquals(Arrays.asList("a", ""), unescapeCsvFields("a,")); + assertEquals(Arrays.asList("", "a"), unescapeCsvFields(",a")); + assertEquals(Collections.singletonList("\""), unescapeCsvFields("\"\"\"\"")); + assertEquals(Arrays.asList("\"", "\""), unescapeCsvFields("\"\"\"\",\"\"\"\"")); + assertEquals(Collections.singletonList("netty"), unescapeCsvFields("netty")); + assertEquals(Arrays.asList("hello", "netty"), unescapeCsvFields("hello,netty")); + assertEquals(Collections.singletonList("hello,netty"), unescapeCsvFields("\"hello,netty\"")); + assertEquals(Arrays.asList("hello", "netty"), unescapeCsvFields("\"hello\",\"netty\"")); + assertEquals(Arrays.asList("a\"b", "c\"d"), unescapeCsvFields("\"a\"\"b\",\"c\"\"d\"")); + assertEquals(Arrays.asList("a\rb", "c\nd"), unescapeCsvFields("\"a\rb\",\"c\nd\"")); + } + + @Test + public void unescapeCsvFieldsWithCRWithoutQuote() { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + unescapeCsvFields("a,\r"); + } + }); + } + + @Test + public void unescapeCsvFieldsWithLFWithoutQuote() { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + unescapeCsvFields("a,\r"); + } + }); + } + + @Test + public void unescapeCsvFieldsWithQuote() { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + unescapeCsvFields("a,\""); + } + }); + } + + @Test + public void unescapeCsvFieldsWithQuote2() { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + unescapeCsvFields("\",a"); + } + }); + } + + @Test + public void unescapeCsvFieldsWithQuote3() { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + unescapeCsvFields("a\"b,a"); + } + }); + } + + @Test + public void testSimpleClassName() throws Exception { + testSimpleClassName(String.class); + } + + @Test + public void testSimpleInnerClassName() throws Exception { + testSimpleClassName(TestClass.class); + } + + private static void testSimpleClassName(Class clazz) throws Exception { + Package pkg = clazz.getPackage(); + String name; + if (pkg != null) { + name = clazz.getName().substring(pkg.getName().length() + 1); + } else { + name = clazz.getName(); + } + assertEquals(name, simpleClassName(clazz)); + } + + private static final class TestClass { } + + @Test + public void testEndsWith() { + assertFalse(StringUtil.endsWith("", 'u')); + assertTrue(StringUtil.endsWith("u", 'u')); + assertTrue(StringUtil.endsWith("-u", 'u')); + assertFalse(StringUtil.endsWith("-", 'u')); + assertFalse(StringUtil.endsWith("u-", 'u')); + } + + @Test + public void trimOws() { + assertSame("", StringUtil.trimOws("")); + assertEquals("", StringUtil.trimOws(" \t ")); + assertSame("a", StringUtil.trimOws("a")); + assertEquals("a", StringUtil.trimOws(" a")); + assertEquals("a", StringUtil.trimOws("a ")); + assertEquals("a", StringUtil.trimOws(" a ")); + assertSame("abc", StringUtil.trimOws("abc")); + assertEquals("abc", StringUtil.trimOws("\tabc")); + assertEquals("abc", StringUtil.trimOws("abc\t")); + assertEquals("abc", StringUtil.trimOws("\tabc\t")); + assertSame("a\t b", StringUtil.trimOws("a\t b")); + assertEquals("", StringUtil.trimOws("\t ").toString()); + assertEquals("a b", StringUtil.trimOws("\ta b \t").toString()); + } + + @Test + public void testJoin() { + assertEquals("", + StringUtil.join(",", Collections.emptyList()).toString()); + assertEquals("a", + StringUtil.join(",", Collections.singletonList("a")).toString()); + assertEquals("a,b", + StringUtil.join(",", Arrays.asList("a", "b")).toString()); + assertEquals("a,b,c", + StringUtil.join(",", Arrays.asList("a", "b", "c")).toString()); + assertEquals("a,b,c,null,d", + StringUtil.join(",", Arrays.asList("a", "b", "c", null, "d")).toString()); + } + + @Test + public void testIsNullOrEmpty() { + assertTrue(isNullOrEmpty(null)); + assertTrue(isNullOrEmpty("")); + assertTrue(isNullOrEmpty(StringUtil.EMPTY_STRING)); + assertFalse(isNullOrEmpty(" ")); + assertFalse(isNullOrEmpty("\t")); + assertFalse(isNullOrEmpty("\n")); + assertFalse(isNullOrEmpty("foo")); + assertFalse(isNullOrEmpty(NEWLINE)); + } + + @Test + public void testIndexOfWhiteSpace() { + assertEquals(-1, indexOfWhiteSpace("", 0)); + assertEquals(0, indexOfWhiteSpace(" ", 0)); + assertEquals(-1, indexOfWhiteSpace(" ", 1)); + assertEquals(0, indexOfWhiteSpace("\n", 0)); + assertEquals(-1, indexOfWhiteSpace("\n", 1)); + assertEquals(0, indexOfWhiteSpace("\t", 0)); + assertEquals(-1, indexOfWhiteSpace("\t", 1)); + assertEquals(3, indexOfWhiteSpace("foo\r\nbar", 1)); + assertEquals(-1, indexOfWhiteSpace("foo\r\nbar", 10)); + assertEquals(7, indexOfWhiteSpace("foo\tbar\r\n", 6)); + assertEquals(-1, indexOfWhiteSpace("foo\tbar\r\n", Integer.MAX_VALUE)); + } + + @Test + public void testIndexOfNonWhiteSpace() { + assertEquals(-1, indexOfNonWhiteSpace("", 0)); + assertEquals(-1, indexOfNonWhiteSpace(" ", 0)); + assertEquals(-1, indexOfNonWhiteSpace(" \t", 0)); + assertEquals(-1, indexOfNonWhiteSpace(" \t\r\n", 0)); + assertEquals(2, indexOfNonWhiteSpace(" \tfoo\r\n", 0)); + assertEquals(2, indexOfNonWhiteSpace(" \tfoo\r\n", 1)); + assertEquals(4, indexOfNonWhiteSpace(" \tfoo\r\n", 4)); + assertEquals(-1, indexOfNonWhiteSpace(" \tfoo\r\n", 10)); + assertEquals(-1, indexOfNonWhiteSpace(" \tfoo\r\n", Integer.MAX_VALUE)); + } +} diff --git a/netty-util/src/test/java/io/netty/util/internal/SystemPropertyUtilTest.java b/netty-util/src/test/java/io/netty/util/internal/SystemPropertyUtilTest.java new file mode 100644 index 0000000..1c2df10 --- /dev/null +++ b/netty-util/src/test/java/io/netty/util/internal/SystemPropertyUtilTest.java @@ -0,0 +1,140 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.internal; + +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class SystemPropertyUtilTest { + + @BeforeEach + public void clearSystemPropertyBeforeEach() { + System.clearProperty("key"); + } + + @Test + public void testGetWithKeyNull() { + assertThrows(NullPointerException.class, new Executable() { + @Override + public void execute() { + SystemPropertyUtil.get(null, null); + } + }); + } + + @Test + public void testGetWithKeyEmpty() { + assertThrows(IllegalArgumentException.class, new Executable() { + @Override + public void execute() { + SystemPropertyUtil.get("", null); + } + }); + } + + @Test + public void testGetDefaultValueWithPropertyNull() { + assertEquals("default", SystemPropertyUtil.get("key", "default")); + } + + @Test + public void testGetPropertyValue() { + System.setProperty("key", "value"); + assertEquals("value", SystemPropertyUtil.get("key")); + } + + @Test + public void testGetBooleanDefaultValueWithPropertyNull() { + assertTrue(SystemPropertyUtil.getBoolean("key", true)); + assertFalse(SystemPropertyUtil.getBoolean("key", false)); + } + + @Test + public void testGetBooleanDefaultValueWithEmptyString() { + System.setProperty("key", ""); + assertTrue(SystemPropertyUtil.getBoolean("key", true)); + assertFalse(SystemPropertyUtil.getBoolean("key", false)); + } + + @Test + public void testGetBooleanWithTrueValue() { + System.setProperty("key", "true"); + assertTrue(SystemPropertyUtil.getBoolean("key", false)); + System.setProperty("key", "yes"); + assertTrue(SystemPropertyUtil.getBoolean("key", false)); + System.setProperty("key", "1"); + assertTrue(SystemPropertyUtil.getBoolean("key", true)); + } + + @Test + public void testGetBooleanWithFalseValue() { + System.setProperty("key", "false"); + assertFalse(SystemPropertyUtil.getBoolean("key", true)); + System.setProperty("key", "no"); + assertFalse(SystemPropertyUtil.getBoolean("key", false)); + System.setProperty("key", "0"); + assertFalse(SystemPropertyUtil.getBoolean("key", true)); + } + + @Test + public void testGetBooleanDefaultValueWithWrongValue() { + System.setProperty("key", "abc"); + assertTrue(SystemPropertyUtil.getBoolean("key", true)); + System.setProperty("key", "123"); + assertFalse(SystemPropertyUtil.getBoolean("key", false)); + } + + @Test + public void getIntDefaultValueWithPropertyNull() { + assertEquals(1, SystemPropertyUtil.getInt("key", 1)); + } + + @Test + public void getIntWithPropertValueIsInt() { + System.setProperty("key", "123"); + assertEquals(123, SystemPropertyUtil.getInt("key", 1)); + } + + @Test + public void getIntDefaultValueWithPropertValueIsNotInt() { + System.setProperty("key", "NotInt"); + assertEquals(1, SystemPropertyUtil.getInt("key", 1)); + } + + @Test + public void getLongDefaultValueWithPropertyNull() { + assertEquals(1, SystemPropertyUtil.getLong("key", 1)); + } + + @Test + public void getLongWithPropertValueIsLong() { + System.setProperty("key", "123"); + assertEquals(123, SystemPropertyUtil.getLong("key", 1)); + } + + @Test + public void getLongDefaultValueWithPropertValueIsNotLong() { + System.setProperty("key", "NotInt"); + assertEquals(1, SystemPropertyUtil.getLong("key", 1)); + } + +} diff --git a/netty-util/src/test/java/io/netty/util/internal/ThreadExecutorMapTest.java b/netty-util/src/test/java/io/netty/util/internal/ThreadExecutorMapTest.java new file mode 100644 index 0000000..6b1ff00 --- /dev/null +++ b/netty-util/src/test/java/io/netty/util/internal/ThreadExecutorMapTest.java @@ -0,0 +1,65 @@ +/* + * Copyright 2019 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.internal; + +import io.netty.util.concurrent.ImmediateEventExecutor; +import io.netty.util.concurrent.ImmediateExecutor; +import org.junit.jupiter.api.Test; + +import java.util.concurrent.Executor; +import java.util.concurrent.Executors; +import java.util.concurrent.ThreadFactory; + +import static org.junit.jupiter.api.Assertions.assertSame; + +public class ThreadExecutorMapTest { + + @Test + public void testDecorateExecutor() { + Executor executor = ThreadExecutorMap.apply(ImmediateExecutor.INSTANCE, ImmediateEventExecutor.INSTANCE); + executor.execute(new Runnable() { + @Override + public void run() { + assertSame(ImmediateEventExecutor.INSTANCE, ThreadExecutorMap.currentExecutor()); + } + }); + } + + @Test + public void testDecorateRunnable() { + ThreadExecutorMap.apply(new Runnable() { + @Override + public void run() { + assertSame(ImmediateEventExecutor.INSTANCE, + ThreadExecutorMap.currentExecutor()); + } + }, ImmediateEventExecutor.INSTANCE).run(); + } + + @Test + public void testDecorateThreadFactory() throws InterruptedException { + ThreadFactory threadFactory = + ThreadExecutorMap.apply(Executors.defaultThreadFactory(), ImmediateEventExecutor.INSTANCE); + Thread thread = threadFactory.newThread(new Runnable() { + @Override + public void run() { + assertSame(ImmediateEventExecutor.INSTANCE, ThreadExecutorMap.currentExecutor()); + } + }); + thread.start(); + thread.join(); + } +} diff --git a/netty-util/src/test/java/io/netty/util/internal/ThreadLocalRandomTest.java b/netty-util/src/test/java/io/netty/util/internal/ThreadLocalRandomTest.java new file mode 100644 index 0000000..33259fb --- /dev/null +++ b/netty-util/src/test/java/io/netty/util/internal/ThreadLocalRandomTest.java @@ -0,0 +1,37 @@ +/* + * Copyright 2014 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.internal; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class ThreadLocalRandomTest { + + @Test + public void getInitialSeedUniquifierPreservesInterrupt() { + try { + Thread.currentThread().interrupt(); + assertTrue(Thread.currentThread().isInterrupted(), + "Assert that thread is interrupted before invocation of getInitialSeedUniquifier()"); + ThreadLocalRandom.getInitialSeedUniquifier(); + assertTrue(Thread.currentThread().isInterrupted(), + "Assert that thread is interrupted after invocation of getInitialSeedUniquifier()"); + } finally { + Thread.interrupted(); // clear interrupted status in order to not affect other tests + } + } +} diff --git a/netty-util/src/test/java/io/netty/util/internal/TypeParameterMatcherTest.java b/netty-util/src/test/java/io/netty/util/internal/TypeParameterMatcherTest.java new file mode 100644 index 0000000..1b6e017 --- /dev/null +++ b/netty-util/src/test/java/io/netty/util/internal/TypeParameterMatcherTest.java @@ -0,0 +1,157 @@ +/* + * Copyright 2013 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ + +package io.netty.util.internal; + +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import java.util.Date; + +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class TypeParameterMatcherTest { + + @Test + public void testConcreteClass() throws Exception { + TypeParameterMatcher m = TypeParameterMatcher.find(new TypeQ(), TypeX.class, "A"); + assertFalse(m.match(new Object())); + assertFalse(m.match(new A())); + assertFalse(m.match(new AA())); + assertTrue(m.match(new AAA())); + assertFalse(m.match(new B())); + assertFalse(m.match(new BB())); + assertFalse(m.match(new BBB())); + assertFalse(m.match(new C())); + assertFalse(m.match(new CC())); + } + + @Test + public void testUnsolvedParameter() throws Exception { + assertThrows(IllegalStateException.class, new Executable() { + @Override + public void execute() { + TypeParameterMatcher.find(new TypeQ(), TypeX.class, "B"); + } + }); + } + + @Test + public void testAnonymousClass() throws Exception { + TypeParameterMatcher m = TypeParameterMatcher.find(new TypeQ() { }, TypeX.class, "B"); + assertFalse(m.match(new Object())); + assertFalse(m.match(new A())); + assertFalse(m.match(new AA())); + assertFalse(m.match(new AAA())); + assertFalse(m.match(new B())); + assertFalse(m.match(new BB())); + assertTrue(m.match(new BBB())); + assertFalse(m.match(new C())); + assertFalse(m.match(new CC())); + } + + @Test + public void testAbstractClass() throws Exception { + TypeParameterMatcher m = TypeParameterMatcher.find(new TypeQ(), TypeX.class, "C"); + assertFalse(m.match(new Object())); + assertFalse(m.match(new A())); + assertFalse(m.match(new AA())); + assertFalse(m.match(new AAA())); + assertFalse(m.match(new B())); + assertFalse(m.match(new BB())); + assertFalse(m.match(new BBB())); + assertFalse(m.match(new C())); + assertTrue(m.match(new CC())); + } + + public static class TypeX { + A a; + B b; + C c; + } + + public static class TypeY extends TypeX { } + + public abstract static class TypeZ extends TypeY { } + + public static class TypeQ extends TypeZ { } + + public static class A { } + public static class AA extends A { } + public static class AAA extends AA { } + + public static class B { } + public static class BB extends B { } + public static class BBB extends BB { } + + public static class C { } + public static class CC extends C { } + + @Test + public void testInaccessibleClass() throws Exception { + TypeParameterMatcher m = TypeParameterMatcher.find(new U() { }, U.class, "E"); + assertFalse(m.match(new Object())); + assertTrue(m.match(new T())); + } + + private static class T { } + private static class U { E a; } + + @Test + public void testArrayAsTypeParam() throws Exception { + TypeParameterMatcher m = TypeParameterMatcher.find(new U() { }, U.class, "E"); + assertFalse(m.match(new Object())); + assertTrue(m.match(new byte[1])); + } + + @Test + public void testRawType() throws Exception { + TypeParameterMatcher m = TypeParameterMatcher.find(new U() { }, U.class, "E"); + assertTrue(m.match(new Object())); + } + + private static class V { + U u = new U() { }; + } + + @Test + public void testInnerClass() throws Exception { + TypeParameterMatcher m = TypeParameterMatcher.find(new V().u, U.class, "E"); + assertTrue(m.match(new Object())); + } + + private abstract static class W { + E e; + } + + private static class X extends W { + T t; + } + + @Test + public void testErasure() throws Exception { + assertThrows(IllegalStateException.class, new Executable() { + @Override + public void execute() { + TypeParameterMatcher m = TypeParameterMatcher.find(new X(), W.class, "E"); + assertTrue(m.match(new Date())); + assertFalse(m.match(new Object())); + } + }); + } +} diff --git a/netty-util/src/test/java/io/netty/util/internal/logging/AbstractInternalLoggerTest.java b/netty-util/src/test/java/io/netty/util/internal/logging/AbstractInternalLoggerTest.java new file mode 100644 index 0000000..4b51d88 --- /dev/null +++ b/netty-util/src/test/java/io/netty/util/internal/logging/AbstractInternalLoggerTest.java @@ -0,0 +1,150 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.internal.logging; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertFalse; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import java.lang.reflect.Method; +import java.util.HashMap; +import java.util.Map; + +import org.junit.jupiter.api.Test; + +/** + * We only need to test methods defined by {@link InternaLogger}. + */ +public abstract class AbstractInternalLoggerTest { + protected String loggerName = "foo"; + protected T mockLog; + protected InternalLogger logger; + protected final Map result = new HashMap(); + + @SuppressWarnings("unchecked") + protected V getResult(String key) { + return (V) result.get(key); + } + + @Test + public void testName() { + assertEquals(loggerName, logger.name()); + } + + @Test + public void testAllLevel() throws Exception { + testLevel(InternalLogLevel.TRACE); + testLevel(InternalLogLevel.DEBUG); + testLevel(InternalLogLevel.INFO); + testLevel(InternalLogLevel.WARN); + testLevel(InternalLogLevel.ERROR); + } + + protected void testLevel(InternalLogLevel level) throws Exception { + result.clear(); + + String format1 = "a={}", format2 = "a={}, b= {}", format3 = "a={}, b= {}, c= {}"; + String msg = "a test message from Junit"; + Exception ex = new Exception("a test Exception from Junit"); + + Class clazz = InternalLogger.class; + String levelName = level.name(), logMethod = levelName.toLowerCase(); + Method isXXEnabled = clazz + .getMethod("is" + levelName.charAt(0) + levelName.substring(1).toLowerCase() + "Enabled"); + + // when level log is disabled + setLevelEnable(level, false); + assertFalse((Boolean) isXXEnabled.invoke(logger)); + + // test xx(msg) + clazz.getMethod(logMethod, String.class).invoke(logger, msg); + assertTrue(result.isEmpty()); + + // test xx(format, arg) + clazz.getMethod(logMethod, String.class, Object.class).invoke(logger, format1, msg); + assertTrue(result.isEmpty()); + + // test xx(format, argA, argB) + clazz.getMethod(logMethod, String.class, Object.class, Object.class).invoke(logger, format2, msg, msg); + assertTrue(result.isEmpty()); + + // test xx(format, ...arguments) + clazz.getMethod(logMethod, String.class, Object[].class).invoke(logger, format3, + new Object[] { msg, msg, msg }); + assertTrue(result.isEmpty()); + + // test xx(format, ...arguments), the last argument is Throwable + clazz.getMethod(logMethod, String.class, Object[].class).invoke(logger, format3, + new Object[] { msg, msg, msg, ex }); + assertTrue(result.isEmpty()); + + // test xx(msg, Throwable) + clazz.getMethod(logMethod, String.class, Throwable.class).invoke(logger, msg, ex); + assertTrue(result.isEmpty()); + + // test xx(Throwable) + clazz.getMethod(logMethod, Throwable.class).invoke(logger, ex); + assertTrue(result.isEmpty()); + + // when level log is enabled + setLevelEnable(level, true); + assertTrue((Boolean) isXXEnabled.invoke(logger)); + + // test xx(msg) + result.clear(); + clazz.getMethod(logMethod, String.class).invoke(logger, msg); + assertResult(level, null, null, msg); + + // test xx(format, arg) + result.clear(); + clazz.getMethod(logMethod, String.class, Object.class).invoke(logger, format1, msg); + assertResult(level, format1, null, msg); + + // test xx(format, argA, argB) + result.clear(); + clazz.getMethod(logMethod, String.class, Object.class, Object.class).invoke(logger, format2, msg, msg); + assertResult(level, format2, null, msg, msg); + + // test xx(format, ...arguments) + result.clear(); + clazz.getMethod(logMethod, String.class, Object[].class).invoke(logger, format3, + new Object[] { msg, msg, msg }); + assertResult(level, format3, null, msg, msg, msg); + + // test xx(format, ...arguments), the last argument is Throwable + result.clear(); + clazz.getMethod(logMethod, String.class, Object[].class).invoke(logger, format3, + new Object[] { msg, msg, msg, ex }); + assertResult(level, format3, ex, msg, msg, msg, ex); + + // test xx(msg, Throwable) + result.clear(); + clazz.getMethod(logMethod, String.class, Throwable.class).invoke(logger, msg, ex); + assertResult(level, null, ex, msg); + + // test xx(Throwable) + result.clear(); + clazz.getMethod(logMethod, Throwable.class).invoke(logger, ex); + assertResult(level, null, ex); + } + + /** a just default code, you can override to fix {@linkplain #mockLog} */ + protected void assertResult(InternalLogLevel level, String format, Throwable t, Object... args) { + assertFalse(result.isEmpty()); + } + + protected abstract void setLevelEnable(InternalLogLevel level, boolean enable) throws Exception; +} diff --git a/netty-util/src/test/java/io/netty/util/internal/logging/InternalLoggerFactoryTest.java b/netty-util/src/test/java/io/netty/util/internal/logging/InternalLoggerFactoryTest.java new file mode 100644 index 0000000..69dc594 --- /dev/null +++ b/netty-util/src/test/java/io/netty/util/internal/logging/InternalLoggerFactoryTest.java @@ -0,0 +1,188 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.internal.logging; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.function.Executable; + +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertNotSame; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; +import static org.mockito.Mockito.*; + +public class InternalLoggerFactoryTest { + private static final Exception e = new Exception(); + private InternalLoggerFactory oldLoggerFactory; + private InternalLogger mockLogger; + + @BeforeEach + public void init() { + oldLoggerFactory = InternalLoggerFactory.getDefaultFactory(); + + final InternalLoggerFactory mockFactory = mock(InternalLoggerFactory.class); + mockLogger = mock(InternalLogger.class); + when(mockFactory.newInstance("mock")).thenReturn(mockLogger); + InternalLoggerFactory.setDefaultFactory(mockFactory); + } + + @AfterEach + public void destroy() { + reset(mockLogger); + InternalLoggerFactory.setDefaultFactory(oldLoggerFactory); + } + + @Test + public void shouldNotAllowNullDefaultFactory() { + assertThrows(NullPointerException.class, new Executable() { + @Override + public void execute() { + InternalLoggerFactory.setDefaultFactory(null); + } + }); + } + + @Test + public void shouldGetInstance() { + InternalLoggerFactory.setDefaultFactory(oldLoggerFactory); + + String helloWorld = "Hello, world!"; + + InternalLogger one = InternalLoggerFactory.getInstance("helloWorld"); + InternalLogger two = InternalLoggerFactory.getInstance(helloWorld.getClass()); + + assertNotNull(one); + assertNotNull(two); + assertNotSame(one, two); + } + + @Test + public void testIsTraceEnabled() { + when(mockLogger.isTraceEnabled()).thenReturn(true); + + InternalLogger logger = InternalLoggerFactory.getInstance("mock"); + assertTrue(logger.isTraceEnabled()); + verify(mockLogger).isTraceEnabled(); + } + + @Test + public void testIsDebugEnabled() { + when(mockLogger.isDebugEnabled()).thenReturn(true); + + InternalLogger logger = InternalLoggerFactory.getInstance("mock"); + assertTrue(logger.isDebugEnabled()); + verify(mockLogger).isDebugEnabled(); + } + + @Test + public void testIsInfoEnabled() { + when(mockLogger.isInfoEnabled()).thenReturn(true); + + InternalLogger logger = InternalLoggerFactory.getInstance("mock"); + assertTrue(logger.isInfoEnabled()); + verify(mockLogger).isInfoEnabled(); + } + + @Test + public void testIsWarnEnabled() { + when(mockLogger.isWarnEnabled()).thenReturn(true); + + InternalLogger logger = InternalLoggerFactory.getInstance("mock"); + assertTrue(logger.isWarnEnabled()); + verify(mockLogger).isWarnEnabled(); + } + + @Test + public void testIsErrorEnabled() { + when(mockLogger.isErrorEnabled()).thenReturn(true); + + InternalLogger logger = InternalLoggerFactory.getInstance("mock"); + assertTrue(logger.isErrorEnabled()); + verify(mockLogger).isErrorEnabled(); + } + + @Test + public void testTrace() { + final InternalLogger logger = InternalLoggerFactory.getInstance("mock"); + logger.trace("a"); + verify(mockLogger).trace("a"); + } + + @Test + public void testTraceWithException() { + final InternalLogger logger = InternalLoggerFactory.getInstance("mock"); + logger.trace("a", e); + verify(mockLogger).trace("a", e); + } + + @Test + public void testDebug() { + final InternalLogger logger = InternalLoggerFactory.getInstance("mock"); + logger.debug("a"); + verify(mockLogger).debug("a"); + } + + @Test + public void testDebugWithException() { + final InternalLogger logger = InternalLoggerFactory.getInstance("mock"); + logger.debug("a", e); + verify(mockLogger).debug("a", e); + } + + @Test + public void testInfo() { + final InternalLogger logger = InternalLoggerFactory.getInstance("mock"); + logger.info("a"); + verify(mockLogger).info("a"); + } + + @Test + public void testInfoWithException() { + final InternalLogger logger = InternalLoggerFactory.getInstance("mock"); + logger.info("a", e); + verify(mockLogger).info("a", e); + } + + @Test + public void testWarn() { + final InternalLogger logger = InternalLoggerFactory.getInstance("mock"); + logger.warn("a"); + verify(mockLogger).warn("a"); + } + + @Test + public void testWarnWithException() { + final InternalLogger logger = InternalLoggerFactory.getInstance("mock"); + logger.warn("a", e); + verify(mockLogger).warn("a", e); + } + + @Test + public void testError() { + final InternalLogger logger = InternalLoggerFactory.getInstance("mock"); + logger.error("a"); + verify(mockLogger).error("a"); + } + + @Test + public void testErrorWithException() { + final InternalLogger logger = InternalLoggerFactory.getInstance("mock"); + logger.error("a", e); + verify(mockLogger).error("a", e); + } +} diff --git a/netty-util/src/test/java/io/netty/util/internal/logging/JdkLoggerFactoryTest.java b/netty-util/src/test/java/io/netty/util/internal/logging/JdkLoggerFactoryTest.java new file mode 100644 index 0000000..96278d2 --- /dev/null +++ b/netty-util/src/test/java/io/netty/util/internal/logging/JdkLoggerFactoryTest.java @@ -0,0 +1,31 @@ +/* + * Copyright 2012 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +package io.netty.util.internal.logging; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertTrue; + +public class JdkLoggerFactoryTest { + + @Test + public void testCreation() { + InternalLogger logger = JdkLoggerFactory.INSTANCE.newInstance("foo"); + assertTrue(logger instanceof JdkLogger); + assertEquals("foo", logger.name()); + } +} diff --git a/netty-util/src/test/java/io/netty/util/internal/logging/MessageFormatterTest.java b/netty-util/src/test/java/io/netty/util/internal/logging/MessageFormatterTest.java new file mode 100644 index 0000000..60a44fc --- /dev/null +++ b/netty-util/src/test/java/io/netty/util/internal/logging/MessageFormatterTest.java @@ -0,0 +1,325 @@ +/* + * Copyright 2017 The Netty Project + * + * The Netty Project licenses this file to you under the Apache License, + * version 2.0 (the "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at: + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + */ +/** + * Copyright (c) 2004-2011 QOS.ch + * All rights reserved. + *

+ * Permission is hereby granted, free of charge, to any person obtaining + * a copy of this software and associated documentation files (the + * "Software"), to deal in the Software without restriction, including + * without limitation the rights to use, copy, modify, merge, publish, + * distribute, sublicense, and/or sell copies of the Software, and to + * permit persons to whom the Software is furnished to do so, subject to + * the following conditions: + *

+ * The above copyright notice and this permission notice shall be + * included in all copies or substantial portions of the Software. + *

+ * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, + * EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF + * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND + * NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE + * LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION + * OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION + * WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. + */ +package io.netty.util.internal.logging; + +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNull; + +public class MessageFormatterTest { + + @Test + public void testNull() { + String result = MessageFormatter.format(null, 1).getMessage(); + assertNull(result); + } + + @Test + public void nullParametersShouldBeHandledWithoutBarfing() { + String result = MessageFormatter.format("Value is {}.", null).getMessage(); + assertEquals("Value is null.", result); + + result = MessageFormatter.format("Val1 is {}, val2 is {}.", null, null).getMessage(); + assertEquals("Val1 is null, val2 is null.", result); + + result = MessageFormatter.format("Val1 is {}, val2 is {}.", 1, null).getMessage(); + assertEquals("Val1 is 1, val2 is null.", result); + + result = MessageFormatter.format("Val1 is {}, val2 is {}.", null, 2).getMessage(); + assertEquals("Val1 is null, val2 is 2.", result); + + result = MessageFormatter.arrayFormat( + "Val1 is {}, val2 is {}, val3 is {}", new Integer[] { null, null, null }).getMessage(); + assertEquals("Val1 is null, val2 is null, val3 is null", result); + + result = MessageFormatter.arrayFormat( + "Val1 is {}, val2 is {}, val3 is {}", new Integer[] { null, 2, 3 }).getMessage(); + assertEquals("Val1 is null, val2 is 2, val3 is 3", result); + + result = MessageFormatter.arrayFormat( + "Val1 is {}, val2 is {}, val3 is {}", new Integer[] { null, null, 3 }).getMessage(); + assertEquals("Val1 is null, val2 is null, val3 is 3", result); + } + + @Test + public void verifyOneParameterIsHandledCorrectly() { + String result = MessageFormatter.format("Value is {}.", 3).getMessage(); + assertEquals("Value is 3.", result); + + result = MessageFormatter.format("Value is {", 3).getMessage(); + assertEquals("Value is {", result); + + result = MessageFormatter.format("{} is larger than 2.", 3).getMessage(); + assertEquals("3 is larger than 2.", result); + + result = MessageFormatter.format("No subst", 3).getMessage(); + assertEquals("No subst", result); + + result = MessageFormatter.format("Incorrect {subst", 3).getMessage(); + assertEquals("Incorrect {subst", result); + + result = MessageFormatter.format("Value is {bla} {}", 3).getMessage(); + assertEquals("Value is {bla} 3", result); + + result = MessageFormatter.format("Escaped \\{} subst", 3).getMessage(); + assertEquals("Escaped {} subst", result); + + result = MessageFormatter.format("{Escaped", 3).getMessage(); + assertEquals("{Escaped", result); + + result = MessageFormatter.format("\\{}Escaped", 3).getMessage(); + assertEquals("{}Escaped", result); + + result = MessageFormatter.format("File name is {{}}.", "App folder.zip").getMessage(); + assertEquals("File name is {App folder.zip}.", result); + + // escaping the escape character + result = MessageFormatter.format("File name is C:\\\\{}.", "App folder.zip").getMessage(); + assertEquals("File name is C:\\App folder.zip.", result); + } + + @Test + public void testTwoParameters() { + String result = MessageFormatter.format("Value {} is smaller than {}.", 1, 2).getMessage(); + assertEquals("Value 1 is smaller than 2.", result); + + result = MessageFormatter.format("Value {} is smaller than {}", 1, 2).getMessage(); + assertEquals("Value 1 is smaller than 2", result); + + result = MessageFormatter.format("{}{}", 1, 2).getMessage(); + assertEquals("12", result); + + result = MessageFormatter.format("Val1={}, Val2={", 1, 2).getMessage(); + assertEquals("Val1=1, Val2={", result); + + result = MessageFormatter.format("Value {} is smaller than \\{}", 1, 2).getMessage(); + assertEquals("Value 1 is smaller than {}", result); + + result = MessageFormatter.format("Value {} is smaller than \\{} tail", 1, 2).getMessage(); + assertEquals("Value 1 is smaller than {} tail", result); + + result = MessageFormatter.format("Value {} is smaller than \\{", 1, 2).getMessage(); + assertEquals("Value 1 is smaller than \\{", result); + + result = MessageFormatter.format("Value {} is smaller than {tail", 1, 2).getMessage(); + assertEquals("Value 1 is smaller than {tail", result); + + result = MessageFormatter.format("Value \\{} is smaller than {}", 1, 2).getMessage(); + assertEquals("Value {} is smaller than 1", result); + } + + @Test + public void testExceptionIn_toString() { + Object o = new Object() { + @Override + public String toString() { + throw new IllegalStateException("a"); + } + }; + String result = MessageFormatter.format("Troublesome object {}", o).getMessage(); + assertEquals("Troublesome object [FAILED toString()]", result); + } + + @Test + public void testNullArray() { + String msg0 = "msg0"; + String msg1 = "msg1 {}"; + String msg2 = "msg2 {} {}"; + String msg3 = "msg3 {} {} {}"; + + Object[] args = null; + + String result = MessageFormatter.arrayFormat(msg0, args).getMessage(); + assertEquals(msg0, result); + + result = MessageFormatter.arrayFormat(msg1, args).getMessage(); + assertEquals(msg1, result); + + result = MessageFormatter.arrayFormat(msg2, args).getMessage(); + assertEquals(msg2, result); + + result = MessageFormatter.arrayFormat(msg3, args).getMessage(); + assertEquals(msg3, result); + } + + // tests the case when the parameters are supplied in a single array + @Test + public void testArrayFormat() { + Integer[] ia0 = { 1, 2, 3 }; + + String result = MessageFormatter.arrayFormat("Value {} is smaller than {} and {}.", ia0).getMessage(); + assertEquals("Value 1 is smaller than 2 and 3.", result); + + result = MessageFormatter.arrayFormat("{}{}{}", ia0).getMessage(); + assertEquals("123", result); + + result = MessageFormatter.arrayFormat("Value {} is smaller than {}.", ia0).getMessage(); + assertEquals("Value 1 is smaller than 2.", result); + + result = MessageFormatter.arrayFormat("Value {} is smaller than {}", ia0).getMessage(); + assertEquals("Value 1 is smaller than 2", result); + + result = MessageFormatter.arrayFormat("Val={}, {, Val={}", ia0).getMessage(); + assertEquals("Val=1, {, Val=2", result); + + result = MessageFormatter.arrayFormat("Val={}, {, Val={}", ia0).getMessage(); + assertEquals("Val=1, {, Val=2", result); + + result = MessageFormatter.arrayFormat("Val1={}, Val2={", ia0).getMessage(); + assertEquals("Val1=1, Val2={", result); + } + + @Test + public void testArrayValues() { + Integer[] p1 = { 2, 3 }; + + String result = MessageFormatter.format("{}{}", 1, p1).getMessage(); + assertEquals("1[2, 3]", result); + + // Integer[] + result = MessageFormatter.arrayFormat("{}{}", new Object[] { "a", p1 }).getMessage(); + assertEquals("a[2, 3]", result); + + // byte[] + result = MessageFormatter.arrayFormat("{}{}", new Object[] { "a", new byte[] { 1, 2 } }).getMessage(); + assertEquals("a[1, 2]", result); + + // int[] + result = MessageFormatter.arrayFormat("{}{}", new Object[] { "a", new int[] { 1, 2 } }).getMessage(); + assertEquals("a[1, 2]", result); + + // float[] + result = MessageFormatter.arrayFormat("{}{}", new Object[] { "a", new float[] { 1, 2 } }).getMessage(); + assertEquals("a[1.0, 2.0]", result); + + // double[] + result = MessageFormatter.arrayFormat("{}{}", new Object[] { "a", new double[] { 1, 2 } }).getMessage(); + assertEquals("a[1.0, 2.0]", result); + } + + @Test + public void testMultiDimensionalArrayValues() { + Integer[] ia0 = { 1, 2, 3 }; + Integer[] ia1 = { 10, 20, 30 }; + + Integer[][] multiIntegerA = { ia0, ia1 }; + String result = MessageFormatter.arrayFormat("{}{}", new Object[] { "a", multiIntegerA }).getMessage(); + assertEquals("a[[1, 2, 3], [10, 20, 30]]", result); + + int[][] multiIntA = { { 1, 2 }, { 10, 20 } }; + result = MessageFormatter.arrayFormat("{}{}", new Object[] { "a", multiIntA }).getMessage(); + assertEquals("a[[1, 2], [10, 20]]", result); + + float[][] multiFloatA = { { 1, 2 }, { 10, 20 } }; + result = MessageFormatter.arrayFormat("{}{}", new Object[] { "a", multiFloatA }).getMessage(); + assertEquals("a[[1.0, 2.0], [10.0, 20.0]]", result); + + Object[][] multiOA = { ia0, ia1 }; + result = MessageFormatter.arrayFormat("{}{}", new Object[] { "a", multiOA }).getMessage(); + assertEquals("a[[1, 2, 3], [10, 20, 30]]", result); + + Object[][][] _3DOA = { multiOA, multiOA }; + result = MessageFormatter.arrayFormat("{}{}", new Object[] { "a", _3DOA }).getMessage(); + assertEquals("a[[[1, 2, 3], [10, 20, 30]], [[1, 2, 3], [10, 20, 30]]]", result); + + Byte[] ba0 = { 0, Byte.MAX_VALUE, Byte.MIN_VALUE }; + Short[] sa0 = { 0, Short.MIN_VALUE, Short.MAX_VALUE }; + result = MessageFormatter.arrayFormat("{}\\{}{}", new Object[] { new Object[] { ba0, sa0 }, ia1 }).getMessage(); + assertEquals("[[0, 127, -128], [0, -32768, 32767]]{}[10, 20, 30]", result); + } + + @Test + public void testCyclicArrays() { + Object[] cyclicA = new Object[1]; + cyclicA[0] = cyclicA; + assertEquals("[[...]]", MessageFormatter.arrayFormat("{}", cyclicA).getMessage()); + + Object[] a = new Object[2]; + a[0] = 1; + Object[] c = { 3, a }; + Object[] b = { 2, c }; + a[1] = b; + assertEquals("1[2, [3, [1, [...]]]]", + MessageFormatter.arrayFormat("{}{}", a).getMessage()); + } + + @Test + public void testArrayThrowable() { + FormattingTuple ft; + Throwable t = new Throwable(); + Object[] ia = { 1, 2, 3, t }; + + ft = MessageFormatter.arrayFormat("Value {} is smaller than {} and {}.", ia); + assertEquals("Value 1 is smaller than 2 and 3.", ft.getMessage()); + assertEquals(t, ft.getThrowable()); + + ft = MessageFormatter.arrayFormat("{}{}{}", ia); + assertEquals("123", ft.getMessage()); + assertEquals(t, ft.getThrowable()); + + ft = MessageFormatter.arrayFormat("Value {} is smaller than {}.", ia); + assertEquals("Value 1 is smaller than 2.", ft.getMessage()); + assertEquals(t, ft.getThrowable()); + + ft = MessageFormatter.arrayFormat("Value {} is smaller than {}", ia); + assertEquals("Value 1 is smaller than 2", ft.getMessage()); + assertEquals(t, ft.getThrowable()); + + ft = MessageFormatter.arrayFormat("Val={}, {, Val={}", ia); + assertEquals("Val=1, {, Val=2", ft.getMessage()); + assertEquals(t, ft.getThrowable()); + + ft = MessageFormatter.arrayFormat("Val={}, \\{, Val={}", ia); + assertEquals("Val=1, \\{, Val=2", ft.getMessage()); + assertEquals(t, ft.getThrowable()); + + ft = MessageFormatter.arrayFormat("Val1={}, Val2={", ia); + assertEquals("Val1=1, Val2={", ft.getMessage()); + assertEquals(t, ft.getThrowable()); + + ft = MessageFormatter.arrayFormat("Value {} is smaller than {} and {}.", ia); + assertEquals("Value 1 is smaller than 2 and 3.", ft.getMessage()); + assertEquals(t, ft.getThrowable()); + + ft = MessageFormatter.arrayFormat("{}{}{}{}", ia); + assertEquals("123java.lang.Throwable", ft.getMessage()); + assertNull(ft.getThrowable()); + } +} diff --git a/netty-util/src/test/resources/logging.properties b/netty-util/src/test/resources/logging.properties new file mode 100644 index 0000000..3cd7309 --- /dev/null +++ b/netty-util/src/test/resources/logging.properties @@ -0,0 +1,7 @@ +handlers=java.util.logging.ConsoleHandler +.level=ALL +java.util.logging.SimpleFormatter.format=%1$tY-%1$tm-%1$td %1$tH:%1$tM:%1$tS.%1$tL %4$-7s [%3$s] %5$s %6$s%n +java.util.logging.ConsoleHandler.level=ALL +java.util.logging.ConsoleHandler.formatter=java.util.logging.SimpleFormatter +jdk.event.security.level=INFO +org.junit.jupiter.engine.execution.ConditionEvaluator.level=OFF diff --git a/netty-zlib/NOTICE.txt b/netty-zlib/NOTICE.txt new file mode 100644 index 0000000..ab999b7 --- /dev/null +++ b/netty-zlib/NOTICE.txt @@ -0,0 +1,3 @@ +This program is based on zlib-1.1.3, so all credit should go authors +Jean-loup Gailly(jloup@gzip.org) and Mark Adler(madler@alumni.caltech.edu) +and contributors of zlib. \ No newline at end of file diff --git a/netty-zlib/src/main/java/io/netty/zlib/Adler32.java b/netty-zlib/src/main/java/io/netty/zlib/Adler32.java new file mode 100644 index 0000000..b7e2f42 --- /dev/null +++ b/netty-zlib/src/main/java/io/netty/zlib/Adler32.java @@ -0,0 +1,108 @@ +package io.netty.zlib; + +final public class Adler32 implements Checksum { + + // largest prime smaller than 65536 + static final private int BASE = 65521; + // NMAX is the largest n such that 255n(n+1)/2 + (n+1)(BASE-1) <= 2^32-1 + static final private int NMAX = 5552; + + private long s1 = 1L; + private long s2 = 0L; + + public void reset(long init) { + s1 = init & 0xffff; + s2 = (init >> 16) & 0xffff; + } + + public void reset() { + s1 = 1L; + s2 = 0L; + } + + public long getValue() { + return ((s2 << 16) | s1); + } + + public void update(byte[] buf, int index, int len) { + + if (len == 1) { + s1 += buf[index++] & 0xff; + s2 += s1; + s1 %= BASE; + s2 %= BASE; + return; + } + + int len1 = len / NMAX; + int len2 = len % NMAX; + while (len1-- > 0) { + int k = NMAX; + len -= k; + while (k-- > 0) { + s1 += buf[index++] & 0xff; + s2 += s1; + } + s1 %= BASE; + s2 %= BASE; + } + + int k = len2; + len -= k; + while (k-- > 0) { + s1 += buf[index++] & 0xff; + s2 += s1; + } + s1 %= BASE; + s2 %= BASE; + } + + public Adler32 copy() { + Adler32 foo = new Adler32(); + foo.s1 = this.s1; + foo.s2 = this.s2; + return foo; + } + + // The following logic has come from zlib.1.2. + static long combine(long adler1, long adler2, long len2) { + long BASEL = BASE; + long sum1; + long sum2; + long rem; // unsigned int + + rem = len2 % BASEL; + sum1 = adler1 & 0xffffL; + sum2 = rem * sum1; + sum2 %= BASEL; // MOD(sum2); + sum1 += (adler2 & 0xffffL) + BASEL - 1; + sum2 += ((adler1 >> 16) & 0xffffL) + ((adler2 >> 16) & 0xffffL) + BASEL - rem; + if (sum1 >= BASEL) sum1 -= BASEL; + if (sum1 >= BASEL) sum1 -= BASEL; + if (sum2 >= (BASEL << 1)) sum2 -= (BASEL << 1); + if (sum2 >= BASEL) sum2 -= BASEL; + return sum1 | (sum2 << 16); + } + +/* + private java.util.zip.Adler32 adler=new java.util.zip.Adler32(); + public void update(byte[] buf, int index, int len){ + if(buf==null) {adler.reset();} + else{adler.update(buf, index, len);} + } + public void reset(){ + adler.reset(); + } + public void reset(long init){ + if(init==1L){ + adler.reset(); + } + else{ + System.err.println("unsupported operation"); + } + } + public long getValue(){ + return adler.getValue(); + } +*/ +} diff --git a/netty-zlib/src/main/java/io/netty/zlib/CRC32.java b/netty-zlib/src/main/java/io/netty/zlib/CRC32.java new file mode 100644 index 0000000..a642b51 --- /dev/null +++ b/netty-zlib/src/main/java/io/netty/zlib/CRC32.java @@ -0,0 +1,147 @@ +package io.netty.zlib; + +final public class CRC32 implements Checksum { + + /* + * The following logic has come from RFC1952. + */ + private int v = 0; + private static int[] crc_table = null; + + static { + crc_table = new int[256]; + for (int n = 0; n < 256; n++) { + int c = n; + for (int k = 8; --k >= 0; ) { + if ((c & 1) != 0) + c = 0xedb88320 ^ (c >>> 1); + else + c = c >>> 1; + } + crc_table[n] = c; + } + } + + public void update(byte[] buf, int index, int len) { + int c = ~v; + while (--len >= 0) + c = crc_table[(c ^ buf[index++]) & 0xff] ^ (c >>> 8); + v = ~c; + } + + public void reset() { + v = 0; + } + + public void reset(long vv) { + v = (int) (vv & 0xffffffffL); + } + + public long getValue() { + return v & 0xffffffffL; + } + + // The following logic has come from zlib.1.2. + private static final int GF2_DIM = 32; + + static long combine(long crc1, long crc2, long len2) { + long row; + long[] even = new long[GF2_DIM]; + long[] odd = new long[GF2_DIM]; + + // degenerate case (also disallow negative lengths) + if (len2 <= 0) + return crc1; + + // put operator for one zero bit in odd + odd[0] = 0xedb88320L; // CRC-32 polynomial + row = 1; + for (int n = 1; n < GF2_DIM; n++) { + odd[n] = row; + row <<= 1; + } + + // put operator for two zero bits in even + gf2_matrix_square(even, odd); + + // put operator for four zero bits in odd + gf2_matrix_square(odd, even); + + // apply len2 zeros to crc1 (first square will put the operator for one + // zero byte, eight zero bits, in even) + do { + // apply zeros operator for this bit of len2 + gf2_matrix_square(even, odd); + if ((len2 & 1) != 0) + crc1 = gf2_matrix_times(even, crc1); + len2 >>= 1; + + // if no more bits set, then done + if (len2 == 0) + break; + + // another iteration of the loop with odd and even swapped + gf2_matrix_square(odd, even); + if ((len2 & 1) != 0) + crc1 = gf2_matrix_times(odd, crc1); + len2 >>= 1; + + // if no more bits set, then done + } while (len2 != 0); + + /* return combined crc */ + crc1 ^= crc2; + return crc1; + } + + private static long gf2_matrix_times(long[] mat, long vec) { + long sum = 0; + int index = 0; + while (vec != 0) { + if ((vec & 1) != 0) + sum ^= mat[index]; + vec >>= 1; + index++; + } + return sum; + } + + static void gf2_matrix_square(long[] square, long[] mat) { + for (int n = 0; n < GF2_DIM; n++) + square[n] = gf2_matrix_times(mat, mat[n]); + } + + /* + private java.util.zip.CRC32 crc32 = new java.util.zip.CRC32(); + + public void update(byte[] buf, int index, int len){ + if(buf==null) {crc32.reset();} + else{crc32.update(buf, index, len);} + } + public void reset(){ + crc32.reset(); + } + public void reset(long init){ + if(init==0L){ + crc32.reset(); + } + else{ + System.err.println("unsupported operation"); + } + } + public long getValue(){ + return crc32.getValue(); + } + */ + public CRC32 copy() { + CRC32 foo = new CRC32(); + foo.v = this.v; + return foo; + } + + public static int[] getCRC32Table() { + int[] tmp = new int[crc_table.length]; + System.arraycopy(crc_table, 0, tmp, 0, tmp.length); + return tmp; + } +} diff --git a/netty-zlib/src/main/java/io/netty/zlib/Checksum.java b/netty-zlib/src/main/java/io/netty/zlib/Checksum.java new file mode 100644 index 0000000..39a5c0b --- /dev/null +++ b/netty-zlib/src/main/java/io/netty/zlib/Checksum.java @@ -0,0 +1,13 @@ +package io.netty.zlib; + +interface Checksum { + void update(byte[] buf, int index, int len); + + void reset(); + + void reset(long init); + + long getValue(); + + Checksum copy(); +} diff --git a/netty-zlib/src/main/java/io/netty/zlib/Deflate.java b/netty-zlib/src/main/java/io/netty/zlib/Deflate.java new file mode 100644 index 0000000..5bd89ec --- /dev/null +++ b/netty-zlib/src/main/java/io/netty/zlib/Deflate.java @@ -0,0 +1,1736 @@ +package io.netty.zlib; + +public +final class Deflate implements Cloneable { + + static final private int MAX_MEM_LEVEL = 9; + + static final private int Z_DEFAULT_COMPRESSION = -1; + + static final private int MAX_WBITS = 15; // 32K LZ77 window + static final private int DEF_MEM_LEVEL = 8; + + static class Config { + int good_length; // reduce lazy search above this match length + int max_lazy; // do not perform lazy search above this match length + int nice_length; // quit search above this match length + int max_chain; + int func; + + Config(int good_length, int max_lazy, + int nice_length, int max_chain, int func) { + this.good_length = good_length; + this.max_lazy = max_lazy; + this.nice_length = nice_length; + this.max_chain = max_chain; + this.func = func; + } + } + + static final private int STORED = 0; + static final private int FAST = 1; + static final private int SLOW = 2; + static final private Config[] config_table; + + static { + config_table = new Config[10]; + // good lazy nice chain + config_table[0] = new Config(0, 0, 0, 0, STORED); + config_table[1] = new Config(4, 4, 8, 4, FAST); + config_table[2] = new Config(4, 5, 16, 8, FAST); + config_table[3] = new Config(4, 6, 32, 32, FAST); + + config_table[4] = new Config(4, 4, 16, 16, SLOW); + config_table[5] = new Config(8, 16, 32, 32, SLOW); + config_table[6] = new Config(8, 16, 128, 128, SLOW); + config_table[7] = new Config(8, 32, 128, 256, SLOW); + config_table[8] = new Config(32, 128, 258, 1024, SLOW); + config_table[9] = new Config(32, 258, 258, 4096, SLOW); + } + + static final private String[] z_errmsg = { + "need dictionary", // Z_NEED_DICT 2 + "stream end", // Z_STREAM_END 1 + "", // Z_OK 0 + "file error", // Z_ERRNO (-1) + "stream error", // Z_STREAM_ERROR (-2) + "data error", // Z_DATA_ERROR (-3) + "insufficient memory", // Z_MEM_ERROR (-4) + "buffer error", // Z_BUF_ERROR (-5) + "incompatible version",// Z_VERSION_ERROR (-6) + "" + }; + + // block not completed, need more input or more output + static final private int NeedMore = 0; + + // block flush performed + static final private int BlockDone = 1; + + // finish started, need only more output at next deflate + static final private int FinishStarted = 2; + + // finish done, accept no more input or output + static final private int FinishDone = 3; + + // preset dictionary flag in zlib header + static final private int PRESET_DICT = 0x20; + + static final private int Z_FILTERED = 1; + static final private int Z_HUFFMAN_ONLY = 2; + static final private int Z_DEFAULT_STRATEGY = 0; + + static final private int Z_NO_FLUSH = 0; + static final private int Z_PARTIAL_FLUSH = 1; + static final private int Z_SYNC_FLUSH = 2; + static final private int Z_FULL_FLUSH = 3; + static final private int Z_FINISH = 4; + + static final private int Z_OK = 0; + static final private int Z_STREAM_END = 1; + static final private int Z_NEED_DICT = 2; + static final private int Z_ERRNO = -1; + static final private int Z_STREAM_ERROR = -2; + static final private int Z_DATA_ERROR = -3; + static final private int Z_MEM_ERROR = -4; + static final private int Z_BUF_ERROR = -5; + static final private int Z_VERSION_ERROR = -6; + + static final private int INIT_STATE = 42; + static final private int BUSY_STATE = 113; + static final private int FINISH_STATE = 666; + + // The deflate compression method + static final private int Z_DEFLATED = 8; + + static final private int STORED_BLOCK = 0; + static final private int STATIC_TREES = 1; + static final private int DYN_TREES = 2; + + // The three kinds of block type + static final private int Z_BINARY = 0; + static final private int Z_ASCII = 1; + static final private int Z_UNKNOWN = 2; + + static final private int Buf_size = 8 * 2; + + // repeat previous bit length 3-6 times (2 bits of repeat count) + static final private int REP_3_6 = 16; + + // repeat a zero length 3-10 times (3 bits of repeat count) + static final private int REPZ_3_10 = 17; + + // repeat a zero length 11-138 times (7 bits of repeat count) + static final private int REPZ_11_138 = 18; + + static final private int MIN_MATCH = 3; + static final private int MAX_MATCH = 258; + static final private int MIN_LOOKAHEAD = (MAX_MATCH + MIN_MATCH + 1); + + static final private int MAX_BITS = 15; + static final private int D_CODES = 30; + static final private int BL_CODES = 19; + static final private int LENGTH_CODES = 29; + static final private int LITERALS = 256; + static final private int L_CODES = (LITERALS + 1 + LENGTH_CODES); + static final private int HEAP_SIZE = (2 * L_CODES + 1); + + static final private int END_BLOCK = 256; + + ZStream strm; // pointer back to this zlib stream + int status; // as the name implies + byte[] pending_buf; // output still pending + int pending_buf_size; // size of pending_buf + int pending_out; // next pending byte to output to the stream + int pending; // nb of bytes in the pending buffer + int wrap = 1; + byte data_type; // UNKNOWN, BINARY or ASCII + byte method; // STORED (for zip only) or DEFLATED + int last_flush; // value of flush param for previous deflate call + + int w_size; // LZ77 window size (32K by default) + int w_bits; // log2(w_size) (8..16) + int w_mask; // w_size - 1 + + byte[] window; + // Sliding window. Input bytes are read into the second half of the window, + // and move to the first half later to keep a dictionary of at least wSize + // bytes. With this organization, matches are limited to a distance of + // wSize-MAX_MATCH bytes, but this ensures that IO is always + // performed with a length multiple of the block size. Also, it limits + // the window size to 64K, which is quite useful on MSDOS. + // To do: use the user input buffer as sliding window. + + int window_size; + // Actual size of window: 2*wSize, except when the user input buffer + // is directly used as sliding window. + + short[] prev; + // Link to older string with same hash index. To limit the size of this + // array to 64K, this link is maintained only for the last 32K strings. + // An index in this array is thus a window index modulo 32K. + + short[] head; // Heads of the hash chains or NIL. + + int ins_h; // hash index of string to be inserted + int hash_size; // number of elements in hash table + int hash_bits; // log2(hash_size) + int hash_mask; // hash_size-1 + + // Number of bits by which ins_h must be shifted at each input + // step. It must be such that after MIN_MATCH steps, the oldest + // byte no longer takes part in the hash key, that is: + // hash_shift * MIN_MATCH >= hash_bits + int hash_shift; + + // Window position at the beginning of the current output block. Gets + // negative when the window is moved backwards. + + int block_start; + + int match_length; // length of best match + int prev_match; // previous match + int match_available; // set if previous match exists + int strstart; // start of string to insert + int match_start; // start of matching string + int lookahead; // number of valid bytes ahead in window + + // Length of the best match at previous step. Matches not greater than this + // are discarded. This is used in the lazy match evaluation. + int prev_length; + + // To speed up deflation, hash chains are never searched beyond this + // length. A higher limit improves compression ratio but degrades the speed. + int max_chain_length; + + // Attempt to find a better match only when the current match is strictly + // smaller than this value. This mechanism is used only for compression + // levels >= 4. + int max_lazy_match; + + // Insert new strings in the hash table only if the match length is not + // greater than this length. This saves time but degrades compression. + // max_insert_length is used only for compression levels <= 3. + + int level; // compression level (1..9) + int strategy; // favor or force Huffman coding + + // Use a faster search when the previous match is longer than this + int good_match; + + // Stop searching when current match exceeds this + int nice_match; + + short[] dyn_ltree; // literal and length tree + short[] dyn_dtree; // distance tree + short[] bl_tree; // Huffman tree for bit lengths + + Tree l_desc = new Tree(); // desc for literal tree + Tree d_desc = new Tree(); // desc for distance tree + Tree bl_desc = new Tree(); // desc for bit length tree + + // number of codes at each bit length for an optimal tree + short[] bl_count = new short[MAX_BITS + 1]; + // working area to be used in Tree#gen_codes() + short[] next_code = new short[MAX_BITS + 1]; + + // heap used to build the Huffman trees + int[] heap = new int[2 * L_CODES + 1]; + + int heap_len; // number of elements in the heap + int heap_max; // element of largest frequency + // The sons of heap[n] are heap[2*n] and heap[2*n+1]. heap[0] is not used. + // The same heap array is used to build all trees. + + // Depth of each subtree used as tie breaker for trees of equal frequency + byte[] depth = new byte[2 * L_CODES + 1]; + + byte[] l_buf; // index for literals or lengths */ + + // Size of match buffer for literals/lengths. There are 4 reasons for + // limiting lit_bufsize to 64K: + // - frequencies can be kept in 16 bit counters + // - if compression is not successful for the first block, all input + // data is still in the window so we can still emit a stored block even + // when input comes from standard input. (This can also be done for + // all blocks if lit_bufsize is not greater than 32K.) + // - if compression is not successful for a file smaller than 64K, we can + // even emit a stored file instead of a stored block (saving 5 bytes). + // This is applicable only for zip (not gzip or zlib). + // - creating new Huffman trees less frequently may not provide fast + // adaptation to changes in the input data statistics. (Take for + // example a binary file with poorly compressible code followed by + // a highly compressible string table.) Smaller buffer sizes give + // fast adaptation but have of course the overhead of transmitting + // trees more frequently. + // - I can't count above 4 + int lit_bufsize; + + int last_lit; // running index in l_buf + + // Buffer for distances. To simplify the code, d_buf and l_buf have + // the same number of elements. To use different lengths, an extra flag + // array would be necessary. + + int d_buf; // index of pendig_buf + + int opt_len; // bit length of current block with optimal trees + int static_len; // bit length of current block with static trees + int matches; // number of string matches in current block + int last_eob_len; // bit length of EOB code for last block + + // Output buffer. bits are inserted starting at the bottom (least + // significant bits). + short bi_buf; + + // Number of valid bits in bi_buf. All bits above the last valid bit + // are always zero. + int bi_valid; + + GZIPHeader gheader = null; + + Deflate(ZStream strm) { + this.strm = strm; + dyn_ltree = new short[HEAP_SIZE * 2]; + dyn_dtree = new short[(2 * D_CODES + 1) * 2]; // distance tree + bl_tree = new short[(2 * BL_CODES + 1) * 2]; // Huffman tree for bit lengths + } + + void lm_init() { + window_size = 2 * w_size; + + head[hash_size - 1] = 0; + for (int i = 0; i < hash_size - 1; i++) { + head[i] = 0; + } + + // Set the default configuration parameters: + max_lazy_match = Deflate.config_table[level].max_lazy; + good_match = Deflate.config_table[level].good_length; + nice_match = Deflate.config_table[level].nice_length; + max_chain_length = Deflate.config_table[level].max_chain; + + strstart = 0; + block_start = 0; + lookahead = 0; + match_length = prev_length = MIN_MATCH - 1; + match_available = 0; + ins_h = 0; + } + + // Initialize the tree data structures for a new zlib stream. + void tr_init() { + + l_desc.dyn_tree = dyn_ltree; + l_desc.stat_desc = StaticTree.static_l_desc; + + d_desc.dyn_tree = dyn_dtree; + d_desc.stat_desc = StaticTree.static_d_desc; + + bl_desc.dyn_tree = bl_tree; + bl_desc.stat_desc = StaticTree.static_bl_desc; + + bi_buf = 0; + bi_valid = 0; + last_eob_len = 8; // enough lookahead for inflate + + // Initialize the first block of the first file: + init_block(); + } + + void init_block() { + // Initialize the trees. + for (int i = 0; i < L_CODES; i++) dyn_ltree[i * 2] = 0; + for (int i = 0; i < D_CODES; i++) dyn_dtree[i * 2] = 0; + for (int i = 0; i < BL_CODES; i++) bl_tree[i * 2] = 0; + + dyn_ltree[END_BLOCK * 2] = 1; + opt_len = static_len = 0; + last_lit = matches = 0; + } + + // Restore the heap property by moving down the tree starting at node k, + // exchanging a node with the smallest of its two sons if necessary, stopping + // when the heap property is re-established (each father smaller than its + // two sons). + void pqdownheap(short[] tree, // the tree to restore + int k // node to move down + ) { + int v = heap[k]; + int j = k << 1; // left son of k + while (j <= heap_len) { + // Set j to the smallest of the two sons: + if (j < heap_len && + smaller(tree, heap[j + 1], heap[j], depth)) { + j++; + } + // Exit if v is smaller than both sons + if (smaller(tree, v, heap[j], depth)) break; + + // Exchange v with the smallest son + heap[k] = heap[j]; + k = j; + // And continue down the tree, setting j to the left son of k + j <<= 1; + } + heap[k] = v; + } + + static boolean smaller(short[] tree, int n, int m, byte[] depth) { + short tn2 = tree[n * 2]; + short tm2 = tree[m * 2]; + return (tn2 < tm2 || + (tn2 == tm2 && depth[n] <= depth[m])); + } + + // Scan a literal or distance tree to determine the frequencies of the codes + // in the bit length tree. + void scan_tree(short[] tree,// the tree to be scanned + int max_code // and its largest code of non zero frequency + ) { + int n; // iterates over all tree elements + int prevlen = -1; // last emitted length + int curlen; // length of current code + int nextlen = tree[1]; // length of next code + int count = 0; // repeat count of the current code + int max_count = 7; // max repeat count + int min_count = 4; // min repeat count + + if (nextlen == 0) { + max_count = 138; + min_count = 3; + } + tree[(max_code + 1) * 2 + 1] = (short) 0xffff; // guard + + for (n = 0; n <= max_code; n++) { + curlen = nextlen; + nextlen = tree[(n + 1) * 2 + 1]; + if (++count < max_count && curlen == nextlen) { + continue; + } else if (count < min_count) { + bl_tree[curlen * 2] += count; + } else if (curlen != 0) { + if (curlen != prevlen) bl_tree[curlen * 2]++; + bl_tree[REP_3_6 * 2]++; + } else if (count <= 10) { + bl_tree[REPZ_3_10 * 2]++; + } else { + bl_tree[REPZ_11_138 * 2]++; + } + count = 0; + prevlen = curlen; + if (nextlen == 0) { + max_count = 138; + min_count = 3; + } else if (curlen == nextlen) { + max_count = 6; + min_count = 3; + } else { + max_count = 7; + min_count = 4; + } + } + } + + // Construct the Huffman tree for the bit lengths and return the index in + // bl_order of the last bit length code to send. + int build_bl_tree() { + int max_blindex; // index of last bit length code of non zero freq + + // Determine the bit length frequencies for literal and distance trees + scan_tree(dyn_ltree, l_desc.max_code); + scan_tree(dyn_dtree, d_desc.max_code); + + // Build the bit length tree: + bl_desc.build_tree(this); + // opt_len now includes the length of the tree representations, except + // the lengths of the bit lengths codes and the 5+5+4 bits for the counts. + + // Determine the number of bit length codes to send. The pkzip format + // requires that at least 4 bit length codes be sent. (appnote.txt says + // 3 but the actual value used is 4.) + for (max_blindex = BL_CODES - 1; max_blindex >= 3; max_blindex--) { + if (bl_tree[Tree.bl_order[max_blindex] * 2 + 1] != 0) break; + } + // Update opt_len to include the bit length tree and counts + opt_len += 3 * (max_blindex + 1) + 5 + 5 + 4; + + return max_blindex; + } + + + // Send the header for a block using dynamic Huffman trees: the counts, the + // lengths of the bit length codes, the literal tree and the distance tree. + // IN assertion: lcodes >= 257, dcodes >= 1, blcodes >= 4. + void send_all_trees(int lcodes, int dcodes, int blcodes) { + int rank; // index in bl_order + + send_bits(lcodes - 257, 5); // not +255 as stated in appnote.txt + send_bits(dcodes - 1, 5); + send_bits(blcodes - 4, 4); // not -3 as stated in appnote.txt + for (rank = 0; rank < blcodes; rank++) { + send_bits(bl_tree[Tree.bl_order[rank] * 2 + 1], 3); + } + send_tree(dyn_ltree, lcodes - 1); // literal tree + send_tree(dyn_dtree, dcodes - 1); // distance tree + } + + // Send a literal or distance tree in compressed form, using the codes in + // bl_tree. + void send_tree(short[] tree,// the tree to be sent + int max_code // and its largest code of non zero frequency + ) { + int n; // iterates over all tree elements + int prevlen = -1; // last emitted length + int curlen; // length of current code + int nextlen = tree[1]; // length of next code + int count = 0; // repeat count of the current code + int max_count = 7; // max repeat count + int min_count = 4; // min repeat count + + if (nextlen == 0) { + max_count = 138; + min_count = 3; + } + + for (n = 0; n <= max_code; n++) { + curlen = nextlen; + nextlen = tree[(n + 1) * 2 + 1]; + if (++count < max_count && curlen == nextlen) { + continue; + } else if (count < min_count) { + do { + send_code(curlen, bl_tree); + } while (--count != 0); + } else if (curlen != 0) { + if (curlen != prevlen) { + send_code(curlen, bl_tree); + count--; + } + send_code(REP_3_6, bl_tree); + send_bits(count - 3, 2); + } else if (count <= 10) { + send_code(REPZ_3_10, bl_tree); + send_bits(count - 3, 3); + } else { + send_code(REPZ_11_138, bl_tree); + send_bits(count - 11, 7); + } + count = 0; + prevlen = curlen; + if (nextlen == 0) { + max_count = 138; + min_count = 3; + } else if (curlen == nextlen) { + max_count = 6; + min_count = 3; + } else { + max_count = 7; + min_count = 4; + } + } + } + + // Output a byte on the stream. + // IN assertion: there is enough room in pending_buf. + void put_byte(byte[] p, int start, int len) { + System.arraycopy(p, start, pending_buf, pending, len); + pending += len; + } + + void put_byte(byte c) { + pending_buf[pending++] = c; + } + + void put_short(int w) { + put_byte((byte) (w/*&0xff*/)); + put_byte((byte) (w >>> 8)); + } + + void putShortMSB(int b) { + put_byte((byte) (b >> 8)); + put_byte((byte) (b/*&0xff*/)); + } + + void send_code(int c, short[] tree) { + int c2 = c * 2; + send_bits((tree[c2] & 0xffff), (tree[c2 + 1] & 0xffff)); + } + + void send_bits(int value, int length) { + int len = length; + if (bi_valid > Buf_size - len) { + int val = value; +// bi_buf |= (val << bi_valid); + bi_buf |= ((val << bi_valid) & 0xffff); + put_short(bi_buf); + bi_buf = (short) (val >>> (Buf_size - bi_valid)); + bi_valid += len - Buf_size; + } else { +// bi_buf |= (value) << bi_valid; + bi_buf |= (((value) << bi_valid) & 0xffff); + bi_valid += len; + } + } + + // Send one empty static block to give enough lookahead for inflate. + // This takes 10 bits, of which 7 may remain in the bit buffer. + // The current inflate code requires 9 bits of lookahead. If the + // last two codes for the previous block (real code plus EOB) were coded + // on 5 bits or less, inflate may have only 5+3 bits of lookahead to decode + // the last real code. In this case we send two empty static blocks instead + // of one. (There are no problems if the previous block is stored or fixed.) + // To simplify the code, we assume the worst case of last real code encoded + // on one bit only. + void _tr_align() { + send_bits(STATIC_TREES << 1, 3); + send_code(END_BLOCK, StaticTree.static_ltree); + + bi_flush(); + + // Of the 10 bits for the empty block, we have already sent + // (10 - bi_valid) bits. The lookahead for the last real code (before + // the EOB of the previous block) was thus at least one plus the length + // of the EOB plus what we have just sent of the empty static block. + if (1 + last_eob_len + 10 - bi_valid < 9) { + send_bits(STATIC_TREES << 1, 3); + send_code(END_BLOCK, StaticTree.static_ltree); + bi_flush(); + } + last_eob_len = 7; + } + + + // Save the match info and tally the frequency counts. Return true if + // the current block must be flushed. + boolean _tr_tally(int dist, // distance of matched string + int lc // match length-MIN_MATCH or unmatched char (if dist==0) + ) { + + pending_buf[d_buf + last_lit * 2] = (byte) (dist >>> 8); + pending_buf[d_buf + last_lit * 2 + 1] = (byte) dist; + + l_buf[last_lit] = (byte) lc; + last_lit++; + + if (dist == 0) { + // lc is the unmatched char + dyn_ltree[lc * 2]++; + } else { + matches++; + // Here, lc is the match length - MIN_MATCH + dist--; // dist = match distance - 1 + dyn_ltree[(Tree._length_code[lc] + LITERALS + 1) * 2]++; + dyn_dtree[Tree.d_code(dist) * 2]++; + } + + if ((last_lit & 0x1fff) == 0 && level > 2) { + // Compute an upper bound for the compressed length + int out_length = last_lit * 8; + int in_length = strstart - block_start; + int dcode; + for (dcode = 0; dcode < D_CODES; dcode++) { + out_length += (int) dyn_dtree[dcode * 2] * + (5L + Tree.extra_dbits[dcode]); + } + out_length >>>= 3; + if ((matches < (last_lit / 2)) && out_length < in_length / 2) return true; + } + + return (last_lit == lit_bufsize - 1); + // We avoid equality with lit_bufsize because of wraparound at 64K + // on 16 bit machines and because stored blocks are restricted to + // 64K-1 bytes. + } + + // Send the block data compressed using the given Huffman trees + void compress_block(short[] ltree, short[] dtree) { + int dist; // distance of matched string + int lc; // match length or unmatched char (if dist == 0) + int lx = 0; // running index in l_buf + int code; // the code to send + int extra; // number of extra bits to send + + if (last_lit != 0) { + do { + dist = ((pending_buf[d_buf + lx * 2] << 8) & 0xff00) | + (pending_buf[d_buf + lx * 2 + 1] & 0xff); + lc = (l_buf[lx]) & 0xff; + lx++; + + if (dist == 0) { + send_code(lc, ltree); // send a literal byte + } else { + // Here, lc is the match length - MIN_MATCH + code = Tree._length_code[lc]; + + send_code(code + LITERALS + 1, ltree); // send the length code + extra = Tree.extra_lbits[code]; + if (extra != 0) { + lc -= Tree.base_length[code]; + send_bits(lc, extra); // send the extra length bits + } + dist--; // dist is now the match distance - 1 + code = Tree.d_code(dist); + + send_code(code, dtree); // send the distance code + extra = Tree.extra_dbits[code]; + if (extra != 0) { + dist -= Tree.base_dist[code]; + send_bits(dist, extra); // send the extra distance bits + } + } // literal or match pair ? + + // Check that the overlay between pending_buf and d_buf+l_buf is ok: + } + while (lx < last_lit); + } + + send_code(END_BLOCK, ltree); + last_eob_len = ltree[END_BLOCK * 2 + 1]; + } + + // Set the data type to ASCII or BINARY, using a crude approximation: + // binary if more than 20% of the bytes are <= 6 or >= 128, ascii otherwise. + // IN assertion: the fields freq of dyn_ltree are set and the total of all + // frequencies does not exceed 64K (to fit in an int on 16 bit machines). + void set_data_type() { + int n = 0; + int ascii_freq = 0; + int bin_freq = 0; + while (n < 7) { + bin_freq += dyn_ltree[n * 2]; + n++; + } + while (n < 128) { + ascii_freq += dyn_ltree[n * 2]; + n++; + } + while (n < LITERALS) { + bin_freq += dyn_ltree[n * 2]; + n++; + } + data_type = (byte) (bin_freq > (ascii_freq >>> 2) ? Z_BINARY : Z_ASCII); + } + + // Flush the bit buffer, keeping at most 7 bits in it. + void bi_flush() { + if (bi_valid == 16) { + put_short(bi_buf); + bi_buf = 0; + bi_valid = 0; + } else if (bi_valid >= 8) { + put_byte((byte) bi_buf); + bi_buf >>>= 8; + bi_valid -= 8; + } + } + + // Flush the bit buffer and align the output on a byte boundary + void bi_windup() { + if (bi_valid > 8) { + put_short(bi_buf); + } else if (bi_valid > 0) { + put_byte((byte) bi_buf); + } + bi_buf = 0; + bi_valid = 0; + } + + // Copy a stored block, storing first the length and its + // one's complement if requested. + void copy_block(int buf, // the input data + int len, // its length + boolean header // true if block header must be written + ) { + int index = 0; + bi_windup(); // align on byte boundary + last_eob_len = 8; // enough lookahead for inflate + + if (header) { + put_short((short) len); + put_short((short) ~len); + } + + // while(len--!=0) { + // put_byte(window[buf+index]); + // index++; + // } + put_byte(window, buf, len); + } + + void flush_block_only(boolean eof) { + _tr_flush_block(block_start >= 0 ? block_start : -1, + strstart - block_start, + eof); + block_start = strstart; + strm.flush_pending(); + } + + // Copy without compression as much as possible from the input stream, return + // the current block state. + // This function does not insert new strings in the dictionary since + // uncompressible data is probably not useful. This function is used + // only for the level=0 compression option. + // NOTE: this function should be optimized to avoid extra copying from + // window to pending_buf. + int deflate_stored(int flush) { + // Stored blocks are limited to 0xffff bytes, pending_buf is limited + // to pending_buf_size, and each stored block has a 5 byte header: + + int max_block_size = 0xffff; + int max_start; + + if (max_block_size > pending_buf_size - 5) { + max_block_size = pending_buf_size - 5; + } + + // Copy as much as possible from input to output: + while (true) { + // Fill the window as much as possible: + if (lookahead <= 1) { + fill_window(); + if (lookahead == 0 && flush == Z_NO_FLUSH) return NeedMore; + if (lookahead == 0) break; // flush the current block + } + + strstart += lookahead; + lookahead = 0; + + // Emit a stored block if pending_buf will be full: + max_start = block_start + max_block_size; + if (strstart == 0 || strstart >= max_start) { + // strstart == 0 is possible when wraparound on 16-bit machine + lookahead = strstart - max_start; + strstart = max_start; + + flush_block_only(false); + if (strm.avail_out == 0) return NeedMore; + + } + + // Flush if we may have to slide, otherwise block_start may become + // negative and the data will be gone: + if (strstart - block_start >= w_size - MIN_LOOKAHEAD) { + flush_block_only(false); + if (strm.avail_out == 0) return NeedMore; + } + } + + flush_block_only(flush == Z_FINISH); + if (strm.avail_out == 0) + return (flush == Z_FINISH) ? FinishStarted : NeedMore; + + return flush == Z_FINISH ? FinishDone : BlockDone; + } + + // Send a stored block + void _tr_stored_block(int buf, // input block + int stored_len, // length of input block + boolean eof // true if this is the last block for a file + ) { + send_bits((eof ? 1 : 0), 3); // send block type + copy_block(buf, stored_len, true); // with header + } + + // Determine the best encoding for the current block: dynamic trees, static + // trees or store, and output the encoded block to the zip file. + void _tr_flush_block(int buf, // input block, or NULL if too old + int stored_len, // length of input block + boolean eof // true if this is the last block for a file + ) { + int opt_lenb, static_lenb;// opt_len and static_len in bytes + int max_blindex = 0; // index of last bit length code of non zero freq + + // Build the Huffman trees unless a stored block is forced + if (level > 0) { + // Check if the file is ascii or binary + if (data_type == Z_UNKNOWN) set_data_type(); + + // Construct the literal and distance trees + l_desc.build_tree(this); + + d_desc.build_tree(this); + + // At this point, opt_len and static_len are the total bit lengths of + // the compressed block data, excluding the tree representations. + + // Build the bit length tree for the above two trees, and get the index + // in bl_order of the last bit length code to send. + max_blindex = build_bl_tree(); + + // Determine the best encoding. Compute first the block length in bytes + opt_lenb = (opt_len + 3 + 7) >>> 3; + static_lenb = (static_len + 3 + 7) >>> 3; + + if (static_lenb <= opt_lenb) opt_lenb = static_lenb; + } else { + opt_lenb = static_lenb = stored_len + 5; // force a stored block + } + + if (stored_len + 4 <= opt_lenb && buf != -1) { + // 4: two words for the lengths + // The test buf != NULL is only necessary if LIT_BUFSIZE > WSIZE. + // Otherwise we can't have processed more than WSIZE input bytes since + // the last block flush, because compression would have been + // successful. If LIT_BUFSIZE <= WSIZE, it is never too late to + // transform a block into a stored block. + _tr_stored_block(buf, stored_len, eof); + } else if (static_lenb == opt_lenb) { + send_bits((STATIC_TREES << 1) + (eof ? 1 : 0), 3); + compress_block(StaticTree.static_ltree, StaticTree.static_dtree); + } else { + send_bits((DYN_TREES << 1) + (eof ? 1 : 0), 3); + send_all_trees(l_desc.max_code + 1, d_desc.max_code + 1, max_blindex + 1); + compress_block(dyn_ltree, dyn_dtree); + } + + // The above check is made mod 2^32, for files larger than 512 MB + // and uLong implemented on 32 bits. + + init_block(); + + if (eof) { + bi_windup(); + } + } + + // Fill the window when the lookahead becomes insufficient. + // Updates strstart and lookahead. + // + // IN assertion: lookahead < MIN_LOOKAHEAD + // OUT assertions: strstart <= window_size-MIN_LOOKAHEAD + // At least one byte has been read, or avail_in == 0; reads are + // performed for at least two bytes (required for the zip translate_eol + // option -- not supported here). + void fill_window() { + int n, m; + int p; + int more; // Amount of free space at the end of the window. + + do { + more = (window_size - lookahead - strstart); + + // Deal with !@#$% 64K limit: + if (more == 0 && strstart == 0 && lookahead == 0) { + more = w_size; + } else if (more == -1) { + // Very unlikely, but possible on 16 bit machine if strstart == 0 + // and lookahead == 1 (input done one byte at time) + more--; + + // If the window is almost full and there is insufficient lookahead, + // move the upper half to the lower one to make room in the upper half. + } else if (strstart >= w_size + w_size - MIN_LOOKAHEAD) { + System.arraycopy(window, w_size, window, 0, w_size); + match_start -= w_size; + strstart -= w_size; // we now have strstart >= MAX_DIST + block_start -= w_size; + + // Slide the hash table (could be avoided with 32 bit values + // at the expense of memory usage). We slide even when level == 0 + // to keep the hash table consistent if we switch back to level > 0 + // later. (Using level 0 permanently is not an optimal usage of + // zlib, so we don't care about this pathological case.) + + n = hash_size; + p = n; + do { + m = (head[--p] & 0xffff); + head[p] = (m >= w_size ? (short) (m - w_size) : 0); + } + while (--n != 0); + + n = w_size; + p = n; + do { + m = (prev[--p] & 0xffff); + prev[p] = (m >= w_size ? (short) (m - w_size) : 0); + // If n is not on any hash chain, prev[n] is garbage but + // its value will never be used. + } + while (--n != 0); + more += w_size; + } + + if (strm.avail_in == 0) return; + + // If there was no sliding: + // strstart <= WSIZE+MAX_DIST-1 && lookahead <= MIN_LOOKAHEAD - 1 && + // more == window_size - lookahead - strstart + // => more >= window_size - (MIN_LOOKAHEAD-1 + WSIZE + MAX_DIST-1) + // => more >= window_size - 2*WSIZE + 2 + // In the BIG_MEM or MMAP case (not yet supported), + // window_size == input_size + MIN_LOOKAHEAD && + // strstart + s->lookahead <= input_size => more >= MIN_LOOKAHEAD. + // Otherwise, window_size == 2*WSIZE so more >= 2. + // If there was sliding, more >= WSIZE. So in all cases, more >= 2. + + n = strm.read_buf(window, strstart + lookahead, more); + lookahead += n; + + // Initialize the hash value now that we have some input: + if (lookahead >= MIN_MATCH) { + ins_h = window[strstart] & 0xff; + ins_h = (((ins_h) << hash_shift) ^ (window[strstart + 1] & 0xff)) & hash_mask; + } + // If the whole input has less than MIN_MATCH bytes, ins_h is garbage, + // but this is not important since only literal bytes will be emitted. + } + while (lookahead < MIN_LOOKAHEAD && strm.avail_in != 0); + } + + // Compress as much as possible from the input stream, return the current + // block state. + // This function does not perform lazy evaluation of matches and inserts + // new strings in the dictionary only for unmatched strings or for short + // matches. It is used only for the fast compression options. + int deflate_fast(int flush) { +// short hash_head = 0; // head of the hash chain + int hash_head = 0; // head of the hash chain + boolean bflush; // set if current block must be flushed + + while (true) { + // Make sure that we always have enough lookahead, except + // at the end of the input file. We need MAX_MATCH bytes + // for the next match, plus MIN_MATCH bytes to insert the + // string following the next match. + if (lookahead < MIN_LOOKAHEAD) { + fill_window(); + if (lookahead < MIN_LOOKAHEAD && flush == Z_NO_FLUSH) { + return NeedMore; + } + if (lookahead == 0) break; // flush the current block + } + + // Insert the string window[strstart .. strstart+2] in the + // dictionary, and set hash_head to the head of the hash chain: + if (lookahead >= MIN_MATCH) { + ins_h = (((ins_h) << hash_shift) ^ (window[(strstart) + (MIN_MATCH - 1)] & 0xff)) & hash_mask; + +// prev[strstart&w_mask]=hash_head=head[ins_h]; + hash_head = (head[ins_h] & 0xffff); + prev[strstart & w_mask] = head[ins_h]; + head[ins_h] = (short) strstart; + } + + // Find the longest match, discarding those <= prev_length. + // At this point we have always match_length < MIN_MATCH + + if (hash_head != 0L && + ((strstart - hash_head) & 0xffff) <= w_size - MIN_LOOKAHEAD + ) { + // To simplify the code, we prevent matches with the string + // of window index 0 (in particular we have to avoid a match + // of the string with itself at the start of the input file). + if (strategy != Z_HUFFMAN_ONLY) { + match_length = longest_match(hash_head); + } + // longest_match() sets match_start + } + if (match_length >= MIN_MATCH) { + // check_match(strstart, match_start, match_length); + + bflush = _tr_tally(strstart - match_start, match_length - MIN_MATCH); + + lookahead -= match_length; + + // Insert new strings in the hash table only if the match length + // is not too large. This saves time but degrades compression. + if (match_length <= max_lazy_match && + lookahead >= MIN_MATCH) { + match_length--; // string at strstart already in hash table + do { + strstart++; + + ins_h = ((ins_h << hash_shift) ^ (window[(strstart) + (MIN_MATCH - 1)] & 0xff)) & hash_mask; +// prev[strstart&w_mask]=hash_head=head[ins_h]; + hash_head = (head[ins_h] & 0xffff); + prev[strstart & w_mask] = head[ins_h]; + head[ins_h] = (short) strstart; + + // strstart never exceeds WSIZE-MAX_MATCH, so there are + // always MIN_MATCH bytes ahead. + } + while (--match_length != 0); + strstart++; + } else { + strstart += match_length; + match_length = 0; + ins_h = window[strstart] & 0xff; + + ins_h = (((ins_h) << hash_shift) ^ (window[strstart + 1] & 0xff)) & hash_mask; + // If lookahead < MIN_MATCH, ins_h is garbage, but it does not + // matter since it will be recomputed at next deflate call. + } + } else { + // No match, output a literal byte + + bflush = _tr_tally(0, window[strstart] & 0xff); + lookahead--; + strstart++; + } + if (bflush) { + + flush_block_only(false); + if (strm.avail_out == 0) return NeedMore; + } + } + + flush_block_only(flush == Z_FINISH); + if (strm.avail_out == 0) { + if (flush == Z_FINISH) return FinishStarted; + else return NeedMore; + } + return flush == Z_FINISH ? FinishDone : BlockDone; + } + + // Same as above, but achieves better compression. We use a lazy + // evaluation for matches: a match is finally adopted only if there is + // no better match at the next window position. + int deflate_slow(int flush) { +// short hash_head = 0; // head of hash chain + int hash_head = 0; // head of hash chain + boolean bflush; // set if current block must be flushed + + // Process the input block. + while (true) { + // Make sure that we always have enough lookahead, except + // at the end of the input file. We need MAX_MATCH bytes + // for the next match, plus MIN_MATCH bytes to insert the + // string following the next match. + + if (lookahead < MIN_LOOKAHEAD) { + fill_window(); + if (lookahead < MIN_LOOKAHEAD && flush == Z_NO_FLUSH) { + return NeedMore; + } + if (lookahead == 0) break; // flush the current block + } + + // Insert the string window[strstart .. strstart+2] in the + // dictionary, and set hash_head to the head of the hash chain: + + if (lookahead >= MIN_MATCH) { + ins_h = (((ins_h) << hash_shift) ^ (window[(strstart) + (MIN_MATCH - 1)] & 0xff)) & hash_mask; +// prev[strstart&w_mask]=hash_head=head[ins_h]; + hash_head = (head[ins_h] & 0xffff); + prev[strstart & w_mask] = head[ins_h]; + head[ins_h] = (short) strstart; + } + + // Find the longest match, discarding those <= prev_length. + prev_length = match_length; + prev_match = match_start; + match_length = MIN_MATCH - 1; + + if (hash_head != 0 && prev_length < max_lazy_match && + ((strstart - hash_head) & 0xffff) <= w_size - MIN_LOOKAHEAD + ) { + // To simplify the code, we prevent matches with the string + // of window index 0 (in particular we have to avoid a match + // of the string with itself at the start of the input file). + + if (strategy != Z_HUFFMAN_ONLY) { + match_length = longest_match(hash_head); + } + // longest_match() sets match_start + + if (match_length <= 5 && (strategy == Z_FILTERED || + (match_length == MIN_MATCH && + strstart - match_start > 4096))) { + + // If prev_match is also MIN_MATCH, match_start is garbage + // but we will ignore the current match anyway. + match_length = MIN_MATCH - 1; + } + } + + // If there was a match at the previous step and the current + // match is not better, output the previous match: + if (prev_length >= MIN_MATCH && match_length <= prev_length) { + int max_insert = strstart + lookahead - MIN_MATCH; + // Do not insert strings in hash table beyond this. + + // check_match(strstart-1, prev_match, prev_length); + + bflush = _tr_tally(strstart - 1 - prev_match, prev_length - MIN_MATCH); + + // Insert in hash table all strings up to the end of the match. + // strstart-1 and strstart are already inserted. If there is not + // enough lookahead, the last two strings are not inserted in + // the hash table. + lookahead -= prev_length - 1; + prev_length -= 2; + do { + if (++strstart <= max_insert) { + ins_h = (((ins_h) << hash_shift) ^ (window[(strstart) + (MIN_MATCH - 1)] & 0xff)) & hash_mask; + //prev[strstart&w_mask]=hash_head=head[ins_h]; + hash_head = (head[ins_h] & 0xffff); + prev[strstart & w_mask] = head[ins_h]; + head[ins_h] = (short) strstart; + } + } + while (--prev_length != 0); + match_available = 0; + match_length = MIN_MATCH - 1; + strstart++; + + if (bflush) { + flush_block_only(false); + if (strm.avail_out == 0) return NeedMore; + } + } else if (match_available != 0) { + + // If there was no match at the previous position, output a + // single literal. If there was a match but the current match + // is longer, truncate the previous match to a single literal. + + bflush = _tr_tally(0, window[strstart - 1] & 0xff); + + if (bflush) { + flush_block_only(false); + } + strstart++; + lookahead--; + if (strm.avail_out == 0) return NeedMore; + } else { + // There is no previous match to compare with, wait for + // the next step to decide. + + match_available = 1; + strstart++; + lookahead--; + } + } + + if (match_available != 0) { + bflush = _tr_tally(0, window[strstart - 1] & 0xff); + match_available = 0; + } + flush_block_only(flush == Z_FINISH); + + if (strm.avail_out == 0) { + if (flush == Z_FINISH) return FinishStarted; + else return NeedMore; + } + + return flush == Z_FINISH ? FinishDone : BlockDone; + } + + int longest_match(int cur_match) { + int chain_length = max_chain_length; // max hash chain length + int scan = strstart; // current string + int match; // matched string + int len; // length of current match + int best_len = prev_length; // best match length so far + int limit = strstart > (w_size - MIN_LOOKAHEAD) ? + strstart - (w_size - MIN_LOOKAHEAD) : 0; + int nice_match = this.nice_match; + + // Stop when cur_match becomes <= limit. To simplify the code, + // we prevent matches with the string of window index 0. + + int wmask = w_mask; + + int strend = strstart + MAX_MATCH; + byte scan_end1 = window[scan + best_len - 1]; + byte scan_end = window[scan + best_len]; + + // The code is optimized for HASH_BITS >= 8 and MAX_MATCH-2 multiple of 16. + // It is easy to get rid of this optimization if necessary. + + // Do not waste too much time if we already have a good match: + if (prev_length >= good_match) { + chain_length >>= 2; + } + + // Do not look for matches beyond the end of the input. This is necessary + // to make deflate deterministic. + if (nice_match > lookahead) nice_match = lookahead; + + do { + match = cur_match; + + // Skip to next match if the match length cannot increase + // or if the match length is less than 2: + if (window[match + best_len] != scan_end || + window[match + best_len - 1] != scan_end1 || + window[match] != window[scan] || + window[++match] != window[scan + 1]) continue; + + // The check at best_len-1 can be removed because it will be made + // again later. (This heuristic is not always a win.) + // It is not necessary to compare scan[2] and match[2] since they + // are always equal when the other bytes match, given that + // the hash keys are equal and that HASH_BITS >= 8. + scan += 2; + match++; + + // We check for insufficient lookahead only every 8th comparison; + // the 256th check will be made at strstart+258. + do { + } while (window[++scan] == window[++match] && + window[++scan] == window[++match] && + window[++scan] == window[++match] && + window[++scan] == window[++match] && + window[++scan] == window[++match] && + window[++scan] == window[++match] && + window[++scan] == window[++match] && + window[++scan] == window[++match] && + scan < strend); + + len = MAX_MATCH - (strend - scan); + scan = strend - MAX_MATCH; + + if (len > best_len) { + match_start = cur_match; + best_len = len; + if (len >= nice_match) break; + scan_end1 = window[scan + best_len - 1]; + scan_end = window[scan + best_len]; + } + + } while ((cur_match = (prev[cur_match & wmask] & 0xffff)) > limit + && --chain_length != 0); + + if (best_len <= lookahead) return best_len; + return lookahead; + } + + int deflateInit(int level, int bits, int memlevel) { + return deflateInit(level, Z_DEFLATED, bits, memlevel, + Z_DEFAULT_STRATEGY); + } + + int deflateInit(int level, int bits) { + return deflateInit(level, Z_DEFLATED, bits, DEF_MEM_LEVEL, + Z_DEFAULT_STRATEGY); + } + + int deflateInit(int level) { + return deflateInit(level, MAX_WBITS); + } + + private int deflateInit(int level, int method, int windowBits, + int memLevel, int strategy) { + int wrap = 1; + // byte[] my_version=ZLIB_VERSION; + + // + // if (version == null || version[0] != my_version[0] + // || stream_size != sizeof(z_stream)) { + // return Z_VERSION_ERROR; + // } + + strm.msg = null; + + if (level == Z_DEFAULT_COMPRESSION) level = 6; + + if (windowBits < 0) { // undocumented feature: suppress zlib header + wrap = 0; + windowBits = -windowBits; + } else if (windowBits > 15) { + wrap = 2; + windowBits -= 16; + strm.adler = new CRC32(); + } + + if (memLevel < 1 || memLevel > MAX_MEM_LEVEL || + method != Z_DEFLATED || + windowBits < 9 || windowBits > 15 || level < 0 || level > 9 || + strategy < 0 || strategy > Z_HUFFMAN_ONLY) { + return Z_STREAM_ERROR; + } + + strm.dstate = this; + + this.wrap = wrap; + w_bits = windowBits; + w_size = 1 << w_bits; + w_mask = w_size - 1; + + hash_bits = memLevel + 7; + hash_size = 1 << hash_bits; + hash_mask = hash_size - 1; + hash_shift = ((hash_bits + MIN_MATCH - 1) / MIN_MATCH); + + window = new byte[w_size * 2]; + prev = new short[w_size]; + head = new short[hash_size]; + + lit_bufsize = 1 << (memLevel + 6); // 16K elements by default + + // We overlay pending_buf and d_buf+l_buf. This works since the average + // output size for (length,distance) codes is <= 24 bits. + pending_buf = new byte[lit_bufsize * 3]; + pending_buf_size = lit_bufsize * 3; + + d_buf = lit_bufsize; + l_buf = new byte[lit_bufsize]; + + this.level = level; + + this.strategy = strategy; + this.method = (byte) method; + + return deflateReset(); + } + + int deflateReset() { + strm.total_in = strm.total_out = 0; + strm.msg = null; // + strm.data_type = Z_UNKNOWN; + + pending = 0; + pending_out = 0; + + if (wrap < 0) { + wrap = -wrap; + } + status = (wrap == 0) ? BUSY_STATE : INIT_STATE; + strm.adler.reset(); + + last_flush = Z_NO_FLUSH; + + tr_init(); + lm_init(); + return Z_OK; + } + + int deflateEnd() { + if (status != INIT_STATE && status != BUSY_STATE && status != FINISH_STATE) { + return Z_STREAM_ERROR; + } + // Deallocate in reverse order of allocations: + pending_buf = null; + l_buf = null; + head = null; + prev = null; + window = null; + // free + // dstate=null; + return status == BUSY_STATE ? Z_DATA_ERROR : Z_OK; + } + + int deflateParams(int _level, int _strategy) { + int err = Z_OK; + + if (_level == Z_DEFAULT_COMPRESSION) { + _level = 6; + } + if (_level < 0 || _level > 9 || + _strategy < 0 || _strategy > Z_HUFFMAN_ONLY) { + return Z_STREAM_ERROR; + } + + if (config_table[level].func != config_table[_level].func && + strm.total_in != 0) { + // Flush the last buffer: + err = strm.deflate(Z_PARTIAL_FLUSH); + } + + if (level != _level) { + level = _level; + max_lazy_match = config_table[level].max_lazy; + good_match = config_table[level].good_length; + nice_match = config_table[level].nice_length; + max_chain_length = config_table[level].max_chain; + } + strategy = _strategy; + return err; + } + + int deflateSetDictionary(byte[] dictionary, int dictLength) { + int length = dictLength; + int index = 0; + + if (dictionary == null || status != INIT_STATE) + return Z_STREAM_ERROR; + + strm.adler.update(dictionary, 0, dictLength); + + if (length < MIN_MATCH) return Z_OK; + if (length > w_size - MIN_LOOKAHEAD) { + length = w_size - MIN_LOOKAHEAD; + index = dictLength - length; // use the tail of the dictionary + } + System.arraycopy(dictionary, index, window, 0, length); + strstart = length; + block_start = length; + + // Insert all strings in the hash table (except for the last two bytes). + // s->lookahead stays null, so s->ins_h will be recomputed at the next + // call of fill_window. + + ins_h = window[0] & 0xff; + ins_h = (((ins_h) << hash_shift) ^ (window[1] & 0xff)) & hash_mask; + + for (int n = 0; n <= length - MIN_MATCH; n++) { + ins_h = (((ins_h) << hash_shift) ^ (window[(n) + (MIN_MATCH - 1)] & 0xff)) & hash_mask; + prev[n & w_mask] = head[ins_h]; + head[ins_h] = (short) n; + } + return Z_OK; + } + + int deflate(int flush) { + int old_flush; + + if (flush > Z_FINISH || flush < 0) { + return Z_STREAM_ERROR; + } + + if (strm.next_out == null || + (strm.next_in == null && strm.avail_in != 0) || + (status == FINISH_STATE && flush != Z_FINISH)) { + strm.msg = z_errmsg[Z_NEED_DICT - (Z_STREAM_ERROR)]; + return Z_STREAM_ERROR; + } + if (strm.avail_out == 0) { + strm.msg = z_errmsg[Z_NEED_DICT - (Z_BUF_ERROR)]; + return Z_BUF_ERROR; + } + + old_flush = last_flush; + last_flush = flush; + + // Write the zlib header + if (status == INIT_STATE) { + if (wrap == 2) { + getGZIPHeader().put(this); + status = BUSY_STATE; + strm.adler.reset(); + } else { + int header = (Z_DEFLATED + ((w_bits - 8) << 4)) << 8; + int level_flags = ((level - 1) & 0xff) >> 1; + + if (level_flags > 3) level_flags = 3; + header |= (level_flags << 6); + if (strstart != 0) header |= PRESET_DICT; + header += 31 - (header % 31); + + status = BUSY_STATE; + putShortMSB(header); + + + // Save the adler32 of the preset dictionary: + if (strstart != 0) { + long adler = strm.adler.getValue(); + putShortMSB((int) (adler >>> 16)); + putShortMSB((int) (adler & 0xffff)); + } + strm.adler.reset(); + } + } + + // Flush as much pending output as possible + if (pending != 0) { + strm.flush_pending(); + if (strm.avail_out == 0) { + // Since avail_out is 0, deflate will be called again with + // more output space, but possibly with both pending and + // avail_in equal to zero. There won't be anything to do, + // but this is not an error situation so make sure we + // return OK instead of BUF_ERROR at next call of deflate: + last_flush = -1; + return Z_OK; + } + + // Make sure there is something to do and avoid duplicate consecutive + // flushes. For repeated and useless calls with Z_FINISH, we keep + // returning Z_STREAM_END instead of Z_BUFF_ERROR. + } else if (strm.avail_in == 0 && flush <= old_flush && + flush != Z_FINISH) { + strm.msg = z_errmsg[Z_NEED_DICT - (Z_BUF_ERROR)]; + return Z_BUF_ERROR; + } + + // User must not provide more input after the first FINISH: + if (status == FINISH_STATE && strm.avail_in != 0) { + strm.msg = z_errmsg[Z_NEED_DICT - (Z_BUF_ERROR)]; + return Z_BUF_ERROR; + } + + // Start a new block or continue the current one. + if (strm.avail_in != 0 || lookahead != 0 || + (flush != Z_NO_FLUSH && status != FINISH_STATE)) { + int bstate = -1; + switch (config_table[level].func) { + case STORED: + bstate = deflate_stored(flush); + break; + case FAST: + bstate = deflate_fast(flush); + break; + case SLOW: + bstate = deflate_slow(flush); + break; + default: + } + + if (bstate == FinishStarted || bstate == FinishDone) { + status = FINISH_STATE; + } + if (bstate == NeedMore || bstate == FinishStarted) { + if (strm.avail_out == 0) { + last_flush = -1; // avoid BUF_ERROR next call, see above + } + return Z_OK; + // If flush != Z_NO_FLUSH && avail_out == 0, the next call + // of deflate should use the same flush parameter to make sure + // that the flush is complete. So we don't have to output an + // empty block here, this will be done at next call. This also + // ensures that for a very small output buffer, we emit at most + // one empty block. + } + + if (bstate == BlockDone) { + if (flush == Z_PARTIAL_FLUSH) { + _tr_align(); + } else { // FULL_FLUSH or SYNC_FLUSH + _tr_stored_block(0, 0, false); + // For a full flush, this empty block will be recognized + // as a special marker by inflate_sync(). + if (flush == Z_FULL_FLUSH) { + //state.head[s.hash_size-1]=0; + for (int i = 0; i < hash_size/*-1*/; i++) // forget history + head[i] = 0; + } + } + strm.flush_pending(); + if (strm.avail_out == 0) { + last_flush = -1; // avoid BUF_ERROR at next call, see above + return Z_OK; + } + } + } + + if (flush != Z_FINISH) return Z_OK; + if (wrap <= 0) return Z_STREAM_END; + + if (wrap == 2) { + long adler = strm.adler.getValue(); + put_byte((byte) (adler & 0xff)); + put_byte((byte) ((adler >> 8) & 0xff)); + put_byte((byte) ((adler >> 16) & 0xff)); + put_byte((byte) ((adler >> 24) & 0xff)); + put_byte((byte) (strm.total_in & 0xff)); + put_byte((byte) ((strm.total_in >> 8) & 0xff)); + put_byte((byte) ((strm.total_in >> 16) & 0xff)); + put_byte((byte) ((strm.total_in >> 24) & 0xff)); + + getGZIPHeader().setCRC(adler); + } else { + // Write the zlib trailer (adler32) + long adler = strm.adler.getValue(); + putShortMSB((int) (adler >>> 16)); + putShortMSB((int) (adler & 0xffff)); + } + + strm.flush_pending(); + + // If avail_out is zero, the application will call deflate again + // to flush the rest. + + if (wrap > 0) wrap = -wrap; // write the trailer only once! + return pending != 0 ? Z_OK : Z_STREAM_END; + } + + static int deflateCopy(ZStream dest, ZStream src) { + + if (src.dstate == null) { + return Z_STREAM_ERROR; + } + + if (src.next_in != null) { + dest.next_in = new byte[src.next_in.length]; + System.arraycopy(src.next_in, 0, dest.next_in, 0, src.next_in.length); + } + dest.next_in_index = src.next_in_index; + dest.avail_in = src.avail_in; + dest.total_in = src.total_in; + + if (src.next_out != null) { + dest.next_out = new byte[src.next_out.length]; + System.arraycopy(src.next_out, 0, dest.next_out, 0, src.next_out.length); + } + + dest.next_out_index = src.next_out_index; + dest.avail_out = src.avail_out; + dest.total_out = src.total_out; + + dest.msg = src.msg; + dest.data_type = src.data_type; + dest.adler = src.adler.copy(); + + try { + dest.dstate = (Deflate) src.dstate.clone(); + dest.dstate.strm = dest; + } catch (CloneNotSupportedException e) { + // + } + return Z_OK; + } + + public Object clone() throws CloneNotSupportedException { + Deflate dest = (Deflate) super.clone(); + + dest.pending_buf = dup(dest.pending_buf); + dest.d_buf = dest.d_buf; + dest.l_buf = dup(dest.l_buf); + dest.window = dup(dest.window); + + dest.prev = dup(dest.prev); + dest.head = dup(dest.head); + dest.dyn_ltree = dup(dest.dyn_ltree); + dest.dyn_dtree = dup(dest.dyn_dtree); + dest.bl_tree = dup(dest.bl_tree); + + dest.bl_count = dup(dest.bl_count); + dest.next_code = dup(dest.next_code); + dest.heap = dup(dest.heap); + dest.depth = dup(dest.depth); + + dest.l_desc.dyn_tree = dest.dyn_ltree; + dest.d_desc.dyn_tree = dest.dyn_dtree; + dest.bl_desc.dyn_tree = dest.bl_tree; + + /* + dest.l_desc.stat_desc = StaticTree.static_l_desc; + dest.d_desc.stat_desc = StaticTree.static_d_desc; + dest.bl_desc.stat_desc = StaticTree.static_bl_desc; + */ + + if (dest.gheader != null) { + dest.gheader = (GZIPHeader) dest.gheader.clone(); + } + + return dest; + } + + private byte[] dup(byte[] buf) { + byte[] foo = new byte[buf.length]; + System.arraycopy(buf, 0, foo, 0, foo.length); + return foo; + } + + private short[] dup(short[] buf) { + short[] foo = new short[buf.length]; + System.arraycopy(buf, 0, foo, 0, foo.length); + return foo; + } + + private int[] dup(int[] buf) { + int[] foo = new int[buf.length]; + System.arraycopy(buf, 0, foo, 0, foo.length); + return foo; + } + + synchronized GZIPHeader getGZIPHeader() { + if (gheader == null) { + gheader = new GZIPHeader(); + } + return gheader; + } +} diff --git a/netty-zlib/src/main/java/io/netty/zlib/Deflater.java b/netty-zlib/src/main/java/io/netty/zlib/Deflater.java new file mode 100644 index 0000000..f3aa30e --- /dev/null +++ b/netty-zlib/src/main/java/io/netty/zlib/Deflater.java @@ -0,0 +1,142 @@ +package io.netty.zlib; + +final public class Deflater extends ZStream { + + static final private int MAX_WBITS = 15; // 32K LZ77 window + static final private int DEF_WBITS = MAX_WBITS; + + static final private int Z_NO_FLUSH = 0; + static final private int Z_PARTIAL_FLUSH = 1; + static final private int Z_SYNC_FLUSH = 2; + static final private int Z_FULL_FLUSH = 3; + static final private int Z_FINISH = 4; + + static final private int MAX_MEM_LEVEL = 9; + + static final private int Z_OK = 0; + static final private int Z_STREAM_END = 1; + static final private int Z_NEED_DICT = 2; + static final private int Z_ERRNO = -1; + static final private int Z_STREAM_ERROR = -2; + static final private int Z_DATA_ERROR = -3; + static final private int Z_MEM_ERROR = -4; + static final private int Z_BUF_ERROR = -5; + static final private int Z_VERSION_ERROR = -6; + + private boolean finished = false; + + public Deflater() { + super(); + } + + public Deflater(int level) throws GZIPException { + this(level, MAX_WBITS); + } + + public Deflater(int level, boolean nowrap) throws GZIPException { + this(level, MAX_WBITS, nowrap); + } + + public Deflater(int level, int bits) throws GZIPException { + this(level, bits, false); + } + + public Deflater(int level, int bits, boolean nowrap) throws GZIPException { + super(); + int ret = init(level, bits, nowrap); + if (ret != Z_OK) + throw new GZIPException(ret + ": " + msg); + } + + public Deflater(int level, int bits, int memlevel, JZlib.WrapperType wrapperType) throws GZIPException { + super(); + int ret = init(level, bits, memlevel, wrapperType); + if (ret != Z_OK) + throw new GZIPException(ret + ": " + msg); + } + + public Deflater(int level, int bits, int memlevel) throws GZIPException { + super(); + int ret = init(level, bits, memlevel); + if (ret != Z_OK) + throw new GZIPException(ret + ": " + msg); + } + + public int init(int level) { + return init(level, MAX_WBITS); + } + + public int init(int level, boolean nowrap) { + return init(level, MAX_WBITS, nowrap); + } + + public int init(int level, int bits) { + return init(level, bits, false); + } + + public int init(int level, int bits, int memlevel, JZlib.WrapperType wrapperType) { + if (bits < 9 || bits > 15) { + return Z_STREAM_ERROR; + } + if (wrapperType == JZlib.W_NONE) { + bits *= -1; + } else if (wrapperType == JZlib.W_GZIP) { + bits += 16; + } else if (wrapperType == JZlib.W_ANY) { + return Z_STREAM_ERROR; + } else if (wrapperType == JZlib.W_ZLIB) { + } + return init(level, bits, memlevel); + } + + public int init(int level, int bits, int memlevel) { + finished = false; + dstate = new Deflate(this); + return dstate.deflateInit(level, bits, memlevel); + } + + public int init(int level, int bits, boolean nowrap) { + finished = false; + dstate = new Deflate(this); + return dstate.deflateInit(level, nowrap ? -bits : bits); + } + + public int deflate(int flush) { + if (dstate == null) { + return Z_STREAM_ERROR; + } + int ret = dstate.deflate(flush); + if (ret == Z_STREAM_END) + finished = true; + return ret; + } + + public int end() { + finished = true; + if (dstate == null) return Z_STREAM_ERROR; + int ret = dstate.deflateEnd(); + dstate = null; + free(); + return ret; + } + + public int params(int level, int strategy) { + if (dstate == null) return Z_STREAM_ERROR; + return dstate.deflateParams(level, strategy); + } + + public int setDictionary(byte[] dictionary, int dictLength) { + if (dstate == null) + return Z_STREAM_ERROR; + return dstate.deflateSetDictionary(dictionary, dictLength); + } + + public boolean finished() { + return finished; + } + + public int copy(Deflater src) { + this.finished = src.finished; + return Deflate.deflateCopy(this, src); + } +} diff --git a/netty-zlib/src/main/java/io/netty/zlib/DeflaterOutputStream.java b/netty-zlib/src/main/java/io/netty/zlib/DeflaterOutputStream.java new file mode 100644 index 0000000..b0022c9 --- /dev/null +++ b/netty-zlib/src/main/java/io/netty/zlib/DeflaterOutputStream.java @@ -0,0 +1,151 @@ +package io.netty.zlib; + +import java.io.FilterOutputStream; +import java.io.IOException; +import java.io.OutputStream; + +public class DeflaterOutputStream extends FilterOutputStream { + + protected final Deflater deflater; + + protected byte[] buffer; + + private boolean closed = false; + + private boolean syncFlush = false; + + private final byte[] buf1 = new byte[1]; + + protected boolean mydeflater = false; + + private boolean close_out = true; + + protected static final int DEFAULT_BUFSIZE = 512; + + public DeflaterOutputStream(OutputStream out) throws IOException { + this(out, + new Deflater(JZlib.Z_DEFAULT_COMPRESSION), + DEFAULT_BUFSIZE, true); + mydeflater = true; + } + + public DeflaterOutputStream(OutputStream out, Deflater def) throws IOException { + this(out, def, DEFAULT_BUFSIZE, true); + } + + public DeflaterOutputStream(OutputStream out, + Deflater deflater, + int size) throws IOException { + this(out, deflater, size, true); + } + + public DeflaterOutputStream(OutputStream out, + Deflater deflater, + int size, + boolean close_out) throws IOException { + super(out); + if (out == null || deflater == null) { + throw new NullPointerException(); + } else if (size <= 0) { + throw new IllegalArgumentException("buffer size must be greater than 0"); + } + this.deflater = deflater; + buffer = new byte[size]; + this.close_out = close_out; + } + + public void write(int b) throws IOException { + buf1[0] = (byte) (b & 0xff); + write(buf1, 0, 1); + } + + public void write(byte[] b, int off, int len) throws IOException { + if (deflater.finished()) { + throw new IOException("finished"); + } else if (off < 0 | len < 0 | off + len > b.length) { + throw new IndexOutOfBoundsException(); + } else if (len == 0) { + } else { + int flush = syncFlush ? JZlib.Z_SYNC_FLUSH : JZlib.Z_NO_FLUSH; + deflater.setInput(b, off, len, true); + while (deflater.avail_in > 0) { + int err = deflate(flush); + if (err == JZlib.Z_STREAM_END) + break; + } + } + } + + public void finish() throws IOException { + while (!deflater.finished()) { + deflate(JZlib.Z_FINISH); + } + } + + public void close() throws IOException { + if (!closed) { + finish(); + if (mydeflater) { + deflater.end(); + } + if (close_out) + out.close(); + closed = true; + } + } + + protected int deflate(int flush) throws IOException { + deflater.setOutput(buffer, 0, buffer.length); + int err = deflater.deflate(flush); + switch (err) { + case JZlib.Z_OK: + case JZlib.Z_STREAM_END: + break; + case JZlib.Z_BUF_ERROR: + if (deflater.avail_in <= 0 && flush != JZlib.Z_FINISH) { + // flush() without any data + break; + } + default: + throw new IOException("failed to deflate: error=" + err + " avail_out=" + deflater.avail_out); + } + int len = deflater.next_out_index; + if (len > 0) { + out.write(buffer, 0, len); + } + return err; + } + + public void flush() throws IOException { + if (syncFlush && !deflater.finished()) { + while (true) { + int err = deflate(JZlib.Z_SYNC_FLUSH); + if (deflater.next_out_index < buffer.length) + break; + if (err == JZlib.Z_STREAM_END) + break; + } + } + out.flush(); + } + + public long getTotalIn() { + return deflater.getTotalIn(); + } + + public long getTotalOut() { + return deflater.getTotalOut(); + } + + public void setSyncFlush(boolean syncFlush) { + this.syncFlush = syncFlush; + } + + public boolean getSyncFlush() { + return this.syncFlush; + } + + public Deflater getDeflater() { + return deflater; + } +} diff --git a/netty-zlib/src/main/java/io/netty/zlib/GZIPException.java b/netty-zlib/src/main/java/io/netty/zlib/GZIPException.java new file mode 100644 index 0000000..f528e8e --- /dev/null +++ b/netty-zlib/src/main/java/io/netty/zlib/GZIPException.java @@ -0,0 +1,11 @@ +package io.netty.zlib; + +public class GZIPException extends java.io.IOException { + public GZIPException() { + super(); + } + + public GZIPException(String s) { + super(s); + } +} diff --git a/netty-zlib/src/main/java/io/netty/zlib/GZIPHeader.java b/netty-zlib/src/main/java/io/netty/zlib/GZIPHeader.java new file mode 100644 index 0000000..0f4b29f --- /dev/null +++ b/netty-zlib/src/main/java/io/netty/zlib/GZIPHeader.java @@ -0,0 +1,160 @@ +package io.netty.zlib; + +import java.io.UnsupportedEncodingException; +import java.nio.charset.StandardCharsets; + +/** + * @see "http://www.ietf.org/rfc/rfc1952.txt" + */ +public class GZIPHeader implements Cloneable { + + public static final byte OS_MSDOS = (byte) 0x00; + public static final byte OS_AMIGA = (byte) 0x01; + public static final byte OS_VMS = (byte) 0x02; + public static final byte OS_UNIX = (byte) 0x03; + public static final byte OS_ATARI = (byte) 0x05; + public static final byte OS_OS2 = (byte) 0x06; + public static final byte OS_MACOS = (byte) 0x07; + public static final byte OS_TOPS20 = (byte) 0x0a; + public static final byte OS_WIN32 = (byte) 0x0b; + public static final byte OS_VMCMS = (byte) 0x04; + public static final byte OS_ZSYSTEM = (byte) 0x08; + public static final byte OS_CPM = (byte) 0x09; + public static final byte OS_QDOS = (byte) 0x0c; + public static final byte OS_RISCOS = (byte) 0x0d; + public static final byte OS_UNKNOWN = (byte) 0xff; + + boolean text = false; + private final boolean fhcrc = false; + long time; + int xflags; + int os = 255; + byte[] extra; + byte[] name; + byte[] comment; + int hcrc; + long crc; + boolean done = false; + long mtime = 0; + + public void setModifiedTime(long mtime) { + this.mtime = mtime; + } + + public long getModifiedTime() { + return mtime; + } + + public void setOS(int os) { + if ((0 <= os && os <= 13) || os == 255) + this.os = os; + else + throw new IllegalArgumentException("os: " + os); + } + + public int getOS() { + return os; + } + + public void setName(String name) { + this.name = name.getBytes(StandardCharsets.ISO_8859_1); + } + + public String getName() { + if (name == null) return ""; + return new String(name, StandardCharsets.ISO_8859_1); + } + + public void setComment(String comment) { + this.comment = comment.getBytes(StandardCharsets.ISO_8859_1); + } + + public String getComment() { + if (comment == null) return ""; + return new String(comment, StandardCharsets.ISO_8859_1); + } + + public void setCRC(long crc) { + this.crc = crc; + } + + public long getCRC() { + return crc; + } + + void put(Deflate d) { + int flag = 0; + if (text) { + flag |= 1; // FTEXT + } + if (fhcrc) { + flag |= 2; // FHCRC + } + if (extra != null) { + flag |= 4; // FEXTRA + } + if (name != null) { + flag |= 8; // FNAME + } + if (comment != null) { + flag |= 16; // FCOMMENT + } + int xfl = 0; + if (d.level == JZlib.Z_BEST_SPEED) { + xfl |= 4; + } else if (d.level == JZlib.Z_BEST_COMPRESSION) { + xfl |= 2; + } + + d.put_short((short) 0x8b1f); // ID1 ID2 + d.put_byte((byte) 8); // CM(Compression Method) + d.put_byte((byte) flag); + d.put_byte((byte) mtime); + d.put_byte((byte) (mtime >> 8)); + d.put_byte((byte) (mtime >> 16)); + d.put_byte((byte) (mtime >> 24)); + d.put_byte((byte) xfl); + d.put_byte((byte) os); + + if (extra != null) { + d.put_byte((byte) extra.length); + d.put_byte((byte) (extra.length >> 8)); + d.put_byte(extra, 0, extra.length); + } + + if (name != null) { + d.put_byte(name, 0, name.length); + d.put_byte((byte) 0); + } + + if (comment != null) { + d.put_byte(comment, 0, comment.length); + d.put_byte((byte) 0); + } + } + + @Override + public Object clone() throws CloneNotSupportedException { + GZIPHeader gheader = (GZIPHeader) super.clone(); + byte[] tmp; + if (gheader.extra != null) { + tmp = new byte[gheader.extra.length]; + System.arraycopy(gheader.extra, 0, tmp, 0, tmp.length); + gheader.extra = tmp; + } + + if (gheader.name != null) { + tmp = new byte[gheader.name.length]; + System.arraycopy(gheader.name, 0, tmp, 0, tmp.length); + gheader.name = tmp; + } + + if (gheader.comment != null) { + tmp = new byte[gheader.comment.length]; + System.arraycopy(gheader.comment, 0, tmp, 0, tmp.length); + gheader.comment = tmp; + } + + return gheader; + } +} diff --git a/netty-zlib/src/main/java/io/netty/zlib/GZIPInputStream.java b/netty-zlib/src/main/java/io/netty/zlib/GZIPInputStream.java new file mode 100644 index 0000000..39f7753 --- /dev/null +++ b/netty-zlib/src/main/java/io/netty/zlib/GZIPInputStream.java @@ -0,0 +1,117 @@ +package io.netty.zlib; + +import java.io.IOException; +import java.io.InputStream; + +public class GZIPInputStream extends InflaterInputStream { + + public GZIPInputStream(InputStream in) throws IOException { + this(in, DEFAULT_BUFSIZE, true); + } + + public GZIPInputStream(InputStream in, + int size, + boolean close_in) throws IOException { + this(in, new Inflater(15 + 16), size, close_in); + myinflater = true; + } + + public GZIPInputStream(InputStream in, + Inflater inflater, + int size, + boolean close_in) throws IOException { + super(in, inflater, size, close_in); + } + + public long getModifiedtime() { + return inflater.istate.getGZIPHeader().getModifiedTime(); + } + + public int getOS() { + return inflater.istate.getGZIPHeader().getOS(); + } + + public String getName() { + return inflater.istate.getGZIPHeader().getName(); + } + + public String getComment() { + return inflater.istate.getGZIPHeader().getComment(); + } + + public long getCRC() throws GZIPException { + if (inflater.istate.mode != 12 /*DONE*/) + throw new GZIPException("checksum is not calculated yet."); + return inflater.istate.getGZIPHeader().getCRC(); + } + + public void readHeader() throws IOException { + + byte[] empty = "".getBytes(); + inflater.setOutput(empty, 0, 0); + inflater.setInput(empty, 0, 0, false); + + byte[] b = new byte[10]; + + int n = fill(b); + if (n != 10) { + if (n > 0) { + inflater.setInput(b, 0, n, false); + //inflater.next_in_index = n; + inflater.next_in_index = 0; + inflater.avail_in = n; + } + throw new IOException("no input"); + } + + inflater.setInput(b, 0, n, false); + + byte[] b1 = new byte[1]; + do { + if (inflater.avail_in <= 0) { + int i = in.read(b1); + if (i <= 0) + throw new IOException("no input"); + inflater.setInput(b1, 0, 1, true); + } + + int err = inflater.inflate(JZlib.Z_NO_FLUSH); + + if (err != 0/*Z_OK*/) { + int len = 2048 - inflater.next_in.length; + if (len > 0) { + byte[] tmp = new byte[len]; + n = fill(tmp); + if (n > 0) { + inflater.avail_in += inflater.next_in_index; + inflater.next_in_index = 0; + inflater.setInput(tmp, 0, n, true); + } + } + //inflater.next_in_index = inflater.next_in.length; + inflater.avail_in += inflater.next_in_index; + inflater.next_in_index = 0; + throw new IOException(inflater.msg); + } + } + while (inflater.istate.inParsingHeader()); + } + + private int fill(byte[] buf) { + int len = buf.length; + int n = 0; + do { + int i = -1; + try { + i = in.read(buf, n, buf.length - n); + } catch (IOException e) { + } + if (i == -1) { + break; + } + n += i; + } + while (n < len); + return n; + } +} \ No newline at end of file diff --git a/netty-zlib/src/main/java/io/netty/zlib/GZIPOutputStream.java b/netty-zlib/src/main/java/io/netty/zlib/GZIPOutputStream.java new file mode 100644 index 0000000..4a09d41 --- /dev/null +++ b/netty-zlib/src/main/java/io/netty/zlib/GZIPOutputStream.java @@ -0,0 +1,63 @@ +package io.netty.zlib; + +import java.io.IOException; +import java.io.OutputStream; + +public class GZIPOutputStream extends DeflaterOutputStream { + + public GZIPOutputStream(OutputStream out) throws IOException { + this(out, DEFAULT_BUFSIZE); + } + + public GZIPOutputStream(OutputStream out, int size) throws IOException { + this(out, size, true); + } + + public GZIPOutputStream(OutputStream out, + int size, + boolean close_out) throws IOException { + this(out, + new Deflater(JZlib.Z_DEFAULT_COMPRESSION, 15 + 16), + size, close_out); + mydeflater = true; + } + + public GZIPOutputStream(OutputStream out, + Deflater deflater, + int size, + boolean close_out) throws IOException { + super(out, deflater, size, close_out); + } + + + private void check() throws GZIPException { + if (deflater.dstate.status != 42 /*INIT_STATUS*/) + throw new GZIPException("header is already written."); + } + + public void setModifiedTime(long mtime) throws GZIPException { + check(); + deflater.dstate.getGZIPHeader().setModifiedTime(mtime); + } + + public void setOS(int os) throws GZIPException { + check(); + deflater.dstate.getGZIPHeader().setOS(os); + } + + public void setName(String name) throws GZIPException { + check(); + deflater.dstate.getGZIPHeader().setName(name); + } + + public void setComment(String comment) throws GZIPException { + check(); + deflater.dstate.getGZIPHeader().setComment(comment); + } + + public long getCRC() throws GZIPException { + if (deflater.dstate.status != 666 /*FINISH_STATE*/) + throw new GZIPException("checksum is not calculated yet."); + return deflater.dstate.getGZIPHeader().getCRC(); + } +} diff --git a/netty-zlib/src/main/java/io/netty/zlib/InfBlocks.java b/netty-zlib/src/main/java/io/netty/zlib/InfBlocks.java new file mode 100644 index 0000000..b24e8f9 --- /dev/null +++ b/netty-zlib/src/main/java/io/netty/zlib/InfBlocks.java @@ -0,0 +1,666 @@ +package io.netty.zlib; + +final class InfBlocks { + static final private int MANY = 1440; + + // And'ing with mask[n] masks the lower n bits + static final private int[] inflate_mask = { + 0x00000000, 0x00000001, 0x00000003, 0x00000007, 0x0000000f, + 0x0000001f, 0x0000003f, 0x0000007f, 0x000000ff, 0x000001ff, + 0x000003ff, 0x000007ff, 0x00000fff, 0x00001fff, 0x00003fff, + 0x00007fff, 0x0000ffff + }; + + // Table for deflate from PKZIP's appnote.txt. + static final int[] border = { // Order of the bit length code lengths + 16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15 + }; + + static final private int Z_OK = 0; + static final private int Z_STREAM_END = 1; + static final private int Z_NEED_DICT = 2; + static final private int Z_ERRNO = -1; + static final private int Z_STREAM_ERROR = -2; + static final private int Z_DATA_ERROR = -3; + static final private int Z_MEM_ERROR = -4; + static final private int Z_BUF_ERROR = -5; + static final private int Z_VERSION_ERROR = -6; + + static final private int TYPE = 0; // get type bits (3, including end bit) + static final private int LENS = 1; // get lengths for stored + static final private int STORED = 2;// processing stored block + static final private int TABLE = 3; // get table lengths + static final private int BTREE = 4; // get bit lengths tree for a dynamic block + static final private int DTREE = 5; // get length, distance trees for a dynamic block + static final private int CODES = 6; // processing fixed or dynamic block + static final private int DRY = 7; // output remaining window bytes + static final private int DONE = 8; // finished last block, done + static final private int BAD = 9; // ot a data error--stuck here + + int mode; // current inflate_block mode + + int left; // if STORED, bytes left to copy + + int table; // table lengths (14 bits) + int index; // index into blens (or border) + int[] blens; // bit lengths of codes + int[] bb = new int[1]; // bit length tree depth + int[] tb = new int[1]; // bit length decoding tree + + int[] bl = new int[1]; + int[] bd = new int[1]; + + int[][] tl = new int[1][]; + int[][] td = new int[1][]; + int[] tli = new int[1]; // tl_index + int[] tdi = new int[1]; // td_index + + private final InfCodes codes; // if CODES, current state + + int last; // true if this block is the last block + + // mode independent information + int bitk; // bits in bit buffer + int bitb; // bit buffer + int[] hufts; // single malloc for tree space + byte[] window; // sliding window + int end; // one byte after sliding window + int read; // window read pointer + int write; // window write pointer + private final boolean check; + + private final InfTree inftree = new InfTree(); + + private final ZStream z; + + InfBlocks(ZStream z, int w) { + this.z = z; + this.codes = new InfCodes(this.z, this); + hufts = new int[MANY * 3]; + window = new byte[w]; + end = w; + this.check = z.istate.wrap != 0; + mode = TYPE; + reset(); + } + + void reset() { + if (mode == BTREE || mode == DTREE) { + } + if (mode == CODES) { + codes.free(z); + } + mode = TYPE; + bitk = 0; + bitb = 0; + read = write = 0; + if (check) { + z.adler.reset(); + } + } + + int proc(int r) { + int t; // temporary storage + int b; // bit buffer + int k; // bits in bit buffer + int p; // input data pointer + int n; // bytes available there + int q; // output window write pointer + int m; // bytes to end of window or read pointer + + // copy input/output information to locals (UPDATE macro restores) + { + p = z.next_in_index; + n = z.avail_in; + b = bitb; + k = bitk; + } + { + q = write; + m = q < read ? read - q - 1 : end - q; + } + + // process input based on current state + while (true) { + switch (mode) { + case TYPE: + + while (k < (3)) { + if (n != 0) { + r = Z_OK; + } else { + bitb = b; + bitk = k; + z.avail_in = n; + z.total_in += p - z.next_in_index; + z.next_in_index = p; + write = q; + return inflate_flush(r); + } + n--; + b |= (z.next_in[p++] & 0xff) << k; + k += 8; + } + t = b & 7; + last = t & 1; + + switch (t >>> 1) { + case 0: // stored + { + b >>>= (3); + k -= (3); + } + t = k & 7; // go to byte boundary + + { + b >>>= (t); + k -= (t); + } + mode = LENS; // get length of stored block + break; + case 1: // fixed + InfTree.inflate_trees_fixed(bl, bd, tl, td, z); + codes.init(bl[0], bd[0], tl[0], 0, td[0], 0); + + { + b >>>= (3); + k -= (3); + } + + mode = CODES; + break; + case 2: // dynamic + + { + b >>>= (3); + k -= (3); + } + + mode = TABLE; + break; + case 3: // illegal + + { + b >>>= (3); + k -= (3); + } + mode = BAD; + z.msg = "invalid block type"; + r = Z_DATA_ERROR; + + bitb = b; + bitk = k; + z.avail_in = n; + z.total_in += p - z.next_in_index; + z.next_in_index = p; + write = q; + return inflate_flush(r); + } + break; + case LENS: + + while (k < (32)) { + if (n != 0) { + r = Z_OK; + } else { + bitb = b; + bitk = k; + z.avail_in = n; + z.total_in += p - z.next_in_index; + z.next_in_index = p; + write = q; + return inflate_flush(r); + } + n--; + b |= (z.next_in[p++] & 0xff) << k; + k += 8; + } + + if ((((~b) >>> 16) & 0xffff) != (b & 0xffff)) { + mode = BAD; + z.msg = "invalid stored block lengths"; + r = Z_DATA_ERROR; + + bitb = b; + bitk = k; + z.avail_in = n; + z.total_in += p - z.next_in_index; + z.next_in_index = p; + write = q; + return inflate_flush(r); + } + left = (b & 0xffff); + b = k = 0; // dump bits + mode = left != 0 ? STORED : (last != 0 ? DRY : TYPE); + break; + case STORED: + if (n == 0) { + bitb = b; + bitk = k; + z.avail_in = n; + z.total_in += p - z.next_in_index; + z.next_in_index = p; + write = q; + return inflate_flush(r); + } + + if (m == 0) { + if (q == end && read != 0) { + q = 0; + m = q < read ? read - q - 1 : end - q; + } + if (m == 0) { + write = q; + r = inflate_flush(r); + q = write; + m = q < read ? read - q - 1 : end - q; + if (q == end && read != 0) { + q = 0; + m = q < read ? read - q - 1 : end - q; + } + if (m == 0) { + bitb = b; + bitk = k; + z.avail_in = n; + z.total_in += p - z.next_in_index; + z.next_in_index = p; + write = q; + return inflate_flush(r); + } + } + } + r = Z_OK; + + t = left; + if (t > n) t = n; + if (t > m) t = m; + System.arraycopy(z.next_in, p, window, q, t); + p += t; + n -= t; + q += t; + m -= t; + if ((left -= t) != 0) + break; + mode = last != 0 ? DRY : TYPE; + break; + case TABLE: + + while (k < (14)) { + if (n != 0) { + r = Z_OK; + } else { + bitb = b; + bitk = k; + z.avail_in = n; + z.total_in += p - z.next_in_index; + z.next_in_index = p; + write = q; + return inflate_flush(r); + } + n--; + b |= (z.next_in[p++] & 0xff) << k; + k += 8; + } + + table = t = (b & 0x3fff); + if ((t & 0x1f) > 29 || ((t >> 5) & 0x1f) > 29) { + mode = BAD; + z.msg = "too many length or distance symbols"; + r = Z_DATA_ERROR; + + bitb = b; + bitk = k; + z.avail_in = n; + z.total_in += p - z.next_in_index; + z.next_in_index = p; + write = q; + return inflate_flush(r); + } + t = 258 + (t & 0x1f) + ((t >> 5) & 0x1f); + if (blens == null || blens.length < t) { + blens = new int[t]; + } else { + for (int i = 0; i < t; i++) { + blens[i] = 0; + } + } + + { + b >>>= (14); + k -= (14); + } + + index = 0; + mode = BTREE; + case BTREE: + while (index < 4 + (table >>> 10)) { + while (k < (3)) { + if (n != 0) { + r = Z_OK; + } else { + bitb = b; + bitk = k; + z.avail_in = n; + z.total_in += p - z.next_in_index; + z.next_in_index = p; + write = q; + return inflate_flush(r); + } + n--; + b |= (z.next_in[p++] & 0xff) << k; + k += 8; + } + + blens[border[index++]] = b & 7; + + { + b >>>= (3); + k -= (3); + } + } + + while (index < 19) { + blens[border[index++]] = 0; + } + + bb[0] = 7; + t = inftree.inflate_trees_bits(blens, bb, tb, hufts, z); + if (t != Z_OK) { + r = t; + if (r == Z_DATA_ERROR) { + blens = null; + mode = BAD; + } + + bitb = b; + bitk = k; + z.avail_in = n; + z.total_in += p - z.next_in_index; + z.next_in_index = p; + write = q; + return inflate_flush(r); + } + + index = 0; + mode = DTREE; + case DTREE: + while (true) { + t = table; + if (!(index < 258 + (t & 0x1f) + ((t >> 5) & 0x1f))) { + break; + } + + int[] h; + int i, j, c; + + t = bb[0]; + + while (k < (t)) { + if (n != 0) { + r = Z_OK; + } else { + bitb = b; + bitk = k; + z.avail_in = n; + z.total_in += p - z.next_in_index; + z.next_in_index = p; + write = q; + return inflate_flush(r); + } + n--; + b |= (z.next_in[p++] & 0xff) << k; + k += 8; + } + + if (tb[0] == -1) { + //System.err.println("null..."); + } + + t = hufts[(tb[0] + (b & inflate_mask[t])) * 3 + 1]; + c = hufts[(tb[0] + (b & inflate_mask[t])) * 3 + 2]; + + if (c < 16) { + b >>>= (t); + k -= (t); + blens[index++] = c; + } else { // c == 16..18 + i = c == 18 ? 7 : c - 14; + j = c == 18 ? 11 : 3; + + while (k < (t + i)) { + if (n != 0) { + r = Z_OK; + } else { + bitb = b; + bitk = k; + z.avail_in = n; + z.total_in += p - z.next_in_index; + z.next_in_index = p; + write = q; + return inflate_flush(r); + } + n--; + b |= (z.next_in[p++] & 0xff) << k; + k += 8; + } + + b >>>= (t); + k -= (t); + + j += (b & inflate_mask[i]); + + b >>>= (i); + k -= (i); + + i = index; + t = table; + if (i + j > 258 + (t & 0x1f) + ((t >> 5) & 0x1f) || + (c == 16 && i < 1)) { + blens = null; + mode = BAD; + z.msg = "invalid bit length repeat"; + r = Z_DATA_ERROR; + + bitb = b; + bitk = k; + z.avail_in = n; + z.total_in += p - z.next_in_index; + z.next_in_index = p; + write = q; + return inflate_flush(r); + } + + c = c == 16 ? blens[i - 1] : 0; + do { + blens[i++] = c; + } + while (--j != 0); + index = i; + } + } + + tb[0] = -1; + { + bl[0] = 9; // must be <= 9 for lookahead assumptions + bd[0] = 6; // must be <= 9 for lookahead assumptions + t = table; + t = inftree.inflate_trees_dynamic(257 + (t & 0x1f), + 1 + ((t >> 5) & 0x1f), + blens, bl, bd, tli, tdi, hufts, z); + + if (t != Z_OK) { + if (t == Z_DATA_ERROR) { + blens = null; + mode = BAD; + } + r = t; + + bitb = b; + bitk = k; + z.avail_in = n; + z.total_in += p - z.next_in_index; + z.next_in_index = p; + write = q; + return inflate_flush(r); + } + codes.init(bl[0], bd[0], hufts, tli[0], hufts, tdi[0]); + } + mode = CODES; + case CODES: + bitb = b; + bitk = k; + z.avail_in = n; + z.total_in += p - z.next_in_index; + z.next_in_index = p; + write = q; + + if ((r = codes.proc(r)) != Z_STREAM_END) { + return inflate_flush(r); + } + r = Z_OK; + codes.free(z); + + p = z.next_in_index; + n = z.avail_in; + b = bitb; + k = bitk; + q = write; + m = q < read ? read - q - 1 : end - q; + + if (last == 0) { + mode = TYPE; + break; + } + mode = DRY; + case DRY: + write = q; + r = inflate_flush(r); + q = write; + m = q < read ? read - q - 1 : end - q; + if (read != write) { + bitb = b; + bitk = k; + z.avail_in = n; + z.total_in += p - z.next_in_index; + z.next_in_index = p; + write = q; + return inflate_flush(r); + } + mode = DONE; + case DONE: + r = Z_STREAM_END; + + bitb = b; + bitk = k; + z.avail_in = n; + z.total_in += p - z.next_in_index; + z.next_in_index = p; + write = q; + return inflate_flush(r); + case BAD: + r = Z_DATA_ERROR; + + bitb = b; + bitk = k; + z.avail_in = n; + z.total_in += p - z.next_in_index; + z.next_in_index = p; + write = q; + return inflate_flush(r); + + default: + r = Z_STREAM_ERROR; + + bitb = b; + bitk = k; + z.avail_in = n; + z.total_in += p - z.next_in_index; + z.next_in_index = p; + write = q; + return inflate_flush(r); + } + } + } + + void free() { + reset(); + window = null; + hufts = null; + //ZFREE(z, s); + } + + void set_dictionary(byte[] d, int start, int n) { + System.arraycopy(d, start, window, 0, n); + read = write = n; + } + + // Returns true if inflate is currently at the end of a block generated + // by Z_SYNC_FLUSH or Z_FULL_FLUSH. + int sync_point() { + return mode == LENS ? 1 : 0; + } + + // copy as much as possible from the sliding window to the output area + int inflate_flush(int r) { + int n; + int p; + int q; + + // local copies of source and destination pointers + p = z.next_out_index; + q = read; + + // compute number of bytes to copy as far as end of window + n = (q <= write ? write : end) - q; + if (n > z.avail_out) n = z.avail_out; + if (n != 0 && r == Z_BUF_ERROR) r = Z_OK; + + // update counters + z.avail_out -= n; + z.total_out += n; + + // update check information + if (check && n > 0) { + z.adler.update(window, q, n); + } + + // copy as far as end of window + System.arraycopy(window, q, z.next_out, p, n); + p += n; + q += n; + + // see if more to copy at beginning of window + if (q == end) { + // wrap pointers + q = 0; + if (write == end) + write = 0; + + // compute bytes to copy + n = write - q; + if (n > z.avail_out) n = z.avail_out; + if (n != 0 && r == Z_BUF_ERROR) r = Z_OK; + + // update counters + z.avail_out -= n; + z.total_out += n; + + // update check information + if (check && n > 0) { + z.adler.update(window, q, n); + } + + // copy + System.arraycopy(window, q, z.next_out, p, n); + p += n; + q += n; + } + + // update pointers + z.next_out_index = p; + read = q; + + // done + return r; + } +} diff --git a/netty-zlib/src/main/java/io/netty/zlib/InfCodes.java b/netty-zlib/src/main/java/io/netty/zlib/InfCodes.java new file mode 100644 index 0000000..53ca04f --- /dev/null +++ b/netty-zlib/src/main/java/io/netty/zlib/InfCodes.java @@ -0,0 +1,690 @@ +package io.netty.zlib; + +final class InfCodes { + + static final private int[] inflate_mask = { + 0x00000000, 0x00000001, 0x00000003, 0x00000007, 0x0000000f, + 0x0000001f, 0x0000003f, 0x0000007f, 0x000000ff, 0x000001ff, + 0x000003ff, 0x000007ff, 0x00000fff, 0x00001fff, 0x00003fff, + 0x00007fff, 0x0000ffff + }; + + static final private int Z_OK = 0; + static final private int Z_STREAM_END = 1; + static final private int Z_NEED_DICT = 2; + static final private int Z_ERRNO = -1; + static final private int Z_STREAM_ERROR = -2; + static final private int Z_DATA_ERROR = -3; + static final private int Z_MEM_ERROR = -4; + static final private int Z_BUF_ERROR = -5; + static final private int Z_VERSION_ERROR = -6; + + // waiting for "i:"=input, + // "o:"=output, + // "x:"=nothing + static final private int START = 0; // x: set up for LEN + static final private int LEN = 1; // i: get length/literal/eob next + static final private int LENEXT = 2; // i: getting length extra (have base) + static final private int DIST = 3; // i: get distance next + static final private int DISTEXT = 4;// i: getting distance extra + static final private int COPY = 5; // o: copying bytes in window, waiting for space + static final private int LIT = 6; // o: got literal, waiting for output space + static final private int WASH = 7; // o: got eob, possibly still output waiting + static final private int END = 8; // x: got eob and all data flushed + static final private int BADCODE = 9;// x: got error + + int mode; // current inflate_codes mode + + // mode dependent information + int len; + + int[] tree; // pointer into tree + int tree_index = 0; + int need; // bits needed + + int lit; + + // if EXT or COPY, where and how much + int get; // bits to get for extra + int dist; // distance back to copy from + + byte lbits; // ltree bits decoded per branch + byte dbits; // dtree bits decoder per branch + int[] ltree; // literal/length/eob tree + int ltree_index; // literal/length/eob tree + int[] dtree; // distance tree + int dtree_index; // distance tree + + private final ZStream z; + private final InfBlocks s; + + InfCodes(ZStream z, InfBlocks s) { + this.z = z; + this.s = s; + } + + void init(int bl, int bd, + int[] tl, int tl_index, + int[] td, int td_index) { + mode = START; + lbits = (byte) bl; + dbits = (byte) bd; + ltree = tl; + ltree_index = tl_index; + dtree = td; + dtree_index = td_index; + tree = null; + } + + int proc(int r) { + int j; // temporary storage + int[] t; // temporary pointer + int tindex; // temporary pointer + int e; // extra bits or operation + int b = 0; // bit buffer + int k = 0; // bits in bit buffer + int p = 0; // input data pointer + int n; // bytes available there + int q; // output window write pointer + int m; // bytes to end of window or read pointer + int f; // pointer to copy strings from + + // copy input/output information to locals (UPDATE macro restores) + p = z.next_in_index; + n = z.avail_in; + b = s.bitb; + k = s.bitk; + q = s.write; + m = q < s.read ? s.read - q - 1 : s.end - q; + + // process input and output based on current state + while (true) { + switch (mode) { + // waiting for "i:"=input, "o:"=output, "x:"=nothing + case START: // x: set up for LEN + if (m >= 258 && n >= 10) { + + s.bitb = b; + s.bitk = k; + z.avail_in = n; + z.total_in += p - z.next_in_index; + z.next_in_index = p; + s.write = q; + r = inflate_fast(lbits, dbits, + ltree, ltree_index, + dtree, dtree_index, + s, z); + + p = z.next_in_index; + n = z.avail_in; + b = s.bitb; + k = s.bitk; + q = s.write; + m = q < s.read ? s.read - q - 1 : s.end - q; + + if (r != Z_OK) { + mode = r == Z_STREAM_END ? WASH : BADCODE; + break; + } + } + need = lbits; + tree = ltree; + tree_index = ltree_index; + + mode = LEN; + case LEN: // i: get length/literal/eob next + j = need; + + while (k < (j)) { + if (n != 0) r = Z_OK; + else { + + s.bitb = b; + s.bitk = k; + z.avail_in = n; + z.total_in += p - z.next_in_index; + z.next_in_index = p; + s.write = q; + return s.inflate_flush(r); + } + n--; + b |= (z.next_in[p++] & 0xff) << k; + k += 8; + } + + tindex = (tree_index + (b & inflate_mask[j])) * 3; + + b >>>= (tree[tindex + 1]); + k -= (tree[tindex + 1]); + + e = tree[tindex]; + + if (e == 0) { // literal + lit = tree[tindex + 2]; + mode = LIT; + break; + } + if ((e & 16) != 0) { // length + get = e & 15; + len = tree[tindex + 2]; + mode = LENEXT; + break; + } + if ((e & 64) == 0) { // next table + need = e; + tree_index = tindex / 3 + tree[tindex + 2]; + break; + } + if ((e & 32) != 0) { // end of block + mode = WASH; + break; + } + mode = BADCODE; // invalid code + z.msg = "invalid literal/length code"; + r = Z_DATA_ERROR; + + s.bitb = b; + s.bitk = k; + z.avail_in = n; + z.total_in += p - z.next_in_index; + z.next_in_index = p; + s.write = q; + return s.inflate_flush(r); + + case LENEXT: // i: getting length extra (have base) + j = get; + + while (k < (j)) { + if (n != 0) r = Z_OK; + else { + + s.bitb = b; + s.bitk = k; + z.avail_in = n; + z.total_in += p - z.next_in_index; + z.next_in_index = p; + s.write = q; + return s.inflate_flush(r); + } + n--; + b |= (z.next_in[p++] & 0xff) << k; + k += 8; + } + + len += (b & inflate_mask[j]); + + b >>= j; + k -= j; + + need = dbits; + tree = dtree; + tree_index = dtree_index; + mode = DIST; + case DIST: // i: get distance next + j = need; + + while (k < (j)) { + if (n != 0) r = Z_OK; + else { + + s.bitb = b; + s.bitk = k; + z.avail_in = n; + z.total_in += p - z.next_in_index; + z.next_in_index = p; + s.write = q; + return s.inflate_flush(r); + } + n--; + b |= (z.next_in[p++] & 0xff) << k; + k += 8; + } + + tindex = (tree_index + (b & inflate_mask[j])) * 3; + + b >>= tree[tindex + 1]; + k -= tree[tindex + 1]; + + e = (tree[tindex]); + if ((e & 16) != 0) { // distance + get = e & 15; + dist = tree[tindex + 2]; + mode = DISTEXT; + break; + } + if ((e & 64) == 0) { // next table + need = e; + tree_index = tindex / 3 + tree[tindex + 2]; + break; + } + mode = BADCODE; // invalid code + z.msg = "invalid distance code"; + r = Z_DATA_ERROR; + + s.bitb = b; + s.bitk = k; + z.avail_in = n; + z.total_in += p - z.next_in_index; + z.next_in_index = p; + s.write = q; + return s.inflate_flush(r); + + case DISTEXT: // i: getting distance extra + j = get; + + while (k < (j)) { + if (n != 0) r = Z_OK; + else { + + s.bitb = b; + s.bitk = k; + z.avail_in = n; + z.total_in += p - z.next_in_index; + z.next_in_index = p; + s.write = q; + return s.inflate_flush(r); + } + n--; + b |= (z.next_in[p++] & 0xff) << k; + k += 8; + } + + dist += (b & inflate_mask[j]); + + b >>= j; + k -= j; + + mode = COPY; + case COPY: // o: copying bytes in window, waiting for space + f = q - dist; + while (f < 0) { // modulo window size-"while" instead + f += s.end; // of "if" handles invalid distances + } + while (len != 0) { + + if (m == 0) { + if (q == s.end && s.read != 0) { + q = 0; + m = q < s.read ? s.read - q - 1 : s.end - q; + } + if (m == 0) { + s.write = q; + r = s.inflate_flush(r); + q = s.write; + m = q < s.read ? s.read - q - 1 : s.end - q; + + if (q == s.end && s.read != 0) { + q = 0; + m = q < s.read ? s.read - q - 1 : s.end - q; + } + + if (m == 0) { + s.bitb = b; + s.bitk = k; + z.avail_in = n; + z.total_in += p - z.next_in_index; + z.next_in_index = p; + s.write = q; + return s.inflate_flush(r); + } + } + } + + s.window[q++] = s.window[f++]; + m--; + + if (f == s.end) + f = 0; + len--; + } + mode = START; + break; + case LIT: // o: got literal, waiting for output space + if (m == 0) { + if (q == s.end && s.read != 0) { + q = 0; + m = q < s.read ? s.read - q - 1 : s.end - q; + } + if (m == 0) { + s.write = q; + r = s.inflate_flush(r); + q = s.write; + m = q < s.read ? s.read - q - 1 : s.end - q; + + if (q == s.end && s.read != 0) { + q = 0; + m = q < s.read ? s.read - q - 1 : s.end - q; + } + if (m == 0) { + s.bitb = b; + s.bitk = k; + z.avail_in = n; + z.total_in += p - z.next_in_index; + z.next_in_index = p; + s.write = q; + return s.inflate_flush(r); + } + } + } + r = Z_OK; + + s.window[q++] = (byte) lit; + m--; + + mode = START; + break; + case WASH: // o: got eob, possibly more output + if (k > 7) { // return unused byte, if any + k -= 8; + n++; + p--; // can always return one + } + + s.write = q; + r = s.inflate_flush(r); + q = s.write; + m = q < s.read ? s.read - q - 1 : s.end - q; + + if (s.read != s.write) { + s.bitb = b; + s.bitk = k; + z.avail_in = n; + z.total_in += p - z.next_in_index; + z.next_in_index = p; + s.write = q; + return s.inflate_flush(r); + } + mode = END; + case END: + r = Z_STREAM_END; + s.bitb = b; + s.bitk = k; + z.avail_in = n; + z.total_in += p - z.next_in_index; + z.next_in_index = p; + s.write = q; + return s.inflate_flush(r); + + case BADCODE: // x: got error + + r = Z_DATA_ERROR; + + s.bitb = b; + s.bitk = k; + z.avail_in = n; + z.total_in += p - z.next_in_index; + z.next_in_index = p; + s.write = q; + return s.inflate_flush(r); + + default: + r = Z_STREAM_ERROR; + + s.bitb = b; + s.bitk = k; + z.avail_in = n; + z.total_in += p - z.next_in_index; + z.next_in_index = p; + s.write = q; + return s.inflate_flush(r); + } + } + } + + void free(ZStream z) { + // ZFREE(z, c); + } + + // Called with number of bytes left to write in window at least 258 + // (the maximum string length) and number of input bytes available + // at least ten. The ten bytes are six bytes for the longest length/ + // distance pair plus four bytes for overloading the bit buffer. + + int inflate_fast(int bl, int bd, + int[] tl, int tl_index, + int[] td, int td_index, + InfBlocks s, ZStream z) { + int t; // temporary pointer + int[] tp; // temporary pointer + int tp_index; // temporary pointer + int e; // extra bits or operation + int b; // bit buffer + int k; // bits in bit buffer + int p; // input data pointer + int n; // bytes available there + int q; // output window write pointer + int m; // bytes to end of window or read pointer + int ml; // mask for literal/length tree + int md; // mask for distance tree + int c; // bytes to copy + int d; // distance back to copy from + int r; // copy source pointer + + int tp_index_t_3; // (tp_index+t)*3 + + // load input, output, bit values + p = z.next_in_index; + n = z.avail_in; + b = s.bitb; + k = s.bitk; + q = s.write; + m = q < s.read ? s.read - q - 1 : s.end - q; + + // initialize masks + ml = inflate_mask[bl]; + md = inflate_mask[bd]; + + // do until not enough input or output space for fast loop + do { // assume called with m >= 258 && n >= 10 + // get literal/length code + while (k < (20)) { // max bits for literal/length code + n--; + b |= (z.next_in[p++] & 0xff) << k; + k += 8; + } + + t = b & ml; + tp = tl; + tp_index = tl_index; + tp_index_t_3 = (tp_index + t) * 3; + if ((e = tp[tp_index_t_3]) == 0) { + b >>= (tp[tp_index_t_3 + 1]); + k -= (tp[tp_index_t_3 + 1]); + + s.window[q++] = (byte) tp[tp_index_t_3 + 2]; + m--; + continue; + } + do { + + b >>= (tp[tp_index_t_3 + 1]); + k -= (tp[tp_index_t_3 + 1]); + + if ((e & 16) != 0) { + e &= 15; + c = tp[tp_index_t_3 + 2] + (b & inflate_mask[e]); + + b >>= e; + k -= e; + + // decode distance base of block to copy + while (k < (15)) { // max bits for distance code + n--; + b |= (z.next_in[p++] & 0xff) << k; + k += 8; + } + + t = b & md; + tp = td; + tp_index = td_index; + tp_index_t_3 = (tp_index + t) * 3; + e = tp[tp_index_t_3]; + + do { + + b >>= (tp[tp_index_t_3 + 1]); + k -= (tp[tp_index_t_3 + 1]); + + if ((e & 16) != 0) { + // get extra bits to add to distance base + e &= 15; + while (k < (e)) { // get extra bits (up to 13) + n--; + b |= (z.next_in[p++] & 0xff) << k; + k += 8; + } + + d = tp[tp_index_t_3 + 2] + (b & inflate_mask[e]); + + b >>= (e); + k -= (e); + + // do the copy + m -= c; + if (q >= d) { // offset before dest + // just copy + r = q - d; + if (q - r > 0 && 2 > (q - r)) { + s.window[q++] = s.window[r++]; // minimum count is three, + s.window[q++] = s.window[r++]; // so unroll loop a little + c -= 2; + } else { + System.arraycopy(s.window, r, s.window, q, 2); + q += 2; + r += 2; + c -= 2; + } + } else { // else offset after destination + r = q - d; + do { + r += s.end; // force pointer in window + } while (r < 0); // covers invalid distances + e = s.end - r; + if (c > e) { // if source crosses, + c -= e; // wrapped copy + if (q - r > 0 && e > (q - r)) { + do { + s.window[q++] = s.window[r++]; + } + while (--e != 0); + } else { + System.arraycopy(s.window, r, s.window, q, e); + q += e; + r += e; + e = 0; + } + r = 0; // copy rest from start of window + } + + } + + // copy all or what's left + if (q - r > 0 && c > (q - r)) { + do { + s.window[q++] = s.window[r++]; + } + while (--c != 0); + } else { + System.arraycopy(s.window, r, s.window, q, c); + q += c; + r += c; + c = 0; + } + break; + } else if ((e & 64) == 0) { + t += tp[tp_index_t_3 + 2]; + t += (b & inflate_mask[e]); + tp_index_t_3 = (tp_index + t) * 3; + e = tp[tp_index_t_3]; + } else { + z.msg = "invalid distance code"; + + c = z.avail_in - n; + c = (k >> 3) < c ? k >> 3 : c; + n += c; + p -= c; + k -= c << 3; + + s.bitb = b; + s.bitk = k; + z.avail_in = n; + z.total_in += p - z.next_in_index; + z.next_in_index = p; + s.write = q; + + return Z_DATA_ERROR; + } + } + while (true); + break; + } + + if ((e & 64) == 0) { + t += tp[tp_index_t_3 + 2]; + t += (b & inflate_mask[e]); + tp_index_t_3 = (tp_index + t) * 3; + if ((e = tp[tp_index_t_3]) == 0) { + + b >>= (tp[tp_index_t_3 + 1]); + k -= (tp[tp_index_t_3 + 1]); + + s.window[q++] = (byte) tp[tp_index_t_3 + 2]; + m--; + break; + } + } else if ((e & 32) != 0) { + + c = z.avail_in - n; + c = (k >> 3) < c ? k >> 3 : c; + n += c; + p -= c; + k -= c << 3; + + s.bitb = b; + s.bitk = k; + z.avail_in = n; + z.total_in += p - z.next_in_index; + z.next_in_index = p; + s.write = q; + + return Z_STREAM_END; + } else { + z.msg = "invalid literal/length code"; + + c = z.avail_in - n; + c = (k >> 3) < c ? k >> 3 : c; + n += c; + p -= c; + k -= c << 3; + + s.bitb = b; + s.bitk = k; + z.avail_in = n; + z.total_in += p - z.next_in_index; + z.next_in_index = p; + s.write = q; + + return Z_DATA_ERROR; + } + } + while (true); + } + while (m >= 258 && n >= 10); + + // not enough input or output--restore pointers and return + c = z.avail_in - n; + c = (k >> 3) < c ? k >> 3 : c; + n += c; + p -= c; + k -= c << 3; + + s.bitb = b; + s.bitk = k; + z.avail_in = n; + z.total_in += p - z.next_in_index; + z.next_in_index = p; + s.write = q; + + return Z_OK; + } +} diff --git a/netty-zlib/src/main/java/io/netty/zlib/InfTree.java b/netty-zlib/src/main/java/io/netty/zlib/InfTree.java new file mode 100644 index 0000000..6992676 --- /dev/null +++ b/netty-zlib/src/main/java/io/netty/zlib/InfTree.java @@ -0,0 +1,490 @@ +package io.netty.zlib; + +final class InfTree { + + static final private int MANY = 1440; + + static final private int Z_OK = 0; + static final private int Z_STREAM_END = 1; + static final private int Z_NEED_DICT = 2; + static final private int Z_ERRNO = -1; + static final private int Z_STREAM_ERROR = -2; + static final private int Z_DATA_ERROR = -3; + static final private int Z_MEM_ERROR = -4; + static final private int Z_BUF_ERROR = -5; + static final private int Z_VERSION_ERROR = -6; + + static final int fixed_bl = 9; + static final int fixed_bd = 5; + + static final int[] fixed_tl = { + 96, 7, 256, 0, 8, 80, 0, 8, 16, 84, 8, 115, + 82, 7, 31, 0, 8, 112, 0, 8, 48, 0, 9, 192, + 80, 7, 10, 0, 8, 96, 0, 8, 32, 0, 9, 160, + 0, 8, 0, 0, 8, 128, 0, 8, 64, 0, 9, 224, + 80, 7, 6, 0, 8, 88, 0, 8, 24, 0, 9, 144, + 83, 7, 59, 0, 8, 120, 0, 8, 56, 0, 9, 208, + 81, 7, 17, 0, 8, 104, 0, 8, 40, 0, 9, 176, + 0, 8, 8, 0, 8, 136, 0, 8, 72, 0, 9, 240, + 80, 7, 4, 0, 8, 84, 0, 8, 20, 85, 8, 227, + 83, 7, 43, 0, 8, 116, 0, 8, 52, 0, 9, 200, + 81, 7, 13, 0, 8, 100, 0, 8, 36, 0, 9, 168, + 0, 8, 4, 0, 8, 132, 0, 8, 68, 0, 9, 232, + 80, 7, 8, 0, 8, 92, 0, 8, 28, 0, 9, 152, + 84, 7, 83, 0, 8, 124, 0, 8, 60, 0, 9, 216, + 82, 7, 23, 0, 8, 108, 0, 8, 44, 0, 9, 184, + 0, 8, 12, 0, 8, 140, 0, 8, 76, 0, 9, 248, + 80, 7, 3, 0, 8, 82, 0, 8, 18, 85, 8, 163, + 83, 7, 35, 0, 8, 114, 0, 8, 50, 0, 9, 196, + 81, 7, 11, 0, 8, 98, 0, 8, 34, 0, 9, 164, + 0, 8, 2, 0, 8, 130, 0, 8, 66, 0, 9, 228, + 80, 7, 7, 0, 8, 90, 0, 8, 26, 0, 9, 148, + 84, 7, 67, 0, 8, 122, 0, 8, 58, 0, 9, 212, + 82, 7, 19, 0, 8, 106, 0, 8, 42, 0, 9, 180, + 0, 8, 10, 0, 8, 138, 0, 8, 74, 0, 9, 244, + 80, 7, 5, 0, 8, 86, 0, 8, 22, 192, 8, 0, + 83, 7, 51, 0, 8, 118, 0, 8, 54, 0, 9, 204, + 81, 7, 15, 0, 8, 102, 0, 8, 38, 0, 9, 172, + 0, 8, 6, 0, 8, 134, 0, 8, 70, 0, 9, 236, + 80, 7, 9, 0, 8, 94, 0, 8, 30, 0, 9, 156, + 84, 7, 99, 0, 8, 126, 0, 8, 62, 0, 9, 220, + 82, 7, 27, 0, 8, 110, 0, 8, 46, 0, 9, 188, + 0, 8, 14, 0, 8, 142, 0, 8, 78, 0, 9, 252, + 96, 7, 256, 0, 8, 81, 0, 8, 17, 85, 8, 131, + 82, 7, 31, 0, 8, 113, 0, 8, 49, 0, 9, 194, + 80, 7, 10, 0, 8, 97, 0, 8, 33, 0, 9, 162, + 0, 8, 1, 0, 8, 129, 0, 8, 65, 0, 9, 226, + 80, 7, 6, 0, 8, 89, 0, 8, 25, 0, 9, 146, + 83, 7, 59, 0, 8, 121, 0, 8, 57, 0, 9, 210, + 81, 7, 17, 0, 8, 105, 0, 8, 41, 0, 9, 178, + 0, 8, 9, 0, 8, 137, 0, 8, 73, 0, 9, 242, + 80, 7, 4, 0, 8, 85, 0, 8, 21, 80, 8, 258, + 83, 7, 43, 0, 8, 117, 0, 8, 53, 0, 9, 202, + 81, 7, 13, 0, 8, 101, 0, 8, 37, 0, 9, 170, + 0, 8, 5, 0, 8, 133, 0, 8, 69, 0, 9, 234, + 80, 7, 8, 0, 8, 93, 0, 8, 29, 0, 9, 154, + 84, 7, 83, 0, 8, 125, 0, 8, 61, 0, 9, 218, + 82, 7, 23, 0, 8, 109, 0, 8, 45, 0, 9, 186, + 0, 8, 13, 0, 8, 141, 0, 8, 77, 0, 9, 250, + 80, 7, 3, 0, 8, 83, 0, 8, 19, 85, 8, 195, + 83, 7, 35, 0, 8, 115, 0, 8, 51, 0, 9, 198, + 81, 7, 11, 0, 8, 99, 0, 8, 35, 0, 9, 166, + 0, 8, 3, 0, 8, 131, 0, 8, 67, 0, 9, 230, + 80, 7, 7, 0, 8, 91, 0, 8, 27, 0, 9, 150, + 84, 7, 67, 0, 8, 123, 0, 8, 59, 0, 9, 214, + 82, 7, 19, 0, 8, 107, 0, 8, 43, 0, 9, 182, + 0, 8, 11, 0, 8, 139, 0, 8, 75, 0, 9, 246, + 80, 7, 5, 0, 8, 87, 0, 8, 23, 192, 8, 0, + 83, 7, 51, 0, 8, 119, 0, 8, 55, 0, 9, 206, + 81, 7, 15, 0, 8, 103, 0, 8, 39, 0, 9, 174, + 0, 8, 7, 0, 8, 135, 0, 8, 71, 0, 9, 238, + 80, 7, 9, 0, 8, 95, 0, 8, 31, 0, 9, 158, + 84, 7, 99, 0, 8, 127, 0, 8, 63, 0, 9, 222, + 82, 7, 27, 0, 8, 111, 0, 8, 47, 0, 9, 190, + 0, 8, 15, 0, 8, 143, 0, 8, 79, 0, 9, 254, + 96, 7, 256, 0, 8, 80, 0, 8, 16, 84, 8, 115, + 82, 7, 31, 0, 8, 112, 0, 8, 48, 0, 9, 193, + + 80, 7, 10, 0, 8, 96, 0, 8, 32, 0, 9, 161, + 0, 8, 0, 0, 8, 128, 0, 8, 64, 0, 9, 225, + 80, 7, 6, 0, 8, 88, 0, 8, 24, 0, 9, 145, + 83, 7, 59, 0, 8, 120, 0, 8, 56, 0, 9, 209, + 81, 7, 17, 0, 8, 104, 0, 8, 40, 0, 9, 177, + 0, 8, 8, 0, 8, 136, 0, 8, 72, 0, 9, 241, + 80, 7, 4, 0, 8, 84, 0, 8, 20, 85, 8, 227, + 83, 7, 43, 0, 8, 116, 0, 8, 52, 0, 9, 201, + 81, 7, 13, 0, 8, 100, 0, 8, 36, 0, 9, 169, + 0, 8, 4, 0, 8, 132, 0, 8, 68, 0, 9, 233, + 80, 7, 8, 0, 8, 92, 0, 8, 28, 0, 9, 153, + 84, 7, 83, 0, 8, 124, 0, 8, 60, 0, 9, 217, + 82, 7, 23, 0, 8, 108, 0, 8, 44, 0, 9, 185, + 0, 8, 12, 0, 8, 140, 0, 8, 76, 0, 9, 249, + 80, 7, 3, 0, 8, 82, 0, 8, 18, 85, 8, 163, + 83, 7, 35, 0, 8, 114, 0, 8, 50, 0, 9, 197, + 81, 7, 11, 0, 8, 98, 0, 8, 34, 0, 9, 165, + 0, 8, 2, 0, 8, 130, 0, 8, 66, 0, 9, 229, + 80, 7, 7, 0, 8, 90, 0, 8, 26, 0, 9, 149, + 84, 7, 67, 0, 8, 122, 0, 8, 58, 0, 9, 213, + 82, 7, 19, 0, 8, 106, 0, 8, 42, 0, 9, 181, + 0, 8, 10, 0, 8, 138, 0, 8, 74, 0, 9, 245, + 80, 7, 5, 0, 8, 86, 0, 8, 22, 192, 8, 0, + 83, 7, 51, 0, 8, 118, 0, 8, 54, 0, 9, 205, + 81, 7, 15, 0, 8, 102, 0, 8, 38, 0, 9, 173, + 0, 8, 6, 0, 8, 134, 0, 8, 70, 0, 9, 237, + 80, 7, 9, 0, 8, 94, 0, 8, 30, 0, 9, 157, + 84, 7, 99, 0, 8, 126, 0, 8, 62, 0, 9, 221, + 82, 7, 27, 0, 8, 110, 0, 8, 46, 0, 9, 189, + 0, 8, 14, 0, 8, 142, 0, 8, 78, 0, 9, 253, + 96, 7, 256, 0, 8, 81, 0, 8, 17, 85, 8, 131, + 82, 7, 31, 0, 8, 113, 0, 8, 49, 0, 9, 195, + 80, 7, 10, 0, 8, 97, 0, 8, 33, 0, 9, 163, + 0, 8, 1, 0, 8, 129, 0, 8, 65, 0, 9, 227, + 80, 7, 6, 0, 8, 89, 0, 8, 25, 0, 9, 147, + 83, 7, 59, 0, 8, 121, 0, 8, 57, 0, 9, 211, + 81, 7, 17, 0, 8, 105, 0, 8, 41, 0, 9, 179, + 0, 8, 9, 0, 8, 137, 0, 8, 73, 0, 9, 243, + 80, 7, 4, 0, 8, 85, 0, 8, 21, 80, 8, 258, + 83, 7, 43, 0, 8, 117, 0, 8, 53, 0, 9, 203, + 81, 7, 13, 0, 8, 101, 0, 8, 37, 0, 9, 171, + 0, 8, 5, 0, 8, 133, 0, 8, 69, 0, 9, 235, + 80, 7, 8, 0, 8, 93, 0, 8, 29, 0, 9, 155, + 84, 7, 83, 0, 8, 125, 0, 8, 61, 0, 9, 219, + 82, 7, 23, 0, 8, 109, 0, 8, 45, 0, 9, 187, + 0, 8, 13, 0, 8, 141, 0, 8, 77, 0, 9, 251, + 80, 7, 3, 0, 8, 83, 0, 8, 19, 85, 8, 195, + 83, 7, 35, 0, 8, 115, 0, 8, 51, 0, 9, 199, + 81, 7, 11, 0, 8, 99, 0, 8, 35, 0, 9, 167, + 0, 8, 3, 0, 8, 131, 0, 8, 67, 0, 9, 231, + 80, 7, 7, 0, 8, 91, 0, 8, 27, 0, 9, 151, + 84, 7, 67, 0, 8, 123, 0, 8, 59, 0, 9, 215, + 82, 7, 19, 0, 8, 107, 0, 8, 43, 0, 9, 183, + 0, 8, 11, 0, 8, 139, 0, 8, 75, 0, 9, 247, + 80, 7, 5, 0, 8, 87, 0, 8, 23, 192, 8, 0, + 83, 7, 51, 0, 8, 119, 0, 8, 55, 0, 9, 207, + 81, 7, 15, 0, 8, 103, 0, 8, 39, 0, 9, 175, + 0, 8, 7, 0, 8, 135, 0, 8, 71, 0, 9, 239, + 80, 7, 9, 0, 8, 95, 0, 8, 31, 0, 9, 159, + 84, 7, 99, 0, 8, 127, 0, 8, 63, 0, 9, 223, + 82, 7, 27, 0, 8, 111, 0, 8, 47, 0, 9, 191, + 0, 8, 15, 0, 8, 143, 0, 8, 79, 0, 9, 255 + }; + static final int[] fixed_td = { + 80, 5, 1, 87, 5, 257, 83, 5, 17, 91, 5, 4097, + 81, 5, 5, 89, 5, 1025, 85, 5, 65, 93, 5, 16385, + 80, 5, 3, 88, 5, 513, 84, 5, 33, 92, 5, 8193, + 82, 5, 9, 90, 5, 2049, 86, 5, 129, 192, 5, 24577, + 80, 5, 2, 87, 5, 385, 83, 5, 25, 91, 5, 6145, + 81, 5, 7, 89, 5, 1537, 85, 5, 97, 93, 5, 24577, + 80, 5, 4, 88, 5, 769, 84, 5, 49, 92, 5, 12289, + 82, 5, 13, 90, 5, 3073, 86, 5, 193, 192, 5, 24577 + }; + + // Tables for deflate from PKZIP's appnote.txt. + static final int[] cplens = { // Copy lengths for literal codes 257..285 + 3, 4, 5, 6, 7, 8, 9, 10, 11, 13, 15, 17, 19, 23, 27, 31, + 35, 43, 51, 59, 67, 83, 99, 115, 131, 163, 195, 227, 258, 0, 0 + }; + + // see note #13 above about 258 + static final int[] cplext = { // Extra bits for literal codes 257..285 + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, + 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 0, 112, 112 // 112==invalid + }; + + static final int[] cpdist = { // Copy offsets for distance codes 0..29 + 1, 2, 3, 4, 5, 7, 9, 13, 17, 25, 33, 49, 65, 97, 129, 193, + 257, 385, 513, 769, 1025, 1537, 2049, 3073, 4097, 6145, + 8193, 12289, 16385, 24577 + }; + + static final int[] cpdext = { // Extra bits for distance codes + 0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, + 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, + 12, 12, 13, 13}; + + // If BMAX needs to be larger than 16, then h and x[] should be uLong. + static final int BMAX = 15; // maximum bit length of any code + + int[] hn = null; // hufts used in space + int[] v = null; // work area for huft_build + int[] c = null; // bit length count table + int[] r = null; // table entry for structure assignment + int[] u = null; // table stack + int[] x = null; // bit offsets, then code stack + + private int huft_build(int[] b, // code lengths in bits (all assumed <= BMAX) + int bindex, + int n, // number of codes (assumed <= 288) + int s, // number of simple-valued codes (0..s-1) + int[] d, // list of base values for non-simple codes + int[] e, // list of extra bits for non-simple codes + int[] t, // result: starting table + int[] m, // maximum lookup bits, returns actual + int[] hp,// space for trees + int[] hn,// hufts used in space + int[] v // working area: values in order of bit length + ) { + // Given a list of code lengths and a maximum table size, make a set of + // tables to decode that set of codes. Return Z_OK on success, Z_BUF_ERROR + // if the given code set is incomplete (the tables are still built in this + // case), Z_DATA_ERROR if the input is invalid (an over-subscribed set of + // lengths), or Z_MEM_ERROR if not enough memory. + + int a; // counter for codes of length k + int f; // i repeats in table every f entries + int g; // maximum code length + int h; // table level + int i; // counter, current code + int j; // counter + int k; // number of bits in current code + int l; // bits per table (returned in m) + int mask; // (1 << w) - 1, to avoid cc -O bug on HP + int p; // pointer into c[], b[], or v[] + int q; // points to current table + int w; // bits before this table == (l * h) + int xp; // pointer into x + int y; // number of dummy codes added + int z; // number of entries in current table + + // Generate counts for each bit length + + p = 0; + i = n; + do { + c[b[bindex + p]]++; + p++; + i--; // assume all entries <= BMAX + } while (i != 0); + + if (c[0] == n) { // null input--all zero length codes + t[0] = -1; + m[0] = 0; + return Z_OK; + } + + // Find minimum and maximum length, bound *m by those + l = m[0]; + for (j = 1; j <= BMAX; j++) + if (c[j] != 0) break; + k = j; // minimum code length + if (l < j) { + l = j; + } + for (i = BMAX; i != 0; i--) { + if (c[i] != 0) break; + } + g = i; // maximum code length + if (l > i) { + l = i; + } + m[0] = l; + + // Adjust last length count to fill out codes, if needed + for (y = 1 << j; j < i; j++, y <<= 1) { + if ((y -= c[j]) < 0) { + return Z_DATA_ERROR; + } + } + if ((y -= c[i]) < 0) { + return Z_DATA_ERROR; + } + c[i] += y; + + // Generate starting offsets into the value table for each length + x[1] = j = 0; + p = 1; + xp = 2; + while (--i != 0) { // note that i == g from above + x[xp] = (j += c[p]); + xp++; + p++; + } + + // Make a table of values in order of bit lengths + i = 0; + p = 0; + do { + if ((j = b[bindex + p]) != 0) { + v[x[j]++] = i; + } + p++; + } + while (++i < n); + n = x[g]; // set n to length of v + + // Generate the Huffman codes and for each, make the table entries + x[0] = i = 0; // first Huffman code is zero + p = 0; // grab values in bit order + h = -1; // no tables yet--level -1 + w = -l; // bits decoded == (l * h) + u[0] = 0; // just to keep compilers happy + q = 0; // ditto + z = 0; // ditto + + // go through the bit lengths (k already is bits in shortest code) + for (; k <= g; k++) { + a = c[k]; + while (a-- != 0) { + // here i is the Huffman code of length k bits for value *p + // make tables up to required level + while (k > w + l) { + h++; + w += l; // previous table always l bits + // compute minimum size table less than or equal to l bits + z = g - w; + z = (z > l) ? l : z; // table size upper limit + if ((f = 1 << (j = k - w)) > a + 1) { // try a k-w bit table + // too few codes for k-w bit table + f -= a + 1; // deduct codes from patterns left + xp = k; + if (j < z) { + while (++j < z) { // try smaller tables up to z bits + if ((f <<= 1) <= c[++xp]) + break; // enough codes to use up j bits + f -= c[xp]; // else deduct codes from patterns + } + } + } + z = 1 << j; // table entries for j-bit table + + // allocate new table + if (hn[0] + z > MANY) { // (note: doesn't matter for fixed) + return Z_DATA_ERROR; // overflow of MANY + } + u[h] = q = /*hp+*/ hn[0]; // DEBUG + hn[0] += z; + + // connect to last table, if there is one + if (h != 0) { + x[h] = i; // save pattern for backing up + r[0] = (byte) j; // bits in this table + r[1] = (byte) l; // bits to dump before this table + j = i >>> (w - l); + r[2] = q - u[h - 1] - j; // offset to this table + System.arraycopy(r, 0, hp, (u[h - 1] + j) * 3, 3); // connect to last table + } else { + t[0] = q; // first table is returned result + } + } + + // set up table entry in r + r[1] = (byte) (k - w); + if (p >= n) { + r[0] = 128 + 64; // out of values--invalid code + } else if (v[p] < s) { + r[0] = (byte) (v[p] < 256 ? 0 : 32 + 64); // 256 is end-of-block + r[2] = v[p++]; // simple code is just the value + } else { + r[0] = (byte) (e[v[p] - s] + 16 + 64); // non-simple--look up in lists + r[2] = d[v[p++] - s]; + } + + // fill code-like entries with r + f = 1 << (k - w); + for (j = i >>> w; j < z; j += f) { + System.arraycopy(r, 0, hp, (q + j) * 3, 3); + } + + // backwards increment the k-bit code i + for (j = 1 << (k - 1); (i & j) != 0; j >>>= 1) { + i ^= j; + } + i ^= j; + + // backup over finished tables + mask = (1 << w) - 1; // needed on HP, cc -O bug + while ((i & mask) != x[h]) { + h--; // don't need to update q + w -= l; + mask = (1 << w) - 1; + } + } + } + // Return Z_BUF_ERROR if we were given an incomplete table + return y != 0 && g != 1 ? Z_BUF_ERROR : Z_OK; + } + + int inflate_trees_bits(int[] c, // 19 code lengths + int[] bb, // bits tree desired/actual depth + int[] tb, // bits tree result + int[] hp, // space for trees + ZStream z // for messages + ) { + int result; + initWorkArea(19); + hn[0] = 0; + result = huft_build(c, 0, 19, 19, null, null, tb, bb, hp, hn, v); + + if (result == Z_DATA_ERROR) { + z.msg = "oversubscribed dynamic bit lengths tree"; + } else if (result == Z_BUF_ERROR || bb[0] == 0) { + z.msg = "incomplete dynamic bit lengths tree"; + result = Z_DATA_ERROR; + } + return result; + } + + int inflate_trees_dynamic(int nl, // number of literal/length codes + int nd, // number of distance codes + int[] c, // that many (total) code lengths + int[] bl, // literal desired/actual bit depth + int[] bd, // distance desired/actual bit depth + int[] tl, // literal/length tree result + int[] td, // distance tree result + int[] hp, // space for trees + ZStream z // for messages + ) { + int result; + + // build literal/length tree + initWorkArea(288); + hn[0] = 0; + result = huft_build(c, 0, nl, 257, cplens, cplext, tl, bl, hp, hn, v); + if (result != Z_OK || bl[0] == 0) { + if (result == Z_DATA_ERROR) { + z.msg = "oversubscribed literal/length tree"; + } else if (result != Z_MEM_ERROR) { + z.msg = "incomplete literal/length tree"; + result = Z_DATA_ERROR; + } + return result; + } + + // build distance tree + initWorkArea(288); + result = huft_build(c, nl, nd, 0, cpdist, cpdext, td, bd, hp, hn, v); + + if (result != Z_OK || (bd[0] == 0 && nl > 257)) { + if (result == Z_DATA_ERROR) { + z.msg = "oversubscribed distance tree"; + } else if (result == Z_BUF_ERROR) { + z.msg = "incomplete distance tree"; + result = Z_DATA_ERROR; + } else if (result != Z_MEM_ERROR) { + z.msg = "empty distance tree with lengths"; + result = Z_DATA_ERROR; + } + return result; + } + + return Z_OK; + } + + static int inflate_trees_fixed(int[] bl, //literal desired/actual bit depth + int[] bd, //distance desired/actual bit depth + int[][] tl,//literal/length tree result + int[][] td,//distance tree result + ZStream z //for memory allocation + ) { + bl[0] = fixed_bl; + bd[0] = fixed_bd; + tl[0] = fixed_tl; + td[0] = fixed_td; + return Z_OK; + } + + private void initWorkArea(int vsize) { + if (hn == null) { + hn = new int[1]; + v = new int[vsize]; + c = new int[BMAX + 1]; + r = new int[3]; + u = new int[BMAX]; + x = new int[BMAX + 1]; + } + if (v.length < vsize) { + v = new int[vsize]; + } + for (int i = 0; i < vsize; i++) { + v[i] = 0; + } + for (int i = 0; i < BMAX + 1; i++) { + c[i] = 0; + } + for (int i = 0; i < 3; i++) { + r[i] = 0; + } + System.arraycopy(c, 0, u, 0, BMAX); + System.arraycopy(c, 0, x, 0, BMAX + 1); + } +} diff --git a/netty-zlib/src/main/java/io/netty/zlib/Inflate.java b/netty-zlib/src/main/java/io/netty/zlib/Inflate.java new file mode 100644 index 0000000..8b28c9d --- /dev/null +++ b/netty-zlib/src/main/java/io/netty/zlib/Inflate.java @@ -0,0 +1,764 @@ +package io.netty.zlib; + +final class Inflate { + + static final private int MAX_WBITS = 15; // 32K LZ77 window + + // preset dictionary flag in zlib header + static final private int PRESET_DICT = 0x20; + + static final int Z_NO_FLUSH = 0; + static final int Z_PARTIAL_FLUSH = 1; + static final int Z_SYNC_FLUSH = 2; + static final int Z_FULL_FLUSH = 3; + static final int Z_FINISH = 4; + + static final private int Z_DEFLATED = 8; + + static final private int Z_OK = 0; + static final private int Z_STREAM_END = 1; + static final private int Z_NEED_DICT = 2; + static final private int Z_ERRNO = -1; + static final private int Z_STREAM_ERROR = -2; + static final private int Z_DATA_ERROR = -3; + static final private int Z_MEM_ERROR = -4; + static final private int Z_BUF_ERROR = -5; + static final private int Z_VERSION_ERROR = -6; + + static final private int METHOD = 0; // waiting for method byte + static final private int FLAG = 1; // waiting for flag byte + static final private int DICT4 = 2; // four dictionary check bytes to go + static final private int DICT3 = 3; // three dictionary check bytes to go + static final private int DICT2 = 4; // two dictionary check bytes to go + static final private int DICT1 = 5; // one dictionary check byte to go + static final private int DICT0 = 6; // waiting for inflateSetDictionary + static final private int BLOCKS = 7; // decompressing blocks + static final private int CHECK4 = 8; // four check bytes to go + static final private int CHECK3 = 9; // three check bytes to go + static final private int CHECK2 = 10; // two check bytes to go + static final private int CHECK1 = 11; // one check byte to go + static final private int DONE = 12; // finished check, done + static final private int BAD = 13; // got an error--stay here + + static final private int HEAD = 14; + static final private int LENGTH = 15; + static final private int TIME = 16; + static final private int OS = 17; + static final private int EXLEN = 18; + static final private int EXTRA = 19; + static final private int NAME = 20; + static final private int COMMENT = 21; + static final private int HCRC = 22; + static final private int FLAGS = 23; + + static final int INFLATE_ANY = 0x40000000; + + int mode; // current inflate mode + + // mode dependent information + int method; // if FLAGS, method byte + + // if CHECK, check values to compare + long was = -1; // computed check value + long need; // stream check value + + // if BAD, inflateSync's marker bytes count + int marker; + + // mode independent information + int wrap; // flag for no wrapper + // 0: no wrapper + // 1: zlib header + // 2: gzip header + // 4: auto detection + + int wbits; // log2(window size) (8..15, defaults to 15) + + InfBlocks blocks; // current inflate_blocks state + + private final ZStream z; + + private int flags; + + private int need_bytes = -1; + private final byte[] crcbuf = new byte[4]; + + GZIPHeader gheader = null; + + int inflateReset() { + if (z == null) return Z_STREAM_ERROR; + + z.total_in = z.total_out = 0; + z.msg = null; + this.mode = HEAD; + this.need_bytes = -1; + this.blocks.reset(); + return Z_OK; + } + + int inflateEnd() { + if (blocks != null) { + blocks.free(); + } + return Z_OK; + } + + Inflate(ZStream z) { + this.z = z; + } + + int inflateInit(int w) { + z.msg = null; + blocks = null; + + // handle undocumented wrap option (no zlib header or check) + wrap = 0; + if (w < 0) { + w = -w; + } else if ((w & INFLATE_ANY) != 0) { + wrap = 4; + w &= ~INFLATE_ANY; + if (w < 48) + w &= 15; + } else if ((w & ~31) != 0) { // for example, DEF_WBITS + 32 + wrap = 4; // zlib and gzip wrapped data should be accepted. + w &= 15; + } else { + wrap = (w >> 4) + 1; + if (w < 48) + w &= 15; + } + + if (w < 8 || w > 15) { + inflateEnd(); + return Z_STREAM_ERROR; + } + if (blocks != null && wbits != w) { + blocks.free(); + blocks = null; + } + + // set window size + wbits = w; + + this.blocks = new InfBlocks(z, 1 << w); + + // reset state + inflateReset(); + + return Z_OK; + } + + int inflate(int f) { + int hold = 0; + + int r; + int b; + + if (z == null || z.next_in == null) { + if (f == Z_FINISH && this.mode == HEAD) + return Z_OK; + return Z_STREAM_ERROR; + } + + f = f == Z_FINISH ? Z_BUF_ERROR : Z_OK; + r = Z_BUF_ERROR; + while (true) { + + switch (this.mode) { + case HEAD: + if (wrap == 0) { + this.mode = BLOCKS; + break; + } + + try { + r = readBytes(2, r, f); + } catch (Return e) { + return e.r; + } + + if ((wrap == 4 || (wrap & 2) != 0) && + this.need == 0x8b1fL) { // gzip header + if (wrap == 4) { + wrap = 2; + } + z.adler = new CRC32(); + checksum(2, this.need); + + if (gheader == null) + gheader = new GZIPHeader(); + + this.mode = FLAGS; + break; + } + + if ((wrap & 2) != 0) { + this.mode = BAD; + z.msg = "incorrect header check"; + break; + } + + flags = 0; + + this.method = ((int) this.need) & 0xff; + b = ((int) (this.need >> 8)) & 0xff; + + if (((wrap & 1) == 0 || // check if zlib header allowed + (((this.method << 8) + b) % 31) != 0) && + (this.method & 0xf) != Z_DEFLATED) { + if (wrap == 4) { + z.next_in_index -= 2; + z.avail_in += 2; + z.total_in -= 2; + wrap = 0; + this.mode = BLOCKS; + break; + } + this.mode = BAD; + z.msg = "incorrect header check"; + // since zlib 1.2, it is allowted to inflateSync for this case. + /* + this.marker = 5; // can't try inflateSync + */ + break; + } + + if ((this.method & 0xf) != Z_DEFLATED) { + this.mode = BAD; + z.msg = "unknown compression method"; + // since zlib 1.2, it is allowted to inflateSync for this case. + /* + this.marker = 5; // can't try inflateSync + */ + break; + } + + if (wrap == 4) { + wrap = 1; + } + + if ((this.method >> 4) + 8 > this.wbits) { + this.mode = BAD; + z.msg = "invalid window size"; + // since zlib 1.2, it is allowted to inflateSync for this case. + /* + this.marker = 5; // can't try inflateSync + */ + break; + } + + z.adler = new Adler32(); + + if ((b & PRESET_DICT) == 0) { + this.mode = BLOCKS; + break; + } + this.mode = DICT4; + case DICT4: + + if (z.avail_in == 0) return r; + r = f; + + z.avail_in--; + z.total_in++; + this.need = ((long) (z.next_in[z.next_in_index++] & 0xff) << 24) & 0xff000000L; + this.mode = DICT3; + case DICT3: + + if (z.avail_in == 0) return r; + r = f; + + z.avail_in--; + z.total_in++; + this.need += ((z.next_in[z.next_in_index++] & 0xff) << 16) & 0xff0000L; + this.mode = DICT2; + case DICT2: + + if (z.avail_in == 0) return r; + r = f; + + z.avail_in--; + z.total_in++; + this.need += ((z.next_in[z.next_in_index++] & 0xff) << 8) & 0xff00L; + this.mode = DICT1; + case DICT1: + + if (z.avail_in == 0) return r; + r = f; + + z.avail_in--; + z.total_in++; + this.need += (z.next_in[z.next_in_index++] & 0xffL); + z.adler.reset(this.need); + this.mode = DICT0; + return Z_NEED_DICT; + case DICT0: + this.mode = BAD; + z.msg = "need dictionary"; + this.marker = 0; // can try inflateSync + return Z_STREAM_ERROR; + case BLOCKS: + r = this.blocks.proc(r); + if (r == Z_DATA_ERROR) { + this.mode = BAD; + this.marker = 0; // can try inflateSync + break; + } + if (r == Z_OK) { + r = f; + } + if (r != Z_STREAM_END) { + return r; + } + r = f; + this.was = z.adler.getValue(); + this.blocks.reset(); + if (this.wrap == 0) { + this.mode = DONE; + break; + } + this.mode = CHECK4; + case CHECK4: + + if (z.avail_in == 0) return r; + r = f; + + z.avail_in--; + z.total_in++; + this.need = ((long) (z.next_in[z.next_in_index++] & 0xff) << 24) & 0xff000000L; + this.mode = CHECK3; + case CHECK3: + + if (z.avail_in == 0) return r; + r = f; + + z.avail_in--; + z.total_in++; + this.need += ((z.next_in[z.next_in_index++] & 0xff) << 16) & 0xff0000L; + this.mode = CHECK2; + case CHECK2: + + if (z.avail_in == 0) return r; + r = f; + + z.avail_in--; + z.total_in++; + this.need += ((z.next_in[z.next_in_index++] & 0xff) << 8) & 0xff00L; + this.mode = CHECK1; + case CHECK1: + + if (z.avail_in == 0) return r; + r = f; + + z.avail_in--; + z.total_in++; + this.need += (z.next_in[z.next_in_index++] & 0xffL); + + if (flags != 0) { // gzip + this.need = ((this.need & 0xff000000) >> 24 | + (this.need & 0x00ff0000) >> 8 | + (this.need & 0x0000ff00) << 8 | + (this.need & 0x0000ffff) << 24) & 0xffffffffL; + } + + if (((int) (this.was)) != ((int) (this.need))) { + z.msg = "incorrect data check"; + // chack is delayed + /* + this.mode = BAD; + this.marker = 5; // can't try inflateSync + break; + */ + } else if (flags != 0 && gheader != null) { + gheader.crc = this.need; + } + + this.mode = LENGTH; + case LENGTH: + if (wrap != 0 && flags != 0) { + + try { + r = readBytes(4, r, f); + } catch (Return e) { + return e.r; + } + + if (z.msg != null && z.msg.equals("incorrect data check")) { + this.mode = BAD; + this.marker = 5; // can't try inflateSync + break; + } + + if (this.need != (z.total_out & 0xffffffffL)) { + z.msg = "incorrect length check"; + this.mode = BAD; + break; + } + z.msg = null; + } else { + if (z.msg != null && z.msg.equals("incorrect data check")) { + this.mode = BAD; + this.marker = 5; // can't try inflateSync + break; + } + } + + this.mode = DONE; + case DONE: + return Z_STREAM_END; + case BAD: + return Z_DATA_ERROR; + + case FLAGS: + + try { + r = readBytes(2, r, f); + } catch (Return e) { + return e.r; + } + + flags = ((int) this.need) & 0xffff; + + if ((flags & 0xff) != Z_DEFLATED) { + z.msg = "unknown compression method"; + this.mode = BAD; + break; + } + if ((flags & 0xe000) != 0) { + z.msg = "unknown header flags set"; + this.mode = BAD; + break; + } + + if ((flags & 0x0200) != 0) { + checksum(2, this.need); + } + + this.mode = TIME; + + case TIME: + try { + r = readBytes(4, r, f); + } catch (Return e) { + return e.r; + } + if (gheader != null) + gheader.time = this.need; + if ((flags & 0x0200) != 0) { + checksum(4, this.need); + } + this.mode = OS; + case OS: + try { + r = readBytes(2, r, f); + } catch (Return e) { + return e.r; + } + if (gheader != null) { + gheader.xflags = ((int) this.need) & 0xff; + gheader.os = (((int) this.need) >> 8) & 0xff; + } + if ((flags & 0x0200) != 0) { + checksum(2, this.need); + } + this.mode = EXLEN; + case EXLEN: + if ((flags & 0x0400) != 0) { + try { + r = readBytes(2, r, f); + } catch (Return e) { + return e.r; + } + if (gheader != null) { + gheader.extra = new byte[((int) this.need) & 0xffff]; + } + if ((flags & 0x0200) != 0) { + checksum(2, this.need); + } + } else if (gheader != null) { + gheader.extra = null; + } + this.mode = EXTRA; + + case EXTRA: + if ((flags & 0x0400) != 0) { + try { + r = readBytes(r, f); + if (gheader != null) { + byte[] foo = tmp_string.toByteArray(); + tmp_string = null; + if (foo.length == gheader.extra.length) { + System.arraycopy(foo, 0, gheader.extra, 0, foo.length); + } else { + z.msg = "bad extra field length"; + this.mode = BAD; + break; + } + } + } catch (Return e) { + return e.r; + } + } else if (gheader != null) { + gheader.extra = null; + } + this.mode = NAME; + case NAME: + if ((flags & 0x0800) != 0) { + try { + r = readString(r, f); + if (gheader != null) { + gheader.name = tmp_string.toByteArray(); + } + tmp_string = null; + } catch (Return e) { + return e.r; + } + } else if (gheader != null) { + gheader.name = null; + } + this.mode = COMMENT; + case COMMENT: + if ((flags & 0x1000) != 0) { + try { + r = readString(r, f); + if (gheader != null) { + gheader.comment = tmp_string.toByteArray(); + } + tmp_string = null; + } catch (Return e) { + return e.r; + } + } else if (gheader != null) { + gheader.comment = null; + } + this.mode = HCRC; + case HCRC: + if ((flags & 0x0200) != 0) { + try { + r = readBytes(2, r, f); + } catch (Return e) { + return e.r; + } + if (gheader != null) { + gheader.hcrc = (int) (this.need & 0xffff); + } + if (this.need != (z.adler.getValue() & 0xffffL)) { + this.mode = BAD; + z.msg = "header crc mismatch"; + this.marker = 5; // can't try inflateSync + break; + } + } + z.adler = new CRC32(); + + this.mode = BLOCKS; + break; + default: + return Z_STREAM_ERROR; + } + } + } + + int inflateSetDictionary(byte[] dictionary, int dictLength) { + if (z == null || (this.mode != DICT0 && this.wrap != 0)) { + return Z_STREAM_ERROR; + } + + int index = 0; + int length = dictLength; + + if (this.mode == DICT0) { + long adler_need = z.adler.getValue(); + z.adler.reset(); + z.adler.update(dictionary, 0, dictLength); + if (z.adler.getValue() != adler_need) { + return Z_DATA_ERROR; + } + } + + z.adler.reset(); + + if (length >= (1 << this.wbits)) { + length = (1 << this.wbits) - 1; + index = dictLength - length; + } + this.blocks.set_dictionary(dictionary, index, length); + this.mode = BLOCKS; + return Z_OK; + } + + static private final byte[] mark = {(byte) 0, (byte) 0, (byte) 0xff, (byte) 0xff}; + + int inflateSync() { + int n; // number of bytes to look at + int p; // pointer to bytes + int m; // number of marker bytes found in a row + long r, w; // temporaries to save total_in and total_out + + // set up + if (z == null) + return Z_STREAM_ERROR; + if (this.mode != BAD) { + this.mode = BAD; + this.marker = 0; + } + if ((n = z.avail_in) == 0) + return Z_BUF_ERROR; + + p = z.next_in_index; + m = this.marker; + // search + while (n != 0 && m < 4) { + if (z.next_in[p] == mark[m]) { + m++; + } else if (z.next_in[p] != 0) { + m = 0; + } else { + m = 4 - m; + } + p++; + n--; + } + + // restore + z.total_in += p - z.next_in_index; + z.next_in_index = p; + z.avail_in = n; + this.marker = m; + + // return no joy or set up to restart on a new block + if (m != 4) { + return Z_DATA_ERROR; + } + r = z.total_in; + w = z.total_out; + inflateReset(); + z.total_in = r; + z.total_out = w; + this.mode = BLOCKS; + + return Z_OK; + } + + // Returns true if inflate is currently at the end of a block generated + // by Z_SYNC_FLUSH or Z_FULL_FLUSH. This function is used by one PPP + // implementation to provide an additional safety check. PPP uses Z_SYNC_FLUSH + // but removes the length bytes of the resulting empty stored block. When + // decompressing, PPP checks that at the end of input packet, inflate is + // waiting for these length bytes. + int inflateSyncPoint() { + if (z == null || this.blocks == null) + return Z_STREAM_ERROR; + return this.blocks.sync_point(); + } + + private int readBytes(int n, int r, int f) throws Return { + if (need_bytes == -1) { + need_bytes = n; + this.need = 0; + } + while (need_bytes > 0) { + if (z.avail_in == 0) { + throw new Return(r); + } + r = f; + z.avail_in--; + z.total_in++; + this.need = this.need | + ((long) (z.next_in[z.next_in_index++] & 0xff) << ((n - need_bytes) * 8)); + need_bytes--; + } + if (n == 2) { + this.need &= 0xffffL; + } else if (n == 4) { + this.need &= 0xffffffffL; + } + need_bytes = -1; + return r; + } + + class Return extends Exception { + int r; + + Return(int r) { + this.r = r; + } + } + + private java.io.ByteArrayOutputStream tmp_string = null; + + private int readString(int r, int f) throws Return { + if (tmp_string == null) { + tmp_string = new java.io.ByteArrayOutputStream(); + } + int b = 0; + do { + if (z.avail_in == 0) { + throw new Return(r); + } + r = f; + z.avail_in--; + z.total_in++; + b = z.next_in[z.next_in_index]; + if (b != 0) tmp_string.write(z.next_in, z.next_in_index, 1); + z.adler.update(z.next_in, z.next_in_index, 1); + z.next_in_index++; + } while (b != 0); + return r; + } + + private int readBytes(int r, int f) throws Return { + if (tmp_string == null) { + tmp_string = new java.io.ByteArrayOutputStream(); + } + int b = 0; + while (this.need > 0) { + if (z.avail_in == 0) { + throw new Return(r); + } + r = f; + z.avail_in--; + z.total_in++; + b = z.next_in[z.next_in_index]; + tmp_string.write(z.next_in, z.next_in_index, 1); + z.adler.update(z.next_in, z.next_in_index, 1); + z.next_in_index++; + this.need--; + } + return r; + } + + private void checksum(int n, long v) { + for (int i = 0; i < n; i++) { + crcbuf[i] = (byte) (v & 0xff); + v >>= 8; + } + z.adler.update(crcbuf, 0, n); + } + + public GZIPHeader getGZIPHeader() { + return gheader; + } + + boolean inParsingHeader() { + switch (mode) { + case HEAD: + case DICT4: + case DICT3: + case DICT2: + case DICT1: + case FLAGS: + case TIME: + case OS: + case EXLEN: + case EXTRA: + case NAME: + case COMMENT: + case HCRC: + return true; + default: + return false; + } + } +} diff --git a/netty-zlib/src/main/java/io/netty/zlib/Inflater.java b/netty-zlib/src/main/java/io/netty/zlib/Inflater.java new file mode 100644 index 0000000..f8a0b02 --- /dev/null +++ b/netty-zlib/src/main/java/io/netty/zlib/Inflater.java @@ -0,0 +1,131 @@ +package io.netty.zlib; + +final public class Inflater extends ZStream { + + static final private int MAX_WBITS = 15; // 32K LZ77 window + static final private int DEF_WBITS = MAX_WBITS; + + static final private int Z_NO_FLUSH = 0; + static final private int Z_PARTIAL_FLUSH = 1; + static final private int Z_SYNC_FLUSH = 2; + static final private int Z_FULL_FLUSH = 3; + static final private int Z_FINISH = 4; + + static final private int MAX_MEM_LEVEL = 9; + + static final private int Z_OK = 0; + static final private int Z_STREAM_END = 1; + static final private int Z_NEED_DICT = 2; + static final private int Z_ERRNO = -1; + static final private int Z_STREAM_ERROR = -2; + static final private int Z_DATA_ERROR = -3; + static final private int Z_MEM_ERROR = -4; + static final private int Z_BUF_ERROR = -5; + static final private int Z_VERSION_ERROR = -6; + + public Inflater() { + super(); + init(); + } + + public Inflater(JZlib.WrapperType wrapperType) throws GZIPException { + this(DEF_WBITS, wrapperType); + } + + public Inflater(int w, JZlib.WrapperType wrapperType) throws GZIPException { + super(); + int ret = init(w, wrapperType); + if (ret != Z_OK) + throw new GZIPException(ret + ": " + msg); + } + + public Inflater(int w) throws GZIPException { + this(w, false); + } + + public Inflater(boolean nowrap) throws GZIPException { + this(DEF_WBITS, nowrap); + } + + public Inflater(int w, boolean nowrap) throws GZIPException { + super(); + int ret = init(w, nowrap); + if (ret != Z_OK) + throw new GZIPException(ret + ": " + msg); + } + + private boolean finished = false; + + public int init() { + return init(DEF_WBITS); + } + + public int init(JZlib.WrapperType wrapperType) { + return init(DEF_WBITS, wrapperType); + } + + public int init(int w, JZlib.WrapperType wrapperType) { + boolean nowrap = false; + if (wrapperType == JZlib.W_NONE) { + nowrap = true; + } else if (wrapperType == JZlib.W_GZIP) { + w += 16; + } else if (wrapperType == JZlib.W_ANY) { + w |= Inflate.INFLATE_ANY; + } else if (wrapperType == JZlib.W_ZLIB) { + } + return init(w, nowrap); + } + + public int init(boolean nowrap) { + return init(DEF_WBITS, nowrap); + } + + public int init(int w) { + return init(w, false); + } + + public int init(int w, boolean nowrap) { + finished = false; + istate = new Inflate(this); + return istate.inflateInit(nowrap ? -w : w); + } + + public int inflate(int f) { + if (istate == null) return Z_STREAM_ERROR; + int ret = istate.inflate(f); + if (ret == Z_STREAM_END) + finished = true; + return ret; + } + + public int end() { + finished = true; + if (istate == null) return Z_STREAM_ERROR; + int ret = istate.inflateEnd(); +// istate = null; + return ret; + } + + public int sync() { + if (istate == null) + return Z_STREAM_ERROR; + return istate.inflateSync(); + } + + public int syncPoint() { + if (istate == null) + return Z_STREAM_ERROR; + return istate.inflateSyncPoint(); + } + + public int setDictionary(byte[] dictionary, int dictLength) { + if (istate == null) + return Z_STREAM_ERROR; + return istate.inflateSetDictionary(dictionary, dictLength); + } + + public boolean finished() { + return istate.mode == 12 /*DONE*/; + } +} diff --git a/netty-zlib/src/main/java/io/netty/zlib/InflaterInputStream.java b/netty-zlib/src/main/java/io/netty/zlib/InflaterInputStream.java new file mode 100644 index 0000000..b7a9be6 --- /dev/null +++ b/netty-zlib/src/main/java/io/netty/zlib/InflaterInputStream.java @@ -0,0 +1,225 @@ +package io.netty.zlib; + +import java.io.EOFException; +import java.io.FilterInputStream; +import java.io.IOException; +import java.io.InputStream; + +public class InflaterInputStream extends FilterInputStream { + protected final Inflater inflater; + protected byte[] buf; + + private boolean closed = false; + + private boolean eof = false; + + private boolean close_in = true; + + protected static final int DEFAULT_BUFSIZE = 512; + + public InflaterInputStream(InputStream in) throws IOException { + this(in, false); + } + + public InflaterInputStream(InputStream in, boolean nowrap) throws IOException { + this(in, new Inflater(nowrap)); + myinflater = true; + } + + public InflaterInputStream(InputStream in, Inflater inflater) throws IOException { + this(in, inflater, DEFAULT_BUFSIZE); + } + + public InflaterInputStream(InputStream in, + Inflater inflater, int size) throws IOException { + this(in, inflater, size, true); + } + + public InflaterInputStream(InputStream in, + Inflater inflater, + int size, boolean close_in) throws IOException { + super(in); + if (in == null || inflater == null) { + throw new NullPointerException(); + } else if (size <= 0) { + throw new IllegalArgumentException("buffer size must be greater than 0"); + } + this.inflater = inflater; + buf = new byte[size]; + this.close_in = close_in; + } + + protected boolean myinflater = false; + + private final byte[] byte1 = new byte[1]; + + public int read() throws IOException { + if (closed) { + throw new IOException("Stream closed"); + } + return read(byte1, 0, 1) == -1 ? -1 : byte1[0] & 0xff; + } + + public int read(byte[] b, int off, int len) throws IOException { + if (closed) { + throw new IOException("Stream closed"); + } + if (b == null) { + throw new NullPointerException(); + } else if (off < 0 || len < 0 || len > b.length - off) { + throw new IndexOutOfBoundsException(); + } else if (len == 0) { + return 0; + } else if (eof) { + return -1; + } + + int n = 0; + inflater.setOutput(b, off, len); + while (!eof) { + if (inflater.avail_in == 0) + fill(); + int err = inflater.inflate(JZlib.Z_NO_FLUSH); + n += inflater.next_out_index - off; + off = inflater.next_out_index; + switch (err) { + case JZlib.Z_DATA_ERROR: + throw new IOException(inflater.msg); + case JZlib.Z_STREAM_END: + case JZlib.Z_NEED_DICT: + eof = true; + if (err == JZlib.Z_NEED_DICT) + return -1; + break; + default: + } + if (inflater.avail_out == 0) + break; + } + return n; + } + + public int available() throws IOException { + if (closed) { + throw new IOException("Stream closed"); + } + if (eof) { + return 0; + } else { + return 1; + } + } + + private final byte[] b = new byte[512]; + + public long skip(long n) throws IOException { + if (n < 0) { + throw new IllegalArgumentException("negative skip length"); + } + + if (closed) { + throw new IOException("Stream closed"); + } + + int max = (int) Math.min(n, Integer.MAX_VALUE); + int total = 0; + while (total < max) { + int len = max - total; + if (len > b.length) { + len = b.length; + } + len = read(b, 0, len); + if (len == -1) { + eof = true; + break; + } + total += len; + } + return total; + } + + public void close() throws IOException { + if (!closed) { + if (myinflater) + inflater.end(); + if (close_in) + in.close(); + closed = true; + } + } + + protected void fill() throws IOException { + if (closed) { + throw new IOException("Stream closed"); + } + int len = in.read(buf, 0, buf.length); + if (len == -1) { + if (inflater.istate.wrap == 0 && + !inflater.finished()) { + buf[0] = 0; + len = 1; + } else if (inflater.istate.was != -1) { // in reading trailer + throw new IOException("footer is not found"); + } else { + throw new EOFException("Unexpected end of ZLIB input stream"); + } + } + inflater.setInput(buf, 0, len, true); + } + + public boolean markSupported() { + return false; + } + + public synchronized void mark(int readlimit) { + } + + public synchronized void reset() throws IOException { + throw new IOException("mark/reset not supported"); + } + + public long getTotalIn() { + return inflater.getTotalIn(); + } + + public long getTotalOut() { + return inflater.getTotalOut(); + } + + public byte[] getAvailIn() { + if (inflater.avail_in <= 0) + return null; + byte[] tmp = new byte[inflater.avail_in]; + System.arraycopy(inflater.next_in, inflater.next_in_index, + tmp, 0, inflater.avail_in); + return tmp; + } + + public void readHeader() throws IOException { + + byte[] empty = "".getBytes(); + inflater.setInput(empty, 0, 0, false); + inflater.setOutput(empty, 0, 0); + + int err = inflater.inflate(JZlib.Z_NO_FLUSH); + if (!inflater.istate.inParsingHeader()) { + return; + } + + byte[] b1 = new byte[1]; + do { + int i = in.read(b1); + if (i <= 0) + throw new IOException("no input"); + inflater.setInput(b1); + err = inflater.inflate(JZlib.Z_NO_FLUSH); + if (err != 0/*Z_OK*/) + throw new IOException(inflater.msg); + } + while (inflater.istate.inParsingHeader()); + } + + public Inflater getInflater() { + return inflater; + } +} \ No newline at end of file diff --git a/netty-zlib/src/main/java/io/netty/zlib/JZlib.java b/netty-zlib/src/main/java/io/netty/zlib/JZlib.java new file mode 100644 index 0000000..1f0cf27 --- /dev/null +++ b/netty-zlib/src/main/java/io/netty/zlib/JZlib.java @@ -0,0 +1,61 @@ +package io.netty.zlib; + +final public class JZlib { + private static final String version = "1.1.0"; + + public static String version() { + return version; + } + + static final public int MAX_WBITS = 15; // 32K LZ77 window + static final public int DEF_WBITS = MAX_WBITS; + + public enum WrapperType { + NONE, ZLIB, GZIP, ANY + } + + public static final WrapperType W_NONE = WrapperType.NONE; + public static final WrapperType W_ZLIB = WrapperType.ZLIB; + public static final WrapperType W_GZIP = WrapperType.GZIP; + public static final WrapperType W_ANY = WrapperType.ANY; + + // compression levels + static final public int Z_NO_COMPRESSION = 0; + static final public int Z_BEST_SPEED = 1; + static final public int Z_BEST_COMPRESSION = 9; + static final public int Z_DEFAULT_COMPRESSION = (-1); + + // compression strategy + static final public int Z_FILTERED = 1; + static final public int Z_HUFFMAN_ONLY = 2; + static final public int Z_DEFAULT_STRATEGY = 0; + + static final public int Z_NO_FLUSH = 0; + static final public int Z_PARTIAL_FLUSH = 1; + static final public int Z_SYNC_FLUSH = 2; + static final public int Z_FULL_FLUSH = 3; + static final public int Z_FINISH = 4; + + static final public int Z_OK = 0; + static final public int Z_STREAM_END = 1; + static final public int Z_NEED_DICT = 2; + static final public int Z_ERRNO = -1; + static final public int Z_STREAM_ERROR = -2; + static final public int Z_DATA_ERROR = -3; + static final public int Z_MEM_ERROR = -4; + static final public int Z_BUF_ERROR = -5; + static final public int Z_VERSION_ERROR = -6; + + // The three kinds of block type + static final public byte Z_BINARY = 0; + static final public byte Z_ASCII = 1; + static final public byte Z_UNKNOWN = 2; + + public static long adler32_combine(long adler1, long adler2, long len2) { + return Adler32.combine(adler1, adler2, len2); + } + + public static long crc32_combine(long crc1, long crc2, long len2) { + return CRC32.combine(crc1, crc2, len2); + } +} diff --git a/netty-zlib/src/main/java/io/netty/zlib/StaticTree.java b/netty-zlib/src/main/java/io/netty/zlib/StaticTree.java new file mode 100644 index 0000000..0c0e71d --- /dev/null +++ b/netty-zlib/src/main/java/io/netty/zlib/StaticTree.java @@ -0,0 +1,114 @@ +package io.netty.zlib; + +final class StaticTree { + static final private int MAX_BITS = 15; + + static final private int BL_CODES = 19; + static final private int D_CODES = 30; + static final private int LITERALS = 256; + static final private int LENGTH_CODES = 29; + static final private int L_CODES = (LITERALS + 1 + LENGTH_CODES); + + // Bit length codes must not exceed MAX_BL_BITS bits + static final int MAX_BL_BITS = 7; + + static final short[] static_ltree = { + 12, 8, 140, 8, 76, 8, 204, 8, 44, 8, + 172, 8, 108, 8, 236, 8, 28, 8, 156, 8, + 92, 8, 220, 8, 60, 8, 188, 8, 124, 8, + 252, 8, 2, 8, 130, 8, 66, 8, 194, 8, + 34, 8, 162, 8, 98, 8, 226, 8, 18, 8, + 146, 8, 82, 8, 210, 8, 50, 8, 178, 8, + 114, 8, 242, 8, 10, 8, 138, 8, 74, 8, + 202, 8, 42, 8, 170, 8, 106, 8, 234, 8, + 26, 8, 154, 8, 90, 8, 218, 8, 58, 8, + 186, 8, 122, 8, 250, 8, 6, 8, 134, 8, + 70, 8, 198, 8, 38, 8, 166, 8, 102, 8, + 230, 8, 22, 8, 150, 8, 86, 8, 214, 8, + 54, 8, 182, 8, 118, 8, 246, 8, 14, 8, + 142, 8, 78, 8, 206, 8, 46, 8, 174, 8, + 110, 8, 238, 8, 30, 8, 158, 8, 94, 8, + 222, 8, 62, 8, 190, 8, 126, 8, 254, 8, + 1, 8, 129, 8, 65, 8, 193, 8, 33, 8, + 161, 8, 97, 8, 225, 8, 17, 8, 145, 8, + 81, 8, 209, 8, 49, 8, 177, 8, 113, 8, + 241, 8, 9, 8, 137, 8, 73, 8, 201, 8, + 41, 8, 169, 8, 105, 8, 233, 8, 25, 8, + 153, 8, 89, 8, 217, 8, 57, 8, 185, 8, + 121, 8, 249, 8, 5, 8, 133, 8, 69, 8, + 197, 8, 37, 8, 165, 8, 101, 8, 229, 8, + 21, 8, 149, 8, 85, 8, 213, 8, 53, 8, + 181, 8, 117, 8, 245, 8, 13, 8, 141, 8, + 77, 8, 205, 8, 45, 8, 173, 8, 109, 8, + 237, 8, 29, 8, 157, 8, 93, 8, 221, 8, + 61, 8, 189, 8, 125, 8, 253, 8, 19, 9, + 275, 9, 147, 9, 403, 9, 83, 9, 339, 9, + 211, 9, 467, 9, 51, 9, 307, 9, 179, 9, + 435, 9, 115, 9, 371, 9, 243, 9, 499, 9, + 11, 9, 267, 9, 139, 9, 395, 9, 75, 9, + 331, 9, 203, 9, 459, 9, 43, 9, 299, 9, + 171, 9, 427, 9, 107, 9, 363, 9, 235, 9, + 491, 9, 27, 9, 283, 9, 155, 9, 411, 9, + 91, 9, 347, 9, 219, 9, 475, 9, 59, 9, + 315, 9, 187, 9, 443, 9, 123, 9, 379, 9, + 251, 9, 507, 9, 7, 9, 263, 9, 135, 9, + 391, 9, 71, 9, 327, 9, 199, 9, 455, 9, + 39, 9, 295, 9, 167, 9, 423, 9, 103, 9, + 359, 9, 231, 9, 487, 9, 23, 9, 279, 9, + 151, 9, 407, 9, 87, 9, 343, 9, 215, 9, + 471, 9, 55, 9, 311, 9, 183, 9, 439, 9, + 119, 9, 375, 9, 247, 9, 503, 9, 15, 9, + 271, 9, 143, 9, 399, 9, 79, 9, 335, 9, + 207, 9, 463, 9, 47, 9, 303, 9, 175, 9, + 431, 9, 111, 9, 367, 9, 239, 9, 495, 9, + 31, 9, 287, 9, 159, 9, 415, 9, 95, 9, + 351, 9, 223, 9, 479, 9, 63, 9, 319, 9, + 191, 9, 447, 9, 127, 9, 383, 9, 255, 9, + 511, 9, 0, 7, 64, 7, 32, 7, 96, 7, + 16, 7, 80, 7, 48, 7, 112, 7, 8, 7, + 72, 7, 40, 7, 104, 7, 24, 7, 88, 7, + 56, 7, 120, 7, 4, 7, 68, 7, 36, 7, + 100, 7, 20, 7, 84, 7, 52, 7, 116, 7, + 3, 8, 131, 8, 67, 8, 195, 8, 35, 8, + 163, 8, 99, 8, 227, 8 + }; + + static final short[] static_dtree = { + 0, 5, 16, 5, 8, 5, 24, 5, 4, 5, + 20, 5, 12, 5, 28, 5, 2, 5, 18, 5, + 10, 5, 26, 5, 6, 5, 22, 5, 14, 5, + 30, 5, 1, 5, 17, 5, 9, 5, 25, 5, + 5, 5, 21, 5, 13, 5, 29, 5, 3, 5, + 19, 5, 11, 5, 27, 5, 7, 5, 23, 5 + }; + + static StaticTree static_l_desc = + new StaticTree(static_ltree, Tree.extra_lbits, + LITERALS + 1, L_CODES, MAX_BITS); + + static StaticTree static_d_desc = + new StaticTree(static_dtree, Tree.extra_dbits, + 0, D_CODES, MAX_BITS); + + static StaticTree static_bl_desc = + new StaticTree(null, Tree.extra_blbits, + 0, BL_CODES, MAX_BL_BITS); + + short[] static_tree; // static tree or null + int[] extra_bits; // extra bits for each code or null + int extra_base; // base index for extra_bits + int elems; // max number of elements in the tree + int max_length; // max bit length for the codes + + private StaticTree(short[] static_tree, + int[] extra_bits, + int extra_base, + int elems, + int max_length) { + this.static_tree = static_tree; + this.extra_bits = extra_bits; + this.extra_base = extra_base; + this.elems = elems; + this.max_length = max_length; + } +} diff --git a/netty-zlib/src/main/java/io/netty/zlib/Tree.java b/netty-zlib/src/main/java/io/netty/zlib/Tree.java new file mode 100644 index 0000000..ef4a700 --- /dev/null +++ b/netty-zlib/src/main/java/io/netty/zlib/Tree.java @@ -0,0 +1,336 @@ +package io.netty.zlib; + +final class Tree { + static final private int MAX_BITS = 15; + static final private int BL_CODES = 19; + static final private int D_CODES = 30; + static final private int LITERALS = 256; + static final private int LENGTH_CODES = 29; + static final private int L_CODES = (LITERALS + 1 + LENGTH_CODES); + static final private int HEAP_SIZE = (2 * L_CODES + 1); + + // Bit length codes must not exceed MAX_BL_BITS bits + static final int MAX_BL_BITS = 7; + + // end of block literal code + static final int END_BLOCK = 256; + + // repeat previous bit length 3-6 times (2 bits of repeat count) + static final int REP_3_6 = 16; + + // repeat a zero length 3-10 times (3 bits of repeat count) + static final int REPZ_3_10 = 17; + + // repeat a zero length 11-138 times (7 bits of repeat count) + static final int REPZ_11_138 = 18; + + // extra bits for each length code + static final int[] extra_lbits = { + 0, 0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3, 4, 4, 4, 4, 5, 5, 5, 5, 0 + }; + + // extra bits for each distance code + static final int[] extra_dbits = { + 0, 0, 0, 0, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 13, 13 + }; + + // extra bits for each bit length code + static final int[] extra_blbits = { + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, 3, 7 + }; + + static final byte[] bl_order = { + 16, 17, 18, 0, 8, 7, 9, 6, 10, 5, 11, 4, 12, 3, 13, 2, 14, 1, 15}; + + + // The lengths of the bit length codes are sent in order of decreasing + // probability, to avoid transmitting the lengths for unused bit + // length codes. + + static final int Buf_size = 8 * 2; + + // see definition of array dist_code below + static final int DIST_CODE_LEN = 512; + + static final byte[] _dist_code = { + 0, 1, 2, 3, 4, 4, 5, 5, 6, 6, 6, 6, 7, 7, 7, 7, 8, 8, 8, 8, + 8, 8, 8, 8, 9, 9, 9, 9, 9, 9, 9, 9, 10, 10, 10, 10, 10, 10, 10, 10, + 10, 10, 10, 10, 10, 10, 10, 10, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, 11, + 11, 11, 11, 11, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, + 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 12, 13, 13, 13, 13, + 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, 13, + 13, 13, 13, 13, 13, 13, 13, 13, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, + 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, + 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, + 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 14, 15, 15, 15, 15, 15, 15, 15, 15, + 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, + 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, + 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 15, 0, 0, 16, 17, + 18, 18, 19, 19, 20, 20, 20, 20, 21, 21, 21, 21, 22, 22, 22, 22, 22, 22, 22, 22, + 23, 23, 23, 23, 23, 23, 23, 23, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, + 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, + 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 27, 27, 27, 27, 27, 27, 27, 27, + 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, + 27, 27, 27, 27, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, + 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, + 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, 28, + 28, 28, 28, 28, 28, 28, 28, 28, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, + 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, + 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, + 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29, 29 + }; + + static final byte[] _length_code = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 8, 9, 9, 10, 10, 11, 11, 12, 12, 12, 12, + 13, 13, 13, 13, 14, 14, 14, 14, 15, 15, 15, 15, 16, 16, 16, 16, 16, 16, 16, 16, + 17, 17, 17, 17, 17, 17, 17, 17, 18, 18, 18, 18, 18, 18, 18, 18, 19, 19, 19, 19, + 19, 19, 19, 19, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, 20, + 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 21, 22, 22, 22, 22, + 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 22, 23, 23, 23, 23, 23, 23, 23, 23, + 23, 23, 23, 23, 23, 23, 23, 23, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, 24, + 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, + 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 25, 26, 26, 26, 26, 26, 26, 26, 26, + 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, 26, + 26, 26, 26, 26, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, + 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 27, 28 + }; + + static final int[] base_length = { + 0, 1, 2, 3, 4, 5, 6, 7, 8, 10, 12, 14, 16, 20, 24, 28, 32, 40, 48, 56, + 64, 80, 96, 112, 128, 160, 192, 224, 0 + }; + + static final int[] base_dist = { + 0, 1, 2, 3, 4, 6, 8, 12, 16, 24, + 32, 48, 64, 96, 128, 192, 256, 384, 512, 768, + 1024, 1536, 2048, 3072, 4096, 6144, 8192, 12288, 16384, 24576 + }; + + // Mapping from a distance to a distance code. dist is the distance - 1 and + // must not have side effects. _dist_code[256] and _dist_code[257] are never + // used. + static int d_code(int dist) { + return ((dist) < 256 ? _dist_code[dist] : _dist_code[256 + ((dist) >>> 7)]); + } + + short[] dyn_tree; // the dynamic tree + int max_code; // largest code with non zero frequency + StaticTree stat_desc; // the corresponding static tree + + // Compute the optimal bit lengths for a tree and update the total bit length + // for the current block. + // IN assertion: the fields freq and dad are set, heap[heap_max] and + // above are the tree nodes sorted by increasing frequency. + // OUT assertions: the field len is set to the optimal bit length, the + // array bl_count contains the frequencies for each bit length. + // The length opt_len is updated; static_len is also updated if stree is + // not null. + void gen_bitlen(Deflate s) { + short[] tree = dyn_tree; + short[] stree = stat_desc.static_tree; + int[] extra = stat_desc.extra_bits; + int base = stat_desc.extra_base; + int max_length = stat_desc.max_length; + int h; // heap index + int n, m; // iterate over the tree elements + int bits; // bit length + int xbits; // extra bits + short f; // frequency + int overflow = 0; // number of elements with bit length too large + + for (bits = 0; bits <= MAX_BITS; bits++) s.bl_count[bits] = 0; + + // In a first pass, compute the optimal bit lengths (which may + // overflow in the case of the bit length tree). + tree[s.heap[s.heap_max] * 2 + 1] = 0; // root of the heap + + for (h = s.heap_max + 1; h < HEAP_SIZE; h++) { + n = s.heap[h]; + bits = tree[tree[n * 2 + 1] * 2 + 1] + 1; + if (bits > max_length) { + bits = max_length; + overflow++; + } + tree[n * 2 + 1] = (short) bits; + // We overwrite tree[n*2+1] which is no longer needed + + if (n > max_code) continue; // not a leaf node + + s.bl_count[bits]++; + xbits = 0; + if (n >= base) xbits = extra[n - base]; + f = tree[n * 2]; + s.opt_len += f * (bits + xbits); + if (stree != null) s.static_len += f * (stree[n * 2 + 1] + xbits); + } + if (overflow == 0) return; + + // This happens for example on obj2 and pic of the Calgary corpus + // Find the first bit length which could increase: + do { + bits = max_length - 1; + while (s.bl_count[bits] == 0) bits--; + s.bl_count[bits]--; // move one leaf down the tree + s.bl_count[bits + 1] += 2; // move one overflow item as its brother + s.bl_count[max_length]--; + // The brother of the overflow item also moves one step up, + // but this does not affect bl_count[max_length] + overflow -= 2; + } + while (overflow > 0); + + for (bits = max_length; bits != 0; bits--) { + n = s.bl_count[bits]; + while (n != 0) { + m = s.heap[--h]; + if (m > max_code) continue; + if (tree[m * 2 + 1] != bits) { + s.opt_len += ((long) bits - (long) tree[m * 2 + 1]) * (long) tree[m * 2]; + tree[m * 2 + 1] = (short) bits; + } + n--; + } + } + } + + // Construct one Huffman tree and assigns the code bit strings and lengths. + // Update the total bit length for the current block. + // IN assertion: the field freq is set for all tree elements. + // OUT assertions: the fields len and code are set to the optimal bit length + // and corresponding code. The length opt_len is updated; static_len is + // also updated if stree is not null. The field max_code is set. + void build_tree(Deflate s) { + short[] tree = dyn_tree; + short[] stree = stat_desc.static_tree; + int elems = stat_desc.elems; + int n, m; // iterate over heap elements + int max_code = -1; // largest code with non zero frequency + int node; // new node being created + + // Construct the initial heap, with least frequent element in + // heap[1]. The sons of heap[n] are heap[2*n] and heap[2*n+1]. + // heap[0] is not used. + s.heap_len = 0; + s.heap_max = HEAP_SIZE; + + for (n = 0; n < elems; n++) { + if (tree[n * 2] != 0) { + s.heap[++s.heap_len] = max_code = n; + s.depth[n] = 0; + } else { + tree[n * 2 + 1] = 0; + } + } + + // The pkzip format requires that at least one distance code exists, + // and that at least one bit should be sent even if there is only one + // possible code. So to avoid special checks later on we force at least + // two codes of non zero frequency. + while (s.heap_len < 2) { + node = s.heap[++s.heap_len] = (max_code < 2 ? ++max_code : 0); + tree[node * 2] = 1; + s.depth[node] = 0; + s.opt_len--; + if (stree != null) s.static_len -= stree[node * 2 + 1]; + // node is 0 or 1 so it does not have extra bits + } + this.max_code = max_code; + + // The elements heap[heap_len/2+1 .. heap_len] are leaves of the tree, + // establish sub-heaps of increasing lengths: + + for (n = s.heap_len / 2; n >= 1; n--) + s.pqdownheap(tree, n); + + // Construct the Huffman tree by repeatedly combining the least two + // frequent nodes. + + node = elems; // next internal node of the tree + do { + // n = node of least frequency + n = s.heap[1]; + s.heap[1] = s.heap[s.heap_len--]; + s.pqdownheap(tree, 1); + m = s.heap[1]; // m = node of next least frequency + + s.heap[--s.heap_max] = n; // keep the nodes sorted by frequency + s.heap[--s.heap_max] = m; + + // Create a new node father of n and m + tree[node * 2] = (short) (tree[n * 2] + tree[m * 2]); + s.depth[node] = (byte) (Math.max(s.depth[n], s.depth[m]) + 1); + tree[n * 2 + 1] = tree[m * 2 + 1] = (short) node; + + // and insert the new node in the heap + s.heap[1] = node++; + s.pqdownheap(tree, 1); + } + while (s.heap_len >= 2); + + s.heap[--s.heap_max] = s.heap[1]; + + // At this point, the fields freq and dad are set. We can now + // generate the bit lengths. + + gen_bitlen(s); + + // The field len is now set, we can generate the bit codes + gen_codes(tree, max_code, s.bl_count, s.next_code); + } + + // Generate the codes for a given tree and bit counts (which need not be + // optimal). + // IN assertion: the array bl_count contains the bit length statistics for + // the given tree and the field len is set for all tree elements. + // OUT assertion: the field code is set for all tree elements of non + // zero code length. + private static void gen_codes( + short[] tree, // the tree to decorate + int max_code, // largest code with non zero frequency + short[] bl_count, // number of codes at each bit length + short[] next_code) { + short code = 0; // running code value + int bits; // bit index + int n; // code index + + // The distribution counts are first used to generate the code values + // without bit reversal. + next_code[0] = 0; + for (bits = 1; bits <= MAX_BITS; bits++) { + next_code[bits] = code = (short) ((code + bl_count[bits - 1]) << 1); + } + + // Check that the bit counts in bl_count are consistent. The last code + // must be all ones. + //Assert (code + bl_count[MAX_BITS]-1 == (1<>>= 1; + res <<= 1; + } + while (--len > 0); + return res >>> 1; + } +} + diff --git a/netty-zlib/src/main/java/io/netty/zlib/ZInputStream.java b/netty-zlib/src/main/java/io/netty/zlib/ZInputStream.java new file mode 100644 index 0000000..f3fcf65 --- /dev/null +++ b/netty-zlib/src/main/java/io/netty/zlib/ZInputStream.java @@ -0,0 +1,101 @@ +package io.netty.zlib; + +import java.io.FilterInputStream; +import java.io.IOException; +import java.io.InputStream; + +/** + * ZInputStream + * + * @deprecated use DeflaterOutputStream or InflaterInputStream + */ +@Deprecated +public class ZInputStream extends FilterInputStream { + + protected int flush = JZlib.Z_NO_FLUSH; + protected boolean compress; + protected InputStream in = null; + + protected Deflater deflater; + protected InflaterInputStream iis; + + public ZInputStream(InputStream in) throws IOException { + this(in, false); + } + + public ZInputStream(InputStream in, boolean nowrap) throws IOException { + super(in); + iis = new InflaterInputStream(in, nowrap); + compress = false; + } + + public ZInputStream(InputStream in, int level) throws IOException { + super(in); + this.in = in; + deflater = new Deflater(); + deflater.init(level); + compress = true; + } + + private final byte[] buf1 = new byte[1]; + + public int read() throws IOException { + if (read(buf1, 0, 1) == -1) return -1; + return (buf1[0] & 0xFF); + } + + private final byte[] buf = new byte[512]; + + public int read(byte[] b, int off, int len) throws IOException { + if (compress) { + deflater.setOutput(b, off, len); + while (true) { + int datalen = in.read(buf, 0, buf.length); + if (datalen == -1) return -1; + deflater.setInput(buf, 0, datalen, true); + int err = deflater.deflate(flush); + if (deflater.next_out_index > 0) + return deflater.next_out_index; + if (err == JZlib.Z_STREAM_END) + return 0; + if (err == JZlib.Z_STREAM_ERROR || + err == JZlib.Z_DATA_ERROR) { + throw new ZStreamException("deflating: " + deflater.msg); + } + } + } else { + return iis.read(b, off, len); + } + } + + public long skip(long n) throws IOException { + int len = 512; + if (n < len) + len = (int) n; + byte[] tmp = new byte[len]; + return read(tmp); + } + + public int getFlushMode() { + return flush; + } + + public void setFlushMode(int flush) { + this.flush = flush; + } + + public long getTotalIn() { + if (compress) return deflater.total_in; + else return iis.getTotalIn(); + } + + public long getTotalOut() { + if (compress) return deflater.total_out; + else return iis.getTotalOut(); + } + + public void close() throws IOException { + if (compress) deflater.end(); + else iis.close(); + } +} diff --git a/netty-zlib/src/main/java/io/netty/zlib/ZOutputStream.java b/netty-zlib/src/main/java/io/netty/zlib/ZOutputStream.java new file mode 100644 index 0000000..5c027f0 --- /dev/null +++ b/netty-zlib/src/main/java/io/netty/zlib/ZOutputStream.java @@ -0,0 +1,137 @@ +package io.netty.zlib; + +import java.io.FilterOutputStream; +import java.io.IOException; +import java.io.OutputStream; + +/** + * ZOutputStream + * + * @deprecated use DeflaterOutputStream or InflaterInputStream + */ +@Deprecated +public class ZOutputStream extends FilterOutputStream { + + protected int bufsize = 512; + protected int flush = JZlib.Z_NO_FLUSH; + protected byte[] buf = new byte[bufsize]; + protected boolean compress; + + protected OutputStream out; + private boolean end = false; + + private DeflaterOutputStream dos; + private Inflater inflater; + + public ZOutputStream(OutputStream out) throws IOException { + super(out); + this.out = out; + inflater = new Inflater(); + inflater.init(); + compress = false; + } + + public ZOutputStream(OutputStream out, int level) throws IOException { + this(out, level, false); + } + + public ZOutputStream(OutputStream out, int level, boolean nowrap) throws IOException { + super(out); + this.out = out; + Deflater deflater = new Deflater(level, nowrap); + dos = new DeflaterOutputStream(out, deflater); + compress = true; + } + + private final byte[] buf1 = new byte[1]; + + public void write(int b) throws IOException { + buf1[0] = (byte) b; + write(buf1, 0, 1); + } + + public void write(byte[] b, int off, int len) throws IOException { + if (len == 0) return; + if (compress) { + dos.write(b, off, len); + } else { + inflater.setInput(b, off, len, true); + int err = JZlib.Z_OK; + while (inflater.avail_in > 0) { + inflater.setOutput(buf, 0, buf.length); + err = inflater.inflate(flush); + if (inflater.next_out_index > 0) + out.write(buf, 0, inflater.next_out_index); + if (err != JZlib.Z_OK) + break; + } + if (err != JZlib.Z_OK) + throw new ZStreamException("inflating: " + inflater.msg); + } + } + + public int getFlushMode() { + return flush; + } + + public void setFlushMode(int flush) { + this.flush = flush; + } + + public void finish() throws IOException { + int err; + if (compress) { + int tmp = flush; + int flush = JZlib.Z_FINISH; + try { + write("".getBytes(), 0, 0); + } finally { + flush = tmp; + } + } else { + dos.finish(); + } + flush(); + } + + public synchronized void end() { + if (end) return; + if (compress) { + try { + dos.finish(); + } catch (Exception e) { + } + } else { + inflater.end(); + } + end = true; + } + + public void close() throws IOException { + try { + try { + finish(); + } catch (IOException ignored) { + } + } finally { + end(); + out.close(); + out = null; + } + } + + public long getTotalIn() { + if (compress) return dos.getTotalIn(); + else return inflater.total_in; + } + + public long getTotalOut() { + if (compress) return dos.getTotalOut(); + else return inflater.total_out; + } + + public void flush() throws IOException { + out.flush(); + } + +} diff --git a/netty-zlib/src/main/java/io/netty/zlib/ZStream.java b/netty-zlib/src/main/java/io/netty/zlib/ZStream.java new file mode 100644 index 0000000..edf7bb6 --- /dev/null +++ b/netty-zlib/src/main/java/io/netty/zlib/ZStream.java @@ -0,0 +1,360 @@ +package io.netty.zlib; + +/** + * ZStream + * + * @deprecated Not for public use in the future. + */ +@Deprecated +public class ZStream { + + static final private int MAX_WBITS = 15; // 32K LZ77 window + static final private int DEF_WBITS = MAX_WBITS; + + static final private int Z_NO_FLUSH = 0; + static final private int Z_PARTIAL_FLUSH = 1; + static final private int Z_SYNC_FLUSH = 2; + static final private int Z_FULL_FLUSH = 3; + static final private int Z_FINISH = 4; + + static final private int MAX_MEM_LEVEL = 9; + + static final private int Z_OK = 0; + static final private int Z_STREAM_END = 1; + static final private int Z_NEED_DICT = 2; + static final private int Z_ERRNO = -1; + static final private int Z_STREAM_ERROR = -2; + static final private int Z_DATA_ERROR = -3; + static final private int Z_MEM_ERROR = -4; + static final private int Z_BUF_ERROR = -5; + static final private int Z_VERSION_ERROR = -6; + + public byte[] next_in; // next input byte + public int next_in_index; + public int avail_in; // number of bytes available at next_in + public long total_in; // total nb of input bytes read so far + + public byte[] next_out; // next output byte should be put there + public int next_out_index; + public int avail_out; // remaining free space at next_out + public long total_out; // total nb of bytes output so far + + public String msg; + + Deflate dstate; + Inflate istate; + + int data_type; // best guess about the data type: ascii or binary + + Checksum adler; + + public ZStream() { + this(new Adler32()); + } + + public ZStream(Checksum adler) { + this.adler = adler; + } + + public int inflateInit() { + return inflateInit(DEF_WBITS); + } + + public int inflateInit(boolean nowrap) { + return inflateInit(DEF_WBITS, nowrap); + } + + public int inflateInit(int w) { + return inflateInit(w, false); + } + + public int inflateInit(JZlib.WrapperType wrapperType) { + return inflateInit(DEF_WBITS, wrapperType); + } + + public int inflateInit(int w, JZlib.WrapperType wrapperType) { + boolean nowrap = false; + if (wrapperType == JZlib.W_NONE) { + nowrap = true; + } else if (wrapperType == JZlib.W_GZIP) { + w += 16; + } else if (wrapperType == JZlib.W_ANY) { + w |= Inflate.INFLATE_ANY; + } else if (wrapperType == JZlib.W_ZLIB) { + } + return inflateInit(w, nowrap); + } + + public int inflateInit(int w, boolean nowrap) { + istate = new Inflate(this); + return istate.inflateInit(nowrap ? -w : w); + } + + public int inflate(int f) { + if (istate == null) return Z_STREAM_ERROR; + return istate.inflate(f); + } + + public int inflateEnd() { + if (istate == null) return Z_STREAM_ERROR; + int ret = istate.inflateEnd(); +// istate = null; + return ret; + } + + public int inflateSync() { + if (istate == null) + return Z_STREAM_ERROR; + return istate.inflateSync(); + } + + public int inflateSyncPoint() { + if (istate == null) + return Z_STREAM_ERROR; + return istate.inflateSyncPoint(); + } + + public int inflateSetDictionary(byte[] dictionary, int dictLength) { + if (istate == null) + return Z_STREAM_ERROR; + return istate.inflateSetDictionary(dictionary, dictLength); + } + + public boolean inflateFinished() { + return istate.mode == 12 /*DONE*/; + } + + public int deflateInit(int level) { + return deflateInit(level, MAX_WBITS); + } + + public int deflateInit(int level, boolean nowrap) { + return deflateInit(level, MAX_WBITS, nowrap); + } + + public int deflateInit(int level, int bits) { + return deflateInit(level, bits, false); + } + + public int deflateInit(int level, int bits, int memlevel, JZlib.WrapperType wrapperType) { + if (bits < 9 || bits > 15) { + return Z_STREAM_ERROR; + } + if (wrapperType == JZlib.W_NONE) { + bits *= -1; + } else if (wrapperType == JZlib.W_GZIP) { + bits += 16; + } else if (wrapperType == JZlib.W_ANY) { + return Z_STREAM_ERROR; + } else if (wrapperType == JZlib.W_ZLIB) { + } + return this.deflateInit(level, bits, memlevel); + } + + public int deflateInit(int level, int bits, int memlevel) { + dstate = new Deflate(this); + return dstate.deflateInit(level, bits, memlevel); + } + + public int deflateInit(int level, int bits, boolean nowrap) { + dstate = new Deflate(this); + return dstate.deflateInit(level, nowrap ? -bits : bits); + } + + public int deflate(int flush) { + if (dstate == null) { + return Z_STREAM_ERROR; + } + return dstate.deflate(flush); + } + + public int deflateEnd() { + if (dstate == null) return Z_STREAM_ERROR; + int ret = dstate.deflateEnd(); + dstate = null; + return ret; + } + + public int deflateParams(int level, int strategy) { + if (dstate == null) return Z_STREAM_ERROR; + return dstate.deflateParams(level, strategy); + } + + public int deflateSetDictionary(byte[] dictionary, int dictLength) { + if (dstate == null) + return Z_STREAM_ERROR; + return dstate.deflateSetDictionary(dictionary, dictLength); + } + + // Flush as much pending output as possible. All deflate() output goes + // through this function so some applications may wish to modify it + // to avoid allocating a large strm->next_out buffer and copying into it. + // (See also read_buf()). + void flush_pending() { + int len = dstate.pending; + + if (len > avail_out) len = avail_out; + if (len == 0) return; + + if (dstate.pending_buf.length <= dstate.pending_out || + next_out.length <= next_out_index || + dstate.pending_buf.length < (dstate.pending_out + len) || + next_out.length < (next_out_index + len)) { + //System.out.println(dstate.pending_buf.length+", "+dstate.pending_out+ + // ", "+next_out.length+", "+next_out_index+", "+len); + //System.out.println("avail_out="+avail_out); + } + + System.arraycopy(dstate.pending_buf, dstate.pending_out, + next_out, next_out_index, len); + + next_out_index += len; + dstate.pending_out += len; + total_out += len; + avail_out -= len; + dstate.pending -= len; + if (dstate.pending == 0) { + dstate.pending_out = 0; + } + } + + // Read a new buffer from the current input stream, update the adler32 + // and total number of bytes read. All deflate() input goes through + // this function so some applications may wish to modify it to avoid + // allocating a large strm->next_in buffer and copying from it. + // (See also flush_pending()). + int read_buf(byte[] buf, int start, int size) { + int len = avail_in; + + if (len > size) len = size; + if (len == 0) return 0; + + avail_in -= len; + + if (dstate.wrap != 0) { + adler.update(next_in, next_in_index, len); + } + System.arraycopy(next_in, next_in_index, buf, start, len); + next_in_index += len; + total_in += len; + return len; + } + + public long getAdler() { + return adler.getValue(); + } + + public void free() { + next_in = null; + next_out = null; + msg = null; + } + + public void setOutput(byte[] buf) { + setOutput(buf, 0, buf.length); + } + + public void setOutput(byte[] buf, int off, int len) { + next_out = buf; + next_out_index = off; + avail_out = len; + } + + public void setInput(byte[] buf) { + setInput(buf, 0, buf.length, false); + } + + public void setInput(byte[] buf, boolean append) { + setInput(buf, 0, buf.length, append); + } + + public void setInput(byte[] buf, int off, int len, boolean append) { + if (len <= 0 && append && next_in != null) return; + + if (avail_in > 0 && append) { + byte[] tmp = new byte[avail_in + len]; + System.arraycopy(next_in, next_in_index, tmp, 0, avail_in); + System.arraycopy(buf, off, tmp, avail_in, len); + next_in = tmp; + next_in_index = 0; + avail_in += len; + } else { + next_in = buf; + next_in_index = off; + avail_in = len; + } + } + + public byte[] getNextIn() { + return next_in; + } + + public void setNextIn(byte[] next_in) { + this.next_in = next_in; + } + + public int getNextInIndex() { + return next_in_index; + } + + public void setNextInIndex(int next_in_index) { + this.next_in_index = next_in_index; + } + + public int getAvailIn() { + return avail_in; + } + + public void setAvailIn(int avail_in) { + this.avail_in = avail_in; + } + + public byte[] getNextOut() { + return next_out; + } + + public void setNextOut(byte[] next_out) { + this.next_out = next_out; + } + + public int getNextOutIndex() { + return next_out_index; + } + + public void setNextOutIndex(int next_out_index) { + this.next_out_index = next_out_index; + } + + public int getAvailOut() { + return avail_out; + + } + + public void setAvailOut(int avail_out) { + this.avail_out = avail_out; + } + + public long getTotalOut() { + return total_out; + } + + public long getTotalIn() { + return total_in; + } + + public String getMessage() { + return msg; + } + + /** + * Those methods are expected to be override by Inflater and Deflater. + * In the future, they will become abstract methods. + */ + public int end() { + return Z_OK; + } + + public boolean finished() { + return false; + } +} diff --git a/netty-zlib/src/main/java/io/netty/zlib/ZStreamException.java b/netty-zlib/src/main/java/io/netty/zlib/ZStreamException.java new file mode 100644 index 0000000..77c655e --- /dev/null +++ b/netty-zlib/src/main/java/io/netty/zlib/ZStreamException.java @@ -0,0 +1,11 @@ +package io.netty.zlib; + +public class ZStreamException extends java.io.IOException { + public ZStreamException() { + super(); + } + + public ZStreamException(String s) { + super(s); + } +} diff --git a/netty-zlib/src/main/java/module-info.java b/netty-zlib/src/main/java/module-info.java new file mode 100644 index 0000000..7b461e9 --- /dev/null +++ b/netty-zlib/src/main/java/module-info.java @@ -0,0 +1,3 @@ +module org.xbib.io.netty.zlib { + exports io.netty.zlib; +} diff --git a/settings.gradle b/settings.gradle new file mode 100644 index 0000000..4724f8e --- /dev/null +++ b/settings.gradle @@ -0,0 +1,75 @@ +pluginManagement { + repositories { + mavenLocal() + mavenCentral { + metadataSources { + mavenPom() + artifact() + ignoreGradleMetadataRedirection() + } + } + gradlePluginPortal() + } +} + +dependencyResolutionManagement { + versionCatalogs { + libs { + version('gradle', '8.5') + version('brotli4j', '1.15.0') + library('bouncycastle', 'org.bouncycastle', 'bcpkix-jdk18on').version('1.77') + library('conscrypt', 'org.conscrypt', 'conscrypt-openjdk-uber').version('2.5.2') + library('brotli4j', 'com.aayushatharva.brotli4j', 'brotli4j').versionRef('brotli4j') + library('jzlib', 'com.jcraft', 'jzlib').version('1.1.3') + library('lz4', 'org.lz4', 'lz4-java').version('1.8.0') + library('lzf', 'com.ning', 'compress-lzf').version('1.1.2') + library('zstd', 'com.github.luben', 'zstd-jni').version('1.5.5-11') + library('protobuf', 'com.google.protobuf', 'protobuf-java').version('4.0.0-rc-2') + } + testLibs { + version('junit', '5.10.1') + version('brotli4j', '1.15.0') + library('junit-jupiter-api', 'org.junit.jupiter', 'junit-jupiter-api').versionRef('junit') + library('junit-jupiter-params', 'org.junit.jupiter', 'junit-jupiter-params').versionRef('junit') + library('junit-jupiter-engine', 'org.junit.jupiter', 'junit-jupiter-engine').versionRef('junit') + library('junit-vintage-engine', 'org.junit.vintage', 'junit-vintage-engine').versionRef('junit') + library('junit-jupiter-platform-launcher', 'org.junit.platform', 'junit-platform-launcher').version('1.10.1') + library('hamcrest', 'org.hamcrest', 'hamcrest-library').version('2.2') + library('junit4', 'junit', 'junit').version('4.13.2') + library('mockito-core', 'org.mockito', 'mockito-core').version('3.12.4') + library('assertj', 'org.assertj', 'assertj-core').version('3.22.0') + library('testlibs', 'com.google.guava', 'guava-testlib').version('33.0.0-jre') + library('lincheck', 'org.jetbrains.kotlinx', 'lincheck-jvm').version('2.23') + library('asm-commons', 'org.ow2.asm', 'asm-commons').version('9.6') + library('asm-util', 'org.ow2.asm', 'asm-util').version('9.6') + library('gson', 'com.google.code.gson', 'gson').version('2.10.1') + library('reflections', 'org.reflections', 'reflections').version('0.10.2') + library('amazonCorrettoCrypt', 'software.amazon.cryptools', 'AmazonCorrettoCryptoProvider').version('2.3.2') + library('commons-compress', 'org.apache.commons', 'commons-compress').version('1.25.0') + library('brotli4j-native-linux-x8664', 'com.aayushatharva.brotli4j', 'native-linux-x86_64').versionRef('brotli4j') + library('brotli4j-native-linux-aarch64', 'com.aayushatharva.brotli4j', 'native-linux-aarch64').versionRef('brotli4j') + library('brotli4j-native-linux-riscv64', 'com.aayushatharva.brotli4j', 'native-linux-riscv64').versionRef('brotli4j') + library('brotli4j-native-osx-x8664', 'com.aayushatharva.brotli4j', 'native-osx-x86_64').versionRef('brotli4j') + library('brotli4j-native-osx-aarch64', 'com.aayushatharva.brotli4j', 'native-osx-aarch64').versionRef('brotli4j') + library('brotli4j-native-windows-x8664', 'com.aayushatharva.brotli4j', 'native-windows-x86_64').versionRef('brotli4j') + library('netty.tcnative.boringssl.static', 'io.netty', 'netty-tcnative-boringssl-static').version('2.0.62.Final') + } + } +} + +include 'netty-buffer' +include 'netty-bzip2' +include 'netty-channel' +include 'netty-channel-unix' +include 'netty-jctools' +include 'netty-handler' +include 'netty-handler-codec' +include 'netty-handler-codec-compression' +include 'netty-handler-codec-http' +include 'netty-handler-codec-http2' +include 'netty-handler-codec-protobuf' +include 'netty-handler-ssl' +include 'netty-internal-tcnative' +include 'netty-resolver' +include 'netty-util' +include 'netty-zlib'